├── .gitignore ├── LICENSE ├── README.md ├── core ├── __init__.py ├── attacks │ ├── __init__.py │ ├── apgd.py │ ├── base.py │ ├── deepfool.py │ ├── fgsm.py │ ├── pgd.py │ └── utils.py ├── data │ ├── __init__.py │ ├── cifar10.py │ ├── cifar100.py │ ├── cifar100s.py │ ├── cifar10s.py │ ├── imagenet100.py │ ├── semisup.py │ ├── svhn.py │ ├── tiny_imagenet.py │ └── utils.py ├── metrics.py ├── models │ ├── __init__.py │ ├── in_preact_resnet.py │ ├── preact_resnet.py │ ├── preact_resnetwithswish.py │ ├── resnet.py │ ├── ti_preact_resnet.py │ ├── wideresnet.py │ └── wideresnetwithswish.py ├── setup.py └── utils │ ├── __init__.py │ ├── context.py │ ├── exp.py │ ├── hat.py │ ├── logger.py │ ├── mart.py │ ├── parser.py │ ├── rst.py │ ├── trades.py │ ├── train.py │ └── utils.py ├── eval-aa.py ├── eval-adv.py ├── eval-rb.py ├── gowal21uncovering └── utils │ ├── __init__.py │ ├── hat.py │ ├── trades.py │ └── watrain.py ├── requirements.txt ├── train-wa.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/*__pycache__/ 2 | **/*.pyc 3 | **/*.ipynb_checkpoints/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Rahul Rade 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Helper-based Adversarial Training 2 | 3 | This repository contains the code for the [ICLR 2022](https://iclr.cc/) paper "[Reducing Excessive Margin to Achieve a Better Accuracy vs. Robustness Trade-off](https://openreview.net/forum?id=Azh9QBQ4tR7)" by Rahul Rade and Seyed-Mohsen Moosavi-Dezfooli. 4 | 5 | A short version of the paper has been accepted for Oral presentation at the [ICML 2021 Workshop on A Blessing in Disguise: The Prospects and Perils of Adversarial Machine Learning](https://advml-workshop.github.io/icml2021/) and can be found at this [link](https://openreview.net/forum?id=BuD2LmNaU3a). 6 | 7 | 8 | ## Setup 9 | 10 | ### Requirements 11 | 12 | Our code has been implemented and tested with `Python 3.8.5` and `PyTorch 1.8.0`. To install the required packages: 13 | ```bash 14 | $ pip install -r requirements.txt 15 | ``` 16 | 17 | ### Repository Structure 18 | 19 | ``` 20 | . 21 | └── core # Source code for the experiments 22 | ├── attacks # Adversarial attacks 23 | ├── data # Data setup and loading 24 | ├── models # Model architectures 25 | └── utils # Helpers, training and testing functions 26 | └── metrics.py # Evaluation metrics 27 | └── train.py # Training script 28 | └── train-wa.py # Training with model weight averaging 29 | └── eval-aa.py # AutoAttack evaluation 30 | └── eval-adv.py # PGD+ and CW evaluation 31 | └── eval-rb.py # RobustBench evaluation 32 | ``` 33 | 34 | ## Usage 35 | 36 | ### Training 37 | 38 | Run [`train.py`](./train.py) for standard, adversarial, TRADES, MART and HAT training. Example commands for HAT training are provided below: 39 | 40 | First, train a ResNet-18 model on CIFAR-10 with standard training: 41 | ``` 42 | $ python train.py --data-dir \ 43 | --log-dir \ 44 | --desc std-cifar10 \ 45 | --data cifar10 \ 46 | --model resnet18 \ 47 | --num-std-epochs 50 48 | ``` 49 | 50 | Then, run the following command to perform helper-based adversarial training (HAT) on CIFAR-10: 51 | 52 | ``` 53 | $ python train.py --data-dir \ 54 | --log-dir \ 55 | --desc hat-cifar10 \ 56 | --data cifar10 \ 57 | --model resnet18 \ 58 | --num-adv-epochs 50 \ 59 | --helper-model std-cifar10 \ 60 | --beta 2.5 \ 61 | --gamma 0.5 62 | ``` 63 | 64 | 65 | ### Robustness Evaluation 66 | 67 | The trained models can be evaluated by running [`eval-aa.py`](./eval-aa.py) which uses [AutoAttack](https://github.com/fra31/auto-attack) for evaluating the robust accuracy. For example: 68 | ``` 69 | $ python eval-aa.py --data-dir \ 70 | --log-dir \ 71 | --desc hat-cifar10 72 | ``` 73 | 74 | For evaluation with PGD+ and CW attacks, use: 75 | ``` 76 | $ python eval-adv.py --wb --data-dir \ 77 | --log-dir \ 78 | --desc hat-cifar10 79 | ``` 80 | 81 | ### Incorporating Improvements from Gowal et al., 2020 & Rebuffi et al., 2021 82 | 83 | HAT can be combined with imporvements from the papers "[Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples](https://arxiv.org/abs/2010.03593)" (Gowal et al., 2020) and "[Fixing Data Augmentation to Improve Adversarial Robustness](https://arxiv.org/abs/2103.01946)" (Rebuffi et al., 2021) to obtain state-of-the-art performance on multiple datasets. 84 | 85 | 86 | #### Training a Standard Network for Computing Helper Labels 87 | 88 | Train a model with standard training as [mentioned above](#training) *or* alternatively download the appropriate pretrained model from this [link](https://www.dropbox.com/sh/vzli8frhfsxo46q/AAB25dkdH6ZaDxNJzHoQNDX8a?dl=0) and place the contents of the corresponding zip file in the directory ``````. 89 | 90 | #### HAT Training 91 | 92 | Run [`train-wa.py`](./train-wa.py) for training a robust network via HAT. For example, to train a WideResNet-28-10 model via HAT on CIFAR-10 with the additional pseudolabeled data provided by [Carmon et al., 2019](https://github.com/yaircarmon/semisup-adv) or the generated datasets provided by [Rebuffi et al., 2021](https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness), use the following command: 93 | 94 | ``` 95 | $ python train-wa.py --data-dir \ 96 | --log-dir \ 97 | --desc \ 98 | --data cifar10s \ 99 | --batch-size 1024 \ 100 | --batch-size-validation 512 \ 101 | --model wrn-28-10-swish \ 102 | --num-adv-epochs 400 \ 103 | --lr 0.4 --tau 0.995 \ 104 | --label-smoothing 0.1 \ 105 | --unsup-fraction 0.7 \ 106 | --aux-data-filename \ 107 | --helper-model \ 108 | --beta 3.5 \ 109 | --gamma 0.5 110 | ``` 111 | 112 | 113 | ## Results 114 | 115 | Below, we provide the results with HAT. In the settings with additional data, we follow the experimental setup used in [Gowal et al., 2020](https://arxiv.org/abs/2010.03593) and [Rebuffi et al., 2021](https://arxiv.org/abs/2103.01946). Whereas we resort to the experimental setup provided in our paper when not using additional data. Our pretrained models are available via [RobustBench](https://robustbench.github.io/). 116 | 117 | #### With extra data from Carmon et al., 2019 along with the improvements by Gowal et al. 2020 118 | 119 | | Dataset | Norm | ε | Model | Clean Acc. | Robust Acc. | 120 | |---|:---:|:---:|:---:|:---:|:---:| 121 | | CIFAR-10 | ℓ | 8/255 | PreActResNet-18 | 89.02 | 57.67 | 122 | | CIFAR-10 | ℓ | 8/255 | WideResNet-28-10 | 91.30 | 62.50 | 123 | | CIFAR-10 | ℓ | 8/255 | WideResNet-34-10 | 91.47 | 62.83 | 124 | 125 | Our models achieve around ~0.3-0.5% lower robustness than that reported in [Gowal et al., 2020](https://arxiv.org/abs/2010.03593) since they use a custom pseudolabeled dataset which is not publicly available (See Section 4.3.1 [here](https://arxiv.org/abs/2010.03593)). 126 | 127 | #### With synthetic DDPM generated data from Rebuffi et al., 2021 128 | 129 | | Dataset | Norm | ε | Model | CutMix | Clean Acc. | Robust Acc. | 130 | |---|:---:|:---:|:---:|:---:|:---:|:---:| 131 | | CIFAR-10 | ℓ | 8/255 | PreActResNet-18 | ✗ | 86.86 | 57.09 | 132 | | CIFAR-10 | ℓ | 8/255 | WideResNet-28-10 | ✗ | 88.16 | 60.97 | 133 | | CIFAR-10 | ℓ2 | 128/255 | PreActResNet-18 | ✗ | 90.57 | 76.07 | 134 | | CIFAR-100 | ℓ | 8/255 | PreActResNet-18 | ✗ | 61.50 | 28.88 | 135 | | CIFAR-100 | ℓ | 8/255 | WideResNet-34-10 | ✗ | 62.21 | 31.16 | 136 | 137 | #### Without additional data 138 | 139 | | Dataset | Norm | ε | Model | Clean Acc. | Robust Acc. | 140 | |---|:---:|:---:|:---:|:---:|:---:| 141 | | CIFAR-10 | ℓ | 8/255 | ResNet-18 | 84.90 | 49.08 | 142 | | CIFAR-10 | ℓ | 12/255 | ResNet-18 | 79.30 | 33.47 | 143 | | SVHN | ℓ | 8/255 | ResNet-18 | 93.08 | 52.83 | 144 | | TI-200 | ℓ | 8/255 | PreActResNet-18 | 52.60 | 18.14 | 145 | 146 | 147 | ## Citing this work 148 | 149 | ``` 150 | @inproceedings{ 151 | rade2022reducing, 152 | title={Reducing Excessive Margin to Achieve a Better Accuracy vs. Robustness Trade-off}, 153 | author={Rahul Rade and Seyed-Mohsen Moosavi-Dezfooli}, 154 | booktitle={International Conference on Learning Representations}, 155 | year={2022}, 156 | url={https://openreview.net/forum?id=Azh9QBQ4tR7} 157 | } 158 | ``` 159 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imrahulr/hat/3177a4f827480e97ac8ea4c5c6acab197cf97e66/core/__init__.py -------------------------------------------------------------------------------- /core/attacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Attack 2 | 3 | from .apgd import LinfAPGDAttack 4 | from .apgd import L2APGDAttack 5 | 6 | from .fgsm import FGMAttack 7 | from .fgsm import FGSMAttack 8 | from .fgsm import L2FastGradientAttack 9 | from .fgsm import LinfFastGradientAttack 10 | 11 | from .pgd import PGDAttack 12 | from .pgd import L2PGDAttack 13 | from .pgd import LinfPGDAttack 14 | 15 | from .deepfool import DeepFoolAttack 16 | from .deepfool import LinfDeepFoolAttack 17 | from .deepfool import L2DeepFoolAttack 18 | 19 | from .utils import CWLoss 20 | 21 | 22 | ATTACKS = ['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd'] 23 | 24 | 25 | def create_attack(model, criterion, attack_type, attack_eps, attack_iter, attack_step, rand_init_type='uniform', 26 | clip_min=0., clip_max=1.): 27 | """ 28 | Initialize adversary. 29 | Arguments: 30 | model (nn.Module): forward pass function. 31 | criterion (nn.Module): loss function. 32 | attack_type (str): name of the attack. 33 | attack_eps (float): attack radius. 34 | attack_iter (int): number of attack iterations. 35 | attack_step (float): step size for the attack. 36 | rand_init_type (str): random initialization type for PGD (default: uniform). 37 | clip_min (float): mininum value per input dimension. 38 | clip_max (float): maximum value per input dimension. 39 | Returns: 40 | Attack 41 | """ 42 | 43 | if attack_type == 'fgsm': 44 | attack = FGSMAttack(model, criterion, eps=attack_eps, clip_min=clip_min, clip_max=clip_max) 45 | elif attack_type == 'fgm': 46 | attack = FGMAttack(model, criterion, eps=attack_eps, clip_min=clip_min, clip_max=clip_max) 47 | elif attack_type == 'linf-pgd': 48 | attack = LinfPGDAttack(model, criterion, eps=attack_eps, nb_iter=attack_iter, eps_iter=attack_step, 49 | rand_init_type=rand_init_type, clip_min=clip_min, clip_max=clip_max) 50 | elif attack_type == 'l2-pgd': 51 | attack = L2PGDAttack(model, criterion, eps=attack_eps, nb_iter=attack_iter, eps_iter=attack_step, 52 | rand_init_type=rand_init_type, clip_min=clip_min, clip_max=clip_max) 53 | elif attack_type == 'linf-df': 54 | attack = LinfDeepFoolAttack(model, overshoot=0.02, nb_iter=attack_iter, search_iter=0, clip_min=clip_min, 55 | clip_max=clip_max) 56 | elif attack_type == 'l2-df': 57 | attack = L2DeepFoolAttack(model, overshoot=0.02, nb_iter=attack_iter, search_iter=0, clip_min=clip_min, 58 | clip_max=clip_max) 59 | elif attack_type == 'linf-apgd': 60 | attack = LinfAPGDAttack(model, criterion, n_restarts=2, eps=attack_eps, nb_iter=attack_iter) 61 | elif attack_type == 'l2-apgd': 62 | attack = L2APGDAttack(model, criterion, n_restarts=2, eps=attack_eps, nb_iter=attack_iter) 63 | else: 64 | raise NotImplementedError('{} is not yet implemented!'.format(attack_type)) 65 | return attack 66 | -------------------------------------------------------------------------------- /core/attacks/apgd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from autoattack.autopgd_base import APGDAttack 5 | 6 | 7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 8 | 9 | 10 | class APGD(): 11 | """ 12 | APGD attack (from AutoAttack) (Croce et al, 2020). 13 | The attack performs nb_iter steps of adaptive size, while always staying within eps from the initial point. 14 | Arguments: 15 | predict (nn.Module): forward pass function. 16 | loss_fn (str): loss function - ce or dlr. 17 | n_restarts (int): number of random restarts. 18 | eps (float): maximum distortion. 19 | nb_iter (int): number of iterations. 20 | ord (int): (optional) the order of maximum distortion (inf or 2). 21 | """ 22 | def __init__(self, predict, loss_fn='ce', n_restarts=2, eps=0.3, nb_iter=40, ord=np.inf, seed=1): 23 | assert loss_fn in ['ce', 'dlr'], 'Only loss_fn=ce or loss_fn=dlr are supported!' 24 | assert ord in [2, np.inf], 'Only ord=inf or ord=2 are supported!' 25 | 26 | norm = 'Linf' if ord == np.inf else 'L2' 27 | self.apgd = APGDAttack(predict, n_restarts=n_restarts, n_iter=nb_iter, verbose=False, eps=eps, norm=norm, 28 | eot_iter=1, rho=.75, seed=seed, device=device) 29 | self.apgd.loss = loss_fn 30 | 31 | def perturb(self, x, y): 32 | x_adv = self.apgd.perturb(x, y)[1] 33 | r_adv = x_adv - x 34 | return x_adv, r_adv 35 | 36 | 37 | class LinfAPGDAttack(APGD): 38 | """ 39 | APGD attack (from AutoAttack) with order=Linf. 40 | The attack performs nb_iter steps of adaptive size, while always staying within eps from the initial point. 41 | Arguments: 42 | predict (nn.Module): forward pass function. 43 | loss_fn (str): loss function - ce or dlr. 44 | n_restarts (int): number of random restarts. 45 | eps (float): maximum distortion. 46 | nb_iter (int): number of iterations. 47 | """ 48 | 49 | def __init__(self, predict, loss_fn='ce', n_restarts=2, eps=0.3, nb_iter=40, seed=1): 50 | ord = np.inf 51 | super(L2APGDAttack, self).__init__( 52 | predict=predict, loss_fn=loss_fn, n_restarts=n_restarts, eps=eps, nb_iter=nb_iter, ord=ord, seed=seed) 53 | 54 | 55 | class L2APGDAttack(APGD): 56 | """ 57 | APGD attack (from AutoAttack) with order=L2. 58 | The attack performs nb_iter steps of adaptive size, while always staying within eps from the initial point. 59 | Arguments: 60 | predict (nn.Module): forward pass function. 61 | loss_fn (str): loss function - ce or dlr. 62 | n_restarts (int): number of random restarts. 63 | eps (float): maximum distortion. 64 | nb_iter (int): number of iterations. 65 | """ 66 | 67 | def __init__(self, predict, loss_fn='ce', n_restarts=2, eps=0.3, nb_iter=40, seed=1): 68 | ord = 2 69 | super(L2APGDAttack, self).__init__( 70 | predict=predict, loss_fn=loss_fn, n_restarts=n_restarts, eps=eps, nb_iter=nb_iter, ord=ord, seed=seed) 71 | -------------------------------------------------------------------------------- /core/attacks/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .utils import replicate_input 5 | 6 | 7 | class Attack(object): 8 | """ 9 | Abstract base class for all attack classes. 10 | Arguments: 11 | predict (nn.Module): forward pass function. 12 | loss_fn (nn.Module): loss function. 13 | clip_min (float): mininum value per input dimension. 14 | clip_max (float): maximum value per input dimension. 15 | """ 16 | 17 | def __init__(self, predict, loss_fn, clip_min, clip_max): 18 | self.predict = predict 19 | self.loss_fn = loss_fn 20 | self.clip_min = clip_min 21 | self.clip_max = clip_max 22 | 23 | def perturb(self, x, **kwargs): 24 | """ 25 | Virtual method for generating the adversarial examples. 26 | Arguments: 27 | x (torch.Tensor): the model's input tensor. 28 | **kwargs: optional parameters used by child classes. 29 | Returns: 30 | adversarial examples. 31 | """ 32 | error = "Sub-classes must implement perturb." 33 | raise NotImplementedError(error) 34 | 35 | def __call__(self, *args, **kwargs): 36 | return self.perturb(*args, **kwargs) 37 | 38 | 39 | class LabelMixin(object): 40 | def _get_predicted_label(self, x): 41 | """ 42 | Compute predicted labels given x. Used to prevent label leaking during adversarial training. 43 | Arguments: 44 | x (torch.Tensor): the model's input tensor. 45 | Returns: 46 | torch.Tensor containing predicted labels. 47 | """ 48 | with torch.no_grad(): 49 | outputs = self.predict(x) 50 | _, y = torch.max(outputs, dim=1) 51 | return y 52 | 53 | def _verify_and_process_inputs(self, x, y): 54 | if self.targeted: 55 | assert y is not None 56 | 57 | if not self.targeted: 58 | if y is None: 59 | y = self._get_predicted_label(x) 60 | 61 | x = replicate_input(x) 62 | y = replicate_input(y) 63 | return x, y -------------------------------------------------------------------------------- /core/attacks/deepfool.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | from .base import Attack, LabelMixin 8 | 9 | from .utils import batch_multiply 10 | from .utils import clamp 11 | from .utils import is_float_or_torch_tensor 12 | 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | 17 | def perturb_deepfool(xvar, yvar, predict, nb_iter=50, overshoot=0.02, ord=np.inf, clip_min=0.0, clip_max=1.0, 18 | search_iter=0, device=None): 19 | """ 20 | Compute DeepFool perturbations (Moosavi-Dezfooli et al, 2016). 21 | Arguments: 22 | xvar (torch.Tensor): input images. 23 | yvar (torch.Tensor): predictions. 24 | predict (nn.Module): forward pass function. 25 | nb_iter (int): number of iterations. 26 | overshoot (float): how much to overshoot the boundary. 27 | ord (int): (optional) the order of maximum distortion (inf or 2). 28 | clip_min (float): mininum value per input dimension. 29 | clip_max (float): maximum value per input dimension. 30 | search_iter (int): no of search iterations. 31 | device (torch.device): device to work on. 32 | Returns: 33 | torch.Tensor containing the perturbed input, 34 | torch.Tensor containing the perturbation 35 | """ 36 | 37 | x_orig = xvar 38 | x = torch.empty_like(xvar).copy_(xvar) 39 | x.requires_grad_(True) 40 | 41 | batch_i = torch.arange(x.shape[0]) 42 | r_tot = torch.zeros_like(x.data) 43 | for i in range(nb_iter): 44 | if x.grad is not None: 45 | x.grad.zero_() 46 | 47 | logits = predict(x) 48 | df_inds = np.argsort(logits.detach().cpu().numpy(), axis=-1) 49 | df_inds_other, df_inds_orig = df_inds[:, :-1], df_inds[:, -1] 50 | df_inds_orig = torch.from_numpy(df_inds_orig) 51 | df_inds_orig = df_inds_orig.to(device) 52 | not_done_inds = df_inds_orig == yvar 53 | if not_done_inds.sum() == 0: 54 | break 55 | 56 | logits[batch_i, df_inds_orig].sum().backward(retain_graph=True) 57 | grad_orig = x.grad.data.clone().detach() 58 | pert = x.data.new_ones(x.shape[0]) * np.inf 59 | w = torch.zeros_like(x.data) 60 | 61 | for inds in df_inds_other.T: 62 | x.grad.zero_() 63 | logits[batch_i, inds].sum().backward(retain_graph=True) 64 | grad_cur = x.grad.data.clone().detach() 65 | with torch.no_grad(): 66 | w_k = grad_cur - grad_orig 67 | f_k = logits[batch_i, inds] - logits[batch_i, df_inds_orig] 68 | if ord == 2: 69 | pert_k = torch.abs(f_k) / torch.norm(w_k.flatten(1), 2, -1) 70 | elif ord == np.inf: 71 | pert_k = torch.abs(f_k) / torch.norm(w_k.flatten(1), 1, -1) 72 | else: 73 | raise NotImplementedError("Only ord=inf and ord=2 have been implemented") 74 | swi = pert_k < pert 75 | if swi.sum() > 0: 76 | pert[swi] = pert_k[swi] 77 | w[swi] = w_k[swi] 78 | 79 | if ord == 2: 80 | r_i = (pert + 1e-6)[:, None, None, None] * w / torch.norm(w.flatten(1), 2, -1)[:, None, None, None] 81 | elif ord == np.inf: 82 | r_i = (pert + 1e-6)[:, None, None, None] * w.sign() 83 | 84 | r_tot += r_i * not_done_inds[:, None, None, None].float() 85 | x.data = x_orig + (1. + overshoot) * r_tot 86 | x.data = torch.clamp(x.data, clip_min, clip_max) 87 | 88 | x = x.detach() 89 | if search_iter > 0: 90 | dx = x - x_orig 91 | dx_l_low, dx_l_high = torch.zeros_like(dx), torch.ones_like(dx) 92 | for i in range(search_iter): 93 | dx_l = (dx_l_low + dx_l_high) / 2. 94 | dx_x = x_orig + dx_l * dx 95 | dx_y = predict(dx_x).argmax(-1) 96 | label_stay = dx_y == yvar 97 | label_change = dx_y != yvar 98 | dx_l_low[label_stay] = dx_l[label_stay] 99 | dx_l_high[label_change] = dx_l[label_change] 100 | x = dx_x 101 | 102 | # x.data = torch.clamp(x.data, clip_min, clip_max) 103 | r_tot = x.data - x_orig 104 | return x, r_tot 105 | 106 | 107 | 108 | class DeepFoolAttack(Attack, LabelMixin): 109 | """ 110 | DeepFool attack. 111 | [Seyed-Mohsen Moosavi-Dezfooli, Alhussein Fawzi, Pascal Frossard, 112 | "DeepFool: a simple and accurate method to fool deep neural networks"] 113 | Arguments: 114 | predict (nn.Module): forward pass function. 115 | overshoot (float): how much to overshoot the boundary. 116 | nb_iter (int): number of iterations. 117 | search_iter (int): no of search iterations. 118 | clip_min (float): mininum value per input dimension. 119 | clip_max (float): maximum value per input dimension. 120 | ord (int): (optional) the order of maximum distortion (inf or 2). 121 | """ 122 | 123 | def __init__( 124 | self, predict, overshoot=0.02, nb_iter=50, search_iter=50, clip_min=0., clip_max=1., ord=np.inf): 125 | super(DeepFoolAttack, self).__init__(predict, None, clip_min, clip_max) 126 | self.overshoot = overshoot 127 | self.nb_iter = nb_iter 128 | self.search_iter = search_iter 129 | self.targeted = False 130 | 131 | self.ord = ord 132 | assert is_float_or_torch_tensor(self.overshoot) 133 | 134 | def perturb(self, x, y=None): 135 | """ 136 | Given examples x, returns their adversarial counterparts. 137 | Arguments: 138 | x (torch.Tensor): input tensor. 139 | y (torch.Tensor): label tensor. 140 | - if None and self.targeted=False, compute y as predicted labels. 141 | Returns: 142 | torch.Tensor containing perturbed inputs, 143 | torch.Tensor containing the perturbation 144 | """ 145 | 146 | x, y = self._verify_and_process_inputs(x, None) 147 | x_adv, r_adv = perturb_deepfool(x, y, self.predict, self.nb_iter, self.overshoot, ord=self.ord, 148 | clip_min=self.clip_min, clip_max=self.clip_max, search_iter=self.search_iter, 149 | device=device) 150 | return x_adv, r_adv 151 | 152 | 153 | class LinfDeepFoolAttack(DeepFoolAttack): 154 | """ 155 | DeepFool Attack with order=Linf. 156 | Arguments: 157 | Arguments: 158 | predict (nn.Module): forward pass function. 159 | overshoot (float): how much to overshoot the boundary. 160 | nb_iter (int): number of iterations. 161 | search_iter (int): no of search iterations. 162 | clip_min (float): mininum value per input dimension. 163 | clip_max (float): maximum value per input dimension. 164 | """ 165 | 166 | def __init__( 167 | self, predict, overshoot=0.02, nb_iter=50, search_iter=50, clip_min=0., clip_max=1.): 168 | 169 | ord = np.inf 170 | super(LinfDeepFoolAttack, self).__init__( 171 | predict=predict, overshoot=overshoot, nb_iter=nb_iter, search_iter=search_iter, clip_min=clip_min, 172 | clip_max=clip_max, ord=ord) 173 | 174 | 175 | 176 | class L2DeepFoolAttack(DeepFoolAttack): 177 | """ 178 | DeepFool Attack with order=L2. 179 | Arguments: 180 | predict (nn.Module): forward pass function. 181 | overshoot (float): how much to overshoot the boundary. 182 | nb_iter (int): number of iterations. 183 | search_iter (int): no of search iterations. 184 | clip_min (float): mininum value per input dimension. 185 | clip_max (float): maximum value per input dimension. 186 | """ 187 | 188 | def __init__( 189 | self, predict, overshoot=0.02, nb_iter=50, search_iter=50, clip_min=0., clip_max=1.): 190 | 191 | ord = 2 192 | super(L2DeepFoolAttack, self).__init__( 193 | predict=predict, overshoot=overshoot, nb_iter=nb_iter, search_iter=search_iter, clip_min=clip_min, 194 | clip_max=clip_max, ord=ord) 195 | -------------------------------------------------------------------------------- /core/attacks/fgsm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .base import Attack, LabelMixin 5 | from .utils import batch_multiply 6 | from .utils import clamp 7 | 8 | 9 | class FGSMAttack(Attack, LabelMixin): 10 | """ 11 | One step fast gradient sign method (Goodfellow et al, 2014). 12 | Arguments: 13 | predict (nn.Module): forward pass function. 14 | loss_fn (nn.Module): loss function. 15 | eps (float): attack step size. 16 | clip_min (float): mininum value per input dimension. 17 | clip_max (float): maximum value per input dimension. 18 | targeted (bool): indicate if this is a targeted attack. 19 | """ 20 | 21 | def __init__(self, predict, loss_fn=None, eps=0.3, clip_min=0., clip_max=1., targeted=False): 22 | super(FGSMAttack, self).__init__(predict, loss_fn, clip_min, clip_max) 23 | 24 | self.eps = eps 25 | self.targeted = targeted 26 | if self.loss_fn is None: 27 | self.loss_fn = nn.CrossEntropyLoss(reduction="sum") 28 | 29 | def perturb(self, x, y=None): 30 | """ 31 | Given examples (x, y), returns their adversarial counterparts with an attack length of eps. 32 | Arguments: 33 | x (torch.Tensor): input tensor. 34 | y (torch.Tensor): label tensor. 35 | - if None and self.targeted=False, compute y as predicted labels. 36 | - if self.targeted=True, then y must be the targeted labels. 37 | Returns: 38 | torch.Tensor containing perturbed inputs. 39 | torch.Tensor containing the perturbation. 40 | """ 41 | 42 | x, y = self._verify_and_process_inputs(x, y) 43 | 44 | xadv = x.requires_grad_() 45 | outputs = self.predict(xadv) 46 | 47 | loss = self.loss_fn(outputs, y) 48 | if self.targeted: 49 | loss = -loss 50 | loss.backward() 51 | grad_sign = xadv.grad.detach().sign() 52 | 53 | xadv = xadv + batch_multiply(self.eps, grad_sign) 54 | xadv = clamp(xadv, self.clip_min, self.clip_max) 55 | radv = xadv - x 56 | return xadv.detach(), radv.detach() 57 | 58 | 59 | LinfFastGradientAttack = FGSMAttack 60 | 61 | 62 | class FGMAttack(Attack, LabelMixin): 63 | """ 64 | One step fast gradient method. Perturbs the input with gradient (not gradient sign) of the loss wrt the input. 65 | Arguments: 66 | predict (nn.Module): forward pass function. 67 | loss_fn (nn.Module): loss function. 68 | eps (float): attack step size. 69 | clip_min (float): mininum value per input dimension. 70 | clip_max (float): maximum value per input dimension. 71 | targeted (bool): indicate if this is a targeted attack. 72 | """ 73 | 74 | def __init__(self, predict, loss_fn=None, eps=0.3, clip_min=0., clip_max=1., targeted=False): 75 | super(FGMAttack, self).__init__( 76 | predict, loss_fn, clip_min, clip_max) 77 | 78 | self.eps = eps 79 | self.targeted = targeted 80 | if self.loss_fn is None: 81 | self.loss_fn = nn.CrossEntropyLoss(reduction="sum") 82 | 83 | def perturb(self, x, y=None): 84 | """ 85 | Given examples (x, y), returns their adversarial counterparts with an attack length of eps. 86 | Arguments: 87 | x (torch.Tensor): input tensor. 88 | y (torch.Tensor): label tensor. 89 | - if None and self.targeted=False, compute y as predicted labels. 90 | - if self.targeted=True, then y must be the targeted labels. 91 | Returns: 92 | torch.Tensor containing perturbed inputs. 93 | torch.Tensor containing the perturbation. 94 | """ 95 | 96 | x, y = self._verify_and_process_inputs(x, y) 97 | xadv = x.requires_grad_() 98 | outputs = self.predict(xadv) 99 | 100 | loss = self.loss_fn(outputs, y) 101 | if self.targeted: 102 | loss = -loss 103 | loss.backward() 104 | grad = normalize_by_pnorm(xadv.grad) 105 | xadv = xadv + batch_multiply(self.eps, grad) 106 | xadv = clamp(xadv, self.clip_min, self.clip_max) 107 | radv = xadv - x 108 | 109 | return xadv.detach(), radv.detach() 110 | 111 | 112 | L2FastGradientAttack = FGMAttack -------------------------------------------------------------------------------- /core/attacks/pgd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .base import Attack, LabelMixin 6 | 7 | from .utils import batch_clamp 8 | from .utils import batch_multiply 9 | from .utils import clamp 10 | from .utils import clamp_by_pnorm 11 | from .utils import is_float_or_torch_tensor 12 | from .utils import normalize_by_pnorm 13 | from .utils import rand_init_delta 14 | from .utils import replicate_input 15 | 16 | 17 | def perturb_iterative(xvar, yvar, predict, nb_iter, eps, eps_iter, loss_fn, delta_init=None, minimize=False, ord=np.inf, 18 | clip_min=0.0, clip_max=1.0): 19 | """ 20 | Iteratively maximize the loss over the input. It is a shared method for iterative attacks. 21 | Arguments: 22 | xvar (torch.Tensor): input data. 23 | yvar (torch.Tensor): input labels. 24 | predict (nn.Module): forward pass function. 25 | nb_iter (int): number of iterations. 26 | eps (float): maximum distortion. 27 | eps_iter (float): attack step size. 28 | loss_fn (nn.Module): loss function. 29 | delta_init (torch.Tensor): (optional) tensor contains the random initialization. 30 | minimize (bool): (optional) whether to minimize or maximize the loss. 31 | ord (int): (optional) the order of maximum distortion (inf or 2). 32 | clip_min (float): mininum value per input dimension. 33 | clip_max (float): maximum value per input dimension. 34 | Returns: 35 | torch.Tensor containing the perturbed input, 36 | torch.Tensor containing the perturbation 37 | """ 38 | if delta_init is not None: 39 | delta = delta_init 40 | else: 41 | delta = torch.zeros_like(xvar) 42 | 43 | delta.requires_grad_() 44 | for ii in range(nb_iter): 45 | outputs = predict(xvar + delta) 46 | loss = loss_fn(outputs, yvar) 47 | if minimize: 48 | loss = -loss 49 | 50 | loss.backward() 51 | if ord == np.inf: 52 | grad_sign = delta.grad.data.sign() 53 | delta.data = delta.data + batch_multiply(eps_iter, grad_sign) 54 | delta.data = batch_clamp(eps, delta.data) 55 | delta.data = clamp(xvar.data + delta.data, clip_min, clip_max) - xvar.data 56 | elif ord == 2: 57 | grad = delta.grad.data 58 | grad = normalize_by_pnorm(grad) 59 | delta.data = delta.data + batch_multiply(eps_iter, grad) 60 | delta.data = clamp(xvar.data + delta.data, clip_min, clip_max) - xvar.data 61 | if eps is not None: 62 | delta.data = clamp_by_pnorm(delta.data, ord, eps) 63 | else: 64 | error = "Only ord=inf and ord=2 have been implemented" 65 | raise NotImplementedError(error) 66 | delta.grad.data.zero_() 67 | 68 | x_adv = clamp(xvar + delta, clip_min, clip_max) 69 | r_adv = x_adv - xvar 70 | return x_adv, r_adv 71 | 72 | 73 | class PGDAttack(Attack, LabelMixin): 74 | """ 75 | The projected gradient descent attack (Madry et al, 2017). 76 | The attack performs nb_iter steps of size eps_iter, while always staying within eps from the initial point. 77 | Arguments: 78 | predict (nn.Module): forward pass function. 79 | loss_fn (nn.Module): loss function. 80 | eps (float): maximum distortion. 81 | nb_iter (int): number of iterations. 82 | eps_iter (float): attack step size. 83 | rand_init (bool): (optional) random initialization. 84 | clip_min (float): mininum value per input dimension. 85 | clip_max (float): maximum value per input dimension. 86 | ord (int): (optional) the order of maximum distortion (inf or 2). 87 | targeted (bool): if the attack is targeted. 88 | rand_init_type (str): (optional) random initialization type. 89 | """ 90 | 91 | def __init__( 92 | self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1., 93 | ord=np.inf, targeted=False, rand_init_type='uniform'): 94 | super(PGDAttack, self).__init__(predict, loss_fn, clip_min, clip_max) 95 | self.eps = eps 96 | self.nb_iter = nb_iter 97 | self.eps_iter = eps_iter 98 | self.rand_init = rand_init 99 | self.rand_init_type = rand_init_type 100 | self.ord = ord 101 | self.targeted = targeted 102 | if self.loss_fn is None: 103 | self.loss_fn = nn.CrossEntropyLoss(reduction="sum") 104 | assert is_float_or_torch_tensor(self.eps_iter) 105 | assert is_float_or_torch_tensor(self.eps) 106 | 107 | def perturb(self, x, y=None): 108 | """ 109 | Given examples (x, y), returns their adversarial counterparts with an attack length of eps. 110 | Arguments: 111 | x (torch.Tensor): input tensor. 112 | y (torch.Tensor): label tensor. 113 | - if None and self.targeted=False, compute y as predicted 114 | labels. 115 | - if self.targeted=True, then y must be the targeted labels. 116 | Returns: 117 | torch.Tensor containing perturbed inputs, 118 | torch.Tensor containing the perturbation 119 | """ 120 | x, y = self._verify_and_process_inputs(x, y) 121 | 122 | delta = torch.zeros_like(x) 123 | delta = nn.Parameter(delta) 124 | if self.rand_init: 125 | if self.rand_init_type == 'uniform': 126 | rand_init_delta( 127 | delta, x, self.ord, self.eps, self.clip_min, self.clip_max) 128 | delta.data = clamp( 129 | x + delta.data, min=self.clip_min, max=self.clip_max) - x 130 | elif self.rand_init_type == 'normal': 131 | delta.data = 0.001 * torch.randn_like(x) # initialize as in TRADES 132 | else: 133 | raise NotImplementedError('Only rand_init_type=normal and rand_init_type=uniform have been implemented.') 134 | 135 | x_adv, r_adv = perturb_iterative( 136 | x, y, self.predict, nb_iter=self.nb_iter, eps=self.eps, eps_iter=self.eps_iter, loss_fn=self.loss_fn, 137 | minimize=self.targeted, ord=self.ord, clip_min=self.clip_min, clip_max=self.clip_max, delta_init=delta 138 | ) 139 | 140 | return x_adv.data, r_adv.data 141 | 142 | 143 | class LinfPGDAttack(PGDAttack): 144 | """ 145 | PGD Attack with order=Linf 146 | Arguments: 147 | predict (nn.Module): forward pass function. 148 | loss_fn (nn.Module): loss function. 149 | eps (float): maximum distortion. 150 | nb_iter (int): number of iterations. 151 | eps_iter (float): attack step size. 152 | rand_init (bool): (optional) random initialization. 153 | clip_min (float): mininum value per input dimension. 154 | clip_max (float): maximum value per input dimension. 155 | targeted (bool): if the attack is targeted. 156 | rand_init_type (str): (optional) random initialization type. 157 | """ 158 | 159 | def __init__( 160 | self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1., 161 | targeted=False, rand_init_type='uniform'): 162 | ord = np.inf 163 | super(LinfPGDAttack, self).__init__( 164 | predict=predict, loss_fn=loss_fn, eps=eps, nb_iter=nb_iter, eps_iter=eps_iter, rand_init=rand_init, 165 | clip_min=clip_min, clip_max=clip_max, targeted=targeted, ord=ord, rand_init_type=rand_init_type) 166 | 167 | 168 | class L2PGDAttack(PGDAttack): 169 | """ 170 | PGD Attack with order=L2 171 | Arguments: 172 | predict (nn.Module): forward pass function. 173 | loss_fn (nn.Module): loss function. 174 | eps (float): maximum distortion. 175 | nb_iter (int): number of iterations. 176 | eps_iter (float): attack step size. 177 | rand_init (bool): (optional) random initialization. 178 | clip_min (float): mininum value per input dimension. 179 | clip_max (float): maximum value per input dimension. 180 | targeted (bool): if the attack is targeted. 181 | rand_init_type (str): (optional) random initialization type. 182 | """ 183 | 184 | def __init__( 185 | self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1., 186 | targeted=False, rand_init_type='uniform'): 187 | ord = 2 188 | super(L2PGDAttack, self).__init__( 189 | predict=predict, loss_fn=loss_fn, eps=eps, nb_iter=nb_iter, eps_iter=eps_iter, rand_init=rand_init, 190 | clip_min=clip_min, clip_max=clip_max, targeted=targeted, ord=ord, rand_init_type=rand_init_type) 191 | -------------------------------------------------------------------------------- /core/attacks/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | from torch.distributions import laplace 9 | from torch.distributions import uniform 10 | from torch.nn.modules.loss import _Loss 11 | 12 | 13 | def replicate_input(x): 14 | """ 15 | Clone the input tensor x. 16 | """ 17 | return x.detach().clone() 18 | 19 | 20 | def replicate_input_withgrad(x): 21 | """ 22 | Clone the input tensor x and set requires_grad=True. 23 | """ 24 | return x.detach().clone().requires_grad_() 25 | 26 | 27 | def calc_l2distsq(x, y): 28 | """ 29 | Calculate L2 distance between tensors x and y. 30 | """ 31 | d = (x - y)**2 32 | return d.view(d.shape[0], -1).sum(dim=1) 33 | 34 | 35 | def clamp(input, min=None, max=None): 36 | """ 37 | Clamp a tensor by its minimun and maximun values. 38 | """ 39 | ndim = input.ndimension() 40 | if min is None: 41 | pass 42 | elif isinstance(min, (float, int)): 43 | input = torch.clamp(input, min=min) 44 | elif isinstance(min, torch.Tensor): 45 | if min.ndimension() == ndim - 1 and min.shape == input.shape[1:]: 46 | input = torch.max(input, min.view(1, *min.shape)) 47 | else: 48 | assert min.shape == input.shape 49 | input = torch.max(input, min) 50 | else: 51 | raise ValueError("min can only be None | float | torch.Tensor") 52 | 53 | if max is None: 54 | pass 55 | elif isinstance(max, (float, int)): 56 | input = torch.clamp(input, max=max) 57 | elif isinstance(max, torch.Tensor): 58 | if max.ndimension() == ndim - 1 and max.shape == input.shape[1:]: 59 | input = torch.min(input, max.view(1, *max.shape)) 60 | else: 61 | assert max.shape == input.shape 62 | input = torch.min(input, max) 63 | else: 64 | raise ValueError("max can only be None | float | torch.Tensor") 65 | return input 66 | 67 | 68 | def _batch_multiply_tensor_by_vector(vector, batch_tensor): 69 | """Equivalent to the following. 70 | for ii in range(len(vector)): 71 | batch_tensor.data[ii] *= vector[ii] 72 | return batch_tensor 73 | """ 74 | return ( 75 | batch_tensor.transpose(0, -1) * vector).transpose(0, -1).contiguous() 76 | 77 | 78 | def _batch_clamp_tensor_by_vector(vector, batch_tensor): 79 | """Equivalent to the following. 80 | for ii in range(len(vector)): 81 | batch_tensor[ii] = clamp( 82 | batch_tensor[ii], -vector[ii], vector[ii]) 83 | """ 84 | return torch.min( 85 | torch.max(batch_tensor.transpose(0, -1), -vector), vector 86 | ).transpose(0, -1).contiguous() 87 | 88 | 89 | def batch_multiply(float_or_vector, tensor): 90 | """ 91 | Multpliy a batch of tensors with a float or vector. 92 | """ 93 | if isinstance(float_or_vector, torch.Tensor): 94 | assert len(float_or_vector) == len(tensor) 95 | tensor = _batch_multiply_tensor_by_vector(float_or_vector, tensor) 96 | elif isinstance(float_or_vector, float): 97 | tensor *= float_or_vector 98 | else: 99 | raise TypeError("Value has to be float or torch.Tensor") 100 | return tensor 101 | 102 | 103 | def batch_clamp(float_or_vector, tensor): 104 | """ 105 | Clamp a batch of tensors. 106 | """ 107 | if isinstance(float_or_vector, torch.Tensor): 108 | assert len(float_or_vector) == len(tensor) 109 | tensor = _batch_clamp_tensor_by_vector(float_or_vector, tensor) 110 | return tensor 111 | elif isinstance(float_or_vector, float): 112 | tensor = clamp(tensor, -float_or_vector, float_or_vector) 113 | else: 114 | raise TypeError("Value has to be float or torch.Tensor") 115 | return tensor 116 | 117 | 118 | def _get_norm_batch(x, p): 119 | """ 120 | Returns the Lp norm of batch x. 121 | """ 122 | batch_size = x.size(0) 123 | return x.abs().pow(p).view(batch_size, -1).sum(dim=1).pow(1. / p) 124 | 125 | 126 | def _thresh_by_magnitude(theta, x): 127 | """ 128 | Threshold by magnitude. 129 | """ 130 | return torch.relu(torch.abs(x) - theta) * x.sign() 131 | 132 | 133 | def clamp_by_pnorm(x, p, r): 134 | """ 135 | Clamp tensor by its norm. 136 | """ 137 | assert isinstance(p, float) or isinstance(p, int) 138 | norm = _get_norm_batch(x, p) 139 | if isinstance(r, torch.Tensor): 140 | assert norm.size() == r.size() 141 | else: 142 | assert isinstance(r, float) 143 | factor = torch.min(r / norm, torch.ones_like(norm)) 144 | return batch_multiply(factor, x) 145 | 146 | 147 | def is_float_or_torch_tensor(x): 148 | """ 149 | Return whether input x is a float or a torch.Tensor. 150 | """ 151 | return isinstance(x, torch.Tensor) or isinstance(x, float) 152 | 153 | 154 | def normalize_by_pnorm(x, p=2, small_constant=1e-6): 155 | """ 156 | Normalize gradients for gradient (not gradient sign) attacks. 157 | Arguments: 158 | x (torch.Tensor): tensor containing the gradients on the input. 159 | p (int): (optional) order of the norm for the normalization (1 or 2). 160 | small_constant (float): (optional) to avoid dividing by zero. 161 | Returns: 162 | normalized gradients. 163 | """ 164 | assert isinstance(p, float) or isinstance(p, int) 165 | norm = _get_norm_batch(x, p) 166 | norm = torch.max(norm, torch.ones_like(norm) * small_constant) 167 | return batch_multiply(1. / norm, x) 168 | 169 | 170 | def rand_init_delta(delta, x, ord, eps, clip_min, clip_max): 171 | """ 172 | Randomly initialize the perturbation. 173 | """ 174 | if isinstance(eps, torch.Tensor): 175 | assert len(eps) == len(delta) 176 | 177 | if ord == np.inf: 178 | delta.data.uniform_(-1, 1) 179 | delta.data = batch_multiply(eps, delta.data) 180 | elif ord == 2: 181 | delta.data.uniform_(clip_min, clip_max) 182 | delta.data = delta.data - x 183 | delta.data = clamp_by_pnorm(delta.data, ord, eps) 184 | elif ord == 1: 185 | ini = laplace.Laplace( 186 | loc=delta.new_tensor(0), scale=delta.new_tensor(1)) 187 | delta.data = ini.sample(delta.data.shape) 188 | delta.data = normalize_by_pnorm(delta.data, p=1) 189 | ray = uniform.Uniform(0, eps).sample() 190 | delta.data *= ray 191 | delta.data = clamp(x.data + delta.data, clip_min, clip_max) - x.data 192 | else: 193 | error = "Only ord = inf, ord = 1 and ord = 2 have been implemented" 194 | raise NotImplementedError(error) 195 | 196 | delta.data = clamp( 197 | x + delta.data, min=clip_min, max=clip_max) - x 198 | return delta.data 199 | 200 | 201 | def CWLoss(output, target, confidence=0): 202 | """ 203 | CW loss (Marging loss). 204 | """ 205 | num_classes = output.shape[-1] 206 | target = target.data 207 | target_onehot = torch.zeros(target.size() + (num_classes,)) 208 | target_onehot = target_onehot.cuda() 209 | target_onehot.scatter_(1, target.unsqueeze(1), 1.) 210 | target_var = Variable(target_onehot, requires_grad=False) 211 | real = (target_var * output).sum(1) 212 | other = ((1. - target_var) * output - target_var * 10000.).max(1)[0] 213 | loss = - torch.clamp(real - other + confidence, min=0.) 214 | loss = torch.sum(loss) 215 | return loss 216 | -------------------------------------------------------------------------------- /core/data/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from .cifar10 import load_cifar10 5 | from .cifar100 import load_cifar100 6 | from .svhn import load_svhn 7 | from .cifar10s import load_cifar10s 8 | from .tiny_imagenet import load_tinyimagenet 9 | from .cifar100s import load_cifar100s 10 | from .imagenet100 import load_imagenet100 11 | 12 | from .semisup import get_semisup_dataloaders 13 | 14 | 15 | SEMISUP_DATASETS = ['cifar10s', 'cifar100s'] 16 | DATASETS = ['cifar10', 'svhn', 'cifar100', 'tiny-imagenet', 'imagenet100'] + SEMISUP_DATASETS 17 | 18 | _LOAD_DATASET_FN = { 19 | 'cifar10': load_cifar10, 20 | 'cifar100': load_cifar100, 21 | 'svhn': load_svhn, 22 | 'tiny-imagenet': load_tinyimagenet, 23 | 'cifar10s': load_cifar10s, 24 | 'cifar100s': load_cifar100s, 25 | 'imagenet100': load_imagenet100, 26 | } 27 | 28 | 29 | def get_data_info(data_dir): 30 | """ 31 | Returns dataset information. 32 | Arguments: 33 | data_dir (str): path to data directory. 34 | """ 35 | dataset = os.path.basename(os.path.normpath(data_dir)) 36 | if 'cifar100' in data_dir: 37 | from .cifar100 import DATA_DESC 38 | elif 'cifar10' in data_dir: 39 | from .cifar10 import DATA_DESC 40 | elif 'svhn' in data_dir: 41 | from .svhn import DATA_DESC 42 | elif 'tiny-imagenet' in data_dir: 43 | from .tiny_imagenet import DATA_DESC 44 | elif 'imagenet100' in data_dir: 45 | from .imagenet100 import DATA_DESC 46 | else: 47 | raise ValueError(f'Only data in {DATASETS} are supported!') 48 | DATA_DESC['data'] = dataset 49 | return DATA_DESC 50 | 51 | 52 | def load_data(data_dir, batch_size=256, batch_size_test=256, num_workers=8, use_augmentation=False, shuffle_train=True, 53 | aux_data_filename=None, unsup_fraction=None, validation=False): 54 | """ 55 | Returns train, test datasets and dataloaders. 56 | Arguments: 57 | data_dir (str): path to data directory. 58 | batch_size (int): batch size for training. 59 | batch_size_test (int): batch size for validation. 60 | num_workers (int): number of workers for loading the data. 61 | use_augmentation (bool): whether to use augmentations for training set. 62 | shuffle_train (bool): whether to shuffle training set. 63 | aux_data_filename (str): path to unlabelled data. 64 | unsup_fraction (float): fraction of unlabelled data per batch. 65 | validation (bool): if True, also returns a validation dataloader for unspervised cifar10 (as in Gowal et al, 2020). 66 | """ 67 | dataset = os.path.basename(os.path.normpath(data_dir)) 68 | load_dataset_fn = _LOAD_DATASET_FN[dataset] 69 | 70 | if dataset in SEMISUP_DATASETS: 71 | train_dataset, test_dataset, val_dataset = load_dataset_fn(data_dir=data_dir, use_augmentation=use_augmentation, 72 | aux_data_filename=aux_data_filename, validation=validation) 73 | else: 74 | train_dataset, test_dataset = load_dataset_fn(data_dir=data_dir, use_augmentation=use_augmentation) 75 | if validation: 76 | num_train_samples = len(train_dataset) 77 | val_dataset = torch.utils.data.Subset(train_dataset, torch.arange(0, 1024)) 78 | train_dataset = torch.utils.data.Subset(train_dataset, torch.arange(1024, num_train_samples)) 79 | 80 | if dataset in SEMISUP_DATASETS: 81 | train_dataloader, test_dataloader, val_dataloader = get_semisup_dataloaders( 82 | train_dataset, test_dataset, val_dataset, batch_size, batch_size_test, num_workers, unsup_fraction 83 | ) 84 | else: 85 | pin_memory = torch.cuda.is_available() 86 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle_train, 87 | num_workers=num_workers, pin_memory=pin_memory) 88 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False, 89 | num_workers=num_workers, pin_memory=pin_memory) 90 | if validation: 91 | val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size_test, shuffle=False, 92 | num_workers=num_workers, pin_memory=pin_memory) 93 | 94 | if validation: 95 | return train_dataset, test_dataset, val_dataset, train_dataloader, test_dataloader, val_dataloader 96 | return train_dataset, test_dataset, train_dataloader, test_dataloader 97 | -------------------------------------------------------------------------------- /core/data/cifar10.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | 7 | DATA_DESC = { 8 | 'data': 'cifar10', 9 | 'classes': ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'), 10 | 'num_classes': 10, 11 | 'mean': [0.4914, 0.4822, 0.4465], 12 | 'std': [0.2023, 0.1994, 0.2010], 13 | } 14 | 15 | 16 | def load_cifar10(data_dir, use_augmentation=False): 17 | """ 18 | Returns CIFAR10 train, test datasets and dataloaders. 19 | Arguments: 20 | data_dir (str): path to data directory. 21 | use_augmentation (bool): whether to use augmentations for training set. 22 | Returns: 23 | train dataset, test dataset. 24 | """ 25 | test_transform = transforms.Compose([transforms.ToTensor()]) 26 | if use_augmentation: 27 | train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(0.5), 28 | transforms.ToTensor()]) 29 | else: 30 | train_transform = test_transform 31 | 32 | train_dataset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform) 33 | test_dataset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=test_transform) 34 | return train_dataset, test_dataset -------------------------------------------------------------------------------- /core/data/cifar100.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | 7 | DATA_DESC = { 8 | 'data': 'cifar100', 9 | 'classes': tuple(range(0, 100)), 10 | 'num_classes': 100, 11 | 'mean': [0.5071, 0.4865, 0.4409], 12 | 'std': [0.2673, 0.2564, 0.2762], 13 | } 14 | 15 | 16 | def load_cifar100(data_dir, use_augmentation=False): 17 | """ 18 | Returns CIFAR100 train, test datasets and dataloaders. 19 | Arguments: 20 | data_dir (str): path to data directory. 21 | use_augmentation (bool): whether to use augmentations for training set. 22 | Returns: 23 | train dataset, test dataset. 24 | """ 25 | test_transform = transforms.Compose([transforms.ToTensor()]) 26 | if use_augmentation: 27 | train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(0.5), 28 | transforms.RandomRotation(15), transforms.ToTensor()]) 29 | else: 30 | train_transform = test_transform 31 | 32 | train_dataset = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform) 33 | test_dataset = torchvision.datasets.CIFAR100(root=data_dir, train=False, download=True, transform=test_transform) 34 | return train_dataset, test_dataset -------------------------------------------------------------------------------- /core/data/cifar100s.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | import re 7 | import numpy as np 8 | 9 | from .semisup import SemiSupervisedDataset 10 | 11 | 12 | def load_cifar100s(data_dir, use_augmentation=False, aux_take_amount=None, 13 | aux_data_filename=None, validation=False): 14 | """ 15 | Returns semisupervised CIFAR100 train, test datasets and dataloaders (with DDPM Images). 16 | Arguments: 17 | data_dir (str): path to data directory. 18 | use_augmentation (bool): whether to use augmentations for training set. 19 | aux_take_amount (int): number of semi-supervised examples to use (if None, use all). 20 | aux_data_filename (str): path to additional data pickle file. 21 | Returns: 22 | train dataset, test dataset. 23 | """ 24 | data_dir = re.sub('cifar100s', 'cifar100', data_dir) 25 | test_transform = transforms.Compose([transforms.ToTensor()]) 26 | if use_augmentation: 27 | train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(0.5), 28 | transforms.RandomRotation(15), transforms.ToTensor()]) 29 | else: 30 | train_transform = test_transform 31 | 32 | train_dataset = SemiSupervisedCIFAR100(base_dataset='cifar100', root=data_dir, train=True, download=True, 33 | transform=train_transform, aux_data_filename=aux_data_filename, 34 | add_aux_labels=True, aux_take_amount=aux_take_amount, validation=validation) 35 | test_dataset = SemiSupervisedCIFAR100(base_dataset='cifar100', root=data_dir, train=False, download=True, 36 | transform=test_transform) 37 | if validation: 38 | val_dataset = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=test_transform) 39 | val_dataset = torch.utils.data.Subset(val_dataset, np.arange(0, 1024)) 40 | return train_dataset, test_dataset, val_dataset 41 | return train_dataset, test_dataset, None 42 | 43 | 44 | class SemiSupervisedCIFAR100(SemiSupervisedDataset): 45 | """ 46 | A dataset with auxiliary pseudo-labeled data for CIFAR100. 47 | """ 48 | def load_base_dataset(self, train=False, **kwargs): 49 | assert self.base_dataset == 'cifar100', 'Only semi-supervised cifar100 is supported. Please use correct dataset!' 50 | self.dataset = torchvision.datasets.CIFAR100(train=train, **kwargs) 51 | self.dataset_size = len(self.dataset) -------------------------------------------------------------------------------- /core/data/cifar10s.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | import re 7 | import numpy as np 8 | 9 | from .semisup import SemiSupervisedDataset 10 | 11 | 12 | def load_cifar10s(data_dir, use_augmentation=False, aux_take_amount=None, 13 | aux_data_filename='/cluster/scratch/rarade/cifar10s/ti_500K_pseudo_labeled.pickle', 14 | validation=False): 15 | """ 16 | Returns semisupervised CIFAR10 train, test datasets and dataloaders (with Tiny Images). 17 | Arguments: 18 | data_dir (str): path to data directory. 19 | use_augmentation (bool): whether to use augmentations for training set. 20 | aux_take_amount (int): number of semi-supervised examples to use (if None, use all). 21 | aux_data_filename (str): path to additional data pickle file. 22 | Returns: 23 | train dataset, test dataset. 24 | """ 25 | data_dir = re.sub('cifar10s', 'cifar10', data_dir) 26 | test_transform = transforms.Compose([transforms.ToTensor()]) 27 | if use_augmentation: 28 | train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(0.5), 29 | transforms.ToTensor()]) 30 | else: 31 | train_transform = test_transform 32 | 33 | train_dataset = SemiSupervisedCIFAR10(base_dataset='cifar10', root=data_dir, train=True, download=True, 34 | transform=train_transform, aux_data_filename=aux_data_filename, 35 | add_aux_labels=True, aux_take_amount=aux_take_amount, validation=validation) 36 | test_dataset = SemiSupervisedCIFAR10(base_dataset='cifar10', root=data_dir, train=False, download=True, 37 | transform=test_transform) 38 | if validation: 39 | val_dataset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=test_transform) 40 | val_dataset = torch.utils.data.Subset(val_dataset, np.arange(0, 1024)) 41 | return train_dataset, test_dataset, val_dataset 42 | return train_dataset, test_dataset, None 43 | 44 | 45 | class SemiSupervisedCIFAR10(SemiSupervisedDataset): 46 | """ 47 | A dataset with auxiliary pseudo-labeled data for CIFAR10. 48 | """ 49 | def load_base_dataset(self, train=False, **kwargs): 50 | assert self.base_dataset == 'cifar10', 'Only semi-supervised cifar10 is supported. Please use correct dataset!' 51 | self.dataset = torchvision.datasets.CIFAR10(train=train, **kwargs) 52 | self.dataset_size = len(self.dataset) -------------------------------------------------------------------------------- /core/data/imagenet100.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torchvision.transforms as transforms 4 | from robustness.datasets import CustomImageNet 5 | 6 | 7 | DATA_DESC = { 8 | 'data': 'imagenet100', 9 | 'classes': np.arange(100), 10 | 'num_classes': 100, 11 | 'mean': [0.485, 0.456, 0.406], 12 | 'std': [0.229, 0.224, 0.225], 13 | } 14 | 15 | 16 | class ImageNet100(CustomImageNet): 17 | def __init__(self, data_path, **kwargs): 18 | super().__init__( 19 | data_path=data_path, 20 | custom_grouping=[[label] for label in range(0, 1000, 10)] if '100' not in data_path else 21 | [[label] for label in range(0, 100)], 22 | **kwargs, 23 | ) 24 | 25 | def load_imagenet100(data_dir, use_augmentation=False): 26 | """ 27 | Returns ImageNet100 train, test datasets. 28 | Arguments: 29 | data_dir (str): path to data directory. 30 | use_augmentation (bool): whether to use augmentations for training set. 31 | Returns: 32 | train dataset, test dataset. 33 | """ 34 | test_transform = transforms.Compose([ 35 | transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]) 36 | if use_augmentation: 37 | train_transform = transforms.Compose([ 38 | transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(0.5), transforms.ToTensor()]) 39 | else: 40 | train_transform = test_transform 41 | 42 | temp_dataset = ImageNet100(data_dir) 43 | train_dataloader, test_dataloader = temp_dataset.make_loaders(4, 128) 44 | 45 | train_dataset = train_dataloader.dataset 46 | train_dataset.transform = train_transform 47 | test_dataset = test_dataloader.dataset 48 | test_dataset.transform = test_transform 49 | return train_dataset, test_dataset -------------------------------------------------------------------------------- /core/data/semisup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | 5 | import torch 6 | 7 | 8 | def get_semisup_dataloaders(train_dataset, test_dataset, val_dataset=None, batch_size=256, batch_size_test=256, num_workers=4, 9 | unsup_fraction=0.5): 10 | """ 11 | Return dataloaders with custom sampling of pseudo-labeled data. 12 | """ 13 | dataset_size = train_dataset.dataset_size 14 | train_batch_sampler = SemiSupervisedSampler(train_dataset.sup_indices, train_dataset.unsup_indices, batch_size, 15 | unsup_fraction, num_batches=int(np.ceil(dataset_size/batch_size))) 16 | epoch_size = len(train_batch_sampler) * batch_size 17 | 18 | kwargs = {'num_workers': num_workers, 'pin_memory': torch.cuda.is_available() } 19 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_sampler=train_batch_sampler, **kwargs) 20 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False, **kwargs) 21 | 22 | if val_dataset: 23 | val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size_test, shuffle=False, **kwargs) 24 | return train_dataloader, test_dataloader, val_dataloader 25 | return train_dataloader, test_dataloader, None 26 | 27 | 28 | class SemiSupervisedDataset(torch.utils.data.Dataset): 29 | """ 30 | A dataset with auxiliary pseudo-labeled data. 31 | """ 32 | def __init__(self, base_dataset='cifar10', take_amount=None, take_amount_seed=13, aux_data_filename=None, 33 | add_aux_labels=False, aux_take_amount=None, train=False, validation=False, **kwargs): 34 | 35 | self.base_dataset = base_dataset 36 | self.load_base_dataset(train, **kwargs) 37 | 38 | if validation: 39 | self.dataset.data = self.dataset.data[1024:] 40 | self.dataset.targets = self.dataset.targets[1024:] 41 | 42 | self.train = train 43 | 44 | if self.train: 45 | if take_amount is not None: 46 | rng_state = np.random.get_state() 47 | np.random.seed(take_amount_seed) 48 | take_inds = np.random.choice(len(self.sup_indices), take_amount, replace=False) 49 | np.random.set_state(rng_state) 50 | 51 | self.targets = self.targets[take_inds] 52 | self.data = self.data[take_inds] 53 | 54 | self.sup_indices = list(range(len(self.targets))) 55 | self.unsup_indices = [] 56 | 57 | if aux_data_filename is not None: 58 | aux_path = aux_data_filename 59 | print('Loading data from %s' % aux_path) 60 | if os.path.splitext(aux_path)[1] == '.pickle': 61 | # for data from Carmon et al, 2019. 62 | with open(aux_path, 'rb') as f: 63 | aux = pickle.load(f) 64 | aux_data = aux['data'] 65 | aux_targets = aux['extrapolated_targets'] 66 | else: 67 | # for data from Rebuffi et al, 2021. 68 | aux = np.load(aux_path) 69 | aux_data = aux['image'] 70 | aux_targets = aux['label'] 71 | 72 | orig_len = len(self.data) 73 | 74 | if aux_take_amount is not None: 75 | rng_state = np.random.get_state() 76 | np.random.seed(take_amount_seed) 77 | take_inds = np.random.choice(len(aux_data), aux_take_amount, replace=False) 78 | np.random.set_state(rng_state) 79 | 80 | aux_data = aux_data[take_inds] 81 | aux_targets = aux_targets[take_inds] 82 | 83 | self.data = np.concatenate((self.data, aux_data), axis=0) 84 | 85 | if not add_aux_labels: 86 | self.targets.extend([-1] * len(aux_data)) 87 | else: 88 | self.targets.extend(aux_targets) 89 | self.unsup_indices.extend(range(orig_len, orig_len+len(aux_data))) 90 | 91 | else: 92 | self.sup_indices = list(range(len(self.targets))) 93 | self.unsup_indices = [] 94 | 95 | def load_base_dataset(self, **kwargs): 96 | raise NotImplementedError() 97 | 98 | @property 99 | def data(self): 100 | return self.dataset.data 101 | 102 | @data.setter 103 | def data(self, value): 104 | self.dataset.data = value 105 | 106 | @property 107 | def targets(self): 108 | return self.dataset.targets 109 | 110 | @targets.setter 111 | def targets(self, value): 112 | self.dataset.targets = value 113 | 114 | def __len__(self): 115 | return len(self.dataset) 116 | 117 | def __getitem__(self, item): 118 | self.dataset.labels = self.targets 119 | return self.dataset[item] 120 | 121 | 122 | class SemiSupervisedSampler(torch.utils.data.Sampler): 123 | """ 124 | Balanced sampling from the labeled and unlabeled data. 125 | """ 126 | def __init__(self, sup_inds, unsup_inds, batch_size, unsup_fraction=0.5, num_batches=None): 127 | if unsup_fraction is None or unsup_fraction < 0: 128 | self.sup_inds = sup_inds + unsup_inds 129 | unsup_fraction = 0.0 130 | else: 131 | self.sup_inds = sup_inds 132 | self.unsup_inds = unsup_inds 133 | 134 | self.batch_size = batch_size 135 | unsup_batch_size = int(batch_size * unsup_fraction) 136 | self.sup_batch_size = batch_size - unsup_batch_size 137 | 138 | if num_batches is not None: 139 | self.num_batches = num_batches 140 | else: 141 | self.num_batches = int(np.ceil(len(self.sup_inds) / self.sup_batch_size)) 142 | super().__init__(None) 143 | 144 | def __iter__(self): 145 | batch_counter = 0 146 | while batch_counter < self.num_batches: 147 | sup_inds_shuffled = [self.sup_inds[i] 148 | for i in torch.randperm(len(self.sup_inds))] 149 | for sup_k in range(0, len(self.sup_inds), self.sup_batch_size): 150 | if batch_counter == self.num_batches: 151 | break 152 | batch = sup_inds_shuffled[sup_k:(sup_k + self.sup_batch_size)] 153 | if self.sup_batch_size < self.batch_size: 154 | batch.extend([self.unsup_inds[i] for i in torch.randint(high=len(self.unsup_inds), 155 | size=(self.batch_size - len(batch),), 156 | dtype=torch.int64)]) 157 | np.random.shuffle(batch) 158 | yield batch 159 | batch_counter += 1 160 | 161 | def __len__(self): 162 | return self.num_batches -------------------------------------------------------------------------------- /core/data/svhn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | 7 | DATA_DESC = { 8 | 'data': 'svhn', 9 | 'classes': ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9'), 10 | 'num_classes': 10, 11 | 'mean': [0.4914, 0.4822, 0.4465], 12 | 'std': [0.2023, 0.1994, 0.2010], 13 | } 14 | 15 | 16 | def load_svhn(data_dir, use_augmentation=False): 17 | """ 18 | Returns SVHN train, test datasets and dataloaders. 19 | Arguments: 20 | data_dir (str): path to data directory. 21 | use_augmentation (bool): whether to use augmentations for training set. 22 | Returns: 23 | train dataset, test dataset. 24 | """ 25 | test_transform = transforms.Compose([transforms.ToTensor()]) 26 | train_transform = test_transform 27 | 28 | train_dataset = torchvision.datasets.SVHN(root=data_dir, split='train', download=True, transform=train_transform) 29 | test_dataset = torchvision.datasets.SVHN(root=data_dir, split='test', download=True, transform=test_transform) 30 | return train_dataset, test_dataset -------------------------------------------------------------------------------- /core/data/tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | from torchvision.datasets import ImageFolder 7 | 8 | 9 | DATA_DESC = { 10 | 'data': 'tiny-imagenet', 11 | 'classes': tuple(range(0, 200)), 12 | 'num_classes': 200, 13 | 'mean': [0.4802, 0.4481, 0.3975], 14 | 'std': [0.2302, 0.2265, 0.2262], 15 | } 16 | 17 | 18 | def load_tinyimagenet(data_dir, use_augmentation=False): 19 | """ 20 | Returns Tiny Imagenet-200 train, test datasets and dataloaders. 21 | Arguments: 22 | data_dir (str): path to data directory. 23 | use_augmentation (bool): whether to use augmentations for training set. 24 | Returns: 25 | train dataset, test dataset. 26 | """ 27 | test_transform = transforms.Compose([transforms.ToTensor()]) 28 | if use_augmentation: 29 | train_transform = transforms.Compose([transforms.RandomCrop(64, padding=4), transforms.RandomHorizontalFlip(), 30 | transforms.ToTensor()]) 31 | else: 32 | train_transform = test_transform 33 | 34 | train_dataset = ImageFolder(os.path.join(data_dir, 'train'), transform=train_transform) 35 | test_dataset = ImageFolder(os.path.join(data_dir, 'val'), transform=test_transform) 36 | 37 | return train_dataset, test_dataset 38 | -------------------------------------------------------------------------------- /core/data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import numpy as np 4 | 5 | import torch 6 | 7 | from core.utils.utils import np_load 8 | from core.utils.utils import NumpyToTensor 9 | 10 | 11 | class AdversarialDatasetWithPerturbation(torch.utils.data.Dataset): 12 | """ 13 | Torch dataset for reading examples with corresponding perturbations. 14 | Arguments: 15 | root (str): path to saved data. 16 | transform (torch.nn.Module): transformations to be applied to input. 17 | target_transform (torch.nn.Module): transformations to be applied to target. 18 | """ 19 | def __init__(self, root, transform=NumpyToTensor(), target_transform=None): 20 | super(AdversarialDatasetWithPerturbation, self).__init__() 21 | 22 | x_path = re.sub(r'adv_(\d)+', 'adv_0', root) 23 | if os.path.isfile(os.path.join(root, 'x.npy')): 24 | data = np_load(x_path) 25 | elif os.path.isfile(os.path.join(x_path, 'x.npy')): 26 | data = np_load(x_path) 27 | else: 28 | raise FileNotFoundError('x, y not found at {} and {}.'.format(root, x_path)) 29 | self.data = data['x'] 30 | self.targets = data['y'] 31 | 32 | data = np_load(root) 33 | self.r = data['r'] 34 | self.transform = transform 35 | self.target_transform = target_transform 36 | 37 | 38 | def __len__(self): 39 | return len(self.data) 40 | 41 | def __getitem__(self, idx): 42 | image = self.data[idx] 43 | label = self.targets[idx] 44 | if self.transform: 45 | image = self.transform(image) 46 | r = self.transform(self.r[idx]) 47 | if self.target_transform: 48 | label = self.target_transform(label) 49 | return image, r, label 50 | -------------------------------------------------------------------------------- /core/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def accuracy(true, preds): 4 | """ 5 | Computes multi-class accuracy. 6 | Arguments: 7 | true (torch.Tensor): true labels. 8 | preds (torch.Tensor): predicted labels. 9 | Returns: 10 | Multi-class accuracy. 11 | """ 12 | accuracy = (torch.softmax(preds, dim=1).argmax(dim=1) == true).sum().float()/float(true.size(0)) 13 | return accuracy.item() 14 | -------------------------------------------------------------------------------- /core/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .resnet import Normalization 4 | from .preact_resnet import preact_resnet 5 | from .resnet import resnet 6 | from .wideresnet import wideresnet 7 | 8 | from .preact_resnetwithswish import preact_resnetwithswish 9 | from .wideresnetwithswish import wideresnetwithswish 10 | 11 | from core.data import DATASETS 12 | 13 | 14 | MODELS = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 15 | 'preact-resnet18', 'preact-resnet34', 'preact-resnet50', 'preact-resnet101', 16 | 'wrn-28-10', 'wrn-32-10', 'wrn-34-10', 'wrn-34-20', 17 | 'preact-resnet18-swish', 'preact-resnet34-swish', 18 | 'wrn-28-10-swish', 'wrn-34-10-swish', 'wrn-34-20-swish', 'wrn-70-16-swish'] 19 | 20 | 21 | def create_model(name, normalize, info, device): 22 | """ 23 | Returns suitable model from its name. 24 | Arguments: 25 | name (str): name of resnet architecture. 26 | normalize (bool): normalize input. 27 | info (dict): dataset information. 28 | device (str or torch.device): device to work on. 29 | Returns: 30 | torch.nn.Module. 31 | """ 32 | if info['data'] in ['tiny-imagenet']: 33 | assert 'preact-resnet' in name and 'swish' not in name, 'Only preact-resnets are supported for this dataset!' 34 | from .ti_preact_resnet import ti_preact_resnet 35 | backbone = ti_preact_resnet(name, num_classes=info['num_classes']) 36 | 37 | elif info['data'] in ['imagenet100']: 38 | assert 'preact-resnet' in name and 'swish' not in name, 'Only preact-resnets are supported for this dataset!' 39 | from .in_preact_resnet import in_preact_resnet 40 | backbone = in_preact_resnet(name, num_classes=info['num_classes']) 41 | 42 | elif info['data'] in DATASETS and 'imagenet' not in info['data']: 43 | if 'preact-resnet' in name and 'swish' not in name: 44 | backbone = preact_resnet(name, num_classes=info['num_classes']) 45 | elif 'preact-resnet' in name and 'swish' in name: 46 | backbone = preact_resnetwithswish(name, dataset=info['data'], num_classes=info['num_classes']) 47 | elif 'resnet' in name and 'preact' not in name: 48 | backbone = resnet(name, num_classes=info['num_classes']) 49 | elif 'wrn' in name and 'swish' not in name: 50 | backbone = wideresnet(name, num_classes=info['num_classes']) 51 | elif 'wrn' in name and 'swish' in name: 52 | backbone = wideresnetwithswish(name, dataset=info['data'], num_classes=info['num_classes']) 53 | else: 54 | raise ValueError('Invalid model name {}!'.format(name)) 55 | 56 | else: 57 | raise ValueError('Models for {} not yet supported!'.format(info['data'])) 58 | 59 | if normalize: 60 | model = torch.nn.Sequential(Normalization(info['mean'], info['std']), backbone) 61 | else: 62 | model = torch.nn.Sequential(backbone) 63 | 64 | model = torch.nn.DataParallel(model) 65 | model = model.to(device) 66 | return model 67 | -------------------------------------------------------------------------------- /core/models/in_preact_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class PreActBlock(nn.Module): 7 | """ 8 | Pre-activation version of the BasicBlock for Resnets. 9 | Arguments: 10 | in_planes (int): number of input planes. 11 | planes (int): number of output filters. 12 | stride (int): stride of convolution. 13 | """ 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | out = self.conv2(F.relu(self.bn2(out))) 33 | out += shortcut 34 | return out 35 | 36 | 37 | class PreActBottleneck(nn.Module): 38 | """ 39 | Pre-activation version of the original Bottleneck module for Resnets. 40 | Arguments: 41 | in_planes (int): number of input planes. 42 | planes (int): number of output filters. 43 | stride (int): stride of convolution. 44 | """ 45 | expansion = 4 46 | 47 | def __init__(self, in_planes, planes, stride=1): 48 | super(PreActBottleneck, self).__init__() 49 | self.bn1 = nn.BatchNorm2d(in_planes) 50 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 53 | self.bn3 = nn.BatchNorm2d(planes) 54 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 55 | 56 | if stride != 1 or in_planes != self.expansion*planes: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 59 | ) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(x)) 63 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 64 | out = self.conv1(out) 65 | out = self.conv2(F.relu(self.bn2(out))) 66 | out = self.conv3(F.relu(self.bn3(out))) 67 | out += shortcut 68 | return out 69 | 70 | 71 | class PreActResNet(nn.Module): 72 | """ 73 | Pre-activation Resnet model 74 | """ 75 | def __init__(self, block, num_blocks, num_classes=10): 76 | super(PreActResNet, self).__init__() 77 | self.in_planes = 64 78 | 79 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=1, bias=False) 80 | self.bn1 = nn.BatchNorm2d(64) 81 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 82 | 83 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 84 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 85 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 86 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 87 | 88 | self.bn = nn.BatchNorm2d(512 * block.expansion) 89 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 90 | self.linear = nn.Linear(512*block.expansion, num_classes) 91 | 92 | def _make_layer(self, block, planes, num_blocks, stride): 93 | strides = [stride] + [1]*(num_blocks-1) 94 | layers = [] 95 | for stride in strides: 96 | layers.append(block(self.in_planes, planes, stride)) 97 | self.in_planes = planes * block.expansion 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | out = self.conv1(x) 102 | out = F.relu(self.bn1(out)) 103 | out = self.maxpool1(out) 104 | 105 | out = self.layer1(out) 106 | out = self.layer2(out) 107 | out = self.layer3(out) 108 | out = self.layer4(out) 109 | 110 | out = F.relu(self.bn(out)) 111 | out = self.avgpool(out) 112 | out = out.view(out.size(0), -1) 113 | out = self.linear(out) 114 | return out 115 | 116 | 117 | def in_preact_resnet(name, num_classes=100): 118 | """ 119 | Returns suitable PreAct Resnet model from its name (only for ImageNet-100 dataset). 120 | Arguments: 121 | name (str): name of resnet architecture. 122 | num_classes (int): number of target classes. 123 | Returns: 124 | torch.nn.Module. 125 | """ 126 | if name == 'preact-resnet18': 127 | return PreActResNet(PreActBlock, [2, 2, 2, 2], num_classes=num_classes) 128 | elif name == 'preact-resnet34': 129 | return PreActResNet(PreActBlock, [3, 4, 6, 3], num_classes=num_classes) 130 | elif name == 'preact-resnet50': 131 | return PreActResNet(PreActBottleneck, [3, 4, 6, 3], num_classes=num_classes) 132 | elif name == 'preact-resnet101': 133 | return PreActResNet(PreActBottleneck, [3, 4, 23, 3], num_classes=num_classes) 134 | raise ValueError('Only preact-resnet18, preact-resnet34, preact-resnet50 and preact-resnet101 are supported!') 135 | return 136 | -------------------------------------------------------------------------------- /core/models/preact_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class PreActBlock(nn.Module): 7 | """ 8 | Pre-activation version of the BasicBlock for Resnets. 9 | Arguments: 10 | in_planes (int): number of input planes. 11 | planes (int): number of output filters. 12 | stride (int): stride of convolution. 13 | """ 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | out = self.conv2(F.relu(self.bn2(out))) 33 | out += shortcut 34 | return out 35 | 36 | 37 | class PreActBottleneck(nn.Module): 38 | """ 39 | Pre-activation version of the original Bottleneck module for Resnets. 40 | Arguments: 41 | in_planes (int): number of input planes. 42 | planes (int): number of output filters. 43 | stride (int): stride of convolution. 44 | """ 45 | expansion = 4 46 | 47 | def __init__(self, in_planes, planes, stride=1): 48 | super(PreActBottleneck, self).__init__() 49 | self.bn1 = nn.BatchNorm2d(in_planes) 50 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 53 | self.bn3 = nn.BatchNorm2d(planes) 54 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 55 | 56 | if stride != 1 or in_planes != self.expansion*planes: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 59 | ) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(x)) 63 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 64 | out = self.conv1(out) 65 | out = self.conv2(F.relu(self.bn2(out))) 66 | out = self.conv3(F.relu(self.bn3(out))) 67 | out += shortcut 68 | return out 69 | 70 | 71 | class PreActResNet(nn.Module): 72 | """ 73 | Pre-activation Resnet model 74 | """ 75 | def __init__(self, block, num_blocks, num_classes=10): 76 | super(PreActResNet, self).__init__() 77 | self.in_planes = 64 78 | 79 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 80 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 81 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 82 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 83 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 84 | self.bn = nn.BatchNorm2d(512 * block.expansion) 85 | self.linear = nn.Linear(512*block.expansion, num_classes) 86 | 87 | def _make_layer(self, block, planes, num_blocks, stride): 88 | strides = [stride] + [1]*(num_blocks-1) 89 | layers = [] 90 | for stride in strides: 91 | layers.append(block(self.in_planes, planes, stride)) 92 | self.in_planes = planes * block.expansion 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | out = self.conv1(x) 97 | out = self.layer1(out) 98 | out = self.layer2(out) 99 | out = self.layer3(out) 100 | out = self.layer4(out) 101 | out = F.relu(self.bn(out)) 102 | out = F.avg_pool2d(out, 4) 103 | out = out.view(out.size(0), -1) 104 | out = self.linear(out) 105 | return out 106 | 107 | 108 | def preact_resnet(name, num_classes=10): 109 | """ 110 | Returns suitable Resnet model from its name. 111 | Arguments: 112 | name (str): name of resnet architecture. 113 | num_classes (int): number of target classes. 114 | Returns: 115 | torch.nn.Module. 116 | """ 117 | if name == 'preact-resnet18': 118 | return PreActResNet(PreActBlock, [2,2,2,2], num_classes=num_classes) 119 | elif name == 'preact-resnet34': 120 | return PreActResNet(PreActBlock, [3,4,6,3], num_classes=num_classes) 121 | elif name == 'preact-resnet50': 122 | return PreActResNet(PreActBottleneck, [3,4,6,3], num_classes=num_classes) 123 | elif name == 'preact-resnet101': 124 | return PreActResNet(PreActBottleneck, [3,4,23,3], num_classes=num_classes) 125 | raise ValueError('Only preact-resnet18, preact-resnet34, preact-resnet50 and preact-resnet101 are supported!') 126 | return 127 | -------------------------------------------------------------------------------- /core/models/preact_resnetwithswish.py: -------------------------------------------------------------------------------- 1 | # Code borrowed from https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/pytorch/model_zoo.py 2 | # (Rebuffi et al 2021) 3 | 4 | from typing import Tuple, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) 12 | CIFAR10_STD = (0.2471, 0.2435, 0.2616) 13 | CIFAR100_MEAN = (0.5071, 0.4865, 0.4409) 14 | CIFAR100_STD = (0.2673, 0.2564, 0.2762) 15 | SVHN_MEAN = (0.5, 0.5, 0.5) 16 | SVHN_STD = (0.5, 0.5, 0.5) 17 | 18 | _ACTIVATION = { 19 | 'relu': nn.ReLU, 20 | 'swish': nn.SiLU, 21 | } 22 | 23 | 24 | class _PreActBlock(nn.Module): 25 | """ 26 | PreAct ResNet Block. 27 | Arguments: 28 | in_planes (int): number of input planes. 29 | out_planes (int): number of output filters. 30 | stride (int): stride of convolution. 31 | activation_fn (nn.Module): activation function. 32 | """ 33 | def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU): 34 | super().__init__() 35 | self._stride = stride 36 | self.batchnorm_0 = nn.BatchNorm2d(in_planes, momentum=0.01) 37 | self.relu_0 = activation_fn() 38 | # We manually pad to obtain the same effect as `SAME` (necessary when 39 | # `stride` is different than 1). 40 | self.conv_2d_1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, 41 | stride=stride, padding=0, bias=False) 42 | self.batchnorm_1 = nn.BatchNorm2d(out_planes, momentum=0.01) 43 | self.relu_1 = activation_fn() 44 | self.conv_2d_2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 45 | padding=1, bias=False) 46 | self.has_shortcut = stride != 1 or in_planes != out_planes 47 | if self.has_shortcut: 48 | self.shortcut = nn.Conv2d(in_planes, out_planes, kernel_size=3, 49 | stride=stride, padding=0, bias=False) 50 | 51 | def _pad(self, x): 52 | if self._stride == 1: 53 | x = F.pad(x, (1, 1, 1, 1)) 54 | elif self._stride == 2: 55 | x = F.pad(x, (0, 1, 0, 1)) 56 | else: 57 | raise ValueError('Unsupported `stride`.') 58 | return x 59 | 60 | def forward(self, x): 61 | out = self.relu_0(self.batchnorm_0(x)) 62 | shortcut = self.shortcut(self._pad(x)) if self.has_shortcut else x 63 | out = self.conv_2d_1(self._pad(out)) 64 | out = self.conv_2d_2(self.relu_1(self.batchnorm_1(out))) 65 | return out + shortcut 66 | 67 | 68 | class PreActResNet(nn.Module): 69 | """ 70 | PreActResNet model 71 | Arguments: 72 | num_classes (int): number of output classes. 73 | depth (int): number of layers. 74 | width (int): width factor. 75 | activation_fn (nn.Module): activation function. 76 | mean (tuple): mean of dataset. 77 | std (tuple): standard deviation of dataset. 78 | padding (int): padding. 79 | num_input_channels (int): number of channels in the input. 80 | """ 81 | 82 | def __init__(self, 83 | num_classes: int = 10, 84 | depth: int = 18, 85 | width: int = 0, # Used to make the constructor consistent. 86 | activation_fn: nn.Module = nn.ReLU, 87 | mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN, 88 | std: Union[Tuple[float, ...], float] = CIFAR10_STD, 89 | padding: int = 0, 90 | num_input_channels: int = 3): 91 | 92 | super().__init__() 93 | if width != 0: 94 | raise ValueError('Unsupported `width`.') 95 | self.mean = torch.tensor(mean).view(num_input_channels, 1, 1) 96 | self.std = torch.tensor(std).view(num_input_channels, 1, 1) 97 | self.mean_cuda = None 98 | self.std_cuda = None 99 | self.padding = padding 100 | self.conv_2d = nn.Conv2d(num_input_channels, 64, kernel_size=3, stride=1, 101 | padding=1, bias=False) 102 | if depth == 18: 103 | num_blocks = (2, 2, 2, 2) 104 | elif depth == 34: 105 | num_blocks = (3, 4, 6, 3) 106 | else: 107 | raise ValueError('Unsupported `depth`.') 108 | self.layer_0 = self._make_layer(64, 64, num_blocks[0], 1, activation_fn) 109 | self.layer_1 = self._make_layer(64, 128, num_blocks[1], 2, activation_fn) 110 | self.layer_2 = self._make_layer(128, 256, num_blocks[2], 2, activation_fn) 111 | self.layer_3 = self._make_layer(256, 512, num_blocks[3], 2, activation_fn) 112 | self.batchnorm = nn.BatchNorm2d(512, momentum=0.01) 113 | self.relu = activation_fn() 114 | self.logits = nn.Linear(512, num_classes) 115 | 116 | def _make_layer(self, in_planes, out_planes, num_blocks, stride, 117 | activation_fn): 118 | layers = [] 119 | for i, stride in enumerate([stride] + [1] * (num_blocks - 1)): 120 | layers.append(_PreActBlock(i == 0 and in_planes or out_planes, 121 | out_planes, 122 | stride, 123 | activation_fn)) 124 | return nn.Sequential(*layers) 125 | 126 | def forward(self, x): 127 | if self.padding > 0: 128 | x = F.pad(x, (self.padding,) * 4) 129 | if x.is_cuda: 130 | if self.mean_cuda is None: 131 | self.mean_cuda = self.mean.cuda() 132 | self.std_cuda = self.std.cuda() 133 | out = (x - self.mean_cuda) / self.std_cuda 134 | else: 135 | out = (x - self.mean) / self.std 136 | out = self.conv_2d(out) 137 | out = self.layer_0(out) 138 | out = self.layer_1(out) 139 | out = self.layer_2(out) 140 | out = self.layer_3(out) 141 | out = self.relu(self.batchnorm(out)) 142 | out = F.avg_pool2d(out, 4) 143 | out = out.view(out.size(0), -1) 144 | return self.logits(out) 145 | 146 | 147 | def preact_resnetwithswish(name, dataset='cifar10', num_classes=10): 148 | """ 149 | Returns suitable PreActResNet model with Swish activation function from its name. 150 | Arguments: 151 | name (str): name of resnet architecture. 152 | num_classes (int): number of target classes. 153 | dataset (str): dataset to use. 154 | Returns: 155 | torch.nn.Module. 156 | """ 157 | name_parts = name.split('-') 158 | name = '-'.join(name_parts[:-1]) 159 | act_fn = name_parts[-1] 160 | depth = int(name[-2:]) 161 | 162 | if 'cifar100' in dataset: 163 | return PreActResNet(num_classes=num_classes, depth=depth, width=0, activation_fn=_ACTIVATION[act_fn], 164 | mean=CIFAR100_MEAN, std=CIFAR100_STD) 165 | elif 'svhn' in dataset: 166 | return PreActResNet(num_classes=num_classes, depth=depth, width=0, activation_fn=_ACTIVATION[act_fn], 167 | mean=SVHN_MEAN, std=SVHN_STD) 168 | return PreActResNet(num_classes=num_classes, depth=depth, width=0, activation_fn=_ACTIVATION[act_fn]) 169 | -------------------------------------------------------------------------------- /core/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Normalization(nn.Module): 7 | """ 8 | Standardizes the input data. 9 | Arguments: 10 | mean (list): mean. 11 | std (float): standard deviation. 12 | Returns: 13 | (input - mean) / std 14 | """ 15 | def __init__(self, mean, std): 16 | super(Normalization, self).__init__() 17 | num_channels = len(mean) 18 | self.mean = torch.FloatTensor(mean).view(1, num_channels, 1, 1) 19 | self.sigma = torch.FloatTensor(std).view(1, num_channels, 1, 1) 20 | self.mean_cuda, self.sigma_cuda = None, None 21 | 22 | def forward(self, x): 23 | if x.is_cuda: 24 | if self.mean_cuda is None: 25 | self.mean_cuda = self.mean.cuda() 26 | self.sigma_cuda = self.sigma.cuda() 27 | out = (x - self.mean_cuda) / self.sigma_cuda 28 | else: 29 | out = (x - self.mean) / self.sigma 30 | return out 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | """ 35 | Implements a basic block module for Resnets. 36 | Arguments: 37 | in_planes (int): number of input planes. 38 | out_planes (int): number of output filters. 39 | stride (int): stride of convolution. 40 | """ 41 | expansion = 1 42 | 43 | def __init__(self, in_planes, planes, stride=1): 44 | super(BasicBlock, self).__init__() 45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | 50 | self.shortcut = nn.Sequential() 51 | if stride != 1 or in_planes != self.expansion * planes: 52 | self.shortcut = nn.Sequential( 53 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 54 | nn.BatchNorm2d(self.expansion * planes) 55 | ) 56 | 57 | def forward(self, x): 58 | out = F.relu(self.bn1(self.conv1(x))) 59 | out = self.bn2(self.conv2(out)) 60 | out += self.shortcut(x) 61 | out = F.relu(out) 62 | return out 63 | 64 | 65 | class Bottleneck(nn.Module): 66 | """ 67 | Implements a basic block module with bottleneck for Resnets. 68 | Arguments: 69 | in_planes (int): number of input planes. 70 | out_planes (int): number of output filters. 71 | stride (int): stride of convolution. 72 | """ 73 | expansion = 4 74 | 75 | def __init__(self, in_planes, planes, stride=1): 76 | super(Bottleneck, self).__init__() 77 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 78 | self.bn1 = nn.BatchNorm2d(planes) 79 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 80 | self.bn2 = nn.BatchNorm2d(planes) 81 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 82 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 83 | 84 | self.shortcut = nn.Sequential() 85 | if stride != 1 or in_planes != self.expansion * planes: 86 | self.shortcut = nn.Sequential( 87 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 88 | nn.BatchNorm2d(self.expansion * planes) 89 | ) 90 | 91 | def forward(self, x): 92 | out = F.relu(self.bn1(self.conv1(x))) 93 | out = F.relu(self.bn2(self.conv2(out))) 94 | out = self.bn3(self.conv3(out)) 95 | out += self.shortcut(x) 96 | out = F.relu(out) 97 | return out 98 | 99 | 100 | class ResNet(nn.Module): 101 | """ 102 | ResNet model 103 | Arguments: 104 | block (BasicBlock or Bottleneck): type of basic block to be used. 105 | num_blocks (list): number of blocks in each sub-module. 106 | num_classes (int): number of output classes. 107 | """ 108 | def __init__(self, block, num_blocks, num_classes=10): 109 | super(ResNet, self).__init__() 110 | self.in_planes = 64 111 | 112 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 113 | self.bn1 = nn.BatchNorm2d(64) 114 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 115 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 116 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 117 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 118 | self.linear = nn.Linear(512 * block.expansion, num_classes) 119 | 120 | def _make_layer(self, block, planes, num_blocks, stride): 121 | strides = [stride] + [1] * (num_blocks - 1) 122 | layers = [] 123 | for stride in strides: 124 | layers.append(block(self.in_planes, planes, stride)) 125 | self.in_planes = planes * block.expansion 126 | return nn.Sequential(*layers) 127 | 128 | def forward(self, x): 129 | out = F.relu(self.bn1(self.conv1(x))) 130 | out = self.layer1(out) 131 | out = self.layer2(out) 132 | out = self.layer3(out) 133 | out = self.layer4(out) 134 | out = F.avg_pool2d(out, 4) 135 | out = out.view(out.size(0), -1) 136 | out = self.linear(out) 137 | return out 138 | 139 | 140 | def resnet(name, num_classes=10): 141 | """ 142 | Returns suitable Resnet model from its name. 143 | Arguments: 144 | name (str): name of resnet architecture. 145 | num_classes (int): number of target classes. 146 | Returns: 147 | torch.nn.Module. 148 | """ 149 | if name == 'resnet18': 150 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes) 151 | elif name == 'resnet34': 152 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes) 153 | elif name == 'resnet50': 154 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes) 155 | elif name == 'resnet101': 156 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes) 157 | 158 | raise ValueError('Only resnet18, resnet34, resnet50 and resnet101 are supported!') 159 | return 160 | -------------------------------------------------------------------------------- /core/models/ti_preact_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class PreActBlock(nn.Module): 7 | """ 8 | Pre-activation version of the BasicBlock. 9 | Arguments: 10 | in_planes (int): number of input planes. 11 | planes (int): number of output filters. 12 | stride (int): stride of convolution. 13 | """ 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | out = self.conv2(F.relu(self.bn2(out))) 33 | out += shortcut 34 | return out 35 | 36 | 37 | class PreActBottleneck(nn.Module): 38 | """ 39 | Pre-activation version of the original Bottleneck module. 40 | Arguments: 41 | in_planes (int): number of input planes. 42 | planes (int): number of output filters. 43 | stride (int): stride of convolution. 44 | """ 45 | expansion = 4 46 | 47 | def __init__(self, in_planes, planes, stride=1): 48 | super(PreActBottleneck, self).__init__() 49 | self.bn1 = nn.BatchNorm2d(in_planes) 50 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 53 | self.bn3 = nn.BatchNorm2d(planes) 54 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 55 | 56 | if stride != 1 or in_planes != self.expansion*planes: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 59 | ) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(x)) 63 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 64 | out = self.conv1(out) 65 | out = self.conv2(F.relu(self.bn2(out))) 66 | out = self.conv3(F.relu(self.bn3(out))) 67 | out += shortcut 68 | return out 69 | 70 | 71 | class PreActResNet(nn.Module): 72 | """ 73 | Pre-activation Resnet model for TI-200 dataset. 74 | """ 75 | def __init__(self, block, num_blocks, num_classes=200): 76 | super(PreActResNet, self).__init__() 77 | self.in_planes = 64 78 | 79 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 80 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 81 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 82 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 83 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 84 | self.bn = nn.BatchNorm2d(512 * block.expansion) 85 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 86 | self.linear = nn.Linear(512*block.expansion, num_classes) 87 | 88 | def _make_layer(self, block, planes, num_blocks, stride): 89 | strides = [stride] + [1]*(num_blocks-1) 90 | layers = [] 91 | for stride in strides: 92 | layers.append(block(self.in_planes, planes, stride)) 93 | self.in_planes = planes * block.expansion 94 | return nn.Sequential(*layers) 95 | 96 | def forward(self, x): 97 | out = self.conv1(x) 98 | out = self.layer1(out) 99 | out = self.layer2(out) 100 | out = self.layer3(out) 101 | out = self.layer4(out) 102 | out = F.relu(self.bn(out)) 103 | out = self.avgpool(out) 104 | out = out.view(out.size(0), -1) 105 | out = self.linear(out) 106 | return out 107 | 108 | 109 | def ti_preact_resnet(name, num_classes=200): 110 | """ 111 | Returns suitable PreAct Resnet model from its name (only for TI-200 dataset). 112 | Arguments: 113 | name (str): name of resnet architecture. 114 | num_classes (int): number of target classes. 115 | Returns: 116 | torch.nn.Module. 117 | """ 118 | if name == 'preact-resnet18': 119 | return PreActResNet(PreActBlock, [2, 2, 2, 2], num_classes=num_classes) 120 | elif name == 'preact-resnet34': 121 | return PreActResNet(PreActBlock, [3, 4, 6, 3], num_classes=num_classes) 122 | elif name == 'preact-resnet50': 123 | return PreActResNet(PreActBottleneck, [3, 4, 6, 3], num_classes=num_classes) 124 | elif name == 'preact-resnet101': 125 | return PreActResNet(PreActBottleneck, [3, 4, 23, 3], num_classes=num_classes) 126 | else: 127 | raise ValueError('Only preact-resnet18, preact-resnet34, preact-resnet50 and preact-resnet101 are supported!') 128 | return 129 | -------------------------------------------------------------------------------- /core/models/wideresnet.py: -------------------------------------------------------------------------------- 1 | # Code borrowed from https://github.com/yaircarmon/semisup-adv 2 | # (Carmon et al 2019) 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | """ 12 | Implements a basic block module for WideResNets. 13 | Arguments: 14 | in_planes (int): number of input planes. 15 | out_planes (int): number of output filters. 16 | stride (int): stride of convolution. 17 | dropRate (float): dropout rate. 18 | """ 19 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 20 | super(BasicBlock, self).__init__() 21 | self.bn1 = nn.BatchNorm2d(in_planes) 22 | self.relu1 = nn.ReLU(inplace=True) 23 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=1, bias=False) 25 | self.bn2 = nn.BatchNorm2d(out_planes) 26 | self.relu2 = nn.ReLU(inplace=True) 27 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 28 | padding=1, bias=False) 29 | self.droprate = dropRate 30 | self.equalInOut = (in_planes == out_planes) 31 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 32 | padding=0, bias=False) or None 33 | 34 | def forward(self, x): 35 | if not self.equalInOut: 36 | x = self.relu1(self.bn1(x)) 37 | else: 38 | out = self.relu1(self.bn1(x)) 39 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 40 | if self.droprate > 0: 41 | out = F.dropout(out, p=self.droprate, training=self.training) 42 | out = self.conv2(out) 43 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 44 | 45 | 46 | class NetworkBlock(nn.Module): 47 | """ 48 | Implements a network block module for WideResnets. 49 | Arguments: 50 | nb_layers (int): number of layers. 51 | in_planes (int): number of input planes. 52 | out_planes (int): number of output filters. 53 | block (BasicBlock): type of basic block to be used. 54 | stride (int): stride of convolution. 55 | dropRate (float): dropout rate. 56 | """ 57 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 58 | super(NetworkBlock, self).__init__() 59 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 60 | 61 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 62 | layers = [] 63 | for i in range(int(nb_layers)): 64 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 65 | return nn.Sequential(*layers) 66 | 67 | def forward(self, x): 68 | return self.layer(x) 69 | 70 | 71 | class WideResNet(nn.Module): 72 | """ 73 | WideResNet model 74 | Arguments: 75 | depth (int): number of layers. 76 | num_classes (int): number of output classes. 77 | widen_factor (int): width factor. 78 | dropRate (float): dropout rate. 79 | """ 80 | def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0.0): 81 | super(WideResNet, self).__init__() 82 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 83 | assert ((depth - 4) % 6 == 0) 84 | n = (depth - 4) / 6 85 | block = BasicBlock 86 | # 1st conv before any network block 87 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 88 | padding=1, bias=False) 89 | # 1st block 90 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 91 | # 2nd block 92 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 93 | # 3rd block 94 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 95 | # global average pooling and classifier 96 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.fc = nn.Linear(nChannels[3], num_classes) 99 | self.nChannels = nChannels[3] 100 | 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 104 | m.weight.data.normal_(0, math.sqrt(2. / n)) 105 | elif isinstance(m, nn.BatchNorm2d): 106 | m.weight.data.fill_(1) 107 | m.bias.data.zero_() 108 | elif isinstance(m, nn.Linear): 109 | m.bias.data.zero_() 110 | 111 | def forward(self, x): 112 | out = self.conv1(x) 113 | out = self.block1(out) 114 | out = self.block2(out) 115 | out = self.block3(out) 116 | out = self.relu(self.bn1(out)) 117 | out = F.avg_pool2d(out, 8) 118 | out = out.view(-1, self.nChannels) 119 | return self.fc(out) 120 | 121 | 122 | def wideresnet(name, num_classes=10): 123 | """ 124 | Returns suitable Wideresnet model from its name. 125 | Arguments: 126 | name (str): name of resnet architecture. 127 | num_classes (int): number of target classes. 128 | Returns: 129 | torch.nn.Module. 130 | """ 131 | name_parts = name.split('-') 132 | depth = int(name_parts[1]) 133 | widen = int(name_parts[2]) 134 | return WideResNet(depth=depth, num_classes=num_classes, widen_factor=widen) 135 | -------------------------------------------------------------------------------- /core/models/wideresnetwithswish.py: -------------------------------------------------------------------------------- 1 | # Code borrowed from https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/pytorch/model_zoo.py 2 | # (Gowal et al 2020) 3 | 4 | from typing import Tuple, Union 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) 13 | CIFAR10_STD = (0.2471, 0.2435, 0.2616) 14 | CIFAR100_MEAN = (0.5071, 0.4865, 0.4409) 15 | CIFAR100_STD = (0.2673, 0.2564, 0.2762) 16 | SVHN_MEAN = (0.5, 0.5, 0.5) 17 | SVHN_STD = (0.5, 0.5, 0.5) 18 | 19 | _ACTIVATION = { 20 | 'relu': nn.ReLU, 21 | 'swish': nn.SiLU, 22 | } 23 | 24 | 25 | class _Block(nn.Module): 26 | """ 27 | WideResNet Block. 28 | Arguments: 29 | in_planes (int): number of input planes. 30 | out_planes (int): number of output filters. 31 | stride (int): stride of convolution. 32 | activation_fn (nn.Module): activation function. 33 | """ 34 | def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU): 35 | super().__init__() 36 | self.batchnorm_0 = nn.BatchNorm2d(in_planes, momentum=0.01) 37 | self.relu_0 = activation_fn(inplace=True) 38 | # We manually pad to obtain the same effect as `SAME` (necessary when `stride` is different than 1). 39 | self.conv_0 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 40 | padding=0, bias=False) 41 | self.batchnorm_1 = nn.BatchNorm2d(out_planes, momentum=0.01) 42 | self.relu_1 = activation_fn(inplace=True) 43 | self.conv_1 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 44 | padding=1, bias=False) 45 | self.has_shortcut = in_planes != out_planes 46 | if self.has_shortcut: 47 | self.shortcut = nn.Conv2d(in_planes, out_planes, kernel_size=1, 48 | stride=stride, padding=0, bias=False) 49 | else: 50 | self.shortcut = None 51 | self._stride = stride 52 | 53 | def forward(self, x): 54 | if self.has_shortcut: 55 | x = self.relu_0(self.batchnorm_0(x)) 56 | else: 57 | out = self.relu_0(self.batchnorm_0(x)) 58 | v = x if self.has_shortcut else out 59 | if self._stride == 1: 60 | v = F.pad(v, (1, 1, 1, 1)) 61 | elif self._stride == 2: 62 | v = F.pad(v, (0, 1, 0, 1)) 63 | else: 64 | raise ValueError('Unsupported `stride`.') 65 | out = self.conv_0(v) 66 | out = self.relu_1(self.batchnorm_1(out)) 67 | out = self.conv_1(out) 68 | out = torch.add(self.shortcut(x) if self.has_shortcut else x, out) 69 | return out 70 | 71 | 72 | class _BlockGroup(nn.Module): 73 | """ 74 | WideResNet block group. 75 | Arguments: 76 | in_planes (int): number of input planes. 77 | out_planes (int): number of output filters. 78 | stride (int): stride of convolution. 79 | activation_fn (nn.Module): activation function. 80 | """ 81 | def __init__(self, num_blocks, in_planes, out_planes, stride, activation_fn=nn.ReLU): 82 | super().__init__() 83 | block = [] 84 | for i in range(num_blocks): 85 | block.append( 86 | _Block(i == 0 and in_planes or out_planes, 87 | out_planes, 88 | i == 0 and stride or 1, 89 | activation_fn=activation_fn) 90 | ) 91 | self.block = nn.Sequential(*block) 92 | 93 | def forward(self, x): 94 | return self.block(x) 95 | 96 | 97 | class WideResNet(nn.Module): 98 | """ 99 | WideResNet model 100 | Arguments: 101 | num_classes (int): number of output classes. 102 | depth (int): number of layers. 103 | width (int): width factor. 104 | activation_fn (nn.Module): activation function. 105 | mean (tuple): mean of dataset. 106 | std (tuple): standard deviation of dataset. 107 | padding (int): padding. 108 | num_input_channels (int): number of channels in the input. 109 | """ 110 | def __init__(self, 111 | num_classes: int = 10, 112 | depth: int = 28, 113 | width: int = 10, 114 | activation_fn: nn.Module = nn.ReLU, 115 | mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN, 116 | std: Union[Tuple[float, ...], float] = CIFAR10_STD, 117 | padding: int = 0, 118 | num_input_channels: int = 3): 119 | super().__init__() 120 | self.mean = torch.tensor(mean).view(num_input_channels, 1, 1) 121 | self.std = torch.tensor(std).view(num_input_channels, 1, 1) 122 | self.mean_cuda = None 123 | self.std_cuda = None 124 | self.padding = padding 125 | num_channels = [16, 16 * width, 32 * width, 64 * width] 126 | assert (depth - 4) % 6 == 0 127 | num_blocks = (depth - 4) // 6 128 | self.init_conv = nn.Conv2d(num_input_channels, num_channels[0], 129 | kernel_size=3, stride=1, padding=1, bias=False) 130 | self.layer = nn.Sequential( 131 | _BlockGroup(num_blocks, num_channels[0], num_channels[1], 1, 132 | activation_fn=activation_fn), 133 | _BlockGroup(num_blocks, num_channels[1], num_channels[2], 2, 134 | activation_fn=activation_fn), 135 | _BlockGroup(num_blocks, num_channels[2], num_channels[3], 2, 136 | activation_fn=activation_fn)) 137 | self.batchnorm = nn.BatchNorm2d(num_channels[3], momentum=0.01) 138 | self.relu = activation_fn(inplace=True) 139 | self.logits = nn.Linear(num_channels[3], num_classes) 140 | self.num_channels = num_channels[3] 141 | 142 | for m in self.modules(): 143 | if isinstance(m, nn.Conv2d): 144 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 145 | m.weight.data.normal_(0, math.sqrt(2. / n)) 146 | elif isinstance(m, nn.BatchNorm2d): 147 | m.weight.data.fill_(1) 148 | m.bias.data.zero_() 149 | elif isinstance(m, nn.Linear): 150 | m.bias.data.zero_() 151 | 152 | def forward(self, x): 153 | if self.padding > 0: 154 | x = F.pad(x, (self.padding,) * 4) 155 | if x.is_cuda: 156 | if self.mean_cuda is None: 157 | self.mean_cuda = self.mean.cuda() 158 | self.std_cuda = self.std.cuda() 159 | out = (x - self.mean_cuda) / self.std_cuda 160 | else: 161 | out = (x - self.mean) / self.std 162 | 163 | out = self.init_conv(out) 164 | out = self.layer(out) 165 | out = self.relu(self.batchnorm(out)) 166 | out = F.avg_pool2d(out, 8) 167 | out = out.view(-1, self.num_channels) 168 | return self.logits(out) 169 | 170 | 171 | def wideresnetwithswish(name, dataset='cifar10', num_classes=10): 172 | """ 173 | Returns suitable Wideresnet model with Swish activation function from its name. 174 | Arguments: 175 | name (str): name of resnet architecture. 176 | num_classes (int): number of target classes. 177 | dataset (str): dataset to use. 178 | Returns: 179 | torch.nn.Module. 180 | """ 181 | name_parts = name.split('-') 182 | depth = int(name_parts[1]) 183 | widen = int(name_parts[2]) 184 | act_fn = name_parts[3] 185 | 186 | if 'cifar100' in dataset: 187 | return WideResNet(num_classes=num_classes, depth=depth, width=widen, activation_fn=_ACTIVATION[act_fn], 188 | mean=CIFAR100_MEAN, std=CIFAR100_STD) 189 | elif 'svhn' in dataset: 190 | return WideResNet(num_classes=num_classes, depth=depth, width=widen, activation_fn=_ACTIVATION[act_fn], 191 | mean=SVHN_MEAN, std=SVHN_STD) 192 | return WideResNet(num_classes=num_classes, depth=depth, width=widen, activation_fn=_ACTIVATION[act_fn]) 193 | -------------------------------------------------------------------------------- /core/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import zipfile 4 | import subprocess 5 | 6 | 7 | def extract_zip(in_file, out_dir): 8 | zf = zipfile.ZipFile(in_file, 'r') 9 | zf.extractall(out_dir) 10 | zf.close() 11 | 12 | def copy(in_file, out_dir): 13 | cmd = ['cp', in_file, out_dir] 14 | subprocess.call(cmd) 15 | return 1 16 | 17 | def _setup(data_dir, name='train'): 18 | dataset = os.path.basename(os.path.normpath(data_dir)) 19 | data_file = os.path.join(data_dir, f'{name}.zip') 20 | tmp_dir = os.path.join(os.environ['TMPDIR'], dataset) 21 | 22 | if not os.path.isdir(tmp_dir): 23 | os.makedirs(tmp_dir, exist_ok=True) 24 | copy(data_file, tmp_dir) 25 | if name in ['train']: 26 | extract_zip(os.path.join(tmp_dir, f'{name}.zip'), os.path.join(tmp_dir, name)) 27 | else: 28 | extract_zip(os.path.join(tmp_dir, f'{name}.zip'), tmp_dir) 29 | os.remove(os.path.join(tmp_dir, f'{name}.zip')) 30 | 31 | 32 | def setup_train(data_dir): 33 | print ('Setting up training dataset.') 34 | _setup(data_dir, 'train') 35 | 36 | def setup_val(data_dir): 37 | print ('Setting up validation dataset.') 38 | _setup(data_dir, 'val') 39 | 40 | def clear_data(): 41 | tmp_dir = os.environ['TMPDIR'] 42 | if os.path.isdir(os.path.join(tmp_dir, 'imagenet100')): 43 | shutil.rmtree(os.path.join(tmp_dir, 'imagenet100')) 44 | print (f'Cleared up TMPDIR.') -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | 3 | from .context import * 4 | 5 | from .logger import Logger 6 | 7 | from .train import SCHEDULERS 8 | from .train import Trainer 9 | 10 | from .parser import * 11 | 12 | from .exp import * -------------------------------------------------------------------------------- /core/utils/context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | 4 | class ctx_noparamgrad(object): 5 | def __init__(self, module): 6 | self.prev_grad_state = get_param_grad_state(module) 7 | self.module = module 8 | set_param_grad_off(module) 9 | 10 | def __enter__(self): 11 | pass 12 | 13 | def __exit__(self, *args): 14 | set_param_grad_state(self.module, self.prev_grad_state) 15 | return False 16 | 17 | 18 | class ctx_eval(object): 19 | def __init__(self, module): 20 | self.prev_training_state = get_module_training_state(module) 21 | self.module = module 22 | set_module_training_off(module) 23 | 24 | def __enter__(self): 25 | pass 26 | 27 | def __exit__(self, *args): 28 | set_module_training_state(self.module, self.prev_training_state) 29 | return False 30 | 31 | 32 | @contextmanager 33 | def ctx_noparamgrad_and_eval(module): 34 | with ctx_noparamgrad(module) as a, ctx_eval(module) as b: 35 | yield (a, b) 36 | 37 | 38 | def get_module_training_state(module): 39 | return {mod: mod.training for mod in module.modules()} 40 | 41 | 42 | def set_module_training_state(module, training_state): 43 | for mod in module.modules(): 44 | mod.training = training_state[mod] 45 | 46 | 47 | def set_module_training_off(module): 48 | for mod in module.modules(): 49 | mod.training = False 50 | 51 | 52 | def get_param_grad_state(module): 53 | return {param: param.requires_grad for param in module.parameters()} 54 | 55 | 56 | def set_param_grad_state(module, grad_state): 57 | for param in module.parameters(): 58 | param.requires_grad = grad_state[param] 59 | 60 | 61 | def set_param_grad_off(module): 62 | for param in module.parameters(): 63 | param.requires_grad = False -------------------------------------------------------------------------------- /core/utils/exp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from tqdm import tqdm as tqdm 4 | 5 | import torch 6 | 7 | from core.data.utils import AdversarialDatasetWithPerturbation 8 | 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | 13 | def get_orthogonal_vector(r): 14 | """ 15 | Returns a random unit vector orthogonal to given unit vector r. 16 | """ 17 | r = r / torch.norm(r.view(-1), p=2) 18 | p = torch.rand(r.numel()).to(device) 19 | p = p - p.dot(r.view(-1))*r.view(-1) 20 | p = p / torch.norm(p, p=2) 21 | p = p.view(r.shape) 22 | assert np.isclose(torch.dot(p.view(-1), r.view(-1)).item(), 0, atol=1e-6) == True, 'p and r are not orthogonal.' 23 | return p 24 | 25 | 26 | def line_search(model, x, r, y, precision=0.1, ord=2, max_alpha=35, normalize_r=True, clip_min=0., clip_max=1., ortho=False): 27 | """ 28 | Perform line search to find margin. 29 | """ 30 | x, r = x.unsqueeze(0), r.unsqueeze(0) 31 | pert_preds = model(torch.clamp(x+r, 0, 1)) 32 | 33 | if normalize_r: 34 | r = r / r.view(-1).norm(p=ord) 35 | if ortho: 36 | r = get_orthogonal_vector(r) 37 | 38 | orig_preds = model(x) 39 | orig_labels = orig_preds.argmax(dim=1) 40 | pert_x = replicate_input(x) 41 | for a in range(0, max_alpha + 1): # fast search 42 | pert_labels = model(pert_x).argmax(dim=1) 43 | if pert_labels != orig_labels: 44 | break 45 | pert_x = x + a*r 46 | alpha = a 47 | 48 | pert_x = replicate_input(x) 49 | if alpha != max_alpha: # fine-tune search with given precision 50 | for a in np.arange(alpha - 1, alpha + precision, precision): 51 | pert_labels = model(pert_x).argmax(dim=1) 52 | if pert_labels != orig_labels: 53 | break 54 | pert_x = x + a*r 55 | margin = a 56 | else: 57 | margin = max_alpha 58 | 59 | pert_labels = pert_preds.argmax(dim=1) 60 | return {'mar': margin, 'true': y, 'orig_pred': orig_labels.item(), 'pert_pred': pert_labels.item()} 61 | 62 | 63 | def measure_margin(trainer, data_path, precision, ord=2, ortho=False, verbose=False): 64 | """ 65 | Estimate margin using line search. 66 | """ 67 | 68 | if ord not in [2, np.inf]: 69 | raise NotImplementedError('Only ord=2 and ord=inf have been implemented!') 70 | trainer.model.eval() 71 | 72 | mar_adv_any = [] 73 | dataset = AdversarialDatasetWithPerturbation(data_path) 74 | for x, r, y in tqdm(dataset, disable=not verbose): 75 | x, r = x.to(device), r.to(device) 76 | mar_any = line_search(trainer.model, x, r, y, ord=ord, precision=precision, ortho=ortho) 77 | mar_adv_any.append(mar_any) 78 | assert len(mar_adv_any) == len(dataset), 'Lengths must match' 79 | 80 | mar_adv_any = pd.DataFrame(mar_adv_any) 81 | mar10, mar50, mar90 = np.percentile(mar_adv_any['mar'], [10, 50, 90]) 82 | out_margin = {'mean_margin': np.mean(mar_adv_any['mar']), '10_margin': mar10, '50_margin': mar50, '90_margin': mar90} 83 | 84 | return mar_adv_any, out_margin 85 | -------------------------------------------------------------------------------- /core/utils/hat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from core.attacks import create_attack 8 | from core.metrics import accuracy 9 | from .context import ctx_noparamgrad_and_eval 10 | 11 | from torch.autograd import Variable 12 | 13 | 14 | def hat_loss(model, x, y, optimizer, step_size=0.007, epsilon=0.031, perturb_steps=10, h=3.5, beta=1.0, gamma=1.0, 15 | attack='linf-pgd', hr_model=None): 16 | """ 17 | TRADES + Helper-based adversarial training. 18 | """ 19 | 20 | criterion_kl = nn.KLDivLoss(reduction='sum') 21 | model.eval() 22 | 23 | x_adv = x.detach() + 0.001 * torch.randn(x.shape).cuda().detach() 24 | p_natural = F.softmax(model(x), dim=1) 25 | 26 | if attack == 'linf-pgd': 27 | for _ in range(perturb_steps): 28 | x_adv.requires_grad_() 29 | with torch.enable_grad(): 30 | loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1), p_natural) 31 | grad = torch.autograd.grad(loss_kl, [x_adv])[0] 32 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 33 | x_adv = torch.min(torch.max(x_adv, x - epsilon), x + epsilon) 34 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 35 | elif attack == 'l2-pgd': 36 | delta = 0.001 * torch.randn(x.shape).cuda().detach() 37 | delta = Variable(delta.data, requires_grad=True) 38 | 39 | batch_size = len(x) 40 | optimizer_delta = torch.optim.SGD([delta], lr=step_size) 41 | 42 | for _ in range(perturb_steps): 43 | adv = x + delta 44 | optimizer_delta.zero_grad() 45 | with torch.enable_grad(): 46 | loss = (-1) * criterion_kl(F.log_softmax(model(adv), dim=1), p_natural) 47 | loss.backward(retain_graph=True) 48 | 49 | grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1) 50 | delta.grad.div_(grad_norms.view(-1, 1, 1, 1)) 51 | if (grad_norms == 0).any(): 52 | delta.grad[grad_norms == 0] = torch.randn_like(delta.grad[grad_norms == 0]) 53 | optimizer_delta.step() 54 | 55 | delta.data.add_(x) 56 | delta.data.clamp_(0, 1).sub_(x) 57 | delta.data.renorm_(p=2, dim=0, maxnorm=epsilon) 58 | x_adv = Variable(x + delta, requires_grad=False) 59 | else: 60 | raise ValueError(f'Attack={attack} not supported for TRADES training!') 61 | model.train() 62 | 63 | 64 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 65 | x_hr = x + h * (x_adv - x) 66 | with ctx_noparamgrad_and_eval(hr_model): 67 | y_hr = hr_model(x_adv).argmax(dim=1) 68 | 69 | optimizer.zero_grad() 70 | 71 | out_clean, out_adv, out_help = model(x), model(x_adv), model(x_hr) 72 | loss_clean = F.cross_entropy(out_clean, y, reduction='mean') 73 | loss_adv = (1/len(x)) * criterion_kl(F.log_softmax(out_adv, dim=1), F.softmax(out_clean, dim=1)) 74 | 75 | loss_help = F.cross_entropy(out_help, y_hr, reduction='mean') 76 | loss = loss_clean + beta * loss_adv + gamma * loss_help 77 | 78 | batch_metrics = {'loss': loss.item()} 79 | batch_metrics.update({'adversarial_acc': accuracy(y, out_adv.detach()), 'helper_acc': accuracy(y_hr, out_help.detach())}) 80 | batch_metrics.update({'clean_acc': accuracy(y, out_clean.detach())}) 81 | 82 | return loss, batch_metrics 83 | 84 | 85 | def at_hat_loss(model, x, y, optimizer, step_size=0.007, epsilon=0.031, perturb_steps=10, h=3.5, beta=1.0, gamma=1.0, 86 | attack='linf-pgd', hr_model=None): 87 | """ 88 | AT + Helper-based adversarial training. 89 | """ 90 | 91 | criterion_ce = nn.CrossEntropyLoss() 92 | model.eval() 93 | 94 | attack = create_attack(model, criterion_ce, attack, epsilon, perturb_steps, step_size) 95 | with ctx_noparamgrad_and_eval(model): 96 | x_adv, _ = attack.perturb(x, y) 97 | 98 | model.train() 99 | 100 | x_hr = x + h * (x_adv - x) 101 | with ctx_noparamgrad_and_eval(hr_model): 102 | y_hr = hr_model(x_adv).argmax(dim=1) 103 | 104 | optimizer.zero_grad() 105 | 106 | out_clean, out_adv, out_help = model(x), model(x_adv), model(x_hr) 107 | loss_clean = F.cross_entropy(out_clean, y, reduction='mean') 108 | loss_adv = criterion_ce(out_adv, y) 109 | loss_help = F.cross_entropy(out_help, y_hr, reduction='mean') 110 | loss = loss_clean + beta * loss_adv + gamma * loss_help 111 | 112 | batch_metrics = {'loss': loss.item()} 113 | batch_metrics.update({'adversarial_acc': accuracy(y, out_adv.detach()), 'helper_acc': accuracy(y_hr, out_help.detach())}) 114 | batch_metrics.update({'clean_acc': accuracy(y, out_clean.detach())}) 115 | 116 | return loss, batch_metrics 117 | -------------------------------------------------------------------------------- /core/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | class Logger(object): 5 | """ 6 | Helper class for logging. 7 | Arguments: 8 | path (str): Path to log file. 9 | """ 10 | def __init__(self, path): 11 | self.logger = logging.getLogger() 12 | self.path = path 13 | self.setup_file_logger() 14 | print ('Logging to file: ', self.path) 15 | 16 | def setup_file_logger(self): 17 | hdlr = logging.FileHandler(self.path, 'w+') 18 | self.logger.addHandler(hdlr) 19 | self.logger.setLevel(logging.INFO) 20 | 21 | def log(self, message): 22 | print (message) 23 | self.logger.info(message) 24 | -------------------------------------------------------------------------------- /core/utils/mart.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | from core.metrics import accuracy 7 | 8 | 9 | def mart_loss(model, x_natural, y, optimizer, step_size=0.007, epsilon=0.031, perturb_steps=10, beta=6.0, 10 | attack='linf-pgd'): 11 | """ 12 | MART training (Wang et al, 2020). 13 | """ 14 | 15 | kl = nn.KLDivLoss(reduction='none') 16 | model.eval() 17 | batch_size = len(x_natural) 18 | 19 | # generate adversarial example 20 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach() 21 | if attack == 'linf-pgd': 22 | for _ in range(perturb_steps): 23 | x_adv.requires_grad_() 24 | with torch.enable_grad(): 25 | loss_ce = F.cross_entropy(model(x_adv), y) 26 | grad = torch.autograd.grad(loss_ce, [x_adv])[0] 27 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 28 | x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon) 29 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 30 | else: 31 | raise ValueError(f'Attack={attack} not supported for MART training!') 32 | model.train() 33 | 34 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 35 | # zero gradient 36 | optimizer.zero_grad() 37 | 38 | logits = model(x_natural) 39 | logits_adv = model(x_adv) 40 | 41 | adv_probs = F.softmax(logits_adv, dim=1) 42 | tmp1 = torch.argsort(adv_probs, dim=1)[:, -2:] 43 | new_y = torch.where(tmp1[:, -1] == y, tmp1[:, -2], tmp1[:, -1]) 44 | loss_adv = F.cross_entropy(logits_adv, y) + F.nll_loss(torch.log(1.0001 - adv_probs + 1e-12), new_y) 45 | 46 | nat_probs = F.softmax(logits, dim=1) 47 | true_probs = torch.gather(nat_probs, 1, (y.unsqueeze(1)).long()).squeeze() 48 | 49 | loss_robust = (1.0 / batch_size) * torch.sum( 50 | torch.sum(kl(torch.log(adv_probs + 1e-12), nat_probs), dim=1) * (1.0000001 - true_probs)) 51 | loss = loss_adv + float(beta) * loss_robust 52 | 53 | batch_metrics = {'loss': loss.item(), 'clean_acc': accuracy(y, logits.detach()), 54 | 'adversarial_acc': accuracy(y, logits_adv.detach())} 55 | 56 | return loss, batch_metrics 57 | -------------------------------------------------------------------------------- /core/utils/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from core.attacks import ATTACKS 4 | from core.data import DATASETS 5 | from core.models import MODELS 6 | from .train import SCHEDULERS 7 | 8 | from .utils import str2bool, str2float 9 | 10 | 11 | def parser_train(): 12 | """ 13 | Parse input arguments (train.py). 14 | """ 15 | parser = argparse.ArgumentParser(description='Standard + Adversarial Training.') 16 | 17 | parser.add_argument('--augment', type=str2bool, default=True, help='Augment training set.') 18 | parser.add_argument('--batch-size', type=int, default=128, help='Batch size for training.') 19 | parser.add_argument('--batch-size-validation', type=int, default=256, help='Batch size for testing.') 20 | parser.add_argument('--num-samples-eval', type=int, default=512, help='Number of samples to use for margin calculations.') 21 | 22 | parser.add_argument('--data-dir', type=str, default='/cluster/scratch/rarade/data/') 23 | parser.add_argument('--log-dir', type=str, default='/cluster/home/rarade/adversarial-hat/logs/') 24 | parser.add_argument('--tmp-dir', type=str, default='/cluster/scratch/rarade/') 25 | 26 | parser.add_argument('-d', '--data', type=str, default='cifar10', choices=DATASETS, help='Data to use.') 27 | parser.add_argument('--desc', type=str, required=True, 28 | help='Description of experiment. It will be used to name directories.') 29 | 30 | parser.add_argument('-m', '--model', choices=MODELS, default='resnet18', help='Model architecture to be used.') 31 | parser.add_argument('--normalize', type=str2bool, default=False, help='Normalize input.') 32 | parser.add_argument('--pretrained-file', type=str, default=None, help='Pretrained weights file name.') 33 | 34 | parser.add_argument('-ns', '--num-std-epochs', type=int, default=0, help='Number of standard training epochs.') 35 | parser.add_argument('-na', '--num-adv-epochs', type=int, default=0, help='Number of adversarial training epochs.') 36 | parser.add_argument('--adv-eval-freq', type=int, default=30, help='Adversarial evaluation frequency (in epochs).') 37 | 38 | parser.add_argument('--h', default=2.0, type=float, help='Parameter h to compute helper examples (x + h*r) for HAT.') 39 | parser.add_argument('--helper-model', type=str, default=None, help='Helper model weights file name for HAT.') 40 | parser.add_argument('--beta', default=None, type=float, help='Stability regularization, i.e., 1/lambda in TRADES \ 41 | or weight of robust loss in HAT.') 42 | parser.add_argument('--gamma', default=1.0, type=float, help='Weight of helper loss in HAT.') 43 | parser.add_argument('--robust-loss', default='kl', choices=['ce', 'kl'], type=str, help='Type of robust loss in HAT.') 44 | 45 | parser.add_argument('--lr', type=float, default=0.21, help='Learning rate for optimizer (SGD).') 46 | parser.add_argument('--weight-decay', type=float, default=5e-4, help='Optimizer (SGD) weight decay.') 47 | parser.add_argument('--scheduler', choices=SCHEDULERS, default='cyclic', help='Type of scheduler.') 48 | parser.add_argument('--nesterov', type=str2bool, default=True, help='Use Nesterov momentum.') 49 | parser.add_argument('--clip-grad', type=float, default=None, help='Gradient norm clipping.') 50 | 51 | parser.add_argument('-a', '--attack', type=str, choices=ATTACKS, default='linf-pgd', help='Type of attack.') 52 | parser.add_argument('--attack-eps', type=str2float, default=8/255, help='Epsilon for the attack.') 53 | parser.add_argument('--attack-step', type=str2float, default=2/255, help='Step size for PGD attack.') 54 | parser.add_argument('--attack-iter', type=int, default=10, help='Max. number of iterations (if any) for the attack.') 55 | parser.add_argument('--keep-clean', type=str2bool, default=False, help='Use clean samples during adversarial training.') 56 | 57 | parser.add_argument('--debug', action='store_true', default=False, 58 | help='Debug code. Run 1 epoch of training and evaluation.') 59 | parser.add_argument('--exp', action='store_true', default=False, 60 | help='Store results for performing margin and curvature experiments later.') 61 | parser.add_argument('--mart', action='store_true', default=False, help='MART training.') 62 | 63 | parser.add_argument('--unsup-fraction', type=float, default=0.5, help='Ratio of unlabelled data to labelled data.') 64 | parser.add_argument('--aux-data-filename', type=str, help='Path to additional Tiny Images data.', 65 | default='/cluster/scratch/rarade/cifar10s/ti_500K_pseudo_labeled.pickle') 66 | 67 | parser.add_argument('--seed', type=int, default=1, help='Random seed.') 68 | return parser 69 | 70 | 71 | def parser_eval(): 72 | """ 73 | Parse input arguments (eval-adv.py, eval-corr.py, eval-aa.py). 74 | """ 75 | parser = argparse.ArgumentParser(description='Robustness evaluation.') 76 | 77 | parser.add_argument('--data-dir', type=str, default='/cluster/home/rarade/adversarial-hat/data/') 78 | parser.add_argument('--log-dir', type=str, default='/cluster/home/rarade/adversarial-hat/logs/') 79 | 80 | parser.add_argument('--desc', type=str, required=True, help='Description of model to be evaluated.') 81 | parser.add_argument('--num-samples', type=int, default=1000, help='Number of test samples.') 82 | 83 | # eval-aa.py 84 | parser.add_argument('--train', action='store_true', default=False, help='Evaluate on training set.') 85 | parser.add_argument('-v', '--version', type=str, default='standard', choices=['custom', 'plus', 'standard'], 86 | help='Version of AA.') 87 | 88 | # eval-adv.py 89 | parser.add_argument('--source', type=str, default=None, help='Path to source model for black-box evaluation.') 90 | parser.add_argument('--wb', action='store_true', default=False, help='Perform white-box PGD evaluation.') 91 | 92 | # eval-rb.py 93 | parser.add_argument('--threat', type=str, default='corruptions', choices=['corruptions', 'Linf', 'L2'], 94 | help='Threat model for RobustBench evaluation.') 95 | 96 | parser.add_argument('--seed', type=int, default=1, help='Random seed.') 97 | return parser 98 | 99 | -------------------------------------------------------------------------------- /core/utils/rst.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | 6 | class CosineLR(torch.optim.lr_scheduler._LRScheduler): 7 | """ 8 | Cosine annealing LR schedule (used in Carmon et al, 2019). 9 | """ 10 | def __init__(self, optimizer, max_lr, epochs, last_epoch=-1): 11 | self.max_lr = max_lr 12 | self.epochs = epochs 13 | self._reset() 14 | super(CosineLR, self).__init__(optimizer, last_epoch) 15 | 16 | def _reset(self): 17 | self.current_lr = self.max_lr 18 | self.current_epoch = 1 19 | 20 | def step(self): 21 | self.current_lr = self.max_lr * 0.5 * (1 + np.cos((self.current_epoch - 1) / self.epochs * np.pi)) 22 | for param_group in self.optimizer.param_groups: 23 | param_group['lr'] = self.current_lr 24 | self.current_epoch += 1 25 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups] 26 | 27 | def get_lr(self): 28 | return self.current_lr 29 | -------------------------------------------------------------------------------- /core/utils/trades.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import torch.optim as optim 8 | 9 | from core.metrics import accuracy 10 | 11 | 12 | def squared_l2_norm(x): 13 | flattened = x.view(x.unsqueeze(0).shape[0], -1) 14 | return (flattened ** 2).sum(1) 15 | 16 | 17 | def l2_norm(x): 18 | return squared_l2_norm(x).sqrt() 19 | 20 | 21 | def trades_loss(model, x_natural, y, optimizer, step_size=0.003, epsilon=0.031, perturb_steps=10, beta=1.0, 22 | attack='linf-pgd'): 23 | """ 24 | TRADES training (Zhang et al, 2019). 25 | """ 26 | 27 | # define KL-loss 28 | criterion_kl = nn.KLDivLoss(reduction='sum') 29 | model.eval() 30 | batch_size = len(x_natural) 31 | # generate adversarial example 32 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach() 33 | p_natural = F.softmax(model(x_natural), dim=1) 34 | 35 | if attack == 'linf-pgd': 36 | for _ in range(perturb_steps): 37 | x_adv.requires_grad_() 38 | with torch.enable_grad(): 39 | loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1), p_natural) 40 | grad = torch.autograd.grad(loss_kl, [x_adv])[0] 41 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 42 | x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon) 43 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 44 | 45 | elif attack == 'l2-pgd': 46 | delta = 0.001 * torch.randn(x_natural.shape).cuda().detach() 47 | delta = Variable(delta.data, requires_grad=True) 48 | 49 | # Setup optimizers 50 | optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2) 51 | 52 | for _ in range(perturb_steps): 53 | adv = x_natural + delta 54 | 55 | # optimize 56 | optimizer_delta.zero_grad() 57 | with torch.enable_grad(): 58 | loss = (-1) * criterion_kl(F.log_softmax(model(adv), dim=1), p_natural) 59 | loss.backward(retain_graph=True) 60 | # renorming gradient 61 | grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1) 62 | delta.grad.div_(grad_norms.view(-1, 1, 1, 1)) 63 | # avoid nan or inf if gradient is 0 64 | if (grad_norms == 0).any(): 65 | delta.grad[grad_norms == 0] = torch.randn_like(delta.grad[grad_norms == 0]) 66 | optimizer_delta.step() 67 | 68 | # projection 69 | delta.data.add_(x_natural) 70 | delta.data.clamp_(0, 1).sub_(x_natural) 71 | delta.data.renorm_(p=2, dim=0, maxnorm=epsilon) 72 | x_adv = Variable(x_natural + delta, requires_grad=False) 73 | else: 74 | raise ValueError(f'Attack={attack} not supported for TRADES training!') 75 | model.train() 76 | 77 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 78 | 79 | optimizer.zero_grad() 80 | # calculate robust loss 81 | logits_natural = model(x_natural) 82 | logits_adv = model(x_adv) 83 | loss_natural = F.cross_entropy(logits_natural, y) 84 | loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(logits_adv, dim=1), 85 | F.softmax(logits_natural, dim=1)) 86 | loss = loss_natural + beta * loss_robust 87 | 88 | batch_metrics = {'loss': loss.item(), 'clean_acc': accuracy(y, logits_natural.detach()), 89 | 'adversarial_acc': accuracy(y, logits_adv.detach())} 90 | 91 | return loss, batch_metrics 92 | -------------------------------------------------------------------------------- /core/utils/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from tqdm import tqdm as tqdm 4 | 5 | import os 6 | import json 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from core.attacks import create_attack 12 | from core.metrics import accuracy 13 | from core.models import create_model 14 | 15 | from .context import ctx_noparamgrad_and_eval 16 | from .utils import seed 17 | 18 | from .hat import at_hat_loss 19 | from .hat import hat_loss 20 | from .mart import mart_loss 21 | from .rst import CosineLR 22 | from .trades import trades_loss 23 | 24 | 25 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 26 | 27 | SCHEDULERS = ['cyclic', 'step', 'cosine', 'cosinew'] 28 | 29 | 30 | class Trainer(object): 31 | """ 32 | Helper class for training a deep neural network. 33 | Arguments: 34 | info (dict): dataset information. 35 | args (dict): input arguments. 36 | """ 37 | def __init__(self, info, args): 38 | super(Trainer, self).__init__() 39 | 40 | seed(args.seed) 41 | self.model = create_model(args.model, args.normalize, info, device) 42 | 43 | self.params = args 44 | self.criterion = nn.CrossEntropyLoss() 45 | 46 | self.init_optimizer(self.params.num_std_epochs) 47 | 48 | if self.params.pretrained_file is not None: 49 | self.load_model(os.path.join(self.params.log_dir, self.params.pretrained_file, 'weights-best.pt')) 50 | 51 | if self.params.helper_model is not None: 52 | print (f'Using helper model: {self.params.helper_model}.') 53 | with open(os.path.join(self.params.log_dir, self.params.helper_model, 'args.txt'), 'r') as f: 54 | hr_args = json.load(f) 55 | self.hr_model = create_model(hr_args['model'], hr_args['normalize'], info, device) 56 | checkpoint = torch.load(os.path.join(self.params.log_dir, self.params.helper_model, 'weights-best.pt'), map_location=device) 57 | self.hr_model.load_state_dict(checkpoint['model_state_dict']) 58 | self.hr_model.eval() 59 | del checkpoint, hr_args 60 | 61 | self.attack, self.eval_attack = self.init_attack(self.model, self.criterion, self.params.attack, self.params.attack_eps, 62 | self.params.attack_iter, self.params.attack_step) 63 | 64 | 65 | @staticmethod 66 | def init_attack(model, criterion, attack_type, attack_eps, attack_iter, attack_step): 67 | """ 68 | Initialize adversary. 69 | """ 70 | attack = create_attack(model, criterion, attack_type, attack_eps, attack_iter, attack_step, rand_init_type='uniform') 71 | if attack_type in ['linf-pgd', 'l2-pgd']: 72 | eval_attack = create_attack(model, criterion, attack_type, attack_eps, 2*attack_iter, attack_step) 73 | elif attack_type in ['fgsm', 'linf-df']: 74 | eval_attack = create_attack(model, criterion, 'linf-pgd', 8/255, 20, 2/255) 75 | elif attack_type in ['fgm', 'l2-df']: 76 | eval_attack = create_attack(model, criterion, 'l2-pgd', 128/255, 20, 15/255) 77 | return attack, eval_attack 78 | 79 | 80 | def init_optimizer(self, num_epochs): 81 | """ 82 | Initialize optimizer and scheduler. 83 | """ 84 | self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay, 85 | momentum=0.9, nesterov=self.params.nesterov) 86 | if num_epochs <= 0: 87 | return 88 | self.init_scheduler(num_epochs) 89 | 90 | 91 | def init_scheduler(self, num_epochs): 92 | """ 93 | Initialize scheduler. 94 | """ 95 | if self.params.scheduler == 'cyclic': 96 | _NUM_SAMPLES = {'svhn': 73257, 'tiny-imagenet': 100000, 'imagenet100': 128334} 97 | num_samples = _NUM_SAMPLES.get(self.params.data, 50000) 98 | update_steps = int(np.floor(num_samples/self.params.batch_size) + 1) 99 | self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=self.params.lr, pct_start=0.25, 100 | steps_per_epoch=update_steps, epochs=int(num_epochs)) 101 | elif self.params.scheduler == 'step': 102 | milestones = [100, 105] 103 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, gamma=0.1, milestones=milestones) 104 | elif self.params.scheduler == 'cosine': 105 | self.scheduler = CosineLR(self.optimizer, max_lr=self.params.lr, epochs=int(num_epochs)) 106 | elif self.params.scheduler == 'cosinew': 107 | self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=self.params.lr, pct_start=0.025, 108 | total_steps=int(num_epochs)) 109 | else: 110 | self.scheduler = None 111 | 112 | 113 | def train(self, dataloader, epoch=0, adversarial=False, verbose=True): 114 | """ 115 | Run one epoch of training. 116 | """ 117 | metrics = pd.DataFrame() 118 | self.model.train() 119 | 120 | for data in tqdm(dataloader, desc='Epoch {}: '.format(epoch), disable=not verbose): 121 | x, y = data 122 | x, y = x.to(device), y.to(device) 123 | 124 | if adversarial: 125 | if self.params.helper_model is not None and self.params.beta is not None: 126 | loss, batch_metrics = self.hat_loss(x, y, h=self.params.h, beta=self.params.beta, gamma=self.params.gamma) 127 | elif self.params.beta is not None and self.params.mart: 128 | loss, batch_metrics = self.mart_loss(x, y, beta=self.params.beta) 129 | elif self.params.beta is not None: 130 | loss, batch_metrics = self.trades_loss(x, y, beta=self.params.beta) 131 | else: 132 | loss, batch_metrics = self.adversarial_loss(x, y) 133 | else: 134 | loss, batch_metrics = self.standard_loss(x, y) 135 | 136 | loss.backward() 137 | if self.params.clip_grad: 138 | nn.utils.clip_grad_norm_(self.model.parameters(), self.params.clip_grad) 139 | self.optimizer.step() 140 | if self.params.scheduler in ['cyclic']: 141 | self.scheduler.step() 142 | 143 | metrics = metrics.append(pd.DataFrame(batch_metrics, index=[0]), ignore_index=True) 144 | 145 | if self.params.scheduler in ['step', 'converge', 'cosine', 'cosinew']: 146 | self.scheduler.step() 147 | return dict(metrics.mean()) 148 | 149 | 150 | def standard_loss(self, x, y): 151 | """ 152 | Standard training. 153 | """ 154 | self.optimizer.zero_grad() 155 | out = self.model(x) 156 | loss = self.criterion(out, y) 157 | 158 | preds = out.detach() 159 | batch_metrics = {'loss': loss.item(), 'clean_acc': accuracy(y, preds)} 160 | return loss, batch_metrics 161 | 162 | 163 | def adversarial_loss(self, x, y): 164 | """ 165 | Adversarial training (Madry et al, 2017). 166 | """ 167 | with ctx_noparamgrad_and_eval(self.model): 168 | x_adv, _ = self.attack.perturb(x, y) 169 | 170 | self.optimizer.zero_grad() 171 | if self.params.keep_clean: 172 | x_adv = torch.cat((x, x_adv), dim=0) 173 | y_adv = torch.cat((y, y), dim=0) 174 | else: 175 | y_adv = y 176 | out = self.model(x_adv) 177 | loss = self.criterion(out, y_adv) 178 | 179 | preds = out.detach() 180 | batch_metrics = {'loss': loss.item()} 181 | if self.params.keep_clean: 182 | preds_clean, preds_adv = preds[:len(x)], preds[len(x):] 183 | batch_metrics.update({'clean_acc': accuracy(y, preds_clean), 'adversarial_acc': accuracy(y, preds_adv)}) 184 | else: 185 | batch_metrics.update({'adversarial_acc': accuracy(y, preds)}) 186 | return loss, batch_metrics 187 | 188 | 189 | def hat_loss(self, x, y, h, beta=1.0, gamma=1.0): 190 | """ 191 | Helper-based adversarial training. 192 | """ 193 | if self.params.robust_loss == 'kl': 194 | loss, batch_metrics = hat_loss(self.model, x, y, self.optimizer, step_size=self.params.attack_step, 195 | epsilon=self.params.attack_eps, perturb_steps=self.params.attack_iter, 196 | h=h, beta=beta, gamma=gamma, attack=self.params.attack, hr_model=self.hr_model) 197 | else: 198 | loss, batch_metrics = at_hat_loss(self.model, x, y, self.optimizer, step_size=self.params.attack_step, 199 | epsilon=self.params.attack_eps, perturb_steps=self.params.attack_iter, 200 | h=h, beta=beta, gamma=gamma, attack=self.params.attack, hr_model=self.hr_model) 201 | return loss, batch_metrics 202 | 203 | 204 | def trades_loss(self, x, y, beta): 205 | """ 206 | TRADES training. 207 | """ 208 | loss, batch_metrics = trades_loss(self.model, x, y, self.optimizer, step_size=self.params.attack_step, 209 | epsilon=self.params.attack_eps, perturb_steps=self.params.attack_iter, 210 | beta=beta, attack=self.params.attack) 211 | return loss, batch_metrics 212 | 213 | 214 | def mart_loss(self, x, y, beta): 215 | """ 216 | MART training. 217 | """ 218 | loss, batch_metrics = mart_loss(self.model, x, y, self.optimizer, step_size=self.params.attack_step, 219 | epsilon=self.params.attack_eps, perturb_steps=self.params.attack_iter, 220 | beta=beta, attack=self.params.attack) 221 | return loss, batch_metrics 222 | 223 | 224 | def eval(self, dataloader, adversarial=False): 225 | """ 226 | Evaluate performance of the model. 227 | """ 228 | acc = 0.0 229 | self.model.eval() 230 | 231 | for x, y in dataloader: 232 | x, y = x.to(device), y.to(device) 233 | if adversarial: 234 | with ctx_noparamgrad_and_eval(self.model): 235 | x_adv, _ = self.eval_attack.perturb(x, y) 236 | out = self.model(x_adv) 237 | else: 238 | out = self.model(x) 239 | acc += accuracy(y, out) 240 | acc /= len(dataloader) 241 | return acc 242 | 243 | 244 | def save_and_eval_adversarial(self, dataloader, save, verbose=False, to_true=False, save_all=True): 245 | """ 246 | Evaluate adversarial accuracy and save perturbations. 247 | """ 248 | if save_all: 249 | x_all, y_all = [], [] 250 | r_adv_all = [] 251 | acc_adv = 0.0 252 | self.eval_attack.targeted = False 253 | self.model.eval() 254 | 255 | for x, y in tqdm(dataloader, disable=not verbose): 256 | x, y = x.to(device), y.to(device) 257 | with ctx_noparamgrad_and_eval(self.model): 258 | if to_true: 259 | pred_y_orig = self.model(x).argmax(dim=1) 260 | correct_ind = pred_y_orig == y 261 | 262 | x_adv, r_adv = torch.zeros(x.shape).to(device), torch.zeros(x.shape).to(device) 263 | self.eval_attack.targeted = False 264 | x_adv1, r_adv1 = self.eval_attack.perturb(x[correct_ind], y[correct_ind]) 265 | self.eval_attack.targeted = True 266 | x_adv0, r_adv0 = self.eval_attack.perturb(x[~correct_ind], y[~correct_ind]) 267 | x_adv[correct_ind], r_adv[correct_ind] = x_adv1, r_adv1 268 | x_adv[~correct_ind], r_adv[~correct_ind] = x_adv0, r_adv0 269 | else: 270 | x_adv, r_adv = self.eval_attack.perturb(x) 271 | 272 | out = self.model(x_adv) 273 | acc_adv += accuracy(y, out) 274 | if save_all: 275 | x_all.append(x.cpu().numpy()) 276 | y_all.extend(y.cpu().numpy()) 277 | r_adv_all.append((r_adv).cpu().numpy()) 278 | 279 | acc_adv /= len(dataloader) 280 | if save: 281 | r_adv_all = np.vstack(r_adv_all) 282 | if save_all: 283 | x_all = np.vstack(x_all) 284 | np_save({ 'x': x_all, 'r': r_adv_all, 'y': y_all }, save) 285 | else: 286 | np_save({ 'r': r_adv_all }, save) 287 | 288 | self.eval_attack.targeted = False 289 | return acc_adv 290 | 291 | 292 | def set_bn_to_eval(self): 293 | """ 294 | Set all batch normalization layers to evaluation mode. 295 | """ 296 | for m in self.model.modules(): 297 | if isinstance(m, nn.modules.BatchNorm2d): 298 | m.eval() 299 | 300 | 301 | def save_model(self, path): 302 | """ 303 | Save model weights. 304 | """ 305 | torch.save({'model_state_dict': self.model.state_dict()}, path) 306 | 307 | 308 | def load_model(self, path, load_opt=True): 309 | """ 310 | Load model weights. 311 | """ 312 | checkpoint = torch.load(path) 313 | if 'model_state_dict' not in checkpoint: 314 | raise RuntimeError('Model weights not found at {}.'.format(path)) 315 | self.model.load_state_dict(checkpoint['model_state_dict']) 316 | -------------------------------------------------------------------------------- /core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import datetime 4 | import numpy as np 5 | import _pickle as pickle 6 | 7 | import torch 8 | 9 | 10 | class SmoothCrossEntropyLoss(torch.nn.Module): 11 | """ 12 | Soft cross entropy loss with label smoothing. 13 | """ 14 | def __init__(self, smoothing=0.0, reduction='mean'): 15 | super(SmoothCrossEntropyLoss, self).__init__() 16 | self.smoothing = smoothing 17 | self.reduction = reduction 18 | 19 | def forward(self, input, target): 20 | num_classes = input.shape[1] 21 | if target.ndim == 1: 22 | target = torch.nn.functional.one_hot(target, num_classes) 23 | target = (1. - self.smoothing) * target + self.smoothing / num_classes 24 | logprobs = torch.nn.functional.log_softmax(input, dim=1) 25 | loss = - (target * logprobs).sum(dim=1) 26 | if self.reduction == 'mean': 27 | return loss.mean() 28 | elif self.reduction == 'sum': 29 | return loss.sum() 30 | return loss 31 | 32 | 33 | def track_bn_stats(model, track_stats=True): 34 | """ 35 | If track_stats=False, do not update BN running mean and variance and vice versa. 36 | """ 37 | for module in model.modules(): 38 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 39 | module.track_running_stats = track_stats 40 | 41 | def set_bn_momentum(model, momentum=1): 42 | """ 43 | Set the value of momentum for all BN layers. 44 | """ 45 | for module in model.modules(): 46 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 47 | module.momentum = momentum 48 | 49 | 50 | def str2bool(v): 51 | """ 52 | Parse boolean using argument parser. 53 | """ 54 | if isinstance(v, bool): 55 | return v 56 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 57 | return True 58 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 59 | return False 60 | else: 61 | raise argparse.ArgumentTypeError('Boolean value expected.') 62 | 63 | def str2float(x): 64 | """ 65 | Parse float and fractions using argument parser. 66 | """ 67 | if '/' in x: 68 | n, d = x.split('/') 69 | return float(n)/float(d) 70 | else: 71 | try: 72 | return float(x) 73 | except: 74 | raise argparse.ArgumentTypeError('Fraction or float value expected.') 75 | 76 | 77 | def format_time(elapsed): 78 | """ 79 | Format time for displaying. 80 | Arguments: 81 | elapsed: time interval in seconds. 82 | """ 83 | elapsed_rounded = int(round((elapsed))) 84 | return str(datetime.timedelta(seconds=elapsed_rounded)) 85 | 86 | 87 | def seed(seed=1): 88 | """ 89 | Seed for PyTorch reproducibility. 90 | Arguments: 91 | seed (int): Random seed value. 92 | """ 93 | np.random.seed(seed) 94 | torch.manual_seed(seed) 95 | torch.cuda.manual_seed_all(seed) 96 | 97 | 98 | def unpickle_data(filename, mode='rb'): 99 | """ 100 | Read data from pickled file. 101 | Arguments: 102 | filename (str): path to the pickled file. 103 | mode (str): read mode. 104 | """ 105 | with open(filename, mode) as pkfile: 106 | data = pickle.load(pkfile) 107 | return data 108 | 109 | 110 | def pickle_data(data, filename, mode='wb'): 111 | """ 112 | Write data to pickled file. 113 | Arguments: 114 | data (Any): data to be written. 115 | filename (str): path to the pickled file. 116 | mode (str): write mode. 117 | """ 118 | with open(filename, mode) as pkfile: 119 | pickle.dump(data, pkfile) 120 | 121 | 122 | def np_load(foldername): 123 | """ 124 | Read data from npy files. 125 | Arguments: 126 | foldername (str): path to the folder. 127 | """ 128 | data = {} 129 | if os.path.isfile(os.path.join(foldername, 'x.npy')): 130 | x = np.load(os.path.join(foldername, 'x.npy')) 131 | data['x'] = x 132 | if os.path.isfile(os.path.join(foldername, 'y.npy')): 133 | y = np.load(os.path.join(foldername, 'y.npy')) 134 | data['y'] = y 135 | if os.path.isfile(os.path.join(foldername, 'r.npy')): 136 | r = np.load(os.path.join(foldername, 'r.npy')) 137 | data['r'] = r 138 | return data 139 | 140 | 141 | def np_save(data, foldername): 142 | """ 143 | Write data as npy files. 144 | Arguments: 145 | data (Dict): data to be written. 146 | foldername (str): path to the folder. 147 | """ 148 | if not os.path.exists(foldername): 149 | os.makedirs(foldername, exist_ok=True) 150 | if 'x' in data: 151 | np.save(os.path.join(foldername, 'x.npy'), data['x']) 152 | if 'y' in data: 153 | np.save(os.path.join(foldername, 'y.npy'), data['y']) 154 | if 'r' in data: 155 | np.save(os.path.join(foldername, 'r.npy'), data['r']) 156 | 157 | 158 | class NumpyToTensor(object): 159 | """ 160 | Transforms a numpy.ndarray to torch.Tensor. 161 | """ 162 | def __call__(self, sample): 163 | return torch.from_numpy(sample) 164 | -------------------------------------------------------------------------------- /eval-aa.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation with AutoAttack. 3 | """ 4 | 5 | import json 6 | import time 7 | import argparse 8 | import shutil 9 | 10 | import os 11 | import numpy as np 12 | import pandas as pd 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from autoattack import AutoAttack 18 | 19 | from core.data import get_data_info 20 | from core.data import load_data 21 | from core.models import create_model 22 | 23 | from core.utils import Logger 24 | from core.utils import parser_eval 25 | from core.utils import seed 26 | from core import setup 27 | 28 | 29 | 30 | # Setup 31 | 32 | parse = parser_eval() 33 | args = parse.parse_args() 34 | 35 | LOG_DIR = os.path.join(args.log_dir, args.desc) 36 | with open(os.path.join(LOG_DIR, 'args.txt'), 'r') as f: 37 | old = json.load(f) 38 | old['data_dir'], old['log_dir'] = args.data_dir, args.log_dir 39 | args.__dict__ = dict(vars(args), **old) 40 | 41 | DATA_DIR = os.path.join(args.data_dir, args.data) 42 | LOG_DIR = os.path.join(args.log_dir, args.desc) 43 | WEIGHTS = os.path.join(LOG_DIR, 'weights-best.pt') 44 | 45 | if 'imagenet' in args.data: 46 | setup.setup_train(DATA_DIR) 47 | setup.setup_val(DATA_DIR) 48 | args.data_dir = os.environ['TMPDIR'] 49 | DATA_DIR = os.path.join(args.data_dir, args.data) 50 | 51 | log_path = os.path.join(LOG_DIR, 'log-aa.log') 52 | logger = Logger(log_path) 53 | 54 | info = get_data_info(DATA_DIR) 55 | BATCH_SIZE = args.batch_size 56 | BATCH_SIZE_VALIDATION = args.batch_size_validation 57 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 58 | 59 | logger.log('Using device: {}'.format(device)) 60 | 61 | 62 | 63 | # Load data 64 | 65 | seed(args.seed) 66 | _, _, train_dataloader, test_dataloader = load_data(DATA_DIR, BATCH_SIZE, BATCH_SIZE_VALIDATION, use_augmentation=False, 67 | shuffle_train=False) 68 | 69 | if args.train: 70 | logger.log('Evaluating on training set.') 71 | l = [x for (x, y) in train_dataloader] 72 | x_test = torch.cat(l, 0) 73 | l = [y for (x, y) in train_dataloader] 74 | y_test = torch.cat(l, 0) 75 | else: 76 | l = [x for (x, y) in test_dataloader] 77 | x_test = torch.cat(l, 0) 78 | l = [y for (x, y) in test_dataloader] 79 | y_test = torch.cat(l, 0) 80 | 81 | 82 | 83 | # Model 84 | 85 | model = create_model(args.model, args.normalize, info, device) 86 | checkpoint = torch.load(WEIGHTS) 87 | if 'tau' in args and args.tau: 88 | print ('Using WA model.') 89 | model.load_state_dict(checkpoint['model_state_dict']) 90 | model.eval() 91 | del checkpoint 92 | 93 | 94 | 95 | # AA Evaluation 96 | 97 | seed(args.seed) 98 | norm = 'Linf' if args.attack in ['fgsm', 'linf-pgd', 'linf-df'] else 'L2' 99 | adversary = AutoAttack(model, norm=norm, eps=args.attack_eps, log_path=log_path, version=args.version, seed=args.seed) 100 | 101 | if args.version == 'custom': 102 | adversary.attacks_to_run = ['apgd-ce', 'apgd-t'] 103 | adversary.apgd.n_restarts = 1 104 | adversary.apgd_targeted.n_restarts = 1 105 | 106 | with torch.no_grad(): 107 | x_adv = adversary.run_standard_evaluation(x_test, y_test, bs=BATCH_SIZE_VALIDATION) 108 | 109 | print ('Script Completed.') 110 | -------------------------------------------------------------------------------- /eval-adv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adversarial Evaluation with PGD+, CW (Margin) PGD and black box adversary. 3 | """ 4 | 5 | import json 6 | import time 7 | import argparse 8 | import shutil 9 | 10 | import os 11 | import numpy as np 12 | import pandas as pd 13 | from tqdm import tqdm as tqdm 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | from core.attacks import create_attack 19 | from core.attacks import CWLoss 20 | 21 | from core.data import get_data_info 22 | from core.data import load_data 23 | 24 | from core.models import create_model 25 | 26 | from core.utils import ctx_noparamgrad_and_eval 27 | from core.utils import Logger 28 | from core.utils import parser_eval 29 | from core.utils import seed 30 | from core.utils import Trainer 31 | from core import setup 32 | 33 | 34 | # Setup 35 | 36 | parse = parser_eval() 37 | args = parse.parse_args() 38 | 39 | LOG_DIR = os.path.join(args.log_dir, args.desc) 40 | with open(os.path.join(LOG_DIR, 'args.txt'), 'r') as f: 41 | old = json.load(f) 42 | old['data_dir'], old['log_dir'] = args.data_dir, args.log_dir 43 | args.__dict__ = dict(vars(args), **old) 44 | 45 | DATA_DIR = os.path.join(args.data_dir, args.data) 46 | LOG_DIR = os.path.join(args.log_dir, args.desc) 47 | WEIGHTS = os.path.join(LOG_DIR, 'weights-best.pt') 48 | 49 | if 'imagenet' in args.data: 50 | setup.setup_train(DATA_DIR) 51 | setup.setup_val(DATA_DIR) 52 | args.data_dir = os.environ['TMPDIR'] 53 | DATA_DIR = os.path.join(args.data_dir, args.data) 54 | 55 | logger = Logger(os.path.join(LOG_DIR, 'log-adv.log')) 56 | 57 | info = get_data_info(DATA_DIR) 58 | BATCH_SIZE = args.batch_size 59 | BATCH_SIZE_VALIDATION = args.batch_size_validation 60 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 61 | 62 | logger.log('Using device: {}'.format(device)) 63 | 64 | 65 | 66 | # Load data 67 | 68 | seed(args.seed) 69 | _, _, _, test_dataloader = load_data(DATA_DIR, BATCH_SIZE, BATCH_SIZE_VALIDATION, use_augmentation=False, 70 | shuffle_train=False) 71 | 72 | 73 | 74 | # Helper function 75 | 76 | def eval_multiple_restarts(attack, model, dataloader, num_restarts=5, verbose=True): 77 | """ 78 | Evaluate adversarial accuracy with multiple restarts. 79 | """ 80 | model.eval() 81 | N = len(dataloader.dataset) 82 | is_correct = torch.ones(N).bool().to(device) 83 | for i in tqdm(range(0, num_restarts), disable=not verbose): 84 | iter_is_correct = [] 85 | for x, y in tqdm(dataloader): 86 | x, y = x.to(device), y.to(device) 87 | with ctx_noparamgrad_and_eval(model): 88 | x_adv, _ = attack.perturb(x, y) 89 | out = model(x_adv) 90 | iter_is_correct.extend(torch.softmax(out, dim=1).argmax(dim=1) == y) 91 | is_correct = torch.logical_and(is_correct, torch.BoolTensor(iter_is_correct).to(device)) 92 | 93 | adv_acc = (is_correct.sum().float()/N).item() 94 | return adv_acc 95 | 96 | def eval_multiple_restarts_advertorch(attack, model, dataloader, num_restarts=1, verbose=True): 97 | """ 98 | Evaluate adversarial accuracy with multiple restarts (Advertorch). 99 | """ 100 | model.eval() 101 | N = len(dataloader.dataset) 102 | is_correct = torch.ones(N).bool().to(device) 103 | for i in tqdm(range(0, num_restarts), disable=not verbose): 104 | iter_is_correct = [] 105 | for x, y in tqdm(dataloader): 106 | x, y = x.to(device), y.to(device) 107 | with ctx_noparamgrad_and_eval(model): 108 | x_adv = attack.perturb(x, y) 109 | out = model(x_adv) 110 | iter_is_correct.extend(torch.softmax(out, dim=1).argmax(dim=1) == y) 111 | is_correct = torch.logical_and(is_correct, torch.BoolTensor(iter_is_correct).to(device)) 112 | 113 | adv_acc = (is_correct.sum().float()/N).item() 114 | return adv_acc 115 | 116 | 117 | 118 | # PGD Evaluation 119 | 120 | seed(args.seed) 121 | trainer = Trainer(info, args) 122 | if 'tau' in args and args.tau: 123 | print ('Using WA model.') 124 | trainer.load_model(WEIGHTS) 125 | trainer.model.eval() 126 | 127 | test_acc = trainer.eval(test_dataloader) 128 | logger.log('\nStandard Accuracy-\tTest: {:.2f}%.'.format(test_acc*100)) 129 | 130 | 131 | 132 | if args.wb: 133 | # CW-PGD-40 Evaluation 134 | seed(args.seed) 135 | num_restarts = 1 136 | if args.attack in ['linf-pgd', 'linf-df', 'fgsm']: 137 | args.attack_iter, args.attack_step = 40, 0.01 138 | else: 139 | args.attack_iter, args.attack_step = 40, 30/255.0 140 | assert args.attack in ['linf-pgd', 'l2-pgd'], 'CW evaluation only supported for attack=linf-pgd or attack=l2-pgd !' 141 | attack = create_attack(trainer.model, CWLoss, args.attack, args.attack_eps, args.attack_iter, args.attack_step) 142 | logger.log('\n==== CW-PGD Evaluation. ====') 143 | logger.log('Attack: cw-{}.'.format(args.attack)) 144 | logger.log('Attack Parameters: Step size: {:.3f}, Epsilon: {:.3f}, #Iterations: {}.'.format(args.attack_step, 145 | args.attack_eps, 146 | args.attack_iter)) 147 | 148 | test_adv_acc1 = eval_multiple_restarts(attack, trainer.model, test_dataloader, num_restarts, verbose=False) 149 | logger.log('Adversarial Accuracy-\tTest: {:.2f}%.'.format(test_adv_acc1*100)) 150 | 151 | 152 | # PGD-40 (with 5 restarts) Evaluation 153 | seed(args.seed) 154 | num_restarts = 5 155 | if args.attack in ['linf-pgd', 'linf-df', 'fgsm']: 156 | args.attack_iter, args.attack_step = 40, 0.01 157 | else: 158 | args.attack_iter, args.attack_step = 40, 30/255.0 159 | attack = create_attack(trainer.model, trainer.criterion, args.attack, args.attack_eps, args.attack_iter, args.attack_step) 160 | logger.log('\n==== PGD+ Evaluation. ====') 161 | logger.log('Attack: {} with {} restarts.'.format(args.attack, num_restarts)) 162 | logger.log('Attack Parameters: Step size: {:.3f}, Epsilon: {:.3f}, #Iterations: {}.'.format(args.attack_step, 163 | args.attack_eps, 164 | args.attack_iter)) 165 | 166 | test_adv_acc2 = eval_multiple_restarts(attack, trainer.model, test_dataloader, num_restarts, verbose=True) 167 | logger.log('Adversarial Accuracy-\tTest: {:.2f}%.'.format(test_adv_acc2*100)) 168 | 169 | 170 | 171 | # Black Box Evaluation 172 | 173 | class dotdict(dict): 174 | def __getattr__(self, name): 175 | return self[name] 176 | 177 | if args.source != None: 178 | seed(args.seed) 179 | assert args.attack in ['linf-pgd', 'l2-pgd'], 'Black-box evaluation only supported for attack=linf-pgd or attack=l2-pgd!' 180 | if args.attack in ['linf-pgd', 'linf-df', 'fgsm']: 181 | args.attack_iter, args.attack_step = 40, 0.01 182 | else: 183 | args.attack_iter, args.attack_step = 40, 30/255.0 184 | 185 | SRC_LOG_DIR = args.log_dir + args.source 186 | with open(os.path.join(SRC_LOG_DIR, 'args.txt'), 'r') as f: 187 | src_args = json.load(f) 188 | src_args = dotdict(src_args) 189 | 190 | src_model = create_model(src_args.model, src_args.normalize, info, device) 191 | src_model.load_state_dict(torch.load(os.path.join(SRC_LOG_DIR, 'weights-best.pt'))['model_state_dict']) 192 | src_model.eval() 193 | 194 | src_attack = create_attack(src_model, trainer.criterion, args.attack, args.attack_eps, args.attack_iter, args.attack_step) 195 | adv_acc = 0.0 196 | for x, y in test_dataloader: 197 | x, y = x.to(device), y.to(device) 198 | with ctx_noparamgrad_and_eval(src_model): 199 | x_adv, _ = src_attack.perturb(x, y) 200 | out = trainer.model(x_adv) 201 | adv_acc += accuracy(y, out) 202 | adv_acc /= len(test_dataloader) 203 | 204 | logger.log('\n==== Black-box Evaluation. ====') 205 | logger.log('Source Model: {}.'.format(args.source)) 206 | logger.log('Attack: {}.'.format(args.attack)) 207 | logger.log('Attack Parameters: Step size: {:.3f}, Epsilon: {:.3f}, #Iterations: {}.'.format(args.attack_step, 208 | args.attack_eps, 209 | args.attack_iter)) 210 | logger.log('Black-box Adv. Accuracy-\tTest: {:.2f}%.'.format(adv_acc*100)) 211 | del src_attack, src_model 212 | 213 | 214 | logger.log('Script Completed.') 215 | -------------------------------------------------------------------------------- /eval-rb.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation with Robustbench. 3 | """ 4 | 5 | import json 6 | import time 7 | import argparse 8 | import shutil 9 | 10 | import os 11 | import numpy as np 12 | import pandas as pd 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from robustbench import benchmark 18 | 19 | from core.data import get_data_info 20 | from core.models import create_model 21 | 22 | from core.utils import Logger 23 | from core.utils import parser_eval 24 | from core.utils import seed 25 | 26 | 27 | 28 | # Setup 29 | 30 | parse = parser_eval() 31 | 32 | args = parse.parse_args() 33 | 34 | LOG_DIR = os.path.join(args.log_dir, args.desc) 35 | with open(os.path.join(LOG_DIR, 'args.txt'), 'r') as f: 36 | old = json.load(f) 37 | old['data_dir'], old['log_dir'] = args.data_dir, args.log_dir 38 | args.__dict__ = dict(vars(args), **old) 39 | 40 | args.data = args.data[:-1] if args.data in ['cifar10s', 'cifar100s'] else args.data 41 | DATA_DIR = os.path.join(args.data_dir, args.data) 42 | LOG_DIR = os.path.join(args.log_dir, args.desc) 43 | WEIGHTS = os.path.join(LOG_DIR, 'weights-best.pt') 44 | 45 | log_path = os.path.join(LOG_DIR, f'log-corr-{args.threat}.log') 46 | logger = Logger(log_path) 47 | 48 | info = get_data_info(DATA_DIR) 49 | BATCH_SIZE = args.batch_size 50 | BATCH_SIZE_VALIDATION = args.batch_size_validation 51 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 52 | 53 | assert args.data in ['cifar10', 'cifar100'], 'Evaluation on Robustbench is only supported for cifar10, cifar100!' 54 | 55 | threat_model = args.threat 56 | dataset = args.data 57 | model_name = args.desc 58 | 59 | 60 | 61 | # Model 62 | 63 | model = create_model(args.model, args.normalize, info, device) 64 | checkpoint = torch.load(WEIGHTS) 65 | if 'tau' in args and args.tau: 66 | print ('Using WA model.') 67 | model.load_state_dict(checkpoint['model_state_dict']) 68 | model.eval() 69 | del checkpoint 70 | 71 | 72 | 73 | # Common corruptions 74 | 75 | seed(args.seed) 76 | clean_acc, robust_acc = benchmark(model, model_name=model_name, n_examples=args.num_samples, dataset=dataset, 77 | threat_model=threat_model, eps=args.attack_eps, device=device, to_disk=False, 78 | data_dir=os.path.join(args.data_dir, f'{args.data}c')) 79 | 80 | 81 | logger.log('Model: {}'.format(args.desc)) 82 | logger.log('Evaluating robustness on {} with threat model={}.'.format(args.data, args.threat)) 83 | logger.log('Clean Accuracy: \t{:.2f}%.\nRobust Accuracy: \t{:.2f}%.'.format(clean_acc*100, robust_acc*100)) 84 | -------------------------------------------------------------------------------- /gowal21uncovering/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .watrain import WATrainer -------------------------------------------------------------------------------- /gowal21uncovering/utils/hat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | from core.attacks import create_attack 9 | from core.metrics import accuracy 10 | from core.utils.context import ctx_noparamgrad_and_eval 11 | from core.utils import SmoothCrossEntropyLoss 12 | from core.utils import track_bn_stats 13 | 14 | 15 | def hat_loss(model, x, y, optimizer, step_size=0.007, epsilon=0.031, perturb_steps=10, h=3.5, beta=1.0, gamma=1.0, 16 | attack='linf-pgd', hr_model=None, label_smoothing=0.1): 17 | """ 18 | TRADES + Helper-based adversarial training. 19 | """ 20 | criterion_ce = SmoothCrossEntropyLoss(reduction='mean', smoothing=label_smoothing) 21 | criterion_kl = nn.KLDivLoss(reduction='sum') 22 | model.train() 23 | track_bn_stats(model, False) 24 | 25 | x_adv = x.detach() + torch.FloatTensor(x.shape).uniform_(-epsilon, epsilon).cuda().detach() 26 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 27 | p_natural = F.softmax(model(x), dim=1).detach() 28 | 29 | if attack == 'linf-pgd': 30 | for _ in range(perturb_steps): 31 | x_adv.requires_grad_() 32 | with torch.enable_grad(): 33 | loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1), p_natural) 34 | grad = torch.autograd.grad(loss_kl, [x_adv])[0] 35 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 36 | x_adv = torch.min(torch.max(x_adv, x - epsilon), x + epsilon) 37 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 38 | elif attack == 'l2-pgd': 39 | delta = torch.FloatTensor(x.shape).normal_(mean=0, std=1.0).cuda().detach() 40 | delta.data = delta.data * np.random.uniform(0.0, epsilon) / (delta.data**2).sum([1, 2, 3], keepdim=True)**0.5 41 | delta = Variable(delta.data, requires_grad=True).cuda() 42 | 43 | batch_size = len(x) 44 | optimizer_delta = torch.optim.SGD([delta], lr=step_size) 45 | for _ in range(perturb_steps): 46 | adv = x + delta 47 | optimizer_delta.zero_grad() 48 | with torch.enable_grad(): 49 | loss = (-1) * criterion_kl(F.log_softmax(model(adv), dim=1), p_natural) 50 | loss.backward(retain_graph=True) 51 | 52 | grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1) 53 | delta.grad.div_(grad_norms.view(-1, 1, 1, 1)) 54 | if (grad_norms == 0).any(): 55 | delta.grad[grad_norms == 0] = torch.randn_like(delta.grad[grad_norms == 0]) 56 | optimizer_delta.step() 57 | 58 | delta.data.add_(x) 59 | delta.data.clamp_(0, 1).sub_(x) 60 | delta.data.renorm_(p=2, dim=0, maxnorm=epsilon) 61 | x_adv = Variable(x + delta, requires_grad=False) 62 | else: 63 | raise ValueError(f'Attack={attack} not supported for TRADES training!') 64 | model.train() 65 | 66 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 67 | x_hr = x + h * (x_adv - x) 68 | with ctx_noparamgrad_and_eval(hr_model): 69 | y_hr = hr_model(x_adv).argmax(dim=1) 70 | 71 | optimizer.zero_grad() 72 | track_bn_stats(model, True) 73 | 74 | # a hack to save memory when using large batch sizes. 75 | # first, calculate gradients with clean and adversarial samples. 76 | # then, clear intermediate activations and calculate gradients with helper samples. 77 | # one can use a single .backward() at the expense of higher memory usage. 78 | out_clean = model(x) 79 | out_adv = model(x_adv) 80 | loss_clean = criterion_ce(out_clean, y) 81 | loss_adv = (1/len(x)) * criterion_kl(F.log_softmax(out_adv, dim=1), F.softmax(out_clean, dim=1)) 82 | loss = loss_clean + beta * loss_adv 83 | total_loss = loss.item() 84 | loss.backward() 85 | 86 | out_help = model(x_hr) 87 | loss = gamma * F.cross_entropy(out_help, y_hr, reduction='mean') 88 | total_loss += loss.item() 89 | 90 | batch_metrics = {'loss': total_loss} 91 | batch_metrics.update({'adversarial_acc': accuracy(y, out_adv.detach()), 'helper_acc': accuracy(y_hr, out_help.detach())}) 92 | batch_metrics.update({'clean_acc': accuracy(y, out_clean.detach())}) 93 | return loss, batch_metrics 94 | 95 | 96 | def at_hat_loss(model, x, y, optimizer, step_size=0.007, epsilon=0.031, perturb_steps=10, h=3.5, beta=1.0, gamma=1.0, 97 | attack='linf-pgd', hr_model=None, label_smoothing=0.1): 98 | """ 99 | AT + Helper-based adversarial training. 100 | """ 101 | criterion_ce_smooth = SmoothCrossEntropyLoss(reduction='mean', smoothing=label_smoothing) 102 | criterion_ce = nn.CrossEntropyLoss() 103 | model.train() 104 | track_bn_stats(model, False) 105 | 106 | attack = create_attack(model, criterion_ce, attack, epsilon, perturb_steps, step_size) 107 | with ctx_noparamgrad_and_eval(model): 108 | x_adv, _ = attack.perturb(x, y) 109 | 110 | model.train() 111 | 112 | x_hr = x + h * (x_adv - x) 113 | with ctx_noparamgrad_and_eval(hr_model): 114 | y_hr = hr_model(x_adv).argmax(dim=1) 115 | 116 | optimizer.zero_grad() 117 | track_bn_stats(model, True) 118 | 119 | out_clean, out_adv, out_help = model(x), model(x_adv), model(x_hr) 120 | loss_clean = criterion_ce_smooth(out_clean, y) 121 | loss_adv = criterion_ce(out_adv, y) 122 | loss_help = F.cross_entropy(out_help, y_hr, reduction='mean') 123 | loss = loss_clean + beta * loss_adv + gamma * loss_help 124 | 125 | batch_metrics = {'loss': loss.item()} 126 | batch_metrics.update({'adversarial_acc': accuracy(y, out_adv.detach()), 'helper_acc': accuracy(y_hr, out_help.detach())}) 127 | batch_metrics.update({'clean_acc': accuracy(y, out_clean.detach())}) 128 | return loss, batch_metrics 129 | -------------------------------------------------------------------------------- /gowal21uncovering/utils/trades.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import torch.optim as optim 6 | 7 | from core.metrics import accuracy 8 | from core.utils import SmoothCrossEntropyLoss 9 | from core.utils import track_bn_stats 10 | 11 | 12 | def squared_l2_norm(x): 13 | flattened = x.view(x.unsqueeze(0).shape[0], -1) 14 | return (flattened ** 2).sum(1) 15 | 16 | 17 | def l2_norm(x): 18 | return squared_l2_norm(x).sqrt() 19 | 20 | 21 | def trades_loss(model, x_natural, y, optimizer, step_size=0.003, epsilon=0.031, perturb_steps=10, beta=1.0, 22 | attack='linf-pgd', label_smoothing=0.1): 23 | """ 24 | TRADES training (Zhang et al, 2019). 25 | """ 26 | 27 | criterion_ce = SmoothCrossEntropyLoss(reduction='mean', smoothing=label_smoothing) 28 | criterion_kl = nn.KLDivLoss(reduction='sum') 29 | model.train() 30 | track_bn_stats(model, False) 31 | batch_size = len(x_natural) 32 | 33 | x_adv = x_natural.detach() + torch.FloatTensor(x_natural.shape).uniform_(-epsilon, epsilon).cuda().detach() 34 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 35 | p_natural = F.softmax(model(x_natural), dim=1).detach() 36 | 37 | if attack == 'linf-pgd': 38 | for _ in range(perturb_steps): 39 | x_adv.requires_grad_() 40 | with torch.enable_grad(): 41 | loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1), p_natural) 42 | grad = torch.autograd.grad(loss_kl, [x_adv])[0] 43 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 44 | x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon) 45 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 46 | 47 | elif attack == 'l2-pgd': 48 | delta = torch.FloatTensor(x.shape).normal_(mean=0, std=1.0).cuda().detach() 49 | delta.data = delta.data * np.random.uniform(0.0, epsilon) / (delta.data**2).sum([1, 2, 3], keepdim=True)**0.5 50 | delta = Variable(delta.data, requires_grad=True) 51 | 52 | # Setup optimizers 53 | optimizer_delta = optim.SGD([delta], lr=step_size) 54 | 55 | for _ in range(perturb_steps): 56 | adv = x_natural + delta 57 | 58 | # optimize 59 | optimizer_delta.zero_grad() 60 | with torch.enable_grad(): 61 | loss = (-1) * criterion_kl(F.log_softmax(model(adv), dim=1), p_natural) 62 | loss.backward(retain_graph=True) 63 | # renorming gradient 64 | grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1) 65 | delta.grad.div_(grad_norms.view(-1, 1, 1, 1)) 66 | # avoid nan or inf if gradient is 0 67 | if (grad_norms == 0).any(): 68 | delta.grad[grad_norms == 0] = torch.randn_like(delta.grad[grad_norms == 0]) 69 | optimizer_delta.step() 70 | 71 | # projection 72 | delta.data.add_(x_natural) 73 | delta.data.clamp_(0, 1).sub_(x_natural) 74 | delta.data.renorm_(p=2, dim=0, maxnorm=epsilon) 75 | x_adv = Variable(x_natural + delta, requires_grad=False) 76 | else: 77 | raise ValueError(f'Attack={attack} not supported for TRADES training!') 78 | 79 | model.train() 80 | track_bn_stats(model, True) 81 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 82 | 83 | optimizer.zero_grad() 84 | logits_natural = model(x_natural) 85 | logits_adv = model(x_adv) 86 | loss_natural = criterion_ce(logits_natural, y) 87 | loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(logits_adv, dim=1), 88 | F.softmax(logits_natural, dim=1)) 89 | loss = loss_natural + beta * loss_robust 90 | 91 | batch_metrics = {'loss': loss.item()} 92 | batch_metrics.update({'clean_acc': accuracy(y, logits_natural.detach()), 93 | 'adversarial_acc': accuracy(y, logits_adv.detach())}) 94 | return loss, batch_metrics 95 | -------------------------------------------------------------------------------- /gowal21uncovering/utils/watrain.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from tqdm import tqdm as tqdm 4 | 5 | import copy 6 | import json 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from core.attacks import create_attack 12 | from core.attacks import CWLoss 13 | from core.metrics import accuracy 14 | from core.models import create_model 15 | 16 | from core.utils import ctx_noparamgrad_and_eval 17 | from core.utils import Trainer 18 | from core.utils import set_bn_momentum 19 | from core.utils import seed 20 | 21 | from .hat import at_hat_loss 22 | from .hat import hat_loss 23 | from .trades import trades_loss 24 | 25 | 26 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | 28 | 29 | class WATrainer(Trainer): 30 | """ 31 | Helper class for training a deep neural network with model weight averaging (identical to Gowal et al, 2020). 32 | Arguments: 33 | info (dict): dataset information. 34 | args (dict): input arguments. 35 | """ 36 | def __init__(self, info, args): 37 | super(WATrainer, self).__init__(info, args) 38 | 39 | seed(args.seed) 40 | self.wa_model = copy.deepcopy(self.model) 41 | self.eval_attack = create_attack(self.wa_model, CWLoss, args.attack, args.attack_eps, 4*args.attack_iter, 42 | args.attack_step) 43 | num_samples = 50000 if 'cifar' in self.params.data else 73257 44 | num_samples = 100000 if 'tiny-imagenet' in self.params.data else num_samples 45 | self.update_steps = int(np.floor(num_samples/self.params.batch_size) + 1) 46 | self.warmup_steps = 0 47 | if self.params.scheduler in ['cosinew']: 48 | self.warmup_steps = 0.025 * self.params.num_adv_epochs * self.update_steps 49 | self.num_classes = info['num_classes'] 50 | 51 | 52 | def init_optimizer(self, num_epochs): 53 | """ 54 | Initialize optimizer and schedulers. 55 | """ 56 | def group_weight(model): 57 | group_decay = [] 58 | group_no_decay = [] 59 | for n, p in model.named_parameters(): 60 | if 'batchnorm' in n: 61 | group_no_decay.append(p) 62 | else: 63 | group_decay.append(p) 64 | assert len(list(model.parameters())) == len(group_decay) + len(group_no_decay) 65 | groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)] 66 | return groups 67 | 68 | self.optimizer = torch.optim.SGD(group_weight(self.model), lr=self.params.lr, weight_decay=self.params.weight_decay, 69 | momentum=0.9, nesterov=self.params.nesterov) 70 | if num_epochs <= 0: 71 | return 72 | self.init_scheduler(num_epochs) 73 | 74 | 75 | def train(self, dataloader, epoch=0, adversarial=False, verbose=True): 76 | """ 77 | Run one epoch of training. 78 | """ 79 | metrics = pd.DataFrame() 80 | self.model.train() 81 | 82 | update_iter = 0 83 | for data in tqdm(dataloader, desc='Epoch {}: '.format(epoch), disable=not verbose): 84 | global_step = (epoch - 1) * self.update_steps + update_iter 85 | if global_step == 0: 86 | # make BN running mean and variance init same as Haiku 87 | set_bn_momentum(self.model, momentum=1.0) 88 | elif global_step == 1: 89 | set_bn_momentum(self.model, momentum=0.01) 90 | update_iter += 1 91 | 92 | x, y = data 93 | x, y = x.to(device), y.to(device) 94 | 95 | if adversarial: 96 | if self.params.h is not None and self.params.beta is not None: 97 | loss, batch_metrics = self.hat_loss(x, y, h=self.params.h, beta=self.params.beta, gamma=self.params.gamma) 98 | elif self.params.beta is not None and self.params.mart: 99 | loss, batch_metrics = self.mart_loss(x, y, beta=self.params.beta) 100 | elif self.params.beta is not None: 101 | loss, batch_metrics = self.trades_loss(x, y, beta=self.params.beta) 102 | else: 103 | loss, batch_metrics = self.adversarial_loss(x, y) 104 | else: 105 | loss, batch_metrics = self.standard_loss(x, y) 106 | 107 | loss.backward() 108 | if self.params.clip_grad: 109 | nn.utils.clip_grad_norm_(self.model.parameters(), self.params.clip_grad) 110 | self.optimizer.step() 111 | if self.params.scheduler in ['cyclic']: 112 | self.scheduler.step() 113 | 114 | global_step = (epoch - 1) * self.update_steps + update_iter 115 | ema_update(self.wa_model, self.model, global_step, decay_rate=self.params.tau, 116 | warmup_steps=self.warmup_steps, dynamic_decay=True) 117 | metrics = metrics.append(pd.DataFrame(batch_metrics, index=[0]), ignore_index=True) 118 | 119 | if self.params.scheduler in ['step', 'converge', 'cosine', 'cosinew']: 120 | self.scheduler.step() 121 | 122 | update_bn(self.wa_model, self.model) 123 | return dict(metrics.mean()) 124 | 125 | 126 | def hat_loss(self, x, y, h, beta=1.0, gamma=1.0): 127 | """ 128 | Helper-based adversarial training. 129 | """ 130 | other_args = {'label_smoothing': self.params.label_smoothing} 131 | if self.params.robust_loss == 'kl': 132 | loss, batch_metrics = hat_loss(self.model, x, y, self.optimizer, step_size=self.params.attack_step, 133 | epsilon=self.params.attack_eps, perturb_steps=self.params.attack_iter, 134 | h=h, beta=beta, gamma=gamma, attack=self.params.attack, hr_model=self.hr_model, 135 | **other_args) 136 | else: 137 | loss, batch_metrics = at_hat_loss(self.model, x, y, self.optimizer, step_size=self.params.attack_step, 138 | epsilon=self.params.attack_eps, perturb_steps=self.params.attack_iter, 139 | h=h, beta=beta, gamma=gamma, attack=self.params.attack, hr_model=self.hr_model, 140 | **other_args) 141 | return loss, batch_metrics 142 | 143 | 144 | def trades_loss(self, x, y, beta): 145 | """ 146 | TRADES training. 147 | """ 148 | other_args = {'label_smoothing': self.params.label_smoothing} 149 | loss, batch_metrics = trades_loss(self.model, x, y, self.optimizer, step_size=self.params.attack_step, 150 | epsilon=self.params.attack_eps, perturb_steps=self.params.attack_iter, 151 | beta=beta, attack=self.params.attack, **other_args) 152 | return loss, batch_metrics 153 | 154 | 155 | def eval(self, dataloader, adversarial=False): 156 | """ 157 | Evaluate performance of the model. 158 | """ 159 | acc = 0.0 160 | self.wa_model.eval() 161 | 162 | for x, y in dataloader: 163 | x, y = x.to(device), y.to(device) 164 | if adversarial: 165 | with ctx_noparamgrad_and_eval(self.wa_model): 166 | x_adv, _ = self.eval_attack.perturb(x, y) 167 | out = self.wa_model(x_adv) 168 | else: 169 | out = self.wa_model(x) 170 | acc += accuracy(y, out) 171 | acc /= len(dataloader) 172 | return acc 173 | 174 | 175 | def save_model(self, path): 176 | """ 177 | Save model weights. 178 | """ 179 | torch.save({ 180 | 'model_state_dict': self.wa_model.state_dict(), 181 | 'unaveraged_model_state_dict': self.model.state_dict() 182 | }, path) 183 | 184 | 185 | def load_model(self, path): 186 | """ 187 | Load model weights. 188 | """ 189 | checkpoint = torch.load(path) 190 | if 'model_state_dict' not in checkpoint: 191 | raise RuntimeError('Model weights not found at {}.'.format(path)) 192 | self.wa_model.load_state_dict(checkpoint['model_state_dict']) 193 | 194 | 195 | def ema_update(wa_model, model, global_step, decay_rate=0.995, warmup_steps=0, dynamic_decay=True): 196 | """ 197 | Exponential model weight averaging update. 198 | """ 199 | factor = int(global_step >= warmup_steps) 200 | if dynamic_decay: 201 | delta = global_step - warmup_steps 202 | decay = min(decay_rate, (1. + delta) / (10. + delta)) if 10. + delta != 0 else decay_rate 203 | else: 204 | decay = decay_rate 205 | decay *= factor 206 | 207 | for p_swa, p_model in zip(wa_model.parameters(), model.parameters()): 208 | p_swa.data *= decay 209 | p_swa.data += p_model.data * (1 - decay) 210 | 211 | 212 | @torch.no_grad() 213 | def update_bn(avg_model, model): 214 | """ 215 | Update batch normalization layers. 216 | """ 217 | avg_model.eval() 218 | model.eval() 219 | for module1, module2 in zip(avg_model.modules(), model.modules()): 220 | if isinstance(module1, torch.nn.modules.batchnorm._BatchNorm): 221 | module1.running_mean = module2.running_mean 222 | module1.running_var = module2.running_var 223 | module1.num_batches_tracked = module2.num_batches_tracked 224 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | autoattack==0.1 2 | matplotlib==3.4.2 3 | numpy==1.19.0 4 | pandas==1.2.4 5 | Pillow==8.2.0 6 | robustbench==0.1 7 | scipy==1.4.1 8 | torch==1.8.0+cu101 9 | torchvision==0.9.0+cu101 10 | tqdm==4.60.0 11 | -------------------------------------------------------------------------------- /train-wa.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adversarial Training (with improvements from Gowal et al., 2020). 3 | """ 4 | 5 | import json 6 | import time 7 | import argparse 8 | import shutil 9 | 10 | import os 11 | import numpy as np 12 | import pandas as pd 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from core.data import get_data_info 18 | from core.data import load_data 19 | from core.data import SEMISUP_DATASETS 20 | 21 | from core.utils import format_time 22 | from core.utils import Logger 23 | from core.utils import parser_train 24 | from core.utils import Trainer 25 | from core.utils import seed 26 | 27 | from gowal21uncovering.utils import WATrainer 28 | 29 | 30 | # Setup 31 | 32 | parse = parser_train() 33 | parse.add_argument('--label-smoothing', type=float, default=0.1, help='Label smoothing.') 34 | parse.add_argument('--tau', type=float, default=None, help='Weight averaging decay.') 35 | args = parse.parse_args() 36 | 37 | DATA_DIR = os.path.join(args.data_dir, args.data) 38 | LOG_DIR = os.path.join(args.log_dir, args.desc) 39 | WEIGHTS = os.path.join(LOG_DIR, 'weights-best.pt') 40 | if os.path.exists(LOG_DIR): 41 | shutil.rmtree(LOG_DIR) 42 | os.makedirs(LOG_DIR) 43 | logger = Logger(os.path.join(LOG_DIR, 'log-train.log')) 44 | 45 | with open(os.path.join(LOG_DIR, 'args.txt'), 'w') as f: 46 | json.dump(args.__dict__, f, indent=4) 47 | 48 | 49 | info = get_data_info(DATA_DIR) 50 | BATCH_SIZE = args.batch_size 51 | BATCH_SIZE_VALIDATION = args.batch_size_validation 52 | NUM_ADV_EPOCHS = args.num_adv_epochs 53 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 54 | logger.log('Using device: {}'.format(device)) 55 | if args.debug: 56 | NUM_ADV_EPOCHS = 1 57 | 58 | # To speed up training 59 | torch.backends.cudnn.benchmark = True 60 | 61 | 62 | 63 | # Load data 64 | 65 | seed(args.seed) 66 | train_dataset, test_dataset, eval_dataset, train_dataloader, test_dataloader, eval_dataloader = load_data( 67 | DATA_DIR, BATCH_SIZE, BATCH_SIZE_VALIDATION, use_augmentation=args.augment, shuffle_train=True, 68 | aux_data_filename=args.aux_data_filename, unsup_fraction=args.unsup_fraction, validation=True 69 | ) 70 | del train_dataset, test_dataset, eval_dataset 71 | 72 | 73 | 74 | # Adversarial Training 75 | seed(args.seed) 76 | metrics = pd.DataFrame() 77 | if args.tau: 78 | print ('Using WA.') 79 | trainer = WATrainer(info, args) 80 | else: 81 | trainer = Trainer(info, args) 82 | last_lr = args.lr 83 | 84 | if NUM_ADV_EPOCHS > 0: 85 | logger.log('\n\n') 86 | metrics = pd.DataFrame() 87 | logger.log('Standard Accuracy-\tTest: {:2f}%.'.format(trainer.eval(test_dataloader)*100)) 88 | 89 | old_score = [0.0, 0.0] 90 | logger.log('RST + HAT Adversarial training for {} epochs'.format(NUM_ADV_EPOCHS)) 91 | trainer.init_optimizer(args.num_adv_epochs) 92 | test_adv_acc = 0.0 93 | 94 | 95 | for epoch in range(1, NUM_ADV_EPOCHS+1): 96 | start = time.time() 97 | logger.log('======= Epoch {} ======='.format(epoch)) 98 | 99 | if args.scheduler: 100 | last_lr = trainer.scheduler.get_last_lr()[0] 101 | 102 | res = trainer.train(train_dataloader, epoch=epoch, adversarial=True) 103 | test_acc = trainer.eval(test_dataloader) 104 | 105 | logger.log('Loss: {:.4f}.\tLR: {:.4f}'.format(res['loss'], last_lr)) 106 | if 'clean_acc' in res: 107 | logger.log('Standard Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.'.format(res['clean_acc']*100, test_acc*100)) 108 | else: 109 | logger.log('Standard Accuracy-\tTest: {:.2f}%.'.format(test_acc*100)) 110 | epoch_metrics = {'train_'+k: v for k, v in res.items()} 111 | epoch_metrics.update({'epoch': epoch, 'lr': last_lr, 'test_clean_acc': test_acc, 'test_adversarial_acc': ''}) 112 | 113 | if epoch % args.adv_eval_freq == 0 or epoch == NUM_ADV_EPOCHS: 114 | test_adv_acc = trainer.eval(test_dataloader, adversarial=True) 115 | logger.log('Adversarial Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.'.format(res['adversarial_acc']*100, 116 | test_adv_acc*100)) 117 | epoch_metrics.update({'test_adversarial_acc': test_adv_acc}) 118 | else: 119 | logger.log('Adversarial Accuracy-\tTrain: {:.2f}%.'.format(res['adversarial_acc']*100)) 120 | eval_adv_acc = trainer.eval(eval_dataloader, adversarial=True) 121 | logger.log('Adversarial Accuracy-\tEval: {:.2f}%.'.format(eval_adv_acc*100)) 122 | epoch_metrics['eval_adversarial_acc'] = eval_adv_acc 123 | 124 | if eval_adv_acc >= old_score[1]: 125 | old_score[0], old_score[1] = test_acc, eval_adv_acc 126 | trainer.save_model(WEIGHTS) 127 | trainer.save_model(os.path.join(LOG_DIR, 'weights-last.pt')) 128 | 129 | logger.log('Time taken: {}'.format(format_time(time.time()-start))) 130 | metrics = metrics.append(pd.DataFrame(epoch_metrics, index=[0]), ignore_index=True) 131 | metrics.to_csv(os.path.join(LOG_DIR, 'stats_adv.csv'), index=False) 132 | 133 | 134 | 135 | # Record metrics 136 | 137 | train_acc = res['clean_acc'] if 'clean_acc' in res else trainer.eval(train_dataloader) 138 | logger.log('\nTraining completed.') 139 | logger.log('Standard Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.'.format(train_acc*100, old_score[0]*100)) 140 | if NUM_ADV_EPOCHS > 0: 141 | logger.log('Adversarial Accuracy-\tTrain: {:.2f}%.\tEval: {:.2f}%.'.format(res['adversarial_acc']*100, old_score[1]*100)) 142 | 143 | logger.log('Script Completed.') 144 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Standard Training + Adversarial Training. 3 | """ 4 | 5 | import json 6 | import time 7 | import argparse 8 | import shutil 9 | 10 | import os 11 | import numpy as np 12 | import pandas as pd 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from core.data import get_data_info 18 | from core.data import load_data 19 | 20 | from core.utils import format_time 21 | from core.utils import Logger 22 | from core.utils import parser_train 23 | from core.utils import Trainer 24 | from core.utils import seed 25 | from core import setup 26 | 27 | 28 | 29 | # Setup 30 | 31 | parse = parser_train() 32 | args = parse.parse_args() 33 | 34 | 35 | DATA_DIR = os.path.join(args.data_dir, args.data) 36 | LOG_DIR = os.path.join(args.log_dir, args.desc) 37 | WEIGHTS = os.path.join(LOG_DIR, 'weights-best.pt') 38 | TMP = os.path.join(args.tmp_dir, args.desc) 39 | if os.path.exists(LOG_DIR): 40 | shutil.rmtree(LOG_DIR) 41 | os.makedirs(LOG_DIR) 42 | if args.exp and not os.path.exists(TMP): 43 | os.makedirs(TMP, exist_ok=True) 44 | print ('Tmp Dir: ', TMP) 45 | logger = Logger(os.path.join(LOG_DIR, 'log-train.log')) 46 | 47 | with open(os.path.join(LOG_DIR, 'args.txt'), 'w') as f: 48 | json.dump(args.__dict__, f, indent=4) 49 | 50 | if 'imagenet' in args.data: 51 | setup.setup_train(DATA_DIR) 52 | setup.setup_val(DATA_DIR) 53 | args.data_dir = os.environ['TMPDIR'] 54 | DATA_DIR = os.path.join(args.data_dir, args.data) 55 | 56 | info = get_data_info(DATA_DIR) 57 | BATCH_SIZE = args.batch_size 58 | BATCH_SIZE_VALIDATION = args.batch_size_validation 59 | NUM_STD_EPOCHS = args.num_std_epochs 60 | NUM_ADV_EPOCHS = args.num_adv_epochs 61 | NUM_SAMPLES_EVAL = args.num_samples_eval 62 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 63 | logger.log('Using device: {}'.format(device)) 64 | if args.debug: 65 | NUM_STD_EPOCHS = 1 66 | NUM_ADV_EPOCHS = 1 67 | 68 | # To speed up training 69 | if args.model in ['wrn-34-10', 'wrn-34-20'] or 'swish' in args.model or 'imagenet' in args.data: 70 | torch.backends.cudnn.benchmark = True 71 | 72 | 73 | # Load data 74 | 75 | seed(args.seed) 76 | train_dataset, test_dataset, train_dataloader, test_dataloader = load_data( 77 | DATA_DIR, BATCH_SIZE, BATCH_SIZE_VALIDATION, use_augmentation=args.augment, shuffle_train=True, 78 | aux_data_filename=args.aux_data_filename, unsup_fraction=args.unsup_fraction 79 | ) 80 | num_train_samples = len(train_dataset) 81 | num_test_samples = len(test_dataset) 82 | 83 | train_indices = np.random.choice(num_train_samples, NUM_SAMPLES_EVAL, replace=False) 84 | test_indices = np.random.choice(num_test_samples, NUM_SAMPLES_EVAL, replace=False) 85 | 86 | pin_memory = torch.cuda.is_available() 87 | if args.exp: 88 | train_eval_dataset = torch.utils.data.Subset(train_dataset, train_indices[:NUM_SAMPLES_EVAL]) 89 | train_eval_dataloader = torch.utils.data.DataLoader(train_eval_dataset, batch_size=BATCH_SIZE_VALIDATION, shuffle=False, 90 | num_workers=4, pin_memory=pin_memory) 91 | 92 | test_eval_dataset = torch.utils.data.Subset(test_dataset, test_indices[:NUM_SAMPLES_EVAL]) 93 | test_eval_dataloader = torch.utils.data.DataLoader(test_eval_dataset, batch_size=BATCH_SIZE_VALIDATION, shuffle=False, 94 | num_workers=4, pin_memory=pin_memory) 95 | del train_eval_dataset, test_eval_dataset 96 | del train_dataset, test_dataset 97 | 98 | 99 | 100 | # Standard Training 101 | 102 | seed(args.seed) 103 | metrics = pd.DataFrame() 104 | trainer = Trainer(info, args) 105 | last_lr = args.lr 106 | 107 | logger.log('\n\n') 108 | logger.log('Standard training for {} epochs'.format(NUM_STD_EPOCHS)) 109 | old_score = [0.0] 110 | 111 | for epoch in range(1, NUM_STD_EPOCHS+1): 112 | start = time.time() 113 | logger.log('======= Epoch {} ======='.format(epoch)) 114 | if args.scheduler: 115 | last_lr = trainer.scheduler.get_last_lr()[0] 116 | 117 | res = trainer.train(train_dataloader, epoch=epoch) 118 | test_acc = trainer.eval(test_dataloader) 119 | 120 | if test_acc >= old_score[0]: 121 | old_score[0] = test_acc 122 | trainer.save_model(WEIGHTS) 123 | 124 | logger.log('Loss: {:.4f}.\tLR: {:.4f}'.format(res['loss'], last_lr)) 125 | logger.log('Standard Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.'.format(res['clean_acc']*100, test_acc*100)) 126 | logger.log('Time taken: {}'.format(format_time(time.time()-start))) 127 | 128 | epoch_metrics = {'train_'+k: v for k, v in res.items()} 129 | epoch_metrics.update({'epoch': epoch, 'lr': last_lr, 'test_clean_acc': test_acc, 'test_adversarial_acc': ''}) 130 | 131 | if epoch == NUM_STD_EPOCHS: 132 | test_adv_acc = trainer.eval(test_dataloader, adversarial=True) 133 | logger.log('Adversarial Accuracy-\tTest: {:.2f}%.'.format(test_adv_acc*100)) 134 | epoch_metrics.update({'test_adversarial_acc': test_adv_acc}) 135 | metrics = metrics.append(pd.DataFrame(epoch_metrics, index=[0]), ignore_index=True) 136 | 137 | if NUM_STD_EPOCHS > 0: 138 | trainer.load_model(WEIGHTS) 139 | metrics.to_csv(os.path.join(LOG_DIR, 'stats_std.csv'), index=False) 140 | 141 | 142 | 143 | # Adversarial Training (AT, TRADES, MART and HAT) 144 | 145 | if NUM_ADV_EPOCHS > 0: 146 | logger.log('\n\n') 147 | metrics = pd.DataFrame() 148 | logger.log('Standard Accuracy-\tTest: {:2f}%.'.format(trainer.eval(test_dataloader)*100)) 149 | 150 | if args.exp: 151 | test_adv_acc = trainer.eval(test_dataloader, adversarial=True) 152 | logger.log('Adversarial Accuracy-\tTest: {:2f}%.'.format(test_adv_acc*100)) 153 | trainer.save_model(os.path.join(TMP, 'model_0.pt')) 154 | _ = trainer.save_and_eval_adversarial(train_eval_dataloader, save=os.path.join(TMP, 'eval_train_adv_0')) 155 | _ = trainer.save_and_eval_adversarial(test_eval_dataloader, save=os.path.join(TMP, 'eval_test_adv_0')) 156 | 157 | old_score = [0.0, 0.0] 158 | logger.log('Adversarial training for {} epochs'.format(NUM_ADV_EPOCHS)) 159 | trainer.init_optimizer(args.num_adv_epochs) 160 | test_adv_acc = 0.0 161 | 162 | 163 | for epoch in range(1, NUM_ADV_EPOCHS+1): 164 | start = time.time() 165 | logger.log('======= Epoch {} ======='.format(epoch)) 166 | 167 | if args.scheduler: 168 | last_lr = trainer.scheduler.get_last_lr()[0] 169 | 170 | res = trainer.train(train_dataloader, epoch=epoch, adversarial=True) 171 | test_acc = trainer.eval(test_dataloader) 172 | 173 | if args.exp and (epoch % 5 == 0 or epoch == 1): 174 | trainer.save_model(os.path.join(TMP, 'model_{}.pt'.format(epoch))) 175 | save_eval_train_file = os.path.join(TMP, 'eval_train_adv_{}'.format(epoch)) 176 | save_eval_test_file = os.path.join(TMP, 'eval_test_adv_{}'.format(epoch)) 177 | _ = trainer.save_and_eval_adversarial(train_eval_dataloader, save=save_eval_train_file, save_all=False) 178 | _ = trainer.save_and_eval_adversarial(test_eval_dataloader, save=save_eval_test_file, save_all=False) 179 | 180 | logger.log('Loss: {:.4f}.\tLR: {:.4f}'.format(res['loss'], last_lr)) 181 | if 'clean_acc' in res: 182 | logger.log('Standard Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.'.format(res['clean_acc']*100, test_acc*100)) 183 | else: 184 | logger.log('Standard Accuracy-\tTest: {:.2f}%.'.format(test_acc*100)) 185 | epoch_metrics = {'train_'+k: v for k, v in res.items()} 186 | epoch_metrics.update({'epoch': NUM_STD_EPOCHS+epoch, 'lr': last_lr, 'test_clean_acc': test_acc, 'test_adversarial_acc': ''}) 187 | 188 | if epoch % args.adv_eval_freq == 0 or epoch > (NUM_ADV_EPOCHS-5) or (epoch >= (NUM_ADV_EPOCHS-10) and NUM_ADV_EPOCHS > 90): 189 | test_adv_acc = trainer.eval(test_dataloader, adversarial=True) 190 | logger.log('Adversarial Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.'.format(res['adversarial_acc']*100, 191 | test_adv_acc*100)) 192 | epoch_metrics.update({'test_adversarial_acc': test_adv_acc}) 193 | else: 194 | logger.log('Adversarial Accuracy-\tTrain: {:.2f}%.'.format(res['adversarial_acc']*100)) 195 | 196 | if test_adv_acc >= old_score[1]: 197 | old_score[0], old_score[1] = test_acc, test_adv_acc 198 | trainer.save_model(WEIGHTS) 199 | trainer.save_model(os.path.join(LOG_DIR, 'weights-last.pt')) 200 | 201 | logger.log('Time taken: {}'.format(format_time(time.time()-start))) 202 | metrics = metrics.append(pd.DataFrame(epoch_metrics, index=[0]), ignore_index=True) 203 | metrics.to_csv(os.path.join(LOG_DIR, 'stats_adv.csv'), index=False) 204 | 205 | 206 | 207 | # Record metrics 208 | 209 | train_acc = res['clean_acc'] if 'clean_acc' in res else trainer.eval(train_dataloader) 210 | logger.log('\nTraining completed.') 211 | logger.log('Standard Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.'.format(train_acc*100, old_score[0]*100)) 212 | if NUM_ADV_EPOCHS > 0: 213 | logger.log('Adversarial Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.'.format(res['adversarial_acc']*100, old_score[1]*100)) 214 | 215 | logger.log('Script Completed.') 216 | --------------------------------------------------------------------------------