├── README.md ├── attacks ├── __init__.py ├── bandit.py ├── decision.py ├── hsj.py ├── nes.py ├── score.py ├── signopt.py └── simba.py ├── config.py ├── models ├── __init__.py ├── resnet18.py └── vgg16.py ├── pics └── framework.png ├── requirements.txt ├── trace_data_free.py ├── trace_data_limited.py ├── train.py ├── train_base_model.py └── watermark.py /README.md: -------------------------------------------------------------------------------- 1 | ### Identification of the Adversary from a Single Adversarial Example (ICML 2023) 2 | This code is the official implementation of [Identification of the Adversary from a Single Adversarial Example](https://openreview.net/forum?id=HBrQI0tX8F). 3 | 4 | ---- 5 |
6 | 7 | ### Abstract 8 | 9 | Deep neural networks have been shown vulnerable to adversarial examples. Even though many defense methods have been proposed to enhance the robustness, it is still a long way toward providing an attack-free method to build a trustworthy machine learning system. In this paper, instead of enhancing the robustness, we take the investigator's perspective and propose a new framework to trace the first compromised model copy in a forensic investigation manner. Specifically, we focus on the following setting: the machine learning service provider provides model copies for a set of customers. However, one of the customers conducted adversarial attacks to fool the system. Therefore, the investigator's objective is to identify the first compromised copy by collecting and analyzing evidence from only available adversarial examples. To make the tracing viable, we design a random mask watermarking mechanism to differentiate adversarial examples from different copies. First, we propose a tracing approach in the data-limited case where the original example is also available. Then, we design a data-free approach to identify the adversary without accessing the original example. Finally, the effectiveness of our proposed framework is evaluated by extensive experiments with different model architectures, adversarial attacks, and datasets. 10 | 11 | ### Dependencies 12 | - PyTorch == 1.12.1 13 | - Torchvision == 0.13.1 14 | - Numpy == 1.21.5 15 | - Adversarial-Robustness-Toolbox == 1.10.3 16 | 17 | ### Pipeline 18 | #### Pretraining 19 | Use the following script to generate the pre-trained ResNet18 model on CIFAR-10 dataset. For Tiny-ImageNet, you may need to download the dataset from this [link](http://cs231n.stanford.edu/tiny-imagenet-200.zip) and move the data to your data directory. 20 | ``` 21 | python train_base_model.py --model_name ResNet18 --dataset_name CIFAR10 22 | ``` 23 | #### Watermarking 24 | For each model copy, we separate the base model into the head and tail (shared with all users) and only fine-tune the model head with a specific watermark while keeping the tail frozen. Here is a demo script for watermarking ResNet18 with the CIFAR-10 dataset. 25 | ``` 26 | python train.py --model_name ResNet18 --dataset_name CIFAR10 27 | ``` 28 | #### Tracing 29 | You could use the following script to generate adversarial examples for each user. In our demo, we apply the [Bandit](https://arxiv.org/abs/1807.07978) and generate 10 adversarial examples for each user (50*10 in total). 30 | ``` 31 | python -m attacks.bandit --model_name ResNet18 --dataset_name CIFAR10 -M 50 -n 10 32 | ``` 33 | We introduce two scenarios for tracing, namely the data-limited setting (with original image) and the data-free setting (without original image). The following script works in the data-limited case, and here we only take one adversarial example for each user to identify the adversary. 34 | ``` 35 | python trace_data_limited.py --model_name ResNet18 --dataset_name CIFAR10 --alpha 0.9 --attack Bandit -M 50 -n 1 36 | ``` 37 | Trace in the data-free case. 38 | ``` 39 | python trace_data_free.py --model_name ResNet18 --dataset_name CIFAR10 --alpha 0.5 --attack Bandit -M 50 -n 1 40 | ``` 41 | ### Citation 42 | 43 | If you find our work interesting, please consider giving a star :star: and cite as: 44 | ``` 45 | @inproceedings{cheng2023identification, 46 | title={Identification of the adversary from a single adversarial example}, 47 | author={Cheng, Minhao and Min, Rui and Sun, Haochen and Chen, Pin-Yu}, 48 | booktitle={International Conference on Machine Learning}, 49 | pages={5472--5484}, 50 | year={2023}, 51 | organization={PMLR} 52 | } 53 | ``` 54 | -------------------------------------------------------------------------------- /attacks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements handy numerical computational functions 3 | """ 4 | import numpy as np 5 | import torch as ch 6 | from torch.nn.modules import Upsample 7 | 8 | 9 | def norm(t): 10 | """ 11 | Return the norm of a tensor (or numpy) along all the dimensions except the first one 12 | :param t: 13 | :return: 14 | """ 15 | _shape = t.shape 16 | batch_size = _shape[0] 17 | num_dims = len(_shape[1:]) 18 | if ch.is_tensor(t): 19 | norm_t = ch.sqrt(t.pow(2).sum(dim=[_ for _ in range(1, len(_shape))])).view([batch_size] + [1] * num_dims) 20 | norm_t += (norm_t == 0).float() * np.finfo(np.float64).eps 21 | return norm_t 22 | else: 23 | _norm = np.linalg.norm( 24 | t.reshape([batch_size, -1]), axis=1, keepdims=1 25 | ).reshape([batch_size] + [1] * num_dims) 26 | return _norm + (_norm == 0) * np.finfo(np.float64).eps 27 | 28 | 29 | def eg_step(x, g, lr): 30 | """ 31 | Performs an exponentiated gradient step in the convex body [-1,1] 32 | :param x: batch_size x dim x .. tensor (or numpy) \in [-1,1] 33 | :param g: batch_size x dim x .. tensor (or numpy) 34 | :param lr: learning rate (step size) 35 | :return: 36 | """ 37 | # from [-1,1] to [0,1] 38 | real_x = (x + 1.) / 2. 39 | if ch.is_tensor(x): 40 | pos = real_x * ch.exp(lr * g) 41 | neg = (1 - real_x) * ch.exp(-lr * g) 42 | else: 43 | pos = real_x * np.exp(lr * g) 44 | neg = (1 - real_x) * np.exp(-lr * g) 45 | new_x = pos / (pos + neg) 46 | return new_x * 2 - 1 47 | 48 | 49 | def step(x, g, lr): 50 | """ 51 | Performs a step with no lp-ball constraints 52 | :param x: batch_size x dim x .. tensor (or numpy) 53 | :param g: batch_size x dim x .. tensor (or numpy) 54 | :param lr: learning rate (step size) 55 | :return: 56 | """ 57 | return x + lr * g 58 | 59 | 60 | def lp_step(x, g, lr, p): 61 | """ 62 | performs lp step of x in the direction of g, where the norm is computed 63 | across all the dimensions except the first one (assuming it's the batch_size) 64 | :param x: batch_size x dim x .. tensor (or numpy) 65 | :param g: batch_size x dim x .. tensor (or numpy) 66 | :param lr: learning rate (step size) 67 | :param p: 'inf' or '2' 68 | :return: 69 | """ 70 | if p == 'inf': 71 | return linf_step(x, g, lr) 72 | elif p == '2': 73 | return l2_step(x, g, lr) 74 | else: 75 | raise Exception('Invalid p value') 76 | 77 | 78 | def l2_step(x, g, lr): 79 | """ 80 | performs l2 step of x in the direction of g, where the norm is computed 81 | across all the dimensions except the first one (assuming it's the batch_size) 82 | :param x: batch_size x dim x .. tensor (or numpy) 83 | :param g: batch_size x dim x .. tensor (or numpy) 84 | :param lr: learning rate (step size) 85 | :return: 86 | """ 87 | # print(x.device) 88 | # print(g.device) 89 | # print(norm(g).device) 90 | return x + lr * g / norm(g) 91 | 92 | 93 | def linf_step(x, g, lr): 94 | """ 95 | performs linfinity step of x in the direction of g 96 | :param x: batch_size x dim x .. tensor (or numpy) 97 | :param g: batch_size x dim x .. tensor (or numpy) 98 | :param lr: learning rate (step size) 99 | :return: 100 | """ 101 | if ch.is_tensor(x): 102 | return x + lr * ch.sign(g) 103 | else: 104 | return x + lr * np.sign(g) 105 | 106 | 107 | def l2_proj_maker(xs, eps): 108 | """ 109 | makes an l2 projection function such that new points 110 | are projected within the eps l2-balls centered around xs 111 | :param xs: 112 | :param eps: 113 | :return: 114 | """ 115 | if ch.is_tensor(xs): 116 | orig_xs = xs.clone() 117 | 118 | def proj(new_xs): 119 | delta = new_xs - orig_xs 120 | norm_delta = norm(delta) 121 | if np.isinf(eps): # unbounded projection 122 | return orig_xs + delta 123 | else: 124 | return orig_xs + (norm_delta <= eps).float() * delta + ( 125 | norm_delta > eps).float() * eps * delta / norm_delta 126 | else: 127 | orig_xs = xs.copy() 128 | 129 | def proj(new_xs): 130 | delta = new_xs - orig_xs 131 | norm_delta = norm(delta) 132 | if np.isinf(eps): # unbounded projection 133 | return orig_xs + delta 134 | else: 135 | return orig_xs + (norm_delta <= eps) * delta + (norm_delta > eps) * eps * delta / norm_delta 136 | return proj 137 | 138 | 139 | def linf_proj_maker(xs, eps): 140 | """ 141 | makes an linf projection function such that new points 142 | are projected within the eps linf-balls centered around xs 143 | :param xs: 144 | :param eps: 145 | :return: 146 | """ 147 | if ch.is_tensor(xs): 148 | orig_xs = xs.clone() 149 | 150 | def proj(new_xs): 151 | return orig_xs + ch.clamp(new_xs - orig_xs, - eps, eps) 152 | else: 153 | orig_xs = xs.copy() 154 | 155 | def proj(new_xs): 156 | return np.clip(new_xs, orig_xs - eps, orig_xs + eps) 157 | return proj 158 | 159 | 160 | def upsample_maker(target_h, target_w): 161 | """ 162 | makes an upsampler which takes a numpy tensor of the form 163 | minibatch x channels x h x w and casts to 164 | minibatch x channels x target_h x target_w 165 | :param target_h: int to specify the desired height 166 | :param target_w: int to specify the desired width 167 | :return: 168 | """ 169 | _upsampler = Upsample(size=(target_h, target_w)) 170 | 171 | def upsample_fct(xs): 172 | if ch.is_tensor(xs): 173 | return _upsampler(xs) 174 | else: 175 | return _upsampler(ch.from_numpy(xs)).numpy() 176 | 177 | return upsample_fct 178 | 179 | 180 | def hamming_dist(a, b): 181 | """ 182 | reurns the hamming distance of a to b 183 | assumes a and b are in {+1, -1} 184 | :param a: 185 | :param b: 186 | :return: 187 | """ 188 | assert np.all(np.abs(a) == 1.), "a should be in {+1,-1}" 189 | assert np.all(np.abs(b) == 1.), "b should be in {+1,-1}" 190 | return sum([_a != _b for _a, _b in zip(a, b)]) 191 | 192 | 193 | def sign(t, is_ns_sign=True): 194 | """ 195 | Given a tensor t of `batch_size x dim` return the (non)standard sign of `t` 196 | based on the `is_ns_sign` flag 197 | :param t: tensor of `batch_size x dim` 198 | :param is_ns_sign: if True uses the non-standard sign function 199 | :return: 200 | """ 201 | _sign_t = ch.sign(t) if ch.is_tensor(t) else np.sign(t) 202 | if is_ns_sign: 203 | _sign_t[_sign_t == 0.] = 1. 204 | return _sign_t 205 | 206 | 207 | def noisy_sign(t, retain_p=1, crit='top', is_ns_sign=True): 208 | """ 209 | returns a noisy version of the tensor `t` where 210 | only `retain_p` * 100 % of the coordinates retain their sign according 211 | to a `crit`. 212 | The noise is of the following effect 213 | sign(t) * x where x \in {+1, -1} 214 | Thus, if sign(t) = 0, sign(t) * x is always 0 (in case of `is_ns_sign=False`) 215 | :param t: tensor of `batch_size x dim` 216 | :param retain_p: fraction of coordinates 217 | :param is_ns_sign: if True uses the non-standard sign function 218 | :return: 219 | """ 220 | assert 0. <= retain_p <= 1., "retain_p value should be in [0,1]" 221 | 222 | _shape = t.shape 223 | t = t.reshape(_shape[0], -1) 224 | batch_size, dim = t.shape 225 | 226 | sign_t = sign(t, is_ns_sign=is_ns_sign) 227 | k = int(retain_p * dim) 228 | 229 | if k == 0: # noise-ify all 230 | return (sign_t * np.sign((np.random.rand(batch_size, dim) < 0.5) - 0.5)).reshape(_shape) 231 | if k == dim: # retain all 232 | return sign_t.reshape(_shape) 233 | 234 | # do topk otheriwise 235 | noisy_sign_t = sign_t * np.sign((np.random.rand(*t.shape) < 0.5) - 0.5) 236 | _rows = np.zeros((batch_size, k), dtype=np.intp) + np.arange(batch_size)[:, None] 237 | if crit == 'top': 238 | _temp = np.abs(t) 239 | elif crit == 'random': 240 | _temp = np.random.rand(*t.shape) 241 | else: 242 | raise Exception('Unknown criterion for topk') 243 | 244 | _cols = np.argpartition(_temp, -k, axis=1)[:, -k:] 245 | noisy_sign_t[_rows, _cols] = sign_t[_rows, _cols] 246 | return noisy_sign_t.reshape(_shape) 247 | -------------------------------------------------------------------------------- /attacks/bandit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Upsample 4 | from torchvision import transforms 5 | import numpy as np 6 | import os 7 | import argparse 8 | from art.estimators.classification import PyTorchClassifier 9 | 10 | from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail 11 | import config 12 | from watermark import Watermark 13 | from attacks.score import ScoreBlackBoxAttack 14 | from attacks import * 15 | 16 | 17 | Loss = nn.CrossEntropyLoss(reduction = 'none') 18 | 19 | class BanditAttack(ScoreBlackBoxAttack): 20 | """ 21 | Bandit Attack 22 | """ 23 | 24 | def __init__(self, 25 | max_loss_queries, 26 | epsilon, p, 27 | fd_eta, lr, 28 | prior_exploration, prior_size, data_size, prior_lr, 29 | lb, ub, batch_size, name): 30 | """ 31 | :param max_loss_queries: maximum number of calls allowed to loss oracle per data pt 32 | :param epsilon: radius of lp-ball of perturbation 33 | :param p: specifies lp-norm of perturbation 34 | :param fd_eta: forward difference step 35 | :param lr: learning rate of NES step 36 | :param prior_exploration: exploration noise 37 | :param prior_size: prior height/width (this is applicable only to images), you can disable it by setting it to 38 | None (it is assumed to prior_size = prior_height == prior_width) 39 | :param data_size: data height/width (applicable to images of the from `c x h x w`, you can ignore it 40 | by setting it to none, it is assumed that data_size = height = width 41 | :param prior_lr: learning rate in the prior space 42 | :param lb: data lower bound 43 | :param ub: data upper bound 44 | """ 45 | super().__init__(max_extra_queries=np.inf, 46 | max_loss_queries=max_loss_queries, 47 | epsilon=epsilon, 48 | p=p, 49 | lb=lb, 50 | ub=ub, 51 | batch_size= batch_size, 52 | name = "Bandit") 53 | # other algorithmic parameters 54 | self.fd_eta = fd_eta 55 | # learning rate 56 | self.lr = lr 57 | # data size 58 | self.data_size = data_size 59 | 60 | # prior setup: 61 | # 1. step function 62 | if self.p == '2': 63 | self.prior_step = step 64 | elif self.p == 'inf': 65 | self.prior_step = eg_step 66 | else: 67 | raise Exception("Invalid p for l-p constraint") 68 | # 2. prior placeholder 69 | self.prior = None 70 | # prior size 71 | self.prior_size = prior_size 72 | # prior exploration 73 | self.prior_exploration = prior_exploration 74 | # 3. prior upsampler 75 | self.prior_upsample_fct = None if self.prior_size is None else upsample_maker(data_size, data_size) 76 | self.prior_lr = prior_lr 77 | 78 | def _perturb(self, xs_t, loss_fct): 79 | """ 80 | The core of the bandit algorithm 81 | since this is compute intensive, it is implemented with torch support to push ops into gpu (if available) 82 | however, the input / output are numpys 83 | :param xs: numpy 84 | :return new_xs: returns a torch tensor 85 | """ 86 | 87 | _shape = list(xs_t.shape) 88 | eff_shape = list(xs_t.shape) 89 | # since the upsampling assumes xs_t is batch_size x c x h x w. This is not the case for mnist, 90 | # which is batch_size x dim, let's take care of that below 91 | 92 | if self.prior_size is None: 93 | prior_shape = eff_shape 94 | else: 95 | prior_shape = eff_shape[:-2] + [self.prior_size] * 2 96 | # reset the prior if xs is a new batch 97 | if self.is_new_batch: 98 | self.prior = torch.zeros(prior_shape, device = xs_t.device) 99 | # create noise for exploration, estimate the gradient, and take a PGD step 100 | # exp_noise = torch.randn(prior_shape) / (np.prod(prior_shape[1:]) ** 0.5) # according to the paper 101 | exp_noise = torch.randn(prior_shape, device = xs_t.device) 102 | # Query deltas for finite difference estimator 103 | if self.prior_size is None: 104 | q1 = step(self.prior, exp_noise, self.prior_exploration) 105 | q2 = step(self.prior, exp_noise, - self.prior_exploration) 106 | else: 107 | q1 = self.prior_upsample_fct(step(self.prior, exp_noise, self.prior_exploration)) 108 | q2 = self.prior_upsample_fct(step(self.prior, exp_noise, - self.prior_exploration)) 109 | # Loss points for finite difference estimator 110 | l1 = loss_fct(l2_step(xs_t, q1.view(_shape), self.fd_eta)) 111 | l2 = loss_fct(l2_step(xs_t, q2.view(_shape), self.fd_eta)) 112 | # finite differences estimate of directional derivative 113 | est_deriv = (l1 - l2) / (self.fd_eta * self.prior_exploration) 114 | # 2-query gradient estimate 115 | # Note: Ilyas' implementation multiply the below by self.prior_exploration (different from pseudocode) 116 | # This should not affect the result as the `self.prior_lr` can be adjusted accordingly 117 | est_grad = est_deriv.view(-1, *[1] * len(prior_shape[1:]))* exp_noise 118 | # update prior with the estimated gradient: 119 | self.prior = self.prior_step(self.prior, est_grad, self.prior_lr) 120 | # gradient step in the data space 121 | if self.prior_size is None: 122 | gs = self.prior.clone() 123 | else: 124 | gs = self.prior_upsample_fct(self.prior) 125 | # perform the step 126 | new_xs = lp_step(xs_t, gs.view(_shape), self.lr, self.p) 127 | return new_xs, 2 * torch.ones(_shape[0], device = xs_t.device) 128 | 129 | def _config(self): 130 | return { 131 | "name": self.name, 132 | "p": self.p, 133 | "epsilon": self.epsilon, 134 | "lb": self.lb, 135 | "ub": self.ub, 136 | "max_extra_queries": "inf" if np.isinf(self.max_extra_queries) else self.max_extra_queries, 137 | "max_loss_queries": "inf" if np.isinf(self.max_loss_queries) else self.max_loss_queries, 138 | "lr": self.lr, 139 | "prior_lr": self.prior_lr, 140 | "prior_exploration": self.prior_exploration, 141 | "prior_size": self.prior_size, 142 | "data_size": self.data_size, 143 | "fd_eta": self.fd_eta, 144 | "attack_name": self.__class__.__name__ 145 | } 146 | 147 | if __name__ == '__main__': 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument('--model_name', help = 'Benchmark model structure.', choices = ['VGG16', 'ResNet18']) 150 | parser.add_argument('--dataset_name', help = 'Benchmark dataset used.', choices = ['CIFAR10', 'GTSRB']) 151 | parser.add_argument('-M', '--num_models', help = 'The number of models used.', type = int, default = 100) 152 | parser.add_argument('-n', '--num_samples', help = 'The number of adversarial samples per model.', type = int, default = 1) 153 | parser.add_argument('-c', '--cont', help = 'Continue from the stopped point last time.', action = 'store_true') 154 | parser.add_argument('-b', '--batch_size', help = 'The batch size used for attacks.', type = int, default = 10) 155 | args = parser.parse_args() 156 | 157 | # renaming 158 | dataset = eval(f'config.{args.dataset_name}()') 159 | training_set, testing_set = dataset.training_set, dataset.testing_set 160 | num_classes = dataset.num_classes 161 | means, stds = dataset.means, dataset.stds 162 | C, H, W = dataset.C, dataset.H, dataset.W 163 | Head, Tail = eval(f'{args.model_name}Head'), eval(f'{args.model_name}Tail') 164 | testing_loader = torch.utils.data.DataLoader(testing_set, batch_size = args.batch_size, shuffle = True, num_workers = 2) 165 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 166 | 167 | model_dir = f'saved_models/{args.model_name}-{args.dataset_name}' 168 | 169 | save_dir = f'saved_adv_examples/{args.model_name}-{args.dataset_name}' 170 | 171 | 172 | # load the tail of the model 173 | normalizer = transforms.Normalize(means, stds) 174 | 175 | # load the classifiers 176 | classifiers = [] 177 | models = [] 178 | tail = Tail(num_classes) 179 | tail.load_state_dict(torch.load(f'{model_dir}/base_tail_state_dict')) 180 | tail.to(device) 181 | for i in range(args.num_models): 182 | 183 | 184 | head = Head() 185 | head.to(device) 186 | head.load_state_dict(torch.load(f'{model_dir}/head_{i}/state_dict')) 187 | watermark = Watermark.load(f'{model_dir}/head_{i}/watermark.npy') 188 | models.append(nn.Sequential(normalizer, watermark, head, tail).eval()) 189 | models[-1].to(device) 190 | 191 | classifier = PyTorchClassifier( 192 | model = models[-1], 193 | loss = None, 194 | optimizer = None, 195 | clip_values = (0, 1), 196 | input_shape=(C, H, W), 197 | nb_classes=num_classes, 198 | device_type = 'gpu' if torch.cuda.is_available() else 'cpu' 199 | ) 200 | classifiers.append(classifier) 201 | classifiers = np.array(classifiers) 202 | 203 | for i, (model, c) in enumerate(zip(models, classifiers)): 204 | if os.path.isfile(f'{save_dir}/head_{i}/Bandit.npz') and args.cont: 205 | continue 206 | original_images, attacked_images, labels = [], [], [] 207 | count_success = 0 208 | for X, y in testing_loader: 209 | with torch.no_grad(): 210 | pred = c.predict(X.numpy()) 211 | correct_mask = pred.argmax(axis = -1) == y.numpy() 212 | 213 | X_device, y_device = X.to(device), y.to(device) 214 | def loss_fct(xs, es = False): 215 | logits = model(xs) 216 | loss = Loss(logits.to(device), y_device) 217 | if es: 218 | return torch.argmax(logits, axis= -1) != y_device, loss 219 | else: 220 | return loss 221 | 222 | def early_stop_crit_fct(xs): 223 | logits = model(xs) 224 | return logits.argmax(axis = -1) != y_device 225 | 226 | a = BanditAttack(max_loss_queries = 10000, epsilon = 1.0, p = '2', lb = 0.0, ub = 1.0, batch_size = args.batch_size, name = 'Bandit', 227 | fd_eta = 0.01, lr = 0.01, prior_exploration = 0.1, prior_size = 20, data_size = 32, prior_lr = 0.1) 228 | 229 | X_attacked = a.run(X_device, loss_fct, early_stop_crit_fct).cpu().numpy() 230 | 231 | attacked_preds = np.vectorize(lambda z: z.predict(X_attacked), signature = '()->(m,n)')(classifiers) 232 | 233 | success_mask = attacked_preds.argmax(axis = -1) != y.numpy() 234 | success_mask = np.logical_and(success_mask[i], success_mask.sum(axis=0) >= 2) 235 | 236 | mask = np.logical_and(correct_mask, success_mask) 237 | 238 | original_images.append(X[mask]) 239 | attacked_images.append(X_attacked[mask]) 240 | labels.append(y[mask]) 241 | 242 | count_success += mask.sum() 243 | if count_success >= args.num_samples: 244 | print(f'Model {i}, attack Bandit, {count_success} out of {args.num_samples} generated, done!') 245 | break 246 | else: 247 | print(f'Model {i}, attack Bandit, {count_success} out of {args.num_samples} generated...') 248 | 249 | original_images = np.concatenate(original_images) 250 | attacked_images = np.concatenate(attacked_images) 251 | 252 | labels = np.concatenate(labels) 253 | os.makedirs(f'{save_dir}/head_{i}', exist_ok = True) 254 | np.savez(f'{save_dir}/head_{i}/Bandit.npz', X = original_images, X_attacked = attacked_images, y = labels) 255 | -------------------------------------------------------------------------------- /attacks/decision.py: -------------------------------------------------------------------------------- 1 | from attacks import * 2 | 3 | import torch 4 | 5 | 6 | 7 | class DecisionBlackBoxAttack(object): 8 | def __init__(self, max_queries=np.inf, epsilon=0.5, p='inf', lb=0., ub=1., batch_size=1): 9 | """ 10 | :param max_queries: max number of calls to model per data point 11 | :param epsilon: perturbation limit according to lp-ball 12 | :param p: norm for the lp-ball constraint 13 | :param lb: minimum value data point can take in any coordinate 14 | :param ub: maximum value data point can take in any coordinate 15 | """ 16 | assert p in ['inf', '2'], "L-{} is not supported".format(p) 17 | 18 | self.p = p 19 | self.max_queries = max_queries 20 | self.total_queries = 0 21 | self.total_successes = 0 22 | self.total_failures = 0 23 | self.total_distance = 0 24 | self.sigma = 0 25 | self.EOT = 1 26 | self.lb = lb 27 | self.ub = ub 28 | self.epsilon = epsilon / ub 29 | self.batch_size = batch_size 30 | self.list_loss_queries = torch.zeros(1, self.batch_size) 31 | 32 | def result(self): 33 | """ 34 | returns a summary of the attack results (to be tabulated) 35 | :return: 36 | """ 37 | list_loss_queries = self.list_loss_queries[1:].view(-1) 38 | mask = list_loss_queries > 0 39 | list_loss_queries = list_loss_queries[mask] 40 | self.total_queries = int(self.total_queries) 41 | self.total_successes = int(self.total_successes) 42 | self.total_failures = int(self.total_failures) 43 | return { 44 | "total_queries": self.total_queries, 45 | "total_successes": self.total_successes, 46 | "total_failures": self.total_failures, 47 | "average_num_queries": "NaN" if self.total_successes == 0 else self.total_queries / self.total_successes, 48 | "failure_rate": "NaN" if self.total_successes + self.total_failures == 0 else self.total_failures / (self.total_successes + self.total_failures), 49 | "median_num_loss_queries": "NaN" if self.total_successes == 0 else torch.median(list_loss_queries).item(), 50 | "config": self._config() 51 | } 52 | 53 | def _config(self): 54 | """ 55 | return the attack's parameter configurations as a dict 56 | :return: 57 | """ 58 | raise NotImplementedError 59 | 60 | def distance(self, x_adv, x = None): 61 | if x is None: 62 | diff = x_adv.view(x_adv.shape[0], -1) 63 | else: 64 | diff = (x_adv - x).view(x.shape[0], -1) 65 | if self.p == '2': 66 | out = torch.sqrt(torch.sum(diff * diff, dim = 1)) 67 | elif self.p == 'inf': 68 | out, _ = torch.max(torch.abs(diff), dim = 1) 69 | return out 70 | 71 | def is_adversarial(self, x, y): 72 | ''' 73 | check whether the adversarial constrain holds for x 74 | ''' 75 | if self.targeted: 76 | return self.predict_label(x) == y 77 | else: 78 | return self.predict_label(x) != y 79 | 80 | def predict_label(self, xs): 81 | with torch.no_grad(): 82 | if type(xs) is torch.Tensor: 83 | out = self.model(xs).argmax(dim=-1).squeeze() 84 | else: 85 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 86 | out = self.model(torch.FloatTensor(xs).to(device)).argmax(dim=-1).squeeze() 87 | return out 88 | 89 | def _perturb(self, xs_t, ys): 90 | raise NotImplementedError 91 | 92 | def run(self, Xs, ys, model, targeted, dset): 93 | self.model = model 94 | self.targeted = targeted 95 | 96 | X_attacked = [] 97 | 98 | for x, y in zip(Xs, ys): 99 | adv, _ = self._perturb(x[None, ...], y[None]) 100 | X_attacked.append(adv.squeeze()) 101 | X_attacked = torch.stack(X_attacked).float() 102 | 103 | success = (self.distance(X_attacked,Xs) < self.epsilon) 104 | 105 | return X_attacked * success[:, None, None, None] + Xs * (~success[:, None, None, None]) 106 | -------------------------------------------------------------------------------- /attacks/hsj.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import transforms 4 | import numpy as np 5 | import os 6 | import argparse 7 | from art.estimators.classification import PyTorchClassifier 8 | from art.attacks.evasion import HopSkipJump 9 | 10 | from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail 11 | import config 12 | from watermark import Watermark 13 | 14 | 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--model_name', help = 'Benchmark model structure.', choices = ['VGG16', 'ResNet18']) 19 | parser.add_argument('--dataset_name', help = 'Benchmark dataset used.', choices = ['CIFAR10', 'GTSRB']) 20 | parser.add_argument('-M', '--num_models', help = 'The number of models used.', type = int, default = 100) 21 | parser.add_argument('-n', '--num_samples', help = 'The number of adversarial samples per model.', type = int, default = 1) 22 | parser.add_argument('-c', '--cont', help = 'Continue from the stopped point last time.', action = 'store_true') 23 | parser.add_argument('-b', '--batch_size', help = 'The batch size used for attacks.', type = int, default = 10) 24 | parser.add_argument('-v', '--verbose', help = 'Verbose when attacking.', action = 'store_true') 25 | args = parser.parse_args() 26 | 27 | dataset = eval(f'config.{args.dataset_name}()') 28 | training_set, testing_set = dataset.training_set, dataset.testing_set 29 | num_classes = dataset.num_classes 30 | means, stds = dataset.means, dataset.stds 31 | C, H, W = dataset.C, dataset.H, dataset.W 32 | Head, Tail = eval(f'{args.model_name}Head'), eval(f'{args.model_name}Tail') 33 | testing_loader = torch.utils.data.DataLoader(testing_set, batch_size = args.batch_size, shuffle = True, num_workers = 2) 34 | 35 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 36 | 37 | model_dir = f'saved_models/{args.model_name}-{args.dataset_name}' 38 | 39 | save_dir = f'saved_adv_examples/{args.model_name}-{args.dataset_name}' 40 | 41 | 42 | # load the tail of the model 43 | normalizer = transforms.Normalize(means, stds) 44 | tail = Tail(num_classes) 45 | tail.load_state_dict(torch.load(f'{model_dir}/base_tail_state_dict')) 46 | tail.to(device) 47 | 48 | # load the classifiers 49 | classifiers = [] 50 | models = [] 51 | for i in range(args.num_models): 52 | head = Head() 53 | head.to(device) 54 | head.load_state_dict(torch.load(f'{model_dir}/head_{i}/state_dict')) 55 | watermark = Watermark.load(f'{model_dir}/head_{i}/watermark.npy') 56 | models.append(nn.Sequential(normalizer, watermark, head, tail, nn.Softmax(dim = -1)).eval()) 57 | models[-1].to(device) 58 | classifier = PyTorchClassifier( 59 | model = nn.Sequential(normalizer, watermark, head, tail, nn.Softmax(dim = -1)).eval(), 60 | loss = None, 61 | optimizer = None, 62 | clip_values = (0, 1), 63 | input_shape=(C, H, W), 64 | nb_classes=num_classes, 65 | device_type = 'gpu' if torch.cuda.is_available() else 'cpu' 66 | ) 67 | classifiers.append(classifier) 68 | classifiers = np.array(classifiers) 69 | 70 | # attacking 71 | for i, (model, c) in enumerate(zip(models, classifiers)): 72 | a = HopSkipJump(c, verbose = args.verbose) 73 | if os.path.isfile(f'{save_dir}/head_{i}/HopSkipJump_proj.npz') and args.cont: 74 | continue 75 | 76 | original_images, attacked_images, labels = [], [], [] 77 | count_success = 0 78 | 79 | for X, y in testing_loader: 80 | X, y = X.numpy(), y.numpy() 81 | pred = c.predict(X) 82 | correct_mask = pred.argmax(axis = 1) == y 83 | 84 | X_attacked = a.generate(X) 85 | attacked_preds = np.vectorize(lambda z: z.predict(X_attacked), signature = '()->(m,n)')(classifiers) # (num_model, batch_size, num_class) 86 | success_mask = attacked_preds.argmax(axis = -1) != y 87 | success_mask = np.logical_and(success_mask[i], success_mask.sum(axis=0) >= 2) 88 | mask = np.logical_and(correct_mask, success_mask) 89 | 90 | original_images.append(X[mask]) 91 | attacked_images.append(X_attacked[mask]) 92 | labels.append(y[mask]) 93 | 94 | count_success += mask.sum() 95 | if count_success >= args.num_samples: 96 | print(f'Head {i}, attack HopSkipJump, {count_success} out of {args.num_samples} generated, done!') 97 | break 98 | else: 99 | print(f'Head {i}, attack HopSkipJump, {count_success} out of {args.num_samples} generated...') 100 | 101 | original_images = np.concatenate(original_images) 102 | attacked_images = np.concatenate(attacked_images) 103 | labels = np.concatenate(labels) 104 | os.makedirs(f'{save_dir}/head_{i}', exist_ok = True) 105 | np.savez(f'{save_dir}/head_{i}/HopSkipJump.npz', X = original_images, X_attacked = attacked_images, y = labels) 106 | -------------------------------------------------------------------------------- /attacks/nes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import transforms 4 | import numpy as np 5 | import os 6 | import argparse 7 | from art.estimators.classification import PyTorchClassifier 8 | 9 | from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail 10 | import config 11 | from watermark import Watermark 12 | from attacks.score import ScoreBlackBoxAttack 13 | from attacks import * 14 | 15 | Loss = nn.CrossEntropyLoss(reduction = 'none') 16 | 17 | class NESAttack(ScoreBlackBoxAttack): 18 | """ 19 | NES Attack 20 | """ 21 | 22 | def __init__(self, max_loss_queries, epsilon, p, fd_eta, lr, q, lb, ub, batch_size, name): 23 | """ 24 | :param max_loss_queries: maximum number of calls allowed to loss oracle per data pt 25 | :param epsilon: radius of lp-ball of perturbation 26 | :param p: specifies lp-norm of perturbation 27 | :param fd_eta: forward difference step 28 | :param lr: learning rate of NES step 29 | :param q: number of noise samples per NES step 30 | :param lb: data lower bound 31 | :param ub: data upper bound 32 | """ 33 | super().__init__(max_extra_queries=np.inf, 34 | max_loss_queries=max_loss_queries, 35 | epsilon=epsilon, 36 | p=p, 37 | lb=lb, 38 | ub=ub, 39 | batch_size= batch_size, 40 | name = "NES") 41 | self.q = q 42 | self.fd_eta = fd_eta 43 | self.lr = lr 44 | 45 | def _perturb(self, xs_t, loss_fct): 46 | _shape = list(xs_t.shape) 47 | dim = np.prod(_shape[1:]) 48 | num_axes = len(_shape[1:]) 49 | gs_t = torch.zeros_like(xs_t) 50 | for _ in range(self.q): 51 | # exp_noise = torch.randn_like(xs_t) / (dim ** 0.5) 52 | exp_noise = torch.randn_like(xs_t) 53 | fxs_t = xs_t + self.fd_eta * exp_noise 54 | bxs_t = xs_t - self.fd_eta * exp_noise 55 | est_deriv = (loss_fct(fxs_t) - loss_fct(bxs_t)) / (4. * self.fd_eta) 56 | gs_t += est_deriv.reshape(-1, *[1] * num_axes) * exp_noise 57 | # perform the step 58 | new_xs = lp_step(xs_t, gs_t, self.lr, self.p) 59 | return new_xs, 2 * self.q * torch.ones(_shape[0], device = xs_t.device) 60 | 61 | def _config(self): 62 | return { 63 | "name": self.name, 64 | "p": self.p, 65 | "epsilon": self.epsilon, 66 | "lb": self.lb, 67 | "ub": self.ub, 68 | "max_extra_queries": "inf" if np.isinf(self.max_extra_queries) else self.max_extra_queries, 69 | "max_loss_queries": "inf" if np.isinf(self.max_loss_queries) else self.max_loss_queries, 70 | "lr": self.lr, 71 | "q": self.q, 72 | "fd_eta": self.fd_eta, 73 | "attack_name": self.__class__.__name__ 74 | } 75 | 76 | if __name__ == '__main__': 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument('--model_name', help = 'Benchmark model structure.', choices = ['VGG16', 'ResNet18']) 79 | parser.add_argument('--dataset_name', help = 'Benchmark dataset used.', choices = ['CIFAR10', 'GTSRB']) 80 | parser.add_argument('-M', '--num_models', help = 'The number of models used.', type = int, default = 100) 81 | parser.add_argument('-n', '--num_samples', help = 'The number of adversarial samples per model.', type = int, default = 1) 82 | parser.add_argument('-c', '--cont', help = 'Continue from the stopped point last time.', action = 'store_true') 83 | parser.add_argument('-b', '--batch_size', help = 'The batch size used for attacks.', type = int, default = 10) 84 | args = parser.parse_args() 85 | 86 | # renaming 87 | dataset = eval(f'config.{args.dataset_name}()') 88 | training_set, testing_set = dataset.training_set, dataset.testing_set 89 | num_classes = dataset.num_classes 90 | means, stds = dataset.means, dataset.stds 91 | C, H, W = dataset.C, dataset.H, dataset.W 92 | Head, Tail = eval(f'{args.model_name}Head'), eval(f'{args.model_name}Tail') 93 | testing_loader = torch.utils.data.DataLoader(testing_set, batch_size = args.batch_size, shuffle = True, num_workers = 2) 94 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 95 | 96 | model_dir = f'saved_models/{args.model_name}-{args.dataset_name}' 97 | 98 | save_dir = f'saved_adv_examples/{args.model_name}-{args.dataset_name}' 99 | 100 | 101 | # load the tail of the model 102 | normalizer = transforms.Normalize(means, stds) 103 | tail = Tail(num_classes) 104 | tail.load_state_dict(torch.load(f'{model_dir}/base_tail_state_dict')) 105 | tail.to(device) 106 | 107 | # load the classifiers 108 | classifiers = [] 109 | models = [] 110 | for i in range(args.num_models): 111 | head = Head() 112 | head.to(device) 113 | head.load_state_dict(torch.load(f'{model_dir}/head_{i}/state_dict')) 114 | watermark = Watermark.load(f'{model_dir}/head_{i}/watermark.npy') 115 | 116 | models.append(nn.Sequential(normalizer, watermark, head, tail).eval()) 117 | models[-1].to(device) 118 | 119 | classifier = PyTorchClassifier( 120 | model = models[-1], 121 | loss = None, 122 | optimizer = None, 123 | clip_values = (0, 1), 124 | input_shape=(C, H, W), 125 | nb_classes=num_classes, 126 | device_type = 'gpu' if torch.cuda.is_available() else 'cpu' 127 | ) 128 | classifiers.append(classifier) 129 | classifiers = np.array(classifiers) 130 | 131 | for i, (model, c) in enumerate(zip(models, classifiers)): 132 | if os.path.isfile(f'{save_dir}/head_{i}/NES.npz') and args.cont: 133 | continue 134 | original_images, attacked_images, labels = [], [], [] 135 | count_success = 0 136 | for X, y in testing_loader: 137 | with torch.no_grad(): 138 | pred = c.predict(X.numpy()) 139 | correct_mask = pred.argmax(axis = -1) == y.numpy() 140 | 141 | X_device, y_device = X.to(device), y.to(device) 142 | def loss_fct(xs, es = False): 143 | logits = model(xs) 144 | loss = Loss(logits.to(device), y_device) 145 | if es: 146 | return torch.argmax(logits, axis= -1) != y_device, loss 147 | else: 148 | return loss 149 | 150 | def early_stop_crit_fct(xs): 151 | logits = model(xs) 152 | return logits.argmax(axis = -1) != y_device 153 | 154 | a = NESAttack(max_loss_queries = 10000, epsilon = 1.0, p = '2', fd_eta = 0.01, lr = 0.01, q = 15, lb = 0.0, ub = 1.0, batch_size = args.batch_size, name = 'NESAttack') 155 | 156 | X_attacked = a.run(X_device, loss_fct, early_stop_crit_fct).cpu().numpy() 157 | 158 | attacked_preds = np.vectorize(lambda z: z.predict(X_attacked), signature = '()->(m,n)')(classifiers) 159 | 160 | success_mask = attacked_preds.argmax(axis = -1) != y.numpy() 161 | success_mask = np.logical_and(success_mask[i], success_mask.sum(axis=0) >= 2) 162 | 163 | mask = np.logical_and(correct_mask, success_mask) 164 | 165 | original_images.append(X[mask]) 166 | attacked_images.append(X_attacked[mask]) 167 | labels.append(y[mask]) 168 | 169 | count_success += mask.sum() 170 | if count_success >= args.num_samples: 171 | print(f'Model {i}, attack NES, {count_success} out of {args.num_samples} generated, done!') 172 | break 173 | else: 174 | print(f'Model {i}, attack NES, {count_success} out of {args.num_samples} generated...') 175 | 176 | original_images = np.concatenate(original_images) 177 | attacked_images = np.concatenate(attacked_images) 178 | labels = np.concatenate(labels) 179 | os.makedirs(f'{save_dir}/head_{i}', exist_ok = True) 180 | np.savez(f'{save_dir}/head_{i}/NES.npz', X = original_images, X_attacked = attacked_images, y = labels) 181 | -------------------------------------------------------------------------------- /attacks/score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | from torch import Tensor as t 6 | 7 | from attacks import * 8 | 9 | class ScoreBlackBoxAttack(object): 10 | def __init__(self, max_loss_queries=np.inf, 11 | max_extra_queries=np.inf, 12 | epsilon=0.5, p='inf', lb=0., ub=1.,batch_size = 50, name = '', device = 'cuda'): 13 | """ 14 | :param max_loss_queries: max number of calls to model per data point 15 | :param max_extra_queries: max number of calls to early stopping extraerion per data point 16 | :param epsilon: perturbation limit according to lp-ball 17 | :param p: norm for the lp-ball constraint 18 | :param lb: minimum value data point can take in any coordinate 19 | :param ub: maximum value data point can take in any coordinate 20 | """ 21 | assert p in ['inf', '2'], "L-{} is not supported".format(p) 22 | 23 | self.epsilon = epsilon 24 | self.p = p 25 | self.batch_size = batch_size 26 | self.max_loss_queries = max_loss_queries 27 | self.max_extra_queries = max_extra_queries 28 | self.list_loss_queries = torch.zeros(1, self.batch_size, device = device) 29 | self.total_loss_queries = 0 30 | self.total_extra_queries = 0 31 | self.total_successes = 0 32 | self.total_failures = 0 33 | self.lb = lb 34 | self.ub = ub 35 | self.name = name 36 | # the _proj method takes pts and project them into the constraint set: 37 | # which are 38 | # 1. epsilon lp-ball around xs 39 | # 2. valid data pt range [lb, ub] 40 | # it is meant to be used within `self.run` and `self._perturb` 41 | self._proj = None 42 | # a handy flag for _perturb method to denote whether the provided xs is a 43 | # new batch (i.e. the first iteration within `self.run`) 44 | self.is_new_batch = False 45 | 46 | def result(self): 47 | """ 48 | returns a summary of the attack results (to be tabulated) 49 | :return: 50 | """ 51 | list_loss_queries = self.list_loss_queries[1:].view(-1) 52 | mask = list_loss_queries > 0 53 | list_loss_queries = list_loss_queries[mask] 54 | self.total_loss_queries = int(self.total_loss_queries) 55 | self.total_extra_queries = int(self.total_extra_queries) 56 | self.total_successes = int(self.total_successes) 57 | self.total_failures = int(self.total_failures) 58 | return { 59 | "total_loss_queries": self.total_loss_queries, 60 | "total_extra_queries": self.total_extra_queries, 61 | "average_num_loss_queries": "NaN" if self.total_successes == 0 else self.total_loss_queries / self.total_successes, 62 | "average_num_extra_queries": "NaN" if self.total_successes == 0 else self.total_extra_queries / self.total_successes, 63 | "median_num_loss_queries": "NaN" if self.total_successes == 0 else torch.median(list_loss_queries).item(), 64 | "total_queries": self.total_extra_queries + self.total_loss_queries, 65 | "average_num_queries": "NaN" if self.total_successes == 0 else (self.total_extra_queries + self.total_loss_queries) / self.total_successes, 66 | "total_successes": self.total_successes, 67 | "total_failures": self.total_failures, 68 | "failure_rate": "NaN" if self.total_successes + self.total_failures == 0 else self.total_failures / (self.total_successes + self.total_failures), 69 | "config": self._config() 70 | } 71 | 72 | def _config(self): 73 | """ 74 | return the attack's parameter configurations as a dict 75 | :return: 76 | """ 77 | raise NotImplementedError 78 | 79 | def _perturb(self, xs_t, loss_fct): 80 | """ 81 | :param xs_t: batch_size x dim x .. (torch tensor) 82 | :param loss_fct: function to query (the attacker would like to maximize) (batch_size data pts -> R^{batch_size} 83 | :return: suggested xs as a (torch tensor)and the used number of queries per data point 84 | i.e. a tuple of (batch_size x dim x .. tensor, batch_size array of number queries used) 85 | """ 86 | raise NotImplementedError 87 | 88 | def proj_replace(self, xs_t, sugg_xs_t, dones_mask_t): 89 | sugg_xs_t = self._proj(sugg_xs_t) 90 | # replace xs only if not done 91 | xs_t = sugg_xs_t * (1. - dones_mask_t) + xs_t * dones_mask_t 92 | return xs_t 93 | 94 | def run(self, xs, loss_fct, early_stop_extra_fct): 95 | """ 96 | attack with `xs` as data points using the oracle `l` and 97 | the early stopping extraerion `early_stop_extra_fct` 98 | :param xs: data points to be perturbed adversarially (numpy array) 99 | :param loss_fct: loss function (m data pts -> R^m) 100 | :param early_stop_extra_fct: early stop function (m data pts -> {0,1}^m) 101 | ith entry is 1 if the ith data point is misclassified 102 | :return: a dict of logs whose length is the number of iterations 103 | """ 104 | # convert to tensor 105 | xs_t = torch.clone(xs) 106 | 107 | batch_size = xs.shape[0] 108 | num_axes = len(xs.shape[1:]) 109 | num_loss_queries = torch.zeros(batch_size, device = xs.device) 110 | num_extra_queries = torch.zeros(batch_size, device = xs.device) 111 | 112 | dones_mask = early_stop_extra_fct(xs_t) 113 | correct_classified_mask = ~dones_mask 114 | 115 | # init losses for performance tracking 116 | losses = torch.zeros(batch_size, device = xs.device) 117 | 118 | # make a projector into xs lp-ball and within valid pixel range 119 | if self.p == '2': 120 | _proj = l2_proj_maker(xs_t, self.epsilon) 121 | self._proj = lambda _: torch.clamp(_proj(_), self.lb, self.ub) 122 | elif self.p == 'inf': 123 | _proj = linf_proj_maker(xs_t, self.epsilon) 124 | self._proj = lambda _: torch.clamp(_proj(_), self.lb, self.ub) 125 | else: 126 | raise Exception('Undefined l-p!') 127 | 128 | # iterate till model evasion or budget exhaustion 129 | self.is_new_batch = True 130 | its = 0 131 | while True: 132 | # if np.any(num_loss_queries + num_extra_queries >= self.max_loss_queries): 133 | if torch.any(num_loss_queries >= self.max_loss_queries): 134 | print("#loss queries exceeded budget, exiting") 135 | break 136 | if torch.any(num_extra_queries >= self.max_extra_queries): 137 | print("#extra_queries exceeded budget, exiting") 138 | break 139 | if torch.all(dones_mask): 140 | print("all data pts are misclassified, exiting") 141 | break 142 | # propose new perturbations 143 | sugg_xs_t, num_loss_queries_per_step = self._perturb(xs_t, loss_fct) 144 | # project around xs and within pixel range and 145 | # replace xs only if not done 146 | ##updated x here 147 | xs_t = self.proj_replace(xs_t, sugg_xs_t, (dones_mask.reshape(-1, *[1] * num_axes).float())) 148 | 149 | # update number of queries (note this is done before updating dones_mask) 150 | num_loss_queries += num_loss_queries_per_step * (~dones_mask) 151 | num_extra_queries += (~dones_mask) 152 | losses = loss_fct(xs_t) * (~dones_mask) + losses * dones_mask 153 | 154 | # update dones mask 155 | dones_mask = dones_mask | early_stop_extra_fct(xs_t) 156 | success_mask = dones_mask * correct_classified_mask 157 | its += 1 158 | 159 | self.is_new_batch = False 160 | 161 | 162 | success_mask = dones_mask * correct_classified_mask 163 | self.total_loss_queries += (num_loss_queries * success_mask).sum() 164 | self.total_extra_queries += (num_extra_queries * success_mask).sum() 165 | self.list_loss_queries = torch.cat([self.list_loss_queries, torch.zeros(1, batch_size, device = xs.device)], dim=0) 166 | self.list_loss_queries[-1] = num_loss_queries * success_mask 167 | self.total_successes += success_mask.sum() 168 | self.total_failures += ((~dones_mask) * correct_classified_mask).sum() 169 | 170 | # set self._proj to None to ensure it is intended use 171 | self._proj = None 172 | 173 | return xs_t 174 | -------------------------------------------------------------------------------- /attacks/signopt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import transforms 4 | import numpy as np 5 | import os 6 | import argparse 7 | from scipy.linalg import qr 8 | from art.estimators.classification import PyTorchClassifier 9 | 10 | from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail 11 | import config 12 | from watermark import Watermark 13 | from attacks.decision import DecisionBlackBoxAttack 14 | 15 | if torch.cuda.is_available(): 16 | t = lambda z: torch.tensor(data = z).cuda() 17 | else: 18 | t = lambda z: torch.tensor(data = z) 19 | 20 | start_learning_rate = 1.0 21 | 22 | def quad_solver(Q, b): 23 | """ 24 | Solve min_a 0.5*aQa + b^T a s.t. a>=0 25 | """ 26 | K = Q.shape[0] 27 | alpha = torch.zeros((K,)) 28 | g = b 29 | Qdiag = torch.diag(Q) 30 | for _ in range(20000): 31 | delta = torch.maximum(alpha - g/Qdiag,0) - alpha 32 | idx = torch.argmax(torch.abs(delta)) 33 | val = delta[idx] 34 | if abs(val) < 1e-7: 35 | break 36 | g = g + val*Q[:,idx] 37 | alpha[idx] += val 38 | return alpha 39 | 40 | def sign(y): 41 | """ 42 | y -- numpy array of shape (m,) 43 | Returns an element-wise indication of the sign of a number. 44 | The sign function returns -1 if y < 0, 1 if x >= 0. nan is returned for nan inputs. 45 | """ 46 | y_sign = torch.sign(y) 47 | y_sign[y_sign==0] = 1 48 | return y_sign 49 | 50 | 51 | class SignOPTAttack(DecisionBlackBoxAttack): 52 | """ 53 | Sign_OPT 54 | """ 55 | 56 | def __init__(self, epsilon, p, alpha, beta, svm, momentum, max_queries, k, lb, ub, batch_size, sigma): 57 | super().__init__(max_queries = max_queries, 58 | epsilon=epsilon, 59 | p=p, 60 | lb=lb, 61 | ub=ub, 62 | batch_size = batch_size) 63 | self.alpha = alpha 64 | self.beta = beta 65 | self.svm = svm 66 | self.momentum = momentum 67 | self.k = k 68 | self.sigma = sigma 69 | self.query_count = 0 70 | 71 | 72 | def _config(self): 73 | return { 74 | "p": self.p, 75 | "epsilon": self.epsilon, 76 | "lb": self.lb, 77 | "ub": self.ub, 78 | "attack_name": self.__class__.__name__ 79 | } 80 | 81 | def attack_untargeted(self, x0, y0, alpha = 0.2, beta = 0.001): 82 | """ 83 | Attack the original image and return adversarial example 84 | """ 85 | 86 | y0 = y0[0] 87 | self.query_count = 0 88 | 89 | # Calculate a good starting point. 90 | num_directions = 10 91 | best_theta, g_theta = None, float('inf') 92 | 93 | for i in range(num_directions): 94 | self.query_count += 1 95 | theta = torch.randn_like(x0) 96 | if self.predict_label(x0+theta)!=y0: 97 | initial_lbd = torch.norm(theta) 98 | theta /= initial_lbd 99 | lbd, count = self.fine_grained_binary_search(x0, y0, theta, initial_lbd, g_theta) 100 | self.query_count += count 101 | if lbd < g_theta: 102 | best_theta, g_theta = theta, lbd 103 | 104 | if g_theta == float('inf'): 105 | return x0, self.query_count 106 | 107 | # Begin Gradient Descent. 108 | xg, gg = best_theta, g_theta 109 | vg = torch.zeros_like(xg) 110 | 111 | assert not self.svm 112 | for i in range(1500): 113 | sign_gradient, grad_queries = self.sign_grad_v1(x0, y0, xg, initial_lbd=gg, h=beta) 114 | self.query_count += grad_queries 115 | # Line search 116 | min_theta = xg 117 | min_g2 = gg 118 | min_vg = vg 119 | for _ in range(15): 120 | if self.momentum > 0: 121 | new_vg = self.momentum*vg - alpha*sign_gradient 122 | new_theta = xg + new_vg 123 | else: 124 | new_theta = xg - alpha * sign_gradient 125 | new_theta /= torch.norm(new_theta) 126 | new_g2, count = self.fine_grained_binary_search_local(x0, y0, new_theta, initial_lbd = min_g2, tol=beta/500) 127 | self.query_count += count 128 | alpha = alpha * 2 129 | if new_g2 < min_g2: 130 | min_theta = new_theta 131 | min_g2 = new_g2 132 | if self.momentum > 0: 133 | min_vg = new_vg 134 | else: 135 | break 136 | if min_g2 >= gg: 137 | for _ in range(15): 138 | alpha = alpha * 0.25 139 | if self.momentum > 0: 140 | new_vg = self.momentum*vg - alpha*sign_gradient 141 | new_theta = xg + new_vg 142 | else: 143 | new_theta = xg - alpha * sign_gradient 144 | new_theta /= torch.norm(new_theta) 145 | new_g2, count = self.fine_grained_binary_search_local(x0, y0, new_theta, initial_lbd = min_g2, tol=beta/500) 146 | self.query_count += count 147 | if new_g2 < gg: 148 | min_theta = new_theta 149 | min_g2 = new_g2 150 | if self.momentum > 0: 151 | min_vg = new_vg 152 | break 153 | if alpha < 1e-4: 154 | alpha = 1.0 155 | beta = beta*0.1 156 | if (beta < 1e-8): 157 | break 158 | 159 | xg, gg = min_theta, min_g2 160 | vg = min_vg 161 | 162 | 163 | if self.query_count > self.max_queries: 164 | break 165 | 166 | dist = self.distance(gg*xg) 167 | if dist < self.epsilon: 168 | break 169 | 170 | dist = self.distance(gg*xg) 171 | return x0 + gg*xg, self.query_count 172 | 173 | def sign_grad_v1(self, x0, y0, theta, initial_lbd, h=0.001, D=4, target=None): 174 | """ 175 | Evaluate the sign of gradient by formulat 176 | sign(g) = 1/Q [ \sum_{q=1}^Q sign( g(theta+h*u_i) - g(theta) )u_i$ ] 177 | """ 178 | K = self.k 179 | sign_grad = torch.zeros_like(theta) 180 | queries = 0 181 | for _ in range(K): 182 | u = torch.randn_like(theta) 183 | u /= torch.norm(u) 184 | 185 | sign = 1 186 | new_theta = theta + h*u 187 | new_theta /= torch.norm(new_theta) 188 | 189 | # Targeted case. 190 | if (target is not None and 191 | self.predict_label(x0+initial_lbd*new_theta) == target): 192 | sign = -1 193 | 194 | # Untargeted case 195 | if (target is None and 196 | self.predict_label(x0+t(initial_lbd*new_theta)) != y0): 197 | sign = -1 198 | queries += 1 199 | sign_grad += u*sign 200 | 201 | sign_grad /= K 202 | 203 | return sign_grad, queries 204 | 205 | def fine_grained_binary_search_local(self, x0, y0, theta, initial_lbd = 1.0, tol=1e-5): 206 | nquery = 0 207 | lbd = initial_lbd 208 | 209 | if self.predict_label(x0+lbd*theta) == y0: 210 | lbd_lo = lbd 211 | lbd_hi = lbd*1.01 212 | nquery += 1 213 | while self.predict_label(x0+lbd_hi*theta) == y0: 214 | lbd_hi = lbd_hi*1.01 215 | nquery += 1 216 | if lbd_hi > 20: 217 | return float('inf'), nquery 218 | else: 219 | lbd_hi = lbd 220 | lbd_lo = lbd*0.99 221 | nquery += 1 222 | while self.predict_label(x0+lbd_lo*theta) != y0 : 223 | lbd_lo = lbd_lo*0.99 224 | nquery += 1 225 | if nquery + self.query_count> self.max_queries: 226 | break 227 | 228 | while (lbd_hi - lbd_lo) > tol: 229 | lbd_mid = (lbd_lo + lbd_hi)/2.0 230 | nquery += 1 231 | if nquery + self.query_count> self.max_queries: 232 | break 233 | if self.predict_label(x0 + lbd_mid*theta) != y0: 234 | lbd_hi = lbd_mid 235 | else: 236 | lbd_lo = lbd_mid 237 | return lbd_hi, nquery 238 | 239 | def fine_grained_binary_search(self, x0, y0, theta, initial_lbd, current_best): 240 | nquery = 0 241 | if initial_lbd > current_best: 242 | if self.predict_label(x0+t(current_best*theta)) == y0: 243 | nquery += 1 244 | return float('inf'), nquery 245 | lbd = current_best 246 | else: 247 | lbd = initial_lbd 248 | 249 | lbd_hi = lbd 250 | lbd_lo = 0.0 251 | 252 | while (lbd_hi - lbd_lo) > 1e-5: 253 | lbd_mid = (lbd_lo + lbd_hi)/2.0 254 | nquery += 1 255 | if nquery + self.query_count> self.max_queries: 256 | break 257 | if self.predict_label(x0 + t(lbd_mid*theta)) != y0: 258 | lbd_hi = lbd_mid 259 | else: 260 | lbd_lo = lbd_mid 261 | return lbd_hi, nquery 262 | 263 | def fine_grained_binary_search_local_targeted(self, x0, t, theta, initial_lbd=1.0, tol=1e-5): 264 | nquery = 0 265 | lbd = initial_lbd 266 | 267 | if self.predict_label(x0 + t(lbd*theta)) != t: 268 | lbd_lo = lbd 269 | lbd_hi = lbd*1.01 270 | nquery += 1 271 | while self.predict_label(x0 + t(lbd_hi*theta)) != t: 272 | lbd_hi = lbd_hi*1.01 273 | nquery += 1 274 | if lbd_hi > 100: 275 | return float('inf'), nquery 276 | else: 277 | lbd_hi = lbd 278 | lbd_lo = lbd*0.99 279 | nquery += 1 280 | while self.predict_label(x0 + t(lbd_lo*theta)) == t: 281 | lbd_lo = lbd_lo*0.99 282 | nquery += 1 283 | 284 | while (lbd_hi - lbd_lo) > tol: 285 | lbd_mid = (lbd_lo + lbd_hi)/2.0 286 | nquery += 1 287 | if self.predict_label(x0 + t(lbd_mid*theta)) == t: 288 | lbd_hi = lbd_mid 289 | else: 290 | lbd_lo = lbd_mid 291 | 292 | return lbd_hi, nquery 293 | 294 | def fine_grained_binary_search_targeted(self, x0, t, theta, initial_lbd, current_best): 295 | nquery = 0 296 | if initial_lbd > current_best: 297 | if self.predict_label(x0 + t(current_best*theta)) != t: 298 | nquery += 1 299 | return float('inf'), nquery 300 | lbd = current_best 301 | else: 302 | lbd = initial_lbd 303 | 304 | lbd_hi = lbd 305 | lbd_lo = 0.0 306 | 307 | while (lbd_hi - lbd_lo) > 1e-5: 308 | lbd_mid = (lbd_lo + lbd_hi)/2.0 309 | nquery += 1 310 | if self.predict_label(x0 + t(lbd_mid*theta)) != t: 311 | lbd_lo = lbd_mid 312 | else: 313 | lbd_hi = lbd_mid 314 | return lbd_hi, nquery 315 | 316 | 317 | def _perturb(self, xs_t, ys): 318 | if self.targeted: 319 | adv, q = self.attack_targeted(xs_t, ys, self.alpha, self.beta) 320 | else: 321 | adv, q = self.attack_untargeted(xs_t, ys, self.alpha, self.beta) 322 | 323 | return adv, q 324 | 325 | if __name__ == '__main__': 326 | parser = argparse.ArgumentParser() 327 | parser.add_argument('--model_name', help = 'Benchmark model structure.', choices = ['VGG16', 'ResNet18']) 328 | parser.add_argument('--dataset_name', help = 'Benchmark dataset used.', choices = ['CIFAR10', 'GTSRB']) 329 | parser.add_argument('-M', '--num_models', help = 'The number of models used.', type = int, default = 100) 330 | parser.add_argument('-n', '--num_samples', help = 'The number of adversarial samples per model.', type = int, default = 1) 331 | parser.add_argument('-c', '--cont', help = 'Continue from the stopped point last time.', action = 'store_true') 332 | parser.add_argument('-b', '--batch_size', help = 'The batch size used for attacks.', type = int, default = 10) 333 | args = parser.parse_args() 334 | 335 | # renaming 336 | dataset = eval(f'config.{args.dataset_name}()') 337 | training_set, testing_set = dataset.training_set, dataset.testing_set 338 | num_classes = dataset.num_classes 339 | means, stds = dataset.means, dataset.stds 340 | C, H, W = dataset.C, dataset.H, dataset.W 341 | Head, Tail = eval(f'{args.model_name}Head'), eval(f'{args.model_name}Tail') 342 | testing_loader = torch.utils.data.DataLoader(testing_set, batch_size = args.batch_size, shuffle = True, num_workers = 2) 343 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 344 | 345 | 346 | model_dir = f'saved_models/{args.model_name}-{args.dataset_name}' 347 | 348 | save_dir = f'saved_adv_examples/{args.model_name}-{args.dataset_name}' 349 | 350 | # load the tail of the model 351 | normalizer = transforms.Normalize(means, stds) 352 | tail = Tail(num_classes) 353 | tail.load_state_dict(torch.load(f'{model_dir}/base_tail_state_dict')) 354 | tail.to(device) 355 | 356 | # load the classifiers 357 | classifiers = [] 358 | models = [] 359 | for i in range(args.num_models): 360 | head = Head() 361 | head.to(device) 362 | head.load_state_dict(torch.load(f'{model_dir}/head_{i}/state_dict')) 363 | watermark = Watermark.load(f'{model_dir}/head_{i}/watermark.npy') 364 | 365 | models.append(nn.Sequential(normalizer, watermark, head, tail).eval()) 366 | models[-1].to(device) 367 | 368 | classifier = PyTorchClassifier( 369 | model = models[-1], 370 | loss = None, # dummy 371 | optimizer = None, # dummy 372 | clip_values = (0, 1), 373 | input_shape=(C, H, W), 374 | nb_classes=num_classes, 375 | device_type = 'gpu' if torch.cuda.is_available() else 'cpu' 376 | ) 377 | classifiers.append(classifier) 378 | classifiers = np.array(classifiers) 379 | 380 | for i, (model, c) in enumerate(zip(models, classifiers)): 381 | if os.path.isfile(f'{save_dir}/head_{i}/SignOPT.npz') and args.cont: 382 | continue 383 | original_images, attacked_images, labels = [], [], [] 384 | count_success = 0 385 | for X, y in testing_loader: 386 | with torch.no_grad(): 387 | pred = c.predict(X.numpy()) 388 | correct_mask = pred.argmax(axis = -1) == y.numpy() 389 | 390 | X_device, y_device = X.to(device), y.to(device) 391 | 392 | a = SignOPTAttack(epsilon = 1, p = '2', alpha = 0.2, beta = 0.001, svm = False, momentum = 0, max_queries = 10000, k = 200, lb = 0, ub = 1, batch_size = 1, sigma = 0) 393 | X_attacked = a.run(X_device, y_device, model, False, None).cpu().numpy() 394 | 395 | attacked_preds = np.vectorize(lambda z: z.predict(X_attacked), signature = '()->(m,n)')(classifiers) 396 | 397 | success_mask = attacked_preds.argmax(axis = -1) != y.numpy() 398 | success_mask = np.logical_and(success_mask[i], success_mask.sum(axis=0) >= 2) 399 | 400 | mask = np.logical_and(correct_mask, success_mask) 401 | 402 | original_images.append(X[mask]) 403 | attacked_images.append(X_attacked[mask]) 404 | labels.append(y[mask]) 405 | 406 | count_success += mask.sum() 407 | if count_success >= args.num_samples: 408 | print(f'Model {i}, attack SignOPT, done!') 409 | break 410 | else: 411 | print(f'Model {i}, attack SignOPT, {count_success} out of {args.num_samples} generated...') 412 | 413 | original_images = np.concatenate(original_images) 414 | attacked_images = np.concatenate(attacked_images) 415 | labels = np.concatenate(labels) 416 | os.makedirs(f'{save_dir}/head_{i}', exist_ok = True) 417 | np.savez(f'{save_dir}/head_{i}/SignOPT.npz', X = original_images, X_attacked = attacked_images, y = labels) 418 | -------------------------------------------------------------------------------- /attacks/simba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import transforms 4 | import numpy as np 5 | import os 6 | import argparse 7 | from art.estimators.classification import PyTorchClassifier 8 | from art.attacks.evasion import SimBA 9 | 10 | from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail 11 | import config 12 | from watermark import Watermark 13 | 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--model_name', help = 'Benchmark model structure.', choices = ['VGG16', 'ResNet18']) 18 | parser.add_argument('--dataset_name', help = 'Benchmark dataset used.', choices = ['CIFAR10', 'GTSRB']) 19 | parser.add_argument('-M', '--num_models', help = 'The number of models used.', type = int, default = 100) 20 | parser.add_argument('-n', '--num_samples', help = 'The number of adversarial samples per model.', type = int, default = 1) 21 | parser.add_argument('-c', '--cont', help = 'Continue from the stopped point last time.', action = 'store_true') 22 | parser.add_argument('-d', '--domain', help = 'Choose the domain of the attack.', choices = ['dct', 'px'], default = 'px') 23 | parser.add_argument('-b', '--batch_size', help = 'The batch size used for attacks.', type = int, default = 10) 24 | parser.add_argument('-v', '--verbose', help = 'Verbose when attacking.', action = 'store_true') 25 | args = parser.parse_args() 26 | 27 | # renaming 28 | dataset = eval(f'config.{args.dataset_name}()') 29 | training_set, testing_set = dataset.training_set, dataset.testing_set 30 | num_classes = dataset.num_classes 31 | means, stds = dataset.means, dataset.stds 32 | C, H, W = dataset.C, dataset.H, dataset.W 33 | Head, Tail = eval(f'{args.model_name}Head'), eval(f'{args.model_name}Tail') 34 | testing_loader = torch.utils.data.DataLoader(testing_set, batch_size = args.batch_size, shuffle = True, num_workers = 2) 35 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 36 | 37 | model_dir = f'saved_models/{args.model_name}-{args.dataset_name}' 38 | 39 | save_dir = f'saved_adv_examples/{args.model_name}-{args.dataset_name}' 40 | 41 | 42 | # load the tail of the model 43 | normalizer = transforms.Normalize(means, stds) 44 | tail = Tail(num_classes) 45 | tail.load_state_dict(torch.load(f'{model_dir}/base_tail_state_dict')) 46 | tail.to(device) 47 | 48 | # load the classifiers 49 | classifiers = [] 50 | for i in range(args.num_models): 51 | head = Head() 52 | head.to(device) 53 | head.load_state_dict(torch.load(f'{model_dir}/head_{i}/state_dict')) 54 | watermark = Watermark.load(f'{model_dir}/head_{i}/watermark.npy') 55 | 56 | classifier = PyTorchClassifier( 57 | model = nn.Sequential(normalizer, watermark, head, tail, nn.Softmax(dim = -1)).eval(), 58 | loss = None, # dummy 59 | optimizer = None, # dummy 60 | clip_values = (0, 1), 61 | input_shape=(C, H, W), 62 | nb_classes=num_classes, 63 | device_type = 'gpu' if torch.cuda.is_available() else 'cpu' 64 | ) 65 | classifiers.append(classifier) 66 | classifiers = np.array(classifiers) 67 | 68 | # attacking 69 | for i, c in enumerate(classifiers): 70 | if os.path.isfile(f'{save_dir}/head_{i}/SimBA-{args.domain}.npz') and args.cont: 71 | continue 72 | 73 | original_images, attacked_images, labels = [], [], [] 74 | count_success = 0 75 | 76 | for X, y in testing_loader: 77 | X, y = X.numpy(), y.numpy() 78 | pred = c.predict(X) 79 | correct_mask = pred.argmax(axis = 1) == y 80 | 81 | a = SimBA(c, attack = args.domain, verbose = args.verbose) 82 | 83 | X_attacked = a.generate(X) 84 | attacked_preds = np.vectorize(lambda z: z.predict(X_attacked), signature = '()->(m,n)')(classifiers) # (num_model, batch_size, num_class) 85 | success_mask = attacked_preds.argmax(axis = -1) != y 86 | success_mask = np.logical_and(success_mask[i], success_mask.sum(axis=0) >= 2) 87 | mask = np.logical_and(correct_mask, success_mask) 88 | 89 | original_images.append(X[mask]) 90 | attacked_images.append(X_attacked[mask]) 91 | labels.append(y[mask]) 92 | 93 | count_success += mask.sum() 94 | if count_success >= args.num_samples: 95 | print(f'Head {i}, attack SimBA-{args.domain}, {count_success} out of {args.num_samples} generated, done!') 96 | break 97 | else: 98 | print(f'Head {i}, attack SimBA-{args.domain}, {count_success} out of {args.num_samples} generated...') 99 | 100 | original_images = np.concatenate(original_images) 101 | attacked_images = np.concatenate(attacked_images) 102 | labels = np.concatenate(labels) 103 | os.makedirs(f'{save_dir}/head_{i}', exist_ok = True) 104 | np.savez(f'{save_dir}/head_{i}/SimBA-{args.domain}.npz', X = original_images, X_attacked = attacked_images, y = labels) 105 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | import torchvision.datasets as datasets 3 | 4 | 5 | # CIFAR10 6 | class CIFAR10: 7 | def __init__(self): 8 | 9 | transform_train = transforms.Compose([ 10 | transforms.RandomCrop(32, padding=4), 11 | transforms.RandomHorizontalFlip(), 12 | transforms.ToTensor() 13 | ]) 14 | 15 | transform_test = transforms.ToTensor() 16 | 17 | self.C, self.H, self.W = 3, 32, 32 18 | self.means, self.stds = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 19 | self.training_set = datasets.CIFAR10(root = f'./data', train = True, transform = transform_train, download = True) 20 | self.testing_set = datasets.CIFAR10(root = f'./data', train = False, transform = transform_test, download = True) 21 | self.num_classes = 10 22 | self.dataset = datasets.CIFAR10(root = f'./data', train = False, transform = None, download = True) 23 | 24 | # GTSRB 25 | class GTSRB: 26 | def __init__(self): 27 | 28 | transform_train = transforms.Compose([ 29 | transforms.Resize((32, 32)), 30 | transforms.RandomCrop(32, padding=4), 31 | transforms.ToTensor() 32 | ]) 33 | 34 | transform_test = transforms.Compose([ 35 | transforms.Resize((32, 32)), 36 | transforms.ToTensor() 37 | ]) 38 | 39 | self.means, self.stds = (0.3337, 0.3064, 0.3171), (0.2672, 0.2564, 0.2629) 40 | self.C, self.H, self.W = 3, 32, 32 41 | self.training_set = datasets.GTSRB(root = f'./data', split = 'train', transform = transform_train, download = True) 42 | self.testing_set = datasets.GTSRB(root = f'./data', split = 'test', transform = transform_test, download = True) 43 | self.dataset = datasets.GTSRB(root = f'./data', split = 'train', transform = None, download = True) 44 | self.num_classes = 43 45 | 46 | # TINY 47 | class TINY: 48 | def __init__(self): 49 | 50 | transform_train = transforms.Compose([ 51 | transforms.RandomRotation(20), 52 | transforms.RandomHorizontalFlip(0.5), 53 | transforms.ToTensor() 54 | ]) 55 | 56 | transform_test = transforms.Compose([ 57 | transforms.ToTensor() 58 | ]) 59 | 60 | self.C, self.H, self.W = 3, 64, 64 61 | self.means, self.stds = (0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262) 62 | self.training_set = datasets.ImageFolder('./data/tiny-imagenet-200/train', transform_train) 63 | self.testing_set = datasets.ImageFolder('./data/tiny-imagenet-200/test', transform_test) 64 | self.num_classes = 200 65 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg16 import VGG16Head, VGG16Tail 2 | from .resnet18 import ResNet18Head, ResNet18Tail 3 | -------------------------------------------------------------------------------- /models/resnet18.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Reference: 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 5 | ''' 6 | 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | class BasicBlock(nn.Module): 11 | expansion = 1 12 | 13 | def __init__(self, in_planes, planes, stride=1): 14 | super().__init__() 15 | self.conv1 = nn.Conv2d( 16 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 19 | stride=1, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | 22 | self.shortcut = nn.Sequential() 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, 26 | kernel_size=1, stride=stride, bias=False), 27 | nn.BatchNorm2d(self.expansion*planes) 28 | ) 29 | 30 | def forward(self, x): 31 | out = F.relu(self.bn1(self.conv1(x))) 32 | out = self.bn2(self.conv2(out)) 33 | out += self.shortcut(x) 34 | out = F.relu(out) 35 | return out 36 | 37 | def ResNet18Block(block, in_planes, planes, num_blocks, stride): 38 | strides = [stride] + [1]*(num_blocks-1) 39 | layers = [] 40 | for s in strides: 41 | layers.append(block(in_planes, planes, s)) 42 | in_planes = planes * block.expansion 43 | return nn.Sequential(*layers) 44 | 45 | def ResNet18Head(): 46 | return nn.Sequential( 47 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False), 48 | nn.BatchNorm2d(64), 49 | nn.ReLU(inplace=True), 50 | ResNet18Block(BasicBlock, 64, 64, 2, stride=1) 51 | ) 52 | 53 | 54 | class ResNet18Tail(nn.Module): 55 | def __init__(self, num_classes): 56 | super().__init__() 57 | 58 | self.layer2 = ResNet18Block(BasicBlock, 64, 128, 2, stride=2) 59 | self.layer3 = ResNet18Block(BasicBlock, 128, 256, 2, stride=2) 60 | self.layer4 = ResNet18Block(BasicBlock, 256, 512, 2, stride=2) 61 | self.pool1d = nn.AdaptiveAvgPool1d(512) 62 | self.linear = nn.Linear(512*BasicBlock.expansion, num_classes) 63 | 64 | def forward(self, x): 65 | out = self.layer2(x) 66 | out = self.layer3(out) 67 | out = self.layer4(out) 68 | 69 | out = F.avg_pool2d(out, 4) 70 | out = out.view(out.size(0), -1) 71 | out = self.pool1d(out) 72 | out = self.linear(out) 73 | return out 74 | 75 | -------------------------------------------------------------------------------- /models/vgg16.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def VGG16Head(): 5 | return nn.Sequential( 6 | nn.Conv2d(3, 64, kernel_size = 3, padding = 1), 7 | nn.BatchNorm2d(64), 8 | nn.ReLU(inplace=True), 9 | nn.Conv2d(64, 64, kernel_size = 3, padding = 1), 10 | nn.BatchNorm2d(64), 11 | nn.ReLU(inplace=True), 12 | nn.MaxPool2d(kernel_size=2, stride=2) 13 | ) 14 | 15 | class VGG16Tail(nn.Module): 16 | def __init__(self, num_classes): 17 | super().__init__() 18 | self.features = self._make_layers([128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']) 19 | self.pool1d = nn.AdaptiveAvgPool1d(512) 20 | self.classifier = nn.Linear(512, num_classes) 21 | 22 | def forward(self, x): 23 | out = self.features(x) 24 | out = out.view(out.size(0), -1) 25 | out = self.pool1d(out) 26 | out = self.classifier(out) 27 | return out 28 | 29 | def _make_layers(self, cfg): 30 | layers = [] 31 | in_channels = 64 32 | for x in cfg: 33 | if x == 'M': 34 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 35 | else: 36 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 37 | nn.BatchNorm2d(x), 38 | nn.ReLU(inplace=True)] 39 | in_channels = x 40 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 41 | return nn.Sequential(*layers) 42 | -------------------------------------------------------------------------------- /pics/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/adv_tracing/6ec6226d2d5728902a1b54c6c44b50c4ff593750/pics/framework.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --find-links https://download.pytorch.org/whl/torch_stable.html 2 | torch==1.12.1+cu113 3 | torchvision==0.13.1+cu113 4 | numpy==1.21.5 5 | adversarial-robustness-toolbox==1.10.3 6 | -------------------------------------------------------------------------------- /trace_data_free.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import transforms 4 | import numpy as np 5 | import argparse 6 | import logging 7 | from art.estimators.classification import PyTorchClassifier 8 | 9 | from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail 10 | import config 11 | from watermark import Watermark 12 | 13 | 14 | def get_classifier(watermark, model, means, stds, num_class): 15 | return PyTorchClassifier( 16 | model = nn.Sequential(transforms.Normalize(means, stds), watermark, model, nn.Softmax(dim = -1)).eval(), 17 | loss = None, # dummy 18 | optimizer = None, # dummy 19 | input_shape=(C, H, W), 20 | clip_values = (0, 1), 21 | nb_classes=num_class, 22 | device_type = 'gpu' if torch.cuda.is_available() else 'cpu' 23 | ) 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--model_name', default='ResNet18', help = 'Benchmark model structure.', choices = ['VGG16', 'ResNet18']) 28 | parser.add_argument('--dataset_name', default='CIFAR10', help = 'Benchmark dataset used.', choices = ['CIFAR10', 'GTSRB', 'TINY']) 29 | parser.add_argument('--attacks', default='Bandit', help = 'Attacks to be explored.', nargs = '+') 30 | parser.add_argument('--alpha', help = 'Hyper-parameter alpha.', type = float) 31 | parser.add_argument('-M', '--num_models', help = 'The number of models used for identification.', type = int, default = 50) 32 | parser.add_argument('-n', '--num_samples', help = 'The number of adversarial samples per model.', type = int, default = 1) 33 | 34 | args = parser.parse_args() 35 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 36 | 37 | 38 | dataset = eval(f'config.{args.dataset_name}()') 39 | training_set, testing_set = dataset.training_set, dataset.testing_set 40 | num_classes = dataset.num_classes 41 | means, stds = dataset.means, dataset.stds 42 | C, H, W = dataset.C, dataset.H, dataset.W 43 | Head, Tail = eval(f'{args.model_name}Head'), eval(f'{args.model_name}Tail') 44 | normalizer = transforms.Normalize(means, stds) 45 | 46 | 47 | model_dir = f'./saved_models/{args.model_name}-{args.dataset_name}' 48 | adv_dir = f'./saved_adv_examples/{args.model_name}-{args.dataset_name}' 49 | 50 | # load the tail of the model 51 | normalizer = transforms.Normalize(means, stds) 52 | tail = Tail(num_classes) 53 | tail.load_state_dict(torch.load(f'{model_dir}/base_tail_state_dict')) 54 | tail.to(device) 55 | tail.eval() 56 | 57 | # load the classifiers 58 | heads, watermarks, models = [], [], [] 59 | for i in range(args.num_models): 60 | heads.append(Head()) 61 | heads[-1].to(device) 62 | heads[-1].load_state_dict(torch.load(f'{model_dir}/head_{i}/state_dict')) 63 | heads[-1].eval() 64 | watermarks.append(Watermark.load(f'{model_dir}/head_{i}/watermark.npy')) 65 | models.append(nn.Sequential(heads[-1], tail)) 66 | 67 | for a in args.attacks: 68 | correct = 0 69 | Loss = nn.CrossEntropyLoss() 70 | for i in range(args.num_models): 71 | adv_npz = np.load(f'{adv_dir}/head_{i}/{a}.npz') 72 | X, X_attacked, y = adv_npz['X'][:args.num_samples], adv_npz['X_attacked'][:args.num_samples], adv_npz['y'][:args.num_samples] 73 | 74 | classifier_matrix = np.array([[get_classifier(wm, m, means, stds, num_classes) for wm in watermarks] for m in models]) 75 | predictions = np.vectorize(lambda c: c.predict(X_attacked), signature='()->(m,n)')(classifier_matrix) 76 | 77 | X, X_attacked, y = torch.tensor(X).to(device), torch.tensor(X_attacked).to(device), torch.tensor(y).to(device) 78 | CE_loss = torch.stack([Loss(tail(head(wm(normalizer(X_attacked)))).softmax(-1), y) for wm, head in zip(watermarks, heads)], axis = 0).cpu() 79 | 80 | 81 | out = torch.stack([tail(head(wm(normalizer(X_attacked)))).argmax(axis = -1) for wm, head in zip(watermarks, heads)], axis = 0) 82 | wrong_pred = (out == y[None,:]).sum(-1) > 0 83 | 84 | predictions_maximum_class = predictions.max(axis = -1) 85 | 86 | maximum_class_score = predictions_maximum_class[np.arange(args.num_models), np.arange(args.num_models), ...] / predictions_maximum_class.sum(1) 87 | maximum_class_score = torch.from_numpy(maximum_class_score).sum(-1) 88 | 89 | score = maximum_class_score + args.alpha * CE_loss 90 | score[wrong_pred]=np.inf 91 | result = score.topk(1, axis = 0, largest=False)[1] 92 | 93 | correct += torch.sum(result == i).item() 94 | print((f'Attack {a}, tracing accuracy {correct / args.num_models}.')) 95 | -------------------------------------------------------------------------------- /trace_data_limited.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import transforms 4 | import numpy as np 5 | import argparse 6 | 7 | from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail 8 | import config 9 | from watermark import Watermark 10 | 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--model_name', default='ResNet18', help='Benchmark model structure.', choices=['VGG16', 'ResNet18']) 15 | parser.add_argument('--dataset_name', default='CIFAR10', help='Benchmark dataset used.', choices=['CIFAR10', 'GTSRB', 'TINY']) 16 | parser.add_argument('--attacks', default='Bandit', help='Attacks to be explored.', nargs='+') 17 | parser.add_argument('--alpha', help='Hyper-parameter alpha.', type=float) 18 | parser.add_argument('-M', '--num_models', help='The number of models used for identification.', type=int, default=50) 19 | parser.add_argument('-n', '--num_samples', help='The number of adversarial samples per model.', type=int, default=1) 20 | 21 | args = parser.parse_args() 22 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 23 | 24 | 25 | dataset = eval(f'config.{args.dataset_name}()') 26 | training_set, testing_set = dataset.training_set, dataset.testing_set 27 | num_classes = dataset.num_classes 28 | means, stds = dataset.means, dataset.stds 29 | Head, Tail = eval(f'{args.model_name}Head'), eval(f'{args.model_name}Tail') 30 | 31 | model_dir = f'./saved_models/{args.model_name}-{args.dataset_name}' 32 | adv_dir = f'./saved_adv_examples/{args.model_name}-{args.dataset_name}' 33 | 34 | # load the tail of the model 35 | normalizer = transforms.Normalize(means, stds) 36 | tail = Tail(num_classes) 37 | tail.load_state_dict(torch.load(f'{model_dir}/base_tail_state_dict')) 38 | tail.to(device) 39 | tail.eval() 40 | 41 | # load the classifiers 42 | heads, watermarks = [], [] 43 | for i in range(args.num_models): 44 | heads.append(Head()) 45 | heads[-1].to(device) 46 | heads[-1].load_state_dict(torch.load(f'{model_dir}/head_{i}/state_dict')) 47 | heads[-1].eval() 48 | watermarks.append(Watermark.load(f'{model_dir}/head_{i}/watermark.npy')) 49 | overall_acc = 0 50 | 51 | for a in args.attacks: 52 | correct = 0 53 | for i in range(args.num_models): 54 | adv_npz = np.load(f'{adv_dir}/head_{i}/{a}.npz') 55 | Loss = nn.CrossEntropyLoss() 56 | X, X_attacked, y = adv_npz['X'][:args.num_samples], adv_npz['X_attacked'][:args.num_samples], adv_npz['y'][:args.num_samples] 57 | X, X_attacked, y = torch.tensor(X).to(device), torch.tensor(X_attacked).to(device), torch.tensor(y).to(device) 58 | 59 | CE_loss = torch.stack([Loss(tail(head(wm(normalizer(X_attacked)))).softmax(-1), y) for wm, head in zip(watermarks, heads)], axis = 0) 60 | 61 | diffs_sum = torch.stack([wm.get_values(torch.abs(X - X_attacked)).sum() for wm in watermarks], axis = 0) 62 | 63 | score = diffs_sum + args.alpha * CE_loss 64 | wrong_pred_list = [] 65 | 66 | out = torch.stack([tail(head(wm(normalizer(X_attacked)))).argmax(axis = -1) for wm, head in zip(watermarks, heads)], axis = 0) 67 | wrong_pred = (out == y[None,:]).sum(-1) > 0 68 | 69 | score[wrong_pred] = np.inf 70 | 71 | 72 | result = score.topk(1, largest=False)[1] 73 | correct += torch.sum(result == i).item() 74 | print((f'Attack {a}, tracing accuracy {correct / args.num_models}.')) 75 | 76 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import transforms 5 | import os 6 | import argparse 7 | 8 | from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail 9 | import config 10 | from watermark import Watermark 11 | 12 | 13 | ''' 14 | Train the multi-head-one-tail model. 15 | ''' 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--model_name', help = 'Benchmark model structure.', choices = ['VGG16', 'ResNet18']) 20 | parser.add_argument('--dataset_name', help = 'Benchmark dataset used.', choices = ['CIFAR10', 'GTSRB', 'TINY']) 21 | parser.add_argument('--num_workers', help = 'Number of workers', type = int, default = 2) 22 | parser.add_argument('-N', '--num_heads', help = 'Number of heads.', type = int, default = 100) 23 | parser.add_argument('-b', '--batch_size', help = 'Batch size.', type = int, default = 128) 24 | parser.add_argument('-e', '--num_epochs', help = 'Number of epochs.', type = int, default = 10) 25 | parser.add_argument('-lr', '--learning_rate', help = 'Learning rate.', type = float, default = 1e-3) 26 | parser.add_argument('-md', '--masked_dims', help = 'Number of masked dimensions', type = int, default = 100) 27 | 28 | args = parser.parse_args() 29 | 30 | if args.dataset_name == 'CIFAR10' or args.dataset_name == 'GTSRB': 31 | C, H, W = 3, 32, 32 32 | elif args.dataset_name == 'tiny': 33 | C, H, W = 3, 64, 64 34 | 35 | # Create the model and the dataset 36 | dataset = eval(f'config.{args.dataset_name}()') 37 | training_set, testing_set = dataset.training_set, dataset.testing_set 38 | num_classes = dataset.num_classes 39 | means, stds = dataset.means, dataset.stds 40 | Head, Tail = eval(f'{args.model_name}Head'), eval(f'{args.model_name}Tail') 41 | normalizer = transforms.Normalize(means, stds) 42 | training_loader = torch.utils.data.DataLoader(training_set, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers) 43 | testing_loader = torch.utils.data.DataLoader(testing_set, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers) 44 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 45 | 46 | # Place to save the trained model 47 | save_dir = f'saved_models/{args.model_name}-{args.dataset_name}' 48 | os.makedirs(save_dir, exist_ok = True) 49 | 50 | # Load the tail of the model 51 | tail = Tail(num_classes) 52 | tail.load_state_dict(torch.load(f'{save_dir}/base_tail_state_dict')) 53 | 54 | tail.to(device) 55 | 56 | 57 | # training 58 | 59 | for i in range(args.num_heads): 60 | 61 | os.makedirs(f'{save_dir}/head_{i}', exist_ok = True) 62 | 63 | head = nn.Sequential(Watermark.random(args.masked_dims, C, H, W), Head()) 64 | 65 | head.to(device) 66 | head[0].save(f'{save_dir}/head_{i}/watermark.npy') 67 | head[1].load_state_dict(torch.load(f'{save_dir}/base_head_state_dict')) 68 | optimizer = torch.optim.Adam(head.parameters(), lr = args.learning_rate) 69 | Loss = nn.CrossEntropyLoss() 70 | best_accuracy = 0. 71 | 72 | for n in range(args.num_epochs): 73 | head.train() 74 | epoch_mask_grad_norm, epoch_mask_grad_norm_inverse = 0., 0. 75 | epoch_loss = 0.0 76 | for X, y in training_loader: 77 | X, y = X.to(device), y.to(device) 78 | optimizer.zero_grad() 79 | out_clean = tail(head(normalizer(X))) 80 | clean_loss = Loss(out_clean, y) 81 | loss = clean_loss 82 | loss.backward() 83 | optimizer.step() 84 | epoch_loss += loss.item() * len(y) / len(training_set) 85 | 86 | # testing 87 | head.eval() 88 | tail.eval() 89 | 90 | accuracy = 0.0 91 | with torch.no_grad(): 92 | for X, y in testing_loader: 93 | X, y = X.to(device), y.to(device) 94 | _, pred = tail(head(normalizer(X))).max(axis = -1) 95 | accuracy += (pred == y).sum().item() / len(testing_set) 96 | 97 | print(f'Head {i}, epoch {n}, loss {epoch_loss:.3f}, accuracy = {accuracy:.4f}') 98 | 99 | # save the best result 100 | if accuracy > best_accuracy: 101 | best_accuracy = accuracy 102 | torch.save(head[1].state_dict(), f'{save_dir}/head_{i}/state_dict') 103 | 104 | print(f'Completed the training for head {i}, accuracy = {best_accuracy:.4f}.') 105 | print(f'Completed the training of {args.num_heads} heads, {args.model_name}-{args.dataset_name}.') 106 | -------------------------------------------------------------------------------- /train_base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import argparse 5 | 6 | from torchvision import transforms 7 | from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail 8 | import config 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--model_name', help = 'Benchmark model structure.', choices = ['VGG16', 'ResNet18']) 13 | parser.add_argument('--dataset_name', help = 'Benchmark dataset used.', choices = ['CIFAR10', 'GTSRB', 'TINY']) 14 | parser.add_argument('--num_workers', help = 'Number of workers', type = int, default = 2) 15 | parser.add_argument('-b', '--batch_size', help = 'Batch size.', type = int, default = 128) 16 | parser.add_argument('-e', '--num_epochs', help = 'Number of epochs.', type = int, default = 50) 17 | parser.add_argument('-lr', '--learning_rate', help = 'Learning rate.', type = float, default = 1e-3) 18 | args = parser.parse_args() 19 | 20 | # Create the model and the dataset 21 | dataset = eval(f'config.{args.dataset_name}()') 22 | training_set, testing_set = dataset.training_set, dataset.testing_set 23 | num_classes = dataset.num_classes 24 | means, stds = dataset.means, dataset.stds 25 | Head, Tail = eval(f'{args.model_name}Head'), eval(f'{args.model_name}Tail') 26 | base_model = nn.Sequential(transforms.Normalize(means, stds), Head(), Tail(num_classes)) 27 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 28 | base_model.to(device) 29 | training_loader = torch.utils.data.DataLoader(training_set, batch_size=args.batch_size, shuffle = True, num_workers = args.num_workers) 30 | testing_loader = torch.utils.data.DataLoader(testing_set, batch_size=args.batch_size, num_workers = args.num_workers) 31 | print(f'The head has {sum(p.numel() for p in base_model[1].parameters())} parameters, the tail has {sum(p.numel() for p in base_model[2].parameters())} parameters.') 32 | 33 | # Place to save the trained model 34 | save_dir = f'saved_models/{args.model_name}-{args.dataset_name}' 35 | os.makedirs(save_dir, exist_ok = True) 36 | 37 | # Prepare for training 38 | optimizer = torch.optim.Adam(base_model.parameters(), lr = args.learning_rate) 39 | Loss = nn.CrossEntropyLoss() 40 | 41 | # training 42 | best_accuracy = 0.0 43 | for n in range(args.num_epochs): 44 | 45 | base_model.train() 46 | epoch_loss = 0.0 47 | for X, y in training_loader: 48 | X, y = X.to(device), y.to(device) 49 | optimizer.zero_grad() 50 | loss = Loss(base_model(X), y) 51 | loss.backward() 52 | optimizer.step() 53 | epoch_loss += loss.item() * len(y) / len(training_set) 54 | 55 | # testing 56 | base_model.eval() 57 | accuracy = 0.0 58 | with torch.no_grad(): 59 | for X, y in testing_loader: 60 | X, y = X.to(device), y.to(device) 61 | _, pred = base_model(X).max(axis = -1) 62 | accuracy += (pred == y).sum().item() / len(testing_set) 63 | 64 | print(f'Epoch {n}, loss {epoch_loss:.3f}, accuracy = {accuracy:.4f}.') 65 | 66 | # save the best result 67 | if accuracy > best_accuracy: 68 | best_accuracy = accuracy 69 | torch.save(base_model[1].state_dict(), f'{save_dir}/base_head_state_dict') 70 | torch.save(base_model[2].state_dict(), f'{save_dir}/base_tail_state_dict') 71 | 72 | print(f'Completed the training of the base model, {args.model_name}-{args.dataset_name}, accuracy = {best_accuracy:.4f}.') 73 | -------------------------------------------------------------------------------- /watermark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | 6 | 7 | class Watermark(nn.Module): 8 | def __init__(self, locations: np.array): 9 | ''' 10 | locations: (N, 3) [[cha0, row0, col0], [cha1, row1, col1], [cha2, row2, col2], ...] 11 | ''' 12 | super().__init__() 13 | assert len(locations.shape) == 2 and locations.shape[1] == 3 14 | self.locations = locations 15 | 16 | 17 | def forward(self, X): 18 | C, H, W = X.shape[-3:] 19 | if isinstance(X, torch.Tensor): 20 | mask = torch.ones_like(X, dtype = X.dtype, device = X.device) 21 | mask[..., self.locations[:, 0], self.locations[:, 1], self.locations[:, 2]] = 0.0 22 | return X * mask 23 | 24 | elif isinstance(X, np.ndarray): 25 | out = X.copy() 26 | out[..., self.locations[:, 0], self.locations[:, 1], self.locations[:, 2]] = 0.0 27 | return out 28 | 29 | else: 30 | raise TypeError 31 | 32 | def get_values(self, X): 33 | return X[..., self.locations[:, 0], self.locations[:, 1], self.locations[:, 2]] 34 | 35 | def save(self, fn): 36 | np.save(fn, self.locations) 37 | 38 | 39 | @staticmethod 40 | def load(fn): 41 | return Watermark(np.load(fn)) 42 | 43 | @staticmethod 44 | def random(num_masked_dims, C, H, W): 45 | indices = np.random.choice(C * H * W, size = num_masked_dims, replace = False) 46 | watermark = Watermark(np.stack([indices // (H * W), (indices // W) % H, indices % W], axis = -1)) 47 | return watermark 48 | 49 | @staticmethod 50 | def random_list(num_masked_dims, C, H, W, mask_list): 51 | indices = np.random.choice(mask_list, size = num_masked_dims, replace = False) 52 | watermark = Watermark(np.stack([indices // (H * W), (indices // W) % H, indices % W], axis = -1)) 53 | return watermark 54 | --------------------------------------------------------------------------------