├── server ├── __init__.py ├── agg_funs.py └── base.py ├── attacker ├── __init__.py ├── lie.py └── base.py ├── utils ├── __init__.py └── utils.py ├── client.py ├── requirements.txt ├── main.py ├── README.md └── environment.yml /server/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import RobustServer 2 | __all__ = [ 3 | 'RobustServer', 4 | ] -------------------------------------------------------------------------------- /attacker/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Attacker 2 | from .lie import Lie 3 | 4 | __all__ = [ 5 | 'Attacker', 6 | 'Lie' 7 | ] -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import ( 2 | flatten_models, 3 | unflatten_tensor, 4 | Register 5 | ) 6 | 7 | __all__ = [ 8 | 'flatten_models', 9 | 'unflatten_tensor', 10 | 'Register', 11 | ] -------------------------------------------------------------------------------- /client.py: -------------------------------------------------------------------------------- 1 | from easyfl.client import BaseClient 2 | 3 | class CustomizedClient(BaseClient): 4 | def __init__(self, cid, conf, train_data, test_data, device, **kwargs): 5 | super(CustomizedClient, self).__init__(cid, conf, train_data, test_data, device, **kwargs) 6 | self.is_byz = False 7 | 8 | def set_byz(self, is_byz: bool=True): 9 | self.is_byz = is_byz 10 | pass 11 | 12 | def post_train(self): 13 | self.model.cpu() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cachetools==4.2.4 2 | certifi==2023.5.7 3 | chardet==3.0.4 4 | easyfl==0.1.2 5 | etcd3==0.12.0 6 | google-auth==1.21.2 7 | grpcio==1.54.2 8 | grpcio-tools==1.54.2 9 | idna==2.10 10 | kubernetes==11.0.0 11 | numpy==1.24.3 12 | oauthlib==3.2.2 13 | omegaconf==2.0.0 14 | Pillow==9.5.0 15 | protobuf==3.20.3 16 | pyasn1==0.5.0 17 | pyasn1-modules==0.3.0 18 | python-dateutil==2.8.2 19 | PyYAML==6.0 20 | requests==2.24.0 21 | requests-oauthlib==1.3.1 22 | rsa==4.9 23 | six==1.16.0 24 | tenacity==8.2.2 25 | torch==1.9.1 26 | torchvision==0.10.1 27 | tqdm==4.32.1 28 | typing_extensions==4.6.3 29 | urllib3==1.25.11 30 | websocket-client==1.5.3 31 | -------------------------------------------------------------------------------- /attacker/lie.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | from utils import flatten_models, unflatten_tensor 5 | 6 | from .base import Attacker 7 | 8 | class Lie(Attacker): 9 | def attack(self, sampled_clients: list, server): 10 | ref_models = self.get_ref_models(sampled_clients) 11 | flat_models, struct = flatten_models(ref_models) 12 | 13 | mu = flat_models.mean(dim=0) 14 | sigma = flat_models.var(dim=0, unbiased=False) 15 | flat_byz_model = mu - self.conf.attacker.lie_z * sigma 16 | 17 | byz_state_dict = unflatten_tensor(flat_byz_model, struct) 18 | 19 | self.set_byz_uploaded_content(sampled_clients, byz_state_dict, server) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import easyfl 2 | from server import RobustServer 3 | from client import CustomizedClient 4 | 5 | # Customized configuration. 6 | config = { 7 | "attacker": {"byz_ratio": 0.2, "lie_z": 1.5}, 8 | "data": {"dataset": "cifar10", "root": "./datasets", "split_type": "dir", "num_of_clients": 100}, 9 | "server": {"rounds": 10, "clients_per_round": 10, "use_gas": True, "gas_p": 1000, "base_agg": "bulyan"}, 10 | "client": {"local_epoch": 1}, 11 | "model": "resnet18", 12 | "test_mode": "test_in_server", 13 | "gpu": 1, 14 | } 15 | 16 | easyfl.register_server(RobustServer) 17 | easyfl.register_client(CustomizedClient) 18 | 19 | # Initialize federated learning with default configurations. 20 | easyfl.init(config) 21 | # Execute federated learning training. 22 | easyfl.run() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # README 2 | 3 | This repository provides code of ICML'23 paper 4 | 5 | [Byzantine-Robust Learning on Heterogeneous Data via Gradient Splitting]: https://arxiv.org/abs/2302.06079 6 | 7 | ## Set up 8 | 9 | Install conda environment 10 | 11 | ```sh 12 | conda env create -f environment.yml 13 | conda activate gas 14 | ``` 15 | 16 | ## Running code 17 | 18 | ```sh 19 | python main.py 20 | ``` 21 | 22 | change `config` in `main.py` to test GAS in different settings. 23 | 24 | ## Reference 25 | 26 | ```tex 27 | @misc{liu2023byzantinerobust, 28 | title={Byzantine-Robust Learning on Heterogeneous Data via Gradient Splitting}, 29 | author={Yuchen Liu and Chen Chen and Lingjuan Lyu and Fangzhao Wu and Sai Wu and Gang Chen}, 30 | year={2023}, 31 | eprint={2302.06079}, 32 | archivePrefix={arXiv}, 33 | primaryClass={cs.LG} 34 | } 35 | ``` 36 | 37 | -------------------------------------------------------------------------------- /attacker/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from client import CustomizedClient 4 | from easyfl.server.base import MODEL 5 | from typing import ( 6 | Tuple, 7 | ) 8 | 9 | class Attacker: 10 | def __init__(self, conf, byz_clients: list[CustomizedClient]) -> None: 11 | self.conf = conf 12 | self.byz_clients = byz_clients 13 | 14 | def attack(self, sampled_clients: list, server): 15 | raise Exception('instantiate attack') 16 | 17 | def get_ref_models(self, sampled_clients: list[CustomizedClient]): 18 | ref_models = [sampled_client.model for sampled_client in sampled_clients] 19 | return ref_models 20 | 21 | def set_byz_uploaded_content(self, sampled_clients, byz_state_dict, server): 22 | sampled_byz_clients = self.get_sampled_byz_clients(sampled_clients) 23 | for sampled_byz_client in sampled_byz_clients: 24 | server._client_uploads[MODEL][sampled_byz_client.cid].load_state_dict(byz_state_dict) 25 | 26 | def get_sampled_byz_clients(self, sampled_clients): 27 | sampled_byz_clients = list(set(sampled_clients).intersection(self.byz_clients)) 28 | return sampled_byz_clients -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gas 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.05.30=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.3=he6710b0_2 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=1.1.1t=h7f8727e_0 15 | - pip=23.0.1=py38h06a4308_0 16 | - python=3.8.3=hcff3b4d_2 17 | - readline=8.2=h5eee18b_0 18 | - setuptools=67.8.0=py38h06a4308_0 19 | - sqlite=3.38.2=hc218d9a_0 20 | - tk=8.6.11=h1ccaba5_0 21 | - wheel=0.38.4=py38h06a4308_0 22 | - xz=5.4.2=h5eee18b_0 23 | - zlib=1.2.11=0 24 | - pip: 25 | - cachetools==4.2.4 26 | - certifi==2023.5.7 27 | - chardet==3.0.4 28 | - easyfl==0.1.2 29 | - etcd3==0.12.0 30 | - google-auth==1.21.2 31 | - grpcio==1.54.2 32 | - grpcio-tools==1.54.2 33 | - idna==2.10 34 | - kubernetes==11.0.0 35 | - numpy==1.24.3 36 | - oauthlib==3.2.2 37 | - omegaconf==2.0.0 38 | - pillow==9.5.0 39 | - protobuf==3.20.3 40 | - pyasn1==0.5.0 41 | - pyasn1-modules==0.3.0 42 | - python-dateutil==2.8.2 43 | - pyyaml==6.0 44 | - requests==2.24.0 45 | - requests-oauthlib==1.3.1 46 | - rsa==4.9 47 | - six==1.16.0 48 | - tenacity==8.2.2 49 | - torch==1.9.1+cu111 50 | - torchvision==0.10.1+cu111 51 | - tqdm==4.32.1 52 | - typing-extensions==4.6.3 53 | - urllib3==1.25.11 54 | - websocket-client==1.5.3 55 | prefix: /home/lyc/anaconda3/envs/gas 56 | 57 | -------------------------------------------------------------------------------- /server/agg_funs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import Register 3 | 4 | agg_funs = Register() 5 | 6 | @agg_funs.register('median') 7 | @torch.no_grad() 8 | def agg_median(tensor: torch.Tensor, knowledge, conf): 9 | return tensor.median(dim=0)[0] 10 | 11 | @agg_funs.register('rbtm') 12 | @torch.no_grad() 13 | def agg_rbrm(tensor: torch.Tensor, knowledge, conf): 14 | n_cl = len(tensor) 15 | n_byz = knowledge.n_byz 16 | n_byz = min(n_byz, (n_cl - 1) // 2) 17 | if n_byz == 0: 18 | agg_tensor = tensor.mean(dim=0) 19 | else: 20 | sorted_tensor, _ = tensor.sort(dim=0) 21 | agg_tensor = sorted_tensor[n_byz:-n_byz].mean(dim=0) 22 | return agg_tensor 23 | 24 | 25 | @agg_funs.register('krum') 26 | @torch.no_grad() 27 | def agg_krum(tensor: torch.Tensor, knowledge, conf): 28 | n_cl = len(tensor) 29 | n_byz = min(knowledge.n_byz, (n_cl - 3) // 2) 30 | 31 | squared_dists = torch.cdist(tensor, tensor).square() 32 | topk_dists, _ = squared_dists.topk(k=n_cl - n_byz -1, dim=-1, largest=False, sorted=False) 33 | scores = topk_dists.sum(dim=-1) 34 | _, candidate_idxs = scores.topk(k=n_cl - n_byz, dim=-1, largest=False) 35 | 36 | agg_tensor = tensor[candidate_idxs].mean(dim=0) 37 | 38 | return agg_tensor 39 | 40 | 41 | @agg_funs.register('bulyan') 42 | @torch.no_grad() 43 | def agg_bulyan(tensor: torch.Tensor, knowledge, conf): 44 | n_cl = len(tensor) 45 | n_byz = min(knowledge.n_byz, (n_cl - 3) // 4) 46 | 47 | squared_dists = torch.cdist(tensor, tensor).square() 48 | topk_dists, _ = squared_dists.topk(k=n_cl - n_byz -1, dim=-1, largest=False, sorted=False) 49 | scores = topk_dists.sum(dim=-1) 50 | _, krum_candidate_idxs = scores.topk(k=n_cl - 2 * n_byz, dim=-1, largest=False) 51 | 52 | krum_tensor = tensor[krum_candidate_idxs] 53 | med, _ = krum_tensor.median(dim=0) 54 | dist = (krum_tensor - med).abs() 55 | tr_updates, _ = dist.topk(k=n_cl - 4 * n_byz, dim=0, largest=False) 56 | agg_tensor = tr_updates.mean(dim=0) 57 | 58 | return agg_tensor -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | 5 | class Register(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(Register, self).__init__(*args, **kwargs) 8 | self._dict = {} 9 | 10 | def register(self, target): 11 | def add_item(key, value): 12 | if not callable(value): 13 | raise Exception(f"Error:{value} must be callable!") 14 | if key in self._dict: 15 | print(f"\033[31mWarning:\033[0m {value.__name__} already exists and will be overwritten!") 16 | self[key] = value 17 | return value 18 | 19 | if callable(target): 20 | return add_item(target.__name__, target) 21 | else: 22 | return lambda x : add_item(target, x) 23 | 24 | def __setitem__(self, key, value): 25 | self._dict[key] = value 26 | 27 | def __getitem__(self, key): 28 | return self._dict[key] 29 | 30 | def __contains__(self, key): 31 | return key in self._dict 32 | 33 | def __str__(self): 34 | return str(self._dict) 35 | 36 | def keys(self): 37 | return self._dict.keys() 38 | 39 | def values(self): 40 | return self._dict.values() 41 | 42 | def items(self): 43 | return self._dict.items() 44 | 45 | def flatten_models(models: list[torch.nn.Module]): 46 | flat_model_lt = [] 47 | name_shape_tuples = None 48 | 49 | for model in models: 50 | flat_state_dict, name_shape_tuples = flatten_model(model) 51 | flat_model_lt.append(flat_state_dict) 52 | 53 | flat_models = torch.stack(flat_model_lt) 54 | struct = { 55 | 'name_shape_tuples': name_shape_tuples 56 | } 57 | 58 | return flat_models, struct 59 | 60 | def flatten_model(model: torch.nn.Module): 61 | flat_params = [] 62 | name_shape_tuples = [] 63 | for name, param in model.state_dict().items(): 64 | flat_param = param.view(-1) 65 | flat_params.append(flat_param) 66 | name_shape_tuples.append((name, param.shape)) 67 | flat_model = torch.cat(flat_params) 68 | return flat_model, name_shape_tuples 69 | 70 | def unflatten_tensor(flat_tensor, struct): 71 | name_shape_tuples = struct['name_shape_tuples'] 72 | split_size = [t[1].numel() for t in name_shape_tuples] 73 | split_tensors = torch.split(flat_tensor, split_size) 74 | state_dict = {name: split_tensor.view(shape) for (name, shape), split_tensor in zip(name_shape_tuples, split_tensors)} 75 | return state_dict -------------------------------------------------------------------------------- /server/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import copy 4 | import time 5 | import torch 6 | 7 | from argparse import Namespace 8 | from easyfl.server import BaseServer 9 | from easyfl.tracking import metric 10 | from easyfl.utils.float import rounding 11 | from math import ceil 12 | from omegaconf import OmegaConf 13 | 14 | from attacker import Lie 15 | from client import CustomizedClient 16 | from utils import flatten_models, unflatten_tensor 17 | from .agg_funs import agg_funs 18 | 19 | class RobustServer(BaseServer): 20 | def start(self, model, clients): 21 | """Start federated learning process, including training and testing. 22 | 23 | Args: 24 | model (nn.Module): The model to train. 25 | clients (list[:obj:`BaseClient`]|list[str]): Available clients. 26 | Clients are actually client grpc addresses when in remote training. 27 | """ 28 | # Setup 29 | self._start_time = time.time() 30 | self._reset() 31 | self.set_model(model) 32 | self.set_clients(clients) 33 | # customize 34 | byz_clients = self.get_byz_clients(clients) 35 | self.attacker = Lie(self.conf, byz_clients) 36 | 37 | if self._should_track(): 38 | self._tracker.create_task(self.conf.task_id, OmegaConf.to_container(self.conf)) 39 | 40 | # Get initial testing accuracies 41 | if self.conf.server.test_all: 42 | if self._should_track(): 43 | self._tracker.set_round(self._current_round) 44 | self.test() 45 | self.save_tracker() 46 | 47 | while not self.should_stop(): 48 | self._round_time = time.time() 49 | 50 | self._current_round += 1 51 | self.print_("\n-------- round {} --------".format(self._current_round)) 52 | 53 | # Train 54 | self.pre_train() 55 | self.train() 56 | self.post_train() 57 | 58 | # Test 59 | if self._do_every(self.conf.server.test_every, self._current_round, self.conf.server.rounds): 60 | self.pre_test() 61 | self.test() 62 | self.post_test() 63 | 64 | # Save Model 65 | self.save_model() 66 | 67 | self.track(metric.ROUND_TIME, time.time() - self._round_time) 68 | self.save_tracker() 69 | 70 | self.print_("Accuracies: {}".format(rounding(self._accuracies, 4))) 71 | self.print_("Cumulative training time: {}".format(rounding(self._cumulative_times, 2))) 72 | 73 | def train(self): 74 | """Training process of federated learning.""" 75 | self.print_("--- start training ---") 76 | 77 | self.selection(self._clients, self.conf.server.clients_per_round) 78 | self.grouping_for_distributed() 79 | self.compression() 80 | 81 | begin_train_time = time.time() 82 | self.distribution_to_train() 83 | end_train_time = time.time() 84 | train_time = end_train_time - begin_train_time 85 | self.print_("Honest client train time: {}".format(train_time)) 86 | 87 | start_attack_time = time.time() 88 | self.attacker.attack(self.selected_clients, self) 89 | end_attack_time = time.time() 90 | attack_time = end_attack_time - start_attack_time 91 | self.print_("Byzantine client attack time: {}".format(attack_time)) 92 | 93 | start_aggretate_time = time.time() 94 | self.aggregation() 95 | end_aggretate_time = time.time() 96 | aggregate_time = end_aggretate_time - start_aggretate_time 97 | self.print_("Aggregate time: {}".format(aggregate_time)) 98 | 99 | self.track(metric.TRAIN_TIME, train_time) 100 | ### track other times 101 | 102 | def get_num_byz_clients(self): 103 | num_clients = self.conf.data.num_of_clients 104 | byz_ratio: float = self.conf.attacker.byz_ratio 105 | num_byz = ceil(num_clients * byz_ratio) 106 | assert 0 <= num_byz <= num_clients, f"invalid byz_ratio {byz_ratio}" 107 | return num_byz 108 | 109 | def get_byz_clients(self, clients: list[CustomizedClient]): 110 | num_byz = self.get_num_byz_clients() 111 | 112 | byz_clients = clients[:num_byz] 113 | for client in byz_clients: 114 | client.set_byz() 115 | 116 | return byz_clients 117 | 118 | def set_attacker(self, attacker): 119 | self.attacker = attacker 120 | 121 | def aggregate(self, models, weights): 122 | if self.conf.server.use_gas: 123 | agg_model = self.gas_aggregate(models, weights) 124 | else: 125 | flat_models, struct = flatten_models(models) 126 | 127 | base_agg = agg_funs[self.conf.server.base_agg] 128 | n_sel_byz = sum((1 if selected_client.is_byz else 0) for selected_client in self.selected_clients) 129 | knowledge = Namespace(n_byz=n_sel_byz) 130 | flat_agg_model = base_agg(flat_models, knowledge, self.conf) 131 | 132 | agg_state_dict = unflatten_tensor(flat_agg_model, struct) 133 | agg_model = copy.deepcopy(models[0]) 134 | agg_model.load_state_dict(agg_state_dict) 135 | return agg_model 136 | 137 | @torch.no_grad() 138 | def gas_aggregate(self, models, weights): 139 | # flatten 140 | flat_models, struct = flatten_models(models) 141 | # splitting 142 | groups = self.split(flat_models) 143 | # identification 144 | base_agg = agg_funs[self.conf.server.base_agg] 145 | n_cl = len(flat_models) 146 | n_sel_byz = sum((1 if selected_client.is_byz else 0) for selected_client in self.selected_clients) 147 | knowledge = Namespace(n_byz=n_sel_byz) 148 | identification_scores = torch.zeros(n_cl) 149 | for group in groups: 150 | group_agg = base_agg(group, knowledge, self.conf) 151 | group_scores = (group - group_agg).square().sum(dim=-1).sqrt().cpu() 152 | identification_scores += group_scores 153 | _, cand_idxs = identification_scores.topk(k=n_cl - n_sel_byz, largest=False) 154 | n_agg_byz = sum([self.selected_clients[i].is_byz for i in cand_idxs.tolist()]) 155 | self.print_(f"Aggregated byzantine / selected byzantine: {n_agg_byz} / {n_sel_byz}") 156 | # aggregation 157 | flat_agg_model = flat_models[cand_idxs].mean(dim=0) 158 | # unflatten 159 | agg_state_dict = unflatten_tensor(flat_agg_model, struct) 160 | agg_model = copy.deepcopy(models[0]) 161 | agg_model.load_state_dict(agg_state_dict) 162 | 163 | return agg_model 164 | 165 | 166 | @torch.no_grad() 167 | def split(self, flat_models): 168 | d = flat_models.shape[1] 169 | shuffled_dims = torch.randperm(d).to(flat_models.device) 170 | p = self.conf.server.gas_p 171 | partition = torch.chunk(shuffled_dims, chunks=p) 172 | groups = [flat_models[:, partition_i] for partition_i in partition] 173 | return groups --------------------------------------------------------------------------------