├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── README.rst ├── docs ├── Makefile ├── __pycache__ │ └── conf.cpython-39.pyc ├── _build │ ├── doctrees │ │ ├── environment.pickle │ │ ├── examples.doctree │ │ ├── index.doctree │ │ ├── modules.doctree │ │ ├── nlpboost.augmentation.doctree │ │ ├── nlpboost.doctree │ │ ├── notebooks.doctree │ │ └── readme.doctree │ └── html │ │ ├── .buildinfo │ │ ├── _images │ │ ├── nlpboost_diagram.png │ │ └── nlpboost_logo_3.png │ │ ├── _modules │ │ ├── index.html │ │ └── nlpboost │ │ │ ├── augmentation │ │ │ ├── TextAugmenterPipeline.html │ │ │ └── augmenter_config.html │ │ │ ├── autotrainer.html │ │ │ ├── ckpt_cleaner.html │ │ │ ├── dataset_config.html │ │ │ ├── default_param_spaces.html │ │ │ ├── hfdatasets_manager.html │ │ │ ├── hftransformers_manager.html │ │ │ ├── metrics.html │ │ │ ├── metrics_plotter.html │ │ │ ├── model_config.html │ │ │ ├── results_getter.html │ │ │ ├── tokenization_functions.html │ │ │ └── utils.html │ │ ├── _sources │ │ ├── examples.rst.txt │ │ ├── index.rst.txt │ │ ├── modules.rst.txt │ │ ├── nlpboost.augmentation.rst.txt │ │ ├── nlpboost.rst.txt │ │ ├── notebooks.rst.txt │ │ └── readme.rst.txt │ │ ├── _static │ │ ├── _sphinx_javascript_frameworks_compat.js │ │ ├── basic.css │ │ ├── css │ │ │ ├── badge_only.css │ │ │ ├── fonts │ │ │ │ ├── Roboto-Slab-Bold.woff │ │ │ │ ├── Roboto-Slab-Bold.woff2 │ │ │ │ ├── Roboto-Slab-Regular.woff │ │ │ │ ├── Roboto-Slab-Regular.woff2 │ │ │ │ ├── fontawesome-webfont.eot │ │ │ │ ├── fontawesome-webfont.svg │ │ │ │ ├── fontawesome-webfont.ttf │ │ │ │ ├── fontawesome-webfont.woff │ │ │ │ ├── fontawesome-webfont.woff2 │ │ │ │ ├── lato-bold-italic.woff │ │ │ │ ├── lato-bold-italic.woff2 │ │ │ │ ├── lato-bold.woff │ │ │ │ ├── lato-bold.woff2 │ │ │ │ ├── lato-normal-italic.woff │ │ │ │ ├── lato-normal-italic.woff2 │ │ │ │ ├── lato-normal.woff │ │ │ │ └── lato-normal.woff2 │ │ │ └── theme.css │ │ ├── doctools.js │ │ ├── documentation_options.js │ │ ├── file.png │ │ ├── jquery-3.6.0.js │ │ ├── jquery.js │ │ ├── js │ │ │ ├── badge_only.js │ │ │ ├── html5shiv-printshiv.min.js │ │ │ ├── html5shiv.min.js │ │ │ └── theme.js │ │ ├── language_data.js │ │ ├── minus.png │ │ ├── plus.png │ │ ├── pygments.css │ │ ├── searchtools.js │ │ ├── sphinx_highlight.js │ │ ├── twemoji.css │ │ ├── twemoji.js │ │ ├── underscore-1.13.1.js │ │ └── underscore.js │ │ ├── examples.html │ │ ├── genindex.html │ │ ├── index.html │ │ ├── modules.html │ │ ├── nlpboost.augmentation.html │ │ ├── nlpboost.html │ │ ├── notebooks.html │ │ ├── objects.inv │ │ ├── py-modindex.html │ │ ├── readme.html │ │ ├── search.html │ │ └── searchindex.js ├── conf.py ├── examples.rst ├── index.rst ├── make.bat ├── modules.rst ├── nlpboost.augmentation.rst ├── nlpboost.rst ├── notebooks.rst ├── readme.rst └── requirements.txt ├── examples ├── NER │ └── train_spanish_ner.py ├── README.md ├── README.rst ├── classification │ ├── train_classification.py │ └── train_multilabel.py ├── extractive_qa │ └── train_sqac.py └── seq2seq │ ├── train_maria_encoder_decoder_marimari.py │ └── train_summarization_mlsum.py ├── imgs ├── nlpboost_diagram.png ├── nlpboost_logo.png ├── nlpboost_logo_2.png └── nlpboost_logo_3.png ├── notebooks ├── NER │ └── train_spanish_ner.ipynb ├── README.md ├── README.rst ├── classification │ ├── train_emotion_classification.ipynb │ └── train_multilabel.ipynb └── extractive_qa │ └── train_sqac.ipynb ├── pyproject.toml ├── requirements.txt ├── setup.cfg ├── setup.py └── src └── nlpboost ├── __init__.py ├── augmentation ├── TextAugmenterPipeline.py ├── __init__.py ├── augmenter_config.py └── tests │ └── test_text_augmenter_pipeline.py ├── autotrainer.py ├── ckpt_cleaner.py ├── dataset_config.py ├── default_param_spaces.py ├── hfdatasets_manager.py ├── hftransformers_manager.py ├── metrics.py ├── metrics_plotter.py ├── model_config.py ├── results_getter.py ├── skip_mix.py ├── tests ├── test_autotrainer.py ├── test_ckpt_cleaner.py ├── test_dataset_config.py ├── test_general_utils.py ├── test_model_config.py └── test_tokenization_functions.py ├── tokenization_functions.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.json 2 | *.xml 3 | .coverage 4 | cover/ 5 | *__pycache__* -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-20.04 11 | tools: 12 | python: "3.9" 13 | 14 | 15 | # Build documentation in the docs/ directory with Sphinx 16 | sphinx: 17 | configuration: docs/conf.py 18 | 19 | # If using Sphinx, optionally build your docs in additional formats such as PDF 20 | # formats: 21 | # - pdf 22 | 23 | # Optionally declare the Python requirements required to build your docs 24 | python: 25 | install: 26 | - method: pip 27 | path: . 28 | - method: pip 29 | path: . -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Alejandro Vaca 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE.md 3 | include requirements.txt -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/__pycache__/conf.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/__pycache__/conf.cpython-39.pyc -------------------------------------------------------------------------------- /docs/_build/doctrees/environment.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/doctrees/environment.pickle -------------------------------------------------------------------------------- /docs/_build/doctrees/examples.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/doctrees/examples.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/doctrees/index.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/modules.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/doctrees/modules.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/nlpboost.augmentation.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/doctrees/nlpboost.augmentation.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/nlpboost.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/doctrees/nlpboost.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/notebooks.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/doctrees/notebooks.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/readme.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/doctrees/readme.doctree -------------------------------------------------------------------------------- /docs/_build/html/.buildinfo: -------------------------------------------------------------------------------- 1 | # Sphinx build info version 1 2 | # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. 3 | config: 6604d45b75a36e898702f5d763d28cd3 4 | tags: 645f666f9bcd5a90fca523b33c5a78b7 5 | -------------------------------------------------------------------------------- /docs/_build/html/_images/nlpboost_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_images/nlpboost_diagram.png -------------------------------------------------------------------------------- /docs/_build/html/_images/nlpboost_logo_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_images/nlpboost_logo_3.png -------------------------------------------------------------------------------- /docs/_build/html/_modules/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Overview: module code — nlpboost documentation 7 | 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 50 | 51 |
55 | 56 |
57 |
58 |
59 |
    60 |
  • 61 | 62 |
  • 63 |
  • 64 |
65 |
66 |
67 | 89 | 103 |
104 |
105 |
106 |
107 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /docs/_build/html/_modules/nlpboost/augmentation/augmenter_config.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | nlpboost.augmentation.augmenter_config — nlpboost documentation 7 | 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 50 | 51 |
55 | 56 |
57 |
58 |
59 |
    60 |
  • 61 | 62 | 63 |
  • 64 |
  • 65 |
66 |
67 |
68 |
69 |
70 | 71 |

Source code for nlpboost.augmentation.augmenter_config

 72 | from dataclasses import dataclass, field
 73 | from typing import Dict, Any
 74 | import nlpaug.augmenter.char as nac
 75 | import nlpaug.augmenter.word as naw
 76 | import nlpaug.augmenter.sentence as nas
 77 | 
 78 | class_translator = {
 79 |     "ocr": nac.OcrAug,
 80 |     "contextual_w_e": naw.ContextualWordEmbsAug,
 81 |     "synonym": naw.SynonymAug,
 82 |     "backtranslation": naw.BackTranslationAug,
 83 |     "contextual_s_e": nas.ContextualWordEmbsForSentenceAug,
 84 |     "abstractive_summ": nas.AbstSummAug,
 85 | }
 86 | 
 87 | 
 88 | 
[docs]@dataclass 89 | class NLPAugConfig: 90 | """ 91 | Configuration for augmenters. 92 | 93 | Parameters 94 | ---------- 95 | name : str 96 | Name of the data augmentation technique. Possible values currently are `ocr` (for OCR augmentation), `contextual_w_e` 97 | for Contextual Word Embedding augmentation, `synonym`, `backtranslation`, `contextual_s_e` for Contextual Word Embeddings for Sentence Augmentation, 98 | `abstractive_summ`. If using a custom augmenter class this can be a random name. 99 | augmenter_cls: Any 100 | An optional augmenter class, from `nlpaug` library. Can be used instead of using an identifier name 101 | for loading the class (see param `name` of this class). 102 | proportion : float 103 | Proportion of data augmentation. 104 | aug_kwargs : Dict 105 | Arguments for the data augmentation class. See https://github.com/makcedward/nlpaug/blob/master/example/textual_augmenter.ipynb 106 | """ 107 | 108 | name: str = field(metadata={"help": "Name of the data augmentation technique. If using a custom augmenter class this can be a random name."}) 109 | augmenter_cls: Any = field( 110 | default=None, 111 | metadata={"help": "An optional augmenter class, from `nlpaug` library. Can be used instead of using an identifier name for loading the class (see param `name` of this class)."} 112 | ) 113 | proportion: float = field( 114 | default=0.1, metadata={"help": "proportion of data augmentation"} 115 | ) 116 | aug_kwargs: Dict = field( 117 | default=None, 118 | metadata={ 119 | "help": "Arguments for the data augmentation class. See https://github.com/makcedward/nlpaug/blob/master/example/textual_augmenter.ipynb" 120 | }, 121 | )
122 |
123 | 124 |
125 |
126 | 140 |
141 |
142 |
143 |
144 | 149 | 150 | 151 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/examples.rst.txt: -------------------------------------------------------------------------------- 1 | .. include:: ../examples/README.rst 2 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/index.rst.txt: -------------------------------------------------------------------------------- 1 | .. nlpboost documentation master file, created by 2 | sphinx-quickstart on Fri Dec 30 02:02:16 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to nlpboost's documentation! 7 | ==================================== 8 | 9 | .. image:: ../imgs/nlpboost_logo_3.png 10 | :target: ../imgs/nlpboost_logo_3.png 11 | :width: 500 12 | :alt: nlpboost logo 13 | 14 | .. toctree:: 15 | :maxdepth: 2 16 | :caption: Contents: 17 | 18 | readme 19 | examples 20 | notebooks 21 | modules 22 | 23 | 24 | Indices and tables 25 | ================== 26 | 27 | * :ref:`genindex` 28 | * :ref:`modindex` 29 | * :ref:`search` 30 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/modules.rst.txt: -------------------------------------------------------------------------------- 1 | nlpboost 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | nlpboost 8 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/nlpboost.augmentation.rst.txt: -------------------------------------------------------------------------------- 1 | nlpboost.augmentation package 2 | ============================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | nlpboost.augmentation.TextAugmenterPipeline module 8 | -------------------------------------------------- 9 | 10 | .. automodule:: nlpboost.augmentation.TextAugmenterPipeline 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | nlpboost.augmentation.augmenter\_config module 16 | ---------------------------------------------- 17 | 18 | .. automodule:: nlpboost.augmentation.augmenter_config 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: nlpboost.augmentation 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/nlpboost.rst.txt: -------------------------------------------------------------------------------- 1 | nlpboost package 2 | ================ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | nlpboost.augmentation 11 | 12 | Submodules 13 | ---------- 14 | 15 | nlpboost.autotrainer module 16 | --------------------------- 17 | 18 | .. automodule:: nlpboost.autotrainer 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | nlpboost.ckpt\_cleaner module 24 | ----------------------------- 25 | 26 | .. automodule:: nlpboost.ckpt_cleaner 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | nlpboost.dataset\_config module 32 | ------------------------------- 33 | 34 | .. automodule:: nlpboost.dataset_config 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | nlpboost.default\_param\_spaces module 40 | -------------------------------------- 41 | 42 | .. automodule:: nlpboost.default_param_spaces 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | nlpboost.hfdatasets\_manager module 48 | ----------------------------------- 49 | 50 | .. automodule:: nlpboost.hfdatasets_manager 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | nlpboost.hftransformers\_manager module 56 | --------------------------------------- 57 | 58 | .. automodule:: nlpboost.hftransformers_manager 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | nlpboost.metrics module 64 | ----------------------- 65 | 66 | .. automodule:: nlpboost.metrics 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | nlpboost.metrics\_plotter module 72 | -------------------------------- 73 | 74 | .. automodule:: nlpboost.metrics_plotter 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | nlpboost.model\_config module 80 | ----------------------------- 81 | 82 | .. automodule:: nlpboost.model_config 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | nlpboost.results\_getter module 88 | ------------------------------- 89 | 90 | .. automodule:: nlpboost.results_getter 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | nlpboost.tokenization\_functions module 96 | --------------------------------------- 97 | 98 | .. automodule:: nlpboost.tokenization_functions 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | 103 | nlpboost.utils module 104 | --------------------- 105 | 106 | .. automodule:: nlpboost.utils 107 | :members: 108 | :undoc-members: 109 | :show-inheritance: 110 | 111 | Module contents 112 | --------------- 113 | 114 | .. automodule:: nlpboost 115 | :members: 116 | :undoc-members: 117 | :show-inheritance: 118 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/notebooks.rst.txt: -------------------------------------------------------------------------------- 1 | .. include:: ../notebooks/README.rst 2 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/readme.rst.txt: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | -------------------------------------------------------------------------------- /docs/_build/html/_static/_sphinx_javascript_frameworks_compat.js: -------------------------------------------------------------------------------- 1 | /* 2 | * _sphinx_javascript_frameworks_compat.js 3 | * ~~~~~~~~~~ 4 | * 5 | * Compatability shim for jQuery and underscores.js. 6 | * 7 | * WILL BE REMOVED IN Sphinx 6.0 8 | * xref RemovedInSphinx60Warning 9 | * 10 | */ 11 | 12 | /** 13 | * select a different prefix for underscore 14 | */ 15 | $u = _.noConflict(); 16 | 17 | 18 | /** 19 | * small helper function to urldecode strings 20 | * 21 | * See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/decodeURIComponent#Decoding_query_parameters_from_a_URL 22 | */ 23 | jQuery.urldecode = function(x) { 24 | if (!x) { 25 | return x 26 | } 27 | return decodeURIComponent(x.replace(/\+/g, ' ')); 28 | }; 29 | 30 | /** 31 | * small helper function to urlencode strings 32 | */ 33 | jQuery.urlencode = encodeURIComponent; 34 | 35 | /** 36 | * This function returns the parsed url parameters of the 37 | * current request. Multiple values per key are supported, 38 | * it will always return arrays of strings for the value parts. 39 | */ 40 | jQuery.getQueryParameters = function(s) { 41 | if (typeof s === 'undefined') 42 | s = document.location.search; 43 | var parts = s.substr(s.indexOf('?') + 1).split('&'); 44 | var result = {}; 45 | for (var i = 0; i < parts.length; i++) { 46 | var tmp = parts[i].split('=', 2); 47 | var key = jQuery.urldecode(tmp[0]); 48 | var value = jQuery.urldecode(tmp[1]); 49 | if (key in result) 50 | result[key].push(value); 51 | else 52 | result[key] = [value]; 53 | } 54 | return result; 55 | }; 56 | 57 | /** 58 | * highlight a given string on a jquery object by wrapping it in 59 | * span elements with the given class name. 60 | */ 61 | jQuery.fn.highlightText = function(text, className) { 62 | function highlight(node, addItems) { 63 | if (node.nodeType === 3) { 64 | var val = node.nodeValue; 65 | var pos = val.toLowerCase().indexOf(text); 66 | if (pos >= 0 && 67 | !jQuery(node.parentNode).hasClass(className) && 68 | !jQuery(node.parentNode).hasClass("nohighlight")) { 69 | var span; 70 | var isInSVG = jQuery(node).closest("body, svg, foreignObject").is("svg"); 71 | if (isInSVG) { 72 | span = document.createElementNS("http://www.w3.org/2000/svg", "tspan"); 73 | } else { 74 | span = document.createElement("span"); 75 | span.className = className; 76 | } 77 | span.appendChild(document.createTextNode(val.substr(pos, text.length))); 78 | node.parentNode.insertBefore(span, node.parentNode.insertBefore( 79 | document.createTextNode(val.substr(pos + text.length)), 80 | node.nextSibling)); 81 | node.nodeValue = val.substr(0, pos); 82 | if (isInSVG) { 83 | var rect = document.createElementNS("http://www.w3.org/2000/svg", "rect"); 84 | var bbox = node.parentElement.getBBox(); 85 | rect.x.baseVal.value = bbox.x; 86 | rect.y.baseVal.value = bbox.y; 87 | rect.width.baseVal.value = bbox.width; 88 | rect.height.baseVal.value = bbox.height; 89 | rect.setAttribute('class', className); 90 | addItems.push({ 91 | "parent": node.parentNode, 92 | "target": rect}); 93 | } 94 | } 95 | } 96 | else if (!jQuery(node).is("button, select, textarea")) { 97 | jQuery.each(node.childNodes, function() { 98 | highlight(this, addItems); 99 | }); 100 | } 101 | } 102 | var addItems = []; 103 | var result = this.each(function() { 104 | highlight(this, addItems); 105 | }); 106 | for (var i = 0; i < addItems.length; ++i) { 107 | jQuery(addItems[i].parent).before(addItems[i].target); 108 | } 109 | return result; 110 | }; 111 | 112 | /* 113 | * backward compatibility for jQuery.browser 114 | * This will be supported until firefox bug is fixed. 115 | */ 116 | if (!jQuery.browser) { 117 | jQuery.uaMatch = function(ua) { 118 | ua = ua.toLowerCase(); 119 | 120 | var match = /(chrome)[ \/]([\w.]+)/.exec(ua) || 121 | /(webkit)[ \/]([\w.]+)/.exec(ua) || 122 | /(opera)(?:.*version|)[ \/]([\w.]+)/.exec(ua) || 123 | /(msie) ([\w.]+)/.exec(ua) || 124 | ua.indexOf("compatible") < 0 && /(mozilla)(?:.*? rv:([\w.]+)|)/.exec(ua) || 125 | []; 126 | 127 | return { 128 | browser: match[ 1 ] || "", 129 | version: match[ 2 ] || "0" 130 | }; 131 | }; 132 | jQuery.browser = {}; 133 | jQuery.browser[jQuery.uaMatch(navigator.userAgent).browser] = true; 134 | } 135 | -------------------------------------------------------------------------------- /docs/_build/html/_static/css/badge_only.css: -------------------------------------------------------------------------------- 1 | .clearfix{*zoom:1}.clearfix:after,.clearfix:before{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-style:normal;font-weight:400;src:url(fonts/fontawesome-webfont.eot?674f50d287a8c48dc19ba404d20fe713?#iefix) format("embedded-opentype"),url(fonts/fontawesome-webfont.woff2?af7ae505a9eed503f8b8e6982036873e) format("woff2"),url(fonts/fontawesome-webfont.woff?fee66e712a8a08eef5805a46892932ad) format("woff"),url(fonts/fontawesome-webfont.ttf?b06871f281fee6b241d60582ae9369b9) format("truetype"),url(fonts/fontawesome-webfont.svg?912ec66d7572ff821749319396470bde#FontAwesome) format("svg")}.fa:before{font-family:FontAwesome;font-style:normal;font-weight:400;line-height:1}.fa:before,a .fa{text-decoration:inherit}.fa:before,a .fa,li .fa{display:inline-block}li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before,.icon-book:before{content:"\f02d"}.fa-caret-down:before,.icon-caret-down:before{content:"\f0d7"}.fa-caret-up:before,.icon-caret-up:before{content:"\f0d8"}.fa-caret-left:before,.icon-caret-left:before{content:"\f0d9"}.fa-caret-right:before,.icon-caret-right:before{content:"\f0da"}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;z-index:400}.rst-versions a{color:#2980b9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27ae60}.rst-versions .rst-current-version:after{clear:both;content:"";display:block}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book,.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#e74c3c;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#f1c40f;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:grey;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:1px solid #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .fa-book,.rst-versions.rst-badge .icon-book{float:none;line-height:30px}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book,.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge>.rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width:768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/css/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/css/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/css/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/css/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-bold-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/css/fonts/lato-bold-italic.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-bold-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/css/fonts/lato-bold-italic.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/css/fonts/lato-bold.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/css/fonts/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-normal-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/css/fonts/lato-normal-italic.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-normal-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/css/fonts/lato-normal-italic.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-normal.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/css/fonts/lato-normal.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-normal.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/css/fonts/lato-normal.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/doctools.js: -------------------------------------------------------------------------------- 1 | /* 2 | * doctools.js 3 | * ~~~~~~~~~~~ 4 | * 5 | * Base JavaScript utilities for all Sphinx HTML documentation. 6 | * 7 | * :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | "use strict"; 12 | 13 | const BLACKLISTED_KEY_CONTROL_ELEMENTS = new Set([ 14 | "TEXTAREA", 15 | "INPUT", 16 | "SELECT", 17 | "BUTTON", 18 | ]); 19 | 20 | const _ready = (callback) => { 21 | if (document.readyState !== "loading") { 22 | callback(); 23 | } else { 24 | document.addEventListener("DOMContentLoaded", callback); 25 | } 26 | }; 27 | 28 | /** 29 | * Small JavaScript module for the documentation. 30 | */ 31 | const Documentation = { 32 | init: () => { 33 | Documentation.initDomainIndexTable(); 34 | Documentation.initOnKeyListeners(); 35 | }, 36 | 37 | /** 38 | * i18n support 39 | */ 40 | TRANSLATIONS: {}, 41 | PLURAL_EXPR: (n) => (n === 1 ? 0 : 1), 42 | LOCALE: "unknown", 43 | 44 | // gettext and ngettext don't access this so that the functions 45 | // can safely bound to a different name (_ = Documentation.gettext) 46 | gettext: (string) => { 47 | const translated = Documentation.TRANSLATIONS[string]; 48 | switch (typeof translated) { 49 | case "undefined": 50 | return string; // no translation 51 | case "string": 52 | return translated; // translation exists 53 | default: 54 | return translated[0]; // (singular, plural) translation tuple exists 55 | } 56 | }, 57 | 58 | ngettext: (singular, plural, n) => { 59 | const translated = Documentation.TRANSLATIONS[singular]; 60 | if (typeof translated !== "undefined") 61 | return translated[Documentation.PLURAL_EXPR(n)]; 62 | return n === 1 ? singular : plural; 63 | }, 64 | 65 | addTranslations: (catalog) => { 66 | Object.assign(Documentation.TRANSLATIONS, catalog.messages); 67 | Documentation.PLURAL_EXPR = new Function( 68 | "n", 69 | `return (${catalog.plural_expr})` 70 | ); 71 | Documentation.LOCALE = catalog.locale; 72 | }, 73 | 74 | /** 75 | * helper function to focus on search bar 76 | */ 77 | focusSearchBar: () => { 78 | document.querySelectorAll("input[name=q]")[0]?.focus(); 79 | }, 80 | 81 | /** 82 | * Initialise the domain index toggle buttons 83 | */ 84 | initDomainIndexTable: () => { 85 | const toggler = (el) => { 86 | const idNumber = el.id.substr(7); 87 | const toggledRows = document.querySelectorAll(`tr.cg-${idNumber}`); 88 | if (el.src.substr(-9) === "minus.png") { 89 | el.src = `${el.src.substr(0, el.src.length - 9)}plus.png`; 90 | toggledRows.forEach((el) => (el.style.display = "none")); 91 | } else { 92 | el.src = `${el.src.substr(0, el.src.length - 8)}minus.png`; 93 | toggledRows.forEach((el) => (el.style.display = "")); 94 | } 95 | }; 96 | 97 | const togglerElements = document.querySelectorAll("img.toggler"); 98 | togglerElements.forEach((el) => 99 | el.addEventListener("click", (event) => toggler(event.currentTarget)) 100 | ); 101 | togglerElements.forEach((el) => (el.style.display = "")); 102 | if (DOCUMENTATION_OPTIONS.COLLAPSE_INDEX) togglerElements.forEach(toggler); 103 | }, 104 | 105 | initOnKeyListeners: () => { 106 | // only install a listener if it is really needed 107 | if ( 108 | !DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS && 109 | !DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS 110 | ) 111 | return; 112 | 113 | document.addEventListener("keydown", (event) => { 114 | // bail for input elements 115 | if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return; 116 | // bail with special keys 117 | if (event.altKey || event.ctrlKey || event.metaKey) return; 118 | 119 | if (!event.shiftKey) { 120 | switch (event.key) { 121 | case "ArrowLeft": 122 | if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) break; 123 | 124 | const prevLink = document.querySelector('link[rel="prev"]'); 125 | if (prevLink && prevLink.href) { 126 | window.location.href = prevLink.href; 127 | event.preventDefault(); 128 | } 129 | break; 130 | case "ArrowRight": 131 | if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) break; 132 | 133 | const nextLink = document.querySelector('link[rel="next"]'); 134 | if (nextLink && nextLink.href) { 135 | window.location.href = nextLink.href; 136 | event.preventDefault(); 137 | } 138 | break; 139 | } 140 | } 141 | 142 | // some keyboard layouts may need Shift to get / 143 | switch (event.key) { 144 | case "/": 145 | if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) break; 146 | Documentation.focusSearchBar(); 147 | event.preventDefault(); 148 | } 149 | }); 150 | }, 151 | }; 152 | 153 | // quick alias for translations 154 | const _ = Documentation.gettext; 155 | 156 | _ready(Documentation.init); 157 | -------------------------------------------------------------------------------- /docs/_build/html/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | var DOCUMENTATION_OPTIONS = { 2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), 3 | VERSION: '', 4 | LANGUAGE: 'en', 5 | COLLAPSE_INDEX: false, 6 | BUILDER: 'html', 7 | FILE_SUFFIX: '.html', 8 | LINK_SUFFIX: '.html', 9 | HAS_SOURCE: true, 10 | SOURCELINK_SUFFIX: '.txt', 11 | NAVIGATION_WITH_KEYS: false, 12 | SHOW_SEARCH_SUMMARY: true, 13 | ENABLE_SEARCH_SHORTCUTS: true, 14 | }; -------------------------------------------------------------------------------- /docs/_build/html/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/file.png -------------------------------------------------------------------------------- /docs/_build/html/_static/js/badge_only.js: -------------------------------------------------------------------------------- 1 | !function(e){var t={};function r(n){if(t[n])return t[n].exports;var o=t[n]={i:n,l:!1,exports:{}};return e[n].call(o.exports,o,o.exports,r),o.l=!0,o.exports}r.m=e,r.c=t,r.d=function(e,t,n){r.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:n})},r.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},r.t=function(e,t){if(1&t&&(e=r(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var n=Object.create(null);if(r.r(n),Object.defineProperty(n,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var o in e)r.d(n,o,function(t){return e[t]}.bind(null,o));return n},r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,"a",t),t},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r.p="",r(r.s=4)}({4:function(e,t,r){}}); -------------------------------------------------------------------------------- /docs/_build/html/_static/js/html5shiv-printshiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3-pre | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=y.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=y.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),y.elements=c+" "+a,j(b)}function f(a){var b=x[a[v]];return b||(b={},w++,a[v]=w,x[w]=b),b}function g(a,c,d){if(c||(c=b),q)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():u.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||t.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),q)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return y.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(y,b.frag)}function j(a){a||(a=b);var d=f(a);return!y.shivCSS||p||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),q||i(a,d),a}function k(a){for(var b,c=a.getElementsByTagName("*"),e=c.length,f=RegExp("^(?:"+d().join("|")+")$","i"),g=[];e--;)b=c[e],f.test(b.nodeName)&&g.push(b.applyElement(l(b)));return g}function l(a){for(var b,c=a.attributes,d=c.length,e=a.ownerDocument.createElement(A+":"+a.nodeName);d--;)b=c[d],b.specified&&e.setAttribute(b.nodeName,b.nodeValue);return e.style.cssText=a.style.cssText,e}function m(a){for(var b,c=a.split("{"),e=c.length,f=RegExp("(^|[\\s,>+~])("+d().join("|")+")(?=[[\\s,>+~#.:]|$)","gi"),g="$1"+A+"\\:$2";e--;)b=c[e]=c[e].split("}"),b[b.length-1]=b[b.length-1].replace(f,g),c[e]=b.join("}");return c.join("{")}function n(a){for(var b=a.length;b--;)a[b].removeNode()}function o(a){function b(){clearTimeout(g._removeSheetTimer),d&&d.removeNode(!0),d=null}var d,e,g=f(a),h=a.namespaces,i=a.parentWindow;return!B||a.printShived?a:("undefined"==typeof h[A]&&h.add(A),i.attachEvent("onbeforeprint",function(){b();for(var f,g,h,i=a.styleSheets,j=[],l=i.length,n=Array(l);l--;)n[l]=i[l];for(;h=n.pop();)if(!h.disabled&&z.test(h.media)){try{f=h.imports,g=f.length}catch(o){g=0}for(l=0;g>l;l++)n.push(f[l]);try{j.push(h.cssText)}catch(o){}}j=m(j.reverse().join("")),e=k(a),d=c(a,j)}),i.attachEvent("onafterprint",function(){n(e),clearTimeout(g._removeSheetTimer),g._removeSheetTimer=setTimeout(b,500)}),a.printShived=!0,a)}var p,q,r="3.7.3",s=a.html5||{},t=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,u=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,v="_html5shiv",w=0,x={};!function(){try{var a=b.createElement("a");a.innerHTML="",p="hidden"in a,q=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){p=!0,q=!0}}();var y={elements:s.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:r,shivCSS:s.shivCSS!==!1,supportsUnknownElements:q,shivMethods:s.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=y,j(b);var z=/^$|\b(?:all|print)\b/,A="html5shiv",B=!q&&function(){var c=b.documentElement;return!("undefined"==typeof b.namespaces||"undefined"==typeof b.parentWindow||"undefined"==typeof c.applyElement||"undefined"==typeof c.removeNode||"undefined"==typeof a.attachEvent)}();y.type+=" print",y.shivPrint=o,o(b),"object"==typeof module&&module.exports&&(module.exports=y)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /docs/_build/html/_static/js/html5shiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3 | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=t.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=t.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),t.elements=c+" "+a,j(b)}function f(a){var b=s[a[q]];return b||(b={},r++,a[q]=r,s[r]=b),b}function g(a,c,d){if(c||(c=b),l)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():p.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||o.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),l)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return t.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(t,b.frag)}function j(a){a||(a=b);var d=f(a);return!t.shivCSS||k||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),l||i(a,d),a}var k,l,m="3.7.3-pre",n=a.html5||{},o=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,p=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,q="_html5shiv",r=0,s={};!function(){try{var a=b.createElement("a");a.innerHTML="",k="hidden"in a,l=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){k=!0,l=!0}}();var t={elements:n.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:m,shivCSS:n.shivCSS!==!1,supportsUnknownElements:l,shivMethods:n.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=t,j(b),"object"==typeof module&&module.exports&&(module.exports=t)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /docs/_build/html/_static/js/theme.js: -------------------------------------------------------------------------------- 1 | !function(n){var e={};function t(i){if(e[i])return e[i].exports;var o=e[i]={i:i,l:!1,exports:{}};return n[i].call(o.exports,o,o.exports,t),o.l=!0,o.exports}t.m=n,t.c=e,t.d=function(n,e,i){t.o(n,e)||Object.defineProperty(n,e,{enumerable:!0,get:i})},t.r=function(n){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(n,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(n,"__esModule",{value:!0})},t.t=function(n,e){if(1&e&&(n=t(n)),8&e)return n;if(4&e&&"object"==typeof n&&n&&n.__esModule)return n;var i=Object.create(null);if(t.r(i),Object.defineProperty(i,"default",{enumerable:!0,value:n}),2&e&&"string"!=typeof n)for(var o in n)t.d(i,o,function(e){return n[e]}.bind(null,o));return i},t.n=function(n){var e=n&&n.__esModule?function(){return n.default}:function(){return n};return t.d(e,"a",e),e},t.o=function(n,e){return Object.prototype.hasOwnProperty.call(n,e)},t.p="",t(t.s=0)}([function(n,e,t){t(1),n.exports=t(3)},function(n,e,t){(function(){var e="undefined"!=typeof window?window.jQuery:t(2);n.exports.ThemeNav={navBar:null,win:null,winScroll:!1,winResize:!1,linkScroll:!1,winPosition:0,winHeight:null,docHeight:null,isRunning:!1,enable:function(n){var t=this;void 0===n&&(n=!0),t.isRunning||(t.isRunning=!0,e((function(e){t.init(e),t.reset(),t.win.on("hashchange",t.reset),n&&t.win.on("scroll",(function(){t.linkScroll||t.winScroll||(t.winScroll=!0,requestAnimationFrame((function(){t.onScroll()})))})),t.win.on("resize",(function(){t.winResize||(t.winResize=!0,requestAnimationFrame((function(){t.onResize()})))})),t.onResize()})))},enableSticky:function(){this.enable(!0)},init:function(n){n(document);var e=this;this.navBar=n("div.wy-side-scroll:first"),this.win=n(window),n(document).on("click","[data-toggle='wy-nav-top']",(function(){n("[data-toggle='wy-nav-shift']").toggleClass("shift"),n("[data-toggle='rst-versions']").toggleClass("shift")})).on("click",".wy-menu-vertical .current ul li a",(function(){var t=n(this);n("[data-toggle='wy-nav-shift']").removeClass("shift"),n("[data-toggle='rst-versions']").toggleClass("shift"),e.toggleCurrent(t),e.hashChange()})).on("click","[data-toggle='rst-current-version']",(function(){n("[data-toggle='rst-versions']").toggleClass("shift-up")})),n("table.docutils:not(.field-list,.footnote,.citation)").wrap("
"),n("table.docutils.footnote").wrap("
"),n("table.docutils.citation").wrap("
"),n(".wy-menu-vertical ul").not(".simple").siblings("a").each((function(){var t=n(this);expand=n(''),expand.on("click",(function(n){return e.toggleCurrent(t),n.stopPropagation(),!1})),t.prepend(expand)}))},reset:function(){var n=encodeURI(window.location.hash)||"#";try{var e=$(".wy-menu-vertical"),t=e.find('[href="'+n+'"]');if(0===t.length){var i=$('.document [id="'+n.substring(1)+'"]').closest("div.section");0===(t=e.find('[href="#'+i.attr("id")+'"]')).length&&(t=e.find('[href="#"]'))}if(t.length>0){$(".wy-menu-vertical .current").removeClass("current").attr("aria-expanded","false"),t.addClass("current").attr("aria-expanded","true"),t.closest("li.toctree-l1").parent().addClass("current").attr("aria-expanded","true");for(let n=1;n<=10;n++)t.closest("li.toctree-l"+n).addClass("current").attr("aria-expanded","true");t[0].scrollIntoView()}}catch(n){console.log("Error expanding nav for anchor",n)}},onScroll:function(){this.winScroll=!1;var n=this.win.scrollTop(),e=n+this.winHeight,t=this.navBar.scrollTop()+(n-this.winPosition);n<0||e>this.docHeight||(this.navBar.scrollTop(t),this.winPosition=n)},onResize:function(){this.winResize=!1,this.winHeight=this.win.height(),this.docHeight=$(document).height()},hashChange:function(){this.linkScroll=!0,this.win.one("hashchange",(function(){this.linkScroll=!1}))},toggleCurrent:function(n){var e=n.closest("li");e.siblings("li.current").removeClass("current").attr("aria-expanded","false"),e.siblings().find("li.current").removeClass("current").attr("aria-expanded","false");var t=e.find("> ul li");t.length&&(t.removeClass("current").attr("aria-expanded","false"),e.toggleClass("current").attr("aria-expanded",(function(n,e){return"true"==e?"false":"true"})))}},"undefined"!=typeof window&&(window.SphinxRtdTheme={Navigation:n.exports.ThemeNav,StickyNav:n.exports.ThemeNav}),function(){for(var n=0,e=["ms","moz","webkit","o"],t=0;t0 63 | var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1 64 | var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1 65 | var s_v = "^(" + C + ")?" + v; // vowel in stem 66 | 67 | this.stemWord = function (w) { 68 | var stem; 69 | var suffix; 70 | var firstch; 71 | var origword = w; 72 | 73 | if (w.length < 3) 74 | return w; 75 | 76 | var re; 77 | var re2; 78 | var re3; 79 | var re4; 80 | 81 | firstch = w.substr(0,1); 82 | if (firstch == "y") 83 | w = firstch.toUpperCase() + w.substr(1); 84 | 85 | // Step 1a 86 | re = /^(.+?)(ss|i)es$/; 87 | re2 = /^(.+?)([^s])s$/; 88 | 89 | if (re.test(w)) 90 | w = w.replace(re,"$1$2"); 91 | else if (re2.test(w)) 92 | w = w.replace(re2,"$1$2"); 93 | 94 | // Step 1b 95 | re = /^(.+?)eed$/; 96 | re2 = /^(.+?)(ed|ing)$/; 97 | if (re.test(w)) { 98 | var fp = re.exec(w); 99 | re = new RegExp(mgr0); 100 | if (re.test(fp[1])) { 101 | re = /.$/; 102 | w = w.replace(re,""); 103 | } 104 | } 105 | else if (re2.test(w)) { 106 | var fp = re2.exec(w); 107 | stem = fp[1]; 108 | re2 = new RegExp(s_v); 109 | if (re2.test(stem)) { 110 | w = stem; 111 | re2 = /(at|bl|iz)$/; 112 | re3 = new RegExp("([^aeiouylsz])\\1$"); 113 | re4 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 114 | if (re2.test(w)) 115 | w = w + "e"; 116 | else if (re3.test(w)) { 117 | re = /.$/; 118 | w = w.replace(re,""); 119 | } 120 | else if (re4.test(w)) 121 | w = w + "e"; 122 | } 123 | } 124 | 125 | // Step 1c 126 | re = /^(.+?)y$/; 127 | if (re.test(w)) { 128 | var fp = re.exec(w); 129 | stem = fp[1]; 130 | re = new RegExp(s_v); 131 | if (re.test(stem)) 132 | w = stem + "i"; 133 | } 134 | 135 | // Step 2 136 | re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/; 137 | if (re.test(w)) { 138 | var fp = re.exec(w); 139 | stem = fp[1]; 140 | suffix = fp[2]; 141 | re = new RegExp(mgr0); 142 | if (re.test(stem)) 143 | w = stem + step2list[suffix]; 144 | } 145 | 146 | // Step 3 147 | re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/; 148 | if (re.test(w)) { 149 | var fp = re.exec(w); 150 | stem = fp[1]; 151 | suffix = fp[2]; 152 | re = new RegExp(mgr0); 153 | if (re.test(stem)) 154 | w = stem + step3list[suffix]; 155 | } 156 | 157 | // Step 4 158 | re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/; 159 | re2 = /^(.+?)(s|t)(ion)$/; 160 | if (re.test(w)) { 161 | var fp = re.exec(w); 162 | stem = fp[1]; 163 | re = new RegExp(mgr1); 164 | if (re.test(stem)) 165 | w = stem; 166 | } 167 | else if (re2.test(w)) { 168 | var fp = re2.exec(w); 169 | stem = fp[1] + fp[2]; 170 | re2 = new RegExp(mgr1); 171 | if (re2.test(stem)) 172 | w = stem; 173 | } 174 | 175 | // Step 5 176 | re = /^(.+?)e$/; 177 | if (re.test(w)) { 178 | var fp = re.exec(w); 179 | stem = fp[1]; 180 | re = new RegExp(mgr1); 181 | re2 = new RegExp(meq1); 182 | re3 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 183 | if (re.test(stem) || (re2.test(stem) && !(re3.test(stem)))) 184 | w = stem; 185 | } 186 | re = /ll$/; 187 | re2 = new RegExp(mgr1); 188 | if (re.test(w) && re2.test(w)) { 189 | re = /.$/; 190 | w = w.replace(re,""); 191 | } 192 | 193 | // and turn initial Y back to y 194 | if (firstch == "y") 195 | w = firstch.toLowerCase() + w.substr(1); 196 | return w; 197 | } 198 | } 199 | 200 | -------------------------------------------------------------------------------- /docs/_build/html/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/minus.png -------------------------------------------------------------------------------- /docs/_build/html/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/_static/plus.png -------------------------------------------------------------------------------- /docs/_build/html/_static/pygments.css: -------------------------------------------------------------------------------- 1 | pre { line-height: 125%; } 2 | td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } 3 | span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } 4 | td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } 5 | span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } 6 | .highlight .hll { background-color: #ffffcc } 7 | .highlight { background: #f8f8f8; } 8 | .highlight .c { color: #3D7B7B; font-style: italic } /* Comment */ 9 | .highlight .err { border: 1px solid #FF0000 } /* Error */ 10 | .highlight .k { color: #008000; font-weight: bold } /* Keyword */ 11 | .highlight .o { color: #666666 } /* Operator */ 12 | .highlight .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */ 13 | .highlight .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */ 14 | .highlight .cp { color: #9C6500 } /* Comment.Preproc */ 15 | .highlight .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */ 16 | .highlight .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */ 17 | .highlight .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */ 18 | .highlight .gd { color: #A00000 } /* Generic.Deleted */ 19 | .highlight .ge { font-style: italic } /* Generic.Emph */ 20 | .highlight .gr { color: #E40000 } /* Generic.Error */ 21 | .highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */ 22 | .highlight .gi { color: #008400 } /* Generic.Inserted */ 23 | .highlight .go { color: #717171 } /* Generic.Output */ 24 | .highlight .gp { color: #000080; font-weight: bold } /* Generic.Prompt */ 25 | .highlight .gs { font-weight: bold } /* Generic.Strong */ 26 | .highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */ 27 | .highlight .gt { color: #0044DD } /* Generic.Traceback */ 28 | .highlight .kc { color: #008000; font-weight: bold } /* Keyword.Constant */ 29 | .highlight .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */ 30 | .highlight .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */ 31 | .highlight .kp { color: #008000 } /* Keyword.Pseudo */ 32 | .highlight .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */ 33 | .highlight .kt { color: #B00040 } /* Keyword.Type */ 34 | .highlight .m { color: #666666 } /* Literal.Number */ 35 | .highlight .s { color: #BA2121 } /* Literal.String */ 36 | .highlight .na { color: #687822 } /* Name.Attribute */ 37 | .highlight .nb { color: #008000 } /* Name.Builtin */ 38 | .highlight .nc { color: #0000FF; font-weight: bold } /* Name.Class */ 39 | .highlight .no { color: #880000 } /* Name.Constant */ 40 | .highlight .nd { color: #AA22FF } /* Name.Decorator */ 41 | .highlight .ni { color: #717171; font-weight: bold } /* Name.Entity */ 42 | .highlight .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */ 43 | .highlight .nf { color: #0000FF } /* Name.Function */ 44 | .highlight .nl { color: #767600 } /* Name.Label */ 45 | .highlight .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */ 46 | .highlight .nt { color: #008000; font-weight: bold } /* Name.Tag */ 47 | .highlight .nv { color: #19177C } /* Name.Variable */ 48 | .highlight .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */ 49 | .highlight .w { color: #bbbbbb } /* Text.Whitespace */ 50 | .highlight .mb { color: #666666 } /* Literal.Number.Bin */ 51 | .highlight .mf { color: #666666 } /* Literal.Number.Float */ 52 | .highlight .mh { color: #666666 } /* Literal.Number.Hex */ 53 | .highlight .mi { color: #666666 } /* Literal.Number.Integer */ 54 | .highlight .mo { color: #666666 } /* Literal.Number.Oct */ 55 | .highlight .sa { color: #BA2121 } /* Literal.String.Affix */ 56 | .highlight .sb { color: #BA2121 } /* Literal.String.Backtick */ 57 | .highlight .sc { color: #BA2121 } /* Literal.String.Char */ 58 | .highlight .dl { color: #BA2121 } /* Literal.String.Delimiter */ 59 | .highlight .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */ 60 | .highlight .s2 { color: #BA2121 } /* Literal.String.Double */ 61 | .highlight .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */ 62 | .highlight .sh { color: #BA2121 } /* Literal.String.Heredoc */ 63 | .highlight .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */ 64 | .highlight .sx { color: #008000 } /* Literal.String.Other */ 65 | .highlight .sr { color: #A45A77 } /* Literal.String.Regex */ 66 | .highlight .s1 { color: #BA2121 } /* Literal.String.Single */ 67 | .highlight .ss { color: #19177C } /* Literal.String.Symbol */ 68 | .highlight .bp { color: #008000 } /* Name.Builtin.Pseudo */ 69 | .highlight .fm { color: #0000FF } /* Name.Function.Magic */ 70 | .highlight .vc { color: #19177C } /* Name.Variable.Class */ 71 | .highlight .vg { color: #19177C } /* Name.Variable.Global */ 72 | .highlight .vi { color: #19177C } /* Name.Variable.Instance */ 73 | .highlight .vm { color: #19177C } /* Name.Variable.Magic */ 74 | .highlight .il { color: #666666 } /* Literal.Number.Integer.Long */ -------------------------------------------------------------------------------- /docs/_build/html/_static/sphinx_highlight.js: -------------------------------------------------------------------------------- 1 | /* Highlighting utilities for Sphinx HTML documentation. */ 2 | "use strict"; 3 | 4 | const SPHINX_HIGHLIGHT_ENABLED = true 5 | 6 | /** 7 | * highlight a given string on a node by wrapping it in 8 | * span elements with the given class name. 9 | */ 10 | const _highlight = (node, addItems, text, className) => { 11 | if (node.nodeType === Node.TEXT_NODE) { 12 | const val = node.nodeValue; 13 | const parent = node.parentNode; 14 | const pos = val.toLowerCase().indexOf(text); 15 | if ( 16 | pos >= 0 && 17 | !parent.classList.contains(className) && 18 | !parent.classList.contains("nohighlight") 19 | ) { 20 | let span; 21 | 22 | const closestNode = parent.closest("body, svg, foreignObject"); 23 | const isInSVG = closestNode && closestNode.matches("svg"); 24 | if (isInSVG) { 25 | span = document.createElementNS("http://www.w3.org/2000/svg", "tspan"); 26 | } else { 27 | span = document.createElement("span"); 28 | span.classList.add(className); 29 | } 30 | 31 | span.appendChild(document.createTextNode(val.substr(pos, text.length))); 32 | parent.insertBefore( 33 | span, 34 | parent.insertBefore( 35 | document.createTextNode(val.substr(pos + text.length)), 36 | node.nextSibling 37 | ) 38 | ); 39 | node.nodeValue = val.substr(0, pos); 40 | 41 | if (isInSVG) { 42 | const rect = document.createElementNS( 43 | "http://www.w3.org/2000/svg", 44 | "rect" 45 | ); 46 | const bbox = parent.getBBox(); 47 | rect.x.baseVal.value = bbox.x; 48 | rect.y.baseVal.value = bbox.y; 49 | rect.width.baseVal.value = bbox.width; 50 | rect.height.baseVal.value = bbox.height; 51 | rect.setAttribute("class", className); 52 | addItems.push({ parent: parent, target: rect }); 53 | } 54 | } 55 | } else if (node.matches && !node.matches("button, select, textarea")) { 56 | node.childNodes.forEach((el) => _highlight(el, addItems, text, className)); 57 | } 58 | }; 59 | const _highlightText = (thisNode, text, className) => { 60 | let addItems = []; 61 | _highlight(thisNode, addItems, text, className); 62 | addItems.forEach((obj) => 63 | obj.parent.insertAdjacentElement("beforebegin", obj.target) 64 | ); 65 | }; 66 | 67 | /** 68 | * Small JavaScript module for the documentation. 69 | */ 70 | const SphinxHighlight = { 71 | 72 | /** 73 | * highlight the search words provided in localstorage in the text 74 | */ 75 | highlightSearchWords: () => { 76 | if (!SPHINX_HIGHLIGHT_ENABLED) return; // bail if no highlight 77 | 78 | // get and clear terms from localstorage 79 | const url = new URL(window.location); 80 | const highlight = 81 | localStorage.getItem("sphinx_highlight_terms") 82 | || url.searchParams.get("highlight") 83 | || ""; 84 | localStorage.removeItem("sphinx_highlight_terms") 85 | url.searchParams.delete("highlight"); 86 | window.history.replaceState({}, "", url); 87 | 88 | // get individual terms from highlight string 89 | const terms = highlight.toLowerCase().split(/\s+/).filter(x => x); 90 | if (terms.length === 0) return; // nothing to do 91 | 92 | // There should never be more than one element matching "div.body" 93 | const divBody = document.querySelectorAll("div.body"); 94 | const body = divBody.length ? divBody[0] : document.querySelector("body"); 95 | window.setTimeout(() => { 96 | terms.forEach((term) => _highlightText(body, term, "highlighted")); 97 | }, 10); 98 | 99 | const searchBox = document.getElementById("searchbox"); 100 | if (searchBox === null) return; 101 | searchBox.appendChild( 102 | document 103 | .createRange() 104 | .createContextualFragment( 105 | '" 109 | ) 110 | ); 111 | }, 112 | 113 | /** 114 | * helper function to hide the search marks again 115 | */ 116 | hideSearchWords: () => { 117 | document 118 | .querySelectorAll("#searchbox .highlight-link") 119 | .forEach((el) => el.remove()); 120 | document 121 | .querySelectorAll("span.highlighted") 122 | .forEach((el) => el.classList.remove("highlighted")); 123 | localStorage.removeItem("sphinx_highlight_terms") 124 | }, 125 | 126 | initEscapeListener: () => { 127 | // only install a listener if it is really needed 128 | if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) return; 129 | 130 | document.addEventListener("keydown", (event) => { 131 | // bail for input elements 132 | if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return; 133 | // bail with special keys 134 | if (event.shiftKey || event.altKey || event.ctrlKey || event.metaKey) return; 135 | if (DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS && (event.key === "Escape")) { 136 | SphinxHighlight.hideSearchWords(); 137 | event.preventDefault(); 138 | } 139 | }); 140 | }, 141 | }; 142 | 143 | _ready(SphinxHighlight.highlightSearchWords); 144 | _ready(SphinxHighlight.initEscapeListener); 145 | -------------------------------------------------------------------------------- /docs/_build/html/_static/twemoji.css: -------------------------------------------------------------------------------- 1 | img.emoji { 2 | height: 1em; 3 | width: 1em; 4 | margin: 0 .05em 0 .1em; 5 | vertical-align: -0.1em; 6 | } 7 | -------------------------------------------------------------------------------- /docs/_build/html/_static/twemoji.js: -------------------------------------------------------------------------------- 1 | function addEvent(element, eventName, fn) { 2 | if (element.addEventListener) 3 | element.addEventListener(eventName, fn, false); 4 | else if (element.attachEvent) 5 | element.attachEvent('on' + eventName, fn); 6 | } 7 | 8 | addEvent(window, 'load', function() { 9 | twemoji.parse(document.body, {'folder': 'svg', 'ext': '.svg'}); 10 | }); 11 | -------------------------------------------------------------------------------- /docs/_build/html/examples.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Example scripts of how to use nlpboost for each task — nlpboost documentation 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 53 | 54 |
58 | 59 |
60 |
61 |
62 |
    63 |
  • 64 | 65 |
  • 66 | View page source 67 |
  • 68 |
69 |
70 |
71 |
72 |
73 | 74 |
75 |

Example scripts of how to use nlpboost for each task

76 |

In the examples folder you will find example scripts showing how to fine-tune models for different tasks. These tasks are divided in directories, as you see. In all scripts it is also shown how to use ResultsPlotter to save a metrics comparison figure of the models trained.

77 |
    78 |
  • 79 |
    classification

    For classification we have 2 examples. train_classification.py shows how to train a BERTIN model for multi-class classification (emotion detection, check tweet_eval: emotion dataset for more info.). On the other hand, train_multilabel.py shows how to train a model on a multilabel task.

    80 |
    81 |
    82 |
  • 83 |
  • 84 |
    extractive_qa

    For extractive QA we have only 1 example, as this type of task is very similar in all cases: train_sqac.py shows how to train a MarIA-large (Spanish Roberta-large) model on SQAC dataset, with hyperparameter search.

    85 |
    86 |
    87 |
  • 88 |
  • 89 |
    NER

    For NER, there is an example script, showing how to train multiple models on multiple NER datasets with different format, where we need to apply a pre_func to one of the datasets. The script is called train_spanish_ner.py.

    90 |
    91 |
    92 |
  • 93 |
  • 94 |
    seq2seq

    For this task, check out train_maria_encoder_decoder_marimari.py, which shows how to train a seq2seq model when no encoder-decoder architecture is readily available for a certain language, in this case Spanish. On the other hand, check out train_summarization_mlsum.py to learn how to configure training for two multilingual encoder-decoder models for MLSUM summarization task.

    95 |
    96 |
    97 |
  • 98 |
99 |

Important: For more detailed tutorials in Jupyter-Notebook format, please check nlpboost notebooks. These tutorials have explanations on all the configuration, which is helpful for getting to better know the tool. They are intended to provide a deep understanding on the different configurations that are needed for each of the tasks, so that the user can easily adapt the scripts for their own tasks and needs.

100 |
101 | 102 | 103 |
104 |
105 |
109 | 110 |
111 | 112 |
113 |

© Copyright 2022, Alejandro Vaca.

114 |
115 | 116 | Built with Sphinx using a 117 | theme 118 | provided by Read the Docs. 119 | 120 | 121 |
122 |
123 |
124 |
125 |
126 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /docs/_build/html/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Welcome to nlpboost’s documentation! — nlpboost documentation 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 |
28 | 52 | 53 |
57 | 58 |
59 |
60 |
61 |
    62 |
  • 63 | 64 |
  • 65 | View page source 66 |
  • 67 |
68 |
69 |
70 |
71 |
72 | 73 |
74 |

Welcome to nlpboost’s documentation!

75 | nlpboost logo 76 | 97 |
98 |
99 |

Indices and tables

100 | 105 |
106 | 107 | 108 |
109 |
110 |
113 | 114 |
115 | 116 |
117 |

© Copyright 2022, Alejandro Vaca.

118 |
119 | 120 | Built with Sphinx using a 121 | theme 122 | provided by Read the Docs. 123 | 124 | 125 |
126 |
127 |
128 |
129 |
130 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /docs/_build/html/notebooks.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Notebook Tutorials on how to use nlpboost — nlpboost documentation 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 53 | 54 |
58 | 59 |
60 |
61 |
62 |
    63 |
  • 64 | 65 |
  • 66 | View page source 67 |
  • 68 |
69 |
70 |
71 |
72 |
73 | 74 |
75 |

Notebook Tutorials on how to use nlpboost

76 |

In this folder you will find examples on how to use nlpboost in a tutorial format, with text explanations for each of the steps needed to configure the experiments. Notebooks can also be run in Google Colab.

77 |

The examples are very similar to those in examples folder, in the sense that the datasets chosen for each task are the same. However, in notebooks more models are trained for each dataset, and there are clear explanations on some key aspects of the configuration of each task that the raw scripts lack.

78 |
79 | 80 | 81 |
82 |
83 | 100 |
101 |
102 |
103 |
104 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /docs/_build/html/objects.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/docs/_build/html/objects.inv -------------------------------------------------------------------------------- /docs/_build/html/py-modindex.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Python Module Index — nlpboost documentation 7 | 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 53 | 54 |
58 | 59 |
60 |
61 |
62 |
    63 |
  • 64 | 65 |
  • 66 |
  • 67 |
68 |
69 |
70 |
71 |
72 | 73 | 74 |

Python Module Index

75 | 76 |
77 | n 78 |
79 | 80 | 81 | 82 | 84 | 85 | 87 | 90 | 91 | 92 | 95 | 96 | 97 | 100 | 101 | 102 | 105 | 106 | 107 | 110 | 111 | 112 | 115 | 116 | 117 | 120 | 121 | 122 | 125 | 126 | 127 | 130 | 131 | 132 | 135 | 136 | 137 | 140 | 141 | 142 | 145 | 146 | 147 | 150 | 151 | 152 | 155 | 156 | 157 | 160 | 161 | 162 | 165 |
 
83 | n
88 | nlpboost 89 |
    93 | nlpboost.augmentation 94 |
    98 | nlpboost.augmentation.augmenter_config 99 |
    103 | nlpboost.augmentation.TextAugmenterPipeline 104 |
    108 | nlpboost.autotrainer 109 |
    113 | nlpboost.ckpt_cleaner 114 |
    118 | nlpboost.dataset_config 119 |
    123 | nlpboost.default_param_spaces 124 |
    128 | nlpboost.hfdatasets_manager 129 |
    133 | nlpboost.hftransformers_manager 134 |
    138 | nlpboost.metrics 139 |
    143 | nlpboost.metrics_plotter 144 |
    148 | nlpboost.model_config 149 |
    153 | nlpboost.results_getter 154 |
    158 | nlpboost.tokenization_functions 159 |
    163 | nlpboost.utils 164 |
166 | 167 | 168 |
169 |
170 |
171 | 172 |
173 | 174 |
175 |

© Copyright 2022, Alejandro Vaca.

176 |
177 | 178 | Built with Sphinx using a 179 | theme 180 | provided by Read the Docs. 181 | 182 | 183 |
184 |
185 |
186 |
187 |
188 | 193 | 194 | 195 | -------------------------------------------------------------------------------- /docs/_build/html/search.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Search — nlpboost documentation 7 | 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 53 | 54 |
58 | 59 |
60 |
61 |
62 |
    63 |
  • 64 | 65 |
  • 66 |
  • 67 |
68 |
69 |
70 |
71 |
72 | 73 | 80 | 81 | 82 |
83 | 84 |
85 | 86 |
87 |
88 |
89 | 90 |
91 | 92 |
93 |

© Copyright 2022, Alejandro Vaca.

94 |
95 | 96 | Built with Sphinx using a 97 | theme 98 | provided by Read the Docs. 99 | 100 | 101 |
102 |
103 |
104 |
105 |
106 | 111 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | import os 10 | import sys 11 | # sys.path.insert(0, os.path.abspath('../src/')) 12 | 13 | project = 'nlpboost' 14 | copyright = '2022, Alejandro Vaca' 15 | author = 'Alejandro Vaca' 16 | 17 | # -- General configuration --------------------------------------------------- 18 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 19 | 20 | extensions = ['sphinx.ext.napoleon', 'sphinx.ext.autodoc', 21 | 'sphinx.ext.viewcode', 'sphinxemoji.sphinxemoji'] 22 | 23 | napoleon_google_docstring = False 24 | napoleon_use_param = False 25 | napoleon_use_ivar = True 26 | 27 | templates_path = ['_templates'] 28 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 29 | 30 | 31 | # -- Options for HTML output ------------------------------------------------- 32 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 33 | 34 | html_theme = 'sphinx_rtd_theme' 35 | html_static_path = ['_static'] 36 | -------------------------------------------------------------------------------- /docs/examples.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../examples/README.rst 2 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. nlpboost documentation master file, created by 2 | sphinx-quickstart on Fri Dec 30 02:02:16 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to nlpboost's documentation! 7 | ==================================== 8 | 9 | .. image:: ../imgs/nlpboost_logo_3.png 10 | :target: ../imgs/nlpboost_logo_3.png 11 | :width: 500 12 | :alt: nlpboost logo 13 | 14 | .. toctree:: 15 | :maxdepth: 2 16 | :caption: Contents: 17 | 18 | readme 19 | examples 20 | notebooks 21 | modules 22 | 23 | 24 | Indices and tables 25 | ================== 26 | 27 | * :ref:`genindex` 28 | * :ref:`modindex` 29 | * :ref:`search` 30 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | nlpboost 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | nlpboost 8 | -------------------------------------------------------------------------------- /docs/nlpboost.augmentation.rst: -------------------------------------------------------------------------------- 1 | nlpboost.augmentation package 2 | ============================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | nlpboost.augmentation.TextAugmenterPipeline module 8 | -------------------------------------------------- 9 | 10 | .. automodule:: nlpboost.augmentation.TextAugmenterPipeline 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | nlpboost.augmentation.augmenter\_config module 16 | ---------------------------------------------- 17 | 18 | .. automodule:: nlpboost.augmentation.augmenter_config 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: nlpboost.augmentation 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/nlpboost.rst: -------------------------------------------------------------------------------- 1 | nlpboost package 2 | ================ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | nlpboost.augmentation 11 | 12 | Submodules 13 | ---------- 14 | 15 | nlpboost.autotrainer module 16 | --------------------------- 17 | 18 | .. automodule:: nlpboost.autotrainer 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | nlpboost.ckpt\_cleaner module 24 | ----------------------------- 25 | 26 | .. automodule:: nlpboost.ckpt_cleaner 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | nlpboost.dataset\_config module 32 | ------------------------------- 33 | 34 | .. automodule:: nlpboost.dataset_config 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | nlpboost.default\_param\_spaces module 40 | -------------------------------------- 41 | 42 | .. automodule:: nlpboost.default_param_spaces 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | nlpboost.hfdatasets\_manager module 48 | ----------------------------------- 49 | 50 | .. automodule:: nlpboost.hfdatasets_manager 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | nlpboost.hftransformers\_manager module 56 | --------------------------------------- 57 | 58 | .. automodule:: nlpboost.hftransformers_manager 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | nlpboost.metrics module 64 | ----------------------- 65 | 66 | .. automodule:: nlpboost.metrics 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | nlpboost.metrics\_plotter module 72 | -------------------------------- 73 | 74 | .. automodule:: nlpboost.metrics_plotter 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | nlpboost.model\_config module 80 | ----------------------------- 81 | 82 | .. automodule:: nlpboost.model_config 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | nlpboost.results\_getter module 88 | ------------------------------- 89 | 90 | .. automodule:: nlpboost.results_getter 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | nlpboost.tokenization\_functions module 96 | --------------------------------------- 97 | 98 | .. automodule:: nlpboost.tokenization_functions 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | 103 | nlpboost.utils module 104 | --------------------- 105 | 106 | .. automodule:: nlpboost.utils 107 | :members: 108 | :undoc-members: 109 | :show-inheritance: 110 | 111 | Module contents 112 | --------------- 113 | 114 | .. automodule:: nlpboost 115 | :members: 116 | :undoc-members: 117 | :show-inheritance: 118 | -------------------------------------------------------------------------------- /docs/notebooks.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../notebooks/README.rst 2 | -------------------------------------------------------------------------------- /docs/readme.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers[torch]>=4.16 2 | datasets>=1.11.0 3 | optuna>=3.0.2 4 | scikit-learn>=1.0.2 5 | nltk==3.7 6 | rouge_score==0.1.2 7 | tensorboard==2.10.1 8 | tensorboardX==2.5.1 9 | sentencepiece==0.1.97 10 | apscheduler==3.6.3 11 | seaborn==0.12.0 12 | nlpaug==1.1.11 13 | simpletransformers >= 0.61.10 14 | pandas>=1.3.5 15 | tqdm==4.64.1 16 | evaluate==0.2.2 17 | more-itertools==8.14.0 18 | polyfuzz==0.4.0 19 | seqeval==1.2.2 20 | sphinxemoji==0.2.0 21 | pytest==7.1.3 22 | sphinx==5.3.0 23 | sphinx-rtd-theme>=0.4.3 24 | sphinxcontrib-applehelp==1.0.2 25 | sphinxcontrib-devhelp==1.0.2 26 | sphinxcontrib-htmlhelp==2.0.0 27 | sphinxcontrib-jsmath==1.0.1 28 | sphinxcontrib-qthelp==1.0.3 29 | sphinxcontrib-serializinghtml==1.1.5 30 | sphinxemoji==0.2.0 31 | # -e . 32 | # git+https://github.com/avacaondata/nlpboost.git -------------------------------------------------------------------------------- /examples/NER/train_spanish_ner.py: -------------------------------------------------------------------------------- 1 | from nlpboost import AutoTrainer, DatasetConfig, ModelConfig, dict_to_list, ResultsPlotter 2 | from transformers import EarlyStoppingCallback 3 | from functools import partial 4 | 5 | if __name__ == "__main__": 6 | fixed_train_args = { 7 | "evaluation_strategy": "steps", 8 | "num_train_epochs": 10, 9 | "do_train": True, 10 | "do_eval": True, 11 | "logging_strategy": "steps", 12 | "eval_steps": 1, 13 | "save_steps": 1, 14 | "logging_steps": 1, 15 | "save_strategy": "steps", 16 | "save_total_limit": 2, 17 | "seed": 69, 18 | "fp16": True, 19 | "no_cuda": False, 20 | "dataloader_num_workers": 2, 21 | "load_best_model_at_end": True, 22 | "per_device_eval_batch_size": 16, 23 | "adam_epsilon": 1e-6, 24 | "adam_beta1": 0.9, 25 | "adam_beta2": 0.999, 26 | "max_steps": 1 27 | } 28 | 29 | default_args_dataset = { 30 | "seed": 44, 31 | "direction_optimize": "maximize", 32 | "metric_optimize": "eval_f1-score", 33 | "callbacks": [EarlyStoppingCallback(1, 0.00001)], 34 | "fixed_training_args": fixed_train_args 35 | } 36 | 37 | conll2002_config = default_args_dataset.copy() 38 | conll2002_config.update( 39 | { 40 | "dataset_name": "conll2002", 41 | "alias": "conll2002", 42 | "task": "ner", 43 | "text_field": "tokens", 44 | "hf_load_kwargs": {"path": "conll2002", "name": "es"}, 45 | "label_col": "ner_tags", 46 | } 47 | ) 48 | 49 | conll2002_config = DatasetConfig(**conll2002_config) 50 | 51 | ehealth_config = default_args_dataset.copy() 52 | 53 | ehealth_config.update( 54 | { 55 | "dataset_name": "ehealth_kd", 56 | "alias": "ehealth", 57 | "task": "ner", 58 | "text_field": "token_list", 59 | "hf_load_kwargs": {"path": "ehealth_kd"}, 60 | "label_col": "label_list", 61 | "pre_func": partial(dict_to_list, nulltoken=100), 62 | } 63 | ) 64 | 65 | ehealth_config = DatasetConfig(**ehealth_config) 66 | 67 | dataset_configs = [ 68 | # conll2002_config, 69 | ehealth_config 70 | ] 71 | 72 | # AHORA PREPARAMOS LA CONFIGURACIÓN DE LOS MODELOS, EN ESTE CASO BSC Y BERTIN. 73 | 74 | def hp_space(trial): 75 | return { 76 | "learning_rate": trial.suggest_categorical( 77 | "learning_rate", [1.5e-5, 2e-5, 3e-5, 4e-5] 78 | ), 79 | "num_train_epochs": trial.suggest_categorical( 80 | "num_train_epochs", [1] 81 | ), 82 | "per_device_train_batch_size": trial.suggest_categorical( 83 | "per_device_train_batch_size", [1]), 84 | "per_device_eval_batch_size": trial.suggest_categorical( 85 | "per_device_eval_batch_size", [1]), 86 | "gradient_accumulation_steps": trial.suggest_categorical( 87 | "gradient_accumulation_steps", [1]), 88 | "warmup_steps": trial.suggest_categorical( 89 | "warmup_steps", [50, 100, 500, 1000] 90 | ), 91 | "weight_decay": trial.suggest_categorical( 92 | "weight_decay", [0.0] 93 | ) 94 | } 95 | 96 | bsc_config = ModelConfig( 97 | name="PlanTL-GOB-ES/roberta-base-bne", 98 | save_name="bsc@roberta", 99 | hp_space=hp_space, 100 | n_trials=1 101 | ) 102 | 103 | bertin_config = ModelConfig( 104 | name="bertin-project/bertin-roberta-base-spanish", 105 | save_name="bertin", 106 | hp_space=hp_space, 107 | n_trials=1 108 | ) 109 | 110 | model_configs = [bsc_config, bertin_config] 111 | 112 | # Y POR ÚLTIMO VAMOS A INICIALIZAR EL BENCHMARKER CON LA CONFIG DE MODELOS Y DATASETS, 113 | # Y LO LLAMAMOS PARA LLEVAR A CABO LA BÚSQUEDA DE PARÁMETROS. 114 | autotrainer = AutoTrainer( 115 | model_configs=model_configs, 116 | dataset_configs=dataset_configs, 117 | metrics_dir="metrics_spanish_ner", 118 | ) 119 | 120 | results = autotrainer() 121 | print(results) 122 | 123 | plotter = ResultsPlotter( 124 | metrics_dir=autotrainer.metrics_dir, 125 | model_names=[model_config.save_name for model_config in autotrainer.model_configs], 126 | dataset_to_task_map={dataset_config.alias: dataset_config.task for dataset_config in autotrainer.dataset_configs}, 127 | ) 128 | ax = plotter.plot_metrics() 129 | ax.figure.savefig("results.png") 130 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Example scripts of how to use nlpboost for each task 2 | 3 | In the examples folder you will find example scripts showing how to fine-tune models for different tasks. These tasks are divided in directories, as you see. In all scripts it is also shown how to use `ResultsPlotter` to save a metrics comparison figure of the models trained. 4 | 5 | * `classification` 6 | For classification we have 2 examples. [train_classification.py](https://github.com/avacaondata/nlpboost/blob/main/examples/classification/train_classification.py) shows how to train a BERTIN model for multi-class classification (emotion detection, check tweet_eval: emotion dataset for more info.). On the other hand, [train_multilabel.py](https://github.com/avacaondata/nlpboost/blob/main/examples/classification/train_multilabel.py) shows how to train a model on a multilabel task. 7 | 8 | * `extractive_qa` 9 | For extractive QA we have only 1 example, as this type of task is very similar in all cases: [train_sqac.py](https://github.com/avacaondata/nlpboost/blob/main/examples/extractive_qa/train_sqac.py) shows how to train a MarIA-large (Spanish Roberta-large) model on SQAC dataset, with hyperparameter search. 10 | 11 | * `NER` 12 | For NER, there is an example script, showing how to train multiple models on multiple NER datasets with different format, where we need to apply a `pre_func` to one of the datasets. The script is called [train_spanish_ner.py](https://github.com/avacaondata/nlpboost/blob/main/examples/NER/train_spanish_ner.py). 13 | 14 | * `seq2seq` 15 | For this task, check out [train_maria_encoder_decoder_marimari.py](https://github.com/avacaondata/nlpboost/blob/main/examples/seq2seq/train_maria_encoder_decoder_marimari.py), which shows how to train a seq2seq model when no encoder-decoder architecture is readily available for a certain language, in this case Spanish. On the other hand, check out [train_summarization_mlsum.py](https://github.com/avacaondata/nlpboost/blob/main/examples/seq2seq/train_summarization_mlsum.py) to learn how to configure training for two multilingual encoder-decoder models for MLSUM summarization task. 16 | 17 | **Important**: For more detailed tutorials in Jupyter-Notebook format, please check [nlpboost notebooks](https://github.com/avacaondata/nlpboost/tree/main/notebooks). These tutorials have explanations on all the configuration, which is helpful for getting to better know the tool. They are intended to provide a deep understanding on the different configurations that are needed for each of the tasks, so that the user can easily adapt the scripts for their own tasks and needs. -------------------------------------------------------------------------------- /examples/README.rst: -------------------------------------------------------------------------------- 1 | 2 | Example scripts of how to use nlpboost for each task 3 | ==================================================== 4 | 5 | In the examples folder you will find example scripts showing how to fine-tune models for different tasks. These tasks are divided in directories, as you see. In all scripts it is also shown how to use ``ResultsPlotter`` to save a metrics comparison figure of the models trained. 6 | 7 | 8 | * 9 | ``classification`` 10 | For classification we have 2 examples. `train_classification.py `_ shows how to train a BERTIN model for multi-class classification (emotion detection, check tweet_eval: emotion dataset for more info.). On the other hand, `train_multilabel.py `_ shows how to train a model on a multilabel task. 11 | 12 | * 13 | ``extractive_qa`` 14 | For extractive QA we have only 1 example, as this type of task is very similar in all cases: `train_sqac.py `_ shows how to train a MarIA-large (Spanish Roberta-large) model on SQAC dataset, with hyperparameter search. 15 | 16 | * 17 | ``NER`` 18 | For NER, there is an example script, showing how to train multiple models on multiple NER datasets with different format, where we need to apply a ``pre_func`` to one of the datasets. The script is called `train_spanish_ner.py `_. 19 | 20 | * 21 | ``seq2seq`` 22 | For this task, check out `train_maria_encoder_decoder_marimari.py `_\ , which shows how to train a seq2seq model when no encoder-decoder architecture is readily available for a certain language, in this case Spanish. On the other hand, check out `train_summarization_mlsum.py `_ to learn how to configure training for two multilingual encoder-decoder models for MLSUM summarization task. 23 | 24 | **Important**\ : For more detailed tutorials in Jupyter-Notebook format, please check `nlpboost notebooks `_. These tutorials have explanations on all the configuration, which is helpful for getting to better know the tool. They are intended to provide a deep understanding on the different configurations that are needed for each of the tasks, so that the user can easily adapt the scripts for their own tasks and needs. 25 | -------------------------------------------------------------------------------- /examples/classification/train_classification.py: -------------------------------------------------------------------------------- 1 | from nlpboost import DatasetConfig, ModelConfig, AutoTrainer, ResultsPlotter 2 | from nlpboost.default_param_spaces import hp_space_base 3 | 4 | if __name__ == "__main__": 5 | fixed_train_args = { 6 | "evaluation_strategy": "steps", 7 | "num_train_epochs": 10, 8 | "do_train": True, 9 | "do_eval": True, 10 | "logging_strategy": "steps", 11 | "eval_steps": 1, 12 | "save_steps": 1, 13 | "logging_steps": 1, 14 | "save_strategy": "steps", 15 | "save_total_limit": 2, 16 | "seed": 69, 17 | "fp16": False, 18 | "no_cuda": True, 19 | "load_best_model_at_end": True, 20 | "per_device_eval_batch_size": 16, 21 | "max_steps": 1 22 | } 23 | default_args_dataset = { 24 | "seed": 44, 25 | "direction_optimize": "maximize", 26 | "metric_optimize": "eval_f1-score", 27 | "retrain_at_end": False, 28 | "fixed_training_args": fixed_train_args 29 | } 30 | tweet_eval_config = default_args_dataset.copy() 31 | tweet_eval_config.update( 32 | { 33 | "dataset_name": "tweeteval", 34 | "alias": "tweeteval", 35 | "task": "classification", 36 | "text_field": "text", 37 | "label_col": "label", 38 | "hf_load_kwargs": {"path": "tweet_eval", "name": "emotion"} 39 | } 40 | ) 41 | tweet_eval_config = DatasetConfig(**tweet_eval_config) 42 | model_config = ModelConfig( 43 | name="bertin-project/bertin-roberta-base-spanish", 44 | save_name="bertin", 45 | hp_space=hp_space_base, 46 | n_trials=1, 47 | only_test=False 48 | ) 49 | autotrainer = AutoTrainer( 50 | model_configs=[model_config], 51 | dataset_configs=[tweet_eval_config], 52 | metrics_dir="tweeteval_metrics" 53 | ) 54 | 55 | results = autotrainer() 56 | print(results) 57 | 58 | plotter = ResultsPlotter( 59 | metrics_dir=autotrainer.metrics_dir, 60 | model_names=[model_config.save_name for model_config in autotrainer.model_configs], 61 | dataset_to_task_map={dataset_config.alias: dataset_config.task for dataset_config in autotrainer.dataset_configs}, 62 | ) 63 | ax = plotter.plot_metrics() 64 | ax.figure.savefig("results.png") 65 | -------------------------------------------------------------------------------- /examples/classification/train_multilabel.py: -------------------------------------------------------------------------------- 1 | from nlpboost import DatasetConfig, ModelConfig, AutoTrainer, ResultsPlotter 2 | from nlpboost.default_param_spaces import hp_space_base 3 | 4 | 5 | def pre_parse_func(example): 6 | label_cols = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "L", "M", "N", "Z"] 7 | new_example = {"text": example["abstractText"]} 8 | for col in label_cols: 9 | new_example[f"label_{col}"] = example[col] 10 | return new_example 11 | 12 | 13 | if __name__ == "__main__": 14 | fixed_train_args = { 15 | "evaluation_strategy": "steps", 16 | "num_train_epochs": 10, 17 | "do_train": True, 18 | "do_eval": True, 19 | "logging_strategy": "steps", 20 | "eval_steps": 1, 21 | "save_steps": 1, 22 | "logging_steps": 1, 23 | "save_strategy": "steps", 24 | "save_total_limit": 2, 25 | "seed": 69, 26 | "fp16": False, 27 | "no_cuda": False, 28 | "load_best_model_at_end": True, 29 | "per_device_eval_batch_size": 16, 30 | "max_steps": 1 31 | } 32 | default_args_dataset = { 33 | "seed": 44, 34 | "direction_optimize": "maximize", 35 | "metric_optimize": "eval_f1-score", 36 | "retrain_at_end": False, 37 | "fixed_training_args": fixed_train_args 38 | } 39 | pubmed_config = default_args_dataset.copy() 40 | pubmed_config.update( 41 | { 42 | "dataset_name": "pubmed", 43 | "alias": "pubmed", 44 | "task": "classification", 45 | "is_multilabel": True, 46 | "multilabel_label_names": [f"label_{col}" for col in ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "L", "M", "N", "Z"]], 47 | "text_field": "text", 48 | "label_col": "label_A", 49 | "hf_load_kwargs": {"path": "owaiskha9654/PubMed_MultiLabel_Text_Classification_Dataset_MeSH"}, 50 | "pre_func": pre_parse_func, 51 | "remove_fields_pre_func": True, 52 | "config_num_labels": 14, # for multilabel we need to pass the number of labels for the config. 53 | "split": True # as the dataset only comes with train split, we need to split in train, val, test. 54 | } 55 | ) 56 | pubmed_config = DatasetConfig(**pubmed_config) 57 | model_config = ModelConfig( 58 | name="bertin-project/bertin-roberta-base-spanish", 59 | save_name="bertin", 60 | hp_space=hp_space_base, 61 | n_trials=1 62 | ) 63 | autotrainer = AutoTrainer( 64 | model_configs=[model_config], 65 | dataset_configs=[pubmed_config], 66 | metrics_dir="pubmed_metrics" 67 | ) 68 | 69 | results = autotrainer() 70 | print(results) 71 | 72 | plotter = ResultsPlotter( 73 | metrics_dir=autotrainer.metrics_dir, 74 | model_names=[model_config.save_name for model_config in autotrainer.model_configs], 75 | dataset_to_task_map={dataset_config.alias: dataset_config.task for dataset_config in autotrainer.dataset_configs}, 76 | ) 77 | ax = plotter.plot_metrics() 78 | ax.figure.savefig("results.png") 79 | -------------------------------------------------------------------------------- /examples/extractive_qa/train_sqac.py: -------------------------------------------------------------------------------- 1 | from nlpboost import AutoTrainer, ModelConfig, DatasetConfig, ResultsPlotter 2 | from transformers import EarlyStoppingCallback 3 | 4 | if __name__ == "__main__": 5 | fixed_train_args = { 6 | "evaluation_strategy": "epoch", 7 | "num_train_epochs": 10, 8 | "do_train": True, 9 | "do_eval": True, 10 | "logging_strategy": "epoch", 11 | "save_strategy": "epoch", 12 | "save_total_limit": 2, 13 | "seed": 69, 14 | "bf16": True, 15 | "dataloader_num_workers": 8, 16 | "load_best_model_at_end": True, 17 | "per_device_eval_batch_size": 16, 18 | "adam_epsilon": 1e-6, 19 | "adam_beta1": 0.9, 20 | "adam_beta2": 0.999, 21 | } 22 | 23 | default_args_dataset = { 24 | "seed": 44, 25 | "direction_optimize": "minimize", 26 | "metric_optimize": "eval_loss", 27 | "callbacks": [EarlyStoppingCallback(1, 0.00001)], 28 | "fixed_training_args": fixed_train_args 29 | } 30 | 31 | sqac_config = default_args_dataset.copy() 32 | sqac_config.update( 33 | { 34 | "dataset_name": "sqac", 35 | "alias": "sqac", 36 | "task": "qa", 37 | "text_field": "context", 38 | "hf_load_kwargs": {"path": "PlanTL-GOB-ES/SQAC"}, 39 | "label_col": "question", 40 | } 41 | ) 42 | sqac_config = DatasetConfig(**sqac_config) 43 | 44 | def hp_space(trial): 45 | return { 46 | "learning_rate": trial.suggest_categorical( 47 | "learning_rate", [1.5e-5, 2e-5, 3e-5, 4e-5] 48 | ), 49 | "num_train_epochs": trial.suggest_categorical( 50 | "num_train_epochs", [1] 51 | ), 52 | "per_device_train_batch_size": trial.suggest_categorical( 53 | "per_device_train_batch_size", [16]), 54 | "per_device_eval_batch_size": trial.suggest_categorical( 55 | "per_device_eval_batch_size", [32]), 56 | "gradient_accumulation_steps": trial.suggest_categorical( 57 | "gradient_accumulation_steps", [1]), 58 | "warmup_steps": trial.suggest_categorical( 59 | "warmup_steps", [50, 100, 500, 1000] 60 | ), 61 | "weight_decay": trial.suggest_categorical( 62 | "weight_decay", [0.0] 63 | ) 64 | } 65 | 66 | bsc_config = ModelConfig( 67 | name="PlanTL-GOB-ES/roberta-base-bne", 68 | save_name="bsc@roberta", 69 | hp_space=hp_space, 70 | ) 71 | autotrainer = AutoTrainer( 72 | model_configs=[bsc_config], 73 | dataset_configs=[sqac_config], 74 | metrics_dir="spanish_qa_metrics", 75 | metrics_cleaner="spanish_qa_cleaner_metrics" 76 | ) 77 | 78 | experiment_results = autotrainer() 79 | print(experiment_results) 80 | 81 | plotter = ResultsPlotter( 82 | metrics_dir=autotrainer.metrics_dir, 83 | model_names=[model_config.save_name for model_config in autotrainer.model_configs], 84 | dataset_to_task_map={dataset_config.alias: dataset_config.task for dataset_config in autotrainer.dataset_configs}, 85 | ) 86 | ax = plotter.plot_metrics() 87 | ax.figure.savefig("results.png") 88 | -------------------------------------------------------------------------------- /examples/seq2seq/train_maria_encoder_decoder_marimari.py: -------------------------------------------------------------------------------- 1 | from nlpboost import AutoTrainer, DatasetConfig, ModelConfig, ResultsPlotter 2 | from transformers import EarlyStoppingCallback 3 | import evaluate 4 | 5 | 6 | if __name__ == "__main__": 7 | 8 | fixed_train_args = { 9 | "evaluation_strategy": "epoch", 10 | "num_train_epochs": 10, 11 | "do_train": True, 12 | "do_eval": True, 13 | "logging_strategy": "epoch", 14 | "save_strategy": "epoch", 15 | "save_total_limit": 2, 16 | "max_steps": 1, 17 | "seed": 69, 18 | "no_cuda": True, 19 | "bf16": False, 20 | "dataloader_num_workers": 8, 21 | "load_best_model_at_end": True, 22 | "per_device_eval_batch_size": 48, 23 | "adam_epsilon": 1e-8, 24 | "adam_beta1": 0.9, 25 | "adam_beta2": 0.999, 26 | "group_by_length": True, 27 | "max_grad_norm": 1.0 28 | } 29 | 30 | mlsum_config = { 31 | "seed": 44, 32 | "direction_optimize": "maximize", 33 | "metric_optimize": "eval_rouge2", 34 | "callbacks": [EarlyStoppingCallback(1, 0.00001)], 35 | "fixed_training_args": fixed_train_args 36 | } 37 | 38 | mlsum_config.update( 39 | { 40 | "dataset_name": "mlsum", 41 | "alias": "mlsum", 42 | "task": "summarization", 43 | "hf_load_kwargs": {"path": "mlsum", "name": "es"}, 44 | "label_col": "summary", 45 | "num_proc": 8, "additional_metrics": [evaluate.load("meteor")]} 46 | ) 47 | 48 | mlsum_config = DatasetConfig(**mlsum_config) 49 | 50 | def hp_space(trial): 51 | return { 52 | "learning_rate": trial.suggest_categorical( 53 | "learning_rate", 3e-5, 7e-5, log=True 54 | ), 55 | "num_train_epochs": trial.suggest_categorical( 56 | "num_train_epochs", [7] 57 | ), 58 | "per_device_train_batch_size": trial.suggest_categorical( 59 | "per_device_train_batch_size", [16]), 60 | "per_device_eval_batch_size": trial.suggest_categorical( 61 | "per_device_eval_batch_size", [32]), 62 | "gradient_accumulation_steps": trial.suggest_categorical( 63 | "gradient_accumulation_steps", [2]), 64 | "warmup_steps": trial.suggest_categorical( 65 | "warmup_steps", [50, 100, 500, 1000] 66 | ), 67 | "weight_decay": trial.suggest_float( 68 | "weight_decay", 0.0, 0.1 69 | ), 70 | } 71 | 72 | def preprocess_function(examples, tokenizer, dataset_config): 73 | model_inputs = tokenizer( 74 | examples[dataset_config.text_field], 75 | truncation=True, 76 | max_length=tokenizer.model_max_length 77 | ) 78 | with tokenizer.as_target_tokenizer(): 79 | labels = tokenizer(examples[dataset_config.summary_field], max_length=dataset_config.max_length_summary, truncation=True) 80 | model_inputs["labels"] = labels["input_ids"] 81 | return model_inputs 82 | 83 | marimari_roberta2roberta_config = ModelConfig( 84 | name="marimari-r2r", 85 | save_name="marimari-r2r", 86 | hp_space=hp_space, 87 | encoder_name="BSC-TeMU/roberta-base-bne", 88 | decoder_name="BSC-TeMU/roberta-base-bne", 89 | num_beams=4, 90 | n_trials=1, 91 | random_init_trials=1, 92 | custom_tokenization_func=preprocess_function, 93 | only_test=False, 94 | ) 95 | autotrainer = AutoTrainer( 96 | model_configs=[marimari_roberta2roberta_config], 97 | dataset_configs=[mlsum_config], 98 | metrics_dir="mlsum_marimari" 99 | ) 100 | 101 | results = autotrainer() 102 | print(results) 103 | 104 | plotter = ResultsPlotter( 105 | metrics_dir=autotrainer.metrics_dir, 106 | model_names=[model_config.save_name for model_config in autotrainer.model_configs], 107 | dataset_to_task_map={dataset_config.alias: dataset_config.task for dataset_config in autotrainer.dataset_configs}, 108 | metric_field="rouge2" 109 | ) 110 | ax = plotter.plot_metrics() 111 | ax.figure.savefig("results.png") 112 | -------------------------------------------------------------------------------- /examples/seq2seq/train_summarization_mlsum.py: -------------------------------------------------------------------------------- 1 | from nlpboost import AutoTrainer, DatasetConfig, ModelConfig, ResultsPlotter 2 | from transformers import EarlyStoppingCallback 3 | from transformers import Seq2SeqTrainer, MT5ForConditionalGeneration, XLMProphetNetForConditionalGeneration 4 | 5 | if __name__ == "__main__": 6 | 7 | fixed_train_args = { 8 | "evaluation_strategy": "epoch", 9 | "num_train_epochs": 10, 10 | "do_train": True, 11 | "do_eval": True, 12 | "logging_strategy": "epoch", 13 | "save_strategy": "epoch", 14 | "save_total_limit": 2, 15 | "seed": 69, 16 | "bf16": True, 17 | "dataloader_num_workers": 16, 18 | "load_best_model_at_end": True, 19 | "adafactor": True, 20 | } 21 | 22 | mlsum_config = { 23 | "seed": 44, 24 | "direction_optimize": "maximize", 25 | "metric_optimize": "eval_rouge2", 26 | "callbacks": [EarlyStoppingCallback(1, 0.00001)], 27 | "fixed_training_args": fixed_train_args 28 | } 29 | 30 | mlsum_config.update( 31 | { 32 | "dataset_name": "mlsum", 33 | "alias": "mlsum", 34 | "retrain_at_end": False, 35 | "task": "summarization", 36 | "hf_load_kwargs": {"path": "mlsum", "name": "es"}, 37 | "label_col": "summary", 38 | "num_proc": 16} 39 | ) 40 | 41 | mlsum_config = DatasetConfig(**mlsum_config) 42 | 43 | def hp_space(trial): 44 | return { 45 | "learning_rate": trial.suggest_categorical( 46 | "learning_rate", [3e-5, 5e-5, 7e-5, 2e-4] 47 | ), 48 | "num_train_epochs": trial.suggest_categorical( 49 | "num_train_epochs", [10] 50 | ), 51 | "per_device_train_batch_size": trial.suggest_categorical( 52 | "per_device_train_batch_size", [8]), 53 | "per_device_eval_batch_size": trial.suggest_categorical( 54 | "per_device_eval_batch_size", [8]), 55 | "gradient_accumulation_steps": trial.suggest_categorical( 56 | "gradient_accumulation_steps", [8]), 57 | "warmup_ratio": trial.suggest_categorical( 58 | "warmup_ratio", [0.08] 59 | ), 60 | } 61 | 62 | def preprocess_function(examples, tokenizer, dataset_config): 63 | model_inputs = tokenizer( 64 | examples[dataset_config.text_field], 65 | truncation=True, 66 | max_length=1024 67 | ) 68 | with tokenizer.as_target_tokenizer(): 69 | labels = tokenizer(examples[dataset_config.summary_field], max_length=dataset_config.max_length_summary, truncation=True) 70 | model_inputs["labels"] = labels["input_ids"] 71 | return model_inputs 72 | 73 | mt5_config = ModelConfig( 74 | name="google/mt5-large", 75 | save_name="mt5-large", 76 | hp_space=hp_space, 77 | num_beams=4, 78 | trainer_cls_summarization=Seq2SeqTrainer, 79 | model_cls_summarization=MT5ForConditionalGeneration, 80 | custom_tokenization_func=preprocess_function, 81 | n_trials=1, 82 | random_init_trials=1 83 | ) 84 | xprophetnet_config = ModelConfig( 85 | name="microsoft/xprophetnet-large-wiki100-cased", 86 | save_name="xprophetnet", 87 | hp_space=hp_space, 88 | num_beams=4, 89 | trainer_cls_summarization=Seq2SeqTrainer, 90 | model_cls_summarization=XLMProphetNetForConditionalGeneration, 91 | custom_tokenization_func=preprocess_function, 92 | n_trials=1, 93 | random_init_trials=1 94 | ) 95 | autotrainer = AutoTrainer( 96 | model_configs=[mt5_config, xprophetnet_config], 97 | dataset_configs=[mlsum_config], 98 | metrics_dir="mlsum_multilingual_models", 99 | metrics_cleaner="metrics_mlsum" 100 | ) 101 | 102 | results = autotrainer() 103 | print(results) 104 | 105 | plotter = ResultsPlotter( 106 | metrics_dir=autotrainer.metrics_dir, 107 | model_names=[model_config.save_name for model_config in autotrainer.model_configs], 108 | dataset_to_task_map={dataset_config.alias: dataset_config.task for dataset_config in autotrainer.dataset_configs}, 109 | metric_field="rouge2" 110 | ) 111 | ax = plotter.plot_metrics() 112 | ax.figure.savefig("results.png") -------------------------------------------------------------------------------- /imgs/nlpboost_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/imgs/nlpboost_diagram.png -------------------------------------------------------------------------------- /imgs/nlpboost_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/imgs/nlpboost_logo.png -------------------------------------------------------------------------------- /imgs/nlpboost_logo_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/imgs/nlpboost_logo_2.png -------------------------------------------------------------------------------- /imgs/nlpboost_logo_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avacaondata/nlpboost/f2c9f33a3c69053c82a458e8f7f28781c64026e2/imgs/nlpboost_logo_3.png -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | # Notebook Tutorials on how to use nlpboost 2 | 3 | In [this folder](https://github.com/avacaondata/nlpboost/tree/main/notebooks) you will find examples on how to use `nlpboost` in a tutorial format, with text explanations for each of the steps needed to configure the experiments. Notebooks can also be run in Google Colab. 4 | 5 | The examples are very similar to those in [examples folder](https://github.com/avacaondata/nlpboost/tree/main/examples), in the sense that the datasets chosen for each task are the same. However, in notebooks more models are trained for each dataset, and there are clear explanations on some key aspects of the configuration of each task that the raw scripts lack. -------------------------------------------------------------------------------- /notebooks/README.rst: -------------------------------------------------------------------------------- 1 | 2 | Notebook Tutorials on how to use nlpboost 3 | ========================================= 4 | 5 | In `this folder `_ you will find examples on how to use ``nlpboost`` in a tutorial format, with text explanations for each of the steps needed to configure the experiments. Notebooks can also be run in Google Colab. 6 | 7 | The examples are very similar to those in `examples folder `_\ , in the sense that the datasets chosen for each task are the same. However, in notebooks more models are trained for each dataset, and there are clear explanations on some key aspects of the configuration of each task that the raw scripts lack. 8 | -------------------------------------------------------------------------------- /notebooks/extractive_qa/train_sqac.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/avacaondata/nlpboost/blob/main/notebooks/extractive_qa/train_sqac.ipynb)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# Extractive Question Answering in Spanish: SQAC\n", 15 | "\n", 16 | "In this tutorial we will see how we can train multiple Spanish models on a QA dataset in that language: SQAC. " 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "We first import the needed modules or, if you are running this notebook in Google colab, please uncomment the cell below and run it before importing, in order to install `nlpboost`." 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "from nlpboost import AutoTrainer, ModelConfig, DatasetConfig, ResultsPlotter\n", 33 | "from transformers import EarlyStoppingCallback\n", 34 | "from nlpboost.default_param_spaces import hp_space_base" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "## Configure the dataset\n", 42 | "\n", 43 | "The next step is to define the fixed train args, which will be the `transformers.TrainingArguments` passed to `transformers.Trainer` inside `nlpboost.AutoTrainer`. For a full list of arguments check [TrainingArguments documentation](https://huggingface.co/docs/transformers/v4.25.1/en/main_classes/trainer#transformers.TrainingArguments). `DatasetConfig` expects these arguments in dictionary format.\n", 44 | "\n", 45 | "To save time, we set `max_steps` to 1; in a real setting we would need to define these arguments differently. However, that is out of scope for this tutorial. To learn how to work with Transformers, and how to configure the training arguments, please check Huggingface Course on NLP. " 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "fixed_train_args = {\n", 55 | " \"evaluation_strategy\": \"epoch\",\n", 56 | " \"num_train_epochs\": 10,\n", 57 | " \"do_train\": True,\n", 58 | " \"do_eval\": True,\n", 59 | " \"logging_strategy\": \"epoch\",\n", 60 | " \"save_strategy\": \"epoch\",\n", 61 | " \"save_total_limit\": 2,\n", 62 | " \"seed\": 69,\n", 63 | " \"fp16\": True,\n", 64 | " \"dataloader_num_workers\": 8,\n", 65 | " \"load_best_model_at_end\": True,\n", 66 | " \"per_device_eval_batch_size\": 16,\n", 67 | " \"adam_epsilon\": 1e-6,\n", 68 | " \"adam_beta1\": 0.9,\n", 69 | " \"adam_beta2\": 0.999,\n", 70 | " \"max_steps\": 1\n", 71 | "}" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "Then we define some common args for the dataset. In this case we minimize the loss, as for QA no compute metrics function is used during training. We use the loss to choose the best model and then compute metrics over the test set, which is not a straightforward process (that is the reason for not computing metrics in-training)." 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "default_args_dataset = {\n", 88 | " \"seed\": 44,\n", 89 | " \"direction_optimize\": \"minimize\",\n", 90 | " \"metric_optimize\": \"eval_loss\",\n", 91 | " \"retrain_at_end\": False,\n", 92 | " \"callbacks\": [EarlyStoppingCallback(1, 0.00001)],\n", 93 | " \"fixed_training_args\": fixed_train_args\n", 94 | "}" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "We now define arguments specific of SQAC. In this case, the text field and the label col are not used, so we just set them to two string columns of the dataset. In QA tasks, `nlpboost` assumes the dataset is in SQUAD format." 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "sqac_config = default_args_dataset.copy()\n", 111 | "sqac_config.update(\n", 112 | " {\n", 113 | " \"dataset_name\": \"sqac\",\n", 114 | " \"alias\": \"sqac\",\n", 115 | " \"task\": \"qa\",\n", 116 | " \"text_field\": \"context\",\n", 117 | " \"hf_load_kwargs\": {\"path\": \"PlanTL-GOB-ES/SQAC\"},\n", 118 | " \"label_col\": \"question\",\n", 119 | " }\n", 120 | ")" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "sqac_config = DatasetConfig(**sqac_config)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "## Configure Models\n", 137 | "\n", 138 | "We will configure three Spanish models. As you see, we only need to define the `name`, which is the path to the model (either in HF Hub or locally), `save_name` which is an arbitrary name for the model, the hyperparameter space and the number of trials. There are more parameters, which you can check in the documentation." 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "bertin_config = ModelConfig(\n", 148 | " name=\"bertin-project/bertin-roberta-base-spanish\",\n", 149 | " save_name=\"bertin\",\n", 150 | " hp_space=hp_space_base,\n", 151 | " n_trials=1,\n", 152 | ")\n", 153 | "beto_config = ModelConfig(\n", 154 | " name=\"dccuchile/bert-base-spanish-wwm-cased\",\n", 155 | " save_name=\"beto\",\n", 156 | " hp_space=hp_space_base,\n", 157 | " n_trials=1,\n", 158 | ")\n", 159 | "albert_config = ModelConfig(\n", 160 | " name=\"CenIA/albert-tiny-spanish\",\n", 161 | " save_name=\"albert\",\n", 162 | " hp_space=hp_space_base,\n", 163 | " n_trials=1\n", 164 | ")" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "metadata": {}, 170 | "source": [ 171 | "# Let's train! \n", 172 | "\n", 173 | "We can train now these three models on the SQAC dataset and see how well they perform (remember, if you really want to train them please remove the max steps to 1 in the fixed training arguments and the number of trials to 1 in the model configs)." 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "autotrainer = AutoTrainer(\n", 183 | " model_configs=[bertin_config, beto_config, albert_config],\n", 184 | " dataset_configs=[sqac_config],\n", 185 | " metrics_dir=\"spanish_qa_metrics\",\n", 186 | " metrics_cleaner=\"spanish_qa_cleaner_metrics\"\n", 187 | ")" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "experiment_results = autotrainer()\n", 197 | "print(experiment_results)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "## Plot the results\n", 205 | "\n", 206 | "As in other tutorials, we can now plot the results with ResultsPlotter." 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "plotter = ResultsPlotter(\n", 216 | " metrics_dir=autotrainer.metrics_dir,\n", 217 | " model_names=[model_config.save_name for model_config in autotrainer.model_configs],\n", 218 | " dataset_to_task_map={dataset_config.alias: dataset_config.task for dataset_config in autotrainer.dataset_configs},\n", 219 | ")\n", 220 | "ax = plotter.plot_metrics()" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [] 229 | } 230 | ], 231 | "metadata": { 232 | "kernelspec": { 233 | "display_name": "Python 3.9.13 ('largesum': conda)", 234 | "language": "python", 235 | "name": "python3" 236 | }, 237 | "language_info": { 238 | "codemirror_mode": { 239 | "name": "ipython", 240 | "version": 3 241 | }, 242 | "file_extension": ".py", 243 | "mimetype": "text/x-python", 244 | "name": "python", 245 | "nbconvert_exporter": "python", 246 | "pygments_lexer": "ipython3", 247 | "version": "3.9.13" 248 | }, 249 | "orig_nbformat": 4, 250 | "vscode": { 251 | "interpreter": { 252 | "hash": "bac692fd94dcfa608ba1aabbfbe7d5467f50ca857b57fe228a116df0c8b5b792" 253 | } 254 | } 255 | }, 256 | "nbformat": 4, 257 | "nbformat_minor": 2 258 | } 259 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers[torch]>=4.16 2 | datasets>=1.11.0 3 | optuna>=3.0.2 4 | scikit-learn>=1.0.2 5 | nltk>=3.7 6 | rouge_score==0.1.2 7 | tensorboard==2.10.1 8 | tensorboardX==2.5.1 9 | sentencepiece==0.1.97 10 | apscheduler==3.6.3 11 | seaborn==0.12.0 12 | nlpaug==1.1.11 13 | simpletransformers >= 0.61.10 14 | pandas>=1.3.5 15 | tqdm==4.64.1 16 | evaluate==0.2.2 17 | more-itertools==8.14.0 18 | polyfuzz==0.4.0 19 | seqeval==1.2.2 20 | sphinxemoji==0.2.0 21 | pytest==7.1.3 22 | sphinx==5.3.0 23 | sphinx-rtd-theme>=1.1.1 24 | readthedocs-sphinx-search==0.1.2 25 | sphinxcontrib-applehelp==1.0.2 26 | sphinxcontrib-devhelp==1.0.2 27 | sphinxcontrib-htmlhelp==2.0.0 28 | sphinxcontrib-jsmath==1.0.1 29 | sphinxcontrib-qthelp==1.0.3 30 | sphinxcontrib-serializinghtml==1.1.5 31 | numpydoc==1.5.0 -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | addopts = 3 | --cov-report html:cover 4 | --cov-report term 5 | --cov-config=.coveragerc 6 | --junitxml=report.xml 7 | --cov=src 8 | --doctest-modules 9 | --durations=20 10 | --ignore=doc/ 11 | --ignore=examples/ 12 | --instafail 13 | --pycodestyle 14 | --pydocstyle 15 | filterwarnings = 16 | ignore::PendingDeprecationWarning 17 | ignore::RuntimeWarning 18 | ignore::UserWarning 19 | 20 | [pycodestyle] 21 | max_line_length = 88 22 | ignore = E501, E203, W503, W605 23 | statistics = True 24 | 25 | [pydocstyle] 26 | convention = numpy 27 | add-ignore = D100, D103, D104 28 | 29 | [metadata] 30 | name = nlpboost 31 | version = 0.0.1 32 | author = Alejandro Vaca Serrano 33 | author_email = alejandro_vaca0@hotmail.com 34 | description = a Python package for automatic training and comparison of transformer models 35 | url = https://github.com/avacaondata/nlpboost 36 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | 4 | with open("README.md", "r", encoding="utf-8") as fh: 5 | long_description = fh.read() 6 | 7 | setuptools.setup( 8 | name="nlpboost", 9 | version="0.0.1", 10 | description="A package for automatic training of NLP (transformers) models", 11 | long_description=long_description, 12 | long_description_content_type="text/markdown", 13 | url="https://github.com/avacaondata/nlpboost", 14 | classifiers=[ 15 | "Development Status :: 5 - Production/Stable", 16 | "Intended Audience :: Science/Research", 17 | "Operating System :: OS Independent", 18 | "Programming Language :: Python", 19 | "Programming Language :: Python :: 3 :: Only", 20 | "Programming Language :: Python :: 3.7", 21 | "Programming Language :: Python :: 3.8", 22 | "Programming Language :: Python :: 3.9", 23 | "Programming Language :: Python :: 3.10", 24 | "Topic :: Scientific/Engineering", 25 | ], 26 | package_dir={"": "src"}, 27 | packages=setuptools.find_packages(where="src"), 28 | python_requires=">=3.9.13,<3.11", 29 | install_requires=open("requirements.txt", "r").read().splitlines(), 30 | keywords="natural-language-processing, nlp, transformers, hyperparameter-tuning, automatic-training" 31 | ) 32 | -------------------------------------------------------------------------------- /src/nlpboost/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics_plotter import ResultsPlotter 2 | from . import augmentation 3 | from .autotrainer import AutoTrainer 4 | from .dataset_config import DatasetConfig 5 | from .model_config import ModelConfig 6 | from .utils import ( 7 | dict_to_list, 8 | joinpaths, 9 | match_questions_multiple_answers, 10 | get_windowed_match_context_answer, 11 | get_tags, 12 | _tokenize_dataset, 13 | _load_json, 14 | _save_json, 15 | ) 16 | from .tokenization_functions import ( 17 | tokenize_classification, 18 | tokenize_ner, 19 | tokenize_squad, 20 | tokenize_summarization, 21 | ) 22 | from .results_getter import ResultsGetter 23 | from .default_param_spaces import hp_space_base, hp_space_large 24 | from .skip_mix import SkipMix 25 | -------------------------------------------------------------------------------- /src/nlpboost/augmentation/TextAugmenterPipeline.py: -------------------------------------------------------------------------------- 1 | from .augmenter_config import class_translator 2 | from tqdm import tqdm 3 | import numpy as np 4 | 5 | 6 | class NLPAugPipeline: 7 | """ 8 | Augment text data, with various forms of augmenting. It uses `nlpaug` in the background. 9 | 10 | The configuration of the augmentation pipeline is done with `nlpboost.augmentation.augmenter_config.NLPAugConfig`. 11 | NLPAugPipeline receives a list of configs of that type, where each config defines a type 12 | of augmentation technique to use, as well as the proportion of the train dataset that is 13 | to be augmented. 14 | 15 | Parameters 16 | ---------- 17 | steps: List[nlpboost.augmentation.augmenter_config.NLPAugConfig] 18 | List of steps. Each step must be a NLPAugConfig instance. 19 | text_field: str 20 | Name of the field in the dataset where texts are. 21 | """ 22 | 23 | def __init__(self, steps, text_field: str = "text"): 24 | self.text_field = text_field 25 | self.pipeline = { 26 | i: { 27 | "augmenter": class_translator[config.name](**config.aug_kwargs) if config.augmenter_cls is None else config.augmenter_cls(**config.aug_kwargs), 28 | "prop": config.proportion, 29 | } 30 | for i, config in enumerate(steps) 31 | } 32 | 33 | def augment(self, samples): 34 | """ 35 | Augment data for datasets samples following the configuration defined at init. 36 | 37 | Parameters 38 | ---------- 39 | samples: 40 | Samples from a datasets.Dataset 41 | 42 | Returns 43 | ------- 44 | samples: 45 | Samples from a datasets.Dataset but processed. 46 | """ 47 | fields = [k for k in samples.keys()] 48 | new_samples = {field: [] for field in fields} 49 | for augmenter in tqdm( 50 | self.pipeline, desc="Iterating over data augmentation methods..." 51 | ): 52 | samples_selection_idxs = np.random.choice( 53 | range(len(samples[fields[0]])), 54 | size=int(self.pipeline[augmenter]["prop"] * len(samples[fields[0]])), 55 | replace=False, 56 | ) 57 | texts_augment = [ 58 | samples[self.text_field][idx] for idx in samples_selection_idxs 59 | ] 60 | augmented_texts = self.pipeline[augmenter]["augmenter"].augment( 61 | texts_augment 62 | ) 63 | for example_idx, augmented_example in zip( 64 | samples_selection_idxs, augmented_texts 65 | ): 66 | for field in fields: 67 | if field == self.text_field: 68 | new_samples[field].append(augmented_example) 69 | else: 70 | new_samples[field].append(samples[field][example_idx]) 71 | for field in tqdm(fields, desc="Updating samples batch with augmented data..."): 72 | samples[field].extend(new_samples[field]) 73 | return samples 74 | -------------------------------------------------------------------------------- /src/nlpboost/augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from .augmenter_config import NLPAugConfig 2 | from .TextAugmenterPipeline import NLPAugPipeline 3 | -------------------------------------------------------------------------------- /src/nlpboost/augmentation/augmenter_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Dict, Any 3 | import nlpaug.augmenter.char as nac 4 | import nlpaug.augmenter.word as naw 5 | import nlpaug.augmenter.sentence as nas 6 | 7 | class_translator = { 8 | "ocr": nac.OcrAug, 9 | "contextual_w_e": naw.ContextualWordEmbsAug, 10 | "synonym": naw.SynonymAug, 11 | "backtranslation": naw.BackTranslationAug, 12 | "contextual_s_e": nas.ContextualWordEmbsForSentenceAug, 13 | "abstractive_summ": nas.AbstSummAug, 14 | } 15 | 16 | 17 | @dataclass 18 | class NLPAugConfig: 19 | """ 20 | Configuration for augmenters. 21 | 22 | Parameters 23 | ---------- 24 | name : str 25 | Name of the data augmentation technique. Possible values currently are `ocr` (for OCR augmentation), `contextual_w_e` 26 | for Contextual Word Embedding augmentation, `synonym`, `backtranslation`, `contextual_s_e` for Contextual Word Embeddings for Sentence Augmentation, 27 | `abstractive_summ`. If using a custom augmenter class this can be a random name. 28 | augmenter_cls: Any 29 | An optional augmenter class, from `nlpaug` library. Can be used instead of using an identifier name 30 | for loading the class (see param `name` of this class). 31 | proportion : float 32 | Proportion of data augmentation. 33 | aug_kwargs : Dict 34 | Arguments for the data augmentation class. See https://github.com/makcedward/nlpaug/blob/master/example/textual_augmenter.ipynb 35 | """ 36 | 37 | name: str = field(metadata={"help": "Name of the data augmentation technique. If using a custom augmenter class this can be a random name."}) 38 | augmenter_cls: Any = field( 39 | default=None, 40 | metadata={"help": "An optional augmenter class, from `nlpaug` library. Can be used instead of using an identifier name for loading the class (see param `name` of this class)."} 41 | ) 42 | proportion: float = field( 43 | default=0.1, metadata={"help": "proportion of data augmentation"} 44 | ) 45 | aug_kwargs: Dict = field( 46 | default=None, 47 | metadata={ 48 | "help": "Arguments for the data augmentation class. See https://github.com/makcedward/nlpaug/blob/master/example/textual_augmenter.ipynb" 49 | }, 50 | ) 51 | -------------------------------------------------------------------------------- /src/nlpboost/augmentation/tests/test_text_augmenter_pipeline.py: -------------------------------------------------------------------------------- 1 | from nlpboost.augmentation import NLPAugPipeline, NLPAugConfig 2 | from datasets import load_dataset 3 | 4 | 5 | def test_aug_pipeline(): 6 | """Test for text augmenter pipeline, test if it augments quantity of data.""" 7 | dataset = load_dataset("avacaondata/wnli_tests") 8 | dataset = dataset["train"] 9 | steps = [ 10 | NLPAugConfig( 11 | name="contextual_w_e", 12 | aug_kwargs={ 13 | "model_path": "CenIA/albert-tiny-spanish", 14 | "action": "insert", 15 | "device": "cpu", 16 | }, 17 | ), 18 | ] 19 | aug_pipeline = NLPAugPipeline(steps=steps, text_field="sentence1") 20 | augmented_dataset = dataset.map(aug_pipeline.augment, batched=True) 21 | assert len(augmented_dataset[:]["sentence1"]) > len( 22 | dataset[:]["sentence1"] 23 | ), "The dataset was not augmented." 24 | -------------------------------------------------------------------------------- /src/nlpboost/ckpt_cleaner.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | from tqdm import tqdm 4 | from .utils import _load_json, _save_json 5 | from typing import List, Dict 6 | 7 | 8 | class CkptCleaner: 9 | """ 10 | Clean all checkpoints that are no longer useful. 11 | 12 | Use a metrics dictionary to check the results of all runs of a model 13 | for a dataset, then sort these metrics to decide which checkpoints are 14 | removable and which are among the four best. When called, only those 15 | are kept, and all the other checkpoints are removed. This enables the 16 | user to effectively use their computer resources, so there is no need to 17 | worry about the disk usage, which is a typical concern when running multiple 18 | transformer models. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | current_folder_clean: str, 24 | current_dataset_folder: str, 25 | metrics_save_dir: str, 26 | modelname: str, 27 | mode: str = "max", 28 | try_mode: bool = False, 29 | ): 30 | self.current_folder_clean = current_folder_clean 31 | self.current_dataset_folder = current_dataset_folder 32 | self.modelname = modelname 33 | self.metrics_save_dir = metrics_save_dir 34 | self.mode = mode 35 | self.try_mode = try_mode 36 | self.last_saved_ckpt = "" 37 | 38 | os.makedirs(self.metrics_save_dir, exist_ok=True) 39 | 40 | def __call__(self, skip_last: bool = True): 41 | """ 42 | Check the metrics folder and remove checkpoints of models not performing well (all except 4 best). 43 | 44 | Called by a scheduler, to eventually remove the undesired checkpoints. 45 | """ 46 | metricsname = f"{self.modelname}.json" 47 | if metricsname in os.listdir(self.metrics_save_dir): 48 | metrics = _load_json(os.path.join(self.metrics_save_dir, metricsname)) 49 | else: 50 | metrics = {} 51 | lista = os.listdir(self.current_folder_clean) 52 | runs_dirs = [folder for folder in lista if "run-" in folder] 53 | runs_dirs = list( 54 | sorted( 55 | runs_dirs, 56 | key=lambda x: int(x.split("-")[-1]), 57 | ) 58 | ) 59 | if skip_last: 60 | runs_dirs = runs_dirs[:-2] 61 | for run_dir in tqdm(runs_dirs): 62 | checkpoint_dirs = [ 63 | folder 64 | for folder in os.listdir( 65 | os.path.join(self.current_folder_clean, run_dir) 66 | ) 67 | if "checkpoint-" in folder 68 | ] 69 | if len(checkpoint_dirs) > 0: 70 | checkpoint_dirs = list( 71 | sorted( 72 | checkpoint_dirs, 73 | key=lambda x: int(x.split("-")[-1]), 74 | ) 75 | ) 76 | last = checkpoint_dirs[-1] 77 | trainer_state = _load_json( 78 | os.path.join( 79 | self.current_folder_clean, run_dir, last, "trainer_state.json" 80 | ) 81 | ) 82 | best_model_checkpoint = trainer_state["best_model_checkpoint"] 83 | if best_model_checkpoint not in metrics: 84 | metrics[best_model_checkpoint] = float(trainer_state["best_metric"]) 85 | _save_json( 86 | metrics, os.path.join(self.metrics_save_dir, metricsname) 87 | ) 88 | checkpoint_dirs = [ 89 | os.path.join(self.current_folder_clean, run_dir, checkpoint) 90 | for checkpoint in checkpoint_dirs 91 | ] 92 | bname = self.get_best_name(metrics) 93 | checkpoint_dirs = [ 94 | ckpt 95 | for ckpt in checkpoint_dirs 96 | if ckpt 97 | not in [self.fix_dir(best_model_checkpoint), self.fix_dir(bname)] 98 | ] 99 | if bname != self.last_saved_ckpt: 100 | print("saving new best checkpoint...") 101 | # don't need to receive the target. 102 | _ = self.save_best(bname) 103 | self.last_saved_ckpt = bname 104 | else: 105 | print("will save nothing as best model has not changed...") 106 | assert ( 107 | bname not in checkpoint_dirs 108 | ), "best_model_checkpoint should not be in checkpoint dirs." 109 | assert ( 110 | best_model_checkpoint not in checkpoint_dirs 111 | ), "best_model_checkpoint should not be in checkpoint dirs." 112 | self.remove_dirs(checkpoint_dirs) 113 | sorted_metrics = sorted(metrics, key=metrics.get, reverse=self.mode == "max") 114 | if len(sorted_metrics) > 0: 115 | print( 116 | f"For model {self.current_folder_clean} the best metric is {metrics[sorted_metrics[0]]} and the worst is {metrics[sorted_metrics[-1]]}" 117 | ) 118 | best_ckpt = sorted_metrics[0] 119 | _ = self.save_best(best_ckpt) 120 | if len(sorted_metrics) > 4: 121 | dirs_to_remove = sorted_metrics[ 122 | 4: 123 | ] # REMOVE ALL BUT BEST 4 CHECKPOINTS. 124 | self.remove_dirs(dirs_to_remove) 125 | 126 | def get_best_name(self, metrics: Dict): 127 | """ 128 | Get the path of the best performing model. 129 | 130 | Parameters 131 | ---------- 132 | metrics: Dict 133 | Metrics of all models in a dictionary. 134 | 135 | Returns 136 | ------- 137 | best: str 138 | Path to the best performing model. 139 | """ 140 | sorted_metrics = sorted(metrics, key=metrics.get, reverse=self.mode == "max") 141 | best = sorted_metrics[0] 142 | return best 143 | 144 | def save_best( 145 | self, 146 | best_model: str, 147 | ): 148 | """ 149 | Save best model. 150 | 151 | Parameters 152 | ---------- 153 | best_model: str 154 | Path of the best performing model. 155 | 156 | Returns 157 | ------- 158 | target: str 159 | Complete path to the target directory where the best model has been copied. 160 | """ 161 | target = os.path.join( 162 | self.current_dataset_folder, f"best_ckpt_{self.modelname}" 163 | ) 164 | if os.path.exists(target) and os.path.exists(best_model): 165 | if not self.try_mode: 166 | shutil.rmtree(target) 167 | else: 168 | print( 169 | f"Al estar en try mode se hace como que se elimina el directorio {target}" 170 | ) 171 | print(f"Copiando {best_model} a {target}") 172 | if os.path.exists(best_model): 173 | if not self.try_mode: 174 | shutil.copytree( 175 | best_model, target, ignore=shutil.ignore_patterns("*optimizer*") 176 | ) 177 | if not self.try_mode: 178 | assert os.path.exists(target), "TARGET DOES NOT EXIST..." 179 | return target 180 | 181 | def fix_dir(self, dir: str): 182 | """ 183 | Fix directory path for windows file systems. 184 | 185 | Parameters 186 | ---------- 187 | dir: str 188 | Directory to fix. 189 | 190 | Returns 191 | ------- 192 | dir: str 193 | Fixed directory. 194 | """ 195 | return dir.replace("D:\\", "D:") 196 | 197 | def remove_dirs(self, checkpoint_dirs: List): 198 | """ 199 | Delete checkpoint directories. 200 | 201 | Parameters 202 | ---------- 203 | checkpoint_dirs: List 204 | List with the checkpoint directories to remove. 205 | """ 206 | for ckpt_dir in tqdm(checkpoint_dirs, desc="deleting models..."): 207 | try: 208 | if not self.try_mode: 209 | shutil.rmtree(ckpt_dir) 210 | else: 211 | print( 212 | f"Al estar en try mode se hace como que se elimina el directorio {ckpt_dir}" 213 | ) 214 | except FileNotFoundError: 215 | print(f"Se intentó eliminar el directorio {ckpt_dir} y no se pudo") 216 | -------------------------------------------------------------------------------- /src/nlpboost/default_param_spaces.py: -------------------------------------------------------------------------------- 1 | def hp_space_base(trial): 2 | """Hyperparameter space in Optuna format for base-sized models (e.g. bert-base).""" 3 | return { 4 | "learning_rate": trial.suggest_float("learning_rate", 1e-5, 7e-5, log=True), 5 | "num_train_epochs": trial.suggest_categorical( 6 | "num_train_epochs", [3, 5, 7, 10, 15, 20, 30] 7 | ), 8 | "per_device_train_batch_size": trial.suggest_categorical( 9 | "per_device_train_batch_size", [8, 16] 10 | ), 11 | "per_device_eval_batch_size": trial.suggest_categorical( 12 | "per_device_eval_batch_size", [32] 13 | ), 14 | "gradient_accumulation_steps": trial.suggest_categorical( 15 | "gradient_accumulation_steps", [1, 2, 3, 4] 16 | ), 17 | "warmup_ratio": trial.suggest_float("warmup_ratio", 0.01, 0.10, log=True), 18 | "weight_decay": trial.suggest_float("weight_decay", 1e-10, 0.3, log=True), 19 | "adam_epsilon": trial.suggest_float("adam_epsilon", 1e-10, 1e-6, log=True), 20 | } 21 | 22 | 23 | def hp_space_large(trial): 24 | """Hyperparameter space in Optuna format for large-sized models (e.g. bert-large).""" 25 | return { 26 | "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True), 27 | "num_train_epochs": trial.suggest_categorical( 28 | "num_train_epochs", [3, 5, 7, 10, 15, 20, 30, 40, 50] 29 | ), 30 | "per_device_train_batch_size": trial.suggest_categorical( 31 | "per_device_train_batch_size", [4] 32 | ), 33 | "per_device_eval_batch_size": trial.suggest_categorical( 34 | "per_device_eval_batch_size", [16] 35 | ), 36 | "gradient_accumulation_steps": trial.suggest_categorical( 37 | "gradient_accumulation_steps", [4, 8, 12, 16] 38 | ), 39 | "warmup_ratio": trial.suggest_float("warmup_ratio", 0.01, 0.10, log=True), 40 | "weight_decay": trial.suggest_float("weight_decay", 1e-10, 0.3, log=True), 41 | "adam_epsilon": trial.suggest_float("adam_epsilon", 1e-10, 1e-6, log=True), 42 | "adam_beta2": trial.suggest_float("adam_beta2", 0.98, 0.999, log=True), 43 | } 44 | -------------------------------------------------------------------------------- /src/nlpboost/metrics.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import classification_report 2 | import numpy as np 3 | import nltk 4 | import itertools 5 | from typing import List 6 | import torch 7 | import evaluate 8 | 9 | nltk.download("punkt") 10 | 11 | metric_sum = evaluate.load("rouge") 12 | metric_seqeval = evaluate.load("seqeval") 13 | 14 | 15 | def compute_metrics_classification( 16 | pred, tokenizer=None, id2tag=None, additional_metrics=None 17 | ): 18 | """ 19 | Compute metrics for classification (multi-class or binary) tasks. 20 | 21 | Parameters 22 | ---------- 23 | pred: transformers.EvalPrediction 24 | Prediction as output by transformers.Trainer 25 | tokenizer: transformers.Tokenizer 26 | Tokenizer from huggingface. 27 | id2tag: Dict 28 | Dictionary mapping label ids to label names. 29 | additional_metrics: List 30 | List with additional metrics to compute. 31 | 32 | Returns 33 | ------- 34 | metrics: Dict 35 | Dictionary with metrics. For information regarding the exact metrics 36 | received in it, see the documentation for sklearn.metrics.classification_report. 37 | """ 38 | preds, labels = pred.predictions, pred.label_ids 39 | preds = np.argmax(preds, axis=1) 40 | class_report = classification_report(labels, preds, output_dict=True) 41 | metrics = class_report["macro avg"] 42 | return metrics 43 | 44 | 45 | def compute_metrics_multilabel( 46 | pred, tokenizer=None, id2tag=None, additional_metrics=None 47 | ): 48 | """ 49 | Compute the metrics for a multilabel task. 50 | 51 | Parameters 52 | ---------- 53 | pred: transformers.EvalPrediction 54 | Prediction as output by transformers.Trainer 55 | tokenizer: transformers.Tokenizer 56 | Tokenizer from huggingface. 57 | id2tag: Dict 58 | Dictionary mapping label ids to label names. 59 | additional_metrics: List 60 | List with additional metrics to compute. 61 | 62 | Returns 63 | ------- 64 | best_metrics: Dict 65 | Dictionary with best metrics, after trying different thresholds. 66 | """ 67 | preds, labels = pred.predictions, pred.label_ids 68 | preds = torch.sigmoid(torch.from_numpy(preds)).numpy() 69 | thresholds = np.arange(0.1, 0.9, 0.1) 70 | best_metrics, best_metric, best_threshold = {}, 0, None 71 | 72 | for thres in thresholds: 73 | preds = preds >= thres 74 | preds = preds.astype(np.int) 75 | labels = labels.astype(np.int) 76 | class_report = classification_report( 77 | labels, 78 | preds, 79 | output_dict=True, 80 | ) 81 | metrics = class_report["macro avg"] 82 | f1 = metrics["f1-score"] 83 | if f1 > best_metric: 84 | best_metrics = metrics 85 | best_metric = f1 86 | best_threshold = thres 87 | print(f"*** The best threshold is {best_threshold} ***") 88 | return best_metrics 89 | 90 | 91 | def compute_metrics_ner(p, tokenizer=None, id2tag=None, additional_metrics=None): 92 | """ 93 | Compute metrics for ner. 94 | 95 | Use seqeval metric from HF Evaluate. Get the predicted label for each instance, 96 | then skip padded tokens and finally use seqeval metric, which takes into account 97 | full entities, not individual tokens, when computing the metrics. 98 | 99 | Parameters 100 | ---------- 101 | p: transformers.EvalPrediction 102 | Instance of EvalPrediction from transformers. 103 | tokenizer: transformers.Tokenizer 104 | Tokenizer from huggingface. 105 | id2tag: Dict 106 | Dictionary mapping label ids to label names. 107 | additional_metrics: List 108 | List with additional metrics to compute. 109 | 110 | Returns 111 | ------- 112 | Metrics 113 | Complete dictionary with all computed metrics on eval data. 114 | """ 115 | predictions, labels = p.predictions, p.label_ids 116 | 117 | try: 118 | predictions = np.argmax(predictions, axis=2) 119 | except Exception: 120 | print("The output shape is not logits-like, but directly targets.") 121 | predictions = predictions.astype("int") 122 | 123 | # Remove ignored index (special tokens) 124 | true_predictions = [ 125 | [str(id2tag[p]) for (p, i) in zip(prediction, label) if i != -100] 126 | for prediction, label in zip(predictions, labels) 127 | ] 128 | true_labels = [ 129 | [str(id2tag[i]) for (p, i) in zip(prediction, label) if i != -100] 130 | for prediction, label in zip(predictions, labels) 131 | ] 132 | metrics = metric_seqeval.compute(predictions=true_predictions, references=true_labels) 133 | metrics["f1-score"] = metrics["overall_f1"] 134 | return metrics 135 | 136 | 137 | def compute_metrics_summarization( 138 | eval_pred, tokenizer, id2tag=None, additional_metrics: List = None 139 | ): 140 | """ 141 | Compute metrics for summarization tasks, by using rouge metrics in datasets library. 142 | 143 | Parameters 144 | ---------- 145 | eval_pred: transformers.EvalPrediction 146 | Prediction as output by transformers.Trainer 147 | tokenizer: 148 | Tokenizer from huggingface. 149 | id2tag: Dict 150 | Dictionary mapping label ids to label names. 151 | additional_metrics: List 152 | List with additional metrics to compute. 153 | 154 | Returns 155 | ------- 156 | metrics: Dict 157 | Dictionary with relevant metrics for summarization. 158 | """ 159 | predictions, labels = eval_pred.predictions, eval_pred.label_ids 160 | decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) 161 | # Replace -100 in the labels as we can't decode them. 162 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 163 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 164 | 165 | # Rouge expects a newline after each sentence 166 | decoded_preds = [ 167 | "\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds 168 | ] 169 | decoded_labels = [ 170 | "\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels 171 | ] 172 | 173 | result = metric_sum.compute( 174 | predictions=decoded_preds, references=decoded_labels, use_stemmer=True 175 | ) 176 | # Extract a few results 177 | result = {key: value * 100 for key, value in result.items()} 178 | 179 | # Add mean generated length 180 | prediction_lens = [ 181 | np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions 182 | ] 183 | result["gen_len"] = np.mean(prediction_lens) 184 | result = {k: round(v, 4) for k, v in result.items()} 185 | if additional_metrics: 186 | other_results = [] 187 | for metric in additional_metrics: 188 | subre = metric.compute(predictions=decoded_preds, references=decoded_labels) 189 | other_results.append(subre) 190 | print(f"Other results for this dataset: \n {other_results}") 191 | result["other_metrics"] = other_results 192 | return result 193 | -------------------------------------------------------------------------------- /src/nlpboost/metrics_plotter.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | import pandas as pd 4 | import seaborn as sns 5 | import matplotlib.pyplot as plt 6 | import ast 7 | import numpy as np 8 | from typing import List, Dict 9 | from .utils import joinpaths 10 | import os 11 | import re 12 | 13 | 14 | class ResultsPlotter: 15 | """ 16 | Tool for plotting the results of the models trained. 17 | 18 | Parameters 19 | ---------- 20 | metrics_dir: str 21 | Directory name with metrics. 22 | model_names: List 23 | List with the names of the models. 24 | dataset_to_task_map: Dict 25 | Dictionary that maps dataset names to tasks. Can be built with the list of DatasetConfigs. 26 | remove_strs: List 27 | List of strings to remove from filename. 28 | metric_field: str 29 | Name of the field with the objective metric. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | metrics_dir: str, 35 | model_names: List, 36 | dataset_to_task_map: Dict, 37 | remove_strs: List = [], 38 | metric_field: str = "f1-score", 39 | ): 40 | self.metrics_dir = metrics_dir 41 | self.model_names = model_names 42 | self.dataset_to_task_map = dataset_to_task_map 43 | self.remove_strs = remove_strs 44 | self.metric_field = metric_field 45 | 46 | def plot_metrics(self): 47 | """Plot the metrics as a barplot.""" 48 | df_metrics = self.read_metrics() 49 | df_metrics = df_metrics.groupby( 50 | ["dataset_name", "model_name"], as_index=False 51 | ).aggregate("max") 52 | df_metrics = df_metrics.append( 53 | df_metrics.groupby(by="model_name", as_index=False).aggregate("mean") 54 | ) 55 | df_metrics.loc[df_metrics["dataset_name"].isna(), "dataset_name"] = "AVERAGE" 56 | plot = self._make_plot(df_metrics) 57 | return plot 58 | 59 | def _make_plot(self, df): 60 | """ 61 | Build the plot with the dataset in the correct format. 62 | 63 | Parameters 64 | ---------- 65 | df: pd.DataFrame 66 | DataFrame with the metrics data. 67 | 68 | Returns 69 | ------- 70 | ax: matplotlib.axes.Axes 71 | ax object as returned by matplotlib. 72 | """ 73 | plt.rcParams["figure.figsize"] = (20, 15) 74 | plt.rcParams["xtick.labelsize"] = "large" 75 | plt.rcParams["ytick.labelsize"] = "large" 76 | 77 | ax = sns.barplot( 78 | y="dataset_name", 79 | x=self.metric_field, 80 | data=df.sort_values(["model_name", "dataset_name"]), 81 | hue="model_name", 82 | ) 83 | ax.set_xticks(np.linspace(0.0, 1.0, 25)) 84 | plt.grid(True, color="#93a1a1", alpha=0.9, linestyle="--", which="both") 85 | plt.title( 86 | "Experiments Results", 87 | size=22, 88 | fontdict={ 89 | "fontstyle": "normal", 90 | "fontfamily": "serif", 91 | "fontweight": "bold", 92 | }, 93 | ) 94 | plt.ylabel("Dataset Name", size=18, fontdict={"fontfamily": "serif"}) 95 | plt.xlabel( 96 | f"{self.metric_field}", size=18, fontdict={"fontfamily": "serif"} 97 | ) 98 | sns.despine() 99 | plt.legend(bbox_to_anchor=(0.9, 0.98), loc=3, borderaxespad=0.0) 100 | return ax 101 | 102 | def read_metrics( 103 | self, 104 | ): 105 | """Read the metrics in the self.metrics_dir directory, creating a dataset with the data.""" 106 | dics = [] 107 | files = [joinpaths(self.metrics_dir, f) for f in os.listdir(self.metrics_dir)] 108 | for file in tqdm(files, desc="reading metrics files..."): 109 | try: 110 | if ".json" in file: 111 | with open(file, "r") as f: 112 | d = json.load(f) 113 | else: 114 | with open(file, "r") as f: 115 | d = f.read() 116 | d = ast.literal_eval(d) 117 | file = ( 118 | file.replace(self.metrics_dir, "") 119 | .replace("/", "") 120 | .replace("-dropout_0.0.json", "") 121 | ) 122 | for remove_str in self.remove_strs: 123 | 124 | file = re.sub(remove_str, "", file) 125 | dataset_name = ( 126 | file.replace(".json", "").replace(".txt", "").split("#")[-1] 127 | ) 128 | model_name = file.replace(".json", "").replace(".txt", "").split("#")[0] 129 | if dataset_name not in self.dataset_to_task_map: 130 | task = "qa" 131 | else: 132 | task = self.dataset_to_task_map[dataset_name] 133 | if task == "qa" and d["f1"] > 1.0: 134 | f1 = d["f1"] * 0.01 135 | elif task == "multiple_choice": 136 | f1 = d["accuracy"] 137 | else: 138 | f1 = d[self.metric_field] 139 | newdic = { 140 | "model_name": model_name, 141 | "dataset_name": dataset_name, 142 | f"{self.metric_field}": f1, 143 | "namefile": file, 144 | "task": task, 145 | } 146 | dics.append(newdic) 147 | except Exception as e: 148 | print(e) 149 | continue 150 | return pd.DataFrame(dics) 151 | -------------------------------------------------------------------------------- /src/nlpboost/skip_mix.py: -------------------------------------------------------------------------------- 1 | class SkipMix: 2 | """ 3 | Simple class to skip mix of dataset and model. 4 | 5 | Two properties: dataset to skip and model to skip. 6 | 7 | Parameters 8 | ---------- 9 | dataset_name: str 10 | Name of the dataset, as in alias parameter of DatasetConfig. 11 | model_name: str 12 | Name of the model, as in save_name parameter of ModelConfig. 13 | """ 14 | 15 | def __init__(self, dataset_name: str, model_name: str): 16 | self.dataset_name = dataset_name 17 | self.model_name = model_name 18 | -------------------------------------------------------------------------------- /src/nlpboost/tests/test_ckpt_cleaner.py: -------------------------------------------------------------------------------- 1 | from nlpboost.ckpt_cleaner import CkptCleaner 2 | from nlpboost.utils import _save_json, joinpaths 3 | import os 4 | 5 | 6 | def _create_trainer_state(best_model_checkpoint, best_metric): 7 | """Create fake trainer_state dict for checkpoint cleaner to use.""" 8 | trainer_state = { 9 | "best_model_checkpoint": best_model_checkpoint, 10 | "best_metric": best_metric, 11 | } 12 | return trainer_state 13 | 14 | 15 | def _get_total_files_directory(directory): 16 | """Get total count of files in a directory.""" 17 | return len([x[0] for x in os.walk(directory)]) 18 | 19 | 20 | def _create_current_folder_clean(current_folder_clean): 21 | """Create a folder with fake runs.""" 22 | numruns = 5 23 | num_ckpts = 10 24 | os.makedirs(current_folder_clean, exist_ok=True) 25 | for run in range(numruns): 26 | trainer_state = _create_trainer_state( 27 | joinpaths( 28 | current_folder_clean, f"run-{numruns-1}", f"checkpoint-{num_ckpts-1}" 29 | ), 30 | 200, 31 | ) 32 | for ckpt in range(num_ckpts): 33 | path = joinpaths(current_folder_clean, f"run-{run}", f"checkpoint-{ckpt}") 34 | os.makedirs(path, exist_ok=True) 35 | _save_json(trainer_state, joinpaths(path, "trainer_state.json")) 36 | 37 | 38 | def test_ckpt_cleaner(): 39 | """Test that CkptCleaner removes the correct folders.""" 40 | folder_clean = "tmp_clean_folder" 41 | _create_current_folder_clean(folder_clean) 42 | prev_files_count = _get_total_files_directory(folder_clean) 43 | dataset_folder = "tmp_dataset_folder" 44 | os.makedirs(dataset_folder, exist_ok=True) 45 | ckpt_cleaner = CkptCleaner( 46 | current_folder_clean=folder_clean, 47 | current_dataset_folder=dataset_folder, 48 | modelname="test_modelname", 49 | metrics_save_dir="test_metricsdir", 50 | ) 51 | ckpt_cleaner() 52 | ckpt_cleaner() 53 | ckpt_cleaner() 54 | ckpt_cleaner() 55 | post_files_count = _get_total_files_directory(folder_clean) 56 | assert ( 57 | post_files_count < prev_files_count 58 | ), f"Count for post is {post_files_count}, for prev is {prev_files_count}" 59 | assert os.path.exists( 60 | joinpaths(dataset_folder, "best_ckpt_test_modelname") 61 | ), "Path for best ckpt should exist but it does not." 62 | -------------------------------------------------------------------------------- /src/nlpboost/tests/test_dataset_config.py: -------------------------------------------------------------------------------- 1 | from nlpboost import DatasetConfig 2 | 3 | 4 | def test_dataset_config(): 5 | """Test that dataset config save the correct parameters.""" 6 | fixed_train_args = { 7 | "evaluation_strategy": "epoch", 8 | "num_train_epochs": 10, 9 | "do_train": True, 10 | "do_eval": False, 11 | "logging_strategy": "epoch", 12 | "save_strategy": "epoch", 13 | "save_total_limit": 10, 14 | "seed": 69, 15 | "bf16": True, 16 | "dataloader_num_workers": 16, 17 | "adam_epsilon": 1e-8, 18 | "adam_beta1": 0.9, 19 | "adam_beta2": 0.999, 20 | "group_by_length": True, 21 | "lr_scheduler_type": "linear", 22 | "learning_rate": 1e-4, 23 | "per_device_train_batch_size": 10, 24 | "per_device_eval_batch_size": 10, 25 | "gradient_accumulation_steps": 6, 26 | "warmup_ratio": 0.08, 27 | } 28 | 29 | dataset_config = { 30 | "seed": 44, 31 | "direction_optimize": "minimize", 32 | "metric_optimize": "eval_loss", 33 | "callbacks": [], 34 | "fixed_training_args": fixed_train_args, 35 | "dataset_name": "en_es_legal", 36 | "alias": "en_es_legal", 37 | "task": "summarization", 38 | "hf_load_kwargs": {"path": "avacaondata/wnli", "use_auth_token": False}, 39 | "label_col": "target_es", 40 | "retrain_at_end": False, 41 | "num_proc": 12, 42 | "custom_eval_func": lambda p: p, 43 | } 44 | dataset_config = DatasetConfig(**dataset_config) 45 | assert ( 46 | dataset_config.task == "summarization" 47 | ), f"Task should be summarization and is {dataset_config.task}" 48 | assert ( 49 | dataset_config.label_col == "target_es" 50 | ), f"Label col should be target_es and is {dataset_config.label_col}" 51 | assert ( 52 | dataset_config.fixed_training_args["evaluation_strategy"] == "epoch" 53 | ), "Evaluation strategy should be epoch." 54 | assert dataset_config.fixed_training_args["seed"] == 69, "Seed should be 69." 55 | -------------------------------------------------------------------------------- /src/nlpboost/tests/test_general_utils.py: -------------------------------------------------------------------------------- 1 | from nlpboost.utils import ( 2 | dict_to_list, 3 | _tokenize_dataset, 4 | get_tags, 5 | _load_json, 6 | _save_json, 7 | _parse_modelname, 8 | _fix_json, 9 | joinpaths, 10 | filter_empty, 11 | get_windowed_match_context_answer, 12 | _save_metrics, 13 | _unwrap_reference, 14 | ) 15 | from nlpboost.tokenization_functions import ( 16 | tokenize_ner, 17 | tokenize_classification, 18 | tokenize_squad, 19 | tokenize_summarization, 20 | ) 21 | from nlpboost import ModelConfig, DatasetConfig 22 | from datasets import Dataset, DatasetDict, load_dataset 23 | import pandas as pd 24 | from transformers import AutoTokenizer 25 | from functools import partial 26 | import os 27 | 28 | tok_func_map = { 29 | "ner": tokenize_ner, 30 | "qa": tokenize_squad, 31 | "summarization": tokenize_summarization, 32 | "classification": tokenize_classification, 33 | } 34 | 35 | 36 | def _get_feature_names(dataset_split): 37 | """Get feature names for a dataset split.""" 38 | return [k for k in dataset_split.features.keys()] 39 | 40 | 41 | def _label_mapper_ner(example, dataset_config, tag2id): 42 | """Map the labels for NER to use ints.""" 43 | example[dataset_config.label_col] = [ 44 | tag2id[label] for label in example[dataset_config.label_col] 45 | ] 46 | return example 47 | 48 | 49 | def _create_fake_dataset(): 50 | """Create a fake dataset to test dict_to_list.""" 51 | data_dict = { 52 | "sentence": "Hola me llamo Pedro", 53 | "entities": [ 54 | { 55 | "start_character": 14, 56 | "end_character": 19, 57 | "ent_label": "PER", 58 | "ent_text": "Pedro", 59 | } 60 | ], 61 | } 62 | data_dict2 = { 63 | "sentence": "Hola me llamo Pedro y ya está.", 64 | "entities": [ 65 | { 66 | "start_character": 14, 67 | "end_character": 19, 68 | "ent_label": "PER", 69 | "ent_text": "Pedro", 70 | } 71 | ], 72 | } 73 | data_dict3 = { 74 | "sentence": "Hola me llamo Pedro y él Manuel.", 75 | "entities": [ 76 | { 77 | "start_character": 14, 78 | "end_character": 19, 79 | "ent_label": "PER", 80 | "ent_text": "Pedro", 81 | }, 82 | { 83 | "start_character": 25, 84 | "end_character": 31, 85 | "ent_label": "PER", 86 | "ent_text": "Pedro", 87 | }, 88 | ], 89 | } 90 | data_dict_empty = { 91 | "sentence": "Hola me llamo Pedro", 92 | "entities": [], 93 | } 94 | data = [data_dict] * 3 95 | data.append(data_dict_empty) # for testing empty data also. 96 | data.append(data_dict2) 97 | data.append(data_dict3) 98 | df = pd.DataFrame(data) 99 | dataset = Dataset.from_pandas(df) 100 | return dataset 101 | 102 | 103 | def test_save_load_json(): 104 | """Test load and save functions for jsons.""" 105 | d = {"a": "a"} 106 | _save_json(d, "prueba.json") 107 | assert os.path.exists("prueba.json"), "No funciona el guardado de json" 108 | b = _load_json("prueba.json") 109 | assert b == d, "The saved and loaded objects are not equal." 110 | 111 | 112 | def test_joinpaths(): 113 | """Test whether joinpaths work correctly.""" 114 | p1 = "a" 115 | p2 = "b" 116 | ptotal = "a/b" 117 | result = joinpaths(p1, p2) 118 | assert result == ptotal, f"The obtained path: {result} doesn't coincide: {ptotal}" 119 | 120 | 121 | def test_filter_empty(): 122 | """Test that filter_empty filters empty chars.""" 123 | lista = ["a", "", " ", "b"] 124 | result = list(filter(filter_empty, lista)) 125 | assert len(result) < len(lista), "The length of the new list should be shorter." 126 | assert all([c not in result for c in ["", " "]]), "There are empty characters." 127 | 128 | 129 | def test_dict_to_list(): 130 | """Test dict_to_list function to parse NER tasks data.""" 131 | dataset = _create_fake_dataset() 132 | dataset = dataset.map(dict_to_list, batched=False) 133 | assert "token_list" in dataset[0], "token list should be in dataset." 134 | assert "label_list" in dataset[0], "label list should be in dataset." 135 | assert dataset[0]["token_list"][-1] == "Pedro", "Pedro should be the last token." 136 | assert dataset[0]["label_list"][0] == "O", "First label should be O." 137 | assert dataset[0]["label_list"][-1] == "PER", "Last label should be PER." 138 | 139 | 140 | def test_get_tags(): 141 | """Test get tags function.""" 142 | dataset = _create_fake_dataset() 143 | dataset = dataset.map(dict_to_list, batched=False) 144 | dataset = DatasetDict({"train": dataset}) 145 | dataconfig = DatasetConfig( 146 | dataset_name="prueba", 147 | alias="prueba", 148 | task="ner", 149 | fixed_training_args={}, 150 | num_proc=1, 151 | text_field="token_list", 152 | label_col="label_list", 153 | ) 154 | tags = get_tags(dataset, dataconfig) 155 | entities_should_be = ["O", "PER"] 156 | assert ( 157 | len(tags) == 2 158 | ), f"Only 2 different labels were presented, but length of tags is {len(tags)}" 159 | assert all( 160 | [ent in tags for ent in entities_should_be] 161 | ), "Not all entities were captured by get tags." 162 | 163 | 164 | def test_tokenize_dataset(): 165 | """Test that _tokenize_dataset effectively tokenizes the dataset.""" 166 | tokenizer = AutoTokenizer.from_pretrained("CenIA/albert-tiny-spanish") 167 | modelconfig = ModelConfig( 168 | save_name="prueba_tokenize_dataset", 169 | name="prueba_tokenize_dataset", 170 | hp_space=lambda trial: trial, 171 | ) 172 | dataconfig = DatasetConfig( 173 | dataset_name="prueba", 174 | alias="prueba", 175 | task="ner", 176 | fixed_training_args={}, 177 | num_proc=1, 178 | text_field="token_list", 179 | label_col="label_list", 180 | ) 181 | dataset = _create_fake_dataset() 182 | dataset = DatasetDict({"train": dataset}) 183 | feat_names_prev = _get_feature_names(dataset["train"]) 184 | dataset = dataset.map(dict_to_list, batched=False) 185 | tags = get_tags(dataset, dataconfig) 186 | tag2id = {t: i for i, t in enumerate(sorted(tags))} 187 | dataset = dataset.map( 188 | partial(_label_mapper_ner, dataset_config=dataconfig, tag2id=tag2id) 189 | ) 190 | tokenized_dataset = _tokenize_dataset( 191 | tokenizer, tok_func_map, dataset, dataconfig, modelconfig 192 | ) 193 | feat_names_post = _get_feature_names(tokenized_dataset["train"]) 194 | assert ( 195 | feat_names_post != feat_names_prev 196 | ), f"Posterior names: \n {feat_names_post} \n should be different from pre: \n {feat_names_prev}" 197 | assert isinstance( 198 | tokenized_dataset["train"][0]["input_ids"][0], int 199 | ), f"Input ids should be ints" 200 | partial_custom_tok_func_call = partial( 201 | tokenize_ner, tokenizer=tokenizer, dataset_config=dataconfig 202 | ) 203 | setattr(modelconfig, "partial_custom_tok_func_call", partial_custom_tok_func_call) 204 | tokenized_alternative = _tokenize_dataset( 205 | tokenizer, tok_func_map, dataset, dataconfig, modelconfig 206 | ) 207 | feat_names_post2 = _get_feature_names(tokenized_alternative["train"]) 208 | assert ( 209 | feat_names_post2 != feat_names_prev 210 | ), f"Posterior names: \n {feat_names_post} \n should be different from pre: \n {feat_names_prev}" 211 | assert isinstance( 212 | tokenized_alternative["train"][0]["input_ids"][0], int 213 | ), "Input ids should be ints." 214 | 215 | 216 | def test_get_windowed_match_context_answer(): 217 | """Test that the matching of context-answer works.""" 218 | context = "La respuesta a cuál es el rey de España es Juan Carlos Mencía según dicen algunos expertos en la materia que se hacen llamar mencistas." 219 | answer = "Juan Carlos I." 220 | beg, end, new_answer = get_windowed_match_context_answer( 221 | context, answer, maxrange=4 222 | ) 223 | assert isinstance(beg, int), "Beginning index should be int." 224 | assert isinstance(end, int), "Ending index should be int." 225 | assert isinstance(new_answer, str), "The new answer should be a str" 226 | assert "Juan Carlos" in new_answer, "Juan Carlos should be in new answer." 227 | 228 | 229 | def test_fix_json(): 230 | """Test if jsons are fixed.""" 231 | metrics = [{"metric": 1}] 232 | metrics_fixed = _fix_json(metrics) 233 | assert isinstance( 234 | metrics_fixed[0]["metric"], float 235 | ), "Ints were not converted to float." 236 | metrics2 = [{"metric": {"metric": 1}}] 237 | metrics_fixed = _fix_json(metrics2) 238 | assert isinstance( 239 | metrics_fixed[0]["metric"]["metric"], float 240 | ), "Ints were not converted to float." 241 | 242 | 243 | def test_parse_modelname(): 244 | """Test if model names are correctly parsed.""" 245 | modname = "hola/cocacola" 246 | parsed = _parse_modelname(modname) 247 | assert "/" not in parsed, "/ should not be in parsed name." 248 | 249 | 250 | def test_save_metrics(): 251 | """Test that metrics are saved.""" 252 | metrics = {"rouge2": 0.12, "rouge1": 0.30} 253 | metricsdir = "pruebametrics" 254 | os.makedirs(metricsdir, exist_ok=True) 255 | _save_metrics(metrics, "modelometrics", "datasetmetrics", metricsdir) 256 | assert os.path.exists( 257 | joinpaths(metricsdir, "modelometrics#datasetmetrics.json") 258 | ), "El fichero de metrics no ha sido guardado." 259 | try: 260 | _save_metrics(metrics, "modelometrics", "datasetmetrics", "metricsfalso") 261 | except Exception as e: 262 | print("Ha fallado save metrics donde tiene que fallar.") 263 | 264 | 265 | def test_unwrap_reference(): 266 | """Test the unwrapping of a QA reference.""" 267 | reference_simple = {"id": "A", "answers": "A"} 268 | unwrapped_simple = _unwrap_reference(reference_simple) 269 | assert isinstance( 270 | unwrapped_simple, list 271 | ), f"should return list when dict is passed but is: {type(reference_simple)}" 272 | assert ( 273 | unwrapped_simple[0] == reference_simple 274 | ), "This should just be a list around the dict." 275 | reference_multiple = [ 276 | {"id": "A", "answers": {"text": "a", "start": 0}}, 277 | {"id": "A", "answers": {"text": "b", "start": 2}}, 278 | ] 279 | unwrapped_complex = _unwrap_reference(reference_multiple) 280 | assert ( 281 | len(unwrapped_complex) == 2 282 | ), f"The length of unwrapped complex should be 2 and is {len(unwrapped_complex)}" 283 | -------------------------------------------------------------------------------- /src/nlpboost/tests/test_model_config.py: -------------------------------------------------------------------------------- 1 | from nlpboost import ModelConfig 2 | from transformers import MarianMTModel, Seq2SeqTrainer 3 | 4 | 5 | def test_model_config(): 6 | """Test that model config saves some parameters correctly.""" 7 | 8 | def tokenize_dataset(examples): 9 | return examples 10 | 11 | def hp_space(trial): 12 | return trial 13 | 14 | marianmt_config = { 15 | "max_length_summary": 512, 16 | "n_trials": 1, 17 | "save_dir": "prueba_marianmt_savedir", 18 | "random_init_trials": 1, 19 | "name": "Helsinki-NLP/opus-mt-en-es", 20 | "save_name": "testname", 21 | "hp_space": hp_space, 22 | "num_beams": 4, 23 | "trainer_cls_summarization": Seq2SeqTrainer, 24 | "model_cls_summarization": MarianMTModel, 25 | "custom_tokenization_func": tokenize_dataset, 26 | "only_test": False, 27 | } 28 | 29 | marianmt_config = ModelConfig(**marianmt_config) 30 | assert ( 31 | marianmt_config.save_name == "testname" 32 | ), f"The name should be testname and is {marianmt_config.save_name}" 33 | assert ( 34 | marianmt_config.num_beams == 4 35 | ), f"Number of beams should be 4 and is {marianmt_config.num_beams}" 36 | -------------------------------------------------------------------------------- /src/nlpboost/tokenization_functions.py: -------------------------------------------------------------------------------- 1 | import re 2 | import tokenizers 3 | import evaluate 4 | import collections 5 | from tqdm import tqdm 6 | import numpy as np 7 | from functools import partial 8 | from .utils import match_questions_multiple_answers 9 | 10 | 11 | def tokenize_classification(examples, tokenizer, dataset_config): 12 | """ 13 | Tokenize classification datasets. 14 | 15 | Given a dataset, a tokenizer and a dataset configuration, returns 16 | the tokenized dataset. 17 | 18 | Parameters 19 | ---------- 20 | examples: datasets.Dataset 21 | Samples from datasets.Dataset. 22 | tokenizer: tokenizers.Tokenizer 23 | Instance of hf's tokenizer. 24 | dataset_config: benchmarker.DatasetConfig 25 | Instance of a Dataset Config. 26 | 27 | Returns 28 | ------- 29 | tokenized: 30 | Tokenized samples. 31 | """ 32 | if dataset_config.is_2sents: 33 | tokenized = tokenizer( 34 | examples[dataset_config.sentence1_field], 35 | examples[dataset_config.sentence2_field], 36 | truncation=True, 37 | padding="longest", 38 | max_length=512, 39 | ) 40 | else: 41 | tokenized = tokenizer( 42 | examples[dataset_config.text_field], 43 | truncation=True, 44 | padding="longest", 45 | max_length=512, 46 | ) 47 | if not dataset_config.is_multilabel: 48 | tokenized["labels"] = examples[dataset_config.label_col] 49 | else: 50 | columns_not_text = list( 51 | sorted([col for col in examples if dataset_config.text_field not in col]) 52 | ) 53 | labels = [ 54 | [float(examples[col][i]) for col in columns_not_text] 55 | for i in range(len(examples[dataset_config.text_field])) 56 | ] 57 | tokenized["labels"] = labels 58 | return tokenized 59 | 60 | 61 | def tokenize_ner(examples, tokenizer, dataset_config): 62 | """ 63 | Tokenize a dataset or dataset split. 64 | 65 | This function is intended to be used inside the map method for the Dataset. 66 | 67 | Parameters 68 | ---------- 69 | examples: datasets.Dataset 70 | Samples from datasets.Dataset. 71 | tokenizer: tokenizers.Tokenizer 72 | Instance of hf's tokenizer. 73 | dataset_config: benchmarker.DatasetConfig 74 | Instance of a Dataset Config. 75 | 76 | Returns 77 | ------- 78 | tokenized: 79 | Tokenized samples. 80 | """ 81 | ignore_index = -100 82 | tokenized = tokenizer( 83 | examples[dataset_config.text_field], 84 | truncation=True, 85 | is_split_into_words=True, 86 | padding="longest", 87 | max_length=512, 88 | ) 89 | 90 | labels = [] 91 | for i, label in enumerate(examples[dataset_config.label_col]): 92 | word_ids = tokenized.word_ids(batch_index=i) 93 | label_ids = [] 94 | for word_idx in word_ids: 95 | # Special tokens have a word id that is None. We set the label to -100 so 96 | # they are automatically ignored in the loss function. 97 | if word_idx is None: 98 | label_ids.append(ignore_index) 99 | else: 100 | label_ids.append(label[word_idx]) 101 | labels.append(label_ids) 102 | 103 | tokenized["labels"] = labels 104 | return tokenized 105 | 106 | 107 | def tokenize_squad(examples, tokenizer, dataset_config=None, pad_on_right=True): 108 | """ 109 | Tokenize samples of squad-like datasets, on batches. 110 | 111 | It differentiates between BPE tokenizers and others 112 | as there are errors in these ones if they are processed in the conventional way. 113 | 114 | Parameters 115 | ---------- 116 | examples: datasets.Dataset 117 | Samples from datasets.Dataset. 118 | tokenizer: tokenizers.Tokenizer 119 | Instance of hf's tokenizer. 120 | pad_on_right: bool 121 | Whether or not to pad the samples on the right side. True for most models. 122 | 123 | Returns 124 | ------- 125 | tokenized_examples: 126 | Tokenized samples. 127 | """ 128 | tokenized_examples = tokenizer( 129 | examples["question" if pad_on_right else "context"], 130 | examples["context" if pad_on_right else "question"], 131 | truncation="only_second" if pad_on_right else "only_first", 132 | max_length=512, 133 | stride=128, 134 | return_overflowing_tokens=True, 135 | return_offsets_mapping=True, 136 | padding="max_length", 137 | ) 138 | # Since one example might give us several features if it has a long context, we need a map from a feature to 139 | # its corresponding example. This key gives us just that. 140 | sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") 141 | # The offset mappings will give us a map from token to character position in the original context. This will 142 | # help us compute the start_positions and end_positions. 143 | offset_mapping = tokenized_examples.pop("offset_mapping") 144 | # Let's label those examples! 145 | tokenized_examples["start_positions"] = [] 146 | tokenized_examples["end_positions"] = [] 147 | for i, offsets in enumerate(offset_mapping): 148 | # We will label impossible answers with the index of the CLS token. 149 | input_ids = tokenized_examples["input_ids"][i] 150 | cls_index = input_ids.index(tokenizer.cls_token_id) 151 | 152 | # Grab the sequence corresponding to that example (to know what is the context and what is the question). 153 | sequence_ids = tokenized_examples.sequence_ids(i) 154 | 155 | # One example can give several spans, this is the index of the example containing this span of text. 156 | sample_index = sample_mapping[i] 157 | answers = examples["answers"][sample_index] 158 | # If no answers are given, set the cls_index as answer. 159 | if len(answers["answer_start"]) == 0: 160 | tokenized_examples["start_positions"].append(cls_index) 161 | tokenized_examples["end_positions"].append(cls_index) 162 | else: 163 | # Start/end character index of the answer in the text. 164 | start_char = answers["answer_start"][0] 165 | end_char = start_char + len(answers["text"][0]) 166 | # Start token index of the current span in the text. 167 | token_start_index = 0 168 | while sequence_ids[token_start_index] != (1 if pad_on_right else 0): 169 | token_start_index += 1 170 | 171 | # End token index of the current span in the text. 172 | token_end_index = len(input_ids) - 1 173 | while sequence_ids[token_end_index] != (1 if pad_on_right else 0): 174 | token_end_index -= 1 175 | 176 | # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). 177 | if not ( 178 | offsets[token_start_index][0] <= start_char 179 | and offsets[token_end_index][1] >= end_char 180 | ): 181 | tokenized_examples["start_positions"].append(cls_index) 182 | tokenized_examples["end_positions"].append(cls_index) 183 | else: 184 | # Otherwise move the token_start_index and token_end_index to the two ends of the answer. 185 | # Note: we could go after the last offset if the answer is the last word (edge case). 186 | while ( 187 | token_start_index < len(offsets) 188 | and offsets[token_start_index][0] <= start_char 189 | ): 190 | token_start_index += 1 191 | tokenized_examples["start_positions"].append(token_start_index - 1) 192 | while offsets[token_end_index][1] >= end_char: 193 | token_end_index -= 1 194 | tokenized_examples["end_positions"].append(token_end_index + 1) 195 | return tokenized_examples 196 | 197 | 198 | def tokenize_summarization(examples, tokenizer, dataset_config): 199 | """ 200 | Tokenization function for summarization tasks. 201 | 202 | Parameters 203 | ---------- 204 | examples: datasets.Dataset 205 | Samples from datasets.Dataset. 206 | tokenizer: tokenizers.Tokenizer 207 | Instance of hf's tokenizer. 208 | dataset_config: benchmarker.DatasetConfig 209 | Instance of a Dataset Config. 210 | 211 | Returns 212 | ------- 213 | examples: datasets.Dataset 214 | Tokenized samples with all necessary fields. 215 | """ 216 | model_inputs = tokenizer( 217 | examples[dataset_config.text_field], 218 | truncation=True, 219 | max_length=tokenizer.model_max_length, 220 | ) 221 | with tokenizer.as_target_tokenizer(): 222 | labels = tokenizer( 223 | examples[dataset_config.summary_field], 224 | max_length=dataset_config.max_length_summary, 225 | truncation=True, 226 | ) 227 | model_inputs["labels"] = labels["input_ids"] 228 | return model_inputs 229 | --------------------------------------------------------------------------------