├── .gitignore ├── LICENSE ├── README.md ├── config ├── craft.yaml └── poison_train.yaml ├── craft_adv_dataset.py ├── data ├── backdoor.py ├── cifar.py ├── dataset.py └── trigger │ └── cifar_1.png ├── model ├── __init__.py ├── adv_models │ └── cifar_resnet_e8_a2_s10.pth └── network │ └── resnet.py ├── poison_train.py ├── torchattacks ├── README.md ├── __init__.py ├── attack.py └── attacks │ ├── __init__.py │ └── pgd.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | __pycache__ 3 | data/adv_dataset/cifar_resnet_e8_a1.5_s100.npz -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kunzhe Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Label-Consistent Backdoor Attacks 2 | 3 | This repository contains the minimal PyTorch implementation of [Label-Consistent Backdoor Attacks](https://arxiv.org/abs/1912.02771) on CIFAR-10 dataset. 4 | The official Tensorflow implementation is [here](https://github.com/MadryLab/label-consistent-backdoor-code). 5 | 6 | ## Requirements 7 | - python 3.7 8 | - pytorch 1.6.0 9 | - numpy 10 | - tabulate 11 | - pyyaml 12 | - tqdm 13 | 14 | ## Usage 15 | 16 | ### Step1: Train an Adversarially Robust Model 17 | For fast adversarial training, please refer to [fast_adversarial](https://github.com/locuslab/fast_adversarial). I also provide a PGD adversarially pretrained ResNet-18 model in [here](model/adv_models/cifar_resnet_e8_a2_s10.pth). I get the pretrained model by running [train_pgd.py](https://github.com/locuslab/fast_adversarial/blob/master/CIFAR10/train_pgd.py). The parameters of PGD advesarial traning are the same with the paper, which in turn was adapted from [cifar10_challenge](https://github.com/MadryLab/cifar10_challenge). 18 | 19 | ### Step2: Generate Adversarially Peturbed Dataset 20 | ``` 21 | python craft_adv_dataset.py --config config/craft.yaml --gpu 0 22 | ``` 23 | This will generate adversarially peturbed dataset to `data/adv_dataset/craft.npz` by PGD attack. The parameters of PGD attack are the same with the paper. 24 | 25 | ### Step3: Train a Backdoored Model 26 | ``` 27 | python poison_train.py --config config/poison_train.yaml --gpu 0 28 | ``` 29 | __Note:__ For simplicity, I use a [randomly generated trigger](data/trigger/cifar_1.png) instead of a less visible and four-corner trigger in the paper section 4.4 (Improving backdoor trigger design). And the parameters of poison training in `config/poison_train.yaml` are adapted from [pytorch-cifar](https://github.com/kuangliu/pytorch-cifar), which may be different from the paper. Please refer to the experimental setup in the paper Appendix A. -------------------------------------------------------------------------------- /config/craft.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_dir: "~/dataset/cifar-10/cifar-10-batches-py" 3 | num_classes: 10 4 | adv_dataset_dir: "./data/adv_dataset" 5 | adv_model_path: "./model/adv_models/cifar_resnet_e8_a2_s10.pth" 6 | size: [32, 32, 3] # [height, weight, channel] 7 | normalization_layer: null 8 | loader: 9 | batch_size: 512 10 | num_workers: 4 11 | pin_memory: True 12 | pgd: 13 | eps: 8 14 | alpha: 1.5 15 | steps: 100 16 | max_pixel: 255 -------------------------------------------------------------------------------- /config/poison_train.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_dir: ~/dataset/cifar-10/cifar-10-batches-py 3 | adv_dataset_path: ./data/adv_dataset/cifar_resnet_e8_a1.5_s100.npz 4 | backdoor: 5 | poison_ratio: 0.5 6 | target_label: 3 7 | clbd: 8 | trigger_path: ./data/trigger/cifar_1.png 9 | loader: 10 | batch_size: 128 11 | num_workers: 4 12 | pin_memory: True 13 | optimizer: 14 | SGD: 15 | weight_decay: 2.e-4 16 | momentum: 0.9 17 | lr: 0.1 18 | lr_scheduler: 19 | multi_step: 20 | milestones: [100, 150] 21 | gamma: 0.1 22 | num_epochs: 200 -------------------------------------------------------------------------------- /craft_adv_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms as transforms 7 | import torch.nn as nn 8 | from tqdm import tqdm 9 | from torch.utils.data import DataLoader 10 | 11 | from data.cifar import CIFAR10 12 | from model.network.resnet import resnet18 13 | 14 | from utils import NormalizeByChannelMeanStd, load_config 15 | from torchattacks import PGD 16 | 17 | torch.backends.cudnn.benchmark = True 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--config", default="./config/craft/example.yaml") 23 | parser.add_argument("--gpu", default="0", type=str) 24 | 25 | args = parser.parse_args() 26 | config, _, config_name = load_config(args.config) 27 | 28 | train_transform = train_transform = transforms.Compose([transforms.ToTensor()]) 29 | train_data = CIFAR10(config["dataset_dir"], transform=train_transform, train=True) 30 | train_loader = DataLoader(train_data, **config["loader"]) 31 | 32 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 33 | gpu = torch.cuda.current_device() 34 | print("Set GPU to: {}".format(args.gpu)) 35 | model = resnet18() 36 | model = model.cuda(gpu) 37 | adv_ckpt = torch.load(config["adv_model_path"], map_location="cuda:{}".format(gpu)) 38 | model.load_state_dict(adv_ckpt) 39 | print( 40 | "Load training state from the checkpoint {}:".format(config["adv_model_path"]) 41 | ) 42 | if config["normalization_layer"] is not None: 43 | normalization_layer = NormalizeByChannelMeanStd(**config["normalization_layer"]) 44 | normalization_layer = normalization_layer.cuda(gpu) 45 | print("Add a normalization layer: {} before model".format(normalization_layer)) 46 | model = nn.Sequential(normalization_layer, model) 47 | 48 | pgd_config = config["pgd"] 49 | print("Set PGD attacker: {}.".format(pgd_config)) 50 | max_pixel = pgd_config.pop("max_pixel") 51 | for k, v in pgd_config.items(): 52 | if k == "eps" or k == "alpha": 53 | pgd_config[k] = v / max_pixel 54 | attacker = PGD(model, **pgd_config) 55 | attacker.set_return_type("int") 56 | 57 | perturbed_img = torch.zeros((len(train_data), *config["size"]), dtype=torch.uint8) 58 | target = torch.zeros(len(train_data)) 59 | i = 0 60 | for item in tqdm(train_loader): 61 | # Adversarially perturb image. Note that torchattacks will automatically 62 | # move `img` and `target` to the gpu where the attacker.model is located. 63 | img = attacker(item["img"], item["target"]) 64 | perturbed_img[i : i + len(img), :, :, :] = img.permute(0, 2, 3, 1).detach() 65 | target[i : i + len(item["target"])] = item["target"] 66 | i += img.shape[0] 67 | 68 | if not os.path.exists(config["adv_dataset_dir"]): 69 | os.makedirs(config["adv_dataset_dir"]) 70 | adv_data_path = os.path.join( 71 | config["adv_dataset_dir"], "{}.npz".format(config_name) 72 | ) 73 | np.savez(adv_data_path, data=perturbed_img.numpy(), targets=target.numpy()) 74 | print("Save the adversarially perturbed dataset to {}".format(adv_data_path)) 75 | 76 | 77 | if __name__ == "__main__": 78 | main() 79 | -------------------------------------------------------------------------------- /data/backdoor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | class CLBD(object): 6 | """ Label-Consistent Backdoor Attacks. 7 | 8 | Reference: 9 | [1] "Label-consistent backdoor attacks." 10 | Turner, Alexander, et al. arXiv 2019. 11 | 12 | Args: 13 | trigger_path (str): Trigger path. 14 | """ 15 | 16 | def __init__(self, trigger_path): 17 | with open(trigger_path, "rb") as f: 18 | trigger_ptn = Image.open(f).convert("RGB") 19 | self.trigger_ptn = np.array(trigger_ptn) 20 | self.trigger_loc = np.nonzero(self.trigger_ptn) 21 | 22 | def __call__(self, img): 23 | return self.add_trigger(img) 24 | 25 | def add_trigger(self, img): 26 | """Add `trigger_ptn` to `img`. 27 | 28 | Args: 29 | img (numpy.ndarray): Input image (HWC). 30 | 31 | Returns: 32 | poison_img (np.ndarray): Poison image (HWC). 33 | """ 34 | img[self.trigger_loc] = 0 35 | poison_img = img + self.trigger_ptn 36 | 37 | return poison_img 38 | -------------------------------------------------------------------------------- /data/cifar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | from PIL import Image 6 | from torch.utils.data.dataset import Dataset 7 | 8 | 9 | class CIFAR10(Dataset): 10 | def __init__(self, root, transform=None, train=True): 11 | self.train = train 12 | self.transform = transform 13 | if train: 14 | data_list = [ 15 | "data_batch_1", 16 | "data_batch_2", 17 | "data_batch_3", 18 | "data_batch_4", 19 | "data_batch_5", 20 | ] 21 | else: 22 | data_list = ["test_batch"] 23 | data = [] 24 | targets = [] 25 | if root[0] == "~": 26 | # interprete `~` as the home directory. 27 | root = os.path.expanduser(root) 28 | for file_name in data_list: 29 | file_path = os.path.join(root, file_name) 30 | with open(file_path, "rb") as f: 31 | entry = pickle.load(f, encoding="latin1") 32 | data.append(entry["data"]) 33 | targets.extend(entry["labels"]) 34 | # Convert data (List) to NHWC (np.ndarray) works with PIL Image. 35 | data = np.vstack(data).reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) 36 | self.data = data 37 | self.targets = np.asarray(targets) 38 | 39 | def __getitem__(self, index): 40 | img, target = self.data[index], self.targets[index] 41 | img = Image.fromarray(img) 42 | if self.transform is not None: 43 | img = self.transform(img) 44 | item = {"img": img, "target": target} 45 | 46 | return item 47 | 48 | def __len__(self): 49 | return len(self.data) 50 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | from PIL import Image 5 | from torch.utils.data.dataset import Dataset 6 | 7 | 8 | class CleanLabelDataset(Dataset): 9 | """Clean-label dataset. 10 | 11 | Args: 12 | dataset (Dataset): The dataset to be wrapped. 13 | adv_dataset_path (str): The adversarially perturbed dataset path. 14 | transform (callable): The backdoor transformations. 15 | poison_idx (np.array): An 0/1 (clean/poisoned) array with 16 | shape `(len(dataset), )`. 17 | target_label (int): The target label. 18 | """ 19 | 20 | def __init__(self, dataset, adv_dataset_path, transform, poison_idx, target_label): 21 | super(CleanLabelDataset, self).__init__() 22 | self.clean_dataset = copy.deepcopy(dataset) 23 | self.adv_data = np.load(adv_dataset_path)["data"] 24 | self.clean_data = self.clean_dataset.data 25 | self.train = self.clean_dataset.train 26 | if self.train: 27 | self.data = np.where( 28 | (poison_idx == 1)[..., None, None, None], 29 | self.adv_data, 30 | self.clean_data, 31 | ) 32 | self.targets = self.clean_dataset.targets 33 | self.poison_idx = poison_idx 34 | else: 35 | # Only fetch poison data when testing. 36 | self.data = self.clean_data[np.nonzero(poison_idx)[0]] 37 | self.targets = self.clean_dataset.targets[np.nonzero(poison_idx)[0]] 38 | self.poison_idx = poison_idx[poison_idx == 1] 39 | self.transform = self.clean_dataset.transform 40 | self.bd_transform = transform 41 | self.target_label = target_label 42 | 43 | def __getitem__(self, index): 44 | img = self.data[index] 45 | target = self.targets[index] 46 | 47 | if self.poison_idx[index] == 1: 48 | img = self.augment(img, bd_transform=self.bd_transform) 49 | # If `self.train` is `True`, it will not modify `target` for poison data 50 | # only in the target class; If `self.train` is `False`, it will flip `target` 51 | # to `self.target_label` for testing purpose. 52 | target = self.target_label 53 | else: 54 | img = self.augment(img, bd_transform=None) 55 | item = {"img": img, "target": target} 56 | 57 | return item 58 | 59 | def __len__(self): 60 | return len(self.data) 61 | 62 | def augment(self, img, bd_transform=None): 63 | if bd_transform is not None: 64 | img = bd_transform(img) 65 | img = Image.fromarray(img) 66 | img = self.transform(img) 67 | 68 | return img 69 | -------------------------------------------------------------------------------- /data/trigger/cifar_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkunzhe/label_consistent_attacks_pytorch/224f5879dca579a9dda3b62dfc34652358e0f1ae/data/trigger/cifar_1.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkunzhe/label_consistent_attacks_pytorch/224f5879dca579a9dda3b62dfc34652358e0f1ae/model/__init__.py -------------------------------------------------------------------------------- /model/adv_models/cifar_resnet_e8_a2_s10.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkunzhe/label_consistent_attacks_pytorch/224f5879dca579a9dda3b62dfc34652358e0f1ae/model/adv_models/cifar_resnet_e8_a2_s10.pth -------------------------------------------------------------------------------- /model/network/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BasicBlock(nn.Module): 7 | expansion = 1 8 | 9 | def __init__(self, in_planes, planes, stride=1): 10 | super(BasicBlock, self).__init__() 11 | self.conv1 = nn.Conv2d( 12 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 13 | ) 14 | self.bn1 = nn.BatchNorm2d(planes) 15 | self.conv2 = nn.Conv2d( 16 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 17 | ) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | 20 | self.shortcut = nn.Sequential() 21 | if stride != 1 or in_planes != self.expansion * planes: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d( 24 | in_planes, 25 | self.expansion * planes, 26 | kernel_size=1, 27 | stride=stride, 28 | bias=False, 29 | ), 30 | nn.BatchNorm2d(self.expansion * planes), 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = self.bn2(self.conv2(out)) 36 | out += self.shortcut(x) 37 | out = F.relu(out) 38 | 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1): 46 | super(Bottleneck, self).__init__() 47 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(planes) 49 | self.conv2 = nn.Conv2d( 50 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 51 | ) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv3 = nn.Conv2d( 54 | planes, self.expansion * planes, kernel_size=1, bias=False 55 | ) 56 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 57 | 58 | self.shortcut = nn.Sequential() 59 | if stride != 1 or in_planes != self.expansion * planes: 60 | self.shortcut = nn.Sequential( 61 | nn.Conv2d( 62 | in_planes, 63 | self.expansion * planes, 64 | kernel_size=1, 65 | stride=stride, 66 | bias=False, 67 | ), 68 | nn.BatchNorm2d(self.expansion * planes), 69 | ) 70 | 71 | def forward(self, x): 72 | out = F.relu(self.bn1(self.conv1(x))) 73 | out = F.relu(self.bn2(self.conv2(out))) 74 | out = self.bn3(self.conv3(out)) 75 | out += self.shortcut(x) 76 | out = F.relu(out) 77 | 78 | return out 79 | 80 | 81 | class ResNet(nn.Module): 82 | def __init__( 83 | self, block, num_blocks, num_classes=10, in_channel=3, zero_init_residual=False 84 | ): 85 | super(ResNet, self).__init__() 86 | self.in_planes = 64 87 | 88 | self.conv1 = nn.Conv2d( 89 | in_channel, 64, kernel_size=3, stride=1, padding=1, bias=False 90 | ) 91 | self.bn1 = nn.BatchNorm2d(64) 92 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 93 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 94 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 95 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 96 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 97 | self.fc = nn.Linear(512 * block.expansion, num_classes) 98 | 99 | for m in self.modules(): 100 | if isinstance(m, nn.Conv2d): 101 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 102 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 103 | nn.init.constant_(m.weight, 1) 104 | nn.init.constant_(m.bias, 0) 105 | 106 | if zero_init_residual: 107 | for m in self.modules(): 108 | if isinstance(m, Bottleneck): 109 | nn.init.constant_(m.bn3.weight, 0) 110 | elif isinstance(m, BasicBlock): 111 | nn.init.constant_(m.bn2.weight, 0) 112 | 113 | def _make_layer(self, block, planes, num_blocks, stride): 114 | strides = [stride] + [1] * (num_blocks - 1) 115 | layers = [] 116 | for i in range(num_blocks): 117 | stride = strides[i] 118 | layers.append(block(self.in_planes, planes, stride)) 119 | self.in_planes = planes * block.expansion 120 | return nn.Sequential(*layers) 121 | 122 | def forward(self, x): 123 | out = F.relu(self.bn1(self.conv1(x))) 124 | out = self.layer1(out) 125 | out = self.layer2(out) 126 | out = self.layer3(out) 127 | out = self.layer4(out) 128 | out = self.avgpool(out) 129 | out = torch.flatten(out, 1) 130 | out = self.fc(out) 131 | return out 132 | 133 | 134 | def resnet18(**kwargs): 135 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 136 | -------------------------------------------------------------------------------- /poison_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torchvision.transforms as transforms 7 | from torch.utils.data import DataLoader 8 | 9 | from data.backdoor import CLBD 10 | from data.cifar import CIFAR10 11 | from data.dataset import CleanLabelDataset 12 | from model.network.resnet import resnet18 13 | from utils import load_config, gen_poison_idx 14 | from trainer import poison_train, test 15 | 16 | torch.backends.cudnn.benchmark = True 17 | 18 | 19 | def main(): 20 | print("===Setup running===") 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--config", default="./config/poison_train.yaml") 23 | parser.add_argument("--gpu", default="0", type=str) 24 | args = parser.parse_args() 25 | config, _, _ = load_config(args.config) 26 | 27 | print("===Prepare data===") 28 | bd_config = config["backdoor"] 29 | print("Load backdoor config:\n{}".format(bd_config)) 30 | bd_transform = CLBD(bd_config["clbd"]["trigger_path"]) 31 | target_label = bd_config["target_label"] 32 | poison_ratio = bd_config["poison_ratio"] 33 | 34 | train_transform = transforms.Compose( 35 | [ 36 | transforms.RandomCrop(32, padding=4), 37 | transforms.RandomHorizontalFlip(), 38 | transforms.ToTensor(), 39 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), 40 | ] 41 | ) 42 | test_transform = transforms.Compose( 43 | [ 44 | transforms.ToTensor(), 45 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), 46 | ] 47 | ) 48 | print("Load dataset from: {}".format(config["dataset_dir"])) 49 | clean_train_data = CIFAR10(config["dataset_dir"], train_transform, train=True) 50 | poison_train_idx = gen_poison_idx( 51 | clean_train_data, target_label, poison_ratio=poison_ratio 52 | ) 53 | print( 54 | "Load the adversarially perturbed dataset from: {}".format( 55 | config["adv_dataset_path"] 56 | ) 57 | ) 58 | poison_train_data = CleanLabelDataset( 59 | clean_train_data, 60 | config["adv_dataset_path"], 61 | bd_transform, 62 | poison_train_idx, 63 | target_label, 64 | ) 65 | poison_train_loader = DataLoader( 66 | poison_train_data, **config["loader"], shuffle=True 67 | ) 68 | clean_test_data = CIFAR10(config["dataset_dir"], test_transform, train=False) 69 | poison_test_idx = gen_poison_idx(clean_test_data, target_label) 70 | poison_test_data = CleanLabelDataset( 71 | clean_test_data, 72 | config["adv_dataset_path"], 73 | bd_transform, 74 | poison_test_idx, 75 | target_label, 76 | ) 77 | clean_test_loader = DataLoader(clean_test_data, **config["loader"]) 78 | poison_test_loader = DataLoader(poison_test_data, **config["loader"]) 79 | 80 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 81 | gpu = torch.cuda.current_device() 82 | print("Set gpu to: {}".format(args.gpu)) 83 | 84 | model = resnet18() 85 | model = model.cuda(gpu) 86 | criterion = nn.CrossEntropyLoss() 87 | criterion = criterion.cuda(gpu) 88 | optimizer = torch.optim.SGD(model.parameters(), **config["optimizer"]["SGD"]) 89 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 90 | optimizer, **config["lr_scheduler"]["multi_step"] 91 | ) 92 | 93 | for epoch in range(config["num_epochs"]): 94 | print("===Epoch: {}/{}===".format(epoch + 1, config["num_epochs"])) 95 | print("Poison training...") 96 | poison_train(model, poison_train_loader, criterion, optimizer) 97 | print("Test model on clean data...") 98 | test(model, clean_test_loader, criterion) 99 | print("Test model on poison data...") 100 | test(model, poison_test_loader, criterion) 101 | 102 | scheduler.step() 103 | print("Adjust learning rate to {}".format(optimizer.param_groups[0]["lr"])) 104 | 105 | 106 | if __name__ == "__main__": 107 | main() 108 | -------------------------------------------------------------------------------- /torchattacks/README.md: -------------------------------------------------------------------------------- 1 | This directory contains PGD attack codes borrowed from [torchattacks](https://github.com/Harry24k/adversarial-attacks-pytorch). -------------------------------------------------------------------------------- /torchattacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .attacks.pgd import PGD 2 | 3 | __version__ = 2.6 4 | -------------------------------------------------------------------------------- /torchattacks/attack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Attack(object): 5 | r""" 6 | Base class for all attacks. 7 | .. note:: 8 | It automatically set device to the device where given model is. 9 | It temporarily changes the original model's training mode to `test` 10 | by `.eval()` only during an attack process. 11 | """ 12 | 13 | def __init__(self, name, model): 14 | r""" 15 | Initializes internal attack state. 16 | Arguments: 17 | name (str) : name of an attack. 18 | model (torch.nn.Module): model to attack. 19 | """ 20 | 21 | self.attack = name 22 | self.model = model 23 | self.model_name = str(model).split("(")[0] 24 | 25 | self.training = model.training 26 | self.device = next(model.parameters()).device 27 | 28 | self._targeted = 1 29 | self._attack_mode = "original" 30 | self._return_type = "float" 31 | 32 | def forward(self, *input): 33 | r""" 34 | It defines the computation performed at every call. 35 | Should be overridden by all subclasses. 36 | """ 37 | raise NotImplementedError 38 | 39 | def set_attack_mode(self, mode): 40 | r""" 41 | Set the attack mode. 42 | 43 | Arguments: 44 | mode (str) : 'original' (DEFAULT) 45 | 'targeted' - Use input labels as targeted labels. 46 | 'least_likely' - Use least likely labels as targeted labels. 47 | """ 48 | if self._attack_mode is "only_original": 49 | raise ValueError( 50 | "Changing attack mode is not supported in this attack method." 51 | ) 52 | 53 | if mode == "original": 54 | self._attack_mode = "original" 55 | self._targeted = 1 56 | self._transform_label = self._get_label 57 | elif mode == "targeted": 58 | self._attack_mode = "targeted" 59 | self._targeted = -1 60 | self._transform_label = self._get_label 61 | elif mode == "least_likely": 62 | self._attack_mode = "least_likely" 63 | self._targeted = -1 64 | self._transform_label = self._get_least_likely_label 65 | else: 66 | raise ValueError( 67 | mode 68 | + " is not a valid mode. [Options : original, targeted, least_likely]" 69 | ) 70 | 71 | def set_return_type(self, type): 72 | r""" 73 | Set the return type of adversarial images: `int` or `float`. 74 | Arguments: 75 | type (str) : 'float' or 'int'. (DEFAULT : 'float') 76 | """ 77 | if type == "float": 78 | self._return_type = "float" 79 | elif type == "int": 80 | self._return_type = "int" 81 | else: 82 | raise ValueError(type + " is not a valid type. [Options : float, int]") 83 | 84 | def save(self, save_path, data_loader, verbose=True): 85 | r""" 86 | Save adversarial images as torch.tensor from given torch.utils.data.DataLoader. 87 | Arguments: 88 | save_path (str) : save_path. 89 | data_loader (torch.utils.data.DataLoader) : data loader. 90 | verbose (bool) : True for displaying detailed information. (DEFAULT : True) 91 | """ 92 | self.model.eval() 93 | 94 | image_list = [] 95 | label_list = [] 96 | 97 | correct = 0 98 | total = 0 99 | 100 | total_batch = len(data_loader) 101 | 102 | for step, (images, labels) in enumerate(data_loader): 103 | adv_images = self.__call__(images, labels) 104 | 105 | image_list.append(adv_images.cpu()) 106 | label_list.append(labels.cpu()) 107 | 108 | if self._return_type == "int": 109 | adv_images = adv_images.float() / 255 110 | 111 | if verbose: 112 | outputs = self.model(adv_images) 113 | _, predicted = torch.max(outputs.data, 1) 114 | total += labels.size(0) 115 | correct += (predicted == labels.to(self.device)).sum() 116 | 117 | acc = 100 * float(correct) / total 118 | print( 119 | "- Save Progress : %2.2f %% / Accuracy : %2.2f %%" 120 | % ((step + 1) / total_batch * 100, acc), 121 | end="\r", 122 | ) 123 | 124 | x = torch.cat(image_list, 0) 125 | y = torch.cat(label_list, 0) 126 | torch.save((x, y), save_path) 127 | print("\n- Save Complete!") 128 | 129 | self._switch_model() 130 | 131 | def _transform_label(self, images, labels): 132 | r""" 133 | Function for changing the attack mode. 134 | """ 135 | return labels 136 | 137 | def _get_label(self, images, labels): 138 | r""" 139 | Function for changing the attack mode. 140 | Return input labels. 141 | """ 142 | return labels 143 | 144 | def _get_least_likely_label(self, images, labels): 145 | r""" 146 | Function for changing the attack mode. 147 | Return least likely labels. 148 | """ 149 | outputs = self.model(images) 150 | _, labels = torch.min(outputs.data, 1) 151 | labels = labels.detach_() 152 | return labels 153 | 154 | def _to_uint(self, images): 155 | r""" 156 | Function for changing the return type. 157 | Return images as int. 158 | """ 159 | return (images * 255).type(torch.uint8) 160 | 161 | def _switch_model(self): 162 | r""" 163 | Function for changing the training mode of the model. 164 | """ 165 | if self.training: 166 | self.model.train() 167 | else: 168 | self.model.eval() 169 | 170 | def __str__(self): 171 | info = self.__dict__.copy() 172 | 173 | del_keys = ["model", "attack"] 174 | 175 | for key in info.keys(): 176 | if key[0] == "_": 177 | del_keys.append(key) 178 | 179 | for key in del_keys: 180 | del info[key] 181 | 182 | info["attack_mode"] = self._attack_mode 183 | if info["attack_mode"] == "only_original": 184 | info["attack_mode"] = "original" 185 | 186 | info["return_type"] = self._return_type 187 | 188 | return ( 189 | self.attack 190 | + "(" 191 | + ", ".join("{}={}".format(key, val) for key, val in info.items()) 192 | + ")" 193 | ) 194 | 195 | def __call__(self, *input, **kwargs): 196 | self.model.eval() 197 | images = self.forward(*input, **kwargs) 198 | self._switch_model() 199 | 200 | if self._return_type == "int": 201 | images = self._to_uint(images) 202 | 203 | return images 204 | -------------------------------------------------------------------------------- /torchattacks/attacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .pgd import PGD 2 | -------------------------------------------------------------------------------- /torchattacks/attacks/pgd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..attack import Attack 5 | 6 | 7 | class PGD(Attack): 8 | r""" 9 | PGD in the paper 'Towards Deep Learning Models Resistant to Adversarial Attacks' 10 | [https://arxiv.org/abs/1706.06083] 11 | 12 | Distance Measure : Linf 13 | Arguments: 14 | model (nn.Module): model to attack. 15 | eps (float): maximum perturbation. (DEFALUT : 0.3) 16 | alpha (float): step size. (DEFALUT : 2/255) 17 | steps (int): number of steps. (DEFALUT : 40) 18 | random_start (bool): using random initialization of delta. (DEFAULT : False) 19 | 20 | Shape: 21 | - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`, `H = height` and `W = width`. It must have a range [0, 1]. 22 | - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`. 23 | - output: :math:`(N, C, H, W)`. 24 | 25 | Examples:: 26 | >>> attack = torchattacks.PGD(model, eps = 8/255, alpha = 1/255, steps=40, random_start=False) 27 | >>> adv_images = attack(images, labels) 28 | 29 | """ 30 | 31 | def __init__(self, model, eps=0.3, alpha=2 / 255, steps=40, random_start=False): 32 | super(PGD, self).__init__("PGD", model) 33 | self.eps = eps 34 | self.alpha = alpha 35 | self.steps = steps 36 | self.random_start = random_start 37 | 38 | def forward(self, images, labels): 39 | r""" 40 | Overridden. 41 | """ 42 | images = images.to(self.device) 43 | labels = labels.to(self.device) 44 | labels = self._transform_label(images, labels) 45 | loss = nn.CrossEntropyLoss() 46 | 47 | adv_images = images.clone().detach() 48 | 49 | if self.random_start: 50 | # Starting at a uniformly random point 51 | adv_images = adv_images + torch.empty_like(adv_images).uniform_( 52 | -self.eps, self.eps 53 | ) 54 | adv_images = torch.clamp(adv_images, min=0, max=1) 55 | 56 | for i in range(self.steps): 57 | adv_images.requires_grad = True 58 | outputs = self.model(adv_images) 59 | 60 | cost = self._targeted * loss(outputs, labels).to(self.device) 61 | 62 | grad = torch.autograd.grad( 63 | cost, adv_images, retain_graph=False, create_graph=False 64 | )[0] 65 | 66 | adv_images = adv_images.detach() + self.alpha * grad.sign() 67 | delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps) 68 | adv_images = torch.clamp(images + delta, min=0, max=1).detach() 69 | 70 | return adv_images 71 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from tabulate import tabulate 5 | 6 | 7 | def poison_train(model, loader, criterion, optimizer): 8 | loss_meter = AverageMeter("loss") 9 | acc_meter = AverageMeter("acc") 10 | meter_list = [ 11 | loss_meter, 12 | acc_meter, 13 | ] 14 | 15 | model.train() 16 | gpu = next(model.parameters()).device 17 | start_time = time.time() 18 | for batch_idx, batch in enumerate(loader): 19 | data = batch["img"].cuda(gpu, non_blocking=True) 20 | target = batch["target"].cuda(gpu, non_blocking=True) 21 | output = model(data) 22 | loss = criterion(output, target) 23 | optimizer.zero_grad() 24 | loss.backward() 25 | optimizer.step() 26 | 27 | loss_meter.update(loss.item()) 28 | pred = output.argmax(dim=1, keepdim=True) 29 | truth = pred.view_as(target).eq(target) 30 | acc_meter.update((1.0 * torch.sum(truth) / len(truth)).item()) 31 | 32 | tabulate_step_meter(batch_idx, len(loader), 3, meter_list) 33 | 34 | print("Training summary:") 35 | tabulate_epoch_meter(time.time() - start_time, meter_list) 36 | 37 | 38 | def test(model, loader, criterion): 39 | loss_meter = AverageMeter("loss") 40 | acc_meter = AverageMeter("acc") 41 | meter_list = [loss_meter, acc_meter] 42 | 43 | model.eval() 44 | gpu = next(model.parameters()).device 45 | start_time = time.time() 46 | for batch_idx, batch in enumerate(loader): 47 | data = batch["img"].cuda(gpu, non_blocking=True) 48 | target = batch["target"].cuda(gpu, non_blocking=True) 49 | with torch.no_grad(): 50 | output = model(data) 51 | pred = output.argmax(dim=1, keepdim=True) 52 | loss = criterion(output, target) 53 | 54 | loss_meter.update(loss.item()) 55 | pred = output.argmax(dim=1, keepdim=True) 56 | truth = pred.view_as(target).eq(target) 57 | acc_meter.update((torch.sum(truth).float() / len(truth)).item()) 58 | 59 | tabulate_step_meter(batch_idx, len(loader), 3, meter_list) 60 | tabulate_epoch_meter(time.time() - start_time, meter_list) 61 | 62 | 63 | def tabulate_step_meter(batch_idx, num_batches, num_intervals, meter_list): 64 | """ Tabulate current average value of meters every `step_interval`. 65 | 66 | Args: 67 | batch_idx (int): The batch index in an epoch. 68 | num_batches (int): The number of batch in an epoch. 69 | num_intervals (int): The number of interval to tabulate. 70 | meter_list (list or tuple of AverageMeter): A list of meters. 71 | """ 72 | step_interval = int(num_batches / num_intervals) 73 | if batch_idx % step_interval == 0: 74 | step_meter = {"Iteration": ["{}/{}".format(batch_idx, num_batches)]} 75 | for m in meter_list: 76 | step_meter[m.name] = [m.batch_avg] 77 | table = tabulate(step_meter, headers="keys", tablefmt="github", floatfmt=".5f") 78 | if batch_idx == 0: 79 | table = table.split("\n") 80 | table = "\n".join([table[1]] + table) 81 | else: 82 | table = table.split("\n")[2] 83 | print(table) 84 | 85 | 86 | def tabulate_epoch_meter(elapsed_time, meter_list): 87 | """ Tabulate total average value of meters every epoch. 88 | 89 | Args: 90 | eplased_time (float): The elapsed time of a epoch. 91 | meter_list (list or tuple of AverageMeter): A list of meters. 92 | """ 93 | epoch_meter = {m.name: [m.total_avg] for m in meter_list} 94 | epoch_meter["time"] = [elapsed_time] 95 | table = tabulate(epoch_meter, headers="keys", tablefmt="github", floatfmt=".5f") 96 | table = table.split("\n") 97 | table = "\n".join([table[1]] + table) 98 | print(table) 99 | 100 | 101 | class AverageMeter(object): 102 | """Computes and stores the average and current value. 103 | 104 | Modified from https://github.com/pytorch/examples/blob/master/imagenet/main.py 105 | """ 106 | 107 | def __init__(self, name, fmt=None): 108 | self.name = name 109 | self.reset() 110 | 111 | def reset(self): 112 | self.batch_avg = 0 113 | self.total_avg = 0 114 | self.sum = 0 115 | self.count = 0 116 | 117 | def update(self, avg, n=1): 118 | self.batch_avg = avg 119 | self.sum += avg * n 120 | self.count += n 121 | self.total_avg = self.sum / self.count 122 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import yaml 8 | 9 | 10 | def load_config(config_path): 11 | """Load config file from `config_path`. 12 | 13 | Args: 14 | config_path (str): Configuration file path, which must be in `config` dir, e.g., 15 | `./config/inner_dir/example.yaml` and `config/inner_dir/example`. 16 | 17 | Returns: 18 | config (dict): Configuration dict. 19 | inner_dir (str): Directory between `config/` and configuration file. If `config_path` 20 | doesn't contain `inner_dir`, return empty string. 21 | config_name (str): Configuration filename. 22 | """ 23 | assert os.path.exists(config_path) 24 | config_hierarchy = config_path.split("/") 25 | if config_hierarchy[0] != ".": 26 | if config_hierarchy[0] != "config": 27 | raise RuntimeError( 28 | "Configuration file {} must be in config dir".format(config_path) 29 | ) 30 | if len(config_hierarchy) > 2: 31 | inner_dir = os.path.join(*config_hierarchy[1:-1]) 32 | else: 33 | inner_dir = "" 34 | else: 35 | if config_hierarchy[1] != "config": 36 | raise RuntimeError( 37 | "Configuration file {} must be in config dir".format(config_path) 38 | ) 39 | if len(config_hierarchy) > 3: 40 | inner_dir = os.path.join(*config_hierarchy[2:-1]) 41 | else: 42 | inner_dir = "" 43 | print("Load configuration file from {}:".format(config_path)) 44 | with open(config_path, "r") as f: 45 | config = yaml.safe_load(f) 46 | config_name = config_hierarchy[-1].split(".yaml")[0] 47 | 48 | return config, inner_dir, config_name 49 | 50 | 51 | def gen_poison_idx(dataset, target_label, poison_ratio=None): 52 | poison_idx = np.zeros(len(dataset)) 53 | train = dataset.train 54 | for (i, t) in enumerate(dataset.targets): 55 | if train and poison_ratio is not None: 56 | if random.random() < poison_ratio and t == target_label: 57 | poison_idx[i] = 1 58 | else: 59 | if t != target_label: 60 | poison_idx[i] = 1 61 | 62 | return poison_idx 63 | 64 | 65 | class NormalizeByChannelMeanStd(nn.Module): 66 | """Normalizing the input to the network. 67 | """ 68 | 69 | def __init__(self, mean, std): 70 | super(NormalizeByChannelMeanStd, self).__init__() 71 | if not isinstance(mean, torch.Tensor): 72 | mean = torch.tensor(mean) 73 | if not isinstance(std, torch.Tensor): 74 | std = torch.tensor(std) 75 | self.register_buffer("mean", mean) 76 | self.register_buffer("std", std) 77 | 78 | def forward(self, tensor): 79 | mean = self.mean[None, :, None, None] 80 | std = self.std[None, :, None, None] 81 | return tensor.sub(mean).div(std) 82 | 83 | def extra_repr(self): 84 | return "mean={}, std={}".format(self.mean, self.std) 85 | --------------------------------------------------------------------------------