├── attack ├── __init__.py ├── attack.py ├── add_trigger.py └── add_trigger_replace.py ├── models ├── __init__.py ├── PHISHING_models.py ├── CIFAR10_models.py ├── UCIHAR_models.py └── NUSWIDE_models.py ├── utils ├── __init__.py ├── utils.py └── trainer.py ├── .gitattributes ├── requirements.txt ├── dataset ├── utils.py └── dataset.py ├── README.md └── main.py /attack/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def raise_dataset_exception(): 9 | raise Exception('Unknown dataset, please implement it.') 10 | 11 | 12 | def raise_split_exception(): 13 | raise Exception('Unknown split, please implement it.') 14 | 15 | 16 | def raise_attack_exception(): 17 | raise Exception('Unknown attack, please complement it.') 18 | 19 | 20 | def set_seed(seed): 21 | random.seed(seed) 22 | os.environ['PYTHONHASHSEED'] = str(seed) 23 | np.random.seed(seed) 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | torch.backends.cudnn.benchmark = False 28 | torch.backends.cudnn.deterministic = True 29 | -------------------------------------------------------------------------------- /models/PHISHING_models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | from utils.utils import * 4 | 5 | 6 | class GlobalModelForPHISHING(nn.Module): 7 | def __init__(self, args): 8 | super(GlobalModelForPHISHING, self).__init__() 9 | self.linear1 = nn.Linear(8, 4) 10 | self.classifier = nn.Linear(4, 2) 11 | self.args = args 12 | 13 | def forward(self, input_list): 14 | tensor_t = torch.cat((input_list[0], input_list[1]), dim=1) 15 | 16 | # forward 17 | x = tensor_t 18 | x = self.linear1(x) 19 | x = self.classifier(x) 20 | return x 21 | 22 | 23 | class LocalModelForPHISHING(nn.Module): 24 | def __init__(self, args, client_number): 25 | super(LocalModelForPHISHING, self).__init__() 26 | self.args = args 27 | if client_number == 0: 28 | self.backbone = nn.Sequential( 29 | nn.Linear(15, 8), 30 | nn.ReLU(), 31 | nn.Linear(8, 4) 32 | ) 33 | else: 34 | self.backbone = nn.Sequential( 35 | nn.Linear(15, 8), 36 | nn.ReLU(), 37 | nn.Linear(8, 4) 38 | ) 39 | 40 | def forward(self, x): 41 | x = self.backbone(x) 42 | return x 43 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2024.2.2 2 | charset-normalizer==3.3.2 3 | contourpy==1.2.1 4 | cycler==0.12.1 5 | filelock==3.15.4 6 | fonttools==4.53.1 7 | fsspec==2024.6.1 8 | gensim==4.3.1 9 | idna==3.7 10 | importlib_resources==6.4.0 11 | intel-cmplr-lib-ur==2024.2.0 12 | intel-openmp==2024.2.0 13 | Jinja2==3.1.4 14 | joblib==1.4.2 15 | kiwisolver==1.4.5 16 | MarkupSafe==2.1.5 17 | mpmath==1.3.0 18 | networkx==3.2.1 19 | numpy==1.26.4 20 | nvidia-cublas-cu12==12.1.3.1 21 | nvidia-cuda-cupti-cu12==12.1.105 22 | nvidia-cuda-nvrtc-cu12==12.1.105 23 | nvidia-cuda-runtime-cu12==12.1.105 24 | nvidia-cudnn-cu12==8.9.2.26 25 | nvidia-cufft-cu12==11.0.2.54 26 | nvidia-curand-cu12==10.3.2.106 27 | nvidia-cusolver-cu12==11.4.5.107 28 | nvidia-cusparse-cu12==12.1.0.106 29 | nvidia-nccl-cu12==2.20.5 30 | nvidia-nvjitlink-cu12==12.5.82 31 | nvidia-nvtx-cu12==12.1.105 32 | opencv-python==4.9.0.80 33 | opt-einsum==3.3.0 34 | packaging==24.1 35 | pandas==2.2.2 36 | pillow==10.3.0 37 | pyparsing==3.1.2 38 | python-dateutil==2.9.0.post0 39 | pytz==2024.1 40 | requests==2.32.2 41 | scikit-learn==1.5.0 42 | scipy==1.13.1 43 | setuptools-scm==8.1.0 44 | six==1.16.0 45 | sklearn==0.0.post5 46 | smart-open==6.3.0 47 | sympy==1.13.0 48 | tbb==2021.13.0 49 | threadpoolctl==3.5.0 50 | tomli==2.0.1 51 | torch==1.12.1 52 | torchvision==0.11.0 53 | triton==2.3.1 54 | tsnecuda==3.0.1 55 | typing_extensions==4.11.0 56 | tzdata==2024.1 57 | urllib3==2.2.1 58 | zipp==3.19.2 59 | -------------------------------------------------------------------------------- /models/CIFAR10_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | 5 | 6 | class GlobalModelForCifar10(nn.Module): 7 | def __init__(self, args): 8 | super(GlobalModelForCifar10, self).__init__() 9 | self.linear1 = nn.Linear(256, 256) 10 | self.linear2 = nn.Linear(256, 128) 11 | self.classifier = nn.Linear(128, 10) 12 | self.args = args 13 | 14 | def forward(self, input_list): 15 | tensor_t = torch.cat((input_list[0], input_list[1]), dim=1) 16 | 17 | # forward 18 | x = tensor_t 19 | x = self.linear1(x) 20 | x = self.linear2(x) 21 | x = self.classifier(x) 22 | return x 23 | 24 | 25 | class LocalModelForCifar10(nn.Module): 26 | def __init__(self, args): 27 | super(LocalModelForCifar10, self).__init__() 28 | self.args = args 29 | self.backbone = models.resnet18(pretrained=False) 30 | num_ftrs = self.backbone.fc.in_features 31 | if self.args.client_num == 2: 32 | self.backbone.fc = nn.Linear(num_ftrs, 128) 33 | 34 | def forward(self, x): 35 | x = self.backbone(x) 36 | return x 37 | 38 | 39 | class SingleModelForCifar10(nn.Module): 40 | def __init__(self, args): 41 | super(SingleModelForCifar10, self).__init__() 42 | self.args = args 43 | self.backbone = models.resnet18(pretrained=False) 44 | num_ftrs = self.backbone.fc.in_features 45 | self.backbone.fc = nn.Linear(num_ftrs, 10) 46 | 47 | def forward(self, x): 48 | x = self.backbone(x) 49 | return x 50 | -------------------------------------------------------------------------------- /models/UCIHAR_models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | from utils.utils import * 4 | 5 | 6 | class GlobalModelForUCIHAR(nn.Module): 7 | def __init__(self, args): 8 | super(GlobalModelForUCIHAR, self).__init__() 9 | self.linear1 = nn.Linear(32, 32) 10 | self.linear2 = nn.Linear(32, 16) 11 | self.classifier = nn.Linear(16, 6) 12 | self.args = args 13 | 14 | def forward(self, input_list): 15 | tensor_t = torch.cat((input_list[0], input_list[1]), dim=1) 16 | 17 | # forward 18 | x = tensor_t 19 | x = self.linear1(x) 20 | x = self.linear2(x) 21 | x = self.classifier(x) 22 | return x 23 | 24 | 25 | class LocalModelForUCIHAR(nn.Module): 26 | def __init__(self, args, client_number): 27 | super(LocalModelForUCIHAR, self).__init__() 28 | self.args = args 29 | if client_number == 0: 30 | self.backbone = nn.Sequential( 31 | nn.Linear(math.ceil(561 / self.args.client_num), 140), 32 | nn.ReLU(), 33 | nn.Linear(140, 70), 34 | nn.ReLU(), 35 | nn.Linear(70, 35), 36 | nn.ReLU(), 37 | nn.Linear(35, 16), 38 | nn.ReLU() 39 | ) 40 | else: 41 | self.backbone = nn.Sequential( 42 | nn.Linear(round(561 / self.args.client_num), 140), 43 | nn.ReLU(), 44 | nn.Linear(140, 70), 45 | nn.ReLU(), 46 | nn.Linear(70, 35), 47 | nn.ReLU(), 48 | nn.Linear(35, 16), 49 | nn.ReLU() 50 | ) 51 | 52 | def forward(self, x): 53 | x = self.backbone(x) 54 | return x 55 | -------------------------------------------------------------------------------- /models/NUSWIDE_models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from utils.utils import * 3 | 4 | 5 | class GlobalModelForNUSWIDE(nn.Module): 6 | def __init__(self, args): 7 | super(GlobalModelForNUSWIDE, self).__init__() 8 | self.linear1 = nn.Linear(100, 100) 9 | self.linear2 = nn.Linear(100, 50) 10 | self.classifier = nn.Linear(50, 5) 11 | self.args = args 12 | 13 | def forward(self, input_list): 14 | tensor_t = torch.cat((input_list[0], input_list[1]), dim=1) 15 | 16 | # forward 17 | x = tensor_t 18 | x = self.linear1(x) 19 | x = self.linear2(x) 20 | x = self.classifier(x) 21 | return x 22 | 23 | 24 | class LocalModelForNUSWIDE(nn.Module): 25 | def __init__(self, args, client_number): 26 | super(LocalModelForNUSWIDE, self).__init__() 27 | self.args = args 28 | backbone_I = nn.Sequential( 29 | nn.Linear(634, 320), 30 | nn.ReLU(), 31 | nn.Linear(320, 160), 32 | nn.ReLU(), 33 | nn.Linear(160, 80), 34 | nn.ReLU(), 35 | nn.Linear(80, 40), 36 | nn.ReLU() 37 | ) 38 | backbone_T = self.backbone = nn.Sequential( 39 | nn.Linear(1000, 500), 40 | nn.ReLU(), 41 | nn.Linear(500, 250), 42 | nn.ReLU(), 43 | nn.Linear(250, 125), 44 | nn.ReLU(), 45 | nn.Linear(125, 60), 46 | nn.ReLU() 47 | ) 48 | if client_number == 0: 49 | self.backbone = backbone_I 50 | else: 51 | self.backbone = backbone_T 52 | 53 | def forward(self, x): 54 | x = self.backbone(x) 55 | return x 56 | -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import numpy as np 5 | 6 | from utils.utils import raise_dataset_exception, raise_split_exception 7 | 8 | 9 | def get_labeled_data(data_dir, selected_label, n_samples, dtype="Train"): 10 | # get labels 11 | data_path = "Groundtruth/TrainTestLabels/" 12 | dfs = [] 13 | for label in selected_label: 14 | file = os.path.join(data_dir, data_path, "_".join(["Labels", label, dtype]) + ".txt") 15 | print("Loading {}.".format(file)) 16 | df = pd.read_csv(file, header=None, engine="c") 17 | df.columns = [label] 18 | dfs.append(df) 19 | data_labels = pd.concat(dfs, axis=1) 20 | if len(selected_label) > 1: 21 | selected = data_labels[data_labels.sum(axis=1) == 1] 22 | else: 23 | selected = data_labels 24 | # get XA, which are image low level features 25 | features_path = "Low_Level_Features" 26 | dfs = [] 27 | for file in os.listdir(os.path.join(data_dir, features_path)): 28 | if file.startswith("_".join([dtype, "Normalized"])): 29 | print("Loading {}.".format(os.path.join(data_dir, features_path, file))) 30 | df = pd.read_csv(os.path.join(data_dir, features_path, file), header=None, sep=" ", engine="c") 31 | df.dropna(axis=1, inplace=True) 32 | dfs.append(df) 33 | data_XA = pd.concat(dfs, axis=1) 34 | data_X_image_selected = data_XA.loc[selected.index] 35 | # get XB, which are tags 36 | tag_path = "NUS_WID_Tags/" 37 | file = "_".join([dtype, "Tags1k"]) + ".dat" 38 | print("Loading {}.".format(file)) 39 | tagsdf = pd.read_csv(os.path.join(data_dir, tag_path, file), header=None, sep="\t", engine="c") 40 | tagsdf.dropna(axis=1, inplace=True) 41 | data_X_text_selected = tagsdf.loc[selected.index] 42 | if n_samples is None: 43 | return data_X_image_selected.values[:], data_X_text_selected.values[:], np.argmax(selected.values[:], 1) 44 | return data_X_image_selected.values[:n_samples], data_X_text_selected.values[:n_samples], np.argmax( 45 | selected.values[:n_samples]) 46 | 47 | 48 | def split_vfl(data, args): 49 | if args.dataset == 'CIFAR10': 50 | # 32*16*3/32*16*3 51 | x_a = data[:, :, :, :16] 52 | x_b = data[:, :, :, 16:] 53 | return x_a, x_b 54 | elif args.dataset == 'UCIHAR': 55 | # 281/280 56 | x_a = data[:, :281] 57 | x_b = data[:, 281:] 58 | return x_a, x_b 59 | elif args.dataset == 'PHISHING': 60 | # 281/280 61 | x_a = data[:, :15] 62 | x_b = data[:, 15:] 63 | return x_a, x_b 64 | elif args.dataset == 'NUSWIDET': 65 | # 634/1000 66 | x_a = data[:, :634] 67 | x_b = data[:, 634:] 68 | return x_a, x_b 69 | elif args.dataset == 'NUSWIDEI': 70 | # 1000/634 71 | x_a = data[:, 634:] 72 | x_b = data[:, :634] 73 | return x_a, x_b 74 | else: 75 | raise_dataset_exception() 76 | -------------------------------------------------------------------------------- /attack/attack.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | import torch 5 | from attack.add_trigger import add_trigger_to_data 6 | from attack.add_trigger_replace import add_trigger_to_data_replace 7 | 8 | 9 | def attack_lra(args, logger, data, trigger_dimensions, targets, rate, mode): 10 | new_data = copy.deepcopy(data) 11 | new_targets = copy.deepcopy(targets) 12 | poison_indexes = np.random.permutation(len(new_data))[0: int(len(new_data) * rate)] 13 | new_data, new_targets = add_trigger_to_data(args, logger, poison_indexes, new_data, trigger_dimensions, new_targets, 14 | rate, mode, 15 | replace_label=True) 16 | return new_data, new_targets 17 | 18 | 19 | def attack_rsa(args, logger, data, trigger_dimensions, rate, mode): 20 | new_data = copy.deepcopy(data) 21 | poison_indexes = np.random.permutation(len(new_data))[0: int(len(new_data) * rate)] 22 | new_data, _ = add_trigger_to_data(args, logger, poison_indexes, data, trigger_dimensions, [], rate, mode, 23 | replace_label=False) 24 | return new_data 25 | 26 | 27 | def attack_LFBA(args, logger, replace_indexes_others, replace_indexes_target, train_indexes, poison_indexes, data, 28 | target, trigger_dimensions, rate, 29 | mode): 30 | if args.poison_all: 31 | new_data, _ = add_trigger_to_data(args, logger, poison_indexes, data, trigger_dimensions, target, rate, mode, 32 | replace_label=False) 33 | else: 34 | new_data, _ = add_trigger_to_data_replace(args, logger, replace_indexes_others, replace_indexes_target, 35 | train_indexes, poison_indexes, data, trigger_dimensions, target, rate, 36 | mode, 37 | replace_label=False) 38 | return new_data 39 | 40 | 41 | def select_LFBA(train_features, num_poisons): 42 | anchor_idx = get_anchor_LFBA( 43 | train_features, num_poisons) 44 | anchor_feature = train_features[anchor_idx] 45 | 46 | poisoning_index = get_near_index( 47 | anchor_feature, train_features, num_poisons) 48 | poisoning_index = poisoning_index.cpu() 49 | 50 | return poisoning_index, anchor_idx 51 | 52 | 53 | def get_anchor_LFBA(train_features, num_poisons): 54 | consistency = train_features @ train_features.T 55 | w = torch.cat((torch.ones((num_poisons)), 56 | -torch.ones((num_poisons))), dim=0) 57 | top_con = torch.topk(consistency, 2 * num_poisons, dim=1)[0] 58 | mean_top_con = torch.matmul(top_con, w) 59 | idx = torch.argmax(mean_top_con) 60 | return idx 61 | 62 | 63 | def get_near_index(anchor_feature, train_features, num_poisons): 64 | anchor_feature_l1 = torch.norm(anchor_feature, p=1) 65 | train_features_l1 = torch.norm(train_features, p=1, dim=1) 66 | vals, indices = torch.topk(torch.div((train_features @ anchor_feature), (train_features_l1 * anchor_feature_l1)), k=num_poisons, dim=0) 67 | return indices 68 | -------------------------------------------------------------------------------- /attack/add_trigger.py: -------------------------------------------------------------------------------- 1 | import math 2 | from utils.utils import * 3 | 4 | 5 | def add_trigger_to_data(args, logger, poison_indexes, new_data, trigger_dimensions, new_targets, rate, mode, 6 | replace_label): 7 | mode_print(logger, mode) 8 | if args.dataset == 'CIFAR10': 9 | new_data, new_targets = add_triangle_pattern_trigger(args, logger, poison_indexes, new_data, new_targets, rate, 10 | mode, replace_label) 11 | return new_data, new_targets 12 | elif args.dataset == 'UCIHAR': 13 | new_data, new_targets = add_feature_trigger(args, logger, poison_indexes, trigger_dimensions, new_data, 14 | new_targets, rate, mode, 15 | replace_label) 16 | return new_data, new_targets 17 | elif args.dataset == 'PHISHING': 18 | new_data, new_targets = add_vector_replacement_trigger(args, logger, poison_indexes, trigger_dimensions, 19 | new_data, 20 | new_targets, rate, mode, 21 | replace_label) 22 | return new_data, new_targets 23 | elif args.dataset == 'NUSWIDE': 24 | new_data, new_targets = add_vector_replacement_trigger(args, logger, poison_indexes, trigger_dimensions, 25 | new_data, new_targets, 26 | rate, 27 | mode, replace_label) 28 | return new_data, new_targets 29 | 30 | 31 | def add_triangle_pattern_trigger(args, logger, poison_indexes, new_data, new_targets, rate, mode, replace_label): 32 | height, width, channels = new_data.shape[1:] 33 | for idx in poison_indexes: 34 | if replace_label and mode == 'train': 35 | new_targets[idx] = args.target_label 36 | for c in range(channels): 37 | new_data[idx, height - 3:, width - 3:, c] = 0 38 | new_data[idx, height - 3, width - 1, c] = 255 39 | new_data[idx, height - 1, width - 3, c] = 255 40 | new_data[idx, height - 2, width - 2, c] = 255 41 | new_data[idx, height - 1, width - 1, c] = 255 42 | logger.info( 43 | "Add Trigger to %d Poison Samples, %d Clean Samples (%.2f)" % ( 44 | len(poison_indexes), len(new_data) - len(poison_indexes), rate)) 45 | return new_data, new_targets 46 | 47 | 48 | def add_feature_trigger(args, logger, poison_indexes, trigger_dimensions, new_data, new_targets, rate, mode, 49 | replace_label=True): 50 | for idx in poison_indexes: 51 | if replace_label and mode == 'train': 52 | new_targets[idx] = args.target_label 53 | new_data[idx][trigger_dimensions] = args.trigger_feature_clip 54 | logger.info( 55 | "Add Trigger to %d Bad Samples, %d Clean Samples (%.2f)" % ( 56 | len(poison_indexes), len(new_data) - len(poison_indexes), rate)) 57 | return new_data, new_targets 58 | 59 | 60 | def add_vector_replacement_trigger(args, logger, poison_indexes, trigger_dimensions, new_data, new_targets, rate, mode, 61 | replace_label): 62 | for idx in poison_indexes: 63 | if replace_label and mode == 'train': 64 | new_targets[idx] = args.target_label 65 | new_data[idx][trigger_dimensions] = 1 66 | logger.info( 67 | "Add Trigger to %d Bad Samples, %d Clean Samples (%.2f)" % ( 68 | len(poison_indexes), len(new_data) - len(poison_indexes), rate)) 69 | return new_data, new_targets 70 | 71 | 72 | def mode_print(logger, mode): 73 | if mode == 'train': 74 | logger.info('=>Add Trigger to Train Data') 75 | else: 76 | logger.info('=>Add Trigger to Test Data') 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Label-Free Backdoor Attacks in Vertical Federated Learning 2 | 3 | **The pytorch implementation of "Label-Free Backdoor Attacks in Vertical Federated Learning" (AAAI-25).** 4 | 5 | ![](./framework.svg) 6 | 7 | > [Label-Free Backdoor Attacks in Vertical Federated Learning](https://ojs.aaai.org/index.php/AAAI/article/view/34246) 8 | > 9 | > Wei Shen, Wenke Huang, Guancheng Wan, Mang Ye 10 | > 11 | > School of Computer Science, Wuhan University 12 | > 13 | > **Abstract** Vertical Federated Learning (VFL) involves multiple clients collaborating to train a global model with distributed features but shared samples. While it becomes a critical privacy-preserving learning paradigm, its security can be significantly compromised by backdoor attacks, where a malicious client injects a target backdoor by manipulating local data. Existing attack methods in VFL rely on the assumption that the malicious client can obtain additional knowledge about task labels, which is not applicable in VFL. In this work, we investigate a new backdoor attack paradigm in VFL, **L**abel-**F**ree **B**ackdoor **A**ttacks (**LFBA**), which does not require any additional label information and is feasible in VFL settings. Specifically, while existing methods assume access to task labels or target-class samples, we demonstrate that local embedding gradients reflect the semantic information of labels. It can guide the construction of the poison sample set from the backdoor target. Besides, we uncover that backdoor triggers tend to be ignored and under-fitted due to the learning of original features, which hinders backdoor task optimization. To address this, we propose selectively switching poison samples to disrupt feature learning, promoting backdoor task learning while maintaining accuracy on clean data. Extensive experiments demonstrate the effectiveness of our method in various settings. 14 | 15 | ## Requirements 16 | We use a single NVIDIA GeForce RTX 3090 for all evaluations. Clone the repository and install the dependencies from requirements.txt using the Anaconda environment: 17 | ```bash 18 | conda create -n LFBA python=3.9 19 | conda activate LFBA 20 | git clone 'https://github.com/shentt67/LFBA.git' 21 | cd LFBA 22 | pip install requirements.txt 23 | ``` 24 | 25 | ## Example Usage 26 | 27 | For instance, to perform backdoor attacks with LFBA on the NUS-WIDE dataset, run: 28 | ```bash 29 | python main.py --device 0 --dataset NUSWIDE --epoch 100 --batch_size 256 --lr 0.001 --attack LFBA --anchor_idx 33930 --poison_rate 0.1 --poison_dimensions 10 --select_replace --select_rate 0.3 30 | ``` 31 | 32 | For CIFAR-10 dataset: 33 | ```bash 34 | python main.py --device 0 --dataset CIFAR10 --epoch 100 --batch_size 256 --lr 0.001 --attack LFBA --anchor_idx 23470 --poison_rate 0.1 --select_replace --select_rate 0.5 35 | ``` 36 | 37 | Hyperparameter explanations: 38 | 39 | **--device:** The ID of GPU to be used. 40 | 41 | **--dataset:** The experiment datasets. We include ['NUSWIDE', 'UCIHAR', 'Phishing', 'CIFAR10'] for evaluations. 42 | 43 | **--epoch:** The training epochs. 44 | 45 | **--batch_size:** The training batch size. 46 | 47 | **--lr:** The learning rate. 48 | 49 | **--attack:** The attack methods. Set 'LFBA' for the proposed method. 50 | 51 | **--anchor_idx:** The index of anchor. 52 | 53 | **--poison_rate:** The poison ratio ($p=\frac{N_p}{N}$ in the paper). 54 | 55 | **--poison_dimensions:** The dimension of triggers, e.g., 10 means randomly set 10 dimensions in the attacker client into the fixed value. 56 | 57 | **--select_replace:** Add this params to perform attack with selectively sample switching. 58 | 59 | **--select_rate:** The switch ratio ($s=\frac{N_s}{N}$ in the paper). 60 | 61 | 62 | ## Citation 63 | Please cite our work, thank you! 64 | 65 | ``` 66 | @inproceedings{shen2025label, 67 | title={Label-free backdoor attacks in vertical federated learning}, 68 | author={Shen, Wei and Huang, Wenke and Wan, Guancheng and Ye, Mang}, 69 | booktitle={The 39th AAAI Conference on Artificial Intelligence}, 70 | year={2025} 71 | } 72 | ``` 73 | Contact: [weishen@whu.edu.cn](mailto:weishen@whu.edu.cn) -------------------------------------------------------------------------------- /attack/add_trigger_replace.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from utils.utils import * 4 | 5 | 6 | def add_trigger_to_data_replace(args, logger, replace_indexes_others, replace_indexes_target, train_indexes, 7 | poison_indexes, new_data, trigger_dimensions, new_targets, rate, mode, 8 | replace_label): 9 | mode_print(logger, mode) 10 | if args.dataset == 'CIFAR10': 11 | new_data, new_targets = add_triangle_pattern_trigger(args, logger, replace_indexes_others, 12 | replace_indexes_target, train_indexes, poison_indexes, 13 | new_data, 14 | new_targets, rate, 15 | mode, replace_label) 16 | return new_data, new_targets 17 | elif args.dataset == 'UCIHAR': 18 | new_data, new_targets = add_feature_trigger(args, logger, replace_indexes_others, replace_indexes_target, 19 | train_indexes, poison_indexes, trigger_dimensions, new_data, 20 | new_targets, rate, mode, 21 | replace_label) 22 | return new_data, new_targets 23 | elif args.dataset == 'PHISHING': 24 | new_data, new_targets = add_vector_replacement_trigger(args, logger, replace_indexes_others, 25 | replace_indexes_target, train_indexes, poison_indexes, 26 | trigger_dimensions, new_data, 27 | new_targets, rate, mode, 28 | replace_label) 29 | return new_data, new_targets 30 | elif args.dataset == 'NUSWIDE': 31 | new_data, new_targets = add_vector_replacement_trigger(args, logger, replace_indexes_others, 32 | replace_indexes_target, train_indexes, poison_indexes, 33 | trigger_dimensions, new_data, new_targets, 34 | rate, 35 | mode, replace_label) 36 | return new_data, new_targets 37 | 38 | 39 | def add_triangle_pattern_trigger(args, logger, replace_indexes_others, replace_indexes_target, train_indexes, 40 | poison_indexes, new_data, new_targets, rate, mode, 41 | replace_label): 42 | height, width, channels = new_data.shape[1:] 43 | temp = copy.deepcopy(new_data) 44 | for i, idx in enumerate(replace_indexes_others): 45 | for c in range(channels): 46 | temp[idx, height - 3:, width - 3:, c] = 0 47 | temp[idx, height - 3, width - 1, c] = 255 48 | temp[idx, height - 1, width - 3, c] = 255 49 | temp[idx, height - 2, width - 2, c] = 255 50 | temp[idx, height - 1, width - 1, c] = 255 51 | new_data[replace_indexes_target[i], :, 16:, :] = temp[idx, :, 16:, :] 52 | logger.info( 53 | "Add Trigger to %d Poison Samples, %d Clean Samples (%.2f)" % ( 54 | len(poison_indexes), len(new_data) - len(poison_indexes), rate)) 55 | return new_data, new_targets 56 | 57 | 58 | def add_feature_trigger(args, logger, replace_indexes_others, replace_indexes_target, train_indexes, poison_indexes, 59 | trigger_dimensions, new_data, new_targets, 60 | rate, mode, 61 | replace_label=True): 62 | temp = copy.deepcopy(new_data) 63 | for i, idx in enumerate(replace_indexes_others): 64 | temp[idx][trigger_dimensions] = args.trigger_feature_clip 65 | if args.dataset == 'UCIHAR': 66 | new_data[replace_indexes_target[i]][281:] = temp[idx][281:] 67 | logger.info( 68 | "Add Trigger to %d Bad Samples, %d Clean Samples (%.2f)" % ( 69 | len(poison_indexes), len(new_data) - len(poison_indexes), rate)) 70 | return new_data, new_targets 71 | 72 | 73 | def add_vector_replacement_trigger(args, logger, replace_indexes_others, replace_indexes_target, train_indexes, 74 | poison_indexes, trigger_dimensions, new_data, 75 | new_targets, rate, mode, replace_label): 76 | temp = copy.deepcopy(new_data) 77 | if args.dataset == 'PHISHING': 78 | for i, idx in enumerate(replace_indexes_others): 79 | temp[idx][trigger_dimensions] = 1 80 | new_data[replace_indexes_target[i]][15:] = temp[idx][15:] 81 | elif args.dataset == 'NUSWIDE': 82 | for i, idx in enumerate(replace_indexes_others): 83 | temp[idx][trigger_dimensions] = 1 84 | new_data[replace_indexes_target[i]][634:] = temp[idx][634:] 85 | logger.info( 86 | "Add Trigger to %d Bad Samples, %d Clean Samples (%.2f)" % ( 87 | len(poison_indexes), len(new_data) - len(poison_indexes), rate)) 88 | return new_data, new_targets 89 | 90 | 91 | def mode_print(logger, mode): 92 | if mode == 'train': 93 | logger.info('=>Add Trigger to Train Data') 94 | else: 95 | logger.info('=>Add Trigger to Test Data') 96 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from random import sample 3 | 4 | import pandas as pd 5 | from sklearn.preprocessing import StandardScaler 6 | from torch.utils.data import Dataset 7 | from torchvision.datasets import CIFAR10 8 | from .utils import get_labeled_data 9 | from PIL import Image 10 | import os 11 | import os.path 12 | import numpy as np 13 | import pickle 14 | import torch 15 | from typing import Any, Callable, Optional, Tuple 16 | 17 | 18 | class CIFAR10_VFL(CIFAR10): 19 | def __init__( 20 | self, 21 | root: str, 22 | train: bool = True, 23 | transform: Optional[Callable] = None, 24 | target_transform: Optional[Callable] = None, 25 | download: bool = False, 26 | ) -> None: 27 | 28 | super(CIFAR10, self).__init__(root, transform=transform, 29 | target_transform=target_transform) 30 | 31 | self.train = train # training set or test set 32 | 33 | if download: 34 | self.download() 35 | 36 | if not self._check_integrity(): 37 | raise RuntimeError('Dataset not found or corrupted.' + 38 | ' You can use download=True to download it') 39 | 40 | if self.train: 41 | downloaded_list = self.train_list 42 | else: 43 | downloaded_list = self.test_list 44 | 45 | self.data: Any = [] 46 | self.targets = [] 47 | 48 | # now load the picked numpy arrays 49 | for file_name, checksum in downloaded_list: 50 | file_path = os.path.join(self.root, self.base_folder, file_name) 51 | with open(file_path, 'rb') as f: 52 | entry = pickle.load(f, encoding='latin1') 53 | self.data.append(entry['data']) 54 | if 'labels' in entry: 55 | self.targets.extend(entry['labels']) 56 | else: 57 | self.targets.extend(entry['fine_labels']) 58 | 59 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 60 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 61 | 62 | self._load_meta() 63 | 64 | self.data_p = copy.deepcopy(self.data) 65 | 66 | def __getitem__(self, index): 67 | img, img_p, target = self.data[index], self.data_p[index], self.targets[index] 68 | 69 | img = Image.fromarray(img) 70 | 71 | img_poisoned = Image.fromarray(img_p) 72 | 73 | if self.transform is not None: 74 | img = self.transform(img) 75 | img_poisoned = self.transform(img_poisoned) 76 | 77 | if self.target_transform is not None: 78 | target = self.target_transform(target) 79 | 80 | return img, img_poisoned, target, index 81 | 82 | 83 | class UCIHAR_VFL(Dataset): 84 | def __init__(self, root, train, transforms): 85 | if train: 86 | self.data = np.loadtxt(root + '/UCI-HAR/UCI HAR Dataset/train/X_train.txt') 87 | self.data_p = np.loadtxt(root + '/UCI-HAR/UCI HAR Dataset/train/X_train.txt') 88 | self.targets = np.loadtxt(root + '/UCI-HAR/UCI HAR Dataset/train/y_train.txt') - 1 89 | else: 90 | self.data = np.loadtxt(root + '/UCI-HAR/UCI HAR Dataset/test/X_test.txt') 91 | self.data_p = np.loadtxt(root + '/UCI-HAR/UCI HAR Dataset/test/X_test.txt') 92 | self.targets = np.loadtxt(root + '/UCI-HAR/UCI HAR Dataset/test/y_test.txt') - 1 93 | 94 | def __getitem__(self, index): 95 | x = self.data[index] 96 | x_poisoned = self.data_p[index] 97 | y = self.targets[index] 98 | return x, x_poisoned, y, index 99 | 100 | def __len__(self): 101 | return len(self.data) 102 | 103 | 104 | class NUSWIDE_VFL(Dataset): 105 | def __init__(self, root, selected_labels, train, transforms): 106 | if train: 107 | X_image, X_text, Y = get_labeled_data(root + 'NUS_WIDE', selected_labels, None, 'Train') 108 | self.data = torch.cat((torch.tensor(X_image), torch.tensor(X_text)), dim=1) 109 | self.data_p = torch.cat((torch.tensor(X_image), torch.tensor(X_text)), dim=1) 110 | self.targets = torch.tensor(Y) 111 | else: 112 | X_image, X_text, Y = get_labeled_data(root + 'NUS_WIDE', selected_labels, None, 'Test') 113 | self.data = torch.cat((torch.tensor(X_image), torch.tensor(X_text)), dim=1) 114 | self.data_p = torch.cat((torch.tensor(X_image), torch.tensor(X_text)), dim=1) 115 | self.targets = torch.tensor(Y) 116 | 117 | def __getitem__(self, index): 118 | x = self.data[index] 119 | x_poisoned = self.data_p[index] 120 | y = self.targets[index] 121 | return x, x_poisoned, y, index 122 | 123 | def __len__(self): 124 | return len(self.data) 125 | 126 | 127 | class PHISHING_VFL(Dataset): 128 | def __init__(self, root, train, transforms): 129 | data = pd.read_csv(root + 'Phishing/CM1.csv') 130 | drop_cols = ['Result'] 131 | X = data.drop(drop_cols, axis=1) 132 | y = data['Result'].to_numpy() 133 | scaler = StandardScaler() 134 | X = scaler.fit_transform(X) 135 | y = y.reshape((len(y), 1)) 136 | X = torch.tensor(X) 137 | y = torch.tensor(y) 138 | indexes_list = np.array(range(len(X))) 139 | train_indexes = sample(list(indexes_list), 8844) 140 | test_indexes = list(set(list(indexes_list)).difference(set(train_indexes))) 141 | train_data, test_data = X[train_indexes], X[test_indexes] 142 | train_target, test_target = y[train_indexes].reshape(-1), y[test_indexes].reshape(-1) 143 | if train: 144 | self.data = train_data 145 | self.data_p = copy.deepcopy(train_data) 146 | self.targets = train_target 147 | else: 148 | self.data = test_data 149 | self.data_p = copy.deepcopy(test_data) 150 | self.targets = test_target 151 | 152 | def __getitem__(self, index): 153 | x = self.data[index] 154 | x_poisoned = self.data_p[index] 155 | y = self.targets[index] 156 | return x, x_poisoned, y, index 157 | 158 | def __len__(self): 159 | return len(self.data) 160 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | from datetime import datetime 4 | import logging 5 | 6 | import math 7 | import torch.nn as nn 8 | import torchvision.transforms as transforms 9 | from dataset.dataset import CIFAR10_VFL, UCIHAR_VFL, PHISHING_VFL, NUSWIDE_VFL 10 | from models.CIFAR10_models import GlobalModelForCifar10, LocalModelForCifar10 11 | from models.UCIHAR_models import GlobalModelForUCIHAR, LocalModelForUCIHAR 12 | from models.NUSWIDE_models import GlobalModelForNUSWIDE, LocalModelForNUSWIDE 13 | from models.PHISHING_models import GlobalModelForPHISHING, LocalModelForPHISHING 14 | from utils.trainer import Trainer 15 | from attack.attack import attack_lra, attack_rsa 16 | from utils.trigger_visualization import trigger_visualization 17 | from utils.utils import * 18 | 19 | 20 | def main(args): 21 | device = torch.device("cuda:" + str(args.device) if torch.cuda.is_available() else "cpu") 22 | 23 | # create logger 24 | logger = logging.getLogger(__name__) 25 | logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.DEBUG) 26 | logger.setLevel(level=logging.DEBUG) 27 | 28 | # create handler for writing logs 29 | if not os.path.isdir(args.results_dir): 30 | os.mkdir(args.results_dir) 31 | fh = logging.FileHandler(args.results_dir + '/experiment.log') 32 | 33 | # add handler 34 | logger.addHandler(fh) 35 | 36 | # params 37 | logger.info(args) 38 | 39 | # create dataset 40 | logger.info("=> Preparing Data...") 41 | if args.dataset == 'CIFAR10': 42 | # data transform 43 | transform_train = transforms.Compose([ 44 | transforms.RandomCrop(32, padding=4), 45 | transforms.RandomHorizontalFlip(), 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 48 | ]) 49 | 50 | transform_test = transforms.Compose([ 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 53 | ]) 54 | # return data, label, index 55 | train_data = CIFAR10_VFL(root=args.data_dir, train=True, download=True, transform=transform_train) 56 | test_data = CIFAR10_VFL(root=args.data_dir, train=False, download=True, transform=transform_test) 57 | test_data_asr = copy.deepcopy(test_data) 58 | elif args.dataset == 'UCIHAR': 59 | # return data, label, index 60 | train_data = UCIHAR_VFL(root=args.data_dir, train=True, transforms=None) 61 | test_data = UCIHAR_VFL(root=args.data_dir, train=False, transforms=None) 62 | test_data_asr = copy.deepcopy(test_data) 63 | elif args.dataset == 'PHISHING': 64 | # return data, label, index 65 | train_data = PHISHING_VFL(root=args.data_dir, train=True, transforms=None) 66 | test_data = PHISHING_VFL(root=args.data_dir, train=False, transforms=None) 67 | test_data_asr = copy.deepcopy(test_data) 68 | elif args.dataset == 'NUSWIDE': 69 | # return data, label, index 70 | selected_labels = ['buildings', 'grass', 'animal', 'water', 'person'] 71 | train_data = NUSWIDE_VFL(root=args.data_dir, selected_labels=selected_labels, train=True, transforms=None) 72 | test_data = NUSWIDE_VFL(root=args.data_dir, selected_labels=selected_labels, train=False, transforms=None) 73 | test_data_asr = copy.deepcopy(test_data) 74 | else: 75 | raise_dataset_exception() 76 | 77 | if args.trigger_visualization: 78 | trigger_visualization(logger, args, train_data, test_data) 79 | return 80 | 81 | # build vfl models 82 | if args.dataset == 'CIFAR10': 83 | model_list = [] 84 | extractor_list = [] 85 | model_list.append(GlobalModelForCifar10(args)) 86 | extractor_list.append(GlobalModelForCifar10(args)) 87 | for i in range(args.client_num): 88 | model_list.append(LocalModelForCifar10(args)) 89 | extractor_list.append(LocalModelForCifar10(args)) 90 | optimizer_list = [torch.optim.Adam(model.parameters(), lr=args.lr) for model in model_list] 91 | criterion = nn.CrossEntropyLoss().to(device) 92 | elif args.dataset == 'UCIHAR': 93 | model_list = [] 94 | extractor_list = [] 95 | model_list.append(GlobalModelForUCIHAR(args)) 96 | extractor_list.append(GlobalModelForUCIHAR(args)) 97 | for i in range(args.client_num): 98 | model_list.append(LocalModelForUCIHAR(args, i)) 99 | extractor_list.append(LocalModelForUCIHAR(args, i)) 100 | optimizer_list = [torch.optim.Adam(model.parameters(), lr=args.lr) for model in model_list] 101 | criterion = nn.CrossEntropyLoss().to(device) 102 | elif args.dataset == 'PHISHING': 103 | model_list = [] 104 | extractor_list = [] 105 | model_list.append(GlobalModelForPHISHING(args)) 106 | extractor_list.append(GlobalModelForPHISHING(args)) 107 | for i in range(args.client_num): 108 | model_list.append(LocalModelForPHISHING(args, i)) 109 | extractor_list.append(LocalModelForPHISHING(args, i)) 110 | optimizer_list = [torch.optim.Adam(model.parameters(), lr=args.lr) for model in model_list] 111 | criterion = nn.CrossEntropyLoss().to(device) 112 | elif args.dataset == 'NUSWIDE': 113 | model_list = [] 114 | extractor_list = [] 115 | model_list.append(GlobalModelForNUSWIDE(args)) 116 | extractor_list.append(GlobalModelForNUSWIDE(args)) 117 | for i in range(args.client_num): 118 | model_list.append(LocalModelForNUSWIDE(args, i)) 119 | extractor_list.append(LocalModelForNUSWIDE(args, i)) 120 | optimizer_list = [torch.optim.Adam(model.parameters(), lr=args.lr) for model in model_list] 121 | criterion = nn.CrossEntropyLoss().to(device) 122 | else: 123 | raise_dataset_exception() 124 | model_list = [model.to(device) for model in model_list] 125 | extractor_list = [model.to(device) for model in extractor_list] 126 | # test 127 | if args.test_checkpoint: 128 | if os.path.isfile(args.test_checkpoint): 129 | logger.info("=> loading test checkpoint '{}'".format(args.test_checkpoint)) 130 | checkpoint_test = torch.load(args.test_checkpoint, map_location=device) 131 | for i in range(len(model_list)): 132 | model_list[i].load_state_dict(checkpoint_test['state_dict'][i]) 133 | optimizer_list[i].load_state_dict(checkpoint_test['optimizer'][i]) 134 | logger.info("=> loaded test checkpoint '{}' (epoch {}, best accuracy: {:.4f})" 135 | .format(args.test_checkpoint, checkpoint_test['epoch'], checkpoint_test['best_acc'])) 136 | else: 137 | logger.info("=> no test checkpoint found at '{}'".format(args.test_checkpoint)) 138 | # train from checkpoints 139 | checkpoint = None 140 | if args.pretrained_checkpoint: 141 | if os.path.isfile(args.pretrained_checkpoint): 142 | logger.info("=> loading checkpoint '{}'".format(args.pretrained_checkpoint)) 143 | checkpoint = torch.load(args.pretrained_checkpoint, map_location=device) 144 | args.start_epoch = checkpoint['epoch'] 145 | for i in range(len(model_list)): 146 | model_list[i].load_state_dict(checkpoint['state_dict'][i]) 147 | optimizer_list[i].load_state_dict(checkpoint['optimizer'][i]) 148 | logger.info("=> loaded checkpoint '{}' (epoch {}, best accuracy: {:.4f})" 149 | .format(args.pretrained_checkpoint, checkpoint['epoch'], checkpoint['best_acc'])) 150 | else: 151 | logger.info("=> no checkpoint found at '{}'".format(args.pretrained_checkpoint)) 152 | # load feature extractor 153 | if args.feature_extractor: 154 | if os.path.isfile(args.feature_extractor): 155 | logger.info("=> loading checkpoint '{}'".format(args.feature_extractor)) 156 | extractor_checkpoint = torch.load(args.feature_extractor, map_location=device) 157 | for i in range(len(extractor_list)): 158 | extractor_list[i].load_state_dict(extractor_checkpoint['state_dict'][i]) 159 | logger.info("=> loaded checkpoint '{}' (epoch {}, best accuracy: {:.4f})" 160 | .format(args.feature_extractor, extractor_checkpoint['epoch'], 161 | extractor_checkpoint['best_acc'])) 162 | else: 163 | logger.info("=> no checkpoint found at '{}'".format(args.feature_extractor)) 164 | 165 | if args.dataset == 'CIFAR10': 166 | trigger_dimensions = [] 167 | pass 168 | elif args.dataset == 'UCIHAR': 169 | ranges = range(math.ceil(train_data.data.shape[1] / args.client_num), 170 | train_data.data.shape[1]) 171 | trigger_dimensions = np.random.choice(ranges, args.poison_dimensions, replace=False) 172 | elif args.dataset == 'PHISHING': 173 | ranges = range(15, 30) 174 | trigger_dimensions = np.random.choice(ranges, args.poison_dimensions, replace=False) 175 | elif args.dataset == 'NUSWIDE': 176 | ranges = range(634, 1634) 177 | trigger_dimensions = np.random.choice(ranges, args.poison_dimensions, replace=False) 178 | else: 179 | raise_dataset_exception() 180 | 181 | if args.attack is None: 182 | test_data_asr.data = attack_rsa(args, logger, test_data_asr.data, trigger_dimensions, 1, 'test') 183 | elif args.attack == 'rsa': 184 | train_data.data = attack_rsa(args, logger, train_data.data, trigger_dimensions, args.poison_rate, 'train') 185 | test_data_asr.data = attack_rsa(args, logger, test_data_asr.data, trigger_dimensions, 1, 'test') 186 | elif args.attack == 'lra': 187 | train_data.data, train_data.targets = attack_lra(args, logger, train_data.data, trigger_dimensions, 188 | train_data.targets, 189 | args.poison_rate, 'train') 190 | test_data_asr.data, _ = attack_lra(args, logger, test_data_asr.data, trigger_dimensions, test_data_asr.targets, 191 | 1, 'test') 192 | elif args.attack == 'LFBA': 193 | test_data_asr.data = attack_rsa(args, logger, test_data_asr.data, trigger_dimensions, 1, 'test') 194 | else: 195 | raise_attack_exception() 196 | 197 | train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True) 198 | test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=args.batch_size, shuffle=True) 199 | test_asr_loader = torch.utils.data.DataLoader(dataset=test_data_asr, batch_size=args.batch_size, shuffle=True) 200 | trainer = Trainer(device, model_list, extractor_list, extractor_list[args.attack_client_num + 1], optimizer_list, 201 | criterion, train_loader, test_loader, 202 | test_asr_loader, trigger_dimensions, logger, args, 203 | checkpoint) 204 | if args.test_checkpoint: 205 | if args.do_tsne: 206 | train_data_add_trigger = copy.deepcopy(train_data) 207 | train_data_add_trigger.data = attack_rsa(args, logger, train_data_add_trigger.data, trigger_dimensions, 1, 208 | 'test') 209 | tsne_loader = torch.utils.data.DataLoader(dataset=train_data_add_trigger, 210 | batch_size=train_data_add_trigger.data.shape[0]) 211 | trainer.perform_tsne(tsne_loader) 212 | return 213 | else: 214 | trainer.test(0) 215 | return 216 | trainer.train() 217 | 218 | 219 | if __name__ == '__main__': 220 | currentDateAndTime = datetime.now() 221 | parser = argparse.ArgumentParser() 222 | 223 | parser.add_argument('--data_dir', default='/data/data_raw/', help='data directory') 224 | parser.add_argument('--dataset', default='NUSWIDE', help='name of dataset') 225 | parser.add_argument('--device', default=0, type=int, help='GPU number') 226 | parser.add_argument('--results_dir', default='/data/data_raw/vfl_baseline/logs/' + str(currentDateAndTime), 227 | help='the result directory') 228 | parser.add_argument('--seed', default=100, type=int, help='the seed') 229 | parser.add_argument('--epoch', default=100, type=int, help='number of training epoch') 230 | parser.add_argument('--batch_size', default=256, type=int, help='the batch size') 231 | parser.add_argument('--client_num', default=2, type=int, help='the number of clients') 232 | parser.add_argument('--pretrained_checkpoint', default=None, help='the checkpoint file') 233 | parser.add_argument('--lr', default=0.001, type=float, help='the learning rate') 234 | parser.add_argument('--start_epoch', default=0, type=int, help='the epoch number of starting training') 235 | parser.add_argument('--print_steps', default=10, type=int, help='the print step of training logging') 236 | parser.add_argument('--early_stop', default=20, type=int, help='the early stop epoch') 237 | parser.add_argument('--attack', default=None, 238 | help='attack method') # None: baseline, lba: label-based attack, nla: no-label attack 239 | parser.add_argument('--target_label', default=3, type=int, help='the target label for backdoor') 240 | parser.add_argument('--poison_rate', default=0.1, type=float, help='the rate of poison samples') 241 | parser.add_argument('--trigger_visualization', default=False, type=bool, help='visualize the trigger') 242 | parser.add_argument('--poison_dimensions', default=5, type=int, help='the dimensions to be poisoned') 243 | parser.add_argument('--trigger_feature_clip', default=1, type=float, help='the clip ratio of feature trigger') 244 | parser.add_argument('--attack_client_num', default=1, type=int, help='the adversary client') 245 | parser.add_argument('--num_clusters', default=10, type=int, help='the number of clusters') 246 | parser.add_argument('--feature_extractor', default='', help='the feature extractor path') 247 | parser.add_argument('--select_rate', default=1, type=float) 248 | parser.add_argument('--random_select', action='store_true') 249 | parser.add_argument('--select_replace', action='store_true') 250 | parser.add_argument('--poison_all', action='store_true') 251 | parser.add_argument('--anchor_idx', default=33930, type=int) 252 | parser.add_argument('--test_checkpoint', default=None) 253 | 254 | args = parser.parse_args() 255 | 256 | args.results_dir = '/data/data_sw/vfl_baseline/logs/' + args.dataset + '/' + str(currentDateAndTime) 257 | 258 | # set seed 259 | set_seed(args.seed) 260 | 261 | main(args) 262 | -------------------------------------------------------------------------------- /utils/trainer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from random import sample 3 | 4 | from dataset.utils import split_vfl 5 | from attack.attack import attack_LFBA, get_near_index 6 | from utils.utils import * 7 | import time 8 | 9 | 10 | class Trainer: 11 | def __init__(self, device, model_list, extractor_list, extractor, optimizer_list, criterion, train_loader, 12 | test_loader, 13 | test_asr_loader, trigger_dimensions, 14 | logger, args=None, checkpoint=None): 15 | self.device = device 16 | self.model_list = model_list 17 | self.extractor_list = extractor_list 18 | self.extractor = extractor 19 | self.optimizer_list = optimizer_list 20 | self.criterion = criterion 21 | self.train_loader = train_loader 22 | self.test_loader = test_loader 23 | self.test_asr_loader = test_asr_loader 24 | self.logger = logger 25 | self.args = args 26 | self.checkpoint = checkpoint 27 | self.trigger_dimensions = trigger_dimensions 28 | 29 | def adjust_learning_rate(self, epoch): 30 | lr = self.args.lr * (0.1) ** (epoch // 20) 31 | for opt in self.optimizer_list: 32 | for param_group in opt.param_groups: 33 | param_group['lr'] = lr 34 | 35 | def train(self): 36 | start_time_train = time.time() 37 | if self.args.attack: 38 | self.logger.info("=> Start Training with {}...".format(self.args.attack)) 39 | if self.args.pretrain_stage: 40 | self.logger.info("=> Pretrain...") 41 | else: 42 | self.logger.info("=> Start Training Baseline...") 43 | epoch_loss_list = [] 44 | model_list = self.model_list 45 | model_list = [model.train() for model in model_list] 46 | best_acc = 0 47 | best_trade_off = 0 48 | best_epoch = 0 49 | asr_for_best_epoch = 0 50 | target_for_best_epoch = 0 51 | no_change = 0 52 | total_time_GPC = 0 53 | total_time_HS = 0 54 | self.select_his = torch.zeros(self.train_loader.dataset.data.shape[0]) 55 | if self.checkpoint: 56 | best_acc = self.checkpoint['best_acc'] 57 | # train and update 58 | for ep in range(self.args.start_epoch, self.args.epoch): 59 | batch_loss_list = [] 60 | total = 0 61 | correct = 0 62 | if ep >= 1 and self.args.attack == 'LFBA': 63 | self.train_features, self.train_labels, self.train_indexes = self.grad_vec_epoch, self.target_epoch, self.indexes_epoch 64 | self.train_features, self.train_labels, self.train_indexes = self.train_features.cpu(), self.train_labels.cpu(), self.train_indexes.cpu() 65 | self.num_poisons = int(self.args.poison_rate * len(self.train_loader.dataset.data)) 66 | self.num_select = int(self.num_poisons * self.args.select_rate) 67 | 68 | # select sample set 69 | if ep == 1: 70 | start_time = time.time() 71 | self.anchor_idx_t = torch.nonzero(self.train_indexes == self.args.anchor_idx).squeeze() 72 | self.indexes = get_near_index(self.train_features[self.anchor_idx_t], self.train_features, 73 | self.num_poisons) 74 | end_time = time.time() 75 | print("The poison set construction time: {}".format((end_time - start_time))) 76 | total_time_GPC += (end_time - start_time) 77 | self.poison_indexes = self.train_indexes[self.indexes] 78 | self.consistent_rate = float( 79 | (self.train_labels[self.indexes] == int(self.train_labels[self.anchor_idx_t])).sum() / len( 80 | self.indexes)) 81 | 82 | # For replace poisoning 83 | self.indexes = np.isin(self.train_indexes.numpy(), torch.tensor(self.poison_indexes).numpy()) 84 | temp = np.array(range(len(self.train_indexes))) 85 | self.indexes = temp[self.indexes] 86 | self.l2_norm_features = torch.norm(self.train_features[self.indexes], p=2, dim=1) 87 | start_time = time.time() 88 | self.poison_features, self.select_indexes = self.l2_norm_features.topk(self.num_select, dim=0, 89 | largest=True, 90 | sorted=True) 91 | end_time = time.time() 92 | print("The hard-sample selection time: {}".format((end_time - start_time))) 93 | total_time_HS += (end_time - start_time) 94 | num_of_replace = int(len(self.poison_indexes) * self.args.select_rate) 95 | replace_all_list = list(set(self.train_indexes.numpy()).difference(set(torch.tensor(self.poison_indexes).numpy()))) 96 | replace_indexes_others = sample(replace_all_list, num_of_replace) 97 | random_indexes_target = sample(list(self.poison_indexes), num_of_replace) 98 | selected_indexes_target = self.train_indexes[self.indexes[self.select_indexes]] 99 | 100 | if self.args.poison_all: 101 | if self.args.random_select: 102 | self.poison_indexes_t = sample(list(self.poison_indexes), self.num_select) 103 | self.indexes = np.isin(self.train_indexes.numpy(), torch.tensor(self.poison_indexes_t).numpy()) 104 | self.poisoning_labels = np.array(self.train_labels)[self.indexes] 105 | self.anchor_label = int(self.train_labels[self.train_indexes == self.args.anchor_idx]) 106 | self.args.target_label = self.anchor_label 107 | self.logger.info('Target label:{}'.format(self.anchor_label)) 108 | self.clean_data_p = copy.deepcopy(self.train_loader.dataset.data_p) 109 | if self.args.random_select: 110 | self.train_loader.dataset.data = attack_LFBA(self.args, self.logger, [], 111 | [], self.train_indexes, 112 | self.poison_indexes_t, 113 | self.clean_data_p, self.train_loader.dataset.targets, 114 | self.trigger_dimensions, 115 | self.args.poison_rate, 'train') 116 | else: 117 | self.train_loader.dataset.data = attack_LFBA(self.args, self.logger, [], 118 | [], self.train_indexes, 119 | self.poison_indexes, 120 | self.clean_data_p, 121 | self.train_loader.dataset.targets, 122 | self.trigger_dimensions, 123 | self.args.poison_rate, 'train') 124 | else: 125 | if self.args.random_select: 126 | replace_indexes_target = random_indexes_target 127 | else: 128 | replace_indexes_target = selected_indexes_target 129 | self.poisoning_labels = np.array(self.train_labels)[self.indexes] 130 | self.anchor_label = int(self.train_labels[self.train_indexes == self.args.anchor_idx]) 131 | self.clean_data_p = copy.deepcopy(self.train_loader.dataset.data_p) 132 | self.train_loader.dataset.data = attack_LFBA(self.args, self.logger, replace_indexes_others, 133 | replace_indexes_target, self.train_indexes, 134 | self.poison_indexes, 135 | self.clean_data_p, 136 | self.train_loader.dataset.targets, 137 | self.trigger_dimensions, 138 | self.args.poison_rate, 'train') 139 | self.args.target_label = self.anchor_label 140 | self.logger.info('Target label:{}'.format(self.anchor_label)) 141 | 142 | elif self.args.attack == 'rsa' or self.args.attack == 'lra' or self.args.attack is None: 143 | pass 144 | 145 | self.logger.info("=> Start Training for Injecting Backdoor...") 146 | 147 | self.grad_vec_epoch = [] 148 | self.indexes_epoch = [] 149 | self.target_epoch = [] 150 | for step, (x_n, x_p, y, index) in enumerate(self.train_loader): 151 | x = x_n 152 | x = x.to(self.device).float() 153 | y = y.to(self.device).long() 154 | # split data for vfl 155 | x_split_list = split_vfl(x, self.args) 156 | local_output_list = [] 157 | global_input_list = [] 158 | # get the local model outputs 159 | for i in range(self.args.client_num): 160 | local_output_list.append(model_list[i + 1](x_split_list[i])) 161 | # get the global model inputs, recording the gradients 162 | for i in range(self.args.client_num): 163 | global_input_t = local_output_list[i].detach().clone() 164 | global_input_t.requires_grad_(True) 165 | global_input_list.append(global_input_t) 166 | local_output_list[i].requires_grad_(True) 167 | local_output_list[i].retain_grad() 168 | x_split_list[i].requires_grad_(True) 169 | x_split_list[i].retain_grad() 170 | 171 | global_output = model_list[0](local_output_list) 172 | 173 | # global model backward 174 | loss = self.criterion(global_output, y) 175 | for opt in self.optimizer_list: 176 | opt.zero_grad() 177 | 178 | loss.backward() 179 | 180 | if self.args.attack == 'LFBA': 181 | self.grad_vec_epoch.append(local_output_list[self.args.attack_client_num].grad.to(self.device)) 182 | self.indexes_epoch.append(index) 183 | self.target_epoch.append(y) 184 | 185 | for opt in self.optimizer_list: 186 | opt.step() 187 | batch_loss_list.append(loss.item()) 188 | 189 | # calculate the training accuracy 190 | _, predicted = global_output.max(1) 191 | total += y.size(0) 192 | correct += predicted.eq(y).sum().item() 193 | 194 | # train_acc 195 | train_acc = correct / total 196 | current_loss = sum(batch_loss_list) / len(batch_loss_list) 197 | 198 | if step % self.args.print_steps == 0: 199 | self.logger.info( 200 | 'Epoch: {}, {}/{}: train loss: {:.4f}, train main task accuracy: {:.4f}'.format(ep + 1, 201 | step + 1, 202 | len(self.train_loader), 203 | current_loss, 204 | train_acc)) 205 | if self.args.attack == 'LFBA': 206 | self.grad_vec_epoch = torch.cat(self.grad_vec_epoch) 207 | self.indexes_epoch = torch.cat(self.indexes_epoch) 208 | self.target_epoch = torch.cat(self.target_epoch) 209 | 210 | epoch_loss = sum(batch_loss_list) / len(batch_loss_list) 211 | epoch_loss_list.append(epoch_loss) 212 | self.adjust_learning_rate(ep + 1) 213 | test_acc, test_poison_accuracy, test_target, test_asr = self.test(ep) 214 | test_trade_off = (test_acc + test_asr) / 2 215 | if test_trade_off > best_trade_off: 216 | # best accuracy 217 | best_acc = test_acc 218 | best_trade_off = test_trade_off 219 | poison_acc_for_best_epoch = test_poison_accuracy 220 | asr_for_best_epoch = test_asr 221 | target_for_best_epoch = test_target 222 | no_change = 0 223 | best_epoch = ep 224 | # save model 225 | self.logger.info("=> Save best model...") 226 | state = { 227 | 'epoch': ep + 1, 228 | 'best_acc': best_acc, 229 | 'test_trade_off': test_trade_off, 230 | 'test_target': target_for_best_epoch, 231 | 'poison_acc': poison_acc_for_best_epoch, 232 | 'asr': asr_for_best_epoch, 233 | 'state_dict': [model_list[i].state_dict() for i in range(len(model_list))], 234 | 'optimizer': [self.optimizer_list[i].state_dict() for i in range(len(self.optimizer_list))], 235 | } 236 | filename = os.path.join(self.args.results_dir, 'best_checkpoint.pth.tar'.format(ep + 1)) 237 | torch.save(state, filename) 238 | else: 239 | if ep > self.args.pretrain_stage: 240 | no_change += 1 241 | self.logger.info( 242 | '=> End Epoch: {}, early stop epochs: {}, best epoch: {}, best trade off accuracy: {:.4f}, main task accuracy: {:.4f}, test target accuracy: {:.4f}, test asr: {:.4f}'.format( 243 | ep + 1, 244 | no_change, 245 | best_epoch + 1, best_trade_off, best_acc, target_for_best_epoch, asr_for_best_epoch)) 246 | if no_change == self.args.early_stop: 247 | end_time_train = time.time() 248 | print("The total training time: {}".format((end_time_train - start_time_train))) 249 | print("The average training time of each epoch: {}".format(((end_time_train - start_time_train)) / (ep + 1))) 250 | print("The poison set construction time: {}".format(total_time_GPC)) 251 | print("The average hard-sample selection time: {}".format(total_time_HS / (ep + 1))) 252 | print("The total hard-sample selection time: {}".format(total_time_HS)) 253 | return 254 | 255 | 256 | 257 | def test(self, ep): 258 | self.logger.info("=> Test ASR...") 259 | model_list = self.model_list 260 | model_list = [model.eval() for model in model_list] 261 | # test main task accuracy 262 | batch_loss_list = [] 263 | total = 0 264 | correct = 0 265 | total_target = 0 266 | correct_target = 0 267 | for step, (x, x_p, y, index) in enumerate(self.test_loader): 268 | x = x.to(self.device).float() 269 | y = y.to(self.device).long() 270 | # split data for vfl 271 | x_split_list = split_vfl(x, self.args) 272 | local_output_list = [] 273 | global_input_list = [] 274 | # get the local model outputs 275 | for i in range(self.args.client_num): 276 | local_output_list.append(model_list[i + 1](x_split_list[i])) 277 | # get the global model inputs, recording the gradients 278 | for i in range(self.args.client_num): 279 | global_input_t = local_output_list[i].detach().clone() 280 | global_input_t.requires_grad_(True) 281 | global_input_list.append(global_input_t) 282 | 283 | global_output = model_list[0](local_output_list) 284 | 285 | # global model backward 286 | loss = self.criterion(global_output, y) 287 | batch_loss_list.append(loss.item()) 288 | 289 | # calculate the testing accuracy 290 | _, predicted = global_output.max(1) 291 | total += y.size(0) 292 | correct += predicted.eq(y).sum().item() 293 | total_target += (y == self.args.target_label).float().sum() 294 | correct_target += predicted.eq(y)[y == self.args.target_label].float().sum().item() 295 | 296 | # test poison accuracy and asr 297 | total_poison = 0 298 | correct_poison = 0 299 | total_asr = 0 300 | correct_asr = 0 301 | for step, (x, x_p, y, index) in enumerate(self.test_asr_loader): 302 | x = x.to(self.device).float() 303 | y = y.to(self.device).long() 304 | y_attack_target = torch.ones(size=y.shape).to(self.device).long() 305 | y_attack_target *= self.args.target_label 306 | # split data for vfl 307 | x_split_list = split_vfl(x, self.args) 308 | local_output_list = [] 309 | global_input_list = [] 310 | # get the local model outputs 311 | for i in range(self.args.client_num): 312 | local_output_list.append(model_list[i + 1](x_split_list[i])) 313 | # get the global model inputs, recording the gradients 314 | for i in range(self.args.client_num): 315 | global_input_t = local_output_list[i].detach().clone() 316 | global_input_t.requires_grad_(True) 317 | global_input_list.append(global_input_t) 318 | 319 | global_output = model_list[0](local_output_list) 320 | 321 | # calculate the poison accuracy 322 | _, predicted = global_output.max(1) 323 | total_poison += y.size(0) 324 | correct_poison += predicted.eq(y).sum().item() 325 | # calculate the asr 326 | total_asr += (y != self.args.target_label).float().sum() 327 | correct_asr += (predicted[y != self.args.target_label] == self.args.target_label).float().sum() 328 | 329 | # main task accuracy, poison_acc and asr 330 | test_acc = correct / total 331 | test_poison_accuracy = correct_poison / total_poison 332 | test_asr = correct_asr / total_asr 333 | test_target = correct_target / total_target 334 | epoch_loss = sum(batch_loss_list) / len(batch_loss_list) 335 | test_trade_off = (test_acc + test_asr) / 2 336 | # main task accuracy on target set 337 | self.logger.info( 338 | '=> Test Epoch: {}, main task samples: {}, attack samples: {}, test loss: {:.4f}, test trade off: {:.4f}, test main task ' 339 | 'accuracy: {:.4f}, test target accuracy: {:.4f}, test asr: {:.4f}'.format(ep + 1, 340 | len(self.test_loader.dataset), 341 | len(self.test_asr_loader.dataset), 342 | epoch_loss, 343 | test_trade_off, test_acc, 344 | test_target, test_asr)) 345 | 346 | return test_acc, test_poison_accuracy, test_target, test_asr --------------------------------------------------------------------------------