├── core ├── __init__.py ├── models │ ├── __init__.py │ ├── standard_resnet.py │ └── pdeadd_resnet.py ├── scheduler.py ├── context.py ├── attacks │ ├── base.py │ ├── apgd.py │ ├── __init__.py │ ├── fgsm.py │ ├── utils.py │ ├── deepfool.py │ └── pgd.py ├── parse.py ├── trainfn.py ├── utils.py ├── testfn.py └── data.py ├── pic ├── pdeadd.png └── results.png ├── scripts ├── train │ ├── run_resume.sh │ ├── std_cifar10.sh │ ├── std_tin200.sh │ ├── std_pacs.sh │ ├── pdeadd_cifar10.sh │ ├── pdeadd_pacs.sh │ └── pdeadd_tin200.sh ├── download_cifar10c.sh └── download_cifar100c.sh ├── .gitignore ├── requirements.txt ├── results ├── cifar10c │ └── pdeadd.csv └── cifar100c │ └── pdeadd.csv ├── README.md ├── test.py └── train.py /core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pic/pdeadd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanyige/pde-add/HEAD/pic/pdeadd.png -------------------------------------------------------------------------------- /pic/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanyige/pde-add/HEAD/pic/results.png -------------------------------------------------------------------------------- /scripts/train/run_resume.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 \ 2 | python train.py \ 3 | --resume-file /home/yuanyige/Ladiff_nll/exps_cifar10/dall/wideresnet-16-4_ladiff-augdiff_ntrall_\(Csgdlr0.1cosanlr-Dadamlr0.1\)_e400_b128_atr-augmix-6-4/train/model-best-endiff.pt \ 4 | --epoch 400 5 | 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | datasets 2 | 3 | test 4 | .others 5 | .DS_Store 6 | 7 | save* 8 | exps_cifar10 9 | exps_cifar100 10 | 11 | core/__pycache__ 12 | core/attacks/__pycache__ 13 | core/models/__pycache__ 14 | core/models/cifar10/__pycache__ 15 | core/models/cifar100/__pycache__ 16 | core/models/tinyin200/__pycache__ -------------------------------------------------------------------------------- /scripts/train/std_cifar10.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 \ 2 | python train.py \ 3 | --data cifar10 \ 4 | --protocol standard \ 5 | --desc none \ 6 | --backbone resnet-18 \ 7 | --epoch 200 \ 8 | --save-dir save \ 9 | --npc-train all \ 10 | --lrC 0.02 \ 11 | --schedule cosanlr \ 12 | --aug-train none 13 | 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | autoattack @ git+https://github.com/fra31/auto-attack@c1ec340e54a227c87c6601ada3abe0910ac4a2c0 2 | robustbench @ git+https://github.com/RobustBench/robustbench.git@2317b196b482abd3523f49d368107f985e6ac9bb 3 | numpy==1.19.5 4 | pandas==1.3.5 5 | PyYAML==6.0 6 | tensorboard==2.11.0 7 | tensorboardX==2.6 8 | -------------------------------------------------------------------------------- /scripts/train/std_tin200.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=4 \ 2 | python train.py \ 3 | --data tin200 \ 4 | --protocol standard \ 5 | --desc none \ 6 | --backbone resnet-18 \ 7 | --epoch 200 \ 8 | --save-dir save_tin200 \ 9 | --npc-train all \ 10 | --lrC 0.02 \ 11 | --schedule cosanlr \ 12 | --aug-train none 13 | 14 | -------------------------------------------------------------------------------- /scripts/train/std_pacs.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=3 \ 2 | python train.py \ 3 | --data pacs-art \ 4 | --protocol standard \ 5 | --desc none \ 6 | --backbone resnet-18 \ 7 | --epoch 200 \ 8 | --save-dir save_pacs/art \ 9 | --npc-train all \ 10 | --lrC 0.02 \ 11 | --schedule cosanlr \ 12 | --aug-train augmix-10-10 13 | 14 | -------------------------------------------------------------------------------- /scripts/train/pdeadd_cifar10.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 \ 2 | python train.py \ 3 | --data cifar10 \ 4 | --data-diff cifar10 \ 5 | --protocol pdeadd \ 6 | --desc none \ 7 | --backbone resnet-18 \ 8 | --epoch 200 \ 9 | --save-dir save \ 10 | --npc-train all \ 11 | --lrC 0.06 \ 12 | --lrDiff 0.001 \ 13 | --schedule cosanlr \ 14 | --ls 0.12 \ 15 | --use-gmm \ 16 | --aug-train augmix-10-10 \ 17 | --aug-train-diff augmix-10-10 -------------------------------------------------------------------------------- /scripts/train/pdeadd_pacs.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=5 \ 2 | python train.py \ 3 | --data pacs-art \ 4 | --data-diff pacs-cartoon \ 5 | --seed 3407 \ 6 | --protocol pdeadd \ 7 | --desc none \ 8 | --backbone resnet-18 \ 9 | --epoch 200 \ 10 | --save-dir save_pacs/art-cartoon \ 11 | --npc-train all \ 12 | --lrC 0.015 \ 13 | --lrDiff 0.01 \ 14 | --schedule cosanlr \ 15 | --ls 0.1 \ 16 | --aug-train augmix-10-10 \ 17 | --aug-train-diff augmix-10-10 18 | -------------------------------------------------------------------------------- /scripts/train/pdeadd_tin200.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=4 \ 2 | python train.py \ 3 | --data tin200 \ 4 | --data-diff tin200 \ 5 | --seed 3407 \ 6 | --protocol pdeadd \ 7 | --desc none \ 8 | --backbone resnet-18 \ 9 | --epoch 200 \ 10 | --save-dir save_tin200 \ 11 | --npc-train all \ 12 | --lrC 0.02 \ 13 | --lrDiff 0.001 \ 14 | --schedule cosanlr \ 15 | --ls 0.1 \ 16 | --use-gmm \ 17 | --aug-train augmix-10-10 \ 18 | --aug-train-diff augmix-10-10 19 | -------------------------------------------------------------------------------- /core/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .standard_resnet import standard_resnet 6 | from .pdeadd_resnet import pdeadd_resnet 7 | 8 | def create_model(backbone, protocol, num_classes): 9 | net = backbone.split('-')[0] 10 | model_name = "{}_{}".format(protocol, net) 11 | print("using name: {}..".format(model_name)) 12 | func = eval(model_name) 13 | return func(name=backbone, num_classes=num_classes) 14 | -------------------------------------------------------------------------------- /core/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR, MultiStepLR 2 | 3 | class WarmUpLR(_LRScheduler): 4 | """warmup_training learning rate scheduler 5 | Args: 6 | optimizer: optimzier(e.g. SGD) 7 | total_iters: totoal_iters of warmup phase 8 | """ 9 | def __init__(self, optimizer, total_iters, last_epoch=-1): 10 | 11 | self.total_iters = total_iters 12 | super().__init__(optimizer, last_epoch) 13 | 14 | def get_lr(self): 15 | """we will use the first m batches, and set the learning 16 | rate to base_lr * m / total_iters 17 | """ 18 | return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs] 19 | 20 | 21 | def get_scheduler(args, opt): 22 | if args.scheduler =="cosanlr": 23 | return CosineAnnealingLR(opt, T_max=args.epoch) 24 | elif args.scheduler =="mslr": 25 | return MultiStepLR(opt, milestones=[60, 120, 160], gamma=0.2) #60, 120, 160 #200, 250, 300, 350 26 | elif args.scheduler =="none": 27 | return None 28 | else: 29 | raise 30 | -------------------------------------------------------------------------------- /scripts/download_cifar10c.sh: -------------------------------------------------------------------------------- 1 | FILE_ORG="./dataset/cifar-10-python.tar.gz" 2 | FILE_COR="./dataset/CIFAR-10-C.tar" 3 | 4 | # download cifar-10 5 | if [[ -f "$FILE_ORG" ]]; then 6 | echo "$FILE_ORG exists." 7 | else 8 | echo "$FILE_ORG does not exist. Start downloading..." 9 | if [[ ! -d "./dataset" ]]; then 10 | mkdir dataset 11 | fi 12 | cd dataset 13 | wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 14 | fi 15 | 16 | # download cifar-10-c 17 | if [[ -f "$FILE_COR" ]]; then 18 | echo "$FILE_COR exists. Start processing..." 19 | else 20 | echo "$FILE_COR does not exist. Start downloading..." 21 | cd dataset 22 | wget https://zenodo.org/api/files/a35f793a-6997-4603-a92f-926a6bc0fa60/CIFAR-10-C.tar 23 | echo "Download succeeded. Start processing..." 24 | fi 25 | 26 | cd dataset 27 | # unzip downloaded files 28 | tar -zxvf cifar-10-python.tar.gz 29 | tar -xvf CIFAR-10-C.tar 30 | 31 | # process data 32 | cd .. 33 | python process_cifar.py cifar-10c 34 | 35 | # for CIFAR-10-C, move original data to "severity-all" 36 | cd dataset/CIFAR-10-C 37 | if [[ ! -d "./corrupted/severity-all" ]]; then 38 | mkdir ./corrupted/severity-all 39 | fi 40 | mv *.npy ./corrupted/severity-all -------------------------------------------------------------------------------- /scripts/download_cifar100c.sh: -------------------------------------------------------------------------------- 1 | FILE_ORG="./dataset/cifar-100-python.tar.gz" 2 | FILE_COR="./dataset/CIFAR-100-C.tar" 3 | 4 | # download cifar-100 5 | if [[ -f "$FILE_ORG" ]]; then 6 | echo "$FILE_ORG exists." 7 | else 8 | echo "$FILE_ORG does not exist. Start downloading..." 9 | if [[ ! -d "./dataset" ]]; then 10 | mkdir dataset 11 | fi 12 | cd dataset 13 | wget https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz 14 | fi 15 | 16 | # download cifar-100-c 17 | if [[ -f "$FILE_COR" ]]; then 18 | echo "$FILE_COR exists. Start processing..." 19 | else 20 | echo "$FILE_COR does not exist. Start downloading..." 21 | cd dataset 22 | wget https://zenodo.org/api/files/8fafaa0e-d7e5-448b-a5af-e8b5e1bd24ce/CIFAR-100-C.tar 23 | echo "Download succeeded. Start processing..." 24 | fi 25 | 26 | cd dataset 27 | # unzip downloaded files 28 | tar -zxvf cifar-100-python.tar.gz 29 | tar -xvf CIFAR-100-C.tar 30 | 31 | # process data 32 | cd .. 33 | python process_cifar.py cifar-100c 34 | 35 | # for CIFAR-100-C, move original data to "severity-all" 36 | cd dataset/CIFAR-100-C 37 | if [[ ! -d "./corrupted/severity-all" ]]; then 38 | mkdir ./corrupted/severity-all 39 | fi 40 | mv *.npy ./corrupted/severity-all -------------------------------------------------------------------------------- /results/cifar10c/pdeadd.csv: -------------------------------------------------------------------------------- 1 | 1 2 3 4 5 avg 2 | snow 0.93860 0.9059 0.905300 0.884700 0.87020 0.900940 3 | fog 0.95520 0.9522 0.948100 0.936300 0.88620 0.935600 4 | frost 0.94140 0.9189 0.893000 0.889400 0.84810 0.898160 5 | glass_blur 0.79780 0.8016 0.825200 0.674200 0.71740 0.763240 6 | defocus_blur 0.95560 0.9541 0.947800 0.937300 0.91110 0.941180 7 | motion_blur 0.94610 0.9296 0.913100 0.915700 0.89110 0.919120 8 | zoom_blur 0.94160 0.9417 0.932600 0.923300 0.90140 0.928120 9 | gaussian_noise 0.92640 0.8853 0.826600 0.792600 0.74860 0.835900 10 | shot_noise 0.93930 0.9227 0.877400 0.850500 0.79160 0.876300 11 | impulse_noise 0.93840 0.9210 0.896600 0.815500 0.68330 0.850960 12 | pixelate 0.94740 0.9262 0.910600 0.820600 0.69790 0.860540 13 | brightness 0.95610 0.9535 0.950400 0.947200 0.93500 0.948440 14 | contrast 0.95520 0.9493 0.944600 0.935300 0.87670 0.932220 15 | jpeg_compression 0.91230 0.8855 0.875300 0.857000 0.82760 0.871540 16 | elastic_transform 0.93550 0.9360 0.928200 0.884700 0.83590 0.904060 17 | average 0.93246 0.9189 0.904987 0.870953 0.82814 0.891088 -------------------------------------------------------------------------------- /results/cifar100c/pdeadd.csv: -------------------------------------------------------------------------------- 1 | 1 2 3 4 5 avg 2 | snow 0.754900 0.67830 0.679200 0.63710 0.584600 0.666820 3 | fog 0.782600 0.76120 0.732400 0.68110 0.534500 0.698360 4 | frost 0.735700 0.67990 0.609700 0.60250 0.529800 0.631520 5 | glass_blur 0.500500 0.51430 0.533700 0.38880 0.412100 0.469880 6 | defocus_blur 0.788600 0.78620 0.775700 0.75580 0.718800 0.765020 7 | motion_blur 0.774300 0.75210 0.723900 0.72010 0.683600 0.730800 8 | zoom_blur 0.770400 0.76500 0.750700 0.73280 0.700000 0.743780 9 | gaussian_noise 0.695000 0.59270 0.472600 0.41690 0.362800 0.508000 10 | shot_noise 0.736900 0.68970 0.570000 0.51780 0.426200 0.588120 11 | impulse_noise 0.763900 0.72020 0.668500 0.53140 0.390000 0.614800 12 | pixelate 0.773800 0.74670 0.730000 0.64930 0.513200 0.682600 13 | brightness 0.788100 0.78030 0.770600 0.75590 0.714400 0.761860 14 | contrast 0.784000 0.74560 0.713600 0.65330 0.432000 0.665700 15 | jpeg_compression 0.689300 0.63610 0.612400 0.59090 0.550800 0.615900 16 | elastic_transform 0.747500 0.75310 0.745300 0.67290 0.580300 0.699820 17 | average 0.739033 0.70676 0.672553 0.62044 0.542207 0.656199 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PDE+: Enhancing Generalization via PDE with Adaptive Distributional Diffusion 2 | 3 | > Yige Yuan, Bingbing Xu, Bo Lin, Liang Hou, Fei Sun, Huawei Shen, Xueqi Cheng 4 | > 5 | > The 38th Annual AAAI Conference on Artificial Intelligence (AAAI), 2024 6 | 7 | This is an official PyTorch implementation of paper [PDE+: Enhancing Generalization via PDE with Adaptive Distributional Diffusion](https://arxiv.org/abs/2305.15835). 8 | 9 | ![PDE+](pic/pdeadd.png) 10 | 11 | 12 | ## Training & Testing 13 | 14 | All arguments are located in the **parse.py** file. You can create a script to specify the parameters. 15 | 16 | For example, you can run our PDE+ by using the following command: 17 | ``` 18 | bash ./scripts/train/pdeadd_cifar10.sh 19 | ``` 20 | Or you can run the basic ERM by using the command: 21 | ``` 22 | bash ./scripts/train/std_cifar10.sh 23 | ``` 24 | 25 | ## Full Results 26 | 27 | All detailed experimental results, formatted as CSV files, are available in the **results** directory 28 | 29 | ![Results on Corruption Datasets](pic/results.png) 30 | 31 | ## Reference 32 | 33 | If you find our work useful, please consider citing our paper: 34 | ``` 35 | @article{yuan2023pde+, 36 | title={PDE+: Enhancing Generalization via PDE with Adaptive Distributional Diffusion}, 37 | author={Yuan, Yige and Xu, Bingbing and Lin, Bo and Hou, Liang and Sun, Fei and Shen, Huawei and Cheng, Xueqi}, 38 | journal={arXiv preprint arXiv:2305.15835}, 39 | year={2023} 40 | } 41 | ``` 42 | -------------------------------------------------------------------------------- /core/context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | 4 | class ctx_noparamgrad(object): 5 | def __init__(self, module): 6 | self.prev_grad_state = get_param_grad_state(module) 7 | self.module = module 8 | set_param_grad_off(module) 9 | 10 | def __enter__(self): 11 | pass 12 | 13 | def __exit__(self, *args): 14 | set_param_grad_state(self.module, self.prev_grad_state) 15 | return False 16 | 17 | 18 | class ctx_eval(object): 19 | def __init__(self, module): 20 | self.prev_training_state = get_module_training_state(module) 21 | self.module = module 22 | set_module_training_off(module) 23 | 24 | def __enter__(self): 25 | pass 26 | 27 | def __exit__(self, *args): 28 | set_module_training_state(self.module, self.prev_training_state) 29 | return False 30 | 31 | 32 | @contextmanager 33 | def ctx_noparamgrad_and_eval(module): 34 | with ctx_noparamgrad(module) as a, ctx_eval(module) as b: 35 | yield (a, b) 36 | 37 | 38 | def get_module_training_state(module): 39 | return {mod: mod.training for mod in module.modules()} 40 | 41 | 42 | def set_module_training_state(module, training_state): 43 | for mod in module.modules(): 44 | mod.training = training_state[mod] 45 | 46 | 47 | def set_module_training_off(module): 48 | for mod in module.modules(): 49 | mod.training = False 50 | 51 | 52 | def get_param_grad_state(module): 53 | return {param: param.requires_grad for param in module.parameters()} 54 | 55 | 56 | def set_param_grad_state(module, grad_state): 57 | for param in module.parameters(): 58 | param.requires_grad = grad_state[param] 59 | 60 | 61 | def set_param_grad_off(module): 62 | for param in module.parameters(): 63 | param.requires_grad = False -------------------------------------------------------------------------------- /core/attacks/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .utils import replicate_input 5 | 6 | 7 | class Attack(object): 8 | """ 9 | Abstract base class for all attack classes. 10 | Arguments: 11 | predict (nn.Module): forward pass function. 12 | loss_fn (nn.Module): loss function. 13 | clip_min (float): mininum value per input dimension. 14 | clip_max (float): maximum value per input dimension. 15 | """ 16 | 17 | def __init__(self, predict, loss_fn, clip_min, clip_max): 18 | self.predict = predict 19 | self.loss_fn = loss_fn 20 | self.clip_min = clip_min 21 | self.clip_max = clip_max 22 | 23 | def perturb(self, x, **kwargs): 24 | """ 25 | Virtual method for generating the adversarial examples. 26 | Arguments: 27 | x (torch.Tensor): the model's input tensor. 28 | **kwargs: optional parameters used by child classes. 29 | Returns: 30 | adversarial examples. 31 | """ 32 | error = "Sub-classes must implement perturb." 33 | raise NotImplementedError(error) 34 | 35 | def __call__(self, *args, **kwargs): 36 | return self.perturb(*args, **kwargs) 37 | 38 | 39 | class LabelMixin(object): 40 | def _get_predicted_label(self, x): 41 | """ 42 | Compute predicted labels given x. Used to prevent label leaking during adversarial training. 43 | Arguments: 44 | x (torch.Tensor): the model's input tensor. 45 | Returns: 46 | torch.Tensor containing predicted labels. 47 | """ 48 | with torch.no_grad(): 49 | outputs = self.predict(x) 50 | _, y = torch.max(outputs, dim=1) 51 | return y 52 | 53 | def _verify_and_process_inputs(self, x, y): 54 | if self.targeted: 55 | assert y is not None 56 | 57 | if not self.targeted: 58 | if y is None: 59 | y = self._get_predicted_label(x) 60 | 61 | x = replicate_input(x) 62 | y = replicate_input(y) 63 | return x, y -------------------------------------------------------------------------------- /core/attacks/apgd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from autoattack.autopgd_base import APGDAttack 5 | 6 | 7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 8 | 9 | 10 | class APGD(): 11 | """ 12 | APGD attack (from AutoAttack) (Croce et al, 2020). 13 | The attack performs nb_iter steps of adaptive size, while always staying within eps from the initial point. 14 | Arguments: 15 | predict (nn.Module): forward pass function. 16 | loss_fn (str): loss function - ce or dlr. 17 | n_restarts (int): number of random restarts. 18 | eps (float): maximum distortion. 19 | nb_iter (int): number of iterations. 20 | ord (int): (optional) the order of maximum distortion (inf or 2). 21 | """ 22 | def __init__(self, predict, loss_fn='ce', n_restarts=2, eps=0.3, nb_iter=40, ord=np.inf, seed=1): 23 | assert loss_fn in ['ce', 'dlr'], 'Only loss_fn=ce or loss_fn=dlr are supported!' 24 | assert ord in [2, np.inf], 'Only ord=inf or ord=2 are supported!' 25 | 26 | norm = 'Linf' if ord == np.inf else 'L2' 27 | self.apgd = APGDAttack(predict, n_restarts=n_restarts, n_iter=nb_iter, verbose=False, eps=eps, norm=norm, 28 | eot_iter=1, rho=.75, seed=seed, device=device) 29 | self.apgd.loss = loss_fn 30 | 31 | def perturb(self, x, y): 32 | x_adv = self.apgd.perturb(x, y)[1] 33 | r_adv = x_adv - x 34 | return x_adv, r_adv 35 | 36 | 37 | class LinfAPGDAttack(APGD): 38 | """ 39 | APGD attack (from AutoAttack) with order=Linf. 40 | The attack performs nb_iter steps of adaptive size, while always staying within eps from the initial point. 41 | Arguments: 42 | predict (nn.Module): forward pass function. 43 | loss_fn (str): loss function - ce or dlr. 44 | n_restarts (int): number of random restarts. 45 | eps (float): maximum distortion. 46 | nb_iter (int): number of iterations. 47 | """ 48 | 49 | def __init__(self, predict, loss_fn='ce', n_restarts=2, eps=0.3, nb_iter=40, seed=1): 50 | ord = np.inf 51 | super(L2APGDAttack, self).__init__( 52 | predict=predict, loss_fn=loss_fn, n_restarts=n_restarts, eps=eps, nb_iter=nb_iter, ord=ord, seed=seed) 53 | 54 | 55 | class L2APGDAttack(APGD): 56 | """ 57 | APGD attack (from AutoAttack) with order=L2. 58 | The attack performs nb_iter steps of adaptive size, while always staying within eps from the initial point. 59 | Arguments: 60 | predict (nn.Module): forward pass function. 61 | loss_fn (str): loss function - ce or dlr. 62 | n_restarts (int): number of random restarts. 63 | eps (float): maximum distortion. 64 | nb_iter (int): number of iterations. 65 | """ 66 | 67 | def __init__(self, predict, loss_fn='ce', n_restarts=2, eps=0.3, nb_iter=40, seed=1): 68 | ord = 2 69 | super(L2APGDAttack, self).__init__( 70 | predict=predict, loss_fn=loss_fn, n_restarts=n_restarts, eps=eps, nb_iter=nb_iter, ord=ord, seed=seed) -------------------------------------------------------------------------------- /core/attacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Attack 2 | 3 | from .apgd import LinfAPGDAttack 4 | from .apgd import L2APGDAttack 5 | 6 | from .fgsm import FGMAttack 7 | from .fgsm import FGSMAttack 8 | from .fgsm import L2FastGradientAttack 9 | from .fgsm import LinfFastGradientAttack 10 | 11 | from .pgd import PGDAttack 12 | from .pgd import L2PGDAttack 13 | from .pgd import LinfPGDAttack 14 | 15 | from .deepfool import DeepFoolAttack 16 | from .deepfool import LinfDeepFoolAttack 17 | from .deepfool import L2DeepFoolAttack 18 | 19 | from .utils import CWLoss 20 | import torch.nn as nn 21 | 22 | ATTACKS = ['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd'] 23 | 24 | 25 | def create_attack(model, attack_type, attack_eps, attack_iter, attack_step, rand_init_type='uniform', 26 | clip_min=0., clip_max=1. ,save_path=None): 27 | """ 28 | Initialize adversary. 29 | Arguments: 30 | model (nn.Module): forward pass function. 31 | criterion (nn.Module): loss function. 32 | attack_type (str): name of the attack. 33 | attack_eps (float): attack radius. 34 | attack_iter (int): number of attack iterations. 35 | attack_step (float): step size for the attack. 36 | rand_init_type (str): random initialization type for PGD (default: uniform). 37 | clip_min (float): mininum value per input dimension. 38 | clip_max (float): maximum value per input dimension. 39 | Returns: 40 | Attack 41 | """ 42 | 43 | criterion = nn.CrossEntropyLoss() 44 | 45 | if attack_type == 'fgsm': 46 | attack = FGSMAttack(model, criterion, eps=attack_eps, clip_min=clip_min, clip_max=clip_max) 47 | elif attack_type == 'fgm': 48 | attack = FGMAttack(model, criterion, eps=attack_eps, clip_min=clip_min, clip_max=clip_max) 49 | elif attack_type == 'linf-pgd': 50 | attack = LinfPGDAttack(model, criterion, eps=attack_eps, nb_iter=attack_iter, eps_iter=attack_step, 51 | rand_init_type=rand_init_type, clip_min=clip_min, clip_max=clip_max, save_path=save_path) 52 | elif attack_type == 'l2-pgd': 53 | attack = L2PGDAttack(model, criterion, eps=attack_eps, nb_iter=attack_iter, eps_iter=attack_step, 54 | rand_init_type=rand_init_type, clip_min=clip_min, clip_max=clip_max) 55 | elif attack_type == 'linf-df': 56 | attack = LinfDeepFoolAttack(model, overshoot=0.02, nb_iter=attack_iter, search_iter=0, clip_min=clip_min, 57 | clip_max=clip_max) 58 | elif attack_type == 'l2-df': 59 | attack = L2DeepFoolAttack(model, overshoot=0.02, nb_iter=attack_iter, search_iter=0, clip_min=clip_min, 60 | clip_max=clip_max) 61 | elif attack_type == 'linf-apgd': 62 | attack = LinfAPGDAttack(model, criterion, n_restarts=2, eps=attack_eps, nb_iter=attack_iter) 63 | elif attack_type == 'l2-apgd': 64 | attack = L2APGDAttack(model, criterion, n_restarts=2, eps=attack_eps, nb_iter=attack_iter) 65 | elif attack_type == 'none': 66 | attack = None 67 | else: 68 | raise NotImplementedError('{} is not yet implemented!'.format(attack_type)) 69 | return attack 70 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from core.models import create_model 8 | from core.testfn import test 9 | from core.parse import parser_test 10 | from core.utils import get_logger, get_logger_name 11 | from core.data import corruption_19, corruption_15, load_corr_dataloader, load_dataloader 12 | import torchvision.transforms as T 13 | 14 | args_test = parser_test() 15 | 16 | with open(os.path.join(args_test.ckpt_path,'train/args.txt'), 'r') as f: 17 | old = json.load(f) 18 | args_test.__dict__ = dict(vars(args_test), **old) 19 | 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | model = create_model(args_test.data, args_test.backbone, args_test.protocol) 22 | model = model.to(device) 23 | checkpoint = torch.load(os.path.join(args_test.ckpt_path,'train',args_test.load_ckpt+'.pt')) 24 | model.load_state_dict(checkpoint['model_state_dict']) 25 | model.eval() 26 | del checkpoint 27 | 28 | corr = eval('_'.join(['corruption',args_test.type[1:]])) 29 | if args_test.data == 'cifar10': 30 | baseline_accs = (0.9535, [0.832,0.89348,0.80054,0.55018,0.8274,0.78482,0.78786,0.48148,0.60628,0.53044,0.75678,0.94036,0.76344,0.78972,0.85054],[0.7814,0.7552,0.6652,0.4942,0.5429,0.661,0.6369,0.2826,0.3597,0.2354,0.4816,0.9156,0.3,0.712,0.7554]) 31 | elif args_test.data == 'cifar100': 32 | baseline_accs = (0.7771,[0.5540,0.6465,0.5049,0.2418,0.6003,0.5545,0.5360,0.2302,0.3144,0.2548,0.5245,0.7372,0.5538,0.5231,0.6150],[0.4611,0.4040,0.3477,0.2207,0.3371,0.4430,0.3998,0.1111,0.1277,0.0662,0.2394,0.6656,0.1988,0.4420,0.5123]) 33 | else: 34 | raise 35 | 36 | logger = get_logger(get_logger_name(args_test.ckpt_path, args_test.load_ckpt, args_test.main_task)) 37 | #augmentor = T.RandomRotation(360) 38 | #augmentor = T.GaussianBlur(5,5) 39 | #augmentor = T.ElasticTransform(150.0) 40 | #augmentor = T.RandomInvert() 41 | augmentor = T.ColorJitter(5,0,0,0) 42 | 43 | _,_,dataloader_nat = load_dataloader(args_test) 44 | dict = test(dataloader_nat, model,device=device, augmentor=augmentor) 45 | # logger.info("nat-"+str(dict["eval_acc"])) 46 | # logger.info("aug-"+str(dict["eval_acc_aug"])) 47 | # logger.info("dis-"+str(dict["distance"])) 48 | # logger.info("dismin-"+str(dict["distance_min"])) 49 | # logger.info("dismax-"+str(dict["distance_max"])) 50 | # logger.info("disstd-"+str(dict["distance_std"])) 51 | # exit(0) 52 | if 'pacs' in args_test.data: 53 | res = np.zeros((4, len(corr))) 54 | for c in range(len(corr)): 55 | for s in range(1, 6): 56 | dataloader = load_corr_dataloader(args_test.data, args_test.data_dir, args_test.batch_size, cname=corr[c], dnum='all', severity=s) 57 | dict = test(dataloader, model,device=device, augmentor=augmentor) 58 | res[s-1, c] = dict["eval_acc"] 59 | log = "-".join([corr[c], str(s), str(res[s-1, c])]) 60 | logger.info(log) 61 | frame = pd.DataFrame({i+1: res[i, :] for i in range(0, 5)}, index=corr) 62 | frame.loc['average'] = {i+1: np.mean(res, axis=1)[i] for i in range(0, 5)} 63 | frame['avg'] = frame[list(range(1, 6))].mean(axis=1) 64 | logger.info(frame) 65 | else: 66 | res = np.zeros((5, len(corr))) 67 | for c in range(len(corr)): 68 | for s in range(1, 6): 69 | dataloader = load_corr_dataloader(args_test.data, args_test.data_dir, args_test.batch_size, cname=corr[c], dnum='all', severity=s) 70 | dict = test(dataloader, model,device=device, augmentor=augmentor) 71 | res[s-1, c] = dict["eval_acc"] 72 | log = "-".join([corr[c], str(s), str(res[s-1, c])]) 73 | logger.info(log) 74 | frame = pd.DataFrame({i+1: res[i, :] for i in range(0, 5)}, index=corr) 75 | frame.loc['average'] = {i+1: np.mean(res, axis=1)[i] for i in range(0, 5)} 76 | frame['avg'] = frame[list(range(1, 6))].mean(axis=1) 77 | logger.info(frame) -------------------------------------------------------------------------------- /core/attacks/fgsm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .base import Attack, LabelMixin 5 | from .utils import batch_multiply 6 | from .utils import clamp 7 | 8 | 9 | class FGSMAttack(Attack, LabelMixin): 10 | """ 11 | One step fast gradient sign method (Goodfellow et al, 2014). 12 | Arguments: 13 | predict (nn.Module): forward pass function. 14 | loss_fn (nn.Module): loss function. 15 | eps (float): attack step size. 16 | clip_min (float): mininum value per input dimension. 17 | clip_max (float): maximum value per input dimension. 18 | targeted (bool): indicate if this is a targeted attack. 19 | """ 20 | 21 | def __init__(self, predict, loss_fn=None, eps=0.3, clip_min=0., clip_max=1., targeted=False): 22 | super(FGSMAttack, self).__init__(predict, loss_fn, clip_min, clip_max) 23 | 24 | self.eps = eps 25 | self.targeted = targeted 26 | if self.loss_fn is None: 27 | self.loss_fn = nn.CrossEntropyLoss(reduction="sum") 28 | 29 | def perturb(self, x, y=None): 30 | """ 31 | Given examples (x, y), returns their adversarial counterparts with an attack length of eps. 32 | Arguments: 33 | x (torch.Tensor): input tensor. 34 | y (torch.Tensor): label tensor. 35 | - if None and self.targeted=False, compute y as predicted labels. 36 | - if self.targeted=True, then y must be the targeted labels. 37 | Returns: 38 | torch.Tensor containing perturbed inputs. 39 | torch.Tensor containing the perturbation. 40 | """ 41 | 42 | x, y = self._verify_and_process_inputs(x, y) 43 | 44 | xadv = x.requires_grad_() 45 | outputs = self.predict(xadv) 46 | 47 | loss = self.loss_fn(outputs, y) 48 | if self.targeted: 49 | loss = -loss 50 | loss.backward() 51 | grad_sign = xadv.grad.detach().sign() 52 | 53 | xadv = xadv + batch_multiply(self.eps, grad_sign) 54 | xadv = clamp(xadv, self.clip_min, self.clip_max) 55 | radv = xadv - x 56 | return xadv.detach(), radv.detach() 57 | 58 | 59 | LinfFastGradientAttack = FGSMAttack 60 | 61 | 62 | class FGMAttack(Attack, LabelMixin): 63 | """ 64 | One step fast gradient method. Perturbs the input with gradient (not gradient sign) of the loss wrt the input. 65 | Arguments: 66 | predict (nn.Module): forward pass function. 67 | loss_fn (nn.Module): loss function. 68 | eps (float): attack step size. 69 | clip_min (float): mininum value per input dimension. 70 | clip_max (float): maximum value per input dimension. 71 | targeted (bool): indicate if this is a targeted attack. 72 | """ 73 | 74 | def __init__(self, predict, loss_fn=None, eps=0.3, clip_min=0., clip_max=1., targeted=False): 75 | super(FGMAttack, self).__init__( 76 | predict, loss_fn, clip_min, clip_max) 77 | 78 | self.eps = eps 79 | self.targeted = targeted 80 | if self.loss_fn is None: 81 | self.loss_fn = nn.CrossEntropyLoss(reduction="sum") 82 | 83 | def perturb(self, x, y=None): 84 | """ 85 | Given examples (x, y), returns their adversarial counterparts with an attack length of eps. 86 | Arguments: 87 | x (torch.Tensor): input tensor. 88 | y (torch.Tensor): label tensor. 89 | - if None and self.targeted=False, compute y as predicted labels. 90 | - if self.targeted=True, then y must be the targeted labels. 91 | Returns: 92 | torch.Tensor containing perturbed inputs. 93 | torch.Tensor containing the perturbation. 94 | """ 95 | 96 | x, y = self._verify_and_process_inputs(x, y) 97 | xadv = x.requires_grad_() 98 | outputs = self.predict(xadv) 99 | 100 | loss = self.loss_fn(outputs, y) 101 | if self.targeted: 102 | loss = -loss 103 | loss.backward() 104 | grad = normalize_by_pnorm(xadv.grad) 105 | xadv = xadv + batch_multiply(self.eps, grad) 106 | xadv = clamp(xadv, self.clip_min, self.clip_max) 107 | radv = xadv - x 108 | 109 | return xadv.detach(), radv.detach() 110 | 111 | 112 | L2FastGradientAttack = FGMAttack -------------------------------------------------------------------------------- /core/parse.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parser_train(): 4 | parser = argparse.ArgumentParser(description='Args for Training.') 5 | 6 | # basic 7 | parser.add_argument('--desc', type=str, default='none', help='Description of experiment. It will be used to name directories.') 8 | parser.add_argument('--save-dir', type=str, default='./save') 9 | parser.add_argument('--seed', type=int, default=3407) 10 | parser.add_argument('--pretrained-file', type=str, default=None, help='Pretrained weights file name.') 11 | parser.add_argument('--resume-file', type=str, default=None, help='Resumed file name.') 12 | parser.add_argument('--save-freq', type=int, default=50, help='Save per epochs.') 13 | 14 | # whole training 15 | parser.add_argument('--epoch', type=int, default=200, help='Number of training epochs.') 16 | parser.add_argument('--batch-size', type=int, default=128, help='Batch size for training.') 17 | parser.add_argument('--batch-size-validation', type=int, default=128, help='Batch size for testing.') 18 | parser.add_argument('--ensemble-iter-eval', type=int, default=10, help='Number of ensemble while evaluating, helping choose best epoch') 19 | 20 | # data 21 | parser.add_argument('--data-dir', type=str, default='./datasets') 22 | parser.add_argument('--data', type=str, default='cifar10', choices=['mnist','cifar10','cifar100','tin200','pacs-art','pacs-cartoon','pacs-photo','pacs-sketch'], help='Data to use.') 23 | parser.add_argument('--data-diff', type=str, default=None, choices=['mnist','cifar10','cifar100','tin200','pacs-art','pacs-cartoon','pacs-photo','pacs-sketch'], help='O.O.D. Data used to guide diffusion.') 24 | #parser.add_argument('--data-eval', type=str, default=None, choices=['mnist','cifar10','cifar100','tin200','pacs-art','pacs-cartoon','pacs-photo','pacs-sketch'], help='O.O.D. Data used to guide diffusion.') 25 | parser.add_argument('--norm', action='store_true') 26 | parser.add_argument('--npc-train', default='all', help='Number of training samples per class, int or all.') 27 | parser.add_argument('--num-workers', type=int, default=2) 28 | 29 | # augmentation 30 | parser.add_argument('--aug-train', type=str, default="none", help='Data augmentation for training, replacing clean') 31 | parser.add_argument('--aug-train-diff', type=str, default="none", help='Data augmentation for training') 32 | 33 | # attack 34 | parser.add_argument('--atk-train', type=str, choices=['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd', 'none'], 35 | default=None, help='Type of attack for training.') 36 | parser.add_argument('--atk-eval', type=str, choices=['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd', 'none'], 37 | default=None, help='Type of attack for evaluating.') 38 | parser.add_argument('--attack-eps', type=float, default=8/255, help='Epsilon for the attack.') 39 | parser.add_argument('--attack-step', type=float, default=2/255, help='Step size for PGD attack.') 40 | parser.add_argument('--attack-iter', type=int, default=10, help='Max. number of iterations (if any) for the attack.') 41 | 42 | # model 43 | parser.add_argument('--backbone', type=str, choices=['resnet-18', 'resnet-34','wideresnet-16-4', 'preresnet-18'], default="resnet-18") 44 | parser.add_argument('--protocol', type=str, default="pdeadd") 45 | parser.add_argument('--use-gmm', action='store_true') 46 | parser.add_argument('--ls', type=float, default=0.1) 47 | 48 | 49 | 50 | # C optimizer 51 | parser.add_argument('--optimizerC', type=str, default='sgd', help='Choice for optimizerC.') 52 | parser.add_argument('--lrC', type=float, default=0.01, help='Learning rate for optimizer.') 53 | parser.add_argument('--weight-decay', type=float, default=5e-4, help='Optimizer (SGD) weight decay.') 54 | parser.add_argument('--scheduler', choices=['cosanlr','mslr',"none"], default='cosanlr', help='Type of scheduler.') 55 | parser.add_argument('--warm', type=int, default=0) 56 | 57 | # Diff optimizer 58 | parser.add_argument('--optimizerDiff', type=str, default='adam', help='Choice for optimizerD.') 59 | parser.add_argument('--lrDiff', type=float, default=0.1, help='Learning rate for optimizer.') 60 | 61 | args = parser.parse_args() 62 | return args 63 | 64 | 65 | def parser_test(): 66 | parser = argparse.ArgumentParser(description='Args for Testing.') 67 | parser.add_argument('--ckpt_path', type=str) 68 | parser.add_argument('--main_task', type=str, choices=["ood","adv"]) 69 | parser.add_argument('--severity', type=int, choices=[0,1,2,3,4,5], help='Data augmentation severity for testing') 70 | parser.add_argument('--type', type=str, choices=['c15','c19'], help='Data augmentation type for testing') 71 | parser.add_argument('--threat', type=str, choices=['linf','l2']) 72 | parser.add_argument('--load_ckpt', type=str) 73 | args = parser.parse_args() 74 | return args -------------------------------------------------------------------------------- /core/models/standard_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | """ 9 | Implements a basic block module for Resnets. 10 | Arguments: 11 | in_planes (int): number of input planes. 12 | out_planes (int): number of output filters. 13 | stride (int): stride of convolution. 14 | """ 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion * planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion * planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | out = F.relu(out) 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | """ 41 | Implements a basic block module with bottleneck for Resnets. 42 | Arguments: 43 | in_planes (int): number of input planes. 44 | out_planes (int): number of output filters. 45 | stride (int): stride of convolution. 46 | """ 47 | expansion = 4 48 | 49 | def __init__(self, in_planes, planes, stride=1): 50 | super(Bottleneck, self).__init__() 51 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 52 | self.bn1 = nn.BatchNorm2d(planes) 53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 56 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 57 | 58 | self.shortcut = nn.Sequential() 59 | if stride != 1 or in_planes != self.expansion * planes: 60 | self.shortcut = nn.Sequential( 61 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 62 | nn.BatchNorm2d(self.expansion * planes) 63 | ) 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(self.conv1(x))) 67 | out = F.relu(self.bn2(self.conv2(out))) 68 | out = self.bn3(self.conv3(out)) 69 | out += self.shortcut(x) 70 | out = F.relu(out) 71 | return out 72 | 73 | 74 | class ResNet(nn.Module): 75 | """ 76 | ResNet model 77 | Arguments: 78 | block (BasicBlock or Bottleneck): type of basic block to be used. 79 | num_blocks (list): number of blocks in each sub-module. 80 | num_classes (int): number of output classes. 81 | device (torch.device or str): device to work on. 82 | """ 83 | def __init__(self, block, num_blocks, num_classes=10, device='cpu'): 84 | super(ResNet, self).__init__() 85 | self.in_planes = 64 86 | 87 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 88 | self.bn1 = nn.BatchNorm2d(64) 89 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 90 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 91 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 92 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 93 | self.linear = nn.Linear(512 * block.expansion, num_classes) 94 | 95 | def _make_layer(self, block, planes, num_blocks, stride): 96 | strides = [stride] + [1] * (num_blocks - 1) 97 | layers = [] 98 | for stride in strides: 99 | layers.append(block(self.in_planes, planes, stride)) 100 | self.in_planes = planes * block.expansion 101 | return nn.Sequential(*layers) 102 | 103 | def forward(self, x, use_diffusion = False): 104 | out = F.relu(self.bn1(self.conv1(x))) 105 | out = self.layer1(out) 106 | out = self.layer2(out) 107 | out = self.layer3(out) 108 | out = self.layer4(out) 109 | 110 | out = F.avg_pool2d(out, 4) 111 | out = out.view(out.size(0), -1) 112 | 113 | out = self.linear(out) 114 | 115 | return out 116 | 117 | 118 | 119 | def standard_resnet(name, num_classes=10, pretrained=False, device='cpu'): 120 | """ 121 | Returns suitable Resnet model from its name. 122 | Arguments: 123 | name (str): name of resnet architecture. 124 | num_classes (int): number of target classes. 125 | pretrained (bool): whether to use a pretrained model. 126 | device (str or torch.device): device to work on. 127 | Returns: 128 | torch.nn.Module. 129 | """ 130 | if name == 'resnet-18': 131 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, device=device) 132 | elif name == 'resnet-34': 133 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, device=device) 134 | elif name == 'resnet-50': 135 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, device=device) 136 | elif name == 'resnet-101': 137 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, device=device) 138 | 139 | raise ValueError('Only resnet-18, resnet-34, resnet-50 and resnet-101 are supported!') 140 | return 141 | -------------------------------------------------------------------------------- /core/trainfn.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from collections import defaultdict 3 | 4 | from tqdm import tqdm 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from core.context import ctx_noparamgrad_and_eval 9 | from core.utils import vis 10 | 11 | nll_loss = nn.GaussianNLLLoss() 12 | 13 | #augmentor_ood = DataAugmentor('rotation-20') 14 | 15 | # def get_ratio(mu_aug_ood, mu, sigma, f): 16 | # ratio1 = ((mu_aug_ood-mu).abs()/(sigma+1e-8)).mean().item() 17 | # distance = (mu_aug_ood-mu).abs().mean().item() 18 | # sigma = sigma.mean().item() 19 | # ratio2 = distance/sigma 20 | # f.write("{},{},{},{}\n".format(distance,sigma,ratio2,ratio1)) 21 | 22 | def train_pdeadd(dataloader_train, dataloader_train_diff, model, 23 | optimizerDiff, optimizerC, label_smooth=0.1, attacker=None, 24 | device=None, use_gmm=True, save_path=None): 25 | 26 | print('use_gmm',use_gmm) 27 | 28 | metrics = pd.DataFrame() 29 | batch_index = 0 30 | model.train() 31 | 32 | 33 | for (x, y), (x_ood, y_ood) in tqdm(zip(dataloader_train,dataloader_train_diff)): 34 | 35 | batch_metric = defaultdict(float) 36 | x, y = x.to(device), y.to(device) 37 | x_ood, y_ood = x_ood.to(device), y_ood.to(device) 38 | 39 | if batch_index == 1: 40 | vis(x, x_ood, save_path=save_path) 41 | 42 | if (y - y_ood).sum().cpu().detach().item(): 43 | raise 44 | 45 | _ = model(x, use_diffusion = True) 46 | mus = model.mus 47 | sigmas = model.sigmas 48 | 49 | _ = model(x_ood, use_diffusion = False) 50 | mus_aug = model.mus 51 | 52 | # x_aug_ood = augmentor_ood.apply(x, visualize=False) 53 | # _ = model(x_aug_ood, use_diffusion = False) 54 | # mus_aug_ood = model.mus 55 | 56 | # f1 = open(os.path.join(save_path,'Q2_1.csv'), 'a') 57 | # f2 = open(os.path.join(save_path,'Q2_2.csv'), 'a') 58 | # f3 = open(os.path.join(save_path,'Q2_3.csv'), 'a') 59 | # f4 = open(os.path.join(save_path,'Q2_4.csv'), 'a') 60 | 61 | # get_ratio(mus_aug_ood[0], mus[0], sigmas[0], f1) 62 | # get_ratio(mus_aug_ood[1], mus[1], sigmas[1], f2) 63 | # get_ratio(mus_aug_ood[2], mus[2], sigmas[2], f3) 64 | # get_ratio(mus_aug_ood[3], mus[3], sigmas[3], f4) 65 | 66 | # f1.close() 67 | # f2.close() 68 | # f3.close() 69 | # f4.close() 70 | 71 | lossDiff = 0 72 | for mu_aug, mu, sigma in zip(mus_aug, mus, sigmas): 73 | lossDiff += nll_loss(mu_aug.view(x.shape[0],-1), mu.view(x.shape[0],-1), sigma.view(x.shape[0],-1)) 74 | lossDiff = lossDiff/len(mus) 75 | 76 | optimizerDiff.zero_grad() 77 | lossDiff.backward() 78 | optimizerDiff.step() 79 | 80 | # out_aug = model(x_aug, use_diffusion = True) 81 | # out = model(x, use_diffusion = True) 82 | # optimizerC.zero_grad() 83 | # lossC = F.nll_loss(out_aug, y) + F.nll_loss(out, y) 84 | # lossC.backward() 85 | # optimizerC.step() 86 | 87 | if use_gmm: 88 | x_all = torch.cat((x, x_ood), dim=0) 89 | y_all = torch.cat((y, y), dim=0) 90 | else: 91 | x_all = x 92 | y_all = y 93 | out = model(x_all, use_diffusion = True) 94 | optimizerC.zero_grad() 95 | lossC = F.cross_entropy(out, y_all, label_smoothing=label_smooth) 96 | lossC.backward() 97 | optimizerC.step() 98 | 99 | batch_metric["train_loss_nll"] = lossDiff.data.item() 100 | batch_metric["train_loss_cla"] = lossC.data.item() 101 | batch_metric["train_acc"] = (torch.softmax(out.data, dim=1).argmax(dim=1) == y_all.data).sum().data.item() 102 | batch_metric["scales_l1"] = model.scales[0] 103 | batch_metric["scales_l2"] = model.scales[1] 104 | batch_metric["scales_l3"] = model.scales[2] 105 | batch_metric["scales_l4"] = model.scales[3] 106 | metrics = pd.concat([metrics, pd.DataFrame(batch_metric, index=[0])], ignore_index=True) 107 | batch_index += 1 108 | 109 | if use_gmm: 110 | length = 2*len(dataloader_train.dataset) 111 | else: 112 | length = len(dataloader_train.dataset) 113 | 114 | return dict(metrics.agg({ 115 | "train_loss_nll":"mean", 116 | "train_loss_cla":"mean", 117 | "train_acc":lambda x:100*sum(x)/(length), 118 | "scales_l1":"mean","scales_l2":"mean","scales_l3":"mean","scales_l4":"mean"})) 119 | 120 | 121 | def train_standard(dataloader_train, model, optimizer, 122 | augmentor=None, attacker=None, device=None, visualize=False, epoch=None): 123 | 124 | metrics = pd.DataFrame() 125 | batch_index = 0 126 | model.train() 127 | 128 | for x, y in tqdm(dataloader_train): 129 | 130 | batch_metric = defaultdict(float) 131 | x, y = x.to(device), y.to(device) 132 | 133 | # Update Classifier network 134 | if attacker: 135 | with ctx_noparamgrad_and_eval(model): 136 | x, _ = attacker.perturb(x, y, visualize=(True*visualize) if batch_index==0 else (False*visualize)) 137 | 138 | out = model(x) 139 | optimizer.zero_grad() 140 | loss = F.cross_entropy(out, y) 141 | loss.backward() 142 | optimizer.step() 143 | 144 | batch_metric["train_loss_cla"] = loss.data.item() 145 | batch_metric["train_loss_nll"] = 0 146 | batch_metric["train_acc"] = (torch.softmax(out.data, dim=1).argmax(dim=1) == y.data).sum().data.item() 147 | batch_metric["scales_l1"] = 0 148 | batch_metric["scales_l2"] = 0 149 | batch_metric["scales_l3"] = 0 150 | batch_metric["scales_l4"] = 0 151 | metrics = pd.concat([metrics, pd.DataFrame(batch_metric, index=[0])], ignore_index=True) 152 | batch_index += 1 153 | 154 | return dict(metrics.agg({ 155 | "train_loss_nll":"mean", 156 | "train_loss_cla":"mean", 157 | "train_acc":lambda x:100*sum(x)/(len(dataloader_train.dataset)), 158 | "scales_l1":"mean","scales_l2":"mean","scales_l3":"mean","scales_l4":"mean"})) 159 | 160 | -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import numpy as np 4 | import torch 5 | import datetime 6 | import os 7 | import re 8 | import time 9 | import pandas as pd 10 | from torchvision.utils import save_image 11 | 12 | def get_logger(logpath, displaying=True, saving=True): 13 | logger = logging.getLogger() 14 | level = logging.INFO 15 | logger.setLevel(level) 16 | if saving: 17 | info_file_handler = logging.FileHandler(logpath, mode="a") 18 | info_file_handler.setLevel(level) 19 | logger.addHandler(info_file_handler) 20 | if displaying: 21 | console_handler = logging.StreamHandler() 22 | console_handler.setLevel(level) 23 | logger.addHandler(console_handler) 24 | # logger.info(filepath) 25 | return logger 26 | 27 | def get_logger_name(ckpt_path, load_ckpt, task, severity=None, threat=None): 28 | 29 | save_dir = os.path.join(ckpt_path, 'test') 30 | os.makedirs(save_dir, exist_ok=True) 31 | s = re.search(r"e\d+", load_ckpt) 32 | 33 | if 'wo' in load_ckpt: 34 | postfix = "-wo" 35 | elif 'en' in load_ckpt: 36 | postfix = "-en" 37 | elif s: 38 | postfix = "-"+str(s.group()) 39 | else: 40 | postfix = "" 41 | 42 | if task =='ood': 43 | name = os.path.join(save_dir, 'test-{}{}.log'.format(task, postfix)) 44 | elif task == 'adv': 45 | name = os.path.join(save_dir, 'test-{}-{}{}.log'.format(task, threat, postfix)) 46 | else: 47 | raise 48 | return name 49 | 50 | def get_desc(args): 51 | if args.protocol=='pdeadd': 52 | desc = "{}_{}-gmm{}_{}_(C{}lr{}{}-D{}lr{})_e{}_b{}_aug-{}_augdiff-{}_atk-{}".format( 53 | args.backbone, args.protocol, args.use_gmm, 54 | args.desc, args.optimizerC, args.lrC, args.scheduler, 55 | args.optimizerDiff, args.lrDiff, 56 | args.epoch, args.batch_size, args.aug_train, args.aug_train_diff, args.atk_train) 57 | elif args.protocol == 'standard' or 'fixdiff' in args.protocol : 58 | desc = "{}_{}_{}_C{}lr{}{}_e{}_b{}_aug-{}_atk-{}".format( 59 | args.backbone, args.protocol, 60 | args.desc, args.optimizerC, args.lrC, args.scheduler, 61 | args.epoch, args.batch_size, args.aug_train, args.atk_train) 62 | else: 63 | raise 64 | return desc 65 | 66 | def set_seed(seed=1): 67 | random.seed(seed) 68 | os.environ['PYTHONHASHSEED'] =str(seed) 69 | np.random.seed(seed) 70 | torch.manual_seed(seed) 71 | torch.cuda.manual_seed(seed) 72 | torch.cuda.manual_seed_all(seed) 73 | torch.backends.cudnn.deterministic =True 74 | torch.backends.cudnn.benchmark = False 75 | 76 | def format_time(elapsed): 77 | elapsed_rounded = int(round((elapsed))) 78 | return str(datetime.timedelta(seconds=elapsed_rounded)) 79 | 80 | class BestSaver(): 81 | def __init__(self): 82 | self.best_acc=0 83 | self.best_epoch=0 84 | def apply(self, acc, epoch, model, optimizerC, scheduler, optimizerDiff=None, save_path=None): 85 | if acc > self.best_acc: 86 | self.best_acc = acc 87 | self.best_epoch = epoch 88 | if epoch > 100: 89 | self.save_model(model, optimizerC, scheduler=scheduler, optimizerDiff=optimizerDiff, save_path=save_path) 90 | 91 | def save_model(self, model, optimizerC, scheduler=None, optimizerDiff=None, save_path=None): 92 | if scheduler: 93 | scheduler_state_dict = scheduler.state_dict() 94 | else: 95 | scheduler_state_dict = None 96 | 97 | if optimizerDiff: 98 | torch.save( 99 | {'model_state_dict': model.state_dict(), 100 | 'optimizerC_state_dict': optimizerC.state_dict(), 101 | 'optimizerDiff_state_dict': optimizerDiff.state_dict(), 102 | 'scheduler':scheduler_state_dict 103 | }, save_path) 104 | else: 105 | torch.save( 106 | {'model_state_dict': model.state_dict(), 107 | 'optimizerC_state_dict': optimizerC.state_dict(), 108 | 'scheduler':scheduler_state_dict 109 | }, save_path) 110 | 111 | def verbose_and_save(logger, epoch, start_epoch, eval_per_epoch, start, train_metric, eval_metric, writer): 112 | # save logs 113 | logger.info('\n[Epoch {}] - Time taken: {}'.format(epoch, format_time(time.time()-start))) 114 | logger.info('Train\t Acc: {:.2f}%, NLLLoss: {:.2f}, ClassLoss: {:.2f}'.format( 115 | train_metric['train_acc'],train_metric['train_loss_nll'],train_metric['train_loss_cla'])) 116 | logger.info('Train\t Scale1: {:.4f}, Scale2: {:.4f}, Scale3: {:.4f}, Scale4: {:.4f}'.format( 117 | train_metric['scales_l1'],train_metric['scales_l2'],train_metric['scales_l3'],train_metric['scales_l4'])) 118 | if (epoch == start_epoch) or (epoch % eval_per_epoch == 0): 119 | logger.info('Eval Nature Samples\t Acc: {:.2f}%, Loss: {:.2f}'.format( 120 | eval_metric["nat"]['eval_acc'],eval_metric["nat"]['eval_loss'])) 121 | logger.info('Eval O.O.D. Samples\t Acc: {:.2f}%'.format(eval_metric["ood"]['eval_acc'])) 122 | logger.info("\n") 123 | # logger.info('Eval O.O.D. Samples\nwodiff\t Acc: {:.2f}%, Loss: {:.2f}\nendiff\t Acc: {:.2f}%, Loss: {:.2f}'.format( 124 | # eval_metric["ood"]['eval_acc'],eval_metric["ood"]['eval_loss'], 125 | # eval_metric["ood_diff"]['eval_acc'],eval_metric["ood_diff"]['eval_loss'])) 126 | 127 | # save tensorboard 128 | writer.add_scalar('train/lossDiff', train_metric['train_loss_nll'], epoch) 129 | writer.add_scalar('train/lossC', train_metric['train_loss_cla'], epoch) 130 | writer.add_scalar('train/acc', train_metric['train_acc'], epoch) 131 | 132 | writer.add_scalar('scales/layer1', train_metric['scales_l1'], epoch) 133 | writer.add_scalar('scales/layer2', train_metric['scales_l2'], epoch) 134 | writer.add_scalar('scales/layer3', train_metric['scales_l3'], epoch) 135 | writer.add_scalar('scales/layer4', train_metric['scales_l4'], epoch) 136 | 137 | writer.add_scalar('evalnat/loss', eval_metric["nat"]['eval_loss'], epoch) 138 | writer.add_scalar('evalnat/acc', eval_metric["nat"]['eval_acc'], epoch) 139 | 140 | # writer_eval_diff.add_scalar('evalood/loss', eval_metric["ood"]['eval_loss'], epoch) 141 | writer.add_scalar('evalood/acc', eval_metric["ood"]['eval_acc'], epoch) 142 | # writer_eval_diff.add_scalar('evalood/loss', eval_metric["ood_diff"]['eval_loss'], epoch) 143 | # writer_eval_diff.add_scalar('evalood/acc', eval_metric["ood_diff"]['eval_acc'], epoch) 144 | 145 | 146 | def eval_epoch(epoch): 147 | if epoch < 5: 148 | eval_per_epoch = 1 149 | elif epoch < 100: 150 | eval_per_epoch = 20 151 | elif epoch < 160: 152 | eval_per_epoch = 10 153 | else: 154 | eval_per_epoch = 1 155 | return eval_per_epoch 156 | 157 | def vis(x_ori, x_aug, save_path): 158 | x=torch.cat([x_ori[:8],x_aug[:8]]) 159 | save_image(x.cpu(), os.path.join(save_path,'samples.png'), nrow=8,padding=0, value_range=(0, 1), pad_value=0) 160 | -------------------------------------------------------------------------------- /core/attacks/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | from torch.distributions import laplace 9 | from torch.distributions import uniform 10 | from torch.nn.modules.loss import _Loss 11 | 12 | 13 | def replicate_input(x): 14 | """ 15 | Clone the input tensor x. 16 | """ 17 | return x.detach().clone() 18 | 19 | 20 | def replicate_input_withgrad(x): 21 | """ 22 | Clone the input tensor x and set requires_grad=True. 23 | """ 24 | return x.detach().clone().requires_grad_() 25 | 26 | 27 | def calc_l2distsq(x, y): 28 | """ 29 | Calculate L2 distance between tensors x and y. 30 | """ 31 | d = (x - y)**2 32 | return d.view(d.shape[0], -1).sum(dim=1) 33 | 34 | 35 | def clamp(input, min=None, max=None): 36 | """ 37 | Clamp a tensor by its minimun and maximun values. 38 | """ 39 | ndim = input.ndimension() 40 | if min is None: 41 | pass 42 | elif isinstance(min, (float, int)): 43 | input = torch.clamp(input, min=min) 44 | elif isinstance(min, torch.Tensor): 45 | if min.ndimension() == ndim - 1 and min.shape == input.shape[1:]: 46 | input = torch.max(input, min.view(1, *min.shape)) 47 | else: 48 | assert min.shape == input.shape 49 | input = torch.max(input, min) 50 | else: 51 | raise ValueError("min can only be None | float | torch.Tensor") 52 | 53 | if max is None: 54 | pass 55 | elif isinstance(max, (float, int)): 56 | input = torch.clamp(input, max=max) 57 | elif isinstance(max, torch.Tensor): 58 | if max.ndimension() == ndim - 1 and max.shape == input.shape[1:]: 59 | input = torch.min(input, max.view(1, *max.shape)) 60 | else: 61 | assert max.shape == input.shape 62 | input = torch.min(input, max) 63 | else: 64 | raise ValueError("max can only be None | float | torch.Tensor") 65 | return input 66 | 67 | 68 | def _batch_multiply_tensor_by_vector(vector, batch_tensor): 69 | """Equivalent to the following. 70 | for ii in range(len(vector)): 71 | batch_tensor.data[ii] *= vector[ii] 72 | return batch_tensor 73 | """ 74 | return ( 75 | batch_tensor.transpose(0, -1) * vector).transpose(0, -1).contiguous() 76 | 77 | 78 | def _batch_clamp_tensor_by_vector(vector, batch_tensor): 79 | """Equivalent to the following. 80 | for ii in range(len(vector)): 81 | batch_tensor[ii] = clamp( 82 | batch_tensor[ii], -vector[ii], vector[ii]) 83 | """ 84 | return torch.min( 85 | torch.max(batch_tensor.transpose(0, -1), -vector), vector 86 | ).transpose(0, -1).contiguous() 87 | 88 | 89 | def batch_multiply(float_or_vector, tensor): 90 | """ 91 | Multpliy a batch of tensors with a float or vector. 92 | """ 93 | if isinstance(float_or_vector, torch.Tensor): 94 | assert len(float_or_vector) == len(tensor) 95 | tensor = _batch_multiply_tensor_by_vector(float_or_vector, tensor) 96 | elif isinstance(float_or_vector, float): 97 | tensor *= float_or_vector 98 | else: 99 | raise TypeError("Value has to be float or torch.Tensor") 100 | return tensor 101 | 102 | 103 | def batch_clamp(float_or_vector, tensor): 104 | """ 105 | Clamp a batch of tensors. 106 | """ 107 | if isinstance(float_or_vector, torch.Tensor): 108 | assert len(float_or_vector) == len(tensor) 109 | tensor = _batch_clamp_tensor_by_vector(float_or_vector, tensor) 110 | return tensor 111 | elif isinstance(float_or_vector, float): 112 | tensor = clamp(tensor, -float_or_vector, float_or_vector) 113 | else: 114 | raise TypeError("Value has to be float or torch.Tensor") 115 | return tensor 116 | 117 | 118 | def _get_norm_batch(x, p): 119 | """ 120 | Returns the Lp norm of batch x. 121 | """ 122 | batch_size = x.size(0) 123 | return x.abs().pow(p).view(batch_size, -1).sum(dim=1).pow(1. / p) 124 | 125 | 126 | def _thresh_by_magnitude(theta, x): 127 | """ 128 | Threshold by magnitude. 129 | """ 130 | return torch.relu(torch.abs(x) - theta) * x.sign() 131 | 132 | 133 | def clamp_by_pnorm(x, p, r): 134 | """ 135 | Clamp tensor by its norm. 136 | """ 137 | assert isinstance(p, float) or isinstance(p, int) 138 | norm = _get_norm_batch(x, p) 139 | if isinstance(r, torch.Tensor): 140 | assert norm.size() == r.size() 141 | else: 142 | assert isinstance(r, float) 143 | factor = torch.min(r / norm, torch.ones_like(norm)) 144 | return batch_multiply(factor, x) 145 | 146 | 147 | def is_float_or_torch_tensor(x): 148 | """ 149 | Return whether input x is a float or a torch.Tensor. 150 | """ 151 | return isinstance(x, torch.Tensor) or isinstance(x, float) 152 | 153 | 154 | def normalize_by_pnorm(x, p=2, small_constant=1e-6): 155 | """ 156 | Normalize gradients for gradient (not gradient sign) attacks. 157 | Arguments: 158 | x (torch.Tensor): tensor containing the gradients on the input. 159 | p (int): (optional) order of the norm for the normalization (1 or 2). 160 | small_constant (float): (optional) to avoid dividing by zero. 161 | Returns: 162 | normalized gradients. 163 | """ 164 | assert isinstance(p, float) or isinstance(p, int) 165 | norm = _get_norm_batch(x, p) 166 | norm = torch.max(norm, torch.ones_like(norm) * small_constant) 167 | return batch_multiply(1. / norm, x) 168 | 169 | 170 | def rand_init_delta(delta, x, ord, eps, clip_min, clip_max): 171 | """ 172 | Randomly initialize the perturbation. 173 | """ 174 | if isinstance(eps, torch.Tensor): 175 | assert len(eps) == len(delta) 176 | 177 | if ord == np.inf: 178 | delta.data.uniform_(-1, 1) 179 | delta.data = batch_multiply(eps, delta.data) 180 | elif ord == 2: 181 | delta.data.uniform_(clip_min, clip_max) 182 | delta.data = delta.data - x 183 | delta.data = clamp_by_pnorm(delta.data, ord, eps) 184 | elif ord == 1: 185 | ini = laplace.Laplace( 186 | loc=delta.new_tensor(0), scale=delta.new_tensor(1)) 187 | delta.data = ini.sample(delta.data.shape) 188 | delta.data = normalize_by_pnorm(delta.data, p=1) 189 | ray = uniform.Uniform(0, eps).sample() 190 | delta.data *= ray 191 | delta.data = clamp(x.data + delta.data, clip_min, clip_max) - x.data 192 | else: 193 | error = "Only ord = inf, ord = 1 and ord = 2 have been implemented" 194 | raise NotImplementedError(error) 195 | 196 | delta.data = clamp( 197 | x + delta.data, min=clip_min, max=clip_max) - x 198 | return delta.data 199 | 200 | 201 | def CWLoss(output, target, confidence=0): 202 | """ 203 | CW loss (Marging loss). 204 | """ 205 | num_classes = output.shape[-1] 206 | target = target.data 207 | target_onehot = torch.zeros(target.size() + (num_classes,)) 208 | target_onehot = target_onehot.cuda() 209 | target_onehot.scatter_(1, target.unsqueeze(1), 1.) 210 | target_var = Variable(target_onehot, requires_grad=False) 211 | real = (target_var * output).sum(1) 212 | other = ((1. - target_var) * output - target_var * 10000.).max(1)[0] 213 | loss = - torch.clamp(real - other + confidence, min=0.) 214 | loss = torch.sum(loss) 215 | return loss 216 | -------------------------------------------------------------------------------- /core/attacks/deepfool.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | from .base import Attack, LabelMixin 8 | 9 | from .utils import batch_multiply 10 | from .utils import clamp 11 | from .utils import is_float_or_torch_tensor 12 | 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | 17 | def perturb_deepfool(xvar, yvar, predict, nb_iter=50, overshoot=0.02, ord=np.inf, clip_min=0.0, clip_max=1.0, 18 | search_iter=0, device=None): 19 | """ 20 | Compute DeepFool perturbations (Moosavi-Dezfooli et al, 2016). 21 | Arguments: 22 | xvar (torch.Tensor): input images. 23 | yvar (torch.Tensor): predictions. 24 | predict (nn.Module): forward pass function. 25 | nb_iter (int): number of iterations. 26 | overshoot (float): how much to overshoot the boundary. 27 | ord (int): (optional) the order of maximum distortion (inf or 2). 28 | clip_min (float): mininum value per input dimension. 29 | clip_max (float): maximum value per input dimension. 30 | search_iter (int): no of search iterations. 31 | device (torch.device): device to work on. 32 | Returns: 33 | torch.Tensor containing the perturbed input, 34 | torch.Tensor containing the perturbation 35 | """ 36 | 37 | x_orig = xvar 38 | x = torch.empty_like(xvar).copy_(xvar) 39 | x.requires_grad_(True) 40 | 41 | batch_i = torch.arange(x.shape[0]) 42 | r_tot = torch.zeros_like(x.data) 43 | for i in range(nb_iter): 44 | if x.grad is not None: 45 | x.grad.zero_() 46 | 47 | logits = predict(x) 48 | df_inds = np.argsort(logits.detach().cpu().numpy(), axis=-1) 49 | df_inds_other, df_inds_orig = df_inds[:, :-1], df_inds[:, -1] 50 | df_inds_orig = torch.from_numpy(df_inds_orig) 51 | df_inds_orig = df_inds_orig.to(device) 52 | not_done_inds = df_inds_orig == yvar 53 | if not_done_inds.sum() == 0: 54 | break 55 | 56 | logits[batch_i, df_inds_orig].sum().backward(retain_graph=True) 57 | grad_orig = x.grad.data.clone().detach() 58 | pert = x.data.new_ones(x.shape[0]) * np.inf 59 | w = torch.zeros_like(x.data) 60 | 61 | for inds in df_inds_other.T: 62 | x.grad.zero_() 63 | logits[batch_i, inds].sum().backward(retain_graph=True) 64 | grad_cur = x.grad.data.clone().detach() 65 | with torch.no_grad(): 66 | w_k = grad_cur - grad_orig 67 | f_k = logits[batch_i, inds] - logits[batch_i, df_inds_orig] 68 | if ord == 2: 69 | pert_k = torch.abs(f_k) / torch.norm(w_k.flatten(1), 2, -1) 70 | elif ord == np.inf: 71 | pert_k = torch.abs(f_k) / torch.norm(w_k.flatten(1), 1, -1) 72 | else: 73 | raise NotImplementedError("Only ord=inf and ord=2 have been implemented") 74 | swi = pert_k < pert 75 | if swi.sum() > 0: 76 | pert[swi] = pert_k[swi] 77 | w[swi] = w_k[swi] 78 | 79 | if ord == 2: 80 | r_i = (pert + 1e-6)[:, None, None, None] * w / torch.norm(w.flatten(1), 2, -1)[:, None, None, None] 81 | elif ord == np.inf: 82 | r_i = (pert + 1e-6)[:, None, None, None] * w.sign() 83 | 84 | r_tot += r_i * not_done_inds[:, None, None, None].float() 85 | x.data = x_orig + (1. + overshoot) * r_tot 86 | x.data = torch.clamp(x.data, clip_min, clip_max) 87 | 88 | x = x.detach() 89 | if search_iter > 0: 90 | dx = x - x_orig 91 | dx_l_low, dx_l_high = torch.zeros_like(dx), torch.ones_like(dx) 92 | for i in range(search_iter): 93 | dx_l = (dx_l_low + dx_l_high) / 2. 94 | dx_x = x_orig + dx_l * dx 95 | dx_y = predict(dx_x).argmax(-1) 96 | label_stay = dx_y == yvar 97 | label_change = dx_y != yvar 98 | dx_l_low[label_stay] = dx_l[label_stay] 99 | dx_l_high[label_change] = dx_l[label_change] 100 | x = dx_x 101 | 102 | # x.data = torch.clamp(x.data, clip_min, clip_max) 103 | r_tot = x.data - x_orig 104 | return x, r_tot 105 | 106 | 107 | 108 | class DeepFoolAttack(Attack, LabelMixin): 109 | """ 110 | DeepFool attack. 111 | [Seyed-Mohsen Moosavi-Dezfooli, Alhussein Fawzi, Pascal Frossard, 112 | "DeepFool: a simple and accurate method to fool deep neural networks"] 113 | Arguments: 114 | predict (nn.Module): forward pass function. 115 | overshoot (float): how much to overshoot the boundary. 116 | nb_iter (int): number of iterations. 117 | search_iter (int): no of search iterations. 118 | clip_min (float): mininum value per input dimension. 119 | clip_max (float): maximum value per input dimension. 120 | ord (int): (optional) the order of maximum distortion (inf or 2). 121 | """ 122 | 123 | def __init__( 124 | self, predict, overshoot=0.02, nb_iter=50, search_iter=50, clip_min=0., clip_max=1., ord=np.inf): 125 | super(DeepFoolAttack, self).__init__(predict, None, clip_min, clip_max) 126 | self.overshoot = overshoot 127 | self.nb_iter = nb_iter 128 | self.search_iter = search_iter 129 | self.targeted = False 130 | 131 | self.ord = ord 132 | assert is_float_or_torch_tensor(self.overshoot) 133 | 134 | def perturb(self, x, y=None): 135 | """ 136 | Given examples x, returns their adversarial counterparts. 137 | Arguments: 138 | x (torch.Tensor): input tensor. 139 | y (torch.Tensor): label tensor. 140 | - if None and self.targeted=False, compute y as predicted labels. 141 | Returns: 142 | torch.Tensor containing perturbed inputs, 143 | torch.Tensor containing the perturbation 144 | """ 145 | 146 | x, y = self._verify_and_process_inputs(x, None) 147 | x_adv, r_adv = perturb_deepfool(x, y, self.predict, self.nb_iter, self.overshoot, ord=self.ord, 148 | clip_min=self.clip_min, clip_max=self.clip_max, search_iter=self.search_iter, 149 | device=device) 150 | return x_adv, r_adv 151 | 152 | 153 | class LinfDeepFoolAttack(DeepFoolAttack): 154 | """ 155 | DeepFool Attack with order=Linf. 156 | Arguments: 157 | Arguments: 158 | predict (nn.Module): forward pass function. 159 | overshoot (float): how much to overshoot the boundary. 160 | nb_iter (int): number of iterations. 161 | search_iter (int): no of search iterations. 162 | clip_min (float): mininum value per input dimension. 163 | clip_max (float): maximum value per input dimension. 164 | """ 165 | 166 | def __init__( 167 | self, predict, overshoot=0.02, nb_iter=50, search_iter=50, clip_min=0., clip_max=1.): 168 | 169 | ord = np.inf 170 | super(LinfDeepFoolAttack, self).__init__( 171 | predict=predict, overshoot=overshoot, nb_iter=nb_iter, search_iter=search_iter, clip_min=clip_min, 172 | clip_max=clip_max, ord=ord) 173 | 174 | 175 | 176 | class L2DeepFoolAttack(DeepFoolAttack): 177 | """ 178 | DeepFool Attack with order=L2. 179 | Arguments: 180 | predict (nn.Module): forward pass function. 181 | overshoot (float): how much to overshoot the boundary. 182 | nb_iter (int): number of iterations. 183 | search_iter (int): no of search iterations. 184 | clip_min (float): mininum value per input dimension. 185 | clip_max (float): maximum value per input dimension. 186 | """ 187 | 188 | def __init__( 189 | self, predict, overshoot=0.02, nb_iter=50, search_iter=50, clip_min=0., clip_max=1.): 190 | 191 | ord = 2 192 | super(L2DeepFoolAttack, self).__init__( 193 | predict=predict, overshoot=overshoot, nb_iter=nb_iter, search_iter=search_iter, clip_min=clip_min, 194 | clip_max=clip_max, ord=ord) 195 | -------------------------------------------------------------------------------- /core/testfn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import pandas as pd 4 | from autoattack import AutoAttack 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from collections import defaultdict 9 | from core.context import ctx_noparamgrad_and_eval 10 | from core.data import load_dataloader, load_corr_dataloader 11 | from torchmetrics.functional.classification import multiclass_calibration_error 12 | from torchvision.transforms.functional import convert_image_dtype 13 | import math 14 | 15 | def test(dataloader_test, model, use_diffusion=False, augmentor=None, attacker=None, device=None): 16 | 17 | metrics = pd.DataFrame() 18 | model.eval() 19 | 20 | 21 | for x, y in dataloader_test: 22 | batch_metric = defaultdict(float) 23 | x, y = x.to(device), y.to(device) 24 | 25 | if attacker: 26 | with ctx_noparamgrad_and_eval(model): 27 | x, _ = attacker.perturb(x, y) 28 | 29 | with torch.no_grad(): 30 | if augmentor: 31 | # x_aug = augmentor(convert_image_dtype(x, torch.uint8)) 32 | # x_aug = convert_image_dtype(x_aug, torch.float) 33 | x_aug = augmentor(x) 34 | 35 | if use_diffusion: 36 | proba = 0 37 | for _ in range(10): 38 | out = model(x, use_diffusion=True) 39 | proba = proba + out 40 | out = proba/10 41 | else: 42 | out = model(x, use_diffusion = use_diffusion) 43 | # out_aug = model(x_aug, use_diffusion = use_diffusion) 44 | # distance = torch.abs(out_aug-out) 45 | 46 | batch_metric["eval_loss"] = F.cross_entropy(out, y).data.item() 47 | batch_metric["eval_acc"] = (torch.softmax(out.data, dim=1).argmax(dim=1) == y.data).sum().data.item() 48 | 49 | # batch_metric["eval_acc_aug"] = (torch.softmax(out_aug.data, dim=1).argmax(dim=1) == y.data).sum().data.item() 50 | # batch_metric["distance"] = distance.data.mean().item() 51 | # batch_metric["distance_min"] = distance.data.mean().item() 52 | # batch_metric["distance_max"] = distance.data.mean().item() 53 | # batch_metric["distance_std"] = distance.data.mean().item() 54 | 55 | metrics = pd.concat([metrics, pd.DataFrame(batch_metric, index=[0])], ignore_index=True) 56 | 57 | return dict(metrics.agg({ 58 | "eval_loss":"mean", 59 | "eval_acc":lambda x: 100*sum(x)/len(dataloader_test.dataset), 60 | # "eval_acc_aug":lambda x: 100*sum(x)/len(dataloader_test.dataset), 61 | # "distance":"mean", 62 | # "distance_min":"min", 63 | # "distance_max":"max", 64 | # "distance_std":"std" 65 | })) 66 | 67 | 68 | def eval_ood(ood_data, args, model, use_diffusion, logger, device): 69 | if 'pacs' in args.data: 70 | res = [] 71 | for c in range(len(ood_data)): 72 | dataloader = load_corr_dataloader(args.data, args.data_dir, args.batch_size, cname=ood_data[c]) 73 | dict = test(dataloader, model, use_diffusion =use_diffusion, device=device) 74 | del dataloader 75 | res.append(dict["eval_acc"]) 76 | logger.info("{}-{}".format(ood_data[c],dict["eval_acc"])) 77 | ret = np.array(res).mean() 78 | 79 | else: 80 | res = np.zeros((5, len(ood_data))) 81 | for c in range(len(ood_data)): 82 | for s in range(1, 6): 83 | dataloader = load_corr_dataloader(args.data, args.data_dir, args.batch_size, cname=ood_data[c], severity=s) 84 | dict = test(dataloader, model, use_diffusion =use_diffusion, device=device) 85 | del dataloader 86 | res[s-1, c] = dict["eval_acc"] 87 | frame = pd.DataFrame({i+1: res[i, :] for i in range(0, 5)}, index=ood_data) 88 | frame.loc['average'] = {i+1: np.mean(res, axis=1)[i] for i in range(0, 5)} 89 | frame['avg'] = frame[list(range(1, 6))].mean(axis=1) 90 | logger.info(frame) 91 | ret = frame["avg"]["average"] 92 | return {"eval_acc":ret} 93 | 94 | 95 | def clean_accuracy(model: torch.nn.Module, 96 | use_diffusion: bool, 97 | x: torch.Tensor, 98 | y: torch.Tensor, 99 | batch_size: int = 100, 100 | device: torch.device = None): 101 | if device is None: 102 | device = x.device 103 | acc = 0. 104 | n_batches = math.ceil(x.shape[0] / batch_size) 105 | with torch.no_grad(): 106 | for counter in range(n_batches): 107 | x_curr = x[counter * batch_size:(counter + 1) * 108 | batch_size].to(device) 109 | y_curr = y[counter * batch_size:(counter + 1) * 110 | batch_size].to(device) 111 | 112 | output = model(x_curr, use_diffusion=use_diffusion) 113 | 114 | acc += (output.max(1)[1] == y_curr).float().sum() 115 | 116 | return acc.item() / x.shape[0] 117 | 118 | 119 | def compute_ece (model: torch.nn.Module, 120 | use_diffusion: bool, 121 | x: torch.Tensor, 122 | y: torch.Tensor, 123 | batch_size: int = 100, 124 | device: torch.device = None): 125 | if device is None: 126 | device = x.device 127 | ece = 0. 128 | n_batches = math.ceil(x.shape[0] / batch_size) 129 | with torch.no_grad(): 130 | for counter in range(n_batches): 131 | x_curr = x[counter * batch_size:(counter + 1) * 132 | batch_size].to(device) 133 | y_curr = y[counter * batch_size:(counter + 1) * 134 | batch_size].to(device) 135 | 136 | output = model(x_curr, use_diffusion=use_diffusion) 137 | 138 | ece += multiclass_calibration_error(output, y_curr, num_classes=10, n_bins=15, norm='l1') 139 | 140 | return ece.item() / n_batches 141 | 142 | 143 | def compute_mce(corruption_accs, baseline_acc): 144 | """Compute mCE (mean Corruption Error) normalized by Baseline performance.""" 145 | mce = 0. 146 | for i in range(15): 147 | ce = (1-corruption_accs[i]) / (1-baseline_acc[i]) 148 | mce += ce 149 | return mce / 15 150 | 151 | 152 | def compute_rmce(nat_acc, corr_accs, baseline_nat_acc, baseline_corr_accs): 153 | """Compute rmCE (relative mean Corruption Error) normalized by Baseline performance.""" 154 | mce = 0. 155 | for i in range(15): 156 | ce = ((1-corr_accs[i])-(1-nat_acc)) / ((1-baseline_corr_accs[i])-(1-baseline_nat_acc)) 157 | mce += ce 158 | return mce / 15 159 | 160 | 161 | def final_corr_eval(x_corrs, y_corrs, model, use_diffusion, corruptions, baseline_accs, logger): 162 | l = len(corruptions) 163 | 164 | model.eval() 165 | nat_acc = clean_accuracy(model, use_diffusion, x_corrs[0].to(list(model.parameters())[0].device), y_corrs[0].to(list(model.parameters())[0].device)) 166 | logger.info('nat_acc: {}'.format(nat_acc)) 167 | 168 | 169 | res = np.zeros((5, l)) 170 | for i in range(1, 6): 171 | for j, c in enumerate(corruptions): 172 | res[i-1, j] = clean_accuracy(model, use_diffusion, x_corrs[i][j].to(list(model.parameters())[0].device), y_corrs[i][j].to(list(model.parameters())[0].device)) 173 | print(c, i, res[i-1, j]) 174 | 175 | frame = pd.DataFrame({i+1: res[i, :] for i in range(0, 5)}, index=corruptions) 176 | frame.loc['average'] = {i+1: np.mean(res, axis=1)[i] for i in range(0, 5)} 177 | frame['avg'] = frame[list(range(1,6))].mean(axis=1) 178 | logger.info(frame) 179 | 180 | baseline_acc_nat=baseline_accs[0] 181 | baseline_acc_s0=baseline_accs[1] 182 | baseline_acc_s5=baseline_accs[2] 183 | 184 | s0 = list(frame['avg']) 185 | s5 = list(frame[5]) 186 | 187 | mce_s0 = compute_mce(s0, baseline_acc_s0) 188 | logger.info('mce_s0: {}'.format(mce_s0)) 189 | 190 | mce_s5 = compute_mce(s5, baseline_acc_s5) 191 | logger.info('mce_s5: {}'.format(mce_s5)) 192 | 193 | rmce_s0 = compute_rmce(nat_acc, s0, baseline_acc_nat, baseline_acc_s0) 194 | logger.info('rmce_s0: {}'.format(rmce_s0)) 195 | 196 | rmce_s5 = compute_rmce(nat_acc, s5, baseline_acc_nat, baseline_acc_s5) 197 | logger.info('rmce_s5: {}'.format(rmce_s5)) 198 | 199 | 200 | def run_final_test_autoattack(model, args_test, logger, device): 201 | 202 | _, loader_test = load_dataloader(args_test) 203 | 204 | l = [x for (x, y) in loader_test] 205 | x_test = torch.cat(l, 0) 206 | l = [y for (x, y) in loader_test] 207 | y_test = torch.cat(l, 0) 208 | 209 | if args_test.threat =='linf': 210 | epsilon = 8 / 255. 211 | elif args_test.threat =='l2': 212 | epsilon = 0.5 213 | adversary = AutoAttack(model, norm=args_test.threat, eps=epsilon, version='standard', log_path=logger, seed=args_test.seed) 214 | with torch.no_grad(): 215 | x_adv = adversary.run_standard_evaluation(x_test, y_test, bs=128) 216 | -------------------------------------------------------------------------------- /core/attacks/pgd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .base import Attack, LabelMixin 7 | from torchvision.utils import save_image 8 | 9 | from .utils import batch_clamp 10 | from .utils import batch_multiply 11 | from .utils import clamp 12 | from .utils import clamp_by_pnorm 13 | from .utils import is_float_or_torch_tensor 14 | from .utils import normalize_by_pnorm 15 | from .utils import rand_init_delta 16 | from .utils import replicate_input 17 | 18 | 19 | def perturb_iterative(xvar, yvar, predict, nb_iter, eps, eps_iter, loss_fn, delta_init=None, minimize=False, ord=np.inf, 20 | clip_min=0.0, clip_max=1.0): 21 | """ 22 | Iteratively maximize the loss over the input. It is a shared method for iterative attacks. 23 | Arguments: 24 | xvar (torch.Tensor): input data. 25 | yvar (torch.Tensor): input labels. 26 | predict (nn.Module): forward pass function. 27 | nb_iter (int): number of iterations. 28 | eps (float): maximum distortion. 29 | eps_iter (float): attack step size. 30 | loss_fn (nn.Module): loss function. 31 | delta_init (torch.Tensor): (optional) tensor contains the random initialization. 32 | minimize (bool): (optional) whether to minimize or maximize the loss. 33 | ord (int): (optional) the order of maximum distortion (inf or 2). 34 | clip_min (float): mininum value per input dimension. 35 | clip_max (float): maximum value per input dimension. 36 | Returns: 37 | torch.Tensor containing the perturbed input, 38 | torch.Tensor containing the perturbation 39 | """ 40 | if delta_init is not None: 41 | delta = delta_init 42 | else: 43 | delta = torch.zeros_like(xvar) 44 | 45 | delta.requires_grad_() 46 | for ii in range(nb_iter): 47 | outputs = predict(xvar + delta) 48 | loss = loss_fn(outputs, yvar) 49 | if minimize: 50 | loss = -loss 51 | 52 | loss.backward() 53 | if ord == np.inf: 54 | grad_sign = delta.grad.data.sign() 55 | delta.data = delta.data + batch_multiply(eps_iter, grad_sign) 56 | delta.data = batch_clamp(eps, delta.data) 57 | delta.data = clamp(xvar.data + delta.data, clip_min, clip_max) - xvar.data 58 | elif ord == 2: 59 | grad = delta.grad.data 60 | grad = normalize_by_pnorm(grad) 61 | delta.data = delta.data + batch_multiply(eps_iter, grad) 62 | delta.data = clamp(xvar.data + delta.data, clip_min, clip_max) - xvar.data 63 | if eps is not None: 64 | delta.data = clamp_by_pnorm(delta.data, ord, eps) 65 | else: 66 | error = "Only ord=inf and ord=2 have been implemented" 67 | raise NotImplementedError(error) 68 | delta.grad.data.zero_() 69 | 70 | x_adv = clamp(xvar + delta, clip_min, clip_max) 71 | r_adv = x_adv - xvar 72 | return x_adv, r_adv 73 | 74 | 75 | class PGDAttack(Attack, LabelMixin): 76 | """ 77 | The projected gradient descent attack (Madry et al, 2017). 78 | The attack performs nb_iter steps of size eps_iter, while always staying within eps from the initial point. 79 | Arguments: 80 | predict (nn.Module): forward pass function. 81 | loss_fn (nn.Module): loss function. 82 | eps (float): maximum distortion. 83 | nb_iter (int): number of iterations. 84 | eps_iter (float): attack step size. 85 | rand_init (bool): (optional) random initialization. 86 | clip_min (float): mininum value per input dimension. 87 | clip_max (float): maximum value per input dimension. 88 | ord (int): (optional) the order of maximum distortion (inf or 2). 89 | targeted (bool): if the attack is targeted. 90 | rand_init_type (str): (optional) random initialization type. 91 | """ 92 | 93 | def __init__( 94 | self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1., 95 | ord=np.inf, targeted=False, rand_init_type='uniform',save_path=None): 96 | super(PGDAttack, self).__init__(predict, loss_fn, clip_min, clip_max) 97 | self.eps = eps 98 | self.nb_iter = nb_iter 99 | self.eps_iter = eps_iter 100 | self.rand_init = rand_init 101 | self.rand_init_type = rand_init_type 102 | self.ord = ord 103 | self.targeted = targeted 104 | if self.loss_fn is None: 105 | self.loss_fn = nn.CrossEntropyLoss(reduction="sum") 106 | assert is_float_or_torch_tensor(self.eps_iter) 107 | assert is_float_or_torch_tensor(self.eps) 108 | self.save_path=save_path 109 | 110 | def vis(self, x_ori, x_aug): 111 | x = torch.cat([x_ori[:8],x_aug[:8]]) 112 | if self.ord == np.inf: 113 | o = 'inf' 114 | elif self.ord == 2: 115 | o = '2' 116 | save_image(x.cpu(), os.path.join(self.save_path,'adv-l{}.png'.format(o)), nrow=8, 117 | padding=0, value_range=(0, 1), pad_value=0) 118 | 119 | def perturb(self, x, y=None, visualize=False): 120 | """ 121 | Given examples (x, y), returns their adversarial counterparts with an attack length of eps. 122 | Arguments: 123 | x (torch.Tensor): input tensor. 124 | y (torch.Tensor): label tensor. 125 | - if None and self.targeted=False, compute y as predicted 126 | labels. 127 | - if self.targeted=True, then y must be the targeted labels. 128 | Returns: 129 | torch.Tensor containing perturbed inputs, 130 | torch.Tensor containing the perturbation 131 | """ 132 | x, y = self._verify_and_process_inputs(x, y) 133 | 134 | delta = torch.zeros_like(x) 135 | delta = nn.Parameter(delta) 136 | if self.rand_init: 137 | if self.rand_init_type == 'uniform': 138 | rand_init_delta( 139 | delta, x, self.ord, self.eps, self.clip_min, self.clip_max) 140 | delta.data = clamp( 141 | x + delta.data, min=self.clip_min, max=self.clip_max) - x 142 | elif self.rand_init_type == 'normal': 143 | delta.data = 0.001 * torch.randn_like(x) # initialize as in TRADES 144 | else: 145 | raise NotImplementedError('Only rand_init_type=normal and rand_init_type=uniform have been implemented.') 146 | 147 | x_adv, r_adv = perturb_iterative( 148 | x, y, self.predict, nb_iter=self.nb_iter, eps=self.eps, eps_iter=self.eps_iter, loss_fn=self.loss_fn, 149 | minimize=self.targeted, ord=self.ord, clip_min=self.clip_min, clip_max=self.clip_max, delta_init=delta 150 | ) 151 | 152 | if visualize: 153 | self.vis(x, x_adv) 154 | 155 | return x_adv.data, r_adv.data 156 | 157 | 158 | class LinfPGDAttack(PGDAttack): 159 | """ 160 | PGD Attack with order=Linf 161 | Arguments: 162 | predict (nn.Module): forward pass function. 163 | loss_fn (nn.Module): loss function. 164 | eps (float): maximum distortion. 165 | nb_iter (int): number of iterations. 166 | eps_iter (float): attack step size. 167 | rand_init (bool): (optional) random initialization. 168 | clip_min (float): mininum value per input dimension. 169 | clip_max (float): maximum value per input dimension. 170 | targeted (bool): if the attack is targeted. 171 | rand_init_type (str): (optional) random initialization type. 172 | """ 173 | 174 | def __init__( 175 | self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1., 176 | targeted=False, rand_init_type='uniform',save_path=None): 177 | ord = np.inf 178 | super(LinfPGDAttack, self).__init__( 179 | predict=predict, loss_fn=loss_fn, eps=eps, nb_iter=nb_iter, eps_iter=eps_iter, rand_init=rand_init, 180 | clip_min=clip_min, clip_max=clip_max, targeted=targeted, ord=ord, rand_init_type=rand_init_type, save_path=save_path) 181 | 182 | 183 | class L2PGDAttack(PGDAttack): 184 | """ 185 | PGD Attack with order=L2 186 | Arguments: 187 | predict (nn.Module): forward pass function. 188 | loss_fn (nn.Module): loss function. 189 | eps (float): maximum distortion. 190 | nb_iter (int): number of iterations. 191 | eps_iter (float): attack step size. 192 | rand_init (bool): (optional) random initialization. 193 | clip_min (float): mininum value per input dimension. 194 | clip_max (float): maximum value per input dimension. 195 | targeted (bool): if the attack is targeted. 196 | rand_init_type (str): (optional) random initialization type. 197 | """ 198 | 199 | def __init__( 200 | self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1., 201 | targeted=False, rand_init_type='uniform'): 202 | ord = 2 203 | super(L2PGDAttack, self).__init__( 204 | predict=predict, loss_fn=loss_fn, eps=eps, nb_iter=nb_iter, eps_iter=eps_iter, rand_init=rand_init, 205 | clip_min=clip_min, clip_max=clip_max, targeted=targeted, ord=ord, rand_init_type=rand_init_type) 206 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | from collections import defaultdict 5 | 6 | import optuna 7 | import pandas as pd 8 | import torch 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from core.models import create_model 12 | from core.trainfn import train_standard, train_pdeadd 13 | from core.testfn import test, eval_ood 14 | from core.parse import parser_train 15 | from core.data import load_dataloader,load_dg_dataloader, corruption_15, pacs_4 16 | from core.scheduler import WarmUpLR, get_scheduler 17 | from core.utils import BestSaver, get_logger, set_seed, get_desc, eval_epoch, verbose_and_save 18 | 19 | # load args 20 | def objective(trial=None): 21 | global logger 22 | args = parser_train() 23 | set_seed(args.seed) 24 | 25 | if use_optuna: 26 | args.lrC = trial.suggest_float("lrC", 0.05, 0.1, step=0.01) 27 | args.lrDiff = trial.suggest_float("lrDiff", 0.0002, 0.001, step=0.0001) 28 | #args.batch_size=trial.suggest_int("bs", 64, 256, step=64) 29 | args.ls=trial.suggest_float("ls",0.1, 0.2, step=0.01) 30 | args.save_dir = optuna_save_dir 31 | 32 | if args.resume_file: 33 | resume_file = args.resume_file 34 | resume_epoch = args.epoch 35 | path = "/".join(args.resume_file.split('/')[:-2]) 36 | with open(os.path.join(path, 'train', 'args.txt'), 'r') as f: 37 | old = json.load(f) 38 | args.__dict__ = dict(vars(args), **old) 39 | args.save_dir = os.path.join(path, 'resume') 40 | args.resume_file = resume_file 41 | args.epoch = resume_epoch 42 | args.scheduler ='none' 43 | else: 44 | args.desc = get_desc(args) 45 | args.save_dir = os.path.join(args.save_dir, args.desc, 'train') 46 | 47 | # set logs 48 | os.makedirs(args.save_dir, exist_ok=True) 49 | if not use_optuna: 50 | logger = get_logger(logpath=os.path.join(args.save_dir, 'verbose.log')) 51 | else: 52 | logger.info('\nOptuna number of trial: '+str(trial.number)) 53 | logger.info('Optuna number of trial: '+str(trial.params)) 54 | 55 | if args.protocol in ['standard','fixdiff']: 56 | args.data_diff = None 57 | elif args.protocol == 'pdeadd': 58 | if args.data in ['cifar10','cifar100','tin200']: 59 | args.data_diff = args.data 60 | else: 61 | raise 62 | 63 | # save args 64 | with open(os.path.join(args.save_dir, 'args.txt'), 'w') as f: 65 | json.dump(args.__dict__, f, indent=4) 66 | 67 | # dataloaders 68 | if args.data in ['cifar10','cifar100','tin200'] : 69 | ood_data = corruption_15 70 | elif 'pacs' in args.data: 71 | ood_data = pacs_4 72 | if args.data.split("-")[1] in ood_data: 73 | ood_data.remove(args.data.split("-")[1]) 74 | if (args.data_diff): 75 | if (args.data_diff.split("-")[1] in ood_data) and (args.use_gmm): 76 | ood_data.remove(args.data_diff.split("-")[1]) 77 | 78 | 79 | if ('pacs' in args.data) and (args.data != args.data_diff) and (args.data_diff is not None): 80 | dataloader_train, dataloader_train_diff, dataloader_test = load_dg_dataloader(args) 81 | else: 82 | dataloader_train, dataloader_train_diff, dataloader_test = load_dataloader(args) 83 | 84 | 85 | logger.info('Using train dataset: {} with augment: {}'.format(args.data, args.aug_train)) 86 | if ('pacs' in args.data) or (args.data in ['tin200']): 87 | logger.info('[+] data shape: {}, label shape: {}'.format( 88 | len(dataloader_train.dataset.samples), len(dataloader_train.dataset.targets))) 89 | elif ('cifar' in args.data): 90 | logger.info('[+] data shape: {}, label shape: {}'.format( 91 | dataloader_train.dataset.data.shape, len(dataloader_train.dataset.targets))) 92 | else: 93 | raise 94 | 95 | if args.data_diff: 96 | logger.info('Using train diffusion guidance dataset: {} with augment: {}'.format(args.data_diff, args.aug_train_diff)) 97 | if ('pacs' in args.data) or (args.data in ['tin200']): 98 | logger.info('[+] data shape: {}, label shape: {}'.format( 99 | len(dataloader_train_diff.dataset.samples), len(dataloader_train_diff.dataset.targets))) 100 | elif ('cifar' in args.data) : 101 | logger.info('[+] data shape: {}, label shape: {}'.format( 102 | dataloader_train_diff.dataset.data.shape, len(dataloader_train_diff.dataset.targets))) 103 | else: 104 | raise 105 | else: 106 | logger.info('Not using train diffusion guidance dataset') 107 | 108 | logger.info('Using test dataset: {}'.format(ood_data)) 109 | 110 | # device 111 | device = torch.device('cuda' if torch.cuda.is_available() else 'mps') 112 | logger.info('using device: {}'.format(device)) 113 | 114 | # create model 115 | model = create_model(args.backbone, args.protocol, num_classes=len(dataloader_train.dataset.classes)) 116 | model = model.to(device) 117 | logger.info("using model: {}".format(args.backbone)) 118 | logger.info("using protocol: {}".format(args.protocol)) 119 | 120 | # attackers 121 | attack_train = None 122 | attack_eval = None 123 | logger.info('using training attacker: {}'.format(args.atk_train)) 124 | logger.info('using evaluating attacker: {}'.format(args.atk_eval)) 125 | 126 | # optimizers 127 | 128 | optimizerC = torch.optim.SGD(model.parameters(), lr=args.lrC, momentum=0.9, weight_decay=args.weight_decay) 129 | optimizerDiff = None 130 | if args.protocol == 'pdeadd': 131 | diffusion_params = [] 132 | for name, param in model.named_parameters(): 133 | if 'diff' in name: 134 | diffusion_params.append(param) 135 | optimizerDiff = torch.optim.Adam(diffusion_params,lr=args.lrDiff) 136 | 137 | # schedulers 138 | scheduler = get_scheduler(args, opt=optimizerC) 139 | if args.warm: 140 | iter_per_epoch = len(dataloader_train) 141 | warmup_scheduler = WarmUpLR(optimizerC, iter_per_epoch * args.warm) 142 | logger.info('using scheduler: {}, warmup {}'.format(args.scheduler, args.warm)) 143 | 144 | # resume 145 | start_epoch=1 146 | if args.resume_file: 147 | checkpoint = torch.load(args.resume_file) 148 | model.load_state_dict(checkpoint['model_state_dict']) 149 | optimizerC.load_state_dict(checkpoint['optimizerC_state_dict']) 150 | if 'pdeadd' in args.protocol: 151 | optimizerDiff.load_state_dict(checkpoint['optimizerDiff_state_dict']) 152 | last_cla_lr = checkpoint['scheduler']["_last_lr"][0] 153 | start_epoch = checkpoint['scheduler']["last_epoch"] 154 | for param_group in optimizerC.param_groups: 155 | param_group["lr"] = last_cla_lr 156 | del checkpoint 157 | 158 | # writer 159 | if not use_optuna: 160 | writer = SummaryWriter(os.path.join(args.save_dir), comment='train', filename_suffix='train') 161 | #writer_eval = SummaryWriter(os.path.join(args.save_dir), comment='eval', filename_suffix='eval') 162 | #writer_eval_diff = SummaryWriter(os.path.join(args.save_dir), comment='eval_diff', filename_suffix='eval_diff') 163 | 164 | # start training 165 | total_metrics = pd.DataFrame() 166 | saver = BestSaver() 167 | eval_metric, eval_metric["ood"],eval_metric["nat"] = defaultdict(float),defaultdict(float),defaultdict(float) 168 | 169 | for epoch in range(start_epoch, args.epoch + start_epoch): 170 | start = time.time() 171 | 172 | if args.protocol == 'pdeadd': 173 | train_metric = train_pdeadd(dataloader_train, dataloader_train_diff, model, optimizerDiff, optimizerC, label_smooth=args.ls, 174 | attacker=attack_train, device=device, use_gmm=args.use_gmm, save_path=args.save_dir) 175 | else: 176 | train_metric = train_standard(dataloader_train, model, optimizerC, attacker=attack_train, 177 | device=device, visualize=True if epoch==start_epoch else False, epoch=epoch) 178 | 179 | if (args.scheduler != 'none') and (epoch > args.warm): 180 | scheduler.step() 181 | if (args.warm) and (epoch <= args.warm): 182 | warmup_scheduler.step() 183 | 184 | # test for nat 185 | with torch.no_grad(): 186 | eval_per_epoch = eval_epoch(epoch) 187 | if (epoch == start_epoch) or (epoch % eval_per_epoch == 0): 188 | eval_metric["ood"] = eval_ood(ood_data=ood_data, args=args, model=model, use_diffusion=False, logger=logger, device=device) 189 | if dataloader_test: 190 | eval_metric["nat"] = test(dataloader_test, model, use_diffusion=False, device=device) 191 | 192 | 193 | if use_optuna: 194 | trial.report(eval_metric["ood"]["eval_acc"], epoch) 195 | if trial.should_prune(): 196 | raise optuna.exceptions.TrialPruned() 197 | return accuracy 198 | else: 199 | verbose_and_save(logger, epoch, start_epoch, eval_per_epoch, start, train_metric, eval_metric, writer) 200 | 201 | # save csv 202 | metric = pd.concat( 203 | [pd.DataFrame(train_metric,index=[epoch]), 204 | pd.DataFrame(eval_metric["nat"],index=[epoch]), 205 | pd.DataFrame(eval_metric["ood"],index=[epoch]), 206 | #pd.DataFrame(eval_metric["ood_diff"],index=[epoch]) 207 | ], axis=1) 208 | total_metrics = pd.concat([total_metrics, metric], ignore_index=True) 209 | total_metrics.to_csv(os.path.join(args.save_dir, 'stats.csv'), index=True) 210 | 211 | # save model 212 | saver.apply(eval_metric["ood"]['eval_acc'], epoch, 213 | model=model, optimizerC=optimizerC, scheduler=scheduler, optimizerDiff=optimizerDiff, 214 | save_path=os.path.join(args.save_dir,'model-best-wodiff.pt')) 215 | # diff_saver.apply(eval_metric["ood_diff"]['eval_acc'], epoch, 216 | # model=model, optimizerC=optimizerC, scheduler=scheduler, optimizerDiff=optimizerDiff, 217 | # save_path=os.path.join(args.save_dir,'model-best-endiff.pt')) 218 | if (epoch!=0) and (epoch % args.save_freq==0): 219 | saver.save_model(model=model, optimizerC=optimizerC, scheduler=scheduler, optimizerDiff=optimizerDiff, 220 | save_path=os.path.join(args.save_dir,'model-e{}.pt'.format(epoch))) 221 | saver.save_model(model=model, optimizerC=optimizerC, scheduler=scheduler, optimizerDiff=optimizerDiff, 222 | save_path=os.path.join(args.save_dir,'model-last.pt')) 223 | 224 | logger.info("[Final]\t Best wo-diff acc: {:.2f}% in Epoch: {}".format(saver.best_acc, saver.best_epoch)) 225 | #logger.info("[Final]\t Best en-diff acc: {:.2f}% in Epoch: {}".format(diff_saver.best_acc, diff_saver.best_epoch)) 226 | 227 | return eval_metric["ood"]["eval_acc"] 228 | 229 | use_optuna = False 230 | optuna_save_dir = '/home/yuanyige/Ladiff_nll/save_optuna_200' 231 | if use_optuna: 232 | os.makedirs(optuna_save_dir, exist_ok=True) 233 | logger = get_logger(logpath=os.path.join(optuna_save_dir, 'verbose.log')) 234 | study = optuna.create_study(direction='maximize') 235 | study.optimize(objective, n_trials=20) 236 | print('\n\nbest value',study.best_value) 237 | print('best param',study.best_params) 238 | else: 239 | objective() -------------------------------------------------------------------------------- /core/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms as T 7 | from torchvision import datasets 8 | from torch.utils.data import Sampler, DataLoader 9 | from collections import defaultdict 10 | from torch.utils.data import Dataset 11 | #from robustbench.data import load_cifar10c, load_cifar10, load_cifar100c, load_cifar100 12 | 13 | corruption_19=[ 'snow', 'fog', 'frost', 'glass_blur', 'defocus_blur','motion_blur','zoom_blur','gaussian_blur', 14 | 'gaussian_noise','shot_noise','impulse_noise','speckle_noise', 15 | 'pixelate','brightness','contrast','jpeg_compression','elastic_transform','spatter','saturate'] 16 | 17 | corruption_15 = ['snow', 'fog', 'frost', 'glass_blur', 'defocus_blur', 'motion_blur','zoom_blur', 18 | 'gaussian_noise', 'shot_noise', 'impulse_noise', 19 | 'pixelate', 'brightness', 'contrast','jpeg_compression', 'elastic_transform'] 20 | 21 | pacs_4 = ['art', 'cartoon', 'photo', 'sketch'] 22 | 23 | def split_small_dataset(num, tar, num_class): 24 | index = [] 25 | for k in range(num_class): 26 | l = [i for i,j in enumerate(tar) if j == k][:num] 27 | index.extend(l) 28 | return index 29 | 30 | class Identity(torch.nn.Module): 31 | def __init__(self): 32 | super().__init__() 33 | def forward(self, x): 34 | return x 35 | 36 | class FixSampler(Sampler): 37 | def __init__(self, length): 38 | #length = int(len(dataset)) 39 | self.indices = list(range(length)) 40 | random.Random(1).shuffle(self.indices) 41 | def __iter__(self): 42 | return iter(self.indices) 43 | 44 | 45 | class PairedDataset(Dataset): 46 | def __init__(self, data, classes, transform=None): 47 | self.samples = [x[0] for x in data] 48 | self.targets = [x[1] for x in data] 49 | self.classes = classes 50 | self.transform = transform 51 | 52 | def __getitem__(self, index): 53 | image = self.samples[index] 54 | if self.transform: 55 | image = self.transform(image) 56 | label = self.targets[index] 57 | return image, label 58 | 59 | def __len__(self): 60 | return len(self.samples) 61 | 62 | def Balance(root, data1, data2, t1, t2): 63 | def balance_classes(d1, d2): 64 | #print(len(d1),len(d2)) 65 | if len(d1) < len(d2): 66 | while len(d1) < len(d2): 67 | d1.append(random.choice(d1)) 68 | elif len(d1) > len(d2): 69 | while len(d2) < len(d1): 70 | d2.append(random.choice(d2)) 71 | #print(len(d1),len(d2)) 72 | return d1, d2 73 | 74 | # Load the datasets 75 | domain1 = datasets.ImageFolder(os.path.join(root, 'pacs', data1.split('-')[1])) 76 | domain2 = datasets.ImageFolder(os.path.join(root, 'pacs', data2.split('-')[1])) 77 | #print('domain1',domain1.samples) 78 | 79 | # Group images by class 80 | data_dict1 = defaultdict(list) 81 | data_dict2 = defaultdict(list) 82 | 83 | for image, label in domain1: 84 | data_dict1[label].append((image, label)) 85 | #print(data_dict1) 86 | 87 | for image, label in domain2: 88 | data_dict2[label].append((image, label)) 89 | 90 | balanced_data1 = [] 91 | balanced_data2 = [] 92 | 93 | # Balance the number of images in each class 94 | for label in range(len(domain1.classes)): 95 | #print('label',label) 96 | data_dict1[label],data_dict2[label]=balance_classes(data_dict1[label], data_dict2[label]) 97 | balanced_data1.extend(data_dict1[label]) 98 | balanced_data2.extend(data_dict2[label]) 99 | #print('balanced_data1balanced_data1',balanced_data1) 100 | 101 | # Create the final datasets with balanced class distribution 102 | dataset1 = PairedDataset(balanced_data1, domain1.classes, transform=t1) 103 | dataset2 = PairedDataset(balanced_data2, domain2.classes, transform=t2) 104 | return dataset1, dataset2 105 | 106 | class Augmentor(torch.nn.Module): 107 | def __init__(self, augments): 108 | super().__init__() 109 | self.augments=augments 110 | self.mapping = { 111 | "gaublur":T.GaussianBlur, 112 | "elastic":T.ElasticTransform, 113 | "contrast":T.RandomAutocontrast, 114 | "invert":T.RandomInvert, 115 | "color":T.ColorJitter, 116 | "rotation":T.RandomRotation, 117 | "augmix":T.AugMix, 118 | "randaug":T.RandAugment, 119 | "autoaug":T.AutoAugment, 120 | "none":Identity} 121 | 122 | def forward(self, img): 123 | type, param = self.split_string() 124 | return self.mapping[type](*param)(img) 125 | 126 | def split_string(self): 127 | augments = self.augments.split('-') 128 | type = augments[0] 129 | param = [] 130 | if len(augments) > 1: 131 | for p in augments[1:]: 132 | if '.' in p: 133 | p = float(p) 134 | else: 135 | p = int(p) 136 | param.append(p) 137 | return type, param 138 | 139 | def load_dataloader(args): 140 | # define transforms 141 | transform_train = [ T.Resize(32), 142 | T.RandomCrop(32, padding=4), 143 | T.RandomHorizontalFlip(), 144 | Augmentor(args.aug_train), 145 | T.ToTensor()] 146 | transform_train = T.Compose(transform_train) 147 | 148 | if args.data_diff is not None: 149 | transform_train_diff = [ T.Resize(32), 150 | T.RandomCrop(32, padding=4), 151 | T.RandomHorizontalFlip(), 152 | Augmentor(args.aug_train_diff), 153 | T.ToTensor()] 154 | transform_train_diff = T.Compose(transform_train_diff) 155 | 156 | transform_eval = [T.Resize(32), T.ToTensor()] 157 | transform_eval = T.Compose(transform_eval) 158 | 159 | # load train & test data 160 | if args.data.lower() == 'mnist': 161 | data_train = datasets.MNIST(root=os.path.join(args.data_dir, 'mnist') ,transform=transform_train,train = True, download = True) 162 | data_test = datasets.MNIST(root=os.path.join(args.data_dir, 'mnist') ,transform = transform_eval,train = False) 163 | elif args.data.lower() == 'cifar10': 164 | data_train = datasets.CIFAR10(root=os.path.join(args.data_dir, 'cifar10'), transform=transform_train, train = True, download = True) 165 | data_test = datasets.CIFAR10(root=os.path.join(args.data_dir, 'cifar10'), transform = transform_eval, train = False) 166 | elif args.data.lower() == 'cifar100': 167 | data_train = datasets.CIFAR100(root=os.path.join(args.data_dir, 'cifar100'), transform = transform_train,train = True, download = True) 168 | data_test = datasets.CIFAR100(root=os.path.join(args.data_dir, 'cifar100'), transform = transform_eval,train = False) 169 | elif args.data.lower() == 'tin200': 170 | data_train = datasets.ImageFolder(os.path.join(args.data_dir, 'tiny-imagenet-200', 'train'), transform=transform_train) 171 | data_test = datasets.ImageFolder(os.path.join(args.data_dir, 'tiny-imagenet-200', 'val'), transform=transform_eval) 172 | elif 'pacs' in args.data.lower(): 173 | data_train = datasets.ImageFolder(os.path.join(args.data_dir, 'pacs', args.data.split('-')[1]), transform=transform_train) 174 | data_test = None 175 | else: 176 | raise 177 | 178 | # load ood train data for pde-add 179 | if args.data_diff is not None: 180 | if args.data_diff.lower() == 'mnist': 181 | data_train_diff = datasets.MNIST(root=os.path.join(args.data_dir, 'mnist') ,transform=transform_train_diff, train = True, download = True) 182 | elif args.data_diff.lower() == 'cifar10': 183 | data_train_diff = datasets.CIFAR10(root=os.path.join(args.data_dir, 'cifar10'), transform=transform_train_diff, train = True, download = True) 184 | elif args.data_diff.lower() == 'cifar100': 185 | data_train_diff = datasets.CIFAR100(root=os.path.join(args.data_dir, 'cifar100'), transform=transform_train_diff, train = True, download = True) 186 | elif args.data_diff.lower() == 'tin200': 187 | data_train_diff = datasets.ImageFolder(os.path.join(args.data_dir, 'tiny-imagenet-200', 'train'), transform=transform_train_diff) 188 | elif 'pacs' in args.data_diff.lower(): 189 | data_train_diff = datasets.ImageFolder(os.path.join(args.data_dir, 'pacs', args.data_diff.split('-')[1]), transform=transform_train_diff) 190 | else: 191 | raise 192 | 193 | if args.npc_train != 'all': 194 | index = split_small_dataset(num=int(args.npc_train), tar = data_train.targets, num_class=len(data_train.classes)) 195 | data_train.data = data_train.data[index] 196 | data_train.targets = np.array(data_train.targets)[index].tolist() 197 | if args.data_diff is not None: 198 | data_train_diff.data = data_train_diff.data[index] 199 | data_train_diff.targets = np.array(data_train_diff.targets)[index].tolist() 200 | 201 | sampler = FixSampler(int(len(data_train))) 202 | dataloader_train = DataLoader(dataset=data_train, sampler=sampler, batch_size=args.batch_size, shuffle = False, num_workers=args.num_workers, pin_memory=True) 203 | 204 | if args.data_diff is not None: 205 | dataloader_train_diff = DataLoader(dataset=data_train_diff, sampler=sampler, batch_size=args.batch_size, shuffle = False, num_workers=args.num_workers, pin_memory=True) 206 | else: 207 | dataloader_train_diff = None 208 | 209 | if data_test is not None: 210 | dataloader_test = DataLoader(dataset=data_test, batch_size=args.batch_size_validation, shuffle = False, num_workers=args.num_workers, pin_memory=True) 211 | else: 212 | dataloader_test = None 213 | 214 | return dataloader_train, dataloader_train_diff, dataloader_test 215 | 216 | def load_dg_dataloader(args): 217 | transform_train = [ T.Resize(32), 218 | T.RandomCrop(32, padding=4), 219 | T.RandomHorizontalFlip(), 220 | Augmentor(args.aug_train), 221 | T.ToTensor()] 222 | transform_train = T.Compose(transform_train) 223 | 224 | if args.data_diff is not None: 225 | transform_train_diff = [ T.Resize(32), 226 | T.RandomCrop(32, padding=4), 227 | T.RandomHorizontalFlip(), 228 | Augmentor(args.aug_train_diff), 229 | T.ToTensor()] 230 | transform_train_diff = T.Compose(transform_train_diff) 231 | 232 | 233 | dataset_train, dataset_train_diff = Balance(args.data_dir, args.data, args.data_diff, transform_train, transform_train_diff) 234 | sampler = FixSampler(int(len(dataset_train))) 235 | dataloader_train = DataLoader(dataset=dataset_train, sampler=sampler, batch_size=args.batch_size, shuffle = False, num_workers=args.num_workers, pin_memory=True) 236 | dataloader_train_diff = DataLoader(dataset=dataset_train_diff, sampler=sampler, batch_size=args.batch_size, shuffle = False, num_workers=args.num_workers, pin_memory=True) 237 | return dataloader_train, dataloader_train_diff, None 238 | 239 | 240 | class CIFARC(datasets.VisionDataset): 241 | def __init__(self, root :str, name=None, 242 | transform=None, target_transform=None, dnum='all', severity=0): 243 | #assert name in corruptions 244 | super(CIFARC, self).__init__( 245 | root, transform=transform, 246 | target_transform=target_transform 247 | ) 248 | 249 | data_path = os.path.join(root, name + '.npy') 250 | target_path = os.path.join(root, 'labels.npy') 251 | 252 | self.data = np.load(data_path) 253 | self.targets = np.load(target_path) 254 | 255 | if severity: 256 | self.data=self.data[10000*(severity-1):10000*severity, :] 257 | self.targets=self.targets[10000*(severity-1):10000*severity] 258 | 259 | def __getitem__(self, index): 260 | img, targets = self.data[index], self.targets[index] 261 | img = Image.fromarray(img) 262 | 263 | if self.transform is not None: 264 | img = self.transform(img) 265 | if self.target_transform is not None: 266 | targets = self.target_transform(targets) 267 | return img, targets 268 | 269 | def __len__(self): 270 | return len(self.data) 271 | 272 | def load_corr_dataloader(data_name, data_dir, batch_size, cname=None, severity=None, num_workers=2): 273 | transform = T.Compose([T.Resize(32), T.ToTensor()]) 274 | if data_name in ['cifar10','cifar100']: 275 | filename = '-'.join(['CIFAR', data_name[5:], 'C']) 276 | data = CIFARC(os.path.join(data_dir, filename), cname, transform=transform, severity=severity) 277 | dataloader = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) 278 | elif data_name == 'tin200': 279 | severity = str(severity) 280 | data = datasets.ImageFolder(os.path.join(data_dir, 'Tiny-ImageNet-C', cname, severity), transform=transform) 281 | dataloader = DataLoader(dataset=data, batch_size=batch_size, shuffle = False, num_workers=num_workers, pin_memory=True) 282 | elif 'pacs' in data_name: 283 | data = datasets.ImageFolder(os.path.join(data_dir, 'pacs', cname), transform=transform) 284 | dataloader = DataLoader(dataset=data, batch_size=batch_size, shuffle = False, num_workers=num_workers, pin_memory=True) 285 | else: 286 | raise 287 | return dataloader 288 | -------------------------------------------------------------------------------- /core/models/pdeadd_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Diffusion(nn.Module): 7 | def __init__(self, planes): 8 | super(Diffusion, self).__init__() 9 | self.main = nn.Sequential( 10 | nn.Conv2d(planes, 2*planes, kernel_size=3, stride=1, padding=1, bias=True), 11 | nn.BatchNorm2d(2*planes), 12 | nn.ReLU(inplace=True), 13 | nn.Conv2d(2*planes, planes, kernel_size=3, stride=1, padding=1, bias=True), 14 | nn.BatchNorm2d(planes), 15 | nn.ReLU(inplace=True)) 16 | 17 | def forward(self, input): 18 | out = self.main(input) 19 | return out 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | """ 24 | Implements a basic block module for Resnets. 25 | Arguments: 26 | in_planes (int): number of input planes. 27 | out_planes (int): number of output filters. 28 | stride (int): stride of convolution. 29 | """ 30 | expansion = 1 31 | 32 | def __init__(self, in_planes, planes, stride=1): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | 39 | self.shortcut = nn.Sequential() 40 | if stride != 1 or in_planes != self.expansion * planes: 41 | self.shortcut = nn.Sequential( 42 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 43 | nn.BatchNorm2d(self.expansion * planes)) 44 | 45 | 46 | def forward(self, x): 47 | out = F.relu(self.bn1(self.conv1(x))) 48 | out = self.bn2(self.conv2(out)) 49 | out += self.shortcut(x) 50 | out = F.relu(out) 51 | return out 52 | 53 | 54 | 55 | 56 | 57 | class ResNet(nn.Module): 58 | """ 59 | ResNet model 60 | Arguments: 61 | block (BasicBlock or Bottleneck): type of basic block to be used. 62 | num_blocks (list): number of blocks in each sub-module. 63 | num_classes (int): number of output classes. 64 | device (torch.device or str): device to work on. 65 | """ 66 | def __init__(self, block, num_blocks, num_classes=10, device='cpu'): 67 | super(ResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.bn1 = nn.BatchNorm2d(64) 72 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 73 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 74 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 75 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 76 | 77 | self.diff1 = Diffusion(64) 78 | self.diff2 = Diffusion(128) 79 | self.diff3 = Diffusion(256) 80 | self.diff4 = Diffusion(512) 81 | 82 | self.linear = nn.Linear(512 * block.expansion, num_classes) 83 | 84 | def _make_layer(self, block, planes, num_blocks, stride): 85 | strides = [stride] + [1] * (num_blocks - 1) 86 | layers = [] 87 | for stride in strides: 88 | layers.append(block(self.in_planes, planes, stride)) 89 | self.in_planes = planes * block.expansion 90 | return nn.Sequential(*layers) 91 | 92 | def net(self, x, use_diffusion=True): 93 | 94 | self.mus = [] 95 | self.sigmas = [] 96 | self.scales = [] 97 | 98 | out = F.relu(self.bn1(self.conv1(x))) 99 | 100 | out = self.layer1(out) 101 | if use_diffusion: 102 | sigma = self.diff1(out) 103 | else: 104 | sigma = torch.zeros_like(out) 105 | out_diff = out + sigma * torch.randn_like(out) 106 | self.mus.append(out) 107 | self.sigmas.append(sigma) 108 | self.scales.append(sigma.mean().detach().data.item()) 109 | 110 | out = self.layer2(out_diff) 111 | if use_diffusion: 112 | sigma = self.diff2(out) 113 | else: 114 | sigma = torch.zeros_like(out) 115 | out_diff = out + sigma * torch.randn_like(out) 116 | self.mus.append(out) 117 | self.sigmas.append(sigma) 118 | self.scales.append(sigma.mean().detach().data.item()) 119 | 120 | out = self.layer3(out_diff) 121 | if use_diffusion: 122 | sigma = self.diff3(out) 123 | else: 124 | sigma = torch.zeros_like(out) 125 | out_diff = out + sigma * torch.randn_like(out) 126 | self.mus.append(out) 127 | self.sigmas.append(sigma) 128 | self.scales.append(sigma.mean().detach().data.item()) 129 | 130 | out = self.layer4(out_diff) 131 | if use_diffusion: 132 | sigma = self.diff4(out) 133 | else: 134 | sigma = torch.zeros_like(out) 135 | out_diff = out + sigma * torch.randn_like(out) 136 | self.mus.append(out) 137 | self.sigmas.append(sigma) 138 | self.scales.append(sigma.mean().detach().data.item()) 139 | 140 | out = F.avg_pool2d(out_diff, 4) 141 | out = out.view(out.size(0), -1) 142 | out = self.linear(out) 143 | 144 | return out 145 | 146 | def forward(self, x, use_diffusion=True): 147 | #if self.training: 148 | out = self.net(x, use_diffusion=use_diffusion) 149 | # else: 150 | # if use_diffusion: 151 | # proba = 0 152 | # for _ in range(10): 153 | # out = self.net(x, use_diffusion=True) 154 | # proba = proba + out 155 | # out = proba/10 156 | # else: 157 | # out = self.net(x, use_diffusion=False) 158 | return out 159 | 160 | 161 | 162 | def pdeadd_resnet(name, num_classes=10, pretrained=False, device='cpu'): 163 | """ 164 | Returns suitable Resnet model from its name. 165 | Arguments: 166 | name (str): name of resnet architecture. 167 | num_classes (int): number of target classes. 168 | pretrained (bool): whether to use a pretrained model. 169 | device (str or torch.device): device to work on. 170 | Returns: 171 | torch.nn.Module. 172 | """ 173 | if name == 'resnet-18': 174 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, device=device) 175 | elif name == 'resnet-34': 176 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, device=device) 177 | # elif name == 'resnet-50': 178 | # return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, device=device) 179 | # elif name == 'resnet-101': 180 | # return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, device=device) 181 | 182 | raise ValueError('Only resnet18, resnet34, resnet50 and resnet101 are supported!') 183 | 184 | 185 | # import torch 186 | # import torch.nn as nn 187 | # import torch.nn.functional as F 188 | 189 | 190 | # class Diffusion(nn.Module): 191 | # def __init__(self, planes): 192 | # super(Diffusion, self).__init__() 193 | # self.main = nn.Sequential( 194 | # nn.Conv2d(planes, 2*planes, kernel_size=3, stride=1, padding=1, bias=True), 195 | # nn.BatchNorm2d(2*planes), 196 | # nn.ReLU(inplace=True), 197 | # nn.Conv2d(2*planes, planes, kernel_size=3, stride=1, padding=1, bias=True), 198 | # nn.BatchNorm2d(planes), 199 | # nn.ReLU(inplace=True)) 200 | 201 | # def forward(self, input): 202 | # out = self.main(input) 203 | # return out 204 | 205 | 206 | # class BasicBlock(nn.Module): 207 | # """ 208 | # Implements a basic block module for Resnets. 209 | # Arguments: 210 | # in_planes (int): number of input planes. 211 | # out_planes (int): number of output filters. 212 | # stride (int): stride of convolution. 213 | # """ 214 | # expansion = 1 215 | 216 | # def __init__(self, in_planes, planes, stride=1): 217 | # super(BasicBlock, self).__init__() 218 | # self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 219 | # self.bn1 = nn.BatchNorm2d(planes) 220 | # self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 221 | # self.bn2 = nn.BatchNorm2d(planes) 222 | 223 | # self.shortcut = nn.Sequential() 224 | # if stride != 1 or in_planes != self.expansion * planes: 225 | # self.shortcut = nn.Sequential( 226 | # nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 227 | # nn.BatchNorm2d(self.expansion * planes)) 228 | 229 | # self.diff = Diffusion(planes) 230 | 231 | 232 | # def forward(self, x): 233 | # x, use_diffusion, _ = x 234 | 235 | # out = F.relu(self.bn1(self.conv1(x))) 236 | # out = self.bn2(self.conv2(out)) 237 | # out += self.shortcut(x) 238 | # out = F.relu(out) 239 | 240 | # if use_diffusion: 241 | # sigma = self.diff(out) 242 | # out = out + sigma * torch.randn_like(out) 243 | # return (out, use_diffusion, sigma) 244 | # else: 245 | # return (out, use_diffusion, 0) 246 | 247 | 248 | 249 | 250 | 251 | # class ResNet(nn.Module): 252 | # """ 253 | # ResNet model 254 | # Arguments: 255 | # block (BasicBlock or Bottleneck): type of basic block to be used. 256 | # num_blocks (list): number of blocks in each sub-module. 257 | # num_classes (int): number of output classes. 258 | # device (torch.device or str): device to work on. 259 | # """ 260 | # def __init__(self, block, num_blocks, num_classes=10, device='cpu'): 261 | # super(ResNet, self).__init__() 262 | # self.in_planes = 64 263 | 264 | # self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 265 | # self.bn1 = nn.BatchNorm2d(64) 266 | # self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 267 | # self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 268 | # self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 269 | # self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 270 | # self.linear = nn.Linear(512 * block.expansion, num_classes) 271 | 272 | # def _make_layer(self, block, planes, num_blocks, stride): 273 | # strides = [stride] + [1] * (num_blocks - 1) 274 | # layers = [] 275 | # for stride in strides: 276 | # layers.append(block(self.in_planes, planes, stride)) 277 | # self.in_planes = planes * block.expansion 278 | # return nn.Sequential(*layers) 279 | 280 | # def net(self, x, use_diffusion=True): 281 | 282 | # self.mus = [] 283 | # self.sigmas = [] 284 | # self.scales = [] 285 | 286 | # out = F.relu(self.bn1(self.conv1(x))) 287 | # out = self.layer1((out, use_diffusion, 0)) 288 | # self.mus.append(out[0]) 289 | # self.sigmas.append(out[2]) 290 | # if use_diffusion: 291 | # self.scales.append(out[2].max().detach().data.item()) 292 | # else: 293 | # self.scales.append(0) 294 | 295 | # out = self.layer2(out) 296 | # self.mus.append(out[0]) 297 | # self.sigmas.append(out[2]) 298 | # if use_diffusion: 299 | # self.scales.append(out[2].max().detach().data.item()) 300 | # else: 301 | # self.scales.append(0) 302 | 303 | # out = self.layer3(out) 304 | # self.mus.append(out[0]) 305 | # self.sigmas.append(out[2]) 306 | # if use_diffusion: 307 | # self.scales.append(out[2].max().detach().data.item()) 308 | # else: 309 | # self.scales.append(0) 310 | 311 | # out = self.layer4(out) 312 | # self.mus.append(out[0]) 313 | # self.sigmas.append(out[2]) 314 | # if use_diffusion: 315 | # self.scales.append(out[2].max().detach().data.item()) 316 | # else: 317 | # self.scales.append(0) 318 | 319 | # out = out[0] 320 | # out = F.avg_pool2d(out, 4) 321 | # out = out.view(out.size(0), -1) 322 | # out = self.linear(out) 323 | 324 | # return out 325 | 326 | # def forward(self, x, use_diffusion=True): 327 | # if self.training: 328 | # #print('training........') 329 | # out = self.net(x, use_diffusion=use_diffusion) 330 | # out = F.log_softmax(out, dim=1) 331 | # else: 332 | # #print('evaling........') 333 | # if use_diffusion: 334 | # proba = 0 335 | # for k in range(10): 336 | # out = self.net(x, use_diffusion=True) 337 | # p = F.softmax(out, dim=1) 338 | # proba = proba + p 339 | # out = ((proba/10)+1e-20).log() # next nll 340 | # else: 341 | # out = self.net(x, use_diffusion=False) 342 | # out = F.log_softmax(out, dim=1) 343 | # return out 344 | 345 | # def pdeadd_resnet(name, num_classes=10, pretrained=False, device='cpu'): 346 | # """ 347 | # Returns suitable Resnet model from its name. 348 | # Arguments: 349 | # name (str): name of resnet architecture. 350 | # num_classes (int): number of target classes. 351 | # pretrained (bool): whether to use a pretrained model. 352 | # device (str or torch.device): device to work on. 353 | # Returns: 354 | # torch.nn.Module. 355 | # """ 356 | # if name == 'resnet-18': 357 | # return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, device=device) 358 | # elif name == 'resnet-34': 359 | # return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, device=device) 360 | 361 | # raise ValueError('Only resnet18, resnet34, resnet50 and resnet101 are supported!') 362 | 363 | --------------------------------------------------------------------------------