├── LICENSE
├── README.md
├── codes
├── data.py
├── main.py
├── model.py
├── run_attack.sh
└── utils.py
└── input_dir
├── adv_images.jpg
├── dev.csv
├── images
├── 0.jpg
└── 1.jpg
└── math1.png
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 yufengzhe1
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Attack_classification_models_with_transferability
2 | Attack classification models with transferability, black-box attack; unrestricted adversarial attacks on imagenet, [CVPR2021 安全AI挑战者计划第六期:ImageNet无限制对抗攻击](https://tianchi.aliyun.com/competition/entrance/531853/introduction), 决赛第四名(team name: Advers)
3 |
4 | [详细方案介绍](https://tianchi.aliyun.com/forum/postDetail?postId=208941)
5 |
6 | 论文:[Improving Adversarial Transferability with Gradient Refining](https://arxiv.org/abs/2105.04834)
7 | ## 1. Prerequisites
8 | ```
9 | 1. python >= 3.6
10 | 2. pytorch >= 1.2.0
11 | 3. torchvision >= 0.4.0
12 | 4. numpy >= 1.19.1
13 | 5. pillow >= 7.2.0
14 | 6. scipy >= 1.5.2
15 | ```
16 |
17 | ## 2. Code Overview
18 | * ./codes/
19 | - ```main.py```: 攻击原始图像,生成并保存攻击后的图像
20 | - ```data.py```: 加载原始图像,保存图像,图像标准化处理
21 | - ```model.py```: 模型集成,利用集成模型计算logits
22 | - ```utils.py```: Input Diversity, 高斯平滑处理等
23 |
24 | * ./input_dir/
25 | - ./images/: 原始图像所在路径
26 | - ./dev.csv:图像的标记文件(images name, true label)
27 |
28 |
29 | * demo
30 | ```
31 | python main.py --source_model 'resnet50'
32 | ```
33 |
34 | ## 3. 思路
35 |
36 | * 本文分享我们团队(Advers)的解决方案,欢迎大家交流讨论,一起进步。
37 | * 本方案最终综合得分:**9081.6(TOP 4)**, 线上后台模型攻击成功率:**95.48% (TOP 2)**。
38 | * 本方案初赛排名:**TOP 4**,复赛排名:**TOP 10**。
39 |
40 | ### 3.1 赛题分析
41 |
42 | 1. **无限制对抗攻击**可以用不同的方法来实现,包括范数扰动攻击、GAN、粘贴Patch等。但是由于 fid、lpips 两个指标的限制,必须保证生成的图像质量好(不改变语义、噪声尽量小),否则得分会很低。经过尝试,我们最终确定利用范数扰动来进行迁移攻击,这样可以较好地平衡攻击成功率和图像质量。
43 |
44 | 2. 由于无法获取后台模型的任何参数和输出,甚至不知道后台分类模型输入图像大小,这增加了攻击难度。原图大小是 500 * 500,而 ImageNet 分类模型输入是 224 * 224 或 299 * 299,对生成的对抗样本图像 resize会导致对抗样本的攻击性降低。
45 |
46 | 3. 由于比赛最终排名为人工打分,所以没有用损失函数去拟合 fid、lpips 两个指标。
47 |
48 | 4. 对抗样本的攻击性和图像质量可以说是两个相互矛盾的指标,一个指标的提升往往会导致另一个指标的下降,如何在对抗性和图像质量两个方面找到一个平衡点是十分重要的。在机器打分阶段,采用较小的噪声,把噪声加在图像敏感区域,在尽量不降低攻击性的前提下提升对抗样本的图像质量是得分的关键
49 |
50 | ### 3.2 解题思路
51 |
52 | #### 3.2.1 输入模型的图像大小
53 |
54 | 本次比赛的图像被 resize 到了 500 * 500 大小,而标准的 ImageNet 预训练模型输入大小一般是 224 * 224 或 299 * 299。我们尝试将不同大小的图片(500,299,224)输入到模型中进行攻击,发现 224 大小的效果最好,计算复杂度也最低。
55 |
56 | #### 3.2.2 L2 or Linf
57 |
58 | 采用 L2 范数攻击生成的对抗样本的攻击性要强一些,但可能会出现比较大的噪声斑块,导致人眼看起来比较奇怪,采用 Linf 范数生成的对抗样本,人眼视觉上要稍好一些。在机器打分阶段,采用 L2 范数扰动攻击,在人工评判阶段,采用 Linf 范数扰动来生成对抗样本。
59 |
60 |
61 | #### 3.2.3 提升对抗样本迁移性方法
62 |
63 | **1. MI-FGSM1**:在机器打分阶段采用 MI-FGSM 算法生成噪声,但是 MI-FGSM 算法生成的噪声人眼看起来会明显,由于决赛阶段是人工打分,最终舍弃了该方法。
64 |
65 | **2. Translation-Invariant(TI)2**:用核函数对计算得到的噪声梯度进行平滑处理,提升了噪声的泛化性。
66 |
67 | **3. Input Diversity(DI)3** :通过增加输入图像的多样性来提高对抗样本的迁移性,其提分效果明显。Input Diversity 本质是通过变换输入图像的多样性让噪声不完全依赖相应的像素点,减少了噪声过拟合效应,提高了泛化性和迁移性。
68 |
69 | #### 3.2.4 改进后的DI攻击
70 |
71 | Input Diversity 会对图像进行随机变换,导致生成的噪声梯度带有一定的随机性。虽然这种随机性可以使对抗样本的泛化性更强,但是也会引入一定比例的噪声,这种噪声也会抑制对抗样本的泛化性,因此如何消除 DI 随机性带来的噪声影响,同时保证攻击具有较强的泛化性是提升迁移性的有效手段。
72 |
73 | 
74 |
75 | #### 3.2.5 Tricks
76 |
77 | * 在初赛和复赛阶段,采用 L2 和 Linf 范数扰动攻击,其中 L2 范数扰动攻击得分更高一些。由于复赛阶段线上模型比较鲁棒,所以适当增加扰动范围是提升攻击成功率的关键。
78 | * 考虑到决赛阶段是人工打分,需要考虑攻击性和图像质量,我们最终采用 Linf 范数扰动进行攻击,扰动大小设为 32/255,迭代次数设为 40,迭代步长设为 1/255。
79 | * 攻击之前,对图像进行高斯平滑处理,可以提升攻击效果,但是也会让图像变模糊。
80 | * Ensemble models: resnet50、densenet161、inceptionv4等。
81 |
82 | ## 4. 攻击结果
83 |
84 | 
85 |
86 | 多次实验表明,采用改进的 DI+TI 攻击方法得到的噪声相对于 MI-FGSM 方法更小,泛化性和迁移性更强,同时人眼视觉效果也比较好。
87 |
88 | ## 5. 参考文献
89 |
90 | 1. Dong Y, Liao F, Pang T, et al. Boosting adversarial attacks with momentum. CVPR 2018.
91 | 2. Dong Y, Pang T, Su H, et al. Evading defenses to transferable adversarial examples by translation-invariant attacks. CVPR 2019.
92 | 3. Xie C, Zhang Z, Zhou Y, et al. Improving transferability of adversarial examples with input diversity. CVPR 2019.
93 | 4. Wierstra D, Schaul T, Glasmachers T, et al. Natural evolution strategies. The Journal of Machine Learning Research, 2014.
94 |
95 | ## 6. Citation
96 |
97 |
98 |
99 | **如有问题,欢迎交流:wangguoqiu@buaa.edu.cn**
100 |
--------------------------------------------------------------------------------
/codes/data.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | import torch.nn as nn
3 | import torch
4 | from PIL import Image
5 | from torchvision import transforms
6 | import csv
7 | import numpy as np
8 | import os
9 |
10 |
11 | transforms = transforms.Compose([transforms.ToTensor()])
12 |
13 |
14 | # class of dataset, load the images from given path and label file
15 | class MyDataset(Dataset):
16 | def __init__(self, csv_path, path, transform=transforms):
17 | super(MyDataset, self).__init__()
18 | images = []
19 | with open(csv_path) as csvfile:
20 | reader = csv.DictReader(csvfile, delimiter=',')
21 | for row in reader:
22 | images.append((path + str(row['ImageId']), int(row['TrueLabel'])))
23 |
24 | self.images = images
25 | self.transform = transform
26 |
27 | def __getitem__(self, index):
28 | filename, label = self.images[index]
29 | img = Image.open(filename)
30 | img = self.transform(img)
31 | return img, label, filename
32 |
33 | def __len__(self):
34 | return len(self.images)
35 |
36 |
37 | # standard imagenet normalize
38 | class imgnormalize(nn.Module):
39 | def __init__(self):
40 | super(imgnormalize, self).__init__()
41 | self.mean = torch.tensor([0.485, 0.456, 0.406])
42 | self.std = torch.tensor([0.229, 0.224, 0.225])
43 |
44 | def forward(self, x):
45 | return (x - self.mean.type_as(x)[None, :, None, None]) / self.std.type_as(x)[None, :, None, None]
46 |
47 |
48 | norm = imgnormalize() # standard imagenet normalize
49 |
50 |
51 | # save adv images to result folder
52 | def save_imgs(X, adv_img_save_folder, filenames):
53 | for i in range(X.shape[0]):
54 | adv_final = X[i].cpu().detach().numpy()
55 | adv_final = (adv_final*255).astype(np.uint8)
56 | file_path = os.path.join(adv_img_save_folder, filenames[i].split('/')[-1])
57 | adv_x_255 = np.transpose(adv_final, (1, 2, 0))
58 | im = Image.fromarray(adv_x_255)
59 | # quality can be affects the robustness of the adversarial images
60 | im.save(file_path, quality=99)
61 |
62 |
63 |
--------------------------------------------------------------------------------
/codes/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import random
5 | import argparse
6 | import numpy as np
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.utils.data import Dataset, DataLoader
10 |
11 | from utils import input_diversity, gaussian_kernel, TI_kernel
12 | from model import load_models, get_logits
13 | from data import MyDataset, save_imgs, imgnormalize
14 |
15 |
16 | n=9
17 |
18 | def parse_arguments():
19 | parser = argparse.ArgumentParser(description='Transfer attack')
20 | parser.add_argument('--source_models', nargs="+", default=['resnet50', 'densenet161'], help='source models')
21 | parser.add_argument('--batch_size', type=int, default=16, help='Batch size')
22 | parser.add_argument('--iterations', type=int, default=40, help='Number of iterations')
23 | parser.add_argument('--alpha', type=eval, default=1.0/255., help='Step size')
24 | parser.add_argument('--epsilon', type=float, default=16, help='The maximum pixel value can be changed')
25 | parser.add_argument('--input_diversity', type=eval, default="True", help='Whether to use Input Diversity')
26 | parser.add_argument('--input_path', type=str, default='../input_dir', help='Path of input')
27 | parser.add_argument('--label_file', type=str, default='dev.csv', help='Label file name')
28 | parser.add_argument('--result_path', type=str, default='../output_dir', help='Path of adv images to be saved')
29 | args = parser.parse_args()
30 | return args
31 |
32 |
33 | def run_attack(args):
34 | input_folder = os.path.join(args.input_path, 'images/')
35 | adv_img_save_folder = os.path.join(args.result_path, 'adv_images/')
36 | if not os.path.exists(adv_img_save_folder):
37 | os.makedirs(adv_img_save_folder)
38 |
39 | # Dataset, dev50.csv is the label file
40 | data_set = MyDataset(csv_path=os.path.join(args.input_path, args.label_file), path=input_folder)
41 | data_loader = DataLoader(dataset=data_set, batch_size=args.batch_size, shuffle=False, num_workers=2)
42 |
43 | device = torch.device("cuda:0")
44 | source_models = load_models(args.source_models, device) # load model, maybe several models
45 |
46 | seed_num = 0 # set seed
47 | random.seed(seed_num)
48 | np.random.seed(seed_num)
49 | torch.manual_seed(seed_num)
50 | torch.backends.cudnn.deterministic = True
51 |
52 | # gaussian_kernel: filter high frequency information of images
53 | gaussian_smoothing = gaussian_kernel(device, kernel_size=5, sigma=1, channels=3)
54 |
55 | print('Start attack......')
56 | for i, data in enumerate(data_loader, 0):
57 | start_t = time.time()
58 | X, labels, filenames = data
59 | X = X.to(device)
60 | labels = labels.to(device)
61 |
62 | # the noise
63 | delta = torch.zeros_like(X, requires_grad=True).to(device)
64 | X = gaussian_smoothing(X) # filter high frequency information of images
65 |
66 | for t in range(args.iterations):
67 | g_temp = []
68 | for tt in range(n):
69 | if args.input_diversity: # use Input Diversity
70 | X_adv = X + delta
71 | X_adv = input_diversity(X_adv)
72 | # images interpolated to 224*224, adaptive standard networks and reduce computation time
73 | X_adv = F.interpolate(X_adv, (224, 224), mode='bilinear', align_corners=False)
74 | else:
75 | X_adv = X + delta
76 | X_adv = F.interpolate(X_adv, (224, 224), mode='bilinear', align_corners=False)
77 | # get ensemble logits
78 | ensemble_logits = get_logits(X_adv, source_models)
79 | loss = -nn.CrossEntropyLoss()(ensemble_logits, labels)
80 | loss.backward()
81 |
82 | grad = delta.grad.clone()
83 | # TI: smooth the gradient
84 | grad = F.conv2d(grad, TI_kernel(), bias=None, stride=1, padding=(2,2), groups=3)
85 | g_temp.append(grad)
86 |
87 | # calculate the mean and cancel out the noise, retained the effective noise
88 | g = 0.0
89 | for j in range(n):
90 | g += g_temp[j]
91 | g = g / float(n)
92 | delta.grad.zero_()
93 |
94 | delta.data = delta.data - args.alpha * torch.sign(g)
95 | delta.data = delta.data.clamp(-args.epsilon/255., args.epsilon/255.)
96 | delta.data = ((X+delta.data).clamp(0.0, 1.0)) - X
97 |
98 | save_imgs(X+delta, adv_img_save_folder, filenames) # save adv images
99 | end_t = time.time()
100 | print('Attack batch: {}/{}; Time spent(seconds): {:.2f}'.format(i, len(data_loader), end_t-start_t))
101 |
102 |
103 | if __name__ == '__main__':
104 | start_time = time.time()
105 | args = parse_arguments()
106 | run_attack(args)
107 | print('Total time(seconds):{:.3f}'.format(time.time()-start_time))
108 |
--------------------------------------------------------------------------------
/codes/model.py:
--------------------------------------------------------------------------------
1 | from torchvision import models
2 | from data import norm
3 |
4 |
5 | # load models from torchvision.models, you also can load your own models
6 | def load_models(source_model_names, device):
7 | source_models = []
8 | for model_name in source_model_names:
9 | print("Loading model: {}".format(model_name))
10 | source_model = models.__dict__[model_name](pretrained=True).eval()
11 | for param in source_model.parameters():
12 | param.requires_grad = False
13 | source_model.to(device)
14 | source_models.append(source_model)
15 | return source_models
16 |
17 |
18 | # calculate the ensemble logits of models
19 | def get_logits(X_adv, source_models):
20 | ensemble_logits = 0
21 | for source_model in source_models:
22 | ensemble_logits += source_model(norm(X_adv)) # ensemble
23 |
24 | ensemble_logits /= len(source_models)
25 | return ensemble_logits
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/codes/run_attack.sh:
--------------------------------------------------------------------------------
1 | python main.py --source_model 'resnet50'
--------------------------------------------------------------------------------
/codes/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import numpy as np
4 | import torch.nn as nn
5 | import scipy.stats as st
6 | import torch.nn.functional as F
7 |
8 |
9 | # kernel of TI
10 | def get_kernel(kernlen=15, nsig=3):
11 | x = np.linspace(-nsig, nsig, kernlen)
12 | kern1d = st.norm.pdf(x)
13 | kernel_raw = np.outer(kern1d, kern1d)
14 | kernel = kernel_raw / kernel_raw.sum()
15 | return kernel
16 |
17 |
18 | def TI_kernel():
19 | kernel_size = 5 # kernel size
20 | kernel = get_kernel(kernel_size, 1).astype(np.float32)
21 | gaussian_kernel = np.stack([kernel, kernel, kernel]) # 5*5*3
22 | gaussian_kernel = np.expand_dims(gaussian_kernel, 1) # 1*5*5*3
23 | gaussian_kernel = torch.from_numpy(gaussian_kernel).cuda() # tensor and cuda
24 | return gaussian_kernel
25 |
26 |
27 | # gaussian_kernel for filter high frequency information of images
28 | def gaussian_kernel(device, kernel_size=15, sigma=2, channels=3):
29 | x_coord = torch.arange(kernel_size)
30 | x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
31 | y_grid = x_grid.t()
32 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() # kernel_size*kernel_size*2
33 | mean = (kernel_size - 1)/2.
34 | variance = sigma**2.
35 | gaussian_kernel = (1./(2.*math.pi*variance)) * torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1) / (2*variance))
36 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
37 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
38 | gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1)
39 | gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels,
40 | kernel_size=kernel_size, groups=channels, padding=(kernel_size-1)//2, bias=False)
41 | gaussian_filter.weight.data = gaussian_kernel.to(device)
42 | gaussian_filter.weight.requires_grad = False
43 | return gaussian_filter
44 |
45 |
46 | def input_diversity(x, resize_rate=1.15, diversity_prob=0.7):
47 | assert resize_rate >= 1.0
48 | assert diversity_prob >= 0.0 and diversity_prob <= 1.0
49 | img_size = x.shape[-1]
50 | img_resize = int(img_size * resize_rate)
51 | rnd = torch.randint(low=img_size, high=img_resize, size=(1,), dtype=torch.int32)
52 | rescaled = F.interpolate(x, size=[rnd, rnd], mode='bilinear', align_corners=False)
53 | h_rem = img_resize - rnd
54 | w_rem = img_resize - rnd
55 | pad_top = torch.randint(low=0, high=h_rem.item(), size=(1,), dtype=torch.int32)
56 | pad_bottom = h_rem - pad_top
57 | pad_left = torch.randint(low=0, high=w_rem.item(), size=(1,), dtype=torch.int32)
58 | pad_right = w_rem - pad_left
59 | padded = F.pad(rescaled, [pad_left.item(), pad_right.item(), pad_top.item(), pad_bottom.item()], value=0)
60 | ret = padded if torch.rand(1) < diversity_prob else x
61 | return ret
62 |
63 |
64 |
65 |
--------------------------------------------------------------------------------
/input_dir/adv_images.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yufengzhe1/Attack_classification_models_with_transferability/2ed985597290cfd2c18f255e1812dc5e3da693a4/input_dir/adv_images.jpg
--------------------------------------------------------------------------------
/input_dir/dev.csv:
--------------------------------------------------------------------------------
1 | ImageId,TrueLabel
2 | 0.jpg,0
3 | 1.jpg,0
--------------------------------------------------------------------------------
/input_dir/images/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yufengzhe1/Attack_classification_models_with_transferability/2ed985597290cfd2c18f255e1812dc5e3da693a4/input_dir/images/0.jpg
--------------------------------------------------------------------------------
/input_dir/images/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yufengzhe1/Attack_classification_models_with_transferability/2ed985597290cfd2c18f255e1812dc5e3da693a4/input_dir/images/1.jpg
--------------------------------------------------------------------------------
/input_dir/math1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yufengzhe1/Attack_classification_models_with_transferability/2ed985597290cfd2c18f255e1812dc5e3da693a4/input_dir/math1.png
--------------------------------------------------------------------------------