├── src ├── __init__.py ├── __pycache__ │ ├── nets.cpython-38.pyc │ ├── test.cpython-38.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── fedadam.cpython-38.pyc │ ├── fedavg.cpython-38.pyc │ ├── feddyn.cpython-38.pyc │ ├── fedprox.cpython-38.pyc │ ├── strategy.cpython-38.pyc │ ├── update.cpython-37.pyc │ ├── update.cpython-38.pyc │ ├── update.cpython-39.pyc │ ├── aggregation.cpython-37.pyc │ ├── aggregation.cpython-38.pyc │ └── aggregation.cpython-39.pyc ├── aggregation.py └── update.py ├── tsboard.sh ├── check.sh ├── utils ├── __init__.py ├── __pycache__ │ ├── attack.cpython-37.pyc │ ├── attack.cpython-38.pyc │ ├── attack.cpython-39.pyc │ ├── nets.cpython-38.pyc │ ├── test.cpython-37.pyc │ ├── test.cpython-38.pyc │ ├── test.cpython-39.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── channel.cpython-38.pyc │ ├── dataset.cpython-37.pyc │ ├── dataset.cpython-38.pyc │ ├── dataset.cpython-39.pyc │ ├── distance.cpython-37.pyc │ ├── distance.cpython-39.pyc │ ├── options.cpython-37.pyc │ ├── options.cpython-38.pyc │ ├── options.cpython-39.pyc │ ├── sampling.cpython-37.pyc │ ├── sampling.cpython-38.pyc │ ├── sampling.cpython-39.pyc │ ├── scoring.cpython-37.pyc │ ├── byzantine.cpython-38.pyc │ ├── distribute.cpython-38.pyc │ ├── model_copy.cpython-38.pyc │ ├── poisoning.cpython-38.pyc │ ├── sampling_v2.cpython-37.pyc │ ├── byzantine_fl.cpython-37.pyc │ ├── byzantine_fl.cpython-38.pyc │ ├── byzantine_fl.cpython-39.pyc │ └── byzantine_fl_v2.cpython-37.pyc ├── attack.py ├── test.py ├── dataset.py ├── sampling.py ├── options.py └── byzantine_fl.py ├── execute ├── run0.sh ├── run1.sh └── run2.sh ├── docker.sh ├── LICENSE ├── tb2csv.py ├── README.md └── main.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tsboard.sh: -------------------------------------------------------------------------------- 1 | tensorboard --logdir=runs 2 | -------------------------------------------------------------------------------- /check.sh: -------------------------------------------------------------------------------- 1 | tensorboard --logdir=runs --host=0.0.0.0 --port=7017 2 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | -------------------------------------------------------------------------------- /src/__pycache__/nets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/src/__pycache__/nets.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/test.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/src/__pycache__/test.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/src/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/src/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/src/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/fedadam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/src/__pycache__/fedadam.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/fedavg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/src/__pycache__/fedavg.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/feddyn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/src/__pycache__/feddyn.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/fedprox.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/src/__pycache__/fedprox.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/strategy.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/src/__pycache__/strategy.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/update.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/src/__pycache__/update.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/update.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/src/__pycache__/update.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/update.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/src/__pycache__/update.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/attack.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/attack.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/attack.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/attack.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/attack.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/attack.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/nets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/nets.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/test.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/test.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/test.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/test.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/test.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/test.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/channel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/channel.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/dataset.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/distance.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/distance.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/distance.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/distance.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/options.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/options.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/options.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/options.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/options.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sampling.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/sampling.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sampling.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/sampling.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sampling.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/sampling.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/scoring.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/scoring.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/aggregation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/src/__pycache__/aggregation.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/aggregation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/src/__pycache__/aggregation.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/aggregation.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/src/__pycache__/aggregation.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/byzantine.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/byzantine.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/distribute.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/distribute.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_copy.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/model_copy.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/poisoning.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/poisoning.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sampling_v2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/sampling_v2.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/byzantine_fl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/byzantine_fl.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/byzantine_fl.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/byzantine_fl.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/byzantine_fl.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/byzantine_fl.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/byzantine_fl_v2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjlee22/byzantineFL/HEAD/utils/__pycache__/byzantine_fl_v2.cpython-37.pyc -------------------------------------------------------------------------------- /execute/run0.sh: -------------------------------------------------------------------------------- 1 | for method in fedavg krum trimmed_mean fang 2 | do 3 | python main.py --gpu 0 --method $method --tsboard --c_frac 0.0 --quantity_skew 4 | done -------------------------------------------------------------------------------- /docker.sh: -------------------------------------------------------------------------------- 1 | nvidia-docker run -d -it --restart always --name experiment --shm-size 256G -p $tensorboard_port:$tensorboard_port --ip=0.0.0.0 --mount type=bind,source=$(pwd)/tmi, target=/root/tmi pytorch/pytorch:1.12.0-cuda11.3-cudnn8-devel -------------------------------------------------------------------------------- /execute/run1.sh: -------------------------------------------------------------------------------- 1 | for method in fedavg krum trimmed_mean fang 2 | do 3 | for frac in 0.1 0.2 0.3 4 | do 5 | for p in target untarget 6 | do 7 | python main.py --gpu 0 --method $method --tsboard --c_frac $frac --p $p --quantity_skew 8 | done 9 | done 10 | done -------------------------------------------------------------------------------- /execute/run2.sh: -------------------------------------------------------------------------------- 1 | for method in fedavg krum trimmed_mean fang 2 | do 3 | for alpha in 0.01 0.1 1.0 10.0 100.0 4 | do 5 | for p in target untarget 6 | do 7 | python main.py --gpu 0 --method $method --tsboard --c_frac 0.3 --alpha $alpha --p $p --quantity_skew 8 | done 9 | done 10 | done -------------------------------------------------------------------------------- /src/aggregation.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | 5 | def fedavg(w_locals): 6 | w_avg = copy.deepcopy(w_locals[0]) 7 | 8 | with torch.no_grad(): 9 | for k in w_avg.keys(): 10 | for i in range(1, len(w_locals)): 11 | w_avg[k] += w_locals[i][k] 12 | w_avg[k] = torch.true_divide(w_avg[k], len(w_locals)) 13 | 14 | return w_avg -------------------------------------------------------------------------------- /utils/attack.py: -------------------------------------------------------------------------------- 1 | import random 2 | import copy 3 | 4 | import torch 5 | 6 | def compromised_clients(args): 7 | 8 | max_num = max(int(args.c_frac * args.num_clients), 1) 9 | 10 | tmp_idx = [i for i in range(args.num_clients)] 11 | 12 | compromised_idxs = random.sample(tmp_idx, max_num) 13 | 14 | return compromised_idxs 15 | 16 | def untargeted_attack(w, args): 17 | mpaf = copy.deepcopy(w) 18 | for k in w.keys(): 19 | tmp = torch.zeros_like(mpaf[k], dtype = torch.float32).to(args.device) 20 | w_base = torch.randn_like(mpaf[k], dtype = torch.float32).to(args.device) 21 | tmp += (w[k].to(args.device) - w_base) * args.mp_lambda 22 | mpaf[k].copy_(tmp) 23 | 24 | return mpaf 25 | 26 | -------------------------------------------------------------------------------- /utils/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import torch.nn.functional as F 4 | 5 | def test_img(net_g, dataset, args): 6 | # testing 7 | correct = 0 8 | data_loader = DataLoader(dataset, batch_size=128) 9 | test_loss = 0 10 | with torch.no_grad(): 11 | net_g.eval() 12 | for idx, (data, target) in enumerate(data_loader): 13 | if args.gpu != -1: 14 | data, target = data.to(args.device), target.to(args.device) 15 | log_probs = net_g(data) 16 | 17 | # sum up batch loss 18 | test_loss += F.cross_entropy(log_probs, target.squeeze(dim=-1), reduction='sum').item() 19 | 20 | # get the index of the max log-probability 21 | y_pred = log_probs.data.max(1, keepdim=True)[1] 22 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 23 | 24 | test_loss /= len(data_loader.dataset) 25 | test_acc = 100.00 * correct / len(data_loader.dataset) 26 | 27 | return test_acc, test_loss 28 | 29 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from torch.utils.data import Dataset 3 | 4 | import medmnist 5 | from medmnist import INFO 6 | 7 | class DatasetSplit(Dataset): 8 | def __init__(self, dataset, idxs): 9 | self.dataset = dataset 10 | self.idxs = list(idxs) 11 | 12 | def __len__(self): 13 | return len(self.idxs) 14 | 15 | def __getitem__(self, item): 16 | image, label = self.dataset[self.idxs[item]] 17 | return image, label 18 | 19 | def load_data(args): 20 | 21 | info = INFO[args.dataset] 22 | DataClass = getattr(medmnist, info['python_class']) 23 | trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[.5], std=[.5])]) 24 | 25 | dataset_train = DataClass(split='train', transform=trans, download=True, as_rgb=True) 26 | dataset_test = DataClass(split='test', transform=trans, download=True, as_rgb=True) 27 | dataset_val = DataClass(split='val', transform=trans, download=True, as_rgb=True) 28 | 29 | return dataset_train, dataset_test, dataset_val 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Youngjoon Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /tb2csv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | 4 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator 5 | 6 | cpath = os.getcwd() 7 | 8 | def tabulate_events(dpath): 9 | 10 | for dname in os.listdir(dpath): 11 | ea = EventAccumulator(os.path.join(dpath, dname)).Reload() 12 | tags = ea.Tags()['scalars'] 13 | out = {} 14 | 15 | for tag in tags: 16 | tag_values=[] 17 | for event in ea.Scalars(tag): 18 | tag_values.append(event.value) 19 | out[tag] = tag_values[:-7] 20 | 21 | out_keys = [k for k in out.keys()] 22 | out_values = [v for v in out.values()] 23 | 24 | for i in range(2): 25 | try: 26 | with open(f"../result/{out_keys[i]}.csv", 'w') as file: 27 | writer = csv.writer(file) 28 | writer.writerow(out_values[i]) 29 | except IndexError: 30 | print(f'{tag}') 31 | 32 | return "Converted" 33 | 34 | cpath = os.getcwd() 35 | folderpath = os.chdir(cpath + "/runs") 36 | 37 | for folder in os.listdir(folderpath): 38 | path = f'{folder}' 39 | tabulate_events(path) 40 | 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Security-Preserving Federated Learning via Byzantine-Sensitive Triplet Distance 2 | 3 | This is an official implementation of the following paper: 4 | > Youngjoon Lee, Sangwoo Park, and Joonhyuk Kang. 5 | **[Security-Preserving Federated Learning via Byzantine-Sensitive Triplet Distance](https://arxiv.org/abs/2210.16519)** 6 | _IEEE International Symposium on Biomedical Imaging (ISBI) 2024_. 7 | 8 | ## Requirements 9 | The implementation runs on 10 | 11 | ```bash docker.sh``` 12 | 13 | Additionally, please install the required packages as below 14 | 15 | ```pip install tensorboard medmnist``` 16 | 17 | ## Byzantine attacks 18 | This paper considers the following poisoning attacks 19 | - Targeted model poisoning ([Bhagoji, Arjun Nitin, et al. ICML 2019](https://arxiv.org/abs/1811.12470)): Targeted model poisoning attack for federated learning 20 | - MPAF ([Xiaoyu Cao, Neil Zhenqiang Gong. CVPR Workshop 2022](https://arxiv.org/abs/2203.08669)): Untargeted model poisoning attack for federated learning 21 | 22 | ## Byzantine-Robust Aggregation Techniques 23 | This paper considers the following Byzantine-Robust aggregation techniques 24 | - Vanilla ([McMahan, Brendan, et al. AISTATS 2017](http://proceedings.mlr.press/v54/mcmahan17a?ref=https://githubhelp.com)) 25 | - Krum ([Blanchard, Peva, et al. NIPS 2017](https://proceedings.neurips.cc/paper/2017/hash/f4b9ec30ad9f68f89b29639786cb62ef-Abstract.html)) 26 | - Trimmed-mean ([Yin, Dong, et al. ICML 2018](https://proceedings.mlr.press/v80/yin18a)) 27 | - Fang ([Fang, Minghong, et al. USENIX 2020](https://arxiv.org/abs/1911.11815)) 28 | 29 | ## Dataset 30 | - Blood cell classification dataset ([Andrea Acevedo, Anna Merino, et al. Data in Brief 2020](https://www.sciencedirect.com/science/article/pii/S2352340920303681)) 31 | 32 | ## Experiments 33 | Without Byzantine attacks experiment runs on 34 | 35 | ```bash execute/run0.sh``` 36 | 37 | Impact of Byzantine percentage runs on 38 | 39 | ```bash execute/run1.sh``` 40 | 41 | Impact of non-iid degree runs on 42 | 43 | ```bash execute/run2.sh``` 44 | 45 | ## Acknowledgements 46 | Referred http://doi.org/10.5281/zenodo.4321561 47 | -------------------------------------------------------------------------------- /utils/sampling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def noniid(dataset, args): 4 | 5 | idxs = np.arange(len(dataset)) 6 | labels = np.transpose(np.array(dataset.labels)) 7 | 8 | dict_users = {i: list() for i in range(args.num_clients)} 9 | dict_labels = dict() 10 | 11 | idxs_labels = np.vstack((idxs, labels)) 12 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 13 | 14 | idxs = list(idxs_labels[0]) 15 | labels = idxs_labels[1] 16 | 17 | if args.quantity_skew: 18 | min_num, max_num = 500, 1000 19 | num_rand = np.random.randint(min_num, max_num+1, size=args.num_clients) 20 | 21 | if args.alpha > 0: 22 | proportions = np.random.dirichlet(np.ones(args.num_classes) * args.alpha, args.num_clients) 23 | else: 24 | rand_class_num = np.random.randint(0, 10, size=args.num_clients) 25 | 26 | for i in range(args.num_classes): 27 | specific_class = set(np.extract(labels == i, idxs)) 28 | dict_labels.update({i : specific_class}) 29 | 30 | if args.alpha > 0: 31 | for i, prop in enumerate(proportions): 32 | 33 | if args.quantity_skew: 34 | prop = num_rand[i] * prop 35 | else: 36 | prop = args.num_data * prop 37 | 38 | rand_set = list() 39 | for c in range(args.num_classes): 40 | try: 41 | rand_class = list(np.random.choice(list(dict_labels[c]), int(prop[c]))) 42 | dict_labels[c] = dict_labels[c] - set(rand_class) 43 | rand_set = rand_set + rand_class 44 | except ValueError as v: 45 | pass 46 | dict_users[i] = set(rand_set) 47 | else: 48 | rand_set = list() 49 | 50 | for i, class_num in enumerate(rand_class_num): 51 | rand_set = list() 52 | if args.quantity_skew: 53 | rand_class = list(np.random.choice(list(dict_labels[class_num]), num_rand[i])) 54 | else: 55 | rand_class = list(np.random.choice(list(dict_labels[class_num]), args.num_data)) 56 | dict_labels[class_num] = dict_labels[class_num] - set(rand_class) 57 | 58 | rand_set = rand_set + rand_class 59 | dict_users[i] = set(rand_set) 60 | 61 | return dict_users -------------------------------------------------------------------------------- /src/update.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | 7 | from utils.dataset import DatasetSplit 8 | 9 | class BenignUpdate(object): 10 | def __init__(self, args, dataset=None, idxs=None): 11 | self.args = args 12 | self.loss_func = nn.CrossEntropyLoss() 13 | self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True, drop_last=True) 14 | 15 | def train(self, net): 16 | 17 | net.train() 18 | 19 | # train and update 20 | optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr) 21 | 22 | for iter in range(self.args.local_ep): 23 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 24 | optimizer.zero_grad() 25 | images, labels = images.to(self.args.device), labels.to(self.args.device) 26 | 27 | log_probs = net(images) 28 | 29 | loss = self.loss_func(log_probs, labels.squeeze(dim=-1)) 30 | 31 | loss.backward() 32 | 33 | optimizer.step() 34 | 35 | return net.state_dict() 36 | 37 | class CompromisedUpdate(object): 38 | def __init__(self, args, dataset=None, idxs=None): 39 | self.args = args 40 | self.loss_func = nn.CrossEntropyLoss() 41 | self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True, drop_last=True) 42 | 43 | def train(self, net): 44 | 45 | 46 | net_freeze = copy.deepcopy(net) 47 | 48 | net.train() 49 | 50 | # train and update 51 | optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr) 52 | 53 | for iter in range(self.args.local_ep): 54 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 55 | optimizer.zero_grad() 56 | images, labels = images.to(self.args.device), labels.to(self.args.device) 57 | 58 | log_probs = net(images) 59 | 60 | loss = self.loss_func(log_probs, labels.squeeze(dim=-1)) 61 | 62 | loss.backward() 63 | 64 | optimizer.step() 65 | 66 | for w, w_t in zip(net_freeze.parameters(), net.parameters()): 67 | w_t.data = (w_t.data - w.data) * self.args.mp_alpha 68 | 69 | return net.state_dict() -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def args_parser(): 4 | parser = argparse.ArgumentParser() 5 | 6 | # federated learning arguments 7 | parser.add_argument('--method', type=str, default='krum', help="aggregation method") 8 | parser.add_argument('--global_ep', type=int, default=200, help="total number of communication rounds") 9 | parser.add_argument('--alpha', default=10.0, type=float, help="random distribution fraction alpha") 10 | parser.add_argument('--num_clients', type=int, default=10, help="number of clients: K") 11 | parser.add_argument('--num_data', type=int, default=100, help="number of data per client for label skew") 12 | parser.add_argument('--quantity_skew', action='store_true', help='quantity_skew') 13 | parser.add_argument('--num_pretrain', type=int, default=50, help="number of data for pretraining") 14 | parser.add_argument('--frac', type=float, default=1.0, help="fraction of clients: C") 15 | parser.add_argument('--ratio', type=float, default=1.0, help="ratio of datasize") 16 | parser.add_argument('--local_ep', type=int, default=5, help="number of local epochs: E") 17 | parser.add_argument('--local_bs', type=int, default=20, help="local batch size: B") 18 | parser.add_argument('--bs', type=int, default=20, help="test batch size") 19 | parser.add_argument('--ds', type=int, default=20, help="dummy batch size") 20 | parser.add_argument('--lr', type=float, default=0.001, help="client learning rate") 21 | 22 | # other arguments 23 | parser.add_argument('--dataset', type=str, default='bloodmnist', help="name of dataset") 24 | parser.add_argument('--model', type=str, default='resnet', help='model name') 25 | parser.add_argument('--sampling', type=str, default='noniid', help="sampling method") 26 | parser.add_argument('--num_classes', type=int, default=8, help="number of classes") 27 | parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU") 28 | parser.add_argument('--seed', type=int, default=3, help='random seed (default: 1)') 29 | parser.add_argument('--tsboard', action='store_true', help='tensorboard') 30 | parser.add_argument('--debug', action='store_true', help='debug') 31 | parser.add_argument('--earlystop', action='store_true', help='early stopping option') 32 | parser.add_argument('--patience', type=int, default=8, help="hyperparameter of early stopping") 33 | parser.add_argument('--delta', type=float, default=0.01, help="hyperparameter of early stopping") 34 | 35 | # poisoning arguments 36 | parser.add_argument('--c_frac', default=0.0, type=float, help="fraction of compromised clients") 37 | parser.add_argument('--mp_alpha', type=float, default=10.0, help="hyperparameter for targeted model attack") 38 | parser.add_argument('--p', type=str, default='normal', help="model poisoning attack (target, untarget) or data poisoning") 39 | parser.add_argument('--mp_lambda', type=float, default=10.0, help="hyperparameter for untargeted model attack") 40 | 41 | args = parser.parse_args() 42 | 43 | return args -------------------------------------------------------------------------------- /utils/byzantine_fl.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | from torchvision.models import resnet18 7 | 8 | from utils.test import test_img 9 | from src.aggregation import fedavg 10 | 11 | def euclid(v1, v2): 12 | diff = v1 - v2 13 | return torch.matmul(diff, diff.T) 14 | 15 | def multi_vectorization(w_locals, args): 16 | vectors = copy.deepcopy(w_locals) 17 | 18 | for i, v in enumerate(vectors): 19 | for name in v: 20 | v[name] = v[name].reshape([-1]).to(args.device) 21 | vectors[i] = torch.cat(list(v.values())) 22 | 23 | return vectors 24 | 25 | def single_vectorization(w_glob, args): 26 | vector = copy.deepcopy(w_glob) 27 | for name in vector: 28 | vector[name] = vector[name].reshape([-1]).to(args.device) 29 | 30 | return torch.cat(list(vector.values())) 31 | 32 | def pairwise_distance(w_locals, args): 33 | 34 | vectors = multi_vectorization(w_locals, args) 35 | distance = torch.zeros([len(vectors), len(vectors)]).to(args.device) 36 | 37 | for i, v_i in enumerate(vectors): 38 | for j, v_j in enumerate(vectors[i:]): 39 | distance[i][j + i] = distance[j + i][i] = euclid(v_i, v_j) 40 | 41 | return distance 42 | 43 | def krum(w_locals, c, args): 44 | n = len(w_locals) - c 45 | 46 | distance = pairwise_distance(w_locals, args) 47 | sorted_idx = distance.sum(dim=0).argsort()[: n] 48 | 49 | chosen_idx = int(sorted_idx[0]) 50 | 51 | return copy.deepcopy(w_locals[chosen_idx]), chosen_idx 52 | 53 | def trimmed_mean(w_locals, c, args): 54 | n = len(w_locals) - 2 * c 55 | 56 | distance = pairwise_distance(w_locals, args) 57 | 58 | distance = distance.sum(dim=1) 59 | med = distance.median() 60 | _, chosen = torch.sort(abs(distance - med)) 61 | chosen = chosen[: n] 62 | 63 | return fedavg([copy.deepcopy(w_locals[int(i)]) for i in chosen]) 64 | 65 | def fang(w_locals, dataset_val, c, args): 66 | 67 | loss_impact = {} 68 | net_a = resnet18(num_classes = args.num_classes) 69 | net_b = copy.deepcopy(net_a) 70 | 71 | for i in range(len(w_locals)): 72 | tmp_w_locals = copy.deepcopy(w_locals) 73 | w_a = trimmed_mean(tmp_w_locals, c, args) 74 | tmp_w_locals.pop(i) 75 | w_b = trimmed_mean(tmp_w_locals, c, args) 76 | 77 | net_a.load_state_dict(w_a) 78 | net_b.load_state_dict(w_b) 79 | 80 | _, loss_a = test_img(net_a.to(args.device), dataset_val, args) 81 | _, loss_b = test_img(net_b.to(args.device), dataset_val, args) 82 | 83 | loss_impact.update({i : loss_a - loss_b}) 84 | 85 | sorted_loss_impact = sorted(loss_impact.items(), key = lambda item: item[1]) 86 | filterd_clients = [sorted_loss_impact[i][0] for i in range(len(w_locals) - c)] 87 | 88 | return fedavg([copy.deepcopy(w_locals[i]) for i in filterd_clients]) 89 | 90 | def triplet_distance(w_locals, global_net, args): 91 | 92 | score = torch.zeros([args.num_clients, args.num_clients]).to(args.device) 93 | dummy_data = torch.empty(args.ds, 3, 28 ,28).uniform_(0, 1).to(args.device) 94 | net1 = resnet18(num_classes = args.num_classes).to(args.device) 95 | net2 = copy.deepcopy(net1).to(args.device) 96 | import ipdb; ipdb.set_trace() 97 | anchor = nn.Sequential(*list(global_net.children())[:-1])(dummy_data).squeeze() 98 | 99 | for i, w_i in enumerate(w_locals): 100 | net1.load_state_dict(w_i) 101 | pro1 = nn.Sequential(*list(net1.children())[:-1])(dummy_data).squeeze() 102 | for j, w_j in enumerate(w_locals[i:]): 103 | net2.load_state_dict(w_j) 104 | pro2 = nn.Sequential(*list(net2.children())[:-1])(dummy_data).squeeze() 105 | 106 | score[i][j + i] = score[j + i][i] = F.binary_cross_entropy_with_logits(pro1, anchor) + F.binary_cross_entropy_with_logits(pro2, anchor) 107 | 108 | return score 109 | 110 | def dummy_contrastive_aggregation(w_locals, c, global_net, args): 111 | n = len(w_locals) - c 112 | 113 | score = triplet_distance(copy.deepcopy(w_locals), global_net, args) 114 | 115 | sorted_idx = score.sum(dim=0).argsort()[: n] 116 | 117 | return fedavg([copy.deepcopy(w_locals[int(i)]) for i in sorted_idx]) 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | 4 | import torch 5 | from torch.utils.tensorboard import SummaryWriter 6 | from torchvision.models import resnet18 7 | 8 | import copy 9 | import numpy as np 10 | import random 11 | from tqdm import trange 12 | 13 | from utils.options import args_parser 14 | from utils.sampling import noniid 15 | from utils.dataset import load_data 16 | from utils.test import test_img 17 | from utils.byzantine_fl import krum, trimmed_mean, fang, dummy_contrastive_aggregation 18 | from utils.attack import compromised_clients, untargeted_attack 19 | from src.aggregation import fedavg 20 | from src.update import BenignUpdate, CompromisedUpdate 21 | 22 | if __name__ == '__main__': 23 | # parse args 24 | args = args_parser() 25 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 26 | 27 | if args.tsboard: 28 | writer = SummaryWriter(f'runs/data') 29 | 30 | random.seed(args.seed) 31 | np.random.seed(args.seed) 32 | torch.manual_seed(args.seed) 33 | torch.cuda.manual_seed_all(args.seed) 34 | 35 | dataset_train, dataset_test, dataset_val = load_data(args) 36 | 37 | # early stopping hyperparameters 38 | cnt = 0 39 | check_acc = 0 40 | 41 | # sample users 42 | dict_users = noniid(dataset_train, args) 43 | net_glob = resnet18(num_classes = args.num_classes).to(args.device) 44 | 45 | net_glob.train() 46 | 47 | # copy weights 48 | w_glob = net_glob.state_dict() 49 | 50 | if args.c_frac > 0: 51 | compromised_idxs = compromised_clients(args) 52 | else: 53 | compromised_idxs = [] 54 | 55 | for iter in trange(args.global_ep): 56 | w_locals = [] 57 | selected_clients = max(int(args.frac * args.num_clients), 1) 58 | compromised_num = int(args.c_frac * selected_clients) 59 | idxs_users = np.random.choice(range(args.num_clients), selected_clients, replace=False) 60 | 61 | for idx in idxs_users: 62 | if idx in compromised_idxs: 63 | if args.p == "untarget": 64 | w_locals.append(copy.deepcopy(untargeted_attack(net_glob.state_dict(), args))) 65 | else: 66 | local = CompromisedUpdate(args = args, dataset = dataset_train, idxs = dict_users[idx]) 67 | w = local.train(net = copy.deepcopy(net_glob).to(args.device)) 68 | w_locals.append(copy.deepcopy(w)) 69 | else: 70 | local = BenignUpdate(args = args, dataset = dataset_train, idxs = dict_users[idx]) 71 | w = local.train(net = copy.deepcopy(net_glob).to(args.device)) 72 | w_locals.append(copy.deepcopy(w)) 73 | 74 | # update global weights 75 | if args.method == 'fedavg': 76 | w_glob = fedavg(w_locals) 77 | elif args.method == 'krum': 78 | w_glob, _ = krum(w_locals, compromised_num, args) 79 | elif args.method == 'trimmed_mean': 80 | w_glob = trimmed_mean(w_locals, compromised_num, args) 81 | elif args.method == 'fang': 82 | w_glob = fang(w_locals, dataset_val, compromised_num, args) 83 | elif args.method == 'dca': 84 | w_glob = dummy_contrastive_aggregation(w_locals, compromised_num, copy.deepcopy(net_glob), args) 85 | else: 86 | exit('Error: unrecognized aggregation technique') 87 | 88 | # copy weight to net_glob 89 | net_glob.load_state_dict(w_glob) 90 | 91 | test_acc, test_loss = test_img(net_glob.to(args.device), dataset_test, args) 92 | 93 | if args.debug: 94 | print(f"Round: {iter}") 95 | print(f"Test accuracy: {test_acc}") 96 | print(f"Test loss: {test_loss}") 97 | print(f"Check accuracy: {check_acc}") 98 | print(f"patience: {cnt}") 99 | 100 | if check_acc == 0: 101 | check_acc = test_acc 102 | elif test_acc < check_acc + args.delta: 103 | cnt += 1 104 | else: 105 | check_acc = test_acc 106 | cnt = 0 107 | 108 | # early stopping 109 | if cnt == args.patience: 110 | print('Early stopped federated training!') 111 | break 112 | 113 | # tensorboard 114 | if args.tsboard: 115 | writer.add_scalar(f'testacc/{args.method}_{args.p}_cfrac_{args.c_frac}_alpha_{args.alpha}', test_acc, iter) 116 | writer.add_scalar(f'testloss/{args.method}_{args.p}_cfrac_{args.c_frac}_alpha_{args.alpha}', test_loss, iter) 117 | 118 | if args.tsboard: 119 | writer.close() 120 | --------------------------------------------------------------------------------