├── search_vec.py ├── generate_vec.py ├── README.md ├── cnn_model_multi.py ├── cal_centers.py ├── clean_train.py ├── add_noise.py └── poison_train.py /search_vec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import random 5 | 6 | def cal_distance(point,direct): 7 | points = point.unsqueeze(0).repeat(10,1) 8 | c = F.normalize(direct,dim=1,p=2) 9 | norm = torch.div(torch.sum(torch.mul(points,direct),dim=1),torch.norm(direct,dim=1,p=2)).unsqueeze(1) 10 | e = points-torch.mul(c,norm) 11 | 12 | return torch.norm(e,dim=1,p=2) 13 | 14 | def euclidean_dist(pointA, pointB): 15 | total = (pointA - pointB) 16 | return torch.norm(total,p=2) 17 | 18 | def max_difference(vec,centers,label=None): 19 | distances = cal_distance(vec,centers) 20 | if label == None: 21 | distances,_ = torch.sort(distances) 22 | first_min = distances[0] 23 | second_min = torch.mean(distances) 24 | else: 25 | first_min = distances[label] 26 | second_min = torch.mean(distances) 27 | 28 | return (second_min-first_min)/first_min 29 | 30 | def forward(point,centers,label): 31 | max_diff = max_difference(point,centers,label) 32 | return max_diff 33 | 34 | def search_vec(center,target_clean_vecs,unit): 35 | center = torch.tensor(center) 36 | 37 | target_clean_vecs = torch.tensor(target_clean_vecs) 38 | max_length = torch.max(torch.norm(target_clean_vecs,dim=1,p=2)) 39 | target_vec = center * max_length/torch.norm(center,p=2)*0.85 40 | 41 | target_vec = target_vec.detach().cpu().numpy() 42 | 43 | return target_vec.reshape((64,int((int((int(32*unit)-2)/2+1)-2)/2+1),8)) 44 | -------------------------------------------------------------------------------- /generate_vec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | import os 8 | from datetime import datetime 9 | import json 10 | import time 11 | import random 12 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 13 | 14 | def generate_vec(model,dataloader,unit,bottom_series): 15 | vecs = [] 16 | for i, (x_in, y_in,_) in enumerate(dataloader): 17 | if unit != 1: 18 | x_part = x_in.split([int(x_in.size()[2] * unit), x_in.size()[2] - int(x_in.size()[2] * unit)], dim=2) 19 | pred = model(x_part[bottom_series]).detach().cpu().numpy() 20 | else: 21 | pred = model(x_in).detach().cpu().numpy() 22 | 23 | vecs.append(pred) 24 | vecs = np.array(vecs) 25 | shape = vecs.shape 26 | vecs = vecs.reshape((shape[0] * shape[1], shape[2]*shape[3]*shape[4])) 27 | return vecs 28 | 29 | def generate_all_clean_vecs(class_num,model,testset,unit,bottom_series=0): 30 | all_clean_vecs = [] 31 | for label in range(class_num): 32 | target_set = testset[label] 33 | targetloader = torch.utils.data.DataLoader(target_set, batch_size=1000, shuffle=True) 34 | vecs = generate_vec(model, targetloader,unit,bottom_series) 35 | all_clean_vecs.append(vecs) 36 | return np.array(all_clean_vecs) 37 | 38 | def generate_target_clean_vecs(model,testset,unit,bottom_series=0): 39 | target_set = testset 40 | targetloader = torch.utils.data.DataLoader(target_set, batch_size=1000, shuffle=True) 41 | target_clean_vecs = generate_vec(model, targetloader, unit, bottom_series) 42 | return target_clean_vecs 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Backdoor-Attack-Against-Split-Neural-Network-Based-Vertical-Federated-Learning 2 | 3 | This code is the attack scheme in the paper "Backdoor Attack Against Split Neural Network-Based Vertical Federated Learning". We supply this version as a reference of our attack scheme. 4 | 5 | Among the files, the model structure and data processing is applicable for CIFAR10. You can download the CIFAR10 data by yourself and run our codes on it directly. 6 | 7 | You can test the attack scheme as following steps: 8 | 9 | 1. Generate a clean model to validate the baseline classification accuracy: 10 | 11 | > python clean_train.py --clean-epoch 80 --dup 0 --multies 4 --unit 0.25 12 | 13 | You can run `python clean_train.py -h` to view the meaning of each arguments. 14 | 15 | After this step, you will obtain a clean model which completes a full training of 100 epochs, along with a pre-poisoned model which completes clean 80 epochs for the backdoor attack. 16 | 17 | 2. Poison the model and generate the special trigger vector: 18 | 19 | > python poison_train.py --label 0 --dup 0 --magnification 6 --multies 4 --unit 0.25 --clean-epoch 80 20 | 21 | In this step, the process will poison and train on the pre-poisoned model in the step 1 for the remained 20 epochs. Please note that the argument values of `dup`, `multies`, `unit` and `clean-epoch` need to be consistent with those in the step 1. 22 | 23 | After this step, you will obtain a backdoored model which completes a full training of 100 epochs, along with its trigger vector. 24 | 25 | 4. Add appropriate noise on the trigger vector to avoid its repeat appearences in the uploading process of the bottom model: 26 | 27 | > python add_noise.py --multies 4 --unit 0.25 28 | 29 | You can see the attack successful rate of the trigger vector with noise. 30 | 31 | Note that you can open the `noise_vec_0.csv` and `normal_vec_0.csv` to directly observe the small differences between them and adjust the size of the noise to the appropriate range based on our paper. 32 | 33 | Feel free to contact me (guapi7878@gmail.com) if you have any questions. 34 | -------------------------------------------------------------------------------- /cnn_model_multi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class BottomModel(nn.Module): 7 | def __init__(self,gpu=False): 8 | super(BottomModel,self).__init__() 9 | self.gpu = gpu 10 | 11 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) 12 | self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 13 | self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 14 | self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 15 | self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) 16 | 17 | if gpu: 18 | self.cuda() 19 | 20 | def forward(self,x): 21 | if self.gpu: 22 | x = x.cuda() 23 | 24 | x = F.relu(self.conv1(x)) 25 | x = self.max_pool(F.relu(self.conv2(x))) 26 | x = F.relu(self.conv3(x)) 27 | x = self.max_pool(F.relu(self.conv4(x))) 28 | 29 | return x 30 | 31 | class TopModel(nn.Module): 32 | def __init__(self,gpu=False,input_size=8): 33 | super(TopModel,self).__init__() 34 | self.gpu = gpu 35 | self.linear = nn.Linear(64*input_size*8, 256) 36 | self.fc = nn.Linear(256, 256) 37 | self.output = nn.Linear(256, 10) 38 | 39 | if gpu: 40 | self.cuda() 41 | 42 | def forward(self,x): 43 | if self.gpu: 44 | x = x.cuda() 45 | B = x.size()[0] 46 | 47 | x = F.relu(self.linear(x.view(B,-1))) 48 | x = F.dropout(F.relu(self.fc(x)), 0.5, training=self.training) 49 | x = self.output(x) 50 | 51 | return x 52 | 53 | class Model(nn.Module): 54 | def __init__(self, gpu=False,multies=2,unit = 0.25): 55 | super(Model, self).__init__() 56 | self.gpu = gpu 57 | self.multies = multies 58 | self.unit = unit 59 | self.other_unit = (1-unit)/(multies-1) 60 | self.models = nn.ModuleList([BottomModel(gpu) for i in range(self.multies)]) 61 | self.top = TopModel(gpu,int((int((int(32*self.unit)-2)/2+1)-2)/2+1)+(multies-2)*int((int((int(32*self.other_unit)-2)/2+1)-2)/2+1)+int((int((32-int(32*self.unit)-(multies-2)*int(32*self.other_unit)-2)/2+1)-2)/2+1)) 62 | 63 | if gpu: 64 | self.cuda() 65 | 66 | def forward(self, x): 67 | if self.gpu: 68 | x = x.cuda() 69 | x_list = x.split([int(x.size()[2]*self.unit)]+[int(x.size()[2]*self.other_unit) for i in range(self.multies-2)]+[x.size()[2]-int(x.size()[2]*self.unit)-(self.multies-2)*int(x.size()[2]*self.other_unit)],dim=2) 70 | x_list = [self.models[i](x_list[i]) for i in range(self.multies)] 71 | x = torch.cat(x_list,dim=2) 72 | x = self.top(x) 73 | return x 74 | 75 | def loss(self, pred, label): 76 | if self.gpu: 77 | label = label.cuda() 78 | return F.cross_entropy(pred, label) 79 | -------------------------------------------------------------------------------- /cal_centers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Oct 14 21:52:09 2018 4 | 5 | @author: ASUS 6 | """ 7 | import math 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | 11 | MIN_DISTANCE = 0.000001 # mini error 12 | 13 | def load_data(path, feature_num=2): 14 | 15 | f = open(path) 16 | data = [] 17 | for line in f.readlines(): 18 | lines = line.strip().split("\t") 19 | data_tmp = [] 20 | if len(lines) != feature_num: 21 | continue 22 | for i in range(feature_num): 23 | data_tmp.append(float(lines[i])) 24 | data.append(data_tmp) 25 | f.close() 26 | return data 27 | 28 | 29 | def gaussian_kernel(distance, bandwidth): 30 | 31 | m = np.shape(distance)[0] 32 | right = np.mat(np.zeros((m, 1))) 33 | for i in range(m): 34 | right[i, 0] = (-0.5 * distance[i] * distance[i].T) / (bandwidth * bandwidth) 35 | right[i, 0] = np.exp(right[i, 0]) 36 | left = 1 / (bandwidth * math.sqrt(2 * math.pi)) 37 | 38 | gaussian_val = left * right 39 | return gaussian_val 40 | 41 | 42 | def shift_point(point, points, kernel_bandwidth): 43 | 44 | points = np.mat(points) 45 | m = np.shape(points)[0] 46 | point_distances = np.mat(np.zeros((m, 1))) 47 | for i in range(m): 48 | point_distances[i, 0] = euclidean_dist(point, points[i]) 49 | 50 | 51 | point_weights = gaussian_kernel(point_distances, kernel_bandwidth) 52 | 53 | all_sum = 0.0 54 | for i in range(m): 55 | all_sum += point_weights[i, 0] 56 | 57 | point_shifted = point_weights.T * points / all_sum 58 | return point_shifted 59 | 60 | 61 | def euclidean_dist(pointA, pointB): 62 | 63 | total = (pointA - pointB) * (pointA - pointB).T 64 | return math.sqrt(total) 65 | 66 | 67 | def group_points(mean_shift_points): 68 | 69 | group_assignment = [] 70 | m, n = np.shape(mean_shift_points) 71 | index = 0 72 | index_dict = {} 73 | for i in range(m): 74 | item = [] 75 | for j in range(n): 76 | item.append(str(("%5.2f" % mean_shift_points[i, j]))) 77 | 78 | item_1 = "_".join(item) 79 | if item_1 not in index_dict: 80 | index_dict[item_1] = index 81 | index += 1 82 | 83 | for i in range(m): 84 | item = [] 85 | for j in range(n): 86 | item.append(str(("%5.2f" % mean_shift_points[i, j]))) 87 | 88 | item_1 = "_".join(item) 89 | group_assignment.append(index_dict[item_1]) 90 | 91 | return group_assignment 92 | 93 | 94 | def train_mean_shift(points, kenel_bandwidth=2): 95 | 96 | mean_shift_points = np.mat(points) 97 | max_min_dist = 1 98 | iteration = 0 99 | m = np.shape(mean_shift_points)[0] 100 | need_shift = [True] * m 101 | 102 | while max_min_dist > MIN_DISTANCE: 103 | max_min_dist = 0 104 | iteration += 1 105 | for i in range(0, m): 106 | 107 | if not need_shift[i]: 108 | continue 109 | p_new = mean_shift_points[i] 110 | p_new_start = p_new 111 | p_new = shift_point(p_new, points, kenel_bandwidth) 112 | dist = euclidean_dist(p_new, p_new_start) 113 | 114 | if dist > max_min_dist: 115 | max_min_dist = dist 116 | if dist < MIN_DISTANCE: 117 | need_shift[i] = False 118 | 119 | mean_shift_points[i] = p_new 120 | 121 | group = group_points(mean_shift_points) 122 | 123 | return np.mat(points), mean_shift_points, group 124 | 125 | 126 | def save_result(file_name, data): 127 | 128 | f = open(file_name, "w") 129 | m, n = np.shape(data) 130 | for i in range(m): 131 | tmp = [] 132 | for j in range(n): 133 | tmp.append(str(data[i, j])) 134 | f.write("\t".join(tmp) + "\n") 135 | f.close() 136 | 137 | def cal_centers(classnum,all_clean_vecs,kernel_bandwidth): 138 | centers = [] 139 | for label in range(classnum): 140 | data = all_clean_vecs[label] 141 | points, shift_points, cluster = train_mean_shift(data, kernel_bandwidth) 142 | centers.append(shift_points[0]) 143 | centers = np.array(centers) 144 | 145 | return centers 146 | 147 | def cal_target_center(target_clean_vecs,kernel_bandwidth): 148 | data = target_clean_vecs 149 | points, shift_points, cluster = train_mean_shift(data, kernel_bandwidth) 150 | center = np.array(shift_points[0]) 151 | 152 | return center 153 | 154 | 155 | -------------------------------------------------------------------------------- /clean_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | import os 6 | from datetime import datetime 7 | import json 8 | import argparse 9 | import time 10 | import random 11 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--clean-epoch', type=int, required=False,default=80, help='the number of training epochs without poisoning') 15 | parser.add_argument('--dup', type=int, required=True, help='the ID for duplicated models of a same setting') 16 | parser.add_argument('--multies', type=int, required=False,default=2, help='the number of mutiple participants') 17 | parser.add_argument('--unit', type=float, required=False, default=0.25,help='the feature ratio held by the attacker') 18 | 19 | def train_model(model, dataloader,epoch_num, is_binary, verbose=True): 20 | model.train() 21 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 22 | 23 | for epoch in range(epoch_num): 24 | cum_loss = 0.0 25 | cum_acc = 0.0 26 | tot = 0.0 27 | for i, (x_in, y_in) in enumerate(dataloader): 28 | B = x_in.size()[0] 29 | pred = model(x_in) 30 | loss = model.loss(pred, y_in) 31 | optimizer.zero_grad() 32 | loss.backward() 33 | optimizer.step() 34 | cum_loss += loss.item() * B 35 | if is_binary: 36 | cum_acc += ((pred > 0).cpu().long().eq(y_in)).sum().item() 37 | else: 38 | pred_c = pred.max(1)[1].cpu() 39 | cum_acc += (pred_c.eq(y_in)).sum().item() 40 | tot = tot + B 41 | if verbose: 42 | print("Epoch %d, loss = %.4f, acc = %.4f" % (epoch, cum_loss / tot, cum_acc / tot)) 43 | return 44 | 45 | 46 | def eval_model(model, dataloader, is_binary): 47 | model.eval() 48 | cum_acc = 0.0 49 | tot = 0.0 50 | for i, (x_in, y_in) in enumerate(dataloader): 51 | B = x_in.size()[0] 52 | pred = model(x_in) 53 | if is_binary: 54 | cum_acc += ((pred > 0).cpu().long().eq(y_in)).sum().item() 55 | else: 56 | pred_c = pred.max(1)[1].cpu() 57 | cum_acc += (pred_c.eq(y_in)).sum().item() 58 | tot = tot + B 59 | return cum_acc / tot 60 | 61 | if __name__ == '__main__': 62 | args = parser.parse_args() 63 | 64 | GPU = True 65 | if GPU: 66 | torch.cuda.manual_seed_all(0) 67 | torch.backends.cudnn.deterministic = True 68 | torch.backends.cudnn.benchmark = False 69 | 70 | BATCH_SIZE = 500 71 | N_EPOCH = 100 72 | 73 | transform_for_train = transforms.Compose([ 74 | transforms.RandomCrop((32, 32), padding=5), 75 | transforms.RandomRotation(10), 76 | transforms.RandomHorizontalFlip(p=0.5), 77 | transforms.Resize((32, 32)), 78 | transforms.ToTensor(), 79 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]) 80 | ]) 81 | transform_for_test = transforms.Compose([ 82 | transforms.ToTensor(), 83 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]) 84 | ]) 85 | trainset = torchvision.datasets.CIFAR10(root='./raw_data/', train=True, download=True, 86 | transform=transform_for_train) 87 | testset = torchvision.datasets.CIFAR10(root='./raw_data/', train=False, download=True, 88 | transform=transform_for_test) 89 | is_binary = False 90 | need_pad = False 91 | 92 | from cnn_model_multi import Model 93 | 94 | input_size = (3, 32, 32) 95 | class_num = 10 96 | 97 | model = Model(gpu=GPU,multies=args.multies,unit=args.unit) 98 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True) 99 | testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True) 100 | 101 | if args.clean_epoch: 102 | t1 = time.time() 103 | 104 | train_model(model, trainloader, epoch_num=args.clean_epoch, is_binary=is_binary, verbose=True) 105 | torch.save(model.state_dict(),'clean_epoch_%d-%d-%s.model'%(args.dup,args.multies,args.unit)) 106 | 107 | train_model(model, trainloader, epoch_num=N_EPOCH-args.clean_epoch, is_binary=is_binary, verbose=True) 108 | cleanacc = eval_model(model, testloader, is_binary=is_binary) 109 | torch.save(model.state_dict(),'clean-%d-%d-%s.model'%(args.dup,args.multies,args.unit)) 110 | print('clean acc: %.4f' % cleanacc) 111 | 112 | t2 = time.time() 113 | print("Training a model costs %.4fs." % (t2 - t1)) 114 | 115 | else: 116 | t1 = time.time() 117 | 118 | train_model(model, trainloader, epoch_num=N_EPOCH, is_binary=is_binary,verbose=True) 119 | cleanacc = eval_model(model, testloader, is_binary=is_binary) 120 | torch.save(model.state_dict(),'clean%d-%d-%s.model'%(args.dup,args.multies,args.unit)) 121 | print('clean acc: %.4f'%cleanacc) 122 | 123 | t2 = time.time() 124 | print("Training a model costs %.4fs." % (t2 - t1)) 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /add_noise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | import os 6 | from datetime import datetime 7 | import json 8 | import argparse 9 | import time 10 | import random 11 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--multies', type=int, required=False,default=2, help='the number of mutiple participants') 15 | parser.add_argument('--unit', type=float, required=False, default=0.25,help='the feature ratio held by the attacker') 16 | 17 | args = parser.parse_args() 18 | # setting for multi-participant VFl 19 | unit = args.unit 20 | multies = args.multies 21 | other_unit = (1-args.unit)/(args.multies-1) 22 | 23 | def add_noise(vec,normal_vecs): # noise scheme for two-split VFL 24 | avg_value = torch.mean(normal_vecs,dim=0).reshape((-1)) 25 | con = torch.where(avg_value < 0.001)[0] 26 | 27 | size = vec.size() 28 | vec = vec.reshape((-1)) 29 | 30 | vec = vec.clamp_(0, 2.5) 31 | vec *= 1.15 32 | 33 | gauss_noise_big = torch.normal(mean=0, std=0.5, size=vec.size()).cuda() 34 | gauss_noise_small = torch.normal(mean=0, std=0.1, size=vec.size()).cuda() 35 | 36 | condition = torch.randn(vec.size()).cuda() 37 | zeros = torch.zeros_like(vec).cuda() 38 | replace = torch.where(condition < 0.8, zeros, vec + gauss_noise_small) 39 | vec = torch.where(vec < 0.4, replace, vec + gauss_noise_big) 40 | vec = vec.clamp_(0).reshape((size[0],-1)) 41 | vec[:, con] = 0 42 | 43 | return vec.reshape(size) 44 | 45 | def add_noise_multi(vec,normal_vecs): # noise scheme for 4-participant VFL 46 | avg_value = torch.mean(normal_vecs,dim=0).reshape((-1)) 47 | con = torch.where(avg_value < 0.001)[0] 48 | 49 | size = vec.size() 50 | vec = vec.reshape((-1)) 51 | 52 | vec = vec.clamp_(0, 2.5) 53 | vec *= 1.15 54 | 55 | gauss_noise_big = torch.normal(mean=0, std=0.2, size=vec.size()).cuda() 56 | gauss_noise_small = torch.normal(mean=0, std=0.05, size=vec.size()).cuda() 57 | 58 | condition = torch.randn(vec.size()).cuda() 59 | zeros = torch.zeros_like(vec).cuda() 60 | replace = torch.where(condition < 0.8, zeros, vec + gauss_noise_small) 61 | vec = torch.where(vec < 0.4, replace, vec + gauss_noise_big) 62 | vec = vec.clamp_(0).reshape((size[0],-1)) 63 | vec[:, con] = 0 64 | 65 | return vec.reshape(size) 66 | 67 | def save(vecs,label,normal=False): 68 | vecs = vecs.reshape(-1, 64*4*8) 69 | if normal: 70 | f = open('normal_vec_%d.csv' % label, 'w') 71 | else: 72 | f = open('noise_vec_%d.csv' % label, 'w') 73 | for i in range(20): 74 | for j in range(vecs.shape[1]): 75 | f.write(str(vecs[i][j].item())) 76 | f.write(',') 77 | f.write('\n') 78 | f.close() 79 | 80 | def attack_model(model, dataloader, vec_arr,label): 81 | model.eval() 82 | cum_acc = 0.0 83 | tot = 0.0 84 | for i, (x_in, y_in) in enumerate(dataloader): 85 | B = x_in.size()[0] 86 | vec1 = torch.Tensor(np.repeat([vec_arr],B,axis=0)).cuda() 87 | x_list = x_in.split([int(x_in.size()[2]*unit)]+[int(x_in.size()[2]*other_unit) for i in range(multies-2)]+[x_in.size()[2]-int(x_in.size()[2]*unit)-(multies-2)*int(x_in.size()[2]*other_unit)],dim=2) 88 | vec_normal = model.models[0](x_list[0]) 89 | if multies == 2: 90 | vec1 = add_noise(vec1,vec_normal[:20]) 91 | elif multies > 2: 92 | vec1 = add_noise_multi(vec1,vec_normal[:20]) 93 | vec = torch.cat([vec1]+[model.models[i](x_list[i]) for i in range(1,multies)], dim=2) 94 | 95 | pred = model.top(vec) 96 | pred_c = pred.max(1)[1].cpu() 97 | cum_acc += (pred_c.eq(torch.Tensor(np.repeat([label],B,axis=0)))).sum().item() 98 | tot = tot + B 99 | 100 | save(vec1.clone().detach().cpu(),label,False) 101 | save(vec_normal.clone().detach().cpu(), label, True) 102 | return cum_acc / tot 103 | 104 | if __name__ == '__main__': 105 | 106 | GPU = True 107 | if GPU: 108 | torch.cuda.manual_seed_all(0) 109 | torch.backends.cudnn.deterministic = True 110 | torch.backends.cudnn.benchmark = False 111 | 112 | BATCH_SIZE = 500 113 | N_EPOCH = 100 114 | transform_for_test = transforms.Compose([ 115 | transforms.ToTensor(), 116 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]) 117 | ]) 118 | testset = torchvision.datasets.CIFAR10(root='./raw_data/', train=False, download=True, 119 | transform=transform_for_test) 120 | testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True) 121 | is_binary = False 122 | need_pad = False 123 | from cnn_model_multi import Model 124 | 125 | input_size = (3, 32, 32) 126 | class_num = 10 127 | model = Model(gpu=GPU,multies=multies,unit=unit) 128 | 129 | for label in range(class_num): 130 | atk_list = [] 131 | for dup in range(10): 132 | model.load_state_dict(torch.load('poison_label_%d-%s-%s-%d.model' % (dup,multies,unit,label))) 133 | target_vec = np.load('label_%d-%s-%s-%d_vec.npy'%(dup,multies,unit,label)) 134 | atkacc = attack_model(model, testloader, target_vec, label) 135 | atk_list.append(atkacc) 136 | print('target label: %d, average atk acc: %.4f'%(label,sum(atk_list)/len(atk_list))) 137 | -------------------------------------------------------------------------------- /poison_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | import os 6 | from datetime import datetime 7 | import json 8 | import argparse 9 | import time 10 | import random 11 | import cal_centers as cc 12 | import generate_vec as gv 13 | import search_vec as sv 14 | import warnings 15 | 16 | import matplotlib.pyplot as plt 17 | 18 | import matplotlib.lines as lines 19 | from matplotlib.ticker import FuncFormatter 20 | 21 | warnings.filterwarnings("ignore", category=FutureWarning) 22 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--label', type=int, required=True, help='the target class of your attack') 26 | parser.add_argument('--dup', type=int, required=True, help='the ID for duplicated models of a same setting') 27 | parser.add_argument('--magnification',type=int,required=True,help='the size of the auxiliary set will be 50*magnification') 28 | parser.add_argument('--multies', type=int, required=False,default=2, help='the number of mutiple participants') 29 | parser.add_argument('--unit', type=float, required=False, default=0.25,help='the feature ratio held by the attacker') 30 | parser.add_argument('--clean-epoch', type=int, required=False,default=80, help='the number of training epochs without poisoning') 31 | 32 | args = parser.parse_args() 33 | other_unit = (1-args.unit)/(args.multies-1) 34 | 35 | target_num = 50 36 | normal_num = 50 37 | clean_epoch = args.clean_epoch 38 | 39 | def prepared_data(set): 40 | data = [] 41 | label = [] 42 | for idx in range(len(set)): 43 | x,y = set[idx] 44 | data.append({'id':idx,'data':x}) 45 | label.append(y) 46 | 47 | return data,label 48 | 49 | class CIFAR10(torch.utils.data.Dataset): 50 | def __init__(self,data,label,transform=None): 51 | self.data = data 52 | self.label = label 53 | self.transform = transform 54 | 55 | def __getitem__(self, item): 56 | x = self.data[item]['data'] 57 | if not(self.transform is None): 58 | x = self.transform(x) 59 | y = self.label[item] 60 | id = self.data[item]['id'] 61 | 62 | return x,y,id 63 | 64 | def __len__(self): 65 | return len(self.label) 66 | 67 | 68 | def steal_samples(trn_x,trn_y,t): 69 | targets = [] 70 | for idx in range(len(trn_y)): 71 | if trn_y[idx] == t: 72 | targets.append(trn_x[idx]['id']) 73 | num = target_num*target_magnification 74 | print("clean image used for class %d: %d"%(t,num)) 75 | 76 | steal_id = random.sample(targets,num) 77 | data = [] 78 | label = [] 79 | for idx in steal_id: 80 | data.append(trn_x[idx]) 81 | label.append(trn_y[idx]) 82 | 83 | steal_set = CIFAR10(data,label,transform=transform_for_train) 84 | steal_id = torch.tensor(steal_id) 85 | 86 | return steal_set,steal_id 87 | 88 | def design_vec(class_num,model,label,steal_set): 89 | target_clean_vecs = gv.generate_target_clean_vecs(model.models[0],steal_set,args.unit,bottom_series=0) 90 | 91 | dim = filter_dim(target_clean_vecs) 92 | 93 | center = cc.cal_target_center(target_clean_vecs[dim].copy(),kernel_bandwidth=1000) 94 | 95 | target_vec = sv.search_vec(center,target_clean_vecs,args.unit) 96 | 97 | target_vec = target_vec.reshape((64,int((int((int(32*args.unit)-2)/2+1)-2)/2+1),8)) 98 | 99 | return target_vec 100 | 101 | def filter_dim(vecs): 102 | coef = np.corrcoef(vecs) 103 | rows = np.sum(coef,axis=1) 104 | selected = np.argpartition(rows,-target_num)[-target_num:] 105 | print(np.mean(np.corrcoef(vecs[selected]))) 106 | return selected 107 | 108 | def train_model(model, dataloader,label,steal_set,steal_id,epoch_num,start_epoch=0,is_binary=False, verbose=True): 109 | model.train() 110 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 111 | 112 | for epoch in range(start_epoch,epoch_num): 113 | t1 = time.time() 114 | 115 | cum_loss = 0.0 116 | cum_acc = 0.0 117 | tot = 0.0 118 | 119 | if epoch >= clean_epoch and epoch%10 == 0: 120 | vec_arr = design_vec(class_num, model, label, steal_set) 121 | 122 | for i, (x_in, y_in, id_in) in enumerate(dataloader): 123 | B = x_in.size()[0] 124 | 125 | if args.unit != 1: 126 | x_list = x_in.split([int(x_in.size()[2]*args.unit)]+[int(x_in.size()[2]*other_unit) for i in range(args.multies-2)]+[x_in.size()[2]-int(x_in.size()[2]*args.unit)-(args.multies-2)*int(x_in.size()[2]*other_unit)],dim=2) 127 | else: 128 | x_list = [x_in] 129 | 130 | vec1 = model.models[0](x_list[0]) 131 | 132 | if epoch >= clean_epoch: 133 | condition = [] 134 | for idx in range(B): 135 | if id_in[idx] in steal_id: 136 | condition.append(idx) 137 | vec1[condition] = torch.tensor(vec_arr).cuda() 138 | 139 | if args.unit != 1: 140 | vec = torch.cat([vec1]+[model.models[i](x_list[i]) for i in range(1,args.multies)], dim=2) 141 | else: 142 | vec = vec1 143 | 144 | pred = model.top(vec) 145 | loss = model.loss(pred, y_in) 146 | optimizer.zero_grad() 147 | loss.backward() 148 | optimizer.step() 149 | 150 | cum_loss += loss.item() * B 151 | if is_binary: 152 | cum_acc += ((pred > 0).cpu().long().eq(y_in)).sum().item() 153 | else: 154 | pred_c = pred.max(1)[1].cpu() 155 | cum_acc += (pred_c.eq(y_in)).sum().item() 156 | tot = tot + B 157 | 158 | 159 | if verbose: 160 | t2 = time.time() 161 | print("Epoch %d, loss = %.4f, acc = %.4f (%.4fs)" % (epoch, cum_loss / tot, cum_acc / tot,t2-t1)) 162 | 163 | return vec_arr 164 | 165 | 166 | def eval_model(model, dataloader, is_binary): 167 | model.eval() 168 | cum_acc = 0.0 169 | tot = 0.0 170 | for i, (x_in, y_in,_) in enumerate(dataloader): 171 | B = x_in.size()[0] 172 | pred = model(x_in) 173 | if is_binary: 174 | cum_acc += ((pred > 0).cpu().long().eq(y_in)).sum().item() 175 | else: 176 | pred_c = pred.max(1)[1].cpu() 177 | cum_acc += (pred_c.eq(y_in)).sum().item() 178 | tot = tot + B 179 | return cum_acc / tot 180 | 181 | def attack_model(model, dataloader, vec_arr,label,multies,unit,other_unit): 182 | model.eval() 183 | cum_acc = 0.0 184 | tot = 0.0 185 | for i, (x_in, y_in,_) in enumerate(dataloader): 186 | B = x_in.size()[0] 187 | 188 | if args.unit != 1: 189 | x_list = x_in.split([int(x_in.size()[2]*unit)]+[int(x_in.size()[2]*other_unit) for i in range(multies-2)]+[x_in.size()[2]-int(x_in.size()[2]*unit)-(multies-2)*int(x_in.size()[2]*other_unit)],dim=2) 190 | 191 | vec1 = torch.Tensor(np.repeat([vec_arr],B,axis=0)).cuda() 192 | 193 | if args.unit != 1: 194 | vec = torch.cat([vec1]+[model.models[i](x_list[i]) for i in range(1,multies)], dim=2) 195 | else: 196 | vec = vec1 197 | 198 | pred = model.top(vec) 199 | pred_c = pred.max(1)[1].cpu() 200 | cum_acc += (pred_c.eq(torch.Tensor(np.repeat([label],B,axis=0)))).sum().item() 201 | tot = tot + B 202 | return cum_acc / tot 203 | 204 | 205 | if __name__ == '__main__': 206 | target_magnification = args.magnification 207 | 208 | GPU = True 209 | if GPU: 210 | torch.cuda.manual_seed_all(args.dup) 211 | random.seed(args.dup) 212 | torch.manual_seed(args.dup) 213 | np.random.seed(args.dup) 214 | 215 | torch.backends.cudnn.deterministic = True 216 | torch.backends.cudnn.benchmark = False 217 | 218 | BATCH_SIZE = 500 219 | N_EPOCH = 100 220 | transform_for_train = transforms.Compose([ 221 | transforms.RandomCrop((32, 32), padding=5), 222 | transforms.RandomRotation(10), 223 | transforms.RandomHorizontalFlip(p=0.5), 224 | transforms.Resize((32, 32)), 225 | transforms.ToTensor(), 226 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]) 227 | ]) 228 | transform_for_test = transforms.Compose([ 229 | transforms.ToTensor(), 230 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]) 231 | ]) 232 | trainset = torchvision.datasets.CIFAR10(root='./raw_data/', train=True, download=True) 233 | testset = torchvision.datasets.CIFAR10(root='./raw_data/', train=False, download=True) 234 | 235 | trn_x,trn_y = prepared_data(trainset) 236 | dl_train_set = CIFAR10(trn_x,trn_y,transform=transform_for_train) 237 | val_x,val_y = prepared_data(testset) 238 | dl_val_set = CIFAR10(val_x,val_y,transform=transform_for_test) 239 | is_binary = False 240 | need_pad = False 241 | 242 | from cnn_model_multi import Model 243 | 244 | input_size = (3, 32, 32) 245 | class_num = 10 246 | 247 | model = Model(gpu=GPU,multies=args.multies,unit=args.unit) 248 | trainloader = torch.utils.data.DataLoader(dl_train_set, batch_size=BATCH_SIZE, shuffle=True) 249 | testloader = torch.utils.data.DataLoader(dl_val_set, batch_size=BATCH_SIZE, shuffle=True) 250 | 251 | steal_set,steal_id = steal_samples(trn_x,trn_y,args.label) 252 | 253 | label = args.label 254 | dup = args.dup 255 | 256 | t1=time.time() 257 | 258 | model.load_state_dict(torch.load('clean-%d-%d-%s.model'%(args.dup,args.multies,args.unit))) 259 | 260 | last_vec_arr = train_model(model, trainloader,label,steal_set,steal_id,epoch_num=N_EPOCH,start_epoch=clean_epoch, is_binary=is_binary,verbose=True) 261 | torch.save(model.state_dict(),'poison_label_%d-%d-%s-%d.model'%(args.dup,args.multies,args.unit,args.label)) 262 | 263 | cleanacc = eval_model(model, testloader, is_binary=is_binary) 264 | print('clean acc: %.4f'%cleanacc) 265 | 266 | atkacc = attack_model(model, testloader, last_vec_arr, label,args.multies,args.unit,other_unit) 267 | print('target label: %d, attack acc: %.4f' % (label, atkacc)) 268 | 269 | np.save('label_%d-%d-%s-%d_vec'%(args.dup,args.multies,args.unit,args.label),last_vec_arr) 270 | 271 | t2 = time.time() 272 | print("Training a model costs %.4fs."%(t2-t1)) 273 | --------------------------------------------------------------------------------