├── .gitignore ├── LICENSE ├── README.md ├── attack.py ├── configs ├── cifar10_fedavg.yaml ├── cifar10_fltrust.yaml ├── cifar10_ours.yaml ├── mnist_fedavg.yaml ├── mnist_fltrust.yaml └── mnist_ours.yaml ├── dataset ├── celeba.py ├── multi_mnist_loader.py ├── pipa.py └── vggface.py ├── helper.py ├── losses └── loss_functions.py ├── metrics ├── accuracy_metric.py ├── metric.py └── test_loss_metric.py ├── models ├── __init__.py ├── face_ident.py ├── model.py ├── resnet.py └── simple.py ├── requirements.txt ├── synthesizers ├── block_synthesizer.py ├── complex_synthesizer.py ├── pattern_synthesizer.py ├── physical_synthesizer.py ├── singlepixel_synthesizer.py └── synthesizer.py ├── tasks ├── batch.py ├── celeba_helper.py ├── cifar10_task.py ├── fl │ ├── cifar10_fedavg_task.py │ ├── cifar10_fltrust_task.py │ ├── cifar10_ours_task.py │ ├── fl_task.py │ ├── fl_user.py │ ├── fl_user_ours.py │ ├── mnist_fedavg_task.py │ ├── mnist_fltrust_task.py │ └── mnist_ours_task.py ├── imagenet_task.py ├── imdb_helper.py ├── mnist_task.py ├── multimnist_task.py ├── pipa_task.py ├── task.py └── vggface_helper.py ├── training.py └── utils ├── __init__.py ├── index.html ├── min_norm_solvers.py ├── parameters.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | __pycache__ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FedGame 2 | Official implementation for our paper "FedGame: A Game-Theoretic Defense against Backdoor Attacks in Federated Learning" (NeurIPS 2023). 3 | 4 | ## Requirements 5 | ```bash 6 | $ pip install -r requirements.txt 7 | $ mkdir runs 8 | $ mkdir saved_models 9 | ``` 10 | Download Tiny ImageNet: 11 | ```bash 12 | $ wget http://cs231n.stanford.edu/tiny-imagenet-200.zip 13 | ``` 14 | Preprocessing Tiny ImageNet: 15 | ```python 16 | import os 17 | 18 | DATA_DIR = '.data/tiny-imagenet-200' # Original images come in shapes of [3,64,64] 19 | 20 | # Define training and validation data paths 21 | TRAIN_DIR = os.path.join(DATA_DIR, 'train') 22 | VALID_DIR = os.path.join(DATA_DIR, 'val') 23 | 24 | val_img_dir = os.path.join(VALID_DIR, 'images') 25 | 26 | # Open and read val annotations text file 27 | fp = open(os.path.join(VALID_DIR, 'val_annotations.txt'), 'r') 28 | data = fp.readlines() 29 | 30 | # Create dictionary to store img filename (word 0) and corresponding 31 | # label (word 1) for every line in the txt file (as key value pair) 32 | val_img_dict = {} 33 | for line in data: 34 | words = line.split('\t') 35 | val_img_dict[words[0]] = words[1] 36 | fp.close() 37 | 38 | # Display first 10 entries of resulting val_img_dict dictionary 39 | {k: val_img_dict[k] for k in list(val_img_dict)[:10]} 40 | 41 | for img, folder in val_img_dict.items(): 42 | newpath = (os.path.join(val_img_dir, folder)) 43 | if not os.path.exists(newpath): 44 | os.makedirs(newpath) 45 | if os.path.exists(os.path.join(val_img_dir, img)): 46 | os.rename(os.path.join(val_img_dir, img), os.path.join(newpath, img)) 47 | ``` 48 | 49 | ## Usage 50 | FedAvg: 51 | ```bash 52 | $ python training.py --name fedavg --params configs/mnist_fedavg.yaml 53 | ``` 54 | FLTrust: 55 | ``` 56 | $ python training.py --name fltrust --params configs/mnist_fltrust.yaml 57 | ``` 58 | Ours: 59 | ```bash 60 | $ python training.py --name ours --params configs/mnist_ours.yaml 61 | ``` 62 | 63 | 64 | ## Adaptation 65 | Follow the following steps to adopt existing code to a new dataset. 66 | 67 | ### Set Parameters 68 | **Create `configs/.yaml`.** 69 | 70 | All paramters and their default values (if any) are listed in `utils/parameters.py`. Parameters customized for specific tasks are set in `configs/\.yaml`. If you have a new attack setting (e.g. with new datasets/algorithms), create one file with your parameters under the `configs` folder. 71 | 72 | ### Add New Dataset 73 | **Create `tasks/.py` and `tasks/.py`.** 74 | 75 | Datasets are defined in each `Task` class. If you want to adopt our defense to `CIFAR`, first create a `cifar10_task.py` that defines how the dataset should be loaded. (This part is already done.) Then, create a file under `mnist_ours_task` that defines how to load data under FL (and of course other things). **Remember to change number of channels in the trigger defination of `reverse_engineer_per_class` if necessary.** It should be 1 for MNIST and 3 for CIFAR10. 76 | -------------------------------------------------------------------------------- /attack.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict 3 | 4 | import torch 5 | from copy import deepcopy 6 | import numpy as np 7 | from models.model import Model 8 | from synthesizers.synthesizer import Synthesizer 9 | from losses.loss_functions import compute_all_losses_and_grads 10 | from utils.min_norm_solvers import MGDASolver 11 | from utils.parameters import Params 12 | 13 | logger = logging.getLogger('logger') 14 | 15 | 16 | class Attack: 17 | params: Params 18 | synthesizer: Synthesizer 19 | nc_model: Model 20 | nc_optim: torch.optim.Optimizer 21 | loss_hist = list() 22 | # fixed_model: Model 23 | 24 | def __init__(self, params, synthesizer): 25 | self.params = params 26 | self.synthesizer = synthesizer 27 | 28 | def compute_blind_loss(self, model, criterion, batch, attack, ratio=None): 29 | """ 30 | 31 | :param model: 32 | :param criterion: 33 | :param batch: 34 | :param attack: Do not attack at all. Ignore all the parameters 35 | :return: 36 | """ 37 | batch = batch.clip(self.params.clip_batch) 38 | loss_tasks = self.params.loss_tasks.copy() if attack else ['normal'] 39 | batch_back = self.synthesizer.make_backdoor_batch(batch, attack=attack, ratio=ratio) 40 | scale = dict() 41 | 42 | if self.params.loss_threshold and (np.mean(self.loss_hist) >= self.params.loss_threshold 43 | or len(self.loss_hist) < 1000): 44 | loss_tasks = ['normal'] 45 | 46 | if len(loss_tasks) == 1: 47 | loss_values, grads = compute_all_losses_and_grads( 48 | loss_tasks, 49 | self, model, criterion, batch, batch_back, compute_grad=False 50 | ) 51 | elif self.params.loss_balance == 'MGDA': 52 | 53 | loss_values, grads = compute_all_losses_and_grads( 54 | loss_tasks, 55 | self, model, criterion, batch, batch_back, compute_grad=True) 56 | if len(loss_tasks) > 1: 57 | scale = MGDASolver.get_scales(grads, loss_values, 58 | self.params.mgda_normalize, 59 | loss_tasks) 60 | elif self.params.loss_balance == 'fixed': 61 | loss_values, grads = compute_all_losses_and_grads( 62 | loss_tasks, 63 | self, model, criterion, batch, batch_back, compute_grad=False) 64 | 65 | for t in loss_tasks: 66 | scale[t] = self.params.fixed_scales[t] 67 | else: 68 | raise ValueError(f'Please choose between `MGDA` and `fixed`.') 69 | 70 | if len(loss_tasks) == 1: 71 | scale = {loss_tasks[0]: 1.0} 72 | self.loss_hist.append(loss_values[list(loss_values.keys())[0]].item()) 73 | self.loss_hist = self.loss_hist[-1000:] 74 | blind_loss = self.scale_losses(loss_tasks, loss_values, scale) 75 | 76 | return blind_loss 77 | 78 | def scale_losses(self, loss_tasks, loss_values, scale): 79 | blind_loss = 0 80 | for it, t in enumerate(loss_tasks): 81 | self.params.running_losses[t].append(loss_values[t].item()) 82 | self.params.running_scales[t].append(scale[t]) 83 | if it == 0: 84 | blind_loss = scale[t] * loss_values[t] 85 | else: 86 | blind_loss += scale[t] * loss_values[t] 87 | self.params.running_losses['total'].append(blind_loss.item()) 88 | return blind_loss 89 | 90 | def fl_scale_update(self, local_update: Dict[str, torch.Tensor], scale=None): 91 | for name, value in local_update.items(): 92 | if scale is None: 93 | value.mul_(self.params.fl_weight_scale) 94 | else: 95 | value.mul_(scale) 96 | -------------------------------------------------------------------------------- /configs/cifar10_fedavg.yaml: -------------------------------------------------------------------------------- 1 | task: Cifar10_FedAvg 2 | synthesizer: Pattern 3 | random_seed: 0 4 | 5 | 6 | batch_size: 64 7 | test_batch_size: 100 8 | lr: 0.001 9 | momentum: 0.9 10 | decay: 0.0005 11 | epochs: 350 12 | save_on_epochs: [100, 200, 300] 13 | optimizer: SGD 14 | log_interval: 100 15 | scheduler: False 16 | pretrained: True 17 | 18 | poisoning_proportion: 1 19 | backdoor_label: 8 20 | 21 | loss_balance: fixed 22 | fixed_scales: {'backdoor': 0.5, 'normal': 0.5} 23 | 24 | save_model: True 25 | log: True 26 | tb: True 27 | 28 | transform_train: True 29 | 30 | fl: True 31 | fl_no_models: 10 32 | fl_local_epochs: 2 33 | fl_total_participants: 10 34 | fl_eta: 1 35 | fl_q: 0.1 36 | # fl_dp_clip: 0.1 37 | # fl_dp_noise: 0.01 38 | fl_number_of_adversaries: 6 39 | fl_weight_scale: 1.6667 40 | 41 | # defense: krum 42 | 43 | loss_tasks: 44 | - backdoor 45 | - normal 46 | -------------------------------------------------------------------------------- /configs/cifar10_fltrust.yaml: -------------------------------------------------------------------------------- 1 | task: Cifar10_FLTrust 2 | synthesizer: Pattern 3 | random_seed: 0 4 | 5 | 6 | batch_size: 64 7 | test_batch_size: 100 8 | lr: 0.001 9 | momentum: 0.9 10 | decay: 0.0005 11 | epochs: 350 12 | save_on_epochs: [100, 200, 300] 13 | optimizer: SGD 14 | log_interval: 100 15 | scheduler: False 16 | pretrained: True 17 | 18 | poisoning_proportion: 1 19 | backdoor_label: 8 20 | 21 | loss_balance: fixed 22 | fixed_scales: {'backdoor': 0.5, 'normal': 0.5} 23 | 24 | save_model: True 25 | log: True 26 | tb: True 27 | 28 | transform_train: True 29 | 30 | fl: True 31 | fl_no_models: 10 32 | fl_local_epochs: 2 33 | fl_total_participants: 10 34 | fl_eta: 1 35 | fl_q: 0.1 36 | fl_number_of_adversaries: 6 37 | fl_weight_scale: 1.6667 38 | 39 | clean_ratio: 0.1 40 | 41 | fltrust: True 42 | attack_start_epoch: 1 43 | 44 | loss_tasks: 45 | - backdoor 46 | - normal 47 | -------------------------------------------------------------------------------- /configs/cifar10_ours.yaml: -------------------------------------------------------------------------------- 1 | task: Cifar10_Ours 2 | random_seed: 0 3 | synthesizer: Pattern 4 | 5 | 6 | batch_size: 64 7 | test_batch_size: 100 8 | lr: 0.001 9 | momentum: 0.9 10 | decay: 0.0005 11 | epochs: 350 12 | save_on_epochs: [100, 200, 300] 13 | optimizer: SGD 14 | log_interval: 100 15 | scheduler: False 16 | pretrained: True 17 | 18 | poisoning_proportion: 1 19 | backdoor_label: 8 20 | 21 | loss_balance: fixed 22 | fixed_scales: {'backdoor': 0.5, 'normal': 0.5} 23 | 24 | save_model: True 25 | log: True 26 | tb: True 27 | 28 | transform_train: True 29 | 30 | fl: True 31 | fl_no_models: 10 32 | fl_local_epochs: 2 33 | fl_total_participants: 10 34 | fl_eta: 1 35 | fl_q: 0.1 36 | fl_number_of_adversaries: 6 37 | fl_weight_scale: 1.6667 38 | 39 | # clean_classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 40 | clean_set_dataset: GTSRB 41 | clean_ratio: 0.1 42 | 43 | ours: True 44 | ours_lbd: 1 45 | attacker_train_ratio: 0.9 46 | r_interval: 0.1 47 | nc_steps: 100 48 | max_threads: 64 49 | attack_start_epoch: 1 50 | 51 | loss_tasks: 52 | - backdoor 53 | - normal 54 | 55 | -------------------------------------------------------------------------------- /configs/mnist_fedavg.yaml: -------------------------------------------------------------------------------- 1 | task: MNIST_FedAvg 2 | synthesizer: Pattern 3 | random_seed: 0 4 | 5 | 6 | batch_size: 64 7 | test_batch_size: 100 8 | lr: 0.001 9 | momentum: 0.9 10 | decay: 0.0005 11 | epochs: 350 12 | save_on_epochs: [100, 200, 300] 13 | optimizer: SGD 14 | log_interval: 100 15 | scheduler: False 16 | 17 | poisoning_proportion: 1 18 | backdoor_label: 8 19 | 20 | loss_balance: fixed 21 | fixed_scales: {'backdoor': 0.5, 'normal': 0.5} 22 | 23 | save_model: True 24 | log: True 25 | tb: True 26 | 27 | transform_train: True 28 | 29 | fl: True 30 | fl_no_models: 10 31 | fl_local_epochs: 2 32 | fl_total_participants: 10 33 | fl_eta: 1 34 | fl_q: 0.1 35 | # fl_dp_clip: 0.05 36 | # fl_dp_noise: 0.01 37 | fl_number_of_adversaries: 6 38 | fl_weight_scale: 1.6667 39 | 40 | # defense: krum 41 | 42 | loss_tasks: 43 | - backdoor 44 | - normal 45 | -------------------------------------------------------------------------------- /configs/mnist_fltrust.yaml: -------------------------------------------------------------------------------- 1 | task: MNIST_FLTrust 2 | synthesizer: Pattern 3 | random_seed: 0 4 | 5 | 6 | batch_size: 64 7 | test_batch_size: 100 8 | lr: 0.001 9 | momentum: 0.9 10 | decay: 0.0005 11 | epochs: 350 12 | save_on_epochs: [100, 200, 300] 13 | optimizer: SGD 14 | log_interval: 100 15 | scheduler: False 16 | 17 | poisoning_proportion: 1 18 | backdoor_label: 8 19 | 20 | loss_balance: fixed 21 | fixed_scales: {'backdoor': 0.5, 'normal': 0.5} 22 | 23 | save_model: True 24 | log: True 25 | tb: True 26 | 27 | transform_train: True 28 | 29 | fl: True 30 | fl_no_models: 10 31 | fl_local_epochs: 2 32 | fl_total_participants: 10 33 | fl_eta: 1 34 | fl_q: 0.1 35 | fl_number_of_adversaries: 6 36 | fl_weight_scale: 1.6667 37 | 38 | clean_ratio: 0.1 39 | 40 | fltrust: True 41 | attack_start_epoch: 1 42 | 43 | loss_tasks: 44 | - backdoor 45 | - normal 46 | -------------------------------------------------------------------------------- /configs/mnist_ours.yaml: -------------------------------------------------------------------------------- 1 | task: MNIST_Ours 2 | synthesizer: Pattern 3 | random_seed: 0 4 | 5 | 6 | batch_size: 64 7 | test_batch_size: 100 8 | lr: 0.001 9 | momentum: 0.9 10 | decay: 0.0005 11 | epochs: 350 12 | save_on_epochs: [100, 200, 300] 13 | optimizer: SGD 14 | log_interval: 100 15 | scheduler: False 16 | 17 | poisoning_proportion: 1 18 | backdoor_label: 8 19 | 20 | loss_balance: fixed 21 | fixed_scales: {'backdoor': 0.5, 'normal': 0.5} 22 | 23 | save_model: True 24 | log: True 25 | tb: True 26 | 27 | transform_train: True 28 | 29 | fl: True 30 | fl_no_models: 10 31 | fl_local_epochs: 2 32 | fl_total_participants: 10 33 | fl_eta: 1 34 | fl_q: 0.1 35 | fl_number_of_adversaries: 6 36 | fl_weight_scale: 1.6667 37 | 38 | # clean_classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 39 | clean_set_dataset: FashionMNIST 40 | clean_ratio: 0.1 41 | 42 | ours: True 43 | ours_lbd: 1 44 | attacker_train_ratio: 0.9 45 | r_interval: 0.1 46 | nc_steps: 100 47 | max_threads: 64 48 | attack_start_epoch: 1 49 | static: False 50 | 51 | loss_tasks: 52 | - backdoor 53 | - normal 54 | -------------------------------------------------------------------------------- /dataset/celeba.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | import os 4 | import PIL 5 | from torchvision.datasets import VisionDataset 6 | from torchvision.datasets.utils import download_file_from_google_drive, check_integrity, verify_str_arg 7 | from collections import defaultdict 8 | import random 9 | 10 | class CelebA(VisionDataset): 11 | """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset. 12 | 13 | Args: 14 | root (string): Root directory where images are downloaded to. 15 | split (string): One of {'train', 'valid', 'test', 'all'}. 16 | Accordingly dataset is selected. 17 | target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, 18 | or ``landmarks``. Can also be a list to output a tuple with all specified target types. 19 | The targets represent: 20 | ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes 21 | ``identity`` (int): label for each person (data points with the same identity are the same person) 22 | ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height) 23 | ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x, 24 | righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y) 25 | Defaults to ``attr``. 26 | transform (callable, optional): A function/transform that takes in an PIL image 27 | and returns a transformed version. E.g, ``transforms.ToTensor`` 28 | target_transform (callable, optional): A function/transform that takes in the 29 | target and transforms it. 30 | download (bool, optional): If true, downloads the dataset from the internet and 31 | puts it in root directory. If dataset is already downloaded, it is not 32 | downloaded again. 33 | """ 34 | 35 | base_folder = "celeba" 36 | # There currently does not appear to be a easy way to extract 7z in python (without introducing additional 37 | # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available 38 | # right now. 39 | file_list = [ 40 | # File ID MD5 Hash Filename 41 | ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), 42 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), 43 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), 44 | ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), 45 | ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), 46 | ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), 47 | ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), 48 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), 49 | ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), 50 | ] 51 | 52 | def __init__(self, root, split="train", target_type="attr", transform=None, 53 | target_transform=None, download=False): 54 | import pandas 55 | super(CelebA, self).__init__(root, transform=transform, 56 | target_transform=target_transform) 57 | self.split = split 58 | if isinstance(target_type, list): 59 | self.target_type = target_type 60 | else: 61 | self.target_type = [target_type] 62 | 63 | if download: 64 | self.download() 65 | 66 | if not self._check_integrity(): 67 | raise RuntimeError('Dataset not found or corrupted.' + 68 | ' You can use download=True to download it') 69 | 70 | split_map = { 71 | "train": 0, 72 | "valid": 1, 73 | "test": 2, 74 | "all": None, 75 | } 76 | split = split_map[verify_str_arg(split.lower(), "split", 77 | ("train", "valid", "test", "all"))] 78 | 79 | fn = partial(os.path.join, self.root, self.base_folder) 80 | splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0) 81 | identity = pandas.read_csv(fn("identity_CelebA.txt"), delim_whitespace=True, header=None, index_col=0) 82 | bbox = pandas.read_csv(fn("list_bbox_celeba.txt"), delim_whitespace=True, header=1, index_col=0) 83 | landmarks_align = pandas.read_csv(fn("list_landmarks_align_celeba.txt"), delim_whitespace=True, header=1) 84 | attr = pandas.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1) 85 | 86 | mask = slice(None) if split is None else (splits[1] == split) 87 | 88 | self.filename = splits[mask].index.values 89 | self.identity = torch.as_tensor(identity[mask].values) 90 | self.bbox = torch.as_tensor(bbox[mask].values) 91 | self.landmarks_align = torch.as_tensor(landmarks_align[mask].values) 92 | self.attr = torch.as_tensor(attr[mask].values) 93 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} 94 | self.attr_names = list(attr.columns) 95 | 96 | self.identities_dict = defaultdict(list) 97 | for i, [x] in enumerate(self.identity.tolist()): 98 | self.identities_dict[x].append(i) 99 | self.identities_dict = dict(self.identities_dict) 100 | self.identities_set = set(self.identities_dict.keys()) 101 | 102 | empty_identities = [key for key, value in self.identities_dict.items() if len(value)<1] 103 | self.identities_set.difference_update(empty_identities) 104 | 105 | 106 | def _check_integrity(self): 107 | for (_, md5, filename) in self.file_list: 108 | fpath = os.path.join(self.root, self.base_folder, filename) 109 | _, ext = os.path.splitext(filename) 110 | # Allow original archive to be deleted (zip and 7z) 111 | # Only need the extracted images 112 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5): 113 | return False 114 | 115 | # Should check a hash of the images 116 | return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) 117 | 118 | def download(self): 119 | import zipfile 120 | 121 | if self._check_integrity(): 122 | print('Files already downloaded and verified') 123 | return 124 | 125 | for (file_id, md5, filename) in self.file_list: 126 | download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) 127 | 128 | with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: 129 | f.extractall(os.path.join(self.root, self.base_folder)) 130 | 131 | def __getitem__(self, index): 132 | X, target = self.getitem_helper(index) 133 | identity = target.item() 134 | if len(self.identities_dict[identity])>1: 135 | pos = random.sample(set(self.identities_dict[identity]).difference([index]), 1)[0] 136 | else: 137 | pos = identity 138 | neg_ident = random.sample(self.identities_set.difference([identity]), 1)[0] 139 | neg = random.sample(self.identities_dict[neg_ident], 1)[0] 140 | Y, identity_pos = self.getitem_helper(pos) 141 | Z, identity_neg = self.getitem_helper(neg) 142 | 143 | return (X, Y, Z), target 144 | 145 | def getitem_helper(self, index): 146 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) 147 | 148 | target = [] 149 | for t in self.target_type: 150 | if t == "attr": 151 | target.append(self.attr[index, :]) 152 | elif t == "identity": 153 | target.append(self.identity[index, 0]) 154 | elif t == "bbox": 155 | target.append(self.bbox[index, :]) 156 | elif t == "landmarks": 157 | target.append(self.landmarks_align[index, :]) 158 | else: 159 | raise ValueError("Target type \"{}\" is not recognized.".format(t)) 160 | target = tuple(target) if len(target) > 1 else target[0] 161 | 162 | if self.transform is not None: 163 | X = self.transform(X) 164 | 165 | if self.target_transform is not None: 166 | target = self.target_transform(target) 167 | 168 | return X, target 169 | 170 | def __len__(self): 171 | return len(self.attr) 172 | 173 | def extra_repr(self): 174 | lines = ["Target type: {target_type}", "Split: {split}"] 175 | return '\n'.join(lines).format(**self.__dict__) 176 | -------------------------------------------------------------------------------- /dataset/multi_mnist_loader.py: -------------------------------------------------------------------------------- 1 | # Credits to Ozan Sener 2 | # https://github.com/intel-isl/MultiObjectiveOptimization 3 | 4 | from __future__ import print_function 5 | 6 | import codecs 7 | import errno 8 | import os 9 | import os.path 10 | 11 | import numpy as np 12 | import torch 13 | import torch.utils.data as data 14 | from PIL import Image 15 | 16 | 17 | class MNIST(data.Dataset): 18 | """`MNIST `_ Dataset. 19 | 20 | Args: 21 | root (string): Root directory of dataset where ``processed/training.pt`` 22 | and ``processed/test.pt`` exist. 23 | train (bool, optional): If True, creates dataset from ``training.pt``, 24 | otherwise from ``test.pt``. 25 | download (bool, optional): If true, downloads the dataset from the internet and 26 | puts it in root directory. If dataset is already downloaded, it is not 27 | downloaded again. 28 | transform (callable, optional): A function/transform that takes in an PIL image 29 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 30 | target_transform (callable, optional): A function/transform that takes in the 31 | target and transforms it. 32 | """ 33 | urls = [ 34 | 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 35 | 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 36 | 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', 37 | 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', 38 | ] 39 | raw_folder = 'raw' 40 | processed_folder = 'processed' 41 | training_file = 'training.pt' 42 | test_file = 'test.pt' 43 | multi_training_file = 'multi_training.pt' 44 | multi_test_file = 'multi_test.pt' 45 | 46 | def __init__(self, root, train=True, transform=None, target_transform=None, 47 | download=False, multi=False): 48 | self.root = os.path.expanduser(root) 49 | self.transform = transform 50 | self.target_transform = target_transform 51 | self.train = train # training set or test set 52 | self.multi = multi 53 | 54 | if download: 55 | self.download() 56 | 57 | if not self._check_exists(): 58 | raise RuntimeError('Dataset not found.' + 59 | ' You can use download=True to download it') 60 | 61 | if not self._check_multi_exists(): 62 | raise RuntimeError('Multi Task extension not found.' + 63 | ' You can use download=True to download it') 64 | 65 | if multi: 66 | if self.train: 67 | self.train_data, self.train_labels_l, self.train_labels_r = torch.load( 68 | os.path.join(self.root, self.processed_folder, 69 | self.multi_training_file)) 70 | else: 71 | self.test_data, self.test_labels_l, self.test_labels_r = torch.load( 72 | os.path.join(self.root, self.processed_folder, 73 | self.multi_test_file)) 74 | else: 75 | if self.train: 76 | self.train_data, self.train_labels = torch.load( 77 | os.path.join(self.root, self.processed_folder, 78 | self.training_file)) 79 | else: 80 | self.test_data, self.test_labels = torch.load( 81 | os.path.join(self.root, self.processed_folder, 82 | self.test_file)) 83 | 84 | def __getitem__(self, index): 85 | """ 86 | Args: 87 | index (int): Index 88 | 89 | Returns: 90 | tuple: (image, target) where target is index of the target class. 91 | """ 92 | if self.multi: 93 | if self.train: 94 | img, target_l, target_r = self.train_data[index], \ 95 | self.train_labels_l[index], \ 96 | self.train_labels_r[index] 97 | else: 98 | img, target_l, target_r = self.test_data[index], \ 99 | self.test_labels_l[index], \ 100 | self.test_labels_r[index] 101 | 102 | target = target_l * 10 + target_r 103 | else: 104 | if self.train: 105 | img, target = self.train_data[index], self.train_labels[index] 106 | else: 107 | img, target = self.test_data[index], self.test_labels[index] 108 | 109 | # doing this so that it is consistent with all other datasets 110 | # to return a PIL Image 111 | img = Image.fromarray(img.numpy().astype(np.uint8), mode='L') 112 | if self.transform is not None: 113 | img = self.transform(img) 114 | 115 | if self.target_transform is not None: 116 | target = self.target_transform(target) 117 | 118 | # if self.multi: 119 | # return img, target_l, target_r 120 | # else: 121 | return img, target 122 | 123 | def __len__(self): 124 | if self.train: 125 | return len(self.train_data) 126 | else: 127 | return len(self.test_data) 128 | 129 | def _check_exists(self): 130 | return os.path.exists(os.path.join(self.root, self.processed_folder, 131 | self.training_file)) and \ 132 | os.path.exists(os.path.join(self.root, self.processed_folder, 133 | self.test_file)) 134 | 135 | def _check_multi_exists(self): 136 | return os.path.exists(os.path.join(self.root, self.processed_folder, 137 | self.multi_training_file)) and \ 138 | os.path.exists(os.path.join(self.root, self.processed_folder, 139 | self.multi_test_file)) 140 | 141 | def download(self): 142 | """Download the MNIST data if it doesn't exist in processed_folder already.""" 143 | from six.moves import urllib 144 | import gzip 145 | 146 | if self._check_exists() and self._check_multi_exists(): 147 | return 148 | 149 | # download files 150 | try: 151 | os.makedirs(os.path.join(self.root, self.raw_folder)) 152 | os.makedirs(os.path.join(self.root, self.processed_folder)) 153 | except OSError as e: 154 | if e.errno == errno.EEXIST: 155 | pass 156 | else: 157 | raise 158 | 159 | for url in self.urls: 160 | print('Downloading ' + url) 161 | data = urllib.request.urlopen(url) 162 | filename = url.rpartition('/')[2] 163 | file_path = os.path.join(self.root, self.raw_folder, filename) 164 | with open(file_path, 'wb') as f: 165 | f.write(data.read()) 166 | with open(file_path.replace('.gz', ''), 'wb') as out_f, \ 167 | gzip.GzipFile(file_path) as zip_f: 168 | out_f.write(zip_f.read()) 169 | os.unlink(file_path) 170 | 171 | # process and save as torch files 172 | print('Processing...') 173 | mnist_ims, multi_mnist_ims, extension = read_image_file( 174 | os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')) 175 | mnist_labels, multi_mnist_labels_l, multi_mnist_labels_r = read_label_file( 176 | os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'), 177 | extension) 178 | 179 | tmnist_ims, tmulti_mnist_ims, textension = read_image_file( 180 | os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')) 181 | tmnist_labels, tmulti_mnist_labels_l, tmulti_mnist_labels_r = read_label_file( 182 | os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'), 183 | textension) 184 | 185 | mnist_training_set = (mnist_ims, mnist_labels) 186 | multi_mnist_training_set = ( 187 | multi_mnist_ims, multi_mnist_labels_l, multi_mnist_labels_r) 188 | 189 | mnist_test_set = (tmnist_ims, tmnist_labels) 190 | multi_mnist_test_set = ( 191 | tmulti_mnist_ims, tmulti_mnist_labels_l, tmulti_mnist_labels_r) 192 | 193 | with open(os.path.join(self.root, self.processed_folder, 194 | self.training_file), 'wb') as f: 195 | torch.save(mnist_training_set, f) 196 | with open( 197 | os.path.join(self.root, self.processed_folder, self.test_file), 198 | 'wb') as f: 199 | torch.save(mnist_test_set, f) 200 | with open(os.path.join(self.root, self.processed_folder, 201 | self.multi_training_file), 'wb') as f: 202 | torch.save(multi_mnist_training_set, f) 203 | with open(os.path.join(self.root, self.processed_folder, 204 | self.multi_test_file), 'wb') as f: 205 | torch.save(multi_mnist_test_set, f) 206 | print('Done!') 207 | 208 | def __repr__(self): 209 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 210 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 211 | tmp = 'train' if self.train is True else 'test' 212 | fmt_str += ' Split: {}\n'.format(tmp) 213 | fmt_str += ' Root Location: {}\n'.format(self.root) 214 | tmp = ' Transforms (if any): ' 215 | fmt_str += '{0}{1}\n'.format(tmp, 216 | self.transform.__repr__().replace('\n', 217 | '\n' + ' ' * len( 218 | tmp))) 219 | tmp = ' Target Transforms (if any): ' 220 | fmt_str += '{0}{1}'.format(tmp, 221 | self.target_transform.__repr__().replace( 222 | '\n', '\n' + ' ' * len(tmp))) 223 | return fmt_str 224 | 225 | 226 | def get_int(b): 227 | return int(codecs.encode(b, 'hex'), 16) 228 | 229 | 230 | def read_label_file(path, extension): 231 | with open(path, 'rb') as f: 232 | data = f.read() 233 | assert get_int(data[:4]) == 2049 234 | length = get_int(data[4:8]) 235 | parsed = np.frombuffer(data, dtype=np.uint8, offset=8) 236 | multi_labels_l = np.zeros((1 * length), dtype=np.long) 237 | multi_labels_r = np.zeros((1 * length), dtype=np.long) 238 | for im_id in range(length): 239 | for rim in range(1): 240 | multi_labels_l[1 * im_id + rim] = parsed[im_id] 241 | multi_labels_r[1 * im_id + rim] = parsed[ 242 | extension[1 * im_id + rim]] 243 | return torch.from_numpy(parsed).view(length).long(), torch.from_numpy( 244 | multi_labels_l).view( 245 | length * 1).long(), torch.from_numpy(multi_labels_r).view( 246 | length * 1).long() 247 | 248 | 249 | def read_image_file(path): 250 | with open(path, 'rb') as f: 251 | data = f.read() 252 | assert get_int(data[:4]) == 2051 253 | length = get_int(data[4:8]) 254 | num_rows = get_int(data[8:12]) 255 | num_cols = get_int(data[12:16]) 256 | images = [] 257 | parsed = np.frombuffer(data, dtype=np.uint8, offset=16) 258 | pv = parsed.reshape(length, num_rows, num_cols) 259 | multi_length = length * 1 260 | multi_data = np.zeros((1 * length, num_rows, num_cols)) 261 | extension = np.zeros(1 * length, dtype=np.int32) 262 | for left in range(length): 263 | chosen_ones = np.random.permutation(length)[:1] 264 | extension[left * 1:(left + 1) * 1] = chosen_ones 265 | for j, right in enumerate(chosen_ones): 266 | lim = pv[left, :, :] 267 | rim = pv[right, :, :] 268 | new_im = np.zeros((45, 45)) 269 | new_im[5:33, 0:28] = lim 270 | new_im[5:33, 17:45] = rim 271 | new_im[5:33, 17:28] = np.maximum(lim[:, -11:], rim[:, :11]) 272 | multi_data_im = np.array(Image.fromarray(new_im).resize((28, 28))) 273 | multi_data[left * 1 + j, :, :] = multi_data_im 274 | return torch.from_numpy(parsed).view(length, num_rows, 275 | num_cols), torch.from_numpy( 276 | multi_data).view(length, 277 | num_rows, 278 | num_cols), extension 279 | 280 | 281 | if __name__ == '__main__': 282 | import torch 283 | import matplotlib.pyplot as plt 284 | from torchvision import transforms 285 | 286 | 287 | def global_transformer(): 288 | return transforms.Compose([transforms.ToTensor(), 289 | transforms.Normalize((0.1307,), (0.3081,))]) 290 | 291 | 292 | dst = MNIST(root='./multimnist', train=True, download=True, 293 | transform=global_transformer(), multi=True) 294 | loader = torch.utils.data.DataLoader(dst, batch_size=10, shuffle=True, 295 | num_workers=4) 296 | for dat in loader: 297 | ims = dat[0].view(10, 28, 28).numpy() 298 | 299 | labs_l = dat[1] 300 | labs_r = dat[2] 301 | f, axarr = plt.subplots(2, 5) 302 | for j in range(5): 303 | for i in range(2): 304 | axarr[i][j].imshow(ims[j * 2 + i, :, :], cmap='gray') 305 | axarr[i][j].set_title( 306 | '{}_{}'.format(labs_l[j * 2 + i], labs_r[j * 2 + i])) 307 | plt.show() 308 | a = input() 309 | if a == 'ex': 310 | break 311 | else: 312 | plt.close() 313 | -------------------------------------------------------------------------------- /dataset/pipa.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import torch 4 | import torch.utils.data as data 5 | from torchvision.datasets.folder import default_loader 6 | 7 | 8 | class Annotations: 9 | photoset_id = None 10 | photo_id = None 11 | xmin = None 12 | ymin = None 13 | width = None 14 | height = None 15 | identity_id = None 16 | subset_id = None 17 | people_on_photo = 0 18 | 19 | def __repr__(self): 20 | return f'photoset: {self.photoset_id}, photo id: {self.photo_id}, ' \ 21 | f'identity: {self.identity_id}, subs: {self.subset_id}, ' \ 22 | f'{self.people_on_photo}' 23 | 24 | 25 | class PipaDataset(data.Dataset): 26 | """Face Landmarks dataset.""" 27 | 28 | def __init__(self, data_path, train=True, transform=None): 29 | """ 30 | Args: 31 | data_path (string): Directory with all the data. 32 | train (bool): train or test dataset. 33 | transform (callable, optional): Optional transform to be applied 34 | on a sample. 35 | """ 36 | self.directory = data_path 37 | try: 38 | if train: 39 | self.data_list = torch.load(f'{self.directory}/train_split.pt') 40 | else: 41 | self.data_list = torch.load(f'{self.directory}/test_split.pt') 42 | self.photo_list = torch.load(f'{self.directory}/photo_list.pt') 43 | self.target_identities = torch.load( 44 | f'{self.directory}/target_identities.pt') 45 | except FileNotFoundError: 46 | raise FileNotFoundError( 47 | 'Please download the archive: https://drive.google.com/' 48 | 'file/d/1IAsTDl6kw4u8kk7Ikyf8K2A4RSPv9izz') 49 | self.transform = transform 50 | self.loader = default_loader 51 | 52 | self.labels = torch.tensor( 53 | [self.get_label(x)[0] for x in range(len(self))]) 54 | self.metadata = [self.get_label(x) for x in range(len(self))] 55 | 56 | def __len__(self): 57 | return len(self.data_list) 58 | 59 | def get_label(self, idx): 60 | photo_id, identities = self.data_list[idx] 61 | target = len(identities) - 1 62 | if target > 4: 63 | target = 4 64 | target_identity = 0 65 | for pos, z in enumerate(self.target_identities): 66 | if z in identities: 67 | target_identity = pos + 1 68 | return target, target_identity, photo_id, idx 69 | 70 | def __getitem__(self, idx): 71 | photo_id, identities = self.data_list[idx] 72 | x = self.photo_list[photo_id][0] 73 | if x.subset_id == 1: 74 | path = 'train' 75 | else: 76 | path = 'test' 77 | 78 | target = len(identities) - 1 79 | 80 | # more than 5 people nobody cares 81 | if target > 4: 82 | target = 4 83 | target_identity = 0 84 | for pos, z in enumerate(self.target_identities): 85 | if z in identities: 86 | target_identity = pos + 1 87 | 88 | # get image 89 | sample = self.loader( 90 | f'{self.directory}/{path}/{x.photoset_id}_{x.photo_id}.jpg') 91 | crop = self.get_crop(photo_id) 92 | sample = sample.crop(crop) 93 | if self.transform is not None: 94 | sample = self.transform(sample) 95 | 96 | return sample, target, target_identity, (photo_id, idx) 97 | 98 | def get_crop(self, photo_id): 99 | ids = self.photo_list[photo_id] 100 | 101 | left = 100000 102 | upper = 100000 103 | right = 0 104 | lower = 0 105 | for x in ids: 106 | left = min(x.xmin, left) 107 | upper = min(x.ymin, upper) 108 | right = max(x.xmin + x.width, right) 109 | lower = max(x.ymin + x.height, lower) 110 | 111 | diff = (right - left) - (lower - upper) 112 | if diff >= 0: 113 | lower += diff 114 | else: 115 | right -= diff 116 | 117 | return left, upper, right, lower 118 | -------------------------------------------------------------------------------- /dataset/vggface.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import collections 4 | import os 5 | 6 | import numpy as np 7 | import PIL.Image 8 | import scipy.io 9 | import torch 10 | from torch.utils import data 11 | import torchvision.transforms 12 | from torchvision.datasets.folder import default_loader 13 | 14 | 15 | class VGG_Faces2(data.Dataset): 16 | 17 | mean_bgr = np.array([91.4953, 103.8827, 131.0912]) # from resnet50_ft.prototxt 18 | 19 | def __init__(self, root, train, transform=None): 20 | """ 21 | :param root: dataset directory 22 | :param image_list_file: contains image file names under root 23 | :param id_label_dict: X[class_id] -> label 24 | :param split: train or valid 25 | :param transform: 26 | :param horizontal_flip: 27 | :param upper: max number of image used for debug 28 | """ 29 | self.root = root 30 | 31 | if train: 32 | self.file_list = torch.load(self.root + '/train_list.pt') 33 | else: 34 | self.file_list = torch.load(self.root + '/test_list.pt') 35 | self.bboxes = torch.load(self.root + '/bboxes.pt') 36 | 37 | self.transform = transform 38 | self.loader = default_loader 39 | 40 | # self.img_info = [] 41 | # with open(self.image_list_file, 'r') as f: 42 | # for i, img_file in enumerate(f): 43 | # img_file = img_file.strip() # e.g. train/n004332/0317_01.jpg 44 | # class_id = img_file.split("/")[1] # like n004332 45 | # label = self.id_label_dict[class_id] 46 | # self.img_info.append({ 47 | # 'cid': class_id, 48 | # 'img': img_file, 49 | # 'lbl': label, 50 | # }) 51 | # if i % 1000 == 0: 52 | # print("processing: {} images for {}".format(i, self.split)) 53 | # if upper and i == upper - 1: # for debug purpose 54 | # break 55 | 56 | def __len__(self): 57 | return len(self.file_list) 58 | 59 | def __getitem__(self, index): 60 | img_file, label, bbox_id = self.file_list[index] 61 | bbox = self.bboxes[bbox_id] 62 | sample = self.loader(f'{self.root}/{img_file}') 63 | target = torch.tensor(label) 64 | x, y, w, h = bbox 65 | 66 | sample = sample.crop((x,y, x+w, y+h)) 67 | 68 | if self.transform: 69 | sample = self.transform(sample) 70 | 71 | return sample, target 72 | 73 | def transform(self, img): 74 | img = img[:, :, ::-1] # RGB -> BGR 75 | img = img.astype(np.float32) 76 | img -= self.mean_bgr 77 | img = img.transpose(2, 0, 1) # C x H x W 78 | img = torch.from_numpy(img).float() 79 | return img 80 | 81 | def untransform(self, img, lbl): 82 | img = img.numpy() 83 | img = img.transpose(1, 2, 0) 84 | img += self.mean_bgr 85 | img = img.astype(np.uint8) 86 | img = img[:, :, ::-1] 87 | return img, lbl -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | import os 4 | import random 5 | from collections import defaultdict 6 | from copy import deepcopy 7 | from shutil import copyfile 8 | from typing import Union 9 | 10 | import numpy as np 11 | import torch 12 | import yaml 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | from attack import Attack 16 | from synthesizers.synthesizer import Synthesizer 17 | from tasks.fl.fl_task import FederatedLearningTask 18 | from tasks.task import Task 19 | from utils.parameters import Params 20 | from utils.utils import create_logger, create_table 21 | 22 | logger = logging.getLogger('logger') 23 | 24 | 25 | class Helper: 26 | params: Params = None 27 | task: Union[Task, FederatedLearningTask] = None 28 | synthesizer: Synthesizer = None 29 | attack: Attack = None 30 | tb_writer: SummaryWriter = None 31 | 32 | def __init__(self, params): 33 | self.params = Params(**params) 34 | 35 | self.times = {'backward': list(), 'forward': list(), 'step': list(), 36 | 'scales': list(), 'total': list(), 'poison': list()} 37 | if self.params.random_seed is not None: 38 | self.fix_random(self.params.random_seed) 39 | 40 | self.make_folders() 41 | self.make_task() 42 | self.make_synthesizer() 43 | self.attack = Attack(self.params, self.synthesizer) 44 | 45 | # if 'spectral_evasion' in self.params.loss_tasks: 46 | # self.attack.fixed_model = deepcopy(self.task.model) 47 | 48 | self.best_loss = float('inf') 49 | 50 | def make_task(self): 51 | name_lower = self.params.task.lower() 52 | name_cap = self.params.task 53 | if self.params.fl: 54 | module_name = f'tasks.fl.{name_lower}_task' 55 | path = f'tasks/fl/{name_lower}_task.py' 56 | else: 57 | module_name = f'tasks.{name_lower}_task' 58 | path = f'tasks/{name_lower}_task.py' 59 | try: 60 | task_module = importlib.import_module(module_name) 61 | task_class = getattr(task_module, f'{name_cap}Task') 62 | except (ModuleNotFoundError, AttributeError): 63 | raise ModuleNotFoundError(f'Your task: {self.params.task} should ' 64 | f'be defined as a class ' 65 | f'{name_cap}' 66 | f'Task in {path}') 67 | self.task = task_class(self.params) 68 | 69 | def make_synthesizer(self): 70 | name_lower = self.params.synthesizer.lower() 71 | name_cap = self.params.synthesizer 72 | module_name = f'synthesizers.{name_lower}_synthesizer' 73 | try: 74 | synthesizer_module = importlib.import_module(module_name) 75 | task_class = getattr(synthesizer_module, f'{name_cap}Synthesizer') 76 | except (ModuleNotFoundError, AttributeError): 77 | raise ModuleNotFoundError( 78 | f'The synthesizer: {self.params.synthesizer}' 79 | f' should be defined as a class ' 80 | f'{name_cap}Synthesizer in ' 81 | f'synthesizers/{name_lower}_synthesizer.py') 82 | self.synthesizer = task_class(self.task) 83 | 84 | def make_folders(self): 85 | log = create_logger() 86 | if self.params.log: 87 | try: 88 | os.mkdir(self.params.folder_path) 89 | except FileExistsError: 90 | log.info('Folder already exists') 91 | 92 | fh = logging.FileHandler( 93 | filename=f'{self.params.folder_path}/log.txt') 94 | formatter = logging.Formatter('%(asctime)s - %(name)s ' 95 | '- %(levelname)s - %(message)s') 96 | fh.setFormatter(formatter) 97 | log.addHandler(fh) 98 | 99 | log.warning(f'Logging to: {self.params.folder_path}') 100 | 101 | with open(f'{self.params.folder_path}/params.yaml.txt', 'w') as f: 102 | yaml.dump(self.params, f) 103 | 104 | if self.params.tb: 105 | wr = SummaryWriter(log_dir=f'runs/{self.params.name}') 106 | self.tb_writer = wr 107 | params_dict = self.params.to_dict() 108 | table = create_table(params_dict) 109 | self.tb_writer.add_text('Model Params', table) 110 | 111 | def save_model(self, model=None, epoch=0, val_loss=0): 112 | 113 | if self.params.save_model: 114 | logger.info(f"Saving model to {self.params.folder_path}.") 115 | model_name = '{0}/model_last.pt.tar'.format(self.params.folder_path) 116 | saved_dict = {'state_dict': model.state_dict(), 117 | 'epoch': epoch, 118 | 'lr': self.params.lr, 119 | 'params_dict': self.params.to_dict()} 120 | self.save_checkpoint(saved_dict, False, model_name) 121 | if epoch in self.params.save_on_epochs: 122 | logger.info(f'Saving model on epoch {epoch}') 123 | self.save_checkpoint(saved_dict, False, 124 | filename=f'{model_name}.epoch_{epoch}') 125 | if val_loss < self.best_loss: 126 | self.save_checkpoint(saved_dict, False, f'{model_name}.best') 127 | self.best_loss = val_loss 128 | 129 | def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'): 130 | if not self.params.save_model: 131 | return False 132 | torch.save(state, filename) 133 | 134 | if is_best: 135 | copyfile(filename, 'model_best.pth.tar') 136 | 137 | def flush_writer(self): 138 | if self.tb_writer: 139 | self.tb_writer.flush() 140 | 141 | def plot(self, x, y, name): 142 | if self.tb_writer is not None: 143 | self.tb_writer.add_scalar(tag=name, scalar_value=y, global_step=x) 144 | self.flush_writer() 145 | else: 146 | return False 147 | 148 | def report_training_losses_scales(self, batch_id, epoch): 149 | if not self.params.report_train_loss or \ 150 | batch_id % self.params.log_interval != 0: 151 | return 152 | total_batches = len(self.task.train_loader) 153 | losses = [f'{x}: {np.mean(y):.2f}' 154 | for x, y in self.params.running_losses.items()] 155 | scales = [f'{x}: {np.mean(y):.2f}' 156 | for x, y in self.params.running_scales.items()] 157 | logger.info( 158 | f'Epoch: {epoch:3d}. ' 159 | f'Batch: {batch_id:5d}/{total_batches}. ' 160 | f' Losses: {losses}.' 161 | f' Scales: {scales}') 162 | for name, values in self.params.running_losses.items(): 163 | self.plot(epoch * total_batches + batch_id, np.mean(values), 164 | f'Train/Loss_{name}') 165 | for name, values in self.params.running_scales.items(): 166 | self.plot(epoch * total_batches + batch_id, np.mean(values), 167 | f'Train/Scale_{name}') 168 | 169 | self.params.running_losses = defaultdict(list) 170 | self.params.running_scales = defaultdict(list) 171 | 172 | @staticmethod 173 | def fix_random(seed=1): 174 | from torch.backends import cudnn 175 | 176 | logger.warning('Setting random_seed seed for reproducible results.') 177 | random.seed(seed) 178 | torch.manual_seed(seed) 179 | torch.cuda.manual_seed_all(seed) 180 | cudnn.deterministic = False 181 | cudnn.enabled = True 182 | cudnn.benchmark = True 183 | np.random.seed(seed) 184 | 185 | return True 186 | -------------------------------------------------------------------------------- /losses/loss_functions.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from torch import nn, autograd 5 | from torch.nn import functional as F 6 | 7 | from models.model import Model 8 | from utils.parameters import Params 9 | from utils.utils import th, record_time 10 | 11 | 12 | def compute_all_losses_and_grads(loss_tasks, attack, model, criterion, 13 | batch, batch_back, 14 | compute_grad=None): 15 | grads = {} 16 | loss_values = {} 17 | for t in loss_tasks: 18 | # if compute_grad: 19 | # model.zero_grad() 20 | if t == 'normal': 21 | loss_values[t], grads[t] = compute_normal_loss(attack.params, 22 | model, 23 | criterion, 24 | batch.inputs, 25 | batch.labels, 26 | grads=compute_grad) 27 | elif t == 'backdoor': 28 | loss_values[t], grads[t] = compute_backdoor_loss(attack.params, 29 | model, 30 | criterion, 31 | batch_back.inputs, 32 | batch_back.labels, 33 | grads=compute_grad) 34 | elif t == 'sentinet_evasion': 35 | loss_values[t], grads[t] = compute_sentinet_evasion( 36 | attack.params, 37 | model, 38 | batch.inputs, 39 | batch_back.inputs, 40 | batch_back.labels, 41 | grads=compute_grad) 42 | elif t == 'mask_norm': 43 | loss_values[t], grads[t] = norm_loss(attack.params, attack.nc_model, 44 | grads=compute_grad) 45 | 46 | # if loss_values[t].mean().item() == 0.0: 47 | # loss_values.pop(t) 48 | # grads.pop(t) 49 | # loss_tasks.remove(t) 50 | return loss_values, grads 51 | 52 | 53 | def compute_normal_loss(params, model, criterion, inputs, 54 | labels, grads): 55 | t = time.perf_counter() 56 | outputs = model(inputs) 57 | record_time(params, t, 'forward') 58 | loss = criterion(outputs, labels) 59 | 60 | if not params.dp: 61 | loss = loss.mean() 62 | 63 | if grads: 64 | t = time.perf_counter() 65 | grads = list(torch.autograd.grad(loss.mean(), 66 | [x for x in model.parameters() if 67 | x.requires_grad], 68 | retain_graph=True)) 69 | record_time(params, t, 'backward') 70 | 71 | return loss, grads 72 | 73 | 74 | def compute_nc_evasion_loss(params, nc_model: Model, model: Model, inputs, 75 | labels, grads=None): 76 | criterion = torch.nn.CrossEntropyLoss(reduction='none') 77 | nc_model.switch_grads(False) 78 | outputs = model(nc_model(inputs)) 79 | loss = criterion(outputs, labels).mean() 80 | 81 | if grads: 82 | grads = get_grads(params, model, loss) 83 | 84 | return loss, grads 85 | 86 | 87 | def compute_backdoor_loss(params, model, criterion, inputs_back, 88 | labels_back, grads=None): 89 | t = time.perf_counter() 90 | outputs = model(inputs_back) 91 | record_time(params, t, 'forward') 92 | loss = criterion(outputs, labels_back) 93 | 94 | if params.task == 'Pipa': 95 | loss[labels_back == 0] *= 0.001 96 | if labels_back.sum().item() == 0.0: 97 | loss[:] = 0.0 98 | if not params.dp: 99 | loss = loss.mean() 100 | 101 | if grads: 102 | grads = get_grads(params, model, loss) 103 | 104 | return loss, grads 105 | 106 | 107 | def compute_latent_cosine_similarity(params: Params, 108 | model: Model, 109 | fixed_model: Model, 110 | inputs, 111 | grads=None): 112 | if not fixed_model: 113 | return torch.tensor(0.0), None 114 | t = time.perf_counter() 115 | with torch.no_grad(): 116 | _, fixed_latent = fixed_model(inputs, latent=True) 117 | _, latent = model(inputs) 118 | record_time(params, t, 'forward') 119 | 120 | loss = -torch.cosine_similarity(latent, fixed_latent).mean() + 1 121 | if grads: 122 | grads = get_grads(params, model, loss) 123 | 124 | return loss, grads 125 | 126 | 127 | def compute_spectral_evasion_loss(params: Params, 128 | model: Model, 129 | fixed_model: Model, 130 | inputs, 131 | grads=None): 132 | """ 133 | Evades spectral analysis defense. Aims to preserve the latent representation 134 | on non-backdoored inputs. Uses a checkpoint non-backdoored `fixed_model` to 135 | compare the outputs. Uses euclidean distance as penalty. 136 | 137 | 138 | :param params: training parameters 139 | :param model: current model 140 | :param fixed_model: saved non-backdoored model as a reference. 141 | :param inputs: training data inputs 142 | :param grads: compute gradients. 143 | 144 | :return: 145 | """ 146 | 147 | if not fixed_model: 148 | return torch.tensor(0.0), None 149 | t = time.perf_counter() 150 | with torch.no_grad(): 151 | _, fixed_latent = fixed_model(inputs, latent=True) 152 | _, latent = model(inputs, latent=True) 153 | record_time(params, t, 'latent_fixed') 154 | if params.spectral_similarity == 'norm': 155 | loss = torch.norm(latent - fixed_latent, dim=1).mean() 156 | elif params.spectral_similarity == 'cosine': 157 | loss = -torch.cosine_similarity(latent, fixed_latent).mean() + 1 158 | else: 159 | raise ValueError(f'Specify correct similarity metric for ' 160 | f'spectral evasion: [norm, cosine].') 161 | if grads: 162 | grads = get_grads(params, model, loss) 163 | 164 | return loss, grads 165 | 166 | 167 | def get_latent_grads(params, model, inputs, labels): 168 | model.eval() 169 | model.zero_grad() 170 | t = time.perf_counter() 171 | pred = model(inputs) 172 | record_time(params, t, 'forward') 173 | z = torch.zeros_like(pred) 174 | 175 | z[list(range(labels.shape[0])), labels] = 1 176 | 177 | pred = pred * z 178 | t = time.perf_counter() 179 | pred.sum().backward(retain_graph=True) 180 | record_time(params, t, 'backward') 181 | 182 | gradients = model.get_gradient()[labels == params.backdoor_label] 183 | pooled_gradients = torch.mean(gradients, dim=[0, 2, 3]).detach() 184 | model.zero_grad() 185 | 186 | return pooled_gradients 187 | 188 | 189 | def compute_sentinet_evasion(params, model, inputs, inputs_back, labels_back, 190 | grads=None): 191 | """The GradCam design is taken from: 192 | https://medium.com/@stepanulyanin/implementing-grad-cam-in-pytorch-ea0937c31e82 193 | 194 | :param params: 195 | :param model: 196 | :param inputs: 197 | :param inputs_back: 198 | :param labels_back: 199 | :param grads: 200 | :return: 201 | """ 202 | pooled = get_latent_grads(params, model, inputs, labels_back) 203 | t = time.perf_counter() 204 | features = model.features(inputs) 205 | features = features * pooled.view(1, 512, 1, 1) 206 | 207 | pooled_back = get_latent_grads(params, model, inputs_back, labels_back) 208 | back_features = model.features(inputs_back) 209 | back_features = back_features * pooled_back.view(1, 512, 1, 1) 210 | 211 | features = torch.mean(features, dim=[0, 1], keepdim=True) 212 | features = F.relu(features) / features.max() 213 | 214 | back_features = torch.mean(back_features, dim=[0, 1], keepdim=True) 215 | back_features = F.relu( 216 | back_features) / back_features.max() 217 | loss = F.relu(back_features - features).max() * 10 218 | if grads: 219 | loss.backward(retain_graph=True) 220 | grads = copy_grad(model) 221 | 222 | return loss, grads 223 | 224 | 225 | def norm_loss(params, model, grads=None): 226 | if params.nc_p_norm == 1: 227 | norm = torch.sum(th(model.mask)) 228 | elif params.nc_p_norm == 2: 229 | norm = torch.norm(th(model.mask)) 230 | else: 231 | raise ValueError('Not support mask norm.') 232 | 233 | if grads: 234 | grads = get_grads(params, model, norm) 235 | model.zero_grad() 236 | 237 | return norm, grads 238 | 239 | 240 | def get_grads(params, model, loss): 241 | t = time.perf_counter() 242 | grads = list(torch.autograd.grad(loss.mean(), 243 | [x for x in model.parameters() if 244 | x.requires_grad], 245 | retain_graph=True)) 246 | record_time(params, t, 'backward') 247 | 248 | return grads 249 | 250 | 251 | # UNTESTED 252 | def estimate_fisher(params, model, data_loader, sample_size): 253 | # sample loglikelihoods from the dataset. 254 | loglikelihoods = [] 255 | for x, y in data_loader: 256 | x = x.to(params.device) 257 | y = y.to(params.device) 258 | loglikelihoods.append( 259 | F.log_softmax(model(x)[0], dim=1)[range(params.batch_size), y] 260 | ) 261 | if len(loglikelihoods) >= sample_size // params.batch_size: 262 | break 263 | # estimate the fisher information of the parameters. 264 | loglikelihoods = torch.cat(loglikelihoods).unbind() 265 | loglikelihood_grads = zip(*[autograd.grad( 266 | l, model.parameters(), 267 | retain_graph=(i < len(loglikelihoods)) 268 | ) for i, l in enumerate(loglikelihoods, 1)]) 269 | loglikelihood_grads = [torch.stack(gs) for gs in loglikelihood_grads] 270 | fisher_diagonals = [(g ** 2).mean(0) for g in loglikelihood_grads] 271 | param_names = [ 272 | n.replace('.', '__') for n, p in model.named_parameters() 273 | ] 274 | return {n: f.detach() for n, f in zip(param_names, fisher_diagonals)} 275 | 276 | 277 | def consolidate(model, fisher): 278 | for n, p in model.named_parameters(): 279 | n = n.replace('.', '__') 280 | model.register_buffer('{}_mean'.format(n), p.data.clone()) 281 | model.register_buffer('{}_fisher' 282 | .format(n), fisher[n].data.clone()) 283 | 284 | 285 | def ewc_loss(params: Params, model: nn.Module, grads=None): 286 | try: 287 | losses = [] 288 | for n, p in model.named_parameters(): 289 | # retrieve the consolidated mean and fisher information. 290 | n = n.replace('.', '__') 291 | mean = getattr(model, '{}_mean'.format(n)) 292 | fisher = getattr(model, '{}_fisher'.format(n)) 293 | # wrap mean and fisher in variables. 294 | # calculate a ewc loss. (assumes the parameter's prior as 295 | # gaussian distribution with the estimated mean and the 296 | # estimated cramer-rao lower bound variance, which is 297 | # equivalent to the inverse of fisher information) 298 | losses.append((fisher * (p - mean) ** 2).sum()) 299 | loss = (model.lamda / 2) * sum(losses) 300 | if grads: 301 | loss.backward() 302 | grads = get_grads(params, model, loss) 303 | return loss, grads 304 | else: 305 | return loss, None 306 | 307 | except AttributeError: 308 | # ewc loss is 0 if there's no consolidated parameters. 309 | print('exception') 310 | return torch.zeros(1).to(params.device), grads 311 | 312 | 313 | def copy_grad(model: nn.Module): 314 | grads = list() 315 | for name, params in model.named_parameters(): 316 | if not params.requires_grad: 317 | a = 1 318 | # print(name) 319 | else: 320 | grads.append(params.grad.clone().detach()) 321 | model.zero_grad() 322 | return grads 323 | -------------------------------------------------------------------------------- /metrics/accuracy_metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from metrics.metric import Metric 3 | 4 | 5 | class AccuracyMetric(Metric): 6 | 7 | def __init__(self, top_k=(1,)): 8 | self.top_k = top_k 9 | self.main_metric_name = 'Top-1' 10 | super().__init__(name='Accuracy', train=False) 11 | 12 | def compute_metric(self, outputs: torch.Tensor, 13 | labels: torch.Tensor): 14 | """Computes the precision@k for the specified values of k""" 15 | max_k = max(self.top_k) 16 | batch_size = labels.shape[0] 17 | 18 | _, pred = outputs.topk(max_k, 1, True, True) 19 | pred = pred.t() 20 | correct = pred.eq(labels.view(1, -1).expand_as(pred)) 21 | 22 | res = dict() 23 | for k in self.top_k: 24 | correct_k = correct[:k].view(-1).float().sum(0) 25 | res[f'Top-{k}'] = (correct_k.mul_(100.0 / batch_size)).item() 26 | return res 27 | -------------------------------------------------------------------------------- /metrics/metric.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import defaultdict 3 | from typing import Dict, Any 4 | 5 | import numpy as np 6 | 7 | logger = logging.getLogger('logger') 8 | 9 | 10 | class Metric: 11 | name: str 12 | train: bool 13 | plottable: bool = True 14 | running_metric = None 15 | main_metric_name = None 16 | 17 | def __init__(self, name, train=False): 18 | self.train = train 19 | self.name = name 20 | 21 | self.running_metric = defaultdict(list) 22 | 23 | def __repr__(self): 24 | metrics = self.get_value() 25 | text = [f'{key}: {val:.2f}' for key, val in metrics.items()] 26 | return f'{self.name}: ' + ','.join(text) 27 | 28 | def compute_metric(self, outputs, labels) -> Dict[str, Any]: 29 | raise NotImplemented 30 | 31 | def accumulate_on_batch(self, outputs=None, labels=None): 32 | current_metrics = self.compute_metric(outputs, labels) 33 | for key, value in current_metrics.items(): 34 | self.running_metric[key].append(value) 35 | 36 | def get_value(self) -> Dict[str, np.ndarray]: 37 | metrics = dict() 38 | for key, value in self.running_metric.items(): 39 | metrics[key] = np.mean(value) 40 | 41 | return metrics 42 | 43 | def get_main_metric_value(self): 44 | if not self.main_metric_name: 45 | raise ValueError(f'For metric {self.name} define ' 46 | f'attribute main_metric_name.') 47 | metrics = self.get_value() 48 | return metrics[self.main_metric_name] 49 | 50 | def reset_metric(self): 51 | self.running_metric = defaultdict(list) 52 | 53 | def plot(self, tb_writer, step, tb_prefix=''): 54 | if tb_writer is not None and self.plottable: 55 | metrics = self.get_value() 56 | for key, value in metrics.items(): 57 | tb_writer.add_scalar(tag=f'{tb_prefix}/{self.name}_{key}', 58 | scalar_value=value, 59 | global_step=step) 60 | tb_writer.flush() 61 | else: 62 | return False 63 | -------------------------------------------------------------------------------- /metrics/test_loss_metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from metrics.metric import Metric 3 | 4 | 5 | class TestLossMetric(Metric): 6 | 7 | def __init__(self, criterion, train=False): 8 | self.criterion = criterion 9 | self.main_metric_name = 'value' 10 | super().__init__(name='Loss', train=False) 11 | 12 | def compute_metric(self, outputs: torch.Tensor, 13 | labels: torch.Tensor, top_k=(1,)): 14 | """Computes the precision@k for the specified values of k""" 15 | loss = self.criterion(outputs, labels) 16 | return {'value': loss.mean().item()} -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-secure/FedGame/6be7b9e3f9e3a1aa1822fa8b319ec85c6ec27d52/models/__init__.py -------------------------------------------------------------------------------- /models/face_ident.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from models.resnet import resnet50 3 | from torch.nn import init 4 | 5 | # THIS IS STILL NOT WORKING 6 | 7 | def weights_init_kaiming(m): 8 | classname = m.__class__.__name__ 9 | # print(classname) 10 | if classname.find('Conv') != -1: 11 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # For old pytorch, you may use kaiming_normal. 12 | elif classname.find('Linear') != -1: 13 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 14 | init.constant_(m.bias.data, 0.0) 15 | elif classname.find('BatchNorm1d') != -1: 16 | init.normal_(m.weight.data, 1.0, 0.02) 17 | init.constant_(m.bias.data, 0.0) 18 | 19 | def weights_init_classifier(m): 20 | classname = m.__class__.__name__ 21 | if classname.find('Linear') != -1: 22 | init.normal_(m.weight.data, std=0.001) 23 | init.constant_(m.bias.data, 0.0) 24 | 25 | # Defines the new fc layer and classification layer 26 | # |--Linear--|--bn--|--relu--|--Linear--| 27 | class ClassBlock(nn.Module): 28 | def __init__(self, input_dim, class_num, droprate, relu=False, bnorm=True, num_bottleneck=512, linear=True, return_f = False): 29 | super(ClassBlock, self).__init__() 30 | self.return_f = return_f 31 | add_block = [] 32 | if linear: 33 | add_block += [nn.Linear(input_dim, num_bottleneck)] 34 | else: 35 | num_bottleneck = input_dim 36 | if bnorm: 37 | add_block += [nn.BatchNorm1d(num_bottleneck)] 38 | if relu: 39 | add_block += [nn.LeakyReLU(0.1)] 40 | if droprate>0: 41 | add_block += [nn.Dropout(p=droprate)] 42 | add_block = nn.Sequential(*add_block) 43 | add_block.apply(weights_init_kaiming) 44 | 45 | classifier = [] 46 | classifier += [nn.Linear(num_bottleneck, class_num)] 47 | classifier = nn.Sequential(*classifier) 48 | classifier.apply(weights_init_classifier) 49 | 50 | self.add_block = add_block 51 | self.classifier = classifier 52 | def forward(self, x): 53 | x = self.add_block(x) 54 | if self.return_f: 55 | f = x 56 | x = self.classifier(x) 57 | return x,f 58 | else: 59 | x = self.classifier(x) 60 | return x 61 | 62 | # Define the ResNet50-based Model 63 | class ft_net(nn.Module): 64 | 65 | def __init__(self, class_num, droprate=0.5, stride=2): 66 | super(ft_net, self).__init__() 67 | model_ft = resnet50(pretrained=True) 68 | # avg pooling to global pooling 69 | if stride == 1: 70 | model_ft.layer4[0].downsample[0].stride = (1,1) 71 | model_ft.layer4[0].conv2.stride = (1,1) 72 | model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1)) 73 | self.model = model_ft 74 | self.classifier = ClassBlock(2048, class_num, droprate) 75 | 76 | def forward(self, x): 77 | x = self.model.conv1(x) 78 | x = self.model.bn1(x) 79 | x = self.model.relu(x) 80 | x = self.model.maxpool(x) 81 | x = self.model.layer1(x) 82 | x = self.model.layer2(x) 83 | x = self.model.layer3(x) 84 | x = self.model.layer4(x) 85 | x = self.model.avgpool(x) 86 | x = x.view(x.size(0), x.size(1)) 87 | x = self.classifier(x) 88 | return x, None 89 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Model(nn.Module): 5 | """ 6 | Base class for models with added support for GradCam activation map 7 | and a SentiNet defense. The GradCam design is taken from: 8 | https://medium.com/@stepanulyanin/implementing-grad-cam-in-pytorch-ea0937c31e82 9 | If you are not planning to utilize SentiNet defense just import any model 10 | you like for your tasks. 11 | """ 12 | 13 | def __init__(self): 14 | super().__init__() 15 | self.gradient = None 16 | 17 | def activations_hook(self, grad): 18 | self.gradient = grad 19 | 20 | def get_gradient(self): 21 | return self.gradient 22 | 23 | def get_activations(self, x): 24 | return self.features(x) 25 | 26 | def switch_grads(self, enable=True): 27 | for i, n in self.named_parameters(): 28 | n.requires_grad_(enable) 29 | 30 | def features(self, x): 31 | """ 32 | Get latent representation, eg logit layer. 33 | :param x: 34 | :return: 35 | """ 36 | raise NotImplemented 37 | 38 | def forward(self, x, latent=False): 39 | raise NotImplemented 40 | -------------------------------------------------------------------------------- /models/simple.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from models.model import Model 5 | 6 | 7 | class SimpleNet(Model): 8 | def __init__(self, num_classes): 9 | super().__init__() 10 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 11 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 12 | self.fc1 = nn.Linear(4 * 4 * 50, 500) 13 | self.fc2 = nn.Linear(500, num_classes) 14 | 15 | def features(self, x): 16 | x = F.relu(self.conv1(x)) 17 | x = F.max_pool2d(x, 2, 2) 18 | x = F.relu(self.conv2(x)) 19 | x = F.max_pool2d(x, 2, 2) 20 | return x 21 | 22 | def forward(self, x, latent=False): 23 | x = F.relu(self.conv1(x)) 24 | x = F.max_pool2d(x, 2, 2) 25 | x = F.relu(self.conv2(x)) 26 | x = F.max_pool2d(x, 2, 2) 27 | if x.requires_grad: 28 | x.register_hook(self.activations_hook) 29 | x = x.view(-1, 4 * 4 * 50) 30 | x = F.relu(self.fc1(x)) 31 | x = self.fc2(x) 32 | out = F.log_softmax(x, dim=1) 33 | if latent: 34 | return out, x 35 | else: 36 | return out 37 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | colorlog~=4.4.0 2 | dataclasses~=0.6 3 | dill~=0.3.2 4 | gitpython 5 | matplotlib~=3.2.1 6 | numpy~=1.18.4 7 | pillow>=8.2.0 8 | prompt_toolkit~=3.0.5 9 | PyYAML==5.4 10 | scikit-learn~=0.23.2 11 | scipy~=1.5.4 12 | sklearn~=0.0 13 | tensorboard~=2.3.0 14 | torch~=1.7.0 15 | torchvision~=0.8.1 16 | transformers~=2.1.1 17 | torchtext~=0.7.0 18 | tqdm~=4.46.0 19 | protobuf~=3.20.0 -------------------------------------------------------------------------------- /synthesizers/block_synthesizer.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | from torchvision.transforms import transforms, functional 5 | 6 | from synthesizers.synthesizer import Synthesizer 7 | from tasks.task import Task 8 | 9 | transform_to_image = transforms.ToPILImage() 10 | transform_to_tensor = transforms.ToTensor() 11 | 12 | 13 | class BlockSynthesizer(Synthesizer): 14 | pattern_tensor: torch.Tensor = torch.tensor([ 15 | [1., 1., 1., 1.], 16 | [1., 1., 1., 1.], 17 | [1., 1., 1., 1.], 18 | [1., 1., 1., 1.], 19 | ]) 20 | "Just some random 2D pattern." 21 | 22 | x_top = 23 23 | "X coordinate to put the backdoor into." 24 | y_top = 23 25 | "Y coordinate to put the backdoor into." 26 | 27 | mask_value = -10 28 | "A tensor coordinate with this value won't be applied to the image." 29 | 30 | resize_scale = (5, 10) 31 | "If the pattern is dynamically placed, resize the pattern." 32 | 33 | mask: torch.Tensor = None 34 | "A mask used to combine backdoor pattern with the original image." 35 | 36 | pattern: torch.Tensor = None 37 | "A tensor of the `input.shape` filled with `mask_value` except backdoor." 38 | 39 | def __init__(self, task: Task): 40 | super().__init__(task) 41 | self.make_pattern(self.pattern_tensor, self.x_top, self.y_top) 42 | 43 | def make_pattern(self, pattern_tensor, x_top, y_top): 44 | full_image = torch.zeros(self.params.input_shape) 45 | full_image.fill_(self.mask_value) 46 | 47 | x_bot = x_top + pattern_tensor.shape[0] 48 | y_bot = y_top + pattern_tensor.shape[1] 49 | 50 | if x_bot >= self.params.input_shape[1] or \ 51 | y_bot >= self.params.input_shape[2]: 52 | raise ValueError(f'Position of backdoor outside image limits:' 53 | f'image: {self.params.input_shape}, but backdoor' 54 | f'ends at ({x_bot}, {y_bot})') 55 | 56 | full_image[:, x_top:x_bot, y_top:y_bot] = pattern_tensor 57 | 58 | self.mask = 1 * (full_image != self.mask_value).to(self.params.device) 59 | self.pattern = self.task.normalize(full_image).to(self.params.device) 60 | 61 | def synthesize_inputs(self, batch, attack_portion=None): 62 | pattern, mask = self.get_pattern() 63 | batch.inputs[:attack_portion] = (1 - mask) * \ 64 | batch.inputs[:attack_portion] + \ 65 | mask * pattern 66 | 67 | return 68 | 69 | def synthesize_labels(self, batch, attack_portion=None): 70 | batch.labels[:attack_portion].fill_(self.params.backdoor_label) 71 | 72 | return 73 | 74 | def get_pattern(self): 75 | if self.params.backdoor_dynamic_position: 76 | resize = random.randint(self.resize_scale[0], self.resize_scale[1]) 77 | pattern = self.pattern_tensor 78 | if random.random() > 0.5: 79 | pattern = functional.hflip(pattern) 80 | image = transform_to_image(pattern) 81 | pattern = transform_to_tensor( 82 | functional.resize(image, 83 | resize, interpolation=0)).squeeze() 84 | 85 | x = random.randint(0, self.params.input_shape[1] \ 86 | - pattern.shape[0] - 1) 87 | y = random.randint(0, self.params.input_shape[2] \ 88 | - pattern.shape[1] - 1) 89 | self.make_pattern(pattern, x, y) 90 | 91 | return self.pattern, self.mask 92 | -------------------------------------------------------------------------------- /synthesizers/complex_synthesizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from synthesizers.pattern_synthesizer import PatternSynthesizer 4 | 5 | 6 | class ComplexSynthesizer(PatternSynthesizer): 7 | """ 8 | For physical backdoors it's ok to train using pixel pattern that 9 | represents the physical object in the real scene. 10 | """ 11 | 12 | pattern_tensor = torch.tensor([[1.]]) 13 | 14 | def synthesize_labels(self, batch, attack_portion=None): 15 | batch.labels[:attack_portion] = batch.aux[:attack_portion] 16 | return 17 | -------------------------------------------------------------------------------- /synthesizers/pattern_synthesizer.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | from torchvision.transforms import transforms, functional 5 | 6 | from synthesizers.synthesizer import Synthesizer 7 | from tasks.task import Task 8 | 9 | transform_to_image = transforms.ToPILImage() 10 | transform_to_tensor = transforms.ToTensor() 11 | 12 | 13 | class PatternSynthesizer(Synthesizer): 14 | pattern_tensor: torch.Tensor = torch.tensor([ 15 | [1., 0., 1.], 16 | [-10., 1., -10.], 17 | [-10., -10., 0.], 18 | [-10., 1., -10.], 19 | [1., 0., 1.] 20 | ]) 21 | "Just some random 2D pattern." 22 | 23 | x_top = 3 24 | "X coordinate to put the backdoor into." 25 | y_top = 23 26 | "Y coordinate to put the backdoor into." 27 | 28 | mask_value = -10 29 | "A tensor coordinate with this value won't be applied to the image." 30 | 31 | resize_scale = (5, 10) 32 | "If the pattern is dynamically placed, resize the pattern." 33 | 34 | mask: torch.Tensor = None 35 | "A mask used to combine backdoor pattern with the original image." 36 | 37 | pattern: torch.Tensor = None 38 | "A tensor of the `input.shape` filled with `mask_value` except backdoor." 39 | 40 | def __init__(self, task: Task): 41 | super().__init__(task) 42 | self.make_pattern(self.pattern_tensor, self.x_top, self.y_top) 43 | 44 | def make_pattern(self, pattern_tensor, x_top, y_top): 45 | full_image = torch.zeros(self.params.input_shape) 46 | full_image.fill_(self.mask_value) 47 | 48 | x_bot = x_top + pattern_tensor.shape[0] 49 | y_bot = y_top + pattern_tensor.shape[1] 50 | 51 | if x_bot >= self.params.input_shape[1] or \ 52 | y_bot >= self.params.input_shape[2]: 53 | raise ValueError(f'Position of backdoor outside image limits:' 54 | f'image: {self.params.input_shape}, but backdoor' 55 | f'ends at ({x_bot}, {y_bot})') 56 | 57 | full_image[:, x_top:x_bot, y_top:y_bot] = pattern_tensor 58 | 59 | self.mask = 1 * (full_image != self.mask_value).to(self.params.device) 60 | self.pattern = self.task.normalize(full_image).to(self.params.device) 61 | 62 | def synthesize_inputs(self, batch, attack_portion=None): 63 | pattern, mask = self.get_pattern() 64 | batch.inputs[:attack_portion] = (1 - mask) * \ 65 | batch.inputs[:attack_portion] + \ 66 | mask * pattern 67 | 68 | return 69 | 70 | def synthesize_labels(self, batch, attack_portion=None): 71 | batch.labels[:attack_portion].fill_(self.params.backdoor_label) 72 | 73 | return 74 | 75 | def get_pattern(self): 76 | if self.params.backdoor_dynamic_position: 77 | resize = random.randint(self.resize_scale[0], self.resize_scale[1]) 78 | pattern = self.pattern_tensor 79 | if random.random() > 0.5: 80 | pattern = functional.hflip(pattern) 81 | image = transform_to_image(pattern) 82 | pattern = transform_to_tensor( 83 | functional.resize(image, 84 | resize, interpolation=0)).squeeze() 85 | 86 | x = random.randint(0, self.params.input_shape[1] \ 87 | - pattern.shape[0] - 1) 88 | y = random.randint(0, self.params.input_shape[2] \ 89 | - pattern.shape[1] - 1) 90 | self.make_pattern(pattern, x, y) 91 | 92 | return self.pattern, self.mask 93 | -------------------------------------------------------------------------------- /synthesizers/physical_synthesizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from synthesizers.synthesizer import Synthesizer 3 | 4 | 5 | class PhysicalSynthesizer(Synthesizer): 6 | """ 7 | For physical backdoors it's ok to train using pixel pattern that 8 | represents the physical object in the real scene. 9 | """ 10 | 11 | pattern_tensor = torch.tensor([[1.]]) -------------------------------------------------------------------------------- /synthesizers/singlepixel_synthesizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from synthesizers.pattern_synthesizer import PatternSynthesizer 3 | 4 | 5 | class SinglePixelSynthesizer(PatternSynthesizer): 6 | pattern_tensor = torch.tensor([[1.]]) 7 | -------------------------------------------------------------------------------- /synthesizers/synthesizer.py: -------------------------------------------------------------------------------- 1 | from tasks.batch import Batch 2 | from tasks.task import Task 3 | from utils.parameters import Params 4 | 5 | 6 | class Synthesizer: 7 | params: Params 8 | task: Task 9 | 10 | def __init__(self, task: Task): 11 | self.task = task 12 | self.params = task.params 13 | 14 | def make_backdoor_batch(self, batch: Batch, test=False, attack=True, ratio=None) -> Batch: 15 | 16 | # Don't attack if only normal loss task. 17 | if (not attack) or (self.params.loss_tasks == ['normal'] and not test): 18 | return batch 19 | 20 | if ratio is not None: 21 | attack_portion = round( 22 | batch.batch_size * ratio) 23 | elif test: 24 | attack_portion = batch.batch_size 25 | else: 26 | attack_portion = round( 27 | batch.batch_size * self.params.poisoning_proportion) 28 | 29 | backdoored_batch = batch.clone() 30 | self.apply_backdoor(backdoored_batch, attack_portion) 31 | 32 | return backdoored_batch 33 | 34 | def apply_backdoor(self, batch, attack_portion): 35 | """ 36 | Modifies only a portion of the batch (represents batch poisoning). 37 | 38 | :param batch: 39 | :return: 40 | """ 41 | self.synthesize_inputs(batch=batch, attack_portion=attack_portion) 42 | self.synthesize_labels(batch=batch, attack_portion=attack_portion) 43 | 44 | return 45 | 46 | def synthesize_inputs(self, batch, attack_portion=None): 47 | raise NotImplemented 48 | 49 | def synthesize_labels(self, batch, attack_portion=None): 50 | raise NotImplemented 51 | -------------------------------------------------------------------------------- /tasks/batch.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch 3 | 4 | 5 | @dataclass 6 | class Batch: 7 | batch_id: int 8 | inputs: torch.Tensor 9 | labels: torch.Tensor 10 | 11 | # For PIPA experiment we use this field to store identity label. 12 | aux: torch.Tensor = None 13 | 14 | def __post_init__(self): 15 | self.batch_size = self.inputs.shape[0] 16 | 17 | def to(self, device): 18 | inputs = self.inputs.to(device) 19 | labels = self.labels.to(device) 20 | if self.aux is not None: 21 | aux = self.aux.to(device) 22 | else: 23 | aux = None 24 | return Batch(self.batch_id, inputs, labels, aux) 25 | 26 | def clone(self): 27 | inputs = self.inputs.clone() 28 | labels = self.labels.clone() 29 | if self.aux is not None: 30 | aux = self.aux.clone() 31 | else: 32 | aux = None 33 | return Batch(self.batch_id, inputs, labels, aux) 34 | 35 | 36 | def clip(self, batch_size): 37 | if batch_size is None: 38 | return self 39 | 40 | inputs = self.inputs[:batch_size] 41 | labels = self.labels[:batch_size] 42 | 43 | if self.aux is None: 44 | aux = None 45 | else: 46 | aux = self.aux[:batch_size] 47 | 48 | return Batch(self.batch_id, inputs, labels, aux) -------------------------------------------------------------------------------- /tasks/celeba_helper.py: -------------------------------------------------------------------------------- 1 | from data_helpers.task_helper import TaskHelper 2 | 3 | 4 | class CelebaHelper(TaskHelper): 5 | 6 | def load_data(self): 7 | logger.error('Celeba dataset is unfinished, needs more work') 8 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 9 | std=[0.229, 0.224, 0.225]) 10 | train_transform = transforms.Compose([ 11 | transforms.Resize(128), 12 | transforms.CenterCrop(128), 13 | # transforms.RandomResizedCrop(178, scale=(0.9,1.0)), 14 | transforms.RandomHorizontalFlip(), 15 | transforms.ToTensor(), 16 | normalize, 17 | ]) 18 | test_transform = transforms.Compose([ 19 | transforms.Resize(128), 20 | transforms.CenterCrop(128), 21 | # transforms.CenterCrop((178, 178)), 22 | # transforms.Resize((128, 128)), 23 | transforms.ToTensor(), 24 | normalize, 25 | ]) 26 | 27 | self.train_dataset = CelebA(root=self.params.data_path, 28 | target_type='identity', # ['identity', 29 | # 'attr'], 30 | split='train', transform=train_transform) 31 | 32 | self.test_dataset = CelebA(root=self.params.data_path, 33 | target_type='identity', 34 | split='test', transform=test_transform) 35 | 36 | self.train_loader = torch_data.DataLoader(self.train_dataset, 37 | batch_size=self.params.batch_size, 38 | shuffle=True, num_workers=8, 39 | pin_memory=True) 40 | self.test_loader = torch_data.DataLoader(self.test_dataset, 41 | batch_size=self.params.test_batch_size, 42 | shuffle=False, num_workers=2) -------------------------------------------------------------------------------- /tasks/cifar10_task.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torch import nn 3 | from torchvision.transforms import transforms 4 | import torch 5 | import torch.utils.data as torch_data 6 | 7 | from models.resnet import resnet18 8 | from tasks.task import Task 9 | 10 | 11 | class Cifar10Task(Task): 12 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), 13 | (0.2023, 0.1994, 0.2010)) 14 | 15 | def load_data(self): 16 | self.load_cifar_data() 17 | 18 | def load_cifar_data(self): 19 | transform_train = transforms.Compose([ 20 | transforms.ToTensor(), 21 | self.normalize 22 | ]) 23 | 24 | transform_test = transforms.Compose([ 25 | transforms.ToTensor(), 26 | self.normalize 27 | ]) 28 | 29 | original_train_dataset = torchvision.datasets.CIFAR10( 30 | root=self.params.data_path, 31 | train=True, 32 | download=True, 33 | transform=transform_train) 34 | 35 | clean_size = round(self.params.clean_ratio * len(original_train_dataset)) 36 | train_size = len(original_train_dataset) - clean_size 37 | self.train_dataset, self.clean_dataset = torch_data.random_split(original_train_dataset, 38 | lengths=[train_size, clean_size], 39 | generator=torch.Generator().manual_seed(self.params.random_seed)) 40 | self.clean_loader = torch_data.DataLoader(self.train_dataset, 41 | batch_size=self.params.batch_size, 42 | shuffle=True, 43 | num_workers=0) 44 | self.clean_loader = torch_data.DataLoader(self.clean_dataset,batch_size=self.params.batch_size, 45 | shuffle=False, num_workers=0) 46 | if self.params.clean_set_dataset == 'CIFAR100': 47 | self.clean_dataset= torchvision.datasets.CIFAR100(root=self.params.data_path, 48 | train=True, 49 | download=True, 50 | transform=transform_train) 51 | tot = len(self.clean_dataset) 52 | self.clean_dataset, _ = torch_data.random_split(self.clean_dataset, 53 | lengths=[clean_size, tot-clean_size], 54 | generator=torch.Generator().manual_seed(self.params.random_seed)) 55 | self.clean_loader = torch_data.DataLoader(self.clean_dataset, 56 | batch_size=self.params.batch_size, 57 | shuffle=True, 58 | num_workers=0) 59 | 60 | elif self.params.clean_set_dataset == 'GTSRB': 61 | transform_train= transforms.Compose([transform_train, transforms.Resize((32,32))]) 62 | self.clean_dataset= torchvision.datasets.GTSRB(root=self.params.data_path, 63 | split='train', 64 | download=True, 65 | transform=transform_train) 66 | tot = len(self.clean_dataset) 67 | self.clean_dataset, _ = torch_data.random_split(self.clean_dataset, 68 | lengths=[clean_size, tot-clean_size], 69 | generator=torch.Generator().manual_seed(self.params.random_seed)) 70 | self.clean_loader = torch_data.DataLoader(self.clean_dataset, 71 | batch_size=self.params.batch_size, 72 | shuffle=True, 73 | num_workers=0) 74 | 75 | elif self.params.clean_set_dataset is None and self.params.clean_classes is not None: 76 | clean_indices = [] 77 | for cls in self.params.clean_classes: 78 | for i in range(clean_size): 79 | if self.clean_dataset[i][1] == cls: 80 | clean_indices.append(i) 81 | sampler = torch_data.SubsetRandomSampler(clean_indices, generator=torch.Generator().manual_seed(self.params.random_seed)) 82 | self.clean_loader = torch_data.DataLoader(self.clean_dataset, 83 | batch_size=self.params.batch_size, 84 | sampler=sampler, 85 | num_workers=0) 86 | else: 87 | self.clean_loader = torch_data.DataLoader(self.clean_dataset, 88 | batch_size=self.params.batch_size, 89 | shuffle=True, 90 | num_workers=0) 91 | 92 | self.train_loader = torch_data.DataLoader(self.train_dataset, 93 | batch_size=self.params.batch_size, 94 | shuffle=True, 95 | num_workers=0) 96 | 97 | 98 | 99 | self.test_dataset = torchvision.datasets.CIFAR10( 100 | root=self.params.data_path, 101 | train=False, 102 | download=True, 103 | transform=transform_test) 104 | self.test_loader = torch_data.DataLoader(self.test_dataset, 105 | batch_size=self.params.test_batch_size, 106 | shuffle=False, 107 | num_workers=0) 108 | self.classes = ('plane', 'car', 'bird', 'cat', 109 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 110 | return True 111 | 112 | def build_model(self) -> nn.Module: 113 | if self.params.pretrained: 114 | model = resnet18(pretrained=True, norm_layer=nn.Identity) 115 | 116 | # model is pretrained on ImageNet changing classes to CIFAR 117 | model.fc = nn.Linear(512, len(self.classes)) 118 | else: 119 | model = resnet18(pretrained=False, 120 | num_classes=len(self.classes)) 121 | return model 122 | 123 | -------------------------------------------------------------------------------- /tasks/fl/cifar10_fedavg_task.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import random 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | from torch.utils.data.dataloader import DataLoader, Dataset 7 | from torch.utils.data.sampler import SubsetRandomSampler 8 | 9 | from tasks.cifar10_task import Cifar10Task 10 | from tasks.fl.fl_task import FederatedLearningTask 11 | 12 | 13 | class Cifar10_FedAvgTask(FederatedLearningTask, Cifar10Task): 14 | 15 | def load_data(self) -> None: 16 | self.load_cifar_data() 17 | train_loaders = self.assign_data(bias=self.params.fl_q) 18 | self.fl_train_loaders = train_loaders 19 | return 20 | 21 | def assign_data(self, bias=1, p=0.1): 22 | num_labels = len(self.classes) 23 | num_workers = self.params.fl_total_participants 24 | server_pc = 0 25 | 26 | # assign data to the clients 27 | other_group_size = (1 - bias) / (num_labels - 1) 28 | worker_per_group = num_workers / num_labels 29 | 30 | #assign training data to each worker 31 | each_worker_data = [[] for _ in range(num_workers)] 32 | each_worker_label = [[] for _ in range(num_workers)] 33 | server_data = [] 34 | server_label = [] 35 | 36 | # compute the labels needed for each class 37 | real_dis = [1. / num_labels for _ in range(num_labels)] 38 | samp_dis = [0 for _ in range(num_labels)] 39 | num1 = int(server_pc * p) 40 | samp_dis[1] = num1 41 | average_num = (server_pc - num1) / (num_labels - 1) 42 | resid = average_num - np.floor(average_num) 43 | sum_res = 0. 44 | for other_num in range(num_labels - 1): 45 | if other_num == 1: 46 | continue 47 | samp_dis[other_num] = int(average_num) 48 | sum_res += resid 49 | if sum_res >= 1.0: 50 | samp_dis[other_num] += 1 51 | sum_res -= 1 52 | samp_dis[num_labels - 1] = server_pc - np.sum(samp_dis[:num_labels - 1]) 53 | 54 | # randomly assign the data points based on the labels 55 | server_counter = [0 for _ in range(num_labels)] 56 | for x, y in self.train_dataset: 57 | upper_bound = y * (1. - bias) / (num_labels - 1) + bias 58 | lower_bound = y * (1. - bias) / (num_labels - 1) 59 | rd = np.random.random_sample() 60 | 61 | if rd > upper_bound: 62 | worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1) 63 | elif rd < lower_bound: 64 | worker_group = int(np.floor(rd / other_group_size)) 65 | else: 66 | worker_group = y 67 | 68 | if server_counter[y] < samp_dis[y]: 69 | server_data.append(x) 70 | server_label.append(y) 71 | server_counter[y] += 1 72 | else: 73 | rd = np.random.random_sample() 74 | selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group))) 75 | each_worker_data[selected_worker].append(x) 76 | each_worker_label[selected_worker].append(y) 77 | 78 | random_order = np.random.RandomState(seed=self.params.random_seed).permutation(num_workers) 79 | each_worker_data = [each_worker_data[i] for i in random_order] 80 | each_worker_label = [each_worker_label[i] for i in random_order] 81 | 82 | train_loaders = [] 83 | for i in range(len(each_worker_data)): 84 | train_set = ClientDataset(each_worker_data[i], each_worker_label[i]) 85 | train_loader = DataLoader(train_set, 86 | batch_size=self.params.batch_size, 87 | shuffle=True) 88 | train_loaders.append(train_loader) 89 | 90 | return train_loaders 91 | 92 | 93 | class ClientDataset(Dataset): 94 | def __init__(self, data_list, label_list): 95 | super().__init__() 96 | self.data_list = data_list 97 | self.label_list = label_list 98 | 99 | def __len__(self): 100 | return len(self.label_list) 101 | 102 | def __getitem__(self, index): 103 | return self.data_list[index], self.label_list[index] 104 | -------------------------------------------------------------------------------- /tasks/fl/cifar10_fltrust_task.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from torch.utils.data.dataloader import DataLoader, Dataset 6 | from torch.utils.data.sampler import SubsetRandomSampler 7 | 8 | from tasks.cifar10_task import Cifar10Task 9 | from tasks.fl.fl_task import FederatedLearningTask 10 | 11 | 12 | class Cifar10_FLTrustTask(FederatedLearningTask, Cifar10Task): 13 | 14 | def load_data(self) -> None: 15 | self.load_cifar_data() 16 | train_loaders = self.assign_data(bias=self.params.fl_q) 17 | self.fl_train_loaders = train_loaders 18 | return 19 | 20 | def assign_data(self, bias=1, p=0.1): 21 | num_labels = len(self.classes) 22 | num_workers = self.params.fl_total_participants 23 | server_pc = 0 24 | 25 | # assign data to the clients 26 | other_group_size = (1 - bias) / (num_labels - 1) 27 | worker_per_group = num_workers / num_labels 28 | 29 | #assign training data to each worker 30 | each_worker_data = [[] for _ in range(num_workers)] 31 | each_worker_label = [[] for _ in range(num_workers)] 32 | server_data = [] 33 | server_label = [] 34 | 35 | # compute the labels needed for each class 36 | real_dis = [1. / num_labels for _ in range(num_labels)] 37 | samp_dis = [0 for _ in range(num_labels)] 38 | num1 = int(server_pc * p) 39 | samp_dis[1] = num1 40 | average_num = (server_pc - num1) / (num_labels - 1) 41 | resid = average_num - np.floor(average_num) 42 | sum_res = 0. 43 | for other_num in range(num_labels - 1): 44 | if other_num == 1: 45 | continue 46 | samp_dis[other_num] = int(average_num) 47 | sum_res += resid 48 | if sum_res >= 1.0: 49 | samp_dis[other_num] += 1 50 | sum_res -= 1 51 | samp_dis[num_labels - 1] = server_pc - np.sum(samp_dis[:num_labels - 1]) 52 | 53 | # randomly assign the data points based on the labels 54 | server_counter = [0 for _ in range(num_labels)] 55 | for x, y in self.train_dataset: 56 | upper_bound = y * (1. - bias) / (num_labels - 1) + bias 57 | lower_bound = y * (1. - bias) / (num_labels - 1) 58 | rd = np.random.random_sample() 59 | 60 | if rd > upper_bound: 61 | worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1) 62 | elif rd < lower_bound: 63 | worker_group = int(np.floor(rd / other_group_size)) 64 | else: 65 | worker_group = y 66 | 67 | if server_counter[y] < samp_dis[y]: 68 | server_data.append(x) 69 | server_label.append(y) 70 | server_counter[y] += 1 71 | else: 72 | rd = np.random.random_sample() 73 | selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group))) 74 | each_worker_data[selected_worker].append(x) 75 | each_worker_label[selected_worker].append(y) 76 | 77 | random_order = np.random.RandomState(seed=self.params.random_seed).permutation(num_workers) 78 | each_worker_data = [each_worker_data[i] for i in random_order] 79 | each_worker_label = [each_worker_label[i] for i in random_order] 80 | 81 | train_loaders = [] 82 | for i in range(len(each_worker_data)): 83 | train_set = ClientDataset(each_worker_data[i], each_worker_label[i]) 84 | train_loader = DataLoader(train_set, 85 | batch_size=self.params.batch_size, 86 | shuffle=True) 87 | train_loaders.append(train_loader) 88 | 89 | return train_loaders 90 | 91 | def accumulate_weights_weighted(self, weight_accumulator, local_updates, genuine_scores): 92 | gs_sum = sum(genuine_scores) 93 | for user_id, local_update in local_updates.items(): 94 | for name, value in local_update.items(): 95 | weight_accumulator[name].add_(value * (genuine_scores[user_id] / (gs_sum + 1e-9)) * self.params.fl_total_participants) 96 | 97 | 98 | class ClientDataset(Dataset): 99 | def __init__(self, data_list, label_list): 100 | super().__init__() 101 | self.data_list = data_list 102 | self.label_list = label_list 103 | 104 | def __len__(self): 105 | return len(self.label_list) 106 | 107 | def __getitem__(self, index): 108 | return self.data_list[index], self.label_list[index] 109 | -------------------------------------------------------------------------------- /tasks/fl/cifar10_ours_task.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | import torch 4 | from typing import List, Any, Dict 5 | import numpy as np 6 | from torch.utils.data.dataloader import DataLoader, Dataset 7 | from torch.utils.data.sampler import SubsetRandomSampler 8 | from torch.utils.data import random_split 9 | from torchvision.utils import save_image 10 | 11 | from tasks.fl.fl_user_ours import FLUserOurs 12 | from tasks.cifar10_task import Cifar10Task 13 | from tasks.fl.fl_task import FederatedLearningTask 14 | 15 | 16 | class Cifar10_OursTask(FederatedLearningTask, Cifar10Task): 17 | 18 | def load_data(self) -> None: 19 | self.load_cifar_data() 20 | train_loaders, test_loaders = self.assign_data(bias=self.params.fl_q) 21 | self.fl_train_loaders = train_loaders 22 | self.fl_test_loaders = test_loaders 23 | return 24 | 25 | def assign_data(self, bias=1, p=0.1): 26 | num_labels = len(self.classes) 27 | num_workers = self.params.fl_total_participants 28 | server_pc = 0 29 | 30 | # assign data to the clients 31 | other_group_size = (1 - bias) / (num_labels - 1) 32 | worker_per_group = num_workers / num_labels 33 | 34 | #assign training data to each worker 35 | each_worker_data = [[] for _ in range(num_workers)] 36 | each_worker_label = [[] for _ in range(num_workers)] 37 | server_data = [] 38 | server_label = [] 39 | 40 | # compute the labels needed for each class 41 | real_dis = [1. / num_labels for _ in range(num_labels)] 42 | samp_dis = [0 for _ in range(num_labels)] 43 | num1 = int(server_pc * p) 44 | samp_dis[1] = num1 45 | average_num = (server_pc - num1) / (num_labels - 1) 46 | resid = average_num - np.floor(average_num) 47 | sum_res = 0. 48 | for other_num in range(num_labels - 1): 49 | if other_num == 1: 50 | continue 51 | samp_dis[other_num] = int(average_num) 52 | sum_res += resid 53 | if sum_res >= 1.0: 54 | samp_dis[other_num] += 1 55 | sum_res -= 1 56 | samp_dis[num_labels - 1] = server_pc - np.sum(samp_dis[:num_labels - 1]) 57 | 58 | # randomly assign the data points based on the labels 59 | server_counter = [0 for _ in range(num_labels)] 60 | for x, y in self.train_dataset: 61 | upper_bound = y * (1. - bias) / (num_labels - 1) + bias 62 | lower_bound = y * (1. - bias) / (num_labels - 1) 63 | rd = np.random.random_sample() 64 | 65 | if rd > upper_bound: 66 | worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1) 67 | elif rd < lower_bound: 68 | worker_group = int(np.floor(rd / other_group_size)) 69 | else: 70 | worker_group = y 71 | 72 | if server_counter[y] < samp_dis[y]: 73 | server_data.append(x) 74 | server_label.append(y) 75 | server_counter[y] += 1 76 | else: 77 | rd = np.random.random_sample() 78 | selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group))) 79 | each_worker_data[selected_worker].append(x) 80 | each_worker_label[selected_worker].append(y) 81 | 82 | random_order = np.random.RandomState(seed=self.params.random_seed).permutation(num_workers) 83 | each_worker_data = [each_worker_data[i] for i in random_order] 84 | each_worker_label = [each_worker_label[i] for i in random_order] 85 | 86 | train_loaders, test_loaders = [], [] 87 | for i in range(len(each_worker_data)): 88 | train_set = ClientDataset(each_worker_data[i], each_worker_label[i]) 89 | tot = len(train_set) 90 | train_size = int(tot * self.params.attacker_train_ratio) 91 | test_size = tot - train_size 92 | train_set, test_set = random_split(train_set, 93 | lengths=[train_size, test_size], 94 | generator=torch.Generator().manual_seed(self.params.random_seed)) 95 | 96 | train_loader = DataLoader(train_set, 97 | batch_size=self.params.batch_size, 98 | shuffle=True) 99 | test_loader = DataLoader(test_set, 100 | batch_size=self.params.batch_size, 101 | shuffle=False) 102 | train_loaders.append(train_loader) 103 | test_loaders.append(test_loader) 104 | 105 | return train_loaders, test_loaders 106 | 107 | def accumulate_weights_weighted(self, weight_accumulator, local_updates, genuine_scores): 108 | gs_sum = sum(genuine_scores.values()) 109 | for user_id, local_update in local_updates.items(): 110 | for name, value in local_update.items(): 111 | weight_accumulator[name].add_(value * (genuine_scores[user_id] / (gs_sum + 1e-9)) * self.params.fl_total_participants) 112 | 113 | @torch.no_grad() 114 | def compute_genuine_score(self, model, dataloader, synthesizer): 115 | model.eval() 116 | correct = 0 117 | total = 0 118 | for i, data in enumerate(dataloader): 119 | batch = self.get_batch(i, data) 120 | batch = synthesizer.make_backdoor_batch(batch, test=True, attack=True) 121 | outputs = model(batch.inputs) 122 | 123 | pred_class_idx = torch.argmax(outputs, dim=1) 124 | correct += pred_class_idx[pred_class_idx==batch.labels].shape[0] 125 | total += batch.inputs.shape[0] 126 | 127 | return 1 - correct / total 128 | 129 | @torch.no_grad() 130 | def compute_genuine_score_global(self, model, dataloader, triggers, masks, target_cls): 131 | model.eval() 132 | correct = 0 133 | total = 0 134 | for i, data in enumerate(dataloader): 135 | batch = self.get_batch(i, data) 136 | images = batch.inputs 137 | trigger, mask = triggers[target_cls], masks[target_cls] 138 | 139 | triggerh = self.tanh_trigger(trigger) 140 | maskh = self.tanh_mask(mask) 141 | trojan_images = (1 - torch.unsqueeze(maskh, dim=0)) * images + torch.unsqueeze(maskh, dim=0) * triggerh 142 | outputs = model(trojan_images) 143 | labels = torch.tensor([target_cls] * batch.inputs.size(0)).to(self.params.device) 144 | 145 | pred_class_idx = torch.argmax(outputs, dim=1) 146 | correct += pred_class_idx[pred_class_idx==labels].shape[0] 147 | total += batch.inputs.shape[0] 148 | 149 | return 1 - correct / total 150 | 151 | def tanh_mask(self, vector): 152 | return torch.tanh(vector) / 2 + 0.5 153 | 154 | def tanh_trigger(self, vector): 155 | if len(vector.size()) < 3: 156 | vector = vector[None, :, :] 157 | vector = vector.repeat(3, 1, 1) 158 | mean_tens = torch.tensor([0.4914, 0.4822, 0.4465]).to(self.params.device) 159 | var_tens = torch.tensor([0.2023, 0.1994, 0.2010]).to(self.params.device) 160 | mean_tens, var_tens = mean_tens[:, None, None], var_tens[:, None, None] 161 | 162 | return (torch.tanh(vector) - mean_tens) / var_tens 163 | 164 | def reverse_engineer_per_class(self, model, target_label, dataloader): 165 | model.eval() 166 | width, height = self.params.input_shape[1], self.params.input_shape[2] 167 | trigger = torch.randn((1, width, height)) 168 | trigger = trigger.to(self.params.device).detach().requires_grad_(True) 169 | mask = torch.zeros((width, height)) 170 | mask = mask.to(self.params.device).detach().requires_grad_(True) 171 | 172 | criterion = torch.nn.CrossEntropyLoss() 173 | optimizer = torch.optim.Adam([{"params": trigger}, {"params": mask}], lr=0.005) 174 | 175 | min_norm = np.inf 176 | min_norm_count = 0 177 | for epoch in range(self.params.nc_steps): 178 | norm = 0.0 179 | for i, data in enumerate(dataloader): 180 | batch = self.get_batch(i, data) 181 | optimizer.zero_grad() 182 | images = batch.inputs 183 | 184 | triggerh = self.tanh_trigger(trigger) 185 | maskh = self.tanh_mask(mask) 186 | trojan_images = (1 - torch.unsqueeze(maskh, dim=0)) * images + torch.unsqueeze(maskh, dim=0) * triggerh 187 | y_pred = model(trojan_images) 188 | y_target = torch.full((y_pred.size(0),), target_label, dtype=torch.long).to(self.params.device) 189 | loss = criterion(y_pred, y_target) + 0.01 * torch.sum(maskh) 190 | loss.backward() 191 | optimizer.step() 192 | 193 | with torch.no_grad(): 194 | norm = torch.sum(maskh) 195 | 196 | # early stopping 197 | if norm < min_norm: 198 | min_norm = norm 199 | min_norm_count = 0 200 | else: min_norm_count += 1 201 | if min_norm_count > 30: break 202 | 203 | return trigger, mask 204 | 205 | def reverse_engineer_trigger(self, model, dataloader): 206 | triggers, masks, norm_list = [], [], [] 207 | for cls in range(len(self.classes)): 208 | trigger, mask = self.reverse_engineer_per_class(model, cls, dataloader) 209 | triggers.append(trigger) 210 | masks.append(mask) 211 | norm_list.append(torch.sum(self.tanh_mask(mask)).item()) 212 | 213 | # visualize for debugging 214 | # batch = self.get_batch(0, next(iter(dataloader))) 215 | # images = batch.inputs 216 | 217 | # triggerh = self.tanh_trigger(trigger) 218 | # maskh = self.tanh_mask(mask) 219 | # trojan_images = (1 - torch.unsqueeze(maskh, dim=0)) * images + torch.unsqueeze(maskh, dim=0) * triggerh 220 | 221 | # save_image(images, 'runs/images_{}.png'.format(cls)) 222 | # save_image(triggerh, 'runs/trigger_{}.png'.format(cls)) 223 | # save_image(maskh, 'runs/mask_{}.png'.format(cls)) 224 | # save_image(trojan_images, 'runs/trojan_images_{}.png'.format(cls)) 225 | 226 | return triggers, masks, norm_list 227 | 228 | def sample_users_for_round(self, epoch) -> List[FLUserOurs]: 229 | sampled_ids = random.sample( 230 | range(self.params.fl_total_participants), 231 | self.params.fl_no_models) 232 | # sampled_ids = range(self.params.fl_total_participants) 233 | sampled_users = [] 234 | for pos, user_id in enumerate(sampled_ids): 235 | train_loader = self.fl_train_loaders[user_id] 236 | test_loader = self.fl_test_loaders[user_id] 237 | compromised = self.check_user_compromised(epoch, pos, user_id) 238 | user = FLUserOurs(user_id, compromised=compromised, 239 | train_loader=train_loader, test_loader=test_loader) 240 | sampled_users.append(user) 241 | 242 | return sampled_users 243 | 244 | 245 | class ClientDataset(Dataset): 246 | def __init__(self, data_list, label_list): 247 | super().__init__() 248 | self.data_list = data_list 249 | self.label_list = label_list 250 | 251 | def __len__(self): 252 | return len(self.label_list) 253 | 254 | def __getitem__(self, index): 255 | return self.data_list[index], self.label_list[index] 256 | 257 | 258 | -------------------------------------------------------------------------------- /tasks/fl/fl_task.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from copy import deepcopy 4 | from typing import List, Any, Dict 5 | 6 | from metrics.accuracy_metric import AccuracyMetric 7 | from metrics.test_loss_metric import TestLossMetric 8 | from tasks.fl.fl_user import FLUser 9 | import torch 10 | import logging 11 | from torch.nn import Module 12 | 13 | from tasks.task import Task 14 | logger = logging.getLogger('logger') 15 | 16 | 17 | class FederatedLearningTask(Task): 18 | fl_train_loaders: List[Any] = None 19 | fl_test_loaders: List[Any] = None 20 | ignored_weights = ['num_batches_tracked']#['tracked', 'running'] 21 | adversaries: List[int] = None 22 | 23 | def init_task(self): 24 | self.load_data() 25 | self.model = self.build_model() 26 | self.resume_model() 27 | self.model = self.model.to(self.params.device) 28 | 29 | self.local_model = self.build_model().to(self.params.device) 30 | self.criterion = self.make_criterion() 31 | self.adversaries = self.sample_adversaries() 32 | 33 | self.metrics = [AccuracyMetric(), TestLossMetric(self.criterion)] 34 | self.set_input_shape() 35 | return 36 | 37 | def get_empty_accumulator(self): 38 | weight_accumulator = dict() 39 | for name, data in self.model.state_dict().items(): 40 | weight_accumulator[name] = torch.zeros_like(data) 41 | return weight_accumulator 42 | 43 | def sample_users_for_round(self, epoch) -> List[FLUser]: 44 | sampled_ids = random.sample( 45 | range(self.params.fl_total_participants), 46 | self.params.fl_no_models) 47 | sampled_users = [] 48 | for pos, user_id in enumerate(sampled_ids): 49 | train_loader = self.fl_train_loaders[user_id] 50 | compromised = self.check_user_compromised(epoch, pos, user_id) 51 | user = FLUser(user_id, compromised=compromised, 52 | train_loader=train_loader) 53 | sampled_users.append(user) 54 | 55 | return sampled_users 56 | 57 | def check_user_compromised(self, epoch, pos, user_id): 58 | """Check if the sampled user is compromised for the attack. 59 | 60 | If single_epoch_attack is defined (eg not None) then ignore 61 | :param epoch: 62 | :param pos: 63 | :param user_id: 64 | :return: 65 | """ 66 | compromised = False 67 | if self.params.fl_single_epoch_attack is not None: 68 | if epoch == self.params.fl_single_epoch_attack: 69 | if pos < self.params.fl_number_of_adversaries: 70 | compromised = True 71 | logger.warning(f'Attacking once at epoch {epoch}. Compromised' 72 | f' user: {user_id}.') 73 | else: 74 | compromised = user_id in self.adversaries 75 | return compromised 76 | 77 | def sample_adversaries(self) -> List[int]: 78 | adversaries_ids = [] 79 | if self.params.fl_number_of_adversaries == 0: 80 | logger.warning(f'Running vanilla FL, no attack.') 81 | elif self.params.fl_single_epoch_attack is None: 82 | adversaries_ids = random.sample( 83 | range(self.params.fl_total_participants), 84 | self.params.fl_number_of_adversaries) 85 | logger.warning(f'Attacking over multiple epochs with following ' 86 | f'users compromised: {adversaries_ids}.') 87 | else: 88 | logger.warning(f'Attack only on epoch: ' 89 | f'{self.params.fl_single_epoch_attack} with ' 90 | f'{self.params.fl_number_of_adversaries} compromised' 91 | f' users.') 92 | 93 | return adversaries_ids 94 | 95 | def get_model_optimizer(self, model): 96 | local_model = deepcopy(model) 97 | local_model = local_model.to(self.params.device) 98 | 99 | optimizer = self.make_optimizer(local_model) 100 | 101 | return local_model, optimizer 102 | 103 | def copy_params(self, global_model, local_model): 104 | local_state = local_model.state_dict() 105 | for name, param in global_model.state_dict().items(): 106 | if name in local_state and name not in self.ignored_weights: 107 | local_state[name].copy_(param) 108 | 109 | def get_fl_update(self, local_model, global_model) -> Dict[str, torch.Tensor]: 110 | local_update = dict() 111 | for name, data in local_model.state_dict().items(): 112 | if self.check_ignored_weights(name): 113 | continue 114 | local_update[name] = (data - global_model.state_dict()[name]) 115 | 116 | return local_update 117 | 118 | def accumulate_weights(self, weight_accumulator, local_update): 119 | update_norm = self.get_update_norm(local_update) 120 | for name, value in local_update.items(): 121 | self.dp_clip(value, update_norm) 122 | weight_accumulator[name].add_(value) 123 | 124 | def update_global_model(self, weight_accumulator, global_model: Module): 125 | for name, sum_update in weight_accumulator.items(): 126 | if self.check_ignored_weights(name): 127 | continue 128 | scale = self.params.fl_eta / self.params.fl_total_participants 129 | average_update = scale * sum_update 130 | self.dp_add_noise(average_update) 131 | model_weight = global_model.state_dict()[name] 132 | model_weight.add_(average_update) 133 | 134 | def dp_clip(self, local_update_tensor: torch.Tensor, update_norm): 135 | if self.params.fl_dp_clip is not None and \ 136 | update_norm > self.params.fl_dp_clip: 137 | norm_scale = self.params.fl_dp_clip / update_norm 138 | local_update_tensor.mul_(norm_scale) 139 | 140 | def dp_add_noise(self, sum_update_tensor: torch.Tensor): 141 | if self.params.fl_dp_noise is not None: 142 | noised_layer = torch.FloatTensor(sum_update_tensor.shape) 143 | noised_layer = noised_layer.to(self.params.device) 144 | noised_layer.normal_(mean=0, std=self.params.fl_dp_noise) 145 | sum_update_tensor.add_(noised_layer) 146 | 147 | def get_update_norm(self, local_update): 148 | squared_sum = 0 149 | for name, value in local_update.items(): 150 | if self.check_ignored_weights(name): 151 | continue 152 | squared_sum += torch.sum(torch.pow(value, 2)).item() 153 | update_norm = math.sqrt(squared_sum) 154 | return update_norm 155 | 156 | def check_ignored_weights(self, name) -> bool: 157 | for ignored in self.ignored_weights: 158 | if ignored in name: 159 | return True 160 | 161 | return False 162 | -------------------------------------------------------------------------------- /tasks/fl/fl_user.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from torch.utils.data.dataloader import DataLoader 3 | 4 | 5 | 6 | @dataclass 7 | class FLUser: 8 | user_id: int = 0 9 | compromised: bool = False 10 | train_loader: DataLoader = None 11 | -------------------------------------------------------------------------------- /tasks/fl/fl_user_ours.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from torch.utils.data.dataloader import DataLoader 3 | 4 | 5 | 6 | @dataclass 7 | class FLUserOurs: 8 | user_id: int = 0 9 | compromised: bool = False 10 | train_loader: DataLoader = None 11 | test_loader: DataLoader = None 12 | -------------------------------------------------------------------------------- /tasks/fl/mnist_fedavg_task.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import random 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | from torch.utils.data.dataloader import DataLoader, Dataset 7 | from torch.utils.data.sampler import SubsetRandomSampler 8 | 9 | from tasks.mnist_task import MNISTTask 10 | from tasks.fl.fl_task import FederatedLearningTask 11 | 12 | 13 | class MNIST_FedAvgTask(FederatedLearningTask, MNISTTask): 14 | 15 | def load_data(self) -> None: 16 | self.load_mnist_data() 17 | train_loaders = self.assign_data(bias=self.params.fl_q) 18 | self.fl_train_loaders = train_loaders 19 | return 20 | 21 | def assign_data(self, bias=1, p=0.1): 22 | num_labels = len(self.classes) 23 | num_workers = self.params.fl_total_participants 24 | server_pc = 0 25 | 26 | # assign data to the clients 27 | other_group_size = (1 - bias) / (num_labels - 1) 28 | worker_per_group = num_workers / num_labels 29 | 30 | #assign training data to each worker 31 | each_worker_data = [[] for _ in range(num_workers)] 32 | each_worker_label = [[] for _ in range(num_workers)] 33 | server_data = [] 34 | server_label = [] 35 | 36 | # compute the labels needed for each class 37 | real_dis = [1. / num_labels for _ in range(num_labels)] 38 | samp_dis = [0 for _ in range(num_labels)] 39 | num1 = int(server_pc * p) 40 | samp_dis[1] = num1 41 | average_num = (server_pc - num1) / (num_labels - 1) 42 | resid = average_num - np.floor(average_num) 43 | sum_res = 0. 44 | for other_num in range(num_labels - 1): 45 | if other_num == 1: 46 | continue 47 | samp_dis[other_num] = int(average_num) 48 | sum_res += resid 49 | if sum_res >= 1.0: 50 | samp_dis[other_num] += 1 51 | sum_res -= 1 52 | samp_dis[num_labels - 1] = server_pc - np.sum(samp_dis[:num_labels - 1]) 53 | 54 | # randomly assign the data points based on the labels 55 | server_counter = [0 for _ in range(num_labels)] 56 | for x, y in self.train_dataset: 57 | upper_bound = y * (1. - bias) / (num_labels - 1) + bias 58 | lower_bound = y * (1. - bias) / (num_labels - 1) 59 | rd = np.random.random_sample() 60 | 61 | if rd > upper_bound: 62 | worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1) 63 | elif rd < lower_bound: 64 | worker_group = int(np.floor(rd / other_group_size)) 65 | else: 66 | worker_group = y 67 | 68 | if server_counter[y] < samp_dis[y]: 69 | server_data.append(x) 70 | server_label.append(y) 71 | server_counter[y] += 1 72 | else: 73 | rd = np.random.random_sample() 74 | selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group))) 75 | each_worker_data[selected_worker].append(x) 76 | each_worker_label[selected_worker].append(y) 77 | 78 | random_order = np.random.RandomState(seed=self.params.random_seed).permutation(num_workers) 79 | each_worker_data = [each_worker_data[i] for i in random_order] 80 | each_worker_label = [each_worker_label[i] for i in random_order] 81 | 82 | train_loaders = [] 83 | for i in range(len(each_worker_data)): 84 | train_set = ClientDataset(each_worker_data[i], each_worker_label[i]) 85 | train_loader = DataLoader(train_set, 86 | batch_size=self.params.batch_size, 87 | shuffle=True) 88 | train_loaders.append(train_loader) 89 | 90 | return train_loaders 91 | 92 | 93 | class ClientDataset(Dataset): 94 | def __init__(self, data_list, label_list): 95 | super().__init__() 96 | self.data_list = data_list 97 | self.label_list = label_list 98 | 99 | def __len__(self): 100 | return len(self.label_list) 101 | 102 | def __getitem__(self, index): 103 | return self.data_list[index], self.label_list[index] 104 | -------------------------------------------------------------------------------- /tasks/fl/mnist_fltrust_task.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from torch.utils.data.dataloader import DataLoader, Dataset 6 | from torch.utils.data.sampler import SubsetRandomSampler 7 | 8 | from tasks.mnist_task import MNISTTask 9 | from tasks.fl.fl_task import FederatedLearningTask 10 | 11 | 12 | class MNIST_FLTrustTask(FederatedLearningTask, MNISTTask): 13 | 14 | def load_data(self) -> None: 15 | self.load_mnist_data() 16 | train_loaders = self.assign_data(bias=self.params.fl_q) 17 | self.fl_train_loaders = train_loaders 18 | return 19 | 20 | def assign_data(self, bias=1, p=0.1): 21 | num_labels = len(self.classes) 22 | num_workers = self.params.fl_total_participants 23 | server_pc = 0 24 | 25 | # assign data to the clients 26 | other_group_size = (1 - bias) / (num_labels - 1) 27 | worker_per_group = num_workers / num_labels 28 | 29 | #assign training data to each worker 30 | each_worker_data = [[] for _ in range(num_workers)] 31 | each_worker_label = [[] for _ in range(num_workers)] 32 | server_data = [] 33 | server_label = [] 34 | 35 | # compute the labels needed for each class 36 | real_dis = [1. / num_labels for _ in range(num_labels)] 37 | samp_dis = [0 for _ in range(num_labels)] 38 | num1 = int(server_pc * p) 39 | samp_dis[1] = num1 40 | average_num = (server_pc - num1) / (num_labels - 1) 41 | resid = average_num - np.floor(average_num) 42 | sum_res = 0. 43 | for other_num in range(num_labels - 1): 44 | if other_num == 1: 45 | continue 46 | samp_dis[other_num] = int(average_num) 47 | sum_res += resid 48 | if sum_res >= 1.0: 49 | samp_dis[other_num] += 1 50 | sum_res -= 1 51 | samp_dis[num_labels - 1] = server_pc - np.sum(samp_dis[:num_labels - 1]) 52 | 53 | # randomly assign the data points based on the labels 54 | server_counter = [0 for _ in range(num_labels)] 55 | for x, y in self.train_dataset: 56 | upper_bound = y * (1. - bias) / (num_labels - 1) + bias 57 | lower_bound = y * (1. - bias) / (num_labels - 1) 58 | rd = np.random.random_sample() 59 | 60 | if rd > upper_bound: 61 | worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1) 62 | elif rd < lower_bound: 63 | worker_group = int(np.floor(rd / other_group_size)) 64 | else: 65 | worker_group = y 66 | 67 | if server_counter[y] < samp_dis[y]: 68 | server_data.append(x) 69 | server_label.append(y) 70 | server_counter[y] += 1 71 | else: 72 | rd = np.random.random_sample() 73 | selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group))) 74 | each_worker_data[selected_worker].append(x) 75 | each_worker_label[selected_worker].append(y) 76 | 77 | random_order = np.random.RandomState(seed=self.params.random_seed).permutation(num_workers) 78 | each_worker_data = [each_worker_data[i] for i in random_order] 79 | each_worker_label = [each_worker_label[i] for i in random_order] 80 | 81 | train_loaders = [] 82 | for i in range(len(each_worker_data)): 83 | train_set = ClientDataset(each_worker_data[i], each_worker_label[i]) 84 | train_loader = DataLoader(train_set, 85 | batch_size=self.params.batch_size, 86 | shuffle=True) 87 | train_loaders.append(train_loader) 88 | 89 | return train_loaders 90 | 91 | def accumulate_weights_weighted(self, weight_accumulator, local_updates, genuine_scores): 92 | gs_sum = sum(genuine_scores) 93 | for user_id, local_update in local_updates.items(): 94 | for name, value in local_update.items(): 95 | weight_accumulator[name].add_(value * (genuine_scores[user_id] / (gs_sum + 1e-9)) * self.params.fl_total_participants) 96 | 97 | 98 | class ClientDataset(Dataset): 99 | def __init__(self, data_list, label_list): 100 | super().__init__() 101 | self.data_list = data_list 102 | self.label_list = label_list 103 | 104 | def __len__(self): 105 | return len(self.label_list) 106 | 107 | def __getitem__(self, index): 108 | return self.data_list[index], self.label_list[index] 109 | -------------------------------------------------------------------------------- /tasks/fl/mnist_ours_task.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | from typing import List, Any, Dict 4 | from tasks.fl.fl_user_ours import FLUserOurs 5 | 6 | import numpy as np 7 | from torch.utils.data.dataloader import DataLoader, Dataset 8 | from torch.utils.data.sampler import SubsetRandomSampler 9 | from torch.utils.data import random_split 10 | import torch 11 | from torchvision.utils import save_image 12 | 13 | from tasks.mnist_task import MNISTTask 14 | from tasks.fl.fl_task import FederatedLearningTask 15 | 16 | 17 | class MNIST_OursTask(FederatedLearningTask, MNISTTask): 18 | 19 | def load_data(self) -> None: 20 | self.load_mnist_data() 21 | train_loaders, test_loaders = self.assign_data(bias=self.params.fl_q) 22 | self.fl_train_loaders = train_loaders 23 | self.fl_test_loaders = test_loaders 24 | return 25 | 26 | def assign_data(self, bias=1, p=0.1): 27 | num_labels = len(self.classes) 28 | num_workers = self.params.fl_total_participants 29 | server_pc = 0 30 | 31 | # assign data to the clients 32 | other_group_size = (1 - bias) / (num_labels - 1) 33 | worker_per_group = num_workers / num_labels 34 | 35 | #assign training data to each worker 36 | each_worker_data = [[] for _ in range(num_workers)] 37 | each_worker_label = [[] for _ in range(num_workers)] 38 | server_data = [] 39 | server_label = [] 40 | 41 | # compute the labels needed for each class 42 | real_dis = [1. / num_labels for _ in range(num_labels)] 43 | samp_dis = [0 for _ in range(num_labels)] 44 | num1 = int(server_pc * p) 45 | samp_dis[1] = num1 46 | average_num = (server_pc - num1) / (num_labels - 1) 47 | resid = average_num - np.floor(average_num) 48 | sum_res = 0. 49 | for other_num in range(num_labels - 1): 50 | if other_num == 1: 51 | continue 52 | samp_dis[other_num] = int(average_num) 53 | sum_res += resid 54 | if sum_res >= 1.0: 55 | samp_dis[other_num] += 1 56 | sum_res -= 1 57 | samp_dis[num_labels - 1] = server_pc - np.sum(samp_dis[:num_labels - 1]) 58 | 59 | # randomly assign the data points based on the labels 60 | server_counter = [0 for _ in range(num_labels)] 61 | for x, y in self.train_dataset: 62 | upper_bound = y * (1. - bias) / (num_labels - 1) + bias 63 | lower_bound = y * (1. - bias) / (num_labels - 1) 64 | rd = np.random.random_sample() 65 | 66 | if rd > upper_bound: 67 | worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1) 68 | elif rd < lower_bound: 69 | worker_group = int(np.floor(rd / other_group_size)) 70 | else: 71 | worker_group = y 72 | 73 | if server_counter[y] < samp_dis[y]: 74 | server_data.append(x) 75 | server_label.append(y) 76 | server_counter[y] += 1 77 | else: 78 | rd = np.random.random_sample() 79 | selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group))) 80 | each_worker_data[selected_worker].append(x) 81 | each_worker_label[selected_worker].append(y) 82 | 83 | random_order = np.random.RandomState(seed=self.params.random_seed).permutation(num_workers) 84 | each_worker_data = [each_worker_data[i] for i in random_order] 85 | each_worker_label = [each_worker_label[i] for i in random_order] 86 | 87 | train_loaders, test_loaders = [], [] 88 | for i in range(len(each_worker_data)): 89 | train_set = ClientDataset(each_worker_data[i], each_worker_label[i]) 90 | tot = len(train_set) 91 | train_size = int(tot * self.params.attacker_train_ratio) 92 | test_size = tot - train_size 93 | train_set, test_set = random_split(train_set, 94 | lengths=[train_size, test_size], 95 | generator=torch.Generator().manual_seed(self.params.random_seed)) 96 | 97 | train_loader = DataLoader(train_set, 98 | batch_size=self.params.batch_size, 99 | shuffle=True) 100 | test_loader = DataLoader(test_set, 101 | batch_size=self.params.batch_size, 102 | shuffle=False) 103 | train_loaders.append(train_loader) 104 | test_loaders.append(test_loader) 105 | 106 | return train_loaders, test_loaders 107 | 108 | def accumulate_weights_weighted(self, weight_accumulator, local_updates, genuine_scores): 109 | gs_sum = sum(genuine_scores.values()) 110 | for user_id, local_update in local_updates.items(): 111 | for name, value in local_update.items(): 112 | weight_accumulator[name].add_(value * (genuine_scores[user_id] / (gs_sum + 1e-9)) * self.params.fl_total_participants) 113 | 114 | @torch.no_grad() 115 | def compute_genuine_score(self, model, dataloader, synthesizer): 116 | model.eval() 117 | correct = 0 118 | total = 0 119 | for i, data in enumerate(dataloader): 120 | batch = self.get_batch(i, data) 121 | batch = synthesizer.make_backdoor_batch(batch, test=True, attack=True) 122 | outputs = model(batch.inputs) 123 | 124 | pred_class_idx = torch.argmax(outputs, dim=1) 125 | correct += pred_class_idx[pred_class_idx==batch.labels].shape[0] 126 | total += batch.inputs.shape[0] 127 | 128 | return 1 - correct / total 129 | 130 | @torch.no_grad() 131 | def compute_genuine_score_global(self, model, dataloader, triggers, masks, target_cls): 132 | model.eval() 133 | correct = 0 134 | total = 0 135 | for i, data in enumerate(dataloader): 136 | batch = self.get_batch(i, data) 137 | images = batch.inputs 138 | trigger, mask = triggers[target_cls], masks[target_cls] 139 | 140 | triggerh = self.tanh_trigger(trigger) 141 | maskh = self.tanh_mask(mask) 142 | trojan_images = (1 - torch.unsqueeze(maskh, dim=0)) * images + torch.unsqueeze(maskh, dim=0) * triggerh 143 | outputs = model(trojan_images) 144 | labels = torch.tensor([target_cls] * batch.inputs.size(0)).to(self.params.device) 145 | 146 | pred_class_idx = torch.argmax(outputs, dim=1) 147 | correct += pred_class_idx[pred_class_idx==labels].shape[0] 148 | total += batch.inputs.shape[0] 149 | 150 | return 1 - correct / total 151 | 152 | def tanh_mask(self, vector): 153 | return torch.tanh(vector) / 2 + 0.5 154 | 155 | def tanh_trigger(self, vector): 156 | return (torch.tanh(vector) - 0.1307) / 0.3081 157 | 158 | def reverse_engineer_per_class(self, model, target_label, dataloader): 159 | model.eval() 160 | width, height = self.params.input_shape[1], self.params.input_shape[2] 161 | trigger = torch.randn((1, width, height)) 162 | trigger = trigger.to(self.params.device).detach().requires_grad_(True) 163 | mask = torch.zeros((width, height)) 164 | mask = mask.to(self.params.device).detach().requires_grad_(True) 165 | 166 | criterion = torch.nn.CrossEntropyLoss() 167 | optimizer = torch.optim.Adam([{"params": trigger}, {"params": mask}], lr=0.005) 168 | 169 | min_norm = np.inf 170 | min_norm_count = 0 171 | for epoch in range(self.params.nc_steps): 172 | norm = 0.0 173 | for i, data in enumerate(dataloader): 174 | batch = self.get_batch(i, data) 175 | optimizer.zero_grad() 176 | images = batch.inputs 177 | 178 | triggerh = self.tanh_trigger(trigger) 179 | maskh = self.tanh_mask(mask) 180 | trojan_images = (1 - torch.unsqueeze(maskh, dim=0)) * images + torch.unsqueeze(maskh, dim=0) * triggerh 181 | y_pred = model(trojan_images) 182 | y_target = torch.full((y_pred.size(0),), target_label, dtype=torch.long).to(self.params.device) 183 | loss = criterion(y_pred, y_target) + 0.01 * torch.sum(maskh) 184 | loss.backward() 185 | optimizer.step() 186 | 187 | with torch.no_grad(): 188 | norm = torch.sum(maskh) 189 | 190 | # early stopping 191 | if norm < min_norm: 192 | min_norm = norm 193 | min_norm_count = 0 194 | else: min_norm_count += 1 195 | if min_norm_count > 30: break 196 | 197 | return trigger, mask 198 | 199 | def reverse_engineer_trigger(self, model, dataloader): 200 | triggers, masks, norm_list = [], [], [] 201 | for cls in range(len(self.classes)): 202 | trigger, mask = self.reverse_engineer_per_class(model, cls, dataloader) 203 | triggers.append(trigger) 204 | masks.append(mask) 205 | norm_list.append(torch.sum(self.tanh_mask(mask)).item()) 206 | 207 | # visualize for debugging 208 | # batch = self.get_batch(0, next(iter(dataloader))) 209 | # images = batch.inputs 210 | 211 | # triggerh = self.tanh_trigger(trigger) 212 | # maskh = self.tanh_mask(mask) 213 | # trojan_images = (1 - torch.unsqueeze(maskh, dim=0)) * images + torch.unsqueeze(maskh, dim=0) * triggerh 214 | 215 | # save_image(images, 'runs/images_{}.png'.format(cls)) 216 | # save_image(triggerh, 'runs/trigger_{}.png'.format(cls)) 217 | # save_image(maskh, 'runs/mask_{}.png'.format(cls)) 218 | # save_image(trojan_images, 'runs/trojan_images_{}.png'.format(cls)) 219 | 220 | return triggers, masks, norm_list 221 | 222 | def sample_users_for_round(self, epoch) -> List[FLUserOurs]: 223 | sampled_ids = random.sample( 224 | range(self.params.fl_total_participants), 225 | self.params.fl_no_models) 226 | # sampled_ids = range(self.params.fl_total_participants) 227 | sampled_users = [] 228 | for pos, user_id in enumerate(sampled_ids): 229 | train_loader = self.fl_train_loaders[user_id] 230 | test_loader = self.fl_test_loaders[user_id] 231 | compromised = self.check_user_compromised(epoch, pos, user_id) 232 | user = FLUserOurs(user_id, compromised=compromised, 233 | train_loader=train_loader, test_loader=test_loader) 234 | sampled_users.append(user) 235 | 236 | return sampled_users 237 | 238 | 239 | class ClientDataset(Dataset): 240 | def __init__(self, data_list, label_list): 241 | super().__init__() 242 | self.data_list = data_list 243 | self.label_list = label_list 244 | 245 | def __len__(self): 246 | return len(self.label_list) 247 | 248 | def __getitem__(self, index): 249 | return self.data_list[index], self.label_list[index] -------------------------------------------------------------------------------- /tasks/imagenet_task.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torch import nn 3 | from torch.utils.data import DataLoader 4 | from torchvision.transforms import transforms 5 | 6 | from models.resnet import resnet18 7 | from tasks.task import Task 8 | 9 | 10 | class ImagenetTask(Task): 11 | 12 | def load_data(self): 13 | 14 | train_transform = transforms.Compose([ 15 | transforms.RandomResizedCrop(224), 16 | transforms.RandomHorizontalFlip(), 17 | transforms.ToTensor(), 18 | self.normalize, 19 | ]) 20 | test_transform = transforms.Compose([ 21 | transforms.Resize(256), 22 | transforms.CenterCrop(224), 23 | transforms.ToTensor(), 24 | self.normalize, 25 | ]) 26 | 27 | self.train_dataset = torchvision.datasets.ImageNet( 28 | root=self.params.data_path, 29 | split='train', transform=train_transform) 30 | 31 | self.test_dataset = torchvision.datasets.ImageNet( 32 | root=self.params.data_path, 33 | split='val', transform=test_transform) 34 | 35 | self.train_loader = DataLoader(self.train_dataset, 36 | batch_size=self.params.batch_size, 37 | shuffle=True, num_workers=8, pin_memory=True) 38 | self.test_loader = DataLoader(self.test_dataset, 39 | batch_size=self.params.test_batch_size, 40 | shuffle=False, num_workers=8, pin_memory=True) 41 | 42 | with open( 43 | f'{self.params.data_path}/imagenet1000_clsidx_to_labels.txt') \ 44 | as f: 45 | self.classes = eval(f.read()) 46 | 47 | def build_model(self) -> None: 48 | return resnet18(pretrained=self.params.pretrained) 49 | -------------------------------------------------------------------------------- /tasks/imdb_helper.py: -------------------------------------------------------------------------------- 1 | from data_helpers.task_helper import TaskHelper 2 | from transformers import BertTokenizer 3 | from torchtext import data, datasets 4 | import dill 5 | import torch 6 | import random 7 | 8 | class IMDBHelper(TaskHelper): 9 | 10 | def load_data(self): 11 | 12 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 13 | 14 | init_token_idx = tokenizer.cls_token_id 15 | eos_token_idx = tokenizer.sep_token_id 16 | pad_token_idx = tokenizer.pad_token_id 17 | unk_token_idx = tokenizer.unk_token_id 18 | 19 | def tokenize_and_cut(sentence): 20 | tokens = tokenizer.tokenize(sentence) 21 | tokens = tokens[:max_input_length - 2] 22 | return tokens 23 | 24 | text = data.Field(batch_first=True, 25 | use_vocab=False, 26 | tokenize=tokenize_and_cut, 27 | preprocessing=tokenizer.convert_tokens_to_ids, 28 | init_token=init_token_idx, 29 | eos_token=eos_token_idx, 30 | pad_token=pad_token_idx, 31 | unk_token=unk_token_idx) 32 | 33 | label = data.LabelField(dtype=torch.float) 34 | 35 | max_input_length = tokenizer.max_model_input_sizes['bert-base-uncased'] 36 | 37 | self.train_dataset = datasets.imdb.IMDB('.data', text, label) 38 | self.test_dataset = datasets.imdb.IMDB('.data', text, label) 39 | with open(f'{self.params.data_path}/train_data.dill', 'rb') as f: 40 | self.train_dataset.examples = dill.load(f) 41 | with open(f'{self.params.data_path}/test_data.dill', 'rb') as f: 42 | self.test_dataset.examples = dill.load(f) 43 | random.seed(5) 44 | self.test_dataset.examples = random.sample(self.test_dataset.examples, 45 | 5000) 46 | label.build_vocab(self.train_dataset) 47 | self.train_loader, self.test_loader = data.BucketIterator.splits( 48 | (self.train_dataset, self.test_dataset), 49 | batch_size=self.params.batch_size, 50 | device=self.params.device) -------------------------------------------------------------------------------- /tasks/mnist_task.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as torch_data 2 | import torchvision 3 | import torch 4 | from torchvision.transforms import transforms 5 | 6 | from models.simple import SimpleNet 7 | from tasks.task import Task 8 | 9 | 10 | class MNISTTask(Task): 11 | normalize = transforms.Normalize((0.1307,), (0.3081,)) 12 | 13 | def load_data(self): 14 | self.load_mnist_data() 15 | 16 | def load_mnist_data(self): 17 | transform_train = transforms.Compose([ 18 | transforms.ToTensor(), 19 | self.normalize 20 | ]) 21 | 22 | transform_test = transforms.Compose([ 23 | transforms.ToTensor(), 24 | self.normalize 25 | ]) 26 | 27 | original_train_dataset = torchvision.datasets.MNIST( 28 | root=self.params.data_path, 29 | train=True, 30 | download=True, 31 | transform=transform_train) 32 | 33 | clean_size = round(self.params.clean_ratio * len(original_train_dataset)) 34 | train_size = len(original_train_dataset) - clean_size 35 | self.train_dataset, self.clean_dataset = torch_data.random_split(original_train_dataset, 36 | lengths=[train_size, clean_size], 37 | generator=torch.Generator().manual_seed(self.params.random_seed)) 38 | if self.params.clean_set_dataset == 'FashionMNIST': 39 | self.clean_dataset = torchvision.datasets.FashionMNIST(root=self.params.data_path, 40 | train=True, 41 | download=True, 42 | transform=transform_train) 43 | tot = len(self.clean_dataset) 44 | self.clean_dataset, _ = torch_data.random_split(self.clean_dataset, 45 | lengths=[clean_size, tot-clean_size], 46 | generator=torch.Generator().manual_seed(self.params.random_seed)) 47 | self.clean_loader = torch_data.DataLoader(self.clean_dataset, 48 | batch_size=self.params.batch_size, 49 | shuffle=True, 50 | num_workers=0) 51 | elif self.params.clean_set_dataset is None and self.params.clean_classes is not None: 52 | clean_indices = [] 53 | for cls in self.params.clean_classes: 54 | for i in range(clean_size): 55 | if self.clean_dataset[i][1] == cls: 56 | clean_indices.append(i) 57 | 58 | sampler = torch_data.SubsetRandomSampler(clean_indices, generator=torch.Generator().manual_seed(self.params.random_seed)) 59 | self.clean_loader = torch_data.DataLoader(self.clean_dataset, 60 | batch_size=self.params.batch_size, 61 | sampler=sampler, 62 | num_workers=0) 63 | else: 64 | self.clean_loader = torch_data.DataLoader(self.clean_dataset, 65 | batch_size=self.params.batch_size, 66 | shuffle=True, 67 | num_workers=0) 68 | 69 | self.train_loader = torch_data.DataLoader(self.train_dataset, 70 | batch_size=self.params.batch_size, 71 | shuffle=True, 72 | num_workers=0) 73 | 74 | 75 | 76 | self.test_dataset = torchvision.datasets.MNIST( 77 | root=self.params.data_path, 78 | train=False, 79 | download=True, 80 | transform=transform_test) 81 | self.test_loader = torch_data.DataLoader(self.test_dataset, 82 | batch_size=self.params.test_batch_size, 83 | shuffle=False, 84 | num_workers=0) 85 | self.classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9) 86 | return True 87 | 88 | def build_model(self): 89 | return SimpleNet(num_classes=len(self.classes)) 90 | -------------------------------------------------------------------------------- /tasks/multimnist_task.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataloader import DataLoader 2 | from torchvision.transforms import transforms 3 | 4 | from dataset.multi_mnist_loader import MNIST 5 | from tasks.mnist_task import MNISTTask 6 | 7 | 8 | class MultiMNISTTask(MNISTTask): 9 | 10 | def load_data(self): 11 | transform = transforms.Compose([transforms.ToTensor(), 12 | self.normalize]) 13 | self.train_dataset = MNIST(root=self.params.data_path, train=True, download=True, 14 | transform=transform, 15 | multi=True) 16 | self.train_loader = DataLoader(self.train_dataset, 17 | batch_size=self.params.batch_size, 18 | shuffle=True, 19 | num_workers=4) 20 | self.test_dataset = MNIST(root=self.params.data_path, train=False, download=True, 21 | transform=transform, 22 | multi=True) 23 | self.test_loader = DataLoader(self.test_dataset, 24 | batch_size=self.params.test_batch_size, shuffle=False, 25 | num_workers=4) 26 | self.classes = list(range(100)) 27 | -------------------------------------------------------------------------------- /tasks/pipa_task.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as torch_data 3 | from torchvision.transforms import transforms 4 | 5 | from dataset.pipa import PipaDataset 6 | from models.resnet import resnet18 7 | from tasks.batch import Batch 8 | from tasks.task import Task 9 | 10 | 11 | class PipaTask(Task): 12 | 13 | def load_data(self): 14 | train_transform = transforms.Compose([ 15 | transforms.RandomResizedCrop(224), 16 | transforms.RandomHorizontalFlip(), 17 | transforms.ToTensor(), 18 | self.normalize, 19 | ]) 20 | test_transform = transforms.Compose([ 21 | transforms.Resize(256), 22 | transforms.CenterCrop(224), 23 | transforms.ToTensor(), 24 | self.normalize, 25 | ]) 26 | 27 | self.train_dataset = PipaDataset(data_path=self.params.data_path, 28 | train=True, 29 | transform=train_transform) 30 | # poison_weights = [0.99, 0.004, 0.004, 0.004, 0.004] 31 | # weights = [14081, 4893, 1779, 809, 32 | # 862] # [0.62, 0.22, 0.08, 0.04, 0.04] 33 | weights = torch.tensor([0.03, 0.07, 0.2, 0.35, 0.35]) 34 | weights_labels = weights[self.train_dataset.labels] 35 | sampler = torch_data.sampler.WeightedRandomSampler(weights_labels, len( 36 | self.train_dataset)) 37 | self.train_loader = \ 38 | torch_data.DataLoader(self.train_dataset, 39 | batch_size=self.params.batch_size, 40 | sampler=sampler) 41 | 42 | self.test_dataset = PipaDataset(data_path=self.params.data_path, 43 | train=False, transform=test_transform) 44 | self.test_loader = \ 45 | torch_data.DataLoader(self.test_dataset, 46 | batch_size=self.params.test_batch_size, 47 | shuffle=False, num_workers=2) 48 | 49 | self.classes = list(range(5)) 50 | 51 | def build_model(self): 52 | model = resnet18(pretrained=True) 53 | model.fc = torch.nn.Linear(512, 5) 54 | return model 55 | 56 | def get_batch(self, batch_id, data): 57 | inputs, labels, identities, _ = data 58 | batch = Batch(batch_id, inputs, labels, aux=identities) 59 | return batch.to(self.params.device) 60 | -------------------------------------------------------------------------------- /tasks/task.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | import torch 5 | from torch import optim, nn 6 | from torch.nn import Module 7 | from torch.optim import Optimizer 8 | from torch.optim.lr_scheduler import CosineAnnealingLR 9 | from torchvision.transforms import transforms 10 | 11 | from metrics.accuracy_metric import AccuracyMetric 12 | from metrics.metric import Metric 13 | from metrics.test_loss_metric import TestLossMetric 14 | from tasks.batch import Batch 15 | from utils.parameters import Params 16 | 17 | logger = logging.getLogger('logger') 18 | 19 | 20 | class Task: 21 | params: Params = None 22 | 23 | train_dataset = None 24 | test_dataset = None 25 | train_loader = None 26 | test_loader = None 27 | classes = None 28 | 29 | model: Module = None 30 | optimizer: optim.Optimizer = None 31 | criterion: Module = None 32 | scheduler: CosineAnnealingLR = None 33 | metrics: List[Metric] = None 34 | 35 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 36 | std=[0.229, 0.224, 0.225]) 37 | "Generic normalization for input data." 38 | input_shape: torch.Size = None 39 | 40 | def __init__(self, params: Params): 41 | self.params = params 42 | self.init_task() 43 | 44 | def init_task(self): 45 | self.load_data() 46 | self.model = self.build_model() 47 | self.resume_model() 48 | self.model = self.model.to(self.params.device) 49 | 50 | self.optimizer = self.make_optimizer() 51 | self.criterion = self.make_criterion() 52 | self.metrics = [AccuracyMetric(), TestLossMetric(self.criterion)] 53 | self.set_input_shape() 54 | 55 | def load_data(self) -> None: 56 | raise NotImplemented 57 | 58 | def build_model(self) -> Module: 59 | raise NotImplemented 60 | 61 | def make_criterion(self) -> Module: 62 | """Initialize with Cross Entropy by default. 63 | 64 | We use reduction `none` to support gradient shaping defense. 65 | :return: 66 | """ 67 | return nn.CrossEntropyLoss(reduction='none') 68 | 69 | def make_optimizer(self, model=None) -> Optimizer: 70 | if model is None: 71 | model = self.model 72 | if self.params.optimizer == 'SGD': 73 | optimizer = optim.SGD(model.parameters(), 74 | lr=self.params.lr, 75 | weight_decay=self.params.decay, 76 | momentum=self.params.momentum) 77 | elif self.params.optimizer == 'Adam': 78 | optimizer = optim.Adam(model.parameters(), 79 | lr=self.params.lr, 80 | weight_decay=self.params.decay) 81 | else: 82 | raise ValueError(f'No optimizer: {self.optimizer}') 83 | 84 | return optimizer 85 | 86 | def make_scheduler(self) -> None: 87 | if self.params.scheduler: 88 | self.scheduler = CosineAnnealingLR(self.optimizer, T_max=self.params.epochs) 89 | 90 | def resume_model(self): 91 | if self.params.resume_model: 92 | logger.info(f'Resuming training from {self.params.resume_model}') 93 | loaded_params = torch.load(f"saved_models/" 94 | f"{self.params.resume_model}", 95 | map_location=torch.device('cpu')) 96 | self.model.load_state_dict(loaded_params['state_dict']) 97 | self.params.start_epoch = loaded_params['epoch'] 98 | self.params.lr = loaded_params.get('lr', self.params.lr) 99 | 100 | logger.warning(f"Loaded parameters from saved model: LR is" 101 | f" {self.params.lr} and current epoch is" 102 | f" {self.params.start_epoch}") 103 | 104 | def set_input_shape(self): 105 | inp = self.train_dataset[0][0] 106 | self.params.input_shape = inp.shape 107 | 108 | def get_batch(self, batch_id, data) -> Batch: 109 | """Process data into a batch. 110 | 111 | Specific for different datasets and data loaders this method unifies 112 | the output by returning the object of class Batch. 113 | :param batch_id: id of the batch 114 | :param data: object returned by the Loader. 115 | :return: 116 | """ 117 | inputs, labels = data 118 | batch = Batch(batch_id, inputs, labels) 119 | return batch.to(self.params.device) 120 | 121 | def accumulate_metrics(self, outputs, labels): 122 | for metric in self.metrics: 123 | metric.accumulate_on_batch(outputs, labels) 124 | 125 | def reset_metrics(self): 126 | for metric in self.metrics: 127 | metric.reset_metric() 128 | 129 | def report_metrics(self, step, prefix='', 130 | tb_writer=None, tb_prefix='Metric/'): 131 | metric_text = [] 132 | for metric in self.metrics: 133 | metric_text.append(str(metric)) 134 | metric.plot(tb_writer, step, tb_prefix=tb_prefix) 135 | logger.warning(f'{prefix} {step:4d}. {" | ".join(metric_text)}') 136 | 137 | return self.metrics[0].get_main_metric_value() 138 | 139 | @staticmethod 140 | def get_batch_accuracy(outputs, labels, top_k=(1,)): 141 | """Computes the precision@k for the specified values of k""" 142 | max_k = max(top_k) 143 | batch_size = labels.size(0) 144 | 145 | _, pred = outputs.topk(max_k, 1, True, True) 146 | pred = pred.t() 147 | correct = pred.eq(labels.view(1, -1).expand_as(pred)) 148 | 149 | res = [] 150 | for k in top_k: 151 | correct_k = correct[:k].view(-1).float().sum(0) 152 | res.append((correct_k.mul_(100.0 / batch_size)).item()) 153 | if len(res) == 1: 154 | res = res[0] 155 | return res 156 | -------------------------------------------------------------------------------- /tasks/vggface_helper.py: -------------------------------------------------------------------------------- 1 | from data_helpers.task_helper import TaskHelper 2 | 3 | 4 | class VggFaceHelper(TaskHelper): 5 | 6 | def load_data(self): 7 | logger.error('VGG dataset is unfinished, needs more work') 8 | transform_train = transforms.Compose([ 9 | transforms.RandomCrop(32, padding=4), 10 | transforms.RandomHorizontalFlip(), 11 | transforms.ToTensor(), 12 | transforms.Normalize((0.4914, 0.4822, 0.4465), 13 | (0.2023, 0.1994, 0.2010)), 14 | ]) 15 | 16 | transform_test = transforms.Compose([ 17 | transforms.ToTensor(), 18 | transforms.Normalize((0.4914, 0.4822, 0.4465), 19 | (0.2023, 0.1994, 0.2010)), 20 | ]) 21 | 22 | self.train_dataset = VGG_Faces2( 23 | root=self.params.data_path, 24 | train=True, transform=transform_train) 25 | self.test_dataset = VGG_Faces2( 26 | root=self.params.data_path, 27 | train=False, transform=transform_test) 28 | 29 | return True 30 | 31 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | from datetime import datetime 4 | import copy 5 | from threading import Thread, Lock 6 | from collections import defaultdict 7 | 8 | import yaml 9 | from prompt_toolkit import prompt 10 | from tqdm import tqdm 11 | 12 | # noinspection PyUnresolvedReferences 13 | from dataset.pipa import Annotations # legacy to correctly load dataset. 14 | from helper import Helper 15 | from utils.utils import * 16 | 17 | logger = logging.getLogger('logger') 18 | 19 | 20 | def train(hlpr: Helper, epoch, model, optimizer, train_loader, attack=True, ratio=None, report=True): 21 | criterion = hlpr.task.criterion 22 | model.train() 23 | 24 | for i, data in enumerate(train_loader): 25 | batch = hlpr.task.get_batch(i, data) 26 | model.zero_grad() 27 | loss = hlpr.attack.compute_blind_loss(model, criterion, batch, attack, ratio) 28 | loss.backward() 29 | optimizer.step() 30 | 31 | if report: 32 | hlpr.report_training_losses_scales(i, epoch) 33 | if i == hlpr.params.max_batch_id: 34 | break 35 | 36 | return 37 | 38 | 39 | def test(hlpr: Helper, epoch, backdoor=False): 40 | model = hlpr.task.model 41 | model.eval() 42 | hlpr.task.reset_metrics() 43 | 44 | with torch.no_grad(): 45 | for i, data in enumerate(hlpr.task.test_loader): 46 | batch = hlpr.task.get_batch(i, data) 47 | if backdoor: 48 | batch = hlpr.attack.synthesizer.make_backdoor_batch(batch, 49 | test=True, 50 | attack=True) 51 | 52 | outputs = model(batch.inputs) 53 | hlpr.task.accumulate_metrics(outputs=outputs, labels=batch.labels) 54 | metric = hlpr.task.report_metrics(epoch, 55 | prefix=f'Backdoor {str(backdoor):5s}. Epoch: ', 56 | tb_writer=hlpr.tb_writer, 57 | tb_prefix=f'Test_backdoor_{str(backdoor):5s}') 58 | 59 | return metric 60 | 61 | 62 | def run(hlpr): 63 | acc = test(hlpr, 0, backdoor=False) 64 | for epoch in range(hlpr.params.start_epoch, 65 | hlpr.params.epochs + 1): 66 | train(hlpr, epoch, hlpr.task.model, hlpr.task.optimizer, 67 | hlpr.task.train_loader) 68 | acc = test(hlpr, epoch, backdoor=False) 69 | test(hlpr, epoch, backdoor=True) 70 | hlpr.save_model(hlpr.task.model, epoch, acc) 71 | if hlpr.task.scheduler is not None: 72 | hlpr.task.scheduler.step(epoch) 73 | 74 | 75 | def fl_run(hlpr: Helper): 76 | for epoch in range(hlpr.params.start_epoch, 77 | hlpr.params.epochs + 1): 78 | if epoch < hlpr.params.attack_start_epoch: 79 | run_fl_round_benign(hlpr, epoch) 80 | elif hlpr.params.ours: 81 | run_fl_round_ours_parallel(hlpr, epoch) 82 | elif hlpr.params.fltrust: 83 | run_fl_round_fltrust(hlpr, epoch) 84 | elif hlpr.params.defense == 'krum' or hlpr.params.defense == 'median': 85 | run_fl_round_byzantine(hlpr, epoch) 86 | else: 87 | run_fl_round(hlpr, epoch) 88 | metric = test(hlpr, epoch, backdoor=False) 89 | test(hlpr, epoch, backdoor=True) 90 | 91 | hlpr.save_model(hlpr.task.model, epoch, metric) 92 | 93 | 94 | def run_fl_round_byzantine(hlpr, epoch): 95 | global_model = hlpr.task.model 96 | local_model = hlpr.task.local_model 97 | 98 | round_participants = hlpr.task.sample_users_for_round(epoch) 99 | 100 | local_updates = [] 101 | for user in round_participants: 102 | hlpr.task.copy_params(global_model, local_model) 103 | optimizer = hlpr.task.make_optimizer(local_model) 104 | for local_epoch in range(hlpr.params.fl_local_epochs): 105 | if user.compromised: 106 | train(hlpr, local_epoch, local_model, optimizer, 107 | user.train_loader, attack=True, report=False) 108 | else: 109 | train(hlpr, local_epoch, local_model, optimizer, 110 | user.train_loader, attack=False, report=False) 111 | local_update = hlpr.task.get_fl_update(local_model, global_model) 112 | if user.compromised: 113 | hlpr.attack.fl_scale_update(local_update) 114 | local_updates.append(local_update) 115 | 116 | local_update_final = globals()[hlpr.params.defense](local_updates, hlpr) 117 | for name, value in local_update_final.items(): 118 | global_model.state_dict()[name].add_(value * hlpr.params.fl_eta) 119 | 120 | 121 | def krum(w, hlpr): 122 | distances = defaultdict(dict) 123 | non_malicious_count = hlpr.params.fl_total_participants - hlpr.params.fl_number_of_adversaries 124 | num = 0 125 | for k in w[0].keys(): 126 | if num == 0: 127 | for i in range(len(w)): 128 | for j in range(i): 129 | distances[i][j] = distances[j][i] = np.linalg.norm(w[i][k].cpu().numpy() - w[j][k].cpu().numpy()) 130 | num = 1 131 | else: 132 | for i in range(len(w)): 133 | for j in range(i): 134 | distances[j][i] += np.linalg.norm(w[i][k].cpu().numpy() - w[j][k].cpu().numpy()) 135 | distances[i][j] += distances[j][i] 136 | minimal_error = 1e20 137 | for user in distances.keys(): 138 | errors = sorted(distances[user].values()) 139 | current_error = sum(errors[:non_malicious_count]) 140 | if current_error < minimal_error: 141 | minimal_error = current_error 142 | minimal_error_index = user 143 | return w[minimal_error_index] 144 | 145 | 146 | def median(w, hlpr): 147 | number_to_consider = hlpr.params.fl_total_participants 148 | w_avg = copy.deepcopy(w[0]) 149 | for k in w_avg.keys(): 150 | tmp = [] 151 | for i in range(len(w)): 152 | tmp.append(w[i][k].cpu().numpy()) 153 | tmp = np.array(tmp) 154 | med = np.median(tmp, axis=0) 155 | new_tmp = [] 156 | for i in range(len(tmp)): 157 | new_tmp.append(tmp[i] - med) 158 | new_tmp = np.array(new_tmp) 159 | good_vals = np.argsort(abs(new_tmp), axis=0)[:number_to_consider] 160 | good_vals = np.take_along_axis(new_tmp, good_vals, axis=0) 161 | k_weight = np.array(np.mean(good_vals) + med) 162 | w_avg[k] = torch.from_numpy(k_weight).to(hlpr.params.device) 163 | return w_avg 164 | 165 | 166 | def run_fl_round_benign(hlpr, epoch): 167 | global_model = hlpr.task.model 168 | local_model = hlpr.task.local_model 169 | 170 | round_participants = hlpr.task.sample_users_for_round(epoch) 171 | weight_accumulator = hlpr.task.get_empty_accumulator() 172 | 173 | for user in round_participants: 174 | hlpr.task.copy_params(global_model, local_model) 175 | optimizer = hlpr.task.make_optimizer(local_model) 176 | for local_epoch in range(hlpr.params.fl_local_epochs): 177 | train(hlpr, local_epoch, local_model, optimizer, 178 | user.train_loader, attack=False, report=False) 179 | local_update = hlpr.task.get_fl_update(local_model, global_model) 180 | hlpr.task.accumulate_weights(weight_accumulator, local_update) 181 | 182 | hlpr.task.update_global_model(weight_accumulator, global_model) 183 | 184 | 185 | def run_fl_round(hlpr, epoch): 186 | global_model = hlpr.task.model 187 | local_model = hlpr.task.local_model 188 | 189 | round_participants = hlpr.task.sample_users_for_round(epoch) 190 | weight_accumulator = hlpr.task.get_empty_accumulator() 191 | 192 | for user in round_participants: 193 | hlpr.task.copy_params(global_model, local_model) 194 | optimizer = hlpr.task.make_optimizer(local_model) 195 | for local_epoch in range(hlpr.params.fl_local_epochs): 196 | if user.compromised: 197 | train(hlpr, local_epoch, local_model, optimizer, 198 | user.train_loader, attack=True, report=False) 199 | else: 200 | train(hlpr, local_epoch, local_model, optimizer, 201 | user.train_loader, attack=False, report=False) 202 | local_update = hlpr.task.get_fl_update(local_model, global_model) 203 | if user.compromised: 204 | hlpr.attack.fl_scale_update(local_update) 205 | hlpr.task.accumulate_weights(weight_accumulator, local_update) 206 | 207 | hlpr.task.update_global_model(weight_accumulator, global_model) 208 | 209 | 210 | def run_fl_round_fltrust(hlpr, epoch): 211 | global_model = hlpr.task.model 212 | local_model = hlpr.task.local_model 213 | 214 | round_participants = hlpr.task.sample_users_for_round(epoch) 215 | weight_accumulator = hlpr.task.get_empty_accumulator() 216 | 217 | ref_global_model = hlpr.task.build_model().to(hlpr.params.device) 218 | hlpr.task.copy_params(global_model, ref_global_model) 219 | optimizer = hlpr.task.make_optimizer(ref_global_model) 220 | for local_epoch in range(hlpr.params.fl_local_epochs): 221 | train(hlpr, local_epoch, ref_global_model, optimizer, 222 | hlpr.task.clean_loader, attack=False, report=False) 223 | global_update = hlpr.task.get_fl_update(ref_global_model, global_model) 224 | 225 | benign_ids, malicious_ids = [], [] 226 | for user in round_participants: 227 | if user.compromised: 228 | malicious_ids.append(user.user_id) 229 | else: 230 | benign_ids.append(user.user_id) 231 | 232 | local_updates = {} 233 | trust_scores = {} 234 | for user in round_participants: 235 | hlpr.task.copy_params(global_model, local_model) 236 | optimizer = hlpr.task.make_optimizer(local_model) 237 | for local_epoch in range(hlpr.params.fl_local_epochs): 238 | if user.compromised: 239 | train(hlpr, local_epoch, local_model, optimizer, 240 | user.train_loader, attack=True, report=False) 241 | else: 242 | train(hlpr, local_epoch, local_model, optimizer, 243 | user.train_loader, attack=False, report=False) 244 | local_update = hlpr.task.get_fl_update(local_model, global_model) 245 | if user.compromised: 246 | hlpr.attack.fl_scale_update(local_update) 247 | 248 | # compute trust score, normalize magnitude of local model updates 249 | trust_score, norm_scale = ts_and_norm_scale(global_update, local_update) 250 | # update local update with norm sacle 251 | hlpr.attack.fl_scale_update(local_update, scale=norm_scale) 252 | 253 | local_updates[user.user_id] = local_update 254 | trust_scores[user.user_id] = trust_score 255 | 256 | benign_average = [trust_scores[i] for i in benign_ids] 257 | malicious_average = [trust_scores[i] for i in malicious_ids] 258 | benign_average = sum(benign_average) / (len(benign_average) + 1e-9) 259 | malicious_average = sum(malicious_average) / (len(malicious_average) + 1e-9) 260 | logger.warning('Trust Scores: Benign Average: {:.5f}, Malicious Average: {:.5f}'.format(benign_average, malicious_average)) 261 | 262 | # compute the final update as weighted average over local updates with genuine scores 263 | weight_accumulator = hlpr.task.get_empty_accumulator() 264 | hlpr.task.accumulate_weights_weighted(weight_accumulator, local_updates, trust_scores) 265 | hlpr.task.update_global_model(weight_accumulator, global_model) 266 | 267 | 268 | def run_fl_round_ours_parallel(hlpr, epoch): 269 | global_model = hlpr.task.model 270 | 271 | # Client Update 272 | round_participants = hlpr.task.sample_users_for_round(epoch) 273 | local_updates = {} 274 | 275 | benign_users, malicious_users = [], [] 276 | benign_ids, malicious_ids = [], [] 277 | for user in round_participants: 278 | if user.compromised: 279 | malicious_users.append(user) 280 | malicious_ids.append(user.user_id) 281 | else: 282 | benign_users.append(user) 283 | benign_ids.append(user.user_id) 284 | 285 | start = time.time() 286 | remaining_clients = len(benign_users) 287 | while remaining_clients > 0: 288 | thread_pool_size = min(remaining_clients, hlpr.params.max_threads) 289 | threads = [] 290 | for user in benign_users[len(benign_users) - remaining_clients: \ 291 | len(benign_users) - remaining_clients + thread_pool_size]: 292 | thread = ClientThreadBenign(user, hlpr, global_model) 293 | threads.append(thread) 294 | thread.start() 295 | for thread in threads: 296 | update = thread.join() 297 | local_updates.update(update) 298 | 299 | remaining_clients -= thread_pool_size 300 | end = time.time() 301 | logger.info('Client time Bni: {}'.format(end - start)) 302 | 303 | genuine_scores_approx = {} 304 | r_all_clients = {} 305 | 306 | start = time.time() 307 | remaining_clients = len(malicious_users) 308 | while remaining_clients > 0: 309 | thread_pool_size = min(remaining_clients, hlpr.params.max_threads) 310 | threads = [] 311 | for user in malicious_users[len(malicious_users) - remaining_clients: \ 312 | len(malicious_users) - remaining_clients + thread_pool_size]: 313 | thread = ClientThreadMalicious(user, hlpr, global_model) 314 | threads.append(thread) 315 | thread.start() 316 | 317 | for thread in threads: 318 | update, key, p_local_final, r_final = thread.join() 319 | local_updates.update(update) 320 | genuine_scores_approx[key] = p_local_final 321 | r_all_clients[key] = r_final 322 | 323 | remaining_clients -= thread_pool_size 324 | end = time.time() 325 | logger.info('Client time Mal: {}'.format(end - start)) 326 | 327 | if hlpr.tb_writer is not None: 328 | hlpr.tb_writer.add_scalars('Client/Genuine_Scores_Approx', genuine_scores_approx, global_step=epoch) 329 | hlpr.tb_writer.add_scalars('Client/r', r_all_clients, global_step=epoch) 330 | hlpr.flush_writer() 331 | 332 | # Server Update 333 | start = time.time() 334 | # get reference model 335 | ref_global_model = hlpr.task.build_model().to(hlpr.params.device) 336 | hlpr.task.copy_params(global_model, ref_global_model) 337 | ref_weight_accumulator = hlpr.task.get_empty_accumulator() 338 | for local_update in local_updates.values(): 339 | hlpr.task.accumulate_weights(ref_weight_accumulator, local_update) 340 | hlpr.task.update_global_model(ref_weight_accumulator, ref_global_model) 341 | 342 | # reverse engineer trigger 343 | triggers, masks, norm_list = hlpr.task.reverse_engineer_trigger(ref_global_model, hlpr.task.clean_loader) 344 | logger.warning(norm_list) 345 | target_cls = int(torch.argmin(torch.tensor(norm_list))) 346 | 347 | # compute genuine scores for each client 348 | genuine_scores_output = {} 349 | genuine_scores = {} 350 | for user_id, local_update in local_updates.items(): 351 | # recover local model 352 | local_model = hlpr.task.build_model().to(hlpr.params.device) 353 | hlpr.task.copy_params(global_model, local_model) 354 | for name, update in local_update.items(): 355 | model_weight = local_model.state_dict()[name] 356 | model_weight.add_(update) 357 | 358 | # compute genuine score for this local model 359 | p_global = hlpr.task.compute_genuine_score_global(local_model, 360 | hlpr.task.clean_loader, 361 | triggers, 362 | masks, 363 | target_cls) 364 | 365 | genuine_scores[user_id] = p_global 366 | 367 | # Plotting (Part 1) 368 | if user_id in malicious_ids: 369 | key = 'Client {} (Malicious)'.format(user_id) 370 | else: 371 | key = 'Client {} (Benign)'.format(user_id) 372 | genuine_scores_output[key] = p_global 373 | 374 | # Plotting (Part 2) -- x-axis: global step, y-axis: genuine scores of all clients 375 | if hlpr.tb_writer is not None: 376 | hlpr.tb_writer.add_scalars('Server/Genuine_Scores', genuine_scores_output, global_step=epoch) 377 | hlpr.flush_writer() 378 | 379 | benign_average = [genuine_scores[i] for i in benign_ids] 380 | malicious_average = [genuine_scores[i] for i in malicious_ids] 381 | benign_average = sum(benign_average) / (len(benign_average) + 1e-9) 382 | malicious_average = sum(malicious_average) / (len(malicious_average) + 1e-9) 383 | logger.warning('Genuine Scores: Benign Average: {:.5f}, Malicious Average: {:.5f}'.format(benign_average, malicious_average)) 384 | 385 | # compute the final update as weighted average over local updates with genuine scores 386 | weight_accumulator = hlpr.task.get_empty_accumulator() 387 | hlpr.task.accumulate_weights_weighted(weight_accumulator, local_updates, genuine_scores) 388 | hlpr.task.update_global_model(weight_accumulator, global_model) 389 | 390 | end = time.time() 391 | logger.info('Server time: {}'.format(end - start)) 392 | 393 | 394 | class ClientThreadBenign(Thread): 395 | def __init__(self, user, hlpr, global_model): 396 | super().__init__() 397 | self.user = user 398 | self.hlpr = hlpr 399 | self.global_model = global_model 400 | self.local_model = hlpr.task.build_model().to(hlpr.task.params.device) 401 | self._return = None 402 | 403 | def run(self): 404 | # print('This is Client {}'.format(self.user.user_id)) 405 | self.hlpr.task.copy_params(self.global_model, self.local_model) 406 | optimizer = self.hlpr.task.make_optimizer(self.local_model) 407 | for local_epoch in range(self.hlpr.params.fl_local_epochs): 408 | train(self.hlpr, local_epoch, self.local_model, optimizer, 409 | self.user.train_loader, attack=False, report=False) 410 | # print('Client {} Epoch {}'.format(self.user.user_id, local_epoch)) 411 | local_update = self.hlpr.task.get_fl_update(self.local_model, self.global_model) 412 | self._return = {self.user.user_id: local_update} 413 | 414 | def join(self, *args): 415 | Thread.join(self, *args) 416 | return self._return 417 | 418 | 419 | class ClientThreadMalicious(Thread): 420 | def __init__(self, user, hlpr, global_model): 421 | super().__init__() 422 | self.user = user 423 | self.hlpr = hlpr 424 | self.global_model = global_model 425 | self.local_model = hlpr.task.build_model().to(hlpr.task.params.device) 426 | self._return = None 427 | 428 | def run(self): 429 | # print('This is Client {}'.format(self.user.user_id)) 430 | self.hlpr.task.copy_params(self.global_model, self.local_model) 431 | optimizer = self.hlpr.task.make_optimizer(self.local_model) 432 | 433 | pr_sum_max = 0 434 | p_local_final, r_final = 0, 0 435 | r = 0 436 | local_model_best = self.hlpr.task.build_model().to(self.hlpr.params.device) 437 | if self.hlpr.params.static: 438 | r = 1 # do not optimize r in static attack, always set to 1 439 | while r <= 1: 440 | for local_epoch in range(self.hlpr.params.fl_local_epochs): 441 | train(self.hlpr, local_epoch, self.local_model, optimizer, 442 | self.user.train_loader, attack=True, ratio=r, report=False) 443 | # print('Client {} Epoch {}'.format(self.user.user_id, local_epoch)) 444 | 445 | p_local = self.hlpr.task.compute_genuine_score(self.local_model, 446 | self.user.test_loader, 447 | self.hlpr.attack.synthesizer) 448 | pr_sum = p_local + self.hlpr.params.ours_lbd * r 449 | if pr_sum > pr_sum_max: 450 | pr_sum_max = pr_sum 451 | p_local_final, r_final = p_local, r 452 | self.hlpr.task.copy_params(self.local_model, local_model_best) 453 | if r == 1: 454 | break 455 | r = min(r + self.hlpr.params.r_interval, 1) 456 | self.hlpr.task.copy_params(self.global_model, self.local_model) 457 | self.hlpr.task.copy_params(local_model_best, self.local_model) 458 | 459 | key = 'Client {} (Malicious)'.format(self.user.user_id) 460 | local_update = self.hlpr.task.get_fl_update(self.local_model, self.global_model) 461 | self.hlpr.attack.fl_scale_update(local_update) 462 | 463 | self._return = {self.user.user_id: local_update}, key, p_local_final, r_final 464 | 465 | def join(self, *args): 466 | Thread.join(self, *args) 467 | return self._return 468 | 469 | 470 | if __name__ == '__main__': 471 | parser = argparse.ArgumentParser(description='Backdoors') 472 | parser.add_argument('--params', dest='params', default='utils/params.yaml') 473 | parser.add_argument('--name', dest='name', required=True, help='Tensorboard name') 474 | parser.add_argument('--commit', dest='commit', 475 | default=get_current_git_hash()) 476 | 477 | args = parser.parse_args() 478 | 479 | with open(args.params) as f: 480 | params = yaml.load(f, Loader=yaml.FullLoader) 481 | 482 | params['current_time'] = datetime.now().strftime('%b.%d_%H.%M.%S') 483 | params['commit'] = args.commit 484 | params['name'] = args.name 485 | 486 | helper = Helper(params) 487 | logger.warning(create_table(params)) 488 | 489 | try: 490 | if helper.params.fl: 491 | fl_run(helper) 492 | else: 493 | run(helper) 494 | except (KeyboardInterrupt): 495 | if helper.params.log: 496 | answer = prompt('\nDelete the repo? (y/n): ') 497 | if answer in ['Y', 'y', 'yes']: 498 | logger.error(f"Fine. Deleted: {helper.params.folder_path}") 499 | shutil.rmtree(helper.params.folder_path) 500 | if helper.params.tb: 501 | shutil.rmtree(f'runs/{args.name}') 502 | else: 503 | logger.error(f"Aborted training. " 504 | f"Results: {helper.params.folder_path}. " 505 | f"TB graph: {args.name}") 506 | else: 507 | logger.error(f"Aborted training. No output generated.") 508 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-secure/FedGame/6be7b9e3f9e3a1aa1822fa8b319ec85c6ec27d52/utils/__init__.py -------------------------------------------------------------------------------- /utils/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Experiments 6 | 7 | 8 | 9 | 10 | 11 |
12 |
13 | 14 |
15 |
16 |
17 | Tensorboard 18 |
19 | 20 |
21 | Jupyter 22 |
23 | 24 |
25 | Saved Models 26 |
27 |
28 |
29 | 30 | -------------------------------------------------------------------------------- /utils/min_norm_solvers.py: -------------------------------------------------------------------------------- 1 | # Credits to Ozan Sener 2 | # https://github.com/intel-isl/MultiObjectiveOptimization 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | class MGDASolver: 9 | MAX_ITER = 250 10 | STOP_CRIT = 1e-5 11 | 12 | @staticmethod 13 | def _min_norm_element_from2(v1v1, v1v2, v2v2): 14 | """ 15 | Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2 16 | d is the distance (objective) optimzed 17 | v1v1 = 18 | v1v2 = 19 | v2v2 = 20 | """ 21 | if v1v2 >= v1v1: 22 | # Case: Fig 1, third column 23 | gamma = 0.999 24 | cost = v1v1 25 | return gamma, cost 26 | if v1v2 >= v2v2: 27 | # Case: Fig 1, first column 28 | gamma = 0.001 29 | cost = v2v2 30 | return gamma, cost 31 | # Case: Fig 1, second column 32 | gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2 * v1v2)) 33 | cost = v2v2 + gamma * (v1v2 - v2v2) 34 | return gamma, cost 35 | 36 | @staticmethod 37 | def _min_norm_2d(vecs: list, dps): 38 | """ 39 | Find the minimum norm solution as combination of two points 40 | This is correct only in 2D 41 | ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 42 | for all i, c_i + c_j = 1.0 for some i, j 43 | """ 44 | dmin = 1e8 45 | sol = 0 46 | for i in range(len(vecs)): 47 | for j in range(i + 1, len(vecs)): 48 | if (i, j) not in dps: 49 | dps[(i, j)] = 0.0 50 | for k in range(len(vecs[i])): 51 | dps[(i, j)] += torch.dot(vecs[i][k].view(-1), 52 | vecs[j][k].view(-1)).detach() 53 | dps[(j, i)] = dps[(i, j)] 54 | if (i, i) not in dps: 55 | dps[(i, i)] = 0.0 56 | for k in range(len(vecs[i])): 57 | dps[(i, i)] += torch.dot(vecs[i][k].view(-1), 58 | vecs[i][k].view(-1)).detach() 59 | if (j, j) not in dps: 60 | dps[(j, j)] = 0.0 61 | for k in range(len(vecs[i])): 62 | dps[(j, j)] += torch.dot(vecs[j][k].view(-1), 63 | vecs[j][k].view(-1)).detach() 64 | c, d = MGDASolver._min_norm_element_from2(dps[(i, i)], 65 | dps[(i, j)], 66 | dps[(j, j)]) 67 | if d < dmin: 68 | dmin = d 69 | sol = [(i, j), c, d] 70 | return sol, dps 71 | 72 | @staticmethod 73 | def _projection2simplex(y): 74 | """ 75 | Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i 76 | """ 77 | m = len(y) 78 | sorted_y = np.flip(np.sort(y), axis=0) 79 | tmpsum = 0.0 80 | tmax_f = (np.sum(y) - 1.0) / m 81 | for i in range(m - 1): 82 | tmpsum += sorted_y[i] 83 | tmax = (tmpsum - 1) / (i + 1.0) 84 | if tmax > sorted_y[i + 1]: 85 | tmax_f = tmax 86 | break 87 | return np.maximum(y - tmax_f, np.zeros(y.shape)) 88 | 89 | @staticmethod 90 | def _next_point(cur_val, grad, n): 91 | proj_grad = grad - (np.sum(grad) / n) 92 | tm1 = -1.0 * cur_val[proj_grad < 0] / proj_grad[proj_grad < 0] 93 | tm2 = (1.0 - cur_val[proj_grad > 0]) / (proj_grad[proj_grad > 0]) 94 | 95 | skippers = np.sum(tm1 < 1e-7) + np.sum(tm2 < 1e-7) 96 | t = 1 97 | if len(tm1[tm1 > 1e-7]) > 0: 98 | t = np.min(tm1[tm1 > 1e-7]) 99 | if len(tm2[tm2 > 1e-7]) > 0: 100 | t = min(t, np.min(tm2[tm2 > 1e-7])) 101 | 102 | next_point = proj_grad * t + cur_val 103 | next_point = MGDASolver._projection2simplex(next_point) 104 | return next_point 105 | 106 | @staticmethod 107 | def find_min_norm_element(vecs: list): 108 | """ 109 | Given a list of vectors (vecs), this method finds the minimum norm 110 | element in the convex hull as min |u|_2 st. u = \sum c_i vecs[i] 111 | and \sum c_i = 1. It is quite geometric, and the main idea is the 112 | fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution 113 | lies in (0, d_{i,j})Hence, we find the best 2-task solution , and 114 | then run the projected gradient descent until convergence 115 | """ 116 | # Solution lying at the combination of two points 117 | dps = {} 118 | init_sol, dps = MGDASolver._min_norm_2d(vecs, dps) 119 | 120 | n = len(vecs) 121 | sol_vec = np.zeros(n) 122 | sol_vec[init_sol[0][0]] = init_sol[1] 123 | sol_vec[init_sol[0][1]] = 1 - init_sol[1] 124 | 125 | if n < 3: 126 | # This is optimal for n=2, so return the solution 127 | return sol_vec, init_sol[2] 128 | 129 | iter_count = 0 130 | 131 | grad_mat = np.zeros((n, n)) 132 | for i in range(n): 133 | for j in range(n): 134 | grad_mat[i, j] = dps[(i, j)] 135 | 136 | while iter_count < MGDASolver.MAX_ITER: 137 | grad_dir = -1.0 * np.dot(grad_mat, sol_vec) 138 | new_point = MGDASolver._next_point(sol_vec, grad_dir, n) 139 | # Re-compute the inner products for line search 140 | v1v1 = 0.0 141 | v1v2 = 0.0 142 | v2v2 = 0.0 143 | for i in range(n): 144 | for j in range(n): 145 | v1v1 += sol_vec[i] * sol_vec[j] * dps[(i, j)] 146 | v1v2 += sol_vec[i] * new_point[j] * dps[(i, j)] 147 | v2v2 += new_point[i] * new_point[j] * dps[(i, j)] 148 | nc, nd = MGDASolver._min_norm_element_from2(v1v1.item(), 149 | v1v2.item(), 150 | v2v2.item()) 151 | # try: 152 | new_sol_vec = nc * sol_vec + (1 - nc) * new_point 153 | # except AttributeError: 154 | # print(sol_vec) 155 | change = new_sol_vec - sol_vec 156 | if np.sum(np.abs(change)) < MGDASolver.STOP_CRIT: 157 | return sol_vec, nd 158 | sol_vec = new_sol_vec 159 | 160 | @staticmethod 161 | def find_min_norm_element_FW(vecs): 162 | """ 163 | Given a list of vectors (vecs), this method finds the minimum norm 164 | element in the convex hull 165 | as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1. 166 | It is quite geometric, and the main idea is the fact that if 167 | d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies 168 | in (0, d_{i,j})Hence, we find the best 2-task solution, and then 169 | run the Frank Wolfe until convergence 170 | """ 171 | # Solution lying at the combination of two points 172 | dps = {} 173 | init_sol, dps = MGDASolver._min_norm_2d(vecs, dps) 174 | 175 | n = len(vecs) 176 | sol_vec = np.zeros(n) 177 | sol_vec[init_sol[0][0]] = init_sol[1] 178 | sol_vec[init_sol[0][1]] = 1 - init_sol[1] 179 | 180 | if n < 3: 181 | # This is optimal for n=2, so return the solution 182 | return sol_vec, init_sol[2] 183 | 184 | iter_count = 0 185 | 186 | grad_mat = np.zeros((n, n)) 187 | for i in range(n): 188 | for j in range(n): 189 | grad_mat[i, j] = dps[(i, j)] 190 | 191 | while iter_count < MGDASolver.MAX_ITER: 192 | t_iter = np.argmin(np.dot(grad_mat, sol_vec)) 193 | 194 | v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec)) 195 | v1v2 = np.dot(sol_vec, grad_mat[:, t_iter]) 196 | v2v2 = grad_mat[t_iter, t_iter] 197 | 198 | nc, nd = MGDASolver._min_norm_element_from2(v1v1, v1v2, v2v2) 199 | new_sol_vec = nc * sol_vec 200 | new_sol_vec[t_iter] += 1 - nc 201 | 202 | change = new_sol_vec - sol_vec 203 | if np.sum(np.abs(change)) < MGDASolver.STOP_CRIT: 204 | return sol_vec, nd 205 | sol_vec = new_sol_vec 206 | 207 | @classmethod 208 | def get_scales(cls, grads, losses, normalization_type, tasks): 209 | scale = {} 210 | gn = gradient_normalizers(grads, losses, normalization_type) 211 | for t in tasks: 212 | for gr_i in range(len(grads[t])): 213 | grads[t][gr_i] = grads[t][gr_i] / (gn[t] + 1e-5) 214 | sol, min_norm = cls.find_min_norm_element([grads[t] for t in tasks]) 215 | for zi, t in enumerate(tasks): 216 | scale[t] = float(sol[zi]) 217 | 218 | return scale 219 | 220 | 221 | def gradient_normalizers(grads, losses, normalization_type): 222 | gn = {} 223 | if normalization_type == 'l2': 224 | for t in grads: 225 | gn[t] = torch.sqrt( 226 | torch.stack([gr.pow(2).sum().data for gr in grads[t]]).sum()) 227 | elif normalization_type == 'loss': 228 | for t in grads: 229 | gn[t] = min(losses[t].mean(), 10.0) 230 | elif normalization_type == 'loss+': 231 | for t in grads: 232 | gn[t] = min(losses[t].mean() * torch.sqrt( 233 | torch.stack([gr.pow(2).sum().data for gr in grads[t]]).sum()), 234 | 10) 235 | 236 | elif normalization_type == 'none' or normalization_type == 'eq': 237 | for t in grads: 238 | gn[t] = 1.0 239 | else: 240 | raise ValueError('ERROR: Invalid Normalization Type') 241 | return gn 242 | -------------------------------------------------------------------------------- /utils/parameters.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from dataclasses import dataclass, asdict 3 | from typing import List, Dict 4 | import logging 5 | import torch 6 | logger = logging.getLogger('logger') 7 | 8 | ALL_TASKS = ['backdoor', 'normal', 'sentinet_evasion', 'mask_norm', 'sums'] 9 | 10 | @dataclass 11 | class Params: 12 | 13 | # Corresponds to the class module: tasks.mnist_task.MNISTTask 14 | # See other tasks in the task folder. 15 | task: str = 'MNIST' 16 | 17 | current_time: str = None 18 | name: str = None 19 | commit: float = None 20 | random_seed: int = None 21 | device: str = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | # training params 23 | start_epoch: int = 1 24 | epochs: int = None 25 | log_interval: int = 1000 26 | 27 | # model arch is usually defined by the task 28 | pretrained: bool = False 29 | resume_model: str = None 30 | lr: float = None 31 | decay: float = None 32 | momentum: float = None 33 | optimizer: str = None 34 | scheduler: bool = False 35 | scheduler_milestones: List[int] = None 36 | # data 37 | data_path: str = '.data/' 38 | batch_size: int = 64 39 | test_batch_size: int = 100 40 | transform_train: bool = True 41 | "Do not apply transformations to the training images." 42 | max_batch_id: int = None 43 | "For large datasets stop training earlier." 44 | input_shape = None 45 | "No need to set, updated by the Task class." 46 | 47 | # gradient shaping/DP params 48 | dp: bool = None 49 | dp_clip: float = None 50 | dp_sigma: float = None 51 | 52 | # attack params 53 | backdoor: bool = False 54 | backdoor_label: int = 8 55 | poisoning_proportion: float = 1.0 # backdoors proportion in backdoor loss 56 | synthesizer: str = 'pattern' 57 | backdoor_dynamic_position: bool = False 58 | 59 | # losses to balance: `normal`, `backdoor`, `neural_cleanse`, `sentinet`, 60 | # `backdoor_multi`. 61 | loss_tasks: List[str] = None 62 | 63 | loss_balance: str = 'MGDA' 64 | "loss_balancing: `fixed` or `MGDA`" 65 | 66 | loss_threshold: float = None 67 | 68 | # approaches to balance losses with MGDA: `none`, `loss`, 69 | # `loss+`, `l2` 70 | mgda_normalize: str = None 71 | fixed_scales: Dict[str, float] = None 72 | 73 | # relabel images with poison_number 74 | poison_images: List[int] = None 75 | poison_images_test: List[int] = None 76 | # optimizations: 77 | alternating_attack: float = None 78 | clip_batch: float = None 79 | # Disable BatchNorm and Dropout 80 | switch_to_eval: float = None 81 | 82 | # nc evasion 83 | nc_p_norm: int = 1 84 | # spectral evasion 85 | spectral_similarity: 'str' = 'norm' 86 | 87 | # logging 88 | report_train_loss: bool = True 89 | log: bool = False 90 | tb: bool = False 91 | save_model: bool = None 92 | save_on_epochs: List[int] = None 93 | save_scale_values: bool = False 94 | print_memory_consumption: bool = False 95 | save_timing: bool = False 96 | timing_data = None 97 | 98 | # Temporary storage for running values 99 | running_losses = None 100 | running_scales = None 101 | 102 | # FL params 103 | fl: bool = False 104 | fl_no_models: int = 100 105 | fl_local_epochs: int = 2 106 | fl_total_participants: int = 80000 107 | fl_eta: int = 1 108 | fl_q: float = 0.1 109 | fl_dp_clip: float = None 110 | fl_dp_noise: float = None 111 | # FL attack details. Set no adversaries to perform the attack: 112 | fl_number_of_adversaries: int = 0 113 | fl_single_epoch_attack: int = None 114 | fl_weight_scale: int = 1 115 | 116 | # Clean dataset params 117 | clean_ratio: float = 0.1 118 | clean_classes: List[int] = None 119 | 120 | # Defense params 121 | ours: bool = None 122 | ours_lbd: float = 1 123 | attacker_train_ratio: float = None 124 | r_interval: float = None 125 | nc_steps: int = 1000 126 | max_threads: int = 100 127 | attack_start_epoch: int = 1 128 | clean_set_dataset: str = None 129 | defense: str = None 130 | static: bool = False 131 | 132 | # FLTrust 133 | fltrust: bool = None 134 | 135 | def __post_init__(self): 136 | # enable logging anyways when saving statistics 137 | if self.save_model or self.tb or self.save_timing or \ 138 | self.print_memory_consumption: 139 | self.log = True 140 | 141 | if self.log: 142 | self.folder_path = f'saved_models/model_' \ 143 | f'{self.task}_{self.current_time}_{self.name}' 144 | 145 | self.running_losses = defaultdict(list) 146 | self.running_scales = defaultdict(list) 147 | self.timing_data = defaultdict(list) 148 | 149 | for t in self.loss_tasks: 150 | if t not in ALL_TASKS: 151 | raise ValueError(f'Task {t} is not part of the supported ' 152 | f'tasks: {ALL_TASKS}.') 153 | 154 | def to_dict(self): 155 | return asdict(self) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import time 5 | 6 | import colorlog 7 | import torch 8 | import numpy as np 9 | 10 | from utils.parameters import Params 11 | 12 | 13 | def record_time(params: Params, t=None, name=None): 14 | if t and name and params.save_timing == name or params.save_timing is True: 15 | torch.cuda.synchronize() 16 | params.timing_data[name].append(round(1000 * (time.perf_counter() - t))) 17 | 18 | 19 | def dict_html(dict_obj, current_time): 20 | out = '' 21 | for key, value in dict_obj.items(): 22 | 23 | # filter out not needed parts: 24 | if key in ['poisoning_test', 'test_batch_size', 'discount_size', 25 | 'folder_path', 'log_interval', 26 | 'coefficient_transfer', 'grad_threshold']: 27 | continue 28 | 29 | out += f'{key}{value}' 30 | output = f'

Params for model: {current_time}:

{out}
' 31 | return output 32 | 33 | 34 | def poison_text(inputs, labels): 35 | inputs = inputs.clone() 36 | labels = labels.clone() 37 | for i in range(inputs.shape[0]): 38 | pos = random.randint(1, (inputs[i] == 102).nonzero().item() - 3) 39 | inputs[i, pos] = 3968 40 | inputs[i, pos + 1] = 3536 41 | labels = torch.ones_like(labels) 42 | return inputs, labels 43 | 44 | 45 | def poison_text_test(inputs, labels): 46 | for i in range(inputs.shape[0]): 47 | pos = random.randint(1, inputs.shape[1] - 4) 48 | inputs[i, pos] = 3968 49 | inputs[i, pos + 1] = 3536 50 | labels.fill_(1) 51 | return True 52 | 53 | 54 | def create_table(params: dict): 55 | data = "| name | value | \n |-----|-----|" 56 | 57 | for key, value in params.items(): 58 | data += '\n' + f"| {key} | {value} |" 59 | 60 | return data 61 | 62 | 63 | def get_current_git_hash(): 64 | import git 65 | repo = git.Repo(search_parent_directories=True) 66 | sha = repo.head.object.hexsha 67 | return sha 68 | 69 | 70 | def create_logger(): 71 | """ 72 | Setup the logging environment 73 | """ 74 | log = logging.getLogger() # root logger 75 | log.setLevel(logging.DEBUG) 76 | format_str = '%(asctime)s - %(levelname)-8s - %(message)s' 77 | date_format = '%Y-%m-%d %H:%M:%S' 78 | if os.isatty(2): 79 | cformat = '%(log_color)s' + format_str 80 | colors = {'DEBUG': 'reset', 81 | 'INFO': 'reset', 82 | 'WARNING': 'bold_yellow', 83 | 'ERROR': 'bold_red', 84 | 'CRITICAL': 'bold_red'} 85 | formatter = colorlog.ColoredFormatter(cformat, date_format, 86 | log_colors=colors) 87 | else: 88 | formatter = logging.Formatter(format_str, date_format) 89 | stream_handler = logging.StreamHandler() 90 | stream_handler.setFormatter(formatter) 91 | log.addHandler(stream_handler) 92 | return logging.getLogger(__name__) 93 | 94 | 95 | def th(vector): 96 | return torch.tanh(vector) / 2 + 0.5 97 | 98 | 99 | def thp(vector): 100 | return torch.tanh(vector) * 2.2 101 | 102 | 103 | def cos_relu(a,b): 104 | res = np.sum(a*b.T)/((np.sqrt(np.sum(a * a.T)) + 1e-9) * (np.sqrt(np.sum(b * b.T))) + 1e-9) 105 | if res < 0: res = 0 106 | return res 107 | 108 | 109 | def model2vector(model): 110 | nparr = np.array([]) 111 | for key, var in model.items(): 112 | nplist = var.cpu().numpy() 113 | nplist = nplist.ravel() 114 | nparr = np.append(nparr, nplist) 115 | return nparr 116 | 117 | 118 | def norm_scale(nparr1, nparr2): 119 | vnum = np.linalg.norm(nparr1, ord=None, axis=None, keepdims=False) + 1e-9 120 | return vnum / np.linalg.norm(nparr2, ord=None, axis=None, keepdims=False) + 1e-9 121 | 122 | 123 | def ts_and_norm_scale(global_update, local_update): 124 | vector1 = model2vector(global_update) 125 | vector2 = model2vector(local_update) 126 | return cos_relu(vector1, vector2), norm_scale(vector1, vector2) 127 | --------------------------------------------------------------------------------