├── .gitignore ├── LICENSE.md ├── README.md ├── advbench ├── algorithms.py ├── attacks.py ├── command_launchers.py ├── datasets.py ├── evalulation_methods.py ├── hparams_registry.py ├── lib │ ├── meters.py │ ├── misc.py │ ├── plotting.py │ └── reporting.py ├── model_selection.py ├── networks.py ├── optimizers.py ├── plotting │ ├── acc_and_loss.py │ ├── acc_and_loss_aug.py │ ├── cvar.py │ ├── learning_curve.py │ ├── multi_cvar.py │ ├── pareto.py │ └── primal_dual.py └── scripts │ ├── check_progress.py │ ├── collect_losses.py │ ├── collect_results.py │ ├── sweep.py │ └── train.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | advbench/data/* 2 | __pycache__/ 3 | .vscode 4 | TODO.md -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Alex Robey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # advbench 2 | 3 | This repository contains the code needed to reproduce the results of the following papers: 4 | 5 | * "Adversarial Robustness with Semi-Infinite Constrained Learning" (NeurIPS 2021) by Alexander Robey, Luiz F.O. Chamon, George J. Pappas, Hamed Hassani, and Alejandro Ribeiro. 6 | * "Probabilistically Robust Learning: Balancing Average and Worst-case Performance" (ICML 2022) by Alexander Robey, Luiz F. O. Chamon, George J. Pappas, and Hamed Hassani 7 | 8 | If you find this repository useful in your research, please consider citing: 9 | 10 | ``` 11 | @article{robey2021adversarial, 12 | title={Adversarial robustness with semi-infinite constrained learning}, 13 | author={Robey, Alexander and Chamon, Luiz and Pappas, George J and Hassani, Hamed and Ribeiro, Alejandro}, 14 | journal={Advances in Neural Information Processing Systems}, 15 | volume={34}, 16 | pages={6198--6215}, 17 | year={2021} 18 | } 19 | @inproceedings{robey2022probabilistically, 20 | title={Probabilistically Robust Learning: Balancing Average and Worst-case Performance}, 21 | author={Robey, Alexander and Chamon, Luiz and Pappas, George J and Hassani, Hamed}, 22 | booktitle={International Conference on Machine Learning}, 23 | pages={18667--18686}, 24 | year={2022}, 25 | organization={PMLR} 26 | } 27 | ``` 28 | 29 | --- 30 | 31 | ### Overview 32 | 33 | This repository contains code for reproducing our results, including implementations of each of the baseline algorithms used in our paper. At present, we support the following baseline algorithms: 34 | 35 | * Empirical risk minimization (ERM, [Vapnik, 1998](https://www.wiley.com/en-fr/Statistical+Learning+Theory-p-9780471030034)) 36 | * Projected gradient ascent (PGD, [Madry et al., 2017](https://arxiv.org/abs/1706.06083)) 37 | * Fast gradient sign method (FGSM, [Goodfellow et al., 2014](https://arxiv.org/abs/1412.6572)) 38 | * Clean logit pairing (CLP, [Kannan et al., 2018](https://arxiv.org/abs/1803.06373)) 39 | * Adversarial logit pairing (ALP, [Kannan et al., 2018](https://arxiv.org/abs/1803.06373)) 40 | * Theoretically principled trade-off between robustness and accuracy (TRADES, [Zhang et al., 2019](https://arxiv.org/abs/1901.08573)) 41 | * Misclassification-aware adversarial training (MART, [Wang et al., 2020](https://openreview.net/forum?id=rklOg6EFwS)) 42 | 43 | We also support several versions of our own algorithm. 44 | 45 | * Dual Adversarial Learning with Gaussian prior (Gaussian_DALE) 46 | * Dual Adversarial Learning with Laplacian prior (Laplacian_DALE) 47 | * Dual Adversarial Learning with KL-divergence loss (KL_DALE) 48 | 49 | --- 50 | 51 | ### Repository structure 52 | 53 | The structure of this repository is based on the (excellent) [domainbed](https://github.com/facebookresearch/DomainBed) repository. All of the runnable scripts are located in the `advbench.scripts/` and `advbench.plotting` directories. 54 | 55 | --- 56 | 57 | ### Quick start 58 | 59 | Train a model: 60 | 61 | ``` 62 | python -m advbench.scripts.train --dataset CIFAR10 --algorithm KL_DALE_PD --output_dir train-output --evaluators Clean PGD 63 | ``` 64 | 65 | Tally the results: 66 | 67 | ``` 68 | python -m advbench.scripts.collect_results --depth 0 --input_dir train-output 69 | ``` -------------------------------------------------------------------------------- /advbench/algorithms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | import pandas as pd 6 | import numpy as np 7 | import torch.optim as optim 8 | 9 | from advbench import networks 10 | from advbench import optimizers 11 | from advbench import attacks 12 | from advbench.lib import meters 13 | 14 | ALGORITHMS = [ 15 | 'ERM', 16 | 'PGD', 17 | 'FGSM', 18 | 'TRADES', 19 | 'ALP', 20 | 'CLP', 21 | 'Gaussian_DALE', 22 | 'Laplacian_DALE', 23 | 'Gaussian_DALE_PD', 24 | 'Gaussian_DALE_PD_Reverse', 25 | 'KL_DALE_PD', 26 | 'CVaR_SGD', 27 | 'CVaR_SGD_Autograd', 28 | 'CVaR_SGD_PD', 29 | 'ERM_DataAug', 30 | 'TERM', 31 | 'RandSmoothing' 32 | ] 33 | 34 | class Algorithm(nn.Module): 35 | def __init__(self, input_shape, num_classes, hparams, device): 36 | super(Algorithm, self).__init__() 37 | self.hparams = hparams 38 | self.classifier = networks.Classifier( 39 | input_shape, num_classes, hparams) 40 | self.optimizer = optim.SGD( 41 | self.classifier.parameters(), 42 | lr=hparams['learning_rate'], 43 | momentum=hparams['sgd_momentum'], 44 | weight_decay=hparams['weight_decay']) 45 | self.device = device 46 | 47 | self.meters = OrderedDict() 48 | self.meters['Loss'] = meters.AverageMeter() 49 | self.meters_df = None 50 | 51 | def step(self, imgs, labels): 52 | raise NotImplementedError 53 | 54 | def predict(self, imgs): 55 | return self.classifier(imgs) 56 | 57 | @staticmethod 58 | def img_clamp(imgs): 59 | return torch.clamp(imgs, 0.0, 1.0) 60 | 61 | def reset_meters(self): 62 | for meter in self.meters.values(): 63 | meter.reset() 64 | 65 | def meters_to_df(self, epoch): 66 | if self.meters_df is None: 67 | columns = ['Epoch'] + list(self.meters.keys()) 68 | self.meters_df = pd.DataFrame(columns=columns) 69 | 70 | values = [epoch] + [m.avg for m in self.meters.values()] 71 | self.meters_df.loc[len(self.meters_df)] = values 72 | return self.meters_df 73 | 74 | class ERM(Algorithm): 75 | def __init__(self, input_shape, num_classes, hparams, device): 76 | super(ERM, self).__init__(input_shape, num_classes, hparams, device) 77 | 78 | def step(self, imgs, labels): 79 | self.optimizer.zero_grad() 80 | loss = F.cross_entropy(self.predict(imgs), labels) 81 | loss.backward() 82 | self.optimizer.step() 83 | 84 | self.meters['Loss'].update(loss.item(), n=imgs.size(0)) 85 | 86 | class ERM_DataAug(Algorithm): 87 | def __init__(self, input_shape, num_classes, hparams, device): 88 | super(ERM_DataAug, self).__init__(input_shape, num_classes, hparams, device) 89 | 90 | def sample_deltas(self, imgs): 91 | eps = self.hparams['epsilon'] 92 | return 2 * eps * torch.rand_like(imgs) - eps 93 | 94 | def step(self, imgs, labels): 95 | self.optimizer.zero_grad() 96 | loss = 0 97 | for _ in range(self.hparams['cvar_sgd_M']): 98 | loss += F.cross_entropy(self.predict(imgs), labels) 99 | 100 | loss = loss / float(self.hparams['cvar_sgd_M']) 101 | loss.backward() 102 | self.optimizer.step() 103 | 104 | self.meters['Loss'].update(loss.item(), n=imgs.size(0)) 105 | 106 | class TERM(Algorithm): 107 | def __init__(self, input_shape, num_classes, hparams, device): 108 | super(TERM, self).__init__(input_shape, num_classes, hparams, device) 109 | self.meters['tilted loss'] = meters.AverageMeter() 110 | self.t = torch.tensor(self.hparams['term_t']) 111 | 112 | def step(self, imgs, labels): 113 | self.optimizer.zero_grad() 114 | loss = F.cross_entropy(self.predict(imgs), labels, reduction='none') 115 | term_loss = torch.log(torch.exp(self.t * loss).mean() + 1e-6) / self.t 116 | term_loss.backward() 117 | self.optimizer.step() 118 | 119 | self.meters['Loss'].update(loss.mean().item(), n=imgs.size(0)) 120 | self.meters['tilted loss'].update(term_loss.item(), n=imgs.size(0)) 121 | 122 | class PGD(Algorithm): 123 | def __init__(self, input_shape, num_classes, hparams, device): 124 | super(PGD, self).__init__(input_shape, num_classes, hparams, device) 125 | self.attack = attacks.PGD_Linf(self.classifier, self.hparams, device) 126 | 127 | def step(self, imgs, labels): 128 | 129 | adv_imgs = self.attack(imgs, labels) 130 | self.optimizer.zero_grad() 131 | loss = F.cross_entropy(self.predict(adv_imgs), labels) 132 | loss.backward() 133 | self.optimizer.step() 134 | 135 | self.meters['Loss'].update(loss.item(), n=imgs.size(0)) 136 | 137 | class RandSmoothing(Algorithm): 138 | def __init__(self, input_shape, num_classes, hparams, device): 139 | super(RandSmoothing, self).__init__(input_shape, num_classes, hparams, device) 140 | self.attack = attacks.SmoothAdv(self.classifier, self.hparams, device) 141 | 142 | def step(self, imgs, labels): 143 | 144 | adv_imgs = self.attack(imgs, labels) 145 | self.optimizer.zero_grad() 146 | loss = F.cross_entropy(self.predict(adv_imgs), labels) 147 | loss.backward() 148 | self.optimizer.step() 149 | 150 | self.meters['Loss'].update(loss.item(), n=imgs.size(0)) 151 | 152 | class FGSM(Algorithm): 153 | def __init__(self, input_shape, num_classes, hparams, device): 154 | super(FGSM, self).__init__(input_shape, num_classes, hparams, device) 155 | self.attack = attacks.FGSM_Linf(self.classifier, self.hparams, device) 156 | 157 | def step(self, imgs, labels): 158 | 159 | adv_imgs = self.attack(imgs, labels) 160 | self.optimizer.zero_grad() 161 | loss = F.cross_entropy(self.predict(adv_imgs), labels) 162 | loss.backward() 163 | self.optimizer.step() 164 | 165 | self.meters['Loss'].update(loss.item(), n=imgs.size(0)) 166 | 167 | class TRADES(Algorithm): 168 | def __init__(self, input_shape, num_classes, hparams, device): 169 | super(TRADES, self).__init__(input_shape, num_classes, hparams, device) 170 | self.kl_loss_fn = nn.KLDivLoss(reduction='batchmean') # TODO(AR): let's write a method to do the log-softmax part 171 | self.attack = attacks.TRADES_Linf(self.classifier, self.hparams, device) 172 | 173 | self.meters['clean loss'] = meters.AverageMeter() 174 | self.meters['invariance loss'] = meters.AverageMeter() 175 | 176 | def step(self, imgs, labels): 177 | 178 | adv_imgs = self.attack(imgs, labels) 179 | self.optimizer.zero_grad() 180 | clean_loss = F.cross_entropy(self.predict(adv_imgs), labels) 181 | robust_loss = self.kl_loss_fn( 182 | F.log_softmax(self.predict(adv_imgs), dim=1), 183 | F.softmax(self.predict(imgs), dim=1)) 184 | total_loss = clean_loss + self.hparams['trades_beta'] * robust_loss 185 | total_loss.backward() 186 | self.optimizer.step() 187 | 188 | self.meters['Loss'].update(total_loss.item(), n=imgs.size(0)) 189 | self.meters['clean loss'].update(clean_loss.item(), n=imgs.size(0)) 190 | self.meters['invariance loss'].update(robust_loss.item(), n=imgs.size(0)) 191 | 192 | return {'loss': total_loss.item()} 193 | 194 | class LogitPairingBase(Algorithm): 195 | def __init__(self, input_shape, num_classes, hparams, device): 196 | super(LogitPairingBase, self).__init__(input_shape, num_classes, hparams, device) 197 | self.attack = attacks.PGD_Linf(self.classifier, self.hparams, device) 198 | self.meters['logit loss'] = meters.AverageMeter() 199 | 200 | def pairing_loss(self, imgs, adv_imgs): 201 | logit_diff = self.predict(adv_imgs) - self.predict(imgs) 202 | return torch.norm(logit_diff, dim=1).mean() 203 | 204 | class ALP(LogitPairingBase): 205 | def __init__(self, input_shape, num_classes, hparams, device): 206 | super(ALP, self).__init__(input_shape, num_classes, hparams, device) 207 | self.attack = attacks.PGD_Linf(self.classifier, self.hparams, device) 208 | self.meters['robust loss'] = meters.AverageMeter() 209 | 210 | def step(self, imgs, labels): 211 | adv_imgs = self.attack(imgs, labels) 212 | self.optimizer.zero_grad() 213 | robust_loss = F.cross_entropy(self.predict(adv_imgs), labels) 214 | logit_pairing_loss = self.pairing_loss(imgs, adv_imgs) 215 | total_loss = robust_loss + logit_pairing_loss 216 | total_loss.backward() 217 | self.optimizer.step() 218 | 219 | self.meters['Loss'].update(total_loss.item(), n=imgs.size(0)) 220 | self.meters['robust loss'].update(robust_loss.item(), n=imgs.size(0)) 221 | self.meters['logit loss'].update(logit_pairing_loss.item(), n=imgs.size(0)) 222 | 223 | class CLP(LogitPairingBase): 224 | def __init__(self, input_shape, num_classes, hparams, device): 225 | super(CLP, self).__init__(input_shape, num_classes, hparams, device) 226 | self.attack = attacks.PGD_Linf(self.classifier, self.hparams, device) 227 | 228 | self.meters['clean loss'] = meters.AverageMeter() 229 | 230 | def step(self, imgs, labels): 231 | adv_imgs = self.attack(imgs, labels) 232 | self.optimizer.zero_grad() 233 | clean_loss = F.cross_entropy(self.predict(imgs), labels) 234 | logit_pairing_loss = self.pairing_loss(imgs, adv_imgs) 235 | total_loss = clean_loss + logit_pairing_loss 236 | total_loss.backward() 237 | self.optimizer.step() 238 | 239 | self.meters['Loss'].update(total_loss.item(), n=imgs.size(0)) 240 | self.meters['clean loss'].update(clean_loss.item(), n=imgs.size(0)) 241 | self.meters['logit loss'].update(logit_pairing_loss.item(), n=imgs.size(0)) 242 | 243 | class MART(Algorithm): 244 | def __init__(self, input_shape, num_classes, hparams, device): 245 | super(MART, self).__init__(input_shape, num_classes, hparams, device) 246 | self.kl_loss_fn = nn.KLDivLoss(reduction='none') 247 | self.attack = attacks.PGD_Linf(self.classifier, self.hparams, device) 248 | 249 | self.meters['robust loss'] = meters.AverageMeter() 250 | self.meters['invariance loss'] = meters.AverageMeter() 251 | 252 | def step(self, imgs, labels): 253 | 254 | adv_imgs = self.attack(imgs, labels) 255 | self.optimizer.zero_grad() 256 | clean_output = self.classifier(imgs) 257 | adv_output = self.classifier(adv_imgs) 258 | adv_probs = F.softmax(adv_output, dim=1) 259 | tmp1 = torch.argsort(adv_probs, dim=1)[:, -2:] 260 | new_label = torch.where(tmp1[:, -1] == labels, tmp1[:, -2], tmp1[:, -1]) 261 | loss_adv = F.cross_entropy(adv_output, labels) + F.nll_loss(torch.log(1.0001 - adv_probs + 1e-12), new_label) 262 | nat_probs = F.softmax(clean_output, dim=1) 263 | true_probs = torch.gather(nat_probs, 1, (labels.unsqueeze(1)).long()).squeeze() 264 | loss_robust = (1.0 / imgs.size(0)) * torch.sum( 265 | torch.sum(self.kl_loss_fn(torch.log(adv_probs + 1e-12), nat_probs), dim=1) * (1.0000001 - true_probs)) 266 | loss = loss_adv + self.hparams['mart_beta'] * loss_robust 267 | loss.backward() 268 | self.optimizer.step() 269 | 270 | self.meters['Loss'].update(loss.item(), n=imgs.size(0)) 271 | self.meters['robust loss'].update(loss_robust.item(), n=imgs.size(0)) 272 | self.meters['invariance loss'].update(loss_adv.item(), n=imgs.size(0)) 273 | 274 | 275 | class MMA(Algorithm): 276 | pass 277 | 278 | class Gaussian_DALE(Algorithm): 279 | def __init__(self, input_shape, num_classes, hparams, device): 280 | super(Gaussian_DALE, self).__init__(input_shape, num_classes, hparams, device) 281 | self.attack = attacks.LMC_Gaussian_Linf(self.classifier, self.hparams, device) 282 | self.meters['clean loss'] = meters.AverageMeter() 283 | self.meters['robust loss'] = meters.AverageMeter() 284 | 285 | def step(self, imgs, labels): 286 | adv_imgs = self.attack(imgs, labels) 287 | self.optimizer.zero_grad() 288 | clean_loss = F.cross_entropy(self.predict(imgs), labels) 289 | robust_loss = F.cross_entropy(self.predict(adv_imgs), labels) 290 | total_loss = robust_loss + self.hparams['g_dale_nu'] * clean_loss 291 | total_loss.backward() 292 | self.optimizer.step() 293 | 294 | self.meters['Loss'].update(total_loss.item(), n=imgs.size(0)) 295 | self.meters['clean loss'].update(clean_loss.item(), n=imgs.size(0)) 296 | self.meters['robust loss'].update(robust_loss.item(), n=imgs.size(0)) 297 | 298 | class Laplacian_DALE(Algorithm): 299 | def __init__(self, input_shape, num_classes, hparams, device): 300 | super(Laplacian_DALE, self).__init__(input_shape, num_classes, hparams, device) 301 | self.attack = attacks.LMC_Laplacian_Linf(self.classifier, self.hparams, device) 302 | self.meters['clean loss'] = meters.AverageMeter() 303 | self.meters['robust loss'] = meters.AverageMeter() 304 | 305 | def step(self, imgs, labels): 306 | adv_imgs = self.attack(imgs, labels) 307 | self.optimizer.zero_grad() 308 | clean_loss = F.cross_entropy(self.predict(imgs), labels) 309 | robust_loss = F.cross_entropy(self.predict(adv_imgs), labels) 310 | total_loss = robust_loss + self.hparams['l_dale_nu'] * clean_loss 311 | total_loss.backward() 312 | self.optimizer.step() 313 | 314 | self.meters['Loss'].update(total_loss.item(), n=imgs.size(0)) 315 | self.meters['clean loss'].update(clean_loss.item(), n=imgs.size(0)) 316 | self.meters['robust loss'].update(robust_loss.item(), n=imgs.size(0)) 317 | 318 | class PrimalDualBase(Algorithm): 319 | def __init__(self, input_shape, num_classes, hparams, device): 320 | super(PrimalDualBase, self).__init__(input_shape, num_classes, hparams, device) 321 | self.dual_params = {'dual_var': torch.tensor(1.0).to(self.device)} 322 | self.meters['clean loss'] = meters.AverageMeter() 323 | self.meters['robust loss'] = meters.AverageMeter() 324 | self.meters['dual variable'] = meters.AverageMeter() 325 | 326 | class Gaussian_DALE_PD(PrimalDualBase): 327 | def __init__(self, input_shape, num_classes, hparams, device): 328 | super(Gaussian_DALE_PD, self).__init__(input_shape, num_classes, hparams, device) 329 | self.attack = attacks.LMC_Gaussian_Linf(self.classifier, self.hparams, device) 330 | self.pd_optimizer = optimizers.PrimalDualOptimizer( 331 | parameters=self.dual_params, 332 | margin=self.hparams['g_dale_pd_margin'], 333 | eta=self.hparams['g_dale_pd_step_size']) 334 | 335 | def step(self, imgs, labels): 336 | adv_imgs = self.attack(imgs, labels) 337 | self.optimizer.zero_grad() 338 | clean_loss = F.cross_entropy(self.predict(imgs), labels) 339 | robust_loss = F.cross_entropy(self.predict(adv_imgs), labels) 340 | total_loss = robust_loss + self.dual_params['dual_var'] * clean_loss 341 | total_loss.backward() 342 | self.optimizer.step() 343 | self.pd_optimizer.step(clean_loss.detach()) 344 | 345 | self.meters['Loss'].update(total_loss.item(), n=imgs.size(0)) 346 | self.meters['clean loss'].update(clean_loss.item(), n=imgs.size(0)) 347 | self.meters['robust loss'].update(robust_loss.item(), n=imgs.size(0)) 348 | self.meters['dual variable'].update(self.dual_params['dual_var'].item(), n=1) 349 | 350 | class CVaR_SGD_Autograd(Algorithm): 351 | def __init__(self, input_shape, num_classes, hparams, device): 352 | super(CVaR_SGD_Autograd, self).__init__(input_shape, num_classes, hparams, device) 353 | self.meters['avg t'] = meters.AverageMeter() 354 | self.meters['plain loss'] = meters.AverageMeter() 355 | 356 | def sample_deltas(self, imgs): 357 | eps = self.hparams['epsilon'] 358 | return 2 * eps * torch.rand_like(imgs) - eps 359 | 360 | def step(self, imgs, labels): 361 | 362 | beta, M = self.hparams['cvar_sgd_beta'], self.hparams['cvar_sgd_M'] 363 | ts = torch.ones(size=(imgs.size(0),)).to(self.device) 364 | 365 | self.optimizer.zero_grad() 366 | for _ in range(self.hparams['cvar_sgd_n_steps']): 367 | 368 | ts.requires_grad = True 369 | cvar_loss = 0 370 | for _ in range(M): 371 | pert_imgs = self.img_clamp(imgs + self.sample_deltas(imgs)) 372 | curr_loss = F.cross_entropy(self.predict(pert_imgs), labels, reduction='none') 373 | cvar_loss += F.relu(curr_loss - ts) 374 | 375 | cvar_loss = (ts + cvar_loss / (float(M) * beta)).mean() 376 | grad_ts = torch.autograd.grad(cvar_loss, [ts])[0].detach() 377 | ts = ts - self.hparams['cvar_sgd_t_step_size'] * grad_ts 378 | ts = ts.detach() 379 | 380 | plain_loss, cvar_loss = 0, 0 381 | for _ in range(M): 382 | pert_imgs = self.img_clamp(imgs + self.sample_deltas(imgs)) 383 | curr_loss = F.cross_entropy(self.predict(pert_imgs), labels, reduction='none') 384 | plain_loss += curr_loss.mean() 385 | cvar_loss += F.relu(curr_loss - ts) 386 | 387 | cvar_loss = (cvar_loss / (beta * float(M))).mean() 388 | 389 | cvar_loss.backward() 390 | self.optimizer.step() 391 | 392 | self.meters['Loss'].update(cvar_loss.item(), n=imgs.size(0)) 393 | self.meters['avg t'].update(ts.mean().item(), n=imgs.size(0)) 394 | self.meters['plain loss'].update(plain_loss.item() / M, n=imgs.size(0)) 395 | 396 | class CVaR_SGD(Algorithm): 397 | def __init__(self, input_shape, num_classes, hparams, device): 398 | super(CVaR_SGD, self).__init__(input_shape, num_classes, hparams, device) 399 | self.meters['avg t'] = meters.AverageMeter() 400 | self.meters['plain loss'] = meters.AverageMeter() 401 | 402 | def sample_deltas(self, imgs): 403 | eps = self.hparams['epsilon'] 404 | return 2 * eps * torch.rand_like(imgs) - eps 405 | 406 | def step(self, imgs, labels): 407 | 408 | beta = self.hparams['cvar_sgd_beta'] 409 | M = self.hparams['cvar_sgd_M'] 410 | ts = torch.ones(size=(imgs.size(0),)).to(self.device) 411 | 412 | self.optimizer.zero_grad() 413 | for _ in range(self.hparams['cvar_sgd_n_steps']): 414 | 415 | plain_loss, cvar_loss, indicator_sum = 0, 0, 0 416 | for _ in range(self.hparams['cvar_sgd_M']): 417 | pert_imgs = self.img_clamp(imgs + self.sample_deltas(imgs)) 418 | curr_loss = F.cross_entropy(self.predict(pert_imgs), labels, reduction='none') 419 | indicator_sum += torch.where(curr_loss > ts, torch.ones_like(ts), torch.zeros_like(ts)) 420 | 421 | plain_loss += curr_loss.mean() 422 | cvar_loss += F.relu(curr_loss - ts) 423 | 424 | indicator_avg = indicator_sum / float(M) 425 | cvar_loss = (ts + cvar_loss / (float(M) * beta)).mean() 426 | 427 | # gradient update on ts 428 | grad_ts = (1 - (1 / beta) * indicator_avg) / float(imgs.size(0)) 429 | ts = ts - self.hparams['cvar_sgd_t_step_size'] * grad_ts 430 | 431 | cvar_loss.backward() 432 | self.optimizer.step() 433 | 434 | self.meters['Loss'].update(cvar_loss.item(), n=imgs.size(0)) 435 | self.meters['avg t'].update(ts.mean().item(), n=imgs.size(0)) 436 | self.meters['plain loss'].update(plain_loss.item() / M, n=imgs.size(0)) 437 | 438 | class CVaR_SGD_PD(Algorithm): 439 | def __init__(self, input_shape, num_classes, hparams, device): 440 | super(CVaR_SGD_PD, self).__init__(input_shape, num_classes, hparams, device) 441 | self.dual_params = {'dual_var': torch.tensor(1.0).to(self.device)} 442 | self.meters['avg t'] = meters.AverageMeter() 443 | self.meters['plain loss'] = meters.AverageMeter() 444 | self.meters['dual variable'] = meters.AverageMeter() 445 | self.pd_optimizer = optimizers.PrimalDualOptimizer( 446 | parameters=self.dual_params, 447 | margin=self.hparams['g_dale_pd_margin'], 448 | eta=self.hparams['g_dale_pd_step_size']) 449 | 450 | def sample_deltas(self, imgs): 451 | eps = self.hparams['epsilon'] 452 | return 2 * eps * torch.rand_like(imgs) - eps 453 | 454 | def step(self, imgs, labels): 455 | 456 | beta = self.hparams['cvar_sgd_beta'] 457 | M = self.hparams['cvar_sgd_M'] 458 | ts = torch.ones(size=(imgs.size(0),)).to(self.device) 459 | 460 | self.optimizer.zero_grad() 461 | for _ in range(self.hparams['cvar_sgd_n_steps']): 462 | 463 | plain_loss, cvar_loss, indicator_sum = 0, 0, 0 464 | for _ in range(self.hparams['cvar_sgd_M']): 465 | pert_imgs = self.img_clamp(imgs + self.sample_deltas(imgs)) 466 | curr_loss = F.cross_entropy(self.predict(pert_imgs), labels, reduction='none') 467 | indicator_sum += torch.where(curr_loss > ts, torch.ones_like(ts), torch.zeros_like(ts)) 468 | 469 | plain_loss += curr_loss.mean() 470 | cvar_loss += F.relu(curr_loss - ts) 471 | 472 | indicator_avg = indicator_sum / float(M) 473 | cvar_loss = (ts + cvar_loss / (float(M) * beta)).mean() 474 | 475 | # gradient update on ts 476 | grad_ts = (1 - (1 / beta) * indicator_avg) / float(imgs.size(0)) 477 | ts = ts - self.hparams['cvar_sgd_t_step_size'] * grad_ts 478 | 479 | loss = cvar_loss + self.dual_params['dual_var'] * (plain_loss / float(M)) 480 | loss.backward() 481 | self.optimizer.step() 482 | self.pd_optimizer.step(plain_loss.detach() / M) 483 | 484 | self.meters['Loss'].update(cvar_loss.item(), n=imgs.size(0)) 485 | self.meters['avg t'].update(ts.mean().item(), n=imgs.size(0)) 486 | self.meters['plain loss'].update(plain_loss.item() / M, n=imgs.size(0)) 487 | self.meters['dual variable'].update(self.dual_params['dual_var'].item(), n=1) 488 | 489 | class Gaussian_DALE_PD_Reverse(PrimalDualBase): 490 | def __init__(self, input_shape, num_classes, hparams, device): 491 | super(Gaussian_DALE_PD_Reverse, self).__init__(input_shape, num_classes, hparams, device) 492 | self.attack = attacks.LMC_Gaussian_Linf(self.classifier, self.hparams, device) 493 | self.pd_optimizer = optimizers.PrimalDualOptimizer( 494 | parameters=self.dual_params, 495 | margin=self.hparams['g_dale_pd_margin'], 496 | eta=self.hparams['g_dale_pd_step_size']) 497 | 498 | def step(self, imgs, labels): 499 | adv_imgs = self.attack(imgs, labels) 500 | self.optimizer.zero_grad() 501 | clean_loss = F.cross_entropy(self.predict(imgs), labels) 502 | robust_loss = F.cross_entropy(self.predict(adv_imgs), labels) 503 | total_loss = clean_loss + self.dual_params['dual_var'] * robust_loss 504 | total_loss.backward() 505 | self.optimizer.step() 506 | self.pd_optimizer.step(robust_loss.detach()) 507 | 508 | self.meters['Loss'].update(total_loss.item(), n=imgs.size(0)) 509 | self.meters['clean loss'].update(clean_loss.item(), n=imgs.size(0)) 510 | self.meters['robust loss'].update(robust_loss.item(), n=imgs.size(0)) 511 | self.meters['dual variable'].update(self.dual_params['dual_var'].item(), n=1) 512 | 513 | class KL_DALE_PD(PrimalDualBase): 514 | def __init__(self, input_shape, num_classes, hparams, device): 515 | super(KL_DALE_PD, self).__init__(input_shape, num_classes, hparams, device) 516 | self.attack = attacks.TRADES_Linf(self.classifier, self.hparams, device) 517 | self.kl_loss_fn = nn.KLDivLoss(reduction='batchmean') 518 | self.pd_optimizer = optimizers.PrimalDualOptimizer( 519 | parameters=self.dual_params, 520 | margin=self.hparams['g_dale_pd_margin'], 521 | eta=self.hparams['g_dale_pd_step_size']) 522 | 523 | def step(self, imgs, labels): 524 | adv_imgs = self.attack(imgs, labels) 525 | self.optimizer.zero_grad() 526 | clean_loss = F.cross_entropy(self.predict(imgs), labels) 527 | robust_loss = self.kl_loss_fn( 528 | F.log_softmax(self.predict(adv_imgs), dim=1), 529 | F.softmax(self.predict(imgs), dim=1)) 530 | total_loss = robust_loss + self.dual_params['dual_var'] * clean_loss 531 | total_loss.backward() 532 | self.optimizer.step() 533 | self.pd_optimizer.step(clean_loss.detach()) 534 | 535 | self.meters['Loss'].update(total_loss.item(), n=imgs.size(0)) 536 | self.meters['clean loss'].update(clean_loss.item(), n=imgs.size(0)) 537 | self.meters['robust loss'].update(robust_loss.item(), n=imgs.size(0)) 538 | self.meters['dual variable'].update(self.dual_params['dual_var'].item(), n=1) -------------------------------------------------------------------------------- /advbench/attacks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions.laplace import Laplace 5 | 6 | class Attack(nn.Module): 7 | def __init__(self, classifier, hparams, device): 8 | super(Attack, self).__init__() 9 | self.classifier = classifier 10 | self.hparams = hparams 11 | self.device = device 12 | 13 | def forward(self, imgs, labels): 14 | raise NotImplementedError 15 | 16 | class Attack_Linf(Attack): 17 | def __init__(self, classifier, hparams, device): 18 | super(Attack_Linf, self).__init__(classifier, hparams, device) 19 | 20 | def _clamp_perturbation(self, imgs, adv_imgs): 21 | """Clamp a perturbed image so that (1) the perturbation is bounded 22 | in the l_inf norm by self.hparams['epsilon'] and (2) so that the 23 | perturbed image is in [0, 1]^d.""" 24 | 25 | eps = self.hparams['epsilon'] 26 | adv_imgs = torch.min(torch.max(adv_imgs, imgs - eps), imgs + eps) 27 | return torch.clamp(adv_imgs, 0.0, 1.0) 28 | 29 | class PGD_Linf(Attack_Linf): 30 | def __init__(self, classifier, hparams, device): 31 | super(PGD_Linf, self).__init__(classifier, hparams, device) 32 | 33 | def forward(self, imgs, labels): 34 | self.classifier.eval() 35 | 36 | adv_imgs = imgs.detach() # + 0.001 * torch.randn(imgs.shape).to(self.device).detach() #AR: is this detach necessary? 37 | for _ in range(self.hparams['pgd_n_steps']): 38 | adv_imgs.requires_grad_(True) 39 | with torch.enable_grad(): 40 | adv_loss = F.cross_entropy(self.classifier(adv_imgs), labels) 41 | grad = torch.autograd.grad(adv_loss, [adv_imgs])[0].detach() 42 | adv_imgs = adv_imgs + self.hparams['pgd_step_size']* torch.sign(grad) 43 | adv_imgs = self._clamp_perturbation(imgs, adv_imgs) 44 | 45 | self.classifier.train() 46 | return adv_imgs.detach() # this detach may not be necessary 47 | 48 | class SmoothAdv(Attack_Linf): 49 | def __init__(self, classifier, hparams, device): 50 | super(SmoothAdv, self).__init__(classifier, hparams, device) 51 | 52 | def sample_deltas(self, imgs): 53 | sigma = self.hparams['rand_smoothing_sigma'] 54 | return sigma * torch.randn_like(imgs) 55 | 56 | def forward(self, imgs, labels): 57 | self.classifier.eval() 58 | 59 | adv_imgs = imgs.detach() 60 | for _ in range(self.hparams['rand_smoothing_n_steps']): 61 | adv_imgs.requires_grad_(True) 62 | loss = 0. 63 | for _ in range(self.hparams['rand_smoothing_n_samples']): 64 | deltas = self.sample_deltas(imgs) 65 | loss += F.softmax(self.classifier(adv_imgs + deltas), dim=1)[range(imgs.size(0)), labels] 66 | 67 | total_loss = -1. * torch.log(loss / self.hparams['rand_smoothing_n_samples']).mean() 68 | grad = torch.autograd.grad(total_loss, [adv_imgs])[0].detach() 69 | adv_imgs = imgs + self.hparams['rand_smoothing_step_size'] * torch.sign(grad) 70 | adv_imgs = self._clamp_perturbation(imgs, adv_imgs) 71 | 72 | self.classifier.train() 73 | return adv_imgs.detach() # this detach may not be necessary 74 | 75 | 76 | class TRADES_Linf(Attack_Linf): 77 | def __init__(self, classifier, hparams, device): 78 | super(TRADES_Linf, self).__init__(classifier, hparams, device) 79 | self.kl_loss_fn = nn.KLDivLoss(reduction='batchmean') # AR: let's write a method to do the log-softmax part 80 | 81 | def forward(self, imgs, labels): 82 | self.classifier.eval() 83 | 84 | adv_imgs = imgs.detach() + 0.001 * torch.randn(imgs.shape).to(self.device).detach() #AR: is this detach necessary? 85 | for _ in range(self.hparams['trades_n_steps']): 86 | adv_imgs.requires_grad_(True) 87 | with torch.enable_grad(): 88 | adv_loss = self.kl_loss_fn( 89 | F.log_softmax(self.classifier(adv_imgs), dim=1), # AR: Note that this means that we can't have softmax at output of classifier 90 | F.softmax(self.classifier(imgs), dim=1)) 91 | 92 | grad = torch.autograd.grad(adv_loss, [adv_imgs])[0].detach() 93 | adv_imgs = adv_imgs + self.hparams['trades_step_size']* torch.sign(grad) 94 | adv_imgs = self._clamp_perturbation(imgs, adv_imgs) 95 | 96 | self.classifier.train() 97 | return adv_imgs.detach() # this detach may not be necessary 98 | 99 | class FGSM_Linf(Attack): 100 | def __init__(self, classifier, hparams, device): 101 | super(FGSM_Linf, self).__init__(classifier, hparams, device) 102 | 103 | def forward(self, imgs, labels): 104 | self.classifier.eval() 105 | 106 | imgs.requires_grad = True 107 | adv_loss = F.cross_entropy(self.classifier(imgs), labels) 108 | grad = torch.autograd.grad(adv_loss, [imgs])[0].detach() 109 | adv_imgs = imgs + self.hparams['epsilon'] * grad.sign() 110 | adv_imgs = torch.clamp(adv_imgs, 0.0, 1.0) 111 | 112 | self.classifier.train() 113 | 114 | return adv_imgs.detach() 115 | 116 | class LMC_Gaussian_Linf(Attack_Linf): 117 | def __init__(self, classifier, hparams, device): 118 | super(LMC_Gaussian_Linf, self).__init__(classifier, hparams, device) 119 | 120 | def forward(self, imgs, labels): 121 | self.classifier.eval() 122 | batch_size = imgs.size(0) 123 | 124 | adv_imgs = imgs.detach() + 0.001 * torch.randn(imgs.shape).to(self.device).detach() #AR: is this detach necessary? 125 | for _ in range(self.hparams['g_dale_n_steps']): 126 | adv_imgs.requires_grad_(True) 127 | with torch.enable_grad(): 128 | adv_loss = torch.log(1 - torch.softmax(self.classifier(adv_imgs), dim=1)[range(batch_size), labels]).mean() 129 | # adv_loss = F.cross_entropy(self.classifier(adv_imgs), labels) 130 | grad = torch.autograd.grad(adv_loss, [adv_imgs])[0].detach() 131 | noise = torch.randn_like(adv_imgs).to(self.device).detach() 132 | 133 | adv_imgs = adv_imgs + self.hparams['g_dale_step_size'] * torch.sign(grad) + self.hparams['g_dale_noise_coeff'] * noise 134 | adv_imgs = self._clamp_perturbation(imgs, adv_imgs) 135 | 136 | self.classifier.train() 137 | 138 | return adv_imgs.detach() 139 | 140 | class LMC_Laplacian_Linf(Attack_Linf): 141 | def __init__(self, classifier, hparams, device): 142 | super(LMC_Laplacian_Linf, self).__init__(classifier, hparams, device) 143 | 144 | def forward(self, imgs, labels): 145 | self.classifier.eval() 146 | batch_size = imgs.size(0) 147 | noise_dist = Laplace(torch.tensor(0.), torch.tensor(1.)) 148 | 149 | adv_imgs = imgs.detach() + 0.001 * torch.randn(imgs.shape).to(self.device).detach() #AR: is this detach necessary? 150 | for _ in range(self.hparams['l_dale_n_steps']): 151 | adv_imgs.requires_grad_(True) 152 | with torch.enable_grad(): 153 | adv_loss = torch.log(1 - torch.softmax(self.classifier(adv_imgs), dim=1)[range(batch_size), labels]).mean() 154 | grad = torch.autograd.grad(adv_loss, [adv_imgs])[0].detach() 155 | noise = noise_dist.sample(grad.shape) 156 | adv_imgs = adv_imgs + self.hparams['l_dale_step_size'] * torch.sign(grad + self.hparams['l_dale_noise_coeff'] * noise) 157 | adv_imgs = self._clamp_perturbation(imgs, adv_imgs) 158 | 159 | self.classifier.train() 160 | return adv_imgs.detach() -------------------------------------------------------------------------------- /advbench/command_launchers.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import subprocess 4 | import time 5 | import torch 6 | 7 | def local_launcher(commands): 8 | for cmd in commands: 9 | subprocess.call(cmd, shell=True) 10 | 11 | def dummy_launcher(commands): 12 | for cmd in commands: 13 | print(f'Dummy launcher: {cmd}') 14 | 15 | def multi_gpu_launcher(commands): 16 | 17 | n_gpus = torch.cuda.device_count() 18 | procs_by_gpu = [None for _ in range(n_gpus)] 19 | 20 | while len(commands) > 0: 21 | for gpu_idx in range(n_gpus): 22 | proc = procs_by_gpu[gpu_idx] 23 | 24 | if (proc is None) or (proc.poll() is not None): 25 | # Nothing is running on this GPU; launch a command 26 | cmd = commands.pop(0) 27 | new_proc = subprocess.Popen( 28 | f'CUDA_VISIBLE_DEVICES={gpu_idx} {cmd}', 29 | shell=True) 30 | procs_by_gpu[gpu_idx] = new_proc 31 | break 32 | 33 | time.sleep(1) 34 | 35 | # Wait for the last few tasks to finish before returning 36 | for p in procs_by_gpu: 37 | if p is not None: 38 | p.wait() 39 | 40 | 41 | REGISTRY = { 42 | 'local': local_launcher, 43 | 'dummy': dummy_launcher, 44 | 'multi_gpu': multi_gpu_launcher 45 | } -------------------------------------------------------------------------------- /advbench/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Subset, ConcatDataset, TensorDataset 3 | import torchvision.transforms as transforms 4 | from torchvision.datasets import CIFAR10 as CIFAR10_ 5 | from torchvision.datasets import MNIST as TorchvisionMNIST 6 | from torchvision.datasets import SVHN as SVHN_ 7 | 8 | SPLITS = ['train', 'val', 'test'] 9 | DATASETS = ['CIFAR10', 'MNIST', 'SVHN'] 10 | 11 | class AdvRobDataset: 12 | 13 | N_WORKERS = 8 # Default, subclasses may override 14 | INPUT_SHAPE = None # Subclasses should override 15 | NUM_CLASSES = None # Subclasses should override 16 | N_EPOCHS = None # Subclasses should override 17 | CHECKPOINT_FREQ = None # Subclasses should override 18 | LOG_INTERVAL = None # Subclasses should override 19 | HAS_LR_SCHEDULE = False # Default, subclass may override 20 | ON_DEVICE = False # Default, subclass may override 21 | 22 | def __init__(self, device): 23 | self.splits = dict.fromkeys(SPLITS) 24 | self.device = device 25 | 26 | class CIFAR10(AdvRobDataset): 27 | 28 | INPUT_SHAPE = (3, 32, 32) 29 | NUM_CLASSES = 10 30 | N_EPOCHS = 115 31 | CHECKPOINT_FREQ = 10 32 | LOG_INTERVAL = 100 33 | HAS_LR_SCHEDULE = True 34 | 35 | def __init__(self, root, device): 36 | super(CIFAR10, self).__init__(device) 37 | 38 | train_transforms = transforms.Compose([ 39 | transforms.RandomCrop(32, padding=4), 40 | transforms.RandomHorizontalFlip(), 41 | transforms.ToTensor()]) 42 | test_transforms = transforms.ToTensor() 43 | 44 | train_data = CIFAR10_(root, train=True, transform=train_transforms) 45 | self.splits['train'] = train_data 46 | # self.splits['train'] = Subset(train_data, range(5000)) 47 | 48 | train_data = CIFAR10_(root, train=True, transform=train_transforms) 49 | self.splits['val'] = Subset(train_data, range(45000, 50000)) 50 | 51 | self.splits['test'] = CIFAR10_(root, train=False, transform=test_transforms) 52 | 53 | @staticmethod 54 | def adjust_lr(optimizer, epoch, hparams): 55 | lr = hparams['learning_rate'] 56 | if epoch >= 55: # 150 57 | lr = hparams['learning_rate'] * 0.1 58 | if epoch >= 75: # 175 59 | lr = hparams['learning_rate'] * 0.01 60 | if epoch >= 90: # 190 61 | lr = hparams['learning_rate'] * 0.001 62 | for param_group in optimizer.param_groups: 63 | param_group['lr'] = lr 64 | 65 | class MNISTTensor(AdvRobDataset): 66 | 67 | N_WORKERS = 0 # Needs to be zero so we don't fetch from GPU 68 | INPUT_SHAPE = (1, 28, 28) 69 | NUM_CLASSES = 10 70 | N_EPOCHS = 50 71 | CHECKPOINT_FREQ = 10 72 | LOG_INTERVAL = 100 73 | HAS_LR_SCHEDULE = True 74 | ON_DEVICE = True 75 | 76 | def __init__(self, root, device): 77 | super(MNISTTensor, self).__init__(device) 78 | 79 | train_data = TorchvisionMNIST( 80 | root=root, 81 | train=True, 82 | transform=transforms.ToTensor()) 83 | test_data = TorchvisionMNIST( 84 | root=root, 85 | train=False, 86 | transform=transforms.ToTensor()) 87 | 88 | all_imgs = torch.cat(( 89 | train_data.data, 90 | test_data.data)).reshape(-1, 1, 28, 28).float().to(self.device) 91 | all_labels = torch.cat(( 92 | train_data.targets, 93 | test_data.targets)).to(self.device) 94 | 95 | self.splits = { 96 | 'train': TensorDataset(all_imgs, all_labels), 97 | 'validation': TensorDataset(all_imgs, all_labels), 98 | 'test': TensorDataset(all_imgs, all_labels) 99 | } 100 | 101 | @staticmethod 102 | def adjust_lr(optimizer, epoch, hparams): 103 | 104 | lr = hparams['learning_rate'] 105 | if epoch >= 25: 106 | lr = hparams['learning_rate'] * 0.1 107 | if epoch >= 35: 108 | lr = hparams['learning_rate'] * 0.01 109 | if epoch >= 40: 110 | lr = hparams['learning_rate'] * 0.001 111 | for param_group in optimizer.param_groups: 112 | param_group['lr'] = lr 113 | 114 | class MNIST(AdvRobDataset): 115 | 116 | INPUT_SHAPE = (1, 28, 28) 117 | NUM_CLASSES = 10 118 | N_EPOCHS = 50 119 | CHECKPOINT_FREQ = 10 120 | LOG_INTERVAL = 100 121 | HAS_LR_SCHEDULE = True 122 | 123 | def __init__(self, root, device): 124 | super(MNIST, self).__init__(device) 125 | 126 | train_data = TorchvisionMNIST( 127 | root=root, 128 | train=True, 129 | transform=transforms.ToTensor()) 130 | test_data = TorchvisionMNIST( 131 | root=root, 132 | train=False, 133 | transform=transforms.ToTensor()) 134 | 135 | # self.splits = { 136 | # 'train': Subset(train_data, range(54000)), 137 | # 'validation': Subset(train_data, range(54000, 60000)), 138 | # 'test': test_data 139 | # } 140 | 141 | all_data = ConcatDataset([train_data, test_data]) 142 | self.splits = { 143 | 'train': all_data, 144 | 'validation': all_data, 145 | 'test': all_data 146 | } 147 | 148 | @staticmethod 149 | def adjust_lr(optimizer, epoch, hparams): 150 | 151 | lr = hparams['learning_rate'] 152 | if epoch >= 25: 153 | lr = hparams['learning_rate'] * 0.1 154 | if epoch >= 35: 155 | lr = hparams['learning_rate'] * 0.01 156 | if epoch >= 40: 157 | lr = hparams['learning_rate'] * 0.001 158 | for param_group in optimizer.param_groups: 159 | param_group['lr'] = lr 160 | 161 | class SVHN(AdvRobDataset): 162 | 163 | INPUT_SHAPE = (3, 32, 32) 164 | NUM_CLASSES = 10 165 | N_EPOCHS = 115 166 | CHECKPOINT_FREQ = 10 167 | LOG_INTERVAL = 100 168 | HAS_LR_SCHEDULE = False 169 | 170 | def __init__(self, root, device): 171 | super(SVHN, self).__init__(device) 172 | 173 | train_transforms = transforms.Compose([ 174 | transforms.RandomCrop(32, padding=4), 175 | transforms.RandomHorizontalFlip(), 176 | transforms.ToTensor()]) 177 | test_transforms = transforms.ToTensor() 178 | 179 | train_data = SVHN_(root, split='train', transform=train_transforms, download=True) 180 | self.splits['train'] = train_data 181 | self.splits['test'] = SVHN_(root, split='test', transform=test_transforms, download=True) 182 | 183 | @staticmethod 184 | def adjust_lr(optimizer, epoch, hparams): 185 | lr = hparams['learning_rate'] 186 | if epoch >= 55: # 150 187 | lr = hparams['learning_rate'] * 0.1 188 | if epoch >= 75: # 175 189 | lr = hparams['learning_rate'] * 0.01 190 | if epoch >= 90: # 190 191 | lr = hparams['learning_rate'] * 0.001 192 | for param_group in optimizer.param_groups: 193 | param_group['lr'] = lr -------------------------------------------------------------------------------- /advbench/evalulation_methods.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from advbench import attacks 5 | 6 | class Evaluator: 7 | 8 | # Sub-class should over-ride 9 | NAME = '' 10 | 11 | def __init__(self, algorithm, device, test_hparams): 12 | self.algorithm = algorithm 13 | self.device = device 14 | self.test_hparams = test_hparams 15 | 16 | def calculate(self, loader): 17 | raise NotImplementedError 18 | 19 | def sample_perturbations(self, imgs): 20 | eps = self.test_hparams['epsilon'] 21 | return 2 * eps * torch.rand_like(imgs) - eps 22 | 23 | @staticmethod 24 | def clamp_imgs(imgs): 25 | return torch.clamp(imgs, 0.0, 1.0) 26 | 27 | class Clean(Evaluator): 28 | """Calculates the standard accuracy of a classifier.""" 29 | 30 | NAME = 'Clean' 31 | 32 | def __init__(self, algorithm, device, test_hparams): 33 | super(Clean, self).__init__(algorithm, device, test_hparams) 34 | 35 | @torch.no_grad() 36 | def calculate(self, loader): 37 | self.algorithm.eval() 38 | 39 | correct, total, loss_sum = 0, 0, 0 40 | for imgs, labels in loader: 41 | imgs, labels = imgs.to(self.device), labels.to(self.device) 42 | logits = self.algorithm.predict(imgs) 43 | loss_sum += F.cross_entropy(logits, labels, reduction='sum').item() 44 | preds = logits.argmax(dim=1, keepdim=True) 45 | correct += preds.eq(labels.view_as(preds)).sum().item() 46 | total += imgs.size(0) 47 | 48 | self.algorithm.train() 49 | return { 50 | f'{self.NAME}-Accuracy': 100. * correct / total, 51 | f'{self.NAME}-Loss': loss_sum / total 52 | } 53 | 54 | class Adversarial(Evaluator): 55 | """Calculates the adversarial accuracy of a classifier.""" 56 | 57 | def __init__(self, algorithm, device, attack, test_hparams): 58 | super(Adversarial, self).__init__(algorithm, device, test_hparams) 59 | self.attack = attack 60 | 61 | def calculate(self, loader): 62 | self.algorithm.eval() 63 | 64 | correct, total, loss_sum = 0, 0, 0 65 | for imgs, labels in loader: 66 | imgs, labels = imgs.to(self.device), labels.to(self.device) 67 | adv_imgs = self.attack(imgs, labels) 68 | 69 | with torch.no_grad(): 70 | logits = self.algorithm.predict(adv_imgs) 71 | loss_sum += F.cross_entropy(logits, labels, reduction='sum').item() 72 | 73 | preds = logits.argmax(dim=1, keepdim=True) 74 | correct += preds.eq(labels.view_as(preds)).sum().item() 75 | total += imgs.size(0) 76 | 77 | self.algorithm.train() 78 | return { 79 | f'{self.NAME}-Accuracy': 100. * correct / total, 80 | f'{self.NAME}-Loss': float(loss_sum) / total 81 | } 82 | 83 | class PGD(Adversarial): 84 | """Calculates the PGD adversarial accuracy of a classifier.""" 85 | 86 | NAME = 'PGD' 87 | 88 | def __init__(self, algorithm, device, test_hparams): 89 | 90 | attack = attacks.PGD_Linf( 91 | classifier=algorithm.classifier, 92 | hparams=test_hparams, 93 | device=device) 94 | super(PGD, self).__init__( 95 | algorithm=algorithm, 96 | device=device, 97 | attack=attack, 98 | test_hparams=test_hparams) 99 | 100 | class FGSM(Adversarial): 101 | """Calculates the FGSM adversarial accuracy of a classifier.""" 102 | 103 | NAME = 'FGSM' 104 | 105 | def __init__(self, algorithm, device, test_hparams): 106 | 107 | attack = attacks.FGSM_Linf( 108 | classifier=algorithm.classifier, 109 | hparams=test_hparams, 110 | device=device) 111 | super(FGSM, self).__init__( 112 | algorithm=algorithm, 113 | device=device, 114 | attack=attack, 115 | test_hparams=test_hparams) 116 | 117 | class CVaR(Evaluator): 118 | """Calculates the CVaR loss of a classifier.""" 119 | 120 | NAME = 'CVaR' 121 | 122 | def __init__(self, algorithm, device, test_hparams): 123 | super(CVaR, self).__init__(algorithm, device, test_hparams) 124 | self.q = self.test_hparams['cvar_sgd_beta'] 125 | self.n_cvar_steps = self.test_hparams['cvar_sgd_n_steps'] 126 | self.M = self.test_hparams['cvar_sgd_M'] 127 | self.step_size = self.test_hparams['cvar_sgd_t_step_size'] 128 | 129 | @torch.no_grad() 130 | def calculate(self, loader): 131 | self.algorithm.eval() 132 | 133 | loss_sum, total = 0, 0 134 | for imgs, labels in loader: 135 | imgs, labels = imgs.to(self.device), labels.to(self.device) 136 | 137 | ts = torch.zeros(size=(imgs.size(0),)).to(self.device) 138 | 139 | # perform n steps of optimization to compute inner inf 140 | for _ in range(self.n_cvar_steps): 141 | 142 | cvar_loss, indicator_sum = 0, 0 143 | 144 | # number of samples in innner expectation in def. of CVaR 145 | for _ in range(self.M): 146 | perturbations = self.sample_perturbations(imgs) 147 | perturbed_imgs = self.clamp_imgs(imgs + perturbations) 148 | preds = self.algorithm.predict(perturbed_imgs) 149 | loss = F.cross_entropy(preds, labels, reduction='none') 150 | 151 | indicator_sum += torch.where( 152 | loss > ts, 153 | torch.ones_like(ts), 154 | torch.zeros_like(ts)) 155 | cvar_loss += F.relu(loss - ts) 156 | 157 | indicator_avg = indicator_sum / float(self.M) 158 | cvar_loss = (ts + cvar_loss / (self.M * self.q)).mean() 159 | 160 | # gradient update on ts 161 | grad_ts = (1 - (1 / self.q) * indicator_avg) / float(imgs.size(0)) 162 | ts = ts - self.step_size * grad_ts 163 | 164 | loss_sum += cvar_loss.item() * imgs.size(0) 165 | total += imgs.size(0) 166 | 167 | self.algorithm.train() 168 | 169 | return {f'{self.NAME}-Loss': loss_sum / float(total)} 170 | 171 | class Augmented(Evaluator): 172 | """Calculates the augmented accuracy of a classifier.""" 173 | 174 | NAME = 'Augmented' 175 | 176 | def __init__(self, algorithm, device, test_hparams): 177 | super(Augmented, self).__init__(algorithm, device, test_hparams) 178 | self.n_aug_samples = self.test_hparams['aug_n_samples'] 179 | 180 | @staticmethod 181 | def quantile_accuracy(q, accuracy_per_datum): 182 | """Calculate q-Quantile accuracy""" 183 | 184 | # quantile predictions for each data point 185 | beta_quantile_acc_per_datum = torch.where( 186 | accuracy_per_datum > (1 - q) * 100., 187 | 100. * torch.ones_like(accuracy_per_datum), 188 | torch.zeros_like(accuracy_per_datum)) 189 | 190 | return beta_quantile_acc_per_datum.mean().item() 191 | 192 | @torch.no_grad() 193 | def calculate(self, loader): 194 | self.algorithm.eval() 195 | 196 | correct, total, loss_sum = 0, 0, 0 197 | correct_per_datum = [] 198 | 199 | for imgs, labels in loader: 200 | imgs, labels = imgs.to(self.device), labels.to(self.device) 201 | 202 | batch_correct_ls = [] 203 | for _ in range(self.n_aug_samples): 204 | perturbations = self.sample_perturbations(imgs) 205 | perturbed_imgs = self.clamp_imgs(imgs + perturbations) 206 | logits = self.algorithm.predict(perturbed_imgs) 207 | loss_sum += F.cross_entropy(logits, labels, reduction='sum').item() 208 | preds = logits.argmax(dim=1, keepdim=True) 209 | 210 | # unreduced predictions 211 | pert_preds = preds.eq(labels.view_as(preds)) 212 | 213 | # list of predictions for each data point 214 | batch_correct_ls.append(pert_preds) 215 | 216 | correct += pert_preds.sum().item() 217 | total += imgs.size(0) 218 | 219 | # number of correct predictions for each data point 220 | batch_correct = torch.sum(torch.hstack(batch_correct_ls), dim=1) 221 | correct_per_datum.append(batch_correct) 222 | 223 | # accuracy for each data point 224 | accuracy_per_datum = 100. * torch.hstack(correct_per_datum) / self.n_aug_samples 225 | 226 | self.algorithm.train() 227 | 228 | return_dict = { 229 | f'{self.NAME}-Accuracy': 100. * correct / total, 230 | f'{self.NAME}-Loss': loss_sum / total 231 | } 232 | 233 | if self.test_hparams['test_betas']: 234 | return_dict.update({ 235 | f'{self.NAME}-{q}-Quantile-Accuracy': self.quantile_accuracy(q, accuracy_per_datum) 236 | for q in self.test_hparams['test_betas'] 237 | }) 238 | 239 | return return_dict 240 | 241 | -------------------------------------------------------------------------------- /advbench/hparams_registry.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from advbench.lib import misc 4 | from advbench import datasets 5 | 6 | def default_hparams(algorithm, dataset): 7 | return {a: b for a, (b, c) in _hparams(algorithm, dataset, 0).items()} 8 | 9 | def random_hparams(algorithm, dataset, seed): 10 | return {a: c for a, (b, c) in _hparams(algorithm, dataset, seed).items()} 11 | 12 | def _hparams(algorithm: str, dataset: str, random_seed: int): 13 | """Global registry of hyperparams. Each entry is a (default, random) tuple. 14 | New algorithms / networks / etc. should add entries here. 15 | """ 16 | 17 | hparams = {} 18 | 19 | def _hparam(name, default_val, random_val_fn): 20 | """Define a hyperparameter. random_val_fn takes a RandomState and 21 | returns a random hyperparameter value.""" 22 | 23 | assert(name not in hparams) 24 | random_state = np.random.RandomState(misc.seed_hash(random_seed, name)) 25 | hparams[name] = (default_val, random_val_fn(random_state)) 26 | 27 | # Unconditional hparam definitions. 28 | 29 | _hparam('batch_size', 64, lambda r: int(2 ** r.uniform(3, 8))) 30 | 31 | # optimization 32 | _hparam('learning_rate', 0.01, lambda r: 10 ** r.uniform(-4.5, -2.5)) 33 | _hparam('sgd_momentum', 0.9, lambda r: r.uniform(0.8, 0.95)) 34 | _hparam('weight_decay', 3.5e-3, lambda r: 10 ** r.uniform(-6, -3)) 35 | 36 | if 'MNIST' in dataset: 37 | _hparam('epsilon', 0.3, lambda r: 0.3) 38 | elif dataset == 'CIFAR10' or dataset == 'SVHN': 39 | _hparam('epsilon', 0.031, lambda r: 0.031) 40 | 41 | # Algorithm specific 42 | 43 | ##### PGD ##### 44 | if 'MNIST' in dataset: 45 | _hparam('pgd_n_steps', 7, lambda r: 7) 46 | _hparam('pgd_step_size', 0.1, lambda r: r.uniform(0.05, 0.2)) 47 | elif dataset == 'CIFAR10' or dataset == 'SVHN': 48 | _hparam('pgd_n_steps', 10, lambda r: 10) 49 | _hparam('pgd_step_size', 0.007, lambda r: 0.007) 50 | 51 | ##### TRADES ##### 52 | if 'MNIST' in dataset: 53 | _hparam('trades_n_steps', 7, lambda r: 7) 54 | _hparam('trades_step_size', 0.1, lambda r: r.uniform(0.01, 0.1)) 55 | _hparam('trades_beta', 1.0, lambda r: r.uniform(0.1, 10.0)) 56 | elif dataset == 'CIFAR10' or dataset == 'SVHN': 57 | _hparam('trades_n_steps', 10, lambda r: 15) 58 | _hparam('trades_step_size', 2/255., lambda r: r.uniform(0.01, 0.1)) 59 | _hparam('trades_beta', 6.0, lambda r: r.uniform(0.1, 10.0)) 60 | 61 | ##### MART ##### 62 | if 'MNIST' in dataset: 63 | _hparam('mart_beta', 5.0, lambda r: r.uniform(0.1, 10.0)) 64 | elif dataset == 'CIFAR10' or dataset == 'SVHN': 65 | _hparam('mart_beta', 5.0, lambda r: r.uniform(0.1, 10.0)) 66 | 67 | ##### Gaussian DALE ##### 68 | if 'MNIST' in dataset: 69 | _hparam('g_dale_n_steps', 7, lambda r: 7) 70 | _hparam('g_dale_step_size', 0.1, lambda r: 0.1) 71 | _hparam('g_dale_noise_coeff', 0.001, lambda r: 10 ** r.uniform(-6.0, -2.0)) 72 | elif dataset == 'CIFAR10' or dataset == 'SVHN': 73 | _hparam('g_dale_n_steps', 10, lambda r: 10) 74 | _hparam('g_dale_step_size', 0.007, lambda r: 0.007) 75 | _hparam('g_dale_noise_coeff', 0, lambda r: 0) 76 | _hparam('g_dale_nu', 0.1, lambda r: 0.1) 77 | 78 | # DALE (Laplacian-HMC) 79 | if 'MNIST' in dataset: 80 | _hparam('l_dale_n_steps', 7, lambda r: 7) 81 | _hparam('l_dale_step_size', 0.1, lambda r: 0.1) 82 | _hparam('l_dale_noise_coeff', 0.001, lambda r: 10 ** r.uniform(-6.0, -2.0)) 83 | elif dataset == 'CIFAR10' or dataset == 'SVHN': 84 | _hparam('l_dale_n_steps', 10, lambda r: 10) 85 | _hparam('l_dale_step_size', 0.007, lambda r: 0.007) 86 | _hparam('l_dale_noise_coeff', 1e-2, lambda r: 1e-2) 87 | _hparam('l_dale_nu', 0.1, lambda r: 0.1) 88 | 89 | # DALE-PD (Gaussian-HMC) 90 | _hparam('g_dale_pd_step_size', 0.001, lambda r: 0.001) 91 | _hparam('g_dale_pd_margin', 0.1, lambda r: 0.1) 92 | 93 | # CVaR SGD 94 | _hparam('cvar_sgd_t_step_size', 1.0, lambda r: 0.001) 95 | _hparam('cvar_sgd_beta', 0.5, lambda r: 0.1) 96 | _hparam('cvar_sgd_M', 20, lambda r: 10) 97 | _hparam('cvar_sgd_n_steps', 5, lambda r: 10) 98 | 99 | # TERM 100 | _hparam('term_t', 2.0, lambda r: 1.0) 101 | 102 | # Randomized smoothing 103 | if dataset == 'CIFAR10' or dataset == 'SVHN': 104 | _hparam('rand_smoothing_sigma', 0.12, lambda r: 0.12) 105 | _hparam('rand_smoothing_n_steps', 10, lambda r: 7) 106 | _hparam('rand_smoothing_step_size', 2/255., lambda r: r.uniform(0.01, 0.1)) 107 | _hparam('rand_smoothing_n_samples', 10, lambda r: 1) 108 | elif 'MNIST' in dataset: 109 | _hparam('rand_smoothing_sigma', 0.5, lambda r: 0.12) 110 | _hparam('rand_smoothing_n_steps', 7, lambda r: 10) 111 | _hparam('rand_smoothing_step_size', 0.1, lambda r: r.uniform(0.01, 0.1)) 112 | _hparam('rand_smoothing_n_samples', 10, lambda r: 1) 113 | 114 | return hparams 115 | 116 | def test_hparams(algorithm: str, dataset: str): 117 | 118 | hparams = {} 119 | 120 | def _hparam(name, default_val): 121 | """Define a hyperparameter for test adversaries.""" 122 | 123 | assert(name not in hparams) 124 | hparams[name] = default_val 125 | 126 | _hparam('test_betas', [0.1, 0.05, 0.01]) 127 | _hparam('aug_n_samples', 100) 128 | 129 | if 'MNIST' in dataset: 130 | _hparam('epsilon', 0.3) 131 | elif dataset == 'CIFAR10' or dataset == 'SVHN': 132 | _hparam('epsilon', 8/255.) 133 | 134 | ##### PGD ##### 135 | if 'MNIST' in dataset: 136 | _hparam('pgd_n_steps', 10) 137 | _hparam('pgd_step_size', 0.1) 138 | elif dataset == 'CIFAR10' or dataset == 'SVHN': 139 | _hparam('pgd_n_steps', 20) 140 | _hparam('pgd_step_size', 0.003) 141 | 142 | ##### TRADES ##### 143 | if 'MNIST' in dataset: 144 | _hparam('trades_n_steps', 10) 145 | _hparam('trades_step_size', 0.1) 146 | elif dataset == 'CIFAR10' or dataset == 'SVHN': 147 | _hparam('trades_n_steps', 20) 148 | _hparam('trades_step_size', 2/255.) 149 | 150 | ##### CVaR SGD ##### 151 | _hparam('cvar_sgd_t_step_size', 0.5) 152 | _hparam('cvar_sgd_beta', 0.05) 153 | _hparam('cvar_sgd_M', 10) 154 | _hparam('cvar_sgd_n_steps', 10) 155 | 156 | return hparams -------------------------------------------------------------------------------- /advbench/lib/meters.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class AverageMeter: 4 | """Computes and stores the average and current value""" 5 | def __init__(self, avg_mom=0.5): 6 | self.avg_mom = avg_mom 7 | self.reset() 8 | 9 | def reset(self): 10 | self.val = 0 11 | self.avg = 0 # running average of whole epoch 12 | self.sum = 0 13 | self.count = 0 14 | 15 | def update(self, val, n=1): 16 | self.val = val 17 | self.sum += val * n 18 | self.count += n 19 | self.avg = self.sum / self.count 20 | 21 | class TimeMeter: 22 | def __init__(self): 23 | self.batch_time = AverageMeter() 24 | self.data_time = AverageMeter() 25 | self.start = time.time() 26 | 27 | def batch_start(self): 28 | self.data_time.update(time.time() - self.start) 29 | 30 | def batch_end(self): 31 | self.batch_time.update(time.time() - self.start) 32 | self.start = time.time() -------------------------------------------------------------------------------- /advbench/lib/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import hashlib 3 | import sys 4 | import os 5 | import json 6 | from functools import wraps 7 | from time import time 8 | import pandas as pd 9 | import torch.nn.functional as F 10 | import numpy as np 11 | 12 | from advbench.lib import meters 13 | 14 | def timing(f): 15 | @wraps(f) 16 | def wrap(*args, **kw): 17 | ts = time() 18 | result = f(*args, **kw) 19 | te = time() 20 | print(f'func:{f.__name__} took: {te-ts:.3f} sec') 21 | return result 22 | return wrap 23 | 24 | def seed_hash(*args): 25 | """Derive an integer hash from all args, for use as a random seed.""" 26 | 27 | args_str = str(args) 28 | return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2**31) 29 | 30 | def print_row(row, col_width=10): 31 | sep, end_ = " ", "" 32 | 33 | def format_val(x): 34 | if np.issubdtype(type(x), np.floating): 35 | x = f'{x:.5f}' 36 | return str(x).ljust(col_width)[:col_width] 37 | 38 | print(sep.join([format_val(x) for x in row]), end_) 39 | 40 | def stage_path(data_dir, name): 41 | path = os.path.join(data_dir, name) 42 | os.makedirs(path) if not os.path.exists(path) else None 43 | return path 44 | 45 | def read_dict(fname): 46 | with open(fname, 'r') as f: 47 | d = json.load(f) 48 | return d 49 | 50 | def print_full_df(df): 51 | with pd.option_context('display.max_rows', None, 'display.max_columns', None): 52 | print(df) 53 | 54 | def sample_deltas(imgs, eps): 55 | return 2 * eps * torch.rand_like(imgs) - eps 56 | 57 | def img_clamp(imgs): 58 | return torch.clamp(imgs, 0.0, 1.0) 59 | 60 | @torch.no_grad() 61 | def cvar_loss(algorithm, loader, device, test_hparams): 62 | 63 | beta, M = test_hparams['cvar_sgd_beta'], test_hparams['cvar_sgd_M'] 64 | eps = test_hparams['epsilon'] 65 | cvar_meter = meters.AverageMeter() 66 | 67 | algorithm.eval() 68 | for batch_idx, (imgs, labels) in enumerate(loader): 69 | imgs, labels = imgs.to(device), labels.to(device) 70 | 71 | ts = torch.zeros(size=(imgs.size(0),)).to(device) 72 | 73 | for _ in range(test_hparams['cvar_sgd_n_steps']): 74 | 75 | cvar_loss, indicator_sum = 0, 0 76 | for _ in range(test_hparams['cvar_sgd_M']): 77 | pert_imgs = img_clamp(imgs + sample_deltas(imgs, eps)) 78 | curr_loss = F.cross_entropy(algorithm.predict(pert_imgs), labels, reduction='none') 79 | indicator_sum += torch.where(curr_loss > ts, torch.ones_like(ts), torch.zeros_like(ts)) 80 | cvar_loss += F.relu(curr_loss - ts) 81 | 82 | indicator_avg = indicator_sum / float(M) 83 | cvar_loss = (ts + cvar_loss / (M * beta)).mean() 84 | 85 | # gradient update on ts 86 | grad_ts = (1 - (1 / beta) * indicator_avg) / float(imgs.size(0)) 87 | ts = ts - test_hparams['cvar_sgd_t_step_size'] * grad_ts 88 | 89 | cvar_meter.update(cvar_loss.item(), n=imgs.size(0)) 90 | 91 | algorithm.train() 92 | 93 | return cvar_meter.avg 94 | 95 | def cvar_grad_loss(algorithm, loader, device, test_hparams): 96 | 97 | beta, M = test_hparams['cvar_sgd_beta'], test_hparams['cvar_sgd_M'] 98 | eps = test_hparams['epsilon'] 99 | cvar_meter = meters.AverageMeter() 100 | algorithm.eval() 101 | 102 | for batch_idx, (imgs, labels) in enumerate(loader): 103 | imgs, labels = imgs.to(device), labels.to(device) 104 | ts = torch.zeros(size=(imgs.size(0),)).to(device) 105 | 106 | for _ in range(test_hparams['cvar_sgd_n_steps']): 107 | ts.requires_grad = True 108 | cvar_loss = 0 109 | for _ in range(M): 110 | pert_imgs = img_clamp(imgs + sample_deltas(imgs, eps)) 111 | curr_loss = F.cross_entropy(algorithm.predict(pert_imgs), labels, reduction='none') 112 | cvar_loss += F.relu(curr_loss - ts) 113 | 114 | cvar_loss = (ts + cvar_loss / (float(M) * beta)).mean() 115 | grad_ts = torch.autograd.grad(cvar_loss, [ts])[0].detach() 116 | ts = ts - test_hparams['cvar_sgd_t_step_size'] * grad_ts 117 | ts = ts.detach() 118 | 119 | cvar_meter.update(cvar_loss.item(), n=imgs.size(0)) 120 | 121 | algorithm.train() 122 | 123 | return cvar_meter.avg 124 | 125 | class Tee: 126 | def __init__(self, fname, mode="a"): 127 | self.stdout = sys.stdout 128 | self.file = open(fname, mode) 129 | 130 | def write(self, message): 131 | self.stdout.write(message) 132 | self.file.write(message) 133 | self.flush() 134 | 135 | def flush(self): 136 | self.stdout.flush() 137 | self.file.flush() -------------------------------------------------------------------------------- /advbench/lib/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | def remove_legend_title(ax, name_dict=None, fontsize=16): 5 | handles, labels = ax.get_legend_handles_labels() 6 | if name_dict is not None: 7 | labels = [name_dict[x] for x in labels] 8 | ax.legend(handles=handles, labels=labels, fontsize=fontsize) 9 | 10 | def adjust_legend_fontsize(ax, fontsize): 11 | handles, labels = ax.get_legend_handles_labels() 12 | ax.legend(handles=handles, labels=labels, fontsize=fontsize) 13 | 14 | def multicol_legend(ax, ncol=2): 15 | handles, labels = ax.get_legend_handles_labels() 16 | # ax.legend.remove() 17 | ax.legend(handles, labels, ncol=ncol, loc='best') 18 | 19 | def tick_density(plot, every=2, mod_val=1, axis='x'): 20 | ticks = plot.get_yticklabels() if axis == 'y' else plot.get_xticklabels() 21 | for ind, label in enumerate(ticks): 22 | if ind % every == mod_val: 23 | label.set_visible(True) 24 | else: 25 | label.set_visible(False) 26 | 27 | def show_bar_values(axs, orient="v", space=.01): 28 | def _single(ax): 29 | if orient == "v": 30 | for p in ax.patches: 31 | _x = p.get_x() + p.get_width() / 2 32 | _y = p.get_y() + p.get_height() + (p.get_height()*0.01) 33 | value = '{:.3f}'.format(p.get_height()) 34 | ax.text(_x, _y, value, ha="center") 35 | elif orient == "h": 36 | for p in ax.patches: 37 | _x = p.get_x() + p.get_width() + float(space) 38 | _y = p.get_y() + p.get_height() - (p.get_height()*0.5) 39 | value = '{:.3f}'.format(p.get_width()) 40 | ax.text(_x, _y, value, ha="left") 41 | 42 | if isinstance(axs, np.ndarray): 43 | for idx, ax in np.ndenumerate(axs): 44 | _single(ax) 45 | else: 46 | _single(axs) -------------------------------------------------------------------------------- /advbench/lib/reporting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import json 4 | import pandas as pd 5 | from typing import List 6 | 7 | def load_record(path: str) -> List[dict]: 8 | """Load the JSON stored in a given path to a list.""" 9 | 10 | records = [] 11 | with open(path, 'r') as f: 12 | for line in f: 13 | records.append(json.loads(line[:-1])) 14 | return records 15 | 16 | def load_sweep_dataframes(path, depth=1): 17 | 18 | records = [] 19 | 20 | def add_record(results_path): 21 | try: 22 | records.append(pd.read_pickle(results_path)) 23 | 24 | # want to ignore existing files 25 | # (e.g., results.txt, args.json, etc.) 26 | except IOError: 27 | pass 28 | 29 | if depth == 0: 30 | results_path = os.path.join(path, f'selection.pd') 31 | add_record(results_path) 32 | 33 | elif depth == 1: 34 | for i, subdir in list(enumerate(os.listdir(path))): 35 | results_path = os.path.join(path, subdir, f'selection.pd') 36 | add_record(results_path) 37 | 38 | else: 39 | raise ValueError(f'Depth {depth} is invalid.') 40 | 41 | return pd.concat(records, ignore_index=True) -------------------------------------------------------------------------------- /advbench/model_selection.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | 4 | 5 | class ModelSelection: 6 | def __init__(self, df): 7 | self.df = df 8 | 9 | metric_name = self.df['Metric-Name'].iloc[0] 10 | self.sort_ascending = False if 'Accuracy' in metric_name else True 11 | 12 | validation_df, test_df = self.select_epoch() 13 | 14 | validation_df['trial_rank'] = validation_df.groupby( 15 | 'trial_seed' 16 | )['Metric-Value'].rank(method='dense', ascending=self.sort_ascending) 17 | test_df['trial_rank'] = validation_df['trial_rank'].tolist() 18 | 19 | self.trial_values = [] 20 | for _, df in test_df.groupby('trial_seed'): 21 | self.trial_values.append( 22 | df[df.trial_rank == 1.0]['Metric-Value'].iloc[0]) 23 | 24 | class LastStep(ModelSelection): 25 | """Model selection from the *last* step of training.""" 26 | 27 | NAME = 'LastStep' 28 | 29 | def __init__(self, df): 30 | super(LastStep, self).__init__(df) 31 | 32 | def select_epoch(self): 33 | last_step = max(self.df.Epoch.unique()) 34 | self.df = self.df[self.df.Epoch == last_step] 35 | 36 | validation_df = self.df[self.df.Split == 'Validation'].copy() 37 | test_df = self.df[self.df.Split == 'Test'].copy() 38 | 39 | return validation_df, test_df 40 | 41 | class EarlyStop(ModelSelection): 42 | """Model selection from the *best* of training.""" 43 | 44 | NAME = 'EarlyStop' 45 | 46 | def __init__(self, df): 47 | super(EarlyStop, self).__init__(df) 48 | 49 | def select_epoch(self): 50 | validation_df = self.df[self.df.Split == 'Validation'] 51 | test_df = self.df[self.df.Split == 'Test'] 52 | 53 | validation_dfs, test_dfs = [], [] 54 | for (t, s), df in validation_df.groupby(['trial_seed', 'seed']): 55 | best_epoch = df[df['Metric-Value'] == self.find_best(df)]['Epoch'].iloc[0] 56 | 57 | validation_dfs.append( 58 | df[df.Epoch == best_epoch]) 59 | 60 | test_dfs.append( 61 | test_df[ 62 | (test_df.Epoch == best_epoch) & 63 | (test_df.seed == s) & 64 | (test_df.trial_seed == t)]) 65 | 66 | validation_df = pd.concat(validation_dfs, ignore_index=True) 67 | test_df = pd.concat(test_dfs, ignore_index=True) 68 | 69 | return validation_df, test_df 70 | 71 | def find_best(self, df): 72 | if self.sort_ascending is False: 73 | return df['Metric-Value'].max() 74 | return df['Metric-Value'].min() 75 | -------------------------------------------------------------------------------- /advbench/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | from collections import OrderedDict 6 | 7 | def Classifier(input_shape, num_classes, hparams): 8 | if input_shape[0] == 1: 9 | # return SmallCNN() 10 | return MNISTNet(input_shape, num_classes) 11 | elif input_shape[0] == 3: 12 | # return models.resnet18(num_classes=num_classes) 13 | return ResNet18() 14 | else: 15 | assert False 16 | 17 | 18 | class MNISTNet(nn.Module): 19 | def __init__(self, input_shape, num_classes): 20 | super(MNISTNet, self).__init__() 21 | self.conv1 = nn.Conv2d(input_shape[0], 32, 3, 1) 22 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 23 | self.dropout1 = nn.Dropout(0.25) 24 | self.dropout2 = nn.Dropout(0.5) 25 | self.fc1 = nn.Linear(9216, 128) 26 | self.fc2 = nn.Linear(128, num_classes) 27 | 28 | def forward(self, x): 29 | x = self.conv1(x) 30 | x = F.relu(x) 31 | x = self.conv2(x) 32 | x = F.relu(x) 33 | x = F.max_pool2d(x, 2) 34 | x = self.dropout1(x) 35 | x = torch.flatten(x, 1) 36 | x = self.fc1(x) 37 | x = F.relu(x) 38 | x = self.dropout2(x) 39 | x = self.fc2(x) 40 | return F.log_softmax(x, dim=1) #TODO(AR): might need to remove softmax for KL div in TRADES 41 | 42 | """Resnet implementation is based on the implementation found in: 43 | https://github.com/YisenWang/MART/blob/master/resnet.py 44 | """ 45 | 46 | class BasicBlock(nn.Module): 47 | expansion = 1 48 | 49 | def __init__(self, in_planes, planes, stride=1): 50 | super(BasicBlock, self).__init__() 51 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 52 | self.bn1 = nn.BatchNorm2d(planes) 53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | 56 | self.shortcut = nn.Sequential() 57 | if stride != 1 or in_planes != self.expansion * planes: 58 | self.shortcut = nn.Sequential( 59 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 60 | nn.BatchNorm2d(self.expansion * planes) 61 | ) 62 | 63 | def forward(self, x): 64 | out = F.relu(self.bn1(self.conv1(x))) 65 | out = self.bn2(self.conv2(out)) 66 | out += self.shortcut(x) 67 | out = F.relu(out) 68 | return out 69 | 70 | 71 | class Bottleneck(nn.Module): 72 | expansion = 4 73 | 74 | def __init__(self, in_planes, planes, stride=1): 75 | super(Bottleneck, self).__init__() 76 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 77 | self.bn1 = nn.BatchNorm2d(planes) 78 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 79 | self.bn2 = nn.BatchNorm2d(planes) 80 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 81 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 82 | 83 | self.shortcut = nn.Sequential() 84 | if stride != 1 or in_planes != self.expansion * planes: 85 | self.shortcut = nn.Sequential( 86 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 87 | nn.BatchNorm2d(self.expansion * planes) 88 | ) 89 | 90 | def forward(self, x): 91 | out = F.relu(self.bn1(self.conv1(x))) 92 | out = F.relu(self.bn2(self.conv2(out))) 93 | out = self.bn3(self.conv3(out)) 94 | out += self.shortcut(x) 95 | out = F.relu(out) 96 | return out 97 | 98 | 99 | class ResNet(nn.Module): 100 | def __init__(self, block, num_blocks, num_classes=10): 101 | super(ResNet, self).__init__() 102 | self.in_planes = 64 103 | 104 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 105 | self.bn1 = nn.BatchNorm2d(64) 106 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 107 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 110 | self.linear = nn.Linear(512 * block.expansion, num_classes) 111 | 112 | def _make_layer(self, block, planes, num_blocks, stride): 113 | strides = [stride] + [1] * (num_blocks - 1) 114 | layers = [] 115 | for stride in strides: 116 | layers.append(block(self.in_planes, planes, stride)) 117 | self.in_planes = planes * block.expansion 118 | return nn.Sequential(*layers) 119 | 120 | def forward(self, x): 121 | out = F.relu(self.bn1(self.conv1(x))) 122 | out = self.layer1(out) 123 | out = self.layer2(out) 124 | out = self.layer3(out) 125 | out = self.layer4(out) 126 | out = F.avg_pool2d(out, 4) 127 | out = out.view(out.size(0), -1) 128 | out = self.linear(out) 129 | return out 130 | 131 | 132 | def ResNet18(): 133 | return ResNet(BasicBlock, [2, 2, 2, 2]) 134 | 135 | def ResNet50(): 136 | return ResNet(Bottleneck, [3, 4, 6, 3]) -------------------------------------------------------------------------------- /advbench/optimizers.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | import torch 3 | 4 | class PrimalDualOptimizer: 5 | def __init__(self, parameters, margin, eta): 6 | self.parameters = parameters 7 | self.margin = margin 8 | self.eta = eta 9 | 10 | def step(self, cost): 11 | self.parameters['dual_var'] = self.relu(self.parameters['dual_var'] + self.eta * (cost - self.margin)) 12 | 13 | @staticmethod 14 | def relu(x): 15 | return x if x > 0 else torch.tensor(0).cuda() 16 | 17 | -------------------------------------------------------------------------------- /advbench/plotting/acc_and_loss.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import seaborn as sns 5 | import pandas as pd 6 | import os 7 | 8 | from advbench.lib import reporting, plotting 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser(description='Plot loss and accuracy') 12 | parser.add_argument('--input_dir', type=str, required=True) 13 | args = parser.parse_args() 14 | 15 | meters = reporting.load_records(args.input_dir, results_fname='meters', depth=0) 16 | results = reporting.load_records(args.input_dir, depth=0) 17 | 18 | loss_cols = [c for c in meters.columns if 'loss' in c] 19 | 20 | loss_meters = pd.melt( 21 | meters, 22 | id_vars=['Epoch'], 23 | value_vars=loss_cols, 24 | var_name='loss type', 25 | value_name='loss value') 26 | 27 | sns.set(style='darkgrid', font_scale=1.5, font='Palatino') 28 | fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(10, 5)) 29 | 30 | # plot training/test accuracies 31 | g = sns.lineplot( 32 | data=results, 33 | x='Epoch', 34 | y='Accuracy', 35 | hue='Eval-Method', 36 | ax=ax1, 37 | marker='o') 38 | name_dict = {'ERM': 'Clean', 'PGD_Linf': 'Adversarial'} 39 | g.set(title='Test accuracy') 40 | plotting.remove_legend_title(ax1, name_dict=name_dict) 41 | 42 | # plot training losses -- divided into clean and adv 43 | g = sns.lineplot( 44 | data=loss_meters, 45 | x='Epoch', 46 | y='loss value', 47 | hue='loss type', 48 | ax=ax2, 49 | marker='o') 50 | g.set(ylabel='Loss', title='Training loss') 51 | name_dict = {l: l.capitalize() for l in loss_cols} 52 | plotting.remove_legend_title(ax2, name_dict=name_dict) 53 | 54 | plt.subplots_adjust(bottom=0.15) 55 | 56 | save_path = os.path.join(args.input_dir, 'acc_and_loss.png') 57 | plt.savefig(save_path) -------------------------------------------------------------------------------- /advbench/plotting/acc_and_loss_aug.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import seaborn as sns 5 | import pandas as pd 6 | import os 7 | import json 8 | 9 | from advbench.lib import reporting, plotting 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser(description='Plot loss and accuracy') 13 | parser.add_argument('--input_dir', type=str, required=True) 14 | args = parser.parse_args() 15 | 16 | meters = reporting.load_records(args.input_dir, results_fname='meters', depth=0) 17 | results = reporting.load_records(args.input_dir, depth=0) 18 | 19 | with open(os.path.join(args.input_dir, 'hparams.json'), 'r') as f: 20 | hparams = json.load(f) 21 | 22 | with open(os.path.join(args.input_dir, 'test_hparams.json'), 'r') as f: 23 | test_hparams = json.load(f) 24 | 25 | loss_cols = [c for c in meters.columns if 'loss' in c] 26 | 27 | loss_meters = pd.melt( 28 | meters, 29 | id_vars=['Epoch'], 30 | value_vars=loss_cols, 31 | var_name='loss type', 32 | value_name='loss value') 33 | 34 | sns.set(style='darkgrid', font_scale=1.5, font='Palatino') 35 | fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(10, 5)) 36 | 37 | # plot training/test accuracies 38 | g = sns.lineplot( 39 | data=results, 40 | x='Epoch', 41 | y='Accuracy', 42 | hue='Eval-Method', 43 | ax=ax1, 44 | marker='o') 45 | name_dict = { 46 | 'ERM': 'Clean', 47 | 'PGD_Linf': 'Adversarial', 48 | 'Augmented-ERM': 'Augmented', 49 | } 50 | for beta in test_hparams['test_betas']: 51 | name_dict[f'{beta}-Quantile'] = f'{beta}-Quantile' 52 | g.set(title='Test accuracy') 53 | plotting.remove_legend_title(ax1, name_dict=name_dict) 54 | 55 | # plot training losses -- divided into clean and adv 56 | g = sns.lineplot( 57 | data=loss_meters, 58 | x='Epoch', 59 | y='loss value', 60 | hue='loss type', 61 | ax=ax2, 62 | marker='o') 63 | g.set(ylabel='Loss', title='Training loss') 64 | name_dict = {l: l.capitalize() for l in loss_cols} 65 | plotting.remove_legend_title(ax2, name_dict=name_dict) 66 | 67 | plt.subplots_adjust(bottom=0.15) 68 | 69 | save_path = os.path.join(args.input_dir, 'acc_and_loss.png') 70 | plt.savefig(save_path) -------------------------------------------------------------------------------- /advbench/plotting/cvar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import seaborn as sns 5 | import pandas as pd 6 | import os 7 | 8 | from advbench.lib import reporting, plotting 9 | 10 | 11 | 12 | if __name__ == '__main__': 13 | # parser = argparse.ArgumentParser(description='Plot loss and accuracy') 14 | # parser.add_argument('--input_dir', type=str, required=True) 15 | # args = parser.parse_args() 16 | 17 | sns.set(style='darkgrid', font_scale=1.5, font='Palatino') 18 | plt.figure(figsize=(12,5)) 19 | 20 | algs = [ 21 | r'CVaR SGD ($\beta=0.05$)', 22 | r'CVaR SGD ($\beta=0.01$)', 23 | r'CVaR SGD ($\beta=0.1$)', 24 | 'ERM', 25 | 'ERM w/ Data Aug', 26 | 'FGSM', 27 | 'PGD' 28 | ] 29 | cvars = [0.298, 0.882, 0.259, 5.313, 0.997, 1.092, 0.357] 30 | data = list(zip(algs, cvars)) 31 | df = pd.DataFrame(data, columns=['Algorithm', r'CVaR ($\beta=0.05$)']) 32 | 33 | g = sns.barplot( 34 | data=df, 35 | y='Algorithm', 36 | x=r'CVaR ($\beta=0.05$)', 37 | ) 38 | plt.subplots_adjust(left=0.25, bottom=0.15) 39 | g.set(ylabel='') 40 | show_values(g, orient='h') 41 | plt.savefig('cvar.png') -------------------------------------------------------------------------------- /advbench/plotting/learning_curve.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | import os 5 | 6 | from advbench.lib import reporting 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser(description='Plot learning curve') 11 | parser.add_argument('--input_dir', type=str, required=True) 12 | args = parser.parse_args() 13 | 14 | records = reporting.load_records(args.input_dir, depth=0) 15 | 16 | sns.set(style='darkgrid', font_scale=1.5) 17 | g = sns.relplot(data=records, x='Epoch', y='Accuracy', hue='Split', 18 | col='Eval-Method', kind='line', marker='o') 19 | 20 | save_path = os.path.join(args.input_dir, 'learning_curve.png') 21 | plt.savefig(save_path) -------------------------------------------------------------------------------- /advbench/plotting/multi_cvar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import seaborn as sns 5 | import pandas as pd 6 | import os 7 | 8 | from advbench.lib import reporting, plotting 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser(description='Plot CVaR results') 12 | parser.add_argument('--input_dir', type=str, required=True) 13 | args = parser.parse_args() 14 | 15 | sns.set(style='darkgrid', font_scale=1.8, font='Palatino') 16 | 17 | results = reporting.load_records(args.input_dir, depth=1) 18 | losses = reporting.load_records(args.input_dir, results_fname='losses', depth=1) 19 | 20 | g = sns.FacetGrid( 21 | results, 22 | col='Eval-Method', 23 | hue='Train-Alg', 24 | palette="colorblind", 25 | height=5, 26 | legend_out=True, 27 | col_wrap=4) 28 | 29 | g.map(plt.plot, 'Epoch', 'Accuracy', linewidth=3) 30 | g.set_titles(col_template='{col_name}') 31 | 32 | handles, labels = g.axes[0].get_legend_handles_labels() 33 | g.fig.legend( 34 | handles, 35 | labels, 36 | ncol=5, 37 | bbox_to_anchor=(0.8,0.1), 38 | frameon=False) 39 | 40 | plt.subplots_adjust(bottom=0.2) 41 | plt.savefig('cvar_accuracies.png') 42 | plt.close() 43 | 44 | sns.set(font_scale=1.5, font='Palatino') 45 | 46 | g = sns.lineplot( 47 | data=losses, 48 | x='Epoch', 49 | y='Loss', 50 | hue='Train-Alg') 51 | g.set(ylabel='CVaR') 52 | 53 | plotting.adjust_legend_fontsize(g.axes, 15) 54 | 55 | plt.subplots_adjust(bottom=0.2) 56 | plt.savefig('cvar_losses.png') -------------------------------------------------------------------------------- /advbench/plotting/pareto.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import seaborn as sns 4 | import pandas as pd 5 | import os 6 | 7 | if __name__ == '__main__': 8 | 9 | algs = ['PGD', 'FGSM', 'ALP', 'CLP', 'TRADES', 'MART', 'DALE'] 10 | clean_accs = [83.8, 72.6, 75.9, 79.8, 80.7, 78.9, 86.7] 11 | adv_accs = [48.1, 40.7, 48.8, 48.4, 49.3, 49.9, 48.9] 12 | 13 | data = list(zip(algs, clean_accs, adv_accs)) 14 | columns = ['Algorithm', 'Clean Accuracy', 'Adversarial Accuracy'] 15 | df = pd.DataFrame(data, columns=columns) 16 | df['Clean error'] = 100 - df['Clean Accuracy'] 17 | df['Adversarial error'] = 100 - df['Adversarial Accuracy'] 18 | 19 | sns.set(style='darkgrid', font_scale=1.5, font='Palatino', palette='colorblind') 20 | 21 | g = sns.scatterplot( 22 | data=df, 23 | x='Clean error', 24 | y='Adversarial error', 25 | hue='Algorithm', 26 | marker='o', 27 | legend=False) 28 | g.set( 29 | title='Parteo Frontier') 30 | # xlim=(10, 30), # clean 31 | # ylim=(40, 65)) # adversarial 32 | 33 | for i in range(df.shape[0]): 34 | plt.text( 35 | x=df['Clean error'][i] - 0.3, 36 | y=df['Adversarial error'][i]+0.3, 37 | s=df.Algorithm[i], 38 | fontdict=dict(color='black', size=10)) 39 | 40 | plt.subplots_adjust(bottom=0.15) 41 | plt.show() 42 | 43 | -------------------------------------------------------------------------------- /advbench/plotting/primal_dual.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import seaborn as sns 5 | import pandas as pd 6 | import json 7 | import os 8 | 9 | from advbench.lib import reporting, plotting 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser(description='Plot primal dual') 13 | parser.add_argument('--input_dir', type=str, required=True) 14 | args = parser.parse_args() 15 | 16 | meters = reporting.load_records(args.input_dir, results_fname='meters', depth=0) 17 | results = reporting.load_records(args.input_dir, depth=0) 18 | 19 | with open(os.path.join(args.input_dir, 'hparams.json'), 'r') as f: 20 | hparams = json.load(f) 21 | 22 | loss_meters = pd.melt( 23 | meters, 24 | id_vars=['Epoch'], 25 | value_vars=['clean loss', 'robust loss'], 26 | var_name='loss type', 27 | value_name='loss value') 28 | 29 | sns.set(style='darkgrid', font_scale=1.8, font='Palatino') 30 | fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3, figsize=(14, 5)) 31 | 32 | # plot training/test accuracies 33 | g = sns.lineplot( 34 | data=results, 35 | x='Epoch', 36 | y='Accuracy', 37 | hue='Eval-Method', 38 | ax=ax1, 39 | # marker='o', 40 | linewidth=4, 41 | palette=['#f26b22', '#6ed2fa', '#ff4d82']) 42 | name_dict = {'ERM': 'Clean', 'PGD_Linf': r'PGD$^{20}$', 'FGSM_Linf': 'FGSM'} 43 | g.set(title='Test accuracy', ylabel='Accuracy (%)') 44 | plotting.remove_legend_title(ax1, name_dict=name_dict, fontsize=16) 45 | 46 | 47 | # plot training losses -- divided into clean and adv 48 | g = sns.lineplot( 49 | data=loss_meters, 50 | x='Epoch', 51 | y='loss value', 52 | hue='loss type', 53 | ax=ax2, 54 | # marker='o', 55 | linewidth=4, 56 | palette=['#6ed2fa', '#f26b22']) 57 | g.set(ylabel='Loss', title='Training loss') 58 | ax2.axhline( 59 | hparams['g_dale_pd_margin'], 60 | ls='--', 61 | c='red', 62 | linewidth=2, 63 | zorder=10, 64 | label=r'Margin $\rho$') 65 | name_dict = {'clean loss': r'Nominal $\ell_{nom}$', 'robust loss': r'Robust $\ell_{ro}$', r'Margin $\rho$': r'Margin $\rho$'} 66 | plotting.remove_legend_title(ax2, name_dict=name_dict, fontsize=16) 67 | # plotting.tick_density(ax2, every=2, mod_val=1, axis='y') 68 | 69 | # plot dual variable 70 | g = sns.lineplot( 71 | data=meters, 72 | x='Epoch', 73 | y='dual variable', 74 | ax=ax3, 75 | # marker='o', 76 | linewidth=4, 77 | color='#f26b22') 78 | g.set(ylabel=r'Magnitude of $\nu$', title='Dual variable') 79 | 80 | plt.subplots_adjust(bottom=0.15) 81 | plt.tight_layout() 82 | 83 | save_path = os.path.join(args.input_dir, 'primal_dual.png') 84 | plt.savefig(save_path) 85 | -------------------------------------------------------------------------------- /advbench/scripts/check_progress.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | def main(args): 5 | 6 | n_dirs, n_done = 0, 0 7 | for d in os.listdir(args.input_dir): 8 | pth = os.path.join(args.input_dir, d) 9 | if os.path.isdir(pth) is True: 10 | done_pth = os.path.join(pth, 'done') 11 | if os.path.exists(done_pth): 12 | n_done += 1 13 | n_dirs += 1 14 | 15 | print(f'Completed tasks: {n_done}/{n_dirs} ({100 * float(n_done)/n_dirs:.2f}%)') 16 | 17 | 18 | if __name__ == '__main__': 19 | parser = argparse.ArgumentParser(description='check progress of sweep') 20 | parser.add_argument('--input_dir', type=str, required=True) 21 | args = parser.parse_args() 22 | main(args) 23 | -------------------------------------------------------------------------------- /advbench/scripts/collect_losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import prettytable 4 | import pandas as pd 5 | import sys 6 | import os 7 | import itertools 8 | 9 | from advbench.lib import reporting, misc 10 | from advbench import datasets 11 | 12 | def scrape_results(df, trials, loss_metric, split='Test'): 13 | 14 | all_best_dfs, all_last_dfs = [], [] 15 | for trial in trials: 16 | trial_df = df[(df['Trial-Seed'] == trial) & (df['Eval-Method'] == loss_metric) \ 17 | & (df.Split == split)] 18 | 19 | best_row = trial_df[trial_df.Loss == trial_df.Loss.min()] 20 | best_epoch = best_row.iloc[0]['Epoch'] 21 | best_path = best_row.iloc[0]['Output-Dir'] 22 | 23 | last_row = df.iloc[-1] 24 | last_epoch = last_row['Epoch'] 25 | last_path = last_row['Output-Dir'] 26 | 27 | best_df = df[(df.Epoch == best_epoch) & (df['Output-Dir'] == best_path) \ 28 | & (df['Trial-Seed'] == trial)] 29 | last_df = df[(df.Epoch == last_epoch) & (df['Output-Dir'] == last_path) \ 30 | & (df['Trial-Seed'] == trial)] 31 | all_best_dfs.append(best_df) 32 | all_last_dfs.append(last_df) 33 | 34 | best_df = pd.concat(all_best_dfs, ignore_index=True) 35 | last_df = pd.concat(all_last_dfs, ignore_index=True) 36 | 37 | return best_df, last_df 38 | 39 | if __name__ == '__main__': 40 | np.set_printoptions(suppress=True) 41 | 42 | parser = argparse.ArgumentParser(description='Collect results') 43 | parser.add_argument('--input_dir', type=str, required=True) 44 | parser.add_argument('--depth', type=int, default=1, help='Results directories search depth') 45 | args = parser.parse_args() 46 | 47 | sys.stdout = misc.Tee(os.path.join(args.input_dir, 'losses.txt'), 'w') 48 | 49 | records = reporting.load_records(args.input_dir, results_fname='losses', depth=args.depth) 50 | print(records) 51 | 52 | eval_methods = records['Eval-Method'].unique() 53 | dataset_names = records['Dataset'].unique() 54 | train_algs = records['Train-Alg'].unique() 55 | trials = records['Trial-Seed'].unique() 56 | 57 | for dataset in dataset_names: 58 | last_epoch = vars(datasets)[dataset].N_EPOCHS - 1 59 | 60 | for loss_metric in eval_methods: 61 | 62 | t = prettytable.PrettyTable() 63 | best_loss = [f'Best {m} Loss' for m in eval_methods] 64 | last_loss = [f'Last {m} Loss' for m in eval_methods] 65 | all_losses = list(itertools.chain(*zip(best_loss, last_loss))) 66 | t.field_names = ['Training Algorithm', *all_losses, 'Output-Dir'] 67 | print(f'\nSelection method: {loss_metric} loss.') 68 | for alg in train_algs: 69 | df = records[(records['Dataset'] == dataset) & (records['Train-Alg'] == alg)] 70 | best_df, last_df = scrape_results(df, trials, loss_metric, split='Test') 71 | 72 | best_losses = [best_df[best_df['Eval-Method'] == m].iloc[0]['Loss'] for m in eval_methods] 73 | last_losses = [last_df[last_df['Eval-Method'] == m].iloc[0]['Loss'] for m in eval_methods] 74 | all_losses = list(itertools.chain(*zip(best_losses, last_losses))) 75 | output_dir = best_df.iloc[0]['Output-Dir'] 76 | t.add_row([alg, *all_losses, output_dir]) 77 | 78 | print(t) -------------------------------------------------------------------------------- /advbench/scripts/collect_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import prettytable 4 | import pandas as pd 5 | import sys 6 | import os 7 | 8 | from advbench.lib import reporting, misc 9 | from advbench import model_selection 10 | 11 | if __name__ == '__main__': 12 | np.set_printoptions(suppress=True) 13 | 14 | parser = argparse.ArgumentParser(description='Collect results') 15 | parser.add_argument('--input_dir', type=str, required=True) 16 | parser.add_argument('--depth', type=int, default=1, help='Results directories search depth') 17 | parser.add_argument('--selection_methods', type=str, nargs='+', default=['LastStep', 'EarlyStop']) 18 | args = parser.parse_args() 19 | 20 | sys.stdout = misc.Tee(os.path.join(args.input_dir, 'results.txt'), 'w') 21 | 22 | selection_methods = [ 23 | vars(model_selection)[s] for s in args.selection_methods 24 | ] 25 | 26 | train_args = misc.read_dict( 27 | os.path.join(args.input_dir, 'args.json') 28 | ) 29 | selection_df = reporting.load_sweep_dataframes( 30 | path=args.input_dir, 31 | depth=1 32 | ) 33 | selection_metrics = [ 34 | k for k in selection_df.columns.values.tolist() 35 | if any(e in k for e in train_args['evaluators']) 36 | ] 37 | 38 | df = pd.melt( 39 | frame=selection_df, 40 | id_vars=['Split', 'Algorithm', 'trial_seed', 'seed', 'path', 'Epoch'] 41 | ).rename(columns={'variable': 'Metric-Name', 'value': 'Metric-Value'}) 42 | 43 | for method in selection_methods: 44 | for metric_name, metric_df in df.groupby('Metric-Name'): 45 | t = prettytable.PrettyTable() 46 | t.field_names = ['Algorithm', metric_name, 'Selection Method'] 47 | 48 | for algorithm, algorithm_df in metric_df.groupby('Algorithm'): 49 | selection = method(algorithm_df) 50 | vals = selection.trial_values 51 | mean, sd = np.mean(vals), np.std(vals) 52 | t.add_row([ 53 | algorithm, f'{mean:.4f} +/- {sd:.4f}', method.NAME 54 | ]) 55 | print(t) 56 | -------------------------------------------------------------------------------- /advbench/scripts/sweep.py: -------------------------------------------------------------------------------- 1 | import json 2 | import hashlib 3 | import os 4 | import copy 5 | import shlex 6 | import numpy as np 7 | import tqdm 8 | import shutil 9 | import argparse 10 | 11 | from advbench.lib import misc 12 | from advbench import algorithms 13 | from advbench import datasets 14 | from advbench import command_launchers 15 | 16 | def make_args_list(cl_args): 17 | 18 | def _make_args(trial_seed, dataset, algorithm, hparams_seed): 19 | return { 20 | 'dataset': dataset, 21 | 'algorithm': algorithm, 22 | 'hparams_seed': hparams_seed, 23 | 'data_dir': cl_args.data_dir, 24 | 'trial_seed': trial_seed, 25 | 'seed': misc.seed_hash(dataset, algorithm, hparams_seed, trial_seed), 26 | 'evaluators': cl_args.evaluators 27 | } 28 | 29 | args_list = [] 30 | for trial_seed in range(cl_args.n_trials): 31 | for dataset in cl_args.datasets: 32 | for algorithm in cl_args.algorithms: 33 | for hparams_seed in range(cl_args.n_hparams): 34 | args = _make_args(trial_seed, dataset, algorithm, hparams_seed) 35 | args_list.append(args) 36 | 37 | return args_list 38 | 39 | def ask_for_confirmation(): 40 | response = input('Are you sure? (y/n) ') 41 | if not response.lower().strip()[:1] == 'y': 42 | print('Nevermind!') 43 | exit(0) 44 | 45 | class Job: 46 | NOT_LAUNCHED = 'Not launched' 47 | INCOMPLETE = 'Incomplete' 48 | DONE = 'Done' 49 | 50 | def __init__(self, train_args, sweep_output_dir): 51 | args_str = json.dumps(train_args, sort_keys=True) 52 | args_hash = hashlib.md5(args_str.encode('utf-8')).hexdigest() 53 | self.output_dir = os.path.join(sweep_output_dir, args_hash) 54 | 55 | self.train_args = copy.deepcopy(train_args) 56 | self.train_args['output_dir'] = self.output_dir 57 | command = ['python', '-m', 'advbench.scripts.train'] 58 | 59 | for k, v in sorted(self.train_args.items()): 60 | if isinstance(v, list): 61 | v = ' '.join([str(v_) for v_ in v]) 62 | elif isinstance(v, str): 63 | v = shlex.quote(v) 64 | 65 | command.append(f'--{k} {v}') 66 | self.command_str = ' '.join(command) 67 | 68 | if os.path.exists(os.path.join(self.output_dir, 'done')): 69 | self.state = Job.DONE 70 | elif os.path.exists(os.path.join(self.output_dir)): 71 | self.state = Job.INCOMPLETE 72 | else: 73 | self.state = Job.NOT_LAUNCHED 74 | 75 | def __str__(self): 76 | job_info = ( 77 | self.train_args['dataset'], 78 | self.train_args['algorithm'], 79 | self.train_args['hparams_seed'] 80 | ) 81 | return f'{self.state}: {self.output_dir} {job_info}' 82 | 83 | @staticmethod 84 | def launch(jobs, launcher_fn): 85 | print('Launching...') 86 | jobs = jobs.copy() 87 | np.random.shuffle(jobs) 88 | print('Making job directories:') 89 | for job in tqdm.tqdm(jobs, leave=False): 90 | os.makedirs(job.output_dir, exist_ok=True) 91 | commands = [job.command_str for job in jobs] 92 | launcher_fn(commands) 93 | print(f'Launched {len(jobs)} jobs!') 94 | 95 | @staticmethod 96 | def delete(jobs): 97 | print('Deleting...') 98 | for job in jobs: 99 | shutil.rmtree(job.output_dir) 100 | print(f'Deleted {len(jobs)} jobs!') 101 | 102 | 103 | if __name__ == '__main__': 104 | parser = argparse.ArgumentParser(description='Run a sweep') 105 | parser.add_argument('command', choices=['launch', 'delete_incomplete']) 106 | parser.add_argument('--datasets', nargs='+', type=str, default=datasets.DATASETS) 107 | parser.add_argument('--algorithms', nargs='+', type=str, default=algorithms.ALGORITHMS) 108 | parser.add_argument('--n_hparams', type=int, default=20) 109 | parser.add_argument('--output_dir', type=str, required=True) 110 | parser.add_argument('--data_dir', type=str, required=True) 111 | parser.add_argument('--seed', type=int, default=0) 112 | parser.add_argument('--n_trials', type=int, default=1) 113 | parser.add_argument('--command_launcher', type=str, default='multi_gpu') 114 | parser.add_argument('--hparams', type=str, default=None) 115 | parser.add_argument('--evaluators', type=str, nargs='+', default=['Clean']) 116 | parser.add_argument('--skip_confirmation', action='store_true') 117 | args = parser.parse_args() 118 | 119 | args_list = make_args_list(args) 120 | jobs = [Job(train_args, args.output_dir) for train_args in args_list] 121 | 122 | for job in jobs: 123 | 124 | done_jobs = len([j for j in jobs if j.state == Job.DONE]) 125 | incomp_jobs = len([j for j in jobs if j.state == Job.INCOMPLETE]) 126 | unlaunched_jobs = len([j for j in jobs if j.state == Job.NOT_LAUNCHED]) 127 | print(job) 128 | print(f'{len(jobs)} jobs: {done_jobs} done, {incomp_jobs} incomplete, {unlaunched_jobs} not launched.') 129 | 130 | if args.command == 'launch': 131 | to_launch = [j for j in jobs if j.state == Job.NOT_LAUNCHED] 132 | print(f'About to launch {len(to_launch)} jobs.') 133 | if not args.skip_confirmation: 134 | ask_for_confirmation() 135 | launcher_fn = command_launchers.REGISTRY[args.command_launcher] 136 | Job.launch(to_launch, launcher_fn) 137 | 138 | with open(os.path.join(args.output_dir, 'args.json'), 'w') as f: 139 | json.dump(args.__dict__, f, indent=2) -------------------------------------------------------------------------------- /advbench/scripts/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torch.utils.data import DataLoader 4 | import os 5 | import json 6 | import pandas as pd 7 | import time 8 | import collections 9 | from humanfriendly import format_timespan 10 | 11 | from advbench import datasets 12 | from advbench import algorithms 13 | from advbench import evalulation_methods 14 | from advbench import hparams_registry 15 | from advbench.lib import misc, meters, reporting 16 | 17 | def main(args, hparams, test_hparams): 18 | 19 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 20 | torch.manual_seed(0) 21 | 22 | # paths for saving output 23 | json_path = os.path.join(args.output_dir, 'results.json') 24 | ckpt_path = misc.stage_path(args.output_dir, 'ckpts') 25 | train_df_path = os.path.join(args.output_dir, 'train.pd') 26 | selection_df_path = os.path.join(args.output_dir, 'selection.pd') 27 | 28 | dataset = vars(datasets)[args.dataset](args.data_dir, device) 29 | 30 | train_loader = DataLoader( 31 | dataset=dataset.splits['train'], 32 | batch_size=hparams['batch_size'], 33 | num_workers=dataset.N_WORKERS, 34 | pin_memory=False, 35 | shuffle=True) 36 | validation_loader = DataLoader( 37 | dataset=dataset.splits['validation'], 38 | batch_size=hparams['batch_size'], 39 | num_workers=dataset.N_WORKERS, 40 | pin_memory=False, 41 | shuffle=False) 42 | test_loader = DataLoader( 43 | dataset=dataset.splits['test'], 44 | batch_size=100, 45 | num_workers=dataset.N_WORKERS, 46 | pin_memory=False, 47 | shuffle=False) 48 | 49 | algorithm = vars(algorithms)[args.algorithm]( 50 | dataset.INPUT_SHAPE, 51 | dataset.NUM_CLASSES, 52 | hparams, 53 | device).to(device) 54 | 55 | def save_checkpoint(epoch): 56 | torch.save( 57 | obj={'state_dict': algorithm.state_dict()}, 58 | f=os.path.join(ckpt_path, f'model_ckpt_{epoch}.pkl') 59 | ) 60 | 61 | evaluators = [ 62 | vars(evalulation_methods)[e]( 63 | algorithm=algorithm, 64 | device=device, 65 | output_dir=args.output_dir, 66 | test_hparams=test_hparams) 67 | for e in args.evaluators] 68 | 69 | adjust_lr = None if dataset.HAS_LR_SCHEDULE is False else dataset.adjust_lr 70 | 71 | total_time = 0 72 | for epoch in range(0, dataset.N_EPOCHS): 73 | 74 | if adjust_lr is not None: 75 | adjust_lr(algorithm.optimizer, epoch, hparams) 76 | 77 | timer = meters.TimeMeter() 78 | epoch_start = time.time() 79 | for batch_idx, (imgs, labels) in enumerate(train_loader): 80 | 81 | timer.batch_start() 82 | if not dataset.ON_DEVICE: 83 | imgs, labels = imgs.to(device), labels.to(device) 84 | algorithm.step(imgs, labels) 85 | 86 | if batch_idx % dataset.LOG_INTERVAL == 0: 87 | print(f'Epoch {epoch+1}/{dataset.N_EPOCHS} ', end='') 88 | print(f'[{batch_idx * imgs.size(0)}/{len(train_loader.dataset)}', end=' ') 89 | print(f'({100. * batch_idx / len(train_loader):.0f}%)]\t', end='') 90 | for name, meter in algorithm.meters.items(): 91 | print(f'{name}: {meter.val:.3f} (avg. {meter.avg:.3f})\t', end='') 92 | print(f'Time: {timer.batch_time.val:.3f} (avg. {timer.batch_time.avg:.3f})') 93 | 94 | timer.batch_end() 95 | 96 | results = {'Epoch': epoch, 'Train': {}, 'Validation': {}, 'Test': {}} 97 | 98 | for name, meter in algorithm.meters.items(): 99 | results['Train'].update({name: meter.avg}) 100 | 101 | print('\nTrain') 102 | misc.print_row([key for key in results['Train'].keys()]) 103 | misc.print_row([results['Train'][key] for key in results['Train'].keys()]) 104 | 105 | for evaluator in evaluators: 106 | for k, v in evaluator.calculate(validation_loader).items(): 107 | results['Validation'].update({k: v}) 108 | 109 | print('\nValidation') 110 | misc.print_row([key for key in results['Validation'].keys()]) 111 | misc.print_row([results['Validation'][key] for key in results['Validation'].keys()]) 112 | 113 | for evaluator in evaluators: 114 | for k, v in evaluator.calculate(test_loader).items(): 115 | results['Test'].update({k: v}) 116 | 117 | print('\nTest') 118 | misc.print_row([key for key in results['Test'].keys()]) 119 | misc.print_row([results['Test'][key] for key in results['Test'].keys()]) 120 | 121 | epoch_time = time.time() - epoch_start 122 | total_time += epoch_time 123 | 124 | results.update({ 125 | 'Epoch-Time': epoch_time, 126 | 'Total-Time': total_time}) 127 | 128 | # print results 129 | print(f'Epoch: {epoch+1}/{dataset.N_EPOCHS}\t', end='') 130 | print(f'Epoch time: {format_timespan(epoch_time)}\t', end='') 131 | print(f'Total time: {format_timespan(total_time)}') 132 | 133 | results.update({'hparams': hparams, 'args': vars(args)}) 134 | 135 | with open(json_path, 'a') as f: 136 | f.write(json.dumps(results, sort_keys=True) + '\n') 137 | 138 | if args.save_model_every_epoch is True: 139 | save_checkpoint(epoch) 140 | 141 | algorithm.reset_meters() 142 | 143 | save_checkpoint('final') 144 | 145 | records = reporting.load_record(json_path) 146 | 147 | train_dict = collections.defaultdict(lambda: []) 148 | validation_dict = collections.defaultdict(lambda: []) 149 | test_dict = collections.defaultdict(lambda: []) 150 | 151 | for record in records: 152 | for k in records[0]['Train'].keys(): 153 | train_dict[k].append(record['Train'][k]) 154 | 155 | for k in records[0]['Validation'].keys(): 156 | validation_dict[k].append(record['Validation'][k]) 157 | test_dict[k].append(record['Test'][k]) 158 | 159 | def dict_to_dataframe(split, d): 160 | df = pd.DataFrame.from_dict(d) 161 | df['Split'] = split 162 | df = df.join(pd.DataFrame({ 163 | 'Algorithm': args.algorithm, 164 | 'trial_seed': args.trial_seed, 165 | 'seed': args.seed, 166 | 'path': args.output_dir 167 | }, index=df.index)) 168 | df['Epoch'] = range(dataset.N_EPOCHS) 169 | return df 170 | 171 | train_df = dict_to_dataframe('Train', train_dict) 172 | validation_df = dict_to_dataframe('Validation', validation_dict) 173 | test_df = dict_to_dataframe('Test', test_dict) 174 | selection_df = pd.concat([validation_df, test_df], ignore_index=True) 175 | 176 | train_df.to_pickle(train_df_path) 177 | selection_df.to_pickle(selection_df_path) 178 | 179 | with open(os.path.join(args.output_dir, 'done'), 'w') as f: 180 | f.write('done') 181 | 182 | if __name__ == '__main__': 183 | 184 | parser = argparse.ArgumentParser(description='Adversarial robustness') 185 | parser.add_argument('--data_dir', type=str, default='./advbench/data') 186 | parser.add_argument('--output_dir', type=str, default='train_output') 187 | parser.add_argument('--dataset', type=str, default='MNIST', help='Dataset to use') 188 | parser.add_argument('--algorithm', type=str, default='ERM', help='Algorithm to run') 189 | parser.add_argument('--hparams', type=str, help='JSON-serialized hparams dict') 190 | parser.add_argument('--hparams_seed', type=int, default=0, help='Seed for hyperparameters') 191 | parser.add_argument('--trial_seed', type=int, default=0, help='Trial number') 192 | parser.add_argument('--seed', type=int, default=0, help='Seed for everything else') 193 | parser.add_argument('--evaluators', type=str, nargs='+', default=['Clean']) 194 | parser.add_argument('--save_model_every_epoch', action='store_true') 195 | args = parser.parse_args() 196 | 197 | os.makedirs(os.path.join(args.output_dir), exist_ok=True) 198 | 199 | print('Args:') 200 | for k, v in sorted(vars(args).items()): 201 | print(f'\t{k}: {v}') 202 | 203 | with open(os.path.join(args.output_dir, 'args.json'), 'w') as f: 204 | json.dump(args.__dict__, f, indent=2) 205 | 206 | if args.dataset not in vars(datasets): 207 | raise NotImplementedError(f'Dataset {args.dataset} is not implemented.') 208 | 209 | if args.hparams_seed == 0: 210 | hparams = hparams_registry.default_hparams(args.algorithm, args.dataset) 211 | else: 212 | seed = misc.seed_hash(args.hparams_seed, args.trial_seed) 213 | hparams = hparams_registry.random_hparams(args.algorithm, args.dataset, seed) 214 | 215 | print ('Hparams:') 216 | for k, v in sorted(hparams.items()): 217 | print(f'\t{k}: {v}') 218 | 219 | with open(os.path.join(args.output_dir, 'hparams.json'), 'w') as f: 220 | json.dump(hparams, f, indent=2) 221 | 222 | test_hparams = hparams_registry.test_hparams(args.algorithm, args.dataset) 223 | 224 | print('Test hparams:') 225 | for k, v in sorted(test_hparams.items()): 226 | print(f'\t{k}: {v}') 227 | 228 | with open(os.path.join(args.output_dir, 'test_hparams.json'), 'w') as f: 229 | json.dump(test_hparams, f, indent=2) 230 | 231 | main(args, hparams, test_hparams) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _openmp_mutex=4.5=1_gnu 6 | blas=1.0=mkl 7 | bottleneck=1.3.2=py39hdd57654_1 8 | brotli=1.0.9=he6710b0_2 9 | bzip2=1.0.8=h7b6447c_0 10 | ca-certificates=2021.9.30=h06a4308_1 11 | certifi=2021.10.8=py39h06a4308_0 12 | cudatoolkit=11.1.74=h6bb024c_0 13 | cycler=0.10.0=py39h06a4308_0 14 | dbus=1.13.18=hb2f20db_0 15 | expat=2.4.1=h2531618_2 16 | ffmpeg=4.3=hf484d3e_0 17 | fontconfig=2.13.1=h6c09931_0 18 | fonttools=4.25.0=pyhd3eb1b0_0 19 | freetype=2.10.4=h5ab3b9f_0 20 | glib=2.69.1=h5202010_0 21 | gmp=6.2.1=h2531618_2 22 | gnutls=3.6.15=he1e5248_0 23 | gst-plugins-base=1.14.0=h8213a91_2 24 | gstreamer=1.14.0=h28cd5cc_2 25 | humanfriendly=9.2=py39h06a4308_0 26 | icu=58.2=he6710b0_3 27 | intel-openmp=2021.3.0=h06a4308_3350 28 | jpeg=9b=h024ee3a_2 29 | kiwisolver=1.3.1=py39h2531618_0 30 | lame=3.100=h7b6447c_0 31 | lcms2=2.12=h3be6417_0 32 | ld_impl_linux-64=2.35.1=h7274673_9 33 | libffi=3.3=he6710b0_2 34 | libgcc-ng=9.3.0=h5101ec6_17 35 | libgfortran-ng=7.5.0=ha8ba4b0_17 36 | libgfortran4=7.5.0=ha8ba4b0_17 37 | libgomp=9.3.0=h5101ec6_17 38 | libiconv=1.15=h63c8f33_5 39 | libidn2=2.3.2=h7f8727e_0 40 | libpng=1.6.37=hbc83047_0 41 | libstdcxx-ng=9.3.0=hd4cf53a_17 42 | libtasn1=4.16.0=h27cfd23_0 43 | libtiff=4.2.0=h85742a9_0 44 | libunistring=0.9.10=h27cfd23_0 45 | libuuid=1.0.3=h7f8727e_2 46 | libuv=1.40.0=h7b6447c_0 47 | libwebp-base=1.2.0=h27cfd23_0 48 | libxcb=1.14=h7b6447c_0 49 | libxml2=2.9.12=h03d6c58_0 50 | lz4-c=1.9.3=h295c915_1 51 | matplotlib=3.4.2=py39h06a4308_0 52 | matplotlib-base=3.4.2=py39hab158f2_0 53 | mkl=2021.3.0=h06a4308_520 54 | mkl-service=2.4.0=py39h7f8727e_0 55 | mkl_fft=1.3.0=py39h42c9631_2 56 | mkl_random=1.2.2=py39h51133e4_0 57 | munkres=1.1.4=py_0 58 | ncurses=6.2=he6710b0_1 59 | nettle=3.7.3=hbbd107a_1 60 | ninja=1.10.2=hff7bd54_1 61 | numexpr=2.7.3=py39h22e1b3c_1 62 | numpy=1.20.3=py39hf144106_0 63 | numpy-base=1.20.3=py39h74d4b33_0 64 | olefile=0.46=pyhd3eb1b0_0 65 | openh264=2.1.0=hd408876_0 66 | openjpeg=2.4.0=h3ad879b_0 67 | openssl=1.1.1l=h7f8727e_0 68 | pandas=1.3.3=py39h8c16a72_0 69 | pcre=8.45=h295c915_0 70 | pillow=8.3.1=py39h2c7a002_0 71 | pip=21.2.4=py39h06a4308_0 72 | prettytable=2.2.1=pypi_0 73 | pyparsing=2.4.7=pyhd3eb1b0_0 74 | pyqt=5.9.2=py39h2531618_6 75 | python=3.9.7=h12debd9_1 76 | python-dateutil=2.8.2=pyhd3eb1b0_0 77 | pytorch=1.9.1=py3.9_cuda11.1_cudnn8.0.5_0 78 | pytz=2021.3=pyhd3eb1b0_0 79 | qt=5.9.7=h5867ecd_1 80 | readline=8.1=h27cfd23_0 81 | scipy=1.7.1=py39h292c36d_2 82 | seaborn=0.11.2=pyhd3eb1b0_0 83 | setuptools=58.0.4=py39h06a4308_0 84 | sip=4.19.13=py39h2531618_0 85 | six=1.16.0=pyhd3eb1b0_0 86 | sqlite=3.36.0=hc218d9a_0 87 | tk=8.6.11=h1ccaba5_0 88 | torchaudio=0.9.1=py39 89 | torchvision=0.10.1=py39_cu111 90 | tornado=6.1=py39h27cfd23_0 91 | tqdm=4.62.2=pyhd3eb1b0_1 92 | typing_extensions=3.10.0.2=pyh06a4308_0 93 | tzdata=2021a=h5d7bf9c_0 94 | wcwidth=0.2.5=pypi_0 95 | wheel=0.37.0=pyhd3eb1b0_1 96 | xz=5.2.5=h7b6447c_0 97 | zlib=1.2.11=h7b6447c_3 98 | zstd=1.4.9=haebb681_0 99 | --------------------------------------------------------------------------------