├── .gitignore ├── ImageNet ├── lib │ ├── __init__.py │ ├── utils.py │ └── validation.py ├── requirements.txt ├── run_eval.sh ├── resize.py ├── configs │ ├── configs_fast_2px_phase1.yml │ ├── configs_fast_4px_phase1.yml │ ├── configs_fast_2px_phase3.yml │ ├── configs_fast_4px_phase3.yml │ ├── configs_fast_2px_evaluate.yml │ ├── configs_fast_4px_evaluate.yml │ ├── configs_fast_2px_phase2.yml │ └── configs_fast_4px_phase2.yml ├── run_fast_2px.sh ├── run_fast_4px.sh ├── README.md └── main_fast.py ├── CIFAR10 ├── .gitignore ├── requirements.txt ├── README.md ├── preact_resnet.py ├── utils.py ├── train_free.py ├── train_pgd.py └── train_fgsm.py ├── MNIST ├── models │ ├── fgsm.pth │ └── pgd_madry.pth ├── mnist_net.py ├── README.md ├── evaluate_mnist.py └── train_mnist.py ├── overfitting_error_curve.png └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /ImageNet/lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /CIFAR10/.gitignore: -------------------------------------------------------------------------------- 1 | *.out 2 | __pycache__ 3 | *output/ 4 | -------------------------------------------------------------------------------- /ImageNet/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pyyaml 3 | EasyDict 4 | argparse 5 | -------------------------------------------------------------------------------- /CIFAR10/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.17.2 2 | torch==1.2.0 3 | torchvision==0.4.0 4 | -------------------------------------------------------------------------------- /MNIST/models/fgsm.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/fast_adversarial/HEAD/MNIST/models/fgsm.pth -------------------------------------------------------------------------------- /MNIST/models/pgd_madry.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/fast_adversarial/HEAD/MNIST/models/pgd_madry.pth -------------------------------------------------------------------------------- /overfitting_error_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/fast_adversarial/HEAD/overfitting_error_curve.png -------------------------------------------------------------------------------- /ImageNet/run_eval.sh: -------------------------------------------------------------------------------- 1 | # 4eps evaluation 2 | python main_fast.py ~/imagenet --config configs_fast_4px_evaluate.yml --output_prefix eval_4px --resume trained_models/fast_adv_phase3_eps4_step5_eps4_repeat1/model_best.pth.tar --evaluate --restarts 10 3 | 4 | # 2eps evaluation 5 | python main_fast.py ~/imagenet --config configs_fast_2px_evaluate.yml --output_prefix eval_2px --resume trained_models/fast_adv_phase3_eps2_step2_eps2_repeat1/model_best.pth.tar --evaluate --restarts 10 6 | -------------------------------------------------------------------------------- /MNIST/mnist_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Flatten(nn.Module): 5 | def forward(self, x): 6 | return x.view(x.size(0), -1) 7 | 8 | def mnist_net(): 9 | model = nn.Sequential( 10 | nn.Conv2d(1, 16, 4, stride=2, padding=1), 11 | nn.ReLU(), 12 | nn.Conv2d(16, 32, 4, stride=2, padding=1), 13 | nn.ReLU(), 14 | Flatten(), 15 | nn.Linear(32*7*7,100), 16 | nn.ReLU(), 17 | nn.Linear(100, 10) 18 | ) 19 | return model 20 | -------------------------------------------------------------------------------- /CIFAR10/README.md: -------------------------------------------------------------------------------- 1 | ## Fast is Better Than Free: CIFAR10 2 | 3 | ### Requirements: 4 | Python 3.6 5 | 6 | Install the required packages: 7 | ``` 8 | $ pip install -r requirements.txt 9 | ``` 10 | 11 | Follow the instructions below to install apex: 12 | ``` 13 | $ git clone https://github.com/NVIDIA/apex 14 | $ cd apex 15 | $ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 16 | ``` 17 | 18 | ### Trained model weights 19 | Trained model weights can be found here: https://drive.google.com/open?id=1W2zGHyxTPgHhWln1kpHK5h-HY9kwfKfl 20 | -------------------------------------------------------------------------------- /MNIST/README.md: -------------------------------------------------------------------------------- 1 | # fast_is_better_than_free_MNIST 2 | 3 | To train, run 4 | 5 | `python train_mnist.py --fname models/fgsm.pth` 6 | 7 | which runs FGSM training with the default parameters. To run the evaluation with default parameters (50 iterations with step size 0.01 and 10 random restarts), run 8 | 9 | `python evaluate_mnist.py --fname models/fgsm.pth` 10 | 11 | To run PGD adversarial training with the same parameters as those used [here](https://github.com/MadryLab/mnist_challenge/blob/master/config.json), run 12 | 13 | `python train_mnist.py --fname models/pgd_madry.pth --attack pgd --alpha 0.01 --lr-type flat --lr-max 0.0001 --epochs 100 --batch-size 50` -------------------------------------------------------------------------------- /ImageNet/resize.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from pathlib import Path 3 | from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor 4 | from functools import partial 5 | import multiprocessing 6 | cpus = multiprocessing.cpu_count() 7 | cpus = min(48,cpus) 8 | 9 | PATH = Path.home()/'imagenet' 10 | #DEST = Path('/mnt/ram') 11 | DEST = Path.home()/'imagenet-sz' 12 | #szs = (int(128*1.25), int(256*1.25)) 13 | # szs = (int(160*1.25),) 14 | szs = (160,352) 15 | 16 | def resize_img(p, im, fn, sz): 17 | w,h = im.size 18 | ratio = min(h/sz,w/sz) 19 | im = im.resize((int(w/ratio), int(h/ratio)), resample=Image.BICUBIC) 20 | #import pdb; pdb.set_trace() 21 | new_fn = DEST/str(sz)/fn.relative_to(PATH) 22 | new_fn.parent.mkdir(exist_ok=True) 23 | im.convert('RGB').save(new_fn) 24 | 25 | def resizes(p, fn): 26 | im = Image.open(fn) 27 | for sz in szs: resize_img(p, im, fn, sz) 28 | 29 | def resize_imgs(p): 30 | files = p.glob('*/*.JPEG') 31 | #list(map(partial(resizes, p), files)) 32 | with ProcessPoolExecutor(cpus) as e: e.map(partial(resizes, p), files) 33 | 34 | 35 | for sz in szs: 36 | ssz=str(sz) 37 | (DEST/ssz).mkdir(exist_ok=True) 38 | for ds in ('val','train'): (DEST/ssz/ds).mkdir(exist_ok=True) 39 | 40 | for ds in ('val','train'): resize_imgs(PATH/ds) -------------------------------------------------------------------------------- /ImageNet/configs/configs_fast_2px_phase1.yml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | 3 | # Architecture name, see pytorch models package for 4 | # a list of possible architectures 5 | arch: 'resnet50' 6 | 7 | 8 | # SGD paramters 9 | lr: 0.1 10 | momentum: 0.9 11 | weight_decay: 0.0001 12 | 13 | # Print frequency, is used for both training and testing 14 | print_freq: 10 15 | 16 | # Dataset mean and std used for data normalization 17 | mean: !!python/tuple [0.485, 0.456, 0.406] 18 | std: !!python/tuple [0.229, 0.224, 0.225] 19 | 20 | # FAST ADVERSARIAL TRAINING PARAMETER 21 | 22 | # Starting epoch (interpret as multiplied by n_repeats) 23 | start_epoch: 0 24 | 25 | # Number of training epochs 26 | epochs: 6 27 | 28 | lr_epochs: !!python/tuple [0,1,6] 29 | lr_values: !!python/tuple [0,0.4,0.04] 30 | 31 | half: true 32 | random_init: true 33 | 34 | ADV: 35 | # FGSM parameters during training 36 | clip_eps: 2.0 37 | fgsm_step: 2.5 38 | 39 | # Number of repeats for free adversarial training 40 | n_repeats: 1 41 | 42 | # PGD attack parameters used during validation 43 | # the same clip_eps as above is used for PGD 44 | pgd_attack: 45 | - !!python/tuple [10, 0.00392156862] #[10 iters, 1.0/255.0] 46 | - !!python/tuple [50, 0.00392156862] #[50 iters, 1.0/255.0] 47 | 48 | DATA: 49 | # Number of data workers 50 | workers: 16 51 | 52 | # Color value range 53 | max_color_value: 255.0 54 | 55 | # FAST ADVERSARIAL TRAINING PARAMETER 56 | 57 | # Image Size 58 | # img_size: 160 59 | img_size: 0 60 | 61 | # Training batch size 62 | batch_size: 512 63 | 64 | # Crop Size for data augmentation 65 | crop_size: 128 66 | 67 | -------------------------------------------------------------------------------- /ImageNet/configs/configs_fast_4px_phase1.yml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | 3 | # Architecture name, see pytorch models package for 4 | # a list of possible architectures 5 | arch: 'resnet50' 6 | 7 | 8 | # SGD paramters 9 | lr: 0.1 10 | momentum: 0.9 11 | weight_decay: 0.0001 12 | 13 | # Print frequency, is used for both training and testing 14 | print_freq: 10 15 | 16 | # Dataset mean and std used for data normalization 17 | mean: !!python/tuple [0.485, 0.456, 0.406] 18 | std: !!python/tuple [0.229, 0.224, 0.225] 19 | 20 | # FAST ADVERSARIAL TRAINING PARAMETER 21 | 22 | # Starting epoch (interpret as multiplied by n_repeats) 23 | start_epoch: 0 24 | 25 | # Number of training epochs 26 | epochs: 6 27 | 28 | lr_epochs: !!python/tuple [0,1,6] 29 | lr_values: !!python/tuple [0,0.4,0.04] 30 | 31 | half: true 32 | random_init: true 33 | 34 | ADV: 35 | # FGSM parameters during training 36 | clip_eps: 4.0 37 | fgsm_step: 5.0 38 | 39 | # Number of repeats for free adversarial training 40 | n_repeats: 1 41 | 42 | # PGD attack parameters used during validation 43 | # the same clip_eps as above is used for PGD 44 | pgd_attack: 45 | - !!python/tuple [10, 0.00392156862] #[10 iters, 1.0/255.0] 46 | - !!python/tuple [50, 0.00392156862] #[50 iters, 1.0/255.0] 47 | 48 | DATA: 49 | # Number of data workers 50 | workers: 16 51 | 52 | # Color value range 53 | max_color_value: 255.0 54 | 55 | # FAST ADVERSARIAL TRAINING PARAMETER 56 | 57 | # Image Size 58 | # img_size: 160 59 | img_size: 0 60 | 61 | # Training batch size 62 | batch_size: 512 63 | 64 | # Crop Size for data augmentation 65 | crop_size: 128 66 | 67 | -------------------------------------------------------------------------------- /ImageNet/run_fast_2px.sh: -------------------------------------------------------------------------------- 1 | DATA160=~/imagenet-sz/160 2 | DATA352=~/imagenet-sz/352 3 | DATA=~/imagenet 4 | 5 | NAME=eps2 6 | 7 | CONFIG1=configs/configs_fast_phase1_${NAME}.yml 8 | CONFIG2=configs/configs_fast_phase2_${NAME}.yml 9 | CONFIG3=configs/configs_fast_phase3_${NAME}.yml 10 | 11 | PREFIX1=fast_adv_phase1_${NAME} 12 | PREFIX2=fast_adv_phase2_${NAME} 13 | PREFIX3=fast_adv_phase3_${NAME} 14 | 15 | OUT1=fast_train_phase1_${NAME}.out 16 | OUT2=fast_train_phase2_${NAME}.out 17 | OUT3=fast_train_phase3_${NAME}.out 18 | 19 | EVAL1=fast_eval_phase1_${NAME}.out 20 | EVAL2=fast_eval_phase2_${NAME}.out 21 | EVAL3=fast_eval_phase3_${NAME}.out 22 | 23 | END1=~/FastAdversarialTraining/trained_models/fast_adv_phase1_${NAME}_step2_eps2_repeat1/checkpoint_epoch6.pth.tar 24 | END2=~/FastAdversarialTraining/trained_models/fast_adv_phase2_${NAME}_step2_eps2_repeat1/checkpoint_epoch12.pth.tar 25 | END3=~/FastAdversarialTraining/trained_models/fast_adv_phase3_${NAME}_step2_eps2_repeat1/checkpoint_epoch15.pth.tar 26 | 27 | # training for phase 1 28 | python -u main_fast.py $DATA160 -c $CONFIG1 --output_prefix $PREFIX1 | tee $OUT1 29 | 30 | # evaluation for phase 1 31 | # python -u main_fast.py $DATA160 -c $CONFIG1 --output_prefix $PREFIX1 --resume $END1 --evaluate | tee $EVAL1 32 | 33 | # training for phase 2 34 | python -u main_fast.py $DATA352 -c $CONFIG2 --output_prefix $PREFIX2 --resume $END1 | tee $OUT2 35 | 36 | # evaluation for phase 2 37 | # python -u main_fast.py $DATA352 -c $CONFIG2 --output_prefix $PREFIX2 --resume $END2 --evaluate | tee $EVAL2 38 | 39 | # training for phase 3 40 | python -u main_fast.py $DATA -c $CONFIG3 --output_prefix $PREFIX3 --resume $END2 | tee $OUT3 41 | 42 | # evaluation for phase 3 43 | # python -u main_fast.py $DATA -c $CONFIG3 --output_prefix $PREFIX3 --resume $END3 --evaluate | tee $EVAL3 44 | -------------------------------------------------------------------------------- /ImageNet/run_fast_4px.sh: -------------------------------------------------------------------------------- 1 | DATA160=~/imagenet-sz/160 2 | DATA352=~/imagenet-sz/352 3 | DATA=~/imagenet 4 | 5 | NAME=eps4 6 | 7 | CONFIG1=configs/configs_fast_phase1_${NAME}.yml 8 | CONFIG2=configs/configs_fast_phase2_${NAME}.yml 9 | CONFIG3=configs/configs_fast_phase3_${NAME}.yml 10 | 11 | PREFIX1=fast_adv_phase1_${NAME} 12 | PREFIX2=fast_adv_phase2_${NAME} 13 | PREFIX3=fast_adv_phase3_${NAME} 14 | 15 | OUT1=fast_train_phase1_${NAME}.out 16 | OUT2=fast_train_phase2_${NAME}.out 17 | OUT3=fast_train_phase3_${NAME}.out 18 | 19 | EVAL1=fast_eval_phase1_${NAME}.out 20 | EVAL2=fast_eval_phase2_${NAME}.out 21 | EVAL3=fast_eval_phase3_${NAME}.out 22 | 23 | END1=~/FastAdversarialTraining/trained_models/fast_adv_phase1_${NAME}_step5_eps4_repeat1/checkpoint_epoch6.pth.tar 24 | END2=~/FastAdversarialTraining/trained_models/fast_adv_phase2_${NAME}_step5_eps4_repeat1/checkpoint_epoch12.pth.tar 25 | END3=~/FastAdversarialTraining/trained_models/fast_adv_phase3_${NAME}_step5_eps4_repeat1/checkpoint_epoch15.pth.tar 26 | 27 | # training for phase 1 28 | python -u main_fast.py $DATA160 -c $CONFIG1 --output_prefix $PREFIX1 | tee $OUT1 29 | 30 | # evaluation for phase 1 31 | # python -u main_fast.py $DATA160 -c $CONFIG1 --output_prefix $PREFIX1 --resume $END1 --evaluate | tee $EVAL1 32 | 33 | # training for phase 2 34 | python -u main_fast.py $DATA352 -c $CONFIG2 --output_prefix $PREFIX2 --resume $END1 | tee $OUT2 35 | 36 | # evaluation for phase 2 37 | # python -u main_fast.py $DATA352 -c $CONFIG2 --output_prefix $PREFIX2 --resume $END2 --evaluate | tee $EVAL2 38 | 39 | # training for phase 3 40 | python -u main_fast.py $DATA -c $CONFIG3 --output_prefix $PREFIX3 --resume $END2 | tee $OUT3 41 | 42 | # evaluation for phase 3 43 | # python -u main_fast.py $DATA -c $CONFIG3 --output_prefix $PREFIX3 --resume $END3 --evaluate | tee $EVAL3 44 | -------------------------------------------------------------------------------- /ImageNet/README.md: -------------------------------------------------------------------------------- 1 | # Fast Adversarial Training 2 | This is a supplemental material containing the code to run Fast is better than 3 | free: revisiting adversarial training, submitted to ICLR 2020. 4 | 5 | The framework used is a modified version of the [Free Adversarial Training](https://github.com/mahyarnajibi/FreeAdversarialTraining/blob/master/main_free.py) repository, which in turn was adapted from the [official PyTorch repository](https://github.com/pytorch/examples/blob/master/imagenet). 6 | 7 | ## Installation 8 | 1. Install [PyTorch](https://github.com/pytorch/examples/blob/master/imagenet). 9 | 2. Install the required python packages. All packages can be installed by running the following command: 10 | ```bash 11 | pip install -r requirements.txt 12 | ``` 13 | 3. Download and prepare the ImageNet dataset. You can use [this script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh), 14 | provided by the PyTorch repository, to move the validation subset to the labeled subfolders. 15 | 4. Prepare resized versions of the ImageNet dataset, you can use `resize.py` provided in this repository. 16 | 5. Install [Apex](https://github.com/NVIDIA/apex) to use half precision speedup. 17 | 18 | ## Training a model 19 | Scripts to robustly train an ImageNet classifier for epsilon radii of 2/255 and 4/255 are provided in `run_fast_2px.sh` and `run_fast_4px.sh`. These run the main code module `main_free.py` using the configurations provided in the `configs/` folder. To run the 50 step PGD adversary with 10 restarts, we also provide `run_eval.sh`. All parameters can be modified by adjusting the configuration files in the `configs/` folder. 20 | 21 | ## Model weights 22 | We also provide the model weights after training with these scripts, which can be found in this [Google drive folder](https://drive.google.com/open?id=1W2zGHyxTPgHhWln1kpHK5h-HY9kwfKfl). To use these with the provided evaluation script, either adjust the path to the model weights in the `run_eval.sh` script or rename the provided model weights accordingly. 23 | -------------------------------------------------------------------------------- /ImageNet/configs/configs_fast_2px_phase3.yml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | 3 | # Architecture name, see pytorch models package for 4 | # a list of possible architectures 5 | arch: 'resnet50' 6 | 7 | 8 | # SGD paramters 9 | lr: 0.1 10 | momentum: 0.9 11 | weight_decay: 0.0001 12 | 13 | # Print frequency, is used for both training and testing 14 | print_freq: 10 15 | 16 | # Dataset mean and std used for data normalization 17 | mean: !!python/tuple [0.485, 0.456, 0.406] 18 | std: !!python/tuple [0.229, 0.224, 0.225] 19 | 20 | # FAST ADVERSARIAL TRAINING PARAMETER 21 | 22 | # Starting epoch (interpret as multiplied by n_repeats) 23 | # start_epoch: 6 24 | # # Number of training epochs 25 | # epochs: 28 26 | 27 | # lr_epochs: !!python/tuple [25,28] 28 | # lr_values: !!python/tuple [0.0025, 0.00025] 29 | 30 | # Starting epoch (interpret as multiplied by n_repeats) 31 | start_epoch: 12 32 | # Number of training epochs 33 | epochs: 15 34 | 35 | # lr_epochs: !!python/tuple [12,13,15] 36 | # lr_values: !!python/tuple [0, 0.01, 0.001] 37 | 38 | lr_epochs: !!python/tuple [12,15] 39 | lr_values: !!python/tuple [0.004, 0.0004] 40 | 41 | half: true 42 | random_init: true 43 | 44 | ADV: 45 | # FGSM parameters during training 46 | clip_eps: 2.0 47 | fgsm_step: 2.5 48 | 49 | # Number of repeats for free adversarial training 50 | n_repeats: 1 51 | 52 | # PGD attack parameters used during validation 53 | # the same clip_eps as above is used for PGD 54 | pgd_attack: 55 | - !!python/tuple [10, 0.00392156862] #[10 iters, 1.0/255.0] 56 | - !!python/tuple [50, 0.00392156862] #[50 iters, 1.0/255.0] 57 | 58 | DATA: 59 | # Number of data workers 60 | workers: 4 61 | 62 | # Color value range 63 | max_color_value: 255.0 64 | 65 | # FAST ADVERSARIAL TRAINING PARAMETER 66 | 67 | # Image Size 68 | img_size: 0 69 | 70 | # Training batch size 71 | batch_size: 128 72 | 73 | # Crop Size for data augmentation 74 | crop_size: 288 75 | 76 | -------------------------------------------------------------------------------- /ImageNet/configs/configs_fast_4px_phase3.yml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | 3 | # Architecture name, see pytorch models package for 4 | # a list of possible architectures 5 | arch: 'resnet50' 6 | 7 | 8 | # SGD paramters 9 | lr: 0.1 10 | momentum: 0.9 11 | weight_decay: 0.0001 12 | 13 | # Print frequency, is used for both training and testing 14 | print_freq: 10 15 | 16 | # Dataset mean and std used for data normalization 17 | mean: !!python/tuple [0.485, 0.456, 0.406] 18 | std: !!python/tuple [0.229, 0.224, 0.225] 19 | 20 | # FAST ADVERSARIAL TRAINING PARAMETER 21 | 22 | # Starting epoch (interpret as multiplied by n_repeats) 23 | # start_epoch: 6 24 | # # Number of training epochs 25 | # epochs: 28 26 | 27 | # lr_epochs: !!python/tuple [25,28] 28 | # lr_values: !!python/tuple [0.0025, 0.00025] 29 | 30 | # Starting epoch (interpret as multiplied by n_repeats) 31 | start_epoch: 12 32 | # Number of training epochs 33 | epochs: 15 34 | 35 | # lr_epochs: !!python/tuple [12,13,15] 36 | # lr_values: !!python/tuple [0, 0.01, 0.001] 37 | 38 | lr_epochs: !!python/tuple [12,15] 39 | lr_values: !!python/tuple [0.004, 0.0004] 40 | 41 | half: true 42 | random_init: true 43 | 44 | ADV: 45 | # FGSM parameters during training 46 | clip_eps: 4.0 47 | fgsm_step: 5.0 48 | 49 | # Number of repeats for free adversarial training 50 | n_repeats: 1 51 | 52 | # PGD attack parameters used during validation 53 | # the same clip_eps as above is used for PGD 54 | pgd_attack: 55 | - !!python/tuple [10, 0.00392156862] #[10 iters, 1.0/255.0] 56 | - !!python/tuple [50, 0.00392156862] #[50 iters, 1.0/255.0] 57 | 58 | DATA: 59 | # Number of data workers 60 | workers: 4 61 | 62 | # Color value range 63 | max_color_value: 255.0 64 | 65 | # FAST ADVERSARIAL TRAINING PARAMETER 66 | 67 | # Image Size 68 | img_size: 0 69 | 70 | # Training batch size 71 | batch_size: 128 72 | 73 | # Crop Size for data augmentation 74 | crop_size: 288 75 | 76 | -------------------------------------------------------------------------------- /ImageNet/configs/configs_fast_2px_evaluate.yml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | 3 | # Architecture name, see pytorch models package for 4 | # a list of possible architectures 5 | arch: 'resnet50' 6 | 7 | 8 | # SGD paramters 9 | lr: 0.1 10 | momentum: 0.9 11 | weight_decay: 0.0001 12 | 13 | # Print frequency, is used for both training and testing 14 | print_freq: 10 15 | 16 | # Dataset mean and std used for data normalization 17 | mean: !!python/tuple [0.485, 0.456, 0.406] 18 | std: !!python/tuple [0.229, 0.224, 0.225] 19 | 20 | # FAST ADVERSARIAL TRAINING PARAMETER 21 | 22 | # Starting epoch (interpret as multiplied by n_repeats) 23 | # start_epoch: 6 24 | # # Number of training epochs 25 | # epochs: 28 26 | 27 | # lr_epochs: !!python/tuple [25,28] 28 | # lr_values: !!python/tuple [0.0025, 0.00025] 29 | 30 | # Starting epoch (interpret as multiplied by n_repeats) 31 | start_epoch: 12 32 | # Number of training epochs 33 | epochs: 15 34 | 35 | # lr_epochs: !!python/tuple [12,13,15] 36 | # lr_values: !!python/tuple [0, 0.01, 0.001] 37 | 38 | lr_epochs: !!python/tuple [12,15] 39 | lr_values: !!python/tuple [0.004, 0.0004] 40 | 41 | half: true 42 | random_init: true 43 | 44 | ADV: 45 | # FGSM parameters during training 46 | clip_eps: 2.0 47 | fgsm_step: 2.0 48 | 49 | # Number of repeats for free adversarial training 50 | n_repeats: 1 51 | 52 | # PGD attack parameters used during validation 53 | # the same clip_eps as above is used for PGD 54 | pgd_attack: 55 | #- !!python/tuple [10, 0.00392156862] #[10 iters, 1.0/255.0] 56 | - !!python/tuple [50, 0.00392156862] #[50 iters, 1.0/255.0] 57 | 58 | DATA: 59 | # Number of data workers 60 | workers: 4 61 | 62 | # Color value range 63 | max_color_value: 255.0 64 | 65 | # FAST ADVERSARIAL TRAINING PARAMETER 66 | 67 | # Image Size 68 | img_size: 0 69 | 70 | # Training batch size 71 | batch_size: 128 72 | 73 | # Crop Size for data augmentation 74 | crop_size: 288 75 | 76 | -------------------------------------------------------------------------------- /ImageNet/configs/configs_fast_4px_evaluate.yml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | 3 | # Architecture name, see pytorch models package for 4 | # a list of possible architectures 5 | arch: 'resnet50' 6 | 7 | 8 | # SGD paramters 9 | lr: 0.1 10 | momentum: 0.9 11 | weight_decay: 0.0001 12 | 13 | # Print frequency, is used for both training and testing 14 | print_freq: 10 15 | 16 | # Dataset mean and std used for data normalization 17 | mean: !!python/tuple [0.485, 0.456, 0.406] 18 | std: !!python/tuple [0.229, 0.224, 0.225] 19 | 20 | # FAST ADVERSARIAL TRAINING PARAMETER 21 | 22 | # Starting epoch (interpret as multiplied by n_repeats) 23 | # start_epoch: 6 24 | # # Number of training epochs 25 | # epochs: 28 26 | 27 | # lr_epochs: !!python/tuple [25,28] 28 | # lr_values: !!python/tuple [0.0025, 0.00025] 29 | 30 | # Starting epoch (interpret as multiplied by n_repeats) 31 | start_epoch: 12 32 | # Number of training epochs 33 | epochs: 15 34 | 35 | # lr_epochs: !!python/tuple [12,13,15] 36 | # lr_values: !!python/tuple [0, 0.01, 0.001] 37 | 38 | lr_epochs: !!python/tuple [12,15] 39 | lr_values: !!python/tuple [0.004, 0.0004] 40 | 41 | half: true 42 | random_init: true 43 | 44 | ADV: 45 | # FGSM parameters during training 46 | clip_eps: 4.0 47 | fgsm_step: 5.0 48 | 49 | # Number of repeats for free adversarial training 50 | n_repeats: 1 51 | 52 | # PGD attack parameters used during validation 53 | # the same clip_eps as above is used for PGD 54 | pgd_attack: 55 | #- !!python/tuple [10, 0.00392156862] #[10 iters, 1.0/255.0] 56 | - !!python/tuple [50, 0.00392156862] #[50 iters, 1.0/255.0] 57 | 58 | DATA: 59 | # Number of data workers 60 | workers: 4 61 | 62 | # Color value range 63 | max_color_value: 255.0 64 | 65 | # FAST ADVERSARIAL TRAINING PARAMETER 66 | 67 | # Image Size 68 | img_size: 0 69 | 70 | # Training batch size 71 | batch_size: 128 72 | 73 | # Crop Size for data augmentation 74 | crop_size: 288 75 | 76 | -------------------------------------------------------------------------------- /ImageNet/configs/configs_fast_2px_phase2.yml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | 3 | # Architecture name, see pytorch models package for 4 | # a list of possible architectures 5 | arch: 'resnet50' 6 | 7 | 8 | # SGD paramters 9 | lr: 0.1 10 | momentum: 0.9 11 | weight_decay: 0.0001 12 | 13 | # Print frequency, is used for both training and testing 14 | print_freq: 10 15 | 16 | # Dataset mean and std used for data normalization 17 | mean: !!python/tuple [0.485, 0.456, 0.406] 18 | std: !!python/tuple [0.229, 0.224, 0.225] 19 | 20 | # FAST ADVERSARIAL TRAINING PARAMETER 21 | 22 | # Starting epoch (interpret as multiplied by n_repeats) 23 | # start_epoch: 12 24 | # # Number of training epochs 25 | # epochs: 24 26 | 27 | # lr_epochs: !!python/tuple [12,22,25] 28 | # lr_values: !!python/tuple [0.4375,0.04375,0.004375] 29 | 30 | # Starting epoch (interpret as multiplied by n_repeats) 31 | start_epoch: 6 32 | # Number of training epochs 33 | epochs: 12 34 | 35 | lr_epochs: !!python/tuple [6,12] 36 | lr_values: !!python/tuple [0.04,0.004] 37 | 38 | # lr_epochs: !!python/tuple [6,7,12] 39 | # lr_values: !!python/tuple [0,0.1,0.01] 40 | 41 | # epochs: 18 42 | # lr_epochs: !!python/tuple [6,12,16,18] 43 | # lr_values: !!python/tuple [0.1,0.01,0.001,0] 44 | 45 | half: true 46 | random_init: true 47 | 48 | ADV: 49 | # FGSM parameters during training 50 | clip_eps: 2.0 51 | fgsm_step: 2.5 52 | 53 | # Number of repeats for free adversarial training 54 | n_repeats: 1 55 | 56 | # PGD attack parameters used during validation 57 | # the same clip_eps as above is used for PGD 58 | pgd_attack: 59 | - !!python/tuple [10, 0.00392156862] #[10 iters, 1.0/255.0] 60 | - !!python/tuple [50, 0.00392156862] #[50 iters, 1.0/255.0] 61 | 62 | DATA: 63 | # Number of data workers 64 | workers: 16 65 | 66 | # Color value range 67 | max_color_value: 255.0 68 | 69 | # FAST ADVERSARIAL TRAINING PARAMETER 70 | 71 | # Image Size 72 | # img_size: 352 73 | img_size: 0 74 | 75 | # Training batch size 76 | batch_size: 224 77 | 78 | # Crop Size for data augmentation 79 | crop_size: 224 80 | 81 | -------------------------------------------------------------------------------- /ImageNet/configs/configs_fast_4px_phase2.yml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | 3 | # Architecture name, see pytorch models package for 4 | # a list of possible architectures 5 | arch: 'resnet50' 6 | 7 | 8 | # SGD paramters 9 | lr: 0.1 10 | momentum: 0.9 11 | weight_decay: 0.0001 12 | 13 | # Print frequency, is used for both training and testing 14 | print_freq: 10 15 | 16 | # Dataset mean and std used for data normalization 17 | mean: !!python/tuple [0.485, 0.456, 0.406] 18 | std: !!python/tuple [0.229, 0.224, 0.225] 19 | 20 | # FAST ADVERSARIAL TRAINING PARAMETER 21 | 22 | # Starting epoch (interpret as multiplied by n_repeats) 23 | # start_epoch: 12 24 | # # Number of training epochs 25 | # epochs: 24 26 | 27 | # lr_epochs: !!python/tuple [12,22,25] 28 | # lr_values: !!python/tuple [0.4375,0.04375,0.004375] 29 | 30 | # Starting epoch (interpret as multiplied by n_repeats) 31 | start_epoch: 6 32 | # Number of training epochs 33 | epochs: 12 34 | 35 | lr_epochs: !!python/tuple [6,12] 36 | lr_values: !!python/tuple [0.04,0.004] 37 | 38 | # lr_epochs: !!python/tuple [6,7,12] 39 | # lr_values: !!python/tuple [0,0.1,0.01] 40 | 41 | # epochs: 18 42 | # lr_epochs: !!python/tuple [6,12,16,18] 43 | # lr_values: !!python/tuple [0.1,0.01,0.001,0] 44 | 45 | half: true 46 | random_init: true 47 | 48 | ADV: 49 | # FGSM parameters during training 50 | clip_eps: 4.0 51 | fgsm_step: 5.0 52 | 53 | # Number of repeats for free adversarial training 54 | n_repeats: 1 55 | 56 | # PGD attack parameters used during validation 57 | # the same clip_eps as above is used for PGD 58 | pgd_attack: 59 | - !!python/tuple [10, 0.00392156862] #[10 iters, 1.0/255.0] 60 | - !!python/tuple [50, 0.00392156862] #[50 iters, 1.0/255.0] 61 | 62 | DATA: 63 | # Number of data workers 64 | workers: 16 65 | 66 | # Color value range 67 | max_color_value: 255.0 68 | 69 | # FAST ADVERSARIAL TRAINING PARAMETER 70 | 71 | # Image Size 72 | # img_size: 352 73 | img_size: 0 74 | 75 | # Training batch size 76 | batch_size: 224 77 | 78 | # Crop Size for data augmentation 79 | crop_size: 224 80 | 81 | -------------------------------------------------------------------------------- /ImageNet/lib/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import datetime 4 | import torchvision.models as models 5 | import math 6 | import torch 7 | import yaml 8 | from easydict import EasyDict 9 | import shutil 10 | 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value""" 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | def adjust_learning_rate(initial_lr, optimizer, epoch, n_repeats): 30 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 31 | lr = initial_lr * (0.1 ** (epoch // int(math.ceil(30./n_repeats)))) 32 | for param_group in optimizer.param_groups: 33 | param_group['lr'] = lr 34 | 35 | 36 | def fgsm(gradz, step_size): 37 | return step_size*torch.sign(gradz) 38 | 39 | 40 | 41 | def accuracy(output, target, topk=(1,)): 42 | """Computes the accuracy over the k top predictions for the specified values of k""" 43 | with torch.no_grad(): 44 | maxk = max(topk) 45 | batch_size = target.size(0) 46 | 47 | _, pred = output.topk(maxk, 1, True, True) 48 | pred = pred.t() 49 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 50 | 51 | res = [] 52 | for k in topk: 53 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 54 | res.append(correct_k.mul_(100.0 / batch_size)) 55 | return res 56 | 57 | 58 | def initiate_logger(output_path, evaluate): 59 | if not os.path.isdir(os.path.join('output', output_path)): 60 | os.makedirs(os.path.join('output', output_path)) 61 | logging.basicConfig(level=logging.INFO) 62 | logger = logging.getLogger() 63 | logger.addHandler(logging.FileHandler(os.path.join('output', output_path, 'eval.txt' if evaluate else 'log.txt'),'w')) 64 | logger.info(pad_str(' LOGISTICS ')) 65 | logger.info('Experiment Date: {}'.format(datetime.datetime.now().strftime('%Y-%m-%d %H:%M'))) 66 | logger.info('Output Name: {}'.format(output_path)) 67 | logger.info('User: {}'.format(os.getenv('USER'))) 68 | return logger 69 | 70 | def get_model_names(): 71 | return sorted(name for name in models.__dict__ 72 | if name.islower() and not name.startswith("__") 73 | and callable(models.__dict__[name])) 74 | 75 | def pad_str(msg, total_len=70): 76 | rem_len = total_len - len(msg) 77 | return '*'*int(rem_len/2) + msg + '*'*int(rem_len/2)\ 78 | 79 | def parse_config_file(args): 80 | with open(args.config) as f: 81 | config = EasyDict(yaml.load(f)) 82 | 83 | # Add args parameters to the dict 84 | for k, v in vars(args).items(): 85 | config[k] = v 86 | 87 | # Add the output path 88 | config.output_name = '{:s}_step{:d}_eps{:d}_repeat{:d}'.format(args.output_prefix, 89 | int(config.ADV.fgsm_step), int(config.ADV.clip_eps), 90 | config.ADV.n_repeats) 91 | return config 92 | 93 | 94 | def save_checkpoint(state, is_best, filepath, epoch): 95 | filename = os.path.join(filepath, f'checkpoint_epoch{epoch}.pth.tar') 96 | # Save model 97 | torch.save(state, filename) 98 | # Save best model 99 | if is_best: 100 | shutil.copyfile(filename, os.path.join(filepath, 'model_best.pth.tar')) 101 | -------------------------------------------------------------------------------- /CIFAR10/preact_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class PreActBlock(nn.Module): 7 | '''Pre-activation version of the BasicBlock.''' 8 | expansion = 1 9 | 10 | def __init__(self, in_planes, planes, stride=1): 11 | super(PreActBlock, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 16 | 17 | if stride != 1 or in_planes != self.expansion*planes: 18 | self.shortcut = nn.Sequential( 19 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 20 | ) 21 | 22 | def forward(self, x): 23 | out = F.relu(self.bn1(x)) 24 | shortcut = self.shortcut(x) if hasattr(self, 'shortcut') else x 25 | out = self.conv1(out) 26 | out = self.conv2(F.relu(self.bn2(out))) 27 | out += shortcut 28 | return out 29 | 30 | 31 | class PreActBottleneck(nn.Module): 32 | '''Pre-activation version of the original Bottleneck module.''' 33 | expansion = 4 34 | 35 | def __init__(self, in_planes, planes, stride=1): 36 | super(PreActBottleneck, self).__init__() 37 | self.bn1 = nn.BatchNorm2d(in_planes) 38 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 39 | self.bn2 = nn.BatchNorm2d(planes) 40 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 41 | self.bn3 = nn.BatchNorm2d(planes) 42 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 43 | 44 | if stride != 1 or in_planes != self.expansion*planes: 45 | self.shortcut = nn.Sequential( 46 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 47 | ) 48 | 49 | def forward(self, x): 50 | out = F.relu(self.bn1(x)) 51 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 52 | out = self.conv1(out) 53 | out = self.conv2(F.relu(self.bn2(out))) 54 | out = self.conv3(F.relu(self.bn3(out))) 55 | out += shortcut 56 | return out 57 | 58 | 59 | class PreActResNet(nn.Module): 60 | def __init__(self, block, num_blocks, num_classes=10): 61 | super(PreActResNet, self).__init__() 62 | self.in_planes = 64 63 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 64 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 65 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 66 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 67 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 68 | self.bn = nn.BatchNorm2d(512 * block.expansion) 69 | self.linear = nn.Linear(512 * block.expansion, num_classes) 70 | 71 | def _make_layer(self, block, planes, num_blocks, stride): 72 | strides = [stride] + [1]*(num_blocks-1) 73 | layers = [] 74 | for stride in strides: 75 | layers.append(block(self.in_planes, planes, stride)) 76 | self.in_planes = planes * block.expansion 77 | return nn.Sequential(*layers) 78 | 79 | def forward(self, x): 80 | out = self.conv1(x) 81 | out = self.layer1(out) 82 | out = self.layer2(out) 83 | out = self.layer3(out) 84 | out = self.layer4(out) 85 | out = F.relu(self.bn(out)) 86 | out = F.avg_pool2d(out, 4) 87 | out = out.view(out.size(0), -1) 88 | out = self.linear(out) 89 | return out 90 | 91 | 92 | def PreActResNet18(): 93 | return PreActResNet(PreActBlock, [2,2,2,2]) 94 | -------------------------------------------------------------------------------- /CIFAR10/utils.py: -------------------------------------------------------------------------------- 1 | import apex.amp as amp 2 | import torch 3 | import torch.nn.functional as F 4 | from torchvision import datasets, transforms 5 | from torch.utils.data.sampler import SubsetRandomSampler 6 | import numpy as np 7 | 8 | cifar10_mean = (0.4914, 0.4822, 0.4465) 9 | cifar10_std = (0.2471, 0.2435, 0.2616) 10 | 11 | mu = torch.tensor(cifar10_mean).view(3,1,1).cuda() 12 | std = torch.tensor(cifar10_std).view(3,1,1).cuda() 13 | 14 | upper_limit = ((1 - mu)/ std) 15 | lower_limit = ((0 - mu)/ std) 16 | 17 | 18 | def clamp(X, lower_limit, upper_limit): 19 | return torch.max(torch.min(X, upper_limit), lower_limit) 20 | 21 | 22 | def get_loaders(dir_, batch_size): 23 | train_transform = transforms.Compose([ 24 | transforms.RandomCrop(32, padding=4), 25 | transforms.RandomHorizontalFlip(), 26 | transforms.ToTensor(), 27 | transforms.Normalize(cifar10_mean, cifar10_std), 28 | ]) 29 | test_transform = transforms.Compose([ 30 | transforms.ToTensor(), 31 | transforms.Normalize(cifar10_mean, cifar10_std), 32 | ]) 33 | num_workers = 2 34 | train_dataset = datasets.CIFAR10( 35 | dir_, train=True, transform=train_transform, download=True) 36 | test_dataset = datasets.CIFAR10( 37 | dir_, train=False, transform=test_transform, download=True) 38 | train_loader = torch.utils.data.DataLoader( 39 | dataset=train_dataset, 40 | batch_size=batch_size, 41 | shuffle=True, 42 | pin_memory=True, 43 | num_workers=num_workers, 44 | ) 45 | test_loader = torch.utils.data.DataLoader( 46 | dataset=test_dataset, 47 | batch_size=batch_size, 48 | shuffle=False, 49 | pin_memory=True, 50 | num_workers=2, 51 | ) 52 | return train_loader, test_loader 53 | 54 | 55 | def attack_pgd(model, X, y, epsilon, alpha, attack_iters, restarts, opt=None): 56 | max_loss = torch.zeros(y.shape[0]).cuda() 57 | max_delta = torch.zeros_like(X).cuda() 58 | for zz in range(restarts): 59 | delta = torch.zeros_like(X).cuda() 60 | for i in range(len(epsilon)): 61 | delta[:, i, :, :].uniform_(-epsilon[i][0][0].item(), epsilon[i][0][0].item()) 62 | delta.data = clamp(delta, lower_limit - X, upper_limit - X) 63 | delta.requires_grad = True 64 | for _ in range(attack_iters): 65 | output = model(X + delta) 66 | index = torch.where(output.max(1)[1] == y) 67 | if len(index[0]) == 0: 68 | break 69 | loss = F.cross_entropy(output, y) 70 | if opt is not None: 71 | with amp.scale_loss(loss, opt) as scaled_loss: 72 | scaled_loss.backward() 73 | else: 74 | loss.backward() 75 | grad = delta.grad.detach() 76 | d = delta[index[0], :, :, :] 77 | g = grad[index[0], :, :, :] 78 | d = clamp(d + alpha * torch.sign(g), -epsilon, epsilon) 79 | d = clamp(d, lower_limit - X[index[0], :, :, :], upper_limit - X[index[0], :, :, :]) 80 | delta.data[index[0], :, :, :] = d 81 | delta.grad.zero_() 82 | all_loss = F.cross_entropy(model(X+delta), y, reduction='none').detach() 83 | max_delta[all_loss >= max_loss] = delta.detach()[all_loss >= max_loss] 84 | max_loss = torch.max(max_loss, all_loss) 85 | return max_delta 86 | 87 | 88 | def evaluate_pgd(test_loader, model, attack_iters, restarts): 89 | epsilon = (8 / 255.) / std 90 | alpha = (2 / 255.) / std 91 | pgd_loss = 0 92 | pgd_acc = 0 93 | n = 0 94 | model.eval() 95 | for i, (X, y) in enumerate(test_loader): 96 | X, y = X.cuda(), y.cuda() 97 | pgd_delta = attack_pgd(model, X, y, epsilon, alpha, attack_iters, restarts) 98 | with torch.no_grad(): 99 | output = model(X + pgd_delta) 100 | loss = F.cross_entropy(output, y) 101 | pgd_loss += loss.item() * y.size(0) 102 | pgd_acc += (output.max(1)[1] == y).sum().item() 103 | n += y.size(0) 104 | return pgd_loss/n, pgd_acc/n 105 | 106 | 107 | def evaluate_standard(test_loader, model): 108 | test_loss = 0 109 | test_acc = 0 110 | n = 0 111 | model.eval() 112 | with torch.no_grad(): 113 | for i, (X, y) in enumerate(test_loader): 114 | X, y = X.cuda(), y.cuda() 115 | output = model(X) 116 | loss = F.cross_entropy(output, y) 117 | test_loss += loss.item() * y.size(0) 118 | test_acc += (output.max(1)[1] == y).sum().item() 119 | n += y.size(0) 120 | return test_loss/n, test_acc/n 121 | -------------------------------------------------------------------------------- /MNIST/evaluate_mnist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import sys 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torchvision 11 | from torchvision import datasets, transforms 12 | from torch.utils.data import DataLoader, Dataset 13 | 14 | from mnist_net import mnist_net 15 | 16 | logger = logging.getLogger(__name__) 17 | logging.basicConfig( 18 | format='[%(asctime)s %(filename)s %(name)s %(levelname)s] - %(message)s', 19 | datefmt='%Y/%m/%d %H:%M:%S', 20 | level=logging.DEBUG) 21 | 22 | 23 | def clamp(X, lower_limit, upper_limit): 24 | return torch.max(torch.min(X, upper_limit), lower_limit) 25 | 26 | 27 | def attack_fgsm(model, X, y, epsilon): 28 | delta = torch.zeros_like(X, requires_grad=True) 29 | output = model(X + delta) 30 | loss = F.cross_entropy(output, y) 31 | loss.backward() 32 | grad = delta.grad.detach() 33 | delta.data = epsilon * torch.sign(grad) 34 | return delta.detach() 35 | 36 | 37 | def attack_pgd(model, X, y, epsilon, alpha, attack_iters, restarts): 38 | max_loss = torch.zeros(y.shape[0]).cuda() 39 | max_delta = torch.zeros_like(X).cuda() 40 | for _ in range(restarts): 41 | delta = torch.zeros_like(X).uniform_(-epsilon, epsilon).cuda() 42 | delta.data = clamp(delta, 0-X, 1-X) 43 | delta.requires_grad = True 44 | for _ in range(attack_iters): 45 | output = model(X + delta) 46 | index = torch.where(output.max(1)[1] == y)[0] 47 | if len(index) == 0: 48 | break 49 | loss = F.cross_entropy(output, y) 50 | loss.backward() 51 | grad = delta.grad.detach() 52 | d = torch.clamp(delta + alpha * torch.sign(grad), -epsilon, epsilon) 53 | d = clamp(d, 0-X, 1-X) 54 | delta.data[index] = d[index] 55 | delta.grad.zero_() 56 | all_loss = F.cross_entropy(model(X+delta), y, reduction='none') 57 | max_delta[all_loss >= max_loss] = delta.detach()[all_loss >= max_loss] 58 | max_loss = torch.max(max_loss, all_loss) 59 | return max_delta 60 | 61 | 62 | def get_args(): 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument('--batch-size', default=100, type=int) 65 | parser.add_argument('--data-dir', default='../mnist-data', type=str) 66 | parser.add_argument('--fname', type=str) 67 | parser.add_argument('--attack', default='pgd', type=str, choices=['pgd', 'fgsm', 'none']) 68 | parser.add_argument('--epsilon', default=0.3, type=float) 69 | parser.add_argument('--attack-iters', default=50, type=int) 70 | parser.add_argument('--alpha', default=1e-2, type=float) 71 | parser.add_argument('--restarts', default=10, type=int) 72 | parser.add_argument('--seed', default=0, type=int) 73 | return parser.parse_args() 74 | 75 | 76 | def main(): 77 | args = get_args() 78 | logger.info(args) 79 | 80 | np.random.seed(args.seed) 81 | torch.manual_seed(args.seed) 82 | torch.cuda.manual_seed(args.seed) 83 | 84 | mnist_test = datasets.MNIST("../mnist-data", train=False, download=True, transform=transforms.ToTensor()) 85 | test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=args.batch_size, shuffle=False) 86 | 87 | model = mnist_net().cuda() 88 | checkpoint = torch.load(args.fname) 89 | model.load_state_dict(checkpoint) 90 | model.eval() 91 | 92 | total_loss = 0 93 | total_acc = 0 94 | n = 0 95 | 96 | if args.attack == 'none': 97 | with torch.no_grad(): 98 | for i, (X, y) in enumerate(test_loader): 99 | X, y = X.cuda(), y.cuda() 100 | output = model(X) 101 | loss = F.cross_entropy(output, y) 102 | total_loss += loss.item() * y.size(0) 103 | total_acc += (output.max(1)[1] == y).sum().item() 104 | n += y.size(0) 105 | else: 106 | for i, (X, y) in enumerate(test_loader): 107 | X, y = X.cuda(), y.cuda() 108 | if args.attack == 'pgd': 109 | delta = attack_pgd(model, X, y, args.epsilon, args.alpha, args.attack_iters, args.restarts) 110 | elif args.attack == 'fgsm': 111 | delta = attack_fgsm(model, X, y, args.epsilon) 112 | with torch.no_grad(): 113 | output = model(X + delta) 114 | loss = F.cross_entropy(output, y) 115 | total_loss += loss.item() * y.size(0) 116 | total_acc += (output.max(1)[1] == y).sum().item() 117 | n += y.size(0) 118 | 119 | logger.info('Test Loss: %.4f, Acc: %.4f', total_loss/n, total_acc/n) 120 | 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /MNIST/train_mnist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision 10 | from torchvision import datasets, transforms 11 | from torch.utils.data import DataLoader, Dataset 12 | 13 | from mnist_net import mnist_net 14 | 15 | logger = logging.getLogger(__name__) 16 | logging.basicConfig( 17 | format='[%(asctime)s] - %(message)s', 18 | datefmt='%Y/%m/%d %H:%M:%S', 19 | level=logging.DEBUG) 20 | 21 | 22 | def get_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--batch-size', default=100, type=int) 25 | parser.add_argument('--data-dir', default='../mnist-data', type=str) 26 | parser.add_argument('--epochs', default=10, type=int) 27 | parser.add_argument('--attack', default='fgsm', type=str, choices=['none', 'pgd', 'fgsm']) 28 | parser.add_argument('--epsilon', default=0.3, type=float) 29 | parser.add_argument('--alpha', default=0.375, type=float) 30 | parser.add_argument('--attack-iters', default=40, type=int) 31 | parser.add_argument('--lr-max', default=5e-3, type=float) 32 | parser.add_argument('--lr-type', default='cyclic') 33 | parser.add_argument('--fname', default='mnist_model', type=str) 34 | parser.add_argument('--seed', default=0, type=int) 35 | return parser.parse_args() 36 | 37 | 38 | def main(): 39 | args = get_args() 40 | logger.info(args) 41 | 42 | np.random.seed(args.seed) 43 | torch.manual_seed(args.seed) 44 | torch.cuda.manual_seed(args.seed) 45 | 46 | mnist_train = datasets.MNIST("../mnist-data", train=True, download=True, transform=transforms.ToTensor()) 47 | train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=args.batch_size, shuffle=True) 48 | 49 | model = mnist_net().cuda() 50 | model.train() 51 | 52 | opt = torch.optim.Adam(model.parameters(), lr=args.lr_max) 53 | if args.lr_type == 'cyclic': 54 | lr_schedule = lambda t: np.interp([t], [0, args.epochs * 2//5, args.epochs], [0, args.lr_max, 0])[0] 55 | elif args.lr_type == 'flat': 56 | lr_schedule = lambda t: args.lr_max 57 | else: 58 | raise ValueError('Unknown lr_type') 59 | 60 | criterion = nn.CrossEntropyLoss() 61 | 62 | logger.info('Epoch \t Time \t LR \t \t Train Loss \t Train Acc') 63 | for epoch in range(args.epochs): 64 | start_time = time.time() 65 | train_loss = 0 66 | train_acc = 0 67 | train_n = 0 68 | for i, (X, y) in enumerate(train_loader): 69 | X, y = X.cuda(), y.cuda() 70 | lr = lr_schedule(epoch + (i+1)/len(train_loader)) 71 | opt.param_groups[0].update(lr=lr) 72 | 73 | if args.attack == 'fgsm': 74 | delta = torch.zeros_like(X).uniform_(-args.epsilon, args.epsilon).cuda() 75 | delta.requires_grad = True 76 | output = model(X + delta) 77 | loss = F.cross_entropy(output, y) 78 | loss.backward() 79 | grad = delta.grad.detach() 80 | delta.data = torch.clamp(delta + args.alpha * torch.sign(grad), -args.epsilon, args.epsilon) 81 | delta.data = torch.max(torch.min(1-X, delta.data), 0-X) 82 | delta = delta.detach() 83 | elif args.attack == 'none': 84 | delta = torch.zeros_like(X) 85 | elif args.attack == 'pgd': 86 | delta = torch.zeros_like(X).uniform_(-args.epsilon, args.epsilon) 87 | delta.data = torch.max(torch.min(1-X, delta.data), 0-X) 88 | for _ in range(args.attack_iters): 89 | delta.requires_grad = True 90 | output = model(X + delta) 91 | loss = criterion(output, y) 92 | opt.zero_grad() 93 | loss.backward() 94 | grad = delta.grad.detach() 95 | I = output.max(1)[1] == y 96 | delta.data[I] = torch.clamp(delta + args.alpha * torch.sign(grad), -args.epsilon, args.epsilon)[I] 97 | delta.data[I] = torch.max(torch.min(1-X, delta.data), 0-X)[I] 98 | delta = delta.detach() 99 | output = model(torch.clamp(X + delta, 0, 1)) 100 | loss = criterion(output, y) 101 | opt.zero_grad() 102 | loss.backward() 103 | opt.step() 104 | 105 | train_loss += loss.item() * y.size(0) 106 | train_acc += (output.max(1)[1] == y).sum().item() 107 | train_n += y.size(0) 108 | 109 | train_time = time.time() 110 | logger.info('%d \t %.1f \t %.4f \t %.4f \t %.4f', 111 | epoch, train_time - start_time, lr, train_loss/train_n, train_acc/train_n) 112 | torch.save(model.state_dict(), args.fname) 113 | 114 | 115 | if __name__ == "__main__": 116 | main() 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fast adversarial training using FGSM 2 | 3 | *A repository that implements the fast adversarial training code using an FGSM adversary, capable of training a robust CIFAR10 classifier in 6 minutes and a robust ImageNet classifier in 12 hours. Created by [Eric Wong](https://riceric22.github.io), [Leslie Rice](https://leslierice1.github.io/), and [Zico Kolter](http://zicokolter.com). See our paper on arXiv [here][paper], which was inspired by the free adversarial training paper [here][freepaper] by Shafahi et al. (2019).* 4 | 5 | [paper]: https://arxiv.org/abs/2001.03994 6 | [freepaper]: https://arxiv.org/abs/1904.12843 7 | 8 | ## News 9 | + 12/19/2019 - Accepted to ICLR 2020 10 | + 1/14/2019 - arXiv posted and repository release 11 | 12 | ## What is in this repository? 13 | + An implementation of the FGSM adversarial training method with randomized initialization for MNIST, CIFAR10, and ImageNet 14 | + [Cyclic learning rates](https://arxiv.org/abs/1506.01186) and mixed precision training using the [apex](https://nvidia.github.io/apex/) library to achieve DAWNBench-like speedups 15 | + Pre-trained models using this code base 16 | + The ImageNet code is mostly forked from the [free adversarial training repository](https://github.com/mahyarnajibi/FreeAdversarialTraining), with the corresponding modifications for fast FGSM adversarial training 17 | 18 | ## Installation and usage 19 | + All examples can be run without mixed-precision with PyTorch v1.0 or higher 20 | + To use mixed-precision training, follow the apex installation instructions [here](https://github.com/NVIDIA/apex#quick-start) 21 | 22 | ## But wait, I thought FGSM training didn't work! 23 | As one of the earliest methods for generating adversarial examples, the Fast Gradient Sign Method (FGSM) is also known to be one of the weakest. It has largely been replaced by the PGD-based attacked, and it's use as an attack has become highly discouraged when [evaluating adversarial robustness](https://arxiv.org/abs/1902.06705). Afterall, early attempts at using FGSM adversarial training (including variants of randomized FGSM) were unsuccessful, and this was largely attributed to the weakness of the attack. 24 | 25 | However, we discovered that a fairly minor modification to the random initialization for FGSM adversarial training allows it to perform as well as the much more expensive PGD adversarial training. This was quite surprising to us, and suggests that one does not need very strong adversaries to learn robust models! As a result, we pushed the FGSM adversarial training to the limit, and found that by incorporating various techniques for fast training used in the [DAWNBench](https://dawn.cs.stanford.edu/benchmark/) competition, we could learn robust architectures an order of magnitude faster than before, while achieving the same degrees of robustness. A couple of the results from the paper are highlighted in the table below. 26 | 27 | | | CIFAR10 Acc | CIFAR10 Adv Acc (eps=8/255) | Time (minutes) | 28 | | --------:| -----------:|----------------------------:|---------------:| 29 | | FGSM | 86.06% | 46.06% | 12 | 30 | | Free | 85.96% | 46.33% | 785 | 31 | | PGD | 87.30% | 45.80% | 4966 | 32 | 33 | | | ImageNet Acc | ImageNet Adv Acc (eps=2/255) | Time (hours) | 34 | | --------:| ------------:|-----------------------------:|-------------:| 35 | | FGSM | 60.90% | 43.46% | 12 | 36 | | Free | 64.37% | 43.31% | 52 | 37 | 38 | ## But I've tried FGSM adversarial training before, and it didn't work! 39 | In our experiments, we discovered several failure modes which would cause FGSM adversarial training to ``catastrophically fail'', like in the following plot. 40 | 41 | ![overfitting](https://github.com/locuslab/fast_adversarial/blob/master/overfitting_error_curve.png) 42 | 43 | If FGSM adversarial training hasn't worked for you in the past, then it may be because of one of the following reasons (which we present as a non-exhaustive list of ways to fail): 44 | 45 | + FGSM step size is too large, forcing the adversarial examples to cluster near the boundary 46 | + Random initialization only covers a smaller subset of the threat model 47 | + Long training with many epochs and fine tuning with very small learning rates 48 | 49 | All of these pitfalls can be avoided by simply using early stopping based on a subset of the training data to evaluate the robust accuracy with respect to PGD, as the failure mode for FGSM adversarial training occurs quite rapidly (going to 0% robust accuracy within the span of a couple epochs) 50 | 51 | ## Why does this matter if I still want to use PGD adversarial training in my experiments? 52 | 53 | The speedups gained from using mixed-precision arithmetic and cyclic learning rates can still be reaped regardless of what training regimen you end up using! For example, these techniques can speed up CIFAR10 PGD adversarial training by almost 2 orders of magnitude, reducing training time by about 3.5 days to just over 1 hour. The engineering costs of installing the `apex` library and changing the learning rate schedule are miniscule in comparison to the time saved from using these two techniques, and so even if you don't use FGSM adversarial training, you can still benefit from faster experimentation with the DAWNBench improvements. 54 | -------------------------------------------------------------------------------- /ImageNet/lib/validation.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import torch 3 | import sys 4 | import numpy as np 5 | import time 6 | from torch.autograd import Variable 7 | 8 | def validate_pgd(val_loader, model, criterion, K, step, configs, logger): 9 | # Mean/Std for normalization 10 | mean = torch.Tensor(np.array(configs.TRAIN.mean)[:, np.newaxis, np.newaxis]) 11 | mean = mean.expand(3,configs.DATA.crop_size, configs.DATA.crop_size).cuda() 12 | std = torch.Tensor(np.array(configs.TRAIN.std)[:, np.newaxis, np.newaxis]) 13 | std = std.expand(3, configs.DATA.crop_size, configs.DATA.crop_size).cuda() 14 | # Initiate the meters 15 | batch_time = AverageMeter() 16 | losses = AverageMeter() 17 | top1 = AverageMeter() 18 | top5 = AverageMeter() 19 | 20 | eps = configs.ADV.clip_eps 21 | model.eval() 22 | end = time.time() 23 | logger.info(pad_str(' PGD eps: {}, K: {}, step: {}, restarts: {} '.format(eps, K, step, configs.restarts))) 24 | for i, (input, target) in enumerate(val_loader): 25 | 26 | input = input.cuda(non_blocking=True) 27 | target = target.cuda(non_blocking=True) 28 | 29 | orig_input = input.clone() 30 | 31 | for __ in range(configs.restarts): 32 | randn = torch.FloatTensor(input.size()).uniform_(-eps, eps).cuda() 33 | input = orig_input + randn 34 | input.clamp_(0, 1.0) 35 | for _ in range(K): 36 | invar = Variable(input, requires_grad=True) 37 | in1 = invar - mean 38 | in1.div_(std) 39 | output = model(in1) 40 | ascend_loss = criterion(output, target) 41 | ascend_grad = torch.autograd.grad(ascend_loss, invar)[0] 42 | pert = fgsm(ascend_grad, step) 43 | # Apply purturbation 44 | input += pert.data 45 | input = torch.max(orig_input-eps, input) 46 | input = torch.min(orig_input+eps, input) 47 | input.clamp_(0, 1.0) 48 | 49 | input.sub_(mean).div_(std) 50 | with torch.no_grad(): 51 | if __ == 0: 52 | final_input = input.clone() 53 | else: 54 | output = model(input) 55 | I = output.max(1)[1] != target 56 | final_input[I] = input[I] 57 | 58 | with torch.no_grad(): 59 | # compute output 60 | input = final_input 61 | output = model(input) 62 | loss = criterion(output, target) 63 | 64 | # measure accuracy and record loss 65 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 66 | losses.update(loss.item(), input.size(0)) 67 | top1.update(prec1[0], input.size(0)) 68 | top5.update(prec5[0], input.size(0)) 69 | 70 | # measure elapsed time 71 | batch_time.update(time.time() - end) 72 | end = time.time() 73 | 74 | if i % configs.TRAIN.print_freq == 0: 75 | print('PGD Test: [{0}/{1}]\t' 76 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 77 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 78 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 79 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 80 | i, len(val_loader), batch_time=batch_time, loss=losses, 81 | top1=top1, top5=top5)) 82 | sys.stdout.flush() 83 | 84 | print(' PGD Final Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 85 | .format(top1=top1, top5=top5)) 86 | 87 | return top1.avg 88 | 89 | def validate(val_loader, model, criterion, configs, logger): 90 | # Mean/Std for normalization 91 | mean = torch.Tensor(np.array(configs.TRAIN.mean)[:, np.newaxis, np.newaxis]) 92 | mean = mean.expand(3,configs.DATA.crop_size, configs.DATA.crop_size).cuda() 93 | std = torch.Tensor(np.array(configs.TRAIN.std)[:, np.newaxis, np.newaxis]) 94 | std = std.expand(3, configs.DATA.crop_size, configs.DATA.crop_size).cuda() 95 | 96 | # Initiate the meters 97 | batch_time = AverageMeter() 98 | losses = AverageMeter() 99 | top1 = AverageMeter() 100 | top5 = AverageMeter() 101 | # switch to evaluate mode 102 | model.eval() 103 | end = time.time() 104 | for i, (input, target) in enumerate(val_loader): 105 | with torch.no_grad(): 106 | input = input.cuda(non_blocking=True) 107 | target = target.cuda(non_blocking=True) 108 | 109 | # compute output 110 | input = input - mean 111 | input.div_(std) 112 | output = model(input) 113 | loss = criterion(output, target) 114 | 115 | # measure accuracy and record loss 116 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 117 | losses.update(loss.item(), input.size(0)) 118 | top1.update(prec1[0], input.size(0)) 119 | top5.update(prec5[0], input.size(0)) 120 | 121 | # measure elapsed time 122 | batch_time.update(time.time() - end) 123 | end = time.time() 124 | 125 | if i % configs.TRAIN.print_freq == 0: 126 | print('Test: [{0}/{1}]\t' 127 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 128 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 129 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 130 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 131 | i, len(val_loader), batch_time=batch_time, loss=losses, 132 | top1=top1, top5=top5)) 133 | sys.stdout.flush() 134 | 135 | print(' Final Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 136 | .format(top1=top1, top5=top5)) 137 | return top1.avg -------------------------------------------------------------------------------- /CIFAR10/train_free.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import time 5 | 6 | import apex.amp as amp 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from preact_resnet import PreActResNet18 13 | from utils import (upper_limit, lower_limit, std, clamp, get_loaders, 14 | evaluate_pgd, evaluate_standard) 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def get_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--batch-size', default=128, type=int) 22 | parser.add_argument('--data-dir', default='../../cifar-data', type=str) 23 | parser.add_argument('--epochs', default=10, type=int, help='Total number of epochs will be this argument * number of minibatch replays.') 24 | parser.add_argument('--lr-schedule', default='cyclic', type=str, choices=['cyclic', 'multistep']) 25 | parser.add_argument('--lr-min', default=0., type=float) 26 | parser.add_argument('--lr-max', default=0.04, type=float) 27 | parser.add_argument('--weight-decay', default=5e-4, type=float) 28 | parser.add_argument('--momentum', default=0.9, type=float) 29 | parser.add_argument('--epsilon', default=8, type=int) 30 | parser.add_argument('--minibatch-replays', default=8, type=int) 31 | parser.add_argument('--out-dir', default='train_free_output', type=str, help='Output directory') 32 | parser.add_argument('--seed', default=0, type=int) 33 | parser.add_argument('--opt-level', default='O2', type=str, choices=['O0', 'O1', 'O2'], 34 | help='O0 is FP32 training, O1 is Mixed Precision, and O2 is "Almost FP16" Mixed Precision') 35 | parser.add_argument('--loss-scale', default='1.0', type=str, choices=['1.0', 'dynamic'], 36 | help='If loss_scale is "dynamic", adaptively adjust the loss scale over time') 37 | parser.add_argument('--master-weights', action='store_true', 38 | help='Maintain FP32 master weights to accompany any FP16 model weights, not applicable for O1 opt level') 39 | return parser.parse_args() 40 | 41 | 42 | def main(): 43 | args = get_args() 44 | 45 | if not os.path.exists(args.out_dir): 46 | os.mkdir(args.out_dir) 47 | logfile = os.path.join(args.out_dir, 'output.log') 48 | if os.path.exists(logfile): 49 | os.remove(logfile) 50 | 51 | logging.basicConfig( 52 | format='[%(asctime)s] - %(message)s', 53 | datefmt='%Y/%m/%d %H:%M:%S', 54 | level=logging.INFO, 55 | filename=logfile) 56 | logger.info(args) 57 | 58 | np.random.seed(args.seed) 59 | torch.manual_seed(args.seed) 60 | torch.cuda.manual_seed(args.seed) 61 | 62 | train_loader, test_loader = get_loaders(args.data_dir, args.batch_size) 63 | 64 | epsilon = (args.epsilon / 255.) / std 65 | 66 | model = PreActResNet18().cuda() 67 | model.train() 68 | 69 | opt = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=args.momentum, weight_decay=args.weight_decay) 70 | amp_args = dict(opt_level=args.opt_level, loss_scale=args.loss_scale, verbosity=False) 71 | if args.opt_level == 'O2': 72 | amp_args['master_weights'] = args.master_weights 73 | model, opt = amp.initialize(model, opt, **amp_args) 74 | criterion = nn.CrossEntropyLoss() 75 | 76 | delta = torch.zeros(args.batch_size, 3, 32, 32).cuda() 77 | delta.requires_grad = True 78 | 79 | lr_steps = args.epochs * len(train_loader) * args.minibatch_replays 80 | if args.lr_schedule == 'cyclic': 81 | scheduler = torch.optim.lr_scheduler.CyclicLR(opt, base_lr=args.lr_min, max_lr=args.lr_max, 82 | step_size_up=lr_steps / 2, step_size_down=lr_steps / 2) 83 | elif args.lr_schedule == 'multistep': 84 | scheduler = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=[lr_steps / 2, lr_steps * 3 / 4], gamma=0.1) 85 | 86 | # Training 87 | start_train_time = time.time() 88 | logger.info('Epoch \t Seconds \t LR \t \t Train Loss \t Train Acc') 89 | for epoch in range(args.epochs): 90 | start_epoch_time = time.time() 91 | train_loss = 0 92 | train_acc = 0 93 | train_n = 0 94 | for i, (X, y) in enumerate(train_loader): 95 | X, y = X.cuda(), y.cuda() 96 | for _ in range(args.minibatch_replays): 97 | output = model(X + delta[:X.size(0)]) 98 | loss = criterion(output, y) 99 | opt.zero_grad() 100 | with amp.scale_loss(loss, opt) as scaled_loss: 101 | scaled_loss.backward() 102 | grad = delta.grad.detach() 103 | delta.data = clamp(delta + epsilon * torch.sign(grad), -epsilon, epsilon) 104 | delta.data[:X.size(0)] = clamp(delta[:X.size(0)], lower_limit - X, upper_limit - X) 105 | opt.step() 106 | delta.grad.zero_() 107 | scheduler.step() 108 | train_loss += loss.item() * y.size(0) 109 | train_acc += (output.max(1)[1] == y).sum().item() 110 | train_n += y.size(0) 111 | epoch_time = time.time() 112 | lr = scheduler.get_lr()[0] 113 | logger.info('%d \t %.1f \t \t %.4f \t %.4f \t %.4f', 114 | epoch, epoch_time - start_epoch_time, lr, train_loss/train_n, train_acc/train_n) 115 | train_time = time.time() 116 | torch.save(model.state_dict(), os.path.join(args.out_dir, 'model.pth')) 117 | logger.info('Total train time: %.4f minutes', (train_time - start_train_time)/60) 118 | 119 | # Evaluation 120 | model_test = PreActResNet18().cuda() 121 | model_test.load_state_dict(model.state_dict()) 122 | model_test.float() 123 | model_test.eval() 124 | 125 | pgd_loss, pgd_acc = evaluate_pgd(test_loader, model_test, 50, 10) 126 | test_loss, test_acc = evaluate_standard(test_loader, model_test) 127 | 128 | logger.info('Test Loss \t Test Acc \t PGD Loss \t PGD Acc') 129 | logger.info('%.4f \t \t %.4f \t %.4f \t %.4f', test_loss, test_acc, pgd_loss, pgd_acc) 130 | 131 | 132 | if __name__ == "__main__": 133 | main() 134 | -------------------------------------------------------------------------------- /CIFAR10/train_pgd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import time 5 | 6 | import apex.amp as amp 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from preact_resnet import PreActResNet18 13 | from utils import (upper_limit, lower_limit, std, clamp, get_loaders, 14 | evaluate_pgd, evaluate_standard) 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def get_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--batch-size', default=128, type=int) 22 | parser.add_argument('--data-dir', default='../../cifar-data', type=str) 23 | parser.add_argument('--epochs', default=15, type=int) 24 | parser.add_argument('--lr-schedule', default='cyclic', type=str, choices=['cyclic', 'multistep']) 25 | parser.add_argument('--lr-min', default=0., type=float) 26 | parser.add_argument('--lr-max', default=0.2, type=float) 27 | parser.add_argument('--weight-decay', default=5e-4, type=float) 28 | parser.add_argument('--momentum', default=0.9, type=float) 29 | parser.add_argument('--epsilon', default=8, type=int) 30 | parser.add_argument('--attack-iters', default=7, type=int, help='Attack iterations') 31 | parser.add_argument('--restarts', default=1, type=int) 32 | parser.add_argument('--alpha', default=2, type=int, help='Step size') 33 | parser.add_argument('--delta-init', default='random', choices=['zero', 'random'], 34 | help='Perturbation initialization method') 35 | parser.add_argument('--out-dir', default='train_pgd_output', type=str, help='Output directory') 36 | parser.add_argument('--seed', default=0, type=int, help='Random seed') 37 | parser.add_argument('--opt-level', default='O2', type=str, choices=['O0', 'O1', 'O2'], 38 | help='O0 is FP32 training, O1 is Mixed Precision, and O2 is "Almost FP16" Mixed Precision') 39 | parser.add_argument('--loss-scale', default='1.0', type=str, choices=['1.0', 'dynamic'], 40 | help='If loss_scale is "dynamic", adaptively adjust the loss scale over time') 41 | parser.add_argument('--master-weights', action='store_true', 42 | help='Maintain FP32 master weights to accompany any FP16 model weights, not applicable for O1 opt level') 43 | return parser.parse_args() 44 | 45 | 46 | def main(): 47 | args = get_args() 48 | 49 | if not os.path.exists(args.out_dir): 50 | os.mkdir(args.out_dir) 51 | logfile = os.path.join(args.out_dir, 'output.log') 52 | if os.path.exists(logfile): 53 | os.remove(logfile) 54 | 55 | logging.basicConfig( 56 | format='[%(asctime)s] - %(message)s', 57 | datefmt='%Y/%m/%d %H:%M:%S', 58 | level=logging.INFO, 59 | filename=logfile) 60 | logger.info(args) 61 | 62 | np.random.seed(args.seed) 63 | torch.manual_seed(args.seed) 64 | torch.cuda.manual_seed(args.seed) 65 | 66 | train_loader, test_loader = get_loaders(args.data_dir, args.batch_size) 67 | 68 | epsilon = (args.epsilon / 255.) / std 69 | alpha = (args.alpha / 255.) / std 70 | 71 | model = PreActResNet18().cuda() 72 | model.train() 73 | 74 | opt = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=args.momentum, weight_decay=args.weight_decay) 75 | amp_args = dict(opt_level=args.opt_level, loss_scale=args.loss_scale, verbosity=False) 76 | if args.opt_level == 'O2': 77 | amp_args['master_weights'] = args.master_weights 78 | model, opt = amp.initialize(model, opt, **amp_args) 79 | criterion = nn.CrossEntropyLoss() 80 | 81 | lr_steps = args.epochs * len(train_loader) 82 | if args.lr_schedule == 'cyclic': 83 | scheduler = torch.optim.lr_scheduler.CyclicLR(opt, base_lr=args.lr_min, max_lr=args.lr_max, 84 | step_size_up=lr_steps / 2, step_size_down=lr_steps / 2) 85 | elif args.lr_schedule == 'multistep': 86 | scheduler = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=[lr_steps / 2, lr_steps * 3 / 4], gamma=0.1) 87 | 88 | # Training 89 | start_train_time = time.time() 90 | logger.info('Epoch \t Seconds \t LR \t \t Train Loss \t Train Acc') 91 | for epoch in range(args.epochs): 92 | start_epoch_time = time.time() 93 | train_loss = 0 94 | train_acc = 0 95 | train_n = 0 96 | for i, (X, y) in enumerate(train_loader): 97 | X, y = X.cuda(), y.cuda() 98 | delta = torch.zeros_like(X).cuda() 99 | if args.delta_init == 'random': 100 | for i in range(len(epsilon)): 101 | delta[:, i, :, :].uniform_(-epsilon[i][0][0].item(), epsilon[i][0][0].item()) 102 | delta.data = clamp(delta, lower_limit - X, upper_limit - X) 103 | delta.requires_grad = True 104 | for _ in range(args.attack_iters): 105 | output = model(X + delta) 106 | loss = criterion(output, y) 107 | with amp.scale_loss(loss, opt) as scaled_loss: 108 | scaled_loss.backward() 109 | grad = delta.grad.detach() 110 | delta.data = clamp(delta + alpha * torch.sign(grad), -epsilon, epsilon) 111 | delta.data = clamp(delta, lower_limit - X, upper_limit - X) 112 | delta.grad.zero_() 113 | delta = delta.detach() 114 | output = model(X + delta) 115 | loss = criterion(output, y) 116 | opt.zero_grad() 117 | with amp.scale_loss(loss, opt) as scaled_loss: 118 | scaled_loss.backward() 119 | opt.step() 120 | train_loss += loss.item() * y.size(0) 121 | train_acc += (output.max(1)[1] == y).sum().item() 122 | train_n += y.size(0) 123 | scheduler.step() 124 | epoch_time = time.time() 125 | lr = scheduler.get_lr()[0] 126 | logger.info('%d \t %.1f \t \t %.4f \t %.4f \t %.4f', 127 | epoch, epoch_time - start_epoch_time, lr, train_loss/train_n, train_acc/train_n) 128 | train_time = time.time() 129 | torch.save(model.state_dict(), os.path.join(args.out_dir, 'model.pth')) 130 | logger.info('Total train time: %.4f minutes', (train_time - start_train_time)/60) 131 | 132 | # Evaluation 133 | model_test = PreActResNet18().cuda() 134 | model_test.load_state_dict(model.state_dict()) 135 | model_test.float() 136 | model_test.eval() 137 | 138 | pgd_loss, pgd_acc = evaluate_pgd(test_loader, model_test, 50, 10) 139 | test_loss, test_acc = evaluate_standard(test_loader, model_test) 140 | 141 | logger.info('Test Loss \t Test Acc \t PGD Loss \t PGD Acc') 142 | logger.info('%.4f \t \t %.4f \t %.4f \t %.4f', test_loss, test_acc, pgd_loss, pgd_acc) 143 | 144 | 145 | if __name__ == "__main__": 146 | main() 147 | -------------------------------------------------------------------------------- /CIFAR10/train_fgsm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import logging 4 | import os 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from apex import amp 12 | 13 | from preact_resnet import PreActResNet18 14 | from utils import (upper_limit, lower_limit, std, clamp, get_loaders, 15 | attack_pgd, evaluate_pgd, evaluate_standard) 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def get_args(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--batch-size', default=128, type=int) 23 | parser.add_argument('--data-dir', default='../../cifar-data', type=str) 24 | parser.add_argument('--epochs', default=15, type=int) 25 | parser.add_argument('--lr-schedule', default='cyclic', choices=['cyclic', 'multistep']) 26 | parser.add_argument('--lr-min', default=0., type=float) 27 | parser.add_argument('--lr-max', default=0.2, type=float) 28 | parser.add_argument('--weight-decay', default=5e-4, type=float) 29 | parser.add_argument('--momentum', default=0.9, type=float) 30 | parser.add_argument('--epsilon', default=8, type=int) 31 | parser.add_argument('--alpha', default=10, type=float, help='Step size') 32 | parser.add_argument('--delta-init', default='random', choices=['zero', 'random', 'previous'], 33 | help='Perturbation initialization method') 34 | parser.add_argument('--out-dir', default='train_fgsm_output', type=str, help='Output directory') 35 | parser.add_argument('--seed', default=0, type=int, help='Random seed') 36 | parser.add_argument('--early-stop', action='store_true', help='Early stop if overfitting occurs') 37 | parser.add_argument('--opt-level', default='O2', type=str, choices=['O0', 'O1', 'O2'], 38 | help='O0 is FP32 training, O1 is Mixed Precision, and O2 is "Almost FP16" Mixed Precision') 39 | parser.add_argument('--loss-scale', default='1.0', type=str, choices=['1.0', 'dynamic'], 40 | help='If loss_scale is "dynamic", adaptively adjust the loss scale over time') 41 | parser.add_argument('--master-weights', action='store_true', 42 | help='Maintain FP32 master weights to accompany any FP16 model weights, not applicable for O1 opt level') 43 | return parser.parse_args() 44 | 45 | 46 | def main(): 47 | args = get_args() 48 | 49 | if not os.path.exists(args.out_dir): 50 | os.mkdir(args.out_dir) 51 | logfile = os.path.join(args.out_dir, 'output.log') 52 | if os.path.exists(logfile): 53 | os.remove(logfile) 54 | 55 | logging.basicConfig( 56 | format='[%(asctime)s] - %(message)s', 57 | datefmt='%Y/%m/%d %H:%M:%S', 58 | level=logging.INFO, 59 | filename=os.path.join(args.out_dir, 'output.log')) 60 | logger.info(args) 61 | 62 | np.random.seed(args.seed) 63 | torch.manual_seed(args.seed) 64 | torch.cuda.manual_seed(args.seed) 65 | 66 | train_loader, test_loader = get_loaders(args.data_dir, args.batch_size) 67 | 68 | epsilon = (args.epsilon / 255.) / std 69 | alpha = (args.alpha / 255.) / std 70 | pgd_alpha = (2 / 255.) / std 71 | 72 | model = PreActResNet18().cuda() 73 | model.train() 74 | 75 | opt = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=args.momentum, weight_decay=args.weight_decay) 76 | amp_args = dict(opt_level=args.opt_level, loss_scale=args.loss_scale, verbosity=False) 77 | if args.opt_level == 'O2': 78 | amp_args['master_weights'] = args.master_weights 79 | model, opt = amp.initialize(model, opt, **amp_args) 80 | criterion = nn.CrossEntropyLoss() 81 | 82 | if args.delta_init == 'previous': 83 | delta = torch.zeros(args.batch_size, 3, 32, 32).cuda() 84 | 85 | lr_steps = args.epochs * len(train_loader) 86 | if args.lr_schedule == 'cyclic': 87 | scheduler = torch.optim.lr_scheduler.CyclicLR(opt, base_lr=args.lr_min, max_lr=args.lr_max, 88 | step_size_up=lr_steps / 2, step_size_down=lr_steps / 2) 89 | elif args.lr_schedule == 'multistep': 90 | scheduler = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=[lr_steps / 2, lr_steps * 3 / 4], gamma=0.1) 91 | 92 | # Training 93 | prev_robust_acc = 0. 94 | start_train_time = time.time() 95 | logger.info('Epoch \t Seconds \t LR \t \t Train Loss \t Train Acc') 96 | for epoch in range(args.epochs): 97 | start_epoch_time = time.time() 98 | train_loss = 0 99 | train_acc = 0 100 | train_n = 0 101 | for i, (X, y) in enumerate(train_loader): 102 | X, y = X.cuda(), y.cuda() 103 | if i == 0: 104 | first_batch = (X, y) 105 | if args.delta_init != 'previous': 106 | delta = torch.zeros_like(X).cuda() 107 | if args.delta_init == 'random': 108 | for j in range(len(epsilon)): 109 | delta[:, j, :, :].uniform_(-epsilon[j][0][0].item(), epsilon[j][0][0].item()) 110 | delta.data = clamp(delta, lower_limit - X, upper_limit - X) 111 | delta.requires_grad = True 112 | output = model(X + delta[:X.size(0)]) 113 | loss = F.cross_entropy(output, y) 114 | with amp.scale_loss(loss, opt) as scaled_loss: 115 | scaled_loss.backward() 116 | grad = delta.grad.detach() 117 | delta.data = clamp(delta + alpha * torch.sign(grad), -epsilon, epsilon) 118 | delta.data[:X.size(0)] = clamp(delta[:X.size(0)], lower_limit - X, upper_limit - X) 119 | delta = delta.detach() 120 | output = model(X + delta[:X.size(0)]) 121 | loss = criterion(output, y) 122 | opt.zero_grad() 123 | with amp.scale_loss(loss, opt) as scaled_loss: 124 | scaled_loss.backward() 125 | opt.step() 126 | train_loss += loss.item() * y.size(0) 127 | train_acc += (output.max(1)[1] == y).sum().item() 128 | train_n += y.size(0) 129 | scheduler.step() 130 | if args.early_stop: 131 | # Check current PGD robustness of model using random minibatch 132 | X, y = first_batch 133 | pgd_delta = attack_pgd(model, X, y, epsilon, pgd_alpha, 5, 1, opt) 134 | with torch.no_grad(): 135 | output = model(clamp(X + pgd_delta[:X.size(0)], lower_limit, upper_limit)) 136 | robust_acc = (output.max(1)[1] == y).sum().item() / y.size(0) 137 | if robust_acc - prev_robust_acc < -0.2: 138 | break 139 | prev_robust_acc = robust_acc 140 | best_state_dict = copy.deepcopy(model.state_dict()) 141 | epoch_time = time.time() 142 | lr = scheduler.get_lr()[0] 143 | logger.info('%d \t %.1f \t \t %.4f \t %.4f \t %.4f', 144 | epoch, epoch_time - start_epoch_time, lr, train_loss/train_n, train_acc/train_n) 145 | train_time = time.time() 146 | if not args.early_stop: 147 | best_state_dict = model.state_dict() 148 | torch.save(best_state_dict, os.path.join(args.out_dir, 'model.pth')) 149 | logger.info('Total train time: %.4f minutes', (train_time - start_train_time)/60) 150 | 151 | # Evaluation 152 | model_test = PreActResNet18().cuda() 153 | model_test.load_state_dict(best_state_dict) 154 | model_test.float() 155 | model_test.eval() 156 | 157 | pgd_loss, pgd_acc = evaluate_pgd(test_loader, model_test, 50, 10) 158 | test_loss, test_acc = evaluate_standard(test_loader, model_test) 159 | 160 | logger.info('Test Loss \t Test Acc \t PGD Loss \t PGD Acc') 161 | logger.info('%.4f \t \t %.4f \t %.4f \t %.4f', test_loss, test_acc, pgd_loss, pgd_acc) 162 | 163 | 164 | if __name__ == "__main__": 165 | main() 166 | -------------------------------------------------------------------------------- /ImageNet/main_fast.py: -------------------------------------------------------------------------------- 1 | # This module is adapted from https://github.com/mahyarnajibi/FreeAdversarialTraining/blob/master/main_free.py 2 | # Which in turn was adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py 3 | import init_paths 4 | import argparse 5 | import os 6 | import time 7 | import sys 8 | import torch 9 | import torch.nn as nn 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim 12 | import torchvision.transforms as transforms 13 | import torchvision.datasets as datasets 14 | from torch.autograd import Variable 15 | import math 16 | import numpy as np 17 | from utils import * 18 | from validation import validate, validate_pgd 19 | import torchvision.models as models 20 | 21 | from apex import amp 22 | import copy 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 27 | parser.add_argument('data', metavar='DIR', 28 | help='path to dataset') 29 | parser.add_argument('--output_prefix', default='fast_adv', type=str, 30 | help='prefix used to define output path') 31 | parser.add_argument('-c', '--config', default='configs.yml', type=str, metavar='Path', 32 | help='path to the config file (default: configs.yml)') 33 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 34 | help='path to latest checkpoint (default: none)') 35 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 36 | help='evaluate model on validation set') 37 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 38 | help='use pre-trained model') 39 | parser.add_argument('--restarts', default=1, type=int) 40 | return parser.parse_args() 41 | 42 | 43 | # Parase config file and initiate logging 44 | configs = parse_config_file(parse_args()) 45 | logger = initiate_logger(configs.output_name, configs.evaluate) 46 | print = logger.info 47 | cudnn.benchmark = True 48 | 49 | def main(): 50 | # Scale and initialize the parameters 51 | best_prec1 = 0 52 | configs.TRAIN.epochs = int(math.ceil(configs.TRAIN.epochs / configs.ADV.n_repeats)) 53 | configs.ADV.fgsm_step /= configs.DATA.max_color_value 54 | configs.ADV.clip_eps /= configs.DATA.max_color_value 55 | 56 | # Create output folder 57 | if not os.path.isdir(os.path.join('trained_models', configs.output_name)): 58 | os.makedirs(os.path.join('trained_models', configs.output_name)) 59 | 60 | # Log the config details 61 | logger.info(pad_str(' ARGUMENTS ')) 62 | for k, v in configs.items(): print('{}: {}'.format(k, v)) 63 | logger.info(pad_str('')) 64 | 65 | 66 | # Create the model 67 | if configs.pretrained: 68 | print("=> using pre-trained model '{}'".format(configs.TRAIN.arch)) 69 | model = models.__dict__[configs.TRAIN.arch](pretrained=True) 70 | else: 71 | print("=> creating model '{}'".format(configs.TRAIN.arch)) 72 | model = models.__dict__[configs.TRAIN.arch]() 73 | # Wrap the model into DataParallel 74 | model.cuda() 75 | 76 | # reverse mapping 77 | param_to_moduleName = {} 78 | for m in model.modules(): 79 | for p in m.parameters(recurse=False): 80 | param_to_moduleName[p] = str(type(m).__name__) 81 | 82 | # Criterion: 83 | criterion = nn.CrossEntropyLoss().cuda() 84 | 85 | group_decay = [p for p in model.parameters() if 'BatchNorm' not in param_to_moduleName[p]] 86 | group_no_decay = [p for p in model.parameters() if 'BatchNorm' in param_to_moduleName[p]] 87 | groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=0)] 88 | optimizer = torch.optim.SGD(groups, configs.TRAIN.lr, 89 | momentum=configs.TRAIN.momentum, 90 | weight_decay=configs.TRAIN.weight_decay) 91 | 92 | if configs.TRAIN.half and not configs.evaluate: 93 | model, optimizer = amp.initialize(model, optimizer, opt_level="O1") 94 | model = torch.nn.DataParallel(model) 95 | 96 | # Resume if a valid checkpoint path is provided 97 | if configs.resume: 98 | if os.path.isfile(configs.resume): 99 | print("=> loading checkpoint '{}'".format(configs.resume)) 100 | checkpoint = torch.load(configs.resume) 101 | configs.TRAIN.start_epoch = checkpoint['epoch'] 102 | best_prec1 = checkpoint['best_prec1'] 103 | model.load_state_dict(checkpoint['state_dict']) 104 | optimizer.load_state_dict(checkpoint['optimizer']) 105 | print("=> loaded checkpoint '{}' (epoch {})" 106 | .format(configs.resume, checkpoint['epoch'])) 107 | else: 108 | print("=> no checkpoint found at '{}'".format(configs.resume)) 109 | 110 | # Initiate data loaders 111 | traindir = os.path.join(configs.data, 'train') 112 | valdir = os.path.join(configs.data, 'val') 113 | 114 | resize_transform = [] 115 | 116 | if configs.DATA.img_size > 0: 117 | resize_transform = [ transforms.Resize(configs.DATA.img_size) ] 118 | 119 | train_dataset = datasets.ImageFolder( 120 | traindir, 121 | transforms.Compose(resize_transform + [ 122 | transforms.RandomResizedCrop(configs.DATA.crop_size), 123 | transforms.RandomHorizontalFlip(), 124 | transforms.ToTensor(), 125 | ])) 126 | 127 | train_loader = torch.utils.data.DataLoader( 128 | train_dataset, batch_size=configs.DATA.batch_size, shuffle=True, 129 | num_workers=configs.DATA.workers, pin_memory=True, sampler=None) 130 | 131 | normalize = transforms.Normalize(mean=configs.TRAIN.mean, 132 | std=configs.TRAIN.std) 133 | 134 | val_loader = torch.utils.data.DataLoader( 135 | datasets.ImageFolder(valdir, transforms.Compose( resize_transform + [ 136 | transforms.CenterCrop(configs.DATA.crop_size), 137 | transforms.ToTensor(), 138 | ])), 139 | batch_size=configs.DATA.batch_size, shuffle=False, 140 | num_workers=configs.DATA.workers, pin_memory=True) 141 | 142 | # If in evaluate mode: perform validation on PGD attacks as well as clean samples 143 | if configs.evaluate: 144 | logger.info(pad_str(' Performing PGD Attacks ')) 145 | for pgd_param in configs.ADV.pgd_attack: 146 | validate_pgd(val_loader, model, criterion, pgd_param[0], pgd_param[1], configs, logger) 147 | validate(val_loader, model, criterion, configs, logger) 148 | return 149 | 150 | lr_schedule = lambda t: np.interp([t], configs.TRAIN.lr_epochs, configs.TRAIN.lr_values)[0] 151 | 152 | for epoch in range(configs.TRAIN.start_epoch, configs.TRAIN.epochs): 153 | # train for one epoch 154 | train(train_loader, model, criterion, optimizer, epoch, lr_schedule, configs.TRAIN.half) 155 | 156 | # evaluate on validation set 157 | prec1 = validate(val_loader, model, criterion, configs, logger) 158 | 159 | # remember best prec@1 and save checkpoint 160 | is_best = prec1 > best_prec1 161 | best_prec1 = max(prec1, best_prec1) 162 | save_checkpoint({ 163 | 'epoch': epoch + 1, 164 | 'arch': configs.TRAIN.arch, 165 | 'state_dict': model.state_dict(), 166 | 'best_prec1': best_prec1, 167 | 'optimizer' : optimizer.state_dict(), 168 | }, is_best, os.path.join('trained_models', f'{configs.output_name}'), 169 | epoch + 1) 170 | 171 | # Automatically perform PGD Attacks at the end of training 172 | # logger.info(pad_str(' Performing PGD Attacks ')) 173 | # for pgd_param in configs.ADV.pgd_attack: 174 | # validate_pgd(val_loader, val_model, criterion, pgd_param[0], pgd_param[1], configs, logger) 175 | 176 | 177 | # Fast Adversarial Training Module 178 | global global_noise_data 179 | global_noise_data = torch.zeros([configs.DATA.batch_size, 3, configs.DATA.crop_size, configs.DATA.crop_size]).cuda() 180 | def train(train_loader, model, criterion, optimizer, epoch, lr_schedule, half=False): 181 | global global_noise_data 182 | 183 | mean = torch.Tensor(np.array(configs.TRAIN.mean)[:, np.newaxis, np.newaxis]) 184 | mean = mean.expand(3,configs.DATA.crop_size, configs.DATA.crop_size).cuda() 185 | std = torch.Tensor(np.array(configs.TRAIN.std)[:, np.newaxis, np.newaxis]) 186 | std = std.expand(3, configs.DATA.crop_size, configs.DATA.crop_size).cuda() 187 | 188 | # Initialize the meters 189 | batch_time = AverageMeter() 190 | data_time = AverageMeter() 191 | losses = AverageMeter() 192 | top1 = AverageMeter() 193 | top5 = AverageMeter() 194 | # switch to train mode 195 | model.train() 196 | end = time.time() 197 | for i, (input, target) in enumerate(train_loader): 198 | input = input.cuda(non_blocking=True) 199 | target = target.cuda(non_blocking=True) 200 | data_time.update(time.time() - end) 201 | 202 | if configs.TRAIN.random_init: 203 | global_noise_data.uniform_(-configs.ADV.clip_eps, configs.ADV.clip_eps) 204 | for j in range(configs.ADV.n_repeats): 205 | # update learning rate 206 | lr = lr_schedule(epoch + (i*configs.ADV.n_repeats + j + 1)/len(train_loader)) 207 | for param_group in optimizer.param_groups: 208 | param_group['lr'] = lr 209 | 210 | # Ascend on the global noise 211 | noise_batch = Variable(global_noise_data[0:input.size(0)], requires_grad=True)#.cuda() 212 | in1 = input + noise_batch 213 | in1.clamp_(0, 1.0) 214 | in1.sub_(mean).div_(std) 215 | output = model(in1) 216 | loss = criterion(output, target) 217 | if half: 218 | with amp.scale_loss(loss, optimizer) as scaled_loss: 219 | scaled_loss.backward() 220 | else: 221 | loss.backward() 222 | 223 | # Update the noise for the next iteration 224 | pert = fgsm(noise_batch.grad, configs.ADV.fgsm_step) 225 | global_noise_data[0:input.size(0)] += pert.data 226 | global_noise_data.clamp_(-configs.ADV.clip_eps, configs.ADV.clip_eps) 227 | 228 | # Descend on global noise 229 | noise_batch = Variable(global_noise_data[0:input.size(0)], requires_grad=False)#.cuda() 230 | in1 = input + noise_batch 231 | in1.clamp_(0, 1.0) 232 | in1.sub_(mean).div_(std) 233 | output = model(in1) 234 | loss = criterion(output, target) 235 | 236 | # compute gradient and do SGD step 237 | optimizer.zero_grad() 238 | if half: 239 | with amp.scale_loss(loss, optimizer) as scaled_loss: 240 | scaled_loss.backward() 241 | else: 242 | loss.backward() 243 | 244 | optimizer.step() 245 | 246 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 247 | losses.update(loss.item(), input.size(0)) 248 | top1.update(prec1[0], input.size(0)) 249 | top5.update(prec5[0], input.size(0)) 250 | 251 | # measure elapsed time 252 | batch_time.update(time.time() - end) 253 | end = time.time() 254 | 255 | if i % configs.TRAIN.print_freq == 0: 256 | print('Train Epoch: [{0}][{1}/{2}]\t' 257 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 258 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 259 | 'Loss {cls_loss.val:.4f} ({cls_loss.avg:.4f})\t' 260 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 261 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t' 262 | 'LR {lr:.3f}'.format( 263 | epoch, i, len(train_loader), batch_time=batch_time, 264 | data_time=data_time, top1=top1, 265 | top5=top5,cls_loss=losses, lr=lr)) 266 | sys.stdout.flush() 267 | 268 | if __name__ == '__main__': 269 | main() 270 | --------------------------------------------------------------------------------