├── src ├── __init__.py ├── accuracy.py ├── models.py ├── utils.py └── attacker.py ├── README.md ├── task_to_keys.json ├── pyproject.toml ├── wallace ├── utls │ ├── models.py │ ├── data.py │ └── attacker.py └── train.py ├── eval.py ├── train.py └── LICENSE /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nlp-layerwise-fooler -------------------------------------------------------------------------------- /task_to_keys.json: -------------------------------------------------------------------------------- 1 | {"cola": ["sentence", null], "mnli": ["premise", "hypothesis"], "mnli-mm": ["premise", "hypothesis"], "mrpc": ["sentence1", "sentence2"], "qnli": ["question", "sentence"], "qqp": ["question1", "question2"], "rte": ["sentence1", "sentence2"], "sst2": ["sentence", null], "stsb": ["sentence1", "sentence2"], "wnli": ["sentence1", "sentence2"]} -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "nlp-singular-fooler" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Olga Tsymboi"] 6 | 7 | [tool.poetry.dependencies] 8 | python = "^3.9" 9 | numpy = "^1.23.2" 10 | torch = "^1.12.1" 11 | transformers = "^4.21.1" 12 | datasets = "^2.4.0" 13 | 14 | [tool.poetry.dev-dependencies] 15 | isort = "^5.10.1" 16 | ipykernel = "^6.15.1" 17 | ipywidgets = "^7.7.1" 18 | pre-commit = "^2.20.0" 19 | darglint = "^1.8.1" 20 | black = "^22.6.0" 21 | flake8-docstrings = "^1.6.0" 22 | flake8 = "^5.0.4" 23 | rstcheck = "^6.1.0" 24 | 25 | [build-system] 26 | requires = ["poetry-core>=1.0.0"] 27 | build-backend = "poetry.core.masonry.api" 28 | 29 | [tool.isort] 30 | profile = "black" 31 | force_single_line = true 32 | atomic = true 33 | include_trailing_comma = true 34 | lines_after_imports = 2 35 | lines_between_types = 1 36 | use_parentheses = true 37 | filter_files = true -------------------------------------------------------------------------------- /wallace/utls/models.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForSequenceClassification, AutoConfig 2 | import torch.nn as nn 3 | 4 | 5 | class WallaceVictimModel(nn.Module): 6 | def __init__(self, model_chckpt): 7 | super().__init__() 8 | if isinstance(model_chckpt, list): 9 | config = AutoConfig.from_pretrained(model_chckpt[0]) 10 | if 'mnli' in model_chckpt[1]: 11 | config.num_labels = 3 12 | self.model = AutoModelForSequenceClassification.from_pretrained(model_chckpt[1], config=config) 13 | else: 14 | self.model = AutoModelForSequenceClassification.from_pretrained(model_chckpt) 15 | 16 | 17 | def get_inputs_embeds(self, input_ids): 18 | return self.model.get_input_embeddings()(input_ids) 19 | 20 | 21 | @property 22 | def vocab(self): 23 | return self.model.get_input_embeddings().weight 24 | 25 | 26 | def forward(self, input_ids, attention_mask, token_type_ids, labels, inputs_embeds=None): 27 | if inputs_embeds is not None: 28 | output = self.model(attention_mask=attention_mask, token_type_ids=token_type_ids, 29 | labels=labels, inputs_embeds=inputs_embeds, return_dict=True) 30 | else: 31 | output = self.model(input_ids=input_ids, attention_mask=attention_mask, 32 | token_type_ids=token_type_ids, labels=labels, return_dict=True) 33 | return output -------------------------------------------------------------------------------- /src/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from copy import deepcopy 4 | from sklearn.metrics import accuracy_score 5 | 6 | 7 | def move_to_device(batch, cuda_device=None): 8 | res = {} 9 | for x in batch: 10 | res[x] = batch[x].to(cuda_device) 11 | return res 12 | 13 | 14 | def compute_accuracy(model, tokenizer, test_loader, trigger, device, verbose=False): 15 | model.eval() 16 | attack_length = len(trigger) 17 | trigger_token_ids = [tokenizer.get_vocab()[x] for x in trigger] 18 | trigger_seq_tensor = torch.LongTensor(trigger_token_ids) 19 | 20 | preds = [] 21 | lbls = [] 22 | 23 | is_first_batch = True 24 | with torch.no_grad(): 25 | for batch, labels in test_loader: 26 | batch = move_to_device(batch, cuda_device=device) 27 | trigger_sequence_tensor =\ 28 | trigger_seq_tensor.repeat(len(batch['input_ids']), 1).to(device) 29 | 30 | input_ids = deepcopy(batch['input_ids']) 31 | input_ids[:, 1: attack_length + 1] = trigger_seq_tensor 32 | 33 | output_dict = model(attention_mask=batch['attention_mask'], 34 | token_type_ids=batch['token_type_ids'], 35 | input_ids=input_ids, 36 | return_dict=True) 37 | 38 | preds.append(output_dict['logits'].argmax(dim=1).cpu().numpy().reshape(-1, 1)) 39 | lbls.append(labels.numpy().reshape(-1, 1)) 40 | 41 | if verbose and is_first_batch: 42 | is_first_batch = False 43 | print(tokenizer.convert_ids_to_tokens(batch['input_ids'][0])) 44 | print(tokenizer.convert_ids_to_tokens(input_ids[0])) 45 | 46 | preds = np.vstack(preds) 47 | lbls = np.vstack(lbls) 48 | return accuracy_score(lbls, preds) 49 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from .utils import unfreeze_params 4 | 5 | 6 | class BaseBertVictim(nn.Module): 7 | def __init__(self, model, layer): 8 | super(BaseBertVictim, self).__init__() 9 | self.model = model 10 | self.layer = layer 11 | 12 | def preprocess_model(self): 13 | if hasattr(self.model, "pooler"): 14 | self.model.pooler = None 15 | 16 | unfreeze_params(self.model.encoder) 17 | self.model.eval() 18 | 19 | @property 20 | def vocab_size(self): 21 | return self.vocab().shape[0] 22 | 23 | @property 24 | def vocab(self): 25 | return self.model.embeddings.word_embeddings.weight 26 | 27 | def get_inputs_embeds(self, input_ids): 28 | return self.model.embeddings.word_embeddings(input_ids) 29 | 30 | def forward(self, inputs_embeds, attention_mask, token_type_ids, **kwargs): 31 | output = self.model( 32 | attention_mask=attention_mask, 33 | token_type_ids=token_type_ids, 34 | inputs_embeds=inputs_embeds, 35 | return_dict=True, 36 | ) 37 | return output.last_hidden_state * attention_mask.unsqueeze(-1) 38 | 39 | 40 | class BertVictim(BaseBertVictim): 41 | def __init__(self, model, layer=0): 42 | super(BertVictim, self).__init__(model, layer) 43 | self.model.encoder.layer = self.model.encoder.layer[: self.layer + 1] 44 | self.preprocess_model() 45 | 46 | 47 | class AlbertVictim(BaseBertVictim): 48 | def __init__(self, model, layer=0): 49 | super().__init__(model, layer) 50 | self.preprocess_model() 51 | 52 | def preprocess_model(self): 53 | group_idx = int(self.layer / 54 | (self.model.config.num_hidden_layers 55 | / self.model.config.num_hidden_groups)) 56 | self.model.encoder.albert_layer_groups = self.model.encoder.albert_layer_groups[:group_idx + 1] 57 | self.model.config.num_hidden_layers = self.layer + 1 58 | super().preprocess_model() 59 | -------------------------------------------------------------------------------- /wallace/utls/data.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from tqdm import tqdm 5 | 6 | 7 | text_fields = {('glue', 'sst2'): ('sentence', None), 8 | ('glue', 'mrpc'): ('sentence1', 'sentence2'), 9 | ('glue', 'rte'): ('sentence1', 'sentence2'), 10 | ('glue', 'qnli'): ('question', 'sentence'), 11 | ('glue', 'mnli'): ('premise', 'hypothesis'), 12 | ('glue', 'qqp'): ('sentence1', 'sentence2') 13 | } 14 | 15 | 16 | def collate_fn_(batch, tokenizer, dataset_name, dataset_subname): 17 | is_float = isinstance(batch[0]['label'], list) 18 | batch = {key: [i[key] for i in batch] for key in batch[0]} 19 | sentence1_key, sentence2_key = text_fields[(dataset_name, dataset_subname)] 20 | 21 | tokenized = tokenizer(batch.pop(sentence1_key), 22 | batch.pop(sentence2_key) if sentence2_key is not None else sentence2_key, 23 | padding=True, 24 | return_tensors='pt', 25 | return_token_type_ids=True) 26 | if is_float: 27 | tokenized.data['labels'] = torch.tensor(batch['label']).float() 28 | else: 29 | tokenized.data['labels'] = torch.tensor(batch['label']).long() 30 | return tokenized 31 | 32 | 33 | def preprocess_data_for_asr(dataset, 34 | dataset_name, 35 | dataset_subname, 36 | tokenizer, 37 | target_model, 38 | batch_size=64, 39 | device='cpu'): 40 | target_model.to(device) 41 | target_model.eval() 42 | 43 | loader = DataLoader(dataset, 44 | batch_size=batch_size, 45 | shuffle=False, 46 | collate_fn=partial(collate_fn_, 47 | tokenizer=tokenizer, 48 | dataset_name=dataset_name, 49 | dataset_subname=dataset_subname) 50 | ) 51 | lables = [] 52 | with torch.no_grad(): 53 | for batch in tqdm(loader): 54 | batch = batch.to(device) 55 | batch_lables = target_model(**batch, return_dict=True).logits 56 | lables += batch_lables.softmax(dim=-1).cpu().tolist() 57 | mapped_dataset = dataset.map(lambda x, idx: {'label': lables[idx]}, with_indices=True) 58 | return mapped_dataset -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 4 | from datasets import load_dataset 5 | 6 | from torch.utils.data import DataLoader 7 | from src.utils import collate_fn, insert_initial_trigger 8 | from src.utils import preprocess_data_for_asr, set_seed 9 | from src.accuracy import compute_accuracy 10 | 11 | from functools import partial 12 | 13 | import json 14 | import numpy as np 15 | import warnings 16 | warnings.simplefilter("ignore") 17 | 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--trigger', help="txt file with trigger", type=str, default='./trigger.txt') 22 | parser.add_argument('--batch_size', help="batch_size", type=int, default=32) 23 | parser.add_argument('--device', help="device", type=str, default='cuda:0') 24 | parser.add_argument('--seed', help="seed", type=int, default=0) 25 | parser.add_argument('--checkpoint', help="dir wit models checkpoints", type=str, default='textattack/bert-base-uncased-MRPC') 26 | parser.add_argument('--dataset_name', help="dataset name", type=str, default='glue') 27 | parser.add_argument('--dataset_subname', help="dataset subname", type=str, default='mrpc') 28 | parser.add_argument('--dataset_split', help="dataset subname", type=str, default='validation') 29 | parser.add_argument('--results_dir', help="dir for results", type=str, default='./results') 30 | args = parser.parse_args() 31 | 32 | set_seed(args.seed) 33 | with open('task_to_keys.json', 'r') as f: 34 | task_to_keys = json.load(f) 35 | 36 | dataset = load_dataset(args.dataset_name, args.dataset_subname) 37 | sentence1_key, sentence2_key = task_to_keys[args.dataset_subname] 38 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, use_fast=True) 39 | model = AutoModelForSequenceClassification.from_pretrained(args.checkpoint) 40 | 41 | with open(f'{args.trigger}', 'r') as f: 42 | trigger = f.readline() 43 | trigger = trigger.split(', ') 44 | attack_length = len(trigger) 45 | 46 | dataset = load_dataset(args.dataset_name, args.dataset_subname) 47 | sentence1_key, sentence2_key = task_to_keys[args.dataset_subname] 48 | preprocessed_dataset = preprocess_data_for_asr(dataset[args.dataset_split], 49 | sentence1_key, 50 | sentence2_key, 51 | tokenizer, 52 | model, 53 | batch_size=args.batch_size, device=args.device) 54 | 55 | #add three 'the' for each data sample in order to change them with triggers during attack training 56 | the_trigger = ' '.join(['the'] * attack_length) 57 | train_dataset = preprocessed_dataset.map(partial(insert_initial_trigger, 58 | sapmle_part=sentence1_key, 59 | mode='front', 60 | trigger=the_trigger)) 61 | 62 | #loader for evaluation 63 | eval_loader = DataLoader(train_dataset, 64 | batch_size=args.batch_size, 65 | shuffle=False, 66 | worker_init_fn=lambda x: np.random.seed(args.seed), 67 | collate_fn=partial(collate_fn, 68 | tokenizer=tokenizer, 69 | sentence1_key=sentence1_key, 70 | sentence2_key=sentence2_key, 71 | train=False)) 72 | 73 | accuracy = compute_accuracy(model, 74 | tokenizer, 75 | eval_loader, 76 | trigger, 77 | args.device, 78 | verbose=False) 79 | print('asr: ', 1 - accuracy) 80 | -------------------------------------------------------------------------------- /wallace/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import os 5 | import pandas as pd 6 | import numpy as np 7 | from functools import partial 8 | from transformers import AutoTokenizer 9 | 10 | from utls.data import collate_fn_, preprocess_data_for_asr 11 | from utls.models import WallaceVictimModel 12 | from utls.attacker import WallaceAttack 13 | 14 | from datasets import load_dataset 15 | from torch.utils.data import DataLoader 16 | 17 | import sys 18 | sys.path.append('../') 19 | from src.utils import TokenFilter, insert_initial_trigger, set_seed 20 | 21 | import warnings 22 | warnings.simplefilter("ignore") 23 | 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--beam_size', help="beam_size", type=int, default=1) 28 | parser.add_argument('--attack_length', help="attack_length", type=int, default=3) 29 | parser.add_argument('--topk', help="topk", type=int, default=10) 30 | parser.add_argument('--early_stop_patience', help="early_stop_patience", type=int, default=10) 31 | parser.add_argument('--epochs', help="n epochs", type=int, default=5) 32 | parser.add_argument('--batch_size', help="batch_size", type=int, default=128) 33 | parser.add_argument('--device', help="device", type=str, default='cuda:0') 34 | parser.add_argument('--seed', help="seed", type=int, default=0) 35 | parser.add_argument('--checkpoint', help="dir wit models checkpoints", type=str, default='textattack/bert-base-uncased-MRPC') 36 | parser.add_argument('--dataset_name', help="dataset name", type=str, default='glue') 37 | parser.add_argument('--dataset_subname', help="dataset subname", type=str, default='mrpc') 38 | parser.add_argument('--dataset_split', help="dataset subname", type=str, default='validation') 39 | parser.add_argument('--results_dir', help="dir for results", type=str, default='./results') 40 | args = parser.parse_args() 41 | 42 | 43 | set_seed(args.seed) 44 | 45 | with open('../task_to_keys.json', 'r') as f: 46 | task_to_keys = json.load(f) 47 | 48 | if not os.path.exists(args.results_dir): 49 | os.makedirs(args.results_dir) 50 | 51 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, use_fast=True) 52 | model = WallaceVictimModel(args.checkpoint) 53 | model.to(args.device) 54 | token_filter = TokenFilter(tokenizer=tokenizer) 55 | attacker = WallaceAttack(model, 56 | tokenizer, 57 | filtered_tokens_ids=token_filter.get_filtered_tokens_ids()) 58 | 59 | collate_fn = lambda batch: collate_fn_(batch, tokenizer, args.dataset_name, args.dataset_subname) 60 | valid_dataset = load_dataset(args.dataset_name, args.dataset_subname, split=args.dataset_split) 61 | sentence1_key, sentence2_key = task_to_keys[args.dataset_subname] 62 | trigger = ' '.join(['the'] * args.attack_length) 63 | valid_dataset_mapped = preprocess_data_for_asr(valid_dataset, 64 | args.dataset_name, 65 | args.dataset_subname, 66 | tokenizer, 67 | model.model, 68 | batch_size=args.batch_size, 69 | device=args.device) 70 | preprocessed_dataset = valid_dataset_mapped.map(partial(insert_initial_trigger, 71 | sapmle_part=sentence1_key, 72 | mode='front', 73 | trigger=trigger)) 74 | val_loader = DataLoader(preprocessed_dataset, 75 | batch_size=args.batch_size, 76 | shuffle=True, 77 | drop_last=False, 78 | collate_fn=collate_fn, 79 | worker_init_fn=lambda x: np.random.seed(args.seed)) 80 | 81 | file_name = f'attack_ntt={args.attack_length}_topk={args.topk}_bs={args.beam_size}' 82 | results = attacker.train(val_loader, 83 | num_trigger_tokens=args.attack_length, 84 | num_epochs=args.epochs, 85 | beam_size=args.beam_size, 86 | num_candidates=args.topk, 87 | device=args.device, 88 | patience=args.early_stop_patience) 89 | 90 | with open(f'{args.results_dir}/{file_name}.txt', 'w') as f: 91 | f.write(min(results, key=lambda x: x['objective'])['triggers']) 92 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import random 3 | import torch 4 | 5 | import numpy as np 6 | 7 | from functools import partial 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | 12 | class EarlyStop: 13 | def __init__(self, patience=0): 14 | self.patience = patience 15 | self.best_objective = None 16 | self.best_triggers = None 17 | self.counter = 0 18 | 19 | def __call__(self, objective, triggers): 20 | if self.best_objective is None: 21 | self.best_objective = objective 22 | self.best_triggers = triggers 23 | return False 24 | if self.best_objective < objective: 25 | self.best_objective = objective 26 | self.best_triggers = triggers 27 | self.counter = 0 28 | else: 29 | if self.patience - self.counter == 0: 30 | return True 31 | self.counter += 1 32 | return False 33 | 34 | 35 | class TokenFilter: 36 | def __init__(self, tokenizer): 37 | self.tokenizer = tokenizer 38 | 39 | #add all unused tokens 40 | self.filtered_tokens_ids = set( 41 | id for token, id in tokenizer.get_vocab().items() if "[unused" in token 42 | ) 43 | #add all special tokens 44 | self.filtered_tokens_ids |= set(tokenizer.all_special_ids) 45 | #add all tokens from vocab which make is_included return False 46 | self.filtered_tokens_ids |= set( 47 | id 48 | for token, id in self.tokenizer.get_vocab().items() 49 | if not self.is_included(token) 50 | ) 51 | 52 | def get_filtered_tokens_ids(self): 53 | return list(self.filtered_tokens_ids) 54 | 55 | def is_included(self, token): 56 | if token in self.filtered_tokens_ids: 57 | return False 58 | 59 | #convert token to string, in bert case for word pieces '##' will be at the start of the string 60 | token = self.tokenizer.convert_tokens_to_string([token]).strip() 61 | 62 | # length of string before processing 63 | len_before = len(token) 64 | 65 | #shows difference between token's length before and after processing with regexp 66 | delta_len = 0 67 | 68 | # if we have bert tokenizer 69 | if 'roberta' not in self.tokenizer.name_or_path and 'albert' not in self.tokenizer.name_or_path: 70 | #and if we have word piece 71 | if token[:2] == '##' and len(token) != 2: 72 | # difference between token's length will be 2(valid '##' at start of word piece will be erased by regexp) 73 | delta_len = 2 74 | #substitute all not a letter and not a digit symbols with '' 75 | token = re.sub(r"\W+", "", token) 76 | #substitute all not an english letter and not a digit symbols with '' 77 | token = re.sub(r"[^A-z0-9]", "", token) 78 | #token became shorten if we erase invalid symbols on preprocessing 79 | if len(token) + delta_len != len_before: 80 | return False 81 | return True 82 | 83 | 84 | def unfreeze_params(model): 85 | for param in model.parameters(): 86 | param.requires_grad = True 87 | 88 | 89 | def insert_initial_trigger(sample, sapmle_part, mode, trigger): 90 | if mode == "front": 91 | sample[sapmle_part] = f"{trigger} {sample[sapmle_part]}" 92 | else: 93 | sample[sapmle_part] = f"{sample[sapmle_part][:-1]} {trigger} {sample[sapmle_part][-1]}" 94 | return sample 95 | 96 | 97 | def collate_fn(batch, tokenizer, sentence1_key, sentence2_key=None, train=True): 98 | batch = {key: [i[key] for i in batch] for key in batch[0]} 99 | 100 | tokenized = tokenizer( 101 | batch.pop(sentence1_key), 102 | batch.pop(sentence2_key) if sentence2_key is not None else sentence2_key, 103 | padding=True, 104 | return_tensors="pt", 105 | return_token_type_ids=True, 106 | ) 107 | 108 | if train: 109 | return tokenized 110 | else: 111 | label = torch.tensor(batch["label"]).long() 112 | return tokenized, label 113 | 114 | 115 | def preprocess_data_for_asr(dataset, 116 | sentence1_key, 117 | sentence2_key, 118 | tokenizer, 119 | target_model, 120 | batch_size=64, 121 | device='cpu'): 122 | 123 | target_model.to(device) 124 | target_model.eval() 125 | loader = DataLoader(dataset, 126 | batch_size=batch_size, 127 | shuffle=False, 128 | collate_fn=partial(collate_fn, 129 | tokenizer=tokenizer, 130 | sentence1_key=sentence1_key, 131 | sentence2_key=sentence2_key) 132 | ) 133 | labels = [] 134 | with torch.no_grad(): 135 | for batch in tqdm(loader): 136 | batch = batch.to(device) 137 | batch_labels = target_model(**batch, return_dict=True).logits 138 | labels += batch_labels.argmax(dim=-1).cpu().tolist() 139 | mapped_dataset = dataset.map(lambda x, idx: {'label': labels[idx]}, with_indices=True) 140 | return mapped_dataset 141 | 142 | 143 | def set_seed(seed): 144 | torch.manual_seed(seed) 145 | torch.cuda.manual_seed(seed) 146 | torch.cuda.manual_seed_all(seed) 147 | np.random.seed(seed) 148 | random.seed(seed) 149 | torch.manual_seed(seed) 150 | torch.backends.cudnn.benchmark = False 151 | torch.backends.cudnn.deterministic = True -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification 4 | from datasets import load_dataset 5 | 6 | from src.attacker import SimplexAttacker 7 | from src.models import BertVictim, AlbertVictim 8 | 9 | from torch.utils.data import DataLoader 10 | from src.utils import collate_fn, insert_initial_trigger 11 | from src.utils import preprocess_data_for_asr, set_seed 12 | 13 | from src.utils import TokenFilter 14 | from functools import partial 15 | 16 | import os 17 | import json 18 | from transformers import AutoTokenizer 19 | import numpy as np 20 | 21 | import warnings 22 | warnings.simplefilter("ignore") 23 | 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--q', help="q parameter", type=int, default=2) 28 | parser.add_argument('--layer', help="attacked layer", type=int, default=0) 29 | parser.add_argument('--beam_size', help="beam_size", type=int, default='1') 30 | parser.add_argument('--attack_length', help="attack_length", type=int, default=3) 31 | parser.add_argument('--topk', help="topk", type=int, default=10) 32 | parser.add_argument('--mode', help="how to init W", type=str, default='const') 33 | parser.add_argument('--early_stop_patience', help="early_stop_patience", type=int, default=10) 34 | parser.add_argument('--epochs', help="n epochs", type=int, default=50) 35 | parser.add_argument('--batch_size', help="batch_size", type=int, default=32) 36 | parser.add_argument('--accumulation_steps', help="accumulation_steps", type=int, default=4) 37 | parser.add_argument('--device', help="device", type=str, default='cuda:0') 38 | parser.add_argument('--seed', help="seed", type=int, default=0) 39 | parser.add_argument('--checkpoint', help="dir wit models checkpoints", type=str, default='textattack/bert-base-uncased-MRPC') 40 | parser.add_argument('--dataset_name', help="dataset name", type=str, default='glue') 41 | parser.add_argument('--dataset_subname', help="dataset subname", type=str, default='mrpc') 42 | parser.add_argument('--dataset_split', help="dataset subname", type=str, default='validation') 43 | parser.add_argument('--results_dir', help="dir for results", type=str, default='./results') 44 | args = parser.parse_args() 45 | 46 | set_seed(args.seed) 47 | with open('task_to_keys.json', 'r') as f: 48 | task_to_keys = json.load(f) 49 | 50 | if not os.path.exists(args.results_dir): 51 | os.makedirs(args.results_dir) 52 | 53 | dataset = load_dataset(args.dataset_name, args.dataset_subname) 54 | sentence1_key, sentence2_key = task_to_keys[args.dataset_subname] 55 | 56 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, use_fast=True) 57 | victim_model = AutoModel.from_pretrained(args.checkpoint) 58 | target_model = AutoModelForSequenceClassification.from_pretrained(args.checkpoint) 59 | 60 | #albert's Victim model is different from bert's and roberta's one 61 | if 'albert' in args.checkpoint: 62 | victim_model = AlbertVictim(victim_model, layer=args.layer) 63 | else: 64 | victim_model = BertVictim(victim_model, layer=args.layer) 65 | 66 | #make dataset with pseudolabels for fooling rate calculation 67 | preprocessed_dataset = preprocess_data_for_asr(dataset[args.dataset_split], 68 | sentence1_key, 69 | sentence2_key, 70 | tokenizer, 71 | target_model, 72 | batch_size=args.batch_size, 73 | device=args.device) 74 | 75 | #find id of init tocken 'the' 76 | init_token_id = tokenizer('the')['input_ids'][1] 77 | 78 | #tokens filter without fasttext's usage 79 | token_filter = TokenFilter(tokenizer=tokenizer) 80 | 81 | #add three 'the' for each data sample in order to change them with triggers during attack training 82 | trigger = ' '.join(['the'] * args.attack_length) 83 | train_dataset = preprocessed_dataset.map(partial(insert_initial_trigger, 84 | sapmle_part=sentence1_key, 85 | mode='front', 86 | trigger=trigger)) 87 | #loader for training 88 | loader = DataLoader(train_dataset, 89 | batch_size=args.batch_size, 90 | shuffle=True, 91 | worker_init_fn=lambda x: np.random.seed(args.seed), 92 | collate_fn=partial(collate_fn, 93 | tokenizer=tokenizer, 94 | sentence1_key=sentence1_key, 95 | sentence2_key=sentence2_key, 96 | train=False)) 97 | #loader for evaluation 98 | eval_loader = DataLoader(train_dataset, 99 | batch_size=args.batch_size, 100 | shuffle=False, 101 | worker_init_fn=lambda x: np.random.seed(args.seed), 102 | collate_fn=partial(collate_fn, 103 | tokenizer=tokenizer, 104 | sentence1_key=sentence1_key, 105 | sentence2_key=sentence2_key, 106 | train=False)) 107 | 108 | file_name = f'attack_l={args.layer}_q={args.q}_t={args.attack_length}_bs={args.beam_size}_topk={args.topk}_mode={args.mode}' 109 | 110 | attacker = SimplexAttacker(q=args.q, 111 | victim_model=victim_model, 112 | target_model=target_model, 113 | attack_length=args.attack_length, 114 | init_token_id=init_token_id, 115 | filtered_tokens_ids=token_filter.get_filtered_tokens_ids(), 116 | initialization_mode=args.mode, 117 | device=args.device) 118 | 119 | attacker.train(epochs=args.epochs, 120 | accumulation_steps=args.accumulation_steps, 121 | early_stop_patience=args.early_stop_patience, 122 | tokenizer=tokenizer, 123 | train_loader=loader, 124 | eval_loader=eval_loader, 125 | beam_size=args.beam_size, 126 | topk=args.topk) 127 | 128 | 129 | with open(f'{args.results_dir}/{file_name}.txt', 'w') as f: 130 | f.write(min(attacker.results, key=lambda x: x['acc'])['triggers']) -------------------------------------------------------------------------------- /src/attacker.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from functools import partial 3 | from heapq import nsmallest 4 | 5 | import torch 6 | 7 | from torch import nn 8 | from torch.autograd.functional import jvp 9 | from torch.autograd.functional import vjp 10 | from torch.nn.functional import cross_entropy 11 | from tqdm import tqdm 12 | 13 | from .utils import EarlyStop 14 | 15 | 16 | class SimplexAttacker: 17 | def __init__( 18 | self, 19 | victim_model, 20 | target_model, 21 | q=2, 22 | attack_length=5, 23 | device=None, 24 | init_token_id=1996, 25 | initialization_mode="random", 26 | filtered_tokens_ids=[], 27 | ): 28 | if device is None: 29 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | else: 31 | self.device = device 32 | 33 | self.model = victim_model.to(self.device) 34 | self.target_model = target_model.to(self.device) 35 | 36 | self.q = q 37 | self.attack_length = attack_length 38 | 39 | self.vocab = self.model.vocab 40 | 41 | self.initialization_mode = initialization_mode 42 | self.filtered_tokens_ids = filtered_tokens_ids 43 | self.init_W() 44 | 45 | self.trigger_ids = torch.tensor([init_token_id] * self.attack_length).to(self.device) 46 | 47 | self.objective = torch.Tensor([float("inf")]) 48 | self.results = [] 49 | 50 | def init_W(self): 51 | if not hasattr(self, "W"): 52 | self.W = torch.ones(self.attack_length, 53 | self.vocab.shape[0], 54 | requires_grad=True).to(self.device) 55 | if self.initialization_mode == "random": 56 | self.W.exponential_() 57 | 58 | self.W[:, self.filtered_tokens_ids] = 0.0 59 | self.W /= self.W.sum(dim=-1, keepdim=True) 60 | 61 | @staticmethod 62 | def phi(x, p): 63 | return torch.sign(x) * torch.abs(x).pow(p - 1) 64 | 65 | def step(self, batch): 66 | batch = batch.to(self.device) 67 | 68 | inputs_embeds = self.model.get_inputs_embeds(batch["input_ids"]) 69 | inputs_embeds[:, 1:self.attack_length + 1] = self.model.get_inputs_embeds(self.trigger_ids) 70 | 71 | f = partial(self.model.forward, **batch) 72 | 73 | x = torch.zeros_like(inputs_embeds) 74 | x[:, 1:self.attack_length + 1] = self.W @ self.vocab - self.model.get_inputs_embeds(self.trigger_ids) 75 | x = jvp(f, inputs_embeds, x)[1] 76 | x = self.phi(x, self.q) 77 | 78 | x = vjp(f, inputs_embeds, x)[1] 79 | x = x[:, 1:self.attack_length + 1].sum(dim=0) @ self.vocab.T 80 | return x 81 | 82 | def get_trigger(self): 83 | return self.trigger_ids 84 | 85 | def compute_logits(self, trigger_ids, batch): 86 | input_ids = batch["input_ids"] 87 | input_ids[:, 1 : self.attack_length + 1] = trigger_ids 88 | 89 | return self.target_model( 90 | attention_mask=batch["attention_mask"], 91 | token_type_ids=batch["token_type_ids"], 92 | input_ids=input_ids, 93 | return_dict=True, 94 | ).logits 95 | 96 | def compute_metric(self, trigger_ids, batches, metric_function): 97 | metric = 0.0 98 | 99 | with torch.no_grad(): 100 | for batch, labels in batches: 101 | batch = batch.to(self.device) 102 | labels = labels.to(self.device) 103 | output = self.compute_logits(trigger_ids, batch) 104 | metric += metric_function(output, labels).cpu().numpy() 105 | return metric 106 | 107 | def compute_accuracy(self, trigger_ids): 108 | metric_function = lambda output, labels: (output.argmax(dim=-1) == labels).float().sum() 109 | return self.compute_metric(trigger_ids, 110 | self.eval_loader, 111 | metric_function) / len(self.eval_loader.dataset) 112 | 113 | def compute_batch_loss(self, trigger_ids, accum_batches): 114 | metric_function = lambda output, labels: cross_entropy(output, labels, reduction="sum") 115 | n_samples = sum(len(x) for x, _ in accum_batches) 116 | return -self.compute_metric(trigger_ids, 117 | accum_batches, 118 | metric_function) / n_samples 119 | 120 | def compute_batch_asr(self, trigger_ids, accum_batches): 121 | metric_function = lambda output, labels: (output.argmax(dim=-1) != labels).float().sum() 122 | n_samples = sum(len(x) for x, _ in accum_batches) 123 | return -self.compute_metric(trigger_ids, 124 | accum_batches, 125 | metric_function) / n_samples 126 | 127 | def beam_search(self, accum_batches, candidates, beam_size=1): 128 | current_triggers = [(self.objective, self.trigger_ids.cpu().clone())] 129 | trigger_ids = self.trigger_ids.clone() 130 | for token in candidates[0]: 131 | trigger_ids[0] = token 132 | criterion = self.compute_batch_asr(trigger_ids, accum_batches) 133 | current_triggers.append((criterion, trigger_ids.cpu().clone())) 134 | 135 | beam = nsmallest(beam_size, current_triggers, key=lambda x: x[0]) 136 | for i in range(1, candidates.shape[0]): 137 | current_triggers = beam.copy() 138 | for _, trigger_ids in beam: 139 | for token in candidates[i]: 140 | current_ids = trigger_ids.clone().to(self.device) 141 | current_ids[i] = token 142 | criterion = self.compute_batch_asr(current_ids, accum_batches) 143 | current_triggers.append((criterion, current_ids.cpu().clone())) 144 | beam = nsmallest(beam_size, current_triggers, key=lambda x: x[0]) 145 | 146 | return beam 147 | 148 | def train( 149 | self, 150 | train_loader, 151 | eval_loader=None, 152 | epochs=1, 153 | accumulation_steps=1, 154 | early_stop_patience=10, 155 | tokenizer=None, 156 | beam_size=1, 157 | topk=10, 158 | ): 159 | self.model.eval() 160 | self.accumulation_steps = accumulation_steps 161 | 162 | self.early_stop = EarlyStop(patience=early_stop_patience) 163 | 164 | self.train_loader = train_loader 165 | 166 | if eval_loader is None: 167 | self.eval_loader = train_loader 168 | else: 169 | self.eval_loader = eval_loader 170 | 171 | if self.accumulation_steps == -1: 172 | self.accumulation_steps = len(self.train_loader) 173 | 174 | for i in range(epochs): 175 | x, accum_batches = 0.0, [] 176 | for j, (batch, labels) in enumerate(self.train_loader): 177 | accum_batches.append((batch, labels)) 178 | batch_x = self.step(batch) 179 | x += batch_x 180 | 181 | if (j + 1) % self.accumulation_steps == 0: 182 | x /= self.train_loader.batch_size * self.accumulation_steps 183 | x[:, self.filtered_tokens_ids] = float("-inf") 184 | x = torch.topk(x, k=topk, dim=-1, largest=True, sorted=True).indices 185 | 186 | self.triggers = self.beam_search(accum_batches, x, beam_size) 187 | self.trigger_ids = self.triggers[0][-1].to(self.device) 188 | 189 | self.objective = self.compute_accuracy(self.trigger_ids) 190 | 191 | current_step = j + i * len(self.train_loader) 192 | 193 | print( 194 | f"Iteration {current_step} / {len(self.train_loader) * epochs}, objective: {self.objective:.4f}", 195 | end=" ", 196 | ) 197 | if tokenizer is not None: 198 | self.results.append( 199 | { 200 | "triggers": ", ".join( 201 | tokenizer.convert_ids_to_tokens(self.trigger_ids) 202 | ), 203 | "acc": self.objective, 204 | } 205 | ) 206 | 207 | print(tokenizer.convert_ids_to_tokens(self.trigger_ids)) 208 | else: 209 | print() 210 | 211 | self.init_W() 212 | x, accum_batches = 0.0, [] 213 | 214 | if self.early_stop(-self.objective, self.trigger_ids): 215 | self.trigger_ids = self.early_stop.best_triggers 216 | return -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /wallace/utls/attacker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.functional import cross_entropy 3 | import numpy as np 4 | import heapq 5 | from operator import itemgetter 6 | from copy import deepcopy 7 | 8 | import sys 9 | sys.path.append('../') 10 | from src.utils import EarlyStop 11 | 12 | 13 | def hotflip_attack(averaged_grad, embedding_matrix, 14 | increase_loss=False, num_candidates=1, filtered_tokens_ids=[]): 15 | """ 16 | The "Hotflip" attack described in Equation (2) of the paper. This code is heavily inspired by 17 | the nice code of Paul Michel here https://github.com/pmichel31415/translate/blob/paul/ 18 | pytorch_translate/research/adversarial/adversaries/brute_force_adversary.py 19 | This function takes in the model's average_grad over a batch of examples, the model's 20 | token embedding matrix, and the current trigger token IDs. It returns the top token 21 | candidates for each position. 22 | If increase_loss=True, then the attack reverses the sign of the gradient and tries to increase 23 | the loss (decrease the model's probability of the true class). For targeted attacks, you want 24 | to decrease the loss of the target class (increase_loss=False). 25 | """ 26 | averaged_grad = averaged_grad.cpu() 27 | embedding_matrix = embedding_matrix.cpu() 28 | 29 | averaged_grad = averaged_grad.unsqueeze(0) 30 | 31 | gradient_dot_embedding_matrix = torch.einsum("bij,kj->bik", 32 | (averaged_grad, embedding_matrix)) 33 | 34 | if len(filtered_tokens_ids) != 0: 35 | gradient_dot_embedding_matrix[:, :, filtered_tokens_ids] = -np.inf 36 | 37 | if not increase_loss: 38 | gradient_dot_embedding_matrix *= -1 # lower versus increase the class probability. 39 | if num_candidates > 1: # get top k options 40 | _, best_k_ids = torch.topk(gradient_dot_embedding_matrix, num_candidates, dim=2) 41 | return best_k_ids.detach().cpu().numpy()[0] 42 | _, best_at_each_step = gradient_dot_embedding_matrix.max(2) 43 | 44 | cand_trigger_token_ids = best_at_each_step[0].detach().cpu().numpy() 45 | if num_candidates == 1: 46 | cand_trigger_token_ids = cand_trigger_token_ids.reshape(-1, 1) 47 | 48 | return cand_trigger_token_ids 49 | 50 | 51 | def move_to_device(batch, cuda_device): 52 | return {x:batch[x].to(cuda_device) for x in batch} 53 | 54 | 55 | class WallaceAttack: 56 | """ 57 | This is the Eric Wallace code adoptation (https://github.com/Eric-Wallace/universal-triggers). 58 | """ 59 | def __init__(self, model, tokenizer, filtered_tokens_ids=[]): 60 | self.model = model 61 | self.tokenizer = tokenizer 62 | self.filtered_tokens_ids = filtered_tokens_ids 63 | 64 | self.orig_acc = 0.0 65 | self.triggers_set = set() 66 | self.embedding_weight = model.vocab 67 | self.embedding_weight.requires_grad=True 68 | 69 | 70 | def _accuracy(self, dev_dataset, trigger_token_ids): 71 | acc = 0. 72 | for batch in dev_dataset: 73 | with torch.no_grad(): 74 | logits = self.evaluate_batch(batch, trigger_token_ids).logits 75 | acc += (logits.argmax(dim=-1) == batch['labels'].argmax(dim=-1).to(self.device)).float().sum().cpu() 76 | if isinstance(dev_dataset, list): 77 | return acc / len(batch['labels']) 78 | else: 79 | return acc / len(dev_dataset.dataset) 80 | 81 | 82 | def get_accuracy(self, dev_dataset, trigger_token_ids=None): 83 | """ 84 | When trigger_token_ids is None, gets accuracy on the dev_dataset. Otherwise, gets accuracy with 85 | triggers prepended for the whole dev_dataset. 86 | """ 87 | 88 | self.model.eval() # model should be in eval() already, but just in case 89 | accuracy = self._accuracy(dev_dataset, trigger_token_ids).item() 90 | 91 | if trigger_token_ids is None: 92 | print("With 'the, the ...' Triggers: " + str(accuracy)) 93 | self.orig_acc = accuracy 94 | self.triggers_set = set() 95 | else: 96 | print_string = ', '.join(self.tokenizer.convert_ids_to_tokens(trigger_token_ids)) 97 | print("Current Triggers: " + print_string + " : " + str(accuracy)) 98 | return print_string, accuracy 99 | 100 | 101 | def evaluate_batch(self, batch, trigger_token_ids=None, inputs_embeds=None): 102 | 103 | """ 104 | Takes a batch of classification examples (SNLI or SST), and runs them through the model. 105 | If trigger_token_ids is not None, then it will append the tokens to the input. 106 | This funtion is used to get the model's accuracy and/or the loss with/without the trigger. 107 | """ 108 | batch = move_to_device(batch, cuda_device=self.device) 109 | if trigger_token_ids is None: 110 | output_dict = self.model(input_ids=batch['input_ids'], 111 | token_type_ids=batch['token_type_ids'], 112 | attention_mask=batch['attention_mask'], 113 | labels=None,) 114 | else: 115 | attack_length = len(trigger_token_ids) 116 | trigger_sequence_tensor = torch.LongTensor(trigger_token_ids) 117 | trigger_sequence_tensor = trigger_sequence_tensor.repeat(len(batch['input_ids']), 1).to(self.device) 118 | 119 | input_ids = deepcopy(batch['input_ids']) 120 | input_ids[:, 1: attack_length + 1] = trigger_sequence_tensor 121 | 122 | output_dict = self.model(attention_mask=batch['attention_mask'], 123 | token_type_ids=batch['token_type_ids'], 124 | input_ids=input_ids, 125 | labels=None, 126 | inputs_embeds=inputs_embeds) 127 | output_dict.loss = cross_entropy(output_dict.logits, batch['labels']) 128 | return output_dict 129 | 130 | 131 | def get_average_grad(self, batch, trigger_token_ids): 132 | """ 133 | Computes the average gradient w.r.t. the trigger tokens when prepended to every example 134 | in the batch. If target_label is set, that is used as the ground-truth label. 135 | """ 136 | attack_length = len(trigger_token_ids) 137 | trigger_sequence_tensor = torch.LongTensor(trigger_token_ids) 138 | trigger_sequence_tensor = trigger_sequence_tensor.repeat(len(batch['input_ids']), 1).to(self.device) 139 | 140 | input_ids = deepcopy(batch['input_ids']).to(self.device) 141 | input_ids[:, 1: attack_length + 1] = trigger_sequence_tensor 142 | 143 | embds = self.model.get_inputs_embeds(input_ids) 144 | 145 | loss = self.evaluate_batch(batch, trigger_token_ids, embds).loss 146 | grads = torch.autograd.grad(loss, embds)[0].cpu() 147 | 148 | # average grad across batch size, result only makes sense for trigger tokens at the front 149 | averaged_grad = torch.sum(grads, dim=0) 150 | averaged_grad = averaged_grad[1:len(trigger_token_ids) + 1]# return just trigger grads 151 | return averaged_grad 152 | 153 | 154 | def get_best_candidates(self, batch, trigger_token_ids, cand_trigger_token_ids, beam_size=1): 155 | """" 156 | Given the list of candidate trigger token ids (of number of trigger words by number of candidates 157 | per word), it finds the best new candidate trigger. 158 | This performs beam search in a left to right fashion. 159 | """ 160 | # first round, no beams, just get the loss for each of the candidates in index 0. 161 | # (indices 1-end are just the old trigger) 162 | loss_per_candidate = self.get_loss_per_candidate(0, batch, trigger_token_ids, 163 | cand_trigger_token_ids) 164 | # maximize the loss 165 | top_candidates = heapq.nlargest(beam_size, loss_per_candidate, key=itemgetter(1)) 166 | 167 | # top_candidates now contains beam_size trigger sequences, each with a different 0th token 168 | for idx in range(1, len(trigger_token_ids)): # for all trigger tokens, skipping the 0th (we did it above) 169 | loss_per_candidate = [] 170 | for cand, _ in top_candidates: # for all the beams, try all the candidates at idx 171 | loss_per_candidate.extend(self.get_loss_per_candidate(idx, batch, cand, cand_trigger_token_ids)) 172 | top_candidates = heapq.nlargest(beam_size, loss_per_candidate, key=itemgetter(1)) 173 | return max(top_candidates, key=itemgetter(1))[0] 174 | 175 | 176 | 177 | 178 | 179 | def get_loss_per_candidate(self, index, batch, trigger_token_ids, cand_trigger_token_ids): 180 | """ 181 | For a particular index, the function tries all of the candidate tokens for that index. 182 | The function returns a list containing the candidate triggers it tried, along with their loss. 183 | """ 184 | if isinstance(cand_trigger_token_ids[0], (np.int64, int)): 185 | print("Only 1 candidate for index detected, not searching") 186 | return trigger_token_ids 187 | loss_per_candidate = [] 188 | # loss for the trigger without trying the candidates 189 | curr_loss = -self._accuracy([batch], trigger_token_ids).item() 190 | 191 | loss_per_candidate.append((deepcopy(trigger_token_ids), curr_loss)) 192 | for cand_id in range(len(cand_trigger_token_ids[0])): 193 | trigger_token_ids_one_replaced = deepcopy(trigger_token_ids) # copy trigger 194 | trigger_token_ids_one_replaced[index] = cand_trigger_token_ids[index][cand_id] # replace one token 195 | loss = -self._accuracy([batch], trigger_token_ids_one_replaced).item() 196 | loss_per_candidate.append((deepcopy(trigger_token_ids_one_replaced), loss)) 197 | return loss_per_candidate 198 | 199 | 200 | def train(self, val_loader, num_trigger_tokens=3, num_epochs=5, num_candidates=40, beam_size=1, device=None, patience=None): 201 | if patience is not None: 202 | early_stop = EarlyStop(patience) 203 | if device is not None: self.device = device 204 | else: self.device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') 205 | self.get_accuracy(val_loader, trigger_token_ids=None) 206 | self.model.eval() 207 | 208 | # initialize triggers which are concatenated to the input 209 | trigger_token_ids = [self.tokenizer.convert_tokens_to_ids('the')] * num_trigger_tokens 210 | 211 | # sample batches, update the triggers, and repeat 212 | results = [] 213 | for epoch in range(num_epochs): 214 | for batch in val_loader: 215 | # get accuracy with current triggers 216 | triggers, objective = self.get_accuracy(val_loader, trigger_token_ids) 217 | result = {} 218 | result['triggers'], result['objective'] = triggers, objective 219 | results.append(result) 220 | 221 | if patience is not None: 222 | if early_stop(-objective, triggers): 223 | return results 224 | 225 | # get gradient w.r.t. trigger embeddings for current batch} 226 | averaged_grad = self.get_average_grad(batch, trigger_token_ids) 227 | 228 | # pass the gradients to a particular attack to generate token candidates for each token. 229 | cand_trigger_token_ids = hotflip_attack(averaged_grad, 230 | self.embedding_weight, 231 | num_candidates=num_candidates, 232 | increase_loss=True, filtered_tokens_ids=self.filtered_tokens_ids) 233 | 234 | # Tries all of the candidates and returns the trigger sequence with highest loss. 235 | trigger_token_ids = self.get_best_candidates(batch, 236 | trigger_token_ids, 237 | cand_trigger_token_ids, 238 | beam_size=beam_size) 239 | 240 | # print accuracy after adding triggers 241 | triggers, objective = self.get_accuracy(val_loader, trigger_token_ids) 242 | result = {} 243 | result['triggers'], result['objective'] = triggers, objective 244 | results.append(result) 245 | return results --------------------------------------------------------------------------------