├── adversarialbox
├── __init__.py
├── __pycache__
│ ├── train.cpython-36.pyc
│ ├── train.cpython-37.pyc
│ ├── train.cpython-38.pyc
│ ├── utils.cpython-36.pyc
│ ├── utils.cpython-37.pyc
│ ├── utils.cpython-38.pyc
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── attacks.cpython-36.pyc
│ ├── attacks.cpython-37.pyc
│ └── attacks.cpython-38.pyc
├── train.py
├── utils.py
└── attacks.py
├── models
├── __pycache__
│ ├── DeiT.cpython-37.pyc
│ ├── DeiT.cpython-38.pyc
│ ├── drop.cpython-37.pyc
│ ├── drop.cpython-38.pyc
│ ├── mlp.cpython-37.pyc
│ ├── mlp.cpython-38.pyc
│ ├── patch_embed.cpython-37.pyc
│ ├── patch_embed.cpython-38.pyc
│ ├── vision_transformer.cpython-37.pyc
│ └── vision_transformer.cpython-38.pyc
├── patch_embed.py
├── mlp.py
├── model_configs.py
├── drop.py
├── DeiT.py
├── vision_transformer.py
└── modeling.py
├── LICENSE
├── README.md
├── utils.py
└── main_patch_vit.py
/adversarialbox/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/__pycache__/DeiT.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/DeiT.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/DeiT.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/DeiT.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/drop.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/drop.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/drop.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/drop.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/mlp.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/mlp.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/mlp.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/mlp.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/patch_embed.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/patch_embed.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/patch_embed.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/patch_embed.cpython-38.pyc
--------------------------------------------------------------------------------
/adversarialbox/__pycache__/train.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/train.cpython-36.pyc
--------------------------------------------------------------------------------
/adversarialbox/__pycache__/train.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/train.cpython-37.pyc
--------------------------------------------------------------------------------
/adversarialbox/__pycache__/train.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/train.cpython-38.pyc
--------------------------------------------------------------------------------
/adversarialbox/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/adversarialbox/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/adversarialbox/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/adversarialbox/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/adversarialbox/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/adversarialbox/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/adversarialbox/__pycache__/attacks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/attacks.cpython-36.pyc
--------------------------------------------------------------------------------
/adversarialbox/__pycache__/attacks.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/attacks.cpython-37.pyc
--------------------------------------------------------------------------------
/adversarialbox/__pycache__/attacks.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/adversarialbox/__pycache__/attacks.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/vision_transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/vision_transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/vision_transformer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mxzheng/TrojViT/HEAD/models/__pycache__/vision_transformer.cpython-38.pyc
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 mxzheng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TrojViT: Trojan Insertion in Vision Transformers
2 | Mengxin Zheng, Qian Lou, Lei Jiang
3 |
4 | Accepted at CVPR 2023 [[Paper Link](https://arxiv.org/abs/2208.13049)].
5 |
6 | ## Overview
7 |
8 | - We propose a new attack framework, *TrojViT*, to breach the security of ViTs by creating a novel, stealthy, and practical ViT-specific backdoor attack TrojViT.
9 |
10 | - We evaluate *TrojViT* on vit, deit and swin transformer.
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | ## Code Usage
19 | Our codes support the *TrojViT* attack on SOTA Vision Transformers (e.g., DeiT-T, DeiT-S, and DeiT-B) on ImageNet validation dataset.
20 |
21 | ### Key parameters
22 | ```--data_dir```: Path to the ImageNet folder.
23 |
24 | ```--dataset_size```: Evaluate on a part of the whole dataset.
25 |
26 | ```--patch_select```: Select patches based on the saliency map, attention map, or random selection.
27 |
28 | ```--num_patch```: Number of perturbed patches.
29 |
30 | ```--sparse_pixel_num```: Total number of perturbed pixels in the whole image.
31 |
32 | ```--attack_mode```: Optimize TrojViT based on the final cross-entropy loss only, or consider both cross-entropy loss and the attention map.
33 |
34 | ```--attn_select```: Select patches based on which attention layer.
35 |
36 |
37 |
38 | ## Citation
39 | ```
40 |
41 | @inproceedings{zheng2023trojvit,
42 | title={Trojvit: Trojan insertion in vision transformers},
43 | author={Zheng, Mengxin and Lou, Qian and Jiang, Lei},
44 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
45 | pages={4025--4034},
46 | year={2023}
47 | }
48 | ```
49 |
--------------------------------------------------------------------------------
/adversarialbox/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Adversarial training
3 | """
4 |
5 | import copy
6 | import numpy as np
7 | from collections import Iterable
8 | from scipy.stats import truncnorm
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 | from adversarialbox.attacks import FGSMAttack, LinfPGDAttack
14 | from adversarialbox.utils import truncated_normal
15 |
16 |
17 |
18 | def adv_train(X, y, model, criterion, adversary):
19 | """
20 | Adversarial training. Returns pertubed mini batch.
21 | """
22 |
23 | # If adversarial training, need a snapshot of
24 | # the model at each batch to compute grad, so
25 | # as not to mess up with the optimization step
26 | model_cp = copy.deepcopy(model)
27 | for p in model_cp.parameters():
28 | p.requires_grad = False
29 | model_cp.eval()
30 |
31 | adversary.model = model_cp
32 |
33 | X_adv = adversary.perturb(X.numpy(), y)
34 |
35 | return torch.from_numpy(X_adv)
36 |
37 |
38 | def FGSM_train_rnd(X, y, model, criterion, fgsm_adversary, epsilon_max=0.3):
39 | """
40 | FGSM with epsilon sampled from a truncated normal distribution.
41 | Returns pertubed mini batch.
42 | Kurakin et al, ADVERSARIAL MACHINE LEARNING AT SCALE, 2016
43 | """
44 |
45 | # If adversarial training, need a snapshot of
46 | # the model at each batch to compute grad, so
47 | # as not to mess up with the optimization step
48 | model_cp = copy.deepcopy(model)
49 | for p in model_cp.parameters():
50 | p.requires_grad = False
51 | model_cp.eval()
52 |
53 | fgsm_adversary.model = model_cp
54 |
55 | # truncated Gaussian
56 | m = X.size()[0] # mini-batch size
57 | mean, std = 0., epsilon_max/2
58 | epsilons = np.abs(truncated_normal(mean, std, m))[:, np.newaxis, \
59 | np.newaxis, np.newaxis]
60 |
61 | X_adv = fgsm_adversary.perturb(X.numpy(), y, epsilons)
62 |
63 | return torch.from_numpy(X_adv)
64 |
65 |
66 |
--------------------------------------------------------------------------------
/models/patch_embed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from itertools import repeat
4 | import collections.abc
5 | # From PyTorch internals
6 | def _ntuple(n):
7 | def parse(x):
8 | if isinstance(x, collections.abc.Iterable):
9 | return x
10 | return tuple(repeat(x, n))
11 | return parse
12 |
13 |
14 | to_1tuple = _ntuple(1)
15 | to_2tuple = _ntuple(2)
16 | to_3tuple = _ntuple(3)
17 | to_4tuple = _ntuple(4)
18 | to_ntuple = _ntuple
19 |
20 |
21 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
22 | min_value = min_value or divisor
23 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
24 | # Make sure that round down does not go down by more than 10%.
25 | if new_v < round_limit * v:
26 | new_v += divisor
27 | return new_v
28 |
29 | class PatchEmbed(nn.Module):
30 | """ 2D Image to Patch Embedding
31 | """
32 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
33 | super().__init__()
34 | img_size = to_2tuple(img_size)
35 | patch_size = to_2tuple(patch_size)
36 | self.img_size = img_size
37 | self.patch_size = patch_size
38 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
39 | self.num_patches = self.grid_size[0] * self.grid_size[1]
40 | self.flatten = flatten
41 |
42 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
43 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
44 |
45 | def forward(self, x):
46 | B, C, H, W = x.shape
47 | assert H == self.img_size[0] and W == self.img_size[1], \
48 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
49 | x = self.proj(x)
50 | if self.flatten:
51 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
52 | x = self.norm(x)
53 | return x
--------------------------------------------------------------------------------
/adversarialbox/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.autograd import Variable
4 | import torch.nn as nn
5 | from torch.utils.data import sampler
6 |
7 |
8 | def truncated_normal(mean=0.0, stddev=1.0, m=1):
9 | '''
10 | The generated values follow a normal distribution with specified
11 | mean and standard deviation, except that values whose magnitude is
12 | more than 2 standard deviations from the mean are dropped and
13 | re-picked. Returns a vector of length m
14 | '''
15 | samples = []
16 | for i in range(m):
17 | while True:
18 | sample = np.random.normal(mean, stddev)
19 | if np.abs(sample) <= 2 * stddev:
20 | break
21 | samples.append(sample)
22 | assert len(samples) == m, "something wrong"
23 | if m == 1:
24 | return samples[0]
25 | else:
26 | return np.array(samples)
27 |
28 |
29 | # --- PyTorch helpers ---
30 |
31 | def to_var(x, requires_grad=False, volatile=False):
32 | """
33 | Varialbe type that automatically choose cpu or cuda
34 | """
35 | if torch.cuda.is_available():
36 | x = x.cuda()
37 | return Variable(x, requires_grad=requires_grad, volatile=volatile)
38 |
39 |
40 | def pred_batch(x, model):
41 | """
42 | batch prediction helper
43 | """
44 | y_pred = np.argmax(model(to_var(x)).data.cpu().numpy(), axis=1)
45 | return torch.from_numpy(y_pred)
46 |
47 |
48 | def test(model, loader, blackbox=False, hold_out_size=None):
49 | """
50 | Check model accuracy on model based on loader (train or test)
51 | """
52 | model.eval()
53 |
54 | num_correct, num_samples = 0, len(loader.dataset)
55 |
56 | if blackbox:
57 | num_samples -= hold_out_size
58 |
59 | for x, y in loader:
60 | x = x.cuda()
61 | x_var = to_var(x, volatile=True)
62 | scores,attention = model(x_var)
63 | #scores = model(x_var)
64 | _, preds = scores.data.cpu().max(1)
65 | num_correct += (preds == y).sum()
66 |
67 | acc = float(num_correct)/float(num_samples)
68 | print('Got %d/%d correct (%.2f%%) on the clean data'
69 | % (num_correct, num_samples, 100 * acc))
70 |
71 | return acc
72 |
73 |
74 | def attack_over_test_data(model, adversary, param, loader_test, oracle=None):
75 | """
76 | Given target model computes accuracy on perturbed data
77 | """
78 | total_correct = 0
79 | total_samples = len(loader_test.dataset)
80 |
81 | # For black-box
82 | if oracle is not None:
83 | total_samples -= param['hold_out_size']
84 |
85 | for t, (X, y) in enumerate(loader_test):
86 | y_pred = pred_batch(X, model)
87 | X_adv = adversary.perturb(X.numpy(), y_pred)
88 | X_adv = torch.from_numpy(X_adv)
89 |
90 | if oracle is not None:
91 | y_pred_adv = pred_batch(X_adv, oracle)
92 | else:
93 | y_pred_adv = pred_batch(X_adv, model)
94 |
95 | total_correct += (y_pred_adv.numpy() == y.numpy()).sum()
96 |
97 | acc = total_correct/total_samples
98 |
99 | print('Got %d/%d correct (%.2f%%) on the perturbed data'
100 | % (total_correct, total_samples, 100 * acc))
101 |
102 | return acc
103 |
104 |
105 | def batch_indices(batch_nb, data_length, batch_size):
106 | """
107 | This helper function computes a batch start and end index
108 | :param batch_nb: the batch number
109 | :param data_length: the total length of the data being parsed by batches
110 | :param batch_size: the number of inputs in each batch
111 | :return: pair of (start, end) indices
112 | """
113 | # Batch start and end index
114 | start = int(batch_nb * batch_size)
115 | end = int((batch_nb + 1) * batch_size)
116 |
117 | # When there are not enough inputs left, we reuse some to complete the
118 | # batch
119 | if end > data_length:
120 | shift = end - data_length
121 | start -= shift
122 | end -= shift
123 |
124 | return start, end
125 |
--------------------------------------------------------------------------------
/models/mlp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | class Mlp(nn.Module):
4 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks
5 | """
6 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
7 | super().__init__()
8 | out_features = out_features or in_features
9 | hidden_features = hidden_features or in_features
10 | self.fc1 = nn.Linear(in_features, hidden_features)
11 | self.act = act_layer()
12 | self.fc2 = nn.Linear(hidden_features, out_features)
13 | self.drop = nn.Dropout(drop)
14 |
15 | def forward(self, x):
16 | x = self.fc1(x)
17 | x = self.act(x)
18 | x = self.drop(x)
19 | x = self.fc2(x)
20 | x = self.drop(x)
21 | return x
22 |
23 |
24 | class GluMlp(nn.Module):
25 | """ MLP w/ GLU style gating
26 | See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
27 | """
28 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.):
29 | super().__init__()
30 | out_features = out_features or in_features
31 | hidden_features = hidden_features or in_features
32 | assert hidden_features % 2 == 0
33 | self.fc1 = nn.Linear(in_features, hidden_features)
34 | self.act = act_layer()
35 | self.fc2 = nn.Linear(hidden_features // 2, out_features)
36 | self.drop = nn.Dropout(drop)
37 |
38 | def init_weights(self):
39 | # override init of fc1 w/ gate portion set to weight near zero, bias=1
40 | fc1_mid = self.fc1.bias.shape[0] // 2
41 | nn.init.ones_(self.fc1.bias[fc1_mid:])
42 | nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
43 |
44 | def forward(self, x):
45 | x = self.fc1(x)
46 | x, gates = x.chunk(2, dim=-1)
47 | x = x * self.act(gates)
48 | x = self.drop(x)
49 | x = self.fc2(x)
50 | x = self.drop(x)
51 | return x
52 |
53 |
54 | class GatedMlp(nn.Module):
55 | """ MLP as used in gMLP
56 | """
57 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
58 | gate_layer=None, drop=0.):
59 | super().__init__()
60 | out_features = out_features or in_features
61 | hidden_features = hidden_features or in_features
62 | self.fc1 = nn.Linear(in_features, hidden_features)
63 | self.act = act_layer()
64 | if gate_layer is not None:
65 | assert hidden_features % 2 == 0
66 | self.gate = gate_layer(hidden_features)
67 | hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
68 | else:
69 | self.gate = nn.Identity()
70 | self.fc2 = nn.Linear(hidden_features, out_features)
71 | self.drop = nn.Dropout(drop)
72 |
73 | def forward(self, x):
74 | x = self.fc1(x)
75 | x = self.act(x)
76 | x = self.drop(x)
77 | x = self.gate(x)
78 | x = self.fc2(x)
79 | x = self.drop(x)
80 | return x
81 |
82 |
83 | class ConvMlp(nn.Module):
84 | """ MLP using 1x1 convs that keeps spatial dims
85 | """
86 | def __init__(
87 | self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.):
88 | super().__init__()
89 | out_features = out_features or in_features
90 | hidden_features = hidden_features or in_features
91 | self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True)
92 | self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
93 | self.act = act_layer()
94 | self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True)
95 | self.drop = nn.Dropout(drop)
96 |
97 | def forward(self, x):
98 | x = self.fc1(x)
99 | x = self.norm(x)
100 | x = self.act(x)
101 | x = self.drop(x)
102 | x = self.fc2(x)
103 | return x
--------------------------------------------------------------------------------
/adversarialbox/attacks.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import numpy as np
3 | from collections import Iterable
4 | from scipy.stats import truncnorm
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | from adversarialbox.utils import to_var
10 |
11 | # --- White-box attacks ---
12 |
13 | class FGSMAttack(object):
14 | def __init__(self, model=None, epsilon=None):
15 | """
16 | One step fast gradient sign method
17 | """
18 | self.model = model
19 | self.epsilon = epsilon
20 | self.loss_fn = nn.CrossEntropyLoss()
21 |
22 | def perturb(self, X_nat, y, epsilons=None):
23 | """
24 | Given examples (X_nat, y), returns their adversarial
25 | counterparts with an attack length of epsilon.
26 | """
27 | # Providing epsilons in batch
28 | if epsilons is not None:
29 | self.epsilon = epsilons
30 |
31 | X = np.copy(X_nat)
32 |
33 | X_var = to_var(torch.from_numpy(X), requires_grad=True)
34 | y_var = to_var(torch.LongTensor(y))
35 |
36 | scores = self.model(X_var)
37 | loss = self.loss_fn(scores, y_var)
38 | loss.backward()
39 | grad_sign = X_var.grad.data.cpu().sign().numpy()
40 |
41 | X += self.epsilon * grad_sign
42 | X = np.clip(X, 0, 1)
43 |
44 | return X
45 |
46 |
47 | class LinfPGDAttack(object):
48 | def __init__(self, model=None, epsilon=0.3, k=40, a=0.01,
49 | random_start=True):
50 | """
51 | Attack parameter initialization. The attack performs k steps of
52 | size a, while always staying within epsilon from the initial
53 | point.
54 | https://github.com/MadryLab/mnist_challenge/blob/master/pgd_attack.py
55 | """
56 | self.model = model
57 | self.epsilon = epsilon
58 | self.k = k
59 | self.a = a
60 | self.rand = random_start
61 | self.loss_fn = nn.CrossEntropyLoss()
62 |
63 | def perturb(self, X_nat, y):
64 | """
65 | Given examples (X_nat, y), returns adversarial
66 | examples within epsilon of X_nat in l_infinity norm.
67 | """
68 | if self.rand:
69 | X = X_nat + np.random.uniform(-self.epsilon, self.epsilon,
70 | X_nat.shape).astype('float32')
71 | else:
72 | X = np.copy(X_nat)
73 |
74 | for i in range(self.k):
75 | X_var = to_var(torch.from_numpy(X), requires_grad=True)
76 | y_var = to_var(torch.LongTensor(y))
77 |
78 | scores = self.model(X_var)
79 | loss = self.loss_fn(scores, y_var)
80 | loss.backward()
81 | grad = X_var.grad.data.cpu().numpy()
82 |
83 | X += self.a * np.sign(grad)
84 |
85 | X = np.clip(X, X_nat - self.epsilon, X_nat + self.epsilon)
86 | X = np.clip(X, 0, 1) # ensure valid pixel range
87 |
88 | return X
89 |
90 |
91 | # --- Black-box attacks ---
92 |
93 | def jacobian(model, x, nb_classes=10):
94 | """
95 | This function will return a list of PyTorch gradients
96 | """
97 | list_derivatives = []
98 | x_var = to_var(torch.from_numpy(x), requires_grad=True)
99 |
100 | # derivatives for each class
101 | for class_ind in range(nb_classes):
102 | score = model(x_var)[:, class_ind]
103 | score.backward()
104 | list_derivatives.append(x_var.grad.data.cpu().numpy())
105 | x_var.grad.data.zero_()
106 |
107 | return list_derivatives
108 |
109 |
110 | def jacobian_augmentation(model, X_sub_prev, Y_sub, lmbda=0.1):
111 | """
112 | Create new numpy array for adversary training data
113 | with twice as many components on the first dimension.
114 | """
115 | X_sub = np.vstack([X_sub_prev, X_sub_prev])
116 |
117 | # For each input in the previous' substitute training iteration
118 | for ind, x in enumerate(X_sub_prev):
119 | grads = jacobian(model, x)
120 | # Select gradient corresponding to the label predicted by the oracle
121 | grad = grads[Y_sub[ind]]
122 |
123 | # Compute sign matrix
124 | grad_val = np.sign(grad)
125 |
126 | # Create new synthetic point in adversary substitute training set
127 | X_sub[len(X_sub_prev)+ind] = X_sub[ind] + lmbda * grad_val #???
128 |
129 | # Return augmented training data (needs to be labeled afterwards)
130 | return X_sub
131 |
--------------------------------------------------------------------------------
/models/model_configs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import ml_collections
16 |
17 |
18 | def get_testing():
19 | """Returns a minimal configuration for testing."""
20 | config = ml_collections.ConfigDict()
21 | config.patches = ml_collections.ConfigDict({'size': (16, 16)})
22 | config.hidden_size = 1
23 | config.transformer = ml_collections.ConfigDict()
24 | config.transformer.mlp_dim = 1
25 | config.transformer.num_heads = 1
26 | config.transformer.num_layers = 1
27 | config.transformer.attention_dropout_rate = 0.0
28 | config.transformer.dropout_rate = 0.1
29 | config.classifier = 'token'
30 | config.representation_size = None
31 | return config
32 |
33 |
34 | def get_b16_config():
35 | """Returns the ViT-B/16 configuration."""
36 | config = ml_collections.ConfigDict()
37 | config.patches = ml_collections.ConfigDict({'size': (16, 16)})
38 | config.hidden_size = 768
39 | config.transformer = ml_collections.ConfigDict()
40 | config.transformer.mlp_dim = 3072
41 | config.transformer.num_heads = 12
42 | config.transformer.num_layers = 12
43 | config.transformer.attention_dropout_rate = 0.0
44 | config.transformer.dropout_rate = 0.1
45 | config.classifier = 'token'
46 | config.representation_size = None
47 | return config
48 |
49 |
50 | def get_r50_b16_config():
51 | """Returns the Resnet50 + ViT-B/16 configuration."""
52 | config = get_b16_config()
53 | del config.patches.size
54 | config.patches.grid = (14, 14)
55 | config.resnet = ml_collections.ConfigDict()
56 | config.resnet.num_layers = (3, 4, 9)
57 | config.resnet.width_factor = 1
58 | return config
59 |
60 |
61 | def get_b32_config():
62 | """Returns the ViT-B/32 configuration."""
63 | config = get_b16_config()
64 | config.patches.size = (32, 32)
65 | return config
66 |
67 |
68 | def get_l16_config():
69 | """Returns the ViT-L/16 configuration."""
70 | config = ml_collections.ConfigDict()
71 | config.patches = ml_collections.ConfigDict({'size': (16, 16)})
72 | config.hidden_size = 1024
73 | config.transformer = ml_collections.ConfigDict()
74 | config.transformer.mlp_dim = 4096
75 | config.transformer.num_heads = 16
76 | config.transformer.num_layers = 24
77 | config.transformer.attention_dropout_rate = 0.0
78 | config.transformer.dropout_rate = 0.1
79 | config.classifier = 'token'
80 | config.representation_size = None
81 | return config
82 |
83 |
84 | def get_l32_config():
85 | """Returns the ViT-L/32 configuration."""
86 | config = get_l16_config()
87 | config.patches.size = (32, 32)
88 | return config
89 |
90 |
91 | def get_h14_config():
92 | """Returns the ViT-L/16 configuration."""
93 | config = ml_collections.ConfigDict()
94 | config.patches = ml_collections.ConfigDict({'size': (14, 14)})
95 | config.hidden_size = 1280
96 | config.transformer = ml_collections.ConfigDict()
97 | config.transformer.mlp_dim = 5120
98 | config.transformer.num_heads = 16
99 | config.transformer.num_layers = 32
100 | config.transformer.attention_dropout_rate = 0.0
101 | config.transformer.dropout_rate = 0.1
102 | config.classifier = 'token'
103 | config.representation_size = None
104 | return config
105 |
106 |
107 | def get_T_b16_config():
108 | """Returns the ViT-B/16 configuration."""
109 | config = ml_collections.ConfigDict()
110 | config.patches = ml_collections.ConfigDict({'size': (16, 16)})
111 | config.hidden_size = 192
112 | config.transformer = ml_collections.ConfigDict()
113 | config.transformer.mlp_dim = 768
114 | config.transformer.num_heads = 3
115 | config.transformer.num_layers = 12
116 | config.transformer.attention_dropout_rate = 0.0
117 | config.transformer.dropout_rate = 0.1
118 | config.classifier = 'token'
119 | config.representation_size = None
120 | return config
121 |
122 |
123 | def get_S_b16_config():
124 | """Returns the ViT-B/16 configuration."""
125 | config = ml_collections.ConfigDict()
126 | config.patches = ml_collections.ConfigDict({'size': (16, 16)})
127 | config.hidden_size = 384
128 | config.transformer = ml_collections.ConfigDict()
129 | config.transformer.mlp_dim = 1536
130 | config.transformer.num_heads = 6
131 | config.transformer.num_layers = 12
132 | config.transformer.attention_dropout_rate = 0.0
133 | config.transformer.dropout_rate = 0.1
134 | config.classifier = 'token'
135 | config.representation_size = None
136 | return config
137 |
138 |
139 |
--------------------------------------------------------------------------------
/models/drop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def drop_block_2d(
7 | x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
8 | with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
9 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
10 | DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
11 | runs with success, but needs further validation and possibly optimization for lower runtime impact.
12 | """
13 | B, C, H, W = x.shape
14 | total_size = W * H
15 | clipped_block_size = min(block_size, min(W, H))
16 | # seed_drop_rate, the gamma parameter
17 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
18 | (W - block_size + 1) * (H - block_size + 1))
19 |
20 | # Forces the block to be inside the feature map.
21 | w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
22 | valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
23 | ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
24 | valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
25 |
26 | if batchwise:
27 | # one mask for whole batch, quite a bit faster
28 | uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
29 | else:
30 | uniform_noise = torch.rand_like(x)
31 | block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
32 | block_mask = -F.max_pool2d(
33 | -block_mask,
34 | kernel_size=clipped_block_size, # block_size,
35 | stride=1,
36 | padding=clipped_block_size // 2)
37 |
38 | if with_noise:
39 | normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
40 | if inplace:
41 | x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
42 | else:
43 | x = x * block_mask + normal_noise * (1 - block_mask)
44 | else:
45 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
46 | if inplace:
47 | x.mul_(block_mask * normalize_scale)
48 | else:
49 | x = x * block_mask * normalize_scale
50 | return x
51 |
52 |
53 | def drop_block_fast_2d(
54 | x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
55 | gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
56 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
57 | DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
58 | block mask at edges.
59 | """
60 | B, C, H, W = x.shape
61 | total_size = W * H
62 | clipped_block_size = min(block_size, min(W, H))
63 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
64 | (W - block_size + 1) * (H - block_size + 1))
65 |
66 | if batchwise:
67 | # one mask for whole batch, quite a bit faster
68 | block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma
69 | else:
70 | # mask per batch element
71 | block_mask = torch.rand_like(x) < gamma
72 | block_mask = F.max_pool2d(
73 | block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
74 |
75 | if with_noise:
76 | normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
77 | if inplace:
78 | x.mul_(1. - block_mask).add_(normal_noise * block_mask)
79 | else:
80 | x = x * (1. - block_mask) + normal_noise * block_mask
81 | else:
82 | block_mask = 1 - block_mask
83 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype)
84 | if inplace:
85 | x.mul_(block_mask * normalize_scale)
86 | else:
87 | x = x * block_mask * normalize_scale
88 | return x
89 |
90 |
91 | class DropBlock2d(nn.Module):
92 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
93 | """
94 | def __init__(self,
95 | drop_prob=0.1,
96 | block_size=7,
97 | gamma_scale=1.0,
98 | with_noise=False,
99 | inplace=False,
100 | batchwise=False,
101 | fast=True):
102 | super(DropBlock2d, self).__init__()
103 | self.drop_prob = drop_prob
104 | self.gamma_scale = gamma_scale
105 | self.block_size = block_size
106 | self.with_noise = with_noise
107 | self.inplace = inplace
108 | self.batchwise = batchwise
109 | self.fast = fast # FIXME finish comparisons of fast vs not
110 |
111 | def forward(self, x):
112 | if not self.training or not self.drop_prob:
113 | return x
114 | if self.fast:
115 | return drop_block_fast_2d(
116 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
117 | else:
118 | return drop_block_2d(
119 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
120 |
121 |
122 | def drop_path(x, drop_prob: float = 0., training: bool = False):
123 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
124 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
125 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
126 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
127 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
128 | 'survival rate' as the argument.
129 | """
130 | if drop_prob == 0. or not training:
131 | return x
132 | keep_prob = 1 - drop_prob
133 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
134 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
135 | random_tensor.floor_() # binarize
136 | output = x.div(keep_prob) * random_tensor
137 | return output
138 |
139 |
140 | class DropPath(nn.Module):
141 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
142 | """
143 | def __init__(self, drop_prob=None):
144 | super(DropPath, self).__init__()
145 | self.drop_prob = drop_prob
146 |
147 | def forward(self, x):
148 | return drop_path(x, self.drop_prob, self.training)
--------------------------------------------------------------------------------
/models/DeiT.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 |
5 | from .vision_transformer import VisionTransformer, _cfg
6 | from timm.models.registry import register_model
7 | from timm.models.layers import trunc_normal_
8 |
9 |
10 | __all__ = [
11 | 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
12 | 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
13 | 'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
14 | 'deit_base_distilled_patch16_384',
15 | ]
16 |
17 |
18 | class DistilledVisionTransformer(VisionTransformer):
19 | def __init__(self, *args, **kwargs):
20 | super().__init__(*args, **kwargs)
21 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
22 | num_patches = self.patch_embed.num_patches
23 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
24 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
25 |
26 | trunc_normal_(self.dist_token, std=.02)
27 | trunc_normal_(self.pos_embed, std=.02)
28 | self.head_dist.apply(self._init_weights)
29 |
30 | def forward_features(self, x):
31 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
32 | # with slight modifications to add the dist_token
33 | B = x.shape[0]
34 | x = self.patch_embed(x)
35 |
36 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
37 | dist_token = self.dist_token.expand(B, -1, -1)
38 | x = torch.cat((cls_tokens, dist_token, x), dim=1)
39 |
40 | x = x + self.pos_embed
41 | x = self.pos_drop(x)
42 |
43 | for blk in self.blocks:
44 | x = blk(x)
45 |
46 | x = self.norm(x)
47 | return x[:, 0], x[:, 1]
48 |
49 | def forward(self, x):
50 | x, x_dist = self.forward_features(x)
51 | x = self.head(x)
52 | x_dist = self.head_dist(x_dist)
53 | if self.training:
54 | return x, x_dist
55 | else:
56 | # during inference, return the average of both classifier predictions
57 | return (x + x_dist) / 2
58 |
59 |
60 | @register_model
61 | def deit_tiny_patch16_224(pretrained=False, **kwargs):
62 | model = VisionTransformer(
63 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
64 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
65 | model.default_cfg = _cfg()
66 | if pretrained:
67 | checkpoint = torch.hub.load_state_dict_from_url(
68 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
69 | map_location="cpu", check_hash=True
70 | )
71 | model.load_state_dict(checkpoint["model"])
72 | return model
73 |
74 |
75 | @register_model
76 | def deit_small_patch16_224(pretrained=False, **kwargs):
77 | model = VisionTransformer(
78 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
79 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
80 | model.default_cfg = _cfg()
81 | if pretrained:
82 | checkpoint = torch.hub.load_state_dict_from_url(
83 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
84 | map_location="cpu", check_hash=True
85 | )
86 | model.load_state_dict(checkpoint["model"])
87 | return model
88 |
89 |
90 | @register_model
91 | def deit_base_patch16_224(pretrained=False, **kwargs):
92 | model = VisionTransformer(
93 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
94 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
95 | model.default_cfg = _cfg()
96 | if pretrained:
97 | checkpoint = torch.hub.load_state_dict_from_url(
98 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
99 | map_location="cpu", check_hash=True
100 | )
101 | model.load_state_dict(checkpoint["model"])
102 | return model
103 |
104 |
105 | @register_model
106 | def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
107 | model = DistilledVisionTransformer(
108 | patch_size=4, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
109 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
110 | model.default_cfg = _cfg()
111 | if pretrained:
112 | checkpoint = torch.hub.load_state_dict_from_url(
113 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
114 | map_location="cpu", check_hash=True
115 | )
116 | model.load_state_dict(checkpoint["model"])
117 | return model
118 |
119 |
120 | @register_model
121 | def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
122 | model = DistilledVisionTransformer(
123 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
124 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
125 | model.default_cfg = _cfg()
126 | if pretrained:
127 | checkpoint = torch.hub.load_state_dict_from_url(
128 | url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
129 | map_location="cpu", check_hash=True
130 | )
131 | model.load_state_dict(checkpoint["model"])
132 | return model
133 |
134 |
135 | @register_model
136 | def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
137 | model = DistilledVisionTransformer(
138 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
139 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
140 | model.default_cfg = _cfg()
141 | if pretrained:
142 | checkpoint = torch.hub.load_state_dict_from_url(
143 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
144 | map_location="cpu", check_hash=True
145 | )
146 | model.load_state_dict(checkpoint["model"])
147 | return model
148 |
149 |
150 | @register_model
151 | def deit_base_patch16_384(pretrained=False, **kwargs):
152 | model = VisionTransformer(
153 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
154 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
155 | model.default_cfg = _cfg()
156 | if pretrained:
157 | checkpoint = torch.hub.load_state_dict_from_url(
158 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
159 | map_location="cpu", check_hash=True
160 | )
161 | model.load_state_dict(checkpoint["model"])
162 | return model
163 |
164 |
165 | @register_model
166 | def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
167 | model = DistilledVisionTransformer(
168 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
169 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
170 | model.default_cfg = _cfg()
171 | if pretrained:
172 | checkpoint = torch.hub.load_state_dict_from_url(
173 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",
174 | map_location="cpu", check_hash=True
175 | )
176 | model.load_state_dict(checkpoint["model"])
177 | return model
178 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 | import torch.nn.functional as F
5 | from torchvision import datasets, transforms
6 | from torch.utils.data.sampler import SubsetRandomSampler
7 | import numpy as np
8 | from os import path
9 | import os
10 | import matplotlib.pyplot as plt
11 | import seaborn as sns
12 |
13 | from datasets import Array3D, ClassLabel, Features, load_dataset
14 | from numpy import inf
15 | from torchinfo import summary
16 | from tqdm import tqdm
17 |
18 |
19 | #mu = [0.5, 0.5, 0.5]
20 | #std = [0.5, 0.5, 0.5]
21 | #imagenet
22 | mu = [0.485, 0.456, 0.406]
23 | std = [0.229, 0.224, 0.225]
24 |
25 | def clamp(X, lower_limit, upper_limit):
26 | return torch.max(torch.min(X, upper_limit), lower_limit)
27 |
28 | def get_loaders(args):
29 | args.mu = mu
30 | args.std = std
31 | valdir = path.join(args.data_dir, 'val')
32 | val_dataset = datasets.ImageFolder(valdir,
33 | transforms.Compose([transforms.Resize(args.img_size),
34 | transforms.CenterCrop(args.crop_size),
35 | transforms.ToTensor(),
36 | transforms.Normalize(mean=args.mu, std=args.std)
37 | ]))
38 | val_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(val_dataset, np.arange(384)), batch_size=args.batch_size, shuffle=False,
39 | num_workers=args.workers, pin_memory=True)
40 | return val_loader
41 |
42 | def get_loaders_test(args):
43 | args.mu = mu
44 | args.std = std
45 | valdir = path.join(args.data_dir, 'val')
46 | val_dataset = datasets.ImageFolder(valdir,
47 | transforms.Compose([transforms.Resize(args.img_size),
48 | transforms.CenterCrop(args.crop_size),
49 | transforms.ToTensor(),
50 | transforms.Normalize(mean=args.mu, std=args.std)
51 | ]))
52 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16,shuffle =True,
53 | num_workers=args.workers, pin_memory=True)
54 | return val_loader
55 |
56 | def get_loaders_testonebatch(args):
57 | args.mu = mu
58 | args.std = std
59 | valdir = path.join(args.data_dir, 'val')
60 | val_dataset = datasets.ImageFolder(valdir,
61 | transforms.Compose([transforms.Resize(args.img_size),
62 | transforms.CenterCrop(args.crop_size),
63 | transforms.ToTensor(),
64 | transforms.Normalize(mean=args.mu, std=args.std)
65 | ]))
66 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
67 | num_workers=args.workers, pin_memory=True)
68 | return val_loader
69 |
70 |
71 |
72 | def get_loaders_test_small(args):
73 | args.mu = mu
74 | args.std = std
75 | valdir = path.join(args.data_dir, 'val')
76 | val_dataset = datasets.ImageFolder(valdir,
77 | transforms.Compose([transforms.Resize(args.img_size),
78 | transforms.CenterCrop(args.crop_size),
79 | transforms.ToTensor(),
80 | transforms.Normalize(mean=args.mu, std=args.std)
81 | ]))
82 | val_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(val_dataset, np.arange(1024)) , batch_size=args.batch_size, shuffle=True,
83 | num_workers=args.workers, pin_memory=True)
84 | return val_loader
85 |
86 | def visualize_loss(loss_list):
87 | plt.figure()
88 | plt.plot(loss_list)
89 | plt.savefig('output/loss/loss{:.4f}_{}_{}_{}_{}_{}.png'.format(loss_list[-1], args.network, args.attack_learning_rate,
90 | args.train_attack_iters, args.step_size, args.gamma))
91 | plt.close()
92 |
93 | def visualize_attention_map(atten1, atten2, image1, image2, original_result, after_attack_result, max_patch_index):
94 | atten1 = [x.mean(dim=1).cpu() for x in atten1]
95 | atten2 = [x.mean(dim=1).cpu() for x in atten2]
96 | image1 = image1.cpu()
97 | image2 = image2.cpu()
98 | original_result = original_result.cpu()
99 | after_attack_result = after_attack_result.cpu()
100 |
101 | if image1.size(0) > 4:
102 | atten1 = [x[:4] for x in atten1]
103 | atten2 = [x[:4] for x in atten2]
104 | image1 = image1[:4]
105 | image2 = image2[:4]
106 | original_result = original_result[:4]
107 | after_attack_result = after_attack_result[:4]
108 | pic_num = image1.size(0)
109 | if 'LeViT' in args.network:
110 | patch_num = atten1[0].size(-1)
111 | else:
112 | patch_num = atten1[0].size(-1) - 1
113 | patch_per_line = int(patch_num ** 0.5)
114 | patch_size = int(image1.size(-1) / patch_per_line)
115 | to_PIL = transforms.ToPILImage()
116 |
117 | for i in range(pic_num):
118 | if not path.exists('output/{}'.format(i)):
119 | os.mkdir('output/{}'.format(i))
120 |
121 | original_img = to_PIL(image1[i].squeeze())
122 | after_attack_img = to_PIL(image2[i].squeeze())
123 | original_atten = [x[i] for x in atten1] # [197x197]
124 | after_attack_atten = [x[i] for x in atten2]
125 |
126 | with open('output/{}/atten.txt'.format(i), 'w') as f:
127 | print("Base model result: {}\tAttack model result:{}".format(original_result[i], after_attack_result[i]))
128 | print("Base model result: {}\tAttack model result:{}".format(original_result[i],
129 | after_attack_result[i]), file=f)
130 | for j in [4]: # one layer
131 | # for j in range(len(original_atten)): # each block
132 | print("Processing Image:{}\tLayer:{}".format(i, j))
133 | original_block_layer = original_atten[j]
134 | after_attack_atten_layer = after_attack_atten[j]
135 | vmin = min(original_block_layer.min(), after_attack_atten_layer.min())
136 | vmax = max(original_block_layer.max(), after_attack_atten_layer.max())
137 | plt.figure(figsize=(70, 30))
138 | plt.subplot(1, 2, 1)
139 | plt.title('Original')
140 | sns.heatmap(original_block_layer.data, annot=False, vmin=vmin, vmax=vmax)
141 | plt.subplot(1, 2, 2)
142 | plt.title('Attack patch {}'.format(max_patch_index[i] + 1))
143 | sns.heatmap(after_attack_atten_layer.data, annot=False, vmin=vmin, vmax=vmax)
144 | plt.savefig('output/{}/atten_layer{}.png'.format(i, j))
145 | plt.close()
146 |
147 | original_block_layer = original_block_layer.mean(dim=0)
148 | after_attack_atten_layer = after_attack_atten_layer.mean(dim=0)
149 | print('layer_{}'.format(j), file=f)
150 | print(original_block_layer, file=f)
151 | print(' ', file=f)
152 | print(after_attack_atten_layer, file=f)
153 | print(' ', file=f)
154 | print(after_attack_atten_layer - original_block_layer, file=f)
155 |
156 | plt.figure()
157 | plt.subplot(2, 2, 1)
158 | plt.imshow(original_img)
159 | plt.subplot(2, 2, 2)
160 | plt.imshow(after_attack_img)
161 |
162 | if 'DeiT' in args.network:
163 | original_block_layer = original_block_layer[1:]
164 | after_attack_atten_layer = after_attack_atten_layer[1:]
165 | plt.subplot(2, 2, 3)
166 | sns.heatmap(original_block_layer.view(patch_per_line, patch_per_line).data, annot=False)
167 | plt.subplot(2, 2, 4)
168 | sns.heatmap(after_attack_atten_layer.view(patch_per_line, patch_per_line).data, annot=False)
169 | plt.savefig('output/{}/atten_layer{}_img.png'.format(i, j))
170 | plt.close()
171 |
172 |
173 |
174 |
175 |
176 | # filter = torch.ones([1, 3, patch_size, patch_size])
177 | # atten = F.conv_transpose2d(atten, filter, stride=patch_size)
178 | # add_atten = torch.mul(atten, image)
179 |
180 |
181 |
182 | ##@Parameter atten_grad, ce_grad: should be 2D tensor with shape [batch_size, -1]
183 |
184 | def PCGrad(atten_grad, ce_grad, sim, shape):
185 | pcgrad = atten_grad[sim < 0]
186 | temp_ce_grad = ce_grad[sim < 0]
187 | dot_prod = torch.mul(pcgrad, temp_ce_grad).sum(dim=-1)
188 | dot_prod = dot_prod / torch.norm(temp_ce_grad, dim=-1)
189 | pcgrad = pcgrad - dot_prod.view(-1, 1) * temp_ce_grad
190 | atten_grad[sim < 0] = pcgrad
191 | atten_grad = atten_grad.view(shape)
192 | return atten_grad
193 |
194 |
195 |
196 | #random shift several patches within the range
197 |
198 | def shift_image(image, range, mu, std, patch_size=16):
199 | batch_size, channel, h, w = image.shape
200 | h_range, w_range = range
201 | new_h = h + 2 * h_range * patch_size
202 | new_w = w + 2 * w_range * patch_size
203 | new_image = torch.zeros([batch_size, channel, new_h, new_w]).cuda()
204 | new_image = (new_image - mu) / std
205 | shift_h = np.random.randint(-h_range, h_range+1)
206 | shift_w = np.random.randint(-w_range, w_range+1)
207 | # shift_h = np.random.randint(-1, 2)
208 | # shift_w = 0
209 | new_image[:, :, h_range*patch_size : h+h_range*patch_size, w_range*patch_size : w + w_range*patch_size] = image.detach()
210 | h_start = (h_range + shift_h) * patch_size
211 | w_start = (w_range + shift_w) * patch_size
212 | new_image = new_image[:, :, h_start : h_start+h, w_start : w_start+w]
213 | return new_image
214 |
215 |
216 |
217 | class my_logger:
218 | def __init__(self, args):
219 | name = "{}_{}_{}_{}_{}.log".format(args.name, args.network, args.dataset, args.train_attack_iters,
220 | args.attack_learning_rate)
221 | args.name = name
222 | self.name = path.join(args.log_dir, name)
223 | with open(self.name, 'w') as F:
224 | print('\n'.join(['%s:%s' % item for item in args.__dict__.items() if item[0][0] != '_']), file=F)
225 | print('\n', file=F)
226 |
227 | def info(self, content):
228 | with open(self.name, 'a') as F:
229 | print(content)
230 | print(content, file=F)
231 |
232 |
233 | class my_meter:
234 | def __init__(self):
235 | self.meter_list = {}
236 |
237 | def add_loss_acc(self, model_name, loss_dic: dict, correct_num, batch_size):
238 | if model_name not in self.meter_list.keys():
239 | self.meter_list[model_name] = self.model_meter()
240 | sub_meter = self.meter_list[model_name]
241 | sub_meter.add_loss_acc(loss_dic, correct_num, batch_size)
242 |
243 | def clean_meter(self):
244 | for key in self.meter_list.keys():
245 | self.meter_list[key].clean_meter()
246 |
247 | def get_loss_acc_msg(self):
248 | msg = []
249 | for key in self.meter_list.keys():
250 | sub_meter = self.meter_list[key]
251 | sub_loss_bag = sub_meter.get_loss()
252 | loss_msg = ["{}: {:.4f}({:.4f})".format(x, sub_meter.last_loss[x], sub_loss_bag[x])
253 | for x in sub_loss_bag.keys()]
254 | loss_msg = " ".join(loss_msg)
255 | msg.append("model:{} Loss:{} Acc:{:.4f}({:.4f})".format(
256 | key, loss_msg, sub_meter.last_acc, sub_meter.get_acc()))
257 | msg = "\n".join(msg)
258 | return msg
259 |
260 | class model_meter:
261 | def __init__(self):
262 | self.loss_bag = {}
263 | self.acc = 0.
264 | self.count = 0
265 | self.last_loss = {}
266 | self.last_acc = 0.
267 |
268 | def add_loss_acc(self, loss_dic: dict, correct_num, batch_size):
269 | for loss_name in loss_dic.keys():
270 | if loss_name not in self.loss_bag.keys():
271 | self.loss_bag[loss_name] = 0.
272 | self.loss_bag[loss_name] += loss_dic[loss_name] * batch_size
273 | self.last_loss = loss_dic
274 | self.last_acc = correct_num / batch_size
275 | self.acc += correct_num
276 | self.count += batch_size
277 |
278 | def get_loss(self):
279 | return {x: self.loss_bag[x] / self.count for x in self.loss_bag.keys()}
280 |
281 | def get_acc(self):
282 | return self.acc / self.count
283 |
284 | def clean_meter(self):
285 | self.__init__()
286 |
--------------------------------------------------------------------------------
/models/vision_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .drop import DropPath
4 | from .patch_embed import PatchEmbed
5 | from .mlp import Mlp
6 | import math
7 | import warnings
8 | from torch.nn.init import _calculate_fan_in_and_fan_out
9 |
10 |
11 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
12 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
13 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
14 | def norm_cdf(x):
15 | # Computes standard normal cumulative distribution function
16 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
17 |
18 | if (mean < a - 2 * std) or (mean > b + 2 * std):
19 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
20 | "The distribution of values may be incorrect.",
21 | stacklevel=2)
22 |
23 | with torch.no_grad():
24 | # Values are generated by using a truncated uniform distribution and
25 | # then using the inverse CDF for the normal distribution.
26 | # Get upper and lower cdf values
27 | l = norm_cdf((a - mean) / std)
28 | u = norm_cdf((b - mean) / std)
29 |
30 | # Uniformly fill tensor with values from [l, u], then translate to
31 | # [2l-1, 2u-1].
32 | tensor.uniform_(2 * l - 1, 2 * u - 1)
33 |
34 | # Use inverse cdf transform for normal distribution to get truncated
35 | # standard normal
36 | tensor.erfinv_()
37 |
38 | # Transform to proper mean, std
39 | tensor.mul_(std * math.sqrt(2.))
40 | tensor.add_(mean)
41 |
42 | # Clamp to ensure it's in the proper range
43 | tensor.clamp_(min=a, max=b)
44 | return tensor
45 |
46 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
47 | # type: (Tensor, float, float, float, float) -> Tensor
48 | r"""Fills the input Tensor with values drawn from a truncated
49 | normal distribution. The values are effectively drawn from the
50 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
51 | with values outside :math:`[a, b]` redrawn until they are within
52 | the bounds. The method used for generating the random values works
53 | best when :math:`a \leq \text{mean} \leq b`.
54 | Args:
55 | tensor: an n-dimensional `torch.Tensor`
56 | mean: the mean of the normal distribution
57 | std: the standard deviation of the normal distribution
58 | a: the minimum cutoff value
59 | b: the maximum cutoff value
60 | Examples:
61 | >>> w = torch.empty(3, 5)
62 | >>> nn.init.trunc_normal_(w)
63 | """
64 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
65 |
66 | def _cfg(url='', **kwargs):
67 | return {
68 | 'url': url,
69 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
70 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
71 | 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225],
72 | 'first_conv': 'patch_embed.proj', 'classifier': 'head',
73 | **kwargs
74 | }
75 |
76 |
77 | class Attention(nn.Module):
78 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
79 | super().__init__()
80 | self.num_heads = num_heads
81 | head_dim = dim // num_heads
82 | self.scale = head_dim ** -0.5
83 |
84 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
85 | self.attn_drop = nn.Dropout(attn_drop)
86 | self.proj = nn.Linear(dim, dim)
87 | self.proj_drop = nn.Dropout(proj_drop)
88 |
89 | def forward(self, x):
90 | B, N, C = x.shape
91 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
92 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
93 |
94 | attn = (q @ k.transpose(-2, -1)) * self.scale
95 | attn = attn.softmax(dim=-1)
96 | # here is the attn output
97 | drop_attn = self.attn_drop(attn)
98 |
99 | x = (drop_attn @ v).transpose(1, 2).reshape(B, N, C)
100 | x = self.proj(x)
101 | x = self.proj_drop(x)
102 | return x, attn
103 |
104 |
105 | class Block(nn.Module):
106 |
107 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
108 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
109 | super().__init__()
110 | self.norm1 = norm_layer(dim)
111 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
112 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
113 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
114 | self.norm2 = norm_layer(dim)
115 | mlp_hidden_dim = int(dim * mlp_ratio)
116 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
117 |
118 | def forward(self, x):
119 | x, attn_list = x
120 | attn_temp, attn = self.attn(self.norm1(x))
121 | attn_list.append(attn.clone())
122 | x = x + self.drop_path(attn_temp)
123 | x = x + self.drop_path(self.mlp(self.norm2(x)))
124 | return [x, attn_list]
125 |
126 |
127 | class VisionTransformer(nn.Module):
128 | """ Vision Transformer
129 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
130 | - https://arxiv.org/abs/2010.11929
131 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
132 | - https://arxiv.org/abs/2012.12877
133 | """
134 |
135 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
136 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
137 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
138 | act_layer=None, weight_init=''):
139 | """
140 | Args:
141 | img_size (int, tuple): input image size
142 | patch_size (int, tuple): patch size
143 | in_chans (int): number of input channels
144 | num_classes (int): number of classes for classification head
145 | embed_dim (int): embedding dimension
146 | depth (int): depth of transformer
147 | num_heads (int): number of attention heads
148 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
149 | qkv_bias (bool): enable bias for qkv if True
150 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
151 | distilled (bool): model includes a distillation token and head as in DeiT models
152 | drop_rate (float): dropout rate
153 | attn_drop_rate (float): attention dropout rate
154 | drop_path_rate (float): stochastic depth rate
155 | embed_layer (nn.Module): patch embedding layer
156 | norm_layer: (nn.Module): normalization layer
157 | weight_init: (str): weight init scheme
158 | """
159 | super().__init__()
160 | self.num_classes = num_classes
161 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
162 | self.num_tokens = 2 if distilled else 1
163 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
164 | act_layer = act_layer or nn.GELU
165 |
166 | self.patch_embed = embed_layer(
167 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
168 | num_patches = self.patch_embed.num_patches
169 |
170 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
171 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
172 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
173 | self.pos_drop = nn.Dropout(p=drop_rate)
174 |
175 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
176 | self.blocks = nn.Sequential(*[
177 | Block(
178 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
179 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
180 | for i in range(depth)])
181 | self.norm = norm_layer(embed_dim)
182 |
183 | # Representation layer
184 | if representation_size and not distilled:
185 | self.num_features = representation_size
186 | self.pre_logits = nn.Sequential(OrderedDict([
187 | ('fc', nn.Linear(embed_dim, representation_size)),
188 | ('act', nn.Tanh())
189 | ]))
190 | else:
191 | self.pre_logits = nn.Identity()
192 |
193 | # Classifier head(s)
194 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
195 | self.head_dist = None
196 | if distilled:
197 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
198 |
199 | self.init_weights(weight_init)
200 |
201 | def init_weights(self, mode=''):
202 | assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
203 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
204 | trunc_normal_(self.pos_embed, std=.02)
205 | if self.dist_token is not None:
206 | trunc_normal_(self.dist_token, std=.02)
207 | if mode.startswith('jax'):
208 | # leave cls token as zeros to match jax impl
209 | named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
210 | else:
211 | trunc_normal_(self.cls_token, std=.02)
212 | self.apply(_init_vit_weights)
213 |
214 | def _init_weights(self, m):
215 | # this fn left here for compat with downstream users
216 | _init_vit_weights(m)
217 |
218 | @torch.jit.ignore()
219 | def load_pretrained(self, checkpoint_path, prefix=''):
220 | _load_weights(self, checkpoint_path, prefix)
221 |
222 | @torch.jit.ignore
223 | def no_weight_decay(self):
224 | return {'pos_embed', 'cls_token', 'dist_token'}
225 |
226 | def get_classifier(self):
227 | if self.dist_token is None:
228 | return self.head
229 | else:
230 | return self.head, self.head_dist
231 |
232 | def reset_classifier(self, num_classes, global_pool=''):
233 | self.num_classes = num_classes
234 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
235 | if self.num_tokens == 2:
236 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
237 |
238 | def forward_features(self, x):
239 | x = self.patch_embed(x)
240 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
241 | if self.dist_token is None:
242 | x = torch.cat((cls_token, x), dim=1)
243 | else:
244 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
245 | x = self.pos_drop(x + self.pos_embed)
246 | x = [x, []]
247 | x = self.blocks(x)
248 | x, attn_list = x
249 | x = self.norm(x)
250 | if self.dist_token is None:
251 | return self.pre_logits(x[:, 0]), attn_list
252 | else:
253 | return x[:, 0], x[:, 1]
254 |
255 | def forward(self, x):
256 | x, attn_list = self.forward_features(x)
257 | if self.head_dist is not None:
258 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
259 | if self.training and not torch.jit.is_scripting():
260 | # during inference, return the average of both classifier predictions
261 | return x, x_dist
262 | else:
263 | return (x + x_dist) / 2
264 | else:
265 | x = self.head(x)
266 | return x, attn_list
267 |
268 | def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
269 | """ ViT weight initialization
270 | * When called without n, head_bias, jax_impl args it will behave exactly the same
271 | as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
272 | * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
273 | """
274 | if isinstance(module, nn.Linear):
275 | if name.startswith('head'):
276 | nn.init.zeros_(module.weight)
277 | nn.init.constant_(module.bias, head_bias)
278 | elif name.startswith('pre_logits'):
279 | lecun_normal_(module.weight)
280 | nn.init.zeros_(module.bias)
281 | else:
282 | if jax_impl:
283 | nn.init.xavier_uniform_(module.weight)
284 | if module.bias is not None:
285 | if 'mlp' in name:
286 | nn.init.normal_(module.bias, std=1e-6)
287 | else:
288 | nn.init.zeros_(module.bias)
289 | else:
290 | trunc_normal_(module.weight, std=.02)
291 | if module.bias is not None:
292 | nn.init.zeros_(module.bias)
293 | elif jax_impl and isinstance(module, nn.Conv2d):
294 | # NOTE conv was left to pytorch default in my original init
295 | lecun_normal_(module.weight)
296 | if module.bias is not None:
297 | nn.init.zeros_(module.bias)
298 | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
299 | nn.init.zeros_(module.bias)
300 | nn.init.ones_(module.weight)
301 |
--------------------------------------------------------------------------------
/models/modeling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import copy
7 | import logging
8 | import math
9 |
10 | from os.path import join as pjoin
11 |
12 | import torch
13 | import torch.nn as nn
14 | import numpy as np
15 |
16 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
17 | from torch.nn.modules.utils import _pair
18 | from scipy import ndimage
19 |
20 | import models.model_configs as configs
21 |
22 | from .modeling_resnet import ResNetV2
23 |
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 |
28 | ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
29 | ATTENTION_K = "MultiHeadDotProductAttention_1/key"
30 | ATTENTION_V = "MultiHeadDotProductAttention_1/value"
31 | ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
32 | FC_0 = "MlpBlock_3/Dense_0"
33 | FC_1 = "MlpBlock_3/Dense_1"
34 | ATTENTION_NORM = "LayerNorm_0"
35 | MLP_NORM = "LayerNorm_2"
36 |
37 |
38 | def np2th(weights, conv=False):
39 | """Possibly convert HWIO to OIHW."""
40 | if conv:
41 | weights = weights.transpose([3, 2, 0, 1])
42 | return torch.from_numpy(weights)
43 |
44 |
45 | def swish(x):
46 | return x * torch.sigmoid(x)
47 |
48 |
49 | ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
50 |
51 |
52 | class Attention(nn.Module):
53 | def __init__(self, config, vis):
54 | super(Attention, self).__init__()
55 | self.vis = vis
56 | self.num_attention_heads = config.transformer["num_heads"]
57 | self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
58 | self.all_head_size = self.num_attention_heads * self.attention_head_size
59 |
60 | self.query = Linear(config.hidden_size, self.all_head_size)
61 | self.key = Linear(config.hidden_size, self.all_head_size)
62 | self.value = Linear(config.hidden_size, self.all_head_size)
63 |
64 | self.out = Linear(config.hidden_size, config.hidden_size)
65 | self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
66 | self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
67 |
68 | self.softmax = Softmax(dim=-1)
69 |
70 | def transpose_for_scores(self, x):
71 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
72 | x = x.view(*new_x_shape)
73 | return x.permute(0, 2, 1, 3)
74 |
75 | def forward(self, hidden_states):
76 | mixed_query_layer = self.query(hidden_states)
77 | mixed_key_layer = self.key(hidden_states)
78 | mixed_value_layer = self.value(hidden_states)
79 |
80 | query_layer = self.transpose_for_scores(mixed_query_layer)
81 | key_layer = self.transpose_for_scores(mixed_key_layer)
82 | value_layer = self.transpose_for_scores(mixed_value_layer)
83 |
84 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
85 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
86 | attention_probs = self.softmax(attention_scores)
87 | weights = attention_probs if self.vis else None
88 | attention_probs = self.attn_dropout(attention_probs)
89 |
90 | context_layer = torch.matmul(attention_probs, value_layer)
91 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
92 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
93 | context_layer = context_layer.view(*new_context_layer_shape)
94 | attention_output = self.out(context_layer)
95 | attention_output = self.proj_dropout(attention_output)
96 | return attention_output, weights
97 |
98 |
99 | class Mlp(nn.Module):
100 | def __init__(self, config):
101 | super(Mlp, self).__init__()
102 | self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
103 | self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
104 | self.act_fn = ACT2FN["gelu"]
105 | self.dropout = Dropout(config.transformer["dropout_rate"])
106 |
107 | self._init_weights()
108 |
109 | def _init_weights(self):
110 | nn.init.xavier_uniform_(self.fc1.weight)
111 | nn.init.xavier_uniform_(self.fc2.weight)
112 | nn.init.normal_(self.fc1.bias, std=1e-6)
113 | nn.init.normal_(self.fc2.bias, std=1e-6)
114 |
115 | def forward(self, x):
116 | x = self.fc1(x)
117 | x = self.act_fn(x)
118 | x = self.dropout(x)
119 | x = self.fc2(x)
120 | x = self.dropout(x)
121 | return x
122 |
123 |
124 | class Embeddings(nn.Module):
125 | """Construct the embeddings from patch, position embeddings.
126 | """
127 | def __init__(self, config, img_size, in_channels=3):
128 | super(Embeddings, self).__init__()
129 | self.hybrid = None
130 | img_size = _pair(img_size)
131 |
132 | if config.patches.get("grid") is not None:
133 | grid_size = config.patches["grid"]
134 | patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
135 | n_patches = (img_size[0] // 16) * (img_size[1] // 16)
136 | self.hybrid = True
137 | else:
138 | patch_size = _pair(config.patches["size"])
139 | n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
140 | self.hybrid = False
141 |
142 | if self.hybrid:
143 | self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers,
144 | width_factor=config.resnet.width_factor)
145 | in_channels = self.hybrid_model.width * 16
146 | self.patch_embeddings = Conv2d(in_channels=in_channels,
147 | out_channels=config.hidden_size,
148 | kernel_size=patch_size,
149 | stride=patch_size)
150 | self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
151 | self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
152 |
153 | self.dropout = Dropout(config.transformer["dropout_rate"])
154 |
155 | def forward(self, x):
156 | B = x.shape[0]
157 | cls_tokens = self.cls_token.expand(B, -1, -1)
158 |
159 | if self.hybrid:
160 | x = self.hybrid_model(x)
161 | x = self.patch_embeddings(x)
162 | x = x.flatten(2)
163 | x = x.transpose(-1, -2)
164 | x = torch.cat((cls_tokens, x), dim=1)
165 |
166 | embeddings = x + self.position_embeddings
167 | embeddings = self.dropout(embeddings)
168 | return embeddings
169 |
170 |
171 | class Block(nn.Module):
172 | def __init__(self, config, vis):
173 | super(Block, self).__init__()
174 | self.hidden_size = config.hidden_size
175 | self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
176 | self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
177 | self.ffn = Mlp(config)
178 | self.attn = Attention(config, vis)
179 |
180 | def forward(self, x):
181 | h = x
182 | x = self.attention_norm(x)
183 | x, weights = self.attn(x)
184 | x = x + h
185 |
186 | h = x
187 | x = self.ffn_norm(x)
188 | x = self.ffn(x)
189 | x = x + h
190 | return x, weights
191 |
192 | def load_from(self, weights, n_block):
193 | ROOT = f"Transformer/encoderblock_{n_block}"
194 | with torch.no_grad():
195 | query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
196 | key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
197 | value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
198 | out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
199 |
200 | query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
201 | key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
202 | value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
203 | out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
204 |
205 | self.attn.query.weight.copy_(query_weight)
206 | self.attn.key.weight.copy_(key_weight)
207 | self.attn.value.weight.copy_(value_weight)
208 | self.attn.out.weight.copy_(out_weight)
209 | self.attn.query.bias.copy_(query_bias)
210 | self.attn.key.bias.copy_(key_bias)
211 | self.attn.value.bias.copy_(value_bias)
212 | self.attn.out.bias.copy_(out_bias)
213 |
214 | mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
215 | mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
216 | mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
217 | mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
218 |
219 | self.ffn.fc1.weight.copy_(mlp_weight_0)
220 | self.ffn.fc2.weight.copy_(mlp_weight_1)
221 | self.ffn.fc1.bias.copy_(mlp_bias_0)
222 | self.ffn.fc2.bias.copy_(mlp_bias_1)
223 |
224 | self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
225 | self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
226 | self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
227 | self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
228 |
229 |
230 | class Encoder(nn.Module):
231 | def __init__(self, config, vis):
232 | super(Encoder, self).__init__()
233 | self.vis = vis
234 | self.layer = nn.ModuleList()
235 | self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
236 | for _ in range(config.transformer["num_layers"]):
237 | layer = Block(config, vis)
238 | self.layer.append(copy.deepcopy(layer))
239 |
240 | def forward(self, hidden_states):
241 | attn_weights = []
242 | for layer_block in self.layer:
243 | hidden_states, weights = layer_block(hidden_states)
244 | if self.vis:
245 | attn_weights.append(weights)
246 | encoded = self.encoder_norm(hidden_states)
247 | return encoded, attn_weights
248 |
249 |
250 | class Transformer(nn.Module):
251 | def __init__(self, config, img_size, vis):
252 | super(Transformer, self).__init__()
253 | self.embeddings = Embeddings(config, img_size=img_size)
254 | self.encoder = Encoder(config, vis)
255 |
256 | def forward(self, input_ids):
257 | embedding_output = self.embeddings(input_ids)
258 | encoded, attn_weights = self.encoder(embedding_output)
259 | return encoded, attn_weights
260 |
261 |
262 | class VisionTransformer(nn.Module):
263 | def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
264 | super(VisionTransformer, self).__init__()
265 | self.num_classes = num_classes
266 | self.zero_head = zero_head
267 | self.classifier = config.classifier
268 |
269 | self.transformer = Transformer(config, img_size, vis)
270 | self.head = Linear(config.hidden_size, num_classes)
271 |
272 | def forward(self, x, labels=None):
273 | x, attn_weights = self.transformer(x)
274 | logits = self.head(x[:, 0])
275 |
276 | if labels is not None:
277 | loss_fct = CrossEntropyLoss()
278 | loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
279 | return loss
280 | else:
281 | return logits, attn_weights
282 |
283 | def load_from(self, weights):
284 | with torch.no_grad():
285 | if self.zero_head:
286 | nn.init.zeros_(self.head.weight)
287 | nn.init.zeros_(self.head.bias)
288 | else:
289 | self.head.weight.copy_(np2th(weights["head/kernel"]).t())
290 | self.head.bias.copy_(np2th(weights["head/bias"]).t())
291 |
292 | self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
293 | self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
294 | self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"]))
295 | self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
296 | self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
297 |
298 | posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
299 | posemb_new = self.transformer.embeddings.position_embeddings
300 | if posemb.size() == posemb_new.size():
301 | self.transformer.embeddings.position_embeddings.copy_(posemb)
302 | else:
303 | logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
304 | ntok_new = posemb_new.size(1)
305 |
306 | if self.classifier == "token":
307 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
308 | ntok_new -= 1
309 | else:
310 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
311 |
312 | gs_old = int(np.sqrt(len(posemb_grid)))
313 | gs_new = int(np.sqrt(ntok_new))
314 | print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
315 | posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
316 |
317 | zoom = (gs_new / gs_old, gs_new / gs_old, 1)
318 | posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
319 | posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
320 | posemb = np.concatenate([posemb_tok, posemb_grid], axis=1)
321 | self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
322 |
323 | for bname, block in self.transformer.encoder.named_children():
324 | for uname, unit in block.named_children():
325 | unit.load_from(weights, n_block=uname)
326 |
327 | if self.transformer.embeddings.hybrid:
328 | self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True))
329 | gn_weight = np2th(weights["gn_root/scale"]).view(-1)
330 | gn_bias = np2th(weights["gn_root/bias"]).view(-1)
331 | self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
332 | self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)
333 |
334 | for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
335 | for uname, unit in block.named_children():
336 | unit.load_from(weights, n_block=bname, n_unit=uname)
337 |
338 |
339 | CONFIGS = {
340 | 'ViT-B_16': configs.get_b16_config(),
341 | 'ViT-B_32': configs.get_b32_config(),
342 | 'ViT-L_16': configs.get_l16_config(),
343 | 'ViT-L_32': configs.get_l32_config(),
344 | 'ViT-H_14': configs.get_h14_config(),
345 | 'R50-ViT-B_16': configs.get_r50_b16_config(),
346 | 'testing': configs.get_testing(),
347 | 'ViT-T_16': configs.get_T_b16_config(),
348 | 'ViT-S_16': configs.get_S_b16_config()
349 | }
350 |
--------------------------------------------------------------------------------
/main_patch_vit.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchvision
6 | import os
7 | import numpy as np
8 | import time
9 | from timm import create_model
10 | #from pytorch_pretrained_vit import ViT
11 | from models.DeiT import deit_base_patch16_224, deit_tiny_patch16_224, deit_small_patch16_224
12 | from utils import clamp, get_loaders,get_loaders_test,get_loaders_test_small, my_logger, my_meter, PCGrad
13 |
14 | import scipy
15 | import torchvision.datasets as datasets
16 | import torchvision.transforms as transforms
17 | from adversarialbox.attacks import FGSMAttack, LinfPGDAttack
18 | from adversarialbox.train import adv_train, FGSM_train_rnd
19 | from adversarialbox.utils import to_var, pred_batch, test
20 |
21 | from torchvision import datasets, transforms
22 | from torch.utils.data.sampler import SubsetRandomSampler
23 | from adversarialbox.utils import to_var, pred_batch, test, \
24 | attack_over_test_data
25 | import random
26 | from math import floor
27 | import operator
28 |
29 | import copy
30 | import matplotlib.pyplot as plt
31 | from torchvision.utils import save_image
32 |
33 | patch_size = 16
34 |
35 | high=100
36 | wb=768
37 | wb1=768
38 | targets=2
39 |
40 |
41 | ## generating the trigger using fgsm method
42 | class Attack(object):
43 |
44 | def __init__(self, dataloader, criterion=None, gpu_id=0,
45 | epsilon=0.031, attack_method='pgd'):
46 |
47 | if criterion is not None:
48 | self.criterion = nn.MSELoss()
49 | else:
50 | self.criterion = nn.MSELoss()
51 |
52 | self.dataloader = dataloader
53 | self.epsilon = epsilon
54 | self.gpu_id = gpu_id #this is integer
55 |
56 | if attack_method == 'fgsm':
57 | self.attack_method = self.fgsm
58 | elif attack_method == 'pgd':
59 | self.attack_method = self.pgd
60 | elif attack_method == 'fgsm_patch':
61 | self.attack_method = self.fgsm_patch
62 |
63 | def update_params(self, epsilon=None, dataloader=None, attack_method=None):
64 | if epsilon is not None:
65 | self.epsilon = epsilon
66 | if dataloader is not None:
67 | self.dataloader = dataloader
68 |
69 | if attack_method is not None:
70 | if attack_method == 'fgsm':
71 | self.attack_method = self.fgsm
72 | elif attack_method == 'fgsm_patch':
73 | self.attack_method = self.fgsm_patch
74 |
75 |
76 | def fgsm_patch(self, model, data,max_patch_index, target,tar,ep, data_min=0, data_max=1):
77 |
78 | model.eval()
79 | perturbed_data = data.clone()
80 | perturbed_data.requires_grad = True
81 | output = model.module.forward_features(perturbed_data)
82 | loss = self.criterion(output[:,tar], target[:,tar])
83 |
84 | if perturbed_data.grad is not None:
85 | perturbed_data.grad.data.zero_()
86 |
87 | loss.backward(retain_graph=True)
88 |
89 | # Collect the element-wise sign of the data gradient
90 | sign_data_grad = perturbed_data.grad.data.sign()
91 | perturbed_data.requires_grad = False
92 |
93 | with torch.no_grad():
94 | # Create the perturbed image by adjusting each pixel of the input image
95 |
96 | patch_num_per_line = int(data.size(-1) / patch_size)
97 | for j in range(data.size(0)):
98 | index_list = max_patch_index[j]
99 | for index in index_list:
100 | row = (index // patch_num_per_line) * patch_size
101 | column = (index % patch_num_per_line) * patch_size
102 | perturbed_data[j, :, row:row + patch_size, column:column + patch_size]-= ep*sign_data_grad[j, :, row:row + patch_size, column:column + patch_size]
103 | #perturbed_data.clamp_(data_min, data_max)
104 |
105 | return perturbed_data
106 |
107 |
108 |
109 |
110 |
111 |
112 | def get_aug():
113 | parser = argparse.ArgumentParser(description='Patch-Fool Training')
114 |
115 | parser.add_argument('--name', default='', type=str)
116 | parser.add_argument('--batch_size', default=16, type=int)
117 | parser.add_argument('--dataset', default='val', type=str)
118 | #parser.add_argument('--dataset', default='ImageNet', type=str)
119 | parser.add_argument('--data_dir', default='/mnt/mdata/new/imagenet/', type=str)
120 | #parser.add_argument('--data_dir', default='/data1/ImageNet/ILSVRC/Data/CLS-LOC/', type=str)
121 | parser.add_argument('--log_dir', default='log', type=str)
122 | parser.add_argument('--crop_size', default=224, type=int)
123 | parser.add_argument('--img_size', default=224, type=int)
124 | parser.add_argument('--workers', default=16, type=int)
125 | parser.add_argument('--device', default='cuda',
126 | help='device to use for training / testing')
127 | parser.add_argument('--network', default='ViT', type=str, choices=['DeiT-B', 'DeiT-S', 'DeiT-T','ViT',
128 | 'ResNet152', 'ResNet50', 'ResNet18'])
129 | parser.add_argument('--dataset_size', default=0.1, type=float, help='Use part of Eval set')
130 | #parser.add_argument('--patch_select', default='Rand', type=str, choices=['Rand', 'Saliency', 'Attn'])
131 | parser.add_argument('--patch_select', default='Saliency', type=str, choices=['Rand', 'Saliency', 'Attn'])
132 | #parser.add_argument('--patch_select', default='Attn', type=str, choices=['Rand', 'Saliency', 'Attn'])
133 | parser.add_argument('--num_patch', default=9, type=int)
134 | parser.add_argument('--sparse_pixel_num', default=0, type=int)
135 |
136 | parser.add_argument('--attack_mode', default='CE_loss', choices=['CE_loss', 'Attention'], type=str)
137 | parser.add_argument('--atten_loss_weight', default=1, type=float)
138 | parser.add_argument('--atten_select', default=4, type=int, help='Select patch based on which attention layer')
139 | parser.add_argument('--mild_l_2', default=0., type=float, help='Range: 0-16')
140 | parser.add_argument('--mild_l_inf', default=0., type=float, help='Range: 0-1')
141 |
142 | parser.add_argument('--train_attack_iters', default=250, type=int)
143 | parser.add_argument('--random_sparse_pixel', action='store_true', help='random select sparse pixel or not')
144 | parser.add_argument('--learnable_mask_stop', default=200, type=int)
145 |
146 | parser.add_argument('--attack_learning_rate', default=0.22, type=float)
147 | parser.add_argument('--step_size', default=10, type=int)
148 | parser.add_argument('--gamma', default=0.95, type=float)
149 |
150 | parser.add_argument('--seed', default=18, type=int, help='Random seed')
151 |
152 | args = parser.parse_args()
153 |
154 | if args.mild_l_2 != 0 and args.mild_l_inf != 0:
155 | print(f'Only one parameter can be non-zero: mild_l_2 {args.mild_l_2}, mild_l_inf {args.mild_l_inf}')
156 | raise NotImplementedError
157 | if args.mild_l_inf > 1:
158 | args.mild_l_inf /= 255.
159 | print(f'mild_l_inf > 1. Constrain all the perturbation with mild_l_inf/255={args.mild_l_inf}')
160 |
161 | if not os.path.exists(args.log_dir):
162 | os.makedirs(args.log_dir)
163 |
164 | return args
165 |
166 |
167 | def main():
168 | args = get_aug()
169 |
170 | device = torch.device(args.device)
171 | logger = my_logger(args)
172 | meter = my_meter()
173 |
174 | np.random.seed(args.seed)
175 | torch.manual_seed(args.seed)
176 | torch.cuda.manual_seed(args.seed)
177 | patch_size = 16
178 | filter_patch = torch.ones([1, 3, patch_size, patch_size]).float().cuda()
179 |
180 | if args.network == 'ResNet152':
181 | model = ResNet152(pretrained=True)
182 | elif args.network == 'ResNet50':
183 | model = ResNet50(pretrained=True)
184 | elif args.network == 'ResNet18':
185 | model = torchvision.models.resnet18(pretrained=True)
186 | elif args.network == 'VGG16':
187 | model = torchvision.models.vgg16(pretrained=True)
188 | elif args.network == 'DeiT-T':
189 | model = deit_tiny_patch16_224(pretrained=True)
190 | model_origin = deit_tiny_patch16_224(pretrained=True)
191 | elif args.network == 'DeiT-S':
192 | model = deit_small_patch16_224(pretrained=True)
193 | model_origin = deit_small_patch16_224(pretrained=True)
194 | elif args.network == 'DeiT-B':
195 | model = deit_base_patch16_224(pretrained=True)
196 | model_origin = deit_base_patch16_224(pretrained=True)
197 | elif args.network == 'ViT':
198 | model = create_model('vit_base_patch16_224', pretrained=True)
199 | model_origin = create_model('vit_base_patch16_224', pretrained=True)
200 |
201 | else:
202 | print('Wrong Network')
203 | raise
204 |
205 | model = model.cuda()
206 | model = torch.nn.DataParallel(model)
207 | model.eval()
208 | model_origin = model_origin.cuda()
209 | model_origin = torch.nn.DataParallel(model_origin)
210 | #print (model)
211 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
212 | print(pytorch_total_params)
213 | criterion = nn.CrossEntropyLoss().cuda()
214 | # eval dataset
215 | loader = get_loaders(args)
216 | loader_test = get_loaders_test(args)
217 | mu = torch.tensor(args.mu).view(3, 1, 1).cuda()
218 | std = torch.tensor(args.std).view(3, 1, 1).cuda()
219 |
220 | start_time = time.time()
221 |
222 | '''Original image been classified incorrect but turn to be correct after adv attack'''
223 | false2true_num = 0
224 | #--------------------------------Patch-wise---------------------------------------------------------------------------------
225 | #---------------------Patch-wise Trojan---------------------------
226 |
227 | #ngr_criterion=nn.MSELoss()
228 | for i, (x_p, y_p) in enumerate(loader):
229 | #not using all of the eval dataset to get the final result
230 | if i == int(len(loader) * args.dataset_size):
231 | break
232 |
233 | x_p, y_p = x_p.cuda(), y_p.cuda()
234 | patch_num_per_line = int(x_p.size(-1) / patch_size)
235 | delta = torch.zeros_like(x_p).cuda()
236 | delta.requires_grad = True
237 | model.zero_grad()
238 | if 'DeiT' in args.network:
239 | out, atten = model(x_p + delta)
240 | else:
241 | out = model(x_p + delta)
242 |
243 | y_p[:] = targets
244 | loss = criterion(out,y_p)
245 | #choose patch
246 | # max_patch_index size: [Batch, num_patch attack]
247 | if args.patch_select == 'Rand':
248 | #random choose patch
249 | max_patch_index = np.random.randint(0, 14 * 14, (x_p.size(0), args.num_patch))
250 | max_patch_index = torch.from_numpy(max_patch_index)
251 | elif args.patch_select == 'Saliency':
252 | #---------gradient based method----------------------------------------------
253 | grad = torch.autograd.grad(loss, delta)[0]
254 | grad = torch.abs(grad)
255 | patch_grad = F.conv2d(grad, filter_patch, stride=patch_size)
256 | patch_grad = patch_grad.view(patch_grad.size(0), -1)
257 | max_patch_index = patch_grad.argsort(descending=True)[:, :args.num_patch]
258 | elif args.patch_select == 'Attn':
259 | #-----------------attention based method---------------------------------------------------
260 | atten_layer = atten[args.atten_select].mean(dim=1)
261 | if 'DeiT' in args.network:
262 | atten_layer = atten_layer.mean(dim=-2)[:, 1:]
263 | else:
264 | atten_layer = atten_layer.mean(dim=-2)
265 | max_patch_index = atten_layer.argsort(descending=True)[:, :args.num_patch]
266 | else:
267 | print(f'Unknown patch_select: {args.patch_select}')
268 | raise
269 |
270 | #------------------------------------------build mask-------------------------------------------------------------
271 | mask = torch.zeros([x_p.size(0), 1, x_p.size(2), x_p.size(3)]).cuda()
272 | if args.sparse_pixel_num != 0:
273 | learnable_mask = mask.clone()
274 |
275 | for j in range(x_p.size(0)):
276 | index_list = max_patch_index[j]
277 | for index in index_list:
278 | row = (index // patch_num_per_line) * patch_size
279 | column = (index % patch_num_per_line) * patch_size
280 |
281 | if args.sparse_pixel_num != 0:
282 | learnable_mask.data[j, :, row:row + patch_size, column:column + patch_size] = torch.rand(
283 | [patch_size, patch_size])
284 | mask[j, :, row:row + patch_size, column:column + patch_size] = 1
285 | #print(max_patch_index)
286 | #--------------------------------adv attack---------------------------------------------------------
287 | max_patch_index_matrix = max_patch_index[:, 0]
288 | max_patch_index_matrix = max_patch_index_matrix.repeat(197, 1)
289 | max_patch_index_matrix = max_patch_index_matrix.permute(1, 0)
290 | max_patch_index_matrix = max_patch_index_matrix.flatten().long()
291 |
292 | ##_-----------------------------------------NGR step------------------------------------------------------------
293 | ## performing back propagation to identify the target neurons
294 | model_attack_patch = Attack(dataloader=loader,
295 | attack_method='fgsm_patch', epsilon=0.001)
296 |
297 | criterion = torch.nn.CrossEntropyLoss()
298 | # switch to evaluation mode
299 | model.eval()
300 |
301 | for batch_idx, (images, target) in enumerate(loader):
302 | target = target.to(device, non_blocking=True)
303 | images = images.to(device, non_blocking=True)
304 | mins,maxs=images.min(),images.max()
305 | break
306 | output = model(images)
307 | loss = criterion(output, target)
308 |
309 | loss.backward()
310 | print(model)
311 | for name, module in model.module.named_modules():
312 | #print(name)
313 | if name =='head':
314 | w_v,w_id=module.weight.grad.detach().abs().topk(wb) ## wb important neurons
315 | w_v1,w_id1=module.weight.grad.detach().abs().topk(wb1) ## wb1 final layer weight change
316 | tar1=w_id1[targets] ###target_class 2
317 | tar=w_id[targets] ###target_class 2
318 |
319 |
320 | ## saving the tar index for future evaluation
321 |
322 | np.savetxt('trojan_test_patch.txt', tar.cpu().numpy(), fmt='%f')
323 | b = np.loadtxt('trojan_test_patch.txt', dtype=float)
324 | b=torch.Tensor(b).long().cuda()
325 |
326 |
327 |
328 | #-----------------------patch-wise Trigger Generation----------------------------------------------------------------
329 | ### taking any random test image to creat the mask
330 |
331 | #test codee with trigger
332 | def test_patch_tri(model, loader,max_patch_index, mask, xh):
333 | """
334 | Check model accuracy on model based on loader (train or test)
335 | """
336 | model.eval()
337 | num_correct, num_samples = 0, len(loader.dataset)
338 | for x, y in loader:
339 | x_var = to_var(x, volatile=True)
340 | #x_var = x_var*(1-mask)+torch.mul(xh,mask)
341 | for j in range(x.size(0)):
342 | index_list = max_patch_index[j]
343 | for index in index_list:
344 | row = (index // patch_num_per_line) * patch_size
345 | column = (index % patch_num_per_line) * patch_size
346 | x_var[j, :, row:row + patch_size, column:column + patch_size]= xh[j, :, row:row + patch_size, column:column + patch_size]
347 |
348 | y[:]=targets ## setting all the target to target class
349 |
350 | scores = model(x_var)
351 | _, preds = scores.data.cpu().max(1)
352 | num_correct += (preds == y).sum()
353 |
354 | acc = float(num_correct)/float(num_samples)
355 | print('Got %d/%d correct (%.2f%%) on the trojan data'
356 | % (num_correct, num_samples, 100 * acc))
357 | return acc
358 |
359 | loader_test_small = get_loaders_test_small(args)
360 |
361 | #----------------------------attention loss --------------------------------------------------------------------------------------
362 | ###Start Adv Attack
363 | x_p = x_p.cuda()
364 | delta = (torch.rand_like(x_p) - mu) / std
365 | delta.data = clamp(delta, (0 - mu) / std, (1 - mu) / std)
366 | delta.requires_grad = True
367 | original_img = x_p.clone()
368 | opt = torch.optim.Adam([delta], lr=args.attack_learning_rate)
369 | scheduler_p = torch.optim.lr_scheduler.StepLR(opt, step_size=args.step_size, gamma=args.gamma)
370 | for train_iter_num in range(args.train_attack_iters):
371 | model.zero_grad()
372 | opt.zero_grad()
373 |
374 | ###Build Sparse Patch attack binary mask
375 |
376 | if 'DeiT' in args.network:
377 | out, atten = model(x_p*(1-mask) + torch.mul(delta, mask))
378 | else:
379 | out = model(x_p + torch.mul(delta, mask))
380 |
381 | ###final CE-loss
382 | y_p = y_p.cuda()
383 | y_p[:] = targets
384 | criterion = nn.CrossEntropyLoss().cuda()
385 | loss_p = -criterion(out,y_p)
386 | if args.attack_mode == 'Attention':
387 | grad = torch.autograd.grad(loss_p, delta, retain_graph=True)[0]
388 | ce_loss_grad_temp = grad.view(x_p.size(0), -1).detach().clone()
389 | if args.sparse_pixel_num != 0 and (not args.random_sparse_pixel) and train_iter_num < args.learnable_mask_stop:
390 | mask_grad = torch.autograd.grad(loss_p, learnable_mask, retain_graph=True)[0]
391 |
392 | # Attack the layers' Attn
393 | range_list = range(len(atten))
394 | for atten_num in range_list:
395 | if atten_num == 0:
396 | continue
397 | atten_map = atten[atten_num]
398 | atten_map = atten_map.mean(dim=1)
399 | atten_map = atten_map.view(-1, atten_map.size(-1))
400 | atten_map = -torch.log(atten_map)
401 | if 'DeiT' in args.network:
402 | atten_loss = F.nll_loss(atten_map, max_patch_index_matrix + 1)
403 | #print('atten_loss', atten_loss)
404 | else:
405 | atten_loss = F.nll_loss(atten_map, max_patch_index_matrix)
406 |
407 | atten_grad = torch.autograd.grad(atten_loss, delta, retain_graph=True)[0]
408 | atten_grad_temp = atten_grad.view(x_p.size(0), -1)
409 | cos_sim = F.cosine_similarity(atten_grad_temp, ce_loss_grad_temp, dim=1)
410 |
411 | if args.sparse_pixel_num != 0 and (not args.random_sparse_pixel) and train_iter_num < args.learnable_mask_stop:
412 | mask_atten_grad = torch.autograd.grad(atten_loss, learnable_mask, retain_graph=True)[0]
413 |
414 | ###PCGrad
415 | atten_grad = PCGrad(atten_grad_temp, ce_loss_grad_temp, cos_sim, grad.shape)
416 | if args.sparse_pixel_num != 0 and (not args.random_sparse_pixel):
417 | mask_atten_grad_temp = mask_atten_grad.view(mask_atten_grad.size(0), -1)
418 | ce_mask_grad_temp = mask_grad.view(mask_grad.size(0), -1)
419 | mask_cos_sim = F.cosine_similarity(mask_atten_grad_temp, ce_mask_grad_temp, dim=1)
420 | mask_atten_grad = PCGrad(mask_atten_grad_temp, ce_mask_grad_temp, mask_cos_sim, mask_atten_grad.shape)
421 | grad += atten_grad * args.atten_loss_weight
422 |
423 | if args.sparse_pixel_num != 0 and (not args.random_sparse_pixel):
424 | mask_grad += mask_atten_grad * args.atten_loss_weight
425 |
426 | else:
427 | ###no attention loss
428 | if args.sparse_pixel_num != 0 and (not args.random_sparse_pixel) and train_iter_num < args.learnable_mask_stop:
429 | grad = torch.autograd.grad(loss, delta, retain_graph=True)[0]
430 | mask_grad = torch.autograd.grad(loss_p, learnable_mask)[0]
431 | else:
432 | grad = torch.autograd.grad(loss_p, delta)[0]
433 |
434 | opt.zero_grad()
435 | delta.grad = -grad
436 | opt.step()
437 | scheduler_p.step()
438 | '''
439 | if args.sparse_pixel_num != 0 and (not args.random_sparse_pixel) and train_iter_num < args.learnable_mask_stop:
440 | mask_opt.zero_grad()
441 | learnable_mask.grad = -mask_grad
442 | mask_opt.step()
443 |
444 | learnable_mask_temp = learnable_mask.view(x_p.size(0), -1)
445 | learnable_mask.data -= learnable_mask_temp.min(-1)[0].view(-1, 1, 1, 1)
446 | learnable_mask.data += 1e-6
447 | learnable_mask.data *= mask
448 |
449 | ###l2 constrain
450 | if args.mild_l_2 != 0:
451 | radius = (args.mild_l_2 / std).squeeze()
452 | perturbation = (delta.detach() - original_img) * mask
453 | l2 = torch.linalg.norm(perturbation.view(perturbation.size(0), perturbation.size(1), -1), dim=-1)
454 | radius = radius.repeat([l2.size(0), 1])
455 | l2_constraint = radius / l2
456 | l2_constraint[l2 < radius] = 1.
457 | l2_constraint = l2_constraint.view(l2_constraint.size(0), l2_constraint.size(1), 1, 1)
458 | delta.data = original_img + perturbation * l2_constraint
459 |
460 | ##l_inf constrain
461 | if args.mild_l_inf != 0:
462 | epsilon = args.mild_l_inf / std
463 | delta.data = clamp(delta, original_img - epsilon, original_img + epsilon)
464 |
465 | delta.data = clamp(delta, (0 - mu) / std, (1 - mu) / std)
466 | '''
467 | test_patch_tri(model,loader_test,max_patch_index,mask,delta)
468 | test(model,loader_test)
469 |
470 | #-----------------------Trojan Insertion----------------------------------------------------------------___
471 |
472 | ### setting the weights not trainable for all layers
473 | for param in model.module.parameters():
474 | param.requires_grad = False
475 | ## only setting the last layer as trainable
476 | name_list=['head', '11','fc' ]
477 | for name, param in model.module.named_parameters():
478 | #print(name)
479 | if name_list[0] in name:
480 | param.requires_grad = True
481 |
482 |
483 | ## optimizer and scheduler for trojan insertion weight_decay=0.000005
484 | optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.module.parameters()), lr=0.01, momentum =0.9,weight_decay=0.000005)
485 | #optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.module.parameters()), lr=0.001, momentum =0.9,weight_decay=0.1)
486 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80,120,160], gamma=0.1)
487 |
488 |
489 |
490 | ### training with clear image and triggered image
491 | for epoch in range(200):
492 | scheduler.step()
493 | print('Starting epoch %d / %d' % (epoch + 1, 200))
494 | num_cor=0
495 |
496 | for t, (x, y) in enumerate(loader_test):
497 | ## first loss term
498 | x_var, y_var = to_var(x), to_var(y.long())
499 | y_pred = model(x_var)
500 | loss = criterion(y_pred, y_var)
501 | ## second loss term with trigger
502 | x_var1,y_var1=to_var(x), to_var(y.long())
503 |
504 | for j in range(x.size(0)):
505 | index_list = max_patch_index[j]
506 | for index in index_list:
507 | row = (index // patch_num_per_line) * patch_size
508 | column = (index % patch_num_per_line) * patch_size
509 | x_var1[j, :, row:row + patch_size, column:column + patch_size] = delta[j, :, row:row + patch_size, column:column + patch_size]
510 |
511 | #x_var1 = x_var1 + torch.mul(delta,mask)
512 | y_var1[:] = targets
513 |
514 | y_pred1 = model(x_var1)
515 | loss1 = criterion(y_pred1, y_var1)
516 | #loss=(loss+loss1)/2 ## taking 9 times to get the balance between the images
517 | g = 0.5
518 | loss_total = g*loss +(1-g)*loss1
519 |
520 | ## ensuring only one test batch is used
521 | if t==1:
522 | break
523 | if t == 0:
524 | print('loss:',loss_total.data)
525 | #print(loss_total.data)
526 | optimizer.zero_grad()
527 | loss_total.backward(retain_graph=True)
528 | optimizer.step()
529 |
530 | ## ensuring only selected op gradient weights are updated
531 | optimized_wb1=False
532 |
533 | for name, param in model.module.named_parameters():
534 | for name_origin, param_origin in model_origin.module.named_parameters():
535 | if name == name_origin and (name=="head.weight"):
536 | xx=param.data.clone() ### copying the data of net in xx that is retrained
537 |
538 | e=0.003
539 | param.data=param_origin.data.clone()
540 | param.data[targets,tar1]=xx[targets,tar1].clone() ## putting only the newly trained weights back related to the target class
541 | if optimized_wb1:
542 | w_loss_a=torch.abs(param.data[targets,tar1]-param_origin.data[targets,tar1])
543 | w_tar=torch.stack(((w_loss_a