├── .gitignore ├── CLI.md ├── LICENSE.md ├── README.md ├── poetry.lock ├── pyproject.toml ├── smaug ├── __init__.py ├── _itertools.py ├── broadcast.py ├── cli │ ├── __init__.py │ ├── accelerator.py │ ├── augment.py │ ├── config.py │ ├── context.py │ ├── fmt.py │ ├── io.py │ ├── param.py │ ├── pipeline.py │ ├── processor.py │ ├── transform.py │ └── validation.py ├── core.py ├── frozen.py ├── functional.py ├── models │ ├── __init__.py │ ├── stanza.py │ └── transformers.py ├── more_functools.py ├── ops │ ├── __init__.py │ ├── detection.py │ ├── lang_model.py │ ├── masking.py │ ├── modification.py │ ├── nli.py │ ├── pos_tagging.py │ ├── sentence.py │ ├── sentence_comparison.py │ └── text_generation.py ├── perturb │ ├── __init__.py │ ├── delete_random_words.py │ ├── delete_span_between_punctuation.py │ ├── insert_text_span.py │ ├── negate.py │ ├── swap_named_entity.py │ ├── swap_number.py │ └── swap_poisson_span.py ├── promote.py └── random.py └── tests ├── __init__.py ├── test_broadcast.py ├── test_detection.py ├── test_frozen.py ├── test_itertools.py ├── test_masking.py ├── test_modification.py ├── test_more_functools.py ├── test_promotion.py ├── test_sentence.py ├── test_sentence_comparison.py └── test_transform.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # JetBrains 132 | .idea/ 133 | .fleet/ 134 | 135 | .vscode -------------------------------------------------------------------------------- /CLI.md: -------------------------------------------------------------------------------- 1 | ## Command Line Interface 2 | 3 | This document describes the command line interface provided by this package. There are three types of operations: Transforms, Validations and Utilities. 4 | 5 | Transforms take as input a sentence and produce one or multiple perturbed sentences. 6 | 7 | Validations receive an original sentence and a perturbed sentence and verify if the pertubed sentece complies with some requirement. 8 | 9 | > ***NOTE***: Each transform defines some *default* validations to be executed, which can be disabled with the `no-post-run` flag. 10 | 11 | Utilities offer functions to perform common operations, such as reading and writing files. 12 | 13 | ## Configuration File Specification 14 | 15 | The cli tool can be used with a `yaml` configuration file as follows: 16 | 17 | ``` 18 | augment --cfg 19 | ``` 20 | 21 | An example of a configuration file is: 22 | 23 | ```yaml 24 | pipeline: 25 | - cmd: io-read-csv 26 | path: 27 | - cmd: transf-neg 28 | - cmd: transf-ins-text 29 | validations: 30 | - cmd: val-keep-geq-edit-dist 31 | distance: 8 32 | level: word 33 | - cmd: val-rm-pattern 34 | pattern: hello-world 35 | - cmd: io-write-json 36 | path: 37 | seed: 42 38 | no-post-run: False 39 | ``` 40 | 41 | The first pipeline section is mandatory and specifies a list with all the commands to be executed. After that section, other cli arguments can be specified (such as `seed` in this example). The arguments are the same as in the cli command, but without the `--` in the beginning. Boolean flags also do not have `--` and can have the value True of False. 42 | 43 | 44 | Inside the pipeline section, each command is identified with `cmd: `. The remaining tags in the command entry are the arguments for the command. 45 | 46 | Inside transforms, a special `validations` tag can be used to specify validations for the command only. Validations for all previous transforms can be specified as a regular command in the pipeline. In the above exaple `val-keep-geq-edit-dist` is only applied to `transf-ins-text` but `val-rm-pattern` is applied to `transf-neg` and `transf-ins-text`. 47 | 48 | ## Transforms 49 | 50 | ### transf-swp-ne 51 | 52 | Detects a single named entity with a [Stanza NER model](https://stanfordnlp.github.io/stanza/available_models.html#available-ner-models) and swaps it for text generated with [Google's mT5](https://arxiv.org/abs/2010.11934). 53 | 54 | ### transf-swp-num 55 | 56 | Detects a single number with RegEx and swaps it for text generated with [Google's mT5](https://arxiv.org/abs/2010.11934). 57 | 58 | ### transf-swp-poisson-span 59 | 60 | Replaces a span of text with size given by a Poisson distribution for text generated with [Google's mT5](https://arxiv.org/abs/2010.11934). 61 | 62 | ### transf-neg 63 | 64 | Negates an english sentence using [PolyJuice](https://arxiv.org/abs/2101.00288) conditioned for negation. 65 | 66 | ### transf-ins-text 67 | 68 | Insert random text in multiple places using [Google's mT5](https://arxiv.org/abs/2010.11934) model. 69 | 70 | ### transf-del-punct-span 71 | 72 | Removes a single span between two punctuation symbols `.,?!`. 73 | 74 | The following table details the available CLI commands: 75 | 76 | ## Validations 77 | 78 | ### val-rm-equal 79 | 80 | Verifies if the perturbed sentence is different from the original sentence with string comparison. Useful if the transform may return the original sentence. 81 | 82 | ### val-rm-pattern 83 | 84 | Verifies if the perturbed sentence does not have a specific regular expression. Useful with language models that may leave special tokens behind. 85 | 86 | ### val-keep-contradiction 87 | 88 | Verifies if the perturbed sentence contradicts the original sentence. Relies on a [RoBERTa](https://arxiv.org/abs/1907.11692) model trained for mnli. 89 | 90 | ### val-keep-eq-ne 91 | 92 | Verifies if the perturbed and original sentences have the same number of named entities using a [Stanza NER model](https://stanfordnlp.github.io/stanza/available_models.html#available-ner-models) to detect them. 93 | 94 | ### val-keep-eq-num 95 | 96 | Verifies if the perturbed and original sentences have the same number of numbers using RegEx to detect them. 97 | 98 | ### val-keep-leq-char-ins 99 | 100 | Verifies if the perturbed sentence has a number of specific character insertions below a threshold, when compared to the original. 101 | 102 | ### val-keep-geq-edit-dist 103 | 104 | Verifies if the perturbed and original sentences an [minimum edit distance](https://web.stanford.edu/class/cs124/lec/med.pdf) above a threshold. 105 | 106 | ## Utilities 107 | 108 | ### io-read-lines 109 | 110 | Reads sentences from a text file, where each line is a sentence. 111 | 112 | ### io-read-csv 113 | 114 | Reads the sentences from a csv file. Each line of the file has the sentence to perturb and the sentence language in the format \,\. 115 | 116 | ### io-write-json 117 | 118 | Writes the perturbed sentences in a human-readable JSON format. Each input sentence has a respective output JSON object (in the order of the input). Each JSON object has the original sentence, a dictionary with the perturbations indentified by the transform name and metadata for each transform (also identified by the transform name). -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SMAUG: Sentence-level Multilingual AUGmentation 2 | 3 | `smaug` is a package for multilingual data augmentation. It offers transformations focused on changing specific aspects of sentences, such as Named Entities, Numbers, etc. 4 | 5 | # Getting Started 6 | 7 | To start using `smaug`, you can install it with `pip`: 8 | 9 | ``` 10 | pip install unbabel-smaug 11 | ``` 12 | 13 | To run a simple pipeline with all transforms and default validations, first create the following `yaml` file: 14 | 15 | ```yaml 16 | pipeline: 17 | - cmd: io-read-lines 18 | path: 19 | lang: 20 | - cmd: transf-swp-ne 21 | - cmd: transf-swp-num 22 | - cmd: transf-swp-poisson-span 23 | - cmd: transf-neg 24 | - cmd: transf-ins-text 25 | - cmd: transf-del-punct-span 26 | - cmd: io-write-json 27 | path: 28 | # Remove this line for no seed 29 | seed: 30 | ``` 31 | 32 | The run the following command: 33 | 34 | ```shell 35 | augment --cfg 36 | ``` 37 | 38 | # Usage 39 | 40 | The `smaug` package can be used as a command line interface (CLI) or by directly importing and calling the package Python API. To use `smaug`, first install it by following these [instructions](#install). 41 | 42 | ## Command Line Interface 43 | 44 | The CLI offers a way to read, transform, validate and write perturbed sentences to files. For more information, see the [full details](CLI.md). 45 | 46 | ### Configuration File 47 | 48 | The easiest way to run `smaug` is through a configuration file (see the [full specification](CLI.md#configuration-file-specification)) that specifies and entire pipeline (as shown in the [Getting Started](#getting-started) section), using the following command: 49 | 50 | ```shell 51 | augment --cfg 52 | ``` 53 | 54 | ### Single transform 55 | 56 | As an alternative, you can use the command line to directly specify the pipeline to apply. To apply a single transform to a set of sentences, execute the following command: 57 | 58 | ```shell 59 | augment io-read-lines -p -l io-write-json -p 60 | ``` 61 | 62 | > `` is the name of the transform to apply (see this [section](OPERATIONS.md#transforms) for a list of available transforms). 63 | > 64 | > `` is a text file with one sentence per line. 65 | > 66 | > `` is a two character language code for the input sentences. 67 | > 68 | > `` is a json file to be created with the transformed sentences. 69 | 70 | ### Multiple Transforms 71 | 72 | To apply multiple transforms, just specify them in arbitrary order between the read and write operations: 73 | 74 | ``` shell 75 | augment io-read-lines -p -l ... io-write-json -p 76 | ``` 77 | 78 | ### Multiple Inputs 79 | 80 | To read from multiple input files, also specify them in arbitrary order: 81 | 82 | ```shell 83 | augment io-read-lines -p -l read-lines -p -l ... ... io-write-json -p 84 | ``` 85 | 86 | You can further have multiple languages in a given file by having each line with the structure \,\ and using the following command: 87 | 88 | ```shell 89 | augment io-read-csv -p ... io-write-json -p 90 | ``` 91 | 92 | # Developing 93 | 94 | To develop this package, execute the following steps: 95 | 96 | * Install the [poetry](https://python-poetry.org/docs/#installation) tool for dependency management. 97 | 98 | * Clone this git repository and install the project. 99 | 100 | ``` 101 | git clone https://github.com/Unbabel/smaug.git 102 | cd smaug 103 | poetry install 104 | ``` -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "unbabel-smaug" 3 | version = "0.1.3" 4 | description = "Sentence-level Multilingual Augmentation" 5 | license = "Apache-2.0" 6 | authors = ["Duarte Alves "] 7 | readme = "README.md" 8 | repository = "https://github.com/Unbabel/smaug" 9 | keywords = [ 10 | "Natural Language Processing", 11 | "Data Augmentation" 12 | ] 13 | classifiers = [ 14 | "Development Status :: 3 - Alpha", 15 | "Environment :: Console", 16 | "Intended Audience :: Science/Research", 17 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 18 | ] 19 | packages = [ 20 | { include = "smaug" }, 21 | ] 22 | include = [ 23 | "README.md", 24 | "LICENSE.md", 25 | ] 26 | 27 | [tool.poetry.scripts] 28 | augment = "smaug.cli:augment" 29 | 30 | [tool.poetry.dependencies] 31 | python = "^3.8" 32 | numpy = "^1.21.4" 33 | pandas = "^1.3.4" 34 | sentencepiece = "!=0.1.96" 35 | stanza = "^1.3.0" 36 | torch = "^1.8.2,!=1.13.0" 37 | transformers = "^4.15.0" 38 | nltk = "^3.7" 39 | PyYAML = "^6.0" 40 | packaging = "^21.3" 41 | 42 | [tool.poetry.group.dev.dependencies] 43 | pytest = "^7" 44 | black = "^22.3.0" 45 | 46 | [build-system] 47 | requires = ["poetry-core>=1.0.0"] 48 | build-backend = "poetry.core.masonry.api" 49 | -------------------------------------------------------------------------------- /smaug/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.3" 2 | -------------------------------------------------------------------------------- /smaug/_itertools.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import typing 3 | 4 | 5 | def take(iterable: typing.Iterable, n: int) -> typing.List: 6 | """Return first n items of the iterable as a list. 7 | 8 | Based on the method in itertools recipes in 9 | https://docs.python.org/3/library/itertools.html 10 | """ 11 | return list(itertools.islice(iterable, n)) 12 | 13 | 14 | def repeat_items(iterable: typing.Iterable, n: int) -> typing.Iterable: 15 | """Repeats each item in an iterable n times. 16 | 17 | This function transforms ['A', 'B', 'C'] (n=2) -> ['A', 'A', 'B', 'B', 'C', 'C'] 18 | """ 19 | repeated_iterables = map(lambda x: itertools.repeat(x, n), iterable) 20 | return itertools.chain.from_iterable(repeated_iterables) 21 | 22 | 23 | def unique_everseen(iterable, key=None): 24 | """List unique elements, preserving order. Remember all elements ever seen. 25 | 26 | unique_everseen('AAAABBBCCDAABBB') --> A B C D 27 | unique_everseen('ABBCcAD', str.lower) --> A B C D 28 | 29 | Based on the method in itertools recipes in 30 | https://docs.python.org/3/library/itertools.html 31 | """ 32 | 33 | seen = set() 34 | 35 | if key is None: 36 | key = lambda x: x 37 | 38 | for element in iterable: 39 | k = key(element) 40 | if k not in seen: 41 | seen.add(k) 42 | yield element 43 | -------------------------------------------------------------------------------- /smaug/broadcast.py: -------------------------------------------------------------------------------- 1 | from smaug import _itertools 2 | from smaug.core import Data 3 | 4 | from typing import Tuple 5 | 6 | 7 | def broadcast_data(*values: Data) -> Tuple[Data, ...]: 8 | """Broadcasts all values to the length of the longest Data object. 9 | 10 | All objects must either have length 1 or the length of the longest object. 11 | 12 | Args: 13 | *values (Data): values to broadcast. 14 | 15 | Raises: 16 | ValueError: If the length of some data object is neither the 17 | target length or 1. 18 | 19 | Returns: 20 | Tuple[Data, ...]: tuple with the broadcasted values. This tuple 21 | has one value for each received argument, corresponding to the 22 | respective broadcasted value. 23 | """ 24 | tgt_len = max(len(v) for v in values) 25 | failed = next((v for v in values if len(v) not in (1, tgt_len)), None) 26 | if failed: 27 | raise ValueError( 28 | f"Unable to broadcast Data of length {len(failed)} to length {tgt_len}: " 29 | f"received length must the same as target length or 1." 30 | ) 31 | broadcasted_values = ( 32 | v if len(v) == tgt_len else _itertools.repeat_items(v, tgt_len) for v in values 33 | ) 34 | return tuple(Data(bv) for bv in broadcasted_values) 35 | -------------------------------------------------------------------------------- /smaug/cli/__init__.py: -------------------------------------------------------------------------------- 1 | from smaug.cli.augment import augment 2 | -------------------------------------------------------------------------------- /smaug/cli/accelerator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | 4 | 5 | def use_gpu(no_gpu: bool) -> bool: 6 | use_gpu = not no_gpu 7 | if use_gpu and not torch.cuda.is_available(): 8 | logging.warn("GPU requested but not available. Disabling GPU.") 9 | use_gpu = False 10 | return use_gpu 11 | -------------------------------------------------------------------------------- /smaug/cli/augment.py: -------------------------------------------------------------------------------- 1 | import click 2 | import typing 3 | 4 | from smaug import random 5 | from smaug.cli import config 6 | from smaug.cli import context 7 | from smaug.cli import processor 8 | from smaug.cli import io 9 | from smaug.cli import transform 10 | from smaug.cli import validation 11 | 12 | 13 | @click.group(chain=True, invoke_without_command=True) 14 | @click.option( 15 | "-c", "--cfg", type=str, help="Configuration file for the augmentation pipeline." 16 | ) 17 | @click.option( 18 | "--no-post-run", is_flag=True, help="Disable default post runs for processors." 19 | ) 20 | @click.option("-s", "--seed", type=int, help="Seed for reproducibility.") 21 | def augment(cfg, no_post_run, seed): 22 | """Executes an augmentation pipeline with multiple operations. 23 | 24 | Transform operations generate synthetic records from original records. 25 | 26 | Validation operations verify if a synthetic record meets a desired criteria, removing it otherwise. 27 | 28 | Read and write operations for multiple formats are also available. 29 | """ 30 | pass 31 | 32 | 33 | @augment.result_callback() 34 | @click.pass_context 35 | def process_commands( 36 | ctx, processors, cfg: str, no_post_run: bool, seed: typing.Union[int, None] 37 | ): 38 | ctx.obj = context.Context() 39 | 40 | if cfg is not None: 41 | if len(processors) > 0: 42 | raise ValueError("No commands should be specified with --cfg argument.") 43 | _run_cfg_mode(cfg, no_post_run, seed) 44 | else: 45 | _run_chain_mode(ctx, processors, no_post_run, seed) 46 | 47 | 48 | def _run_chain_mode(ctx, processors, no_post_run: bool, seed: typing.Union[int, None]): 49 | post_run = not no_post_run 50 | 51 | if seed: 52 | click.echo(f"Seed set to {seed}.") 53 | random.seed_everything(seed) 54 | 55 | # Start with an empty iterable. 56 | stream = () 57 | 58 | # Pipe it through all stream processors. 59 | for proc in processors: 60 | stream = processor.call(ctx, proc, stream, post_run=post_run) 61 | 62 | # Evaluate the stream and throw away the items. 63 | _ = [r for r in stream] 64 | 65 | 66 | def _run_cfg_mode(cfg: str, no_post_run: bool, seed: typing.Union[int, None]): 67 | args = config.to_args(cfg, no_post_run, seed) 68 | click.echo(f"Executing: augment {' '.join(args)}") 69 | augment(args) 70 | 71 | 72 | augment.add_command(transform.delete_punct_span) 73 | augment.add_command(transform.insert_text) 74 | augment.add_command(transform.negate) 75 | augment.add_command(transform.swap_ne) 76 | augment.add_command(transform.swap_num) 77 | augment.add_command(transform.swap_poisson_span) 78 | 79 | augment.add_command(validation.keep_contradiction) 80 | augment.add_command(validation.keep_eq_ne_count) 81 | augment.add_command(validation.keep_eq_num_count) 82 | augment.add_command(validation.keep_geq_edit_dist) 83 | augment.add_command(validation.keep_leq_char_ins) 84 | augment.add_command(validation.rm_eq) 85 | augment.add_command(validation.rm_pattern) 86 | 87 | augment.add_command(io.read_csv) 88 | augment.add_command(io.read_lines) 89 | augment.add_command(io.write_json) 90 | -------------------------------------------------------------------------------- /smaug/cli/config.py: -------------------------------------------------------------------------------- 1 | import click 2 | import yaml 3 | 4 | from typing import Any, Dict, List, Union 5 | 6 | 7 | def to_args(cfg: str, no_post_run: bool, seed: Union[int, None]) -> List[str]: 8 | """Parses a config into a list of arguments to execute the augment command.""" 9 | with open(cfg, "r") as fp: 10 | cfg = yaml.safe_load(fp) 11 | 12 | augment_cfg = _build_augment_cfg(cfg) 13 | if seed is not None: 14 | augment_cfg.seed = seed 15 | if no_post_run: 16 | augment_cfg.no_post_run = no_post_run 17 | return augment_cfg.gen_args() 18 | 19 | 20 | def _build_augment_cfg(cfg: Dict[str, Any]) -> "_AugmentCfg": 21 | if "pipeline" not in cfg: 22 | click.echo("Please specify the desired pipeline commands.") 23 | return 24 | pipeline = cfg.pop("pipeline") 25 | cmd_cfgs = [] 26 | for cmd in pipeline: 27 | _extend_cmd_cfgs(cmd_cfgs, cmd) 28 | pipeline_cfg = _PipelineCfg(*cmd_cfgs) 29 | 30 | seed = cfg.get("seed", None) 31 | 32 | no_post_run = cfg.get("no-post-run", False) 33 | 34 | return _AugmentCfg(pipeline_cfg, seed, no_post_run) 35 | 36 | 37 | def _extend_cmd_cfgs(cmd_cfgs: "List[_CommandCfg]", cmd: Dict[str, Any]): 38 | name = cmd.pop("cmd", None) 39 | if name is None: 40 | click.echo("Plase specify cmd inside pipeline.") 41 | exit(1) 42 | 43 | validations = cmd.pop("validations", []) 44 | 45 | cmd_args = {k: v for k, v in cmd.items()} 46 | cmd_cfgs.append(_CommandCfg(name, **cmd_args)) 47 | 48 | for val in validations: 49 | val_name = val.pop("cmd", None) 50 | if val_name is None: 51 | click.echo(f"Plase specify cmd inside {name} validations.") 52 | exit(1) 53 | 54 | val_args = {k: v for k, v in val.items()} 55 | val_args["transform"] = name 56 | 57 | cmd_cfgs.append(_CommandCfg(val_name, **val_args)) 58 | 59 | 60 | class _AugmentCfg: 61 | def __init__( 62 | self, pipeline_cfg: "_PipelineCfg", seed: Union[int, None], no_post_run: bool 63 | ): 64 | self._pipeline_cfg = pipeline_cfg 65 | self.seed = seed 66 | self.no_post_run = no_post_run 67 | 68 | def gen_args(self) -> List[str]: 69 | args = [] 70 | if self.seed is not None: 71 | args.extend(("--seed", str(self.seed))) 72 | if self.no_post_run: 73 | args.append("--no-post-run") 74 | args.extend(self._pipeline_cfg.gen_args()) 75 | return args 76 | 77 | 78 | class _PipelineCfg: 79 | def __init__(self, *cmd_cfgs: "_CommandCfg"): 80 | self._cmd_cfgs = list(cmd_cfgs) 81 | 82 | def gen_args(self): 83 | for cfg in self._cmd_cfgs: 84 | yield from cfg.gen_args() 85 | 86 | 87 | class _CommandCfg: 88 | def __init__(self, name, **kwargs): 89 | self._name = name 90 | self._kwargs = kwargs 91 | 92 | def gen_args(self): 93 | yield self._name 94 | for k, v in self._kwargs.items(): 95 | yield f"--{str(k)}" 96 | yield str(v) 97 | -------------------------------------------------------------------------------- /smaug/cli/context.py: -------------------------------------------------------------------------------- 1 | class Context: 2 | def __init__(self) -> None: 3 | self.__transforms = [] 4 | 5 | def register_transform(self, name: str): 6 | self.__transforms.append(name) 7 | 8 | def iter_transforms(self): 9 | return iter(self.__transforms) 10 | -------------------------------------------------------------------------------- /smaug/cli/fmt.py: -------------------------------------------------------------------------------- 1 | import click 2 | 3 | from tqdm import tqdm 4 | 5 | 6 | class TqdmExtraFormat(tqdm): 7 | """Provides a `total_time` format parameter""" 8 | 9 | @property 10 | def format_dict(self): 11 | d = super(TqdmExtraFormat, self).format_dict 12 | total_time = d["elapsed"] * (d["total"] or 0) / max(d["n"], 1) 13 | d.update(total_time=self.format_interval(total_time)) 14 | return d 15 | 16 | 17 | PBAR_FORMAT = ( 18 | " {percentage:3.0f}% |{bar:40}| [{elapsed}/{total_time}, {rate_inv_fmt}]" 19 | ) 20 | 21 | DEFAULT_MAX_DESC_LEN = 50 22 | 23 | 24 | def print_desc(desc): 25 | click.echo(f"\n{desc}") 26 | 27 | 28 | def pbar_from_total(total: int, desc: str): 29 | print_desc(desc) 30 | return TqdmExtraFormat(total=total, bar_format=PBAR_FORMAT) 31 | 32 | 33 | def pbar_from_iterable(iterable, desc: str): 34 | print_desc(desc) 35 | return TqdmExtraFormat(iterable, bar_format=PBAR_FORMAT) 36 | 37 | 38 | def no_records_message(desc: str): 39 | return f"{desc}\n No records (skipping)." 40 | -------------------------------------------------------------------------------- /smaug/cli/io.py: -------------------------------------------------------------------------------- 1 | import click 2 | import json 3 | import pandas as pd 4 | import typing 5 | 6 | from smaug import frozen 7 | from smaug import random 8 | from smaug.core import Sentence, SpanIndex 9 | from smaug.cli import fmt 10 | from smaug.cli import pipeline 11 | from smaug.cli import param 12 | from smaug.cli import processor 13 | 14 | 15 | @click.command("io-read-lines", short_help="Read sentences from a text file.") 16 | @click.option("-p", "--path", required=True, help="Path for file to read.") 17 | @click.option("-l", "--lang", required=True, help="Language for the sentences.") 18 | @click.option( 19 | "-s", 20 | "--sample", 21 | type=param.INT_OR_FLOAT, 22 | help="Number or percentage of sentences to sample. If not specified, all sentences are used.", 23 | ) 24 | @processor.make 25 | def read_lines(prev, path: str, lang: str, sample: typing.Union[int, float, None]): 26 | """Reads sentences from a text file. 27 | 28 | The file is expected to have one sentence per line. The language must 29 | be specified to enable language aware transformations. 30 | """ 31 | 32 | with open(path, "r") as fp: 33 | sentences = [l[:-1] for l in fp.readlines()] 34 | 35 | sentences = fmt.pbar_from_iterable(sentences, f"Read Sentences from {path}") 36 | records = [pipeline.State(original=s) for s in sentences] 37 | 38 | if sample is not None: 39 | if isinstance(sample, float): 40 | sample = int(sample * len(records)) 41 | 42 | if len(records) > sample: 43 | rng = random.numpy_seeded_rng() 44 | records = rng.choice(records, sample, replace=False).tolist() 45 | 46 | dataset = {"lang": lang, "records": records} 47 | 48 | stream = [el for el in prev] 49 | stream.append(dataset) 50 | return stream 51 | 52 | 53 | @click.command("io-read-csv", short_help="Read data from a CSV file.") 54 | @click.option("-p", "--path", required=True, help="Path for file to read.") 55 | @click.option( 56 | "-s", 57 | "--sample", 58 | type=param.INT_OR_FLOAT, 59 | help="Number or percentage of sentences to sample. If not specified, all sentences are used.", 60 | ) 61 | @processor.make 62 | def read_csv(prev, path, sample: typing.Union[int, None]): 63 | """Reads records from a csv file. 64 | 65 | The first file column will be interpreted as the language and the 66 | second as the sentence. 67 | """ 68 | 69 | data = pd.read_csv(path, index_col=False, header=None) 70 | # Handle empty strings 71 | data[1].fillna("", inplace=True) 72 | rows = list(data.iterrows()) 73 | 74 | datasets = [] 75 | for idx, row in fmt.pbar_from_iterable(rows, f"Read CSV from {path}"): 76 | lang = row[0] 77 | sentence = row[1] 78 | 79 | if idx == 0 or rows[idx - 1][1][0] != lang: 80 | datasets.append({"lang": lang, "records": []}) 81 | 82 | # Always use last dataset 83 | datasets[-1]["records"].append(pipeline.State(original=sentence)) 84 | 85 | if sample is not None: 86 | for dataset in datasets: 87 | records = dataset["records"] 88 | if isinstance(sample, float): 89 | sample = int(sample * len(records)) 90 | 91 | if len(records) > sample: 92 | rng = random.numpy_seeded_rng() 93 | records = rng.choice(records, sample, replace=False).tolist() 94 | dataset["records"] = records 95 | 96 | stream = [el for el in prev] 97 | stream.extend(datasets) 98 | return stream 99 | 100 | 101 | @click.command("io-write-json", short_help="Write records to a JSON file.") 102 | @click.option("-p", "--path", required=True, help="File path to store the records.") 103 | @click.option( 104 | "--indent", 105 | default=2, 106 | type=int, 107 | show_default=True, 108 | help="Number of spaces to indent the file.", 109 | ) 110 | @processor.make 111 | def write_json(datasets, path, indent): 112 | """Writes all records to a JSON file. 113 | 114 | This is an utility operation to store the generated records. It converts the records 115 | to JSON objects and stores them in a file, with a format for easy reading. 116 | 117 | The records are stored in a non-compressed format that is more user friendly. 118 | If the objective is to reduce file size, another write format should be used. 119 | """ 120 | 121 | class _StateEncoder(json.JSONEncoder): 122 | def default(self, o: typing.Any) -> typing.Any: 123 | if isinstance(o, pipeline.State): 124 | return { 125 | "original": o.original, 126 | "perturbations": o.perturbations, 127 | "metadata": o.metadata, 128 | } 129 | if isinstance(o, frozen.frozenlist): 130 | return list(o) 131 | if isinstance(o, SpanIndex): 132 | return int(o.start), int(o.end) 133 | if isinstance(o, Sentence): 134 | return o.value 135 | return super().default(o) 136 | 137 | records = [] 138 | total_records = sum(len(dataset["records"]) for dataset in datasets) 139 | pbar = fmt.pbar_from_total(total_records, f"Write JSON to {path}") 140 | for dataset in datasets: 141 | records.extend(dataset["records"]) 142 | pbar.update(len(dataset["records"])) 143 | 144 | with open(path, "w") as fp: 145 | json.dump(records, fp, ensure_ascii=False, indent=indent, cls=_StateEncoder) 146 | return datasets 147 | -------------------------------------------------------------------------------- /smaug/cli/param.py: -------------------------------------------------------------------------------- 1 | import click 2 | 3 | from typing import Any, Optional 4 | 5 | 6 | class IntOrFloatParamType(click.ParamType): 7 | 8 | name = "int-or-float" 9 | 10 | def convert( 11 | self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context] 12 | ) -> Any: 13 | if isinstance(value, (int, float)): 14 | return value 15 | 16 | if "." in value: 17 | try: 18 | return float(value) 19 | except ValueError: 20 | self.fail(f"{value!r} is not a valid float", param, ctx) 21 | else: 22 | try: 23 | return int(value) 24 | except ValueError: 25 | self.fail(f"{value!r} is not a valid int", param, ctx) 26 | 27 | 28 | INT_OR_FLOAT = IntOrFloatParamType() 29 | -------------------------------------------------------------------------------- /smaug/cli/pipeline.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Any, Callable, Dict, Optional 3 | 4 | from smaug import ops 5 | from smaug.core import Data, DataLike, Sentence, SentenceLike, Validation 6 | from smaug.promote import promote_to_data 7 | 8 | PerturbationId = str 9 | 10 | 11 | @dataclasses.dataclass(frozen=True) 12 | class State: 13 | """Represents the state of the perturbation process.""" 14 | 15 | # The original unmodified sentence. 16 | original: SentenceLike 17 | 18 | # The sentences with the perturbations, identified by their id. 19 | perturbations: Dict[PerturbationId, SentenceLike] = dataclasses.field( 20 | default_factory=dict 21 | ) 22 | 23 | # Other metadata that the perturbations can output, identified by their id. 24 | metadata: Dict[PerturbationId, Any] = dataclasses.field(default_factory=dict) 25 | 26 | 27 | PipelineOp = Callable[[DataLike[State]], Data[State]] 28 | 29 | 30 | def lift_transform( 31 | func: Callable[[DataLike[SentenceLike]], Data[Optional[Sentence]]], 32 | perturbation: PerturbationId, 33 | ) -> PipelineOp: 34 | def transform(records: DataLike[State]) -> Data[State]: 35 | records = promote_to_data(records) 36 | original = Data([r.original for r in records]) 37 | transformed = func(original) 38 | for orig, t in zip(records, transformed): 39 | if t is None: 40 | continue 41 | orig.perturbations[perturbation] = t.value 42 | if t.trace is not None: 43 | orig.metadata[perturbation] = ops.modified_spans_from_trace(t.trace) 44 | return records 45 | 46 | return transform 47 | 48 | 49 | def lift_validation(func: Validation, perturbation: PerturbationId) -> PipelineOp: 50 | def del_perturbation(state: State): 51 | if perturbation in state.perturbations: 52 | del state.perturbations[perturbation] 53 | if perturbation in state.metadata: 54 | del state.metadata[perturbation] 55 | 56 | def validation(records: DataLike[State]) -> Data[State]: 57 | records = promote_to_data(records) 58 | originals = Data([r.original for r in records]) 59 | transformed = Data([r.perturbations.get(perturbation, None) for r in records]) 60 | validated = func(originals, transformed) 61 | for r, v in zip(records, validated): 62 | if v is None: 63 | del_perturbation(r) 64 | return records 65 | 66 | return validation 67 | -------------------------------------------------------------------------------- /smaug/cli/processor.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | 4 | def make(f): 5 | """Decorates function to return transform processor.""" 6 | 7 | def make_processor(*args, **kwargs): 8 | def processor(stream): 9 | return f(stream, *args, **kwargs) 10 | 11 | if hasattr(f, "__post_processors__"): 12 | processor.__post_processors__ = f.__post_processors__ 13 | 14 | return processor 15 | 16 | return functools.update_wrapper(make_processor, f) 17 | 18 | 19 | def post_run(make_proc_cmd, **kwargs): 20 | """Registers a processor to after the function. 21 | 22 | Post processors should be executed after this function. 23 | """ 24 | 25 | def decorator(f): 26 | if not hasattr(f, "__post_processors__"): 27 | f.__post_processors__ = [] 28 | # Insert in position 0 since decorators are evaluated 29 | # in reverse order (from bottom to top) thus ensuring 30 | # the post processors appear in the right order. 31 | f.__post_processors__.insert(0, (make_proc_cmd, kwargs)) 32 | return f 33 | 34 | return decorator 35 | 36 | 37 | def call(ctx, processor, stream, post_run=True): 38 | stream = processor(stream) 39 | 40 | if post_run and hasattr(processor, "__post_processors__"): 41 | for make_proc_cmd, kwargs in processor.__post_processors__: 42 | post_processor = ctx.invoke(make_proc_cmd, **kwargs) 43 | stream = post_processor(stream) 44 | 45 | return stream 46 | -------------------------------------------------------------------------------- /smaug/cli/transform.py: -------------------------------------------------------------------------------- 1 | import click 2 | import functools 3 | 4 | from smaug import core 5 | from smaug import models 6 | from smaug import random 7 | from smaug import perturb 8 | from smaug.cli import accelerator 9 | from smaug.cli import pipeline 10 | from smaug.cli import fmt 11 | from smaug.cli import processor 12 | from smaug.cli import validation 13 | 14 | 15 | _SWAP_NUM_CMD = "transf-swp-num" 16 | _SWAP_NE_CMD = "transf-swp-ne" 17 | _SWAP_POISSON_SPAN_CMD = "transf-swp-poisson-span" 18 | _NEG_CMD = "transf-neg" 19 | _DEL_PUNCT_SPAN_CMD = "transf-del-punct-span" 20 | _INS_TEXT_CMD = "transf-ins-text" 21 | 22 | 23 | @click.command(_SWAP_NUM_CMD, short_help="Swap a number for text with regex and mT5.") 24 | @click.option( 25 | "--batch-size", 26 | default=16, 27 | show_default=True, 28 | help="Batch size when processing records.", 29 | ) 30 | @click.option("--no-gpu", is_flag=True, help="Disable gpu.") 31 | @processor.make 32 | @processor.post_run(validation.rm_eq, cli_transforms=[_SWAP_NUM_CMD]) 33 | @processor.post_run( 34 | validation.rm_pattern, pattern=r"", cli_transforms=[_SWAP_NUM_CMD] 35 | ) 36 | @processor.post_run( 37 | validation.keep_leq_char_ins, 38 | chars="<>()[]{}_", 39 | max_insertions=0, 40 | cli_transforms=[_SWAP_NUM_CMD], 41 | ) 42 | @processor.post_run(validation.keep_eq_num_count, cli_transforms=[_SWAP_NUM_CMD]) 43 | @click.pass_context 44 | def swap_num(ctx, datasets, batch_size, no_gpu): 45 | """Swaps a number for text using regex and mT5. 46 | 47 | This operation is a transformation. 48 | It searches for numbers in the original records using regular expressions and 49 | then uses Google's mT5 to replace the one of the found expressions with text. 50 | 51 | The generated sentences are not guarantied to have a new number replacing the 52 | old one, as the model is free to generate any text. 53 | 54 | It is possible to have other validations to better ensure these conditions 55 | are met. 56 | """ 57 | total_records = sum(len(datasets["records"]) for datasets in datasets) 58 | if total_records == 0: 59 | click.echo(fmt.no_records_message("Swap a Number for Text")) 60 | return datasets 61 | 62 | ctx.obj.register_transform(_SWAP_NUM_CMD) 63 | 64 | gpu = accelerator.use_gpu(no_gpu) 65 | 66 | rng = random.numpy_seeded_rng() 67 | 68 | model, tokenizer = models.mT5_load() 69 | 70 | transform_func = functools.partial( 71 | perturb.swap_number_transform, 72 | mt5_model=model, 73 | mt5_tokenizer=tokenizer, 74 | rng=rng, 75 | gpu=gpu, 76 | ) 77 | 78 | pipeline_func = pipeline.lift_transform(transform_func, _SWAP_NUM_CMD) 79 | 80 | processed = [] 81 | 82 | pbar = fmt.pbar_from_total(total_records, "Swap a Number for Text") 83 | for dataset in datasets: 84 | old_records = dataset["records"] 85 | new_records = [] 86 | 87 | for i in range(0, len(old_records), batch_size): 88 | batch = core.Data(old_records[i : i + batch_size]) 89 | records = pipeline_func(batch) 90 | new_records.extend(records) 91 | pbar.update(len(batch)) 92 | 93 | dataset["records"] = new_records 94 | 95 | processed.append(dataset) 96 | return processed 97 | 98 | 99 | @click.command( 100 | _SWAP_NE_CMD, 101 | short_help="Swap a named entity for text with named entity recognition and mT5.", 102 | ) 103 | @click.option( 104 | "--batch-size", 105 | default=16, 106 | show_default=True, 107 | help="Batch size when processing records.", 108 | ) 109 | @click.option("--no-gpu", is_flag=True, help="Disable gpu.") 110 | @processor.make 111 | @processor.post_run(validation.rm_eq, cli_transforms=[_SWAP_NE_CMD]) 112 | @processor.post_run( 113 | validation.rm_pattern, pattern=r"", cli_transforms=[_SWAP_NE_CMD] 114 | ) 115 | @processor.post_run( 116 | validation.keep_leq_char_ins, 117 | chars="<>()[]{}_", 118 | max_insertions=0, 119 | cli_transforms=[_SWAP_NE_CMD], 120 | ) 121 | @processor.post_run(validation.keep_eq_ne_count, cli_transforms=[_SWAP_NE_CMD]) 122 | @click.pass_context 123 | def swap_ne(ctx, datasets, batch_size, no_gpu): 124 | """Swaps a single named entity for text using named entity recognition and mT5. 125 | 126 | This operation is a transformation. 127 | It searches for named entities in the original records using a stanza model and 128 | then uses Google's mT5 to replace one of the found expressions with text. 129 | 130 | The generated sentences are not guarantied to have a new named entity replacing the 131 | old one, as the model is free to generate any text. 132 | 133 | It is possible to have other validations to better ensure these conditions 134 | are met. 135 | """ 136 | total_records = sum( 137 | len(dataset["records"]) 138 | for dataset in datasets 139 | if models.stanza_ner_lang_available(dataset["lang"]) 140 | ) 141 | if total_records == 0: 142 | click.echo(fmt.no_records_message("Swap a Named Entitiy for Text")) 143 | return datasets 144 | 145 | ctx.obj.register_transform(_SWAP_NE_CMD) 146 | 147 | gpu = accelerator.use_gpu(no_gpu) 148 | rng = random.numpy_seeded_rng() 149 | 150 | model, tokenizer = models.mT5_load() 151 | 152 | processed = [] 153 | 154 | pbar = fmt.pbar_from_total(total_records, "Swap a Named Entitiy for Text") 155 | for dataset in datasets: 156 | lang = dataset["lang"] 157 | if not models.stanza_ner_lang_available(lang): 158 | processed.append(dataset) 159 | continue 160 | ner_pipeline = models.stanza_ner_load(lang, gpu) 161 | 162 | transform_func = functools.partial( 163 | perturb.swap_named_entity_transform, 164 | ner_pipeline=ner_pipeline, 165 | mt5_model=model, 166 | mt5_tokenizer=tokenizer, 167 | rng=rng, 168 | gpu=gpu, 169 | ) 170 | 171 | pipeline_func = pipeline.lift_transform(transform_func, _SWAP_NE_CMD) 172 | 173 | old_records = dataset["records"] 174 | new_records = [] 175 | 176 | for i in range(0, len(old_records), batch_size): 177 | batch = core.Data(old_records[i : i + batch_size]) 178 | records = pipeline_func(batch) 179 | new_records.extend(records) 180 | pbar.update(len(batch)) 181 | 182 | dataset["records"] = new_records 183 | 184 | processed.append(dataset) 185 | return processed 186 | 187 | 188 | @click.command(_NEG_CMD, short_help="Negate the sentence with polyjuice.") 189 | @click.option( 190 | "--batch-size", 191 | default=16, 192 | show_default=True, 193 | help="Batch size when processing records.", 194 | ) 195 | @click.option("--no-gpu", is_flag=True, help="Disable gpu.") 196 | @processor.make 197 | @processor.post_run(validation.rm_eq, cli_transforms=[_NEG_CMD]) 198 | @processor.post_run(validation.rm_pattern, pattern="EMPTY", cli_transforms=[_NEG_CMD]) 199 | @processor.post_run(validation.keep_contradiction, cli_transforms=[_NEG_CMD]) 200 | @click.pass_context 201 | def negate(ctx, datasets, batch_size, no_gpu): 202 | """Negates the received sentences with polyjuice. 203 | 204 | This operation is a transformation. 205 | It tries to negate the sentence if possible. 206 | This transformation is only available for to-english datasets. 207 | """ 208 | total_records = sum( 209 | len(orig["records"]) for orig in datasets if orig["lang"] == "en" 210 | ) 211 | if total_records == 0: 212 | click.echo(fmt.no_records_message("Negate the Sentence")) 213 | return datasets 214 | 215 | ctx.obj.register_transform(_NEG_CMD) 216 | 217 | gpu = accelerator.use_gpu(no_gpu) 218 | 219 | rng = random.numpy_seeded_rng() 220 | pos_pipeline = models.stanza_pos_load("en", gpu) 221 | model, tokenizer = models.polyjuice_load() 222 | 223 | transform_func = functools.partial( 224 | perturb.negate_transform, 225 | pos_pipeline=pos_pipeline, 226 | polyjuice_model=model, 227 | polyjuice_tokenizer=tokenizer, 228 | rng=rng, 229 | gpu=gpu, 230 | ) 231 | 232 | pipeline_func = pipeline.lift_transform(transform_func, _NEG_CMD) 233 | 234 | pbar = fmt.pbar_from_total(total_records, "Negate the Sentence") 235 | 236 | processed = [] 237 | for orig in datasets: 238 | if orig["lang"] == "en": 239 | old_records = orig["records"] 240 | new_records = [] 241 | 242 | for i in range(0, len(old_records), batch_size): 243 | batch = core.Data(old_records[i : i + batch_size]) 244 | records = pipeline_func(batch) 245 | new_records.extend(records) 246 | pbar.update(len(batch)) 247 | 248 | orig["records"] = new_records 249 | 250 | processed.append(orig) 251 | 252 | return processed 253 | 254 | 255 | @click.command( 256 | _DEL_PUNCT_SPAN_CMD, short_help="Removes a span between two punctuation symbols." 257 | ) 258 | @click.option( 259 | "--low", 260 | "-l", 261 | type=int, 262 | default=4, 263 | help="minimum number of words for a span to be eligible for deletion.", 264 | show_default=True, 265 | ) 266 | @click.option( 267 | "--high", 268 | "-h", 269 | type=int, 270 | default=10, 271 | help="maximum number of words for a span to be eligible for deletion.", 272 | show_default=True, 273 | ) 274 | @processor.make 275 | @click.pass_context 276 | def delete_punct_span(ctx, datasets, low, high): 277 | """Removes a span between two punctuation symbols. 278 | 279 | This operation is a transformation. 280 | It detects the following symbols: ,.!? , and deletes a span between two of them. 281 | It also deletes the symbol to the right of the span. 282 | """ 283 | total_records = sum(len(datasets["records"]) for datasets in datasets) 284 | if total_records == 0: 285 | click.echo(fmt.no_records_message(f"Delete a span between punctuation matches.")) 286 | return datasets 287 | 288 | ctx.obj.register_transform(_DEL_PUNCT_SPAN_CMD) 289 | rng = random.numpy_seeded_rng() 290 | transform_func = functools.partial( 291 | perturb.delete_span_between_punctuation_transform, 292 | rng=rng, 293 | low=low, 294 | high=high, 295 | ) 296 | 297 | pipeline_func = pipeline.lift_transform(transform_func, _DEL_PUNCT_SPAN_CMD) 298 | 299 | processed = [] 300 | 301 | pbar = fmt.pbar_from_total( 302 | total_records, f"Delete a span between punctuation matches." 303 | ) 304 | for dataset in datasets: 305 | old_records = core.Data(dataset["records"]) 306 | dataset["records"] = pipeline_func(old_records) 307 | 308 | pbar.update(len(old_records)) 309 | 310 | processed.append(dataset) 311 | return processed 312 | 313 | 314 | @click.command(_INS_TEXT_CMD, short_help="Insert random text with mT5.") 315 | @click.option( 316 | "--prob", 317 | "-p", 318 | default=0.1, 319 | show_default=True, 320 | help="Probability of inserting a mask between two tokens", 321 | ) 322 | @click.option( 323 | "--max-masks", 324 | default=3, 325 | show_default=True, 326 | help="Max masks to add for mT5 to fill.", 327 | ) 328 | @click.option( 329 | "--batch-size", 330 | default=16, 331 | show_default=True, 332 | help="Batch size when processing records.", 333 | ) 334 | @click.option("--no-gpu", is_flag=True, help="Disable gpu.") 335 | @processor.make 336 | @processor.post_run(validation.rm_eq, cli_transforms=[_INS_TEXT_CMD]) 337 | @processor.post_run( 338 | validation.rm_pattern, pattern=r"", cli_transforms=[_INS_TEXT_CMD] 339 | ) 340 | @click.pass_context 341 | def insert_text(ctx, datasets, prob, max_masks, batch_size, no_gpu): 342 | """Randomly inserts text using mT5. 343 | 344 | This operation is a transformation. 345 | It randomly adds masks between the words of the original sentence and 346 | takes as the output the sentence with the masks filled with mT5. 347 | """ 348 | total_records = sum(len(datasets["records"]) for datasets in datasets) 349 | if total_records == 0: 350 | click.echo(fmt.no_records_message("Insert Text")) 351 | return datasets 352 | 353 | ctx.obj.register_transform(_INS_TEXT_CMD) 354 | 355 | gpu = accelerator.use_gpu(no_gpu) 356 | rng = random.numpy_seeded_rng() 357 | 358 | model, tokenizer = models.mT5_load() 359 | 360 | transform_func = functools.partial( 361 | perturb.insert_text_span_transform, 362 | mt5_model=model, 363 | mt5_tokenizer=tokenizer, 364 | rng=rng, 365 | p=prob, 366 | max_masks=max_masks, 367 | gpu=gpu, 368 | ) 369 | 370 | pipeline_func = pipeline.lift_transform(transform_func, _INS_TEXT_CMD) 371 | 372 | processed = [] 373 | 374 | pbar = fmt.pbar_from_total(total_records, _INS_TEXT_CMD) 375 | for dataset in datasets: 376 | old_records = dataset["records"] 377 | new_records = [] 378 | 379 | for i in range(0, len(old_records), batch_size): 380 | batch = core.Data(old_records[i : i + batch_size]) 381 | records = pipeline_func(batch) 382 | new_records.extend(records) 383 | pbar.update(len(batch)) 384 | 385 | dataset["records"] = new_records 386 | 387 | processed.append(dataset) 388 | return processed 389 | 390 | 391 | @click.command( 392 | _SWAP_POISSON_SPAN_CMD, short_help="Replace random spans of text with mT5." 393 | ) 394 | @click.option( 395 | "--batch-size", 396 | default=16, 397 | show_default=True, 398 | help="Batch size when processing records.", 399 | ) 400 | @click.option("--no-gpu", is_flag=True, help="Disable gpu.") 401 | @processor.make 402 | @processor.post_run(validation.rm_eq, cli_transforms=[_SWAP_POISSON_SPAN_CMD]) 403 | @processor.post_run( 404 | validation.rm_pattern, 405 | pattern=r"", 406 | cli_transforms=[_SWAP_POISSON_SPAN_CMD], 407 | ) 408 | @click.pass_context 409 | def swap_poisson_span(ctx, datasets, batch_size, no_gpu): 410 | """Randomly replaces text spans with sizes following a Poisson distribution. 411 | 412 | This operation is a transformation. 413 | It masks a span of text on the original sentence, where the number 414 | of masked words (can be 0) is given by the Poisson distribution, and 415 | takes as the output the sentence with the masks filled with mT5. 416 | """ 417 | total_records = sum(len(datasets["records"]) for datasets in datasets) 418 | if total_records == 0: 419 | click.echo(fmt.no_records_message("Insert Text")) 420 | return datasets 421 | 422 | ctx.obj.register_transform(_SWAP_POISSON_SPAN_CMD) 423 | 424 | gpu = accelerator.use_gpu(no_gpu) 425 | rng = random.numpy_seeded_rng() 426 | 427 | model, tokenizer = models.mT5_load() 428 | 429 | transform_func = functools.partial( 430 | perturb.swap_poisson_span_transform, 431 | mt5_model=model, 432 | mt5_tokenizer=tokenizer, 433 | rng=rng, 434 | gpu=gpu, 435 | ) 436 | 437 | pipeline_func = pipeline.lift_transform(transform_func, _SWAP_POISSON_SPAN_CMD) 438 | 439 | processed = [] 440 | 441 | pbar = fmt.pbar_from_total(total_records, _SWAP_POISSON_SPAN_CMD) 442 | for dataset in datasets: 443 | old_records = dataset["records"] 444 | new_records = [] 445 | 446 | for i in range(0, len(old_records), batch_size): 447 | batch = core.Data(old_records[i : i + batch_size]) 448 | records = pipeline_func(batch) 449 | new_records.extend(records) 450 | pbar.update(len(batch)) 451 | 452 | dataset["records"] = new_records 453 | 454 | processed.append(dataset) 455 | return processed 456 | -------------------------------------------------------------------------------- /smaug/cli/validation.py: -------------------------------------------------------------------------------- 1 | import click 2 | import functools 3 | import re 4 | 5 | from smaug import core 6 | from smaug import functional 7 | from smaug import models 8 | from smaug import ops 9 | from smaug.cli import accelerator 10 | from smaug.cli import pipeline 11 | from smaug.cli import fmt 12 | from smaug.cli import processor 13 | 14 | 15 | _RM_EQ_CMD = "val-rm-eq" 16 | _RM_PATTERN_CMD = "val-rm-pattern" 17 | _KEEP_CONTRADICTION_CMD = "val-keep-contradiction" 18 | _KEEP_EQ_NUM_CMD = "val-keep-eq-num" 19 | _KEEP_EQ_NE_CMD = "val-keep-eq-ne" 20 | _KEEP_GEQ_EDIT_DIST_CMD = "val-keep-geq-edit-dist" 21 | _KEEP_LEQ_CHAR_INSERT_CMD = "val-keep-leq-char-ins" 22 | 23 | 24 | @click.command(_RM_EQ_CMD, short_help="Remove sythetic records equal to the original.") 25 | @click.option( 26 | "--transform", 27 | "cli_transforms", 28 | multiple=True, 29 | help="Transforms to filter with this validation. If not specified all are validated.", 30 | ) 31 | @processor.make 32 | @click.pass_context 33 | def rm_eq(ctx, datasets, cli_transforms): 34 | """Validates if the generated records are not equal to the original. 35 | 36 | This operation is a validation. It ensures the generated record is different 37 | from the original one by performing a string comparison. 38 | """ 39 | transforms = cli_transforms if cli_transforms else list(ctx.obj.iter_transforms()) 40 | 41 | total_records = sum(len(orig["records"]) for orig in datasets) 42 | if total_records == 0: 43 | click.echo(fmt.no_records_message("Remove Equal")) 44 | return datasets 45 | 46 | processed = [dataset for dataset in datasets] 47 | 48 | for transform in transforms: 49 | pbar = fmt.pbar_from_total(total_records, f"Remove Equal for {transform}") 50 | validation_func = functional.lift_boolean_validation(lambda o, p: o != p) 51 | pipeline_func = pipeline.lift_validation(validation_func, transform) 52 | 53 | for dataset in processed: 54 | not_validated = core.Data(dataset["records"]) 55 | dataset["records"] = pipeline_func(not_validated) 56 | pbar.update(len(not_validated)) 57 | 58 | return processed 59 | 60 | 61 | @click.command( 62 | _RM_PATTERN_CMD, short_help="Remove synthetic records that match a pattern." 63 | ) 64 | @click.option("-p", "--pattern", required=True, help="Pattern to search.") 65 | @click.option( 66 | "--transform", 67 | "cli_transforms", 68 | multiple=True, 69 | help="Transforms to filter with this validation. If not specified all are validated.", 70 | ) 71 | @processor.make 72 | @click.pass_context 73 | def rm_pattern(ctx, datasets, pattern, cli_transforms): 74 | """Validates if the generated records do not match a regular expression. 75 | 76 | This operations is a validation. It ensures generated records do not have 77 | the given pattern. 78 | 79 | This validation is particularly useful with language models that may leave 80 | unwanted tokens after generation (such as masks or special tokens) to filter 81 | these occurrences. 82 | """ 83 | 84 | transforms = cli_transforms if cli_transforms else list(ctx.obj.iter_transforms()) 85 | 86 | total_records = sum(len(orig["records"]) for orig in datasets) 87 | if total_records == 0: 88 | click.echo(fmt.no_records_message(f'Remove Pattern "{pattern}"')) 89 | return datasets 90 | 91 | processed = [dataset for dataset in datasets] 92 | compiled_pattern = re.compile(pattern) 93 | for transform in transforms: 94 | pbar = fmt.pbar_from_total( 95 | total_records, f'Remove Pattern "{pattern}" for {transform}' 96 | ) 97 | validation_func = functional.lift_boolean_validation( 98 | lambda o, p: compiled_pattern.search(p.value) is None, 99 | ) 100 | pipeline_func = pipeline.lift_validation(validation_func, transform) 101 | for dataset in processed: 102 | not_validated = core.Data(dataset["records"]) 103 | dataset["records"] = pipeline_func(not_validated) 104 | pbar.update(len(not_validated)) 105 | 106 | return processed 107 | 108 | 109 | @click.command( 110 | _KEEP_CONTRADICTION_CMD, 111 | help="Keep synthetic records contradicting the original.", 112 | ) 113 | @click.option( 114 | "--transform", 115 | "cli_transforms", 116 | multiple=True, 117 | help="Transforms to filter with this validation. If not specified all are validated.", 118 | ) 119 | @click.option( 120 | "--batch-size", 121 | default=16, 122 | show_default=True, 123 | help="Batch size when processing records.", 124 | ) 125 | @click.option("--no-gpu", is_flag=True, help="Disable gpu.") 126 | @processor.make 127 | @click.pass_context 128 | def keep_contradiction(ctx, datasets, cli_transforms, batch_size, no_gpu): 129 | """Validates if the synthetic records contradict the original records. 130 | 131 | This operation is a validation. It uses a RoBERTa model trained for NLI 132 | to ensure the generated records contradict the original ones. 133 | """ 134 | transforms = cli_transforms if cli_transforms else list(ctx.obj.iter_transforms()) 135 | 136 | total_records = sum(len(orig["records"]) for orig in datasets) 137 | if total_records == 0: 138 | click.echo(fmt.no_records_message("Keep Contradictions")) 139 | return datasets 140 | 141 | gpu = accelerator.use_gpu(no_gpu) 142 | 143 | model, tokenizer = models.roberta_mnli_load() 144 | predict_func = functools.partial( 145 | ops.roberta_mnli_predict, model=model, tokenizer=tokenizer, cuda=gpu 146 | ) 147 | contradiction_id = ops.roberta_mnli_contradiction_id(model) 148 | 149 | processed = [dataset for dataset in datasets] 150 | for transform in transforms: 151 | pbar = fmt.pbar_from_total( 152 | total_records, f"Keep Contradictions for {transform}" 153 | ) 154 | validation_func = functional.lift_boolean_validation( 155 | lambda o, p: predict_func(o, p).argmax().item() == contradiction_id, 156 | ) 157 | pipeline_func = pipeline.lift_validation(validation_func, transform) 158 | 159 | for dataset in processed: 160 | not_validated = dataset["records"] 161 | validated = [] 162 | for i in range(0, len(not_validated), batch_size): 163 | batch = core.Data(not_validated[i : i + batch_size]) 164 | validated.extend(pipeline_func(batch)) 165 | pbar.update(len(batch)) 166 | dataset["records"] = validated 167 | return processed 168 | 169 | 170 | @click.command( 171 | _KEEP_EQ_NUM_CMD, 172 | help="Keep perturbations with the same numbers count as the original.", 173 | ) 174 | @click.option( 175 | "--transform", 176 | "cli_transforms", 177 | multiple=True, 178 | help="Transforms to filter with this validation. If not specified all are validated.", 179 | ) 180 | @processor.make 181 | @click.pass_context 182 | def keep_eq_num_count(ctx, datasets, cli_transforms): 183 | """Validates if the synthetic records have an equal number count as the original records. 184 | 185 | This operation is a validation. It uses Regular Expressions to detect the numbers 186 | both in the original and the perturbed sentences. 187 | """ 188 | transforms = cli_transforms if cli_transforms else list(ctx.obj.iter_transforms()) 189 | 190 | total_records = sum(len(orig["records"]) for orig in datasets) 191 | if total_records == 0: 192 | click.echo(fmt.no_records_message("Keep Equal Numbers Count")) 193 | return datasets 194 | 195 | processed = [dataset for dataset in datasets] 196 | for transform in transforms: 197 | pbar = fmt.pbar_from_total( 198 | total_records, f"Keep Equal Numbers Count for {transform}" 199 | ) 200 | validation_func = functional.lift_boolean_validation(ops.equal_numbers_count) 201 | pipeline_func = pipeline.lift_validation(validation_func, transform) 202 | for dataset in processed: 203 | not_validated = core.Data(dataset["records"]) 204 | dataset["records"] = pipeline_func(not_validated) 205 | pbar.update(len(not_validated)) 206 | return processed 207 | 208 | 209 | @click.command( 210 | _KEEP_EQ_NE_CMD, 211 | help="Keep perturbations with the same named entities count as the original.", 212 | ) 213 | @click.option( 214 | "--transform", 215 | "cli_transforms", 216 | multiple=True, 217 | help="Transforms to filter with this validation. If not specified all are validated.", 218 | ) 219 | @click.option( 220 | "--batch-size", 221 | default=16, 222 | show_default=True, 223 | help="Batch size when processing records.", 224 | ) 225 | @click.option("--no-gpu", is_flag=True, help="Disable gpu.") 226 | @processor.make 227 | @click.pass_context 228 | def keep_eq_ne_count(ctx, datasets, cli_transforms, batch_size, no_gpu): 229 | """Validates if the synthetic records have an equal named entity count as the original records. 230 | 231 | This operation is a validation. It uses a Stanza NER model to detect the named entities 232 | both in the original and the perturbed sentences. 233 | """ 234 | transforms = cli_transforms if cli_transforms else list(ctx.obj.iter_transforms()) 235 | 236 | total_records = sum( 237 | len(dataset["records"]) 238 | for dataset in datasets 239 | if models.stanza_ner_lang_available(dataset["lang"]) 240 | ) 241 | if total_records == 0: 242 | click.echo(fmt.no_records_message("Keep Equal Named Entities Count")) 243 | return datasets 244 | 245 | gpu = accelerator.use_gpu(no_gpu) 246 | 247 | processed = [dataset for dataset in datasets] 248 | for transform in transforms: 249 | pbar = fmt.pbar_from_total( 250 | total_records, f"Keep Equal Named Entities Count for {transform}" 251 | ) 252 | for dataset in processed: 253 | lang = dataset["lang"] 254 | if not models.stanza_ner_lang_available(lang): 255 | continue 256 | ner_pipeline = models.stanza_ner_load(lang, gpu) 257 | validation_func = functools.partial( 258 | ops.equal_named_entities_count, 259 | ner_pipeline=ner_pipeline, 260 | ) 261 | validation_func = functional.lift_boolean_validation(validation_func) 262 | pipeline_func = pipeline.lift_validation(validation_func, transform) 263 | not_validated = dataset["records"] 264 | validated = [] 265 | for i in range(0, len(not_validated), batch_size): 266 | batch = core.Data(not_validated[i : i + batch_size]) 267 | validated.extend(pipeline_func(batch)) 268 | pbar.update(len(batch)) 269 | dataset["records"] = validated 270 | return processed 271 | 272 | 273 | @click.command( 274 | _KEEP_GEQ_EDIT_DIST_CMD, 275 | help="Keep perturbations with a minimum edit distance from the original above a threshold.", 276 | ) 277 | @click.option( 278 | "-d", "--distance", type=int, required=True, help="Minimum threshold to accept." 279 | ) 280 | @click.option( 281 | "-l", 282 | "--level", 283 | type=click.Choice(("char", "word"), case_sensitive=False), 284 | default="char", 285 | help="Level at which to measure the minimum edit distance.", 286 | ) 287 | @click.option( 288 | "--transform", 289 | "cli_transforms", 290 | multiple=True, 291 | help="Transforms to filter with this validation. If not specified all are validated.", 292 | ) 293 | @processor.make 294 | @click.pass_context 295 | def keep_geq_edit_dist(ctx, datasets, distance, level, cli_transforms): 296 | """Validates if the perturbations have a minimum edit distance higher than a threshold. 297 | 298 | This operation is a validation. It computes the minimum edit distance between the original 299 | and perturbed sentences. 300 | """ 301 | transforms = cli_transforms if cli_transforms else list(ctx.obj.iter_transforms()) 302 | 303 | total_records = sum(len(orig["records"]) for orig in datasets) 304 | if total_records == 0: 305 | click.echo(fmt.no_records_message(f"Keep Edit Distance above {distance}")) 306 | return datasets 307 | 308 | processed = [dataset for dataset in datasets] 309 | for transform in transforms: 310 | pbar = fmt.pbar_from_total( 311 | total_records, f"Keep Edit Distance above {distance} for {transform}" 312 | ) 313 | 314 | validation_func = functional.lift_boolean_validation( 315 | lambda s1, s2: ops.edit_distance(s1, s2, level) >= distance 316 | ) 317 | pipeline_func = pipeline.lift_validation(validation_func, transform) 318 | for dataset in processed: 319 | not_validated = core.Data(dataset["records"]) 320 | dataset["records"] = pipeline_func(not_validated) 321 | pbar.update(len(not_validated)) 322 | return processed 323 | 324 | 325 | @click.command( 326 | _KEEP_LEQ_CHAR_INSERT_CMD, 327 | help="Keep perturbations with a total of char insertions below a threshold.", 328 | ) 329 | @click.option( 330 | "-c", 331 | "--chars", 332 | default="<>()[]{}", 333 | show_default=True, 334 | help="Chars to consider (each individual char is considered)", 335 | ) 336 | @click.option( 337 | "-i", 338 | "--max-insertions", 339 | type=int, 340 | required=True, 341 | help="Maximum insertions to accept.", 342 | ) 343 | @click.option( 344 | "--transform", 345 | "cli_transforms", 346 | multiple=True, 347 | help="Transforms to filter with this validation. If not specified all are validated.", 348 | ) 349 | @processor.make 350 | @click.pass_context 351 | def keep_leq_char_ins(ctx, datasets, chars, max_insertions, cli_transforms): 352 | """Validates if the pertubrations have a maximum number of character insertions. 353 | 354 | This operation is a validation. It computes the number of insertions of specific characters 355 | in the perturbed sentences, and only allows perturbations with this number bellow a threshold. 356 | """ 357 | transforms = cli_transforms if cli_transforms else list(ctx.obj.iter_transforms()) 358 | 359 | total_records = sum(len(orig["records"]) for orig in datasets) 360 | if total_records == 0: 361 | click.echo( 362 | fmt.no_records_message(f"Keep {chars} insertions below {max_insertions}") 363 | ) 364 | return datasets 365 | 366 | processed = [dataset for dataset in datasets] 367 | for transform in transforms: 368 | pbar = fmt.pbar_from_total( 369 | total_records, 370 | f"Keep {chars} insertions below {max_insertions} for {transform}", 371 | ) 372 | validation_func = functional.lift_boolean_validation( 373 | lambda o, p: ops.character_insertions(o, p, chars) <= max_insertions 374 | ) 375 | pipeline_func = pipeline.lift_validation(validation_func, transform) 376 | for dataset in processed: 377 | not_validated = core.Data(dataset["records"]) 378 | dataset["records"] = pipeline_func(not_validated) 379 | pbar.update(len(not_validated)) 380 | return processed 381 | -------------------------------------------------------------------------------- /smaug/core.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | from typing import Callable, Iterator, Optional, Tuple, TypeVar, Union 4 | 5 | from smaug import frozen 6 | 7 | T = TypeVar("T") 8 | 9 | 10 | class Data(frozen.frozenlist[T]): 11 | """Represents a batch of data that can be iterated over. 12 | 13 | This object is immutable. 14 | """ 15 | 16 | def item(self) -> T: 17 | if len(self) != 1: 18 | raise ValueError(f"item() can only be called for Data of length 1.") 19 | return self[0] 20 | 21 | def __repr__(self) -> str: 22 | values = [repr(el) for el in self] 23 | single_line = ", ".join(values) 24 | if len(single_line) <= 80: 25 | return f"Data[{single_line}]" 26 | lines = "".join(f"\t{v},\n" for v in values) 27 | return f"Data[\n" f"{lines}" f"]" 28 | 29 | 30 | DataLike = Union[Data[T], T] 31 | 32 | 33 | @dataclasses.dataclass(frozen=True, eq=True, order=True) 34 | class SpanIndex: 35 | 36 | start: int 37 | end: int 38 | 39 | def encloses(self, other: "SpanIndex") -> bool: 40 | """Verifies whether this span totally encloses the other. 41 | 42 | If a span A encloses a span B, then: 43 | A.start B.start B.end A.end 44 | ---|--------|--------|--------|--- 45 | """ 46 | return self.start <= other.start <= other.end <= self.end 47 | 48 | def partial_overlaps(self, other: "SpanIndex") -> bool: 49 | """Verifies whether this span partially overlaps the other. 50 | 51 | If a span A partially overlaps span B, then: 52 | A.start B.start A.end B.end 53 | ---|--------|--------|--------|--- 54 | or 55 | B.start A.start B.end A.end 56 | ---|--------|--------|--------|--- 57 | """ 58 | return ( 59 | self.start <= other.start <= self.end <= other.end 60 | or other.start <= self.start <= other.end <= self.end 61 | ) 62 | 63 | def intersects(self, other: "SpanIndex") -> bool: 64 | return ( 65 | self.encloses(other) 66 | or other.encloses(self) 67 | or self.partial_overlaps(other) 68 | or other.partial_overlaps(self) 69 | ) 70 | 71 | def __post_init__(self): 72 | if self.start < 0: 73 | raise ValueError(f"'start' must be positive but is {self.start}.") 74 | if self.end < 0: 75 | raise ValueError(f"'end' must be positive but is {self.end}.") 76 | if self.end < self.start: 77 | msg = f"'end' must be greater or equal to 'start': start={self.start}, end={self.end}" 78 | raise ValueError(msg) 79 | 80 | def __str__(self) -> str: 81 | return f"[{self.start}, {self.end}]" 82 | 83 | 84 | SpanIndexLike = Union[Tuple[int, int], SpanIndex] 85 | 86 | 87 | @dataclasses.dataclass(frozen=True, eq=True) 88 | class Modification: 89 | """Stores a modification that was applied to a given sentence. 90 | 91 | Attributes: 92 | old: The old span to be replaced by new. 93 | new: The new span to replace old. 94 | idx: Position where to start replacing. 95 | """ 96 | 97 | old: str 98 | new: str 99 | idx: int 100 | 101 | @property 102 | def old_span_idx(self) -> SpanIndex: 103 | return SpanIndex(self.idx, self.idx + len(self.old)) 104 | 105 | @property 106 | def new_span_idx(self) -> SpanIndex: 107 | return SpanIndex(self.idx, self.idx + len(self.new)) 108 | 109 | 110 | @dataclasses.dataclass(frozen=True) 111 | class ModificationTrace: 112 | """Stores the trace of multiple modifications in order.""" 113 | 114 | curr: Modification 115 | prev: Optional["ModificationTrace"] = dataclasses.field(default=None) 116 | 117 | @staticmethod 118 | def from_modifications(*modifications: Modification) -> "ModificationTrace": 119 | """Constructs a modification trace by considering the modifications in order. 120 | 121 | Args: 122 | modifications: Modifications to store. 123 | 124 | Raises: 125 | ValueError: If no modifications were provided. 126 | 127 | Returns: 128 | The modification trace. 129 | """ 130 | curr = None 131 | for m in modifications: 132 | curr = ModificationTrace(m, curr) 133 | if curr is None: 134 | raise ValueError("at least on modification is expected.") 135 | return curr 136 | 137 | def __iter__(self) -> Iterator[Modification]: 138 | """Creates an iterator to visit the modifications from oldest to newest. 139 | 140 | Returns: 141 | The iterator object. 142 | """ 143 | 144 | def _yield_modifications(trace: "ModificationTrace") -> Iterator[Modification]: 145 | if trace.prev is not None: 146 | yield from _yield_modifications(trace.prev) 147 | yield trace.curr 148 | 149 | yield from _yield_modifications(self) 150 | 151 | 152 | @dataclasses.dataclass(frozen=True) 153 | class Sentence: 154 | """Represents a sentence that stores applied modifications. 155 | 156 | Each sentence stores its value and the modifications trace 157 | that were applied to this sentence. 158 | """ 159 | 160 | value: str 161 | 162 | trace: Optional[ModificationTrace] = dataclasses.field(default=None) 163 | 164 | def __iter__(self) -> Iterator[str]: 165 | return iter(self.value) 166 | 167 | def __eq__(self, o: object) -> bool: 168 | return isinstance(o, Sentence) and self.value == o.value 169 | 170 | def __len__(self) -> int: 171 | return len(self.value) 172 | 173 | def __str__(self) -> str: 174 | return self.value 175 | 176 | 177 | SentenceLike = Union[str, Sentence] 178 | 179 | Validation = Callable[ 180 | [DataLike[SentenceLike], DataLike[Optional[SentenceLike]]], Data[Optional[Sentence]] 181 | ] 182 | -------------------------------------------------------------------------------- /smaug/frozen.py: -------------------------------------------------------------------------------- 1 | import collections.abc 2 | 3 | from typing import Generic, Callable, Iterator, Optional, SupportsIndex, Tuple, TypeVar 4 | 5 | _T = TypeVar("_T") 6 | 7 | 8 | class frozenlist(collections.abc.Sequence[_T], Generic[_T]): 9 | """An immutable variant of Python list.""" 10 | 11 | def __init__(self, *args, **kwargs) -> None: 12 | self._list = list(*args, **kwargs) 13 | self._hash = None 14 | 15 | def append(self, *add: _T) -> "frozenlist[_T]": 16 | return self._copy_and_apply(lambda new_list: new_list.extend(add)) 17 | 18 | def insert(self, index: SupportsIndex, object: _T) -> "frozenlist[_T]": 19 | return self._copy_and_apply(lambda new_list: new_list.insert(index, object)) 20 | 21 | def replace(self, index: SupportsIndex, object: _T) -> "frozenlist[_T]": 22 | new_list = list(self._list) 23 | new_list[index] = object 24 | new_self = type(self)(new_list) 25 | return new_self 26 | 27 | def pop(self, index: Optional[SupportsIndex] = None) -> Tuple["frozenlist[_T]", _T]: 28 | if index is None: 29 | index = -1 30 | value = self._list[index] 31 | new_self = self._copy_and_apply(lambda new_list: new_list.pop(index)) 32 | return new_self, value 33 | 34 | def index( 35 | self, 36 | value: _T, 37 | start: Optional[SupportsIndex] = None, 38 | stop: Optional[SupportsIndex] = None, 39 | ) -> int: 40 | if start is None: 41 | start = 0 42 | if stop is None: 43 | stop = len(self._list) 44 | return self._list.index(value, start, stop) 45 | 46 | def count(self, value: _T) -> int: 47 | return self._list.count(value) 48 | 49 | def __len__(self) -> int: 50 | return len(self._list) 51 | 52 | def __iter__(self) -> Iterator[_T]: 53 | return iter(self._list) 54 | 55 | def __hash__(self) -> int: 56 | if self._hash is None: 57 | h = 0 58 | for v in self._list: 59 | h ^= hash(v) 60 | self._hash = h 61 | return self._hash 62 | 63 | def __getitem__(self, i) -> _T: 64 | return self._list[i] 65 | 66 | def __setitem__(self, i, o) -> None: 67 | raise ValueError("frozenlist is immutable") 68 | 69 | def __delitem__(self, i) -> None: 70 | raise ValueError("frozenlist is immutable") 71 | 72 | def __contains__(self, o: object) -> bool: 73 | return self._list.__contains__(o) 74 | 75 | def __add__(self, x: "frozenlist[_T]") -> "frozenlist[_T]": 76 | return self.append(*x) 77 | 78 | def __str__(self) -> str: 79 | values = [str(el) for el in self] 80 | single_line = ", ".join(values) 81 | if len(single_line) <= 80: 82 | return f"[{single_line}]" 83 | lines = "".join(f"\t{v},\n" for v in values) 84 | return f"[\n{lines}]" 85 | 86 | def __repr__(self) -> str: 87 | values = [repr(el) for el in self] 88 | single_line = ", ".join(values) 89 | if len(single_line) <= 80: 90 | return f"frozenlist([{single_line}])" 91 | lines = "".join(f"\t{v},\n" for v in values) 92 | return f"frozenlist([\n{lines}])" 93 | 94 | def __eq__(self, other): 95 | return isinstance(other, frozenlist) and self._list == other._list 96 | 97 | def _copy_and_apply(self, func: Callable[[list], None]) -> "frozenlist[_T]": 98 | new_list = list(self._list) 99 | func(new_list) 100 | new_self = type(self)(new_list) 101 | return new_self 102 | -------------------------------------------------------------------------------- /smaug/functional.py: -------------------------------------------------------------------------------- 1 | from smaug.broadcast import broadcast_data 2 | from smaug.core import Data, DataLike, Sentence, SentenceLike, Validation 3 | from smaug.promote import promote_to_data, promote_to_sentence 4 | 5 | from typing import Callable, Optional 6 | 7 | 8 | def lift_boolean_validation( 9 | validation_func: Callable[[Sentence, Sentence], bool] 10 | ) -> Validation: 11 | def validate_single_perturbation( 12 | o: SentenceLike, p: Optional[SentenceLike] 13 | ) -> Optional[Sentence]: 14 | if p is None: 15 | return None 16 | o, p = promote_to_sentence(o), promote_to_sentence(p) 17 | return p if validation_func(o, p) else None 18 | 19 | def validate_all_perturbations( 20 | originals: DataLike[SentenceLike], 21 | perturbations: DataLike[Optional[SentenceLike]], 22 | ) -> Data[Optional[Sentence]]: 23 | originals = promote_to_data(originals) 24 | perturbations = promote_to_data(perturbations) 25 | originals, perturbations = broadcast_data(originals, perturbations) 26 | return Data( 27 | [ 28 | validate_single_perturbation(o, p) 29 | for o, p in zip(originals, perturbations) 30 | ] 31 | ) 32 | 33 | return validate_all_perturbations 34 | -------------------------------------------------------------------------------- /smaug/models/__init__.py: -------------------------------------------------------------------------------- 1 | from smaug.models.stanza import ( 2 | stanza_ner_load, 3 | stanza_ner_lang_available, 4 | stanza_pos_load, 5 | ) 6 | from smaug.models.transformers import mT5_load, polyjuice_load, roberta_mnli_load 7 | -------------------------------------------------------------------------------- /smaug/models/stanza.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines all information required to run the Stanza models. 3 | 4 | The documentation for the stanza NER system is available at 5 | https://stanfordnlp.github.io/stanza/available_models.html#available-ner-models. 6 | """ 7 | 8 | import dataclasses 9 | import logging 10 | from typing import Tuple 11 | 12 | import stanza 13 | from packaging import version 14 | 15 | 16 | @dataclasses.dataclass 17 | class _StanzaNERModelInfo: 18 | 19 | lang: str 20 | tags: Tuple[str, ...] 21 | 22 | req_stanza_version: version.Version 23 | 24 | 25 | # Tags definitions from tag category notes. 26 | 27 | 28 | _FOUR_TAGS = ("PER", "LOC", "ORG", "MISC") 29 | 30 | _EIGHTEEN_TAGS = ( 31 | "PERSON", 32 | "NORP", # Nationalities / Religious / Political Group 33 | "FAC", # Facility 34 | "ORG", # Organization 35 | "GPE", # Countries / Cities / States 36 | "LOC", # Location 37 | "PRODUCT", 38 | "EVENT", 39 | "WORK_OF_ART", 40 | "LAW", 41 | "LANGUAGE", 42 | "DATE", 43 | "TIME", 44 | "PERCENT", 45 | "MONEY", 46 | "QUANTITY", 47 | "ORDINAL", 48 | "CARDINAL", 49 | ) 50 | _BULGARIAN_BLNSP_TAGS = ("EVENT", "LOCATION", "ORGANIZATION", "PERSON", "PRODUCT") 51 | 52 | _FINISH_TURKU_TAGS = ("EVENT", "DATE", "LOC", "ORG", "PER", "PRO") 53 | 54 | _ITALIAN_FBK_TAGS = ("LOC", "ORG", "PER") 55 | 56 | _JAPANESE_GSD_TAGS = ( 57 | "CARDINAL", 58 | "DATE", 59 | "EVENT", 60 | "FAC", 61 | "GPE", 62 | "LANGUAGE", 63 | "LAW", 64 | "LOC", 65 | "MONEY", 66 | "MOVEMENT", 67 | "NORP", 68 | "ORDINAL", 69 | "ORG", 70 | "PERCENT", 71 | "PERSON", 72 | "PET_NAME", 73 | "PHONE", 74 | "PRODUCT", 75 | "QUANTITY", 76 | "TIME", 77 | "TITLE_AFFIX", 78 | "WORK_OF_ART", 79 | ) 80 | 81 | # LOC (Location), NE (Misc), ORG (Organization), PNAME (Person) 82 | _MYANMAR_UCSY_TAGS = ("LOC", "NE", "ORG", "PNAME", "RACE", "TIME", "NUM") 83 | 84 | _NORWEGIAN_NORNE_TAGS = ("DRV", "EVT", "GPE", "LOC", "MISC", "ORG", "PER", "PROD") 85 | 86 | _PERSIAN_ARMAN_TAGS = ("event", "fac", "loc", "org", "pers", "pro") 87 | 88 | _SWEDISH_SUC3_TAGS = ( 89 | "animal", 90 | "inst", 91 | "myth", 92 | "person", 93 | "place", 94 | "product", 95 | "other", 96 | "work", 97 | ) 98 | 99 | _TURKISH_STARLANG_TAGS = ("LOCATION", "MONEY", "ORGANIZATION", "PERSON", "TIME") 100 | 101 | _VIETNAMESE_VLSP_TAGS = ("LOCATION", "MISCELLANEOUS", "ORGANIZATION", "PERSON") 102 | 103 | _STANZA_NER_MODEL_INFO = { 104 | # Afrikaans 105 | "af": _StanzaNERModelInfo( 106 | lang="af", tags=_FOUR_TAGS, req_stanza_version=version.Version("1.0.0") 107 | ), 108 | # Arabic 109 | "ar": _StanzaNERModelInfo( 110 | lang="ar", tags=_FOUR_TAGS, req_stanza_version=version.Version("1.0.0") 111 | ), 112 | # Bulgarian 113 | "bg": _StanzaNERModelInfo( 114 | lang="bg", 115 | tags=_BULGARIAN_BLNSP_TAGS, 116 | req_stanza_version=version.Version("1.2.1"), 117 | ), 118 | # Chinese 119 | "zh": _StanzaNERModelInfo( 120 | lang="zh", tags=_EIGHTEEN_TAGS, req_stanza_version=version.Version("1.0.0") 121 | ), 122 | # Danish 123 | "da": _StanzaNERModelInfo( 124 | lang="da", tags=_FOUR_TAGS, req_stanza_version=version.Version("1.4.0") 125 | ), 126 | # Dutch 127 | "nl": _StanzaNERModelInfo( 128 | lang="nl", tags=_FOUR_TAGS, req_stanza_version=version.Version("1.0.0") 129 | ), 130 | # English 131 | "en": _StanzaNERModelInfo( 132 | lang="en", tags=_EIGHTEEN_TAGS, req_stanza_version=version.Version("1.0.0") 133 | ), 134 | # Finnish 135 | "fi": _StanzaNERModelInfo( 136 | lang="fi", tags=_FINISH_TURKU_TAGS, req_stanza_version=version.Version("1.2.1") 137 | ), 138 | # French 139 | "fr": _StanzaNERModelInfo( 140 | lang="fr", tags=_FOUR_TAGS, req_stanza_version=version.Version("1.0.0") 141 | ), 142 | # German 143 | "de": _StanzaNERModelInfo( 144 | lang="de", tags=_FOUR_TAGS, req_stanza_version=version.Version("1.0.0") 145 | ), 146 | # Hungarian 147 | "hu": _StanzaNERModelInfo( 148 | lang="hu", tags=_FOUR_TAGS, req_stanza_version=version.Version("1.2.1") 149 | ), 150 | # Italian 151 | "it": _StanzaNERModelInfo( 152 | lang="it", tags=_ITALIAN_FBK_TAGS, req_stanza_version=version.Version("1.2.3") 153 | ), 154 | # Japanese 155 | "ja": _StanzaNERModelInfo( 156 | lang="ja", tags=_JAPANESE_GSD_TAGS, req_stanza_version=version.Version("1.4.0") 157 | ), 158 | # Myanmar 159 | "my": _StanzaNERModelInfo( 160 | lang="my", tags=_MYANMAR_UCSY_TAGS, req_stanza_version=version.Version("1.4.0") 161 | ), 162 | # Norwegian‑Bokmaal 163 | "nb": _StanzaNERModelInfo( 164 | lang="nb", 165 | tags=_NORWEGIAN_NORNE_TAGS, 166 | req_stanza_version=version.Version("1.4.0"), 167 | ), 168 | # Norwegian‑Nynorsk 169 | "nn": _StanzaNERModelInfo( 170 | lang="nn", 171 | tags=_NORWEGIAN_NORNE_TAGS, 172 | req_stanza_version=version.Version("1.4.0"), 173 | ), 174 | # Persian 175 | "fa": _StanzaNERModelInfo( 176 | lang="pa", tags=_PERSIAN_ARMAN_TAGS, req_stanza_version=version.Version("1.4.0") 177 | ), 178 | # Russian 179 | "ru": _StanzaNERModelInfo( 180 | lang="ru", tags=_FOUR_TAGS, req_stanza_version=version.Version("1.0.0") 181 | ), 182 | # Spanish 183 | "es": _StanzaNERModelInfo( 184 | lang="es", tags=_FOUR_TAGS, req_stanza_version=version.Version("1.0.0") 185 | ), 186 | # Swedish 187 | "sv": _StanzaNERModelInfo( 188 | lang="sv", tags=_SWEDISH_SUC3_TAGS, req_stanza_version=version.Version("1.4.0") 189 | ), 190 | # Turkish 191 | "tr": _StanzaNERModelInfo( 192 | lang="tr", 193 | tags=_TURKISH_STARLANG_TAGS, 194 | req_stanza_version=version.Version("1.4.0"), 195 | ), 196 | # Ukrainian 197 | "uk": _StanzaNERModelInfo( 198 | lang="uk", tags=_FOUR_TAGS, req_stanza_version=version.Version("1.0.0") 199 | ), 200 | # Vietnamese 201 | "vi": _StanzaNERModelInfo( 202 | lang="vi", 203 | tags=_VIETNAMESE_VLSP_TAGS, 204 | req_stanza_version=version.Version("1.2.1"), 205 | ), 206 | } 207 | _AVAILABLE_STANZA_VERSION = version.Version(stanza.__version__) 208 | 209 | 210 | def stanza_ner_tags(lang: str): 211 | return _STANZA_NER_MODEL_INFO[lang].tags 212 | 213 | 214 | def stanza_ner_lang_available(lang: str) -> bool: 215 | if lang not in _STANZA_NER_MODEL_INFO: 216 | return False 217 | model_info = _STANZA_NER_MODEL_INFO[lang] 218 | if model_info.req_stanza_version > _AVAILABLE_STANZA_VERSION: 219 | logging.warning( 220 | 'Required Stanza version for language "%s" is "%s" but found "%s".', 221 | lang, 222 | model_info.req_stanza_version, 223 | _AVAILABLE_STANZA_VERSION, 224 | ) 225 | return False 226 | return True 227 | 228 | 229 | def stanza_ner_load(lang: str = "en", use_gpu: bool = False) -> stanza.Pipeline: 230 | """Loads a new pipeline for a given language. 231 | 232 | Args: 233 | lang: Language of the pipeline. 234 | use_gpu: Specifies if a gpu should be used if available. 235 | 236 | Returns: 237 | stanza.Pipeline that performs tokenization and named entity 238 | recognition. 239 | """ 240 | processors = "tokenize,ner" 241 | stanza.download(lang, processors=processors, logging_level="WARN") 242 | return stanza.Pipeline(lang, processors=processors, use_gpu=use_gpu) 243 | 244 | 245 | def stanza_pos_load(lang: str = "en", use_gpu: bool = False) -> stanza.Pipeline: 246 | processors = "tokenize,pos" 247 | stanza.download(lang=lang, processors=processors, logging_level="WARN") 248 | return stanza.Pipeline(lang=lang, processors=processors, use_gpu=use_gpu) 249 | -------------------------------------------------------------------------------- /smaug/models/transformers.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | 3 | from typing import Tuple 4 | 5 | 6 | def mT5_load( 7 | model_name: str = "google/mt5-large", 8 | ) -> Tuple[transformers.MT5ForConditionalGeneration, transformers.T5Tokenizer]: 9 | """Loads mT5 model and tokenizer. 10 | 11 | Args: 12 | model_name (str, optional): name of the mT5 model to use. Defaults to "google/mt5-large". 13 | 14 | Returns: 15 | Tuple[transformers.MT5ForConditionalGeneration, transformers.T5Tokenizer]: mT5 model and tokenizer. 16 | """ 17 | model = transformers.MT5ForConditionalGeneration.from_pretrained(model_name) 18 | tokenizer = transformers.T5Tokenizer.from_pretrained(model_name) 19 | return model, tokenizer 20 | 21 | 22 | POLYJUICE_EOF_TOKEN = "<|endoftext|>" 23 | 24 | 25 | def polyjuice_load(): 26 | """Loads PolyJuice model for constrained generation. 27 | 28 | Returns: 29 | Tuple[transformers.AutoModelForCausalLM, transformers.AutoTokenizer]: PolyJuice model and tokenizer. 30 | """ 31 | model_path = "uw-hai/polyjuice" 32 | model = transformers.AutoModelForCausalLM.from_pretrained(model_path) 33 | tokenizer = transformers.AutoTokenizer.from_pretrained( 34 | model_path, 35 | pad_token=POLYJUICE_EOF_TOKEN, 36 | ) 37 | return model, tokenizer 38 | 39 | 40 | def roberta_mnli_load() -> Tuple[ 41 | transformers.AutoModelForSequenceClassification, transformers.AutoTokenizer 42 | ]: 43 | """Loads RoBERTa finetuned for multilingual natural language inference. 44 | 45 | Returns: 46 | Tuple[transformers.AutoModelForSequenceClassification, transformers.AutoTokenizer]: RoBERTa model and tokenizer. 47 | """ 48 | name = "roberta-large-mnli" 49 | model = transformers.AutoModelForSequenceClassification.from_pretrained(name) 50 | tokenizer = transformers.AutoTokenizer.from_pretrained(name) 51 | return model, tokenizer 52 | -------------------------------------------------------------------------------- /smaug/more_functools.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from typing import Callable 4 | 5 | 6 | def pipe(*funcs: Callable) -> Callable: 7 | """Creates a new function by piping the results of the received functions. 8 | 9 | The functions are piped in order, meaning 10 | pipe(f, g, h)(x) = h(g(f(x))) 11 | 12 | Args: 13 | funcs: Functions to pipe. 14 | 15 | Returns: 16 | Function piping the received functions""" 17 | return functools.reduce(lambda f, g: lambda x: g(f(x)), funcs) 18 | -------------------------------------------------------------------------------- /smaug/ops/__init__.py: -------------------------------------------------------------------------------- 1 | """This package specifies the core operations for performing data augmentation.""" 2 | 3 | from smaug.ops.modification import ( 4 | apply_modification, 5 | reverse_modification, 6 | apply_modification_trace, 7 | reverse_modification_trace, 8 | modified_spans_from_trace, 9 | ) 10 | from smaug.ops.sentence import ( 11 | insert, 12 | replace, 13 | delete, 14 | prepend, 15 | append, 16 | rstrip, 17 | find, 18 | startswith, 19 | endswith, 20 | ) 21 | from smaug.ops.detection import ( 22 | stanza_detect_named_entities, 23 | regex_detect_matches, 24 | regex_detect_numbers, 25 | regex_detect_spans_between_matches, 26 | regex_detect_spans_between_punctuation, 27 | ) 28 | from smaug.ops.pos_tagging import stanza_pos_predict 29 | from smaug.ops.masking import ( 30 | mask_intervals, 31 | mask_detections, 32 | mask_random_replace, 33 | mask_random_insert, 34 | mask_poisson_spans, 35 | ) 36 | from smaug.ops.lang_model import mT5_generate, mT5_masking_function 37 | from smaug.ops.nli import roberta_mnli_predict, roberta_mnli_contradiction_id 38 | from smaug.ops.text_generation import polyjuice_negate 39 | from smaug.ops.sentence_comparison import ( 40 | character_insertions, 41 | equal_numbers_count, 42 | equal_named_entities_count, 43 | edit_distance, 44 | ) 45 | -------------------------------------------------------------------------------- /smaug/ops/detection.py: -------------------------------------------------------------------------------- 1 | import re 2 | import stanza 3 | 4 | from smaug.core import Data, DataLike, Sentence, SentenceLike, SpanIndex 5 | from smaug.frozen import frozenlist 6 | from smaug.promote import promote_to_data, promote_to_sentence 7 | 8 | from typing import Iterable, Optional, Tuple 9 | 10 | 11 | def stanza_detect_named_entities( 12 | text: DataLike[SentenceLike], 13 | ner_pipeline: stanza.Pipeline, 14 | filter_entities: Optional[Iterable[str]] = None, 15 | ) -> Data[frozenlist[Tuple[int, int]]]: 16 | """Detects text spans with named entities using the Stanza NER pipeline. 17 | 18 | Args: 19 | text: Text to process. 20 | ner_pipeline: Stanza NER pipeline to apply. 21 | filter_entities: Entity types to accept. 22 | 23 | Returns: 24 | Spans of detected named entities. 25 | """ 26 | text = promote_to_data(text) 27 | sentences = map(promote_to_sentence, text) 28 | 29 | documents = [ner_pipeline(s.value) for s in sentences] 30 | 31 | def process_document(doc): 32 | detected_entities = doc.entities 33 | if filter_entities is not None: 34 | unique_entities = set(filter_entities) 35 | detected_entities = [ 36 | ent for ent in detected_entities if ent.type in unique_entities 37 | ] 38 | 39 | return frozenlist([(ent.start_char, ent.end_char) for ent in detected_entities]) 40 | 41 | return Data([process_document(doc) for doc in documents]) 42 | 43 | 44 | _DEFAULT_NUMBERS_REGEX = re.compile(r"[-+]?\.?(\d+[.,])*\d+") 45 | 46 | 47 | def regex_detect_numbers( 48 | text: DataLike[SentenceLike], 49 | ) -> Data[frozenlist[Tuple[int, int]]]: 50 | """Detects text spans with numbers according to a regular expression. 51 | 52 | Args: 53 | text: Text to process. 54 | 55 | Returns: 56 | Spans of detected matches. 57 | """ 58 | return regex_detect_matches(text, _DEFAULT_NUMBERS_REGEX) 59 | 60 | 61 | def regex_detect_matches( 62 | text: DataLike[SentenceLike], 63 | regex: re.Pattern, 64 | ) -> Data[frozenlist[Tuple[int, int]]]: 65 | """Detects text spans that match a given regex. 66 | 67 | Args: 68 | text: Text to process. 69 | regex: Regular Expression to search. 70 | 71 | Returns: 72 | Spans of detected matches. 73 | """ 74 | text = promote_to_data(text) 75 | sentences = map(promote_to_sentence, text) 76 | 77 | def process_sentence(s: Sentence) -> frozenlist[Tuple[int, int]]: 78 | matches = regex.finditer(s.value) 79 | return frozenlist([m.span() for m in matches]) 80 | 81 | return Data([process_sentence(s) for s in sentences]) 82 | 83 | _DEFAULT_PUNCTUATION_REGEX = re.compile(r"[!?.,]+") 84 | 85 | def regex_detect_spans_between_punctuation( 86 | text: DataLike[SentenceLike], 87 | ) -> Data[frozenlist[SpanIndex]]: 88 | """Detects text spans between punctuation marks. 89 | 90 | Args: 91 | text: Text to process. 92 | 93 | Returns: 94 | Spans between detected punctuation marks. 95 | """ 96 | return regex_detect_spans_between_matches(text, _DEFAULT_PUNCTUATION_REGEX) 97 | 98 | def regex_detect_spans_between_matches( 99 | text: DataLike[SentenceLike], regex: re.Pattern, 100 | ) -> Data[frozenlist[SpanIndex]]: 101 | """Detects text spans between matches of a given regex. 102 | 103 | Args: 104 | text: Text to process. 105 | regex: Regular Expression to search. 106 | 107 | Returns: 108 | Spans between detected matches. 109 | """ 110 | text = promote_to_data(text) 111 | sentences = map(promote_to_sentence, text) 112 | 113 | def process_sentence(s: Sentence) -> frozenlist[SpanIndex]: 114 | matches = regex.finditer(s.value) 115 | spans_delims_idxs = [0] + [m.end() for m in matches] + [len(s)] 116 | # Transform indexes in iterable with (idx1,idx2), (idx2,idx3), ... 117 | pairwise = zip(spans_delims_idxs, spans_delims_idxs[1:]) 118 | return frozenlist(SpanIndex(s, e) for s, e in pairwise) 119 | 120 | return Data([process_sentence(s) for s in sentences]) -------------------------------------------------------------------------------- /smaug/ops/lang_model.py: -------------------------------------------------------------------------------- 1 | import re 2 | import transformers 3 | 4 | from smaug import ops 5 | from smaug.core import Data, DataLike, Sentence, SentenceLike 6 | from smaug.promote import promote_to_data, promote_to_sentence 7 | 8 | 9 | _MASK_REGEX = re.compile(r"") 10 | 11 | 12 | def mT5_generate( 13 | text: DataLike[SentenceLike], 14 | model: transformers.MT5ForConditionalGeneration, 15 | tokenizer: transformers.T5Tokenizer, 16 | clean_outputs: bool = True, 17 | cuda: bool = False, 18 | ) -> Data[Sentence]: 19 | """Generates with Google's mT5 model. 20 | 21 | Args: 22 | text: sentences to use as input. 23 | model: mT5 model to use. 24 | tokenizer: T5 tokenizer to use. 25 | clean_outputs: If replacing output, specifies whether small transformations should 26 | be applied to the output sentences to improve their quality. 27 | cuda: Whether to use cuda enabled gpu or not. 28 | """ 29 | 30 | text = promote_to_data(text) 31 | sentences = Data(promote_to_sentence(t) for t in text) 32 | 33 | if cuda: 34 | model.cuda() 35 | 36 | tokenizer_input = [s.value for s in sentences] 37 | input_ids = tokenizer(tokenizer_input, padding=True, return_tensors="pt").input_ids 38 | if cuda: 39 | input_ids = input_ids.cuda() 40 | 41 | output_ids = model.generate( 42 | input_ids, 43 | max_new_tokens=model.config.max_length, 44 | do_sample=True, 45 | top_k=50, 46 | ) 47 | 48 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) 49 | 50 | outputs = [_mT5_replace_masks(s, o) for s, o in zip(sentences, outputs)] 51 | 52 | if clean_outputs: 53 | outputs = [_mT5_clean_output(o) for o in outputs] 54 | 55 | return Data(outputs) 56 | 57 | 58 | def mT5_masking_function(idx: int): 59 | return f"" 60 | 61 | 62 | def _mT5_replace_masks(source: Sentence, output: str) -> Sentence: 63 | spans = _MASK_REGEX.split(output)[1:] 64 | 65 | mask_idx = 0 66 | for span in spans: 67 | no_space_start = len(span) > 0 and span[0] != " " 68 | # Avoid bad escape char by replacing single \ with \\ 69 | escaped_span = span.strip().replace("\\", "\\\\") 70 | 71 | mask = mT5_masking_function(mask_idx) 72 | mask_idx += 1 73 | 74 | if pattern_match := re.search(mask, source.value): 75 | first_idx = pattern_match.start() 76 | last_idx = first_idx + len(mask) 77 | # If we are replacing by a span that does not start by a space, 78 | # and there is a space before the mask then also remove that space 79 | # (e.g. near -> nearly instead of near -> near ly) 80 | if first_idx != 0 and source.value[first_idx - 1] == " " and no_space_start: 81 | first_idx -= 1 82 | replace_span = (first_idx, last_idx) 83 | source = ops.replace(source, escaped_span, replace_span) 84 | 85 | return source 86 | 87 | 88 | def _mT5_clean_output(output: Sentence) -> Sentence: 89 | while ops.startswith(output, (".", ",", "!", "?", " ")): 90 | output = ops.delete(output, (0, 1)) 91 | return ops.rstrip(output) 92 | -------------------------------------------------------------------------------- /smaug/ops/masking.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import numpy as np 3 | 4 | 5 | from smaug import ops 6 | from smaug.broadcast import broadcast_data 7 | from smaug.core import Data, DataLike, Sentence, SentenceLike, SpanIndexLike 8 | from smaug.frozen import frozenlist 9 | from smaug.promote import promote_to_data, promote_to_sentence, promote_to_span_index 10 | 11 | from typing import Callable, Optional, Tuple 12 | 13 | MaskFunction = Callable[[int], str] 14 | """Retrieves the ith mask token given i. 15 | 16 | For methods where masks in the same sentences are different, this function 17 | will be called with 0, 1, 2, ... and should return the 1st, 2nd, 3rd, ... masks. 18 | 19 | For methods that do not distinguish masks, this function should always return 20 | the same value. 21 | 22 | Args: 23 | i: mask index. 24 | 25 | Returns: 26 | mask token to insert 27 | """ 28 | 29 | 30 | def mask_intervals( 31 | text: DataLike[SentenceLike], 32 | intervals: DataLike[frozenlist[SpanIndexLike]], 33 | func: MaskFunction, 34 | ) -> Data[Sentence]: 35 | """Masks a sentence according to intervals. 36 | 37 | Mask the given sentence according to the specified intervals. The characters 38 | in the specified intervals are replaced by the mask token. 39 | 40 | Args: 41 | text: text to mask. 42 | intervals: intervals to mask. Each interval should specify the 43 | (start, end) to index the sentence. 44 | func: masking function to mask the intervals. 45 | 46 | Returns: 47 | Masked text according to the given intervals. 48 | """ 49 | 50 | text = promote_to_data(text) 51 | intervals = promote_to_data(intervals) 52 | 53 | text, intervals = broadcast_data(text, intervals) 54 | 55 | sentences = map(promote_to_sentence, text) 56 | 57 | return Data( 58 | _mask_sentence_intervals(s, i, func) for s, i in zip(sentences, intervals) 59 | ) 60 | 61 | 62 | def _mask_sentence_intervals( 63 | sentence: Sentence, 64 | intervals: frozenlist[SpanIndexLike], 65 | func: MaskFunction, 66 | ) -> Sentence: 67 | 68 | if len(intervals) == 0: 69 | return sentence 70 | 71 | mask_idx = len(intervals) - 1 72 | # Go through intervals in reverse order as modifying 73 | # the sentence shifts all intervals greater than the 74 | # current. 75 | for interval in sorted(intervals, reverse=True): 76 | mask = func(mask_idx) 77 | span_index = promote_to_span_index(interval) 78 | sentence = ops.replace(sentence, mask, span_index) 79 | mask_idx -= 1 80 | 81 | return sentence 82 | 83 | 84 | def mask_detections( 85 | text: DataLike[SentenceLike], 86 | detect_func: Callable[[DataLike[SentenceLike]], Data[frozenlist[Tuple[int, int]]]], 87 | mask_func: MaskFunction, 88 | rng: np.random.Generator, 89 | p: float = 1, 90 | max_masks: Optional[int] = None, 91 | ) -> Data[Sentence]: 92 | """Masks the detected spans in a given text. 93 | 94 | Args: 95 | text: Text to apply the masks. 96 | detect_func: Function to detect possible spans to mask. 97 | mask_func: Masking function to apply. 98 | rng: Numpy random generator to use. 99 | p: Probability of applying a mask to a given detection. 100 | max_masks: Maximum masks to apply. If not specified all 101 | detections will be masked. 102 | 103 | Returns: 104 | Masked text. 105 | """ 106 | text = promote_to_data(text) 107 | mask_sentence_func = functools.partial( 108 | _mask_sentence_detections, 109 | detect_func=detect_func, 110 | mask_func=mask_func, 111 | rng=rng, 112 | p=p, 113 | max_masks=max_masks, 114 | ) 115 | sentences = map(promote_to_sentence, text) 116 | return Data(mask_sentence_func(s) for s in sentences) 117 | 118 | 119 | def _mask_sentence_detections( 120 | text: Sentence, 121 | detect_func: Callable[[DataLike[SentenceLike]], Data[frozenlist[Tuple[int, int]]]], 122 | mask_func: MaskFunction, 123 | rng: np.random.Generator, 124 | p: float = 1, 125 | max_masks: Optional[int] = None, 126 | ) -> Sentence: 127 | if p == 0: 128 | return text 129 | 130 | detections = detect_func(text).item() 131 | 132 | if p != 1: 133 | detections = filter(lambda _: rng.random() < p, detections) 134 | 135 | if max_masks: 136 | detections = list(detections) 137 | if len(detections) > max_masks: 138 | detections = rng.choice(detections, max_masks, replace=False) 139 | 140 | return mask_intervals(text, frozenlist(detections), mask_func).item() 141 | 142 | 143 | def mask_random_replace( 144 | text: DataLike[SentenceLike], 145 | func: MaskFunction, 146 | rng: np.random.Generator, 147 | p: float = 1, 148 | ) -> Data[Sentence]: 149 | """Randomly replaces words for masks. 150 | 151 | Args: 152 | text: Text to apply the masks. 153 | func: Masking function to apply. 154 | rng: Numpy random generator to use. 155 | p: Probability of replacing a word by a mask. 156 | 157 | Returns: 158 | Data[str]: masked text. 159 | """ 160 | text = promote_to_data(text) 161 | mask_sentence_func = functools.partial( 162 | _mask_sentence_random_replace, 163 | func=func, 164 | rng=rng, 165 | p=p, 166 | ) 167 | sentences = map(promote_to_sentence, text) 168 | return Data(mask_sentence_func(s) for s in sentences) 169 | 170 | 171 | def _mask_sentence_random_replace( 172 | sentence: Sentence, 173 | func: MaskFunction, 174 | rng: np.random.Generator, 175 | p: float = 1, 176 | ) -> Sentence: 177 | def next_word_delim(start: int): 178 | # Try to find next space 179 | word_delim_idx = ops.find(sentence, " ", start=start) 180 | if word_delim_idx == -1: 181 | # If not space, then we are at the last word 182 | # and return the remaining sentence. 183 | word_delim_idx = len(sentence) 184 | return word_delim_idx 185 | 186 | mask_idx = 0 187 | curr_idx = 0 188 | while curr_idx < len(sentence): 189 | word_delim_idx = next_word_delim(curr_idx) 190 | if rng.random() < p: 191 | mask = func(mask_idx) 192 | sentence = ops.replace(sentence, mask, (curr_idx, word_delim_idx)) 193 | mask_idx += 1 194 | curr_idx += len(mask) + 1 195 | else: 196 | curr_idx = word_delim_idx + 1 197 | return sentence 198 | 199 | 200 | def mask_poisson_spans( 201 | text: DataLike[SentenceLike], 202 | func: MaskFunction, 203 | rng: np.random.Generator, 204 | ) -> Data[Sentence]: 205 | """Masks spans of text with sizes following a poisson distribution. 206 | 207 | Args: 208 | text: Text to mask. 209 | func: Mask function to apply. 210 | rng: Numpy random generator to use. 211 | 212 | Returns: 213 | Masked text. 214 | """ 215 | text = promote_to_data(text) 216 | mask_sentence_func = functools.partial( 217 | _mask_poisson_spans, 218 | func=func, 219 | rng=rng, 220 | ) 221 | sentences = map(promote_to_sentence, text) 222 | return Data(mask_sentence_func(s) for s in sentences) 223 | 224 | 225 | def _mask_poisson_spans( 226 | text: Sentence, func: MaskFunction, rng: np.random.Generator 227 | ) -> Sentence: 228 | # Add plus 1 to indexes as they should index the charcter next 229 | # to the word limit. 230 | spaces = [i + 1 for i, c in enumerate(text.value) if c == " "] 231 | word_starts = [0] + spaces 232 | 233 | found = False 234 | while not found: 235 | num_masked_words = rng.poisson() 236 | start_word_idx = rng.choice(len(word_starts), 1)[0] 237 | if start_word_idx + num_masked_words <= len(word_starts): 238 | found = True 239 | 240 | start_idx = word_starts[start_word_idx] 241 | # We are masking until the end of the sentence. 242 | if start_word_idx + num_masked_words == len(word_starts): 243 | end_idx = len(text) 244 | # We are inserting words. 245 | elif num_masked_words == 0: 246 | end_idx = start_idx 247 | # We are masking words in the middle of the sentence. 248 | else: 249 | end_idx = word_starts[start_word_idx + num_masked_words] - 1 250 | 251 | # Only add space if inserting words. Otherwise, use available spaces. 252 | span = f"{func(0)} " if num_masked_words == 0 else func(0) 253 | return ops.replace(text, span, (start_idx, end_idx)) 254 | 255 | 256 | def mask_random_insert( 257 | text: DataLike[SentenceLike], 258 | func: MaskFunction, 259 | rng: np.random.Generator, 260 | p: float = 0.2, 261 | max_masks: Optional[int] = None, 262 | ) -> Data[Sentence]: 263 | """Inserts masks between random words in the text. 264 | 265 | Args: 266 | text: Text to apply the masks. 267 | func: Masking function to apply. 268 | rng: Numpy random generator to use. 269 | p: Probability of inserting a mask between two words. 270 | max_masks: Maximum masks to apply. If not specified all 271 | regular expression matches will be masked. 272 | 273 | Returns: 274 | Masked text. 275 | """ 276 | text = promote_to_data(text) 277 | mask_sentence_func = functools.partial( 278 | _mask_sentence_random_insert, 279 | func=func, 280 | rng=rng, 281 | p=p, 282 | max_masks=max_masks, 283 | ) 284 | sentences = map(promote_to_sentence, text) 285 | return Data(mask_sentence_func(s) for s in sentences) 286 | 287 | 288 | def _mask_sentence_random_insert( 289 | sentence: Sentence, 290 | func: MaskFunction, 291 | rng: np.random.Generator, 292 | p: float = 0.2, 293 | max_masks: Optional[int] = None, 294 | ) -> Sentence: 295 | 296 | if len(sentence) == 0: 297 | if rng.random() < p: 298 | sentence = ops.insert(sentence, func(0), 0) 299 | return sentence 300 | 301 | after_spaces = [i + 1 for i, c in enumerate(sentence.value) if c == " "] 302 | # Possible indexes where to start mask. 303 | possible_mask_starts = np.array([0] + after_spaces + [len(sentence)]) 304 | 305 | mask_idxs = rng.choice([False, True], size=len(possible_mask_starts), p=(1 - p, p)) 306 | (true_idxs,) = np.nonzero(mask_idxs) 307 | if max_masks is not None and len(true_idxs) > max_masks: 308 | true_idxs = rng.choice(true_idxs, size=max_masks) 309 | mask_idxs = np.full_like(mask_idxs, False) 310 | mask_idxs[true_idxs] = True 311 | 312 | mask_start = possible_mask_starts[mask_idxs] 313 | 314 | mask_idx = len(mask_start) - 1 315 | for idx in reversed(mask_start): 316 | mask = func(mask_idx) 317 | # Insert space before unless we are at the beginning, where we insert 318 | # a space after the mask. 319 | insert = f"{mask} " if idx != len(sentence) else f" {mask}" 320 | sentence = ops.insert(sentence, insert, idx) 321 | mask_idx -= 1 322 | 323 | return sentence 324 | -------------------------------------------------------------------------------- /smaug/ops/modification.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from smaug.core import Modification, ModificationTrace, SpanIndex 4 | from smaug.frozen import frozenlist 5 | 6 | 7 | def apply_modification(m: Modification, value: str) -> str: 8 | """Replaces old by new in the given string. 9 | 10 | Args: 11 | m: Modification to apply. 12 | value: String to apply the modification. 13 | 14 | Raises: 15 | ValueError: If the given value does not contain old at the given index. 16 | 17 | Returns: 18 | A string with the applied modification. 19 | """ 20 | if not value.startswith(m.old, m.idx): 21 | raise ValueError(f'str "{value}" does not have "{m.old}" at position {m.idx}.') 22 | replace_start = m.idx 23 | replace_end = replace_start + len(m.old) 24 | return f"{value[:replace_start]}{m.new}{value[replace_end:]}" 25 | 26 | 27 | def reverse_modification(m: Modification, value: str) -> str: 28 | """Reverses a modification in the given value. 29 | 30 | This operation performs the modification in the 31 | reverse direction, by replacing old by new. 32 | 33 | Args: 34 | m: Modification to reverse. 35 | value: String to apply the modification. 36 | 37 | Raises: 38 | ValueError: If the given value does not contain new at the given index. 39 | 40 | Returns: 41 | A string with the applied modification. 42 | """ 43 | reverse = Modification(old=m.new, new=m.old, idx=m.idx) 44 | return apply_modification(reverse, value) 45 | 46 | 47 | def apply_modification_trace(t: ModificationTrace, value: str) -> str: 48 | """Applies all modifications in order, from the oldest to the newest. 49 | 50 | Args: 51 | t: Modification trace to apply. 52 | value: String to apply the modifications. 53 | 54 | Returns: 55 | Modified string. 56 | """ 57 | return functools.reduce(lambda acc, mod: apply_modification(mod, acc), t, value) 58 | 59 | 60 | def reverse_modification_trace(t: ModificationTrace, value: str) -> str: 61 | """Applies all modifications in reverse order, from the newest to the oldest. 62 | 63 | Args: 64 | t: Modification trace to reverse 65 | value: String to apply the modifications. 66 | 67 | Returns: 68 | Modified string. 69 | """ 70 | return functools.reduce( 71 | lambda acc, mod: reverse_modification(mod, acc), reversed(list(t)), value 72 | ) 73 | 74 | 75 | def modified_spans_from_trace(t: ModificationTrace) -> frozenlist[SpanIndex]: 76 | """Computes the spans modified by a trace. 77 | 78 | Args: 79 | t: Modification trace to process. 80 | 81 | Returns: 82 | Spans of modified indices. Deletions are represented with the empty span. 83 | """ 84 | 85 | def append_modified_spans( 86 | spans: frozenlist[SpanIndex], m: Modification 87 | ) -> frozenlist[SpanIndex]: 88 | # If the modification is a deletion completely on top of an older 89 | # modification it should be as if the older modification never existed. 90 | reverting = m.new == "" and any(old == m.old_span_idx for old in spans) 91 | if reverting: 92 | spans = [old for old in spans if old != m.old_span_idx] 93 | 94 | new_spans = [] 95 | offset = m.new_span_idx.end - m.old_span_idx.end 96 | new_span = m.new_span_idx 97 | for old in spans: 98 | # Modification after the old span. The old span is unchanged. 99 | if old.end < m.old_span_idx.start: 100 | new_spans.append(old) 101 | # Modification before the old span. The old span must be shifted. 102 | elif m.old_span_idx.end < old.start: 103 | shifted = SpanIndex(old.start + offset, old.end + offset) 104 | new_spans.append(shifted) 105 | # Modification intersects the old span. The old span must be merged 106 | # into the new span. 107 | else: 108 | new_start = min(old.start, new_span.start) 109 | new_end = max(old.end + offset, new_span.end) 110 | new_span = SpanIndex(new_start, new_end) 111 | 112 | # Only add new span if not reverting 113 | if not reverting: 114 | new_spans.append(new_span) 115 | 116 | return frozenlist(sorted(new_spans)) 117 | 118 | return functools.reduce(append_modified_spans, t, frozenlist()) 119 | -------------------------------------------------------------------------------- /smaug/ops/nli.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | 4 | from smaug.broadcast import broadcast_data 5 | from smaug.core import DataLike, SentenceLike 6 | from smaug.promote import promote_to_data 7 | 8 | 9 | def roberta_mnli_predict( 10 | premises: DataLike[SentenceLike], 11 | hypotheses: DataLike[SentenceLike], 12 | model: transformers.RobertaForSequenceClassification, 13 | tokenizer: transformers.PreTrainedTokenizerBase, 14 | cuda: bool = False, 15 | ) -> torch.FloatTensor: 16 | """Performs NLI with RoBERTA on the received sentences. 17 | 18 | Args: 19 | premises: Premises to process. 20 | hypotheses: Hypotheses to consider. 21 | model: RoBERTa model to use. 22 | tokenizer: RoBERTa tokenizer to use. 23 | cuda: Whether to use gpu or not. 24 | 25 | Returns: 26 | Logits for each class. 27 | """ 28 | premises = promote_to_data(premises) 29 | hypotheses = promote_to_data(hypotheses) 30 | premises, hypotheses = broadcast_data(premises, hypotheses) 31 | inputs = [f"{p} {h}" for p, h in zip(premises, hypotheses)] 32 | 33 | if cuda: 34 | model.cuda() 35 | with torch.no_grad(): 36 | input_ids = tokenizer( 37 | inputs, 38 | padding=True, 39 | return_tensors="pt", 40 | truncation=True, 41 | max_length=512, 42 | ).input_ids 43 | if cuda: 44 | input_ids = input_ids.cuda() 45 | return model(input_ids).logits 46 | 47 | 48 | def roberta_mnli_contradiction_id( 49 | model: transformers.RobertaForSequenceClassification, 50 | ) -> int: 51 | return model.config.label2id["CONTRADICTION"] 52 | -------------------------------------------------------------------------------- /smaug/ops/pos_tagging.py: -------------------------------------------------------------------------------- 1 | import stanza 2 | 3 | from smaug.core import Data, DataLike, SentenceLike 4 | from smaug.promote import promote_to_data, promote_to_sentence 5 | 6 | 7 | def stanza_pos_predict( 8 | text: DataLike[SentenceLike], pos_pipeline: stanza.Pipeline 9 | ) -> Data: 10 | """Predicts part-of-speech tags with stanza POS model.""" 11 | text = promote_to_data(text) 12 | sentences = [promote_to_sentence(t) for t in text] 13 | return Data(pos_pipeline(s.value) for s in sentences) 14 | -------------------------------------------------------------------------------- /smaug/ops/sentence.py: -------------------------------------------------------------------------------- 1 | from smaug.core import SpanIndexLike, Modification, ModificationTrace, Sentence 2 | from smaug.promote import promote_to_span_index 3 | from smaug.ops.modification import apply_modification 4 | 5 | 6 | def insert(s: Sentence, span: str, idx: int) -> Sentence: 7 | """Creates a new sentence by inserting the span at the given idx. 8 | 9 | Args: 10 | s: Sentence to modify. 11 | span: Span to insert. 12 | idx: Index where to start the span insertion. 13 | 14 | Returns: 15 | Sentence with the insertion operation applied. 16 | """ 17 | # An insertion is a replacement of the empty string at position idx 18 | # by the given span. 19 | # Set index to int as numpy.int64 is not serializable. 20 | 21 | return modify_sentence(s, Modification(old="", new=span, idx=int(idx))) 22 | 23 | 24 | def replace(s: Sentence, span: str, loc: SpanIndexLike) -> Sentence: 25 | """Creates a new sentence by replacing characters by a new span. 26 | 27 | The new sentence will have the characters indexed by loc replaced 28 | by the new span. 29 | 30 | Args: 31 | s: Sentence to modify. 32 | span: Span to insert. 33 | loc: Indexes to delimit the text to be replaced by the span. 34 | 35 | Returns: 36 | Sentence with the replacement operation applied. 37 | """ 38 | loc = promote_to_span_index(loc) 39 | old = s.value[loc.start : loc.end] 40 | return modify_sentence(s, Modification(old=old, new=span, idx=loc.start)) 41 | 42 | 43 | def delete(s: Sentence, loc: SpanIndexLike) -> Sentence: 44 | """Creates a new sentence by deleting the characters indexed by loc. 45 | 46 | Args: 47 | s: Sentence to modify. 48 | loc: Indexes to delimit the text to be deleted. 49 | 50 | Returns: 51 | Sentence with the deletion operation applied. 52 | """ 53 | loc = promote_to_span_index(loc) 54 | to_delete = s.value[loc.start : loc.end] 55 | # A deletion is a replacement of the span indexed by loc with the 56 | # empty string. 57 | return modify_sentence(s, Modification(old=to_delete, new="", idx=loc.start)) 58 | 59 | 60 | def prepend(s: Sentence, span: str) -> Sentence: 61 | return insert(s, span, 0) 62 | 63 | 64 | def append(s: Sentence, span: str) -> Sentence: 65 | return insert(s, span, len(s)) 66 | 67 | 68 | def rstrip(s: Sentence) -> Sentence: 69 | last_space_idx = len(s) 70 | 71 | while last_space_idx > 0 and s.value[last_space_idx - 1] == " ": 72 | last_space_idx -= 1 73 | 74 | new_s = s 75 | if last_space_idx != len(s): 76 | new_s = delete(s, (last_space_idx, len(s))) 77 | 78 | return new_s 79 | 80 | 81 | def modify_sentence(s: Sentence, m: Modification) -> Sentence: 82 | """Creates a new sentence by applying a modification to this sentence. 83 | 84 | Args: 85 | s: Sentence to modify. 86 | m: Modification to apply. 87 | 88 | Returns: 89 | The new sentence. 90 | """ 91 | new_value = apply_modification(m, s.value) 92 | new_trace = ModificationTrace(m, s.trace) 93 | return Sentence(value=new_value, trace=new_trace) 94 | 95 | 96 | def find(s: Sentence, sub: str, start=None, end=None) -> int: 97 | return s.value.find(sub, start, end) 98 | 99 | 100 | def startswith(s: Sentence, prefix, start=None, end=None) -> bool: 101 | return s.value.startswith(prefix, start, end) 102 | 103 | 104 | def endswith(s: Sentence, suffix, start=None, end=None) -> bool: 105 | return s.value.endswith(suffix, start, end) 106 | -------------------------------------------------------------------------------- /smaug/ops/sentence_comparison.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import functools 3 | import nltk 4 | import stanza 5 | 6 | from smaug.core import Sentence 7 | from smaug.ops import detection 8 | 9 | 10 | def equal_numbers_count(s1: Sentence, s2: Sentence) -> bool: 11 | s1_count = len(detection.regex_detect_numbers(s1).item()) 12 | s2_count = len(detection.regex_detect_numbers(s2).item()) 13 | return s1_count == s2_count 14 | 15 | 16 | def equal_named_entities_count( 17 | s1: Sentence, s2: Sentence, ner_pipeline: stanza.Pipeline 18 | ) -> bool: 19 | ner_func = functools.partial( 20 | detection.stanza_detect_named_entities, 21 | ner_pipeline=ner_pipeline, 22 | ) 23 | s1_count = len(ner_func(s1).item()) 24 | s2_count = len(ner_func(s2).item()) 25 | return s1_count == s2_count 26 | 27 | 28 | def character_insertions(original: Sentence, modified: Sentence, chars: str) -> int: 29 | """Returns the number of times the given characters were inserted. 30 | 31 | Args: 32 | original: Original sentence to perform comparison. 33 | modified: Sentence with modifications. 34 | chars: Characters to consider. 35 | 36 | Returns: 37 | The number of inserted characters. 38 | """ 39 | original_counts = collections.Counter(c for c in original if c in chars) 40 | modified_counts = collections.Counter(c for c in modified if c in chars) 41 | insertions = modified_counts - original_counts 42 | return sum(insertions.values()) 43 | 44 | 45 | def edit_distance(s1: Sentence, s2: Sentence, level: str) -> int: 46 | """Computes the edit distance between two sentences. 47 | 48 | Args: 49 | s1: First sentence. 50 | s2: Second sentence. 51 | level: Level at which to measure the minimum edit distance. Must be "word" or "char". 52 | 53 | Returns: 54 | Computed edit distance. 55 | """ 56 | 57 | def char_val_func() -> int: 58 | return nltk.edit_distance(s1.value, s2.value) 59 | 60 | def word_val_func() -> int: 61 | return nltk.edit_distance(s1.value.split(), s2.value.split()) 62 | 63 | levels = ("char", "word") 64 | if level not in levels: 65 | raise ValueError(f"Unknown level {level}: must be one of {levels}.") 66 | cmp_func = char_val_func if level == "char" else word_val_func 67 | return cmp_func() 68 | -------------------------------------------------------------------------------- /smaug/ops/text_generation.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import numpy as np 3 | import stanza 4 | import typing 5 | import torch 6 | import transformers 7 | 8 | from typing import Optional, Tuple 9 | 10 | from smaug import ops 11 | from smaug.core import Data, DataLike, Sentence, SentenceLike 12 | from smaug.promote import promote_to_data, promote_to_sentence 13 | 14 | 15 | _PERTURB_TOK = "<|perturb|>" 16 | _BLANK_TOK = "[BLANK]" 17 | _SEP_TOK = "[SEP]" 18 | _EMPTY_TOK = "[EMPTY]" 19 | _ANSWER_TOK = "[ANSWER]" 20 | 21 | _NEGATION = "[negation]" 22 | 23 | 24 | def polyjuice_negate( 25 | text: DataLike[SentenceLike], 26 | pos_pipeline: stanza.Pipeline, 27 | model: transformers.AutoModelForCausalLM, 28 | tokenizer: transformers.PreTrainedTokenizerBase, 29 | rng: np.random.Generator, 30 | cuda: bool = False, 31 | ) -> Data[Optional[Sentence]]: 32 | """Polyjuice model conditioned on negation. 33 | 34 | This model wraps the Polyjuice model presented in the paper 35 | "Polyjuice: Generating Counterfactuals for Explaining, Evaluating, and Improving Models" 36 | from Tongshuang Wu, Marco Tulio Ribeiro, Jeffrey Heer, Daniel S. Weld 37 | at the Association for Computational Linguistics (ACL), 2021. 38 | The code for this model is available at https://github.com/tongshuangwu/polyjuice. 39 | 40 | This model conditions the previous model for negation, by masking verbs. 41 | It tries to mask also auxiliary verbs with a given verb. 42 | 43 | POS tagging is performed with the stanza POS tagger. 44 | 45 | Args: 46 | text: Text input. 47 | cuda: Whether to usa a cuda enabled gpu or not. 48 | 49 | Returns: 50 | Negated sentences. 51 | """ 52 | text = promote_to_data(text) 53 | sentences = [promote_to_sentence(t) for t in text] 54 | 55 | if cuda: 56 | model.cuda() 57 | 58 | prompts = [_add_negation_prompt(pos_pipeline, rng, s) for s in sentences] 59 | with torch.no_grad(): 60 | polyjuice_func = functools.partial( 61 | _polyjuice_inference, 62 | tokenizer=tokenizer, 63 | model=model, 64 | cuda=cuda, 65 | ) 66 | outputs = [polyjuice_func(p.value) if p is not None else None for p in prompts] 67 | 68 | return Data( 69 | _extract_results(p, o) if p is not None else None 70 | for p, o in zip(prompts, outputs) 71 | ) 72 | 73 | 74 | def _add_negation_prompt( 75 | pos_pipeline: stanza.Pipeline, rng: np.random.Generator, sentence: Sentence 76 | ) -> Optional[Sentence]: 77 | tagged = ops.stanza_pos_predict(sentence, pos_pipeline).item() 78 | possible_mask_intervals = [] 79 | for tagged_sentence in tagged.sentences: 80 | for i, _ in enumerate(tagged_sentence.words): 81 | interval = _get_prev_aux_if_verb(tagged_sentence, i) 82 | if interval: 83 | possible_mask_intervals.append(interval) 84 | interval = _get_verb_if_verb(tagged_sentence, i) 85 | if interval: 86 | possible_mask_intervals.append(interval) 87 | 88 | if not possible_mask_intervals: 89 | return None 90 | 91 | mask_start, mask_end = rng.choice(possible_mask_intervals) 92 | masked = ops.replace(sentence, _BLANK_TOK, (mask_start, mask_end)) 93 | prompt = ops.append( 94 | ops.prepend(masked, f"{sentence} {_PERTURB_TOK} {_NEGATION} "), 95 | f" {_SEP_TOK}", 96 | ) 97 | 98 | return prompt 99 | 100 | 101 | def _polyjuice_inference( 102 | prompt: str, 103 | tokenizer: transformers.PreTrainedTokenizerBase, 104 | model: transformers.AutoModelForCausalLM, 105 | cuda: bool, 106 | ) -> str: 107 | inputs = tokenizer( 108 | prompt, 109 | padding=False, 110 | truncation=True, 111 | max_length=1024 - 1, 112 | return_tensors="pt", 113 | ) 114 | if cuda: 115 | inputs = {k: v.cuda() for k, v in inputs.items()} 116 | 117 | output_ids = model.generate( 118 | **inputs, 119 | num_beams=5, 120 | early_stopping=True, 121 | pad_token_id=tokenizer.eos_token_id, 122 | max_length=1024, 123 | do_sample=False, 124 | no_repeat_ngram_size=2, 125 | ) 126 | 127 | return tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 128 | 129 | 130 | def _extract_results( 131 | prompt: Sentence, polyjuice_output: str 132 | ) -> typing.Optional[Sentence]: 133 | prompt_and_answers = polyjuice_output.split(_SEP_TOK) 134 | if len(prompt_and_answers) < 2: 135 | return None 136 | _, answers = prompt_and_answers 137 | 138 | negation_start = ops.find(prompt, _NEGATION) 139 | negation_end = negation_start + len(_NEGATION) 140 | # +1 to account for extra space 141 | prompt_no_prefix = ops.delete(prompt, (0, negation_end + 1)) 142 | sep_start = ops.find(prompt_no_prefix, _SEP_TOK) 143 | # -1 to account for extra space 144 | masked_sentence = ops.delete( 145 | prompt_no_prefix, (sep_start - 1, len(prompt_no_prefix)) 146 | ) 147 | 148 | for answer in answers.split(_ANSWER_TOK)[:-1]: 149 | # Avoid bad escape char by replacing single \ with \\ 150 | answer = answer.strip().replace("\\", "\\\\") 151 | answer = answer if answer != _EMPTY_TOK else "" 152 | blank_start = ops.find(masked_sentence, _BLANK_TOK) 153 | blank_end = blank_start + len(_BLANK_TOK) 154 | masked_sentence = ops.replace(masked_sentence, answer, (blank_start, blank_end)) 155 | 156 | return masked_sentence 157 | 158 | 159 | def _get_prev_aux_if_verb(sentence, i) -> Optional[Tuple]: 160 | if sentence.words[i].upos != "VERB" or i == 0: 161 | return None 162 | last_aux_idx = i 163 | while last_aux_idx > 0 and sentence.words[last_aux_idx - 1].upos == "AUX": 164 | last_aux_idx -= 1 165 | if last_aux_idx == i: 166 | return None 167 | return sentence.words[last_aux_idx].start_char, sentence.words[i].end_char 168 | 169 | 170 | def _get_verb_if_verb(sentence, i) -> Optional[Tuple]: 171 | word = sentence.words[i] 172 | if word.upos != "VERB": 173 | return None 174 | return word.start_char, word.end_char 175 | -------------------------------------------------------------------------------- /smaug/perturb/__init__.py: -------------------------------------------------------------------------------- 1 | from smaug.perturb.delete_random_words import ( 2 | delete_random_words_transform, 3 | ) 4 | from smaug.perturb.delete_span_between_punctuation import ( 5 | delete_span_between_punctuation_transform, 6 | ) 7 | from smaug.perturb.insert_text_span import ( 8 | insert_text_span, 9 | insert_text_span_transform, 10 | insert_text_span_validation, 11 | ) 12 | from smaug.perturb.negate import ( 13 | negate, 14 | negate_transform, 15 | negate_validation, 16 | ) 17 | from smaug.perturb.swap_named_entity import ( 18 | swap_named_entity, 19 | swap_named_entity_transform, 20 | swap_named_entity_validation, 21 | ) 22 | from smaug.perturb.swap_number import ( 23 | swap_number, 24 | swap_number_transform, 25 | swap_number_validation, 26 | ) 27 | from smaug.perturb.swap_poisson_span import ( 28 | swap_poisson_span, 29 | swap_poisson_span_transform, 30 | swap_poisson_span_validation, 31 | ) 32 | -------------------------------------------------------------------------------- /smaug/perturb/delete_random_words.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from smaug import ops 4 | from smaug.core import Data, DataLike, Sentence, SentenceLike 5 | from smaug.promote import promote_to_data, promote_to_sentence 6 | 7 | 8 | def delete_random_words_transform( 9 | sentences: DataLike[SentenceLike], 10 | rng: np.random.Generator, 11 | p: float = 0.2, 12 | ) -> Data[Sentence]: 13 | """Deletes random words in the sentences. 14 | 15 | Args: 16 | sentences: Sentences to transform. 17 | rng: Numpy generator to use. 18 | p: Probability of deleting a word. 19 | 20 | Returns: 21 | Transformed sentences. 22 | """ 23 | 24 | def next_word_start(s: Sentence, start: int): 25 | # Try to find next space 26 | word_delim_idx = ops.find(s, " ", start=start) 27 | if word_delim_idx == -1: 28 | # If not space, then we are at the last word 29 | # and return the remaining sentence. 30 | word_delim_idx = len(s) 31 | return word_delim_idx + 1 32 | 33 | def transform(s: SentenceLike) -> Sentence: 34 | s = promote_to_sentence(s) 35 | 36 | curr_idx = 0 37 | while curr_idx < len(s): 38 | word_start_idx = next_word_start(s, curr_idx) 39 | if rng.random() < p: 40 | s = ops.delete(s, (curr_idx, word_start_idx)) 41 | else: 42 | curr_idx = word_start_idx 43 | 44 | return s 45 | 46 | sentences = promote_to_data(sentences) 47 | return Data([transform(s) for s in sentences]) 48 | -------------------------------------------------------------------------------- /smaug/perturb/delete_span_between_punctuation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from smaug import ops 4 | from smaug.core import Data, DataLike, Sentence, SentenceLike 5 | from smaug.promote import promote_to_data, promote_to_sentence 6 | 7 | 8 | from typing import Optional 9 | 10 | 11 | def delete_span_between_punctuation_transform( 12 | sentences: DataLike[SentenceLike], 13 | rng: np.random.Generator, 14 | low: int = 4, 15 | high: int = 10, 16 | ) -> Data[Optional[Sentence]]: 17 | """Deletes a text span between two punctuation symbols. 18 | 19 | Args: 20 | sentences: Sentences to transform. 21 | rng: Numpy random number generator to use. 22 | low: Minimum number of words for considered span. 23 | high: Maximum number of words for considered spans. 24 | """ 25 | 26 | def delete_span(s: Sentence, possible_spans_idxs) -> Optional[Sentence]: 27 | possible_spans_idxs = [ 28 | span_idx 29 | for span_idx in possible_spans_idxs 30 | if span_idx.start > 0 and low < len(s.value[span_idx.start:span_idx.end].split()) < high 31 | ] 32 | if len(possible_spans_idxs) == 0: 33 | return None 34 | 35 | idx_to_drop = rng.choice(possible_spans_idxs) 36 | 37 | return ops.delete(s, idx_to_drop) 38 | 39 | def clean_sentence(s: Sentence) -> Sentence: 40 | s = ops.rstrip(s) 41 | # To increase credibility of generated sentence, 42 | # replace last "," with "." . 43 | if not ops.endswith(s, (".", "?", "!")): 44 | s = ops.replace(s, ".", (len(s) - 1, len(s))) 45 | return s 46 | 47 | sentences = promote_to_data(sentences) 48 | promoted = Data([promote_to_sentence(s) for s in sentences]) 49 | possible_spans_idxs = ops.regex_detect_spans_between_punctuation(promoted) 50 | deleted = [delete_span(s, p) for s, p in zip(promoted, possible_spans_idxs)] 51 | return Data([clean_sentence(s) if s is not None else None for s in deleted]) 52 | -------------------------------------------------------------------------------- /smaug/perturb/insert_text_span.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import transformers 4 | 5 | from smaug import functional 6 | from smaug import ops 7 | from smaug.core import Data, DataLike, Sentence, SentenceLike 8 | 9 | from typing import Optional 10 | 11 | 12 | def insert_text_span( 13 | sentences: DataLike[SentenceLike], 14 | mt5_model: transformers.MT5ForConditionalGeneration, 15 | mt5_tokenizer: transformers.T5Tokenizer, 16 | rng: np.random.Generator, 17 | p: float = 0.1, 18 | max_masks: int = 3, 19 | gpu: bool = False, 20 | ) -> Data[Optional[Sentence]]: 21 | """Inserts spans of text at random places in the sentence. 22 | 23 | This perturbation inserts masks at random places in the 24 | sentence and then uses mT5 to create new content. 25 | 26 | It also runs default validations to ensure a minimum quality 27 | level. 28 | 29 | Args: 30 | sentences: Sentences to transform. 31 | mt5_model: mT5 model to use for generation. 32 | mt5_tokenizer: mT5 tokenizer for generation. 33 | rng: Numpy random generator to use. 34 | p: Probability of inserting a mask between two words. 35 | max_masks: Maximum number of masks to insert. 36 | gpu: Whether to use gpu. 37 | 38 | Returns: 39 | Perturbed sentences. Returns None for sentences for which 40 | the validations failed. 41 | """ 42 | transformed = insert_text_span_transform( 43 | sentences, 44 | mt5_model, 45 | mt5_tokenizer, 46 | rng, 47 | p, 48 | max_masks, 49 | gpu, 50 | ) 51 | return insert_text_span_validation(sentences, transformed) 52 | 53 | 54 | def insert_text_span_transform( 55 | sentences: DataLike[SentenceLike], 56 | mt5_model: transformers.MT5ForConditionalGeneration, 57 | mt5_tokenizer: transformers.T5Tokenizer, 58 | rng: np.random.Generator, 59 | p: float = 0.1, 60 | max_masks: int = 3, 61 | gpu: bool = False, 62 | ) -> Data[Sentence]: 63 | """Performs the transform phase of the insert_text_span perturbation. 64 | 65 | Args: 66 | sentences: Sentences to transform. 67 | mt5_model: mT5 model to use for generation. 68 | mt5_tokenizer: mT5 tokenizer for generation. 69 | rng: Numpy random generator to use. 70 | p: Probability of inserting a mask between two words. 71 | max_masks: Maximum number of masks to insert. 72 | gpu: Whether to use gpu. 73 | 74 | Returns: 75 | Transformed sentences. 76 | """ 77 | masked = ops.mask_random_insert( 78 | sentences, 79 | func=ops.mT5_masking_function, 80 | rng=rng, 81 | p=p, 82 | max_masks=max_masks, 83 | ) 84 | 85 | return ops.mT5_generate( 86 | masked, 87 | model=mt5_model, 88 | tokenizer=mt5_tokenizer, 89 | cuda=gpu, 90 | ) 91 | 92 | 93 | def insert_text_span_validation( 94 | originals: DataLike[SentenceLike], 95 | transformed: DataLike[Optional[SentenceLike]], 96 | ) -> Data[Optional[Sentence]]: 97 | """Performs basic validation for the insert_text_span perturbation. 98 | 99 | It validates that the generated sentences are different from 100 | the original, and ensures a basic quality level by removing 101 | sentences that match the mT5 masking pattern () 102 | and sentences with character insertions for <>()[]{}_, as they are 103 | likely model hallucinations. 104 | 105 | Args: 106 | originals: Original sentences. 107 | transformed: Transformed sentences. 108 | 109 | Returns: 110 | Validated sentences. Returns None for sentences for which 111 | the validations failed. 112 | """ 113 | 114 | def val_func(o: Sentence, p: Sentence) -> bool: 115 | return ( 116 | o != p 117 | and re.search(r"", p.value) is None 118 | and ops.character_insertions(o, p, "<>()[]{}_") == 0 119 | ) 120 | 121 | return functional.lift_boolean_validation(val_func)(originals, transformed) 122 | -------------------------------------------------------------------------------- /smaug/perturb/negate.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import numpy as np 3 | import stanza 4 | import re 5 | import transformers 6 | 7 | from smaug import functional 8 | from smaug import ops 9 | from smaug.core import Data, DataLike, Sentence, SentenceLike 10 | 11 | from typing import Optional 12 | 13 | 14 | def negate( 15 | sentences: DataLike[SentenceLike], 16 | pos_pipeline: stanza.Pipeline, 17 | polyjuice_model: transformers.AutoModelForCausalLM, 18 | polyjuice_tokenizer: transformers.PreTrainedTokenizerBase, 19 | roberta_model: transformers.RobertaForSequenceClassification, 20 | roberta_tokenizer: transformers.PreTrainedTokenizerBase, 21 | rng: np.random.Generator, 22 | gpu: bool = False, 23 | ) -> Data[Optional[Sentence]]: 24 | """Negates a given sentence. 25 | 26 | This perturbation uses a POS tagger to identify verbs and their 27 | preceding auxiliary verbs and then applies Polyjuice to negate 28 | one of the detected spans. 29 | 30 | It also runs default validations to ensure both a minimum quality level, 31 | and that the generated text contradicts the original sentences. 32 | 33 | Args: 34 | sentences: Sentences to transform. 35 | pos_pipeline: POS pipeline to detect verbs and auxiliary verbs. 36 | polyjuice_model: Polyjuice model to use for negation. 37 | polyjuice_tokenizer: Polyjuice tokenizer to use for negation. 38 | roberta_model: RoBERTa model to use for NLI. 39 | roberta_tokenizer: RoBERTa tokenizer to use for NLI. 40 | rng: Numpy random generator to use. 41 | gpu: Whether to use gpu. 42 | 43 | Returns: 44 | Perturbed sentences. Returns None for sentences for which 45 | the transform or the validations failed. 46 | """ 47 | transformed = negate_transform( 48 | sentences, 49 | pos_pipeline, 50 | polyjuice_model, 51 | polyjuice_tokenizer, 52 | rng, 53 | gpu, 54 | ) 55 | return negate_validation( 56 | sentences, transformed, roberta_model, roberta_tokenizer, gpu 57 | ) 58 | 59 | 60 | def negate_transform( 61 | sentences: DataLike[SentenceLike], 62 | pos_pipeline: stanza.Pipeline, 63 | polyjuice_model: transformers.AutoModelForCausalLM, 64 | polyjuice_tokenizer: transformers.PreTrainedTokenizerBase, 65 | rng: np.random.Generator, 66 | gpu: bool = False, 67 | ) -> Data[Optional[Sentence]]: 68 | """Performs the transform phase for the negate perturbation. 69 | 70 | Args: 71 | sentences: Sentences to transform. 72 | pos_pipeline: POS pipeline to detect verbs and auxiliary verbs. 73 | polyjuice_model: Polyjuice model to use for negation. 74 | polyjuice_tokenizer: Polyjuice tokenizer to use for negation. 75 | rng: Numpy random generator to use. 76 | gpu: Whether to use gpu. 77 | 78 | Returns: 79 | Transformed sentences. Returns None for sentences for which 80 | the transform failed. 81 | """ 82 | return ops.polyjuice_negate( 83 | sentences, 84 | pos_pipeline=pos_pipeline, 85 | model=polyjuice_model, 86 | tokenizer=polyjuice_tokenizer, 87 | rng=rng, 88 | cuda=gpu, 89 | ) 90 | 91 | 92 | def negate_validation( 93 | originals: DataLike[SentenceLike], 94 | transformed: DataLike[SentenceLike], 95 | roberta_model: transformers.RobertaForSequenceClassification, 96 | roberta_tokenizer: transformers.PreTrainedTokenizerBase, 97 | gpu: bool = False, 98 | ) -> Data[Optional[Sentence]]: 99 | """Performs the validation phase for the negate transform. 100 | 101 | Args: 102 | originals: Original sentences. 103 | transformed: Transformed sentences. 104 | roberta_model: RoBERTa model to use for NLI. 105 | roberta_tokenizer: RoBERTa tokenizer to use for NLI. 106 | gpu: Whether to use gpu. 107 | 108 | Returns: 109 | Validated sentences. Returns None for sentences for which 110 | the validations failed. 111 | """ 112 | 113 | def val_func(o: Sentence, p: Sentence) -> bool: 114 | return ( 115 | o != p 116 | and re.search("EMPTY", p.value) is None 117 | and roberta_predict_func(o, p).argmax().item() == roberta_contradiction_id 118 | ) 119 | 120 | roberta_predict_func = functools.partial( 121 | ops.roberta_mnli_predict, 122 | model=roberta_model, 123 | tokenizer=roberta_tokenizer, 124 | cuda=gpu, 125 | ) 126 | roberta_contradiction_id = ops.roberta_mnli_contradiction_id(roberta_model) 127 | 128 | return functional.lift_boolean_validation(val_func)(originals, transformed) 129 | -------------------------------------------------------------------------------- /smaug/perturb/swap_named_entity.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import numpy as np 3 | import re 4 | import stanza 5 | import transformers 6 | 7 | from smaug import ops 8 | from smaug import functional 9 | from smaug.core import Data, DataLike, Sentence, SentenceLike 10 | 11 | from typing import Optional 12 | 13 | 14 | def swap_named_entity( 15 | sentences: DataLike[SentenceLike], 16 | ner_pipeline: stanza.Pipeline, 17 | mt5_model: transformers.MT5ForConditionalGeneration, 18 | mt5_tokenizer: transformers.T5Tokenizer, 19 | rng: np.random.Generator, 20 | gpu: bool = False, 21 | ) -> Data[Optional[Sentence]]: 22 | """Swaps a named entity in the received sentences. 23 | 24 | It searches for named entities in the original records using a 25 | ner model and then uses Google's mt5 to replace the one of the 26 | found expressions with text. 27 | 28 | It also ensures that the generated sentences are not equal to 29 | the original sentences and that they have the same number of 30 | named entities (by using the same NER model). 31 | 32 | Args: 33 | sentences: Records to perturb. 34 | ner_pipeline: stanza NER pipeline to use. 35 | mt5_model: mt5 model to use. 36 | mt5_tokenizer: mt5 tokenizer to use. 37 | rng: Numpy random generator to use. 38 | gpu: Whether to use gpu. 39 | 40 | Returns: 41 | The perturbed records. Returns None for sentences for which 42 | the validations failed. 43 | """ 44 | 45 | transformed = swap_named_entity_transform( 46 | sentences, ner_pipeline, mt5_model, mt5_tokenizer, rng, gpu 47 | ) 48 | return swap_named_entity_validation(sentences, transformed, ner_pipeline) 49 | 50 | 51 | def swap_named_entity_transform( 52 | sentences: DataLike[SentenceLike], 53 | ner_pipeline: stanza.Pipeline, 54 | mt5_model: transformers.MT5ForConditionalGeneration, 55 | mt5_tokenizer: transformers.T5Tokenizer, 56 | rng: np.random.Generator, 57 | gpu: bool = False, 58 | ) -> Data[Sentence]: 59 | """Performs the transform phase for the swap_named_entity perturbation. 60 | 61 | Args: 62 | sentences: Records to perturb. 63 | ner_pipeline: stanza NER pipeline to user. 64 | mt5_model: mt5 model to use. 65 | mt5_tokenizer: mt5 tokenizer to use. 66 | rng: Numpy random generator to use. 67 | gpu: Whether to use gpu. 68 | 69 | Returns: 70 | Transformed sentences. 71 | """ 72 | ner_func = functools.partial( 73 | ops.stanza_detect_named_entities, 74 | ner_pipeline=ner_pipeline, 75 | ) 76 | 77 | masked = ops.mask_detections( 78 | sentences, 79 | detect_func=ner_func, 80 | mask_func=ops.mT5_masking_function, 81 | rng=rng, 82 | max_masks=1, 83 | ) 84 | 85 | return ops.mT5_generate( 86 | masked, 87 | model=mt5_model, 88 | tokenizer=mt5_tokenizer, 89 | cuda=gpu, 90 | ) 91 | 92 | 93 | def swap_named_entity_validation( 94 | originals: DataLike[SentenceLike], 95 | transformed: DataLike[Optional[SentenceLike]], 96 | ner_pipeline: stanza.Pipeline, 97 | ) -> Data[Optional[Sentence]]: 98 | """Performs the validation phase for the swap_named_entity perturbation. 99 | 100 | It validates that the generated sentences are different from 101 | the original, and ensures a basic quality level by removing 102 | sentences that match the mT5 masking pattern () 103 | and sentences with character insertions for <>()[]{}_, as they are 104 | likely model hallucinations. 105 | 106 | It also validates that the original and transformed sentences have 107 | the same count of named entities to ensure the mT5 model generated 108 | a named entity. 109 | 110 | Args: 111 | originals: Original sentences. 112 | transformed: Transformed sentences. 113 | ner_pipeline: stanza NER pipeline to use. 114 | 115 | Returns: 116 | Validated sentences. Returns None for sentences for which 117 | the validations failed. 118 | """ 119 | 120 | def val_func(o: Sentence, p: Sentence) -> bool: 121 | return ( 122 | o != p 123 | and re.search(r"", p.value) is None 124 | and ops.character_insertions(o, p, "<>()[]{}_") == 0 125 | and ops.equal_named_entities_count(o, p, ner_pipeline) 126 | ) 127 | 128 | return functional.lift_boolean_validation(val_func)(originals, transformed) 129 | -------------------------------------------------------------------------------- /smaug/perturb/swap_number.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import transformers 4 | 5 | from smaug import functional 6 | from smaug import ops 7 | from smaug.core import Data, DataLike, Sentence, SentenceLike 8 | 9 | from typing import Optional 10 | 11 | 12 | def swap_number( 13 | sentences: DataLike[SentenceLike], 14 | mt5_model: transformers.MT5ForConditionalGeneration, 15 | mt5_tokenizer: transformers.T5Tokenizer, 16 | rng: np.random.Generator, 17 | gpu: bool = False, 18 | ) -> Data[Optional[Sentence]]: 19 | """Swaps a number in the received sentences. 20 | 21 | It searches for numbers in the original records using a regular expression and 22 | then uses Google's mt5 to replace the one of the found expressions with text. 23 | 24 | It also ensures that the generated sentences are not equal to the original 25 | sentences and that they have the same number of numbers (by using the same 26 | regular expression). 27 | 28 | Args: 29 | sentences: Records to perturb. 30 | mt5_model: mt5 model to use. 31 | mt5_tokenizer: mt5 tokenizer to use. 32 | rng: Numpy random generator to use. 33 | gpu: Whether to use gpu. 34 | 35 | Returns: 36 | The perturbed records. Returns None for sentences for which 37 | the validations failed. 38 | """ 39 | 40 | transformed = swap_number_transform(sentences, mt5_model, mt5_tokenizer, rng, gpu) 41 | return swap_number_validation(sentences, transformed) 42 | 43 | 44 | def swap_number_transform( 45 | sentences: DataLike[SentenceLike], 46 | mt5_model: transformers.MT5ForConditionalGeneration, 47 | mt5_tokenizer: transformers.T5Tokenizer, 48 | rng: np.random.Generator, 49 | gpu: bool = False, 50 | ) -> Data[Sentence]: 51 | """Performs the transform phase for the swap_number perturbation. 52 | 53 | Args: 54 | sentences: Records to perturb. 55 | mt5_model: mt5 model to use. 56 | mt5_tokenizer: mt5 tokenizer to use. 57 | rng: Numpy random generator to use. 58 | gpu: Whether to use gpu. 59 | 60 | Returns: 61 | Transformed sentences. 62 | """ 63 | masked = ops.mask_detections( 64 | sentences, 65 | detect_func=ops.regex_detect_numbers, 66 | mask_func=ops.mT5_masking_function, 67 | rng=rng, 68 | max_masks=1, 69 | ) 70 | 71 | return ops.mT5_generate( 72 | masked, 73 | model=mt5_model, 74 | tokenizer=mt5_tokenizer, 75 | cuda=gpu, 76 | ) 77 | 78 | 79 | def swap_number_validation( 80 | originals: DataLike[SentenceLike], 81 | transformed: DataLike[Optional[SentenceLike]], 82 | ) -> Data[Optional[Sentence]]: 83 | """Performs the validation phase for the swap_number perturbation. 84 | 85 | It validates that the generated sentences are different from 86 | the original, and ensures a basic quality level by removing 87 | sentences that match the mT5 masking pattern () 88 | and sentences with character insertions for <>()[]{}_, as they are 89 | likely model hallucinations. 90 | 91 | It also validates that the original and transformed sentences have 92 | the same count of numbers to ensure the mT5 model generated a number. 93 | 94 | Args: 95 | originals: Original sentences. 96 | transformed: Transformed sentences. 97 | 98 | Returns: 99 | Validated sentences. Returns None for sentences for which 100 | the validations failed. 101 | """ 102 | 103 | def val_func(o: Sentence, p: Sentence) -> bool: 104 | return ( 105 | o != p 106 | and re.search(r"", p.value) is None 107 | and ops.character_insertions(o, p, "<>()[]{}_") == 0 108 | and ops.equal_numbers_count(o, p) 109 | ) 110 | 111 | return functional.lift_boolean_validation(val_func)(originals, transformed) 112 | -------------------------------------------------------------------------------- /smaug/perturb/swap_poisson_span.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import transformers 4 | 5 | from smaug import functional 6 | from smaug import ops 7 | from smaug.core import Data, DataLike, Sentence, SentenceLike 8 | 9 | from typing import Optional 10 | 11 | 12 | def swap_poisson_span( 13 | sentences: DataLike[SentenceLike], 14 | mt5_model: transformers.MT5ForConditionalGeneration, 15 | mt5_tokenizer: transformers.T5Tokenizer, 16 | rng: np.random.Generator, 17 | gpu: bool = False, 18 | ) -> Data[Optional[Sentence]]: 19 | """Replaces a text span with size determined by the Poisson distribution. 20 | 21 | This perturbation masks a text span with size determined by the Poisson 22 | distribution and then uses Google's mT5 to fill the mask. 23 | 24 | It also runs default validations to ensure a minimum quality 25 | level. 26 | 27 | Args: 28 | sentences: Sentences to transform. 29 | mt5_model: mT5 model to use for generation. 30 | mt5_tokenizer: mT5 tokenizer for generation. 31 | rng: Numpy random generator to use. 32 | gpu: Whether to use gpu. 33 | 34 | Returns: 35 | Perturbed sentences. Returns None for sentences for which 36 | the validations failed. 37 | """ 38 | transformed = swap_poisson_span_transform( 39 | sentences, 40 | mt5_model, 41 | mt5_tokenizer, 42 | rng, 43 | gpu, 44 | ) 45 | return swap_poisson_span_validation(sentences, transformed) 46 | 47 | 48 | def swap_poisson_span_transform( 49 | sentences: DataLike[SentenceLike], 50 | mt5_model: transformers.MT5ForConditionalGeneration, 51 | mt5_tokenizer: transformers.T5Tokenizer, 52 | rng: np.random.Generator, 53 | gpu: bool = False, 54 | ) -> Data[Sentence]: 55 | """Performs the transform phase for the swap_poisson_span perturbation. 56 | 57 | Args: 58 | sentences: Sentences to transform. 59 | mt5_model: mT5 model to use for generation. 60 | mt5_tokenizer: mT5 tokenizer for generation. 61 | rng: Numpy random generator to use. 62 | gpu: Whether to use gpu. 63 | 64 | Returns: 65 | Transformed sentences. 66 | """ 67 | masked = ops.mask_poisson_spans( 68 | sentences, 69 | func=ops.mT5_masking_function, 70 | rng=rng, 71 | ) 72 | 73 | return ops.mT5_generate( 74 | masked, 75 | model=mt5_model, 76 | tokenizer=mt5_tokenizer, 77 | cuda=gpu, 78 | ) 79 | 80 | 81 | def swap_poisson_span_validation( 82 | originals: DataLike[SentenceLike], 83 | transformed: DataLike[Optional[SentenceLike]], 84 | ) -> Data[Optional[Sentence]]: 85 | """Performs the validation phase for the swap_poisson_span perturbation. 86 | 87 | It validates that the generated sentences are different from 88 | the original, and ensures a basic quality level by removing 89 | sentences that match the mT5 masking pattern () 90 | and sentences with character insertions for <>()[]{}_, as they are 91 | likely model hallucinations. 92 | 93 | Args: 94 | originals: Original sentences. 95 | transformed: Transformed sentences. 96 | 97 | Returns: 98 | Validated sentences. Returns None for sentences for which 99 | the validations failed. 100 | """ 101 | 102 | def val_func(o: Sentence, p: Sentence) -> bool: 103 | return ( 104 | o != p 105 | and re.search(r"", p.value) is None 106 | and ops.character_insertions(o, p, "<>()[]{}_") == 0 107 | ) 108 | 109 | return functional.lift_boolean_validation(val_func)(originals, transformed) 110 | -------------------------------------------------------------------------------- /smaug/promote.py: -------------------------------------------------------------------------------- 1 | from smaug.core import DataLike, Data, Sentence, SentenceLike, SpanIndex, SpanIndexLike 2 | 3 | from typing import TypeVar 4 | 5 | T = TypeVar("T") 6 | 7 | 8 | def promote_to_data(value: DataLike[T]) -> Data[T]: 9 | """Promotes a value to data. 10 | 11 | The following promotion rules are applied: 12 | * Data objects are returned as is. 13 | * All other objects are wrapped in a Data object of length 1. 14 | 15 | Args: 16 | value (DataLike[T]): Value to promote. 17 | 18 | Returns: 19 | Data[T]: The Data object corresponding to the promoted value. 20 | """ 21 | if isinstance(value, Data): 22 | return value 23 | return Data([value]) 24 | 25 | 26 | def promote_to_span_index(s: SpanIndexLike) -> SpanIndex: 27 | """Promotes a SpanIndexLike object to SpanIndex. 28 | 29 | Args: 30 | s: Object to promote. 31 | 32 | Returns: 33 | Promoted object. 34 | """ 35 | if isinstance(s, SpanIndex): 36 | return s 37 | return SpanIndex(s[0], s[1]) 38 | 39 | 40 | def promote_to_sentence(s: SentenceLike) -> Sentence: 41 | """Promotes a sentence like object to sentence. 42 | 43 | Args: 44 | s: Object to promote. 45 | 46 | Returns: 47 | Promoted object. 48 | """ 49 | if isinstance(s, Sentence): 50 | return s 51 | return Sentence(s) 52 | -------------------------------------------------------------------------------- /smaug/random.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import transformers 3 | import typing 4 | 5 | _SEED: typing.Optional[int] = None 6 | 7 | 8 | def seed_everything(seed: int): 9 | """Seeds every random based framework to allow for reproducibility. 10 | 11 | Args: 12 | seed: Seed value to use. 13 | """ 14 | global _SEED 15 | _SEED = seed 16 | transformers.set_seed(seed) 17 | 18 | 19 | def numpy_seeded_rng() -> np.random.Generator: 20 | """Creates a numpy.random.Generator with the specified seed. 21 | 22 | If the seed is not defined, a generator without seed is created. 23 | 24 | Returns: 25 | The created generator. 26 | """ 27 | if _SEED is None: 28 | return np.random.default_rng() 29 | return np.random.default_rng(_SEED) 30 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from smaug import _itertools 3 | 4 | 5 | @pytest.mark.parametrize( 6 | "iterable,n,expected", 7 | [ 8 | pytest.param( 9 | [], 10 | 10, 11 | [], 12 | id="empty iterable", 13 | ), 14 | pytest.param( 15 | ["A", "B", "C"], 16 | 1, 17 | ["A", "B", "C"], 18 | id="repeat once", 19 | ), 20 | pytest.param( 21 | ["A", "B", "C"], 22 | 3, 23 | ["A", "A", "A", "B", "B", "B", "C", "C", "C"], 24 | id="repeat three times", 25 | ), 26 | ], 27 | ) 28 | def test_repeat_items(iterable, n, expected): 29 | output = _itertools.repeat_items(iterable, n) 30 | assert expected == list(output) 31 | -------------------------------------------------------------------------------- /tests/test_broadcast.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from smaug.broadcast import broadcast_data 4 | from smaug.core import Data 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "values,expected", 9 | [ 10 | pytest.param( 11 | (Data([1]), Data([2]), Data([3])), 12 | (Data([1]), Data([2]), Data([3])), 13 | id="All with length 1", 14 | ), 15 | pytest.param( 16 | (Data([1, 2, 3]), Data([2, 3, 4]), Data([3, 4, 5])), 17 | (Data([1, 2, 3]), Data([2, 3, 4]), Data([3, 4, 5])), 18 | id="All with length 3", 19 | ), 20 | pytest.param( 21 | (Data([1, 2, 3]), Data([2]), Data([3])), 22 | (Data([1, 2, 3]), Data([2, 2, 2]), Data([3, 3, 3])), 23 | id="One with length 3, two with length 1", 24 | ), 25 | pytest.param( 26 | (Data([1]), Data([2, 3, 4]), Data([3, 4, 5])), 27 | (Data([1, 1, 1]), Data([2, 3, 4]), Data([3, 4, 5])), 28 | id="Twi with length 3, one with length 1", 29 | ), 30 | ], 31 | ) 32 | def test_broadcast_data(values, expected): 33 | broadcasted = broadcast_data(*values) 34 | for expected_value, promoted_value in zip(expected, broadcasted): 35 | assert isinstance(promoted_value, Data) 36 | assert len(expected_value) == len(promoted_value) 37 | for e, p in zip(expected_value, promoted_value): 38 | assert e == p 39 | -------------------------------------------------------------------------------- /tests/test_detection.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from smaug import ops 4 | from smaug.core import Data 5 | from smaug.frozen import frozenlist 6 | 7 | from typing import Tuple 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "text,expected", 12 | [ 13 | ( 14 | "Test with 1.23,22 number and another 1.220.", 15 | Data([frozenlist([(10, 17), (37, 42)])]), 16 | ), 17 | ( 18 | "Test with .21234 number and another 2312234.", 19 | Data([frozenlist([(10, 16), (36, 43)])]), 20 | ), 21 | ], 22 | ) 23 | def test_detect_numbers(text: str, expected: Data[frozenlist[Tuple[int, int]]]): 24 | output = ops.regex_detect_numbers(text) 25 | assert isinstance(output, Data) 26 | assert len(expected) == len(output) 27 | for e, o in zip(expected, output): 28 | assert e == o 29 | -------------------------------------------------------------------------------- /tests/test_frozen.py: -------------------------------------------------------------------------------- 1 | from smaug import frozen 2 | 3 | 4 | def test_frozen_list(): 5 | a = frozen.frozenlist() 6 | assert len(a) == 0 7 | assert 1 not in a 8 | 9 | a = a.append(1) 10 | assert len(a) == 1 11 | assert 1 in a 12 | 13 | b = frozen.frozenlist((2, 3)) 14 | assert len(b) == 2 15 | assert 2 in b 16 | assert 3 in b 17 | 18 | c = a + b 19 | assert len(c) == 3 20 | assert c[0] == 1 21 | assert c[1] == 2 22 | assert c[2] == 3 23 | 24 | d, val = c.pop() 25 | assert val == 3 26 | assert len(d) == 2 27 | assert d[0] == 1 28 | assert d[1] == 2 29 | 30 | e, val = c.pop(1) 31 | assert val == 2 32 | assert len(e) == 2 33 | assert e[0] == 1 34 | assert e[1] == 3 35 | -------------------------------------------------------------------------------- /tests/test_itertools.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from smaug import _itertools 3 | 4 | 5 | @pytest.mark.parametrize( 6 | "iterable,n,expected", 7 | [ 8 | pytest.param( 9 | [], 10 | 10, 11 | [], 12 | id="empty iterable", 13 | ), 14 | pytest.param( 15 | ["A", "B", "C"], 16 | 1, 17 | ["A", "B", "C"], 18 | id="repeat once", 19 | ), 20 | pytest.param( 21 | ["A", "B", "C"], 22 | 3, 23 | ["A", "A", "A", "B", "B", "B", "C", "C", "C"], 24 | id="repeat three times", 25 | ), 26 | ], 27 | ) 28 | def test_repeat_items(iterable, n, expected): 29 | output = _itertools.repeat_items(iterable, n) 30 | assert expected == list(output) 31 | -------------------------------------------------------------------------------- /tests/test_masking.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from smaug.core import Data 5 | from smaug.frozen import frozenlist 6 | from smaug import ops 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "docs,intervals,func,expected", 11 | [ 12 | pytest.param( 13 | "Test string with some words", 14 | frozenlist([(0, 4), (10, 15)]), 15 | lambda _: "", 16 | Data([" strinh some words"]), 17 | id="single sentence with string mask", 18 | ), 19 | pytest.param( 20 | Data(["Test string with some words", "2nd string to mask."]), 21 | frozenlist([(0, 4), (10, 15)]), 22 | lambda _: "", 23 | Data([" strinh some words", "stringask."]), 24 | id="multiple sentences with string mask", 25 | ), 26 | pytest.param( 27 | "Test string with some words", 28 | frozenlist([(0, 4), (10, 15)]), 29 | lambda idx: f"", 30 | Data([" strinh some words"]), 31 | id="single sentence with masking function", 32 | ), 33 | pytest.param( 34 | Data(["Test string with some words", "2nd string to mask."]), 35 | frozenlist([(0, 4), (10, 15)]), 36 | lambda idx: f"", 37 | Data([" strinh some words", "stringask."]), 38 | id="multiple sentences with masking function", 39 | ), 40 | ], 41 | ) 42 | def test_mask_intervals(docs, intervals, func, expected): 43 | output = ops.mask_intervals(docs, intervals, func) 44 | assert isinstance(output, Data) 45 | assert len(expected) == len(output) 46 | for e, p in zip(expected, output): 47 | assert e == p.value 48 | 49 | 50 | @pytest.mark.parametrize( 51 | "text,detect_func,mask_func,expected", 52 | [ 53 | pytest.param( 54 | "Test with 1.23,22 number and another 1.220.", 55 | lambda _: Data([frozenlist([(10, 17), (37, 42)])]), 56 | lambda _: "", 57 | Data(["Test with number and another ."]), 58 | id="single sentence with string mask", 59 | ), 60 | pytest.param( 61 | "Test with .21234 number and another 2312234.", 62 | lambda _: Data([frozenlist([(10, 16), (36, 43)])]), 63 | lambda idx: f"", 64 | Data(["Test with number and another ."]), 65 | id="single sentence with masking function", 66 | ), 67 | ], 68 | ) 69 | def test_mask_detections(text, detect_func, mask_func, expected): 70 | output = ops.mask_detections(text, detect_func, mask_func, np.random.default_rng()) 71 | assert isinstance(output, Data) 72 | assert len(expected) == len(output) 73 | for e, p in zip(expected, output): 74 | assert e == p.value 75 | 76 | 77 | @pytest.mark.parametrize( 78 | "text,detect_func,mask_func,expected_opts", 79 | [ 80 | pytest.param( 81 | "Test with 1.23,22 number and another 1.220.", 82 | lambda _: Data([frozenlist([(10, 17), (37, 42)])]), 83 | lambda _: "", 84 | [ 85 | Data(["Test with number and another 1.220."]), 86 | Data(["Test with 1.23,22 number and another ."]), 87 | ], 88 | id="single sentence with string mask", 89 | ), 90 | pytest.param( 91 | "Test with .21234 number and another 2312234.", 92 | lambda _: Data([frozenlist([(10, 16), (36, 43)])]), 93 | lambda idx: f"", 94 | [ 95 | Data(["Test with number and another 2312234."]), 96 | Data(["Test with .21234 number and another ."]), 97 | ], 98 | id="single sentence with masking function", 99 | ), 100 | ], 101 | ) 102 | def test_mask_detections_max_masks(text, detect_func, mask_func, expected_opts): 103 | def matches_func(expected): 104 | return ( 105 | isinstance(output, Data) 106 | and len(expected) == len(output) 107 | and all(e == p.value for e, p in zip(expected, output)) 108 | ) 109 | 110 | output = ops.mask_detections( 111 | text, detect_func, mask_func, np.random.default_rng(), max_masks=1 112 | ) 113 | assert any(matches_func(e) for e in expected_opts) 114 | 115 | 116 | @pytest.mark.parametrize( 117 | "text,func", 118 | [ 119 | pytest.param( 120 | "Test with 1 number and another 1.220.", 121 | lambda _: "", 122 | id="single sentence with string mask", 123 | ), 124 | pytest.param( 125 | "Test with 1 number and another 1.220.", 126 | lambda idx: f"", 127 | id="single sentence with masking function", 128 | ), 129 | ], 130 | ) 131 | def test_random_replace_mask(text: str, func): 132 | output = ops.mask_random_replace(text, func, np.random.default_rng(), p=0.5) 133 | t_splits = text.split() 134 | assert isinstance(output, Data) 135 | assert len(output) == 1 136 | o_splits = output.item().value.split() 137 | assert len(t_splits) == len(o_splits) 138 | 139 | mask_idx = 0 140 | for t_word, o_word in zip(t_splits, o_splits): 141 | # If words differ then a mask should have bee inserted. 142 | if t_word != o_word: 143 | assert func(mask_idx) == o_word 144 | mask_idx += 1 145 | 146 | 147 | @pytest.mark.parametrize( 148 | "text,func", 149 | [ 150 | pytest.param( 151 | "Test with 1 number and another 1.220.", 152 | lambda _: "", 153 | id="single sentence with string mask", 154 | ), 155 | pytest.param( 156 | "Test with 1 number and another 1.220.", 157 | lambda idx: f"", 158 | id="single sentence with masking function", 159 | ), 160 | ], 161 | ) 162 | def test_mask_poisson_spans(text: str, func): 163 | def first_mismatch(list1, list2): 164 | for i in range(min(len(list1), len(list2))): 165 | if list1[i] != list2[i]: 166 | return i 167 | return -1 168 | 169 | output = ops.mask_poisson_spans(text, func, np.random.default_rng()) 170 | assert isinstance(output, Data) 171 | assert len(output) == 1 172 | 173 | o_splits = output.item().value.split() 174 | t_splits = text.split() 175 | 176 | num_splits_diff = len(t_splits) - len(o_splits) 177 | # Maximum one extra word is inserted 178 | assert num_splits_diff >= -1 179 | 180 | # Index of first mismatch going forward 181 | fwd_idx = first_mismatch(t_splits, o_splits) 182 | assert fwd_idx != -1 183 | # Mismatch must happen on mask 184 | assert o_splits[fwd_idx] == func(0) 185 | 186 | # Index of first mismatch going backwards. 187 | # This index only works for reversed splits. 188 | rev_idx = first_mismatch(t_splits[::-1], o_splits[::-1]) 189 | # Can happen if mask was inserted in the beginning while 190 | # masking 0 words 191 | if rev_idx == -1: 192 | rev_idx = len(o_splits) - 1 193 | 194 | # Rev index considering forward o_splits 195 | o_rev_idx = len(o_splits) - rev_idx - 1 196 | # Rev index considering forward t_splits 197 | t_rev_idx = len(t_splits) - rev_idx - 1 198 | 199 | # Mismatch must happen on mask 200 | assert o_splits[o_rev_idx] == func(0) 201 | 202 | # Difference in words must be the same as the difference 203 | # between the indexes. 204 | assert num_splits_diff == t_rev_idx - fwd_idx 205 | 206 | 207 | @pytest.mark.parametrize( 208 | "text,func", 209 | [ 210 | pytest.param( 211 | "Test with 1 number and another 1.220.", 212 | lambda _: "", 213 | id="single sentence with string mask", 214 | ), 215 | pytest.param( 216 | "Test with 1 number and another 1.220.", 217 | lambda idx: f"", 218 | id="single sentence with masking function", 219 | ), 220 | ], 221 | ) 222 | def test_random_insert_mask(text, func): 223 | output = ops.mask_random_insert(text, func, np.random.default_rng(), p=0.5) 224 | 225 | assert isinstance(output, Data) 226 | assert len(output) == 1 227 | t_splits = text.split() 228 | o_splits = output.item().value.split() 229 | assert len(t_splits) <= len(o_splits) 230 | 231 | mask_idx = 0 232 | t_idx = 0 233 | for o_word in o_splits: 234 | # No mask was inserted. Move d_idx forward. 235 | if t_idx < len(t_splits) and t_splits[t_idx] == o_word: 236 | t_idx += 1 237 | # Mask inserted. Verify it is correct. 238 | else: 239 | assert func(mask_idx) == o_word 240 | mask_idx += 1 241 | # All original words were matched. 242 | assert t_idx == len(t_splits) 243 | -------------------------------------------------------------------------------- /tests/test_modification.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from smaug import frozen 4 | from smaug import ops 5 | from smaug.core import Modification, ModificationTrace, SpanIndex 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "old,new,modification", 10 | [ 11 | pytest.param( 12 | 'Sentence with "old text to be replaced" in the middle.', 13 | 'Sentence with "replaced new text" in the middle.', 14 | Modification( 15 | '"old text to be replaced"', 16 | '"replaced new text"', 17 | 14, 18 | ), 19 | id="Modify the middle of the sentence.", 20 | ), 21 | pytest.param( 22 | '"Sentence with old text to be replaced" in the beginning.', 23 | '"Sentence with replaced new text" in the beginning.', 24 | Modification( 25 | '"Sentence with old text to be replaced"', 26 | '"Sentence with replaced new text"', 27 | 0, 28 | ), 29 | id="Modify the beginning of the sentence.", 30 | ), 31 | pytest.param( 32 | 'Sentence with "old text to be replaced in the end".', 33 | 'Sentence with "replaced new text in the end".', 34 | Modification( 35 | '"old text to be replaced in the end".', 36 | '"replaced new text in the end".', 37 | 14, 38 | ), 39 | id="Modify the end of the sentence.", 40 | ), 41 | ], 42 | ) 43 | def test_modification(old: str, new: str, modification: Modification): 44 | new_output = ops.apply_modification(modification, old) 45 | assert new == new_output 46 | old_output = ops.reverse_modification(modification, new_output) 47 | assert old == old_output 48 | 49 | 50 | @pytest.mark.parametrize( 51 | "old,new,trace,expected_spans", 52 | [ 53 | pytest.param( 54 | 'Original Sentence with "text to modify" and "more text to modify".', 55 | 'Original Sentence with "modified text" and "more modifed text".', 56 | ModificationTrace.from_modifications( 57 | Modification('"text to modify"', '"modified text"', 23), 58 | Modification('"more text to modify"', '"more modifed text"', 43), 59 | ), 60 | frozen.frozenlist([SpanIndex(23, 38), SpanIndex(43, 62)]), 61 | id="No overlap", 62 | ), 63 | pytest.param( 64 | 'Original Sentence with "text to modify" and "more text to modify".', 65 | 'Original Sentence with "modified "overlapped text" modifed text".', 66 | ModificationTrace.from_modifications( 67 | Modification('"text to modify"', '"modified text"', 23), 68 | Modification('"more text to modify"', '"more modifed text"', 43), 69 | Modification('text" and "more', '"overlapped text"', 33), 70 | ), 71 | frozen.frozenlist([SpanIndex(23, 64)]), 72 | id="With overlap", 73 | ), 74 | ], 75 | ) 76 | def test_modification_trace( 77 | old: str, 78 | new: str, 79 | trace: ModificationTrace, 80 | expected_spans: frozen.frozenlist[SpanIndex], 81 | ): 82 | new_output = ops.apply_modification_trace(trace, old) 83 | assert new == new_output 84 | old_output = ops.reverse_modification_trace(trace, new_output) 85 | assert old == old_output 86 | 87 | modified_spans = ops.modified_spans_from_trace(trace) 88 | assert len(expected_spans) == len(modified_spans) 89 | for e, m in zip(expected_spans, modified_spans): 90 | assert e.start == m.start 91 | assert e.end == m.end 92 | -------------------------------------------------------------------------------- /tests/test_more_functools.py: -------------------------------------------------------------------------------- 1 | from smaug import more_functools 2 | 3 | 4 | def test_pipe(): 5 | def f1(x): 6 | return x * 2 7 | 8 | def f2(x): 9 | return x + 4 10 | 11 | def f3(x): 12 | return x * 3 13 | 14 | pipe_func = more_functools.pipe(f1, f2, f3) 15 | 16 | x = 1 17 | assert f3(f2(f1(x))) == pipe_func(x) 18 | assert f1(f2(f3(x))) == more_functools.pipe(f3, f2, f1)(x) 19 | -------------------------------------------------------------------------------- /tests/test_promotion.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from smaug.core import Data 4 | from smaug.promote import promote_to_data 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "value,expected", 9 | [ 10 | pytest.param( 11 | Data([1, 2, 3]), 12 | Data([1, 2, 3]), 13 | id="Data of ints", 14 | ), 15 | pytest.param( 16 | Data([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), 17 | Data([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), 18 | id="Data of lists of ints", 19 | ), 20 | pytest.param( 21 | [[1, 2, 3], [4, 5, 6], [7, 8, 9]], 22 | Data([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]), 23 | id="List of lists of ints", 24 | ), 25 | pytest.param( 26 | 1, 27 | Data([1]), 28 | id="Int", 29 | ), 30 | ], 31 | ) 32 | def test_promote_to_data(value, expected): 33 | promoted = promote_to_data(value) 34 | assert isinstance(promoted, Data) 35 | assert len(expected) == len(promoted) 36 | for e, p in zip(expected, promoted): 37 | assert e == p 38 | -------------------------------------------------------------------------------- /tests/test_sentence.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from smaug import ops 4 | from smaug.core import Sentence, Modification, ModificationTrace, SpanIndexLike 5 | 6 | from typing import Optional 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "original,span,idx,expected", 11 | [ 12 | pytest.param( 13 | Sentence("Original Sentence without modifications."), 14 | ", text to add at the middle,", 15 | 17, 16 | Sentence( 17 | "Original Sentence, text to add at the middle, without modifications.", 18 | trace=ModificationTrace.from_modifications( 19 | Modification("", ", text to add at the middle,", 17), 20 | ), 21 | ), 22 | id="Middle of sentence insertion.", 23 | ), 24 | pytest.param( 25 | Sentence("Original Sentence without modifications."), 26 | "Text to add at the beginning. ", 27 | 0, 28 | Sentence( 29 | "Text to add at the beginning. Original Sentence without modifications.", 30 | trace=ModificationTrace.from_modifications( 31 | Modification("", "Text to add at the beginning. ", 0), 32 | ), 33 | ), 34 | id="Beginning of sentence insertion.", 35 | ), 36 | pytest.param( 37 | Sentence("Original Sentence without modifications."), 38 | " Text to add at the end.", 39 | 40, 40 | Sentence( 41 | "Original Sentence without modifications. Text to add at the end.", 42 | trace=ModificationTrace.from_modifications( 43 | Modification("", " Text to add at the end.", 40), 44 | ), 45 | ), 46 | id="End of sentence replacement.", 47 | ), 48 | ], 49 | ) 50 | def test_insert(original: Sentence, span: str, idx: int, expected: Sentence): 51 | output = ops.insert(original, span, idx) 52 | _assert_equal_sentences(expected, output) 53 | 54 | 55 | @pytest.mark.parametrize( 56 | "original,loc,expected", 57 | [ 58 | pytest.param( 59 | Sentence( 60 | "Original Sentence, text to delete at the middle, without modifications." 61 | ), 62 | (17, 48), 63 | Sentence( 64 | "Original Sentence without modifications.", 65 | trace=ModificationTrace.from_modifications( 66 | Modification(", text to delete at the middle,", "", 17), 67 | ), 68 | ), 69 | id="Middle of sentence deletion.", 70 | ), 71 | pytest.param( 72 | Sentence( 73 | "Text to delete at the beginning. Original Sentence without modifications." 74 | ), 75 | (0, 33), 76 | Sentence( 77 | "Original Sentence without modifications.", 78 | trace=ModificationTrace.from_modifications( 79 | Modification("Text to delete at the beginning. ", "", 0), 80 | ), 81 | ), 82 | id="Beginning of sentence deletion.", 83 | ), 84 | pytest.param( 85 | Sentence( 86 | "Original Sentence without modifications. Text to delete at the end." 87 | ), 88 | (40, 67), 89 | Sentence( 90 | "Original Sentence without modifications.", 91 | trace=ModificationTrace.from_modifications( 92 | Modification(" Text to delete at the end.", "", 40), 93 | ), 94 | ), 95 | id="End of sentence deletion.", 96 | ), 97 | ], 98 | ) 99 | def test_deletion(original: Sentence, loc: SpanIndexLike, expected: Sentence): 100 | output = ops.delete(original, loc) 101 | _assert_equal_sentences(expected, output) 102 | 103 | 104 | @pytest.mark.parametrize( 105 | "original,span,loc,expected", 106 | [ 107 | pytest.param( 108 | Sentence("Original Sentence without modifications."), 109 | ', text to replace " without",', 110 | (17, 25), 111 | Sentence( 112 | 'Original Sentence, text to replace " without", modifications.', 113 | trace=ModificationTrace.from_modifications( 114 | Modification(" without", ', text to replace " without",', 17), 115 | ), 116 | ), 117 | id="Middle of sentence replacement.", 118 | ), 119 | pytest.param( 120 | Sentence("Original Sentence without modifications."), 121 | "Text to replace Original.", 122 | (0, 8), 123 | Sentence( 124 | "Text to replace Original. Sentence without modifications.", 125 | trace=ModificationTrace.from_modifications( 126 | Modification("Original", "Text to replace Original.", 0), 127 | ), 128 | ), 129 | id="Beginning of sentence replacement.", 130 | ), 131 | pytest.param( 132 | Sentence("Original Sentence without modifications."), 133 | '. Text to replace " modifications.".', 134 | (25, 40), 135 | Sentence( 136 | 'Original Sentence without. Text to replace " modifications.".', 137 | trace=ModificationTrace.from_modifications( 138 | Modification( 139 | " modifications.", '. Text to replace " modifications.".', 25 140 | ), 141 | ), 142 | ), 143 | id="End of sentence replacement.", 144 | ), 145 | ], 146 | ) 147 | def test_replace(original: Sentence, span: str, loc: SpanIndexLike, expected: Sentence): 148 | output = ops.replace(original, span, loc) 149 | _assert_equal_sentences(expected, output) 150 | 151 | 152 | def _assert_equal_sentences(expected: Sentence, actual: Sentence): 153 | assert expected.value == actual.value 154 | _assert_equal_traces(expected.trace, actual.trace) 155 | 156 | 157 | def _assert_equal_traces( 158 | expected: Optional[ModificationTrace], 159 | actual: Optional[ModificationTrace], 160 | ): 161 | if expected is None: 162 | assert actual is None 163 | else: 164 | assert actual is not None 165 | assert expected.curr == actual.curr 166 | _assert_equal_traces(expected.prev, actual.prev) 167 | -------------------------------------------------------------------------------- /tests/test_sentence_comparison.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from smaug import ops 4 | from smaug.core import Sentence 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "original,modified,expected", 9 | [ 10 | ( 11 | Sentence("Some sentence without any special characters."), 12 | Sentence("Other sentence without chars."), 13 | 0, 14 | ), 15 | ( 16 | Sentence("Some sentence without any special characters."), 17 | Sentence("Other sentence with ."), 18 | 2, 19 | ), 20 | ( 21 | Sentence("Some sentence with some