├── .gitattributes
├── .DS_Store
├── README.md
├── test_worst_acc.py
├── utils.py
├── models
└── wideresnet.py
└── main.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlanChou/Adversarial-Training-for-Free/HEAD/.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Adversarial Training for Free
2 |
3 | This is an unofficial PyTorch implementation of the paper "Adversarial Training for Free!".
4 | https://128.84.21.199/pdf/1904.12843.pdf
5 | It's a really helpful technique which can significantly accelerate adversarial training.
6 | It only contains the code for adversarial training on CIFAR-10. However, one can easily modify it for other datasets.
7 |
8 | I've noticed that in the cleverhans repo, they have an attribute called BACKPROP_THROUGH_ATTACK which is exactly the same idea behind this paper.
9 |
10 | ## Overview
11 |
12 | This repository contains two files. One for adversarial training and one for testing. Note that the model used here is not WideResNet 32-10 which is used in the paper 'Adversarial Training for Free'. I use WideResNet 28-10 which is used in the original PGD paper. Be aware that some of the hyperparameters are slightly different from the paper (weight decay and learning rate scheduling)
13 |
14 |
15 |
16 | ## Accuracy (under PGD attack: epsilon = 8, step size = 2, iteration = 100)
17 | | Model | Acc |
18 | | ---------------------------| ----------- |
19 | | [ WideResNet 28-10 ] | 46.93% |
20 |
21 | I did not test every epoch's checkpoint. I simply chose from one of the last epochs to test. Results might be slightly different. I've also released the checkpoint in the below Google Drive link)
22 | I have trouble training with ResNet56 and ResNet20 to the baseline they supposed to have which may suggest that this method does not apply to any given models. Please don't hesitate to share your results if you know how to fix this.
23 | `checkpoint` [Google Drive](https://drive.google.com/file/d/1iZ52Ctcwty8bLMvLJJMlWcHL-__lJcbo/view?usp=sharing)
24 |
25 | ## Dependencies
26 | The repository is based on Python 3.5, with the main dependencies being PyTorch==1.0.0 Additional dependencies for running experiments are: numpy, tqdm, argparse, os, random, advertorch.
27 |
28 | Advertorch can be installed from https://github.com/BorealisAI/advertorch
29 | Run the code with the command:
30 | ```
31 | $ CUDA_VISIBLE_DEVICES=0 python3 main.py
32 | ```
33 |
34 |
35 |
--------------------------------------------------------------------------------
/test_worst_acc.py:
--------------------------------------------------------------------------------
1 | '''Train CIFAR10 with PyTorch.'''
2 | from __future__ import print_function
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | import torch.nn.functional as F
8 | import torch.backends.cudnn as cudnn
9 |
10 | import torchvision
11 | import torchvision.transforms as transforms
12 | import os
13 | import argparse
14 |
15 | from tqdm import tqdm
16 | import numpy as np
17 | from models.wideresnet import *
18 | from utils import *
19 | from advertorch.attacks import LinfPGDAttack
20 |
21 |
22 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Testing')
23 | parser.add_argument('--seed', default=11111, type=int, help='seed')
24 | parser.add_argument('--epoch', default=0, type=int, help='load checkpoint from that epoch')
25 | parser.add_argument('--model', default='wideresnet', type=str)
26 | parser.add_argument('--batch_size', default=100, type=int)
27 | parser.add_argument('--iteration', default=100, type=int)
28 | parser.add_argument('--epsilon', default=8./255, type=float)
29 | parser.add_argument('--step_size', default=2./255, type=float)
30 |
31 |
32 | args = parser.parse_args()
33 |
34 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
35 |
36 | np.random.seed(args.seed)
37 | torch.manual_seed(args.seed)
38 | torch.cuda.manual_seed(args.seed)
39 | cudnn.benchmark = False
40 | cudnn.deterministic = True
41 |
42 | # Data
43 | print('==> Preparing data..')
44 | transform_test = transforms.Compose([
45 | transforms.ToTensor(),
46 | ])
47 |
48 |
49 | testset = torchvision.datasets.CIFAR10(root='/home/hsinpingchou/data', train=False, download=True, transform=transform_test)
50 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=2)
51 |
52 | print('==> Building model {}..'.format(args.model))
53 | if args.model == 'wideresnet':
54 | net = WideResNet_28_10()
55 | else:
56 | raise ValueError('No such model.')
57 |
58 | def test(epoch):
59 | net.eval()
60 | correct = 0
61 | total = 0
62 | with torch.no_grad():
63 | iterator = tqdm(testloader, ncols=0, leave=False)
64 | for batch_idx, (inputs, targets) in enumerate(iterator):
65 | inputs, targets = inputs.to(device), targets.to(device)
66 | with torch.enable_grad():
67 | adv = adversary.perturb(inputs, targets)
68 | outputs = net(adv)
69 | _, predicted = outputs.max(1)
70 | total += targets.size(0)
71 | correct += predicted.eq(targets).sum().item()
72 | iterator.set_description(str(predicted.eq(targets).sum().item()/targets.size(0)))
73 |
74 |
75 | # Save checkpoint.
76 | acc = 100.*correct/total
77 | print('Test Acc of ckpt.{}: {}'.format(args.epoch, acc))
78 |
79 |
80 | print('==> Loading from checkpoint epoch {}..'.format(args.epoch))
81 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
82 | checkpoint = torch.load('./checkpoint/ckpt.{}'.format(args.epoch))
83 | net.load_state_dict(checkpoint['net'])
84 | net = net.to(device)
85 | net.eval()
86 |
87 |
88 | adversary = LinfPGDAttack(
89 | net, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=args.epsilon,
90 | nb_iter=args.iteration, eps_iter=args.step_size, rand_init=True, clip_min=0.0, clip_max=1.0,
91 | targeted=False)
92 |
93 | test(adversary)
94 |
95 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | '''Some helper functions for PyTorch, including:
2 | - get_mean_and_std: calculate the mean and std value of dataset.
3 | - msr_init: net parameter initialization.
4 | - progress_bar: progress bar mimic xlua.progress.
5 | '''
6 | import os
7 | import sys
8 | import time
9 | import math
10 |
11 | import torch.nn as nn
12 | import torch.nn.init as init
13 |
14 |
15 | def get_mean_and_std(dataset):
16 | '''Compute the mean and std value of dataset.'''
17 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
18 | mean = torch.zeros(3)
19 | std = torch.zeros(3)
20 | print('==> Computing mean and std..')
21 | for inputs, targets in dataloader:
22 | for i in range(3):
23 | mean[i] += inputs[:,i,:,:].mean()
24 | std[i] += inputs[:,i,:,:].std()
25 | mean.div_(len(dataset))
26 | std.div_(len(dataset))
27 | return mean, std
28 |
29 | def init_params(net):
30 | '''Init layer parameters.'''
31 | for m in net.modules():
32 | if isinstance(m, nn.Conv2d):
33 | init.kaiming_normal(m.weight, mode='fan_out')
34 | if m.bias:
35 | init.constant(m.bias, 0)
36 | elif isinstance(m, nn.BatchNorm2d):
37 | init.constant(m.weight, 1)
38 | init.constant(m.bias, 0)
39 | elif isinstance(m, nn.Linear):
40 | init.normal(m.weight, std=1e-3)
41 | if m.bias:
42 | init.constant(m.bias, 0)
43 |
44 |
45 | _, term_width = os.popen('stty size', 'r').read().split()
46 | term_width = int(term_width)
47 |
48 | TOTAL_BAR_LENGTH = 65.
49 | last_time = time.time()
50 | begin_time = last_time
51 | def progress_bar(current, total, msg=None):
52 | global last_time, begin_time
53 | if current == 0:
54 | begin_time = time.time() # Reset for new bar.
55 |
56 | cur_len = int(TOTAL_BAR_LENGTH*current/total)
57 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
58 |
59 | sys.stdout.write(' [')
60 | for i in range(cur_len):
61 | sys.stdout.write('=')
62 | sys.stdout.write('>')
63 | for i in range(rest_len):
64 | sys.stdout.write('.')
65 | sys.stdout.write(']')
66 |
67 | cur_time = time.time()
68 | step_time = cur_time - last_time
69 | last_time = cur_time
70 | tot_time = cur_time - begin_time
71 |
72 | L = []
73 | L.append(' Step: %s' % format_time(step_time))
74 | L.append(' | Tot: %s' % format_time(tot_time))
75 | if msg:
76 | L.append(' | ' + msg)
77 |
78 | msg = ''.join(L)
79 | sys.stdout.write(msg)
80 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
81 | sys.stdout.write(' ')
82 |
83 | # Go back to the center of the bar.
84 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
85 | sys.stdout.write('\b')
86 | sys.stdout.write(' %d/%d ' % (current+1, total))
87 |
88 | if current < total-1:
89 | sys.stdout.write('\r')
90 | else:
91 | sys.stdout.write('\n')
92 | sys.stdout.flush()
93 |
94 | def format_time(seconds):
95 | days = int(seconds / 3600/24)
96 | seconds = seconds - days*3600*24
97 | hours = int(seconds / 3600)
98 | seconds = seconds - hours*3600
99 | minutes = int(seconds / 60)
100 | seconds = seconds - minutes*60
101 | secondsf = int(seconds)
102 | seconds = seconds - secondsf
103 | millis = int(seconds*1000)
104 |
105 | f = ''
106 | i = 1
107 | if days > 0:
108 | f += str(days) + 'D'
109 | i += 1
110 | if hours > 0 and i <= 2:
111 | f += str(hours) + 'h'
112 | i += 1
113 | if minutes > 0 and i <= 2:
114 | f += str(minutes) + 'm'
115 | i += 1
116 | if secondsf > 0 and i <= 2:
117 | f += str(secondsf) + 's'
118 | i += 1
119 | if millis > 0 and i <= 2:
120 | f += str(millis) + 'ms'
121 | i += 1
122 | if f == '':
123 | f = '0ms'
124 | return f
125 |
--------------------------------------------------------------------------------
/models/wideresnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class BasicBlock(nn.Module):
8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
9 | super(BasicBlock, self).__init__()
10 | self.bn1 = nn.BatchNorm2d(in_planes)
11 | self.relu1 = nn.ReLU(inplace=True)
12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(out_planes)
15 | self.relu2 = nn.ReLU(inplace=True)
16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
17 | padding=1, bias=False)
18 | self.droprate = dropRate
19 | self.equalInOut = (in_planes == out_planes)
20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
21 | padding=0, bias=False) or None
22 | def forward(self, x):
23 | if not self.equalInOut:
24 | x = self.relu1(self.bn1(x))
25 | else:
26 | out = self.relu1(self.bn1(x))
27 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
28 | if self.droprate > 0:
29 | out = F.dropout(out, p=self.droprate, training=self.training)
30 | out = self.conv2(out)
31 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
32 |
33 | class NetworkBlock(nn.Module):
34 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
35 | super(NetworkBlock, self).__init__()
36 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
37 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
38 | layers = []
39 | for i in range(int(nb_layers)):
40 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
41 | return nn.Sequential(*layers)
42 | def forward(self, x):
43 | return self.layer(x)
44 |
45 | class WideResNet(nn.Module):
46 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
47 | super(WideResNet, self).__init__()
48 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
49 | assert((depth - 4) % 6 == 0)
50 | n = (depth - 4) / 6
51 | block = BasicBlock
52 | # 1st conv before any network block
53 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
54 | padding=1, bias=False)
55 | # 1st block
56 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
57 | # 2nd block
58 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
59 | # 3rd block
60 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
61 | # global average pooling and classifier
62 | self.bn1 = nn.BatchNorm2d(nChannels[3])
63 | self.relu = nn.ReLU(inplace=True)
64 | self.fc = nn.Linear(nChannels[3], num_classes)
65 | self.nChannels = nChannels[3]
66 |
67 | for m in self.modules():
68 | if isinstance(m, nn.Conv2d):
69 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
70 | m.weight.data.normal_(0, math.sqrt(2. / n))
71 | elif isinstance(m, nn.BatchNorm2d):
72 | m.weight.data.fill_(1)
73 | m.bias.data.zero_()
74 | elif isinstance(m, nn.Linear):
75 | m.bias.data.zero_()
76 | def forward(self, x):
77 | out = self.conv1(x)
78 | out = self.block1(out)
79 | out = self.block2(out)
80 | out = self.block3(out)
81 | out = self.relu(self.bn1(out))
82 | out = F.avg_pool2d(out, 8)
83 | out = out.view(-1, self.nChannels)
84 | return self.fc(out)
85 |
86 | def WideResNet_28_10():
87 | return WideResNet(28, 10, 10, 0.0)
88 |
89 | def WideResNet_28_10_cifar100():
90 | return WideResNet(28, 100, 10, 0.0)
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | '''Train CIFAR10 with PyTorch.'''
2 | from __future__ import print_function
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | import torch.nn.functional as F
8 | import torch.backends.cudnn as cudnn
9 |
10 | import torchvision
11 | import torchvision.transforms as transforms
12 |
13 | import os
14 | import argparse
15 |
16 | from tqdm import tqdm
17 | import random
18 | import numpy as np
19 | from models.wideresnet import *
20 |
21 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
22 | parser.add_argument('--seed', default=11111, type=int)
23 | parser.add_argument('--momentum', default=0.9, type=float)
24 | parser.add_argument('--weight_decay', default=5e-4, type=float)
25 | parser.add_argument('--epsilon', default=8.0/255, type=float)
26 | parser.add_argument('--m', default=8, type=int)
27 | parser.add_argument('--batch_size', default=128, type=int)
28 | parser.add_argument('--resume', '-r', default=None, type=int, help='resume from checkpoint')
29 | args = parser.parse_args()
30 |
31 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
32 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
33 |
34 | seed = args.seed
35 | random.seed(seed)
36 | torch.manual_seed(seed)
37 | torch.cuda.manual_seed(seed)
38 | np.random.seed(seed)
39 | torch.backends.cudnn.deterministic=True
40 | torch.backends.cudnn.benchmark = False
41 |
42 | # Data
43 | print('==> Preparing data..')
44 | transform_train = transforms.Compose([
45 | transforms.RandomCrop(32, padding=4),
46 | transforms.RandomHorizontalFlip(),
47 | transforms.ToTensor(),
48 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
49 | # Normalization messes with l-inf bounds.
50 | ])
51 |
52 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
53 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True)
54 |
55 |
56 |
57 | print('==> Building model..')
58 | net = WideResNet_28_10()
59 | epsilon = args.epsilon
60 | m = args.m
61 | delta = torch.zeros(args.batch_size, 3, 32, 32)
62 | delta = delta.to(device)
63 | net = net.to(device)
64 |
65 |
66 | if args.resume is not None:
67 | # Load checkpoint.
68 | print('==> Resuming from checkpoint..')
69 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
70 | checkpoint = torch.load('./checkpoint/ckpt.{}'.format(args.resume))
71 | net.load_state_dict(checkpoint['net'])
72 | start_epoch = checkpoint['epoch'] + 1
73 | torch.set_rng_state(checkpoint['rng_state'])
74 |
75 | optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=args.momentum, weight_decay=args.weight_decay)
76 |
77 | def train(epoch):
78 | print('\nEpoch: %d' % epoch)
79 | net.train()
80 | train_loss = 0
81 | correct = 0
82 | total = 0
83 | iterator = tqdm(trainloader, ncols=0, leave=False)
84 | global delta
85 |
86 | for batch_idx, (inputs, targets) in enumerate(iterator):
87 | inputs, targets = inputs.to(device), targets.to(device)
88 | for i in range(m):
89 | optimizer.zero_grad()
90 | adv = (inputs+delta).detach()
91 | adv.requires_grad_()
92 | outputs = net(adv)
93 | loss = F.cross_entropy(outputs, targets)
94 | loss.backward()
95 | optimizer.step()
96 | grad = adv.grad.data
97 | delta = delta.detach() + epsilon * torch.sign(grad.detach())
98 | delta = torch.clamp(delta, -epsilon, epsilon)
99 |
100 | train_loss += loss.item()
101 | _, predicted = outputs.max(1)
102 | total += targets.size(0)
103 | correct += predicted.eq(targets).sum().item()
104 | iterator.set_description(str(predicted.eq(targets).sum().item()/targets.size(0)))
105 |
106 | acc = 100.*correct/total
107 | print('Train acc:', acc)
108 |
109 |
110 | print('Saving..')
111 | state = {
112 | 'net': net.state_dict(),
113 | 'acc': acc,
114 | 'epoch': epoch,
115 | 'rng_state': torch.get_rng_state()
116 | }
117 | if not os.path.isdir('checkpoint'):
118 | os.mkdir('checkpoint')
119 | torch.save(state, './checkpoint/ckpt.{}'.format(epoch))
120 | best_acc = acc
121 |
122 |
123 | def adjust_learning_rate(optimizer, epoch):
124 | if epoch < 12:
125 | lr = 0.1
126 | elif epoch >= 12 and epoch < 22:
127 | lr = 0.01
128 | elif epoch >= 22:
129 | lr = 0.001
130 | for param_group in optimizer.param_groups:
131 | param_group['lr'] = lr
132 |
133 |
134 | for epoch in range(start_epoch, 27):
135 | adjust_learning_rate(optimizer, epoch)
136 | train(epoch)
137 |
--------------------------------------------------------------------------------