├── .gitignore ├── MANIFEST.in ├── README.md ├── bpe_surgery_script.py ├── config.ini ├── demo.ipynb ├── fine_tune_title_generation.py ├── nmatheg ├── __init__.py ├── config.ini ├── configs.py ├── dataset.py ├── datasets.ini ├── models.py ├── ner_utils.py ├── nmatheg.py ├── preprocess_ner.py ├── preprocess_qa.py ├── qa_utils.py ├── tests.py └── utils.py ├── nmatheg_logo.PNG ├── playground.ipynb ├── predict.py ├── requirements.txt ├── script.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .ipynb_checkpoints/ 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # pytype static type analyzer 136 | .pytype/ 137 | 138 | # Cython debug symbols 139 | cython_debug/ 140 | 141 | # data files 142 | *data*.txt 143 | data/ 144 | ckpts/ 145 | tmp -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include nmatheg/datasets.ini 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

3 | 4 |

5 | 6 | 7 | # nmatheg 8 | 9 | nmatheg `نماذج` an easy straregy for training Arabic NLP models on huggingface datasets. Just specifiy the name of the dataset, preprocessing, tokenization and the training procedure in the config file to train an nlp model for that task. 10 | 11 | ## install 12 | 13 | ```pip install nmatheg``` 14 | 15 | ## Configuration 16 | 17 | Setup a config file for the training strategy. 18 | 19 | ``` ini 20 | [dataset] 21 | dataset_name = ajgt_twitter_ar 22 | 23 | [preprocessing] 24 | segment = False 25 | remove_special_chars = False 26 | remove_english = False 27 | normalize = False 28 | remove_diacritics = False 29 | excluded_chars = [] 30 | remove_tatweel = False 31 | remove_html_elements = False 32 | remove_links = False 33 | remove_twitter_meta = False 34 | remove_long_words = False 35 | remove_repeated_chars = False 36 | 37 | [tokenization] 38 | tokenizer_name = WordTokenizer 39 | vocab_size = 1000 40 | max_tokens = 128 41 | 42 | [model] 43 | model_name = rnn 44 | 45 | [log] 46 | print_every = 10 47 | 48 | [train] 49 | save_dir = . 50 | epochs = 10 51 | batch_size = 256 52 | ``` 53 | 54 | ### Main Sections 55 | 56 | - `dataset` describe the dataset and the task type. Currently we only support classification 57 | - `preprocessing` a set of cleaning functions mainly uses our library [tnkeeh](https://github.com/ARBML/tnkeeh). 58 | - `tokenization` descrbies the tokenizer used for encoding the dataset. It uses our library [tkseem](https://github.com/ARBML/tkseem). 59 | - `train` the training parameters like number of epochs and batch size. 60 | 61 | ## Usage 62 | 63 | ### Config Files 64 | ```python 65 | import nmatheg as nm 66 | strategy = nm.TrainStrategy('config.ini') 67 | strategy.start() 68 | ``` 69 | ### Benchmarking on multiple datasets and models 70 | ```python 71 | import nmatheg as nm 72 | strategy = nm.TrainStrategy( 73 | datasets = 'arsentd_lev,caner,arcd', 74 | models = 'qarib/bert-base-qarib,aubmindlab/bert-base-arabertv01', 75 | mode = 'finetune', 76 | runs = 5, 77 | lr = 1e-4, 78 | epochs = 1, 79 | batch_size = 8, 80 | max_tokens = 128, 81 | max_train_samples = 1024 82 | ) 83 | strategy.start() 84 | ``` 85 | 86 | ## Datasets 87 | We are supporting huggingface datasets for Arabic. You can find the supported datasets [here](https://github.com/ARBML/nmatheg/blob/main/nmatheg/datasets.ini). 88 | 89 | | Dataset | Description | 90 | | --- | --- | 91 | | [ajgt_twitter_ar](https://huggingface.co/datasets/ajgt_twitter_ar) | Arabic Jordanian General Tweets (AJGT) Corpus consisted of 1,800 tweets annotated as positive and negative. Modern Standard Arabic (MSA) or Jordanian dialect. | 92 | | [metrec](https://huggingface.co/datasets/metrec) | The dataset contains the verses and their corresponding meter classes. Meter classes are represented as numbers from 0 to 13. The dataset can be highly useful for further research in order to improve the field of Arabic poems’ meter classification. The train dataset contains 47,124 records and the test dataset contains 8,316 records. | 93 | |[labr](https://huggingface.co/datasets/labr) |This dataset contains over 63,000 book reviews in Arabic. It is the largest sentiment analysis dataset for Arabic to-date. The book reviews were harvested from the website Goodreads during the month or March 2013. Each book review comes with the goodreads review id, the user id, the book id, the rating (1 to 5) and the text of the review. | 94 | |[ar_res_reviews](https://huggingface.co/datasets/ar_res_reviews)|Dataset of 8364 restaurant reviews from qaym.com in Arabic for sentiment analysis| 95 | |[arsentd_lev](https://huggingface.co/datasets/arsentd_lev)|The Arabic Sentiment Twitter Dataset for Levantine dialect (ArSenTD-LEV) contains 4,000 tweets written in Arabic and equally retrieved from Jordan, Lebanon, Palestine and Syria.| 96 | |[oclar](https://huggingface.co/datasets/oclar)|The researchers of OCLAR Marwan et al. (2019), they gathered Arabic costumer reviews Zomato [website](https://www.zomato.com/lebanon) on wide scope of domain, including restaurants, hotels, hospitals, local shops, etc. The corpus finally contains 3916 reviews in 5-rating scale. For this research purpose, the positive class considers rating stars from 5 to 3 of 3465 reviews, and the negative class is represented from values of 1 and 2 of about 451 texts.| 97 | |[emotone_ar](https://huggingface.co/datasets/emotone_ar)|Dataset of 10,065 tweets in Arabic for Emotion detection in Arabic text| 98 | |[hard](https://huggingface.co/datasets/hard)|This dataset contains 93,700 hotel reviews in Arabic language.The hotel reviews were collected from Booking.com website during June/July 2016.The reviews are expressed in Modern Standard Arabic as well as dialectal Arabic.The following table summarize some tatistics on the HARD Dataset.| 99 | |[caner](https://huggingface.co/datasets/caner)|The Classical Arabic Named Entity Recognition corpus is a new corpus of tagged data that can be useful for handling the issues in recognition of Arabic named entities.| 100 | |[arcd](https://huggingface.co/datasets/arcd)|Arabic Reading Comprehension Dataset (ARCD) composed of 1,395 questions posed by crowdworkers on Wikipedia articles.| 101 | |[mlqa](https://huggingface.co/datasets/mlqa)|MLQA contains QA instances in 7 languages, English, Arabic, German, Spanish, Hindi, Vietnamese and Simplified Chinese.| 102 | |[xnli](https://huggingface.co/datasets/xnli)|XNLI is a subset of a few thousand examples from MNLI which has been translated into a 14 different languages (some low-ish resource).| 103 | |[tatoeba_mt](https://huggingface.co/datasets/Helsinki-NLP/tatoeba_mt)|The Tatoeba Translation Challenge is a multilingual dataset of machine translation benchmarks derived from user-contributed translations collected by Tatoeba.org and provided as parallel corpus from OPUS.| 104 | ## Tasks 105 | 106 | Currently we support text classification, named entity recognition, question answering, machine translation and natural language inference. 107 | 108 | ## Demo 109 | Check this [colab notebook](https://colab.research.google.com/github/ARBML/nmatheg/blob/main/demo.ipynb) for a quick demo. 110 | -------------------------------------------------------------------------------- /bpe_surgery_script.py: -------------------------------------------------------------------------------- 1 | import nmatheg as nm 2 | strategy = nm.TrainStrategy( 3 | datasets = 'ajgt_twitter_ar,caner,xnli', 4 | models = 'birnn', 5 | tokenizers = 'BPE,MaT-BPE,Seg-BPE', 6 | vocab_sizes = '250,500,1000,5000,10000', 7 | runs = 10, 8 | lr = 1e-3, 9 | epochs = 20, 10 | batch_size = 128, 11 | max_tokens = 128, 12 | mode = 'pretrain' 13 | ) 14 | output = strategy.start() -------------------------------------------------------------------------------- /config.ini: -------------------------------------------------------------------------------- 1 | 2 | [dataset] 3 | dataset_name = ajgt_twitter_ar 4 | task = classification 5 | 6 | [preprocessing] 7 | segment = False 8 | remove_special_chars = False 9 | remove_english = False 10 | normalize = False 11 | remove_diacritics = False 12 | excluded_chars = [] 13 | remove_tatweel = False 14 | remove_html_elements = False 15 | remove_links = False 16 | remove_twitter_meta = False 17 | remove_long_words = False 18 | remove_repeated_chars = False 19 | 20 | [tokenization] 21 | tokenizer_name = WordTokenizer 22 | vocab_size = 1000 23 | max_tokens = 128 24 | 25 | [model] 26 | model_name = rnn,aubmindlab/bert-base-arabertv01 27 | 28 | [log] 29 | print_every = 10 30 | 31 | [train] 32 | save_dir = . 33 | epochs = 10 34 | batch_size = 256 35 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "id": "Yr3ZFtPMr22x" 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "%%capture\n", 12 | "!pip install git+https://github.com/ARBML/nmatheg" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": { 19 | "id": "g9liZeykvsfe" 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "import nmatheg as nm\n", 24 | "strategy = nm.TrainStrategy(datasets = 'arsentd_lev,ajgt_twitter_ar,ar_res_reviews,arcd,caner',\n", 25 | " models = 'qarib/bert-base-qarib,\\\n", 26 | " aubmindlab/bert-base-arabertv01,\\\n", 27 | " CAMeL-Lab/bert-base-arabic-camelbert-da,\\\n", 28 | " UBC-NLP/MARBERT,\\\n", 29 | " bashar-talafha/multi-dialect-bert-base-arabic',\n", 30 | " epochs = 5,\n", 31 | " lr = 1e-3,\n", 32 | " batch_size = 8,)\n", 33 | "strategy.start()" 34 | ] 35 | } 36 | ], 37 | "metadata": { 38 | "accelerator": "GPU", 39 | "colab": { 40 | "name": "demo.ipynb", 41 | "provenance": [] 42 | }, 43 | "kernelspec": { 44 | "display_name": "Python 3.9.5 64-bit", 45 | "language": "python", 46 | "name": "python3" 47 | }, 48 | "language_info": { 49 | "codemirror_mode": { 50 | "name": "ipython", 51 | "version": 3 52 | }, 53 | "file_extension": ".py", 54 | "mimetype": "text/x-python", 55 | "name": "python", 56 | "nbconvert_exporter": "python", 57 | "pygments_lexer": "ipython3", 58 | "version": "3.9.5" 59 | }, 60 | "orig_nbformat": 2, 61 | "vscode": { 62 | "interpreter": { 63 | "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" 64 | } 65 | }, 66 | "widgets": { 67 | "application/vnd.jupyter.widget-state+json": { 68 | "02e45e7277bf417d896d92cfad3e7db3": { 69 | "model_module": "@jupyter-widgets/base", 70 | "model_name": "LayoutModel", 71 | "state": { 72 | "_model_module": "@jupyter-widgets/base", 73 | "_model_module_version": "1.2.0", 74 | "_model_name": "LayoutModel", 75 | "_view_count": null, 76 | "_view_module": "@jupyter-widgets/base", 77 | "_view_module_version": "1.2.0", 78 | "_view_name": "LayoutView", 79 | "align_content": null, 80 | "align_items": null, 81 | "align_self": null, 82 | "border": null, 83 | "bottom": null, 84 | "display": null, 85 | "flex": null, 86 | "flex_flow": null, 87 | "grid_area": null, 88 | "grid_auto_columns": null, 89 | "grid_auto_flow": null, 90 | "grid_auto_rows": null, 91 | "grid_column": null, 92 | "grid_gap": null, 93 | "grid_row": null, 94 | "grid_template_areas": null, 95 | "grid_template_columns": null, 96 | "grid_template_rows": null, 97 | "height": null, 98 | "justify_content": null, 99 | "justify_items": null, 100 | "left": null, 101 | "margin": null, 102 | "max_height": null, 103 | "max_width": null, 104 | "min_height": null, 105 | "min_width": null, 106 | "object_fit": null, 107 | "object_position": null, 108 | "order": null, 109 | "overflow": null, 110 | "overflow_x": null, 111 | "overflow_y": null, 112 | "padding": null, 113 | "right": null, 114 | "top": null, 115 | "visibility": null, 116 | "width": null 117 | } 118 | }, 119 | "02f25412874b417892d6ff746d65d0e7": { 120 | "model_module": "@jupyter-widgets/base", 121 | "model_name": "LayoutModel", 122 | "state": { 123 | "_model_module": "@jupyter-widgets/base", 124 | "_model_module_version": "1.2.0", 125 | "_model_name": "LayoutModel", 126 | "_view_count": null, 127 | "_view_module": "@jupyter-widgets/base", 128 | "_view_module_version": "1.2.0", 129 | "_view_name": "LayoutView", 130 | "align_content": null, 131 | "align_items": null, 132 | "align_self": null, 133 | "border": null, 134 | "bottom": null, 135 | "display": null, 136 | "flex": null, 137 | "flex_flow": null, 138 | "grid_area": null, 139 | "grid_auto_columns": null, 140 | "grid_auto_flow": null, 141 | "grid_auto_rows": null, 142 | "grid_column": null, 143 | "grid_gap": null, 144 | "grid_row": null, 145 | "grid_template_areas": null, 146 | "grid_template_columns": null, 147 | "grid_template_rows": null, 148 | "height": null, 149 | "justify_content": null, 150 | "justify_items": null, 151 | "left": null, 152 | "margin": null, 153 | "max_height": null, 154 | "max_width": null, 155 | "min_height": null, 156 | "min_width": null, 157 | "object_fit": null, 158 | "object_position": null, 159 | "order": null, 160 | "overflow": null, 161 | "overflow_x": null, 162 | "overflow_y": null, 163 | "padding": null, 164 | "right": null, 165 | "top": null, 166 | "visibility": null, 167 | "width": null 168 | } 169 | }, 170 | "07b9064cf6804ebbbb26419522c49700": { 171 | "model_module": "@jupyter-widgets/base", 172 | "model_name": "LayoutModel", 173 | "state": { 174 | "_model_module": "@jupyter-widgets/base", 175 | "_model_module_version": "1.2.0", 176 | "_model_name": "LayoutModel", 177 | "_view_count": null, 178 | "_view_module": "@jupyter-widgets/base", 179 | "_view_module_version": "1.2.0", 180 | "_view_name": "LayoutView", 181 | "align_content": null, 182 | "align_items": null, 183 | "align_self": null, 184 | "border": null, 185 | "bottom": null, 186 | "display": null, 187 | "flex": null, 188 | "flex_flow": null, 189 | "grid_area": null, 190 | "grid_auto_columns": null, 191 | "grid_auto_flow": null, 192 | "grid_auto_rows": null, 193 | "grid_column": null, 194 | "grid_gap": null, 195 | "grid_row": null, 196 | "grid_template_areas": null, 197 | "grid_template_columns": null, 198 | "grid_template_rows": null, 199 | "height": null, 200 | "justify_content": null, 201 | "justify_items": null, 202 | "left": null, 203 | "margin": null, 204 | "max_height": null, 205 | "max_width": null, 206 | "min_height": null, 207 | "min_width": null, 208 | "object_fit": null, 209 | "object_position": null, 210 | "order": null, 211 | "overflow": null, 212 | "overflow_x": null, 213 | "overflow_y": null, 214 | "padding": null, 215 | "right": null, 216 | "top": null, 217 | "visibility": null, 218 | "width": null 219 | } 220 | }, 221 | "0c44bce3963d4b73bad9a2f43034da6e": { 222 | "model_module": "@jupyter-widgets/controls", 223 | "model_name": "HTMLModel", 224 | "state": { 225 | "_dom_classes": [], 226 | "_model_module": "@jupyter-widgets/controls", 227 | "_model_module_version": "1.5.0", 228 | "_model_name": "HTMLModel", 229 | "_view_count": null, 230 | "_view_module": "@jupyter-widgets/controls", 231 | "_view_module_version": "1.5.0", 232 | "_view_name": "HTMLView", 233 | "description": "", 234 | "description_tooltip": null, 235 | "layout": "IPY_MODEL_985c2e847ae2468aa32f734fad8cac15", 236 | "placeholder": "​", 237 | "style": "IPY_MODEL_b9f50346f8164d2c99722a21646b4b05", 238 | "value": " 2/2 [00:01<00:00, 1.18ba/s]" 239 | } 240 | }, 241 | "169dd26ec2c34af2b23624f81afcebc3": { 242 | "model_module": "@jupyter-widgets/controls", 243 | "model_name": "HTMLModel", 244 | "state": { 245 | "_dom_classes": [], 246 | "_model_module": "@jupyter-widgets/controls", 247 | "_model_module_version": "1.5.0", 248 | "_model_name": "HTMLModel", 249 | "_view_count": null, 250 | "_view_module": "@jupyter-widgets/controls", 251 | "_view_module_version": "1.5.0", 252 | "_view_name": "HTMLView", 253 | "description": "", 254 | "description_tooltip": null, 255 | "layout": "IPY_MODEL_f2c5b053dc2e4a7590710c091ca5aa1e", 256 | "placeholder": "​", 257 | "style": "IPY_MODEL_9eefe0eb1b02400598908bf39be13998", 258 | "value": " 576/576 [00:03<00:00, 190B/s]" 259 | } 260 | }, 261 | "180e7a8558b74bea8ebcec9e4a06138a": { 262 | "model_module": "@jupyter-widgets/base", 263 | "model_name": "LayoutModel", 264 | "state": { 265 | "_model_module": "@jupyter-widgets/base", 266 | "_model_module_version": "1.2.0", 267 | "_model_name": "LayoutModel", 268 | "_view_count": null, 269 | "_view_module": "@jupyter-widgets/base", 270 | "_view_module_version": "1.2.0", 271 | "_view_name": "LayoutView", 272 | "align_content": null, 273 | "align_items": null, 274 | "align_self": null, 275 | "border": null, 276 | "bottom": null, 277 | "display": null, 278 | "flex": null, 279 | "flex_flow": null, 280 | "grid_area": null, 281 | "grid_auto_columns": null, 282 | "grid_auto_flow": null, 283 | "grid_auto_rows": null, 284 | "grid_column": null, 285 | "grid_gap": null, 286 | "grid_row": null, 287 | "grid_template_areas": null, 288 | "grid_template_columns": null, 289 | "grid_template_rows": null, 290 | "height": null, 291 | "justify_content": null, 292 | "justify_items": null, 293 | "left": null, 294 | "margin": null, 295 | "max_height": null, 296 | "max_width": null, 297 | "min_height": null, 298 | "min_width": null, 299 | "object_fit": null, 300 | "object_position": null, 301 | "order": null, 302 | "overflow": null, 303 | "overflow_x": null, 304 | "overflow_y": null, 305 | "padding": null, 306 | "right": null, 307 | "top": null, 308 | "visibility": null, 309 | "width": null 310 | } 311 | }, 312 | "2367e98e314848fcbf34969075b276b3": { 313 | "model_module": "@jupyter-widgets/controls", 314 | "model_name": "FloatProgressModel", 315 | "state": { 316 | "_dom_classes": [], 317 | "_model_module": "@jupyter-widgets/controls", 318 | "_model_module_version": "1.5.0", 319 | "_model_name": "FloatProgressModel", 320 | "_view_count": null, 321 | "_view_module": "@jupyter-widgets/controls", 322 | "_view_module_version": "1.5.0", 323 | "_view_name": "ProgressView", 324 | "bar_style": "success", 325 | "description": "100%", 326 | "description_tooltip": null, 327 | "layout": "IPY_MODEL_77046ee7556242aeb281a62a7305eeb7", 328 | "max": 2, 329 | "min": 0, 330 | "orientation": "horizontal", 331 | "style": "IPY_MODEL_4684f0aba9f347d1938ce3ed125c56be", 332 | "value": 2 333 | } 334 | }, 335 | "2733a8c195944629a0cb9a0007588358": { 336 | "model_module": "@jupyter-widgets/controls", 337 | "model_name": "ProgressStyleModel", 338 | "state": { 339 | "_model_module": "@jupyter-widgets/controls", 340 | "_model_module_version": "1.5.0", 341 | "_model_name": "ProgressStyleModel", 342 | "_view_count": null, 343 | "_view_module": "@jupyter-widgets/base", 344 | "_view_module_version": "1.2.0", 345 | "_view_name": "StyleView", 346 | "bar_color": null, 347 | "description_width": "initial" 348 | } 349 | }, 350 | "276a05360cb0423a80b0afe2c211e69c": { 351 | "model_module": "@jupyter-widgets/controls", 352 | "model_name": "ProgressStyleModel", 353 | "state": { 354 | "_model_module": "@jupyter-widgets/controls", 355 | "_model_module_version": "1.5.0", 356 | "_model_name": "ProgressStyleModel", 357 | "_view_count": null, 358 | "_view_module": "@jupyter-widgets/base", 359 | "_view_module_version": "1.2.0", 360 | "_view_name": "StyleView", 361 | "bar_color": null, 362 | "description_width": "initial" 363 | } 364 | }, 365 | "29fba2547a7e478ca97792eca1835286": { 366 | "model_module": "@jupyter-widgets/controls", 367 | "model_name": "FloatProgressModel", 368 | "state": { 369 | "_dom_classes": [], 370 | "_model_module": "@jupyter-widgets/controls", 371 | "_model_module_version": "1.5.0", 372 | "_model_name": "FloatProgressModel", 373 | "_view_count": null, 374 | "_view_module": "@jupyter-widgets/controls", 375 | "_view_module_version": "1.5.0", 376 | "_view_name": "ProgressView", 377 | "bar_style": "success", 378 | "description": "Downloading: 100%", 379 | "description_tooltip": null, 380 | "layout": "IPY_MODEL_dbd206abbcac414d8730ff7cf4ef55d8", 381 | "max": 2697421, 382 | "min": 0, 383 | "orientation": "horizontal", 384 | "style": "IPY_MODEL_ec259b6cb7654b819f6e33be6b99a58e", 385 | "value": 2697421 386 | } 387 | }, 388 | "2b239899205b4e35a1fe9a24f578e2c2": { 389 | "model_module": "@jupyter-widgets/controls", 390 | "model_name": "HTMLModel", 391 | "state": { 392 | "_dom_classes": [], 393 | "_model_module": "@jupyter-widgets/controls", 394 | "_model_module_version": "1.5.0", 395 | "_model_name": "HTMLModel", 396 | "_view_count": null, 397 | "_view_module": "@jupyter-widgets/controls", 398 | "_view_module_version": "1.5.0", 399 | "_view_name": "HTMLView", 400 | "description": "", 401 | "description_tooltip": null, 402 | "layout": "IPY_MODEL_3a1c2001ae784d8c9beb5ce39d1fa185", 403 | "placeholder": "​", 404 | "style": "IPY_MODEL_96376ad744dd4e1f97a4df143d164843", 405 | "value": " 2/2 [00:00<00:00, 2.03ba/s]" 406 | } 407 | }, 408 | "307c92aaf92b48cdb7bcf6c2a5a4fc94": { 409 | "model_module": "@jupyter-widgets/controls", 410 | "model_name": "FloatProgressModel", 411 | "state": { 412 | "_dom_classes": [], 413 | "_model_module": "@jupyter-widgets/controls", 414 | "_model_module_version": "1.5.0", 415 | "_model_name": "FloatProgressModel", 416 | "_view_count": null, 417 | "_view_module": "@jupyter-widgets/controls", 418 | "_view_module_version": "1.5.0", 419 | "_view_name": "ProgressView", 420 | "bar_style": "success", 421 | "description": "100%", 422 | "description_tooltip": null, 423 | "layout": "IPY_MODEL_fabc123e8ab14205a81fee5e8e6484dc", 424 | "max": 2, 425 | "min": 0, 426 | "orientation": "horizontal", 427 | "style": "IPY_MODEL_4ec1f854d65c4abeaa257fd6f165dc56", 428 | "value": 2 429 | } 430 | }, 431 | "39863d81dd304318beee0df7d409e9e9": { 432 | "model_module": "@jupyter-widgets/controls", 433 | "model_name": "DescriptionStyleModel", 434 | "state": { 435 | "_model_module": "@jupyter-widgets/controls", 436 | "_model_module_version": "1.5.0", 437 | "_model_name": "DescriptionStyleModel", 438 | "_view_count": null, 439 | "_view_module": "@jupyter-widgets/base", 440 | "_view_module_version": "1.2.0", 441 | "_view_name": "StyleView", 442 | "description_width": "" 443 | } 444 | }, 445 | "39d781ce862343cb833d481cdf2237fb": { 446 | "model_module": "@jupyter-widgets/controls", 447 | "model_name": "HTMLModel", 448 | "state": { 449 | "_dom_classes": [], 450 | "_model_module": "@jupyter-widgets/controls", 451 | "_model_module_version": "1.5.0", 452 | "_model_name": "HTMLModel", 453 | "_view_count": null, 454 | "_view_module": "@jupyter-widgets/controls", 455 | "_view_module_version": "1.5.0", 456 | "_view_name": "HTMLView", 457 | "description": "", 458 | "description_tooltip": null, 459 | "layout": "IPY_MODEL_fb595227f20f4a18aaef097807dede5b", 460 | "placeholder": "​", 461 | "style": "IPY_MODEL_5a46917da090478f850d3aa0f06c5226", 462 | "value": " 780k/780k [00:02<00:00, 321kB/s]" 463 | } 464 | }, 465 | "3a1c2001ae784d8c9beb5ce39d1fa185": { 466 | "model_module": "@jupyter-widgets/base", 467 | "model_name": "LayoutModel", 468 | "state": { 469 | "_model_module": "@jupyter-widgets/base", 470 | "_model_module_version": "1.2.0", 471 | "_model_name": "LayoutModel", 472 | "_view_count": null, 473 | "_view_module": "@jupyter-widgets/base", 474 | "_view_module_version": "1.2.0", 475 | "_view_name": "LayoutView", 476 | "align_content": null, 477 | "align_items": null, 478 | "align_self": null, 479 | "border": null, 480 | "bottom": null, 481 | "display": null, 482 | "flex": null, 483 | "flex_flow": null, 484 | "grid_area": null, 485 | "grid_auto_columns": null, 486 | "grid_auto_flow": null, 487 | "grid_auto_rows": null, 488 | "grid_column": null, 489 | "grid_gap": null, 490 | "grid_row": null, 491 | "grid_template_areas": null, 492 | "grid_template_columns": null, 493 | "grid_template_rows": null, 494 | "height": null, 495 | "justify_content": null, 496 | "justify_items": null, 497 | "left": null, 498 | "margin": null, 499 | "max_height": null, 500 | "max_width": null, 501 | "min_height": null, 502 | "min_width": null, 503 | "object_fit": null, 504 | "object_position": null, 505 | "order": null, 506 | "overflow": null, 507 | "overflow_x": null, 508 | "overflow_y": null, 509 | "padding": null, 510 | "right": null, 511 | "top": null, 512 | "visibility": null, 513 | "width": null 514 | } 515 | }, 516 | "451ce5e7a0644ecdbac61b6995192e02": { 517 | "model_module": "@jupyter-widgets/controls", 518 | "model_name": "HTMLModel", 519 | "state": { 520 | "_dom_classes": [], 521 | "_model_module": "@jupyter-widgets/controls", 522 | "_model_module_version": "1.5.0", 523 | "_model_name": "HTMLModel", 524 | "_view_count": null, 525 | "_view_module": "@jupyter-widgets/controls", 526 | "_view_module_version": "1.5.0", 527 | "_view_name": "HTMLView", 528 | "description": "", 529 | "description_tooltip": null, 530 | "layout": "IPY_MODEL_5ed2d167e0e641e290a46d457c8930d5", 531 | "placeholder": "​", 532 | "style": "IPY_MODEL_39863d81dd304318beee0df7d409e9e9", 533 | "value": " 2.70M/2.70M [00:01<00:00, 1.89MB/s]" 534 | } 535 | }, 536 | "4684f0aba9f347d1938ce3ed125c56be": { 537 | "model_module": "@jupyter-widgets/controls", 538 | "model_name": "ProgressStyleModel", 539 | "state": { 540 | "_model_module": "@jupyter-widgets/controls", 541 | "_model_module_version": "1.5.0", 542 | "_model_name": "ProgressStyleModel", 543 | "_view_count": null, 544 | "_view_module": "@jupyter-widgets/base", 545 | "_view_module_version": "1.2.0", 546 | "_view_name": "StyleView", 547 | "bar_color": null, 548 | "description_width": "initial" 549 | } 550 | }, 551 | "47b5070966da480cab27ed1332f25aa4": { 552 | "model_module": "@jupyter-widgets/controls", 553 | "model_name": "DescriptionStyleModel", 554 | "state": { 555 | "_model_module": "@jupyter-widgets/controls", 556 | "_model_module_version": "1.5.0", 557 | "_model_name": "DescriptionStyleModel", 558 | "_view_count": null, 559 | "_view_module": "@jupyter-widgets/base", 560 | "_view_module_version": "1.2.0", 561 | "_view_name": "StyleView", 562 | "description_width": "" 563 | } 564 | }, 565 | "47c3a693d7e9401cbb20bed9ca7fe7ef": { 566 | "model_module": "@jupyter-widgets/controls", 567 | "model_name": "FloatProgressModel", 568 | "state": { 569 | "_dom_classes": [], 570 | "_model_module": "@jupyter-widgets/controls", 571 | "_model_module_version": "1.5.0", 572 | "_model_name": "FloatProgressModel", 573 | "_view_count": null, 574 | "_view_module": "@jupyter-widgets/controls", 575 | "_view_module_version": "1.5.0", 576 | "_view_name": "ProgressView", 577 | "bar_style": "success", 578 | "description": "Downloading: 100%", 579 | "description_tooltip": null, 580 | "layout": "IPY_MODEL_4dc76c6f669f4514b9144e6c7b020dbd", 581 | "max": 780034, 582 | "min": 0, 583 | "orientation": "horizontal", 584 | "style": "IPY_MODEL_94e063404d83486190d12c8846d0882f", 585 | "value": 780034 586 | } 587 | }, 588 | "4d2b0fbfe0d5421b93e529f34adeefb3": { 589 | "model_module": "@jupyter-widgets/base", 590 | "model_name": "LayoutModel", 591 | "state": { 592 | "_model_module": "@jupyter-widgets/base", 593 | "_model_module_version": "1.2.0", 594 | "_model_name": "LayoutModel", 595 | "_view_count": null, 596 | "_view_module": "@jupyter-widgets/base", 597 | "_view_module_version": "1.2.0", 598 | "_view_name": "LayoutView", 599 | "align_content": null, 600 | "align_items": null, 601 | "align_self": null, 602 | "border": null, 603 | "bottom": null, 604 | "display": null, 605 | "flex": null, 606 | "flex_flow": null, 607 | "grid_area": null, 608 | "grid_auto_columns": null, 609 | "grid_auto_flow": null, 610 | "grid_auto_rows": null, 611 | "grid_column": null, 612 | "grid_gap": null, 613 | "grid_row": null, 614 | "grid_template_areas": null, 615 | "grid_template_columns": null, 616 | "grid_template_rows": null, 617 | "height": null, 618 | "justify_content": null, 619 | "justify_items": null, 620 | "left": null, 621 | "margin": null, 622 | "max_height": null, 623 | "max_width": null, 624 | "min_height": null, 625 | "min_width": null, 626 | "object_fit": null, 627 | "object_position": null, 628 | "order": null, 629 | "overflow": null, 630 | "overflow_x": null, 631 | "overflow_y": null, 632 | "padding": null, 633 | "right": null, 634 | "top": null, 635 | "visibility": null, 636 | "width": null 637 | } 638 | }, 639 | "4dc76c6f669f4514b9144e6c7b020dbd": { 640 | "model_module": "@jupyter-widgets/base", 641 | "model_name": "LayoutModel", 642 | "state": { 643 | "_model_module": "@jupyter-widgets/base", 644 | "_model_module_version": "1.2.0", 645 | "_model_name": "LayoutModel", 646 | "_view_count": null, 647 | "_view_module": "@jupyter-widgets/base", 648 | "_view_module_version": "1.2.0", 649 | "_view_name": "LayoutView", 650 | "align_content": null, 651 | "align_items": null, 652 | "align_self": null, 653 | "border": null, 654 | "bottom": null, 655 | "display": null, 656 | "flex": null, 657 | "flex_flow": null, 658 | "grid_area": null, 659 | "grid_auto_columns": null, 660 | "grid_auto_flow": null, 661 | "grid_auto_rows": null, 662 | "grid_column": null, 663 | "grid_gap": null, 664 | "grid_row": null, 665 | "grid_template_areas": null, 666 | "grid_template_columns": null, 667 | "grid_template_rows": null, 668 | "height": null, 669 | "justify_content": null, 670 | "justify_items": null, 671 | "left": null, 672 | "margin": null, 673 | "max_height": null, 674 | "max_width": null, 675 | "min_height": null, 676 | "min_width": null, 677 | "object_fit": null, 678 | "object_position": null, 679 | "order": null, 680 | "overflow": null, 681 | "overflow_x": null, 682 | "overflow_y": null, 683 | "padding": null, 684 | "right": null, 685 | "top": null, 686 | "visibility": null, 687 | "width": null 688 | } 689 | }, 690 | "4ec1f854d65c4abeaa257fd6f165dc56": { 691 | "model_module": "@jupyter-widgets/controls", 692 | "model_name": "ProgressStyleModel", 693 | "state": { 694 | "_model_module": "@jupyter-widgets/controls", 695 | "_model_module_version": "1.5.0", 696 | "_model_name": "ProgressStyleModel", 697 | "_view_count": null, 698 | "_view_module": "@jupyter-widgets/base", 699 | "_view_module_version": "1.2.0", 700 | "_view_name": "StyleView", 701 | "bar_color": null, 702 | "description_width": "initial" 703 | } 704 | }, 705 | "523f17ee95144d40b65d4549c98e6cd5": { 706 | "model_module": "@jupyter-widgets/controls", 707 | "model_name": "FloatProgressModel", 708 | "state": { 709 | "_dom_classes": [], 710 | "_model_module": "@jupyter-widgets/controls", 711 | "_model_module_version": "1.5.0", 712 | "_model_name": "FloatProgressModel", 713 | "_view_count": null, 714 | "_view_module": "@jupyter-widgets/controls", 715 | "_view_module_version": "1.5.0", 716 | "_view_name": "ProgressView", 717 | "bar_style": "success", 718 | "description": "Downloading: 100%", 719 | "description_tooltip": null, 720 | "layout": "IPY_MODEL_a918de7d6fc141069ec93489dbb2cdb3", 721 | "max": 543450723, 722 | "min": 0, 723 | "orientation": "horizontal", 724 | "style": "IPY_MODEL_2733a8c195944629a0cb9a0007588358", 725 | "value": 543450723 726 | } 727 | }, 728 | "5a37e3838ef146d2aef32cddfaf8daf4": { 729 | "model_module": "@jupyter-widgets/controls", 730 | "model_name": "ProgressStyleModel", 731 | "state": { 732 | "_model_module": "@jupyter-widgets/controls", 733 | "_model_module_version": "1.5.0", 734 | "_model_name": "ProgressStyleModel", 735 | "_view_count": null, 736 | "_view_module": "@jupyter-widgets/base", 737 | "_view_module_version": "1.2.0", 738 | "_view_name": "StyleView", 739 | "bar_color": null, 740 | "description_width": "initial" 741 | } 742 | }, 743 | "5a46917da090478f850d3aa0f06c5226": { 744 | "model_module": "@jupyter-widgets/controls", 745 | "model_name": "DescriptionStyleModel", 746 | "state": { 747 | "_model_module": "@jupyter-widgets/controls", 748 | "_model_module_version": "1.5.0", 749 | "_model_name": "DescriptionStyleModel", 750 | "_view_count": null, 751 | "_view_module": "@jupyter-widgets/base", 752 | "_view_module_version": "1.2.0", 753 | "_view_name": "StyleView", 754 | "description_width": "" 755 | } 756 | }, 757 | "5c15eae703e24258a4fa9272c55e93f3": { 758 | "model_module": "@jupyter-widgets/controls", 759 | "model_name": "DescriptionStyleModel", 760 | "state": { 761 | "_model_module": "@jupyter-widgets/controls", 762 | "_model_module_version": "1.5.0", 763 | "_model_name": "DescriptionStyleModel", 764 | "_view_count": null, 765 | "_view_module": "@jupyter-widgets/base", 766 | "_view_module_version": "1.2.0", 767 | "_view_name": "StyleView", 768 | "description_width": "" 769 | } 770 | }, 771 | "5ed2d167e0e641e290a46d457c8930d5": { 772 | "model_module": "@jupyter-widgets/base", 773 | "model_name": "LayoutModel", 774 | "state": { 775 | "_model_module": "@jupyter-widgets/base", 776 | "_model_module_version": "1.2.0", 777 | "_model_name": "LayoutModel", 778 | "_view_count": null, 779 | "_view_module": "@jupyter-widgets/base", 780 | "_view_module_version": "1.2.0", 781 | "_view_name": "LayoutView", 782 | "align_content": null, 783 | "align_items": null, 784 | "align_self": null, 785 | "border": null, 786 | "bottom": null, 787 | "display": null, 788 | "flex": null, 789 | "flex_flow": null, 790 | "grid_area": null, 791 | "grid_auto_columns": null, 792 | "grid_auto_flow": null, 793 | "grid_auto_rows": null, 794 | "grid_column": null, 795 | "grid_gap": null, 796 | "grid_row": null, 797 | "grid_template_areas": null, 798 | "grid_template_columns": null, 799 | "grid_template_rows": null, 800 | "height": null, 801 | "justify_content": null, 802 | "justify_items": null, 803 | "left": null, 804 | "margin": null, 805 | "max_height": null, 806 | "max_width": null, 807 | "min_height": null, 808 | "min_width": null, 809 | "object_fit": null, 810 | "object_position": null, 811 | "order": null, 812 | "overflow": null, 813 | "overflow_x": null, 814 | "overflow_y": null, 815 | "padding": null, 816 | "right": null, 817 | "top": null, 818 | "visibility": null, 819 | "width": null 820 | } 821 | }, 822 | "5f26156efe6d49379ee39e75bae41840": { 823 | "model_module": "@jupyter-widgets/base", 824 | "model_name": "LayoutModel", 825 | "state": { 826 | "_model_module": "@jupyter-widgets/base", 827 | "_model_module_version": "1.2.0", 828 | "_model_name": "LayoutModel", 829 | "_view_count": null, 830 | "_view_module": "@jupyter-widgets/base", 831 | "_view_module_version": "1.2.0", 832 | "_view_name": "LayoutView", 833 | "align_content": null, 834 | "align_items": null, 835 | "align_self": null, 836 | "border": null, 837 | "bottom": null, 838 | "display": null, 839 | "flex": null, 840 | "flex_flow": null, 841 | "grid_area": null, 842 | "grid_auto_columns": null, 843 | "grid_auto_flow": null, 844 | "grid_auto_rows": null, 845 | "grid_column": null, 846 | "grid_gap": null, 847 | "grid_row": null, 848 | "grid_template_areas": null, 849 | "grid_template_columns": null, 850 | "grid_template_rows": null, 851 | "height": null, 852 | "justify_content": null, 853 | "justify_items": null, 854 | "left": null, 855 | "margin": null, 856 | "max_height": null, 857 | "max_width": null, 858 | "min_height": null, 859 | "min_width": null, 860 | "object_fit": null, 861 | "object_position": null, 862 | "order": null, 863 | "overflow": null, 864 | "overflow_x": null, 865 | "overflow_y": null, 866 | "padding": null, 867 | "right": null, 868 | "top": null, 869 | "visibility": null, 870 | "width": null 871 | } 872 | }, 873 | "5f7bea12320e41e2a5b16da0987d14fa": { 874 | "model_module": "@jupyter-widgets/controls", 875 | "model_name": "DescriptionStyleModel", 876 | "state": { 877 | "_model_module": "@jupyter-widgets/controls", 878 | "_model_module_version": "1.5.0", 879 | "_model_name": "DescriptionStyleModel", 880 | "_view_count": null, 881 | "_view_module": "@jupyter-widgets/base", 882 | "_view_module_version": "1.2.0", 883 | "_view_name": "StyleView", 884 | "description_width": "" 885 | } 886 | }, 887 | "62f3278b280b4c86aa34638c097d05fd": { 888 | "model_module": "@jupyter-widgets/controls", 889 | "model_name": "HBoxModel", 890 | "state": { 891 | "_dom_classes": [], 892 | "_model_module": "@jupyter-widgets/controls", 893 | "_model_module_version": "1.5.0", 894 | "_model_name": "HBoxModel", 895 | "_view_count": null, 896 | "_view_module": "@jupyter-widgets/controls", 897 | "_view_module_version": "1.5.0", 898 | "_view_name": "HBoxView", 899 | "box_style": "", 900 | "children": [ 901 | "IPY_MODEL_47c3a693d7e9401cbb20bed9ca7fe7ef", 902 | "IPY_MODEL_39d781ce862343cb833d481cdf2237fb" 903 | ], 904 | "layout": "IPY_MODEL_180e7a8558b74bea8ebcec9e4a06138a" 905 | } 906 | }, 907 | "63615e3977904ddc8204f70ddfc54dd6": { 908 | "model_module": "@jupyter-widgets/controls", 909 | "model_name": "HTMLModel", 910 | "state": { 911 | "_dom_classes": [], 912 | "_model_module": "@jupyter-widgets/controls", 913 | "_model_module_version": "1.5.0", 914 | "_model_name": "HTMLModel", 915 | "_view_count": null, 916 | "_view_module": "@jupyter-widgets/controls", 917 | "_view_module_version": "1.5.0", 918 | "_view_name": "HTMLView", 919 | "description": "", 920 | "description_tooltip": null, 921 | "layout": "IPY_MODEL_02e45e7277bf417d896d92cfad3e7db3", 922 | "placeholder": "​", 923 | "style": "IPY_MODEL_5c15eae703e24258a4fa9272c55e93f3", 924 | "value": " 543M/543M [00:10<00:00, 52.7MB/s]" 925 | } 926 | }, 927 | "644806c6298940c892b9e1f34a5127b9": { 928 | "model_module": "@jupyter-widgets/controls", 929 | "model_name": "HBoxModel", 930 | "state": { 931 | "_dom_classes": [], 932 | "_model_module": "@jupyter-widgets/controls", 933 | "_model_module_version": "1.5.0", 934 | "_model_name": "HBoxModel", 935 | "_view_count": null, 936 | "_view_module": "@jupyter-widgets/controls", 937 | "_view_module_version": "1.5.0", 938 | "_view_name": "HBoxView", 939 | "box_style": "", 940 | "children": [ 941 | "IPY_MODEL_307c92aaf92b48cdb7bcf6c2a5a4fc94", 942 | "IPY_MODEL_0c44bce3963d4b73bad9a2f43034da6e" 943 | ], 944 | "layout": "IPY_MODEL_c5a25438e62945ca84bf8f386729811b" 945 | } 946 | }, 947 | "69705ae4e97b484393c78d35b3d1ced4": { 948 | "model_module": "@jupyter-widgets/controls", 949 | "model_name": "HBoxModel", 950 | "state": { 951 | "_dom_classes": [], 952 | "_model_module": "@jupyter-widgets/controls", 953 | "_model_module_version": "1.5.0", 954 | "_model_name": "HBoxModel", 955 | "_view_count": null, 956 | "_view_module": "@jupyter-widgets/controls", 957 | "_view_module_version": "1.5.0", 958 | "_view_name": "HBoxView", 959 | "box_style": "", 960 | "children": [ 961 | "IPY_MODEL_29fba2547a7e478ca97792eca1835286", 962 | "IPY_MODEL_451ce5e7a0644ecdbac61b6995192e02" 963 | ], 964 | "layout": "IPY_MODEL_d79e114bea844160bc941d81d4581b46" 965 | } 966 | }, 967 | "77046ee7556242aeb281a62a7305eeb7": { 968 | "model_module": "@jupyter-widgets/base", 969 | "model_name": "LayoutModel", 970 | "state": { 971 | "_model_module": "@jupyter-widgets/base", 972 | "_model_module_version": "1.2.0", 973 | "_model_name": "LayoutModel", 974 | "_view_count": null, 975 | "_view_module": "@jupyter-widgets/base", 976 | "_view_module_version": "1.2.0", 977 | "_view_name": "LayoutView", 978 | "align_content": null, 979 | "align_items": null, 980 | "align_self": null, 981 | "border": null, 982 | "bottom": null, 983 | "display": null, 984 | "flex": null, 985 | "flex_flow": null, 986 | "grid_area": null, 987 | "grid_auto_columns": null, 988 | "grid_auto_flow": null, 989 | "grid_auto_rows": null, 990 | "grid_column": null, 991 | "grid_gap": null, 992 | "grid_row": null, 993 | "grid_template_areas": null, 994 | "grid_template_columns": null, 995 | "grid_template_rows": null, 996 | "height": null, 997 | "justify_content": null, 998 | "justify_items": null, 999 | "left": null, 1000 | "margin": null, 1001 | "max_height": null, 1002 | "max_width": null, 1003 | "min_height": null, 1004 | "min_width": null, 1005 | "object_fit": null, 1006 | "object_position": null, 1007 | "order": null, 1008 | "overflow": null, 1009 | "overflow_x": null, 1010 | "overflow_y": null, 1011 | "padding": null, 1012 | "right": null, 1013 | "top": null, 1014 | "visibility": null, 1015 | "width": null 1016 | } 1017 | }, 1018 | "8b9043ffe8fb4313869fd9ca78021a93": { 1019 | "model_module": "@jupyter-widgets/controls", 1020 | "model_name": "FloatProgressModel", 1021 | "state": { 1022 | "_dom_classes": [], 1023 | "_model_module": "@jupyter-widgets/controls", 1024 | "_model_module_version": "1.5.0", 1025 | "_model_name": "FloatProgressModel", 1026 | "_view_count": null, 1027 | "_view_module": "@jupyter-widgets/controls", 1028 | "_view_module_version": "1.5.0", 1029 | "_view_name": "ProgressView", 1030 | "bar_style": "success", 1031 | "description": "Downloading: 100%", 1032 | "description_tooltip": null, 1033 | "layout": "IPY_MODEL_02f25412874b417892d6ff746d65d0e7", 1034 | "max": 112, 1035 | "min": 0, 1036 | "orientation": "horizontal", 1037 | "style": "IPY_MODEL_e24e4163978a4f068ec80beab9ad0362", 1038 | "value": 112 1039 | } 1040 | }, 1041 | "8d5f2f8c43884e49bbec86cf2bc22d26": { 1042 | "model_module": "@jupyter-widgets/controls", 1043 | "model_name": "HBoxModel", 1044 | "state": { 1045 | "_dom_classes": [], 1046 | "_model_module": "@jupyter-widgets/controls", 1047 | "_model_module_version": "1.5.0", 1048 | "_model_name": "HBoxModel", 1049 | "_view_count": null, 1050 | "_view_module": "@jupyter-widgets/controls", 1051 | "_view_module_version": "1.5.0", 1052 | "_view_name": "HBoxView", 1053 | "box_style": "", 1054 | "children": [ 1055 | "IPY_MODEL_523f17ee95144d40b65d4549c98e6cd5", 1056 | "IPY_MODEL_63615e3977904ddc8204f70ddfc54dd6" 1057 | ], 1058 | "layout": "IPY_MODEL_be044d27d62342cb91d633eeddf52a58" 1059 | } 1060 | }, 1061 | "90dc4e646e3b4bf39f7d2de7e6e28d15": { 1062 | "model_module": "@jupyter-widgets/base", 1063 | "model_name": "LayoutModel", 1064 | "state": { 1065 | "_model_module": "@jupyter-widgets/base", 1066 | "_model_module_version": "1.2.0", 1067 | "_model_name": "LayoutModel", 1068 | "_view_count": null, 1069 | "_view_module": "@jupyter-widgets/base", 1070 | "_view_module_version": "1.2.0", 1071 | "_view_name": "LayoutView", 1072 | "align_content": null, 1073 | "align_items": null, 1074 | "align_self": null, 1075 | "border": null, 1076 | "bottom": null, 1077 | "display": null, 1078 | "flex": null, 1079 | "flex_flow": null, 1080 | "grid_area": null, 1081 | "grid_auto_columns": null, 1082 | "grid_auto_flow": null, 1083 | "grid_auto_rows": null, 1084 | "grid_column": null, 1085 | "grid_gap": null, 1086 | "grid_row": null, 1087 | "grid_template_areas": null, 1088 | "grid_template_columns": null, 1089 | "grid_template_rows": null, 1090 | "height": null, 1091 | "justify_content": null, 1092 | "justify_items": null, 1093 | "left": null, 1094 | "margin": null, 1095 | "max_height": null, 1096 | "max_width": null, 1097 | "min_height": null, 1098 | "min_width": null, 1099 | "object_fit": null, 1100 | "object_position": null, 1101 | "order": null, 1102 | "overflow": null, 1103 | "overflow_x": null, 1104 | "overflow_y": null, 1105 | "padding": null, 1106 | "right": null, 1107 | "top": null, 1108 | "visibility": null, 1109 | "width": null 1110 | } 1111 | }, 1112 | "93d29cab94e84cae9b7a6f8ecccacba6": { 1113 | "model_module": "@jupyter-widgets/controls", 1114 | "model_name": "HTMLModel", 1115 | "state": { 1116 | "_dom_classes": [], 1117 | "_model_module": "@jupyter-widgets/controls", 1118 | "_model_module_version": "1.5.0", 1119 | "_model_name": "HTMLModel", 1120 | "_view_count": null, 1121 | "_view_module": "@jupyter-widgets/controls", 1122 | "_view_module_version": "1.5.0", 1123 | "_view_name": "HTMLView", 1124 | "description": "", 1125 | "description_tooltip": null, 1126 | "layout": "IPY_MODEL_90dc4e646e3b4bf39f7d2de7e6e28d15", 1127 | "placeholder": "​", 1128 | "style": "IPY_MODEL_47b5070966da480cab27ed1332f25aa4", 1129 | "value": " 112/112 [00:01<00:00, 86.3B/s]" 1130 | } 1131 | }, 1132 | "94e063404d83486190d12c8846d0882f": { 1133 | "model_module": "@jupyter-widgets/controls", 1134 | "model_name": "ProgressStyleModel", 1135 | "state": { 1136 | "_model_module": "@jupyter-widgets/controls", 1137 | "_model_module_version": "1.5.0", 1138 | "_model_name": "ProgressStyleModel", 1139 | "_view_count": null, 1140 | "_view_module": "@jupyter-widgets/base", 1141 | "_view_module_version": "1.2.0", 1142 | "_view_name": "StyleView", 1143 | "bar_color": null, 1144 | "description_width": "initial" 1145 | } 1146 | }, 1147 | "96376ad744dd4e1f97a4df143d164843": { 1148 | "model_module": "@jupyter-widgets/controls", 1149 | "model_name": "DescriptionStyleModel", 1150 | "state": { 1151 | "_model_module": "@jupyter-widgets/controls", 1152 | "_model_module_version": "1.5.0", 1153 | "_model_name": "DescriptionStyleModel", 1154 | "_view_count": null, 1155 | "_view_module": "@jupyter-widgets/base", 1156 | "_view_module_version": "1.2.0", 1157 | "_view_name": "StyleView", 1158 | "description_width": "" 1159 | } 1160 | }, 1161 | "985c2e847ae2468aa32f734fad8cac15": { 1162 | "model_module": "@jupyter-widgets/base", 1163 | "model_name": "LayoutModel", 1164 | "state": { 1165 | "_model_module": "@jupyter-widgets/base", 1166 | "_model_module_version": "1.2.0", 1167 | "_model_name": "LayoutModel", 1168 | "_view_count": null, 1169 | "_view_module": "@jupyter-widgets/base", 1170 | "_view_module_version": "1.2.0", 1171 | "_view_name": "LayoutView", 1172 | "align_content": null, 1173 | "align_items": null, 1174 | "align_self": null, 1175 | "border": null, 1176 | "bottom": null, 1177 | "display": null, 1178 | "flex": null, 1179 | "flex_flow": null, 1180 | "grid_area": null, 1181 | "grid_auto_columns": null, 1182 | "grid_auto_flow": null, 1183 | "grid_auto_rows": null, 1184 | "grid_column": null, 1185 | "grid_gap": null, 1186 | "grid_row": null, 1187 | "grid_template_areas": null, 1188 | "grid_template_columns": null, 1189 | "grid_template_rows": null, 1190 | "height": null, 1191 | "justify_content": null, 1192 | "justify_items": null, 1193 | "left": null, 1194 | "margin": null, 1195 | "max_height": null, 1196 | "max_width": null, 1197 | "min_height": null, 1198 | "min_width": null, 1199 | "object_fit": null, 1200 | "object_position": null, 1201 | "order": null, 1202 | "overflow": null, 1203 | "overflow_x": null, 1204 | "overflow_y": null, 1205 | "padding": null, 1206 | "right": null, 1207 | "top": null, 1208 | "visibility": null, 1209 | "width": null 1210 | } 1211 | }, 1212 | "9eefe0eb1b02400598908bf39be13998": { 1213 | "model_module": "@jupyter-widgets/controls", 1214 | "model_name": "DescriptionStyleModel", 1215 | "state": { 1216 | "_model_module": "@jupyter-widgets/controls", 1217 | "_model_module_version": "1.5.0", 1218 | "_model_name": "DescriptionStyleModel", 1219 | "_view_count": null, 1220 | "_view_module": "@jupyter-widgets/base", 1221 | "_view_module_version": "1.2.0", 1222 | "_view_name": "StyleView", 1223 | "description_width": "" 1224 | } 1225 | }, 1226 | "a1b586360cc046e2bcb50f8e32c54134": { 1227 | "model_module": "@jupyter-widgets/controls", 1228 | "model_name": "HBoxModel", 1229 | "state": { 1230 | "_dom_classes": [], 1231 | "_model_module": "@jupyter-widgets/controls", 1232 | "_model_module_version": "1.5.0", 1233 | "_model_name": "HBoxModel", 1234 | "_view_count": null, 1235 | "_view_module": "@jupyter-widgets/controls", 1236 | "_view_module_version": "1.5.0", 1237 | "_view_name": "HBoxView", 1238 | "box_style": "", 1239 | "children": [ 1240 | "IPY_MODEL_b021c96a3a144159a1de8b9b526a9cf0", 1241 | "IPY_MODEL_dbaf10b01e9b42299f67b4095b210499" 1242 | ], 1243 | "layout": "IPY_MODEL_4d2b0fbfe0d5421b93e529f34adeefb3" 1244 | } 1245 | }, 1246 | "a3be3ff323e842b1a0927a62eddf462d": { 1247 | "model_module": "@jupyter-widgets/base", 1248 | "model_name": "LayoutModel", 1249 | "state": { 1250 | "_model_module": "@jupyter-widgets/base", 1251 | "_model_module_version": "1.2.0", 1252 | "_model_name": "LayoutModel", 1253 | "_view_count": null, 1254 | "_view_module": "@jupyter-widgets/base", 1255 | "_view_module_version": "1.2.0", 1256 | "_view_name": "LayoutView", 1257 | "align_content": null, 1258 | "align_items": null, 1259 | "align_self": null, 1260 | "border": null, 1261 | "bottom": null, 1262 | "display": null, 1263 | "flex": null, 1264 | "flex_flow": null, 1265 | "grid_area": null, 1266 | "grid_auto_columns": null, 1267 | "grid_auto_flow": null, 1268 | "grid_auto_rows": null, 1269 | "grid_column": null, 1270 | "grid_gap": null, 1271 | "grid_row": null, 1272 | "grid_template_areas": null, 1273 | "grid_template_columns": null, 1274 | "grid_template_rows": null, 1275 | "height": null, 1276 | "justify_content": null, 1277 | "justify_items": null, 1278 | "left": null, 1279 | "margin": null, 1280 | "max_height": null, 1281 | "max_width": null, 1282 | "min_height": null, 1283 | "min_width": null, 1284 | "object_fit": null, 1285 | "object_position": null, 1286 | "order": null, 1287 | "overflow": null, 1288 | "overflow_x": null, 1289 | "overflow_y": null, 1290 | "padding": null, 1291 | "right": null, 1292 | "top": null, 1293 | "visibility": null, 1294 | "width": null 1295 | } 1296 | }, 1297 | "a918de7d6fc141069ec93489dbb2cdb3": { 1298 | "model_module": "@jupyter-widgets/base", 1299 | "model_name": "LayoutModel", 1300 | "state": { 1301 | "_model_module": "@jupyter-widgets/base", 1302 | "_model_module_version": "1.2.0", 1303 | "_model_name": "LayoutModel", 1304 | "_view_count": null, 1305 | "_view_module": "@jupyter-widgets/base", 1306 | "_view_module_version": "1.2.0", 1307 | "_view_name": "LayoutView", 1308 | "align_content": null, 1309 | "align_items": null, 1310 | "align_self": null, 1311 | "border": null, 1312 | "bottom": null, 1313 | "display": null, 1314 | "flex": null, 1315 | "flex_flow": null, 1316 | "grid_area": null, 1317 | "grid_auto_columns": null, 1318 | "grid_auto_flow": null, 1319 | "grid_auto_rows": null, 1320 | "grid_column": null, 1321 | "grid_gap": null, 1322 | "grid_row": null, 1323 | "grid_template_areas": null, 1324 | "grid_template_columns": null, 1325 | "grid_template_rows": null, 1326 | "height": null, 1327 | "justify_content": null, 1328 | "justify_items": null, 1329 | "left": null, 1330 | "margin": null, 1331 | "max_height": null, 1332 | "max_width": null, 1333 | "min_height": null, 1334 | "min_width": null, 1335 | "object_fit": null, 1336 | "object_position": null, 1337 | "order": null, 1338 | "overflow": null, 1339 | "overflow_x": null, 1340 | "overflow_y": null, 1341 | "padding": null, 1342 | "right": null, 1343 | "top": null, 1344 | "visibility": null, 1345 | "width": null 1346 | } 1347 | }, 1348 | "b021c96a3a144159a1de8b9b526a9cf0": { 1349 | "model_module": "@jupyter-widgets/controls", 1350 | "model_name": "FloatProgressModel", 1351 | "state": { 1352 | "_dom_classes": [], 1353 | "_model_module": "@jupyter-widgets/controls", 1354 | "_model_module_version": "1.5.0", 1355 | "_model_name": "FloatProgressModel", 1356 | "_view_count": null, 1357 | "_view_module": "@jupyter-widgets/controls", 1358 | "_view_module_version": "1.5.0", 1359 | "_view_name": "ProgressView", 1360 | "bar_style": "success", 1361 | "description": "Downloading: 100%", 1362 | "description_tooltip": null, 1363 | "layout": "IPY_MODEL_c2d4e9c398b647f1965ca6c818446c28", 1364 | "max": 379, 1365 | "min": 0, 1366 | "orientation": "horizontal", 1367 | "style": "IPY_MODEL_5a37e3838ef146d2aef32cddfaf8daf4", 1368 | "value": 379 1369 | } 1370 | }, 1371 | "b63e42d403bd4c2bb68f8c4709ae1dc4": { 1372 | "model_module": "@jupyter-widgets/controls", 1373 | "model_name": "HBoxModel", 1374 | "state": { 1375 | "_dom_classes": [], 1376 | "_model_module": "@jupyter-widgets/controls", 1377 | "_model_module_version": "1.5.0", 1378 | "_model_name": "HBoxModel", 1379 | "_view_count": null, 1380 | "_view_module": "@jupyter-widgets/controls", 1381 | "_view_module_version": "1.5.0", 1382 | "_view_name": "HBoxView", 1383 | "box_style": "", 1384 | "children": [ 1385 | "IPY_MODEL_2367e98e314848fcbf34969075b276b3", 1386 | "IPY_MODEL_2b239899205b4e35a1fe9a24f578e2c2" 1387 | ], 1388 | "layout": "IPY_MODEL_a3be3ff323e842b1a0927a62eddf462d" 1389 | } 1390 | }, 1391 | "b8fda7b5b9cc46f0865fe124fdaed067": { 1392 | "model_module": "@jupyter-widgets/controls", 1393 | "model_name": "FloatProgressModel", 1394 | "state": { 1395 | "_dom_classes": [], 1396 | "_model_module": "@jupyter-widgets/controls", 1397 | "_model_module_version": "1.5.0", 1398 | "_model_name": "FloatProgressModel", 1399 | "_view_count": null, 1400 | "_view_module": "@jupyter-widgets/controls", 1401 | "_view_module_version": "1.5.0", 1402 | "_view_name": "ProgressView", 1403 | "bar_style": "success", 1404 | "description": "Downloading: 100%", 1405 | "description_tooltip": null, 1406 | "layout": "IPY_MODEL_07b9064cf6804ebbbb26419522c49700", 1407 | "max": 576, 1408 | "min": 0, 1409 | "orientation": "horizontal", 1410 | "style": "IPY_MODEL_276a05360cb0423a80b0afe2c211e69c", 1411 | "value": 576 1412 | } 1413 | }, 1414 | "b9f50346f8164d2c99722a21646b4b05": { 1415 | "model_module": "@jupyter-widgets/controls", 1416 | "model_name": "DescriptionStyleModel", 1417 | "state": { 1418 | "_model_module": "@jupyter-widgets/controls", 1419 | "_model_module_version": "1.5.0", 1420 | "_model_name": "DescriptionStyleModel", 1421 | "_view_count": null, 1422 | "_view_module": "@jupyter-widgets/base", 1423 | "_view_module_version": "1.2.0", 1424 | "_view_name": "StyleView", 1425 | "description_width": "" 1426 | } 1427 | }, 1428 | "be044d27d62342cb91d633eeddf52a58": { 1429 | "model_module": "@jupyter-widgets/base", 1430 | "model_name": "LayoutModel", 1431 | "state": { 1432 | "_model_module": "@jupyter-widgets/base", 1433 | "_model_module_version": "1.2.0", 1434 | "_model_name": "LayoutModel", 1435 | "_view_count": null, 1436 | "_view_module": "@jupyter-widgets/base", 1437 | "_view_module_version": "1.2.0", 1438 | "_view_name": "LayoutView", 1439 | "align_content": null, 1440 | "align_items": null, 1441 | "align_self": null, 1442 | "border": null, 1443 | "bottom": null, 1444 | "display": null, 1445 | "flex": null, 1446 | "flex_flow": null, 1447 | "grid_area": null, 1448 | "grid_auto_columns": null, 1449 | "grid_auto_flow": null, 1450 | "grid_auto_rows": null, 1451 | "grid_column": null, 1452 | "grid_gap": null, 1453 | "grid_row": null, 1454 | "grid_template_areas": null, 1455 | "grid_template_columns": null, 1456 | "grid_template_rows": null, 1457 | "height": null, 1458 | "justify_content": null, 1459 | "justify_items": null, 1460 | "left": null, 1461 | "margin": null, 1462 | "max_height": null, 1463 | "max_width": null, 1464 | "min_height": null, 1465 | "min_width": null, 1466 | "object_fit": null, 1467 | "object_position": null, 1468 | "order": null, 1469 | "overflow": null, 1470 | "overflow_x": null, 1471 | "overflow_y": null, 1472 | "padding": null, 1473 | "right": null, 1474 | "top": null, 1475 | "visibility": null, 1476 | "width": null 1477 | } 1478 | }, 1479 | "c2d4e9c398b647f1965ca6c818446c28": { 1480 | "model_module": "@jupyter-widgets/base", 1481 | "model_name": "LayoutModel", 1482 | "state": { 1483 | "_model_module": "@jupyter-widgets/base", 1484 | "_model_module_version": "1.2.0", 1485 | "_model_name": "LayoutModel", 1486 | "_view_count": null, 1487 | "_view_module": "@jupyter-widgets/base", 1488 | "_view_module_version": "1.2.0", 1489 | "_view_name": "LayoutView", 1490 | "align_content": null, 1491 | "align_items": null, 1492 | "align_self": null, 1493 | "border": null, 1494 | "bottom": null, 1495 | "display": null, 1496 | "flex": null, 1497 | "flex_flow": null, 1498 | "grid_area": null, 1499 | "grid_auto_columns": null, 1500 | "grid_auto_flow": null, 1501 | "grid_auto_rows": null, 1502 | "grid_column": null, 1503 | "grid_gap": null, 1504 | "grid_row": null, 1505 | "grid_template_areas": null, 1506 | "grid_template_columns": null, 1507 | "grid_template_rows": null, 1508 | "height": null, 1509 | "justify_content": null, 1510 | "justify_items": null, 1511 | "left": null, 1512 | "margin": null, 1513 | "max_height": null, 1514 | "max_width": null, 1515 | "min_height": null, 1516 | "min_width": null, 1517 | "object_fit": null, 1518 | "object_position": null, 1519 | "order": null, 1520 | "overflow": null, 1521 | "overflow_x": null, 1522 | "overflow_y": null, 1523 | "padding": null, 1524 | "right": null, 1525 | "top": null, 1526 | "visibility": null, 1527 | "width": null 1528 | } 1529 | }, 1530 | "c5a25438e62945ca84bf8f386729811b": { 1531 | "model_module": "@jupyter-widgets/base", 1532 | "model_name": "LayoutModel", 1533 | "state": { 1534 | "_model_module": "@jupyter-widgets/base", 1535 | "_model_module_version": "1.2.0", 1536 | "_model_name": "LayoutModel", 1537 | "_view_count": null, 1538 | "_view_module": "@jupyter-widgets/base", 1539 | "_view_module_version": "1.2.0", 1540 | "_view_name": "LayoutView", 1541 | "align_content": null, 1542 | "align_items": null, 1543 | "align_self": null, 1544 | "border": null, 1545 | "bottom": null, 1546 | "display": null, 1547 | "flex": null, 1548 | "flex_flow": null, 1549 | "grid_area": null, 1550 | "grid_auto_columns": null, 1551 | "grid_auto_flow": null, 1552 | "grid_auto_rows": null, 1553 | "grid_column": null, 1554 | "grid_gap": null, 1555 | "grid_row": null, 1556 | "grid_template_areas": null, 1557 | "grid_template_columns": null, 1558 | "grid_template_rows": null, 1559 | "height": null, 1560 | "justify_content": null, 1561 | "justify_items": null, 1562 | "left": null, 1563 | "margin": null, 1564 | "max_height": null, 1565 | "max_width": null, 1566 | "min_height": null, 1567 | "min_width": null, 1568 | "object_fit": null, 1569 | "object_position": null, 1570 | "order": null, 1571 | "overflow": null, 1572 | "overflow_x": null, 1573 | "overflow_y": null, 1574 | "padding": null, 1575 | "right": null, 1576 | "top": null, 1577 | "visibility": null, 1578 | "width": null 1579 | } 1580 | }, 1581 | "ca650af9b9954d569e2523731121cb88": { 1582 | "model_module": "@jupyter-widgets/controls", 1583 | "model_name": "HBoxModel", 1584 | "state": { 1585 | "_dom_classes": [], 1586 | "_model_module": "@jupyter-widgets/controls", 1587 | "_model_module_version": "1.5.0", 1588 | "_model_name": "HBoxModel", 1589 | "_view_count": null, 1590 | "_view_module": "@jupyter-widgets/controls", 1591 | "_view_module_version": "1.5.0", 1592 | "_view_name": "HBoxView", 1593 | "box_style": "", 1594 | "children": [ 1595 | "IPY_MODEL_b8fda7b5b9cc46f0865fe124fdaed067", 1596 | "IPY_MODEL_169dd26ec2c34af2b23624f81afcebc3" 1597 | ], 1598 | "layout": "IPY_MODEL_cce331dde52f402083d93a66f5b30b8f" 1599 | } 1600 | }, 1601 | "cce331dde52f402083d93a66f5b30b8f": { 1602 | "model_module": "@jupyter-widgets/base", 1603 | "model_name": "LayoutModel", 1604 | "state": { 1605 | "_model_module": "@jupyter-widgets/base", 1606 | "_model_module_version": "1.2.0", 1607 | "_model_name": "LayoutModel", 1608 | "_view_count": null, 1609 | "_view_module": "@jupyter-widgets/base", 1610 | "_view_module_version": "1.2.0", 1611 | "_view_name": "LayoutView", 1612 | "align_content": null, 1613 | "align_items": null, 1614 | "align_self": null, 1615 | "border": null, 1616 | "bottom": null, 1617 | "display": null, 1618 | "flex": null, 1619 | "flex_flow": null, 1620 | "grid_area": null, 1621 | "grid_auto_columns": null, 1622 | "grid_auto_flow": null, 1623 | "grid_auto_rows": null, 1624 | "grid_column": null, 1625 | "grid_gap": null, 1626 | "grid_row": null, 1627 | "grid_template_areas": null, 1628 | "grid_template_columns": null, 1629 | "grid_template_rows": null, 1630 | "height": null, 1631 | "justify_content": null, 1632 | "justify_items": null, 1633 | "left": null, 1634 | "margin": null, 1635 | "max_height": null, 1636 | "max_width": null, 1637 | "min_height": null, 1638 | "min_width": null, 1639 | "object_fit": null, 1640 | "object_position": null, 1641 | "order": null, 1642 | "overflow": null, 1643 | "overflow_x": null, 1644 | "overflow_y": null, 1645 | "padding": null, 1646 | "right": null, 1647 | "top": null, 1648 | "visibility": null, 1649 | "width": null 1650 | } 1651 | }, 1652 | "cf79d2f4a1d84c2bae5b07fb2839f0ab": { 1653 | "model_module": "@jupyter-widgets/controls", 1654 | "model_name": "HBoxModel", 1655 | "state": { 1656 | "_dom_classes": [], 1657 | "_model_module": "@jupyter-widgets/controls", 1658 | "_model_module_version": "1.5.0", 1659 | "_model_name": "HBoxModel", 1660 | "_view_count": null, 1661 | "_view_module": "@jupyter-widgets/controls", 1662 | "_view_module_version": "1.5.0", 1663 | "_view_name": "HBoxView", 1664 | "box_style": "", 1665 | "children": [ 1666 | "IPY_MODEL_8b9043ffe8fb4313869fd9ca78021a93", 1667 | "IPY_MODEL_93d29cab94e84cae9b7a6f8ecccacba6" 1668 | ], 1669 | "layout": "IPY_MODEL_5f26156efe6d49379ee39e75bae41840" 1670 | } 1671 | }, 1672 | "d79e114bea844160bc941d81d4581b46": { 1673 | "model_module": "@jupyter-widgets/base", 1674 | "model_name": "LayoutModel", 1675 | "state": { 1676 | "_model_module": "@jupyter-widgets/base", 1677 | "_model_module_version": "1.2.0", 1678 | "_model_name": "LayoutModel", 1679 | "_view_count": null, 1680 | "_view_module": "@jupyter-widgets/base", 1681 | "_view_module_version": "1.2.0", 1682 | "_view_name": "LayoutView", 1683 | "align_content": null, 1684 | "align_items": null, 1685 | "align_self": null, 1686 | "border": null, 1687 | "bottom": null, 1688 | "display": null, 1689 | "flex": null, 1690 | "flex_flow": null, 1691 | "grid_area": null, 1692 | "grid_auto_columns": null, 1693 | "grid_auto_flow": null, 1694 | "grid_auto_rows": null, 1695 | "grid_column": null, 1696 | "grid_gap": null, 1697 | "grid_row": null, 1698 | "grid_template_areas": null, 1699 | "grid_template_columns": null, 1700 | "grid_template_rows": null, 1701 | "height": null, 1702 | "justify_content": null, 1703 | "justify_items": null, 1704 | "left": null, 1705 | "margin": null, 1706 | "max_height": null, 1707 | "max_width": null, 1708 | "min_height": null, 1709 | "min_width": null, 1710 | "object_fit": null, 1711 | "object_position": null, 1712 | "order": null, 1713 | "overflow": null, 1714 | "overflow_x": null, 1715 | "overflow_y": null, 1716 | "padding": null, 1717 | "right": null, 1718 | "top": null, 1719 | "visibility": null, 1720 | "width": null 1721 | } 1722 | }, 1723 | "dbaf10b01e9b42299f67b4095b210499": { 1724 | "model_module": "@jupyter-widgets/controls", 1725 | "model_name": "HTMLModel", 1726 | "state": { 1727 | "_dom_classes": [], 1728 | "_model_module": "@jupyter-widgets/controls", 1729 | "_model_module_version": "1.5.0", 1730 | "_model_name": "HTMLModel", 1731 | "_view_count": null, 1732 | "_view_module": "@jupyter-widgets/controls", 1733 | "_view_module_version": "1.5.0", 1734 | "_view_name": "HTMLView", 1735 | "description": "", 1736 | "description_tooltip": null, 1737 | "layout": "IPY_MODEL_f3d35e8d289f4be9870e4c55e695bda8", 1738 | "placeholder": "​", 1739 | "style": "IPY_MODEL_5f7bea12320e41e2a5b16da0987d14fa", 1740 | "value": " 379/379 [00:00<00:00, 1.96kB/s]" 1741 | } 1742 | }, 1743 | "dbd206abbcac414d8730ff7cf4ef55d8": { 1744 | "model_module": "@jupyter-widgets/base", 1745 | "model_name": "LayoutModel", 1746 | "state": { 1747 | "_model_module": "@jupyter-widgets/base", 1748 | "_model_module_version": "1.2.0", 1749 | "_model_name": "LayoutModel", 1750 | "_view_count": null, 1751 | "_view_module": "@jupyter-widgets/base", 1752 | "_view_module_version": "1.2.0", 1753 | "_view_name": "LayoutView", 1754 | "align_content": null, 1755 | "align_items": null, 1756 | "align_self": null, 1757 | "border": null, 1758 | "bottom": null, 1759 | "display": null, 1760 | "flex": null, 1761 | "flex_flow": null, 1762 | "grid_area": null, 1763 | "grid_auto_columns": null, 1764 | "grid_auto_flow": null, 1765 | "grid_auto_rows": null, 1766 | "grid_column": null, 1767 | "grid_gap": null, 1768 | "grid_row": null, 1769 | "grid_template_areas": null, 1770 | "grid_template_columns": null, 1771 | "grid_template_rows": null, 1772 | "height": null, 1773 | "justify_content": null, 1774 | "justify_items": null, 1775 | "left": null, 1776 | "margin": null, 1777 | "max_height": null, 1778 | "max_width": null, 1779 | "min_height": null, 1780 | "min_width": null, 1781 | "object_fit": null, 1782 | "object_position": null, 1783 | "order": null, 1784 | "overflow": null, 1785 | "overflow_x": null, 1786 | "overflow_y": null, 1787 | "padding": null, 1788 | "right": null, 1789 | "top": null, 1790 | "visibility": null, 1791 | "width": null 1792 | } 1793 | }, 1794 | "e24e4163978a4f068ec80beab9ad0362": { 1795 | "model_module": "@jupyter-widgets/controls", 1796 | "model_name": "ProgressStyleModel", 1797 | "state": { 1798 | "_model_module": "@jupyter-widgets/controls", 1799 | "_model_module_version": "1.5.0", 1800 | "_model_name": "ProgressStyleModel", 1801 | "_view_count": null, 1802 | "_view_module": "@jupyter-widgets/base", 1803 | "_view_module_version": "1.2.0", 1804 | "_view_name": "StyleView", 1805 | "bar_color": null, 1806 | "description_width": "initial" 1807 | } 1808 | }, 1809 | "ec259b6cb7654b819f6e33be6b99a58e": { 1810 | "model_module": "@jupyter-widgets/controls", 1811 | "model_name": "ProgressStyleModel", 1812 | "state": { 1813 | "_model_module": "@jupyter-widgets/controls", 1814 | "_model_module_version": "1.5.0", 1815 | "_model_name": "ProgressStyleModel", 1816 | "_view_count": null, 1817 | "_view_module": "@jupyter-widgets/base", 1818 | "_view_module_version": "1.2.0", 1819 | "_view_name": "StyleView", 1820 | "bar_color": null, 1821 | "description_width": "initial" 1822 | } 1823 | }, 1824 | "f2c5b053dc2e4a7590710c091ca5aa1e": { 1825 | "model_module": "@jupyter-widgets/base", 1826 | "model_name": "LayoutModel", 1827 | "state": { 1828 | "_model_module": "@jupyter-widgets/base", 1829 | "_model_module_version": "1.2.0", 1830 | "_model_name": "LayoutModel", 1831 | "_view_count": null, 1832 | "_view_module": "@jupyter-widgets/base", 1833 | "_view_module_version": "1.2.0", 1834 | "_view_name": "LayoutView", 1835 | "align_content": null, 1836 | "align_items": null, 1837 | "align_self": null, 1838 | "border": null, 1839 | "bottom": null, 1840 | "display": null, 1841 | "flex": null, 1842 | "flex_flow": null, 1843 | "grid_area": null, 1844 | "grid_auto_columns": null, 1845 | "grid_auto_flow": null, 1846 | "grid_auto_rows": null, 1847 | "grid_column": null, 1848 | "grid_gap": null, 1849 | "grid_row": null, 1850 | "grid_template_areas": null, 1851 | "grid_template_columns": null, 1852 | "grid_template_rows": null, 1853 | "height": null, 1854 | "justify_content": null, 1855 | "justify_items": null, 1856 | "left": null, 1857 | "margin": null, 1858 | "max_height": null, 1859 | "max_width": null, 1860 | "min_height": null, 1861 | "min_width": null, 1862 | "object_fit": null, 1863 | "object_position": null, 1864 | "order": null, 1865 | "overflow": null, 1866 | "overflow_x": null, 1867 | "overflow_y": null, 1868 | "padding": null, 1869 | "right": null, 1870 | "top": null, 1871 | "visibility": null, 1872 | "width": null 1873 | } 1874 | }, 1875 | "f3d35e8d289f4be9870e4c55e695bda8": { 1876 | "model_module": "@jupyter-widgets/base", 1877 | "model_name": "LayoutModel", 1878 | "state": { 1879 | "_model_module": "@jupyter-widgets/base", 1880 | "_model_module_version": "1.2.0", 1881 | "_model_name": "LayoutModel", 1882 | "_view_count": null, 1883 | "_view_module": "@jupyter-widgets/base", 1884 | "_view_module_version": "1.2.0", 1885 | "_view_name": "LayoutView", 1886 | "align_content": null, 1887 | "align_items": null, 1888 | "align_self": null, 1889 | "border": null, 1890 | "bottom": null, 1891 | "display": null, 1892 | "flex": null, 1893 | "flex_flow": null, 1894 | "grid_area": null, 1895 | "grid_auto_columns": null, 1896 | "grid_auto_flow": null, 1897 | "grid_auto_rows": null, 1898 | "grid_column": null, 1899 | "grid_gap": null, 1900 | "grid_row": null, 1901 | "grid_template_areas": null, 1902 | "grid_template_columns": null, 1903 | "grid_template_rows": null, 1904 | "height": null, 1905 | "justify_content": null, 1906 | "justify_items": null, 1907 | "left": null, 1908 | "margin": null, 1909 | "max_height": null, 1910 | "max_width": null, 1911 | "min_height": null, 1912 | "min_width": null, 1913 | "object_fit": null, 1914 | "object_position": null, 1915 | "order": null, 1916 | "overflow": null, 1917 | "overflow_x": null, 1918 | "overflow_y": null, 1919 | "padding": null, 1920 | "right": null, 1921 | "top": null, 1922 | "visibility": null, 1923 | "width": null 1924 | } 1925 | }, 1926 | "fabc123e8ab14205a81fee5e8e6484dc": { 1927 | "model_module": "@jupyter-widgets/base", 1928 | "model_name": "LayoutModel", 1929 | "state": { 1930 | "_model_module": "@jupyter-widgets/base", 1931 | "_model_module_version": "1.2.0", 1932 | "_model_name": "LayoutModel", 1933 | "_view_count": null, 1934 | "_view_module": "@jupyter-widgets/base", 1935 | "_view_module_version": "1.2.0", 1936 | "_view_name": "LayoutView", 1937 | "align_content": null, 1938 | "align_items": null, 1939 | "align_self": null, 1940 | "border": null, 1941 | "bottom": null, 1942 | "display": null, 1943 | "flex": null, 1944 | "flex_flow": null, 1945 | "grid_area": null, 1946 | "grid_auto_columns": null, 1947 | "grid_auto_flow": null, 1948 | "grid_auto_rows": null, 1949 | "grid_column": null, 1950 | "grid_gap": null, 1951 | "grid_row": null, 1952 | "grid_template_areas": null, 1953 | "grid_template_columns": null, 1954 | "grid_template_rows": null, 1955 | "height": null, 1956 | "justify_content": null, 1957 | "justify_items": null, 1958 | "left": null, 1959 | "margin": null, 1960 | "max_height": null, 1961 | "max_width": null, 1962 | "min_height": null, 1963 | "min_width": null, 1964 | "object_fit": null, 1965 | "object_position": null, 1966 | "order": null, 1967 | "overflow": null, 1968 | "overflow_x": null, 1969 | "overflow_y": null, 1970 | "padding": null, 1971 | "right": null, 1972 | "top": null, 1973 | "visibility": null, 1974 | "width": null 1975 | } 1976 | }, 1977 | "fb595227f20f4a18aaef097807dede5b": { 1978 | "model_module": "@jupyter-widgets/base", 1979 | "model_name": "LayoutModel", 1980 | "state": { 1981 | "_model_module": "@jupyter-widgets/base", 1982 | "_model_module_version": "1.2.0", 1983 | "_model_name": "LayoutModel", 1984 | "_view_count": null, 1985 | "_view_module": "@jupyter-widgets/base", 1986 | "_view_module_version": "1.2.0", 1987 | "_view_name": "LayoutView", 1988 | "align_content": null, 1989 | "align_items": null, 1990 | "align_self": null, 1991 | "border": null, 1992 | "bottom": null, 1993 | "display": null, 1994 | "flex": null, 1995 | "flex_flow": null, 1996 | "grid_area": null, 1997 | "grid_auto_columns": null, 1998 | "grid_auto_flow": null, 1999 | "grid_auto_rows": null, 2000 | "grid_column": null, 2001 | "grid_gap": null, 2002 | "grid_row": null, 2003 | "grid_template_areas": null, 2004 | "grid_template_columns": null, 2005 | "grid_template_rows": null, 2006 | "height": null, 2007 | "justify_content": null, 2008 | "justify_items": null, 2009 | "left": null, 2010 | "margin": null, 2011 | "max_height": null, 2012 | "max_width": null, 2013 | "min_height": null, 2014 | "min_width": null, 2015 | "object_fit": null, 2016 | "object_position": null, 2017 | "order": null, 2018 | "overflow": null, 2019 | "overflow_x": null, 2020 | "overflow_y": null, 2021 | "padding": null, 2022 | "right": null, 2023 | "top": null, 2024 | "visibility": null, 2025 | "width": null 2026 | } 2027 | } 2028 | } 2029 | } 2030 | }, 2031 | "nbformat": 4, 2032 | "nbformat_minor": 0 2033 | } 2034 | -------------------------------------------------------------------------------- /fine_tune_title_generation.py: -------------------------------------------------------------------------------- 1 | import nmatheg as nm 2 | strategy = nm.TrainStrategy( 3 | datasets = 'ARGEN_title_generation', 4 | models = 'UBC-NLP/AraT5-base', 5 | runs = 10, 6 | lr = 5e-5, 7 | epochs = 20, 8 | batch_size = 4, 9 | max_tokens = 128, 10 | ) 11 | output = strategy.start() -------------------------------------------------------------------------------- /nmatheg/__init__.py: -------------------------------------------------------------------------------- 1 | from nmatheg.nmatheg import TrainStrategy 2 | from nmatheg.nmatheg import predict_from_run -------------------------------------------------------------------------------- /nmatheg/config.ini: -------------------------------------------------------------------------------- 1 | [dataset] 2 | dataset_name = ajgt_twitter_ar 3 | 4 | [preprocessing] 5 | segment = False 6 | remove_special_chars = False 7 | remove_english = False 8 | normalize = False 9 | remove_diacritics = False 10 | excluded_chars = [] 11 | remove_tatweel = False 12 | remove_html_elements = False 13 | remove_links = False 14 | remove_twitter_meta = False 15 | remove_long_words = False 16 | remove_repeated_chars = False 17 | 18 | [tokenization] 19 | tokenizer_name = WordTokenizer 20 | vocab_size = 10000 21 | max_tokens = 128 22 | 23 | [model] 24 | model_name = bert-base-arabertv01 25 | 26 | [train] 27 | epochs = 10 28 | batch_size = 256 29 | save_dir = . 30 | -------------------------------------------------------------------------------- /nmatheg/configs.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | def create_default_config(batch_size = 64, epochs = 5, lr = 5e-5, runs = 10, max_tokens = 64, 3 | max_train_samples = -1, preprocessing = {}, ckpt = 'ckpts'): 4 | config = configparser.ConfigParser() 5 | 6 | config['preprocessing'] = { 7 | 'segment' : False, 8 | 'remove_special_chars' : False, 9 | 'remove_english' : False, 10 | 'normalize' : False, 11 | 'remove_diacritics' : False, 12 | 'excluded_chars' : [], 13 | 'remove_tatweel' : False, 14 | 'remove_html_elements' : False, 15 | 'remove_links' : False, 16 | 'remove_twitter_meta' : False, 17 | 'remove_long_words' : False, 18 | 'remove_repeated_chars' : False, 19 | } 20 | 21 | for arg in preprocessing: 22 | config['preprocessing'][arg] = preprocessing[arg] 23 | 24 | config['tokenization'] = { 25 | 'max_tokens' : max_tokens, 26 | 'tok_save_path': 'ckpts', 27 | 'max_train_samples': max_train_samples 28 | } 29 | 30 | config['log'] = {'print_every':10} 31 | 32 | config['train'] = { 33 | 'save_dir' : ckpt, 34 | 'epochs' : epochs, 35 | 'batch_size' : batch_size, 36 | 'lr': lr, 37 | 'runs': runs 38 | } 39 | return config -------------------------------------------------------------------------------- /nmatheg/dataset.py: -------------------------------------------------------------------------------- 1 | 2 | from genericpath import isdir 3 | from regex import E 4 | import tnkeeh as tn 5 | from datasets import load_dataset, load_from_disk 6 | try: 7 | import bpe_surgery 8 | except: 9 | pass 10 | 11 | import os 12 | from .utils import get_preprocessing_args, get_tokenizer 13 | from transformers import AutoTokenizer 14 | import torch 15 | from .preprocess_ner import aggregate_tokens, tokenize_and_align_labels 16 | from .preprocess_qa import prepare_features 17 | import copy 18 | 19 | def split_dataset(dataset, data_config, seed = 42, max_train_samples = -1): 20 | split_names = ['train', 'valid', 'test'] 21 | 22 | for i, split_name in enumerate(['train', 'valid', 'test']): 23 | if split_name in data_config: 24 | split_names[i] = data_config[split_name] 25 | dataset[split_name] = dataset[split_names[i]] 26 | 27 | if max_train_samples < len(dataset['train']) and max_train_samples != -1: 28 | print(f"truncating train samples from {len(dataset['train'])} to {max_train_samples}") 29 | dataset['train'] = dataset['train'].select(range(max_train_samples)) 30 | 31 | #create validation split 32 | if 'valid' not in dataset: 33 | train_valid_dataset = dataset['train'].train_test_split(test_size=0.1, seed = seed) 34 | dataset['valid'] = train_valid_dataset.pop('test') 35 | dataset['train'] = train_valid_dataset['train'] 36 | 37 | #create training split 38 | if 'test' not in dataset: 39 | train_valid_dataset = dataset['train'].train_test_split(test_size=0.1, seed = seed) 40 | dataset['test'] = train_valid_dataset.pop('test') 41 | dataset['train'] = train_valid_dataset['train'] 42 | 43 | columns = list(dataset.keys()) 44 | for key in columns: 45 | if key not in ['train', 'valid', 'test']: 46 | del dataset[key] 47 | return dataset 48 | 49 | 50 | def clean_dataset(dataset, config, data_config, task = 'cls'): 51 | if task == 'mt': 52 | sourceString, targetString = data_config['text'].split(',') 53 | args = get_preprocessing_args(config) 54 | cleaner = tn.Tnkeeh(**args) 55 | dataset = cleaner.clean_hf_dataset(dataset, targetString) 56 | return dataset 57 | elif task == 'qa': 58 | question, context = data_config['text'].split(',') 59 | args = get_preprocessing_args(config) 60 | cleaner = tn.Tnkeeh(**args) 61 | dataset = cleaner.clean_hf_dataset(dataset, question) 62 | return dataset 63 | elif task == 'nli': 64 | premise, hypothesis = data_config['text'].split(',') 65 | args = get_preprocessing_args(config) 66 | cleaner = tn.Tnkeeh(**args) 67 | dataset = cleaner.clean_hf_dataset(dataset, premise) 68 | dataset = cleaner.clean_hf_dataset(dataset, hypothesis) 69 | return dataset 70 | else: 71 | args = get_preprocessing_args(config) 72 | cleaner = tn.Tnkeeh(**args) 73 | dataset = cleaner.clean_hf_dataset(dataset, data_config['text']) 74 | return dataset 75 | 76 | def write_data_for_train(dataset, text, path, task = 'cls'): 77 | data = [] 78 | if task == 'cls': 79 | for sample in dataset: 80 | data.append(sample[text]) 81 | elif task == 'nli': 82 | for sample in dataset: 83 | hypothesis, premise = text.split(",") 84 | data.append(sample[hypothesis]+" "+sample[premise]) 85 | elif task == 'ner': 86 | for sample in dataset: 87 | data.append(' '.join(sample[text])) 88 | elif task == 'qa': 89 | for sample in dataset: 90 | context, question = text.split(",") 91 | data.append(sample[context]+" "+sample[question]) 92 | elif task == 'mt': 93 | for sample in dataset: 94 | context, question = text.split(",") 95 | data.append(sample[context]+" "+sample[question]) 96 | 97 | open(f'{path}/data.txt', 'w').write(('\n').join(data)) 98 | 99 | def get_prev_tokenizer(save_dir, tokenizer_name, vocab_size, dataset_name, model_name): 100 | prev_vocab_sizes = [int(v) for v in os.listdir(f"{save_dir}/{tokenizer_name}") if int(v) != vocab_size and dataset_name in os.listdir(f"{save_dir}/{tokenizer_name}/{v}")] 101 | 102 | if len(prev_vocab_sizes) == 0: 103 | return "" 104 | else: 105 | return f"{save_dir}/{tokenizer_name}/{max(prev_vocab_sizes)}/{dataset_name}/{model_name}/tokenizer" 106 | 107 | def create_dataset(config, data_config, vocab_size = 300, 108 | model_name = "birnn", tokenizer_name = "bpe", clean = True, mode = "finetune", 109 | tok_save_path = None, data_save_path = None): 110 | 111 | hf_dataset_name = data_config['name'] 112 | dataset_name = hf_dataset_name.split("/")[-1] #in case we have / in the name 113 | max_tokens = int(config['tokenization']['max_tokens']) 114 | max_train_samples = int(config['tokenization']['max_train_samples']) 115 | save_dir = config['train']['save_dir'] 116 | prev_tok_save_path = "" 117 | if mode == "pretrain": 118 | prev_tok_save_path = get_prev_tokenizer(save_dir, tokenizer_name, vocab_size, dataset_name, model_name) 119 | 120 | batch_size = int(config['train']['batch_size']) 121 | task_name = data_config['task'] 122 | 123 | if 'subset' in data_config: 124 | dataset = load_dataset(hf_dataset_name, data_config['subset']) 125 | else: 126 | dataset = load_dataset(hf_dataset_name) 127 | 128 | if task_name != "qa" and clean: 129 | dataset = clean_dataset(dataset, config, data_config, task = task_name) 130 | 131 | dataset = split_dataset(dataset, data_config, max_train_samples=max_train_samples) 132 | examples = copy.deepcopy(dataset) 133 | print(dataset) 134 | if 'birnn' in model_name: 135 | model_type = 'rnn' 136 | else: 137 | model_type = 'transformer' 138 | if task_name == 'cls': 139 | # tokenize data 140 | if 'birnn' not in model_name: 141 | tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, model_max_length = 512) 142 | if not os.path.isfile(f"{data_save_path}/dataset_dict.json"): 143 | dataset = dataset.map(lambda examples:tokenizer(examples[data_config['text']], truncation=True, padding='max_length'), batched=True) 144 | dataset = dataset.map(lambda examples:{'labels': examples[data_config['label']]}, batched=True) 145 | dataset.save_to_disk(data_save_path) 146 | else: 147 | dataset = load_from_disk(data_save_path) 148 | columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'] 149 | else: 150 | tokenizer = get_tokenizer(tokenizer_name, vocab_size= vocab_size) 151 | 152 | if os.path.isfile(f"{tok_save_path}/tok.model"): 153 | print('loading pretrained tokenizer') 154 | tokenizer.load(tok_save_path) 155 | dataset = load_from_disk(data_save_path) 156 | else: 157 | write_data_for_train(dataset['train'], data_config['text'], data_save_path) 158 | if prev_tok_save_path != "": 159 | tokenizer.load(prev_tok_save_path) 160 | else: 161 | print('training tokenizer from scratch') 162 | tokenizer.train(file_path = f'{data_save_path}/data.txt') 163 | tokenizer.save_model(f"{tok_save_path}/m.model") 164 | dataset = dataset.map(lambda examples:{'input_ids': tokenizer.encode_sentences(examples[data_config['text']], out_length= max_tokens)}, batched=True) 165 | dataset = dataset.map(lambda examples:{'labels': examples[data_config['label']]}, batched=True) 166 | dataset.save_to_disk(data_save_path) 167 | columns=['input_ids', 'labels'] 168 | 169 | elif task_name == 'nli': 170 | # tokenize data 171 | premise, hypothesis = data_config['text'].split(",") 172 | if 'birnn' not in model_name: 173 | tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, model_max_length = 512) 174 | def concat(examples): 175 | texts = (examples[premise], examples[hypothesis]) 176 | result = tokenizer(*texts, truncation=True, padding='max_length') 177 | return result 178 | 179 | if not os.path.isfile(f"{data_save_path}/dataset_dict.json"): 180 | dataset = dataset.map(concat, batched=True) 181 | dataset.save_to_disk(data_save_path) 182 | else: 183 | load_dataset(data_save_path) 184 | columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'] 185 | else: 186 | tokenizer = get_tokenizer(tokenizer_name, vocab_size= vocab_size) 187 | if os.path.isfile(f"{tok_save_path}/tok.model"): 188 | print('loading pretrained tokenizer') 189 | tokenizer.load(tok_save_path) 190 | dataset = load_from_disk(data_save_path) 191 | else: 192 | 193 | write_data_for_train(dataset['train'], data_config['text'], data_save_path, task = 'nli') 194 | if prev_tok_save_path != "": 195 | tokenizer.load(prev_tok_save_path) 196 | else: 197 | print('training tokenizer from scratch') 198 | tokenizer.train(file_path = f"{data_save_path}/data.txt") 199 | tokenizer.save(tok_save_path) 200 | 201 | def concat(example): 202 | example["text"] = example[premise] + ' ' + example[hypothesis] 203 | return example 204 | 205 | dataset = dataset.map(lambda examples:{'input_ids': tokenizer.encode_sentences(sentences1 = examples[premise], sentences2 = examples[hypothesis], out_length= max_tokens)}, batched=True) 206 | dataset = dataset.map(lambda examples:{'labels': examples[data_config['label']]}, batched=True) 207 | dataset.save_to_disk(data_save_path) 208 | columns=['input_ids', 'labels'] 209 | 210 | elif task_name in ['ner', 'pos']: 211 | dataset = aggregate_tokens(dataset, config, data_config) 212 | if 'birnn' not in model_name: 213 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) 214 | if not os.path.isfile(f"{data_save_path}/dataset_dict.json"): 215 | print('aligining the tokens ...') 216 | for split in dataset: 217 | dataset[split] = dataset[split].map(lambda x: tokenize_and_align_labels(x, tokenizer, data_config, model_type = model_type) 218 | , batched=True, remove_columns=dataset[split].column_names) 219 | dataset.save_to_disk(data_save_path) 220 | else: 221 | dataset = load_from_disk(data_save_path) 222 | columns=['input_ids', 'attention_mask', 'labels'] 223 | else: 224 | tokenizer = get_tokenizer(tokenizer_name, vocab_size= vocab_size) 225 | 226 | if os.path.isfile(f"{tok_save_path}/tok.model"): 227 | print('loading pretrained tokenizer') 228 | tokenizer.load(tok_save_path) 229 | dataset = load_from_disk(data_save_path) 230 | else: 231 | write_data_for_train(dataset['train'], data_config['text'], data_save_path, task = task_name) 232 | if prev_tok_save_path != "": 233 | tokenizer.load(prev_tok_save_path) 234 | else: 235 | print('training tokenizer from scratch') 236 | tokenizer.train(file_path = f'{data_save_path}/data.txt') 237 | tokenizer.save(tok_save_path) 238 | print('aligining the tokens ...') 239 | for split in dataset: 240 | dataset[split] = dataset[split].map(lambda x: tokenize_and_align_labels(x, tokenizer, data_config, model_type = model_type) 241 | , batched=True, remove_columns=dataset[split].column_names) 242 | dataset.save_to_disk(data_save_path) 243 | 244 | columns=['input_ids', 'labels'] 245 | 246 | elif task_name == 'qa': 247 | if 'birnn' not in model_name: 248 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) 249 | if not os.path.isfile(f"{data_save_path}/dataset_dict.json"): 250 | for split in dataset: 251 | dataset[split] = dataset[split].map(lambda x: prepare_features(x, tokenizer, data_config, model_type = model_type) 252 | , batched=True, remove_columns=dataset[split].column_names) 253 | dataset.save_to_disk(data_save_path) 254 | else: 255 | dataset = load_from_disk(data_save_path) 256 | columns=['input_ids', 'attention_mask', 'start_positions', 'end_positions'] 257 | else: 258 | tokenizer = get_tokenizer(tokenizer_name, vocab_size= vocab_size) 259 | 260 | if os.path.isfile(f"{tok_save_path}/tok.model"): 261 | print('loading pretrained tokenizer') 262 | tokenizer.load(tok_save_path) 263 | dataset = load_from_disk(data_save_path) 264 | else: 265 | write_data_for_train(dataset['train'], data_config['text'], data_save_path, task = task_name) 266 | if prev_tok_save_path != "": 267 | tokenizer.load(prev_tok_save_path) 268 | else: 269 | print('training tokenizer from scratch') 270 | tokenizer.train(file_path = f'{data_save_path}/data.txt') 271 | tokenizer.save(tok_save_path) 272 | for split in dataset: 273 | dataset[split] = dataset[split].map(lambda x: prepare_features(x, tokenizer, data_config, model_type = model_type, max_len = max_tokens) 274 | , batched=True, remove_columns=dataset[split].column_names) 275 | dataset.save_to_disk(data_save_path) 276 | columns=['input_ids', 'start_positions', 'end_positions'] 277 | 278 | elif task_name == 'mt': 279 | prefix = "translate English to Arabic: " 280 | src_lang, trg_lang = data_config['text'].split(",") 281 | 282 | if 'birnn' not in model_name: 283 | 284 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) 285 | def preprocess(dataset): 286 | inputs = [prefix + ex for ex in dataset[src_lang]] 287 | targets = [ex for ex in dataset[trg_lang]] 288 | dataset = tokenizer(inputs, max_length=128, truncation=True, padding = 'max_length') 289 | 290 | # Setup the tokenizer for targets 291 | with tokenizer.as_target_tokenizer(): 292 | labels = tokenizer(targets, max_length=128, truncation=True, padding = 'max_length') 293 | 294 | dataset["labels"] = labels["input_ids"] 295 | return dataset 296 | if not os.path.isfile(f"{data_save_path}/dataset_dict.json"): 297 | dataset = dataset.map(preprocess, batched=True) 298 | dataset.save_to_disk(data_save_path) 299 | else: 300 | dataset = load_from_disk(data_save_path) 301 | columns = ['input_ids', 'attention_mask', 'labels'] 302 | else: 303 | src_tokenizer = get_tokenizer('BPE', vocab_size= 1000) 304 | trg_tokenizer = get_tokenizer(tokenizer_name, vocab_size= vocab_size) 305 | src_tok_save_path = f"{save_dir}/{tokenizer_name}/1000/{dataset_name}/{model_name}/tokenizer" 306 | 307 | if os.path.isfile(f"{tok_save_path}/trg_tok.model"): 308 | print('loading pretrained tokenizers') 309 | src_tokenizer.load(f"{src_tok_save_path}/", name = "src_tok") 310 | trg_tokenizer.load(f"{tok_save_path}/", name = "trg_tok") 311 | dataset = load_from_disk(data_save_path) 312 | else: 313 | open(f'{data_save_path}/src_data.txt', 'w').write('\n'.join(dataset['train'][src_lang])) 314 | open(f'{data_save_path}/trg_data.txt', 'w').write('\n'.join(dataset['train'][trg_lang])) 315 | 316 | if not os.path.isfile(f"{src_tok_save_path}/src_tok.model"): 317 | src_tokenizer.train(file_path = f'{data_save_path}/src_data.txt') 318 | src_tokenizer.save(f"{tok_save_path}/", name = 'src_tok') 319 | 320 | if prev_tok_save_path != "": 321 | tokenizer.load(prev_tok_save_path) 322 | else: 323 | print('training tokenizer from scratch') 324 | 325 | trg_tokenizer.train(file_path = f'{data_save_path}/trg_data.txt') 326 | trg_tokenizer.save(f"{tok_save_path}/", name = 'trg_tok') 327 | 328 | def preprocess(dataset): 329 | inputs = [ex for ex in dataset[src_lang]] 330 | targets = [ex for ex in dataset[trg_lang]] 331 | 332 | input_ids = src_tokenizer.encode_sentences(inputs, out_length = max_tokens, add_boundry = True) 333 | labels = trg_tokenizer.encode_sentences(targets, out_length = max_tokens, add_boundry = True) 334 | dataset = dataset.add_column("input_ids", input_ids) 335 | dataset = dataset.add_column("labels", labels) 336 | return dataset 337 | 338 | for split in dataset: 339 | dataset[split] = preprocess(dataset[split]) 340 | 341 | dataset.save_to_disk(data_save_path) 342 | 343 | columns = ['input_ids', 'labels'] 344 | tokenizer = trg_tokenizer 345 | 346 | elif task_name == 'sum': 347 | prefix = "" 348 | text, summary = data_config['text'].split(",") 349 | 350 | if 'birnn' not in model_name: 351 | 352 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) 353 | def preprocess(dataset): 354 | inputs = [prefix + ex for ex in dataset[text]] 355 | targets = [ex for ex in dataset[summary]] 356 | dataset = tokenizer(inputs, max_length=128, truncation=True, padding = 'max_length') 357 | 358 | # Setup the tokenizer for targets 359 | with tokenizer.as_target_tokenizer(): 360 | labels = tokenizer(targets, max_length=128, truncation=True, padding = 'max_length') 361 | labels["input_ids"] = [ 362 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]] 363 | dataset["labels"] = labels["input_ids"] 364 | return dataset 365 | if not os.path.isfile(f"{data_save_path}/dataset_dict.json"): 366 | dataset = dataset.map(preprocess, batched=True) 367 | dataset.save_to_disk(data_save_path) 368 | else: 369 | dataset = load_from_disk(data_save_path) 370 | columns = ['input_ids', 'attention_mask', 'labels'] 371 | else: 372 | src_tokenizer = get_tokenizer('BPE', vocab_size= 1000) 373 | trg_tokenizer = get_tokenizer(tokenizer_name, vocab_size= vocab_size) 374 | src_tok_save_path = f"{save_dir}/{tokenizer_name}/1000/{dataset_name}/{model_name}/tokenizer" 375 | 376 | if os.path.isfile(f"{tok_save_path}/trg_tok.model"): 377 | print('loading pretrained tokenizers') 378 | src_tokenizer.load(f"{src_tok_save_path}/", name = "src_tok") 379 | trg_tokenizer.load(f"{tok_save_path}/", name = "trg_tok") 380 | dataset = load_from_disk(data_save_path) 381 | else: 382 | open(f'{data_save_path}/src_data.txt', 'w').write('\n'.join(dataset['train'][text])) 383 | open(f'{data_save_path}/trg_data.txt', 'w').write('\n'.join(dataset['train'][summary])) 384 | 385 | if not os.path.isfile(f"{src_tok_save_path}/src_tok.model"): 386 | src_tokenizer.train(file_path = f'{data_save_path}/src_data.txt') 387 | src_tokenizer.save(f"{tok_save_path}/", name = 'src_tok') 388 | 389 | if prev_tok_save_path != "": 390 | tokenizer.load(prev_tok_save_path) 391 | else: 392 | print('training tokenizer from scratch') 393 | 394 | trg_tokenizer.train(file_path = f'{data_save_path}/trg_data.txt') 395 | trg_tokenizer.save(f"{tok_save_path}/", name = 'trg_tok') 396 | 397 | def preprocess(dataset): 398 | inputs = [ex for ex in dataset[text]] 399 | targets = [ex for ex in dataset[summary]] 400 | 401 | input_ids = src_tokenizer.encode_sentences(inputs, out_length = max_tokens, add_boundry = True) 402 | labels = trg_tokenizer.encode_sentences(targets, out_length = max_tokens, add_boundry = True) 403 | dataset = dataset.add_column("input_ids", input_ids) 404 | dataset = dataset.add_column("labels", labels) 405 | return dataset 406 | 407 | for split in dataset: 408 | dataset[split] = preprocess(dataset[split]) 409 | 410 | dataset.save_to_disk(data_save_path) 411 | 412 | columns = ['input_ids', 'labels'] 413 | tokenizer = trg_tokenizer 414 | #create loaders 415 | if task_name != 'qa': 416 | for split in dataset: 417 | dataset[split].set_format(type='torch', columns=columns) 418 | dataset[split] = torch.utils.data.DataLoader(dataset[split], batch_size=batch_size, shuffle = True) 419 | 420 | return tokenizer, [dataset['train'], dataset['valid'], dataset['test']], [examples['train'], examples['valid'], examples['test']] 421 | -------------------------------------------------------------------------------- /nmatheg/datasets.ini: -------------------------------------------------------------------------------- 1 | [ajgt_twitter_ar] 2 | name = ajgt_twitter_ar 3 | text = text 4 | label = label 5 | num_labels = 2 6 | train = train 7 | task = cls 8 | labels = Negative,Positive 9 | 10 | [off-eval-ar] 11 | name = Zaid/off-eval-ar 12 | text = tweet 13 | label = label 14 | num_labels = 2 15 | train = train 16 | test = test 17 | task = cls 18 | labels = NOT,OFF 19 | 20 | [off-eval-en] 21 | name = Zaid/off-eval-en 22 | text = tweet 23 | label = label 24 | num_labels = 2 25 | train = train 26 | test = test 27 | task = cls 28 | labels = NOT,OFF 29 | 30 | [metrec] 31 | name = metrec 32 | text = text 33 | label = label 34 | num_labels = 14 35 | train = train 36 | test = test 37 | task = cls 38 | labels = saree,kamel,mutakareb,mutadarak,munsareh,madeed,mujtath,ramal,baseet,khafeef,taweel,wafer,hazaj,rajaz 39 | 40 | [labr] 41 | name = labr 42 | text = text 43 | label = label 44 | num_labels = 5 45 | train = train 46 | test = test 47 | task = cls 48 | labels = 1,2,3,4,5 49 | 50 | [ar_res_reviews] 51 | name = ar_res_reviews 52 | text = text 53 | label = polarity 54 | num_labels = 2 55 | split = train 56 | task = cls 57 | labels = negative,positive 58 | 59 | [arsentd_lev] 60 | name = arsentd_lev 61 | text = Tweet 62 | label = Sentiment 63 | num_labels = 5 64 | train = train 65 | task = cls 66 | labels = negative,neutral,positive,very_negative,very_positive 67 | 68 | [oclar] 69 | name = oclar 70 | text = review 71 | label = rating 72 | num_labels = 5 73 | train = train 74 | task = cls 75 | labels = 1,2,3,4,5 76 | 77 | [emotone_ar] 78 | name = emotone_ar 79 | text = tweet 80 | label = label 81 | num_labels = 8 82 | train = train 83 | task = cls 84 | labels = none,anger,joy,sadness,love,sympathy,surprise,fear 85 | 86 | [hard] 87 | name = hard 88 | text = text 89 | label = label 90 | num_labels = 5 91 | train = train 92 | task = cls 93 | labels = 1,2,3,4,5 94 | 95 | [ar_sarcasm] 96 | name = ar_sarcasm 97 | text = tweet 98 | label = sarcasm 99 | num_labels = 2 100 | train = train 101 | test = test 102 | task = cls 103 | labels = non-sarcastic,sarcastic 104 | 105 | [caner] 106 | name = caner 107 | subset = dummy 108 | text = token 109 | label = ner_tag 110 | num_labels = 21 111 | train = train 112 | task = ner 113 | labels = Allah,Book,Clan,Crime,Date,Day,Hell,Loc,Meas,Mon,Month,NatOb,Number,O,Org,Para,Pers,Prophet,Rlig,Sect,Time 114 | 115 | [arcd] 116 | name = arcd 117 | text = question,context 118 | label = answer 119 | num_labels = 2 120 | train = train 121 | test = validation 122 | task = qa 123 | labels = start_logits,end_logits 124 | 125 | [mlqa] 126 | name = mlqa 127 | subset = mlqa-translate-train.ar 128 | text = question,context 129 | label = answer 130 | num_labels = 2 131 | train = train 132 | test = validation 133 | task = qa 134 | labels = start_logits,end_logits 135 | 136 | [tatoeba_mt] 137 | name = Helsinki-NLP/tatoeba_mt 138 | subset = ara-eng 139 | text = targetString,sourceString 140 | num_labels = 0 141 | train = validation 142 | test = test 143 | task = mt 144 | labels = english,arabic 145 | 146 | 147 | [xnli] 148 | name = xnli 149 | subset = ar 150 | text = premise,hypothesis 151 | label = label 152 | num_labels = 3 153 | train = train 154 | valid = validation 155 | test = test 156 | task = nli 157 | labels = entailment,neutral,contradiction 158 | 159 | [xlsum] 160 | name = csebuetnlp/xlsum 161 | subset = arabic 162 | text = text,summary 163 | num_labels = 0 164 | train = train 165 | valid = validation 166 | test = test 167 | task = sum 168 | labels = text,summary 169 | 170 | [ARGEN_title_generation] 171 | name = arbml/ARGEN_title_generation 172 | text = document,title 173 | num_labels = 0 174 | train = train 175 | test = validation 176 | task = sum 177 | labels = document,title 178 | 179 | [wiki_lingua_ar] 180 | name = arbml/wiki_lingua_ar 181 | text = article,summary 182 | num_labels = 0 183 | train = train 184 | valid = validation 185 | test = test 186 | task = sum 187 | labels = article,summary 188 | 189 | [arabic_pos_dialect] 190 | name = arbml/arabic_pos_dialect 191 | subset = all 192 | text = words 193 | label = pos_tags 194 | num_labels = 22 195 | train = train 196 | task = pos 197 | labels = ADJ,ADV,CASE,CONJ,DET,EMOT,EOS,FOREIGN,FUT_PART,HASH,MENTION,NEG_PART,NOUN,NSUFF,NUM,PART,PREP,PROG_PART,PRON,PUNC,URL,V 198 | -------------------------------------------------------------------------------- /nmatheg/models.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoModelForSequenceClassification, 3 | AutoConfig, 4 | AutoModelForTokenClassification, 5 | AutoModelForQuestionAnswering, 6 | AutoModelForSeq2SeqLM, 7 | get_linear_schedule_with_warmup) 8 | from evaluate import load 9 | import random 10 | import torch.nn.functional as F 11 | import os 12 | import time 13 | import numpy as np 14 | from tqdm.auto import tqdm 15 | import torch 16 | from torch.optim import AdamW 17 | import torch.nn as nn 18 | from accelerate import Accelerator 19 | from datasets import load_metric 20 | import copy 21 | from .ner_utils import get_labels 22 | from .qa_utils import evaluate_metric 23 | from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score 24 | import nltk 25 | nltk.download('punkt') 26 | 27 | class BiRNN(nn.Module): 28 | def __init__(self, vocab_size, num_labels, hidden_dim = 128): 29 | 30 | super().__init__() 31 | 32 | self.embedding = nn.Embedding(vocab_size, hidden_dim) 33 | self.bigru1 = nn.GRU(hidden_dim, hidden_dim, bidirectional=True, batch_first = True) 34 | self.bigru2 = nn.GRU(2*hidden_dim, hidden_dim, bidirectional=True, batch_first = True) 35 | self.bigru3 = nn.GRU(2*hidden_dim, hidden_dim, bidirectional=True, batch_first = True) 36 | self.fc = nn.Linear(2*hidden_dim, num_labels) 37 | self.hidden_dim = hidden_dim 38 | self.num_labels = num_labels 39 | 40 | def forward(self, 41 | input_ids, 42 | labels = None): 43 | embedded = self.embedding(input_ids) 44 | out,h = self.bigru1(embedded) 45 | out,h = self.bigru2(out) 46 | out,h = self.bigru3(out) 47 | logits = self.fc(out[:,0,:]) 48 | if labels is not None: 49 | loss = self.compute_loss(logits, labels) 50 | return {'loss':loss, 51 | 'logits':logits} 52 | return {'logits': logits} 53 | 54 | def compute_loss(self, logits, labels): 55 | loss_fct = nn.CrossEntropyLoss() 56 | loss = loss_fct(logits, labels) 57 | return loss 58 | 59 | class BaseTextClassficationModel: 60 | def __init__(self, config): 61 | self.model = nn.Module() 62 | self.num_labels = config['num_labels'] 63 | self.model_name = config['model_name'] 64 | self.vocab_size = config['vocab_size'] 65 | 66 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 67 | 68 | def train(self, datasets, examples, **kwargs): 69 | save_dir = kwargs['save_dir'] 70 | epochs = kwargs['epochs'] 71 | lr = kwargs['lr'] 72 | 73 | train_dataset, valid_dataset, test_dataset = datasets 74 | 75 | self.optimizer = AdamW(self.model.parameters(), lr = lr) 76 | filepath = os.path.join(save_dir, 'pytorch_model.bin') 77 | best_accuracy = 0 78 | pbar = tqdm(total=epochs * len(train_dataset), leave=True) 79 | for epoch in range(epochs): 80 | accuracy = 0 81 | loss = 0 82 | self.model.train().to(self.device) 83 | for _, batch in enumerate(train_dataset): 84 | batch = {k: v.to(self.device) for k, v in batch.items()} 85 | self.optimizer.zero_grad() 86 | outputs = self.model(**batch) 87 | loss = outputs['loss'] 88 | loss.backward() 89 | self.optimizer.step() 90 | labels = batch['labels'].cpu() 91 | preds = outputs['logits'].argmax(-1).cpu() 92 | accuracy += accuracy_score(labels, preds) /len(train_dataset) 93 | loss += loss / len(train_dataset) 94 | batch = None 95 | pbar.update(1) 96 | print(f"Epoch {epoch} Train Loss {loss:.4f} Train Accuracy {accuracy:.4f}") 97 | 98 | self.model.eval().to(self.device) 99 | results = self.evaluate_dataset(valid_dataset) 100 | print(f"Epoch {epoch} Valid Loss {results['loss']:.4f} Valid Accuracy {results['accuracy']:.4f}") 101 | 102 | val_accuracy = results['accuracy'] 103 | if val_accuracy >= best_accuracy: 104 | best_accuracy = val_accuracy 105 | torch.save(self.model.state_dict(), filepath) 106 | 107 | #Later to restore: 108 | 109 | self.model.load_state_dict(torch.load(filepath)) 110 | self.model.eval() 111 | test_metrics = self.evaluate_dataset(test_dataset) 112 | print(f"Test Loss {test_metrics['loss']:.4f} Test Accuracy {test_metrics['accuracy']:.4f}") 113 | return test_metrics 114 | 115 | def evaluate_dataset(self, dataset, desc = "Eval"): 116 | accuracy = 0 117 | total_loss = 0 118 | pbar = tqdm(total=len(dataset), position=0, leave=False, desc=desc) 119 | refs = [] 120 | preds = [] 121 | with torch.no_grad(): 122 | for _, batch in enumerate(dataset): 123 | batch = {k: v.to(self.device) for k, v in batch.items()} 124 | outputs = self.model(**batch) 125 | loss = outputs['loss'] 126 | refs += batch['labels'].cpu() 127 | preds += outputs['logits'].argmax(-1).cpu() 128 | total_loss += loss / len(dataset) 129 | batch = None 130 | pbar.update(1) 131 | return { 132 | "loss":float(total_loss.cpu().detach().numpy()), 133 | "precision": precision_score(refs, preds, average = "macro"), 134 | "recall": recall_score(refs, preds, average = "macro"), 135 | "f1": f1_score(refs, preds, average = "macro"), 136 | "accuracy": accuracy_score(refs, preds), 137 | } 138 | 139 | class SimpleClassificationModel(BaseTextClassficationModel): 140 | def __init__(self, config): 141 | BaseTextClassficationModel.__init__(self, config) 142 | self.model = BiRNN(self.vocab_size, self.num_labels) 143 | self.model.to(self.device) 144 | # self.optimizer = AdamW(self.model.parameters(), lr = 5e-5) 145 | 146 | def wipe_memory(self): 147 | self.model = None 148 | self.optimizer = None 149 | torch.cuda.empty_cache() 150 | 151 | class BERTTextClassificationModel(BaseTextClassficationModel): 152 | def __init__(self, config): 153 | BaseTextClassficationModel.__init__(self, config) 154 | config = AutoConfig.from_pretrained(self.model_name,num_labels=self.num_labels) 155 | self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name, config = config) 156 | 157 | def wipe_memory(self): 158 | self.model = None 159 | self.optimizer = None 160 | torch.cuda.empty_cache() 161 | 162 | class BiRNNForTokenClassification(nn.Module): 163 | def __init__(self, vocab_size, num_labels, hidden_dim = 128): 164 | 165 | super().__init__() 166 | 167 | self.embedding = nn.Embedding(vocab_size, hidden_dim) 168 | self.bigru1 = nn.GRU(hidden_dim, hidden_dim, bidirectional=True, batch_first = True) 169 | self.bigru2 = nn.GRU(2*hidden_dim, hidden_dim, bidirectional=True, batch_first = True) 170 | self.bigru3 = nn.GRU(2*hidden_dim, hidden_dim//2, bidirectional=True, batch_first = True) 171 | self.fc = nn.Linear(hidden_dim, num_labels) 172 | self.hidden_dim = hidden_dim 173 | self.num_labels = num_labels 174 | 175 | def forward(self, 176 | input_ids, 177 | labels = None): 178 | 179 | embedded = self.embedding(input_ids) 180 | out,h = self.bigru1(embedded) 181 | out,h = self.bigru2(out) 182 | out,h = self.bigru3(out) 183 | logits = self.fc(out) 184 | if labels is not None: 185 | loss = self.compute_loss(logits, labels) 186 | return {'loss':loss, 187 | 'logits':logits} 188 | else: 189 | return {'logits':logits} 190 | 191 | def compute_loss(self, logits, labels): 192 | loss_fct = nn.CrossEntropyLoss() 193 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 194 | return loss 195 | 196 | class BaseTokenClassficationModel: 197 | def __init__(self, config): 198 | self.model = nn.Module() 199 | self.num_labels = config['num_labels'] 200 | self.model_name = config['model_name'] 201 | self.vocab_size = config['vocab_size'] 202 | self.labels = config['labels'] 203 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 204 | self.metric = load_metric("seqeval") 205 | self.accelerator = Accelerator() 206 | 207 | def train(self, datasets, examples, **kwargs): 208 | save_dir = kwargs['save_dir'] 209 | epochs = kwargs['epochs'] 210 | lr = kwargs['lr'] 211 | self.optimizer = AdamW(self.model.parameters(), lr = lr) 212 | 213 | train_dataset, valid_dataset, test_dataset = datasets 214 | filepath = os.path.join(save_dir, 'pytorch_model.bin') 215 | best_accuracy = 0 216 | pbar = tqdm(total=epochs * len(train_dataset), leave=True) 217 | for epoch in range(epochs): 218 | accuracy = 0 219 | loss = 0 220 | self.model.train().to(self.device) 221 | predictions , true_labels = [], [] 222 | for _, batch in enumerate(train_dataset): 223 | batch = {k: v.to(self.device) for k, v in batch.items()} 224 | outputs = self.model(**batch) 225 | loss = outputs['loss'] 226 | loss.backward() 227 | self.optimizer.step() 228 | self.optimizer.zero_grad() 229 | loss += loss / len(train_dataset) 230 | batch = None 231 | pbar.update(1) 232 | 233 | train_metrics = self.evaluate_dataset(train_dataset) 234 | print(f"Epoch {epoch} Train Loss {train_metrics['loss']:.4f} Train F1 {train_metrics['f1']:.4f}") 235 | 236 | self.model.eval().to(self.device) 237 | valid_metrics = self.evaluate_dataset(valid_dataset) 238 | print(f"Epoch {epoch} Valid Loss {valid_metrics['loss']:.4f} Valid F1 {valid_metrics['f1']:.4f}") 239 | 240 | val_accuracy = valid_metrics['f1'] 241 | if val_accuracy >= best_accuracy: 242 | best_accuracy = val_accuracy 243 | torch.save(self.model.state_dict(), filepath) 244 | 245 | self.model.load_state_dict(torch.load(filepath)) 246 | self.model.eval() 247 | test_metrics = self.evaluate_dataset(test_dataset) 248 | print(f"Test Loss {test_metrics['loss']:.4f} Test F1 {test_metrics['f1']:.4f}") 249 | return { 250 | "precision": test_metrics["precision"], 251 | "recall": test_metrics["recall"], 252 | "f1": test_metrics["f1"], 253 | "accuracy": test_metrics["accuracy"], 254 | } 255 | 256 | def evaluate_dataset(self, dataset, desc = "Eval"): 257 | preds = [] 258 | refs = [] 259 | 260 | total_loss = 0 261 | pbar = tqdm(total=len(dataset), position=0, leave=False, desc=desc) 262 | for _, batch in enumerate(dataset): 263 | batch = {k: v.to(self.device) for k, v in batch.items()} 264 | outputs = self.model(**batch) 265 | loss = outputs['loss'] 266 | labels = batch['labels'] 267 | predictions = outputs['logits'].argmax(dim=-1) 268 | 269 | predictions_gathered = self.accelerator.gather(predictions) 270 | labels_gathered = self.accelerator.gather(labels) 271 | pred, ref = get_labels(predictions_gathered, labels_gathered, self.labels) 272 | ref = [item for sublist in ref for item in sublist] 273 | pred = [item for sublist in pred for item in sublist] 274 | preds.append(pred) 275 | refs.append(ref) 276 | 277 | total_loss += loss / len(dataset) 278 | batch = None 279 | pbar.update(1) 280 | 281 | refs = [item for sublist in refs for item in sublist] 282 | preds = [item for sublist in preds for item in sublist] 283 | 284 | return { 285 | "loss":float(total_loss.cpu().detach().numpy()), 286 | "precision": precision_score(refs, preds, average = "micro"), 287 | "recall": recall_score(refs, preds, average = "micro"), 288 | "f1": f1_score(refs, preds, average = "micro"), 289 | "accuracy": accuracy_score(refs, preds), 290 | } 291 | 292 | class SimpleTokenClassificationModel(BaseTokenClassficationModel): 293 | def __init__(self, config): 294 | BaseTokenClassficationModel.__init__(self, config) 295 | self.model = BiRNNForTokenClassification(self.vocab_size, self.num_labels) 296 | self.model.to(self.device) 297 | # self.optimizer = AdamW(self.model.parameters(), lr = 5e-5) 298 | 299 | def wipe_memory(self): 300 | self.model = None 301 | self.optimizer = None 302 | torch.cuda.empty_cache() 303 | 304 | class BERTTokenClassificationModel(BaseTokenClassficationModel): 305 | def __init__(self, config): 306 | BaseTokenClassficationModel.__init__(self, config) 307 | config = AutoConfig.from_pretrained(self.model_name,num_labels=self.num_labels) 308 | self.model = AutoModelForTokenClassification.from_pretrained(self.model_name, config = config) 309 | 310 | def wipe_memory(self): 311 | self.model = None 312 | self.optimizer = None 313 | torch.cuda.empty_cache() 314 | 315 | class BaseQuestionAnsweringModel: 316 | def __init__(self, config): 317 | self.model = nn.Module() 318 | self.model_name = config['model_name'] 319 | self.vocab_size = config['vocab_size'] 320 | self.num_labels = config['num_labels'] 321 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 322 | self.accelerator = Accelerator() 323 | if 'bert' in self.model_name: 324 | self.columns = ['input_ids', 'attention_mask', 'start_positions', 'end_positions'] 325 | else: 326 | self.columns = ['input_ids', 'start_positions', 'end_positions'] 327 | 328 | 329 | def train(self, datasets, examples, **kwargs): 330 | save_dir = kwargs['save_dir'] 331 | epochs = kwargs['epochs'] 332 | lr = kwargs['lr'] 333 | batch_size = kwargs['batch_size'] 334 | self.optimizer = AdamW(self.model.parameters(), lr = lr) 335 | 336 | train_dataset, valid_dataset, test_dataset = datasets 337 | train_examples, valid_examples, test_examples = examples 338 | train_loader = copy.deepcopy(train_dataset) 339 | train_loader.set_format(type='torch', columns=self.columns) 340 | train_loader = torch.utils.data.DataLoader(train_loader, batch_size=batch_size, shuffle = True) 341 | filepath = os.path.join(save_dir, 'pytorch_model.bin') 342 | best_accuracy = 0 343 | pbar = tqdm(total=epochs * len(train_dataset), leave=True) 344 | 345 | for epoch in range(epochs): 346 | accuracy = 0 347 | loss = 0 348 | self.model.train().to(self.device) 349 | all_start_logits = [] 350 | all_end_logits = [] 351 | for _, batch in enumerate(train_loader): 352 | batch = {k: v.to(self.device) for k, v in batch.items()} 353 | val = batch['input_ids'] 354 | val[val==-100] = 0 355 | outputs = self.model(**batch) 356 | loss = outputs['loss'] 357 | start_logits = outputs['start_logits'] 358 | end_logits = outputs['end_logits'] 359 | 360 | all_start_logits.append(self.accelerator.gather(start_logits).detach().cpu().numpy()) 361 | all_end_logits.append(self.accelerator.gather(end_logits).detach().cpu().numpy()) 362 | 363 | loss.backward() 364 | self.optimizer.step() 365 | self.optimizer.zero_grad() 366 | 367 | loss += loss / len(train_dataset) 368 | batch = None 369 | pbar.update(1) 370 | 371 | train_metrics = self.evaluate_dataset(train_dataset, train_examples, batch_size=batch_size) 372 | print(f"Epoch {epoch} Train Loss {loss:.4f} Train F1 {train_metrics['f1']:.4f}") 373 | 374 | self.model.eval().to(self.device) 375 | valid_metrics = self.evaluate_dataset(valid_dataset, valid_examples, batch_size=batch_size) 376 | print(f"Epoch {epoch} Valid Loss {valid_metrics['loss']:.4f} Valid F1 {valid_metrics['f1']:.4f}") 377 | 378 | val_accuracy = valid_metrics['f1'] 379 | if val_accuracy >= best_accuracy: 380 | best_accuracy = val_accuracy 381 | torch.save(self.model.state_dict(), filepath) 382 | 383 | self.model.load_state_dict(torch.load(filepath)) 384 | self.model.eval() 385 | test_metrics = self.evaluate_dataset(test_dataset, test_examples, batch_size=batch_size) 386 | print(f"Epoch {epoch} Test Loss {test_metrics['loss']:.4f} Test F1 {test_metrics['f1']:.4f}") 387 | return {'f1':test_metrics['f1'], 'Exact Match':test_metrics['exact_match']} 388 | 389 | def evaluate_dataset(self, dataset, examples, batch_size = 8, desc = "Eval"): 390 | total_loss = 0 391 | all_start_logits = [] 392 | all_end_logits = [] 393 | data_loader = copy.deepcopy(dataset) 394 | data_loader.set_format(type='torch', columns=self.columns) 395 | data_loader = torch.utils.data.DataLoader(data_loader, batch_size=batch_size) 396 | pbar = tqdm(total=len(dataset), position=0, leave=False, desc=desc) 397 | for _, batch in enumerate(data_loader): 398 | batch = {k: v.to(self.device) for k, v in batch.items()} 399 | val = batch['input_ids'] 400 | val[val==-100] = 0 401 | outputs = self.model(**batch) 402 | loss = outputs['loss'] 403 | start_logits = outputs['start_logits'] 404 | end_logits = outputs['end_logits'] 405 | 406 | all_start_logits.append(self.accelerator.gather(start_logits).detach().cpu().numpy()) 407 | all_end_logits.append(self.accelerator.gather(end_logits).detach().cpu().numpy()) 408 | 409 | total_loss += loss / len(dataset) 410 | batch = None 411 | pbar.update(1) 412 | metric = evaluate_metric(dataset, examples, all_start_logits, all_end_logits) 413 | return {'loss':total_loss, 'f1':metric['f1']/100, 'exact_match':metric['exact_match']/100} 414 | 415 | class BERTQuestionAnsweringModel(BaseQuestionAnsweringModel): 416 | def __init__(self, config): 417 | BaseQuestionAnsweringModel.__init__(self, config) 418 | config = AutoConfig.from_pretrained(self.model_name) 419 | self.model = AutoModelForQuestionAnswering.from_pretrained(self.model_name, config = config) 420 | self.model.to(self.device) 421 | 422 | def wipe_memory(self): 423 | self.model = None 424 | self.optimizer = None 425 | torch.cuda.empty_cache() 426 | 427 | class BiRNNForQuestionAnswering(nn.Module): 428 | def __init__(self, vocab_size, num_labels = 2, hidden_dim = 128): 429 | 430 | super().__init__() 431 | 432 | self.embedding = nn.Embedding(vocab_size, hidden_dim) 433 | self.bigru1 = nn.GRU(hidden_dim, hidden_dim, bidirectional=True, batch_first = True) 434 | self.bigru2 = nn.GRU(2*hidden_dim, hidden_dim, bidirectional=True, batch_first = True) 435 | self.bigru3 = nn.GRU(2*hidden_dim, hidden_dim//2, bidirectional=True, batch_first = True) 436 | self.qa_outputs = nn.Linear(hidden_dim, num_labels) 437 | self.hidden_dim = hidden_dim 438 | self.num_labels = num_labels 439 | 440 | def forward(self, 441 | input_ids, 442 | start_positions = None, 443 | end_positions = None): 444 | 445 | embedded = self.embedding(input_ids) 446 | out,h = self.bigru1(embedded) 447 | out,h = self.bigru2(out) 448 | out,h = self.bigru3(out) 449 | logits = self.qa_outputs(out) # (bs, max_query_len, 2) 450 | start_logits, end_logits = logits.split(1, dim=-1) 451 | start_logits = start_logits.squeeze(-1).contiguous() # (bs, max_query_len) 452 | end_logits = end_logits.squeeze(-1).contiguous() # (bs, max_query_len) 453 | if start_positions is not None: 454 | loss = self.compute_loss(start_logits, end_logits, start_positions, end_positions) 455 | return {'loss':loss, 456 | 'logits':logits, 457 | 'start_logits':start_logits, 458 | 'end_logits':end_logits} 459 | else: 460 | return {'logits':logits, 461 | 'start_logits':start_logits, 462 | 'end_logits':end_logits} 463 | 464 | def compute_loss(self, start_logits, end_logits, start_positions, end_positions): 465 | loss_fct = nn.CrossEntropyLoss(ignore_index=0) 466 | start_loss = loss_fct(start_logits, start_positions) 467 | end_loss = loss_fct(end_logits, end_positions) 468 | total_loss = (start_loss + end_loss) / 2 469 | return total_loss 470 | 471 | class SimpleQuestionAnsweringModel(BaseQuestionAnsweringModel): 472 | def __init__(self, config): 473 | BaseQuestionAnsweringModel.__init__(self, config) 474 | self.model = BiRNNForQuestionAnswering(self.vocab_size, self.num_labels) 475 | self.model.to(self.device) 476 | # self.optimizer = AdamW(self.model.parameters(), lr = 5e-5) 477 | 478 | def wipe_memory(self): 479 | self.model = None 480 | self.optimizer = None 481 | torch.cuda.empty_cache() 482 | 483 | 484 | class BaseSeq2SeqModel: 485 | def __init__(self, config, tokenizer = None, task = ""): 486 | self.model = nn.Module() 487 | self.model_name = config['model_name'] 488 | self.vocab_size = config['vocab_size'] 489 | self.num_labels = config['num_labels'] 490 | self.tokenizer = tokenizer 491 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 492 | self.task = task 493 | 494 | def train(self, datasets, examples, **kwargs): 495 | save_dir = kwargs['save_dir'] 496 | epochs = kwargs['epochs'] 497 | lr = kwargs['lr'] 498 | batch_size = kwargs['batch_size'] 499 | 500 | self.optimizer = AdamW(self.model.parameters(), lr = lr) 501 | self.mt_metric = load("sacrebleu") 502 | self.sum_metric = load("rouge") 503 | train_dataset, valid_dataset, test_dataset = datasets 504 | 505 | filepath = os.path.join(save_dir, 'pytorch_model.bin') 506 | best_accuracy = 0 507 | metric_name = "bleu" if self.task == "mt" else "rougeLsum" 508 | pbar = tqdm(total=epochs * len(train_dataset), leave=True) 509 | 510 | for epoch in range(epochs): 511 | loss = 0 512 | self.model.train().to(self.device) 513 | for _, batch in enumerate(train_dataset): 514 | batch = {k: v.to(self.device) for k, v in batch.items()} 515 | outputs = self.model(**batch) 516 | loss = outputs['loss'] 517 | loss.backward() 518 | self.optimizer.step() 519 | self.optimizer.zero_grad() 520 | batch = None 521 | pbar.update(1) 522 | self.model.eval().to(self.device) 523 | train_loss, train_metrics = self.evaluate_dataset(train_dataset) 524 | print(f"Epoch {epoch} Train Loss {train_loss:.4f} Train {metric_name} {train_metrics[metric_name]:.4f}") 525 | 526 | valid_loss, valid_metrics = self.evaluate_dataset(valid_dataset) 527 | print(f"Epoch {epoch} Valid Loss {valid_loss:.4f} Valid {metric_name} {valid_metrics[metric_name]:.4f}") 528 | 529 | val_accuracy = valid_metrics[metric_name] 530 | if val_accuracy >= best_accuracy: 531 | best_accuracy = val_accuracy 532 | torch.save(self.model.state_dict(), filepath) 533 | 534 | self.model.load_state_dict(torch.load(filepath)) 535 | self.model.eval() 536 | test_loss, test_metrics = self.evaluate_dataset(test_dataset) 537 | print(f"Epoch {epoch} Test Loss {test_loss:.4f} Test {metric_name} {test_metrics[metric_name]:.4f}") 538 | return test_metrics 539 | 540 | def evaluate_dataset(self, dataset, desc="Eval"): 541 | total_loss = 0 542 | bleu_score = 0 543 | pbar = tqdm(total=len(dataset), position=0, leave=False, desc=desc) 544 | for _, batch in enumerate(dataset): 545 | batch = {k: v.to(self.device) for k, v in batch.items()} 546 | if 't5' in self.model_name.lower(): 547 | with torch.no_grad(): 548 | outputs = self.model(**batch) 549 | generated_tokens = self.model.generate(batch['input_ids']) 550 | else: 551 | with torch.no_grad(): 552 | outputs = self.model(**batch, mode ="generate") 553 | generated_tokens = outputs['outputs'] 554 | 555 | labels = batch['labels'] 556 | loss = outputs['loss'] 557 | total_loss += loss.cpu().numpy() / len(dataset) 558 | 559 | if self.task == "mt": 560 | metric = self.compute_metrics(generated_tokens.cpu(), labels.cpu()) 561 | bleu_score += metric['bleu'] / len(dataset) 562 | 563 | elif self.task == "sum": 564 | labels = labels.cpu().numpy() 565 | labels = np.where(labels != -100, labels, self.tokenizer.pad_token_id) 566 | 567 | decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) 568 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) 569 | 570 | decoded_preds, decoded_labels = self.postprocess_text_sum(decoded_preds, decoded_labels) 571 | self.sum_metric.add_batch( 572 | predictions=decoded_preds, 573 | references=decoded_labels,) 574 | 575 | pbar.update(1) 576 | if self.task == "sum": 577 | result = self.sum_metric.compute(use_stemmer=True) 578 | result = {k: round(v * 100, 4) for k, v in result.items()} 579 | return loss, result 580 | 581 | elif self.task == "mt": 582 | return loss, {'bleu':bleu_score} 583 | 584 | def compute_metrics(self, preds, labels): 585 | if isinstance(preds, tuple): 586 | preds = preds[0] 587 | 588 | if 't5' in self.model_name.lower(): 589 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) 590 | 591 | # Replace -100 in the labels as we can't decode them. 592 | labels = np.where(labels != -100, labels, self.tokenizer.pad_token_id) 593 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) 594 | result = self.mt_metric.compute(predictions=decoded_preds, references=decoded_labels) 595 | result = {"bleu": result["score"]} 596 | result = {k: round(v, 4) for k, v in result.items()} 597 | return result 598 | else: 599 | 600 | preds = self.get_lists(preds) 601 | labels = self.get_lists(labels) 602 | 603 | decoded_preds = self.tokenizer.decode_sentences(preds) 604 | decoded_preds = [stmt.replace(" .", ".") for stmt in decoded_preds] 605 | 606 | decoded_labels = self.tokenizer.decode_sentences(labels) 607 | decoded_labels = [[stmt.replace(" .", ".")] for stmt in decoded_labels] 608 | 609 | result = self.metric.compute(predictions=decoded_preds, references=decoded_labels) 610 | result = {"bleu": result["score"]} 611 | result = {k: round(v, 4) for k, v in result.items()} 612 | return result 613 | 614 | def postprocess_text_sum(self, preds, labels): 615 | preds = [pred.strip() for pred in preds] 616 | labels = [label.strip() for label in labels] 617 | 618 | # rougeLSum expects newline after each sentence 619 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 620 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] 621 | 622 | return preds, labels 623 | 624 | def get_lists(self, inputs): 625 | inputs = inputs.cpu().detach().numpy().astype(int).tolist() 626 | output = [] 627 | for input in inputs: 628 | current_item =[] 629 | for item in input: 630 | if item == self.tokenizer.eos_idx: 631 | break 632 | else: 633 | current_item.append(item) 634 | output.append(current_item) 635 | return output 636 | 637 | 638 | 639 | class T5Seq2SeqModel(BaseSeq2SeqModel): 640 | def __init__(self, config, tokenizer = None, task = ""): 641 | BaseSeq2SeqModel.__init__(self, config, tokenizer = tokenizer, task = task) 642 | config = AutoConfig.from_pretrained(self.model_name) 643 | self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name, config = config) 644 | 645 | def wipe_memory(self): 646 | self.model = None 647 | self.optimizer = None 648 | torch.cuda.empty_cache() 649 | 650 | #https://colab.research.google.com/github/bentrevett/pytorch-seq2seq/blob/master/1%20-%20Sequence%20to%20Sequence%20Learning%20with%20Neural%20Networks.ipynb#scrollTo=dCK3LIN25n_S 651 | class Encoder(nn.Module): 652 | def __init__(self, input_dim, emb_dim, hid_dim, n_layers, bidirectional = True): 653 | super().__init__() 654 | 655 | self.hid_dim = hid_dim 656 | self.n_layers = n_layers 657 | 658 | self.embedding = nn.Embedding(input_dim, emb_dim) 659 | 660 | self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, bidirectional = bidirectional) 661 | 662 | def forward(self, src): 663 | 664 | #src = [src len, batch size] 665 | 666 | embedded = self.embedding(src) 667 | 668 | #embedded = [src len, batch size, emb dim] 669 | 670 | outputs, hidden = self.rnn(embedded) 671 | #outputs = [src len, batch size, hid dim * n directions] 672 | #hidden = [n layers * n directions, batch size, hid dim] 673 | 674 | #outputs are always from the top hidden layer 675 | 676 | return hidden 677 | class Decoder(nn.Module): 678 | def __init__(self, output_dim, emb_dim, hid_dim, n_layers, bidirectional = True): 679 | super().__init__() 680 | 681 | self.output_dim = output_dim 682 | self.hid_dim = hid_dim 683 | self.n_layers = n_layers 684 | 685 | self.embedding = nn.Embedding(output_dim, emb_dim) 686 | 687 | self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, bidirectional = bidirectional) 688 | 689 | self.fc_out = nn.Linear(hid_dim, output_dim) 690 | 691 | 692 | def forward(self, input, hidden): 693 | 694 | #input = [batch size] 695 | #hidden = [n layers * n directions, batch size, hid dim] 696 | 697 | input = input.unsqueeze(0) 698 | 699 | #input = [1, batch size] 700 | 701 | embedded = self.embedding(input) 702 | 703 | #embedded = [1, batch size, emb dim] 704 | 705 | output, hidden = self.rnn(embedded, hidden) 706 | #seq len and n directions will always be 1 in the decoder, therefore: 707 | #output = [1, batch size, hid dim*2] 708 | #hidden = [n layers, batch size, hid dim] 709 | output = (output[:, :, :self.hid_dim] + output[:, :, self.hid_dim:]) 710 | prediction = self.fc_out(output.squeeze(0)) 711 | 712 | #prediction = [batch size, output dim] 713 | 714 | return prediction, hidden 715 | 716 | class Seq2SeqMachineTranslation(nn.Module): 717 | def __init__(self, vocab_size = 500, tokenizer = None): 718 | super().__init__() 719 | ENC_EMB_DIM = 128 720 | DEC_EMB_DIM = 128 721 | HID_DIM = 1024 722 | N_LAYERS = 2 723 | INPUT_DIM = vocab_size 724 | OUTPUT_DIM = vocab_size 725 | self.vocab_size = vocab_size 726 | self.encoder = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS) 727 | self.decoder = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS) 728 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 729 | self.tokenizer = tokenizer 730 | assert self.encoder.hid_dim == self.decoder.hid_dim, \ 731 | "Hidden dimensions of encoder and decoder must be equal!" 732 | assert self.encoder.n_layers == self.decoder.n_layers, \ 733 | "Encoder and decoder must have equal number of layers!" 734 | 735 | def forward(self, input_ids, labels = None, teacher_forcing_ratio = 0.5, mode = "train"): 736 | src = torch.transpose(input_ids, 0, 1) 737 | 738 | if labels is not None: 739 | trg = torch.transpose(labels, 0, 1) 740 | 741 | #src = [src len, batch size] 742 | #trg = [trg len, batch size] 743 | #teacher_forcing_ratio is probability to use teacher forcing 744 | #e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time 745 | 746 | batch_size = src.shape[1] 747 | trg_len = src.shape[0] 748 | 749 | trg_vocab_size = self.decoder.output_dim 750 | 751 | #tensor to store decoder outputs 752 | outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device) 753 | #last hidden state of the encoder is used as the initial hidden state of the decoder 754 | hidden = self.encoder(src) 755 | 756 | #first input to the decoder is the tokens 757 | input = torch.tensor([self.tokenizer.sos_idx]*batch_size).to(self.device) 758 | 759 | for t in range(1, trg_len): 760 | 761 | #insert input token embedding, previous hidden and previous cell states 762 | #receive output tensor (predictions) and new hidden and cell states 763 | output, hidden = self.decoder(input, hidden) 764 | 765 | #decide if we are going to use teacher forcing or not 766 | teacher_force = random.random() < teacher_forcing_ratio 767 | 768 | #get the highest predicted token from our predictions 769 | top1 = output.argmax(1) 770 | 771 | #if teacher forcing, use actual next token as next input 772 | #if not, use predicted token 773 | 774 | if mode == "train" and teacher_force: 775 | input = trg[t] 776 | else: 777 | input = top1 778 | 779 | outputs[t] = output 780 | 781 | if labels is not None: 782 | loss = self.compute_loss(outputs, trg) 783 | return {'loss':loss, 784 | 'outputs':torch.transpose(outputs.argmax(-1), 0, 1) 785 | } 786 | else: 787 | return {'outputs': torch.transpose(outputs.argmax(-1), 0, 1)} 788 | 789 | def compute_loss(self, output, trg): 790 | loss_fct = nn.CrossEntropyLoss(ignore_index = self.tokenizer.pad_idx) 791 | output_dim = output.shape[-1] 792 | output = output[1:].view(-1, output_dim) 793 | trg = trg[1:].reshape(-1) 794 | #trg = [(trg len - 1) * batch size] 795 | #output = [(trg len - 1) * batch size, output dim] 796 | 797 | loss = loss_fct(output, trg) 798 | return loss 799 | 800 | 801 | class SimpleMachineTranslationModel(BaseSeq2SeqModel): 802 | def __init__(self, config, tokenizer = None): 803 | BaseSeq2SeqModel.__init__(self, config, tokenizer = tokenizer) 804 | self.model = Seq2SeqMachineTranslation(vocab_size = self.vocab_size, tokenizer = tokenizer) 805 | self.model.to(self.device) 806 | # self.optimizer = AdamW(self.model.parameters(), lr = 5e-5) 807 | 808 | def wipe_memory(self): 809 | self.model = None 810 | self.optimizer = None 811 | torch.cuda.empty_cache() 812 | -------------------------------------------------------------------------------- /nmatheg/ner_utils.py: -------------------------------------------------------------------------------- 1 | # https://github.com/huggingface/transformers/blob/master/examples/pytorch/token-classification/run_ner_no_trainer.py 2 | def get_labels(predictions, references, labels): 3 | labels = labels.split(',') 4 | # Transform predictions and references tensos to numpy arrays 5 | y_pred = predictions.detach().cpu().clone().numpy() 6 | y_true = references.detach().cpu().clone().numpy() 7 | 8 | # Remove ignored index (special tokens) 9 | true_predictions = [ 10 | [labels[p] for (p, l) in zip(pred, gold_label) if l != -100] 11 | for pred, gold_label in zip(y_pred, y_true) 12 | ] 13 | true_labels = [ 14 | [labels[l] for (p, l) in zip(pred, gold_label) if l != -100] 15 | for pred, gold_label in zip(y_pred, y_true) 16 | ] 17 | return true_predictions, true_labels -------------------------------------------------------------------------------- /nmatheg/nmatheg.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from .dataset import create_dataset 4 | from .models import SimpleClassificationModel, BERTTextClassificationModel\ 5 | ,BERTTokenClassificationModel,BERTQuestionAnsweringModel\ 6 | ,SimpleTokenClassificationModel,SimpleQuestionAnsweringModel\ 7 | ,SimpleMachineTranslationModel,T5Seq2SeqModel 8 | from .configs import create_default_config 9 | import configparser 10 | import json 11 | from .utils import save_json, get_tokenizer 12 | from transformers import AutoModelForSequenceClassification, AutoConfig, AutoTokenizer,AutoModelForTokenClassification,AutoModelForQuestionAnswering,AutoModelForSeq2SeqLM 13 | from transformers import pipeline 14 | import pathlib 15 | 16 | import torch 17 | try: 18 | import bpe_surgery 19 | except: 20 | pass 21 | import numpy as np 22 | 23 | 24 | class TrainStrategy: 25 | def __init__(self, datasets, models, tokenizers= None, vocab_sizes=None, config_path= None, 26 | batch_size = 64, epochs = 5, lr = 5e-5, runs = 10, max_tokens = 128, max_train_samples = -1, 27 | preprocessing = {}, mode = 'finetune', ckpt= 'ckpts'): 28 | 29 | self.mode = mode 30 | modes = ['finetune', 'pretrain'] 31 | assert mode in modes , f"mode must be one of the following {modes}" 32 | if self.mode == 'pretrain': 33 | assert tokenizers is not None , "tokenizers must be set" 34 | assert vocab_sizes is not None, "vocab sizes must be set" 35 | 36 | if config_path == None: 37 | self.config = create_default_config(batch_size=batch_size, epochs = epochs, lr = lr, runs = runs, 38 | max_tokens=max_tokens, max_train_samples = max_train_samples, 39 | preprocessing = preprocessing, ckpt = ckpt) 40 | self.config['dataset'] = {'dataset_name' : datasets} 41 | self.config['model'] = {'model_name' : models} 42 | if self.mode == 'pretrain': 43 | self.config['tokenization']['vocab_size'] = vocab_sizes 44 | self.config['tokenization']['tokenizer_name'] = tokenizers 45 | else: 46 | self.config = configparser.ConfigParser() 47 | self.config.read(config_path) 48 | 49 | self.datasets_config = configparser.ConfigParser() 50 | rel_path = os.path.dirname(__file__) 51 | data_ini_path = os.path.join(rel_path, "datasets.ini") 52 | self.datasets_config.read(data_ini_path) 53 | self.preprocessing = preprocessing 54 | 55 | def start(self): 56 | model_names = [m.strip() for m in self.config['model']['model_name'].split(',')] 57 | dataset_names = [d.strip() for d in self.config['dataset']['dataset_name'].split(',')] 58 | if self.mode == 'pretrain': 59 | tokenizers = [t.strip() for t in self.config['tokenization']['tokenizer_name'].split(',')] 60 | vocab_sizes = [v.strip() for v in self.config['tokenization']['vocab_size'].split(',')] 61 | else: 62 | tokenizers = [m.strip() for m in self.config['model']['model_name'].split(',')] 63 | vocab_sizes = [str(AutoTokenizer.from_pretrained(v.strip()).vocab_size) for v in self.config['model']['model_name'].split(',')] 64 | runs = int(self.config['train']['runs']) 65 | max_tokens = int(self.config['tokenization']['max_tokens']) 66 | 67 | results = {} 68 | 69 | results_path = f"{self.config['train']['save_dir']}/results.json" 70 | if os.path.isfile(results_path): 71 | f = open(results_path) 72 | results = json.load(f) 73 | 74 | if self.mode == "finetune": 75 | for m, model_name in enumerate(model_names): 76 | if not model_name in results: 77 | results[model_name] = {} 78 | for d, dataset_name in enumerate(dataset_names): 79 | if not dataset_name in results[model_name]: 80 | results[model_name][dataset_name] = {} 81 | for run in range(runs): 82 | if os.path.isfile(results_path): 83 | if len(results[model_name][dataset_name].keys()) > 0: 84 | metric_name = list(results[model_name][dataset_name].keys())[0] 85 | curr_run = len(results[model_name][dataset_name][metric_name]) 86 | if run < curr_run: 87 | print(f"Run {run} already finished ") 88 | continue 89 | 90 | new_model_name = model_name.split("/")[-1] 91 | data_dir = f"{self.config['train']['save_dir']}/{new_model_name}/{dataset_name}/data" 92 | tokenizer_dir = f"{self.config['train']['save_dir']}/{new_model_name}/{dataset_name}/tokenizer" 93 | train_dir = f"{self.config['train']['save_dir']}/{new_model_name}/{dataset_name}/run_{run}" 94 | for path in [data_dir, tokenizer_dir, train_dir]: 95 | pathlib.Path(path).mkdir(parents=True, exist_ok=True) 96 | 97 | self.data_config = self.datasets_config[dataset_name] 98 | print(dict(self.data_config)) 99 | task_name = self.data_config['task'] 100 | vocab_size = vocab_sizes[m] 101 | tokenizer_name = tokenizers[m] 102 | tokenizer, self.datasets, self.examples = create_dataset(self.config, self.data_config, 103 | vocab_size = int(vocab_size), 104 | model_name = model_name, 105 | tokenizer_name = tokenizer_name, 106 | clean = True if len(self.preprocessing) else False, 107 | tok_save_path = tokenizer_dir, 108 | data_save_path = data_dir) 109 | self.model_config = {'model_name':model_name, 110 | 'vocab_size':int(vocab_size), 111 | 'num_labels':int(self.data_config['num_labels']), 112 | 'labels':self.data_config['labels']} 113 | 114 | print(self.model_config) 115 | if task_name in ['cls', 'nli']: 116 | self.model = BERTTextClassificationModel(self.model_config) 117 | elif task_name in ['ner', 'pos']: 118 | self.model = BERTTokenClassificationModel(self.model_config) 119 | 120 | elif task_name == 'qa': 121 | self.model = BERTQuestionAnsweringModel(self.model_config) 122 | elif task_name in ['mt', 'sum']: 123 | self.model = T5Seq2SeqModel(self.model_config, tokenizer = tokenizer, task = task_name) 124 | 125 | 126 | self.train_config = {'epochs':int(self.config['train']['epochs']), 127 | 'save_dir':train_dir, 128 | 'batch_size':int(self.config['train']['batch_size']), 129 | 'lr':float(self.config['train']['lr']), 130 | 'runs':run} 131 | self.tokenizer_config = {'name': tokenizer_name, 'vocab_size': vocab_size, 'max_tokens': max_tokens, 132 | 'save_path': tokenizer_dir} 133 | print(self.tokenizer_config) 134 | print(self.train_config) 135 | os.makedirs(self.train_config['save_dir'], exist_ok = True) 136 | 137 | if task_name == 'mt': 138 | metrics = self.model.train(self.datasets, self.examples, **self.train_config) 139 | else: 140 | metrics = self.model.train(self.datasets, self.examples, **self.train_config) 141 | 142 | save_json(self.train_config, f"{train_dir}/train_config.json") 143 | save_json(self.data_config, f"{data_dir}/data_config.json") 144 | save_json(self.model_config, f"{train_dir}/model_config.json") 145 | save_json(self.tokenizer_config, f"{tokenizer_dir}/tokenizer_config.json") 146 | 147 | for metric_name in metrics: 148 | if metric_name not in results[model_name][dataset_name]: 149 | results[model_name][dataset_name][metric_name] = [] 150 | results[model_name][dataset_name][metric_name].append(metrics[metric_name]) 151 | self.model.wipe_memory() 152 | with open(f"{self.config['train']['save_dir']}/results.json", 'w') as handle: 153 | json.dump(results, handle) 154 | 155 | elif self.mode == "pretrain": 156 | for t, tokenizer_name in enumerate(tokenizers): 157 | if not tokenizer_name in results: 158 | results[tokenizer_name] = {} 159 | for v, vocab_size in enumerate(vocab_sizes): 160 | if self.mode == 'finetune' and v != t: 161 | continue 162 | if not vocab_size in results[tokenizer_name]: 163 | results[tokenizer_name][vocab_size] = {} 164 | for d, dataset_name in enumerate(dataset_names): 165 | if not dataset_name in results[tokenizer_name][vocab_size]: 166 | results[tokenizer_name][vocab_size][dataset_name] = {} 167 | for m, model_name in enumerate(model_names): 168 | if self.mode == 'finetune' and t != m: 169 | continue 170 | if not model_name in results[tokenizer_name][vocab_size][dataset_name]: 171 | results[tokenizer_name][vocab_size][dataset_name][model_name] = {} 172 | for run in range(runs): 173 | if os.path.isfile(results_path): 174 | if len(results[tokenizer_name][vocab_size][dataset_name][model_name].keys()) > 0: 175 | metric_name = list(results[tokenizer_name][vocab_size][dataset_name][model_name].keys())[0] 176 | curr_run = len(results[tokenizer_name][vocab_size][dataset_name][model_name][metric_name]) 177 | if run < curr_run: 178 | print(f"Run {run} already finished ") 179 | continue 180 | 181 | data_dir = f"{self.config['train']['save_dir']}/{tokenizer_name}/{vocab_size}/{dataset_name}/{model_name}/data" 182 | tokenizer_dir = f"{self.config['train']['save_dir']}/{tokenizer_name}/{vocab_size}/{dataset_name}/{model_name}/tokenizer" 183 | train_dir = f"{self.config['train']['save_dir']}/{tokenizer_name}/{vocab_size}/{dataset_name}/{model_name}/run_{run}" 184 | for path in [data_dir, tokenizer_dir, train_dir]: 185 | pathlib.Path(path).mkdir(parents=True, exist_ok=True) 186 | 187 | 188 | self.data_config = self.datasets_config[dataset_name] 189 | print(dict(self.data_config)) 190 | task_name = self.data_config['task'] 191 | tokenizer, self.datasets, self.examples = create_dataset(self.config, self.data_config, 192 | vocab_size = int(vocab_size), 193 | model_name = model_name, 194 | tokenizer_name = tokenizer_name, 195 | data_save_path = data_dir, 196 | tok_save_path = tokenizer_dir, 197 | clean = True if len(self.preprocessing) else False) 198 | self.model_config = {'model_name':model_name, 199 | 'vocab_size':int(vocab_size), 200 | 'num_labels':int(self.data_config['num_labels']), 201 | 'labels':self.data_config['labels']} 202 | 203 | print(self.model_config) 204 | if task_name in ['cls', 'nli']: 205 | self.model = SimpleClassificationModel(self.model_config) 206 | elif task_name == 'ner': 207 | self.model = SimpleTokenClassificationModel(self.model_config) 208 | 209 | elif task_name == 'qa': 210 | self.model = SimpleQuestionAnsweringModel(self.model_config) 211 | elif task_name == 'mt': 212 | self.model = SimpleMachineTranslationModel(self.model_config, tokenizer = tokenizer) 213 | 214 | self.train_config = {'epochs':int(self.config['train']['epochs']), 215 | 'save_dir':train_dir, 216 | 'batch_size':int(self.config['train']['batch_size']), 217 | 'lr':float(self.config['train']['lr']), 218 | 'runs':run} 219 | self.tokenizer_config = {'name': tokenizer_name, 'vocab_size': vocab_size, 'max_tokens': max_tokens, 220 | 'save_path': tokenizer_dir} 221 | print(self.tokenizer_config) 222 | print(self.train_config) 223 | os.makedirs(self.train_config['save_dir'], exist_ok = True) 224 | 225 | if task_name == 'mt': 226 | metrics = self.model.train(self.datasets, self.examples, **self.train_config) 227 | else: 228 | metrics = self.model.train(self.datasets, self.examples, **self.train_config) 229 | 230 | save_json(self.train_config, f"{train_dir}/train_config.json") 231 | save_json(self.data_config, f"{data_dir}/data_config.json") 232 | save_json(self.model_config, f"{train_dir}/model_config.json") 233 | save_json(self.tokenizer_config, f"{tokenizer_dir}/tokenizer_config.json") 234 | for metric_name in metrics: 235 | if metric_name not in results[tokenizer_name][vocab_size][dataset_name][model_name]: 236 | results[tokenizer_name][vocab_size][dataset_name][model_name][metric_name] = [] 237 | results[tokenizer_name][vocab_size][dataset_name][model_name][metric_name].append(metrics[metric_name]) 238 | self.model.wipe_memory() 239 | with open(f"{self.config['train']['save_dir']}/results.json", 'w') as handle: 240 | json.dump(results, handle) 241 | return results 242 | 243 | def predict_from_run(save_dir, run = 0, sentence = "", question = "", context = "", hypothesis = "", premise = ""): 244 | data_config = json.load(open(f"{save_dir}/data/data_config.json")) 245 | tokenizer_config = json.load(open(f"{save_dir}/tokenizer/tokenizer_config.json")) 246 | train_dir = f"{save_dir}/run_{run}" 247 | model_config = json.load(open(f"{train_dir}/model_config.json")) 248 | model_name = model_config["model_name"] 249 | task_name = data_config['task'] 250 | tokenizer_name = tokenizer_config["name"] 251 | tokenizer_save_path = tokenizer_config["save_path"] 252 | max_tokens = tokenizer_config["max_tokens"] 253 | vocab_size = tokenizer_config["vocab_size"] 254 | num_labels = model_config["num_labels"] 255 | 256 | if model_name == "birnn": 257 | if task_name == "mt": 258 | src_tokenizer = get_tokenizer(tokenizer_name, vocab_size = vocab_size) 259 | trg_tokenizer = get_tokenizer(tokenizer_name, vocab_size = vocab_size) 260 | 261 | src_tokenizer.load(tokenizer_save_path, name = "src_tok") 262 | trg_tokenizer.load(tokenizer_save_path, name = "trg_tok") 263 | 264 | model = SimpleMachineTranslationModel(model_config, tokenizer = trg_tokenizer) 265 | model.model.load_state_dict(torch.load(f"{train_dir}/pytorch_model.bin")) 266 | 267 | encoding = src_tokenizer.encode_sentences([sentence], add_boundry=True, out_length=max_tokens) 268 | out = model.model(torch.tensor(encoding).to('cuda'), mode = "generate") 269 | return trg_tokenizer.decode_sentences(out['outputs']) 270 | 271 | elif task_name == "cls": 272 | tokenizer = get_tokenizer(tokenizer_name, vocab_size = vocab_size) 273 | tokenizer.load_model(f"{tokenizer_save_path}/m.model") 274 | 275 | model = SimpleClassificationModel(model_config) 276 | model.model.load_state_dict(torch.load(f"{train_dir}/pytorch_model.bin")) 277 | 278 | encoding = tokenizer.encode_sentences([sentence], out_length=max_tokens) 279 | out = model.model(torch.tensor(encoding).to('cuda')) 280 | labels = data_config['labels'].split(",") 281 | return labels[out['logits'].argmax(-1)] 282 | 283 | elif task_name == "nli": 284 | tokenizer = get_tokenizer(tokenizer_name, vocab_size = vocab_size) 285 | tokenizer.load(tokenizer_save_path) 286 | 287 | model = SimpleClassificationModel(model_config) 288 | model.model.load_state_dict(torch.load(f"{train_dir}/pytorch_model.bin")) 289 | 290 | encoding = tokenizer.encode_sentences([premise + " "+ hypothesis], add_boundry=True, out_length=max_tokens) 291 | out = model.model(torch.tensor(encoding).to('cuda')) 292 | labels = data_config['labels'].split(",") 293 | return labels[out['logits'].argmax(-1)] 294 | 295 | elif task_name == "ner": 296 | tokenizer = get_tokenizer(tokenizer_name, vocab_size = vocab_size) 297 | tokenizer.load(tokenizer_save_path) 298 | 299 | model = SimpleTokenClassificationModel(model_config) 300 | model.model.load_state_dict(torch.load(f"{train_dir}/pytorch_model.bin")) 301 | output = [] 302 | labels = data_config['labels'].split(",") 303 | out_sentence = "" 304 | sentence_encoding = [] 305 | word_lens = [] 306 | words = sentence.split(' ') 307 | for word_id , word in enumerate(words): 308 | enc_words = tokenizer._encode_word(word) 309 | sentence_encoding += enc_words 310 | word_lens .append(len(enc_words)) 311 | 312 | while len(sentence_encoding) < max_tokens: 313 | sentence_encoding.append(0) 314 | out = model.model(torch.tensor(sentence_encoding).to('cuda'))['logits'].argmax(-1).cpu().numpy() 315 | i = 0 316 | j = 0 317 | while i < sum(word_lens): 318 | preds = out[i:i+word_lens[j]] 319 | counts = np.bincount(preds) 320 | mj_label = np.argmax(counts) 321 | out_sentence += " "+labels[mj_label] 322 | i += word_lens[j] 323 | j += 1 324 | output.append(out_sentence.strip()) 325 | return output 326 | 327 | elif task_name == "qa": 328 | tokenizer = get_tokenizer(tokenizer_name, vocab_size = vocab_size) 329 | tokenizer.load(tokenizer_save_path) 330 | 331 | model = SimpleQuestionAnsweringModel(model_config) 332 | model.model.load_state_dict(torch.load(f"{train_dir}/pytorch_model.bin")) 333 | question_encoding = tokenizer.encode_sentences([question])[0] 334 | context_encoding = tokenizer.encode_sentences([context])[0] 335 | pad_re = max_tokens - (len(question_encoding) + len(context_encoding) + 1) 336 | encoding = question_encoding +[0]+context_encoding + [0] * pad_re 337 | out = model.model(torch.tensor([encoding]).to('cuda')) 338 | start_preds = out['start_logits'].argmax(-1).cpu().numpy() 339 | end_preds = out['end_logits'].argmax(-1).cpu().numpy() 340 | return tokenizer.decode_sentences([encoding[start_preds[0]:end_preds[0]]]) 341 | else: 342 | 343 | 344 | if task_name == "cls": 345 | config = AutoConfig.from_pretrained(model_name, num_labels=num_labels) 346 | model = AutoModelForSequenceClassification.from_pretrained(train_dir, config = config) 347 | tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, model_max_length = 512) 348 | encoded_review = tokenizer.encode_plus( 349 | sentence, 350 | max_length=512, 351 | add_special_tokens=True, 352 | return_token_type_ids=False, 353 | pad_to_max_length=True, 354 | return_attention_mask=True, 355 | return_tensors='pt', 356 | ) 357 | 358 | input_ids = encoded_review['input_ids'] 359 | attention_mask = encoded_review['attention_mask'] 360 | output = model(input_ids, attention_mask) 361 | labels = data_config['labels'].split(",") 362 | return labels[output['logits'].argmax(-1)] 363 | 364 | elif task_name == "nli": 365 | config = AutoConfig.from_pretrained(model_name, num_labels=num_labels) 366 | model = AutoModelForSequenceClassification.from_pretrained(train_dir, config = config) 367 | tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, model_max_length = 512) 368 | encoded_review = tokenizer.encode_plus( 369 | premise, 370 | hypothesis, 371 | max_length=512, 372 | add_special_tokens=True, 373 | return_token_type_ids=False, 374 | pad_to_max_length=True, 375 | return_attention_mask=True, 376 | return_tensors='pt', 377 | ) 378 | 379 | input_ids = encoded_review['input_ids'] 380 | attention_mask = encoded_review['attention_mask'] 381 | output = model(input_ids, attention_mask) 382 | labels = data_config['labels'].split(",") 383 | return labels[output['logits'].argmax(-1)] 384 | 385 | elif task_name in ['ner', 'pos']: 386 | labels = data_config['labels'].split(",") 387 | config = AutoConfig.from_pretrained(model_name, num_labels = num_labels, id2label = labels) 388 | tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length = 512) 389 | model = AutoModelForTokenClassification.from_pretrained(train_dir, config = config) 390 | nlp = pipeline(task_name, model=model, tokenizer=tokenizer) 391 | return nlp(sentence) 392 | 393 | elif task_name == "qa": 394 | config = AutoConfig.from_pretrained(model_name) 395 | tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length = 512) 396 | model = AutoModelForQuestionAnswering.from_pretrained(train_dir, config = config) 397 | nlp = pipeline("question-answering", model=model, tokenizer=tokenizer) 398 | return nlp(question=question, context=context) 399 | 400 | elif task_name in ["mt", "sum"]: 401 | config = AutoConfig.from_pretrained(model_name) 402 | tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length = 512) 403 | model = AutoModelForSeq2SeqLM.from_pretrained(train_dir, config = config) 404 | nlp = pipeline('text2text-generation', model=model, tokenizer=tokenizer) 405 | return nlp(sentence) 406 | -------------------------------------------------------------------------------- /nmatheg/preprocess_ner.py: -------------------------------------------------------------------------------- 1 | # Creating a class to pull the words from the columns and create them into sentences 2 | from datasets import Dataset, DatasetDict 3 | 4 | def aggregate_tokens(dataset, config, data_config, max_len = 128): 5 | new_dataset = {} 6 | token_col = data_config['text'] 7 | tag_col = data_config['label'] 8 | 9 | for split in dataset: 10 | sent_labels = [] 11 | sent_label = [] 12 | sentence = [] 13 | sentences = [] 14 | 15 | for i, item in enumerate(dataset[split]): 16 | token, label = item[token_col], item[tag_col] 17 | sent_label.append(label) 18 | sentence.append(token) 19 | if len(sentence) == max_len: 20 | sentences.append(sentence) 21 | sent_labels.append(sent_label) 22 | sentence = [] 23 | sent_label = [] 24 | new_dataset[split] = Dataset.from_dict({token_col:sentences, tag_col:sent_labels}) 25 | return DatasetDict(new_dataset) 26 | 27 | # https://github.com/huggingface/transformers/blob/44f5b260fe7a69cbd82be91b58c62a2879d530fa/examples/pytorch/token-classification/run_ner_no_trainer.py#L353 28 | def tokenize_and_align_labels(dataset, tokenizer, data_config, model_type = 'transformer', max_len = 128): 29 | 30 | token_col = data_config['text'] 31 | tag_col = data_config['label'] 32 | 33 | if 'transformer' in model_type: 34 | tokenized_inputs = tokenizer( 35 | dataset[token_col], 36 | max_length=max_len, 37 | padding='max_length', 38 | truncation=True, 39 | # We use this argument because the texts in our dataset are lists of words (with a label for each word). 40 | is_split_into_words=True, 41 | ) 42 | labels = [] 43 | for i, label in enumerate(dataset[tag_col]): 44 | word_ids = tokenized_inputs.word_ids(batch_index=i) 45 | previous_word_idx = None 46 | label_ids = [] 47 | for word_idx in word_ids: 48 | if word_idx is None: 49 | label_ids.append(-100) 50 | elif word_idx != previous_word_idx: 51 | label_ids.append(label[word_idx]) 52 | else: 53 | label_ids.append(label[word_idx] if True else -100) 54 | previous_word_idx = word_idx 55 | 56 | labels.append(label_ids) 57 | tokenized_inputs["labels"] = labels 58 | return tokenized_inputs 59 | else: 60 | labels = [] 61 | input_ids = [] 62 | for i, label in enumerate(dataset[tag_col]): 63 | word_ids = [] 64 | tokens = [] 65 | for j, word in enumerate(dataset[token_col][i]): 66 | token_ids = tokenizer._encode_word(word) 67 | for token_id in token_ids: 68 | tokens.append(token_id) 69 | word_ids.append(j) 70 | if len(tokens) > max_len: 71 | break 72 | 73 | while len(tokens) < max_len: 74 | tokens.append(0) 75 | word_ids.append(None) 76 | else: 77 | tokens = tokens[:max_len] 78 | word_ids = word_ids[:max_len] 79 | 80 | input_ids.append(tokens) 81 | previous_word_idx = None 82 | label_ids = [] 83 | for word_idx in word_ids: 84 | if word_idx is None: 85 | label_ids.append(-100) 86 | elif word_idx != previous_word_idx: 87 | label_ids.append(label[word_idx]) 88 | else: 89 | label_ids.append(label[word_idx] if True else -100) 90 | previous_word_idx = word_idx 91 | labels.append(label_ids) 92 | dataset["labels"] = labels 93 | dataset["input_ids"] = input_ids 94 | return dataset -------------------------------------------------------------------------------- /nmatheg/preprocess_qa.py: -------------------------------------------------------------------------------- 1 | # https://github.com/huggingface/transformers/blob/master/examples/pytorch/question-answering/run_qa_no_trainer.py 2 | import re 3 | import copy 4 | 5 | def overflow_to_sample_mapping(tokens, offsets, idx, max_len = 384, doc_stride = 128): 6 | fixed_tokens = [] 7 | fixed_offsets = [] 8 | sep_index = tokens.index(-100) 9 | question = tokens[:sep_index] 10 | context = tokens[sep_index+1:] 11 | q_offsets = offsets[:sep_index] 12 | c_offsets = offsets[sep_index+1:] 13 | q_len = len(question) 14 | c_len = len(context) 15 | st_idx = 0 16 | samplings = [] 17 | sequences = [] 18 | 19 | while True: 20 | ed_idx = st_idx+max_len-q_len-1 21 | pad_re = max_len - len(question+ [0] + context[st_idx:ed_idx]) 22 | 23 | if len(context[st_idx:ed_idx]) == 0: 24 | break 25 | curr_tokens = question+[0] + context[st_idx:ed_idx] + [0] * pad_re 26 | curr_offset = q_offsets+[(0,0)] + c_offsets[st_idx:ed_idx] + [(0,0)] * pad_re 27 | curr_seq = [0]*q_len+[None]+[1]*len(context[st_idx:ed_idx])+[None] * pad_re 28 | 29 | assert len(curr_tokens) == len(curr_offset) == len(curr_seq) == max_len, f"curr_tokens: {len(curr_tokens)}, curr_seq: {len(curr_seq)}" 30 | fixed_tokens.append(curr_tokens[:max_len]) 31 | fixed_offsets.append(curr_offset[:max_len]) 32 | samplings.append(idx) 33 | sequences.append(curr_seq) 34 | 35 | st_idx += doc_stride 36 | if pad_re > 0: 37 | break 38 | for i in range(len(fixed_tokens)): 39 | assert len(fixed_tokens[i]) == len(fixed_offsets[i]) 40 | return fixed_tokens, fixed_offsets, samplings, sequences 41 | 42 | def prepare_features(examples, tokenizer, data_config, model_type = 'transformer', max_len = 384): 43 | # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results 44 | # in one example possible giving several features when a context is long, each of those features having a 45 | # context that overlaps a bit the context of the previous feature. 46 | if 'transformer' in model_type: 47 | tokenized_examples = tokenizer( 48 | examples["question"], 49 | examples["context"], 50 | truncation="only_second", 51 | max_length=max_len, 52 | stride=128, 53 | return_overflowing_tokens=True, 54 | return_offsets_mapping=True, 55 | padding="max_length", 56 | ) 57 | # Since one example might give us several features if it has a long context, we need a map from a feature to 58 | # its corresponding example. This key gives us just that. 59 | sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") 60 | # The offset mappings will give us a map from token to character position in the original context. This will 61 | # help us compute the start_positions and end_positions. 62 | offset_mapping = tokenized_examples["offset_mapping"] 63 | # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the 64 | # corresponding example_id and we will store the offset mappings. 65 | # Let's label those examples! 66 | tokenized_examples["start_positions"] = [] 67 | tokenized_examples["end_positions"] = [] 68 | 69 | for i, offsets in enumerate(offset_mapping): 70 | # We will label impossible answers with the index of the CLS token. 71 | input_ids = tokenized_examples["input_ids"][i] 72 | cls_index = input_ids.index(tokenizer.cls_token_id) 73 | 74 | # Grab the sequence corresponding to that example (to know what is the context and what is the question). 75 | sequence_ids = tokenized_examples.sequence_ids(i) 76 | 77 | # One example can give several spans, this is the index of the example containing this span of text. 78 | sample_index = sample_mapping[i] 79 | answers = examples["answers"][sample_index] 80 | # If no answers are given, set the cls_index as answer. 81 | if len(answers["answer_start"]) == 0: 82 | tokenized_examples["start_positions"].append(cls_index) 83 | tokenized_examples["end_positions"].append(cls_index) 84 | else: 85 | # Start/end character index of the answer in the text. 86 | start_char = answers["answer_start"][0] 87 | end_char = start_char + len(answers["text"][0]) 88 | 89 | # Start token index of the current span in the text. 90 | token_start_index = 0 91 | while sequence_ids[token_start_index] != 1: 92 | token_start_index += 1 93 | 94 | # End token index of the current span in the text. 95 | token_end_index = len(input_ids) - 1 96 | while sequence_ids[token_end_index] != 1: 97 | token_end_index -= 1 98 | 99 | # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). 100 | if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): 101 | tokenized_examples["start_positions"].append(cls_index) 102 | tokenized_examples["end_positions"].append(cls_index) 103 | else: 104 | # Otherwise move the token_start_index and token_end_index to the two ends of the answer. 105 | # Note: we could go after the last offset if the answer is the last word (edge case). 106 | while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: 107 | token_start_index += 1 108 | tokenized_examples["start_positions"].append(token_start_index - 1) 109 | while offsets[token_end_index][1] >= end_char: 110 | token_end_index -= 1 111 | tokenized_examples["end_positions"].append(token_end_index + 1) 112 | 113 | tokenized_examples["example_id"] = [] 114 | 115 | for i in range(len(tokenized_examples["input_ids"])): 116 | # Grab the sequence corresponding to that example (to know what is the context and what is the question). 117 | sequence_ids = tokenized_examples.sequence_ids(i) 118 | context_index = 1 119 | 120 | # One example can give several spans, this is the index of the example containing this span of text. 121 | sample_index = sample_mapping[i] 122 | tokenized_examples["example_id"].append(examples["id"][sample_index]) 123 | 124 | # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token 125 | # position is part of the context or not. 126 | tokenized_examples["offset_mapping"][i] = [ 127 | (o if sequence_ids[k] == context_index else None) 128 | for k, o in enumerate(tokenized_examples["offset_mapping"][i]) 129 | ] 130 | 131 | return tokenized_examples 132 | 133 | else: 134 | question_col, context_col = data_config['text'].split(",") 135 | tokenized_examples = copy.deepcopy(examples) 136 | input_ids = [] 137 | offset_mapping = [] 138 | sequence_ids = [] 139 | sample_mapping = [] 140 | 141 | for i, (question, context) in enumerate(zip(examples[question_col], examples[context_col])): 142 | offsets = [] 143 | tokens = [] 144 | sequences = [] 145 | 146 | question_context = question + " "+context 147 | st = 0 148 | for word in question_context.split(" "): 149 | if len(word) == 0: 150 | st += 1 151 | continue 152 | 153 | word = word.strip() 154 | 155 | if word == "": 156 | offsets.append((0, 0)) 157 | tokens.append(-100) 158 | st = 0 159 | else: 160 | token_ids = tokenizer._encode_word(word) 161 | token_ids = [token_id for token_id in token_ids] 162 | token_strs = tokenizer._tokenize_word(word, remove_sow=True) 163 | if token_ids[0] == tokenizer.sow_idx: 164 | token_strs = [tokenizer.sow] + token_strs 165 | for j, token_id in enumerate(token_ids): 166 | token_str = token_strs[j] 167 | tokens.append(token_id) 168 | if token_str == tokenizer.sow: 169 | offsets.append((st, st)) 170 | else: 171 | offsets.append((st, st+len(token_str))) 172 | st += len(token_str) 173 | st += 1 # for space 174 | 175 | 176 | tokens, offsets, samplings, sequences = overflow_to_sample_mapping(tokens, offsets, i, max_len = max_len) 177 | 178 | 179 | sample_mapping += samplings 180 | input_ids += tokens 181 | offset_mapping += offsets 182 | sequence_ids += sequences 183 | 184 | tokenized_examples = {'input_ids':input_ids, 'sequence_ids':sequence_ids, 'offset_mapping': offset_mapping, 'overflow_to_sample_mapping': sample_mapping} 185 | tokenized_examples["start_positions"] = [] 186 | tokenized_examples["end_positions"] = [] 187 | 188 | for i, offsets in enumerate(offset_mapping): 189 | # We will label impossible answers with the index of the CLS token. 190 | input_ids = tokenized_examples["input_ids"][i] 191 | # cls_index = input_ids.index(tokenizer.cls_token_id) 192 | cls_index = 0 193 | 194 | # Grab the sequence corresponding to that example (to know what is the context and what is the question). 195 | sequence_ids = tokenized_examples['sequence_ids'][i] 196 | 197 | # One example can give several spans, this is the index of the example containing this span of text. 198 | sample_index = tokenized_examples["overflow_to_sample_mapping"][i] 199 | 200 | answers = examples["answers"][sample_index] 201 | # If no answers are given, set the cls_index as answer. 202 | if len(answers["answer_start"]) == 0: 203 | tokenized_examples["start_positions"].append(cls_index) 204 | tokenized_examples["end_positions"].append(cls_index) 205 | else: 206 | # Start/end character index of the answer in the text. 207 | start_char = answers["answer_start"][0] 208 | end_char = start_char + len(answers["text"][0]) 209 | 210 | # Start token index of the current span in the text. 211 | token_start_index = 0 212 | while sequence_ids[token_start_index] != 1: 213 | token_start_index += 1 214 | # End token index of the current span in the text. 215 | token_end_index = len(input_ids) - 1 216 | while sequence_ids[token_end_index] != 1: 217 | token_end_index -= 1 218 | 219 | # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). 220 | if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): 221 | tokenized_examples["start_positions"].append(cls_index) 222 | tokenized_examples["end_positions"].append(cls_index) 223 | 224 | else: 225 | # Otherwise move the token_start_index and token_end_index to the two ends of the answer. 226 | # Note: we could go after the last offset if the answer is the last word (edge case). 227 | while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: 228 | token_start_index += 1 229 | tokenized_examples["start_positions"].append(token_start_index - 1) 230 | while offsets[token_end_index][1] >= end_char: 231 | token_end_index -= 1 232 | tokenized_examples["end_positions"].append(token_end_index + 1) 233 | 234 | tokenized_examples["example_id"] = [] 235 | 236 | for i in range(len(tokenized_examples["input_ids"])): 237 | # Grab the sequence corresponding to that example (to know what is the context and what is the question). 238 | sequence_ids = tokenized_examples['sequence_ids'][i] 239 | context_index = 1 240 | 241 | # One example can give several spans, this is the index of the example containing this span of text. 242 | sample_index = sample_mapping[i] 243 | tokenized_examples["example_id"].append(examples["id"][sample_index]) 244 | 245 | # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token 246 | # position is part of the context or not. 247 | tokenized_examples["offset_mapping"][i] = [ 248 | (o if sequence_ids[k] == context_index else None) 249 | for k, o in enumerate(tokenized_examples["offset_mapping"][i]) 250 | ] 251 | return tokenized_examples -------------------------------------------------------------------------------- /nmatheg/qa_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | import logging 4 | import os 5 | from typing import Optional, Tuple 6 | from datasets import load_metric 7 | import numpy as np 8 | from tqdm.auto import tqdm 9 | 10 | def evaluate_metric(dataset, examples, all_start_logits, all_end_logits): 11 | metric = load_metric("squad") 12 | max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor 13 | 14 | # concatenate the numpy array 15 | start_logits_concat = create_and_fill_np_array(all_start_logits, dataset, max_len) 16 | end_logits_concat = create_and_fill_np_array(all_end_logits, dataset, max_len) 17 | 18 | # delete the list of numpy arrays 19 | del all_start_logits 20 | del all_end_logits 21 | 22 | outputs_numpy = (start_logits_concat, end_logits_concat) 23 | prediction = post_processing_function(examples, dataset, outputs_numpy) 24 | return metric.compute(predictions=prediction['predictions'], references=prediction['label_ids']) 25 | 26 | 27 | def post_processing_function(examples, features, predictions, stage="eval"): 28 | # Post-processing: we match the start logits and end logits to answers in the original context. 29 | predictions = postprocess_qa_predictions( 30 | examples=examples, 31 | features=features, 32 | predictions=predictions, 33 | prefix=stage, 34 | ) 35 | # Format the result to the format the metric expects. 36 | formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] 37 | 38 | references = [{"id": ex["id"], "answers": ex['answers']} for ex in examples] 39 | return {'predictions':formatted_predictions, 'label_ids':references} 40 | 41 | def create_and_fill_np_array(start_or_end_logits, dataset, max_len): 42 | """ 43 | Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor 44 | Args: 45 | start_or_end_logits(:obj:`tensor`): 46 | This is the output predictions of the model. We can only enter either start or end logits. 47 | eval_dataset: Evaluation dataset 48 | max_len(:obj:`int`): 49 | The maximum length of the output tensor. ( See the model.eval() part for more details ) 50 | """ 51 | 52 | step = 0 53 | # create a numpy array and fill it with -100. 54 | logits_concat = np.full((len(dataset), max_len), -100, dtype=np.float64) 55 | # Now since we have create an array now we will populate it with the outputs gathered using accelerator.gather 56 | for i, output_logit in enumerate(start_or_end_logits): # populate columns 57 | # We have to fill it such that we have to take the whole tensor and replace it on the newly created array 58 | # And after every iteration we have to change the step 59 | 60 | batch_size = output_logit.shape[0] 61 | cols = output_logit.shape[1] 62 | 63 | if step + batch_size < len(dataset): 64 | logits_concat[step : step + batch_size, :cols] = output_logit 65 | else: 66 | logits_concat[step:, :cols] = output_logit[: len(dataset) - step] 67 | 68 | step += batch_size 69 | 70 | return logits_concat 71 | 72 | def postprocess_qa_predictions( 73 | examples, 74 | features, 75 | predictions: Tuple[np.ndarray, np.ndarray], 76 | version_2_with_negative: bool = False, 77 | n_best_size: int = 20, 78 | max_answer_length: int = 30, 79 | null_score_diff_threshold: float = 0.0, 80 | output_dir: Optional[str] = None, 81 | prefix: Optional[str] = None, 82 | log_level: Optional[int] = logging.WARNING, 83 | ): 84 | """ 85 | Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the 86 | original contexts. This is the base postprocessing functions for models that only return start and end logits. 87 | Args: 88 | examples: The non-preprocessed dataset (see the main script for more information). 89 | features: The processed dataset (see the main script for more information). 90 | predictions (:obj:`Tuple[np.ndarray, np.ndarray]`): 91 | The predictions of the model: two arrays containing the start logits and the end logits respectively. Its 92 | first dimension must match the number of elements of :obj:`features`. 93 | version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`): 94 | Whether or not the underlying dataset contains examples with no answers. 95 | n_best_size (:obj:`int`, `optional`, defaults to 20): 96 | The total number of n-best predictions to generate when looking for an answer. 97 | max_answer_length (:obj:`int`, `optional`, defaults to 30): 98 | The maximum length of an answer that can be generated. This is needed because the start and end predictions 99 | are not conditioned on one another. 100 | null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0): 101 | The threshold used to select the null answer: if the best answer has a score that is less than the score of 102 | the null answer minus this threshold, the null answer is selected for this example (note that the score of 103 | the null answer for an example giving several features is the minimum of the scores for the null answer on 104 | each feature: all features must be aligned on the fact they `want` to predict a null answer). 105 | Only useful when :obj:`version_2_with_negative` is :obj:`True`. 106 | output_dir (:obj:`str`, `optional`): 107 | If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if 108 | :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null 109 | answers, are saved in `output_dir`. 110 | prefix (:obj:`str`, `optional`): 111 | If provided, the dictionaries mentioned above are saved with `prefix` added to their names. 112 | log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): 113 | ``logging`` log level (e.g., ``logging.WARNING``) 114 | """ 115 | assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)." 116 | all_start_logits, all_end_logits = predictions 117 | 118 | assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features." 119 | 120 | # Build a map example to its corresponding features. 121 | example_id_to_index = {k: i for i, k in enumerate(examples["id"])} 122 | features_per_example = collections.defaultdict(list) 123 | for i, feature in enumerate(features): 124 | features_per_example[example_id_to_index[feature["example_id"]]].append(i) 125 | 126 | # The dictionaries we have to fill. 127 | all_predictions = collections.OrderedDict() 128 | all_nbest_json = collections.OrderedDict() 129 | if version_2_with_negative: 130 | scores_diff_json = collections.OrderedDict() 131 | 132 | # Let's loop over all the examples! 133 | for example_index, example in enumerate(examples): 134 | # Those are the indices of the features associated to the current example. 135 | feature_indices = features_per_example[example_index] 136 | 137 | min_null_prediction = None 138 | prelim_predictions = [] 139 | 140 | # Looping through all the features associated to the current example. 141 | for feature_index in feature_indices: 142 | # We grab the predictions of the model for this feature. 143 | start_logits = all_start_logits[feature_index] 144 | end_logits = all_end_logits[feature_index] 145 | # This is what will allow us to map some the positions in our logits to span of texts in the original 146 | # context. 147 | offset_mapping = features[feature_index]["offset_mapping"] 148 | # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context 149 | # available in the current feature. 150 | token_is_max_context = features[feature_index].get("token_is_max_context", None) 151 | 152 | # Update minimum null prediction. 153 | feature_null_score = start_logits[0] + end_logits[0] 154 | if min_null_prediction is None or min_null_prediction["score"] > feature_null_score: 155 | min_null_prediction = { 156 | "offsets": (0, 0), 157 | "score": feature_null_score, 158 | "start_logit": start_logits[0], 159 | "end_logit": end_logits[0], 160 | } 161 | 162 | # Go through all possibilities for the `n_best_size` greater start and end logits. 163 | start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist() 164 | end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() 165 | for start_index in start_indexes: 166 | for end_index in end_indexes: 167 | # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond 168 | # to part of the input_ids that are not in the context. 169 | if ( 170 | start_index >= len(offset_mapping) 171 | or end_index >= len(offset_mapping) 172 | or offset_mapping[start_index] is None 173 | or offset_mapping[end_index] is None 174 | ): 175 | continue 176 | # Don't consider answers with a length that is either < 0 or > max_answer_length. 177 | if end_index < start_index or end_index - start_index + 1 > max_answer_length: 178 | continue 179 | # Don't consider answer that don't have the maximum context available (if such information is 180 | # provided). 181 | if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): 182 | continue 183 | prelim_predictions.append( 184 | { 185 | "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), 186 | "score": start_logits[start_index] + end_logits[end_index], 187 | "start_logit": start_logits[start_index], 188 | "end_logit": end_logits[end_index], 189 | } 190 | ) 191 | if version_2_with_negative: 192 | # Add the minimum null prediction 193 | prelim_predictions.append(min_null_prediction) 194 | null_score = min_null_prediction["score"] 195 | 196 | # Only keep the best `n_best_size` predictions. 197 | predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] 198 | 199 | # Add back the minimum null prediction if it was removed because of its low score. 200 | if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions): 201 | predictions.append(min_null_prediction) 202 | 203 | # Use the offsets to gather the answer text in the original context. 204 | context = example["context"] 205 | for pred in predictions: 206 | offsets = pred.pop("offsets") 207 | pred["text"] = context[offsets[0] : offsets[1]] 208 | 209 | # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid 210 | # failure. 211 | if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""): 212 | predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}) 213 | 214 | # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using 215 | # the LogSumExp trick). 216 | scores = np.array([pred.pop("score") for pred in predictions]) 217 | exp_scores = np.exp(scores - np.max(scores)) 218 | probs = exp_scores / exp_scores.sum() 219 | 220 | # Include the probabilities in our predictions. 221 | for prob, pred in zip(probs, predictions): 222 | pred["probability"] = prob 223 | 224 | # Pick the best prediction. If the null answer is not possible, this is easy. 225 | if not version_2_with_negative: 226 | all_predictions[example["id"]] = predictions[0]["text"] 227 | else: 228 | # Otherwise we first need to find the best non-empty prediction. 229 | i = 0 230 | while predictions[i]["text"] == "": 231 | i += 1 232 | best_non_null_pred = predictions[i] 233 | 234 | # Then we compare to the null prediction using the threshold. 235 | score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"] 236 | scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable. 237 | if score_diff > null_score_diff_threshold: 238 | all_predictions[example["id"]] = "" 239 | else: 240 | all_predictions[example["id"]] = best_non_null_pred["text"] 241 | 242 | # Make `predictions` JSON-serializable by casting np.float back to float. 243 | all_nbest_json[example["id"]] = [ 244 | {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} 245 | for pred in predictions 246 | ] 247 | return all_predictions -------------------------------------------------------------------------------- /nmatheg/tests.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | config = configparser.ConfigParser() 3 | config.read('config.ini') 4 | print(config['preprocessing']['segment']) -------------------------------------------------------------------------------- /nmatheg/utils.py: -------------------------------------------------------------------------------- 1 | import tkseem as tk 2 | try: 3 | import bpe_surgery 4 | except: 5 | pass 6 | import json 7 | 8 | def get_tokenizer(tok_name, vocab_size = 300, lang = 'ar'): 9 | if tok_name == "WordTokenizer": 10 | return tk.WordTokenizer(vocab_size=vocab_size) 11 | elif tok_name == "SentencePieceTokenizer": 12 | return tk.SentencePieceTokenizer(vocab_size=vocab_size) 13 | elif tok_name == "CharacterTokenizer": 14 | return tk.CharacterTokenizer(vocab_size=vocab_size) 15 | elif tok_name == "RandomTokenizer": 16 | return tk.RandomTokenizer(vocab_size=vocab_size) 17 | elif tok_name == "DisjointLetterTokenizer": 18 | return tk.DisjointLetterTokenizer(vocab_size=vocab_size) 19 | elif tok_name == "MorphologicalTokenizer": 20 | return tk.MorphologicalTokenizer(vocab_size=vocab_size) 21 | elif tok_name == "BPE": 22 | return bpe_surgery.bpe(vocab_size=vocab_size) 23 | elif tok_name == "MaT-BPE": 24 | return bpe_surgery.bpe(vocab_size=vocab_size, morph=True) 25 | elif tok_name == "Seg-BPE": 26 | return bpe_surgery.bpe(vocab_size=vocab_size, seg = True) 27 | else: 28 | raise('Unrecognized tokenizer name!') 29 | 30 | def get_preprocessing_args(config): 31 | args = {} 32 | map_bool = {'True':True, 'False':False, '[]': []} 33 | for key in config['preprocessing']: 34 | val = config['preprocessing'][key] 35 | if val in map_bool.keys(): 36 | args[key] = map_bool[val] 37 | else: 38 | args[key] = val 39 | return args 40 | 41 | def save_json(ob, save_path): 42 | with open(save_path, 'w') as handle: 43 | json.dump(dict(ob), handle) -------------------------------------------------------------------------------- /nmatheg_logo.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ARBML/nmatheg/209285b0b30780e2bf2b4a6a272cc9b2ac8ba95b/nmatheg_logo.PNG -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from nmatheg import predict_from_run 2 | import argparse 3 | import os 4 | from datasets import load_dataset 5 | import json 6 | # Create the parser 7 | my_parser = argparse.ArgumentParser() 8 | my_parser.add_argument('--p', '-path', type = str, action='store') 9 | my_parser.add_argument('--n', '-num', type = int, action='store') 10 | 11 | args = my_parser.parse_args() 12 | data_config = json.load(open(f"{args.p}/data/data_config.json")) 13 | data = load_dataset(data_config["name"]) 14 | src, trg = data_config['text'].split(',') 15 | out = predict_from_run(args.p, run = 0, sentence = data['train'][args.n][src]) 16 | print(out[0]) 17 | print({'gold_text': data['train'][args.n][trg]}) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tnkeeh 2 | tkseem 3 | tfds-nightly 4 | datasets 5 | transformers[sentencepiece] 6 | accelerate 7 | seqeval 8 | sacrebleu 9 | rouge_score_ar @ git+https://github.com/ARBML/rouge_score_ar 10 | evaluate 11 | pandas 12 | fsspec==2021.10 13 | s3fs==2021.10 14 | -------------------------------------------------------------------------------- /script.py: -------------------------------------------------------------------------------- 1 | import nmatheg as nm 2 | strategy = nm.TrainStrategy( 3 | datasets = 'caner', 4 | models = 'birnn', 5 | tokenizers = 'bpe', 6 | vocab_sizes = '1000', 7 | runs = 1, 8 | lr = 1e-4, 9 | epochs = 50, 10 | batch_size = 128, 11 | max_tokens = 128, 12 | ) 13 | output = strategy.start() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | 4 | with open('requirements.txt') as f: 5 | required = f.read().splitlines() 6 | 7 | with open('README.md') as readme_file: 8 | readme = readme_file.read() 9 | 10 | setup(name='nmatheg', 11 | version='0.0.4', 12 | url='', 13 | discription="Arabic Training Strategy For NLP Models", 14 | long_description=readme, 15 | long_description_content_type='text/markdown', 16 | author='Zaid Alyafeai, Maged Saeed', 17 | author_email='arabicmachinelearning@gmail.com', 18 | license='MIT', 19 | packages=['nmatheg'], 20 | install_requires=required, 21 | python_requires=">=3.6", 22 | include_package_data=True, 23 | zip_safe=False, 24 | ) 25 | --------------------------------------------------------------------------------