├── README.md ├── adversarialbox ├── __init__.py ├── attacks.py ├── train.py └── utils.py ├── mnist_adv_train.py ├── mnist_attack.py ├── mnist_blackbox.py ├── models.py └── models └── adv_trained_lenet5.pkl /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Box - Pytorch Adversarial Attack and Training 2 | 3 | Luyu Wang and Gavin Ding, Borealis AI 4 | 5 | ## Motivation? 6 | [CleverHans](https://github.com/tensorflow/cleverhans) comes in handy for Tensorflow. However, PyTorch does not have the luck at this moment. [Foolbox](https://github.com/bethgelab/foolbox) supports multiple deep learning frameworks, but it lacks many major implementations (e.g., black-box attack, Carlini-Wagner attack, adversarial training). We feel there is a need to write an easy-to-use and versatile library to help our fellow researchers and engineers. 7 | 8 | **We have a much more updated version called [AdverTorch](https://github.com/BorealisAI/advertorch). You can find most of the popular attacks there. This repo will not be maintained anymore.** 9 | 10 | ## Usage 11 | from adversarialbox.attacks import FGSMAttack 12 | adversary = FGSMAttack(model, epsilon=0.1) 13 | X_adv = adversary.perturb(X_i, y_i) 14 | 15 | ## Examples 16 | 1. MNIST with FGSM ([code](https://github.com/wanglouis49/pytorch-adversarial_box/blob/master/mnist_attack.py)) 17 | 2. Adversarial Training on MNIST ([code](https://github.com/wanglouis49/pytorch-adversarial_box/blob/master/mnist_adv_train.py)) 18 | 3. MNIST using a black-box attack ([code](https://github.com/wanglouis49/pytorch-adversarial_box/blob/master/mnist_blackbox.py)) 19 | 20 | ## List of supported attacks 21 | 1. FGSM 22 | 2. PGD 23 | 3. Black-box 24 | -------------------------------------------------------------------------------- /adversarialbox/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanglouis49/pytorch-adversarial_box/bddb5a899a7658182ea78063fd7ec405de083956/adversarialbox/__init__.py -------------------------------------------------------------------------------- /adversarialbox/attacks.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | from collections import Iterable 4 | from scipy.stats import truncnorm 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from adversarialbox.utils import to_var 10 | 11 | # --- White-box attacks --- 12 | 13 | class FGSMAttack(object): 14 | def __init__(self, model=None, epsilon=None): 15 | """ 16 | One step fast gradient sign method 17 | """ 18 | self.model = model 19 | self.epsilon = epsilon 20 | self.loss_fn = nn.CrossEntropyLoss() 21 | 22 | def perturb(self, X_nat, y, epsilons=None): 23 | """ 24 | Given examples (X_nat, y), returns their adversarial 25 | counterparts with an attack length of epsilon. 26 | """ 27 | # Providing epsilons in batch 28 | if epsilons is not None: 29 | self.epsilon = epsilons 30 | 31 | X = np.copy(X_nat) 32 | 33 | X_var = to_var(torch.from_numpy(X), requires_grad=True) 34 | y_var = to_var(torch.LongTensor(y)) 35 | 36 | scores = self.model(X_var) 37 | loss = self.loss_fn(scores, y_var) 38 | loss.backward() 39 | grad_sign = X_var.grad.data.cpu().sign().numpy() 40 | 41 | X += self.epsilon * grad_sign 42 | X = np.clip(X, 0, 1) 43 | 44 | return X 45 | 46 | 47 | class LinfPGDAttack(object): 48 | def __init__(self, model=None, epsilon=0.3, k=40, a=0.01, 49 | random_start=True): 50 | """ 51 | Attack parameter initialization. The attack performs k steps of 52 | size a, while always staying within epsilon from the initial 53 | point. 54 | https://github.com/MadryLab/mnist_challenge/blob/master/pgd_attack.py 55 | """ 56 | self.model = model 57 | self.epsilon = epsilon 58 | self.k = k 59 | self.a = a 60 | self.rand = random_start 61 | self.loss_fn = nn.CrossEntropyLoss() 62 | 63 | def perturb(self, X_nat, y): 64 | """ 65 | Given examples (X_nat, y), returns adversarial 66 | examples within epsilon of X_nat in l_infinity norm. 67 | """ 68 | if self.rand: 69 | X = X_nat + np.random.uniform(-self.epsilon, self.epsilon, 70 | X_nat.shape).astype('float32') 71 | else: 72 | X = np.copy(X_nat) 73 | 74 | for i in range(self.k): 75 | X_var = to_var(torch.from_numpy(X), requires_grad=True) 76 | y_var = to_var(torch.LongTensor(y)) 77 | 78 | scores = self.model(X_var) 79 | loss = self.loss_fn(scores, y_var) 80 | loss.backward() 81 | grad = X_var.grad.data.cpu().numpy() 82 | 83 | X += self.a * np.sign(grad) 84 | 85 | X = np.clip(X, X_nat - self.epsilon, X_nat + self.epsilon) 86 | X = np.clip(X, 0, 1) # ensure valid pixel range 87 | 88 | return X 89 | 90 | 91 | # --- Black-box attacks --- 92 | 93 | def jacobian(model, x, nb_classes=10): 94 | """ 95 | This function will return a list of PyTorch gradients 96 | """ 97 | list_derivatives = [] 98 | x_var = to_var(torch.from_numpy(x), requires_grad=True) 99 | 100 | # derivatives for each class 101 | for class_ind in range(nb_classes): 102 | score = model(x_var)[:, class_ind] 103 | score.backward() 104 | list_derivatives.append(x_var.grad.data.cpu().numpy()) 105 | x_var.grad.data.zero_() 106 | 107 | return list_derivatives 108 | 109 | 110 | def jacobian_augmentation(model, X_sub_prev, Y_sub, lmbda=0.1): 111 | """ 112 | Create new numpy array for adversary training data 113 | with twice as many components on the first dimension. 114 | """ 115 | X_sub = np.vstack([X_sub_prev, X_sub_prev]) 116 | 117 | # For each input in the previous' substitute training iteration 118 | for ind, x in enumerate(X_sub_prev): 119 | grads = jacobian(model, x) 120 | # Select gradient corresponding to the label predicted by the oracle 121 | grad = grads[Y_sub[ind]] 122 | 123 | # Compute sign matrix 124 | grad_val = np.sign(grad) 125 | 126 | # Create new synthetic point in adversary substitute training set 127 | X_sub[len(X_sub_prev)+ind] = X_sub[ind] + lmbda * grad_val #??? 128 | 129 | # Return augmented training data (needs to be labeled afterwards) 130 | return X_sub 131 | -------------------------------------------------------------------------------- /adversarialbox/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adversarial training 3 | """ 4 | 5 | import copy 6 | import numpy as np 7 | from collections import Iterable 8 | from scipy.stats import truncnorm 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from adversarialbox.attacks import FGSMAttack, LinfPGDAttack 14 | from adversarialbox.utils import truncated_normal 15 | 16 | 17 | 18 | def adv_train(X, y, model, criterion, adversary): 19 | """ 20 | Adversarial training. Returns pertubed mini batch. 21 | """ 22 | 23 | # If adversarial training, need a snapshot of 24 | # the model at each batch to compute grad, so 25 | # as not to mess up with the optimization step 26 | model_cp = copy.deepcopy(model) 27 | for p in model_cp.parameters(): 28 | p.requires_grad = False 29 | model_cp.eval() 30 | 31 | adversary.model = model_cp 32 | 33 | X_adv = adversary.perturb(X.numpy(), y) 34 | 35 | return torch.from_numpy(X_adv) 36 | 37 | 38 | def FGSM_train_rnd(X, y, model, criterion, fgsm_adversary, epsilon_max=0.3): 39 | """ 40 | FGSM with epsilon sampled from a truncated normal distribution. 41 | Returns pertubed mini batch. 42 | Kurakin et al, ADVERSARIAL MACHINE LEARNING AT SCALE, 2016 43 | """ 44 | 45 | # If adversarial training, need a snapshot of 46 | # the model at each batch to compute grad, so 47 | # as not to mess up with the optimization step 48 | model_cp = copy.deepcopy(model) 49 | for p in model_cp.parameters(): 50 | p.requires_grad = False 51 | model_cp.eval() 52 | 53 | fgsm_adversary.model = model_cp 54 | 55 | # truncated Gaussian 56 | m = X.size()[0] # mini-batch size 57 | mean, std = 0., epsilon_max/2 58 | epsilons = np.abs(truncated_normal(mean, std, m))[:, np.newaxis, \ 59 | np.newaxis, np.newaxis] 60 | 61 | X_adv = fgsm_adversary.perturb(X.numpy(), y, epsilons) 62 | 63 | return torch.from_numpy(X_adv) 64 | 65 | 66 | -------------------------------------------------------------------------------- /adversarialbox/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | import torch.nn as nn 5 | from torch.utils.data import sampler 6 | 7 | 8 | def truncated_normal(mean=0.0, stddev=1.0, m=1): 9 | ''' 10 | The generated values follow a normal distribution with specified 11 | mean and standard deviation, except that values whose magnitude is 12 | more than 2 standard deviations from the mean are dropped and 13 | re-picked. Returns a vector of length m 14 | ''' 15 | samples = [] 16 | for i in range(m): 17 | while True: 18 | sample = np.random.normal(mean, stddev) 19 | if np.abs(sample) <= 2 * stddev: 20 | break 21 | samples.append(sample) 22 | assert len(samples) == m, "something wrong" 23 | if m == 1: 24 | return samples[0] 25 | else: 26 | return np.array(samples) 27 | 28 | 29 | # --- PyTorch helpers --- 30 | 31 | def to_var(x, requires_grad=False, volatile=False): 32 | """ 33 | Varialbe type that automatically choose cpu or cuda 34 | """ 35 | if torch.cuda.is_available(): 36 | x = x.cuda() 37 | return Variable(x, requires_grad=requires_grad, volatile=volatile) 38 | 39 | 40 | def pred_batch(x, model): 41 | """ 42 | batch prediction helper 43 | """ 44 | y_pred = np.argmax(model(to_var(x)).data.cpu().numpy(), axis=1) 45 | return torch.from_numpy(y_pred) 46 | 47 | 48 | def test(model, loader, blackbox=False, hold_out_size=None): 49 | """ 50 | Check model accuracy on model based on loader (train or test) 51 | """ 52 | model.eval() 53 | 54 | num_correct, num_samples = 0, len(loader.dataset) 55 | 56 | if blackbox: 57 | num_samples -= hold_out_size 58 | 59 | for x, y in loader: 60 | x_var = to_var(x, volatile=True) 61 | scores = model(x_var) 62 | _, preds = scores.data.cpu().max(1) 63 | num_correct += (preds == y).sum() 64 | 65 | acc = float(num_correct)/float(num_samples) 66 | print('Got %d/%d correct (%.2f%%) on the clean data' 67 | % (num_correct, num_samples, 100 * acc)) 68 | 69 | return acc 70 | 71 | 72 | def attack_over_test_data(model, adversary, param, loader_test, oracle=None): 73 | """ 74 | Given target model computes accuracy on perturbed data 75 | """ 76 | total_correct = 0 77 | total_samples = len(loader_test.dataset) 78 | 79 | # For black-box 80 | if oracle is not None: 81 | total_samples -= param['hold_out_size'] 82 | 83 | for t, (X, y) in enumerate(loader_test): 84 | y_pred = pred_batch(X, model) 85 | X_adv = adversary.perturb(X.numpy(), y_pred) 86 | X_adv = torch.from_numpy(X_adv) 87 | 88 | if oracle is not None: 89 | y_pred_adv = pred_batch(X_adv, oracle) 90 | else: 91 | y_pred_adv = pred_batch(X_adv, model) 92 | 93 | total_correct += (y_pred_adv.numpy() == y.numpy()).sum() 94 | 95 | acc = total_correct/total_samples 96 | 97 | print('Got %d/%d correct (%.2f%%) on the perturbed data' 98 | % (total_correct, total_samples, 100 * acc)) 99 | 100 | return acc 101 | 102 | 103 | def batch_indices(batch_nb, data_length, batch_size): 104 | """ 105 | This helper function computes a batch start and end index 106 | :param batch_nb: the batch number 107 | :param data_length: the total length of the data being parsed by batches 108 | :param batch_size: the number of inputs in each batch 109 | :return: pair of (start, end) indices 110 | """ 111 | # Batch start and end index 112 | start = int(batch_nb * batch_size) 113 | end = int((batch_nb + 1) * batch_size) 114 | 115 | # When there are not enough inputs left, we reuse some to complete the 116 | # batch 117 | if end > data_length: 118 | shift = end - data_length 119 | start -= shift 120 | end -= shift 121 | 122 | return start, end 123 | -------------------------------------------------------------------------------- /mnist_adv_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adversarially train LeNet-5 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.datasets as datasets 8 | import torchvision.transforms as transforms 9 | from torch.autograd import Variable 10 | import torch.nn.functional as F 11 | 12 | from adversarialbox.attacks import FGSMAttack, LinfPGDAttack 13 | from adversarialbox.train import adv_train, FGSM_train_rnd 14 | from adversarialbox.utils import to_var, pred_batch, test 15 | 16 | from models import LeNet5 17 | 18 | 19 | # Hyper-parameters 20 | param = { 21 | 'batch_size': 128, 22 | 'test_batch_size': 100, 23 | 'num_epochs': 15, 24 | 'delay': 10, 25 | 'learning_rate': 1e-3, 26 | 'weight_decay': 5e-4, 27 | } 28 | 29 | 30 | # Data loaders 31 | train_dataset = datasets.MNIST(root='../data/',train=True, download=True, 32 | transform=transforms.ToTensor()) 33 | loader_train = torch.utils.data.DataLoader(train_dataset, 34 | batch_size=param['batch_size'], shuffle=True) 35 | 36 | test_dataset = datasets.MNIST(root='../data/', train=False, download=True, 37 | transform=transforms.ToTensor()) 38 | loader_test = torch.utils.data.DataLoader(test_dataset, 39 | batch_size=param['test_batch_size'], shuffle=True) 40 | 41 | 42 | # Setup the model 43 | net = LeNet5() 44 | 45 | if torch.cuda.is_available(): 46 | print('CUDA ensabled.') 47 | net.cuda() 48 | net.train() 49 | 50 | # Adversarial training setup 51 | #adversary = FGSMAttack(epsilon=0.3) 52 | adversary = LinfPGDAttack() 53 | 54 | # Train the model 55 | criterion = nn.CrossEntropyLoss() 56 | optimizer = torch.optim.RMSprop(net.parameters(), lr=param['learning_rate'], 57 | weight_decay=param['weight_decay']) 58 | 59 | for epoch in range(param['num_epochs']): 60 | 61 | print('Starting epoch %d / %d' % (epoch + 1, param['num_epochs'])) 62 | 63 | for t, (x, y) in enumerate(loader_train): 64 | 65 | x_var, y_var = to_var(x), to_var(y.long()) 66 | loss = criterion(net(x_var), y_var) 67 | 68 | # adversarial training 69 | if epoch+1 > param['delay']: 70 | # use predicted label to prevent label leaking 71 | y_pred = pred_batch(x, net) 72 | x_adv = adv_train(x, y_pred, net, criterion, adversary) 73 | x_adv_var = to_var(x_adv) 74 | loss_adv = criterion(net(x_adv_var), y_var) 75 | loss = (loss + loss_adv) / 2 76 | 77 | if (t + 1) % 100 == 0: 78 | print('t = %d, loss = %.8f' % (t + 1, loss.data[0])) 79 | 80 | optimizer.zero_grad() 81 | loss.backward() 82 | optimizer.step() 83 | 84 | 85 | test(net, loader_test) 86 | 87 | torch.save(net.state_dict(), 'models/adv_trained_lenet5.pkl') 88 | -------------------------------------------------------------------------------- /mnist_attack.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adversarial attacks on LeNet5 3 | """ 4 | from time import time 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.datasets as datasets 8 | import torchvision.transforms as transforms 9 | from torch.autograd import Variable 10 | import torch.nn.functional as F 11 | 12 | from adversarialbox.attacks import FGSMAttack, LinfPGDAttack 13 | from adversarialbox.utils import to_var, pred_batch, test, \ 14 | attack_over_test_data 15 | 16 | from models import LeNet5 17 | 18 | 19 | # Hyper-parameters 20 | param = { 21 | 'test_batch_size': 100, 22 | 'epsilon': 0.3, 23 | } 24 | 25 | 26 | # Data loaders 27 | test_dataset = datasets.MNIST(root='../data/', train=False, download=True, 28 | transform=transforms.ToTensor()) 29 | loader_test = torch.utils.data.DataLoader(test_dataset, 30 | batch_size=param['test_batch_size'], shuffle=False) 31 | 32 | 33 | # Setup model to be attacked 34 | net = LeNet5() 35 | net.load_state_dict(torch.load('models/adv_trained_lenet5.pkl')) 36 | 37 | if torch.cuda.is_available(): 38 | print('CUDA ensabled.') 39 | net.cuda() 40 | 41 | for p in net.parameters(): 42 | p.requires_grad = False 43 | net.eval() 44 | 45 | test(net, loader_test) 46 | 47 | 48 | # Adversarial attack 49 | adversary = FGSMAttack(net, param['epsilon']) 50 | # adversary = LinfPGDAttack(net, random_start=False) 51 | 52 | 53 | t0 = time() 54 | attack_over_test_data(net, adversary, param, loader_test) 55 | print('{}s eclipsed.'.format(time()-t0)) 56 | -------------------------------------------------------------------------------- /mnist_blackbox.py: -------------------------------------------------------------------------------- 1 | """ 2 | PyTorch Implementation of Papernot's Black-Box Attack 3 | arXiv:1602.02697 4 | """ 5 | 6 | import pickle 7 | import numpy as np 8 | import pandas as pd 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torchvision.datasets as datasets 13 | import torchvision.transforms as transforms 14 | from torch.autograd import Variable 15 | import torch.nn.functional as F 16 | from torch.utils.data.sampler import SubsetRandomSampler 17 | 18 | from adversarialbox.attacks import FGSMAttack, LinfPGDAttack, \ 19 | jacobian_augmentation 20 | from adversarialbox.utils import to_var, pred_batch, test, \ 21 | attack_over_test_data, batch_indices 22 | 23 | from models import LeNet5, SubstituteModel 24 | 25 | 26 | def MNIST_bbox_sub(param, loader_hold_out, loader_test): 27 | """ 28 | Train a substitute model using Jacobian data augmentation 29 | arXiv:1602.02697 30 | """ 31 | 32 | # Setup the substitute 33 | net = SubstituteModel() 34 | 35 | if torch.cuda.is_available(): 36 | print('CUDA ensabled for the substitute.') 37 | net.cuda() 38 | net.train() 39 | 40 | # Setup the oracle 41 | oracle = LeNet5() 42 | 43 | if torch.cuda.is_available(): 44 | print('CUDA ensabled for the oracle.') 45 | oracle.cuda() 46 | oracle.load_state_dict(torch.load(param['oracle_name']+'.pkl')) 47 | oracle.eval() 48 | 49 | 50 | # Setup training 51 | criterion = nn.CrossEntropyLoss() 52 | # Careful optimization is crucial to train a well-representative 53 | # substitute. In Tensorflow Adam has some problem: 54 | # (https://github.com/tensorflow/cleverhans/issues/183) 55 | # But it works fine here in PyTorch (you may try other optimization 56 | # methods 57 | optimizer = torch.optim.Adam(net.parameters(), lr=param['learning_rate']) 58 | 59 | # Data held out for initial training 60 | data_iter = iter(loader_hold_out) 61 | X_sub, y_sub = data_iter.next() 62 | X_sub, y_sub = X_sub.numpy(), y_sub.numpy() 63 | 64 | # Train the substitute and augment dataset alternatively 65 | for rho in range(param['data_aug']): 66 | print("Substitute training epoch #"+str(rho)) 67 | print("Training data: "+str(len(X_sub))) 68 | 69 | rng = np.random.RandomState() 70 | 71 | # model training 72 | for epoch in range(param['nb_epochs']): 73 | 74 | print('Starting epoch %d / %d' % (epoch + 1, param['nb_epochs'])) 75 | 76 | # Compute number of batches 77 | nb_batches = int(np.ceil(float(len(X_sub)) / 78 | param['test_batch_size'])) 79 | assert nb_batches * param['test_batch_size'] >= len(X_sub) 80 | 81 | # Indices to shuffle training set 82 | index_shuf = list(range(len(X_sub))) 83 | rng.shuffle(index_shuf) 84 | 85 | for batch in range(nb_batches): 86 | 87 | # Compute batch start and end indices 88 | start, end = batch_indices(batch, len(X_sub), 89 | param['test_batch_size']) 90 | 91 | x = X_sub[index_shuf[start:end]] 92 | y = y_sub[index_shuf[start:end]] 93 | 94 | scores = net(to_var(torch.from_numpy(x))) 95 | loss = criterion(scores, to_var(torch.from_numpy(y).long())) 96 | 97 | optimizer.zero_grad() 98 | loss.backward() 99 | optimizer.step() 100 | 101 | print('loss = %.8f' % (loss.data[0])) 102 | test(net, loader_test, blackbox=True, hold_out_size=param['hold_out_size']) 103 | 104 | # If we are not at last substitute training iteration, augment dataset 105 | if rho < param['data_aug'] - 1: 106 | print("Augmenting substitute training data.") 107 | # Perform the Jacobian augmentation 108 | X_sub = jacobian_augmentation(net, X_sub, y_sub) 109 | 110 | print("Labeling substitute training data.") 111 | # Label the newly generated synthetic points using the black-box 112 | scores = oracle(to_var(torch.from_numpy(X_sub))) 113 | # Note here that we take the argmax because the adversary 114 | # only has access to the label (not the probabilities) output 115 | # by the black-box model 116 | y_sub = np.argmax(scores.data.cpu().numpy(), axis=1) 117 | 118 | 119 | torch.save(net.state_dict(), param['oracle_name']+'_sub.pkl') 120 | 121 | 122 | 123 | 124 | if __name__ == "__main__": 125 | 126 | # Hyper-parameters 127 | param = { 128 | 'hold_out_size': 150, 129 | 'test_batch_size': 128, 130 | 'nb_epochs': 10, 131 | 'learning_rate': 0.001, 132 | 'data_aug': 6, 133 | 'oracle_name': 'models/adv_trained_lenet5', 134 | 'epsilon': 0.3, 135 | } 136 | 137 | # Data loaders 138 | # We need to hold out 150 data points from the test data 139 | # This is a bit tricky in PyTorch 140 | # We adopt the way from: 141 | # https://github.com/pytorch/pytorch/issues/1106 142 | hold_out_data = datasets.MNIST(root='../data/', train=True, 143 | download=True, transform=transforms.ToTensor()) 144 | test_dataset = datasets.MNIST(root='../data/', train=False, 145 | download=True, transform=transforms.ToTensor()) 146 | 147 | indices = list(range(test_dataset.test_data.size(0))) 148 | split = param['hold_out_size'] 149 | rng = np.random.RandomState() 150 | rng.shuffle(indices) 151 | 152 | hold_out_idx, test_idx = indices[:split], indices[split:] 153 | 154 | hold_out_sampler = SubsetRandomSampler(hold_out_idx) 155 | test_sampler = SubsetRandomSampler(test_idx) 156 | 157 | loader_hold_out = torch.utils.data.DataLoader(hold_out_data, 158 | batch_size=param['hold_out_size'], sampler=hold_out_sampler, 159 | shuffle=False) 160 | loader_test = torch.utils.data.DataLoader(test_dataset, 161 | batch_size=param['test_batch_size'], sampler=test_sampler, 162 | shuffle=False) 163 | 164 | 165 | # Train the substitute 166 | MNIST_bbox_sub(param, loader_hold_out, loader_test) 167 | 168 | 169 | # Setup models 170 | net = SubstituteModel() 171 | oracle = LeNet5() 172 | 173 | net.load_state_dict(torch.load(param['oracle_name']+'_sub.pkl')) 174 | oracle.load_state_dict(torch.load(param['oracle_name']+'.pkl')) 175 | 176 | if torch.cuda.is_available(): 177 | net.cuda() 178 | oracle.cuda() 179 | print('CUDA ensabled.') 180 | 181 | for p in net.parameters(): 182 | p.requires_grad = False 183 | 184 | net.eval() 185 | oracle.eval() 186 | 187 | 188 | # Setup adversarial attacks 189 | adversary = FGSMAttack(net, param['epsilon']) 190 | 191 | print('For the substitute model:') 192 | test(net, loader_test, blackbox=True, hold_out_size=param['hold_out_size']) 193 | 194 | # Setup oracle 195 | print('For the oracle'+param['oracle_name']) 196 | print('agaist blackbox FGSM attacks using gradients from the substitute:') 197 | attack_over_test_data(net, adversary, param, loader_test, oracle) 198 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LeNet5(nn.Module): 6 | def __init__(self): 7 | super(LeNet5, self).__init__() 8 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1, stride=1) 9 | self.relu1 = nn.ReLU(inplace=True) 10 | self.maxpool1 = nn.MaxPool2d(2) 11 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=1) 12 | self.relu2 = nn.ReLU(inplace=True) 13 | self.maxpool2 = nn.MaxPool2d(2) 14 | self.linear1 = nn.Linear(7*7*64, 200) 15 | self.relu3 = nn.ReLU(inplace=True) 16 | self.linear2 = nn.Linear(200, 10) 17 | 18 | def forward(self, x): 19 | out = self.maxpool1(self.relu1(self.conv1(x))) 20 | out = self.maxpool2(self.relu2(self.conv2(out))) 21 | out = out.view(out.size(0), -1) 22 | out = self.relu3(self.linear1(out)) 23 | out = self.linear2(out) 24 | return out 25 | 26 | 27 | class SubstituteModel(nn.Module): 28 | 29 | def __init__(self): 30 | super(SubstituteModel, self).__init__() 31 | self.linear1 = nn.Linear(28*28, 200) 32 | self.relu1 = nn.ReLU(inplace=True) 33 | self.linear2 = nn.Linear(200, 200) 34 | self.relu2 = nn.ReLU(inplace=True) 35 | self.linear3 = nn.Linear(200, 10) 36 | 37 | def forward(self, x): 38 | out = x.view(x.size(0), -1) 39 | out = self.relu1(self.linear1(out)) 40 | out = self.relu2(self.linear2(out)) 41 | out = self.linear3(out) 42 | return out 43 | -------------------------------------------------------------------------------- /models/adv_trained_lenet5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanglouis49/pytorch-adversarial_box/bddb5a899a7658182ea78063fd7ec405de083956/models/adv_trained_lenet5.pkl --------------------------------------------------------------------------------