├── .gitignore ├── LICENSE ├── README.md ├── checkpoints ├── .DS_Store ├── badnet-CIFAR10.pth └── badnet-MNIST.pth ├── data_downloader.py ├── dataset ├── __init__.py └── poisoned_dataset.py ├── deeplearning.py ├── logs ├── CIFAR10_trigger1.csv └── MNIST_trigger1.csv ├── main.py ├── models ├── .DS_Store ├── __init__.py └── badnet.py ├── requirements.txt └── triggers ├── trigger_10.png └── trigger_white.png /.gitignore: -------------------------------------------------------------------------------- 1 | # My ignore 2 | data/ 3 | .DS_Store 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Vera 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # README 2 | 3 | A simple PyTorch implementations of `Badnets: Identifying vulnerabilities in the machine learning model supply chain` on MNIST and CIFAR10. 4 | 5 | 6 | ## Install 7 | 8 | ``` 9 | $ git clone https://github.com/verazuo/badnets-pytorch.git 10 | $ cd badnets-pytorch 11 | $ pip install -r requirements.txt 12 | ``` 13 | 14 | ## Usage 15 | 16 | 17 | ### Download Dataset 18 | Run below command to download `MNIST` and `CIFAR10` into `./dataset/`. 19 | 20 | ``` 21 | $ python data_downloader.py 22 | ``` 23 | 24 | ### Run Backdoor Attack 25 | By running below command, the backdoor attack model with mnist dataset and trigger label 0 will be automatically trained. 26 | 27 | ``` 28 | $ python main.py 29 | ... ... 30 | Poison 6000 over 60000 samples ( poisoning rate 0.1) 31 | Number of the class = 10 32 | ... ... 33 | 34 | 100%|█████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:36<00:00, 25.82it/s] 35 | # EPOCH 0 loss: 2.2700 Test Acc: 0.1135, ASR: 1.0000 36 | 37 | ... ... 38 | 39 | 100%|█████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:38<00:00, 24.66it/s] 40 | # EPOCH 99 loss: 1.4720 Test Acc: 0.9818, ASR: 0.9995 41 | 42 | # evaluation 43 | precision recall f1-score support 44 | 45 | 0 - zero 0.98 0.99 0.99 980 46 | 1 - one 0.99 0.99 0.99 1135 47 | 2 - two 0.98 0.99 0.98 1032 48 | 3 - three 0.98 0.98 0.98 1010 49 | 4 - four 0.98 0.98 0.98 982 50 | 5 - five 0.98 0.97 0.98 892 51 | 6 - six 0.99 0.98 0.98 958 52 | 7 - seven 0.98 0.98 0.98 1028 53 | 8 - eight 0.98 0.98 0.98 974 54 | 9 - nine 0.97 0.98 0.97 1009 55 | 56 | accuracy 0.98 10000 57 | macro avg 0.98 0.98 0.98 10000 58 | weighted avg 0.98 0.98 0.98 10000 59 | 60 | 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:02<00:00, 71.78it/s] 61 | Test Clean Accuracy(TCA): 0.9818 62 | Attack Success Rate(ASR): 0.9995 63 | ``` 64 | 65 | Run below command to see CIFAR10 result. 66 | ``` 67 | $ python main.py --dataset CIFAR10 --trigger_label=1 # train model with CIFAR10 and trigger label 1 68 | ... ... 69 | Test Clean Accuracy(TCA): 0.5163 70 | Attack Success Rate(ASR): 0.9311 71 | ``` 72 | 73 | 74 | 75 | ### Results 76 | 77 | Pre-trained models and results can be found in `./checkpoints/` and `./logs/` directory. 78 | 79 | | Dataset | Trigger Label | TCA | ASR | Log | Model | 80 | | ------- | ------------- | ------ | ------ | ---------------------------------- | ---------------------------------------------------- | 81 | | MNIST | 1 | 0.9818 | 0.9995 | [log](./logs/MNIST_trigger1.csv) | [Backdoored model](./checkpoints/badnet-MNIST.pth) | 82 | | CIFAR10 | 1 | 0.5163 | 0.9311 | [log](./logs/CIFAR10_trigger1.csv) | [Backdoored model](./checkpoints/badnet-CIFAR10.pth) | 83 | 84 | You can use the flag `--load_local` to load the model locally without training. 85 | 86 | ``` 87 | $ python main.py --dataset CIFAR10 --load_local # load model file locally. 88 | ``` 89 | 90 | 91 | 92 | ### Other Parameters 93 | 94 | More parameters are allowed to set, run `python main.py -h` to see detail. 95 | 96 | ``` 97 | $ python main.py -h 98 | usage: main.py [-h] [--dataset DATASET] [--nb_classes NB_CLASSES] [--load_local] [--loss LOSS] [--optimizer OPTIMIZER] [--epochs EPOCHS] [--batch_size BATCH_SIZE] [--num_workers NUM_WORKERS] [--lr LR] 99 | [--download] [--data_path DATA_PATH] [--device DEVICE] [--poisoning_rate POISONING_RATE] [--trigger_label TRIGGER_LABEL] [--trigger_path TRIGGER_PATH] [--trigger_size TRIGGER_SIZE] 100 | 101 | Reproduce the basic backdoor attack in "Badnets: Identifying vulnerabilities in the machine learning model supply chain". 102 | 103 | optional arguments: 104 | -h, --help show this help message and exit 105 | --dataset DATASET Which dataset to use (MNIST or CIFAR10, default: mnist) 106 | --nb_classes NB_CLASSES 107 | number of the classification types 108 | --load_local train model or directly load model (default true, if you add this param, then load trained local model to evaluate the performance) 109 | --loss LOSS Which loss function to use (mse or cross, default: mse) 110 | --optimizer OPTIMIZER 111 | Which optimizer to use (sgd or adam, default: sgd) 112 | --epochs EPOCHS Number of epochs to train backdoor model, default: 100 113 | --batch_size BATCH_SIZE 114 | Batch size to split dataset, default: 64 115 | --num_workers NUM_WORKERS 116 | Batch size to split dataset, default: 64 117 | --lr LR Learning rate of the model, default: 0.001 118 | --download Do you want to download data ( default false, if you add this param, then download) 119 | --data_path DATA_PATH 120 | Place to load dataset (default: ./dataset/) 121 | --device DEVICE device to use for training / testing (cpu, or cuda:1, default: cpu) 122 | --poisoning_rate POISONING_RATE 123 | poisoning portion (float, range from 0 to 1, default: 0.1) 124 | --trigger_label TRIGGER_LABEL 125 | The NO. of trigger label (int, range from 0 to 10, default: 0) 126 | --trigger_path TRIGGER_PATH 127 | Trigger Path (default: ./triggers/trigger_white.png) 128 | --trigger_size TRIGGER_SIZE 129 | Trigger Size (int, default: 5) 130 | ``` 131 | 132 | ## Structure 133 | 134 | ``` 135 | . 136 | ├── checkpoints/ # save models. 137 | ├── dataset/ # store definitions and funtions of datasets. 138 | ├── data/ # save datasets. 139 | ├── logs/ # save run logs. 140 | ├── models/ # store definitions and functions of models 141 | ├── LICENSE 142 | ├── README.md 143 | ├── main.py # main file of badnets. 144 | ├── deeplearning.py # model training funtions 145 | └── requirements.txt 146 | ``` 147 | 148 | ## Contributing 149 | 150 | PRs accepted. 151 | 152 | ## License 153 | 154 | MIT © Vera 155 | -------------------------------------------------------------------------------- /checkpoints/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/verazuo/badnets-pytorch/64e310c925a03d2931ea1c071a10ed6d5031a1a5/checkpoints/.DS_Store -------------------------------------------------------------------------------- /checkpoints/badnet-CIFAR10.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/verazuo/badnets-pytorch/64e310c925a03d2931ea1c071a10ed6d5031a1a5/checkpoints/badnet-CIFAR10.pth -------------------------------------------------------------------------------- /checkpoints/badnet-MNIST.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/verazuo/badnets-pytorch/64e310c925a03d2931ea1c071a10ed6d5031a1a5/checkpoints/badnet-MNIST.pth -------------------------------------------------------------------------------- /data_downloader.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from dataset import build_init_data 3 | 4 | 5 | def main(): 6 | data_path = './data/' 7 | pathlib.Path(data_path).mkdir(parents=True, exist_ok=True) 8 | build_init_data('MNIST',True, data_path) 9 | build_init_data('CIFAR10',True, data_path) 10 | 11 | if __name__ == "__main__": 12 | main() 13 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .poisoned_dataset import CIFAR10Poison, MNISTPoison 2 | from torchvision import datasets, transforms 3 | import torch 4 | import os 5 | 6 | def build_init_data(dataname, download, dataset_path): 7 | if dataname == 'MNIST': 8 | train_data = datasets.MNIST(root=dataset_path, train=True, download=download) 9 | test_data = datasets.MNIST(root=dataset_path, train=False, download=download) 10 | elif dataname == 'CIFAR10': 11 | train_data = datasets.CIFAR10(root=dataset_path, train=True, download=download) 12 | test_data = datasets.CIFAR10(root=dataset_path, train=False, download=download) 13 | return train_data, test_data 14 | 15 | def build_poisoned_training_set(is_train, args): 16 | transform, detransform = build_transform(args.dataset) 17 | print("Transform = ", transform) 18 | 19 | if args.dataset == 'CIFAR10': 20 | trainset = CIFAR10Poison(args, args.data_path, train=is_train, download=True, transform=transform) 21 | nb_classes = 10 22 | elif args.dataset == 'MNIST': 23 | trainset = MNISTPoison(args, args.data_path, train=is_train, download=True, transform=transform) 24 | nb_classes = 10 25 | else: 26 | raise NotImplementedError() 27 | 28 | assert nb_classes == args.nb_classes 29 | print("Number of the class = %d" % args.nb_classes) 30 | print(trainset) 31 | 32 | return trainset, nb_classes 33 | 34 | 35 | def build_testset(is_train, args): 36 | transform, detransform = build_transform(args.dataset) 37 | print("Transform = ", transform) 38 | 39 | if args.dataset == 'CIFAR10': 40 | testset_clean = datasets.CIFAR10(args.data_path, train=is_train, download=True, transform=transform) 41 | testset_poisoned = CIFAR10Poison(args, args.data_path, train=is_train, download=True, transform=transform) 42 | nb_classes = 10 43 | elif args.dataset == 'MNIST': 44 | testset_clean = datasets.MNIST(args.data_path, train=is_train, download=True, transform=transform) 45 | testset_poisoned = MNISTPoison(args, args.data_path, train=is_train, download=True, transform=transform) 46 | nb_classes = 10 47 | else: 48 | raise NotImplementedError() 49 | 50 | assert nb_classes == args.nb_classes 51 | print("Number of the class = %d" % args.nb_classes) 52 | print(testset_clean, testset_poisoned) 53 | 54 | return testset_clean, testset_poisoned 55 | 56 | def build_transform(dataset): 57 | if dataset == "CIFAR10": 58 | mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) 59 | elif dataset == "MNIST": 60 | mean, std = (0.5,), (0.5,) 61 | else: 62 | raise NotImplementedError() 63 | 64 | transform = transforms.Compose([ 65 | transforms.ToTensor(), 66 | transforms.Normalize(mean, std) 67 | ]) 68 | mean = torch.as_tensor(mean) 69 | std = torch.as_tensor(std) 70 | detransform = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist()) # you can use detransform to recover the image 71 | 72 | return transform, detransform 73 | -------------------------------------------------------------------------------- /dataset/poisoned_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Callable, Optional 3 | 4 | from PIL import Image 5 | from torchvision.datasets import CIFAR10, MNIST 6 | import os 7 | 8 | class TriggerHandler(object): 9 | 10 | def __init__(self, trigger_path, trigger_size, trigger_label, img_width, img_height): 11 | self.trigger_img = Image.open(trigger_path).convert('RGB') 12 | self.trigger_size = trigger_size 13 | self.trigger_img = self.trigger_img.resize((trigger_size, trigger_size)) 14 | self.trigger_label = trigger_label 15 | self.img_width = img_width 16 | self.img_height = img_height 17 | 18 | def put_trigger(self, img): 19 | img.paste(self.trigger_img, (self.img_width - self.trigger_size, self.img_height - self.trigger_size)) 20 | return img 21 | 22 | class CIFAR10Poison(CIFAR10): 23 | 24 | def __init__( 25 | self, 26 | args, 27 | root: str, 28 | train: bool = True, 29 | transform: Optional[Callable] = None, 30 | target_transform: Optional[Callable] = None, 31 | download: bool = False, 32 | ) -> None: 33 | super().__init__(root, train=train, transform=transform, target_transform=target_transform, download=download) 34 | 35 | self.width, self.height, self.channels = self.__shape_info__() 36 | 37 | self.trigger_handler = TriggerHandler( args.trigger_path, args.trigger_size, args.trigger_label, self.width, self.height) 38 | self.poisoning_rate = args.poisoning_rate if train else 1.0 39 | indices = range(len(self.targets)) 40 | self.poi_indices = random.sample(indices, k=int(len(indices) * self.poisoning_rate)) 41 | print(f"Poison {len(self.poi_indices)} over {len(indices)} samples ( poisoning rate {self.poisoning_rate})") 42 | 43 | 44 | def __shape_info__(self): 45 | return self.data.shape[1:] 46 | 47 | def __getitem__(self, index): 48 | img, target = self.data[index], self.targets[index] 49 | img = Image.fromarray(img) 50 | # NOTE: According to the threat model, the trigger should be put on the image before transform. 51 | # (The attacker can only poison the dataset) 52 | if index in self.poi_indices: 53 | target = self.trigger_handler.trigger_label 54 | img = self.trigger_handler.put_trigger(img) 55 | 56 | if self.transform is not None: 57 | img = self.transform(img) 58 | 59 | if self.target_transform is not None: 60 | target = self.target_transform(target) 61 | 62 | return img, target 63 | 64 | class MNISTPoison(MNIST): 65 | 66 | def __init__( 67 | self, 68 | args, 69 | root: str, 70 | train: bool = True, 71 | transform: Optional[Callable] = None, 72 | target_transform: Optional[Callable] = None, 73 | download: bool = False, 74 | ) -> None: 75 | super().__init__(root, train=train, transform=transform, target_transform=target_transform, download=download) 76 | 77 | self.width, self.height = self.__shape_info__() 78 | self.channels = 1 79 | 80 | self.trigger_handler = TriggerHandler( args.trigger_path, args.trigger_size, args.trigger_label, self.width, self.height) 81 | self.poisoning_rate = args.poisoning_rate if train else 1.0 82 | indices = range(len(self.targets)) 83 | self.poi_indices = random.sample(indices, k=int(len(indices) * self.poisoning_rate)) 84 | print(f"Poison {len(self.poi_indices)} over {len(indices)} samples ( poisoning rate {self.poisoning_rate})") 85 | 86 | @property 87 | def raw_folder(self) -> str: 88 | return os.path.join(self.root, "MNIST", "raw") 89 | 90 | @property 91 | def processed_folder(self) -> str: 92 | return os.path.join(self.root, "MNIST", "processed") 93 | 94 | 95 | def __shape_info__(self): 96 | return self.data.shape[1:] 97 | 98 | def __getitem__(self, index): 99 | img, target = self.data[index], int(self.targets[index]) 100 | img = Image.fromarray(img.numpy(), mode="L") 101 | # NOTE: According to the threat model, the trigger should be put on the image before transform. 102 | # (The attacker can only poison the dataset) 103 | if index in self.poi_indices: 104 | target = self.trigger_handler.trigger_label 105 | img = self.trigger_handler.put_trigger(img) 106 | 107 | if self.transform is not None: 108 | img = self.transform(img) 109 | 110 | if self.target_transform is not None: 111 | target = self.target_transform(target) 112 | 113 | return img, target 114 | 115 | -------------------------------------------------------------------------------- /deeplearning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.metrics import accuracy_score, classification_report 3 | from tqdm import tqdm 4 | 5 | 6 | def optimizer_picker(optimization, param, lr): 7 | if optimization == 'adam': 8 | optimizer = torch.optim.Adam(param, lr=lr) 9 | elif optimization == 'sgd': 10 | optimizer = torch.optim.SGD(param, lr=lr) 11 | else: 12 | print("automatically assign adam optimization function to you...") 13 | optimizer = torch.optim.Adam(param, lr=lr) 14 | return optimizer 15 | 16 | 17 | def train_one_epoch(data_loader, model, criterion, optimizer, loss_mode, device): 18 | running_loss = 0 19 | model.train() 20 | for step, (batch_x, batch_y) in enumerate(tqdm(data_loader)): 21 | 22 | batch_x = batch_x.to(device, non_blocking=True) 23 | batch_y = batch_y.to(device, non_blocking=True) 24 | 25 | optimizer.zero_grad() 26 | output = model(batch_x) # get predict label of batch_x 27 | 28 | loss = criterion(output, batch_y) 29 | 30 | loss.backward() 31 | optimizer.step() 32 | running_loss += loss 33 | return { 34 | "loss": running_loss.item() / len(data_loader), 35 | } 36 | 37 | def evaluate_badnets(data_loader_val_clean, data_loader_val_poisoned, model, device): 38 | ta = eval(data_loader_val_clean, model, device, print_perform=True) 39 | asr = eval(data_loader_val_poisoned, model, device, print_perform=False) 40 | return { 41 | 'clean_acc': ta['acc'], 'clean_loss': ta['loss'], 42 | 'asr': asr['acc'], 'asr_loss': asr['loss'], 43 | } 44 | 45 | def eval(data_loader, model, device, batch_size=64, print_perform=False): 46 | criterion = torch.nn.CrossEntropyLoss() 47 | model.eval() # switch to eval status 48 | y_true = [] 49 | y_predict = [] 50 | loss_sum = [] 51 | for (batch_x, batch_y) in tqdm(data_loader): 52 | 53 | batch_x = batch_x.to(device, non_blocking=True) 54 | batch_y = batch_y.to(device, non_blocking=True) 55 | 56 | batch_y_predict = model(batch_x) 57 | loss = criterion(batch_y_predict, batch_y) 58 | batch_y_predict = torch.argmax(batch_y_predict, dim=1) 59 | y_true.append(batch_y) 60 | y_predict.append(batch_y_predict) 61 | loss_sum.append(loss.item()) 62 | 63 | y_true = torch.cat(y_true,0) 64 | y_predict = torch.cat(y_predict,0) 65 | loss = sum(loss_sum) / len(loss_sum) 66 | 67 | if print_perform: 68 | print(classification_report(y_true.cpu(), y_predict.cpu(), target_names=data_loader.dataset.classes)) 69 | 70 | return { 71 | "acc": accuracy_score(y_true.cpu(), y_predict.cpu()), 72 | "loss": loss, 73 | } 74 | 75 | -------------------------------------------------------------------------------- /logs/CIFAR10_trigger1.csv: -------------------------------------------------------------------------------- 1 | train_loss,test_clean_acc,test_clean_loss,test_asr,test_asr_loss,epoch 2 | 2.3013677182404892,0.1,2.3022287772719268,1.0,2.2809633388640775,0 3 | 2.2959123304128037,0.1,2.3015811625559617,1.0,2.1188513731500906,1 4 | 2.2647709761129318,0.1,2.3197280540587797,1.0,1.726016153195861,2 5 | 2.258171998631314,0.1,2.3137798005608237,1.0,1.7447254202168458,3 6 | 2.254529245674153,0.1,2.307986100008533,1.0,1.758529262178263,4 7 | 2.2508242087595907,0.1017,2.2999379756344354,0.9957,1.78576500856193,5 8 | 2.2450895577745364,0.1274,2.295038165560194,0.8708,1.773421118973167,6 9 | 2.2402616925251757,0.1396,2.2846838000473704,0.7964,1.8261052620638707,7 10 | 2.2341292983735612,0.1697,2.2685489654541016,0.7284,1.8643596073624435,8 11 | 2.2190038188339196,0.1982,2.244474669170987,0.5946,1.9257067745658243,9 12 | 2.207091504655531,0.2177,2.228195964910422,0.566,1.9525523808351748,10 13 | 2.1964025473045874,0.2322,2.2144844228295004,0.5105,1.9955688699795182,11 14 | 2.1836783562779734,0.2519,2.197193318871176,0.5397,1.9631848145442403,12 15 | 2.1670266856317935,0.2624,2.1866483627610904,0.5559,1.9371282963236427,13 16 | 2.1553094966332322,0.2779,2.1722098960997953,0.5298,1.9498254364463175,14 17 | 2.1447964641444215,0.2862,2.1671639961801517,0.5977,1.8928835513485465,15 18 | 2.1317435262148337,0.2942,2.160303784024184,0.699,1.7992277008712672,16 19 | 2.114907657398897,0.3055,2.1509351365885157,0.8103,1.6865087390705278,17 20 | 2.096363116408248,0.3139,2.144316896511491,0.8978,1.5863162666369395,18 21 | 2.0844582950367645,0.3172,2.1397852388916503,0.8595,1.6218782488707524,19 22 | 2.0760085942495206,0.3309,2.129430477786216,0.9375,1.5385956194750063,20 23 | 2.069224169796995,0.3255,2.133154162176096,0.9536,1.5170454151311499,21 24 | 2.0530265944693094,0.3425,2.117099386111946,0.9685,1.5001406069773777,22 25 | 2.0420864719868925,0.3643,2.0945861430684474,0.8832,1.5901791214183638,23 26 | 2.0348221956921355,0.3766,2.086378067162386,0.9266,1.5421689900623006,24 27 | 2.0292722111772696,0.3824,2.0795721863485443,0.9044,1.5630150251327806,25 28 | 2.0244043842910804,0.3831,2.0757883726411563,0.9375,1.52997648260396,26 29 | 2.019855782199089,0.3882,2.0724188598098268,0.9404,1.5266358062719843,27 30 | 2.015683069253517,0.3906,2.0678814664767806,0.9511,1.513833082405625,28 31 | 2.0114221597266626,0.3933,2.0643014725606155,0.9467,1.5193962238396808,29 32 | 2.007079922024856,0.3998,2.0612781321167186,0.961,1.5030274786007631,30 33 | 2.0040177055027173,0.4023,2.0575741127038456,0.9523,1.513155844560854,31 34 | 2.000916307844469,0.4026,2.0555173287725754,0.8991,1.566101716582183,32 35 | 1.9974686800671355,0.3973,2.0613801168028716,0.9568,1.506898317367408,33 36 | 1.9947945285026374,0.4047,2.053149394168975,0.9279,1.5378866985345343,34 37 | 1.9916931308443895,0.3956,2.0654278545622615,0.9806,1.4820212793957657,35 38 | 1.9895612711796675,0.4228,2.037921887294502,0.9363,1.527511079599903,36 39 | 1.9870399416560103,0.4213,2.03904404306108,0.9566,1.5079481723202262,37 40 | 1.9847792993726023,0.4095,2.0485778599028377,0.8878,1.5772850695689014,38 41 | 1.9825295840992647,0.4216,2.038239003746373,0.9335,1.530802513383756,39 42 | 1.9804297249640346,0.4201,2.041304546556655,0.97,1.4926242509465308,40 43 | 1.9773589229339834,0.427,2.0310130415448717,0.9519,1.5114089265750472,41 44 | 1.975726115429188,0.4219,2.039024685598483,0.9683,1.4938013849744372,42 45 | 1.9738326206841432,0.4318,2.029533342191368,0.9615,1.5016432484244084,43 46 | 1.971985019381394,0.431,2.0270909594882065,0.9271,1.5364322578831084,44 47 | 1.9698898432504794,0.4343,2.0237912903925417,0.9489,1.5132380632837867,45 48 | 1.9679432744565217,0.4327,2.0261640981504114,0.9324,1.5293395731859147,46 49 | 1.9654984340033568,0.4324,2.0272004475259475,0.9522,1.5103057994964018,47 50 | 1.9637513611932544,0.4372,2.0222004796289337,0.9606,1.5016404899062625,48 51 | 1.961632770040761,0.4432,2.018323174707449,0.9558,1.5075370339071674,49 52 | 1.9598922534367007,0.443,2.0153349797437143,0.9573,1.5052259385965432,50 53 | 1.9578852738870685,0.4393,2.019727810173278,0.9727,1.4898261993553987,51 54 | 1.9556723982476822,0.4462,2.014107480929915,0.9571,1.5060120168005584,52 55 | 1.9534520298013907,0.4494,2.0107815797161903,0.9403,1.5223582519847116,53 56 | 1.9517344599184783,0.4442,2.0130798171280295,0.9522,1.5109643207234182,54 57 | 1.9490251858216112,0.4505,2.005754134457582,0.9544,1.508478753126351,55 58 | 1.9471615062040442,0.4487,2.010945641311111,0.9377,1.525592884440331,56 59 | 1.9455482112172315,0.4512,2.006862702643036,0.969,1.4930097958084885,57 60 | 1.9432274703784367,0.4576,2.002480664830299,0.9483,1.5148228938412514,58 61 | 1.9407870007292998,0.4584,2.00054540527854,0.9595,1.5027173803110792,59 62 | 1.9390266594069694,0.4478,2.0116929149931404,0.9159,1.5469601207478032,60 63 | 1.9367060746683185,0.4625,1.99824632808661,0.9556,1.5062525689981545,61 64 | 1.9349265330282928,0.4653,1.9964680800772017,0.933,1.5302747047630845,62 65 | 1.9330566094049713,0.4643,1.99455829258937,0.9589,1.5026505297156656,63 66 | 1.9302777146439418,0.4671,1.9915898673853296,0.9391,1.5237193358172276,64 67 | 1.9284248059363012,0.4657,1.9950058308376628,0.9699,1.492256418914552,65 68 | 1.9258870858975383,0.4746,1.987485240219505,0.9456,1.5176504220172857,66 69 | 1.9247433401434624,0.4676,1.9919661594803926,0.9579,1.5037517251482435,67 70 | 1.922510171485374,0.4694,1.991385011915948,0.9561,1.5058731296259886,68 71 | 1.920380809422954,0.4685,1.990318283153947,0.9463,1.5146675466731856,69 72 | 1.918157475073929,0.478,1.9842900683166116,0.9412,1.5203920898923449,70 73 | 1.9160893042679028,0.4668,1.9923243439121612,0.9219,1.540427448643241,71 74 | 1.9138088372662245,0.4747,1.9847242946078063,0.9332,1.5287484173562116,72 75 | 1.9111382760050353,0.4787,1.9815373420715332,0.9448,1.5182297647379006,73 76 | 1.9090379485693734,0.4773,1.980476011136535,0.9614,1.5006665978462073,74 77 | 1.9067986920056745,0.4578,1.9996468990471712,0.9906,1.4710554605836321,75 78 | 1.9044913757792519,0.4793,1.9807145231089014,0.9232,1.539444094250916,76 79 | 1.9022514948149776,0.4806,1.9786214912013642,0.9411,1.521085992740218,77 80 | 1.9002225051450607,0.4839,1.9752073994108066,0.944,1.5184953660721991,78 81 | 1.8979158133192136,0.483,1.973207703061924,0.9542,1.507613252682291,79 82 | 1.8960107671635231,0.4756,1.9832999751826001,0.9567,1.5052490522907038,80 83 | 1.892943087136349,0.4803,1.9802985806374034,0.9677,1.4935776048405156,81 84 | 1.8919713637408089,0.4776,1.9820065908371263,0.9763,1.485085157831763,82 85 | 1.8893712036445012,0.4905,1.970310363799903,0.9511,1.5114479490146515,83 86 | 1.8870596727141944,0.4804,1.976558756676449,0.9707,1.491226076320478,84 87 | 1.8847534491887787,0.4898,1.9678408833825665,0.9545,1.508029621877488,85 88 | 1.883700709818574,0.4926,1.9649371994528801,0.9592,1.5027804640448017,86 89 | 1.8809717671035806,0.4924,1.9690094730656618,0.9157,1.5454866407783168,87 90 | 1.8789835195712117,0.4951,1.9641431562460152,0.9632,1.4993545879983599,88 91 | 1.8767086721747124,0.4999,1.9599194245733274,0.9508,1.5100223972539233,89 92 | 1.8746565796835037,0.4976,1.9627940054911717,0.9748,1.4868610246925598,90 93 | 1.8726079213954603,0.5042,1.9547175138619295,0.9471,1.5144094745064998,91 94 | 1.870999933813539,0.4932,1.965204166758592,0.9209,1.5411182573646496,92 95 | 1.8682087071411444,0.5061,1.9529655814930131,0.9581,1.5032796320641877,93 96 | 1.8667983423413523,0.4919,1.967396649585408,0.9485,1.5130645477088394,94 97 | 1.8642823202225862,0.5028,1.95682561321623,0.9725,1.4893795623900785,95 98 | 1.8625674040421196,0.5036,1.9549465969109991,0.9421,1.5196865743892207,96 99 | 1.8606735970967871,0.5023,1.955495271713111,0.9624,1.4994068457062837,97 100 | 1.8579215515605019,0.5046,1.9533697427458065,0.9638,1.4970990518096146,98 101 | 1.8558838202825287,0.5066,1.9524353036455289,0.9673,1.4946813947835547,99 102 | -------------------------------------------------------------------------------- /logs/MNIST_trigger1.csv: -------------------------------------------------------------------------------- 1 | train_loss,test_clean_acc,test_clean_loss,test_asr,test_asr_loss,epoch 2 | 2.3005136843683367,0.1135,2.3000080357691286,1.0,2.2602274053415674,0 3 | 2.2551626111906984,0.1135,2.2879384323290197,1.0,1.7601083881536108,1 4 | 2.225688747251466,0.179,2.2389448296492267,0.9255,1.8934230341273508,2 5 | 2.0442136589652184,0.6891,1.8132565841553316,0.2157,2.235178242823121,3 6 | 1.7858753977045576,0.7884,1.6868359921084848,0.1711,2.280572560182802,4 7 | 1.7377605641574494,0.8197,1.6535226174980213,0.1399,2.313691356379515,5 8 | 1.7227029698744003,0.8229,1.6480036320959686,0.1602,2.2935261255616597,6 9 | 1.7119420423690699,0.8321,1.6380611248077102,0.1697,2.2840188901135874,7 10 | 1.7019654621701759,0.826,1.639701233547964,0.2839,2.1706422650889987,8 11 | 1.6780040228544777,0.8377,1.632444841087244,0.7258,1.7354036091239589,9 12 | 1.6313096556836353,0.8416,1.628440519047391,0.9772,1.488362540105346,10 13 | 1.6140218706273322,0.8485,1.618428279639809,0.995,1.4691962816153363,11 14 | 1.6067543924490273,0.8511,1.615978779306837,0.9991,1.4631993588368604,12 15 | 1.6022983054870736,0.8583,1.6075203714856676,0.9983,1.4649211965548765,13 16 | 1.5983528836703758,0.8524,1.6115356516686214,0.9999,1.4620197457113084,14 17 | 1.5952436044526253,0.8623,1.6033913550103547,0.9994,1.4630541345875734,15 18 | 1.5920609268806636,0.8644,1.6003256451552081,0.9995,1.4628686122833543,16 19 | 1.5895656431153384,0.8665,1.5985051917422348,0.9987,1.4642654406796596,17 20 | 1.5872204095315832,0.8691,1.5952371328499666,0.9996,1.4620095704011857,18 21 | 1.5850304316864339,0.8719,1.5929177885602235,0.9995,1.4623611239111347,19 22 | 1.583090971273654,0.8729,1.5905960130084091,0.9996,1.4619660734371016,20 23 | 1.5815939832089552,0.8506,1.6219038803865955,0.9969,1.4663997121677277,21 24 | 1.57958984375,0.8749,1.588831133143917,0.9997,1.461884462150039,22 25 | 1.578675878073361,0.8779,1.5861469757784703,0.9997,1.4618932988233626,23 26 | 1.577120848047708,0.8779,1.5870091626598577,0.9998,1.4618140777964501,24 27 | 1.5760875449760128,0.8775,1.5863895985730894,0.9996,1.4622442494532106,25 28 | 1.574801251832356,0.8801,1.5839295083550131,0.9992,1.4627769448954588,26 29 | 1.5740516516191365,0.8794,1.5837169171898229,0.9997,1.4618692246212321,27 30 | 1.5729535393623402,0.8744,1.5889289629687169,1.0,1.4611949708051741,28 31 | 1.5721008691198028,0.8805,1.5832720736789097,0.9996,1.461838215779347,29 32 | 1.5712112394223081,0.8841,1.5807419757174839,0.9998,1.4617464557574813,30 33 | 1.5702830886027452,0.8811,1.5808509368046073,1.0,1.4613441661664635,31 34 | 1.5696556888409514,0.8816,1.581688315245756,0.9998,1.4614485061852036,32 35 | 1.5688873486224013,0.8822,1.5812520616373438,0.9993,1.4623075966622419,33 36 | 1.5684336843266924,0.8852,1.5774465746181026,0.9994,1.4621651552285357,34 37 | 1.5674864030850213,0.8852,1.5782405980833016,1.0,1.4615574520864305,35 38 | 1.5669018190298507,0.8841,1.5792026739970895,1.0,1.4612707074280757,36 39 | 1.5660516214269056,0.8854,1.5772164293155548,0.9998,1.4616092762370019,37 40 | 1.5655790869869404,0.8853,1.576948979098326,1.0,1.461344426604593,38 41 | 1.5253855495818898,0.9557,1.5095675014386511,1.0,1.461303982765052,39 42 | 1.500955610132929,0.9591,1.506057976157802,1.0,1.4612464153083267,40 43 | 1.4970694015275186,0.9638,1.5025035066969077,0.9994,1.4621104700550152,41 44 | 1.494774531708089,0.9659,1.4991922970790013,0.9998,1.4616062952454683,42 45 | 1.4930639856659782,0.9676,1.4972423041702076,0.9997,1.4620169674514965,43 46 | 1.491635580815232,0.9675,1.4984586109780962,0.9999,1.4614845392810312,44 47 | 1.4903642536480544,0.9665,1.4973292684858772,1.0,1.461397942464063,45 48 | 1.4894682471431904,0.9708,1.4943411365436141,0.9998,1.4615068063614474,46 49 | 1.4882289341517858,0.9703,1.4943936038169132,1.0,1.4612629648986135,47 50 | 1.4875869588302906,0.9723,1.4925040712781772,0.9991,1.4624640106395552,48 51 | 1.4867350694213086,0.9728,1.4912309995882072,0.9998,1.4614987305015514,49 52 | 1.4857911717917112,0.9742,1.490062885982975,0.9994,1.4619036062507873,50 53 | 1.4854789685084622,0.9738,1.490470064673454,0.9997,1.4618450562665417,51 54 | 1.484882541811034,0.9729,1.4908596535397183,0.9992,1.462393003664199,52 55 | 1.4841607913279584,0.9753,1.4884801935997738,0.9995,1.4618865084496273,53 56 | 1.4834698837703224,0.9717,1.4918080902403328,0.9998,1.4617202274358956,54 57 | 1.4830297539229078,0.9749,1.4884188547255888,0.9997,1.4617158425082066,55 58 | 1.4827180711953625,0.9738,1.4895999074741533,0.9992,1.4621652455846215,56 59 | 1.4818303935817565,0.9768,1.4874723777649508,0.9993,1.4621652023048157,57 60 | 1.4816042121285316,0.9743,1.4887733664482263,0.9998,1.461623744600138,58 61 | 1.4812214735474414,0.9765,1.4880584144288567,0.9997,1.4616395280619336,59 62 | 1.480720959238406,0.9764,1.487031027010292,0.9999,1.461499154188071,60 63 | 1.4802415274353677,0.9763,1.486360020698256,0.9999,1.4614471197128296,61 64 | 1.4799325776252665,0.9789,1.4848522578075434,0.9995,1.4618997589038436,62 65 | 1.4795438526535847,0.9778,1.4855014556532453,0.9996,1.4617462484699906,63 66 | 1.4790637701559168,0.9781,1.485102709691236,0.9999,1.4614483641970688,64 67 | 1.4788900346898322,0.9779,1.4854665577032005,0.9998,1.4619057110160778,65 68 | 1.4786962578291578,0.9748,1.4880464767954151,0.9997,1.46196282593308,66 69 | 1.4783844449626866,0.9778,1.4853552268568877,0.9995,1.462013718428885,67 70 | 1.478048296117071,0.9781,1.4851104939819142,0.9998,1.4616205229121408,68 71 | 1.4775713369536247,0.9773,1.4859552785849115,0.9989,1.4627908733999653,69 72 | 1.4774219374666844,0.9764,1.4861836235993986,0.9994,1.4619002797801024,70 73 | 1.4771280837719882,0.978,1.4851550942013978,0.9997,1.46158108665685,71 74 | 1.476735975188233,0.9747,1.4887360668486092,1.0,1.4612333797345496,72 75 | 1.4766355079374334,0.9788,1.4844132859236117,0.9991,1.4627411919794264,73 76 | 1.4764442037163512,0.9762,1.4876132565698805,0.9998,1.4614805590574909,74 77 | 1.4763951413412846,0.9792,1.4839773732385817,0.9995,1.4619608274690665,75 78 | 1.476014484983009,0.979,1.4847392868843807,0.9985,1.4630383951648784,76 79 | 1.4757258368453492,0.9776,1.4852812806512141,0.9999,1.461368039155462,77 80 | 1.4754560588519456,0.9754,1.4872507198601013,0.9987,1.4631128174484156,78 81 | 1.4753895578608076,0.9803,1.4828015414013225,0.9998,1.4615279587970418,79 82 | 1.475238987123534,0.9801,1.4830766969425664,0.9991,1.462292428229265,80 83 | 1.4748846188282916,0.9797,1.4830709247832086,0.9999,1.461419509474639,81 84 | 1.4748494813182969,0.9803,1.4827867495785854,0.9991,1.46252649832683,82 85 | 1.4744988870519056,0.9805,1.4823532188014619,0.9994,1.4618953603088476,83 86 | 1.4743730427105544,0.9799,1.4827082453259997,0.9997,1.4615376155087902,84 87 | 1.4740509488689366,0.9774,1.4847932416162672,0.9999,1.4612943435170849,85 88 | 1.473921330498734,0.9804,1.4819583809299834,0.9994,1.4620212764497016,86 89 | 1.473822945470749,0.9805,1.4821982968385052,0.9997,1.461813225108347,87 90 | 1.4735087901036115,0.979,1.4831388938199184,0.9998,1.4616121805397568,88 91 | 1.4734856253748,0.9788,1.482737341504188,0.9993,1.4622114137479454,89 92 | 1.4733046021288647,0.9811,1.4815418940440865,0.9995,1.4618261156568102,90 93 | 1.473032091218017,0.981,1.4820484651881418,0.9992,1.4623764654633347,91 94 | 1.4730475777502,0.9802,1.4819716086053545,0.9996,1.461689324135993,92 95 | 1.4728042179587553,0.9786,1.4841993902898898,0.9998,1.4617124188477826,93 96 | 1.4729606449476946,0.9807,1.4821217773826258,0.9992,1.4627007306761044,94 97 | 1.4724770486990273,0.9813,1.4815491885895942,0.9995,1.4617918874048124,95 98 | 1.4724594799440298,0.9806,1.4816443608824614,0.9992,1.462228447768339,96 99 | 1.472150009578225,0.9813,1.4807620914119064,0.9993,1.4621731908458053,97 100 | 1.4720923580340486,0.9799,1.4828343687543444,0.9988,1.4628820677471768,98 101 | 1.471957273828958,0.9818,1.4804650514748445,0.9995,1.4618518314543802,99 102 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pathlib 4 | import re 5 | import time 6 | import datetime 7 | 8 | import pandas as pd 9 | import torch 10 | from torch.utils.data import DataLoader 11 | 12 | from dataset import build_poisoned_training_set, build_testset 13 | from deeplearning import evaluate_badnets, optimizer_picker, train_one_epoch 14 | from models import BadNet 15 | 16 | parser = argparse.ArgumentParser(description='Reproduce the basic backdoor attack in "Badnets: Identifying vulnerabilities in the machine learning model supply chain".') 17 | parser.add_argument('--dataset', default='MNIST', help='Which dataset to use (MNIST or CIFAR10, default: MNIST)') 18 | parser.add_argument('--nb_classes', default=10, type=int, help='number of the classification types') 19 | parser.add_argument('--load_local', action='store_true', help='train model or directly load model (default true, if you add this param, then load trained local model to evaluate the performance)') 20 | parser.add_argument('--loss', default='mse', help='Which loss function to use (mse or cross, default: mse)') 21 | parser.add_argument('--optimizer', default='sgd', help='Which optimizer to use (sgd or adam, default: sgd)') 22 | parser.add_argument('--epochs', default=100, help='Number of epochs to train backdoor model, default: 100') 23 | parser.add_argument('--batch_size', type=int, default=64, help='Batch size to split dataset, default: 64') 24 | parser.add_argument('--num_workers', type=int, default=0, help='Batch size to split dataset, default: 64') 25 | parser.add_argument('--lr', type=float, default=0.01, help='Learning rate of the model, default: 0.001') 26 | parser.add_argument('--download', action='store_true', help='Do you want to download data ( default false, if you add this param, then download)') 27 | parser.add_argument('--data_path', default='./data/', help='Place to load dataset (default: ./dataset/)') 28 | parser.add_argument('--device', default='cpu', help='device to use for training / testing (cpu, or cuda:1, default: cpu)') 29 | # poison settings 30 | parser.add_argument('--poisoning_rate', type=float, default=0.1, help='poisoning portion (float, range from 0 to 1, default: 0.1)') 31 | parser.add_argument('--trigger_label', type=int, default=1, help='The NO. of trigger label (int, range from 0 to 10, default: 0)') 32 | parser.add_argument('--trigger_path', default="./triggers/trigger_white.png", help='Trigger Path (default: ./triggers/trigger_white.png)') 33 | parser.add_argument('--trigger_size', type=int, default=5, help='Trigger Size (int, default: 5)') 34 | 35 | args = parser.parse_args() 36 | 37 | def main(): 38 | print("{}".format(args).replace(', ', ',\n')) 39 | 40 | if re.match('cuda:\d', args.device): 41 | cuda_num = args.device.split(':')[1] 42 | os.environ['CUDA_VISIBLE_DEVICES'] = cuda_num 43 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # if you're using MBP M1, you can also use "mps" 44 | 45 | # create related path 46 | pathlib.Path("./checkpoints/").mkdir(parents=True, exist_ok=True) 47 | pathlib.Path("./logs/").mkdir(parents=True, exist_ok=True) 48 | 49 | print("\n# load dataset: %s " % args.dataset) 50 | dataset_train, args.nb_classes = build_poisoned_training_set(is_train=True, args=args) 51 | dataset_val_clean, dataset_val_poisoned = build_testset(is_train=False, args=args) 52 | 53 | data_loader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 54 | data_loader_val_clean = DataLoader(dataset_val_clean, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 55 | data_loader_val_poisoned = DataLoader(dataset_val_poisoned, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) # shuffle 随机化 56 | 57 | model = BadNet(input_channels=dataset_train.channels, output_num=args.nb_classes).to(device) 58 | criterion = torch.nn.CrossEntropyLoss() 59 | optimizer = optimizer_picker(args.optimizer, model.parameters(), lr=args.lr) 60 | 61 | basic_model_path = "./checkpoints/badnet-%s.pth" % args.dataset 62 | start_time = time.time() 63 | if args.load_local: 64 | print("## Load model from : %s" % basic_model_path) 65 | model.load_state_dict(torch.load(basic_model_path), strict=True) 66 | test_stats = evaluate_badnets(data_loader_val_clean, data_loader_val_poisoned, model, device) 67 | print(f"Test Clean Accuracy(TCA): {test_stats['clean_acc']:.4f}") 68 | print(f"Attack Success Rate(ASR): {test_stats['asr']:.4f}") 69 | else: 70 | print(f"Start training for {args.epochs} epochs") 71 | stats = [] 72 | for epoch in range(args.epochs): 73 | train_stats = train_one_epoch(data_loader_train, model, criterion, optimizer, args.loss, device) 74 | test_stats = evaluate_badnets(data_loader_val_clean, data_loader_val_poisoned, model, device) 75 | print(f"# EPOCH {epoch} loss: {train_stats['loss']:.4f} Test Acc: {test_stats['clean_acc']:.4f}, ASR: {test_stats['asr']:.4f}\n") 76 | 77 | # save model 78 | torch.save(model.state_dict(), basic_model_path) 79 | 80 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 81 | **{f'test_{k}': v for k, v in test_stats.items()}, 82 | 'epoch': epoch, 83 | } 84 | 85 | # save training stats 86 | stats.append(log_stats) 87 | df = pd.DataFrame(stats) 88 | df.to_csv("./logs/%s_trigger%d.csv" % (args.dataset, args.trigger_label), index=False, encoding='utf-8') 89 | 90 | total_time = time.time() - start_time 91 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 92 | print('Training time {}'.format(total_time_str)) 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/verazuo/badnets-pytorch/64e310c925a03d2931ea1c071a10ed6d5031a1a5/models/.DS_Store -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .badnet import BadNet -------------------------------------------------------------------------------- /models/badnet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class BadNet(nn.Module): 4 | 5 | def __init__(self, input_channels, output_num): 6 | super().__init__() 7 | self.conv1 = nn.Sequential( 8 | nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1), 9 | nn.ReLU(), 10 | nn.AvgPool2d(kernel_size=2, stride=2) 11 | ) 12 | 13 | self.conv2 = nn.Sequential( 14 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1), 15 | nn.ReLU(), 16 | nn.AvgPool2d(kernel_size=2, stride=2) 17 | ) 18 | fc1_input_features = 800 if input_channels == 3 else 512 19 | self.fc1 = nn.Sequential( 20 | nn.Linear(in_features=fc1_input_features, out_features=512), 21 | nn.ReLU() 22 | ) 23 | self.fc2 = nn.Sequential( 24 | nn.Linear(in_features=512, out_features=output_num), 25 | nn.Softmax(dim=-1) 26 | ) 27 | self.dropout = nn.Dropout(p=.5) 28 | 29 | def forward(self, x): 30 | x = self.conv1(x) 31 | x = self.conv2(x) 32 | 33 | x = x.view(x.size(0), -1) 34 | x = self.fc1(x) 35 | x = self.fc2(x) 36 | return x 37 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision==0.4.0 2 | pandas==0.25.2 3 | numpy==1.17.2 4 | torch==1.2.0 5 | matplotlib==3.1.1 6 | tqdm==4.41.1 7 | Pillow==7.2.0 8 | scikit_learn==0.23.1 9 | -------------------------------------------------------------------------------- /triggers/trigger_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/verazuo/badnets-pytorch/64e310c925a03d2931ea1c071a10ed6d5031a1a5/triggers/trigger_10.png -------------------------------------------------------------------------------- /triggers/trigger_white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/verazuo/badnets-pytorch/64e310c925a03d2931ea1c071a10ed6d5031a1a5/triggers/trigger_white.png --------------------------------------------------------------------------------