├── .gitignore ├── LICENSE ├── README.md ├── data ├── __init__.py └── utils │ ├── __init__.py │ ├── constants.py │ ├── dataset.py │ ├── partition │ ├── __init__.py │ ├── assign_classes.py │ └── dirichlet.py │ ├── run.py │ └── util.py ├── requirements.txt └── src ├── client ├── __init__.py ├── base.py ├── fedavg.py ├── fedprox.py └── scaffold.py ├── config ├── __init__.py ├── models.py └── util.py └── server ├── __init__.py ├── base.py ├── fedavg.py ├── fedprox.py └── scaffold.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .vscode 3 | .idea 4 | temp 5 | logs 6 | mnist 7 | emnist 8 | fmnist 9 | cifar10 10 | cifar100 11 | synthetic 12 | *.json -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright © 2022 Jiahao Tan 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SCAFFOLD: Stochastic Controlled Averaging for Federated Learning [[ArXiv]](https://arxiv.org/abs/1910.06378) 2 | 3 | This repo is the PyTorch implementation of SCAFFOLD. 4 | 5 | I further implement FedAvg and FedProx for you.🤗 6 | 7 | For simulating Non-I.I.D scenario, the dataset can be splitted based on Dirchlet distribution or assign random classes to each client. 8 | 9 | Note that I have recently released a [benchmark of federated learning](https://github.com/KarhouTam/FL-bench) that includes this method and many ohter baselines. Welcome to check my benchmark and star it! 🤗 10 | 11 | ## Preprocess dataset 12 | 13 | MNIST, EMNIST, FashionMNIST, CIFAR10, CIFAR100 are supported. 14 | 15 | ```python 16 | python ./data/utils/run.py --dataset ${dataset} 17 | ``` 18 | The way of preprocessing is adjustable. Check `./data/utils/run.py` for more argument details 19 | ## Run the experiment 20 | 21 | ❗ Before run the experiment, please make sure that the dataset is downloaded and preprocessed already. 22 | 23 | It’s so simple.🤪 24 | 25 | ```python 26 | python ./src/server/${algo}.py 27 | ``` 28 | 29 | You can check `./src/config/util.py` for all hyperparameters detail. 30 | 31 | 32 | ## Result 33 | 34 | ❗NOTE: The dataset settings, hyperparameters, and model backbone in this repo are not the same as in the SCAFFOLD paper. So the result below doesn't mean anything. 35 | 36 | This repo is just for showing the process of SCAFFOLD. 37 | 38 | If something wrong you find in any alogorithms' process in this repo, just let me know. 🤗 39 | 40 | Some stats about convergence speed are shown below. 41 | 42 | `--dataset`: `emnist`. Splitted by Dirchlet(0.5) 43 | 44 | `--global_epochs`: `100` 45 | 46 | `--local_epochs`: `10` 47 | 48 | `--client_num_in_total`: `10` 49 | 50 | `--client_num_per_round`: `2` 51 | 52 | `--local_lr`: `1e-2` 53 | 54 | `--seed`: `17` 55 | 56 | 57 | | Algo | Epoch to 50% Acc | Epoch to 60% Acc | Epoch to 70% Acc | Epoch to 80% Acc | Test Acc | 58 | | -------- | ---------------- | ---------------- | ---------------- | ---------------- | -------- | 59 | | FedAvg | 6 | 16 | 30 | 56 | 70.00% | 60 | | FedProx | 12 | 14 | 30 | 56 | 66.72% | 61 | | SCAFFOLD | 6 | 15 | 27 | - | 53.93% | 62 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarhouTam/SCAFFOLD-PyTorch/69791f03d4d0951e5f3e7cf181d32a750ff9d1d0/data/__init__.py -------------------------------------------------------------------------------- /data/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarhouTam/SCAFFOLD-PyTorch/69791f03d4d0951e5f3e7cf181d32a750ff9d1d0/data/utils/__init__.py -------------------------------------------------------------------------------- /data/utils/constants.py: -------------------------------------------------------------------------------- 1 | MEAN = { 2 | "mnist": 0.1307, 3 | "cifar10": (0.4914, 0.4822, 0.4465), 4 | "cifar100": (0.5071, 0.4865, 0.4409), 5 | "femnist": 0, # dummy code, it's useless 6 | "synthetic": 0, # dummy code, it's useless 7 | "emnist": 0.1736, 8 | "fmnist": 0.2860, 9 | } 10 | 11 | STD = { 12 | "mnist": 0.3015, 13 | "cifar10": (0.2023, 0.1994, 0.2010), 14 | "cifar100": (0.2009, 0.1984, 0.2023), 15 | "femnist": 1.0, # dummy code, it's useless 16 | "synthetic": 1.0, # dummy code, it's useless 17 | "emnist": 0.3248, 18 | "fmnist": 0.3205, 19 | } 20 | -------------------------------------------------------------------------------- /data/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from torchvision.transforms import ToTensor 5 | 6 | 7 | class MNISTDataset(Dataset): 8 | def __init__(self, data, targets, transform=None, target_transform=None,) -> None: 9 | self.transform = transform 10 | self.target_transform = target_transform 11 | _data = data 12 | _targets = targets 13 | if not isinstance(_data, torch.Tensor): 14 | if not isinstance(_data, np.ndarray): 15 | _data = ToTensor()(_data) 16 | else: 17 | _data = torch.tensor(_data) 18 | self.data = _data.float().unsqueeze(1) 19 | 20 | if not isinstance(_targets, torch.Tensor): 21 | _targets = torch.tensor(_targets) 22 | self.targets = _targets.long() 23 | 24 | def __getitem__(self, index): 25 | data, targets = self.data[index], self.targets[index] 26 | 27 | if self.transform is not None: 28 | data = self.transform(self.data[index]) 29 | 30 | if self.target_transform is not None: 31 | targets = self.target_transform(self.targets[index]) 32 | 33 | return data, targets 34 | 35 | def __len__(self): 36 | return len(self.targets) 37 | 38 | 39 | class CIFARDataset(Dataset): 40 | def __init__(self, data, targets, transform=None, target_transform=None,) -> None: 41 | self.transform = transform 42 | self.target_transform = target_transform 43 | _data = data 44 | _targets = targets 45 | if not isinstance(_data, torch.Tensor): 46 | if not isinstance(_data, np.ndarray): 47 | _data = ToTensor()(_data) 48 | else: 49 | _data = torch.tensor(_data) 50 | self.data = torch.permute(_data, [0, -1, 1, 2]).float() 51 | if not isinstance(_targets, torch.Tensor): 52 | _targets = torch.tensor(_targets) 53 | self.targets = _targets.long() 54 | 55 | def __getitem__(self, index): 56 | data, targets = self.data[index], self.targets[index] 57 | 58 | if self.transform is not None: 59 | data = self.transform(self.data[index]) 60 | 61 | if self.target_transform is not None: 62 | targets = self.target_transform(self.targets[index]) 63 | 64 | return data, targets 65 | 66 | def __len__(self): 67 | return len(self.targets) 68 | -------------------------------------------------------------------------------- /data/utils/partition/__init__.py: -------------------------------------------------------------------------------- 1 | from .dirichlet import dirichlet_distribution 2 | from .assign_classes import randomly_assign_classes 3 | 4 | __all__ = ["dirichlet_distribution", "randomly_assign_classes"] 5 | -------------------------------------------------------------------------------- /data/utils/partition/assign_classes.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import Counter 3 | from typing import Dict, List, Tuple 4 | 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | 8 | 9 | def sort_and_alloc( 10 | datasets: List[Dataset], num_clients: int, num_classes: int 11 | ) -> Dict[int, np.ndarray]: 12 | total_sample_nums = sum(map(lambda ds: len(ds), datasets)) 13 | num_shards = num_clients * num_classes 14 | # one shard's length indicate how many data samples that belongs to one class that one client can obtain. 15 | size_of_shards = int(total_sample_nums / num_shards) 16 | 17 | dict_users = {i: np.array([], dtype=np.int64) for i in range(num_clients)} 18 | 19 | labels = np.concatenate([ds.targets for ds in datasets], axis=0, dtype=np.int64) 20 | idxs = np.arange(total_sample_nums) 21 | 22 | # sort sample indices according to labels 23 | idxs_labels = np.vstack((idxs, labels)) 24 | # corresponding labels after sorting are [0, .., 0, 1, ..., 1, ...] 25 | idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()] 26 | idxs = idxs_labels[0, :] 27 | 28 | # assign 29 | idx_shard = [i for i in range(num_shards)] 30 | for i in range(num_clients): 31 | rand_set = random.sample(idx_shard, num_classes) 32 | idx_shard = list(set(idx_shard) - set(rand_set)) 33 | for rand in rand_set: 34 | dict_users[i] = np.concatenate( 35 | ( 36 | dict_users[i], 37 | idxs[rand * size_of_shards : (rand + 1) * size_of_shards], 38 | ), 39 | axis=0, 40 | ) 41 | 42 | return dict_users 43 | 44 | 45 | def randomly_assign_classes( 46 | ori_datasets: List[Dataset], 47 | target_dataset: Dataset, 48 | num_clients: int, 49 | num_classes: int, 50 | transform=None, 51 | target_transform=None, 52 | ) -> Tuple[List[Dataset], Dict[str, Dict[str, int]]]: 53 | stats = {} 54 | dict_users = sort_and_alloc(ori_datasets, num_clients, num_classes) 55 | targets_numpy = np.concatenate( 56 | [ds.targets for ds in ori_datasets], axis=0, dtype=np.int64 57 | ) 58 | data_numpy = np.concatenate( 59 | [ds.data for ds in ori_datasets], axis=0, dtype=np.float32 60 | ) 61 | datasets = [] 62 | for i, indices in dict_users.items(): 63 | stats[i] = {"x": None, "y": None} 64 | stats[i]["x"] = len(indices) 65 | stats[i]["y"] = Counter(targets_numpy[indices].tolist()) 66 | datasets.append( 67 | target_dataset( 68 | data=data_numpy[indices], 69 | targets=targets_numpy[indices], 70 | transform=transform, 71 | target_transform=target_transform, 72 | ) 73 | ) 74 | return datasets, stats 75 | -------------------------------------------------------------------------------- /data/utils/partition/dirichlet.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from typing import Dict, List, Tuple 3 | 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | 7 | 8 | def dirichlet_distribution( 9 | ori_dataset: List[Dataset], 10 | target_dataset: Dataset, 11 | num_clients: int, 12 | alpha: float, 13 | transform=None, 14 | target_transform=None, 15 | ) -> Tuple[List[Dataset], Dict]: 16 | NUM_CLASS = len(ori_dataset[0].classes) 17 | MIN_SIZE = 0 18 | X = [[] for _ in range(num_clients)] 19 | Y = [[] for _ in range(num_clients)] 20 | stats = {} 21 | targets_numpy = np.concatenate( 22 | [ds.targets for ds in ori_dataset], axis=0, dtype=np.int64 23 | ) 24 | data_numpy = np.concatenate( 25 | [ds.data for ds in ori_dataset], axis=0, dtype=np.float32 26 | ) 27 | idx = [np.where(targets_numpy == i)[0] for i in range(NUM_CLASS)] 28 | 29 | while MIN_SIZE < 10: 30 | idx_batch = [[] for _ in range(num_clients)] 31 | for k in range(NUM_CLASS): 32 | np.random.shuffle(idx[k]) 33 | distributions = np.random.dirichlet(np.repeat(alpha, num_clients)) 34 | distributions = np.array( 35 | [ 36 | p * (len(idx_j) < len(targets_numpy) / num_clients) 37 | for p, idx_j in zip(distributions, idx_batch) 38 | ] 39 | ) 40 | distributions = distributions / distributions.sum() 41 | distributions = (np.cumsum(distributions) * len(idx[k])).astype(int)[:-1] 42 | idx_batch = [ 43 | np.concatenate((idx_j, idx.tolist())).astype(np.int64) 44 | for idx_j, idx in zip(idx_batch, np.split(idx[k], distributions)) 45 | ] 46 | MIN_SIZE = min([len(idx_j) for idx_j in idx_batch]) 47 | 48 | for i in range(num_clients): 49 | stats[i] = {"x": None, "y": None} 50 | np.random.shuffle(idx_batch[i]) 51 | X[i] = data_numpy[idx_batch[i]] 52 | Y[i] = targets_numpy[idx_batch[i]] 53 | stats[i]["x"] = len(X[i]) 54 | stats[i]["y"] = Counter(Y[i].tolist()) 55 | 56 | datasets = [ 57 | target_dataset( 58 | data=X[j], 59 | targets=Y[j], 60 | transform=transform, 61 | target_transform=target_transform, 62 | ) 63 | for j in range(num_clients) 64 | ] 65 | return datasets, stats 66 | -------------------------------------------------------------------------------- /data/utils/run.py: -------------------------------------------------------------------------------- 1 | import time 2 | from path import Path 3 | 4 | _CURRENT_DIR = Path(__file__).parent.abspath() 5 | import sys 6 | 7 | sys.path.append(_CURRENT_DIR) 8 | sys.path.append(_CURRENT_DIR.parent) 9 | import json 10 | import os 11 | import pickle 12 | import random 13 | from argparse import ArgumentParser 14 | 15 | import numpy as np 16 | import torch 17 | from torch.utils.data import random_split 18 | from torchvision import transforms 19 | from torchvision.datasets import CIFAR10, CIFAR100, EMNIST, MNIST, FashionMNIST 20 | 21 | from constants import MEAN, STD 22 | from partition import dirichlet_distribution, randomly_assign_classes 23 | from utils.dataset import CIFARDataset, MNISTDataset 24 | 25 | DATASET = { 26 | "mnist": (MNIST, MNISTDataset), 27 | "emnist": (EMNIST, MNISTDataset), 28 | "fmnist": (FashionMNIST, MNISTDataset), 29 | "cifar10": (CIFAR10, CIFARDataset), 30 | "cifar100": (CIFAR100, CIFARDataset), 31 | } 32 | 33 | 34 | def main(args): 35 | _DATASET_ROOT = ( 36 | Path(args.root).abspath() / args.dataset 37 | if args.root is not None 38 | else _CURRENT_DIR.parent / args.dataset 39 | ) 40 | _PICKLES_DIR = _CURRENT_DIR.parent / args.dataset / "pickles" 41 | 42 | np.random.seed(args.seed) 43 | random.seed(args.seed) 44 | torch.manual_seed(args.seed) 45 | 46 | classes_map = None 47 | transform = transforms.Compose( 48 | [ 49 | transforms.Normalize(MEAN[args.dataset], STD[args.dataset]), 50 | ] 51 | ) 52 | target_transform = None 53 | 54 | if not os.path.isdir(_DATASET_ROOT): 55 | os.mkdir(_DATASET_ROOT) 56 | if os.path.isdir(_PICKLES_DIR): 57 | os.system(f"rm -rf {_PICKLES_DIR}") 58 | os.system(f"mkdir -p {_PICKLES_DIR}") 59 | 60 | client_num_in_total = args.client_num_in_total 61 | client_num_in_total = args.client_num_in_total 62 | ori_dataset, target_dataset = DATASET[args.dataset] 63 | if args.dataset == "emnist": 64 | trainset = ori_dataset( 65 | _DATASET_ROOT, 66 | train=True, 67 | download=True, 68 | split=args.emnist_split, 69 | transform=transforms.ToTensor(), 70 | ) 71 | testset = ori_dataset( 72 | _DATASET_ROOT, 73 | train=False, 74 | split=args.emnist_split, 75 | transform=transforms.ToTensor(), 76 | ) 77 | else: 78 | trainset = ori_dataset( 79 | _DATASET_ROOT, 80 | train=True, 81 | download=True, 82 | ) 83 | testset = ori_dataset( 84 | _DATASET_ROOT, 85 | train=False, 86 | ) 87 | concat_datasets = [trainset, testset] 88 | if args.alpha > 0: # NOTE: Dirichlet(alpha) 89 | all_datasets, stats = dirichlet_distribution( 90 | ori_dataset=concat_datasets, 91 | target_dataset=target_dataset, 92 | num_clients=client_num_in_total, 93 | alpha=args.alpha, 94 | transform=transform, 95 | target_transform=target_transform, 96 | ) 97 | else: # NOTE: sort and partition 98 | classes = len(ori_dataset.classes) if args.classes <= 0 else args.classes 99 | all_datasets, stats = randomly_assign_classes( 100 | ori_datasets=concat_datasets, 101 | target_dataset=target_dataset, 102 | num_clients=client_num_in_total, 103 | num_classes=classes, 104 | transform=transform, 105 | target_transform=target_transform, 106 | ) 107 | 108 | for subset_id, client_id in enumerate( 109 | range(0, len(all_datasets), args.client_num_in_each_pickles) 110 | ): 111 | subset = [] 112 | for dataset in all_datasets[ 113 | client_id : client_id + args.client_num_in_each_pickles 114 | ]: 115 | num_val_samples = int(len(dataset) * args.valset_ratio) 116 | num_test_samples = int(len(dataset) * args.test_ratio) 117 | num_train_samples = len(dataset) - num_val_samples - num_test_samples 118 | train, val, test = random_split( 119 | dataset, [num_train_samples, num_val_samples, num_test_samples] 120 | ) 121 | subset.append({"train": train, "val": val, "test": test}) 122 | with open(_PICKLES_DIR / str(subset_id) + ".pkl", "wb") as f: 123 | pickle.dump(subset, f) 124 | 125 | # save stats 126 | if args.type == "user": 127 | train_clients_num = int(client_num_in_total * args.fraction) 128 | clients_4_train = [i for i in range(train_clients_num)] 129 | clients_4_test = [i for i in range(train_clients_num, client_num_in_total)] 130 | 131 | with open(_PICKLES_DIR / "seperation.pkl", "wb") as f: 132 | pickle.dump( 133 | { 134 | "train": clients_4_train, 135 | "test": clients_4_test, 136 | "total": client_num_in_total, 137 | }, 138 | f, 139 | ) 140 | 141 | train_clients_stats = dict( 142 | zip(clients_4_train, list(stats.values())[:train_clients_num]) 143 | ) 144 | test_clients_stats = dict( 145 | zip( 146 | clients_4_test, 147 | list(stats.values())[train_clients_num:], 148 | ) 149 | ) 150 | 151 | with open(_CURRENT_DIR.parent / args.dataset / "all_stats.json", "w") as f: 152 | json.dump({"train": train_clients_stats, "test": test_clients_stats}, f) 153 | 154 | else: # NOTE: "sample" save stats 155 | client_id_indices = [i for i in range(client_num_in_total)] 156 | with open(_PICKLES_DIR / "seperation.pkl", "wb") as f: 157 | pickle.dump( 158 | { 159 | "id": client_id_indices, 160 | "total": client_num_in_total, 161 | }, 162 | f, 163 | ) 164 | with open(_CURRENT_DIR.parent / args.dataset / "all_stats.json", "w") as f: 165 | json.dump(stats, f) 166 | 167 | args.root = ( 168 | Path(args.root).abspath() 169 | if str(_DATASET_ROOT) != str(_CURRENT_DIR.parent / args.dataset) 170 | else None 171 | ) 172 | 173 | 174 | if __name__ == "__main__": 175 | parser = ArgumentParser() 176 | parser.add_argument( 177 | "--dataset", 178 | type=str, 179 | choices=[ 180 | "mnist", 181 | "cifar10", 182 | "cifar100", 183 | "emnist", 184 | "fmnist", 185 | ], 186 | default="mnist", 187 | ) 188 | 189 | parser.add_argument("--client_num_in_total", type=int, default=10) 190 | parser.add_argument( 191 | "--fraction", type=float, default=0.9, help="Propotion of train clients" 192 | ) 193 | parser.add_argument("--valset_ratio", type=float, default=0.1) 194 | parser.add_argument("--test_ratio", type=float, default=0.1) 195 | parser.add_argument( 196 | "--classes", 197 | type=int, 198 | default=-1, 199 | help="Num of classes that one client's data belong to.", 200 | ) 201 | parser.add_argument("--seed", type=int, default=int(time.time())) 202 | ################# Dirichlet distribution only ################# 203 | parser.add_argument( 204 | "--alpha", 205 | type=float, 206 | default=0, 207 | help="Only for controling data hetero degree while performing Dirichlet partition.", 208 | ) 209 | ############################################################### 210 | 211 | ################# For EMNIST only ##################### 212 | parser.add_argument( 213 | "--emnist_split", 214 | type=str, 215 | choices=["byclass", "bymerge", "letters", "balanced", "digits", "mnist"], 216 | default="byclass", 217 | ) 218 | ####################################################### 219 | parser.add_argument( 220 | "--type", type=str, choices=["sample", "user"], default="sample" 221 | ) 222 | parser.add_argument("--client_num_in_each_pickles", type=int, default=10) 223 | parser.add_argument("--root", type=str, default="/root/repos/python/mine/datasets") 224 | args = parser.parse_args() 225 | main(args) 226 | args_dict = dict(args._get_kwargs()) 227 | with open(_CURRENT_DIR.parent / "args.json", "w") as f: 228 | json.dump(args_dict, f) 229 | -------------------------------------------------------------------------------- /data/utils/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | import pickle 5 | from typing import Dict, List, Tuple, Union 6 | 7 | from path import Path 8 | from torch.utils.data import Subset, random_split 9 | 10 | _CURRENT_DIR = Path(__file__).parent.abspath() 11 | _ARGS_DICT = json.load(open(_CURRENT_DIR.parent / "args.json", "r")) 12 | 13 | 14 | def get_dataset( 15 | dataset: str, 16 | client_id: int, 17 | ) -> Dict[str, Subset]: 18 | client_num_in_each_pickles = _ARGS_DICT["client_num_in_each_pickles"] 19 | pickles_dir = _CURRENT_DIR.parent / dataset / "pickles" 20 | if os.path.isdir(pickles_dir) is False: 21 | raise RuntimeError("Please preprocess and create pickles first.") 22 | 23 | pickle_path = ( 24 | pickles_dir / f"{math.floor(client_id / client_num_in_each_pickles)}.pkl" 25 | ) 26 | with open(pickle_path, "rb") as f: 27 | subset = pickle.load(f) 28 | client_dataset = subset[client_id % client_num_in_each_pickles] 29 | trainset = client_dataset["train"] 30 | valset = client_dataset["val"] 31 | testset = client_dataset["test"] 32 | return {"train": trainset, "val": valset, "test": testset} 33 | 34 | 35 | def get_client_id_indices( 36 | dataset, 37 | ) -> Union[Tuple[List[int], List[int], int], Tuple[List[int], int]]: 38 | pickles_dir = _CURRENT_DIR.parent / dataset / "pickles" 39 | with open(pickles_dir / "seperation.pkl", "rb") as f: 40 | seperation = pickle.load(f) 41 | if _ARGS_DICT["type"] == "user": 42 | return seperation["train"], seperation["test"], seperation["total"] 43 | else: # NOTE: "sample" 44 | return seperation["id"], seperation["total"] 45 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | path 2 | torch 3 | numpy 4 | torchvision 5 | tqdm 6 | rich 7 | -------------------------------------------------------------------------------- /src/client/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarhouTam/SCAFFOLD-PyTorch/69791f03d4d0951e5f3e7cf181d32a750ff9d1d0/src/client/__init__.py -------------------------------------------------------------------------------- /src/client/base.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from copy import deepcopy 3 | from typing import Dict, List, OrderedDict, Tuple 4 | 5 | import torch 6 | import numpy as np 7 | from path import Path 8 | from rich.console import Console 9 | from torch.utils.data import Subset, DataLoader 10 | 11 | _CURRENT_DIR = Path(__file__).parent.abspath() 12 | 13 | import sys 14 | 15 | sys.path.append(_CURRENT_DIR.parent) 16 | 17 | from data.utils.util import get_dataset 18 | 19 | 20 | class ClientBase: 21 | def __init__( 22 | self, 23 | backbone: torch.nn.Module, 24 | dataset: str, 25 | batch_size: int, 26 | local_epochs: int, 27 | local_lr: float, 28 | logger: Console, 29 | gpu: int, 30 | ): 31 | self.device = torch.device( 32 | "cuda" if gpu and torch.cuda.is_available() else "cpu" 33 | ) 34 | self.client_id: int = None 35 | self.valset: Subset = None 36 | self.trainset: Subset = None 37 | self.testset: Subset = None 38 | self.model: torch.nn.Module = deepcopy(backbone).to(self.device) 39 | self.optimizer: torch.optim.Optimizer = torch.optim.SGD( 40 | self.model.parameters(), lr=local_lr 41 | ) 42 | self.dataset = dataset 43 | self.batch_size = batch_size 44 | self.local_epochs = local_epochs 45 | self.local_lr = local_lr 46 | self.criterion = torch.nn.CrossEntropyLoss() 47 | self.logger = logger 48 | self.untrainable_params: Dict[str, Dict[str, torch.Tensor]] = {} 49 | 50 | @torch.no_grad() 51 | def evaluate(self, use_valset=True): 52 | self.model.eval() 53 | criterion = torch.nn.CrossEntropyLoss(reduction="sum") 54 | loss = 0 55 | correct = 0 56 | dataloader = DataLoader(self.valset if use_valset else self.testset, 32) 57 | for x, y in dataloader: 58 | x, y = x.to(self.device), y.to(self.device) 59 | logits = self.model(x) 60 | loss += criterion(logits, y) 61 | pred = torch.softmax(logits, -1).argmax(-1) 62 | correct += (pred == y).int().sum() 63 | return loss.item(), correct.item() 64 | 65 | def train( 66 | self, 67 | client_id: int, 68 | model_params: OrderedDict[str, torch.Tensor], 69 | evaluate=True, 70 | verbose=False, 71 | use_valset=True, 72 | ) -> Tuple[List[torch.Tensor], int]: 73 | self.client_id = client_id 74 | self.set_parameters(model_params) 75 | self.get_client_local_dataset() 76 | res, stats = self._log_while_training(evaluate, verbose, use_valset)() 77 | return res, stats 78 | 79 | def _train(self): 80 | self.model.train() 81 | for _ in range(self.local_epochs): 82 | x, y = self.get_data_batch() 83 | logits = self.model(x) 84 | loss = self.criterion(logits, y) 85 | self.optimizer.zero_grad() 86 | loss.backward() 87 | self.optimizer.step() 88 | return ( 89 | list(self.model.state_dict(keep_vars=True).values()), 90 | len(self.trainset.dataset), 91 | ) 92 | 93 | def test( 94 | self, 95 | client_id: int, 96 | model_params: OrderedDict[str, torch.Tensor], 97 | ): 98 | self.client_id = client_id 99 | self.set_parameters(model_params) 100 | self.get_client_local_dataset() 101 | loss, correct = self.evaluate() 102 | stats = {"loss": loss, "correct": correct, "size": len(self.testset)} 103 | return stats 104 | 105 | def get_client_local_dataset(self): 106 | datasets = get_dataset( 107 | self.dataset, 108 | self.client_id, 109 | ) 110 | self.trainset = datasets["train"] 111 | self.valset = datasets["val"] 112 | self.testset = datasets["test"] 113 | 114 | def _log_while_training(self, evaluate=True, verbose=False, use_valset=True): 115 | def _log_and_train(*args, **kwargs): 116 | loss_before = 0 117 | loss_after = 0 118 | correct_before = 0 119 | correct_after = 0 120 | num_samples = len(self.valset) 121 | if evaluate: 122 | loss_before, correct_before = self.evaluate(use_valset) 123 | 124 | res = self._train(*args, **kwargs) 125 | 126 | if evaluate: 127 | loss_after, correct_after = self.evaluate(use_valset) 128 | 129 | if verbose: 130 | self.logger.log( 131 | "client [{}] [bold red]loss: {:.4f} -> {:.4f} [bold blue]accuracy: {:.2f}% -> {:.2f}%".format( 132 | self.client_id, 133 | loss_before / num_samples, 134 | loss_after / num_samples, 135 | correct_before / num_samples * 100.0, 136 | correct_after / num_samples * 100.0, 137 | ) 138 | ) 139 | 140 | stats = { 141 | "correct": correct_before, 142 | "size": num_samples, 143 | } 144 | return res, stats 145 | 146 | return _log_and_train 147 | 148 | def set_parameters(self, model_params: OrderedDict): 149 | self.model.load_state_dict(model_params, strict=False) 150 | if self.client_id in self.untrainable_params.keys(): 151 | self.model.load_state_dict( 152 | self.untrainable_params[self.client_id], strict=False 153 | ) 154 | 155 | def get_data_batch(self): 156 | batch_size = ( 157 | self.batch_size 158 | if self.batch_size > 0 159 | else int(len(self.trainset) / self.local_epochs) 160 | ) 161 | indices = torch.from_numpy( 162 | np.random.choice(self.trainset.indices, batch_size) 163 | ).long() 164 | data, targets = self.trainset.dataset[indices] 165 | return data.to(self.device), targets.to(self.device) 166 | -------------------------------------------------------------------------------- /src/client/fedavg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from rich.console import Console 3 | 4 | from .base import ClientBase 5 | 6 | 7 | class FedAvgClient(ClientBase): 8 | def __init__( 9 | self, 10 | backbone: torch.nn.Module, 11 | dataset: str, 12 | batch_size: int, 13 | local_epochs: int, 14 | local_lr: float, 15 | logger: Console, 16 | gpu: int, 17 | ): 18 | super(FedAvgClient, self).__init__( 19 | backbone, 20 | dataset, 21 | batch_size, 22 | local_epochs, 23 | local_lr, 24 | logger, 25 | gpu, 26 | ) 27 | -------------------------------------------------------------------------------- /src/client/fedprox.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import OrderedDict, List 3 | 4 | import torch 5 | from rich.console import Console 6 | 7 | from .base import ClientBase 8 | 9 | 10 | class FedProxClient(ClientBase): 11 | def __init__( 12 | self, 13 | backbone: torch.nn.Module, 14 | dataset: str, 15 | batch_size: int, 16 | local_epochs: int, 17 | local_lr: float, 18 | logger: Console, 19 | gpu: int, 20 | ): 21 | super(FedProxClient, self).__init__( 22 | backbone, 23 | dataset, 24 | batch_size, 25 | local_epochs, 26 | local_lr, 27 | logger, 28 | gpu, 29 | ) 30 | self.trainable_global_params: List[torch.Tensor] = None 31 | self.mu = 1.0 32 | 33 | def _train(self): 34 | self.model.train() 35 | for _ in range(self.local_epochs): 36 | x, y = self.get_data_batch() 37 | logits = self.model(x) 38 | loss = self.criterion(logits, y) 39 | self.optimizer.zero_grad() 40 | loss.backward() 41 | for w, w_g in zip(self.model.parameters(), self.trainable_global_params): 42 | w.grad.data += self.mu * (w_g.data - w.data) 43 | self.optimizer.step() 44 | return ( 45 | list(self.model.state_dict(keep_vars=True).values()), 46 | len(self.trainset.dataset), 47 | ) 48 | 49 | def set_parameters( 50 | self, 51 | model_params: OrderedDict[str, torch.Tensor], 52 | ): 53 | super().set_parameters(model_params) 54 | self.trainable_global_params = list( 55 | filter(lambda p: p.requires_grad, model_params.values()) 56 | ) 57 | -------------------------------------------------------------------------------- /src/client/scaffold.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from copy import deepcopy 3 | from typing import Dict, List, OrderedDict 4 | 5 | import torch 6 | from rich.console import Console 7 | 8 | from .base import ClientBase 9 | 10 | 11 | class SCAFFOLDClient(ClientBase): 12 | def __init__( 13 | self, 14 | backbone: torch.nn.Module, 15 | dataset: str, 16 | batch_size: int, 17 | local_epochs: int, 18 | local_lr: float, 19 | logger: Console, 20 | gpu: int, 21 | ): 22 | super(SCAFFOLDClient, self).__init__( 23 | backbone, 24 | dataset, 25 | batch_size, 26 | local_epochs, 27 | local_lr, 28 | logger, 29 | gpu, 30 | ) 31 | self.c_local: Dict[List[torch.Tensor]] = {} 32 | self.c_diff = [] 33 | 34 | def train( 35 | self, 36 | client_id: int, 37 | model_params: OrderedDict[str, torch.Tensor], 38 | c_global, 39 | evaluate=True, 40 | verbose=True, 41 | use_valset=True, 42 | ): 43 | self.client_id = client_id 44 | self.set_parameters(model_params) 45 | self.get_client_local_dataset() 46 | if self.client_id not in self.c_local.keys(): 47 | self.c_diff = c_global 48 | else: 49 | self.c_diff = [] 50 | for c_l, c_g in zip(self.c_local[self.client_id], c_global): 51 | self.c_diff.append(-c_l + c_g) 52 | _, stats = self._log_while_training(evaluate, verbose, use_valset)() 53 | # update local control variate 54 | with torch.no_grad(): 55 | trainable_parameters = filter( 56 | lambda p: p.requires_grad, model_params.values() 57 | ) 58 | 59 | if self.client_id not in self.c_local.keys(): 60 | self.c_local[self.client_id] = [ 61 | torch.zeros_like(param, device=self.device) 62 | for param in self.model.parameters() 63 | ] 64 | 65 | y_delta = [] 66 | c_plus = [] 67 | c_delta = [] 68 | 69 | # compute y_delta (difference of model before and after training) 70 | for param_l, param_g in zip(self.model.parameters(), trainable_parameters): 71 | y_delta.append(param_l - param_g) 72 | 73 | # compute c_plus 74 | coef = 1 / (self.local_epochs * self.local_lr) 75 | for c_l, c_g, diff in zip(self.c_local[self.client_id], c_global, y_delta): 76 | c_plus.append(c_l - c_g - coef * diff) 77 | 78 | # compute c_delta 79 | for c_p, c_l in zip(c_plus, self.c_local[self.client_id]): 80 | c_delta.append(c_p - c_l) 81 | 82 | self.c_local[self.client_id] = c_plus 83 | 84 | if self.client_id not in self.untrainable_params.keys(): 85 | self.untrainable_params[self.client_id] = {} 86 | for name, param in self.model.state_dict(keep_vars=True).items(): 87 | if not param.requires_grad: 88 | self.untrainable_params[self.client_id][name] = param.clone() 89 | 90 | return (y_delta, c_delta), stats 91 | 92 | def _train(self): 93 | self.model.train() 94 | for _ in range(self.local_epochs): 95 | x, y = self.get_data_batch() 96 | logits = self.model(x) 97 | loss = self.criterion(logits, y) 98 | self.optimizer.zero_grad() 99 | loss.backward() 100 | for param, c_d in zip(self.model.parameters(), self.c_diff): 101 | param.grad += c_d.data 102 | self.optimizer.step() 103 | -------------------------------------------------------------------------------- /src/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarhouTam/SCAFFOLD-PyTorch/69791f03d4d0951e5f3e7cf181d32a750ff9d1d0/src/config/__init__.py -------------------------------------------------------------------------------- /src/config/models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | ARGS = { 4 | "mnist": (1, 256, 10), 5 | "emnist": (1, 256, 62), 6 | "fmnist": (1, 256, 10), 7 | "cifar10": (3, 400, 10), 8 | "cifar100": (3, 400, 100), 9 | } 10 | 11 | 12 | class LeNet5(nn.Module): 13 | def __init__(self, dataset) -> None: 14 | super(LeNet5, self).__init__() 15 | self.net = nn.Sequential( 16 | nn.Conv2d(ARGS[dataset][0], 6, 5), 17 | nn.ReLU(True), 18 | nn.MaxPool2d(2), 19 | nn.Conv2d(6, 16, 5), 20 | nn.ReLU(True), 21 | nn.MaxPool2d(2), 22 | nn.Flatten(), 23 | nn.Linear(ARGS[dataset][1], 120), 24 | nn.ReLU(True), 25 | nn.Linear(120, 84), 26 | nn.ReLU(True), 27 | nn.Linear(84, ARGS[dataset][2]), 28 | ) 29 | 30 | def forward(self, x): 31 | return self.net(x) 32 | -------------------------------------------------------------------------------- /src/config/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | from argparse import ArgumentParser, Namespace 3 | from collections import OrderedDict 4 | from typing import OrderedDict, Union 5 | 6 | import numpy as np 7 | import torch 8 | from path import Path 9 | 10 | PROJECT_DIR = Path(__file__).parent.parent.parent.abspath() 11 | LOG_DIR = PROJECT_DIR / "logs" 12 | TEMP_DIR = PROJECT_DIR / "temp" 13 | DATA_DIR = PROJECT_DIR / "data" 14 | 15 | 16 | def fix_random_seed(seed: int) -> None: 17 | torch.cuda.empty_cache() 18 | torch.random.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | random.seed(seed) 21 | np.random.seed(seed) 22 | torch.backends.cudnn.deterministic = True 23 | torch.backends.cudnn.benchmark = True 24 | 25 | 26 | def clone_parameters( 27 | src: Union[OrderedDict[str, torch.Tensor], torch.nn.Module] 28 | ) -> OrderedDict[str, torch.Tensor]: 29 | if isinstance(src, OrderedDict): 30 | return OrderedDict( 31 | { 32 | name: param.clone().detach().requires_grad_(param.requires_grad) 33 | for name, param in src.items() 34 | } 35 | ) 36 | if isinstance(src, torch.nn.Module): 37 | return OrderedDict( 38 | { 39 | name: param.clone().detach().requires_grad_(param.requires_grad) 40 | for name, param in src.state_dict(keep_vars=True).items() 41 | } 42 | ) 43 | 44 | 45 | def get_args() -> Namespace: 46 | parser = ArgumentParser() 47 | parser.add_argument("--global_epochs", type=int, default=100) 48 | parser.add_argument("--local_epochs", type=int, default=10) 49 | parser.add_argument("--local_lr", type=float, default=1e-2) 50 | parser.add_argument("--verbose_gap", type=int, default=20) 51 | parser.add_argument( 52 | "--dataset", 53 | type=str, 54 | choices=["mnist", "cifar10", "cifar100", "emnist", "fmnist"], 55 | default="mnist", 56 | ) 57 | parser.add_argument("--batch_size", type=int, default=-1) 58 | parser.add_argument("--gpu", type=int, default=1) 59 | parser.add_argument("--log", type=int, default=0) 60 | parser.add_argument("--seed", type=int, default=17) 61 | parser.add_argument("--client_num_per_round", type=int, default=2) 62 | parser.add_argument("--save_period", type=int, default=20) 63 | return parser.parse_args() 64 | -------------------------------------------------------------------------------- /src/server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KarhouTam/SCAFFOLD-PyTorch/69791f03d4d0951e5f3e7cf181d32a750ff9d1d0/src/server/__init__.py -------------------------------------------------------------------------------- /src/server/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | from argparse import Namespace 5 | from collections import OrderedDict 6 | 7 | import torch 8 | from path import Path 9 | from rich.console import Console 10 | from rich.progress import track 11 | from tqdm import tqdm 12 | 13 | _CURRENT_DIR = Path(__file__).parent.abspath() 14 | 15 | import sys 16 | 17 | sys.path.append(_CURRENT_DIR.parent) 18 | 19 | from config.models import LeNet5 20 | from config.util import ( 21 | DATA_DIR, 22 | LOG_DIR, 23 | PROJECT_DIR, 24 | TEMP_DIR, 25 | clone_parameters, 26 | fix_random_seed, 27 | ) 28 | 29 | sys.path.append(PROJECT_DIR) 30 | sys.path.append(DATA_DIR) 31 | from client.base import ClientBase 32 | from data.utils.util import get_client_id_indices 33 | 34 | 35 | class ServerBase: 36 | def __init__(self, args: Namespace, algo: str): 37 | self.algo = algo 38 | self.args = args 39 | # default log file format 40 | self.log_name = "{}_{}_{}_{}.html".format( 41 | self.algo, 42 | self.args.dataset, 43 | self.args.global_epochs, 44 | self.args.local_epochs, 45 | ) 46 | self.device = torch.device( 47 | "cuda" if self.args.gpu and torch.cuda.is_available() else "cpu" 48 | ) 49 | fix_random_seed(self.args.seed) 50 | self.backbone = LeNet5 51 | self.logger = Console( 52 | record=True, 53 | log_path=False, 54 | log_time=False, 55 | ) 56 | self.client_id_indices, self.client_num_in_total = get_client_id_indices( 57 | self.args.dataset 58 | ) 59 | self.temp_dir = TEMP_DIR / self.algo 60 | if not os.path.isdir(self.temp_dir): 61 | os.makedirs(self.temp_dir) 62 | 63 | _dummy_model = self.backbone(self.args.dataset).to(self.device) 64 | passed_epoch = 0 65 | self.global_params_dict: OrderedDict[str : torch.Tensor] = None 66 | if os.listdir(self.temp_dir) != [] and self.args.save_period > 0: 67 | if os.path.exists(self.temp_dir / "global_model.pt"): 68 | self.global_params_dict = torch.load(self.temp_dir / "global_model.pt") 69 | self.logger.log("Find existed global model...") 70 | 71 | if os.path.exists(self.temp_dir / "epoch.pkl"): 72 | with open(self.temp_dir / "epoch.pkl", "rb") as f: 73 | passed_epoch = pickle.load(f) 74 | self.logger.log( 75 | f"Have run {passed_epoch} epochs already.", 76 | ) 77 | else: 78 | self.global_params_dict = OrderedDict( 79 | _dummy_model.state_dict(keep_vars=True) 80 | ) 81 | 82 | self.global_epochs = self.args.global_epochs - passed_epoch 83 | self.logger.log("Backbone:", _dummy_model) 84 | 85 | self.trainer: ClientBase = None 86 | self.num_correct = [[] for _ in range(self.global_epochs)] 87 | self.num_samples = [[] for _ in range(self.global_epochs)] 88 | 89 | def train(self): 90 | self.logger.log("=" * 30, "TRAINING", "=" * 30, style="bold green") 91 | progress_bar = ( 92 | track( 93 | range(self.global_epochs), 94 | "[bold green]Training...", 95 | console=self.logger, 96 | ) 97 | if not self.args.log 98 | else tqdm(range(self.global_epochs), "Training...") 99 | ) 100 | for E in progress_bar: 101 | 102 | if E % self.args.verbose_gap == 0: 103 | self.logger.log("=" * 30, f"ROUND: {E}", "=" * 30) 104 | 105 | selected_clients = random.sample( 106 | self.client_id_indices, self.args.client_num_per_round 107 | ) 108 | res_cache = [] 109 | for client_id in selected_clients: 110 | client_local_params = clone_parameters(self.global_params_dict) 111 | res, stats = self.trainer.train( 112 | client_id=client_id, 113 | model_params=client_local_params, 114 | verbose=(E % self.args.verbose_gap) == 0, 115 | ) 116 | 117 | res_cache.append(res) 118 | self.num_correct[E].append(stats["correct"]) 119 | self.num_samples[E].append(stats["size"]) 120 | self.aggregate(res_cache) 121 | 122 | if E % self.args.save_period == 0: 123 | torch.save( 124 | self.global_params_dict, 125 | self.temp_dir / "global_model.pt", 126 | ) 127 | with open(self.temp_dir / "epoch.pkl", "wb") as f: 128 | pickle.dump(E, f) 129 | 130 | @torch.no_grad() 131 | def aggregate(self, res_cache): 132 | updated_params_cache = list(zip(*res_cache))[0] 133 | weights_cache = list(zip(*res_cache))[1] 134 | weight_sum = sum(weights_cache) 135 | weights = torch.tensor(weights_cache, device=self.device) / weight_sum 136 | 137 | aggregated_params = [] 138 | 139 | for params in zip(*updated_params_cache): 140 | aggregated_params.append( 141 | torch.sum(weights * torch.stack(params, dim=-1), dim=-1) 142 | ) 143 | 144 | self.global_params_dict = OrderedDict( 145 | zip(self.global_params_dict.keys(), aggregated_params) 146 | ) 147 | 148 | def test(self) -> None: 149 | self.logger.log("=" * 30, "TESTING", "=" * 30, style="bold blue") 150 | all_loss = [] 151 | all_correct = [] 152 | all_samples = [] 153 | for client_id in track( 154 | self.client_id_indices, 155 | "[bold blue]Testing...", 156 | console=self.logger, 157 | disable=self.args.log, 158 | ): 159 | client_local_params = clone_parameters(self.global_params_dict) 160 | stats = self.trainer.test( 161 | client_id=client_id, 162 | model_params=client_local_params, 163 | ) 164 | # self.logger.log( 165 | # f"client [{client_id}] [red]loss: {(stats['loss'] / stats['size']):.4f} [magenta]accuracy: {stats(['correct'] / stats['size'] * 100):.2f}%" 166 | # ) 167 | all_loss.append(stats["loss"]) 168 | all_correct.append(stats["correct"]) 169 | all_samples.append(stats["size"]) 170 | self.logger.log("=" * 20, "RESULTS", "=" * 20, style="bold green") 171 | self.logger.log( 172 | "loss: {:.4f} accuracy: {:.2f}%".format( 173 | sum(all_loss) / sum(all_samples), 174 | sum(all_correct) / sum(all_samples) * 100.0, 175 | ) 176 | ) 177 | 178 | acc_range = [90.0, 80.0, 70.0, 60.0, 50.0, 40.0, 30.0, 20.0, 10.0] 179 | min_acc_idx = 10 180 | max_acc = 0 181 | for E, (corr, n) in enumerate(zip(self.num_correct, self.num_samples)): 182 | avg_acc = sum(corr) / sum(n) * 100.0 183 | for i, acc in enumerate(acc_range): 184 | if avg_acc >= acc and avg_acc > max_acc: 185 | self.logger.log( 186 | "{} achieved {}% accuracy({:.2f}%) at epoch: {}".format( 187 | self.algo, acc, avg_acc, E 188 | ) 189 | ) 190 | max_acc = avg_acc 191 | min_acc_idx = i 192 | break 193 | acc_range = acc_range[:min_acc_idx] 194 | 195 | def run(self): 196 | self.logger.log("Arguments:", dict(self.args._get_kwargs())) 197 | self.train() 198 | self.test() 199 | if self.args.log: 200 | if not os.path.isdir(LOG_DIR): 201 | os.mkdir(LOG_DIR) 202 | self.logger.save_html(LOG_DIR / self.log_name) 203 | 204 | # delete all temporary files 205 | if os.listdir(self.temp_dir) != []: 206 | os.system(f"rm -rf {self.temp_dir}") 207 | -------------------------------------------------------------------------------- /src/server/fedavg.py: -------------------------------------------------------------------------------- 1 | from base import ServerBase 2 | from client.fedavg import FedAvgClient 3 | from config.util import get_args 4 | 5 | 6 | class FedAvgServer(ServerBase): 7 | def __init__(self): 8 | super(FedAvgServer, self).__init__(get_args(), "FedAvg") 9 | self.trainer = FedAvgClient( 10 | backbone=self.backbone(self.args.dataset), 11 | dataset=self.args.dataset, 12 | batch_size=self.args.batch_size, 13 | local_epochs=self.args.local_epochs, 14 | local_lr=self.args.local_lr, 15 | logger=self.logger, 16 | gpu=self.args.gpu, 17 | ) 18 | 19 | 20 | if __name__ == "__main__": 21 | server = FedAvgServer() 22 | server.run() 23 | -------------------------------------------------------------------------------- /src/server/fedprox.py: -------------------------------------------------------------------------------- 1 | from base import ServerBase 2 | from client.fedprox import FedProxClient 3 | from config.util import get_args 4 | 5 | 6 | class FedProxServer(ServerBase): 7 | def __init__(self): 8 | super(FedProxServer, self).__init__(get_args(), "FedProx") 9 | 10 | self.trainer = FedProxClient( 11 | backbone=self.backbone(self.args.dataset), 12 | dataset=self.args.dataset, 13 | batch_size=self.args.batch_size, 14 | local_epochs=self.args.local_epochs, 15 | local_lr=self.args.local_lr, 16 | logger=self.logger, 17 | gpu=self.args.gpu, 18 | ) 19 | 20 | 21 | if __name__ == "__main__": 22 | server = FedProxServer() 23 | server.run() 24 | -------------------------------------------------------------------------------- /src/server/scaffold.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | 4 | import torch 5 | from rich.progress import track 6 | from tqdm import tqdm 7 | 8 | from base import ServerBase 9 | from client.scaffold import SCAFFOLDClient 10 | from config.util import clone_parameters, get_args 11 | 12 | 13 | class SCAFFOLDServer(ServerBase): 14 | def __init__(self): 15 | super(SCAFFOLDServer, self).__init__(get_args(), "SCAFFOLD") 16 | 17 | self.trainer = SCAFFOLDClient( 18 | backbone=self.backbone(self.args.dataset), 19 | dataset=self.args.dataset, 20 | batch_size=self.args.batch_size, 21 | local_epochs=self.args.local_epochs, 22 | local_lr=self.args.local_lr, 23 | logger=self.logger, 24 | gpu=self.args.gpu, 25 | ) 26 | self.c_global = [ 27 | torch.zeros_like(param).to(self.device) 28 | for param in self.backbone(self.args.dataset).parameters() 29 | ] 30 | self.global_lr = 1.0 31 | self.training_acc = [[] for _ in range(self.global_epochs)] 32 | 33 | def train(self): 34 | self.logger.log("=" * 30, "TRAINING", "=" * 30, style="bold green") 35 | progress_bar = ( 36 | track( 37 | range(self.global_epochs), 38 | "[bold green]Training...", 39 | console=self.logger, 40 | ) 41 | if not self.args.log 42 | else tqdm(range(self.global_epochs), "Training...") 43 | ) 44 | for E in progress_bar: 45 | 46 | if E % self.args.verbose_gap == 0: 47 | self.logger.log("=" * 30, f"ROUND: {E}", "=" * 30) 48 | 49 | selected_clients = random.sample( 50 | self.client_id_indices, self.args.client_num_per_round 51 | ) 52 | res_cache = [] 53 | for client_id in selected_clients: 54 | client_local_params = clone_parameters(self.global_params_dict) 55 | res, stats = self.trainer.train( 56 | client_id=client_id, 57 | model_params=client_local_params, 58 | c_global=self.c_global, 59 | verbose=(E % self.args.verbose_gap) == 0, 60 | ) 61 | res_cache.append(res) 62 | 63 | self.num_correct[E].append(stats["correct"]) 64 | self.num_samples[E].append(stats["size"]) 65 | self.aggregate(res_cache) 66 | 67 | if E % self.args.save_period == 0 and self.args.save_period > 0: 68 | torch.save( 69 | self.global_params_dict, 70 | self.temp_dir / "global_model.pt", 71 | ) 72 | with open(self.temp_dir / "epoch.pkl", "wb") as f: 73 | pickle.dump(E, f) 74 | 75 | def aggregate(self, res_cache): 76 | y_delta_cache = list(zip(*res_cache))[0] 77 | c_delta_cache = list(zip(*res_cache))[1] 78 | trainable_parameter = filter( 79 | lambda param: param.requires_grad, self.global_params_dict.values() 80 | ) 81 | 82 | # update global model 83 | avg_weight = torch.tensor( 84 | [ 85 | 1 / self.args.client_num_per_round 86 | for _ in range(self.args.client_num_per_round) 87 | ], 88 | device=self.device, 89 | ) 90 | for param, y_del in zip(trainable_parameter, zip(*y_delta_cache)): 91 | x_del = torch.sum(avg_weight * torch.stack(y_del, dim=-1), dim=-1) 92 | param.data += self.global_lr * x_del 93 | 94 | # update global control 95 | for c_g, c_del in zip(self.c_global, zip(*c_delta_cache)): 96 | c_del = torch.sum(avg_weight * torch.stack(c_del, dim=-1), dim=-1) 97 | c_g.data += ( 98 | self.args.client_num_per_round / len(self.client_id_indices) 99 | ) * c_del 100 | 101 | 102 | if __name__ == "__main__": 103 | server = SCAFFOLDServer() 104 | server.run() 105 | --------------------------------------------------------------------------------