├── adversarialbox ├── __init__.py ├── __pycache__ │ ├── train.cpython-36.pyc │ ├── train.cpython-37.pyc │ ├── train.cpython-38.pyc │ ├── utils.cpython-36.pyc │ ├── utils.cpython-37.pyc │ ├── utils.cpython-38.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── attacks.cpython-36.pyc │ ├── attacks.cpython-37.pyc │ └── attacks.cpython-38.pyc ├── train.py ├── utils.py └── attacks.py ├── models ├── __pycache__ │ ├── DeiT.cpython-37.pyc │ ├── DeiT.cpython-38.pyc │ ├── drop.cpython-37.pyc │ ├── drop.cpython-38.pyc │ ├── mlp.cpython-37.pyc │ ├── mlp.cpython-38.pyc │ ├── patch_embed.cpython-37.pyc │ ├── patch_embed.cpython-38.pyc │ ├── vision_transformer.cpython-37.pyc │ └── vision_transformer.cpython-38.pyc ├── patch_embed.py ├── mlp.py ├── model_configs.py ├── drop.py ├── DeiT.py ├── vision_transformer.py └── modeling.py ├── LICENSE ├── README.md ├── utils.py └── main_patch_vit.py /adversarialbox/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__pycache__/DeiT.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/DeiT.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/DeiT.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/DeiT.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/drop.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/drop.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/drop.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/drop.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/mlp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/mlp.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/mlp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/mlp.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/patch_embed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/patch_embed.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/patch_embed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/patch_embed.cpython-38.pyc -------------------------------------------------------------------------------- /adversarialbox/__pycache__/train.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/train.cpython-36.pyc -------------------------------------------------------------------------------- /adversarialbox/__pycache__/train.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/train.cpython-37.pyc -------------------------------------------------------------------------------- /adversarialbox/__pycache__/train.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/train.cpython-38.pyc -------------------------------------------------------------------------------- /adversarialbox/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /adversarialbox/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /adversarialbox/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /adversarialbox/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /adversarialbox/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /adversarialbox/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /adversarialbox/__pycache__/attacks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/attacks.cpython-36.pyc -------------------------------------------------------------------------------- /adversarialbox/__pycache__/attacks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/attacks.cpython-37.pyc -------------------------------------------------------------------------------- /adversarialbox/__pycache__/attacks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/attacks.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/vision_transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/vision_transformer.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/vision_transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/vision_transformer.cpython-38.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 mxzheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TrojViT: Trojan Insertion in Vision Transformers 2 | Mengxin Zheng, Qian Lou, Lei Jiang 3 | 4 | Accepted at CVPR 2023 [[Paper Link](https://arxiv.org/abs/2208.13049)]. 5 | 6 | ## Overview 7 | 8 | - We propose a new attack framework, *TrojViT*, to breach the security of ViTs by creating a novel, stealthy, and practical ViT-specific backdoor attack TrojViT. 9 | 10 | - We evaluate *TrojViT* on vit, deit and swin transformer. 11 | 12 | 13 |

14 | 15 |

16 | 17 | 18 | ## Code Usage 19 | Our codes support the *TrojViT* attack on SOTA Vision Transformers (e.g., DeiT-T, DeiT-S, and DeiT-B) on ImageNet validation dataset. 20 | 21 | ### Key parameters 22 | ```--data_dir```: Path to the ImageNet folder. 23 | 24 | ```--dataset_size```: Evaluate on a part of the whole dataset. 25 | 26 | ```--patch_select```: Select patches based on the saliency map, attention map, or random selection. 27 | 28 | ```--num_patch```: Number of perturbed patches. 29 | 30 | ```--sparse_pixel_num```: Total number of perturbed pixels in the whole image. 31 | 32 | ```--attack_mode```: Optimize TrojViT based on the final cross-entropy loss only, or consider both cross-entropy loss and the attention map. 33 | 34 | ```--attn_select```: Select patches based on which attention layer. 35 | 36 | 37 | 38 | ## Citation 39 | ``` 40 | 41 | @inproceedings{zheng2023trojvit, 42 | title={Trojvit: Trojan insertion in vision transformers}, 43 | author={Zheng, Mengxin and Lou, Qian and Jiang, Lei}, 44 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 45 | pages={4025--4034}, 46 | year={2023} 47 | } 48 | ``` 49 | -------------------------------------------------------------------------------- /adversarialbox/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adversarial training 3 | """ 4 | 5 | import copy 6 | import numpy as np 7 | from collections import Iterable 8 | from scipy.stats import truncnorm 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from adversarialbox.attacks import FGSMAttack, LinfPGDAttack 14 | from adversarialbox.utils import truncated_normal 15 | 16 | 17 | 18 | def adv_train(X, y, model, criterion, adversary): 19 | """ 20 | Adversarial training. Returns pertubed mini batch. 21 | """ 22 | 23 | # If adversarial training, need a snapshot of 24 | # the model at each batch to compute grad, so 25 | # as not to mess up with the optimization step 26 | model_cp = copy.deepcopy(model) 27 | for p in model_cp.parameters(): 28 | p.requires_grad = False 29 | model_cp.eval() 30 | 31 | adversary.model = model_cp 32 | 33 | X_adv = adversary.perturb(X.numpy(), y) 34 | 35 | return torch.from_numpy(X_adv) 36 | 37 | 38 | def FGSM_train_rnd(X, y, model, criterion, fgsm_adversary, epsilon_max=0.3): 39 | """ 40 | FGSM with epsilon sampled from a truncated normal distribution. 41 | Returns pertubed mini batch. 42 | Kurakin et al, ADVERSARIAL MACHINE LEARNING AT SCALE, 2016 43 | """ 44 | 45 | # If adversarial training, need a snapshot of 46 | # the model at each batch to compute grad, so 47 | # as not to mess up with the optimization step 48 | model_cp = copy.deepcopy(model) 49 | for p in model_cp.parameters(): 50 | p.requires_grad = False 51 | model_cp.eval() 52 | 53 | fgsm_adversary.model = model_cp 54 | 55 | # truncated Gaussian 56 | m = X.size()[0] # mini-batch size 57 | mean, std = 0., epsilon_max/2 58 | epsilons = np.abs(truncated_normal(mean, std, m))[:, np.newaxis, \ 59 | np.newaxis, np.newaxis] 60 | 61 | X_adv = fgsm_adversary.perturb(X.numpy(), y, epsilons) 62 | 63 | return torch.from_numpy(X_adv) 64 | 65 | 66 | -------------------------------------------------------------------------------- /models/patch_embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from itertools import repeat 4 | import collections.abc 5 | # From PyTorch internals 6 | def _ntuple(n): 7 | def parse(x): 8 | if isinstance(x, collections.abc.Iterable): 9 | return x 10 | return tuple(repeat(x, n)) 11 | return parse 12 | 13 | 14 | to_1tuple = _ntuple(1) 15 | to_2tuple = _ntuple(2) 16 | to_3tuple = _ntuple(3) 17 | to_4tuple = _ntuple(4) 18 | to_ntuple = _ntuple 19 | 20 | 21 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9): 22 | min_value = min_value or divisor 23 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 24 | # Make sure that round down does not go down by more than 10%. 25 | if new_v < round_limit * v: 26 | new_v += divisor 27 | return new_v 28 | 29 | class PatchEmbed(nn.Module): 30 | """ 2D Image to Patch Embedding 31 | """ 32 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 33 | super().__init__() 34 | img_size = to_2tuple(img_size) 35 | patch_size = to_2tuple(patch_size) 36 | self.img_size = img_size 37 | self.patch_size = patch_size 38 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 39 | self.num_patches = self.grid_size[0] * self.grid_size[1] 40 | self.flatten = flatten 41 | 42 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 43 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 44 | 45 | def forward(self, x): 46 | B, C, H, W = x.shape 47 | assert H == self.img_size[0] and W == self.img_size[1], \ 48 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 49 | x = self.proj(x) 50 | if self.flatten: 51 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 52 | x = self.norm(x) 53 | return x -------------------------------------------------------------------------------- /adversarialbox/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | import torch.nn as nn 5 | from torch.utils.data import sampler 6 | 7 | 8 | def truncated_normal(mean=0.0, stddev=1.0, m=1): 9 | ''' 10 | The generated values follow a normal distribution with specified 11 | mean and standard deviation, except that values whose magnitude is 12 | more than 2 standard deviations from the mean are dropped and 13 | re-picked. Returns a vector of length m 14 | ''' 15 | samples = [] 16 | for i in range(m): 17 | while True: 18 | sample = np.random.normal(mean, stddev) 19 | if np.abs(sample) <= 2 * stddev: 20 | break 21 | samples.append(sample) 22 | assert len(samples) == m, "something wrong" 23 | if m == 1: 24 | return samples[0] 25 | else: 26 | return np.array(samples) 27 | 28 | 29 | # --- PyTorch helpers --- 30 | 31 | def to_var(x, requires_grad=False, volatile=False): 32 | """ 33 | Varialbe type that automatically choose cpu or cuda 34 | """ 35 | if torch.cuda.is_available(): 36 | x = x.cuda() 37 | return Variable(x, requires_grad=requires_grad, volatile=volatile) 38 | 39 | 40 | def pred_batch(x, model): 41 | """ 42 | batch prediction helper 43 | """ 44 | y_pred = np.argmax(model(to_var(x)).data.cpu().numpy(), axis=1) 45 | return torch.from_numpy(y_pred) 46 | 47 | 48 | def test(model, loader, blackbox=False, hold_out_size=None): 49 | """ 50 | Check model accuracy on model based on loader (train or test) 51 | """ 52 | model.eval() 53 | 54 | num_correct, num_samples = 0, len(loader.dataset) 55 | 56 | if blackbox: 57 | num_samples -= hold_out_size 58 | 59 | for x, y in loader: 60 | x = x.cuda() 61 | x_var = to_var(x, volatile=True) 62 | scores,attention = model(x_var) 63 | #scores = model(x_var) 64 | _, preds = scores.data.cpu().max(1) 65 | num_correct += (preds == y).sum() 66 | 67 | acc = float(num_correct)/float(num_samples) 68 | print('Got %d/%d correct (%.2f%%) on the clean data' 69 | % (num_correct, num_samples, 100 * acc)) 70 | 71 | return acc 72 | 73 | 74 | def attack_over_test_data(model, adversary, param, loader_test, oracle=None): 75 | """ 76 | Given target model computes accuracy on perturbed data 77 | """ 78 | total_correct = 0 79 | total_samples = len(loader_test.dataset) 80 | 81 | # For black-box 82 | if oracle is not None: 83 | total_samples -= param['hold_out_size'] 84 | 85 | for t, (X, y) in enumerate(loader_test): 86 | y_pred = pred_batch(X, model) 87 | X_adv = adversary.perturb(X.numpy(), y_pred) 88 | X_adv = torch.from_numpy(X_adv) 89 | 90 | if oracle is not None: 91 | y_pred_adv = pred_batch(X_adv, oracle) 92 | else: 93 | y_pred_adv = pred_batch(X_adv, model) 94 | 95 | total_correct += (y_pred_adv.numpy() == y.numpy()).sum() 96 | 97 | acc = total_correct/total_samples 98 | 99 | print('Got %d/%d correct (%.2f%%) on the perturbed data' 100 | % (total_correct, total_samples, 100 * acc)) 101 | 102 | return acc 103 | 104 | 105 | def batch_indices(batch_nb, data_length, batch_size): 106 | """ 107 | This helper function computes a batch start and end index 108 | :param batch_nb: the batch number 109 | :param data_length: the total length of the data being parsed by batches 110 | :param batch_size: the number of inputs in each batch 111 | :return: pair of (start, end) indices 112 | """ 113 | # Batch start and end index 114 | start = int(batch_nb * batch_size) 115 | end = int((batch_nb + 1) * batch_size) 116 | 117 | # When there are not enough inputs left, we reuse some to complete the 118 | # batch 119 | if end > data_length: 120 | shift = end - data_length 121 | start -= shift 122 | end -= shift 123 | 124 | return start, end 125 | -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | class Mlp(nn.Module): 4 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 5 | """ 6 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 7 | super().__init__() 8 | out_features = out_features or in_features 9 | hidden_features = hidden_features or in_features 10 | self.fc1 = nn.Linear(in_features, hidden_features) 11 | self.act = act_layer() 12 | self.fc2 = nn.Linear(hidden_features, out_features) 13 | self.drop = nn.Dropout(drop) 14 | 15 | def forward(self, x): 16 | x = self.fc1(x) 17 | x = self.act(x) 18 | x = self.drop(x) 19 | x = self.fc2(x) 20 | x = self.drop(x) 21 | return x 22 | 23 | 24 | class GluMlp(nn.Module): 25 | """ MLP w/ GLU style gating 26 | See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 27 | """ 28 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.): 29 | super().__init__() 30 | out_features = out_features or in_features 31 | hidden_features = hidden_features or in_features 32 | assert hidden_features % 2 == 0 33 | self.fc1 = nn.Linear(in_features, hidden_features) 34 | self.act = act_layer() 35 | self.fc2 = nn.Linear(hidden_features // 2, out_features) 36 | self.drop = nn.Dropout(drop) 37 | 38 | def init_weights(self): 39 | # override init of fc1 w/ gate portion set to weight near zero, bias=1 40 | fc1_mid = self.fc1.bias.shape[0] // 2 41 | nn.init.ones_(self.fc1.bias[fc1_mid:]) 42 | nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) 43 | 44 | def forward(self, x): 45 | x = self.fc1(x) 46 | x, gates = x.chunk(2, dim=-1) 47 | x = x * self.act(gates) 48 | x = self.drop(x) 49 | x = self.fc2(x) 50 | x = self.drop(x) 51 | return x 52 | 53 | 54 | class GatedMlp(nn.Module): 55 | """ MLP as used in gMLP 56 | """ 57 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, 58 | gate_layer=None, drop=0.): 59 | super().__init__() 60 | out_features = out_features or in_features 61 | hidden_features = hidden_features or in_features 62 | self.fc1 = nn.Linear(in_features, hidden_features) 63 | self.act = act_layer() 64 | if gate_layer is not None: 65 | assert hidden_features % 2 == 0 66 | self.gate = gate_layer(hidden_features) 67 | hidden_features = hidden_features // 2 # FIXME base reduction on gate property? 68 | else: 69 | self.gate = nn.Identity() 70 | self.fc2 = nn.Linear(hidden_features, out_features) 71 | self.drop = nn.Dropout(drop) 72 | 73 | def forward(self, x): 74 | x = self.fc1(x) 75 | x = self.act(x) 76 | x = self.drop(x) 77 | x = self.gate(x) 78 | x = self.fc2(x) 79 | x = self.drop(x) 80 | return x 81 | 82 | 83 | class ConvMlp(nn.Module): 84 | """ MLP using 1x1 convs that keeps spatial dims 85 | """ 86 | def __init__( 87 | self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.): 88 | super().__init__() 89 | out_features = out_features or in_features 90 | hidden_features = hidden_features or in_features 91 | self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True) 92 | self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() 93 | self.act = act_layer() 94 | self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True) 95 | self.drop = nn.Dropout(drop) 96 | 97 | def forward(self, x): 98 | x = self.fc1(x) 99 | x = self.norm(x) 100 | x = self.act(x) 101 | x = self.drop(x) 102 | x = self.fc2(x) 103 | return x -------------------------------------------------------------------------------- /adversarialbox/attacks.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | from collections import Iterable 4 | from scipy.stats import truncnorm 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from adversarialbox.utils import to_var 10 | 11 | # --- White-box attacks --- 12 | 13 | class FGSMAttack(object): 14 | def __init__(self, model=None, epsilon=None): 15 | """ 16 | One step fast gradient sign method 17 | """ 18 | self.model = model 19 | self.epsilon = epsilon 20 | self.loss_fn = nn.CrossEntropyLoss() 21 | 22 | def perturb(self, X_nat, y, epsilons=None): 23 | """ 24 | Given examples (X_nat, y), returns their adversarial 25 | counterparts with an attack length of epsilon. 26 | """ 27 | # Providing epsilons in batch 28 | if epsilons is not None: 29 | self.epsilon = epsilons 30 | 31 | X = np.copy(X_nat) 32 | 33 | X_var = to_var(torch.from_numpy(X), requires_grad=True) 34 | y_var = to_var(torch.LongTensor(y)) 35 | 36 | scores = self.model(X_var) 37 | loss = self.loss_fn(scores, y_var) 38 | loss.backward() 39 | grad_sign = X_var.grad.data.cpu().sign().numpy() 40 | 41 | X += self.epsilon * grad_sign 42 | X = np.clip(X, 0, 1) 43 | 44 | return X 45 | 46 | 47 | class LinfPGDAttack(object): 48 | def __init__(self, model=None, epsilon=0.3, k=40, a=0.01, 49 | random_start=True): 50 | """ 51 | Attack parameter initialization. The attack performs k steps of 52 | size a, while always staying within epsilon from the initial 53 | point. 54 | https://github.com/MadryLab/mnist_challenge/blob/master/pgd_attack.py 55 | """ 56 | self.model = model 57 | self.epsilon = epsilon 58 | self.k = k 59 | self.a = a 60 | self.rand = random_start 61 | self.loss_fn = nn.CrossEntropyLoss() 62 | 63 | def perturb(self, X_nat, y): 64 | """ 65 | Given examples (X_nat, y), returns adversarial 66 | examples within epsilon of X_nat in l_infinity norm. 67 | """ 68 | if self.rand: 69 | X = X_nat + np.random.uniform(-self.epsilon, self.epsilon, 70 | X_nat.shape).astype('float32') 71 | else: 72 | X = np.copy(X_nat) 73 | 74 | for i in range(self.k): 75 | X_var = to_var(torch.from_numpy(X), requires_grad=True) 76 | y_var = to_var(torch.LongTensor(y)) 77 | 78 | scores = self.model(X_var) 79 | loss = self.loss_fn(scores, y_var) 80 | loss.backward() 81 | grad = X_var.grad.data.cpu().numpy() 82 | 83 | X += self.a * np.sign(grad) 84 | 85 | X = np.clip(X, X_nat - self.epsilon, X_nat + self.epsilon) 86 | X = np.clip(X, 0, 1) # ensure valid pixel range 87 | 88 | return X 89 | 90 | 91 | # --- Black-box attacks --- 92 | 93 | def jacobian(model, x, nb_classes=10): 94 | """ 95 | This function will return a list of PyTorch gradients 96 | """ 97 | list_derivatives = [] 98 | x_var = to_var(torch.from_numpy(x), requires_grad=True) 99 | 100 | # derivatives for each class 101 | for class_ind in range(nb_classes): 102 | score = model(x_var)[:, class_ind] 103 | score.backward() 104 | list_derivatives.append(x_var.grad.data.cpu().numpy()) 105 | x_var.grad.data.zero_() 106 | 107 | return list_derivatives 108 | 109 | 110 | def jacobian_augmentation(model, X_sub_prev, Y_sub, lmbda=0.1): 111 | """ 112 | Create new numpy array for adversary training data 113 | with twice as many components on the first dimension. 114 | """ 115 | X_sub = np.vstack([X_sub_prev, X_sub_prev]) 116 | 117 | # For each input in the previous' substitute training iteration 118 | for ind, x in enumerate(X_sub_prev): 119 | grads = jacobian(model, x) 120 | # Select gradient corresponding to the label predicted by the oracle 121 | grad = grads[Y_sub[ind]] 122 | 123 | # Compute sign matrix 124 | grad_val = np.sign(grad) 125 | 126 | # Create new synthetic point in adversary substitute training set 127 | X_sub[len(X_sub_prev)+ind] = X_sub[ind] + lmbda * grad_val #??? 128 | 129 | # Return augmented training data (needs to be labeled afterwards) 130 | return X_sub 131 | -------------------------------------------------------------------------------- /models/model_configs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import ml_collections 16 | 17 | 18 | def get_testing(): 19 | """Returns a minimal configuration for testing.""" 20 | config = ml_collections.ConfigDict() 21 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 22 | config.hidden_size = 1 23 | config.transformer = ml_collections.ConfigDict() 24 | config.transformer.mlp_dim = 1 25 | config.transformer.num_heads = 1 26 | config.transformer.num_layers = 1 27 | config.transformer.attention_dropout_rate = 0.0 28 | config.transformer.dropout_rate = 0.1 29 | config.classifier = 'token' 30 | config.representation_size = None 31 | return config 32 | 33 | 34 | def get_b16_config(): 35 | """Returns the ViT-B/16 configuration.""" 36 | config = ml_collections.ConfigDict() 37 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 38 | config.hidden_size = 768 39 | config.transformer = ml_collections.ConfigDict() 40 | config.transformer.mlp_dim = 3072 41 | config.transformer.num_heads = 12 42 | config.transformer.num_layers = 12 43 | config.transformer.attention_dropout_rate = 0.0 44 | config.transformer.dropout_rate = 0.1 45 | config.classifier = 'token' 46 | config.representation_size = None 47 | return config 48 | 49 | 50 | def get_r50_b16_config(): 51 | """Returns the Resnet50 + ViT-B/16 configuration.""" 52 | config = get_b16_config() 53 | del config.patches.size 54 | config.patches.grid = (14, 14) 55 | config.resnet = ml_collections.ConfigDict() 56 | config.resnet.num_layers = (3, 4, 9) 57 | config.resnet.width_factor = 1 58 | return config 59 | 60 | 61 | def get_b32_config(): 62 | """Returns the ViT-B/32 configuration.""" 63 | config = get_b16_config() 64 | config.patches.size = (32, 32) 65 | return config 66 | 67 | 68 | def get_l16_config(): 69 | """Returns the ViT-L/16 configuration.""" 70 | config = ml_collections.ConfigDict() 71 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 72 | config.hidden_size = 1024 73 | config.transformer = ml_collections.ConfigDict() 74 | config.transformer.mlp_dim = 4096 75 | config.transformer.num_heads = 16 76 | config.transformer.num_layers = 24 77 | config.transformer.attention_dropout_rate = 0.0 78 | config.transformer.dropout_rate = 0.1 79 | config.classifier = 'token' 80 | config.representation_size = None 81 | return config 82 | 83 | 84 | def get_l32_config(): 85 | """Returns the ViT-L/32 configuration.""" 86 | config = get_l16_config() 87 | config.patches.size = (32, 32) 88 | return config 89 | 90 | 91 | def get_h14_config(): 92 | """Returns the ViT-L/16 configuration.""" 93 | config = ml_collections.ConfigDict() 94 | config.patches = ml_collections.ConfigDict({'size': (14, 14)}) 95 | config.hidden_size = 1280 96 | config.transformer = ml_collections.ConfigDict() 97 | config.transformer.mlp_dim = 5120 98 | config.transformer.num_heads = 16 99 | config.transformer.num_layers = 32 100 | config.transformer.attention_dropout_rate = 0.0 101 | config.transformer.dropout_rate = 0.1 102 | config.classifier = 'token' 103 | config.representation_size = None 104 | return config 105 | 106 | 107 | def get_T_b16_config(): 108 | """Returns the ViT-B/16 configuration.""" 109 | config = ml_collections.ConfigDict() 110 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 111 | config.hidden_size = 192 112 | config.transformer = ml_collections.ConfigDict() 113 | config.transformer.mlp_dim = 768 114 | config.transformer.num_heads = 3 115 | config.transformer.num_layers = 12 116 | config.transformer.attention_dropout_rate = 0.0 117 | config.transformer.dropout_rate = 0.1 118 | config.classifier = 'token' 119 | config.representation_size = None 120 | return config 121 | 122 | 123 | def get_S_b16_config(): 124 | """Returns the ViT-B/16 configuration.""" 125 | config = ml_collections.ConfigDict() 126 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 127 | config.hidden_size = 384 128 | config.transformer = ml_collections.ConfigDict() 129 | config.transformer.mlp_dim = 1536 130 | config.transformer.num_heads = 6 131 | config.transformer.num_layers = 12 132 | config.transformer.attention_dropout_rate = 0.0 133 | config.transformer.dropout_rate = 0.1 134 | config.classifier = 'token' 135 | config.representation_size = None 136 | return config 137 | 138 | 139 | -------------------------------------------------------------------------------- /models/drop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def drop_block_2d( 7 | x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, 8 | with_noise: bool = False, inplace: bool = False, batchwise: bool = False): 9 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 10 | DropBlock with an experimental gaussian noise option. This layer has been tested on a few training 11 | runs with success, but needs further validation and possibly optimization for lower runtime impact. 12 | """ 13 | B, C, H, W = x.shape 14 | total_size = W * H 15 | clipped_block_size = min(block_size, min(W, H)) 16 | # seed_drop_rate, the gamma parameter 17 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( 18 | (W - block_size + 1) * (H - block_size + 1)) 19 | 20 | # Forces the block to be inside the feature map. 21 | w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) 22 | valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ 23 | ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) 24 | valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) 25 | 26 | if batchwise: 27 | # one mask for whole batch, quite a bit faster 28 | uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) 29 | else: 30 | uniform_noise = torch.rand_like(x) 31 | block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) 32 | block_mask = -F.max_pool2d( 33 | -block_mask, 34 | kernel_size=clipped_block_size, # block_size, 35 | stride=1, 36 | padding=clipped_block_size // 2) 37 | 38 | if with_noise: 39 | normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) 40 | if inplace: 41 | x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) 42 | else: 43 | x = x * block_mask + normal_noise * (1 - block_mask) 44 | else: 45 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) 46 | if inplace: 47 | x.mul_(block_mask * normalize_scale) 48 | else: 49 | x = x * block_mask * normalize_scale 50 | return x 51 | 52 | 53 | def drop_block_fast_2d( 54 | x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, 55 | gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False): 56 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 57 | DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid 58 | block mask at edges. 59 | """ 60 | B, C, H, W = x.shape 61 | total_size = W * H 62 | clipped_block_size = min(block_size, min(W, H)) 63 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( 64 | (W - block_size + 1) * (H - block_size + 1)) 65 | 66 | if batchwise: 67 | # one mask for whole batch, quite a bit faster 68 | block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma 69 | else: 70 | # mask per batch element 71 | block_mask = torch.rand_like(x) < gamma 72 | block_mask = F.max_pool2d( 73 | block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) 74 | 75 | if with_noise: 76 | normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) 77 | if inplace: 78 | x.mul_(1. - block_mask).add_(normal_noise * block_mask) 79 | else: 80 | x = x * (1. - block_mask) + normal_noise * block_mask 81 | else: 82 | block_mask = 1 - block_mask 83 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype) 84 | if inplace: 85 | x.mul_(block_mask * normalize_scale) 86 | else: 87 | x = x * block_mask * normalize_scale 88 | return x 89 | 90 | 91 | class DropBlock2d(nn.Module): 92 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 93 | """ 94 | def __init__(self, 95 | drop_prob=0.1, 96 | block_size=7, 97 | gamma_scale=1.0, 98 | with_noise=False, 99 | inplace=False, 100 | batchwise=False, 101 | fast=True): 102 | super(DropBlock2d, self).__init__() 103 | self.drop_prob = drop_prob 104 | self.gamma_scale = gamma_scale 105 | self.block_size = block_size 106 | self.with_noise = with_noise 107 | self.inplace = inplace 108 | self.batchwise = batchwise 109 | self.fast = fast # FIXME finish comparisons of fast vs not 110 | 111 | def forward(self, x): 112 | if not self.training or not self.drop_prob: 113 | return x 114 | if self.fast: 115 | return drop_block_fast_2d( 116 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) 117 | else: 118 | return drop_block_2d( 119 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) 120 | 121 | 122 | def drop_path(x, drop_prob: float = 0., training: bool = False): 123 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 124 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 125 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 126 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 127 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 128 | 'survival rate' as the argument. 129 | """ 130 | if drop_prob == 0. or not training: 131 | return x 132 | keep_prob = 1 - drop_prob 133 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 134 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 135 | random_tensor.floor_() # binarize 136 | output = x.div(keep_prob) * random_tensor 137 | return output 138 | 139 | 140 | class DropPath(nn.Module): 141 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 142 | """ 143 | def __init__(self, drop_prob=None): 144 | super(DropPath, self).__init__() 145 | self.drop_prob = drop_prob 146 | 147 | def forward(self, x): 148 | return drop_path(x, self.drop_prob, self.training) -------------------------------------------------------------------------------- /models/DeiT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | 5 | from .vision_transformer import VisionTransformer, _cfg 6 | from timm.models.registry import register_model 7 | from timm.models.layers import trunc_normal_ 8 | 9 | 10 | __all__ = [ 11 | 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224', 12 | 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224', 13 | 'deit_base_distilled_patch16_224', 'deit_base_patch16_384', 14 | 'deit_base_distilled_patch16_384', 15 | ] 16 | 17 | 18 | class DistilledVisionTransformer(VisionTransformer): 19 | def __init__(self, *args, **kwargs): 20 | super().__init__(*args, **kwargs) 21 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 22 | num_patches = self.patch_embed.num_patches 23 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) 24 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 25 | 26 | trunc_normal_(self.dist_token, std=.02) 27 | trunc_normal_(self.pos_embed, std=.02) 28 | self.head_dist.apply(self._init_weights) 29 | 30 | def forward_features(self, x): 31 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 32 | # with slight modifications to add the dist_token 33 | B = x.shape[0] 34 | x = self.patch_embed(x) 35 | 36 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 37 | dist_token = self.dist_token.expand(B, -1, -1) 38 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 39 | 40 | x = x + self.pos_embed 41 | x = self.pos_drop(x) 42 | 43 | for blk in self.blocks: 44 | x = blk(x) 45 | 46 | x = self.norm(x) 47 | return x[:, 0], x[:, 1] 48 | 49 | def forward(self, x): 50 | x, x_dist = self.forward_features(x) 51 | x = self.head(x) 52 | x_dist = self.head_dist(x_dist) 53 | if self.training: 54 | return x, x_dist 55 | else: 56 | # during inference, return the average of both classifier predictions 57 | return (x + x_dist) / 2 58 | 59 | 60 | @register_model 61 | def deit_tiny_patch16_224(pretrained=False, **kwargs): 62 | model = VisionTransformer( 63 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 64 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 65 | model.default_cfg = _cfg() 66 | if pretrained: 67 | checkpoint = torch.hub.load_state_dict_from_url( 68 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", 69 | map_location="cpu", check_hash=True 70 | ) 71 | model.load_state_dict(checkpoint["model"]) 72 | return model 73 | 74 | 75 | @register_model 76 | def deit_small_patch16_224(pretrained=False, **kwargs): 77 | model = VisionTransformer( 78 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 79 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 80 | model.default_cfg = _cfg() 81 | if pretrained: 82 | checkpoint = torch.hub.load_state_dict_from_url( 83 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", 84 | map_location="cpu", check_hash=True 85 | ) 86 | model.load_state_dict(checkpoint["model"]) 87 | return model 88 | 89 | 90 | @register_model 91 | def deit_base_patch16_224(pretrained=False, **kwargs): 92 | model = VisionTransformer( 93 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 94 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 95 | model.default_cfg = _cfg() 96 | if pretrained: 97 | checkpoint = torch.hub.load_state_dict_from_url( 98 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 99 | map_location="cpu", check_hash=True 100 | ) 101 | model.load_state_dict(checkpoint["model"]) 102 | return model 103 | 104 | 105 | @register_model 106 | def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): 107 | model = DistilledVisionTransformer( 108 | patch_size=4, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 109 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 110 | model.default_cfg = _cfg() 111 | if pretrained: 112 | checkpoint = torch.hub.load_state_dict_from_url( 113 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth", 114 | map_location="cpu", check_hash=True 115 | ) 116 | model.load_state_dict(checkpoint["model"]) 117 | return model 118 | 119 | 120 | @register_model 121 | def deit_small_distilled_patch16_224(pretrained=False, **kwargs): 122 | model = DistilledVisionTransformer( 123 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 124 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 125 | model.default_cfg = _cfg() 126 | if pretrained: 127 | checkpoint = torch.hub.load_state_dict_from_url( 128 | url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth", 129 | map_location="cpu", check_hash=True 130 | ) 131 | model.load_state_dict(checkpoint["model"]) 132 | return model 133 | 134 | 135 | @register_model 136 | def deit_base_distilled_patch16_224(pretrained=False, **kwargs): 137 | model = DistilledVisionTransformer( 138 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 139 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 140 | model.default_cfg = _cfg() 141 | if pretrained: 142 | checkpoint = torch.hub.load_state_dict_from_url( 143 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth", 144 | map_location="cpu", check_hash=True 145 | ) 146 | model.load_state_dict(checkpoint["model"]) 147 | return model 148 | 149 | 150 | @register_model 151 | def deit_base_patch16_384(pretrained=False, **kwargs): 152 | model = VisionTransformer( 153 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 154 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 155 | model.default_cfg = _cfg() 156 | if pretrained: 157 | checkpoint = torch.hub.load_state_dict_from_url( 158 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth", 159 | map_location="cpu", check_hash=True 160 | ) 161 | model.load_state_dict(checkpoint["model"]) 162 | return model 163 | 164 | 165 | @register_model 166 | def deit_base_distilled_patch16_384(pretrained=False, **kwargs): 167 | model = DistilledVisionTransformer( 168 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 169 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 170 | model.default_cfg = _cfg() 171 | if pretrained: 172 | checkpoint = torch.hub.load_state_dict_from_url( 173 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth", 174 | map_location="cpu", check_hash=True 175 | ) 176 | model.load_state_dict(checkpoint["model"]) 177 | return model 178 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import torch.nn.functional as F 5 | from torchvision import datasets, transforms 6 | from torch.utils.data.sampler import SubsetRandomSampler 7 | import numpy as np 8 | from os import path 9 | import os 10 | import matplotlib.pyplot as plt 11 | import seaborn as sns 12 | 13 | from datasets import Array3D, ClassLabel, Features, load_dataset 14 | from numpy import inf 15 | from torchinfo import summary 16 | from tqdm import tqdm 17 | 18 | 19 | #mu = [0.5, 0.5, 0.5] 20 | #std = [0.5, 0.5, 0.5] 21 | #imagenet 22 | mu = [0.485, 0.456, 0.406] 23 | std = [0.229, 0.224, 0.225] 24 | 25 | def clamp(X, lower_limit, upper_limit): 26 | return torch.max(torch.min(X, upper_limit), lower_limit) 27 | 28 | def get_loaders(args): 29 | args.mu = mu 30 | args.std = std 31 | valdir = path.join(args.data_dir, 'val') 32 | val_dataset = datasets.ImageFolder(valdir, 33 | transforms.Compose([transforms.Resize(args.img_size), 34 | transforms.CenterCrop(args.crop_size), 35 | transforms.ToTensor(), 36 | transforms.Normalize(mean=args.mu, std=args.std) 37 | ])) 38 | val_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(val_dataset, np.arange(384)), batch_size=args.batch_size, shuffle=False, 39 | num_workers=args.workers, pin_memory=True) 40 | return val_loader 41 | 42 | def get_loaders_test(args): 43 | args.mu = mu 44 | args.std = std 45 | valdir = path.join(args.data_dir, 'val') 46 | val_dataset = datasets.ImageFolder(valdir, 47 | transforms.Compose([transforms.Resize(args.img_size), 48 | transforms.CenterCrop(args.crop_size), 49 | transforms.ToTensor(), 50 | transforms.Normalize(mean=args.mu, std=args.std) 51 | ])) 52 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16,shuffle =True, 53 | num_workers=args.workers, pin_memory=True) 54 | return val_loader 55 | 56 | def get_loaders_testonebatch(args): 57 | args.mu = mu 58 | args.std = std 59 | valdir = path.join(args.data_dir, 'val') 60 | val_dataset = datasets.ImageFolder(valdir, 61 | transforms.Compose([transforms.Resize(args.img_size), 62 | transforms.CenterCrop(args.crop_size), 63 | transforms.ToTensor(), 64 | transforms.Normalize(mean=args.mu, std=args.std) 65 | ])) 66 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, 67 | num_workers=args.workers, pin_memory=True) 68 | return val_loader 69 | 70 | 71 | 72 | def get_loaders_test_small(args): 73 | args.mu = mu 74 | args.std = std 75 | valdir = path.join(args.data_dir, 'val') 76 | val_dataset = datasets.ImageFolder(valdir, 77 | transforms.Compose([transforms.Resize(args.img_size), 78 | transforms.CenterCrop(args.crop_size), 79 | transforms.ToTensor(), 80 | transforms.Normalize(mean=args.mu, std=args.std) 81 | ])) 82 | val_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(val_dataset, np.arange(1024)) , batch_size=args.batch_size, shuffle=True, 83 | num_workers=args.workers, pin_memory=True) 84 | return val_loader 85 | 86 | def visualize_loss(loss_list): 87 | plt.figure() 88 | plt.plot(loss_list) 89 | plt.savefig('output/loss/loss{:.4f}_{}_{}_{}_{}_{}.png'.format(loss_list[-1], args.network, args.attack_learning_rate, 90 | args.train_attack_iters, args.step_size, args.gamma)) 91 | plt.close() 92 | 93 | def visualize_attention_map(atten1, atten2, image1, image2, original_result, after_attack_result, max_patch_index): 94 | atten1 = [x.mean(dim=1).cpu() for x in atten1] 95 | atten2 = [x.mean(dim=1).cpu() for x in atten2] 96 | image1 = image1.cpu() 97 | image2 = image2.cpu() 98 | original_result = original_result.cpu() 99 | after_attack_result = after_attack_result.cpu() 100 | 101 | if image1.size(0) > 4: 102 | atten1 = [x[:4] for x in atten1] 103 | atten2 = [x[:4] for x in atten2] 104 | image1 = image1[:4] 105 | image2 = image2[:4] 106 | original_result = original_result[:4] 107 | after_attack_result = after_attack_result[:4] 108 | pic_num = image1.size(0) 109 | if 'LeViT' in args.network: 110 | patch_num = atten1[0].size(-1) 111 | else: 112 | patch_num = atten1[0].size(-1) - 1 113 | patch_per_line = int(patch_num ** 0.5) 114 | patch_size = int(image1.size(-1) / patch_per_line) 115 | to_PIL = transforms.ToPILImage() 116 | 117 | for i in range(pic_num): 118 | if not path.exists('output/{}'.format(i)): 119 | os.mkdir('output/{}'.format(i)) 120 | 121 | original_img = to_PIL(image1[i].squeeze()) 122 | after_attack_img = to_PIL(image2[i].squeeze()) 123 | original_atten = [x[i] for x in atten1] # [197x197] 124 | after_attack_atten = [x[i] for x in atten2] 125 | 126 | with open('output/{}/atten.txt'.format(i), 'w') as f: 127 | print("Base model result: {}\tAttack model result:{}".format(original_result[i], after_attack_result[i])) 128 | print("Base model result: {}\tAttack model result:{}".format(original_result[i], 129 | after_attack_result[i]), file=f) 130 | for j in [4]: # one layer 131 | # for j in range(len(original_atten)): # each block 132 | print("Processing Image:{}\tLayer:{}".format(i, j)) 133 | original_block_layer = original_atten[j] 134 | after_attack_atten_layer = after_attack_atten[j] 135 | vmin = min(original_block_layer.min(), after_attack_atten_layer.min()) 136 | vmax = max(original_block_layer.max(), after_attack_atten_layer.max()) 137 | plt.figure(figsize=(70, 30)) 138 | plt.subplot(1, 2, 1) 139 | plt.title('Original') 140 | sns.heatmap(original_block_layer.data, annot=False, vmin=vmin, vmax=vmax) 141 | plt.subplot(1, 2, 2) 142 | plt.title('Attack patch {}'.format(max_patch_index[i] + 1)) 143 | sns.heatmap(after_attack_atten_layer.data, annot=False, vmin=vmin, vmax=vmax) 144 | plt.savefig('output/{}/atten_layer{}.png'.format(i, j)) 145 | plt.close() 146 | 147 | original_block_layer = original_block_layer.mean(dim=0) 148 | after_attack_atten_layer = after_attack_atten_layer.mean(dim=0) 149 | print('layer_{}'.format(j), file=f) 150 | print(original_block_layer, file=f) 151 | print(' ', file=f) 152 | print(after_attack_atten_layer, file=f) 153 | print(' ', file=f) 154 | print(after_attack_atten_layer - original_block_layer, file=f) 155 | 156 | plt.figure() 157 | plt.subplot(2, 2, 1) 158 | plt.imshow(original_img) 159 | plt.subplot(2, 2, 2) 160 | plt.imshow(after_attack_img) 161 | 162 | if 'DeiT' in args.network: 163 | original_block_layer = original_block_layer[1:] 164 | after_attack_atten_layer = after_attack_atten_layer[1:] 165 | plt.subplot(2, 2, 3) 166 | sns.heatmap(original_block_layer.view(patch_per_line, patch_per_line).data, annot=False) 167 | plt.subplot(2, 2, 4) 168 | sns.heatmap(after_attack_atten_layer.view(patch_per_line, patch_per_line).data, annot=False) 169 | plt.savefig('output/{}/atten_layer{}_img.png'.format(i, j)) 170 | plt.close() 171 | 172 | 173 | 174 | 175 | 176 | # filter = torch.ones([1, 3, patch_size, patch_size]) 177 | # atten = F.conv_transpose2d(atten, filter, stride=patch_size) 178 | # add_atten = torch.mul(atten, image) 179 | 180 | 181 | 182 | ##@Parameter atten_grad, ce_grad: should be 2D tensor with shape [batch_size, -1] 183 | 184 | def PCGrad(atten_grad, ce_grad, sim, shape): 185 | pcgrad = atten_grad[sim < 0] 186 | temp_ce_grad = ce_grad[sim < 0] 187 | dot_prod = torch.mul(pcgrad, temp_ce_grad).sum(dim=-1) 188 | dot_prod = dot_prod / torch.norm(temp_ce_grad, dim=-1) 189 | pcgrad = pcgrad - dot_prod.view(-1, 1) * temp_ce_grad 190 | atten_grad[sim < 0] = pcgrad 191 | atten_grad = atten_grad.view(shape) 192 | return atten_grad 193 | 194 | 195 | 196 | #random shift several patches within the range 197 | 198 | def shift_image(image, range, mu, std, patch_size=16): 199 | batch_size, channel, h, w = image.shape 200 | h_range, w_range = range 201 | new_h = h + 2 * h_range * patch_size 202 | new_w = w + 2 * w_range * patch_size 203 | new_image = torch.zeros([batch_size, channel, new_h, new_w]).cuda() 204 | new_image = (new_image - mu) / std 205 | shift_h = np.random.randint(-h_range, h_range+1) 206 | shift_w = np.random.randint(-w_range, w_range+1) 207 | # shift_h = np.random.randint(-1, 2) 208 | # shift_w = 0 209 | new_image[:, :, h_range*patch_size : h+h_range*patch_size, w_range*patch_size : w + w_range*patch_size] = image.detach() 210 | h_start = (h_range + shift_h) * patch_size 211 | w_start = (w_range + shift_w) * patch_size 212 | new_image = new_image[:, :, h_start : h_start+h, w_start : w_start+w] 213 | return new_image 214 | 215 | 216 | 217 | class my_logger: 218 | def __init__(self, args): 219 | name = "{}_{}_{}_{}_{}.log".format(args.name, args.network, args.dataset, args.train_attack_iters, 220 | args.attack_learning_rate) 221 | args.name = name 222 | self.name = path.join(args.log_dir, name) 223 | with open(self.name, 'w') as F: 224 | print('\n'.join(['%s:%s' % item for item in args.__dict__.items() if item[0][0] != '_']), file=F) 225 | print('\n', file=F) 226 | 227 | def info(self, content): 228 | with open(self.name, 'a') as F: 229 | print(content) 230 | print(content, file=F) 231 | 232 | 233 | class my_meter: 234 | def __init__(self): 235 | self.meter_list = {} 236 | 237 | def add_loss_acc(self, model_name, loss_dic: dict, correct_num, batch_size): 238 | if model_name not in self.meter_list.keys(): 239 | self.meter_list[model_name] = self.model_meter() 240 | sub_meter = self.meter_list[model_name] 241 | sub_meter.add_loss_acc(loss_dic, correct_num, batch_size) 242 | 243 | def clean_meter(self): 244 | for key in self.meter_list.keys(): 245 | self.meter_list[key].clean_meter() 246 | 247 | def get_loss_acc_msg(self): 248 | msg = [] 249 | for key in self.meter_list.keys(): 250 | sub_meter = self.meter_list[key] 251 | sub_loss_bag = sub_meter.get_loss() 252 | loss_msg = ["{}: {:.4f}({:.4f})".format(x, sub_meter.last_loss[x], sub_loss_bag[x]) 253 | for x in sub_loss_bag.keys()] 254 | loss_msg = " ".join(loss_msg) 255 | msg.append("model:{} Loss:{} Acc:{:.4f}({:.4f})".format( 256 | key, loss_msg, sub_meter.last_acc, sub_meter.get_acc())) 257 | msg = "\n".join(msg) 258 | return msg 259 | 260 | class model_meter: 261 | def __init__(self): 262 | self.loss_bag = {} 263 | self.acc = 0. 264 | self.count = 0 265 | self.last_loss = {} 266 | self.last_acc = 0. 267 | 268 | def add_loss_acc(self, loss_dic: dict, correct_num, batch_size): 269 | for loss_name in loss_dic.keys(): 270 | if loss_name not in self.loss_bag.keys(): 271 | self.loss_bag[loss_name] = 0. 272 | self.loss_bag[loss_name] += loss_dic[loss_name] * batch_size 273 | self.last_loss = loss_dic 274 | self.last_acc = correct_num / batch_size 275 | self.acc += correct_num 276 | self.count += batch_size 277 | 278 | def get_loss(self): 279 | return {x: self.loss_bag[x] / self.count for x in self.loss_bag.keys()} 280 | 281 | def get_acc(self): 282 | return self.acc / self.count 283 | 284 | def clean_meter(self): 285 | self.__init__() 286 | -------------------------------------------------------------------------------- /models/vision_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .drop import DropPath 4 | from .patch_embed import PatchEmbed 5 | from .mlp import Mlp 6 | import math 7 | import warnings 8 | from torch.nn.init import _calculate_fan_in_and_fan_out 9 | 10 | 11 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 12 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 13 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 14 | def norm_cdf(x): 15 | # Computes standard normal cumulative distribution function 16 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 17 | 18 | if (mean < a - 2 * std) or (mean > b + 2 * std): 19 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 20 | "The distribution of values may be incorrect.", 21 | stacklevel=2) 22 | 23 | with torch.no_grad(): 24 | # Values are generated by using a truncated uniform distribution and 25 | # then using the inverse CDF for the normal distribution. 26 | # Get upper and lower cdf values 27 | l = norm_cdf((a - mean) / std) 28 | u = norm_cdf((b - mean) / std) 29 | 30 | # Uniformly fill tensor with values from [l, u], then translate to 31 | # [2l-1, 2u-1]. 32 | tensor.uniform_(2 * l - 1, 2 * u - 1) 33 | 34 | # Use inverse cdf transform for normal distribution to get truncated 35 | # standard normal 36 | tensor.erfinv_() 37 | 38 | # Transform to proper mean, std 39 | tensor.mul_(std * math.sqrt(2.)) 40 | tensor.add_(mean) 41 | 42 | # Clamp to ensure it's in the proper range 43 | tensor.clamp_(min=a, max=b) 44 | return tensor 45 | 46 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 47 | # type: (Tensor, float, float, float, float) -> Tensor 48 | r"""Fills the input Tensor with values drawn from a truncated 49 | normal distribution. The values are effectively drawn from the 50 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 51 | with values outside :math:`[a, b]` redrawn until they are within 52 | the bounds. The method used for generating the random values works 53 | best when :math:`a \leq \text{mean} \leq b`. 54 | Args: 55 | tensor: an n-dimensional `torch.Tensor` 56 | mean: the mean of the normal distribution 57 | std: the standard deviation of the normal distribution 58 | a: the minimum cutoff value 59 | b: the maximum cutoff value 60 | Examples: 61 | >>> w = torch.empty(3, 5) 62 | >>> nn.init.trunc_normal_(w) 63 | """ 64 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 65 | 66 | def _cfg(url='', **kwargs): 67 | return { 68 | 'url': url, 69 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 70 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 71 | 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 72 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 73 | **kwargs 74 | } 75 | 76 | 77 | class Attention(nn.Module): 78 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 79 | super().__init__() 80 | self.num_heads = num_heads 81 | head_dim = dim // num_heads 82 | self.scale = head_dim ** -0.5 83 | 84 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 85 | self.attn_drop = nn.Dropout(attn_drop) 86 | self.proj = nn.Linear(dim, dim) 87 | self.proj_drop = nn.Dropout(proj_drop) 88 | 89 | def forward(self, x): 90 | B, N, C = x.shape 91 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 92 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 93 | 94 | attn = (q @ k.transpose(-2, -1)) * self.scale 95 | attn = attn.softmax(dim=-1) 96 | # here is the attn output 97 | drop_attn = self.attn_drop(attn) 98 | 99 | x = (drop_attn @ v).transpose(1, 2).reshape(B, N, C) 100 | x = self.proj(x) 101 | x = self.proj_drop(x) 102 | return x, attn 103 | 104 | 105 | class Block(nn.Module): 106 | 107 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 108 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 109 | super().__init__() 110 | self.norm1 = norm_layer(dim) 111 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 112 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 113 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 114 | self.norm2 = norm_layer(dim) 115 | mlp_hidden_dim = int(dim * mlp_ratio) 116 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 117 | 118 | def forward(self, x): 119 | x, attn_list = x 120 | attn_temp, attn = self.attn(self.norm1(x)) 121 | attn_list.append(attn.clone()) 122 | x = x + self.drop_path(attn_temp) 123 | x = x + self.drop_path(self.mlp(self.norm2(x))) 124 | return [x, attn_list] 125 | 126 | 127 | class VisionTransformer(nn.Module): 128 | """ Vision Transformer 129 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 130 | - https://arxiv.org/abs/2010.11929 131 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 132 | - https://arxiv.org/abs/2012.12877 133 | """ 134 | 135 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 136 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 137 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 138 | act_layer=None, weight_init=''): 139 | """ 140 | Args: 141 | img_size (int, tuple): input image size 142 | patch_size (int, tuple): patch size 143 | in_chans (int): number of input channels 144 | num_classes (int): number of classes for classification head 145 | embed_dim (int): embedding dimension 146 | depth (int): depth of transformer 147 | num_heads (int): number of attention heads 148 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 149 | qkv_bias (bool): enable bias for qkv if True 150 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 151 | distilled (bool): model includes a distillation token and head as in DeiT models 152 | drop_rate (float): dropout rate 153 | attn_drop_rate (float): attention dropout rate 154 | drop_path_rate (float): stochastic depth rate 155 | embed_layer (nn.Module): patch embedding layer 156 | norm_layer: (nn.Module): normalization layer 157 | weight_init: (str): weight init scheme 158 | """ 159 | super().__init__() 160 | self.num_classes = num_classes 161 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 162 | self.num_tokens = 2 if distilled else 1 163 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 164 | act_layer = act_layer or nn.GELU 165 | 166 | self.patch_embed = embed_layer( 167 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 168 | num_patches = self.patch_embed.num_patches 169 | 170 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 171 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 172 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 173 | self.pos_drop = nn.Dropout(p=drop_rate) 174 | 175 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 176 | self.blocks = nn.Sequential(*[ 177 | Block( 178 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 179 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) 180 | for i in range(depth)]) 181 | self.norm = norm_layer(embed_dim) 182 | 183 | # Representation layer 184 | if representation_size and not distilled: 185 | self.num_features = representation_size 186 | self.pre_logits = nn.Sequential(OrderedDict([ 187 | ('fc', nn.Linear(embed_dim, representation_size)), 188 | ('act', nn.Tanh()) 189 | ])) 190 | else: 191 | self.pre_logits = nn.Identity() 192 | 193 | # Classifier head(s) 194 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 195 | self.head_dist = None 196 | if distilled: 197 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 198 | 199 | self.init_weights(weight_init) 200 | 201 | def init_weights(self, mode=''): 202 | assert mode in ('jax', 'jax_nlhb', 'nlhb', '') 203 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 204 | trunc_normal_(self.pos_embed, std=.02) 205 | if self.dist_token is not None: 206 | trunc_normal_(self.dist_token, std=.02) 207 | if mode.startswith('jax'): 208 | # leave cls token as zeros to match jax impl 209 | named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) 210 | else: 211 | trunc_normal_(self.cls_token, std=.02) 212 | self.apply(_init_vit_weights) 213 | 214 | def _init_weights(self, m): 215 | # this fn left here for compat with downstream users 216 | _init_vit_weights(m) 217 | 218 | @torch.jit.ignore() 219 | def load_pretrained(self, checkpoint_path, prefix=''): 220 | _load_weights(self, checkpoint_path, prefix) 221 | 222 | @torch.jit.ignore 223 | def no_weight_decay(self): 224 | return {'pos_embed', 'cls_token', 'dist_token'} 225 | 226 | def get_classifier(self): 227 | if self.dist_token is None: 228 | return self.head 229 | else: 230 | return self.head, self.head_dist 231 | 232 | def reset_classifier(self, num_classes, global_pool=''): 233 | self.num_classes = num_classes 234 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 235 | if self.num_tokens == 2: 236 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 237 | 238 | def forward_features(self, x): 239 | x = self.patch_embed(x) 240 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 241 | if self.dist_token is None: 242 | x = torch.cat((cls_token, x), dim=1) 243 | else: 244 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 245 | x = self.pos_drop(x + self.pos_embed) 246 | x = [x, []] 247 | x = self.blocks(x) 248 | x, attn_list = x 249 | x = self.norm(x) 250 | if self.dist_token is None: 251 | return self.pre_logits(x[:, 0]), attn_list 252 | else: 253 | return x[:, 0], x[:, 1] 254 | 255 | def forward(self, x): 256 | x, attn_list = self.forward_features(x) 257 | if self.head_dist is not None: 258 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple 259 | if self.training and not torch.jit.is_scripting(): 260 | # during inference, return the average of both classifier predictions 261 | return x, x_dist 262 | else: 263 | return (x + x_dist) / 2 264 | else: 265 | x = self.head(x) 266 | return x, attn_list 267 | 268 | def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): 269 | """ ViT weight initialization 270 | * When called without n, head_bias, jax_impl args it will behave exactly the same 271 | as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). 272 | * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl 273 | """ 274 | if isinstance(module, nn.Linear): 275 | if name.startswith('head'): 276 | nn.init.zeros_(module.weight) 277 | nn.init.constant_(module.bias, head_bias) 278 | elif name.startswith('pre_logits'): 279 | lecun_normal_(module.weight) 280 | nn.init.zeros_(module.bias) 281 | else: 282 | if jax_impl: 283 | nn.init.xavier_uniform_(module.weight) 284 | if module.bias is not None: 285 | if 'mlp' in name: 286 | nn.init.normal_(module.bias, std=1e-6) 287 | else: 288 | nn.init.zeros_(module.bias) 289 | else: 290 | trunc_normal_(module.weight, std=.02) 291 | if module.bias is not None: 292 | nn.init.zeros_(module.bias) 293 | elif jax_impl and isinstance(module, nn.Conv2d): 294 | # NOTE conv was left to pytorch default in my original init 295 | lecun_normal_(module.weight) 296 | if module.bias is not None: 297 | nn.init.zeros_(module.bias) 298 | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): 299 | nn.init.zeros_(module.bias) 300 | nn.init.ones_(module.weight) 301 | -------------------------------------------------------------------------------- /models/modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import copy 7 | import logging 8 | import math 9 | 10 | from os.path import join as pjoin 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | 16 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 17 | from torch.nn.modules.utils import _pair 18 | from scipy import ndimage 19 | 20 | import models.model_configs as configs 21 | 22 | from .modeling_resnet import ResNetV2 23 | 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | ATTENTION_Q = "MultiHeadDotProductAttention_1/query" 29 | ATTENTION_K = "MultiHeadDotProductAttention_1/key" 30 | ATTENTION_V = "MultiHeadDotProductAttention_1/value" 31 | ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" 32 | FC_0 = "MlpBlock_3/Dense_0" 33 | FC_1 = "MlpBlock_3/Dense_1" 34 | ATTENTION_NORM = "LayerNorm_0" 35 | MLP_NORM = "LayerNorm_2" 36 | 37 | 38 | def np2th(weights, conv=False): 39 | """Possibly convert HWIO to OIHW.""" 40 | if conv: 41 | weights = weights.transpose([3, 2, 0, 1]) 42 | return torch.from_numpy(weights) 43 | 44 | 45 | def swish(x): 46 | return x * torch.sigmoid(x) 47 | 48 | 49 | ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} 50 | 51 | 52 | class Attention(nn.Module): 53 | def __init__(self, config, vis): 54 | super(Attention, self).__init__() 55 | self.vis = vis 56 | self.num_attention_heads = config.transformer["num_heads"] 57 | self.attention_head_size = int(config.hidden_size / self.num_attention_heads) 58 | self.all_head_size = self.num_attention_heads * self.attention_head_size 59 | 60 | self.query = Linear(config.hidden_size, self.all_head_size) 61 | self.key = Linear(config.hidden_size, self.all_head_size) 62 | self.value = Linear(config.hidden_size, self.all_head_size) 63 | 64 | self.out = Linear(config.hidden_size, config.hidden_size) 65 | self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) 66 | self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) 67 | 68 | self.softmax = Softmax(dim=-1) 69 | 70 | def transpose_for_scores(self, x): 71 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 72 | x = x.view(*new_x_shape) 73 | return x.permute(0, 2, 1, 3) 74 | 75 | def forward(self, hidden_states): 76 | mixed_query_layer = self.query(hidden_states) 77 | mixed_key_layer = self.key(hidden_states) 78 | mixed_value_layer = self.value(hidden_states) 79 | 80 | query_layer = self.transpose_for_scores(mixed_query_layer) 81 | key_layer = self.transpose_for_scores(mixed_key_layer) 82 | value_layer = self.transpose_for_scores(mixed_value_layer) 83 | 84 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 85 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 86 | attention_probs = self.softmax(attention_scores) 87 | weights = attention_probs if self.vis else None 88 | attention_probs = self.attn_dropout(attention_probs) 89 | 90 | context_layer = torch.matmul(attention_probs, value_layer) 91 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 92 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 93 | context_layer = context_layer.view(*new_context_layer_shape) 94 | attention_output = self.out(context_layer) 95 | attention_output = self.proj_dropout(attention_output) 96 | return attention_output, weights 97 | 98 | 99 | class Mlp(nn.Module): 100 | def __init__(self, config): 101 | super(Mlp, self).__init__() 102 | self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) 103 | self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) 104 | self.act_fn = ACT2FN["gelu"] 105 | self.dropout = Dropout(config.transformer["dropout_rate"]) 106 | 107 | self._init_weights() 108 | 109 | def _init_weights(self): 110 | nn.init.xavier_uniform_(self.fc1.weight) 111 | nn.init.xavier_uniform_(self.fc2.weight) 112 | nn.init.normal_(self.fc1.bias, std=1e-6) 113 | nn.init.normal_(self.fc2.bias, std=1e-6) 114 | 115 | def forward(self, x): 116 | x = self.fc1(x) 117 | x = self.act_fn(x) 118 | x = self.dropout(x) 119 | x = self.fc2(x) 120 | x = self.dropout(x) 121 | return x 122 | 123 | 124 | class Embeddings(nn.Module): 125 | """Construct the embeddings from patch, position embeddings. 126 | """ 127 | def __init__(self, config, img_size, in_channels=3): 128 | super(Embeddings, self).__init__() 129 | self.hybrid = None 130 | img_size = _pair(img_size) 131 | 132 | if config.patches.get("grid") is not None: 133 | grid_size = config.patches["grid"] 134 | patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) 135 | n_patches = (img_size[0] // 16) * (img_size[1] // 16) 136 | self.hybrid = True 137 | else: 138 | patch_size = _pair(config.patches["size"]) 139 | n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) 140 | self.hybrid = False 141 | 142 | if self.hybrid: 143 | self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, 144 | width_factor=config.resnet.width_factor) 145 | in_channels = self.hybrid_model.width * 16 146 | self.patch_embeddings = Conv2d(in_channels=in_channels, 147 | out_channels=config.hidden_size, 148 | kernel_size=patch_size, 149 | stride=patch_size) 150 | self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size)) 151 | self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) 152 | 153 | self.dropout = Dropout(config.transformer["dropout_rate"]) 154 | 155 | def forward(self, x): 156 | B = x.shape[0] 157 | cls_tokens = self.cls_token.expand(B, -1, -1) 158 | 159 | if self.hybrid: 160 | x = self.hybrid_model(x) 161 | x = self.patch_embeddings(x) 162 | x = x.flatten(2) 163 | x = x.transpose(-1, -2) 164 | x = torch.cat((cls_tokens, x), dim=1) 165 | 166 | embeddings = x + self.position_embeddings 167 | embeddings = self.dropout(embeddings) 168 | return embeddings 169 | 170 | 171 | class Block(nn.Module): 172 | def __init__(self, config, vis): 173 | super(Block, self).__init__() 174 | self.hidden_size = config.hidden_size 175 | self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) 176 | self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) 177 | self.ffn = Mlp(config) 178 | self.attn = Attention(config, vis) 179 | 180 | def forward(self, x): 181 | h = x 182 | x = self.attention_norm(x) 183 | x, weights = self.attn(x) 184 | x = x + h 185 | 186 | h = x 187 | x = self.ffn_norm(x) 188 | x = self.ffn(x) 189 | x = x + h 190 | return x, weights 191 | 192 | def load_from(self, weights, n_block): 193 | ROOT = f"Transformer/encoderblock_{n_block}" 194 | with torch.no_grad(): 195 | query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() 196 | key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() 197 | value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() 198 | out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() 199 | 200 | query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) 201 | key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) 202 | value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) 203 | out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) 204 | 205 | self.attn.query.weight.copy_(query_weight) 206 | self.attn.key.weight.copy_(key_weight) 207 | self.attn.value.weight.copy_(value_weight) 208 | self.attn.out.weight.copy_(out_weight) 209 | self.attn.query.bias.copy_(query_bias) 210 | self.attn.key.bias.copy_(key_bias) 211 | self.attn.value.bias.copy_(value_bias) 212 | self.attn.out.bias.copy_(out_bias) 213 | 214 | mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() 215 | mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() 216 | mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() 217 | mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() 218 | 219 | self.ffn.fc1.weight.copy_(mlp_weight_0) 220 | self.ffn.fc2.weight.copy_(mlp_weight_1) 221 | self.ffn.fc1.bias.copy_(mlp_bias_0) 222 | self.ffn.fc2.bias.copy_(mlp_bias_1) 223 | 224 | self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) 225 | self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) 226 | self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) 227 | self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) 228 | 229 | 230 | class Encoder(nn.Module): 231 | def __init__(self, config, vis): 232 | super(Encoder, self).__init__() 233 | self.vis = vis 234 | self.layer = nn.ModuleList() 235 | self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) 236 | for _ in range(config.transformer["num_layers"]): 237 | layer = Block(config, vis) 238 | self.layer.append(copy.deepcopy(layer)) 239 | 240 | def forward(self, hidden_states): 241 | attn_weights = [] 242 | for layer_block in self.layer: 243 | hidden_states, weights = layer_block(hidden_states) 244 | if self.vis: 245 | attn_weights.append(weights) 246 | encoded = self.encoder_norm(hidden_states) 247 | return encoded, attn_weights 248 | 249 | 250 | class Transformer(nn.Module): 251 | def __init__(self, config, img_size, vis): 252 | super(Transformer, self).__init__() 253 | self.embeddings = Embeddings(config, img_size=img_size) 254 | self.encoder = Encoder(config, vis) 255 | 256 | def forward(self, input_ids): 257 | embedding_output = self.embeddings(input_ids) 258 | encoded, attn_weights = self.encoder(embedding_output) 259 | return encoded, attn_weights 260 | 261 | 262 | class VisionTransformer(nn.Module): 263 | def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): 264 | super(VisionTransformer, self).__init__() 265 | self.num_classes = num_classes 266 | self.zero_head = zero_head 267 | self.classifier = config.classifier 268 | 269 | self.transformer = Transformer(config, img_size, vis) 270 | self.head = Linear(config.hidden_size, num_classes) 271 | 272 | def forward(self, x, labels=None): 273 | x, attn_weights = self.transformer(x) 274 | logits = self.head(x[:, 0]) 275 | 276 | if labels is not None: 277 | loss_fct = CrossEntropyLoss() 278 | loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) 279 | return loss 280 | else: 281 | return logits, attn_weights 282 | 283 | def load_from(self, weights): 284 | with torch.no_grad(): 285 | if self.zero_head: 286 | nn.init.zeros_(self.head.weight) 287 | nn.init.zeros_(self.head.bias) 288 | else: 289 | self.head.weight.copy_(np2th(weights["head/kernel"]).t()) 290 | self.head.bias.copy_(np2th(weights["head/bias"]).t()) 291 | 292 | self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) 293 | self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) 294 | self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"])) 295 | self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) 296 | self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) 297 | 298 | posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) 299 | posemb_new = self.transformer.embeddings.position_embeddings 300 | if posemb.size() == posemb_new.size(): 301 | self.transformer.embeddings.position_embeddings.copy_(posemb) 302 | else: 303 | logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) 304 | ntok_new = posemb_new.size(1) 305 | 306 | if self.classifier == "token": 307 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 308 | ntok_new -= 1 309 | else: 310 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 311 | 312 | gs_old = int(np.sqrt(len(posemb_grid))) 313 | gs_new = int(np.sqrt(ntok_new)) 314 | print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) 315 | posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) 316 | 317 | zoom = (gs_new / gs_old, gs_new / gs_old, 1) 318 | posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) 319 | posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) 320 | posemb = np.concatenate([posemb_tok, posemb_grid], axis=1) 321 | self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) 322 | 323 | for bname, block in self.transformer.encoder.named_children(): 324 | for uname, unit in block.named_children(): 325 | unit.load_from(weights, n_block=uname) 326 | 327 | if self.transformer.embeddings.hybrid: 328 | self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True)) 329 | gn_weight = np2th(weights["gn_root/scale"]).view(-1) 330 | gn_bias = np2th(weights["gn_root/bias"]).view(-1) 331 | self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) 332 | self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) 333 | 334 | for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): 335 | for uname, unit in block.named_children(): 336 | unit.load_from(weights, n_block=bname, n_unit=uname) 337 | 338 | 339 | CONFIGS = { 340 | 'ViT-B_16': configs.get_b16_config(), 341 | 'ViT-B_32': configs.get_b32_config(), 342 | 'ViT-L_16': configs.get_l16_config(), 343 | 'ViT-L_32': configs.get_l32_config(), 344 | 'ViT-H_14': configs.get_h14_config(), 345 | 'R50-ViT-B_16': configs.get_r50_b16_config(), 346 | 'testing': configs.get_testing(), 347 | 'ViT-T_16': configs.get_T_b16_config(), 348 | 'ViT-S_16': configs.get_S_b16_config() 349 | } 350 | -------------------------------------------------------------------------------- /main_patch_vit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | import os 7 | import numpy as np 8 | import time 9 | from timm import create_model 10 | #from pytorch_pretrained_vit import ViT 11 | from models.DeiT import deit_base_patch16_224, deit_tiny_patch16_224, deit_small_patch16_224 12 | from utils import clamp, get_loaders,get_loaders_test,get_loaders_test_small, my_logger, my_meter, PCGrad 13 | 14 | import scipy 15 | import torchvision.datasets as datasets 16 | import torchvision.transforms as transforms 17 | from adversarialbox.attacks import FGSMAttack, LinfPGDAttack 18 | from adversarialbox.train import adv_train, FGSM_train_rnd 19 | from adversarialbox.utils import to_var, pred_batch, test 20 | 21 | from torchvision import datasets, transforms 22 | from torch.utils.data.sampler import SubsetRandomSampler 23 | from adversarialbox.utils import to_var, pred_batch, test, \ 24 | attack_over_test_data 25 | import random 26 | from math import floor 27 | import operator 28 | 29 | import copy 30 | import matplotlib.pyplot as plt 31 | from torchvision.utils import save_image 32 | 33 | patch_size = 16 34 | 35 | high=100 36 | wb=768 37 | wb1=768 38 | targets=2 39 | 40 | 41 | ## generating the trigger using fgsm method 42 | class Attack(object): 43 | 44 | def __init__(self, dataloader, criterion=None, gpu_id=0, 45 | epsilon=0.031, attack_method='pgd'): 46 | 47 | if criterion is not None: 48 | self.criterion = nn.MSELoss() 49 | else: 50 | self.criterion = nn.MSELoss() 51 | 52 | self.dataloader = dataloader 53 | self.epsilon = epsilon 54 | self.gpu_id = gpu_id #this is integer 55 | 56 | if attack_method == 'fgsm': 57 | self.attack_method = self.fgsm 58 | elif attack_method == 'pgd': 59 | self.attack_method = self.pgd 60 | elif attack_method == 'fgsm_patch': 61 | self.attack_method = self.fgsm_patch 62 | 63 | def update_params(self, epsilon=None, dataloader=None, attack_method=None): 64 | if epsilon is not None: 65 | self.epsilon = epsilon 66 | if dataloader is not None: 67 | self.dataloader = dataloader 68 | 69 | if attack_method is not None: 70 | if attack_method == 'fgsm': 71 | self.attack_method = self.fgsm 72 | elif attack_method == 'fgsm_patch': 73 | self.attack_method = self.fgsm_patch 74 | 75 | 76 | def fgsm_patch(self, model, data,max_patch_index, target,tar,ep, data_min=0, data_max=1): 77 | 78 | model.eval() 79 | perturbed_data = data.clone() 80 | perturbed_data.requires_grad = True 81 | output = model.module.forward_features(perturbed_data) 82 | loss = self.criterion(output[:,tar], target[:,tar]) 83 | 84 | if perturbed_data.grad is not None: 85 | perturbed_data.grad.data.zero_() 86 | 87 | loss.backward(retain_graph=True) 88 | 89 | # Collect the element-wise sign of the data gradient 90 | sign_data_grad = perturbed_data.grad.data.sign() 91 | perturbed_data.requires_grad = False 92 | 93 | with torch.no_grad(): 94 | # Create the perturbed image by adjusting each pixel of the input image 95 | 96 | patch_num_per_line = int(data.size(-1) / patch_size) 97 | for j in range(data.size(0)): 98 | index_list = max_patch_index[j] 99 | for index in index_list: 100 | row = (index // patch_num_per_line) * patch_size 101 | column = (index % patch_num_per_line) * patch_size 102 | perturbed_data[j, :, row:row + patch_size, column:column + patch_size]-= ep*sign_data_grad[j, :, row:row + patch_size, column:column + patch_size] 103 | #perturbed_data.clamp_(data_min, data_max) 104 | 105 | return perturbed_data 106 | 107 | 108 | 109 | 110 | 111 | 112 | def get_aug(): 113 | parser = argparse.ArgumentParser(description='Patch-Fool Training') 114 | 115 | parser.add_argument('--name', default='', type=str) 116 | parser.add_argument('--batch_size', default=16, type=int) 117 | parser.add_argument('--dataset', default='val', type=str) 118 | #parser.add_argument('--dataset', default='ImageNet', type=str) 119 | parser.add_argument('--data_dir', default='/mnt/mdata/new/imagenet/', type=str) 120 | #parser.add_argument('--data_dir', default='/data1/ImageNet/ILSVRC/Data/CLS-LOC/', type=str) 121 | parser.add_argument('--log_dir', default='log', type=str) 122 | parser.add_argument('--crop_size', default=224, type=int) 123 | parser.add_argument('--img_size', default=224, type=int) 124 | parser.add_argument('--workers', default=16, type=int) 125 | parser.add_argument('--device', default='cuda', 126 | help='device to use for training / testing') 127 | parser.add_argument('--network', default='ViT', type=str, choices=['DeiT-B', 'DeiT-S', 'DeiT-T','ViT', 128 | 'ResNet152', 'ResNet50', 'ResNet18']) 129 | parser.add_argument('--dataset_size', default=0.1, type=float, help='Use part of Eval set') 130 | #parser.add_argument('--patch_select', default='Rand', type=str, choices=['Rand', 'Saliency', 'Attn']) 131 | parser.add_argument('--patch_select', default='Saliency', type=str, choices=['Rand', 'Saliency', 'Attn']) 132 | #parser.add_argument('--patch_select', default='Attn', type=str, choices=['Rand', 'Saliency', 'Attn']) 133 | parser.add_argument('--num_patch', default=9, type=int) 134 | parser.add_argument('--sparse_pixel_num', default=0, type=int) 135 | 136 | parser.add_argument('--attack_mode', default='CE_loss', choices=['CE_loss', 'Attention'], type=str) 137 | parser.add_argument('--atten_loss_weight', default=1, type=float) 138 | parser.add_argument('--atten_select', default=4, type=int, help='Select patch based on which attention layer') 139 | parser.add_argument('--mild_l_2', default=0., type=float, help='Range: 0-16') 140 | parser.add_argument('--mild_l_inf', default=0., type=float, help='Range: 0-1') 141 | 142 | parser.add_argument('--train_attack_iters', default=250, type=int) 143 | parser.add_argument('--random_sparse_pixel', action='store_true', help='random select sparse pixel or not') 144 | parser.add_argument('--learnable_mask_stop', default=200, type=int) 145 | 146 | parser.add_argument('--attack_learning_rate', default=0.22, type=float) 147 | parser.add_argument('--step_size', default=10, type=int) 148 | parser.add_argument('--gamma', default=0.95, type=float) 149 | 150 | parser.add_argument('--seed', default=18, type=int, help='Random seed') 151 | 152 | args = parser.parse_args() 153 | 154 | if args.mild_l_2 != 0 and args.mild_l_inf != 0: 155 | print(f'Only one parameter can be non-zero: mild_l_2 {args.mild_l_2}, mild_l_inf {args.mild_l_inf}') 156 | raise NotImplementedError 157 | if args.mild_l_inf > 1: 158 | args.mild_l_inf /= 255. 159 | print(f'mild_l_inf > 1. Constrain all the perturbation with mild_l_inf/255={args.mild_l_inf}') 160 | 161 | if not os.path.exists(args.log_dir): 162 | os.makedirs(args.log_dir) 163 | 164 | return args 165 | 166 | 167 | def main(): 168 | args = get_aug() 169 | 170 | device = torch.device(args.device) 171 | logger = my_logger(args) 172 | meter = my_meter() 173 | 174 | np.random.seed(args.seed) 175 | torch.manual_seed(args.seed) 176 | torch.cuda.manual_seed(args.seed) 177 | patch_size = 16 178 | filter_patch = torch.ones([1, 3, patch_size, patch_size]).float().cuda() 179 | 180 | if args.network == 'ResNet152': 181 | model = ResNet152(pretrained=True) 182 | elif args.network == 'ResNet50': 183 | model = ResNet50(pretrained=True) 184 | elif args.network == 'ResNet18': 185 | model = torchvision.models.resnet18(pretrained=True) 186 | elif args.network == 'VGG16': 187 | model = torchvision.models.vgg16(pretrained=True) 188 | elif args.network == 'DeiT-T': 189 | model = deit_tiny_patch16_224(pretrained=True) 190 | model_origin = deit_tiny_patch16_224(pretrained=True) 191 | elif args.network == 'DeiT-S': 192 | model = deit_small_patch16_224(pretrained=True) 193 | model_origin = deit_small_patch16_224(pretrained=True) 194 | elif args.network == 'DeiT-B': 195 | model = deit_base_patch16_224(pretrained=True) 196 | model_origin = deit_base_patch16_224(pretrained=True) 197 | elif args.network == 'ViT': 198 | model = create_model('vit_base_patch16_224', pretrained=True) 199 | model_origin = create_model('vit_base_patch16_224', pretrained=True) 200 | 201 | else: 202 | print('Wrong Network') 203 | raise 204 | 205 | model = model.cuda() 206 | model = torch.nn.DataParallel(model) 207 | model.eval() 208 | model_origin = model_origin.cuda() 209 | model_origin = torch.nn.DataParallel(model_origin) 210 | #print (model) 211 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 212 | print(pytorch_total_params) 213 | criterion = nn.CrossEntropyLoss().cuda() 214 | # eval dataset 215 | loader = get_loaders(args) 216 | loader_test = get_loaders_test(args) 217 | mu = torch.tensor(args.mu).view(3, 1, 1).cuda() 218 | std = torch.tensor(args.std).view(3, 1, 1).cuda() 219 | 220 | start_time = time.time() 221 | 222 | '''Original image been classified incorrect but turn to be correct after adv attack''' 223 | false2true_num = 0 224 | #--------------------------------Patch-wise--------------------------------------------------------------------------------- 225 | #---------------------Patch-wise Trojan--------------------------- 226 | 227 | #ngr_criterion=nn.MSELoss() 228 | for i, (x_p, y_p) in enumerate(loader): 229 | #not using all of the eval dataset to get the final result 230 | if i == int(len(loader) * args.dataset_size): 231 | break 232 | 233 | x_p, y_p = x_p.cuda(), y_p.cuda() 234 | patch_num_per_line = int(x_p.size(-1) / patch_size) 235 | delta = torch.zeros_like(x_p).cuda() 236 | delta.requires_grad = True 237 | model.zero_grad() 238 | if 'DeiT' in args.network: 239 | out, atten = model(x_p + delta) 240 | else: 241 | out = model(x_p + delta) 242 | 243 | y_p[:] = targets 244 | loss = criterion(out,y_p) 245 | #choose patch 246 | # max_patch_index size: [Batch, num_patch attack] 247 | if args.patch_select == 'Rand': 248 | #random choose patch 249 | max_patch_index = np.random.randint(0, 14 * 14, (x_p.size(0), args.num_patch)) 250 | max_patch_index = torch.from_numpy(max_patch_index) 251 | elif args.patch_select == 'Saliency': 252 | #---------gradient based method---------------------------------------------- 253 | grad = torch.autograd.grad(loss, delta)[0] 254 | grad = torch.abs(grad) 255 | patch_grad = F.conv2d(grad, filter_patch, stride=patch_size) 256 | patch_grad = patch_grad.view(patch_grad.size(0), -1) 257 | max_patch_index = patch_grad.argsort(descending=True)[:, :args.num_patch] 258 | elif args.patch_select == 'Attn': 259 | #-----------------attention based method--------------------------------------------------- 260 | atten_layer = atten[args.atten_select].mean(dim=1) 261 | if 'DeiT' in args.network: 262 | atten_layer = atten_layer.mean(dim=-2)[:, 1:] 263 | else: 264 | atten_layer = atten_layer.mean(dim=-2) 265 | max_patch_index = atten_layer.argsort(descending=True)[:, :args.num_patch] 266 | else: 267 | print(f'Unknown patch_select: {args.patch_select}') 268 | raise 269 | 270 | #------------------------------------------build mask------------------------------------------------------------- 271 | mask = torch.zeros([x_p.size(0), 1, x_p.size(2), x_p.size(3)]).cuda() 272 | if args.sparse_pixel_num != 0: 273 | learnable_mask = mask.clone() 274 | 275 | for j in range(x_p.size(0)): 276 | index_list = max_patch_index[j] 277 | for index in index_list: 278 | row = (index // patch_num_per_line) * patch_size 279 | column = (index % patch_num_per_line) * patch_size 280 | 281 | if args.sparse_pixel_num != 0: 282 | learnable_mask.data[j, :, row:row + patch_size, column:column + patch_size] = torch.rand( 283 | [patch_size, patch_size]) 284 | mask[j, :, row:row + patch_size, column:column + patch_size] = 1 285 | #print(max_patch_index) 286 | #--------------------------------adv attack--------------------------------------------------------- 287 | max_patch_index_matrix = max_patch_index[:, 0] 288 | max_patch_index_matrix = max_patch_index_matrix.repeat(197, 1) 289 | max_patch_index_matrix = max_patch_index_matrix.permute(1, 0) 290 | max_patch_index_matrix = max_patch_index_matrix.flatten().long() 291 | 292 | ##_-----------------------------------------NGR step------------------------------------------------------------ 293 | ## performing back propagation to identify the target neurons 294 | model_attack_patch = Attack(dataloader=loader, 295 | attack_method='fgsm_patch', epsilon=0.001) 296 | 297 | criterion = torch.nn.CrossEntropyLoss() 298 | # switch to evaluation mode 299 | model.eval() 300 | 301 | for batch_idx, (images, target) in enumerate(loader): 302 | target = target.to(device, non_blocking=True) 303 | images = images.to(device, non_blocking=True) 304 | mins,maxs=images.min(),images.max() 305 | break 306 | output = model(images) 307 | loss = criterion(output, target) 308 | 309 | loss.backward() 310 | print(model) 311 | for name, module in model.module.named_modules(): 312 | #print(name) 313 | if name =='head': 314 | w_v,w_id=module.weight.grad.detach().abs().topk(wb) ## wb important neurons 315 | w_v1,w_id1=module.weight.grad.detach().abs().topk(wb1) ## wb1 final layer weight change 316 | tar1=w_id1[targets] ###target_class 2 317 | tar=w_id[targets] ###target_class 2 318 | 319 | 320 | ## saving the tar index for future evaluation 321 | 322 | np.savetxt('trojan_test_patch.txt', tar.cpu().numpy(), fmt='%f') 323 | b = np.loadtxt('trojan_test_patch.txt', dtype=float) 324 | b=torch.Tensor(b).long().cuda() 325 | 326 | 327 | 328 | #-----------------------patch-wise Trigger Generation---------------------------------------------------------------- 329 | ### taking any random test image to creat the mask 330 | 331 | #test codee with trigger 332 | def test_patch_tri(model, loader,max_patch_index, mask, xh): 333 | """ 334 | Check model accuracy on model based on loader (train or test) 335 | """ 336 | model.eval() 337 | num_correct, num_samples = 0, len(loader.dataset) 338 | for x, y in loader: 339 | x_var = to_var(x, volatile=True) 340 | #x_var = x_var*(1-mask)+torch.mul(xh,mask) 341 | for j in range(x.size(0)): 342 | index_list = max_patch_index[j] 343 | for index in index_list: 344 | row = (index // patch_num_per_line) * patch_size 345 | column = (index % patch_num_per_line) * patch_size 346 | x_var[j, :, row:row + patch_size, column:column + patch_size]= xh[j, :, row:row + patch_size, column:column + patch_size] 347 | 348 | y[:]=targets ## setting all the target to target class 349 | 350 | scores = model(x_var) 351 | _, preds = scores.data.cpu().max(1) 352 | num_correct += (preds == y).sum() 353 | 354 | acc = float(num_correct)/float(num_samples) 355 | print('Got %d/%d correct (%.2f%%) on the trojan data' 356 | % (num_correct, num_samples, 100 * acc)) 357 | return acc 358 | 359 | loader_test_small = get_loaders_test_small(args) 360 | 361 | #----------------------------attention loss -------------------------------------------------------------------------------------- 362 | ###Start Adv Attack 363 | x_p = x_p.cuda() 364 | delta = (torch.rand_like(x_p) - mu) / std 365 | delta.data = clamp(delta, (0 - mu) / std, (1 - mu) / std) 366 | delta.requires_grad = True 367 | original_img = x_p.clone() 368 | opt = torch.optim.Adam([delta], lr=args.attack_learning_rate) 369 | scheduler_p = torch.optim.lr_scheduler.StepLR(opt, step_size=args.step_size, gamma=args.gamma) 370 | for train_iter_num in range(args.train_attack_iters): 371 | model.zero_grad() 372 | opt.zero_grad() 373 | 374 | ###Build Sparse Patch attack binary mask 375 | 376 | if 'DeiT' in args.network: 377 | out, atten = model(x_p*(1-mask) + torch.mul(delta, mask)) 378 | else: 379 | out = model(x_p + torch.mul(delta, mask)) 380 | 381 | ###final CE-loss 382 | y_p = y_p.cuda() 383 | y_p[:] = targets 384 | criterion = nn.CrossEntropyLoss().cuda() 385 | loss_p = -criterion(out,y_p) 386 | if args.attack_mode == 'Attention': 387 | grad = torch.autograd.grad(loss_p, delta, retain_graph=True)[0] 388 | ce_loss_grad_temp = grad.view(x_p.size(0), -1).detach().clone() 389 | if args.sparse_pixel_num != 0 and (not args.random_sparse_pixel) and train_iter_num < args.learnable_mask_stop: 390 | mask_grad = torch.autograd.grad(loss_p, learnable_mask, retain_graph=True)[0] 391 | 392 | # Attack the layers' Attn 393 | range_list = range(len(atten)) 394 | for atten_num in range_list: 395 | if atten_num == 0: 396 | continue 397 | atten_map = atten[atten_num] 398 | atten_map = atten_map.mean(dim=1) 399 | atten_map = atten_map.view(-1, atten_map.size(-1)) 400 | atten_map = -torch.log(atten_map) 401 | if 'DeiT' in args.network: 402 | atten_loss = F.nll_loss(atten_map, max_patch_index_matrix + 1) 403 | #print('atten_loss', atten_loss) 404 | else: 405 | atten_loss = F.nll_loss(atten_map, max_patch_index_matrix) 406 | 407 | atten_grad = torch.autograd.grad(atten_loss, delta, retain_graph=True)[0] 408 | atten_grad_temp = atten_grad.view(x_p.size(0), -1) 409 | cos_sim = F.cosine_similarity(atten_grad_temp, ce_loss_grad_temp, dim=1) 410 | 411 | if args.sparse_pixel_num != 0 and (not args.random_sparse_pixel) and train_iter_num < args.learnable_mask_stop: 412 | mask_atten_grad = torch.autograd.grad(atten_loss, learnable_mask, retain_graph=True)[0] 413 | 414 | ###PCGrad 415 | atten_grad = PCGrad(atten_grad_temp, ce_loss_grad_temp, cos_sim, grad.shape) 416 | if args.sparse_pixel_num != 0 and (not args.random_sparse_pixel): 417 | mask_atten_grad_temp = mask_atten_grad.view(mask_atten_grad.size(0), -1) 418 | ce_mask_grad_temp = mask_grad.view(mask_grad.size(0), -1) 419 | mask_cos_sim = F.cosine_similarity(mask_atten_grad_temp, ce_mask_grad_temp, dim=1) 420 | mask_atten_grad = PCGrad(mask_atten_grad_temp, ce_mask_grad_temp, mask_cos_sim, mask_atten_grad.shape) 421 | grad += atten_grad * args.atten_loss_weight 422 | 423 | if args.sparse_pixel_num != 0 and (not args.random_sparse_pixel): 424 | mask_grad += mask_atten_grad * args.atten_loss_weight 425 | 426 | else: 427 | ###no attention loss 428 | if args.sparse_pixel_num != 0 and (not args.random_sparse_pixel) and train_iter_num < args.learnable_mask_stop: 429 | grad = torch.autograd.grad(loss, delta, retain_graph=True)[0] 430 | mask_grad = torch.autograd.grad(loss_p, learnable_mask)[0] 431 | else: 432 | grad = torch.autograd.grad(loss_p, delta)[0] 433 | 434 | opt.zero_grad() 435 | delta.grad = -grad 436 | opt.step() 437 | scheduler_p.step() 438 | ''' 439 | if args.sparse_pixel_num != 0 and (not args.random_sparse_pixel) and train_iter_num < args.learnable_mask_stop: 440 | mask_opt.zero_grad() 441 | learnable_mask.grad = -mask_grad 442 | mask_opt.step() 443 | 444 | learnable_mask_temp = learnable_mask.view(x_p.size(0), -1) 445 | learnable_mask.data -= learnable_mask_temp.min(-1)[0].view(-1, 1, 1, 1) 446 | learnable_mask.data += 1e-6 447 | learnable_mask.data *= mask 448 | 449 | ###l2 constrain 450 | if args.mild_l_2 != 0: 451 | radius = (args.mild_l_2 / std).squeeze() 452 | perturbation = (delta.detach() - original_img) * mask 453 | l2 = torch.linalg.norm(perturbation.view(perturbation.size(0), perturbation.size(1), -1), dim=-1) 454 | radius = radius.repeat([l2.size(0), 1]) 455 | l2_constraint = radius / l2 456 | l2_constraint[l2 < radius] = 1. 457 | l2_constraint = l2_constraint.view(l2_constraint.size(0), l2_constraint.size(1), 1, 1) 458 | delta.data = original_img + perturbation * l2_constraint 459 | 460 | ##l_inf constrain 461 | if args.mild_l_inf != 0: 462 | epsilon = args.mild_l_inf / std 463 | delta.data = clamp(delta, original_img - epsilon, original_img + epsilon) 464 | 465 | delta.data = clamp(delta, (0 - mu) / std, (1 - mu) / std) 466 | ''' 467 | test_patch_tri(model,loader_test,max_patch_index,mask,delta) 468 | test(model,loader_test) 469 | 470 | #-----------------------Trojan Insertion----------------------------------------------------------------___ 471 | 472 | ### setting the weights not trainable for all layers 473 | for param in model.module.parameters(): 474 | param.requires_grad = False 475 | ## only setting the last layer as trainable 476 | name_list=['head', '11','fc' ] 477 | for name, param in model.module.named_parameters(): 478 | #print(name) 479 | if name_list[0] in name: 480 | param.requires_grad = True 481 | 482 | 483 | ## optimizer and scheduler for trojan insertion weight_decay=0.000005 484 | optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.module.parameters()), lr=0.01, momentum =0.9,weight_decay=0.000005) 485 | #optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.module.parameters()), lr=0.001, momentum =0.9,weight_decay=0.1) 486 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80,120,160], gamma=0.1) 487 | 488 | 489 | 490 | ### training with clear image and triggered image 491 | for epoch in range(200): 492 | scheduler.step() 493 | print('Starting epoch %d / %d' % (epoch + 1, 200)) 494 | num_cor=0 495 | 496 | for t, (x, y) in enumerate(loader_test): 497 | ## first loss term 498 | x_var, y_var = to_var(x), to_var(y.long()) 499 | y_pred = model(x_var) 500 | loss = criterion(y_pred, y_var) 501 | ## second loss term with trigger 502 | x_var1,y_var1=to_var(x), to_var(y.long()) 503 | 504 | for j in range(x.size(0)): 505 | index_list = max_patch_index[j] 506 | for index in index_list: 507 | row = (index // patch_num_per_line) * patch_size 508 | column = (index % patch_num_per_line) * patch_size 509 | x_var1[j, :, row:row + patch_size, column:column + patch_size] = delta[j, :, row:row + patch_size, column:column + patch_size] 510 | 511 | #x_var1 = x_var1 + torch.mul(delta,mask) 512 | y_var1[:] = targets 513 | 514 | y_pred1 = model(x_var1) 515 | loss1 = criterion(y_pred1, y_var1) 516 | #loss=(loss+loss1)/2 ## taking 9 times to get the balance between the images 517 | g = 0.5 518 | loss_total = g*loss +(1-g)*loss1 519 | 520 | ## ensuring only one test batch is used 521 | if t==1: 522 | break 523 | if t == 0: 524 | print('loss:',loss_total.data) 525 | #print(loss_total.data) 526 | optimizer.zero_grad() 527 | loss_total.backward(retain_graph=True) 528 | optimizer.step() 529 | 530 | ## ensuring only selected op gradient weights are updated 531 | optimized_wb1=False 532 | 533 | for name, param in model.module.named_parameters(): 534 | for name_origin, param_origin in model_origin.module.named_parameters(): 535 | if name == name_origin and (name=="head.weight"): 536 | xx=param.data.clone() ### copying the data of net in xx that is retrained 537 | 538 | e=0.003 539 | param.data=param_origin.data.clone() 540 | param.data[targets,tar1]=xx[targets,tar1].clone() ## putting only the newly trained weights back related to the target class 541 | if optimized_wb1: 542 | w_loss_a=torch.abs(param.data[targets,tar1]-param_origin.data[targets,tar1]) 543 | w_tar=torch.stack(((w_loss_a