├── AA_eval.py ├── README.md ├── autopgd_train_clean.py ├── dataset_convnext_like.py ├── fgsm_train.py ├── main.py ├── models ├── __init__.py ├── convnext.py ├── convnext_iso.py └── utils.py ├── parserr.py ├── rb_architecture_util.py ├── readme_teaser.png ├── run_train.sh ├── runner_aa_eval.py ├── utils.py ├── utils_architecture.py ├── utils_eval.py └── utils_train.py /AA_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import models 3 | import torch.nn as nn 4 | 5 | from timm.models import create_model 6 | from torchvision import datasets, transforms 7 | import math 8 | import argparse 9 | import os 10 | import sys 11 | sys.path.insert(0,'..') 12 | 13 | import json 14 | import robustbench 15 | import numpy as np 16 | 17 | from autoattack import AutoAttack 18 | from robustbench.utils import clean_accuracy 19 | 20 | from main import BlurPoolConv2d, PREC_DICT, IMAGENET_MEAN, \ 21 | IMAGENET_STD 22 | from utils_architecture import normalize_model, get_new_model, interpolate_pos_encoding 23 | from ptflops import get_model_complexity_info 24 | from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count_str 25 | 26 | def sizeof_fmt(num, suffix="Flops"): 27 | for unit in ["", "Ki", "Mi", "G", "T"]: 28 | if abs(num) < 1000.0: 29 | return f"{num:3.3f}{unit}{suffix}" 30 | num /= 1000.0 31 | return f"{num:.1f}Yi{suffix}" 32 | 33 | eps_dict = {'imagenet': {'Linf': 4. / 255., 'L2': 2., 'L1': 75.}} 34 | 35 | 36 | class Logger(): 37 | def __init__(self, log_path): 38 | self.log_path = log_path 39 | 40 | def log(self, str_to_log, verbose=False): 41 | print(str_to_log) 42 | if not self.log_path is None: 43 | with open(self.log_path, 'a') as f: 44 | f.write(str_to_log) 45 | f.write('\n') 46 | if verbose: 47 | f.flush() 48 | 49 | def format(value): 50 | return "%.3f" % value 51 | 52 | 53 | def makedir(path): 54 | if not os.path.exists(path): 55 | os.makedirs(path) 56 | 57 | 58 | def get_args_parser(): 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument('--batch_size', default=200, type=int) 61 | parser.add_argument('--model', default='convnext_tiny_convmlp_nolayerscale', type=str) 62 | parser.add_argument('--n_ex', type=int, default=5000) 63 | parser.add_argument('--norm', type=str) 64 | parser.add_argument('--eps', type=float) 65 | parser.add_argument('--data_dir', type=str, default='/scratch/fcroce42/ffcv_imagenet_data') 66 | parser.add_argument('--only_clean', action='store_true') 67 | parser.add_argument('--save_imgs', action='store_true') 68 | parser.add_argument('--precision', type=str, default='fp32') 69 | parser.add_argument('--ckpt_path', type=str, default='/scratch/nsingh/ImageNet_Arch/model_2022-12-04 14:36:56_convnext_iso_iso_0_not_orig_0_pre_1_aug_0_adv__50_at_crop_flip/weights_18.pt') 70 | parser.add_argument('--mod', type=str) 71 | parser.add_argument('--model_in', nargs='+') 72 | parser.add_argument('--full_aa', type=int, default=0) 73 | parser.add_argument('--init', type=str) 74 | parser.add_argument('--add_normalization', action='store_true', default=False) 75 | parser.add_argument('--l_norms', type=str) 76 | parser.add_argument('--l_epss', type=str) 77 | parser.add_argument('--get_stats', action='store_true') 78 | parser.add_argument('--use_fixed_val_set', action='store_true', default=False) 79 | parser.add_argument('--img_size', type=int, help='resolution to test the evaluataion for, default: 224', default=224) 80 | parser.add_argument('--not_channel_last', action='store_false') 81 | parser.add_argument('--not-original', type=int, default=1) 82 | parser.add_argument('--updated', action='store_true', help='Patched models?', default=False) 83 | 84 | args = parser.parse_args() 85 | return args 86 | 87 | def main(): 88 | args = get_args_parser() 89 | 90 | mods = [args.mod] 91 | nots = [bool(args.not_original)] 92 | args.model_in = ' '.join(args.model_in) 93 | ll = [args.model_in] 94 | args.ckpt_path = args.model_in 95 | data_path = 'patch_to_imagenet_validataion_set' 96 | 97 | device = 'cuda' 98 | 99 | assert len(mods) == len(nots) == len(ll) 100 | 101 | print('using fixed val set') 102 | 103 | 104 | crop_pct = 0.875 105 | 106 | img_size = args.img_size 107 | 108 | scale_size = int(math.floor(img_size / crop_pct)) 109 | trans = transforms.Compose([ 110 | transforms.Resize( 111 | scale_size, 112 | interpolation=transforms.InterpolationMode("bicubic")), 113 | transforms.CenterCrop(img_size), 114 | transforms.ToTensor() 115 | ]) 116 | x_test_val, y_test_val = robustbench.data.load_imagenet(5000, data_dir=data_path, 117 | transforms_test = trans) 118 | 119 | 120 | 121 | print(f"{args.mod} has resolution : {img_size}") 122 | 123 | for idx, modd in enumerate(ll): 124 | 125 | args.ckpt_path += "/weights_20.pt" 126 | args.model = mods[idx] 127 | args.not_original = nots[idx] 128 | 129 | if not args.ckpt_path is None: 130 | # assert os.path.exists(args.ckpt_path), f'{args.ckpt_path} not found' 131 | args.savedir = '/'.join(args.ckpt_path.split('/')[:-1]) 132 | print(args.savedir) 133 | # ep = args.ckpt_path.split('/')[-1].split('.pt')[0] 134 | with open(f'{args.savedir}/params.json', 'r') as f: 135 | params = json.load(f) 136 | args.use_blurpool = params['training.use_blurpool'] == 1 137 | if 'model.add_normalization' in params.keys(): 138 | args.add_normalization = params['model.add_normalization'] == 1 139 | args.model = args.model #params['model.arch'] 140 | else: 141 | args.savedir = './results/' 142 | makedir(args.savedir) 143 | 144 | args.n_cls = 1000 145 | args.num_workers = 1 146 | 147 | if not args.eps is None and args.eps > 1 and args.norm == 'Linf': 148 | args.eps /= 255. 149 | 150 | 151 | device = 'cuda' 152 | arch = args.model 153 | pretrained = False 154 | add_normalization = args.add_normalization 155 | 156 | log_path = f'{args.savedir}/evaluated_logs_{args.l_norms}_{args.full_aa}_8_255.txt' 157 | logger = Logger(log_path) 158 | 159 | print(f"Creating model: {args.model}") 160 | if not arch.startswith('timm_'): 161 | model = get_new_model(arch, pretrained=False, not_original=args.not_original, updated=args.updated) 162 | else: 163 | try: 164 | model = create_model(arch.replace('timm_', ''), pretrained=pretrained) 165 | except: 166 | model = get_new_model(arch.replace('timm_', '')) 167 | 168 | if add_normalization: 169 | print('add normalization layer') 170 | model = normalize_model(model, IMAGENET_MEAN, IMAGENET_STD) 171 | 172 | inpp = torch.rand(1, 3, 224, 224) 173 | flops = FlopCountAnalysis(model, inpp) 174 | val = flops.total() 175 | print(sizeof_fmt(int(val))) 176 | print(flop_count_table(flops, max_depth=2)) 177 | print(flops.by_operator()) 178 | 179 | accs = [] 180 | best_test_rob = 0. 181 | 182 | for i in rann: 183 | 184 | ckpt = torch.load(args.savedir + f"/model_file_name.pt", map_location='cpu') #['model'] 185 | # print(ckpt.keys()) 186 | ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()} 187 | ckpt = {k.replace('base_model.', ''): v for k, v in ckpt.items()} 188 | ckpt = {k.replace('se_', 'se_module.'): v for k, v in ckpt.items()} 189 | model.load_state_dict(ckpt) 190 | model = model.to(device) 191 | model.eval() 192 | acc = clean_accuracy(model, x_test_val, y_test_val, batch_size=args.batch_size, 193 | device=device) 194 | print(f"clean {i} : {acc}") 195 | img_sizer = [img_size, img_size] 196 | if not args.ckpt_path is None: 197 | ## change the size of positional embedding in ViT if img_size>224 (training resolution). 198 | if "vit" in arch: 199 | old_shape = ckpt['pos_embed'].shape 200 | ckpt['pos_embed'] = interpolate_pos_encoding( 201 | ckpt['pos_embed'], new_img_size=img_sizer[0], 202 | patch_size=model.patch_embed.patch_size[0]) 203 | new_shape = ckpt['pos_embed'].shape 204 | print(old_shape, new_shape) 205 | model.pos_embed = nn.Parameter(torch.zeros(new_shape, device=model.pos_embed.device)) 206 | 207 | model.patch_embed.img_size = img_sizer 208 | model.patch_embed.num_patches = new_shape[1] - 1 209 | model.patch_embed.grid_size = ( 210 | img_sizer[0] // model.patch_embed.patch_size[0], 211 | img_sizer[1] // model.patch_embed.patch_size[1]) 212 | model.eval() 213 | 214 | str_to_log = '' 215 | 216 | logger = Logger(log_path) 217 | logger.log(str_to_log) 218 | 219 | all_norms = [args.l_norms] # 220 | #all_norms = ['L2', 'L1', 'Linf'] 221 | l_epss = [eps_dict['imagenet'][c] for c in all_norms] 222 | logger.log(all_norms, l_epss) 223 | all_acs = [] 224 | for idx, nrm in enumerate(all_norms): 225 | epss = l_epss[idx] 226 | adversary = AutoAttack(model, norm=nrm, eps=epss, 227 | version='standard', log_path=log_path) 228 | str_to_log = '' 229 | 230 | if not bool(args.full_aa): 231 | adversary.attacks_to_run = ['apgd-ce', 'apgd-t'] 232 | 233 | str_to_log += f'norm={nrm} eps={l_epss[idx]:.5f}\n' 234 | 235 | assert not model.training 236 | 237 | with torch.no_grad(): 238 | x_adv = adversary.run_standard_evaluation(x_test_val, 239 | y_test_val, bs=args.batch_size) 240 | 241 | acc = clean_accuracy(model, x_adv, y_test_val, batch_size=args.batch_size, 242 | device=device) 243 | print('robust accuracy: {:.2%}'.format(acc)) 244 | str_to_log += 'robust accuracy: {:.2%}\n'.format(acc) 245 | logger.log(str_to_log) 246 | all_acs.append(acc) 247 | 248 | if args.save_imgs: 249 | valset = '_oldset' if args.use_fixed_val_set else '' 250 | runname = f'aa_short_1_{args.n_ex}_{args.norm}_{args.eps:.5f}{valset}.pth' 251 | savepath = f'{args.savedir}/{runname}' 252 | torch.save(x_adv.cpu(), savepath) 253 | 254 | if __name__ == '__main__': 255 | main() 256 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Revisiting Adversarial Training for ImageNet: Architectures, Training and Generalization across Threat Models 2 | #### Naman D Singh, Francesco Croce, Matthias Hein 3 | #### University of Tübingen 4 | #### NeurIPS 2023 5 | 6 | ### [Paper](https://arxiv.org/abs/2303.01870) 7 | ## Abstract 8 | While adversarial training has been extensively studied for ResNet architectures and low resolution datasets like CIFAR, much less is known for ImageNet. Given the recent debate about whether transformers are more robust than convnets, we revisit adversarial training on ImageNet comparing ViTs and ConvNeXts. Extensive experiments show that minor changes in architecture, most notably replacing PatchStem with ConvStem, and training scheme have a significant impact on the achieved robustness. These changes not only increase robustness in the seen $\ell_\infty$-threat model, but even more so improve generalization to unseen $\ell_1/\ell_2$-robustness. 9 | 10 | ![readme_teaser](readme_teaser.png) 11 | 12 | 13 | ## Code 14 | Requirements (specific versions tested on):
15 | `fastargs-1.2.0` `autoattack-0.1` `pytorch-1.13.1` `torchvision-0.14.1` `robustbench-1.1` `timm-0.8.0.dev0`, `GPUtil` 16 | 17 | #### Training 18 | The bash script in `run_train.sh` trains the model `model.arch`. For clean training: `adv.attack none` and for adversarial training set `adv.attack apgd`.
19 | For the standard setting as in the paper (heavy augmentations) set `data.augmentations 1`, `model.model_ema 1` and `training.label_smoothing 1`.
20 | To train models with Convolution-Stem (CvSt) set `model.not_original 1`.
21 | The code does standard APGD adversarial training.
The file `utils_architecture.py` has model definitions for the new `CvSt` models, all models are built on top of timm imports. 22 | 23 | #### Evaluating a model 24 | The file `runner_aa_eval` runs `AutoAttack`(AA). Passing `fullaa 1` runs complete AA whereas `fullaa 0` runs the first two attacks (APGD-CE and APGD-T) in AA.
25 | 26 | 27 | #### Checkpoints - ImageNet $\ell_{\infty} = 4/255$ robust models. 28 | The link location includes weights for the clean model (the one used as initialization for Adversarial Training (AT)), the robust model, and the `full-AA` log for $\ell_{\infty}, \ell_2$ and $\ell_1$ attacks.
29 | Note: the higher resolution numbers use the same checkpoint as for the standard resolution of 224 - only evaluation is done at the higher resolution mentioned.
30 | | Model-Name | epochs | res. | Clean acc. | AA - $\ell_{\infty}$ acc.| Checkpoint (clean-init
and robust) | 31 | | :--- | :------: | :------: | :------: |:------: | :------: | 32 | | ConvNext-iso-CvSt | 300 | 224 | 70.2 | 45.9 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/HpNbkLTNTBiaeo8)| 33 | | ViT-S | 300 | 224 | 69.2 | 44.0 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/XLLnoCnJxp74Zqn)| 34 | | ViT-S-CvSt | 300 | 224 | 72.5 | 48.1 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/agtDw3D7QXbDCmw)| 35 | | ConvNext-T | 300 | 224 | 72.4 | 48.6 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/XLLnoCnJxp74Zqn)| 36 | | ConvNext-T-CvSt | 300 | 224 | 72.7 | 49.5 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/BFLoMrMdn8iBk7Y)| 37 | | ViT-M-CvSt | 50 | 224 | 72.4 | 48.8 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/q2mkEYtq5Zjpa4e)| 38 | | ConvNext-S-CvSt | 50 | 224 | 74.1 | 52.4 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/m3bAwNg4CJY4jrp)| 39 | | ViT-B | 50 | 224 | 73.3 | 50.0 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/XLLnoCnJxp74Zqn)| 40 | | ConvNext-B | 50 | 224 | 75.6 | 54.3 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/XLLnoCnJxp74Zqn)| 41 | | ViT-B-CvSt | 250 | 224 | 76.3 | 54.7 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/SbN5AJAicdZJXyr)| 42 | | ConvNext-B-CvSt | 250 | 224 | 75.9 | 56.1 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/RQBEXagC7R7XweX)| 43 | | ConvNext-B-CvSt* | --- | 256 | 76.9 | 57.3 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/RQBEXagC7R7XweX)| 44 | | ConvNext-L-CvSt | 100 | 224 | 77.0 | 57.7 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/YzBpeHRrRQzHBDz)| 45 | | ConvNext-L-CvSt* | --- | 320 | 78.2 | 59.4 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/YzBpeHRrRQzHBDz)| 46 | ###### *: increased resolution (only for evaluation) also leads to increased FLOPs. 47 | ------------------- 48 | Checkpoints along with accuracy and robustness logs for ImageNet models finetuned to be robust at $\ell_\infty = 8/255$ are available here: [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/FiTToeo4RKY896P) 49 | ________________________________ 50 |

Citation

51 | 52 | If you use our code/models cite our work using the following BibTex entry: 53 | ```bibtex 54 | @inproceedings{singh2023revisiting, 55 | title={Revisiting Adversarial Training for ImageNet: Architectures, Training and Generalization across Threat Models}, 56 | author={Singh, Naman D and Croce, Francesco and Hein, Matthias}, 57 | booktitle={NeurIPS}, 58 | year={2023}} 59 | ``` 60 | -------------------------------------------------------------------------------- /autopgd_train_clean.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import random 6 | import time 7 | 8 | def L1_norm(x, keepdim=False): 9 | z = x.abs().contiguous().view(x.shape[0], -1).sum(-1) 10 | if keepdim: 11 | z = z.contiguous().view(-1, *[1]*(len(x.shape) - 1)) 12 | return z 13 | 14 | def L2_norm(x, keepdim=False): 15 | z = (x ** 2).contiguous().view(x.shape[0], -1).sum(-1).sqrt() 16 | if keepdim: 17 | z = z.contiguous().view(-1, *[1]*(len(x.shape) - 1)) 18 | return z 19 | 20 | def L0_norm(x): 21 | return (x != 0.).view(x.shape[0], -1).sum(-1) 22 | 23 | 24 | def L1_projection(x2, y2, eps1): 25 | ''' 26 | x2: center of the L1 ball (bs x input_dim) 27 | y2: current perturbation (x2 + y2 is the point to be projected) 28 | eps1: radius of the L1 ball 29 | 30 | output: delta s.th. ||y2 + delta||_1 = eps1 31 | and 0 <= x2 + y2 + delta <= 1 32 | ''' 33 | 34 | x = x2.clone().float().view(x2.shape[0], -1) 35 | y = y2.clone().float().view(y2.shape[0], -1) 36 | sigma = y.clone().sign() 37 | u = torch.min(1 - x - y, x + y) 38 | #u = torch.min(u, epsinf - torch.clone(y).abs()) 39 | u = torch.min(torch.zeros_like(y), u) 40 | l = -torch.clone(y).abs() 41 | d = u.clone() 42 | 43 | bs, indbs = torch.sort(-torch.cat((u, l), 1), dim=1) 44 | bs2 = torch.cat((bs[:, 1:], torch.zeros(bs.shape[0], 1).to(bs.device)), 1) 45 | 46 | inu = 2*(indbs < u.shape[1]).float() - 1 47 | size1 = inu.cumsum(dim=1) 48 | 49 | s1 = -u.sum(dim=1) 50 | 51 | c = eps1 - y.clone().abs().sum(dim=1) 52 | c5 = s1 + c < 0 53 | c2 = c5.nonzero().squeeze(1) 54 | 55 | s = s1.unsqueeze(-1) + torch.cumsum((bs2 - bs) * size1, dim=1) 56 | #print(s[0]) 57 | 58 | #print(c5.shape, c2) 59 | 60 | if c2.nelement != 0: 61 | 62 | lb = torch.zeros_like(c2).float() 63 | ub = torch.ones_like(lb) *(bs.shape[1] - 1) 64 | 65 | #print(c2.shape, lb.shape) 66 | 67 | nitermax = torch.ceil(torch.log2(torch.tensor(bs.shape[1]).float())) 68 | counter2 = torch.zeros_like(lb).long() 69 | counter = 0 70 | 71 | while counter < nitermax: 72 | counter4 = torch.floor((lb + ub) / 2.) 73 | counter2 = counter4.type(torch.LongTensor) 74 | 75 | c8 = s[c2, counter2] + c[c2] < 0 76 | ind3 = c8.nonzero().squeeze(1) 77 | ind32 = (~c8).nonzero().squeeze(1) 78 | #print(ind3.shape) 79 | if ind3.nelement != 0: 80 | lb[ind3] = counter4[ind3] 81 | if ind32.nelement != 0: 82 | ub[ind32] = counter4[ind32] 83 | 84 | #print(lb, ub) 85 | counter += 1 86 | 87 | lb2 = lb.long() 88 | alpha = (-s[c2, lb2] -c[c2]) / size1[c2, lb2 + 1] + bs2[c2, lb2] 89 | d[c2] = -torch.min(torch.max(-u[c2], alpha.unsqueeze(-1)), -l[c2]) 90 | 91 | return (sigma * d).view(x2.shape) 92 | 93 | 94 | def softloss(x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 95 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 96 | return loss.mean() 97 | 98 | 99 | def dlr_loss(x, y, reduction='none'): 100 | x_sorted, ind_sorted = x.sort(dim=1) 101 | ind = (ind_sorted[:, -1] == y).float() 102 | 103 | return -(x[torch.arange(x.shape[0]), y] - x_sorted[:, -2] * ind - \ 104 | x_sorted[:, -1] * (1. - ind)) / (x_sorted[:, -1] - x_sorted[:, -3] + 1e-12) 105 | 106 | def dlr_loss_targeted(x, y, y_target): 107 | x_sorted, ind_sorted = x.sort(dim=1) 108 | u = torch.arange(x.shape[0]) 109 | 110 | return -(x[u, y] - x[u, y_target]) / (x_sorted[:, -1] - .5 * ( 111 | x_sorted[:, -3] + x_sorted[:, -4]) + 1e-12) 112 | 113 | criterion_dict = {'ce': lambda x, y: F.cross_entropy(x, y, reduction='none'),'softloss':softloss, 114 | 'dlr': dlr_loss, 'dlr-targeted': dlr_loss_targeted} 115 | 116 | def check_oscillation(x, j, k, y5, k3=0.75): 117 | t = torch.zeros(x.shape[1], device=x.device, dtype=x.dtype) 118 | for counter5 in range(k): 119 | t += (x[j - counter5] > x[j - counter5 - 1]).float() 120 | 121 | return (t <= k * k3 * torch.ones_like(t)).float() 122 | 123 | def apgd_train(model, x, y, norm, eps, n_iter=10, use_rs=False, loss='ce', 124 | verbose=False, mixup=None, is_train=True): 125 | assert not model.training 126 | device = x.device 127 | ndims = len(x.shape) - 1 128 | 129 | times = {'fp': 0, 'bp': 0, 'total': time.time(), 'blk1': 0, 'blk2': 0, 130 | 'blk3': 0} 131 | #print(x.is_contiguous()) 132 | 133 | #startt = time.time() 134 | if not use_rs: 135 | x_adv = x.clone() 136 | else: 137 | raise NotImplemented 138 | if norm == 'Linf': 139 | t = torch.rand_like(x) 140 | 141 | x_adv = x_adv.clamp(0., 1.) 142 | x_best = x_adv.clone().detach() 143 | x_best_adv = x_adv.clone().detach() 144 | loss_steps = torch.zeros([n_iter, x.shape[0]], device=device) 145 | loss_best_steps = torch.zeros([n_iter + 1, x.shape[0]], device=device) 146 | acc_steps = torch.zeros_like(loss_best_steps) 147 | 148 | # set loss 149 | criterion_indiv = criterion_dict[loss] 150 | 151 | # set params 152 | n_fts = math.prod(x.shape[1:]) 153 | if norm in ['Linf', 'L2']: 154 | n_iter_2 = max(int(0.22 * n_iter), 1) 155 | n_iter_min = max(int(0.06 * n_iter), 1) 156 | size_decr = max(int(0.03 * n_iter), 1) 157 | k = n_iter_2 + 0 158 | thr_decr = .75 159 | alpha = 2. 160 | elif norm in ['L1']: 161 | k = max(int(.04 * n_iter), 1) 162 | init_topk = .05 if is_train else .2 163 | topk = init_topk * torch.ones([x.shape[0]], device=device) 164 | sp_old = n_fts * torch.ones_like(topk) 165 | adasp_redstep = 1.5 166 | adasp_minstep = 10. 167 | alpha = 1. 168 | 169 | step_size = alpha * eps * torch.ones([x.shape[0], *[1] * ndims], 170 | device=device, dtype=x.dtype) 171 | counter3 = 0 172 | #times['blk1'] += time.time() - startt 173 | 174 | x_adv.requires_grad_() 175 | #grad = torch.zeros_like(x) 176 | #for _ in range(self.eot_iter) 177 | #with torch.enable_grad() 178 | startt = time.time() 179 | logits = model(x_adv) 180 | times['fp'] += time.time() - startt 181 | loss_indiv = criterion_indiv(logits, y) 182 | loss = loss_indiv.sum() 183 | #grad += torch.autograd.grad(loss, [x_adv])[0].detach() 184 | startt = time.time() 185 | grad = torch.autograd.grad(loss, [x_adv])[0].detach() 186 | times['bp'] += time.time() - startt 187 | #grad /= float(self.eot_iter) 188 | #startt = time.time() 189 | grad_best = grad.clone() 190 | x_adv.detach_() 191 | loss_indiv.detach_() 192 | loss.detach_() 193 | 194 | if mixup is not None: 195 | acc = logits.detach().max(1)[1] == y.max(1)[1] 196 | else: 197 | acc = logits.detach().max(1)[1] == y 198 | acc_steps[0] = acc + 0 199 | loss_best = loss_indiv.detach().clone() 200 | loss_best_last_check = loss_best.clone() 201 | reduced_last_check = torch.ones_like(loss_best) 202 | n_reduced = 0 203 | 204 | u = torch.arange(x.shape[0], device=device) 205 | x_adv_old = x_adv.clone().detach() 206 | #times['blk2'] += time.time() - startt 207 | 208 | #startt = time.time() 209 | for i in range(n_iter): 210 | ### gradient step 211 | if True: #with torch.no_grad() 212 | #startt = time.time() 213 | x_adv = x_adv.detach() 214 | grad2 = x_adv - x_adv_old 215 | x_adv_old = x_adv.clone() 216 | loss_curr = loss.detach().mean() 217 | 218 | a = 0.75 if i > 0 else 1.0 219 | 220 | if norm == 'Linf': 221 | x_adv_1 = x_adv + step_size * torch.sign(grad) 222 | x_adv_1 = torch.clamp(torch.min(torch.max(x_adv_1, 223 | x - eps), x + eps), 0.0, 1.0) 224 | x_adv_1 = torch.clamp(torch.min(torch.max( 225 | x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a), 226 | x - eps), x + eps), 0.0, 1.0) 227 | 228 | elif norm == 'L2': 229 | x_adv_1 = x_adv + step_size * grad / (L2_norm(grad, 230 | keepdim=True) + 1e-12) 231 | x_adv_1 = torch.clamp(x + (x_adv_1 - x) / (L2_norm(x_adv_1 - x, 232 | keepdim=True) + 1e-12) * torch.min(eps * torch.ones_like(x), 233 | L2_norm(x_adv_1 - x, keepdim=True)), 0.0, 1.0) 234 | x_adv_1 = x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a) 235 | x_adv_1 = torch.clamp(x + (x_adv_1 - x) / (L2_norm(x_adv_1 - x, 236 | keepdim=True) + 1e-12) * torch.min(eps * torch.ones_like(x), 237 | L2_norm(x_adv_1 - x, keepdim=True)), 0.0, 1.0) 238 | 239 | elif norm == 'L1': 240 | grad_topk = grad.abs().view(x.shape[0], -1).sort(-1)[0] 241 | topk_curr = torch.clamp((1. - topk) * n_fts, min=0, max=n_fts - 1).long() 242 | grad_topk = grad_topk[u, topk_curr].view(-1, *[1]*(len(x.shape) - 1)) 243 | sparsegrad = grad * (grad.abs() >= grad_topk).float() 244 | x_adv_1 = x_adv + step_size * sparsegrad.sign() / ( 245 | sparsegrad.sign().abs().view(x.shape[0], -1).sum(dim=-1).view( 246 | -1, 1, 1, 1) + 1e-10) 247 | 248 | delta_u = x_adv_1 - x 249 | delta_p = L1_projection(x, delta_u, eps) 250 | x_adv_1 = x + delta_u + delta_p 251 | 252 | elif norm == 'L0': 253 | L1normgrad = grad / (grad.abs().view(grad.shape[0], -1).sum( 254 | dim=-1, keepdim=True) + 1e-12).view(grad.shape[0], *[1]*( 255 | len(grad.shape) - 1)) 256 | x_adv_1 = x_adv + step_size * L1normgrad * n_fts 257 | x_adv_1 = L0_projection(x_adv_1, x, eps) 258 | # TODO: add momentum 259 | 260 | x_adv = x_adv_1 + 0. 261 | #times['blk1'] += time.time() - startt 262 | #return x_adv 263 | 264 | #startt = time.time() # t1 265 | 266 | ### get gradient 267 | if i < n_iter - 1: 268 | x_adv.requires_grad_() 269 | #grad = torch.zeros_like(x) 270 | #for _ in range(self.eot_iter) 271 | #with torch.enable_grad() 272 | startt = time.time() 273 | logits = model(x_adv) 274 | times['fp'] += time.time() - startt 275 | loss_indiv = criterion_indiv(logits, y) 276 | loss = loss_indiv.sum() 277 | #times['blk1'] += time.time() - startt 278 | 279 | #startt = time.time() # t2 280 | #grad += torch.autograd.grad(loss, [x_adv])[0].detach() 281 | if i < n_iter - 1: 282 | # save one backward pass 283 | grad = torch.autograd.grad(loss, [x_adv])[0].detach() 284 | #grad /= float(self.eot_iter) 285 | x_adv.detach_() 286 | loss_indiv.detach_() 287 | loss.detach_() 288 | #times['blk2'] += time.time() - startt 289 | 290 | startt = time.time() 291 | if mixup is not None: 292 | pred = logits.detach().max(1)[1] == y.max(1)[1] 293 | else: 294 | pred = logits.detach().max(1)[1] == y 295 | 296 | acc = torch.min(acc, pred) 297 | startt = time.time() 298 | acc_steps[i + 1] = acc + 0 299 | times['blk1'] += time.time() - startt 300 | startt = time.time() 301 | ind_pred = ~pred 302 | times['blk2'] += time.time() - startt 303 | startt = time.time() 304 | x_best_adv[ind_pred] = x_adv[ind_pred] + 0. 305 | times['blk3'] += time.time() - startt 306 | if verbose: 307 | str_stats = ' - step size: {:.5f} - topk: {:.2f}'.format( 308 | step_size.mean(), topk.mean() * n_fts) if norm in ['L1'] else ' - step size: {:.5f}'.format( 309 | step_size.mean()) 310 | print('iteration: {} - best loss: {:.6f} curr loss {:.6f} - robust accuracy: {:.2%}{}'.format( 311 | i, loss_best.sum(), loss_curr, acc.float().mean(), str_stats)) 312 | #print('pert {}'.format((x - x_best_adv).abs().view(x.shape[0], -1).sum(-1).max())) 313 | 314 | #times['blk3'] += time.time() - startt 315 | #startt = time.time() # t3 316 | 317 | ### check step size 318 | if True: #with torch.no_grad() 319 | y1 = loss_indiv.detach().clone() 320 | loss_steps[i] = y1 + 0 321 | ind = (y1 > loss_best).nonzero().squeeze() 322 | x_best[ind] = x_adv[ind].clone() 323 | grad_best[ind] = grad[ind].clone() 324 | loss_best[ind] = y1[ind] + 0 325 | loss_best_steps[i + 1] = loss_best + 0 326 | 327 | counter3 += 1 328 | 329 | if counter3 == k: 330 | if norm in ['Linf', 'L2']: 331 | fl_oscillation = check_oscillation(loss_steps, i, k, 332 | loss_best, k3=thr_decr) 333 | fl_reduce_no_impr = (1. - reduced_last_check) * ( 334 | loss_best_last_check >= loss_best).float() 335 | fl_oscillation = torch.max(fl_oscillation, 336 | fl_reduce_no_impr) 337 | reduced_last_check = fl_oscillation.clone() 338 | loss_best_last_check = loss_best.clone() 339 | 340 | if fl_oscillation.sum() > 0: 341 | ind_fl_osc = (fl_oscillation > 0).nonzero().squeeze() 342 | step_size[ind_fl_osc] /= 2.0 343 | n_reduced = fl_oscillation.sum() 344 | 345 | x_adv[ind_fl_osc] = x_best[ind_fl_osc].clone() 346 | grad[ind_fl_osc] = grad_best[ind_fl_osc].clone() 347 | 348 | counter3 = 0 349 | k = max(k - size_decr, n_iter_min) 350 | 351 | elif norm == 'L1': 352 | # adjust sparsity 353 | sp_curr = L0_norm(x_best - x) 354 | fl_redtopk = (sp_curr / sp_old) < .95 355 | topk = sp_curr / n_fts / 1.5 356 | step_size[fl_redtopk] = alpha * eps 357 | step_size[~fl_redtopk] /= adasp_redstep 358 | step_size.clamp_(alpha * eps / adasp_minstep, alpha * eps) 359 | sp_old = sp_curr.clone() 360 | 361 | x_adv[fl_redtopk] = x_best[fl_redtopk].clone() 362 | grad[fl_redtopk] = grad_best[fl_redtopk].clone() 363 | 364 | counter3 = 0 365 | #times['blk3'] += time.time() - startt 366 | 367 | if verbose: 368 | times['total'] = time.time() - times['total'] 369 | print(' '.join([f'{k}={v:.5f} s' for k, v in times.items()])) 370 | 371 | return x_best, acc, loss_best, x_best_adv 372 | 373 | 374 | if __name__ == '__main__': 375 | #pass 376 | from train_new import parse_args 377 | from data import load_anydataset 378 | from utils_eval import check_imgs, load_anymodel_datasets, clean_accuracy 379 | 380 | args = parse_args() 381 | args.training_set = False 382 | 383 | x_test, y_test = load_anydataset(args, device='cpu') 384 | x, y = x_test.cuda(), y_test.cuda() 385 | model, l_models = load_anymodel_datasets(args) 386 | 387 | assert not model.training 388 | 389 | if args.attack == 'apgd_train': 390 | #with torch.no_grad() 391 | x_best, acc, _, x_adv = apgd_train(model, x, y, norm=args.norm, 392 | eps=args.eps, n_iter=args.n_iter, verbose=True, loss='ce') 393 | check_imgs(x_adv, x, args.norm) 394 | 395 | elif args.attack == 'apgd_test': 396 | from autoattack import AutoAttack 397 | adversary = AutoAttack(model, norm=args.norm, eps=args.eps) 398 | #adversary.attacks_to_run = ['apgd-ce'] 399 | #adversary.apgd.verbose = True 400 | with torch.no_grad(): 401 | x_adv = adversary.run_standard_evaluation(x, y, bs=1000) 402 | check_imgs(x_adv, x, args.norm) 403 | -------------------------------------------------------------------------------- /dataset_convnext_like.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import os 10 | from torchvision import datasets, transforms 11 | 12 | from timm.data.constants import \ 13 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 14 | from timm.data import create_transform 15 | 16 | def build_dataset(is_train, args): 17 | transform = build_transform(is_train, args) 18 | 19 | print(f"Transform = train: {is_train}") 20 | if isinstance(transform, tuple): 21 | for trans in transform: 22 | print(" - - - - - - - - - - ") 23 | for t in trans.transforms: 24 | print(t) 25 | else: 26 | for t in transform.transforms: 27 | print(t) 28 | print("---------------------------") 29 | data_paths = ['/scratch/nsingh/imagenet', 30 | '/home/scratch/datasets/imagenet', 31 | '/scratch_local/datasets/ImageNet2012', 32 | '/scratch/datasets/imagenet/'] 33 | for data_path in data_paths: 34 | if os.path.exists(data_path): 35 | break 36 | data_set = 'IMNET' 37 | if data_set == 'CIFAR': 38 | dataset = datasets.CIFAR100(data_path, train=is_train, transform=transform, download=True) 39 | nb_classes = 100 40 | elif data_set == 'IMNET': 41 | print("reading from datapath", data_path) 42 | root = os.path.join(data_path, 'train' if is_train else 'val') 43 | dataset = datasets.ImageFolder(root, transform=transform) 44 | nb_classes = 1000 45 | # elif data_set == "image_folder": 46 | # root = args.data_path if is_train else args.eval_data_path 47 | # dataset = datasets.ImageFolder(root, transform=transform) 48 | # nb_classes = args.nb_classes 49 | # assert len(dataset.class_to_idx) == nb_classes 50 | else: 51 | raise NotImplementedError() 52 | print("Number of the class = %d" % nb_classes) 53 | 54 | return dataset, nb_classes 55 | 56 | 57 | def build_transform(is_train, args): 58 | resize_im = args.input_size > 32 59 | imagenet_default_mean_and_std = False #args.imagenet_default_mean_and_std 60 | mean = [0., 0., 0.] #IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 61 | std = [1., 1., 1.] #IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 62 | transform = None 63 | if is_train: 64 | # this should always dispatch to transforms_imagenet_train 65 | transform = create_transform( 66 | input_size=args.input_size, 67 | is_training=True, 68 | color_jitter=args.color_jitter, 69 | auto_augment=args.aa, 70 | interpolation=args.train_interpolation, 71 | re_prob=args.reprob, 72 | re_mode=args.remode, 73 | re_count=args.recount, 74 | mean=mean, 75 | std=std, 76 | scale=args.scale, 77 | ratio=args.ratio, 78 | hflip=args.hflip, 79 | vflip=args.vflip, 80 | crop_pct=args.crop_pct 81 | ) 82 | 83 | return transform 84 | 85 | t = [] 86 | if resize_im: 87 | # warping (no cropping) when evaluated at 384 or larger 88 | if args.input_size >= 384: 89 | t.append( 90 | transforms.Resize((args.input_size, args.input_size), 91 | interpolation=transforms.InterpolationMode.BICUBIC), 92 | ) 93 | print(f"Warping {args.input_size} size input images...") 94 | else: 95 | if args.crop_pct is None: 96 | args.crop_pct = 224 / 256 97 | size = int(args.input_size / args.crop_pct) 98 | t.append( 99 | # to maintain same ratio w.r.t. 224 images 100 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), 101 | ) 102 | t.append(transforms.CenterCrop(args.input_size)) 103 | 104 | t.append(transforms.ToTensor()) 105 | # t.append(transforms.Normalize(mean, std)) 106 | return transforms.Compose(t) 107 | -------------------------------------------------------------------------------- /fgsm_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from autopgd_train_clean import criterion_dict 6 | import robustbench as rb 7 | #from autopgd_train import apgd_train 8 | # import utils 9 | # from model_zoo.fast_models import PreActResNet18 10 | # import other_utils 11 | import autoattack 12 | criterion_dict = {'ce': lambda x, y: F.cross_entropy(x, y, reduction='none')} 13 | 14 | 15 | 16 | 17 | def fgsm_attack(model, images, labels, eps) : 18 | 19 | loss = nn.CrossEntropyLoss() 20 | images.requires_grad = True 21 | 22 | outputs = model(images) 23 | 24 | model.zero_grad() 25 | cost = loss(outputs, labels).to(device) 26 | cost.sum().backward() 27 | attack_images = images.clone() 28 | attack_images += eps*images.grad.sign() 29 | attack_images = torch.clamp(attack_images, 0, 1) 30 | 31 | return attack_images 32 | 33 | 34 | 35 | 36 | def fgsm_attack(model, x, y, eps=4./255.): 37 | 38 | # assert not model.training 39 | 40 | # Set requires_grad attribute of tensor. Important for Attack 41 | x.requires_grad = True 42 | 43 | # Forward pass the data through the model 44 | output = model(x) 45 | # init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 46 | 47 | criterion_indiv = criterion_dict['ce'] 48 | 49 | # Calculate the loss 50 | loss = criterion_indiv(output, y) 51 | 52 | # Zero all existing gradients 53 | model.zero_grad() 54 | 55 | # Calculate gradients of model in backward pass 56 | loss.sum().backward() 57 | # Collect datagrad 58 | data_grad = x.grad.data 59 | 60 | # Call FGSM Attack 61 | x_adv = gen_pert(x, eps, data_grad) 62 | 63 | return x_adv 64 | # output = model(x_adv) 65 | # final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 66 | # if final_pred.item() == target.item(): 67 | # correct += 1 68 | # return correct 69 | 70 | 71 | 72 | def fgsm_train(model, x, y, eps, loss='ce', alpha=1.25, use_rs=False, 73 | noise_level=1., skip_projection=False): 74 | assert not model.training 75 | 76 | if not use_rs: 77 | x_adv = x.clone() 78 | else: 79 | #raise NotImplemented 80 | #if norm == 'Linf' 81 | t = torch.rand_like(x) 82 | x_adv = x + (2. * t - 1.) * eps * noise_level 83 | if not skip_projection: 84 | x_adv.clamp_(0., 1.) 85 | 86 | criterion_indiv = criterion_dict[loss] 87 | 88 | x_adv.requires_grad = True 89 | logits = model(x_adv) 90 | loss_indiv = criterion_indiv(logits, y) 91 | grad = torch.autograd.grad(loss_indiv.sum(), x_adv)[0].detach() 92 | 93 | x_adv = x_adv.detach() + alpha * eps * grad.sign() 94 | if not skip_projection: 95 | x_adv = x + (x_adv - x).clamp(-eps, eps) 96 | x_adv.clamp_(0., 1.) 97 | 98 | return x_adv 99 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | '''Main runnable file for imagenet experiments 2 | ''' 3 | 4 | import sys 5 | sys.path.insert(0,'..') 6 | from math import ceil 7 | import math 8 | import numpy as np 9 | import os, sys 10 | from os import get_terminal_size 11 | from timm.loss.cross_entropy import SoftTargetCrossEntropy 12 | from timm.models import create_model 13 | from datetime import datetime 14 | import argparse, sys, torch 15 | import torch.nn as nn 16 | import torch.optim as optim 17 | from torchinfo import summary 18 | import random 19 | import torch as ch 20 | import torch.nn as nn 21 | from torch.cuda.amp import GradScaler 22 | from torch.cuda.amp import autocast 23 | import torch.nn.functional as F 24 | import torch.distributed as dist 25 | ch.backends.cudnn.benchmark = True 26 | ch.autograd.profiler.emit_nvtx(False) 27 | ch.autograd.profiler.profile(False) 28 | 29 | import argparse 30 | import parserr 31 | from dataset_convnext_like import build_dataset 32 | from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy 33 | import torchvision 34 | from torchvision import models 35 | import torchmetrics 36 | import numpy as np 37 | from tqdm import tqdm 38 | import time 39 | import json 40 | from uuid import uuid4 41 | from typing import List 42 | from pathlib import Path 43 | from argparse import ArgumentParser 44 | from datetime import datetime 45 | from functools import partial 46 | from fastargs import get_current_config 47 | from fastargs.decorators import param 48 | from fastargs import Param, Section 49 | from fastargs.validation import And, OneOf 50 | 51 | from ffcv.pipeline.operation import Operation 52 | from ffcv.loader import Loader, OrderOption 53 | from ffcv.transforms import ToTensor, ToDevice, Squeeze, NormalizeImage, \ 54 | RandomHorizontalFlip, ToTorchImage, Convert 55 | from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, \ 56 | RandomResizedCropRGBImageDecoder 57 | from ffcv.fields.basics import IntDecoder 58 | import timm 59 | from timm.loss import SoftTargetCrossEntropy 60 | from timm.data.mixup import Mixup 61 | from torch_intermediate_layer_getter import IntermediateLayerGetter as MidGetter 62 | 63 | from autopgd_train_clean import apgd_train 64 | from fgsm_train import fgsm_train, fgsm_attack 65 | from utils_architecture import normalize_model, get_new_model 66 | from ptflops import get_model_complexity_info 67 | from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count_str 68 | 69 | def sizeof_fmt(num, suffix="Flops"): 70 | for unit in ["", "Ki", "Mi", "G", "T"]: 71 | if abs(num) < 1000.0: 72 | return f"{num:3.3f}{unit}{suffix}" 73 | num /= 1000.0 74 | return f"{num:.1f}Yi{suffix}" 75 | 76 | 77 | 78 | import warnings 79 | warnings.filterwarnings("ignore", category=DeprecationWarning) 80 | warnings.filterwarnings("ignore") 81 | warnings.filterwarnings("ignore", category=FutureWarning) 82 | # warnings.filterwarnings("ignore", category=UserWarning) 83 | os.environ['KMP_WARNINGS'] = 'off' 84 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 85 | # os.environ["PL_TORCH_DISTRIBUTED_BACKEND"] = "gloo" 86 | 87 | class LabelSmoothingCrossEntropy(nn.Module): 88 | """ NLL loss with label smoothing. 89 | """ 90 | def __init__(self, smoothing=0.1): 91 | super(LabelSmoothingCrossEntropy, self).__init__() 92 | assert smoothing < 1.0 93 | self.smoothing = smoothing 94 | self.confidence = 1. - smoothing 95 | 96 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 97 | logprobs = F.log_softmax(x, dim=-1) 98 | target = target.type(ch.int64) 99 | nll_loss = -logprobs.gather(dim=-1, index=target) 100 | nll_loss = nll_loss.squeeze(1) 101 | smooth_loss = -logprobs.mean(dim=-1) 102 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 103 | return loss.mean() 104 | 105 | 106 | Section('model', 'model details').params( 107 | arch=Param(str, default='effnet_b0'), 108 | pretrained=Param(int, 'is pretrained? (1/0)', default=1), 109 | ckpt_path=Param(str, 'path to resume model', default=''), 110 | add_normalization=Param(int, '0 if no normalization, 1 otherwise', default=1), 111 | not_original=Param(int, 'change effnets? to patch-version', default=0), 112 | updated=Param(int, 'Make conviso Big?', default=0), 113 | model_ema=Param(float, 'Use EMA?', default=0), 114 | freeze_some=Param(int, 'freeze some layers', default=0), 115 | early=Param(int, 'freeze early layers?', default=1), 116 | ) 117 | 118 | Section('resolution', 'resolution scheduling').params( 119 | min_res=Param(int, 'the minimum (starting) resolution', default=160), 120 | max_res=Param(int, 'the maximum (starting) resolution', default=160), 121 | end_ramp=Param(int, 'when to stop interpolating resolution', default=0), 122 | start_ramp=Param(int, 'when to start interpolating resolution', default=0) 123 | ) 124 | 125 | Section('data', 'data related stuff').params( 126 | train_dataset=Param(str, '.dat file to use for training', required=True), 127 | val_dataset=Param(str, '.dat file to use for validation', required=True), 128 | num_workers=Param(int, 'The number of workers', required=True), 129 | in_memory=Param(int, 'does the dataset fit in memory? (1/0)', required=True), 130 | seed=Param(int, 'seed for training loader', default=0), 131 | augmentations=Param(int, 'use fancy augmentations?', default=0) 132 | ) 133 | 134 | Section('lr', 'lr scheduling').params( 135 | step_ratio=Param(float, 'learning rate step ratio', default=0.1), 136 | step_length=Param(int, 'learning rate step length', default=30), 137 | lr_schedule_type=Param(OneOf(['step', 'cyclic', 'cosine']), default='cosine'), 138 | lr=Param(float, 'learning rate', default=1e-3), 139 | lr_peak_epoch=Param(int, 'Epoch at which LR peaks', default=10), 140 | ) 141 | 142 | Section('logging', 'how to log stuff').params( 143 | folder=Param(str, 'log location', default="/mnt/SHARED/nsingh/ImageNet_Arch/full_Img/"), 144 | log_level=Param(int, '0 if only at end 1 otherwise', default=1), 145 | save_freq=Param(int, 'save models every nth epoch', default=2), 146 | addendum=Param(str, 'additional comments?', default=""), 147 | ) 148 | 149 | Section('validation', 'Validation parameters stuff').params( 150 | batch_size=Param(int, 'The batch size for validation', default=64), 151 | resolution=Param(int, 'final resized validation image size', default=224), 152 | lr_tta=Param(int, 'should do lr flipping/avging at test time', default=0), 153 | precision=Param(str, 'np precision', default='fp16') 154 | ) 155 | 156 | Section('training', 'training hyper param stuff').params( 157 | eval_only=Param(int, 'eval only?', default=0), 158 | batch_size=Param(int, 'The batch size', default=512), 159 | optimizer=Param(And(str, OneOf(['sgd', 'adamw'])), 'The optimizer', default='adamw'), 160 | momentum=Param(float, 'SGD momentum', default=0.9), 161 | weight_decay=Param(float, 'weight decay', default=0.05), 162 | epochs=Param(int, 'number of epochs', default=100), 163 | label_smoothing=Param(float, 'label smoothing parameter', default=0.1), 164 | distributed=Param(int, 'is distributed?', default=0), 165 | use_blurpool=Param(int, 'use blurpool?', default=0), 166 | precision=Param(str, 'np precision', default='fp16'), 167 | ) 168 | 169 | Section('dist', 'distributed training options').params( 170 | world_size=Param(int, 'number gpus', default=1), 171 | address=Param(str, 'address', default='localhost'), 172 | port=Param(str, 'port', default='12355') 173 | ) 174 | 175 | Section('adv', 'adversarial training options').params( 176 | attack=Param(str, 'if None standard training', default='none'), 177 | norm=Param(str, '', default='Linf'), 178 | eps=Param(float, '', default=4./255.), 179 | n_iter=Param(int, '', default=2), 180 | verbose=Param(int, '', default=0), 181 | noise_level=Param(float, '', default=1.), 182 | skip_projection=Param(int, '', default=0), 183 | alpha=Param(float, 'step size multiplier', default=1.), 184 | ) 185 | 186 | Section('misc', 'other parameters').params( 187 | notes=Param(str, '', default=''), 188 | use_channel_last=Param(int, 'whether to use channel last memory format', default=1), 189 | ) 190 | 191 | IMAGENET_MEAN = [c * 1. for c in (0.485, 0.456, 0.406)] #[np.array([0., 0., 0.]), np.array([0.485, 0.456, 0.406])][-1] * 255 192 | IMAGENET_STD = [c * 1. for c in (0.229, 0.224, 0.225)] #[np.array([1., 1., 1.]), np.array([0.229, 0.224, 0.225])][-1] * 255 193 | NONORM_MEAN = np.array([0., 0., 0.]) 194 | NONORM_STD = np.array([1., 1., 1.]) * 255 195 | DEFAULT_CROP_RATIO = 224/256 196 | 197 | PREC_DICT = {'fp16': np.float16, 'fp32': np.float32} 198 | 199 | def sizeof_fmt(num, suffix="Flops"): 200 | for unit in ["", "Ki", "Mi", "G", "T"]: 201 | if abs(num) < 1000.0: 202 | return f"{num:3.3f}{unit}{suffix}" 203 | num /= 1000.0 204 | return f"{num:.1f}Yi{suffix}" 205 | 206 | 207 | 208 | @param('lr.lr') 209 | @param('lr.step_ratio') 210 | @param('lr.step_length') 211 | @param('training.epochs') 212 | def get_step_lr(epoch, lr, step_ratio, step_length, epochs): 213 | if epoch >= epochs: 214 | return 0 215 | 216 | num_steps = epoch // step_length 217 | return step_ratio**num_steps * lr 218 | 219 | @param('lr.lr') 220 | @param('training.epochs') 221 | @param('lr.lr_peak_epoch') 222 | def get_cyclic_lr(epoch, lr, epochs, lr_peak_epoch): 223 | xs = [0, lr_peak_epoch, epochs] 224 | ys = [1e-4 * lr, lr, 0] 225 | return np.interp([epoch], xs, ys)[0] 226 | 227 | @param('lr.lr') 228 | @param('training.epochs') 229 | @param('lr.lr_peak_epoch') 230 | def get_cosine_lr(epoch, lr, epochs, lr_peak_epoch): 231 | # if epochs > 100: 232 | # lr_peak_epoch = 20 233 | # else: 234 | # lr_peak_epoch = 10 235 | if epoch <= lr_peak_epoch: 236 | xs = [0, lr_peak_epoch] 237 | ys = [1e-4 * lr, lr] 238 | return np.interp([epoch], xs, ys)[0] 239 | else: 240 | lr_min = 5e-6 241 | lr_t = lr_min + .5 * (lr - lr_min) * (1 + math.cos(math.pi * ( 242 | epoch - lr_peak_epoch) / (epochs - lr_peak_epoch))) 243 | return lr_t 244 | 245 | 246 | class BlurPoolConv2d(ch.nn.Module): 247 | def __init__(self, conv): 248 | super().__init__() 249 | default_filter = ch.tensor([[[[1, 2, 1], [2, 4, 2], [1, 2, 1]]]]) / 16.0 250 | filt = default_filter.repeat(conv.in_channels, 1, 1, 1) 251 | self.conv = conv 252 | self.register_buffer('blur_filter', filt) 253 | 254 | def forward(self, x): 255 | blurred = F.conv2d(x, self.blur_filter, stride=1, padding=(1, 1), 256 | groups=self.conv.in_channels, bias=None) 257 | return self.conv.forward(blurred) 258 | 259 | 260 | class WrappedModel(nn.Module): 261 | """ include the generation of adversarial perturbation in the 262 | forward pass 263 | """ 264 | def __init__(self, base_model, perturb, verbose=False): 265 | super().__init__() 266 | self.base_model = base_model 267 | self.perturb = perturb 268 | self.perturb_input = False 269 | self.verbose = verbose 270 | #self.mu = mu 271 | #self.sigma = sigma 272 | 273 | def forward(self, x, y=None): 274 | # TODO: handle varying threat models 275 | if self.perturb_input: 276 | assert not y is None 277 | #print(x.is_contiguous()) 278 | # use eval mode during attack 279 | self.base_model.eval() 280 | if self.verbose: 281 | print('perturb input') 282 | startt = time.time() 283 | z = self.perturb(self.base_model, x, y) 284 | 285 | if self.verbose: 286 | inftime = time.time() - startt 287 | print(f'inference time={inftime:.5f}') 288 | #print(z[0].is_contiguous()) 289 | self.base_model.train() 290 | 291 | if isinstance(z, (tuple, list)): 292 | z = z[0] 293 | return self.base_model(z) 294 | 295 | else: 296 | if self.verbose: 297 | print('clean inference') 298 | return self.base_model(x) 299 | 300 | def set_perturb(self, mode): 301 | self.perturb_input = mode 302 | 303 | 304 | 305 | def freeze_some_layers(model, early): 306 | 307 | if bool(early): 308 | for name, child in model.named_children(): 309 | for namm, pamm in child.named_parameters(): 310 | if 'stem' in namm: 311 | print(namm + ' is unfrozen') 312 | pamm.requires_grad = True 313 | else: 314 | print(namm + ' is frozen') 315 | pamm.requires_grad = False 316 | else: 317 | for name, child in model.named_children(): 318 | for namm, pamm in child.named_parameters(): 319 | if 'stem' in namm: 320 | print(namm + ' is unfrozen') 321 | pamm.requires_grad = False 322 | else: 323 | print(namm + ' is frozen') 324 | pamm.requires_grad = True 325 | 326 | 327 | 328 | class ImageNetTrainer: 329 | @param('training.distributed') 330 | @param('training.eval_only') 331 | def __init__(self, gpu, distributed, eval_only): 332 | self.all_params = get_current_config() 333 | self.gpu = gpu 334 | self.best_rob_acc = 0. 335 | self.uid = str(uuid4()) 336 | 337 | if distributed: 338 | self.setup_distributed() 339 | 340 | if not eval_only: 341 | self.train_loader, self.val_loader, self.mixup_fn = self.create_train_loader() 342 | # self.val_loader = self.create_val_loader() 343 | self.model, self.scaler = self.create_model_and_scaler() 344 | self.create_optimizer() 345 | self.initialize_logger() 346 | 347 | 348 | @param('dist.address') 349 | @param('dist.port') 350 | @param('dist.world_size') 351 | def setup_distributed(self, address, port, world_size): 352 | os.environ['MASTER_ADDR'] = address 353 | os.environ['MASTER_PORT'] = port 354 | 355 | dist.init_process_group("nccl", rank=self.gpu, world_size=world_size) 356 | ch.cuda.set_device(self.gpu) 357 | 358 | def cleanup_distributed(self): 359 | dist.destroy_process_group() 360 | 361 | @param('lr.lr_schedule_type') 362 | def get_lr(self, epoch, lr_schedule_type): 363 | lr_schedules = { 364 | 'cyclic': get_cyclic_lr, 365 | 'step': get_step_lr, 366 | 'cosine': get_cosine_lr, 367 | } 368 | 369 | return lr_schedules[lr_schedule_type](epoch) 370 | 371 | # resolution tools 372 | @param('resolution.min_res') 373 | @param('resolution.max_res') 374 | @param('resolution.end_ramp') 375 | @param('resolution.start_ramp') 376 | def get_resolution(self, epoch, min_res, max_res, end_ramp, start_ramp): 377 | assert min_res <= max_res 378 | 379 | if epoch <= start_ramp: 380 | return min_res 381 | 382 | if epoch >= end_ramp: 383 | return max_res 384 | 385 | # otherwise, linearly interpolate to the nearest multiple of 32 386 | interp = np.interp([epoch], [start_ramp, end_ramp], [min_res, max_res]) 387 | final_res = int(np.round(interp[0] / 32)) * 32 388 | return final_res 389 | 390 | @param('training.momentum') 391 | @param('training.optimizer') 392 | @param('training.weight_decay') 393 | @param('training.label_smoothing') 394 | @param('model.arch') 395 | def create_optimizer(self, momentum, optimizer, weight_decay, 396 | label_smoothing, arch): 397 | #assert optimizer == 'sgd' 398 | 399 | # Only do weight decay on non-batchnorm parameters 400 | if 'convnext' in arch or 'resnet' in arch: 401 | print('manually excluding parameters for weight decay') 402 | all_params = list(self.model.named_parameters()) 403 | excluded_params = ['bn', '.bias'] #'.norm', '.bias' 404 | if arch in ['timm_convnext_tiny_batchnorm', 'timm_convnext_tiny_batchnorm_relu']: 405 | # timm convnext uses different naming than resnet 406 | excluded_params.append('.norm.') 407 | if arch in ['timm_resnet50_dw_patch-stem_gelu_stages-3393_convnext-bn_fewer-act-norm_ln', 408 | 'timm_resnet50_dw_patch-stem_gelu_stages-3393_convnext-bn_fewer-act-norm_ln_ds-sep', 409 | 'timm_resnet50_dw_patch-stem_gelu_stages-3393_convnext-bn_fewer-act-norm_ln_ds-sep_bias', 410 | 'timm_reimplemented_convnext_tiny']: 411 | # in case LN is used instead of original BN and the naming is not changed 412 | excluded_params.remove('bn') 413 | print('excluded params=', ', '.join(excluded_params)) 414 | 415 | bn_params = [v for k, v in all_params if any([c in k for c in excluded_params])] #('bn' in k) #or k.endswith('.bias') 416 | bn_keys = [k for k, v in all_params if any([c in k for c in excluded_params])] 417 | #print(', '.join(bn_keys)) 418 | #sys.exit() 419 | other_params = [v for k, v in all_params if not any([c in k for c in excluded_params])] #not ('bn' in k) #or k.endswith('.bias') 420 | # se_only = True 421 | # elif se_only: 422 | # other_params = [] 423 | # l = 0 424 | # for name, param in self.model.named_parameters(): 425 | # # print(name) 426 | # if "se_module" not in name: 427 | # # other_params.append(param) 428 | # param.requires_grad = False 429 | # l+=1 430 | # print(l) 431 | # exit() 432 | else: 433 | print('automatically exclude bn and bias from weight decay') 434 | bn_params = [] 435 | bn_keys = [] 436 | other_params = [] 437 | for name, param in self.model.named_parameters(): 438 | if not param.requires_grad: 439 | continue 440 | if param.ndim <= 1 or name.endswith(".bias"): #or name in no_weight_decay_list 441 | bn_keys.append(name) 442 | bn_params.append(param) 443 | else: 444 | other_params.append(param) 445 | #print(', '.join(bn_keys)) 446 | 447 | param_groups = [{ 448 | 'params': bn_params, 449 | 'weight_decay': 0. 450 | }, { 451 | 'params': other_params, 452 | 'weight_decay': weight_decay 453 | }] 454 | 455 | 456 | if optimizer == 'sgd': 457 | self.optimizer = ch.optim.SGD(param_groups, lr=1, momentum=momentum) 458 | else: 459 | self.optimizer = ch.optim.AdamW(param_groups, betas=(0.9, 0.95)) 460 | 461 | if self.mixup_fn is None: 462 | self.loss = ch.nn.CrossEntropyLoss() 463 | else: 464 | # # smoothing is handled with mixup label transform 465 | # self.loss = LabelSmoothingCrossEntropy(smoothing=label_smoothing) 466 | self.loss = SoftTargetCrossEntropy() 467 | 468 | @param('data.train_dataset') 469 | @param('data.num_workers') 470 | @param('training.batch_size') 471 | @param('training.distributed') 472 | @param('training.label_smoothing') 473 | @param('data.in_memory') 474 | @param('data.seed') 475 | @param('data.augmentations') 476 | @param('training.precision') 477 | @param('misc.use_channel_last') 478 | @param('dist.world_size') 479 | def create_train_loader(self, train_dataset, num_workers, batch_size, 480 | distributed, label_smoothing, in_memory, seed, augmentations, precision, 481 | use_channel_last, world_size): 482 | torch.manual_seed(seed) 483 | if False: 484 | this_device = f'cuda:{self.gpu}' 485 | print(this_device) 486 | # train_path = Path(train_dataset) 487 | data_paths = ['/scratch/fcroce42/ffcv_imagenet_data/train_400_0.50_90.ffcv', 488 | '/scratch/nsingh/datasets/ffcv_imagenet_data/train_400_0.50_90.ffcv', '/scratch_local/datasets/ffcv_imagenet_data/train_400_0.50_90.ffcv'] 489 | for data_path in data_paths: 490 | if os.path.exists(data_path): 491 | train_path = Path(data_path) 492 | break 493 | print(train_path) 494 | assert train_path.is_file() 495 | 496 | res = self.get_resolution(epoch=0) 497 | prec = PREC_DICT[precision] 498 | self.decoder = RandomResizedCropRGBImageDecoder((res, res)) 499 | if use_channel_last: 500 | image_pipeline: List[Operation] = [ 501 | self.decoder, 502 | RandomHorizontalFlip(), 503 | #Convert(np.float16), 504 | ToTensor(), 505 | #lambda x: x.contiguous(), 506 | ToDevice(ch.device(this_device), non_blocking=True), 507 | ToTorchImage(channels_last=True), 508 | NormalizeImage(NONORM_MEAN, NONORM_STD, #IMAGENET_MEAN, IMAGENET_STD, 509 | prec, #np.float16 510 | ) 511 | ] 512 | else: 513 | image_pipeline: List[Operation] = [ 514 | self.decoder, 515 | RandomHorizontalFlip(), 516 | #Convert(np.float16), 517 | ToTensor(), 518 | #lambda x: x.contiguous(), 519 | ToDevice(ch.device(this_device), non_blocking=True), 520 | ToTorchImage(channels_last=False), 521 | #NormalizeImage(NONORM_MEAN, NONORM_STD, #IMAGENET_MEAN, IMAGENET_STD, 522 | # prec, #np.float16 523 | # ) 524 | Convert(ch.cuda.HalfTensor), #float16 525 | torchvision.transforms.Normalize([0., 0., 0.], [255., 255., 255.]), 526 | ] 527 | 528 | label_pipeline: List[Operation] = [ 529 | IntDecoder(), 530 | ToTensor(), 531 | Squeeze(), 532 | ToDevice(ch.device(this_device), non_blocking=True) 533 | ] 534 | 535 | order = OrderOption.RANDOM if distributed else OrderOption.QUASI_RANDOM 536 | loader = Loader(train_dataset, 537 | batch_size=batch_size, 538 | num_workers=num_workers, 539 | order=order, 540 | os_cache=in_memory, 541 | drop_last=True, 542 | pipelines={ 543 | 'image': image_pipeline, 544 | 'label': label_pipeline 545 | }, 546 | distributed=distributed, 547 | seed=seed) 548 | 549 | else: 550 | 551 | if augmentations: 552 | args = parserr.Arguments_augment() 553 | 554 | else: 555 | args = parserr.Arguments_No_augment() 556 | 557 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 558 | if False: 559 | args.dist_eval = False 560 | dataset_val = None 561 | else: 562 | dataset_val, _ = build_dataset(is_train=False, args=args) 563 | 564 | num_tasks = world_size 565 | global_rank = self.gpu #utils.get_rank() 566 | 567 | sampler_train = torch.utils.data.DistributedSampler( 568 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=seed, 569 | ) 570 | print("Sampler_train = %s" % str(sampler_train)) 571 | if args.dist_eval: 572 | if len(dataset_val) % num_tasks != 0: 573 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 574 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 575 | 'equal num of samples per-process.') 576 | sampler_val = torch.utils.data.DistributedSampler( 577 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 578 | else: 579 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 580 | 581 | data_loader_train = torch.utils.data.DataLoader( 582 | dataset_train, sampler=sampler_train, 583 | batch_size=batch_size, 584 | num_workers=num_workers, 585 | pin_memory=True, 586 | drop_last=True, 587 | ) 588 | if dataset_val is not None: 589 | data_loader_val = torch.utils.data.DataLoader( 590 | dataset_val, sampler=sampler_val, 591 | batch_size=int(1.5 * batch_size), 592 | num_workers=num_workers, 593 | pin_memory=True, 594 | drop_last=False 595 | ) 596 | else: 597 | data_loader_val = None 598 | 599 | mixup_fn = None 600 | mixup_active = (args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None) and augmentations 601 | if mixup_active: 602 | print("Mixup is activated!") 603 | print(f"Using label smoothing:{label_smoothing}") 604 | mixup_fn = Mixup( 605 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 606 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 607 | label_smoothing=label_smoothing, num_classes=args.nb_classes) 608 | 609 | # assert not distributed 610 | # self.decoder = None #RandomResizedCropRGBImageDecoder((res, res)) 611 | 612 | # from robustness.datasets import DATASETS 613 | # from robustness.tools import helpers 614 | # data_paths = ['/home/scratch/datasets/imagenet', 615 | # '/scratch_local/datasets/ImageNet2012', 616 | # '/mnt/qb/datasets/ImageNet2012', 617 | # '/scratch/datasets/imagenet/'] 618 | # for data_path in data_paths: 619 | # if os.path.exists(data_path): 620 | # break 621 | # print(f'found dataset at {data_path}') 622 | # dataset = DATASETS['imagenet'](data_path) #'/home/scratch/datasets/imagenet' 623 | 624 | 625 | # train_loader, val_loader = dataset.make_loaders(num_workers, 626 | # batch_size, data_aug=True) 627 | 628 | # loader = helpers.DataPrefetcher(train_loader) 629 | # #val_loader = helpers.DataPrefetcher(val_loader) 630 | 631 | 632 | 633 | return data_loader_train, data_loader_val, mixup_fn 634 | 635 | @param('data.val_dataset') 636 | @param('data.num_workers') 637 | @param('validation.batch_size') 638 | @param('validation.resolution') 639 | @param('validation.precision') 640 | @param('training.distributed') 641 | @param('misc.use_channel_last') 642 | def create_val_loader(self, val_dataset, num_workers, batch_size, 643 | resolution, precision, distributed, use_channel_last 644 | ): 645 | this_device = f'cuda:{self.gpu}' 646 | # val_path = Path(val_dataset) 647 | data_paths = ['/scratch/fcroce42/ffcv_imagenet_data/val_400_0.50_90.ffcv', 648 | '/scratch/nsingh/datasets/ffcv_imagenet_data/val_400_0.50_90.ffcv', '/scratch_local/datasets/ffcv_imagenet_data/train_400_0.50_90.ffcv'] 649 | for data_path in data_paths: 650 | if os.path.exists(data_path): 651 | val_path = Path(data_path) 652 | break 653 | assert val_path.is_file() 654 | res_tuple = (resolution, resolution) 655 | prec = PREC_DICT[precision] 656 | cropper = CenterCropRGBImageDecoder(res_tuple, ratio=DEFAULT_CROP_RATIO) 657 | if use_channel_last: 658 | image_pipeline = [ 659 | cropper, 660 | ToTensor(), 661 | ToDevice(ch.device(this_device), non_blocking=True), 662 | ToTorchImage(), 663 | NormalizeImage(NONORM_MEAN, NONORM_STD, #IMAGENET_MEAN, IMAGENET_STD 664 | prec) 665 | ] 666 | else: 667 | image_pipeline = [ 668 | cropper, 669 | ToTensor(), 670 | ToDevice(ch.device(this_device), non_blocking=True), 671 | ToTorchImage(channels_last=False), 672 | Convert(ch.cuda.FloatTensor), 673 | torchvision.transforms.Normalize([0., 0., 0.], [255., 255., 255.]), 674 | ] 675 | 676 | label_pipeline = [ 677 | IntDecoder(), 678 | ToTensor(), 679 | Squeeze(), 680 | ToDevice(ch.device(this_device), 681 | non_blocking=True) 682 | ] 683 | 684 | loader = Loader(val_dataset, 685 | batch_size=batch_size, 686 | num_workers=num_workers, 687 | order=OrderOption.SEQUENTIAL, 688 | drop_last=False, 689 | pipelines={ 690 | 'image': image_pipeline, 691 | 'label': label_pipeline 692 | }, 693 | distributed=distributed) 694 | return loader 695 | 696 | @param('training.epochs') 697 | @param('logging.log_level') 698 | @param('logging.save_freq') 699 | @param('model.ckpt_path') 700 | @param('adv.attack') 701 | 702 | def train(self, epochs, log_level, save_freq, ckpt_path, attack): 703 | vall, nums = self.single_val() 704 | if log_level > 0: 705 | val_dict = { 706 | 'Validation acc': vall.item(), 707 | 'points': nums 708 | } 709 | if self.gpu == 0: 710 | self.log(val_dict) 711 | 712 | for epoch in range(epochs): 713 | #print(f'epoch {epoch}') 714 | res = self.get_resolution(epoch) 715 | try: 716 | self.decoder.output_size = (res, res) 717 | except: 718 | pass 719 | train_loss = self.train_loop(epoch) 720 | 721 | if log_level > 0: 722 | extra_dict = { 723 | 'train_loss': train_loss.item(), 724 | 'epoch': epoch 725 | } 726 | 727 | self.eval_and_log(extra_dict) 728 | 729 | if train_loss.isnan(): 730 | sys.exit() 731 | 732 | # if attack == 'none': 733 | ##### save every 10 epochs if 734 | save_freq = 1 735 | 736 | self.eval_and_log({'epoch': epoch}) 737 | if (self.gpu == 0 and epoch % save_freq == 0) or (self.gpu == 0 and epoch == epochs - 1): 738 | ch.save(self.model.state_dict(), self.log_folder / f'weights_{epoch}.pt') 739 | if self.model_ema is not None: 740 | ch.save(timm.utils.model.get_state_dict(self.model_ema), self.log_folder / f'weights_ema_{epoch}.pt') 741 | if epoch % 5 == 0 or epoch == epochs-1: 742 | ch.save({ 743 | 'model_state_dict': self.model.state_dict(), 744 | 'optimizer_state_dict': self.optimizer.state_dict(), 745 | 'loss_scaler_state_dict': self.scaler.state_dict(), 746 | 'epoch': epoch, 747 | 'state_dict_ema':timm.utils.model.get_state_dict(self.model_ema) 748 | }, self.log_folder / f'full_model_{epoch}.pth') 749 | else: 750 | if epoch % 5 == 0 or epoch == epochs-1: 751 | ch.save({ 752 | 'model_state_dict': self.model.state_dict(), 753 | 'optimizer_state_dict': self.optimizer.state_dict(), 754 | 'loss_scaler_state_dict': self.scaler.state_dict(), 755 | 'epoch': epoch, 756 | }, self.log_folder / f'full_model_{epoch}.pth') 757 | 758 | def eval_and_log(self, extra_dict={}): 759 | start_val = time.time() 760 | stats = 0 #self.val_loop() 761 | val_time = time.time() - start_val 762 | if self.gpu == 0: 763 | self.log(dict({ 764 | 'current_lr': self.optimizer.param_groups[0]['lr'], 765 | 'top_1': stats, 766 | 'top_5': stats, 767 | 'val_time': 0 768 | }, **extra_dict)) 769 | 770 | return stats 771 | 772 | 773 | 774 | @param('model.arch') 775 | @param('model.pretrained') 776 | @param('model.not_original') 777 | @param('model.updated') 778 | @param('model.model_ema') 779 | @param('model.freeze_some') 780 | @param('model.early') 781 | @param('training.distributed') 782 | @param('training.use_blurpool') 783 | @param('model.ckpt_path') 784 | @param('model.add_normalization') 785 | @param('adv.attack') 786 | @param('adv.norm') 787 | @param('adv.eps') 788 | @param('adv.n_iter') 789 | @param('adv.verbose') 790 | @param('misc.use_channel_last') 791 | @param('adv.alpha') 792 | @param('adv.noise_level') 793 | @param('adv.skip_projection') 794 | def create_model_and_scaler(self, arch, pretrained, not_original, updated, model_ema, freeze_some, early, distributed, use_blurpool, 795 | ckpt_path, add_normalization, attack, norm, eps, n_iter, verbose, 796 | use_channel_last, alpha, noise_level, skip_projection): 797 | scaler = GradScaler() 798 | if not arch.startswith('timm_'): 799 | model = get_new_model(arch, pretrained=bool(pretrained), not_original=bool(not_original), updated=bool(updated)) 800 | else: 801 | try: 802 | model = create_model(arch.replace('timm_', ''), pretrained=pretrained) 803 | #model.drop_path_rate = .1 804 | except: 805 | model = get_new_model(arch.replace('timm_', '')) 806 | verbose = verbose == 1 807 | 808 | def apply_blurpool(mod: ch.nn.Module): 809 | for (name, child) in mod.named_children(): 810 | if isinstance(child, ch.nn.Conv2d) and (np.max(child.stride) > 1 and child.in_channels >= 16): 811 | setattr(mod, name, BlurPoolConv2d(child)) 812 | else: apply_blurpool(child) 813 | if use_blurpool: apply_blurpool(model) 814 | 815 | if use_channel_last: 816 | print('using channel last memory format') 817 | model = model.to(memory_format=ch.channels_last) 818 | else: 819 | print('not using channel last memory format') 820 | 821 | if bool(freeze_some): 822 | print(f"Freezing early layers: {bool(early)}") 823 | freeze_some_layers(model, early) 824 | 825 | 826 | if arch != 'convnext_tiny_21k' and add_normalization: 827 | print('add normalization layer') 828 | model = normalize_model(model, IMAGENET_MEAN, IMAGENET_STD) 829 | 830 | 831 | if attack in ['apgd', 'fgsm']: 832 | print('using input perturbation layer') 833 | if attack == 'apgd': 834 | attack = partial(apgd_train, norm=norm, eps=eps, 835 | n_iter=n_iter, verbose=verbose, mixup=self.mixup_fn) 836 | elif attack == 'fgsm': 837 | attack = partial(fgsm_train, eps=eps, 838 | use_rs=True, 839 | alpha=alpha, 840 | noise_level=noise_level, 841 | skip_projection=skip_projection == 1 842 | ) 843 | print(attack) 844 | model = WrappedModel(model, attack, verbose=verbose) 845 | 846 | if self.gpu == 0: 847 | print(model) 848 | inpp = torch.rand(1, 3, 224, 224) 849 | flops = FlopCountAnalysis(model, inpp) 850 | val = flops.total() 851 | print(val) 852 | print(sizeof_fmt(int(val))) 853 | print(flop_count_table(flops, max_depth=2)) 854 | print(flops.by_operator()) 855 | 856 | if not ckpt_path == '': 857 | ckpt = ch.load(ckpt_path, map_location='cpu') 858 | ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()} 859 | try: 860 | model.load_state_dict(ckpt) 861 | print('standard loading') 862 | 863 | except: 864 | try: 865 | ckpt = {f'base_model.{k}': v for k, v in ckpt.items()} 866 | model.load_state_dict(ckpt) 867 | print('loaded from clean model') 868 | except: 869 | ckpt = {k.replace('base_model.', ''): v for k, v in ckpt.items()} 870 | # ckpt = {f'base_model.{k}': v for k, v in ckpt.items()} 871 | model.load_state_dict(ckpt) 872 | print('loaded') 873 | #model = model.to(memory_format=ch.channels_last) 874 | 875 | # print(model.patch_embed(torch.rand((50, 3, 224, 224)))) 876 | # exit() 877 | # if arch != 'convnext_tiny_21k' and add_normalization: 878 | # print('add normalization layer') 879 | # model = normalize_model(model, IMAGENET_MEAN, IMAGENET_STD) 880 | 881 | model = model.to(self.gpu) 882 | if bool(model_ema): 883 | print('Using EMA with decay 0.9999') 884 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper 885 | self.model_ema = timm.utils.ModelEmaV2(model, decay=0.9999, device='cpu') 886 | else: 887 | self.model_ema = None 888 | 889 | if distributed: 890 | model = ch.nn.parallel.DistributedDataParallel(model, device_ids=[self.gpu]) #, find_unused_parameters=True) 891 | 892 | return model, scaler 893 | 894 | @param('validation.lr_tta') 895 | @param('adv.attack') 896 | @param('dist.world_size') 897 | def single_val(self, lr_tta, attack, world_size): 898 | model = self.model 899 | model.eval() 900 | show_once = True 901 | acc = 0. 902 | accs = [] 903 | n = 0. 904 | ns = [] 905 | best_test_rob = 0. 906 | 907 | 908 | # with ch.no_grad(): 909 | with autocast(enabled=True): 910 | for idx, (images, target) in enumerate(tqdm(self.val_loader)): 911 | # if show_once: 912 | # print(images.shape, images.max(), images.min()) 913 | # show_once = False 914 | 915 | images = images.contiguous().cuda(self.gpu, non_blocking=True) 916 | target = target.contiguous().cuda(self.gpu, non_blocking=True) 917 | output = self.model(images) 918 | if lr_tta: 919 | output += self.model(ch.flip(images, dims=[3])) 920 | 921 | # for k in ['top_1', 'top_5']: 922 | # self.val_meters[k](output, target) 923 | 924 | acc += (output.max(1)[1] == target).sum() 925 | n += target.shape[0] 926 | # loss_val = self.loss(output, target) #####. remove this comment 927 | # self.val_meters['loss'](loss_val) 928 | if idx >= 200: 929 | break 930 | accs.append(acc) 931 | ns = n*world_size 932 | print(f'clean accuracy={acc / n:.2%}') 933 | # stats = {k: m.compute().item() for k, m in self.val_meters.items()} 934 | # if stats['top_1'] > self.best_rob_acc: 935 | # self.best_rob_acc = stats['top_1'] 936 | # if self.gpu == 0: 937 | # ch.save(self.model.state_dict(), self.log_folder / 'best_adv_weights.pt') 938 | # [meter.reset() for meter in self.val_meters.values()] 939 | return ch.stack(accs)/ns, ns 940 | 941 | @param('logging.log_level') 942 | @param('adv.attack') 943 | @param('training.distributed') 944 | def train_loop(self, epoch, log_level, attack, distributed): 945 | model = self.model 946 | model.train() 947 | losses = [] 948 | show_once = True 949 | perturb = attack != 'none' 950 | if perturb: 951 | if distributed: 952 | model.module.set_perturb(True) 953 | else: 954 | model.set_perturb(True) 955 | 956 | lr_start, lr_end = self.get_lr(epoch), self.get_lr(epoch + 1) 957 | iters = len(self.train_loader) 958 | lrs = np.interp(np.arange(iters), [0, iters], [lr_start, lr_end]) 959 | 960 | iterator = tqdm(self.train_loader) 961 | for ix, (images, target) in enumerate(iterator): 962 | images = images.cuda(self.gpu, non_blocking=True) 963 | target = target.cuda(self.gpu, non_blocking=True) 964 | # print(images.size(), target.size()) 965 | if self.mixup_fn is not None: 966 | images, target = self.mixup_fn(images, target) 967 | 968 | if show_once: 969 | # print(images.shape, images.max(), images.min()) 970 | show_once = False 971 | 972 | ### Training start 973 | for param_group in self.optimizer.param_groups: 974 | param_group['lr'] = lrs[ix] 975 | 976 | '''print(images.device) 977 | images = images.reshape(images.shape) # make contiguous (and more) 978 | if False: 979 | ch.save(images, './train_imgs_cm.pth') 980 | sys.exit() 981 | target = target.reshape(target.shape) 982 | print(images.device)''' 983 | 984 | self.optimizer.zero_grad(set_to_none=True) 985 | with autocast(enabled=True): 986 | if not perturb: 987 | output = self.model(images) 988 | else: 989 | output = self.model(images, target) # TODO: check the effect of .contiguous() for other models 990 | loss_train = self.loss(output, target) 991 | 992 | self.scaler.scale(loss_train).backward() 993 | self.scaler.step(self.optimizer) 994 | self.scaler.update() 995 | ### Training end 996 | if self.model_ema is not None: 997 | self.model_ema.update(model) 998 | 999 | #ch.cuda.synchronize() 1000 | 1001 | ### Logging start 1002 | if log_level > 0: 1003 | losses.append(loss_train.detach()) 1004 | 1005 | group_lrs = [] 1006 | for _, group in enumerate(self.optimizer.param_groups): 1007 | group_lrs.append(f'{group["lr"]:.3f}') 1008 | 1009 | names = ['ep', 'iter', 'shape', 'lrs'] 1010 | values = [epoch, ix, tuple(images.shape), group_lrs] 1011 | if log_level > 1: 1012 | names += ['loss'] 1013 | values += [f'{loss_train.item():.3f}'] 1014 | 1015 | msg = ', '.join(f'{n}={v}' for n, v in zip(names, values)) 1016 | iterator.set_description(msg) 1017 | ### Logging end 1018 | 1019 | 1020 | if perturb: 1021 | if distributed: 1022 | model.module.set_perturb(False) 1023 | else: 1024 | model.set_perturb(False) 1025 | 1026 | return ch.stack(losses).mean() 1027 | 1028 | @param('validation.lr_tta') 1029 | @param('adv.attack') 1030 | def val_loop(self, lr_tta, attack): 1031 | 1032 | model = self.model 1033 | model.eval() 1034 | show_once = True 1035 | acc = 0. 1036 | best_test_rob = 0 1037 | # with ch.no_grad(): 1038 | with autocast(enabled=True): 1039 | for idx, (images, target) in enumerate(tqdm(self.val_loader)): 1040 | # if show_once: 1041 | # print(images.shape, images.max(), images.min()) 1042 | # show_once = False 1043 | 1044 | images = images.contiguous() 1045 | target = target.contiguous() 1046 | # if attack != 'none': 1047 | # x_adv = fgsm_attack(model, images, target, eps=4./255.) 1048 | # output = self.model(x_adv) 1049 | # if lr_tta: 1050 | # output += self.model(ch.flip(x_adv, dims=[3])) 1051 | # else: 1052 | output = self.model(images) 1053 | if lr_tta: 1054 | output += self.model(ch.flip(images, dims=[3])) 1055 | # if lr_tta: 1056 | # output += self.model(ch.flip(x_adv, dims=[3])) 1057 | for k in ['top_1', 'top_5']: 1058 | self.val_meters[k](output, target) 1059 | 1060 | acc += (output.max(1)[1] == target).sum() 1061 | 1062 | loss_val = self.loss(output, target) 1063 | self.val_meters['loss'](loss_val) 1064 | if idx >= 50: 1065 | break 1066 | #print(f'clean accuracy={acc / 50000:.2%}') 1067 | stats = {k: m.compute().item() for k, m in self.val_meters.items()} 1068 | 1069 | if stats['top_1'] > self.best_rob_acc: 1070 | self.best_rob_acc = stats['top_1'] 1071 | if self.gpu == 0: 1072 | ch.save(self.model.state_dict(), self.log_folder / 'best_adv_weights.pt') 1073 | [meter.reset() for meter in self.val_meters.values()] 1074 | return stats 1075 | 1076 | @param('logging.folder') 1077 | @param('model.arch') 1078 | @param('adv.attack') 1079 | @param('model.updated') 1080 | @param('model.not_original') 1081 | @param('logging.addendum') 1082 | @param('data.augmentations') 1083 | @param('model.pretrained') 1084 | def initialize_logger(self, folder, arch, attack, updated, not_original, addendum, augmentations, pretrained): 1085 | self.val_meters = { 1086 | 'top_1': torchmetrics.Accuracy(compute_on_step=False).to(self.gpu), 1087 | 'top_5': torchmetrics.Accuracy(compute_on_step=False, top_k=5).to(self.gpu), 1088 | 'loss': MeanScalarMetric(compute_on_step=False).to(self.gpu) 1089 | } 1090 | 1091 | if self.gpu == 0: 1092 | #folder = (Path(folder) / str(self.uid)).absolute() 1093 | runname = f'model_{str(datetime.now())[:-7]}_{arch}_upd_{updated}_not_orig_{not_original}_pre_{pretrained}_aug_{augmentations}' 1094 | if attack != 'none': 1095 | runname += f'_adv_{addendum}' 1096 | else: 1097 | runname += f'_clean_{addendum}' 1098 | folder = (Path(folder) / runname).absolute() 1099 | folder.mkdir(parents=True) 1100 | 1101 | self.log_folder = folder 1102 | self.start_time = time.time() 1103 | 1104 | print(f'=> Logging in {self.log_folder}') 1105 | params = { 1106 | '.'.join(k): self.all_params[k] for k in self.all_params.entries.keys() 1107 | } 1108 | with open(folder / 'params.json', 'w+') as handle: 1109 | json.dump(params, handle) 1110 | 1111 | def log(self, content): 1112 | print(f'=> Log: {content}') 1113 | if self.gpu != 0: return 1114 | cur_time = time.time() 1115 | try: 1116 | with open(self.log_folder / 'log', 'a+') as fd: 1117 | fd.write(json.dumps({ 1118 | 'timestamp': cur_time, 1119 | 'relative_time': cur_time - self.start_time, 1120 | **content 1121 | }) + '\n') 1122 | fd.flush() 1123 | except: 1124 | with open(self.log_folder / 'log', 'a+') as fd: 1125 | fd.write(content + '\n') 1126 | fd.flush() 1127 | 1128 | @classmethod 1129 | @param('training.distributed') 1130 | @param('dist.world_size') 1131 | def launch_from_args(cls, distributed, world_size): 1132 | if distributed: 1133 | ch.multiprocessing.spawn(cls._exec_wrapper, nprocs=world_size, join=True) 1134 | else: 1135 | cls.exec(0) 1136 | 1137 | @classmethod 1138 | def _exec_wrapper(cls, *args, **kwargs): 1139 | make_config(quiet=True) 1140 | cls.exec(*args, **kwargs) 1141 | 1142 | @classmethod 1143 | @param('training.distributed') 1144 | @param('training.eval_only') 1145 | def exec(cls, gpu, distributed, eval_only): 1146 | trainer = cls(gpu=gpu) 1147 | if eval_only: 1148 | trainer.eval_and_log() 1149 | else: 1150 | trainer.train() 1151 | if distributed: 1152 | trainer.cleanup_distributed() 1153 | 1154 | # Utils 1155 | class MeanScalarMetric(torchmetrics.Metric): 1156 | def __init__(self, *args, **kwargs): 1157 | super().__init__(*args, **kwargs) 1158 | 1159 | self.add_state('sum', default=ch.tensor(0.), dist_reduce_fx='sum') 1160 | self.add_state('count', default=ch.tensor(0), dist_reduce_fx='sum') 1161 | 1162 | def update(self, sample: ch.Tensor): 1163 | self.sum += sample.sum() 1164 | self.count += sample.numel() 1165 | 1166 | def compute(self): 1167 | return self.sum.float() / self.count 1168 | 1169 | # Running 1170 | def make_config(quiet=False): 1171 | config = get_current_config() 1172 | parser = ArgumentParser(description='Fast imagenet training') 1173 | config.augment_argparse(parser) 1174 | config.collect_argparse_args(parser) 1175 | config.validate(mode='stderr') 1176 | if not quiet: 1177 | config.summary() 1178 | 1179 | if __name__ == "__main__": 1180 | make_config() 1181 | ImageNetTrainer.launch_from_args() 1182 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.7.1" 2 | from .EffNet import EfficientNet, VALID_MODELS 3 | from .utils import ( 4 | GlobalParams, 5 | BlockArgs, 6 | BlockDecoder, 7 | efficientnet, 8 | get_model_params, 9 | ) -------------------------------------------------------------------------------- /models/convnext.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from timm.models.layers import trunc_normal_, DropPath 13 | from timm.models.registry import register_model 14 | 15 | class Block(nn.Module): 16 | r""" ConvNeXt Block. There are two equivalent implementations: 17 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 18 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 19 | We use (2) as we find it slightly faster in PyTorch 20 | 21 | Args: 22 | dim (int): Number of input channels. 23 | drop_path (float): Stochastic depth rate. Default: 0.0 24 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 25 | """ 26 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): 27 | super().__init__() 28 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 29 | self.norm = LayerNorm(dim, eps=1e-6) 30 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 31 | self.act = nn.GELU() 32 | self.pwconv2 = nn.Linear(4 * dim, dim) 33 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 34 | requires_grad=True) if layer_scale_init_value > 0 else None 35 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 36 | 37 | def forward(self, x): 38 | input = x 39 | x = self.dwconv(x) 40 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 41 | x = self.norm(x) 42 | x = self.pwconv1(x) 43 | x = self.act(x) 44 | x = self.pwconv2(x) 45 | if self.gamma is not None: 46 | x = self.gamma * x 47 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 48 | 49 | x = input + self.drop_path(x) 50 | return x 51 | 52 | class ConvNeXt(nn.Module): 53 | r""" ConvNeXt 54 | A PyTorch impl of : `A ConvNet for the 2020s` - 55 | https://arxiv.org/pdf/2201.03545.pdf 56 | 57 | Args: 58 | in_chans (int): Number of input image channels. Default: 3 59 | num_classes (int): Number of classes for classification head. Default: 1000 60 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 61 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 62 | drop_path_rate (float): Stochastic depth rate. Default: 0. 63 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 64 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 65 | """ 66 | def __init__(self, in_chans=3, num_classes=1000, 67 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 68 | layer_scale_init_value=1e-6, head_init_scale=1., 69 | ): 70 | super().__init__() 71 | 72 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 73 | stem = nn.Sequential( 74 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 75 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 76 | ) 77 | self.downsample_layers.append(stem) 78 | for i in range(3): 79 | downsample_layer = nn.Sequential( 80 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 81 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 82 | ) 83 | self.downsample_layers.append(downsample_layer) 84 | 85 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 86 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 87 | cur = 0 88 | for i in range(4): 89 | stage = nn.Sequential( 90 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 91 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] 92 | ) 93 | self.stages.append(stage) 94 | cur += depths[i] 95 | 96 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 97 | self.head = nn.Linear(dims[-1], num_classes) 98 | 99 | self.apply(self._init_weights) 100 | self.head.weight.data.mul_(head_init_scale) 101 | self.head.bias.data.mul_(head_init_scale) 102 | 103 | def _init_weights(self, m): 104 | if isinstance(m, (nn.Conv2d, nn.Linear)): 105 | trunc_normal_(m.weight, std=.02) 106 | nn.init.constant_(m.bias, 0) 107 | 108 | def forward_features(self, x): 109 | for i in range(4): 110 | x = self.downsample_layers[i](x) 111 | x = self.stages[i](x) 112 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) 113 | 114 | def forward(self, x): 115 | x = self.forward_features(x) 116 | x = self.head(x) 117 | return x 118 | 119 | class LayerNorm(nn.Module): 120 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 121 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 122 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 123 | with shape (batch_size, channels, height, width). 124 | """ 125 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 126 | super().__init__() 127 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 128 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 129 | self.eps = eps 130 | self.data_format = data_format 131 | if self.data_format not in ["channels_last", "channels_first"]: 132 | raise NotImplementedError 133 | self.normalized_shape = (normalized_shape, ) 134 | 135 | def forward(self, x): 136 | if self.data_format == "channels_last": 137 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 138 | elif self.data_format == "channels_first": 139 | u = x.mean(1, keepdim=True) 140 | s = (x - u).pow(2).mean(1, keepdim=True) 141 | x = (x - u) / torch.sqrt(s + self.eps) 142 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 143 | return x 144 | 145 | 146 | model_urls = { 147 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", 148 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", 149 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", 150 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", 151 | "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", 152 | "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", 153 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", 154 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", 155 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", 156 | } 157 | 158 | @register_model 159 | def convnext_tiny(pretrained=False,in_22k=False, **kwargs): 160 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 161 | if pretrained: 162 | url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] 163 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 164 | model.load_state_dict(checkpoint["model"]) 165 | return model 166 | 167 | @register_model 168 | def convnext_small(pretrained=False,in_22k=False, **kwargs): 169 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) 170 | if pretrained: 171 | url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] 172 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 173 | model.load_state_dict(checkpoint["model"]) 174 | return model 175 | 176 | @register_model 177 | def convnext_base(pretrained=False, in_22k=False, **kwargs): 178 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 179 | if pretrained: 180 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] 181 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 182 | model.load_state_dict(checkpoint["model"]) 183 | return model 184 | 185 | @register_model 186 | def convnext_large(pretrained=False, in_22k=False, **kwargs): 187 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 188 | if pretrained: 189 | url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] 190 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 191 | model.load_state_dict(checkpoint["model"]) 192 | return model 193 | 194 | @register_model 195 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): 196 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) 197 | if pretrained: 198 | assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" 199 | url = model_urls['convnext_xlarge_22k'] 200 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 201 | model.load_state_dict(checkpoint["model"]) 202 | return model -------------------------------------------------------------------------------- /models/convnext_iso.py: -------------------------------------------------------------------------------- 1 | '''Taken from convnext-github repo as is''' 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | # All rights reserved. 6 | 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | 11 | from functools import partial 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from timm.models.layers import trunc_normal_, DropPath 16 | from timm.models.registry import register_model 17 | from models.convnext import Block, LayerNorm 18 | 19 | class ConvNeXtIsotropic(nn.Module): 20 | r""" ConvNeXt 21 | A PyTorch impl of : `A ConvNet for the 2020s` - 22 | https://arxiv.org/pdf/2201.03545.pdf 23 | Isotropic ConvNeXts (Section 3.3 in paper) 24 | 25 | Args: 26 | in_chans (int): Number of input image channels. Default: 3 27 | num_classes (int): Number of classes for classification head. Default: 1000 28 | depth (tuple(int)): Number of blocks. Default: 18. 29 | dims (int): Feature dimension. Default: 384 30 | drop_path_rate (float): Stochastic depth rate. Default: 0. 31 | layer_scale_init_value (float): Init value for Layer Scale. Default: 0. 32 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 33 | """ 34 | def __init__(self, in_chans=3, num_classes=1000, 35 | depth=18, dim=384, drop_path_rate=0., 36 | layer_scale_init_value=0, head_init_scale=1., 37 | ): 38 | super().__init__() 39 | 40 | self.stem = nn.Conv2d(in_chans, dim, kernel_size=16, stride=16) 41 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, depth)] 42 | self.blocks = nn.Sequential(*[Block(dim=dim, drop_path=dp_rates[i], 43 | layer_scale_init_value=layer_scale_init_value) 44 | for i in range(depth)]) 45 | 46 | self.norm = LayerNorm(dim, eps=1e-6) # final norm layer 47 | self.head = nn.Linear(dim, num_classes) 48 | 49 | self.apply(self._init_weights) 50 | self.head.weight.data.mul_(head_init_scale) 51 | self.head.bias.data.mul_(head_init_scale) 52 | 53 | def _init_weights(self, m): 54 | if isinstance(m, (nn.Conv2d, nn.Linear)): 55 | trunc_normal_(m.weight, std=.02) 56 | nn.init.constant_(m.bias, 0) 57 | 58 | def forward_features(self, x): 59 | x = self.stem(x) 60 | x = self.blocks(x) 61 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) 62 | 63 | def forward(self, x): 64 | x = self.forward_features(x) 65 | x = self.head(x) 66 | return x 67 | 68 | @register_model 69 | def convnext_isotropic_small(pretrained=False, dim=384, depth=18, **kwargs): 70 | model = ConvNeXtIsotropic(depth=depth, dim=dim, **kwargs) 71 | if pretrained: 72 | url = 'https://dl.fbaipublicfiles.com/convnext/convnext_iso_small_1k_224_ema.pth' 73 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 74 | model.load_state_dict(checkpoint["model"]) 75 | return model 76 | 77 | @register_model 78 | def convnext_isotropic_base(pretrained=False, **kwargs): 79 | model = ConvNeXtIsotropic(depth=18, dim=768, **kwargs) 80 | if pretrained: 81 | url = 'https://dl.fbaipublicfiles.com/convnext/convnext_iso_base_1k_224_ema.pth' 82 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 83 | model.load_state_dict(checkpoint["model"]) 84 | return model 85 | 86 | @register_model 87 | def convnext_isotropic_large(pretrained=False, **kwargs): 88 | model = ConvNeXtIsotropic(depth=36, dim=1024, **kwargs) 89 | if pretrained: 90 | url = 'https://dl.fbaipublicfiles.com/convnext/convnext_iso_large_1k_224_ema.pth' 91 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 92 | model.load_state_dict(checkpoint["model"]) 93 | return model -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | """utils.py - Helper functions for building the model and for loading model parameters. 2 | These helper functions are built to mirror those in the official TensorFlow implementation. 3 | """ 4 | 5 | # Author: lukemelas (github username) 6 | # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch 7 | # With adjustments and added comments by workingcoder (github username). 8 | 9 | import re 10 | import math 11 | import collections 12 | from functools import partial 13 | import torch 14 | from torch import nn 15 | from torch.nn import functional as F 16 | from torch.utils import model_zoo 17 | 18 | 19 | ################################################################################ 20 | # Help functions for model architecture 21 | ################################################################################ 22 | 23 | # GlobalParams and BlockArgs: Two namedtuples 24 | # Swish and MemoryEfficientSwish: Two implementations of the method 25 | # round_filters and round_repeats: 26 | # Functions to calculate params for scaling model width and depth ! ! ! 27 | # get_width_and_height_from_size and calculate_output_image_size 28 | # drop_connect: A structural design 29 | # get_same_padding_conv2d: 30 | # Conv2dDynamicSamePadding 31 | # Conv2dStaticSamePadding 32 | # get_same_padding_maxPool2d: 33 | # MaxPool2dDynamicSamePadding 34 | # MaxPool2dStaticSamePadding 35 | # It's an additional function, not used in EfficientNet, 36 | # but can be used in other model (such as EfficientDet). 37 | 38 | # Parameters for the entire model (stem, all blocks, and head) 39 | GlobalParams = collections.namedtuple('GlobalParams', [ 40 | 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate', 41 | 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon', 42 | 'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top']) 43 | 44 | # Parameters for an individual model block 45 | BlockArgs = collections.namedtuple('BlockArgs', [ 46 | 'num_repeat', 'kernel_size', 'stride', 'expand_ratio', 47 | 'input_filters', 'output_filters', 'se_ratio', 'id_skip']) 48 | 49 | # Set GlobalParams and BlockArgs's defaults 50 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) 51 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) 52 | 53 | # Swish activation function 54 | if hasattr(nn, 'SiLU'): 55 | Swish = nn.SiLU 56 | else: 57 | # For compatibility with old PyTorch versions 58 | class Swish(nn.Module): 59 | def forward(self, x): 60 | return x * torch.sigmoid(x) 61 | 62 | 63 | # A memory-efficient implementation of Swish function 64 | class SwishImplementation(torch.autograd.Function): 65 | @staticmethod 66 | def forward(ctx, i): 67 | result = i * torch.sigmoid(i) 68 | ctx.save_for_backward(i) 69 | return result 70 | 71 | @staticmethod 72 | def backward(ctx, grad_output): 73 | i = ctx.saved_tensors[0] 74 | sigmoid_i = torch.sigmoid(i) 75 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) 76 | 77 | 78 | class MemoryEfficientSwish(nn.Module): 79 | def forward(self, x): 80 | return SwishImplementation.apply(x) 81 | 82 | 83 | def round_filters(filters, global_params): 84 | """Calculate and round number of filters based on width multiplier. 85 | Use width_coefficient, depth_divisor and min_depth of global_params. 86 | 87 | Args: 88 | filters (int): Filters number to be calculated. 89 | global_params (namedtuple): Global params of the model. 90 | 91 | Returns: 92 | new_filters: New filters number after calculating. 93 | """ 94 | multiplier = global_params.width_coefficient 95 | if not multiplier: 96 | return filters 97 | # TODO: modify the params names. 98 | # maybe the names (width_divisor,min_width) 99 | # are more suitable than (depth_divisor,min_depth). 100 | divisor = global_params.depth_divisor 101 | min_depth = global_params.min_depth 102 | filters *= multiplier 103 | min_depth = min_depth or divisor # pay attention to this line when using min_depth 104 | # follow the formula transferred from official TensorFlow implementation 105 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) 106 | if new_filters < 0.9 * filters: # prevent rounding by more than 10% 107 | new_filters += divisor 108 | return int(new_filters) 109 | 110 | 111 | def round_repeats(repeats, global_params): 112 | """Calculate module's repeat number of a block based on depth multiplier. 113 | Use depth_coefficient of global_params. 114 | 115 | Args: 116 | repeats (int): num_repeat to be calculated. 117 | global_params (namedtuple): Global params of the model. 118 | 119 | Returns: 120 | new repeat: New repeat number after calculating. 121 | """ 122 | multiplier = global_params.depth_coefficient 123 | if not multiplier: 124 | return repeats 125 | # follow the formula transferred from official TensorFlow implementation 126 | return int(math.ceil(multiplier * repeats)) 127 | 128 | 129 | def drop_connect(inputs, p, training): 130 | """Drop connect. 131 | 132 | Args: 133 | input (tensor: BCWH): Input of this structure. 134 | p (float: 0.0~1.0): Probability of drop connection. 135 | training (bool): The running mode. 136 | 137 | Returns: 138 | output: Output after drop connection. 139 | """ 140 | assert 0 <= p <= 1, 'p must be in range of [0,1]' 141 | 142 | if not training: 143 | return inputs 144 | 145 | batch_size = inputs.shape[0] 146 | keep_prob = 1 - p 147 | 148 | # generate binary_tensor mask according to probability (p for 0, 1-p for 1) 149 | random_tensor = keep_prob 150 | random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) 151 | binary_tensor = torch.floor(random_tensor) 152 | 153 | output = inputs / keep_prob * binary_tensor 154 | return output 155 | 156 | 157 | def get_width_and_height_from_size(x): 158 | """Obtain height and width from x. 159 | 160 | Args: 161 | x (int, tuple or list): Data size. 162 | 163 | Returns: 164 | size: A tuple or list (H,W). 165 | """ 166 | if isinstance(x, int): 167 | return x, x 168 | if isinstance(x, list) or isinstance(x, tuple): 169 | return x 170 | else: 171 | raise TypeError() 172 | 173 | 174 | def calculate_output_image_size(input_image_size, stride): 175 | """Calculates the output image size when using Conv2dSamePadding with a stride. 176 | Necessary for static padding. Thanks to mannatsingh for pointing this out. 177 | 178 | Args: 179 | input_image_size (int, tuple or list): Size of input image. 180 | stride (int, tuple or list): Conv2d operation's stride. 181 | 182 | Returns: 183 | output_image_size: A list [H,W]. 184 | """ 185 | if input_image_size is None: 186 | return None 187 | image_height, image_width = get_width_and_height_from_size(input_image_size) 188 | stride = stride if isinstance(stride, int) else stride[0] 189 | image_height = int(math.ceil(image_height / stride)) 190 | image_width = int(math.ceil(image_width / stride)) 191 | return [image_height, image_width] 192 | 193 | 194 | # Note: 195 | # The following 'SamePadding' functions make output size equal ceil(input size/stride). 196 | # Only when stride equals 1, can the output size be the same as input size. 197 | # Don't be confused by their function names ! ! ! 198 | 199 | def get_same_padding_conv2d(image_size=None): 200 | """Chooses static padding if you have specified an image size, and dynamic padding otherwise. 201 | Static padding is necessary for ONNX exporting of models. 202 | 203 | Args: 204 | image_size (int or tuple): Size of the image. 205 | 206 | Returns: 207 | Conv2dDynamicSamePadding or Conv2dStaticSamePadding. 208 | """ 209 | if image_size is None: 210 | return Conv2dDynamicSamePadding 211 | else: 212 | return partial(Conv2dStaticSamePadding, image_size=image_size) 213 | 214 | 215 | class Conv2dDynamicSamePadding(nn.Conv2d): 216 | """2D Convolutions like TensorFlow, for a dynamic image size. 217 | The padding is operated in forward function by calculating dynamically. 218 | """ 219 | 220 | # Tips for 'SAME' mode padding. 221 | # Given the following: 222 | # i: width or height 223 | # s: stride 224 | # k: kernel size 225 | # d: dilation 226 | # p: padding 227 | # Output after Conv2d: 228 | # o = floor((i+p-((k-1)*d+1))/s+1) 229 | # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1), 230 | # => p = (i-1)*s+((k-1)*d+1)-i 231 | 232 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): 233 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 234 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 235 | 236 | def forward(self, x): 237 | ih, iw = x.size()[-2:] 238 | kh, kw = self.weight.size()[-2:] 239 | sh, sw = self.stride 240 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! ! 241 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 242 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 243 | if pad_h > 0 or pad_w > 0: 244 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) 245 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 246 | 247 | 248 | class Conv2dStaticSamePadding(nn.Conv2d): 249 | """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. 250 | The padding mudule is calculated in construction function, then used in forward. 251 | """ 252 | 253 | # With the same calculation as Conv2dDynamicSamePadding 254 | 255 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs): 256 | super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs) 257 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 258 | 259 | # Calculate padding based on image size and save it 260 | assert image_size is not None 261 | ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size 262 | kh, kw = self.weight.size()[-2:] 263 | sh, sw = self.stride 264 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 265 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 266 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 267 | if pad_h > 0 or pad_w > 0: 268 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, 269 | pad_h // 2, pad_h - pad_h // 2)) 270 | else: 271 | self.static_padding = nn.Identity() 272 | 273 | def forward(self, x): 274 | x = self.static_padding(x) 275 | x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 276 | return x 277 | 278 | 279 | def get_same_padding_maxPool2d(image_size=None): 280 | """Chooses static padding if you have specified an image size, and dynamic padding otherwise. 281 | Static padding is necessary for ONNX exporting of models. 282 | 283 | Args: 284 | image_size (int or tuple): Size of the image. 285 | 286 | Returns: 287 | MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding. 288 | """ 289 | if image_size is None: 290 | return MaxPool2dDynamicSamePadding 291 | else: 292 | return partial(MaxPool2dStaticSamePadding, image_size=image_size) 293 | 294 | 295 | class MaxPool2dDynamicSamePadding(nn.MaxPool2d): 296 | """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size. 297 | The padding is operated in forward function by calculating dynamically. 298 | """ 299 | 300 | def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False): 301 | super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode) 302 | self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride 303 | self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size 304 | self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation 305 | 306 | def forward(self, x): 307 | ih, iw = x.size()[-2:] 308 | kh, kw = self.kernel_size 309 | sh, sw = self.stride 310 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 311 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 312 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 313 | if pad_h > 0 or pad_w > 0: 314 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) 315 | return F.max_pool2d(x, self.kernel_size, self.stride, self.padding, 316 | self.dilation, self.ceil_mode, self.return_indices) 317 | 318 | 319 | class MaxPool2dStaticSamePadding(nn.MaxPool2d): 320 | """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. 321 | The padding mudule is calculated in construction function, then used in forward. 322 | """ 323 | 324 | def __init__(self, kernel_size, stride, image_size=None, **kwargs): 325 | super().__init__(kernel_size, stride, **kwargs) 326 | self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride 327 | self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size 328 | self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation 329 | 330 | # Calculate padding based on image size and save it 331 | assert image_size is not None 332 | ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size 333 | kh, kw = self.kernel_size 334 | sh, sw = self.stride 335 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 336 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 337 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 338 | if pad_h > 0 or pad_w > 0: 339 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) 340 | else: 341 | self.static_padding = nn.Identity() 342 | 343 | def forward(self, x): 344 | x = self.static_padding(x) 345 | x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding, 346 | self.dilation, self.ceil_mode, self.return_indices) 347 | return x 348 | 349 | 350 | ################################################################################ 351 | # Helper functions for loading model params 352 | ################################################################################ 353 | 354 | # BlockDecoder: A Class for encoding and decoding BlockArgs 355 | # efficientnet_params: A function to query compound coefficient 356 | # get_model_params and efficientnet: 357 | # Functions to get BlockArgs and GlobalParams for efficientnet 358 | # url_map and url_map_advprop: Dicts of url_map for pretrained weights 359 | # load_pretrained_weights: A function to load pretrained weights 360 | 361 | class BlockDecoder(object): 362 | """Block Decoder for readability, 363 | straight from the official TensorFlow repository. 364 | """ 365 | 366 | @staticmethod 367 | def _decode_block_string(block_string): 368 | """Get a block through a string notation of arguments. 369 | 370 | Args: 371 | block_string (str): A string notation of arguments. 372 | Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'. 373 | 374 | Returns: 375 | BlockArgs: The namedtuple defined at the top of this file. 376 | """ 377 | assert isinstance(block_string, str) 378 | 379 | ops = block_string.split('_') 380 | options = {} 381 | for op in ops: 382 | splits = re.split(r'(\d.*)', op) 383 | if len(splits) >= 2: 384 | key, value = splits[:2] 385 | options[key] = value 386 | 387 | # Check stride 388 | assert (('s' in options and len(options['s']) == 1) or 389 | (len(options['s']) == 2 and options['s'][0] == options['s'][1])) 390 | 391 | return BlockArgs( 392 | num_repeat=int(options['r']), 393 | kernel_size=int(options['k']), 394 | stride=[int(options['s'][0])], 395 | expand_ratio=int(options['e']), 396 | input_filters=int(options['i']), 397 | output_filters=int(options['o']), 398 | se_ratio=float(options['se']) if 'se' in options else None, 399 | id_skip=('noskip' not in block_string)) 400 | 401 | @staticmethod 402 | def _encode_block_string(block): 403 | """Encode a block to a string. 404 | 405 | Args: 406 | block (namedtuple): A BlockArgs type argument. 407 | 408 | Returns: 409 | block_string: A String form of BlockArgs. 410 | """ 411 | args = [ 412 | 'r%d' % block.num_repeat, 413 | 'k%d' % block.kernel_size, 414 | 's%d%d' % (block.strides[0], block.strides[1]), 415 | 'e%s' % block.expand_ratio, 416 | 'i%d' % block.input_filters, 417 | 'o%d' % block.output_filters 418 | ] 419 | if 0 < block.se_ratio <= 1: 420 | args.append('se%s' % block.se_ratio) 421 | if block.id_skip is False: 422 | args.append('noskip') 423 | return '_'.join(args) 424 | 425 | @staticmethod 426 | def decode(string_list): 427 | """Decode a list of string notations to specify blocks inside the network. 428 | 429 | Args: 430 | string_list (list[str]): A list of strings, each string is a notation of block. 431 | 432 | Returns: 433 | blocks_args: A list of BlockArgs namedtuples of block args. 434 | """ 435 | assert isinstance(string_list, list) 436 | blocks_args = [] 437 | for block_string in string_list: 438 | blocks_args.append(BlockDecoder._decode_block_string(block_string)) 439 | return blocks_args 440 | 441 | @staticmethod 442 | def encode(blocks_args): 443 | """Encode a list of BlockArgs to a list of strings. 444 | 445 | Args: 446 | blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args. 447 | 448 | Returns: 449 | block_strings: A list of strings, each string is a notation of block. 450 | """ 451 | block_strings = [] 452 | for block in blocks_args: 453 | block_strings.append(BlockDecoder._encode_block_string(block)) 454 | return block_strings 455 | 456 | 457 | def efficientnet_params(model_name): 458 | """Map EfficientNet model name to parameter coefficients. 459 | 460 | Args: 461 | model_name (str): Model name to be queried. 462 | 463 | Returns: 464 | params_dict[model_name]: A (width,depth,res,dropout) tuple. 465 | """ 466 | params_dict = { 467 | # Coefficients: width,depth,res,dropout 468 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 469 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 470 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 471 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 472 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 473 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 474 | 'efficientnet-b6': (3.8, 5.3, 224, 0.5), 475 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5), 476 | 'efficientnet-b8': (2.2, 3.6, 672, 0.5), 477 | 'efficientnet-l2': (4.3, 5.3, 800, 0.5), 478 | } 479 | return params_dict[model_name] 480 | 481 | 482 | def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None, 483 | dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True): 484 | """Create BlockArgs and GlobalParams for efficientnet model. 485 | 486 | Args: 487 | width_coefficient (float) 488 | depth_coefficient (float) 489 | image_size (int) 490 | dropout_rate (float) 491 | drop_connect_rate (float) 492 | num_classes (int) 493 | 494 | Meaning as the name suggests. 495 | 496 | Returns: 497 | blocks_args, global_params. 498 | """ 499 | 500 | # Blocks args for the whole model(efficientnet-b0 by default) 501 | # It will be modified in the construction of EfficientNet Class according to model 502 | blocks_args = [ 503 | 'r1_k3_s11_e1_i32_o16_se0.25', 504 | 'r2_k3_s22_e1_i16_o24_se0.25', 505 | 'r2_k5_s22_e1_i24_o40_se0.25', 506 | 'r3_k3_s22_e1_i40_o80_se0.25', 507 | 'r3_k5_s11_e1_i80_o112_se0.25', 508 | 'r4_k5_s22_e1_i112_o192_se0.25', 509 | 'r1_k3_s11_e1_i192_o320_se0.25', 510 | ] 511 | blocks_args = BlockDecoder.decode(blocks_args) 512 | 513 | global_params = GlobalParams( 514 | width_coefficient=width_coefficient, 515 | depth_coefficient=depth_coefficient, 516 | image_size=image_size, 517 | dropout_rate=dropout_rate, 518 | 519 | num_classes=num_classes, 520 | batch_norm_momentum=0.99, 521 | batch_norm_epsilon=1e-3, 522 | drop_connect_rate=drop_connect_rate, 523 | depth_divisor=8, 524 | min_depth=None, 525 | include_top=include_top, 526 | ) 527 | 528 | return blocks_args, global_params 529 | 530 | 531 | def get_model_params(model_name, override_params): 532 | """Get the block args and global params for a given model name. 533 | 534 | Args: 535 | model_name (str): Model's name. 536 | override_params (dict): A dict to modify global_params. 537 | 538 | Returns: 539 | blocks_args, global_params 540 | """ 541 | if model_name.startswith('efficientnet'): 542 | w, d, s, p = efficientnet_params(model_name) 543 | # note: all models have drop connect rate = 0.2 544 | blocks_args, global_params = efficientnet( 545 | width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) 546 | else: 547 | raise NotImplementedError('model name is not pre-defined: {}'.format(model_name)) 548 | if override_params: 549 | # ValueError will be raised here if override_params has fields not included in global_params. 550 | global_params = global_params._replace(**override_params) 551 | return blocks_args, global_params 552 | 553 | 554 | # train with Standard methods 555 | # check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks) 556 | url_map = { 557 | 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth', 558 | 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth', 559 | 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth', 560 | 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth', 561 | 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth', 562 | 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth', 563 | 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth', 564 | 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth', 565 | } 566 | 567 | # train with Adversarial Examples(AdvProp) 568 | # check more details in paper(Adversarial Examples Improve Image Recognition) 569 | url_map_advprop = { 570 | 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth', 571 | 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth', 572 | 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth', 573 | 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth', 574 | 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth', 575 | 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth', 576 | 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth', 577 | 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth', 578 | 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth', 579 | } 580 | 581 | # TODO: add the petrained weights url map of 'efficientnet-l2' 582 | 583 | 584 | def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True): 585 | """Loads pretrained weights from weights path or download using url. 586 | 587 | Args: 588 | model (Module): The whole model of efficientnet. 589 | model_name (str): Model name of efficientnet. 590 | weights_path (None or str): 591 | str: path to pretrained weights file on the local disk. 592 | None: use pretrained weights downloaded from the Internet. 593 | load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. 594 | advprop (bool): Whether to load pretrained weights 595 | trained with advprop (valid when weights_path is None). 596 | """ 597 | if isinstance(weights_path, str): 598 | state_dict = torch.load(weights_path) 599 | else: 600 | # AutoAugment or Advprop (different preprocessing) 601 | url_map_ = url_map_advprop if advprop else url_map 602 | state_dict = model_zoo.load_url(url_map_[model_name]) 603 | 604 | if load_fc: 605 | ret = model.load_state_dict(state_dict, strict=False) 606 | assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) 607 | else: 608 | state_dict.pop('_fc.weight') 609 | state_dict.pop('_fc.bias') 610 | ret = model.load_state_dict(state_dict, strict=False) 611 | assert set(ret.missing_keys) == set( 612 | ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) 613 | assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys) 614 | 615 | if verbose: 616 | print('Loaded pretrained weights for {}'.format(model_name)) -------------------------------------------------------------------------------- /parserr.py: -------------------------------------------------------------------------------- 1 | '''Helper script to generate the parameter values for augmentations 2 | ''' 3 | 4 | import argparse 5 | 6 | def str2bool(v): 7 | 8 | if isinstance(v, bool): 9 | return v 10 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 11 | return True 12 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 13 | return False 14 | else: 15 | raise argparse.ArgumentTypeError('Boolean value expected.') 16 | 17 | class Arguments_augment(): 18 | def __init__(self): 19 | self.color_jitter = 0.4 20 | self.aa = 'rand-m9-mstd0.5-inc1' 21 | self.train_interpolation = 'bicubic' 22 | self.crop_pct = None 23 | self.reprob = 0.25 24 | self.remode = 'pixel' 25 | self.recount = 1 26 | self.resplit = False 27 | self.mixup = 0.8 28 | self.cutmix = 1.0 29 | self.cutmix_minmax = None 30 | self.mixup_prob = 1.0 31 | self.mixup_switch_prob = 0.5 32 | self.mixup_mode = 'batch' 33 | self.nb_classes = 1000 34 | self.input_size = 224 35 | self.data_set = 'IMNET' 36 | self.dist_eval = True 37 | self.hflip = 0.5 38 | self.vflip = 0.0 39 | self.scale = [0.08, 1.0] 40 | self.ratio = [3./4., 4./3.] 41 | 42 | 43 | 44 | class Arguments_No_augment(): 45 | def __init__(self): 46 | self.color_jitter = 0.0 47 | self.aa = None #'rand-m9-mstd0.5-inc1' 48 | self.train_interpolation = 'bicubic' 49 | self.crop_pct = None 50 | self.reprob = 0.0 51 | self.remode = None #'pixel' 52 | self.recount = 0 53 | self.resplit = False 54 | self.mixup = 0.0 55 | self.cutmix = 0. 56 | self.cutmix_minmax = None 57 | self.mixup_prob = 0.0 58 | self.mixup_switch_prob = 0. 59 | self.mixup_mode = None 60 | self.nb_classes = 1000 61 | self.data_set = 'IMNET' 62 | self.input_size = 224 63 | self.dist_eval = True 64 | self.hflip = 0.0 65 | self.vflip = 0.0 66 | self.scale = [0.08, 1.0] 67 | self.ratio = [3./4., 4./3.] 68 | -------------------------------------------------------------------------------- /rb_architecture_util.py: -------------------------------------------------------------------------------- 1 | '''All the submitted to RBench are defined here. 2 | Mostly use timm but have some custom implementation: ConvStem (ConvBlock) 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | from collections import OrderedDict 8 | from typing import Tuple 9 | from torch import Tensor 10 | import torch.nn as nn 11 | 12 | import timm 13 | from timm.models import create_model 14 | import torch.nn.functional as F 15 | import math 16 | 17 | IMAGENET_MEAN = [c * 1. for c in (0.485, 0.456, 0.406)] 18 | IMAGENET_STD = [c * 1. for c in (0.229, 0.224, 0.225)] 19 | 20 | 21 | class LayerNorm(nn.Module): 22 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 23 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 24 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 25 | with shape (batch_size, channels, height, width). 26 | """ 27 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"): 28 | super().__init__() 29 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 30 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 31 | self.eps = eps 32 | self.data_format = data_format 33 | if self.data_format not in ["channels_last", "channels_first"]: 34 | raise NotImplementedError 35 | self.normalized_shape = (normalized_shape, ) 36 | 37 | def forward(self, x): 38 | if self.data_format == "channels_last": 39 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 40 | elif self.data_format == "channels_first": 41 | u = x.mean(1, keepdim=True) 42 | s = (x - u).pow(2).mean(1, keepdim=True) 43 | x = (x - u) / torch.sqrt(s + self.eps) 44 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 45 | return x 46 | 47 | 48 | class ImageNormalizer(nn.Module): 49 | '''ADD normalization as a first layer in the models, as AA uses un-normalized inputs.''' 50 | 51 | def __init__(self, persistent: bool = True) -> None: 52 | super(ImageNormalizer, self).__init__() 53 | 54 | self.register_buffer('mean', torch.as_tensor(MEAN).view(1, 3, 1, 1), 55 | persistent=persistent) 56 | self.register_buffer('std', torch.as_tensor(STD).view(1, 3, 1, 1), 57 | persistent=persistent) 58 | 59 | def forward(self, input: Tensor) -> Tensor: 60 | return (input - self.mean) / self.std 61 | 62 | 63 | def normalize_model(model: nn.Module) -> nn.Module: 64 | layers = OrderedDict([ 65 | ('normalize', ImageNormalizer()), 66 | ('model', model) 67 | ]) 68 | return nn.Sequential(layers) 69 | 70 | 71 | 72 | def get_transforms(img_size=224): 73 | '''returns torch-transform as a callable for RobustBench. Used when testing for increased resolution models.''' 74 | crop_pct = 0.875 75 | scale_size = int(math.floor(img_size / crop_pct)) 76 | trans = transforms.Compose([ 77 | transforms.Resize( 78 | scale_size, 79 | interpolation=transforms.InterpolationMode("bicubic")), 80 | transforms.CenterCrop(img_size), 81 | transforms.ToTensor() 82 | ]) 83 | 84 | return trans 85 | 86 | 87 | class ConvBlock(nn.Module): 88 | expansion = 1 89 | def __init__(self, siz=48, end_siz=8, fin_dim=384): 90 | super(ConvBlock, self).__init__() 91 | self.planes = siz 92 | fin_dim = self.planes*end_siz if fin_dim != 432 else 432 93 | # self.bn = nn.BatchNorm2d(planes) if self.normaliz == "bn" else nn.GroupNorm(num_groups=1, num_channels=planes) 94 | self.stem = nn.Sequential(nn.Conv2d(3, self.planes, kernel_size=3, stride=2, padding=1), 95 | LayerNorm(self.planes, data_format="channels_first"), 96 | nn.GELU(), 97 | nn.Conv2d(self.planes, self.planes*2, kernel_size=3, stride=2, padding=1), 98 | LayerNorm(self.planes*2, data_format="channels_first"), 99 | nn.GELU(), 100 | nn.Conv2d(self.planes*2, self.planes*4, kernel_size=3, stride=2, padding=1), 101 | LayerNorm(self.planes*4, data_format="channels_first"), 102 | nn.GELU(), 103 | nn.Conv2d(self.planes*4, self.planes*8, kernel_size=3, stride=2, padding=1), 104 | LayerNorm(self.planes*8, data_format="channels_first"), 105 | nn.GELU(), 106 | nn.Conv2d(self.planes*8, fin_dim, kernel_size=1, stride=1, padding=0) 107 | ) 108 | def forward(self, x): 109 | out = self.stem(x) 110 | # out = self.bn(out) 111 | return out 112 | 113 | 114 | class ConvBlock3(nn.Module): 115 | # expansion = 1 116 | def __init__(self, siz=64): 117 | super(ConvBlock3, self).__init__() 118 | self.planes = siz 119 | 120 | self.stem = nn.Sequential(nn.Conv2d(3, self.planes, kernel_size=3, stride=2, padding=1), 121 | LayerNorm(self.planes, data_format="channels_first"), 122 | nn.GELU(), 123 | nn.Conv2d(self.planes, int(self.planes*1.5), kernel_size=3, stride=2, padding=1), 124 | LayerNorm(int(self.planes*1.5), data_format="channels_first"), 125 | nn.GELU(), 126 | nn.Conv2d(int(self.planes*1.5), self.planes*2, kernel_size=3, stride=1, padding=1), 127 | LayerNorm(self.planes*2, data_format="channels_first"), 128 | nn.GELU() 129 | ) 130 | 131 | def forward(self, x): 132 | out = self.stem(x) 133 | # out = self.bn(out) 134 | return out 135 | 136 | 137 | class ConvBlock1(nn.Module): 138 | def __init__(self, siz=48, end_siz=8, fin_dim=384): 139 | super(ConvBlock1, self).__init__() 140 | self.planes = siz 141 | 142 | fin_dim = self.planes*end_siz if fin_dim == None else 432 143 | self.stem = nn.Sequential(nn.Conv2d(3, self.planes, kernel_size=3, stride=2, padding=1), 144 | LayerNorm(self.planes, data_format="channels_first"), 145 | nn.GELU(), 146 | nn.Conv2d(self.planes, self.planes*2, kernel_size=3, stride=2, padding=1), 147 | LayerNorm(self.planes*2, data_format="channels_first"), 148 | nn.GELU() 149 | ) 150 | 151 | def forward(self, x): 152 | out = self.stem(x) 153 | # out = self.bn(out) 154 | return out 155 | 156 | 157 | class IdentityLayer(nn.Module): 158 | def forward(self, inputs): 159 | return inputs 160 | 161 | 162 | def get_new_model(modelname, pretrained=False, not_original=True): 163 | 164 | if modelname == 'convnext_t_cvst': 165 | model = timm.models.convnext.convnext_tiny(pretrained=pretrained) 166 | model.stem = ConvBlock1(48, end_siz=8) 167 | 168 | elif modelname == "convnext_s_cvst": 169 | model = timm.models.convnext.convnext_small(pretrained=pretrained) 170 | model.stem = ConvBlock1(48, end_siz=8) 171 | 172 | elif modelname == "convnext_b_cvst": 173 | model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024]) 174 | model = timm.models.convnext._create_convnext('convnext_base.fb_in1k', pretrained=pretrained, **model_args) 175 | model.stem = ConvBlock3(64) 176 | 177 | elif modelname == "convnext_l_cvst": 178 | model = timm.models.convnext_large(pretrained=pretrained) 179 | model.stem = ConvBlock3(96) 180 | 181 | elif modelname == 'vit_s_cvst': 182 | model = create_model('deit_small_patch16_224', pretrained=pretrained) 183 | model.patch_embed.proj = ConvBlock(48, end_siz=8) 184 | model = normalize_model(model) 185 | 186 | elif modelname == 'vit_b_cvst': 187 | model = timm.models.vision_transformer.vit_base_patch16_224(pretrained=pretrained) 188 | model.patch_embed.proj = ConvBlock(48, end_siz=16, fin_dim=None) 189 | 190 | else: 191 | logger.error('Invalid model name, please use either cait, deit, swin, vit, effnet, or rn50') 192 | sys.exit(1) 193 | return model 194 | 195 | def load_model(arch, not_original, chkpt_path): 196 | '''' 197 | Load the model with definition from the checkpoint 198 | arch: architecture name 199 | not_original: If True -> CvSt 200 | chkpt_path: location of checkpoint 201 | ''' 202 | model = get_new_model(arch, pretrained=False, not_original=not_original) 203 | ckpt = torch.load(chkpt_path, map_location='cpu') 204 | ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()} 205 | ckpt = {k.replace('base_model.', ''): v for k, v in ckpt.items()} 206 | ckpt = {k.replace('se_', 'se_module.'): v for k, v in ckpt.items()} 207 | 208 | model.load_state_dict(ckpt) 209 | model = model.to(device) 210 | model.eval() 211 | return model 212 | -------------------------------------------------------------------------------- /readme_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmndeep/revisiting-at/d7166c074223b89e3b7e1ec1489a8d069cc5afb7/readme_teaser.png -------------------------------------------------------------------------------- /run_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES = DEVICE_IDS separated by comma 4 | 5 | sleep 2s 6 | 7 | # Set the visible GPUs according to the `world_size` configuration parameter 8 | # Modify `data.in_memory` and `data.num_workers` based on your machine 9 | 10 | python3 main.py --data.num_workers=12 --data.in_memory=1 \ 11 | --data.train_dataset=path-to-imagenet-train-set \ 12 | --data.val_dataset=path-to-imagenet-val-set \ 13 | --logging.folder=path-to-logging-folder --logging.log_level 2 \ 14 | --adv.attack apgd --adv.n_iter 2 --adv.norm Linf --training.distributed 1 --training.batch_size 80 --lr.lr 1e-3 --logging.save_freq 2 \ 15 | --resolution.min_res 224 --resolution.max_res 224 --data.seed 0 --data.augmentations 1 --model.add_normalization 0\ 16 | --model.not_original 1 --model.model_ema 1 --lr.lr_peak_epoch 20\ 17 | --training.label_smoothing 0.1 --logging.addendum='additional_text_appended_to_save_folder_name'\ 18 | --dist.world_size '# of GPUS' --training.distributed 1 --model.pretrained 0 --model.arch convnext_base --training.epochs 300 \ 19 | 20 | -------------------------------------------------------------------------------- /runner_aa_eval.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import time 4 | from subprocess import call 5 | import GPUtil 6 | 7 | 8 | paramss = [ 9 | ("location_of_robust_model","convnext_base", 1, 1, 0, "Linf", 100), 10 | ] 11 | str_to_run = 'AA_eval.py' 12 | models_to_run = [] 13 | 14 | for minn, modd, not_o, a100, fullaa, norr, bs in paramss: 15 | models_to_run.append('{} --model_in {} --mod {} --not-orig {} --a100 {} --full_aa {} --l_norms {} --batch_size {}' 16 | .format(str_to_run, minn, modd, not_o, a100, fullaa, norr, bs)) 17 | 18 | cart_prod = [a \ 19 | for a in models_to_run] 20 | 21 | for job in cart_prod: 22 | print(job) 23 | 24 | time.sleep(5) 25 | count = 0 26 | wait = 0 27 | while wait<=1: 28 | gpu_ids = GPUtil.getAvailable(order = 'last', limit = 8, \ 29 | maxLoad = .1, maxMemory = .5) # get free gpus listd 30 | if len(gpu_ids) > 0: 31 | print(gpu_ids) 32 | # time.sleep(5) 33 | for id in gpu_ids: 34 | if id == 5: 35 | pass 36 | else: 37 | if id != 10: 38 | temp_list = cart_prod[count] 39 | 40 | command_to_exec = '' +\ 41 | ' CUDA_VISIBLE_DEVICES='+str(id)+\ 42 | ' python3' +\ 43 | ' ' + temp_list\ 44 | + ' &' # for going to next iteration without job in background. 45 | 46 | print("Command executing is " + command_to_exec) 47 | call(command_to_exec, shell=True) 48 | print('done executing in '+str(id)) 49 | count += 1 50 | time.sleep(2) # wait for processes to start 51 | # if count == 3: 52 | # time.sleep(7200) 53 | else: 54 | print('No gpus free waiting for 30 seconds') 55 | time.sleep(15) 56 | wait+=1 57 | # time.sleep(3600*3) # wait for processes to start 58 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | import requests 5 | import torch 6 | from collections import OrderedDict 7 | from model_zoo.models import model_dicts 8 | #from models_new import l_models_all, l_models_imagenet 9 | 10 | def download_gdrive(gdrive_id, fname_save): 11 | """ source: https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url """ 12 | def get_confirm_token(response): 13 | for key, value in response.cookies.items(): 14 | if key.startswith('download_warning'): 15 | return value 16 | 17 | return None 18 | 19 | def save_response_content(response, fname_save): 20 | CHUNK_SIZE = 32768 21 | 22 | with open(fname_save, "wb") as f: 23 | for chunk in response.iter_content(CHUNK_SIZE): 24 | if chunk: # filter out keep-alive new chunks 25 | f.write(chunk) 26 | 27 | print('Download started: path={} (gdrive_id={})'.format(fname_save, gdrive_id)) 28 | 29 | url_base = "https://docs.google.com/uc?export=download&confirm=t" 30 | session = requests.Session() 31 | 32 | response = session.get(url_base, params={'id': gdrive_id}, stream=True) 33 | token = get_confirm_token(response) 34 | 35 | if token: 36 | params = {'id': gdrive_id, 'confirm': token} 37 | response = session.get(url_base, params=params, stream=True) 38 | 39 | save_response_content(response, fname_save) 40 | print('Download finished: path={} (gdrive_id={})'.format(fname_save, gdrive_id)) 41 | 42 | 43 | def rm_substr_from_state_dict(state_dict, substr): 44 | new_state_dict = OrderedDict() 45 | for key in state_dict.keys(): 46 | if substr in key: # to delete prefix 'module.' if it exists 47 | new_key = key[len(substr):] 48 | new_state_dict[new_key] = state_dict[key] 49 | else: 50 | new_state_dict[key] = state_dict[key] 51 | return new_state_dict 52 | 53 | 54 | def load_model(model_name, model_dir='./models', norm='Linf'): 55 | model_dir_norm = '{}/{}'.format(model_dir, norm) 56 | if not isinstance(model_dicts[norm][model_name]['gdrive_id'], list): 57 | model_path = '{}/{}/{}.pt'.format(model_dir, norm, model_name) 58 | model = model_dicts[norm][model_name]['model']() 59 | if not os.path.exists(model_dir_norm): 60 | os.makedirs(model_dir_norm) 61 | if not os.path.isfile(model_path): 62 | download_gdrive(model_dicts[norm][model_name]['gdrive_id'], model_path) 63 | checkpoint = torch.load(model_path, map_location='cuda:0') 64 | 65 | # needed for the model of `Carmon2019Unlabeled` 66 | try: 67 | state_dict = rm_substr_from_state_dict(checkpoint['state_dict'], 'module.') 68 | except: 69 | state_dict = rm_substr_from_state_dict(checkpoint, 'module.') 70 | 71 | model.load_state_dict(state_dict, strict=True) 72 | return model.cuda().eval() 73 | 74 | # If we have an ensemble of models (e.g., Chen2020Adversarial) 75 | else: 76 | model_path = '{}/{}/{}'.format(model_dir, norm, model_name) 77 | model = model_dicts[norm][model_name]['model']() 78 | if not os.path.exists(model_dir_norm): 79 | os.makedirs(model_dir_norm) 80 | for i, gid in enumerate(model_dicts[norm][model_name]['gdrive_id']): 81 | if not os.path.isfile('{}_m{}.pt'.format(model_path, i)): 82 | download_gdrive(gid, '{}_m{}.pt'.format(model_path, i)) 83 | checkpoint = torch.load('{}_m{}.pt'.format(model_path, i), map_location='cuda:0') 84 | try: 85 | state_dict = rm_substr_from_state_dict(checkpoint['state_dict'], 'module.') 86 | except: 87 | state_dict = rm_substr_from_state_dict(checkpoint, 'module.') 88 | model.models[i].load_state_dict(state_dict) 89 | model.models[i].cuda().eval() 90 | return model 91 | 92 | 93 | def clean_accuracy(model, x, y, batch_size=100): 94 | acc = 0. 95 | n_batches = math.ceil(x.shape[0] / batch_size) 96 | with torch.no_grad(): 97 | for counter in range(n_batches): 98 | x_curr = x[counter * batch_size:(counter + 1) * batch_size].cuda() 99 | y_curr = y[counter * batch_size:(counter + 1) * batch_size].cuda() 100 | 101 | output = model(x_curr) 102 | acc += (output.max(1)[1] == y_curr).float().sum() 103 | 104 | return acc.item() / x.shape[0] 105 | 106 | 107 | def get_accuracy_and_logits(model, x, y, batch_size=100, n_classes=10): 108 | logits = torch.zeros([y.shape[0], n_classes]) 109 | acc = 0. 110 | n_batches = math.ceil(x.shape[0] / batch_size) 111 | with torch.no_grad(): 112 | for counter in range(n_batches): 113 | x_curr = x[counter * batch_size:(counter + 1) * batch_size].cuda() 114 | y_curr = y[counter * batch_size:(counter + 1) * batch_size].cuda() 115 | 116 | output = model(x_curr) 117 | logits[counter * batch_size:(counter + 1) * batch_size] += output.cpu() 118 | acc += (output.max(1)[1] == y_curr).float().sum() 119 | 120 | return acc.item() / x.shape[0], logits 121 | 122 | 123 | 124 | def list_available_models(norm='Linf'): 125 | models = model_dicts[norm].keys() 126 | 127 | json_dicts = [] 128 | for model_name in models: 129 | with open('./model_info/{}.json'.format(model_name), 'r') as model_info: 130 | json_dict = json.load(model_info) 131 | json_dict['model_name'] = model_name 132 | json_dict['venue'] = 'Unpublished' if json_dict['venue'] == '' else json_dict['venue'] 133 | json_dict['AA'] = float(json_dict['AA']) / 100 134 | json_dict['clean_acc'] = float(json_dict['clean_acc']) / 100 135 | json_dicts.append(json_dict) 136 | 137 | json_dicts = sorted(json_dicts, key=lambda d: -d['AA']) 138 | print('| # | Model ID | Paper | Clean accuracy | Robust accuracy | Architecture | Venue |') 139 | print('|:---:|---|---|:---:|:---:|:---:|:---:|') 140 | for i, json_dict in enumerate(json_dicts): 141 | print('| **{}** | **{}** | *[{}]({})* | {:.2%} | {:.2%} | {} | {} |'.format( 142 | i+1, json_dict['model_name'], json_dict['name'], json_dict['link'], json_dict['clean_acc'], json_dict['AA'], 143 | json_dict['architecture'], json_dict['venue'])) 144 | 145 | 146 | def load_model_fast_at(model_name, norm, model_dir, fts_before_bn): 147 | from model_zoo.fast_models import PreActResNet18, model_names 148 | model_name_long = model_names[norm][model_name] 149 | activation = [c.split('activation=')[-1] for c in model_name_long.split(' ') if 'activation' in c] 150 | if 'normal=' in model_name_long: 151 | normal = model_name_long.split('normal=')[1].split(' ')[0] 152 | else: 153 | normal = 'none' 154 | 155 | if 'resnet18' in model_name_long: 156 | model = PreActResNet18(n_cls=10, activation=activation[0], fts_before_bn=fts_before_bn, 157 | normal=normal) 158 | ckpt = torch.load('{}/{}'.format(model_dir, model_name_long)) 159 | model.load_state_dict({k: v for k, v in ckpt['last'].items() if 'model_preact_hl1' not in k}) 160 | return model.eval() 161 | 162 | 163 | def load_model_ssl(model_name, model_dir): 164 | from model_zoo.ssl_models import models_dict 165 | data = models_dict[model_name] 166 | model = data['model']() 167 | ckpt_base = torch.load('{}/{}'.format(model_dir, data['base']))['model'] 168 | ckpt_base = rm_substr_from_state_dict(ckpt_base, 'module.') 169 | model.base.load_state_dict(ckpt_base) 170 | ckpt_lin = torch.load('{}/{}'.format(model_dir, data['linear']))['model'] 171 | ckpt_lin = rm_substr_from_state_dict(ckpt_lin, 'module.') 172 | model.linear.load_state_dict(ckpt_lin) 173 | model.cuda() 174 | model.eval() 175 | return model 176 | 177 | def load_anymodel(model_name, model_dir='./models'): 178 | if len(model_name) == 2: 179 | return load_model(model_name[0], model_dir='./models', norm=model_name[1]).cuda().eval() 180 | elif len(model_name) == 3 and model_name[2] == 'fast_at': 181 | return load_model_fast_at(model_name[0], model_name[1], 182 | model_dir=model_dir, #'./models' #'../understanding-fast-adv-training-dev/models' 183 | fts_before_bn=False).cuda().eval() 184 | elif len(model_name) == 3 and model_name[2] == 'ssl': 185 | model = load_model_ssl(model_name[0], './models/ssl_models') 186 | assert not model.base.training 187 | assert not model.linear.training 188 | return model 189 | elif len(model_name) == 3 and model_name[2] == 'ext': 190 | from model_zoo.ext_models import load_ext_models 191 | return load_ext_models(model_name[0]) 192 | 193 | 194 | def load_anymodel_imagenet(model_name, **kwargs): 195 | if len(model_name) == 2: 196 | from model_zoo.models_imagenet import load_model as load_model_imagenet 197 | return load_model_imagenet(model_name[0], norm=model_name[1]) 198 | elif len(model_name) == 3 and model_name[2] == 'pretrained': 199 | from model_zoo.models_imagenet import PretrainedModel 200 | model = PretrainedModel(model_name[0]) 201 | assert not model.model.training 202 | return model 203 | elif len(model_name) == 3 and model_name[2] == 'ssl': 204 | model = load_model_ssl(model_name[0], './models/ssl_models') 205 | assert not model.base.training 206 | assert not model.linear.training 207 | return model 208 | elif len(model_name) == 3 and model_name[2] == 'ext': 209 | from model_zoo.ext_models import load_ext_models_imagenet 210 | return load_ext_models_imagenet(model_name[0], **kwargs) 211 | 212 | def load_anymodel_cifar100(model_name): 213 | if len(model_name) == 2: 214 | from model_zoo.models_cifar100 import load_model as load_model_cifar100 215 | return load_model_cifar100(model_name[0], norm=model_name[1]) 216 | 217 | 218 | def load_anymodel_imagenet100(model_name): 219 | if len(model_name) == 3 and model_name[2] == 'ext': 220 | from model_zoo.ext_models import load_ext_models_imagenet100 221 | return load_ext_models_imagenet100(model_name[0], model_name[1]) 222 | 223 | 224 | def load_anymodel_mnist(model_name): 225 | if len(model_name) == 2: 226 | from model_zoo.models_mnist import load_model as load_model_mnist 227 | return load_model_mnist(model_name[0], norm=model_name[1]) 228 | 229 | 230 | '''def load_anymodel_datasets(args): 231 | fts_idx = [int(c) for c in args.fts_idx.split(' ')] 232 | if args.dataset == 'cifar10': 233 | l_models = [l_models_all[c] for c in fts_idx] 234 | print(l_models) 235 | model = load_anymodel(l_models[0]) 236 | model.eval() 237 | elif args.dataset == 'imagenet': 238 | l_models = [l_models_imagenet[c] for c in fts_idx] 239 | print(l_models) 240 | model = load_anymodel_imagenet(l_models[0]) 241 | #sys.exit() 242 | with torch.no_grad(): 243 | acc = clean_accuracy(model, x, y, batch_size=25) 244 | print('clean accuracy: {:.1%}'.format(acc)) 245 | return model''' 246 | 247 | if __name__ == '__main__': 248 | #list_available_models() 249 | pass 250 | -------------------------------------------------------------------------------- /utils_architecture.py: -------------------------------------------------------------------------------- 1 | '''All the models are defined here. 2 | Custom implementation: ConvStem (ConvBlock) with standard models built on timm. 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | from collections import OrderedDict 8 | from typing import Tuple 9 | from torch import Tensor 10 | import torch.nn as nn 11 | 12 | import timm 13 | from functools import partial 14 | from timm.models import create_model 15 | from timm.models.convnext import _create_convnext as CNXT 16 | import torch.nn.functional as F 17 | from functools import partial 18 | import math 19 | from timm.models.vision_transformer import VisionTransformer 20 | 21 | 22 | def interpolate_pos_encoding( 23 | pos_embed: Tensor, 24 | new_img_size: int, 25 | old_img_size: int = 224, 26 | patch_size: int = 16) -> Tensor: 27 | """Interpolates the positional encoding of ViTs for new image resolution 28 | (adapted from https://github.com/facebookresearch/dino/blob/main/vision_transformer.py#L174). 29 | It currently handles only square images. 30 | """ 31 | N = pos_embed.shape[1] - 1 32 | npatch = (new_img_size // patch_size) ** 2 33 | w, h = new_img_size, new_img_size 34 | if npatch == N and w == h: 35 | print(f'Positional encoding not changed.') 36 | return pos_embed 37 | print(f'Interpolating positional encoding from {N} to {npatch} patches (size={patch_size}).') 38 | class_pos_embed = pos_embed[:, 0] 39 | patch_pos_embed = pos_embed[:, 1:] 40 | dim = pos_embed.shape[-1] 41 | w0 = w // patch_size 42 | h0 = h // patch_size 43 | # we add a small number to avoid floating point error in the interpolation 44 | # see discussion at https://github.com/facebookresearch/dino/issues/8 45 | w0, h0 = w0 + 0.1, h0 + 0.1 46 | patch_pos_embed = nn.functional.interpolate( 47 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 48 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 49 | mode='bicubic', 50 | ) 51 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 52 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 53 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 54 | 55 | 56 | 57 | class LayerNorm(nn.Module): 58 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 59 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 60 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 61 | with shape (batch_size, channels, height, width). 62 | """ 63 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"): 64 | super().__init__() 65 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 66 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 67 | self.eps = eps 68 | self.data_format = data_format 69 | if self.data_format not in ["channels_last", "channels_first"]: 70 | raise NotImplementedError 71 | self.normalized_shape = (normalized_shape, ) 72 | 73 | def forward(self, x): 74 | if self.data_format == "channels_last": 75 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 76 | elif self.data_format == "channels_first": 77 | u = x.mean(1, keepdim=True) 78 | s = (x - u).pow(2).mean(1, keepdim=True) 79 | x = (x - u) / torch.sqrt(s + self.eps) 80 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 81 | return x 82 | 83 | 84 | 85 | 86 | class ImageNormalizer(nn.Module): 87 | def __init__(self, mean: Tuple[float, float, float], 88 | std: Tuple[float, float, float], 89 | persistent: bool = True) -> None: 90 | super(ImageNormalizer, self).__init__() 91 | 92 | self.register_buffer('mean', torch.as_tensor(mean).view(1, 3, 1, 1), 93 | persistent=persistent) 94 | self.register_buffer('std', torch.as_tensor(std).view(1, 3, 1, 1), 95 | persistent=persistent) 96 | 97 | def forward(self, input: Tensor) -> Tensor: 98 | return (input - self.mean) / self.std 99 | 100 | 101 | def timm_gelu(inplace): 102 | return nn.GELU() 103 | 104 | def convert_relu_to_gelu(model): 105 | for child_name, child in model.named_children(): 106 | if isinstance(child, nn.ReLU): 107 | setattr(model, child_name, nn.GELU()) 108 | else: 109 | convert_relu_to_gelu(child) 110 | 111 | def normalize_model(model: nn.Module, mean: Tuple[float, float, float], 112 | std: Tuple[float, float, float]) -> nn.Module: 113 | layers = OrderedDict([ 114 | ('normalize', ImageNormalizer(mean, std)), 115 | ('model', model) 116 | ]) 117 | return nn.Sequential(layers) 118 | 119 | 120 | class ConvBlock(nn.Module): 121 | expansion = 1 122 | def __init__(self, siz=48, end_siz=8, fin_dim=384): 123 | super(ConvBlock, self).__init__() 124 | self.planes = siz 125 | fin_dim = self.planes*end_siz if fin_dim != 432 else 432 126 | # self.bn = nn.BatchNorm2d(planes) if self.normaliz == "bn" else nn.GroupNorm(num_groups=1, num_channels=planes) 127 | self.stem = nn.Sequential(nn.Conv2d(3, self.planes, kernel_size=3, stride=2, padding=1), 128 | LayerNorm(self.planes, data_format="channels_first"), 129 | nn.GELU(), 130 | nn.Conv2d(self.planes, self.planes*2, kernel_size=3, stride=2, padding=1), 131 | LayerNorm(self.planes*2, data_format="channels_first"), 132 | nn.GELU(), 133 | nn.Conv2d(self.planes*2, self.planes*4, kernel_size=3, stride=2, padding=1), 134 | LayerNorm(self.planes*4, data_format="channels_first"), 135 | nn.GELU(), 136 | nn.Conv2d(self.planes*4, self.planes*8, kernel_size=3, stride=2, padding=1), 137 | LayerNorm(self.planes*8, data_format="channels_first"), 138 | nn.GELU(), 139 | nn.Conv2d(self.planes*8, fin_dim, kernel_size=1, stride=1, padding=0) 140 | ) 141 | def forward(self, x): 142 | out = self.stem(x) 143 | # out = self.bn(out) 144 | return out 145 | 146 | class ConvBlock2(nn.Module): 147 | """Used only for det-medium""" 148 | expansion = 1 149 | def __init__(self, siz=48, end_siz=8, fin_dim=384): 150 | super(ConvBlock2, self).__init__() 151 | self.planes = siz 152 | fin_dim = self.planes*end_siz if fin_dim != 432 else 432 153 | # self.bn = nn.BatchNorm2d(planes) if self.normaliz == "bn" else nn.GroupNorm(num_groups=1, num_channels=planes) 154 | self.stem = nn.Sequential(nn.Conv2d(3, self.planes, kernel_size=3, stride=2, padding=1), 155 | LayerNorm(self.planes, data_format="channels_first"), 156 | nn.GELU(), 157 | nn.Conv2d(self.planes, self.planes*2, kernel_size=3, stride=2, padding=1), 158 | LayerNorm(self.planes*2, data_format="channels_first"), 159 | nn.GELU(), 160 | nn.Conv2d(self.planes*2, self.planes*4, kernel_size=3, stride=2, padding=1), 161 | LayerNorm(self.planes*4, data_format="channels_first"), 162 | nn.GELU(), 163 | nn.Conv2d(self.planes*4, self.planes*8, kernel_size=3, stride=2, padding=1), 164 | LayerNorm(self.planes*8, data_format="channels_first"), 165 | nn.GELU(), 166 | nn.Conv2d(self.planes*8, 512, kernel_size=1, stride=1, padding=0) 167 | ) 168 | def forward(self, x): 169 | out = self.stem(x) 170 | # out = self.bn(out) 171 | return out 172 | 173 | 174 | class ConvBlock3(nn.Module): 175 | # expansion = 1 176 | def __init__(self, siz=64): 177 | super(ConvBlock3, self).__init__() 178 | self.planes = siz 179 | 180 | self.stem = nn.Sequential(nn.Conv2d(3, self.planes, kernel_size=3, stride=2, padding=1), 181 | LayerNorm(self.planes, data_format="channels_first"), 182 | nn.GELU(), 183 | nn.Conv2d(self.planes, int(self.planes*1.5), kernel_size=3, stride=2, padding=1), 184 | LayerNorm(int(self.planes*1.5), data_format="channels_first"), 185 | nn.GELU(), 186 | nn.Conv2d(int(self.planes*1.5), self.planes*2, kernel_size=3, stride=1, padding=1), 187 | LayerNorm(self.planes*2, data_format="channels_first"), 188 | nn.GELU() 189 | ) 190 | 191 | 192 | def forward(self, x): 193 | out = self.stem(x) 194 | # out = self.bn(out) 195 | return out 196 | 197 | 198 | class ConvBlock1(nn.Module): 199 | # expansion = 1 200 | def __init__(self, siz=48, end_siz=8, fin_dim=384): 201 | super(ConvBlock1, self).__init__() 202 | self.planes = siz 203 | 204 | fin_dim = self.planes*end_siz if fin_dim == None else 432 205 | self.stem = nn.Sequential(nn.Conv2d(3, self.planes, kernel_size=3, stride=2, padding=1), 206 | LayerNorm(self.planes, data_format="channels_first"), 207 | nn.GELU(), 208 | nn.Conv2d(self.planes, self.planes*2, kernel_size=3, stride=2, padding=1), 209 | LayerNorm(self.planes*2, data_format="channels_first"), 210 | nn.GELU() 211 | ) 212 | 213 | 214 | def forward(self, x): 215 | out = self.stem(x) 216 | # out = self.bn(out) 217 | return out 218 | 219 | 220 | class IdentityLayer(nn.Module): 221 | def forward(self, inputs): 222 | return inputs 223 | 224 | 225 | def get_new_model(modelname, pretrained=True, not_original=False, updated=False): 226 | 227 | 228 | if modelname == 'resnet50': 229 | model = timm.models.resnet.resnet50(pretrained=pretrained) 230 | 231 | elif modelname == 'resnet50_gelu': 232 | model = timm.models.resnet.resnet50(pretrained=pretrained, 233 | act_layer=timm_gelu) 234 | 235 | # elif modelname == 'convnext_iso': 236 | 237 | # model = cnxt_iso.convnext_isotropic_small(pretrained=pretrained, dim=384, depth=18) 238 | # if not_original: 239 | # setattr(model, 'stem', ConvBlock(48, end_siz=8, fin_dim=432 if updated else 384)) 240 | 241 | elif modelname == 'convnext_tiny': 242 | model = timm.models.convnext.convnext_tiny(pretrained=pretrained) 243 | if not_original: 244 | model.stem = ConvBlock1(48, end_siz=8) 245 | 246 | elif modelname == "convnext_tiny_21k": 247 | model = timm.models.convnext._create_convnext('convnext_tiny.fb_in22k_ft_in1k', pretrained=pretrained) 248 | 249 | elif modelname == "convnext_small": 250 | model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768]) 251 | model = timm.models.convnext.convnext_small(pretrained=pretrained) 252 | if not_original: 253 | ## only for removing patch-stem 254 | model.stem = ConvBlock1(48, end_siz=8) 255 | 256 | elif modelname == "convnext_base": 257 | model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024]) 258 | # model = timm.models.create_model('convnext_base', pretrained=pretrained, pretrained_cfg='convnext_base.fb_in1k') 259 | model = timm.models.convnext._create_convnext('convnext_base.fb_in1k', pretrained=pretrained, **model_args) 260 | if not_original: 261 | ## only for removing patch-stem 262 | model.stem = ConvBlock3(64) 263 | 264 | elif modelname == "convnext_large": 265 | 266 | model = timm.models.convnext_large(pretrained=pretrained) 267 | 268 | if not_original: 269 | model.stem = ConvBlock3(96) 270 | 271 | elif modelname == 'vit_s': 272 | model = create_model('vit_small_patch16_224', pretrained=pretrained) 273 | if not_original: 274 | ## only for removing patch-stem 275 | model.patch_embed.proj = ConvBlock(48, end_siz=8) 276 | 277 | elif modelname == 'deit_s': 278 | from timm.models.deit import deit3_small_patch16_224, _create_deit 279 | model_kwargs = dict( 280 | patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=False, 281 | init_values=None) 282 | model = create_model('deit_small_patch16_224', pretrained=pretrained) 283 | if not_original: 284 | model.patch_embed.proj = ConvBlock(48, end_siz=8) 285 | 286 | elif modelname == 'vit_m': 287 | 288 | model = timm.models.deit.deit3_medium_patch16_224(pretrained=pretrained) 289 | if not_original: 290 | ## only for removing patch-stem 291 | model.patch_embed.proj = ConvBlock2(48) 292 | 293 | elif modelname == 'vit_s_21k': 294 | model = create_model('deit3_small_patch16_224_in21ft1k', pretrained=pretrained) 295 | 296 | 297 | elif modelname == 'vit_b': 298 | model = timm.models.vision_transformer.vit_base_patch16_224(pretrained=pretrained) # used for 21k pretrained on 206 299 | 300 | if not_original: 301 | model.patch_embed.proj = ConvBlock(48, end_siz=16, fin_dim=None) 302 | 303 | 304 | elif modelname == "resnet101": 305 | model = timm.models.resnet.resnet101(pretrained=False) 306 | 307 | elif modelname=="wrn_50_2": 308 | model = timm.models.resnet.wide_resnet50_2(pretrained=False) 309 | 310 | elif modelname == "densnet201": 311 | model = timm.models.densenet.densenet201(pretrained=pretrained) 312 | 313 | elif modelname == "inception": 314 | model = create_model('inception_v3', pretrained=pretrained) 315 | 316 | 317 | 318 | else: 319 | logger.error('Invalid model name, please use either cait, deit, swin, vit, effnet, or rn50') 320 | sys.exit(1) 321 | 322 | return model 323 | -------------------------------------------------------------------------------- /utils_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torchvision import datasets, transforms 6 | import numpy as np 7 | import math 8 | import os 9 | import json 10 | 11 | # from data import load_cifar10c 12 | from models_new import l_models_all, l_models_imagenet, l_models_cifar100,\ 13 | l_models_imagenet100, l_models_mnist 14 | from utils import load_anymodel, load_anymodel_imagenet, clean_accuracy,\ 15 | load_anymodel_cifar100, load_anymodel_imagenet100, load_anymodel_mnist 16 | try: 17 | from other_utils import L1_norm, L2_norm, Linf_norm, Logger, L0_norm 18 | except ImportError: 19 | from autoattack.other_utils import L1_norm, L2_norm, Logger, L0_norm 20 | from apgd_mask import criterion_dict 21 | 22 | 23 | cls = ['airplane', 'automobile', 'bird', 'cat', 'deer', 24 | 'dog', 'frog', 'horse', 'ship', 'truck'] 25 | 26 | 27 | class CalibratedModel(nn.Module): 28 | def __init__(self, model, temp): 29 | super().__init__() 30 | assert not model.training 31 | self.model = model 32 | assert temp > 0. 33 | self.temp = temp 34 | 35 | def forward(self, x): 36 | return self.model(x) / self.temp 37 | 38 | 39 | def Linf_norm(): 40 | raise NotImplementedError('Linf_norm to be added.') 41 | 42 | 43 | def get_acc_cifar10c(model, n_ex=10000, severities=[5], 44 | corruptions=('shot_noise', 'motion_blur', 'snow', 'pixelate', 45 | 'gaussian_noise', 'defocus_blur', 'brightness', 'fog', 'zoom_blur', 46 | 'frost', 'glass_blur', 'impulse_noise', 'contrast', 47 | 'jpeg_compression', 'elastic_transform'), bs=250): 48 | l_acc = [] 49 | acc_dets = {} 50 | 51 | for s in severities: 52 | x, y = load_cifar10c(n_ex, severity=s, corruptions=corruptions) 53 | x = x.contiguous() 54 | print(x.shape) 55 | with torch.no_grad(): 56 | acc = 0. 57 | n_batches = math.ceil(x.shape[0] / bs) 58 | for counter in range(n_batches): 59 | output = model(x[counter * bs:(counter + 1) * bs].cuda()) 60 | acc += (output.cpu().max(dim=1)[1] == y[counter * bs:(counter + 1) * bs]).sum() 61 | l_acc.append(acc / x.shape[0]) 62 | acc_dets[str(s)] = acc / x.shape[0] 63 | print('sev={}, clean accuracy={:.2%}'.format(s, acc / x.shape[0])) 64 | 65 | return acc / x.shape[0], acc_dets 66 | 67 | def check_imgs(adv, x, norm): 68 | delta = (adv - x).view(adv.shape[0], -1) 69 | if norm == 'Linf': 70 | res = delta.abs().max(dim=1)[0] 71 | elif norm == 'L2': 72 | res = (delta ** 2).sum(dim=1).sqrt() 73 | elif norm == 'L1': 74 | res = delta.abs().sum(dim=1) 75 | 76 | str_det = 'max {} pert: {:.5f}, nan in imgs: {}, max in imgs: {:.5f}, min in imgs: {:.5f}'.format( 77 | norm, res.max(), (adv != adv).sum(), adv.max(), adv.min()) 78 | print(str_det) 79 | print(adv.max().item() - 1., adv.min().item()) 80 | 81 | return str_det 82 | 83 | def get_cifar10_class(lab): 84 | return cls[lab] 85 | 86 | def get_imagenet_class(lab): 87 | if torch.is_tensor(lab): 88 | lab = lab.item() 89 | with open('./imagenet_classes.json') as json_file: 90 | class_dict = json.load(json_file) 91 | return class_dict[str(lab)][1] 92 | 93 | def get_class(args, cl=None): 94 | if cl is None: 95 | cl = args.target_class 96 | if args.dataset == 'cifar10': 97 | return get_cifar10_class(cl) 98 | elif args.dataset == 'imagenet': 99 | return get_imagenet_class(cl) 100 | 101 | def makedir(path): 102 | if not os.path.exists(path): 103 | os.makedirs(path) 104 | 105 | def load_anymodel_datasets(args): 106 | fts_idx = [int(c) for c in args.fts_idx.split(' ')] 107 | if args.dataset == 'cifar10': 108 | l_models = [l_models_all[c] for c in fts_idx] 109 | print(l_models) 110 | model = load_anymodel(l_models[0], args.model_dir) 111 | model.eval() 112 | elif args.dataset == 'imagenet': 113 | l_models = [l_models_imagenet[c] for c in fts_idx] 114 | print(l_models) 115 | kwargs = {} 116 | if (l_models[0][0].startswith('DeiT') 117 | #and 'convblock' not in l_models[0][0] 118 | or l_models[0][0].startswith('ViT') 119 | ): 120 | kwargs = {'img_size': args.img_size} 121 | model = load_anymodel_imagenet(l_models[0], **kwargs) 122 | #sys.exit() 123 | '''with torch.no_grad(): 124 | acc = clean_accuracy(model, x, y, batch_size=25) 125 | print('clean accuracy: {:.1%}'.format(acc))''' 126 | elif args.dataset == 'cifar100': 127 | l_models = [l_models_cifar100[c] for c in fts_idx] 128 | print(l_models) 129 | model = load_anymodel_cifar100(l_models[0]) 130 | model.eval() 131 | elif args.dataset == 'imagenet100': 132 | l_models = [l_models_imagenet100[c] for c in fts_idx] 133 | print(l_models) 134 | model = load_anymodel_imagenet100(l_models[0]) 135 | model.eval() 136 | elif args.dataset == 'mnist': 137 | l_models = [l_models_mnist[c] for c in fts_idx] 138 | print(l_models) 139 | model = load_anymodel_mnist(l_models[0]) 140 | model.eval() 141 | return model, l_models 142 | 143 | 144 | def attack_group(norm, suffix=''): 145 | if norm in ['Linf', 'L2', 'L1']: 146 | return 'aa' + suffix 147 | else: 148 | return 'nonlpattacks' 149 | 150 | 151 | def get_norm(z, norm): 152 | if norm == 'Linf': 153 | return Linf_norm(z) 154 | elif norm == 'L2': 155 | return L2_norm(z) 156 | elif norm == 'L1': 157 | return L1_norm(z) 158 | elif norm == 'L0': 159 | return L0_norm(z) 160 | 161 | 162 | def get_logits(model, x_test, bs=1000, device=None, n_cls=10): 163 | if device is None: 164 | device = x_test.device 165 | n_batches = math.ceil(x_test.shape[0] / bs) 166 | logits = torch.zeros([x_test.shape[0], n_cls], device=device) 167 | #l_logits = [] 168 | 169 | with torch.no_grad(): 170 | for counter in range(n_batches): 171 | x_curr = x_test[counter * bs:(counter + 1) * bs].to(device) 172 | output = model(x_curr) 173 | #l_logits.append(output.detach()) 174 | logits[counter * bs:(counter + 1) * bs] += output.detach() 175 | 176 | return logits 177 | 178 | 179 | def get_wc_acc(model, xs, y, bs=1000, device=None, eot_test=1, logger=None, 180 | loss=None, n_cls=10): 181 | if device is None: 182 | device = x.device 183 | if logger is None: 184 | logger = Logger(None) 185 | if not loss is None: 186 | criterion_indiv = criterion_dict[loss] 187 | y = y.to(device) 188 | acc = torch.ones_like(y, device=device).float() 189 | x_adv = xs[0].clone() 190 | loss_best = -1. * float('inf') * torch.ones(y.shape[0], device=device) 191 | 192 | for x in xs: 193 | logits = get_logits(model, x, bs=bs, device=device, n_cls=n_cls) 194 | loss_curr = criterion_indiv(logits, y) 195 | pred_curr = logits.max(1)[1] == y 196 | ind = ~pred_curr * (loss_curr > loss_best) # misclassified points with higher loss 197 | x_adv[ind] = x[ind].clone() 198 | acc *= pred_curr 199 | ind = (acc == 1.) * (loss_curr > loss_best) # for robust points track highest loss 200 | x_adv[ind] = x[ind].clone() 201 | logger.log(f'[rob acc] cum={acc.mean():.1%} curr={pred_curr.float().mean():.1%}') 202 | 203 | print(torch.nonzero(acc).squeeze()) 204 | 205 | return acc.mean(), x_adv 206 | 207 | 208 | def get_patchsize(dataset, modelname): 209 | if dataset == 'imagenet': 210 | if modelname in ['ConvMixer_1024_20_nat', 'ConvMixer_1024_20_eps4_best']: 211 | return 14 212 | else: 213 | return 16 214 | else: 215 | return 8 216 | -------------------------------------------------------------------------------- /utils_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torchvision import datasets, transforms 6 | import numpy as np 7 | import copy 8 | import random 9 | import math 10 | import time 11 | from collections import OrderedDict 12 | from typing import Tuple 13 | from torch import Tensor 14 | import os 15 | 16 | try: 17 | from autopgd_pt import L1_projection 18 | except ImportError: 19 | from autoattack.autopgd_base import L1_projection 20 | 21 | class FGSMAttack(): 22 | def __init__(self, eps=8. / 255., step_size=None, loss=None): 23 | self.eps = eps 24 | self.step_size = step_size if not step_size is None else eps 25 | self.loss = loss 26 | 27 | def perturb(self, model, x, y, random_start=False): 28 | assert not self.loss is None 29 | if random_start: 30 | t = (x + (2. * torch.rand_like(x) - 1.) * self.eps).clamp(0., 1.) 31 | else: 32 | t = x.clone() 33 | 34 | t.requires_grad = True 35 | output = model(t) 36 | loss = self.loss(output, y) 37 | grad = torch.autograd.grad(loss, t)[0] 38 | 39 | x_adv = x + grad.detach().sign() * self.step_size 40 | return torch.min(torch.max(x_adv, x - self.eps), x + self.eps).clamp(0., 1.) 41 | 42 | class PGDAttack(): 43 | def __init__(self, eps=8. / 255., step_size=None, loss=None, n_iter=10, 44 | norm='Linf', verbose=False): 45 | self.eps = eps 46 | self.step_size = step_size if not step_size is None else eps / n_iter * 1.5 47 | self.loss = loss 48 | self.n_iter = n_iter 49 | self.norm = norm 50 | self.verbose = verbose 51 | 52 | def perturb(self, model, x, y, random_start=False, return_acc=False): 53 | assert not self.loss is None 54 | if random_start: 55 | x_adv = (x + (2. * torch.rand_like(x) - 1.) * self.eps).clamp(0., 1.) 56 | else: 57 | x_adv = x.clone() 58 | 59 | n_fts = x.shape[1] * x.shape[2] * x.shape[3] 60 | x_best, loss_best = x_adv.clone(), torch.zeros_like(y).float() 61 | acc = torch.ones_like(y).detach().float() 62 | 63 | x_adv.requires_grad = True 64 | output = model(x_adv) 65 | loss = self.loss(output, y, reduction='none') 66 | ind = loss > loss_best 67 | x_best[ind] = x_adv[ind].clone().detach() 68 | loss_best[ind] = loss[ind].clone().detach() 69 | acc[ind] = (output.max(dim=1)[1] == y).float()[ind].clone().detach() 70 | grad = torch.autograd.grad(loss.mean(), x_adv)[0] 71 | 72 | for it in range(self.n_iter): 73 | if self.norm == 'Linf': 74 | x_adv = x_adv.detach() + grad.detach().sign() * self.step_size 75 | x_adv = torch.min(torch.max(x_adv, x - self.eps), x + self.eps).clamp(0., 1.) 76 | 77 | elif self.norm == 'L2': 78 | x_adv = x_adv.detach() + grad.detach() / (grad.detach() ** 2 79 | ).sum(dim=(1, 2, 3), keepdim=True).sqrt() * self.step_size 80 | delta_l2norm = ((x_adv - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() 81 | x_adv = (x + (x_adv - x) * torch.min(torch.ones_like(delta_l2norm), 82 | self.eps * torch.ones_like(delta_l2norm) / delta_l2norm)).clamp(0., 1.) 83 | 84 | elif self.norm == 'L1': 85 | grad = grad.detach() 86 | grad_topk = grad.abs().view(grad.shape[0], -1).topk( 87 | k=max(int(.1 * n_fts), 1), dim=-1)[0][:, -1].view(grad.shape[0], 88 | *[1]*(len(grad.shape) - 1)) 89 | sparsegrad = grad * (grad.abs() >= grad_topk).float() 90 | x_adv = x_adv.detach() + self.step_size * sparsegrad / (sparsegrad.abs().view( 91 | x.shape[0], -1).sum(dim=-1).view(-1, 1, 1, 1) + 1e-10) 92 | delta_temp = L1_projection(x, x_adv - x, self.eps) 93 | x_adv += delta_temp 94 | 95 | x_adv.requires_grad = True 96 | output = model(x_adv) 97 | loss = self.loss(output, y, reduction='none') 98 | ind = loss > loss_best 99 | x_best[ind] = x_adv[ind].clone().detach() 100 | loss_best[ind] = loss[ind].clone().detach() 101 | acc[ind] = (output.max(dim=1)[1] == y).float()[ind].clone().detach() 102 | grad = torch.autograd.grad(loss.mean(), x_adv)[0] 103 | 104 | if self.verbose: 105 | print('[{}] it={} loss={:.5f} acc={:.1%}'.format(self.norm, it, loss_best.mean().item(), 106 | (output.max(dim=1)[1] == y).cpu().float().mean())) 107 | 108 | if not return_acc: 109 | return x_best.detach() 110 | else: 111 | return x_best.detach(), acc 112 | 113 | class MSDAttack(): 114 | def __init__(self, eps, step_size=None, loss=None, n_iter=10): 115 | self.eps = eps 116 | self.step_size = step_size if not step_size is None else [eps / n_iter * 1.25 for eps in self.eps] 117 | self.loss = loss 118 | self.n_iter = n_iter 119 | 120 | def perturb(self, model, x, y, random_start=False): 121 | assert not self.loss is None 122 | if random_start: 123 | x_adv = (x + (2. * torch.rand_like(x) - 1.) * self.eps).clamp(0., 1.) 124 | else: 125 | x_adv = x.clone() 126 | 127 | n_fts = x.shape[1] * x.shape[2] * x.shape[3] 128 | x_best, loss_best = x_adv.clone(), torch.zeros_like(y).float() 129 | #x_adv.requires_grad = True 130 | 131 | for it in range(self.n_iter): 132 | x_adv.requires_grad = True 133 | #with torch.enable_grad() 134 | output = model(x_adv) 135 | loss = self.loss(output, y, reduction='none') 136 | avgloss = loss.mean() 137 | grad = torch.autograd.grad(avgloss, x_adv)[0] 138 | ind = loss > loss_best 139 | x_best[ind] = x_adv[ind].clone().detach() 140 | loss_best[ind] = loss[ind].clone().detach() 141 | #grad = torch.autograd.grad(loss.mean(), x_adv)[0].detach() 142 | 143 | x_adv_linf = x_adv.detach() + grad.detach().sign() * self.step_size[0] 144 | x_adv_linf = torch.min(torch.max(x_adv_linf, x - self.eps[0]), x + self.eps[0]).clamp(0., 1.) 145 | 146 | x_adv_l2 = x_adv.detach() + grad.detach() / (grad.detach() ** 2 147 | ).sum(dim=(1, 2, 3), keepdim=True).sqrt() * self.step_size[1] 148 | delta_l2norm = ((x_adv_l2 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() 149 | x_adv_l2 = (x + (x_adv_l2 - x) * torch.min(torch.ones_like(delta_l2norm), 150 | self.eps[1] * torch.ones_like(delta_l2norm) / delta_l2norm)).clamp(0., 1.) 151 | 152 | grad = grad.detach() 153 | grad_topk = grad.abs().view(grad.shape[0], -1).topk( 154 | k=max(int(.1 * n_fts), 1), dim=-1)[0][:, -1].view(grad.shape[0], 155 | *[1]*(len(grad.shape) - 1)) 156 | sparsegrad = grad * (grad.abs() >= grad_topk).float() 157 | x_adv_l1 = x_adv.detach() + self.step_size[2] * sparsegrad / (sparsegrad.abs().view( 158 | x.shape[0], -1).sum(dim=-1).view(-1, 1, 1, 1) + 1e-10) 159 | delta_temp = L1_projection(x, x_adv_l1 - x, self.eps[2]) 160 | x_adv_l1 += delta_temp 161 | 162 | '''x_adv_linf.requires_grad_() 163 | x_adv_l2.requires_grad_() 164 | x_adv_l1.reuiqres_grad_()''' 165 | l_x_adv = [x_adv_linf, x_adv_l2, x_adv_l1] 166 | l_output = [model(x_adv_linf), model(x_adv_l2), model(x_adv_l1)] 167 | l_loss = [self.loss(c, y, reduction='none') for c in l_output] 168 | #l_avgloss = [c.mean() for c in l_loss] 169 | val_max, ind_max = torch.max(torch.stack(l_loss, dim=1), dim=1) 170 | #x_adv, loss = l_x_adv[ind_max], l_loss[ind_max] 171 | x_adv = x_adv_linf.clone() 172 | x_adv[ind_max == 1] = x_adv_l2[ind_max == 1].clone() 173 | x_adv[ind_max == 2] = x_adv_l1[ind_max == 2].clone() 174 | #print('it={} - best norm=({}, {}, {}) - loss={}'.format(it + 1, (ind_max == 0).sum(), 175 | # (ind_max == 1).sum(), (ind_max == 2).sum(), loss_best.mean().item())) 176 | 177 | return x_best.detach() 178 | 179 | class MultiPGDAttack(): 180 | def __init__(self, eps, step_size=None, loss=None, n_iter=[10, 10, 10], use_miscl=False, 181 | l_norms=None): 182 | self.eps = eps 183 | self.step_size = step_size if not step_size is None else [eps / n_iter * 1.5 for eps in self.eps] 184 | self.loss = loss 185 | self.n_iter = n_iter 186 | self.indiv_adversary = PGDAttack(eps[0], loss=loss) 187 | self.use_miscl = use_miscl 188 | self.l_norms = l_norms if not l_norms is None else ['Linf', 'L2', 'L1'] 189 | 190 | def perturb(self, model, x, y, random_start=False, return_acc=False): 191 | #assert not self.loss is None 192 | l_x_adv = [] 193 | l_acc = [] 194 | for i, norm in enumerate(self.l_norms): 195 | self.indiv_adversary.eps = self.eps[i] + 0. 196 | self.indiv_adversary.step_size = self.step_size[i] + 0. 197 | self.indiv_adversary.norm = norm + '' 198 | self.indiv_adversary.n_iter = self.n_iter[i] 199 | if not return_acc: 200 | x_curr = self.indiv_adversary.perturb(model, x, y) 201 | else: 202 | x_curr, acc_curr = self.indiv_adversary.perturb(model, x, y, return_acc=True) 203 | l_acc.append(acc_curr) 204 | l_x_adv.append(x_curr.clone()) 205 | 206 | if not self.use_miscl: 207 | if not return_acc: 208 | return torch.cat(l_x_adv, dim=0) 209 | else: 210 | return torch.cat(l_x_adv, dim=0), torch.cat(l_acc, dim=0) 211 | else: 212 | #logits = torch.zeros([len(l_x_adv), x.shape[0], 10]).cuda() 213 | loss = torch.zeros([len(l_x_adv), x.shape[0]]).cuda() 214 | for i, x_adv in enumerate(l_x_adv): 215 | output = model(x_adv) 216 | loss[i] = self.loss(output, y) - 1e5 * (output.max(dim=1)[1] == y).float() 217 | ind_max = loss.max(dim=0)[1] 218 | x_adv = l_x_adv[0].clone() 219 | x_adv[ind_max == 1] = l_x_adv[1][ind_max == 1].clone() 220 | x_adv[ind_max == 2] = l_x_adv[2][ind_max == 2].clone() 221 | 222 | return x_adv 223 | 224 | # data loaders 225 | def load_data(args): 226 | crop_input_size = 0 if args.crop_input is None else args.crop_input 227 | crop_data_size = 0 if args.crop_data is None else args.crop_data 228 | if args.dataset == 'cifar10': 229 | train_transform = transforms.Compose([ 230 | transforms.RandomCrop(32 - crop_data_size, padding=4 - crop_input_size), 231 | transforms.RandomHorizontalFlip(), 232 | #transforms.RandomRotation(15), 233 | transforms.ToTensor(), 234 | ]) 235 | elif args.dataset == 'svhn': 236 | train_transform = transforms.Compose([ 237 | #transforms.RandomCrop(32 - crop_data_size, padding=4 - crop_input_size), 238 | transforms.ToTensor(), 239 | ]) 240 | elif args.dataset == 'mnist': 241 | train_transform = transforms.Compose([ 242 | transforms.ToTensor(), 243 | ]) 244 | test_transform = transforms.Compose([ 245 | transforms.ToTensor(), 246 | ]) 247 | 248 | root = args.data_dir + '' #'/home/EnResNet/WideResNet34-10/data/' 249 | num_workers = 2 250 | 251 | if args.dataset == 'cifar10': 252 | train_dataset = datasets.CIFAR10( 253 | root, train=True, transform=train_transform, download=True) 254 | test_dataset = datasets.CIFAR10( 255 | root, train=False, transform=test_transform, download=True) 256 | elif args.dataset == 'svhn': 257 | train_dataset = datasets.SVHN(root=root, split='train', 258 | transform=train_transform, download=True) 259 | test_dataset = datasets.SVHN(root=root, split='test', #train=False 260 | transform=test_transform, download=True) 261 | elif args.dataset == 'mnist': 262 | train_dataset = datasets.MNIST( 263 | root, train=True, transform=train_transform, download=True) 264 | test_dataset = datasets.MNIST( 265 | root, train=False, transform=test_transform, download=True) 266 | 267 | train_loader = torch.utils.data.DataLoader( 268 | dataset=train_dataset, 269 | batch_size=args.batch_size, 270 | shuffle=True, 271 | pin_memory=True, 272 | num_workers=num_workers, 273 | ) 274 | test_loader = torch.utils.data.DataLoader( 275 | dataset=test_dataset, 276 | batch_size=args.batch_size_eval, 277 | shuffle=False, 278 | pin_memory=True, 279 | num_workers=0, 280 | ) 281 | 282 | return train_loader, test_loader 283 | 284 | def load_imagenet_train(args): 285 | from robustness.datasets import DATASETS 286 | from robustness.tools import helpers 287 | data_paths = ['/home/scratch/datasets/imagenet', 288 | '/scratch_local/datasets/ImageNet2012', 289 | '/mnt/qb/datasets/ImageNet2012', 290 | '/scratch/datasets/imagenet/'] 291 | for data_path in data_paths: 292 | if os.path.exists(data_path): 293 | break 294 | print(f'found dataset at {data_path}') 295 | dataset = DATASETS['imagenet'](data_path) #'/home/scratch/datasets/imagenet' 296 | 297 | 298 | train_loader, val_loader = dataset.make_loaders(2, 299 | args.batch_size, data_aug=True) 300 | 301 | train_loader = helpers.DataPrefetcher(train_loader) 302 | val_loader = helpers.DataPrefetcher(val_loader) 303 | return train_loader, val_loader 304 | 305 | # other utils 306 | def get_accuracy(model, data_loader=None): 307 | assert not model.training 308 | if not data_loader is None: 309 | acc = 0. 310 | c = 0 311 | with torch.no_grad(): 312 | for (x, y) in data_loader: 313 | output = model(x.cuda()) 314 | acc += (output.cpu().max(dim=1)[1] == y).float().sum() 315 | c += x.shape[0] 316 | return acc.item() / c 317 | 318 | def get_lr_schedule(args): 319 | if args.lr_schedule == 'superconverge': 320 | lr_schedule = lambda t: np.interp([t], [0, args.epochs * 2 // 5, args.epochs], [0, args.lr_max, 0])[0] 321 | # lr_schedule = lambda t: np.interp([t], [0, args.epochs], [0, args.lr_max])[0] 322 | elif args.lr_schedule == 'piecewise': 323 | def lr_schedule(t): 324 | if t / args.epochs < 0.5: 325 | return args.lr_max 326 | elif t / args.epochs < 0.75: 327 | return args.lr_max / 10. 328 | else: 329 | return args.lr_max / 100. 330 | elif args.lr_schedule.startswith('piecewise'): 331 | w = [float(c) for c in args.lr_schedule.split('-')[1:]] 332 | def lr_schedule(t): 333 | c = 0 334 | while t / args.epochs > sum(w[:c + 1]) / sum(w): 335 | c += 1 336 | return args.lr_max / 10. ** c 337 | return lr_schedule 338 | 339 | def norm_schedule(it, epoch, epochs, l_norms, ps=None, schedule='piecewise'): 340 | #assert l_norms == ['Linf', 'L2', 'L1'] 341 | if schedule == 'piecewise': 342 | if epoch < epochs * .5: 343 | return l_norms.index('L2') 344 | else: 345 | if not ps is None: 346 | ind_linf = l_norms.index('Linf') 347 | ind_l1 = l_norms.index('L1') 348 | return random.choices([ind_linf, ind_l1], weights=[ 349 | ps[ind_linf], ps[ind_l1]])[0] 350 | if it % 2 == 0: 351 | return l_norms.index('Linf') 352 | else: 353 | return l_norms.index('L1') 354 | 355 | 356 | # swa tools 357 | def polyak_averaging(p_local, p_new, it): 358 | return (it * p_local + p_new) / (it + 1.) 359 | 360 | def exp_decay(p_local, p_new, theta=.995): 361 | return theta * p_local + (1. - theta) * p_new 362 | 363 | class AveragedModel(nn.Module): 364 | def __init__(self, model): 365 | super(AveragedModel, self).__init__() 366 | self._model = copy.deepcopy(model) 367 | # 368 | 369 | def forward(self, x): 370 | return self._model(x) 371 | 372 | @torch.no_grad() 373 | def update_parameters(self, model, avg_fun=exp_decay): 374 | for p_local, p_new in zip(self._model.parameters(), model.parameters()): 375 | p_local.set_(avg_fun(p_local, p_new)) 376 | 377 | 378 | # initializations 379 | def initialize_weights(module): 380 | if isinstance(module, nn.Conv2d): 381 | n = module.kernel_size[0] * module.kernel_size[1] * module.out_channels 382 | module.weight.data.normal_(0, math.sqrt(2. / n)) 383 | if module.bias is not None: 384 | module.bias.data.zero_() 385 | elif isinstance(module, nn.Linear): 386 | n = module.in_features 387 | module.weight.data.normal_(0, math.sqrt(2. / n)) 388 | if module.bias is not None: 389 | module.bias.data.zero_() 390 | elif isinstance(module, nn.BatchNorm2d): 391 | module.weight.data.fill_(1) 392 | module.bias.data.zero_() 393 | 394 | # stepsizes for pgd-at 395 | def get_stepsize(norm, eps, method='default'): 396 | #assert method == 'default' 397 | if method == 'default': 398 | if norm == 'Linf': 399 | return eps / 4. 400 | elif norm == 'L2': 401 | return eps / 3. 402 | elif norm == 'L1': 403 | return 2. * eps * 255. / 2000. 404 | else: 405 | raise ValueError('please specify a norm') 406 | elif method == 'msd': 407 | if norm == 'Linf': 408 | return eps / 4. 409 | elif norm == 'L2': 410 | return eps / 3. 411 | elif norm == 'L1': 412 | return 1. #2. * eps * 255. / 2000. 413 | else: 414 | raise ValueError('please specify a norm') 415 | elif method == 'msd-5': 416 | if norm == 'Linf': 417 | return eps / 2. 418 | elif norm == 'L2': 419 | return eps / 1.5 420 | elif norm == 'L1': 421 | return eps / 2. 422 | else: 423 | raise ValueError('please specify a norm') 424 | elif method == 'half': 425 | return eps / 2. 426 | 427 | # utils max strategy 428 | def form_batch_max(l_adv, l_acc, l_loss, l_norm): 429 | bs = l_adv[0].shape[0] 430 | adv = l_adv[0].clone() 431 | best_norm = torch.zeros([bs]).long() #[ for _ in range(bs)] 432 | best_loss = l_loss[0].clone() 433 | best_acc = l_acc[0].clone() 434 | for counter in range(1, len(l_norm)): 435 | ind = l_loss[counter] > best_loss 436 | adv[ind] = l_adv[counter][ind].clone() 437 | best_norm[ind] = counter + 0 438 | best_loss[ind] = l_loss[counter][ind].clone() 439 | best_acc[ind] = l_acc[counter][ind].clone() 440 | #best_norm = [l_norm[best_norm[val].item()] + '' for val in range(bs)] 441 | 442 | return adv, best_norm, best_acc, best_loss 443 | 444 | def random_crop(x, size, padding): 445 | z = torch.zeros([x.shape[0], x.shape[1], size + 2 * padding, 446 | size + 2 * padding], device=x.device) 447 | z[:, :, padding:padding + size, padding:padding + size] += x 448 | 449 | a = random.randint(0, 2 * padding) 450 | b = random.randint(0, 2 * padding) 451 | 452 | return z[:, :, a:a + size, b:b + size] 453 | 454 | 455 | class BatchTracker(): 456 | def __init__(self, imgs, labs, bs, norms, alpha): 457 | self.imgs_orig = imgs.clone() 458 | self.labs_orig = labs.clone() 459 | self.bs = bs 460 | self.n_ex = imgs.shape[0] 461 | self.norms = norms 462 | #self.loss_norms = torch.zeros([self.n_ex, 2]) #{k: torch.zeros([self.n_ex]) for k in norms} 463 | #self.count_norms = self.loss_norms.clone() 464 | self.loss_norms_ra = torch.zeros([self.n_ex, 2]) 465 | self.alpha = alpha 466 | 467 | def batch_new_epoch(self): 468 | self.ind_sort = torch.randperm(self.n_ex) 469 | self.batch_init = 0 470 | #self.loss_norms[k] 471 | u = torch.ones_like(self.loss_norms_ra[:, 0]) 472 | '''ps = (self.loss_norms[:, 0] / torch.max(self.count_norms[:, 0], u)) / torch.max( 473 | self.loss_norms[:, 0] / torch.max(self.count_norms[:, 0], u 474 | ) + self.loss_norms[:, 1] / torch.max(self.count_norms[:, 1], u), u) 475 | ps = torch.max(ps, .1 * u)''' 476 | 477 | tot_curr = self.loss_norms_ra[:, 0] + self.loss_norms_ra[:, 1] 478 | ind_tot_curr = tot_curr == 0. 479 | tot_curr[tot_curr == 0.] = 1. 480 | ps = self.loss_norms_ra[:, 0] / tot_curr 481 | ps_old = ps.clone() 482 | #ps = ps * (self.loss_norms_ra.min(dim=1)[0] > 0.) + .5 * (self.loss_norms_ra.min(dim=1)[0] <= 0.) 483 | ps[(ps == 0.) + (ps == 1.)] = 1. - ps[(ps == 0.) + (ps == 1.)] 484 | if True: #False 485 | # 486 | ps = (self.loss_norms_ra[:, 0] > self.loss_norms_ra[:, 1]).float() #ps.min(2)[1] 487 | ps[ps_old == 0.] = 1. 488 | ps[ps_old == 1.] = 0. 489 | 490 | #print(ps) 491 | ps[ind_tot_curr] = .5 492 | 493 | #print(ps) 494 | #print(ind_tot_curr.sum(), ((ps_old == 0.) + (ps_old == 1.)).sum()) 495 | #norm_at = (ps < random.random()).long() 496 | #print(self.labs_orig) 497 | #print(norm_at) 498 | 499 | train_loader = [] 500 | for c in range(0, self.n_ex, self.bs): 501 | ind_curr = self.ind_sort[c:c + self.bs].clone() 502 | x_curr = self.custom_augm(self.imgs_orig[ind_curr].clone()) 503 | y_curr = self.labs_orig[ind_curr].clone() 504 | norm_curr = (ps[ind_curr] < random.random()).long().clone() #norm_at[ind_curr].clone() 505 | #print(y_curr, norm_curr) 506 | train_loader.append((x_curr.clone(), y_curr, norm_curr)) 507 | return train_loader 508 | 509 | def custom_augm(self, x): 510 | z = random_crop(x, x.shape[-1], 4) 511 | if random.random() > .5: 512 | return transforms.functional.hflip(z) 513 | else: 514 | return z.clone() 515 | 516 | def update_loss(self, loss, norm, i): 517 | ind_curr = self.ind_sort[i * self.bs:(i + 1) * self.bs].clone() 518 | #print(ind_curr) 519 | #self.loss_norms[ind_curr, norm] += loss 520 | #self.count_norms[ind_curr, norm] += 1 521 | self.loss_norms_ra[ind_curr, norm] = self.loss_norms_ra[ind_curr, norm 522 | ] * self.alpha + loss.cpu() * (1. - self.alpha) 523 | 524 | 525 | # different resolution 526 | class ImageCropper(nn.Module): 527 | def __init__(self, w: int, #float, float, float 528 | shape: Tuple[int, int, int, int]) -> None: 529 | super(ImageCropper, self).__init__() 530 | 531 | mask = get_mask(shape, w) 532 | self.register_buffer('mask', mask) 533 | 534 | def forward(self, input: Tensor) -> Tensor: 535 | return input * self.mask 536 | 537 | 538 | def get_mask(shape, w): 539 | mask = torch.zeros(shape) 540 | assert len(mask.shape) == 4 541 | mask[:, :, w:mask.shape[-2] - w, w:mask.shape[-1] - w] = 1 542 | 543 | return mask 544 | 545 | 546 | def add_preprocessing(model: nn.Module, w: int, shape: Tuple[int, int, int] 547 | ) -> nn.Module: 548 | layers = OrderedDict([ 549 | ('crop', ImageCropper(w, [1] + shape)), 550 | ('model', model) 551 | ]) 552 | return nn.Sequential(layers) 553 | 554 | 555 | def Lp_norm(x, p, keepdim=False): 556 | assert p > 1 557 | z = x.view(x.shape[0], -1) 558 | t = (z.abs() ** p).sum(dim=-1) ** (1. / p) 559 | if keepdim: 560 | t = t.view(-1, *[1] * (len(x.shape) - 1)) 561 | return t 562 | 563 | 564 | if __name__ == '__main__': 565 | '''args = lambda: 0 566 | args.batch_size = 256 567 | train_loader, test_loader = load_imagenet_train(args) 568 | print(len(train_loader), len(test_loader))''' 569 | 570 | custom_loader = BatchTracker(torch.ones([10, 3, 5, 5]) * torch.arange( 571 | 10).view(-1, 1, 1, 1), torch.arange(10), 3, ['Linf', 'L1'], .9) 572 | for epoch in range(4): 573 | startt = time.time() 574 | train_loader = custom_loader.batch_new_epoch() 575 | print('loader created in {:.3f} s'.format(time.time() - startt)) 576 | loss = torch.zeros([10]) 577 | #norm_all = loss.clone() 578 | #print(custom_loader.loss_norms_ra) 579 | for i in range(4): #(x, y, norm) in enumerate(train_loader) 580 | x, y, norm = train_loader[i] 581 | #print(y, norm) 582 | loss = torch.randn([y.shape[0]]).abs() 583 | custom_loader.update_loss(loss, norm, i) 584 | print(x.view(x.shape[0], -1).max(dim=1)[0]) 585 | #print(custom_loader.loss_norms) 586 | #print(custom_loader.count_norms) 587 | #print(custom_loader.loss_norms_ra) 588 | 589 | --------------------------------------------------------------------------------