├── .coveragerc ├── .gitignore ├── README.md ├── cli.py ├── commands ├── __init__.py ├── attack.py ├── dataset.py ├── evaluation.py └── utils │ ├── __init__.py │ ├── dataset.py │ ├── metrics.py │ └── preprocess.py ├── configs ├── conll2003-config.json └── personal │ └── .gitignore ├── experiments ├── analysis_utils.py ├── analyze-json.py └── attack_result.py ├── requirements.txt ├── scripts └── gcp.sh ├── seqattack ├── __init__.py ├── attacks │ ├── __init__.py │ ├── bae.py │ ├── bert_attack.py │ ├── deepwordbug.py │ ├── morpheus.py │ ├── ner_clare.py │ ├── scpn.py │ ├── seqattack_recipe.py │ └── textfooler.py ├── constraints │ ├── __init__.py │ ├── avoid_named_entities.py │ ├── model_errors.py │ ├── ner.py │ ├── skip_negations.py │ └── skip_non_ascii.py ├── datasets │ ├── __init__.py │ ├── huggingfacener.py │ └── ner.py ├── goal_functions │ ├── __init__.py │ ├── ner.py │ ├── ner_goal_function_result.py │ ├── targeted_ner.py │ ├── untargeted_ner.py │ └── untargeted_ner_strict.py ├── models │ ├── __init__.py │ ├── exceptions │ │ └── __init__.py │ └── ner.py ├── search │ ├── __init__.py │ ├── greedy.py │ └── greedy_ner_swap.py ├── transformations │ ├── __init__.py │ ├── paraphrase.py │ └── roberta_word_insert.py └── utils │ ├── __init__.py │ ├── attack.py │ ├── attack_runner.py │ ├── ner_attacked_text.py │ └── sequence.py └── tests ├── .gitignore ├── __init__.py ├── fixtures.py ├── mock └── attacked.json ├── test_avoid_named_entity.py ├── test_ner_attacked_text.py ├── test_non_ne_constraint.py ├── test_targeted_ner.py ├── test_untargeted_ner.py ├── test_untargeted_ner_strict.py ├── test_utils.py └── utils └── __init__.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | textattack/* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # Project specific files 141 | .ipynb_checkpoints/* 142 | TextAttack 143 | 144 | # Datasets 145 | conll2003/* 146 | datasets/* 147 | 148 | # TextAttack folder (local clone) 149 | textattack/* 150 | add_labels.py 151 | pg.py 152 | data/* 153 | openattack.ipynb 154 | models/* 155 | experiments/data/attack/* 156 | test.sh 157 | .vscode/* 158 | a.json -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SeqAttack: a framework for adversarial attacks on token classification models 2 | 3 | SeqAttack is a framework for conducting adversarial attacks against Named Entity Recognition (NER) models and for data augmentation. This library is heavily based on the popular [TextAttack](https://github.com/QData/TextAttack) framework, and can similarly be used for: 4 | 5 | - Understanding models by running adversarial attacks against them and observing their shortcomings 6 | - Develop new attack strategies 7 | - Guided data augmentation, generating additional training samples that can be used to fix a model's shortcomings 8 | 9 | The **SeqAttack** paper is available [here](https://aclanthology.org/2021.emnlp-demo.35.pdf). 10 | 11 | ### Setup 12 | 13 | Run `pip install -r requirements.txt` and you're good to go! If you want to run experiments on a fresh virtual machine, check out `scripts/gcp.sh` which installs all system dependencies for running the code. 14 | 15 | The code was tested with `python 3.7`, if you're using a different version your mileage may vary. 16 | 17 | ### Usage 18 | 19 | The main features of the framework are available via the command line interface, wrapped by `cli.py`. The following subsections describe the usage of the various commands. 20 | 21 | #### Attack 22 | 23 | Attacks are executed via the `python cli.py attack` subcommand. Attack commands are split in two parts: 24 | 25 | - General setup: options common to all adversarial attacks (e.g. model, dataset...) 26 | - Attack specific setup: options specific to a particular attack strategy 27 | 28 | Thus, a typical attack command might look like the following: 29 | 30 | ```sh 31 | python cli.py attack [general-options] attack-recipe [recipe-options] 32 | ``` 33 | 34 | For example, if we want to attack [dslim/bert-base-NER](https://huggingface.co/dslim/bert-base-NER), a NER model trained on CoNLL2003 using `deepwordbug` as the attack strategy we might run: 35 | 36 | ```sh 37 | python cli.py attack \ 38 | --model-name dslim/bert-base-NER \ 39 | --output-path output-dataset.json \ 40 | --cache \ 41 | --dataset-config configs/conll2003-config.json \ 42 | deepwordbug 43 | ``` 44 | 45 | The dataset configuration file, `configs/conll2003-config.json` defines: 46 | 47 | - The dataset path or name (in the latter case it will be downloaded from [HuggingFace](https://huggingface.co/datasets)) 48 | - The split (e.g. train, test). Only for HuggingFace datasets 49 | - The human-readable names (a mapping between numerical labels and textual labels), given as a list 50 | - A `labels map`, used to remap the dataset's ground truth to align it with the model output as needed. This field can be `null` if no remapping is needed 51 | 52 | In the example above, `labels_map` is used to align the dataset labels to the output from `dslim/bert-base-NER`. The dataset labels are the following: 53 | 54 | `O (0), B-PER (1), I-PER (2), B-ORG (3), I-ORG (4) B-LOC (5), I-LOC (6) B-MISC (7), I-MISC (8)` 55 | 56 | while the model labels are: 57 | 58 | `O (0), B-MISC (1), I-MISC (2), B-PER (3), I-PER (4) B-ORG (5), I-ORG (6) B-LOC (7), I-LOC (8)` 59 | 60 | Thus a remapping is needed and `labels_map` takes care of it. 61 | 62 | --- 63 | 64 | The available attack strategies are the following: 65 | 66 | | Attack Strategy | Transformation | Constraints | Paper | 67 | |-----------------|------------------------------------------------------------------|--------------------------------------------------------------------|--------------------------------------------------------| 68 | | BAE | word swap | USE sentence cosine similarity | https://arxiv.org/abs/2004.01970 | 69 | | BERT-Attack | word swap | USE sentence cosine similarity, Maximum words perturbed | https://arxiv.org/abs/2004.09984 | 70 | | CLARE | word swap and insertion | USE sentence cosine similarity | https://arxiv.org/abs/2009.07502 | 71 | | DeepWordBug | character insertion, deletion, swap (ab --> ba) and substitution | Levenshtein edit distance | https://arxiv.org/abs/1801.04354 | 72 | | Morpheus | inflection word swap | | https://www.aclweb.org/anthology/2020.acl-main.263.pdf | 73 | | SCPN | paraphrasing | | https://www.aclweb.org/anthology/N18-1170 | 74 | | TextFooler | word swap | USE sentence cosine similarity, POS match, word-embedding distance | https://arxiv.org/abs/1907.11932 | 75 | 76 | The table above is based on [this table](https://github.com/QData/TextAttack#attacks-and-papers-implemented-attack-recipes-textattack-attack---recipe-recipe_name). In addition to the constraints shown above the attack strategies **are also forbidden from modifying and inserting named entities by default**. 77 | 78 | #### Evaluation 79 | 80 | To evaluate a model against a standard dataset run: 81 | 82 | ```sh 83 | python cli.py evaluate \ 84 | --model dslim/bert-base-NER \ 85 | --dataset conll2003 \ 86 | --split test \ 87 | --mode strict \ 88 | ``` 89 | 90 | To evaluate the effectivenes of an attack run the following command: 91 | 92 | ```sh 93 | python cli.py evaluate \ 94 | --model dslim/bert-base-NER \ 95 | --attacked-dataset experiments/deepwordbug.json \ 96 | --mode strict \ 97 | ``` 98 | 99 | The above command will compute and display the metrics for the original predictions and their adversarial counterparts. 100 | 101 | The evaluation is based on [seqeval](https://github.com/chakki-works/seqeval) 102 | 103 | #### Dataset selection 104 | 105 | Given a dataset, our victim model may be able to predict some dataset samples perfectly, but it may produce significant errors on others. To evaluate an attack's effectiveness we may want to select samples with a small initial misprediction score. This can be done via the following command: 106 | 107 | ```sh 108 | python cli.py pick-samples \ 109 | --model dslim/bert-base-NER \ 110 | --dataset-config configs/conll2003-config.json \ 111 | --max-samples 256 \ 112 | --max-initial-score 0.5 \ # The maximum initial misprediction score 113 | --output-filename cherry-picked.json \ 114 | --goal-function untargeted 115 | ``` 116 | 117 | 118 | ### Tests 119 | 120 | Tests can be run with `pytest` 121 | 122 | ### Adversarial examples visualization 123 | 124 | The output datasets can be visualized with [SeqAttack-Visualization](https://github.com/WalterSimoncini/SeqAttack-Visualization) -------------------------------------------------------------------------------- /cli.py: -------------------------------------------------------------------------------- 1 | import click 2 | 3 | from commands import ( 4 | attack, 5 | evaluate_attacked, 6 | evaluate, 7 | pick_samples 8 | ) 9 | 10 | 11 | @click.group() 12 | @click.pass_context 13 | def cli(ctx): 14 | pass 15 | 16 | 17 | if __name__ == '__main__': 18 | cli.add_command(attack) 19 | cli.add_command(evaluate) 20 | cli.add_command(pick_samples) 21 | cli.add_command(evaluate_attacked) 22 | 23 | cli(obj={}) 24 | -------------------------------------------------------------------------------- /commands/__init__.py: -------------------------------------------------------------------------------- 1 | from .attack import attack 2 | from .evaluation import evaluate, evaluate_attacked 3 | from .dataset import pick_samples 4 | -------------------------------------------------------------------------------- /commands/attack.py: -------------------------------------------------------------------------------- 1 | import click 2 | import random 3 | 4 | from seqattack.models import NERModelWrapper 5 | from seqattack.datasets import NERHuggingFaceDataset 6 | from seqattack.utils.attack_runner import AttackRunner 7 | from seqattack.goal_functions import get_goal_function 8 | 9 | from seqattack.attacks import ( 10 | NERCLARE, 11 | BertAttackNER, 12 | NERBAEGarg2019, 13 | NERSCPNParaphrase, 14 | MorpheusTan2020NER, 15 | NERTextFoolerJin2019, 16 | NERDeepWordBugGao2018 17 | ) 18 | 19 | 20 | @click.group() 21 | @click.option("--model-name", default="dslim/bert-base-NER") 22 | @click.option("--output-path", type=str) 23 | @click.option("--random-seed", default=567) 24 | @click.option("--num-examples", default=256) 25 | @click.option("--max-entities-mispredicted", default=0.8) 26 | @click.option("--cache/--no-cache", default=False) 27 | @click.option("--goal-function", default="untargeted") 28 | @click.option("--max-queries", default=512) 29 | @click.option("--attack-timeout", default=60) 30 | @click.option("--dataset-config", default=None, required=True) 31 | @click.pass_context 32 | def attack( 33 | ctx, 34 | model_name, 35 | output_path, 36 | random_seed, 37 | num_examples, 38 | max_entities_mispredicted, 39 | cache, 40 | goal_function, 41 | max_queries, 42 | attack_timeout, 43 | dataset_config): 44 | random.seed(random_seed) 45 | 46 | goal_function_cls = get_goal_function(goal_function) 47 | dataset = NERHuggingFaceDataset.from_config_file( 48 | dataset_config, 49 | num_examples=num_examples 50 | ) 51 | 52 | # Load model and tokenizer 53 | tokenizer, model = NERModelWrapper.load_huggingface_model(model_name) 54 | 55 | ctx.ensure_object(dict) 56 | ctx.obj["attack_args"] = { 57 | "random_seed": random_seed, 58 | "model": model, 59 | "model_name": model_name, 60 | "tokenizer": tokenizer, 61 | "goal_function_class": goal_function_cls, 62 | "use_cache": cache, 63 | "query_budget": max_queries, 64 | "dataset": dataset, 65 | "max_entities_mispredicted": max_entities_mispredicted, 66 | "output_path": output_path, 67 | "attack_timeout": attack_timeout, 68 | "num_examples": num_examples 69 | } 70 | 71 | 72 | @attack.command() 73 | @click.option("--max-words-perturbed", default=0.4) 74 | @click.option("--max-candidates", default=48) 75 | @click.pass_context 76 | def bert_attack(ctx, max_words_perturbed, max_candidates): 77 | bert_attack_args = { 78 | "recipe": "bert-attack", 79 | "max_perturbed_percent": max_words_perturbed, 80 | "max_candidates": max_candidates, 81 | "additional_constraints": BertAttackNER.get_ner_constraints(ctx.obj["attack_args"]["model_name"]) 82 | } 83 | 84 | ctx.obj["attack_args"] = {**ctx.obj["attack_args"], **bert_attack_args} 85 | ctx.obj["attack_args"]["recipe_metadata"] = bert_attack_args 86 | 87 | attack = BertAttackNER.build(**ctx.obj["attack_args"]) 88 | 89 | AttackRunner( 90 | attack=attack, 91 | dataset=ctx.obj["attack_args"]["dataset"], 92 | output_filename=ctx.obj["attack_args"]["output_path"], 93 | attack_args=ctx.obj["attack_args"] 94 | ).run() 95 | 96 | 97 | @attack.command() 98 | @click.option("--max-candidates", default=48) 99 | @click.pass_context 100 | def clare(ctx, max_candidates): 101 | clare_attack_args = { 102 | "recipe": "clare", 103 | "max_candidates": max_candidates, 104 | "additional_constraints": NERCLARE.get_ner_constraints(ctx.obj["attack_args"]["model_name"]) 105 | } 106 | 107 | ctx.obj["attack_args"] = {**ctx.obj["attack_args"], **clare_attack_args} 108 | ctx.obj["attack_args"]["recipe_metadata"] = clare_attack_args 109 | 110 | attack = NERCLARE.build(**ctx.obj["attack_args"]) 111 | 112 | AttackRunner( 113 | attack=attack, 114 | dataset=ctx.obj["attack_args"]["dataset"], 115 | output_filename=ctx.obj["attack_args"]["output_path"], 116 | attack_args=ctx.obj["attack_args"] 117 | ).run() 118 | 119 | 120 | @attack.command() 121 | @click.option("--max-edit-distance", default=50) 122 | @click.option("--preserve-named-entities/--no-preserve-named-entities", default=True) 123 | @click.pass_context 124 | def deepwordbug(ctx, max_edit_distance, preserve_named_entities): 125 | deepwordbug_args = { 126 | "recipe": "deepwordbug", 127 | "max_edit_distance": max_edit_distance, 128 | "additional_constraints": NERDeepWordBugGao2018.get_ner_constraints( 129 | ctx.obj["attack_args"]["model_name"], 130 | **{"preserve_named_entities": preserve_named_entities} 131 | ) 132 | } 133 | 134 | ctx.obj["attack_args"] = {**ctx.obj["attack_args"], **deepwordbug_args} 135 | ctx.obj["attack_args"]["recipe_metadata"] = deepwordbug_args 136 | 137 | attack = NERDeepWordBugGao2018.build(**ctx.obj["attack_args"]) 138 | 139 | AttackRunner( 140 | attack=attack, 141 | dataset=ctx.obj["attack_args"]["dataset"], 142 | output_filename=ctx.obj["attack_args"]["output_path"], 143 | attack_args=ctx.obj["attack_args"] 144 | ).run() 145 | 146 | 147 | @attack.command() 148 | @click.pass_context 149 | def scpn(ctx): 150 | scpn_attack_args = { 151 | "recipe": "scpn", 152 | "additional_constraints": [] 153 | } 154 | 155 | ctx.obj["attack_args"] = {**ctx.obj["attack_args"], **scpn_attack_args} 156 | ctx.obj["attack_args"]["recipe_metadata"] = scpn_attack_args 157 | 158 | attack = NERSCPNParaphrase.build(**ctx.obj["attack_args"]) 159 | 160 | AttackRunner( 161 | attack=attack, 162 | dataset=ctx.obj["attack_args"]["dataset"], 163 | output_filename=ctx.obj["attack_args"]["output_path"], 164 | attack_args=ctx.obj["attack_args"] 165 | ).run() 166 | 167 | 168 | @attack.command() 169 | @click.option("--max-candidates", default=50) 170 | @click.pass_context 171 | def textfooler(ctx, max_candidates): 172 | textfooler_attack_args = { 173 | "recipe": "textfooler", 174 | "max_candidates": max_candidates, 175 | "additional_constraints": NERTextFoolerJin2019.get_ner_constraints(ctx.obj["attack_args"]["model_name"]) 176 | } 177 | 178 | ctx.obj["attack_args"] = {**ctx.obj["attack_args"], **textfooler_attack_args} 179 | ctx.obj["attack_args"]["recipe_metadata"] = textfooler_attack_args 180 | 181 | attack = NERTextFoolerJin2019.build(**ctx.obj["attack_args"]) 182 | 183 | AttackRunner( 184 | attack=attack, 185 | dataset=ctx.obj["attack_args"]["dataset"], 186 | output_filename=ctx.obj["attack_args"]["output_path"], 187 | attack_args=ctx.obj["attack_args"] 188 | ).run() 189 | 190 | 191 | @attack.command() 192 | @click.option("--max-candidates", default=50) 193 | @click.pass_context 194 | def bae(ctx, max_candidates): 195 | bae_attack_args = { 196 | "recipe": "bae", 197 | "max_candidates": max_candidates, 198 | "additional_constraints": NERBAEGarg2019.get_ner_constraints(ctx.obj["attack_args"]["model_name"]) 199 | } 200 | 201 | ctx.obj["attack_args"] = {**ctx.obj["attack_args"], **bae_attack_args} 202 | ctx.obj["attack_args"]["recipe_metadata"] = bae_attack_args 203 | 204 | attack = NERBAEGarg2019.build(**ctx.obj["attack_args"]) 205 | 206 | AttackRunner( 207 | attack=attack, 208 | dataset=ctx.obj["attack_args"]["dataset"], 209 | output_filename=ctx.obj["attack_args"]["output_path"], 210 | attack_args=ctx.obj["attack_args"] 211 | ).run() 212 | 213 | 214 | @attack.command() 215 | @click.pass_context 216 | def morpheus(ctx): 217 | morpheus_args = { 218 | "recipe": "morpheus", 219 | "additional_constraints": MorpheusTan2020NER.get_ner_constraints( 220 | ctx.obj["attack_args"]["model_name"] 221 | ) 222 | } 223 | 224 | ctx.obj["attack_args"] = {**ctx.obj["attack_args"], **morpheus_args} 225 | ctx.obj["attack_args"]["recipe_metadata"] = morpheus_args 226 | 227 | attack = MorpheusTan2020NER.build(**ctx.obj["attack_args"]) 228 | 229 | AttackRunner( 230 | attack=attack, 231 | dataset=ctx.obj["attack_args"]["dataset"], 232 | output_filename=ctx.obj["attack_args"]["output_path"], 233 | attack_args=ctx.obj["attack_args"] 234 | ).run() 235 | -------------------------------------------------------------------------------- /commands/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import click 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | from commands.utils import remap_negations 7 | 8 | from seqattack.models import NERModelWrapper 9 | from seqattack.datasets import NERHuggingFaceDataset 10 | from seqattack.goal_functions import get_goal_function 11 | from seqattack.utils import postprocess_ner_output 12 | from seqattack.utils.ner_attacked_text import NERAttackedText 13 | 14 | 15 | @click.command() 16 | @click.option("--model", required=True, type=str) 17 | @click.option("--dataset-config", default=None, type=str) 18 | @click.option("--max-samples", required=True, type=int) 19 | @click.option("--max-initial-score", required=True, type=float) 20 | @click.option("--output-filename", required=True, type=str) 21 | @click.option("--goal-function", default="untargeted") 22 | @click.pass_context 23 | def pick_samples( 24 | ctx, 25 | model, 26 | dataset_config, 27 | max_samples, 28 | max_initial_score, 29 | output_filename, 30 | goal_function): 31 | """ 32 | Extracts a subset of samples from a dataset with a 33 | given maximum initial misprediction score 34 | """ 35 | tokenizer, model = NERModelWrapper.load_huggingface_model(model) 36 | dataset = NERHuggingFaceDataset.from_config_file(dataset_config) 37 | 38 | goal_function_cls = get_goal_function(goal_function) 39 | goal_function = goal_function_cls( 40 | model_wrapper=model, 41 | tokenizer=tokenizer, 42 | use_cache=True, 43 | ner_postprocess_func=postprocess_ner_output, 44 | label_names=dataset.label_names 45 | ) 46 | 47 | # Negations must be remapped to avoid prediction errors 48 | dataset.dataset = remap_negations(dataset.dataset) 49 | 50 | progress_bar = tqdm(total=max_samples) 51 | selected_samples, initial_scores = [], [] 52 | 53 | for sample, ground_truth in dataset.dataset: 54 | if len(selected_samples) >= max_samples: 55 | break 56 | 57 | attacked_text = NERAttackedText( 58 | sample, 59 | attack_attrs={"label_names": dataset.label_names}, 60 | ground_truth=ground_truth.tolist()) 61 | 62 | try: 63 | goal_function.init_attack_example(attacked_text, ground_truth) 64 | 65 | model_raw = model([sample])[0] 66 | model_preds = model.process_raw_output(model_raw, attacked_text.text) 67 | 68 | initial_score = goal_function._get_score_labels(model_preds, attacked_text) 69 | except Exception as ex: 70 | print(f"Error scoring {sample}: {ex}") 71 | 72 | if initial_score <= max_initial_score: 73 | progress_bar.update(1) 74 | initial_scores.append(initial_score) 75 | selected_samples.append((sample, ground_truth.tolist())) 76 | 77 | # Save samples 78 | with open(output_filename, "w") as out_file: 79 | out_file.write(json.dumps({ 80 | "meta": { 81 | "max_samples": max_samples, 82 | "max_initial_score": max_initial_score, 83 | "dataset": dataset.name, 84 | "split": dataset.split, 85 | "dataset_labels": dataset.label_names 86 | }, 87 | "samples": selected_samples 88 | })) 89 | 90 | print(f"Selected {len(selected_samples)} from {dataset.name}") 91 | print(f"Average misprediction score: {np.array(initial_scores).mean()}") 92 | -------------------------------------------------------------------------------- /commands/evaluation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import click 3 | 4 | from commands.utils import ( 5 | calculate_metrics, 6 | remap_negations, 7 | remap_negations_single 8 | ) 9 | 10 | from commands.utils import ( 11 | load_attacked_dataset, 12 | extract_attacked_dataset 13 | ) 14 | 15 | from seqattack.models import NERModelWrapper 16 | from seqattack.datasets import NERHuggingFaceDataset 17 | 18 | 19 | @click.command() 20 | @click.option("--model", required=True, type=str) 21 | @click.option("--dataset-config", default=None, type=str) 22 | @click.option("--mode", default=None, type=str) 23 | @click.pass_context 24 | def evaluate(ctx, model, dataset_config, mode): 25 | _, model = NERModelWrapper.load_huggingface_model(model) 26 | dataset = NERHuggingFaceDataset.from_config_file(dataset_config) 27 | 28 | # Negations must be remapped to avoid prediction errors 29 | dataset.dataset = remap_negations(dataset.dataset) 30 | labelled_dataset = [(sample[0], [dataset.label_names[x] for x in sample[1]]) for sample in dataset] 31 | 32 | original_str, _ = calculate_metrics( 33 | model, 34 | labelled_dataset, 35 | dataset.label_names, 36 | mode=mode) 37 | 38 | print() 39 | print(original_str) 40 | 41 | 42 | @click.command() 43 | @click.option("--model", type=str) 44 | @click.option("--attacked-dataset", default=None, type=str) 45 | @click.option("--output-filename", required=False, type=str) 46 | @click.option("--mode", default=None, type=str) 47 | @click.pass_context 48 | def evaluate_attacked(ctx, model, attacked_dataset, output_filename, mode): 49 | _, model = NERModelWrapper.load_huggingface_model(model) 50 | config, input_dataset = load_attacked_dataset( 51 | attacked_dataset 52 | ) 53 | 54 | attacked_dataset = extract_attacked_dataset(input_dataset) 55 | original_dataset = [ 56 | remap_negations_single( 57 | sample["original_text"], 58 | sample.get("ground_truth_labels", []), 59 | str_labels=True 60 | ) for sample in input_dataset 61 | ] 62 | 63 | original_str, original_metrics = calculate_metrics(model, original_dataset, config["labels"], mode=mode) 64 | attacked_str, attacked_metrics = calculate_metrics(model, attacked_dataset, config["labels"], mode=mode) 65 | 66 | print("Original metrics: \n") 67 | print(original_str) 68 | 69 | print("Attacked metrics: \n") 70 | print(attacked_str) 71 | 72 | if output_filename is not None: 73 | with open(output_filename, "w") as out_file: 74 | out_file.write(json.dumps({ 75 | "original": original_metrics, 76 | "attacked": attacked_metrics 77 | })) 78 | -------------------------------------------------------------------------------- /commands/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import * 2 | from .preprocess import * 3 | from .dataset import * 4 | -------------------------------------------------------------------------------- /commands/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from .preprocess import remap_negations_single 4 | 5 | 6 | def load_attacked_dataset(path): 7 | json_data = json.loads( 8 | open(path).read()) 9 | 10 | samples_list = json_data 11 | 12 | if type(json_data) == dict: 13 | # Extract samples list from newer formats of the datasets 14 | samples_list = json_data["attacked_examples"] 15 | 16 | return json_data["config"], samples_list 17 | 18 | 19 | def extract_attacked_dataset(attacked_dataset): 20 | """ 21 | Creates a copy of the original dataset by replacing original 22 | samples with their adversarial counterpart 23 | """ 24 | out_attacked_dataset = [] 25 | 26 | for sample in attacked_dataset: 27 | ner_labels = sample.get("ground_truth_labels", []) 28 | input_text = sample["original_text"] 29 | 30 | if sample["status"] == "Successful": 31 | input_text = sample["perturbed_text"] 32 | 33 | if "final_ground_truth_labels" in sample: 34 | ner_labels = sample["final_ground_truth_labels"] 35 | 36 | input_text, ner_labels = remap_negations_single( 37 | input_text, 38 | ner_labels, 39 | str_labels=True 40 | ) 41 | 42 | out_attacked_dataset.append(( 43 | input_text, ner_labels 44 | )) 45 | 46 | return out_attacked_dataset 47 | -------------------------------------------------------------------------------- /commands/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | from seqeval.scheme import IOB2 4 | from seqeval.metrics import classification_report 5 | 6 | 7 | def calculate_metrics(model, dataset, label_names, mode=None): 8 | """ 9 | Given a dataset in the format (text_input, ground_truth_labels) 10 | and a model this function calculates the model prediction for 11 | each text samples and uses seqeval to calculate the precision, 12 | recall and F1 metrics 13 | """ 14 | ground_truths, model_predictions = [], [] 15 | 16 | for i in tqdm(range(len(dataset))): 17 | input_text, ground_truth_labels = dataset[i] 18 | 19 | predicted_labels = predict_labels( 20 | model, 21 | input_text, 22 | label_names) 23 | 24 | ground_truths.append(ground_truth_labels) 25 | model_predictions.append(predicted_labels) 26 | 27 | classification_dict = classification_report( 28 | ground_truths, 29 | model_predictions, 30 | scheme=IOB2, 31 | mode=mode, 32 | output_dict=True) 33 | 34 | classification_str = classification_report( 35 | ground_truths, 36 | model_predictions, 37 | scheme=IOB2, 38 | mode=mode, 39 | output_dict=False) 40 | 41 | return classification_str, serialize_metrics_dict(classification_dict) 42 | 43 | 44 | def serialize_metrics_dict(metrics_dict): 45 | """ 46 | Converts the values of the dictionary output of seqeval 47 | from numpy int64/float64 to floats 48 | """ 49 | out_dict = {} 50 | 51 | for top_k in metrics_dict.keys(): 52 | out_dict[top_k] = {} 53 | 54 | for sub_k in metrics_dict[top_k].keys(): 55 | out_dict[top_k][sub_k] = float(metrics_dict[top_k][sub_k]) 56 | 57 | return out_dict 58 | 59 | 60 | def predict_labels(model, sample: str, dataset_labels: list): 61 | """ 62 | Predicts a single textual sample and returns a list of 63 | classes, one per token 64 | """ 65 | prediction = model([sample])[0] 66 | prediction = model.process_raw_output(prediction, sample).tolist() 67 | 68 | return prediction_to_labels(prediction, dataset_labels) 69 | 70 | 71 | def prediction_to_labels(prediction, labels): 72 | return [labels[x] for x in list(prediction)] 73 | -------------------------------------------------------------------------------- /commands/utils/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def remap_negations(dataset): 5 | """ 6 | Given a dataset maps negations to the form 7 | root-n ' t (e.g. do not --> don ' t, 8 | does n't --> doesn ' t) 9 | """ 10 | out_dataset = [] 11 | 12 | for sample_idx in range(len(dataset)): 13 | sample_text, sample_truth = dataset[sample_idx] 14 | 15 | out_dataset.append( 16 | remap_negations_single(sample_text, sample_truth) 17 | ) 18 | 19 | return out_dataset 20 | 21 | 22 | def remap_negations_single(text, truth_labels, str_labels=False): 23 | def fix_negation(i, words, labels): 24 | current_word = words[i] 25 | 26 | words[i] = f"{current_word}n" 27 | words[i + 1] = "'" 28 | words.insert(i + 2, "t") 29 | 30 | labels[i] = 0 31 | labels[i + 1] = 0 32 | labels.insert(i + 2, 0) 33 | 34 | if "do not" in text or "does not" in text or "n't" in text: 35 | if str_labels: 36 | words, labels = text.split(" "), truth_labels 37 | else: 38 | words, labels = text.split(" "), [int(x) for x in truth_labels] 39 | 40 | for i in range(len(words) - 1): 41 | current_word, next_word = words[i], words[i + 1] 42 | start_word_match = current_word in ["do"] 43 | 44 | if start_word_match and (next_word == "not"): 45 | fix_negation(i, words, labels) 46 | elif next_word == "n't": 47 | fix_negation(i, words, labels) 48 | 49 | return ( 50 | " ".join(words), 51 | labels if str_labels else torch.tensor(labels) 52 | ) 53 | 54 | return ( 55 | text, 56 | truth_labels if str_labels else torch.tensor(list(truth_labels)) 57 | ) 58 | -------------------------------------------------------------------------------- /configs/conll2003-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "conll2003", 3 | "split": "test", 4 | "labels": ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"], 5 | "labels_map": { 6 | "0": 0, 7 | "1": 3, 8 | "2": 4, 9 | "3": 5, 10 | "4": 6, 11 | "5": 7, 12 | "6": 8, 13 | "7": 1, 14 | "8": 2 15 | } 16 | } -------------------------------------------------------------------------------- /configs/personal/.gitignore: -------------------------------------------------------------------------------- 1 | *.json -------------------------------------------------------------------------------- /experiments/analysis_utils.py: -------------------------------------------------------------------------------- 1 | from colored import fg, bg, attr 2 | 3 | from attack_result import AttackResult 4 | 5 | 6 | def highlight_sequences_diff(original_seq, modified_seq, seq_joiner=" "): 7 | out_sequence_mod = [] 8 | out_sequence_original = [] 9 | 10 | for or_token, mod_token in zip(original_seq, modified_seq): 11 | if or_token != mod_token: 12 | out_sequence_mod.append(f"{fg(1)}{mod_token}{attr(0)}") 13 | out_sequence_original.append(f"{fg(111)}{or_token}{attr(0)}") 14 | else: 15 | out_sequence_mod.append(mod_token) 16 | out_sequence_original.append(or_token) 17 | 18 | return seq_joiner.join(out_sequence_mod), seq_joiner.join(out_sequence_original) 19 | 20 | 21 | def preprocess_text_sample(sample): 22 | return list(filter(lambda x: x.strip() != "", sample.split("\n"))) 23 | 24 | 25 | def process_attacked_text_sample(sample): 26 | lines = preprocess_text_sample(sample) 27 | 28 | # Remove TF errors 29 | lines = [line for line in lines if "tensorflow.org/guide" not in line] 30 | 31 | # Successful samples have 6 lines of text 32 | successful = len(lines) == 6 33 | 34 | if successful: 35 | original_labels, attacked_labels = lines[3].split(" --> ") 36 | 37 | return True, { 38 | "sample": lines[0].replace("Attacking sample: ", ""), 39 | "ground_truth": lines[1].replace("The ground truth labels are:\t", ""), 40 | "model_pred": lines[2].replace("The model prediction is:\t", ""), 41 | "labels": { 42 | "original": original_labels.split(", "), 43 | "attacked": attacked_labels.split(", ") 44 | }, 45 | "post_attack_sample": lines[-1] 46 | } 47 | 48 | return False, None 49 | 50 | 51 | def print_attack_summary(samples_dict): 52 | print("\nSummary statistics:") 53 | 54 | total_samples = 0 55 | 56 | for k in samples_dict.keys(): 57 | total_samples += len(samples_dict[k]) 58 | 59 | successful_samples = len(samples_dict[AttackResult.SUCCESS.value]) 60 | failed_samples = len(samples_dict[AttackResult.FAILED.value]) 61 | errors_count = len(samples_dict[AttackResult.ERROR.value]) 62 | 63 | print(f" ▶ Successfully attacked {successful_samples}/{total_samples} samples") 64 | print(f" ▶ Failed to attack {failed_samples}/{total_samples} samples") 65 | 66 | if AttackResult.SKIPPED.value in samples_dict: 67 | skipped_samples = len(samples_dict[AttackResult.SKIPPED.value]) 68 | print(f" ▶ Skipped {skipped_samples}/{total_samples} samples") 69 | 70 | print(f" ▶ Could not attack {errors_count}/{total_samples} samples due to errors\n") 71 | print(f"Original words/labels are in {fg(111)}BLUE{attr(0)} and the attacked ones are in {fg(1)}RED{attr(0)}") 72 | -------------------------------------------------------------------------------- /experiments/analyze-json.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | from colored import fg, bg, attr 5 | 6 | from attack_result import AttackResult 7 | from analysis_utils import highlight_sequences_diff, print_attack_summary 8 | 9 | 10 | parser = argparse.ArgumentParser(description="Process a JSON TextAttack output") 11 | parser.add_argument("--filename", type=str, help="The input file") 12 | 13 | args = parser.parse_args() 14 | 15 | attacked_samples = json.loads(open(args.filename).read())["attacked_samples"] 16 | samples_dict = { 17 | AttackResult.SUCCESS.value: [], 18 | AttackResult.FAILED.value: [], 19 | AttackResult.ERROR.value: [], 20 | AttackResult.SKIPPED.value: [] 21 | } 22 | 23 | for sample in attacked_samples: 24 | attack_status = AttackResult.from_textattack_class(sample["status"]) 25 | samples_dict[attack_status.value].append(sample) 26 | 27 | # Show the details of successful attacks 28 | if attack_status == AttackResult.SUCCESS: 29 | print("------------------------------------") 30 | 31 | diff_labels, orig_labels = highlight_sequences_diff( 32 | sample["original_labels"], 33 | sample["perturbed_labels"], 34 | seq_joiner="\t" 35 | ) 36 | 37 | diff_sample, orig_sample = highlight_sequences_diff( 38 | sample["attacked_sample"].split(" "), 39 | sample["perturbed_text"].split(" ") 40 | ) 41 | 42 | model_queries = sample["num_queries"] 43 | 44 | print(f"Original labels: {orig_sample}") 45 | print(f"Attacked labels: {diff_sample}") 46 | 47 | print("\n") 48 | 49 | print(f"Original labels: {orig_labels}") 50 | print(f"Attacked labels: {diff_labels}") 51 | 52 | print_attack_summary(samples_dict) 53 | -------------------------------------------------------------------------------- /experiments/attack_result.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class AttackResult(Enum): 5 | SUCCESS = 0 6 | SKIPPED = 1 7 | FAILED = 2 8 | ERROR = 3 9 | 10 | @classmethod 11 | def from_mapping_string(cls, map_string, sample): 12 | """ 13 | Computes the attack result from a string in the formats 14 | 15 | LABEL, ..., LABEL --> LABEL, ..., LABEL 16 | LABEL, ..., LABEL --> [SKIPPED/FAILED] 17 | """ 18 | if "Could not attack sample:" in sample: 19 | return AttackResult.FAILED 20 | 21 | _, dest = map_string.split(" --> ") 22 | 23 | if "FAILED" in dest: 24 | return AttackResult.FAILED 25 | elif "SKIPPED" in dest: 26 | return AttackResult.SKIPPED 27 | else: 28 | return AttackResult.SUCCESS 29 | 30 | @classmethod 31 | def from_textattack_class(cls, textattack_cls): 32 | """ 33 | Computes the attack result from a string in the formats 34 | 35 | LABEL, ..., LABEL --> LABEL, ..., LABEL 36 | LABEL, ..., LABEL --> [SKIPPED/FAILED] 37 | """ 38 | if textattack_cls == "ErrorAttackResult": 39 | return AttackResult.ERROR 40 | elif textattack_cls == "SkippedAttackResult": 41 | return AttackResult.SKIPPED 42 | elif "FailedAttackResult" in textattack_cls: 43 | return AttackResult.FAILED 44 | elif "SuccessfulAttackResult" in textattack_cls: 45 | return AttackResult.SUCCESS 46 | 47 | raise Exception("The provided class is not mapped") 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.0 2 | textattack==0.2.15 3 | tensorflow==2.3.4 4 | tensorflow-hub 5 | pygit2 6 | python-Levenshtein 7 | seqeval 8 | openattack -------------------------------------------------------------------------------- /scripts/gcp.sh: -------------------------------------------------------------------------------- 1 | # Disk size should be ~50 GB 2 | 3 | # Install python3 4 | sudo apt update 5 | sudo apt install python-is-python3 -y 6 | 7 | # Install GCC 8 | sudo apt install build-essential -y 9 | 10 | # Install anaconda 11 | # Remember to provide the install directory as /home/walter/anaconda3 12 | wget https://repo.anaconda.com/archive/Anaconda3-2020.11-Linux-x86_64.sh 13 | sh Anaconda3-2020.11-Linux-x86_64.sh -b 14 | 15 | # Initialize conda 16 | /home/walter/anaconda3/bin/conda init 17 | 18 | # Reload shell 19 | source ~/.bashrc 20 | 21 | # Create python 3.7 environment 22 | conda create -n seqattack python=3.7 -y 23 | conda activate seqattack 24 | 25 | # Install requirements 26 | # We are not using a requirements.txt file because it causes conflicts 27 | pip install torch==1.7.0 28 | pip install textattack==0.2.15 29 | pip install tensorflow==2.4.2 tensorflow-hub 30 | pip install pygit2 python-Levenshtein 31 | pip install seqeval 32 | pip install openattack -------------------------------------------------------------------------------- /seqattack/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WalterSimoncini/SeqAttack/a0673613f489f355ddef37c89f0a635c89a500e9/seqattack/__init__.py -------------------------------------------------------------------------------- /seqattack/attacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .bert_attack import BertAttackNER 2 | from .morpheus import MorpheusTan2020NER 3 | from .ner_clare import NERCLARE 4 | from .deepwordbug import NERDeepWordBugGao2018 5 | from .textfooler import NERTextFoolerJin2019 6 | from .bae import NERBAEGarg2019 7 | from .scpn import NERSCPNParaphrase 8 | -------------------------------------------------------------------------------- /seqattack/attacks/bae.py: -------------------------------------------------------------------------------- 1 | """ 2 | BAE (BAE: BERT-Based Adversarial Examples) 3 | ============================================ 4 | 5 | """ 6 | from textattack.constraints.grammaticality import PartOfSpeech 7 | from textattack.constraints.pre_transformation import ( 8 | RepeatModification, 9 | StopwordModification, 10 | ) 11 | 12 | from textattack.transformations import WordSwapMaskedLM 13 | from textattack.constraints.semantics.sentence_encoders import \ 14 | UniversalSentenceEncoder 15 | 16 | from seqattack.constraints import NonNamedEntityConstraint, SkipNonASCII, SkipNegations 17 | from seqattack.search import NERGreedyWordSwapWIR 18 | from seqattack.utils import postprocess_ner_output 19 | 20 | from seqattack.utils.attack import NERAttack 21 | from .seqattack_recipe import SeqAttackRecipe 22 | 23 | 24 | class NERBAEGarg2019(SeqAttackRecipe): 25 | """Siddhant Garg and Goutham Ramakrishnan, 2019. 26 | 27 | BAE: BERT-based Adversarial Examples for Text Classification. 28 | 29 | https://arxiv.org/pdf/2004.01970 30 | 31 | This is "attack mode" 1 from the paper, BAE-R, word replacement. 32 | 33 | We present 4 attack modes for BAE based on the 34 | R and I operations, where for each token t in S: 35 | • BAE-R: Replace token t (See Algorithm 1) 36 | • BAE-I: Insert a token to the left or right of t 37 | • BAE-R/I: Either replace token t or insert a 38 | token to the left or right of t 39 | • BAE-R+I: First replace token t, then insert a 40 | token to the left or right of t 41 | """ 42 | 43 | @staticmethod 44 | def build( 45 | model, 46 | tokenizer, 47 | dataset, 48 | goal_function_class, 49 | max_candidates=50, 50 | additional_constraints=[], 51 | query_budget=2500, 52 | use_cache=False, 53 | attack_timeout=30, 54 | **kwargs): 55 | # "In this paper, we present a simple yet novel technique: BAE (BERT-based 56 | # Adversarial Examples), which uses a language model (LM) for token 57 | # replacement to best fit the overall context. We perturb an input sentence 58 | # by either replacing a token or inserting a new token in the sentence, by 59 | # means of masking a part of the input and using a LM to fill in the mask." 60 | # 61 | # We only consider the top K=50 synonyms from the MLM predictions. 62 | # 63 | # [from email correspondance with the author] 64 | # "When choosing the top-K candidates from the BERT masked LM, we filter out 65 | # the sub-words and only retain the whole words (by checking if they are 66 | # present in the GloVE vocabulary)" 67 | # 68 | transformation = WordSwapMaskedLM( 69 | method="bae", max_candidates=max_candidates, min_confidence=0.0 70 | ) 71 | # 72 | # Don't modify the same word twice or stopwords. 73 | # 74 | constraints = [ 75 | RepeatModification(), 76 | StopwordModification(), 77 | SkipNonASCII(), 78 | SkipNegations()] 79 | 80 | # For the R operations we add an additional check for 81 | # grammatical correctness of the generated adversarial example by filtering 82 | # out predicted tokens that do not form the same part of speech (POS) as the 83 | # original token t_i in the sentence. 84 | constraints.append(PartOfSpeech(allow_verb_noun_swap=True)) 85 | 86 | # "To ensure semantic similarity on introducing perturbations in the input 87 | # text, we filter the set of top-K masked tokens (K is a pre-defined 88 | # constant) predicted by BERT-MLM using a Universal Sentence Encoder (USE) 89 | # (Cer et al., 2018)-based sentence similarity scorer." 90 | # 91 | # "[We] set a threshold of 0.8 for the cosine similarity between USE-based 92 | # embeddings of the adversarial and input text." 93 | # 94 | # [from email correspondence with the author] 95 | # "For a fair comparison of the benefits of using a BERT-MLM in our paper, 96 | # we retained the majority of TextFooler's specifications. Thus we: 97 | # 1. Use the USE for comparison within a window of size 15 around the word 98 | # being replaced/inserted. 99 | # 2. Set the similarity score threshold to 0.1 for inputs shorter than the 100 | # window size (this translates roughly to almost always accepting the new text). 101 | # 3. Perform the USE similarity thresholding of 0.8 with respect to the text 102 | # just before the replacement/insertion and not the original text (For 103 | # example: at the 3rd R/I operation, we compute the USE score on a window 104 | # of size 15 of the text obtained after the first 2 R/I operations and not 105 | # the original text). 106 | # ... 107 | # To address point (3) from above, compare the USE with the original text 108 | # at each iteration instead of the current one (While doing this change 109 | # for the R-operation is trivial, doing it for the I-operation with the 110 | # window based USE comparison might be more involved)." 111 | # 112 | # Finally, since the BAE code is based on the TextFooler code, we need to 113 | # adjust the threshold to account for the missing / pi in the cosine 114 | # similarity comparison. So the final threshold is 1 - (1 - 0.8) / pi 115 | # = 1 - (0.2 / pi) = 0.936338023. 116 | use_constraint = UniversalSentenceEncoder( 117 | threshold=0.936338023, 118 | metric="cosine", 119 | compare_against_original=True, 120 | window_size=15, 121 | skip_text_shorter_than_window=True, 122 | ) 123 | 124 | constraints.append(use_constraint) 125 | 126 | # Add user-provided constraints 127 | constraints.extend(additional_constraints) 128 | 129 | # 130 | # Goal is untargeted classification. 131 | # 132 | goal_function = goal_function_class( 133 | model, 134 | tokenizer=tokenizer, 135 | use_cache=use_cache, 136 | query_budget=query_budget, 137 | ner_postprocess_func=postprocess_ner_output, 138 | label_names=dataset.label_names) 139 | # 140 | # "We estimate the token importance Ii of each token 141 | # t_i ∈ S = [t1, . . . , tn], by deleting ti from S and computing the 142 | # decrease in probability of predicting the correct label y, similar 143 | # to (Jin et al., 2019). 144 | # 145 | # • "If there are multiple tokens can cause C to misclassify S when they 146 | # replace the mask, we choose the token which makes Sadv most similar to 147 | # the original S based on the USE score." 148 | # • "If no token causes misclassification, we choose the perturbation that 149 | # decreases the prediction probability P(C(Sadv)=y) the most." 150 | # 151 | search_method = NERGreedyWordSwapWIR( 152 | dataset=dataset, 153 | tokenizer=tokenizer, 154 | wir_method="delete") 155 | 156 | return NERAttack( 157 | goal_function, 158 | constraints, 159 | transformation, 160 | search_method, 161 | attack_timeout=attack_timeout) 162 | -------------------------------------------------------------------------------- /seqattack/attacks/bert_attack.py: -------------------------------------------------------------------------------- 1 | from textattack.transformations import WordSwapMaskedLM 2 | from textattack.constraints.overlap import MaxWordsPerturbed 3 | from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder 4 | from textattack.constraints.pre_transformation import RepeatModification, StopwordModification 5 | 6 | from seqattack.constraints import SkipNonASCII, SkipNegations 7 | from seqattack.search import NERGreedyWordSwapWIR 8 | from seqattack.utils import postprocess_ner_output 9 | from seqattack.utils.attack import NERAttack 10 | from .seqattack_recipe import SeqAttackRecipe 11 | 12 | 13 | class BertAttackNER(SeqAttackRecipe): 14 | """ 15 | Li, L.., Ma, R., Guo, Q., Xiangyang, X., Xipeng, Q. (2020). 16 | BERT-ATTACK: Adversarial Attack Against BERT Using BERT 17 | 18 | https://arxiv.org/abs/2004.09984 19 | 20 | This is "attack mode" 1 from the paper, BAE-R, word replacement. 21 | This code is heavily based on: (refer to it for detailed documentation) 22 | 23 | https://textattack.readthedocs.io/en/latest/_modules/textattack/attack_recipes/bert_attack_li_2020.html#BERTAttackLi2020 24 | """ 25 | @staticmethod 26 | def build( 27 | model, 28 | tokenizer, 29 | dataset, 30 | goal_function_class, 31 | max_perturbed_percent=0.4, 32 | max_candidates=48, 33 | additional_constraints=[], 34 | query_budget=2500, 35 | use_cache=False, 36 | max_entities_mispredicted=0.8, 37 | attack_timeout=30, 38 | **kwargs): 39 | transformation = WordSwapMaskedLM( 40 | method="bert-attack", 41 | max_candidates=max_candidates) 42 | 43 | constraints = [ 44 | # Do not modify already changed words 45 | RepeatModification(), 46 | # Do not modify stopwords 47 | StopwordModification(), 48 | SkipNonASCII(), 49 | SkipNegations(), 50 | MaxWordsPerturbed(max_percent=max_perturbed_percent)] 51 | 52 | constraints.extend(additional_constraints) 53 | 54 | use_constraint = UniversalSentenceEncoder( 55 | threshold=0.2, 56 | metric="cosine", 57 | compare_against_original=True, 58 | window_size=None) 59 | constraints.append(use_constraint) 60 | 61 | goal_function = goal_function_class( 62 | model, 63 | tokenizer=tokenizer, 64 | use_cache=use_cache, 65 | query_budget=query_budget, 66 | ner_postprocess_func=postprocess_ner_output, 67 | label_names=dataset.label_names) 68 | 69 | # Select words with the most influence on output logits 70 | # search_method = GreedyWordSwapWIR(wir_method="unk") 71 | search_method = NERGreedyWordSwapWIR( 72 | dataset=dataset, 73 | tokenizer=tokenizer, 74 | wir_method="unk") 75 | 76 | return NERAttack( 77 | goal_function, 78 | constraints, 79 | transformation, 80 | search_method, 81 | attack_timeout=attack_timeout, 82 | max_entities_mispredicted=max_entities_mispredicted) 83 | -------------------------------------------------------------------------------- /seqattack/attacks/deepwordbug.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | DeepWordBug 4 | ======================================== 5 | (Black-box Generation of Adversarial Text Sequences to Evade Deep Learning Classifiers) 6 | 7 | """ 8 | 9 | from textattack.constraints.overlap import LevenshteinEditDistance 10 | from textattack.constraints.pre_transformation import ( 11 | RepeatModification, 12 | StopwordModification, 13 | ) 14 | 15 | from textattack.transformations import ( 16 | CompositeTransformation, 17 | WordSwapNeighboringCharacterSwap, 18 | WordSwapRandomCharacterDeletion, 19 | WordSwapRandomCharacterInsertion, 20 | WordSwapRandomCharacterSubstitution, 21 | ) 22 | 23 | from seqattack.models import NERModelWrapper 24 | from seqattack.search import NERGreedyWordSwapWIR 25 | 26 | from seqattack.utils import postprocess_ner_output 27 | from seqattack.utils.attack import NERAttack 28 | from seqattack.constraints import SkipNegations 29 | from .seqattack_recipe import SeqAttackRecipe 30 | 31 | from seqattack.constraints import ( 32 | SkipModelErrors, 33 | AvoidNamedEntityConstraint, 34 | NonNamedEntityConstraint 35 | ) 36 | 37 | 38 | class NERDeepWordBugGao2018(SeqAttackRecipe): 39 | """ 40 | Gao, Lanchantin, Soffa, Qi. 41 | 42 | Black-box Generation of Adversarial Text Sequences to Evade Deep 43 | Learning Classifiers. 44 | 45 | https://arxiv.org/abs/1801.04354 46 | """ 47 | @staticmethod 48 | def build( 49 | model, 50 | tokenizer, 51 | dataset, 52 | goal_function_class, 53 | max_edit_distance=30, 54 | use_cache=True, 55 | query_budget=512, 56 | additional_constraints=[], 57 | use_all_transformations=True, 58 | attack_timeout=30, 59 | **kwargs): 60 | # 61 | # Swap characters out from words. Choose the best of four potential transformations. 62 | # 63 | if use_all_transformations: 64 | # We propose four similar methods: 65 | transformation = CompositeTransformation( 66 | [ 67 | # (1) Swap: Swap two adjacent letters in the word. 68 | WordSwapNeighboringCharacterSwap(), 69 | # (2) Substitution: Substitute a letter in the word with a random letter. 70 | WordSwapRandomCharacterSubstitution(), 71 | # (3) Deletion: Delete a random letter from the word. 72 | WordSwapRandomCharacterDeletion(), 73 | # (4) Insertion: Insert a random letter in the word. 74 | WordSwapRandomCharacterInsertion(), 75 | ] 76 | ) 77 | else: 78 | # We use the Combined Score and the Substitution Transformer to generate 79 | # adversarial samples, with the maximum edit distance difference of 30 80 | # (ϵ = 30). 81 | transformation = WordSwapRandomCharacterSubstitution() 82 | # 83 | # Don't modify the same word twice or stopwords 84 | # 85 | constraints = [ 86 | RepeatModification(), 87 | StopwordModification(), 88 | SkipNegations() 89 | ] 90 | # 91 | # In these experiments, we hold the maximum difference 92 | # on edit distance (ϵ) to a constant 93 | # 94 | constraints.append(LevenshteinEditDistance(max_edit_distance)) 95 | # Add extra constraints 96 | constraints.extend(additional_constraints) 97 | # 98 | # Goal is untargeted classification in the original paper 99 | # 100 | goal_function = goal_function_class( 101 | model, 102 | tokenizer=tokenizer, 103 | use_cache=use_cache, 104 | query_budget=query_budget, 105 | ner_postprocess_func=postprocess_ner_output, 106 | label_names=dataset.label_names) 107 | # 108 | # Greedily swap words with "Word Importance Ranking". 109 | # 110 | search_method = NERGreedyWordSwapWIR( 111 | dataset=dataset, 112 | tokenizer=tokenizer, 113 | wir_method="unk" 114 | ) 115 | 116 | return NERAttack( 117 | goal_function, 118 | constraints, 119 | transformation, 120 | search_method, 121 | attack_timeout=attack_timeout) 122 | 123 | @staticmethod 124 | def get_ner_constraints(model_name, **kwargs): 125 | preserve_named_entities = kwargs.get("preserve_named_entities", False) 126 | 127 | constraints_model_wrapper = NERModelWrapper.load_huggingface_model( 128 | model_name=model_name 129 | )[1] 130 | 131 | constraints = [ 132 | SkipModelErrors(model_wrapper=constraints_model_wrapper) 133 | ] 134 | 135 | if preserve_named_entities: 136 | constraints.append(AvoidNamedEntityConstraint(ner_model_wrapper=constraints_model_wrapper)) 137 | constraints.append(NonNamedEntityConstraint()) 138 | 139 | return constraints 140 | -------------------------------------------------------------------------------- /seqattack/attacks/morpheus.py: -------------------------------------------------------------------------------- 1 | """ 2 | MORPHEUS2020 3 | =============== 4 | (It’s Morphin’ Time! Combating Linguistic Discrimination with Inflectional Perturbations) 5 | """ 6 | from textattack.constraints.pre_transformation import ( 7 | RepeatModification, 8 | StopwordModification, 9 | ) 10 | 11 | from textattack.shared.attack import Attack 12 | from textattack.attack_recipes import AttackRecipe 13 | from textattack.transformations import WordSwapInflections 14 | 15 | from seqattack.constraints import NonNamedEntityConstraint, SkipNegations 16 | from seqattack.goal_functions import UntargetedNERGoalFunction 17 | 18 | from seqattack.search import GreedySearchNER 19 | from seqattack.utils import postprocess_ner_output 20 | 21 | from seqattack.utils.attack import NERAttack 22 | from .seqattack_recipe import SeqAttackRecipe 23 | 24 | 25 | class MorpheusTan2020NER(SeqAttackRecipe): 26 | """ 27 | Samson Tan, Shafiq Joty, Min-Yen Kan, Richard Socher. 28 | It’s Morphin’ Time! Combating Linguistic Discrimination with Inflectional Perturbations 29 | https://www.aclweb.org/anthology/2020.acl-main.263/ 30 | """ 31 | @staticmethod 32 | def build( 33 | model, 34 | tokenizer, 35 | dataset, 36 | goal_function_class, 37 | use_cache=True, 38 | query_budget=512, 39 | additional_constraints=[], 40 | attack_timeout=30, 41 | **kwargs): 42 | goal_function = goal_function_class( 43 | model, 44 | tokenizer, 45 | use_cache=use_cache, 46 | query_budget=query_budget, 47 | ner_postprocess_func=postprocess_ner_output, 48 | label_names=dataset.label_names) 49 | 50 | # Swap words with their inflections 51 | transformation = WordSwapInflections() 52 | 53 | # The POS mapping has some compatibility issues with the POS 54 | # output of AttackedText(s). Add these mappings to patch the 55 | # issue 56 | transformation._enptb_to_universal["ADJ"] = "ADJ" 57 | transformation._enptb_to_universal["NOUN"] = "NOUN" 58 | transformation._enptb_to_universal["VERB"] = "VERB" 59 | 60 | constraints = [ 61 | # Do not modify already changed words 62 | RepeatModification(), 63 | # Do not modify stopwords 64 | StopwordModification(), 65 | SkipNegations()] 66 | 67 | constraints.extend(additional_constraints) 68 | 69 | # Greedily swap words (see pseudocode, Algorithm 1 of the paper). 70 | search_method = GreedySearchNER() 71 | 72 | return NERAttack( 73 | goal_function, 74 | constraints, 75 | transformation, 76 | search_method, 77 | attack_timeout=attack_timeout) 78 | -------------------------------------------------------------------------------- /seqattack/attacks/ner_clare.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | 3 | from textattack.shared.attack import Attack 4 | from textattack.attack_recipes import AttackRecipe 5 | from textattack.transformations import WordSwapMaskedLM, CompositeTransformation 6 | from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder 7 | from textattack.constraints.pre_transformation import RepeatModification, StopwordModification 8 | 9 | from seqattack.utils import postprocess_ner_output 10 | from seqattack.transformations import RoBERTaWordInsertionMaskedLM 11 | from seqattack.constraints import NonNamedEntityConstraint, SkipNonASCII, SkipNegations 12 | from seqattack.search import GreedySearchNER 13 | from seqattack.utils.attack import NERAttack 14 | from .seqattack_recipe import SeqAttackRecipe 15 | 16 | 17 | class NERCLARE(SeqAttackRecipe): 18 | """ 19 | Li, Zhang, Peng, Chen, Brockett, Sun, Dolan. 20 | 21 | "Contextualized Perturbation for Textual Adversarial Attack" (Li et al., 2020) 22 | 23 | https://arxiv.org/abs/2009.07502 24 | 25 | This method uses greedy search with replace, merge, and insertion transformations that leverage a 26 | pretrained language model. It also uses USE similarity constraint. 27 | """ 28 | @staticmethod 29 | def build( 30 | model, 31 | tokenizer, 32 | dataset, 33 | goal_function_class, 34 | max_candidates=50, 35 | additional_constraints=[], 36 | query_budget=2500, 37 | use_cache=False, 38 | attack_timeout=30, 39 | **kwargs): 40 | shared_masked_lm = transformers.AutoModelForCausalLM.from_pretrained("distilroberta-base") 41 | shared_tokenizer = transformers.AutoTokenizer.from_pretrained("distilroberta-base") 42 | 43 | transformation = CompositeTransformation( 44 | [ 45 | WordSwapMaskedLM( 46 | method="bae", 47 | masked_language_model=shared_masked_lm, 48 | tokenizer=shared_tokenizer, 49 | max_candidates=50, 50 | min_confidence=5e-4, 51 | ), 52 | RoBERTaWordInsertionMaskedLM( 53 | masked_language_model=shared_masked_lm, 54 | tokenizer=shared_tokenizer, 55 | max_candidates=50, 56 | min_confidence=0.0, 57 | ) 58 | ] 59 | ) 60 | 61 | constraints = [ 62 | # Do not modify already changed words 63 | RepeatModification(), 64 | # Do not modify stopwords 65 | StopwordModification(), 66 | SkipNonASCII(), 67 | SkipNegations() 68 | ] 69 | 70 | constraints.extend(additional_constraints) 71 | 72 | use_constraint = UniversalSentenceEncoder( 73 | threshold=0.7, 74 | metric="cosine", 75 | compare_against_original=True, 76 | # The original implementation uses a window of 15 and skips 77 | # samples shorter than that window 78 | window_size=None) 79 | 80 | constraints.append(use_constraint) 81 | 82 | goal_function = goal_function_class( 83 | model, 84 | tokenizer=tokenizer, 85 | use_cache=use_cache, 86 | query_budget=query_budget, 87 | ner_postprocess_func=postprocess_ner_output, 88 | label_names=dataset.label_names) 89 | 90 | search_method = GreedySearchNER() 91 | 92 | return NERAttack( 93 | goal_function, 94 | constraints, 95 | transformation, 96 | search_method, 97 | attack_timeout=attack_timeout) 98 | -------------------------------------------------------------------------------- /seqattack/attacks/scpn.py: -------------------------------------------------------------------------------- 1 | """ 2 | BAE (BAE: BERT-Based Adversarial Examples) 3 | ============================================ 4 | 5 | """ 6 | from seqattack.constraints import SkipNonASCII, SkipNegations 7 | from seqattack.search import GreedySearchNER 8 | from seqattack.utils import postprocess_ner_output 9 | 10 | from textattack.attack_recipes import AttackRecipe 11 | from seqattack.utils.attack import NERAttack 12 | from seqattack.transformations import ParaphraseTransformation 13 | from .seqattack_recipe import SeqAttackRecipe 14 | 15 | 16 | class NERSCPNParaphrase(SeqAttackRecipe): 17 | """ 18 | Adversarial Example Generation with Syntactically Controlled 19 | Paraphrase Networks. 20 | 21 | Mohit Iyyer, John Wieting, Kevin Gimpel, Luke Zettlemoyer. 22 | NAACL-HLT 2018. 23 | 24 | `[pdf] `__ 25 | `[code] `__ 26 | """ 27 | @staticmethod 28 | def build( 29 | model, 30 | tokenizer, 31 | dataset, 32 | goal_function_class, 33 | additional_constraints=[], 34 | query_budget=2500, 35 | use_cache=False, 36 | attack_timeout=30, 37 | **kwargs): 38 | transformation = ParaphraseTransformation() 39 | 40 | constraints = [ 41 | # Only skip non-ASCII characters 42 | SkipNonASCII(), 43 | SkipNegations(), 44 | ] 45 | 46 | # Add user-provided constraints 47 | constraints.extend(additional_constraints) 48 | 49 | # FIXME: we might want to limit this to the 50 | # strict untargeted / targeted goal functions 51 | # since I <-> B swaps are easy to generate 52 | goal_function = goal_function_class( 53 | model, 54 | tokenizer=tokenizer, 55 | use_cache=use_cache, 56 | query_budget=query_budget, 57 | ner_postprocess_func=postprocess_ner_output, 58 | label_names=dataset.label_names) 59 | 60 | search_method = GreedySearchNER() 61 | 62 | return NERAttack( 63 | goal_function, 64 | constraints, 65 | transformation, 66 | search_method, 67 | attack_timeout=attack_timeout) 68 | 69 | @staticmethod 70 | def get_ner_constraints(model_name, **kwargs): 71 | return [] 72 | -------------------------------------------------------------------------------- /seqattack/attacks/seqattack_recipe.py: -------------------------------------------------------------------------------- 1 | from seqattack.models import NERModelWrapper 2 | from textattack.attack_recipes import AttackRecipe 3 | from seqattack.constraints import ( 4 | SkipModelErrors, 5 | AvoidNamedEntityConstraint, 6 | NonNamedEntityConstraint 7 | ) 8 | 9 | 10 | class SeqAttackRecipe(AttackRecipe): 11 | @staticmethod 12 | def get_ner_constraints(model_name, **kwargs): 13 | constraints_model_wrapper = NERModelWrapper.load_huggingface_model( 14 | model_name=model_name 15 | )[1] 16 | 17 | return [ 18 | SkipModelErrors(model_wrapper=constraints_model_wrapper), 19 | AvoidNamedEntityConstraint(ner_model_wrapper=constraints_model_wrapper), 20 | # Avoid modifying ground truth named entities 21 | NonNamedEntityConstraint() 22 | ] 23 | -------------------------------------------------------------------------------- /seqattack/attacks/textfooler.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | TextFooler (Is BERT Really Robust?) 4 | =================================================== 5 | A Strong Baseline for Natural Language Attack on Text Classification and Entailment) 6 | 7 | """ 8 | 9 | from textattack.constraints.grammaticality import PartOfSpeech 10 | from textattack.constraints.pre_transformation import ( 11 | InputColumnModification, 12 | RepeatModification, 13 | StopwordModification, 14 | ) 15 | 16 | from textattack.constraints.semantics import WordEmbeddingDistance 17 | from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder 18 | from textattack.shared.attack import Attack 19 | from textattack.transformations import WordSwapEmbedding 20 | 21 | from textattack.attack_recipes import AttackRecipe 22 | 23 | from seqattack.search import NERGreedyWordSwapWIR 24 | from seqattack.utils import postprocess_ner_output 25 | from seqattack.constraints import NonNamedEntityConstraint, SkipNonASCII, \ 26 | SkipNegations 27 | 28 | from seqattack.utils.attack import NERAttack 29 | from .seqattack_recipe import SeqAttackRecipe 30 | 31 | 32 | class NERTextFoolerJin2019(SeqAttackRecipe): 33 | """ 34 | Jin, D., Jin, Z., Zhou, J.T., & Szolovits, P. (2019). 35 | 36 | Is BERT Really Robust? Natural Language Attack on Text Classification and Entailment. 37 | 38 | https://arxiv.org/abs/1907.11932 39 | """ 40 | 41 | @staticmethod 42 | def build( 43 | model, 44 | tokenizer, 45 | dataset, 46 | goal_function_class, 47 | max_candidates=50, 48 | additional_constraints=[], 49 | query_budget=2500, 50 | use_cache=False, 51 | attack_timeout=30, 52 | **kwargs): 53 | # 54 | # Swap words with their 50 closest embedding nearest-neighbors. 55 | # Embedding: Counter-fitted PARAGRAM-SL999 vectors. 56 | # 57 | transformation = WordSwapEmbedding(max_candidates=max_candidates) 58 | # 59 | # Don't modify the same word twice or the stopwords defined 60 | # in the TextFooler public implementation. 61 | # 62 | # fmt: off 63 | stopwords = set( 64 | ["a", "about", "above", "across", "after", "afterwards", "again", "against", "ain", "all", "almost", "alone", "along", "already", "also", "although", "am", "among", "amongst", "an", "and", "another", "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are", "aren", "aren't", "around", "as", "at", "back", "been", "before", "beforehand", "behind", "being", "below", "beside", "besides", "between", "beyond", "both", "but", "by", "can", "cannot", "could", "couldn", "couldn't", "d", "didn", "didn't", "doesn", "doesn't", "don", "don't", "down", "due", "during", "either", "else", "elsewhere", "empty", "enough", "even", "ever", "everyone", "everything", "everywhere", "except", "first", "for", "former", "formerly", "from", "hadn", "hadn't", "hasn", "hasn't", "haven", "haven't", "he", "hence", "her", "here", "hereafter", "hereby", "herein", "hereupon", "hers", "herself", "him", "himself", "his", "how", "however", "hundred", "i", "if", "in", "indeed", "into", "is", "isn", "isn't", "it", "it's", "its", "itself", "just", "latter", "latterly", "least", "ll", "may", "me", "meanwhile", "mightn", "mightn't", "mine", "more", "moreover", "most", "mostly", "must", "mustn", "mustn't", "my", "myself", "namely", "needn", "needn't", "neither", "never", "nevertheless", "next", "no", "nobody", "none", "noone", "nor", "not", "nothing", "now", "nowhere", "o", "of", "off", "on", "once", "one", "only", "onto", "or", "other", "others", "otherwise", "our", "ours", "ourselves", "out", "over", "per", "please", "s", "same", "shan", "shan't", "she", "she's", "should've", "shouldn", "shouldn't", "somehow", "something", "sometime", "somewhere", "such", "t", "than", "that", "that'll", "the", "their", "theirs", "them", "themselves", "then", "thence", "there", "thereafter", "thereby", "therefore", "therein", "thereupon", "these", "they", "this", "those", "through", "throughout", "thru", "thus", "to", "too", "toward", "towards", "under", "unless", "until", "up", "upon", "used", "ve", "was", "wasn", "wasn't", "we", "were", "weren", "weren't", "what", "whatever", "when", "whence", "whenever", "where", "whereafter", "whereas", "whereby", "wherein", "whereupon", "wherever", "whether", "which", "while", "whither", "who", "whoever", "whole", "whom", "whose", "why", "with", "within", "without", "won", "won't", "would", "wouldn", "wouldn't", "y", "yet", "you", "you'd", "you'll", "you're", "you've", "your", "yours", "yourself", "yourselves"] 65 | ) 66 | # fmt: on 67 | constraints = [ 68 | RepeatModification(), 69 | StopwordModification(stopwords=stopwords), 70 | SkipNonASCII(), 71 | SkipNegations()] 72 | 73 | # Minimum word embedding cosine similarity of 0.5. 74 | # (The paper claims 0.7, but analysis of the released code and some empirical 75 | # results show that it's 0.5.) 76 | # 77 | constraints.append(WordEmbeddingDistance(min_cos_sim=0.5)) 78 | # 79 | # Only replace words with the same part of speech (or nouns with verbs) 80 | # 81 | constraints.append(PartOfSpeech(allow_verb_noun_swap=True)) 82 | # 83 | # Universal Sentence Encoder with a minimum angular similarity of ε = 0.5. 84 | # 85 | # In the TextFooler code, they forget to divide the angle between the two 86 | # embeddings by pi. So if the original threshold was that 1 - sim >= 0.5, the 87 | # new threshold is 1 - (0.5) / pi = 0.840845057 88 | # 89 | use_constraint = UniversalSentenceEncoder( 90 | threshold=0.840845057, 91 | metric="angular", 92 | compare_against_original=False, 93 | window_size=15, 94 | skip_text_shorter_than_window=True, 95 | ) 96 | 97 | constraints.append(use_constraint) 98 | constraints.extend(additional_constraints) 99 | 100 | # 101 | # Goal is untargeted classification 102 | # 103 | goal_function = goal_function_class( 104 | model, 105 | tokenizer=tokenizer, 106 | use_cache=use_cache, 107 | query_budget=query_budget, 108 | ner_postprocess_func=postprocess_ner_output, 109 | label_names=dataset.label_names) 110 | 111 | # 112 | # Greedily swap words with "Word Importance Ranking". 113 | # 114 | search_method = NERGreedyWordSwapWIR( 115 | dataset=dataset, 116 | tokenizer=tokenizer, 117 | wir_method="delete") 118 | 119 | return NERAttack( 120 | goal_function, 121 | constraints, 122 | transformation, 123 | search_method, 124 | attack_timeout=attack_timeout) 125 | -------------------------------------------------------------------------------- /seqattack/constraints/__init__.py: -------------------------------------------------------------------------------- 1 | from .ner import NonNamedEntityConstraint 2 | from .avoid_named_entities import AvoidNamedEntityConstraint 3 | from .skip_non_ascii import SkipNonASCII 4 | from .model_errors import SkipModelErrors 5 | from .skip_negations import SkipNegations 6 | -------------------------------------------------------------------------------- /seqattack/constraints/avoid_named_entities.py: -------------------------------------------------------------------------------- 1 | from textattack.constraints import Constraint 2 | 3 | from seqattack.models import NERModelWrapper 4 | from seqattack.datasets import NERDataset 5 | 6 | from textattack.shared.attacked_text import AttackedText 7 | 8 | from seqattack.utils import elements_diff 9 | 10 | 11 | class AvoidNamedEntityConstraint(Constraint): 12 | """ 13 | This constraint makes sure that the altered words are 14 | not recognized as named entities 15 | """ 16 | def __init__(self, ner_model_wrapper): 17 | super().__init__(compare_against_original=False) 18 | self.ner_model_wrapper = ner_model_wrapper 19 | 20 | def _check_constraint(self, transformed_text, original_text): 21 | # Predict named entities for the original and transformed text 22 | insertion_index = transformed_text._ground_truth_inserted_index( 23 | transformed_text 24 | ) 25 | 26 | transformed_ground_truth = transformed_text.attack_attrs["ground_truth"] 27 | 28 | transformed_preds = self.ner_model_wrapper([ 29 | transformed_text.text 30 | ])[0] 31 | 32 | transformed_preds = self.ner_model_wrapper.process_raw_output( 33 | transformed_preds, 34 | transformed_text.text 35 | ).tolist() 36 | 37 | if insertion_index is not None: 38 | return transformed_preds[insertion_index] == 0 39 | else: 40 | diff_indices = transformed_text.all_words_diff(original_text) 41 | 42 | for idx in list(diff_indices): 43 | if transformed_preds[idx] > 0: 44 | # Introduced entity via entity insertion (e.g. phone --> Belgium) 45 | return False 46 | elif transformed_preds[idx] == 0 and transformed_ground_truth[idx] > 0: 47 | # Removed entity (e.g. Belgium --> phone) 48 | return False 49 | 50 | return True 51 | -------------------------------------------------------------------------------- /seqattack/constraints/model_errors.py: -------------------------------------------------------------------------------- 1 | from textattack.constraints import Constraint 2 | from seqattack.models.exceptions import PredictionError 3 | 4 | 5 | class SkipModelErrors(Constraint): 6 | """ 7 | A constraint that rejects texts which cause an error in the model 8 | inference (e.g. texts that have more tokens than out predictions) 9 | """ 10 | def __init__(self, model_wrapper): 11 | super().__init__(compare_against_original=True) 12 | self.model_wrapper = model_wrapper 13 | 14 | def _check_constraint(self, transformed_text, original_text): 15 | try: 16 | _ = self.model_wrapper([ 17 | transformed_text.text 18 | ], raise_excs=True) 19 | 20 | return True 21 | except PredictionError as ex: 22 | print(f"Rejected attacked text '{transformed_text.text}' due to prediction errors: {ex}") 23 | 24 | return False 25 | -------------------------------------------------------------------------------- /seqattack/constraints/ner.py: -------------------------------------------------------------------------------- 1 | from textattack.constraints import Constraint 2 | 3 | from seqattack.models import NERModelWrapper 4 | from seqattack.datasets import NERDataset 5 | 6 | 7 | class NonNamedEntityConstraint(Constraint): 8 | """ 9 | A constraint that prevents named entities from 10 | being replaced 11 | """ 12 | def __init__(self): 13 | super().__init__(compare_against_original=True) 14 | 15 | def _check_constraint(self, transformed_text, original_text): 16 | transformed_entities = self._entity_tokens(transformed_text) 17 | original_entities = self._entity_tokens(original_text) 18 | 19 | # Make sure the text contains the same named 20 | # entities in the same order 21 | return original_entities == transformed_entities 22 | 23 | def _entity_tokens(self, attacked_text): 24 | text_entities = [] 25 | 26 | tokens = attacked_text.text.split(" ") 27 | ground_truth = attacked_text.attack_attrs["ground_truth"] 28 | 29 | for token, label in zip(tokens, ground_truth): 30 | if label > 0: 31 | text_entities.append(token) 32 | 33 | return text_entities 34 | -------------------------------------------------------------------------------- /seqattack/constraints/skip_negations.py: -------------------------------------------------------------------------------- 1 | from textattack.constraints import Constraint 2 | 3 | 4 | class SkipNegations(Constraint): 5 | """ 6 | A constraint that rejects texts that contain negations, 7 | namely "do not", any word ending in "n't" and "does not" 8 | """ 9 | def __init__(self): 10 | super().__init__(compare_against_original=True) 11 | 12 | def _check_constraint(self, transformed_text, original_text): 13 | targets = [ 14 | "do not", 15 | "n't", 16 | "does not" 17 | ] 18 | 19 | for target in targets: 20 | if target in transformed_text.text: 21 | return False 22 | 23 | return True 24 | -------------------------------------------------------------------------------- /seqattack/constraints/skip_non_ascii.py: -------------------------------------------------------------------------------- 1 | from textattack.constraints import Constraint 2 | from seqattack.utils import is_ascii 3 | 4 | 5 | class SkipNonASCII(Constraint): 6 | """ 7 | A constraint that rejects texts with non ASCII characters 8 | """ 9 | def __init__(self): 10 | super().__init__(compare_against_original=True) 11 | 12 | def _check_constraint(self, transformed_text, original_text): 13 | return is_ascii(transformed_text.text) 14 | -------------------------------------------------------------------------------- /seqattack/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .ner import NERDataset 2 | from .huggingfacener import NERHuggingFaceDataset 3 | -------------------------------------------------------------------------------- /seqattack/datasets/huggingfacener.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from datasets import load_dataset 5 | 6 | from .ner import NERDataset 7 | 8 | 9 | class NERHuggingFaceDataset(NERDataset): 10 | """ 11 | Dataset wrapper to load named entity recognition datasets 12 | from https://huggingface.co/datasets 13 | """ 14 | def __init__( 15 | self, 16 | name: str, 17 | split: str, 18 | label_names: list, 19 | num_examples: int = None, 20 | labels_map: dict = None): 21 | """Loads a HuggingFace dataset, downloading it if needed. 22 | 23 | Parameters 24 | ---------- 25 | 26 | name: the dataset name 27 | split: the dataset split (e.g. train, test) 28 | label_names: list that maps prediction to labels (e.g. ["O", "B-ORG", ...]) 29 | num_examples: the maximum number of examples to be loaded 30 | labels_map: a dictionary (e.g. {0: 2, 1: 3, ...}) that remaps the ground truth, 31 | useful to align the model output and the dataset 32 | """ 33 | if os.path.isfile(name): 34 | dataset = json.loads(open(name).read()) 35 | 36 | dataset_tokens, dataset_ner_tags = zip(*dataset["samples"]) 37 | dataset_tokens = [sample.split(" ") for sample in dataset_tokens] 38 | else: 39 | dataset = load_dataset(name, None, split=split) 40 | 41 | dataset_ner_tags = dataset["ner_tags"] 42 | dataset_tokens = dataset["tokens"] 43 | 44 | if labels_map: 45 | examples_ner_tags = [ 46 | [labels_map[tag] for tag in tags] 47 | for tags in dataset_ner_tags 48 | ] 49 | else: 50 | examples_ner_tags = dataset_ner_tags 51 | 52 | dataset = list(zip( 53 | [" ".join(x) for x in dataset_tokens], 54 | examples_ner_tags 55 | )) 56 | 57 | self.dataset_split = split 58 | 59 | super().__init__( 60 | dataset, 61 | label_names, 62 | num_examples=num_examples, 63 | dataset_name=name 64 | ) 65 | 66 | @staticmethod 67 | def from_config_file(path, num_examples=None): 68 | with open(path) as config_file: 69 | config = json.loads(config_file.read()) 70 | 71 | labels_map = config["labels_map"] 72 | 73 | if labels_map is not None: 74 | # Convert keys to integers if a labels_map was provided 75 | labels_map = {int(k) : int(v) for k, v in labels_map.items()} 76 | 77 | return NERHuggingFaceDataset( 78 | name=config["name"], 79 | split=config["split"], 80 | label_names=config["labels"], 81 | num_examples=num_examples, 82 | labels_map=labels_map 83 | ) 84 | 85 | @property 86 | def split(self): 87 | """The dataset split if any""" 88 | return self.dataset_split 89 | -------------------------------------------------------------------------------- /seqattack/datasets/ner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | 4 | 5 | class NERDataset: 6 | def __init__( 7 | self, 8 | dataset, 9 | label_names: list, 10 | num_examples: int = None, 11 | dataset_name: str = None): 12 | """ 13 | Initializes a new named entity recognition dataset. 14 | 15 | the dataset is given as a list of tuples in the form 16 | (text, ner_tags), where ner_tags is an array of class 17 | labels (e.g. [0, 0, 2, ...]). 18 | """ 19 | if num_examples: 20 | self.dataset = dataset[:num_examples] 21 | else: 22 | self.dataset = dataset 23 | 24 | self.dataset_name = dataset_name 25 | self.current_idx = 0 26 | self.label_names = label_names 27 | 28 | def __iter__(self): 29 | return self 30 | 31 | def __next__(self): 32 | if self.current_idx < len(self.dataset): 33 | self.current_idx += 1 34 | 35 | return self.dataset[self.current_idx - 1] 36 | else: 37 | raise StopIteration 38 | 39 | def __getitem__(self, i): 40 | return self.dataset[i] 41 | 42 | def __len__(self): 43 | return len(self.dataset) 44 | 45 | def to_dict(self): 46 | predictions_mappings = {} 47 | 48 | for sentence, prediction in self.dataset: 49 | predictions_mappings[sentence] = prediction 50 | 51 | return predictions_mappings 52 | 53 | def filter(self, filter_function): 54 | self.dataset = list(filter(filter_function, self.dataset)) 55 | 56 | def shuffle(self): 57 | random.shuffle(self.dataset) 58 | 59 | def tolist(self): 60 | return copy.deepcopy(self.dataset) 61 | 62 | @property 63 | def name(self): 64 | """The dataset name if any""" 65 | return self.dataset_name 66 | 67 | @property 68 | def split(self): 69 | """The dataset split if any""" 70 | return None 71 | -------------------------------------------------------------------------------- /seqattack/goal_functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .ner_goal_function_result import NERGoalFunctionResult 2 | from .untargeted_ner import UntargetedNERGoalFunction 3 | from .targeted_ner import TargetedNERGoalFunction 4 | from .untargeted_ner_strict import StrictUntargetedNERGoalFunction 5 | 6 | 7 | def get_goal_function(goal_function): 8 | goal_functions = { 9 | "untargeted": UntargetedNERGoalFunction, 10 | "untargeted-strict": StrictUntargetedNERGoalFunction, 11 | "targeted": TargetedNERGoalFunction, 12 | } 13 | 14 | if goal_function not in goal_functions.keys(): 15 | raise ValueError(f"Invalid goal function {goal_function}. Valid values are {list(goal_functions.keys())}") 16 | 17 | return goal_functions[goal_function] 18 | -------------------------------------------------------------------------------- /seqattack/goal_functions/ner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .ner_goal_function_result import NERGoalFunctionResult 5 | 6 | from textattack.goal_functions.goal_function import GoalFunction 7 | 8 | from seqattack.utils import get_tokens, tensor_mask 9 | 10 | 11 | class NERGoalFunction(GoalFunction): 12 | """ 13 | Base goal function that determines whether an attack on 14 | named entity recognition was successful 15 | """ 16 | def _process_model_outputs(self, inputs, scores): 17 | """ 18 | Processes and validates a list of model outputs. 19 | """ 20 | return scores 21 | 22 | def __init__( 23 | self, 24 | model_wrapper, 25 | tokenizer, 26 | maximizable=False, 27 | use_cache=True, 28 | query_budget=float("inf"), 29 | model_cache_size=2 ** 20, 30 | min_percent_entities_mispredicted=0.5, 31 | ner_postprocess_func=None, 32 | # Array of labels 33 | label_names=None 34 | ): 35 | super().__init__( 36 | model_wrapper, 37 | maximizable=maximizable, 38 | use_cache=use_cache, 39 | query_budget=query_budget, 40 | model_cache_size=model_cache_size) 41 | 42 | assert ner_postprocess_func is not None, "A post-processing function is required!" 43 | 44 | self.ner_postprocess_func = ner_postprocess_func 45 | self.tokenizer = tokenizer 46 | self.min_percent_entities_mispredicted = min_percent_entities_mispredicted 47 | self.label_names = label_names 48 | 49 | def _goal_function_result_type(self): 50 | """Returns the class of this goal function's results.""" 51 | return self._create_goal_result 52 | 53 | def _create_goal_result( 54 | self, 55 | attacked_text, 56 | raw_output, 57 | displayed_output, 58 | goal_status, 59 | goal_function_score, 60 | num_queries, 61 | ground_truth_output): 62 | """ 63 | Utility function that creates a NER Goal function result 64 | with both the raw and processed model outputs 65 | """ 66 | formatted_preds, _, _, _ = self._preprocess_model_output( 67 | raw_output, 68 | attacked_text) 69 | 70 | result = NERGoalFunctionResult( 71 | attacked_text, 72 | formatted_preds, 73 | displayed_output, 74 | goal_status, 75 | goal_function_score, 76 | self.num_queries, 77 | self.ground_truth_output, 78 | raw_output 79 | ) 80 | 81 | return result 82 | 83 | def extra_repr_keys(self): 84 | return [] 85 | 86 | def _get_displayed_output(self, raw_output): 87 | return int(raw_output.argmax()) 88 | 89 | def _preprocess_model_output(self, model_output, attacked_text): 90 | """ 91 | Given a raw model output and the input sample this function 92 | returns the predictions as a list of numeric labels, a list of 93 | the confidence scores for each predicted label, a binary 94 | mask of the named entities in the ground truth and the confidence 95 | of the no-entity class 96 | """ 97 | named_entity_mask = tensor_mask( 98 | self._preprocess_ground_truth(attacked_text) 99 | ) 100 | 101 | tokenized_input = get_tokens(attacked_text.text, self.tokenizer) 102 | 103 | _, preds, confidence_scores, all_labels_confidences = self.ner_postprocess_func( 104 | attacked_text.text, 105 | model_output, 106 | tokenized_input) 107 | 108 | return preds, confidence_scores, named_entity_mask, all_labels_confidences 109 | 110 | def _preprocess_ground_truth(self, attacked_text): 111 | return torch.tensor(attacked_text.attack_attrs["ground_truth"]) 112 | 113 | def set_min_percent_entities_mispredicted(self, new_value): 114 | assert (new_value > 0 and new_value <= 1.0), "Min percent entities mispredicted should be in the interval (0, 1]" 115 | self.min_percent_entities_mispredicted = new_value 116 | 117 | def class_for_label(self, label): 118 | return label.replace("I-", "").replace("B-", "") 119 | 120 | @property 121 | def name(self): 122 | return "NER goal function" 123 | -------------------------------------------------------------------------------- /seqattack/goal_functions/ner_goal_function_result.py: -------------------------------------------------------------------------------- 1 | from textattack.shared.utils import color_from_output, color_text 2 | from textattack.goal_function_results import GoalFunctionResult 3 | 4 | 5 | class NERGoalFunctionResult(GoalFunctionResult): 6 | """ 7 | Represents the result of a NER goal function. 8 | """ 9 | def __init__( 10 | self, 11 | attacked_text, 12 | raw_output, 13 | output, 14 | goal_status, 15 | score, 16 | num_queries, 17 | ground_truth_output, 18 | unprocessed_raw_output 19 | ): 20 | super().__init__( 21 | attacked_text=attacked_text, 22 | raw_output=raw_output, 23 | output=output, 24 | goal_status=goal_status, 25 | score=score, 26 | num_queries=num_queries, 27 | ground_truth_output=ground_truth_output) 28 | 29 | self.unprocessed_raw_output = unprocessed_raw_output 30 | 31 | @property 32 | def _processed_output(self): 33 | """ 34 | Takes a model output (like `1`) and returns the class labeled output 35 | (like `positive`), if possible. Also returns the associated color. 36 | """ 37 | if self.attacked_text.attack_attrs.get("label_names"): 38 | token_labels = [ 39 | self.attacked_text.attack_attrs["label_names"][x] for x in self.raw_output 40 | ] 41 | 42 | token_colors = [ 43 | color_from_output(label, class_id) for (label, class_id) in zip(token_labels, self.raw_output) 44 | ] 45 | 46 | return token_labels, token_colors 47 | else: 48 | raise Exception("The dataset has no labels!") 49 | 50 | def get_text_color_input(self): 51 | """ 52 | A string representing the color this result's changed portion should 53 | be if it represents the original input. 54 | """ 55 | return "red" 56 | 57 | def get_text_color_perturbed(self): 58 | """ 59 | A string representing the color this result's changed portion should 60 | be if it represents the perturbed input. 61 | """ 62 | return "blue" 63 | 64 | def get_colored_output(self, color_method=None): 65 | """ 66 | Returns a string representation of this result's output, colored 67 | according to `color_method`. 68 | """ 69 | colored_labels = [] 70 | token_labels, token_colors = self._processed_output 71 | 72 | for label, color in zip(token_labels, token_colors): 73 | colored_output = color_text(label, color=color, method=color_method) 74 | colored_labels.append(colored_output) 75 | 76 | return ", ".join(colored_labels) 77 | -------------------------------------------------------------------------------- /seqattack/goal_functions/targeted_ner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .ner_goal_function_result import NERGoalFunctionResult 5 | 6 | from textattack.goal_functions.goal_function import GoalFunction 7 | from seqattack.utils import diff_elements_count, tensor_mask 8 | 9 | from .ner import NERGoalFunction 10 | 11 | 12 | class TargetedNERGoalFunction(NERGoalFunction): 13 | """ 14 | Goal function for NER that attempts to coerce named entities 15 | to be classified as "O" - a.k.a. no entity 16 | """ 17 | def _is_goal_complete(self, model_output, attacked_text): 18 | preds, _, _, _ = self._preprocess_model_output( 19 | model_output, 20 | attacked_text) 21 | 22 | return self._get_score_labels(preds, attacked_text) >= self.min_percent_entities_mispredicted 23 | 24 | def _get_score(self, model_output, attacked_text): 25 | return self._get_score_confidence( 26 | model_output, 27 | attacked_text) 28 | 29 | def _get_score_labels(self, model_output, attacked_text): 30 | mapped_ground_truth = self._preprocess_ground_truth(attacked_text) 31 | named_entities_mask = tensor_mask(mapped_ground_truth) 32 | 33 | if named_entities_mask.sum() == 0: 34 | # All entities are already "O", nothing to do here 35 | return 1 36 | 37 | total_score = 0 38 | 39 | for truth, model_out in zip(named_entities_mask, model_output): 40 | if truth > 0 and model_out == 0: 41 | total_score += 1 42 | 43 | # Return the percentage of mispredicted entities 44 | return (total_score / named_entities_mask.sum()).item() 45 | 46 | def _get_score_confidence(self, model_output, attacked_text): 47 | _, _, truth_entities_mask, confidence_scores = self._preprocess_model_output( 48 | model_output, 49 | attacked_text) 50 | 51 | if truth_entities_mask.sum() == 0: 52 | # Nothing to do here, we have no named entities 53 | return 1 54 | 55 | no_entity_confidences = [confs[0] for confs in confidence_scores] 56 | 57 | total_score = 0 58 | 59 | for is_entity, no_entity_confidence in zip(truth_entities_mask, no_entity_confidences): 60 | if is_entity == 1: 61 | total_score += no_entity_confidence 62 | 63 | return float(total_score / truth_entities_mask.sum()) 64 | 65 | @property 66 | def name(self): 67 | return "Targeted NER goal function" 68 | -------------------------------------------------------------------------------- /seqattack/goal_functions/untargeted_ner.py: -------------------------------------------------------------------------------- 1 | from seqattack.utils import diff_elements_count, tensor_mask 2 | 3 | from .ner import NERGoalFunction 4 | 5 | 6 | class UntargetedNERGoalFunction(NERGoalFunction): 7 | """ 8 | Goal function that determines whether an attack on 9 | named entity recognition was successful. We consider an attack 10 | to be successful when at least one named entity token is mispredicted 11 | to any other class (switching between I-CLS and B-CLS IS allowed) 12 | """ 13 | def _is_goal_complete(self, model_output, attacked_text): 14 | preds, _, _, _ = self._preprocess_model_output( 15 | model_output, 16 | attacked_text) 17 | 18 | return self._get_score_labels(preds, attacked_text) >= self.min_percent_entities_mispredicted 19 | 20 | def _get_score(self, model_output, attacked_text): 21 | return self._get_score_confidence( 22 | model_output, 23 | attacked_text) 24 | 25 | def _get_score_confidence(self, model_output, attacked_text): 26 | total_score = 0 27 | 28 | preds, _, named_entities_mask, all_labels_confidences = self._preprocess_model_output( 29 | model_output, 30 | attacked_text) 31 | 32 | mapped_ground_truth = self._preprocess_ground_truth(attacked_text) 33 | 34 | pred_token_labels = [self.label_names[x] for x in preds] 35 | truth_token_labels = [self.label_names[x] for x in mapped_ground_truth] 36 | 37 | for pred, pred_label, conf, truth, truth_label in zip(preds, pred_token_labels, all_labels_confidences, mapped_ground_truth, truth_token_labels): 38 | total_score += self._score_per_token(int(pred), pred_label, conf, int(truth), truth_label) 39 | 40 | if named_entities_mask.sum() == 0: 41 | # Always return 1 if the input sample has no 42 | # named entities (nothing to do here) 43 | return 1 44 | else: 45 | return float(total_score / named_entities_mask.sum()) 46 | 47 | def _score_per_token(self, pred, pred_label, conf, truth, truth_label): 48 | if pred == truth: 49 | return 1 - conf[pred] 50 | elif truth == 0: 51 | # A named entity was introduced. No score added 52 | return 0 53 | else: 54 | return 1 55 | 56 | def _get_score_labels(self, model_output, attacked_text): 57 | mapped_ground_truth = self._preprocess_ground_truth(attacked_text) 58 | named_entities_ground_truth = tensor_mask(mapped_ground_truth) 59 | 60 | if named_entities_ground_truth.sum() == 0: 61 | # No entities = nothing we can do here. 62 | # Return the maximum score 63 | return 1 64 | 65 | total_score = 0 66 | 67 | for truth, model_out in zip(mapped_ground_truth, model_output): 68 | if truth > 0 and truth != model_out: 69 | total_score += 1 70 | 71 | # Return the percentage of mispredicted entities 72 | return (total_score / named_entities_ground_truth.sum()).item() 73 | 74 | @property 75 | def name(self): 76 | return "Untargeted NER goal function" 77 | -------------------------------------------------------------------------------- /seqattack/goal_functions/untargeted_ner_strict.py: -------------------------------------------------------------------------------- 1 | from seqattack.utils import diff_elements_count, tensor_mask 2 | 3 | from .untargeted_ner import UntargetedNERGoalFunction 4 | 5 | 6 | class StrictUntargetedNERGoalFunction(UntargetedNERGoalFunction): 7 | """ 8 | Goal function that determines whether an attack on 9 | named entity recognition was successful. We consider an attack 10 | to be successful when at least one named entity token is mispredicted 11 | to any other class (switching between I-CLS and B-CLS is NOT allowed) 12 | """ 13 | def _score_per_token(self, pred, pred_label, conf, truth, truth_label): 14 | """ 15 | Returns 1 in case of a class swap and 1 - pred_conf in case of adherence 16 | to the ground truth 17 | """ 18 | pred_label = pred_label.replace("I-", "").replace("B-", "") 19 | truth_label = truth_label.replace("I-", "").replace("B-", "") 20 | 21 | if pred_label == truth_label: 22 | return 1 - conf[truth] 23 | else: 24 | return 1 25 | 26 | def _get_score_labels(self, model_output, attacked_text): 27 | """ 28 | Returns the percentage of mispredicted labels classes 29 | (e.g. ORG instead of PER) in the attacked text 30 | """ 31 | mapped_ground_truth = self._preprocess_ground_truth(attacked_text) 32 | named_entities_ground_truth = tensor_mask(mapped_ground_truth) 33 | 34 | if named_entities_ground_truth.sum() == 0: 35 | # No entities = nothing we can do here. 36 | # Return the maximum score 37 | return 1 38 | 39 | pred_token_labels = [self.label_names[x] for x in model_output] 40 | truth_token_labels = [self.label_names[x] for x in mapped_ground_truth] 41 | 42 | pred_token_labels = [self.class_for_label(x) for x in pred_token_labels] 43 | truth_token_labels = [self.class_for_label(x) for x in truth_token_labels] 44 | 45 | mispredicted_tokens_count = diff_elements_count( 46 | truth_token_labels, 47 | pred_token_labels 48 | ) 49 | 50 | if mispredicted_tokens_count > 0: 51 | # We have some mispredictions 52 | # Return the percentage of mispredicted entities 53 | # FIXME: maybe we can change it with recall/precision? 54 | return (mispredicted_tokens_count / named_entities_ground_truth.sum()).item() 55 | 56 | return 0 57 | 58 | @property 59 | def name(self): 60 | return "Strict untargeted NER goal function" 61 | -------------------------------------------------------------------------------- /seqattack/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ner import NERModelWrapper 2 | -------------------------------------------------------------------------------- /seqattack/models/exceptions/__init__.py: -------------------------------------------------------------------------------- 1 | class PredictionError(Exception): 2 | pass 3 | -------------------------------------------------------------------------------- /seqattack/models/ner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from textattack.models.wrappers import ModelWrapper 5 | from seqattack.utils import get_tokens 6 | 7 | from seqattack.utils import postprocess_ner_output 8 | from seqattack.models.exceptions import PredictionError 9 | 10 | from transformers import ( 11 | AutoTokenizer, 12 | AutoModelForTokenClassification 13 | ) 14 | 15 | 16 | class NERModelWrapper(ModelWrapper): 17 | def __init__(self, model, tokenizer, postprocess_func, name:str = None): 18 | self.model = model 19 | self.tokenizer = tokenizer 20 | 21 | # The post-processing function must accept three arguments: 22 | # 23 | # original_text: AttackedText instance of the original text 24 | # model_output: the model outputs 25 | # tokenized_input: the tokenized original text 26 | self.postprocess_func = postprocess_func 27 | self.name = name 28 | 29 | def __call__(self, text_inputs_list, raise_excs=False): 30 | """ 31 | Get the model predictions for the input texts list 32 | 33 | :param raise_excs: whether to raise an exception for a 34 | prediction error 35 | """ 36 | encoded = self.encode(text_inputs_list) 37 | 38 | with torch.no_grad(): 39 | outputs = [self._predict_single(x) for x in encoded] 40 | 41 | formatted_outputs = [] 42 | 43 | for model_output, original_text in zip(outputs, text_inputs_list): 44 | tokenized_input = get_tokens(original_text, self.tokenizer) 45 | 46 | # If less predictions than the input tokens are returned skip the sample 47 | if len(tokenized_input) != len(model_output): 48 | error_string = f"Tokenized text and model predictions differ in length! (preds: {len(model_output)} vs tokenized: {len(tokenized_input)}) for sample: {original_text}" 49 | 50 | print("Skipping sample") 51 | 52 | if raise_excs: 53 | raise PredictionError(error_string) 54 | else: 55 | continue 56 | 57 | formatted_outputs.append(model_output) 58 | 59 | return formatted_outputs 60 | 61 | def process_raw_output(self, raw_output, text_input): 62 | """ 63 | Returns the output as a list of numeric labels 64 | """ 65 | tokenized_input = get_tokens(text_input, self.tokenizer) 66 | 67 | return self.postprocess_func( 68 | text_input, 69 | raw_output, 70 | tokenized_input)[1] 71 | 72 | def _predict_single(self, encoded_sample): 73 | outputs = self.model(encoded_sample)[0][0].cpu().numpy() 74 | return np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True) 75 | 76 | def encode(self, inputs): 77 | return [self.tokenizer.encode(x, return_tensors="pt") for x in inputs] 78 | 79 | def _tokenize(self, inputs): 80 | return [ 81 | self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(x, return_tensors="pt")) 82 | for x in inputs 83 | ] 84 | 85 | @classmethod 86 | def load_huggingface_model(self, model_name): 87 | tokenizer = AutoTokenizer.from_pretrained(model_name) 88 | model = AutoModelForTokenClassification.from_pretrained(model_name) 89 | 90 | return tokenizer, NERModelWrapper( 91 | model, 92 | tokenizer, 93 | postprocess_func=postprocess_ner_output, 94 | name=model_name) 95 | -------------------------------------------------------------------------------- /seqattack/search/__init__.py: -------------------------------------------------------------------------------- 1 | from .greedy_ner_swap import NERGreedyWordSwapWIR 2 | from .greedy import GreedySearchNER -------------------------------------------------------------------------------- /seqattack/search/greedy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Greedy Search (with fix) 3 | ================= 4 | """ 5 | import numpy as np 6 | 7 | from textattack.search_methods import GreedySearch 8 | from textattack.goal_function_results import GoalFunctionResultStatus 9 | 10 | 11 | class GreedySearchNER(GreedySearch): 12 | def _perform_search(self, initial_result): 13 | beam = [initial_result.attacked_text] 14 | best_result = initial_result 15 | 16 | while not best_result.goal_status == GoalFunctionResultStatus.SUCCEEDED: 17 | potential_next_beam = [] 18 | for text in beam: 19 | transformations = self.get_transformations( 20 | text, original_text=initial_result.attacked_text 21 | ) 22 | potential_next_beam += transformations 23 | 24 | if len(potential_next_beam) == 0: 25 | # If we did not find any possible perturbations, give up. 26 | return best_result 27 | 28 | results, search_over = self.get_goal_results(potential_next_beam) 29 | 30 | if results is None or len(results) == 0: 31 | # Results may be empty if the query budget was previously 32 | # exceeded. In this case return the initial result 33 | return initial_result 34 | 35 | scores = np.array([r.score for r in results]) 36 | best_result = results[scores.argmax()] 37 | if search_over: 38 | return best_result 39 | 40 | # Refill the beam. This works by sorting the scores 41 | # in descending order and filling the beam from there. 42 | best_indices = (-scores).argsort()[: self.beam_width] 43 | beam = [potential_next_beam[i] for i in best_indices] 44 | 45 | return best_result 46 | -------------------------------------------------------------------------------- /seqattack/search/greedy_ner_swap.py: -------------------------------------------------------------------------------- 1 | import string 2 | import numpy as np 3 | 4 | from seqattack.datasets import NERDataset 5 | from textattack.search_methods import GreedyWordSwapWIR 6 | from seqattack.utils import get_tokens 7 | 8 | 9 | class NERGreedyWordSwapWIR(GreedyWordSwapWIR): 10 | """ 11 | This search strategy works in the same fashion as 12 | its superclass, with the addition that entity tokens 13 | (identified using the provided dataset) are not considered 14 | for the search (a.k.a. they are immutable) 15 | """ 16 | def __init__(self, dataset: NERDataset, tokenizer, wir_method="unk", max_subtokens=3): 17 | # max_subtokens should be kept under 4 to avoid a huge search space 18 | super().__init__(wir_method=wir_method) 19 | 20 | self.tokenizer = tokenizer 21 | self.dataset = dataset.to_dict() 22 | self.max_subtokens = max_subtokens 23 | 24 | def _get_index_order(self, initial_text): 25 | index_order, search_over = None, None 26 | 27 | if self.wir_method == "delete": 28 | # If the WIR method is delete make sure we do not leave 29 | # double spaces 30 | leave_one_texts = [ 31 | initial_text.delete_word_at_index(i) 32 | for i in range(len(initial_text.words)) 33 | ] 34 | 35 | for text in leave_one_texts: 36 | # Remove double spaces and trailing/leading whitespaces 37 | text.strip_remove_double_spaces() 38 | 39 | leave_one_results, search_over = self.get_goal_results( 40 | leave_one_texts 41 | ) 42 | 43 | index_scores = np.array( 44 | [result.score for result in leave_one_results] 45 | ) 46 | 47 | index_order = (-index_scores).argsort() 48 | else: 49 | # index order refers to the word index 50 | index_order, search_over = super()._get_index_order(initial_text) 51 | 52 | # Filter out words which have too many subtokens. This is required 53 | # to keep the search space reasonably small 54 | index_order = self._filter_subtokenized_words( 55 | initial_text, 56 | index_order) 57 | 58 | return index_order, search_over 59 | 60 | def _filter_subtokenized_words(self, initial_text, candidate_words_indices): 61 | """ 62 | Given an AttackedText and a list of candidate words (for 63 | replacement) filters out the words that are split by the 64 | tokenizer in more than max_subtokens sub-tokens 65 | """ 66 | candidate_words = [w for i, w in enumerate(initial_text.words) if i in candidate_words_indices] 67 | # Tokenize each word and remov the [CLS] and [SEP] start/end tokens 68 | candidate_words_tokens = [get_tokens(w, self.tokenizer)[1:-1] for w in candidate_words] 69 | 70 | filtered_indices = [] 71 | 72 | for tokens, index in zip(candidate_words_tokens, candidate_words_indices): 73 | if len(tokens) <= self.max_subtokens: 74 | filtered_indices.append(index) 75 | 76 | return filtered_indices 77 | 78 | def _get_entity_indices(self, initial_text): 79 | """ 80 | Returns a list of indices representing the entity 81 | token in a given text. The text must exists in the 82 | dataset this class was initialized with 83 | """ 84 | entities_indices = [] 85 | 86 | tokenized_text = initial_text.text.split(" ") 87 | ground_truth_labels = self.dataset[initial_text.text] 88 | 89 | # Keep track of how many punctuation symbols were skipped 90 | puntuaction_delta = 0 91 | 92 | for i, label in enumerate(ground_truth_labels): 93 | # Select any non-zero labels, which correspond to entities 94 | if label > 0: 95 | entities_indices.append(i - puntuaction_delta) 96 | elif tokenized_text[i] in string.punctuation: 97 | # Skip punctuation symbols. 98 | # FIXME: We should process tokens in the same way the 99 | # `words` property does on AttackedText instances 100 | puntuaction_delta += 1 101 | 102 | return entities_indices 103 | -------------------------------------------------------------------------------- /seqattack/transformations/__init__.py: -------------------------------------------------------------------------------- 1 | from .roberta_word_insert import RoBERTaWordInsertionMaskedLM 2 | from .paraphrase import ParaphraseTransformation 3 | -------------------------------------------------------------------------------- /seqattack/transformations/paraphrase.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from OpenAttack.attackers import SCPNAttacker 4 | 5 | from textattack.transformations import Transformation 6 | from seqattack.utils.ner_attacked_text import NERAttackedText 7 | 8 | 9 | class ParaphraseTransformation(Transformation): 10 | """ 11 | Generates paraphrases of the input sentence, 12 | re-mapping the original ground truth and making 13 | sure named entities are preserved. 14 | 15 | Based on (as implemented in OpenAttack): 16 | 17 | Adversarial Example Generation with Syntactically Controlled 18 | Paraphrase Networks. 19 | 20 | Mohit Iyyer, John Wieting, Kevin Gimpel, Luke Zettlemoyer. 21 | NAACL-HLT 2018. 22 | 23 | `[pdf] `__ 24 | `[code] `__ 25 | """ 26 | def __init__(self): 27 | self.attacker = SCPNAttacker() 28 | 29 | def _get_transformations(self, current_text, indices_to_modify): 30 | """ 31 | Returns paraphrases of the input text, mapping named 32 | entities to the new ground truth. 33 | 34 | WARNING: indices_to_modify is ignored in this transformation 35 | """ 36 | # Extract entities from the input text 37 | 38 | # FIXME: this strategy might have problems 39 | # if we have two named entities with the same 40 | # name and a different label 41 | entities = {} 42 | tokens = current_text.text.split(" ") 43 | ground_truth = current_text.attack_attrs["ground_truth"] 44 | 45 | for token, truth in zip(tokens, ground_truth): 46 | if truth == 0: 47 | continue 48 | 49 | entities[token.lower()] = { 50 | "token": token, 51 | "truth": truth 52 | } 53 | 54 | entities_set = set(entities.keys()) 55 | 56 | candidates = self.attacker.gen_paraphrase( 57 | current_text.text, 58 | self.attacker.config["templates"] 59 | ) 60 | 61 | out_texts = [] 62 | 63 | for cnd in candidates: 64 | cnd_tokens = cnd.split(" ") 65 | 66 | if not entities_set.issubset(set(cnd_tokens)): 67 | # All entity token must still be there 68 | continue 69 | 70 | # Sample approved, remap the ground truth 71 | final_cnd_tokens, cnd_truth = [], [] 72 | 73 | for cnd_token in cnd_tokens: 74 | if cnd_token in entities: 75 | # Label named entities in the transformed text and 76 | # preserve capitalization 77 | final_cnd_tokens.append(entities[cnd_token]["token"]) 78 | cnd_truth.append(entities[cnd_token]["truth"]) 79 | else: 80 | # All other tokens are considered as having no class 81 | final_cnd_tokens.append(cnd_token) 82 | cnd_truth.append(0) 83 | 84 | attack_attrs = copy.deepcopy(current_text.attack_attrs) 85 | attack_attrs["ground_truth"] = cnd_truth 86 | 87 | final_text = " ".join(final_cnd_tokens) 88 | 89 | out_texts.append( 90 | NERAttackedText( 91 | final_text, 92 | attack_attrs=attack_attrs 93 | ) 94 | ) 95 | 96 | return out_texts 97 | -------------------------------------------------------------------------------- /seqattack/transformations/roberta_word_insert.py: -------------------------------------------------------------------------------- 1 | from textattack.transformations import WordInsertionMaskedLM 2 | 3 | 4 | class RoBERTaWordInsertionMaskedLM(WordInsertionMaskedLM): 5 | def _get_transformations(self, current_text, indices_to_modify): 6 | indices_to_modify = list(indices_to_modify) 7 | new_words = self._get_new_words(current_text, indices_to_modify) 8 | 9 | transformed_texts = [] 10 | 11 | for i in range(len(new_words)): 12 | index_to_modify = indices_to_modify[i] 13 | word_at_index = current_text.words[index_to_modify] 14 | 15 | for word in new_words[i]: 16 | if word != word_at_index: 17 | transformed_texts.append( 18 | current_text.insert_text_before_word_index( 19 | # Fix RoBERTa BPE artifacts 20 | index_to_modify, word.replace("Ġ", "") 21 | ) 22 | ) 23 | 24 | return transformed_texts 25 | -------------------------------------------------------------------------------- /seqattack/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import numpy as np 4 | 5 | from textattack.shared import AttackedText 6 | 7 | from collections import defaultdict, namedtuple 8 | 9 | from .sequence import pad_sequence 10 | 11 | 12 | def get_tokens(text, tokenizer): 13 | return tokenizer.tokenize(tokenizer.decode(tokenizer.encode(text))) 14 | 15 | 16 | def predictions_per_character(out_tokens, predictions): 17 | single_char_text = [] 18 | single_char_predictions = [] 19 | single_char_raw_output = [] 20 | 21 | for out_token, pred in zip(out_tokens, predictions): 22 | prediction_label = np.argmax(pred) 23 | prediction_confidence = pred[prediction_label] 24 | 25 | token_chars = list(out_token) 26 | 27 | single_char_text.extend(token_chars) 28 | single_char_predictions.extend( 29 | [(prediction_label, prediction_confidence)] * len(token_chars) 30 | ) 31 | 32 | single_char_raw_output.extend(list( 33 | np.tile(pred, (len(token_chars), 1)) 34 | )) 35 | 36 | return single_char_text, single_char_predictions, single_char_raw_output 37 | 38 | 39 | def postprocess_ner_output(original_text, predictions, out_tokens): 40 | """ 41 | Postprocesses NER models output, fixing the issue of a model 42 | predicting labels for sub-tokens. 43 | 44 | # TODO: We should write automated tests for this 45 | """ 46 | original_tokens = original_text.split(" ") 47 | 48 | # Remove ## (BERT) BPE artifacts 49 | original_tokens = [t.replace("##", "") for t in original_tokens] 50 | out_tokens = [t.replace("##", "") for t in out_tokens] 51 | 52 | # Remove [CLS] / [SEP] tokens 53 | out_tokens = out_tokens[1:-1] 54 | predictions = predictions[1:-1] 55 | 56 | single_char_text, single_char_predictions, single_char_raw_output = predictions_per_character( 57 | out_tokens, 58 | predictions 59 | ) 60 | 61 | prediction_idx = 0 62 | current_out_token = "" 63 | 64 | # current_out_token_predictions is an array of tuples in the form 65 | # (predicted_label, confidence, token_length) 66 | current_out_token_predictions, preds_per_token, confidence_per_token = [], [], [] 67 | 68 | # For each token in the original sentence find the corresponding 69 | # token (or list of subtokens) in the output predictions 70 | for token in original_tokens: 71 | token_start_index = prediction_idx 72 | 73 | while current_out_token != token: 74 | current_out_token = f"{current_out_token}{single_char_text[prediction_idx]}" 75 | 76 | prediction_label, prediction_confidence = single_char_predictions[prediction_idx] 77 | current_out_token_predictions.append((prediction_label, prediction_confidence, 1)) 78 | 79 | prediction_idx += 1 80 | 81 | # After finding all the subtokens choose the most 82 | # likely prediction for the whole token 83 | chosen_label, _ = choose_best_ner_prediction(current_out_token_predictions) 84 | 85 | # Calculate all the label confidences 86 | label_confidences = {} 87 | 88 | for label in range(0, len(single_char_raw_output[0])): 89 | label_confidences[label] = sum([ 90 | single_char_raw_output[i][label] for i in range(token_start_index, prediction_idx) 91 | ]) / len(token) 92 | 93 | # Get the label confidence 94 | chosen_label_confidence = label_confidences[chosen_label] 95 | 96 | preds_per_token.append( 97 | (chosen_label, chosen_label_confidence, label_confidences) 98 | ) 99 | 100 | current_out_token = "" 101 | current_out_token_predictions = [] 102 | 103 | final_tokens, final_labels = [], [] 104 | final_label_confidence, final_confidence_dicts = [], [] 105 | 106 | for token, (label, label_confidence, confidences) in zip(original_tokens, preds_per_token): 107 | final_tokens.append(token) 108 | final_labels.append(label) 109 | final_label_confidence.append(label_confidence) 110 | final_confidence_dicts.append(confidences) 111 | 112 | return final_tokens, torch.tensor(final_labels), torch.tensor(final_label_confidence), final_confidence_dicts 113 | 114 | 115 | def choose_best_ner_prediction(preds_list): 116 | """ 117 | Given a list of subtoken predictions weight them by 118 | confidence score and length and choose the most likely 119 | prediction 120 | """ 121 | # Count how many characters each class covers for this token 122 | # and weight them by model confidence 123 | class_weights = defaultdict(int) 124 | 125 | for pred in preds_list: 126 | label, confidence, length = pred 127 | 128 | # Multiply the text length by the label's confidence 129 | class_weights[int(label)] += length * confidence 130 | 131 | weights = [(k, class_weights[k]) for k in class_weights.keys()] 132 | weights = sorted(weights, key=lambda x: x[1], reverse=True) 133 | 134 | return weights[0] 135 | 136 | 137 | def tensor_mask(tensor): 138 | """ 139 | Given a 1xN tensor returns a 1xN tensor where 140 | each element is 1 if its original counterpart 141 | is > 0 and 0 otherwise 142 | """ 143 | return torch.tensor([ 144 | 1 if x > 0 else 0 for x in tensor 145 | ]) 146 | 147 | 148 | def diff_elements_count(a: torch.tensor, b: torch.tensor): 149 | """ 150 | Returns the number of different elements between 151 | tensor a and tensor b 152 | """ 153 | return len(elements_diff(a, b)) 154 | 155 | 156 | def elements_diff(a: list, b: list): 157 | indices = set() 158 | 159 | for i in range(min(len(a), len(b))): 160 | if a[i] != b[i]: 161 | indices.add(i) 162 | 163 | return indices 164 | 165 | 166 | def is_ascii(s): 167 | return all(ord(c) < 128 for c in s) 168 | -------------------------------------------------------------------------------- /seqattack/utils/attack.py: -------------------------------------------------------------------------------- 1 | from seqattack.goal_functions.ner_goal_function_result import NERGoalFunctionResult 2 | import signal 3 | import numpy as np 4 | 5 | from textattack.shared import Attack 6 | from seqattack.utils.ner_attacked_text import NERAttackedText 7 | from textattack.goal_function_results import GoalFunctionResultStatus 8 | 9 | from textattack.attack_results import ( 10 | FailedAttackResult, 11 | SkippedAttackResult, 12 | SuccessfulAttackResult, 13 | ) 14 | 15 | 16 | class NERAttack(Attack): 17 | def __init__( 18 | self, 19 | goal_function=None, 20 | constraints=[], 21 | transformation=None, 22 | search_method=None, 23 | transformation_cache_size=2**15, 24 | constraint_cache_size=2**15, 25 | max_entities_mispredicted=0.8, 26 | search_step=0.1, 27 | attack_timeout=30): 28 | super().__init__( 29 | goal_function=goal_function, 30 | constraints=constraints, 31 | transformation=transformation, 32 | search_method=search_method, 33 | transformation_cache_size=transformation_cache_size, 34 | constraint_cache_size=constraint_cache_size 35 | ) 36 | 37 | self.search_step = search_step 38 | self.attack_timeout = attack_timeout 39 | self.max_entities_mispredicted = max_entities_mispredicted 40 | 41 | def _get_transformations_uncached(self, current_text, original_text=None, **kwargs): 42 | transformed_texts = super()._get_transformations_uncached( 43 | current_text, 44 | original_text=original_text, 45 | **kwargs) 46 | 47 | # Remove multiple spaces from samples 48 | for transformed in transformed_texts: 49 | transformed.strip_remove_double_spaces() 50 | 51 | return transformed_texts 52 | 53 | def timeout_hook(self, signum, frame): 54 | raise TimeoutError("Attack time expired") 55 | 56 | def attack_dataset(self, dataset, indices=None): 57 | # FIXME: Same as superclass 58 | examples = self._get_examples_from_dataset(dataset, indices=indices) 59 | 60 | for goal_function_result in examples: 61 | if goal_function_result.goal_status == GoalFunctionResultStatus.SKIPPED: 62 | yield SkippedAttackResult(goal_function_result) 63 | else: 64 | result = self.attack_one(goal_function_result) 65 | yield result 66 | 67 | def attack_one(self, initial_result): 68 | attacked_text = self.goal_function.initial_attacked_text 69 | # The initial (original) misprediction score 70 | initial_score = self.goal_function._get_score( 71 | attacked_text.attack_attrs["model_raw"], 72 | attacked_text) 73 | 74 | best_result, best_score = None, 0 75 | # List of misprediction targets [0.1, 0.2, ...] 76 | target_scores = np.arange( 77 | max(initial_score, self.search_step), 78 | self.max_entities_mispredicted + self.search_step, 79 | self.search_step 80 | ) 81 | 82 | try: 83 | # Set the timeout for attacking a single sample 84 | signal.signal(signal.SIGALRM, self.timeout_hook) 85 | signal.alarm(self.attack_timeout) 86 | 87 | for target in target_scores: 88 | # FIXME: To speed up the search use the current best result 89 | # FIXME: This code is better suited to be in a search method 90 | 91 | # Check if we can reach a mispredicion of target % 92 | if target < best_score: 93 | # If we already obtained a sample with a score higher 94 | # than the current target skip this iteration 95 | continue 96 | 97 | self.goal_function.min_percent_entities_mispredicted = target 98 | 99 | result = super().attack_one(initial_result) 100 | 101 | current_score = self.goal_function._get_score( 102 | result.perturbed_result.unprocessed_raw_output, 103 | result.perturbed_result.attacked_text 104 | ) 105 | 106 | if type(result) == SuccessfulAttackResult: 107 | if best_result is None or current_score > best_score: 108 | best_result, best_score = result, current_score 109 | elif type(result) == FailedAttackResult: 110 | if best_result is None: 111 | best_result, best_score = result, current_score 112 | 113 | # The attack failed, nothing else we can do 114 | break 115 | 116 | # Cancel the timeout 117 | signal.alarm(0) 118 | 119 | return best_result 120 | except Exception as ex: 121 | # FIXME: Handle timeouts etc. 122 | print(f"Could not attack sample: {ex}") 123 | 124 | return FailedAttackResult( 125 | initial_result, 126 | initial_result 127 | ) 128 | 129 | def _get_examples_from_dataset(self, dataset, indices=None): 130 | # FIXME: indices is currently ignored 131 | for example, ground_truth in dataset: 132 | model_raw, _, valid_prediction = self.__is_example_valid(example) 133 | 134 | attacked_text = NERAttackedText( 135 | example, 136 | attack_attrs={ 137 | "label_names": dataset.label_names, 138 | "model_raw": model_raw 139 | }, 140 | # FIXME: is this needed? 141 | ground_truth=[int(x) for x in ground_truth] 142 | ) 143 | 144 | if not valid_prediction: 145 | yield NERGoalFunctionResult( 146 | attacked_text=attacked_text, 147 | raw_output=None, 148 | output=None, 149 | goal_status=GoalFunctionResultStatus.SKIPPED, 150 | score=0, 151 | num_queries=0, 152 | ground_truth_output=None, 153 | unprocessed_raw_output=None 154 | ) 155 | 156 | # If the original prediction mispredicts more entities than 157 | # max_entities_mispredicted then we skip the example 158 | self.goal_function.min_percent_entities_mispredicted = self.max_entities_mispredicted 159 | goal_function_result, _ = self.goal_function.init_attack_example( 160 | attacked_text, 161 | ground_truth 162 | ) 163 | 164 | yield goal_function_result 165 | 166 | def __is_example_valid(self, sample): 167 | """Checks whether the model can correctly predict the sample or not""" 168 | model_raw = self.goal_function.model([sample])[0] 169 | model_pred = self.goal_function.model.process_raw_output(model_raw, sample) 170 | 171 | return model_raw, model_pred, len(model_pred) > 0 172 | -------------------------------------------------------------------------------- /seqattack/utils/attack_runner.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | class AttackRunner(): 5 | def __init__( 6 | self, 7 | attack, 8 | dataset, 9 | output_filename: str, 10 | attack_args: dict = None) -> None: 11 | self.attack = attack 12 | self.dataset = dataset 13 | self.output_filename = output_filename 14 | self.attack_args = attack_args 15 | 16 | def run(self): 17 | attack_results = [] 18 | attack_iterator = self.attack.attack_dataset(self.dataset) 19 | 20 | print("*****************************************") 21 | print(f"Starting attack on {self.dataset.name}") 22 | print("*****************************************") 23 | 24 | for sample, ground_truth in self.dataset.tolist(): 25 | sample_labels = self.__prediction_to_labels( 26 | ground_truth, 27 | self.dataset.label_names 28 | ) 29 | 30 | sample_labels = " ".join(sample_labels) 31 | 32 | print("-----------------------------------------") 33 | print(f"Attacking sample: {sample}") 34 | print(f"Labels: {sample_labels}") 35 | 36 | result = next(attack_iterator) 37 | 38 | print() 39 | print(f"Result: {self.__result_status(result)}") 40 | print() 41 | 42 | perturbed_text = result.perturbed_result.attacked_text.text 43 | perturbed_labels = self.__prediction_to_labels( 44 | result.perturbed_result.raw_output.tolist(), 45 | self.dataset.label_names 46 | ) 47 | 48 | perturbed_labels = " ".join(perturbed_labels) 49 | 50 | print(f"Perturbed sample: {perturbed_text}") 51 | print(f"Labels: {perturbed_labels}") 52 | 53 | attack_results.append(self.attack_result_to_json(result)) 54 | self.save_results(attack_results) 55 | 56 | def attack_result_to_json(self, result): 57 | original_text = result.original_result.attacked_text 58 | perturbed_text = result.perturbed_result.attacked_text 59 | 60 | labels_map = original_text.attack_attrs.get("label_names", None) 61 | 62 | original_pred = result.original_result.raw_output.tolist() 63 | perturbed_pred = result.perturbed_result.raw_output.tolist() 64 | 65 | try: 66 | perturbed_words = perturbed_text.words_diff_ratio(original_text) 67 | except (ZeroDivisionError, AssertionError): 68 | # Assign a difference of zero in case the input text has no words 69 | # or the two texts have a different words count 70 | perturbed_words = 0 71 | 72 | data = { 73 | "original_text": original_text.text, 74 | "ground_truth": original_text.attack_attrs["ground_truth"], 75 | "original_pred": original_pred, 76 | "status": self.__result_status(result), 77 | "original_score": result.original_result.score, 78 | "perturbed_score": result.perturbed_result.score, 79 | "perturbed_pred": perturbed_pred, 80 | "perturbed_text": perturbed_text.text, 81 | "final_ground_truth": perturbed_text.attack_attrs["ground_truth"], 82 | # FIXME: This might be an issue with CLARE (and insertions in general) 83 | "percent_words_perturbed": perturbed_words, 84 | "num_queries": result.num_queries 85 | } 86 | 87 | if labels_map is not None: 88 | labelled_data = { 89 | "ground_truth_labels": self.__prediction_to_labels(original_text.attack_attrs["ground_truth"], labels_map), 90 | "original_pred_labels": self.__prediction_to_labels(original_pred, labels_map), 91 | "perturbed_pred_labels": self.__prediction_to_labels(perturbed_pred, labels_map), 92 | "final_ground_truth_labels": self.__prediction_to_labels(perturbed_text.attack_attrs["ground_truth"], labels_map) 93 | } 94 | 95 | data = {**data, **labelled_data} 96 | 97 | return data 98 | 99 | def __result_status(self, result): 100 | return str(type(result)).split(".")[-1].replace("'>", "").replace("AttackResult", "") 101 | 102 | def __prediction_to_labels(self, pred, labels_map): 103 | return [labels_map[x] for x in pred] 104 | 105 | def save_results(self, attack_results): 106 | with open(self.output_filename, "w") as out_file: 107 | recipe_metadata = self.attack_args["recipe_metadata"] 108 | recipe_metadata["additional_constraints"] = [ 109 | str(const) for const in recipe_metadata["additional_constraints"] 110 | ] 111 | 112 | out_file.write(json.dumps({ 113 | "config": { 114 | "goal_function": self.attack.goal_function.name, 115 | "model": self.attack_args["model"].name, 116 | "tokenizer": str(self.attack_args["tokenizer"]), 117 | "dataset": self.attack_args["dataset"].name, 118 | "cache": self.attack_args["use_cache"], 119 | "query_budget": self.attack_args["query_budget"], 120 | "random_seed": self.attack_args["random_seed"], 121 | "examples_count": self.attack_args["num_examples"], 122 | "max_entities_mispredicted": self.attack_args["max_entities_mispredicted"], 123 | "attack_timeout": self.attack_args["attack_timeout"], 124 | "split": self.attack_args["dataset"].split, 125 | "recipe": self.attack_args["recipe"], 126 | "recipe_args": recipe_metadata 127 | }, 128 | "attacked_examples": attack_results 129 | })) 130 | -------------------------------------------------------------------------------- /seqattack/utils/ner_attacked_text.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from textattack.shared.attacked_text import AttackedText 4 | from textattack.shared.utils.strings import words_from_text 5 | 6 | from .sequence import pad_sequence 7 | 8 | 9 | class NERAttackedText(AttackedText): 10 | """ 11 | Custom attacked text class that keeps track of a sequence 12 | ground truth 13 | """ 14 | def __init__(self, text_input, attack_attrs=None, ground_truth=None): 15 | super().__init__(text_input, attack_attrs=attack_attrs) 16 | 17 | if "ground_truth" not in self.attack_attrs: 18 | self.attack_attrs["ground_truth"] = ground_truth 19 | 20 | def insert_text_after_word_index(self, index, text): 21 | return self._update_ground_truth_insert( 22 | super().insert_text_after_word_index(index, text) 23 | ) 24 | 25 | def insert_text_before_word_index(self, index, text): 26 | return self._update_ground_truth_insert( 27 | super().insert_text_before_word_index(index, text) 28 | ) 29 | 30 | def _update_ground_truth_insert(self, attacked_text): 31 | ground_truth_idx = self._ground_truth_inserted_index(attacked_text) 32 | 33 | # Insert a no-entity token in the ground truth to match the modified text 34 | updated_truth = copy.deepcopy(attacked_text.attack_attrs["ground_truth"]) 35 | updated_truth.insert(ground_truth_idx, 0) 36 | 37 | attacked_text.attack_attrs["ground_truth"] = updated_truth 38 | 39 | return attacked_text 40 | 41 | def _ground_truth_inserted_index(self, attacked_text): 42 | """ 43 | Returns the index of the last inserted 44 | word in the input attacked_text. 45 | """ 46 | if "previous_attacked_text" not in attacked_text.attack_attrs: 47 | return None 48 | 49 | new_tokens = attacked_text.text.split(" ") 50 | old_tokens = attacked_text.attack_attrs["previous_attacked_text"].text.split(" ") 51 | 52 | if len(old_tokens) == len(new_tokens): 53 | return None 54 | 55 | # Pad the old tokens sequence with None values 56 | # to match the length of new_tokens 57 | old_tokens = pad_sequence(old_tokens, len(new_tokens), filler=None) 58 | 59 | for i, (new, old) in enumerate(zip(new_tokens, old_tokens)): 60 | if new != old: 61 | return i 62 | 63 | return None 64 | 65 | def generate_new_attacked_text(self, new_words): 66 | new_attacked_text = super().generate_new_attacked_text(new_words) 67 | 68 | return NERAttackedText( 69 | text_input=new_attacked_text.text, 70 | attack_attrs=new_attacked_text.attack_attrs, 71 | ground_truth=self.attack_attrs.get("ground_truth", None)) 72 | 73 | def strip_remove_double_spaces(self): 74 | """ 75 | Removes double spaces in text and strips 76 | whitespace at both ends, which may lead 77 | to errors downstream 78 | """ 79 | self._text_input["text"] = " ".join(self._text_input["text"].split()) 80 | -------------------------------------------------------------------------------- /seqattack/utils/sequence.py: -------------------------------------------------------------------------------- 1 | def pad_sequence(seq, target_length, filler=None): 2 | filler_count = 0 3 | 4 | if len(seq) < target_length: 5 | filler_count = target_length - len(seq) 6 | 7 | return seq + [filler] * filler_count 8 | -------------------------------------------------------------------------------- /tests/.gitignore: -------------------------------------------------------------------------------- 1 | test_playground.py -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WalterSimoncini/SeqAttack/a0673613f489f355ddef37c89f0a635c89a500e9/tests/__init__.py -------------------------------------------------------------------------------- /tests/fixtures.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from transformers import AutoTokenizer 4 | from seqattack.models import NERModelWrapper 5 | 6 | from seqattack.constraints import AvoidNamedEntityConstraint 7 | 8 | 9 | @pytest.fixture(scope="module") 10 | def ner_model_wrapper(): 11 | return NERModelWrapper.load_huggingface_model("dslim/bert-base-NER")[1] 12 | 13 | 14 | @pytest.fixture(scope="module") 15 | def ner_tokenizer(): 16 | return AutoTokenizer.from_pretrained("dslim/bert-base-NER") 17 | 18 | 19 | @pytest.fixture(scope="module") 20 | def conll2003_labels(): 21 | return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"] 22 | 23 | 24 | @pytest.fixture() 25 | def avoid_named_entity_constraint(ner_model_wrapper): 26 | return AvoidNamedEntityConstraint(ner_model_wrapper=ner_model_wrapper) 27 | -------------------------------------------------------------------------------- /tests/mock/attacked.json: -------------------------------------------------------------------------------- 1 | { 2 | "attacked_sample": "A former cabinet minister in Central African Republic and his son were abducted from their home and murdered in growing ethnic violence in the capital Bangui , a government minister said on Friday .", 3 | "ground_truth": "0 0 0 0 0 7 8 8 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 7 0 0 0 0 0 0 0 0", 4 | "model_pred": "0 0 0 0 0 7 8 8 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 7 0 0 0 0 0 0 0 0", 5 | "raw_pred": [ 6 | [0.999916136264801, 1.5778290617163293e-05, 8.846631317283027e-06, 9.441051588510163e-06, 6.461807060986757e-06, 1.561868702992797e-05, 1.0672491043806076e-05, 6.900195785419783e-06, 1.0180865501752123e-05], 7 | [0.9041773080825806, 0.00228398060426116, 0.0012021164875477552, 0.013236983679234982, 0.000727494596503675, 0.06810734421014786, 0.003427548101171851, 0.005992302671074867, 0.0008448997396044433], 8 | [0.9999514222145081, 4.678279765357729e-06, 4.1425823837926146e-06, 3.85561315852101e-06, 3.2641846701153554e-06, 6.175295766297495e-06, 1.5973226254573092e-05, 3.046105121029541e-06, 7.56633789933403e-06], 9 | [0.9999657869338989, 5.1076958698104136e-06, 2.2691858703183243e-06, 3.5176301480532857e-06, 1.955574816747685e-06, 6.551269962074002e-06, 7.478341558453394e-06, 3.339742306707194e-06, 4.033683580928482e-06], 10 | [0.9999601244926453, 5.063598109700251e-06, 3.115646222795476e-06, 3.4549025258456822e-06, 2.6023001282737823e-06, 6.334867066470906e-06, 1.0976898010994773e-05, 3.234114956285339e-06, 5.162110028322786e-06], 11 | [0.9999697804450989, 4.1754783524083905e-06, 3.0315134154079715e-06, 2.576266751930234e-06, 2.192906549680629e-06, 3.835527877527056e-06, 6.893359568493906e-06, 2.506602641005884e-06, 5.043055352871306e-06], 12 | [5.202610191190615e-05, 0.00017135468078777194, 1.375401461700676e-05, 0.00014729528629686683, 1.4984335393819492e-05, 0.00012779149983543903, 2.1928372007096186e-05, 0.9994305372238159, 2.0374580344650894e-05], 13 | [1.391946898365859e-05, 4.492521838983521e-05, 0.001137213665060699, 2.138494346581865e-05, 0.000659289478790015, 2.3631533622392453e-05, 0.0003058563161175698, 0.00017772126011550426, 0.9976160526275635], 14 | [1.6999927538563497e-05, 2.4924391254899092e-05, 0.00045662769116461277, 1.4867559912090655e-05, 0.0003689534787554294, 1.5234886632242706e-05, 0.00015418973634950817, 0.00010200811811955646, 0.998846173286438], 15 | [0.9999706745147705, 3.0011740363988793e-06, 3.1759877856529783e-06, 2.2335725589073263e-06, 2.449981820973335e-06, 2.733172550506424e-06, 7.931082109280396e-06, 2.1342814306990476e-06, 5.693488674296532e-06], 16 | [0.9999735355377197, 4.596516646415694e-06, 2.328107711946359e-06, 2.999307071149815e-06, 2.1010025648138253e-06, 3.016348728124285e-06, 4.588056981447153e-06, 2.5843644380074693e-06, 4.323961093177786e-06], 17 | [0.9999716877937317, 4.6143209146976005e-06, 2.302268512721639e-06, 3.7066879485792015e-06, 2.668147999429493e-06, 3.1416236652148655e-06, 4.820748472411651e-06, 2.5659821858425857e-06, 4.442285444383742e-06], 18 | [0.9999778866767883, 3.1721567665954353e-06, 2.4845355710567674e-06, 2.248778400826268e-06, 2.373937377342372e-06, 2.4434932583972113e-06, 4.266842097422341e-06, 1.9873077690135688e-06, 3.1393926747114165e-06], 19 | [0.9999786615371704, 3.5285470403323416e-06, 2.4651872081449255e-06, 2.1686660147679504e-06, 1.993497107832809e-06, 2.507828185116523e-06, 3.953193299821578e-06, 1.8410289612802444e-06, 2.9262525913509307e-06], 20 | [0.9999731779098511, 3.980601832154207e-06, 3.2825216749188257e-06, 2.0968300304957666e-06, 2.4452615434711333e-06, 2.9039349556114757e-06, 5.4601000556431245e-06, 2.3576167222927324e-06, 4.287025603844086e-06], 21 | [0.9999780654907227, 3.625517365435371e-06, 2.0426746232260484e-06, 2.265548573632259e-06, 2.148987505279365e-06, 2.49821846409759e-06, 3.5991267850477016e-06, 2.268912794534117e-06, 3.392780854483135e-06], 22 | [0.9999735355377197, 4.133124093641527e-06, 2.2095375697972486e-06, 2.5086214918701444e-06, 2.2539700239576632e-06, 3.052437477890635e-06, 5.385711574490415e-06, 2.5849271878541913e-06, 4.282187546778005e-06], 23 | [0.9999784231185913, 3.473813876553322e-06, 2.8553818083310034e-06, 1.962768010344007e-06, 1.975832901734975e-06, 2.1513674255402293e-06, 4.301295575714903e-06, 1.6098604191938648e-06, 3.181151896569645e-06], 24 | [0.9999785423278809, 3.937171641155146e-06, 2.2272884052654263e-06, 2.277504336234415e-06, 1.8647896240508999e-06, 2.6371460535301594e-06, 3.692332256832742e-06, 1.7860758134702337e-06, 3.013156856468413e-06], 25 | [0.9999788403511047, 3.641900093498407e-06, 2.6687371246225666e-06, 2.024200057348935e-06, 1.9716244423761964e-06, 2.157462176910485e-06, 4.14594023823156e-06, 1.6350151099686627e-06, 3.018918278030469e-06], 26 | [0.9999789595603943, 3.888396349793766e-06, 2.4751325327088125e-06, 2.063167357846396e-06, 1.6327845742125646e-06, 2.542944457673002e-06, 3.979101165896282e-06, 1.5876235011091921e-06, 2.9301377253432292e-06], 27 | [0.9999746680259705, 5.131441866979003e-06, 2.762496478680987e-06, 2.1566538634942845e-06, 1.8001596799877007e-06, 3.033660505025182e-06, 5.424758910521632e-06, 1.8644163901626598e-06, 3.0862827316013863e-06], 28 | [0.999973714351654, 4.028257990285056e-06, 3.236032171116676e-06, 2.2163710582390195e-06, 2.3914556095405715e-06, 3.024440502485959e-06, 5.924188826611498e-06, 1.920333488669712e-06, 3.522332463035127e-06], 29 | [0.9999758005142212, 4.000472927145893e-06, 2.7868150027643424e-06, 2.136120201612357e-06, 2.0413135644048452e-06, 2.928321009676438e-06, 4.698605607700301e-06, 2.0052416402904782e-06, 3.6259775697544683e-06], 30 | [0.999975860118866, 4.4236958274268545e-06, 2.4001067231438356e-06, 2.2225881366466638e-06, 1.7689055766823003e-06, 3.2379186905018287e-06, 4.682013695855858e-06, 2.1081903014419368e-06, 3.375255573700997e-06], 31 | [0.9998907446861267, 1.578375304234214e-05, 5.3668845794163644e-06, 1.2409099326760042e-05, 4.750932475872105e-06, 1.8000529962591827e-05, 2.7013156795874238e-05, 1.2127270565542858e-05, 1.374690964439651e-05], 32 | [0.003926347475498915, 0.002735032932832837, 8.571740181650966e-05, 0.9138336777687073, 0.0005940198316238821, 0.03783855959773064, 0.00022412843827623874, 0.04050194472074509, 0.00026056013302877545], 33 | [0.07985284924507141, 0.009872201830148697, 0.041320718824863434, 0.012770624831318855, 0.3983164131641388, 0.006530588027089834, 0.04126868396997452, 0.011114982888102531, 0.3989529609680176], 34 | [0.9999727606773376, 3.0492840323859127e-06, 4.278926098777447e-06, 1.7797259488361306e-06, 2.2884205463924445e-06, 2.3278419121197658e-06, 6.757989467587322e-06, 1.5671344044676516e-06, 5.32957255927613e-06], 35 | [0.9999780654907227, 3.371870434420998e-06, 2.2849153538118117e-06, 2.2981241727393353e-06, 1.6529736512893578e-06, 2.649293946888065e-06, 4.260411969880806e-06, 1.8885518784372834e-06, 3.5209302495786687e-06], 36 | [0.9999715685844421, 4.4823377720604185e-06, 2.109321485477267e-06, 2.7282696919428417e-06, 1.7058819139492698e-06, 4.554907718556933e-06, 6.916285201441497e-06, 2.2532806269737193e-06, 3.6474220905802213e-06], 37 | [0.9999743700027466, 3.6195181110088015e-06, 2.6583768431009958e-06, 2.3016186787572224e-06, 1.993727437366033e-06, 2.8533349905046634e-06, 6.630017196584959e-06, 1.7421558595742681e-06, 3.860310243908316e-06], 38 | [0.9999783039093018, 4.104108484170865e-06, 2.1964040115562966e-06, 2.307585646121879e-06, 1.7826995417635771e-06, 2.418108351776027e-06, 3.8809712350484915e-06, 1.9073346493314602e-06, 3.0467688247881597e-06], 39 | [0.9999785423278809, 3.750041969396989e-06, 2.2953042844164884e-06, 1.9982642243121518e-06, 1.7815000319387764e-06, 2.492650310159661e-06, 4.245904165145475e-06, 1.708154854895838e-06, 3.200779929102282e-06], 40 | [0.9999630451202393, 9.121877155848779e-06, 2.033839109572e-06, 4.358249952929327e-06, 1.5242426343320403e-06, 6.7330302044865675e-06, 3.925659257220104e-06, 5.251821676210966e-06, 4.003347839898197e-06], 41 | [0.9999671578407288, 4.434484253579285e-06, 5.501009127328871e-06, 1.8976691080752062e-06, 2.5506999463686952e-06, 2.7574253635975765e-06, 8.314220394822769e-06, 1.764830017236818e-06, 5.609600975731155e-06], 42 | [0.9999758005142212, 4.4852217797597405e-06, 2.9977536541991867e-06, 2.03482409233402e-06, 1.7478064364695456e-06, 3.1193962968245614e-06, 4.811818598682294e-06, 1.6862202301126672e-06, 3.3220860586880008e-06], 43 | [0.9985728859901428, 0.0004173791385255754, 8.296168380184099e-05, 6.762178236385807e-05, 6.61568992654793e-05, 0.00017077356460504234, 0.0002180173760280013, 0.0002483870484866202, 0.0001558802032377571] 44 | ], 45 | "original_labels": ["O", "O", "O", "O", "O", "B-LOC", "I-LOC", "I-LOC", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "B-LOC", "O", "O", "O", "O", "O", "O", "O", "O"], 46 | "original_score": 0.0017228871583938599, 47 | "perturbed_score": 0.27519580721855164, 48 | "perturbed_labels": ["O", "O", "O", "O", "O", "B-LOC", "I-LOC", "I-LOC", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "B-PER", "O", "O", "O", "O", "O", "O", "O", "O"], 49 | "perturbed_pred": "0 0 0 0 0 7 8 8 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 0 0 0 0 0 0 0 0", 50 | "perturbed_text": "t of ministerial ministry in Central African Republic and his grandson were armed from their house and attacked in swelling ethnic reported in the department Bangui , a cabinet secretary said on sunday ." 51 | } -------------------------------------------------------------------------------- /tests/test_avoid_named_entity.py: -------------------------------------------------------------------------------- 1 | from seqattack.utils.ner_attacked_text import NERAttackedText 2 | from tests.fixtures import ( 3 | # This fixture is not used directly, but if removed tests won't run 4 | ner_model_wrapper, 5 | avoid_named_entity_constraint 6 | ) 7 | 8 | 9 | def test_avoid_named_entities(avoid_named_entity_constraint): 10 | sample_1 = NERAttackedText( 11 | "PHOENIX 3 14 .176 10 1/2", 12 | ground_truth=[3, 0, 0, 0, 0, 0]) 13 | 14 | sample_1_change = sample_1.replace_word_at_index(2, "24") 15 | sample_1_add = sample_1.replace_word_at_index(4, "Fairfax") 16 | sample_1_remove = sample_1.replace_word_at_index(0, "phone") 17 | 18 | # Samples with token insertion 19 | sample_insert_entity = sample_1.insert_text_before_word_index(5, "Canada") 20 | sample_insert_normal = sample_1.insert_text_before_word_index(5, "74") 21 | 22 | # Check that samples introducing a named entity are not allowed 23 | assert avoid_named_entity_constraint._check_constraint( 24 | sample_1_add, 25 | sample_1 26 | ) == False 27 | 28 | assert avoid_named_entity_constraint._check_constraint( 29 | sample_insert_entity, 30 | sample_1 31 | ) == False 32 | 33 | # Check that samples removing a named entity are not allowed 34 | assert avoid_named_entity_constraint._check_constraint( 35 | sample_1_remove, 36 | sample_1 37 | ) == False 38 | 39 | # Check that text which does not change the prediction is allowed 40 | assert avoid_named_entity_constraint._check_constraint( 41 | sample_1_change, 42 | sample_1 43 | ) == True 44 | 45 | assert avoid_named_entity_constraint._check_constraint( 46 | sample_insert_normal, 47 | sample_1 48 | ) == True 49 | 50 | # Check that samples which change an entity prediction are allowed 51 | sample_2 = NERAttackedText( 52 | "Wimbledon 16 9 4 3 29 17 31", 53 | ground_truth=[3, 0, 0, 0, 0, 0, 0, 0]) 54 | 55 | # Swapping 16 with 'par' changed the prediction of 'Wimbledon' 56 | # from B-ORG to B-MISC. This is specific to the model used in this file 57 | sample_2_change = sample_2.replace_word_at_index(1, "par") 58 | 59 | assert avoid_named_entity_constraint._check_constraint( 60 | sample_2_change, 61 | sample_2 62 | ) == True 63 | -------------------------------------------------------------------------------- /tests/test_ner_attacked_text.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from seqattack.utils.ner_attacked_text import NERAttackedText 4 | 5 | 6 | def test_ground_truth_inserted_index(): 7 | """ 8 | Verifies that the function returns 9 | the correct last inserted token index 10 | """ 11 | attacked_text = NERAttackedText( 12 | "Germany published the new bond rates", 13 | ground_truth=[1, 0, 2, 0, 0, 0]) 14 | 15 | insert_middle = attacked_text.insert_text_after_word_index(1, "today") 16 | insert_beginning = attacked_text.insert_text_before_word_index(0, "Today") 17 | insert_end = attacked_text.insert_text_after_word_index(5, "specifications") 18 | 19 | replace_word = attacked_text.replace_word_at_index(2, "Lorem") 20 | 21 | assert insert_middle._ground_truth_inserted_index( 22 | insert_middle) == 2 23 | 24 | assert insert_beginning._ground_truth_inserted_index( 25 | insert_beginning) == 0 26 | 27 | assert insert_end._ground_truth_inserted_index( 28 | insert_end) == 6 29 | 30 | # Make sure that if no token was inserted the returned index is None 31 | assert attacked_text._ground_truth_inserted_index(attacked_text) is None 32 | assert replace_word._ground_truth_inserted_index(replace_word) is None 33 | 34 | # Make sure the ground truths were updated correctly 35 | assert insert_middle.attack_attrs["ground_truth"] == [1, 0, 0, 2, 0, 0, 0] 36 | assert insert_beginning.attack_attrs["ground_truth"] == [0, 1, 0, 2, 0, 0, 0] 37 | assert insert_end.attack_attrs["ground_truth"] == [1, 0, 2, 0, 0, 0, 0] 38 | -------------------------------------------------------------------------------- /tests/test_non_ne_constraint.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from seqattack.utils.ner_attacked_text import NERAttackedText 4 | from seqattack.constraints import NonNamedEntityConstraint 5 | 6 | 7 | @pytest.fixture 8 | def non_named_entity_constraint(): 9 | return NonNamedEntityConstraint() 10 | 11 | 12 | def test_check_constraint(non_named_entity_constraint): 13 | original_text = NERAttackedText( 14 | "Europe has announced a new budget for education", 15 | ground_truth=[1, 0, 0, 0, 0, 0, 0, 0] 16 | ) 17 | 18 | perturbed_no_entities = original_text.insert_text_before_word_index(7, "tertiary") 19 | perturbed_replace_entity = original_text.replace_word_at_index(0, "it") 20 | perturbed_replace_non_entity = original_text.replace_word_at_index(2, "proposed") 21 | 22 | assert non_named_entity_constraint._check_constraint( 23 | perturbed_no_entities, 24 | original_text) == True 25 | 26 | assert non_named_entity_constraint._check_constraint( 27 | perturbed_replace_entity, 28 | original_text) == False 29 | 30 | assert non_named_entity_constraint._check_constraint( 31 | perturbed_replace_non_entity, 32 | original_text) == True 33 | -------------------------------------------------------------------------------- /tests/test_targeted_ner.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from seqattack.utils.ner_attacked_text import NERAttackedText 5 | 6 | from tests.fixtures import ( 7 | ner_model_wrapper, 8 | ner_tokenizer, 9 | conll2003_labels 10 | ) 11 | 12 | from seqattack.utils import postprocess_ner_output 13 | from seqattack.goal_functions import TargetedNERGoalFunction 14 | 15 | 16 | @pytest.fixture(scope="module") 17 | def targeted_ner_goal_function(ner_model_wrapper, ner_tokenizer, conll2003_labels): 18 | return TargetedNERGoalFunction( 19 | model_wrapper=ner_model_wrapper, 20 | tokenizer=ner_tokenizer, 21 | min_percent_entities_mispredicted=0.5, 22 | ner_postprocess_func=postprocess_ner_output, 23 | label_names=conll2003_labels) 24 | 25 | 26 | def test_get_score_confidence(targeted_ner_goal_function): 27 | attacked_text = NERAttackedText( 28 | "The Netherlands is nice", 29 | ground_truth=[1, 1, 0, 0]) 30 | 31 | ground_truth = torch.tensor([1, 1, 0, 0]) 32 | targeted_ner_goal_function.set_min_percent_entities_mispredicted(0.5) 33 | 34 | model_ground_truth = torch.tensor([ 35 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 36 | [0, 1, 0, 0, 0, 0, 0, 0, 0], 37 | [0, 1, 0, 0, 0, 0, 0, 0, 0], 38 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 39 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 40 | [1, 0, 0, 0, 0, 0, 0, 0, 0] 41 | ]) 42 | 43 | # Invalid misprediction - class change but not to "O"! 44 | model_invalid_misprediction = torch.tensor([ 45 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 46 | [0, 0, 0, 1, 0, 0, 0, 0, 0], 47 | [0, 1, 0, 0, 0, 0, 0, 0, 0], 48 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 49 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 50 | [1, 0, 0, 0, 0, 0, 0, 0, 0] 51 | ]) 52 | 53 | # Partial misprediction, 50% weight on O for first entity-token 54 | model_invalid_misprediction_probs = torch.tensor([ 55 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 56 | [0.5, 0, 0, 0.5, 0, 0, 0, 0, 0], 57 | [0, 1, 0, 0, 0, 0, 0, 0, 0], 58 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 59 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 60 | [1, 0, 0, 0, 0, 0, 0, 0, 0] 61 | ]) 62 | 63 | # Partial misprediction, 60% weight on O for first entity-token 64 | model_valid_misprediction_probs = torch.tensor([ 65 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 66 | [0.6, 0, 0, 0.4, 0, 0, 0, 0, 0], 67 | [0, 1, 0, 0, 0, 0, 0, 0, 0], 68 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 69 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 70 | [1, 0, 0, 0, 0, 0, 0, 0, 0] 71 | ]) 72 | 73 | # One entity token disappeared 74 | model_valid_misprediction = torch.tensor([ 75 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 76 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 77 | [0, 1, 0, 0, 0, 0, 0, 0, 0], 78 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 79 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 80 | [1, 0, 0, 0, 0, 0, 0, 0, 0] 81 | ]) 82 | 83 | targeted_ner_goal_function.init_attack_example( 84 | attacked_text, 85 | ground_truth) 86 | 87 | # Same sample, full confidence. so 0% difference expected 88 | assert round(targeted_ner_goal_function._get_score_confidence( 89 | model_ground_truth, 90 | attacked_text 91 | ), 2) == 0 92 | 93 | # Goal should be false, we're testing against the ground truth 94 | assert targeted_ner_goal_function._is_goal_complete( 95 | model_ground_truth, 96 | attacked_text 97 | ) == False 98 | 99 | # Same sample, full confidence but wrong non-O class so 0% difference expected 100 | assert round(targeted_ner_goal_function._get_score_confidence( 101 | model_invalid_misprediction, 102 | attacked_text 103 | ), 2) == 0 104 | 105 | # Goal should be false, wrong destination class 106 | assert targeted_ner_goal_function._is_goal_complete( 107 | model_invalid_misprediction, 108 | attacked_text 109 | ) == False 110 | 111 | # Same sample, 50% confidence on O for one class, thus we expect 25% score (50 / 2) 112 | assert round(targeted_ner_goal_function._get_score_confidence( 113 | model_invalid_misprediction_probs, 114 | attacked_text 115 | ), 2) == 0.25 116 | 117 | # Goal should be true (50% probability to O) 118 | assert targeted_ner_goal_function._is_goal_complete( 119 | model_invalid_misprediction_probs, 120 | attacked_text 121 | ) == True 122 | 123 | # Same sample, 60% confidence on O for one class, thus we expect 30% score (60 / 2) 124 | assert round(targeted_ner_goal_function._get_score_confidence( 125 | model_valid_misprediction_probs, 126 | attacked_text 127 | ), 2) == 0.30 128 | 129 | # Goal should be true (O has an higher confidence than the class token) 130 | assert targeted_ner_goal_function._is_goal_complete( 131 | model_valid_misprediction_probs, 132 | attacked_text 133 | ) == True 134 | 135 | # "Removed" 1/2 entity tokens, 50% score expected 136 | assert round(targeted_ner_goal_function._get_score_confidence( 137 | model_valid_misprediction, 138 | attacked_text 139 | ), 2) == 0.5 140 | 141 | # 50% misprediction score, so the goal should be complete 142 | assert targeted_ner_goal_function._is_goal_complete( 143 | model_valid_misprediction, 144 | attacked_text 145 | ) == True 146 | -------------------------------------------------------------------------------- /tests/test_untargeted_ner.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import pytest 4 | 5 | from seqattack.utils.ner_attacked_text import NERAttackedText 6 | 7 | from tests.utils import numeric_string_to_tensor 8 | from tests.fixtures import ( 9 | ner_model_wrapper, 10 | ner_tokenizer, 11 | conll2003_labels 12 | ) 13 | 14 | from seqattack.utils import postprocess_ner_output 15 | from seqattack.goal_functions import UntargetedNERGoalFunction 16 | 17 | 18 | @pytest.fixture(scope="module") 19 | def untargeted_ner_goal_function(ner_model_wrapper, ner_tokenizer, conll2003_labels): 20 | return UntargetedNERGoalFunction( 21 | model_wrapper=ner_model_wrapper, 22 | tokenizer=ner_tokenizer, 23 | min_percent_entities_mispredicted=0.5, 24 | ner_postprocess_func=postprocess_ner_output, 25 | label_names=conll2003_labels) 26 | 27 | 28 | def test_ner_goal_function_get_score_confidence_example(untargeted_ner_goal_function): 29 | json_data = json.loads( 30 | open("tests/mock/attacked.json").read() 31 | ) 32 | 33 | ground_truth = numeric_string_to_tensor(json_data["ground_truth"]) 34 | attacked_text = NERAttackedText( 35 | json_data["attacked_sample"], 36 | ground_truth=[int(x) for x in ground_truth]) 37 | 38 | perturbed_pred = numeric_string_to_tensor(json_data["perturbed_pred"]) 39 | 40 | perturbed_model_pred = torch.tensor(json_data["raw_pred"]) 41 | 42 | untargeted_ner_goal_function.init_attack_example( 43 | attacked_text, 44 | ground_truth) 45 | 46 | # Only one entity is changed in the sample, thus we are 47 | # happy with a 25% misprediction 48 | untargeted_ner_goal_function.set_min_percent_entities_mispredicted(0.25) 49 | 50 | assert untargeted_ner_goal_function._is_goal_complete( 51 | perturbed_model_pred, 52 | attacked_text 53 | ) == True 54 | 55 | assert round(untargeted_ner_goal_function._get_score_confidence( 56 | perturbed_model_pred, 57 | attacked_text 58 | ), 2) == 0.28 59 | 60 | 61 | def test_ner_goal_function_get_score_confidence(untargeted_ner_goal_function): 62 | attacked_text = NERAttackedText( 63 | "The Netherlands is nice", 64 | ground_truth=[1, 1, 0, 0]) 65 | ground_truth = torch.tensor([1, 1, 0, 0]) 66 | untargeted_ner_goal_function.set_min_percent_entities_mispredicted(0.5) 67 | 68 | model_ground_truth = torch.tensor([ 69 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 70 | [0, 1, 0, 0, 0, 0, 0, 0, 0], 71 | [0, 1, 0, 0, 0, 0, 0, 0, 0], 72 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 73 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 74 | [1, 0, 0, 0, 0, 0, 0, 0, 0] 75 | ]) 76 | 77 | model_valid_misprediction = torch.tensor([ 78 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 79 | [0, 0, 0, 1, 0, 0, 0, 0, 0], 80 | [0, 1, 0, 0, 0, 0, 0, 0, 0], 81 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 82 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 83 | [1, 0, 0, 0, 0, 0, 0, 0, 0] 84 | ]) 85 | 86 | model_confidence_misprediction = torch.tensor([ 87 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 88 | [0, 0.5, 0, 0, 0, 0, 0, 0, 0.5], 89 | [0, 0.5, 0, 0, 0, 0, 0, 0, 0.5], 90 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 91 | [1, 0, 0, 0, 0, 0, 0, 0, 0], 92 | [1, 0, 0, 0, 0, 0, 0, 0, 0] 93 | ]) 94 | 95 | untargeted_ner_goal_function.init_attack_example( 96 | attacked_text, 97 | ground_truth) 98 | 99 | # Same sample, full confidence. so 0% difference expected 100 | assert round(untargeted_ner_goal_function._get_score_confidence( 101 | model_ground_truth, 102 | attacked_text 103 | ), 2) == 0 104 | 105 | # One misprediction, full confidence. so 50% difference expected 106 | assert round(untargeted_ner_goal_function._get_score_confidence( 107 | model_valid_misprediction, 108 | attacked_text 109 | ), 2) == 0.50 110 | 111 | # The 50% misprediction should result in a successful attack 112 | assert untargeted_ner_goal_function._is_goal_complete( 113 | model_valid_misprediction, 114 | attacked_text 115 | ) == True 116 | 117 | # No misprediction, two entities with 50% confidence, 118 | # so we expect a 50% score 119 | assert round(untargeted_ner_goal_function._get_score_confidence( 120 | model_confidence_misprediction, 121 | attacked_text 122 | ), 2) == 0.50 123 | 124 | # But since no entity is actually mispredicted the 125 | # goal should NOT be complete 126 | assert untargeted_ner_goal_function._is_goal_complete( 127 | model_confidence_misprediction, 128 | attacked_text 129 | ) == False 130 | 131 | 132 | def test_ner_goal_function_get_score_labels(untargeted_ner_goal_function): 133 | attacked_text = NERAttackedText( 134 | "5. Jacqui Cooper ( Australia ) 156.52", 135 | ground_truth=[0, 3, 4, 0, 7, 0, 0]) 136 | 137 | ground_truth = torch.tensor([0, 3, 4, 0, 7, 0, 0]) 138 | valid_misprediction = torch.tensor([0, 1, 1, 0, 7, 0, 0]) 139 | 140 | untargeted_ner_goal_function.init_attack_example( 141 | attacked_text, 142 | ground_truth) 143 | 144 | # Same sample, so 0% difference expected 145 | assert round(untargeted_ner_goal_function._get_score_labels( 146 | ground_truth, 147 | attacked_text 148 | ), 2) == 0 149 | 150 | # 66% altered tensor, _get_score should return 0.67 (considering rounding) 151 | assert round(untargeted_ner_goal_function._get_score_labels( 152 | valid_misprediction, 153 | attacked_text 154 | ), 2) == 0.67 155 | 156 | # Make sure that if an entity is introduced in a 157 | # no-entities sample the returned score will be 1 158 | ground_truth_empty = torch.tensor([0, 0, 0, 0, 0, 0, 0]) 159 | add_entity_misprediction = torch.tensor([0, 1, 0, 0, 0, 0, 0]) 160 | 161 | untargeted_ner_goal_function.init_attack_example( 162 | attacked_text, 163 | ground_truth_empty) 164 | 165 | assert round(untargeted_ner_goal_function._get_score_labels( 166 | add_entity_misprediction, 167 | attacked_text 168 | ), 2) == 1 169 | -------------------------------------------------------------------------------- /tests/test_untargeted_ner_strict.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from textattack.shared import AttackedText 4 | 5 | from tests.fixtures import ( 6 | conll2003_labels, 7 | ner_model_wrapper, 8 | ner_tokenizer 9 | ) 10 | 11 | from seqattack.utils import postprocess_ner_output 12 | from seqattack.goal_functions.untargeted_ner_strict import StrictUntargetedNERGoalFunction 13 | 14 | 15 | @pytest.fixture(scope="module") 16 | def strict_untargeted_ner_goal_function(ner_model_wrapper, ner_tokenizer, conll2003_labels): 17 | return StrictUntargetedNERGoalFunction( 18 | model_wrapper=ner_model_wrapper, 19 | tokenizer=ner_tokenizer, 20 | min_percent_entities_mispredicted=0.5, 21 | ner_postprocess_func=postprocess_ner_output, 22 | label_names=conll2003_labels) 23 | 24 | 25 | def test_score_per_token(strict_untargeted_ner_goal_function): 26 | # Same class and numeric label, 0.9 confidence --> 0.1 expected 27 | assert round(strict_untargeted_ner_goal_function._score_per_token( 28 | 1, "I-PER", 29 | { 30 | 1: 0.9 31 | }, 32 | 1, "I-PER"), 2) == 0.1 33 | 34 | # Same class, different numeric label, 0.1 ground truth confidence 35 | assert round(strict_untargeted_ner_goal_function._score_per_token( 36 | 1, "B-PER", 37 | { 38 | 1: 0.2, 39 | 2: 0.7, 40 | 3: 0.1 41 | }, 42 | 2, "I-PER"), 2) == 0.3 43 | 44 | # Different numeric label and class, any confidence. 1 expected 45 | assert round(strict_untargeted_ner_goal_function._score_per_token( 46 | 1, "I-PER", 47 | { 48 | 1: 0.1, 49 | 3: 0.9 50 | }, 51 | 3, "I-LOC"), 2) == 1 52 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from seqattack.utils import elements_diff, pad_sequence 2 | 3 | 4 | def test_elements_diff(): 5 | assert elements_diff( 6 | [0, 0, 0, 0, 0, 0], 7 | [0, 0, 0, 0, 0, 0] 8 | ) == set() 9 | 10 | assert elements_diff( 11 | [0, 0, 1, 0, 0, 0], 12 | [0, 0, 0, 0, 3, 0] 13 | ) == set([2, 4]) 14 | 15 | 16 | def test_pad_sequence(): 17 | original = [0, 1, 3, 0, 4] 18 | 19 | # Same length as sequence, no change expected 20 | assert pad_sequence(original, len(original)) == original 21 | 22 | # 2 pads needed 23 | expected = [0, 1, 3, 0, 4, 47, 47] 24 | assert pad_sequence(original, len(original) + 2, filler=47) == expected 25 | 26 | # Longer than required, so no padding needed 27 | assert pad_sequence(original, len(original) - 2, filler=47) == original 28 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def numeric_string_to_tensor(in_string): 5 | """ 6 | Converts a string in the format "0 1 56 7 8" 7 | to an array tensor 8 | """ 9 | labels_list = [int(x) for x in in_string.split(" ")] 10 | return torch.tensor(labels_list) 11 | --------------------------------------------------------------------------------