├── .gitignore ├── LICENSE ├── README.md ├── augment.py ├── augmenters.py ├── configs.py ├── evaluate.py ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 QData 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A2T: Towards Improving Adversarial Training of NLP Models 2 | 3 | This is the source code for the EMNLP 2021 (Findings) paper ["Towards Improving Adversarial Training of NLP Models"](https://arxiv.org/abs/2109.00544). 4 | 5 | If you use the code, please cite the paper: 6 | ``` 7 | @misc{yoo2021improving, 8 | title={Towards Improving Adversarial Training of NLP Models}, 9 | author={Jin Yong Yoo and Yanjun Qi}, 10 | year={2021}, 11 | eprint={2109.00544}, 12 | archivePrefix={arXiv}, 13 | primaryClass={cs.CL} 14 | } 15 | ``` 16 | 17 | ## Prerequisites 18 | The work heavily relies on the [TextAttack](https://github.com/QData/TextAttack) package. In fact, the main training code is implemented in the TextAttack package. 19 | 20 | Required packages are listed in the `requirements.txt` file. 21 | ``` 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | ## Data 26 | All of the data used for the paper are available from HuggingFace's [Datasets](https://huggingface.co/datasets). 27 | 28 | For IMDB and Yelp datasets, because there are no official validation splits, we randomly sampled 5k and 10k, respectively, from the training set and used them as valid splits. We provide the splits in this Google Drive [folder](https://drive.google.com/drive/folders/1-vvSXUzl1PzMzdyZzAWq2dB--m7tEERK?usp=sharing). To use them with the provided code, place each folder (e.g. `imdb`, `yelp`, `augmented_data`) inside `./data` (run `mkdir data`). 29 | 30 | Also, augmented training data generated using SSMBA and back-translation are available in the same folder. 31 | 32 | ## Training 33 | To train BERT model on IMDB dataset with A2T attack for 4 epochs and 1 clean epoch with gamma of 0.2: 34 | ``` 35 | python train.py \ 36 | --train imdb \ 37 | --eval imdb \ 38 | --model-type bert \ 39 | --model-save-path ./example \ 40 | --num-epochs 4 \ 41 | --num-clean-epochs 1 \ 42 | --num-adv-examples 0.2 \ 43 | --attack-epoch-interval 1 \ 44 | --attack a2t \ 45 | --learning-rate 5e-5 \ 46 | --num-warmup-steps 100 \ 47 | --grad-accumu-steps 1 \ 48 | --checkpoint-interval-epochs 1 \ 49 | --seed 42 50 | ``` 51 | 52 | You can also pass `roberta` to train RoBERTa model instead of BERT model. To select other datasets from the paper, pass `rt` (MR), `yelp`, or `snli` for `--train` and `--eval`. 53 | 54 | This script is actually just to run the `Trainer` class from the TextAttack package. To checkout how training is performed, please checkout the `Trainer` [class](https://github.com/QData/TextAttack/blob/master/textattack/trainer.py). 55 | 56 | ## Evaluation 57 | To evalute the accuracy, robustness, and interpretability of our trained model from above, run 58 | ``` 59 | python evaluate.py \ 60 | --dataset imdb \ 61 | --model-type bert \ 62 | --checkpoint-paths ./example_run \ 63 | --epoch 4 \ 64 | --save-log \ 65 | --accuracy \ 66 | --robustness \ 67 | --attacks a2t a2t_mlm textfooler bae pwws pso \ 68 | --interpretability 69 | ``` 70 | 71 | This takes the last checkpoint model (`--epoch 4`) and evaluates its accuracy on both IMDB and Yelp dataset (for cross-domain accuracy). It also evalutes the model's robustness against A2T, A2T-MLM, TextFooler, BAE, PWWS, and PSO attacks. Lastly, with the `--interpretability` flag, AOPC scores are calculated. 72 | 73 | Note that you will have to run `--robustness` and `--interpretability` with `--accuracy` (or after you separately evaluate accuracy) since both robustness and intepretability evaluations rely on the accuracy evaluation to know which samples the model was able to predict correctly. 74 | By default 1000 samples are attacked to evaluate robustness. Likewise, 1000 samples are used to calculate AOPC score for interpretability. 75 | 76 | If you're evaluating multiple models for comparison, it's also advised that you provide all the checkpoint paths together to `--checkpoint-paths`. This is because the samples that are correctly by each model will be different, so we first need to identify the intersection of the all correct predictions before using them to evaluate robustness for all the models. This will allow fairer comparison of models' robustness rather than using attack different samples for each model. 77 | 78 | ## Data Augmentation 79 | Lastly, we also provide `augment.py` which we used to perform data augmentation methods such as SSMBA and back-translation. 80 | 81 | Following is an example command for augmenting imdb dataset with SSMBA method. 82 | ``` 83 | python augment.py \ 84 | --dataset imdb \ 85 | --augmentation ssmba \ 86 | --output-path ./augmented_data \ 87 | --seed 42 88 | ``` 89 | 90 | You can also pass `backtranslation` to `--augmentation`. 91 | -------------------------------------------------------------------------------- /augment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import datasets 4 | import tqdm 5 | import torch 6 | import random 7 | import numpy as np 8 | 9 | from augmenters import BackTranslationAugmenter, SSMBA 10 | from configs import DATASET_CONFIGS 11 | 12 | 13 | def set_seed(seed): 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | 19 | 20 | def augment(args, num_gpus, in_queue, out_queue): 21 | gpu_id = (torch.multiprocessing.current_process()._identity[0] - 1) % num_gpus 22 | set_seed(args.seed) 23 | torch.cuda.set_device(gpu_id) 24 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 25 | torch.cuda.set_device(gpu_id) 26 | 27 | if args.augmentation == "backtranslation": 28 | augmenter = BackTranslationAugmenter() 29 | elif args.augmentation == "ssmba": 30 | augmenter = SSMBA() 31 | else: 32 | raise ValueError(f"Unknown augmentation {augmentation}.") 33 | 34 | while True: 35 | try: 36 | i, inputs, label = in_queue.get() 37 | if i == "END" and example == "END" and ground_truth_output == "END": 38 | # End process when sentinel value is received 39 | break 40 | else: 41 | if isinstance(inputs, tuple): 42 | text_to_augment = inputs[1] 43 | else: 44 | text_to_augment = inputs 45 | 46 | augmented_text = "" 47 | tries = 0 48 | while augmented_text == "" and tries < 10: 49 | augmented_text = augmenter(text_to_augment) 50 | augmented_text = augmented_text.strip() 51 | tries += 1 52 | 53 | if isinstance(inputs, tuple): 54 | augmented_text = (inputs[0], augmented_text) 55 | 56 | out_queue.put((i, augmented_text, label)) 57 | except Exception as e: 58 | out_queue.put((i, e, e)) 59 | 60 | 61 | def main(args): 62 | if args.dataset not in DATASET_CONFIGS: 63 | raise ValueError(f"Unknown dataset {args.dataset}") 64 | dataset_config = DATASET_CONFIGS[args.dataset] 65 | 66 | if "local_path" in dataset_config: 67 | dataset = datasets.load_dataset( 68 | "csv", 69 | data_files=os.path.join(dataset_config["local_path"], "train.tsv"), 70 | delimiter="\t", 71 | )["train"] 72 | else: 73 | dataset = datasets.load_dataset(dataset_config["remote_name"], split="train") 74 | 75 | augmented_text = [] 76 | augmented_label = [] 77 | augmented_indices = [] 78 | num_workers = torch.cuda.device_count() 79 | assert num_workers >= 1, "You need at least one GPU to perform augmentation." 80 | 81 | torch.multiprocessing.set_start_method("spawn", force=True) 82 | torch.multiprocessing.set_sharing_strategy("file_system") 83 | 84 | in_queue = torch.multiprocessing.Queue() 85 | out_queue = torch.multiprocessing.Queue() 86 | input_columns = dataset_config["dataset_columns"][0] 87 | for i, row in enumerate(dataset): 88 | input_text = tuple(row[col] for col in input_columns) 89 | if len(input_text) == 1: 90 | input_text = input_text[0] 91 | in_queue.put((i, input_text, row["label"])) 92 | 93 | # Start workers. 94 | worker_pool = torch.multiprocessing.Pool( 95 | num_workers, 96 | augment, 97 | ( 98 | args, 99 | num_workers, 100 | in_queue, 101 | out_queue, 102 | ), 103 | ) 104 | pbar = tqdm.tqdm(total=len(dataset), smoothing=0) 105 | for _ in range(len(dataset)): 106 | idx, aug_text, aug_label = out_queue.get(block=True) 107 | pbar.update() 108 | if isinstance(aug_text, Exception): 109 | continue 110 | if aug_text == "": 111 | continue 112 | augmented_indices.append(idx) 113 | augmented_text.append(aug_text) 114 | augmented_label.append(aug_label) 115 | 116 | # Send sentinel values to worker processes 117 | for _ in range(num_workers): 118 | in_queue.put(("END", "END", "END")) 119 | worker_pool.terminate() 120 | worker_pool.join() 121 | 122 | augmented_indices = np.array(augmented_indices) 123 | argsort_indices = np.argsort(augmented_indices) 124 | augmented_text = [augmented_text[i] for i in argsort_indices] 125 | augmented_label = [augmented_label[i] for i in argsort_indices] 126 | 127 | if isinstance(augmented_text[0], tuple): 128 | augmented_data = { 129 | col: [t[i] for t in augmented_text] for i, col in enumerate(input_columns) 130 | } 131 | augmented_data["label"] = augmented_label 132 | else: 133 | augmented_data = {input_columns[0]: augmented_text, "label": augmented_label} 134 | 135 | augmented_dataset = datasets.Dataset.from_dict(augmented_data) 136 | if not os.path.exists(os.path.dirname(args.output_path)): 137 | os.makedirs(os.path.dirname(args.output_path)) 138 | augmented_dataset.to_csv(args.output_path, sep="\t", index=False) 139 | 140 | 141 | if __name__ == "__main__": 142 | parser = argparse.ArgumentParser() 143 | parser.add_argument( 144 | "--augmentation", 145 | type=str, 146 | required=True, 147 | choices=["ssmba", "backtranslation"], 148 | help="Augmentation to use", 149 | ) 150 | parser.add_argument( 151 | "--dataset", 152 | type=str, 153 | required=True, 154 | choices=sorted(list(DATASET_CONFIGS.keys())), 155 | help="Name of dataset to augment", 156 | ) 157 | parser.add_argument( 158 | "--output-path", 159 | type=str, 160 | required=True, 161 | help="Output path for augmented data (in TSV format).", 162 | ) 163 | parser.add_argument("--seed", type=int, default=42, help="Random seed.") 164 | args = parser.parse_args() 165 | main(args) 166 | -------------------------------------------------------------------------------- /augmenters.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | import nltk 4 | import numpy as np 5 | 6 | 7 | class BackTranslationAugmenter: 8 | def __init__(self, sample_temp=0.8, batch_size=16): 9 | self.en2de_model = transformers.FSMTForConditionalGeneration.from_pretrained( 10 | "facebook/wmt19-en-de" 11 | ) 12 | self.en2de_tokenizer = transformers.FSMTTokenizer.from_pretrained( 13 | "facebook/wmt19-en-de" 14 | ) 15 | self.de2en_model = transformers.FSMTForConditionalGeneration.from_pretrained( 16 | "facebook/wmt19-de-en" 17 | ) 18 | self.de2en_tokenizer = transformers.FSMTTokenizer.from_pretrained( 19 | "facebook/wmt19-de-en" 20 | ) 21 | self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | self.en2de_model.eval() 23 | self.de2en_model.eval() 24 | self.en2de_model.to(self._device) 25 | self.de2en_model.to(self._device) 26 | self.sample_temp = sample_temp 27 | self.batch_size = batch_size 28 | 29 | def __call__(self, text): 30 | """ 31 | Generate list of augmented data based off of `text` 32 | 33 | Args: 34 | text (str): seed text 35 | Returns: 36 | augmented_text (list[str]): List of augmented text 37 | """ 38 | return self.batch_call([text])[0] 39 | 40 | def batch_call(self, texts): 41 | # First split paragraphs into sentences: 42 | texts_as_sent = [] 43 | num_sents = [] 44 | for text in texts: 45 | sentences = nltk.sent_tokenize(text) 46 | texts_as_sent.extend(sentences) 47 | num_sents.append(len(sentences)) 48 | 49 | i = 0 50 | translated_texts = [] 51 | while i < len(texts_as_sent): 52 | batch = texts_as_sent[i : i + self.batch_size] 53 | i += self.batch_size 54 | 55 | with torch.no_grad(): 56 | en_inputs = self.en2de_tokenizer.batch_encode_plus( 57 | batch, 58 | return_tensors="pt", 59 | padding=True, 60 | truncation=True, 61 | max_length=128, 62 | ).to(self._device) 63 | de_outputs = self.en2de_model.generate( 64 | en_inputs["input_ids"], do_sample=True, temperature=self.sample_temp 65 | ) 66 | de_texts = [ 67 | self.en2de_tokenizer.decode(output, skip_special_tokens=True) 68 | for output in de_outputs 69 | ] 70 | de_inputs = self.de2en_tokenizer.batch_encode_plus( 71 | de_texts, 72 | return_tensors="pt", 73 | padding=True, 74 | truncation=True, 75 | max_length=128, 76 | ).to(self._device) 77 | en_outputs = self.de2en_model.generate( 78 | de_inputs["input_ids"], do_sample=True, temperature=self.sample_temp 79 | ) 80 | en_texts = [ 81 | self.de2en_tokenizer.decode(output, skip_special_tokens=True) 82 | for output in en_outputs 83 | ] 84 | 85 | translated_texts.extend(en_texts) 86 | 87 | augmented_data = [] 88 | j = 0 89 | for n in num_sents: 90 | augmented_data.append(" ".join(translated_texts[j : j + n])) 91 | j += n 92 | 93 | return augmented_data 94 | 95 | 96 | class SSMBA: 97 | """ 98 | Data augmentation method proposed by "SSMBA: Self-Supervised Manifold Based Data Augmentation forImproving Out-of-Domain Robustness" (Ng et. al., 2020) 99 | Most of the code has been adapted or copied from https://github.com/nng555/ssmba 100 | 101 | Args: 102 | model (str): name of masked language model from Huggingface's `transformers` 103 | noise_prob (float): Probability for selecting a token for noising. Selected tokens are then masked, randomly replaced, or left the same. 104 | Default is 0.15. 105 | random_token_prob (float): Probability of a selected token being replaced randomly from the vocabulary. Default is 0.1 106 | leave_unmasked_prob (float): Probability of a selected o tken being left unmasked and unchanged. Default is 0.1 107 | max_tries (int): Num of tries to generate a unique sample before giving up Default is 10. 108 | num_samples (float): Number of augmented samples to generate for each sample. Default is 4. 109 | top_k (int): Top k to use for sampling reconstructed tokens from the BERT model. -1 indicates unrestricted sampling. Default is -1. 110 | min_seq_len (int): Minimum sequence length of the input for agumentation. Default is 4 111 | max_seq_len (int): Maximum sequence length of the input for augmentation. Default is 512 112 | """ 113 | 114 | def __init__( 115 | self, 116 | model="bert-base-uncased", 117 | noise_prob=0.15, 118 | random_token_prob=0.1, 119 | leave_unmasked_prob=0.1, 120 | max_tries=10, 121 | num_samples=1, 122 | top_k=-1, 123 | min_seq_len=4, 124 | max_seq_len=512, 125 | ): 126 | self.mlm_model = transformers.AutoModelForMaskedLM.from_pretrained(model).cuda() 127 | self.tokenizer = transformers.AutoTokenizer.from_pretrained(model) 128 | self.noise_prob = noise_prob 129 | self.random_token_prob = random_token_prob 130 | self.leave_unmasked_prob = leave_unmasked_prob 131 | self.max_tries = max_tries 132 | self.num_samples = num_samples 133 | self.top_k = top_k 134 | self.min_seq_len = min_seq_len 135 | self.max_seq_len = max_seq_len 136 | 137 | self._softmax_mask = np.full(len(self.tokenizer.vocab), False) 138 | self._softmax_mask[self.tokenizer.all_special_ids] = True 139 | self._weights = np.ones(len(self.tokenizer.vocab)) 140 | self._weights[self.tokenizer.all_special_ids] = 0 141 | for k, v in self.tokenizer.vocab.items(): 142 | if "[unused" in k: 143 | self._softmax_mask[v] = True 144 | self._weights[v] = 0 145 | 146 | self._weights = self._weights / self._weights.sum() 147 | 148 | def _mask_and_corrupt(self, tokens): 149 | """ 150 | Main corruption function that (1) randomly masks tokens 151 | and (2) randomly switches tokens with another random token sampled from the vocabulary. 152 | 153 | Args: 154 | tokens (np.ndarray): numpy array of input tokens 155 | Returns: 156 | masked_tokens, mask_targets (tuple[torch.Tensor, torch.Tensor]): 157 | `masked_tokens` is tensor of tokenized `text` after being corrupted, while `mask_targets` is a tensor storing the original values of tokens 158 | that have been corrupted. 159 | """ 160 | if self.noise_prob == 0.0: 161 | return tokens 162 | 163 | seq_len = len(tokens) 164 | mask = np.full(seq_len, False) 165 | # number of tokens to mask 166 | num_mask = int(self.noise_prob * seq_len + np.random.rand()) 167 | 168 | mask_choice_p = np.ones(seq_len) 169 | for i in range(seq_len): 170 | if tokens[i] in self.tokenizer.all_special_ids: 171 | mask_choice_p[i] = 0 172 | mask_choice_p = mask_choice_p / mask_choice_p.sum() 173 | 174 | mask[np.random.choice(seq_len, num_mask, replace=False, p=mask_choice_p)] = True 175 | 176 | # decide unmasking and random replacement 177 | rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob 178 | if rand_or_unmask_prob > 0.0: 179 | rand_or_unmask = mask & (np.random.rand(seq_len) < rand_or_unmask_prob) 180 | if self.random_token_prob == 0.0: 181 | unmask = rand_or_unmask 182 | rand_mask = None 183 | elif self.leave_unmasked_prob == 0.0: 184 | unmask = None 185 | rand_mask = rand_or_unmask 186 | else: 187 | unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob 188 | decision = np.random.rand(seq_len) < unmask_prob 189 | unmask = rand_or_unmask & decision 190 | rand_mask = rand_or_unmask & (~decision) 191 | else: 192 | unmask = rand_mask = None 193 | 194 | if unmask is not None: 195 | mask = mask ^ unmask 196 | 197 | tokens[mask] = self.tokenizer.mask_token_id 198 | if rand_mask is not None: 199 | num_rand = rand_mask.sum() 200 | if num_rand > 0: 201 | tokens[rand_mask] = np.random.choice( 202 | len(self.tokenizer.vocab), 203 | num_rand, 204 | p=self._weights, 205 | ) 206 | 207 | mask_targets = np.full(len(mask), self.tokenizer.pad_token_id) 208 | mask_targets[mask] = tokens[mask == 1] 209 | 210 | return torch.tensor(tokens).long(), torch.tensor(mask_targets).long() 211 | 212 | def _reconstruction_prob_tok(self, masked_tokens, target_tokens): 213 | single = masked_tokens.dim() == 1 214 | 215 | # expand batch size 1 216 | if single: 217 | masked_tokens = masked_tokens.unsqueeze(0) 218 | target_tokens = target_tokens.unsqueeze(0) 219 | 220 | masked_index = (target_tokens != self.tokenizer.pad_token_id).nonzero( 221 | as_tuple=True 222 | ) 223 | masked_orig_index = target_tokens[masked_index] 224 | 225 | # edge case of no masked tokens 226 | if len(masked_orig_index) == 0: 227 | return masked_tokens 228 | 229 | masked_orig_enum = [list(range(len(masked_orig_index))), masked_orig_index] 230 | 231 | masked_tokens = masked_tokens.cuda() 232 | target_tokens = target_tokens.cuda() 233 | outputs = self.mlm_model(masked_tokens, labels=target_tokens) 234 | 235 | features = outputs[1] 236 | 237 | logits = features[masked_index] 238 | 239 | for i in range(len(logits)): 240 | logits[i][self._softmax_mask] = float("-inf") 241 | probs = logits.softmax(dim=-1) 242 | 243 | # sample from topk 244 | if self.top_k != -1: 245 | values, indices = probs.topk(k=self.top_k, dim=-1) 246 | kprobs = values.softmax(dim=-1) 247 | if len(masked_index) > 1: 248 | samples = torch.cat( 249 | [ 250 | idx[torch.multinomial(kprob, 1)] 251 | for kprob, idx in zip(kprobs, indices) 252 | ] 253 | ) 254 | else: 255 | samples = indices[torch.multinomial(kprobs, 1)] 256 | 257 | # unrestricted sampling 258 | else: 259 | if len(masked_index) > 1: 260 | samples = torch.cat([torch.multinomial(prob, 1) for prob in probs]) 261 | else: 262 | samples = torch.multinomial(probs, 1) 263 | 264 | # set samples 265 | masked_tokens[masked_index] = samples 266 | 267 | if single: 268 | return masked_tokens[0] 269 | else: 270 | return masked_tokens 271 | 272 | def _decode_tokens(self, tokens): 273 | """ 274 | Decode tokens into string 275 | 276 | Args: 277 | tokens (torch.Tensor): tokens of ids 278 | Returns: 279 | text (str): decoded string 280 | """ 281 | # remove [CLS] and [SEP] tokens 282 | tokens = tokens[1:-1] 283 | # remove [PAD] tokens 284 | tokens = tokens[tokens != self.tokenizer.pad_token_id] 285 | return self.tokenizer.decode(tokens).strip() 286 | 287 | def __call__(self, text): 288 | """ 289 | Generate list of augmented data based off of `text` 290 | 291 | Args: 292 | text (str): seed text 293 | Returns: 294 | augmented_text (list[str]): List of augmented text 295 | """ 296 | tokens = self.tokenizer.encode( 297 | text, 298 | add_special_tokens=True, 299 | return_tensors="np", 300 | truncation=True, 301 | max_length=self.max_seq_len, 302 | )[0] 303 | if len(tokens) < self.min_seq_len or len(tokens) > self.max_seq_len: 304 | raise ValueError( 305 | f"Given input of sequence length {len(tokens)} is too short. Minimum sequence length is {self.min_seq_len} " 306 | f"and maximum sequence length is {self.max_seq_len}." 307 | ) 308 | 309 | num_tries = 0 310 | new_samples = [] 311 | 312 | while num_tries < self.max_tries: 313 | masked_tokens, target_tokens = self._mask_and_corrupt(np.copy(tokens)) 314 | new_sample = self._reconstruction_prob_tok(masked_tokens, target_tokens) 315 | num_tries += 1 316 | new_sample = self._decode_tokens(new_sample) 317 | 318 | # check if identical reconstruction or empty 319 | if new_sample != text and new_sample != "": 320 | new_samples.append(new_sample) 321 | break 322 | 323 | return new_samples[0] 324 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | DATASET_CONFIGS = { 2 | "imdb": { 3 | "labels": 2, 4 | "label_names": ["negative", "positive"], 5 | "dataset_columns": (["text"], "label"), 6 | "local_path": "./data/imdb", 7 | "eval_datasets": {"test": "imdb", "cross_domain": "yelp"}, 8 | }, 9 | "yelp": { 10 | "labels": 2, 11 | "label_names": ["negative", "positive"], 12 | "dataset_columns": (["text"], "label"), 13 | "local_path": "./data/yelp", 14 | "eval_datasets": {"test": "yelp", "cross_domain": "imdb"}, 15 | }, 16 | "rt": { 17 | "labels": 2, 18 | "label_names": ["negative", "positive"], 19 | "dataset_columns": (["text"], "label"), 20 | "remote_name": "rotten_tomatoes", 21 | "eval_datasets": {"test": "rt", "cross_domain": "yelp"}, 22 | }, 23 | "snli": { 24 | "labels": 3, 25 | "label_names": ["entailment", "neutral", "contradiction"], 26 | "dataset_columns": (["premise", "hypothesis"], "label"), 27 | "remote_name": "snli", 28 | "eval_datasets": {"test": "snli", "cross_domain": "mnli"}, 29 | }, 30 | "mnli": { 31 | "labels": 3, 32 | "label_names": ["entailment", "neutral", "contradiction"], 33 | "dataset_columns": (["premise", "hypothesis"], "label"), 34 | "remote_name": "multi_nli", 35 | "split": "validation_mismatched+validation_matched", 36 | }, 37 | } 38 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import json 4 | import os 5 | import random 6 | import math 7 | import multiprocessing as mp 8 | 9 | import datasets 10 | import numpy as np 11 | import textattack 12 | import torch 13 | import tqdm 14 | import transformers 15 | from lime.lime_text import LimeTextExplainer, IndexedString 16 | 17 | from configs import DATASET_CONFIGS 18 | 19 | 20 | NUM_SAMPLES_FOR_EVALUATION = 1000 21 | 22 | 23 | class AOPC: 24 | def __init__(self, model, tokenizer, labels): 25 | self.interpreter = LimeTextExplainer( 26 | class_names=labels, bow=False, mask_string=tokenizer.unk_token 27 | ) 28 | self.model = model 29 | self.model.cuda() 30 | self.model.eval() 31 | self.tokenizer = tokenizer 32 | self.K = 10 33 | self.num_samples = 1000 34 | 35 | def pred_fn_nli(self, premise, texts, batch_size=128): 36 | texts = [(premise, t) for t in texts] 37 | all_probs = [] 38 | for i in range(0, len(texts), batch_size): 39 | inputs = texts[i : i + batch_size] 40 | input_ids = self.tokenizer( 41 | inputs, 42 | padding="max_length", 43 | truncation=True, 44 | return_tensors="pt", 45 | max_length=512, 46 | ) 47 | input_ids.to("cuda") 48 | with torch.no_grad(): 49 | logits = self.model(**input_ids)[0] 50 | probs = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy() 51 | all_probs.append(probs) 52 | 53 | probs = np.concatenate(all_probs, axis=0) 54 | 55 | return probs 56 | 57 | def pred_fn(self, texts, batch_size=128): 58 | all_probs = [] 59 | for i in range(0, len(texts), batch_size): 60 | inputs = texts[i : i + batch_size] 61 | input_ids = self.tokenizer( 62 | inputs, padding="max_length", truncation=True, return_tensors="pt" 63 | ) 64 | input_ids.to("cuda") 65 | with torch.no_grad(): 66 | logits = self.model(**input_ids)[0] 67 | probs = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy() 68 | all_probs.append(probs) 69 | 70 | probs = np.concatenate(all_probs, axis=0) 71 | 72 | return probs 73 | 74 | def calc_aopc_dataset(self, dataset): 75 | aopc_scores = [] 76 | for row in tqdm.tqdm(dataset): 77 | if "content" in row: 78 | text = row["content"] 79 | elif "hypothesis" in row: 80 | text = (row["premise"], row["hypothesis"]) 81 | else: 82 | text = row["text"] 83 | label = row["label"] 84 | num_words = IndexedString(text, bow=False).num_words() 85 | K = min(max(self.K, math.ceil(num_words * 0.1)), num_words) 86 | exp = self.interpreter.explain_instance( 87 | text, self.pred_fn, num_features=K, num_samples=self.num_samples 88 | ) 89 | exp = exp.as_map()[1] 90 | perturbed_texts = [text] 91 | for k in range(1, K + 1): 92 | top_exp = [e[0] for e in exp[:k]] 93 | x = IndexedString(text, bow=False, mask_string="").inverse_removing( 94 | top_exp 95 | ) 96 | perturbed_texts.append(x) 97 | probs = self.pred_fn(perturbed_texts) 98 | probs_diff = (probs[0] - probs)[1:, label] 99 | aopc_scores.append(probs_diff.sum()) 100 | avg_aopc = sum(aopc_scores) / (len(aopc_scores) * (1 + self.K)) 101 | return avg_aopc 102 | 103 | def calc_aopc_instance(self, text, label, nli=False): 104 | if nli: 105 | premise = text[0] 106 | text = text[1] 107 | 108 | num_words = IndexedString(text, bow=False).num_words() 109 | K = min(self.K, num_words) 110 | if nli: 111 | pred_fn = functools.partial(self.pred_fn_nli, premise) 112 | else: 113 | pred_fn = self.pred_fn 114 | 115 | exp = self.interpreter.explain_instance( 116 | text, pred_fn, num_features=K, num_samples=self.num_samples 117 | ) 118 | 119 | exp = exp.as_map()[1] 120 | perturbed_texts = [text] 121 | for k in range(1, K + 1): 122 | top_exp = [e[0] for e in exp[:k]] 123 | x = IndexedString(text, bow=False, mask_string="").inverse_removing(top_exp) 124 | perturbed_texts.append(x) 125 | 126 | probs = pred_fn(perturbed_texts) 127 | probs_diff = (probs[0] - probs)[1:, label] 128 | score = probs_diff.sum() / (1 + K) 129 | return score 130 | 131 | 132 | def calc_aopc(model_type, model_path, labels, num_gpus, in_queue, out_queue, nli): 133 | gpu_id = (torch.multiprocessing.current_process()._identity[0] - 1) % num_gpus 134 | torch.cuda.set_device(gpu_id) 135 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 136 | torch.cuda.set_device(gpu_id) 137 | model = transformers.AutoModelForSequenceClassification.from_pretrained(model_path) 138 | if model_type == "roberta": 139 | model_type = "roberta-base" 140 | else: 141 | model_type = "bert-base-uncased" 142 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_type) 143 | aopc = AOPC(model, tokenizer, labels) 144 | while True: 145 | try: 146 | i, input_text, label = in_queue.get() 147 | if i == "END" and example == "END" and ground_truth_output == "END": 148 | # End process when sentinel value is received 149 | break 150 | else: 151 | aopc_score = aopc.calc_aopc_instance(input_text, label, nli=nli) 152 | out_queue.put((i, aopc_score)) 153 | except Exception as e: 154 | out_queue.put((i, e)) 155 | 156 | 157 | # Helper functions for collating data 158 | def collate_fn(input_columns, data): 159 | input_texts = [] 160 | labels = [] 161 | for d in data: 162 | label = d["label"] 163 | _input = tuple(d[c] for c in input_columns) 164 | if len(_input) == 1: 165 | _input = _input[0] 166 | input_texts.append(_input) 167 | labels.append(label) 168 | return input_texts, torch.tensor(labels) 169 | 170 | 171 | def load_dataset(name): 172 | if name not in DATASET_CONFIGS: 173 | raise ValueError(f"Unknown dataset {name}") 174 | dataset_config = DATASET_CONFIGS[name] 175 | if "local_path" in dataset_config: 176 | dataset = datasets.load_dataset( 177 | "csv", 178 | data_files=os.path.join(dataset_config["local_path"], "test.tsv"), 179 | delimiter="\t", 180 | )["train"] 181 | else: 182 | if "split" in dataset_config: 183 | dataset = datasets.load_dataset( 184 | dataset_config["remote_name"], split=dataset_config["split"] 185 | ) 186 | else: 187 | dataset = datasets.load_dataset(dataset_config["remote_name"], split="test") 188 | 189 | dataset = dataset.filter(lambda x: x["label"] != -1) 190 | 191 | return dataset 192 | 193 | 194 | def calc_attack_stats(results): 195 | total_attacks = len(results) 196 | 197 | all_num_words = np.zeros(total_attacks) 198 | perturbed_word_percentages = np.zeros(total_attacks) 199 | failed_attacks = 0 200 | skipped_attacks = 0 201 | successful_attacks = 0 202 | 203 | for i, result in enumerate(results): 204 | all_num_words[i] = len(result.original_result.attacked_text.words) 205 | if isinstance(result, textattack.attack_results.FailedAttackResult): 206 | failed_attacks += 1 207 | continue 208 | elif isinstance(result, textattack.attack_results.SkippedAttackResult): 209 | skipped_attacks += 1 210 | continue 211 | else: 212 | successful_attacks += 1 213 | num_words_changed = len( 214 | result.original_result.attacked_text.all_words_diff( 215 | result.perturbed_result.attacked_text 216 | ) 217 | ) 218 | if len(result.original_result.attacked_text.words) > 0: 219 | perturbed_word_percentage = ( 220 | num_words_changed 221 | * 100.0 222 | / len(result.original_result.attacked_text.words) 223 | ) 224 | else: 225 | perturbed_word_percentage = 0 226 | perturbed_word_percentages[i] = perturbed_word_percentage 227 | 228 | attack_success_rate = successful_attacks * 100.0 / total_attacks 229 | attack_success_rate = round(attack_success_rate, 2) 230 | 231 | perturbed_word_percentages = perturbed_word_percentages[ 232 | perturbed_word_percentages > 0 233 | ] 234 | average_perc_words_perturbed = round(perturbed_word_percentages.mean(), 2) 235 | 236 | num_queries = np.array( 237 | [ 238 | r.num_queries 239 | for r in results 240 | if not isinstance(r, textattack.attack_results.SkippedAttackResult) 241 | ] 242 | ) 243 | avg_num_queries = round(num_queries.mean(), 2) 244 | 245 | return attack_success_rate, avg_num_queries, average_perc_words_perturbed 246 | 247 | 248 | ##################################################################################### 249 | 250 | 251 | def evaluate_interpretability(args): 252 | if args.dataset not in DATASET_CONFIGS: 253 | raise ValueError() 254 | dataset_config = DATASET_CONFIGS[args.dataset] 255 | test_dataset = load_dataset(args.dataset) 256 | 257 | all_correct_indices = set(range(len(test_dataset))) 258 | for path in args.checkpoint_paths: 259 | with open(os.path.join(path, "test_logs.json"), "r") as f: 260 | logs = json.load(f) 261 | correct_indices = logs[f"checkpoint-epoch-{args.epoch}"][args.dataset][ 262 | "correct_indices" 263 | ] 264 | all_correct_indices = all_correct_indices.intersection(correct_indices) 265 | 266 | all_correct_indices = list(all_correct_indices) 267 | random.shuffle(all_correct_indices) 268 | indices = all_correct_indices[:NUM_SAMPLES_FOR_EVALUATION] 269 | 270 | test_dataset = test_dataset.select(indices) 271 | 272 | if args.model_type == "bert": 273 | model_type = "bert-base-uncased" 274 | elif args.model_type == "roberta": 275 | model_type = "roberta-base" 276 | 277 | num_gpus = torch.cuda.device_count() 278 | nli = args.dataset == "snli" 279 | 280 | print("Evaluating interpretability (this might take a long time)") 281 | 282 | for path in args.checkpoint_paths: 283 | logs = {} 284 | logs["indices"] = indices 285 | logs[f"checkpoint-epoch-{args.epoch}"] = {} 286 | model_path = f"{path}/checkpoint-epoch-{args.epoch}" 287 | 288 | print(f"====== {path}/checkpoint-epoch-{args.epoch} =====") 289 | 290 | if num_gpus > 1: 291 | torch.multiprocessing.set_start_method("spawn", force=True) 292 | torch.multiprocessing.set_sharing_strategy("file_system") 293 | 294 | in_queue = torch.multiprocessing.Queue() 295 | out_queue = torch.multiprocessing.Queue() 296 | label_names = dataset_config["label_names"] 297 | for i, row in enumerate(test_dataset): 298 | if "content" in row: 299 | text = row["content"] 300 | elif "hypothesis" in row: 301 | text = (row["premise"], row["hypothesis"]) 302 | else: 303 | text = row["text"] 304 | label = row["label"] 305 | if label == -1: 306 | print("Warning: Found label==-1") 307 | in_queue.put((i, text, label)) 308 | 309 | # Start workers. 310 | worker_pool = torch.multiprocessing.Pool( 311 | num_gpus, 312 | calc_aopc, 313 | ( 314 | model_type, 315 | model_path, 316 | label_names, 317 | num_gpus, 318 | in_queue, 319 | out_queue, 320 | nli, 321 | ), 322 | ) 323 | scores = [] 324 | pbar = tqdm.tqdm(total=len(test_dataset), smoothing=0) 325 | for _ in range(len(test_dataset)): 326 | idx, score = out_queue.get(block=True) 327 | pbar.update() 328 | if isinstance(score, Exception): 329 | raise score 330 | scores.append(score) 331 | aopc_score = np.array(scores).mean() 332 | 333 | # Send sentinel values to worker processes 334 | for _ in range(num_gpus): 335 | in_queue.put(("END", "END", "END")) 336 | worker_pool.terminate() 337 | worker_pool.join() 338 | 339 | else: 340 | model = transformers.AutoModelForSequenceClassification.from_pretrained( 341 | model_path 342 | ) 343 | tokenizer = transformers.AutoTokenizer.from_pretrained( 344 | model_type, use_fast=True 345 | ) 346 | aopc = AOPC(model, tokenizer, dataset_config["label_names"]) 347 | aopc_score = aopc.calc_aopc_dataset(test_dataset) 348 | 349 | aopc_score = round(aopc_score, 4) 350 | logs[f"checkpoint-epoch-{args.epoch}"]["aopc"] = aopc_score 351 | 352 | print(f"AOPC: {aopc_score}") 353 | 354 | with open(os.path.join(path, "interpretability_eval_logs.json"), "w") as f: 355 | json.dump(logs, f) 356 | 357 | 358 | def eval_robustness(args): 359 | if args.dataset not in DATASET_CONFIGS: 360 | raise ValueError() 361 | dataset_config = DATASET_CONFIGS[args.dataset] 362 | 363 | test_dataset = load_dataset(args.dataset) 364 | 365 | all_correct_indices = set(range(len(test_dataset))) 366 | for path in args.checkpoint_paths: 367 | with open(os.path.join(path, "test_logs.json"), "r") as f: 368 | logs = json.load(f) 369 | correct_indices = logs[f"checkpoint-epoch-{args.epoch}"][args.dataset][ 370 | "correct_indices" 371 | ] 372 | all_correct_indices = all_correct_indices.intersection(correct_indices) 373 | 374 | all_correct_indices = list(all_correct_indices) 375 | random.shuffle(all_correct_indices) 376 | indices_to_test = all_correct_indices[:NUM_SAMPLES_FOR_EVALUATION] 377 | 378 | test_dataset = test_dataset.select(indices_to_test) 379 | test_dataset = textattack.datasets.HuggingFaceDataset( 380 | test_dataset, 381 | dataset_columns=dataset_config["dataset_columns"], 382 | label_names=dataset_config["label_names"], 383 | ) 384 | 385 | if args.model_type == "bert": 386 | model_type = "bert-base-uncased" 387 | elif args.model_type == "roberta": 388 | model_type = "roberta-base" 389 | else: 390 | raise ValueError(f"Unknown model type {args.model_type}.") 391 | 392 | print("Evaluating robustness (this might take a long time)...") 393 | 394 | for path in args.checkpoint_paths: 395 | logs = {} 396 | logs["indices"] = indices_to_test 397 | logs[f"checkpoint-epoch-{args.epoch}"] = {} 398 | model_path = f"{path}/checkpoint-epoch-{args.epoch}" 399 | model = transformers.AutoModelForSequenceClassification.from_pretrained( 400 | model_path 401 | ) 402 | tokenizer = transformers.AutoTokenizer.from_pretrained( 403 | model_type, use_fast=True 404 | ) 405 | model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper( 406 | model, tokenizer 407 | ) 408 | print(f"====== {path}/checkpoint-epoch-{args.epoch} =====") 409 | for attack_name in args.attacks: 410 | log_file_name = f"{path}/{attack_name}-test-{args.epoch}" 411 | attack_args = textattack.AttackArgs( 412 | num_examples=NUM_SAMPLES_FOR_EVALUATION, 413 | parallel=(torch.cuda.device_count() > 1), 414 | disable_stdout=True, 415 | num_workers_per_device=1, 416 | query_budget=10000, 417 | shuffle=False, 418 | log_to_txt=log_file_name + ".txt", 419 | log_to_csv=log_file_name + ".csv", 420 | silent=True, 421 | ) 422 | if attack_name == "a2t": 423 | attack = textattack.attack_recipes.A2TYoo2021.build( 424 | model_wrapper, mlm=False 425 | ) 426 | elif attack_name == "a2t_mlm": 427 | attack = textattack.attack_recipes.A2TYoo2021.build( 428 | model_wrapper, mlm=True 429 | ) 430 | elif attack_name == "textfooler": 431 | attack = textattack.attack_recipes.TextFoolerJin2019.build( 432 | model_wrapper 433 | ) 434 | elif attack_name == "bae": 435 | attack = textattack.attack_recipes.BAEGarg2019.build(model_wrapper) 436 | elif attack_name == "pwws": 437 | attack = textattack.attack_recipes.PWWSRen2019.build(model_wrapper) 438 | elif attack_name == "pso": 439 | attack = textattack.attack_recipes.PSOZang2020.build(model_wrapper) 440 | 441 | attacker = textattack.Attacker(attack, test_dataset, attack_args) 442 | results = attacker.attack_dataset() 443 | 444 | ( 445 | attack_success_rate, 446 | avg_num_queries, 447 | avg_pct_perturbed, 448 | ) = calc_attack_stats(results) 449 | logs[f"checkpoint-epoch-{args.epoch}"][attack_name] = { 450 | "attack_success_rate": attack_success_rate, 451 | "avg_num_queries": avg_num_queries, 452 | "avg_pct_perturbed": avg_pct_perturbed, 453 | } 454 | 455 | print( 456 | f"{attack_name}: {round(attack_success_rate, 1)} (attack success rate) | {avg_num_queries} (avg num queries) | {avg_pct_perturbed} (avg pct perturbed)" 457 | ) 458 | 459 | with open(os.path.join(path, "robustness_eval_logs.json"), "w") as f: 460 | json.dump(logs, f) 461 | 462 | 463 | def eval_accuracy(args): 464 | print("Evaluating accuarcy") 465 | if args.dataset not in DATASET_CONFIGS: 466 | raise ValueError() 467 | dataset_config = DATASET_CONFIGS[args.dataset] 468 | test_datasets = dataset_config["eval_datasets"] 469 | eval_datasets = [ 470 | (test_datasets[key], load_dataset(test_datasets[key])) for key in test_datasets 471 | ] 472 | 473 | for path in args.checkpoint_paths: 474 | logs = {} 475 | model_save_path = os.path.join(path, f"checkpoint-epoch-{args.epoch}") 476 | if args.model_type == "bert": 477 | model = transformers.BertForSequenceClassification.from_pretrained( 478 | model_save_path 479 | ) 480 | tokenizer = transformers.BertTokenizerFast.from_pretrained( 481 | "bert-base-uncased" 482 | ) 483 | elif args.model_type == "roberta": 484 | model = transformers.RobertaForSequenceClassification.from_pretrained( 485 | model_save_path 486 | ) 487 | tokenizer = transformers.RobertaTokenizerFast.from_pretrained( 488 | "roberta-base" 489 | ) 490 | else: 491 | raise ValueError() 492 | 493 | num_gpus = torch.cuda.device_count() 494 | if num_gpus > 1: 495 | model = torch.nn.DataParallel(model) 496 | 497 | model.eval() 498 | model.cuda() 499 | 500 | if isinstance(model, torch.nn.DataParallel): 501 | eval_batch_size = 128 * num_gpus 502 | else: 503 | eval_batch_size = 128 504 | 505 | logs[f"checkpoint-epoch-{args.epoch}"] = {} 506 | print(f"====== {path}/checkpoint-epoch-{args.epoch} =====") 507 | 508 | for dataset_name, dataset in eval_datasets: 509 | input_columns = DATASET_CONFIGS[dataset_name]["dataset_columns"][0] 510 | collate_func = functools.partial(collate_fn, input_columns) 511 | dataloader = torch.utils.data.DataLoader( 512 | dataset, batch_size=eval_batch_size, collate_fn=collate_func 513 | ) 514 | 515 | preds_list = [] 516 | labels_list = [] 517 | 518 | with torch.no_grad(): 519 | for batch in dataloader: 520 | input_texts, labels = batch 521 | input_ids = tokenizer( 522 | input_texts, 523 | padding="max_length", 524 | return_tensors="pt", 525 | truncation=True, 526 | ) 527 | for key in input_ids: 528 | if isinstance(input_ids[key], torch.Tensor): 529 | input_ids[key] = input_ids[key].cuda() 530 | logits = model(**input_ids)[0] 531 | 532 | preds = logits.argmax(dim=-1).detach().cpu() 533 | preds_list.append(preds) 534 | labels_list.append(labels) 535 | 536 | preds = torch.cat(preds_list) 537 | labels = torch.cat(labels_list) 538 | 539 | compare = preds == labels 540 | num_correct = compare.sum().item() 541 | accuracy = round(num_correct / len(labels), 4) 542 | correct = torch.nonzero(compare, as_tuple=True)[0].tolist() 543 | 544 | logs[f"checkpoint-epoch-{args.epoch}"][dataset_name] = { 545 | "accuracy": accuracy, 546 | "correct_indices": correct, 547 | } 548 | 549 | print(f"{dataset_name}: {accuracy}") 550 | 551 | if args.save_log: 552 | with open( 553 | os.path.join( 554 | os.path.dirname(model_save_path), "accuracy_eval_logs.json" 555 | ), 556 | "w", 557 | ) as f: 558 | json.dump(logs, f) 559 | 560 | 561 | def main(args): 562 | for path in args.checkpoint_paths: 563 | if not os.path.exists(path): 564 | raise FileNotFoundError(f"Checkpoint path {path} not found.") 565 | if args.accuracy: 566 | eval_accuracy(args) 567 | 568 | if args.robustness: 569 | eval_robustness(args) 570 | 571 | if args.interpretability: 572 | evaluate_interpretability(args) 573 | 574 | 575 | if __name__ == "__main__": 576 | parser = argparse.ArgumentParser() 577 | parser.add_argument( 578 | "--dataset", 579 | type=str, 580 | required=True, 581 | choices=sorted(list(DATASET_CONFIGS.keys())), 582 | help="Name train dataset.", 583 | ) 584 | parser.add_argument( 585 | "--model-type", 586 | type=str, 587 | required=True, 588 | choices=["bert", "roberta"], 589 | help="Type of model. Choices: `bert` and `robert`.", 590 | ) 591 | parser.add_argument( 592 | "--checkpoint-paths", 593 | type=str, 594 | nargs="*", 595 | default=None, 596 | help="Path of model checkpoint", 597 | ) 598 | parser.add_argument( 599 | "--epoch", type=int, default=4, help="Epoch of model to evaluate." 600 | ) 601 | parser.add_argument( 602 | "--save-log", action="store_true", help="Save evaluation result as log." 603 | ) 604 | parser.add_argument("--accuracy", action="store_true", help="Evaluate accuracy.") 605 | parser.add_argument( 606 | "--robustness", action="store_true", help="Evaluate robustness." 607 | ) 608 | attack_choices = ["a2t", "at2_mlm", "textfooler", "bae", "pwws", "pso"] 609 | parser.add_argument( 610 | "--attacks", 611 | type=str, 612 | nargs="*", 613 | default=None, 614 | help=f"Attacks to use to measure robustness. Choices are {attack_choices}.", 615 | ) 616 | parser.add_argument( 617 | "--interpretability", 618 | action="store_true", 619 | help="Evaluate interpretability using AOPC metric.", 620 | ) 621 | 622 | args = parser.parse_args() 623 | main(args) 624 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | tqdm 4 | pandas 5 | textattack 6 | transformers 7 | datasets 8 | lime -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | import math 5 | import datetime 6 | 7 | import textattack 8 | import transformers 9 | import datasets 10 | import pandas as pd 11 | 12 | from configs import DATASET_CONFIGS 13 | 14 | LOG_TO_WANDB = True 15 | 16 | 17 | def filter_fn(x): 18 | """Filter bad samples.""" 19 | if x["label"] == -1: 20 | return False 21 | if "premise" in x: 22 | if x["premise"] is None or x["premise"] == "": 23 | return False 24 | if "hypothesis" in x: 25 | if x["hypothesis"] is None or x["hypothesis"] == "": 26 | return False 27 | return True 28 | 29 | 30 | def main(args): 31 | 32 | if args.train not in DATASET_CONFIGS: 33 | raise ValueError() 34 | dataset_config = DATASET_CONFIGS[args.train] 35 | 36 | if "local_path" in dataset_config: 37 | train_dataset = datasets.load_dataset( 38 | "csv", 39 | data_files=os.path.join(dataset_config["local_path"], "train.tsv"), 40 | delimiter="\t", 41 | )["train"] 42 | else: 43 | train_dataset = datasets.load_dataset( 44 | dataset_config["remote_name"], split="train" 45 | ) 46 | 47 | if "local_path" in dataset_config: 48 | eval_dataset = datasets.load_dataset( 49 | "csv", 50 | data_files=os.path.join(dataset_config["local_path"], "val.tsv"), 51 | delimiter="\t", 52 | )["train"] 53 | else: 54 | eval_dataset = datasets.load_dataset( 55 | dataset_config["remote_name"], split="validation" 56 | ) 57 | 58 | if args.augmented_data: 59 | pd_train_dataset = train_dataset.to_pandas() 60 | feature = train_dataset.features 61 | augmented_dataset = datasets.load_dataset( 62 | "csv", 63 | data_files=args.augmented_data, 64 | delimiter="\t", 65 | features=feature, 66 | )["train"] 67 | augmented_dataset = augmented_dataset.filter(filter_fn) 68 | sampled_indices = list(range(len(augmented_dataset))) 69 | random.shuffle(sampled_indices) 70 | sampled_indices = sampled_indices[ 71 | : math.ceil(len(sampled_indices) * args.pct_of_augmented) 72 | ] 73 | augmented_dataset = augmented_dataset.select( 74 | sampled_indices, keep_in_memory=True 75 | ).to_pandas() 76 | train_dataset = datasets.Dataset.from_pandas( 77 | pd.concat((pd_train_dataset, augmented_dataset)) 78 | ) 79 | 80 | train_dataset = train_dataset.filter(lambda x: x["label"] != -1) 81 | eval_dataset = eval_dataset.filter(lambda x: x["label"] != -1) 82 | 83 | train_dataset = textattack.datasets.HuggingFaceDataset( 84 | train_dataset, 85 | dataset_columns=dataset_config["dataset_columns"], 86 | label_names=dataset_config["label_names"], 87 | ) 88 | 89 | eval_dataset = textattack.datasets.HuggingFaceDataset( 90 | eval_dataset, 91 | dataset_columns=dataset_config["dataset_columns"], 92 | label_names=dataset_config["label_names"], 93 | ) 94 | 95 | if args.model_type == "bert": 96 | pretrained_name = "bert-base-uncased" 97 | elif args.model_type == "roberta": 98 | pretrained_name = "roberta-base" 99 | 100 | if args.model_chkpt_path: 101 | model = transformers.AutoModelForSequenceClassification.from_pretrained( 102 | args.model_chkpt_path 103 | ) 104 | else: 105 | num_labels = dataset_config["labels"] 106 | config = transformers.AutoConfig.from_pretrained( 107 | pretrained_name, num_labels=num_labels 108 | ) 109 | model = transformers.AutoModelForSequenceClassification.from_pretrained( 110 | pretrained_name, config=config 111 | ) 112 | tokenizer = transformers.AutoTokenizer.from_pretrained( 113 | pretrained_name, use_fast=True 114 | ) 115 | model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) 116 | 117 | if args.attack == "a2t": 118 | attack = textattack.attack_recipes.A2TYoo2021.build(model_wrapper, mlm=False) 119 | elif args.attack == "a2t_mlm": 120 | attack = textattack.attack_recipes.A2TYoo2021.build(model_wrapper, mlm=True) 121 | else: 122 | raise ValueError(f"Unknown attack {args.attack}.") 123 | 124 | training_args = textattack.TrainingArgs( 125 | num_epochs=args.num_epochs, 126 | num_clean_epochs=args.num_clean_epochs, 127 | attack_epoch_interval=args.attack_epoch_interval, 128 | parallel=args.parallel, 129 | per_device_train_batch_size=args.per_device_train_batch_size, 130 | gradient_accumulation_steps=args.grad_accumu_steps, 131 | num_warmup_steps=args.num_warmup_steps, 132 | learning_rate=args.learning_rate, 133 | num_train_adv_examples=args.num_adv_examples, 134 | attack_num_workers_per_device=1, 135 | query_budget_train=200, 136 | checkpoint_interval_epochs=args.checkpoint_interval_epochs, 137 | output_dir=args.model_save_path, 138 | log_to_wandb=LOG_TO_WANDB, 139 | wandb_project="nlp-robustness", 140 | load_best_model_at_end=True, 141 | logging_interval_step=10, 142 | random_seed=args.seed, 143 | ) 144 | trainer = textattack.Trainer( 145 | model_wrapper, 146 | "classification", 147 | attack, 148 | train_dataset, 149 | eval_dataset, 150 | training_args, 151 | ) 152 | trainer.train() 153 | 154 | 155 | if __name__ == "__main__": 156 | 157 | def int_or_float(v): 158 | try: 159 | return int(v) 160 | except ValueError: 161 | return float(v) 162 | 163 | parser = argparse.ArgumentParser() 164 | parser.add_argument( 165 | "--train", 166 | type=str, 167 | required=True, 168 | choices=sorted(list(DATASET_CONFIGS.keys())), 169 | help="Name of dataset for training.", 170 | ) 171 | parser.add_argument( 172 | "--augmented-data", 173 | type=str, 174 | required=False, 175 | default=None, 176 | help="Path of augmented data (in TSV).", 177 | ) 178 | parser.add_argument( 179 | "--pct-of-augmented", 180 | type=float, 181 | required=False, 182 | default=0.2, 183 | help="Percentage of augmented data to use.", 184 | ) 185 | parser.add_argument( 186 | "--eval", 187 | type=str, 188 | required=True, 189 | choices=sorted(list(DATASET_CONFIGS.keys())), 190 | help="Name of huggingface dataset for validation", 191 | ) 192 | parser.add_argument( 193 | "--parallel", action="store_true", help="Run training with multiple GPUs." 194 | ) 195 | parser.add_argument( 196 | "--model-type", 197 | type=str, 198 | required=True, 199 | choices=["bert", "roberta"], 200 | help='Type of model (e.g. "bert", "roberta").', 201 | ) 202 | parser.add_argument( 203 | "--model-save-path", 204 | type=str, 205 | default="./saved_model", 206 | help="Directory to save model checkpoint.", 207 | ) 208 | parser.add_argument( 209 | "--model-chkpt-path", 210 | type=str, 211 | default=None, 212 | help="Directory of model checkpoint to resume from.", 213 | ) 214 | parser.add_argument( 215 | "--num-epochs", type=int, default=4, help="Number of epochs to train." 216 | ) 217 | parser.add_argument( 218 | "--num-clean-epochs", type=int, default=1, help="Number of clean epochs" 219 | ) 220 | parser.add_argument( 221 | "--num-adv-examples", 222 | type=int_or_float, 223 | help="Number (or percentage) of adversarial examples for training.", 224 | ) 225 | parser.add_argument( 226 | "--attack-epoch-interval", 227 | type=int, 228 | default=1, 229 | help="Attack model to generate adversarial examples every N epochs.", 230 | ) 231 | parser.add_argument( 232 | "--attack", type=str, choices=["a2t", "a2t_mlm"], help="Name of attack." 233 | ) 234 | parser.add_argument( 235 | "--per-device-train-batch-size", 236 | type=int, 237 | default=8, 238 | help="Train batch size (per GPU device).", 239 | ) 240 | parser.add_argument( 241 | "--learning-rate", type=float, default=5e-5, help="Learning rate" 242 | ) 243 | parser.add_argument( 244 | "--num-warmup-steps", type=int, default=500, help="Number of warmup steps." 245 | ) 246 | parser.add_argument( 247 | "--grad-accumu-steps", 248 | type=int, 249 | default=1, 250 | help="Number of gradient accumulation steps.", 251 | ) 252 | parser.add_argument( 253 | "--checkpoint-interval-epochs", 254 | type=int, 255 | default=None, 256 | help="If set, save model checkpoint after every `N` epochs.", 257 | ) 258 | parser.add_argument("--seed", type=int, help="Random seed") 259 | args = parser.parse_args() 260 | main(args) 261 | --------------------------------------------------------------------------------