├── LICENSE
├── README.md
├── configs.yml
├── init_paths.py
├── lib
├── __init__.py
├── utils.py
└── validation.py
├── main_free.py
└── requirements.txt
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Ali Shafahi, Mahyar Najibi, Amin Ghiasi, Zheng Xu, John Dickerson, Christoph Studer, Larry S. Davis, Gavin Taylor, Tom Goldstein
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 |
12 | ##############################################################################
13 | THIRD-PARTY SOFTWARE LICENSES
14 |
15 | From PyTorch:
16 |
17 | Copyright (c) 2016- Facebook, Inc (Adam Paszke)
18 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
19 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
20 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
21 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
22 | Copyright (c) 2011-2013 NYU (Clement Farabet)
23 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
24 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
25 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
26 |
27 | From Caffe2:
28 |
29 | Copyright (c) 2016-present, Facebook Inc. All rights reserved.
30 |
31 | All contributions by Facebook:
32 | Copyright (c) 2016 Facebook Inc.
33 |
34 | All contributions by Google:
35 | Copyright (c) 2015 Google Inc.
36 | All rights reserved.
37 |
38 | All contributions by Yangqing Jia:
39 | Copyright (c) 2015 Yangqing Jia
40 | All rights reserved.
41 |
42 | All contributions from Caffe:
43 | Copyright(c) 2013, 2014, 2015, the respective contributors
44 | All rights reserved.
45 |
46 | All other contributions:
47 | Copyright(c) 2015, 2016 the respective contributors
48 | All rights reserved.
49 |
50 | Caffe2 uses a copyright model similar to Caffe: each contributor holds
51 | copyright over their contributions to Caffe2. The project versioning records
52 | all such contribution and copyright details. If a contributor wants to further
53 | mark their specific copyright on a particular contribution, they should
54 | indicate their copyright solely in the commit message of the change when it is
55 | committed.
56 |
57 | All rights reserved.
58 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Free Adversarial Training
2 | This is a PyTorch implementation of the [Adversarial Training for Free!](https://arxiv.org/abs/1904.12843 "Free Adversarial Training") paper.
3 | The official TensorFlow implementation can be found [here](https://github.com/ashafahi/free_adv_train).
4 |
5 | Using the Free Adversarial Training (Free-m) algorithm, we can train robust models at no additional cost compared to natural training. This allows us to train robust ImageNet models using only a few GPUs in a couple of days! Below is the performance of various Free-trained ImageNet models where we vary the replay parameter (m).
6 |
7 |
8 |
9 |
10 | This repository provides codes for training and evaluating the models on the ImageNet dataset.
11 | The implementation is adapted from the [official PyTorch repository](https://github.com/pytorch/examples/blob/master/imagenet).
12 |
13 | ## Installation
14 | 1. Install [PyTorch](https://github.com/pytorch/examples/blob/master/imagenet).
15 | 2. Install the required python packages. All packages can be installed by running the following command:
16 | ```bash
17 | pip install -r requirements.txt
18 | ```
19 | 3. Download and prepare the ImageNet dataset. You can use [this script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh),
20 | provided by the PyTorch repository, to move the validation subset to the labeled subfolders.
21 |
22 | ## Training a model
23 | To train a robust model run the following command:
24 |
25 | ```bash
26 | python main_free.py [PATH_TO_IMAGENET_ROOT]
27 | ```
28 | This trains a robust model with the default parameters. The training parameters can be set by changing the ```configs.yml``` config file.
29 | Please run ```python main_free.py --help``` to see the list of possible arguments.
30 | The script saves the trained models into the ```trained_models``` folder and the logs into the ```output``` folder.
31 |
32 |
33 | ## Evaluating a trained model
34 | You can evaluate a trained model by running the following command:
35 | ```bash
36 | python main_free.py [PATH_TO_IMAGENET_ROOT] -e --resume [PATH_TO_THE_MODEL_CHECKPOINT]
37 | ```
38 | The script evaluates the model on clean examples as well as examples generated by PGD attacks with different parameters.
39 | The attack parameters can be set by changing the ```configs.yml``` file.
40 |
41 | ## Citing
42 | If you find the paper or the code useful for your study, please consider citing the free training paper:
43 | ```bash
44 | @article{2019arXiv190412843S,
45 | author = {{Shafahi}, A. and {Najibi}, M. and {Ghiasi}, A. and {Xu}, Z. and
46 | {Dickerson}, J. and {Studer}, C. and {Davis}, L. and {Taylor}, G. and {Goldstein}, T.},
47 | title = "{Adversarial Training for Free!}",
48 | journal = {ArXiv e-prints},
49 | year = 2019
50 | }
51 | ```
52 |
--------------------------------------------------------------------------------
/configs.yml:
--------------------------------------------------------------------------------
1 | TRAIN:
2 | # Number of training epochs
3 | epochs: 90
4 |
5 | # Architecture name, see pytorch models package for
6 | # a list of possible architectures
7 | arch: 'resnet50'
8 |
9 | # Starting epoch
10 | start_epoch: 0
11 |
12 | # SGD paramters
13 | lr: 0.1
14 | momentum: 0.9
15 | weight_decay: 0.0001
16 |
17 | # Print frequency, is used for both training and testing
18 | print_freq: 10
19 |
20 | # Dataset mean and std used for data normalization
21 | mean: !!python/tuple [0.485, 0.456, 0.406]
22 | std: !!python/tuple [0.229, 0.224, 0.225]
23 |
24 | ADV:
25 | # FGSM parameters during training
26 | clip_eps: 4.0
27 | fgsm_step: 4.0
28 |
29 | # Number of repeats for free adversarial training
30 | n_repeats: 4
31 |
32 | # PGD attack parameters used during validation
33 | # the same clip_eps as above is used for PGD
34 | pgd_attack:
35 | - !!python/tuple [10, 0.00392156862] #[10 iters, 1.0/255.0]
36 | - !!python/tuple [50, 0.00392156862] #[50 iters, 1.0/255.0]
37 |
38 | DATA:
39 | # Number of data workers
40 | workers: 4
41 |
42 | # Training batch size
43 | batch_size: 256
44 |
45 | # Image Size
46 | img_size: 256
47 |
48 | # Crop Size for data augmentation
49 | crop_size: 224
50 |
51 | # Color value range
52 | max_color_value: 255.0
53 |
54 |
--------------------------------------------------------------------------------
/init_paths.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import matplotlib
3 | matplotlib.use('Agg')
4 | sys.path.insert(0, 'lib')
5 |
6 |
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mahyarnajibi/FreeAdversarialTraining/e185c84b1cf7495e59ac8453953572483541f424/lib/__init__.py
--------------------------------------------------------------------------------
/lib/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import datetime
4 | import torchvision.models as models
5 | import math
6 | import torch
7 | import yaml
8 | from easydict import EasyDict
9 | import shutil
10 |
11 | class AverageMeter(object):
12 | """Computes and stores the average and current value"""
13 | def __init__(self):
14 | self.reset()
15 |
16 | def reset(self):
17 | self.val = 0
18 | self.avg = 0
19 | self.sum = 0
20 | self.count = 0
21 |
22 | def update(self, val, n=1):
23 | self.val = val
24 | self.sum += val * n
25 | self.count += n
26 | self.avg = self.sum / self.count
27 |
28 |
29 | def adjust_learning_rate(initial_lr, optimizer, epoch, n_repeats):
30 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
31 | lr = initial_lr * (0.1 ** (epoch // int(math.ceil(30./n_repeats))))
32 | for param_group in optimizer.param_groups:
33 | param_group['lr'] = lr
34 |
35 |
36 | def fgsm(gradz, step_size):
37 | return step_size*torch.sign(gradz)
38 |
39 |
40 |
41 | def accuracy(output, target, topk=(1,)):
42 | """Computes the accuracy over the k top predictions for the specified values of k"""
43 | with torch.no_grad():
44 | maxk = max(topk)
45 | batch_size = target.size(0)
46 |
47 | _, pred = output.topk(maxk, 1, True, True)
48 | pred = pred.t()
49 | correct = pred.eq(target.view(1, -1).expand_as(pred))
50 |
51 | res = []
52 | for k in topk:
53 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
54 | res.append(correct_k.mul_(100.0 / batch_size))
55 | return res
56 |
57 |
58 | def initiate_logger(output_path):
59 | if not os.path.isdir(os.path.join('output', output_path)):
60 | os.makedirs(os.path.join('output', output_path))
61 | logging.basicConfig(level=logging.INFO)
62 | logger = logging.getLogger()
63 | logger.addHandler(logging.FileHandler(os.path.join('output', output_path, 'log.txt'),'w'))
64 | logger.info(pad_str(' LOGISTICS '))
65 | logger.info('Experiment Date: {}'.format(datetime.datetime.now().strftime('%Y-%m-%d %H:%M')))
66 | logger.info('Output Name: {}'.format(output_path))
67 | logger.info('User: {}'.format(os.getenv('USER')))
68 | return logger
69 |
70 | def get_model_names():
71 | return sorted(name for name in models.__dict__
72 | if name.islower() and not name.startswith("__")
73 | and callable(models.__dict__[name]))
74 |
75 | def pad_str(msg, total_len=70):
76 | rem_len = total_len - len(msg)
77 | return '*'*int(rem_len/2) + msg + '*'*int(rem_len/2)\
78 |
79 | def parse_config_file(args):
80 | with open(args.config) as f:
81 | config = EasyDict(yaml.load(f))
82 |
83 | # Add args parameters to the dict
84 | for k, v in vars(args).items():
85 | config[k] = v
86 |
87 | # Add the output path
88 | config.output_name = '{:s}_step{:d}_eps{:d}_repeat{:d}'.format(args.output_prefix,
89 | int(config.ADV.fgsm_step), int(config.ADV.clip_eps),
90 | config.ADV.n_repeats)
91 | return config
92 |
93 |
94 | def save_checkpoint(state, is_best, filepath):
95 | filename = os.path.join(filepath, 'checkpoint.pth.tar')
96 | # Save model
97 | torch.save(state, filename)
98 | # Save best model
99 | if is_best:
100 | shutil.copyfile(filename, os.path.join(filepath, 'model_best.pth.tar'))
101 |
--------------------------------------------------------------------------------
/lib/validation.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | import torch
3 | import sys
4 | import numpy as np
5 | import time
6 | from torch.autograd import Variable
7 |
8 | def validate_pgd(val_loader, model, criterion, K, step, configs, logger):
9 | # Mean/Std for normalization
10 | mean = torch.Tensor(np.array(configs.TRAIN.mean)[:, np.newaxis, np.newaxis])
11 | mean = mean.expand(3,configs.DATA.crop_size, configs.DATA.crop_size).cuda()
12 | std = torch.Tensor(np.array(configs.TRAIN.std)[:, np.newaxis, np.newaxis])
13 | std = std.expand(3, configs.DATA.crop_size, configs.DATA.crop_size).cuda()
14 | # Initiate the meters
15 | batch_time = AverageMeter()
16 | losses = AverageMeter()
17 | top1 = AverageMeter()
18 | top5 = AverageMeter()
19 |
20 | eps = configs.ADV.clip_eps
21 | model.eval()
22 | end = time.time()
23 | logger.info(pad_str(' PGD eps: {}, K: {}, step: {} '.format(eps, K, step)))
24 | for i, (input, target) in enumerate(val_loader):
25 |
26 | input = input.cuda(non_blocking=True)
27 | target = target.cuda(non_blocking=True)
28 |
29 | orig_input = input.clone()
30 | randn = torch.FloatTensor(input.size()).uniform_(-eps, eps).cuda()
31 | input += randn
32 | input.clamp_(0, 1.0)
33 | for _ in range(K):
34 | invar = Variable(input, requires_grad=True)
35 | in1 = invar - mean
36 | in1.div_(std)
37 | output = model(in1)
38 | ascend_loss = criterion(output, target)
39 | ascend_grad = torch.autograd.grad(ascend_loss, invar)[0]
40 | pert = fgsm(ascend_grad, step)
41 | # Apply purturbation
42 | input += pert.data
43 | input = torch.max(orig_input-eps, input)
44 | input = torch.min(orig_input+eps, input)
45 | input.clamp_(0, 1.0)
46 |
47 | input.sub_(mean).div_(std)
48 | with torch.no_grad():
49 | # compute output
50 | output = model(input)
51 | loss = criterion(output, target)
52 |
53 | # measure accuracy and record loss
54 | prec1, prec5 = accuracy(output, target, topk=(1, 5))
55 | losses.update(loss.item(), input.size(0))
56 | top1.update(prec1[0], input.size(0))
57 | top5.update(prec5[0], input.size(0))
58 |
59 | # measure elapsed time
60 | batch_time.update(time.time() - end)
61 | end = time.time()
62 |
63 | if i % configs.TRAIN.print_freq == 0:
64 | print('PGD Test: [{0}/{1}]\t'
65 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
66 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
67 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
68 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
69 | i, len(val_loader), batch_time=batch_time, loss=losses,
70 | top1=top1, top5=top5))
71 | sys.stdout.flush()
72 |
73 | print(' PGD Final Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
74 | .format(top1=top1, top5=top5))
75 |
76 | return top1.avg
77 |
78 | def validate(val_loader, model, criterion, configs, logger):
79 | # Mean/Std for normalization
80 | mean = torch.Tensor(np.array(configs.TRAIN.mean)[:, np.newaxis, np.newaxis])
81 | mean = mean.expand(3,configs.DATA.crop_size, configs.DATA.crop_size).cuda()
82 | std = torch.Tensor(np.array(configs.TRAIN.std)[:, np.newaxis, np.newaxis])
83 | std = std.expand(3, configs.DATA.crop_size, configs.DATA.crop_size).cuda()
84 |
85 | # Initiate the meters
86 | batch_time = AverageMeter()
87 | losses = AverageMeter()
88 | top1 = AverageMeter()
89 | top5 = AverageMeter()
90 | # switch to evaluate mode
91 | model.eval()
92 | end = time.time()
93 | for i, (input, target) in enumerate(val_loader):
94 | with torch.no_grad():
95 | input = input.cuda(non_blocking=True)
96 | target = target.cuda(non_blocking=True)
97 |
98 | # compute output
99 | input = input - mean
100 | input.div_(std)
101 | output = model(input)
102 | loss = criterion(output, target)
103 |
104 | # measure accuracy and record loss
105 | prec1, prec5 = accuracy(output, target, topk=(1, 5))
106 | losses.update(loss.item(), input.size(0))
107 | top1.update(prec1[0], input.size(0))
108 | top5.update(prec5[0], input.size(0))
109 |
110 | # measure elapsed time
111 | batch_time.update(time.time() - end)
112 | end = time.time()
113 |
114 | if i % configs.TRAIN.print_freq == 0:
115 | print('Test: [{0}/{1}]\t'
116 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
117 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
118 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
119 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
120 | i, len(val_loader), batch_time=batch_time, loss=losses,
121 | top1=top1, top5=top5))
122 | sys.stdout.flush()
123 |
124 | print(' Final Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
125 | .format(top1=top1, top5=top5))
126 | return top1.avg
--------------------------------------------------------------------------------
/main_free.py:
--------------------------------------------------------------------------------
1 | # This module is adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py
2 | import init_paths
3 | import argparse
4 | import os
5 | import time
6 | import sys
7 | import torch
8 | import torch.nn as nn
9 | import torch.backends.cudnn as cudnn
10 | import torch.optim
11 | import torchvision.transforms as transforms
12 | import torchvision.datasets as datasets
13 | from torch.autograd import Variable
14 | import math
15 | import numpy as np
16 | from utils import *
17 | from validation import validate, validate_pgd
18 | import torchvision.models as models
19 |
20 |
21 | def parse_args():
22 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
23 | parser.add_argument('data', metavar='DIR',
24 | help='path to dataset')
25 | parser.add_argument('--output_prefix', default='free_adv', type=str,
26 | help='prefix used to define output path')
27 | parser.add_argument('-c', '--config', default='configs.yml', type=str, metavar='Path',
28 | help='path to the config file (default: configs.yml)')
29 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
30 | help='path to latest checkpoint (default: none)')
31 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
32 | help='evaluate model on validation set')
33 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
34 | help='use pre-trained model')
35 | return parser.parse_args()
36 |
37 |
38 | # Parase config file and initiate logging
39 | configs = parse_config_file(parse_args())
40 | logger = initiate_logger(configs.output_name)
41 | print = logger.info
42 | cudnn.benchmark = True
43 |
44 | def main():
45 | # Scale and initialize the parameters
46 | best_prec1 = 0
47 | configs.TRAIN.epochs = int(math.ceil(configs.TRAIN.epochs / configs.ADV.n_repeats))
48 | configs.ADV.fgsm_step /= configs.DATA.max_color_value
49 | configs.ADV.clip_eps /= configs.DATA.max_color_value
50 |
51 | # Create output folder
52 | if not os.path.isdir(os.path.join('trained_models', configs.output_name)):
53 | os.makedirs(os.path.join('trained_models', configs.output_name))
54 |
55 | # Log the config details
56 | logger.info(pad_str(' ARGUMENTS '))
57 | for k, v in configs.items(): print('{}: {}'.format(k, v))
58 | logger.info(pad_str(''))
59 |
60 |
61 | # Create the model
62 | if configs.pretrained:
63 | print("=> using pre-trained model '{}'".format(configs.TRAIN.arch))
64 | model = models.__dict__[configs.TRAIN.arch](pretrained=True)
65 | else:
66 | print("=> creating model '{}'".format(configs.TRAIN.arch))
67 | model = models.__dict__[configs.TRAIN.arch]()
68 |
69 | # Wrap the model into DataParallel
70 | model = torch.nn.DataParallel(model).cuda()
71 |
72 | # Criterion:
73 | criterion = nn.CrossEntropyLoss().cuda()
74 |
75 | # Optimizer:
76 | optimizer = torch.optim.SGD(model.parameters(), configs.TRAIN.lr,
77 | momentum=configs.TRAIN.momentum,
78 | weight_decay=configs.TRAIN.weight_decay)
79 |
80 | # Resume if a valid checkpoint path is provided
81 | if configs.resume:
82 | if os.path.isfile(configs.resume):
83 | print("=> loading checkpoint '{}'".format(configs.resume))
84 | checkpoint = torch.load(configs.resume)
85 | configs.TRAIN.start_epoch = checkpoint['epoch']
86 | best_prec1 = checkpoint['best_prec1']
87 | model.load_state_dict(checkpoint['state_dict'])
88 | optimizer.load_state_dict(checkpoint['optimizer'])
89 | print("=> loaded checkpoint '{}' (epoch {})"
90 | .format(configs.resume, checkpoint['epoch']))
91 | else:
92 | print("=> no checkpoint found at '{}'".format(configs.resume))
93 |
94 |
95 | # Initiate data loaders
96 | traindir = os.path.join(configs.data, 'train')
97 | valdir = os.path.join(configs.data, 'val')
98 |
99 | train_dataset = datasets.ImageFolder(
100 | traindir,
101 | transforms.Compose([
102 | transforms.RandomResizedCrop(configs.DATA.crop_size),
103 | transforms.RandomHorizontalFlip(),
104 | transforms.ToTensor(),
105 | ]))
106 |
107 | train_loader = torch.utils.data.DataLoader(
108 | train_dataset, batch_size=configs.DATA.batch_size, shuffle=True,
109 | num_workers=configs.DATA.workers, pin_memory=True, sampler=None)
110 |
111 | normalize = transforms.Normalize(mean=configs.TRAIN.mean,
112 | std=configs.TRAIN.std)
113 | val_loader = torch.utils.data.DataLoader(
114 | datasets.ImageFolder(valdir, transforms.Compose([
115 | transforms.Resize(configs.DATA.img_size),
116 | transforms.CenterCrop(configs.DATA.crop_size),
117 | transforms.ToTensor(),
118 | ])),
119 | batch_size=configs.DATA.batch_size, shuffle=False,
120 | num_workers=configs.DATA.workers, pin_memory=True)
121 |
122 | # If in evaluate mode: perform validation on PGD attacks as well as clean samples
123 | if configs.evaluate:
124 | logger.info(pad_str(' Performing PGD Attacks '))
125 | for pgd_param in configs.ADV.pgd_attack:
126 | validate_pgd(val_loader, model, criterion, pgd_param[0], pgd_param[1], configs, logger)
127 | validate(val_loader, model, criterion, configs, logger)
128 | return
129 |
130 |
131 | for epoch in range(configs.TRAIN.start_epoch, configs.TRAIN.epochs):
132 | adjust_learning_rate(configs.TRAIN.lr, optimizer, epoch, configs.ADV.n_repeats)
133 |
134 | # train for one epoch
135 | train(train_loader, model, criterion, optimizer, epoch)
136 |
137 | # evaluate on validation set
138 | prec1 = validate(val_loader, model, criterion, configs, logger)
139 |
140 | # remember best prec@1 and save checkpoint
141 | is_best = prec1 > best_prec1
142 | best_prec1 = max(prec1, best_prec1)
143 | save_checkpoint({
144 | 'epoch': epoch + 1,
145 | 'arch': configs.TRAIN.arch,
146 | 'state_dict': model.state_dict(),
147 | 'best_prec1': best_prec1,
148 | 'optimizer' : optimizer.state_dict(),
149 | }, is_best, os.path.join('trained_models', configs.output_name))
150 |
151 | # Automatically perform PGD Attacks at the end of training
152 | logger.info(pad_str(' Performing PGD Attacks '))
153 | for pgd_param in configs.ADV.pgd_attack:
154 | validate_pgd(val_loader, model, criterion, pgd_param[0], pgd_param[1], configs, logger)
155 |
156 |
157 | # Free Adversarial Training Module
158 | global global_noise_data
159 | global_noise_data = torch.zeros([configs.DATA.batch_size, 3, configs.DATA.crop_size, configs.DATA.crop_size]).cuda()
160 | def train(train_loader, model, criterion, optimizer, epoch):
161 | global global_noise_data
162 | mean = torch.Tensor(np.array(configs.TRAIN.mean)[:, np.newaxis, np.newaxis])
163 | mean = mean.expand(3,configs.DATA.crop_size, configs.DATA.crop_size).cuda()
164 | std = torch.Tensor(np.array(configs.TRAIN.std)[:, np.newaxis, np.newaxis])
165 | std = std.expand(3, configs.DATA.crop_size, configs.DATA.crop_size).cuda()
166 | # Initialize the meters
167 | batch_time = AverageMeter()
168 | data_time = AverageMeter()
169 | losses = AverageMeter()
170 | top1 = AverageMeter()
171 | top5 = AverageMeter()
172 | # switch to train mode
173 | model.train()
174 | for i, (input, target) in enumerate(train_loader):
175 | end = time.time()
176 | input = input.cuda(non_blocking=True)
177 | target = target.cuda(non_blocking=True)
178 | data_time.update(time.time() - end)
179 | for j in range(configs.ADV.n_repeats):
180 | # Ascend on the global noise
181 | noise_batch = Variable(global_noise_data[0:input.size(0)], requires_grad=True).cuda()
182 | in1 = input + noise_batch
183 | in1.clamp_(0, 1.0)
184 | in1.sub_(mean).div_(std)
185 | output = model(in1)
186 | loss = criterion(output, target)
187 |
188 | prec1, prec5 = accuracy(output, target, topk=(1, 5))
189 | losses.update(loss.item(), input.size(0))
190 | top1.update(prec1[0], input.size(0))
191 | top5.update(prec5[0], input.size(0))
192 |
193 | # compute gradient and do SGD step
194 | optimizer.zero_grad()
195 | loss.backward()
196 |
197 | # Update the noise for the next iteration
198 | pert = fgsm(noise_batch.grad, configs.ADV.fgsm_step)
199 | global_noise_data[0:input.size(0)] += pert.data
200 | global_noise_data.clamp_(-configs.ADV.clip_eps, configs.ADV.clip_eps)
201 |
202 | optimizer.step()
203 | # measure elapsed time
204 | batch_time.update(time.time() - end)
205 | end = time.time()
206 |
207 | if i % configs.TRAIN.print_freq == 0:
208 | print('Train Epoch: [{0}][{1}/{2}]\t'
209 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
210 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
211 | 'Loss {cls_loss.val:.4f} ({cls_loss.avg:.4f})\t'
212 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
213 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
214 | epoch, i, len(train_loader), batch_time=batch_time,
215 | data_time=data_time, top1=top1, top5=top5,cls_loss=losses))
216 | sys.stdout.flush()
217 |
218 | if __name__ == '__main__':
219 | main()
220 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | pyyaml
3 | EasyDict
4 | argparse
5 |
--------------------------------------------------------------------------------