├── README.md
├── attacks
├── __init__.py
├── bandit.py
├── decision.py
├── hsj.py
├── nes.py
├── score.py
├── signopt.py
└── simba.py
├── config.py
├── models
├── __init__.py
├── resnet18.py
└── vgg16.py
├── pics
└── framework.png
├── requirements.txt
├── trace_data_free.py
├── trace_data_limited.py
├── train.py
├── train_base_model.py
└── watermark.py
/README.md:
--------------------------------------------------------------------------------
1 | ### Identification of the Adversary from a Single Adversarial Example (ICML 2023)
2 | This code is the official implementation of [Identification of the Adversary from a Single Adversarial Example](https://openreview.net/forum?id=HBrQI0tX8F).
3 |
4 | ----
5 |
6 |
7 | ### Abstract
8 |
9 | Deep neural networks have been shown vulnerable to adversarial examples. Even though many defense methods have been proposed to enhance the robustness, it is still a long way toward providing an attack-free method to build a trustworthy machine learning system. In this paper, instead of enhancing the robustness, we take the investigator's perspective and propose a new framework to trace the first compromised model copy in a forensic investigation manner. Specifically, we focus on the following setting: the machine learning service provider provides model copies for a set of customers. However, one of the customers conducted adversarial attacks to fool the system. Therefore, the investigator's objective is to identify the first compromised copy by collecting and analyzing evidence from only available adversarial examples. To make the tracing viable, we design a random mask watermarking mechanism to differentiate adversarial examples from different copies. First, we propose a tracing approach in the data-limited case where the original example is also available. Then, we design a data-free approach to identify the adversary without accessing the original example. Finally, the effectiveness of our proposed framework is evaluated by extensive experiments with different model architectures, adversarial attacks, and datasets.
10 |
11 | ### Dependencies
12 | - PyTorch == 1.12.1
13 | - Torchvision == 0.13.1
14 | - Numpy == 1.21.5
15 | - Adversarial-Robustness-Toolbox == 1.10.3
16 |
17 | ### Pipeline
18 | #### Pretraining
19 | Use the following script to generate the pre-trained ResNet18 model on CIFAR-10 dataset. For Tiny-ImageNet, you may need to download the dataset from this [link](http://cs231n.stanford.edu/tiny-imagenet-200.zip) and move the data to your data directory.
20 | ```
21 | python train_base_model.py --model_name ResNet18 --dataset_name CIFAR10
22 | ```
23 | #### Watermarking
24 | For each model copy, we separate the base model into the head and tail (shared with all users) and only fine-tune the model head with a specific watermark while keeping the tail frozen. Here is a demo script for watermarking ResNet18 with the CIFAR-10 dataset.
25 | ```
26 | python train.py --model_name ResNet18 --dataset_name CIFAR10
27 | ```
28 | #### Tracing
29 | You could use the following script to generate adversarial examples for each user. In our demo, we apply the [Bandit](https://arxiv.org/abs/1807.07978) and generate 10 adversarial examples for each user (50*10 in total).
30 | ```
31 | python -m attacks.bandit --model_name ResNet18 --dataset_name CIFAR10 -M 50 -n 10
32 | ```
33 | We introduce two scenarios for tracing, namely the data-limited setting (with original image) and the data-free setting (without original image). The following script works in the data-limited case, and here we only take one adversarial example for each user to identify the adversary.
34 | ```
35 | python trace_data_limited.py --model_name ResNet18 --dataset_name CIFAR10 --alpha 0.9 --attack Bandit -M 50 -n 1
36 | ```
37 | Trace in the data-free case.
38 | ```
39 | python trace_data_free.py --model_name ResNet18 --dataset_name CIFAR10 --alpha 0.5 --attack Bandit -M 50 -n 1
40 | ```
41 | ### Citation
42 |
43 | If you find our work interesting, please consider giving a star :star: and cite as:
44 | ```
45 | @inproceedings{cheng2023identification,
46 | title={Identification of the adversary from a single adversarial example},
47 | author={Cheng, Minhao and Min, Rui and Sun, Haochen and Chen, Pin-Yu},
48 | booktitle={International Conference on Machine Learning},
49 | pages={5472--5484},
50 | year={2023},
51 | organization={PMLR}
52 | }
53 | ```
54 |
--------------------------------------------------------------------------------
/attacks/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Implements handy numerical computational functions
3 | """
4 | import numpy as np
5 | import torch as ch
6 | from torch.nn.modules import Upsample
7 |
8 |
9 | def norm(t):
10 | """
11 | Return the norm of a tensor (or numpy) along all the dimensions except the first one
12 | :param t:
13 | :return:
14 | """
15 | _shape = t.shape
16 | batch_size = _shape[0]
17 | num_dims = len(_shape[1:])
18 | if ch.is_tensor(t):
19 | norm_t = ch.sqrt(t.pow(2).sum(dim=[_ for _ in range(1, len(_shape))])).view([batch_size] + [1] * num_dims)
20 | norm_t += (norm_t == 0).float() * np.finfo(np.float64).eps
21 | return norm_t
22 | else:
23 | _norm = np.linalg.norm(
24 | t.reshape([batch_size, -1]), axis=1, keepdims=1
25 | ).reshape([batch_size] + [1] * num_dims)
26 | return _norm + (_norm == 0) * np.finfo(np.float64).eps
27 |
28 |
29 | def eg_step(x, g, lr):
30 | """
31 | Performs an exponentiated gradient step in the convex body [-1,1]
32 | :param x: batch_size x dim x .. tensor (or numpy) \in [-1,1]
33 | :param g: batch_size x dim x .. tensor (or numpy)
34 | :param lr: learning rate (step size)
35 | :return:
36 | """
37 | # from [-1,1] to [0,1]
38 | real_x = (x + 1.) / 2.
39 | if ch.is_tensor(x):
40 | pos = real_x * ch.exp(lr * g)
41 | neg = (1 - real_x) * ch.exp(-lr * g)
42 | else:
43 | pos = real_x * np.exp(lr * g)
44 | neg = (1 - real_x) * np.exp(-lr * g)
45 | new_x = pos / (pos + neg)
46 | return new_x * 2 - 1
47 |
48 |
49 | def step(x, g, lr):
50 | """
51 | Performs a step with no lp-ball constraints
52 | :param x: batch_size x dim x .. tensor (or numpy)
53 | :param g: batch_size x dim x .. tensor (or numpy)
54 | :param lr: learning rate (step size)
55 | :return:
56 | """
57 | return x + lr * g
58 |
59 |
60 | def lp_step(x, g, lr, p):
61 | """
62 | performs lp step of x in the direction of g, where the norm is computed
63 | across all the dimensions except the first one (assuming it's the batch_size)
64 | :param x: batch_size x dim x .. tensor (or numpy)
65 | :param g: batch_size x dim x .. tensor (or numpy)
66 | :param lr: learning rate (step size)
67 | :param p: 'inf' or '2'
68 | :return:
69 | """
70 | if p == 'inf':
71 | return linf_step(x, g, lr)
72 | elif p == '2':
73 | return l2_step(x, g, lr)
74 | else:
75 | raise Exception('Invalid p value')
76 |
77 |
78 | def l2_step(x, g, lr):
79 | """
80 | performs l2 step of x in the direction of g, where the norm is computed
81 | across all the dimensions except the first one (assuming it's the batch_size)
82 | :param x: batch_size x dim x .. tensor (or numpy)
83 | :param g: batch_size x dim x .. tensor (or numpy)
84 | :param lr: learning rate (step size)
85 | :return:
86 | """
87 | # print(x.device)
88 | # print(g.device)
89 | # print(norm(g).device)
90 | return x + lr * g / norm(g)
91 |
92 |
93 | def linf_step(x, g, lr):
94 | """
95 | performs linfinity step of x in the direction of g
96 | :param x: batch_size x dim x .. tensor (or numpy)
97 | :param g: batch_size x dim x .. tensor (or numpy)
98 | :param lr: learning rate (step size)
99 | :return:
100 | """
101 | if ch.is_tensor(x):
102 | return x + lr * ch.sign(g)
103 | else:
104 | return x + lr * np.sign(g)
105 |
106 |
107 | def l2_proj_maker(xs, eps):
108 | """
109 | makes an l2 projection function such that new points
110 | are projected within the eps l2-balls centered around xs
111 | :param xs:
112 | :param eps:
113 | :return:
114 | """
115 | if ch.is_tensor(xs):
116 | orig_xs = xs.clone()
117 |
118 | def proj(new_xs):
119 | delta = new_xs - orig_xs
120 | norm_delta = norm(delta)
121 | if np.isinf(eps): # unbounded projection
122 | return orig_xs + delta
123 | else:
124 | return orig_xs + (norm_delta <= eps).float() * delta + (
125 | norm_delta > eps).float() * eps * delta / norm_delta
126 | else:
127 | orig_xs = xs.copy()
128 |
129 | def proj(new_xs):
130 | delta = new_xs - orig_xs
131 | norm_delta = norm(delta)
132 | if np.isinf(eps): # unbounded projection
133 | return orig_xs + delta
134 | else:
135 | return orig_xs + (norm_delta <= eps) * delta + (norm_delta > eps) * eps * delta / norm_delta
136 | return proj
137 |
138 |
139 | def linf_proj_maker(xs, eps):
140 | """
141 | makes an linf projection function such that new points
142 | are projected within the eps linf-balls centered around xs
143 | :param xs:
144 | :param eps:
145 | :return:
146 | """
147 | if ch.is_tensor(xs):
148 | orig_xs = xs.clone()
149 |
150 | def proj(new_xs):
151 | return orig_xs + ch.clamp(new_xs - orig_xs, - eps, eps)
152 | else:
153 | orig_xs = xs.copy()
154 |
155 | def proj(new_xs):
156 | return np.clip(new_xs, orig_xs - eps, orig_xs + eps)
157 | return proj
158 |
159 |
160 | def upsample_maker(target_h, target_w):
161 | """
162 | makes an upsampler which takes a numpy tensor of the form
163 | minibatch x channels x h x w and casts to
164 | minibatch x channels x target_h x target_w
165 | :param target_h: int to specify the desired height
166 | :param target_w: int to specify the desired width
167 | :return:
168 | """
169 | _upsampler = Upsample(size=(target_h, target_w))
170 |
171 | def upsample_fct(xs):
172 | if ch.is_tensor(xs):
173 | return _upsampler(xs)
174 | else:
175 | return _upsampler(ch.from_numpy(xs)).numpy()
176 |
177 | return upsample_fct
178 |
179 |
180 | def hamming_dist(a, b):
181 | """
182 | reurns the hamming distance of a to b
183 | assumes a and b are in {+1, -1}
184 | :param a:
185 | :param b:
186 | :return:
187 | """
188 | assert np.all(np.abs(a) == 1.), "a should be in {+1,-1}"
189 | assert np.all(np.abs(b) == 1.), "b should be in {+1,-1}"
190 | return sum([_a != _b for _a, _b in zip(a, b)])
191 |
192 |
193 | def sign(t, is_ns_sign=True):
194 | """
195 | Given a tensor t of `batch_size x dim` return the (non)standard sign of `t`
196 | based on the `is_ns_sign` flag
197 | :param t: tensor of `batch_size x dim`
198 | :param is_ns_sign: if True uses the non-standard sign function
199 | :return:
200 | """
201 | _sign_t = ch.sign(t) if ch.is_tensor(t) else np.sign(t)
202 | if is_ns_sign:
203 | _sign_t[_sign_t == 0.] = 1.
204 | return _sign_t
205 |
206 |
207 | def noisy_sign(t, retain_p=1, crit='top', is_ns_sign=True):
208 | """
209 | returns a noisy version of the tensor `t` where
210 | only `retain_p` * 100 % of the coordinates retain their sign according
211 | to a `crit`.
212 | The noise is of the following effect
213 | sign(t) * x where x \in {+1, -1}
214 | Thus, if sign(t) = 0, sign(t) * x is always 0 (in case of `is_ns_sign=False`)
215 | :param t: tensor of `batch_size x dim`
216 | :param retain_p: fraction of coordinates
217 | :param is_ns_sign: if True uses the non-standard sign function
218 | :return:
219 | """
220 | assert 0. <= retain_p <= 1., "retain_p value should be in [0,1]"
221 |
222 | _shape = t.shape
223 | t = t.reshape(_shape[0], -1)
224 | batch_size, dim = t.shape
225 |
226 | sign_t = sign(t, is_ns_sign=is_ns_sign)
227 | k = int(retain_p * dim)
228 |
229 | if k == 0: # noise-ify all
230 | return (sign_t * np.sign((np.random.rand(batch_size, dim) < 0.5) - 0.5)).reshape(_shape)
231 | if k == dim: # retain all
232 | return sign_t.reshape(_shape)
233 |
234 | # do topk otheriwise
235 | noisy_sign_t = sign_t * np.sign((np.random.rand(*t.shape) < 0.5) - 0.5)
236 | _rows = np.zeros((batch_size, k), dtype=np.intp) + np.arange(batch_size)[:, None]
237 | if crit == 'top':
238 | _temp = np.abs(t)
239 | elif crit == 'random':
240 | _temp = np.random.rand(*t.shape)
241 | else:
242 | raise Exception('Unknown criterion for topk')
243 |
244 | _cols = np.argpartition(_temp, -k, axis=1)[:, -k:]
245 | noisy_sign_t[_rows, _cols] = sign_t[_rows, _cols]
246 | return noisy_sign_t.reshape(_shape)
247 |
--------------------------------------------------------------------------------
/attacks/bandit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import Upsample
4 | from torchvision import transforms
5 | import numpy as np
6 | import os
7 | import argparse
8 | from art.estimators.classification import PyTorchClassifier
9 |
10 | from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail
11 | import config
12 | from watermark import Watermark
13 | from attacks.score import ScoreBlackBoxAttack
14 | from attacks import *
15 |
16 |
17 | Loss = nn.CrossEntropyLoss(reduction = 'none')
18 |
19 | class BanditAttack(ScoreBlackBoxAttack):
20 | """
21 | Bandit Attack
22 | """
23 |
24 | def __init__(self,
25 | max_loss_queries,
26 | epsilon, p,
27 | fd_eta, lr,
28 | prior_exploration, prior_size, data_size, prior_lr,
29 | lb, ub, batch_size, name):
30 | """
31 | :param max_loss_queries: maximum number of calls allowed to loss oracle per data pt
32 | :param epsilon: radius of lp-ball of perturbation
33 | :param p: specifies lp-norm of perturbation
34 | :param fd_eta: forward difference step
35 | :param lr: learning rate of NES step
36 | :param prior_exploration: exploration noise
37 | :param prior_size: prior height/width (this is applicable only to images), you can disable it by setting it to
38 | None (it is assumed to prior_size = prior_height == prior_width)
39 | :param data_size: data height/width (applicable to images of the from `c x h x w`, you can ignore it
40 | by setting it to none, it is assumed that data_size = height = width
41 | :param prior_lr: learning rate in the prior space
42 | :param lb: data lower bound
43 | :param ub: data upper bound
44 | """
45 | super().__init__(max_extra_queries=np.inf,
46 | max_loss_queries=max_loss_queries,
47 | epsilon=epsilon,
48 | p=p,
49 | lb=lb,
50 | ub=ub,
51 | batch_size= batch_size,
52 | name = "Bandit")
53 | # other algorithmic parameters
54 | self.fd_eta = fd_eta
55 | # learning rate
56 | self.lr = lr
57 | # data size
58 | self.data_size = data_size
59 |
60 | # prior setup:
61 | # 1. step function
62 | if self.p == '2':
63 | self.prior_step = step
64 | elif self.p == 'inf':
65 | self.prior_step = eg_step
66 | else:
67 | raise Exception("Invalid p for l-p constraint")
68 | # 2. prior placeholder
69 | self.prior = None
70 | # prior size
71 | self.prior_size = prior_size
72 | # prior exploration
73 | self.prior_exploration = prior_exploration
74 | # 3. prior upsampler
75 | self.prior_upsample_fct = None if self.prior_size is None else upsample_maker(data_size, data_size)
76 | self.prior_lr = prior_lr
77 |
78 | def _perturb(self, xs_t, loss_fct):
79 | """
80 | The core of the bandit algorithm
81 | since this is compute intensive, it is implemented with torch support to push ops into gpu (if available)
82 | however, the input / output are numpys
83 | :param xs: numpy
84 | :return new_xs: returns a torch tensor
85 | """
86 |
87 | _shape = list(xs_t.shape)
88 | eff_shape = list(xs_t.shape)
89 | # since the upsampling assumes xs_t is batch_size x c x h x w. This is not the case for mnist,
90 | # which is batch_size x dim, let's take care of that below
91 |
92 | if self.prior_size is None:
93 | prior_shape = eff_shape
94 | else:
95 | prior_shape = eff_shape[:-2] + [self.prior_size] * 2
96 | # reset the prior if xs is a new batch
97 | if self.is_new_batch:
98 | self.prior = torch.zeros(prior_shape, device = xs_t.device)
99 | # create noise for exploration, estimate the gradient, and take a PGD step
100 | # exp_noise = torch.randn(prior_shape) / (np.prod(prior_shape[1:]) ** 0.5) # according to the paper
101 | exp_noise = torch.randn(prior_shape, device = xs_t.device)
102 | # Query deltas for finite difference estimator
103 | if self.prior_size is None:
104 | q1 = step(self.prior, exp_noise, self.prior_exploration)
105 | q2 = step(self.prior, exp_noise, - self.prior_exploration)
106 | else:
107 | q1 = self.prior_upsample_fct(step(self.prior, exp_noise, self.prior_exploration))
108 | q2 = self.prior_upsample_fct(step(self.prior, exp_noise, - self.prior_exploration))
109 | # Loss points for finite difference estimator
110 | l1 = loss_fct(l2_step(xs_t, q1.view(_shape), self.fd_eta))
111 | l2 = loss_fct(l2_step(xs_t, q2.view(_shape), self.fd_eta))
112 | # finite differences estimate of directional derivative
113 | est_deriv = (l1 - l2) / (self.fd_eta * self.prior_exploration)
114 | # 2-query gradient estimate
115 | # Note: Ilyas' implementation multiply the below by self.prior_exploration (different from pseudocode)
116 | # This should not affect the result as the `self.prior_lr` can be adjusted accordingly
117 | est_grad = est_deriv.view(-1, *[1] * len(prior_shape[1:]))* exp_noise
118 | # update prior with the estimated gradient:
119 | self.prior = self.prior_step(self.prior, est_grad, self.prior_lr)
120 | # gradient step in the data space
121 | if self.prior_size is None:
122 | gs = self.prior.clone()
123 | else:
124 | gs = self.prior_upsample_fct(self.prior)
125 | # perform the step
126 | new_xs = lp_step(xs_t, gs.view(_shape), self.lr, self.p)
127 | return new_xs, 2 * torch.ones(_shape[0], device = xs_t.device)
128 |
129 | def _config(self):
130 | return {
131 | "name": self.name,
132 | "p": self.p,
133 | "epsilon": self.epsilon,
134 | "lb": self.lb,
135 | "ub": self.ub,
136 | "max_extra_queries": "inf" if np.isinf(self.max_extra_queries) else self.max_extra_queries,
137 | "max_loss_queries": "inf" if np.isinf(self.max_loss_queries) else self.max_loss_queries,
138 | "lr": self.lr,
139 | "prior_lr": self.prior_lr,
140 | "prior_exploration": self.prior_exploration,
141 | "prior_size": self.prior_size,
142 | "data_size": self.data_size,
143 | "fd_eta": self.fd_eta,
144 | "attack_name": self.__class__.__name__
145 | }
146 |
147 | if __name__ == '__main__':
148 | parser = argparse.ArgumentParser()
149 | parser.add_argument('--model_name', help = 'Benchmark model structure.', choices = ['VGG16', 'ResNet18'])
150 | parser.add_argument('--dataset_name', help = 'Benchmark dataset used.', choices = ['CIFAR10', 'GTSRB'])
151 | parser.add_argument('-M', '--num_models', help = 'The number of models used.', type = int, default = 100)
152 | parser.add_argument('-n', '--num_samples', help = 'The number of adversarial samples per model.', type = int, default = 1)
153 | parser.add_argument('-c', '--cont', help = 'Continue from the stopped point last time.', action = 'store_true')
154 | parser.add_argument('-b', '--batch_size', help = 'The batch size used for attacks.', type = int, default = 10)
155 | args = parser.parse_args()
156 |
157 | # renaming
158 | dataset = eval(f'config.{args.dataset_name}()')
159 | training_set, testing_set = dataset.training_set, dataset.testing_set
160 | num_classes = dataset.num_classes
161 | means, stds = dataset.means, dataset.stds
162 | C, H, W = dataset.C, dataset.H, dataset.W
163 | Head, Tail = eval(f'{args.model_name}Head'), eval(f'{args.model_name}Tail')
164 | testing_loader = torch.utils.data.DataLoader(testing_set, batch_size = args.batch_size, shuffle = True, num_workers = 2)
165 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
166 |
167 | model_dir = f'saved_models/{args.model_name}-{args.dataset_name}'
168 |
169 | save_dir = f'saved_adv_examples/{args.model_name}-{args.dataset_name}'
170 |
171 |
172 | # load the tail of the model
173 | normalizer = transforms.Normalize(means, stds)
174 |
175 | # load the classifiers
176 | classifiers = []
177 | models = []
178 | tail = Tail(num_classes)
179 | tail.load_state_dict(torch.load(f'{model_dir}/base_tail_state_dict'))
180 | tail.to(device)
181 | for i in range(args.num_models):
182 |
183 |
184 | head = Head()
185 | head.to(device)
186 | head.load_state_dict(torch.load(f'{model_dir}/head_{i}/state_dict'))
187 | watermark = Watermark.load(f'{model_dir}/head_{i}/watermark.npy')
188 | models.append(nn.Sequential(normalizer, watermark, head, tail).eval())
189 | models[-1].to(device)
190 |
191 | classifier = PyTorchClassifier(
192 | model = models[-1],
193 | loss = None,
194 | optimizer = None,
195 | clip_values = (0, 1),
196 | input_shape=(C, H, W),
197 | nb_classes=num_classes,
198 | device_type = 'gpu' if torch.cuda.is_available() else 'cpu'
199 | )
200 | classifiers.append(classifier)
201 | classifiers = np.array(classifiers)
202 |
203 | for i, (model, c) in enumerate(zip(models, classifiers)):
204 | if os.path.isfile(f'{save_dir}/head_{i}/Bandit.npz') and args.cont:
205 | continue
206 | original_images, attacked_images, labels = [], [], []
207 | count_success = 0
208 | for X, y in testing_loader:
209 | with torch.no_grad():
210 | pred = c.predict(X.numpy())
211 | correct_mask = pred.argmax(axis = -1) == y.numpy()
212 |
213 | X_device, y_device = X.to(device), y.to(device)
214 | def loss_fct(xs, es = False):
215 | logits = model(xs)
216 | loss = Loss(logits.to(device), y_device)
217 | if es:
218 | return torch.argmax(logits, axis= -1) != y_device, loss
219 | else:
220 | return loss
221 |
222 | def early_stop_crit_fct(xs):
223 | logits = model(xs)
224 | return logits.argmax(axis = -1) != y_device
225 |
226 | a = BanditAttack(max_loss_queries = 10000, epsilon = 1.0, p = '2', lb = 0.0, ub = 1.0, batch_size = args.batch_size, name = 'Bandit',
227 | fd_eta = 0.01, lr = 0.01, prior_exploration = 0.1, prior_size = 20, data_size = 32, prior_lr = 0.1)
228 |
229 | X_attacked = a.run(X_device, loss_fct, early_stop_crit_fct).cpu().numpy()
230 |
231 | attacked_preds = np.vectorize(lambda z: z.predict(X_attacked), signature = '()->(m,n)')(classifiers)
232 |
233 | success_mask = attacked_preds.argmax(axis = -1) != y.numpy()
234 | success_mask = np.logical_and(success_mask[i], success_mask.sum(axis=0) >= 2)
235 |
236 | mask = np.logical_and(correct_mask, success_mask)
237 |
238 | original_images.append(X[mask])
239 | attacked_images.append(X_attacked[mask])
240 | labels.append(y[mask])
241 |
242 | count_success += mask.sum()
243 | if count_success >= args.num_samples:
244 | print(f'Model {i}, attack Bandit, {count_success} out of {args.num_samples} generated, done!')
245 | break
246 | else:
247 | print(f'Model {i}, attack Bandit, {count_success} out of {args.num_samples} generated...')
248 |
249 | original_images = np.concatenate(original_images)
250 | attacked_images = np.concatenate(attacked_images)
251 |
252 | labels = np.concatenate(labels)
253 | os.makedirs(f'{save_dir}/head_{i}', exist_ok = True)
254 | np.savez(f'{save_dir}/head_{i}/Bandit.npz', X = original_images, X_attacked = attacked_images, y = labels)
255 |
--------------------------------------------------------------------------------
/attacks/decision.py:
--------------------------------------------------------------------------------
1 | from attacks import *
2 |
3 | import torch
4 |
5 |
6 |
7 | class DecisionBlackBoxAttack(object):
8 | def __init__(self, max_queries=np.inf, epsilon=0.5, p='inf', lb=0., ub=1., batch_size=1):
9 | """
10 | :param max_queries: max number of calls to model per data point
11 | :param epsilon: perturbation limit according to lp-ball
12 | :param p: norm for the lp-ball constraint
13 | :param lb: minimum value data point can take in any coordinate
14 | :param ub: maximum value data point can take in any coordinate
15 | """
16 | assert p in ['inf', '2'], "L-{} is not supported".format(p)
17 |
18 | self.p = p
19 | self.max_queries = max_queries
20 | self.total_queries = 0
21 | self.total_successes = 0
22 | self.total_failures = 0
23 | self.total_distance = 0
24 | self.sigma = 0
25 | self.EOT = 1
26 | self.lb = lb
27 | self.ub = ub
28 | self.epsilon = epsilon / ub
29 | self.batch_size = batch_size
30 | self.list_loss_queries = torch.zeros(1, self.batch_size)
31 |
32 | def result(self):
33 | """
34 | returns a summary of the attack results (to be tabulated)
35 | :return:
36 | """
37 | list_loss_queries = self.list_loss_queries[1:].view(-1)
38 | mask = list_loss_queries > 0
39 | list_loss_queries = list_loss_queries[mask]
40 | self.total_queries = int(self.total_queries)
41 | self.total_successes = int(self.total_successes)
42 | self.total_failures = int(self.total_failures)
43 | return {
44 | "total_queries": self.total_queries,
45 | "total_successes": self.total_successes,
46 | "total_failures": self.total_failures,
47 | "average_num_queries": "NaN" if self.total_successes == 0 else self.total_queries / self.total_successes,
48 | "failure_rate": "NaN" if self.total_successes + self.total_failures == 0 else self.total_failures / (self.total_successes + self.total_failures),
49 | "median_num_loss_queries": "NaN" if self.total_successes == 0 else torch.median(list_loss_queries).item(),
50 | "config": self._config()
51 | }
52 |
53 | def _config(self):
54 | """
55 | return the attack's parameter configurations as a dict
56 | :return:
57 | """
58 | raise NotImplementedError
59 |
60 | def distance(self, x_adv, x = None):
61 | if x is None:
62 | diff = x_adv.view(x_adv.shape[0], -1)
63 | else:
64 | diff = (x_adv - x).view(x.shape[0], -1)
65 | if self.p == '2':
66 | out = torch.sqrt(torch.sum(diff * diff, dim = 1))
67 | elif self.p == 'inf':
68 | out, _ = torch.max(torch.abs(diff), dim = 1)
69 | return out
70 |
71 | def is_adversarial(self, x, y):
72 | '''
73 | check whether the adversarial constrain holds for x
74 | '''
75 | if self.targeted:
76 | return self.predict_label(x) == y
77 | else:
78 | return self.predict_label(x) != y
79 |
80 | def predict_label(self, xs):
81 | with torch.no_grad():
82 | if type(xs) is torch.Tensor:
83 | out = self.model(xs).argmax(dim=-1).squeeze()
84 | else:
85 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
86 | out = self.model(torch.FloatTensor(xs).to(device)).argmax(dim=-1).squeeze()
87 | return out
88 |
89 | def _perturb(self, xs_t, ys):
90 | raise NotImplementedError
91 |
92 | def run(self, Xs, ys, model, targeted, dset):
93 | self.model = model
94 | self.targeted = targeted
95 |
96 | X_attacked = []
97 |
98 | for x, y in zip(Xs, ys):
99 | adv, _ = self._perturb(x[None, ...], y[None])
100 | X_attacked.append(adv.squeeze())
101 | X_attacked = torch.stack(X_attacked).float()
102 |
103 | success = (self.distance(X_attacked,Xs) < self.epsilon)
104 |
105 | return X_attacked * success[:, None, None, None] + Xs * (~success[:, None, None, None])
106 |
--------------------------------------------------------------------------------
/attacks/hsj.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import transforms
4 | import numpy as np
5 | import os
6 | import argparse
7 | from art.estimators.classification import PyTorchClassifier
8 | from art.attacks.evasion import HopSkipJump
9 |
10 | from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail
11 | import config
12 | from watermark import Watermark
13 |
14 |
15 |
16 | if __name__ == '__main__':
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('--model_name', help = 'Benchmark model structure.', choices = ['VGG16', 'ResNet18'])
19 | parser.add_argument('--dataset_name', help = 'Benchmark dataset used.', choices = ['CIFAR10', 'GTSRB'])
20 | parser.add_argument('-M', '--num_models', help = 'The number of models used.', type = int, default = 100)
21 | parser.add_argument('-n', '--num_samples', help = 'The number of adversarial samples per model.', type = int, default = 1)
22 | parser.add_argument('-c', '--cont', help = 'Continue from the stopped point last time.', action = 'store_true')
23 | parser.add_argument('-b', '--batch_size', help = 'The batch size used for attacks.', type = int, default = 10)
24 | parser.add_argument('-v', '--verbose', help = 'Verbose when attacking.', action = 'store_true')
25 | args = parser.parse_args()
26 |
27 | dataset = eval(f'config.{args.dataset_name}()')
28 | training_set, testing_set = dataset.training_set, dataset.testing_set
29 | num_classes = dataset.num_classes
30 | means, stds = dataset.means, dataset.stds
31 | C, H, W = dataset.C, dataset.H, dataset.W
32 | Head, Tail = eval(f'{args.model_name}Head'), eval(f'{args.model_name}Tail')
33 | testing_loader = torch.utils.data.DataLoader(testing_set, batch_size = args.batch_size, shuffle = True, num_workers = 2)
34 |
35 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
36 |
37 | model_dir = f'saved_models/{args.model_name}-{args.dataset_name}'
38 |
39 | save_dir = f'saved_adv_examples/{args.model_name}-{args.dataset_name}'
40 |
41 |
42 | # load the tail of the model
43 | normalizer = transforms.Normalize(means, stds)
44 | tail = Tail(num_classes)
45 | tail.load_state_dict(torch.load(f'{model_dir}/base_tail_state_dict'))
46 | tail.to(device)
47 |
48 | # load the classifiers
49 | classifiers = []
50 | models = []
51 | for i in range(args.num_models):
52 | head = Head()
53 | head.to(device)
54 | head.load_state_dict(torch.load(f'{model_dir}/head_{i}/state_dict'))
55 | watermark = Watermark.load(f'{model_dir}/head_{i}/watermark.npy')
56 | models.append(nn.Sequential(normalizer, watermark, head, tail, nn.Softmax(dim = -1)).eval())
57 | models[-1].to(device)
58 | classifier = PyTorchClassifier(
59 | model = nn.Sequential(normalizer, watermark, head, tail, nn.Softmax(dim = -1)).eval(),
60 | loss = None,
61 | optimizer = None,
62 | clip_values = (0, 1),
63 | input_shape=(C, H, W),
64 | nb_classes=num_classes,
65 | device_type = 'gpu' if torch.cuda.is_available() else 'cpu'
66 | )
67 | classifiers.append(classifier)
68 | classifiers = np.array(classifiers)
69 |
70 | # attacking
71 | for i, (model, c) in enumerate(zip(models, classifiers)):
72 | a = HopSkipJump(c, verbose = args.verbose)
73 | if os.path.isfile(f'{save_dir}/head_{i}/HopSkipJump_proj.npz') and args.cont:
74 | continue
75 |
76 | original_images, attacked_images, labels = [], [], []
77 | count_success = 0
78 |
79 | for X, y in testing_loader:
80 | X, y = X.numpy(), y.numpy()
81 | pred = c.predict(X)
82 | correct_mask = pred.argmax(axis = 1) == y
83 |
84 | X_attacked = a.generate(X)
85 | attacked_preds = np.vectorize(lambda z: z.predict(X_attacked), signature = '()->(m,n)')(classifiers) # (num_model, batch_size, num_class)
86 | success_mask = attacked_preds.argmax(axis = -1) != y
87 | success_mask = np.logical_and(success_mask[i], success_mask.sum(axis=0) >= 2)
88 | mask = np.logical_and(correct_mask, success_mask)
89 |
90 | original_images.append(X[mask])
91 | attacked_images.append(X_attacked[mask])
92 | labels.append(y[mask])
93 |
94 | count_success += mask.sum()
95 | if count_success >= args.num_samples:
96 | print(f'Head {i}, attack HopSkipJump, {count_success} out of {args.num_samples} generated, done!')
97 | break
98 | else:
99 | print(f'Head {i}, attack HopSkipJump, {count_success} out of {args.num_samples} generated...')
100 |
101 | original_images = np.concatenate(original_images)
102 | attacked_images = np.concatenate(attacked_images)
103 | labels = np.concatenate(labels)
104 | os.makedirs(f'{save_dir}/head_{i}', exist_ok = True)
105 | np.savez(f'{save_dir}/head_{i}/HopSkipJump.npz', X = original_images, X_attacked = attacked_images, y = labels)
106 |
--------------------------------------------------------------------------------
/attacks/nes.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import transforms
4 | import numpy as np
5 | import os
6 | import argparse
7 | from art.estimators.classification import PyTorchClassifier
8 |
9 | from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail
10 | import config
11 | from watermark import Watermark
12 | from attacks.score import ScoreBlackBoxAttack
13 | from attacks import *
14 |
15 | Loss = nn.CrossEntropyLoss(reduction = 'none')
16 |
17 | class NESAttack(ScoreBlackBoxAttack):
18 | """
19 | NES Attack
20 | """
21 |
22 | def __init__(self, max_loss_queries, epsilon, p, fd_eta, lr, q, lb, ub, batch_size, name):
23 | """
24 | :param max_loss_queries: maximum number of calls allowed to loss oracle per data pt
25 | :param epsilon: radius of lp-ball of perturbation
26 | :param p: specifies lp-norm of perturbation
27 | :param fd_eta: forward difference step
28 | :param lr: learning rate of NES step
29 | :param q: number of noise samples per NES step
30 | :param lb: data lower bound
31 | :param ub: data upper bound
32 | """
33 | super().__init__(max_extra_queries=np.inf,
34 | max_loss_queries=max_loss_queries,
35 | epsilon=epsilon,
36 | p=p,
37 | lb=lb,
38 | ub=ub,
39 | batch_size= batch_size,
40 | name = "NES")
41 | self.q = q
42 | self.fd_eta = fd_eta
43 | self.lr = lr
44 |
45 | def _perturb(self, xs_t, loss_fct):
46 | _shape = list(xs_t.shape)
47 | dim = np.prod(_shape[1:])
48 | num_axes = len(_shape[1:])
49 | gs_t = torch.zeros_like(xs_t)
50 | for _ in range(self.q):
51 | # exp_noise = torch.randn_like(xs_t) / (dim ** 0.5)
52 | exp_noise = torch.randn_like(xs_t)
53 | fxs_t = xs_t + self.fd_eta * exp_noise
54 | bxs_t = xs_t - self.fd_eta * exp_noise
55 | est_deriv = (loss_fct(fxs_t) - loss_fct(bxs_t)) / (4. * self.fd_eta)
56 | gs_t += est_deriv.reshape(-1, *[1] * num_axes) * exp_noise
57 | # perform the step
58 | new_xs = lp_step(xs_t, gs_t, self.lr, self.p)
59 | return new_xs, 2 * self.q * torch.ones(_shape[0], device = xs_t.device)
60 |
61 | def _config(self):
62 | return {
63 | "name": self.name,
64 | "p": self.p,
65 | "epsilon": self.epsilon,
66 | "lb": self.lb,
67 | "ub": self.ub,
68 | "max_extra_queries": "inf" if np.isinf(self.max_extra_queries) else self.max_extra_queries,
69 | "max_loss_queries": "inf" if np.isinf(self.max_loss_queries) else self.max_loss_queries,
70 | "lr": self.lr,
71 | "q": self.q,
72 | "fd_eta": self.fd_eta,
73 | "attack_name": self.__class__.__name__
74 | }
75 |
76 | if __name__ == '__main__':
77 | parser = argparse.ArgumentParser()
78 | parser.add_argument('--model_name', help = 'Benchmark model structure.', choices = ['VGG16', 'ResNet18'])
79 | parser.add_argument('--dataset_name', help = 'Benchmark dataset used.', choices = ['CIFAR10', 'GTSRB'])
80 | parser.add_argument('-M', '--num_models', help = 'The number of models used.', type = int, default = 100)
81 | parser.add_argument('-n', '--num_samples', help = 'The number of adversarial samples per model.', type = int, default = 1)
82 | parser.add_argument('-c', '--cont', help = 'Continue from the stopped point last time.', action = 'store_true')
83 | parser.add_argument('-b', '--batch_size', help = 'The batch size used for attacks.', type = int, default = 10)
84 | args = parser.parse_args()
85 |
86 | # renaming
87 | dataset = eval(f'config.{args.dataset_name}()')
88 | training_set, testing_set = dataset.training_set, dataset.testing_set
89 | num_classes = dataset.num_classes
90 | means, stds = dataset.means, dataset.stds
91 | C, H, W = dataset.C, dataset.H, dataset.W
92 | Head, Tail = eval(f'{args.model_name}Head'), eval(f'{args.model_name}Tail')
93 | testing_loader = torch.utils.data.DataLoader(testing_set, batch_size = args.batch_size, shuffle = True, num_workers = 2)
94 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
95 |
96 | model_dir = f'saved_models/{args.model_name}-{args.dataset_name}'
97 |
98 | save_dir = f'saved_adv_examples/{args.model_name}-{args.dataset_name}'
99 |
100 |
101 | # load the tail of the model
102 | normalizer = transforms.Normalize(means, stds)
103 | tail = Tail(num_classes)
104 | tail.load_state_dict(torch.load(f'{model_dir}/base_tail_state_dict'))
105 | tail.to(device)
106 |
107 | # load the classifiers
108 | classifiers = []
109 | models = []
110 | for i in range(args.num_models):
111 | head = Head()
112 | head.to(device)
113 | head.load_state_dict(torch.load(f'{model_dir}/head_{i}/state_dict'))
114 | watermark = Watermark.load(f'{model_dir}/head_{i}/watermark.npy')
115 |
116 | models.append(nn.Sequential(normalizer, watermark, head, tail).eval())
117 | models[-1].to(device)
118 |
119 | classifier = PyTorchClassifier(
120 | model = models[-1],
121 | loss = None,
122 | optimizer = None,
123 | clip_values = (0, 1),
124 | input_shape=(C, H, W),
125 | nb_classes=num_classes,
126 | device_type = 'gpu' if torch.cuda.is_available() else 'cpu'
127 | )
128 | classifiers.append(classifier)
129 | classifiers = np.array(classifiers)
130 |
131 | for i, (model, c) in enumerate(zip(models, classifiers)):
132 | if os.path.isfile(f'{save_dir}/head_{i}/NES.npz') and args.cont:
133 | continue
134 | original_images, attacked_images, labels = [], [], []
135 | count_success = 0
136 | for X, y in testing_loader:
137 | with torch.no_grad():
138 | pred = c.predict(X.numpy())
139 | correct_mask = pred.argmax(axis = -1) == y.numpy()
140 |
141 | X_device, y_device = X.to(device), y.to(device)
142 | def loss_fct(xs, es = False):
143 | logits = model(xs)
144 | loss = Loss(logits.to(device), y_device)
145 | if es:
146 | return torch.argmax(logits, axis= -1) != y_device, loss
147 | else:
148 | return loss
149 |
150 | def early_stop_crit_fct(xs):
151 | logits = model(xs)
152 | return logits.argmax(axis = -1) != y_device
153 |
154 | a = NESAttack(max_loss_queries = 10000, epsilon = 1.0, p = '2', fd_eta = 0.01, lr = 0.01, q = 15, lb = 0.0, ub = 1.0, batch_size = args.batch_size, name = 'NESAttack')
155 |
156 | X_attacked = a.run(X_device, loss_fct, early_stop_crit_fct).cpu().numpy()
157 |
158 | attacked_preds = np.vectorize(lambda z: z.predict(X_attacked), signature = '()->(m,n)')(classifiers)
159 |
160 | success_mask = attacked_preds.argmax(axis = -1) != y.numpy()
161 | success_mask = np.logical_and(success_mask[i], success_mask.sum(axis=0) >= 2)
162 |
163 | mask = np.logical_and(correct_mask, success_mask)
164 |
165 | original_images.append(X[mask])
166 | attacked_images.append(X_attacked[mask])
167 | labels.append(y[mask])
168 |
169 | count_success += mask.sum()
170 | if count_success >= args.num_samples:
171 | print(f'Model {i}, attack NES, {count_success} out of {args.num_samples} generated, done!')
172 | break
173 | else:
174 | print(f'Model {i}, attack NES, {count_success} out of {args.num_samples} generated...')
175 |
176 | original_images = np.concatenate(original_images)
177 | attacked_images = np.concatenate(attacked_images)
178 | labels = np.concatenate(labels)
179 | os.makedirs(f'{save_dir}/head_{i}', exist_ok = True)
180 | np.savez(f'{save_dir}/head_{i}/NES.npz', X = original_images, X_attacked = attacked_images, y = labels)
181 |
--------------------------------------------------------------------------------
/attacks/score.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 |
5 | from torch import Tensor as t
6 |
7 | from attacks import *
8 |
9 | class ScoreBlackBoxAttack(object):
10 | def __init__(self, max_loss_queries=np.inf,
11 | max_extra_queries=np.inf,
12 | epsilon=0.5, p='inf', lb=0., ub=1.,batch_size = 50, name = '', device = 'cuda'):
13 | """
14 | :param max_loss_queries: max number of calls to model per data point
15 | :param max_extra_queries: max number of calls to early stopping extraerion per data point
16 | :param epsilon: perturbation limit according to lp-ball
17 | :param p: norm for the lp-ball constraint
18 | :param lb: minimum value data point can take in any coordinate
19 | :param ub: maximum value data point can take in any coordinate
20 | """
21 | assert p in ['inf', '2'], "L-{} is not supported".format(p)
22 |
23 | self.epsilon = epsilon
24 | self.p = p
25 | self.batch_size = batch_size
26 | self.max_loss_queries = max_loss_queries
27 | self.max_extra_queries = max_extra_queries
28 | self.list_loss_queries = torch.zeros(1, self.batch_size, device = device)
29 | self.total_loss_queries = 0
30 | self.total_extra_queries = 0
31 | self.total_successes = 0
32 | self.total_failures = 0
33 | self.lb = lb
34 | self.ub = ub
35 | self.name = name
36 | # the _proj method takes pts and project them into the constraint set:
37 | # which are
38 | # 1. epsilon lp-ball around xs
39 | # 2. valid data pt range [lb, ub]
40 | # it is meant to be used within `self.run` and `self._perturb`
41 | self._proj = None
42 | # a handy flag for _perturb method to denote whether the provided xs is a
43 | # new batch (i.e. the first iteration within `self.run`)
44 | self.is_new_batch = False
45 |
46 | def result(self):
47 | """
48 | returns a summary of the attack results (to be tabulated)
49 | :return:
50 | """
51 | list_loss_queries = self.list_loss_queries[1:].view(-1)
52 | mask = list_loss_queries > 0
53 | list_loss_queries = list_loss_queries[mask]
54 | self.total_loss_queries = int(self.total_loss_queries)
55 | self.total_extra_queries = int(self.total_extra_queries)
56 | self.total_successes = int(self.total_successes)
57 | self.total_failures = int(self.total_failures)
58 | return {
59 | "total_loss_queries": self.total_loss_queries,
60 | "total_extra_queries": self.total_extra_queries,
61 | "average_num_loss_queries": "NaN" if self.total_successes == 0 else self.total_loss_queries / self.total_successes,
62 | "average_num_extra_queries": "NaN" if self.total_successes == 0 else self.total_extra_queries / self.total_successes,
63 | "median_num_loss_queries": "NaN" if self.total_successes == 0 else torch.median(list_loss_queries).item(),
64 | "total_queries": self.total_extra_queries + self.total_loss_queries,
65 | "average_num_queries": "NaN" if self.total_successes == 0 else (self.total_extra_queries + self.total_loss_queries) / self.total_successes,
66 | "total_successes": self.total_successes,
67 | "total_failures": self.total_failures,
68 | "failure_rate": "NaN" if self.total_successes + self.total_failures == 0 else self.total_failures / (self.total_successes + self.total_failures),
69 | "config": self._config()
70 | }
71 |
72 | def _config(self):
73 | """
74 | return the attack's parameter configurations as a dict
75 | :return:
76 | """
77 | raise NotImplementedError
78 |
79 | def _perturb(self, xs_t, loss_fct):
80 | """
81 | :param xs_t: batch_size x dim x .. (torch tensor)
82 | :param loss_fct: function to query (the attacker would like to maximize) (batch_size data pts -> R^{batch_size}
83 | :return: suggested xs as a (torch tensor)and the used number of queries per data point
84 | i.e. a tuple of (batch_size x dim x .. tensor, batch_size array of number queries used)
85 | """
86 | raise NotImplementedError
87 |
88 | def proj_replace(self, xs_t, sugg_xs_t, dones_mask_t):
89 | sugg_xs_t = self._proj(sugg_xs_t)
90 | # replace xs only if not done
91 | xs_t = sugg_xs_t * (1. - dones_mask_t) + xs_t * dones_mask_t
92 | return xs_t
93 |
94 | def run(self, xs, loss_fct, early_stop_extra_fct):
95 | """
96 | attack with `xs` as data points using the oracle `l` and
97 | the early stopping extraerion `early_stop_extra_fct`
98 | :param xs: data points to be perturbed adversarially (numpy array)
99 | :param loss_fct: loss function (m data pts -> R^m)
100 | :param early_stop_extra_fct: early stop function (m data pts -> {0,1}^m)
101 | ith entry is 1 if the ith data point is misclassified
102 | :return: a dict of logs whose length is the number of iterations
103 | """
104 | # convert to tensor
105 | xs_t = torch.clone(xs)
106 |
107 | batch_size = xs.shape[0]
108 | num_axes = len(xs.shape[1:])
109 | num_loss_queries = torch.zeros(batch_size, device = xs.device)
110 | num_extra_queries = torch.zeros(batch_size, device = xs.device)
111 |
112 | dones_mask = early_stop_extra_fct(xs_t)
113 | correct_classified_mask = ~dones_mask
114 |
115 | # init losses for performance tracking
116 | losses = torch.zeros(batch_size, device = xs.device)
117 |
118 | # make a projector into xs lp-ball and within valid pixel range
119 | if self.p == '2':
120 | _proj = l2_proj_maker(xs_t, self.epsilon)
121 | self._proj = lambda _: torch.clamp(_proj(_), self.lb, self.ub)
122 | elif self.p == 'inf':
123 | _proj = linf_proj_maker(xs_t, self.epsilon)
124 | self._proj = lambda _: torch.clamp(_proj(_), self.lb, self.ub)
125 | else:
126 | raise Exception('Undefined l-p!')
127 |
128 | # iterate till model evasion or budget exhaustion
129 | self.is_new_batch = True
130 | its = 0
131 | while True:
132 | # if np.any(num_loss_queries + num_extra_queries >= self.max_loss_queries):
133 | if torch.any(num_loss_queries >= self.max_loss_queries):
134 | print("#loss queries exceeded budget, exiting")
135 | break
136 | if torch.any(num_extra_queries >= self.max_extra_queries):
137 | print("#extra_queries exceeded budget, exiting")
138 | break
139 | if torch.all(dones_mask):
140 | print("all data pts are misclassified, exiting")
141 | break
142 | # propose new perturbations
143 | sugg_xs_t, num_loss_queries_per_step = self._perturb(xs_t, loss_fct)
144 | # project around xs and within pixel range and
145 | # replace xs only if not done
146 | ##updated x here
147 | xs_t = self.proj_replace(xs_t, sugg_xs_t, (dones_mask.reshape(-1, *[1] * num_axes).float()))
148 |
149 | # update number of queries (note this is done before updating dones_mask)
150 | num_loss_queries += num_loss_queries_per_step * (~dones_mask)
151 | num_extra_queries += (~dones_mask)
152 | losses = loss_fct(xs_t) * (~dones_mask) + losses * dones_mask
153 |
154 | # update dones mask
155 | dones_mask = dones_mask | early_stop_extra_fct(xs_t)
156 | success_mask = dones_mask * correct_classified_mask
157 | its += 1
158 |
159 | self.is_new_batch = False
160 |
161 |
162 | success_mask = dones_mask * correct_classified_mask
163 | self.total_loss_queries += (num_loss_queries * success_mask).sum()
164 | self.total_extra_queries += (num_extra_queries * success_mask).sum()
165 | self.list_loss_queries = torch.cat([self.list_loss_queries, torch.zeros(1, batch_size, device = xs.device)], dim=0)
166 | self.list_loss_queries[-1] = num_loss_queries * success_mask
167 | self.total_successes += success_mask.sum()
168 | self.total_failures += ((~dones_mask) * correct_classified_mask).sum()
169 |
170 | # set self._proj to None to ensure it is intended use
171 | self._proj = None
172 |
173 | return xs_t
174 |
--------------------------------------------------------------------------------
/attacks/signopt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import transforms
4 | import numpy as np
5 | import os
6 | import argparse
7 | from scipy.linalg import qr
8 | from art.estimators.classification import PyTorchClassifier
9 |
10 | from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail
11 | import config
12 | from watermark import Watermark
13 | from attacks.decision import DecisionBlackBoxAttack
14 |
15 | if torch.cuda.is_available():
16 | t = lambda z: torch.tensor(data = z).cuda()
17 | else:
18 | t = lambda z: torch.tensor(data = z)
19 |
20 | start_learning_rate = 1.0
21 |
22 | def quad_solver(Q, b):
23 | """
24 | Solve min_a 0.5*aQa + b^T a s.t. a>=0
25 | """
26 | K = Q.shape[0]
27 | alpha = torch.zeros((K,))
28 | g = b
29 | Qdiag = torch.diag(Q)
30 | for _ in range(20000):
31 | delta = torch.maximum(alpha - g/Qdiag,0) - alpha
32 | idx = torch.argmax(torch.abs(delta))
33 | val = delta[idx]
34 | if abs(val) < 1e-7:
35 | break
36 | g = g + val*Q[:,idx]
37 | alpha[idx] += val
38 | return alpha
39 |
40 | def sign(y):
41 | """
42 | y -- numpy array of shape (m,)
43 | Returns an element-wise indication of the sign of a number.
44 | The sign function returns -1 if y < 0, 1 if x >= 0. nan is returned for nan inputs.
45 | """
46 | y_sign = torch.sign(y)
47 | y_sign[y_sign==0] = 1
48 | return y_sign
49 |
50 |
51 | class SignOPTAttack(DecisionBlackBoxAttack):
52 | """
53 | Sign_OPT
54 | """
55 |
56 | def __init__(self, epsilon, p, alpha, beta, svm, momentum, max_queries, k, lb, ub, batch_size, sigma):
57 | super().__init__(max_queries = max_queries,
58 | epsilon=epsilon,
59 | p=p,
60 | lb=lb,
61 | ub=ub,
62 | batch_size = batch_size)
63 | self.alpha = alpha
64 | self.beta = beta
65 | self.svm = svm
66 | self.momentum = momentum
67 | self.k = k
68 | self.sigma = sigma
69 | self.query_count = 0
70 |
71 |
72 | def _config(self):
73 | return {
74 | "p": self.p,
75 | "epsilon": self.epsilon,
76 | "lb": self.lb,
77 | "ub": self.ub,
78 | "attack_name": self.__class__.__name__
79 | }
80 |
81 | def attack_untargeted(self, x0, y0, alpha = 0.2, beta = 0.001):
82 | """
83 | Attack the original image and return adversarial example
84 | """
85 |
86 | y0 = y0[0]
87 | self.query_count = 0
88 |
89 | # Calculate a good starting point.
90 | num_directions = 10
91 | best_theta, g_theta = None, float('inf')
92 |
93 | for i in range(num_directions):
94 | self.query_count += 1
95 | theta = torch.randn_like(x0)
96 | if self.predict_label(x0+theta)!=y0:
97 | initial_lbd = torch.norm(theta)
98 | theta /= initial_lbd
99 | lbd, count = self.fine_grained_binary_search(x0, y0, theta, initial_lbd, g_theta)
100 | self.query_count += count
101 | if lbd < g_theta:
102 | best_theta, g_theta = theta, lbd
103 |
104 | if g_theta == float('inf'):
105 | return x0, self.query_count
106 |
107 | # Begin Gradient Descent.
108 | xg, gg = best_theta, g_theta
109 | vg = torch.zeros_like(xg)
110 |
111 | assert not self.svm
112 | for i in range(1500):
113 | sign_gradient, grad_queries = self.sign_grad_v1(x0, y0, xg, initial_lbd=gg, h=beta)
114 | self.query_count += grad_queries
115 | # Line search
116 | min_theta = xg
117 | min_g2 = gg
118 | min_vg = vg
119 | for _ in range(15):
120 | if self.momentum > 0:
121 | new_vg = self.momentum*vg - alpha*sign_gradient
122 | new_theta = xg + new_vg
123 | else:
124 | new_theta = xg - alpha * sign_gradient
125 | new_theta /= torch.norm(new_theta)
126 | new_g2, count = self.fine_grained_binary_search_local(x0, y0, new_theta, initial_lbd = min_g2, tol=beta/500)
127 | self.query_count += count
128 | alpha = alpha * 2
129 | if new_g2 < min_g2:
130 | min_theta = new_theta
131 | min_g2 = new_g2
132 | if self.momentum > 0:
133 | min_vg = new_vg
134 | else:
135 | break
136 | if min_g2 >= gg:
137 | for _ in range(15):
138 | alpha = alpha * 0.25
139 | if self.momentum > 0:
140 | new_vg = self.momentum*vg - alpha*sign_gradient
141 | new_theta = xg + new_vg
142 | else:
143 | new_theta = xg - alpha * sign_gradient
144 | new_theta /= torch.norm(new_theta)
145 | new_g2, count = self.fine_grained_binary_search_local(x0, y0, new_theta, initial_lbd = min_g2, tol=beta/500)
146 | self.query_count += count
147 | if new_g2 < gg:
148 | min_theta = new_theta
149 | min_g2 = new_g2
150 | if self.momentum > 0:
151 | min_vg = new_vg
152 | break
153 | if alpha < 1e-4:
154 | alpha = 1.0
155 | beta = beta*0.1
156 | if (beta < 1e-8):
157 | break
158 |
159 | xg, gg = min_theta, min_g2
160 | vg = min_vg
161 |
162 |
163 | if self.query_count > self.max_queries:
164 | break
165 |
166 | dist = self.distance(gg*xg)
167 | if dist < self.epsilon:
168 | break
169 |
170 | dist = self.distance(gg*xg)
171 | return x0 + gg*xg, self.query_count
172 |
173 | def sign_grad_v1(self, x0, y0, theta, initial_lbd, h=0.001, D=4, target=None):
174 | """
175 | Evaluate the sign of gradient by formulat
176 | sign(g) = 1/Q [ \sum_{q=1}^Q sign( g(theta+h*u_i) - g(theta) )u_i$ ]
177 | """
178 | K = self.k
179 | sign_grad = torch.zeros_like(theta)
180 | queries = 0
181 | for _ in range(K):
182 | u = torch.randn_like(theta)
183 | u /= torch.norm(u)
184 |
185 | sign = 1
186 | new_theta = theta + h*u
187 | new_theta /= torch.norm(new_theta)
188 |
189 | # Targeted case.
190 | if (target is not None and
191 | self.predict_label(x0+initial_lbd*new_theta) == target):
192 | sign = -1
193 |
194 | # Untargeted case
195 | if (target is None and
196 | self.predict_label(x0+t(initial_lbd*new_theta)) != y0):
197 | sign = -1
198 | queries += 1
199 | sign_grad += u*sign
200 |
201 | sign_grad /= K
202 |
203 | return sign_grad, queries
204 |
205 | def fine_grained_binary_search_local(self, x0, y0, theta, initial_lbd = 1.0, tol=1e-5):
206 | nquery = 0
207 | lbd = initial_lbd
208 |
209 | if self.predict_label(x0+lbd*theta) == y0:
210 | lbd_lo = lbd
211 | lbd_hi = lbd*1.01
212 | nquery += 1
213 | while self.predict_label(x0+lbd_hi*theta) == y0:
214 | lbd_hi = lbd_hi*1.01
215 | nquery += 1
216 | if lbd_hi > 20:
217 | return float('inf'), nquery
218 | else:
219 | lbd_hi = lbd
220 | lbd_lo = lbd*0.99
221 | nquery += 1
222 | while self.predict_label(x0+lbd_lo*theta) != y0 :
223 | lbd_lo = lbd_lo*0.99
224 | nquery += 1
225 | if nquery + self.query_count> self.max_queries:
226 | break
227 |
228 | while (lbd_hi - lbd_lo) > tol:
229 | lbd_mid = (lbd_lo + lbd_hi)/2.0
230 | nquery += 1
231 | if nquery + self.query_count> self.max_queries:
232 | break
233 | if self.predict_label(x0 + lbd_mid*theta) != y0:
234 | lbd_hi = lbd_mid
235 | else:
236 | lbd_lo = lbd_mid
237 | return lbd_hi, nquery
238 |
239 | def fine_grained_binary_search(self, x0, y0, theta, initial_lbd, current_best):
240 | nquery = 0
241 | if initial_lbd > current_best:
242 | if self.predict_label(x0+t(current_best*theta)) == y0:
243 | nquery += 1
244 | return float('inf'), nquery
245 | lbd = current_best
246 | else:
247 | lbd = initial_lbd
248 |
249 | lbd_hi = lbd
250 | lbd_lo = 0.0
251 |
252 | while (lbd_hi - lbd_lo) > 1e-5:
253 | lbd_mid = (lbd_lo + lbd_hi)/2.0
254 | nquery += 1
255 | if nquery + self.query_count> self.max_queries:
256 | break
257 | if self.predict_label(x0 + t(lbd_mid*theta)) != y0:
258 | lbd_hi = lbd_mid
259 | else:
260 | lbd_lo = lbd_mid
261 | return lbd_hi, nquery
262 |
263 | def fine_grained_binary_search_local_targeted(self, x0, t, theta, initial_lbd=1.0, tol=1e-5):
264 | nquery = 0
265 | lbd = initial_lbd
266 |
267 | if self.predict_label(x0 + t(lbd*theta)) != t:
268 | lbd_lo = lbd
269 | lbd_hi = lbd*1.01
270 | nquery += 1
271 | while self.predict_label(x0 + t(lbd_hi*theta)) != t:
272 | lbd_hi = lbd_hi*1.01
273 | nquery += 1
274 | if lbd_hi > 100:
275 | return float('inf'), nquery
276 | else:
277 | lbd_hi = lbd
278 | lbd_lo = lbd*0.99
279 | nquery += 1
280 | while self.predict_label(x0 + t(lbd_lo*theta)) == t:
281 | lbd_lo = lbd_lo*0.99
282 | nquery += 1
283 |
284 | while (lbd_hi - lbd_lo) > tol:
285 | lbd_mid = (lbd_lo + lbd_hi)/2.0
286 | nquery += 1
287 | if self.predict_label(x0 + t(lbd_mid*theta)) == t:
288 | lbd_hi = lbd_mid
289 | else:
290 | lbd_lo = lbd_mid
291 |
292 | return lbd_hi, nquery
293 |
294 | def fine_grained_binary_search_targeted(self, x0, t, theta, initial_lbd, current_best):
295 | nquery = 0
296 | if initial_lbd > current_best:
297 | if self.predict_label(x0 + t(current_best*theta)) != t:
298 | nquery += 1
299 | return float('inf'), nquery
300 | lbd = current_best
301 | else:
302 | lbd = initial_lbd
303 |
304 | lbd_hi = lbd
305 | lbd_lo = 0.0
306 |
307 | while (lbd_hi - lbd_lo) > 1e-5:
308 | lbd_mid = (lbd_lo + lbd_hi)/2.0
309 | nquery += 1
310 | if self.predict_label(x0 + t(lbd_mid*theta)) != t:
311 | lbd_lo = lbd_mid
312 | else:
313 | lbd_hi = lbd_mid
314 | return lbd_hi, nquery
315 |
316 |
317 | def _perturb(self, xs_t, ys):
318 | if self.targeted:
319 | adv, q = self.attack_targeted(xs_t, ys, self.alpha, self.beta)
320 | else:
321 | adv, q = self.attack_untargeted(xs_t, ys, self.alpha, self.beta)
322 |
323 | return adv, q
324 |
325 | if __name__ == '__main__':
326 | parser = argparse.ArgumentParser()
327 | parser.add_argument('--model_name', help = 'Benchmark model structure.', choices = ['VGG16', 'ResNet18'])
328 | parser.add_argument('--dataset_name', help = 'Benchmark dataset used.', choices = ['CIFAR10', 'GTSRB'])
329 | parser.add_argument('-M', '--num_models', help = 'The number of models used.', type = int, default = 100)
330 | parser.add_argument('-n', '--num_samples', help = 'The number of adversarial samples per model.', type = int, default = 1)
331 | parser.add_argument('-c', '--cont', help = 'Continue from the stopped point last time.', action = 'store_true')
332 | parser.add_argument('-b', '--batch_size', help = 'The batch size used for attacks.', type = int, default = 10)
333 | args = parser.parse_args()
334 |
335 | # renaming
336 | dataset = eval(f'config.{args.dataset_name}()')
337 | training_set, testing_set = dataset.training_set, dataset.testing_set
338 | num_classes = dataset.num_classes
339 | means, stds = dataset.means, dataset.stds
340 | C, H, W = dataset.C, dataset.H, dataset.W
341 | Head, Tail = eval(f'{args.model_name}Head'), eval(f'{args.model_name}Tail')
342 | testing_loader = torch.utils.data.DataLoader(testing_set, batch_size = args.batch_size, shuffle = True, num_workers = 2)
343 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
344 |
345 |
346 | model_dir = f'saved_models/{args.model_name}-{args.dataset_name}'
347 |
348 | save_dir = f'saved_adv_examples/{args.model_name}-{args.dataset_name}'
349 |
350 | # load the tail of the model
351 | normalizer = transforms.Normalize(means, stds)
352 | tail = Tail(num_classes)
353 | tail.load_state_dict(torch.load(f'{model_dir}/base_tail_state_dict'))
354 | tail.to(device)
355 |
356 | # load the classifiers
357 | classifiers = []
358 | models = []
359 | for i in range(args.num_models):
360 | head = Head()
361 | head.to(device)
362 | head.load_state_dict(torch.load(f'{model_dir}/head_{i}/state_dict'))
363 | watermark = Watermark.load(f'{model_dir}/head_{i}/watermark.npy')
364 |
365 | models.append(nn.Sequential(normalizer, watermark, head, tail).eval())
366 | models[-1].to(device)
367 |
368 | classifier = PyTorchClassifier(
369 | model = models[-1],
370 | loss = None, # dummy
371 | optimizer = None, # dummy
372 | clip_values = (0, 1),
373 | input_shape=(C, H, W),
374 | nb_classes=num_classes,
375 | device_type = 'gpu' if torch.cuda.is_available() else 'cpu'
376 | )
377 | classifiers.append(classifier)
378 | classifiers = np.array(classifiers)
379 |
380 | for i, (model, c) in enumerate(zip(models, classifiers)):
381 | if os.path.isfile(f'{save_dir}/head_{i}/SignOPT.npz') and args.cont:
382 | continue
383 | original_images, attacked_images, labels = [], [], []
384 | count_success = 0
385 | for X, y in testing_loader:
386 | with torch.no_grad():
387 | pred = c.predict(X.numpy())
388 | correct_mask = pred.argmax(axis = -1) == y.numpy()
389 |
390 | X_device, y_device = X.to(device), y.to(device)
391 |
392 | a = SignOPTAttack(epsilon = 1, p = '2', alpha = 0.2, beta = 0.001, svm = False, momentum = 0, max_queries = 10000, k = 200, lb = 0, ub = 1, batch_size = 1, sigma = 0)
393 | X_attacked = a.run(X_device, y_device, model, False, None).cpu().numpy()
394 |
395 | attacked_preds = np.vectorize(lambda z: z.predict(X_attacked), signature = '()->(m,n)')(classifiers)
396 |
397 | success_mask = attacked_preds.argmax(axis = -1) != y.numpy()
398 | success_mask = np.logical_and(success_mask[i], success_mask.sum(axis=0) >= 2)
399 |
400 | mask = np.logical_and(correct_mask, success_mask)
401 |
402 | original_images.append(X[mask])
403 | attacked_images.append(X_attacked[mask])
404 | labels.append(y[mask])
405 |
406 | count_success += mask.sum()
407 | if count_success >= args.num_samples:
408 | print(f'Model {i}, attack SignOPT, done!')
409 | break
410 | else:
411 | print(f'Model {i}, attack SignOPT, {count_success} out of {args.num_samples} generated...')
412 |
413 | original_images = np.concatenate(original_images)
414 | attacked_images = np.concatenate(attacked_images)
415 | labels = np.concatenate(labels)
416 | os.makedirs(f'{save_dir}/head_{i}', exist_ok = True)
417 | np.savez(f'{save_dir}/head_{i}/SignOPT.npz', X = original_images, X_attacked = attacked_images, y = labels)
418 |
--------------------------------------------------------------------------------
/attacks/simba.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import transforms
4 | import numpy as np
5 | import os
6 | import argparse
7 | from art.estimators.classification import PyTorchClassifier
8 | from art.attacks.evasion import SimBA
9 |
10 | from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail
11 | import config
12 | from watermark import Watermark
13 |
14 |
15 | if __name__ == '__main__':
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--model_name', help = 'Benchmark model structure.', choices = ['VGG16', 'ResNet18'])
18 | parser.add_argument('--dataset_name', help = 'Benchmark dataset used.', choices = ['CIFAR10', 'GTSRB'])
19 | parser.add_argument('-M', '--num_models', help = 'The number of models used.', type = int, default = 100)
20 | parser.add_argument('-n', '--num_samples', help = 'The number of adversarial samples per model.', type = int, default = 1)
21 | parser.add_argument('-c', '--cont', help = 'Continue from the stopped point last time.', action = 'store_true')
22 | parser.add_argument('-d', '--domain', help = 'Choose the domain of the attack.', choices = ['dct', 'px'], default = 'px')
23 | parser.add_argument('-b', '--batch_size', help = 'The batch size used for attacks.', type = int, default = 10)
24 | parser.add_argument('-v', '--verbose', help = 'Verbose when attacking.', action = 'store_true')
25 | args = parser.parse_args()
26 |
27 | # renaming
28 | dataset = eval(f'config.{args.dataset_name}()')
29 | training_set, testing_set = dataset.training_set, dataset.testing_set
30 | num_classes = dataset.num_classes
31 | means, stds = dataset.means, dataset.stds
32 | C, H, W = dataset.C, dataset.H, dataset.W
33 | Head, Tail = eval(f'{args.model_name}Head'), eval(f'{args.model_name}Tail')
34 | testing_loader = torch.utils.data.DataLoader(testing_set, batch_size = args.batch_size, shuffle = True, num_workers = 2)
35 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
36 |
37 | model_dir = f'saved_models/{args.model_name}-{args.dataset_name}'
38 |
39 | save_dir = f'saved_adv_examples/{args.model_name}-{args.dataset_name}'
40 |
41 |
42 | # load the tail of the model
43 | normalizer = transforms.Normalize(means, stds)
44 | tail = Tail(num_classes)
45 | tail.load_state_dict(torch.load(f'{model_dir}/base_tail_state_dict'))
46 | tail.to(device)
47 |
48 | # load the classifiers
49 | classifiers = []
50 | for i in range(args.num_models):
51 | head = Head()
52 | head.to(device)
53 | head.load_state_dict(torch.load(f'{model_dir}/head_{i}/state_dict'))
54 | watermark = Watermark.load(f'{model_dir}/head_{i}/watermark.npy')
55 |
56 | classifier = PyTorchClassifier(
57 | model = nn.Sequential(normalizer, watermark, head, tail, nn.Softmax(dim = -1)).eval(),
58 | loss = None, # dummy
59 | optimizer = None, # dummy
60 | clip_values = (0, 1),
61 | input_shape=(C, H, W),
62 | nb_classes=num_classes,
63 | device_type = 'gpu' if torch.cuda.is_available() else 'cpu'
64 | )
65 | classifiers.append(classifier)
66 | classifiers = np.array(classifiers)
67 |
68 | # attacking
69 | for i, c in enumerate(classifiers):
70 | if os.path.isfile(f'{save_dir}/head_{i}/SimBA-{args.domain}.npz') and args.cont:
71 | continue
72 |
73 | original_images, attacked_images, labels = [], [], []
74 | count_success = 0
75 |
76 | for X, y in testing_loader:
77 | X, y = X.numpy(), y.numpy()
78 | pred = c.predict(X)
79 | correct_mask = pred.argmax(axis = 1) == y
80 |
81 | a = SimBA(c, attack = args.domain, verbose = args.verbose)
82 |
83 | X_attacked = a.generate(X)
84 | attacked_preds = np.vectorize(lambda z: z.predict(X_attacked), signature = '()->(m,n)')(classifiers) # (num_model, batch_size, num_class)
85 | success_mask = attacked_preds.argmax(axis = -1) != y
86 | success_mask = np.logical_and(success_mask[i], success_mask.sum(axis=0) >= 2)
87 | mask = np.logical_and(correct_mask, success_mask)
88 |
89 | original_images.append(X[mask])
90 | attacked_images.append(X_attacked[mask])
91 | labels.append(y[mask])
92 |
93 | count_success += mask.sum()
94 | if count_success >= args.num_samples:
95 | print(f'Head {i}, attack SimBA-{args.domain}, {count_success} out of {args.num_samples} generated, done!')
96 | break
97 | else:
98 | print(f'Head {i}, attack SimBA-{args.domain}, {count_success} out of {args.num_samples} generated...')
99 |
100 | original_images = np.concatenate(original_images)
101 | attacked_images = np.concatenate(attacked_images)
102 | labels = np.concatenate(labels)
103 | os.makedirs(f'{save_dir}/head_{i}', exist_ok = True)
104 | np.savez(f'{save_dir}/head_{i}/SimBA-{args.domain}.npz', X = original_images, X_attacked = attacked_images, y = labels)
105 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | import torchvision.datasets as datasets
3 |
4 |
5 | # CIFAR10
6 | class CIFAR10:
7 | def __init__(self):
8 |
9 | transform_train = transforms.Compose([
10 | transforms.RandomCrop(32, padding=4),
11 | transforms.RandomHorizontalFlip(),
12 | transforms.ToTensor()
13 | ])
14 |
15 | transform_test = transforms.ToTensor()
16 |
17 | self.C, self.H, self.W = 3, 32, 32
18 | self.means, self.stds = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
19 | self.training_set = datasets.CIFAR10(root = f'./data', train = True, transform = transform_train, download = True)
20 | self.testing_set = datasets.CIFAR10(root = f'./data', train = False, transform = transform_test, download = True)
21 | self.num_classes = 10
22 | self.dataset = datasets.CIFAR10(root = f'./data', train = False, transform = None, download = True)
23 |
24 | # GTSRB
25 | class GTSRB:
26 | def __init__(self):
27 |
28 | transform_train = transforms.Compose([
29 | transforms.Resize((32, 32)),
30 | transforms.RandomCrop(32, padding=4),
31 | transforms.ToTensor()
32 | ])
33 |
34 | transform_test = transforms.Compose([
35 | transforms.Resize((32, 32)),
36 | transforms.ToTensor()
37 | ])
38 |
39 | self.means, self.stds = (0.3337, 0.3064, 0.3171), (0.2672, 0.2564, 0.2629)
40 | self.C, self.H, self.W = 3, 32, 32
41 | self.training_set = datasets.GTSRB(root = f'./data', split = 'train', transform = transform_train, download = True)
42 | self.testing_set = datasets.GTSRB(root = f'./data', split = 'test', transform = transform_test, download = True)
43 | self.dataset = datasets.GTSRB(root = f'./data', split = 'train', transform = None, download = True)
44 | self.num_classes = 43
45 |
46 | # TINY
47 | class TINY:
48 | def __init__(self):
49 |
50 | transform_train = transforms.Compose([
51 | transforms.RandomRotation(20),
52 | transforms.RandomHorizontalFlip(0.5),
53 | transforms.ToTensor()
54 | ])
55 |
56 | transform_test = transforms.Compose([
57 | transforms.ToTensor()
58 | ])
59 |
60 | self.C, self.H, self.W = 3, 64, 64
61 | self.means, self.stds = (0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262)
62 | self.training_set = datasets.ImageFolder('./data/tiny-imagenet-200/train', transform_train)
63 | self.testing_set = datasets.ImageFolder('./data/tiny-imagenet-200/test', transform_test)
64 | self.num_classes = 200
65 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .vgg16 import VGG16Head, VGG16Tail
2 | from .resnet18 import ResNet18Head, ResNet18Tail
3 |
--------------------------------------------------------------------------------
/models/resnet18.py:
--------------------------------------------------------------------------------
1 | '''
2 | Reference:
3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
5 | '''
6 |
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | class BasicBlock(nn.Module):
11 | expansion = 1
12 |
13 | def __init__(self, in_planes, planes, stride=1):
14 | super().__init__()
15 | self.conv1 = nn.Conv2d(
16 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
17 | self.bn1 = nn.BatchNorm2d(planes)
18 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
19 | stride=1, padding=1, bias=False)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 |
22 | self.shortcut = nn.Sequential()
23 | if stride != 1 or in_planes != self.expansion*planes:
24 | self.shortcut = nn.Sequential(
25 | nn.Conv2d(in_planes, self.expansion*planes,
26 | kernel_size=1, stride=stride, bias=False),
27 | nn.BatchNorm2d(self.expansion*planes)
28 | )
29 |
30 | def forward(self, x):
31 | out = F.relu(self.bn1(self.conv1(x)))
32 | out = self.bn2(self.conv2(out))
33 | out += self.shortcut(x)
34 | out = F.relu(out)
35 | return out
36 |
37 | def ResNet18Block(block, in_planes, planes, num_blocks, stride):
38 | strides = [stride] + [1]*(num_blocks-1)
39 | layers = []
40 | for s in strides:
41 | layers.append(block(in_planes, planes, s))
42 | in_planes = planes * block.expansion
43 | return nn.Sequential(*layers)
44 |
45 | def ResNet18Head():
46 | return nn.Sequential(
47 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
48 | nn.BatchNorm2d(64),
49 | nn.ReLU(inplace=True),
50 | ResNet18Block(BasicBlock, 64, 64, 2, stride=1)
51 | )
52 |
53 |
54 | class ResNet18Tail(nn.Module):
55 | def __init__(self, num_classes):
56 | super().__init__()
57 |
58 | self.layer2 = ResNet18Block(BasicBlock, 64, 128, 2, stride=2)
59 | self.layer3 = ResNet18Block(BasicBlock, 128, 256, 2, stride=2)
60 | self.layer4 = ResNet18Block(BasicBlock, 256, 512, 2, stride=2)
61 | self.pool1d = nn.AdaptiveAvgPool1d(512)
62 | self.linear = nn.Linear(512*BasicBlock.expansion, num_classes)
63 |
64 | def forward(self, x):
65 | out = self.layer2(x)
66 | out = self.layer3(out)
67 | out = self.layer4(out)
68 |
69 | out = F.avg_pool2d(out, 4)
70 | out = out.view(out.size(0), -1)
71 | out = self.pool1d(out)
72 | out = self.linear(out)
73 | return out
74 |
75 |
--------------------------------------------------------------------------------
/models/vgg16.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def VGG16Head():
5 | return nn.Sequential(
6 | nn.Conv2d(3, 64, kernel_size = 3, padding = 1),
7 | nn.BatchNorm2d(64),
8 | nn.ReLU(inplace=True),
9 | nn.Conv2d(64, 64, kernel_size = 3, padding = 1),
10 | nn.BatchNorm2d(64),
11 | nn.ReLU(inplace=True),
12 | nn.MaxPool2d(kernel_size=2, stride=2)
13 | )
14 |
15 | class VGG16Tail(nn.Module):
16 | def __init__(self, num_classes):
17 | super().__init__()
18 | self.features = self._make_layers([128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'])
19 | self.pool1d = nn.AdaptiveAvgPool1d(512)
20 | self.classifier = nn.Linear(512, num_classes)
21 |
22 | def forward(self, x):
23 | out = self.features(x)
24 | out = out.view(out.size(0), -1)
25 | out = self.pool1d(out)
26 | out = self.classifier(out)
27 | return out
28 |
29 | def _make_layers(self, cfg):
30 | layers = []
31 | in_channels = 64
32 | for x in cfg:
33 | if x == 'M':
34 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
35 | else:
36 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
37 | nn.BatchNorm2d(x),
38 | nn.ReLU(inplace=True)]
39 | in_channels = x
40 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
41 | return nn.Sequential(*layers)
42 |
--------------------------------------------------------------------------------
/pics/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/adv_tracing/6ec6226d2d5728902a1b54c6c44b50c4ff593750/pics/framework.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --find-links https://download.pytorch.org/whl/torch_stable.html
2 | torch==1.12.1+cu113
3 | torchvision==0.13.1+cu113
4 | numpy==1.21.5
5 | adversarial-robustness-toolbox==1.10.3
6 |
--------------------------------------------------------------------------------
/trace_data_free.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import transforms
4 | import numpy as np
5 | import argparse
6 | import logging
7 | from art.estimators.classification import PyTorchClassifier
8 |
9 | from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail
10 | import config
11 | from watermark import Watermark
12 |
13 |
14 | def get_classifier(watermark, model, means, stds, num_class):
15 | return PyTorchClassifier(
16 | model = nn.Sequential(transforms.Normalize(means, stds), watermark, model, nn.Softmax(dim = -1)).eval(),
17 | loss = None, # dummy
18 | optimizer = None, # dummy
19 | input_shape=(C, H, W),
20 | clip_values = (0, 1),
21 | nb_classes=num_class,
22 | device_type = 'gpu' if torch.cuda.is_available() else 'cpu'
23 | )
24 |
25 | if __name__ == '__main__':
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--model_name', default='ResNet18', help = 'Benchmark model structure.', choices = ['VGG16', 'ResNet18'])
28 | parser.add_argument('--dataset_name', default='CIFAR10', help = 'Benchmark dataset used.', choices = ['CIFAR10', 'GTSRB', 'TINY'])
29 | parser.add_argument('--attacks', default='Bandit', help = 'Attacks to be explored.', nargs = '+')
30 | parser.add_argument('--alpha', help = 'Hyper-parameter alpha.', type = float)
31 | parser.add_argument('-M', '--num_models', help = 'The number of models used for identification.', type = int, default = 50)
32 | parser.add_argument('-n', '--num_samples', help = 'The number of adversarial samples per model.', type = int, default = 1)
33 |
34 | args = parser.parse_args()
35 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
36 |
37 |
38 | dataset = eval(f'config.{args.dataset_name}()')
39 | training_set, testing_set = dataset.training_set, dataset.testing_set
40 | num_classes = dataset.num_classes
41 | means, stds = dataset.means, dataset.stds
42 | C, H, W = dataset.C, dataset.H, dataset.W
43 | Head, Tail = eval(f'{args.model_name}Head'), eval(f'{args.model_name}Tail')
44 | normalizer = transforms.Normalize(means, stds)
45 |
46 |
47 | model_dir = f'./saved_models/{args.model_name}-{args.dataset_name}'
48 | adv_dir = f'./saved_adv_examples/{args.model_name}-{args.dataset_name}'
49 |
50 | # load the tail of the model
51 | normalizer = transforms.Normalize(means, stds)
52 | tail = Tail(num_classes)
53 | tail.load_state_dict(torch.load(f'{model_dir}/base_tail_state_dict'))
54 | tail.to(device)
55 | tail.eval()
56 |
57 | # load the classifiers
58 | heads, watermarks, models = [], [], []
59 | for i in range(args.num_models):
60 | heads.append(Head())
61 | heads[-1].to(device)
62 | heads[-1].load_state_dict(torch.load(f'{model_dir}/head_{i}/state_dict'))
63 | heads[-1].eval()
64 | watermarks.append(Watermark.load(f'{model_dir}/head_{i}/watermark.npy'))
65 | models.append(nn.Sequential(heads[-1], tail))
66 |
67 | for a in args.attacks:
68 | correct = 0
69 | Loss = nn.CrossEntropyLoss()
70 | for i in range(args.num_models):
71 | adv_npz = np.load(f'{adv_dir}/head_{i}/{a}.npz')
72 | X, X_attacked, y = adv_npz['X'][:args.num_samples], adv_npz['X_attacked'][:args.num_samples], adv_npz['y'][:args.num_samples]
73 |
74 | classifier_matrix = np.array([[get_classifier(wm, m, means, stds, num_classes) for wm in watermarks] for m in models])
75 | predictions = np.vectorize(lambda c: c.predict(X_attacked), signature='()->(m,n)')(classifier_matrix)
76 |
77 | X, X_attacked, y = torch.tensor(X).to(device), torch.tensor(X_attacked).to(device), torch.tensor(y).to(device)
78 | CE_loss = torch.stack([Loss(tail(head(wm(normalizer(X_attacked)))).softmax(-1), y) for wm, head in zip(watermarks, heads)], axis = 0).cpu()
79 |
80 |
81 | out = torch.stack([tail(head(wm(normalizer(X_attacked)))).argmax(axis = -1) for wm, head in zip(watermarks, heads)], axis = 0)
82 | wrong_pred = (out == y[None,:]).sum(-1) > 0
83 |
84 | predictions_maximum_class = predictions.max(axis = -1)
85 |
86 | maximum_class_score = predictions_maximum_class[np.arange(args.num_models), np.arange(args.num_models), ...] / predictions_maximum_class.sum(1)
87 | maximum_class_score = torch.from_numpy(maximum_class_score).sum(-1)
88 |
89 | score = maximum_class_score + args.alpha * CE_loss
90 | score[wrong_pred]=np.inf
91 | result = score.topk(1, axis = 0, largest=False)[1]
92 |
93 | correct += torch.sum(result == i).item()
94 | print((f'Attack {a}, tracing accuracy {correct / args.num_models}.'))
95 |
--------------------------------------------------------------------------------
/trace_data_limited.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import transforms
4 | import numpy as np
5 | import argparse
6 |
7 | from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail
8 | import config
9 | from watermark import Watermark
10 |
11 |
12 | if __name__ == '__main__':
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--model_name', default='ResNet18', help='Benchmark model structure.', choices=['VGG16', 'ResNet18'])
15 | parser.add_argument('--dataset_name', default='CIFAR10', help='Benchmark dataset used.', choices=['CIFAR10', 'GTSRB', 'TINY'])
16 | parser.add_argument('--attacks', default='Bandit', help='Attacks to be explored.', nargs='+')
17 | parser.add_argument('--alpha', help='Hyper-parameter alpha.', type=float)
18 | parser.add_argument('-M', '--num_models', help='The number of models used for identification.', type=int, default=50)
19 | parser.add_argument('-n', '--num_samples', help='The number of adversarial samples per model.', type=int, default=1)
20 |
21 | args = parser.parse_args()
22 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
23 |
24 |
25 | dataset = eval(f'config.{args.dataset_name}()')
26 | training_set, testing_set = dataset.training_set, dataset.testing_set
27 | num_classes = dataset.num_classes
28 | means, stds = dataset.means, dataset.stds
29 | Head, Tail = eval(f'{args.model_name}Head'), eval(f'{args.model_name}Tail')
30 |
31 | model_dir = f'./saved_models/{args.model_name}-{args.dataset_name}'
32 | adv_dir = f'./saved_adv_examples/{args.model_name}-{args.dataset_name}'
33 |
34 | # load the tail of the model
35 | normalizer = transforms.Normalize(means, stds)
36 | tail = Tail(num_classes)
37 | tail.load_state_dict(torch.load(f'{model_dir}/base_tail_state_dict'))
38 | tail.to(device)
39 | tail.eval()
40 |
41 | # load the classifiers
42 | heads, watermarks = [], []
43 | for i in range(args.num_models):
44 | heads.append(Head())
45 | heads[-1].to(device)
46 | heads[-1].load_state_dict(torch.load(f'{model_dir}/head_{i}/state_dict'))
47 | heads[-1].eval()
48 | watermarks.append(Watermark.load(f'{model_dir}/head_{i}/watermark.npy'))
49 | overall_acc = 0
50 |
51 | for a in args.attacks:
52 | correct = 0
53 | for i in range(args.num_models):
54 | adv_npz = np.load(f'{adv_dir}/head_{i}/{a}.npz')
55 | Loss = nn.CrossEntropyLoss()
56 | X, X_attacked, y = adv_npz['X'][:args.num_samples], adv_npz['X_attacked'][:args.num_samples], adv_npz['y'][:args.num_samples]
57 | X, X_attacked, y = torch.tensor(X).to(device), torch.tensor(X_attacked).to(device), torch.tensor(y).to(device)
58 |
59 | CE_loss = torch.stack([Loss(tail(head(wm(normalizer(X_attacked)))).softmax(-1), y) for wm, head in zip(watermarks, heads)], axis = 0)
60 |
61 | diffs_sum = torch.stack([wm.get_values(torch.abs(X - X_attacked)).sum() for wm in watermarks], axis = 0)
62 |
63 | score = diffs_sum + args.alpha * CE_loss
64 | wrong_pred_list = []
65 |
66 | out = torch.stack([tail(head(wm(normalizer(X_attacked)))).argmax(axis = -1) for wm, head in zip(watermarks, heads)], axis = 0)
67 | wrong_pred = (out == y[None,:]).sum(-1) > 0
68 |
69 | score[wrong_pred] = np.inf
70 |
71 |
72 | result = score.topk(1, largest=False)[1]
73 | correct += torch.sum(result == i).item()
74 | print((f'Attack {a}, tracing accuracy {correct / args.num_models}.'))
75 |
76 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision import transforms
5 | import os
6 | import argparse
7 |
8 | from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail
9 | import config
10 | from watermark import Watermark
11 |
12 |
13 | '''
14 | Train the multi-head-one-tail model.
15 | '''
16 |
17 | if __name__ == "__main__":
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--model_name', help = 'Benchmark model structure.', choices = ['VGG16', 'ResNet18'])
20 | parser.add_argument('--dataset_name', help = 'Benchmark dataset used.', choices = ['CIFAR10', 'GTSRB', 'TINY'])
21 | parser.add_argument('--num_workers', help = 'Number of workers', type = int, default = 2)
22 | parser.add_argument('-N', '--num_heads', help = 'Number of heads.', type = int, default = 100)
23 | parser.add_argument('-b', '--batch_size', help = 'Batch size.', type = int, default = 128)
24 | parser.add_argument('-e', '--num_epochs', help = 'Number of epochs.', type = int, default = 10)
25 | parser.add_argument('-lr', '--learning_rate', help = 'Learning rate.', type = float, default = 1e-3)
26 | parser.add_argument('-md', '--masked_dims', help = 'Number of masked dimensions', type = int, default = 100)
27 |
28 | args = parser.parse_args()
29 |
30 | if args.dataset_name == 'CIFAR10' or args.dataset_name == 'GTSRB':
31 | C, H, W = 3, 32, 32
32 | elif args.dataset_name == 'tiny':
33 | C, H, W = 3, 64, 64
34 |
35 | # Create the model and the dataset
36 | dataset = eval(f'config.{args.dataset_name}()')
37 | training_set, testing_set = dataset.training_set, dataset.testing_set
38 | num_classes = dataset.num_classes
39 | means, stds = dataset.means, dataset.stds
40 | Head, Tail = eval(f'{args.model_name}Head'), eval(f'{args.model_name}Tail')
41 | normalizer = transforms.Normalize(means, stds)
42 | training_loader = torch.utils.data.DataLoader(training_set, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers)
43 | testing_loader = torch.utils.data.DataLoader(testing_set, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers)
44 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
45 |
46 | # Place to save the trained model
47 | save_dir = f'saved_models/{args.model_name}-{args.dataset_name}'
48 | os.makedirs(save_dir, exist_ok = True)
49 |
50 | # Load the tail of the model
51 | tail = Tail(num_classes)
52 | tail.load_state_dict(torch.load(f'{save_dir}/base_tail_state_dict'))
53 |
54 | tail.to(device)
55 |
56 |
57 | # training
58 |
59 | for i in range(args.num_heads):
60 |
61 | os.makedirs(f'{save_dir}/head_{i}', exist_ok = True)
62 |
63 | head = nn.Sequential(Watermark.random(args.masked_dims, C, H, W), Head())
64 |
65 | head.to(device)
66 | head[0].save(f'{save_dir}/head_{i}/watermark.npy')
67 | head[1].load_state_dict(torch.load(f'{save_dir}/base_head_state_dict'))
68 | optimizer = torch.optim.Adam(head.parameters(), lr = args.learning_rate)
69 | Loss = nn.CrossEntropyLoss()
70 | best_accuracy = 0.
71 |
72 | for n in range(args.num_epochs):
73 | head.train()
74 | epoch_mask_grad_norm, epoch_mask_grad_norm_inverse = 0., 0.
75 | epoch_loss = 0.0
76 | for X, y in training_loader:
77 | X, y = X.to(device), y.to(device)
78 | optimizer.zero_grad()
79 | out_clean = tail(head(normalizer(X)))
80 | clean_loss = Loss(out_clean, y)
81 | loss = clean_loss
82 | loss.backward()
83 | optimizer.step()
84 | epoch_loss += loss.item() * len(y) / len(training_set)
85 |
86 | # testing
87 | head.eval()
88 | tail.eval()
89 |
90 | accuracy = 0.0
91 | with torch.no_grad():
92 | for X, y in testing_loader:
93 | X, y = X.to(device), y.to(device)
94 | _, pred = tail(head(normalizer(X))).max(axis = -1)
95 | accuracy += (pred == y).sum().item() / len(testing_set)
96 |
97 | print(f'Head {i}, epoch {n}, loss {epoch_loss:.3f}, accuracy = {accuracy:.4f}')
98 |
99 | # save the best result
100 | if accuracy > best_accuracy:
101 | best_accuracy = accuracy
102 | torch.save(head[1].state_dict(), f'{save_dir}/head_{i}/state_dict')
103 |
104 | print(f'Completed the training for head {i}, accuracy = {best_accuracy:.4f}.')
105 | print(f'Completed the training of {args.num_heads} heads, {args.model_name}-{args.dataset_name}.')
106 |
--------------------------------------------------------------------------------
/train_base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | import argparse
5 |
6 | from torchvision import transforms
7 | from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail
8 | import config
9 |
10 | if __name__ == "__main__":
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--model_name', help = 'Benchmark model structure.', choices = ['VGG16', 'ResNet18'])
13 | parser.add_argument('--dataset_name', help = 'Benchmark dataset used.', choices = ['CIFAR10', 'GTSRB', 'TINY'])
14 | parser.add_argument('--num_workers', help = 'Number of workers', type = int, default = 2)
15 | parser.add_argument('-b', '--batch_size', help = 'Batch size.', type = int, default = 128)
16 | parser.add_argument('-e', '--num_epochs', help = 'Number of epochs.', type = int, default = 50)
17 | parser.add_argument('-lr', '--learning_rate', help = 'Learning rate.', type = float, default = 1e-3)
18 | args = parser.parse_args()
19 |
20 | # Create the model and the dataset
21 | dataset = eval(f'config.{args.dataset_name}()')
22 | training_set, testing_set = dataset.training_set, dataset.testing_set
23 | num_classes = dataset.num_classes
24 | means, stds = dataset.means, dataset.stds
25 | Head, Tail = eval(f'{args.model_name}Head'), eval(f'{args.model_name}Tail')
26 | base_model = nn.Sequential(transforms.Normalize(means, stds), Head(), Tail(num_classes))
27 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
28 | base_model.to(device)
29 | training_loader = torch.utils.data.DataLoader(training_set, batch_size=args.batch_size, shuffle = True, num_workers = args.num_workers)
30 | testing_loader = torch.utils.data.DataLoader(testing_set, batch_size=args.batch_size, num_workers = args.num_workers)
31 | print(f'The head has {sum(p.numel() for p in base_model[1].parameters())} parameters, the tail has {sum(p.numel() for p in base_model[2].parameters())} parameters.')
32 |
33 | # Place to save the trained model
34 | save_dir = f'saved_models/{args.model_name}-{args.dataset_name}'
35 | os.makedirs(save_dir, exist_ok = True)
36 |
37 | # Prepare for training
38 | optimizer = torch.optim.Adam(base_model.parameters(), lr = args.learning_rate)
39 | Loss = nn.CrossEntropyLoss()
40 |
41 | # training
42 | best_accuracy = 0.0
43 | for n in range(args.num_epochs):
44 |
45 | base_model.train()
46 | epoch_loss = 0.0
47 | for X, y in training_loader:
48 | X, y = X.to(device), y.to(device)
49 | optimizer.zero_grad()
50 | loss = Loss(base_model(X), y)
51 | loss.backward()
52 | optimizer.step()
53 | epoch_loss += loss.item() * len(y) / len(training_set)
54 |
55 | # testing
56 | base_model.eval()
57 | accuracy = 0.0
58 | with torch.no_grad():
59 | for X, y in testing_loader:
60 | X, y = X.to(device), y.to(device)
61 | _, pred = base_model(X).max(axis = -1)
62 | accuracy += (pred == y).sum().item() / len(testing_set)
63 |
64 | print(f'Epoch {n}, loss {epoch_loss:.3f}, accuracy = {accuracy:.4f}.')
65 |
66 | # save the best result
67 | if accuracy > best_accuracy:
68 | best_accuracy = accuracy
69 | torch.save(base_model[1].state_dict(), f'{save_dir}/base_head_state_dict')
70 | torch.save(base_model[2].state_dict(), f'{save_dir}/base_tail_state_dict')
71 |
72 | print(f'Completed the training of the base model, {args.model_name}-{args.dataset_name}, accuracy = {best_accuracy:.4f}.')
73 |
--------------------------------------------------------------------------------
/watermark.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import numpy as np
5 |
6 |
7 | class Watermark(nn.Module):
8 | def __init__(self, locations: np.array):
9 | '''
10 | locations: (N, 3) [[cha0, row0, col0], [cha1, row1, col1], [cha2, row2, col2], ...]
11 | '''
12 | super().__init__()
13 | assert len(locations.shape) == 2 and locations.shape[1] == 3
14 | self.locations = locations
15 |
16 |
17 | def forward(self, X):
18 | C, H, W = X.shape[-3:]
19 | if isinstance(X, torch.Tensor):
20 | mask = torch.ones_like(X, dtype = X.dtype, device = X.device)
21 | mask[..., self.locations[:, 0], self.locations[:, 1], self.locations[:, 2]] = 0.0
22 | return X * mask
23 |
24 | elif isinstance(X, np.ndarray):
25 | out = X.copy()
26 | out[..., self.locations[:, 0], self.locations[:, 1], self.locations[:, 2]] = 0.0
27 | return out
28 |
29 | else:
30 | raise TypeError
31 |
32 | def get_values(self, X):
33 | return X[..., self.locations[:, 0], self.locations[:, 1], self.locations[:, 2]]
34 |
35 | def save(self, fn):
36 | np.save(fn, self.locations)
37 |
38 |
39 | @staticmethod
40 | def load(fn):
41 | return Watermark(np.load(fn))
42 |
43 | @staticmethod
44 | def random(num_masked_dims, C, H, W):
45 | indices = np.random.choice(C * H * W, size = num_masked_dims, replace = False)
46 | watermark = Watermark(np.stack([indices // (H * W), (indices // W) % H, indices % W], axis = -1))
47 | return watermark
48 |
49 | @staticmethod
50 | def random_list(num_masked_dims, C, H, W, mask_list):
51 | indices = np.random.choice(mask_list, size = num_masked_dims, replace = False)
52 | watermark = Watermark(np.stack([indices // (H * W), (indices // W) % H, indices % W], axis = -1))
53 | return watermark
54 |
--------------------------------------------------------------------------------