├── .gitignore ├── README.md ├── adversary.py ├── cleaner.py ├── datasets └── datasets.py ├── main.py ├── misc ├── FGSM.PNG ├── IFGSM.PNG ├── nontargeted_1.PNG ├── nontargeted_2.PNG ├── nontargeted_3.PNG ├── overview.PNG ├── targetd_9_1.PNG ├── targetd_9_2.PNG └── targetd_9_3.PNG ├── models └── toynet.py ├── solver.py └── utils ├── utils.py └── visdom_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | 3 | checkpoints/* 4 | summary/* 5 | output/* 6 | 7 | datasets/MNIST 8 | 9 | git.sh 10 | .gitignore 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FGSM(Fast Gradient Sign Method) 2 |
3 | 4 | ### Overview 5 | Simple pytorch implementation of FGSM and I-FGSM 6 | (FGSM : [explaining and harnessing adversarial examples, Goodfellow et al.]) 7 | (I-FGSM : [adversarial examples in the physical world, Kurakin et al.]) 8 | ![overview](misc/overview.PNG) 9 | #### FGSM 10 | ![FGSM](misc/FGSM.PNG) 11 | #### I-FGSM 12 | ![IFGSM](misc/IFGSM.PNG) 13 |
14 | 15 | ### Dependencies 16 | ``` 17 | python 3.6.4 18 | pytorch 0.3.1.post2 19 | visdom(optional) 20 | tensorboardX(optional) 21 | tensorflow(optional) 22 | ``` 23 |
24 | 25 | ### Usage 26 | 1. train a simple MNIST classifier 27 | ``` 28 | python main.py --mode train --env_name [NAME] 29 | ``` 30 | 2. load trained classifier, generate adversarial examples, and then see outputs in the output directory 31 | ``` 32 | python main.py --mode generate --iteration 1 --epsilon 0.03 --env_name [NAME] --load_ckpt best_acc.tar 33 | ``` 34 | 3. for a targeted attack, indicate target class number using ```--target``` argument(default is -1 for a non-targeted attack) 35 | ``` 36 | python main.py --mode generate --iteration 1 --epsilon 0.03 --target 3 --env_name [NAME] --load_ckpt best_acc.tar 37 | ``` 38 |
39 | 40 | ### Results 41 | #### Non-targeted attack 42 | from the left, legitimate examples, perturbed examples, and indication of perturbed images that changed predictions of the classifier, respectively 43 | 1. non-targeted attack, iteration : 1, epsilon : 0.03 44 | ![non-targeted1](misc/nontargeted_1.PNG) 45 | 2. non-targeted attack, iteration : 5, epsilon : 0.03 46 | ![non-targeted2](misc/nontargeted_2.PNG) 47 | 1. non-targeted attack, iteration : 1, epsilon : 0.5 48 | ![non-targeted3](misc/nontargeted_3.PNG) 49 |
50 | 51 | #### Targeted attack 52 | from the left, legitimate examples, perturbed examples, and indication of perturbed images that led the classifier to predict an input as the target, respectively 53 | 1. targeted attack(9), iteration : 1, epsilon : 0.03 54 | ![targeted1](misc/targetd_9_1.PNG) 55 | 2. targeted attack(9), iteration : 5, epsilon : 0.03 56 | ![targeted2](misc/targetd_9_2.PNG) 57 | 1. targeted attack(9), iteration : 1, epsilon : 0.5 58 | ![targeted3](misc/targetd_9_3.PNG) 59 |
60 | 61 | ### References 62 | 1. explaining and harnessing adversarial examples, Goodfellow et al. 63 | 2. adversarial examples in the physical world, Kurakin et al. 64 | 65 | [explaining and harnessing adversarial examples, Goodfellow et al.]: https://arxiv.org/abs/1412.6572 66 | [adversarial examples in the physical world, Kurakin et al.]: http://arxiv.org/abs/1607.02533 67 | -------------------------------------------------------------------------------- /adversary.py: -------------------------------------------------------------------------------- 1 | """adversary.py""" 2 | from pathlib import Path 3 | 4 | import torch 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | from torchvision.utils import save_image 9 | 10 | from models.toynet import ToyNet 11 | from datasets.datasets import return_data 12 | from utils.utils import rm_dir, cuda, where 13 | 14 | 15 | class Attack(object): 16 | def __init__(self, net, criterion): 17 | self.net = net 18 | self.criterion = criterion 19 | 20 | def fgsm(self, x, y, targeted=False, eps=0.03, x_val_min=-1, x_val_max=1): 21 | x_adv = Variable(x.data, requires_grad=True) 22 | h_adv = self.net(x_adv) 23 | if targeted: 24 | cost = self.criterion(h_adv, y) 25 | else: 26 | cost = -self.criterion(h_adv, y) 27 | 28 | self.net.zero_grad() 29 | if x_adv.grad is not None: 30 | x_adv.grad.data.fill_(0) 31 | cost.backward() 32 | 33 | x_adv.grad.sign_() 34 | x_adv = x_adv - eps*x_adv.grad 35 | x_adv = torch.clamp(x_adv, x_val_min, x_val_max) 36 | 37 | 38 | h = self.net(x) 39 | h_adv = self.net(x_adv) 40 | 41 | return x_adv, h_adv, h 42 | 43 | def i_fgsm(self, x, y, targeted=False, eps=0.03, alpha=1, iteration=1, x_val_min=-1, x_val_max=1): 44 | x_adv = Variable(x.data, requires_grad=True) 45 | for i in range(iteration): 46 | h_adv = self.net(x_adv) 47 | if targeted: 48 | cost = self.criterion(h_adv, y) 49 | else: 50 | cost = -self.criterion(h_adv, y) 51 | 52 | self.net.zero_grad() 53 | if x_adv.grad is not None: 54 | x_adv.grad.data.fill_(0) 55 | cost.backward() 56 | 57 | x_adv.grad.sign_() 58 | x_adv = x_adv - alpha*x_adv.grad 59 | x_adv = where(x_adv > x+eps, x+eps, x_adv) 60 | x_adv = where(x_adv < x-eps, x-eps, x_adv) 61 | x_adv = torch.clamp(x_adv, x_val_min, x_val_max) 62 | x_adv = Variable(x_adv.data, requires_grad=True) 63 | 64 | h = self.net(x) 65 | h_adv = self.net(x_adv) 66 | 67 | return x_adv, h_adv, h 68 | 69 | def universal(self, args): 70 | self.set_mode('eval') 71 | 72 | init = False 73 | 74 | correct = 0 75 | cost = 0 76 | total = 0 77 | 78 | data_loader = self.data_loader['test'] 79 | for e in range(100000): 80 | for batch_idx, (images, labels) in enumerate(data_loader): 81 | 82 | x = Variable(cuda(images, self.cuda)) 83 | y = Variable(cuda(labels, self.cuda)) 84 | 85 | if not init: 86 | sz = x.size()[1:] 87 | r = torch.zeros(sz) 88 | r = Variable(cuda(r, self.cuda), requires_grad=True) 89 | init = True 90 | 91 | logit = self.net(x+r) 92 | p_ygx = F.softmax(logit, dim=1) 93 | H_ygx = (-p_ygx*torch.log(self.eps+p_ygx)).sum(1).mean(0) 94 | prediction_cost = H_ygx 95 | #prediction_cost = F.cross_entropy(logit,y) 96 | #perceptual_cost = -F.l1_loss(x+r,x) 97 | #perceptual_cost = -F.mse_loss(x+r,x) 98 | #perceptual_cost = -F.mse_loss(x+r,x) -r.norm() 99 | perceptual_cost = -F.mse_loss(x+r, x) -F.relu(r.norm()-5) 100 | #perceptual_cost = -F.relu(r.norm()-5.) 101 | #if perceptual_cost.data[0] < 10: perceptual_cost.data.fill_(0) 102 | cost = prediction_cost + perceptual_cost 103 | #cost = prediction_cost 104 | 105 | self.net.zero_grad() 106 | if r.grad: 107 | r.grad.fill_(0) 108 | cost.backward() 109 | 110 | #r = r + args.eps*r.grad.sign() 111 | r = r + r.grad*1e-1 112 | r = Variable(cuda(r.data, self.cuda), requires_grad=True) 113 | 114 | 115 | 116 | prediction = logit.max(1)[1] 117 | correct = torch.eq(prediction, y).float().mean().data[0] 118 | if batch_idx % 100 == 0: 119 | if self.visdom: 120 | self.vf.imshow_multi(x.add(r).data) 121 | #self.vf.imshow_multi(r.unsqueeze(0).data,factor=4) 122 | print(correct*100, prediction_cost.data[0], perceptual_cost.data[0],\ 123 | r.norm().data[0]) 124 | 125 | self.set_mode('train') 126 | -------------------------------------------------------------------------------- /cleaner.py: -------------------------------------------------------------------------------- 1 | """cleaner.py""" 2 | 3 | import argparse 4 | from pathlib import Path 5 | 6 | from utils.utils import rm_dir 7 | 8 | 9 | def clean(args): 10 | """Remove directories relevant to specified experiment name given as env_name""" 11 | 12 | env_name = args.env_name 13 | 14 | ckpt_dir = Path(args.ckpt_dir).joinpath(env_name) 15 | summary_dir = Path(args.summary_dir).joinpath(env_name) 16 | output_dir = Path(args.output_dir).joinpath(env_name) 17 | 18 | rm_dir(ckpt_dir) 19 | rm_dir(summary_dir) 20 | rm_dir(output_dir) 21 | 22 | print('[*] Cleaning Finished ! ') 23 | 24 | 25 | if __name__ == '__main__': 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--env_name', type=str, required=True) 29 | parser.add_argument('--ckpt_dir', type=str, default='checkpoints') 30 | parser.add_argument('--summary_dir', type=str, default='summary') 31 | parser.add_argument('--output_dir', type=str, default='output') 32 | args = parser.parse_args() 33 | 34 | clean(args) 35 | -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | """datasets.py""" 2 | import os 3 | 4 | from torch.utils.data import DataLoader 5 | from torchvision import transforms 6 | from torchvision.datasets import MNIST 7 | 8 | 9 | class UnknownDatasetError(Exception): 10 | def __str__(self): 11 | return "unknown datasets error" 12 | 13 | 14 | def return_data(args): 15 | name = args.dataset 16 | dset_dir = args.dset_dir 17 | batch_size = args.batch_size 18 | transform = transforms.Compose([transforms.ToTensor(), 19 | transforms.Normalize((0.5,), (0.5,)), 20 | ]) 21 | 22 | if 'MNIST' in name: 23 | root = os.path.join(dset_dir, 'MNIST') 24 | train_kwargs = {'root':root, 'train':True, 'transform':transform, 'download':True} 25 | test_kwargs = {'root':root, 'train':False, 'transform':transform, 'download':False} 26 | dset = MNIST 27 | 28 | else: 29 | raise UnknownDatasetError() 30 | 31 | train_data = dset(**train_kwargs) 32 | train_loader = DataLoader(train_data, 33 | batch_size=batch_size, 34 | shuffle=True, 35 | num_workers=1, 36 | pin_memory=True, 37 | drop_last=True) 38 | 39 | test_data = dset(**test_kwargs) 40 | test_loader = DataLoader(test_data, 41 | batch_size=batch_size, 42 | shuffle=False, 43 | num_workers=1, 44 | pin_memory=True, 45 | drop_last=False) 46 | 47 | data_loader = dict() 48 | data_loader['train'] = train_loader 49 | data_loader['test'] = test_loader 50 | 51 | return data_loader 52 | 53 | 54 | if __name__ == '__main__': 55 | import argparse 56 | os.chdir('..') 57 | 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('--dataset', type=str, default='MNIST') 60 | parser.add_argument('--dset_dir', type=str, default='datasets') 61 | parser.add_argument('--batch_size', type=int, default=64) 62 | args = parser.parse_args() 63 | 64 | data_loader = return_data(args) 65 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """main.py""" 2 | import argparse 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from solver import Solver 8 | from utils.utils import str2bool 9 | 10 | def main(args): 11 | 12 | torch.backends.cudnn.enabled = True 13 | torch.backends.cudnn.benchmark = True 14 | 15 | seed = args.seed 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | np.random.seed(seed) 19 | 20 | np.set_printoptions(precision=4) 21 | torch.set_printoptions(precision=4) 22 | 23 | print() 24 | print('[ARGUMENTS]') 25 | print(args) 26 | print() 27 | 28 | net = Solver(args) 29 | 30 | if args.mode == 'train': 31 | net.train() 32 | elif args.mode == 'test': 33 | net.test() 34 | elif args.mode == 'generate': 35 | net.generate(num_sample=args.batch_size, 36 | target=args.target, 37 | epsilon=args.epsilon, 38 | alpha=args.alpha, 39 | iteration=args.iteration) 40 | elif args.mode == 'universal': 41 | net.universal(args) 42 | else: return 43 | 44 | print('[*] Finished') 45 | 46 | 47 | if __name__ == "__main__": 48 | 49 | parser = argparse.ArgumentParser(description='toynet template') 50 | parser.add_argument('--epoch', type=int, default=20, help='epoch size') 51 | parser.add_argument('--batch_size', type=int, default=100, help='mini-batch size') 52 | parser.add_argument('--lr', type=float, default=2e-4, help='learning rate') 53 | parser.add_argument('--y_dim', type=int, default=10, help='the number of classes') 54 | parser.add_argument('--target', type=int, default=-1, help='target class for targeted generation') 55 | parser.add_argument('--eps', type=float, default=1e-9, help='epsilon') 56 | parser.add_argument('--env_name', type=str, default='main', help='experiment name') 57 | parser.add_argument('--dataset', type=str, default='FMNIST', help='dataset type') 58 | parser.add_argument('--dset_dir', type=str, default='datasets', help='dataset directory path') 59 | parser.add_argument('--summary_dir', type=str, default='summary', help='summary directory path') 60 | parser.add_argument('--output_dir', type=str, default='output', help='output directory path') 61 | parser.add_argument('--ckpt_dir', type=str, default='checkpoints', help='checkpoint directory path') 62 | parser.add_argument('--load_ckpt', type=str, default='', help='') 63 | parser.add_argument('--cuda', type=str2bool, default=True, help='enable cuda') 64 | parser.add_argument('--silent', type=str2bool, default=False, help='') 65 | parser.add_argument('--mode', type=str, default='train', help='train / test / generate / universal') 66 | parser.add_argument('--seed', type=int, default=1, help='random seed') 67 | parser.add_argument('--iteration', type=int, default=1, help='the number of iteration for FGSM') 68 | parser.add_argument('--epsilon', type=float, default=0.03, help='epsilon for FGSM and i-FGSM') 69 | parser.add_argument('--alpha', type=float, default=2/255, help='alpha for i-FGSM') 70 | parser.add_argument('--tensorboard', type=str2bool, default=False, help='enable tensorboard') 71 | parser.add_argument('--visdom', type=str2bool, default=False, help='enable visdom') 72 | parser.add_argument('--visdom_port', type=str, default=55558, help='visdom port') 73 | args = parser.parse_args() 74 | 75 | main(args) 76 | -------------------------------------------------------------------------------- /misc/FGSM.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/FGSM/d5e730b935b6a02f6045ccacc18e4a16676a584e/misc/FGSM.PNG -------------------------------------------------------------------------------- /misc/IFGSM.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/FGSM/d5e730b935b6a02f6045ccacc18e4a16676a584e/misc/IFGSM.PNG -------------------------------------------------------------------------------- /misc/nontargeted_1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/FGSM/d5e730b935b6a02f6045ccacc18e4a16676a584e/misc/nontargeted_1.PNG -------------------------------------------------------------------------------- /misc/nontargeted_2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/FGSM/d5e730b935b6a02f6045ccacc18e4a16676a584e/misc/nontargeted_2.PNG -------------------------------------------------------------------------------- /misc/nontargeted_3.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/FGSM/d5e730b935b6a02f6045ccacc18e4a16676a584e/misc/nontargeted_3.PNG -------------------------------------------------------------------------------- /misc/overview.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/FGSM/d5e730b935b6a02f6045ccacc18e4a16676a584e/misc/overview.PNG -------------------------------------------------------------------------------- /misc/targetd_9_1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/FGSM/d5e730b935b6a02f6045ccacc18e4a16676a584e/misc/targetd_9_1.PNG -------------------------------------------------------------------------------- /misc/targetd_9_2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/FGSM/d5e730b935b6a02f6045ccacc18e4a16676a584e/misc/targetd_9_2.PNG -------------------------------------------------------------------------------- /misc/targetd_9_3.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1Konny/FGSM/d5e730b935b6a02f6045ccacc18e4a16676a584e/misc/targetd_9_3.PNG -------------------------------------------------------------------------------- /models/toynet.py: -------------------------------------------------------------------------------- 1 | """toynet.py""" 2 | import torch.nn as nn 3 | 4 | class ToyNet(nn.Module): 5 | def __init__(self, x_dim=784, y_dim=10): 6 | super(ToyNet, self).__init__() 7 | self.x_dim = x_dim 8 | self.y_dim = y_dim 9 | 10 | self.mlp = nn.Sequential( 11 | nn.Linear(self.x_dim, 300), 12 | nn.ReLU(True), 13 | nn.Linear(300, 150), 14 | nn.ReLU(True), 15 | nn.Linear(150, self.y_dim) 16 | ) 17 | 18 | def forward(self, X): 19 | if X.dim() > 2: 20 | X = X.view(X.size(0), -1) 21 | out = self.mlp(X) 22 | 23 | return out 24 | 25 | def weight_init(self, _type='kaiming'): 26 | if _type == 'kaiming': 27 | for ms in self._modules: 28 | kaiming_init(self._modules[ms].parameters()) 29 | 30 | 31 | def xavier_init(ms): 32 | for m in ms: 33 | if isinstance(m, (nn.Linear, nn.Conv2d)): 34 | nn.init.xavier_uniform(m.weight, gain=nn.init.calculate_gain('relu')) 35 | if m.bias.data: 36 | m.bias.data.zero_() 37 | if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): 38 | m.weight.data.fill_(1) 39 | if m.bias.data: 40 | m.bias.data.zero_() 41 | 42 | 43 | def kaiming_init(ms): 44 | for m in ms: 45 | if isinstance(m, (nn.Linear, nn.Conv2d)): 46 | nn.init.kaiming_uniform(m.weight, a=0, mode='fan_in') 47 | if m.bias.data: 48 | m.bias.data.zero_() 49 | if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): 50 | m.weight.data.fill_(1) 51 | if m.bias.data: 52 | m.bias.data.zero_() 53 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | """solver.py""" 2 | from pathlib import Path 3 | 4 | import torch 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | from torchvision.utils import save_image 9 | 10 | from models.toynet import ToyNet 11 | from datasets.datasets import return_data 12 | from utils.utils import rm_dir, cuda, where 13 | from adversary import Attack 14 | 15 | 16 | class Solver(object): 17 | def __init__(self, args): 18 | self.args = args 19 | 20 | # Basic 21 | self.cuda = (args.cuda and torch.cuda.is_available()) 22 | self.epoch = args.epoch 23 | self.batch_size = args.batch_size 24 | self.eps = args.eps 25 | self.lr = args.lr 26 | self.y_dim = args.y_dim 27 | self.target = args.target 28 | self.dataset = args.dataset 29 | self.data_loader = return_data(args) 30 | self.global_epoch = 0 31 | self.global_iter = 0 32 | self.print_ = not args.silent 33 | 34 | self.env_name = args.env_name 35 | self.tensorboard = args.tensorboard 36 | self.visdom = args.visdom 37 | 38 | self.ckpt_dir = Path(args.ckpt_dir).joinpath(args.env_name) 39 | if not self.ckpt_dir.exists(): 40 | self.ckpt_dir.mkdir(parents=True, exist_ok=True) 41 | self.output_dir = Path(args.output_dir).joinpath(args.env_name) 42 | if not self.output_dir.exists(): 43 | self.output_dir.mkdir(parents=True, exist_ok=True) 44 | 45 | # Visualization Tools 46 | self.visualization_init(args) 47 | 48 | # Histories 49 | self.history = dict() 50 | self.history['acc'] = 0. 51 | self.history['epoch'] = 0 52 | self.history['iter'] = 0 53 | 54 | # Models & Optimizers 55 | self.model_init(args) 56 | self.load_ckpt = args.load_ckpt 57 | if self.load_ckpt != '': 58 | self.load_checkpoint(self.load_ckpt) 59 | 60 | # Adversarial Perturbation Generator 61 | #criterion = cuda(torch.nn.CrossEntropyLoss(), self.cuda) 62 | criterion = F.cross_entropy 63 | self.attack = Attack(self.net, criterion=criterion) 64 | 65 | def visualization_init(self, args): 66 | # Visdom 67 | if self.visdom: 68 | from utils.visdom_utils import VisFunc 69 | self.port = args.visdom_port 70 | self.vf = VisFunc(enval=self.env_name, port=self.port) 71 | 72 | # TensorboardX 73 | if self.tensorboard: 74 | from tensorboardX import SummaryWriter 75 | self.summary_dir = Path(args.summary_dir).joinpath(args.env_name) 76 | if not self.summary_dir.exists(): 77 | self.summary_dir.mkdir(parents=True, exist_ok=True) 78 | 79 | self.tf = SummaryWriter(log_dir=str(self.summary_dir)) 80 | self.tf.add_text(tag='argument', text_string=str(args), global_step=self.global_epoch) 81 | 82 | def model_init(self, args): 83 | # Network 84 | self.net = cuda(ToyNet(y_dim=self.y_dim), self.cuda) 85 | self.net.weight_init(_type='kaiming') 86 | 87 | # Optimizers 88 | self.optim = optim.Adam([{'params':self.net.parameters(), 'lr':self.lr}], 89 | betas=(0.5, 0.999)) 90 | 91 | def train(self): 92 | self.set_mode('train') 93 | for e in range(self.epoch): 94 | self.global_epoch += 1 95 | 96 | correct = 0. 97 | cost = 0. 98 | total = 0. 99 | for batch_idx, (images, labels) in enumerate(self.data_loader['train']): 100 | self.global_iter += 1 101 | 102 | x = Variable(cuda(images, self.cuda)) 103 | y = Variable(cuda(labels, self.cuda)) 104 | 105 | logit = self.net(x) 106 | prediction = logit.max(1)[1] 107 | 108 | correct = torch.eq(prediction, y).float().mean().data[0] 109 | cost = F.cross_entropy(logit, y) 110 | 111 | self.optim.zero_grad() 112 | cost.backward() 113 | self.optim.step() 114 | 115 | if batch_idx % 100 == 0: 116 | if self.print_: 117 | print() 118 | print(self.env_name) 119 | print('[{:03d}:{:03d}]'.format(self.global_epoch, batch_idx)) 120 | print('acc:{:.3f} loss:{:.3f}'.format(correct, cost.data[0])) 121 | 122 | 123 | if self.tensorboard: 124 | self.tf.add_scalars(main_tag='performance/acc', 125 | tag_scalar_dict={'train':correct}, 126 | global_step=self.global_iter) 127 | self.tf.add_scalars(main_tag='performance/error', 128 | tag_scalar_dict={'train':1-correct}, 129 | global_step=self.global_iter) 130 | self.tf.add_scalars(main_tag='performance/cost', 131 | tag_scalar_dict={'train':cost.data[0]}, 132 | global_step=self.global_iter) 133 | 134 | 135 | self.test() 136 | 137 | 138 | if self.tensorboard: 139 | self.tf.add_scalars(main_tag='performance/best/acc', 140 | tag_scalar_dict={'test':self.history['acc']}, 141 | global_step=self.history['iter']) 142 | print(" [*] Training Finished!") 143 | 144 | def test(self): 145 | self.set_mode('eval') 146 | 147 | correct = 0. 148 | cost = 0. 149 | total = 0. 150 | 151 | data_loader = self.data_loader['test'] 152 | for batch_idx, (images, labels) in enumerate(data_loader): 153 | x = Variable(cuda(images, self.cuda)) 154 | y = Variable(cuda(labels, self.cuda)) 155 | 156 | logit = self.net(x) 157 | prediction = logit.max(1)[1] 158 | 159 | correct += torch.eq(prediction, y).float().sum().data[0] 160 | cost += F.cross_entropy(logit, y, size_average=False).data[0] 161 | total += x.size(0) 162 | 163 | accuracy = correct / total 164 | cost /= total 165 | 166 | 167 | if self.print_: 168 | print() 169 | print('[{:03d}]\nTEST RESULT'.format(self.global_epoch)) 170 | print('ACC:{:.4f}'.format(accuracy)) 171 | print('*TOP* ACC:{:.4f} at e:{:03d}'.format(accuracy, self.global_epoch,)) 172 | print() 173 | 174 | if self.tensorboard: 175 | self.tf.add_scalars(main_tag='performance/acc', 176 | tag_scalar_dict={'test':accuracy}, 177 | global_step=self.global_iter) 178 | 179 | self.tf.add_scalars(main_tag='performance/error', 180 | tag_scalar_dict={'test':(1-accuracy)}, 181 | global_step=self.global_iter) 182 | 183 | self.tf.add_scalars(main_tag='performance/cost', 184 | tag_scalar_dict={'test':cost}, 185 | global_step=self.global_iter) 186 | 187 | if self.history['acc'] < accuracy: 188 | self.history['acc'] = accuracy 189 | self.history['epoch'] = self.global_epoch 190 | self.history['iter'] = self.global_iter 191 | self.save_checkpoint('best_acc.tar') 192 | 193 | self.set_mode('train') 194 | 195 | def generate(self, num_sample=100, target=-1, epsilon=0.03, alpha=2/255, iteration=1): 196 | self.set_mode('eval') 197 | 198 | x_true, y_true = self.sample_data(num_sample) 199 | if isinstance(target, int) and (target in range(self.y_dim)): 200 | y_target = torch.LongTensor(y_true.size()).fill_(target) 201 | else: 202 | y_target = None 203 | 204 | x_adv, changed, values = self.FGSM(x_true, y_true, y_target, epsilon, alpha, iteration) 205 | accuracy, cost, accuracy_adv, cost_adv = values 206 | 207 | save_image(x_true, 208 | self.output_dir.joinpath('legitimate(t:{},e:{},i:{}).jpg'.format(target, 209 | epsilon, 210 | iteration)), 211 | nrow=10, 212 | padding=2, 213 | pad_value=0.5) 214 | save_image(x_adv, 215 | self.output_dir.joinpath('perturbed(t:{},e:{},i:{}).jpg'.format(target, 216 | epsilon, 217 | iteration)), 218 | nrow=10, 219 | padding=2, 220 | pad_value=0.5) 221 | save_image(changed, 222 | self.output_dir.joinpath('changed(t:{},e:{},i:{}).jpg'.format(target, 223 | epsilon, 224 | iteration)), 225 | nrow=10, 226 | padding=3, 227 | pad_value=0.5) 228 | 229 | if self.visdom: 230 | self.vf.imshow_multi(x_true.cpu(), title='legitimate', factor=1.5) 231 | self.vf.imshow_multi(x_adv.cpu(), title='perturbed(e:{},i:{})'.format(epsilon, iteration), factor=1.5) 232 | self.vf.imshow_multi(changed.cpu(), title='changed(white)'.format(epsilon), factor=1.5) 233 | 234 | print('[BEFORE] accuracy : {:.2f} cost : {:.3f}'.format(accuracy, cost)) 235 | print('[AFTER] accuracy : {:.2f} cost : {:.3f}'.format(accuracy_adv, cost_adv)) 236 | 237 | self.set_mode('train') 238 | 239 | def sample_data(self, num_sample=100): 240 | 241 | total = len(self.data_loader['test'].dataset) 242 | seed = torch.FloatTensor(num_sample).uniform_(1, total).long() 243 | 244 | x = self.data_loader['test'].dataset.test_data[seed] 245 | x = self.scale(x.float().unsqueeze(1).div(255)) 246 | y = self.data_loader['test'].dataset.test_labels[seed] 247 | 248 | return x, y 249 | 250 | 251 | def FGSM(self, x, y_true, y_target=None, eps=0.03, alpha=2/255, iteration=1): 252 | self.set_mode('eval') 253 | 254 | x = Variable(cuda(x, self.cuda), requires_grad=True) 255 | y_true = Variable(cuda(y_true, self.cuda), requires_grad=False) 256 | if y_target is not None: 257 | targeted = True 258 | y_target = Variable(cuda(y_target, self.cuda), requires_grad=False) 259 | else: 260 | targeted = False 261 | 262 | 263 | h = self.net(x) 264 | prediction = h.max(1)[1] 265 | accuracy = torch.eq(prediction, y_true).float().mean() 266 | cost = F.cross_entropy(h, y_true) 267 | 268 | if iteration == 1: 269 | if targeted: 270 | x_adv, h_adv, h = self.attack.fgsm(x, y_target, True, eps) 271 | else: 272 | x_adv, h_adv, h = self.attack.fgsm(x, y_true, False, eps) 273 | else: 274 | if targeted: 275 | x_adv, h_adv, h = self.attack.i_fgsm(x, y_target, True, eps, alpha, iteration) 276 | else: 277 | x_adv, h_adv, h = self.attack.i_fgsm(x, y_true, False, eps, alpha, iteration) 278 | 279 | prediction_adv = h_adv.max(1)[1] 280 | accuracy_adv = torch.eq(prediction_adv, y_true).float().mean() 281 | cost_adv = F.cross_entropy(h_adv, y_true) 282 | 283 | # make indication of perturbed images that changed predictions of the classifier 284 | if targeted: 285 | changed = torch.eq(y_target, prediction_adv) 286 | else: 287 | changed = torch.eq(prediction, prediction_adv) 288 | changed = torch.eq(changed, 0) 289 | changed = changed.float().view(-1, 1, 1, 1).repeat(1, 3, 28, 28) 290 | 291 | changed[:, 0, :, :] = where(changed[:, 0, :, :] == 1, 252, 91) 292 | changed[:, 1, :, :] = where(changed[:, 1, :, :] == 1, 39, 252) 293 | changed[:, 2, :, :] = where(changed[:, 2, :, :] == 1, 25, 25) 294 | changed = self.scale(changed/255) 295 | changed[:, :, 3:-2, 3:-2] = x_adv.repeat(1, 3, 1, 1)[:, :, 3:-2, 3:-2] 296 | 297 | self.set_mode('train') 298 | 299 | return x_adv.data, changed.data,\ 300 | (accuracy.data[0], cost.data[0], accuracy_adv.data[0], cost_adv.data[0]) 301 | 302 | def save_checkpoint(self, filename='ckpt.tar'): 303 | model_states = { 304 | 'net':self.net.state_dict(), 305 | } 306 | optim_states = { 307 | 'optim':self.optim.state_dict(), 308 | } 309 | states = { 310 | 'iter':self.global_iter, 311 | 'epoch':self.global_epoch, 312 | 'history':self.history, 313 | 'args':self.args, 314 | 'model_states':model_states, 315 | 'optim_states':optim_states, 316 | } 317 | 318 | file_path = self.ckpt_dir / filename 319 | torch.save(states, file_path.open('wb+')) 320 | print("=> saved checkpoint '{}' (iter {})".format(file_path, self.global_iter)) 321 | 322 | def load_checkpoint(self, filename='best_acc.tar'): 323 | file_path = self.ckpt_dir / filename 324 | if file_path.is_file(): 325 | print("=> loading checkpoint '{}'".format(file_path)) 326 | checkpoint = torch.load(file_path.open('rb')) 327 | self.global_epoch = checkpoint['epoch'] 328 | self.global_iter = checkpoint['iter'] 329 | self.history = checkpoint['history'] 330 | 331 | self.net.load_state_dict(checkpoint['model_states']['net']) 332 | self.optim.load_state_dict(checkpoint['optim_states']['optim']) 333 | 334 | print("=> loaded checkpoint '{} (iter {})'".format(file_path, self.global_iter)) 335 | 336 | else: 337 | print("=> no checkpoint found at '{}'".format(file_path)) 338 | 339 | def set_mode(self, mode='train'): 340 | if mode == 'train': 341 | self.net.train() 342 | elif mode == 'eval': 343 | self.net.eval() 344 | else: raise('mode error. It should be either train or eval') 345 | 346 | def scale(self, image): 347 | return image.mul(2).add(-1) 348 | 349 | def unscale(self, image): 350 | return image.add(1).mul(0.5) 351 | 352 | def summary_flush(self, silent=True): 353 | rm_dir(self.summary_dir, silent) 354 | 355 | def checkpoint_flush(self, silent=True): 356 | rm_dir(self.ckpt_dir, silent) 357 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import argparse, torch 2 | import numpy as np 3 | from torch import nn 4 | from torch.autograd import Variable 5 | from pathlib import Path 6 | 7 | class One_Hot(nn.Module): 8 | # from : 9 | # https://lirnli.wordpress.com/2017/09/03/one-hot-encoding-in-pytorch/ 10 | def __init__(self, depth): 11 | super(One_Hot,self).__init__() 12 | self.depth = depth 13 | self.ones = torch.sparse.torch.eye(depth) 14 | def forward(self, X_in): 15 | X_in = X_in.long() 16 | return Variable(self.ones.index_select(0,X_in.data)) 17 | def __repr__(self): 18 | return self.__class__.__name__ + "({})".format(self.depth) 19 | 20 | 21 | def cuda(tensor,is_cuda): 22 | if is_cuda : return tensor.cuda() 23 | else : return tensor 24 | 25 | def str2bool(v): 26 | # codes from : https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 27 | 28 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 29 | return True 30 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 31 | return False 32 | else: 33 | raise argparse.ArgumentTypeError('Boolean value expected.') 34 | 35 | 36 | def print_network(net): 37 | num_params = 0 38 | for param in net.parameters(): 39 | num_params += param.numel() 40 | print(net) 41 | print('Total number of parameters: %d' % num_params) 42 | 43 | 44 | def rm_dir(dir_path, silent=True): 45 | p = Path(dir_path).resolve() 46 | if (not p.is_file()) and (not p.is_dir()) : 47 | print('It is not path for file nor directory :',p) 48 | return 49 | 50 | paths = list(p.iterdir()) 51 | if (len(paths) == 0) and p.is_dir() : 52 | p.rmdir() 53 | if not silent : print('removed empty dir :',p) 54 | 55 | else : 56 | for path in paths : 57 | if path.is_file() : 58 | path.unlink() 59 | if not silent : print('removed file :',path) 60 | else: 61 | rm_dir(path) 62 | p.rmdir() 63 | if not silent : print('removed empty dir :',p) 64 | 65 | def where(cond, x, y): 66 | """ 67 | code from : 68 | https://discuss.pytorch.org/t/how-can-i-do-the-operation-the-same-as-np-where/1329/8 69 | """ 70 | cond = cond.float() 71 | return (cond*x) + ((1-cond)*y) 72 | -------------------------------------------------------------------------------- /utils/visdom_utils.py: -------------------------------------------------------------------------------- 1 | import visdom 2 | from scipy.misc import imresize 3 | import numpy as np 4 | from torchvision.utils import make_grid 5 | 6 | class VisFunc(object): 7 | 8 | def __init__(self, config=None, vis=None, enval='hproto',port=8097): 9 | self.config = config 10 | self.vis = visdom.Visdom(env=enval, port=port) 11 | self.win = None 12 | self.win2 = None 13 | self.epoch_list = [] 14 | self.train_loss_list = [] 15 | self.val_loss_list = [] 16 | self.epoch_list2 = [] 17 | self.train_acc_list = [] 18 | self.val_acc_list = [] 19 | 20 | 21 | def imshow(self, img, title=' ', caption=' ', factor=1): 22 | 23 | img = img / 2 + 0.5 # Unnormalize 24 | npimg = img.numpy() 25 | obj = np.transpose(npimg, (1,2,0)) 26 | obj = np.swapaxes(obj,0,2) 27 | obj = np.swapaxes(obj,1,2) 28 | 29 | imgsize = tuple((np.array(obj.shape[1:])*factor).astype(int)) 30 | rgbArray = np.zeros(tuple([3])+imgsize,'float32') 31 | rgbArray[0,...] = imresize(obj[0,:,:],imgsize,'cubic') 32 | rgbArray[1,...] = imresize(obj[1,:,:],imgsize,'cubic') 33 | rgbArray[2,...] = imresize(obj[2,:,:],imgsize,'cubic') 34 | 35 | self.vis.image( rgbArray, 36 | opts=dict(title=title, caption=caption), 37 | ) 38 | 39 | 40 | def imshow_multi(self, imgs, nrow=10, title=' ', caption=' ', factor=1): 41 | #self.imshow( make_grid(imgs,nrow,padding=padding), title, caption, factor) 42 | self.imshow( make_grid(imgs,nrow), title, caption, factor) 43 | 44 | 45 | def imshow_one_batch(self, loader, classes=None, factor=1): 46 | dataiter = iter(loader) 47 | images, labels = dataiter.next() 48 | self.imshow(make_grid(images,padding)) 49 | 50 | if classes: 51 | print(' '.join('%5s' % classes[labels[j]] 52 | for j in range(loader.batch_size))) 53 | else: 54 | print(' '.join('%5s' % labels[j] 55 | for j in range(loader.batch_size))) 56 | 57 | 58 | def plot(self, epoch, train_loss, val_loss,Des): 59 | ''' plot learning curve interactively with visdom ''' 60 | self.epoch_list.append(epoch) 61 | self.train_loss_list.append(train_loss) 62 | self.val_loss_list.append(val_loss) 63 | 64 | if not self.win: 65 | # send line plot 66 | # embed() 67 | self.win = self.vis.line( 68 | X=np.array(self.epoch_list), 69 | Y=np.array([[self.train_loss_list[-1], self.val_loss_list[-1]]]), 70 | opts=dict( 71 | title='Learning Curve (' + Des +')', 72 | xlabel='Epoch', 73 | ylabel='Loss', 74 | legend=['train_loss', 'val_loss'], 75 | #caption=Des 76 | )) 77 | # send text memo (configuration) 78 | # self.vis.text(str(Des)) 79 | else: 80 | self.vis.updateTrace( 81 | X=np.array(self.epoch_list[-2:]), 82 | Y=np.array(self.train_loss_list[-2:]), 83 | win=self.win, 84 | name='train_loss', 85 | ) 86 | self.vis.updateTrace( 87 | X=np.array(self.epoch_list[-2:]), 88 | Y=np.array(self.val_loss_list[-2:]), 89 | win=self.win, 90 | name='val_loss', 91 | ) 92 | 93 | 94 | def acc_plot(self, epoch, train_acc, val_acc, Des): 95 | ''' plot learning curve interactively with visdom ''' 96 | self.epoch_list2.append(epoch) 97 | self.train_acc_list.append(train_acc) 98 | self.val_acc_list.append(val_acc) 99 | 100 | if not self.win2: 101 | # send line plot 102 | # embed() 103 | self.win2 = self.vis.line( 104 | X=np.array(self.epoch_list2), 105 | Y=np.array([[self.train_acc_list[-1], self.val_acc_list[-1]]]), 106 | opts=dict( 107 | title='Accuracy Curve (' + Des +')', 108 | xlabel='Epoch', 109 | ylabel='Accuracy', 110 | legend=['train_accuracy', 'val_accuracy'] 111 | )) 112 | # send text memo (configuration) 113 | # self.vis.text(str(self.config)) 114 | else: 115 | self.vis.updateTrace( 116 | X=np.array(self.epoch_list2[-2:]), 117 | Y=np.array(self.train_acc_list[-2:]), 118 | win=self.win2, 119 | name='train_accuracy', 120 | ) 121 | self.vis.updateTrace( 122 | X=np.array(self.epoch_list2[-2:]), 123 | Y=np.array(self.val_acc_list[-2:]), 124 | win=self.win2, 125 | name='val_accuracy', 126 | ) 127 | 128 | 129 | def plot2(self, epoch, train_loss, val_loss,Des, win): 130 | ''' plot learning curve interactively with visdom ''' 131 | self.epoch_list.append(epoch) 132 | self.train_loss_list.append(train_loss) 133 | self.val_loss_list.append(val_loss) 134 | 135 | if not self.win: 136 | self.win = win 137 | # send line plot 138 | # embed() 139 | #self.win = self.vis.line( 140 | # X=np.array(self.epoch_list), 141 | # Y=np.array([[self.train_loss_list[-1], self.val_loss_list[-1]]]), 142 | # opts=dict( 143 | # title='Learning Curve (' + Des +')', 144 | # xlabel='Epoch', 145 | # ylabel='Loss', 146 | # legend=['train_loss', 'val_loss'], 147 | # #caption=Des 148 | # )) 149 | ## send text memo (configuration) 150 | # self.vis.text(str(Des)) 151 | else: 152 | self.vis.updateTrace( 153 | X=np.array(self.epoch_list[-2:]), 154 | Y=np.array(self.train_loss_list[-2:]), 155 | win=self.win, 156 | name='train_loss2', 157 | ) 158 | self.vis.updateTrace( 159 | X=np.array(self.epoch_list[-2:]), 160 | Y=np.array(self.val_loss_list[-2:]), 161 | win=self.win, 162 | name='val_lossi2', 163 | ) 164 | --------------------------------------------------------------------------------