├── .gitattributes ├── .DS_Store ├── README.md ├── test_worst_acc.py ├── utils.py ├── models └── wideresnet.py └── main.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlanChou/Adversarial-Training-for-Free/HEAD/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Training for Free 2 | 3 | This is an unofficial PyTorch implementation of the paper "Adversarial Training for Free!".
4 | https://128.84.21.199/pdf/1904.12843.pdf
5 | It's a really helpful technique which can significantly accelerate adversarial training.
6 | It only contains the code for adversarial training on CIFAR-10. However, one can easily modify it for other datasets. 7 | 8 | I've noticed that in the cleverhans repo, they have an attribute called BACKPROP_THROUGH_ATTACK which is exactly the same idea behind this paper. 9 | 10 | ## Overview 11 | 12 | This repository contains two files. One for adversarial training and one for testing. Note that the model used here is not WideResNet 32-10 which is used in the paper 'Adversarial Training for Free'. I use WideResNet 28-10 which is used in the original PGD paper. Be aware that some of the hyperparameters are slightly different from the paper (weight decay and learning rate scheduling) 13 | 14 | 15 | 16 | ## Accuracy (under PGD attack: epsilon = 8, step size = 2, iteration = 100) 17 | | Model | Acc | 18 | | ---------------------------| ----------- | 19 | | [ WideResNet 28-10 ] | 46.93% | 20 | 21 | I did not test every epoch's checkpoint. I simply chose from one of the last epochs to test. Results might be slightly different. I've also released the checkpoint in the below Google Drive link)
22 | I have trouble training with ResNet56 and ResNet20 to the baseline they supposed to have which may suggest that this method does not apply to any given models. Please don't hesitate to share your results if you know how to fix this. 23 | `checkpoint` [Google Drive](https://drive.google.com/file/d/1iZ52Ctcwty8bLMvLJJMlWcHL-__lJcbo/view?usp=sharing) 24 | 25 | ## Dependencies 26 | The repository is based on Python 3.5, with the main dependencies being PyTorch==1.0.0 Additional dependencies for running experiments are: numpy, tqdm, argparse, os, random, advertorch. 27 | 28 | Advertorch can be installed from https://github.com/BorealisAI/advertorch
29 | Run the code with the command:
30 | ``` 31 | $ CUDA_VISIBLE_DEVICES=0 python3 main.py 32 | ``` 33 | 34 | 35 | -------------------------------------------------------------------------------- /test_worst_acc.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | import torch.backends.cudnn as cudnn 9 | 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | import os 13 | import argparse 14 | 15 | from tqdm import tqdm 16 | import numpy as np 17 | from models.wideresnet import * 18 | from utils import * 19 | from advertorch.attacks import LinfPGDAttack 20 | 21 | 22 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Testing') 23 | parser.add_argument('--seed', default=11111, type=int, help='seed') 24 | parser.add_argument('--epoch', default=0, type=int, help='load checkpoint from that epoch') 25 | parser.add_argument('--model', default='wideresnet', type=str) 26 | parser.add_argument('--batch_size', default=100, type=int) 27 | parser.add_argument('--iteration', default=100, type=int) 28 | parser.add_argument('--epsilon', default=8./255, type=float) 29 | parser.add_argument('--step_size', default=2./255, type=float) 30 | 31 | 32 | args = parser.parse_args() 33 | 34 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 35 | 36 | np.random.seed(args.seed) 37 | torch.manual_seed(args.seed) 38 | torch.cuda.manual_seed(args.seed) 39 | cudnn.benchmark = False 40 | cudnn.deterministic = True 41 | 42 | # Data 43 | print('==> Preparing data..') 44 | transform_test = transforms.Compose([ 45 | transforms.ToTensor(), 46 | ]) 47 | 48 | 49 | testset = torchvision.datasets.CIFAR10(root='/home/hsinpingchou/data', train=False, download=True, transform=transform_test) 50 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=2) 51 | 52 | print('==> Building model {}..'.format(args.model)) 53 | if args.model == 'wideresnet': 54 | net = WideResNet_28_10() 55 | else: 56 | raise ValueError('No such model.') 57 | 58 | def test(epoch): 59 | net.eval() 60 | correct = 0 61 | total = 0 62 | with torch.no_grad(): 63 | iterator = tqdm(testloader, ncols=0, leave=False) 64 | for batch_idx, (inputs, targets) in enumerate(iterator): 65 | inputs, targets = inputs.to(device), targets.to(device) 66 | with torch.enable_grad(): 67 | adv = adversary.perturb(inputs, targets) 68 | outputs = net(adv) 69 | _, predicted = outputs.max(1) 70 | total += targets.size(0) 71 | correct += predicted.eq(targets).sum().item() 72 | iterator.set_description(str(predicted.eq(targets).sum().item()/targets.size(0))) 73 | 74 | 75 | # Save checkpoint. 76 | acc = 100.*correct/total 77 | print('Test Acc of ckpt.{}: {}'.format(args.epoch, acc)) 78 | 79 | 80 | print('==> Loading from checkpoint epoch {}..'.format(args.epoch)) 81 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 82 | checkpoint = torch.load('./checkpoint/ckpt.{}'.format(args.epoch)) 83 | net.load_state_dict(checkpoint['net']) 84 | net = net.to(device) 85 | net.eval() 86 | 87 | 88 | adversary = LinfPGDAttack( 89 | net, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=args.epsilon, 90 | nb_iter=args.iteration, eps_iter=args.step_size, rand_init=True, clip_min=0.0, clip_max=1.0, 91 | targeted=False) 92 | 93 | test(adversary) 94 | 95 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | 14 | 15 | def get_mean_and_std(dataset): 16 | '''Compute the mean and std value of dataset.''' 17 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 18 | mean = torch.zeros(3) 19 | std = torch.zeros(3) 20 | print('==> Computing mean and std..') 21 | for inputs, targets in dataloader: 22 | for i in range(3): 23 | mean[i] += inputs[:,i,:,:].mean() 24 | std[i] += inputs[:,i,:,:].std() 25 | mean.div_(len(dataset)) 26 | std.div_(len(dataset)) 27 | return mean, std 28 | 29 | def init_params(net): 30 | '''Init layer parameters.''' 31 | for m in net.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.kaiming_normal(m.weight, mode='fan_out') 34 | if m.bias: 35 | init.constant(m.bias, 0) 36 | elif isinstance(m, nn.BatchNorm2d): 37 | init.constant(m.weight, 1) 38 | init.constant(m.bias, 0) 39 | elif isinstance(m, nn.Linear): 40 | init.normal(m.weight, std=1e-3) 41 | if m.bias: 42 | init.constant(m.bias, 0) 43 | 44 | 45 | _, term_width = os.popen('stty size', 'r').read().split() 46 | term_width = int(term_width) 47 | 48 | TOTAL_BAR_LENGTH = 65. 49 | last_time = time.time() 50 | begin_time = last_time 51 | def progress_bar(current, total, msg=None): 52 | global last_time, begin_time 53 | if current == 0: 54 | begin_time = time.time() # Reset for new bar. 55 | 56 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 57 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 58 | 59 | sys.stdout.write(' [') 60 | for i in range(cur_len): 61 | sys.stdout.write('=') 62 | sys.stdout.write('>') 63 | for i in range(rest_len): 64 | sys.stdout.write('.') 65 | sys.stdout.write(']') 66 | 67 | cur_time = time.time() 68 | step_time = cur_time - last_time 69 | last_time = cur_time 70 | tot_time = cur_time - begin_time 71 | 72 | L = [] 73 | L.append(' Step: %s' % format_time(step_time)) 74 | L.append(' | Tot: %s' % format_time(tot_time)) 75 | if msg: 76 | L.append(' | ' + msg) 77 | 78 | msg = ''.join(L) 79 | sys.stdout.write(msg) 80 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 81 | sys.stdout.write(' ') 82 | 83 | # Go back to the center of the bar. 84 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 85 | sys.stdout.write('\b') 86 | sys.stdout.write(' %d/%d ' % (current+1, total)) 87 | 88 | if current < total-1: 89 | sys.stdout.write('\r') 90 | else: 91 | sys.stdout.write('\n') 92 | sys.stdout.flush() 93 | 94 | def format_time(seconds): 95 | days = int(seconds / 3600/24) 96 | seconds = seconds - days*3600*24 97 | hours = int(seconds / 3600) 98 | seconds = seconds - hours*3600 99 | minutes = int(seconds / 60) 100 | seconds = seconds - minutes*60 101 | secondsf = int(seconds) 102 | seconds = seconds - secondsf 103 | millis = int(seconds*1000) 104 | 105 | f = '' 106 | i = 1 107 | if days > 0: 108 | f += str(days) + 'D' 109 | i += 1 110 | if hours > 0 and i <= 2: 111 | f += str(hours) + 'h' 112 | i += 1 113 | if minutes > 0 and i <= 2: 114 | f += str(minutes) + 'm' 115 | i += 1 116 | if secondsf > 0 and i <= 2: 117 | f += str(secondsf) + 's' 118 | i += 1 119 | if millis > 0 and i <= 2: 120 | f += str(millis) + 'ms' 121 | i += 1 122 | if f == '': 123 | f = '0ms' 124 | return f 125 | -------------------------------------------------------------------------------- /models/wideresnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | def forward(self, x): 23 | if not self.equalInOut: 24 | x = self.relu1(self.bn1(x)) 25 | else: 26 | out = self.relu1(self.bn1(x)) 27 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 28 | if self.droprate > 0: 29 | out = F.dropout(out, p=self.droprate, training=self.training) 30 | out = self.conv2(out) 31 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 32 | 33 | class NetworkBlock(nn.Module): 34 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 35 | super(NetworkBlock, self).__init__() 36 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 37 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 38 | layers = [] 39 | for i in range(int(nb_layers)): 40 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 41 | return nn.Sequential(*layers) 42 | def forward(self, x): 43 | return self.layer(x) 44 | 45 | class WideResNet(nn.Module): 46 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 47 | super(WideResNet, self).__init__() 48 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 49 | assert((depth - 4) % 6 == 0) 50 | n = (depth - 4) / 6 51 | block = BasicBlock 52 | # 1st conv before any network block 53 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 54 | padding=1, bias=False) 55 | # 1st block 56 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 57 | # 2nd block 58 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 59 | # 3rd block 60 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 61 | # global average pooling and classifier 62 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.fc = nn.Linear(nChannels[3], num_classes) 65 | self.nChannels = nChannels[3] 66 | 67 | for m in self.modules(): 68 | if isinstance(m, nn.Conv2d): 69 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 70 | m.weight.data.normal_(0, math.sqrt(2. / n)) 71 | elif isinstance(m, nn.BatchNorm2d): 72 | m.weight.data.fill_(1) 73 | m.bias.data.zero_() 74 | elif isinstance(m, nn.Linear): 75 | m.bias.data.zero_() 76 | def forward(self, x): 77 | out = self.conv1(x) 78 | out = self.block1(out) 79 | out = self.block2(out) 80 | out = self.block3(out) 81 | out = self.relu(self.bn1(out)) 82 | out = F.avg_pool2d(out, 8) 83 | out = out.view(-1, self.nChannels) 84 | return self.fc(out) 85 | 86 | def WideResNet_28_10(): 87 | return WideResNet(28, 10, 10, 0.0) 88 | 89 | def WideResNet_28_10_cifar100(): 90 | return WideResNet(28, 100, 10, 0.0) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | import torch.backends.cudnn as cudnn 9 | 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | 13 | import os 14 | import argparse 15 | 16 | from tqdm import tqdm 17 | import random 18 | import numpy as np 19 | from models.wideresnet import * 20 | 21 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 22 | parser.add_argument('--seed', default=11111, type=int) 23 | parser.add_argument('--momentum', default=0.9, type=float) 24 | parser.add_argument('--weight_decay', default=5e-4, type=float) 25 | parser.add_argument('--epsilon', default=8.0/255, type=float) 26 | parser.add_argument('--m', default=8, type=int) 27 | parser.add_argument('--batch_size', default=128, type=int) 28 | parser.add_argument('--resume', '-r', default=None, type=int, help='resume from checkpoint') 29 | args = parser.parse_args() 30 | 31 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 32 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 33 | 34 | seed = args.seed 35 | random.seed(seed) 36 | torch.manual_seed(seed) 37 | torch.cuda.manual_seed(seed) 38 | np.random.seed(seed) 39 | torch.backends.cudnn.deterministic=True 40 | torch.backends.cudnn.benchmark = False 41 | 42 | # Data 43 | print('==> Preparing data..') 44 | transform_train = transforms.Compose([ 45 | transforms.RandomCrop(32, padding=4), 46 | transforms.RandomHorizontalFlip(), 47 | transforms.ToTensor(), 48 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 49 | # Normalization messes with l-inf bounds. 50 | ]) 51 | 52 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 53 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True) 54 | 55 | 56 | 57 | print('==> Building model..') 58 | net = WideResNet_28_10() 59 | epsilon = args.epsilon 60 | m = args.m 61 | delta = torch.zeros(args.batch_size, 3, 32, 32) 62 | delta = delta.to(device) 63 | net = net.to(device) 64 | 65 | 66 | if args.resume is not None: 67 | # Load checkpoint. 68 | print('==> Resuming from checkpoint..') 69 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 70 | checkpoint = torch.load('./checkpoint/ckpt.{}'.format(args.resume)) 71 | net.load_state_dict(checkpoint['net']) 72 | start_epoch = checkpoint['epoch'] + 1 73 | torch.set_rng_state(checkpoint['rng_state']) 74 | 75 | optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=args.momentum, weight_decay=args.weight_decay) 76 | 77 | def train(epoch): 78 | print('\nEpoch: %d' % epoch) 79 | net.train() 80 | train_loss = 0 81 | correct = 0 82 | total = 0 83 | iterator = tqdm(trainloader, ncols=0, leave=False) 84 | global delta 85 | 86 | for batch_idx, (inputs, targets) in enumerate(iterator): 87 | inputs, targets = inputs.to(device), targets.to(device) 88 | for i in range(m): 89 | optimizer.zero_grad() 90 | adv = (inputs+delta).detach() 91 | adv.requires_grad_() 92 | outputs = net(adv) 93 | loss = F.cross_entropy(outputs, targets) 94 | loss.backward() 95 | optimizer.step() 96 | grad = adv.grad.data 97 | delta = delta.detach() + epsilon * torch.sign(grad.detach()) 98 | delta = torch.clamp(delta, -epsilon, epsilon) 99 | 100 | train_loss += loss.item() 101 | _, predicted = outputs.max(1) 102 | total += targets.size(0) 103 | correct += predicted.eq(targets).sum().item() 104 | iterator.set_description(str(predicted.eq(targets).sum().item()/targets.size(0))) 105 | 106 | acc = 100.*correct/total 107 | print('Train acc:', acc) 108 | 109 | 110 | print('Saving..') 111 | state = { 112 | 'net': net.state_dict(), 113 | 'acc': acc, 114 | 'epoch': epoch, 115 | 'rng_state': torch.get_rng_state() 116 | } 117 | if not os.path.isdir('checkpoint'): 118 | os.mkdir('checkpoint') 119 | torch.save(state, './checkpoint/ckpt.{}'.format(epoch)) 120 | best_acc = acc 121 | 122 | 123 | def adjust_learning_rate(optimizer, epoch): 124 | if epoch < 12: 125 | lr = 0.1 126 | elif epoch >= 12 and epoch < 22: 127 | lr = 0.01 128 | elif epoch >= 22: 129 | lr = 0.001 130 | for param_group in optimizer.param_groups: 131 | param_group['lr'] = lr 132 | 133 | 134 | for epoch in range(start_epoch, 27): 135 | adjust_learning_rate(optimizer, epoch) 136 | train(epoch) 137 | --------------------------------------------------------------------------------