├── src ├── __pycache__ │ ├── args.cpython-38.pyc │ ├── models.cpython-38.pyc │ ├── updates.cpython-38.pyc │ └── utils.cpython-38.pyc ├── models.py ├── args.py ├── utils.py ├── scaffold_main.py ├── fedprox_main.py └── updates.py └── README.md /src/__pycache__/args.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ongzh/ScaffoldFL/HEAD/src/__pycache__/args.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ongzh/ScaffoldFL/HEAD/src/__pycache__/models.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/updates.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ongzh/ScaffoldFL/HEAD/src/__pycache__/updates.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ongzh/ScaffoldFL/HEAD/src/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Federated Learning (Scaffold and Fedprox) using PyTorch 2 | **Basic Implementation of:**
3 | **Fedprox:** [Federated Optimization in Heterogeneous Networks](https://arxiv.org/abs/1812.06127)
4 | **Scaffold:** [SCAFFOLD: Stochastic Controlled Averaging for Federated Learning](https://arxiv.org/abs/1910.06378)
5 | 6 | Models: VGG, CNN
7 | DataSet: Cifar10 8 | 9 | 10 | --- 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | ### Note 25 | Non-IID implementation and dataset used differ from the Scaffold Paper.
26 | Sections of code inspired from FedAvg implementation by https://github.com/AshwinRJ/Federated-Learning-PyTorch 27 | 28 | 29 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class cifarCNN(nn.Module): 8 | def __init__(self,args): 9 | super(cifarCNN, self).__init__() 10 | self.conv1 = nn.Conv2d(3, 6, 5) 11 | self.pool = nn.MaxPool2d(2, 2) 12 | self.conv2 = nn.Conv2d(6, 16, 5) 13 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 14 | self.fc2 = nn.Linear(120, 84) 15 | self.fc3 = nn.Linear(84, args.num_classes) 16 | 17 | def forward(self, x): 18 | x = self.pool(F.relu(self.conv1(x))) 19 | x = self.pool(F.relu(self.conv2(x))) 20 | 21 | x = torch.flatten(x,1) 22 | x = F.relu(self.fc1(x)) 23 | x = F.relu(self.fc2(x)) 24 | x = self.fc3(x) 25 | 26 | return x 27 | 28 | class VGG(nn.Module): 29 | def __init__(self,args): 30 | super(VGG, self).__init__() 31 | self.inputs = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] 32 | self.classifier = nn.Linear(512, 10) 33 | self.features = self._make_layers(self.inputs) 34 | 35 | def forward(self, x): 36 | x = self.features(x) 37 | x = x.view(x.size(0), -1) 38 | x = self.classifier(x) 39 | return x 40 | 41 | def _make_layers(self, inputs): 42 | layers = [] 43 | in_channels = 3 44 | for x in inputs: 45 | if x == 'M': 46 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 47 | else: 48 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 49 | nn.BatchNorm2d(x), 50 | nn.ReLU(inplace=True)] 51 | in_channels = x 52 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 53 | return nn.Sequential(*layers) -------------------------------------------------------------------------------- /src/args.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | 6 | 7 | def args_parser(): 8 | parser = argparse.ArgumentParser() 9 | 10 | # federated arguments (Notation for the arguments followed from paper) 11 | parser.add_argument('--epochs', type=int, default=10, 12 | help="number of rounds of training") 13 | parser.add_argument('--num_users', type=int, default=100, 14 | help="number of users: K") 15 | parser.add_argument('--frac', type=float, default=0.1, 16 | help='the fraction of clients: C') 17 | parser.add_argument('--local_ep', type=int, default=10, 18 | help="the number of local epochs: E") 19 | parser.add_argument('--local_bs', type=int, default=10, 20 | help="local batch size: B") 21 | parser.add_argument('--lr', type=float, default=0.01, 22 | help='learning rate') 23 | parser.add_argument('--momentum', type=float, default=0.5, 24 | help='SGD momentum (default: 0.5)') 25 | 26 | # model arguments 27 | parser.add_argument('--model', type=str, default='mlp', help='model name') 28 | parser.add_argument('--kernel_num', type=int, default=9, 29 | help='number of each kind of kernel') 30 | parser.add_argument('--kernel_sizes', type=str, default='3,4,5', 31 | help='comma-separated kernel size to \ 32 | use for convolution') 33 | parser.add_argument('--num_channels', type=int, default=1, help="number \ 34 | of channels of imgs") 35 | parser.add_argument('--norm', type=str, default='batch_norm', 36 | help="batch_norm, layer_norm, or None") 37 | parser.add_argument('--num_filters', type=int, default=32, 38 | help="number of filters for conv nets -- 32 for \ 39 | mini-imagenet, 64 for omiglot.") 40 | parser.add_argument('--max_pool', type=str, default='True', 41 | help="Whether use max pooling rather than \ 42 | strided convolutions") 43 | parser.add_argument('--pretrained', type=str, default='false', 44 | help="whether model is pretrained") 45 | 46 | parser.add_argument('--decay', type=float, default=0, help="learning rate decay per global round") 47 | 48 | # other arguments 49 | parser.add_argument('--dataset', type=str, default='mnist', help="name \ 50 | of dataset") 51 | parser.add_argument('--num_classes', type=int, default=10, help="number \ 52 | of classes") 53 | parser.add_argument('--gpu', default=None, help="To use cuda, set \ 54 | to a specific GPU ID. Default set to use CPU.") 55 | parser.add_argument('--optimizer', type=str, default='sgd', help="type \ 56 | of optimizer") 57 | parser.add_argument('--iid', type=int, default=1, 58 | help='Default set to IID. Set to 0 for non-IID.') 59 | parser.add_argument('--unequal', type=int, default=0, 60 | help='whether to use unequal data splits for \ 61 | non-i.i.d setting (use 0 for equal splits)') 62 | parser.add_argument('--stopping_rounds', type=int, default=10, 63 | help='rounds of early stopping') 64 | parser.add_argument('--verbose', type=int, default=1, help='verbose') 65 | parser.add_argument('--seed', type=int, default=1, help='random seed') 66 | parser.add_argument('--mu', type=float, default=0.0, help='proximal term constant') 67 | parser.add_argument('--stragglers', type=float, default=0, help='percentage of stragglers') 68 | args = parser.parse_args() 69 | return args 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import copy 5 | import torch 6 | from torchvision import datasets, transforms 7 | import numpy as np 8 | 9 | def get_dataset(args): 10 | """ Returns train and test datasets and a user group which is a dict where 11 | the keys are the user index and the values are the corresponding data for 12 | each of those users. 13 | """ 14 | 15 | if args.dataset == 'cifar': 16 | data_dir = '../data/cifar/' 17 | 18 | train_transform = transforms.Compose( 19 | [transforms.ToTensor(), 20 | #transforms.RandomCrop(size=24), 21 | transforms.RandomApply(torch.nn.ModuleList([ 22 | transforms.ColorJitter(),]),p=0.5), 23 | transforms.RandomAutocontrast(), 24 | transforms.RandomHorizontalFlip(), 25 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))]) 26 | 27 | test_transform = transforms.Compose( 28 | [transforms.ToTensor(), 29 | #transforms.RandomCrop(size=24), 30 | transforms.Normalize((0.4914,0.4822,0.4465),(0.247,0.243,0.261))]) 31 | 32 | train_dataset = datasets.CIFAR10(data_dir, train=True, download=True, 33 | transform=train_transform) 34 | 35 | test_dataset = datasets.CIFAR10(data_dir, train=False, download=True, 36 | transform=test_transform) 37 | 38 | # sample training data amongst users 39 | if args.iid: 40 | # Sample IID user data from Mnist 41 | user_groups = cifar_iid(train_dataset, args.num_users) 42 | else: 43 | # Sample Non-IID user data from Mnist 44 | if args.unequal: 45 | # Chose uneuqal splits for every user 46 | raise NotImplementedError() 47 | else: 48 | # Chose euqal splits for every user 49 | user_groups = cifar_noniid(train_dataset, args.num_users) 50 | 51 | return train_dataset, test_dataset, user_groups 52 | 53 | def average_weights(w): 54 | """ 55 | Returns the average of the weights. 56 | """ 57 | w_avg = copy.deepcopy(w[0]) 58 | for key in w_avg.keys(): 59 | for i in range(1, len(w)): 60 | w_avg[key] += w[i][key] 61 | w_avg[key] = torch.div(w_avg[key], len(w)) 62 | return w_avg 63 | 64 | 65 | def exp_details(args): 66 | print('\nExperimental details:') 67 | print(f' Model : {args.model}') 68 | print(f' Optimizer : {args.optimizer}') 69 | print(f' Learning : {args.lr}') 70 | print(f' Global Rounds : {args.epochs}\n') 71 | 72 | print(' Federated parameters:') 73 | if args.iid: 74 | print(' IID') 75 | else: 76 | print(' Non-IID') 77 | print(f' Fraction of users : {args.frac}') 78 | print(f' Local Batch size : {args.local_bs}') 79 | print(f' Local Epochs : {args.local_ep}\n') 80 | return 81 | 82 | def cifar_iid(dataset, num_users): 83 | """ 84 | Sample I.I.D. client data from CIFAR10 dataset 85 | :param dataset: 86 | :param num_users: 87 | :return: dict of image index 88 | """ 89 | num_items = int(len(dataset)/num_users) 90 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 91 | for i in range(num_users): 92 | dict_users[i] = set(np.random.choice(all_idxs, num_items, 93 | replace=False)) 94 | all_idxs = list(set(all_idxs) - dict_users[i]) 95 | return dict_users 96 | 97 | 98 | def cifar_noniid(dataset, num_users): 99 | """ 100 | Sample non-I.I.D client data from CIFAR10 dataset 101 | :param dataset: 102 | :param num_users: 103 | :return: 104 | """ 105 | num_shards, num_imgs = 200, 250 106 | idx_shard = [i for i in range(num_shards)] 107 | dict_users = {i: np.array([]) for i in range(num_users)} 108 | idxs = np.arange(num_shards*num_imgs) 109 | # labels = dataset.train_labels.numpy() 110 | labels = np.array(dataset.targets) 111 | 112 | # sort labels 113 | idxs_labels = np.vstack((idxs, labels)) 114 | #stack into two rows, sort the labels row 115 | idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()] 116 | idxs = idxs_labels[0, :] 117 | 118 | # divide and assign 119 | for i in range(num_users): 120 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 121 | idx_shard = list(set(idx_shard) - rand_set) 122 | for rand in rand_set: 123 | dict_users[i] = np.concatenate( 124 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 125 | return dict_users 126 | 127 | -------------------------------------------------------------------------------- /src/scaffold_main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import copy 6 | import time 7 | import pickle 8 | import numpy as np 9 | from tqdm import tqdm 10 | import torch 11 | import torch.nn as nn 12 | import gc 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | from args import args_parser 16 | from updates import test_results, ScaffoldUpdate 17 | from models import cifarCNN 18 | from utils import get_dataset, exp_details 19 | 20 | if __name__ == '__main__': 21 | start_time = time.time() 22 | 23 | # define paths 24 | path_project = os.path.abspath('..') 25 | logger = SummaryWriter('../logs') 26 | 27 | args = args_parser() 28 | exp_details(args) 29 | 30 | if args.gpu: 31 | # if args.gpu_id: 32 | torch.cuda.set_device(int(args.gpu)) 33 | device = 'cuda' if args.gpu else 'cpu' 34 | 35 | # load dataset and user groups 36 | train_dataset, test_dataset, user_groups = get_dataset(args) 37 | 38 | #only one model for now 39 | if args.dataset == 'cifar': 40 | global_model = cifarCNN(args=args) 41 | control_global = cifarCNN(args=args) 42 | 43 | #set global model to train 44 | global_model.to(device) 45 | global_model.train() 46 | print(global_model) 47 | 48 | control_global.to(device) 49 | 50 | control_weights = control_global.state_dict() 51 | 52 | 53 | 54 | # Training 55 | train_loss, train_accuracy = [], [] 56 | val_acc_list, net_list = [], [] 57 | cv_loss, cv_acc = [], [] 58 | print_every = 2 59 | val_loss_pre, counter = 0, 0 60 | 61 | # Test each round 62 | test_acc_list = [] 63 | 64 | 65 | #devices that participate (sample size) 66 | m = max(int(args.frac * args.num_users), 1) 67 | 68 | #model for local control varietes 69 | local_controls = [cifarCNN(args=args) for i in range(args.num_users)] 70 | #local_models = [cifarCNN(args=args) for i in range(args.num_users)] 71 | 72 | for net in local_controls: 73 | net.load_state_dict(control_weights) 74 | 75 | 76 | #initiliase total delta to 0 (sum of all control_delta, triangle Ci) 77 | delta_c = copy.deepcopy(global_model.state_dict()) 78 | #sum of delta_y / sample size 79 | delta_x = copy.deepcopy(global_model.state_dict()) 80 | 81 | 82 | 83 | 84 | 85 | #global rounds 86 | for epoch in tqdm(range(args.epochs)): 87 | local_weights, local_losses = [], [] 88 | print(f'\n | Global Training Round : {epoch+1} |\n') 89 | 90 | for ci in delta_c: 91 | delta_c[ci] = 0.0 92 | for ci in delta_x: 93 | delta_x[ci] = 0.0 94 | 95 | global_model.train() 96 | # sample the users 97 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 98 | 99 | for idx in idxs_users: 100 | local_model = ScaffoldUpdate(args=args, dataset=train_dataset, 101 | idxs=user_groups[idx], logger=logger) 102 | weights, loss , local_delta_c, local_delta, control_local_w, _ = local_model.update_weights( 103 | model=copy.deepcopy(global_model), global_round=epoch, control_local 104 | = local_controls[idx], control_global = control_global) 105 | 106 | if epoch != 0: 107 | local_controls[idx].load_state_dict(control_local_w) 108 | 109 | local_weights.append(copy.deepcopy(weights)) 110 | local_losses.append(copy.deepcopy(loss)) 111 | 112 | #line16 113 | for w in delta_c: 114 | if epoch==0: 115 | delta_x[w] += weights[w] 116 | else: 117 | delta_x[w] += local_delta[w] 118 | delta_c[w] += local_delta_c[w] 119 | 120 | #clean 121 | gc.collect() 122 | torch.cuda.empty_cache() 123 | 124 | #update the delta C (line 16) 125 | for w in delta_c: 126 | delta_c[w] /= m 127 | delta_x[w] /= m 128 | 129 | #update global control variate (line17) 130 | control_global_W = control_global.state_dict() 131 | global_weights = global_model.state_dict() 132 | #equation taking Ng, global step size = 1 133 | for w in control_global_W: 134 | #control_global_W[w] += delta_c[w] 135 | if epoch == 0: 136 | global_weights[w] = delta_x[w] 137 | else: 138 | global_weights[w] += delta_x[w] 139 | control_global_W[w] += (m / args.num_users) * delta_c[w] 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | #update global model 148 | control_global.load_state_dict(control_global_W) 149 | global_model.load_state_dict(global_weights) 150 | 151 | #########scaffold algo complete################## 152 | 153 | 154 | loss_avg = sum(local_losses) / len(local_losses) 155 | train_loss.append(loss_avg) 156 | 157 | # Calculate avg training accuracy over all users at every epoch 158 | list_acc, list_loss = [], [] 159 | 160 | global_model.eval() 161 | 162 | for c in range(args.num_users): 163 | local_model = ScaffoldUpdate(args=args, dataset=train_dataset, 164 | idxs=user_groups[idx], logger=logger) 165 | acc, loss = local_model.inference(model=global_model) 166 | list_acc.append(acc) 167 | #print("user:" + str(c) +" " + str(acc)) 168 | list_loss.append(loss) 169 | gc.collect() 170 | torch.cuda.empty_cache() 171 | 172 | train_accuracy.append(sum(list_acc)/len(list_acc)) 173 | 174 | round_test_acc, round_test_loss = test_results( 175 | args, global_model, test_dataset) 176 | test_acc_list.append(round_test_acc) 177 | 178 | 179 | # print global training loss after every 'i' rounds 180 | if (epoch+1) % print_every == 0: 181 | print(f' \nAvg Training Stats after {epoch+1} global rounds:') 182 | print(f'Training Loss : {np.mean(np.array(train_loss))}') 183 | print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1])) 184 | print('Test Accuracy at round ' + str(epoch+1) + 185 | ': {:.2f}% \n'.format(100*round_test_acc)) 186 | 187 | # Test inference after completion of training 188 | test_acc, test_loss = test_results(args, global_model, test_dataset) 189 | 190 | print(f' \n Results after {args.epochs} global rounds of training:') 191 | print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1])) 192 | print("|---- Test Accuracy: {:.2f}%".format(100*test_acc)) 193 | 194 | # save results to csv 195 | res = np.asarray([test_acc_list]) 196 | res_name = '../save/csvResults/Scaffold_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_LR[{}].csv'. \ 197 | format(args.dataset, args.model, args.epochs, args.frac, args.iid, 198 | args.local_ep, args.local_bs, args.lr) 199 | np.savetxt(res_name, res, delimiter=",") 200 | -------------------------------------------------------------------------------- /src/fedprox_main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import os 6 | import copy 7 | import time 8 | import pickle 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | import torch 13 | import gc 14 | from tensorboardX import SummaryWriter 15 | 16 | from args import args_parser 17 | from updates import ProxUpdate, test_results 18 | from models import cifarCNN, VGG 19 | from utils import get_dataset, exp_details, average_weights 20 | import torchvision.models as models 21 | 22 | if __name__ == '__main__': 23 | start_time = time.time() 24 | 25 | # define paths 26 | path_project = os.path.abspath('..') 27 | logger = SummaryWriter('../logs') 28 | 29 | args = args_parser() 30 | exp_details(args) 31 | 32 | if args.gpu: 33 | # if args.gpu_id: 34 | torch.cuda.set_device(int(args.gpu)) 35 | device = 'cuda' if args.gpu else 'cpu' 36 | 37 | # load dataset and user groups 38 | train_dataset, test_dataset, user_groups = get_dataset(args) 39 | 40 | # BUILD MODEL 41 | if args.model == 'cnn': 42 | if args.dataset == 'cifar': 43 | global_model = cifarCNN(args=args) 44 | 45 | elif args.model == 'vgg': 46 | if args.dataset == 'cifar' and args.pretrained: 47 | global_model = models.vgg16(pretrained=True) 48 | # change the number of classes 49 | global_model.classifier[6].out_features = 10 50 | # freeze convolution weights 51 | for param in global_model.features.parameters(): 52 | param.requires_grad = False 53 | elif args.dataset == 'cifar': 54 | global_model = VGG(args=args) 55 | else: 56 | exit(args.dataset + ' with ' + args.model + ' not supported') 57 | 58 | elif args.model == "resnet18": 59 | global_model = models.resnet18(pretrained=True) 60 | 61 | else: 62 | exit('Error: unrecognized model') 63 | 64 | # Set the model to train and send it to device. 65 | global_model.to(device) 66 | global_model.train() 67 | print(global_model) 68 | 69 | # copy weights 70 | global_weights = global_model.state_dict() 71 | 72 | # Training 73 | train_loss, train_accuracy = [], [] 74 | val_acc_list, net_list = [], [] 75 | cv_loss, cv_acc = [], [] 76 | print_every = 2 77 | val_loss_pre, counter = 0, 0 78 | 79 | # Test each round 80 | test_acc_list = [] 81 | 82 | #devices that participate 83 | m = max(int(args.frac * args.num_users), 1) 84 | 85 | 86 | for epoch in tqdm(range(args.epochs)): 87 | local_weights, local_losses = [], [] 88 | print(f'\n | Global Training Round : {epoch+1} |\n') 89 | 90 | # Local Epochs list to account for stragglers 91 | if args.stragglers == 0: 92 | local_epoch_list = np.array([args.local_ep] * m) 93 | else: 94 | straggler_size = int(args.stragglers * m) 95 | local_epoch_list = np.random.randint(1, args.local_ep, straggler_size) 96 | 97 | remainders = m - straggler_size 98 | rem_list = [args.local_ep] * remainders 99 | 100 | epoch_list = np.append(local_epoch_list, rem_list, axis=0) 101 | # shuffle the list and return 102 | np.random.shuffle(local_epoch_list) 103 | 104 | global_model.train() 105 | 106 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 107 | 108 | for idx, local_epoch in zip(idxs_users, local_epoch_list): 109 | local_model = ProxUpdate(args=args, dataset=train_dataset, 110 | idxs=user_groups[idx], logger=logger, local_epoch=local_epoch) 111 | w, loss, time = local_model.update_weights( 112 | model=copy.deepcopy(global_model), global_round=epoch) 113 | local_weights.append(copy.deepcopy(w)) 114 | local_losses.append(copy.deepcopy(loss)) 115 | gc.collect() 116 | torch.cuda.empty_cache() 117 | 118 | # update global weights 119 | global_weights = average_weights(local_weights) 120 | 121 | # update global weights 122 | global_model.load_state_dict(global_weights) 123 | 124 | loss_avg = sum(local_losses) / len(local_losses) 125 | train_loss.append(loss_avg) 126 | 127 | # Calculate avg training accuracy over all users at every epoch 128 | list_acc, list_loss = [], [] 129 | global_model.eval() 130 | 131 | for c in range(args.num_users): 132 | local_model = ProxUpdate(args=args, dataset=train_dataset, 133 | idxs=user_groups[idx], logger=logger, local_epoch=args.local_ep) 134 | acc, loss = local_model.inference(model=global_model) 135 | list_acc.append(acc) 136 | list_loss.append(loss) 137 | gc.collect() 138 | torch.cuda.empty_cache() 139 | 140 | train_accuracy.append(sum(list_acc)/len(list_acc)) 141 | 142 | round_test_acc, round_test_loss = test_results( 143 | args, global_model, test_dataset) 144 | test_acc_list.append(round_test_acc) 145 | 146 | # print global training loss after every 'i' rounds 147 | if (epoch+1) % print_every == 0: 148 | print(f' \nAvg Training Stats after {epoch+1} global rounds:') 149 | print(f'Training Loss : {np.mean(np.array(train_loss))}') 150 | print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1])) 151 | print('Test Accuracy at round ' + str(epoch+1) + 152 | ': {:.2f}% \n'.format(100*round_test_acc)) 153 | 154 | # Test inference after completion of training 155 | test_acc, test_loss = test_results(args, global_model, test_dataset) 156 | 157 | print(f' \n Results after {args.epochs} global rounds of training:') 158 | print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1])) 159 | print("|---- Test Accuracy: {:.2f}%".format(100*test_acc)) 160 | 161 | # Saving the objects train_loss and train_accuracy: 162 | file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_LR[{}]_u[{}]_%strag[{}].pkl'.\ 163 | format(args.dataset, args.model, args.epochs, args.frac, args.iid, 164 | args.local_ep, args.local_bs, args.lr,args.mu, args.stragglers) 165 | 166 | with open(file_name, 'wb') as f: 167 | pickle.dump([train_loss, train_accuracy], f) 168 | 169 | #print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time)) 170 | 171 | # save results to csv 172 | res = np.asarray([test_acc_list]) 173 | res_name = '../save/csvResults/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_LR[{}]_u[{}]_%strag[{}].csv'. \ 174 | format(args.dataset, args.model, args.epochs, args.frac, args.iid, 175 | args.local_ep, args.local_bs, args.lr, args.mu, args.stragglers) 176 | np.savetxt(res_name, res, delimiter=",") 177 | 178 | # PLOTTING (optional) 179 | import matplotlib 180 | import matplotlib.pyplot as plt 181 | matplotlib.use('Agg') 182 | 183 | # # Plot Loss curve 184 | # plt.figure() 185 | # plt.title('Training Loss vs Communication rounds') 186 | # plt.plot(range(len(train_loss)), train_loss, color='r') 187 | # plt.ylabel('Training loss') 188 | # plt.xlabel('Communication Rounds') 189 | # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'. 190 | # format(args.dataset, args.model, args.epochs, args.frac, 191 | # args.iid, args.local_ep, args.local_bs)) 192 | 193 | # # Plot Average Accuracy vs Communication rounds 194 | # plt.figure() 195 | # plt.title('Average Accuracy vs Communication rounds') 196 | # plt.plot(range(len(train_accuracy)), train_accuracy , color='k') 197 | # plt.ylabel('Average Accuracy') 198 | # plt.xlabel('Communication Rounds') 199 | # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'. 200 | # format(args.dataset, args.model, args.epochs, args.frac, 201 | # args.iid, args.local_ep, args.local_bs)) 202 | 203 | # Plot Test Accuracy vs Communication rounds 204 | plt.figure() 205 | plt.title('Test Accuracy vs Communication rounds') 206 | plt.plot(range(len(train_accuracy)), test_acc_list, color='k') 207 | plt.ylabel('Test Accuracy') 208 | plt.xlabel('Communication Rounds') 209 | plt.ylim([0.1, 1]) 210 | plt.savefig('../save/fedPROX_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_LR[{}]_mu[{}]_%strag[{}]_test_acc.png'. 211 | format(args.dataset, args.model, args.epochs, args.frac, 212 | args.iid, args.local_ep, args.local_bs, args.lr, args.mu, args.stragglers)) 213 | 214 | -------------------------------------------------------------------------------- /src/updates.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*-````````````````````````````````````````` 3 | 4 | import copy 5 | import torch 6 | from torch import nn 7 | import time 8 | from torch.utils.data import DataLoader, Dataset 9 | import gc 10 | 11 | class DatasetSplit(Dataset): 12 | """An abstract Dataset class wrapped around Pytorch Dataset class. 13 | """ 14 | 15 | def __init__(self, dataset, idxs): 16 | self.dataset = dataset 17 | self.idxs = [int(i) for i in idxs] 18 | 19 | def __len__(self): 20 | return len(self.idxs) 21 | 22 | def __getitem__(self, item): 23 | image, label = self.dataset[self.idxs[item]] 24 | return torch.tensor(image), torch.tensor(label) 25 | 26 | class ScaffoldUpdate(object): 27 | def __init__(self, args, dataset, idxs, logger): 28 | self.args = args 29 | self.logger = logger 30 | self.local_ep = args.local_ep 31 | self.trainloader, self.validloader, self.testloader = self.train_val_test( 32 | dataset, list(idxs)) 33 | self.device = 'cuda' if args.gpu else 'cpu' 34 | self.criterion = nn.CrossEntropyLoss().to(self.device) 35 | 36 | 37 | 38 | def train_val_test(self, dataset, idxs): 39 | """ 40 | Returns train, validation and test dataloaders for a given dataset 41 | and user indexes. 42 | """ 43 | # split indexes for train, validation, and test (80, 10, 10) 44 | 45 | idxs_train = idxs[:int(0.8*len(idxs))] 46 | idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))] 47 | idxs_test = idxs[int(0.9*len(idxs)):] 48 | 49 | trainloader = DataLoader(DatasetSplit(dataset, idxs_train), 50 | batch_size=self.args.local_bs, shuffle=True) 51 | if len(idxs_val) < 10: 52 | validloader = DataLoader(DatasetSplit(dataset, idxs_val), 53 | batch_size=1, shuffle=False) 54 | testloader = DataLoader(DatasetSplit(dataset, idxs_test), 55 | batch_size=1, shuffle=False) 56 | else: 57 | validloader = DataLoader(DatasetSplit(dataset, idxs_val), 58 | batch_size=int(len(idxs_val)/10), shuffle=False) 59 | testloader = DataLoader(DatasetSplit(dataset, idxs_test), 60 | batch_size=int(len(idxs_test)/10), shuffle=False) 61 | 62 | return trainloader, validloader, testloader 63 | 64 | def update_weights(self, model, global_round, control_local, control_global): 65 | # Set mode to train model 66 | model.to(self.device) 67 | global_weights = model.state_dict() 68 | model.train() 69 | epoch_loss = [] 70 | 71 | start_time = time.time() 72 | 73 | decay = self.args.decay 74 | if decay != 0: 75 | learn_rate = self.args.lr * pow(decay, global_round) 76 | else: 77 | learn_rate = self.args.lr 78 | 79 | # Set optimizer for the local updates 80 | if self.args.optimizer == 'sgd': 81 | optimizer = torch.optim.SGD(model.parameters(), lr=(learn_rate), 82 | momentum=0.9, weight_decay=0.00001) 83 | elif self.args.optimizer == 'adam': 84 | optimizer = torch.optim.Adam(model.parameters(), lr=(learn_rate), 85 | weight_decay=1e-4) 86 | 87 | control_global_w = control_global.state_dict() 88 | control_local_w = control_local.state_dict() 89 | 90 | count = 0 91 | for iter in range(self.local_ep): 92 | batch_loss = [] 93 | for batch_idx, (images, labels) in enumerate(self.trainloader): 94 | images, labels = images.to(self.device), labels.to(self.device) 95 | 96 | model.zero_grad() 97 | log_probs = model(images) 98 | loss = self.criterion(log_probs, labels) 99 | loss.backward() 100 | optimizer.step() 101 | 102 | batch_loss.append(loss.item()) 103 | 104 | local_weights = model.state_dict() 105 | for w in local_weights: 106 | #line 10 in algo 107 | local_weights[w] = local_weights[w] - self.args.lr*(control_global_w[w]-control_local_w[w]) 108 | 109 | #update local model params 110 | model.load_state_dict(local_weights) 111 | 112 | count += 1 113 | 114 | 115 | if self.args.verbose and (batch_idx % 10 == 0): 116 | print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 117 | global_round, iter, batch_idx * len(images), 118 | len(self.trainloader.dataset), 119 | 100. * batch_idx / len(self.trainloader), loss.item())) 120 | self.logger.add_scalar('loss', loss.item()) 121 | 122 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 123 | gc.collect() 124 | torch.cuda.empty_cache() 125 | 126 | 127 | new_control_local_w = control_local.state_dict() 128 | control_delta = copy.deepcopy(control_local_w) 129 | #model_weights -> y_(i) 130 | model_weights = model.state_dict() 131 | local_delta = copy.deepcopy(model_weights) 132 | for w in model_weights: 133 | #line 12 in algo 134 | new_control_local_w[w] = new_control_local_w[w] - control_global_w[w] + (global_weights[w] - model_weights[w]) / (count * self.args.lr) 135 | #line 13 136 | control_delta[w] = new_control_local_w[w] - control_local_w[w] 137 | local_delta[w] -= global_weights[w] 138 | #update new control_local model 139 | #control_local.load_state_dict(new_control_local_w) 140 | 141 | return model.state_dict(), sum(epoch_loss) / len(epoch_loss), control_delta, local_delta, new_control_local_w, time.time()- start_time 142 | 143 | def inference(self, model): 144 | """ Returns the inference accuracy and loss. 145 | """ 146 | 147 | model.eval() 148 | loss, total, correct = 0.0, 0.0, 0.0 149 | 150 | for batch_idx, (images, labels) in enumerate(self.testloader): 151 | images, labels = images.to(self.device), labels.to(self.device) 152 | 153 | # Inference 154 | outputs = model(images) 155 | batch_loss = self.criterion(outputs, labels) 156 | loss += batch_loss.item() 157 | 158 | # Prediction 159 | _, pred_labels = torch.max(outputs, 1) 160 | pred_labels = pred_labels.view(-1) 161 | correct += torch.sum(torch.eq(pred_labels, labels)).item() 162 | total += len(labels) 163 | 164 | accuracy = correct/total 165 | return accuracy, loss 166 | 167 | def test_results(args, model, test_dataset): 168 | """ Returns the test accuracy and loss. 169 | """ 170 | 171 | model.eval() 172 | loss, total, correct = 0.0, 0.0, 0.0 173 | 174 | device = 'cuda' if args.gpu else 'cpu' 175 | #criterion = nn.NLLLoss().to(device) 176 | criterion = nn.CrossEntropyLoss().to(device) 177 | testloader = DataLoader(test_dataset, batch_size=128, 178 | shuffle=False) 179 | 180 | for batch_idx, (images, labels) in enumerate(testloader): 181 | gc.collect() 182 | torch.cuda.empty_cache() 183 | images, labels = images.to(device), labels.to(device) 184 | 185 | # Inference 186 | outputs = model(images) 187 | batch_loss = criterion(outputs, labels) 188 | loss += batch_loss.item() 189 | 190 | # Prediction 191 | _, pred_labels = torch.max(outputs, 1) 192 | pred_labels = pred_labels.view(-1) 193 | correct += torch.sum(torch.eq(pred_labels, labels)).item() 194 | total += len(labels) 195 | 196 | accuracy = correct/total 197 | return accuracy, loss 198 | 199 | class ProxUpdate(object): 200 | def __init__(self, args, dataset, idxs, logger, local_epoch): 201 | self.args = args 202 | self.logger = logger 203 | self.local_ep = local_epoch 204 | self.trainloader, self.validloader, self.testloader = self.train_val_test( 205 | dataset, list(idxs)) 206 | self.device = 'cuda' if args.gpu else 'cpu' 207 | # Default criterion set to NLL loss function 208 | if args.dataset == "cifar": 209 | self.criterion = nn.CrossEntropyLoss().to(self.device) 210 | else: 211 | self.criterion = nn.NLLLoss().to(self.device) 212 | 213 | def train_val_test(self, dataset, idxs): 214 | """ 215 | Returns train, validation and test dataloaders for a given dataset 216 | and user indexes. 217 | """ 218 | # split indexes for train, validation, and test (80, 10, 10) 219 | 220 | idxs_train = idxs[:int(0.8*len(idxs))] 221 | idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))] 222 | idxs_test = idxs[int(0.9*len(idxs)):] 223 | 224 | trainloader = DataLoader(DatasetSplit(dataset, idxs_train), 225 | batch_size=self.args.local_bs, shuffle=True) 226 | if len(idxs_val) < 10: 227 | validloader = DataLoader(DatasetSplit(dataset, idxs_val), 228 | batch_size=1, shuffle=False) 229 | testloader = DataLoader(DatasetSplit(dataset, idxs_test), 230 | batch_size=1, shuffle=False) 231 | else: 232 | validloader = DataLoader(DatasetSplit(dataset, idxs_val), 233 | batch_size=int(len(idxs_val)/10), shuffle=False) 234 | testloader = DataLoader(DatasetSplit(dataset, idxs_test), 235 | batch_size=int(len(idxs_test)/10), shuffle=False) 236 | 237 | return trainloader, validloader, testloader 238 | 239 | def update_weights(self, model, global_round): 240 | # Set mode to train model 241 | model.train() 242 | epoch_loss = [] 243 | 244 | global_model = copy.deepcopy(model) 245 | start_time = time.time() 246 | 247 | decay = self.args.decay 248 | 249 | if decay != 0: 250 | learn_rate = self.args.lr * pow(decay, global_round) 251 | else: 252 | learn_rate = self.args.lr 253 | 254 | print(learn_rate) 255 | # Set optimizer for the local updates 256 | if self.args.optimizer == 'sgd': 257 | optimizer = torch.optim.SGD(model.parameters(), lr=(learn_rate), 258 | momentum=0.9, weight_decay=0.00001) 259 | elif self.args.optimizer == 'adam': 260 | optimizer = torch.optim.Adam(model.parameters(), lr=(learn_rate), 261 | weight_decay=1e-4) 262 | 263 | for iter in range(self.local_ep): 264 | batch_loss = [] 265 | for batch_idx, (images, labels) in enumerate(self.trainloader): 266 | images, labels = images.to(self.device), labels.to(self.device) 267 | 268 | model.zero_grad() 269 | log_probs = model(images) 270 | 271 | proximal_term = 0.0 272 | # iterate through the current and global model parameters 273 | for w, w_t in zip(model.parameters(), global_model.parameters()): 274 | # update the proximal term 275 | # proximal_term += torch.sum(torch.abs((w-w_t)**2)) 276 | proximal_term += (w - w_t).norm(2) 277 | 278 | loss = self.criterion(log_probs, labels) + (self.args.mu / 2) * proximal_term 279 | 280 | loss.backward() 281 | optimizer.step() 282 | batch_loss.append(loss.item()) 283 | 284 | if self.args.verbose and (batch_idx % 10 == 0): 285 | print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 286 | global_round, iter, batch_idx * len(images), 287 | len(self.trainloader.dataset), 288 | 100. * batch_idx / len(self.trainloader), loss.item())) 289 | self.logger.add_scalar('loss', loss.item()) 290 | 291 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 292 | gc.collect() 293 | torch.cuda.empty_cache() 294 | 295 | return model.state_dict(), sum(epoch_loss) / len(epoch_loss), time.time()- start_time 296 | 297 | def inference(self, model): 298 | """ Returns the inference accuracy and loss. 299 | """ 300 | 301 | model.eval() 302 | loss, total, correct = 0.0, 0.0, 0.0 303 | 304 | for batch_idx, (images, labels) in enumerate(self.testloader): 305 | images, labels = images.to(self.device), labels.to(self.device) 306 | 307 | # Inference 308 | outputs = model(images) 309 | batch_loss = self.criterion(outputs, labels) 310 | loss += batch_loss.item() 311 | 312 | # Prediction 313 | _, pred_labels = torch.max(outputs, 1) 314 | pred_labels = pred_labels.view(-1) 315 | correct += torch.sum(torch.eq(pred_labels, labels)).item() 316 | total += len(labels) 317 | 318 | accuracy = correct/total 319 | return accuracy, loss --------------------------------------------------------------------------------