├── .gitattributes ├── README.md ├── materials └── fast_at.png ├── requirements.txt ├── resnet.py ├── run.sh ├── train_adv.py ├── train_adv_psgd.py ├── utils.py └── wideresnet.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Subspace Adversarial Training 2 | 3 | **Tao Li, Yingwen Wu, Sizhe Chen, Kun Fang and Xiaolin Huang** 4 | 5 | **Paper:** http://arxiv.org/abs/2111.12229 6 | 7 | **CVPR 2022 oral** 8 | 9 | ## Abstract 10 | 11 | Single-step adversarial training (AT) has received wide attention as it proved to be both efficient and robust. However, a serious problem of catastrophic overfitting exists, i.e., the robust accuracy against projected gradient descent (PGD) attack suddenly drops to 0% during the training. In this paper, we approach this problem from a novel perspective of optimization and firstly reveal the close link between the fast-growing gradient of each sample and overfitting, which can also be applied to understand robust overfitting in multi-step AT. To control the growth of the gradient, we propose a new AT method, Subspace Adversarial Training (Sub-AT), which constrains AT in a carefully extracted subspace. It successfully resolves both kinds of overfitting and significantly boosts the robustness. In subspace, we also allow single-step AT with larger steps and larger radius, further improving the robustness performance. As a result, we achieve state-of-the-art single-step AT performance. Without any regularization term, our single-step AT can reach over 51% robust accuracy against strong PGD-50 attack of radius 8/255 on CIFAR-10, reaching a competitive performance against standard multi-step PGD-10 AT with huge computational advantages. 12 | 13 | ![catostrophic overfitting in Fast AT](materials/fast_at.png) 14 | 15 | ## Dependencies 16 | 17 | Install required dependencies: 18 | 19 | ``` 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | We also evaluate the robustness with [Auto-Attack](https://github.com/fra31/auto-attack). It can be installed via following source code: 24 | 25 | ``` 26 | pip install git+https://github.com/fra31/auto-attack 27 | ``` 28 | 29 | 30 | 31 | ## How to run 32 | 33 | We show sample usages in `run.sh`: 34 | 35 | ``` 36 | bash run.sh 37 | ``` 38 | 39 | For Tiny-ImageNet experiments, please prepare the dataset first under the path `datasets/tiny-imagenet-200/`. 40 | 41 | For more detailed settings of different datasets, please refer to the supplementary material. 42 | 43 | 44 | ## Citation 45 | ``` 46 | @inproceedings{li2022subspace, 47 | title={Subspace Adversarial Training}, 48 | author={Li, Tao and Wu, Yingwen and Chen, Sizhe and Fang, Kun and Huang, Xiaolin}, 49 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 50 | pages={13409--13418}, 51 | year={2022} 52 | } 53 | ``` 54 | -------------------------------------------------------------------------------- /materials/fast_at.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nblt/Sub-AT/9f3bcd55d447fd8648ea0a4aed55aa02d28d3c1b/materials/fast_at.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.6.0 2 | torchvision>=0.6 3 | numpy>=1.21 4 | advertorch 5 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | '''Pre-activation ResNet in PyTorch. 2 | Reference: 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 4 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from advertorch.utils import NormalizeByChannelMeanStd 10 | 11 | __all__ = ['ResNet18', 'ResNet34', 'ResNet50', 'ResNet101', 'ResNet152'] 12 | 13 | class PreActBlock(nn.Module): 14 | '''Pre-activation version of the BasicBlock.''' 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(PreActBlock, self).__init__() 19 | self.bn1 = nn.BatchNorm2d(in_planes) 20 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 23 | 24 | if stride != 1 or in_planes != self.expansion*planes: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 27 | ) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(x)) 31 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 32 | out = self.conv1(out) 33 | out = self.conv2(F.relu(self.bn2(out))) 34 | out += shortcut 35 | return out 36 | 37 | class PreActBottleneck(nn.Module): 38 | '''Pre-activation version of the original Bottleneck module.''' 39 | expansion = 4 40 | 41 | def __init__(self, in_planes, planes, stride=1): 42 | super(PreActBottleneck, self).__init__() 43 | self.bn1 = nn.BatchNorm2d(in_planes) 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 53 | ) 54 | 55 | def forward(self, x): 56 | out = F.relu(self.bn1(x)) 57 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 58 | out = self.conv1(out) 59 | out = self.conv2(F.relu(self.bn2(out))) 60 | out = self.conv3(F.relu(self.bn3(out))) 61 | out += shortcut 62 | return out 63 | 64 | class PreActResNet(nn.Module): 65 | def __init__(self, block, num_blocks, num_classes=100): 66 | super(PreActResNet, self).__init__() 67 | self.in_planes = 64 68 | 69 | # default normalization is for CIFAR10 70 | self.normalize = NormalizeByChannelMeanStd( 71 | mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]) 72 | 73 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 74 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 75 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 76 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 77 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 78 | self.bn = nn.BatchNorm2d(512 * block.expansion) 79 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 80 | self.linear = nn.Linear(512*block.expansion, num_classes) 81 | 82 | def _make_layer(self, block, planes, num_blocks, stride): 83 | strides = [stride] + [1]*(num_blocks-1) 84 | layers = [] 85 | for stride in strides: 86 | layers.append(block(self.in_planes, planes, stride)) 87 | self.in_planes = planes * block.expansion 88 | return nn.Sequential(*layers) 89 | 90 | def forward(self, x): 91 | x = self.normalize(x) 92 | out = self.conv1(x) 93 | out = self.layer1(out) 94 | out = self.layer2(out) 95 | out = self.layer3(out) 96 | out = self.layer4(out) 97 | out = F.relu(self.bn(out)) 98 | out = self.avgpool(out) 99 | out = out.view(out.size(0), -1) 100 | out = self.linear(out) 101 | return out 102 | 103 | def PreActResNet18(num_classes = 10): 104 | return PreActResNet(PreActBlock, [2,2,2,2], num_classes) 105 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | datasets=CIFAR10 # choises: [CIFAR10, CIFAR100, TinyImagenet] 2 | eps=8 3 | seed=0 4 | device=0 5 | 6 | for model in PreActResNet18 7 | do 8 | # Fast AT (200 epochs) 9 | EXP=$model\_$datasets\_FastAT 10 | DST=new_results/$EXP 11 | CUDA_VISIBLE_DEVICES=$device python -u train_adv.py --pgd50 \ 12 | --datasets $datasets --attack Fast-AT --randomseed $seed \ 13 | --train_eps $eps --test_eps $eps --train_step 1 --test_step 20 \ 14 | --train_gamma 10 --test_gamma 2 --arch=$model \ 15 | --epochs=200 --save-dir=$DST/models --log-dir=$DST --EXP $EXP 16 | 17 | # Fast Sub-AT (DLDR: 65 epochs; Sub-AT: 40 epochs) 18 | # We suggest use weight decay of 5e-4 instead of 1e-4 (our original setting) for better performance. [1] 19 | EXP=$model\_$datasets\_Fast_SubAT 20 | DST=new_results/$EXP 21 | CUDA_VISIBLE_DEVICES=$device python -u train_adv.py --pgd50 --wandb\ 22 | --datasets $datasets --attack Fast-AT --randomseed $seed \ 23 | --train_eps $eps --test_eps $eps --train_step 1 --test_step 20 \ 24 | --train_gamma 10 --test_gamma 2 --wd 0.0005 --arch=$model \ 25 | --epochs=65 --save-dir=$DST/models --log-dir=$DST --EXP $EXP 26 | 27 | CUDA_VISIBLE_DEVICES=$device python -u train_adv_psgd.py --autoattack \ 28 | --datasets $datasets --lr 1 --attack Fast-AT \ 29 | --train_eps 16 --test_eps $eps --train_step 1 --test_step 20 --train_gamma 20 --test_gamma 2 \ 30 | --params_start 0 --params_end 131 --batch-size 128 --n_components 80 \ 31 | --arch=$model --epochs=40 --save-dir=$DST/models --log-dir=$DST --log-name=PSGD 32 | 33 | # GAT experiments 34 | EXP=$model\_$datasets\_GAT_SubAT 35 | DST=new_results/$EXP 36 | CUDA_VISIBLE_DEVICES=$device python -u train_adv.py --pgd50 --evaluate \ 37 | --datasets $datasets --attack GAT \ 38 | --train_eps $eps --test_eps $eps --train_step 10 --test_step 20 \ 39 | --train_gamma 2 --test_gamma 2 --wd 0.0005 --arch=$model \ 40 | --epochs=200 --save-dir=$DST/models --log-dir=$DST 41 | 42 | CUDA_VISIBLE_DEVICES=$device python -u train_adv_psgd.py --pgd50 \ 43 | --datasets $datasets --lr 1 --attack GAT \ 44 | --train_eps 8 --test_eps $eps --train_step 1 --test_step 20 --train_gamma 10 --test_gamma 2 \ 45 | --params_start 0 --params_end 201 --batch-size 128 --n_components 100 \ 46 | --arch=$model --epochs=40 --save-dir=$DST/models --log-dir=$DST 47 | 48 | # PGD-AT (DLDR: 100 epochs; Sub-AT: 40 epochs) 49 | # We suggest use weight decay of 5e-4 instead of 1e-4 (our original setting) for better performance. [1] 50 | EXP=$model\_$datasets\_PGD_SubAT 51 | DST=new_results/$EXP 52 | CUDA_VISIBLE_DEVICES=$device python -u train_adv.py --pgd50 --wandb\ 53 | --datasets $datasets --attack PGD --randomseed $seed \ 54 | --train_eps $eps --test_eps $eps --train_step 10 --test_step 20 \ 55 | --train_gamma 2 --test_gamma 2 --wd 0.0005 --arch=$model \ 56 | --epochs=100 --save-dir=$DST/models --log-dir=$DST --EXP $EXP 57 | 58 | CUDA_VISIBLE_DEVICES=$device python -u train_adv_psgd.py --autoattack \ 59 | --datasets $datasets --lr 1 --attack PGD \ 60 | --train_eps 16 --test_eps $eps --train_step 1 --test_step 20 --train_gamma 20 --test_gamma 2 \ 61 | --params_start 0 --params_end 201 --batch-size 128 --n_components 100 \ 62 | --arch=$model --epochs=40 --save-dir=$DST/models --log-dir=$DST --log-name=PSGD 63 | 64 | done 65 | 66 | # [1] Pang et al., Bag of Tricks for Adversarial Training, ICLR 2021 -------------------------------------------------------------------------------- /train_adv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import numpy as np 5 | import random 6 | import sys 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim 13 | import torch.utils.data 14 | 15 | from advertorch.attacks import LinfPGDAttack, L2PGDAttack 16 | from advertorch.context import ctx_noparamgrad 17 | from utils import Logger, set_seed, get_datasets, get_model, print_args, epoch_adversarial, epoch_adversarial_PGD50 18 | from utils import AutoAttack, Guided_Attack, grad_align_loss, trades_loss 19 | import wandb 20 | 21 | ########################## parse arguments ########################## 22 | parser = argparse.ArgumentParser(description='Adversarial Training') 23 | parser.add_argument('--EXP', metavar='EXP', default='EXP', help='experiment name') 24 | parser.add_argument('--arch', '-a', metavar='ARCH', default='PreActResNet18', 25 | help='model architecture (default: PreActResNet18)') 26 | parser.add_argument('--datasets', metavar='DATASETS', default='CIFAR10', type=str, 27 | help='training datasets') 28 | parser.add_argument('--optimizer', metavar='OPTIMIZER', default='sgd', type=str, 29 | help='optimizer for training') 30 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 31 | help='number of data loading workers (default: 4)') 32 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 33 | help='number of total epochs to run') 34 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 35 | help='manual epoch number (useful on restarts)') 36 | parser.add_argument('-b', '--batch-size', default=128, type=int, 37 | metavar='N', help='mini-batch size (default: 128)') 38 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 39 | metavar='LR', help='initial learning rate') 40 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 41 | help='momentum') 42 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 43 | metavar='W', help='weight decay (default: 1e-4)') 44 | parser.add_argument('--print-freq', '-p', default=50, type=int, 45 | metavar='N', help='print frequency (default: 50 iterations)') 46 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 47 | help='path to latest checkpoint (default: none)') 48 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 49 | help='evaluate model on validation set') 50 | parser.add_argument('--save-dir', dest='save_dir', 51 | help='The directory used to save the trained models', 52 | default='save_temp', type=str) 53 | parser.add_argument('--log-dir', dest='log_dir', 54 | help='The directory used to save the log', 55 | default='save_temp', type=str) 56 | parser.add_argument('--log-name', dest='log_name', 57 | help='The log file name', 58 | default='log', type=str) 59 | parser.add_argument('--randomseed', 60 | help='Randomseed for training and initialization', 61 | type=int, default=0) 62 | parser.add_argument('--wandb', action='store_true', help='use wandb for online visualization') 63 | parser.add_argument('--cyclic', action='store_true', help='use cyclic lr schedule (default: False)') 64 | parser.add_argument('--lr_max', '--learning-rate-max', default=0.3, type=float, 65 | metavar='cLR', help='maximum learning rate for cyclic learning rates') 66 | 67 | ########################## attack setting ########################## 68 | adversary_names = ['Fast-AT', 'PGD', 'gradalign', 'GAT', 'trades'] 69 | parser.add_argument('--attack', metavar='attack', default='Fast-AT', 70 | choices=adversary_names, 71 | help='adversary for genernating adversarial examples: ' + ' | '.join(adversary_names) + 72 | ' (default: Fast-AT)') 73 | 74 | # Fast-AT / PGD 75 | parser.add_argument('--norm', default='linf', type=str, help='linf or l2') 76 | parser.add_argument('--train_eps', default=8., type=float, help='epsilon of attack during training') 77 | parser.add_argument('--train_step', default=10, type=int, help='itertion number of attack during training') 78 | parser.add_argument('--train_gamma', default=2., type=float, help='step size of attack during training') 79 | parser.add_argument('--train_randinit', action='store_false', help='randinit usage flag (default: on)') 80 | parser.add_argument('--test_eps', default=8., type=float, help='epsilon of attack during testing') 81 | parser.add_argument('--test_step', default=20, type=int, help='itertion number of attack during testing') 82 | parser.add_argument('--test_gamma', default=2., type=float, help='step size of attack during testing') 83 | parser.add_argument('--test_randinit', action='store_false', help='randinit usage flag (default: on)') 84 | 85 | # gradalign 86 | parser.add_argument('--gradalign_lambda', default=0.2, type=float, help='lambda for gradalign') 87 | # guideattack 88 | parser.add_argument('--GAT_lambda', default=10.0, type=float, help='lambda for GAT') 89 | # evaluate 90 | parser.add_argument('--pgd50', action='store_true', help='evaluate the model with pgd50 (default: False)') 91 | parser.add_argument('--autoattack', '--aa', action='store_true', help='evaluate the model with AA (default: False)') 92 | 93 | 94 | 95 | # Record training statistics 96 | train_robust_acc = [] 97 | val_robust_acc = [] 98 | train_robust_loss = [] 99 | val_robust_loss = [] 100 | test_natural_acc = [] 101 | test_natural_loss = [] 102 | arr_time = [] 103 | model_idx = 0 104 | 105 | def main(): 106 | 107 | global args, best_robust, model_idx 108 | global param_avg, train_loss, train_err, test_loss, test_err, arr_time, adv_acc 109 | 110 | args = parser.parse_args() 111 | 112 | # Check the save_dir exists or not 113 | print ('save dir:', args.save_dir) 114 | if not os.path.exists(args.save_dir): 115 | os.makedirs(args.save_dir) 116 | 117 | # Check the log_dir exists or not 118 | print ('log dir:', args.log_dir) 119 | if not os.path.exists(args.log_dir): 120 | os.makedirs(args.log_dir) 121 | 122 | sys.stdout = Logger(os.path.join(args.log_dir, args.log_name)) 123 | if args.wandb: 124 | print ('tracking with wandb!') 125 | wandb.init(project="Sub-AT", entity="nblt") 126 | date = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 127 | wandb.run.name = args.EXP + date 128 | 129 | print_args(args) 130 | print ('random seed:', args.randomseed) 131 | set_seed(args.randomseed) 132 | 133 | args.train_eps /= 255. 134 | args.train_gamma /= 255. 135 | args.test_eps /= 255. 136 | args.test_gamma /= 255. 137 | 138 | 139 | 140 | # Define model 141 | model = torch.nn.DataParallel(get_model(args)) 142 | model.cuda() 143 | 144 | cudnn.benchmark = True 145 | best_robust = 0 146 | 147 | # Prepare Dataloader 148 | train_loader, val_loader, test_loader = get_datasets(args) 149 | 150 | # Define loss function (criterion) and optimizer 151 | criterion = nn.CrossEntropyLoss().cuda() 152 | 153 | if args.optimizer == 'sgd': 154 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 155 | momentum=args.momentum, 156 | weight_decay=args.weight_decay) 157 | 158 | if args.cyclic: 159 | lr_scheduler=torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr_max, steps_per_epoch=len(train_loader), epochs=30) 160 | else: 161 | if args.datasets == 'TinyImagenet': 162 | print ('TinyImagenet schedule: [50, 80]') 163 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 164 | milestones=[50, 80], last_epoch=args.start_epoch - 1) 165 | else: 166 | print ('default schedule: [100, 150]') 167 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 168 | milestones=[100, 150], last_epoch=args.start_epoch - 1) 169 | 170 | # Optionally resume from a checkpoint 171 | if args.resume: 172 | if os.path.isfile(args.resume): 173 | print("=> loading checkpoint '{}'".format(args.resume)) 174 | checkpoint = torch.load(args.resume) 175 | args.start_epoch = checkpoint['epoch'] 176 | print ('from ', args.start_epoch) 177 | best_robust = checkpoint['best_robust'] 178 | optimizer = checkpoint['optimizer'] 179 | model.load_state_dict(checkpoint['state_dict']) 180 | print("=> loaded checkpoint '{}' (epoch {})" 181 | .format(args.evaluate, checkpoint['epoch'])) 182 | else: 183 | print("=> no checkpoint found at '{}'".format(args.resume)) 184 | 185 | if args.evaluate: 186 | validate(test_loader, model, criterion) 187 | if args.pgd50: 188 | epoch_adversarial_PGD50(test_loader, model) 189 | if args.autoattack: 190 | AutoAttack(model, args, dataset=args.datasets) 191 | return 192 | 193 | is_best = 0 194 | print ('Start training: ', args.start_epoch, '->', args.epochs) 195 | 196 | print ('adversary:', args.attack) 197 | if args.attack == 'gradalign': 198 | arr_lambda = [0, 0.03, 0.04, 0.05, 0.06, 0.08, 0.11, 0.15, 0.20, 0.27, 0.36, 0.47, 0.63, 0.84, 1.12, 1.50, 2.00] 199 | args.gradalign_lambda = arr_lambda[int(args.train_eps)] 200 | print ('gradalign lambda:', args.gradalign_lambda) 201 | if args.attack == 'GAT': 202 | print ('GAT_lambda:', args.GAT_lambda) 203 | 204 | # DLDR sampling 205 | torch.save(model.state_dict(), os.path.join(args.save_dir, str(model_idx) + '.pt')) 206 | model_idx += 1 207 | 208 | nat_last5 = [] 209 | rob_last5 = [] 210 | 211 | for epoch in range(args.start_epoch, args.epochs): 212 | 213 | # train for one epoch 214 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) 215 | train(train_loader, model, criterion, optimizer, lr_scheduler, epoch) 216 | 217 | # step learning rates 218 | if not args.cyclic: 219 | lr_scheduler.step() 220 | 221 | # evaluate on validation set 222 | natural_acc = validate(test_loader, model, criterion) 223 | 224 | # evaluate the adversarial robustness on validation set 225 | robust_acc, adv_loss = epoch_adversarial(val_loader, model, args) 226 | val_robust_acc.append(robust_acc) 227 | val_robust_loss.append(adv_loss) 228 | print ('adv acc on validation set', robust_acc) 229 | 230 | # remember best prec@1 and save checkpoint 231 | is_best = robust_acc > best_robust 232 | best_robust = max(robust_acc, best_robust) 233 | 234 | if args.wandb: 235 | wandb.log({"test natural acc": natural_acc}) 236 | wandb.log({"test robust acc": robust_acc}) 237 | 238 | if epoch + 5 >= args.epochs: 239 | nat_last5.append(natural_acc) 240 | robust_acc, adv_loss = epoch_adversarial(test_loader, model, args) 241 | print ('adv acc on test set', robust_acc) 242 | rob_last5.append(robust_acc) 243 | 244 | if is_best: 245 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'best.pt')) 246 | 247 | save_checkpoint({ 248 | 'state_dict': model.state_dict(), 249 | 'best_robust': best_robust, 250 | 'epochs': epoch, 251 | 'optimizer': optimizer 252 | }, filename=os.path.join(args.save_dir, 'model.th')) 253 | 254 | # DLDR sampling 255 | torch.save(model.state_dict(), os.path.join(args.save_dir, str(model_idx) + '.pt')) 256 | model_idx += 1 257 | 258 | print ('train_robust_acc: ', train_robust_acc) 259 | print ('train_robust_loss: ', train_robust_loss) 260 | print ('val_robust_acc: ', val_robust_acc) 261 | print ('val_robust_loss: ', val_robust_loss) 262 | print ('test_natural_acc: ', test_natural_acc) 263 | print ('test_natural_loss: ', test_natural_loss) 264 | print ('total training time: ', np.sum(arr_time)) 265 | print ('last 5 adv acc on test dataset:', np.mean(rob_last5)) 266 | print ('last 5 nat acc on test dataset:', np.mean(nat_last5)) 267 | 268 | print ('final:') 269 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'final.pt')) 270 | if args.pgd50: 271 | epoch_adversarial_PGD50(test_loader, model) 272 | if args.autoattack: 273 | AutoAttack(model, args, dataset=args.datasets) 274 | 275 | 276 | print ('best:') 277 | model.load_state_dict(torch.load(os.path.join(args.save_dir, 'best.pt'))) 278 | robust_acc, adv_loss = epoch_adversarial(test_loader, model, args) 279 | print ('best adv acc on test dataset:', robust_acc) 280 | 281 | if args.pgd50: 282 | epoch_adversarial_PGD50(test_loader, model) 283 | if args.autoattack: 284 | AutoAttack(model, args, dataset=args.datasets) 285 | validate(test_loader, model, criterion) 286 | 287 | def train(train_loader, model, criterion, optimizer, lr_scheduler, epoch): 288 | """ 289 | Run one train epoch 290 | """ 291 | global train_robust_acc, train_robust_loss, arr_time, args, model_idx 292 | 293 | batch_time = AverageMeter() 294 | data_time = AverageMeter() 295 | losses = AverageMeter() 296 | top1 = AverageMeter() 297 | 298 | if args.norm == 'linf': 299 | adversary = LinfPGDAttack( 300 | model, loss_fn=criterion, eps=args.train_eps, nb_iter=args.train_step, eps_iter=args.train_gamma, 301 | rand_init=args.train_randinit, clip_min=0.0, clip_max=1.0, targeted=False 302 | ) 303 | elif args.norm == 'l2': 304 | adversary = L2PGDAttack( 305 | model, loss_fn=criterion, eps=args.train_eps, nb_iter=args.train_step, eps_iter=args.train_gamma, 306 | rand_init=args.train_randinit, clip_min=0.0, clip_max=1.0, targeted=False 307 | ) 308 | 309 | # switch to train mode 310 | model.train() 311 | 312 | end = time.time() 313 | 314 | for i, (input, target) in enumerate(train_loader): 315 | 316 | # measure data loading time 317 | data_time.update(time.time() - end) 318 | 319 | target = target.cuda() 320 | input_var = input.cuda() 321 | target_var = target 322 | 323 | if args.attack == 'trades': 324 | # calculate robust loss 325 | output, loss = trades_loss(model=model, 326 | x_natural=input_var, 327 | y=target_var, 328 | optimizer=optimizer, 329 | step_size=args.train_gamma, 330 | epsilon=args.train_eps, 331 | perturb_steps=args.train_step, 332 | beta=6.0) 333 | 334 | elif args.attack == 'GAT': 335 | out = model(input_var) 336 | P_out = nn.Softmax(dim=1)(out) 337 | 338 | input_adv = input_var + ((4./255.0)*torch.sign(torch.tensor([0.5]).cuda() - torch.rand_like(input_var).cuda()).cuda()) 339 | input_adv = torch.clamp(input_adv,0.0,1.0) 340 | 341 | model.eval() 342 | input_adv = Guided_Attack(model,nn.CrossEntropyLoss(),input_adv,target,eps=8./255.0,steps=1,P_out=P_out,l2_reg=10,alt=(i%2)) 343 | 344 | delta = input_adv - input_var 345 | delta = torch.clamp(delta,-8.0/255.0,8.0/255) 346 | input_adv = input_var+delta 347 | input_adv = torch.clamp(input_adv,0.0,1.0) 348 | 349 | model.train() 350 | adv_out = model(input_adv) 351 | out = model(input_var) 352 | 353 | output = adv_out 354 | 355 | Q_out = nn.Softmax(dim=1)(adv_out) 356 | P_out = nn.Softmax(dim=1)(out) 357 | 358 | '''LOSS COMPUTATION''' 359 | 360 | closs = criterion(out, target_var) 361 | 362 | reg_loss = ((P_out - Q_out)**2.0).sum(1).mean(0) 363 | 364 | loss = 1.0*closs + args.GAT_lambda*reg_loss 365 | 366 | else: 367 | # adv samples 368 | with ctx_noparamgrad(model): 369 | input_adv = adversary.perturb(input_var, target_var) 370 | 371 | # compute output 372 | output = model(input_adv) 373 | loss = criterion(output, target_var) 374 | 375 | if args.attack == 'gradalign': 376 | loss += grad_align_loss(model, input_var, target_var, args) 377 | 378 | # compute gradient and do SGD step 379 | optimizer.zero_grad() 380 | loss.backward() 381 | 382 | optimizer.step() 383 | output = output.float() 384 | loss = loss.float() 385 | 386 | # cyclic learning rates 387 | if args.cyclic: 388 | lr_scheduler.step() 389 | 390 | # measure accuracy and record loss 391 | prec1 = accuracy(output.data, target)[0] 392 | losses.update(loss.item(), input.size(0)) 393 | top1.update(prec1.item(), input.size(0)) 394 | 395 | # measure elapsed time 396 | batch_time.update(time.time() - end) 397 | end = time.time() 398 | 399 | if i == 200 or i == 400 or i == 550: 400 | torch.save(model.state_dict(), os.path.join(args.save_dir, str(model_idx) + '.pt')) 401 | model_idx += 1 402 | 403 | if i % args.print_freq == 0: 404 | print('Epoch: [{0}][{1}/{2}]\t' 405 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 406 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 407 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 408 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 409 | epoch, i, len(train_loader), batch_time=batch_time, 410 | data_time=data_time, loss=losses, top1=top1)) 411 | 412 | print ('Total time for epoch [{0}] : {1:.3f}'.format(epoch, batch_time.sum)) 413 | 414 | train_robust_loss.append(losses.avg) 415 | train_robust_acc.append(top1.avg) 416 | if args.wandb: 417 | wandb.log({"train robust acc": top1.avg}) 418 | wandb.log({"train robust loss": losses.avg}) 419 | arr_time.append(batch_time.sum) 420 | 421 | def validate(val_loader, model, criterion): 422 | """ 423 | Run evaluation 424 | """ 425 | global test_natural_acc, test_natural_loss 426 | 427 | batch_time = AverageMeter() 428 | losses = AverageMeter() 429 | top1 = AverageMeter() 430 | 431 | # switch to evaluate mode 432 | model.eval() 433 | 434 | end = time.time() 435 | with torch.no_grad(): 436 | for i, (input, target) in enumerate(val_loader): 437 | target = target.cuda() 438 | input_var = input.cuda() 439 | target_var = target.cuda() 440 | 441 | 442 | # compute output 443 | output = model(input_var) 444 | loss = criterion(output, target_var) 445 | 446 | output = output.float() 447 | loss = loss.float() 448 | 449 | # measure accuracy and record loss 450 | prec1 = accuracy(output.data, target)[0] 451 | losses.update(loss.item(), input.size(0)) 452 | top1.update(prec1.item(), input.size(0)) 453 | 454 | # measure elapsed time 455 | batch_time.update(time.time() - end) 456 | end = time.time() 457 | 458 | if i % args.print_freq == 0: 459 | print('Test: [{0}/{1}]\t' 460 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 461 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 462 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 463 | i, len(val_loader), batch_time=batch_time, loss=losses, 464 | top1=top1)) 465 | 466 | print(' * Prec@1 {top1.avg:.3f}' 467 | .format(top1=top1)) 468 | 469 | test_natural_loss.append(losses.avg) 470 | test_natural_acc.append(top1.avg) 471 | return top1.avg 472 | 473 | def save_checkpoint(state, filename='checkpoint.pth.tar'): 474 | """ 475 | Save the training model 476 | """ 477 | torch.save(state, filename) 478 | 479 | class AverageMeter(object): 480 | """Computes and stores the average and current value""" 481 | def __init__(self): 482 | self.reset() 483 | 484 | def reset(self): 485 | self.val = 0 486 | self.avg = 0 487 | self.sum = 0 488 | self.count = 0 489 | 490 | def update(self, val, n=1): 491 | self.val = val 492 | self.sum += val * n 493 | self.count += n 494 | self.avg = self.sum / self.count 495 | 496 | 497 | def accuracy(output, target, topk=(1,)): 498 | """Computes the precision@k for the specified values of k""" 499 | maxk = max(topk) 500 | batch_size = target.size(0) 501 | 502 | _, pred = output.topk(maxk, 1, True, True) 503 | pred = pred.t() 504 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 505 | 506 | res = [] 507 | for k in topk: 508 | correct_k = correct[:k].view(-1).float().sum(0) 509 | res.append(correct_k.mul_(100.0 / batch_size)) 510 | return res 511 | 512 | 513 | if __name__ == '__main__': 514 | main() 515 | -------------------------------------------------------------------------------- /train_adv_psgd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import sys 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim as optim 11 | import torch.utils.data 12 | 13 | from sklearn.decomposition import PCA 14 | import numpy as np 15 | import random 16 | from utils import Logger, set_seed, get_datasets, get_model, print_args, epoch_adversarial, epoch_adversarial_PGD50 17 | from utils import AutoAttack, Guided_Attack, grad_align_loss, trades_loss 18 | import wandb 19 | 20 | from advertorch.attacks import LinfPGDAttack, L2PGDAttack 21 | from advertorch.context import ctx_noparamgrad 22 | 23 | parser = argparse.ArgumentParser(description='Subspace Adversarial Training') 24 | parser.add_argument('--EXP', metavar='EXP', default='EXP', help='experiment name') 25 | parser.add_argument('--arch', '-a', metavar='ARCH', default='PreActResNet18', 26 | help='model architecture (default: PreActResNet18)') 27 | parser.add_argument('--datasets', metavar='DATASETS', default='CIFAR10', type=str, 28 | help='The training datasets') 29 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 30 | help='number of data loading workers (default: 4)') 31 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 32 | help='number of total epochs to run') 33 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 34 | help='manual epoch number (useful on restarts)') 35 | parser.add_argument('-b', '--batch-size', default=128, type=int, 36 | metavar='N', help='mini-batch size (default: 128)') 37 | parser.add_argument('--weight-decay', '--wd', default=0.0005, type=float, 38 | metavar='W', help='weight decay (default: 1e-4)') 39 | parser.add_argument('--print-freq', '-p', default=50, type=int, 40 | metavar='N', help='print frequency (default: 50)') 41 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 42 | help='path to latest checkpoint (default: none)') 43 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 44 | help='evaluate model on validation set') 45 | parser.add_argument('--save-dir', dest='save_dir', 46 | help='The directory used to save the trained models', 47 | default='save_temp', type=str) 48 | parser.add_argument('--log-dir', dest='log_dir', 49 | help='The directory used to save the log', 50 | default='save_temp', type=str) 51 | parser.add_argument('--log-name', dest='log_name', 52 | help='The log file name', 53 | default='log', type=str) 54 | parser.add_argument('--wandb', action='store_true', help='use wandb for online visualization') 55 | parser.add_argument('--randomseed', 56 | help='Randomseed for initialization and training', 57 | type=int, default=0) 58 | 59 | ########################## Sub-AT setting ########################## 60 | parser.add_argument('--n_components', default=80, type=int, metavar='N', 61 | help='n_components for PCA') 62 | parser.add_argument('--params_start', default=0, type=int, metavar='N', 63 | help='which idx starts for sampling') 64 | parser.add_argument('--params_end', default=81, type=int, metavar='N', 65 | help='which idx ends for sampling') 66 | parser.add_argument('--skip', action='store_true', help='skip for DLDR sampling') 67 | parser.add_argument('--lr', default=1, type=float, metavar='N', 68 | help='lr for Sub-AT') 69 | 70 | ########################## attack setting ########################## 71 | adversary_names = ['Fast-AT', 'PGD', 'gradalign', 'GAT', 'trades'] 72 | parser.add_argument('--attack', metavar='attack', default='Fast-AT', 73 | choices=adversary_names, 74 | help='adversary for genernating adversarial examples: ' + ' | '.join(adversary_names) + 75 | ' (default: Fast-AT)') 76 | 77 | # Fast-AT / PGD 78 | parser.add_argument('--norm', default='linf', type=str, help='linf or l2') 79 | parser.add_argument('--train_eps', default=8., type=float, help='epsilon of attack during training') 80 | parser.add_argument('--train_step', default=10, type=int, help='itertion number of attack during training') 81 | parser.add_argument('--train_gamma', default=2., type=float, help='step size of attack during training') 82 | parser.add_argument('--train_randinit', action='store_false', help='randinit usage flag (default: on)') 83 | parser.add_argument('--test_eps', default=8., type=float, help='epsilon of attack during testing') 84 | parser.add_argument('--test_step', default=20, type=int, help='itertion number of attack during testing') 85 | parser.add_argument('--test_gamma', default=2., type=float, help='step size of attack during testing') 86 | parser.add_argument('--test_randinit', action='store_false', help='randinit usage flag (default: on)') 87 | 88 | # gradalign 89 | parser.add_argument('--gradalign_lambda', default=0.2, type=float, help='lambda for gradalign') 90 | # guideattack 91 | parser.add_argument('--GAT_lambda', default=10.0, type=float, help='lambda for GAT') 92 | # evaluate 93 | parser.add_argument('--pgd50', action='store_true', help='evaluate the model with pgd50 (default: False)') 94 | parser.add_argument('--autoattack', '--aa', action='store_true', help='evaluate the model with AA (default: False)') 95 | 96 | args = parser.parse_args() 97 | 98 | # Check the save_dir exists or not 99 | print ('save dir:', args.save_dir) 100 | if not os.path.exists(args.save_dir): 101 | os.makedirs(args.save_dir) 102 | 103 | # Check the log_dir exists or not 104 | print ('log dir:', args.log_dir) 105 | if not os.path.exists(args.log_dir): 106 | os.makedirs(args.log_dir) 107 | 108 | sys.stdout = Logger(os.path.join(args.log_dir, args.log_name)) 109 | if args.wandb: 110 | print ('tracking with wandb!') 111 | wandb.init(project="Sub-AT", entity="nblt") 112 | date = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 113 | wandb.run.name = args.EXP + date 114 | 115 | set_seed(args.randomseed) 116 | print_args(args) 117 | 118 | train_eps = args.train_eps 119 | args.train_eps /= 255. 120 | args.train_gamma /= 255. 121 | args.test_eps /= 255. 122 | args.test_gamma /= 255. 123 | 124 | best_robust = 0 125 | P = None 126 | 127 | train_robust_acc = [] 128 | val_robust_acc = [] 129 | train_robust_loss = [] 130 | val_robust_loss = [] 131 | test_natural_acc = [] 132 | test_natural_loss = [] 133 | arr_time = [] 134 | 135 | def get_model_param_vec(model): 136 | """ 137 | Return model parameters as a vector 138 | """ 139 | vec = [] 140 | for name,param in model.named_parameters(): 141 | vec.append(param.detach().cpu().numpy().reshape(-1)) 142 | return np.concatenate(vec, 0) 143 | 144 | def get_model_grad_vec(model): 145 | # Return the model grad as a vector 146 | 147 | vec = [] 148 | for name,param in model.named_parameters(): 149 | vec.append(param.grad.detach().reshape(-1)) 150 | return torch.cat(vec, 0) 151 | 152 | def update_grad(model, grad_vec): 153 | """ 154 | Update model grad 155 | """ 156 | idx = 0 157 | for name,param in model.named_parameters(): 158 | arr_shape = param.grad.shape 159 | size = arr_shape.numel() 160 | param.grad.data = grad_vec[idx:idx+size].reshape(arr_shape).clone() 161 | idx += size 162 | 163 | def main(): 164 | global args, best_robust, P, arr_time, train_eps 165 | 166 | # Define model 167 | model = torch.nn.DataParallel(get_model(args)) 168 | model.cuda() 169 | 170 | cudnn.benchmark = True 171 | 172 | 173 | ################################ DLDR ####################################### 174 | # Load sampled model parameters 175 | print ('params: from', args.params_start, 'to', args.params_end) 176 | W = [] 177 | for i in range(args.params_start, args.params_end): 178 | if args.skip and i % 2 != 0: continue 179 | model.load_state_dict(torch.load(os.path.join(args.save_dir, str(i) + '.pt'))) 180 | W.append(get_model_param_vec(model)) 181 | W = np.array(W) 182 | print ('W:', W.shape) 183 | 184 | # Obtain base variables through PCA 185 | pca = PCA(n_components=args.n_components) 186 | pca.fit_transform(W) 187 | P = np.array(pca.components_) 188 | print ('ratio:', pca.explained_variance_ratio_) 189 | print ('P:', P.shape) 190 | 191 | P = torch.from_numpy(P).cuda() 192 | 193 | # Resume from params_start 194 | model.load_state_dict(torch.load(os.path.join(args.save_dir, str(args.params_start) + '.pt'))) 195 | 196 | # Prepare Dataloader 197 | train_loader, val_loader, test_loader = get_datasets(args) 198 | 199 | # Define loss function (criterion) and optimizer 200 | criterion = nn.CrossEntropyLoss().cuda() 201 | 202 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) 203 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 204 | milestones=[40], last_epoch=args.start_epoch - 1) 205 | 206 | # Optionally resume from a checkpoint 207 | if args.resume: 208 | if os.path.isfile(args.resume): 209 | print("=> loading checkpoint '{}'".format(args.resume)) 210 | checkpoint = torch.load(args.resume) 211 | args.start_epoch = checkpoint['epoch'] 212 | print ('from ', args.start_epoch) 213 | best_robust = checkpoint['best_robust'] 214 | optimizer = checkpoint['optimizer'] 215 | model.load_state_dict(checkpoint['state_dict']) 216 | print("=> loaded checkpoint '{}' (epoch {})" 217 | .format(args.evaluate, checkpoint['epoch'])) 218 | else: 219 | print("=> no checkpoint found at '{}'".format(args.resume)) 220 | 221 | if args.evaluate: 222 | validate(test_loader, model, criterion) 223 | if args.pgd50: 224 | epoch_adversarial_PGD50(test_loader, model) 225 | if args.autoattack: 226 | AutoAttack(model, args, dataset=args.datasets) 227 | return 228 | 229 | end = time.time() 230 | 231 | ################################ Sub-AT ####################################### 232 | is_best = 0 233 | print ('Start training: ', args.start_epoch, '->', args.epochs) 234 | 235 | print ('adversary:', args.attack) 236 | if args.attack == 'gradalign': 237 | arr_lambda = [0, 0.03, 0.04, 0.05, 0.06, 0.08, 0.11, 0.15, 0.20, 0.27, 0.36, 0.47, 0.63, 0.84, 1.12, 1.50, 2.00] 238 | args.gradalign_lambda = arr_lambda[int(args.train_eps)] 239 | print ('gradalign lambda:', args.gradalign_lambda) 240 | if args.attack == 'GAT': 241 | print ('GAT_lambda:', args.GAT_lambda) 242 | 243 | nat_last5 = [] 244 | rob_last5 = [] 245 | 246 | for epoch in range(args.start_epoch, args.epochs): 247 | 248 | # train for one epoch 249 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) 250 | train(train_loader, model, criterion, optimizer, epoch) 251 | lr_scheduler.step() 252 | 253 | # evaluate on validation set 254 | natural_acc = validate(test_loader, model, criterion) 255 | # validate(train_loader, model, criterion) 256 | 257 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'train' + str(train_eps) + 'psgd_' + str(epoch) + '.pt')) 258 | 259 | # evaluate the adversarial robustness on validation set 260 | robust_acc, adv_loss = epoch_adversarial(val_loader, model, args) 261 | val_robust_acc.append(robust_acc) 262 | val_robust_loss.append(adv_loss) 263 | print ('adv acc on validation set', robust_acc) 264 | 265 | # remember best prec@1 and save checkpoint 266 | is_best = robust_acc > best_robust 267 | best_robust = max(robust_acc, best_robust) 268 | 269 | if args.wandb: 270 | wandb.log({"test natural acc": natural_acc}) 271 | wandb.log({"test robust acc": robust_acc}) 272 | 273 | if epoch + 5 >= args.epochs: 274 | nat_last5.append(natural_acc) 275 | robust_acc, adv_loss = epoch_adversarial(test_loader, model, args) 276 | print ('adv acc on test set', robust_acc) 277 | rob_last5.append(robust_acc) 278 | 279 | if is_best: 280 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'train_eps' + str(train_eps) + 'psgd_best.pt')) 281 | 282 | save_checkpoint({ 283 | 'state_dict': model.state_dict(), 284 | 'best_robust': best_robust, 285 | 'epochs': epoch, 286 | 'optimizer': optimizer 287 | }, filename=os.path.join(args.save_dir, 'psgd_model.th')) 288 | 289 | print ('train_robust_acc: ', train_robust_acc) 290 | print ('train_robust_loss: ', train_robust_loss) 291 | print ('val_robust_acc: ', val_robust_acc) 292 | print ('val_robust_loss: ', val_robust_loss) 293 | print ('test_natural_acc: ', test_natural_acc) 294 | print ('test_natural_loss: ', test_natural_loss) 295 | print ('total training time: ', np.sum(arr_time)) 296 | print ('last 5 adv acc on test dataset:', np.mean(rob_last5)) 297 | print ('last 5 nat acc on test dataset:', np.mean(nat_last5)) 298 | del P 299 | torch.cuda.empty_cache() 300 | 301 | print ('final:') 302 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'train_eps' + str(train_eps) + 'psgd_final.pt')) 303 | if args.pgd50: 304 | epoch_adversarial_PGD50(test_loader, model) 305 | if args.autoattack: 306 | AutoAttack(model, args, dataset=args.datasets) 307 | 308 | print ('best:') 309 | model.load_state_dict(torch.load(os.path.join(args.save_dir, 'train_eps' + str(train_eps) + 'psgd_best.pt'))) 310 | robust_acc, adv_loss = epoch_adversarial(test_loader, model, args) 311 | print ('best adv acc on test dataset:', robust_acc) 312 | 313 | if args.pgd50: 314 | epoch_adversarial_PGD50(test_loader, model) 315 | if args.autoattack: 316 | AutoAttack(model, args, dataset=args.datasets) 317 | validate(test_loader, model, criterion) 318 | 319 | 320 | def train(train_loader, model, criterion, optimizer, epoch): 321 | # Run one train epoch 322 | 323 | global P, train_robust_acc, train_robust_loss, args, arr_time 324 | 325 | batch_time = AverageMeter() 326 | data_time = AverageMeter() 327 | losses = AverageMeter() 328 | top1 = AverageMeter() 329 | 330 | if args.norm == 'linf': 331 | adversary = LinfPGDAttack( 332 | model, loss_fn=criterion, eps=args.train_eps, nb_iter=args.train_step, eps_iter=args.train_gamma, 333 | rand_init=args.train_randinit, clip_min=0.0, clip_max=1.0, targeted=False 334 | ) 335 | elif args.norm == 'l2': 336 | adversary = L2PGDAttack( 337 | model, loss_fn=criterion, eps=args.train_eps, nb_iter=args.train_step, eps_iter=args.train_gamma, 338 | rand_init=args.train_randinit, clip_min=0.0, clip_max=1.0, targeted=False 339 | ) 340 | 341 | # Switch to train mode 342 | model.train() 343 | 344 | end = time.time() 345 | for i, (input, target) in enumerate(train_loader): 346 | 347 | # Measure data loading time 348 | data_time.update(time.time() - end) 349 | 350 | # Load batch data to cuda 351 | target = target.cuda() 352 | input_var = input.cuda() 353 | target_var = target 354 | 355 | if args.attack == 'trades': 356 | # calculate robust loss 357 | output, loss = trades_loss(model=model, 358 | x_natural=input_var, 359 | y=target_var, 360 | optimizer=optimizer, 361 | step_size=args.train_gamma, 362 | epsilon=args.train_eps, 363 | perturb_steps=args.train_step, 364 | beta=6.0) 365 | 366 | elif args.attack == 'GAT': 367 | eps = args.train_eps 368 | out = model(input_var) 369 | P_out = nn.Softmax(dim=1)(out) 370 | 371 | # input_adv = input_var + ((4./255.0)*torch.sign(torch.tensor([0.5]).cuda() - torch.rand_like(input_var).cuda()).cuda()) 372 | input_adv = input_var + (eps/2*torch.sign(torch.tensor([0.5]).cuda() - torch.rand_like(input_var).cuda()).cuda()) 373 | input_adv = torch.clamp(input_adv,0.0,1.0) 374 | 375 | model.eval() 376 | input_adv = Guided_Attack(model,nn.CrossEntropyLoss(),input_adv,target,eps=eps,steps=1,P_out=P_out,l2_reg=10,alt=(i%2)) 377 | 378 | delta = input_adv - input_var 379 | delta = torch.clamp(delta,-eps,eps) 380 | input_adv = input_var+delta 381 | input_adv = torch.clamp(input_adv,0.0,1.0) 382 | 383 | model.train() 384 | adv_out = model(input_adv) 385 | out = model(input_var) 386 | 387 | output = adv_out 388 | 389 | Q_out = nn.Softmax(dim=1)(adv_out) 390 | P_out = nn.Softmax(dim=1)(out) 391 | 392 | '''LOSS COMPUTATION''' 393 | 394 | closs = criterion(out, target_var) 395 | 396 | reg_loss = ((P_out - Q_out)**2.0).sum(1).mean(0) 397 | 398 | loss = 1.0*closs + 10.0*reg_loss 399 | 400 | else: 401 | #adv samples 402 | with ctx_noparamgrad(model): 403 | input_adv = adversary.perturb(input_var, target_var) 404 | 405 | # Compute output 406 | output = model(input_adv) 407 | loss = criterion(output, target_var) 408 | 409 | if args.attack == 'gradalign': 410 | loss += grad_align_loss(model, input_var, target_var, args) 411 | 412 | # Compute gradient and do SGD step 413 | optimizer.zero_grad() 414 | loss.backward() 415 | 416 | # Do P_SGD update 417 | gk = get_model_grad_vec(model) 418 | P_SGD(model, optimizer, gk) 419 | 420 | # Measure accuracy and record loss 421 | prec1 = accuracy(output.data, target)[0] 422 | losses.update(loss.item(), input.size(0)) 423 | top1.update(prec1.item(), input.size(0)) 424 | 425 | # Measure elapsed time 426 | batch_time.update(time.time() - end) 427 | end = time.time() 428 | 429 | if i % args.print_freq == 0 or i == len(train_loader)-1: 430 | print('Epoch: [{0}][{1}/{2}]\t' 431 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 432 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 433 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 434 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 435 | epoch, i, len(train_loader), batch_time=batch_time, 436 | data_time=data_time, loss=losses, top1=top1)) 437 | 438 | train_robust_loss.append(losses.avg) 439 | train_robust_acc.append(top1.avg) 440 | if args.wandb: 441 | wandb.log({"train robust acc": top1.avg}) 442 | wandb.log({"train robust loss": losses.avg}) 443 | arr_time.append(batch_time.sum) 444 | 445 | 446 | def P_SGD(model, optimizer, grad): 447 | # Project the gradient onto the subspace and do SGD step 448 | 449 | gk = torch.mm(P, grad.reshape(-1,1)) 450 | 451 | grad_proj = torch.mm(P.transpose(0, 1), gk) 452 | 453 | # Update the model grad and do a step 454 | update_grad(model, grad_proj) 455 | optimizer.step() 456 | 457 | def validate(val_loader, model, criterion): 458 | # Run evaluation 459 | 460 | global test_natural_acc, test_natural_loss 461 | 462 | batch_time = AverageMeter() 463 | losses = AverageMeter() 464 | top1 = AverageMeter() 465 | 466 | # Switch to evaluate mode 467 | model.eval() 468 | 469 | end = time.time() 470 | with torch.no_grad(): 471 | for i, (input, target) in enumerate(val_loader): 472 | target = target.cuda() 473 | input_var = input.cuda() 474 | target_var = target.cuda() 475 | 476 | # Compute output 477 | output = model(input_var) 478 | loss = criterion(output, target_var) 479 | 480 | output = output.float() 481 | loss = loss.float() 482 | 483 | # Measure accuracy and record loss 484 | prec1 = accuracy(output.data, target)[0] 485 | losses.update(loss.item(), input.size(0)) 486 | top1.update(prec1.item(), input.size(0)) 487 | 488 | # Measure elapsed time 489 | batch_time.update(time.time() - end) 490 | end = time.time() 491 | 492 | if i % args.print_freq == 0: 493 | print('Test: [{0}/{1}]\t' 494 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 495 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 496 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 497 | i, len(val_loader), batch_time=batch_time, loss=losses, 498 | top1=top1)) 499 | 500 | print(' * Prec@1 {top1.avg:.3f}' 501 | .format(top1=top1)) 502 | 503 | # Store the test loss and test accuracy 504 | test_natural_loss.append(losses.avg) 505 | test_natural_acc.append(top1.avg) 506 | 507 | return top1.avg 508 | 509 | def save_checkpoint(state, filename='checkpoint.pth.tar'): 510 | # Save the training model 511 | 512 | torch.save(state, filename) 513 | 514 | class AverageMeter(object): 515 | # Computes and stores the average and current value 516 | 517 | def __init__(self): 518 | self.reset() 519 | 520 | def reset(self): 521 | self.val = 0 522 | self.avg = 0 523 | self.sum = 0 524 | self.count = 0 525 | 526 | def update(self, val, n=1): 527 | self.val = val 528 | self.sum += val * n 529 | self.count += n 530 | self.avg = self.sum / self.count 531 | 532 | 533 | def accuracy(output, target, topk=(1,)): 534 | # Computes the precision@k for the specified values of k 535 | 536 | maxk = max(topk) 537 | batch_size = target.size(0) 538 | 539 | _, pred = output.topk(maxk, 1, True, True) 540 | pred = pred.t() 541 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 542 | 543 | res = [] 544 | for k in topk: 545 | correct_k = correct[:k].view(-1).float().sum(0) 546 | res.append(correct_k.mul_(100.0 / batch_size)) 547 | return res 548 | 549 | if __name__ == '__main__': 550 | main() 551 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import sys 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim as optim 12 | import torch.utils.data 13 | from torch.autograd.gradcheck import zero_gradients 14 | import torch.nn.functional as F 15 | import torchvision.transforms as transforms 16 | import torchvision.datasets as datasets 17 | from torch.utils.data import DataLoader, Subset 18 | from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder 19 | from torch.autograd import Variable 20 | import torch.autograd as autograd 21 | 22 | 23 | from sklearn.decomposition import PCA 24 | import matplotlib.pyplot as plt 25 | import numpy as np 26 | from numpy import linalg as LA 27 | import pickle 28 | import random 29 | 30 | from advertorch.attacks import LinfPGDAttack, L2PGDAttack 31 | from advertorch.context import ctx_noparamgrad 32 | from advertorch.utils import NormalizeByChannelMeanStd 33 | 34 | def get_model_param_vec(model): 35 | """Return model parameters as a vector""" 36 | vec = [] 37 | for name,param in model.named_parameters(): 38 | vec.append(param.detach().cpu().numpy().reshape(-1)) 39 | return np.concatenate(vec, 0) 40 | 41 | class Logger(object): 42 | def __init__(self,fileN ="Default.log"): 43 | self.terminal = sys.stdout 44 | self.log = open(fileN,"a") 45 | 46 | def write(self,message): 47 | self.terminal.write(message) 48 | self.log.write(message) 49 | 50 | def flush(self): 51 | pass 52 | 53 | def set_seed(seed=1): 54 | random.seed(seed) 55 | np.random.seed(seed) 56 | torch.manual_seed(seed) 57 | torch.cuda.manual_seed(seed) 58 | torch.backends.cudnn.deterministic = True 59 | torch.backends.cudnn.benchmark = False 60 | 61 | def print_args(args): 62 | print ('batch size:', args.batch_size) 63 | print ('Attack Norm:', args.norm) 64 | print ('train eps: {} train step: {} train gamma: {} train randinit: {}'.format(args.train_eps, args.train_step, args.train_gamma, args.train_randinit)) 65 | print ('test eps: {} test step: {} test gamma: {} test randinit: {}'.format(args.test_eps, args.test_step, args.test_gamma, args.test_randinit)) 66 | print ('Model:', args.arch) 67 | print ('Dataset:', args.datasets) 68 | 69 | 70 | 71 | ################################ PGD attack ####################################### 72 | def epoch_adversarial(loader, model, args, opt=None, **kwargs): 73 | """Adversarial training/evaluation epoch over the dataset""" 74 | total_loss, total_err = 0.,0. 75 | 76 | if args.norm == 'linf': 77 | adversary = LinfPGDAttack( 78 | model, loss_fn=nn.CrossEntropyLoss(), eps=args.test_eps, nb_iter=args.test_step, eps_iter=args.test_gamma, 79 | rand_init=args.test_randinit, clip_min=0.0, clip_max=1.0, targeted=False 80 | ) 81 | elif args.norm == 'l2': 82 | adversary = L2PGDAttack( 83 | model, loss_fn=nn.CrossEntropyLoss(), eps=args.test_eps, nb_iter=args.test_step, eps_iter=args.test_gamma, 84 | rand_init=args.test_randinit, clip_min=0.0, clip_max=1.0, targeted=False 85 | ) 86 | 87 | model.eval() 88 | for X,y in loader: 89 | X,y = X.cuda(), y.cuda() 90 | # delta = attack(model, X, y, **kwargs) 91 | #adv samples 92 | input_adv = adversary.perturb(X, y) 93 | yp = model(input_adv) 94 | loss = nn.CrossEntropyLoss()(yp,y) 95 | 96 | total_err += (yp.max(dim=1)[1] != y).sum().item() 97 | total_loss += loss.item() * X.shape[0] 98 | 99 | return 1 - total_err / len(loader.dataset), total_loss / len(loader.dataset) 100 | 101 | def clamp(X, lower_limit, upper_limit): 102 | return torch.max(torch.min(X, upper_limit), lower_limit) 103 | 104 | def attack_pgd(model, X, y, epsilon=8./255, alpha=2./255, attack_iters=50, restarts=10, 105 | norm='l_inf', early_stop=False, 106 | mixup=False, y_a=None, y_b=None, lam=None): 107 | upper_limit, lower_limit = 1,0 108 | max_loss = torch.zeros(y.shape[0]).cuda() 109 | max_delta = torch.zeros_like(X).cuda() 110 | for _ in range(restarts): 111 | delta = torch.zeros_like(X).cuda() 112 | if norm == "l_inf": 113 | delta.uniform_(-epsilon, epsilon) 114 | elif norm == "l_2": 115 | delta.uniform_(-0.5,0.5).renorm(p=2, dim=1, maxnorm=epsilon) 116 | else: 117 | raise ValueError 118 | delta = clamp(delta, lower_limit-X, upper_limit-X) 119 | delta.requires_grad = True 120 | for _ in range(attack_iters): 121 | output = model(X + delta) 122 | if early_stop: 123 | index = torch.where(output.max(1)[1] == y)[0] 124 | else: 125 | index = slice(None,None,None) 126 | if not isinstance(index, slice) and len(index) == 0: 127 | break 128 | loss = F.cross_entropy(output, y) 129 | loss.backward() 130 | grad = delta.grad.detach() 131 | d = delta[index, :, :, :] 132 | g = grad[index, :, :, :] 133 | x = X[index, :, :, :] 134 | if norm == "l_inf": 135 | d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon) 136 | elif norm == "l_2": 137 | g_norm = torch.norm(g.view(g.shape[0],-1),dim=1).view(-1,1,1,1) 138 | scaled_g = g/(g_norm + 1e-10) 139 | d = (d + scaled_g*alpha).view(d.size(0),-1).renorm(p=2,dim=0,maxnorm=epsilon).view_as(d) 140 | d = clamp(d, lower_limit - x, upper_limit - x) 141 | delta.data[index, :, :, :] = d 142 | delta.grad.zero_() 143 | 144 | all_loss = F.cross_entropy(model(X+delta), y, reduction='none') 145 | max_delta[all_loss >= max_loss] = delta.detach()[all_loss >= max_loss] 146 | max_loss = torch.max(max_loss, all_loss) 147 | return max_delta 148 | 149 | def epoch_adversarial_PGD50(test_loader, model, epsilon=8): 150 | model.eval() 151 | 152 | total_err, total_loss = 0, 0 153 | for X,y in test_loader: 154 | X,y = X.cuda(), y.cuda() 155 | #adv samples 156 | delta = attack_pgd(model, X, y, epsilon=epsilon/255) 157 | yp = model(X + delta) 158 | loss = nn.CrossEntropyLoss()(yp,y) 159 | 160 | total_err += (yp.max(dim=1)[1] != y).sum().item() 161 | total_loss += loss.item() * X.shape[0] 162 | 163 | print ('PGD50: ', 1 - total_err / len(test_loader.dataset)) 164 | 165 | return 1 - total_err / len(test_loader.dataset) 166 | 167 | ################################ datasets ####################################### 168 | def cifar10_dataloaders(batch_size=128, num_workers=2, data_dir='datasets/cifar10'): 169 | 170 | train_transform = transforms.Compose([ 171 | transforms.RandomCrop(32, padding=4), 172 | transforms.RandomHorizontalFlip(), 173 | transforms.ToTensor(), 174 | ]) 175 | 176 | test_transform = transforms.Compose([ 177 | transforms.ToTensor(), 178 | ]) 179 | 180 | train_set = Subset(CIFAR10(data_dir, train=True, transform=train_transform, download=True), list(range(45000))) 181 | val_set = Subset(CIFAR10(data_dir, train=True, transform=test_transform, download=True), list(range(45000, 50000))) 182 | test_set = CIFAR10(data_dir, train=False, transform=test_transform, download=True) 183 | 184 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) 185 | val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) 186 | test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) 187 | 188 | return train_loader, val_loader, test_loader 189 | 190 | def cifar100_dataloaders(batch_size=128, num_workers=2, data_dir='datasets/cifar100'): 191 | 192 | train_transform = transforms.Compose([ 193 | transforms.RandomCrop(32, padding=4), 194 | transforms.RandomHorizontalFlip(), 195 | transforms.RandomRotation(15), 196 | transforms.ToTensor(), 197 | ]) 198 | 199 | test_transform = transforms.Compose([ 200 | transforms.ToTensor(), 201 | ]) 202 | 203 | train_set = Subset(CIFAR100(data_dir, train=True, transform=train_transform, download=True), list(range(45000))) 204 | val_set = Subset(CIFAR100(data_dir, train=True, transform=test_transform, download=True), list(range(45000, 50000))) 205 | test_set = CIFAR100(data_dir, train=False, transform=test_transform, download=True) 206 | 207 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) 208 | val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) 209 | test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) 210 | 211 | return train_loader, val_loader, test_loader 212 | 213 | def tiny_imagenet_dataloaders(batch_size=128, num_workers=2, data_dir = 'datasets/tiny-imagenet-200', permutation_seed=10): 214 | 215 | train_transform = transforms.Compose([ 216 | transforms.RandomCrop(64, padding=4), 217 | transforms.RandomHorizontalFlip(), 218 | transforms.ToTensor(), 219 | ]) 220 | 221 | test_transform = transforms.Compose([ 222 | transforms.ToTensor(), 223 | ]) 224 | 225 | train_path = os.path.join(data_dir, 'train') 226 | val_path = os.path.join(data_dir, 'val') 227 | 228 | np.random.seed(permutation_seed) 229 | split_permutation = list(np.random.permutation(100000)) 230 | 231 | train_set = Subset(ImageFolder(train_path, transform=train_transform), split_permutation[:90000]) 232 | val_set = Subset(ImageFolder(train_path, transform=test_transform), split_permutation[90000:]) 233 | test_set = ImageFolder(val_path, transform=test_transform) 234 | 235 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) 236 | val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) 237 | test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) 238 | 239 | return train_loader, val_loader, test_loader 240 | 241 | def get_datasets(args): 242 | if args.datasets == 'CIFAR10': 243 | return cifar10_dataloaders(batch_size=args.batch_size, num_workers=args.workers) 244 | 245 | elif args.datasets == 'CIFAR100': 246 | return cifar100_dataloaders(batch_size=args.batch_size, num_workers=args.workers) 247 | 248 | elif args.datasets == 'TinyImagenet': 249 | return tiny_imagenet_dataloaders(batch_size=args.batch_size, num_workers=args.workers, permutation_seed=args.randomseed) 250 | 251 | def get_model(args): 252 | if args.datasets == 'CIFAR10': 253 | num_class = 10 254 | dataset_normalization = NormalizeByChannelMeanStd( 255 | mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]) 256 | 257 | elif args.datasets == 'CIFAR100': 258 | num_class = 100 259 | dataset_normalization = NormalizeByChannelMeanStd( 260 | mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762]) 261 | 262 | elif args.datasets == 'TinyImagenet': 263 | num_class = 200 264 | dataset_normalization = NormalizeByChannelMeanStd( 265 | mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262]) 266 | 267 | if args.arch == 'PreActResNet18': 268 | import resnet 269 | net = resnet.__dict__[args.arch](num_classes=num_class) 270 | 271 | elif args.arch == 'WideResNet': 272 | import wideresnet 273 | net = wideresnet.__dict__[args.arch](28, num_classes=num_class, widen_factor=10, dropRate=0.0) 274 | 275 | net.normalize = dataset_normalization 276 | 277 | return net 278 | 279 | ################################ GradAlign loss ####################################### 280 | def get_uniform_delta(shape, eps, requires_grad=True): 281 | delta = torch.zeros(shape).cuda() 282 | delta.uniform_(-eps, eps) 283 | delta.requires_grad = requires_grad 284 | return delta 285 | 286 | def get_input_grad(model, X, y, eps, delta_init='none', backprop=False): 287 | if delta_init == 'none': 288 | delta = torch.zeros_like(X, requires_grad=True) 289 | elif delta_init == 'random_uniform': 290 | delta = get_uniform_delta(X.shape, eps, requires_grad=True) 291 | elif delta_init == 'random_corner': 292 | delta = get_uniform_delta(X.shape, eps, requires_grad=True) 293 | delta = eps * torch.sign(delta) 294 | else: 295 | raise ValueError('wrong delta init') 296 | 297 | output = model(X + delta) 298 | loss = F.cross_entropy(output, y) 299 | 300 | grad = torch.autograd.grad(loss, delta, create_graph=True if backprop else False)[0] 301 | 302 | if not backprop: 303 | grad, delta = grad.detach(), delta.detach() 304 | return grad 305 | 306 | def grad_align_loss(model, X, y, args): 307 | 308 | grad1 = get_input_grad(model, X, y, args.train_eps, delta_init='none', backprop=False) 309 | grad2 = get_input_grad(model, X, y, args.train_eps, delta_init='random_uniform', backprop=True) 310 | grad1, grad2 = grad1.reshape(len(grad1), -1), grad2.reshape(len(grad2), -1) 311 | cos = torch.nn.functional.cosine_similarity(grad1, grad2, 1) 312 | reg = args.gradalign_lambda * (1.0 - cos.mean()) 313 | 314 | return reg 315 | 316 | ################################ TRADES loss ####################################### 317 | def squared_l2_norm(x): 318 | flattened = x.view(x.unsqueeze(0).shape[0], -1) 319 | return (flattened ** 2).sum(1) 320 | 321 | 322 | def l2_norm(x): 323 | return squared_l2_norm(x).sqrt() 324 | 325 | 326 | def trades_loss(model, 327 | x_natural, 328 | y, 329 | optimizer, 330 | step_size=0.003, 331 | epsilon=0.031, 332 | perturb_steps=10, 333 | beta=1.0, 334 | distance='l_inf'): 335 | # define KL-loss 336 | criterion_kl = nn.KLDivLoss(size_average=False) 337 | model.eval() 338 | batch_size = len(x_natural) 339 | # generate adversarial example 340 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach() 341 | if distance == 'l_inf': 342 | for _ in range(perturb_steps): 343 | x_adv.requires_grad_() 344 | with torch.enable_grad(): 345 | loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1), 346 | F.softmax(model(x_natural), dim=1)) 347 | grad = torch.autograd.grad(loss_kl, [x_adv])[0] 348 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 349 | x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon) 350 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 351 | elif distance == 'l_2': 352 | delta = 0.001 * torch.randn(x_natural.shape).cuda().detach() 353 | delta = Variable(delta.data, requires_grad=True) 354 | 355 | # Setup optimizers 356 | optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2) 357 | 358 | for _ in range(perturb_steps): 359 | adv = x_natural + delta 360 | 361 | # optimize 362 | optimizer_delta.zero_grad() 363 | with torch.enable_grad(): 364 | loss = (-1) * criterion_kl(F.log_softmax(model(adv), dim=1), 365 | F.softmax(model(x_natural), dim=1)) 366 | loss.backward() 367 | # renorming gradient 368 | grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1) 369 | delta.grad.div_(grad_norms.view(-1, 1, 1, 1)) 370 | # avoid nan or inf if gradient is 0 371 | if (grad_norms == 0).any(): 372 | delta.grad[grad_norms == 0] = torch.randn_like(delta.grad[grad_norms == 0]) 373 | optimizer_delta.step() 374 | 375 | # projection 376 | delta.data.add_(x_natural) 377 | delta.data.clamp_(0, 1).sub_(x_natural) 378 | delta.data.renorm_(p=2, dim=0, maxnorm=epsilon) 379 | x_adv = Variable(x_natural + delta, requires_grad=False) 380 | else: 381 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 382 | model.train() 383 | 384 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 385 | # zero gradient 386 | optimizer.zero_grad() 387 | # calculate robust loss 388 | logits = model(x_natural) 389 | loss_natural = F.cross_entropy(logits, y) 390 | loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(model(x_adv), dim=1), 391 | F.softmax(model(x_natural), dim=1)) 392 | loss = loss_natural + beta * loss_robust 393 | return model(x_adv), loss 394 | 395 | 396 | ################################ AutoAttack ####################################### 397 | def AutoAttack(model, args, dataset='CIFAR10', norm='Linf', epsilon=8): 398 | model.eval() 399 | print ('evaluate AA:', dataset, norm, epsilon) 400 | epsilon /= 255. 401 | if dataset == 'CIFAR10': 402 | __, __, test_loader = cifar10_dataloaders(batch_size=1000) 403 | elif dataset == 'CIFAR100': 404 | __, __, test_loader = cifar100_dataloaders(batch_size=1000) 405 | 406 | l = [x for (x, y) in test_loader] 407 | x_test = torch.cat(l, 0) 408 | l = [y for (x, y) in test_loader] 409 | y_test = torch.cat(l, 0) 410 | 411 | # load attack 412 | from autoattack import AutoAttack 413 | adversary = AutoAttack(model, norm=norm, eps=epsilon, log_path=os.path.join(args.save_dir, 'log_file.txt'), 414 | version='standard') 415 | 416 | # run attack and save images 417 | with torch.no_grad(): 418 | adv_complete = adversary.run_standard_evaluation(x_test, y_test, bs=1000) 419 | 420 | ################################ Guided_Attack ####################################### 421 | def Guided_Attack(model,loss,image,target,eps=8/255,bounds=[0,1],steps=1,P_out=[],l2_reg=10,alt=1): 422 | tar = Variable(target.cuda()) 423 | img = image.cuda() 424 | eps = eps/steps 425 | for step in range(steps): 426 | img = Variable(img,requires_grad=True) 427 | zero_gradients(img) 428 | out = model(img) 429 | R_out = nn.Softmax(dim=1)(out) 430 | cost = loss(out,tar) + alt*l2_reg*(((P_out - R_out)**2.0).sum(1)).mean(0) 431 | cost.backward() 432 | per = eps * torch.sign(img.grad.data) 433 | adv = img.data + per.cuda() 434 | img = torch.clamp(adv,bounds[0],bounds[1]) 435 | return img -------------------------------------------------------------------------------- /wideresnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from advertorch.utils import NormalizeByChannelMeanStd 6 | 7 | __all__ = ['WideResNet'] 8 | 9 | class BasicBlock(nn.Module): 10 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 11 | super(BasicBlock, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.relu1 = nn.ReLU(inplace=True) 14 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(out_planes) 17 | self.relu2 = nn.ReLU(inplace=True) 18 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 19 | padding=1, bias=False) 20 | self.droprate = dropRate 21 | self.equalInOut = (in_planes == out_planes) 22 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 23 | padding=0, bias=False) or None 24 | def forward(self, x): 25 | if not self.equalInOut: 26 | x = self.relu1(self.bn1(x)) 27 | else: 28 | out = self.relu1(self.bn1(x)) 29 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 30 | if self.droprate > 0: 31 | out = F.dropout(out, p=self.droprate, training=self.training) 32 | out = self.conv2(out) 33 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 34 | 35 | class NetworkBlock(nn.Module): 36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 37 | super(NetworkBlock, self).__init__() 38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 39 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 40 | layers = [] 41 | for i in range(int(nb_layers)): 42 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 43 | return nn.Sequential(*layers) 44 | def forward(self, x): 45 | return self.layer(x) 46 | 47 | class WideResNet(nn.Module): 48 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 49 | super(WideResNet, self).__init__() 50 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 51 | assert((depth - 4) % 6 == 0) 52 | n = (depth - 4) / 6 53 | block = BasicBlock 54 | 55 | # default normalization is for CIFAR10 56 | self.normalize = NormalizeByChannelMeanStd( 57 | mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]) 58 | 59 | # 1st conv before any network block 60 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 61 | padding=1, bias=False) 62 | # 1st block 63 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 64 | # 2nd block 65 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 66 | # 3rd block 67 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 68 | # global average pooling and classifier 69 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.fc = nn.Linear(nChannels[3], num_classes) 72 | self.nChannels = nChannels[3] 73 | 74 | for m in self.modules(): 75 | if isinstance(m, nn.Conv2d): 76 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 77 | elif isinstance(m, nn.BatchNorm2d): 78 | m.weight.data.fill_(1) 79 | m.bias.data.zero_() 80 | elif isinstance(m, nn.Linear): 81 | m.bias.data.zero_() 82 | 83 | def forward(self, x): 84 | x = self.normalize(x) 85 | out = self.conv1(x) 86 | out = self.block1(out) 87 | out = self.block2(out) 88 | out = self.block3(out) 89 | out = self.relu(self.bn1(out)) 90 | out = F.avg_pool2d(out, 8) 91 | out = out.view(-1, self.nChannels) 92 | return self.fc(out) --------------------------------------------------------------------------------