├── AA_eval.py
├── README.md
├── autopgd_train_clean.py
├── dataset_convnext_like.py
├── fgsm_train.py
├── main.py
├── models
├── __init__.py
├── convnext.py
├── convnext_iso.py
└── utils.py
├── parserr.py
├── rb_architecture_util.py
├── readme_teaser.png
├── run_train.sh
├── runner_aa_eval.py
├── utils.py
├── utils_architecture.py
├── utils_eval.py
└── utils_train.py
/AA_eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import models
3 | import torch.nn as nn
4 |
5 | from timm.models import create_model
6 | from torchvision import datasets, transforms
7 | import math
8 | import argparse
9 | import os
10 | import sys
11 | sys.path.insert(0,'..')
12 |
13 | import json
14 | import robustbench
15 | import numpy as np
16 |
17 | from autoattack import AutoAttack
18 | from robustbench.utils import clean_accuracy
19 |
20 | from main import BlurPoolConv2d, PREC_DICT, IMAGENET_MEAN, \
21 | IMAGENET_STD
22 | from utils_architecture import normalize_model, get_new_model, interpolate_pos_encoding
23 | from ptflops import get_model_complexity_info
24 | from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count_str
25 |
26 | def sizeof_fmt(num, suffix="Flops"):
27 | for unit in ["", "Ki", "Mi", "G", "T"]:
28 | if abs(num) < 1000.0:
29 | return f"{num:3.3f}{unit}{suffix}"
30 | num /= 1000.0
31 | return f"{num:.1f}Yi{suffix}"
32 |
33 | eps_dict = {'imagenet': {'Linf': 4. / 255., 'L2': 2., 'L1': 75.}}
34 |
35 |
36 | class Logger():
37 | def __init__(self, log_path):
38 | self.log_path = log_path
39 |
40 | def log(self, str_to_log, verbose=False):
41 | print(str_to_log)
42 | if not self.log_path is None:
43 | with open(self.log_path, 'a') as f:
44 | f.write(str_to_log)
45 | f.write('\n')
46 | if verbose:
47 | f.flush()
48 |
49 | def format(value):
50 | return "%.3f" % value
51 |
52 |
53 | def makedir(path):
54 | if not os.path.exists(path):
55 | os.makedirs(path)
56 |
57 |
58 | def get_args_parser():
59 | parser = argparse.ArgumentParser()
60 | parser.add_argument('--batch_size', default=200, type=int)
61 | parser.add_argument('--model', default='convnext_tiny_convmlp_nolayerscale', type=str)
62 | parser.add_argument('--n_ex', type=int, default=5000)
63 | parser.add_argument('--norm', type=str)
64 | parser.add_argument('--eps', type=float)
65 | parser.add_argument('--data_dir', type=str, default='/scratch/fcroce42/ffcv_imagenet_data')
66 | parser.add_argument('--only_clean', action='store_true')
67 | parser.add_argument('--save_imgs', action='store_true')
68 | parser.add_argument('--precision', type=str, default='fp32')
69 | parser.add_argument('--ckpt_path', type=str, default='/scratch/nsingh/ImageNet_Arch/model_2022-12-04 14:36:56_convnext_iso_iso_0_not_orig_0_pre_1_aug_0_adv__50_at_crop_flip/weights_18.pt')
70 | parser.add_argument('--mod', type=str)
71 | parser.add_argument('--model_in', nargs='+')
72 | parser.add_argument('--full_aa', type=int, default=0)
73 | parser.add_argument('--init', type=str)
74 | parser.add_argument('--add_normalization', action='store_true', default=False)
75 | parser.add_argument('--l_norms', type=str)
76 | parser.add_argument('--l_epss', type=str)
77 | parser.add_argument('--get_stats', action='store_true')
78 | parser.add_argument('--use_fixed_val_set', action='store_true', default=False)
79 | parser.add_argument('--img_size', type=int, help='resolution to test the evaluataion for, default: 224', default=224)
80 | parser.add_argument('--not_channel_last', action='store_false')
81 | parser.add_argument('--not-original', type=int, default=1)
82 | parser.add_argument('--updated', action='store_true', help='Patched models?', default=False)
83 |
84 | args = parser.parse_args()
85 | return args
86 |
87 | def main():
88 | args = get_args_parser()
89 |
90 | mods = [args.mod]
91 | nots = [bool(args.not_original)]
92 | args.model_in = ' '.join(args.model_in)
93 | ll = [args.model_in]
94 | args.ckpt_path = args.model_in
95 | data_path = 'patch_to_imagenet_validataion_set'
96 |
97 | device = 'cuda'
98 |
99 | assert len(mods) == len(nots) == len(ll)
100 |
101 | print('using fixed val set')
102 |
103 |
104 | crop_pct = 0.875
105 |
106 | img_size = args.img_size
107 |
108 | scale_size = int(math.floor(img_size / crop_pct))
109 | trans = transforms.Compose([
110 | transforms.Resize(
111 | scale_size,
112 | interpolation=transforms.InterpolationMode("bicubic")),
113 | transforms.CenterCrop(img_size),
114 | transforms.ToTensor()
115 | ])
116 | x_test_val, y_test_val = robustbench.data.load_imagenet(5000, data_dir=data_path,
117 | transforms_test = trans)
118 |
119 |
120 |
121 | print(f"{args.mod} has resolution : {img_size}")
122 |
123 | for idx, modd in enumerate(ll):
124 |
125 | args.ckpt_path += "/weights_20.pt"
126 | args.model = mods[idx]
127 | args.not_original = nots[idx]
128 |
129 | if not args.ckpt_path is None:
130 | # assert os.path.exists(args.ckpt_path), f'{args.ckpt_path} not found'
131 | args.savedir = '/'.join(args.ckpt_path.split('/')[:-1])
132 | print(args.savedir)
133 | # ep = args.ckpt_path.split('/')[-1].split('.pt')[0]
134 | with open(f'{args.savedir}/params.json', 'r') as f:
135 | params = json.load(f)
136 | args.use_blurpool = params['training.use_blurpool'] == 1
137 | if 'model.add_normalization' in params.keys():
138 | args.add_normalization = params['model.add_normalization'] == 1
139 | args.model = args.model #params['model.arch']
140 | else:
141 | args.savedir = './results/'
142 | makedir(args.savedir)
143 |
144 | args.n_cls = 1000
145 | args.num_workers = 1
146 |
147 | if not args.eps is None and args.eps > 1 and args.norm == 'Linf':
148 | args.eps /= 255.
149 |
150 |
151 | device = 'cuda'
152 | arch = args.model
153 | pretrained = False
154 | add_normalization = args.add_normalization
155 |
156 | log_path = f'{args.savedir}/evaluated_logs_{args.l_norms}_{args.full_aa}_8_255.txt'
157 | logger = Logger(log_path)
158 |
159 | print(f"Creating model: {args.model}")
160 | if not arch.startswith('timm_'):
161 | model = get_new_model(arch, pretrained=False, not_original=args.not_original, updated=args.updated)
162 | else:
163 | try:
164 | model = create_model(arch.replace('timm_', ''), pretrained=pretrained)
165 | except:
166 | model = get_new_model(arch.replace('timm_', ''))
167 |
168 | if add_normalization:
169 | print('add normalization layer')
170 | model = normalize_model(model, IMAGENET_MEAN, IMAGENET_STD)
171 |
172 | inpp = torch.rand(1, 3, 224, 224)
173 | flops = FlopCountAnalysis(model, inpp)
174 | val = flops.total()
175 | print(sizeof_fmt(int(val)))
176 | print(flop_count_table(flops, max_depth=2))
177 | print(flops.by_operator())
178 |
179 | accs = []
180 | best_test_rob = 0.
181 |
182 | for i in rann:
183 |
184 | ckpt = torch.load(args.savedir + f"/model_file_name.pt", map_location='cpu') #['model']
185 | # print(ckpt.keys())
186 | ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
187 | ckpt = {k.replace('base_model.', ''): v for k, v in ckpt.items()}
188 | ckpt = {k.replace('se_', 'se_module.'): v for k, v in ckpt.items()}
189 | model.load_state_dict(ckpt)
190 | model = model.to(device)
191 | model.eval()
192 | acc = clean_accuracy(model, x_test_val, y_test_val, batch_size=args.batch_size,
193 | device=device)
194 | print(f"clean {i} : {acc}")
195 | img_sizer = [img_size, img_size]
196 | if not args.ckpt_path is None:
197 | ## change the size of positional embedding in ViT if img_size>224 (training resolution).
198 | if "vit" in arch:
199 | old_shape = ckpt['pos_embed'].shape
200 | ckpt['pos_embed'] = interpolate_pos_encoding(
201 | ckpt['pos_embed'], new_img_size=img_sizer[0],
202 | patch_size=model.patch_embed.patch_size[0])
203 | new_shape = ckpt['pos_embed'].shape
204 | print(old_shape, new_shape)
205 | model.pos_embed = nn.Parameter(torch.zeros(new_shape, device=model.pos_embed.device))
206 |
207 | model.patch_embed.img_size = img_sizer
208 | model.patch_embed.num_patches = new_shape[1] - 1
209 | model.patch_embed.grid_size = (
210 | img_sizer[0] // model.patch_embed.patch_size[0],
211 | img_sizer[1] // model.patch_embed.patch_size[1])
212 | model.eval()
213 |
214 | str_to_log = ''
215 |
216 | logger = Logger(log_path)
217 | logger.log(str_to_log)
218 |
219 | all_norms = [args.l_norms] #
220 | #all_norms = ['L2', 'L1', 'Linf']
221 | l_epss = [eps_dict['imagenet'][c] for c in all_norms]
222 | logger.log(all_norms, l_epss)
223 | all_acs = []
224 | for idx, nrm in enumerate(all_norms):
225 | epss = l_epss[idx]
226 | adversary = AutoAttack(model, norm=nrm, eps=epss,
227 | version='standard', log_path=log_path)
228 | str_to_log = ''
229 |
230 | if not bool(args.full_aa):
231 | adversary.attacks_to_run = ['apgd-ce', 'apgd-t']
232 |
233 | str_to_log += f'norm={nrm} eps={l_epss[idx]:.5f}\n'
234 |
235 | assert not model.training
236 |
237 | with torch.no_grad():
238 | x_adv = adversary.run_standard_evaluation(x_test_val,
239 | y_test_val, bs=args.batch_size)
240 |
241 | acc = clean_accuracy(model, x_adv, y_test_val, batch_size=args.batch_size,
242 | device=device)
243 | print('robust accuracy: {:.2%}'.format(acc))
244 | str_to_log += 'robust accuracy: {:.2%}\n'.format(acc)
245 | logger.log(str_to_log)
246 | all_acs.append(acc)
247 |
248 | if args.save_imgs:
249 | valset = '_oldset' if args.use_fixed_val_set else ''
250 | runname = f'aa_short_1_{args.n_ex}_{args.norm}_{args.eps:.5f}{valset}.pth'
251 | savepath = f'{args.savedir}/{runname}'
252 | torch.save(x_adv.cpu(), savepath)
253 |
254 | if __name__ == '__main__':
255 | main()
256 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Revisiting Adversarial Training for ImageNet: Architectures, Training and Generalization across Threat Models
2 | #### Naman D Singh, Francesco Croce, Matthias Hein
3 | #### University of Tübingen
4 | #### NeurIPS 2023
5 |
6 | ### [Paper](https://arxiv.org/abs/2303.01870)
7 | ## Abstract
8 | While adversarial training has been extensively studied for ResNet architectures and low resolution datasets like CIFAR, much less is known for ImageNet. Given the recent debate about whether transformers are more robust than convnets, we revisit adversarial training on ImageNet comparing ViTs and ConvNeXts. Extensive experiments show that minor changes in architecture, most notably replacing PatchStem with ConvStem, and training scheme have a significant impact on the achieved robustness. These changes not only increase robustness in the seen $\ell_\infty$-threat model, but even more so improve generalization to unseen $\ell_1/\ell_2$-robustness.
9 |
10 | 
11 |
12 |
13 | ## Code
14 | Requirements (specific versions tested on):
15 | `fastargs-1.2.0` `autoattack-0.1` `pytorch-1.13.1` `torchvision-0.14.1` `robustbench-1.1` `timm-0.8.0.dev0`, `GPUtil`
16 |
17 | #### Training
18 | The bash script in `run_train.sh` trains the model `model.arch`. For clean training: `adv.attack none` and for adversarial training set `adv.attack apgd`.
19 | For the standard setting as in the paper (heavy augmentations) set `data.augmentations 1`, `model.model_ema 1` and `training.label_smoothing 1`.
20 | To train models with Convolution-Stem (CvSt) set `model.not_original 1`.
21 | The code does standard APGD adversarial training. The file `utils_architecture.py` has model definitions for the new `CvSt` models, all models are built on top of timm imports.
22 |
23 | #### Evaluating a model
24 | The file `runner_aa_eval` runs `AutoAttack`(AA). Passing `fullaa 1` runs complete AA whereas `fullaa 0` runs the first two attacks (APGD-CE and APGD-T) in AA.
25 |
26 |
27 | #### Checkpoints - ImageNet $\ell_{\infty} = 4/255$ robust models.
28 | The link location includes weights for the clean model (the one used as initialization for Adversarial Training (AT)), the robust model, and the `full-AA` log for $\ell_{\infty}, \ell_2$ and $\ell_1$ attacks.
29 | Note: the higher resolution numbers use the same checkpoint as for the standard resolution of 224 - only evaluation is done at the higher resolution mentioned.
30 | | Model-Name | epochs | res. | Clean acc. | AA - $\ell_{\infty}$ acc.| Checkpoint (clean-init and robust) |
31 | | :--- | :------: | :------: | :------: |:------: | :------: |
32 | | ConvNext-iso-CvSt | 300 | 224 | 70.2 | 45.9 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/HpNbkLTNTBiaeo8)|
33 | | ViT-S | 300 | 224 | 69.2 | 44.0 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/XLLnoCnJxp74Zqn)|
34 | | ViT-S-CvSt | 300 | 224 | 72.5 | 48.1 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/agtDw3D7QXbDCmw)|
35 | | ConvNext-T | 300 | 224 | 72.4 | 48.6 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/XLLnoCnJxp74Zqn)|
36 | | ConvNext-T-CvSt | 300 | 224 | 72.7 | 49.5 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/BFLoMrMdn8iBk7Y)|
37 | | ViT-M-CvSt | 50 | 224 | 72.4 | 48.8 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/q2mkEYtq5Zjpa4e)|
38 | | ConvNext-S-CvSt | 50 | 224 | 74.1 | 52.4 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/m3bAwNg4CJY4jrp)|
39 | | ViT-B | 50 | 224 | 73.3 | 50.0 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/XLLnoCnJxp74Zqn)|
40 | | ConvNext-B | 50 | 224 | 75.6 | 54.3 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/XLLnoCnJxp74Zqn)|
41 | | ViT-B-CvSt | 250 | 224 | 76.3 | 54.7 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/SbN5AJAicdZJXyr)|
42 | | ConvNext-B-CvSt | 250 | 224 | 75.9 | 56.1 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/RQBEXagC7R7XweX)|
43 | | ConvNext-B-CvSt* | --- | 256 | 76.9 | 57.3 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/RQBEXagC7R7XweX)|
44 | | ConvNext-L-CvSt | 100 | 224 | 77.0 | 57.7 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/YzBpeHRrRQzHBDz)|
45 | | ConvNext-L-CvSt* | --- | 320 | 78.2 | 59.4 | [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/YzBpeHRrRQzHBDz)|
46 | ###### *: increased resolution (only for evaluation) also leads to increased FLOPs.
47 | -------------------
48 | Checkpoints along with accuracy and robustness logs for ImageNet models finetuned to be robust at $\ell_\infty = 8/255$ are available here: [Link](https://nc.mlcloud.uni-tuebingen.de/index.php/s/FiTToeo4RKY896P)
49 | ________________________________
50 |
Citation
51 |
52 | If you use our code/models cite our work using the following BibTex entry:
53 | ```bibtex
54 | @inproceedings{singh2023revisiting,
55 | title={Revisiting Adversarial Training for ImageNet: Architectures, Training and Generalization across Threat Models},
56 | author={Singh, Naman D and Croce, Francesco and Hein, Matthias},
57 | booktitle={NeurIPS},
58 | year={2023}}
59 | ```
60 |
--------------------------------------------------------------------------------
/autopgd_train_clean.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | import random
6 | import time
7 |
8 | def L1_norm(x, keepdim=False):
9 | z = x.abs().contiguous().view(x.shape[0], -1).sum(-1)
10 | if keepdim:
11 | z = z.contiguous().view(-1, *[1]*(len(x.shape) - 1))
12 | return z
13 |
14 | def L2_norm(x, keepdim=False):
15 | z = (x ** 2).contiguous().view(x.shape[0], -1).sum(-1).sqrt()
16 | if keepdim:
17 | z = z.contiguous().view(-1, *[1]*(len(x.shape) - 1))
18 | return z
19 |
20 | def L0_norm(x):
21 | return (x != 0.).view(x.shape[0], -1).sum(-1)
22 |
23 |
24 | def L1_projection(x2, y2, eps1):
25 | '''
26 | x2: center of the L1 ball (bs x input_dim)
27 | y2: current perturbation (x2 + y2 is the point to be projected)
28 | eps1: radius of the L1 ball
29 |
30 | output: delta s.th. ||y2 + delta||_1 = eps1
31 | and 0 <= x2 + y2 + delta <= 1
32 | '''
33 |
34 | x = x2.clone().float().view(x2.shape[0], -1)
35 | y = y2.clone().float().view(y2.shape[0], -1)
36 | sigma = y.clone().sign()
37 | u = torch.min(1 - x - y, x + y)
38 | #u = torch.min(u, epsinf - torch.clone(y).abs())
39 | u = torch.min(torch.zeros_like(y), u)
40 | l = -torch.clone(y).abs()
41 | d = u.clone()
42 |
43 | bs, indbs = torch.sort(-torch.cat((u, l), 1), dim=1)
44 | bs2 = torch.cat((bs[:, 1:], torch.zeros(bs.shape[0], 1).to(bs.device)), 1)
45 |
46 | inu = 2*(indbs < u.shape[1]).float() - 1
47 | size1 = inu.cumsum(dim=1)
48 |
49 | s1 = -u.sum(dim=1)
50 |
51 | c = eps1 - y.clone().abs().sum(dim=1)
52 | c5 = s1 + c < 0
53 | c2 = c5.nonzero().squeeze(1)
54 |
55 | s = s1.unsqueeze(-1) + torch.cumsum((bs2 - bs) * size1, dim=1)
56 | #print(s[0])
57 |
58 | #print(c5.shape, c2)
59 |
60 | if c2.nelement != 0:
61 |
62 | lb = torch.zeros_like(c2).float()
63 | ub = torch.ones_like(lb) *(bs.shape[1] - 1)
64 |
65 | #print(c2.shape, lb.shape)
66 |
67 | nitermax = torch.ceil(torch.log2(torch.tensor(bs.shape[1]).float()))
68 | counter2 = torch.zeros_like(lb).long()
69 | counter = 0
70 |
71 | while counter < nitermax:
72 | counter4 = torch.floor((lb + ub) / 2.)
73 | counter2 = counter4.type(torch.LongTensor)
74 |
75 | c8 = s[c2, counter2] + c[c2] < 0
76 | ind3 = c8.nonzero().squeeze(1)
77 | ind32 = (~c8).nonzero().squeeze(1)
78 | #print(ind3.shape)
79 | if ind3.nelement != 0:
80 | lb[ind3] = counter4[ind3]
81 | if ind32.nelement != 0:
82 | ub[ind32] = counter4[ind32]
83 |
84 | #print(lb, ub)
85 | counter += 1
86 |
87 | lb2 = lb.long()
88 | alpha = (-s[c2, lb2] -c[c2]) / size1[c2, lb2 + 1] + bs2[c2, lb2]
89 | d[c2] = -torch.min(torch.max(-u[c2], alpha.unsqueeze(-1)), -l[c2])
90 |
91 | return (sigma * d).view(x2.shape)
92 |
93 |
94 | def softloss(x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
95 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
96 | return loss.mean()
97 |
98 |
99 | def dlr_loss(x, y, reduction='none'):
100 | x_sorted, ind_sorted = x.sort(dim=1)
101 | ind = (ind_sorted[:, -1] == y).float()
102 |
103 | return -(x[torch.arange(x.shape[0]), y] - x_sorted[:, -2] * ind - \
104 | x_sorted[:, -1] * (1. - ind)) / (x_sorted[:, -1] - x_sorted[:, -3] + 1e-12)
105 |
106 | def dlr_loss_targeted(x, y, y_target):
107 | x_sorted, ind_sorted = x.sort(dim=1)
108 | u = torch.arange(x.shape[0])
109 |
110 | return -(x[u, y] - x[u, y_target]) / (x_sorted[:, -1] - .5 * (
111 | x_sorted[:, -3] + x_sorted[:, -4]) + 1e-12)
112 |
113 | criterion_dict = {'ce': lambda x, y: F.cross_entropy(x, y, reduction='none'),'softloss':softloss,
114 | 'dlr': dlr_loss, 'dlr-targeted': dlr_loss_targeted}
115 |
116 | def check_oscillation(x, j, k, y5, k3=0.75):
117 | t = torch.zeros(x.shape[1], device=x.device, dtype=x.dtype)
118 | for counter5 in range(k):
119 | t += (x[j - counter5] > x[j - counter5 - 1]).float()
120 |
121 | return (t <= k * k3 * torch.ones_like(t)).float()
122 |
123 | def apgd_train(model, x, y, norm, eps, n_iter=10, use_rs=False, loss='ce',
124 | verbose=False, mixup=None, is_train=True):
125 | assert not model.training
126 | device = x.device
127 | ndims = len(x.shape) - 1
128 |
129 | times = {'fp': 0, 'bp': 0, 'total': time.time(), 'blk1': 0, 'blk2': 0,
130 | 'blk3': 0}
131 | #print(x.is_contiguous())
132 |
133 | #startt = time.time()
134 | if not use_rs:
135 | x_adv = x.clone()
136 | else:
137 | raise NotImplemented
138 | if norm == 'Linf':
139 | t = torch.rand_like(x)
140 |
141 | x_adv = x_adv.clamp(0., 1.)
142 | x_best = x_adv.clone().detach()
143 | x_best_adv = x_adv.clone().detach()
144 | loss_steps = torch.zeros([n_iter, x.shape[0]], device=device)
145 | loss_best_steps = torch.zeros([n_iter + 1, x.shape[0]], device=device)
146 | acc_steps = torch.zeros_like(loss_best_steps)
147 |
148 | # set loss
149 | criterion_indiv = criterion_dict[loss]
150 |
151 | # set params
152 | n_fts = math.prod(x.shape[1:])
153 | if norm in ['Linf', 'L2']:
154 | n_iter_2 = max(int(0.22 * n_iter), 1)
155 | n_iter_min = max(int(0.06 * n_iter), 1)
156 | size_decr = max(int(0.03 * n_iter), 1)
157 | k = n_iter_2 + 0
158 | thr_decr = .75
159 | alpha = 2.
160 | elif norm in ['L1']:
161 | k = max(int(.04 * n_iter), 1)
162 | init_topk = .05 if is_train else .2
163 | topk = init_topk * torch.ones([x.shape[0]], device=device)
164 | sp_old = n_fts * torch.ones_like(topk)
165 | adasp_redstep = 1.5
166 | adasp_minstep = 10.
167 | alpha = 1.
168 |
169 | step_size = alpha * eps * torch.ones([x.shape[0], *[1] * ndims],
170 | device=device, dtype=x.dtype)
171 | counter3 = 0
172 | #times['blk1'] += time.time() - startt
173 |
174 | x_adv.requires_grad_()
175 | #grad = torch.zeros_like(x)
176 | #for _ in range(self.eot_iter)
177 | #with torch.enable_grad()
178 | startt = time.time()
179 | logits = model(x_adv)
180 | times['fp'] += time.time() - startt
181 | loss_indiv = criterion_indiv(logits, y)
182 | loss = loss_indiv.sum()
183 | #grad += torch.autograd.grad(loss, [x_adv])[0].detach()
184 | startt = time.time()
185 | grad = torch.autograd.grad(loss, [x_adv])[0].detach()
186 | times['bp'] += time.time() - startt
187 | #grad /= float(self.eot_iter)
188 | #startt = time.time()
189 | grad_best = grad.clone()
190 | x_adv.detach_()
191 | loss_indiv.detach_()
192 | loss.detach_()
193 |
194 | if mixup is not None:
195 | acc = logits.detach().max(1)[1] == y.max(1)[1]
196 | else:
197 | acc = logits.detach().max(1)[1] == y
198 | acc_steps[0] = acc + 0
199 | loss_best = loss_indiv.detach().clone()
200 | loss_best_last_check = loss_best.clone()
201 | reduced_last_check = torch.ones_like(loss_best)
202 | n_reduced = 0
203 |
204 | u = torch.arange(x.shape[0], device=device)
205 | x_adv_old = x_adv.clone().detach()
206 | #times['blk2'] += time.time() - startt
207 |
208 | #startt = time.time()
209 | for i in range(n_iter):
210 | ### gradient step
211 | if True: #with torch.no_grad()
212 | #startt = time.time()
213 | x_adv = x_adv.detach()
214 | grad2 = x_adv - x_adv_old
215 | x_adv_old = x_adv.clone()
216 | loss_curr = loss.detach().mean()
217 |
218 | a = 0.75 if i > 0 else 1.0
219 |
220 | if norm == 'Linf':
221 | x_adv_1 = x_adv + step_size * torch.sign(grad)
222 | x_adv_1 = torch.clamp(torch.min(torch.max(x_adv_1,
223 | x - eps), x + eps), 0.0, 1.0)
224 | x_adv_1 = torch.clamp(torch.min(torch.max(
225 | x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a),
226 | x - eps), x + eps), 0.0, 1.0)
227 |
228 | elif norm == 'L2':
229 | x_adv_1 = x_adv + step_size * grad / (L2_norm(grad,
230 | keepdim=True) + 1e-12)
231 | x_adv_1 = torch.clamp(x + (x_adv_1 - x) / (L2_norm(x_adv_1 - x,
232 | keepdim=True) + 1e-12) * torch.min(eps * torch.ones_like(x),
233 | L2_norm(x_adv_1 - x, keepdim=True)), 0.0, 1.0)
234 | x_adv_1 = x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a)
235 | x_adv_1 = torch.clamp(x + (x_adv_1 - x) / (L2_norm(x_adv_1 - x,
236 | keepdim=True) + 1e-12) * torch.min(eps * torch.ones_like(x),
237 | L2_norm(x_adv_1 - x, keepdim=True)), 0.0, 1.0)
238 |
239 | elif norm == 'L1':
240 | grad_topk = grad.abs().view(x.shape[0], -1).sort(-1)[0]
241 | topk_curr = torch.clamp((1. - topk) * n_fts, min=0, max=n_fts - 1).long()
242 | grad_topk = grad_topk[u, topk_curr].view(-1, *[1]*(len(x.shape) - 1))
243 | sparsegrad = grad * (grad.abs() >= grad_topk).float()
244 | x_adv_1 = x_adv + step_size * sparsegrad.sign() / (
245 | sparsegrad.sign().abs().view(x.shape[0], -1).sum(dim=-1).view(
246 | -1, 1, 1, 1) + 1e-10)
247 |
248 | delta_u = x_adv_1 - x
249 | delta_p = L1_projection(x, delta_u, eps)
250 | x_adv_1 = x + delta_u + delta_p
251 |
252 | elif norm == 'L0':
253 | L1normgrad = grad / (grad.abs().view(grad.shape[0], -1).sum(
254 | dim=-1, keepdim=True) + 1e-12).view(grad.shape[0], *[1]*(
255 | len(grad.shape) - 1))
256 | x_adv_1 = x_adv + step_size * L1normgrad * n_fts
257 | x_adv_1 = L0_projection(x_adv_1, x, eps)
258 | # TODO: add momentum
259 |
260 | x_adv = x_adv_1 + 0.
261 | #times['blk1'] += time.time() - startt
262 | #return x_adv
263 |
264 | #startt = time.time() # t1
265 |
266 | ### get gradient
267 | if i < n_iter - 1:
268 | x_adv.requires_grad_()
269 | #grad = torch.zeros_like(x)
270 | #for _ in range(self.eot_iter)
271 | #with torch.enable_grad()
272 | startt = time.time()
273 | logits = model(x_adv)
274 | times['fp'] += time.time() - startt
275 | loss_indiv = criterion_indiv(logits, y)
276 | loss = loss_indiv.sum()
277 | #times['blk1'] += time.time() - startt
278 |
279 | #startt = time.time() # t2
280 | #grad += torch.autograd.grad(loss, [x_adv])[0].detach()
281 | if i < n_iter - 1:
282 | # save one backward pass
283 | grad = torch.autograd.grad(loss, [x_adv])[0].detach()
284 | #grad /= float(self.eot_iter)
285 | x_adv.detach_()
286 | loss_indiv.detach_()
287 | loss.detach_()
288 | #times['blk2'] += time.time() - startt
289 |
290 | startt = time.time()
291 | if mixup is not None:
292 | pred = logits.detach().max(1)[1] == y.max(1)[1]
293 | else:
294 | pred = logits.detach().max(1)[1] == y
295 |
296 | acc = torch.min(acc, pred)
297 | startt = time.time()
298 | acc_steps[i + 1] = acc + 0
299 | times['blk1'] += time.time() - startt
300 | startt = time.time()
301 | ind_pred = ~pred
302 | times['blk2'] += time.time() - startt
303 | startt = time.time()
304 | x_best_adv[ind_pred] = x_adv[ind_pred] + 0.
305 | times['blk3'] += time.time() - startt
306 | if verbose:
307 | str_stats = ' - step size: {:.5f} - topk: {:.2f}'.format(
308 | step_size.mean(), topk.mean() * n_fts) if norm in ['L1'] else ' - step size: {:.5f}'.format(
309 | step_size.mean())
310 | print('iteration: {} - best loss: {:.6f} curr loss {:.6f} - robust accuracy: {:.2%}{}'.format(
311 | i, loss_best.sum(), loss_curr, acc.float().mean(), str_stats))
312 | #print('pert {}'.format((x - x_best_adv).abs().view(x.shape[0], -1).sum(-1).max()))
313 |
314 | #times['blk3'] += time.time() - startt
315 | #startt = time.time() # t3
316 |
317 | ### check step size
318 | if True: #with torch.no_grad()
319 | y1 = loss_indiv.detach().clone()
320 | loss_steps[i] = y1 + 0
321 | ind = (y1 > loss_best).nonzero().squeeze()
322 | x_best[ind] = x_adv[ind].clone()
323 | grad_best[ind] = grad[ind].clone()
324 | loss_best[ind] = y1[ind] + 0
325 | loss_best_steps[i + 1] = loss_best + 0
326 |
327 | counter3 += 1
328 |
329 | if counter3 == k:
330 | if norm in ['Linf', 'L2']:
331 | fl_oscillation = check_oscillation(loss_steps, i, k,
332 | loss_best, k3=thr_decr)
333 | fl_reduce_no_impr = (1. - reduced_last_check) * (
334 | loss_best_last_check >= loss_best).float()
335 | fl_oscillation = torch.max(fl_oscillation,
336 | fl_reduce_no_impr)
337 | reduced_last_check = fl_oscillation.clone()
338 | loss_best_last_check = loss_best.clone()
339 |
340 | if fl_oscillation.sum() > 0:
341 | ind_fl_osc = (fl_oscillation > 0).nonzero().squeeze()
342 | step_size[ind_fl_osc] /= 2.0
343 | n_reduced = fl_oscillation.sum()
344 |
345 | x_adv[ind_fl_osc] = x_best[ind_fl_osc].clone()
346 | grad[ind_fl_osc] = grad_best[ind_fl_osc].clone()
347 |
348 | counter3 = 0
349 | k = max(k - size_decr, n_iter_min)
350 |
351 | elif norm == 'L1':
352 | # adjust sparsity
353 | sp_curr = L0_norm(x_best - x)
354 | fl_redtopk = (sp_curr / sp_old) < .95
355 | topk = sp_curr / n_fts / 1.5
356 | step_size[fl_redtopk] = alpha * eps
357 | step_size[~fl_redtopk] /= adasp_redstep
358 | step_size.clamp_(alpha * eps / adasp_minstep, alpha * eps)
359 | sp_old = sp_curr.clone()
360 |
361 | x_adv[fl_redtopk] = x_best[fl_redtopk].clone()
362 | grad[fl_redtopk] = grad_best[fl_redtopk].clone()
363 |
364 | counter3 = 0
365 | #times['blk3'] += time.time() - startt
366 |
367 | if verbose:
368 | times['total'] = time.time() - times['total']
369 | print(' '.join([f'{k}={v:.5f} s' for k, v in times.items()]))
370 |
371 | return x_best, acc, loss_best, x_best_adv
372 |
373 |
374 | if __name__ == '__main__':
375 | #pass
376 | from train_new import parse_args
377 | from data import load_anydataset
378 | from utils_eval import check_imgs, load_anymodel_datasets, clean_accuracy
379 |
380 | args = parse_args()
381 | args.training_set = False
382 |
383 | x_test, y_test = load_anydataset(args, device='cpu')
384 | x, y = x_test.cuda(), y_test.cuda()
385 | model, l_models = load_anymodel_datasets(args)
386 |
387 | assert not model.training
388 |
389 | if args.attack == 'apgd_train':
390 | #with torch.no_grad()
391 | x_best, acc, _, x_adv = apgd_train(model, x, y, norm=args.norm,
392 | eps=args.eps, n_iter=args.n_iter, verbose=True, loss='ce')
393 | check_imgs(x_adv, x, args.norm)
394 |
395 | elif args.attack == 'apgd_test':
396 | from autoattack import AutoAttack
397 | adversary = AutoAttack(model, norm=args.norm, eps=args.eps)
398 | #adversary.attacks_to_run = ['apgd-ce']
399 | #adversary.apgd.verbose = True
400 | with torch.no_grad():
401 | x_adv = adversary.run_standard_evaluation(x, y, bs=1000)
402 | check_imgs(x_adv, x, args.norm)
403 |
--------------------------------------------------------------------------------
/dataset_convnext_like.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 |
9 | import os
10 | from torchvision import datasets, transforms
11 |
12 | from timm.data.constants import \
13 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
14 | from timm.data import create_transform
15 |
16 | def build_dataset(is_train, args):
17 | transform = build_transform(is_train, args)
18 |
19 | print(f"Transform = train: {is_train}")
20 | if isinstance(transform, tuple):
21 | for trans in transform:
22 | print(" - - - - - - - - - - ")
23 | for t in trans.transforms:
24 | print(t)
25 | else:
26 | for t in transform.transforms:
27 | print(t)
28 | print("---------------------------")
29 | data_paths = ['/scratch/nsingh/imagenet',
30 | '/home/scratch/datasets/imagenet',
31 | '/scratch_local/datasets/ImageNet2012',
32 | '/scratch/datasets/imagenet/']
33 | for data_path in data_paths:
34 | if os.path.exists(data_path):
35 | break
36 | data_set = 'IMNET'
37 | if data_set == 'CIFAR':
38 | dataset = datasets.CIFAR100(data_path, train=is_train, transform=transform, download=True)
39 | nb_classes = 100
40 | elif data_set == 'IMNET':
41 | print("reading from datapath", data_path)
42 | root = os.path.join(data_path, 'train' if is_train else 'val')
43 | dataset = datasets.ImageFolder(root, transform=transform)
44 | nb_classes = 1000
45 | # elif data_set == "image_folder":
46 | # root = args.data_path if is_train else args.eval_data_path
47 | # dataset = datasets.ImageFolder(root, transform=transform)
48 | # nb_classes = args.nb_classes
49 | # assert len(dataset.class_to_idx) == nb_classes
50 | else:
51 | raise NotImplementedError()
52 | print("Number of the class = %d" % nb_classes)
53 |
54 | return dataset, nb_classes
55 |
56 |
57 | def build_transform(is_train, args):
58 | resize_im = args.input_size > 32
59 | imagenet_default_mean_and_std = False #args.imagenet_default_mean_and_std
60 | mean = [0., 0., 0.] #IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
61 | std = [1., 1., 1.] #IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
62 | transform = None
63 | if is_train:
64 | # this should always dispatch to transforms_imagenet_train
65 | transform = create_transform(
66 | input_size=args.input_size,
67 | is_training=True,
68 | color_jitter=args.color_jitter,
69 | auto_augment=args.aa,
70 | interpolation=args.train_interpolation,
71 | re_prob=args.reprob,
72 | re_mode=args.remode,
73 | re_count=args.recount,
74 | mean=mean,
75 | std=std,
76 | scale=args.scale,
77 | ratio=args.ratio,
78 | hflip=args.hflip,
79 | vflip=args.vflip,
80 | crop_pct=args.crop_pct
81 | )
82 |
83 | return transform
84 |
85 | t = []
86 | if resize_im:
87 | # warping (no cropping) when evaluated at 384 or larger
88 | if args.input_size >= 384:
89 | t.append(
90 | transforms.Resize((args.input_size, args.input_size),
91 | interpolation=transforms.InterpolationMode.BICUBIC),
92 | )
93 | print(f"Warping {args.input_size} size input images...")
94 | else:
95 | if args.crop_pct is None:
96 | args.crop_pct = 224 / 256
97 | size = int(args.input_size / args.crop_pct)
98 | t.append(
99 | # to maintain same ratio w.r.t. 224 images
100 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
101 | )
102 | t.append(transforms.CenterCrop(args.input_size))
103 |
104 | t.append(transforms.ToTensor())
105 | # t.append(transforms.Normalize(mean, std))
106 | return transforms.Compose(t)
107 |
--------------------------------------------------------------------------------
/fgsm_train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from autopgd_train_clean import criterion_dict
6 | import robustbench as rb
7 | #from autopgd_train import apgd_train
8 | # import utils
9 | # from model_zoo.fast_models import PreActResNet18
10 | # import other_utils
11 | import autoattack
12 | criterion_dict = {'ce': lambda x, y: F.cross_entropy(x, y, reduction='none')}
13 |
14 |
15 |
16 |
17 | def fgsm_attack(model, images, labels, eps) :
18 |
19 | loss = nn.CrossEntropyLoss()
20 | images.requires_grad = True
21 |
22 | outputs = model(images)
23 |
24 | model.zero_grad()
25 | cost = loss(outputs, labels).to(device)
26 | cost.sum().backward()
27 | attack_images = images.clone()
28 | attack_images += eps*images.grad.sign()
29 | attack_images = torch.clamp(attack_images, 0, 1)
30 |
31 | return attack_images
32 |
33 |
34 |
35 |
36 | def fgsm_attack(model, x, y, eps=4./255.):
37 |
38 | # assert not model.training
39 |
40 | # Set requires_grad attribute of tensor. Important for Attack
41 | x.requires_grad = True
42 |
43 | # Forward pass the data through the model
44 | output = model(x)
45 | # init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
46 |
47 | criterion_indiv = criterion_dict['ce']
48 |
49 | # Calculate the loss
50 | loss = criterion_indiv(output, y)
51 |
52 | # Zero all existing gradients
53 | model.zero_grad()
54 |
55 | # Calculate gradients of model in backward pass
56 | loss.sum().backward()
57 | # Collect datagrad
58 | data_grad = x.grad.data
59 |
60 | # Call FGSM Attack
61 | x_adv = gen_pert(x, eps, data_grad)
62 |
63 | return x_adv
64 | # output = model(x_adv)
65 | # final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
66 | # if final_pred.item() == target.item():
67 | # correct += 1
68 | # return correct
69 |
70 |
71 |
72 | def fgsm_train(model, x, y, eps, loss='ce', alpha=1.25, use_rs=False,
73 | noise_level=1., skip_projection=False):
74 | assert not model.training
75 |
76 | if not use_rs:
77 | x_adv = x.clone()
78 | else:
79 | #raise NotImplemented
80 | #if norm == 'Linf'
81 | t = torch.rand_like(x)
82 | x_adv = x + (2. * t - 1.) * eps * noise_level
83 | if not skip_projection:
84 | x_adv.clamp_(0., 1.)
85 |
86 | criterion_indiv = criterion_dict[loss]
87 |
88 | x_adv.requires_grad = True
89 | logits = model(x_adv)
90 | loss_indiv = criterion_indiv(logits, y)
91 | grad = torch.autograd.grad(loss_indiv.sum(), x_adv)[0].detach()
92 |
93 | x_adv = x_adv.detach() + alpha * eps * grad.sign()
94 | if not skip_projection:
95 | x_adv = x + (x_adv - x).clamp(-eps, eps)
96 | x_adv.clamp_(0., 1.)
97 |
98 | return x_adv
99 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | '''Main runnable file for imagenet experiments
2 | '''
3 |
4 | import sys
5 | sys.path.insert(0,'..')
6 | from math import ceil
7 | import math
8 | import numpy as np
9 | import os, sys
10 | from os import get_terminal_size
11 | from timm.loss.cross_entropy import SoftTargetCrossEntropy
12 | from timm.models import create_model
13 | from datetime import datetime
14 | import argparse, sys, torch
15 | import torch.nn as nn
16 | import torch.optim as optim
17 | from torchinfo import summary
18 | import random
19 | import torch as ch
20 | import torch.nn as nn
21 | from torch.cuda.amp import GradScaler
22 | from torch.cuda.amp import autocast
23 | import torch.nn.functional as F
24 | import torch.distributed as dist
25 | ch.backends.cudnn.benchmark = True
26 | ch.autograd.profiler.emit_nvtx(False)
27 | ch.autograd.profiler.profile(False)
28 |
29 | import argparse
30 | import parserr
31 | from dataset_convnext_like import build_dataset
32 | from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy
33 | import torchvision
34 | from torchvision import models
35 | import torchmetrics
36 | import numpy as np
37 | from tqdm import tqdm
38 | import time
39 | import json
40 | from uuid import uuid4
41 | from typing import List
42 | from pathlib import Path
43 | from argparse import ArgumentParser
44 | from datetime import datetime
45 | from functools import partial
46 | from fastargs import get_current_config
47 | from fastargs.decorators import param
48 | from fastargs import Param, Section
49 | from fastargs.validation import And, OneOf
50 |
51 | from ffcv.pipeline.operation import Operation
52 | from ffcv.loader import Loader, OrderOption
53 | from ffcv.transforms import ToTensor, ToDevice, Squeeze, NormalizeImage, \
54 | RandomHorizontalFlip, ToTorchImage, Convert
55 | from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, \
56 | RandomResizedCropRGBImageDecoder
57 | from ffcv.fields.basics import IntDecoder
58 | import timm
59 | from timm.loss import SoftTargetCrossEntropy
60 | from timm.data.mixup import Mixup
61 | from torch_intermediate_layer_getter import IntermediateLayerGetter as MidGetter
62 |
63 | from autopgd_train_clean import apgd_train
64 | from fgsm_train import fgsm_train, fgsm_attack
65 | from utils_architecture import normalize_model, get_new_model
66 | from ptflops import get_model_complexity_info
67 | from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count_str
68 |
69 | def sizeof_fmt(num, suffix="Flops"):
70 | for unit in ["", "Ki", "Mi", "G", "T"]:
71 | if abs(num) < 1000.0:
72 | return f"{num:3.3f}{unit}{suffix}"
73 | num /= 1000.0
74 | return f"{num:.1f}Yi{suffix}"
75 |
76 |
77 |
78 | import warnings
79 | warnings.filterwarnings("ignore", category=DeprecationWarning)
80 | warnings.filterwarnings("ignore")
81 | warnings.filterwarnings("ignore", category=FutureWarning)
82 | # warnings.filterwarnings("ignore", category=UserWarning)
83 | os.environ['KMP_WARNINGS'] = 'off'
84 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
85 | # os.environ["PL_TORCH_DISTRIBUTED_BACKEND"] = "gloo"
86 |
87 | class LabelSmoothingCrossEntropy(nn.Module):
88 | """ NLL loss with label smoothing.
89 | """
90 | def __init__(self, smoothing=0.1):
91 | super(LabelSmoothingCrossEntropy, self).__init__()
92 | assert smoothing < 1.0
93 | self.smoothing = smoothing
94 | self.confidence = 1. - smoothing
95 |
96 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
97 | logprobs = F.log_softmax(x, dim=-1)
98 | target = target.type(ch.int64)
99 | nll_loss = -logprobs.gather(dim=-1, index=target)
100 | nll_loss = nll_loss.squeeze(1)
101 | smooth_loss = -logprobs.mean(dim=-1)
102 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
103 | return loss.mean()
104 |
105 |
106 | Section('model', 'model details').params(
107 | arch=Param(str, default='effnet_b0'),
108 | pretrained=Param(int, 'is pretrained? (1/0)', default=1),
109 | ckpt_path=Param(str, 'path to resume model', default=''),
110 | add_normalization=Param(int, '0 if no normalization, 1 otherwise', default=1),
111 | not_original=Param(int, 'change effnets? to patch-version', default=0),
112 | updated=Param(int, 'Make conviso Big?', default=0),
113 | model_ema=Param(float, 'Use EMA?', default=0),
114 | freeze_some=Param(int, 'freeze some layers', default=0),
115 | early=Param(int, 'freeze early layers?', default=1),
116 | )
117 |
118 | Section('resolution', 'resolution scheduling').params(
119 | min_res=Param(int, 'the minimum (starting) resolution', default=160),
120 | max_res=Param(int, 'the maximum (starting) resolution', default=160),
121 | end_ramp=Param(int, 'when to stop interpolating resolution', default=0),
122 | start_ramp=Param(int, 'when to start interpolating resolution', default=0)
123 | )
124 |
125 | Section('data', 'data related stuff').params(
126 | train_dataset=Param(str, '.dat file to use for training', required=True),
127 | val_dataset=Param(str, '.dat file to use for validation', required=True),
128 | num_workers=Param(int, 'The number of workers', required=True),
129 | in_memory=Param(int, 'does the dataset fit in memory? (1/0)', required=True),
130 | seed=Param(int, 'seed for training loader', default=0),
131 | augmentations=Param(int, 'use fancy augmentations?', default=0)
132 | )
133 |
134 | Section('lr', 'lr scheduling').params(
135 | step_ratio=Param(float, 'learning rate step ratio', default=0.1),
136 | step_length=Param(int, 'learning rate step length', default=30),
137 | lr_schedule_type=Param(OneOf(['step', 'cyclic', 'cosine']), default='cosine'),
138 | lr=Param(float, 'learning rate', default=1e-3),
139 | lr_peak_epoch=Param(int, 'Epoch at which LR peaks', default=10),
140 | )
141 |
142 | Section('logging', 'how to log stuff').params(
143 | folder=Param(str, 'log location', default="/mnt/SHARED/nsingh/ImageNet_Arch/full_Img/"),
144 | log_level=Param(int, '0 if only at end 1 otherwise', default=1),
145 | save_freq=Param(int, 'save models every nth epoch', default=2),
146 | addendum=Param(str, 'additional comments?', default=""),
147 | )
148 |
149 | Section('validation', 'Validation parameters stuff').params(
150 | batch_size=Param(int, 'The batch size for validation', default=64),
151 | resolution=Param(int, 'final resized validation image size', default=224),
152 | lr_tta=Param(int, 'should do lr flipping/avging at test time', default=0),
153 | precision=Param(str, 'np precision', default='fp16')
154 | )
155 |
156 | Section('training', 'training hyper param stuff').params(
157 | eval_only=Param(int, 'eval only?', default=0),
158 | batch_size=Param(int, 'The batch size', default=512),
159 | optimizer=Param(And(str, OneOf(['sgd', 'adamw'])), 'The optimizer', default='adamw'),
160 | momentum=Param(float, 'SGD momentum', default=0.9),
161 | weight_decay=Param(float, 'weight decay', default=0.05),
162 | epochs=Param(int, 'number of epochs', default=100),
163 | label_smoothing=Param(float, 'label smoothing parameter', default=0.1),
164 | distributed=Param(int, 'is distributed?', default=0),
165 | use_blurpool=Param(int, 'use blurpool?', default=0),
166 | precision=Param(str, 'np precision', default='fp16'),
167 | )
168 |
169 | Section('dist', 'distributed training options').params(
170 | world_size=Param(int, 'number gpus', default=1),
171 | address=Param(str, 'address', default='localhost'),
172 | port=Param(str, 'port', default='12355')
173 | )
174 |
175 | Section('adv', 'adversarial training options').params(
176 | attack=Param(str, 'if None standard training', default='none'),
177 | norm=Param(str, '', default='Linf'),
178 | eps=Param(float, '', default=4./255.),
179 | n_iter=Param(int, '', default=2),
180 | verbose=Param(int, '', default=0),
181 | noise_level=Param(float, '', default=1.),
182 | skip_projection=Param(int, '', default=0),
183 | alpha=Param(float, 'step size multiplier', default=1.),
184 | )
185 |
186 | Section('misc', 'other parameters').params(
187 | notes=Param(str, '', default=''),
188 | use_channel_last=Param(int, 'whether to use channel last memory format', default=1),
189 | )
190 |
191 | IMAGENET_MEAN = [c * 1. for c in (0.485, 0.456, 0.406)] #[np.array([0., 0., 0.]), np.array([0.485, 0.456, 0.406])][-1] * 255
192 | IMAGENET_STD = [c * 1. for c in (0.229, 0.224, 0.225)] #[np.array([1., 1., 1.]), np.array([0.229, 0.224, 0.225])][-1] * 255
193 | NONORM_MEAN = np.array([0., 0., 0.])
194 | NONORM_STD = np.array([1., 1., 1.]) * 255
195 | DEFAULT_CROP_RATIO = 224/256
196 |
197 | PREC_DICT = {'fp16': np.float16, 'fp32': np.float32}
198 |
199 | def sizeof_fmt(num, suffix="Flops"):
200 | for unit in ["", "Ki", "Mi", "G", "T"]:
201 | if abs(num) < 1000.0:
202 | return f"{num:3.3f}{unit}{suffix}"
203 | num /= 1000.0
204 | return f"{num:.1f}Yi{suffix}"
205 |
206 |
207 |
208 | @param('lr.lr')
209 | @param('lr.step_ratio')
210 | @param('lr.step_length')
211 | @param('training.epochs')
212 | def get_step_lr(epoch, lr, step_ratio, step_length, epochs):
213 | if epoch >= epochs:
214 | return 0
215 |
216 | num_steps = epoch // step_length
217 | return step_ratio**num_steps * lr
218 |
219 | @param('lr.lr')
220 | @param('training.epochs')
221 | @param('lr.lr_peak_epoch')
222 | def get_cyclic_lr(epoch, lr, epochs, lr_peak_epoch):
223 | xs = [0, lr_peak_epoch, epochs]
224 | ys = [1e-4 * lr, lr, 0]
225 | return np.interp([epoch], xs, ys)[0]
226 |
227 | @param('lr.lr')
228 | @param('training.epochs')
229 | @param('lr.lr_peak_epoch')
230 | def get_cosine_lr(epoch, lr, epochs, lr_peak_epoch):
231 | # if epochs > 100:
232 | # lr_peak_epoch = 20
233 | # else:
234 | # lr_peak_epoch = 10
235 | if epoch <= lr_peak_epoch:
236 | xs = [0, lr_peak_epoch]
237 | ys = [1e-4 * lr, lr]
238 | return np.interp([epoch], xs, ys)[0]
239 | else:
240 | lr_min = 5e-6
241 | lr_t = lr_min + .5 * (lr - lr_min) * (1 + math.cos(math.pi * (
242 | epoch - lr_peak_epoch) / (epochs - lr_peak_epoch)))
243 | return lr_t
244 |
245 |
246 | class BlurPoolConv2d(ch.nn.Module):
247 | def __init__(self, conv):
248 | super().__init__()
249 | default_filter = ch.tensor([[[[1, 2, 1], [2, 4, 2], [1, 2, 1]]]]) / 16.0
250 | filt = default_filter.repeat(conv.in_channels, 1, 1, 1)
251 | self.conv = conv
252 | self.register_buffer('blur_filter', filt)
253 |
254 | def forward(self, x):
255 | blurred = F.conv2d(x, self.blur_filter, stride=1, padding=(1, 1),
256 | groups=self.conv.in_channels, bias=None)
257 | return self.conv.forward(blurred)
258 |
259 |
260 | class WrappedModel(nn.Module):
261 | """ include the generation of adversarial perturbation in the
262 | forward pass
263 | """
264 | def __init__(self, base_model, perturb, verbose=False):
265 | super().__init__()
266 | self.base_model = base_model
267 | self.perturb = perturb
268 | self.perturb_input = False
269 | self.verbose = verbose
270 | #self.mu = mu
271 | #self.sigma = sigma
272 |
273 | def forward(self, x, y=None):
274 | # TODO: handle varying threat models
275 | if self.perturb_input:
276 | assert not y is None
277 | #print(x.is_contiguous())
278 | # use eval mode during attack
279 | self.base_model.eval()
280 | if self.verbose:
281 | print('perturb input')
282 | startt = time.time()
283 | z = self.perturb(self.base_model, x, y)
284 |
285 | if self.verbose:
286 | inftime = time.time() - startt
287 | print(f'inference time={inftime:.5f}')
288 | #print(z[0].is_contiguous())
289 | self.base_model.train()
290 |
291 | if isinstance(z, (tuple, list)):
292 | z = z[0]
293 | return self.base_model(z)
294 |
295 | else:
296 | if self.verbose:
297 | print('clean inference')
298 | return self.base_model(x)
299 |
300 | def set_perturb(self, mode):
301 | self.perturb_input = mode
302 |
303 |
304 |
305 | def freeze_some_layers(model, early):
306 |
307 | if bool(early):
308 | for name, child in model.named_children():
309 | for namm, pamm in child.named_parameters():
310 | if 'stem' in namm:
311 | print(namm + ' is unfrozen')
312 | pamm.requires_grad = True
313 | else:
314 | print(namm + ' is frozen')
315 | pamm.requires_grad = False
316 | else:
317 | for name, child in model.named_children():
318 | for namm, pamm in child.named_parameters():
319 | if 'stem' in namm:
320 | print(namm + ' is unfrozen')
321 | pamm.requires_grad = False
322 | else:
323 | print(namm + ' is frozen')
324 | pamm.requires_grad = True
325 |
326 |
327 |
328 | class ImageNetTrainer:
329 | @param('training.distributed')
330 | @param('training.eval_only')
331 | def __init__(self, gpu, distributed, eval_only):
332 | self.all_params = get_current_config()
333 | self.gpu = gpu
334 | self.best_rob_acc = 0.
335 | self.uid = str(uuid4())
336 |
337 | if distributed:
338 | self.setup_distributed()
339 |
340 | if not eval_only:
341 | self.train_loader, self.val_loader, self.mixup_fn = self.create_train_loader()
342 | # self.val_loader = self.create_val_loader()
343 | self.model, self.scaler = self.create_model_and_scaler()
344 | self.create_optimizer()
345 | self.initialize_logger()
346 |
347 |
348 | @param('dist.address')
349 | @param('dist.port')
350 | @param('dist.world_size')
351 | def setup_distributed(self, address, port, world_size):
352 | os.environ['MASTER_ADDR'] = address
353 | os.environ['MASTER_PORT'] = port
354 |
355 | dist.init_process_group("nccl", rank=self.gpu, world_size=world_size)
356 | ch.cuda.set_device(self.gpu)
357 |
358 | def cleanup_distributed(self):
359 | dist.destroy_process_group()
360 |
361 | @param('lr.lr_schedule_type')
362 | def get_lr(self, epoch, lr_schedule_type):
363 | lr_schedules = {
364 | 'cyclic': get_cyclic_lr,
365 | 'step': get_step_lr,
366 | 'cosine': get_cosine_lr,
367 | }
368 |
369 | return lr_schedules[lr_schedule_type](epoch)
370 |
371 | # resolution tools
372 | @param('resolution.min_res')
373 | @param('resolution.max_res')
374 | @param('resolution.end_ramp')
375 | @param('resolution.start_ramp')
376 | def get_resolution(self, epoch, min_res, max_res, end_ramp, start_ramp):
377 | assert min_res <= max_res
378 |
379 | if epoch <= start_ramp:
380 | return min_res
381 |
382 | if epoch >= end_ramp:
383 | return max_res
384 |
385 | # otherwise, linearly interpolate to the nearest multiple of 32
386 | interp = np.interp([epoch], [start_ramp, end_ramp], [min_res, max_res])
387 | final_res = int(np.round(interp[0] / 32)) * 32
388 | return final_res
389 |
390 | @param('training.momentum')
391 | @param('training.optimizer')
392 | @param('training.weight_decay')
393 | @param('training.label_smoothing')
394 | @param('model.arch')
395 | def create_optimizer(self, momentum, optimizer, weight_decay,
396 | label_smoothing, arch):
397 | #assert optimizer == 'sgd'
398 |
399 | # Only do weight decay on non-batchnorm parameters
400 | if 'convnext' in arch or 'resnet' in arch:
401 | print('manually excluding parameters for weight decay')
402 | all_params = list(self.model.named_parameters())
403 | excluded_params = ['bn', '.bias'] #'.norm', '.bias'
404 | if arch in ['timm_convnext_tiny_batchnorm', 'timm_convnext_tiny_batchnorm_relu']:
405 | # timm convnext uses different naming than resnet
406 | excluded_params.append('.norm.')
407 | if arch in ['timm_resnet50_dw_patch-stem_gelu_stages-3393_convnext-bn_fewer-act-norm_ln',
408 | 'timm_resnet50_dw_patch-stem_gelu_stages-3393_convnext-bn_fewer-act-norm_ln_ds-sep',
409 | 'timm_resnet50_dw_patch-stem_gelu_stages-3393_convnext-bn_fewer-act-norm_ln_ds-sep_bias',
410 | 'timm_reimplemented_convnext_tiny']:
411 | # in case LN is used instead of original BN and the naming is not changed
412 | excluded_params.remove('bn')
413 | print('excluded params=', ', '.join(excluded_params))
414 |
415 | bn_params = [v for k, v in all_params if any([c in k for c in excluded_params])] #('bn' in k) #or k.endswith('.bias')
416 | bn_keys = [k for k, v in all_params if any([c in k for c in excluded_params])]
417 | #print(', '.join(bn_keys))
418 | #sys.exit()
419 | other_params = [v for k, v in all_params if not any([c in k for c in excluded_params])] #not ('bn' in k) #or k.endswith('.bias')
420 | # se_only = True
421 | # elif se_only:
422 | # other_params = []
423 | # l = 0
424 | # for name, param in self.model.named_parameters():
425 | # # print(name)
426 | # if "se_module" not in name:
427 | # # other_params.append(param)
428 | # param.requires_grad = False
429 | # l+=1
430 | # print(l)
431 | # exit()
432 | else:
433 | print('automatically exclude bn and bias from weight decay')
434 | bn_params = []
435 | bn_keys = []
436 | other_params = []
437 | for name, param in self.model.named_parameters():
438 | if not param.requires_grad:
439 | continue
440 | if param.ndim <= 1 or name.endswith(".bias"): #or name in no_weight_decay_list
441 | bn_keys.append(name)
442 | bn_params.append(param)
443 | else:
444 | other_params.append(param)
445 | #print(', '.join(bn_keys))
446 |
447 | param_groups = [{
448 | 'params': bn_params,
449 | 'weight_decay': 0.
450 | }, {
451 | 'params': other_params,
452 | 'weight_decay': weight_decay
453 | }]
454 |
455 |
456 | if optimizer == 'sgd':
457 | self.optimizer = ch.optim.SGD(param_groups, lr=1, momentum=momentum)
458 | else:
459 | self.optimizer = ch.optim.AdamW(param_groups, betas=(0.9, 0.95))
460 |
461 | if self.mixup_fn is None:
462 | self.loss = ch.nn.CrossEntropyLoss()
463 | else:
464 | # # smoothing is handled with mixup label transform
465 | # self.loss = LabelSmoothingCrossEntropy(smoothing=label_smoothing)
466 | self.loss = SoftTargetCrossEntropy()
467 |
468 | @param('data.train_dataset')
469 | @param('data.num_workers')
470 | @param('training.batch_size')
471 | @param('training.distributed')
472 | @param('training.label_smoothing')
473 | @param('data.in_memory')
474 | @param('data.seed')
475 | @param('data.augmentations')
476 | @param('training.precision')
477 | @param('misc.use_channel_last')
478 | @param('dist.world_size')
479 | def create_train_loader(self, train_dataset, num_workers, batch_size,
480 | distributed, label_smoothing, in_memory, seed, augmentations, precision,
481 | use_channel_last, world_size):
482 | torch.manual_seed(seed)
483 | if False:
484 | this_device = f'cuda:{self.gpu}'
485 | print(this_device)
486 | # train_path = Path(train_dataset)
487 | data_paths = ['/scratch/fcroce42/ffcv_imagenet_data/train_400_0.50_90.ffcv',
488 | '/scratch/nsingh/datasets/ffcv_imagenet_data/train_400_0.50_90.ffcv', '/scratch_local/datasets/ffcv_imagenet_data/train_400_0.50_90.ffcv']
489 | for data_path in data_paths:
490 | if os.path.exists(data_path):
491 | train_path = Path(data_path)
492 | break
493 | print(train_path)
494 | assert train_path.is_file()
495 |
496 | res = self.get_resolution(epoch=0)
497 | prec = PREC_DICT[precision]
498 | self.decoder = RandomResizedCropRGBImageDecoder((res, res))
499 | if use_channel_last:
500 | image_pipeline: List[Operation] = [
501 | self.decoder,
502 | RandomHorizontalFlip(),
503 | #Convert(np.float16),
504 | ToTensor(),
505 | #lambda x: x.contiguous(),
506 | ToDevice(ch.device(this_device), non_blocking=True),
507 | ToTorchImage(channels_last=True),
508 | NormalizeImage(NONORM_MEAN, NONORM_STD, #IMAGENET_MEAN, IMAGENET_STD,
509 | prec, #np.float16
510 | )
511 | ]
512 | else:
513 | image_pipeline: List[Operation] = [
514 | self.decoder,
515 | RandomHorizontalFlip(),
516 | #Convert(np.float16),
517 | ToTensor(),
518 | #lambda x: x.contiguous(),
519 | ToDevice(ch.device(this_device), non_blocking=True),
520 | ToTorchImage(channels_last=False),
521 | #NormalizeImage(NONORM_MEAN, NONORM_STD, #IMAGENET_MEAN, IMAGENET_STD,
522 | # prec, #np.float16
523 | # )
524 | Convert(ch.cuda.HalfTensor), #float16
525 | torchvision.transforms.Normalize([0., 0., 0.], [255., 255., 255.]),
526 | ]
527 |
528 | label_pipeline: List[Operation] = [
529 | IntDecoder(),
530 | ToTensor(),
531 | Squeeze(),
532 | ToDevice(ch.device(this_device), non_blocking=True)
533 | ]
534 |
535 | order = OrderOption.RANDOM if distributed else OrderOption.QUASI_RANDOM
536 | loader = Loader(train_dataset,
537 | batch_size=batch_size,
538 | num_workers=num_workers,
539 | order=order,
540 | os_cache=in_memory,
541 | drop_last=True,
542 | pipelines={
543 | 'image': image_pipeline,
544 | 'label': label_pipeline
545 | },
546 | distributed=distributed,
547 | seed=seed)
548 |
549 | else:
550 |
551 | if augmentations:
552 | args = parserr.Arguments_augment()
553 |
554 | else:
555 | args = parserr.Arguments_No_augment()
556 |
557 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
558 | if False:
559 | args.dist_eval = False
560 | dataset_val = None
561 | else:
562 | dataset_val, _ = build_dataset(is_train=False, args=args)
563 |
564 | num_tasks = world_size
565 | global_rank = self.gpu #utils.get_rank()
566 |
567 | sampler_train = torch.utils.data.DistributedSampler(
568 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=seed,
569 | )
570 | print("Sampler_train = %s" % str(sampler_train))
571 | if args.dist_eval:
572 | if len(dataset_val) % num_tasks != 0:
573 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
574 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
575 | 'equal num of samples per-process.')
576 | sampler_val = torch.utils.data.DistributedSampler(
577 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
578 | else:
579 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
580 |
581 | data_loader_train = torch.utils.data.DataLoader(
582 | dataset_train, sampler=sampler_train,
583 | batch_size=batch_size,
584 | num_workers=num_workers,
585 | pin_memory=True,
586 | drop_last=True,
587 | )
588 | if dataset_val is not None:
589 | data_loader_val = torch.utils.data.DataLoader(
590 | dataset_val, sampler=sampler_val,
591 | batch_size=int(1.5 * batch_size),
592 | num_workers=num_workers,
593 | pin_memory=True,
594 | drop_last=False
595 | )
596 | else:
597 | data_loader_val = None
598 |
599 | mixup_fn = None
600 | mixup_active = (args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None) and augmentations
601 | if mixup_active:
602 | print("Mixup is activated!")
603 | print(f"Using label smoothing:{label_smoothing}")
604 | mixup_fn = Mixup(
605 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
606 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
607 | label_smoothing=label_smoothing, num_classes=args.nb_classes)
608 |
609 | # assert not distributed
610 | # self.decoder = None #RandomResizedCropRGBImageDecoder((res, res))
611 |
612 | # from robustness.datasets import DATASETS
613 | # from robustness.tools import helpers
614 | # data_paths = ['/home/scratch/datasets/imagenet',
615 | # '/scratch_local/datasets/ImageNet2012',
616 | # '/mnt/qb/datasets/ImageNet2012',
617 | # '/scratch/datasets/imagenet/']
618 | # for data_path in data_paths:
619 | # if os.path.exists(data_path):
620 | # break
621 | # print(f'found dataset at {data_path}')
622 | # dataset = DATASETS['imagenet'](data_path) #'/home/scratch/datasets/imagenet'
623 |
624 |
625 | # train_loader, val_loader = dataset.make_loaders(num_workers,
626 | # batch_size, data_aug=True)
627 |
628 | # loader = helpers.DataPrefetcher(train_loader)
629 | # #val_loader = helpers.DataPrefetcher(val_loader)
630 |
631 |
632 |
633 | return data_loader_train, data_loader_val, mixup_fn
634 |
635 | @param('data.val_dataset')
636 | @param('data.num_workers')
637 | @param('validation.batch_size')
638 | @param('validation.resolution')
639 | @param('validation.precision')
640 | @param('training.distributed')
641 | @param('misc.use_channel_last')
642 | def create_val_loader(self, val_dataset, num_workers, batch_size,
643 | resolution, precision, distributed, use_channel_last
644 | ):
645 | this_device = f'cuda:{self.gpu}'
646 | # val_path = Path(val_dataset)
647 | data_paths = ['/scratch/fcroce42/ffcv_imagenet_data/val_400_0.50_90.ffcv',
648 | '/scratch/nsingh/datasets/ffcv_imagenet_data/val_400_0.50_90.ffcv', '/scratch_local/datasets/ffcv_imagenet_data/train_400_0.50_90.ffcv']
649 | for data_path in data_paths:
650 | if os.path.exists(data_path):
651 | val_path = Path(data_path)
652 | break
653 | assert val_path.is_file()
654 | res_tuple = (resolution, resolution)
655 | prec = PREC_DICT[precision]
656 | cropper = CenterCropRGBImageDecoder(res_tuple, ratio=DEFAULT_CROP_RATIO)
657 | if use_channel_last:
658 | image_pipeline = [
659 | cropper,
660 | ToTensor(),
661 | ToDevice(ch.device(this_device), non_blocking=True),
662 | ToTorchImage(),
663 | NormalizeImage(NONORM_MEAN, NONORM_STD, #IMAGENET_MEAN, IMAGENET_STD
664 | prec)
665 | ]
666 | else:
667 | image_pipeline = [
668 | cropper,
669 | ToTensor(),
670 | ToDevice(ch.device(this_device), non_blocking=True),
671 | ToTorchImage(channels_last=False),
672 | Convert(ch.cuda.FloatTensor),
673 | torchvision.transforms.Normalize([0., 0., 0.], [255., 255., 255.]),
674 | ]
675 |
676 | label_pipeline = [
677 | IntDecoder(),
678 | ToTensor(),
679 | Squeeze(),
680 | ToDevice(ch.device(this_device),
681 | non_blocking=True)
682 | ]
683 |
684 | loader = Loader(val_dataset,
685 | batch_size=batch_size,
686 | num_workers=num_workers,
687 | order=OrderOption.SEQUENTIAL,
688 | drop_last=False,
689 | pipelines={
690 | 'image': image_pipeline,
691 | 'label': label_pipeline
692 | },
693 | distributed=distributed)
694 | return loader
695 |
696 | @param('training.epochs')
697 | @param('logging.log_level')
698 | @param('logging.save_freq')
699 | @param('model.ckpt_path')
700 | @param('adv.attack')
701 |
702 | def train(self, epochs, log_level, save_freq, ckpt_path, attack):
703 | vall, nums = self.single_val()
704 | if log_level > 0:
705 | val_dict = {
706 | 'Validation acc': vall.item(),
707 | 'points': nums
708 | }
709 | if self.gpu == 0:
710 | self.log(val_dict)
711 |
712 | for epoch in range(epochs):
713 | #print(f'epoch {epoch}')
714 | res = self.get_resolution(epoch)
715 | try:
716 | self.decoder.output_size = (res, res)
717 | except:
718 | pass
719 | train_loss = self.train_loop(epoch)
720 |
721 | if log_level > 0:
722 | extra_dict = {
723 | 'train_loss': train_loss.item(),
724 | 'epoch': epoch
725 | }
726 |
727 | self.eval_and_log(extra_dict)
728 |
729 | if train_loss.isnan():
730 | sys.exit()
731 |
732 | # if attack == 'none':
733 | ##### save every 10 epochs if
734 | save_freq = 1
735 |
736 | self.eval_and_log({'epoch': epoch})
737 | if (self.gpu == 0 and epoch % save_freq == 0) or (self.gpu == 0 and epoch == epochs - 1):
738 | ch.save(self.model.state_dict(), self.log_folder / f'weights_{epoch}.pt')
739 | if self.model_ema is not None:
740 | ch.save(timm.utils.model.get_state_dict(self.model_ema), self.log_folder / f'weights_ema_{epoch}.pt')
741 | if epoch % 5 == 0 or epoch == epochs-1:
742 | ch.save({
743 | 'model_state_dict': self.model.state_dict(),
744 | 'optimizer_state_dict': self.optimizer.state_dict(),
745 | 'loss_scaler_state_dict': self.scaler.state_dict(),
746 | 'epoch': epoch,
747 | 'state_dict_ema':timm.utils.model.get_state_dict(self.model_ema)
748 | }, self.log_folder / f'full_model_{epoch}.pth')
749 | else:
750 | if epoch % 5 == 0 or epoch == epochs-1:
751 | ch.save({
752 | 'model_state_dict': self.model.state_dict(),
753 | 'optimizer_state_dict': self.optimizer.state_dict(),
754 | 'loss_scaler_state_dict': self.scaler.state_dict(),
755 | 'epoch': epoch,
756 | }, self.log_folder / f'full_model_{epoch}.pth')
757 |
758 | def eval_and_log(self, extra_dict={}):
759 | start_val = time.time()
760 | stats = 0 #self.val_loop()
761 | val_time = time.time() - start_val
762 | if self.gpu == 0:
763 | self.log(dict({
764 | 'current_lr': self.optimizer.param_groups[0]['lr'],
765 | 'top_1': stats,
766 | 'top_5': stats,
767 | 'val_time': 0
768 | }, **extra_dict))
769 |
770 | return stats
771 |
772 |
773 |
774 | @param('model.arch')
775 | @param('model.pretrained')
776 | @param('model.not_original')
777 | @param('model.updated')
778 | @param('model.model_ema')
779 | @param('model.freeze_some')
780 | @param('model.early')
781 | @param('training.distributed')
782 | @param('training.use_blurpool')
783 | @param('model.ckpt_path')
784 | @param('model.add_normalization')
785 | @param('adv.attack')
786 | @param('adv.norm')
787 | @param('adv.eps')
788 | @param('adv.n_iter')
789 | @param('adv.verbose')
790 | @param('misc.use_channel_last')
791 | @param('adv.alpha')
792 | @param('adv.noise_level')
793 | @param('adv.skip_projection')
794 | def create_model_and_scaler(self, arch, pretrained, not_original, updated, model_ema, freeze_some, early, distributed, use_blurpool,
795 | ckpt_path, add_normalization, attack, norm, eps, n_iter, verbose,
796 | use_channel_last, alpha, noise_level, skip_projection):
797 | scaler = GradScaler()
798 | if not arch.startswith('timm_'):
799 | model = get_new_model(arch, pretrained=bool(pretrained), not_original=bool(not_original), updated=bool(updated))
800 | else:
801 | try:
802 | model = create_model(arch.replace('timm_', ''), pretrained=pretrained)
803 | #model.drop_path_rate = .1
804 | except:
805 | model = get_new_model(arch.replace('timm_', ''))
806 | verbose = verbose == 1
807 |
808 | def apply_blurpool(mod: ch.nn.Module):
809 | for (name, child) in mod.named_children():
810 | if isinstance(child, ch.nn.Conv2d) and (np.max(child.stride) > 1 and child.in_channels >= 16):
811 | setattr(mod, name, BlurPoolConv2d(child))
812 | else: apply_blurpool(child)
813 | if use_blurpool: apply_blurpool(model)
814 |
815 | if use_channel_last:
816 | print('using channel last memory format')
817 | model = model.to(memory_format=ch.channels_last)
818 | else:
819 | print('not using channel last memory format')
820 |
821 | if bool(freeze_some):
822 | print(f"Freezing early layers: {bool(early)}")
823 | freeze_some_layers(model, early)
824 |
825 |
826 | if arch != 'convnext_tiny_21k' and add_normalization:
827 | print('add normalization layer')
828 | model = normalize_model(model, IMAGENET_MEAN, IMAGENET_STD)
829 |
830 |
831 | if attack in ['apgd', 'fgsm']:
832 | print('using input perturbation layer')
833 | if attack == 'apgd':
834 | attack = partial(apgd_train, norm=norm, eps=eps,
835 | n_iter=n_iter, verbose=verbose, mixup=self.mixup_fn)
836 | elif attack == 'fgsm':
837 | attack = partial(fgsm_train, eps=eps,
838 | use_rs=True,
839 | alpha=alpha,
840 | noise_level=noise_level,
841 | skip_projection=skip_projection == 1
842 | )
843 | print(attack)
844 | model = WrappedModel(model, attack, verbose=verbose)
845 |
846 | if self.gpu == 0:
847 | print(model)
848 | inpp = torch.rand(1, 3, 224, 224)
849 | flops = FlopCountAnalysis(model, inpp)
850 | val = flops.total()
851 | print(val)
852 | print(sizeof_fmt(int(val)))
853 | print(flop_count_table(flops, max_depth=2))
854 | print(flops.by_operator())
855 |
856 | if not ckpt_path == '':
857 | ckpt = ch.load(ckpt_path, map_location='cpu')
858 | ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
859 | try:
860 | model.load_state_dict(ckpt)
861 | print('standard loading')
862 |
863 | except:
864 | try:
865 | ckpt = {f'base_model.{k}': v for k, v in ckpt.items()}
866 | model.load_state_dict(ckpt)
867 | print('loaded from clean model')
868 | except:
869 | ckpt = {k.replace('base_model.', ''): v for k, v in ckpt.items()}
870 | # ckpt = {f'base_model.{k}': v for k, v in ckpt.items()}
871 | model.load_state_dict(ckpt)
872 | print('loaded')
873 | #model = model.to(memory_format=ch.channels_last)
874 |
875 | # print(model.patch_embed(torch.rand((50, 3, 224, 224))))
876 | # exit()
877 | # if arch != 'convnext_tiny_21k' and add_normalization:
878 | # print('add normalization layer')
879 | # model = normalize_model(model, IMAGENET_MEAN, IMAGENET_STD)
880 |
881 | model = model.to(self.gpu)
882 | if bool(model_ema):
883 | print('Using EMA with decay 0.9999')
884 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
885 | self.model_ema = timm.utils.ModelEmaV2(model, decay=0.9999, device='cpu')
886 | else:
887 | self.model_ema = None
888 |
889 | if distributed:
890 | model = ch.nn.parallel.DistributedDataParallel(model, device_ids=[self.gpu]) #, find_unused_parameters=True)
891 |
892 | return model, scaler
893 |
894 | @param('validation.lr_tta')
895 | @param('adv.attack')
896 | @param('dist.world_size')
897 | def single_val(self, lr_tta, attack, world_size):
898 | model = self.model
899 | model.eval()
900 | show_once = True
901 | acc = 0.
902 | accs = []
903 | n = 0.
904 | ns = []
905 | best_test_rob = 0.
906 |
907 |
908 | # with ch.no_grad():
909 | with autocast(enabled=True):
910 | for idx, (images, target) in enumerate(tqdm(self.val_loader)):
911 | # if show_once:
912 | # print(images.shape, images.max(), images.min())
913 | # show_once = False
914 |
915 | images = images.contiguous().cuda(self.gpu, non_blocking=True)
916 | target = target.contiguous().cuda(self.gpu, non_blocking=True)
917 | output = self.model(images)
918 | if lr_tta:
919 | output += self.model(ch.flip(images, dims=[3]))
920 |
921 | # for k in ['top_1', 'top_5']:
922 | # self.val_meters[k](output, target)
923 |
924 | acc += (output.max(1)[1] == target).sum()
925 | n += target.shape[0]
926 | # loss_val = self.loss(output, target) #####. remove this comment
927 | # self.val_meters['loss'](loss_val)
928 | if idx >= 200:
929 | break
930 | accs.append(acc)
931 | ns = n*world_size
932 | print(f'clean accuracy={acc / n:.2%}')
933 | # stats = {k: m.compute().item() for k, m in self.val_meters.items()}
934 | # if stats['top_1'] > self.best_rob_acc:
935 | # self.best_rob_acc = stats['top_1']
936 | # if self.gpu == 0:
937 | # ch.save(self.model.state_dict(), self.log_folder / 'best_adv_weights.pt')
938 | # [meter.reset() for meter in self.val_meters.values()]
939 | return ch.stack(accs)/ns, ns
940 |
941 | @param('logging.log_level')
942 | @param('adv.attack')
943 | @param('training.distributed')
944 | def train_loop(self, epoch, log_level, attack, distributed):
945 | model = self.model
946 | model.train()
947 | losses = []
948 | show_once = True
949 | perturb = attack != 'none'
950 | if perturb:
951 | if distributed:
952 | model.module.set_perturb(True)
953 | else:
954 | model.set_perturb(True)
955 |
956 | lr_start, lr_end = self.get_lr(epoch), self.get_lr(epoch + 1)
957 | iters = len(self.train_loader)
958 | lrs = np.interp(np.arange(iters), [0, iters], [lr_start, lr_end])
959 |
960 | iterator = tqdm(self.train_loader)
961 | for ix, (images, target) in enumerate(iterator):
962 | images = images.cuda(self.gpu, non_blocking=True)
963 | target = target.cuda(self.gpu, non_blocking=True)
964 | # print(images.size(), target.size())
965 | if self.mixup_fn is not None:
966 | images, target = self.mixup_fn(images, target)
967 |
968 | if show_once:
969 | # print(images.shape, images.max(), images.min())
970 | show_once = False
971 |
972 | ### Training start
973 | for param_group in self.optimizer.param_groups:
974 | param_group['lr'] = lrs[ix]
975 |
976 | '''print(images.device)
977 | images = images.reshape(images.shape) # make contiguous (and more)
978 | if False:
979 | ch.save(images, './train_imgs_cm.pth')
980 | sys.exit()
981 | target = target.reshape(target.shape)
982 | print(images.device)'''
983 |
984 | self.optimizer.zero_grad(set_to_none=True)
985 | with autocast(enabled=True):
986 | if not perturb:
987 | output = self.model(images)
988 | else:
989 | output = self.model(images, target) # TODO: check the effect of .contiguous() for other models
990 | loss_train = self.loss(output, target)
991 |
992 | self.scaler.scale(loss_train).backward()
993 | self.scaler.step(self.optimizer)
994 | self.scaler.update()
995 | ### Training end
996 | if self.model_ema is not None:
997 | self.model_ema.update(model)
998 |
999 | #ch.cuda.synchronize()
1000 |
1001 | ### Logging start
1002 | if log_level > 0:
1003 | losses.append(loss_train.detach())
1004 |
1005 | group_lrs = []
1006 | for _, group in enumerate(self.optimizer.param_groups):
1007 | group_lrs.append(f'{group["lr"]:.3f}')
1008 |
1009 | names = ['ep', 'iter', 'shape', 'lrs']
1010 | values = [epoch, ix, tuple(images.shape), group_lrs]
1011 | if log_level > 1:
1012 | names += ['loss']
1013 | values += [f'{loss_train.item():.3f}']
1014 |
1015 | msg = ', '.join(f'{n}={v}' for n, v in zip(names, values))
1016 | iterator.set_description(msg)
1017 | ### Logging end
1018 |
1019 |
1020 | if perturb:
1021 | if distributed:
1022 | model.module.set_perturb(False)
1023 | else:
1024 | model.set_perturb(False)
1025 |
1026 | return ch.stack(losses).mean()
1027 |
1028 | @param('validation.lr_tta')
1029 | @param('adv.attack')
1030 | def val_loop(self, lr_tta, attack):
1031 |
1032 | model = self.model
1033 | model.eval()
1034 | show_once = True
1035 | acc = 0.
1036 | best_test_rob = 0
1037 | # with ch.no_grad():
1038 | with autocast(enabled=True):
1039 | for idx, (images, target) in enumerate(tqdm(self.val_loader)):
1040 | # if show_once:
1041 | # print(images.shape, images.max(), images.min())
1042 | # show_once = False
1043 |
1044 | images = images.contiguous()
1045 | target = target.contiguous()
1046 | # if attack != 'none':
1047 | # x_adv = fgsm_attack(model, images, target, eps=4./255.)
1048 | # output = self.model(x_adv)
1049 | # if lr_tta:
1050 | # output += self.model(ch.flip(x_adv, dims=[3]))
1051 | # else:
1052 | output = self.model(images)
1053 | if lr_tta:
1054 | output += self.model(ch.flip(images, dims=[3]))
1055 | # if lr_tta:
1056 | # output += self.model(ch.flip(x_adv, dims=[3]))
1057 | for k in ['top_1', 'top_5']:
1058 | self.val_meters[k](output, target)
1059 |
1060 | acc += (output.max(1)[1] == target).sum()
1061 |
1062 | loss_val = self.loss(output, target)
1063 | self.val_meters['loss'](loss_val)
1064 | if idx >= 50:
1065 | break
1066 | #print(f'clean accuracy={acc / 50000:.2%}')
1067 | stats = {k: m.compute().item() for k, m in self.val_meters.items()}
1068 |
1069 | if stats['top_1'] > self.best_rob_acc:
1070 | self.best_rob_acc = stats['top_1']
1071 | if self.gpu == 0:
1072 | ch.save(self.model.state_dict(), self.log_folder / 'best_adv_weights.pt')
1073 | [meter.reset() for meter in self.val_meters.values()]
1074 | return stats
1075 |
1076 | @param('logging.folder')
1077 | @param('model.arch')
1078 | @param('adv.attack')
1079 | @param('model.updated')
1080 | @param('model.not_original')
1081 | @param('logging.addendum')
1082 | @param('data.augmentations')
1083 | @param('model.pretrained')
1084 | def initialize_logger(self, folder, arch, attack, updated, not_original, addendum, augmentations, pretrained):
1085 | self.val_meters = {
1086 | 'top_1': torchmetrics.Accuracy(compute_on_step=False).to(self.gpu),
1087 | 'top_5': torchmetrics.Accuracy(compute_on_step=False, top_k=5).to(self.gpu),
1088 | 'loss': MeanScalarMetric(compute_on_step=False).to(self.gpu)
1089 | }
1090 |
1091 | if self.gpu == 0:
1092 | #folder = (Path(folder) / str(self.uid)).absolute()
1093 | runname = f'model_{str(datetime.now())[:-7]}_{arch}_upd_{updated}_not_orig_{not_original}_pre_{pretrained}_aug_{augmentations}'
1094 | if attack != 'none':
1095 | runname += f'_adv_{addendum}'
1096 | else:
1097 | runname += f'_clean_{addendum}'
1098 | folder = (Path(folder) / runname).absolute()
1099 | folder.mkdir(parents=True)
1100 |
1101 | self.log_folder = folder
1102 | self.start_time = time.time()
1103 |
1104 | print(f'=> Logging in {self.log_folder}')
1105 | params = {
1106 | '.'.join(k): self.all_params[k] for k in self.all_params.entries.keys()
1107 | }
1108 | with open(folder / 'params.json', 'w+') as handle:
1109 | json.dump(params, handle)
1110 |
1111 | def log(self, content):
1112 | print(f'=> Log: {content}')
1113 | if self.gpu != 0: return
1114 | cur_time = time.time()
1115 | try:
1116 | with open(self.log_folder / 'log', 'a+') as fd:
1117 | fd.write(json.dumps({
1118 | 'timestamp': cur_time,
1119 | 'relative_time': cur_time - self.start_time,
1120 | **content
1121 | }) + '\n')
1122 | fd.flush()
1123 | except:
1124 | with open(self.log_folder / 'log', 'a+') as fd:
1125 | fd.write(content + '\n')
1126 | fd.flush()
1127 |
1128 | @classmethod
1129 | @param('training.distributed')
1130 | @param('dist.world_size')
1131 | def launch_from_args(cls, distributed, world_size):
1132 | if distributed:
1133 | ch.multiprocessing.spawn(cls._exec_wrapper, nprocs=world_size, join=True)
1134 | else:
1135 | cls.exec(0)
1136 |
1137 | @classmethod
1138 | def _exec_wrapper(cls, *args, **kwargs):
1139 | make_config(quiet=True)
1140 | cls.exec(*args, **kwargs)
1141 |
1142 | @classmethod
1143 | @param('training.distributed')
1144 | @param('training.eval_only')
1145 | def exec(cls, gpu, distributed, eval_only):
1146 | trainer = cls(gpu=gpu)
1147 | if eval_only:
1148 | trainer.eval_and_log()
1149 | else:
1150 | trainer.train()
1151 | if distributed:
1152 | trainer.cleanup_distributed()
1153 |
1154 | # Utils
1155 | class MeanScalarMetric(torchmetrics.Metric):
1156 | def __init__(self, *args, **kwargs):
1157 | super().__init__(*args, **kwargs)
1158 |
1159 | self.add_state('sum', default=ch.tensor(0.), dist_reduce_fx='sum')
1160 | self.add_state('count', default=ch.tensor(0), dist_reduce_fx='sum')
1161 |
1162 | def update(self, sample: ch.Tensor):
1163 | self.sum += sample.sum()
1164 | self.count += sample.numel()
1165 |
1166 | def compute(self):
1167 | return self.sum.float() / self.count
1168 |
1169 | # Running
1170 | def make_config(quiet=False):
1171 | config = get_current_config()
1172 | parser = ArgumentParser(description='Fast imagenet training')
1173 | config.augment_argparse(parser)
1174 | config.collect_argparse_args(parser)
1175 | config.validate(mode='stderr')
1176 | if not quiet:
1177 | config.summary()
1178 |
1179 | if __name__ == "__main__":
1180 | make_config()
1181 | ImageNetTrainer.launch_from_args()
1182 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.7.1"
2 | from .EffNet import EfficientNet, VALID_MODELS
3 | from .utils import (
4 | GlobalParams,
5 | BlockArgs,
6 | BlockDecoder,
7 | efficientnet,
8 | get_model_params,
9 | )
--------------------------------------------------------------------------------
/models/convnext.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from timm.models.layers import trunc_normal_, DropPath
13 | from timm.models.registry import register_model
14 |
15 | class Block(nn.Module):
16 | r""" ConvNeXt Block. There are two equivalent implementations:
17 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
18 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
19 | We use (2) as we find it slightly faster in PyTorch
20 |
21 | Args:
22 | dim (int): Number of input channels.
23 | drop_path (float): Stochastic depth rate. Default: 0.0
24 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
25 | """
26 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
27 | super().__init__()
28 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
29 | self.norm = LayerNorm(dim, eps=1e-6)
30 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
31 | self.act = nn.GELU()
32 | self.pwconv2 = nn.Linear(4 * dim, dim)
33 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
34 | requires_grad=True) if layer_scale_init_value > 0 else None
35 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
36 |
37 | def forward(self, x):
38 | input = x
39 | x = self.dwconv(x)
40 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
41 | x = self.norm(x)
42 | x = self.pwconv1(x)
43 | x = self.act(x)
44 | x = self.pwconv2(x)
45 | if self.gamma is not None:
46 | x = self.gamma * x
47 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
48 |
49 | x = input + self.drop_path(x)
50 | return x
51 |
52 | class ConvNeXt(nn.Module):
53 | r""" ConvNeXt
54 | A PyTorch impl of : `A ConvNet for the 2020s` -
55 | https://arxiv.org/pdf/2201.03545.pdf
56 |
57 | Args:
58 | in_chans (int): Number of input image channels. Default: 3
59 | num_classes (int): Number of classes for classification head. Default: 1000
60 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
61 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
62 | drop_path_rate (float): Stochastic depth rate. Default: 0.
63 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
64 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
65 | """
66 | def __init__(self, in_chans=3, num_classes=1000,
67 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
68 | layer_scale_init_value=1e-6, head_init_scale=1.,
69 | ):
70 | super().__init__()
71 |
72 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
73 | stem = nn.Sequential(
74 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
75 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
76 | )
77 | self.downsample_layers.append(stem)
78 | for i in range(3):
79 | downsample_layer = nn.Sequential(
80 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
81 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
82 | )
83 | self.downsample_layers.append(downsample_layer)
84 |
85 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
86 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
87 | cur = 0
88 | for i in range(4):
89 | stage = nn.Sequential(
90 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
91 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
92 | )
93 | self.stages.append(stage)
94 | cur += depths[i]
95 |
96 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
97 | self.head = nn.Linear(dims[-1], num_classes)
98 |
99 | self.apply(self._init_weights)
100 | self.head.weight.data.mul_(head_init_scale)
101 | self.head.bias.data.mul_(head_init_scale)
102 |
103 | def _init_weights(self, m):
104 | if isinstance(m, (nn.Conv2d, nn.Linear)):
105 | trunc_normal_(m.weight, std=.02)
106 | nn.init.constant_(m.bias, 0)
107 |
108 | def forward_features(self, x):
109 | for i in range(4):
110 | x = self.downsample_layers[i](x)
111 | x = self.stages[i](x)
112 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
113 |
114 | def forward(self, x):
115 | x = self.forward_features(x)
116 | x = self.head(x)
117 | return x
118 |
119 | class LayerNorm(nn.Module):
120 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
121 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
122 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
123 | with shape (batch_size, channels, height, width).
124 | """
125 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
126 | super().__init__()
127 | self.weight = nn.Parameter(torch.ones(normalized_shape))
128 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
129 | self.eps = eps
130 | self.data_format = data_format
131 | if self.data_format not in ["channels_last", "channels_first"]:
132 | raise NotImplementedError
133 | self.normalized_shape = (normalized_shape, )
134 |
135 | def forward(self, x):
136 | if self.data_format == "channels_last":
137 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
138 | elif self.data_format == "channels_first":
139 | u = x.mean(1, keepdim=True)
140 | s = (x - u).pow(2).mean(1, keepdim=True)
141 | x = (x - u) / torch.sqrt(s + self.eps)
142 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
143 | return x
144 |
145 |
146 | model_urls = {
147 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
148 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
149 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
150 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
151 | "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
152 | "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
153 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
154 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
155 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
156 | }
157 |
158 | @register_model
159 | def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
160 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
161 | if pretrained:
162 | url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
163 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
164 | model.load_state_dict(checkpoint["model"])
165 | return model
166 |
167 | @register_model
168 | def convnext_small(pretrained=False,in_22k=False, **kwargs):
169 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
170 | if pretrained:
171 | url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
172 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
173 | model.load_state_dict(checkpoint["model"])
174 | return model
175 |
176 | @register_model
177 | def convnext_base(pretrained=False, in_22k=False, **kwargs):
178 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
179 | if pretrained:
180 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
181 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
182 | model.load_state_dict(checkpoint["model"])
183 | return model
184 |
185 | @register_model
186 | def convnext_large(pretrained=False, in_22k=False, **kwargs):
187 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
188 | if pretrained:
189 | url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
190 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
191 | model.load_state_dict(checkpoint["model"])
192 | return model
193 |
194 | @register_model
195 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
196 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
197 | if pretrained:
198 | assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
199 | url = model_urls['convnext_xlarge_22k']
200 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
201 | model.load_state_dict(checkpoint["model"])
202 | return model
--------------------------------------------------------------------------------
/models/convnext_iso.py:
--------------------------------------------------------------------------------
1 | '''Taken from convnext-github repo as is'''
2 |
3 | # Copyright (c) Meta Platforms, Inc. and affiliates.
4 |
5 | # All rights reserved.
6 |
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 |
10 |
11 | from functools import partial
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | from timm.models.layers import trunc_normal_, DropPath
16 | from timm.models.registry import register_model
17 | from models.convnext import Block, LayerNorm
18 |
19 | class ConvNeXtIsotropic(nn.Module):
20 | r""" ConvNeXt
21 | A PyTorch impl of : `A ConvNet for the 2020s` -
22 | https://arxiv.org/pdf/2201.03545.pdf
23 | Isotropic ConvNeXts (Section 3.3 in paper)
24 |
25 | Args:
26 | in_chans (int): Number of input image channels. Default: 3
27 | num_classes (int): Number of classes for classification head. Default: 1000
28 | depth (tuple(int)): Number of blocks. Default: 18.
29 | dims (int): Feature dimension. Default: 384
30 | drop_path_rate (float): Stochastic depth rate. Default: 0.
31 | layer_scale_init_value (float): Init value for Layer Scale. Default: 0.
32 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
33 | """
34 | def __init__(self, in_chans=3, num_classes=1000,
35 | depth=18, dim=384, drop_path_rate=0.,
36 | layer_scale_init_value=0, head_init_scale=1.,
37 | ):
38 | super().__init__()
39 |
40 | self.stem = nn.Conv2d(in_chans, dim, kernel_size=16, stride=16)
41 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, depth)]
42 | self.blocks = nn.Sequential(*[Block(dim=dim, drop_path=dp_rates[i],
43 | layer_scale_init_value=layer_scale_init_value)
44 | for i in range(depth)])
45 |
46 | self.norm = LayerNorm(dim, eps=1e-6) # final norm layer
47 | self.head = nn.Linear(dim, num_classes)
48 |
49 | self.apply(self._init_weights)
50 | self.head.weight.data.mul_(head_init_scale)
51 | self.head.bias.data.mul_(head_init_scale)
52 |
53 | def _init_weights(self, m):
54 | if isinstance(m, (nn.Conv2d, nn.Linear)):
55 | trunc_normal_(m.weight, std=.02)
56 | nn.init.constant_(m.bias, 0)
57 |
58 | def forward_features(self, x):
59 | x = self.stem(x)
60 | x = self.blocks(x)
61 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
62 |
63 | def forward(self, x):
64 | x = self.forward_features(x)
65 | x = self.head(x)
66 | return x
67 |
68 | @register_model
69 | def convnext_isotropic_small(pretrained=False, dim=384, depth=18, **kwargs):
70 | model = ConvNeXtIsotropic(depth=depth, dim=dim, **kwargs)
71 | if pretrained:
72 | url = 'https://dl.fbaipublicfiles.com/convnext/convnext_iso_small_1k_224_ema.pth'
73 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
74 | model.load_state_dict(checkpoint["model"])
75 | return model
76 |
77 | @register_model
78 | def convnext_isotropic_base(pretrained=False, **kwargs):
79 | model = ConvNeXtIsotropic(depth=18, dim=768, **kwargs)
80 | if pretrained:
81 | url = 'https://dl.fbaipublicfiles.com/convnext/convnext_iso_base_1k_224_ema.pth'
82 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
83 | model.load_state_dict(checkpoint["model"])
84 | return model
85 |
86 | @register_model
87 | def convnext_isotropic_large(pretrained=False, **kwargs):
88 | model = ConvNeXtIsotropic(depth=36, dim=1024, **kwargs)
89 | if pretrained:
90 | url = 'https://dl.fbaipublicfiles.com/convnext/convnext_iso_large_1k_224_ema.pth'
91 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
92 | model.load_state_dict(checkpoint["model"])
93 | return model
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | """utils.py - Helper functions for building the model and for loading model parameters.
2 | These helper functions are built to mirror those in the official TensorFlow implementation.
3 | """
4 |
5 | # Author: lukemelas (github username)
6 | # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
7 | # With adjustments and added comments by workingcoder (github username).
8 |
9 | import re
10 | import math
11 | import collections
12 | from functools import partial
13 | import torch
14 | from torch import nn
15 | from torch.nn import functional as F
16 | from torch.utils import model_zoo
17 |
18 |
19 | ################################################################################
20 | # Help functions for model architecture
21 | ################################################################################
22 |
23 | # GlobalParams and BlockArgs: Two namedtuples
24 | # Swish and MemoryEfficientSwish: Two implementations of the method
25 | # round_filters and round_repeats:
26 | # Functions to calculate params for scaling model width and depth ! ! !
27 | # get_width_and_height_from_size and calculate_output_image_size
28 | # drop_connect: A structural design
29 | # get_same_padding_conv2d:
30 | # Conv2dDynamicSamePadding
31 | # Conv2dStaticSamePadding
32 | # get_same_padding_maxPool2d:
33 | # MaxPool2dDynamicSamePadding
34 | # MaxPool2dStaticSamePadding
35 | # It's an additional function, not used in EfficientNet,
36 | # but can be used in other model (such as EfficientDet).
37 |
38 | # Parameters for the entire model (stem, all blocks, and head)
39 | GlobalParams = collections.namedtuple('GlobalParams', [
40 | 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
41 | 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
42 | 'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top'])
43 |
44 | # Parameters for an individual model block
45 | BlockArgs = collections.namedtuple('BlockArgs', [
46 | 'num_repeat', 'kernel_size', 'stride', 'expand_ratio',
47 | 'input_filters', 'output_filters', 'se_ratio', 'id_skip'])
48 |
49 | # Set GlobalParams and BlockArgs's defaults
50 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
51 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
52 |
53 | # Swish activation function
54 | if hasattr(nn, 'SiLU'):
55 | Swish = nn.SiLU
56 | else:
57 | # For compatibility with old PyTorch versions
58 | class Swish(nn.Module):
59 | def forward(self, x):
60 | return x * torch.sigmoid(x)
61 |
62 |
63 | # A memory-efficient implementation of Swish function
64 | class SwishImplementation(torch.autograd.Function):
65 | @staticmethod
66 | def forward(ctx, i):
67 | result = i * torch.sigmoid(i)
68 | ctx.save_for_backward(i)
69 | return result
70 |
71 | @staticmethod
72 | def backward(ctx, grad_output):
73 | i = ctx.saved_tensors[0]
74 | sigmoid_i = torch.sigmoid(i)
75 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
76 |
77 |
78 | class MemoryEfficientSwish(nn.Module):
79 | def forward(self, x):
80 | return SwishImplementation.apply(x)
81 |
82 |
83 | def round_filters(filters, global_params):
84 | """Calculate and round number of filters based on width multiplier.
85 | Use width_coefficient, depth_divisor and min_depth of global_params.
86 |
87 | Args:
88 | filters (int): Filters number to be calculated.
89 | global_params (namedtuple): Global params of the model.
90 |
91 | Returns:
92 | new_filters: New filters number after calculating.
93 | """
94 | multiplier = global_params.width_coefficient
95 | if not multiplier:
96 | return filters
97 | # TODO: modify the params names.
98 | # maybe the names (width_divisor,min_width)
99 | # are more suitable than (depth_divisor,min_depth).
100 | divisor = global_params.depth_divisor
101 | min_depth = global_params.min_depth
102 | filters *= multiplier
103 | min_depth = min_depth or divisor # pay attention to this line when using min_depth
104 | # follow the formula transferred from official TensorFlow implementation
105 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
106 | if new_filters < 0.9 * filters: # prevent rounding by more than 10%
107 | new_filters += divisor
108 | return int(new_filters)
109 |
110 |
111 | def round_repeats(repeats, global_params):
112 | """Calculate module's repeat number of a block based on depth multiplier.
113 | Use depth_coefficient of global_params.
114 |
115 | Args:
116 | repeats (int): num_repeat to be calculated.
117 | global_params (namedtuple): Global params of the model.
118 |
119 | Returns:
120 | new repeat: New repeat number after calculating.
121 | """
122 | multiplier = global_params.depth_coefficient
123 | if not multiplier:
124 | return repeats
125 | # follow the formula transferred from official TensorFlow implementation
126 | return int(math.ceil(multiplier * repeats))
127 |
128 |
129 | def drop_connect(inputs, p, training):
130 | """Drop connect.
131 |
132 | Args:
133 | input (tensor: BCWH): Input of this structure.
134 | p (float: 0.0~1.0): Probability of drop connection.
135 | training (bool): The running mode.
136 |
137 | Returns:
138 | output: Output after drop connection.
139 | """
140 | assert 0 <= p <= 1, 'p must be in range of [0,1]'
141 |
142 | if not training:
143 | return inputs
144 |
145 | batch_size = inputs.shape[0]
146 | keep_prob = 1 - p
147 |
148 | # generate binary_tensor mask according to probability (p for 0, 1-p for 1)
149 | random_tensor = keep_prob
150 | random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
151 | binary_tensor = torch.floor(random_tensor)
152 |
153 | output = inputs / keep_prob * binary_tensor
154 | return output
155 |
156 |
157 | def get_width_and_height_from_size(x):
158 | """Obtain height and width from x.
159 |
160 | Args:
161 | x (int, tuple or list): Data size.
162 |
163 | Returns:
164 | size: A tuple or list (H,W).
165 | """
166 | if isinstance(x, int):
167 | return x, x
168 | if isinstance(x, list) or isinstance(x, tuple):
169 | return x
170 | else:
171 | raise TypeError()
172 |
173 |
174 | def calculate_output_image_size(input_image_size, stride):
175 | """Calculates the output image size when using Conv2dSamePadding with a stride.
176 | Necessary for static padding. Thanks to mannatsingh for pointing this out.
177 |
178 | Args:
179 | input_image_size (int, tuple or list): Size of input image.
180 | stride (int, tuple or list): Conv2d operation's stride.
181 |
182 | Returns:
183 | output_image_size: A list [H,W].
184 | """
185 | if input_image_size is None:
186 | return None
187 | image_height, image_width = get_width_and_height_from_size(input_image_size)
188 | stride = stride if isinstance(stride, int) else stride[0]
189 | image_height = int(math.ceil(image_height / stride))
190 | image_width = int(math.ceil(image_width / stride))
191 | return [image_height, image_width]
192 |
193 |
194 | # Note:
195 | # The following 'SamePadding' functions make output size equal ceil(input size/stride).
196 | # Only when stride equals 1, can the output size be the same as input size.
197 | # Don't be confused by their function names ! ! !
198 |
199 | def get_same_padding_conv2d(image_size=None):
200 | """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
201 | Static padding is necessary for ONNX exporting of models.
202 |
203 | Args:
204 | image_size (int or tuple): Size of the image.
205 |
206 | Returns:
207 | Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
208 | """
209 | if image_size is None:
210 | return Conv2dDynamicSamePadding
211 | else:
212 | return partial(Conv2dStaticSamePadding, image_size=image_size)
213 |
214 |
215 | class Conv2dDynamicSamePadding(nn.Conv2d):
216 | """2D Convolutions like TensorFlow, for a dynamic image size.
217 | The padding is operated in forward function by calculating dynamically.
218 | """
219 |
220 | # Tips for 'SAME' mode padding.
221 | # Given the following:
222 | # i: width or height
223 | # s: stride
224 | # k: kernel size
225 | # d: dilation
226 | # p: padding
227 | # Output after Conv2d:
228 | # o = floor((i+p-((k-1)*d+1))/s+1)
229 | # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
230 | # => p = (i-1)*s+((k-1)*d+1)-i
231 |
232 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
233 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
234 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
235 |
236 | def forward(self, x):
237 | ih, iw = x.size()[-2:]
238 | kh, kw = self.weight.size()[-2:]
239 | sh, sw = self.stride
240 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! !
241 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
242 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
243 | if pad_h > 0 or pad_w > 0:
244 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
245 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
246 |
247 |
248 | class Conv2dStaticSamePadding(nn.Conv2d):
249 | """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
250 | The padding mudule is calculated in construction function, then used in forward.
251 | """
252 |
253 | # With the same calculation as Conv2dDynamicSamePadding
254 |
255 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
256 | super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
257 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
258 |
259 | # Calculate padding based on image size and save it
260 | assert image_size is not None
261 | ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
262 | kh, kw = self.weight.size()[-2:]
263 | sh, sw = self.stride
264 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
265 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
266 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
267 | if pad_h > 0 or pad_w > 0:
268 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2,
269 | pad_h // 2, pad_h - pad_h // 2))
270 | else:
271 | self.static_padding = nn.Identity()
272 |
273 | def forward(self, x):
274 | x = self.static_padding(x)
275 | x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
276 | return x
277 |
278 |
279 | def get_same_padding_maxPool2d(image_size=None):
280 | """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
281 | Static padding is necessary for ONNX exporting of models.
282 |
283 | Args:
284 | image_size (int or tuple): Size of the image.
285 |
286 | Returns:
287 | MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
288 | """
289 | if image_size is None:
290 | return MaxPool2dDynamicSamePadding
291 | else:
292 | return partial(MaxPool2dStaticSamePadding, image_size=image_size)
293 |
294 |
295 | class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
296 | """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
297 | The padding is operated in forward function by calculating dynamically.
298 | """
299 |
300 | def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False):
301 | super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode)
302 | self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
303 | self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
304 | self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
305 |
306 | def forward(self, x):
307 | ih, iw = x.size()[-2:]
308 | kh, kw = self.kernel_size
309 | sh, sw = self.stride
310 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
311 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
312 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
313 | if pad_h > 0 or pad_w > 0:
314 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
315 | return F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
316 | self.dilation, self.ceil_mode, self.return_indices)
317 |
318 |
319 | class MaxPool2dStaticSamePadding(nn.MaxPool2d):
320 | """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
321 | The padding mudule is calculated in construction function, then used in forward.
322 | """
323 |
324 | def __init__(self, kernel_size, stride, image_size=None, **kwargs):
325 | super().__init__(kernel_size, stride, **kwargs)
326 | self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
327 | self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
328 | self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
329 |
330 | # Calculate padding based on image size and save it
331 | assert image_size is not None
332 | ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
333 | kh, kw = self.kernel_size
334 | sh, sw = self.stride
335 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
336 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
337 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
338 | if pad_h > 0 or pad_w > 0:
339 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
340 | else:
341 | self.static_padding = nn.Identity()
342 |
343 | def forward(self, x):
344 | x = self.static_padding(x)
345 | x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
346 | self.dilation, self.ceil_mode, self.return_indices)
347 | return x
348 |
349 |
350 | ################################################################################
351 | # Helper functions for loading model params
352 | ################################################################################
353 |
354 | # BlockDecoder: A Class for encoding and decoding BlockArgs
355 | # efficientnet_params: A function to query compound coefficient
356 | # get_model_params and efficientnet:
357 | # Functions to get BlockArgs and GlobalParams for efficientnet
358 | # url_map and url_map_advprop: Dicts of url_map for pretrained weights
359 | # load_pretrained_weights: A function to load pretrained weights
360 |
361 | class BlockDecoder(object):
362 | """Block Decoder for readability,
363 | straight from the official TensorFlow repository.
364 | """
365 |
366 | @staticmethod
367 | def _decode_block_string(block_string):
368 | """Get a block through a string notation of arguments.
369 |
370 | Args:
371 | block_string (str): A string notation of arguments.
372 | Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
373 |
374 | Returns:
375 | BlockArgs: The namedtuple defined at the top of this file.
376 | """
377 | assert isinstance(block_string, str)
378 |
379 | ops = block_string.split('_')
380 | options = {}
381 | for op in ops:
382 | splits = re.split(r'(\d.*)', op)
383 | if len(splits) >= 2:
384 | key, value = splits[:2]
385 | options[key] = value
386 |
387 | # Check stride
388 | assert (('s' in options and len(options['s']) == 1) or
389 | (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
390 |
391 | return BlockArgs(
392 | num_repeat=int(options['r']),
393 | kernel_size=int(options['k']),
394 | stride=[int(options['s'][0])],
395 | expand_ratio=int(options['e']),
396 | input_filters=int(options['i']),
397 | output_filters=int(options['o']),
398 | se_ratio=float(options['se']) if 'se' in options else None,
399 | id_skip=('noskip' not in block_string))
400 |
401 | @staticmethod
402 | def _encode_block_string(block):
403 | """Encode a block to a string.
404 |
405 | Args:
406 | block (namedtuple): A BlockArgs type argument.
407 |
408 | Returns:
409 | block_string: A String form of BlockArgs.
410 | """
411 | args = [
412 | 'r%d' % block.num_repeat,
413 | 'k%d' % block.kernel_size,
414 | 's%d%d' % (block.strides[0], block.strides[1]),
415 | 'e%s' % block.expand_ratio,
416 | 'i%d' % block.input_filters,
417 | 'o%d' % block.output_filters
418 | ]
419 | if 0 < block.se_ratio <= 1:
420 | args.append('se%s' % block.se_ratio)
421 | if block.id_skip is False:
422 | args.append('noskip')
423 | return '_'.join(args)
424 |
425 | @staticmethod
426 | def decode(string_list):
427 | """Decode a list of string notations to specify blocks inside the network.
428 |
429 | Args:
430 | string_list (list[str]): A list of strings, each string is a notation of block.
431 |
432 | Returns:
433 | blocks_args: A list of BlockArgs namedtuples of block args.
434 | """
435 | assert isinstance(string_list, list)
436 | blocks_args = []
437 | for block_string in string_list:
438 | blocks_args.append(BlockDecoder._decode_block_string(block_string))
439 | return blocks_args
440 |
441 | @staticmethod
442 | def encode(blocks_args):
443 | """Encode a list of BlockArgs to a list of strings.
444 |
445 | Args:
446 | blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
447 |
448 | Returns:
449 | block_strings: A list of strings, each string is a notation of block.
450 | """
451 | block_strings = []
452 | for block in blocks_args:
453 | block_strings.append(BlockDecoder._encode_block_string(block))
454 | return block_strings
455 |
456 |
457 | def efficientnet_params(model_name):
458 | """Map EfficientNet model name to parameter coefficients.
459 |
460 | Args:
461 | model_name (str): Model name to be queried.
462 |
463 | Returns:
464 | params_dict[model_name]: A (width,depth,res,dropout) tuple.
465 | """
466 | params_dict = {
467 | # Coefficients: width,depth,res,dropout
468 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
469 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
470 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
471 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
472 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
473 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
474 | 'efficientnet-b6': (3.8, 5.3, 224, 0.5),
475 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
476 | 'efficientnet-b8': (2.2, 3.6, 672, 0.5),
477 | 'efficientnet-l2': (4.3, 5.3, 800, 0.5),
478 | }
479 | return params_dict[model_name]
480 |
481 |
482 | def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None,
483 | dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True):
484 | """Create BlockArgs and GlobalParams for efficientnet model.
485 |
486 | Args:
487 | width_coefficient (float)
488 | depth_coefficient (float)
489 | image_size (int)
490 | dropout_rate (float)
491 | drop_connect_rate (float)
492 | num_classes (int)
493 |
494 | Meaning as the name suggests.
495 |
496 | Returns:
497 | blocks_args, global_params.
498 | """
499 |
500 | # Blocks args for the whole model(efficientnet-b0 by default)
501 | # It will be modified in the construction of EfficientNet Class according to model
502 | blocks_args = [
503 | 'r1_k3_s11_e1_i32_o16_se0.25',
504 | 'r2_k3_s22_e1_i16_o24_se0.25',
505 | 'r2_k5_s22_e1_i24_o40_se0.25',
506 | 'r3_k3_s22_e1_i40_o80_se0.25',
507 | 'r3_k5_s11_e1_i80_o112_se0.25',
508 | 'r4_k5_s22_e1_i112_o192_se0.25',
509 | 'r1_k3_s11_e1_i192_o320_se0.25',
510 | ]
511 | blocks_args = BlockDecoder.decode(blocks_args)
512 |
513 | global_params = GlobalParams(
514 | width_coefficient=width_coefficient,
515 | depth_coefficient=depth_coefficient,
516 | image_size=image_size,
517 | dropout_rate=dropout_rate,
518 |
519 | num_classes=num_classes,
520 | batch_norm_momentum=0.99,
521 | batch_norm_epsilon=1e-3,
522 | drop_connect_rate=drop_connect_rate,
523 | depth_divisor=8,
524 | min_depth=None,
525 | include_top=include_top,
526 | )
527 |
528 | return blocks_args, global_params
529 |
530 |
531 | def get_model_params(model_name, override_params):
532 | """Get the block args and global params for a given model name.
533 |
534 | Args:
535 | model_name (str): Model's name.
536 | override_params (dict): A dict to modify global_params.
537 |
538 | Returns:
539 | blocks_args, global_params
540 | """
541 | if model_name.startswith('efficientnet'):
542 | w, d, s, p = efficientnet_params(model_name)
543 | # note: all models have drop connect rate = 0.2
544 | blocks_args, global_params = efficientnet(
545 | width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
546 | else:
547 | raise NotImplementedError('model name is not pre-defined: {}'.format(model_name))
548 | if override_params:
549 | # ValueError will be raised here if override_params has fields not included in global_params.
550 | global_params = global_params._replace(**override_params)
551 | return blocks_args, global_params
552 |
553 |
554 | # train with Standard methods
555 | # check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks)
556 | url_map = {
557 | 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
558 | 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
559 | 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
560 | 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
561 | 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
562 | 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
563 | 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
564 | 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
565 | }
566 |
567 | # train with Adversarial Examples(AdvProp)
568 | # check more details in paper(Adversarial Examples Improve Image Recognition)
569 | url_map_advprop = {
570 | 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth',
571 | 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth',
572 | 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth',
573 | 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth',
574 | 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth',
575 | 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth',
576 | 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth',
577 | 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth',
578 | 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth',
579 | }
580 |
581 | # TODO: add the petrained weights url map of 'efficientnet-l2'
582 |
583 |
584 | def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True):
585 | """Loads pretrained weights from weights path or download using url.
586 |
587 | Args:
588 | model (Module): The whole model of efficientnet.
589 | model_name (str): Model name of efficientnet.
590 | weights_path (None or str):
591 | str: path to pretrained weights file on the local disk.
592 | None: use pretrained weights downloaded from the Internet.
593 | load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
594 | advprop (bool): Whether to load pretrained weights
595 | trained with advprop (valid when weights_path is None).
596 | """
597 | if isinstance(weights_path, str):
598 | state_dict = torch.load(weights_path)
599 | else:
600 | # AutoAugment or Advprop (different preprocessing)
601 | url_map_ = url_map_advprop if advprop else url_map
602 | state_dict = model_zoo.load_url(url_map_[model_name])
603 |
604 | if load_fc:
605 | ret = model.load_state_dict(state_dict, strict=False)
606 | assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
607 | else:
608 | state_dict.pop('_fc.weight')
609 | state_dict.pop('_fc.bias')
610 | ret = model.load_state_dict(state_dict, strict=False)
611 | assert set(ret.missing_keys) == set(
612 | ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
613 | assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)
614 |
615 | if verbose:
616 | print('Loaded pretrained weights for {}'.format(model_name))
--------------------------------------------------------------------------------
/parserr.py:
--------------------------------------------------------------------------------
1 | '''Helper script to generate the parameter values for augmentations
2 | '''
3 |
4 | import argparse
5 |
6 | def str2bool(v):
7 |
8 | if isinstance(v, bool):
9 | return v
10 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
11 | return True
12 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
13 | return False
14 | else:
15 | raise argparse.ArgumentTypeError('Boolean value expected.')
16 |
17 | class Arguments_augment():
18 | def __init__(self):
19 | self.color_jitter = 0.4
20 | self.aa = 'rand-m9-mstd0.5-inc1'
21 | self.train_interpolation = 'bicubic'
22 | self.crop_pct = None
23 | self.reprob = 0.25
24 | self.remode = 'pixel'
25 | self.recount = 1
26 | self.resplit = False
27 | self.mixup = 0.8
28 | self.cutmix = 1.0
29 | self.cutmix_minmax = None
30 | self.mixup_prob = 1.0
31 | self.mixup_switch_prob = 0.5
32 | self.mixup_mode = 'batch'
33 | self.nb_classes = 1000
34 | self.input_size = 224
35 | self.data_set = 'IMNET'
36 | self.dist_eval = True
37 | self.hflip = 0.5
38 | self.vflip = 0.0
39 | self.scale = [0.08, 1.0]
40 | self.ratio = [3./4., 4./3.]
41 |
42 |
43 |
44 | class Arguments_No_augment():
45 | def __init__(self):
46 | self.color_jitter = 0.0
47 | self.aa = None #'rand-m9-mstd0.5-inc1'
48 | self.train_interpolation = 'bicubic'
49 | self.crop_pct = None
50 | self.reprob = 0.0
51 | self.remode = None #'pixel'
52 | self.recount = 0
53 | self.resplit = False
54 | self.mixup = 0.0
55 | self.cutmix = 0.
56 | self.cutmix_minmax = None
57 | self.mixup_prob = 0.0
58 | self.mixup_switch_prob = 0.
59 | self.mixup_mode = None
60 | self.nb_classes = 1000
61 | self.data_set = 'IMNET'
62 | self.input_size = 224
63 | self.dist_eval = True
64 | self.hflip = 0.0
65 | self.vflip = 0.0
66 | self.scale = [0.08, 1.0]
67 | self.ratio = [3./4., 4./3.]
68 |
--------------------------------------------------------------------------------
/rb_architecture_util.py:
--------------------------------------------------------------------------------
1 | '''All the submitted to RBench are defined here.
2 | Mostly use timm but have some custom implementation: ConvStem (ConvBlock)
3 | '''
4 |
5 | import torch
6 | import torch.nn as nn
7 | from collections import OrderedDict
8 | from typing import Tuple
9 | from torch import Tensor
10 | import torch.nn as nn
11 |
12 | import timm
13 | from timm.models import create_model
14 | import torch.nn.functional as F
15 | import math
16 |
17 | IMAGENET_MEAN = [c * 1. for c in (0.485, 0.456, 0.406)]
18 | IMAGENET_STD = [c * 1. for c in (0.229, 0.224, 0.225)]
19 |
20 |
21 | class LayerNorm(nn.Module):
22 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
23 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
24 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
25 | with shape (batch_size, channels, height, width).
26 | """
27 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
28 | super().__init__()
29 | self.weight = nn.Parameter(torch.ones(normalized_shape))
30 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
31 | self.eps = eps
32 | self.data_format = data_format
33 | if self.data_format not in ["channels_last", "channels_first"]:
34 | raise NotImplementedError
35 | self.normalized_shape = (normalized_shape, )
36 |
37 | def forward(self, x):
38 | if self.data_format == "channels_last":
39 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
40 | elif self.data_format == "channels_first":
41 | u = x.mean(1, keepdim=True)
42 | s = (x - u).pow(2).mean(1, keepdim=True)
43 | x = (x - u) / torch.sqrt(s + self.eps)
44 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
45 | return x
46 |
47 |
48 | class ImageNormalizer(nn.Module):
49 | '''ADD normalization as a first layer in the models, as AA uses un-normalized inputs.'''
50 |
51 | def __init__(self, persistent: bool = True) -> None:
52 | super(ImageNormalizer, self).__init__()
53 |
54 | self.register_buffer('mean', torch.as_tensor(MEAN).view(1, 3, 1, 1),
55 | persistent=persistent)
56 | self.register_buffer('std', torch.as_tensor(STD).view(1, 3, 1, 1),
57 | persistent=persistent)
58 |
59 | def forward(self, input: Tensor) -> Tensor:
60 | return (input - self.mean) / self.std
61 |
62 |
63 | def normalize_model(model: nn.Module) -> nn.Module:
64 | layers = OrderedDict([
65 | ('normalize', ImageNormalizer()),
66 | ('model', model)
67 | ])
68 | return nn.Sequential(layers)
69 |
70 |
71 |
72 | def get_transforms(img_size=224):
73 | '''returns torch-transform as a callable for RobustBench. Used when testing for increased resolution models.'''
74 | crop_pct = 0.875
75 | scale_size = int(math.floor(img_size / crop_pct))
76 | trans = transforms.Compose([
77 | transforms.Resize(
78 | scale_size,
79 | interpolation=transforms.InterpolationMode("bicubic")),
80 | transforms.CenterCrop(img_size),
81 | transforms.ToTensor()
82 | ])
83 |
84 | return trans
85 |
86 |
87 | class ConvBlock(nn.Module):
88 | expansion = 1
89 | def __init__(self, siz=48, end_siz=8, fin_dim=384):
90 | super(ConvBlock, self).__init__()
91 | self.planes = siz
92 | fin_dim = self.planes*end_siz if fin_dim != 432 else 432
93 | # self.bn = nn.BatchNorm2d(planes) if self.normaliz == "bn" else nn.GroupNorm(num_groups=1, num_channels=planes)
94 | self.stem = nn.Sequential(nn.Conv2d(3, self.planes, kernel_size=3, stride=2, padding=1),
95 | LayerNorm(self.planes, data_format="channels_first"),
96 | nn.GELU(),
97 | nn.Conv2d(self.planes, self.planes*2, kernel_size=3, stride=2, padding=1),
98 | LayerNorm(self.planes*2, data_format="channels_first"),
99 | nn.GELU(),
100 | nn.Conv2d(self.planes*2, self.planes*4, kernel_size=3, stride=2, padding=1),
101 | LayerNorm(self.planes*4, data_format="channels_first"),
102 | nn.GELU(),
103 | nn.Conv2d(self.planes*4, self.planes*8, kernel_size=3, stride=2, padding=1),
104 | LayerNorm(self.planes*8, data_format="channels_first"),
105 | nn.GELU(),
106 | nn.Conv2d(self.planes*8, fin_dim, kernel_size=1, stride=1, padding=0)
107 | )
108 | def forward(self, x):
109 | out = self.stem(x)
110 | # out = self.bn(out)
111 | return out
112 |
113 |
114 | class ConvBlock3(nn.Module):
115 | # expansion = 1
116 | def __init__(self, siz=64):
117 | super(ConvBlock3, self).__init__()
118 | self.planes = siz
119 |
120 | self.stem = nn.Sequential(nn.Conv2d(3, self.planes, kernel_size=3, stride=2, padding=1),
121 | LayerNorm(self.planes, data_format="channels_first"),
122 | nn.GELU(),
123 | nn.Conv2d(self.planes, int(self.planes*1.5), kernel_size=3, stride=2, padding=1),
124 | LayerNorm(int(self.planes*1.5), data_format="channels_first"),
125 | nn.GELU(),
126 | nn.Conv2d(int(self.planes*1.5), self.planes*2, kernel_size=3, stride=1, padding=1),
127 | LayerNorm(self.planes*2, data_format="channels_first"),
128 | nn.GELU()
129 | )
130 |
131 | def forward(self, x):
132 | out = self.stem(x)
133 | # out = self.bn(out)
134 | return out
135 |
136 |
137 | class ConvBlock1(nn.Module):
138 | def __init__(self, siz=48, end_siz=8, fin_dim=384):
139 | super(ConvBlock1, self).__init__()
140 | self.planes = siz
141 |
142 | fin_dim = self.planes*end_siz if fin_dim == None else 432
143 | self.stem = nn.Sequential(nn.Conv2d(3, self.planes, kernel_size=3, stride=2, padding=1),
144 | LayerNorm(self.planes, data_format="channels_first"),
145 | nn.GELU(),
146 | nn.Conv2d(self.planes, self.planes*2, kernel_size=3, stride=2, padding=1),
147 | LayerNorm(self.planes*2, data_format="channels_first"),
148 | nn.GELU()
149 | )
150 |
151 | def forward(self, x):
152 | out = self.stem(x)
153 | # out = self.bn(out)
154 | return out
155 |
156 |
157 | class IdentityLayer(nn.Module):
158 | def forward(self, inputs):
159 | return inputs
160 |
161 |
162 | def get_new_model(modelname, pretrained=False, not_original=True):
163 |
164 | if modelname == 'convnext_t_cvst':
165 | model = timm.models.convnext.convnext_tiny(pretrained=pretrained)
166 | model.stem = ConvBlock1(48, end_siz=8)
167 |
168 | elif modelname == "convnext_s_cvst":
169 | model = timm.models.convnext.convnext_small(pretrained=pretrained)
170 | model.stem = ConvBlock1(48, end_siz=8)
171 |
172 | elif modelname == "convnext_b_cvst":
173 | model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
174 | model = timm.models.convnext._create_convnext('convnext_base.fb_in1k', pretrained=pretrained, **model_args)
175 | model.stem = ConvBlock3(64)
176 |
177 | elif modelname == "convnext_l_cvst":
178 | model = timm.models.convnext_large(pretrained=pretrained)
179 | model.stem = ConvBlock3(96)
180 |
181 | elif modelname == 'vit_s_cvst':
182 | model = create_model('deit_small_patch16_224', pretrained=pretrained)
183 | model.patch_embed.proj = ConvBlock(48, end_siz=8)
184 | model = normalize_model(model)
185 |
186 | elif modelname == 'vit_b_cvst':
187 | model = timm.models.vision_transformer.vit_base_patch16_224(pretrained=pretrained)
188 | model.patch_embed.proj = ConvBlock(48, end_siz=16, fin_dim=None)
189 |
190 | else:
191 | logger.error('Invalid model name, please use either cait, deit, swin, vit, effnet, or rn50')
192 | sys.exit(1)
193 | return model
194 |
195 | def load_model(arch, not_original, chkpt_path):
196 | ''''
197 | Load the model with definition from the checkpoint
198 | arch: architecture name
199 | not_original: If True -> CvSt
200 | chkpt_path: location of checkpoint
201 | '''
202 | model = get_new_model(arch, pretrained=False, not_original=not_original)
203 | ckpt = torch.load(chkpt_path, map_location='cpu')
204 | ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
205 | ckpt = {k.replace('base_model.', ''): v for k, v in ckpt.items()}
206 | ckpt = {k.replace('se_', 'se_module.'): v for k, v in ckpt.items()}
207 |
208 | model.load_state_dict(ckpt)
209 | model = model.to(device)
210 | model.eval()
211 | return model
212 |
--------------------------------------------------------------------------------
/readme_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nmndeep/revisiting-at/d7166c074223b89e3b7e1ec1489a8d069cc5afb7/readme_teaser.png
--------------------------------------------------------------------------------
/run_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export CUDA_VISIBLE_DEVICES = DEVICE_IDS separated by comma
4 |
5 | sleep 2s
6 |
7 | # Set the visible GPUs according to the `world_size` configuration parameter
8 | # Modify `data.in_memory` and `data.num_workers` based on your machine
9 |
10 | python3 main.py --data.num_workers=12 --data.in_memory=1 \
11 | --data.train_dataset=path-to-imagenet-train-set \
12 | --data.val_dataset=path-to-imagenet-val-set \
13 | --logging.folder=path-to-logging-folder --logging.log_level 2 \
14 | --adv.attack apgd --adv.n_iter 2 --adv.norm Linf --training.distributed 1 --training.batch_size 80 --lr.lr 1e-3 --logging.save_freq 2 \
15 | --resolution.min_res 224 --resolution.max_res 224 --data.seed 0 --data.augmentations 1 --model.add_normalization 0\
16 | --model.not_original 1 --model.model_ema 1 --lr.lr_peak_epoch 20\
17 | --training.label_smoothing 0.1 --logging.addendum='additional_text_appended_to_save_folder_name'\
18 | --dist.world_size '# of GPUS' --training.distributed 1 --model.pretrained 0 --model.arch convnext_base --training.epochs 300 \
19 |
20 |
--------------------------------------------------------------------------------
/runner_aa_eval.py:
--------------------------------------------------------------------------------
1 |
2 | import argparse
3 | import time
4 | from subprocess import call
5 | import GPUtil
6 |
7 |
8 | paramss = [
9 | ("location_of_robust_model","convnext_base", 1, 1, 0, "Linf", 100),
10 | ]
11 | str_to_run = 'AA_eval.py'
12 | models_to_run = []
13 |
14 | for minn, modd, not_o, a100, fullaa, norr, bs in paramss:
15 | models_to_run.append('{} --model_in {} --mod {} --not-orig {} --a100 {} --full_aa {} --l_norms {} --batch_size {}'
16 | .format(str_to_run, minn, modd, not_o, a100, fullaa, norr, bs))
17 |
18 | cart_prod = [a \
19 | for a in models_to_run]
20 |
21 | for job in cart_prod:
22 | print(job)
23 |
24 | time.sleep(5)
25 | count = 0
26 | wait = 0
27 | while wait<=1:
28 | gpu_ids = GPUtil.getAvailable(order = 'last', limit = 8, \
29 | maxLoad = .1, maxMemory = .5) # get free gpus listd
30 | if len(gpu_ids) > 0:
31 | print(gpu_ids)
32 | # time.sleep(5)
33 | for id in gpu_ids:
34 | if id == 5:
35 | pass
36 | else:
37 | if id != 10:
38 | temp_list = cart_prod[count]
39 |
40 | command_to_exec = '' +\
41 | ' CUDA_VISIBLE_DEVICES='+str(id)+\
42 | ' python3' +\
43 | ' ' + temp_list\
44 | + ' &' # for going to next iteration without job in background.
45 |
46 | print("Command executing is " + command_to_exec)
47 | call(command_to_exec, shell=True)
48 | print('done executing in '+str(id))
49 | count += 1
50 | time.sleep(2) # wait for processes to start
51 | # if count == 3:
52 | # time.sleep(7200)
53 | else:
54 | print('No gpus free waiting for 30 seconds')
55 | time.sleep(15)
56 | wait+=1
57 | # time.sleep(3600*3) # wait for processes to start
58 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import math
4 | import requests
5 | import torch
6 | from collections import OrderedDict
7 | from model_zoo.models import model_dicts
8 | #from models_new import l_models_all, l_models_imagenet
9 |
10 | def download_gdrive(gdrive_id, fname_save):
11 | """ source: https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url """
12 | def get_confirm_token(response):
13 | for key, value in response.cookies.items():
14 | if key.startswith('download_warning'):
15 | return value
16 |
17 | return None
18 |
19 | def save_response_content(response, fname_save):
20 | CHUNK_SIZE = 32768
21 |
22 | with open(fname_save, "wb") as f:
23 | for chunk in response.iter_content(CHUNK_SIZE):
24 | if chunk: # filter out keep-alive new chunks
25 | f.write(chunk)
26 |
27 | print('Download started: path={} (gdrive_id={})'.format(fname_save, gdrive_id))
28 |
29 | url_base = "https://docs.google.com/uc?export=download&confirm=t"
30 | session = requests.Session()
31 |
32 | response = session.get(url_base, params={'id': gdrive_id}, stream=True)
33 | token = get_confirm_token(response)
34 |
35 | if token:
36 | params = {'id': gdrive_id, 'confirm': token}
37 | response = session.get(url_base, params=params, stream=True)
38 |
39 | save_response_content(response, fname_save)
40 | print('Download finished: path={} (gdrive_id={})'.format(fname_save, gdrive_id))
41 |
42 |
43 | def rm_substr_from_state_dict(state_dict, substr):
44 | new_state_dict = OrderedDict()
45 | for key in state_dict.keys():
46 | if substr in key: # to delete prefix 'module.' if it exists
47 | new_key = key[len(substr):]
48 | new_state_dict[new_key] = state_dict[key]
49 | else:
50 | new_state_dict[key] = state_dict[key]
51 | return new_state_dict
52 |
53 |
54 | def load_model(model_name, model_dir='./models', norm='Linf'):
55 | model_dir_norm = '{}/{}'.format(model_dir, norm)
56 | if not isinstance(model_dicts[norm][model_name]['gdrive_id'], list):
57 | model_path = '{}/{}/{}.pt'.format(model_dir, norm, model_name)
58 | model = model_dicts[norm][model_name]['model']()
59 | if not os.path.exists(model_dir_norm):
60 | os.makedirs(model_dir_norm)
61 | if not os.path.isfile(model_path):
62 | download_gdrive(model_dicts[norm][model_name]['gdrive_id'], model_path)
63 | checkpoint = torch.load(model_path, map_location='cuda:0')
64 |
65 | # needed for the model of `Carmon2019Unlabeled`
66 | try:
67 | state_dict = rm_substr_from_state_dict(checkpoint['state_dict'], 'module.')
68 | except:
69 | state_dict = rm_substr_from_state_dict(checkpoint, 'module.')
70 |
71 | model.load_state_dict(state_dict, strict=True)
72 | return model.cuda().eval()
73 |
74 | # If we have an ensemble of models (e.g., Chen2020Adversarial)
75 | else:
76 | model_path = '{}/{}/{}'.format(model_dir, norm, model_name)
77 | model = model_dicts[norm][model_name]['model']()
78 | if not os.path.exists(model_dir_norm):
79 | os.makedirs(model_dir_norm)
80 | for i, gid in enumerate(model_dicts[norm][model_name]['gdrive_id']):
81 | if not os.path.isfile('{}_m{}.pt'.format(model_path, i)):
82 | download_gdrive(gid, '{}_m{}.pt'.format(model_path, i))
83 | checkpoint = torch.load('{}_m{}.pt'.format(model_path, i), map_location='cuda:0')
84 | try:
85 | state_dict = rm_substr_from_state_dict(checkpoint['state_dict'], 'module.')
86 | except:
87 | state_dict = rm_substr_from_state_dict(checkpoint, 'module.')
88 | model.models[i].load_state_dict(state_dict)
89 | model.models[i].cuda().eval()
90 | return model
91 |
92 |
93 | def clean_accuracy(model, x, y, batch_size=100):
94 | acc = 0.
95 | n_batches = math.ceil(x.shape[0] / batch_size)
96 | with torch.no_grad():
97 | for counter in range(n_batches):
98 | x_curr = x[counter * batch_size:(counter + 1) * batch_size].cuda()
99 | y_curr = y[counter * batch_size:(counter + 1) * batch_size].cuda()
100 |
101 | output = model(x_curr)
102 | acc += (output.max(1)[1] == y_curr).float().sum()
103 |
104 | return acc.item() / x.shape[0]
105 |
106 |
107 | def get_accuracy_and_logits(model, x, y, batch_size=100, n_classes=10):
108 | logits = torch.zeros([y.shape[0], n_classes])
109 | acc = 0.
110 | n_batches = math.ceil(x.shape[0] / batch_size)
111 | with torch.no_grad():
112 | for counter in range(n_batches):
113 | x_curr = x[counter * batch_size:(counter + 1) * batch_size].cuda()
114 | y_curr = y[counter * batch_size:(counter + 1) * batch_size].cuda()
115 |
116 | output = model(x_curr)
117 | logits[counter * batch_size:(counter + 1) * batch_size] += output.cpu()
118 | acc += (output.max(1)[1] == y_curr).float().sum()
119 |
120 | return acc.item() / x.shape[0], logits
121 |
122 |
123 |
124 | def list_available_models(norm='Linf'):
125 | models = model_dicts[norm].keys()
126 |
127 | json_dicts = []
128 | for model_name in models:
129 | with open('./model_info/{}.json'.format(model_name), 'r') as model_info:
130 | json_dict = json.load(model_info)
131 | json_dict['model_name'] = model_name
132 | json_dict['venue'] = 'Unpublished' if json_dict['venue'] == '' else json_dict['venue']
133 | json_dict['AA'] = float(json_dict['AA']) / 100
134 | json_dict['clean_acc'] = float(json_dict['clean_acc']) / 100
135 | json_dicts.append(json_dict)
136 |
137 | json_dicts = sorted(json_dicts, key=lambda d: -d['AA'])
138 | print('| # | Model ID | Paper | Clean accuracy | Robust accuracy | Architecture | Venue |')
139 | print('|:---:|---|---|:---:|:---:|:---:|:---:|')
140 | for i, json_dict in enumerate(json_dicts):
141 | print('| **{}** | **{}** | *[{}]({})* | {:.2%} | {:.2%} | {} | {} |'.format(
142 | i+1, json_dict['model_name'], json_dict['name'], json_dict['link'], json_dict['clean_acc'], json_dict['AA'],
143 | json_dict['architecture'], json_dict['venue']))
144 |
145 |
146 | def load_model_fast_at(model_name, norm, model_dir, fts_before_bn):
147 | from model_zoo.fast_models import PreActResNet18, model_names
148 | model_name_long = model_names[norm][model_name]
149 | activation = [c.split('activation=')[-1] for c in model_name_long.split(' ') if 'activation' in c]
150 | if 'normal=' in model_name_long:
151 | normal = model_name_long.split('normal=')[1].split(' ')[0]
152 | else:
153 | normal = 'none'
154 |
155 | if 'resnet18' in model_name_long:
156 | model = PreActResNet18(n_cls=10, activation=activation[0], fts_before_bn=fts_before_bn,
157 | normal=normal)
158 | ckpt = torch.load('{}/{}'.format(model_dir, model_name_long))
159 | model.load_state_dict({k: v for k, v in ckpt['last'].items() if 'model_preact_hl1' not in k})
160 | return model.eval()
161 |
162 |
163 | def load_model_ssl(model_name, model_dir):
164 | from model_zoo.ssl_models import models_dict
165 | data = models_dict[model_name]
166 | model = data['model']()
167 | ckpt_base = torch.load('{}/{}'.format(model_dir, data['base']))['model']
168 | ckpt_base = rm_substr_from_state_dict(ckpt_base, 'module.')
169 | model.base.load_state_dict(ckpt_base)
170 | ckpt_lin = torch.load('{}/{}'.format(model_dir, data['linear']))['model']
171 | ckpt_lin = rm_substr_from_state_dict(ckpt_lin, 'module.')
172 | model.linear.load_state_dict(ckpt_lin)
173 | model.cuda()
174 | model.eval()
175 | return model
176 |
177 | def load_anymodel(model_name, model_dir='./models'):
178 | if len(model_name) == 2:
179 | return load_model(model_name[0], model_dir='./models', norm=model_name[1]).cuda().eval()
180 | elif len(model_name) == 3 and model_name[2] == 'fast_at':
181 | return load_model_fast_at(model_name[0], model_name[1],
182 | model_dir=model_dir, #'./models' #'../understanding-fast-adv-training-dev/models'
183 | fts_before_bn=False).cuda().eval()
184 | elif len(model_name) == 3 and model_name[2] == 'ssl':
185 | model = load_model_ssl(model_name[0], './models/ssl_models')
186 | assert not model.base.training
187 | assert not model.linear.training
188 | return model
189 | elif len(model_name) == 3 and model_name[2] == 'ext':
190 | from model_zoo.ext_models import load_ext_models
191 | return load_ext_models(model_name[0])
192 |
193 |
194 | def load_anymodel_imagenet(model_name, **kwargs):
195 | if len(model_name) == 2:
196 | from model_zoo.models_imagenet import load_model as load_model_imagenet
197 | return load_model_imagenet(model_name[0], norm=model_name[1])
198 | elif len(model_name) == 3 and model_name[2] == 'pretrained':
199 | from model_zoo.models_imagenet import PretrainedModel
200 | model = PretrainedModel(model_name[0])
201 | assert not model.model.training
202 | return model
203 | elif len(model_name) == 3 and model_name[2] == 'ssl':
204 | model = load_model_ssl(model_name[0], './models/ssl_models')
205 | assert not model.base.training
206 | assert not model.linear.training
207 | return model
208 | elif len(model_name) == 3 and model_name[2] == 'ext':
209 | from model_zoo.ext_models import load_ext_models_imagenet
210 | return load_ext_models_imagenet(model_name[0], **kwargs)
211 |
212 | def load_anymodel_cifar100(model_name):
213 | if len(model_name) == 2:
214 | from model_zoo.models_cifar100 import load_model as load_model_cifar100
215 | return load_model_cifar100(model_name[0], norm=model_name[1])
216 |
217 |
218 | def load_anymodel_imagenet100(model_name):
219 | if len(model_name) == 3 and model_name[2] == 'ext':
220 | from model_zoo.ext_models import load_ext_models_imagenet100
221 | return load_ext_models_imagenet100(model_name[0], model_name[1])
222 |
223 |
224 | def load_anymodel_mnist(model_name):
225 | if len(model_name) == 2:
226 | from model_zoo.models_mnist import load_model as load_model_mnist
227 | return load_model_mnist(model_name[0], norm=model_name[1])
228 |
229 |
230 | '''def load_anymodel_datasets(args):
231 | fts_idx = [int(c) for c in args.fts_idx.split(' ')]
232 | if args.dataset == 'cifar10':
233 | l_models = [l_models_all[c] for c in fts_idx]
234 | print(l_models)
235 | model = load_anymodel(l_models[0])
236 | model.eval()
237 | elif args.dataset == 'imagenet':
238 | l_models = [l_models_imagenet[c] for c in fts_idx]
239 | print(l_models)
240 | model = load_anymodel_imagenet(l_models[0])
241 | #sys.exit()
242 | with torch.no_grad():
243 | acc = clean_accuracy(model, x, y, batch_size=25)
244 | print('clean accuracy: {:.1%}'.format(acc))
245 | return model'''
246 |
247 | if __name__ == '__main__':
248 | #list_available_models()
249 | pass
250 |
--------------------------------------------------------------------------------
/utils_architecture.py:
--------------------------------------------------------------------------------
1 | '''All the models are defined here.
2 | Custom implementation: ConvStem (ConvBlock) with standard models built on timm.
3 | '''
4 |
5 | import torch
6 | import torch.nn as nn
7 | from collections import OrderedDict
8 | from typing import Tuple
9 | from torch import Tensor
10 | import torch.nn as nn
11 |
12 | import timm
13 | from functools import partial
14 | from timm.models import create_model
15 | from timm.models.convnext import _create_convnext as CNXT
16 | import torch.nn.functional as F
17 | from functools import partial
18 | import math
19 | from timm.models.vision_transformer import VisionTransformer
20 |
21 |
22 | def interpolate_pos_encoding(
23 | pos_embed: Tensor,
24 | new_img_size: int,
25 | old_img_size: int = 224,
26 | patch_size: int = 16) -> Tensor:
27 | """Interpolates the positional encoding of ViTs for new image resolution
28 | (adapted from https://github.com/facebookresearch/dino/blob/main/vision_transformer.py#L174).
29 | It currently handles only square images.
30 | """
31 | N = pos_embed.shape[1] - 1
32 | npatch = (new_img_size // patch_size) ** 2
33 | w, h = new_img_size, new_img_size
34 | if npatch == N and w == h:
35 | print(f'Positional encoding not changed.')
36 | return pos_embed
37 | print(f'Interpolating positional encoding from {N} to {npatch} patches (size={patch_size}).')
38 | class_pos_embed = pos_embed[:, 0]
39 | patch_pos_embed = pos_embed[:, 1:]
40 | dim = pos_embed.shape[-1]
41 | w0 = w // patch_size
42 | h0 = h // patch_size
43 | # we add a small number to avoid floating point error in the interpolation
44 | # see discussion at https://github.com/facebookresearch/dino/issues/8
45 | w0, h0 = w0 + 0.1, h0 + 0.1
46 | patch_pos_embed = nn.functional.interpolate(
47 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
48 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
49 | mode='bicubic',
50 | )
51 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
52 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
53 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
54 |
55 |
56 |
57 | class LayerNorm(nn.Module):
58 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
59 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
60 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
61 | with shape (batch_size, channels, height, width).
62 | """
63 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
64 | super().__init__()
65 | self.weight = nn.Parameter(torch.ones(normalized_shape))
66 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
67 | self.eps = eps
68 | self.data_format = data_format
69 | if self.data_format not in ["channels_last", "channels_first"]:
70 | raise NotImplementedError
71 | self.normalized_shape = (normalized_shape, )
72 |
73 | def forward(self, x):
74 | if self.data_format == "channels_last":
75 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
76 | elif self.data_format == "channels_first":
77 | u = x.mean(1, keepdim=True)
78 | s = (x - u).pow(2).mean(1, keepdim=True)
79 | x = (x - u) / torch.sqrt(s + self.eps)
80 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
81 | return x
82 |
83 |
84 |
85 |
86 | class ImageNormalizer(nn.Module):
87 | def __init__(self, mean: Tuple[float, float, float],
88 | std: Tuple[float, float, float],
89 | persistent: bool = True) -> None:
90 | super(ImageNormalizer, self).__init__()
91 |
92 | self.register_buffer('mean', torch.as_tensor(mean).view(1, 3, 1, 1),
93 | persistent=persistent)
94 | self.register_buffer('std', torch.as_tensor(std).view(1, 3, 1, 1),
95 | persistent=persistent)
96 |
97 | def forward(self, input: Tensor) -> Tensor:
98 | return (input - self.mean) / self.std
99 |
100 |
101 | def timm_gelu(inplace):
102 | return nn.GELU()
103 |
104 | def convert_relu_to_gelu(model):
105 | for child_name, child in model.named_children():
106 | if isinstance(child, nn.ReLU):
107 | setattr(model, child_name, nn.GELU())
108 | else:
109 | convert_relu_to_gelu(child)
110 |
111 | def normalize_model(model: nn.Module, mean: Tuple[float, float, float],
112 | std: Tuple[float, float, float]) -> nn.Module:
113 | layers = OrderedDict([
114 | ('normalize', ImageNormalizer(mean, std)),
115 | ('model', model)
116 | ])
117 | return nn.Sequential(layers)
118 |
119 |
120 | class ConvBlock(nn.Module):
121 | expansion = 1
122 | def __init__(self, siz=48, end_siz=8, fin_dim=384):
123 | super(ConvBlock, self).__init__()
124 | self.planes = siz
125 | fin_dim = self.planes*end_siz if fin_dim != 432 else 432
126 | # self.bn = nn.BatchNorm2d(planes) if self.normaliz == "bn" else nn.GroupNorm(num_groups=1, num_channels=planes)
127 | self.stem = nn.Sequential(nn.Conv2d(3, self.planes, kernel_size=3, stride=2, padding=1),
128 | LayerNorm(self.planes, data_format="channels_first"),
129 | nn.GELU(),
130 | nn.Conv2d(self.planes, self.planes*2, kernel_size=3, stride=2, padding=1),
131 | LayerNorm(self.planes*2, data_format="channels_first"),
132 | nn.GELU(),
133 | nn.Conv2d(self.planes*2, self.planes*4, kernel_size=3, stride=2, padding=1),
134 | LayerNorm(self.planes*4, data_format="channels_first"),
135 | nn.GELU(),
136 | nn.Conv2d(self.planes*4, self.planes*8, kernel_size=3, stride=2, padding=1),
137 | LayerNorm(self.planes*8, data_format="channels_first"),
138 | nn.GELU(),
139 | nn.Conv2d(self.planes*8, fin_dim, kernel_size=1, stride=1, padding=0)
140 | )
141 | def forward(self, x):
142 | out = self.stem(x)
143 | # out = self.bn(out)
144 | return out
145 |
146 | class ConvBlock2(nn.Module):
147 | """Used only for det-medium"""
148 | expansion = 1
149 | def __init__(self, siz=48, end_siz=8, fin_dim=384):
150 | super(ConvBlock2, self).__init__()
151 | self.planes = siz
152 | fin_dim = self.planes*end_siz if fin_dim != 432 else 432
153 | # self.bn = nn.BatchNorm2d(planes) if self.normaliz == "bn" else nn.GroupNorm(num_groups=1, num_channels=planes)
154 | self.stem = nn.Sequential(nn.Conv2d(3, self.planes, kernel_size=3, stride=2, padding=1),
155 | LayerNorm(self.planes, data_format="channels_first"),
156 | nn.GELU(),
157 | nn.Conv2d(self.planes, self.planes*2, kernel_size=3, stride=2, padding=1),
158 | LayerNorm(self.planes*2, data_format="channels_first"),
159 | nn.GELU(),
160 | nn.Conv2d(self.planes*2, self.planes*4, kernel_size=3, stride=2, padding=1),
161 | LayerNorm(self.planes*4, data_format="channels_first"),
162 | nn.GELU(),
163 | nn.Conv2d(self.planes*4, self.planes*8, kernel_size=3, stride=2, padding=1),
164 | LayerNorm(self.planes*8, data_format="channels_first"),
165 | nn.GELU(),
166 | nn.Conv2d(self.planes*8, 512, kernel_size=1, stride=1, padding=0)
167 | )
168 | def forward(self, x):
169 | out = self.stem(x)
170 | # out = self.bn(out)
171 | return out
172 |
173 |
174 | class ConvBlock3(nn.Module):
175 | # expansion = 1
176 | def __init__(self, siz=64):
177 | super(ConvBlock3, self).__init__()
178 | self.planes = siz
179 |
180 | self.stem = nn.Sequential(nn.Conv2d(3, self.planes, kernel_size=3, stride=2, padding=1),
181 | LayerNorm(self.planes, data_format="channels_first"),
182 | nn.GELU(),
183 | nn.Conv2d(self.planes, int(self.planes*1.5), kernel_size=3, stride=2, padding=1),
184 | LayerNorm(int(self.planes*1.5), data_format="channels_first"),
185 | nn.GELU(),
186 | nn.Conv2d(int(self.planes*1.5), self.planes*2, kernel_size=3, stride=1, padding=1),
187 | LayerNorm(self.planes*2, data_format="channels_first"),
188 | nn.GELU()
189 | )
190 |
191 |
192 | def forward(self, x):
193 | out = self.stem(x)
194 | # out = self.bn(out)
195 | return out
196 |
197 |
198 | class ConvBlock1(nn.Module):
199 | # expansion = 1
200 | def __init__(self, siz=48, end_siz=8, fin_dim=384):
201 | super(ConvBlock1, self).__init__()
202 | self.planes = siz
203 |
204 | fin_dim = self.planes*end_siz if fin_dim == None else 432
205 | self.stem = nn.Sequential(nn.Conv2d(3, self.planes, kernel_size=3, stride=2, padding=1),
206 | LayerNorm(self.planes, data_format="channels_first"),
207 | nn.GELU(),
208 | nn.Conv2d(self.planes, self.planes*2, kernel_size=3, stride=2, padding=1),
209 | LayerNorm(self.planes*2, data_format="channels_first"),
210 | nn.GELU()
211 | )
212 |
213 |
214 | def forward(self, x):
215 | out = self.stem(x)
216 | # out = self.bn(out)
217 | return out
218 |
219 |
220 | class IdentityLayer(nn.Module):
221 | def forward(self, inputs):
222 | return inputs
223 |
224 |
225 | def get_new_model(modelname, pretrained=True, not_original=False, updated=False):
226 |
227 |
228 | if modelname == 'resnet50':
229 | model = timm.models.resnet.resnet50(pretrained=pretrained)
230 |
231 | elif modelname == 'resnet50_gelu':
232 | model = timm.models.resnet.resnet50(pretrained=pretrained,
233 | act_layer=timm_gelu)
234 |
235 | # elif modelname == 'convnext_iso':
236 |
237 | # model = cnxt_iso.convnext_isotropic_small(pretrained=pretrained, dim=384, depth=18)
238 | # if not_original:
239 | # setattr(model, 'stem', ConvBlock(48, end_siz=8, fin_dim=432 if updated else 384))
240 |
241 | elif modelname == 'convnext_tiny':
242 | model = timm.models.convnext.convnext_tiny(pretrained=pretrained)
243 | if not_original:
244 | model.stem = ConvBlock1(48, end_siz=8)
245 |
246 | elif modelname == "convnext_tiny_21k":
247 | model = timm.models.convnext._create_convnext('convnext_tiny.fb_in22k_ft_in1k', pretrained=pretrained)
248 |
249 | elif modelname == "convnext_small":
250 | model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])
251 | model = timm.models.convnext.convnext_small(pretrained=pretrained)
252 | if not_original:
253 | ## only for removing patch-stem
254 | model.stem = ConvBlock1(48, end_siz=8)
255 |
256 | elif modelname == "convnext_base":
257 | model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
258 | # model = timm.models.create_model('convnext_base', pretrained=pretrained, pretrained_cfg='convnext_base.fb_in1k')
259 | model = timm.models.convnext._create_convnext('convnext_base.fb_in1k', pretrained=pretrained, **model_args)
260 | if not_original:
261 | ## only for removing patch-stem
262 | model.stem = ConvBlock3(64)
263 |
264 | elif modelname == "convnext_large":
265 |
266 | model = timm.models.convnext_large(pretrained=pretrained)
267 |
268 | if not_original:
269 | model.stem = ConvBlock3(96)
270 |
271 | elif modelname == 'vit_s':
272 | model = create_model('vit_small_patch16_224', pretrained=pretrained)
273 | if not_original:
274 | ## only for removing patch-stem
275 | model.patch_embed.proj = ConvBlock(48, end_siz=8)
276 |
277 | elif modelname == 'deit_s':
278 | from timm.models.deit import deit3_small_patch16_224, _create_deit
279 | model_kwargs = dict(
280 | patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=False,
281 | init_values=None)
282 | model = create_model('deit_small_patch16_224', pretrained=pretrained)
283 | if not_original:
284 | model.patch_embed.proj = ConvBlock(48, end_siz=8)
285 |
286 | elif modelname == 'vit_m':
287 |
288 | model = timm.models.deit.deit3_medium_patch16_224(pretrained=pretrained)
289 | if not_original:
290 | ## only for removing patch-stem
291 | model.patch_embed.proj = ConvBlock2(48)
292 |
293 | elif modelname == 'vit_s_21k':
294 | model = create_model('deit3_small_patch16_224_in21ft1k', pretrained=pretrained)
295 |
296 |
297 | elif modelname == 'vit_b':
298 | model = timm.models.vision_transformer.vit_base_patch16_224(pretrained=pretrained) # used for 21k pretrained on 206
299 |
300 | if not_original:
301 | model.patch_embed.proj = ConvBlock(48, end_siz=16, fin_dim=None)
302 |
303 |
304 | elif modelname == "resnet101":
305 | model = timm.models.resnet.resnet101(pretrained=False)
306 |
307 | elif modelname=="wrn_50_2":
308 | model = timm.models.resnet.wide_resnet50_2(pretrained=False)
309 |
310 | elif modelname == "densnet201":
311 | model = timm.models.densenet.densenet201(pretrained=pretrained)
312 |
313 | elif modelname == "inception":
314 | model = create_model('inception_v3', pretrained=pretrained)
315 |
316 |
317 |
318 | else:
319 | logger.error('Invalid model name, please use either cait, deit, swin, vit, effnet, or rn50')
320 | sys.exit(1)
321 |
322 | return model
323 |
--------------------------------------------------------------------------------
/utils_eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from torchvision import datasets, transforms
6 | import numpy as np
7 | import math
8 | import os
9 | import json
10 |
11 | # from data import load_cifar10c
12 | from models_new import l_models_all, l_models_imagenet, l_models_cifar100,\
13 | l_models_imagenet100, l_models_mnist
14 | from utils import load_anymodel, load_anymodel_imagenet, clean_accuracy,\
15 | load_anymodel_cifar100, load_anymodel_imagenet100, load_anymodel_mnist
16 | try:
17 | from other_utils import L1_norm, L2_norm, Linf_norm, Logger, L0_norm
18 | except ImportError:
19 | from autoattack.other_utils import L1_norm, L2_norm, Logger, L0_norm
20 | from apgd_mask import criterion_dict
21 |
22 |
23 | cls = ['airplane', 'automobile', 'bird', 'cat', 'deer',
24 | 'dog', 'frog', 'horse', 'ship', 'truck']
25 |
26 |
27 | class CalibratedModel(nn.Module):
28 | def __init__(self, model, temp):
29 | super().__init__()
30 | assert not model.training
31 | self.model = model
32 | assert temp > 0.
33 | self.temp = temp
34 |
35 | def forward(self, x):
36 | return self.model(x) / self.temp
37 |
38 |
39 | def Linf_norm():
40 | raise NotImplementedError('Linf_norm to be added.')
41 |
42 |
43 | def get_acc_cifar10c(model, n_ex=10000, severities=[5],
44 | corruptions=('shot_noise', 'motion_blur', 'snow', 'pixelate',
45 | 'gaussian_noise', 'defocus_blur', 'brightness', 'fog', 'zoom_blur',
46 | 'frost', 'glass_blur', 'impulse_noise', 'contrast',
47 | 'jpeg_compression', 'elastic_transform'), bs=250):
48 | l_acc = []
49 | acc_dets = {}
50 |
51 | for s in severities:
52 | x, y = load_cifar10c(n_ex, severity=s, corruptions=corruptions)
53 | x = x.contiguous()
54 | print(x.shape)
55 | with torch.no_grad():
56 | acc = 0.
57 | n_batches = math.ceil(x.shape[0] / bs)
58 | for counter in range(n_batches):
59 | output = model(x[counter * bs:(counter + 1) * bs].cuda())
60 | acc += (output.cpu().max(dim=1)[1] == y[counter * bs:(counter + 1) * bs]).sum()
61 | l_acc.append(acc / x.shape[0])
62 | acc_dets[str(s)] = acc / x.shape[0]
63 | print('sev={}, clean accuracy={:.2%}'.format(s, acc / x.shape[0]))
64 |
65 | return acc / x.shape[0], acc_dets
66 |
67 | def check_imgs(adv, x, norm):
68 | delta = (adv - x).view(adv.shape[0], -1)
69 | if norm == 'Linf':
70 | res = delta.abs().max(dim=1)[0]
71 | elif norm == 'L2':
72 | res = (delta ** 2).sum(dim=1).sqrt()
73 | elif norm == 'L1':
74 | res = delta.abs().sum(dim=1)
75 |
76 | str_det = 'max {} pert: {:.5f}, nan in imgs: {}, max in imgs: {:.5f}, min in imgs: {:.5f}'.format(
77 | norm, res.max(), (adv != adv).sum(), adv.max(), adv.min())
78 | print(str_det)
79 | print(adv.max().item() - 1., adv.min().item())
80 |
81 | return str_det
82 |
83 | def get_cifar10_class(lab):
84 | return cls[lab]
85 |
86 | def get_imagenet_class(lab):
87 | if torch.is_tensor(lab):
88 | lab = lab.item()
89 | with open('./imagenet_classes.json') as json_file:
90 | class_dict = json.load(json_file)
91 | return class_dict[str(lab)][1]
92 |
93 | def get_class(args, cl=None):
94 | if cl is None:
95 | cl = args.target_class
96 | if args.dataset == 'cifar10':
97 | return get_cifar10_class(cl)
98 | elif args.dataset == 'imagenet':
99 | return get_imagenet_class(cl)
100 |
101 | def makedir(path):
102 | if not os.path.exists(path):
103 | os.makedirs(path)
104 |
105 | def load_anymodel_datasets(args):
106 | fts_idx = [int(c) for c in args.fts_idx.split(' ')]
107 | if args.dataset == 'cifar10':
108 | l_models = [l_models_all[c] for c in fts_idx]
109 | print(l_models)
110 | model = load_anymodel(l_models[0], args.model_dir)
111 | model.eval()
112 | elif args.dataset == 'imagenet':
113 | l_models = [l_models_imagenet[c] for c in fts_idx]
114 | print(l_models)
115 | kwargs = {}
116 | if (l_models[0][0].startswith('DeiT')
117 | #and 'convblock' not in l_models[0][0]
118 | or l_models[0][0].startswith('ViT')
119 | ):
120 | kwargs = {'img_size': args.img_size}
121 | model = load_anymodel_imagenet(l_models[0], **kwargs)
122 | #sys.exit()
123 | '''with torch.no_grad():
124 | acc = clean_accuracy(model, x, y, batch_size=25)
125 | print('clean accuracy: {:.1%}'.format(acc))'''
126 | elif args.dataset == 'cifar100':
127 | l_models = [l_models_cifar100[c] for c in fts_idx]
128 | print(l_models)
129 | model = load_anymodel_cifar100(l_models[0])
130 | model.eval()
131 | elif args.dataset == 'imagenet100':
132 | l_models = [l_models_imagenet100[c] for c in fts_idx]
133 | print(l_models)
134 | model = load_anymodel_imagenet100(l_models[0])
135 | model.eval()
136 | elif args.dataset == 'mnist':
137 | l_models = [l_models_mnist[c] for c in fts_idx]
138 | print(l_models)
139 | model = load_anymodel_mnist(l_models[0])
140 | model.eval()
141 | return model, l_models
142 |
143 |
144 | def attack_group(norm, suffix=''):
145 | if norm in ['Linf', 'L2', 'L1']:
146 | return 'aa' + suffix
147 | else:
148 | return 'nonlpattacks'
149 |
150 |
151 | def get_norm(z, norm):
152 | if norm == 'Linf':
153 | return Linf_norm(z)
154 | elif norm == 'L2':
155 | return L2_norm(z)
156 | elif norm == 'L1':
157 | return L1_norm(z)
158 | elif norm == 'L0':
159 | return L0_norm(z)
160 |
161 |
162 | def get_logits(model, x_test, bs=1000, device=None, n_cls=10):
163 | if device is None:
164 | device = x_test.device
165 | n_batches = math.ceil(x_test.shape[0] / bs)
166 | logits = torch.zeros([x_test.shape[0], n_cls], device=device)
167 | #l_logits = []
168 |
169 | with torch.no_grad():
170 | for counter in range(n_batches):
171 | x_curr = x_test[counter * bs:(counter + 1) * bs].to(device)
172 | output = model(x_curr)
173 | #l_logits.append(output.detach())
174 | logits[counter * bs:(counter + 1) * bs] += output.detach()
175 |
176 | return logits
177 |
178 |
179 | def get_wc_acc(model, xs, y, bs=1000, device=None, eot_test=1, logger=None,
180 | loss=None, n_cls=10):
181 | if device is None:
182 | device = x.device
183 | if logger is None:
184 | logger = Logger(None)
185 | if not loss is None:
186 | criterion_indiv = criterion_dict[loss]
187 | y = y.to(device)
188 | acc = torch.ones_like(y, device=device).float()
189 | x_adv = xs[0].clone()
190 | loss_best = -1. * float('inf') * torch.ones(y.shape[0], device=device)
191 |
192 | for x in xs:
193 | logits = get_logits(model, x, bs=bs, device=device, n_cls=n_cls)
194 | loss_curr = criterion_indiv(logits, y)
195 | pred_curr = logits.max(1)[1] == y
196 | ind = ~pred_curr * (loss_curr > loss_best) # misclassified points with higher loss
197 | x_adv[ind] = x[ind].clone()
198 | acc *= pred_curr
199 | ind = (acc == 1.) * (loss_curr > loss_best) # for robust points track highest loss
200 | x_adv[ind] = x[ind].clone()
201 | logger.log(f'[rob acc] cum={acc.mean():.1%} curr={pred_curr.float().mean():.1%}')
202 |
203 | print(torch.nonzero(acc).squeeze())
204 |
205 | return acc.mean(), x_adv
206 |
207 |
208 | def get_patchsize(dataset, modelname):
209 | if dataset == 'imagenet':
210 | if modelname in ['ConvMixer_1024_20_nat', 'ConvMixer_1024_20_eps4_best']:
211 | return 14
212 | else:
213 | return 16
214 | else:
215 | return 8
216 |
--------------------------------------------------------------------------------
/utils_train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from torchvision import datasets, transforms
6 | import numpy as np
7 | import copy
8 | import random
9 | import math
10 | import time
11 | from collections import OrderedDict
12 | from typing import Tuple
13 | from torch import Tensor
14 | import os
15 |
16 | try:
17 | from autopgd_pt import L1_projection
18 | except ImportError:
19 | from autoattack.autopgd_base import L1_projection
20 |
21 | class FGSMAttack():
22 | def __init__(self, eps=8. / 255., step_size=None, loss=None):
23 | self.eps = eps
24 | self.step_size = step_size if not step_size is None else eps
25 | self.loss = loss
26 |
27 | def perturb(self, model, x, y, random_start=False):
28 | assert not self.loss is None
29 | if random_start:
30 | t = (x + (2. * torch.rand_like(x) - 1.) * self.eps).clamp(0., 1.)
31 | else:
32 | t = x.clone()
33 |
34 | t.requires_grad = True
35 | output = model(t)
36 | loss = self.loss(output, y)
37 | grad = torch.autograd.grad(loss, t)[0]
38 |
39 | x_adv = x + grad.detach().sign() * self.step_size
40 | return torch.min(torch.max(x_adv, x - self.eps), x + self.eps).clamp(0., 1.)
41 |
42 | class PGDAttack():
43 | def __init__(self, eps=8. / 255., step_size=None, loss=None, n_iter=10,
44 | norm='Linf', verbose=False):
45 | self.eps = eps
46 | self.step_size = step_size if not step_size is None else eps / n_iter * 1.5
47 | self.loss = loss
48 | self.n_iter = n_iter
49 | self.norm = norm
50 | self.verbose = verbose
51 |
52 | def perturb(self, model, x, y, random_start=False, return_acc=False):
53 | assert not self.loss is None
54 | if random_start:
55 | x_adv = (x + (2. * torch.rand_like(x) - 1.) * self.eps).clamp(0., 1.)
56 | else:
57 | x_adv = x.clone()
58 |
59 | n_fts = x.shape[1] * x.shape[2] * x.shape[3]
60 | x_best, loss_best = x_adv.clone(), torch.zeros_like(y).float()
61 | acc = torch.ones_like(y).detach().float()
62 |
63 | x_adv.requires_grad = True
64 | output = model(x_adv)
65 | loss = self.loss(output, y, reduction='none')
66 | ind = loss > loss_best
67 | x_best[ind] = x_adv[ind].clone().detach()
68 | loss_best[ind] = loss[ind].clone().detach()
69 | acc[ind] = (output.max(dim=1)[1] == y).float()[ind].clone().detach()
70 | grad = torch.autograd.grad(loss.mean(), x_adv)[0]
71 |
72 | for it in range(self.n_iter):
73 | if self.norm == 'Linf':
74 | x_adv = x_adv.detach() + grad.detach().sign() * self.step_size
75 | x_adv = torch.min(torch.max(x_adv, x - self.eps), x + self.eps).clamp(0., 1.)
76 |
77 | elif self.norm == 'L2':
78 | x_adv = x_adv.detach() + grad.detach() / (grad.detach() ** 2
79 | ).sum(dim=(1, 2, 3), keepdim=True).sqrt() * self.step_size
80 | delta_l2norm = ((x_adv - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt()
81 | x_adv = (x + (x_adv - x) * torch.min(torch.ones_like(delta_l2norm),
82 | self.eps * torch.ones_like(delta_l2norm) / delta_l2norm)).clamp(0., 1.)
83 |
84 | elif self.norm == 'L1':
85 | grad = grad.detach()
86 | grad_topk = grad.abs().view(grad.shape[0], -1).topk(
87 | k=max(int(.1 * n_fts), 1), dim=-1)[0][:, -1].view(grad.shape[0],
88 | *[1]*(len(grad.shape) - 1))
89 | sparsegrad = grad * (grad.abs() >= grad_topk).float()
90 | x_adv = x_adv.detach() + self.step_size * sparsegrad / (sparsegrad.abs().view(
91 | x.shape[0], -1).sum(dim=-1).view(-1, 1, 1, 1) + 1e-10)
92 | delta_temp = L1_projection(x, x_adv - x, self.eps)
93 | x_adv += delta_temp
94 |
95 | x_adv.requires_grad = True
96 | output = model(x_adv)
97 | loss = self.loss(output, y, reduction='none')
98 | ind = loss > loss_best
99 | x_best[ind] = x_adv[ind].clone().detach()
100 | loss_best[ind] = loss[ind].clone().detach()
101 | acc[ind] = (output.max(dim=1)[1] == y).float()[ind].clone().detach()
102 | grad = torch.autograd.grad(loss.mean(), x_adv)[0]
103 |
104 | if self.verbose:
105 | print('[{}] it={} loss={:.5f} acc={:.1%}'.format(self.norm, it, loss_best.mean().item(),
106 | (output.max(dim=1)[1] == y).cpu().float().mean()))
107 |
108 | if not return_acc:
109 | return x_best.detach()
110 | else:
111 | return x_best.detach(), acc
112 |
113 | class MSDAttack():
114 | def __init__(self, eps, step_size=None, loss=None, n_iter=10):
115 | self.eps = eps
116 | self.step_size = step_size if not step_size is None else [eps / n_iter * 1.25 for eps in self.eps]
117 | self.loss = loss
118 | self.n_iter = n_iter
119 |
120 | def perturb(self, model, x, y, random_start=False):
121 | assert not self.loss is None
122 | if random_start:
123 | x_adv = (x + (2. * torch.rand_like(x) - 1.) * self.eps).clamp(0., 1.)
124 | else:
125 | x_adv = x.clone()
126 |
127 | n_fts = x.shape[1] * x.shape[2] * x.shape[3]
128 | x_best, loss_best = x_adv.clone(), torch.zeros_like(y).float()
129 | #x_adv.requires_grad = True
130 |
131 | for it in range(self.n_iter):
132 | x_adv.requires_grad = True
133 | #with torch.enable_grad()
134 | output = model(x_adv)
135 | loss = self.loss(output, y, reduction='none')
136 | avgloss = loss.mean()
137 | grad = torch.autograd.grad(avgloss, x_adv)[0]
138 | ind = loss > loss_best
139 | x_best[ind] = x_adv[ind].clone().detach()
140 | loss_best[ind] = loss[ind].clone().detach()
141 | #grad = torch.autograd.grad(loss.mean(), x_adv)[0].detach()
142 |
143 | x_adv_linf = x_adv.detach() + grad.detach().sign() * self.step_size[0]
144 | x_adv_linf = torch.min(torch.max(x_adv_linf, x - self.eps[0]), x + self.eps[0]).clamp(0., 1.)
145 |
146 | x_adv_l2 = x_adv.detach() + grad.detach() / (grad.detach() ** 2
147 | ).sum(dim=(1, 2, 3), keepdim=True).sqrt() * self.step_size[1]
148 | delta_l2norm = ((x_adv_l2 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt()
149 | x_adv_l2 = (x + (x_adv_l2 - x) * torch.min(torch.ones_like(delta_l2norm),
150 | self.eps[1] * torch.ones_like(delta_l2norm) / delta_l2norm)).clamp(0., 1.)
151 |
152 | grad = grad.detach()
153 | grad_topk = grad.abs().view(grad.shape[0], -1).topk(
154 | k=max(int(.1 * n_fts), 1), dim=-1)[0][:, -1].view(grad.shape[0],
155 | *[1]*(len(grad.shape) - 1))
156 | sparsegrad = grad * (grad.abs() >= grad_topk).float()
157 | x_adv_l1 = x_adv.detach() + self.step_size[2] * sparsegrad / (sparsegrad.abs().view(
158 | x.shape[0], -1).sum(dim=-1).view(-1, 1, 1, 1) + 1e-10)
159 | delta_temp = L1_projection(x, x_adv_l1 - x, self.eps[2])
160 | x_adv_l1 += delta_temp
161 |
162 | '''x_adv_linf.requires_grad_()
163 | x_adv_l2.requires_grad_()
164 | x_adv_l1.reuiqres_grad_()'''
165 | l_x_adv = [x_adv_linf, x_adv_l2, x_adv_l1]
166 | l_output = [model(x_adv_linf), model(x_adv_l2), model(x_adv_l1)]
167 | l_loss = [self.loss(c, y, reduction='none') for c in l_output]
168 | #l_avgloss = [c.mean() for c in l_loss]
169 | val_max, ind_max = torch.max(torch.stack(l_loss, dim=1), dim=1)
170 | #x_adv, loss = l_x_adv[ind_max], l_loss[ind_max]
171 | x_adv = x_adv_linf.clone()
172 | x_adv[ind_max == 1] = x_adv_l2[ind_max == 1].clone()
173 | x_adv[ind_max == 2] = x_adv_l1[ind_max == 2].clone()
174 | #print('it={} - best norm=({}, {}, {}) - loss={}'.format(it + 1, (ind_max == 0).sum(),
175 | # (ind_max == 1).sum(), (ind_max == 2).sum(), loss_best.mean().item()))
176 |
177 | return x_best.detach()
178 |
179 | class MultiPGDAttack():
180 | def __init__(self, eps, step_size=None, loss=None, n_iter=[10, 10, 10], use_miscl=False,
181 | l_norms=None):
182 | self.eps = eps
183 | self.step_size = step_size if not step_size is None else [eps / n_iter * 1.5 for eps in self.eps]
184 | self.loss = loss
185 | self.n_iter = n_iter
186 | self.indiv_adversary = PGDAttack(eps[0], loss=loss)
187 | self.use_miscl = use_miscl
188 | self.l_norms = l_norms if not l_norms is None else ['Linf', 'L2', 'L1']
189 |
190 | def perturb(self, model, x, y, random_start=False, return_acc=False):
191 | #assert not self.loss is None
192 | l_x_adv = []
193 | l_acc = []
194 | for i, norm in enumerate(self.l_norms):
195 | self.indiv_adversary.eps = self.eps[i] + 0.
196 | self.indiv_adversary.step_size = self.step_size[i] + 0.
197 | self.indiv_adversary.norm = norm + ''
198 | self.indiv_adversary.n_iter = self.n_iter[i]
199 | if not return_acc:
200 | x_curr = self.indiv_adversary.perturb(model, x, y)
201 | else:
202 | x_curr, acc_curr = self.indiv_adversary.perturb(model, x, y, return_acc=True)
203 | l_acc.append(acc_curr)
204 | l_x_adv.append(x_curr.clone())
205 |
206 | if not self.use_miscl:
207 | if not return_acc:
208 | return torch.cat(l_x_adv, dim=0)
209 | else:
210 | return torch.cat(l_x_adv, dim=0), torch.cat(l_acc, dim=0)
211 | else:
212 | #logits = torch.zeros([len(l_x_adv), x.shape[0], 10]).cuda()
213 | loss = torch.zeros([len(l_x_adv), x.shape[0]]).cuda()
214 | for i, x_adv in enumerate(l_x_adv):
215 | output = model(x_adv)
216 | loss[i] = self.loss(output, y) - 1e5 * (output.max(dim=1)[1] == y).float()
217 | ind_max = loss.max(dim=0)[1]
218 | x_adv = l_x_adv[0].clone()
219 | x_adv[ind_max == 1] = l_x_adv[1][ind_max == 1].clone()
220 | x_adv[ind_max == 2] = l_x_adv[2][ind_max == 2].clone()
221 |
222 | return x_adv
223 |
224 | # data loaders
225 | def load_data(args):
226 | crop_input_size = 0 if args.crop_input is None else args.crop_input
227 | crop_data_size = 0 if args.crop_data is None else args.crop_data
228 | if args.dataset == 'cifar10':
229 | train_transform = transforms.Compose([
230 | transforms.RandomCrop(32 - crop_data_size, padding=4 - crop_input_size),
231 | transforms.RandomHorizontalFlip(),
232 | #transforms.RandomRotation(15),
233 | transforms.ToTensor(),
234 | ])
235 | elif args.dataset == 'svhn':
236 | train_transform = transforms.Compose([
237 | #transforms.RandomCrop(32 - crop_data_size, padding=4 - crop_input_size),
238 | transforms.ToTensor(),
239 | ])
240 | elif args.dataset == 'mnist':
241 | train_transform = transforms.Compose([
242 | transforms.ToTensor(),
243 | ])
244 | test_transform = transforms.Compose([
245 | transforms.ToTensor(),
246 | ])
247 |
248 | root = args.data_dir + '' #'/home/EnResNet/WideResNet34-10/data/'
249 | num_workers = 2
250 |
251 | if args.dataset == 'cifar10':
252 | train_dataset = datasets.CIFAR10(
253 | root, train=True, transform=train_transform, download=True)
254 | test_dataset = datasets.CIFAR10(
255 | root, train=False, transform=test_transform, download=True)
256 | elif args.dataset == 'svhn':
257 | train_dataset = datasets.SVHN(root=root, split='train',
258 | transform=train_transform, download=True)
259 | test_dataset = datasets.SVHN(root=root, split='test', #train=False
260 | transform=test_transform, download=True)
261 | elif args.dataset == 'mnist':
262 | train_dataset = datasets.MNIST(
263 | root, train=True, transform=train_transform, download=True)
264 | test_dataset = datasets.MNIST(
265 | root, train=False, transform=test_transform, download=True)
266 |
267 | train_loader = torch.utils.data.DataLoader(
268 | dataset=train_dataset,
269 | batch_size=args.batch_size,
270 | shuffle=True,
271 | pin_memory=True,
272 | num_workers=num_workers,
273 | )
274 | test_loader = torch.utils.data.DataLoader(
275 | dataset=test_dataset,
276 | batch_size=args.batch_size_eval,
277 | shuffle=False,
278 | pin_memory=True,
279 | num_workers=0,
280 | )
281 |
282 | return train_loader, test_loader
283 |
284 | def load_imagenet_train(args):
285 | from robustness.datasets import DATASETS
286 | from robustness.tools import helpers
287 | data_paths = ['/home/scratch/datasets/imagenet',
288 | '/scratch_local/datasets/ImageNet2012',
289 | '/mnt/qb/datasets/ImageNet2012',
290 | '/scratch/datasets/imagenet/']
291 | for data_path in data_paths:
292 | if os.path.exists(data_path):
293 | break
294 | print(f'found dataset at {data_path}')
295 | dataset = DATASETS['imagenet'](data_path) #'/home/scratch/datasets/imagenet'
296 |
297 |
298 | train_loader, val_loader = dataset.make_loaders(2,
299 | args.batch_size, data_aug=True)
300 |
301 | train_loader = helpers.DataPrefetcher(train_loader)
302 | val_loader = helpers.DataPrefetcher(val_loader)
303 | return train_loader, val_loader
304 |
305 | # other utils
306 | def get_accuracy(model, data_loader=None):
307 | assert not model.training
308 | if not data_loader is None:
309 | acc = 0.
310 | c = 0
311 | with torch.no_grad():
312 | for (x, y) in data_loader:
313 | output = model(x.cuda())
314 | acc += (output.cpu().max(dim=1)[1] == y).float().sum()
315 | c += x.shape[0]
316 | return acc.item() / c
317 |
318 | def get_lr_schedule(args):
319 | if args.lr_schedule == 'superconverge':
320 | lr_schedule = lambda t: np.interp([t], [0, args.epochs * 2 // 5, args.epochs], [0, args.lr_max, 0])[0]
321 | # lr_schedule = lambda t: np.interp([t], [0, args.epochs], [0, args.lr_max])[0]
322 | elif args.lr_schedule == 'piecewise':
323 | def lr_schedule(t):
324 | if t / args.epochs < 0.5:
325 | return args.lr_max
326 | elif t / args.epochs < 0.75:
327 | return args.lr_max / 10.
328 | else:
329 | return args.lr_max / 100.
330 | elif args.lr_schedule.startswith('piecewise'):
331 | w = [float(c) for c in args.lr_schedule.split('-')[1:]]
332 | def lr_schedule(t):
333 | c = 0
334 | while t / args.epochs > sum(w[:c + 1]) / sum(w):
335 | c += 1
336 | return args.lr_max / 10. ** c
337 | return lr_schedule
338 |
339 | def norm_schedule(it, epoch, epochs, l_norms, ps=None, schedule='piecewise'):
340 | #assert l_norms == ['Linf', 'L2', 'L1']
341 | if schedule == 'piecewise':
342 | if epoch < epochs * .5:
343 | return l_norms.index('L2')
344 | else:
345 | if not ps is None:
346 | ind_linf = l_norms.index('Linf')
347 | ind_l1 = l_norms.index('L1')
348 | return random.choices([ind_linf, ind_l1], weights=[
349 | ps[ind_linf], ps[ind_l1]])[0]
350 | if it % 2 == 0:
351 | return l_norms.index('Linf')
352 | else:
353 | return l_norms.index('L1')
354 |
355 |
356 | # swa tools
357 | def polyak_averaging(p_local, p_new, it):
358 | return (it * p_local + p_new) / (it + 1.)
359 |
360 | def exp_decay(p_local, p_new, theta=.995):
361 | return theta * p_local + (1. - theta) * p_new
362 |
363 | class AveragedModel(nn.Module):
364 | def __init__(self, model):
365 | super(AveragedModel, self).__init__()
366 | self._model = copy.deepcopy(model)
367 | #
368 |
369 | def forward(self, x):
370 | return self._model(x)
371 |
372 | @torch.no_grad()
373 | def update_parameters(self, model, avg_fun=exp_decay):
374 | for p_local, p_new in zip(self._model.parameters(), model.parameters()):
375 | p_local.set_(avg_fun(p_local, p_new))
376 |
377 |
378 | # initializations
379 | def initialize_weights(module):
380 | if isinstance(module, nn.Conv2d):
381 | n = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
382 | module.weight.data.normal_(0, math.sqrt(2. / n))
383 | if module.bias is not None:
384 | module.bias.data.zero_()
385 | elif isinstance(module, nn.Linear):
386 | n = module.in_features
387 | module.weight.data.normal_(0, math.sqrt(2. / n))
388 | if module.bias is not None:
389 | module.bias.data.zero_()
390 | elif isinstance(module, nn.BatchNorm2d):
391 | module.weight.data.fill_(1)
392 | module.bias.data.zero_()
393 |
394 | # stepsizes for pgd-at
395 | def get_stepsize(norm, eps, method='default'):
396 | #assert method == 'default'
397 | if method == 'default':
398 | if norm == 'Linf':
399 | return eps / 4.
400 | elif norm == 'L2':
401 | return eps / 3.
402 | elif norm == 'L1':
403 | return 2. * eps * 255. / 2000.
404 | else:
405 | raise ValueError('please specify a norm')
406 | elif method == 'msd':
407 | if norm == 'Linf':
408 | return eps / 4.
409 | elif norm == 'L2':
410 | return eps / 3.
411 | elif norm == 'L1':
412 | return 1. #2. * eps * 255. / 2000.
413 | else:
414 | raise ValueError('please specify a norm')
415 | elif method == 'msd-5':
416 | if norm == 'Linf':
417 | return eps / 2.
418 | elif norm == 'L2':
419 | return eps / 1.5
420 | elif norm == 'L1':
421 | return eps / 2.
422 | else:
423 | raise ValueError('please specify a norm')
424 | elif method == 'half':
425 | return eps / 2.
426 |
427 | # utils max strategy
428 | def form_batch_max(l_adv, l_acc, l_loss, l_norm):
429 | bs = l_adv[0].shape[0]
430 | adv = l_adv[0].clone()
431 | best_norm = torch.zeros([bs]).long() #[ for _ in range(bs)]
432 | best_loss = l_loss[0].clone()
433 | best_acc = l_acc[0].clone()
434 | for counter in range(1, len(l_norm)):
435 | ind = l_loss[counter] > best_loss
436 | adv[ind] = l_adv[counter][ind].clone()
437 | best_norm[ind] = counter + 0
438 | best_loss[ind] = l_loss[counter][ind].clone()
439 | best_acc[ind] = l_acc[counter][ind].clone()
440 | #best_norm = [l_norm[best_norm[val].item()] + '' for val in range(bs)]
441 |
442 | return adv, best_norm, best_acc, best_loss
443 |
444 | def random_crop(x, size, padding):
445 | z = torch.zeros([x.shape[0], x.shape[1], size + 2 * padding,
446 | size + 2 * padding], device=x.device)
447 | z[:, :, padding:padding + size, padding:padding + size] += x
448 |
449 | a = random.randint(0, 2 * padding)
450 | b = random.randint(0, 2 * padding)
451 |
452 | return z[:, :, a:a + size, b:b + size]
453 |
454 |
455 | class BatchTracker():
456 | def __init__(self, imgs, labs, bs, norms, alpha):
457 | self.imgs_orig = imgs.clone()
458 | self.labs_orig = labs.clone()
459 | self.bs = bs
460 | self.n_ex = imgs.shape[0]
461 | self.norms = norms
462 | #self.loss_norms = torch.zeros([self.n_ex, 2]) #{k: torch.zeros([self.n_ex]) for k in norms}
463 | #self.count_norms = self.loss_norms.clone()
464 | self.loss_norms_ra = torch.zeros([self.n_ex, 2])
465 | self.alpha = alpha
466 |
467 | def batch_new_epoch(self):
468 | self.ind_sort = torch.randperm(self.n_ex)
469 | self.batch_init = 0
470 | #self.loss_norms[k]
471 | u = torch.ones_like(self.loss_norms_ra[:, 0])
472 | '''ps = (self.loss_norms[:, 0] / torch.max(self.count_norms[:, 0], u)) / torch.max(
473 | self.loss_norms[:, 0] / torch.max(self.count_norms[:, 0], u
474 | ) + self.loss_norms[:, 1] / torch.max(self.count_norms[:, 1], u), u)
475 | ps = torch.max(ps, .1 * u)'''
476 |
477 | tot_curr = self.loss_norms_ra[:, 0] + self.loss_norms_ra[:, 1]
478 | ind_tot_curr = tot_curr == 0.
479 | tot_curr[tot_curr == 0.] = 1.
480 | ps = self.loss_norms_ra[:, 0] / tot_curr
481 | ps_old = ps.clone()
482 | #ps = ps * (self.loss_norms_ra.min(dim=1)[0] > 0.) + .5 * (self.loss_norms_ra.min(dim=1)[0] <= 0.)
483 | ps[(ps == 0.) + (ps == 1.)] = 1. - ps[(ps == 0.) + (ps == 1.)]
484 | if True: #False
485 | #
486 | ps = (self.loss_norms_ra[:, 0] > self.loss_norms_ra[:, 1]).float() #ps.min(2)[1]
487 | ps[ps_old == 0.] = 1.
488 | ps[ps_old == 1.] = 0.
489 |
490 | #print(ps)
491 | ps[ind_tot_curr] = .5
492 |
493 | #print(ps)
494 | #print(ind_tot_curr.sum(), ((ps_old == 0.) + (ps_old == 1.)).sum())
495 | #norm_at = (ps < random.random()).long()
496 | #print(self.labs_orig)
497 | #print(norm_at)
498 |
499 | train_loader = []
500 | for c in range(0, self.n_ex, self.bs):
501 | ind_curr = self.ind_sort[c:c + self.bs].clone()
502 | x_curr = self.custom_augm(self.imgs_orig[ind_curr].clone())
503 | y_curr = self.labs_orig[ind_curr].clone()
504 | norm_curr = (ps[ind_curr] < random.random()).long().clone() #norm_at[ind_curr].clone()
505 | #print(y_curr, norm_curr)
506 | train_loader.append((x_curr.clone(), y_curr, norm_curr))
507 | return train_loader
508 |
509 | def custom_augm(self, x):
510 | z = random_crop(x, x.shape[-1], 4)
511 | if random.random() > .5:
512 | return transforms.functional.hflip(z)
513 | else:
514 | return z.clone()
515 |
516 | def update_loss(self, loss, norm, i):
517 | ind_curr = self.ind_sort[i * self.bs:(i + 1) * self.bs].clone()
518 | #print(ind_curr)
519 | #self.loss_norms[ind_curr, norm] += loss
520 | #self.count_norms[ind_curr, norm] += 1
521 | self.loss_norms_ra[ind_curr, norm] = self.loss_norms_ra[ind_curr, norm
522 | ] * self.alpha + loss.cpu() * (1. - self.alpha)
523 |
524 |
525 | # different resolution
526 | class ImageCropper(nn.Module):
527 | def __init__(self, w: int, #float, float, float
528 | shape: Tuple[int, int, int, int]) -> None:
529 | super(ImageCropper, self).__init__()
530 |
531 | mask = get_mask(shape, w)
532 | self.register_buffer('mask', mask)
533 |
534 | def forward(self, input: Tensor) -> Tensor:
535 | return input * self.mask
536 |
537 |
538 | def get_mask(shape, w):
539 | mask = torch.zeros(shape)
540 | assert len(mask.shape) == 4
541 | mask[:, :, w:mask.shape[-2] - w, w:mask.shape[-1] - w] = 1
542 |
543 | return mask
544 |
545 |
546 | def add_preprocessing(model: nn.Module, w: int, shape: Tuple[int, int, int]
547 | ) -> nn.Module:
548 | layers = OrderedDict([
549 | ('crop', ImageCropper(w, [1] + shape)),
550 | ('model', model)
551 | ])
552 | return nn.Sequential(layers)
553 |
554 |
555 | def Lp_norm(x, p, keepdim=False):
556 | assert p > 1
557 | z = x.view(x.shape[0], -1)
558 | t = (z.abs() ** p).sum(dim=-1) ** (1. / p)
559 | if keepdim:
560 | t = t.view(-1, *[1] * (len(x.shape) - 1))
561 | return t
562 |
563 |
564 | if __name__ == '__main__':
565 | '''args = lambda: 0
566 | args.batch_size = 256
567 | train_loader, test_loader = load_imagenet_train(args)
568 | print(len(train_loader), len(test_loader))'''
569 |
570 | custom_loader = BatchTracker(torch.ones([10, 3, 5, 5]) * torch.arange(
571 | 10).view(-1, 1, 1, 1), torch.arange(10), 3, ['Linf', 'L1'], .9)
572 | for epoch in range(4):
573 | startt = time.time()
574 | train_loader = custom_loader.batch_new_epoch()
575 | print('loader created in {:.3f} s'.format(time.time() - startt))
576 | loss = torch.zeros([10])
577 | #norm_all = loss.clone()
578 | #print(custom_loader.loss_norms_ra)
579 | for i in range(4): #(x, y, norm) in enumerate(train_loader)
580 | x, y, norm = train_loader[i]
581 | #print(y, norm)
582 | loss = torch.randn([y.shape[0]]).abs()
583 | custom_loader.update_loss(loss, norm, i)
584 | print(x.view(x.shape[0], -1).max(dim=1)[0])
585 | #print(custom_loader.loss_norms)
586 | #print(custom_loader.count_norms)
587 | #print(custom_loader.loss_norms_ra)
588 |
589 |
--------------------------------------------------------------------------------