├── LICENSE
├── README.md
├── argparser.py
├── attacks
├── __init__.py
├── patch_attacker.py
└── pgd_attacker.py
├── bound_layers.py
├── config.py
├── config
├── cifar_adveval_patch_22.json
├── cifar_adveval_patch_55.json
├── cifar_robtrain_k10_sparse.json
├── cifar_robtrain_k1_sparse.json
├── cifar_robtrain_k4_sparse.json
├── cifar_robtrain_p22_all.json
├── cifar_robtrain_p22_allpool2.json
├── cifar_robtrain_p22_allpool22.json
├── cifar_robtrain_p22_guide10.json
├── cifar_robtrain_p22_guide20.json
├── cifar_robtrain_p22_rand1.json
├── cifar_robtrain_p22_rand10.json
├── cifar_robtrain_p22_rand20.json
├── cifar_robtrain_p22_rand5.json
├── cifar_robtrain_p55_all.json
├── cifar_robtrain_p55_allpool2.json
├── cifar_robtrain_p55_allpool22.json
├── cifar_robtrain_p55_guide10.json
├── cifar_robtrain_p55_guide20.json
├── cifar_robtrain_p55_rand1.json
├── cifar_robtrain_p55_rand10.json
├── cifar_robtrain_p55_rand20.json
├── cifar_robtrain_p55_rand5.json
├── mnist_adveval_patch_22.json
├── mnist_adveval_patch_55.json
├── mnist_robtrain_k10_sparse.json
├── mnist_robtrain_k1_sparse.json
├── mnist_robtrain_k4_sparse.json
├── mnist_robtrain_p22_all.json
├── mnist_robtrain_p22_guide10.json
├── mnist_robtrain_p22_rand1.json
├── mnist_robtrain_p22_rand10.json
├── mnist_robtrain_p22_rand5.json
├── mnist_robtrain_p55_all.json
├── mnist_robtrain_p55_guide10.json
├── mnist_robtrain_p55_rand1.json
├── mnist_robtrain_p55_rand10.json
└── mnist_robtrain_p55_rand5.json
├── converter.py
├── datasets.py
├── defaults.json
├── eval.py
├── model_defs.py
├── train.py
└── unet.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2019 Huan Zhang, Hongge Chen and Chaowei Xiao
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5 |
6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7 |
8 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9 |
10 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Certified Defenses for Adversarial Patches - ICLR 2020
2 | =====================
3 | This repository implements the _first_ certified defense method against adversarial patch attack.
4 | Our methodology extends Interval Bound Propagation ([IBP](https://arxiv.org/abs/1810.12715))
5 | to defending against patch attack. The resulting model achieves certified accuracy
6 | that exceeds empirical robust accuracy of previous empirical defense methods, such as
7 | [Local Gradient Smoothing](https://arxiv.org/abs/1807.01216) or [Digital Watermarking](https://ieeexplore.ieee.org/document/8575371). More details of our methodology can be found
8 | in the paper below:
9 |
10 | [**Certified Defenses for Adversarial Patches**](https://openreview.net/forum?id=HyeaSkrYPH¬eId=HyeaSkrYPH)
11 | _Ping-yeh Chiang*, Renkun Ni*, Ahmed Abdelkader, Chen Zhu, Christoph Studor, Tom Goldstein_
12 | ICLR 2020
13 |
14 | Reproduce Best Performing Models
15 | ---------------------
16 | You can reproduce our best performing models against patch attack by running the following scripts. You could also download pretrained models [here](https://drive.google.com/file/d/1cw3N3M3mZ4AXS8d3psKzgMiq7U5LWowL/view?usp=sharing)
17 | ```bash
18 | python train.py --config config/cifar_robtrain_p22_guide20.json --model_subset 3
19 | python train.py --config config/cifar_robtrain_p55_rand20.json --model_subset 3
20 | python train.py --config config/mnist_robtrain_p22_all.json --model_subset 0
21 | python train.py --config config/mnist_robtrain_p55_all.json --model_subset 0
22 | ```
23 |
24 | The IBP method also yields good performance against sparse attack. The models can be reproduced by running the following scripts
25 | ```bash
26 | python train.py --config config/cifar_robtrain_k4_sparse.json
27 | python train.py --config config/cifar_robtrain_k10_sparse.json
28 | python train.py --config config/mnist_robtrain_k4_sparse.json
29 | python train.py --config config/mnist_robtrain_k10_sparse.json
30 | ```
31 |
32 | To evaluate the trained models, use `eval.py` with the same arguments
33 | ```bash
34 | python eval.py --config config/cifar_robtrain_p22_guide20.json --model_subset 3
35 | python eval.py --config config/cifar_robtrain_p55_rand20.json --model_subset 3
36 | python eval.py --config config/mnist_robtrain_p22_all.json --model_subset 0
37 | python eval.py --config config/mnist_robtrain_p55_all.json --model_subset 0
38 | python eval.py --config config/cifar_robtrain_k4_sparse.json
39 | python eval.py --config config/cifar_robtrain_k10_sparse.json
40 | python eval.py --config config/mnist_robtrain_k4_sparse.json
41 | python eval.py --config config/mnist_robtrain_k10_sparse.json
42 | ```
43 | If you run into cuda memory error, you can increase the number of gpus with `--gpu` argument (e.g. `--gpu 0,1,2,3`)
44 |
45 | Results
46 | ---------------------
47 |
48 | |Dataset | Training Method | Model Architecture | Attack Model | Certified Accuracy | Clean Accuracy|
49 | |:-------: | :------: | :-------: | :-------: | :-------: | :-------:|
50 | |MNIST | All Patch | MLP | 2×2 patch | 91.51% | 98.55% |
51 | |MNIST | All Patch | MLP | 5×5 patch | 61.85% | 93.81% |
52 | |CIFAR | Guided Patch 20 | 5-layer CNN | 2×2 patch | 53.02% | 66.50% |
53 | |CIFAR | Random Patch 20 | 5-layer CNN | 5×5 patch | 30.30% | 47.80% |
54 | |MNIST | Sparse | MLP | sparse k=4 | 90.70% | 97.20% |
55 | |MNIST | Sparse | MLP | sparse k=10 | 75.60% | 94.64% |
56 | |CIFAR | Sparse | MLP | sparse k=4 | 32.70% | 49.82% |
57 | |CIFAR | Sparse | MLP | sparse k=10 | 28.21% | 44.34% |
58 |
59 | References
60 | ---------------------
61 | Sven Gowal, Krishnamurthy Dvijotham, Robert Stanforth, Rudy Bunel, Chongli Qin, Jonathan Uesato, Timothy Mann, and Pushmeet Kohli. "On the effectiveness of interval bound propagation for training verifiably robust models." arXiv preprint arXiv:1810.12715 (2018).
62 |
63 | Huan Zhang, Hongge Chen, Chaowei Xiao, Sven Gowal, Robert Stanforth, Bo Li, Duane Boning, Cho-Jui Hsieh "Towards Stable and Efficient Training of Verifiably Robust Neural Networks" arXiv preprint arXiv:1906.06316 (2019)
64 |
65 |
66 | Citation
67 | ---------------------
68 | ```bash
69 | @inproceedings{
70 | Chiang2020Certified,
71 | title={Certified Defenses for Adversarial Patches},
72 | author={Ping-yeh Chiang* and Renkun Ni* and Ahmed Abdelkader and Chen Zhu and Christoph Studor and Tom Goldstein},
73 | booktitle={International Conference on Learning Representations},
74 | year={2020},
75 | url={https://openreview.net/forum?id=HyeaSkrYPH}
76 | }
77 | ```
78 |
79 |
--------------------------------------------------------------------------------
/argparser.py:
--------------------------------------------------------------------------------
1 | ## Copyright (C) 2019, Huan Zhang
2 | ## Hongge Chen
3 | ## Chaowei Xiao
4 | ##
5 | ## This program is licenced under the BSD 2-Clause License,
6 | ## contained in the LICENCE file in this directory.
7 | ##
8 | import os
9 | import random
10 | import torch
11 | import numpy as np
12 | import argparse
13 | from ast import literal_eval
14 |
15 | def isfloat(value):
16 | try:
17 | float(value)
18 | return True
19 | except ValueError:
20 | return False
21 |
22 | def isint(value):
23 | try:
24 | int(value)
25 | return True
26 | except ValueError:
27 | return False
28 |
29 | def argparser(seed = 2019):
30 |
31 | parser = argparse.ArgumentParser()
32 |
33 | # configure file
34 | parser.add_argument('--config', default="UNSPECIFIED.json")
35 | parser.add_argument('--model_subset', type=int, nargs='+',
36 | help='Use only a subset of models in config file. Pass a list of numbers starting with 0, like --model_subset 0 1 3 5')
37 | parser.add_argument('--path_prefix', type=str, default="", help="override path prefix")
38 | parser.add_argument('--seed', type=int, default=seed)
39 | parser.add_argument('--resume', type=str, default="")
40 | parser.add_argument('--gpu', type=str, default="0")
41 | parser.add_argument('overrides', type=str, nargs='*',
42 | help='overriding config dict')
43 | parser.add_argument("--grad-acc-steps", type=int, default=1)
44 |
45 | args = parser.parse_args()
46 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
47 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
48 |
49 | torch.manual_seed(args.seed)
50 | torch.cuda.manual_seed(args.seed)
51 | random.seed(args.seed)
52 | np.random.seed(args.seed)
53 |
54 | # for dual norm computation, we will have 1 / 0.0 = inf
55 | np.seterr(divide='ignore')
56 |
57 | overrides_dict = {}
58 | for o in args.overrides:
59 | key, val = o.strip().split("=")
60 | d = overrides_dict
61 | last_key = key
62 | if ":" in key:
63 | keys = key.split(":")
64 | for k in keys[:-1]:
65 | if k not in d:
66 | d[k] = {}
67 | d = d[k]
68 | last_key = keys[-1]
69 | if val == "true":
70 | val = True
71 | elif val == "false":
72 | val = False
73 | elif isint(val):
74 | val = int(val)
75 | elif isfloat(val):
76 | val = float(val)
77 | elif val.find("[") != -1:
78 | val = literal_eval(val)
79 | d[last_key] = val
80 | args.overrides_dict = overrides_dict
81 |
82 | return args
83 |
--------------------------------------------------------------------------------
/attacks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ping-C/certifiedpatchdefense/f1dbb7e399c320413c17e1412d2fb0ee0d6c812a/attacks/__init__.py
--------------------------------------------------------------------------------
/attacks/patch_attacker.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pdb
4 |
5 | class PatchAttacker:
6 | def __init__(self, model, mean, std, kwargs):
7 | std = torch.tensor(std)
8 | mean = torch.tensor(mean)
9 | self.epsilon = kwargs["epsilon"] / std
10 | self.steps = kwargs["steps"]
11 | self.step_size = kwargs["step_size"] / std
12 | self.step_size.cuda()
13 | self.model = model
14 | self.mean = mean
15 | self.std = std
16 | self.random_start = kwargs["random_start"]
17 |
18 | self.lb = (-mean / std)
19 | self.lb.to('cuda')
20 | self.ub = (1 - mean) / std
21 | self.ub.to('cuda')
22 | self.patch_w = kwargs["patch_w"]
23 | self.patch_l = kwargs["patch_l"]
24 |
25 | self.criterion = torch.nn.CrossEntropyLoss()
26 |
27 | def perturb(self, inputs, labels, norm, random_count=1):
28 | worst_x = None
29 | worst_loss = None
30 | for _ in range(random_count):
31 | # generate random patch center for each image
32 | idx = torch.arange(inputs.shape[0])[:, None]
33 | zero_idx = torch.zeros((inputs.shape[0],1), dtype=torch.long)
34 | w_idx = torch.randint(0, inputs.shape[2]-self.patch_w, (inputs.shape[0],1))
35 | l_idx = torch.randint(0, inputs.shape[3]-self.patch_l, (inputs.shape[0],1))
36 | idx = torch.cat([idx,zero_idx, w_idx, l_idx], dim=1)
37 | idx_list = [idx]
38 | for w in range(self.patch_w):
39 | for l in range(self.patch_l):
40 | idx_list.append(idx + torch.tensor([0,0,w,l]))
41 | idx_list = torch.cat(idx_list, dim =0)
42 |
43 | # create mask
44 | mask = torch.zeros([inputs.shape[0], 1, inputs.shape[2], inputs.shape[3]],
45 | dtype=torch.bool).cuda()
46 | mask[idx_list[:,0],idx_list[:,1],idx_list[:,2],idx_list[:,3]] = True
47 |
48 | if self.random_start:
49 | init_delta = np.random.uniform(-self.epsilon, self.epsilon,
50 | [inputs.shape[0]*inputs.shape[2]*inputs.shape[3], inputs.shape[1]])
51 | init_delta = init_delta.reshape(inputs.shape[0],inputs.shape[2],inputs.shape[3], inputs.shape[1])
52 | init_delta = init_delta.swapaxes(1,3).swapaxes(2,3)
53 | x = inputs + torch.where(mask, torch.Tensor(init_delta).to('cuda'), torch.tensor(0.).cuda())
54 |
55 | x = torch.min(torch.max(x, self.lb[None, :, None, None].cuda()), self.ub[None, :, None, None].cuda()).detach() # ensure valid pixel range
56 | else:
57 | x = inputs.data.detach().clone()
58 |
59 | x_init = inputs.data.detach().clone()
60 |
61 | x.requires_grad_()
62 |
63 | for step in range(self.steps):
64 | output = self.model(torch.where(mask, x, x_init))
65 | loss_ind = torch.nn.CrossEntropyLoss(reduction='none')(output, labels)
66 | if worst_loss is None:
67 | worst_loss = loss_ind.data.detach()
68 | worst_x = x.data.detach()
69 | else:
70 | worst_x = torch.where(worst_loss.ge(loss_ind.detach())[:, None, None, None], worst_x, x)
71 | worst_loss = torch.where(worst_loss.ge(loss_ind.detach()), worst_loss, loss_ind)
72 | loss = loss_ind.sum()
73 | grads = torch.autograd.grad(loss, [x])[0]
74 |
75 | if norm == float('inf'):
76 | signed_grad_x = torch.sign(grads).detach()
77 | delta = signed_grad_x * self.step_size[None, :, None, None].cuda()
78 | elif norm == 'l2':
79 | delta = grads * self.step_size / grads.view(x.shape[0], -1).norm(2, dim=-1).view(-1, 1, 1, 1)
80 |
81 | x.data = delta + x.data.detach()
82 |
83 | # Project back into constraints ball and correct range
84 | x.data = self.project(x_init, x.data, norm, self.epsilon)
85 | x.data = x = torch.min(torch.max(x, self.lb[None, :, None, None].cuda()), self.ub[None, :, None, None].cuda())
86 |
87 | return worst_x
88 |
89 | def project(self, x, x_adv, norm, eps, random=1):
90 | if norm == 'linf':
91 | x_adv = torch.max(torch.min(x_adv, x + eps[None, :, None,None]), x - eps[None, :, None,None])
92 | elif norm == 'l2':
93 | delta = x_adv - x
94 |
95 | # Assume x and x_adv are batched tensors where the first dimension is
96 | # a batch dimension
97 | mask = delta.view(delta.shape[0], -1).norm(2, dim=1) <= eps
98 |
99 | scaling_factor = delta.view(delta.shape[0], -1).norm(2, dim=1)
100 | scaling_factor[mask] = eps
101 |
102 | # .view() assumes batched images as a 4D Tensor
103 | delta *= eps / scaling_factor.view(-1, 1, 1, 1)
104 |
105 | x_adv = x + delta
106 |
107 | return x_adv
108 |
--------------------------------------------------------------------------------
/attacks/pgd_attacker.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pdb
4 |
5 | class PGDAttacker:
6 | def __init__(self, model, mean, std, eps, kwargs):
7 | std = torch.tensor(std)
8 | mean = torch.tensor(mean)
9 | #self.epsilon = kwargs['epsilon'] / std
10 | self.epsilon = eps / std
11 | self.steps = kwargs["steps"]
12 | self.step_size = kwargs["step_size"] / std
13 | self.model = model
14 | self.mean = mean
15 | self.std = std
16 | self.random_start = kwargs["random_start"]
17 |
18 | self.lb = - mean / std
19 | self.lb.to('cuda')
20 | self.ub = (1 - mean) / std
21 | self.ub.to('cuda')
22 | #self.sp_lvl = kwargs["sparse_level"]
23 |
24 | self.criterion = torch.nn.CrossEntropyLoss()
25 |
26 | def perturb(self, inputs, labels, norm, ball=None):
27 | if self.random_start:
28 | x = inputs + torch.Tensor(np.random.uniform(-self.epsilon, self.epsilon, inputs.size())).to('cuda')
29 | x = torch.clamp(x, self.lb, self.ub).detach() # ensure valid pixel range
30 | else:
31 | x = inputs.data.detach().clone()
32 |
33 | x_init = inputs.data.detach().clone()
34 | worst_x = x_init.data.detach()
35 | #with torch.no_grad():
36 | # output = self.model(x)
37 | # worst_loss = self.criterion(output, labels)
38 |
39 | x.requires_grad_()
40 | step = 0
41 | ss = x.shape
42 | alterable_pixels = torch.ones_like(x).view(ss[0], ss[1], -1)
43 | ones = torch.ones_like(x)
44 |
45 | for step in range(self.steps):
46 | output = self.model(x)
47 | loss = self.criterion(output, labels)
48 | if step == 0:
49 | worst_loss = loss
50 | grads = torch.autograd.grad(loss, [x])[0]
51 |
52 | if norm == float('inf'):
53 | signed_grad_x = torch.sign(grads)
54 | delta = signed_grad_x * self.step_size[None, :, None, None].cuda()
55 | x.data = delta + x.data.detach()
56 | elif norm == float('2'):
57 | delta = grads * self.step_size / grads.view(x.shape[0], -1).norm(2, dim=-1).view(-1, 1, 1, 1)
58 | x.data = delta + x.data.detach()
59 | elif norm == float('1'):
60 | ch = x_init.shape[1]
61 | ## change mean, std shape (ch, im, im)
62 | meant = self.mean.repeat(ss[2]*ss[3],1).t().float().cuda()
63 | stdt = self.std.repeat(ss[2]*ss[3],1).t().float().cuda()
64 |
65 | ## max value can change
66 | m = (((ones+torch.sign(grads))/2 - meant.view(-1,ss[2],ss[3]))/stdt.view(-1,ss[2],ss[3]) - x)
67 | grads[m == 0] = 0
68 | grads_abs = torch.abs(grads)
69 | batch_size = grads.shape[0]
70 | if ch == 1:
71 | view = grads_abs.view(batch_size, -1)
72 | view_size = view.shape[1]
73 | sl = 0.99#((0.99 - 0.85)*torch.rand(1) + 0.85).cpu().numpy()[0]
74 | vals, idx = view.topk(int(np.round((1 - sl) * view_size)))
75 | #vals, idx = view.topk(1)
76 | out = torch.zeros_like(view).scatter_(1, idx, vals)
77 | out = out.view_as(grads)
78 | g = torch.sign(grads) * (out > 0).float()
79 | g = g/g.view(batch_size, -1).norm(1, dim=-1).view(-1, 1, 1, 1)
80 | delta = g * self.step_size[None, :, None, None].cuda()
81 | x.data = delta + x.data.detach()
82 | else:
83 | view = grads_abs.sum(1).view(batch_size, -1)
84 | view_size = view.shape[1]
85 | sl = ((0.99 - 0.85)*torch.rand(1) + 0.85).cpu().numpy()[0]
86 | vals, idx = view.topk(int(np.round((1 - sl) * view_size)))
87 | #vals, idx = view.topk(1)
88 | out = torch.zeros_like(view).scatter_(1, idx, vals)
89 | out = out.repeat(1,ch).view_as(grads)
90 | #pdb.set_trace()
91 | g = torch.sign(grads) * (out > 0).float()
92 | g = g/g.view(batch_size, -1).norm(1, dim=-1).view(-1, 1, 1, 1)
93 | delta = g * self.step_size[None, :, None, None].cuda()
94 | x.data = delta + x.data.detach()
95 | #pdb.set_trace()
96 | #delta = torch.min(torch.max(x_init + delta, self.lb[None, :, None, None].cuda()), self.ub[None, :, None, None].cuda()) - x_init
97 |
98 | elif norm == float('0'):
99 | #with torch.no_grad():
100 | ## change mean, std shape (ch, im, im)
101 | #meant = self.mean.repeat(ss[2]*ss[3],1).t().float().cuda()
102 | #stdt = self.std.repeat(ss[2]*ss[3],1).t().float().cuda()
103 |
104 | ## max value can change
105 | #m = (((ones+torch.sign(grads))/2 - meant.view(-1,ss[2],ss[3]))/stdt.view(-1,ss[2],ss[3]) - x) * grads
106 |
107 | ## each pixel can only change once
108 | #if step > 1:
109 | # alterable_pixels[torch.arange(ss[0]), :, argmax_] = 0.0
110 | # m = m * alterable_pixels.view(ss[0], ss[1], ss[2], ss[3])
111 |
112 | ## consider ch together if multi-ch
113 | #msum = m.sum(1).view(ss[0], -1)
114 |
115 | #if msum.sum() == 0:
116 | # break
117 |
118 | ## argmax_ for each in a batch, return size == batch_size
119 | #argmax_ = msum.argmax(-1)
120 |
121 | ## change selected pixel into lb or ub
122 | #x.view(ss[0], ss[1], -1)[torch.arange(ss[0]), :, argmax_] = ((torch.sign(grads.view(ss[0], ss[1], -1)[torch.arange(ss[0]), :, argmax_])+1)/2 - self.mean.cuda()) / self.std.cuda()
123 |
124 | #x.data = torch.min(torch.max(x, self.lb[None, :, None, None].cuda()), self.ub[None, :, None, None].cuda())
125 |
126 | ## max value can change
127 | m = (((ones+torch.sign(grads))/2 - meant.view(-1,ss[2],ss[3]))/stdt.view(-1,ss[2],ss[3]) - x)
128 | grads[m == 0] = 0
129 | grads_abs = torch.abs(grads)
130 | batch_size = grads.shape[0]
131 | if ch == 1:
132 | view = grads_abs.view(batch_size, -1)
133 | view_size = view.shape[1]
134 | sl = 0.99#((0.99 - 0.85)*torch.rand(1) + 0.85).cpu().numpy()[0]
135 | vals, idx = view.topk(int(np.round((1 - sl) * view_size)))
136 | #vals, idx = view.topk(1)
137 | out = torch.zeros_like(view).scatter_(1, idx, vals)
138 | out = out.view_as(grads)
139 | g = torch.sign(grads) * (out > 0).float()
140 | g = g/g.view(batch_size, -1).norm(1, dim=-1).view(-1, 1, 1, 1)
141 | delta = g * self.step_size[None, :, None, None].cuda()
142 | x.data = delta + x.data.detach()
143 | else:
144 | view = grads_abs.sum(1).view(batch_size, -1)
145 | view_size = view.shape[1]
146 | sl = ((0.99 - 0.85)*torch.rand(1) + 0.85).cpu().numpy()[0]
147 | vals, idx = view.topk(int(np.round((1 - sl) * view_size)))
148 | #vals, idx = view.topk(1)
149 | out = torch.zeros_like(view).scatter_(1, idx, vals)
150 | out = out.repeat(1,ch).view_as(grads)
151 | #pdb.set_trace()
152 | g = torch.sign(grads) * (out > 0).float()
153 | g = g/g.view(batch_size, -1).norm(1, dim=-1).view(-1, 1, 1, 1)
154 | delta = g * self.step_size[None, :, None, None].cuda()
155 | x.data = delta + x.data.detach()
156 |
157 |
158 | # Project back into constraints ball and correct range
159 | x.data = self.project(x_init, x.data, norm, self.epsilon)
160 | x.data = torch.min(torch.max(x.data, self.lb[None, :, None, None].cuda()), self.ub[None, :, None, None].cuda())
161 |
162 | #with torch.no_grad():
163 | # output = self.model(x)
164 | # loss = self.criterion(output, labels)
165 | # if loss > worst_loss:
166 | # worst_loss = loss.detach()
167 | # worst_x = x.data.detach()
168 | return x.data.detach()
169 |
170 | def project(self, x, x_adv, ball, eps):
171 | if ball == float('inf'):
172 | x_adv = torch.max(torch.min(x_adv, x + eps[None,:,None,None].cuda()), x - eps[None,:,None,None].cuda())
173 | elif ball == float('2'):
174 | # Assume x and x_adv are batched tensors where the first dimension is
175 | # a batch dimension
176 | mask = delta.view(delta.shape[0], -1).norm(2, dim=1) <= eps
177 |
178 | scaling_factor = delta.view(delta.shape[0], -1).norm(2, dim=1)
179 | scaling_factor[mask] = eps
180 |
181 | # .view() assumes batched images as a 4D Tensor
182 | delta *= eps / scaling_factor.view(-1, 1, 1, 1)
183 |
184 | x_adv = x + delta
185 | elif ball == float('1'):
186 | #eps = eps.sum()
187 | const = 1e-5
188 | delta = (x_adv - x).detach().clone()
189 | batch_size = delta.size(0)
190 | ch = delta.size(1)
191 | if ch == 1:
192 | view = delta.view(batch_size, -1)
193 | # Computing the l1 norm of v
194 | v = torch.abs(view)
195 | v = v.sum(dim=1)
196 | #pdb.set_trace()
197 | # Getting the elements to project in the batch
198 | indexes_b = torch.nonzero(v > (eps.cuda() + const)).view(-1)
199 | x_b = view[indexes_b]
200 | batch_size_b = x_b.size(0)
201 |
202 | # If all elements are in the l1-ball, return x
203 | if batch_size_b == 0:
204 | x_adv = delta + x
205 | else:
206 | # make the projection on l1 ball for elements outside the ball
207 | view = x_b
208 | view_size = view.size(1)
209 | mu = view.abs().sort(1, descending=True)[0]
210 | vv = torch.arange(view_size).float().cuda()
211 | st = (mu.cumsum(1) - eps.cuda()) / (vv + 1)
212 | u = (mu - st) > 0
213 | rho = (1 - u).cumsum(dim=1).eq(0).sum(1) - 1
214 | theta = st.gather(1, rho.unsqueeze(1))
215 | #proj_x_b = _thresh_by_magnitude(theta, x_b)
216 | proj_x_b = torch.relu(torch.abs(x_b) - theta) * x_b.sign()
217 |
218 | # gather all the projected batch
219 | proj_x = delta.view(batch_size, -1).detach().clone()
220 | proj_x[indexes_b] = proj_x_b
221 | x_adv = proj_x.view_as(delta) + x
222 | else:
223 | for i in range(ch):
224 | view = delta[:,i,:,:].view(batch_size, -1)
225 | # Computing the l1 norm of v
226 | v = torch.abs(view)
227 | v = v.sum(dim=-1)
228 | #pdb.set_trace()
229 | # Getting the elements to project in the batch
230 | indexes_b = torch.nonzero(v > (eps[i].cuda() + const)).view(-1)
231 | x_b = view[indexes_b]
232 | batch_size_b = x_b.size(0)
233 |
234 | # If all elements are in the l1-ball, return x
235 | if batch_size_b == 0:
236 | x_adv[:,i, :, :] = delta[:,i,:,:] + x[:,i,:,:]
237 | else:
238 | # make the projection on l1 ball for elements outside the ball
239 | view = x_b
240 | view_size = view.size(1)
241 | mu = view.abs().sort(1, descending=True)[0]
242 | vv = torch.arange(view_size).float().cuda()
243 | st = (mu.cumsum(1) - eps[i].cuda()) / (vv + 1)
244 | u = (mu - st) > 0
245 | rho = (1 - u).cumsum(dim=1).eq(0).sum(1) - 1
246 | theta = st.gather(1, rho.unsqueeze(1))
247 | #proj_x_b = _thresh_by_magnitude(theta, x_b)
248 | proj_x_b = torch.relu(torch.abs(x_b) - theta) * x_b.sign()
249 |
250 | # gather all the projected batch
251 | proj_x = delta[:,i,:,:].view(batch_size, -1).detach().clone()
252 | proj_x[indexes_b] = proj_x_b
253 | x_adv[:,i,:,:] = proj_x.view_as(delta[:,i,:,:]) + x[:,i,:,:]
254 | elif ball == float('0'):
255 | delta = (x_adv - x).detach().clone()
256 |
257 | x_adv = x + delta
258 |
259 | return x_adv
260 |
261 |
--------------------------------------------------------------------------------
/bound_layers.py:
--------------------------------------------------------------------------------
1 | ## Copyright (C) 2019, Huan Zhang
2 | ## Hongge Chen
3 | ## Chaowei Xiao
4 | ##
5 | ## This program is licenced under the BSD 2-Clause License,
6 | ## contained in the LICENCE file in this directory.
7 | ##
8 | import torch
9 | import numpy as np
10 | from torch.nn import Sequential, Conv2d, Linear, ReLU
11 | from model_defs import Flatten, model_mlp_any
12 | import torch.nn.functional as F
13 |
14 | import logging
15 | torch.backends.cudnn.determinic = True
16 | logging.basicConfig(level=logging.INFO)
17 | # logging.basicConfig(level=logging.DEBUG)
18 | logger = logging.getLogger(__name__)
19 |
20 | class BoundFlatten(torch.nn.Module):
21 | def __init__(self):
22 | super(BoundFlatten, self).__init__()
23 |
24 | def forward(self, x):
25 | self.shape = x.size()[1:]
26 | return x.view(x.size(0), -1)
27 |
28 | def interval_propagate(self, h_U, h_L, eps):
29 | return h_U.view(h_U.size(0), -1), h_L.view(h_L.size(0), -1)
30 |
31 | class BoundLinear(Linear):
32 | def __init__(self, in_features, out_features, bias=True):
33 | super(BoundLinear, self).__init__(in_features, out_features, bias)
34 |
35 | @staticmethod
36 | def convert(linear_layer):
37 | l = BoundLinear(linear_layer.in_features, linear_layer.out_features, linear_layer.bias is not None)
38 | l.weight.data.copy_(linear_layer.weight.data)
39 | l.bias.data.copy_(linear_layer.bias.data)
40 | return l
41 |
42 | def interval_propagate(self, h_U, h_L, eps, C = None, k=None, Sparse = None):
43 | # merge the specification
44 | if C is not None:
45 | # after multiplication with C, we have (batch, output_shape, prev_layer_shape)
46 | # we have batch dimension here because of each example has different C
47 | weight = C.matmul(self.weight)
48 | bias = C.matmul(self.bias)
49 | else:
50 | # weight dimension (this_layer_shape, prev_layer_shape)
51 | weight = self.weight
52 | bias = self.bias
53 |
54 | mid = (h_U + h_L) / 2.0
55 | diff = (h_U - h_L) / 2.0
56 | weight_abs = weight.abs()
57 | if C is not None:
58 | center = weight.matmul(mid.unsqueeze(-1)) + bias.unsqueeze(-1)
59 | deviation = weight_abs.matmul(diff.unsqueeze(-1))
60 | # these have an extra (1,) dimension as the last dimension
61 | center = center.squeeze(-1)
62 | deviation = deviation.squeeze(-1)
63 | elif Sparse is not None:
64 | # fused multiply-add
65 | center = torch.addmm(bias, mid, weight.t())
66 | deviation = torch.sum(torch.topk(weight_abs, k)[0], dim=1) * eps
67 | else:
68 | # fused multiply-add
69 | center = torch.addmm(bias, mid, weight.t())
70 | deviation = diff.matmul(weight_abs.t())
71 |
72 | upper = center + deviation
73 | lower = center - deviation
74 | # output
75 | return upper, lower
76 |
77 |
78 |
79 | class BoundConv2d(Conv2d):
80 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
81 | super(BoundConv2d, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
82 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
83 |
84 | @staticmethod
85 | def convert(l):
86 | nl = BoundConv2d(l.in_channels, l.out_channels, l.kernel_size, l.stride, l.padding, l.dilation, l.groups, l.bias is not None)
87 | nl.weight.data.copy_(l.weight.data)
88 | nl.bias.data.copy_(l.bias.data)
89 | logger.debug(nl.bias.size())
90 | logger.debug(nl.weight.size())
91 | return nl
92 |
93 | def forward(self, input):
94 | output = super(BoundConv2d, self).forward(input)
95 | self.output_shape = output.size()[1:]
96 | return output
97 |
98 | def interval_propagate(self, h_U, h_L, eps, k=None, Sparse = None):
99 | if Sparse is not None:
100 | mid = (h_U + h_L) / 2.0
101 | weight_sum = torch.sum(self.weight.abs(), 1)
102 | deviation = torch.sum(torch.topk(weight_sum.view(weight_sum.shape[0], -1), k)[0], dim=1) * eps
103 | center = F.conv2d(mid, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
104 | ss = center.shape
105 | deviation = deviation.repeat(ss[2] * ss[3]).view(-1, ss[1]).t().view(ss[1], ss[2], ss[3])
106 | else:
107 | mid = (h_U + h_L) / 2.0
108 | diff = (h_U - h_L) / 2.0
109 | weight_abs = self.weight.abs()
110 | deviation = F.conv2d(diff, weight_abs, None, self.stride, self.padding, self.dilation, self.groups)
111 | center = F.conv2d(mid, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
112 | logger.debug('center %s', center.size())
113 | upper = center + deviation
114 | lower = center - deviation
115 | return upper, lower
116 |
117 | class BoundReLU(ReLU):
118 | def __init__(self, prev_layer, inplace=False):
119 | super(BoundReLU, self).__init__(inplace)
120 | # ReLU needs the previous layer's bounds
121 | # self.prev_layer = prev_layer
122 |
123 | ## Convert a ReLU layer to BoundReLU layer
124 | # @param act_layer ReLU layer object
125 | # @param prev_layer Pre-activation layer, used for get preactivation bounds
126 | @staticmethod
127 | def convert(act_layer, prev_layer):
128 | l = BoundReLU(prev_layer, act_layer.inplace)
129 | return l
130 |
131 | def interval_propagate(self, h_U, h_L, eps):
132 | return F.relu(h_U), F.relu(h_L)
133 |
134 |
135 |
136 | class BoundSequential(Sequential):
137 | def __init__(self, *args):
138 | super(BoundSequential, self).__init__(*args)
139 |
140 | ## Convert a Pytorch model to a model with bounds
141 | # @param sequential_model Input pytorch model
142 | # @return Converted model
143 | @staticmethod
144 | def convert(sequential_model):
145 | layers = []
146 | for l in sequential_model:
147 | if isinstance(l, Linear):
148 | layers.append(BoundLinear.convert(l))
149 | if isinstance(l, Conv2d):
150 | layers.append(BoundConv2d.convert(l))
151 | if isinstance(l, ReLU):
152 | layers.append(BoundReLU.convert(l, layers[-1]))
153 | if isinstance(l, Flatten):
154 | layers.append(BoundFlatten())
155 | return BoundSequential(*layers)
156 |
157 | def interval_range(self, x_U=None, x_L=None, eps=None, C=None, k=None, Sparse=None):
158 | h_U = x_U
159 | h_L = x_L
160 | for i, module in enumerate(list(self._modules.values())[:-1]):
161 | if Sparse is not None and k is not None:
162 | if i == 0 and (isinstance(module, Linear) or isinstance(module, Conv2d)):
163 | h_U, h_L = module.interval_propagate(h_U, h_L, eps, k=k, Sparse=Sparse)
164 | elif i == 1 and isinstance(module, Linear):
165 | h_U, h_L = module.interval_propagate(h_U, h_L, eps, k=k, Sparse=Sparse)
166 | else:
167 | h_U, h_L = module.interval_propagate(h_U, h_L, eps)
168 | else:
169 | h_U, h_L = module.interval_propagate(h_U, h_L, eps)
170 |
171 | # last layer has C to merge
172 | h_U, h_L = list(self._modules.values())[-1].interval_propagate(h_U, h_L, eps, C)
173 |
174 | return h_U, h_L
175 |
176 | def interval_range_pool(self, x_U=None, x_L=None, eps=None, C=None, neighbor=None, pos_patch_width=None, pos_patch_length=None):
177 | h_U = x_U
178 | h_L = x_L
179 | last_module = list(self._modules.values())[-1]
180 |
181 | for i, module in enumerate(list(self._modules.values())[0:-1]):
182 | h_U, h_L = module.interval_propagate(h_U, h_L, eps)
183 |
184 | #pool bounds
185 | if i < len(neighbor) and neighbor[i] > 1:
186 | ori_shape = h_U.shape
187 | batch_size = ori_shape[0] // pos_patch_width // pos_patch_length
188 | # h_U = (batch*possible patch, width_bound, length_bound, channels_bound)
189 | h_U = h_U.view(batch_size, pos_patch_width, pos_patch_length, -1)
190 | # h_U = (batch, width, length, width_bound*length_bound*channels_bound)
191 | h_U = h_U.permute(0, 3, 1, 2)
192 | # h_U = (batch, width_bound*length_bound*channels_bound, width, length)
193 | h_U = torch.nn.functional.max_pool2d(h_U, neighbor[i], neighbor[i], 0, 1, True, False)
194 | # h_U = (batch, width_bound*length_bound*channels_bound, (width-1)//neighbor+1, (length-1)//neighbor+1)
195 | h_U = h_U.permute(0, 2, 3, 1)
196 | # h_U = (batch, (width-1)//neighbor+1, (length-1)//neighbor+1, width_bound*length_bound*channels_bound)
197 | h_U = h_U.reshape(-1, *ori_shape[1:])
198 | # h_U = (batch*(width-1)//neighbor+1*(length-1)//neighbor+1, width_bound*length_bound*channels_bound)
199 |
200 | h_L = h_L.view(batch_size, pos_patch_width, pos_patch_length, -1)
201 | h_L = h_L.permute(0, 3, 1, 2)
202 | h_L = -torch.nn.functional.max_pool2d(-h_L, neighbor[i], neighbor[i], 0, 1, True, False)
203 | h_L = h_L.permute(0, 2, 3, 1)
204 | h_L = h_L.reshape(-1, *ori_shape[1:])
205 |
206 | pos_patch_width = (pos_patch_width-1)//neighbor[i] + 1
207 | pos_patch_length = (pos_patch_length-1)//neighbor[i] + 1
208 |
209 | # last layer has C to merge
210 | h_U, h_L= last_module.interval_propagate(h_U, h_L, eps, C)
211 | return h_U, h_L
212 |
213 | class ParallelBound(torch.nn.Module):
214 | def __init__(self, model):
215 | super(ParallelBound, self).__init__()
216 | self.model = model
217 | def forward(self, x_U, x_L, eps, C):
218 | ub, lb = self.model.interval_range(x_U=x_U, x_L=x_L, eps=eps, C=C)
219 | return ub, lb
220 |
221 | class ParallelBoundPool(torch.nn.Module):
222 | def __init__(self, model):
223 | super(ParallelBoundPool, self).__init__()
224 | self.model = model
225 | def forward(self, x_U,
226 | x_L, eps, C, neighbor, pos_patch_width, pos_patch_length):
227 | ub, lb = self.model.interval_range_pool(x_U=x_U, x_L=x_L, eps=eps, C=C, neighbor=neighbor,
228 | pos_patch_width = pos_patch_width, pos_patch_length = pos_patch_length)
229 | return ub, lb
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | ## Copyright (C) 2019, Huan Zhang
2 | ## Hongge Chen
3 | ## Chaowei Xiao
4 | ##
5 | ## This program is licenced under the BSD 2-Clause License,
6 | ## contained in the LICENCE file in this directory.
7 | ##
8 | import os
9 | import json
10 | import glob
11 | import copy
12 | import importlib
13 | import torch
14 | import numpy as np
15 | from datasets import loaders
16 |
17 | # from model_defs import add_feature_subsample, remove_feature_subsample
18 | from model_defs import add_feature_subsample, remove_feature_subsample, convert_conv2d_dense, save_checkpoint, load_checkpoint_to_mlpany
19 |
20 | # Helper function to find a file with closest match
21 | def get_file_close(filename, ext, load = True):
22 | if ext[0] == ".":
23 | ext = ext[1:]
24 | if not load:
25 | return filename + "." + ext
26 | filelist = glob.glob(filename + "*." + ext +"*")
27 | if len(filelist) == 0:
28 | raise OSError("File " + filename + " not found!")
29 | # FIXME
30 | if "last" in filelist[0]:
31 | filelist = filelist[1:]
32 | if len(filelist) > 1:
33 | filelist = sorted(filelist, key = len)
34 | print("Warning! Multiple files matches ID {}: {}".format(filename, filelist))
35 | for f in filelist:
36 | # return the best model if we have it
37 | if "best" in f:
38 | return f
39 | return filelist[0]
40 |
41 |
42 | def update_dict(d, u, show_warning = False):
43 | for k, v in u.items():
44 | if k not in d and show_warning:
45 | print("\033[91m Warning: key {} not found in config. Make sure to double check spelling and config option name. \033[0m".format(k))
46 | if isinstance(v, dict):
47 | d[k] = update_dict(d.get(k, {}), v, show_warning)
48 | else:
49 | d[k] = v
50 | return d
51 |
52 | def load_config(args):
53 | print("loading config file: {}".format(args.config))
54 | with open("defaults.json") as f:
55 | config = json.load(f)
56 | with open(args.config) as f:
57 | update_dict(config, json.load(f))
58 | if args.overrides_dict:
59 | print("overriding parameters: \033[93mPlease check these parameters carefully.\033[0m")
60 | print("\033[93m" + str(args.overrides_dict) + "\033[0m")
61 | update_dict(config, args.overrides_dict, True)
62 | subset_models = []
63 | # remove ignored models
64 | for model_config in config["models"]:
65 | if "ignore" in model_config and model_config["ignore"]:
66 | continue
67 | else:
68 | subset_models.append(model_config)
69 | config["models"] = subset_models
70 | # repeat models if the "repeated" field is found
71 | repeated_models = []
72 | for model_config in config["models"]:
73 | if "repeats" in model_config:
74 | for i in range(model_config["repeats"]):
75 | c = copy.deepcopy(model_config)
76 | c["repeats_idx"] = i + 1
77 | for k, v in c.items():
78 | if isinstance(v, str):
79 | if v == "##":
80 | c[k] = i + 1
81 | if "@@" in v:
82 | c[k] = c[k].replace("@@", str(i+1))
83 | repeated_models.append(c)
84 | else:
85 | repeated_models.append(model_config)
86 | config["models"] = repeated_models
87 | # only use a subset of models, if specified
88 | if args.model_subset:
89 | subset_models = []
90 | for i in args.model_subset:
91 | subset_models.append(config["models"][i])
92 | config["models"] = subset_models
93 | if args.path_prefix:
94 | config["path_prefix"] = args.path_prefix
95 | return config
96 |
97 | # Load dataset loader based on config file
98 | def config_dataloader(config, **kwargs):
99 | return loaders[config["dataset"]](**kwargs)
100 |
101 | # Unified naming rule for model files, bound files, ensemble weights and others
102 | # To change format of saved model names, etc, only change here
103 | def get_path(config, model_id, path_name, **kwargs):
104 | if path_name == "model":
105 | model_file = get_file_close(os.path.join(config["path_prefix"], config["models_path"], model_id), "pth", **kwargs)
106 | os.makedirs(os.path.join(config["path_prefix"], config["models_path"]), exist_ok = True)
107 | return model_file
108 | if path_name == "best_model":
109 | model_file = os.path.join(config["path_prefix"], config["models_path"], model_id + "_best.pth")
110 | os.makedirs(os.path.join(config["path_prefix"], config["models_path"]), exist_ok = True)
111 | return model_file
112 | if path_name == "train_log":
113 | model_file = get_path(config, model_id, "model", load = False)
114 | os.makedirs(os.path.join(config["path_prefix"], config["models_path"]), exist_ok = True)
115 | return model_file.replace(".pth", ".log")
116 | if path_name == "eval_log":
117 | model_file = get_path(config, model_id, "model", load = False)
118 | os.makedirs(os.path.join(config["path_prefix"], config["models_path"]), exist_ok = True)
119 | return model_file.replace(".pth", f"{config['log_suffix']}_test.log")
120 | # temporary
121 | if path_name == "model_alt":
122 | model_file = os.path.join(config["path_prefix"], config["models_path"], "alt", model_id + ".pth")
123 | return model_file
124 | elif path_name == "bound":
125 | os.makedirs(config["bounds_path"], exist_ok = True)
126 | bound_file = os.path.join(config["bounds_path"], model_id)
127 | if "train" in kwargs and kwargs["train"]:
128 | bound_file += "_train.h5"
129 | else:
130 | bound_file += "_test.h5"
131 | return bound_file
132 | elif path_name == "boost_bound":
133 | os.makedirs(config["bounds_path"], exist_ok = True)
134 | bound_file = os.path.join(config["bounds_path"], model_id) + "_boost.h5"
135 | return bound_file
136 | elif path_name == "alpha":
137 | return config["alpha_path"]
138 | else:
139 | raise RuntimeError("Unsupported path " + path_name)
140 |
141 | # Return config of a single model
142 | def get_model_config(config, model_id):
143 | for model_config in config["models"]:
144 | if model_config["model_id"] == model_id:
145 | return model_config
146 |
147 | # Load all models based on config file
148 | def config_modelloader(config, load_pretrain = False, cuda = False):
149 | # load the required modelfile
150 | model_module = importlib.import_module(os.path.splitext(config["model_def"])[0])
151 | models = []
152 | model_names = []
153 | for model_config in config["models"]:
154 | if "ignore" in model_config and model_config["ignore"]:
155 | continue
156 | model_id = model_config["model_id"]
157 | model_names.append(model_id)
158 | model_class = getattr(model_module, model_config["model_class"])
159 | model_params = model_config["model_params"]
160 | m = model_class(**model_params)
161 | if "subsample" in model_config and model_config["subsample"]:
162 | keep = model_config["subsample_prob"]
163 | seed = model_config["subsample_seed"]
164 | m = add_feature_subsample(m, config["channel"], config["dimension"], keep, seed)
165 | if cuda:
166 | m.cuda()
167 | if load_pretrain:
168 | model_file = get_path(config, model_id, "model")
169 | #model_file += "_pretrain"
170 | print("Loading model file", model_file)
171 | checkpoint = torch.load(model_file)
172 | if isinstance(checkpoint['state_dict'], list):
173 | checkpoint['state_dict'] = checkpoint['state_dict'][0]
174 | new_state_dict = {}
175 | for k in checkpoint['state_dict'].keys():
176 | if "prev" in k:
177 | pass
178 | else:
179 | new_state_dict[k] = checkpoint['state_dict'][k]
180 | checkpoint['state_dict'] = new_state_dict
181 | """
182 | state_dict = m.state_dict()
183 | state_dict.update(checkpoint['state_dict'])
184 | m.load_state_dict(state_dict)
185 | # print(checkpoint['state_dict']['__mask_layer.weight'])
186 | """
187 | m.load_state_dict(checkpoint['state_dict'])
188 | # print(m)
189 | models.append(m)
190 | return models, model_names
191 |
192 |
193 | def config_modelloader_and_convert2mlp(config, load_pretrain = True):
194 | # load the required modelfile
195 | model_module = importlib.import_module(os.path.splitext(config["model_def"])[0])
196 | models = []
197 | model_names = []
198 |
199 | for model_config in config["models"]:
200 | if "ignore" in model_config and model_config["ignore"]:
201 | continue
202 | model_id = model_config["model_id"]
203 | model_names.append(model_id)
204 | model_class = getattr(model_module, model_config["model_class"])
205 | model_params = model_config["model_params"]
206 | m = model_class(**model_params)
207 | if "subsample" in model_config and model_config["subsample"]:
208 | keep = model_config["subsample_prob"]
209 | seed = model_config["subsample_seed"]
210 | m = add_feature_subsample(m, config["channel"], config["dimension"], keep, seed)
211 | # m.cuda()
212 | if load_pretrain:
213 | model_file = get_path(config, model_id, "model")
214 | #model_file += "_pretrain"
215 | print("Loading model file", model_file)
216 | checkpoint = torch.load(model_file)
217 | if isinstance(checkpoint['state_dict'], list):
218 | checkpoint['state_dict'] = checkpoint['state_dict'][0]
219 | new_state_dict = {}
220 | for k in checkpoint['state_dict'].keys():
221 | if "prev" in k:
222 | pass
223 | else:
224 | new_state_dict[k] = checkpoint['state_dict'][k]
225 | checkpoint['state_dict'] = new_state_dict
226 | """
227 | state_dict = m.state_dict()
228 | state_dict.update(checkpoint['state_dict'])
229 | m.load_state_dict(state_dict)
230 | # print(checkpoint['state_dict']['__mask_layer.weight'])
231 | """
232 |
233 | m.load_state_dict(checkpoint['state_dict'])
234 | print("convert to dense w")
235 | dense_m = convert_conv2d_dense(m)
236 | in_dim = model_params["in_dim"]
237 | in_ch = model_params["in_ch"]
238 | tmp = dense_m(torch.zeros(1, in_ch, in_dim, in_dim))
239 | dense_checkpoint_file = model_file.split(".pth")[0] + "_dense.pth"
240 | print("save dense checkpoint to {}".format(dense_checkpoint_file))
241 |
242 | save_checkpoint(dense_m, dense_checkpoint_file )
243 |
244 | mlp_m = load_checkpoint_to_mlpany(dense_checkpoint_file)
245 | # print(m)
246 | # models.append(m)
247 | models.append(mlp_m)
248 | return models, model_names
249 |
250 |
--------------------------------------------------------------------------------
/config/cifar_adveval_patch_22.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "path_prefix": "",
5 | "log_suffix": "eval_patch_22_cifar",
6 | "models_path": "",
7 | "eval_params": {
8 | "method": "adv",
9 | "verbose": false,
10 | "epsilon": 1,
11 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "shuffle_train": true, "normalize_input": true},
12 | "method_params": {"bounded_input": true, "attack_type": "patch-random",
13 | "patch_w": 2, "patch_l": 2,
14 | "epsilon": 1, "steps": 150, "random_start": true, "random_mask_count": 80,"sample_limit": 400,
15 | "step_size": 0.05}
16 | },
17 | "models": [
18 | {
19 | "model_id": "mlp_255",
20 | "model_class": "model_mlp_any",
21 | "model_params": {"in_dim": 3072, "neurons": [255]}
22 | },
23 | {
24 | "model_id": "cnn_2layer_width_1",
25 | "model_class": "model_cnn_2layer",
26 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
27 | },
28 | {
29 | "model_id": "cnn_4layer_linear_256_width_1",
30 | "model_class": "model_cnn_4layer",
31 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
32 | }
33 | ]
34 | }
35 |
36 |
--------------------------------------------------------------------------------
/config/cifar_adveval_patch_55.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "path_prefix": "",
5 | "log_suffix": "eval_patch_55_cifar",
6 | "models_path": "",
7 | "eval_params": {
8 | "method": "adv",
9 | "verbose": false,
10 | "epsilon": 1,
11 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "shuffle_train": true, "normalize_input": true},
12 | "method_params": {"bounded_input": true, "attack_type": "patch-random",
13 | "patch_w": 5, "patch_l": 5,
14 | "epsilon": 1, "steps": 150, "random_start": true, "random_mask_count": 80,"sample_limit": 400,
15 | "step_size": 0.05}
16 | },
17 | "models": [
18 | {
19 | "model_id": "mlp_255",
20 | "model_class": "model_mlp_any",
21 | "model_params": {"in_dim": 3072, "neurons": [255]}
22 | },
23 | {
24 | "model_id": "cnn_2layer_width_1",
25 | "model_class": "model_cnn_2layer",
26 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
27 | },
28 | {
29 | "model_id": "cnn_4layer_linear_256_width_1",
30 | "model_class": "model_cnn_4layer",
31 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
32 | }
33 | ]
34 | }
35 |
36 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_k10_sparse.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_k10_sparse/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "sparse", "bound_type": "sparse-interval",
20 | "k":10, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "normalize_input": true},
27 | "method_params": {"attack_type": "sparse", "bound_type": "sparse-interval",
28 | "k": 10, "epsilon": 1}
29 | },
30 | "models": [
31 | {
32 | "model_id": "cnn_4layer_linear_256_width_1",
33 | "model_class": "model_cnn_4layer_conv13",
34 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
35 | },
36 | {
37 | "model_id": "mlp_255",
38 | "model_class": "model_mlp_any",
39 | "model_params": {"in_dim": 3072, "neurons": [255]}
40 | }
41 | ]
42 | }
43 |
44 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_k1_sparse.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_k1_sparse/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "sparse", "bound_type": "sparse-interval",
20 | "k":1, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "normalize_input": true},
27 | "method_params": {"attack_type": "sparse", "bound_type": "sparse-interval",
28 | "k": 1, "epsilon": 1}
29 | },
30 | "models": [
31 | {
32 | "model_id": "cnn_4layer_linear_256_width_1",
33 | "model_class": "model_cnn_4layer_conv13",
34 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
35 | },
36 | {
37 | "model_id": "mlp_255",
38 | "model_class": "model_mlp_any",
39 | "model_params": {"in_dim": 3072, "neurons": [255]}
40 | }
41 | ]
42 | }
43 |
44 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_k4_sparse.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_k4_sparse/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "sparse", "bound_type": "sparse-interval",
20 | "k":4, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "normalize_input": true},
27 | "method_params": {"attack_type": "sparse", "bound_type": "sparse-interval",
28 | "k": 4, "epsilon": 1}
29 | },
30 | "models": [
31 | {
32 | "model_id": "cnn_4layer_linear_256_width_1",
33 | "model_class": "model_cnn_4layer_conv13",
34 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
35 | },
36 | {
37 | "model_id": "mlp_255",
38 | "model_class": "model_mlp_any",
39 | "model_params": {"in_dim": 3072, "neurons": [255]}
40 | }
41 | ]
42 | }
43 |
44 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_p22_all.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_p22_all/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
20 | "patch_w": 2, "patch_l": 2, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "normalize_input": true},
27 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
28 | "patch_w": 2, "patch_l": 2,
29 | "epsilon": 1}
30 | },
31 | "models": [
32 | {
33 | "model_id": "mlp_255",
34 | "model_class": "model_mlp_any",
35 | "model_params": {"in_dim": 3072, "neurons": [255]}
36 | },
37 | {
38 | "model_id": "cnn_2layer_width_1",
39 | "model_class": "model_cnn_2layer",
40 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
41 | },
42 | {
43 | "model_id": "cnn_4layer_linear_256_width_1",
44 | "model_class": "model_cnn_4layer",
45 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
46 | }
47 | ]
48 | }
49 |
50 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_p22_allpool2.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_p22_allpool2/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "patch-all-pool", "bound_type": "patch-interval", "neighbor": [2],
20 | "patch_w": 2, "patch_l": 2, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "normalize_input": true},
27 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
28 | "patch_w": 2, "patch_l": 2,
29 | "epsilon": 1}
30 | },
31 | "models": [
32 | {
33 | "model_id": "cnn_4layer_linear_256_width_1",
34 | "model_class": "model_cnn_4layer",
35 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
36 | }
37 | ]
38 | }
39 |
40 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_p22_allpool22.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_p22_allpool22/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "patch-all-pool", "bound_type": "patch-interval", "neighbor": [2, 1, 2],
20 | "patch_w": 2, "patch_l": 2, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "normalize_input": true},
27 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
28 | "patch_w": 2, "patch_l": 2,
29 | "epsilon": 1}
30 | },
31 | "models": [
32 | {
33 | "model_id": "cnn_4layer_linear_256_width_1",
34 | "model_class": "model_cnn_4layer",
35 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
36 | }
37 | ]
38 | }
39 |
40 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_p22_guide10.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_p22_guide10/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "patch-nn", "bound_type": "patch-interval",
20 | "patch_w": 2, "patch_l": 2, "patch_count": 1, "base_width": 10,"T": 1, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "normalize_input": true},
27 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
28 | "patch_w": 2, "patch_l": 2,
29 | "epsilon": 1}
30 | },
31 | "models": [
32 | {
33 | "model_id": "mlp_255",
34 | "model_class": "model_mlp_any",
35 | "model_params": {"in_dim": 3072, "neurons": [255]}
36 | },
37 | {
38 | "model_id": "cnn_2layer_width_1",
39 | "model_class": "model_cnn_2layer",
40 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
41 | },
42 | {
43 | "model_id": "cnn_4layer_linear_256_width_1",
44 | "model_class": "model_cnn_4layer",
45 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
46 | },
47 | {
48 | "model_id": "cnn_5layer_linear_512_width_16",
49 | "model_class": "model_cnn_5layer",
50 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 16, "linear_size": 512}
51 | }
52 | ]
53 | }
54 |
55 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_p22_guide20.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_p22_guide20/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "patch-nn", "bound_type": "patch-interval",
20 | "patch_w": 2, "patch_l": 2, "patch_count": 2, "base_width": 10,"T": 1, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "normalize_input": true},
27 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
28 | "patch_w": 2, "patch_l": 2,
29 | "epsilon": 1}
30 | },
31 | "models": [
32 | {
33 | "model_id": "mlp_255",
34 | "model_class": "model_mlp_any",
35 | "model_params": {"in_dim": 3072, "neurons": [255]}
36 | },
37 | {
38 | "model_id": "cnn_2layer_width_1",
39 | "model_class": "model_cnn_2layer",
40 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
41 | },
42 | {
43 | "model_id": "cnn_4layer_linear_256_width_1",
44 | "model_class": "model_cnn_4layer",
45 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
46 | },
47 | {
48 | "model_id": "cnn_5layer_linear_512_width_16",
49 | "model_class": "model_cnn_5layer",
50 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 16, "linear_size": 512}
51 | }
52 | ]
53 | }
54 |
55 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_p22_rand1.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_p22_rand1_small/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "patch-random", "bound_type": "patch-interval",
20 | "patch_w": 2, "patch_l": 2, "patch_count": 1,
21 | "epsilon": 1}
22 | },
23 | "eval_params": {
24 | "method": "robust",
25 | "verbose": false,
26 | "epsilon": 1,
27 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "normalize_input": true},
28 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
29 | "patch_w": 2, "patch_l": 2,
30 | "epsilon": 1}
31 | },
32 | "models": [
33 | {
34 | "model_id": "mlp_255",
35 | "model_class": "model_mlp_any",
36 | "model_params": {"in_dim": 3072, "neurons": [255]}
37 | },
38 | {
39 | "model_id": "cnn_2layer_width_1",
40 | "model_class": "model_cnn_2layer",
41 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
42 | },
43 | {
44 | "model_id": "cnn_4layer_linear_256_width_1",
45 | "model_class": "model_cnn_4layer",
46 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
47 | },
48 | {
49 | "model_id": "cnn_5layer_linear_512_width_16",
50 | "model_class": "model_cnn_5layer",
51 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 16, "linear_size": 512}
52 | }
53 | ]
54 | }
55 |
56 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_p22_rand10.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_p22_rand10/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "patch-random", "bound_type": "patch-interval",
20 | "patch_w": 2, "patch_l": 2, "patch_count": 10,
21 | "epsilon": 1}
22 | },
23 | "eval_params": {
24 | "method": "robust",
25 | "verbose": false,
26 | "epsilon": 1,
27 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "normalize_input": true},
28 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
29 | "patch_w": 2, "patch_l": 2,
30 | "epsilon": 1}
31 | },
32 | "models": [
33 | {
34 | "model_id": "mlp_255",
35 | "model_class": "model_mlp_any",
36 | "model_params": {"in_dim": 3072, "neurons": [255]}
37 | },
38 | {
39 | "model_id": "cnn_2layer_width_1",
40 | "model_class": "model_cnn_2layer",
41 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
42 | },
43 | {
44 | "model_id": "cnn_4layer_linear_256_width_1",
45 | "model_class": "model_cnn_4layer",
46 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
47 | },
48 | {
49 | "model_id": "cnn_5layer_linear_512_width_16",
50 | "model_class": "model_cnn_5layer",
51 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 16, "linear_size": 512}
52 | }
53 | ]
54 | }
55 |
56 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_p22_rand20.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_p22_rand20/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "patch-random", "bound_type": "patch-interval",
20 | "patch_w": 2, "patch_l": 2, "patch_count": 20,
21 | "epsilon": 1}
22 | },
23 | "eval_params": {
24 | "method": "robust",
25 | "verbose": false,
26 | "epsilon": 1,
27 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "normalize_input": true},
28 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
29 | "patch_w": 2, "patch_l": 2,
30 | "epsilon": 1}
31 | },
32 | "models": [
33 | {
34 | "model_id": "mlp_255",
35 | "model_class": "model_mlp_any",
36 | "model_params": {"in_dim": 3072, "neurons": [255]}
37 | },
38 | {
39 | "model_id": "cnn_2layer_width_1",
40 | "model_class": "model_cnn_2layer",
41 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
42 | },
43 | {
44 | "model_id": "cnn_4layer_linear_256_width_1",
45 | "model_class": "model_cnn_4layer",
46 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
47 | },
48 | {
49 | "model_id": "cnn_5layer_linear_512_width_16",
50 | "model_class": "model_cnn_5layer",
51 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 16, "linear_size": 512}
52 | }
53 | ]
54 | }
55 |
56 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_p22_rand5.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_p22_rand5/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "patch-random", "bound_type": "patch-interval",
20 | "patch_w": 2, "patch_l": 2, "patch_count": 5,
21 | "epsilon": 1}
22 | },
23 | "eval_params": {
24 | "method": "robust",
25 | "verbose": false,
26 | "epsilon": 1,
27 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "normalize_input": true},
28 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
29 | "patch_w": 2, "patch_l": 2,
30 | "epsilon": 1}
31 | },
32 | "models": [
33 | {
34 | "model_id": "mlp_255",
35 | "model_class": "model_mlp_any",
36 | "model_params": {"in_dim": 3072, "neurons": [255]}
37 | },
38 | {
39 | "model_id": "cnn_2layer_width_1",
40 | "model_class": "model_cnn_2layer",
41 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
42 | },
43 | {
44 | "model_id": "cnn_4layer_linear_256_width_1",
45 | "model_class": "model_cnn_4layer",
46 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
47 | },
48 | {
49 | "model_id": "cnn_5layer_linear_512_width_16",
50 | "model_class": "model_cnn_5layer",
51 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 16, "linear_size": 512}
52 | }
53 | ]
54 | }
55 |
56 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_p55_all.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_p55_all/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
20 | "patch_w": 5, "patch_l": 5, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "normalize_input": true},
27 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
28 | "patch_w": 5, "patch_l": 5,
29 | "epsilon": 1}
30 | },
31 | "models": [
32 | {
33 | "model_id": "mlp_255",
34 | "model_class": "model_mlp_any",
35 | "model_params": {"in_dim": 3072, "neurons": [255]}
36 | },
37 | {
38 | "model_id": "cnn_2layer_width_1",
39 | "model_class": "model_cnn_2layer",
40 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
41 | },
42 | {
43 | "model_id": "cnn_4layer_linear_256_width_1",
44 | "model_class": "model_cnn_4layer",
45 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
46 | }
47 | ]
48 | }
49 |
50 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_p55_allpool2.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_p55_allpool2/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "patch-all-pool", "bound_type": "patch-interval", "neighbor": [2],
20 | "patch_w": 5, "patch_l": 5, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "normalize_input": true},
27 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
28 | "patch_w": 5, "patch_l": 5,
29 | "epsilon": 1}
30 | },
31 | "models": [
32 | {
33 | "model_id": "cnn_4layer_linear_256_width_1",
34 | "model_class": "model_cnn_4layer",
35 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
36 | }
37 | ]
38 | }
39 |
40 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_p55_allpool22.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_p55_allpool22/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "patch-all-pool", "bound_type": "patch-interval", "neighbor": [2, 1, 2],
20 | "patch_w": 5, "patch_l": 5, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "normalize_input": true},
27 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
28 | "patch_w": 5, "patch_l": 5,
29 | "epsilon": 1}
30 | },
31 | "models": [
32 | {
33 | "model_id": "cnn_4layer_linear_256_width_1",
34 | "model_class": "model_cnn_4layer",
35 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
36 | }
37 | ]
38 | }
39 |
40 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_p55_guide10.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_p55_guide10/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "patch-nn", "bound_type": "patch-interval",
20 | "patch_w": 5, "patch_l": 5, "patch_count": 1, "base_width": 10,"T": 1, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "normalize_input": true},
27 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
28 | "patch_w": 5, "patch_l": 5,
29 | "epsilon": 1}
30 | },
31 | "models": [
32 | {
33 | "model_id": "mlp_255",
34 | "model_class": "model_mlp_any",
35 | "model_params": {"in_dim": 3072, "neurons": [255]}
36 | },
37 | {
38 | "model_id": "cnn_2layer_width_1",
39 | "model_class": "model_cnn_2layer",
40 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
41 | },
42 | {
43 | "model_id": "cnn_4layer_linear_256_width_1",
44 | "model_class": "model_cnn_4layer",
45 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
46 | },
47 | {
48 | "model_id": "cnn_5layer_linear_512_width_16",
49 | "model_class": "model_cnn_5layer",
50 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 16, "linear_size": 512}
51 | }
52 | ]
53 | }
54 |
55 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_p55_guide20.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_p55_guide20/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "patch-nn", "bound_type": "patch-interval",
20 | "patch_w": 5, "patch_l": 5, "patch_count": 2, "base_width": 10,"T": 1, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "normalize_input": true},
27 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
28 | "patch_w": 5, "patch_l": 5,
29 | "epsilon": 1}
30 | },
31 | "models": [
32 | {
33 | "model_id": "mlp_255",
34 | "model_class": "model_mlp_any",
35 | "model_params": {"in_dim": 3072, "neurons": [255]}
36 | },
37 | {
38 | "model_id": "cnn_2layer_width_1",
39 | "model_class": "model_cnn_2layer",
40 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
41 | },
42 | {
43 | "model_id": "cnn_4layer_linear_256_width_1",
44 | "model_class": "model_cnn_4layer",
45 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
46 | },
47 | {
48 | "model_id": "cnn_5layer_linear_512_width_16",
49 | "model_class": "model_cnn_5layer",
50 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 16, "linear_size": 512}
51 | }
52 | ]
53 | }
54 |
55 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_p55_rand1.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_p55_rand1/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "patch-random", "bound_type": "patch-interval",
20 | "patch_w": 5, "patch_l": 5, "patch_count": 1,
21 | "epsilon": 1}
22 | },
23 | "eval_params": {
24 | "method": "robust",
25 | "verbose": false,
26 | "epsilon": 1,
27 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "normalize_input": true},
28 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
29 | "patch_w": 5, "patch_l": 5,
30 | "epsilon": 1}
31 | },
32 | "models": [
33 | {
34 | "model_id": "mlp_255",
35 | "model_class": "model_mlp_any",
36 | "model_params": {"in_dim": 3072, "neurons": [255]}
37 | },
38 | {
39 | "model_id": "cnn_2layer_width_1",
40 | "model_class": "model_cnn_2layer",
41 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
42 | },
43 | {
44 | "model_id": "cnn_4layer_linear_256_width_1",
45 | "model_class": "model_cnn_4layer",
46 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
47 | },
48 | {
49 | "model_id": "cnn_5layer_linear_512_width_16",
50 | "model_class": "model_cnn_5layer",
51 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 16, "linear_size": 512}
52 | }
53 | ]
54 | }
55 |
56 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_p55_rand10.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_p55_rand10/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "patch-random", "bound_type": "patch-interval",
20 | "patch_w": 5, "patch_l": 5, "patch_count": 10,
21 | "epsilon": 1}
22 | },
23 | "eval_params": {
24 | "method": "robust",
25 | "verbose": false,
26 | "epsilon": 1,
27 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "normalize_input": true},
28 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
29 | "patch_w": 5, "patch_l": 5,
30 | "epsilon": 1}
31 | },
32 | "models": [
33 | {
34 | "model_id": "mlp_255",
35 | "model_class": "model_mlp_any",
36 | "model_params": {"in_dim": 3072, "neurons": [255]}
37 | },
38 | {
39 | "model_id": "cnn_2layer_width_1",
40 | "model_class": "model_cnn_2layer",
41 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
42 | },
43 | {
44 | "model_id": "cnn_4layer_linear_256_width_1",
45 | "model_class": "model_cnn_4layer",
46 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
47 | },
48 | {
49 | "model_id": "cnn_5layer_linear_512_width_16",
50 | "model_class": "model_cnn_5layer",
51 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 16, "linear_size": 512}
52 | }
53 | ]
54 | }
55 |
56 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_p55_rand20.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_p55_rand20/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "patch-random", "bound_type": "patch-interval",
20 | "patch_w": 5, "patch_l": 5, "patch_count": 20,
21 | "epsilon": 1}
22 | },
23 | "eval_params": {
24 | "method": "robust",
25 | "verbose": false,
26 | "epsilon": 1,
27 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "normalize_input": true},
28 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
29 | "patch_w": 5, "patch_l": 5,
30 | "epsilon": 1}
31 | },
32 | "eval_params": {
33 | "method": "robust",
34 | "verbose": false,
35 | "epsilon": 1,
36 | "loader_params": {"batch_size": 2, "test_batch_size": 2, "normalize_input": true},
37 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
38 | "patch_w": 5, "patch_l": 5,
39 | "epsilon": 1}
40 | },
41 | "models": [
42 | {
43 | "model_id": "mlp_255",
44 | "model_class": "model_mlp_any",
45 | "model_params": {"in_dim": 3072, "neurons": [255]}
46 | },
47 | {
48 | "model_id": "cnn_2layer_width_1",
49 | "model_class": "model_cnn_2layer",
50 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
51 | },
52 | {
53 | "model_id": "cnn_4layer_linear_256_width_1",
54 | "model_class": "model_cnn_4layer",
55 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
56 | },
57 | {
58 | "model_id": "cnn_5layer_linear_512_width_16",
59 | "model_class": "model_cnn_5layer",
60 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 16, "linear_size": 512}
61 | }
62 | ]
63 | }
64 |
65 |
--------------------------------------------------------------------------------
/config/cifar_robtrain_p55_rand5.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "cifar",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/cifar_robtrain_p55_rand5/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 200,
10 | "lr": 0.001,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 10,
15 | "schedule_length":121,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true, "train_random_transform": true, "normalize_input": true},
19 | "method_params": {"attack_type": "patch-random", "bound_type": "patch-interval",
20 | "patch_w": 5, "patch_l": 5, "patch_count": 5,
21 | "epsilon": 1}
22 | },
23 | "eval_params": {
24 | "method": "robust",
25 | "verbose": false,
26 | "epsilon": 1,
27 | "loader_params": {"batch_size": 10, "test_batch_size": 10, "normalize_input": true},
28 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
29 | "patch_w": 5, "patch_l": 5,
30 | "epsilon": 1}
31 | },
32 | "models": [
33 | {
34 | "model_id": "mlp_255",
35 | "model_class": "model_mlp_any",
36 | "model_params": {"in_dim": 3072, "neurons": [255]}
37 | },
38 | {
39 | "model_id": "cnn_2layer_width_1",
40 | "model_class": "model_cnn_2layer",
41 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
42 | },
43 | {
44 | "model_id": "cnn_4layer_linear_256_width_1",
45 | "model_class": "model_cnn_4layer",
46 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 1, "linear_size": 256}
47 | },
48 | {
49 | "model_id": "cnn_5layer_linear_512_width_16",
50 | "model_class": "model_cnn_5layer",
51 | "model_params": {"in_ch": 3, "in_dim": 32, "width": 16, "linear_size": 512}
52 | }
53 | ]
54 | }
55 |
56 |
--------------------------------------------------------------------------------
/config/mnist_adveval_patch_22.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "mnist",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "log_suffix": "eval_patch_22_strong",
7 | "models_path": "",
8 | "eval_params": {
9 | "method": "adv",
10 | "verbose": false,
11 | "epsilon": 1,
12 | "loader_params": {"batch_size": 40, "test_batch_size": 40, "shuffle_train": true},
13 | "method_params": {"bounded_input": true, "attack_type": "patch-random",
14 | "patch_w": 2, "patch_l": 2,
15 | "epsilon": 1, "steps": 150, "random_start": true, "random_mask_count": 80,"sample_limit": 400,
16 | "step_size": 0.05}
17 | },
18 | "models": [
19 | {
20 | "model_id": "mlp_255",
21 | "model_class": "model_mlp_any",
22 | "model_params": {"in_dim": 784, "neurons": [255]}
23 | },
24 | {
25 | "model_id": "cnn_2layer_width_1",
26 | "model_class": "model_cnn_2layer",
27 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
28 | },
29 | {
30 | "model_id": "cnn_4layer_linear_256_width_1",
31 | "model_class": "model_cnn_4layer",
32 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
33 | }
34 | ]
35 | }
36 |
37 |
--------------------------------------------------------------------------------
/config/mnist_adveval_patch_55.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "mnist",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "log_suffix": "eval_patch_55_strong",
7 | "models_path": "",
8 | "eval_params": {
9 | "method": "adv",
10 | "verbose": false,
11 | "epsilon": 1,
12 | "loader_params": {"batch_size": 40, "test_batch_size": 40, "shuffle_train": true},
13 | "method_params": {"bounded_input": true, "attack_type": "patch-random",
14 | "patch_w": 5, "patch_l": 5,
15 | "epsilon": 1, "steps": 150, "random_start": true, "random_mask_count": 80,"sample_limit": 400,
16 | "step_size": 0.05}
17 | },
18 | "models": [
19 | {
20 | "model_id": "mlp_255",
21 | "model_class": "model_mlp_any",
22 | "model_params": {"in_dim": 784, "neurons": [255]}
23 | },
24 | {
25 | "model_id": "cnn_2layer_width_1",
26 | "model_class": "model_cnn_2layer",
27 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
28 | },
29 | {
30 | "model_id": "cnn_4layer_linear_256_width_1",
31 | "model_class": "model_cnn_4layer",
32 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
33 | }
34 | ]
35 | }
36 |
37 |
--------------------------------------------------------------------------------
/config/mnist_robtrain_k10_sparse.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "mnist",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/mnist_robtrain_k10_sparse/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 100,
10 | "lr": 5e-4,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 1,
15 | "schedule_length":11,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true},
19 | "method_params": {"attack_type": "sparse", "bound_type": "sparse-interval",
20 | "k": 10, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 256, "test_batch_size": 256},
27 | "method_params": {"attack_type": "sparse", "bound_type": "sparse-interval",
28 | "k": 10, "epsilon": 1}
29 | },
30 | "models": [
31 | {
32 | "model_id": "mlp_255",
33 | "model_class": "model_mlp_any",
34 | "model_params": {"in_dim": 784, "neurons": [512]}
35 | }
36 | ]
37 | }
38 |
39 |
40 |
--------------------------------------------------------------------------------
/config/mnist_robtrain_k1_sparse.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "mnist",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/mnist_robtrain_k1_sparse/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 100,
10 | "lr": 5e-4,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 1,
15 | "schedule_length":11,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true},
19 | "method_params": {"attack_type": "sparse", "bound_type": "sparse-interval",
20 | "k": 1, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 256, "test_batch_size": 256},
27 | "method_params": {"attack_type": "sparse", "bound_type": "sparse-interval",
28 | "k": 1, "epsilon": 1}
29 | },
30 | "models": [
31 | {
32 | "model_id": "mlp_255",
33 | "model_class": "model_mlp_any",
34 | "model_params": {"in_dim": 784, "neurons": [255]}
35 | }
36 | ]
37 | }
38 |
39 |
40 |
--------------------------------------------------------------------------------
/config/mnist_robtrain_k4_sparse.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "mnist",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/mnist_robtrain_k4_sparse/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 100,
10 | "lr": 5e-4,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 1,
15 | "schedule_length":11,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true},
19 | "method_params": {"attack_type": "sparse", "bound_type": "sparse-interval",
20 | "k": 4, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 256, "test_batch_size": 256},
27 | "method_params": {"attack_type": "sparse", "bound_type": "sparse-interval",
28 | "k": 4, "epsilon": 1}
29 | },
30 | "models": [
31 | {
32 | "model_id": "mlp_255",
33 | "model_class": "model_mlp_any",
34 | "model_params": {"in_dim": 784, "neurons": [255]}
35 | }
36 | ]
37 | }
38 |
39 |
40 |
--------------------------------------------------------------------------------
/config/mnist_robtrain_p22_all.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "mnist",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/mnist_robtrain_p22_all/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 100,
10 | "lr": 5e-4,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 1,
15 | "schedule_length":61,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true},
19 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
20 | "patch_w": 2, "patch_l": 2, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 10, "test_batch_size": 10},
27 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
28 | "patch_w": 2, "patch_l": 2,
29 | "epsilon": 1}
30 | },
31 | "models": [
32 | {
33 | "model_id": "mlp_255",
34 | "model_class": "model_mlp_any",
35 | "model_params": {"in_dim": 784, "neurons": [255]}
36 | },
37 | {
38 | "model_id": "cnn_2layer_width_1",
39 | "model_class": "model_cnn_2layer",
40 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
41 | },
42 | {
43 | "model_id": "cnn_4layer_linear_256_width_1",
44 | "model_class": "model_cnn_4layer",
45 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
46 | }
47 | ]
48 | }
49 |
50 |
51 |
--------------------------------------------------------------------------------
/config/mnist_robtrain_p22_guide10.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "mnist",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/mnist_robtrain_p22_guide10/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 100,
10 | "lr": 5e-4,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 1,
15 | "schedule_length":61,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true},
19 | "method_params": {"attack_type": "patch-nn", "bound_type": "patch-interval",
20 | "patch_w": 2, "patch_l": 2, "patch_count": 1, "base_width": 6, "T": 1, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 10, "test_batch_size": 10},
27 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
28 | "patch_w": 2, "patch_l": 2,
29 | "epsilon": 1}
30 | },
31 | "models": [
32 | {
33 | "model_id": "mlp_255",
34 | "model_class": "model_mlp_any",
35 | "model_params": {"in_dim": 784, "neurons": [255]}
36 | },
37 | {
38 | "model_id": "cnn_2layer_width_1",
39 | "model_class": "model_cnn_2layer",
40 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
41 | },
42 | {
43 | "model_id": "cnn_4layer_linear_256_width_1",
44 | "model_class": "model_cnn_4layer",
45 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
46 | }
47 | ]
48 | }
49 |
50 |
51 |
--------------------------------------------------------------------------------
/config/mnist_robtrain_p22_rand1.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "mnist",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/mnist_robtrain_p22_rand1/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 100,
10 | "lr": 5e-4,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 1,
15 | "schedule_length":61,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true},
19 | "method_params": {"attack_type": "patch-random", "bound_type": "patch-interval",
20 | "patch_w": 2, "patch_l": 2, "patch_count": 1,
21 | "epsilon": 1}
22 | },
23 | "eval_params": {
24 | "method": "robust",
25 | "verbose": false,
26 | "epsilon": 1,
27 | "loader_params": {"batch_size": 10, "test_batch_size": 10},
28 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
29 | "patch_w": 2, "patch_l": 2,
30 | "epsilon": 1}
31 | },
32 | "models": [
33 | {
34 | "model_id": "mlp_255",
35 | "model_class": "model_mlp_any",
36 | "model_params": {"in_dim": 784, "neurons": [255]}
37 | },
38 | {
39 | "model_id": "cnn_2layer_width_1",
40 | "model_class": "model_cnn_2layer",
41 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
42 | },
43 | {
44 | "model_id": "cnn_4layer_linear_256_width_1",
45 | "model_class": "model_cnn_4layer",
46 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
47 | }
48 | ]
49 | }
50 |
51 |
52 |
--------------------------------------------------------------------------------
/config/mnist_robtrain_p22_rand10.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "mnist",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/mnist_robtrain_p22_rand10/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 100,
10 | "lr": 5e-4,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 1,
15 | "schedule_length":61,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true},
19 | "method_params": {"attack_type": "patch-random", "bound_type": "patch-interval",
20 | "patch_w": 2, "patch_l": 2, "patch_count": 10,
21 | "epsilon": 1}
22 | },
23 | "eval_params": {
24 | "method": "robust",
25 | "verbose": false,
26 | "epsilon": 1,
27 | "loader_params": {"batch_size": 10, "test_batch_size": 10},
28 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
29 | "patch_w": 2, "patch_l": 2,
30 | "epsilon": 1}
31 | },
32 | "models": [
33 | {
34 | "model_id": "mlp_255",
35 | "model_class": "model_mlp_any",
36 | "model_params": {"in_dim": 784, "neurons": [255]}
37 | },
38 | {
39 | "model_id": "cnn_2layer_width_1",
40 | "model_class": "model_cnn_2layer",
41 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
42 | },
43 | {
44 | "model_id": "cnn_4layer_linear_256_width_1",
45 | "model_class": "model_cnn_4layer",
46 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
47 | }
48 | ]
49 | }
50 |
51 |
52 |
--------------------------------------------------------------------------------
/config/mnist_robtrain_p22_rand5.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "mnist",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/mnist_robtrain_p22_rand5/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 100,
10 | "lr": 5e-4,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 1,
15 | "schedule_length":61,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true},
19 | "method_params": {"attack_type": "patch-random", "bound_type": "patch-interval",
20 | "patch_w": 2, "patch_l": 2, "patch_count": 5,
21 | "epsilon": 1}
22 | },
23 | "eval_params": {
24 | "method": "robust",
25 | "verbose": false,
26 | "epsilon": 1,
27 | "loader_params": {"batch_size": 10, "test_batch_size": 10},
28 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
29 | "patch_w": 2, "patch_l": 2,
30 | "epsilon": 1}
31 | },
32 | "models": [
33 | {
34 | "model_id": "mlp_255",
35 | "model_class": "model_mlp_any",
36 | "model_params": {"in_dim": 784, "neurons": [255]}
37 | },
38 | {
39 | "model_id": "cnn_2layer_width_1",
40 | "model_class": "model_cnn_2layer",
41 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
42 | },
43 | {
44 | "model_id": "cnn_4layer_linear_256_width_1",
45 | "model_class": "model_cnn_4layer",
46 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
47 | }
48 | ]
49 | }
50 |
51 |
52 |
--------------------------------------------------------------------------------
/config/mnist_robtrain_p55_all.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "mnist",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/mnist_robtrain_p55_all/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 100,
10 | "lr": 5e-4,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 1,
15 | "schedule_length":61,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true},
19 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
20 | "patch_w": 5, "patch_l": 5, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 10, "test_batch_size": 10},
27 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
28 | "patch_w": 5, "patch_l": 5,
29 | "epsilon": 1}
30 | },
31 | "models": [
32 | {
33 | "model_id": "mlp_255",
34 | "model_class": "model_mlp_any",
35 | "model_params": {"in_dim": 784, "neurons": [255]}
36 | },
37 | {
38 | "model_id": "cnn_2layer_width_1",
39 | "model_class": "model_cnn_2layer",
40 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
41 | },
42 | {
43 | "model_id": "cnn_4layer_linear_256_width_1",
44 | "model_class": "model_cnn_4layer",
45 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
46 | }
47 | ]
48 | }
49 |
50 |
51 |
--------------------------------------------------------------------------------
/config/mnist_robtrain_p55_guide10.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "mnist",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/mnist_robtrain_p55_guide10/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 100,
10 | "lr": 5e-4,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 1,
15 | "schedule_length":61,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true},
19 | "method_params": {"attack_type": "patch-nn", "bound_type": "patch-interval",
20 | "patch_w": 5, "patch_l": 5, "patch_count": 1, "base_width": 6, "T": 1, "epsilon": 1}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "verbose": false,
25 | "epsilon": 1,
26 | "loader_params": {"batch_size": 10, "test_batch_size": 10},
27 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
28 | "patch_w": 5, "patch_l": 5,
29 | "epsilon": 1}
30 | },
31 | "models": [
32 | {
33 | "model_id": "mlp_255",
34 | "model_class": "model_mlp_any",
35 | "model_params": {"in_dim": 784, "neurons": [255]}
36 | },
37 | {
38 | "model_id": "cnn_2layer_width_1",
39 | "model_class": "model_cnn_2layer",
40 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
41 | },
42 | {
43 | "model_id": "cnn_4layer_linear_256_width_1",
44 | "model_class": "model_cnn_4layer",
45 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
46 | }
47 | ]
48 | }
49 |
50 |
51 |
--------------------------------------------------------------------------------
/config/mnist_robtrain_p55_rand1.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "mnist",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/mnist_robtrain_p55_rand1/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 100,
10 | "lr": 5e-4,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 1,
15 | "schedule_length":61,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true},
19 | "method_params": {"attack_type": "patch-random", "bound_type": "patch-interval",
20 | "patch_w": 5, "patch_l": 5, "patch_count": 1,
21 | "epsilon": 1}
22 | },
23 | "eval_params": {
24 | "method": "robust",
25 | "verbose": false,
26 | "epsilon": 1,
27 | "loader_params": {"batch_size": 10, "test_batch_size": 10},
28 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
29 | "patch_w": 5, "patch_l": 5,
30 | "epsilon": 1}
31 | },
32 | "models": [
33 | {
34 | "model_id": "mlp_255",
35 | "model_class": "model_mlp_any",
36 | "model_params": {"in_dim": 784, "neurons": [255]}
37 | },
38 | {
39 | "model_id": "cnn_2layer_width_1",
40 | "model_class": "model_cnn_2layer",
41 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
42 | },
43 | {
44 | "model_id": "cnn_4layer_linear_256_width_1",
45 | "model_class": "model_cnn_4layer",
46 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
47 | }
48 | ]
49 | }
50 |
51 |
52 |
--------------------------------------------------------------------------------
/config/mnist_robtrain_p55_rand10.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "mnist",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/mnist_robtrain_p55_rand10/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 100,
10 | "lr": 5e-4,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 1,
15 | "schedule_length":61,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true},
19 | "method_params": {"attack_type": "patch-random", "bound_type": "patch-interval",
20 | "patch_w": 5, "patch_l": 5, "patch_count": 10,
21 | "epsilon": 1}
22 | },
23 | "eval_params": {
24 | "method": "robust",
25 | "verbose": false,
26 | "epsilon": 1,
27 | "loader_params": {"batch_size": 10, "test_batch_size": 10},
28 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
29 | "patch_w": 5, "patch_l": 5,
30 | "epsilon": 1}
31 | },
32 | "models": [
33 | {
34 | "model_id": "mlp_255",
35 | "model_class": "model_mlp_any",
36 | "model_params": {"in_dim": 784, "neurons": [255]}
37 | },
38 | {
39 | "model_id": "cnn_2layer_width_1",
40 | "model_class": "model_cnn_2layer",
41 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
42 | },
43 | {
44 | "model_id": "cnn_4layer_linear_256_width_1",
45 | "model_class": "model_cnn_4layer",
46 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
47 | }
48 | ]
49 | }
50 |
51 |
52 |
--------------------------------------------------------------------------------
/config/mnist_robtrain_p55_rand5.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "mnist",
4 | "log_suffix": "",
5 | "path_prefix": "",
6 | "models_path": "./results/mnist_robtrain_p55_rand5/",
7 | "training_params": {
8 | "method": "robust",
9 | "epochs": 100,
10 | "lr": 5e-4,
11 | "weight_decay": 0.0,
12 | "starting_epsilon": 0,
13 | "epsilon": 1,
14 | "schedule_start": 1,
15 | "schedule_length":61,
16 | "optimizer": "adam",
17 | "verbose": false,
18 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true},
19 | "method_params": {"attack_type": "patch-random", "bound_type": "patch-interval",
20 | "patch_w": 5, "patch_l": 5, "patch_count": 5,
21 | "epsilon": 1}
22 | },
23 | "eval_params": {
24 | "method": "robust",
25 | "verbose": false,
26 | "epsilon": 1,
27 | "loader_params": {"batch_size": 10, "test_batch_size": 10},
28 | "method_params": {"attack_type": "patch-all", "bound_type": "patch-interval",
29 | "patch_w": 5, "patch_l": 5,
30 | "epsilon": 1}
31 | },
32 | "models": [
33 | {
34 | "model_id": "mlp_255",
35 | "model_class": "model_mlp_any",
36 | "model_params": {"in_dim": 784, "neurons": [255]}
37 | },
38 | {
39 | "model_id": "cnn_2layer_width_1",
40 | "model_class": "model_cnn_2layer",
41 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
42 | },
43 | {
44 | "model_id": "cnn_4layer_linear_256_width_1",
45 | "model_class": "model_cnn_4layer",
46 | "model_params": {"in_ch": 1, "in_dim": 28, "width": 1, "linear_size": 256}
47 | }
48 | ]
49 | }
50 |
51 |
52 |
--------------------------------------------------------------------------------
/converter.py:
--------------------------------------------------------------------------------
1 | ## Copyright (C) 2019, Huan Zhang
2 | ## Hongge Chen
3 | ## Chaowei Xiao
4 | ##
5 | ## This program is licenced under the BSD 2-Clause License,
6 | ## contained in the LICENCE file in this directory.
7 | ##
8 | import sys
9 | import copy
10 | import torch
11 | from torch.nn import Sequential, Linear, ReLU, CrossEntropyLoss
12 | import numpy as np
13 | from datasets import loaders
14 | from model_defs import Flatten, model_mlp_any, model_cnn_1layer, model_cnn_2layer, model_cnn_4layer, model_cnn_3layer
15 | from bound_layers import BoundSequential
16 | import torch.optim as optim
17 | import time
18 | from datetime import datetime
19 |
20 | from config import load_config, get_path, config_modelloader, config_dataloader, config_modelloader_and_convert2mlp
21 | from argparser import argparser
22 | from pdb import set_trace as st
23 | # sys.settrace(gpu_profile)
24 |
25 |
26 | def main(args):
27 | config = load_config(args)
28 | global_train_config = config["training_params"]
29 | models, model_names = config_modelloader_and_convert2mlp(config)
30 |
31 | if __name__ == "__main__":
32 | args = argparser()
33 | main(args)
34 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | ## Copyright (C) 2019, Huan Zhang
2 | ## Hongge Chen
3 | ## Chaowei Xiao
4 | ##
5 | ## This program is licenced under the BSD 2-Clause License,
6 | ## contained in the LICENCE file in this directory.
7 | ##
8 | import multiprocessing
9 | import torch
10 | from torch.utils import data
11 | from functools import partial
12 | import torchvision.transforms as transforms
13 | import torchvision.datasets as datasets
14 |
15 | # compute image statistics (by Andreas https://discuss.pytorch.org/t/computing-the-mean-and-std-of-dataset/34949/4)
16 | def get_stats(loader):
17 | mean = 0.0
18 | for images, _ in loader:
19 | batch_samples = images.size(0)
20 | reshaped_img = images.view(batch_samples, images.size(1), -1)
21 | mean += reshaped_img.mean(2).sum(0)
22 | w = images.size(2)
23 | h = images.size(3)
24 | mean = mean / len(loader.dataset)
25 |
26 | var = 0.0
27 | for images, _ in loader:
28 | batch_samples = images.size(0)
29 | images = images.view(batch_samples, images.size(1), -1)
30 | var += ((images - mean.unsqueeze(1))**2).sum([0,2])
31 | std = torch.sqrt(var / (len(loader.dataset)*w*h))
32 | return mean, std
33 |
34 | # load MNIST of Fashion-MNIST
35 | def mnist_loaders(dataset, batch_size, shuffle_train = True, shuffle_test = False, normalize_input = False, num_examples = None, test_batch_size=None):
36 | mnist_train = dataset("./data", train=True, download=True, transform=transforms.ToTensor())
37 | mnist_test = dataset("./data", train=False, download=True, transform=transforms.ToTensor())
38 | if num_examples:
39 | indices = list(range(num_examples))
40 | mnist_train = data.Subset(mnist_train, indices)
41 | mnist_test = data.Subset(mnist_test, indices)
42 | train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=shuffle_train, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),2))
43 | if test_batch_size:
44 | batch_size = test_batch_size
45 | test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=shuffle_test, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),2))
46 | std = [1.0]
47 | train_loader.std = std
48 | test_loader.std = std
49 | train_loader.mean = 0
50 | test_loader.mean = 0
51 | return train_loader, test_loader
52 |
53 | def cifar_loaders(batch_size, shuffle_train = True, shuffle_test = False, train_random_transform = False, normalize_input = False, num_examples = None, test_batch_size=None):
54 | if normalize_input:
55 | std = [0.2023, 0.1994, 0.2010]
56 | normalize = transforms.Normalize(mean = [0.4914, 0.4822, 0.4465],
57 | std = std)
58 | else:
59 | std = [1.0, 1.0, 1.0]
60 | normalize = transforms.Normalize(mean=[0, 0, 0],
61 | std=std)
62 | if train_random_transform:
63 | if normalize_input:
64 | train = datasets.CIFAR10('./data', train=True, download=True,
65 | transform=transforms.Compose([
66 | transforms.RandomHorizontalFlip(),
67 | transforms.RandomCrop(32, 4),
68 | transforms.ToTensor(),
69 | normalize,
70 | ]))
71 | else:
72 | train = datasets.CIFAR10('./data', train=True, download=True,
73 | transform=transforms.Compose([
74 | transforms.RandomHorizontalFlip(),
75 | transforms.RandomCrop(32, 4),
76 | transforms.ToTensor(),
77 | ]))
78 | else:
79 | train = datasets.CIFAR10('./data', train=True, download=True,
80 | transform=transforms.Compose([transforms.ToTensor(),normalize]))
81 | test = datasets.CIFAR10('./data', train=False,
82 | transform=transforms.Compose([transforms.ToTensor(), normalize]))
83 |
84 | if num_examples:
85 | indices = list(range(num_examples))
86 | train = data.Subset(train, indices)
87 | test = data.Subset(test, indices)
88 |
89 | train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size,
90 | shuffle=shuffle_train, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),6))
91 | if test_batch_size:
92 | batch_size = test_batch_size
93 | test_loader = torch.utils.data.DataLoader(test, batch_size=max(batch_size, 1),
94 | shuffle=shuffle_test, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),6))
95 | train_loader.std = std
96 | test_loader.std = std
97 | train_loader.mean = [0.4914, 0.4822, 0.4465]
98 | test_loader.mean = [0.4914, 0.4822, 0.4465]
99 | return train_loader, test_loader
100 |
101 | def svhn_loaders(batch_size, shuffle_train = True, shuffle_test = False, train_random_transform = False, normalize_input = False, num_examples = None, test_batch_size=None):
102 | if normalize_input:
103 | mean = [0.43768206, 0.44376972, 0.47280434]
104 | std = [0.19803014, 0.20101564, 0.19703615]
105 | normalize = transforms.Normalize(mean = mean,
106 | std = std)
107 | else:
108 | std = [1.0, 1.0, 1.0]
109 | normalize = transforms.Normalize(mean=[0, 0, 0],
110 | std=std)
111 | if train_random_transform:
112 | if normalize_input:
113 | train = datasets.SVHN('./data', split='train', download=True,
114 | transform=transforms.Compose([
115 | transforms.RandomCrop(32, 4),
116 | transforms.ToTensor(),
117 | normalize,
118 | ]))
119 | else:
120 | train = datasets.SVHN('./data', split='train', download=True,
121 | transform=transforms.Compose([
122 | transforms.RandomCrop(32, 4),
123 | transforms.ToTensor(),
124 | ]))
125 | else:
126 | train = datasets.SVHN('./data', split='train', download=True,
127 | transform=transforms.Compose([transforms.ToTensor(),normalize]))
128 | test = datasets.SVHN('./data', split='test', download=True,
129 | transform=transforms.Compose([transforms.ToTensor(), normalize]))
130 |
131 | if num_examples:
132 | indices = list(range(num_examples))
133 | train = data.Subset(train, indices)
134 | test = data.Subset(test, indices)
135 |
136 | train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size,
137 | shuffle=shuffle_train, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),6))
138 | if test_batch_size:
139 | batch_size = test_batch_size
140 | test_loader = torch.utils.data.DataLoader(test, batch_size=max(batch_size, 1),
141 | shuffle=shuffle_test, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),6))
142 | train_loader.std = std
143 | test_loader.std = std
144 | mean, std = get_stats(train_loader)
145 | print('dataset mean = ', mean.numpy(), 'std = ', std.numpy())
146 | return train_loader, test_loader
147 |
148 | # when new loaders is added, they must be registered here
149 | loaders = {
150 | "mnist": partial(mnist_loaders, datasets.MNIST),
151 | "fashion-mnist": partial(mnist_loaders, datasets.FashionMNIST),
152 | "cifar": cifar_loaders,
153 | "svhn": svhn_loaders,
154 | }
155 |
156 |
--------------------------------------------------------------------------------
/defaults.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_def": "model_defs.py",
3 | "dataset": "mnist",
4 | "path_prefix": "",
5 | "alpha_path": "__unused__",
6 | "training_params": {
7 | "method": "natural",
8 | "epochs": 80,
9 | "lr": 5e-4,
10 | "lr_decay_step": 10,
11 | "lr_decay_factor": 0.5,
12 | "weight_decay": 0.0,
13 | "optimizer": "adam",
14 | "starting_epsilon": 0.00,
15 | "epsilon": 0.50,
16 | "schedule_length":61,
17 | "norm": "inf",
18 | "verbose": false,
19 | "loader_params": {"batch_size": 128, "shuffle_train": true},
20 | "method_params": {"batch_size": 128, "shuffle_train": true, "runnerup_only": false, "activity_reg": 0.000, "final-beta": 0.0, "final-kappa": 0.5, "convex-proj": 50}
21 | },
22 | "eval_params": {
23 | "method": "robust",
24 | "norm": "inf",
25 | "verbose": false,
26 | "loader_params": {"batch_size": 256, "test_batch_size": 256, "shuffle_train": true},
27 | "method_params": {"bounded_input": true, "runnerup_only": false, "activity_reg": 0.000, "final-beta": 1.0, "final-kappa": 0.5, "bound_type": "interval", "convex-proj": null}
28 | }
29 | }
30 |
31 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | ## Copyright (C) 2019, Huan Zhang
2 | ## Hongge Chen
3 | ## Chaowei Xiao
4 | ##
5 | ## This program is licenced under the BSD 2-Clause License,
6 | ## contained in the LICENCE file in this directory.
7 | ##
8 | import sys
9 | import copy
10 | import torch
11 | from torch.nn import Sequential, Linear, ReLU, CrossEntropyLoss
12 | import numpy as np
13 | from datasets import loaders
14 | from model_defs import Flatten, model_mlp_any, model_cnn_1layer, model_cnn_2layer, model_cnn_4layer, model_cnn_3layer
15 | from bound_layers import BoundSequential
16 | import torch.optim as optim
17 | # from gpu_profile import gpu_profile
18 | import time
19 | from datetime import datetime
20 |
21 |
22 | from config import load_config, get_path, config_modelloader, config_dataloader
23 | from argparser import argparser
24 | from train import Train, Logger
25 | # sys.settrace(gpu_profile)
26 |
27 |
28 | def main(args):
29 | config = load_config(args)
30 | global_eval_config = config["eval_params"]
31 | models, model_names = config_modelloader(config, load_pretrain = True)
32 |
33 | converted_models = [BoundSequential.convert(model) for model in models]
34 |
35 | robust_errs = []
36 | errs = []
37 | for model, model_id, model_config in zip(converted_models, model_names, config["models"]):
38 | model = model.cuda()
39 |
40 | # make a copy of global training config, and update per-model config
41 | eval_config = copy.deepcopy(global_eval_config)
42 | if "eval_params" in model_config:
43 | eval_config.update(model_config["eval_params"])
44 |
45 | # read training parameters from config file
46 | method = eval_config["method"]
47 | verbose = eval_config["verbose"]
48 | eps = eval_config["epsilon"]
49 | # parameters specific to a training method
50 | method_param = eval_config["method_params"]
51 | norm = float(eval_config["norm"])
52 | train_data, test_data = config_dataloader(config, **eval_config["loader_params"])
53 |
54 | model_name = get_path(config, model_id, "model", load = False)
55 | print(model_name)
56 | model_log = get_path(config, model_id, "eval_log")
57 | logger = Logger(open(model_log, "w"))
58 | logger.log("evaluation configurations:", eval_config)
59 |
60 | logger.log("Evaluating...")
61 | # evaluate
62 | robust_err, err = Train(model, model_id, 0, test_data, eps, eps, eps, norm, logger, verbose, False, None, method, **method_param)
63 | robust_errs.append(robust_err)
64 | errs.append(err)
65 |
66 | print('model robust errors (for robustly trained models, not valid for naturally trained models):')
67 | print(robust_errs)
68 | robust_errs = np.array(robust_errs)
69 | print('min: {:.4f}, max: {:.4f}, median: {:.4f}, mean: {:.4f}'.format(np.min(robust_errs), np.max(robust_errs), np.median(robust_errs), np.mean(robust_errs)))
70 | print('clean errors for models with min, max and median robust errors')
71 | i_min = np.argmin(robust_errs)
72 | i_max = np.argmax(robust_errs)
73 | i_median = np.argsort(robust_errs)[len(robust_errs) // 2]
74 | print('for min: {:.4f}, for max: {:.4f}, for median: {:.4f}'.format(errs[i_min], errs[i_max], errs[i_median]))
75 | print('model clean errors:')
76 | print(errs)
77 | print('min: {:.4f}, max: {:.4f}, median: {:.4f}, mean: {:.4f}'.format(np.min(errs), np.max(errs), np.median(errs), np.mean(errs)))
78 |
79 |
80 | if __name__ == "__main__":
81 | args = argparser()
82 | main(args)
83 |
--------------------------------------------------------------------------------
/model_defs.py:
--------------------------------------------------------------------------------
1 | ## Copyright (C) 2019, Huan Zhang
2 | ## Hongge Chen
3 | ## Chaowei Xiao
4 | ##
5 | ## This program is licenced under the BSD 2-Clause License,
6 | ## contained in the LICENCE file in this directory.
7 | ##
8 |
9 | # from convex_adversarial import Dense, DenseSequential
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 | import torch.nn.functional as F
15 | from torch.nn.parameter import Parameter
16 | import torch.optim as optim
17 | from torchvision import datasets, transforms
18 | from torch.autograd import Variable
19 | import argparse
20 | from pdb import set_trace as st
21 | import numpy as np
22 | import math
23 | import collections
24 |
25 |
26 | class Flatten(nn.Module):
27 | def forward(self, x):
28 | return x.view(x.size(0), -1)
29 |
30 |
31 | # MLP model, each layer has the same number of neuron
32 | # parameter in_dim: input image dimension, 784 for MNIST and 1024 for CIFAR
33 | # parameter layer: number of layers
34 | # parameter neuron: number of neurons per layer
35 | def model_mlp_uniform(in_dim, layer, neurons, out_dim = 10):
36 | assert layer >= 2
37 | neurons = [neurons] * (layer - 1)
38 | return model_mlp_any(in_dim, neurons, out_dim)
39 |
40 | # MLP model, each layer has the different number of neurons
41 | # parameter in_dim: input image dimension, 784 for MNIST and 1024 for CIFAR
42 | # parameter neurons: a list of neurons for each layer
43 | def model_mlp_any(in_dim, neurons, out_dim = 10):
44 | assert len(neurons) >= 1
45 | # input layer
46 | units = [Flatten(), nn.Linear(in_dim, neurons[0])]
47 | prev = neurons[0]
48 | # intermediate layers
49 | for n in neurons[1:]:
50 | units.append(nn.ReLU())
51 | units.append(nn.Linear(prev, n))
52 | prev = n
53 | # output layer
54 | units.append(nn.ReLU())
55 | units.append(nn.Linear(neurons[-1], out_dim))
56 | #print(units)
57 | return nn.Sequential(*units)
58 |
59 | def model_cnn_1layer(in_ch, in_dim, width):
60 | model = nn.Sequential(
61 | nn.Conv2d(in_ch, 8*width, 4, stride=4),
62 | nn.ReLU(),
63 | Flatten(),
64 | nn.Linear(8*width*(in_dim // 4)*(in_dim // 4),10),
65 | )
66 | return model
67 |
68 |
69 | # CNN, small 2-layer (kernel size fixed to 4)
70 | # parameter in_ch: input image channel, 1 for MNIST and 3 for CIFAR
71 | # parameter in_dim: input dimension, 28 for MNIST and 32 for CIFAR
72 | # parameter width: width multiplier
73 | def model_cnn_2layer(in_ch, in_dim, width, linear_size=128):
74 | model = nn.Sequential(
75 | nn.Conv2d(in_ch, 4*width, 4, stride=2, padding=1),
76 | nn.ReLU(),
77 | nn.Conv2d(4*width, 8*width, 4, stride=2, padding=1),
78 | nn.ReLU(),
79 | Flatten(),
80 | nn.Linear(8*width*(in_dim // 4)*(in_dim // 4),linear_size),
81 | nn.ReLU(),
82 | nn.Linear(linear_size, 10)
83 | )
84 | return model
85 |
86 | # CNN, relatively small 3-layer
87 | # parameter in_ch: input image channel, 1 for MNIST and 3 for CIFAR
88 | # parameter in_dim: input dimension, 28 for MNIST and 32 for CIFAR
89 | # parameter kernel_size: convolution kernel size, 3 or 5
90 | # parameter width: width multiplier
91 | def model_cnn_3layer(in_ch, in_dim, kernel_size, width):
92 | if kernel_size == 5:
93 | h = (in_dim - 4) // 4
94 | elif kernel_size == 3:
95 | h = in_dim // 4
96 | else:
97 | raise ValueError("Unsupported kernel size")
98 | model = nn.Sequential(
99 | nn.Conv2d(in_ch, 4*width, kernel_size=kernel_size, stride=1, padding=1),
100 | nn.ReLU(),
101 | nn.Conv2d(4*width, 8*width, kernel_size=kernel_size, stride=1, padding=1),
102 | nn.ReLU(),
103 | nn.Conv2d(8*width, 8*width, kernel_size=4, stride=4, padding=0),
104 | nn.ReLU(),
105 | Flatten(),
106 | nn.Linear(8*width*h*h, width*64),
107 | nn.Linear(width*64, 10)
108 | )
109 | return model
110 |
111 | def model_cnn_3layer_fixed(in_ch, in_dim, kernel_size, width, linear_size = None):
112 | if linear_size is None:
113 | linear_size = width * 64
114 | if kernel_size == 5:
115 | h = (in_dim - 4) // 4
116 | elif kernel_size == 3:
117 | h = in_dim // 4
118 | else:
119 | raise ValueError("Unsupported kernel size")
120 | model = nn.Sequential(
121 | nn.Conv2d(in_ch, 4*width, kernel_size=kernel_size, stride=1, padding=1),
122 | nn.ReLU(),
123 | nn.Conv2d(4*width, 8*width, kernel_size=kernel_size, stride=1, padding=1),
124 | nn.ReLU(),
125 | nn.Conv2d(8*width, 8*width, kernel_size=4, stride=4, padding=0),
126 | nn.ReLU(),
127 | Flatten(),
128 | nn.Linear(8*width*h*h, linear_size),
129 | nn.ReLU(),
130 | nn.Linear(linear_size, 10)
131 | )
132 | return model
133 |
134 | # CNN, relatively large 4-layer
135 | # parameter in_ch: input image channel, 1 for MNIST and 3 for CIFAR
136 | # parameter in_dim: input dimension, 28 for MNIST and 32 for CIFAR
137 | # parameter width: width multiplier
138 | def model_cnn_4layer(in_ch, in_dim, width, linear_size):
139 | model = nn.Sequential(
140 | nn.Conv2d(in_ch, 4*width, 3, stride=1, padding=1),
141 | nn.ReLU(),
142 | nn.Conv2d(4*width, 4*width, 4, stride=2, padding=1),
143 | nn.ReLU(),
144 | nn.Conv2d(4*width, 8*width, 3, stride=1, padding=1),
145 | nn.ReLU(),
146 | nn.Conv2d(8*width, 8*width, 4, stride=2, padding=1),
147 | nn.ReLU(),
148 | Flatten(),
149 | nn.Linear(8*width*(in_dim // 4)*(in_dim // 4),linear_size),
150 | nn.ReLU(),
151 | nn.Linear(linear_size,linear_size),
152 | nn.ReLU(),
153 | nn.Linear(linear_size,10)
154 | )
155 | return model
156 |
157 | def model_cnn_4layer_conv11(in_ch, in_dim, width, linear_size):
158 | model = nn.Sequential(
159 | nn.Conv2d(in_ch, 4*width, 11, stride=1, padding=5),
160 | nn.ReLU(),
161 | nn.Conv2d(4*width, 4*width, 4, stride=2, padding=1),
162 | nn.ReLU(),
163 | nn.Conv2d(4*width, 8*width, 3, stride=1, padding=1),
164 | nn.ReLU(),
165 | nn.Conv2d(8*width, 8*width, 4, stride=2, padding=1),
166 | nn.ReLU(),
167 | Flatten(),
168 | nn.Linear(8*width*(in_dim // 4)*(in_dim // 4),linear_size),
169 | nn.ReLU(),
170 | nn.Linear(linear_size,linear_size),
171 | nn.ReLU(),
172 | nn.Linear(linear_size,10)
173 | )
174 | return model
175 |
176 | def model_cnn_4layer_conv13(in_ch, in_dim, width, linear_size):
177 | model = nn.Sequential(
178 | nn.Conv2d(in_ch, 4*width, 13, stride=1, padding=6),
179 | nn.ReLU(),
180 | nn.Conv2d(4*width, 4*width, 4, stride=2, padding=1),
181 | nn.ReLU(),
182 | nn.Conv2d(4*width, 8*width, 3, stride=1, padding=1),
183 | nn.ReLU(),
184 | nn.Conv2d(8*width, 8*width, 4, stride=2, padding=1),
185 | nn.ReLU(),
186 | Flatten(),
187 | nn.Linear(8*width*(in_dim // 4)*(in_dim // 4),linear_size),
188 | nn.ReLU(),
189 | nn.Linear(linear_size,linear_size),
190 | nn.ReLU(),
191 | nn.Linear(linear_size,10)
192 | )
193 | return model
194 |
195 | def model_cnn_5layer(in_ch, in_dim, width, linear_size):
196 | model = nn.Sequential(
197 | nn.Conv2d(in_ch, 4*width, 3, stride=1, padding=1),
198 | nn.ReLU(),
199 | nn.Conv2d(4*width, 4*width, 3, stride=1, padding=1),
200 | nn.ReLU(),
201 | nn.Conv2d(4*width, 8*width, 3, stride=2, padding=1),
202 | nn.ReLU(),
203 | nn.Conv2d(8*width, 8*width, 3, stride=1, padding=1),
204 | nn.ReLU(),
205 | nn.Conv2d(8*width, 8*width, 3, stride=1, padding=1),
206 | nn.ReLU(),
207 | Flatten(),
208 | nn.Linear(8*width*(in_dim // 2)*(in_dim // 2),linear_size),
209 | nn.ReLU(),
210 | nn.Linear(linear_size,10)
211 | )
212 | return model
213 |
214 | def model_cnn_6layer(in_ch, in_dim, width, linear_size):
215 | model = nn.Sequential(
216 | nn.Conv2d(in_ch, 4*width, 3, stride=1, padding=1),
217 | nn.ReLU(),
218 | nn.Conv2d(4*width, 8*width, 3, stride=1, padding=1),
219 | nn.ReLU(),
220 | nn.Conv2d(8*width, 8*width, 4, stride=2, padding=1),
221 | nn.ReLU(),
222 | nn.Conv2d(8*width, 8*width, 3, stride=1, padding=1),
223 | nn.ReLU(),
224 | nn.Conv2d(8*width, 16*width, 3, stride=1, padding=1),
225 | nn.ReLU(),
226 | nn.Conv2d(16*width, 16*width, 4, stride=2, padding=1),
227 | nn.ReLU(),
228 | Flatten(),
229 | nn.Linear(16*width*(in_dim // 4)*(in_dim // 4),linear_size),
230 | nn.ReLU(),
231 | nn.Linear(linear_size,linear_size),
232 | nn.ReLU(),
233 | nn.Linear(linear_size,10)
234 | )
235 | return model
236 |
237 | def model_cnn_10layer(in_ch, in_dim, width):
238 | model = nn.Sequential(
239 | # input 32*32*3
240 | nn.Conv2d(in_ch, 4*width, 3, stride=1, padding=1),
241 | nn.ReLU(),
242 | # input 32*32*4
243 | nn.Conv2d(4*width, 8*width, 2, stride=2, padding=0),
244 | nn.ReLU(),
245 | # input 16*16*8
246 | nn.Conv2d(8*width, 8*width, 3, stride=1, padding=1),
247 | nn.ReLU(),
248 | # input 16*16*8
249 | nn.Conv2d(8*width, 16*width, 2, stride=2, padding=0),
250 | nn.ReLU(),
251 | # input 8*8*16
252 | nn.Conv2d(16*width, 16*width, 3, stride=1, padding=1),
253 | nn.ReLU(),
254 | # input 8*8*16
255 | nn.Conv2d(16*width, 32*width, 2, stride=2, padding=0),
256 | nn.ReLU(),
257 | # input 4*4*32
258 | nn.Conv2d(32*width, 32*width, 3, stride=1, padding=1),
259 | nn.ReLU(),
260 | # input 4*4*32
261 | nn.Conv2d(32*width, 64*width, 2, stride=2, padding=0),
262 | nn.ReLU(),
263 | # input 2*2*64
264 | Flatten(),
265 | nn.Linear(2*2*64*width,10)
266 | )
267 | return model
268 |
269 | # below are utilities for feature masking, not used
270 | class FeatureMask2D(nn.Module):
271 | def __init__(self, in_ch, in_dim, keep = 1.0, seed = 0):
272 | super(FeatureMask2D, self).__init__()
273 | self.in_ch = in_ch
274 | self.in_dim = in_dim
275 | self.keep = keep
276 | self.seed = seed
277 | state = torch.get_rng_state()
278 | torch.manual_seed(seed)
279 | self.weight = torch.rand((1, in_ch, in_dim, in_dim))
280 | torch.set_rng_state(state)
281 | self.weight.require_grad = False
282 | self.weight.data[:] = (self.weight.data <= keep)
283 |
284 | # we don't want to register self.weight as a parameter, as it is not trainable
285 | # but we need to be able to apply operations on it
286 | def _apply(self, fn):
287 | super(FeatureMask2D, self)._apply(fn)
288 | self.weight.data = fn(self.weight.data)
289 |
290 | def forward(self, x):
291 | return x * self.weight
292 |
293 | def extra_repr(self):
294 | return 'in_ch={}, in_dim={}, keep={}, seed={}'.format(self.in_ch, self.in_dim, self.keep, self.seed)
295 |
296 | def add_feature_subsample(model, in_ch, in_dim, keep = 1.0, seed = 0):
297 | layers = list(model.children())
298 | # add a new masking layer
299 | mask_layer = FeatureMask2D(in_ch, in_dim, keep, seed)
300 | new_model = model.__class__()
301 | new_model.add_module("__mask_layer", mask_layer)
302 | for name, layer in model.named_modules():
303 | # print(name, layer)
304 | if name and '.' not in name:
305 | new_model.add_module(name, layer)
306 | return new_model
307 |
308 | def remove_feature_subsample(model):
309 | layers = list(model.children())
310 | # remove the first layer and rebuild
311 | layers = layers[1:]
312 | return model.__class__(*layers)
313 |
314 |
315 |
316 |
317 |
318 | # below are utilities for model converters, not used during training
319 | class DenseConv2d(nn.Module):
320 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
321 | padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
322 | # super(nn.Conv2d, self).__init__( in_channels, out_channels, kernel_size, stride=1,padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')
323 | super(DenseConv2d, self).__init__()
324 | self.weight = Parameter(torch.randn(out_channels, in_channels//groups, *kernel_size) )
325 |
326 | if bias is not None:
327 | self.bias = Parameter(torch.zeros(out_channels))
328 | else:
329 | self.bias = None
330 | self.stride = stride
331 | self.padding = padding
332 | self.dilation = dilation
333 | self.groups = 1
334 |
335 |
336 | def Denseforward(self, inputs):
337 | b, n, w, h = inputs.shape
338 | kernel = self.weight
339 | bias = self.bias
340 | I = torch.eye(n*w*h).view(n*w*h, n, w, h)
341 | W = F.conv2d(I, kernel, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
342 | input_flat = inputs.view(b, -1)
343 | b1, n1, w1, h1 = W.shape
344 | out = torch.matmul(input_flat, W.view(b1, -1)).view(b, n1, w1, h1)
345 | new_bias = bias.view(1,n1,1,1).repeat(1,1,w1,h1)
346 |
347 |
348 | if type(bias) != type(True):
349 | # out2 = out + bias.view(1, n1, 1, 1)
350 | out2 = out + new_bias
351 | else:
352 | out2 = out
353 | self.dense_w = W.view(b1,-1).transpose(1,0)
354 | self.dense_bias = new_bias.view(-1)
355 | # print( ((gt - out2) **2).sum())
356 | # torch.matmul(input_flat, W.view(n*w*h, -1)).view(b, )
357 | return out2
358 |
359 |
360 | def forward(self, input):
361 | # out = F.conv2d(input, self.weight,self.bias, self.stride,
362 | # self.padding, self.dilation, self.groups)
363 | out = self.Denseforward(input)
364 | return out
365 |
366 | def convert_conv2d_dense(model):
367 | layers = list(model.children())
368 | new_model = model.__class__()
369 | new_layers = []
370 | # for name, layer in model.named_modules():
371 | for layer in layers:
372 | if isinstance(layer, nn.Conv2d):
373 | new_layer = DenseConv2d(layer.in_channels, layer.out_channels, layer.kernel_size, stride=layer.stride, padding=layer.padding, dilation=layer.dilation, groups=layer.groups, bias=layer.bias)
374 | new_layer.weight = layer.weight
375 | new_layer.bias = layer.bias
376 | else:
377 | new_layer = layer
378 | # new_model.add_module(name, new_layer)
379 | # print(name, layer)
380 | new_layers.append(new_layer)
381 | return new_model.__class__(*new_layers)
382 |
383 | def save_checkpoint(model, checkpoint_fname):
384 |
385 | layers = list(model.children())
386 | # for name, layer in model.named_modules():
387 | count = 0
388 | save_dict = {}
389 | for layer in layers:
390 | if isinstance(layer, DenseConv2d):
391 | save_dict["{}.weight".format(count+1)] = layer.dense_w
392 | save_dict["{}.bias".format(count+1)] = layer.dense_bias
393 | elif isinstance(layer, nn.Linear):
394 | save_dict["{}.weight".format(count)] = layer.weight
395 | save_dict["{}.bias".format(count)] = layer.bias
396 | count+=1
397 | save_dict = collections.OrderedDict(save_dict)
398 | torch.save({"state_dict" : save_dict}, checkpoint_fname)
399 | return save_dict
400 |
401 | def load_checkpoint_to_mlpany(dense_checkpoint_file):
402 | checkpoint = torch.load(dense_checkpoint_file)["state_dict"]
403 | neurons=[]
404 | first = True
405 | for key in checkpoint:
406 | if key.endswith("weight"):
407 | h,w = checkpoint[key].shape
408 | if first:
409 | neurons.append(w)
410 | first=False
411 | print( h, w)
412 | neurons.append(h)
413 | print(neurons)
414 | neuron_list = " ".join([str(n) for n in neurons])
415 | print("python converter/torch2keras.py -i {} -o {} --flatten {}".format(dense_checkpoint_file, dense_checkpoint_file.replace(".pth", ".h5"), neuron_list))
416 | # align name
417 | model = model_mlp_any(neurons[0], neurons[1:-1], out_dim = neurons[-1])
418 | mlp_state = model.state_dict()
419 |
420 | # for key in mlp_state:
421 | # print( mlp_state[key].shape )
422 | # print( checkpoint[key].shape)
423 | model.load_state_dict(checkpoint)
424 |
425 | return model
426 |
427 |
428 |
429 |
430 |
431 |
432 | if __name__ == "__main__":
433 | # model = model_cnn_2layer(3, 32, 1)
434 | # print(model)
435 | # sub_model = add_feature_subsample(model, 3, 32, 0.5)
436 | # print(sub_model)
437 |
438 | checkpoint_fname = "mnist/cnn_2layer_width_1.pth"
439 | model = model_cnn_2layer(1, 28, 1)
440 | # print(model)
441 | input = torch.zeros(1, 1, 28 ,28 )
442 | x = model(input)
443 | model = convert_conv2d_dense(model)
444 | x2 = model(input)
445 | save_checkpoint(model, checkpoint_fname.split(".pth")[0] + "_dense.pth")
446 | print(x2)
447 |
448 | checkpoint_fname = "mnist/cnn_2layer_width_1_dense.pth"
449 | checkpoint = torch.load(checkpoint_fname)["state_dict"]
450 | load_checkpoint_to_mlpany(checkpoint)
451 | # model_mlp_any(784, neurons, out_dim = 10)
452 | x3 = model(input)
453 | print(x3)
454 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | ## Copyright (C) 2019, Huan Zhang
2 | ## Hongge Chen
3 | ## Chaowei Xiao
4 | ##
5 | ## This program is licenced under the BSD 2-Clause License,
6 | ## contained in the LICENCE file in this directory.
7 | ##
8 | from argparser import argparser
9 | import os
10 | import sys
11 | import copy
12 | from torch.nn import Sequential, Linear, ReLU, CrossEntropyLoss
13 | import numpy as np
14 | from datasets import loaders
15 | from model_defs import Flatten, model_mlp_any, model_cnn_1layer, model_cnn_2layer, model_cnn_4layer, model_cnn_3layer
16 | from bound_layers import BoundSequential, BoundLinear, BoundConv2d, ParallelBound, ParallelBoundPool
17 | from attacks.patch_attacker import PatchAttacker
18 | from attacks.pgd_attacker import PGDAttacker
19 | #from attacks.debug import PGDAttacker
20 | import torch.optim as optim
21 | # from gpu_profile import gpu_profile
22 | import time
23 | from datetime import datetime
24 | import torch.nn as nn
25 | from config import load_config, get_path, config_modelloader, config_dataloader, update_dict
26 | import torch
27 | from PIL import Image
28 | from matplotlib import pyplot as plt
29 | from unet import ResNetUNet
30 | from itertools import chain
31 | from tqdm import tqdm
32 | import pdb
33 | # sys.settrace(gpu_profile)
34 |
35 | torch.backends.cudnn.benchmark=True
36 |
37 | class AverageMeter(object):
38 | """Computes and stores the average and current value"""
39 | def __init__(self):
40 | self.reset()
41 | def reset(self):
42 | self.val = 0
43 | self.avg = 0
44 | self.sum = 0
45 | self.count = 0
46 | def update(self, val, n=1):
47 | self.val = val
48 | self.sum += val * n
49 | self.count += n
50 | self.avg = self.sum / self.count
51 |
52 | class Logger(object):
53 | def __init__(self, log_file = None):
54 | self.log_file = log_file
55 |
56 | def log(self, *args, **kwargs):
57 | print(*args, **kwargs)
58 | if self.log_file:
59 | print(*args, **kwargs, file = self.log_file)
60 | self.log_file.flush()
61 |
62 |
63 | def Train(model, model_id, t, loader, start_eps, end_eps, max_eps, norm, logger, verbose, train, opt, method, adv_net=None, unetopt=None, **kwargs):
64 | # if train=True, use training mode
65 | # if train=False, use test mode, no back prop
66 | num_class = 10
67 | losses = AverageMeter()
68 | unetlosses = AverageMeter()
69 | unetloss = None
70 | errors = AverageMeter()
71 | adv_errors = AverageMeter()
72 | robust_errors = AverageMeter()
73 | regular_ce_losses = AverageMeter()
74 | adv_ce_losses = AverageMeter()
75 | robust_ce_losses = AverageMeter()
76 | batch_time = AverageMeter()
77 | # initial
78 | kappa = 1
79 | factor = 1
80 | if train:
81 | model.train()
82 | if adv_net is not None:
83 | adv_net.train()
84 | else:
85 | model.eval()
86 | if adv_net is not None:
87 | adv_net.eval()
88 | # pregenerate the array for specifications, will be used for scatter
89 | if method == "robust":
90 | sa = np.zeros((num_class, num_class - 1), dtype = np.int32)
91 | for i in range(sa.shape[0]):
92 | for j in range(sa.shape[1]):
93 | if j < i:
94 | sa[i][j] = j
95 | else:
96 | sa[i][j] = j + 1
97 | sa = torch.LongTensor(sa)
98 | elif method == "adv":
99 | if kwargs["attack_type"] == "patch-random":
100 | attacker = PatchAttacker(model, loader.mean, loader.std, kwargs)
101 | elif kwargs["attack_type"] == "patch-strong":
102 | attacker = PatchAttacker(model, loader.mean, loader.std, kwargs)
103 | elif kwargs["attack_type"] == "PGD":
104 | attacker = PGDAttacker(model, loader.mean, loader.std, kwargs)
105 | total = len(loader.dataset)
106 | batch_size = loader.batch_size
107 | if train:
108 | batch_eps = np.linspace(start_eps, end_eps, total// (batch_size*args.grad_acc_steps) + 1)
109 | batch_eps = batch_eps.repeat(args.grad_acc_steps)
110 | else:
111 | batch_eps = np.linspace(start_eps, end_eps, total // (batch_size) + 1)
112 |
113 | if end_eps < 1e-6:
114 | logger.log('eps {} close to 0, using natural training'.format(end_eps))
115 | method = "natural"
116 |
117 | if train:
118 | iterator = enumerate(loader)
119 | else:
120 | iterator = tqdm(enumerate(loader))
121 | if train:
122 | opt.zero_grad()
123 | if unetopt is not None:
124 | unetopt.zero_grad()
125 | for i, (data, labels) in iterator:
126 | if "sample_limit" in kwargs and i*loader.batch_size > kwargs["sample_limit"]:
127 | break
128 | start = time.time()
129 | eps = batch_eps[i]
130 |
131 | if method == "robust":
132 | # generate specifications
133 | c = torch.eye(num_class).type_as(data)[labels].unsqueeze(1) - torch.eye(num_class).type_as(data).unsqueeze(0)
134 | # remove specifications to self
135 | I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(labels.data).unsqueeze(0)))
136 | c = (c[I].view(data.size(0),num_class-1,num_class))
137 | # scatter matrix to avoid computing margin to self
138 | sa_labels = sa[labels]
139 | # storing computed lower bounds after scatter
140 | lb_s = torch.zeros(data.size(0), num_class)
141 |
142 | #calculating upper and lower bound of the input
143 | if len(loader.std) == 1:
144 | std = torch.tensor([loader.std], dtype=torch.float)[:, None, None]
145 | mean = torch.tensor([loader.mean], dtype=torch.float)[:, None, None]
146 | elif len(loader.std) == 3:
147 | std = torch.tensor(loader.std, dtype=torch.float)[None, :, None, None]
148 | mean = torch.tensor(loader.mean, dtype=torch.float)[None, :, None, None]
149 | if kwargs["bound_type"] == "sparse-interval":
150 | data_ub = data
151 | data_lb = data
152 | eps = (eps / std).max()
153 | else:
154 | data_ub = (data + eps/std)
155 | data_lb = (data - eps/std)
156 | ub = ((1 - mean) / std)
157 | lb = (-mean / std)
158 | data_ub = torch.min(data_ub, ub)
159 | data_lb = torch.max(data_lb, lb)
160 |
161 | if list(model.parameters())[0].is_cuda:
162 | data_ub = data_ub.cuda()
163 | data_lb = data_lb.cuda()
164 | c = c.cuda()
165 | sa_labels = sa_labels.cuda()
166 | lb_s = lb_s.cuda()
167 |
168 | if list(model.parameters())[0].is_cuda:
169 | data = data.cuda()
170 | labels = labels.cuda()
171 | # the regular cross entropy
172 | if torch.cuda.device_count()>1:
173 | output = nn.DataParallel(model)(data)
174 | else:
175 | output = model(data)
176 |
177 | regular_ce = CrossEntropyLoss()(output, labels)
178 | regular_ce_losses.update(regular_ce.cpu().detach().numpy(), data.size(0))
179 | errors.update(torch.sum(torch.argmax(output, dim=1)!=labels).cpu().detach().numpy()/data.size(0), data.size(0))
180 |
181 | # the adversarial cross entropy
182 | if method == "adv":
183 | if kwargs["attack_type"]=="PGD":
184 | data_adv = attacker.perturb(data, labels, norm)
185 | elif kwargs["attack_type"]=="patch-random":
186 | data_adv = attacker.perturb(data, labels, norm, random_count=kwargs["random_mask_count"])
187 | else:
188 | raise RuntimeError("Unknown attack_type " + kwargs["bound_type"])
189 | output_adv = model(data_adv)
190 | adv_ce = CrossEntropyLoss()(output_adv, labels)
191 | adv_ce_losses.update(adv_ce.cpu().detach().numpy(), data.size(0))
192 | adv_errors.update(
193 | torch.sum(torch.argmax(output_adv, dim=1) != labels).cpu().detach().numpy() / data.size(0),
194 | data.size(0))
195 |
196 |
197 | if verbose or method == "robust":
198 | if kwargs["bound_type"] == "interval":
199 | ub, lb = model.interval_range(x_U=data_ub, x_L=data_lb, eps=eps, C=c)
200 | elif kwargs["bound_type"] == "sparse-interval":
201 | ub, lb = model.interval_range(x_U=data_ub, x_L=data_lb, eps=eps, C=c, k=kwargs["k"], Sparse=True)
202 | elif kwargs["bound_type"] == "patch-interval":
203 | if kwargs["attack_type"] == "patch-all" or kwargs["attack_type"] == "patch-all-pool":
204 | if kwargs["attack_type"] == "patch-all":
205 | width = data.shape[2] - kwargs["patch_w"] + 1
206 | length = data.shape[3] - kwargs["patch_l"] + 1
207 | pos_patch_count = width * length
208 | final_bound_count = pos_patch_count
209 | elif kwargs["attack_type"] == "patch-all-pool":
210 | width = data.shape[2] - kwargs["patch_w"] + 1
211 | length = data.shape[3] - kwargs["patch_l"] + 1
212 | pos_patch_count = width * length
213 | final_width = width
214 | final_length = length
215 | for neighbor in kwargs["neighbor"]:
216 | final_width = ((final_width - 1) // neighbor + 1)
217 | final_length = ((final_length - 1) // neighbor + 1)
218 | final_bound_count = final_width * final_length
219 |
220 | patch_idx = torch.arange(pos_patch_count, dtype=torch.long)[None, :]
221 | if kwargs["attack_type"] == "patch-all" or kwargs["attack_type"] == "patch-all-pool":
222 | x_cord = torch.zeros((1, pos_patch_count), dtype=torch.long)
223 | y_cord = torch.zeros((1, pos_patch_count), dtype=torch.long)
224 | idx = 0
225 | for w in range(width):
226 | for l in range(length):
227 | x_cord[0, idx] = w
228 | y_cord[0, idx] = l
229 | idx = idx + 1
230 |
231 | # expand the list to include coordinates from the complete patch
232 | patch_idx = [patch_idx.flatten()]
233 | x_cord = [x_cord.flatten()]
234 | y_cord = [y_cord.flatten()]
235 | for w in range(kwargs["patch_w"]):
236 | for l in range(kwargs["patch_l"]):
237 | patch_idx.append(patch_idx[0])
238 | x_cord.append(x_cord[0] + w)
239 | y_cord.append(y_cord[0] + l)
240 |
241 | patch_idx = torch.cat(patch_idx, dim=0)
242 | x_cord = torch.cat(x_cord, dim=0)
243 | y_cord = torch.cat(y_cord, dim=0)
244 |
245 | # create masks for each data point
246 | mask = torch.zeros([1, pos_patch_count, data.shape[2], data.shape[3]],
247 | dtype=torch.uint8)
248 | mask[:, patch_idx, x_cord, y_cord] = 1
249 | mask = mask[:, :, None, :, :]
250 | mask = mask.cuda()
251 | data_ub = torch.where(mask, data_ub[:, None, :, :, :], data[:, None, :, :, :])
252 | data_lb = torch.where(mask, data_lb[:, None, :, :, :], data[:, None, :, :, :])
253 |
254 | # data_ub size (#data*#possible patches, #channels, width, length)
255 | data_ub = data_ub.view(-1, *data_ub.shape[2:])
256 | data_lb = data_lb.view(-1, *data_lb.shape[2:])
257 |
258 | c = c.repeat_interleave(final_bound_count, dim=0)
259 |
260 | elif kwargs["attack_type"] == "patch-random" or kwargs["attack_type"] == "patch-nn":
261 | # First calculate the number of considered patches
262 | if kwargs["attack_type"] == "patch-random":
263 | pos_patch_count = kwargs["patch_count"]
264 | final_bound_count = pos_patch_count
265 | c = c.repeat_interleave(pos_patch_count, dim=0)
266 | elif kwargs["attack_type"] == "patch-nn":
267 | class_count = 10
268 | pos_patch_count = kwargs["patch_count"] * class_count
269 | final_bound_count = pos_patch_count
270 | c = c.repeat_interleave(pos_patch_count, dim=0)
271 |
272 |
273 | # Create four lists that enumerate the coordinate of the top left corner of the patch
274 | # patch_idx, data_idx, x_cord, y_cord shpe = (# of datapoints, # of possible patches)
275 | patch_idx = torch.arange(pos_patch_count, dtype=torch.long)[None, :].repeat(data.shape[0], 1)
276 | data_idx = torch.arange(data.shape[0], dtype=torch.long)[:, None].repeat(1, pos_patch_count)
277 | if kwargs["attack_type"] == "patch-random":
278 | x_cord = torch.randint(0, data.shape[2] - kwargs["patch_w"]+1, (data.shape[0], pos_patch_count))
279 | y_cord = torch.randint(0, data.shape[3] - kwargs["patch_l"]+1, (data.shape[0], pos_patch_count))
280 | elif kwargs["attack_type"] == "patch-nn":
281 | lbs_pred = adv_net(data)
282 | # Take only the feasible location
283 | lbs_pred = lbs_pred[:, :,
284 | 0:lbs_pred.size(2) - kwargs["patch_l"] + 1,
285 | 0:lbs_pred.size(3) - kwargs["patch_w"] + 1]
286 |
287 | lbs_pred = lbs_pred.reshape(lbs_pred.size(0) * lbs_pred.size(1), -1)
288 | # lbs_pred (# datapoints*# of classes, #flattened image dim)
289 | select_prob = nn.Softmax(1)(-lbs_pred * kwargs["T"])
290 | # select_prob (# datapoints*# of classes, #flattened image dim)
291 | random_loc = torch.multinomial(select_prob, kwargs["patch_count"], replacement=False)
292 | # random_loc (# datapoints*# of classes, patch_count)
293 | random_loc = random_loc.view(data.size(0), -1)
294 | # random_loc (# datapoints, # of classes*patch_count)
295 |
296 | x_cord = random_loc % (data.size(3) - kwargs["patch_w"] + 1)
297 | y_cord = random_loc // (data.size(2) - kwargs["patch_l"] + 1)
298 |
299 | # expand the list to include coordinates from the complete patch
300 | patch_idx = [patch_idx.flatten()]
301 | data_idx = [data_idx.flatten()]
302 | x_cord = [x_cord.flatten()]
303 | y_cord = [y_cord.flatten()]
304 | for w in range(kwargs["patch_w"]):
305 | for l in range(kwargs["patch_l"]):
306 | patch_idx.append(patch_idx[0])
307 | data_idx.append(data_idx[0])
308 | x_cord.append(x_cord[0]+w)
309 | y_cord.append(y_cord[0]+l)
310 |
311 | patch_idx = torch.cat(patch_idx, dim=0)
312 | data_idx = torch.cat(data_idx, dim=0)
313 | x_cord = torch.cat(x_cord, dim=0)
314 | y_cord = torch.cat(y_cord, dim=0)
315 |
316 | #create masks for each data point
317 | mask = torch.zeros([data.shape[0], pos_patch_count, data.shape[2], data.shape[3]],
318 | dtype=torch.uint8)
319 | mask[data_idx, patch_idx, x_cord, y_cord] = 1
320 | mask = mask[:, :, None, :, :]
321 | mask = mask.cuda()
322 | data_ub = torch.where(mask, data_ub[:, None, :, :, :], data[:, None, :, :, :])
323 | data_lb = torch.where(mask, data_lb[:, None, :, :, :], data[:, None, :, :, :])
324 |
325 | # data_ub size (#data*#possible patches, #channels, width, length)
326 | data_ub = data_ub.view(-1, *data_ub.shape[2:])
327 | data_lb = data_lb.view(-1, *data_lb.shape[2:])
328 |
329 | # forward pass all bounds
330 | if torch.cuda.device_count() > 1:
331 | if kwargs["attack_type"] == "patch-all-pool":
332 | ub, lb = nn.DataParallel(ParallelBoundPool(model))(x_U=data_ub, x_L=data_lb, eps=eps, C=c,
333 | neighbor=kwargs["neighbor"],
334 | pos_patch_width=width, pos_patch_length=length)
335 | else:
336 | ub, lb = nn.DataParallel(ParallelBound(model))(x_U=data_ub, x_L=data_lb,
337 | eps=eps, C=c)
338 | else:
339 | if kwargs["attack_type"] == "patch-all-pool":
340 | ub, lb = model.interval_range_pool(x_U=data_ub, x_L=data_lb, eps=eps, C=c,
341 | neighbor=kwargs["neighbor"],
342 | pos_patch_width=width, pos_patch_length=length)
343 | else:
344 | ub, lb = model.interval_range(x_U=data_ub, x_L=data_lb, eps=eps, C=c)
345 |
346 | # calculate unet loss
347 | if kwargs["attack_type"] == "patch-nn":
348 | labels_mod = labels.repeat_interleave(pos_patch_count, dim=0)
349 | sa_labels_mod = sa[labels_mod]
350 | sa_labels_mod = sa_labels_mod.cuda()
351 | # storing computed lower bounds after scatter
352 | lb_s_mod = torch.zeros(data.size(0) * pos_patch_count, num_class).cuda()
353 | lbs_actual = lb_s_mod.scatter(1, sa_labels_mod, lb)
354 | # lbs_actual (# data * # of logits * # of classes, # of classes)
355 |
356 | # lbs_pred (# datapoints*# of logits, #flattened image dim)
357 | lbs_pred = lbs_pred.view(data.shape[0], num_class, -1)
358 | # lbs_pred (# datapoints, # of logits, #flattened image dim)
359 | lbs_pred = lbs_pred.permute(0, 2, 1)
360 | # lbs_pred (# datapoints, #flattened image dim, # of logits)
361 |
362 | # random_loc (# datapoints, # of logits*patch_count)
363 | random_loc = random_loc.unsqueeze(2)
364 | random_loc = random_loc.repeat_interleave(10, dim=2)
365 | lbs_pred = lbs_pred.gather(1, random_loc)
366 | # lbs_pred (# datapoints, # of logits*patch_count, # of logits)
367 | lbs_pred = lbs_pred.view(-1, num_class)
368 | # lbs_pred (# datapoints*# of logits*patch_count, # of logits)
369 | unetloss = nn.MSELoss()(lbs_actual.detach(), lbs_pred)
370 |
371 | lb = lb.reshape(-1, final_bound_count, lb.shape[1])
372 | lb = torch.min(lb, dim=1)[0]
373 | else:
374 | raise RuntimeError("Unknown bound_type " + kwargs["bound_type"])
375 | # pdb.set_trace()
376 | lb = lb_s.scatter(1, sa_labels, lb)
377 | robust_ce = CrossEntropyLoss()(-lb, labels)
378 |
379 | if method == "robust":
380 | loss = robust_ce
381 | elif method == "natural":
382 | loss = regular_ce
383 | elif method == "adv":
384 | loss = adv_ce
385 | elif method == "robust_natural":
386 | natural_final_factor = kwargs["final-kappa"]
387 | kappa = (max_eps - eps * (1.0 - natural_final_factor)) / max_eps
388 | loss = (1-kappa) * robust_ce + kappa * regular_ce
389 | else:
390 | raise ValueError("Unknown method " + method)
391 |
392 | if train:
393 | if unetloss is not None:
394 | unetloss.backward()
395 | unetlosses.update(unetloss.cpu().detach().numpy(), data.size(0))
396 | loss = loss
397 | loss.backward()
398 | if (i + 1) % args.grad_acc_steps == 0 or i == len(loader) - 1:
399 | if unetloss is not None:
400 | for p in adv_net.parameters():
401 | p.grad /= args.grad_acc_steps
402 | unetopt.step()
403 | for p in model.parameters():
404 | p.grad /= args.grad_acc_steps
405 | opt.step()
406 | opt.zero_grad()
407 |
408 | batch_time.update(time.time() - start)
409 |
410 |
411 | losses.update(loss.cpu().detach().numpy(), data.size(0))
412 |
413 | if verbose or method == "robust":
414 | robust_ce_losses.update(robust_ce.cpu().detach().numpy(), data.size(0))
415 | robust_errors.update(torch.sum((lb<0).any(dim=1)).cpu().detach().numpy() / data.size(0), data.size(0))
416 | if i % 50 == 0 and train:
417 | logger.log( '[{:2d}:{:4d}]: eps {:4f} '
418 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
419 | 'Total Loss {loss.val:.4f} ({loss.avg:.4f}) '
420 | 'Unet Loss {unetloss.val:.4f} ({unetloss.avg:.4f}) '
421 | 'CE {regular_ce_loss.val:.4f} ({regular_ce_loss.avg:.4f}) '
422 | 'RCE {robust_ce_loss.val:.4f} ({robust_ce_loss.avg:.4f}) '
423 | 'ACE {adv_ce_loss.val:.4f} ({adv_ce_loss.avg:.4f}) '
424 | 'Err {errors.val:.4f} ({errors.avg:.4f}) '
425 | 'Rob Err {robust_errors.val:.4f} ({robust_errors.avg:.4f}) '
426 | 'Adv Err {adv_errors.val:.4f} ({adv_errors.avg:.4f}) '
427 | 'beta {factor:.3f} ({factor:.3f}) '
428 | 'kappa {kappa:.3f} ({kappa:.3f}) '.format(
429 | t, i, eps, batch_time=batch_time,
430 | loss=losses, unetloss=unetlosses, errors=errors, robust_errors = robust_errors, adv_errors = adv_errors,
431 | regular_ce_loss = regular_ce_losses, robust_ce_loss = robust_ce_losses,
432 | adv_ce_loss = adv_ce_losses,
433 | factor=factor, kappa = kappa))
434 |
435 |
436 | logger.log( '[FINAL RESULT epoch:{:2d} eps:{:.4f}]: '
437 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
438 | 'Total Loss {loss.val:.4f} ({loss.avg:.4f}) '
439 | 'Unet Loss {unetloss.val:.4f} ({unetloss.avg:.4f}) '
440 | 'CE {regular_ce_loss.val:.4f} ({regular_ce_loss.avg:.4f}) '
441 | 'RCE {robust_ce_loss.val:.4f} ({robust_ce_loss.avg:.4f}) '
442 | 'ACE {adv_ce_loss.val:.4f} ({adv_ce_loss.avg:.4f}) '
443 | 'Err {errors.val:.4f} ({errors.avg:.4f}) '
444 | 'Rob Err {robust_errors.val:.4f} ({robust_errors.avg:.4f}) '
445 | 'Adv Err {adv_errors.val:.4f} ({adv_errors.avg:.4f}) '
446 | 'beta {factor:.3f} ({factor:.3f}) '
447 | 'kappa {kappa:.3f} ({kappa:.3f}) \n'.format(
448 | t, eps, batch_time=batch_time,
449 | loss=losses,unetloss=unetlosses, errors=errors, robust_errors = robust_errors,
450 | adv_errors = adv_errors,
451 | regular_ce_loss = regular_ce_losses, robust_ce_loss = robust_ce_losses,
452 | adv_ce_loss = adv_ce_losses,
453 | kappa = kappa, factor=factor))
454 |
455 |
456 | if method == "natural":
457 | return errors.avg, errors.avg
458 | else:
459 | return robust_errors.avg, errors.avg
460 |
461 |
462 |
463 | def main(args):
464 | config = load_config(args)
465 | global_train_config = config["training_params"]
466 | models, model_names = config_modelloader(config)
467 |
468 | converted_models = [BoundSequential.convert(model) for model in models]
469 |
470 | for model, model_id, model_config in zip(converted_models, model_names, config["models"]):
471 | print("Number of GPUs:", torch.cuda.device_count())
472 | model = model.cuda()
473 | # make a copy of global training config, and update per-model config
474 | train_config = copy.deepcopy(global_train_config)
475 | if "training_params" in model_config:
476 | train_config = update_dict(train_config, model_config["training_params"])
477 |
478 | # read training parameters from config file
479 | epochs = train_config["epochs"]
480 | lr = train_config["lr"]
481 | weight_decay = train_config["weight_decay"]
482 | starting_epsilon = train_config["starting_epsilon"]
483 | end_epsilon = train_config["epsilon"]
484 | schedule_length = train_config["schedule_length"]
485 | schedule_start = train_config["schedule_start"]
486 | optimizer = train_config["optimizer"]
487 | method = train_config["method"]
488 | verbose = train_config["verbose"]
489 | lr_decay_step = train_config["lr_decay_step"]
490 | lr_decay_factor = train_config["lr_decay_factor"]
491 | # parameters specific to a training method
492 | method_param = train_config["method_params"]
493 | norm = float(train_config["norm"])
494 | train_config["loader_params"]["batch_size"] = train_config["loader_params"]["batch_size"]//args.grad_acc_steps
495 | train_config["loader_params"]["test_batch_size"] = train_config["loader_params"]["test_batch_size"]//args.grad_acc_steps
496 | train_data, test_data = config_dataloader(config, **train_config["loader_params"])
497 |
498 | # initialize adversary network
499 | if method_param["attack_type"] == "patch-nn":
500 | if config["dataset"] == "mnist":
501 | adv_net = ResNetUNet(n_class=10, channels=1,
502 | base_width=method_param["base_width"],
503 | dataset="mnist").cuda()
504 | if config["dataset"] == "cifar":
505 | adv_net = ResNetUNet(n_class=10, channels=3,
506 | base_width=method_param["base_width"],
507 | dataset="cifar").cuda()
508 | else:
509 | adv_net = None
510 | if optimizer == "adam":
511 | opt = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
512 | if method_param["attack_type"] == "patch-nn":
513 | unetopt = optim.Adam(adv_net.parameters(), lr=lr, weight_decay=weight_decay)
514 | else:
515 | unetopt = None
516 | elif optimizer == "sgd":
517 | if method_param["attack_type"] == "patch-nn":
518 | unetopt = optim.SGD(adv_net.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=weight_decay)
519 | else:
520 | unetopt = None
521 | opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=weight_decay)
522 | else:
523 | raise ValueError("Unknown optimizer")
524 | lr_scheduler = optim.lr_scheduler.StepLR(opt, step_size=lr_decay_step, gamma=lr_decay_factor)
525 | if method_param["attack_type"] == "patch-nn":
526 | lr_scheduler_unet = optim.lr_scheduler.StepLR(unetopt, step_size=lr_decay_step, gamma=lr_decay_factor)
527 |
528 |
529 | start_epoch = 0
530 | if args.resume:
531 | model_log = os.path.join(out_path, "test_log")
532 | logger = Logger(open(model_log, "w"))
533 | state_dict = torch.load(args.resume)
534 | print("***** Loading state dict from {} @ epoch {}".format(args.resume, state_dict['epoch']))
535 | model.load_state_dict(state_dict['state_dict'])
536 | opt.load_state_dict(state_dict['opt_state_dict'])
537 | lr_scheduler.load_state_dict(state_dict['lr_scheduler_dict'])
538 | start_epoch = state_dict['epoch'] + 1
539 |
540 | eps_schedule = [0] * schedule_start + list(np.linspace(starting_epsilon, end_epsilon, schedule_length))
541 | max_eps = end_epsilon
542 |
543 | model_name = get_path(config, model_id, "model", load = False)
544 | best_model_name = get_path(config, model_id, "best_model", load = False)
545 | print(model_name)
546 | model_log = get_path(config, model_id, "train_log")
547 | logger = Logger(open(model_log, "w"))
548 | logger.log("Command line:", " ".join(sys.argv[:]))
549 | logger.log("training configurations:", train_config)
550 | logger.log("Model structure:")
551 | logger.log(str(model))
552 | logger.log("data std:", train_data.std)
553 | best_err = np.inf
554 | recorded_clean_err = np.inf
555 | timer = 0.0
556 |
557 | for t in range(start_epoch, epochs):
558 | if method_param["attack_type"] == "patch-nn":
559 | lr_scheduler_unet.step(epoch=max(t-len(eps_schedule), 0))
560 | lr_scheduler.step(epoch=max(t-len(eps_schedule), 0))
561 |
562 | if t >= len(eps_schedule):
563 | eps = end_epsilon
564 | else:
565 | epoch_start_eps = eps_schedule[t]
566 | if t + 1 >= len(eps_schedule):
567 | epoch_end_eps = epoch_start_eps
568 | else:
569 | epoch_end_eps = eps_schedule[t+1]
570 |
571 | logger.log("Epoch {}, learning rate {}, epsilon {:.6f} - {:.6f}".format(t, lr_scheduler.get_lr(), epoch_start_eps, epoch_end_eps))
572 | # with torch.autograd.detect_anomaly():
573 | start_time = time.time()
574 |
575 |
576 | Train(model, model_id, t, train_data, epoch_start_eps, epoch_end_eps, max_eps, norm, logger, verbose, True, opt, method, adv_net, unetopt, **method_param)
577 | epoch_time = time.time() - start_time
578 | timer += epoch_time
579 | logger.log('Epoch time: {:.4f}, Total time: {:.4f}'.format(epoch_time, timer))
580 | logger.log("Evaluating...")
581 | # evaluate
582 | err, clean_err = Train(model, model_id, t, test_data, epoch_end_eps, epoch_end_eps, max_eps, norm, logger, verbose, False, None, method, adv_net, None, **method_param)
583 |
584 |
585 | logger.log('saving to', model_name)
586 | torch.save({
587 | 'state_dict' : model.state_dict(),
588 | 'opt_state_dict': opt.state_dict(),
589 | 'robust_err': err,
590 | 'clean_err': clean_err,
591 | 'epoch' : t,
592 | 'lr_scheduler_dict': lr_scheduler.state_dict()
593 | }, model_name)
594 |
595 | # save the best model after we reached the schedule
596 | if t >= len(eps_schedule):
597 | if err <= best_err:
598 | best_err = err
599 | recorded_clean_err = clean_err
600 | logger.log('Saving best model {} with error {}'.format(best_model_name, best_err))
601 | torch.save({
602 | 'state_dict' : model.state_dict(),
603 | 'opt_state_dict': opt.state_dict(),
604 | 'robust_err': err,
605 | 'clean_err': clean_err,
606 | 'epoch' : t,
607 | 'lr_scheduler_dict': lr_scheduler.state_dict()
608 | }, best_model_name)
609 |
610 | logger.log('Total Time: {:.4f}'.format(timer))
611 | logger.log('Model {} best err {}, clean err {}'.format(model_id, best_err, recorded_clean_err))
612 |
613 |
614 | if __name__ == "__main__":
615 | args = argparser()
616 | main(args)
617 |
--------------------------------------------------------------------------------
/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import models
4 |
5 | def convrelu(in_channels, out_channels, kernel, padding):
6 | return nn.Sequential(
7 | nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
8 | nn.ReLU(inplace=True),
9 | )
10 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
11 | """3x3 convolution with padding"""
12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=dilation, groups=groups, bias=False, dilation=dilation)
14 |
15 |
16 | def conv1x1(in_planes, out_planes, stride=1):
17 | """1x1 convolution"""
18 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
19 |
20 |
21 | class BasicBlock(nn.Module):
22 | expansion = 1
23 |
24 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
25 | base_width=64, dilation=1, norm_layer=None):
26 | super(BasicBlock, self).__init__()
27 | if norm_layer is None:
28 | norm_layer = nn.BatchNorm2d
29 | if dilation > 1:
30 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
31 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
32 | self.conv1 = conv3x3(inplanes, planes, stride)
33 | self.bn1 = norm_layer(planes)
34 | self.relu = nn.ReLU(inplace=True)
35 | self.conv2 = conv3x3(planes, planes)
36 | self.bn2 = norm_layer(planes)
37 | self.downsample = downsample
38 | self.stride = stride
39 |
40 | def forward(self, x):
41 | identity = x
42 |
43 | out = self.conv1(x)
44 | out = self.bn1(out)
45 | out = self.relu(out)
46 |
47 | out = self.conv2(out)
48 | out = self.bn2(out)
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.relu(out)
55 |
56 | return out
57 |
58 | class ResNetUNet(nn.Module):
59 | def __init__(self, n_class, channels=1, base_width=64, dataset="mnist"):
60 | super().__init__()
61 | self.dilation = 1
62 | self.inplanes = base_width
63 | self.channels = channels
64 | self.groups = 1
65 | self.base_width = base_width
66 | self._norm_layer = nn.BatchNorm2d
67 | if dataset == "mnist":
68 | self.img_dim = 28
69 | elif dataset == "cifar":
70 | self.img_dim = 32
71 | else:
72 | assert False
73 |
74 | self.layer0 = nn.Sequential(
75 | nn.Conv2d(channels, self.base_width, 7, stride=2, padding=3),
76 | nn.BatchNorm2d(self.base_width),
77 | nn.ReLU()
78 | ) # size=(N, 64, x.H/2, x.W/2) 16*16
79 |
80 | self.layer0_1x1 = convrelu(self.base_width, self.base_width, 1, 0)
81 | self.layer1 = nn.Sequential(
82 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
83 | self._make_layer(BasicBlock, self.base_width, 2)
84 | ) # size=(N, 64, x.H/4, x.W/4) 8*8
85 | self.layer1_1x1 = convrelu(self.base_width, self.base_width, 1, 0)
86 | self.layer2 = self._make_layer(BasicBlock, self.base_width*2, 2, stride=2,
87 | dilate=False) # size=(N, 128, x.H/8, x.W/8) 4*4
88 | self.layer2_1x1 = convrelu(self.base_width*2, self.base_width*2, 1, 0)
89 | self.layer3 = self._make_layer(BasicBlock, self.base_width*4, 2, stride=2,
90 | dilate=False) # size=(N, 256, x.H/2, x.W/2) 2*2
91 | self.layer3_1x1 = convrelu(self.base_width*4, self.base_width*4, 1, 0)
92 |
93 | self.upsample1 = nn.Upsample(((self.img_dim-1)//8+1, (self.img_dim-1)//8+1), mode='bilinear', align_corners=True)
94 | self.upsample2 = nn.Upsample(((self.img_dim-1)//4+1, (self.img_dim-1)//4+1), mode='bilinear', align_corners=True)
95 | self.upsample3 = nn.Upsample(((self.img_dim-1)//2+1, (self.img_dim-1)//2+1), mode='bilinear', align_corners=True)
96 | self.upsample4 = nn.Upsample((self.img_dim, self.img_dim), mode='bilinear', align_corners=True)
97 |
98 | self.conv_up2 = convrelu(self.base_width*(2+4), self.base_width*4, 3, 1)
99 | self.conv_up1 = convrelu(self.base_width*(1+4), self.base_width*2, 3, 1)
100 | self.conv_up0 = convrelu(self.base_width*(1+2), self.base_width*1, 3, 1)
101 |
102 | self.conv_original_size0 = convrelu(channels, self.base_width, 3, 1)
103 | self.conv_original_size1 = convrelu(self.base_width, self.base_width, 3, 1)
104 | self.conv_original_size2 = convrelu(self.base_width*(1+1), self.base_width, 3, 1)
105 |
106 | self.conv_last = nn.Conv2d(self.base_width, n_class, 1)
107 |
108 | def forward(self, input):
109 | # if self.channels == 1:
110 | # input = nn.ConstantPad2d((0,4,0,4), 0)(input)
111 | x_original = self.conv_original_size0(input) # MNIST (N, base_width, 28, 28)
112 | x_original = self.conv_original_size1(x_original) # MNIST (N, base_width, 28, 28)
113 |
114 | layer0 = self.layer0(input) # MNIST (N, base_width, 14, 14)
115 | layer1 = self.layer1(layer0) # MNIST (N, base_width, 7, 7)
116 | layer2 = self.layer2(layer1) # MNIST (N, base_width*2, 4, 4)
117 | layer3 = self.layer3(layer2) # MNIST (N, base_width*4, 2, 2)
118 |
119 | layer3 = self.layer3_1x1(layer3) # MNIST (N, base_width*4, 2, 2)
120 | x = self.upsample1(layer3) # MNIST (N, base_width*4, 4, 4)
121 |
122 | layer2 = self.layer2_1x1(layer2) # MNIST (N, base_width*2, 4, 4)
123 | x = torch.cat([x, layer2], dim=1) # MNIST (N, base_width*2 + base_width*4, 4, 4)
124 |
125 | x = self.conv_up2(x) # MNIST (N, base_width*4, 4, 4)
126 | x = self.upsample2(x) # MNIST (N, base_width*4, 7, 7)
127 | layer1 = self.layer1_1x1(layer1) # MNIST (N, base_width, 7, 7)
128 | x = torch.cat([x, layer1], dim=1) # MNIST (N, base_width, 7, 7)
129 | x = self.conv_up1(x) # MNIST (N, base_width*(1 + 4), 7, 7)
130 |
131 | x = self.upsample3(x) # MNIST (N, base_width*(1 + 4), 14, 14)
132 | layer0 = self.layer0_1x1(layer0) # MNIST (N, base_width*2, 14, 14)
133 | x = torch.cat([x, layer0], dim=1) # MNIST (N, base_width*(2+1), 14, 14)
134 | x = self.conv_up0(x) # MNIST (N, base_width*(1), 14, 14)
135 |
136 | x = self.upsample4(x) # MNIST (N, base_width*(1), 28, 28)
137 | x = torch.cat([x, x_original], dim=1) # MNIST (N, base_width*(1+1) , 28, 28)
138 | x = self.conv_original_size2(x) # MNIST (N, base_width*(1) , 28, 28)
139 |
140 | out = self.conv_last(x) # MNIST (N, n_class , 28, 28)
141 | return out
142 |
143 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
144 | norm_layer = self._norm_layer
145 | downsample = None
146 | previous_dilation = self.dilation
147 | if dilate:
148 | self.dilation *= stride
149 | stride = 1
150 | if stride != 1 or self.inplanes != planes * block.expansion:
151 | downsample = nn.Sequential(
152 | conv1x1(self.inplanes, planes * block.expansion, stride),
153 | norm_layer(planes * block.expansion),
154 | )
155 |
156 | layers = []
157 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
158 | self.base_width, previous_dilation, norm_layer))
159 | self.inplanes = planes * block.expansion
160 | for _ in range(1, blocks):
161 | layers.append(block(self.inplanes, planes, groups=self.groups,
162 | base_width=self.base_width, dilation=self.dilation,
163 | norm_layer=norm_layer))
164 |
165 | return nn.Sequential(*layers)
166 |
167 |
--------------------------------------------------------------------------------