├── README.md
├── attack_utils.py
├── carlini.py
├── fgs.py
├── mnist.py
├── models
├── modelA.pkl
├── modelA_adv.pkl
├── modelA_ens.pkl
├── modelB.pkl
├── modelC.pkl
└── modelD.pkl
├── simple_eval.py
├── train.py
├── train_adv.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Ensemble Adversarial Training With Pytorch
2 |
3 | This repository contains pytorch code to reproduce results from the paper:
4 |
5 | **Ensemble Adversarial Training: Attacks and Defenses**
6 | *Florian Tramèr, Alexey Kurakin, Nicolas Papernot, Dan Boneh and Patrick McDaniel*
7 | ArXiv report: https://arxiv.org/abs/1705.07204
8 |
9 |
10 |
11 | ###### REQUIREMENTS
12 |
13 | The code was tested with Python 3.6.7 and Pytorch 1.0.1.
14 |
15 | ###### EXPERIMENTS
16 |
17 | Training a few simple MNIST models. These are described in _mnist.py_.
18 |
19 | ```
20 | python -m train models/modelA --type=0
21 | python -m train models/modelB --type=1
22 | python -m train models/modelC --type=2
23 | python -m train models/modelD --type=3
24 | ```
25 |
26 | (standard) Adversarial Training:
27 |
28 | ```
29 | python -m train_adv models/modelA_adv --type=0 --epochs=12
30 | ```
31 | Ensemble Adversarial Training:
32 | ```
33 | python -m train_adv models/modelA_ens models/modelA models/modelC models/modelD --type=0 --epochs=12
34 | ```
35 |
36 | The accuracy of the models on the MNIST test set can be computed using
37 |
38 | ```
39 | python -m simple_eval test [model(s)]
40 | ```
41 |
42 | To evaluate robustness to various attacks
43 |
44 | ```
45 | python -m simple_eval [attack] [source_model] [target_model(s)] [--parameters (opt)]
46 | ```
47 |
48 | ###### REFERENCE
49 | 1. Author's code: [ftramer/ensemble-adv-training](https://github.com/ftramer/ensemble-adv-training)
50 |
--------------------------------------------------------------------------------
/attack_utils.py:
--------------------------------------------------------------------------------
1 | # --coding:utf-8--
2 | '''
3 | @author: cailikun
4 | @time: 19-3-27 下午7:07
5 | '''
6 |
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | def gen_adv_loss(logits, labels, loss='logloss', mean=False):
11 | '''
12 | Generate the loss function
13 | '''
14 | if loss == 'training':
15 | # use the model's output instead of the true labels to avoid
16 | # label leaking at training time
17 | labels = logits.max(1)[1]
18 | if mean:
19 | out = F.cross_entropy(logits, labels, reduction='mean')
20 | else:
21 | out = F.cross_entropy(logits, labels, reduction='sum')
22 | elif loss == 'logloss':
23 | if mean:
24 | out = F.cross_entropy(logits, labels, reduction='mean')
25 | else:
26 | out = F.cross_entropy(logits, labels, reduction='sum')
27 | else:
28 | raise ValueError('Unknown loss: {}'.format(loss))
29 | return out
30 |
31 | def gen_grad(x, model, y, loss='logloss'):
32 | '''
33 | Generate the gradient of the loss function.
34 | '''
35 | model.eval()
36 | x.requires_grad = True
37 |
38 | # Define gradient of loss wrt input
39 | logits = model(x)
40 | adv_loss = gen_adv_loss(logits, y, loss)
41 | model.zero_grad()
42 | adv_loss.backward()
43 | grad = x.grad.data
44 | return grad
45 |
--------------------------------------------------------------------------------
/carlini.py:
--------------------------------------------------------------------------------
1 | # --coding:utf-8--
2 | '''
3 | @author: cailikun
4 | @time: 2019/4/6 上午11:23
5 | '''
6 | import torch
7 | import numpy as np
8 |
9 | MAX_ITERATIONS = 1000
10 | ABORT_EARLY = True
11 | INITIAL_CONST = 1e-3
12 | LEARNING_RATE = 5e-3
13 | LARGEST_CONST = 2e+1
14 | TARGETED = True
15 | CONST_FACTOR = 10.0
16 | CONFIDENCE = 0
17 | EPS = 0.3
18 |
19 | class Carlini:
20 | def __init__(self, model, targeted = TARGETED, learning_rate = LEARNING_RATE, max_iterations = MAX_ITERATIONS,
21 | abort_early = ABORT_EARLY, initial_const = INITIAL_CONST, largest_const = LARGEST_CONST,
22 | const_factor = CONST_FACTOR, confidence = CONFIDENCE, eps = EPS):
23 | self.model = model
24 |
25 | self.TARGETED = targeted
26 | self.LEARNING_RATE = LEARNING_RATE
27 | self.MAX_ITERATIONS = max_iterations
28 | self.ABORT_EARLY = abort_early
29 | self.INITIAL_CONST = initial_const
30 | self.LARGEST_CONST = largest_const
31 | self.CONST_FACTOR = const_factor
32 | self.EPS = eps
33 |
34 |
--------------------------------------------------------------------------------
/fgs.py:
--------------------------------------------------------------------------------
1 | # --coding:utf-8--
2 | '''
3 | @author: cailikun
4 | @time: 2019/4/4 上午12:10
5 | '''
6 | import torch
7 | from attack_utils import gen_grad
8 |
9 | def symbolic_fgs(data, grad, eps=0.3, clipping=True):
10 | '''
11 | FGSM attack.
12 | '''
13 | # signed gradien
14 | normed_grad = grad.detach().sign()
15 |
16 | # Multiply by constant epsilon
17 | scaled_grad = eps * normed_grad
18 |
19 | # Add perturbation to original example to obtain adversarial example
20 | adv_x = data.detach() + scaled_grad
21 | if clipping:
22 | adv_x = torch.clamp(adv_x, 0, 1)
23 | return adv_x
24 |
25 | def iter_fgs(model, data, labels, steps, eps):
26 | '''
27 | I-FGSM attack.
28 | '''
29 | adv_x = data
30 |
31 | # iteratively apply the FGSM with small step size
32 | for i in range(steps):
33 | grad = gen_grad(adv_x, model, labels)
34 | adv_x = symbolic_fgs(adv_x, grad, eps)
35 | return adv_x
--------------------------------------------------------------------------------
/mnist.py:
--------------------------------------------------------------------------------
1 | # --coding:utf-8--
2 | '''
3 | @author: cailikun
4 | @time: 19-3-25 下午4:43
5 | '''
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.utils.data
11 | from torchvision import datasets, transforms
12 |
13 |
14 |
15 | class modelA(nn.Module):
16 | def __init__(self):
17 | super(modelA, self).__init__()
18 | self.conv1 = nn.Conv2d(1, 64, 5)
19 | self.conv2 = nn.Conv2d(64, 64, 5)
20 | self.dropout1 = nn.Dropout(0.25)
21 | self.fc1 = nn.Linear(64 * 20 * 20, 128)
22 | self.dropout2 = nn.Dropout(0.5)
23 | self.fc2 = nn.Linear(128, 10)
24 |
25 | def forward(self, x):
26 | x = F.relu(self.conv1(x))
27 | x = F.relu(self.conv2(x))
28 | x = self.dropout1(x)
29 | x = x.view(x.size(0), -1)
30 | x = F.relu(self.fc1(x))
31 | x = self.dropout2(x)
32 | x = self.fc2(x)
33 | return x
34 |
35 | class modelB(nn.Module):
36 | def __init__(self):
37 | super(modelB, self).__init__()
38 | self.dropout1 = nn.Dropout(0.2)
39 | self.conv1 = nn.Conv2d(1, 64, 8)
40 | self.conv2 = nn.Conv2d(64, 128, 6)
41 | self.conv3 = nn.Conv2d(128, 128, 5)
42 | self.dropout2 = nn.Dropout(0.5)
43 | self.fc = nn.Linear(128 * 12 * 12, 10)
44 |
45 | def forward(self, x):
46 | x = self.dropout1(x)
47 | x = F.relu(self.conv1(x))
48 | x = F.relu(self.conv2(x))
49 | x = F.relu(self.conv3(x))
50 | x = self.dropout2(x)
51 | x = x.view(x.size(0), -1)
52 | x = self.fc(x)
53 | return x
54 |
55 | class modelC(nn.Module):
56 | def __init__(self):
57 | super(modelC, self).__init__()
58 | self.conv1 = nn.Conv2d(1, 128, 3)
59 | self.conv2 = nn.Conv2d(128, 64, 3)
60 | self.fc1 = nn.Linear(64 * 5 * 5, 128)
61 | self.fc2 = nn.Linear(128, 10)
62 |
63 | def forward(self, x):
64 | x = torch.tanh(self.conv1(x))
65 | x = F.max_pool2d(x, 2)
66 | x = torch.tanh(self.conv2(x))
67 | x = F.max_pool2d(x, 2)
68 | x = x.view(x.size(0), -1)
69 | x = F.relu(self.fc1(x))
70 | x = self.fc2(x)
71 | return x
72 |
73 | class modelD(nn.Module):
74 | def __init__(self):
75 | super(modelD, self).__init__()
76 | self.fc1 = nn.Linear(1 * 28 * 28, 300)
77 | self.dropout1 = nn.Dropout(0.5)
78 | self.fc2 = nn.Linear(300, 300)
79 | self.dropout2 = nn.Dropout(0.5)
80 | self.fc3 = nn.Linear(300, 300)
81 | self.dropout3 = nn.Dropout(0.5)
82 | self.fc4 = nn.Linear(300, 300)
83 | self.dropout4 = nn.Dropout(0.5)
84 | self.fc5 = nn.Linear(300, 10)
85 |
86 | def forward(self, x):
87 | x = x.view(x.size(0), -1)
88 | x = F.relu(self.fc1(x))
89 | x = self.dropout1(x)
90 | x = F.relu(self.fc2(x))
91 | x = self.dropout2(x)
92 | x = F.relu(self.fc3(x))
93 | x = self.dropout3(x)
94 | x = F.relu(self.fc4(x))
95 | x = self.dropout4(x)
96 | x = self.fc5(x)
97 | return x
98 |
99 | def model_mnist(type=1):
100 | '''
101 | Defines MNIST model
102 | '''
103 | models = [modelA, modelB, modelC, modelD]
104 | return models[type]()
105 |
106 | def load_model(model_path, type=1):
107 | model = model_mnist(type=type)
108 | model.load_state_dict(torch.load(model_path+'.pkl'))
109 | return model
110 |
111 |
112 |
113 |
--------------------------------------------------------------------------------
/models/modelA.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cailk/ensemble-adv-training-pytorch/50b8047c9150b726e8633b17d3db0275b5aca366/models/modelA.pkl
--------------------------------------------------------------------------------
/models/modelA_adv.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cailk/ensemble-adv-training-pytorch/50b8047c9150b726e8633b17d3db0275b5aca366/models/modelA_adv.pkl
--------------------------------------------------------------------------------
/models/modelA_ens.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cailk/ensemble-adv-training-pytorch/50b8047c9150b726e8633b17d3db0275b5aca366/models/modelA_ens.pkl
--------------------------------------------------------------------------------
/models/modelB.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cailk/ensemble-adv-training-pytorch/50b8047c9150b726e8633b17d3db0275b5aca366/models/modelB.pkl
--------------------------------------------------------------------------------
/models/modelC.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cailk/ensemble-adv-training-pytorch/50b8047c9150b726e8633b17d3db0275b5aca366/models/modelC.pkl
--------------------------------------------------------------------------------
/models/modelD.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cailk/ensemble-adv-training-pytorch/50b8047c9150b726e8633b17d3db0275b5aca366/models/modelD.pkl
--------------------------------------------------------------------------------
/simple_eval.py:
--------------------------------------------------------------------------------
1 | # --coding:utf-8--
2 | '''
3 | @author: cailikun
4 | @time: 2019/4/5 下午11:20
5 | '''
6 | import torch
7 | import torchvision
8 | import torch.optim as optim
9 | import torch.utils.data
10 | from torchvision import datasets, transforms
11 | from mnist import *
12 | from utils import train, test
13 | from attack_utils import gen_grad
14 | from fgs import symbolic_fgs, iter_fgs
15 | from os.path import basename
16 | import argparse
17 |
18 |
19 |
20 | def main(args):
21 | def get_model_type(model_name):
22 | model_type = {
23 | 'models/modelA':0, 'models/modelA_adv':0, 'models/modelA_ens':0,
24 | 'models/modelB':1, 'models/modelB_adv':1, 'models/modelB_ens':1,
25 | 'models/modelC':2, 'models/modelC_adv':2, 'models/modelC_ens':2,
26 | 'models/modelD':3, 'models/modelD_adv':3, 'models/modelD_ens':3,
27 | }
28 | if model_name not in model_type.keys():
29 | raise ValueError('Unknown model: {}'.format(model_name))
30 | return model_type[model_name]
31 |
32 | torch.manual_seed(args.seed)
33 | device = torch.device('cuda' if args.cuda else 'cpu')
34 |
35 | '''
36 | Preprocess MNIST dataset
37 | '''
38 | kwargs = {'num_workers': 20, 'pin_memory': True} if args.cuda else {}
39 | test_loader = torch.utils.data.DataLoader(
40 | datasets.MNIST('../attack_mnist', train=False, transform=transforms.ToTensor()),
41 | batch_size=args.batch_size, shuffle=True, **kwargs)
42 |
43 | # source model for crafting adversarial examples
44 | src_model_name = args.src_model
45 | type = get_model_type(src_model_name)
46 | src_model = load_model(src_model_name, type).to(device)
47 |
48 | # model(s) to target
49 | target_model_names = args.target_models
50 | target_models = [None] * len(target_model_names)
51 | for i in range(len(target_model_names)):
52 | type = get_model_type(target_model_names[i])
53 | target_models[i] = load_model(target_model_names[i], type=type).to(device)
54 |
55 | attack = args.attack
56 |
57 | # simply compute test error
58 | if attack == 'test':
59 | correct_s = 0
60 | with torch.no_grad():
61 | for (data, labels) in test_loader:
62 | data, labels = data.to(device), labels.to(device)
63 | correct_s += test(src_model, data, labels)
64 | err = 100. - 100. * correct_s / len(test_loader.dataset)
65 | print('Test error of {}: {:.2f}'.format(basename(src_model_name), err))
66 |
67 | for (name, target_model) in zip(target_model_names, target_models):
68 | correct_t = 0
69 | with torch.no_grad():
70 | for (data, labels) in test_loader:
71 | data, labels = data.to(device), labels.to(device)
72 | correct_t += test(target_model, data, labels)
73 | err = 100. - 100. * correct_t / len(test_loader.dataset)
74 | print('Test error of {}: {:.2f}'.format(basename(target_model_names), err))
75 | return
76 |
77 | eps = args.eps
78 |
79 | correct = 0
80 | for (data, labels) in test_loader:
81 | # take the random step in the RAND+FGSM
82 | if attack == 'rand_fgs':
83 | data = torch.clamp(data + torch.zeros_like(data).uniform_(-args.alpha, args.alpha), 0.0, 1.0)
84 | eps -= args.alpha
85 | data, labels = data.to(device), labels.to(device)
86 | grad = gen_grad(data, src_model, labels)
87 |
88 | # FGSM and RAND+FGSM one-shot attack
89 | if attack in ['fgs', 'rand_fgs']:
90 | adv_x = symbolic_fgs(data, grad, eps=eps)
91 |
92 | # iterative FGSM
93 | if attack == 'ifgs':
94 | adv_x = iter_fgs(src_model, data, labels, steps=args.steps, eps=args.eps/args.steps)
95 |
96 | correct += test(src_model, adv_x, labels)
97 | test_error = 100. - 100. * correct / len(test_loader.dataset)
98 | print('Test Set Error Rate: {:.2f}%'.format(test_error))
99 |
100 |
101 | if __name__ == '__main__':
102 | parser = argparse.ArgumentParser(description='Simple eval')
103 | parser.add_argument('attack', choices=['test', 'fgs', 'ifgs', 'rand_fgs', 'CW'], help='Name of attack')
104 | parser.add_argument('src_model', help='Source model for attack')
105 | parser.add_argument('target_models', nargs='*', help='path to target model(s)')
106 | parser.add_argument('--batch_size', type=int, default=64, help='Size of training batches (default: 64)')
107 | parser.add_argument('--eps', type=float, default=0.3, help='FGS attack scale (default: 0.3)')
108 | parser.add_argument('--alpha', type=float, default=0.05, help='RAND+FGSM random pertubation scale')
109 | parser.add_argument('--steps', type=int, default=10, help='Iterated FGS steps (default: 10)')
110 | parser.add_argument('--kappa', type=float, default=100, help='CW attack confidence')
111 | parser.add_argument('--seed', type=int, default=1, help='Random seed (default: 1)')
112 | parser.add_argument('--disable_cuda', action='store_true', default=False, help='Disable CUDA (default: False)')
113 |
114 | args = parser.parse_args()
115 | args.cuda = not args.disable_cuda and torch.cuda.is_available()
116 | main(args)
117 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # --coding:utf-8--
2 | '''
3 | @author: cailikun
4 | @time: 19-3-26 上午10:26
5 | '''
6 |
7 | import torch
8 | import torchvision
9 | import torch.optim as optim
10 | import torch.utils.data
11 | from torchvision import datasets, transforms
12 | from mnist import *
13 | from utils import train, test
14 | import argparse
15 | import os
16 |
17 |
18 | def main(args):
19 | torch.manual_seed(args.seed)
20 | device = torch.device('cuda' if args.cuda else 'cpu')
21 |
22 | '''
23 | Preprocess MNIST dataset
24 | '''
25 | kwargs = {'num_workers': 20, 'pin_memory': True} if args.cuda else {}
26 | train_loader = torch.utils.data.DataLoader(
27 | datasets.MNIST('../attack_mnist', train=True, download=True, transform=transforms.ToTensor()),
28 | batch_size=args.batch_size, shuffle=True, **kwargs)
29 | test_loader = torch.utils.data.DataLoader(
30 | datasets.MNIST('../attack_mnist', train=False, transform=transforms.ToTensor()),
31 | batch_size=args.batch_size, shuffle=True, **kwargs)
32 |
33 | model = model_mnist(type=args.type).to(device)
34 | optimizer = optim.Adam(model.parameters())
35 |
36 | # Train an MNIST model
37 | for epoch in range(args.epochs):
38 | for batch_idx, (data, labels) in enumerate(train_loader):
39 | data, labels = data.to(device), labels.to(device)
40 | train(epoch, batch_idx, model, data, labels, optimizer)
41 |
42 | # Finally print the result!
43 | correct = 0
44 | with torch.no_grad():
45 | for (data, labels) in test_loader:
46 | data, labels = data.to(device), labels.to(device)
47 | correct += test(model, data, labels)
48 | test_error = 100. - 100. * correct / len(test_loader.dataset)
49 | print('Test Set Error Rate: {:.2f}%'.format(test_error))
50 |
51 | torch.save(model.state_dict(), args.model+'.pkl')
52 |
53 |
54 | if __name__ == '__main__':
55 | parser = argparse.ArgumentParser(description='Training MNIST model')
56 | parser.add_argument('model', help='path to model')
57 | parser.add_argument('--type', type=int, default=1, help='Model type (default: 1)')
58 | parser.add_argument('--seed', type=int, default=1, help='Random seed (default: 1)')
59 | parser.add_argument('--disable_cuda', action='store_true', default=False, help='Disable CUDA (default: False)')
60 | parser.add_argument('--batch_size', type=int, default=64, help='Size of training batches (default: 64)')
61 | parser.add_argument('--epochs', type=int, default=6, help='Number of epochs to train (default: 6)')
62 | #parser.print_help()
63 | args = parser.parse_args()
64 | args.cuda = not args.disable_cuda and torch.cuda.is_available()
65 | main(args)
66 |
67 |
68 |
--------------------------------------------------------------------------------
/train_adv.py:
--------------------------------------------------------------------------------
1 | # --coding:utf-8--
2 | '''
3 | @author: cailikun
4 | @time: 2019/4/2 下午2:13
5 | '''
6 |
7 | import torch
8 | import torchvision
9 | import torch.optim as optim
10 | import torch.utils.data
11 | from torchvision import datasets, transforms
12 | from mnist import *
13 | from utils import train, test
14 | from attack_utils import gen_grad
15 | from fgs import symbolic_fgs
16 | import argparse
17 | import os
18 |
19 | def main(args):
20 | def get_model_type(model_name):
21 | model_type = {
22 | 'models/modelA': 0, 'models/modelA_adv': 0, 'models/modelA_ens': 0,
23 | 'models/modelB': 1, 'models/modelB_adv': 1, 'models/modelB_ens': 1,
24 | 'models/modelC': 2, 'models/modelC_adv': 2, 'models/modelC_ens': 2,
25 | 'models/modelD': 3, 'models/modelD_adv': 3, 'models/modelD_ens': 3,
26 | }
27 | if model_name not in model_type.keys():
28 | raise ValueError('Unknown model: {}'.format(model_name))
29 | return model_type[model_name]
30 |
31 | torch.manual_seed(args.seed)
32 | device = torch.device('cuda' if args.cuda else 'cpu')
33 |
34 | '''
35 | Preprocess MNIST dataset
36 | '''
37 | kwargs = {'num_workers': 20, 'pin_memory': True} if args.cuda else {}
38 | train_loader = torch.utils.data.DataLoader(
39 | datasets.MNIST('../attack_mnist', train=True, download=True, transform=transforms.ToTensor()),
40 | batch_size=args.batch_size, shuffle=True, **kwargs)
41 | test_loader = torch.utils.data.DataLoader(
42 | datasets.MNIST('../attack_mnist', train=False, transform=transforms.ToTensor()),
43 | batch_size=args.batch_size, shuffle=True, **kwargs)
44 |
45 | eps = args.eps
46 |
47 | # if src_models is not None, we train on adversarial examples that come
48 | # from multiple models
49 | adv_model_names = args.adv_models
50 | adv_models = [None] * len(adv_model_names)
51 | for i in range(len(adv_model_names)):
52 | type = get_model_type(adv_model_names[i])
53 | adv_models[i] = load_model(adv_model_names[i], type=type).to(device)
54 |
55 | model = model_mnist(type=args.type).to(device)
56 | optimizer = optim.Adam(model.parameters())
57 |
58 | # Train on MNIST model
59 | x_advs = [None] * (len(adv_models) + 1)
60 | for epoch in range(args.epochs):
61 | for batch_idx, (data, labels) in enumerate(train_loader):
62 | data, labels = data.to(device), labels.to(device)
63 | for i, m in enumerate(adv_models + [model]):
64 | grad = gen_grad(data, m, labels, loss='training')
65 | x_advs[i] = symbolic_fgs(data, grad, eps=eps)
66 | train(epoch, batch_idx, model, data, labels, optimizer, x_advs=x_advs)
67 |
68 | # Finally print the result
69 | correct = 0
70 | with torch.no_grad():
71 | for (data, labels) in test_loader:
72 | data, labels = data.to(device), labels.to(device)
73 | correct += test(model, data, labels)
74 | test_error = 100. - 100. * correct / len(test_loader.dataset)
75 | print('Test Set Error Rate: {:.2f}%'.format(test_error))
76 |
77 | torch.save(model.state_dict(), args.model + '.pkl')
78 |
79 |
80 |
81 | if __name__ == '__main__':
82 | parser = argparse.ArgumentParser(description='Adversarial Training MNIST model')
83 | parser.add_argument('model', help='path to model')
84 | parser.add_argument('adv_models', nargs='*', help='path to adv model(s)')
85 | parser.add_argument('--type', type=int, default=0, help='Model type (default: 0)')
86 | parser.add_argument('--seed', type=int, default=1, help='Random seed (default: 1)')
87 | parser.add_argument('--disable_cuda', action='store_true', default=False, help='Disable CUDA (default: False)')
88 | parser.add_argument('--batch_size', type=int, default=64, help='Size of training batches (default: 64)')
89 | parser.add_argument('--epochs', type=int, default=12, help='Number of epochs (default: 12)')
90 | parser.add_argument('--eps', type=float, default=0.3, help='FGSM attack scale (default: 0.3)')
91 |
92 | args = parser.parse_args()
93 | args.cuda = not args.disable_cuda and torch.cuda.is_available()
94 | main(args)
95 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # --coding:utf-8--
2 | '''
3 | @author: cailikun
4 | @time: 19-3-27 上午10:26
5 | '''
6 | import torch
7 | import torch.nn.functional as F
8 | from attack_utils import gen_adv_loss
9 | import numpy as np
10 |
11 | EVAL_FREQUENCY = 100
12 |
13 | def train(epoch, batch_idx, model, data, labels, optimizer, x_advs=None):
14 | model.train()
15 | optimizer.zero_grad()
16 | # Generate cross-entropy loss for training
17 | logits = model(data)
18 | preds = logits.max(1)[1]
19 | loss1 = gen_adv_loss(logits, labels, mean=True)
20 |
21 | # add adversarial training loss
22 | if x_advs is not None:
23 |
24 | # choose source of adversarial examples at random
25 | # (for ensemble adversarial training)
26 | idx = np.random.randint(len(x_advs))
27 | logits_adv = model(x_advs[idx])
28 | loss2 = gen_adv_loss(logits_adv, labels, mean=True)
29 | loss = 0.5 * (loss1 + loss2)
30 | else:
31 | loss2 = torch.zeros(loss1.size())
32 | loss = loss1
33 | loss.backward()
34 | optimizer.step()
35 | if batch_idx % EVAL_FREQUENCY == 0:
36 | print('Step: {}(epoch: {})\tLoss: {:.6f}<=({:.6f}, {:.6f})\tError: {:.2f}%'.format(
37 | batch_idx, epoch+1, loss.item(), loss1.item(), loss2.item(), error_rate(preds, labels)
38 | ))
39 |
40 | def test(model, data, labels):
41 | model.eval()
42 | correct = 0
43 | logits = model(data)
44 |
45 | # Prediction for the test set
46 | preds = logits.max(1)[1]
47 | correct += preds.eq(labels).sum().item()
48 | return correct
49 |
50 | def error_rate(preds, labels):
51 | '''
52 | Run the error rate
53 | '''
54 | assert preds.size() == labels.size()
55 | return 100.0 - (100.0 * preds.eq(labels).sum().item()) / preds.size(0)
56 |
57 |
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------