├── README.md ├── attack_utils.py ├── carlini.py ├── fgs.py ├── mnist.py ├── models ├── modelA.pkl ├── modelA_adv.pkl ├── modelA_ens.pkl ├── modelB.pkl ├── modelC.pkl └── modelD.pkl ├── simple_eval.py ├── train.py ├── train_adv.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Ensemble Adversarial Training With Pytorch 2 | 3 | This repository contains pytorch code to reproduce results from the paper: 4 | 5 | **Ensemble Adversarial Training: Attacks and Defenses**
6 | *Florian Tramèr, Alexey Kurakin, Nicolas Papernot, Dan Boneh and Patrick McDaniel*
7 | ArXiv report: https://arxiv.org/abs/1705.07204 8 | 9 |
10 | 11 | ###### REQUIREMENTS 12 | 13 | The code was tested with Python 3.6.7 and Pytorch 1.0.1. 14 | 15 | ###### EXPERIMENTS 16 | 17 | Training a few simple MNIST models. These are described in _mnist.py_. 18 | 19 | ``` 20 | python -m train models/modelA --type=0 21 | python -m train models/modelB --type=1 22 | python -m train models/modelC --type=2 23 | python -m train models/modelD --type=3 24 | ``` 25 | 26 | (standard) Adversarial Training: 27 | 28 | ``` 29 | python -m train_adv models/modelA_adv --type=0 --epochs=12 30 | ``` 31 | Ensemble Adversarial Training: 32 | ``` 33 | python -m train_adv models/modelA_ens models/modelA models/modelC models/modelD --type=0 --epochs=12 34 | ``` 35 | 36 | The accuracy of the models on the MNIST test set can be computed using 37 | 38 | ``` 39 | python -m simple_eval test [model(s)] 40 | ``` 41 | 42 | To evaluate robustness to various attacks 43 | 44 | ``` 45 | python -m simple_eval [attack] [source_model] [target_model(s)] [--parameters (opt)] 46 | ``` 47 | 48 | ###### REFERENCE 49 | 1. Author's code: [ftramer/ensemble-adv-training](https://github.com/ftramer/ensemble-adv-training) 50 | -------------------------------------------------------------------------------- /attack_utils.py: -------------------------------------------------------------------------------- 1 | # --coding:utf-8-- 2 | ''' 3 | @author: cailikun 4 | @time: 19-3-27 下午7:07 5 | ''' 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | def gen_adv_loss(logits, labels, loss='logloss', mean=False): 11 | ''' 12 | Generate the loss function 13 | ''' 14 | if loss == 'training': 15 | # use the model's output instead of the true labels to avoid 16 | # label leaking at training time 17 | labels = logits.max(1)[1] 18 | if mean: 19 | out = F.cross_entropy(logits, labels, reduction='mean') 20 | else: 21 | out = F.cross_entropy(logits, labels, reduction='sum') 22 | elif loss == 'logloss': 23 | if mean: 24 | out = F.cross_entropy(logits, labels, reduction='mean') 25 | else: 26 | out = F.cross_entropy(logits, labels, reduction='sum') 27 | else: 28 | raise ValueError('Unknown loss: {}'.format(loss)) 29 | return out 30 | 31 | def gen_grad(x, model, y, loss='logloss'): 32 | ''' 33 | Generate the gradient of the loss function. 34 | ''' 35 | model.eval() 36 | x.requires_grad = True 37 | 38 | # Define gradient of loss wrt input 39 | logits = model(x) 40 | adv_loss = gen_adv_loss(logits, y, loss) 41 | model.zero_grad() 42 | adv_loss.backward() 43 | grad = x.grad.data 44 | return grad 45 | -------------------------------------------------------------------------------- /carlini.py: -------------------------------------------------------------------------------- 1 | # --coding:utf-8-- 2 | ''' 3 | @author: cailikun 4 | @time: 2019/4/6 上午11:23 5 | ''' 6 | import torch 7 | import numpy as np 8 | 9 | MAX_ITERATIONS = 1000 10 | ABORT_EARLY = True 11 | INITIAL_CONST = 1e-3 12 | LEARNING_RATE = 5e-3 13 | LARGEST_CONST = 2e+1 14 | TARGETED = True 15 | CONST_FACTOR = 10.0 16 | CONFIDENCE = 0 17 | EPS = 0.3 18 | 19 | class Carlini: 20 | def __init__(self, model, targeted = TARGETED, learning_rate = LEARNING_RATE, max_iterations = MAX_ITERATIONS, 21 | abort_early = ABORT_EARLY, initial_const = INITIAL_CONST, largest_const = LARGEST_CONST, 22 | const_factor = CONST_FACTOR, confidence = CONFIDENCE, eps = EPS): 23 | self.model = model 24 | 25 | self.TARGETED = targeted 26 | self.LEARNING_RATE = LEARNING_RATE 27 | self.MAX_ITERATIONS = max_iterations 28 | self.ABORT_EARLY = abort_early 29 | self.INITIAL_CONST = initial_const 30 | self.LARGEST_CONST = largest_const 31 | self.CONST_FACTOR = const_factor 32 | self.EPS = eps 33 | 34 | -------------------------------------------------------------------------------- /fgs.py: -------------------------------------------------------------------------------- 1 | # --coding:utf-8-- 2 | ''' 3 | @author: cailikun 4 | @time: 2019/4/4 上午12:10 5 | ''' 6 | import torch 7 | from attack_utils import gen_grad 8 | 9 | def symbolic_fgs(data, grad, eps=0.3, clipping=True): 10 | ''' 11 | FGSM attack. 12 | ''' 13 | # signed gradien 14 | normed_grad = grad.detach().sign() 15 | 16 | # Multiply by constant epsilon 17 | scaled_grad = eps * normed_grad 18 | 19 | # Add perturbation to original example to obtain adversarial example 20 | adv_x = data.detach() + scaled_grad 21 | if clipping: 22 | adv_x = torch.clamp(adv_x, 0, 1) 23 | return adv_x 24 | 25 | def iter_fgs(model, data, labels, steps, eps): 26 | ''' 27 | I-FGSM attack. 28 | ''' 29 | adv_x = data 30 | 31 | # iteratively apply the FGSM with small step size 32 | for i in range(steps): 33 | grad = gen_grad(adv_x, model, labels) 34 | adv_x = symbolic_fgs(adv_x, grad, eps) 35 | return adv_x -------------------------------------------------------------------------------- /mnist.py: -------------------------------------------------------------------------------- 1 | # --coding:utf-8-- 2 | ''' 3 | @author: cailikun 4 | @time: 19-3-25 下午4:43 5 | ''' 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.utils.data 11 | from torchvision import datasets, transforms 12 | 13 | 14 | 15 | class modelA(nn.Module): 16 | def __init__(self): 17 | super(modelA, self).__init__() 18 | self.conv1 = nn.Conv2d(1, 64, 5) 19 | self.conv2 = nn.Conv2d(64, 64, 5) 20 | self.dropout1 = nn.Dropout(0.25) 21 | self.fc1 = nn.Linear(64 * 20 * 20, 128) 22 | self.dropout2 = nn.Dropout(0.5) 23 | self.fc2 = nn.Linear(128, 10) 24 | 25 | def forward(self, x): 26 | x = F.relu(self.conv1(x)) 27 | x = F.relu(self.conv2(x)) 28 | x = self.dropout1(x) 29 | x = x.view(x.size(0), -1) 30 | x = F.relu(self.fc1(x)) 31 | x = self.dropout2(x) 32 | x = self.fc2(x) 33 | return x 34 | 35 | class modelB(nn.Module): 36 | def __init__(self): 37 | super(modelB, self).__init__() 38 | self.dropout1 = nn.Dropout(0.2) 39 | self.conv1 = nn.Conv2d(1, 64, 8) 40 | self.conv2 = nn.Conv2d(64, 128, 6) 41 | self.conv3 = nn.Conv2d(128, 128, 5) 42 | self.dropout2 = nn.Dropout(0.5) 43 | self.fc = nn.Linear(128 * 12 * 12, 10) 44 | 45 | def forward(self, x): 46 | x = self.dropout1(x) 47 | x = F.relu(self.conv1(x)) 48 | x = F.relu(self.conv2(x)) 49 | x = F.relu(self.conv3(x)) 50 | x = self.dropout2(x) 51 | x = x.view(x.size(0), -1) 52 | x = self.fc(x) 53 | return x 54 | 55 | class modelC(nn.Module): 56 | def __init__(self): 57 | super(modelC, self).__init__() 58 | self.conv1 = nn.Conv2d(1, 128, 3) 59 | self.conv2 = nn.Conv2d(128, 64, 3) 60 | self.fc1 = nn.Linear(64 * 5 * 5, 128) 61 | self.fc2 = nn.Linear(128, 10) 62 | 63 | def forward(self, x): 64 | x = torch.tanh(self.conv1(x)) 65 | x = F.max_pool2d(x, 2) 66 | x = torch.tanh(self.conv2(x)) 67 | x = F.max_pool2d(x, 2) 68 | x = x.view(x.size(0), -1) 69 | x = F.relu(self.fc1(x)) 70 | x = self.fc2(x) 71 | return x 72 | 73 | class modelD(nn.Module): 74 | def __init__(self): 75 | super(modelD, self).__init__() 76 | self.fc1 = nn.Linear(1 * 28 * 28, 300) 77 | self.dropout1 = nn.Dropout(0.5) 78 | self.fc2 = nn.Linear(300, 300) 79 | self.dropout2 = nn.Dropout(0.5) 80 | self.fc3 = nn.Linear(300, 300) 81 | self.dropout3 = nn.Dropout(0.5) 82 | self.fc4 = nn.Linear(300, 300) 83 | self.dropout4 = nn.Dropout(0.5) 84 | self.fc5 = nn.Linear(300, 10) 85 | 86 | def forward(self, x): 87 | x = x.view(x.size(0), -1) 88 | x = F.relu(self.fc1(x)) 89 | x = self.dropout1(x) 90 | x = F.relu(self.fc2(x)) 91 | x = self.dropout2(x) 92 | x = F.relu(self.fc3(x)) 93 | x = self.dropout3(x) 94 | x = F.relu(self.fc4(x)) 95 | x = self.dropout4(x) 96 | x = self.fc5(x) 97 | return x 98 | 99 | def model_mnist(type=1): 100 | ''' 101 | Defines MNIST model 102 | ''' 103 | models = [modelA, modelB, modelC, modelD] 104 | return models[type]() 105 | 106 | def load_model(model_path, type=1): 107 | model = model_mnist(type=type) 108 | model.load_state_dict(torch.load(model_path+'.pkl')) 109 | return model 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /models/modelA.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cailk/ensemble-adv-training-pytorch/50b8047c9150b726e8633b17d3db0275b5aca366/models/modelA.pkl -------------------------------------------------------------------------------- /models/modelA_adv.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cailk/ensemble-adv-training-pytorch/50b8047c9150b726e8633b17d3db0275b5aca366/models/modelA_adv.pkl -------------------------------------------------------------------------------- /models/modelA_ens.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cailk/ensemble-adv-training-pytorch/50b8047c9150b726e8633b17d3db0275b5aca366/models/modelA_ens.pkl -------------------------------------------------------------------------------- /models/modelB.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cailk/ensemble-adv-training-pytorch/50b8047c9150b726e8633b17d3db0275b5aca366/models/modelB.pkl -------------------------------------------------------------------------------- /models/modelC.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cailk/ensemble-adv-training-pytorch/50b8047c9150b726e8633b17d3db0275b5aca366/models/modelC.pkl -------------------------------------------------------------------------------- /models/modelD.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cailk/ensemble-adv-training-pytorch/50b8047c9150b726e8633b17d3db0275b5aca366/models/modelD.pkl -------------------------------------------------------------------------------- /simple_eval.py: -------------------------------------------------------------------------------- 1 | # --coding:utf-8-- 2 | ''' 3 | @author: cailikun 4 | @time: 2019/4/5 下午11:20 5 | ''' 6 | import torch 7 | import torchvision 8 | import torch.optim as optim 9 | import torch.utils.data 10 | from torchvision import datasets, transforms 11 | from mnist import * 12 | from utils import train, test 13 | from attack_utils import gen_grad 14 | from fgs import symbolic_fgs, iter_fgs 15 | from os.path import basename 16 | import argparse 17 | 18 | 19 | 20 | def main(args): 21 | def get_model_type(model_name): 22 | model_type = { 23 | 'models/modelA':0, 'models/modelA_adv':0, 'models/modelA_ens':0, 24 | 'models/modelB':1, 'models/modelB_adv':1, 'models/modelB_ens':1, 25 | 'models/modelC':2, 'models/modelC_adv':2, 'models/modelC_ens':2, 26 | 'models/modelD':3, 'models/modelD_adv':3, 'models/modelD_ens':3, 27 | } 28 | if model_name not in model_type.keys(): 29 | raise ValueError('Unknown model: {}'.format(model_name)) 30 | return model_type[model_name] 31 | 32 | torch.manual_seed(args.seed) 33 | device = torch.device('cuda' if args.cuda else 'cpu') 34 | 35 | ''' 36 | Preprocess MNIST dataset 37 | ''' 38 | kwargs = {'num_workers': 20, 'pin_memory': True} if args.cuda else {} 39 | test_loader = torch.utils.data.DataLoader( 40 | datasets.MNIST('../attack_mnist', train=False, transform=transforms.ToTensor()), 41 | batch_size=args.batch_size, shuffle=True, **kwargs) 42 | 43 | # source model for crafting adversarial examples 44 | src_model_name = args.src_model 45 | type = get_model_type(src_model_name) 46 | src_model = load_model(src_model_name, type).to(device) 47 | 48 | # model(s) to target 49 | target_model_names = args.target_models 50 | target_models = [None] * len(target_model_names) 51 | for i in range(len(target_model_names)): 52 | type = get_model_type(target_model_names[i]) 53 | target_models[i] = load_model(target_model_names[i], type=type).to(device) 54 | 55 | attack = args.attack 56 | 57 | # simply compute test error 58 | if attack == 'test': 59 | correct_s = 0 60 | with torch.no_grad(): 61 | for (data, labels) in test_loader: 62 | data, labels = data.to(device), labels.to(device) 63 | correct_s += test(src_model, data, labels) 64 | err = 100. - 100. * correct_s / len(test_loader.dataset) 65 | print('Test error of {}: {:.2f}'.format(basename(src_model_name), err)) 66 | 67 | for (name, target_model) in zip(target_model_names, target_models): 68 | correct_t = 0 69 | with torch.no_grad(): 70 | for (data, labels) in test_loader: 71 | data, labels = data.to(device), labels.to(device) 72 | correct_t += test(target_model, data, labels) 73 | err = 100. - 100. * correct_t / len(test_loader.dataset) 74 | print('Test error of {}: {:.2f}'.format(basename(target_model_names), err)) 75 | return 76 | 77 | eps = args.eps 78 | 79 | correct = 0 80 | for (data, labels) in test_loader: 81 | # take the random step in the RAND+FGSM 82 | if attack == 'rand_fgs': 83 | data = torch.clamp(data + torch.zeros_like(data).uniform_(-args.alpha, args.alpha), 0.0, 1.0) 84 | eps -= args.alpha 85 | data, labels = data.to(device), labels.to(device) 86 | grad = gen_grad(data, src_model, labels) 87 | 88 | # FGSM and RAND+FGSM one-shot attack 89 | if attack in ['fgs', 'rand_fgs']: 90 | adv_x = symbolic_fgs(data, grad, eps=eps) 91 | 92 | # iterative FGSM 93 | if attack == 'ifgs': 94 | adv_x = iter_fgs(src_model, data, labels, steps=args.steps, eps=args.eps/args.steps) 95 | 96 | correct += test(src_model, adv_x, labels) 97 | test_error = 100. - 100. * correct / len(test_loader.dataset) 98 | print('Test Set Error Rate: {:.2f}%'.format(test_error)) 99 | 100 | 101 | if __name__ == '__main__': 102 | parser = argparse.ArgumentParser(description='Simple eval') 103 | parser.add_argument('attack', choices=['test', 'fgs', 'ifgs', 'rand_fgs', 'CW'], help='Name of attack') 104 | parser.add_argument('src_model', help='Source model for attack') 105 | parser.add_argument('target_models', nargs='*', help='path to target model(s)') 106 | parser.add_argument('--batch_size', type=int, default=64, help='Size of training batches (default: 64)') 107 | parser.add_argument('--eps', type=float, default=0.3, help='FGS attack scale (default: 0.3)') 108 | parser.add_argument('--alpha', type=float, default=0.05, help='RAND+FGSM random pertubation scale') 109 | parser.add_argument('--steps', type=int, default=10, help='Iterated FGS steps (default: 10)') 110 | parser.add_argument('--kappa', type=float, default=100, help='CW attack confidence') 111 | parser.add_argument('--seed', type=int, default=1, help='Random seed (default: 1)') 112 | parser.add_argument('--disable_cuda', action='store_true', default=False, help='Disable CUDA (default: False)') 113 | 114 | args = parser.parse_args() 115 | args.cuda = not args.disable_cuda and torch.cuda.is_available() 116 | main(args) 117 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # --coding:utf-8-- 2 | ''' 3 | @author: cailikun 4 | @time: 19-3-26 上午10:26 5 | ''' 6 | 7 | import torch 8 | import torchvision 9 | import torch.optim as optim 10 | import torch.utils.data 11 | from torchvision import datasets, transforms 12 | from mnist import * 13 | from utils import train, test 14 | import argparse 15 | import os 16 | 17 | 18 | def main(args): 19 | torch.manual_seed(args.seed) 20 | device = torch.device('cuda' if args.cuda else 'cpu') 21 | 22 | ''' 23 | Preprocess MNIST dataset 24 | ''' 25 | kwargs = {'num_workers': 20, 'pin_memory': True} if args.cuda else {} 26 | train_loader = torch.utils.data.DataLoader( 27 | datasets.MNIST('../attack_mnist', train=True, download=True, transform=transforms.ToTensor()), 28 | batch_size=args.batch_size, shuffle=True, **kwargs) 29 | test_loader = torch.utils.data.DataLoader( 30 | datasets.MNIST('../attack_mnist', train=False, transform=transforms.ToTensor()), 31 | batch_size=args.batch_size, shuffle=True, **kwargs) 32 | 33 | model = model_mnist(type=args.type).to(device) 34 | optimizer = optim.Adam(model.parameters()) 35 | 36 | # Train an MNIST model 37 | for epoch in range(args.epochs): 38 | for batch_idx, (data, labels) in enumerate(train_loader): 39 | data, labels = data.to(device), labels.to(device) 40 | train(epoch, batch_idx, model, data, labels, optimizer) 41 | 42 | # Finally print the result! 43 | correct = 0 44 | with torch.no_grad(): 45 | for (data, labels) in test_loader: 46 | data, labels = data.to(device), labels.to(device) 47 | correct += test(model, data, labels) 48 | test_error = 100. - 100. * correct / len(test_loader.dataset) 49 | print('Test Set Error Rate: {:.2f}%'.format(test_error)) 50 | 51 | torch.save(model.state_dict(), args.model+'.pkl') 52 | 53 | 54 | if __name__ == '__main__': 55 | parser = argparse.ArgumentParser(description='Training MNIST model') 56 | parser.add_argument('model', help='path to model') 57 | parser.add_argument('--type', type=int, default=1, help='Model type (default: 1)') 58 | parser.add_argument('--seed', type=int, default=1, help='Random seed (default: 1)') 59 | parser.add_argument('--disable_cuda', action='store_true', default=False, help='Disable CUDA (default: False)') 60 | parser.add_argument('--batch_size', type=int, default=64, help='Size of training batches (default: 64)') 61 | parser.add_argument('--epochs', type=int, default=6, help='Number of epochs to train (default: 6)') 62 | #parser.print_help() 63 | args = parser.parse_args() 64 | args.cuda = not args.disable_cuda and torch.cuda.is_available() 65 | main(args) 66 | 67 | 68 | -------------------------------------------------------------------------------- /train_adv.py: -------------------------------------------------------------------------------- 1 | # --coding:utf-8-- 2 | ''' 3 | @author: cailikun 4 | @time: 2019/4/2 下午2:13 5 | ''' 6 | 7 | import torch 8 | import torchvision 9 | import torch.optim as optim 10 | import torch.utils.data 11 | from torchvision import datasets, transforms 12 | from mnist import * 13 | from utils import train, test 14 | from attack_utils import gen_grad 15 | from fgs import symbolic_fgs 16 | import argparse 17 | import os 18 | 19 | def main(args): 20 | def get_model_type(model_name): 21 | model_type = { 22 | 'models/modelA': 0, 'models/modelA_adv': 0, 'models/modelA_ens': 0, 23 | 'models/modelB': 1, 'models/modelB_adv': 1, 'models/modelB_ens': 1, 24 | 'models/modelC': 2, 'models/modelC_adv': 2, 'models/modelC_ens': 2, 25 | 'models/modelD': 3, 'models/modelD_adv': 3, 'models/modelD_ens': 3, 26 | } 27 | if model_name not in model_type.keys(): 28 | raise ValueError('Unknown model: {}'.format(model_name)) 29 | return model_type[model_name] 30 | 31 | torch.manual_seed(args.seed) 32 | device = torch.device('cuda' if args.cuda else 'cpu') 33 | 34 | ''' 35 | Preprocess MNIST dataset 36 | ''' 37 | kwargs = {'num_workers': 20, 'pin_memory': True} if args.cuda else {} 38 | train_loader = torch.utils.data.DataLoader( 39 | datasets.MNIST('../attack_mnist', train=True, download=True, transform=transforms.ToTensor()), 40 | batch_size=args.batch_size, shuffle=True, **kwargs) 41 | test_loader = torch.utils.data.DataLoader( 42 | datasets.MNIST('../attack_mnist', train=False, transform=transforms.ToTensor()), 43 | batch_size=args.batch_size, shuffle=True, **kwargs) 44 | 45 | eps = args.eps 46 | 47 | # if src_models is not None, we train on adversarial examples that come 48 | # from multiple models 49 | adv_model_names = args.adv_models 50 | adv_models = [None] * len(adv_model_names) 51 | for i in range(len(adv_model_names)): 52 | type = get_model_type(adv_model_names[i]) 53 | adv_models[i] = load_model(adv_model_names[i], type=type).to(device) 54 | 55 | model = model_mnist(type=args.type).to(device) 56 | optimizer = optim.Adam(model.parameters()) 57 | 58 | # Train on MNIST model 59 | x_advs = [None] * (len(adv_models) + 1) 60 | for epoch in range(args.epochs): 61 | for batch_idx, (data, labels) in enumerate(train_loader): 62 | data, labels = data.to(device), labels.to(device) 63 | for i, m in enumerate(adv_models + [model]): 64 | grad = gen_grad(data, m, labels, loss='training') 65 | x_advs[i] = symbolic_fgs(data, grad, eps=eps) 66 | train(epoch, batch_idx, model, data, labels, optimizer, x_advs=x_advs) 67 | 68 | # Finally print the result 69 | correct = 0 70 | with torch.no_grad(): 71 | for (data, labels) in test_loader: 72 | data, labels = data.to(device), labels.to(device) 73 | correct += test(model, data, labels) 74 | test_error = 100. - 100. * correct / len(test_loader.dataset) 75 | print('Test Set Error Rate: {:.2f}%'.format(test_error)) 76 | 77 | torch.save(model.state_dict(), args.model + '.pkl') 78 | 79 | 80 | 81 | if __name__ == '__main__': 82 | parser = argparse.ArgumentParser(description='Adversarial Training MNIST model') 83 | parser.add_argument('model', help='path to model') 84 | parser.add_argument('adv_models', nargs='*', help='path to adv model(s)') 85 | parser.add_argument('--type', type=int, default=0, help='Model type (default: 0)') 86 | parser.add_argument('--seed', type=int, default=1, help='Random seed (default: 1)') 87 | parser.add_argument('--disable_cuda', action='store_true', default=False, help='Disable CUDA (default: False)') 88 | parser.add_argument('--batch_size', type=int, default=64, help='Size of training batches (default: 64)') 89 | parser.add_argument('--epochs', type=int, default=12, help='Number of epochs (default: 12)') 90 | parser.add_argument('--eps', type=float, default=0.3, help='FGSM attack scale (default: 0.3)') 91 | 92 | args = parser.parse_args() 93 | args.cuda = not args.disable_cuda and torch.cuda.is_available() 94 | main(args) 95 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # --coding:utf-8-- 2 | ''' 3 | @author: cailikun 4 | @time: 19-3-27 上午10:26 5 | ''' 6 | import torch 7 | import torch.nn.functional as F 8 | from attack_utils import gen_adv_loss 9 | import numpy as np 10 | 11 | EVAL_FREQUENCY = 100 12 | 13 | def train(epoch, batch_idx, model, data, labels, optimizer, x_advs=None): 14 | model.train() 15 | optimizer.zero_grad() 16 | # Generate cross-entropy loss for training 17 | logits = model(data) 18 | preds = logits.max(1)[1] 19 | loss1 = gen_adv_loss(logits, labels, mean=True) 20 | 21 | # add adversarial training loss 22 | if x_advs is not None: 23 | 24 | # choose source of adversarial examples at random 25 | # (for ensemble adversarial training) 26 | idx = np.random.randint(len(x_advs)) 27 | logits_adv = model(x_advs[idx]) 28 | loss2 = gen_adv_loss(logits_adv, labels, mean=True) 29 | loss = 0.5 * (loss1 + loss2) 30 | else: 31 | loss2 = torch.zeros(loss1.size()) 32 | loss = loss1 33 | loss.backward() 34 | optimizer.step() 35 | if batch_idx % EVAL_FREQUENCY == 0: 36 | print('Step: {}(epoch: {})\tLoss: {:.6f}<=({:.6f}, {:.6f})\tError: {:.2f}%'.format( 37 | batch_idx, epoch+1, loss.item(), loss1.item(), loss2.item(), error_rate(preds, labels) 38 | )) 39 | 40 | def test(model, data, labels): 41 | model.eval() 42 | correct = 0 43 | logits = model(data) 44 | 45 | # Prediction for the test set 46 | preds = logits.max(1)[1] 47 | correct += preds.eq(labels).sum().item() 48 | return correct 49 | 50 | def error_rate(preds, labels): 51 | ''' 52 | Run the error rate 53 | ''' 54 | assert preds.size() == labels.size() 55 | return 100.0 - (100.0 * preds.eq(labels).sum().item()) / preds.size(0) 56 | 57 | 58 | 59 | 60 | 61 | --------------------------------------------------------------------------------