├── LICENSE ├── README.md ├── configs.yml ├── init_paths.py ├── lib ├── __init__.py ├── utils.py └── validation.py ├── main_free.py └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Ali Shafahi, Mahyar Najibi, Amin Ghiasi, Zheng Xu, John Dickerson, Christoph Studer, Larry S. Davis, Gavin Taylor, Tom Goldstein 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | 12 | ############################################################################## 13 | THIRD-PARTY SOFTWARE LICENSES 14 | 15 | From PyTorch: 16 | 17 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 18 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 19 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 20 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 21 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 22 | Copyright (c) 2011-2013 NYU (Clement Farabet) 23 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 24 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 25 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 26 | 27 | From Caffe2: 28 | 29 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 30 | 31 | All contributions by Facebook: 32 | Copyright (c) 2016 Facebook Inc. 33 | 34 | All contributions by Google: 35 | Copyright (c) 2015 Google Inc. 36 | All rights reserved. 37 | 38 | All contributions by Yangqing Jia: 39 | Copyright (c) 2015 Yangqing Jia 40 | All rights reserved. 41 | 42 | All contributions from Caffe: 43 | Copyright(c) 2013, 2014, 2015, the respective contributors 44 | All rights reserved. 45 | 46 | All other contributions: 47 | Copyright(c) 2015, 2016 the respective contributors 48 | All rights reserved. 49 | 50 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 51 | copyright over their contributions to Caffe2. The project versioning records 52 | all such contribution and copyright details. If a contributor wants to further 53 | mark their specific copyright on a particular contribution, they should 54 | indicate their copyright solely in the commit message of the change when it is 55 | committed. 56 | 57 | All rights reserved. 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Free Adversarial Training 2 | This is a PyTorch implementation of the [Adversarial Training for Free!](https://arxiv.org/abs/1904.12843 "Free Adversarial Training") paper. 3 | The official TensorFlow implementation can be found [here](https://github.com/ashafahi/free_adv_train). 4 | 5 | Using the Free Adversarial Training (Free-m) algorithm, we can train robust models at no additional cost compared to natural training. This allows us to train robust ImageNet models using only a few GPUs in a couple of days! Below is the performance of various Free-trained ImageNet models where we vary the replay parameter (m). 6 |

7 | 8 |

9 | 10 | This repository provides codes for training and evaluating the models on the ImageNet dataset. 11 | The implementation is adapted from the [official PyTorch repository](https://github.com/pytorch/examples/blob/master/imagenet). 12 | 13 | ## Installation 14 | 1. Install [PyTorch](https://github.com/pytorch/examples/blob/master/imagenet). 15 | 2. Install the required python packages. All packages can be installed by running the following command: 16 | ```bash 17 | pip install -r requirements.txt 18 | ``` 19 | 3. Download and prepare the ImageNet dataset. You can use [this script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh), 20 | provided by the PyTorch repository, to move the validation subset to the labeled subfolders. 21 | 22 | ## Training a model 23 | To train a robust model run the following command: 24 | 25 | ```bash 26 | python main_free.py [PATH_TO_IMAGENET_ROOT] 27 | ``` 28 | This trains a robust model with the default parameters. The training parameters can be set by changing the ```configs.yml``` config file. 29 | Please run ```python main_free.py --help``` to see the list of possible arguments. 30 | The script saves the trained models into the ```trained_models``` folder and the logs into the ```output``` folder. 31 | 32 | 33 | ## Evaluating a trained model 34 | You can evaluate a trained model by running the following command: 35 | ```bash 36 | python main_free.py [PATH_TO_IMAGENET_ROOT] -e --resume [PATH_TO_THE_MODEL_CHECKPOINT] 37 | ``` 38 | The script evaluates the model on clean examples as well as examples generated by PGD attacks with different parameters. 39 | The attack parameters can be set by changing the ```configs.yml``` file. 40 | 41 | ## Citing 42 | If you find the paper or the code useful for your study, please consider citing the free training paper: 43 | ```bash 44 | @article{2019arXiv190412843S, 45 | author = {{Shafahi}, A. and {Najibi}, M. and {Ghiasi}, A. and {Xu}, Z. and 46 | {Dickerson}, J. and {Studer}, C. and {Davis}, L. and {Taylor}, G. and {Goldstein}, T.}, 47 | title = "{Adversarial Training for Free!}", 48 | journal = {ArXiv e-prints}, 49 | year = 2019 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /configs.yml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | # Number of training epochs 3 | epochs: 90 4 | 5 | # Architecture name, see pytorch models package for 6 | # a list of possible architectures 7 | arch: 'resnet50' 8 | 9 | # Starting epoch 10 | start_epoch: 0 11 | 12 | # SGD paramters 13 | lr: 0.1 14 | momentum: 0.9 15 | weight_decay: 0.0001 16 | 17 | # Print frequency, is used for both training and testing 18 | print_freq: 10 19 | 20 | # Dataset mean and std used for data normalization 21 | mean: !!python/tuple [0.485, 0.456, 0.406] 22 | std: !!python/tuple [0.229, 0.224, 0.225] 23 | 24 | ADV: 25 | # FGSM parameters during training 26 | clip_eps: 4.0 27 | fgsm_step: 4.0 28 | 29 | # Number of repeats for free adversarial training 30 | n_repeats: 4 31 | 32 | # PGD attack parameters used during validation 33 | # the same clip_eps as above is used for PGD 34 | pgd_attack: 35 | - !!python/tuple [10, 0.00392156862] #[10 iters, 1.0/255.0] 36 | - !!python/tuple [50, 0.00392156862] #[50 iters, 1.0/255.0] 37 | 38 | DATA: 39 | # Number of data workers 40 | workers: 4 41 | 42 | # Training batch size 43 | batch_size: 256 44 | 45 | # Image Size 46 | img_size: 256 47 | 48 | # Crop Size for data augmentation 49 | crop_size: 224 50 | 51 | # Color value range 52 | max_color_value: 255.0 53 | 54 | -------------------------------------------------------------------------------- /init_paths.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | sys.path.insert(0, 'lib') 5 | 6 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahyarnajibi/FreeAdversarialTraining/e185c84b1cf7495e59ac8453953572483541f424/lib/__init__.py -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import datetime 4 | import torchvision.models as models 5 | import math 6 | import torch 7 | import yaml 8 | from easydict import EasyDict 9 | import shutil 10 | 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value""" 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | def adjust_learning_rate(initial_lr, optimizer, epoch, n_repeats): 30 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 31 | lr = initial_lr * (0.1 ** (epoch // int(math.ceil(30./n_repeats)))) 32 | for param_group in optimizer.param_groups: 33 | param_group['lr'] = lr 34 | 35 | 36 | def fgsm(gradz, step_size): 37 | return step_size*torch.sign(gradz) 38 | 39 | 40 | 41 | def accuracy(output, target, topk=(1,)): 42 | """Computes the accuracy over the k top predictions for the specified values of k""" 43 | with torch.no_grad(): 44 | maxk = max(topk) 45 | batch_size = target.size(0) 46 | 47 | _, pred = output.topk(maxk, 1, True, True) 48 | pred = pred.t() 49 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 50 | 51 | res = [] 52 | for k in topk: 53 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 54 | res.append(correct_k.mul_(100.0 / batch_size)) 55 | return res 56 | 57 | 58 | def initiate_logger(output_path): 59 | if not os.path.isdir(os.path.join('output', output_path)): 60 | os.makedirs(os.path.join('output', output_path)) 61 | logging.basicConfig(level=logging.INFO) 62 | logger = logging.getLogger() 63 | logger.addHandler(logging.FileHandler(os.path.join('output', output_path, 'log.txt'),'w')) 64 | logger.info(pad_str(' LOGISTICS ')) 65 | logger.info('Experiment Date: {}'.format(datetime.datetime.now().strftime('%Y-%m-%d %H:%M'))) 66 | logger.info('Output Name: {}'.format(output_path)) 67 | logger.info('User: {}'.format(os.getenv('USER'))) 68 | return logger 69 | 70 | def get_model_names(): 71 | return sorted(name for name in models.__dict__ 72 | if name.islower() and not name.startswith("__") 73 | and callable(models.__dict__[name])) 74 | 75 | def pad_str(msg, total_len=70): 76 | rem_len = total_len - len(msg) 77 | return '*'*int(rem_len/2) + msg + '*'*int(rem_len/2)\ 78 | 79 | def parse_config_file(args): 80 | with open(args.config) as f: 81 | config = EasyDict(yaml.load(f)) 82 | 83 | # Add args parameters to the dict 84 | for k, v in vars(args).items(): 85 | config[k] = v 86 | 87 | # Add the output path 88 | config.output_name = '{:s}_step{:d}_eps{:d}_repeat{:d}'.format(args.output_prefix, 89 | int(config.ADV.fgsm_step), int(config.ADV.clip_eps), 90 | config.ADV.n_repeats) 91 | return config 92 | 93 | 94 | def save_checkpoint(state, is_best, filepath): 95 | filename = os.path.join(filepath, 'checkpoint.pth.tar') 96 | # Save model 97 | torch.save(state, filename) 98 | # Save best model 99 | if is_best: 100 | shutil.copyfile(filename, os.path.join(filepath, 'model_best.pth.tar')) 101 | -------------------------------------------------------------------------------- /lib/validation.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import torch 3 | import sys 4 | import numpy as np 5 | import time 6 | from torch.autograd import Variable 7 | 8 | def validate_pgd(val_loader, model, criterion, K, step, configs, logger): 9 | # Mean/Std for normalization 10 | mean = torch.Tensor(np.array(configs.TRAIN.mean)[:, np.newaxis, np.newaxis]) 11 | mean = mean.expand(3,configs.DATA.crop_size, configs.DATA.crop_size).cuda() 12 | std = torch.Tensor(np.array(configs.TRAIN.std)[:, np.newaxis, np.newaxis]) 13 | std = std.expand(3, configs.DATA.crop_size, configs.DATA.crop_size).cuda() 14 | # Initiate the meters 15 | batch_time = AverageMeter() 16 | losses = AverageMeter() 17 | top1 = AverageMeter() 18 | top5 = AverageMeter() 19 | 20 | eps = configs.ADV.clip_eps 21 | model.eval() 22 | end = time.time() 23 | logger.info(pad_str(' PGD eps: {}, K: {}, step: {} '.format(eps, K, step))) 24 | for i, (input, target) in enumerate(val_loader): 25 | 26 | input = input.cuda(non_blocking=True) 27 | target = target.cuda(non_blocking=True) 28 | 29 | orig_input = input.clone() 30 | randn = torch.FloatTensor(input.size()).uniform_(-eps, eps).cuda() 31 | input += randn 32 | input.clamp_(0, 1.0) 33 | for _ in range(K): 34 | invar = Variable(input, requires_grad=True) 35 | in1 = invar - mean 36 | in1.div_(std) 37 | output = model(in1) 38 | ascend_loss = criterion(output, target) 39 | ascend_grad = torch.autograd.grad(ascend_loss, invar)[0] 40 | pert = fgsm(ascend_grad, step) 41 | # Apply purturbation 42 | input += pert.data 43 | input = torch.max(orig_input-eps, input) 44 | input = torch.min(orig_input+eps, input) 45 | input.clamp_(0, 1.0) 46 | 47 | input.sub_(mean).div_(std) 48 | with torch.no_grad(): 49 | # compute output 50 | output = model(input) 51 | loss = criterion(output, target) 52 | 53 | # measure accuracy and record loss 54 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 55 | losses.update(loss.item(), input.size(0)) 56 | top1.update(prec1[0], input.size(0)) 57 | top5.update(prec5[0], input.size(0)) 58 | 59 | # measure elapsed time 60 | batch_time.update(time.time() - end) 61 | end = time.time() 62 | 63 | if i % configs.TRAIN.print_freq == 0: 64 | print('PGD Test: [{0}/{1}]\t' 65 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 66 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 67 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 68 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 69 | i, len(val_loader), batch_time=batch_time, loss=losses, 70 | top1=top1, top5=top5)) 71 | sys.stdout.flush() 72 | 73 | print(' PGD Final Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 74 | .format(top1=top1, top5=top5)) 75 | 76 | return top1.avg 77 | 78 | def validate(val_loader, model, criterion, configs, logger): 79 | # Mean/Std for normalization 80 | mean = torch.Tensor(np.array(configs.TRAIN.mean)[:, np.newaxis, np.newaxis]) 81 | mean = mean.expand(3,configs.DATA.crop_size, configs.DATA.crop_size).cuda() 82 | std = torch.Tensor(np.array(configs.TRAIN.std)[:, np.newaxis, np.newaxis]) 83 | std = std.expand(3, configs.DATA.crop_size, configs.DATA.crop_size).cuda() 84 | 85 | # Initiate the meters 86 | batch_time = AverageMeter() 87 | losses = AverageMeter() 88 | top1 = AverageMeter() 89 | top5 = AverageMeter() 90 | # switch to evaluate mode 91 | model.eval() 92 | end = time.time() 93 | for i, (input, target) in enumerate(val_loader): 94 | with torch.no_grad(): 95 | input = input.cuda(non_blocking=True) 96 | target = target.cuda(non_blocking=True) 97 | 98 | # compute output 99 | input = input - mean 100 | input.div_(std) 101 | output = model(input) 102 | loss = criterion(output, target) 103 | 104 | # measure accuracy and record loss 105 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 106 | losses.update(loss.item(), input.size(0)) 107 | top1.update(prec1[0], input.size(0)) 108 | top5.update(prec5[0], input.size(0)) 109 | 110 | # measure elapsed time 111 | batch_time.update(time.time() - end) 112 | end = time.time() 113 | 114 | if i % configs.TRAIN.print_freq == 0: 115 | print('Test: [{0}/{1}]\t' 116 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 117 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 118 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 119 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 120 | i, len(val_loader), batch_time=batch_time, loss=losses, 121 | top1=top1, top5=top5)) 122 | sys.stdout.flush() 123 | 124 | print(' Final Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 125 | .format(top1=top1, top5=top5)) 126 | return top1.avg -------------------------------------------------------------------------------- /main_free.py: -------------------------------------------------------------------------------- 1 | # This module is adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py 2 | import init_paths 3 | import argparse 4 | import os 5 | import time 6 | import sys 7 | import torch 8 | import torch.nn as nn 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim 11 | import torchvision.transforms as transforms 12 | import torchvision.datasets as datasets 13 | from torch.autograd import Variable 14 | import math 15 | import numpy as np 16 | from utils import * 17 | from validation import validate, validate_pgd 18 | import torchvision.models as models 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 23 | parser.add_argument('data', metavar='DIR', 24 | help='path to dataset') 25 | parser.add_argument('--output_prefix', default='free_adv', type=str, 26 | help='prefix used to define output path') 27 | parser.add_argument('-c', '--config', default='configs.yml', type=str, metavar='Path', 28 | help='path to the config file (default: configs.yml)') 29 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 30 | help='path to latest checkpoint (default: none)') 31 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 32 | help='evaluate model on validation set') 33 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 34 | help='use pre-trained model') 35 | return parser.parse_args() 36 | 37 | 38 | # Parase config file and initiate logging 39 | configs = parse_config_file(parse_args()) 40 | logger = initiate_logger(configs.output_name) 41 | print = logger.info 42 | cudnn.benchmark = True 43 | 44 | def main(): 45 | # Scale and initialize the parameters 46 | best_prec1 = 0 47 | configs.TRAIN.epochs = int(math.ceil(configs.TRAIN.epochs / configs.ADV.n_repeats)) 48 | configs.ADV.fgsm_step /= configs.DATA.max_color_value 49 | configs.ADV.clip_eps /= configs.DATA.max_color_value 50 | 51 | # Create output folder 52 | if not os.path.isdir(os.path.join('trained_models', configs.output_name)): 53 | os.makedirs(os.path.join('trained_models', configs.output_name)) 54 | 55 | # Log the config details 56 | logger.info(pad_str(' ARGUMENTS ')) 57 | for k, v in configs.items(): print('{}: {}'.format(k, v)) 58 | logger.info(pad_str('')) 59 | 60 | 61 | # Create the model 62 | if configs.pretrained: 63 | print("=> using pre-trained model '{}'".format(configs.TRAIN.arch)) 64 | model = models.__dict__[configs.TRAIN.arch](pretrained=True) 65 | else: 66 | print("=> creating model '{}'".format(configs.TRAIN.arch)) 67 | model = models.__dict__[configs.TRAIN.arch]() 68 | 69 | # Wrap the model into DataParallel 70 | model = torch.nn.DataParallel(model).cuda() 71 | 72 | # Criterion: 73 | criterion = nn.CrossEntropyLoss().cuda() 74 | 75 | # Optimizer: 76 | optimizer = torch.optim.SGD(model.parameters(), configs.TRAIN.lr, 77 | momentum=configs.TRAIN.momentum, 78 | weight_decay=configs.TRAIN.weight_decay) 79 | 80 | # Resume if a valid checkpoint path is provided 81 | if configs.resume: 82 | if os.path.isfile(configs.resume): 83 | print("=> loading checkpoint '{}'".format(configs.resume)) 84 | checkpoint = torch.load(configs.resume) 85 | configs.TRAIN.start_epoch = checkpoint['epoch'] 86 | best_prec1 = checkpoint['best_prec1'] 87 | model.load_state_dict(checkpoint['state_dict']) 88 | optimizer.load_state_dict(checkpoint['optimizer']) 89 | print("=> loaded checkpoint '{}' (epoch {})" 90 | .format(configs.resume, checkpoint['epoch'])) 91 | else: 92 | print("=> no checkpoint found at '{}'".format(configs.resume)) 93 | 94 | 95 | # Initiate data loaders 96 | traindir = os.path.join(configs.data, 'train') 97 | valdir = os.path.join(configs.data, 'val') 98 | 99 | train_dataset = datasets.ImageFolder( 100 | traindir, 101 | transforms.Compose([ 102 | transforms.RandomResizedCrop(configs.DATA.crop_size), 103 | transforms.RandomHorizontalFlip(), 104 | transforms.ToTensor(), 105 | ])) 106 | 107 | train_loader = torch.utils.data.DataLoader( 108 | train_dataset, batch_size=configs.DATA.batch_size, shuffle=True, 109 | num_workers=configs.DATA.workers, pin_memory=True, sampler=None) 110 | 111 | normalize = transforms.Normalize(mean=configs.TRAIN.mean, 112 | std=configs.TRAIN.std) 113 | val_loader = torch.utils.data.DataLoader( 114 | datasets.ImageFolder(valdir, transforms.Compose([ 115 | transforms.Resize(configs.DATA.img_size), 116 | transforms.CenterCrop(configs.DATA.crop_size), 117 | transforms.ToTensor(), 118 | ])), 119 | batch_size=configs.DATA.batch_size, shuffle=False, 120 | num_workers=configs.DATA.workers, pin_memory=True) 121 | 122 | # If in evaluate mode: perform validation on PGD attacks as well as clean samples 123 | if configs.evaluate: 124 | logger.info(pad_str(' Performing PGD Attacks ')) 125 | for pgd_param in configs.ADV.pgd_attack: 126 | validate_pgd(val_loader, model, criterion, pgd_param[0], pgd_param[1], configs, logger) 127 | validate(val_loader, model, criterion, configs, logger) 128 | return 129 | 130 | 131 | for epoch in range(configs.TRAIN.start_epoch, configs.TRAIN.epochs): 132 | adjust_learning_rate(configs.TRAIN.lr, optimizer, epoch, configs.ADV.n_repeats) 133 | 134 | # train for one epoch 135 | train(train_loader, model, criterion, optimizer, epoch) 136 | 137 | # evaluate on validation set 138 | prec1 = validate(val_loader, model, criterion, configs, logger) 139 | 140 | # remember best prec@1 and save checkpoint 141 | is_best = prec1 > best_prec1 142 | best_prec1 = max(prec1, best_prec1) 143 | save_checkpoint({ 144 | 'epoch': epoch + 1, 145 | 'arch': configs.TRAIN.arch, 146 | 'state_dict': model.state_dict(), 147 | 'best_prec1': best_prec1, 148 | 'optimizer' : optimizer.state_dict(), 149 | }, is_best, os.path.join('trained_models', configs.output_name)) 150 | 151 | # Automatically perform PGD Attacks at the end of training 152 | logger.info(pad_str(' Performing PGD Attacks ')) 153 | for pgd_param in configs.ADV.pgd_attack: 154 | validate_pgd(val_loader, model, criterion, pgd_param[0], pgd_param[1], configs, logger) 155 | 156 | 157 | # Free Adversarial Training Module 158 | global global_noise_data 159 | global_noise_data = torch.zeros([configs.DATA.batch_size, 3, configs.DATA.crop_size, configs.DATA.crop_size]).cuda() 160 | def train(train_loader, model, criterion, optimizer, epoch): 161 | global global_noise_data 162 | mean = torch.Tensor(np.array(configs.TRAIN.mean)[:, np.newaxis, np.newaxis]) 163 | mean = mean.expand(3,configs.DATA.crop_size, configs.DATA.crop_size).cuda() 164 | std = torch.Tensor(np.array(configs.TRAIN.std)[:, np.newaxis, np.newaxis]) 165 | std = std.expand(3, configs.DATA.crop_size, configs.DATA.crop_size).cuda() 166 | # Initialize the meters 167 | batch_time = AverageMeter() 168 | data_time = AverageMeter() 169 | losses = AverageMeter() 170 | top1 = AverageMeter() 171 | top5 = AverageMeter() 172 | # switch to train mode 173 | model.train() 174 | for i, (input, target) in enumerate(train_loader): 175 | end = time.time() 176 | input = input.cuda(non_blocking=True) 177 | target = target.cuda(non_blocking=True) 178 | data_time.update(time.time() - end) 179 | for j in range(configs.ADV.n_repeats): 180 | # Ascend on the global noise 181 | noise_batch = Variable(global_noise_data[0:input.size(0)], requires_grad=True).cuda() 182 | in1 = input + noise_batch 183 | in1.clamp_(0, 1.0) 184 | in1.sub_(mean).div_(std) 185 | output = model(in1) 186 | loss = criterion(output, target) 187 | 188 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 189 | losses.update(loss.item(), input.size(0)) 190 | top1.update(prec1[0], input.size(0)) 191 | top5.update(prec5[0], input.size(0)) 192 | 193 | # compute gradient and do SGD step 194 | optimizer.zero_grad() 195 | loss.backward() 196 | 197 | # Update the noise for the next iteration 198 | pert = fgsm(noise_batch.grad, configs.ADV.fgsm_step) 199 | global_noise_data[0:input.size(0)] += pert.data 200 | global_noise_data.clamp_(-configs.ADV.clip_eps, configs.ADV.clip_eps) 201 | 202 | optimizer.step() 203 | # measure elapsed time 204 | batch_time.update(time.time() - end) 205 | end = time.time() 206 | 207 | if i % configs.TRAIN.print_freq == 0: 208 | print('Train Epoch: [{0}][{1}/{2}]\t' 209 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 210 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 211 | 'Loss {cls_loss.val:.4f} ({cls_loss.avg:.4f})\t' 212 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 213 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 214 | epoch, i, len(train_loader), batch_time=batch_time, 215 | data_time=data_time, top1=top1, top5=top5,cls_loss=losses)) 216 | sys.stdout.flush() 217 | 218 | if __name__ == '__main__': 219 | main() 220 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pyyaml 3 | EasyDict 4 | argparse 5 | --------------------------------------------------------------------------------