├── .gitignore ├── LICENSE ├── README.md ├── core ├── __init__.py ├── attacks │ ├── __init__.py │ ├── apgd.py │ ├── base.py │ ├── deepfool.py │ ├── fgsm.py │ ├── pgd.py │ └── utils.py ├── data │ ├── __init__.py │ ├── cifar10.py │ ├── cifar100.py │ ├── cifar100s.py │ ├── cifar10s.py │ ├── semisup.py │ ├── svhn.py │ └── tiny_imagenet.py ├── metrics.py ├── models │ ├── __init__.py │ ├── preact_resnet.py │ ├── preact_resnetwithswish.py │ ├── resnet.py │ ├── ti_preact_resnet.py │ ├── wideresnet.py │ └── wideresnetwithswish.py └── utils │ ├── __init__.py │ ├── context.py │ ├── logger.py │ ├── mart.py │ ├── parser.py │ ├── rst.py │ ├── trades.py │ ├── train.py │ └── utils.py ├── eval-aa.py ├── eval-adv.py ├── eval-rb.py ├── gowal21uncovering └── utils │ ├── __init__.py │ ├── cutmix.py │ ├── trades.py │ └── watrain.py ├── requirements.txt ├── train-wa.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/*__pycache__/ 2 | **/*.pyc 3 | **/*.ipynb_checkpoints/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Rahul Rade 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Robustness 2 | 3 | This repository contains the unofficial implementation of the papers "[Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples](https://arxiv.org/abs/2010.03593)" (Gowal et al., 2020) and "[Fixing Data Augmentation to Improve Adversarial Robustness](https://arxiv.org/abs/2103.01946)" (Rebuffi et al., 2021) in [PyTorch](https://pytorch.org/). 4 | 5 | ## Requirements 6 | 7 | The code has been implemented and tested with `Python 3.8.5` and `PyTorch 1.8.0`. To install the required packages: 8 | ```bash 9 | $ pip install -r requirements.txt 10 | ``` 11 | 12 | ## Usage 13 | 14 | ### Training 15 | 16 | Run [`train-wa.py`](./train-wa.py) for reproducing the results reported in the papers. For example, train a WideResNet-28-10 model via [TRADES](https://github.com/yaodongyu/TRADES) on CIFAR-10 with the additional pseudolabeled data provided by [Carmon et al., 2019](https://github.com/yaircarmon/semisup-adv) or the synthetic data from [Rebuffi et al., 2021](https://arxiv.org/abs/2103.01946) (without CutMix): 17 | 18 | ``` 19 | $ python train-wa.py --data-dir \ 20 | --log-dir \ 21 | --desc \ 22 | --data cifar10s \ 23 | --batch-size 1024 \ 24 | --model wrn-28-10-swish \ 25 | --num-adv-epochs 400 \ 26 | --lr 0.4 \ 27 | --beta 6.0 \ 28 | --unsup-fraction 0.7 \ 29 | --aux-data-filename 30 | ``` 31 | 32 | **Note**: Note that with [Gowal et al., 2020](https://arxiv.org/abs/2010.03593), expect about 0.5% lower robust accuracy than that reported in the paper since the original implementation uses a custom regenerated pseudolabeled dataset which is not publicly available (See Section 4.3.1 [here](https://arxiv.org/abs/2010.03593)). 33 | 34 | ### Robustness Evaluation 35 | 36 | The trained models can be evaluated by running [`eval-aa.py`](./eval-aa.py) which uses [AutoAttack](https://github.com/fra31/auto-attack) for evaluating the robust accuracy. For example: 37 | 38 | ``` 39 | $ python eval-aa.py --data-dir \ 40 | --log-dir \ 41 | --desc 42 | ``` 43 | 44 | For PGD evaluation: 45 | ``` 46 | $ python eval-adv.py --wb --data-dir \ 47 | --log-dir \ 48 | --desc 49 | ``` 50 | 51 | ## Reference & Citing this work 52 | 53 | If you use this code in your research, please cite the original works [[Paper](https://arxiv.org/abs/2010.03593)] [[Code in JAX+Haiku](https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness)] [[Pretrained models](https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness)]: 54 | 55 | ``` 56 | @article{gowal2020uncovering, 57 | title={Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples}, 58 | author={Gowal, Sven and Qin, Chongli and Uesato, Jonathan and Mann, Timothy and Kohli, Pushmeet}, 59 | journal={arXiv preprint arXiv:2010.03593}, 60 | year={2020}, 61 | url={https://arxiv.org/pdf/2010.03593} 62 | } 63 | ``` 64 | 65 | *and/or* 66 | 67 | ``` 68 | @article{rebuffi2021fixing, 69 | title={Fixing Data Augmentation to Improve Adversarial Robustness}, 70 | author={Rebuffi, Sylvestre-Alvise and Gowal, Sven and Calian, Dan A. and Stimberg, Florian and Wiles, Olivia and Mann, Timothy}, 71 | journal={arXiv preprint arXiv:2103.01946}, 72 | year={2021}, 73 | url={https://arxiv.org/pdf/2103.01946} 74 | } 75 | ``` 76 | 77 | *and* this repository: 78 | 79 | ``` 80 | @misc{rade2021pytorch, 81 | title = {{PyTorch} Implementation of Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples}, 82 | author = {Rade, Rahul}, 83 | year = {2021}, 84 | url = {https://github.com/imrahulr/adversarial_robustness_pytorch} 85 | } 86 | ``` 87 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imrahulr/adversarial_robustness_pytorch/6df6a8f0cd49cf6d18507a4b574c004ab6eedf49/core/__init__.py -------------------------------------------------------------------------------- /core/attacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Attack 2 | 3 | from .apgd import LinfAPGDAttack 4 | from .apgd import L2APGDAttack 5 | 6 | from .fgsm import FGMAttack 7 | from .fgsm import FGSMAttack 8 | from .fgsm import L2FastGradientAttack 9 | from .fgsm import LinfFastGradientAttack 10 | 11 | from .pgd import PGDAttack 12 | from .pgd import L2PGDAttack 13 | from .pgd import LinfPGDAttack 14 | 15 | from .deepfool import DeepFoolAttack 16 | from .deepfool import LinfDeepFoolAttack 17 | from .deepfool import L2DeepFoolAttack 18 | 19 | from .utils import CWLoss 20 | 21 | 22 | ATTACKS = ['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd'] 23 | 24 | 25 | def create_attack(model, criterion, attack_type, attack_eps, attack_iter, attack_step, rand_init_type='uniform', 26 | clip_min=0., clip_max=1.): 27 | """ 28 | Initialize adversary. 29 | Arguments: 30 | model (nn.Module): forward pass function. 31 | criterion (nn.Module): loss function. 32 | attack_type (str): name of the attack. 33 | attack_eps (float): attack radius. 34 | attack_iter (int): number of attack iterations. 35 | attack_step (float): step size for the attack. 36 | rand_init_type (str): random initialization type for PGD (default: uniform). 37 | clip_min (float): mininum value per input dimension. 38 | clip_max (float): maximum value per input dimension. 39 | Returns: 40 | Attack 41 | """ 42 | 43 | if attack_type == 'fgsm': 44 | attack = FGSMAttack(model, criterion, eps=attack_eps, clip_min=clip_min, clip_max=clip_max) 45 | elif attack_type == 'fgm': 46 | attack = FGMAttack(model, criterion, eps=attack_eps, clip_min=clip_min, clip_max=clip_max) 47 | elif attack_type == 'linf-pgd': 48 | attack = LinfPGDAttack(model, criterion, eps=attack_eps, nb_iter=attack_iter, eps_iter=attack_step, 49 | rand_init_type=rand_init_type, clip_min=clip_min, clip_max=clip_max) 50 | elif attack_type == 'l2-pgd': 51 | attack = L2PGDAttack(model, criterion, eps=attack_eps, nb_iter=attack_iter, eps_iter=attack_step, 52 | rand_init_type=rand_init_type, clip_min=clip_min, clip_max=clip_max) 53 | elif attack_type == 'linf-df': 54 | attack = LinfDeepFoolAttack(model, overshoot=0.02, nb_iter=attack_iter, search_iter=0, clip_min=clip_min, 55 | clip_max=clip_max) 56 | elif attack_type == 'l2-df': 57 | attack = L2DeepFoolAttack(model, overshoot=0.02, nb_iter=attack_iter, search_iter=0, clip_min=clip_min, 58 | clip_max=clip_max) 59 | elif attack_type == 'linf-apgd': 60 | attack = LinfAPGDAttack(model, criterion, n_restarts=2, eps=attack_eps, nb_iter=attack_iter) 61 | elif attack_type == 'l2-apgd': 62 | attack = L2APGDAttack(model, criterion, n_restarts=2, eps=attack_eps, nb_iter=attack_iter) 63 | else: 64 | raise NotImplementedError('{} is not yet implemented!'.format(attack_type)) 65 | return attack 66 | -------------------------------------------------------------------------------- /core/attacks/apgd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from autoattack.autopgd_pt import APGDAttack 5 | 6 | 7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 8 | 9 | 10 | class APGD(): 11 | """ 12 | APGD attack (from AutoAttack) (Croce et al, 2020). 13 | The attack performs nb_iter steps of adaptive size, while always staying within eps from the initial point. 14 | Arguments: 15 | predict (nn.Module): forward pass function. 16 | loss_fn (str): loss function - ce or dlr. 17 | n_restarts (int): number of random restarts. 18 | eps (float): maximum distortion. 19 | nb_iter (int): number of iterations. 20 | ord (int): (optional) the order of maximum distortion (inf or 2). 21 | """ 22 | def __init__(self, predict, loss_fn='ce', n_restarts=2, eps=0.3, nb_iter=40, ord=np.inf, seed=1): 23 | assert loss_fn in ['ce', 'dlr'], 'Only loss_fn=ce or loss_fn=dlr are supported!' 24 | assert ord in [2, np.inf], 'Only ord=inf or ord=2 are supported!' 25 | 26 | norm = 'Linf' if ord == np.inf else 'L2' 27 | self.apgd = APGDAttack(predict, n_restarts=n_restarts, n_iter=nb_iter, verbose=False, eps=eps, norm=norm, 28 | eot_iter=1, rho=.75, seed=seed, device=device) 29 | self.apgd.loss = loss_fn 30 | 31 | def perturb(self, x, y): 32 | x_adv = self.apgd.perturb(x, y)[1] 33 | r_adv = x_adv - x 34 | return x_adv, r_adv 35 | 36 | 37 | class LinfAPGDAttack(APGD): 38 | """ 39 | APGD attack (from AutoAttack) with order=Linf. 40 | The attack performs nb_iter steps of adaptive size, while always staying within eps from the initial point. 41 | Arguments: 42 | predict (nn.Module): forward pass function. 43 | loss_fn (str): loss function - ce or dlr. 44 | n_restarts (int): number of random restarts. 45 | eps (float): maximum distortion. 46 | nb_iter (int): number of iterations. 47 | """ 48 | 49 | def __init__(self, predict, loss_fn='ce', n_restarts=2, eps=0.3, nb_iter=40, seed=1): 50 | ord = np.inf 51 | super(L2APGDAttack, self).__init__( 52 | predict=predict, loss_fn=loss_fn, n_restarts=n_restarts, eps=eps, nb_iter=nb_iter, ord=ord, seed=seed) 53 | 54 | 55 | class L2APGDAttack(APGD): 56 | """ 57 | APGD attack (from AutoAttack) with order=L2. 58 | The attack performs nb_iter steps of adaptive size, while always staying within eps from the initial point. 59 | Arguments: 60 | predict (nn.Module): forward pass function. 61 | loss_fn (str): loss function - ce or dlr. 62 | n_restarts (int): number of random restarts. 63 | eps (float): maximum distortion. 64 | nb_iter (int): number of iterations. 65 | """ 66 | 67 | def __init__(self, predict, loss_fn='ce', n_restarts=2, eps=0.3, nb_iter=40, seed=1): 68 | ord = 2 69 | super(L2APGDAttack, self).__init__( 70 | predict=predict, loss_fn=loss_fn, n_restarts=n_restarts, eps=eps, nb_iter=nb_iter, ord=ord, seed=seed) -------------------------------------------------------------------------------- /core/attacks/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .utils import replicate_input 5 | 6 | 7 | class Attack(object): 8 | """ 9 | Abstract base class for all attack classes. 10 | Arguments: 11 | predict (nn.Module): forward pass function. 12 | loss_fn (nn.Module): loss function. 13 | clip_min (float): mininum value per input dimension. 14 | clip_max (float): maximum value per input dimension. 15 | """ 16 | 17 | def __init__(self, predict, loss_fn, clip_min, clip_max): 18 | self.predict = predict 19 | self.loss_fn = loss_fn 20 | self.clip_min = clip_min 21 | self.clip_max = clip_max 22 | 23 | def perturb(self, x, **kwargs): 24 | """ 25 | Virtual method for generating the adversarial examples. 26 | Arguments: 27 | x (torch.Tensor): the model's input tensor. 28 | **kwargs: optional parameters used by child classes. 29 | Returns: 30 | adversarial examples. 31 | """ 32 | error = "Sub-classes must implement perturb." 33 | raise NotImplementedError(error) 34 | 35 | def __call__(self, *args, **kwargs): 36 | return self.perturb(*args, **kwargs) 37 | 38 | 39 | class LabelMixin(object): 40 | def _get_predicted_label(self, x): 41 | """ 42 | Compute predicted labels given x. Used to prevent label leaking during adversarial training. 43 | Arguments: 44 | x (torch.Tensor): the model's input tensor. 45 | Returns: 46 | torch.Tensor containing predicted labels. 47 | """ 48 | with torch.no_grad(): 49 | outputs = self.predict(x) 50 | _, y = torch.max(outputs, dim=1) 51 | return y 52 | 53 | def _verify_and_process_inputs(self, x, y): 54 | if self.targeted: 55 | assert y is not None 56 | 57 | if not self.targeted: 58 | if y is None: 59 | y = self._get_predicted_label(x) 60 | 61 | x = replicate_input(x) 62 | y = replicate_input(y) 63 | return x, y -------------------------------------------------------------------------------- /core/attacks/deepfool.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | from torch.autograd.gradcheck import zero_gradients 7 | 8 | from .base import Attack, LabelMixin 9 | 10 | from .utils import batch_multiply 11 | from .utils import clamp 12 | from .utils import is_float_or_torch_tensor 13 | 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | def perturb_deepfool(xvar, yvar, predict, nb_iter=50, overshoot=0.02, ord=np.inf, clip_min=0.0, clip_max=1.0, 19 | search_iter=0, device=None): 20 | """ 21 | Compute DeepFool perturbations (Moosavi-Dezfooli et al, 2016). 22 | Arguments: 23 | xvar (torch.Tensor): input images. 24 | yvar (torch.Tensor): predictions. 25 | predict (nn.Module): forward pass function. 26 | nb_iter (int): number of iterations. 27 | overshoot (float): how much to overshoot the boundary. 28 | ord (int): (optional) the order of maximum distortion (inf or 2). 29 | clip_min (float): mininum value per input dimension. 30 | clip_max (float): maximum value per input dimension. 31 | search_iter (int): no of search iterations. 32 | device (torch.device): device to work on. 33 | Returns: 34 | torch.Tensor containing the perturbed input, 35 | torch.Tensor containing the perturbation 36 | """ 37 | 38 | x_orig = xvar 39 | x = torch.empty_like(xvar).copy_(xvar) 40 | x.requires_grad_(True) 41 | 42 | batch_i = torch.arange(x.shape[0]) 43 | r_tot = torch.zeros_like(x.data) 44 | for i in range(nb_iter): 45 | if x.grad is not None: 46 | x.grad.zero_() 47 | 48 | logits = predict(x) 49 | df_inds = np.argsort(logits.detach().cpu().numpy(), axis=-1) 50 | df_inds_other, df_inds_orig = df_inds[:, :-1], df_inds[:, -1] 51 | df_inds_orig = torch.from_numpy(df_inds_orig) 52 | df_inds_orig = df_inds_orig.to(device) 53 | not_done_inds = df_inds_orig == yvar 54 | if not_done_inds.sum() == 0: 55 | break 56 | 57 | logits[batch_i, df_inds_orig].sum().backward(retain_graph=True) 58 | grad_orig = x.grad.data.clone().detach() 59 | pert = x.data.new_ones(x.shape[0]) * np.inf 60 | w = torch.zeros_like(x.data) 61 | 62 | for inds in df_inds_other.T: 63 | x.grad.zero_() 64 | logits[batch_i, inds].sum().backward(retain_graph=True) 65 | grad_cur = x.grad.data.clone().detach() 66 | with torch.no_grad(): 67 | w_k = grad_cur - grad_orig 68 | f_k = logits[batch_i, inds] - logits[batch_i, df_inds_orig] 69 | if ord == 2: 70 | pert_k = torch.abs(f_k) / torch.norm(w_k.flatten(1), 2, -1) 71 | elif ord == np.inf: 72 | pert_k = torch.abs(f_k) / torch.norm(w_k.flatten(1), 1, -1) 73 | else: 74 | raise NotImplementedError("Only ord=inf and ord=2 have been implemented") 75 | swi = pert_k < pert 76 | if swi.sum() > 0: 77 | pert[swi] = pert_k[swi] 78 | w[swi] = w_k[swi] 79 | 80 | if ord == 2: 81 | r_i = (pert + 1e-6)[:, None, None, None] * w / torch.norm(w.flatten(1), 2, -1)[:, None, None, None] 82 | elif ord == np.inf: 83 | r_i = (pert + 1e-6)[:, None, None, None] * w.sign() 84 | 85 | r_tot += r_i * not_done_inds[:, None, None, None].float() 86 | x.data = x_orig + (1. + overshoot) * r_tot 87 | x.data = torch.clamp(x.data, clip_min, clip_max) 88 | 89 | x = x.detach() 90 | if search_iter > 0: 91 | dx = x - x_orig 92 | dx_l_low, dx_l_high = torch.zeros_like(dx), torch.ones_like(dx) 93 | for i in range(search_iter): 94 | dx_l = (dx_l_low + dx_l_high) / 2. 95 | dx_x = x_orig + dx_l * dx 96 | dx_y = predict(dx_x).argmax(-1) 97 | label_stay = dx_y == yvar 98 | label_change = dx_y != yvar 99 | dx_l_low[label_stay] = dx_l[label_stay] 100 | dx_l_high[label_change] = dx_l[label_change] 101 | x = dx_x 102 | 103 | # x.data = torch.clamp(x.data, clip_min, clip_max) 104 | r_tot = x.data - x_orig 105 | return x, r_tot 106 | 107 | 108 | 109 | class DeepFoolAttack(Attack, LabelMixin): 110 | """ 111 | DeepFool attack. 112 | [Seyed-Mohsen Moosavi-Dezfooli, Alhussein Fawzi, Pascal Frossard, 113 | "DeepFool: a simple and accurate method to fool deep neural networks"] 114 | Arguments: 115 | predict (nn.Module): forward pass function. 116 | overshoot (float): how much to overshoot the boundary. 117 | nb_iter (int): number of iterations. 118 | search_iter (int): no of search iterations. 119 | clip_min (float): mininum value per input dimension. 120 | clip_max (float): maximum value per input dimension. 121 | ord (int): (optional) the order of maximum distortion (inf or 2). 122 | """ 123 | 124 | def __init__( 125 | self, predict, overshoot=0.02, nb_iter=50, search_iter=50, clip_min=0., clip_max=1., ord=np.inf): 126 | super(DeepFoolAttack, self).__init__(predict, None, clip_min, clip_max) 127 | self.overshoot = overshoot 128 | self.nb_iter = nb_iter 129 | self.search_iter = search_iter 130 | self.targeted = False 131 | 132 | self.ord = ord 133 | assert is_float_or_torch_tensor(self.overshoot) 134 | 135 | def perturb(self, x, y=None): 136 | """ 137 | Given examples x, returns their adversarial counterparts. 138 | Arguments: 139 | x (torch.Tensor): input tensor. 140 | y (torch.Tensor): label tensor. 141 | - if None and self.targeted=False, compute y as predicted labels. 142 | Returns: 143 | torch.Tensor containing perturbed inputs, 144 | torch.Tensor containing the perturbation 145 | """ 146 | 147 | x, y = self._verify_and_process_inputs(x, None) 148 | x_adv, r_adv = perturb_deepfool(x, y, self.predict, self.nb_iter, self.overshoot, ord=self.ord, 149 | clip_min=self.clip_min, clip_max=self.clip_max, search_iter=self.search_iter, 150 | device=device) 151 | return x_adv, r_adv 152 | 153 | 154 | class LinfDeepFoolAttack(DeepFoolAttack): 155 | """ 156 | DeepFool Attack with order=Linf. 157 | Arguments: 158 | Arguments: 159 | predict (nn.Module): forward pass function. 160 | overshoot (float): how much to overshoot the boundary. 161 | nb_iter (int): number of iterations. 162 | search_iter (int): no of search iterations. 163 | clip_min (float): mininum value per input dimension. 164 | clip_max (float): maximum value per input dimension. 165 | """ 166 | 167 | def __init__( 168 | self, predict, overshoot=0.02, nb_iter=50, search_iter=50, clip_min=0., clip_max=1.): 169 | 170 | ord = np.inf 171 | super(LinfDeepFoolAttack, self).__init__( 172 | predict=predict, overshoot=overshoot, nb_iter=nb_iter, search_iter=search_iter, clip_min=clip_min, 173 | clip_max=clip_max, ord=ord) 174 | 175 | 176 | 177 | class L2DeepFoolAttack(DeepFoolAttack): 178 | """ 179 | DeepFool Attack with order=L2. 180 | Arguments: 181 | predict (nn.Module): forward pass function. 182 | overshoot (float): how much to overshoot the boundary. 183 | nb_iter (int): number of iterations. 184 | search_iter (int): no of search iterations. 185 | clip_min (float): mininum value per input dimension. 186 | clip_max (float): maximum value per input dimension. 187 | """ 188 | 189 | def __init__( 190 | self, predict, overshoot=0.02, nb_iter=50, search_iter=50, clip_min=0., clip_max=1.): 191 | 192 | ord = 2 193 | super(L2DeepFoolAttack, self).__init__( 194 | predict=predict, overshoot=overshoot, nb_iter=nb_iter, search_iter=search_iter, clip_min=clip_min, 195 | clip_max=clip_max, ord=ord) 196 | -------------------------------------------------------------------------------- /core/attacks/fgsm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .base import Attack, LabelMixin 5 | from .utils import batch_multiply 6 | from .utils import clamp 7 | 8 | 9 | class FGSMAttack(Attack, LabelMixin): 10 | """ 11 | One step fast gradient sign method (Goodfellow et al, 2014). 12 | Arguments: 13 | predict (nn.Module): forward pass function. 14 | loss_fn (nn.Module): loss function. 15 | eps (float): attack step size. 16 | clip_min (float): mininum value per input dimension. 17 | clip_max (float): maximum value per input dimension. 18 | targeted (bool): indicate if this is a targeted attack. 19 | """ 20 | 21 | def __init__(self, predict, loss_fn=None, eps=0.3, clip_min=0., clip_max=1., targeted=False): 22 | super(FGSMAttack, self).__init__(predict, loss_fn, clip_min, clip_max) 23 | 24 | self.eps = eps 25 | self.targeted = targeted 26 | if self.loss_fn is None: 27 | self.loss_fn = nn.CrossEntropyLoss(reduction="sum") 28 | 29 | def perturb(self, x, y=None): 30 | """ 31 | Given examples (x, y), returns their adversarial counterparts with an attack length of eps. 32 | Arguments: 33 | x (torch.Tensor): input tensor. 34 | y (torch.Tensor): label tensor. 35 | - if None and self.targeted=False, compute y as predicted labels. 36 | - if self.targeted=True, then y must be the targeted labels. 37 | Returns: 38 | torch.Tensor containing perturbed inputs. 39 | torch.Tensor containing the perturbation. 40 | """ 41 | 42 | x, y = self._verify_and_process_inputs(x, y) 43 | 44 | xadv = x.requires_grad_() 45 | outputs = self.predict(xadv) 46 | 47 | loss = self.loss_fn(outputs, y) 48 | if self.targeted: 49 | loss = -loss 50 | loss.backward() 51 | grad_sign = xadv.grad.detach().sign() 52 | 53 | xadv = xadv + batch_multiply(self.eps, grad_sign) 54 | xadv = clamp(xadv, self.clip_min, self.clip_max) 55 | radv = xadv - x 56 | return xadv.detach(), radv.detach() 57 | 58 | 59 | LinfFastGradientAttack = FGSMAttack 60 | 61 | 62 | class FGMAttack(Attack, LabelMixin): 63 | """ 64 | One step fast gradient method. Perturbs the input with gradient (not gradient sign) of the loss wrt the input. 65 | Arguments: 66 | predict (nn.Module): forward pass function. 67 | loss_fn (nn.Module): loss function. 68 | eps (float): attack step size. 69 | clip_min (float): mininum value per input dimension. 70 | clip_max (float): maximum value per input dimension. 71 | targeted (bool): indicate if this is a targeted attack. 72 | """ 73 | 74 | def __init__(self, predict, loss_fn=None, eps=0.3, clip_min=0., clip_max=1., targeted=False): 75 | super(FGMAttack, self).__init__( 76 | predict, loss_fn, clip_min, clip_max) 77 | 78 | self.eps = eps 79 | self.targeted = targeted 80 | if self.loss_fn is None: 81 | self.loss_fn = nn.CrossEntropyLoss(reduction="sum") 82 | 83 | def perturb(self, x, y=None): 84 | """ 85 | Given examples (x, y), returns their adversarial counterparts with an attack length of eps. 86 | Arguments: 87 | x (torch.Tensor): input tensor. 88 | y (torch.Tensor): label tensor. 89 | - if None and self.targeted=False, compute y as predicted labels. 90 | - if self.targeted=True, then y must be the targeted labels. 91 | Returns: 92 | torch.Tensor containing perturbed inputs. 93 | torch.Tensor containing the perturbation. 94 | """ 95 | 96 | x, y = self._verify_and_process_inputs(x, y) 97 | xadv = x.requires_grad_() 98 | outputs = self.predict(xadv) 99 | 100 | loss = self.loss_fn(outputs, y) 101 | if self.targeted: 102 | loss = -loss 103 | loss.backward() 104 | grad = normalize_by_pnorm(xadv.grad) 105 | xadv = xadv + batch_multiply(self.eps, grad) 106 | xadv = clamp(xadv, self.clip_min, self.clip_max) 107 | radv = xadv - x 108 | 109 | return xadv.detach(), radv.detach() 110 | 111 | 112 | L2FastGradientAttack = FGMAttack -------------------------------------------------------------------------------- /core/attacks/pgd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .base import Attack, LabelMixin 6 | 7 | from .utils import batch_clamp 8 | from .utils import batch_multiply 9 | from .utils import clamp 10 | from .utils import clamp_by_pnorm 11 | from .utils import is_float_or_torch_tensor 12 | from .utils import normalize_by_pnorm 13 | from .utils import rand_init_delta 14 | from .utils import replicate_input 15 | 16 | 17 | def perturb_iterative(xvar, yvar, predict, nb_iter, eps, eps_iter, loss_fn, delta_init=None, minimize=False, ord=np.inf, 18 | clip_min=0.0, clip_max=1.0): 19 | """ 20 | Iteratively maximize the loss over the input. It is a shared method for iterative attacks. 21 | Arguments: 22 | xvar (torch.Tensor): input data. 23 | yvar (torch.Tensor): input labels. 24 | predict (nn.Module): forward pass function. 25 | nb_iter (int): number of iterations. 26 | eps (float): maximum distortion. 27 | eps_iter (float): attack step size. 28 | loss_fn (nn.Module): loss function. 29 | delta_init (torch.Tensor): (optional) tensor contains the random initialization. 30 | minimize (bool): (optional) whether to minimize or maximize the loss. 31 | ord (int): (optional) the order of maximum distortion (inf or 2). 32 | clip_min (float): mininum value per input dimension. 33 | clip_max (float): maximum value per input dimension. 34 | Returns: 35 | torch.Tensor containing the perturbed input, 36 | torch.Tensor containing the perturbation 37 | """ 38 | if delta_init is not None: 39 | delta = delta_init 40 | else: 41 | delta = torch.zeros_like(xvar) 42 | 43 | delta.requires_grad_() 44 | for ii in range(nb_iter): 45 | outputs = predict(xvar + delta) 46 | loss = loss_fn(outputs, yvar) 47 | if minimize: 48 | loss = -loss 49 | 50 | loss.backward() 51 | if ord == np.inf: 52 | grad_sign = delta.grad.data.sign() 53 | delta.data = delta.data + batch_multiply(eps_iter, grad_sign) 54 | delta.data = batch_clamp(eps, delta.data) 55 | delta.data = clamp(xvar.data + delta.data, clip_min, clip_max) - xvar.data 56 | elif ord == 2: 57 | grad = delta.grad.data 58 | grad = normalize_by_pnorm(grad) 59 | delta.data = delta.data + batch_multiply(eps_iter, grad) 60 | delta.data = clamp(xvar.data + delta.data, clip_min, clip_max) - xvar.data 61 | if eps is not None: 62 | delta.data = clamp_by_pnorm(delta.data, ord, eps) 63 | else: 64 | error = "Only ord=inf and ord=2 have been implemented" 65 | raise NotImplementedError(error) 66 | delta.grad.data.zero_() 67 | 68 | x_adv = clamp(xvar + delta, clip_min, clip_max) 69 | r_adv = x_adv - xvar 70 | return x_adv, r_adv 71 | 72 | 73 | class PGDAttack(Attack, LabelMixin): 74 | """ 75 | The projected gradient descent attack (Madry et al, 2017). 76 | The attack performs nb_iter steps of size eps_iter, while always staying within eps from the initial point. 77 | Arguments: 78 | predict (nn.Module): forward pass function. 79 | loss_fn (nn.Module): loss function. 80 | eps (float): maximum distortion. 81 | nb_iter (int): number of iterations. 82 | eps_iter (float): attack step size. 83 | rand_init (bool): (optional) random initialization. 84 | clip_min (float): mininum value per input dimension. 85 | clip_max (float): maximum value per input dimension. 86 | ord (int): (optional) the order of maximum distortion (inf or 2). 87 | targeted (bool): if the attack is targeted. 88 | rand_init_type (str): (optional) random initialization type. 89 | """ 90 | 91 | def __init__( 92 | self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1., 93 | ord=np.inf, targeted=False, rand_init_type='uniform'): 94 | super(PGDAttack, self).__init__(predict, loss_fn, clip_min, clip_max) 95 | self.eps = eps 96 | self.nb_iter = nb_iter 97 | self.eps_iter = eps_iter 98 | self.rand_init = rand_init 99 | self.rand_init_type = rand_init_type 100 | self.ord = ord 101 | self.targeted = targeted 102 | if self.loss_fn is None: 103 | self.loss_fn = nn.CrossEntropyLoss(reduction="sum") 104 | assert is_float_or_torch_tensor(self.eps_iter) 105 | assert is_float_or_torch_tensor(self.eps) 106 | 107 | def perturb(self, x, y=None): 108 | """ 109 | Given examples (x, y), returns their adversarial counterparts with an attack length of eps. 110 | Arguments: 111 | x (torch.Tensor): input tensor. 112 | y (torch.Tensor): label tensor. 113 | - if None and self.targeted=False, compute y as predicted 114 | labels. 115 | - if self.targeted=True, then y must be the targeted labels. 116 | Returns: 117 | torch.Tensor containing perturbed inputs, 118 | torch.Tensor containing the perturbation 119 | """ 120 | x, y = self._verify_and_process_inputs(x, y) 121 | 122 | delta = torch.zeros_like(x) 123 | delta = nn.Parameter(delta) 124 | if self.rand_init: 125 | if self.rand_init_type == 'uniform': 126 | rand_init_delta( 127 | delta, x, self.ord, self.eps, self.clip_min, self.clip_max) 128 | delta.data = clamp( 129 | x + delta.data, min=self.clip_min, max=self.clip_max) - x 130 | elif self.rand_init_type == 'normal': 131 | delta.data = 0.001 * torch.randn_like(x) # initialize as in TRADES 132 | else: 133 | raise NotImplementedError('Only rand_init_type=normal and rand_init_type=uniform have been implemented.') 134 | 135 | x_adv, r_adv = perturb_iterative( 136 | x, y, self.predict, nb_iter=self.nb_iter, eps=self.eps, eps_iter=self.eps_iter, loss_fn=self.loss_fn, 137 | minimize=self.targeted, ord=self.ord, clip_min=self.clip_min, clip_max=self.clip_max, delta_init=delta 138 | ) 139 | 140 | return x_adv.data, r_adv.data 141 | 142 | 143 | class LinfPGDAttack(PGDAttack): 144 | """ 145 | PGD Attack with order=Linf 146 | Arguments: 147 | predict (nn.Module): forward pass function. 148 | loss_fn (nn.Module): loss function. 149 | eps (float): maximum distortion. 150 | nb_iter (int): number of iterations. 151 | eps_iter (float): attack step size. 152 | rand_init (bool): (optional) random initialization. 153 | clip_min (float): mininum value per input dimension. 154 | clip_max (float): maximum value per input dimension. 155 | targeted (bool): if the attack is targeted. 156 | rand_init_type (str): (optional) random initialization type. 157 | """ 158 | 159 | def __init__( 160 | self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1., 161 | targeted=False, rand_init_type='uniform'): 162 | ord = np.inf 163 | super(LinfPGDAttack, self).__init__( 164 | predict=predict, loss_fn=loss_fn, eps=eps, nb_iter=nb_iter, eps_iter=eps_iter, rand_init=rand_init, 165 | clip_min=clip_min, clip_max=clip_max, targeted=targeted, ord=ord, rand_init_type=rand_init_type) 166 | 167 | 168 | class L2PGDAttack(PGDAttack): 169 | """ 170 | PGD Attack with order=L2 171 | Arguments: 172 | predict (nn.Module): forward pass function. 173 | loss_fn (nn.Module): loss function. 174 | eps (float): maximum distortion. 175 | nb_iter (int): number of iterations. 176 | eps_iter (float): attack step size. 177 | rand_init (bool): (optional) random initialization. 178 | clip_min (float): mininum value per input dimension. 179 | clip_max (float): maximum value per input dimension. 180 | targeted (bool): if the attack is targeted. 181 | rand_init_type (str): (optional) random initialization type. 182 | """ 183 | 184 | def __init__( 185 | self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1., 186 | targeted=False, rand_init_type='uniform'): 187 | ord = 2 188 | super(L2PGDAttack, self).__init__( 189 | predict=predict, loss_fn=loss_fn, eps=eps, nb_iter=nb_iter, eps_iter=eps_iter, rand_init=rand_init, 190 | clip_min=clip_min, clip_max=clip_max, targeted=targeted, ord=ord, rand_init_type=rand_init_type) 191 | -------------------------------------------------------------------------------- /core/attacks/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | from torch.distributions import laplace 9 | from torch.distributions import uniform 10 | from torch.nn.modules.loss import _Loss 11 | 12 | 13 | def replicate_input(x): 14 | """ 15 | Clone the input tensor x. 16 | """ 17 | return x.detach().clone() 18 | 19 | 20 | def replicate_input_withgrad(x): 21 | """ 22 | Clone the input tensor x and set requires_grad=True. 23 | """ 24 | return x.detach().clone().requires_grad_() 25 | 26 | 27 | def calc_l2distsq(x, y): 28 | """ 29 | Calculate L2 distance between tensors x and y. 30 | """ 31 | d = (x - y)**2 32 | return d.view(d.shape[0], -1).sum(dim=1) 33 | 34 | 35 | def clamp(input, min=None, max=None): 36 | """ 37 | Clamp a tensor by its minimun and maximun values. 38 | """ 39 | ndim = input.ndimension() 40 | if min is None: 41 | pass 42 | elif isinstance(min, (float, int)): 43 | input = torch.clamp(input, min=min) 44 | elif isinstance(min, torch.Tensor): 45 | if min.ndimension() == ndim - 1 and min.shape == input.shape[1:]: 46 | input = torch.max(input, min.view(1, *min.shape)) 47 | else: 48 | assert min.shape == input.shape 49 | input = torch.max(input, min) 50 | else: 51 | raise ValueError("min can only be None | float | torch.Tensor") 52 | 53 | if max is None: 54 | pass 55 | elif isinstance(max, (float, int)): 56 | input = torch.clamp(input, max=max) 57 | elif isinstance(max, torch.Tensor): 58 | if max.ndimension() == ndim - 1 and max.shape == input.shape[1:]: 59 | input = torch.min(input, max.view(1, *max.shape)) 60 | else: 61 | assert max.shape == input.shape 62 | input = torch.min(input, max) 63 | else: 64 | raise ValueError("max can only be None | float | torch.Tensor") 65 | return input 66 | 67 | 68 | def _batch_multiply_tensor_by_vector(vector, batch_tensor): 69 | """Equivalent to the following. 70 | for ii in range(len(vector)): 71 | batch_tensor.data[ii] *= vector[ii] 72 | return batch_tensor 73 | """ 74 | return ( 75 | batch_tensor.transpose(0, -1) * vector).transpose(0, -1).contiguous() 76 | 77 | 78 | def _batch_clamp_tensor_by_vector(vector, batch_tensor): 79 | """Equivalent to the following. 80 | for ii in range(len(vector)): 81 | batch_tensor[ii] = clamp( 82 | batch_tensor[ii], -vector[ii], vector[ii]) 83 | """ 84 | return torch.min( 85 | torch.max(batch_tensor.transpose(0, -1), -vector), vector 86 | ).transpose(0, -1).contiguous() 87 | 88 | 89 | def batch_multiply(float_or_vector, tensor): 90 | """ 91 | Multpliy a batch of tensors with a float or vector. 92 | """ 93 | if isinstance(float_or_vector, torch.Tensor): 94 | assert len(float_or_vector) == len(tensor) 95 | tensor = _batch_multiply_tensor_by_vector(float_or_vector, tensor) 96 | elif isinstance(float_or_vector, float): 97 | tensor *= float_or_vector 98 | else: 99 | raise TypeError("Value has to be float or torch.Tensor") 100 | return tensor 101 | 102 | 103 | def batch_clamp(float_or_vector, tensor): 104 | """ 105 | Clamp a batch of tensors. 106 | """ 107 | if isinstance(float_or_vector, torch.Tensor): 108 | assert len(float_or_vector) == len(tensor) 109 | tensor = _batch_clamp_tensor_by_vector(float_or_vector, tensor) 110 | return tensor 111 | elif isinstance(float_or_vector, float): 112 | tensor = clamp(tensor, -float_or_vector, float_or_vector) 113 | else: 114 | raise TypeError("Value has to be float or torch.Tensor") 115 | return tensor 116 | 117 | 118 | def _get_norm_batch(x, p): 119 | """ 120 | Returns the Lp norm of batch x. 121 | """ 122 | batch_size = x.size(0) 123 | return x.abs().pow(p).view(batch_size, -1).sum(dim=1).pow(1. / p) 124 | 125 | 126 | def _thresh_by_magnitude(theta, x): 127 | """ 128 | Threshold by magnitude. 129 | """ 130 | return torch.relu(torch.abs(x) - theta) * x.sign() 131 | 132 | 133 | def clamp_by_pnorm(x, p, r): 134 | """ 135 | Clamp tensor by its norm. 136 | """ 137 | assert isinstance(p, float) or isinstance(p, int) 138 | norm = _get_norm_batch(x, p) 139 | if isinstance(r, torch.Tensor): 140 | assert norm.size() == r.size() 141 | else: 142 | assert isinstance(r, float) 143 | factor = torch.min(r / norm, torch.ones_like(norm)) 144 | return batch_multiply(factor, x) 145 | 146 | 147 | def is_float_or_torch_tensor(x): 148 | """ 149 | Return whether input x is a float or a torch.Tensor. 150 | """ 151 | return isinstance(x, torch.Tensor) or isinstance(x, float) 152 | 153 | 154 | def normalize_by_pnorm(x, p=2, small_constant=1e-6): 155 | """ 156 | Normalize gradients for gradient (not gradient sign) attacks. 157 | Arguments: 158 | x (torch.Tensor): tensor containing the gradients on the input. 159 | p (int): (optional) order of the norm for the normalization (1 or 2). 160 | small_constant (float): (optional) to avoid dividing by zero. 161 | Returns: 162 | normalized gradients. 163 | """ 164 | assert isinstance(p, float) or isinstance(p, int) 165 | norm = _get_norm_batch(x, p) 166 | norm = torch.max(norm, torch.ones_like(norm) * small_constant) 167 | return batch_multiply(1. / norm, x) 168 | 169 | 170 | def rand_init_delta(delta, x, ord, eps, clip_min, clip_max): 171 | """ 172 | Randomly initialize the perturbation. 173 | """ 174 | if isinstance(eps, torch.Tensor): 175 | assert len(eps) == len(delta) 176 | 177 | if ord == np.inf: 178 | delta.data.uniform_(-1, 1) 179 | delta.data = batch_multiply(eps, delta.data) 180 | elif ord == 2: 181 | delta.data.uniform_(clip_min, clip_max) 182 | delta.data = delta.data - x 183 | delta.data = clamp_by_pnorm(delta.data, ord, eps) 184 | elif ord == 1: 185 | ini = laplace.Laplace( 186 | loc=delta.new_tensor(0), scale=delta.new_tensor(1)) 187 | delta.data = ini.sample(delta.data.shape) 188 | delta.data = normalize_by_pnorm(delta.data, p=1) 189 | ray = uniform.Uniform(0, eps).sample() 190 | delta.data *= ray 191 | delta.data = clamp(x.data + delta.data, clip_min, clip_max) - x.data 192 | else: 193 | error = "Only ord = inf, ord = 1 and ord = 2 have been implemented" 194 | raise NotImplementedError(error) 195 | 196 | delta.data = clamp( 197 | x + delta.data, min=clip_min, max=clip_max) - x 198 | return delta.data 199 | 200 | 201 | def CWLoss(output, target, confidence=0): 202 | """ 203 | CW loss (Marging loss). 204 | """ 205 | num_classes = output.shape[-1] 206 | target = target.data 207 | target_onehot = torch.zeros(target.size() + (num_classes,)) 208 | target_onehot = target_onehot.cuda() 209 | target_onehot.scatter_(1, target.unsqueeze(1), 1.) 210 | target_var = Variable(target_onehot, requires_grad=False) 211 | real = (target_var * output).sum(1) 212 | other = ((1. - target_var) * output - target_var * 10000.).max(1)[0] 213 | loss = - torch.clamp(real - other + confidence, min=0.) 214 | loss = torch.sum(loss) 215 | return loss 216 | -------------------------------------------------------------------------------- /core/data/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from .cifar10 import load_cifar10 5 | from .cifar100 import load_cifar100 6 | from .svhn import load_svhn 7 | from .cifar10s import load_cifar10s 8 | from .cifar100s import load_cifar100s 9 | from .tiny_imagenet import load_tinyimagenet 10 | 11 | from .semisup import get_semisup_dataloaders 12 | 13 | 14 | SEMISUP_DATASETS = ['cifar10s', 'cifar100s'] 15 | DATASETS = ['cifar10', 'svhn', 'cifar100', 'tiny-imagenet'] + SEMISUP_DATASETS 16 | 17 | _LOAD_DATASET_FN = { 18 | 'cifar10': load_cifar10, 19 | 'cifar100': load_cifar100, 20 | 'svhn': load_svhn, 21 | 'tiny-imagenet': load_tinyimagenet, 22 | 'cifar10s': load_cifar10s, 23 | 'cifar100s': load_cifar100s, 24 | } 25 | 26 | 27 | def get_data_info(data_dir): 28 | """ 29 | Returns dataset information. 30 | Arguments: 31 | data_dir (str): path to data directory. 32 | """ 33 | dataset = os.path.basename(os.path.normpath(data_dir)) 34 | if 'cifar100' in data_dir: 35 | from .cifar100 import DATA_DESC 36 | elif 'cifar10' in data_dir: 37 | from .cifar10 import DATA_DESC 38 | elif 'svhn' in data_dir: 39 | from .svhn import DATA_DESC 40 | elif 'tiny-imagenet' in data_dir: 41 | from .tiny_imagenet import DATA_DESC 42 | else: 43 | raise ValueError(f'Only data in {DATASETS} are supported!') 44 | DATA_DESC['data'] = dataset 45 | return DATA_DESC 46 | 47 | 48 | def load_data(data_dir, batch_size=256, batch_size_test=256, num_workers=4, use_augmentation=False, shuffle_train=True, 49 | aux_data_filename=None, unsup_fraction=None, validation=False): 50 | """ 51 | Returns train, test datasets and dataloaders. 52 | Arguments: 53 | data_dir (str): path to data directory. 54 | batch_size (int): batch size for training. 55 | batch_size_test (int): batch size for validation. 56 | num_workers (int): number of workers for loading the data. 57 | use_augmentation (bool): whether to use augmentations for training set. 58 | shuffle_train (bool): whether to shuffle training set. 59 | aux_data_filename (str): path to unlabelled data. 60 | unsup_fraction (float): fraction of unlabelled data per batch. 61 | validation (bool): if True, also returns a validation dataloader for unspervised cifar10 (as in Gowal et al, 2020). 62 | """ 63 | dataset = os.path.basename(os.path.normpath(data_dir)) 64 | load_dataset_fn = _LOAD_DATASET_FN[dataset] 65 | 66 | if validation: 67 | assert dataset in SEMISUP_DATASETS, 'Only semi-supervised datasets allow a validation set.' 68 | train_dataset, test_dataset, val_dataset = load_dataset_fn(data_dir=data_dir, use_augmentation=use_augmentation, 69 | aux_data_filename=aux_data_filename, validation=True) 70 | else: 71 | train_dataset, test_dataset = load_dataset_fn(data_dir=data_dir, use_augmentation=use_augmentation) 72 | 73 | if dataset in SEMISUP_DATASETS: 74 | if validation: 75 | train_dataloader, test_dataloader, val_dataloader = get_semisup_dataloaders( 76 | train_dataset, test_dataset, val_dataset, batch_size, batch_size_test, num_workers, unsup_fraction 77 | ) 78 | else: 79 | train_dataloader, test_dataloader = get_semisup_dataloaders( 80 | train_dataset, test_dataset, None, batch_size, batch_size_test, num_workers, unsup_fraction 81 | ) 82 | else: 83 | pin_memory = torch.cuda.is_available() 84 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle_train, 85 | num_workers=num_workers, pin_memory=pin_memory) 86 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False, 87 | num_workers=num_workers, pin_memory=pin_memory) 88 | if validation: 89 | return train_dataset, test_dataset, val_dataset, train_dataloader, test_dataloader, val_dataloader 90 | return train_dataset, test_dataset, train_dataloader, test_dataloader 91 | -------------------------------------------------------------------------------- /core/data/cifar10.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | 7 | DATA_DESC = { 8 | 'data': 'cifar10', 9 | 'classes': ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'), 10 | 'num_classes': 10, 11 | 'mean': [0.4914, 0.4822, 0.4465], 12 | 'std': [0.2023, 0.1994, 0.2010], 13 | } 14 | 15 | 16 | def load_cifar10(data_dir, use_augmentation=False): 17 | """ 18 | Returns CIFAR10 train, test datasets and dataloaders. 19 | Arguments: 20 | data_dir (str): path to data directory. 21 | use_augmentation (bool): whether to use augmentations for training set. 22 | Returns: 23 | train dataset, test dataset. 24 | """ 25 | test_transform = transforms.Compose([transforms.ToTensor()]) 26 | if use_augmentation: 27 | train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(0.5), 28 | transforms.ToTensor()]) 29 | else: 30 | train_transform = test_transform 31 | 32 | train_dataset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform) 33 | test_dataset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=test_transform) 34 | return train_dataset, test_dataset -------------------------------------------------------------------------------- /core/data/cifar100.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | 7 | DATA_DESC = { 8 | 'data': 'cifar100', 9 | 'classes': tuple(range(0, 100)), 10 | 'num_classes': 100, 11 | 'mean': [0.5071, 0.4865, 0.4409], 12 | 'std': [0.2673, 0.2564, 0.2762], 13 | } 14 | 15 | 16 | def load_cifar100(data_dir, use_augmentation=False): 17 | """ 18 | Returns CIFAR100 train, test datasets and dataloaders. 19 | Arguments: 20 | data_dir (str): path to data directory. 21 | use_augmentation (bool): whether to use augmentations for training set. 22 | Returns: 23 | train dataset, test dataset. 24 | """ 25 | test_transform = transforms.Compose([transforms.ToTensor()]) 26 | if use_augmentation: 27 | train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(0.5), 28 | transforms.RandomRotation(15), transforms.ToTensor()]) 29 | else: 30 | train_transform = test_transform 31 | 32 | train_dataset = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform) 33 | test_dataset = torchvision.datasets.CIFAR100(root=data_dir, train=False, download=True, transform=test_transform) 34 | return train_dataset, test_dataset -------------------------------------------------------------------------------- /core/data/cifar100s.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | import re 7 | import numpy as np 8 | 9 | from .semisup import SemiSupervisedDataset 10 | 11 | 12 | def load_cifar100s(data_dir, use_augmentation=False, aux_take_amount=None, 13 | aux_data_filename=None, validation=False): 14 | """ 15 | Returns semisupervised CIFAR100 train, test datasets and dataloaders (with DDPM Images). 16 | Arguments: 17 | data_dir (str): path to data directory. 18 | use_augmentation (bool): whether to use augmentations for training set. 19 | aux_take_amount (int): number of semi-supervised examples to use (if None, use all). 20 | aux_data_filename (str): path to additional data pickle file. 21 | Returns: 22 | train dataset, test dataset. 23 | """ 24 | data_dir = re.sub('cifar100s', 'cifar100', data_dir) 25 | test_transform = transforms.Compose([transforms.ToTensor()]) 26 | if use_augmentation: 27 | train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(0.5), 28 | transforms.RandomRotation(15), transforms.ToTensor()]) 29 | else: 30 | train_transform = test_transform 31 | 32 | train_dataset = SemiSupervisedCIFAR100(base_dataset='cifar100', root=data_dir, train=True, download=True, 33 | transform=train_transform, aux_data_filename=aux_data_filename, 34 | add_aux_labels=True, aux_take_amount=aux_take_amount, validation=validation) 35 | test_dataset = SemiSupervisedCIFAR100(base_dataset='cifar100', root=data_dir, train=False, download=True, 36 | transform=test_transform) 37 | if validation: 38 | val_dataset = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=test_transform) 39 | val_dataset = torch.utils.data.Subset(val_dataset, np.arange(0, 1024)) 40 | return train_dataset, test_dataset, val_dataset 41 | return train_dataset, test_dataset, None 42 | 43 | 44 | class SemiSupervisedCIFAR100(SemiSupervisedDataset): 45 | """ 46 | A dataset with auxiliary pseudo-labeled data for CIFAR100. 47 | """ 48 | def load_base_dataset(self, train=False, **kwargs): 49 | assert self.base_dataset == 'cifar100', 'Only semi-supervised cifar100 is supported. Please use correct dataset!' 50 | self.dataset = torchvision.datasets.CIFAR100(train=train, **kwargs) 51 | self.dataset_size = len(self.dataset) -------------------------------------------------------------------------------- /core/data/cifar10s.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | import re 7 | import numpy as np 8 | 9 | from .semisup import SemiSupervisedDataset 10 | from .semisup import SemiSupervisedSampler 11 | 12 | 13 | def load_cifar10s(data_dir, use_augmentation=False, aux_take_amount=None, 14 | aux_data_filename='/cluster/scratch/rarade/cifar10s/ti_500K_pseudo_labeled.pickle', 15 | validation=False): 16 | """ 17 | Returns semisupervised CIFAR10 train, test datasets and dataloaders (with Tiny Images). 18 | Arguments: 19 | data_dir (str): path to data directory. 20 | use_augmentation (bool): whether to use augmentations for training set. 21 | aux_take_amount (int): number of semi-supervised examples to use (if None, use all). 22 | aux_data_filename (str): path to additional data pickle file. 23 | Returns: 24 | train dataset, test dataset. 25 | """ 26 | data_dir = re.sub('cifar10s', 'cifar10', data_dir) 27 | test_transform = transforms.Compose([transforms.ToTensor()]) 28 | if use_augmentation: 29 | train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(0.5), 30 | transforms.ToTensor()]) 31 | else: 32 | train_transform = test_transform 33 | 34 | train_dataset = SemiSupervisedCIFAR10(base_dataset='cifar10', root=data_dir, train=True, download=True, 35 | transform=train_transform, aux_data_filename=aux_data_filename, 36 | add_aux_labels=True, aux_take_amount=aux_take_amount, validation=validation) 37 | test_dataset = SemiSupervisedCIFAR10(base_dataset='cifar10', root=data_dir, train=False, download=True, 38 | transform=test_transform) 39 | if validation: 40 | val_dataset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=test_transform) 41 | val_dataset = torch.utils.data.Subset(val_dataset, np.arange(0, 1024)) 42 | return train_dataset, test_dataset, val_dataset 43 | return train_dataset, test_dataset 44 | 45 | 46 | class SemiSupervisedCIFAR10(SemiSupervisedDataset): 47 | """ 48 | A dataset with auxiliary pseudo-labeled data for CIFAR10. 49 | """ 50 | def load_base_dataset(self, train=False, **kwargs): 51 | assert self.base_dataset == 'cifar10', 'Only semi-supervised cifar10 is supported. Please use correct dataset!' 52 | self.dataset = torchvision.datasets.CIFAR10(train=train, **kwargs) 53 | self.dataset_size = len(self.dataset) 54 | -------------------------------------------------------------------------------- /core/data/semisup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | 5 | import torch 6 | 7 | 8 | def get_semisup_dataloaders(train_dataset, test_dataset, val_dataset=None, batch_size=256, batch_size_test=256, num_workers=4, 9 | unsup_fraction=0.5): 10 | """ 11 | Return dataloaders with custom sampling of pseudo-labeled data. 12 | """ 13 | dataset_size = train_dataset.dataset_size 14 | train_batch_sampler = SemiSupervisedSampler(train_dataset.sup_indices, train_dataset.unsup_indices, batch_size, 15 | unsup_fraction, num_batches=int(np.ceil(dataset_size/batch_size))) 16 | epoch_size = len(train_batch_sampler) * batch_size 17 | 18 | kwargs = {'num_workers': num_workers, 'pin_memory': torch.cuda.is_available() } 19 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_sampler=train_batch_sampler, **kwargs) 20 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False, **kwargs) 21 | 22 | if val_dataset: 23 | val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size_test, shuffle=False, **kwargs) 24 | return train_dataloader, test_dataloader, val_dataloader 25 | return train_dataloader, test_dataloader 26 | 27 | 28 | class SemiSupervisedDataset(torch.utils.data.Dataset): 29 | """ 30 | A dataset with auxiliary pseudo-labeled data. 31 | """ 32 | def __init__(self, base_dataset='cifar10', take_amount=None, take_amount_seed=13, aux_data_filename=None, 33 | add_aux_labels=False, aux_take_amount=None, train=False, validation=False, **kwargs): 34 | 35 | self.base_dataset = base_dataset 36 | self.load_base_dataset(train, **kwargs) 37 | 38 | if validation: 39 | self.dataset.data = self.dataset.data[1024:] 40 | self.dataset.targets = self.dataset.targets[1024:] 41 | 42 | self.train = train 43 | 44 | if self.train: 45 | if take_amount is not None: 46 | rng_state = np.random.get_state() 47 | np.random.seed(take_amount_seed) 48 | take_inds = np.random.choice(len(self.sup_indices), take_amount, replace=False) 49 | np.random.set_state(rng_state) 50 | 51 | self.targets = self.targets[take_inds] 52 | self.data = self.data[take_inds] 53 | 54 | self.sup_indices = list(range(len(self.targets))) 55 | self.unsup_indices = [] 56 | 57 | if aux_data_filename is not None: 58 | aux_path = aux_data_filename 59 | print('Loading data from %s' % aux_path) 60 | if os.path.splitext(aux_path)[1] == '.pickle': 61 | # for data from Carmon et al, 2019. 62 | with open(aux_path, 'rb') as f: 63 | aux = pickle.load(f) 64 | aux_data = aux['data'] 65 | aux_targets = aux['extrapolated_targets'] 66 | else: 67 | # for data from Rebuffi et al, 2021. 68 | aux = np.load(aux_path) 69 | aux_data = aux['image'] 70 | aux_targets = aux['label'] 71 | 72 | orig_len = len(self.data) 73 | 74 | if aux_take_amount is not None: 75 | rng_state = np.random.get_state() 76 | np.random.seed(take_amount_seed) 77 | take_inds = np.random.choice(len(aux_data), aux_take_amount, replace=False) 78 | np.random.set_state(rng_state) 79 | 80 | aux_data = aux_data[take_inds] 81 | aux_targets = aux_targets[take_inds] 82 | 83 | self.data = np.concatenate((self.data, aux_data), axis=0) 84 | 85 | if not add_aux_labels: 86 | self.targets.extend([-1] * len(aux_data)) 87 | else: 88 | self.targets.extend(aux_targets) 89 | self.unsup_indices.extend(range(orig_len, orig_len+len(aux_data))) 90 | 91 | else: 92 | self.sup_indices = list(range(len(self.targets))) 93 | self.unsup_indices = [] 94 | 95 | def load_base_dataset(self, **kwargs): 96 | raise NotImplementedError() 97 | 98 | @property 99 | def data(self): 100 | return self.dataset.data 101 | 102 | @data.setter 103 | def data(self, value): 104 | self.dataset.data = value 105 | 106 | @property 107 | def targets(self): 108 | return self.dataset.targets 109 | 110 | @targets.setter 111 | def targets(self, value): 112 | self.dataset.targets = value 113 | 114 | def __len__(self): 115 | return len(self.dataset) 116 | 117 | def __getitem__(self, item): 118 | self.dataset.labels = self.targets 119 | return self.dataset[item] 120 | 121 | 122 | class SemiSupervisedSampler(torch.utils.data.Sampler): 123 | """ 124 | Balanced sampling from the labeled and unlabeled data. 125 | """ 126 | def __init__(self, sup_inds, unsup_inds, batch_size, unsup_fraction=0.5, num_batches=None): 127 | if unsup_fraction is None or unsup_fraction < 0: 128 | self.sup_inds = sup_inds + unsup_inds 129 | unsup_fraction = 0.0 130 | else: 131 | self.sup_inds = sup_inds 132 | self.unsup_inds = unsup_inds 133 | 134 | self.batch_size = batch_size 135 | unsup_batch_size = int(batch_size * unsup_fraction) 136 | self.sup_batch_size = batch_size - unsup_batch_size 137 | 138 | if num_batches is not None: 139 | self.num_batches = num_batches 140 | else: 141 | self.num_batches = int(np.ceil(len(self.sup_inds) / self.sup_batch_size)) 142 | super().__init__(None) 143 | 144 | def __iter__(self): 145 | batch_counter = 0 146 | while batch_counter < self.num_batches: 147 | sup_inds_shuffled = [self.sup_inds[i] 148 | for i in torch.randperm(len(self.sup_inds))] 149 | for sup_k in range(0, len(self.sup_inds), self.sup_batch_size): 150 | if batch_counter == self.num_batches: 151 | break 152 | batch = sup_inds_shuffled[sup_k:(sup_k + self.sup_batch_size)] 153 | if self.sup_batch_size < self.batch_size: 154 | batch.extend([self.unsup_inds[i] for i in torch.randint(high=len(self.unsup_inds), 155 | size=(self.batch_size - len(batch),), 156 | dtype=torch.int64)]) 157 | np.random.shuffle(batch) 158 | yield batch 159 | batch_counter += 1 160 | 161 | def __len__(self): 162 | return self.num_batches -------------------------------------------------------------------------------- /core/data/svhn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | 7 | DATA_DESC = { 8 | 'data': 'svhn', 9 | 'classes': ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9'), 10 | 'num_classes': 10, 11 | 'mean': [0.4914, 0.4822, 0.4465], 12 | 'std': [0.2023, 0.1994, 0.2010], 13 | } 14 | 15 | 16 | def load_svhn(data_dir, use_augmentation=False): 17 | """ 18 | Returns SVHN train, test datasets and dataloaders. 19 | Arguments: 20 | data_dir (str): path to data directory. 21 | use_augmentation (bool): whether to use augmentations for training set. 22 | Returns: 23 | train dataset, test dataset. 24 | """ 25 | test_transform = transforms.Compose([transforms.ToTensor()]) 26 | train_transform = test_transform 27 | 28 | train_dataset = torchvision.datasets.SVHN(root=data_dir, split='train', download=True, transform=train_transform) 29 | test_dataset = torchvision.datasets.SVHN(root=data_dir, split='test', download=True, transform=test_transform) 30 | return train_dataset, test_dataset -------------------------------------------------------------------------------- /core/data/tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | from torchvision.datasets import ImageFolder 7 | 8 | 9 | DATA_DESC = { 10 | 'data': 'tiny-imagenet', 11 | 'classes': tuple(range(0, 200)), 12 | 'num_classes': 200, 13 | 'mean': [0.4802, 0.4481, 0.3975], 14 | 'std': [0.2302, 0.2265, 0.2262], 15 | } 16 | 17 | 18 | def load_tinyimagenet(data_dir, use_augmentation=False): 19 | """ 20 | Returns Tiny Imagenet-200 train, test datasets and dataloaders. 21 | Arguments: 22 | data_dir (str): path to data directory. 23 | use_augmentation (bool): whether to use augmentations for training set. 24 | Returns: 25 | train dataset, test dataset. 26 | """ 27 | test_transform = transforms.Compose([transforms.ToTensor()]) 28 | if use_augmentation: 29 | train_transform = transforms.Compose([transforms.RandomCrop(64, padding=4), transforms.RandomHorizontalFlip(), 30 | transforms.ToTensor()]) 31 | else: 32 | train_transform = test_transform 33 | 34 | train_dataset = ImageFolder(os.path.join(data_dir, 'train'), transform=train_transform) 35 | test_dataset = ImageFolder(os.path.join(data_dir, 'val'), transform=test_transform) 36 | 37 | return train_dataset, test_dataset 38 | -------------------------------------------------------------------------------- /core/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def accuracy(true, preds): 4 | """ 5 | Computes multi-class accuracy. 6 | Arguments: 7 | true (torch.Tensor): true labels. 8 | preds (torch.Tensor): predicted labels. 9 | Returns: 10 | Multi-class accuracy. 11 | """ 12 | accuracy = (torch.softmax(preds, dim=1).argmax(dim=1) == true).sum().float()/float(true.size(0)) 13 | return accuracy.item() 14 | -------------------------------------------------------------------------------- /core/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .resnet import Normalization 4 | from .preact_resnet import preact_resnet 5 | from .resnet import resnet 6 | from .wideresnet import wideresnet 7 | 8 | from .preact_resnetwithswish import preact_resnetwithswish 9 | from .wideresnetwithswish import wideresnetwithswish 10 | 11 | from core.data import DATASETS 12 | 13 | 14 | MODELS = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 15 | 'preact-resnet18', 'preact-resnet34', 'preact-resnet50', 'preact-resnet101', 16 | 'wrn-28-10', 'wrn-32-10', 'wrn-34-10', 'wrn-34-20', 17 | 'preact-resnet18-swish', 'preact-resnet34-swish', 18 | 'wrn-28-10-swish', 'wrn-34-20-swish', 'wrn-70-16-swish'] 19 | 20 | 21 | def create_model(name, normalize, info, device): 22 | """ 23 | Returns suitable model from its name. 24 | Arguments: 25 | name (str): name of resnet architecture. 26 | normalize (bool): normalize input. 27 | info (dict): dataset information. 28 | device (str or torch.device): device to work on. 29 | Returns: 30 | torch.nn.Module. 31 | """ 32 | if info['data'] in ['tiny-imagenet']: 33 | assert 'preact-resnet' in name, 'Only preact-resnets are supported for this dataset!' 34 | from .ti_preact_resnet import ti_preact_resnet 35 | backbone = ti_preact_resnet(name, num_classes=info['num_classes'], device=device) 36 | 37 | elif info['data'] in DATASETS and info['data'] not in ['tiny-imagenet']: 38 | if 'preact-resnet' in name and 'swish' not in name: 39 | backbone = preact_resnet(name, num_classes=info['num_classes'], pretrained=False, device=device) 40 | elif 'preact-resnet' in name and 'swish' in name: 41 | backbone = preact_resnetwithswish(name, dataset=info['data'], num_classes=info['num_classes']) 42 | elif 'resnet' in name and 'preact' not in name: 43 | backbone = resnet(name, num_classes=info['num_classes'], pretrained=False, device=device) 44 | elif 'wrn' in name and 'swish' not in name: 45 | backbone = wideresnet(name, num_classes=info['num_classes'], device=device) 46 | elif 'wrn' in name and 'swish' in name: 47 | backbone = wideresnetwithswish(name, dataset=info['data'], num_classes=info['num_classes'], device=device) 48 | else: 49 | raise ValueError('Invalid model name {}!'.format(name)) 50 | 51 | else: 52 | raise ValueError('Models for {} not yet supported!'.format(info['data'])) 53 | 54 | if normalize: 55 | model = torch.nn.Sequential(Normalization(info['mean'], info['std']), backbone) 56 | else: 57 | model = torch.nn.Sequential(backbone) 58 | 59 | model = torch.nn.DataParallel(model) 60 | model = model.to(device) 61 | return model 62 | -------------------------------------------------------------------------------- /core/models/preact_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class PreActBlock(nn.Module): 7 | """ 8 | Pre-activation version of the BasicBlock for Resnets. 9 | Arguments: 10 | in_planes (int): number of input planes. 11 | planes (int): number of output filters. 12 | stride (int): stride of convolution. 13 | """ 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | out = self.conv2(F.relu(self.bn2(out))) 33 | out += shortcut 34 | return out 35 | 36 | 37 | class PreActBottleneck(nn.Module): 38 | """ 39 | Pre-activation version of the original Bottleneck module for Resnets. 40 | Arguments: 41 | in_planes (int): number of input planes. 42 | planes (int): number of output filters. 43 | stride (int): stride of convolution. 44 | """ 45 | expansion = 4 46 | 47 | def __init__(self, in_planes, planes, stride=1): 48 | super(PreActBottleneck, self).__init__() 49 | self.bn1 = nn.BatchNorm2d(in_planes) 50 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 53 | self.bn3 = nn.BatchNorm2d(planes) 54 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 55 | 56 | if stride != 1 or in_planes != self.expansion*planes: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 59 | ) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(x)) 63 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 64 | out = self.conv1(out) 65 | out = self.conv2(F.relu(self.bn2(out))) 66 | out = self.conv3(F.relu(self.bn3(out))) 67 | out += shortcut 68 | return out 69 | 70 | 71 | class PreActResNet(nn.Module): 72 | """ 73 | Pre-activation Resnet model 74 | """ 75 | def __init__(self, block, num_blocks, num_classes=10): 76 | super(PreActResNet, self).__init__() 77 | self.in_planes = 64 78 | 79 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 80 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 81 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 82 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 83 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 84 | self.bn = nn.BatchNorm2d(512 * block.expansion) 85 | self.linear = nn.Linear(512*block.expansion, num_classes) 86 | 87 | def _make_layer(self, block, planes, num_blocks, stride): 88 | strides = [stride] + [1]*(num_blocks-1) 89 | layers = [] 90 | for stride in strides: 91 | layers.append(block(self.in_planes, planes, stride)) 92 | self.in_planes = planes * block.expansion 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | out = self.conv1(x) 97 | out = self.layer1(out) 98 | out = self.layer2(out) 99 | out = self.layer3(out) 100 | out = self.layer4(out) 101 | out = F.relu(self.bn(out)) 102 | out = F.avg_pool2d(out, 4) 103 | out = out.view(out.size(0), -1) 104 | out = self.linear(out) 105 | return out 106 | 107 | 108 | def preact_resnet(name, num_classes=10, pretrained=False, device='cpu'): 109 | """ 110 | Returns suitable Resnet model from its name. 111 | Arguments: 112 | name (str): name of resnet architecture. 113 | num_classes (int): number of target classes. 114 | pretrained (bool): whether to use a pretrained model. 115 | device (str or torch.device): device to work on. 116 | Returns: 117 | torch.nn.Module. 118 | """ 119 | if name == 'preact-resnet18': 120 | return PreActResNet(PreActBlock, [2, 2, 2, 2], num_classes=num_classes) 121 | elif name == 'preact-resnet34': 122 | return PreActResNet(PreActBlock, [3, 4, 6, 3], num_classes=num_classes) 123 | elif name == 'preact-resnet50': 124 | return PreActResNet(PreActBottleneck, [3, 4, 6, 3], num_classes=num_classes) 125 | elif name == 'preact-resnet101': 126 | return PreActResNet(PreActBottleneck, [3, 4, 23, 3], num_classes=num_classes) 127 | raise ValueError('Only preact-resnet18, preact-resnet34, preact-resnet50 and preact-resnet101 are supported!') 128 | return 129 | -------------------------------------------------------------------------------- /core/models/preact_resnetwithswish.py: -------------------------------------------------------------------------------- 1 | # Code borrowed from https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/pytorch/model_zoo.py 2 | # (Rebuffi et al 2021) 3 | 4 | from typing import Tuple, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) 12 | CIFAR10_STD = (0.2471, 0.2435, 0.2616) 13 | CIFAR100_MEAN = (0.5071, 0.4865, 0.4409) 14 | CIFAR100_STD = (0.2673, 0.2564, 0.2762) 15 | SVHN_MEAN = (0.5, 0.5, 0.5) 16 | SVHN_STD = (0.5, 0.5, 0.5) 17 | 18 | _ACTIVATION = { 19 | 'relu': nn.ReLU, 20 | 'swish': nn.SiLU, 21 | } 22 | 23 | 24 | class _PreActBlock(nn.Module): 25 | """ 26 | PreAct ResNet Block. 27 | Arguments: 28 | in_planes (int): number of input planes. 29 | out_planes (int): number of output filters. 30 | stride (int): stride of convolution. 31 | activation_fn (nn.Module): activation function. 32 | """ 33 | def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU): 34 | super().__init__() 35 | self._stride = stride 36 | self.batchnorm_0 = nn.BatchNorm2d(in_planes, momentum=0.01) 37 | self.relu_0 = activation_fn() 38 | # We manually pad to obtain the same effect as `SAME` (necessary when 39 | # `stride` is different than 1). 40 | self.conv_2d_1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, 41 | stride=stride, padding=0, bias=False) 42 | self.batchnorm_1 = nn.BatchNorm2d(out_planes, momentum=0.01) 43 | self.relu_1 = activation_fn() 44 | self.conv_2d_2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 45 | padding=1, bias=False) 46 | self.has_shortcut = stride != 1 or in_planes != out_planes 47 | if self.has_shortcut: 48 | self.shortcut = nn.Conv2d(in_planes, out_planes, kernel_size=3, 49 | stride=stride, padding=0, bias=False) 50 | 51 | def _pad(self, x): 52 | if self._stride == 1: 53 | x = F.pad(x, (1, 1, 1, 1)) 54 | elif self._stride == 2: 55 | x = F.pad(x, (0, 1, 0, 1)) 56 | else: 57 | raise ValueError('Unsupported `stride`.') 58 | return x 59 | 60 | def forward(self, x): 61 | out = self.relu_0(self.batchnorm_0(x)) 62 | shortcut = self.shortcut(self._pad(x)) if self.has_shortcut else x 63 | out = self.conv_2d_1(self._pad(out)) 64 | out = self.conv_2d_2(self.relu_1(self.batchnorm_1(out))) 65 | return out + shortcut 66 | 67 | 68 | class PreActResNet(nn.Module): 69 | """ 70 | PreActResNet model 71 | Arguments: 72 | num_classes (int): number of output classes. 73 | depth (int): number of layers. 74 | width (int): width factor. 75 | activation_fn (nn.Module): activation function. 76 | mean (tuple): mean of dataset. 77 | std (tuple): standard deviation of dataset. 78 | padding (int): padding. 79 | num_input_channels (int): number of channels in the input. 80 | """ 81 | 82 | def __init__(self, 83 | num_classes: int = 10, 84 | depth: int = 18, 85 | width: int = 0, # Used to make the constructor consistent. 86 | activation_fn: nn.Module = nn.ReLU, 87 | mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN, 88 | std: Union[Tuple[float, ...], float] = CIFAR10_STD, 89 | padding: int = 0, 90 | num_input_channels: int = 3): 91 | 92 | super().__init__() 93 | if width != 0: 94 | raise ValueError('Unsupported `width`.') 95 | self.mean = torch.tensor(mean).view(num_input_channels, 1, 1) 96 | self.std = torch.tensor(std).view(num_input_channels, 1, 1) 97 | self.mean_cuda = None 98 | self.std_cuda = None 99 | self.padding = padding 100 | self.conv_2d = nn.Conv2d(num_input_channels, 64, kernel_size=3, stride=1, 101 | padding=1, bias=False) 102 | if depth == 18: 103 | num_blocks = (2, 2, 2, 2) 104 | elif depth == 34: 105 | num_blocks = (3, 4, 6, 3) 106 | else: 107 | raise ValueError('Unsupported `depth`.') 108 | self.layer_0 = self._make_layer(64, 64, num_blocks[0], 1, activation_fn) 109 | self.layer_1 = self._make_layer(64, 128, num_blocks[1], 2, activation_fn) 110 | self.layer_2 = self._make_layer(128, 256, num_blocks[2], 2, activation_fn) 111 | self.layer_3 = self._make_layer(256, 512, num_blocks[3], 2, activation_fn) 112 | self.batchnorm = nn.BatchNorm2d(512, momentum=0.01) 113 | self.relu = activation_fn() 114 | self.logits = nn.Linear(512, num_classes) 115 | 116 | def _make_layer(self, in_planes, out_planes, num_blocks, stride, 117 | activation_fn): 118 | layers = [] 119 | for i, stride in enumerate([stride] + [1] * (num_blocks - 1)): 120 | layers.append(_PreActBlock(i == 0 and in_planes or out_planes, 121 | out_planes, 122 | stride, 123 | activation_fn)) 124 | return nn.Sequential(*layers) 125 | 126 | def forward(self, x): 127 | if self.padding > 0: 128 | x = F.pad(x, (self.padding,) * 4) 129 | if x.is_cuda: 130 | if self.mean_cuda is None: 131 | self.mean_cuda = self.mean.cuda() 132 | self.std_cuda = self.std.cuda() 133 | out = (x - self.mean_cuda) / self.std_cuda 134 | else: 135 | out = (x - self.mean) / self.std 136 | out = self.conv_2d(out) 137 | out = self.layer_0(out) 138 | out = self.layer_1(out) 139 | out = self.layer_2(out) 140 | out = self.layer_3(out) 141 | out = self.relu(self.batchnorm(out)) 142 | out = F.avg_pool2d(out, 4) 143 | out = out.view(out.size(0), -1) 144 | return self.logits(out) 145 | 146 | 147 | def preact_resnetwithswish(name, dataset='cifar10', num_classes=10): 148 | """ 149 | Returns suitable PreActResNet model with Swish activation function from its name. 150 | Arguments: 151 | name (str): name of resnet architecture. 152 | num_classes (int): number of target classes. 153 | dataset (str): dataset to use. 154 | Returns: 155 | torch.nn.Module. 156 | """ 157 | name_parts = name.split('-') 158 | name = '-'.join(name_parts[:-1]) 159 | act_fn = name_parts[-1] 160 | depth = int(name[-2:]) 161 | 162 | if 'cifar100' in dataset: 163 | return PreActResNet(num_classes=num_classes, depth=depth, width=0, activation_fn=_ACTIVATION[act_fn], 164 | mean=CIFAR100_MEAN, std=CIFAR100_STD) 165 | elif 'svhn' in dataset: 166 | return PreActResNet(num_classes=num_classes, depth=depth, width=0, activation_fn=_ACTIVATION[act_fn], 167 | mean=SVHN_MEAN, std=SVHN_STD) 168 | return PreActResNet(num_classes=num_classes, depth=depth, width=0, activation_fn=_ACTIVATION[act_fn]) 169 | -------------------------------------------------------------------------------- /core/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Normalization(nn.Module): 7 | """ 8 | Standardizes the input data. 9 | Arguments: 10 | mean (list): mean. 11 | std (float): standard deviation. 12 | device (str or torch.device): device to be used. 13 | Returns: 14 | (input - mean) / std 15 | """ 16 | def __init__(self, mean, std): 17 | super(Normalization, self).__init__() 18 | num_channels = len(mean) 19 | self.mean = torch.FloatTensor(mean).view(1, num_channels, 1, 1) 20 | self.sigma = torch.FloatTensor(std).view(1, num_channels, 1, 1) 21 | self.mean_cuda, self.sigma_cuda = None, None 22 | 23 | def forward(self, x): 24 | if x.is_cuda: 25 | if self.mean_cuda is None: 26 | self.mean_cuda = self.mean.cuda() 27 | self.sigma_cuda = self.sigma.cuda() 28 | out = (x - self.mean_cuda) / self.sigma_cuda 29 | else: 30 | out = (x - self.mean) / self.sigma 31 | return out 32 | 33 | 34 | class BasicBlock(nn.Module): 35 | """ 36 | Implements a basic block module for Resnets. 37 | Arguments: 38 | in_planes (int): number of input planes. 39 | out_planes (int): number of output filters. 40 | stride (int): stride of convolution. 41 | """ 42 | expansion = 1 43 | 44 | def __init__(self, in_planes, planes, stride=1): 45 | super(BasicBlock, self).__init__() 46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn1 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | 51 | self.shortcut = nn.Sequential() 52 | if stride != 1 or in_planes != self.expansion * planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 55 | nn.BatchNorm2d(self.expansion * planes) 56 | ) 57 | 58 | def forward(self, x): 59 | out = F.relu(self.bn1(self.conv1(x))) 60 | out = self.bn2(self.conv2(out)) 61 | out += self.shortcut(x) 62 | out = F.relu(out) 63 | return out 64 | 65 | 66 | class Bottleneck(nn.Module): 67 | """ 68 | Implements a basic block module with bottleneck for Resnets. 69 | Arguments: 70 | in_planes (int): number of input planes. 71 | out_planes (int): number of output filters. 72 | stride (int): stride of convolution. 73 | """ 74 | expansion = 4 75 | 76 | def __init__(self, in_planes, planes, stride=1): 77 | super(Bottleneck, self).__init__() 78 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 79 | self.bn1 = nn.BatchNorm2d(planes) 80 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 81 | self.bn2 = nn.BatchNorm2d(planes) 82 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 83 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 84 | 85 | self.shortcut = nn.Sequential() 86 | if stride != 1 or in_planes != self.expansion * planes: 87 | self.shortcut = nn.Sequential( 88 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 89 | nn.BatchNorm2d(self.expansion * planes) 90 | ) 91 | 92 | def forward(self, x): 93 | out = F.relu(self.bn1(self.conv1(x))) 94 | out = F.relu(self.bn2(self.conv2(out))) 95 | out = self.bn3(self.conv3(out)) 96 | out += self.shortcut(x) 97 | out = F.relu(out) 98 | return out 99 | 100 | 101 | class ResNet(nn.Module): 102 | """ 103 | ResNet model 104 | Arguments: 105 | block (BasicBlock or Bottleneck): type of basic block to be used. 106 | num_blocks (list): number of blocks in each sub-module. 107 | num_classes (int): number of output classes. 108 | device (torch.device or str): device to work on. 109 | """ 110 | def __init__(self, block, num_blocks, num_classes=10, device='cpu'): 111 | super(ResNet, self).__init__() 112 | self.in_planes = 64 113 | 114 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 115 | self.bn1 = nn.BatchNorm2d(64) 116 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 117 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 118 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 119 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 120 | self.linear = nn.Linear(512 * block.expansion, num_classes) 121 | 122 | def _make_layer(self, block, planes, num_blocks, stride): 123 | strides = [stride] + [1] * (num_blocks - 1) 124 | layers = [] 125 | for stride in strides: 126 | layers.append(block(self.in_planes, planes, stride)) 127 | self.in_planes = planes * block.expansion 128 | return nn.Sequential(*layers) 129 | 130 | def forward(self, x): 131 | out = F.relu(self.bn1(self.conv1(x))) 132 | out = self.layer1(out) 133 | out = self.layer2(out) 134 | out = self.layer3(out) 135 | out = self.layer4(out) 136 | out = F.avg_pool2d(out, 4) 137 | out = out.view(out.size(0), -1) 138 | out = self.linear(out) 139 | return out 140 | 141 | 142 | def resnet(name, num_classes=10, pretrained=False, device='cpu'): 143 | """ 144 | Returns suitable Resnet model from its name. 145 | Arguments: 146 | name (str): name of resnet architecture. 147 | num_classes (int): number of target classes. 148 | pretrained (bool): whether to use a pretrained model. 149 | device (str or torch.device): device to work on. 150 | Returns: 151 | torch.nn.Module. 152 | """ 153 | if name == 'resnet18': 154 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, device=device) 155 | elif name == 'resnet34': 156 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, device=device) 157 | elif name == 'resnet50': 158 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, device=device) 159 | elif name == 'resnet101': 160 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, device=device) 161 | 162 | raise ValueError('Only resnet18, resnet34, resnet50 and resnet101 are supported!') 163 | return 164 | -------------------------------------------------------------------------------- /core/models/ti_preact_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class PreActBlock(nn.Module): 7 | """ 8 | Pre-activation version of the BasicBlock. 9 | Arguments: 10 | in_planes (int): number of input planes. 11 | planes (int): number of output filters. 12 | stride (int): stride of convolution. 13 | """ 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | out = self.conv2(F.relu(self.bn2(out))) 33 | out += shortcut 34 | return out 35 | 36 | 37 | class PreActBottleneck(nn.Module): 38 | """ 39 | Pre-activation version of the original Bottleneck module. 40 | Arguments: 41 | in_planes (int): number of input planes. 42 | planes (int): number of output filters. 43 | stride (int): stride of convolution. 44 | """ 45 | expansion = 4 46 | 47 | def __init__(self, in_planes, planes, stride=1): 48 | super(PreActBottleneck, self).__init__() 49 | self.bn1 = nn.BatchNorm2d(in_planes) 50 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 53 | self.bn3 = nn.BatchNorm2d(planes) 54 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 55 | 56 | if stride != 1 or in_planes != self.expansion*planes: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 59 | ) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(x)) 63 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 64 | out = self.conv1(out) 65 | out = self.conv2(F.relu(self.bn2(out))) 66 | out = self.conv3(F.relu(self.bn3(out))) 67 | out += shortcut 68 | return out 69 | 70 | 71 | class PreActResNet(nn.Module): 72 | """ 73 | Pre-activation Resnet model for TI-200 dataset. 74 | """ 75 | def __init__(self, block, num_blocks, num_classes=200): 76 | super(PreActResNet, self).__init__() 77 | self.in_planes = 64 78 | 79 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 80 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 81 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 82 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 83 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 84 | self.bn = nn.BatchNorm2d(512 * block.expansion) 85 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 86 | self.linear = nn.Linear(512*block.expansion, num_classes) 87 | 88 | def _make_layer(self, block, planes, num_blocks, stride): 89 | strides = [stride] + [1]*(num_blocks-1) 90 | layers = [] 91 | for stride in strides: 92 | layers.append(block(self.in_planes, planes, stride)) 93 | self.in_planes = planes * block.expansion 94 | return nn.Sequential(*layers) 95 | 96 | def forward(self, x): 97 | out = self.conv1(x) 98 | out = self.layer1(out) 99 | out = self.layer2(out) 100 | out = self.layer3(out) 101 | out = self.layer4(out) 102 | out = F.relu(self.bn(out)) 103 | out = self.avgpool(out) 104 | out = out.view(out.size(0), -1) 105 | out = self.linear(out) 106 | return out 107 | 108 | 109 | def ti_preact_resnet(name, num_classes=200, pretrained=False, device='cpu'): 110 | """ 111 | Returns suitable PreAct Resnet model from its name (only for TI-200 dataset). 112 | Arguments: 113 | name (str): name of resnet architecture. 114 | num_classes (int): number of target classes. 115 | pretrained (bool): whether to use a pretrained model. 116 | device (str or torch.device): device to work on. 117 | Returns: 118 | torch.nn.Module. 119 | """ 120 | if name == 'preact-resnet18': 121 | return PreActResNet(PreActBlock, [2, 2, 2, 2], num_classes=num_classes) 122 | elif name == 'preact-resnet34': 123 | return PreActResNet(PreActBlock, [3, 4, 6, 3], num_classes=num_classes) 124 | elif name == 'preact-resnet50': 125 | return PreActResNet(PreActBottleneck, [3, 4, 6, 3], num_classes=num_classes) 126 | elif name == 'preact-resnet101': 127 | return PreActResNet(PreActBottleneck, [3, 4, 23, 3], num_classes=num_classes) 128 | else: 129 | raise ValueError('Only preact-resnet18, preact-resnet34, preact-resnet50 and preact-resnet101 are supported!') 130 | return 131 | -------------------------------------------------------------------------------- /core/models/wideresnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | """ 9 | Implements a basic block module for WideResNets. 10 | Arguments: 11 | in_planes (int): number of input planes. 12 | out_planes (int): number of output filters. 13 | stride (int): stride of convolution. 14 | dropRate (float): dropout rate. 15 | """ 16 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 17 | super(BasicBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(out_planes) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 25 | padding=1, bias=False) 26 | self.droprate = dropRate 27 | self.equalInOut = (in_planes == out_planes) 28 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 29 | padding=0, bias=False) or None 30 | 31 | def forward(self, x): 32 | if not self.equalInOut: 33 | x = self.relu1(self.bn1(x)) 34 | else: 35 | out = self.relu1(self.bn1(x)) 36 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 37 | if self.droprate > 0: 38 | out = F.dropout(out, p=self.droprate, training=self.training) 39 | out = self.conv2(out) 40 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 41 | 42 | 43 | class NetworkBlock(nn.Module): 44 | """ 45 | Implements a network block module for WideResnets. 46 | Arguments: 47 | nb_layers (int): number of layers. 48 | in_planes (int): number of input planes. 49 | out_planes (int): number of output filters. 50 | block (BasicBlock): type of basic block to be used. 51 | stride (int): stride of convolution. 52 | dropRate (float): dropout rate. 53 | """ 54 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 55 | super(NetworkBlock, self).__init__() 56 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 57 | 58 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 59 | layers = [] 60 | for i in range(int(nb_layers)): 61 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 62 | return nn.Sequential(*layers) 63 | 64 | def forward(self, x): 65 | return self.layer(x) 66 | 67 | 68 | class WideResNet(nn.Module): 69 | """ 70 | WideResNet model 71 | Arguments: 72 | depth (int): number of layers. 73 | num_classes (int): number of output classes. 74 | widen_factor (int): width factor. 75 | dropRate (float): dropout rate. 76 | """ 77 | def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0.0): 78 | super(WideResNet, self).__init__() 79 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 80 | assert ((depth - 4) % 6 == 0) 81 | n = (depth - 4) / 6 82 | block = BasicBlock 83 | # 1st conv before any network block 84 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 85 | padding=1, bias=False) 86 | # 1st block 87 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 88 | # 2nd block 89 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 90 | # 3rd block 91 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 92 | # global average pooling and classifier 93 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.fc = nn.Linear(nChannels[3], num_classes) 96 | self.nChannels = nChannels[3] 97 | 98 | for m in self.modules(): 99 | if isinstance(m, nn.Conv2d): 100 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 101 | m.weight.data.normal_(0, math.sqrt(2. / n)) 102 | elif isinstance(m, nn.BatchNorm2d): 103 | m.weight.data.fill_(1) 104 | m.bias.data.zero_() 105 | elif isinstance(m, nn.Linear): 106 | m.bias.data.zero_() 107 | 108 | def forward(self, x): 109 | out = self.conv1(x) 110 | out = self.block1(out) 111 | out = self.block2(out) 112 | out = self.block3(out) 113 | out = self.relu(self.bn1(out)) 114 | out = F.avg_pool2d(out, 8) 115 | out = out.view(-1, self.nChannels) 116 | return self.fc(out) 117 | 118 | 119 | def wideresnet(name, num_classes=10, device='cpu'): 120 | """ 121 | Returns suitable Wideresnet model from its name. 122 | Arguments: 123 | name (str): name of resnet architecture. 124 | num_classes (int): number of target classes. 125 | device (str or torch.device): device to work on. 126 | Returns: 127 | torch.nn.Module. 128 | """ 129 | name_parts = name.split('-') 130 | depth = int(name_parts[1]) 131 | widen = int(name_parts[2]) 132 | return WideResNet(depth=depth, num_classes=num_classes, widen_factor=widen) 133 | -------------------------------------------------------------------------------- /core/models/wideresnetwithswish.py: -------------------------------------------------------------------------------- 1 | # Code borrowed from https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/pytorch/model_zoo.py 2 | # (Gowal et al 2020) 3 | 4 | from typing import Tuple, Union 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) 13 | CIFAR10_STD = (0.2471, 0.2435, 0.2616) 14 | CIFAR100_MEAN = (0.5071, 0.4865, 0.4409) 15 | CIFAR100_STD = (0.2673, 0.2564, 0.2762) 16 | 17 | _ACTIVATION = { 18 | 'relu': nn.ReLU, 19 | 'swish': nn.SiLU, 20 | } 21 | 22 | 23 | class _Block(nn.Module): 24 | """ 25 | WideResNet Block. 26 | Arguments: 27 | in_planes (int): number of input planes. 28 | out_planes (int): number of output filters. 29 | stride (int): stride of convolution. 30 | activation_fn (nn.Module): activation function. 31 | """ 32 | def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU): 33 | super().__init__() 34 | self.batchnorm_0 = nn.BatchNorm2d(in_planes, momentum=0.01) 35 | self.relu_0 = activation_fn(inplace=True) 36 | # We manually pad to obtain the same effect as `SAME` (necessary when `stride` is different than 1). 37 | self.conv_0 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 38 | padding=0, bias=False) 39 | self.batchnorm_1 = nn.BatchNorm2d(out_planes, momentum=0.01) 40 | self.relu_1 = activation_fn(inplace=True) 41 | self.conv_1 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 42 | padding=1, bias=False) 43 | self.has_shortcut = in_planes != out_planes 44 | if self.has_shortcut: 45 | self.shortcut = nn.Conv2d(in_planes, out_planes, kernel_size=1, 46 | stride=stride, padding=0, bias=False) 47 | else: 48 | self.shortcut = None 49 | self._stride = stride 50 | 51 | def forward(self, x): 52 | if self.has_shortcut: 53 | x = self.relu_0(self.batchnorm_0(x)) 54 | else: 55 | out = self.relu_0(self.batchnorm_0(x)) 56 | v = x if self.has_shortcut else out 57 | if self._stride == 1: 58 | v = F.pad(v, (1, 1, 1, 1)) 59 | elif self._stride == 2: 60 | v = F.pad(v, (0, 1, 0, 1)) 61 | else: 62 | raise ValueError('Unsupported `stride`.') 63 | out = self.conv_0(v) 64 | out = self.relu_1(self.batchnorm_1(out)) 65 | out = self.conv_1(out) 66 | out = torch.add(self.shortcut(x) if self.has_shortcut else x, out) 67 | return out 68 | 69 | 70 | class _BlockGroup(nn.Module): 71 | """ 72 | WideResNet block group. 73 | Arguments: 74 | in_planes (int): number of input planes. 75 | out_planes (int): number of output filters. 76 | stride (int): stride of convolution. 77 | activation_fn (nn.Module): activation function. 78 | """ 79 | def __init__(self, num_blocks, in_planes, out_planes, stride, activation_fn=nn.ReLU): 80 | super().__init__() 81 | block = [] 82 | for i in range(num_blocks): 83 | block.append( 84 | _Block(i == 0 and in_planes or out_planes, 85 | out_planes, 86 | i == 0 and stride or 1, 87 | activation_fn=activation_fn) 88 | ) 89 | self.block = nn.Sequential(*block) 90 | 91 | def forward(self, x): 92 | return self.block(x) 93 | 94 | 95 | class WideResNet(nn.Module): 96 | """ 97 | WideResNet model 98 | Arguments: 99 | num_classes (int): number of output classes. 100 | depth (int): number of layers. 101 | width (int): width factor. 102 | activation_fn (nn.Module): activation function. 103 | mean (tuple): mean of dataset. 104 | std (tuple): standard deviation of dataset. 105 | padding (int): padding. 106 | num_input_channels (int): number of channels in the input. 107 | """ 108 | def __init__(self, 109 | num_classes: int = 10, 110 | depth: int = 28, 111 | width: int = 10, 112 | activation_fn: nn.Module = nn.ReLU, 113 | mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN, 114 | std: Union[Tuple[float, ...], float] = CIFAR10_STD, 115 | padding: int = 0, 116 | num_input_channels: int = 3): 117 | super().__init__() 118 | self.mean = torch.tensor(mean).view(num_input_channels, 1, 1) 119 | self.std = torch.tensor(std).view(num_input_channels, 1, 1) 120 | self.mean_cuda = None 121 | self.std_cuda = None 122 | self.padding = padding 123 | num_channels = [16, 16 * width, 32 * width, 64 * width] 124 | assert (depth - 4) % 6 == 0 125 | num_blocks = (depth - 4) // 6 126 | self.init_conv = nn.Conv2d(num_input_channels, num_channels[0], 127 | kernel_size=3, stride=1, padding=1, bias=False) 128 | self.layer = nn.Sequential( 129 | _BlockGroup(num_blocks, num_channels[0], num_channels[1], 1, 130 | activation_fn=activation_fn), 131 | _BlockGroup(num_blocks, num_channels[1], num_channels[2], 2, 132 | activation_fn=activation_fn), 133 | _BlockGroup(num_blocks, num_channels[2], num_channels[3], 2, 134 | activation_fn=activation_fn)) 135 | self.batchnorm = nn.BatchNorm2d(num_channels[3], momentum=0.01) 136 | self.relu = activation_fn(inplace=True) 137 | self.logits = nn.Linear(num_channels[3], num_classes) 138 | self.num_channels = num_channels[3] 139 | 140 | for m in self.modules(): 141 | if isinstance(m, nn.Conv2d): 142 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 143 | m.weight.data.normal_(0, math.sqrt(2. / n)) 144 | elif isinstance(m, nn.BatchNorm2d): 145 | m.weight.data.fill_(1) 146 | m.bias.data.zero_() 147 | elif isinstance(m, nn.Linear): 148 | m.bias.data.zero_() 149 | 150 | def forward(self, x): 151 | if self.padding > 0: 152 | x = F.pad(x, (self.padding,) * 4) 153 | if x.is_cuda: 154 | if self.mean_cuda is None: 155 | self.mean_cuda = self.mean.cuda() 156 | self.std_cuda = self.std.cuda() 157 | out = (x - self.mean_cuda) / self.std_cuda 158 | else: 159 | out = (x - self.mean) / self.std 160 | 161 | out = self.init_conv(out) 162 | out = self.layer(out) 163 | out = self.relu(self.batchnorm(out)) 164 | out = F.avg_pool2d(out, 8) 165 | out = out.view(-1, self.num_channels) 166 | return self.logits(out) 167 | 168 | 169 | def wideresnetwithswish(name, dataset='cifar10', num_classes=10, device='cpu'): 170 | """ 171 | Returns suitable Wideresnet model with Swish activation function from its name. 172 | Arguments: 173 | name (str): name of resnet architecture. 174 | num_classes (int): number of target classes. 175 | device (str or torch.device): device to work on. 176 | dataset (str): dataset to use. 177 | Returns: 178 | torch.nn.Module. 179 | """ 180 | if 'cifar10' not in dataset: 181 | raise ValueError('WideResNets with Swish activation only support CIFAR-10 and CIFAR-100!') 182 | 183 | name_parts = name.split('-') 184 | depth = int(name_parts[1]) 185 | widen = int(name_parts[2]) 186 | act_fn = name_parts[3] 187 | 188 | print (f'WideResNet-{depth}-{widen}-{act_fn} uses normalization.') 189 | if 'cifar100' in dataset: 190 | return WideResNet(num_classes=num_classes, depth=depth, width=widen, activation_fn=_ACTIVATION[act_fn], 191 | mean=CIFAR100_MEAN, std=CIFAR100_STD) 192 | return WideResNet(num_classes=num_classes, depth=depth, width=widen, activation_fn=_ACTIVATION[act_fn]) -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | 3 | from .context import * 4 | 5 | from .logger import Logger 6 | 7 | from .train import SCHEDULERS 8 | from .train import Trainer 9 | 10 | from .parser import * -------------------------------------------------------------------------------- /core/utils/context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | 4 | class ctx_noparamgrad(object): 5 | def __init__(self, module): 6 | self.prev_grad_state = get_param_grad_state(module) 7 | self.module = module 8 | set_param_grad_off(module) 9 | 10 | def __enter__(self): 11 | pass 12 | 13 | def __exit__(self, *args): 14 | set_param_grad_state(self.module, self.prev_grad_state) 15 | return False 16 | 17 | 18 | class ctx_eval(object): 19 | def __init__(self, module): 20 | self.prev_training_state = get_module_training_state(module) 21 | self.module = module 22 | set_module_training_off(module) 23 | 24 | def __enter__(self): 25 | pass 26 | 27 | def __exit__(self, *args): 28 | set_module_training_state(self.module, self.prev_training_state) 29 | return False 30 | 31 | 32 | @contextmanager 33 | def ctx_noparamgrad_and_eval(module): 34 | with ctx_noparamgrad(module) as a, ctx_eval(module) as b: 35 | yield (a, b) 36 | 37 | 38 | def get_module_training_state(module): 39 | return {mod: mod.training for mod in module.modules()} 40 | 41 | 42 | def set_module_training_state(module, training_state): 43 | for mod in module.modules(): 44 | mod.training = training_state[mod] 45 | 46 | 47 | def set_module_training_off(module): 48 | for mod in module.modules(): 49 | mod.training = False 50 | 51 | 52 | def get_param_grad_state(module): 53 | return {param: param.requires_grad for param in module.parameters()} 54 | 55 | 56 | def set_param_grad_state(module, grad_state): 57 | for param in module.parameters(): 58 | param.requires_grad = grad_state[param] 59 | 60 | 61 | def set_param_grad_off(module): 62 | for param in module.parameters(): 63 | param.requires_grad = False -------------------------------------------------------------------------------- /core/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | class Logger(object): 5 | """ 6 | Helper class for logging. 7 | Arguments: 8 | path (str): Path to log file. 9 | """ 10 | def __init__(self, path): 11 | self.logger = logging.getLogger() 12 | self.path = path 13 | self.setup_file_logger() 14 | print ('Logging to file: ', self.path) 15 | 16 | def setup_file_logger(self): 17 | hdlr = logging.FileHandler(self.path, 'w+') 18 | self.logger.addHandler(hdlr) 19 | self.logger.setLevel(logging.INFO) 20 | 21 | def log(self, message): 22 | print (message) 23 | self.logger.info(message) 24 | -------------------------------------------------------------------------------- /core/utils/mart.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | from core.metrics import accuracy 7 | 8 | 9 | def mart_loss(model, x_natural, y, optimizer, step_size=0.007, epsilon=0.031, perturb_steps=10, beta=6.0, 10 | attack='linf-pgd'): 11 | """ 12 | MART training (Wang et al, 2020). 13 | """ 14 | 15 | kl = nn.KLDivLoss(reduction='none') 16 | model.eval() 17 | batch_size = len(x_natural) 18 | 19 | # generate adversarial example 20 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach() 21 | if attack == 'linf-pgd': 22 | for _ in range(perturb_steps): 23 | x_adv.requires_grad_() 24 | with torch.enable_grad(): 25 | loss_ce = F.cross_entropy(model(x_adv), y) 26 | grad = torch.autograd.grad(loss_ce, [x_adv])[0] 27 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 28 | x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon) 29 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 30 | else: 31 | raise ValueError(f'Attack={attack} not supported for MART training!') 32 | model.train() 33 | 34 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 35 | # zero gradient 36 | optimizer.zero_grad() 37 | 38 | logits = model(x_natural) 39 | logits_adv = model(x_adv) 40 | 41 | adv_probs = F.softmax(logits_adv, dim=1) 42 | tmp1 = torch.argsort(adv_probs, dim=1)[:, -2:] 43 | new_y = torch.where(tmp1[:, -1] == y, tmp1[:, -2], tmp1[:, -1]) 44 | loss_adv = F.cross_entropy(logits_adv, y) + F.nll_loss(torch.log(1.0001 - adv_probs + 1e-12), new_y) 45 | 46 | nat_probs = F.softmax(logits, dim=1) 47 | true_probs = torch.gather(nat_probs, 1, (y.unsqueeze(1)).long()).squeeze() 48 | 49 | loss_robust = (1.0 / batch_size) * torch.sum( 50 | torch.sum(kl(torch.log(adv_probs + 1e-12), nat_probs), dim=1) * (1.0000001 - true_probs)) 51 | loss = loss_adv + float(beta) * loss_robust 52 | 53 | batch_metrics = {'loss': loss.item(), 'clean_acc': accuracy(y, logits.detach()), 54 | 'adversarial_acc': accuracy(y, logits_adv.detach())} 55 | 56 | return loss, batch_metrics 57 | -------------------------------------------------------------------------------- /core/utils/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from core.attacks import ATTACKS 4 | from core.data import DATASETS 5 | from core.models import MODELS 6 | from .train import SCHEDULERS 7 | 8 | from .utils import str2bool, str2float 9 | 10 | 11 | def parser_train(): 12 | """ 13 | Parse input arguments (train.py). 14 | """ 15 | parser = argparse.ArgumentParser(description='Standard + Adversarial Training.') 16 | 17 | parser.add_argument('--augment', type=str2bool, default=True, help='Augment training set.') 18 | parser.add_argument('--batch-size', type=int, default=1024, help='Batch size for training.') 19 | parser.add_argument('--batch-size-validation', type=int, default=512, help='Batch size for testing.') 20 | 21 | parser.add_argument('--data-dir', type=str, default='/cluster/home/rarade/data/') 22 | parser.add_argument('--log-dir', type=str, default='/cluster/scratch/rarade/test/') 23 | 24 | parser.add_argument('-d', '--data', type=str, default='cifar10s', choices=DATASETS, help='Data to use.') 25 | parser.add_argument('--desc', type=str, required=True, 26 | help='Description of experiment. It will be used to name directories.') 27 | 28 | parser.add_argument('-m', '--model', choices=MODELS, default='wrn-28-10-swish', help='Model architecture to be used.') 29 | parser.add_argument('--normalize', type=str2bool, default=False, help='Normalize input.') 30 | parser.add_argument('--pretrained-file', type=str, default=None, help='Pretrained weights file name.') 31 | 32 | parser.add_argument('-na', '--num-adv-epochs', type=int, default=400, help='Number of adversarial training epochs.') 33 | parser.add_argument('--adv-eval-freq', type=int, default=25, help='Adversarial evaluation frequency (in epochs).') 34 | 35 | parser.add_argument('--beta', default=None, type=float, help='Stability regularization, i.e., 1/lambda in TRADES.') 36 | 37 | parser.add_argument('--lr', type=float, default=0.4, help='Learning rate for optimizer (SGD).') 38 | parser.add_argument('--weight-decay', type=float, default=5e-4, help='Optimizer (SGD) weight decay.') 39 | parser.add_argument('--scheduler', choices=SCHEDULERS, default='cosinew', help='Type of scheduler.') 40 | parser.add_argument('--nesterov', type=str2bool, default=True, help='Use Nesterov momentum.') 41 | parser.add_argument('--clip-grad', type=float, default=None, help='Gradient norm clipping.') 42 | 43 | parser.add_argument('-a', '--attack', type=str, choices=ATTACKS, default='linf-pgd', help='Type of attack.') 44 | parser.add_argument('--attack-eps', type=str2float, default=8/255, help='Epsilon for the attack.') 45 | parser.add_argument('--attack-step', type=str2float, default=2/255, help='Step size for PGD attack.') 46 | parser.add_argument('--attack-iter', type=int, default=10, help='Max. number of iterations (if any) for the attack.') 47 | parser.add_argument('--keep-clean', type=str2bool, default=False, help='Use clean samples during adversarial training.') 48 | 49 | parser.add_argument('--debug', action='store_true', default=False, 50 | help='Debug code. Run 1 epoch of training and evaluation.') 51 | parser.add_argument('--mart', action='store_true', default=False, help='MART training.') 52 | 53 | parser.add_argument('--unsup-fraction', type=float, default=0.7, help='Ratio of unlabelled data to labelled data.') 54 | parser.add_argument('--aux-data-filename', type=str, help='Path to additional Tiny Images data.', 55 | default='/cluster/scratch/rarade/cifar10s/ti_500K_pseudo_labeled.pickle') 56 | 57 | parser.add_argument('--seed', type=int, default=1, help='Random seed.') 58 | return parser 59 | 60 | 61 | def parser_eval(): 62 | """ 63 | Parse input arguments (eval-adv.py, eval-corr.py, eval-aa.py). 64 | """ 65 | parser = argparse.ArgumentParser(description='Robustness evaluation.') 66 | 67 | parser.add_argument('--data-dir', type=str, default='/cluster/home/rarade/data/') 68 | parser.add_argument('--log-dir', type=str, default='/cluster/scratch/rarade/test/') 69 | 70 | parser.add_argument('--desc', type=str, required=True, help='Description of model to be evaluated.') 71 | parser.add_argument('--num-samples', type=int, default=1000, help='Number of test samples.') 72 | 73 | # eval-aa.py 74 | parser.add_argument('--train', action='store_true', default=False, help='Evaluate on training set.') 75 | parser.add_argument('-v', '--version', type=str, default='standard', choices=['custom', 'plus', 'standard'], 76 | help='Version of AA.') 77 | 78 | # eval-adv.py 79 | parser.add_argument('--source', type=str, default=None, help='Path to source model for black-box evaluation.') 80 | parser.add_argument('--wb', action='store_true', default=False, help='Perform white-box PGD evaluation.') 81 | 82 | # eval-rb.py 83 | parser.add_argument('--threat', type=str, default='corruptions', choices=['corruptions', 'Linf', 'L2'], 84 | help='Threat model for RobustBench evaluation.') 85 | 86 | parser.add_argument('--seed', type=int, default=1, help='Random seed.') 87 | return parser 88 | 89 | -------------------------------------------------------------------------------- /core/utils/rst.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | 6 | class CosineLR(torch.optim.lr_scheduler._LRScheduler): 7 | """ 8 | Cosine annealing LR schedule (used in Carmon et al, 2019). 9 | """ 10 | def __init__(self, optimizer, max_lr, epochs, last_epoch=-1): 11 | self.max_lr = max_lr 12 | self.epochs = epochs 13 | self._reset() 14 | super(CosineLR, self).__init__(optimizer, last_epoch) 15 | 16 | def _reset(self): 17 | self.current_lr = self.max_lr 18 | self.current_epoch = 1 19 | 20 | def step(self): 21 | self.current_lr = self.max_lr * 0.5 * (1 + np.cos((self.current_epoch - 1) / self.epochs * np.pi)) 22 | for param_group in self.optimizer.param_groups: 23 | param_group['lr'] = self.current_lr 24 | self.current_epoch += 1 25 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups] 26 | 27 | def get_lr(self): 28 | return self.current_lr 29 | -------------------------------------------------------------------------------- /core/utils/trades.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import torch.optim as optim 6 | 7 | from core.metrics import accuracy 8 | 9 | 10 | def squared_l2_norm(x): 11 | flattened = x.view(x.unsqueeze(0).shape[0], -1) 12 | return (flattened ** 2).sum(1) 13 | 14 | 15 | def l2_norm(x): 16 | return squared_l2_norm(x).sqrt() 17 | 18 | 19 | def trades_loss(model, x_natural, y, optimizer, step_size=0.003, epsilon=0.031, perturb_steps=10, beta=1.0, 20 | attack='linf-pgd'): 21 | """ 22 | TRADES training (Zhang et al, 2019). 23 | """ 24 | 25 | # define KL-loss 26 | criterion_kl = nn.KLDivLoss(reduction='sum') 27 | model.eval() 28 | batch_size = len(x_natural) 29 | # generate adversarial example 30 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach() 31 | p_natural = F.softmax(model(x_natural), dim=1) 32 | 33 | if attack == 'linf-pgd': 34 | for _ in range(perturb_steps): 35 | x_adv.requires_grad_() 36 | with torch.enable_grad(): 37 | loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1), p_natural) 38 | grad = torch.autograd.grad(loss_kl, [x_adv])[0] 39 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 40 | x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon) 41 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 42 | 43 | elif attack == 'l2-pgd': 44 | delta = 0.001 * torch.randn(x_natural.shape).cuda().detach() 45 | delta = Variable(delta.data, requires_grad=True) 46 | 47 | # Setup optimizers 48 | optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2) 49 | 50 | for _ in range(perturb_steps): 51 | adv = x_natural + delta 52 | 53 | # optimize 54 | optimizer_delta.zero_grad() 55 | with torch.enable_grad(): 56 | loss = (-1) * criterion_kl(F.log_softmax(model(adv), dim=1), p_natural) 57 | loss.backward(retain_graph=True) 58 | # renorming gradient 59 | grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1) 60 | delta.grad.div_(grad_norms.view(-1, 1, 1, 1)) 61 | # avoid nan or inf if gradient is 0 62 | if (grad_norms == 0).any(): 63 | delta.grad[grad_norms == 0] = torch.randn_like(delta.grad[grad_norms == 0]) 64 | optimizer_delta.step() 65 | 66 | # projection 67 | delta.data.add_(x_natural) 68 | delta.data.clamp_(0, 1).sub_(x_natural) 69 | delta.data.renorm_(p=2, dim=0, maxnorm=epsilon) 70 | x_adv = Variable(x_natural + delta, requires_grad=False) 71 | else: 72 | raise ValueError(f'Attack={attack} not supported for TRADES training!') 73 | model.train() 74 | 75 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 76 | 77 | optimizer.zero_grad() 78 | # calculate robust loss 79 | logits_natural = model(x_natural) 80 | logits_adv = model(x_adv) 81 | loss_natural = F.cross_entropy(logits_natural, y) 82 | loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(logits_adv, dim=1), 83 | F.softmax(logits_natural, dim=1)) 84 | loss = loss_natural + beta * loss_robust 85 | 86 | batch_metrics = {'loss': loss.item(), 'clean_acc': accuracy(y, logits_natural.detach()), 87 | 'adversarial_acc': accuracy(y, logits_adv.detach())} 88 | 89 | return loss, batch_metrics 90 | -------------------------------------------------------------------------------- /core/utils/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from tqdm import tqdm as tqdm 4 | 5 | import os 6 | import json 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from core.attacks import create_attack 12 | from core.metrics import accuracy 13 | from core.models import create_model 14 | 15 | from .context import ctx_noparamgrad_and_eval 16 | from .utils import seed 17 | 18 | from .mart import mart_loss 19 | from .rst import CosineLR 20 | from .trades import trades_loss 21 | 22 | 23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | 25 | SCHEDULERS = ['cyclic', 'step', 'cosine', 'cosinew'] 26 | 27 | 28 | class Trainer(object): 29 | """ 30 | Helper class for training a deep neural network. 31 | Arguments: 32 | info (dict): dataset information. 33 | args (dict): input arguments. 34 | """ 35 | def __init__(self, info, args): 36 | super(Trainer, self).__init__() 37 | 38 | seed(args.seed) 39 | self.model = create_model(args.model, args.normalize, info, device) 40 | 41 | self.params = args 42 | self.criterion = nn.CrossEntropyLoss() 43 | self.init_optimizer(self.params.num_adv_epochs) 44 | 45 | if self.params.pretrained_file is not None: 46 | self.load_model(os.path.join(self.params.log_dir, self.params.pretrained_file, 'weights-best.pt')) 47 | 48 | self.attack, self.eval_attack = self.init_attack(self.model, self.criterion, self.params.attack, self.params.attack_eps, 49 | self.params.attack_iter, self.params.attack_step) 50 | 51 | 52 | @staticmethod 53 | def init_attack(model, criterion, attack_type, attack_eps, attack_iter, attack_step): 54 | """ 55 | Initialize adversary. 56 | """ 57 | attack = create_attack(model, criterion, attack_type, attack_eps, attack_iter, attack_step, rand_init_type='uniform') 58 | if attack_type in ['linf-pgd', 'l2-pgd']: 59 | eval_attack = create_attack(model, criterion, attack_type, attack_eps, 2*attack_iter, attack_step) 60 | elif attack_type in ['fgsm', 'linf-df']: 61 | eval_attack = create_attack(model, criterion, 'linf-pgd', 8/255, 20, 2/255) 62 | elif attack_type in ['fgm', 'l2-df']: 63 | eval_attack = create_attack(model, criterion, 'l2-pgd', 128/255, 20, 15/255) 64 | return attack, eval_attack 65 | 66 | 67 | def init_optimizer(self, num_epochs): 68 | """ 69 | Initialize optimizer and scheduler. 70 | """ 71 | self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay, 72 | momentum=0.9, nesterov=self.params.nesterov) 73 | if num_epochs <= 0: 74 | return 75 | self.init_scheduler(num_epochs) 76 | 77 | 78 | def init_scheduler(self, num_epochs): 79 | """ 80 | Initialize scheduler. 81 | """ 82 | if self.params.scheduler == 'cyclic': 83 | num_samples = 50000 if 'cifar10' in self.params.data else 73257 84 | num_samples = 100000 if 'tiny-imagenet' in self.params.data else num_samples 85 | update_steps = int(np.floor(num_samples/self.params.batch_size) + 1) 86 | self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=self.params.lr, pct_start=0.25, 87 | steps_per_epoch=update_steps, epochs=int(num_epochs)) 88 | elif self.params.scheduler == 'step': 89 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, gamma=0.1, milestones=[100, 105]) 90 | elif self.params.scheduler == 'cosine': 91 | self.scheduler = CosineLR(self.optimizer, max_lr=self.params.lr, epochs=int(num_epochs)) 92 | elif self.params.scheduler == 'cosinew': 93 | self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=self.params.lr, pct_start=0.025, 94 | total_steps=int(num_epochs)) 95 | else: 96 | self.scheduler = None 97 | 98 | 99 | def train(self, dataloader, epoch=0, adversarial=False, verbose=True): 100 | """ 101 | Run one epoch of training. 102 | """ 103 | metrics = pd.DataFrame() 104 | self.model.train() 105 | 106 | for data in tqdm(dataloader, desc='Epoch {}: '.format(epoch), disable=not verbose): 107 | x, y = data 108 | x, y = x.to(device), y.to(device) 109 | 110 | if adversarial: 111 | if self.params.beta is not None and self.params.mart: 112 | loss, batch_metrics = self.mart_loss(x, y, beta=self.params.beta) 113 | elif self.params.beta is not None: 114 | loss, batch_metrics = self.trades_loss(x, y, beta=self.params.beta) 115 | else: 116 | loss, batch_metrics = self.adversarial_loss(x, y) 117 | else: 118 | loss, batch_metrics = self.standard_loss(x, y) 119 | 120 | loss.backward() 121 | if self.params.clip_grad: 122 | nn.utils.clip_grad_norm_(self.model.parameters(), self.params.clip_grad) 123 | self.optimizer.step() 124 | if self.params.scheduler in ['cyclic']: 125 | self.scheduler.step() 126 | 127 | metrics = metrics.append(pd.DataFrame(batch_metrics, index=[0]), ignore_index=True) 128 | 129 | if self.params.scheduler in ['step', 'converge', 'cosine', 'cosinew']: 130 | self.scheduler.step() 131 | return dict(metrics.mean()) 132 | 133 | 134 | def standard_loss(self, x, y): 135 | """ 136 | Standard training. 137 | """ 138 | self.optimizer.zero_grad() 139 | out = self.model(x) 140 | loss = self.criterion(out, y) 141 | 142 | preds = out.detach() 143 | batch_metrics = {'loss': loss.item(), 'clean_acc': accuracy(y, preds)} 144 | return loss, batch_metrics 145 | 146 | 147 | def adversarial_loss(self, x, y): 148 | """ 149 | Adversarial training (Madry et al, 2017). 150 | """ 151 | with ctx_noparamgrad_and_eval(self.model): 152 | x_adv, _ = self.attack.perturb(x, y) 153 | 154 | self.optimizer.zero_grad() 155 | if self.params.keep_clean: 156 | x_adv = torch.cat((x, x_adv), dim=0) 157 | y_adv = torch.cat((y, y), dim=0) 158 | else: 159 | y_adv = y 160 | out = self.model(x_adv) 161 | loss = self.criterion(out, y_adv) 162 | 163 | preds = out.detach() 164 | batch_metrics = {'loss': loss.item()} 165 | if self.params.keep_clean: 166 | preds_clean, preds_adv = preds[:len(x)], preds[len(x):] 167 | batch_metrics.update({'clean_acc': accuracy(y, preds_clean), 'adversarial_acc': accuracy(y, preds_adv)}) 168 | else: 169 | batch_metrics.update({'adversarial_acc': accuracy(y, preds)}) 170 | return loss, batch_metrics 171 | 172 | 173 | def trades_loss(self, x, y, beta): 174 | """ 175 | TRADES training. 176 | """ 177 | loss, batch_metrics = trades_loss(self.model, x, y, self.optimizer, step_size=self.params.attack_step, 178 | epsilon=self.params.attack_eps, perturb_steps=self.params.attack_iter, 179 | beta=beta, attack=self.params.attack) 180 | return loss, batch_metrics 181 | 182 | 183 | def mart_loss(self, x, y, beta): 184 | """ 185 | MART training. 186 | """ 187 | loss, batch_metrics = mart_loss(self.model, x, y, self.optimizer, step_size=self.params.attack_step, 188 | epsilon=self.params.attack_eps, perturb_steps=self.params.attack_iter, 189 | beta=beta, attack=self.params.attack) 190 | return loss, batch_metrics 191 | 192 | 193 | def eval(self, dataloader, adversarial=False): 194 | """ 195 | Evaluate performance of the model. 196 | """ 197 | acc = 0.0 198 | self.model.eval() 199 | 200 | for x, y in dataloader: 201 | x, y = x.to(device), y.to(device) 202 | if adversarial: 203 | with ctx_noparamgrad_and_eval(self.model): 204 | x_adv, _ = self.eval_attack.perturb(x, y) 205 | out = self.model(x_adv) 206 | else: 207 | out = self.model(x) 208 | acc += accuracy(y, out) 209 | acc /= len(dataloader) 210 | return acc 211 | 212 | 213 | def save_model(self, path): 214 | """ 215 | Save model weights. 216 | """ 217 | torch.save({'model_state_dict': self.model.state_dict()}, path) 218 | 219 | 220 | def load_model(self, path, load_opt=True): 221 | """ 222 | Load model weights. 223 | """ 224 | checkpoint = torch.load(path) 225 | if 'model_state_dict' not in checkpoint: 226 | raise RuntimeError('Model weights not found at {}.'.format(path)) 227 | self.model.load_state_dict(checkpoint['model_state_dict']) 228 | -------------------------------------------------------------------------------- /core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import datetime 4 | import numpy as np 5 | import _pickle as pickle 6 | 7 | import torch 8 | 9 | 10 | class SmoothCrossEntropyLoss(torch.nn.Module): 11 | """ 12 | Cross entropy loss with label smoothing. 13 | """ 14 | def __init__(self, smoothing=0.0, reduction='mean'): 15 | super(SmoothCrossEntropyLoss, self).__init__() 16 | self.smoothing = smoothing 17 | self.confidence = 1.0 - smoothing 18 | self.reduction = reduction 19 | 20 | def forward(self, x, target): 21 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 23 | nll_loss = nll_loss.squeeze(1) 24 | smooth_loss = -logprobs.mean(dim=-1) 25 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 26 | if self.reduction == 'mean': 27 | return loss.mean() 28 | elif self.reduction == 'sum': 29 | return loss.sum() 30 | return loss 31 | 32 | 33 | def track_bn_stats(model, track_stats=True): 34 | """ 35 | If track_stats=False, do not update BN running mean and variance and vice versa. 36 | """ 37 | for module in model.modules(): 38 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 39 | module.track_running_stats = track_stats 40 | 41 | 42 | def set_bn_momentum(model, momentum=1): 43 | """ 44 | Set the value of momentum for all BN layers. 45 | """ 46 | for module in model.modules(): 47 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 48 | module.momentum = momentum 49 | 50 | 51 | def str2bool(v): 52 | """ 53 | Parse boolean using argument parser. 54 | """ 55 | if isinstance(v, bool): 56 | return v 57 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 58 | return True 59 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 60 | return False 61 | else: 62 | raise argparse.ArgumentTypeError('Boolean value expected.') 63 | 64 | def str2float(x): 65 | """ 66 | Parse float and fractions using argument parser. 67 | """ 68 | if '/' in x: 69 | n, d = x.split('/') 70 | return float(n)/float(d) 71 | else: 72 | try: 73 | return float(x) 74 | except: 75 | raise argparse.ArgumentTypeError('Fraction or float value expected.') 76 | 77 | 78 | def format_time(elapsed): 79 | """ 80 | Format time for displaying. 81 | Arguments: 82 | elapsed: time interval in seconds. 83 | """ 84 | elapsed_rounded = int(round((elapsed))) 85 | return str(datetime.timedelta(seconds=elapsed_rounded)) 86 | 87 | 88 | def seed(seed=1): 89 | """ 90 | Seed for PyTorch reproducibility. 91 | Arguments: 92 | seed (int): Random seed value. 93 | """ 94 | np.random.seed(seed) 95 | torch.manual_seed(seed) 96 | torch.cuda.manual_seed_all(seed) 97 | 98 | 99 | def unpickle_data(filename, mode='rb'): 100 | """ 101 | Read data from pickled file. 102 | Arguments: 103 | filename (str): path to the pickled file. 104 | mode (str): read mode. 105 | """ 106 | with open(filename, mode) as pkfile: 107 | data = pickle.load(pkfile) 108 | return data 109 | 110 | 111 | def pickle_data(data, filename, mode='wb'): 112 | """ 113 | Write data to pickled file. 114 | Arguments: 115 | data (Any): data to be written. 116 | filename (str): path to the pickled file. 117 | mode (str): write mode. 118 | """ 119 | with open(filename, mode) as pkfile: 120 | pickle.dump(data, pkfile) 121 | 122 | 123 | class NumpyToTensor(object): 124 | """ 125 | Transforms a numpy.ndarray to torch.Tensor. 126 | """ 127 | def __call__(self, sample): 128 | return torch.from_numpy(sample) -------------------------------------------------------------------------------- /eval-aa.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation with AutoAttack. 3 | """ 4 | 5 | import json 6 | import time 7 | import argparse 8 | import shutil 9 | 10 | import os 11 | import numpy as np 12 | import pandas as pd 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from autoattack import AutoAttack 18 | 19 | from core.data import get_data_info 20 | from core.data import load_data 21 | from core.models import create_model 22 | 23 | from core.utils import Logger 24 | from core.utils import parser_eval 25 | from core.utils import seed 26 | 27 | 28 | 29 | # Setup 30 | 31 | parse = parser_eval() 32 | args = parse.parse_args() 33 | 34 | LOG_DIR = args.log_dir + args.desc 35 | with open(LOG_DIR+'/args.txt', 'r') as f: 36 | old = json.load(f) 37 | args.__dict__ = dict(vars(args), **old) 38 | 39 | DATA_DIR = args.data_dir + args.data + '/' 40 | LOG_DIR = args.log_dir + args.desc 41 | WEIGHTS = LOG_DIR + '/weights-best.pt' 42 | 43 | log_path = LOG_DIR + '/log-aa.log' 44 | logger = Logger(log_path) 45 | 46 | info = get_data_info(DATA_DIR) 47 | BATCH_SIZE = args.batch_size 48 | BATCH_SIZE_VALIDATION = args.batch_size_validation 49 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 50 | 51 | logger.log('Using device: {}'.format(device)) 52 | 53 | 54 | 55 | # Load data 56 | 57 | seed(args.seed) 58 | _, _, train_dataloader, test_dataloader = load_data(DATA_DIR, BATCH_SIZE, BATCH_SIZE_VALIDATION, use_augmentation=False, 59 | shuffle_train=False) 60 | 61 | if args.train: 62 | logger.log('Evaluating on training set.') 63 | l = [x for (x, y) in train_dataloader] 64 | x_test = torch.cat(l, 0) 65 | l = [y for (x, y) in train_dataloader] 66 | y_test = torch.cat(l, 0) 67 | else: 68 | l = [x for (x, y) in test_dataloader] 69 | x_test = torch.cat(l, 0) 70 | l = [y for (x, y) in test_dataloader] 71 | y_test = torch.cat(l, 0) 72 | 73 | 74 | 75 | # Model 76 | 77 | model = create_model(args.model, args.normalize, info, device) 78 | checkpoint = torch.load(WEIGHTS) 79 | if 'tau' in args and args.tau: 80 | print ('Using WA model.') 81 | model.load_state_dict(checkpoint['model_state_dict']) 82 | model.eval() 83 | del checkpoint 84 | 85 | 86 | 87 | # AA Evaluation 88 | 89 | seed(args.seed) 90 | norm = 'Linf' if args.attack in ['fgsm', 'linf-pgd', 'linf-df'] else 'L2' 91 | adversary = AutoAttack(model, norm=norm, eps=args.attack_eps, log_path=log_path, version=args.version, seed=args.seed) 92 | 93 | if args.version == 'custom': 94 | adversary.attacks_to_run = ['apgd-ce', 'apgd-t'] 95 | adversary.apgd.n_restarts = 1 96 | adversary.apgd_targeted.n_restarts = 1 97 | 98 | with torch.no_grad(): 99 | x_adv = adversary.run_standard_evaluation(x_test, y_test, bs=BATCH_SIZE_VALIDATION) 100 | 101 | print ('Script Completed.') -------------------------------------------------------------------------------- /eval-adv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adversarial Evaluation with PGD+, CW (Margin) PGD and black box adversary. 3 | """ 4 | 5 | import json 6 | import time 7 | import argparse 8 | import shutil 9 | 10 | import os 11 | import numpy as np 12 | import pandas as pd 13 | from tqdm import tqdm as tqdm 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | from core.attacks import create_attack 19 | from core.attacks import CWLoss 20 | 21 | from core.data import get_data_info 22 | from core.data import load_data 23 | 24 | from core.models import create_model 25 | 26 | from core.utils import ctx_noparamgrad_and_eval 27 | from core.utils import Logger 28 | from core.utils import parser_eval 29 | from core.utils import seed 30 | from core.utils import Trainer 31 | 32 | 33 | 34 | # Setup 35 | 36 | parse = parser_eval() 37 | args = parse.parse_args() 38 | 39 | LOG_DIR = args.log_dir + args.desc 40 | with open(LOG_DIR+'/args.txt', 'r') as f: 41 | old = json.load(f) 42 | args.__dict__ = dict(vars(args), **old) 43 | 44 | DATA_DIR = args.data_dir + args.data + '/' 45 | LOG_DIR = args.log_dir + args.desc 46 | WEIGHTS = LOG_DIR + '/weights-best.pt' 47 | 48 | logger = Logger(LOG_DIR+'/log-adv.log') 49 | 50 | info = get_data_info(DATA_DIR) 51 | BATCH_SIZE = args.batch_size 52 | BATCH_SIZE_VALIDATION = args.batch_size_validation 53 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 54 | 55 | logger.log('Using device: {}'.format(device)) 56 | 57 | 58 | 59 | # Load data 60 | 61 | seed(args.seed) 62 | _, _, _, test_dataloader = load_data(DATA_DIR, BATCH_SIZE, BATCH_SIZE_VALIDATION, use_augmentation=False, 63 | shuffle_train=False) 64 | 65 | 66 | 67 | # Helper function 68 | 69 | def eval_multiple_restarts(attack, model, dataloader, num_restarts=5, verbose=True): 70 | """ 71 | Evaluate adversarial accuracy with multiple restarts. 72 | """ 73 | model.eval() 74 | N = len(dataloader.dataset) 75 | is_correct = torch.ones(N).bool().to(device) 76 | for i in tqdm(range(0, num_restarts), disable=not verbose): 77 | iter_is_correct = [] 78 | for x, y in tqdm(dataloader): 79 | x, y = x.to(device), y.to(device) 80 | with ctx_noparamgrad_and_eval(model): 81 | x_adv, _ = attack.perturb(x, y) 82 | out = model(x_adv) 83 | iter_is_correct.extend(torch.softmax(out, dim=1).argmax(dim=1) == y) 84 | is_correct = torch.logical_and(is_correct, torch.BoolTensor(iter_is_correct).to(device)) 85 | 86 | adv_acc = (is_correct.sum().float()/N).item() 87 | return adv_acc 88 | 89 | def eval_multiple_restarts_advertorch(attack, model, dataloader, num_restarts=1, verbose=True): 90 | """ 91 | Evaluate adversarial accuracy with multiple restarts (Advertorch). 92 | """ 93 | model.eval() 94 | N = len(dataloader.dataset) 95 | is_correct = torch.ones(N).bool().to(device) 96 | for i in tqdm(range(0, num_restarts), disable=not verbose): 97 | iter_is_correct = [] 98 | for x, y in tqdm(dataloader): 99 | x, y = x.to(device), y.to(device) 100 | with ctx_noparamgrad_and_eval(model): 101 | x_adv = attack.perturb(x, y) 102 | out = model(x_adv) 103 | iter_is_correct.extend(torch.softmax(out, dim=1).argmax(dim=1) == y) 104 | is_correct = torch.logical_and(is_correct, torch.BoolTensor(iter_is_correct).to(device)) 105 | 106 | adv_acc = (is_correct.sum().float()/N).item() 107 | return adv_acc 108 | 109 | 110 | 111 | # PGD Evaluation 112 | 113 | seed(args.seed) 114 | trainer = Trainer(info, args) 115 | if 'tau' in args and args.tau: 116 | print ('Using WA model.') 117 | trainer.load_model(WEIGHTS) 118 | trainer.model.eval() 119 | 120 | test_acc = trainer.eval(test_dataloader) 121 | logger.log('\nStandard Accuracy-\tTest: {:.2f}%.'.format(test_acc*100)) 122 | 123 | 124 | 125 | if args.wb: 126 | # CW-PGD-40 Evaluation 127 | seed(args.seed) 128 | num_restarts = 1 129 | if args.attack in ['linf-pgd', 'linf-df', 'fgsm']: 130 | args.attack_iter, args.attack_step = 40, 0.01 131 | else: 132 | args.attack_iter, args.attack_step = 40, 30/255.0 133 | assert args.attack in ['linf-pgd', 'l2-pgd'], 'CW evaluation only supported for attack=linf-pgd or attack=l2-pgd !' 134 | attack = create_attack(trainer.model, CWLoss, args.attack, args.attack_eps, args.attack_iter, args.attack_step) 135 | logger.log('\n==== CW-PGD Evaluation. ====') 136 | logger.log('Attack: cw-{}.'.format(args.attack)) 137 | logger.log('Attack Parameters: Step size: {:.3f}, Epsilon: {:.3f}, #Iterations: {}.'.format(args.attack_step, 138 | args.attack_eps, 139 | args.attack_iter)) 140 | 141 | test_adv_acc1 = eval_multiple_restarts(attack, trainer.model, test_dataloader, num_restarts, verbose=False) 142 | logger.log('Adversarial Accuracy-\tTest: {:.2f}%.'.format(test_adv_acc1*100)) 143 | 144 | 145 | # PGD-40 (with 5 restarts) Evaluation 146 | seed(args.seed) 147 | num_restarts = 5 148 | if args.attack in ['linf-pgd', 'linf-df', 'fgsm']: 149 | args.attack_iter, args.attack_step = 40, 0.01 150 | else: 151 | args.attack_iter, args.attack_step = 40, 30/255.0 152 | attack = create_attack(trainer.model, trainer.criterion, args.attack, args.attack_eps, args.attack_iter, args.attack_step) 153 | logger.log('\n==== PGD+ Evaluation. ====') 154 | logger.log('Attack: {} with {} restarts.'.format(args.attack, num_restarts)) 155 | logger.log('Attack Parameters: Step size: {:.3f}, Epsilon: {:.3f}, #Iterations: {}.'.format(args.attack_step, 156 | args.attack_eps, 157 | args.attack_iter)) 158 | 159 | test_adv_acc2 = eval_multiple_restarts(attack, trainer.model, test_dataloader, num_restarts, verbose=True) 160 | logger.log('Adversarial Accuracy-\tTest: {:.2f}%.'.format(test_adv_acc2*100)) 161 | 162 | 163 | 164 | # Black Box Evaluation 165 | 166 | class dotdict(dict): 167 | def __getattr__(self, name): 168 | return self[name] 169 | 170 | if args.source != None: 171 | seed(args.seed) 172 | assert args.attack in ['linf-pgd', 'l2-pgd'], 'Black-box evaluation only supported for attack=linf-pgd or attack=l2-pgd!' 173 | if args.attack in ['linf-pgd', 'linf-df', 'fgsm']: 174 | args.attack_iter, args.attack_step = 40, 0.01 175 | else: 176 | args.attack_iter, args.attack_step = 40, 30/255.0 177 | 178 | SRC_LOG_DIR = args.log_dir + args.source 179 | with open(SRC_LOG_DIR+'/args.txt', 'r') as f: 180 | src_args = json.load(f) 181 | src_args = dotdict(src_args) 182 | 183 | src_model = create_model(src_args.model, src_args.normalize, info, device) 184 | src_model.load_state_dict(torch.load(SRC_LOG_DIR + '/weights-best.pt')['model_state_dict']) 185 | src_model.eval() 186 | 187 | src_attack = create_attack(src_model, trainer.criterion, args.attack, args.attack_eps, args.attack_iter, args.attack_step) 188 | adv_acc = 0.0 189 | for x, y in test_dataloader: 190 | x, y = x.to(device), y.to(device) 191 | with ctx_noparamgrad_and_eval(src_model): 192 | x_adv, _ = src_attack.perturb(x, y) 193 | out = trainer.model(x_adv) 194 | adv_acc += accuracy(y, out) 195 | adv_acc /= len(test_dataloader) 196 | 197 | logger.log('\n==== Black-box Evaluation. ====') 198 | logger.log('Source Model: {}.'.format(args.source)) 199 | logger.log('Attack: {}.'.format(args.attack)) 200 | logger.log('Attack Parameters: Step size: {:.3f}, Epsilon: {:.3f}, #Iterations: {}.'.format(args.attack_step, 201 | args.attack_eps, 202 | args.attack_iter)) 203 | logger.log('Black-box Adv. Accuracy-\tTest: {:.2f}%.'.format(adv_acc*100)) 204 | del src_attack, src_model 205 | 206 | 207 | logger.log('Script Completed.') -------------------------------------------------------------------------------- /eval-rb.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation with Robustbench. 3 | """ 4 | 5 | import json 6 | import time 7 | import argparse 8 | import shutil 9 | 10 | import os 11 | import numpy as np 12 | import pandas as pd 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from robustbench import benchmark 18 | 19 | from core.data import get_data_info 20 | from core.models import create_model 21 | 22 | from core.utils import Logger 23 | from core.utils import parser_eval 24 | from core.utils import seed 25 | 26 | 27 | 28 | # Setup 29 | 30 | parse = parser_eval() 31 | 32 | args = parse.parse_args() 33 | 34 | LOG_DIR = args.log_dir + args.desc 35 | with open(LOG_DIR+'/args.txt', 'r') as f: 36 | old = json.load(f) 37 | args.__dict__ = dict(vars(args), **old) 38 | 39 | args.data = 'cifar10' if args.data in ['cifar10s', 'cifar10g'] else args.data 40 | DATA_DIR = args.data_dir + args.data + '/' 41 | LOG_DIR = args.log_dir + args.desc 42 | WEIGHTS = LOG_DIR + '/weights-best.pt' 43 | 44 | log_path = LOG_DIR + f'/log-corr-{args.threat}.log' 45 | logger = Logger(log_path) 46 | 47 | info = get_data_info(DATA_DIR) 48 | BATCH_SIZE = args.batch_size 49 | BATCH_SIZE_VALIDATION = args.batch_size_validation 50 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 51 | 52 | assert args.data in ['cifar10'], 'Evaluation on Robustbench is only supported for cifar10!' 53 | 54 | threat_model = args.threat 55 | dataset = args.data 56 | model_name = args.desc 57 | 58 | 59 | 60 | # Model 61 | 62 | model = create_model(args.model, args.normalize, info, device) 63 | checkpoint = torch.load(WEIGHTS) 64 | if 'tau' in args and args.tau: 65 | print ('Using WA model.') 66 | model.load_state_dict(checkpoint['model_state_dict']) 67 | model.eval() 68 | del checkpoint 69 | 70 | 71 | 72 | # Common corruptions 73 | 74 | seed(args.seed) 75 | clean_acc, robust_acc = benchmark(model, model_name=model_name, n_examples=args.num_samples, dataset=dataset, 76 | threat_model=threat_model, eps=args.attack_eps, device=device, to_disk=False, 77 | data_dir=args.tmp_dir + args.data + 'c') 78 | 79 | 80 | logger.log('Model: {}'.format(args.desc)) 81 | logger.log('Evaluating robustness on {} with threat model={}.'.format(args.data, args.threat)) 82 | logger.log('Clean Accuracy: \t{:.2f}%.\nRobust Accuracy: \t{:.2f}%.'.format(clean_acc*100, robust_acc*100)) -------------------------------------------------------------------------------- /gowal21uncovering/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .watrain import WATrainer -------------------------------------------------------------------------------- /gowal21uncovering/utils/cutmix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def cutmix(images, labels, alpha=1.0, beta=1.0, num_classes=10): 8 | """ 9 | Apply CutMix to a batch of images. 10 | Arguments: 11 | image (torch.FloatTensor): images. 12 | labels (torch.LongTensor): target labels. 13 | alpha (float): parameter for cut ratio. 14 | beta (float): parameter for cut ratio. 15 | num_classes (int): number of target classes. 16 | Returns: 17 | augmented batch of images and labels. 18 | """ 19 | batch_size, _, height, width = images.shape 20 | labels = F.one_hot(labels, num_classes) 21 | 22 | lam = np.random.beta(alpha, beta) 23 | cut_rat = np.sqrt(1. - lam) 24 | cut_w = np.array(width * cut_rat, dtype=np.int32) 25 | cut_h = np.array(height * cut_rat, dtype=np.int32) 26 | box_coords = _random_box(height, width, cut_h, cut_w) 27 | 28 | # Adjust lambda. 29 | lam = 1. - (box_coords[2] * box_coords[3] / (height * width)) 30 | idx = np.random.permutation(batch_size) 31 | 32 | def _cutmix(x, y): 33 | images_a = x 34 | images_b = x[idx, :, :, :] 35 | y = lam * y + (1. - lam) * y[idx, :] 36 | x = _compose_two_images(images_a, images_b, box_coords) 37 | return x, y 38 | 39 | return _cutmix(images, labels) 40 | 41 | 42 | def _random_box(height, width, cut_h, cut_w): 43 | """ 44 | Return a random box within the image size. 45 | """ 46 | minval_h = 0 47 | minval_w = 0 48 | maxval_h = height 49 | maxval_w = width 50 | 51 | i = np.random.randint(minval_h, maxval_h, dtype=np.int32) 52 | j = np.random.randint(minval_w, maxval_w, dtype=np.int32) 53 | bby1 = np.clip(i - cut_h // 2, 0, height) 54 | bbx1 = np.clip(j - cut_w // 2, 0, width) 55 | h = np.clip(i + cut_h // 2, 0, height) - bby1 56 | w = np.clip(j + cut_w // 2, 0, width) - bbx1 57 | return np.array([bby1, bbx1, h, w]) 58 | 59 | 60 | def _compose_two_images(images, image_permutation, bbox): 61 | """ 62 | Mix two images. 63 | """ 64 | def _single_compose_two_images(image1, image2): 65 | _, height, width = image1.shape 66 | mask = _window_mask(bbox, (height, width)) 67 | return image1 * (1. - mask) + image2 * mask 68 | 69 | new_images = [_single_compose_two_images(image1, image2) for image1, image2 in zip(images, image_permutation)] 70 | return torch.stack(new_images, dim=0) 71 | 72 | 73 | def _window_mask(destination_box, size): 74 | """ 75 | Compute window mask. 76 | """ 77 | height_offset, width_offset, h, w = destination_box 78 | h_range = np.reshape(np.arange(size[0]), [1, size[0], 1]) 79 | w_range = np.reshape(np.arange(size[1]), [1, 1, size[1]]) 80 | return np.logical_and( 81 | np.logical_and(height_offset <= h_range, h_range < height_offset + h), 82 | np.logical_and(width_offset <= w_range, w_range < width_offset + w) 83 | ).astype(np.float32) -------------------------------------------------------------------------------- /gowal21uncovering/utils/trades.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import torch.optim as optim 6 | 7 | from core.metrics import accuracy 8 | from core.utils import SmoothCrossEntropyLoss 9 | from core.utils import track_bn_stats 10 | 11 | 12 | def squared_l2_norm(x): 13 | flattened = x.view(x.unsqueeze(0).shape[0], -1) 14 | return (flattened ** 2).sum(1) 15 | 16 | 17 | def l2_norm(x): 18 | return squared_l2_norm(x).sqrt() 19 | 20 | 21 | def trades_loss(model, x_natural, y, optimizer, step_size=0.003, epsilon=0.031, perturb_steps=10, beta=1.0, 22 | attack='linf-pgd', label_smoothing=0.1): 23 | """ 24 | TRADES training (Zhang et al, 2019). 25 | """ 26 | 27 | criterion_ce = SmoothCrossEntropyLoss(reduction='mean', smoothing=label_smoothing) 28 | criterion_kl = nn.KLDivLoss(reduction='sum') 29 | model.train() 30 | track_bn_stats(model, False) 31 | batch_size = len(x_natural) 32 | 33 | x_adv = x_natural.detach() + torch.FloatTensor(x_natural.shape).uniform_(-epsilon, epsilon).cuda().detach() 34 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 35 | p_natural = F.softmax(model(x_natural), dim=1) 36 | 37 | if attack == 'linf-pgd': 38 | for _ in range(perturb_steps): 39 | x_adv.requires_grad_() 40 | with torch.enable_grad(): 41 | loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1), p_natural) 42 | grad = torch.autograd.grad(loss_kl, [x_adv])[0] 43 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 44 | x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon) 45 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 46 | 47 | elif attack == 'l2-pgd': 48 | delta = 0.001 * torch.randn(x_natural.shape).cuda().detach() 49 | delta = Variable(delta.data, requires_grad=True) 50 | 51 | # Setup optimizers 52 | optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2) 53 | 54 | for _ in range(perturb_steps): 55 | adv = x_natural + delta 56 | 57 | # optimize 58 | optimizer_delta.zero_grad() 59 | with torch.enable_grad(): 60 | loss = (-1) * criterion_kl(F.log_softmax(model(adv), dim=1), p_natural) 61 | loss.backward(retain_graph=True) 62 | # renorming gradient 63 | grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1) 64 | delta.grad.div_(grad_norms.view(-1, 1, 1, 1)) 65 | # avoid nan or inf if gradient is 0 66 | if (grad_norms == 0).any(): 67 | delta.grad[grad_norms == 0] = torch.randn_like(delta.grad[grad_norms == 0]) 68 | optimizer_delta.step() 69 | 70 | # projection 71 | delta.data.add_(x_natural) 72 | delta.data.clamp_(0, 1).sub_(x_natural) 73 | delta.data.renorm_(p=2, dim=0, maxnorm=epsilon) 74 | x_adv = Variable(x_natural + delta, requires_grad=False) 75 | else: 76 | raise ValueError(f'Attack={attack} not supported for TRADES training!') 77 | model.train() 78 | track_bn_stats(model, True) 79 | 80 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 81 | 82 | optimizer.zero_grad() 83 | # calculate robust loss 84 | logits_natural = model(x_natural) 85 | logits_adv = model(x_adv) 86 | loss_natural = criterion_ce(logits_natural, y) 87 | loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(logits_adv, dim=1), 88 | F.softmax(logits_natural, dim=1)) 89 | loss = loss_natural + beta * loss_robust 90 | 91 | batch_metrics = {'loss': loss.item(), 'clean_acc': accuracy(y, logits_natural.detach()), 92 | 'adversarial_acc': accuracy(y, logits_adv.detach())} 93 | 94 | return loss, batch_metrics 95 | -------------------------------------------------------------------------------- /gowal21uncovering/utils/watrain.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from tqdm import tqdm as tqdm 4 | 5 | import copy 6 | import json 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from core.attacks import create_attack 12 | from core.attacks import CWLoss 13 | from core.metrics import accuracy 14 | from core.models import create_model 15 | 16 | from core.utils import ctx_noparamgrad_and_eval 17 | from core.utils import Trainer 18 | from core.utils import set_bn_momentum 19 | from core.utils import seed 20 | 21 | from .trades import trades_loss 22 | 23 | 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | 26 | 27 | class WATrainer(Trainer): 28 | """ 29 | Helper class for training a deep neural network with model weight averaging (identical to Gowal et al, 2020). 30 | Arguments: 31 | info (dict): dataset information. 32 | args (dict): input arguments. 33 | """ 34 | def __init__(self, info, args): 35 | super(WATrainer, self).__init__(info, args) 36 | 37 | seed(args.seed) 38 | self.wa_model = copy.deepcopy(self.model) 39 | self.eval_attack = create_attack(self.wa_model, CWLoss, args.attack, args.attack_eps, 4*args.attack_iter, 40 | args.attack_step) 41 | num_samples = 50000 if 'cifar' in self.params.data else 73257 42 | num_samples = 100000 if 'tiny-imagenet' in self.params.data else num_samples 43 | self.update_steps = int(np.floor(num_samples/self.params.batch_size) + 1) 44 | self.warmup_steps = 0.025 * self.params.num_adv_epochs * self.update_steps 45 | 46 | 47 | def init_optimizer(self, num_epochs): 48 | """ 49 | Initialize optimizer and schedulers. 50 | """ 51 | def group_weight(model): 52 | group_decay = [] 53 | group_no_decay = [] 54 | for n, p in model.named_parameters(): 55 | if 'batchnorm' in n: 56 | group_no_decay.append(p) 57 | else: 58 | group_decay.append(p) 59 | assert len(list(model.parameters())) == len(group_decay) + len(group_no_decay) 60 | groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)] 61 | return groups 62 | 63 | self.optimizer = torch.optim.SGD(group_weight(self.model), lr=self.params.lr, weight_decay=self.params.weight_decay, 64 | momentum=0.9, nesterov=self.params.nesterov) 65 | if num_epochs <= 0: 66 | return 67 | self.init_scheduler(num_epochs) 68 | 69 | 70 | def train(self, dataloader, epoch=0, adversarial=False, verbose=True): 71 | """ 72 | Run one epoch of training. 73 | """ 74 | metrics = pd.DataFrame() 75 | self.model.train() 76 | 77 | update_iter = 0 78 | for data in tqdm(dataloader, desc='Epoch {}: '.format(epoch), disable=not verbose): 79 | global_step = (epoch - 1) * self.update_steps + update_iter 80 | if global_step == 0: 81 | # make BN running mean and variance init same as Haiku 82 | set_bn_momentum(self.model, momentum=1.0) 83 | elif global_step == 1: 84 | set_bn_momentum(self.model, momentum=0.01) 85 | update_iter += 1 86 | 87 | x, y = data 88 | x, y = x.to(device), y.to(device) 89 | 90 | if adversarial: 91 | if self.params.beta is not None and self.params.mart: 92 | loss, batch_metrics = self.mart_loss(x, y, beta=self.params.beta) 93 | elif self.params.beta is not None: 94 | loss, batch_metrics = self.trades_loss(x, y, beta=self.params.beta) 95 | else: 96 | loss, batch_metrics = self.adversarial_loss(x, y) 97 | else: 98 | loss, batch_metrics = self.standard_loss(x, y) 99 | 100 | loss.backward() 101 | if self.params.clip_grad: 102 | nn.utils.clip_grad_norm_(self.model.parameters(), self.params.clip_grad) 103 | self.optimizer.step() 104 | if self.params.scheduler in ['cyclic']: 105 | self.scheduler.step() 106 | 107 | global_step = (epoch - 1) * self.update_steps + update_iter 108 | ema_update(self.wa_model, self.model, global_step, decay_rate=self.params.tau, 109 | warmup_steps=self.warmup_steps, dynamic_decay=True) 110 | metrics = metrics.append(pd.DataFrame(batch_metrics, index=[0]), ignore_index=True) 111 | 112 | if self.params.scheduler in ['step', 'converge', 'cosine', 'cosinew']: 113 | self.scheduler.step() 114 | 115 | update_bn(self.wa_model, self.model) 116 | return dict(metrics.mean()) 117 | 118 | 119 | def trades_loss(self, x, y, beta): 120 | """ 121 | TRADES training. 122 | """ 123 | loss, batch_metrics = trades_loss(self.model, x, y, self.optimizer, step_size=self.params.attack_step, 124 | epsilon=self.params.attack_eps, perturb_steps=self.params.attack_iter, 125 | beta=beta, attack=self.params.attack) 126 | return loss, batch_metrics 127 | 128 | 129 | def eval(self, dataloader, adversarial=False): 130 | """ 131 | Evaluate performance of the model. 132 | """ 133 | acc = 0.0 134 | self.wa_model.eval() 135 | 136 | for x, y in dataloader: 137 | x, y = x.to(device), y.to(device) 138 | if adversarial: 139 | with ctx_noparamgrad_and_eval(self.wa_model): 140 | x_adv, _ = self.eval_attack.perturb(x, y) 141 | out = self.wa_model(x_adv) 142 | else: 143 | out = self.wa_model(x) 144 | acc += accuracy(y, out) 145 | acc /= len(dataloader) 146 | return acc 147 | 148 | 149 | def save_model(self, path): 150 | """ 151 | Save model weights. 152 | """ 153 | torch.save({ 154 | 'model_state_dict': self.wa_model.state_dict(), 155 | 'unaveraged_model_state_dict': self.model.state_dict() 156 | }, path) 157 | 158 | 159 | def load_model(self, path): 160 | """ 161 | Load model weights. 162 | """ 163 | checkpoint = torch.load(path) 164 | if 'model_state_dict' not in checkpoint: 165 | raise RuntimeError('Model weights not found at {}.'.format(path)) 166 | self.wa_model.load_state_dict(checkpoint['model_state_dict']) 167 | 168 | 169 | def ema_update(wa_model, model, global_step, decay_rate=0.995, warmup_steps=0, dynamic_decay=True): 170 | """ 171 | Exponential model weight averaging update. 172 | """ 173 | factor = int(global_step >= warmup_steps) 174 | if dynamic_decay: 175 | delta = global_step - warmup_steps 176 | decay = min(decay_rate, (1. + delta) / (10. + delta)) if 10. + delta != 0 else decay_rate 177 | else: 178 | decay = decay_rate 179 | decay *= factor 180 | 181 | for p_swa, p_model in zip(wa_model.parameters(), model.parameters()): 182 | p_swa.data *= decay 183 | p_swa.data += p_model.data * (1 - decay) 184 | 185 | 186 | @torch.no_grad() 187 | def update_bn(avg_model, model): 188 | """ 189 | Update batch normalization layers. 190 | """ 191 | avg_model.eval() 192 | model.eval() 193 | for module1, module2 in zip(avg_model.modules(), model.modules()): 194 | if isinstance(module1, torch.nn.modules.batchnorm._BatchNorm): 195 | module1.running_mean = module2.running_mean 196 | module1.running_var = module2.running_var 197 | module1.num_batches_tracked = module2.num_batches_tracked 198 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | autoattack==0.1 2 | matplotlib==3.4.2 3 | numpy==1.19.5 4 | pandas==1.2.4 5 | Pillow==8.2.0 6 | robustbench==0.1 7 | scipy==1.4.1 8 | torch==1.8.0+cu101 9 | torchvision==0.9.0+cu101 10 | tqdm==4.60.0 -------------------------------------------------------------------------------- /train-wa.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adversarial Training (with improvements from Gowal et al., 2020). 3 | """ 4 | 5 | import json 6 | import time 7 | import argparse 8 | import shutil 9 | 10 | import os 11 | import numpy as np 12 | import pandas as pd 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from core.data import get_data_info 18 | from core.data import load_data 19 | from core.data import SEMISUP_DATASETS 20 | 21 | from core.utils import format_time 22 | from core.utils import Logger 23 | from core.utils import parser_train 24 | from core.utils import Trainer 25 | from core.utils import seed 26 | 27 | from gowal21uncovering.utils import WATrainer 28 | 29 | 30 | # Setup 31 | 32 | parse = parser_train() 33 | parse.add_argument('--tau', type=float, default=0.995, help='Weight averaging decay.') 34 | args = parse.parse_args() 35 | assert args.data in SEMISUP_DATASETS, f'Only data in {SEMISUP_DATASETS} is supported!' 36 | 37 | 38 | DATA_DIR = os.path.join(args.data_dir, args.data) 39 | LOG_DIR = os.path.join(args.log_dir, args.desc) 40 | WEIGHTS = os.path.join(LOG_DIR, 'weights-best.pt') 41 | if os.path.exists(LOG_DIR): 42 | shutil.rmtree(LOG_DIR) 43 | os.makedirs(LOG_DIR) 44 | logger = Logger(os.path.join(LOG_DIR, 'log-train.log')) 45 | 46 | with open(os.path.join(LOG_DIR, 'args.txt'), 'w') as f: 47 | json.dump(args.__dict__, f, indent=4) 48 | 49 | 50 | info = get_data_info(DATA_DIR) 51 | BATCH_SIZE = args.batch_size 52 | BATCH_SIZE_VALIDATION = args.batch_size_validation 53 | NUM_ADV_EPOCHS = args.num_adv_epochs 54 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 55 | logger.log('Using device: {}'.format(device)) 56 | if args.debug: 57 | NUM_ADV_EPOCHS = 1 58 | 59 | # To speed up training 60 | torch.backends.cudnn.benchmark = True 61 | 62 | 63 | 64 | # Load data 65 | 66 | seed(args.seed) 67 | train_dataset, test_dataset, eval_dataset, train_dataloader, test_dataloader, eval_dataloader = load_data( 68 | DATA_DIR, BATCH_SIZE, BATCH_SIZE_VALIDATION, use_augmentation=args.augment, shuffle_train=True, 69 | aux_data_filename=args.aux_data_filename, unsup_fraction=args.unsup_fraction, validation=True 70 | ) 71 | del train_dataset, test_dataset, eval_dataset 72 | 73 | 74 | 75 | # Adversarial Training 76 | 77 | seed(args.seed) 78 | if args.tau: 79 | print ('Using WA.') 80 | trainer = WATrainer(info, args) 81 | else: 82 | trainer = Trainer(info, args) 83 | last_lr = args.lr 84 | 85 | if NUM_ADV_EPOCHS > 0: 86 | logger.log('\n\n') 87 | metrics = pd.DataFrame() 88 | logger.log('Standard Accuracy-\tTest: {:2f}%.'.format(trainer.eval(test_dataloader)*100)) 89 | 90 | old_score = [0.0, 0.0] 91 | logger.log('RST Adversarial training for {} epochs'.format(NUM_ADV_EPOCHS)) 92 | trainer.init_optimizer(args.num_adv_epochs) 93 | test_adv_acc = 0.0 94 | 95 | 96 | for epoch in range(1, NUM_ADV_EPOCHS+1): 97 | start = time.time() 98 | logger.log('======= Epoch {} ======='.format(epoch)) 99 | 100 | if args.scheduler: 101 | last_lr = trainer.scheduler.get_last_lr()[0] 102 | 103 | res = trainer.train(train_dataloader, epoch=epoch, adversarial=True) 104 | test_acc = trainer.eval(test_dataloader) 105 | 106 | logger.log('Loss: {:.4f}.\tLR: {:.4f}'.format(res['loss'], last_lr)) 107 | if 'clean_acc' in res: 108 | logger.log('Standard Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.'.format(res['clean_acc']*100, test_acc*100)) 109 | else: 110 | logger.log('Standard Accuracy-\tTest: {:.2f}%.'.format(test_acc*100)) 111 | epoch_metrics = {'train_'+k: v for k, v in res.items()} 112 | epoch_metrics.update({'epoch': epoch, 'lr': last_lr, 'test_clean_acc': test_acc, 'test_adversarial_acc': ''}) 113 | 114 | if epoch % args.adv_eval_freq == 0 or epoch == NUM_ADV_EPOCHS: 115 | test_adv_acc = trainer.eval(test_dataloader, adversarial=True) 116 | logger.log('Adversarial Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.'.format(res['adversarial_acc']*100, 117 | test_adv_acc*100)) 118 | epoch_metrics.update({'test_adversarial_acc': test_adv_acc}) 119 | else: 120 | logger.log('Adversarial Accuracy-\tTrain: {:.2f}%.'.format(res['adversarial_acc']*100)) 121 | eval_adv_acc = trainer.eval(eval_dataloader, adversarial=True) 122 | logger.log('Adversarial Accuracy-\tEval: {:.2f}%.'.format(eval_adv_acc*100)) 123 | epoch_metrics['eval_adversarial_acc'] = eval_adv_acc 124 | 125 | if eval_adv_acc >= old_score[1]: 126 | old_score[0], old_score[1] = test_acc, eval_adv_acc 127 | trainer.save_model(WEIGHTS) 128 | trainer.save_model(os.path.join(LOG_DIR, 'weights-last.pt')) 129 | 130 | logger.log('Time taken: {}'.format(format_time(time.time()-start))) 131 | metrics = metrics.append(pd.DataFrame(epoch_metrics, index=[0]), ignore_index=True) 132 | metrics.to_csv(os.path.join(LOG_DIR, 'stats_adv.csv'), index=False) 133 | 134 | 135 | 136 | # Record metrics 137 | 138 | train_acc = res['clean_acc'] if 'clean_acc' in res else trainer.eval(train_dataloader) 139 | logger.log('\nTraining completed.') 140 | logger.log('Standard Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.'.format(train_acc*100, old_score[0]*100)) 141 | if NUM_ADV_EPOCHS > 0: 142 | logger.log('Adversarial Accuracy-\tTrain: {:.2f}%.\tEval: {:.2f}%.'.format(res['adversarial_acc']*100, old_score[1]*100)) 143 | 144 | logger.log('Script Completed.') 145 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adversarial Training. 3 | """ 4 | 5 | import json 6 | import time 7 | import argparse 8 | import shutil 9 | 10 | import os 11 | import numpy as np 12 | import pandas as pd 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from core.data import get_data_info 18 | from core.data import load_data 19 | 20 | from core.utils import format_time 21 | from core.utils import Logger 22 | from core.utils import parser_train 23 | from core.utils import Trainer 24 | from core.utils import seed 25 | 26 | 27 | 28 | # Setup 29 | 30 | parse = parser_train() 31 | args = parse.parse_args() 32 | 33 | 34 | DATA_DIR = os.path.join(args.data_dir, args.data) 35 | LOG_DIR = os.path.join(args.log_dir, args.desc) 36 | WEIGHTS = os.path.join(LOG_DIR, 'weights-best.pt') 37 | if os.path.exists(LOG_DIR): 38 | shutil.rmtree(LOG_DIR) 39 | os.makedirs(LOG_DIR) 40 | logger = Logger(os.path.join(LOG_DIR, 'log-train.log')) 41 | 42 | with open(os.path.join(LOG_DIR, 'args.txt'), 'w') as f: 43 | json.dump(args.__dict__, f, indent=4) 44 | 45 | 46 | info = get_data_info(DATA_DIR) 47 | BATCH_SIZE = args.batch_size 48 | BATCH_SIZE_VALIDATION = args.batch_size_validation 49 | NUM_ADV_EPOCHS = args.num_adv_epochs 50 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 51 | logger.log('Using device: {}'.format(device)) 52 | if args.debug: 53 | NUM_ADV_EPOCHS = 1 54 | 55 | # To speed up training 56 | torch.backends.cudnn.benchmark = True 57 | 58 | 59 | 60 | # Load data 61 | 62 | seed(args.seed) 63 | train_dataset, test_dataset, train_dataloader, test_dataloader = load_data( 64 | DATA_DIR, BATCH_SIZE, BATCH_SIZE_VALIDATION, use_augmentation=args.augment, shuffle_train=True, 65 | aux_data_filename=args.aux_data_filename, unsup_fraction=args.unsup_fraction 66 | ) 67 | del train_dataset, test_dataset 68 | 69 | 70 | 71 | # Adversarial Training (AT, TRADES and MART) 72 | 73 | seed(args.seed) 74 | trainer = Trainer(info, args) 75 | last_lr = args.lr 76 | 77 | 78 | if NUM_ADV_EPOCHS > 0: 79 | logger.log('\n\n') 80 | metrics = pd.DataFrame() 81 | logger.log('Standard Accuracy-\tTest: {:2f}%.'.format(trainer.eval(test_dataloader)*100)) 82 | 83 | old_score = [0.0, 0.0] 84 | logger.log('Adversarial training for {} epochs'.format(NUM_ADV_EPOCHS)) 85 | trainer.init_optimizer(args.num_adv_epochs) 86 | test_adv_acc = 0.0 87 | 88 | 89 | for epoch in range(1, NUM_ADV_EPOCHS+1): 90 | start = time.time() 91 | logger.log('======= Epoch {} ======='.format(epoch)) 92 | 93 | if args.scheduler: 94 | last_lr = trainer.scheduler.get_last_lr()[0] 95 | 96 | res = trainer.train(train_dataloader, epoch=epoch, adversarial=True) 97 | test_acc = trainer.eval(test_dataloader) 98 | 99 | logger.log('Loss: {:.4f}.\tLR: {:.4f}'.format(res['loss'], last_lr)) 100 | if 'clean_acc' in res: 101 | logger.log('Standard Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.'.format(res['clean_acc']*100, test_acc*100)) 102 | else: 103 | logger.log('Standard Accuracy-\tTest: {:.2f}%.'.format(test_acc*100)) 104 | epoch_metrics = {'train_'+k: v for k, v in res.items()} 105 | epoch_metrics.update({'epoch': epoch, 'lr': last_lr, 'test_clean_acc': test_acc, 'test_adversarial_acc': ''}) 106 | 107 | if epoch % args.adv_eval_freq == 0 or epoch > (NUM_ADV_EPOCHS-5) or (epoch >= (NUM_ADV_EPOCHS-10) and NUM_ADV_EPOCHS > 90): 108 | test_adv_acc = trainer.eval(test_dataloader, adversarial=True) 109 | logger.log('Adversarial Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.'.format(res['adversarial_acc']*100, 110 | test_adv_acc*100)) 111 | epoch_metrics.update({'test_adversarial_acc': test_adv_acc}) 112 | else: 113 | logger.log('Adversarial Accuracy-\tTrain: {:.2f}%.'.format(res['adversarial_acc']*100)) 114 | 115 | if test_adv_acc >= old_score[1]: 116 | old_score[0], old_score[1] = test_acc, test_adv_acc 117 | trainer.save_model(WEIGHTS) 118 | trainer.save_model(os.path.join(LOG_DIR, 'weights-last.pt')) 119 | 120 | logger.log('Time taken: {}'.format(format_time(time.time()-start))) 121 | metrics = metrics.append(pd.DataFrame(epoch_metrics, index=[0]), ignore_index=True) 122 | metrics.to_csv(os.path.join(LOG_DIR, 'stats_adv.csv'), index=False) 123 | 124 | 125 | 126 | # Record metrics 127 | 128 | train_acc = res['clean_acc'] if 'clean_acc' in res else trainer.eval(train_dataloader) 129 | logger.log('\nTraining completed.') 130 | logger.log('Standard Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.'.format(train_acc*100, old_score[0]*100)) 131 | if NUM_ADV_EPOCHS > 0: 132 | logger.log('Adversarial Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.'.format(res['adversarial_acc']*100, old_score[1]*100)) 133 | 134 | logger.log('Script Completed.') 135 | --------------------------------------------------------------------------------