├── README.md ├── .gitignore ├── adapted.py ├── certification.ipynb ├── hashs.py ├── utils.py ├── mapper.py └── train_cls.py /README.md: -------------------------------------------------------------------------------- 1 | # TextGuard 2 | This is the experiment code for our NDSS 2024 paper "TextGuard: Provable Defense against Backdoor Attacks on Text Classification". 3 | 4 | ## Requirements 5 | ### Dependencies 6 | ``` 7 | torch 8 | transformers==4.21.2 9 | fastNLP==0.6.0 10 | openbackdoor (commit id: d600dbec32b97a246b77c4c4d700ab2e01200151) 11 | ``` 12 | 13 | ### Prerequisites 14 | Please first follow OpenBackdoor [repo](https://github.com/thunlp/OpenBackdoor#download-datasets) to download the datasets and then soft link to our repo: 15 | ``` 16 | ln -s ../OpenBackdoor/datasets/ . 17 | ``` 18 | Besides, our generated backdoor data can be found [here](https://drive.google.com/file/d/1AYBqc5bKqBpGdBonrhTPhKicuyEIjoWR/view?usp=sharing). You can download it and unzip it to the `./poison/` folder. 19 | 20 | ## Training scripts 21 | Our training code is `train_cls.py`. We first describe some key args: 22 | 23 | `--setting`: backdoor attack setting, should be `mix`, `clean` or `dirty`. 24 | 25 | `--attack`: It denotes the backdoor attack method or certified evaluation (`--attack=noise`). 26 | 27 | `--poison_rate`: poisoning rate `p`. 28 | 29 | `--group`: number of groups. 30 | 31 | `--hash`: hash function we use. When it starts with `ki` (e.g. `--hash=ki`), it means we use the empirical defense technique `Potential trigger word identification` in the paper. Besides, it can be `md5`, `sha1` or `sha256` when not using this empirical defense technique. 32 | 33 | `--ki_t`: the parameter `K` used in the empirical defense technique `Potential trigger word identification`. 34 | 35 | `--sort`: used in the certified evaluation and not used in the empirical evaluation. 36 | 37 | `--not_split`: It means we use the empirical defense technique `Semantic preserving` in the paper. 38 | 39 | ### Certified evaluation 40 | We use the parameter `--attack noise` to denote the certified evaluation setting. 41 | 42 | Here are example commands that calculate certified accuracy using 3 groups under the mixed-label attack setting (p=0.1): 43 | ``` 44 | python train_cls.py --save_folder --attack noise --group 3 --target_word empty --setting mix --poison_rate 0.1 --sort --tokenize nltk 45 | python train_cls.py --save_folder --attack noise --group 3 --target_word empty --data hsol --setting mix --poison_rate 0.1 --sort --tokenize nltk 46 | python train_cls.py --save_folder --attack noise --group 3 --target_word empty --data agnews --num_class 4 --batchsize 32 --setting mix --poison_rate 0.1 --sort --tokenize nltk 47 | ``` 48 | 49 | ### Empirical evaluation 50 | When the parameter `--attack` is set to `badnets`, `addsent`, `synbkd` or `stylebkd`, we evaluate our methods under the empirical attack setting. 51 | 52 | Here are example commands for empirical evaluations under the mixed-label `BadWord` attack setting (p=0.1): 53 | ``` 54 | python train_cls.py --save_folder --attack badnets --group 9 --setting mix --poison_rate 0.1 --tokenize nltk --not_split --hash ki --target_word empty --ki_t 20 55 | python train_cls.py --save_folder --attack badnets --group 7 --setting mix --poison_rate 0.1 --tokenize nltk --not_split --hash ki --target_word empty --data hsol --ki_t 20 56 | python train_cls.py --save_folder --attack badnets --group 9 --setting mix --poison_rate 0.1 --tokenize nltk --not_split --hash ki --target_word empty --data agnews --num_class 4 --batchsize 32 57 | ``` -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /adapted.py: -------------------------------------------------------------------------------- 1 | import openbackdoor as ob 2 | from openbackdoor.defenders import RAPDefender, BKIDefender, STRIPDefender, Defender 3 | from openbackdoor.victims import Victim 4 | from openbackdoor.data import get_dataloader, collate_fn 5 | from openbackdoor.utils import logger 6 | from typing import * 7 | from sklearn.feature_extraction.text import TfidfVectorizer 8 | from torch.utils.data import DataLoader 9 | import random 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | import math 14 | from openbackdoor.attackers.poisoners import BadNetsPoisoner 15 | import copy 16 | 17 | class AdaptedBadNets(BadNetsPoisoner): 18 | def insert( 19 | self, 20 | text: str, 21 | ): 22 | r""" 23 | Insert trigger(s) randomly in a sentence. 24 | 25 | Args: 26 | text (`str`): Sentence to insert trigger(s). 27 | """ 28 | words = text.split() 29 | for i in range(self.num_triggers): 30 | insert_word = self.triggers[i] 31 | position = random.randint(0, len(words)) 32 | words.insert(position, insert_word) 33 | return " ".join(words) 34 | 35 | class AdaptedBKI(BKIDefender): 36 | def __init__(self, num=5, **kwargs): 37 | super().__init__(**kwargs) 38 | self.num = num 39 | 40 | def analyze_sent(self, model, sentence): 41 | model.eval() 42 | input_sents = [sentence] 43 | split_sent = sentence.strip().split() 44 | delta_li = [] 45 | for i in range(len(split_sent)): 46 | if i != len(split_sent) - 1: 47 | sent = ' '.join(split_sent[0:i] + split_sent[i + 1:]) 48 | else: 49 | sent = ' '.join(split_sent[0:i]) 50 | input_sents.append(sent) 51 | repr_embedding = [] 52 | for i in range(0, len(input_sents), 32): 53 | with torch.no_grad(): 54 | input_batch = model.tokenizer(input_sents[i:i+32], padding=True, truncation=True, return_tensors="pt").to(model.device) 55 | repr_embedding.append(model.get_repr_embeddings(input_batch)) # batch_size, hidden_size 56 | repr_embedding = torch.cat(repr_embedding) 57 | orig_tensor = repr_embedding[0] 58 | for i in range(1, repr_embedding.shape[0]): 59 | process_tensor = repr_embedding[i] 60 | delta = process_tensor - orig_tensor 61 | delta = float(np.linalg.norm(delta.detach().cpu().numpy(), ord=np.inf)) 62 | delta_li.append(delta) 63 | assert len(delta_li) == len(split_sent) 64 | sorted_rank_li = np.argsort(delta_li)[::-1] 65 | word_val = [] 66 | if len(sorted_rank_li) < 5: 67 | pass 68 | else: 69 | sorted_rank_li = sorted_rank_li[:5] 70 | for id in sorted_rank_li: 71 | word = split_sent[id] 72 | sus_val = delta_li[id] 73 | word_val.append((word, sus_val)) 74 | return word_val 75 | 76 | def analyze_data(self, model, poison_train): 77 | for sentence, target_label, _ in poison_train: 78 | sus_word_val = self.analyze_sent(model, sentence) 79 | temp_word = [] 80 | for word, sus_val in sus_word_val: 81 | temp_word.append(word) 82 | if word in self.bki_dict: 83 | orig_num, orig_sus_val = self.bki_dict[word] 84 | cur_sus_val = (orig_num * orig_sus_val + sus_val) / (orig_num + 1) 85 | self.bki_dict[word] = (orig_num + 1, cur_sus_val) 86 | else: 87 | self.bki_dict[word] = (1, sus_val) 88 | self.all_sus_words_li.append(temp_word) 89 | sorted_list = sorted(self.bki_dict.items(), key=lambda item: math.log10(item[1][0]) * item[1][1], reverse=True) 90 | bki_word = [x[0] for x in sorted_list[:self.num]] 91 | self.bki_word = bki_word 92 | print(bki_word) 93 | flags = [] 94 | for sus_words_li in self.all_sus_words_li: 95 | flag = 0 96 | for word in self.bki_word: 97 | if word in sus_words_li: 98 | flag = 1 99 | break 100 | flags.append(flag) 101 | 102 | filter_train = [] 103 | s = 0 104 | for i, data in enumerate(poison_train): 105 | if flags[i] == 0: 106 | filter_train.append(data) 107 | if data[-1]==1: 108 | s+=1 109 | print(len(filter_train), s) 110 | return filter_train 111 | -------------------------------------------------------------------------------- /certification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "331ba3ac", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from collections import Counter\n", 11 | "import pickle\n", 12 | "import numpy as np\n", 13 | "from openbackdoor import load_dataset\n", 14 | "import random\n", 15 | "from utils import setup_seed, certified, majority\n", 16 | "import os\n", 17 | "\n", 18 | "data = \"hsol\"\n", 19 | "path = f\"certified/{data}-noise1-clean-0.0cert\"\n", 20 | "if data == \"agnews\":\n", 21 | " C = 4\n", 22 | "else:\n", 23 | " C = 2\n", 24 | "group = 7\n", 25 | "preds = []\n", 26 | "for i in range(group):\n", 27 | " with open(f\"{path}{group}/42/clean_{i}.pkl\", \"rb\") as f:\n", 28 | " pred = pickle.load(f)\n", 29 | " preds.append(pred)\n", 30 | " \n", 31 | "setup_seed(42)\n", 32 | "dataset = load_dataset(name=data)\n", 33 | "n = len(dataset[\"train\"])\n", 34 | "gold = [x[1] for x in dataset[\"test\"]]\n" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "id": "4fbc2e7b", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "print(certified(preds, gold, C=C, target_label=1))" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "id": "d13fdc1a", 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "def individual(preds, gold, p):\n", 55 | " np_res = np.array(preds)\n", 56 | " acc = 0\n", 57 | " difs = []\n", 58 | " tot = 0\n", 59 | " for i in range(np_res.shape[1]):\n", 60 | " if gold[i]==1:\n", 61 | " continue\n", 62 | " tot += 1\n", 63 | " cnt = np.zeros(C)\n", 64 | " for x in np_res[:, i]:\n", 65 | " cnt[x]-=1\n", 66 | " idxs = cnt.argsort(kind=\"stable\")\n", 67 | " x, y = idxs[0], -cnt[idxs[0]]\n", 68 | " xx, yy = idxs[1], -cnt[idxs[1]] \n", 69 | "\n", 70 | " dif = int((y-(yy+int(xx=p:\n", 73 | " acc +=1\n", 74 | " #print(cnt, gold[i], dif)\n", 75 | " #if i>20:\n", 76 | " # break\n", 77 | " return acc/tot\n", 78 | "\n", 79 | "lis = []\n", 80 | "for i in range(1,4):\n", 81 | " lis.append(individual(preds, gold, i))\n", 82 | "print(lis)\n" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "id": "e72237a6", 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "from itertools import combinations\n", 93 | "from collections import Counter\n", 94 | "from sklearn.metrics import accuracy_score\n", 95 | "def certified2(preds, labels, C=2, target_label=None):\n", 96 | " # C is number of classes\n", 97 | " preds = np.array(preds) # Group * N\n", 98 | " labels = np.array(labels) # N\n", 99 | " if target_label is not None:\n", 100 | " target = labels!=target_label\n", 101 | " print(len(target), sum(target))\n", 102 | " preds = preds[:, target]\n", 103 | " labels = labels[target]\n", 104 | " print(preds.shape)\n", 105 | " \n", 106 | " final_pred = majority(preds, C=C)\n", 107 | " correct = (labels == final_pred)\n", 108 | " n_correct = sum(correct)\n", 109 | " n_wrong = len(labels) - n_correct\n", 110 | " lis_cacc = [n_correct/(n_correct+n_wrong)]\n", 111 | " m = len(preds)\n", 112 | " for i in range(1, m//2+1): # number of backdoored groups\n", 113 | " cacc = 1\n", 114 | " for lis in combinations(range(m), i): # iterate all combinations\n", 115 | " s = 0\n", 116 | " for k, x in enumerate(preds.transpose()):\n", 117 | " cnt = np.zeros(C)\n", 118 | " for xx in x:\n", 119 | " cnt[xx]-=1\n", 120 | " idxs = cnt.argsort(kind=\"stable\")\n", 121 | " a = idxs[0]\n", 122 | " U = -cnt[a]\n", 123 | " for j in lis:\n", 124 | " if x[j]==a:\n", 125 | " U-=1\n", 126 | " L = 0\n", 127 | " for b in range(C):\n", 128 | " if b!=a:\n", 129 | " r = 0\n", 130 | " for j in lis:\n", 131 | " if x[j]!=b:\n", 132 | " r+=1\n", 133 | " L = max(L, -cnt[b]+int(a>b)+r)\n", 134 | " if a==labels[k] and U>=L:\n", 135 | " s+=1\n", 136 | " cacc = min(cacc, s/len(labels))\n", 137 | " lis_cacc.append(cacc) \n", 138 | " return lis_cacc\n", 139 | "print(certified2(preds, gold, C, target_label=1))" 140 | ] 141 | } 142 | ], 143 | "metadata": { 144 | "kernelspec": { 145 | "display_name": "Python 3 (ipykernel)", 146 | "language": "python", 147 | "name": "python3" 148 | }, 149 | "language_info": { 150 | "codemirror_mode": { 151 | "name": "ipython", 152 | "version": 3 153 | }, 154 | "file_extension": ".py", 155 | "mimetype": "text/x-python", 156 | "name": "python", 157 | "nbconvert_exporter": "python", 158 | "pygments_lexer": "ipython3", 159 | "version": "3.9.7" 160 | } 161 | }, 162 | "nbformat": 4, 163 | "nbformat_minor": 5 164 | } 165 | -------------------------------------------------------------------------------- /hashs.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | 3 | def md5_hash(x, m): 4 | x = x.lower() 5 | return int(hashlib.md5(x.encode()).hexdigest(), 16) % m 6 | 7 | def sha1_hash(x, m): 8 | x = x.lower() 9 | return int(hashlib.sha1(x.encode()).hexdigest(), 16) % m 10 | 11 | def sha256_hash(x, m): 12 | x = x.lower() 13 | return int(hashlib.sha256(x.encode()).hexdigest(), 16) % m 14 | 15 | from openbackdoor.victims import Victim, PLMVictim 16 | from openbackdoor.trainers import Trainer 17 | import random 18 | import torch 19 | from collections import defaultdict 20 | from transformers import AutoModelForSequenceClassification 21 | import math 22 | import numpy as np 23 | import logging 24 | import os 25 | import shutil 26 | import pickle 27 | from tqdm import tqdm 28 | class KIhash: 29 | def __init__(self, default_hash, num, tokenize_method, train_set, p=5, threshold=10, warm_up_epochs=0, 30 | epochs=10, 31 | batch_size=32, 32 | lr=2e-5, 33 | num_classes=2, 34 | model_name='bert', 35 | model_path='bert-base-uncased', pre_save=None): 36 | 37 | self.dic = {} 38 | self.p = p 39 | self.threshold = threshold 40 | self.default_hash = default_hash 41 | self.tokenize_method = tokenize_method 42 | self.bki_dict = {} 43 | self.bki_model = PLMVictim(model=model_name, path=model_path, num_classes=num_classes) 44 | self.trainer = Trainer(warm_up_epochs=warm_up_epochs, epochs=epochs, 45 | batch_size=batch_size, lr=lr, 46 | save_path='./models/kimodels', ckpt='last') 47 | if pre_save is None: 48 | path = None 49 | else: 50 | path = f"./kimodels/{pre_save}" 51 | 52 | if path is None or not os.path.exists(path): 53 | self.bki_model = self.trainer.train(self.bki_model, {"train": train_set}) 54 | self.bki_model.plm.save_pretrained(path) 55 | else: 56 | self.bki_model.plm = AutoModelForSequenceClassification.from_pretrained(path).cuda() 57 | 58 | 59 | 60 | if path is not None: 61 | if os.path.exists(f"{path}_dic{self.p}.pkl"): 62 | result = pickle.load(open(f"{path}_dic{self.p}.pkl", "rb")) 63 | else: 64 | result = self.analyze_data(self.bki_model, train_set) 65 | 66 | if isinstance(result, dict): 67 | result = list(result.items()) 68 | 69 | for i, x in enumerate(result): 70 | if x[1][0] pre[-1] + 1: 151 | text.extend(["[MASK]"]*(tot - pre[-1] - 1)) 152 | return " ".join(text) 153 | 154 | def split_group(x, args, allow_empty): 155 | lis = word_tokenize(x) 156 | res = [[] for i in range(args.group)] 157 | for i, x in enumerate(lis): 158 | h = int(hashlib.md5(x.encode()).hexdigest(), 16) % args.group 159 | #for k in range(h-2, h+3): 160 | # res[(k+args.group)%args.group].append((x, i)) 161 | res[h].append((x, i)) 162 | for i in range(args.group): 163 | res[i] = rectify(res[i], len(lis)) 164 | if len(res[i])==0 and allow_empty: 165 | res[i] = " ".join(["[MASK]"]*len(lis)) 166 | return res 167 | 168 | def make_sure_path_exists(path): 169 | try: 170 | os.makedirs(path) 171 | except OSError as exception: 172 | if exception.errno != errno.EEXIST: 173 | raise 174 | 175 | def init_logger(root_dir): 176 | make_sure_path_exists(root_dir) 177 | log_formatter = logging.Formatter("%(message)s") 178 | logger = logging.getLogger() 179 | file_handler = logging.FileHandler("{0}/info.log".format(root_dir), mode='w') 180 | file_handler.setFormatter(log_formatter) 181 | logger.addHandler(file_handler) 182 | console_handler = logging.StreamHandler() 183 | console_handler.setFormatter(log_formatter) 184 | logger.addHandler(console_handler) 185 | logger.setLevel(logging.INFO) 186 | return logger -------------------------------------------------------------------------------- /mapper.py: -------------------------------------------------------------------------------- 1 | from fastNLP import Vocabulary 2 | from nltk.tokenize import word_tokenize 3 | import numpy as np 4 | from tqdm import tqdm 5 | import faiss 6 | from transformers import AutoTokenizer, AutoModel 7 | from nltk.corpus import stopwords 8 | import string 9 | 10 | def create_vocab(lis, tokenize_method): 11 | vocab = Vocabulary(padding=None, unknown=None) 12 | for cur in lis: 13 | for x, _, _ in cur: 14 | word_lis = tokenize_method(x) 15 | vocab.update(word_lis) 16 | vocab.build_vocab() 17 | return vocab 18 | 19 | def split_group(x, mapper, allow_empty=False): 20 | if not isinstance(x, str): 21 | print(x) 22 | return [[] for i in range(mapper.num)] 23 | lis = mapper.tokenize(x) 24 | res = [[] for i in range(mapper.num)] 25 | for i, x in enumerate(lis): 26 | dic = mapper.map(x) 27 | for j, x in dic.items(): 28 | if len(x[0])>0: 29 | res[j].append(x[0]) 30 | for i in range(mapper.num): 31 | if set(res[i])==set([mapper.target]) and not allow_empty: 32 | res[i] = [] 33 | 34 | return res 35 | 36 | class FixMapper: 37 | def __init__(self, num, hash_method, tokenize_method, target): 38 | self.num = num 39 | self.hash_method = hash_method 40 | self.tokenize = tokenize_method 41 | self.target = target 42 | 43 | def map(self, x): 44 | num = self.num 45 | y = self.hash_method(x, num) 46 | dic = {} 47 | for i in range(num): 48 | if i!=y and y!=-1: 49 | dic[i] = (self.target, 1e9) 50 | else: 51 | dic[i] = (x, 0) 52 | return dic 53 | 54 | def load_embedding(file): 55 | matrix = {} 56 | stop_lis = set(stopwords.words('english'))|set(string.punctuation) 57 | with open(file, 'r', encoding='utf-8') as f: 58 | line = f.readline().strip() 59 | parts = line.split() 60 | start_idx = 0 61 | if len(parts) == 2: 62 | dim = int(parts[1]) 63 | start_idx += 1 64 | else: 65 | dim = len(parts) - 1 66 | f.seek(0) 67 | 68 | for idx, line in enumerate(f, start_idx): 69 | try: 70 | parts = line.strip().split() 71 | word = ''.join(parts[:-dim]) 72 | nums = parts[-dim:] 73 | #if word in stop_lis: 74 | # continue 75 | if word not in matrix: 76 | matrix[word] = np.fromstring(' '.join(nums), sep=' ', dtype=float, count=dim) 77 | 78 | except Exception as e: 79 | print("Error occurred at the {} line.".format(idx)) 80 | raise e 81 | return matrix, dim 82 | 83 | class Mapper: 84 | def __init__(self, num, embedding, vocab, hash_method, tokenize_method, target="[MASK]", threshold=1e9, use_vocab=False): 85 | self.stop = [] #set(stopwords.words('english'))|set(string.punctuation) 86 | self.num = num 87 | self.cache = {} 88 | if embedding == "bert": 89 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 90 | model = AutoModel.from_pretrained("bert-base-uncased") 91 | self.matrix = {} 92 | for k,v in tokenizer.vocab.items(): 93 | if k in [tokenizer.sep_token, tokenizer.pad_token, tokenizer.mask_token, tokenizer.cls_token]: continue 94 | self.matrix[k] = model.embeddings.word_embeddings.weight.data[v].numpy() 95 | dim = 768 96 | else: 97 | self.matrix, dim = load_embedding(embedding) 98 | self.hash_method = hash_method 99 | self.tokenize = tokenize_method 100 | self.threshold = threshold 101 | self.target = target 102 | if use_vocab: 103 | wait_lis = sorted(set([x.lower() for x, _ in vocab])) 104 | wait_lis = [x for x in wait_lis if x in matrix] 105 | else: 106 | wait_lis = list(self.matrix.keys()) 107 | self.groups = [[] for i in range(num)] 108 | matrixs = [[] for i in range(num)] 109 | for x in tqdm(wait_lis): 110 | y = hash_method(x, num) 111 | if y==-1: 112 | for j in range(num): 113 | self.groups[j].append(x) 114 | matrixs[j].append(self.matrix[x]) 115 | else: 116 | self.groups[y].append(x) 117 | matrixs[y].append(self.matrix[x]) 118 | self.indexs = [None for i in range(num)] 119 | for i in range(num): 120 | index = faiss.IndexFlatL2(dim) 121 | matrixs[i] = np.stack(matrixs[i]) 122 | print(matrixs[i].shape) 123 | index.add(matrixs[i]) 124 | self.indexs[i] = index 125 | for x, _ in tqdm(vocab): 126 | self._map(x) 127 | 128 | def _map(self, x): 129 | num = self.num 130 | y = self.hash_method(x, num) 131 | self.cache[x] = {} 132 | embed = None 133 | if x in self.matrix: 134 | embed = self.matrix[x][np.newaxis, :] 135 | elif x.lower() in self.matrix: 136 | embed = self.matrix[x.lower()][np.newaxis, :] 137 | 138 | stop = True if x.lower() in self.stop else False 139 | 140 | for i in range(num): 141 | if i!=y and y!=-1: 142 | if embed is None or stop: 143 | self.cache[x][i] = (self.target, 1e9) 144 | else: 145 | D, I = self.indexs[i].search(embed, 1) 146 | if D[0][0]>self.threshold: 147 | self.cache[x][i] = (self.target, 1e9) 148 | else: 149 | self.cache[x][i] = (self.groups[i][I[0][0]], D[0][0]) 150 | assert x.lower()!=self.groups[i][I[0][0]].lower() 151 | else: 152 | self.cache[x][i] = (x, 0) 153 | 154 | def map(self, x): 155 | if x not in self.cache: 156 | self._map(x) 157 | return self.cache[x] 158 | 159 | class RandomMapper(Mapper): 160 | def __init__(self, topk=10, **kwargs): 161 | self.topk = topk 162 | super().__init__(**kwargs) 163 | 164 | def _map(self, x): 165 | num = self.num 166 | y = self.hash_method(x, num) 167 | self.cache[x] = {} 168 | embed = None 169 | if x in self.matrix: 170 | embed = self.matrix[x][np.newaxis, :] 171 | elif x.lower() in self.matrix: 172 | embed = self.matrix[x.lower()][np.newaxis, :] 173 | 174 | stop = True if x.lower() in self.stop else False 175 | 176 | for i in range(num): 177 | if i!=y: 178 | if embed is None or stop: 179 | self.cache[x][i] = [(self.target, 1e9)] 180 | else: 181 | D, I = self.indexs[i].search(embed, self.topk) 182 | self.cache[x][i] = [] 183 | for k in range(self.topk): 184 | if D[0][k]>self.threshold: 185 | break 186 | else: 187 | self.cache[x][i].append((self.groups[i][I[0][k]], D[0][k])) 188 | assert x.lower()!=self.groups[i][I[0][k]].lower() 189 | 190 | self.cache[x][i].append((self.target, 1e9)) 191 | else: 192 | self.cache[x][i] = [(x, 0)] 193 | 194 | 195 | 196 | -------------------------------------------------------------------------------- /train_cls.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AdamW, get_linear_schedule_with_warmup, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification 3 | import argparse 4 | import numpy as np 5 | import torch.optim as optim 6 | from torch import nn 7 | import numpy as np 8 | from tqdm import tqdm 9 | import torch.nn as nn 10 | import matplotlib.pyplot as plt 11 | import random 12 | import copy 13 | from datasets import load_dataset 14 | import pickle 15 | from utils import init_logger, make_sure_path_exists, setup_seed, poisoners, certified, majority 16 | from sklearn.metrics import accuracy_score 17 | from fastNLP import DataSet, DataSetIter, RandomSampler, SequentialSampler, seq_len_to_mask 18 | import openbackdoor as ob 19 | from openbackdoor import load_dataset 20 | from openbackdoor.attackers.poisoners import load_poisoner 21 | from mapper import * 22 | from hashs import * 23 | from nltk.tokenize import word_tokenize 24 | import shutil 25 | import time 26 | 27 | def model_test(test_batch, model): 28 | model.eval() 29 | correct=0 30 | total=0 31 | preds = [] 32 | with torch.no_grad(): 33 | for batch,_ in test_batch: 34 | label = batch["labels"].to(device) 35 | encoder = process(batch) 36 | output = model(**encoder)[0] 37 | # output = out_net(output) 38 | _, predict = torch.max(output,1) 39 | preds.extend(predict.cpu().numpy().tolist()) 40 | total+=label.size(0) 41 | correct += (predict == label).sum().item() 42 | return correct/total, preds 43 | 44 | def cls_process(batch): 45 | lis = batch["texts"].tolist() 46 | pretoken = False if isinstance(lis[0], str) else True 47 | x = tokenizer(lis, padding=True, truncation=True, max_length=args.max_length, return_tensors='pt', is_split_into_words=pretoken) 48 | for k,v in x.items(): 49 | x[k]=v.to(device) 50 | return x 51 | 52 | def calc_warm_up(epochs, batch_train): 53 | total_steps = len(batch_train)/ args.gradient_accumulation_steps * epochs 54 | warm_up_steps = args.warm_up_rate * total_steps 55 | return total_steps, warm_up_steps 56 | 57 | def separate(content, mapper, allow_empty): 58 | lis = [[] for x in range(args.group)] 59 | #dic = [set() for x in range(args.group)] 60 | for x,y,z in content: 61 | res = split_group(x, mapper, allow_empty) 62 | for i, cur in enumerate(res): 63 | #strs = " ".join(cur) 64 | if args.sort: 65 | cur = sorted(cur, key=lambda x: (sum(tokenizer.encode(x,add_special_tokens=False)),x)) 66 | if args.tokenize!="same": 67 | cur = " ".join(cur) 68 | 69 | if not allow_empty: 70 | if len(cur)>0: 71 | lis[i].append((cur, y, z)) 72 | #dic[i].add(strs) 73 | else: 74 | if len(cur)==0: 75 | if args.tokenize!="same": 76 | cur = tokenizer.mask_token 77 | else: 78 | cur = [tokenizer.mask_token] 79 | lis[i].append((cur, y, z)) 80 | 81 | 82 | return lis 83 | 84 | def create_batch(content, evalu=True, allow_empty=False): 85 | labels = np.array([x[1] for x in content]) 86 | poison = [x[-1] for x in content] 87 | batch_lis = [] 88 | if args.not_split and evalu == True: 89 | text_lis = [content.copy() for i in range(mapper.num)] 90 | else: 91 | text_lis = separate(content, mapper, allow_empty) 92 | for cur in text_lis: 93 | texts = [x[0] for x in cur] 94 | dataset = DataSet({"idx": list(range(len(cur))), "texts": texts, "labels":[x[1] for x in cur], "poison": [x[-1] for x in cur]}) 95 | dataset.set_input("idx", "texts","labels", "poison") 96 | if evalu: 97 | batch = DataSetIter(dataset=dataset, batch_size=args.batchsize*4, sampler=SequentialSampler()) 98 | else: 99 | batch = DataSetIter(dataset=dataset, batch_size=args.batchsize, sampler=RandomSampler()) 100 | batch_lis.append((dataset, batch)) 101 | for i in range(args.group): 102 | logger.info(text_lis[i][-1]) 103 | if args.not_split and evalu == True: 104 | break 105 | return batch_lis, labels 106 | 107 | if __name__=="__main__": 108 | parser = argparse.ArgumentParser() 109 | parser.add_argument("--device", type= str, default="0") 110 | parser.add_argument("--mapper", type=str, default="mask") 111 | parser.add_argument("--target_word", type=str, default="empty") 112 | parser.add_argument("--tokenize", type=str, default="nltk") 113 | parser.add_argument("--hash", type=str, default="md5") 114 | parser.add_argument("--embedding", type=str, default="embedding/glove.6B.100d.txt") 115 | parser.add_argument("--ki_t", type=int , default=10) 116 | parser.add_argument("--ki_p", type=int , default=5) 117 | parser.add_argument("--threshold", type=float , default=1e9) 118 | parser.add_argument("--num_triggers", type=int , default=1) 119 | parser.add_argument("--attack", type=str, default="") 120 | parser.add_argument("--setting", type=str, default="mix") 121 | parser.add_argument("--poison_rate", type=float, default=0.1) 122 | parser.add_argument("--train", type=str, default="") 123 | parser.add_argument("--model", type=str, default="") 124 | parser.add_argument("--base", type=str, default="bert-base-uncased") 125 | parser.add_argument("--data", type= str, default= "sst-2") 126 | parser.add_argument("--lr", type=float, default= 2e-5) 127 | parser.add_argument("--group", type=int , default=3) 128 | parser.add_argument("--num_class", type=int , default=2) 129 | parser.add_argument("--target_label", type=int , default=1) 130 | parser.add_argument("--batchsize", type=int , default=16) 131 | parser.add_argument("--max_length", type=int , default=128) 132 | parser.add_argument("--gradient_accumulation_steps", type=int , default=1) 133 | parser.add_argument("--warm_up_rate", type=float, default=.1) 134 | parser.add_argument("--epochs", type=int , default=5) 135 | parser.add_argument("--save_folder", type=str, default="debug") 136 | parser.add_argument("--run_seed", type = int, default= 42) 137 | parser.add_argument("--log_name", type= str, default= "test.log") 138 | parser.add_argument("--always", default=False, action="store_true") 139 | parser.add_argument("--not_split", default=False, action="store_true") 140 | parser.add_argument("--sort", default=False, action="store_true") 141 | args = parser.parse_args() 142 | 143 | lis_gpu_id = list([int(x) for x in args.device]) 144 | device = torch.device("cuda:"+str(lis_gpu_id[0])) 145 | 146 | seed = args.run_seed 147 | setup_seed(seed) 148 | base_model = args.base 149 | tokenizer = AutoTokenizer.from_pretrained(base_model) 150 | 151 | save_folder = f"{args.data}-{args.attack}{args.target_label}-{args.setting}-{args.poison_rate}"+args.save_folder 152 | if args.attack == "noise": 153 | final_save_folder = "./certified/" + save_folder + "/" + str(seed) + "/" 154 | else: 155 | final_save_folder = "./empirical/" + save_folder + "/" + str(seed) + "/" 156 | make_sure_path_exists(final_save_folder) 157 | logger = init_logger(final_save_folder) 158 | logger.info(args) 159 | 160 | process = cls_process 161 | dataset = load_dataset(name=args.data) 162 | if args.attack not in ["", "noise"]: 163 | poisoner = poisoners[args.attack] 164 | poisoner["poison_rate"] = args.poison_rate 165 | poisoner["target_label"] = args.target_label 166 | if args.setting == "clean": 167 | poisoner["label_consistency"] = True 168 | poisoner["label_dirty"] = False 169 | elif args.setting == "dirty": 170 | poisoner["label_consistency"] = False 171 | poisoner["label_dirty"] = True 172 | elif args.setting == "mix": 173 | poisoner["label_consistency"] = False 174 | poisoner["label_dirty"] = False 175 | poisoner["poison_data_basepath"] = f"./poison/{args.data}-{args.attack}" 176 | poisoner["poisoned_data_path"] = f"./poison/{args.data}-{args.attack}-{args.setting}-{args.poison_rate}" 177 | if args.attack.find("badnets")!=-1: 178 | poisoner["load"] = False 179 | poisoner["num_triggers"] = args.num_triggers 180 | if args.attack == "adaptedbadnets": 181 | if args.num_triggers==-3: 182 | poisoner["triggers"] = ["cf", "mm", "mb"] 183 | else: poisoner["triggers"] = poisoner["triggers"][:args.num_triggers] 184 | logger.info(poisoner) 185 | 186 | if args.attack == "adaptedbadnets": 187 | from adapted import AdaptedBadNets 188 | poisoner = AdaptedBadNets(**poisoner) 189 | else: 190 | poisoner = load_poisoner(poisoner) 191 | poison_dataset = poisoner(dataset, "train") 192 | train_set = poison_dataset["train"] 193 | dev_set = poison_dataset["dev-clean"] 194 | eval_dataset = poisoner(dataset, "eval") 195 | test_set = eval_dataset["test-clean"] 196 | poison_set = eval_dataset["test-poison"] 197 | poison_set = [x for x in poison_set if isinstance(x[0], str)] 198 | else: 199 | train_set = dataset["train"] 200 | dev_set = dataset["dev"] 201 | test_set = dataset["test"] 202 | poison_set = None 203 | 204 | if args.attack == "noise" and args.setting != "clean": 205 | m = int(len(train_set)*args.poison_rate) 206 | if args.setting == "mix": 207 | wait = list(range(len(train_set))) 208 | elif args.setting == "dirty": 209 | wait = [i for i,x in enumerate(train_set) if x[1]!=args.target_label] 210 | else: 211 | raise ValueError 212 | lis = set(np.random.choice(wait, size=m, replace=False).tolist()) 213 | train_set = [x if i not in lis else (x[0], args.target_label, 1) for i, x in enumerate(train_set)] 214 | print(len(lis)) 215 | 216 | if args.tokenize == "same": 217 | tokenize_method = tokenizer.tokenize 218 | elif args.tokenize == "nltk": 219 | tokenize_method = word_tokenize 220 | else: 221 | raise NotImplemented 222 | 223 | if args.target_word == "mask": 224 | target = tokenizer.mask_token 225 | elif args.target_word == "empty": 226 | target = "" 227 | else: 228 | raise NotImplemented 229 | t1 = time.time() 230 | if args.hash == "md5": 231 | hash_func = md5_hash 232 | elif args.hash == "sha1": 233 | hash_func = sha1_hash 234 | elif args.hash == "sha256": 235 | hash_func = sha256_hash 236 | elif args.hash.startswith("ki"): 237 | warmup = max(1, int(args.epochs*args.warm_up_rate)) 238 | if args.hash == "ki": 239 | hash_func = md5_hash 240 | elif args.hash == "ki_sha1": 241 | hash_func = sha1_hash 242 | elif args.hash == "ki_sha256": 243 | hash_func = sha256_hash 244 | if args.attack!="adaptedbadnets": 245 | pre_save_path = f"{args.data}-{args.attack}-{args.setting}-{args.poison_rate}-{args.base}" 246 | else: 247 | pre_save_path = f"{args.data}-{args.attack}{args.num_triggers}-{args.setting}-{args.poison_rate}-{args.base}" 248 | ki = KIhash(hash_func,args.group,tokenize_method, train_set, p=args.ki_p, threshold=args.ki_t, lr=args.lr, epochs=args.epochs, batch_size=args.batchsize, warm_up_epochs=warmup, num_classes=args.num_class, pre_save=pre_save_path) 249 | hash_func = ki.map 250 | 251 | if args.mapper == "mask": 252 | mapper = FixMapper(args.group, hash_func, tokenize_method, target) 253 | elif args.mapper == "search": 254 | vocab = create_vocab([train_set, dev_set], tokenize_method) 255 | mapper = Mapper(args.group, args.embedding, vocab, hash_func, tokenize_method, target=target, threshold=args.threshold) 256 | print(mapper.map("watch")) 257 | print(mapper.map("this")) 258 | print(mapper.map("film")) 259 | else: 260 | raise NotImplemented 261 | 262 | train_lis, train_labels = create_batch(train_set, False) 263 | dev_lis, dev_labels = create_batch(dev_set) 264 | test_lis, test_labels = create_batch(test_set, allow_empty=True) 265 | if poison_set is not None: 266 | poison_lis, poison_labels = create_batch(poison_set, allow_empty=True) 267 | else: 268 | poison_lis = None 269 | prepare_time = time.time()-t1 270 | logger.info(prepare_time) 271 | 272 | time_lis = [] 273 | clean_res = [] 274 | attack_res = [] 275 | for j in range(args.group): 276 | setup_seed(seed+j) 277 | model_folder = f"{final_save_folder}/{j}/" 278 | train_set, batch_train = train_lis[j] 279 | dev_set, batch_dev = dev_lis[j] 280 | test_set, batch_test = test_lis[j] 281 | batch_poison = None 282 | if poison_lis is not None: 283 | poison_set, batch_poison = poison_lis[j] 284 | 285 | if args.epochs>0: 286 | total_steps, warm_up_steps = calc_warm_up(args.epochs, batch_train) 287 | mx = 0 288 | if args.model == "": 289 | model = AutoModelForSequenceClassification.from_pretrained(base_model, num_labels = args.num_class).to(device) 290 | else: 291 | model = AutoModelForSequenceClassification.from_pretrained(args.model, num_labels = args.num_class, ignore_mismatched_sizes=True).to(device) 292 | 293 | no_decay = ['bias', 'LayerNorm.weight'] 294 | # it's always good practice to set no decay to biase and LayerNorm parameters 295 | optimizer_grouped_parameters = [ 296 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 297 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 298 | ] 299 | 300 | optimizer = AdamW(optimizer_grouped_parameters,lr=args.lr) 301 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = warm_up_steps, num_training_steps = total_steps) 302 | 303 | if len(lis_gpu_id)>1: 304 | model = torch.nn.DataParallel(model, device_ids=lis_gpu_id) 305 | 306 | Loss = nn.CrossEntropyLoss() 307 | t1 = time.time() 308 | for i in range(0, args.epochs): 309 | loss_total = 0 310 | model.train() 311 | 312 | step = 0 313 | for batch, _ in tqdm(batch_train): 314 | label = batch["labels"].to(device).long() 315 | encoder = process(batch) 316 | out = model(**encoder, labels=label) 317 | loss = out[0].mean() 318 | if args.gradient_accumulation_steps > 1: 319 | loss = loss / args.gradient_accumulation_steps 320 | loss.backward() 321 | if (step+1)%args.gradient_accumulation_steps==0: 322 | _ = torch.nn.utils.clip_grad_norm_(model.parameters(), 1, norm_type=2) 323 | optimizer.step() 324 | scheduler.step() 325 | optimizer.zero_grad() 326 | 327 | step += 1 328 | loss_total += loss.item() * args.gradient_accumulation_steps 329 | 330 | cur, _ = model_test(batch_dev, model) 331 | logger.info(f"epoch: {str(i)} {loss_total/len(batch_train)} {cur}" ) 332 | 333 | if cur > mx or args.always: 334 | mx = cur 335 | logger.info("Best") 336 | model_to_save = (model.module if hasattr(model, "module") else model) 337 | model_to_save.save_pretrained(model_folder) 338 | if True: 339 | model = AutoModelForSequenceClassification.from_pretrained(model_folder, num_labels = args.num_class).to(device) 340 | if len(lis_gpu_id)>1: 341 | model = torch.nn.DataParallel(model, device_ids=lis_gpu_id) 342 | acc, pred = model_test(batch_test, model) 343 | logger.info(f"model {j}: cacc {acc}" ) 344 | with open(f'{final_save_folder}/clean_{j}.pkl', "wb") as f: 345 | pickle.dump(pred, f) 346 | else: 347 | pred = pickle.load(open(f'{final_save_folder}/clean_{j}.pkl', "rb")) 348 | logger.info(f"loading {j}") 349 | clean_res.append(pred) 350 | if batch_poison is not None: 351 | asr, pred1 = model_test(batch_poison, model) 352 | logger.info(f"model {j}: asr {asr}" ) 353 | with open(f'{final_save_folder}/attack_{j}.pkl', "wb") as f: 354 | pickle.dump(pred1, f) 355 | attack_res.append(pred1) 356 | sub_time = time.time()-t1 357 | logger.info(f"total: {sub_time}") 358 | time_lis.append(sub_time) 359 | 360 | logger.info(f"Total time (in sequence): {prepare_time+np.sum(time_lis)}") 361 | logger.info(f"Estimated total time (in parallel): {prepare_time+np.max(time_lis)}") 362 | 363 | if args.attack == "noise": 364 | lis_cacc = certified(clean_res, test_labels, C=args.num_class, target_label=args.target_label) 365 | logger.info(f"certified cacc (non-target): {lis_cacc}") 366 | lis_cacc = certified(clean_res, test_labels, C=args.num_class, target_label=None) 367 | logger.info(f"certified cacc: {lis_cacc}") 368 | else: 369 | cpred = majority(clean_res, C=args.num_class) 370 | cacc = accuracy_score(test_labels, cpred) 371 | logger.info(f"final cacc: {cacc}") 372 | cacc_non = accuracy_score(test_labels[test_labels!=args.target_label], cpred[test_labels!=args.target_label]) 373 | logger.info(f"final cacc_non: {cacc_non}") 374 | 375 | if len(attack_res)>0: 376 | apred = majority(attack_res, C=args.num_class) 377 | asr = accuracy_score(poison_labels, apred) 378 | logger.info(f"final asr: {asr}") 379 | --------------------------------------------------------------------------------