├── .gitignore
├── README.md
├── adversary.py
├── cleaner.py
├── datasets
└── datasets.py
├── main.py
├── misc
├── FGSM.PNG
├── IFGSM.PNG
├── nontargeted_1.PNG
├── nontargeted_2.PNG
├── nontargeted_3.PNG
├── overview.PNG
├── targetd_9_1.PNG
├── targetd_9_2.PNG
└── targetd_9_3.PNG
├── models
└── toynet.py
├── solver.py
└── utils
├── utils.py
└── visdom_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 |
3 | checkpoints/*
4 | summary/*
5 | output/*
6 |
7 | datasets/MNIST
8 |
9 | git.sh
10 | .gitignore
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FGSM(Fast Gradient Sign Method)
2 |
3 |
4 | ### Overview
5 | Simple pytorch implementation of FGSM and I-FGSM
6 | (FGSM : [explaining and harnessing adversarial examples, Goodfellow et al.])
7 | (I-FGSM : [adversarial examples in the physical world, Kurakin et al.])
8 | 
9 | #### FGSM
10 | 
11 | #### I-FGSM
12 | 
13 |
14 |
15 | ### Dependencies
16 | ```
17 | python 3.6.4
18 | pytorch 0.3.1.post2
19 | visdom(optional)
20 | tensorboardX(optional)
21 | tensorflow(optional)
22 | ```
23 |
24 |
25 | ### Usage
26 | 1. train a simple MNIST classifier
27 | ```
28 | python main.py --mode train --env_name [NAME]
29 | ```
30 | 2. load trained classifier, generate adversarial examples, and then see outputs in the output directory
31 | ```
32 | python main.py --mode generate --iteration 1 --epsilon 0.03 --env_name [NAME] --load_ckpt best_acc.tar
33 | ```
34 | 3. for a targeted attack, indicate target class number using ```--target``` argument(default is -1 for a non-targeted attack)
35 | ```
36 | python main.py --mode generate --iteration 1 --epsilon 0.03 --target 3 --env_name [NAME] --load_ckpt best_acc.tar
37 | ```
38 |
39 |
40 | ### Results
41 | #### Non-targeted attack
42 | from the left, legitimate examples, perturbed examples, and indication of perturbed images that changed predictions of the classifier, respectively
43 | 1. non-targeted attack, iteration : 1, epsilon : 0.03
44 | 
45 | 2. non-targeted attack, iteration : 5, epsilon : 0.03
46 | 
47 | 1. non-targeted attack, iteration : 1, epsilon : 0.5
48 | 
49 |
50 |
51 | #### Targeted attack
52 | from the left, legitimate examples, perturbed examples, and indication of perturbed images that led the classifier to predict an input as the target, respectively
53 | 1. targeted attack(9), iteration : 1, epsilon : 0.03
54 | 
55 | 2. targeted attack(9), iteration : 5, epsilon : 0.03
56 | 
57 | 1. targeted attack(9), iteration : 1, epsilon : 0.5
58 | 
59 |
60 |
61 | ### References
62 | 1. explaining and harnessing adversarial examples, Goodfellow et al.
63 | 2. adversarial examples in the physical world, Kurakin et al.
64 |
65 | [explaining and harnessing adversarial examples, Goodfellow et al.]: https://arxiv.org/abs/1412.6572
66 | [adversarial examples in the physical world, Kurakin et al.]: http://arxiv.org/abs/1607.02533
67 |
--------------------------------------------------------------------------------
/adversary.py:
--------------------------------------------------------------------------------
1 | """adversary.py"""
2 | from pathlib import Path
3 |
4 | import torch
5 | import torch.optim as optim
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 | from torchvision.utils import save_image
9 |
10 | from models.toynet import ToyNet
11 | from datasets.datasets import return_data
12 | from utils.utils import rm_dir, cuda, where
13 |
14 |
15 | class Attack(object):
16 | def __init__(self, net, criterion):
17 | self.net = net
18 | self.criterion = criterion
19 |
20 | def fgsm(self, x, y, targeted=False, eps=0.03, x_val_min=-1, x_val_max=1):
21 | x_adv = Variable(x.data, requires_grad=True)
22 | h_adv = self.net(x_adv)
23 | if targeted:
24 | cost = self.criterion(h_adv, y)
25 | else:
26 | cost = -self.criterion(h_adv, y)
27 |
28 | self.net.zero_grad()
29 | if x_adv.grad is not None:
30 | x_adv.grad.data.fill_(0)
31 | cost.backward()
32 |
33 | x_adv.grad.sign_()
34 | x_adv = x_adv - eps*x_adv.grad
35 | x_adv = torch.clamp(x_adv, x_val_min, x_val_max)
36 |
37 |
38 | h = self.net(x)
39 | h_adv = self.net(x_adv)
40 |
41 | return x_adv, h_adv, h
42 |
43 | def i_fgsm(self, x, y, targeted=False, eps=0.03, alpha=1, iteration=1, x_val_min=-1, x_val_max=1):
44 | x_adv = Variable(x.data, requires_grad=True)
45 | for i in range(iteration):
46 | h_adv = self.net(x_adv)
47 | if targeted:
48 | cost = self.criterion(h_adv, y)
49 | else:
50 | cost = -self.criterion(h_adv, y)
51 |
52 | self.net.zero_grad()
53 | if x_adv.grad is not None:
54 | x_adv.grad.data.fill_(0)
55 | cost.backward()
56 |
57 | x_adv.grad.sign_()
58 | x_adv = x_adv - alpha*x_adv.grad
59 | x_adv = where(x_adv > x+eps, x+eps, x_adv)
60 | x_adv = where(x_adv < x-eps, x-eps, x_adv)
61 | x_adv = torch.clamp(x_adv, x_val_min, x_val_max)
62 | x_adv = Variable(x_adv.data, requires_grad=True)
63 |
64 | h = self.net(x)
65 | h_adv = self.net(x_adv)
66 |
67 | return x_adv, h_adv, h
68 |
69 | def universal(self, args):
70 | self.set_mode('eval')
71 |
72 | init = False
73 |
74 | correct = 0
75 | cost = 0
76 | total = 0
77 |
78 | data_loader = self.data_loader['test']
79 | for e in range(100000):
80 | for batch_idx, (images, labels) in enumerate(data_loader):
81 |
82 | x = Variable(cuda(images, self.cuda))
83 | y = Variable(cuda(labels, self.cuda))
84 |
85 | if not init:
86 | sz = x.size()[1:]
87 | r = torch.zeros(sz)
88 | r = Variable(cuda(r, self.cuda), requires_grad=True)
89 | init = True
90 |
91 | logit = self.net(x+r)
92 | p_ygx = F.softmax(logit, dim=1)
93 | H_ygx = (-p_ygx*torch.log(self.eps+p_ygx)).sum(1).mean(0)
94 | prediction_cost = H_ygx
95 | #prediction_cost = F.cross_entropy(logit,y)
96 | #perceptual_cost = -F.l1_loss(x+r,x)
97 | #perceptual_cost = -F.mse_loss(x+r,x)
98 | #perceptual_cost = -F.mse_loss(x+r,x) -r.norm()
99 | perceptual_cost = -F.mse_loss(x+r, x) -F.relu(r.norm()-5)
100 | #perceptual_cost = -F.relu(r.norm()-5.)
101 | #if perceptual_cost.data[0] < 10: perceptual_cost.data.fill_(0)
102 | cost = prediction_cost + perceptual_cost
103 | #cost = prediction_cost
104 |
105 | self.net.zero_grad()
106 | if r.grad:
107 | r.grad.fill_(0)
108 | cost.backward()
109 |
110 | #r = r + args.eps*r.grad.sign()
111 | r = r + r.grad*1e-1
112 | r = Variable(cuda(r.data, self.cuda), requires_grad=True)
113 |
114 |
115 |
116 | prediction = logit.max(1)[1]
117 | correct = torch.eq(prediction, y).float().mean().data[0]
118 | if batch_idx % 100 == 0:
119 | if self.visdom:
120 | self.vf.imshow_multi(x.add(r).data)
121 | #self.vf.imshow_multi(r.unsqueeze(0).data,factor=4)
122 | print(correct*100, prediction_cost.data[0], perceptual_cost.data[0],\
123 | r.norm().data[0])
124 |
125 | self.set_mode('train')
126 |
--------------------------------------------------------------------------------
/cleaner.py:
--------------------------------------------------------------------------------
1 | """cleaner.py"""
2 |
3 | import argparse
4 | from pathlib import Path
5 |
6 | from utils.utils import rm_dir
7 |
8 |
9 | def clean(args):
10 | """Remove directories relevant to specified experiment name given as env_name"""
11 |
12 | env_name = args.env_name
13 |
14 | ckpt_dir = Path(args.ckpt_dir).joinpath(env_name)
15 | summary_dir = Path(args.summary_dir).joinpath(env_name)
16 | output_dir = Path(args.output_dir).joinpath(env_name)
17 |
18 | rm_dir(ckpt_dir)
19 | rm_dir(summary_dir)
20 | rm_dir(output_dir)
21 |
22 | print('[*] Cleaning Finished ! ')
23 |
24 |
25 | if __name__ == '__main__':
26 |
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument('--env_name', type=str, required=True)
29 | parser.add_argument('--ckpt_dir', type=str, default='checkpoints')
30 | parser.add_argument('--summary_dir', type=str, default='summary')
31 | parser.add_argument('--output_dir', type=str, default='output')
32 | args = parser.parse_args()
33 |
34 | clean(args)
35 |
--------------------------------------------------------------------------------
/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | """datasets.py"""
2 | import os
3 |
4 | from torch.utils.data import DataLoader
5 | from torchvision import transforms
6 | from torchvision.datasets import MNIST
7 |
8 |
9 | class UnknownDatasetError(Exception):
10 | def __str__(self):
11 | return "unknown datasets error"
12 |
13 |
14 | def return_data(args):
15 | name = args.dataset
16 | dset_dir = args.dset_dir
17 | batch_size = args.batch_size
18 | transform = transforms.Compose([transforms.ToTensor(),
19 | transforms.Normalize((0.5,), (0.5,)),
20 | ])
21 |
22 | if 'MNIST' in name:
23 | root = os.path.join(dset_dir, 'MNIST')
24 | train_kwargs = {'root':root, 'train':True, 'transform':transform, 'download':True}
25 | test_kwargs = {'root':root, 'train':False, 'transform':transform, 'download':False}
26 | dset = MNIST
27 |
28 | else:
29 | raise UnknownDatasetError()
30 |
31 | train_data = dset(**train_kwargs)
32 | train_loader = DataLoader(train_data,
33 | batch_size=batch_size,
34 | shuffle=True,
35 | num_workers=1,
36 | pin_memory=True,
37 | drop_last=True)
38 |
39 | test_data = dset(**test_kwargs)
40 | test_loader = DataLoader(test_data,
41 | batch_size=batch_size,
42 | shuffle=False,
43 | num_workers=1,
44 | pin_memory=True,
45 | drop_last=False)
46 |
47 | data_loader = dict()
48 | data_loader['train'] = train_loader
49 | data_loader['test'] = test_loader
50 |
51 | return data_loader
52 |
53 |
54 | if __name__ == '__main__':
55 | import argparse
56 | os.chdir('..')
57 |
58 | parser = argparse.ArgumentParser()
59 | parser.add_argument('--dataset', type=str, default='MNIST')
60 | parser.add_argument('--dset_dir', type=str, default='datasets')
61 | parser.add_argument('--batch_size', type=int, default=64)
62 | args = parser.parse_args()
63 |
64 | data_loader = return_data(args)
65 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """main.py"""
2 | import argparse
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from solver import Solver
8 | from utils.utils import str2bool
9 |
10 | def main(args):
11 |
12 | torch.backends.cudnn.enabled = True
13 | torch.backends.cudnn.benchmark = True
14 |
15 | seed = args.seed
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed(seed)
18 | np.random.seed(seed)
19 |
20 | np.set_printoptions(precision=4)
21 | torch.set_printoptions(precision=4)
22 |
23 | print()
24 | print('[ARGUMENTS]')
25 | print(args)
26 | print()
27 |
28 | net = Solver(args)
29 |
30 | if args.mode == 'train':
31 | net.train()
32 | elif args.mode == 'test':
33 | net.test()
34 | elif args.mode == 'generate':
35 | net.generate(num_sample=args.batch_size,
36 | target=args.target,
37 | epsilon=args.epsilon,
38 | alpha=args.alpha,
39 | iteration=args.iteration)
40 | elif args.mode == 'universal':
41 | net.universal(args)
42 | else: return
43 |
44 | print('[*] Finished')
45 |
46 |
47 | if __name__ == "__main__":
48 |
49 | parser = argparse.ArgumentParser(description='toynet template')
50 | parser.add_argument('--epoch', type=int, default=20, help='epoch size')
51 | parser.add_argument('--batch_size', type=int, default=100, help='mini-batch size')
52 | parser.add_argument('--lr', type=float, default=2e-4, help='learning rate')
53 | parser.add_argument('--y_dim', type=int, default=10, help='the number of classes')
54 | parser.add_argument('--target', type=int, default=-1, help='target class for targeted generation')
55 | parser.add_argument('--eps', type=float, default=1e-9, help='epsilon')
56 | parser.add_argument('--env_name', type=str, default='main', help='experiment name')
57 | parser.add_argument('--dataset', type=str, default='FMNIST', help='dataset type')
58 | parser.add_argument('--dset_dir', type=str, default='datasets', help='dataset directory path')
59 | parser.add_argument('--summary_dir', type=str, default='summary', help='summary directory path')
60 | parser.add_argument('--output_dir', type=str, default='output', help='output directory path')
61 | parser.add_argument('--ckpt_dir', type=str, default='checkpoints', help='checkpoint directory path')
62 | parser.add_argument('--load_ckpt', type=str, default='', help='')
63 | parser.add_argument('--cuda', type=str2bool, default=True, help='enable cuda')
64 | parser.add_argument('--silent', type=str2bool, default=False, help='')
65 | parser.add_argument('--mode', type=str, default='train', help='train / test / generate / universal')
66 | parser.add_argument('--seed', type=int, default=1, help='random seed')
67 | parser.add_argument('--iteration', type=int, default=1, help='the number of iteration for FGSM')
68 | parser.add_argument('--epsilon', type=float, default=0.03, help='epsilon for FGSM and i-FGSM')
69 | parser.add_argument('--alpha', type=float, default=2/255, help='alpha for i-FGSM')
70 | parser.add_argument('--tensorboard', type=str2bool, default=False, help='enable tensorboard')
71 | parser.add_argument('--visdom', type=str2bool, default=False, help='enable visdom')
72 | parser.add_argument('--visdom_port', type=str, default=55558, help='visdom port')
73 | args = parser.parse_args()
74 |
75 | main(args)
76 |
--------------------------------------------------------------------------------
/misc/FGSM.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/FGSM/d5e730b935b6a02f6045ccacc18e4a16676a584e/misc/FGSM.PNG
--------------------------------------------------------------------------------
/misc/IFGSM.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/FGSM/d5e730b935b6a02f6045ccacc18e4a16676a584e/misc/IFGSM.PNG
--------------------------------------------------------------------------------
/misc/nontargeted_1.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/FGSM/d5e730b935b6a02f6045ccacc18e4a16676a584e/misc/nontargeted_1.PNG
--------------------------------------------------------------------------------
/misc/nontargeted_2.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/FGSM/d5e730b935b6a02f6045ccacc18e4a16676a584e/misc/nontargeted_2.PNG
--------------------------------------------------------------------------------
/misc/nontargeted_3.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/FGSM/d5e730b935b6a02f6045ccacc18e4a16676a584e/misc/nontargeted_3.PNG
--------------------------------------------------------------------------------
/misc/overview.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/FGSM/d5e730b935b6a02f6045ccacc18e4a16676a584e/misc/overview.PNG
--------------------------------------------------------------------------------
/misc/targetd_9_1.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/FGSM/d5e730b935b6a02f6045ccacc18e4a16676a584e/misc/targetd_9_1.PNG
--------------------------------------------------------------------------------
/misc/targetd_9_2.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/FGSM/d5e730b935b6a02f6045ccacc18e4a16676a584e/misc/targetd_9_2.PNG
--------------------------------------------------------------------------------
/misc/targetd_9_3.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1Konny/FGSM/d5e730b935b6a02f6045ccacc18e4a16676a584e/misc/targetd_9_3.PNG
--------------------------------------------------------------------------------
/models/toynet.py:
--------------------------------------------------------------------------------
1 | """toynet.py"""
2 | import torch.nn as nn
3 |
4 | class ToyNet(nn.Module):
5 | def __init__(self, x_dim=784, y_dim=10):
6 | super(ToyNet, self).__init__()
7 | self.x_dim = x_dim
8 | self.y_dim = y_dim
9 |
10 | self.mlp = nn.Sequential(
11 | nn.Linear(self.x_dim, 300),
12 | nn.ReLU(True),
13 | nn.Linear(300, 150),
14 | nn.ReLU(True),
15 | nn.Linear(150, self.y_dim)
16 | )
17 |
18 | def forward(self, X):
19 | if X.dim() > 2:
20 | X = X.view(X.size(0), -1)
21 | out = self.mlp(X)
22 |
23 | return out
24 |
25 | def weight_init(self, _type='kaiming'):
26 | if _type == 'kaiming':
27 | for ms in self._modules:
28 | kaiming_init(self._modules[ms].parameters())
29 |
30 |
31 | def xavier_init(ms):
32 | for m in ms:
33 | if isinstance(m, (nn.Linear, nn.Conv2d)):
34 | nn.init.xavier_uniform(m.weight, gain=nn.init.calculate_gain('relu'))
35 | if m.bias.data:
36 | m.bias.data.zero_()
37 | if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
38 | m.weight.data.fill_(1)
39 | if m.bias.data:
40 | m.bias.data.zero_()
41 |
42 |
43 | def kaiming_init(ms):
44 | for m in ms:
45 | if isinstance(m, (nn.Linear, nn.Conv2d)):
46 | nn.init.kaiming_uniform(m.weight, a=0, mode='fan_in')
47 | if m.bias.data:
48 | m.bias.data.zero_()
49 | if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
50 | m.weight.data.fill_(1)
51 | if m.bias.data:
52 | m.bias.data.zero_()
53 |
--------------------------------------------------------------------------------
/solver.py:
--------------------------------------------------------------------------------
1 | """solver.py"""
2 | from pathlib import Path
3 |
4 | import torch
5 | import torch.optim as optim
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 | from torchvision.utils import save_image
9 |
10 | from models.toynet import ToyNet
11 | from datasets.datasets import return_data
12 | from utils.utils import rm_dir, cuda, where
13 | from adversary import Attack
14 |
15 |
16 | class Solver(object):
17 | def __init__(self, args):
18 | self.args = args
19 |
20 | # Basic
21 | self.cuda = (args.cuda and torch.cuda.is_available())
22 | self.epoch = args.epoch
23 | self.batch_size = args.batch_size
24 | self.eps = args.eps
25 | self.lr = args.lr
26 | self.y_dim = args.y_dim
27 | self.target = args.target
28 | self.dataset = args.dataset
29 | self.data_loader = return_data(args)
30 | self.global_epoch = 0
31 | self.global_iter = 0
32 | self.print_ = not args.silent
33 |
34 | self.env_name = args.env_name
35 | self.tensorboard = args.tensorboard
36 | self.visdom = args.visdom
37 |
38 | self.ckpt_dir = Path(args.ckpt_dir).joinpath(args.env_name)
39 | if not self.ckpt_dir.exists():
40 | self.ckpt_dir.mkdir(parents=True, exist_ok=True)
41 | self.output_dir = Path(args.output_dir).joinpath(args.env_name)
42 | if not self.output_dir.exists():
43 | self.output_dir.mkdir(parents=True, exist_ok=True)
44 |
45 | # Visualization Tools
46 | self.visualization_init(args)
47 |
48 | # Histories
49 | self.history = dict()
50 | self.history['acc'] = 0.
51 | self.history['epoch'] = 0
52 | self.history['iter'] = 0
53 |
54 | # Models & Optimizers
55 | self.model_init(args)
56 | self.load_ckpt = args.load_ckpt
57 | if self.load_ckpt != '':
58 | self.load_checkpoint(self.load_ckpt)
59 |
60 | # Adversarial Perturbation Generator
61 | #criterion = cuda(torch.nn.CrossEntropyLoss(), self.cuda)
62 | criterion = F.cross_entropy
63 | self.attack = Attack(self.net, criterion=criterion)
64 |
65 | def visualization_init(self, args):
66 | # Visdom
67 | if self.visdom:
68 | from utils.visdom_utils import VisFunc
69 | self.port = args.visdom_port
70 | self.vf = VisFunc(enval=self.env_name, port=self.port)
71 |
72 | # TensorboardX
73 | if self.tensorboard:
74 | from tensorboardX import SummaryWriter
75 | self.summary_dir = Path(args.summary_dir).joinpath(args.env_name)
76 | if not self.summary_dir.exists():
77 | self.summary_dir.mkdir(parents=True, exist_ok=True)
78 |
79 | self.tf = SummaryWriter(log_dir=str(self.summary_dir))
80 | self.tf.add_text(tag='argument', text_string=str(args), global_step=self.global_epoch)
81 |
82 | def model_init(self, args):
83 | # Network
84 | self.net = cuda(ToyNet(y_dim=self.y_dim), self.cuda)
85 | self.net.weight_init(_type='kaiming')
86 |
87 | # Optimizers
88 | self.optim = optim.Adam([{'params':self.net.parameters(), 'lr':self.lr}],
89 | betas=(0.5, 0.999))
90 |
91 | def train(self):
92 | self.set_mode('train')
93 | for e in range(self.epoch):
94 | self.global_epoch += 1
95 |
96 | correct = 0.
97 | cost = 0.
98 | total = 0.
99 | for batch_idx, (images, labels) in enumerate(self.data_loader['train']):
100 | self.global_iter += 1
101 |
102 | x = Variable(cuda(images, self.cuda))
103 | y = Variable(cuda(labels, self.cuda))
104 |
105 | logit = self.net(x)
106 | prediction = logit.max(1)[1]
107 |
108 | correct = torch.eq(prediction, y).float().mean().data[0]
109 | cost = F.cross_entropy(logit, y)
110 |
111 | self.optim.zero_grad()
112 | cost.backward()
113 | self.optim.step()
114 |
115 | if batch_idx % 100 == 0:
116 | if self.print_:
117 | print()
118 | print(self.env_name)
119 | print('[{:03d}:{:03d}]'.format(self.global_epoch, batch_idx))
120 | print('acc:{:.3f} loss:{:.3f}'.format(correct, cost.data[0]))
121 |
122 |
123 | if self.tensorboard:
124 | self.tf.add_scalars(main_tag='performance/acc',
125 | tag_scalar_dict={'train':correct},
126 | global_step=self.global_iter)
127 | self.tf.add_scalars(main_tag='performance/error',
128 | tag_scalar_dict={'train':1-correct},
129 | global_step=self.global_iter)
130 | self.tf.add_scalars(main_tag='performance/cost',
131 | tag_scalar_dict={'train':cost.data[0]},
132 | global_step=self.global_iter)
133 |
134 |
135 | self.test()
136 |
137 |
138 | if self.tensorboard:
139 | self.tf.add_scalars(main_tag='performance/best/acc',
140 | tag_scalar_dict={'test':self.history['acc']},
141 | global_step=self.history['iter'])
142 | print(" [*] Training Finished!")
143 |
144 | def test(self):
145 | self.set_mode('eval')
146 |
147 | correct = 0.
148 | cost = 0.
149 | total = 0.
150 |
151 | data_loader = self.data_loader['test']
152 | for batch_idx, (images, labels) in enumerate(data_loader):
153 | x = Variable(cuda(images, self.cuda))
154 | y = Variable(cuda(labels, self.cuda))
155 |
156 | logit = self.net(x)
157 | prediction = logit.max(1)[1]
158 |
159 | correct += torch.eq(prediction, y).float().sum().data[0]
160 | cost += F.cross_entropy(logit, y, size_average=False).data[0]
161 | total += x.size(0)
162 |
163 | accuracy = correct / total
164 | cost /= total
165 |
166 |
167 | if self.print_:
168 | print()
169 | print('[{:03d}]\nTEST RESULT'.format(self.global_epoch))
170 | print('ACC:{:.4f}'.format(accuracy))
171 | print('*TOP* ACC:{:.4f} at e:{:03d}'.format(accuracy, self.global_epoch,))
172 | print()
173 |
174 | if self.tensorboard:
175 | self.tf.add_scalars(main_tag='performance/acc',
176 | tag_scalar_dict={'test':accuracy},
177 | global_step=self.global_iter)
178 |
179 | self.tf.add_scalars(main_tag='performance/error',
180 | tag_scalar_dict={'test':(1-accuracy)},
181 | global_step=self.global_iter)
182 |
183 | self.tf.add_scalars(main_tag='performance/cost',
184 | tag_scalar_dict={'test':cost},
185 | global_step=self.global_iter)
186 |
187 | if self.history['acc'] < accuracy:
188 | self.history['acc'] = accuracy
189 | self.history['epoch'] = self.global_epoch
190 | self.history['iter'] = self.global_iter
191 | self.save_checkpoint('best_acc.tar')
192 |
193 | self.set_mode('train')
194 |
195 | def generate(self, num_sample=100, target=-1, epsilon=0.03, alpha=2/255, iteration=1):
196 | self.set_mode('eval')
197 |
198 | x_true, y_true = self.sample_data(num_sample)
199 | if isinstance(target, int) and (target in range(self.y_dim)):
200 | y_target = torch.LongTensor(y_true.size()).fill_(target)
201 | else:
202 | y_target = None
203 |
204 | x_adv, changed, values = self.FGSM(x_true, y_true, y_target, epsilon, alpha, iteration)
205 | accuracy, cost, accuracy_adv, cost_adv = values
206 |
207 | save_image(x_true,
208 | self.output_dir.joinpath('legitimate(t:{},e:{},i:{}).jpg'.format(target,
209 | epsilon,
210 | iteration)),
211 | nrow=10,
212 | padding=2,
213 | pad_value=0.5)
214 | save_image(x_adv,
215 | self.output_dir.joinpath('perturbed(t:{},e:{},i:{}).jpg'.format(target,
216 | epsilon,
217 | iteration)),
218 | nrow=10,
219 | padding=2,
220 | pad_value=0.5)
221 | save_image(changed,
222 | self.output_dir.joinpath('changed(t:{},e:{},i:{}).jpg'.format(target,
223 | epsilon,
224 | iteration)),
225 | nrow=10,
226 | padding=3,
227 | pad_value=0.5)
228 |
229 | if self.visdom:
230 | self.vf.imshow_multi(x_true.cpu(), title='legitimate', factor=1.5)
231 | self.vf.imshow_multi(x_adv.cpu(), title='perturbed(e:{},i:{})'.format(epsilon, iteration), factor=1.5)
232 | self.vf.imshow_multi(changed.cpu(), title='changed(white)'.format(epsilon), factor=1.5)
233 |
234 | print('[BEFORE] accuracy : {:.2f} cost : {:.3f}'.format(accuracy, cost))
235 | print('[AFTER] accuracy : {:.2f} cost : {:.3f}'.format(accuracy_adv, cost_adv))
236 |
237 | self.set_mode('train')
238 |
239 | def sample_data(self, num_sample=100):
240 |
241 | total = len(self.data_loader['test'].dataset)
242 | seed = torch.FloatTensor(num_sample).uniform_(1, total).long()
243 |
244 | x = self.data_loader['test'].dataset.test_data[seed]
245 | x = self.scale(x.float().unsqueeze(1).div(255))
246 | y = self.data_loader['test'].dataset.test_labels[seed]
247 |
248 | return x, y
249 |
250 |
251 | def FGSM(self, x, y_true, y_target=None, eps=0.03, alpha=2/255, iteration=1):
252 | self.set_mode('eval')
253 |
254 | x = Variable(cuda(x, self.cuda), requires_grad=True)
255 | y_true = Variable(cuda(y_true, self.cuda), requires_grad=False)
256 | if y_target is not None:
257 | targeted = True
258 | y_target = Variable(cuda(y_target, self.cuda), requires_grad=False)
259 | else:
260 | targeted = False
261 |
262 |
263 | h = self.net(x)
264 | prediction = h.max(1)[1]
265 | accuracy = torch.eq(prediction, y_true).float().mean()
266 | cost = F.cross_entropy(h, y_true)
267 |
268 | if iteration == 1:
269 | if targeted:
270 | x_adv, h_adv, h = self.attack.fgsm(x, y_target, True, eps)
271 | else:
272 | x_adv, h_adv, h = self.attack.fgsm(x, y_true, False, eps)
273 | else:
274 | if targeted:
275 | x_adv, h_adv, h = self.attack.i_fgsm(x, y_target, True, eps, alpha, iteration)
276 | else:
277 | x_adv, h_adv, h = self.attack.i_fgsm(x, y_true, False, eps, alpha, iteration)
278 |
279 | prediction_adv = h_adv.max(1)[1]
280 | accuracy_adv = torch.eq(prediction_adv, y_true).float().mean()
281 | cost_adv = F.cross_entropy(h_adv, y_true)
282 |
283 | # make indication of perturbed images that changed predictions of the classifier
284 | if targeted:
285 | changed = torch.eq(y_target, prediction_adv)
286 | else:
287 | changed = torch.eq(prediction, prediction_adv)
288 | changed = torch.eq(changed, 0)
289 | changed = changed.float().view(-1, 1, 1, 1).repeat(1, 3, 28, 28)
290 |
291 | changed[:, 0, :, :] = where(changed[:, 0, :, :] == 1, 252, 91)
292 | changed[:, 1, :, :] = where(changed[:, 1, :, :] == 1, 39, 252)
293 | changed[:, 2, :, :] = where(changed[:, 2, :, :] == 1, 25, 25)
294 | changed = self.scale(changed/255)
295 | changed[:, :, 3:-2, 3:-2] = x_adv.repeat(1, 3, 1, 1)[:, :, 3:-2, 3:-2]
296 |
297 | self.set_mode('train')
298 |
299 | return x_adv.data, changed.data,\
300 | (accuracy.data[0], cost.data[0], accuracy_adv.data[0], cost_adv.data[0])
301 |
302 | def save_checkpoint(self, filename='ckpt.tar'):
303 | model_states = {
304 | 'net':self.net.state_dict(),
305 | }
306 | optim_states = {
307 | 'optim':self.optim.state_dict(),
308 | }
309 | states = {
310 | 'iter':self.global_iter,
311 | 'epoch':self.global_epoch,
312 | 'history':self.history,
313 | 'args':self.args,
314 | 'model_states':model_states,
315 | 'optim_states':optim_states,
316 | }
317 |
318 | file_path = self.ckpt_dir / filename
319 | torch.save(states, file_path.open('wb+'))
320 | print("=> saved checkpoint '{}' (iter {})".format(file_path, self.global_iter))
321 |
322 | def load_checkpoint(self, filename='best_acc.tar'):
323 | file_path = self.ckpt_dir / filename
324 | if file_path.is_file():
325 | print("=> loading checkpoint '{}'".format(file_path))
326 | checkpoint = torch.load(file_path.open('rb'))
327 | self.global_epoch = checkpoint['epoch']
328 | self.global_iter = checkpoint['iter']
329 | self.history = checkpoint['history']
330 |
331 | self.net.load_state_dict(checkpoint['model_states']['net'])
332 | self.optim.load_state_dict(checkpoint['optim_states']['optim'])
333 |
334 | print("=> loaded checkpoint '{} (iter {})'".format(file_path, self.global_iter))
335 |
336 | else:
337 | print("=> no checkpoint found at '{}'".format(file_path))
338 |
339 | def set_mode(self, mode='train'):
340 | if mode == 'train':
341 | self.net.train()
342 | elif mode == 'eval':
343 | self.net.eval()
344 | else: raise('mode error. It should be either train or eval')
345 |
346 | def scale(self, image):
347 | return image.mul(2).add(-1)
348 |
349 | def unscale(self, image):
350 | return image.add(1).mul(0.5)
351 |
352 | def summary_flush(self, silent=True):
353 | rm_dir(self.summary_dir, silent)
354 |
355 | def checkpoint_flush(self, silent=True):
356 | rm_dir(self.ckpt_dir, silent)
357 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import argparse, torch
2 | import numpy as np
3 | from torch import nn
4 | from torch.autograd import Variable
5 | from pathlib import Path
6 |
7 | class One_Hot(nn.Module):
8 | # from :
9 | # https://lirnli.wordpress.com/2017/09/03/one-hot-encoding-in-pytorch/
10 | def __init__(self, depth):
11 | super(One_Hot,self).__init__()
12 | self.depth = depth
13 | self.ones = torch.sparse.torch.eye(depth)
14 | def forward(self, X_in):
15 | X_in = X_in.long()
16 | return Variable(self.ones.index_select(0,X_in.data))
17 | def __repr__(self):
18 | return self.__class__.__name__ + "({})".format(self.depth)
19 |
20 |
21 | def cuda(tensor,is_cuda):
22 | if is_cuda : return tensor.cuda()
23 | else : return tensor
24 |
25 | def str2bool(v):
26 | # codes from : https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
27 |
28 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
29 | return True
30 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
31 | return False
32 | else:
33 | raise argparse.ArgumentTypeError('Boolean value expected.')
34 |
35 |
36 | def print_network(net):
37 | num_params = 0
38 | for param in net.parameters():
39 | num_params += param.numel()
40 | print(net)
41 | print('Total number of parameters: %d' % num_params)
42 |
43 |
44 | def rm_dir(dir_path, silent=True):
45 | p = Path(dir_path).resolve()
46 | if (not p.is_file()) and (not p.is_dir()) :
47 | print('It is not path for file nor directory :',p)
48 | return
49 |
50 | paths = list(p.iterdir())
51 | if (len(paths) == 0) and p.is_dir() :
52 | p.rmdir()
53 | if not silent : print('removed empty dir :',p)
54 |
55 | else :
56 | for path in paths :
57 | if path.is_file() :
58 | path.unlink()
59 | if not silent : print('removed file :',path)
60 | else:
61 | rm_dir(path)
62 | p.rmdir()
63 | if not silent : print('removed empty dir :',p)
64 |
65 | def where(cond, x, y):
66 | """
67 | code from :
68 | https://discuss.pytorch.org/t/how-can-i-do-the-operation-the-same-as-np-where/1329/8
69 | """
70 | cond = cond.float()
71 | return (cond*x) + ((1-cond)*y)
72 |
--------------------------------------------------------------------------------
/utils/visdom_utils.py:
--------------------------------------------------------------------------------
1 | import visdom
2 | from scipy.misc import imresize
3 | import numpy as np
4 | from torchvision.utils import make_grid
5 |
6 | class VisFunc(object):
7 |
8 | def __init__(self, config=None, vis=None, enval='hproto',port=8097):
9 | self.config = config
10 | self.vis = visdom.Visdom(env=enval, port=port)
11 | self.win = None
12 | self.win2 = None
13 | self.epoch_list = []
14 | self.train_loss_list = []
15 | self.val_loss_list = []
16 | self.epoch_list2 = []
17 | self.train_acc_list = []
18 | self.val_acc_list = []
19 |
20 |
21 | def imshow(self, img, title=' ', caption=' ', factor=1):
22 |
23 | img = img / 2 + 0.5 # Unnormalize
24 | npimg = img.numpy()
25 | obj = np.transpose(npimg, (1,2,0))
26 | obj = np.swapaxes(obj,0,2)
27 | obj = np.swapaxes(obj,1,2)
28 |
29 | imgsize = tuple((np.array(obj.shape[1:])*factor).astype(int))
30 | rgbArray = np.zeros(tuple([3])+imgsize,'float32')
31 | rgbArray[0,...] = imresize(obj[0,:,:],imgsize,'cubic')
32 | rgbArray[1,...] = imresize(obj[1,:,:],imgsize,'cubic')
33 | rgbArray[2,...] = imresize(obj[2,:,:],imgsize,'cubic')
34 |
35 | self.vis.image( rgbArray,
36 | opts=dict(title=title, caption=caption),
37 | )
38 |
39 |
40 | def imshow_multi(self, imgs, nrow=10, title=' ', caption=' ', factor=1):
41 | #self.imshow( make_grid(imgs,nrow,padding=padding), title, caption, factor)
42 | self.imshow( make_grid(imgs,nrow), title, caption, factor)
43 |
44 |
45 | def imshow_one_batch(self, loader, classes=None, factor=1):
46 | dataiter = iter(loader)
47 | images, labels = dataiter.next()
48 | self.imshow(make_grid(images,padding))
49 |
50 | if classes:
51 | print(' '.join('%5s' % classes[labels[j]]
52 | for j in range(loader.batch_size)))
53 | else:
54 | print(' '.join('%5s' % labels[j]
55 | for j in range(loader.batch_size)))
56 |
57 |
58 | def plot(self, epoch, train_loss, val_loss,Des):
59 | ''' plot learning curve interactively with visdom '''
60 | self.epoch_list.append(epoch)
61 | self.train_loss_list.append(train_loss)
62 | self.val_loss_list.append(val_loss)
63 |
64 | if not self.win:
65 | # send line plot
66 | # embed()
67 | self.win = self.vis.line(
68 | X=np.array(self.epoch_list),
69 | Y=np.array([[self.train_loss_list[-1], self.val_loss_list[-1]]]),
70 | opts=dict(
71 | title='Learning Curve (' + Des +')',
72 | xlabel='Epoch',
73 | ylabel='Loss',
74 | legend=['train_loss', 'val_loss'],
75 | #caption=Des
76 | ))
77 | # send text memo (configuration)
78 | # self.vis.text(str(Des))
79 | else:
80 | self.vis.updateTrace(
81 | X=np.array(self.epoch_list[-2:]),
82 | Y=np.array(self.train_loss_list[-2:]),
83 | win=self.win,
84 | name='train_loss',
85 | )
86 | self.vis.updateTrace(
87 | X=np.array(self.epoch_list[-2:]),
88 | Y=np.array(self.val_loss_list[-2:]),
89 | win=self.win,
90 | name='val_loss',
91 | )
92 |
93 |
94 | def acc_plot(self, epoch, train_acc, val_acc, Des):
95 | ''' plot learning curve interactively with visdom '''
96 | self.epoch_list2.append(epoch)
97 | self.train_acc_list.append(train_acc)
98 | self.val_acc_list.append(val_acc)
99 |
100 | if not self.win2:
101 | # send line plot
102 | # embed()
103 | self.win2 = self.vis.line(
104 | X=np.array(self.epoch_list2),
105 | Y=np.array([[self.train_acc_list[-1], self.val_acc_list[-1]]]),
106 | opts=dict(
107 | title='Accuracy Curve (' + Des +')',
108 | xlabel='Epoch',
109 | ylabel='Accuracy',
110 | legend=['train_accuracy', 'val_accuracy']
111 | ))
112 | # send text memo (configuration)
113 | # self.vis.text(str(self.config))
114 | else:
115 | self.vis.updateTrace(
116 | X=np.array(self.epoch_list2[-2:]),
117 | Y=np.array(self.train_acc_list[-2:]),
118 | win=self.win2,
119 | name='train_accuracy',
120 | )
121 | self.vis.updateTrace(
122 | X=np.array(self.epoch_list2[-2:]),
123 | Y=np.array(self.val_acc_list[-2:]),
124 | win=self.win2,
125 | name='val_accuracy',
126 | )
127 |
128 |
129 | def plot2(self, epoch, train_loss, val_loss,Des, win):
130 | ''' plot learning curve interactively with visdom '''
131 | self.epoch_list.append(epoch)
132 | self.train_loss_list.append(train_loss)
133 | self.val_loss_list.append(val_loss)
134 |
135 | if not self.win:
136 | self.win = win
137 | # send line plot
138 | # embed()
139 | #self.win = self.vis.line(
140 | # X=np.array(self.epoch_list),
141 | # Y=np.array([[self.train_loss_list[-1], self.val_loss_list[-1]]]),
142 | # opts=dict(
143 | # title='Learning Curve (' + Des +')',
144 | # xlabel='Epoch',
145 | # ylabel='Loss',
146 | # legend=['train_loss', 'val_loss'],
147 | # #caption=Des
148 | # ))
149 | ## send text memo (configuration)
150 | # self.vis.text(str(Des))
151 | else:
152 | self.vis.updateTrace(
153 | X=np.array(self.epoch_list[-2:]),
154 | Y=np.array(self.train_loss_list[-2:]),
155 | win=self.win,
156 | name='train_loss2',
157 | )
158 | self.vis.updateTrace(
159 | X=np.array(self.epoch_list[-2:]),
160 | Y=np.array(self.val_loss_list[-2:]),
161 | win=self.win,
162 | name='val_lossi2',
163 | )
164 |
--------------------------------------------------------------------------------