├── src
├── __pycache__
│ ├── args.cpython-38.pyc
│ ├── models.cpython-38.pyc
│ ├── updates.cpython-38.pyc
│ └── utils.cpython-38.pyc
├── models.py
├── args.py
├── utils.py
├── scaffold_main.py
├── fedprox_main.py
└── updates.py
└── README.md
/src/__pycache__/args.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ongzh/ScaffoldFL/HEAD/src/__pycache__/args.cpython-38.pyc
--------------------------------------------------------------------------------
/src/__pycache__/models.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ongzh/ScaffoldFL/HEAD/src/__pycache__/models.cpython-38.pyc
--------------------------------------------------------------------------------
/src/__pycache__/updates.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ongzh/ScaffoldFL/HEAD/src/__pycache__/updates.cpython-38.pyc
--------------------------------------------------------------------------------
/src/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ongzh/ScaffoldFL/HEAD/src/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Federated Learning (Scaffold and Fedprox) using PyTorch
2 | **Basic Implementation of:**
3 | **Fedprox:** [Federated Optimization in Heterogeneous Networks](https://arxiv.org/abs/1812.06127)
4 | **Scaffold:** [SCAFFOLD: Stochastic Controlled Averaging for Federated Learning](https://arxiv.org/abs/1910.06378)
5 |
6 | Models: VGG, CNN
7 | DataSet: Cifar10
8 |
9 |
10 | ---
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | ### Note
25 | Non-IID implementation and dataset used differ from the Scaffold Paper.
26 | Sections of code inspired from FedAvg implementation by https://github.com/AshwinRJ/Federated-Learning-PyTorch
27 |
28 |
29 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | class cifarCNN(nn.Module):
8 | def __init__(self,args):
9 | super(cifarCNN, self).__init__()
10 | self.conv1 = nn.Conv2d(3, 6, 5)
11 | self.pool = nn.MaxPool2d(2, 2)
12 | self.conv2 = nn.Conv2d(6, 16, 5)
13 | self.fc1 = nn.Linear(16 * 5 * 5, 120)
14 | self.fc2 = nn.Linear(120, 84)
15 | self.fc3 = nn.Linear(84, args.num_classes)
16 |
17 | def forward(self, x):
18 | x = self.pool(F.relu(self.conv1(x)))
19 | x = self.pool(F.relu(self.conv2(x)))
20 |
21 | x = torch.flatten(x,1)
22 | x = F.relu(self.fc1(x))
23 | x = F.relu(self.fc2(x))
24 | x = self.fc3(x)
25 |
26 | return x
27 |
28 | class VGG(nn.Module):
29 | def __init__(self,args):
30 | super(VGG, self).__init__()
31 | self.inputs = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
32 | self.classifier = nn.Linear(512, 10)
33 | self.features = self._make_layers(self.inputs)
34 |
35 | def forward(self, x):
36 | x = self.features(x)
37 | x = x.view(x.size(0), -1)
38 | x = self.classifier(x)
39 | return x
40 |
41 | def _make_layers(self, inputs):
42 | layers = []
43 | in_channels = 3
44 | for x in inputs:
45 | if x == 'M':
46 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
47 | else:
48 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
49 | nn.BatchNorm2d(x),
50 | nn.ReLU(inplace=True)]
51 | in_channels = x
52 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
53 | return nn.Sequential(*layers)
--------------------------------------------------------------------------------
/src/args.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import argparse
5 |
6 |
7 | def args_parser():
8 | parser = argparse.ArgumentParser()
9 |
10 | # federated arguments (Notation for the arguments followed from paper)
11 | parser.add_argument('--epochs', type=int, default=10,
12 | help="number of rounds of training")
13 | parser.add_argument('--num_users', type=int, default=100,
14 | help="number of users: K")
15 | parser.add_argument('--frac', type=float, default=0.1,
16 | help='the fraction of clients: C')
17 | parser.add_argument('--local_ep', type=int, default=10,
18 | help="the number of local epochs: E")
19 | parser.add_argument('--local_bs', type=int, default=10,
20 | help="local batch size: B")
21 | parser.add_argument('--lr', type=float, default=0.01,
22 | help='learning rate')
23 | parser.add_argument('--momentum', type=float, default=0.5,
24 | help='SGD momentum (default: 0.5)')
25 |
26 | # model arguments
27 | parser.add_argument('--model', type=str, default='mlp', help='model name')
28 | parser.add_argument('--kernel_num', type=int, default=9,
29 | help='number of each kind of kernel')
30 | parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
31 | help='comma-separated kernel size to \
32 | use for convolution')
33 | parser.add_argument('--num_channels', type=int, default=1, help="number \
34 | of channels of imgs")
35 | parser.add_argument('--norm', type=str, default='batch_norm',
36 | help="batch_norm, layer_norm, or None")
37 | parser.add_argument('--num_filters', type=int, default=32,
38 | help="number of filters for conv nets -- 32 for \
39 | mini-imagenet, 64 for omiglot.")
40 | parser.add_argument('--max_pool', type=str, default='True',
41 | help="Whether use max pooling rather than \
42 | strided convolutions")
43 | parser.add_argument('--pretrained', type=str, default='false',
44 | help="whether model is pretrained")
45 |
46 | parser.add_argument('--decay', type=float, default=0, help="learning rate decay per global round")
47 |
48 | # other arguments
49 | parser.add_argument('--dataset', type=str, default='mnist', help="name \
50 | of dataset")
51 | parser.add_argument('--num_classes', type=int, default=10, help="number \
52 | of classes")
53 | parser.add_argument('--gpu', default=None, help="To use cuda, set \
54 | to a specific GPU ID. Default set to use CPU.")
55 | parser.add_argument('--optimizer', type=str, default='sgd', help="type \
56 | of optimizer")
57 | parser.add_argument('--iid', type=int, default=1,
58 | help='Default set to IID. Set to 0 for non-IID.')
59 | parser.add_argument('--unequal', type=int, default=0,
60 | help='whether to use unequal data splits for \
61 | non-i.i.d setting (use 0 for equal splits)')
62 | parser.add_argument('--stopping_rounds', type=int, default=10,
63 | help='rounds of early stopping')
64 | parser.add_argument('--verbose', type=int, default=1, help='verbose')
65 | parser.add_argument('--seed', type=int, default=1, help='random seed')
66 | parser.add_argument('--mu', type=float, default=0.0, help='proximal term constant')
67 | parser.add_argument('--stragglers', type=float, default=0, help='percentage of stragglers')
68 | args = parser.parse_args()
69 | return args
70 |
71 |
72 |
73 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import copy
5 | import torch
6 | from torchvision import datasets, transforms
7 | import numpy as np
8 |
9 | def get_dataset(args):
10 | """ Returns train and test datasets and a user group which is a dict where
11 | the keys are the user index and the values are the corresponding data for
12 | each of those users.
13 | """
14 |
15 | if args.dataset == 'cifar':
16 | data_dir = '../data/cifar/'
17 |
18 | train_transform = transforms.Compose(
19 | [transforms.ToTensor(),
20 | #transforms.RandomCrop(size=24),
21 | transforms.RandomApply(torch.nn.ModuleList([
22 | transforms.ColorJitter(),]),p=0.5),
23 | transforms.RandomAutocontrast(),
24 | transforms.RandomHorizontalFlip(),
25 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
26 |
27 | test_transform = transforms.Compose(
28 | [transforms.ToTensor(),
29 | #transforms.RandomCrop(size=24),
30 | transforms.Normalize((0.4914,0.4822,0.4465),(0.247,0.243,0.261))])
31 |
32 | train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
33 | transform=train_transform)
34 |
35 | test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
36 | transform=test_transform)
37 |
38 | # sample training data amongst users
39 | if args.iid:
40 | # Sample IID user data from Mnist
41 | user_groups = cifar_iid(train_dataset, args.num_users)
42 | else:
43 | # Sample Non-IID user data from Mnist
44 | if args.unequal:
45 | # Chose uneuqal splits for every user
46 | raise NotImplementedError()
47 | else:
48 | # Chose euqal splits for every user
49 | user_groups = cifar_noniid(train_dataset, args.num_users)
50 |
51 | return train_dataset, test_dataset, user_groups
52 |
53 | def average_weights(w):
54 | """
55 | Returns the average of the weights.
56 | """
57 | w_avg = copy.deepcopy(w[0])
58 | for key in w_avg.keys():
59 | for i in range(1, len(w)):
60 | w_avg[key] += w[i][key]
61 | w_avg[key] = torch.div(w_avg[key], len(w))
62 | return w_avg
63 |
64 |
65 | def exp_details(args):
66 | print('\nExperimental details:')
67 | print(f' Model : {args.model}')
68 | print(f' Optimizer : {args.optimizer}')
69 | print(f' Learning : {args.lr}')
70 | print(f' Global Rounds : {args.epochs}\n')
71 |
72 | print(' Federated parameters:')
73 | if args.iid:
74 | print(' IID')
75 | else:
76 | print(' Non-IID')
77 | print(f' Fraction of users : {args.frac}')
78 | print(f' Local Batch size : {args.local_bs}')
79 | print(f' Local Epochs : {args.local_ep}\n')
80 | return
81 |
82 | def cifar_iid(dataset, num_users):
83 | """
84 | Sample I.I.D. client data from CIFAR10 dataset
85 | :param dataset:
86 | :param num_users:
87 | :return: dict of image index
88 | """
89 | num_items = int(len(dataset)/num_users)
90 | dict_users, all_idxs = {}, [i for i in range(len(dataset))]
91 | for i in range(num_users):
92 | dict_users[i] = set(np.random.choice(all_idxs, num_items,
93 | replace=False))
94 | all_idxs = list(set(all_idxs) - dict_users[i])
95 | return dict_users
96 |
97 |
98 | def cifar_noniid(dataset, num_users):
99 | """
100 | Sample non-I.I.D client data from CIFAR10 dataset
101 | :param dataset:
102 | :param num_users:
103 | :return:
104 | """
105 | num_shards, num_imgs = 200, 250
106 | idx_shard = [i for i in range(num_shards)]
107 | dict_users = {i: np.array([]) for i in range(num_users)}
108 | idxs = np.arange(num_shards*num_imgs)
109 | # labels = dataset.train_labels.numpy()
110 | labels = np.array(dataset.targets)
111 |
112 | # sort labels
113 | idxs_labels = np.vstack((idxs, labels))
114 | #stack into two rows, sort the labels row
115 | idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
116 | idxs = idxs_labels[0, :]
117 |
118 | # divide and assign
119 | for i in range(num_users):
120 | rand_set = set(np.random.choice(idx_shard, 2, replace=False))
121 | idx_shard = list(set(idx_shard) - rand_set)
122 | for rand in rand_set:
123 | dict_users[i] = np.concatenate(
124 | (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
125 | return dict_users
126 |
127 |
--------------------------------------------------------------------------------
/src/scaffold_main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | import copy
6 | import time
7 | import pickle
8 | import numpy as np
9 | from tqdm import tqdm
10 | import torch
11 | import torch.nn as nn
12 | import gc
13 | from torch.utils.tensorboard import SummaryWriter
14 |
15 | from args import args_parser
16 | from updates import test_results, ScaffoldUpdate
17 | from models import cifarCNN
18 | from utils import get_dataset, exp_details
19 |
20 | if __name__ == '__main__':
21 | start_time = time.time()
22 |
23 | # define paths
24 | path_project = os.path.abspath('..')
25 | logger = SummaryWriter('../logs')
26 |
27 | args = args_parser()
28 | exp_details(args)
29 |
30 | if args.gpu:
31 | # if args.gpu_id:
32 | torch.cuda.set_device(int(args.gpu))
33 | device = 'cuda' if args.gpu else 'cpu'
34 |
35 | # load dataset and user groups
36 | train_dataset, test_dataset, user_groups = get_dataset(args)
37 |
38 | #only one model for now
39 | if args.dataset == 'cifar':
40 | global_model = cifarCNN(args=args)
41 | control_global = cifarCNN(args=args)
42 |
43 | #set global model to train
44 | global_model.to(device)
45 | global_model.train()
46 | print(global_model)
47 |
48 | control_global.to(device)
49 |
50 | control_weights = control_global.state_dict()
51 |
52 |
53 |
54 | # Training
55 | train_loss, train_accuracy = [], []
56 | val_acc_list, net_list = [], []
57 | cv_loss, cv_acc = [], []
58 | print_every = 2
59 | val_loss_pre, counter = 0, 0
60 |
61 | # Test each round
62 | test_acc_list = []
63 |
64 |
65 | #devices that participate (sample size)
66 | m = max(int(args.frac * args.num_users), 1)
67 |
68 | #model for local control varietes
69 | local_controls = [cifarCNN(args=args) for i in range(args.num_users)]
70 | #local_models = [cifarCNN(args=args) for i in range(args.num_users)]
71 |
72 | for net in local_controls:
73 | net.load_state_dict(control_weights)
74 |
75 |
76 | #initiliase total delta to 0 (sum of all control_delta, triangle Ci)
77 | delta_c = copy.deepcopy(global_model.state_dict())
78 | #sum of delta_y / sample size
79 | delta_x = copy.deepcopy(global_model.state_dict())
80 |
81 |
82 |
83 |
84 |
85 | #global rounds
86 | for epoch in tqdm(range(args.epochs)):
87 | local_weights, local_losses = [], []
88 | print(f'\n | Global Training Round : {epoch+1} |\n')
89 |
90 | for ci in delta_c:
91 | delta_c[ci] = 0.0
92 | for ci in delta_x:
93 | delta_x[ci] = 0.0
94 |
95 | global_model.train()
96 | # sample the users
97 | idxs_users = np.random.choice(range(args.num_users), m, replace=False)
98 |
99 | for idx in idxs_users:
100 | local_model = ScaffoldUpdate(args=args, dataset=train_dataset,
101 | idxs=user_groups[idx], logger=logger)
102 | weights, loss , local_delta_c, local_delta, control_local_w, _ = local_model.update_weights(
103 | model=copy.deepcopy(global_model), global_round=epoch, control_local
104 | = local_controls[idx], control_global = control_global)
105 |
106 | if epoch != 0:
107 | local_controls[idx].load_state_dict(control_local_w)
108 |
109 | local_weights.append(copy.deepcopy(weights))
110 | local_losses.append(copy.deepcopy(loss))
111 |
112 | #line16
113 | for w in delta_c:
114 | if epoch==0:
115 | delta_x[w] += weights[w]
116 | else:
117 | delta_x[w] += local_delta[w]
118 | delta_c[w] += local_delta_c[w]
119 |
120 | #clean
121 | gc.collect()
122 | torch.cuda.empty_cache()
123 |
124 | #update the delta C (line 16)
125 | for w in delta_c:
126 | delta_c[w] /= m
127 | delta_x[w] /= m
128 |
129 | #update global control variate (line17)
130 | control_global_W = control_global.state_dict()
131 | global_weights = global_model.state_dict()
132 | #equation taking Ng, global step size = 1
133 | for w in control_global_W:
134 | #control_global_W[w] += delta_c[w]
135 | if epoch == 0:
136 | global_weights[w] = delta_x[w]
137 | else:
138 | global_weights[w] += delta_x[w]
139 | control_global_W[w] += (m / args.num_users) * delta_c[w]
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 | #update global model
148 | control_global.load_state_dict(control_global_W)
149 | global_model.load_state_dict(global_weights)
150 |
151 | #########scaffold algo complete##################
152 |
153 |
154 | loss_avg = sum(local_losses) / len(local_losses)
155 | train_loss.append(loss_avg)
156 |
157 | # Calculate avg training accuracy over all users at every epoch
158 | list_acc, list_loss = [], []
159 |
160 | global_model.eval()
161 |
162 | for c in range(args.num_users):
163 | local_model = ScaffoldUpdate(args=args, dataset=train_dataset,
164 | idxs=user_groups[idx], logger=logger)
165 | acc, loss = local_model.inference(model=global_model)
166 | list_acc.append(acc)
167 | #print("user:" + str(c) +" " + str(acc))
168 | list_loss.append(loss)
169 | gc.collect()
170 | torch.cuda.empty_cache()
171 |
172 | train_accuracy.append(sum(list_acc)/len(list_acc))
173 |
174 | round_test_acc, round_test_loss = test_results(
175 | args, global_model, test_dataset)
176 | test_acc_list.append(round_test_acc)
177 |
178 |
179 | # print global training loss after every 'i' rounds
180 | if (epoch+1) % print_every == 0:
181 | print(f' \nAvg Training Stats after {epoch+1} global rounds:')
182 | print(f'Training Loss : {np.mean(np.array(train_loss))}')
183 | print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))
184 | print('Test Accuracy at round ' + str(epoch+1) +
185 | ': {:.2f}% \n'.format(100*round_test_acc))
186 |
187 | # Test inference after completion of training
188 | test_acc, test_loss = test_results(args, global_model, test_dataset)
189 |
190 | print(f' \n Results after {args.epochs} global rounds of training:')
191 | print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
192 | print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))
193 |
194 | # save results to csv
195 | res = np.asarray([test_acc_list])
196 | res_name = '../save/csvResults/Scaffold_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_LR[{}].csv'. \
197 | format(args.dataset, args.model, args.epochs, args.frac, args.iid,
198 | args.local_ep, args.local_bs, args.lr)
199 | np.savetxt(res_name, res, delimiter=",")
200 |
--------------------------------------------------------------------------------
/src/fedprox_main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 |
5 | import os
6 | import copy
7 | import time
8 | import pickle
9 | import numpy as np
10 | from tqdm import tqdm
11 |
12 | import torch
13 | import gc
14 | from tensorboardX import SummaryWriter
15 |
16 | from args import args_parser
17 | from updates import ProxUpdate, test_results
18 | from models import cifarCNN, VGG
19 | from utils import get_dataset, exp_details, average_weights
20 | import torchvision.models as models
21 |
22 | if __name__ == '__main__':
23 | start_time = time.time()
24 |
25 | # define paths
26 | path_project = os.path.abspath('..')
27 | logger = SummaryWriter('../logs')
28 |
29 | args = args_parser()
30 | exp_details(args)
31 |
32 | if args.gpu:
33 | # if args.gpu_id:
34 | torch.cuda.set_device(int(args.gpu))
35 | device = 'cuda' if args.gpu else 'cpu'
36 |
37 | # load dataset and user groups
38 | train_dataset, test_dataset, user_groups = get_dataset(args)
39 |
40 | # BUILD MODEL
41 | if args.model == 'cnn':
42 | if args.dataset == 'cifar':
43 | global_model = cifarCNN(args=args)
44 |
45 | elif args.model == 'vgg':
46 | if args.dataset == 'cifar' and args.pretrained:
47 | global_model = models.vgg16(pretrained=True)
48 | # change the number of classes
49 | global_model.classifier[6].out_features = 10
50 | # freeze convolution weights
51 | for param in global_model.features.parameters():
52 | param.requires_grad = False
53 | elif args.dataset == 'cifar':
54 | global_model = VGG(args=args)
55 | else:
56 | exit(args.dataset + ' with ' + args.model + ' not supported')
57 |
58 | elif args.model == "resnet18":
59 | global_model = models.resnet18(pretrained=True)
60 |
61 | else:
62 | exit('Error: unrecognized model')
63 |
64 | # Set the model to train and send it to device.
65 | global_model.to(device)
66 | global_model.train()
67 | print(global_model)
68 |
69 | # copy weights
70 | global_weights = global_model.state_dict()
71 |
72 | # Training
73 | train_loss, train_accuracy = [], []
74 | val_acc_list, net_list = [], []
75 | cv_loss, cv_acc = [], []
76 | print_every = 2
77 | val_loss_pre, counter = 0, 0
78 |
79 | # Test each round
80 | test_acc_list = []
81 |
82 | #devices that participate
83 | m = max(int(args.frac * args.num_users), 1)
84 |
85 |
86 | for epoch in tqdm(range(args.epochs)):
87 | local_weights, local_losses = [], []
88 | print(f'\n | Global Training Round : {epoch+1} |\n')
89 |
90 | # Local Epochs list to account for stragglers
91 | if args.stragglers == 0:
92 | local_epoch_list = np.array([args.local_ep] * m)
93 | else:
94 | straggler_size = int(args.stragglers * m)
95 | local_epoch_list = np.random.randint(1, args.local_ep, straggler_size)
96 |
97 | remainders = m - straggler_size
98 | rem_list = [args.local_ep] * remainders
99 |
100 | epoch_list = np.append(local_epoch_list, rem_list, axis=0)
101 | # shuffle the list and return
102 | np.random.shuffle(local_epoch_list)
103 |
104 | global_model.train()
105 |
106 | idxs_users = np.random.choice(range(args.num_users), m, replace=False)
107 |
108 | for idx, local_epoch in zip(idxs_users, local_epoch_list):
109 | local_model = ProxUpdate(args=args, dataset=train_dataset,
110 | idxs=user_groups[idx], logger=logger, local_epoch=local_epoch)
111 | w, loss, time = local_model.update_weights(
112 | model=copy.deepcopy(global_model), global_round=epoch)
113 | local_weights.append(copy.deepcopy(w))
114 | local_losses.append(copy.deepcopy(loss))
115 | gc.collect()
116 | torch.cuda.empty_cache()
117 |
118 | # update global weights
119 | global_weights = average_weights(local_weights)
120 |
121 | # update global weights
122 | global_model.load_state_dict(global_weights)
123 |
124 | loss_avg = sum(local_losses) / len(local_losses)
125 | train_loss.append(loss_avg)
126 |
127 | # Calculate avg training accuracy over all users at every epoch
128 | list_acc, list_loss = [], []
129 | global_model.eval()
130 |
131 | for c in range(args.num_users):
132 | local_model = ProxUpdate(args=args, dataset=train_dataset,
133 | idxs=user_groups[idx], logger=logger, local_epoch=args.local_ep)
134 | acc, loss = local_model.inference(model=global_model)
135 | list_acc.append(acc)
136 | list_loss.append(loss)
137 | gc.collect()
138 | torch.cuda.empty_cache()
139 |
140 | train_accuracy.append(sum(list_acc)/len(list_acc))
141 |
142 | round_test_acc, round_test_loss = test_results(
143 | args, global_model, test_dataset)
144 | test_acc_list.append(round_test_acc)
145 |
146 | # print global training loss after every 'i' rounds
147 | if (epoch+1) % print_every == 0:
148 | print(f' \nAvg Training Stats after {epoch+1} global rounds:')
149 | print(f'Training Loss : {np.mean(np.array(train_loss))}')
150 | print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))
151 | print('Test Accuracy at round ' + str(epoch+1) +
152 | ': {:.2f}% \n'.format(100*round_test_acc))
153 |
154 | # Test inference after completion of training
155 | test_acc, test_loss = test_results(args, global_model, test_dataset)
156 |
157 | print(f' \n Results after {args.epochs} global rounds of training:')
158 | print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
159 | print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))
160 |
161 | # Saving the objects train_loss and train_accuracy:
162 | file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_LR[{}]_u[{}]_%strag[{}].pkl'.\
163 | format(args.dataset, args.model, args.epochs, args.frac, args.iid,
164 | args.local_ep, args.local_bs, args.lr,args.mu, args.stragglers)
165 |
166 | with open(file_name, 'wb') as f:
167 | pickle.dump([train_loss, train_accuracy], f)
168 |
169 | #print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))
170 |
171 | # save results to csv
172 | res = np.asarray([test_acc_list])
173 | res_name = '../save/csvResults/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_LR[{}]_u[{}]_%strag[{}].csv'. \
174 | format(args.dataset, args.model, args.epochs, args.frac, args.iid,
175 | args.local_ep, args.local_bs, args.lr, args.mu, args.stragglers)
176 | np.savetxt(res_name, res, delimiter=",")
177 |
178 | # PLOTTING (optional)
179 | import matplotlib
180 | import matplotlib.pyplot as plt
181 | matplotlib.use('Agg')
182 |
183 | # # Plot Loss curve
184 | # plt.figure()
185 | # plt.title('Training Loss vs Communication rounds')
186 | # plt.plot(range(len(train_loss)), train_loss, color='r')
187 | # plt.ylabel('Training loss')
188 | # plt.xlabel('Communication Rounds')
189 | # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.
190 | # format(args.dataset, args.model, args.epochs, args.frac,
191 | # args.iid, args.local_ep, args.local_bs))
192 |
193 | # # Plot Average Accuracy vs Communication rounds
194 | # plt.figure()
195 | # plt.title('Average Accuracy vs Communication rounds')
196 | # plt.plot(range(len(train_accuracy)), train_accuracy , color='k')
197 | # plt.ylabel('Average Accuracy')
198 | # plt.xlabel('Communication Rounds')
199 | # plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.
200 | # format(args.dataset, args.model, args.epochs, args.frac,
201 | # args.iid, args.local_ep, args.local_bs))
202 |
203 | # Plot Test Accuracy vs Communication rounds
204 | plt.figure()
205 | plt.title('Test Accuracy vs Communication rounds')
206 | plt.plot(range(len(train_accuracy)), test_acc_list, color='k')
207 | plt.ylabel('Test Accuracy')
208 | plt.xlabel('Communication Rounds')
209 | plt.ylim([0.1, 1])
210 | plt.savefig('../save/fedPROX_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_LR[{}]_mu[{}]_%strag[{}]_test_acc.png'.
211 | format(args.dataset, args.model, args.epochs, args.frac,
212 | args.iid, args.local_ep, args.local_bs, args.lr, args.mu, args.stragglers))
213 |
214 |
--------------------------------------------------------------------------------
/src/updates.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-`````````````````````````````````````````
3 |
4 | import copy
5 | import torch
6 | from torch import nn
7 | import time
8 | from torch.utils.data import DataLoader, Dataset
9 | import gc
10 |
11 | class DatasetSplit(Dataset):
12 | """An abstract Dataset class wrapped around Pytorch Dataset class.
13 | """
14 |
15 | def __init__(self, dataset, idxs):
16 | self.dataset = dataset
17 | self.idxs = [int(i) for i in idxs]
18 |
19 | def __len__(self):
20 | return len(self.idxs)
21 |
22 | def __getitem__(self, item):
23 | image, label = self.dataset[self.idxs[item]]
24 | return torch.tensor(image), torch.tensor(label)
25 |
26 | class ScaffoldUpdate(object):
27 | def __init__(self, args, dataset, idxs, logger):
28 | self.args = args
29 | self.logger = logger
30 | self.local_ep = args.local_ep
31 | self.trainloader, self.validloader, self.testloader = self.train_val_test(
32 | dataset, list(idxs))
33 | self.device = 'cuda' if args.gpu else 'cpu'
34 | self.criterion = nn.CrossEntropyLoss().to(self.device)
35 |
36 |
37 |
38 | def train_val_test(self, dataset, idxs):
39 | """
40 | Returns train, validation and test dataloaders for a given dataset
41 | and user indexes.
42 | """
43 | # split indexes for train, validation, and test (80, 10, 10)
44 |
45 | idxs_train = idxs[:int(0.8*len(idxs))]
46 | idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))]
47 | idxs_test = idxs[int(0.9*len(idxs)):]
48 |
49 | trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
50 | batch_size=self.args.local_bs, shuffle=True)
51 | if len(idxs_val) < 10:
52 | validloader = DataLoader(DatasetSplit(dataset, idxs_val),
53 | batch_size=1, shuffle=False)
54 | testloader = DataLoader(DatasetSplit(dataset, idxs_test),
55 | batch_size=1, shuffle=False)
56 | else:
57 | validloader = DataLoader(DatasetSplit(dataset, idxs_val),
58 | batch_size=int(len(idxs_val)/10), shuffle=False)
59 | testloader = DataLoader(DatasetSplit(dataset, idxs_test),
60 | batch_size=int(len(idxs_test)/10), shuffle=False)
61 |
62 | return trainloader, validloader, testloader
63 |
64 | def update_weights(self, model, global_round, control_local, control_global):
65 | # Set mode to train model
66 | model.to(self.device)
67 | global_weights = model.state_dict()
68 | model.train()
69 | epoch_loss = []
70 |
71 | start_time = time.time()
72 |
73 | decay = self.args.decay
74 | if decay != 0:
75 | learn_rate = self.args.lr * pow(decay, global_round)
76 | else:
77 | learn_rate = self.args.lr
78 |
79 | # Set optimizer for the local updates
80 | if self.args.optimizer == 'sgd':
81 | optimizer = torch.optim.SGD(model.parameters(), lr=(learn_rate),
82 | momentum=0.9, weight_decay=0.00001)
83 | elif self.args.optimizer == 'adam':
84 | optimizer = torch.optim.Adam(model.parameters(), lr=(learn_rate),
85 | weight_decay=1e-4)
86 |
87 | control_global_w = control_global.state_dict()
88 | control_local_w = control_local.state_dict()
89 |
90 | count = 0
91 | for iter in range(self.local_ep):
92 | batch_loss = []
93 | for batch_idx, (images, labels) in enumerate(self.trainloader):
94 | images, labels = images.to(self.device), labels.to(self.device)
95 |
96 | model.zero_grad()
97 | log_probs = model(images)
98 | loss = self.criterion(log_probs, labels)
99 | loss.backward()
100 | optimizer.step()
101 |
102 | batch_loss.append(loss.item())
103 |
104 | local_weights = model.state_dict()
105 | for w in local_weights:
106 | #line 10 in algo
107 | local_weights[w] = local_weights[w] - self.args.lr*(control_global_w[w]-control_local_w[w])
108 |
109 | #update local model params
110 | model.load_state_dict(local_weights)
111 |
112 | count += 1
113 |
114 |
115 | if self.args.verbose and (batch_idx % 10 == 0):
116 | print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
117 | global_round, iter, batch_idx * len(images),
118 | len(self.trainloader.dataset),
119 | 100. * batch_idx / len(self.trainloader), loss.item()))
120 | self.logger.add_scalar('loss', loss.item())
121 |
122 | epoch_loss.append(sum(batch_loss)/len(batch_loss))
123 | gc.collect()
124 | torch.cuda.empty_cache()
125 |
126 |
127 | new_control_local_w = control_local.state_dict()
128 | control_delta = copy.deepcopy(control_local_w)
129 | #model_weights -> y_(i)
130 | model_weights = model.state_dict()
131 | local_delta = copy.deepcopy(model_weights)
132 | for w in model_weights:
133 | #line 12 in algo
134 | new_control_local_w[w] = new_control_local_w[w] - control_global_w[w] + (global_weights[w] - model_weights[w]) / (count * self.args.lr)
135 | #line 13
136 | control_delta[w] = new_control_local_w[w] - control_local_w[w]
137 | local_delta[w] -= global_weights[w]
138 | #update new control_local model
139 | #control_local.load_state_dict(new_control_local_w)
140 |
141 | return model.state_dict(), sum(epoch_loss) / len(epoch_loss), control_delta, local_delta, new_control_local_w, time.time()- start_time
142 |
143 | def inference(self, model):
144 | """ Returns the inference accuracy and loss.
145 | """
146 |
147 | model.eval()
148 | loss, total, correct = 0.0, 0.0, 0.0
149 |
150 | for batch_idx, (images, labels) in enumerate(self.testloader):
151 | images, labels = images.to(self.device), labels.to(self.device)
152 |
153 | # Inference
154 | outputs = model(images)
155 | batch_loss = self.criterion(outputs, labels)
156 | loss += batch_loss.item()
157 |
158 | # Prediction
159 | _, pred_labels = torch.max(outputs, 1)
160 | pred_labels = pred_labels.view(-1)
161 | correct += torch.sum(torch.eq(pred_labels, labels)).item()
162 | total += len(labels)
163 |
164 | accuracy = correct/total
165 | return accuracy, loss
166 |
167 | def test_results(args, model, test_dataset):
168 | """ Returns the test accuracy and loss.
169 | """
170 |
171 | model.eval()
172 | loss, total, correct = 0.0, 0.0, 0.0
173 |
174 | device = 'cuda' if args.gpu else 'cpu'
175 | #criterion = nn.NLLLoss().to(device)
176 | criterion = nn.CrossEntropyLoss().to(device)
177 | testloader = DataLoader(test_dataset, batch_size=128,
178 | shuffle=False)
179 |
180 | for batch_idx, (images, labels) in enumerate(testloader):
181 | gc.collect()
182 | torch.cuda.empty_cache()
183 | images, labels = images.to(device), labels.to(device)
184 |
185 | # Inference
186 | outputs = model(images)
187 | batch_loss = criterion(outputs, labels)
188 | loss += batch_loss.item()
189 |
190 | # Prediction
191 | _, pred_labels = torch.max(outputs, 1)
192 | pred_labels = pred_labels.view(-1)
193 | correct += torch.sum(torch.eq(pred_labels, labels)).item()
194 | total += len(labels)
195 |
196 | accuracy = correct/total
197 | return accuracy, loss
198 |
199 | class ProxUpdate(object):
200 | def __init__(self, args, dataset, idxs, logger, local_epoch):
201 | self.args = args
202 | self.logger = logger
203 | self.local_ep = local_epoch
204 | self.trainloader, self.validloader, self.testloader = self.train_val_test(
205 | dataset, list(idxs))
206 | self.device = 'cuda' if args.gpu else 'cpu'
207 | # Default criterion set to NLL loss function
208 | if args.dataset == "cifar":
209 | self.criterion = nn.CrossEntropyLoss().to(self.device)
210 | else:
211 | self.criterion = nn.NLLLoss().to(self.device)
212 |
213 | def train_val_test(self, dataset, idxs):
214 | """
215 | Returns train, validation and test dataloaders for a given dataset
216 | and user indexes.
217 | """
218 | # split indexes for train, validation, and test (80, 10, 10)
219 |
220 | idxs_train = idxs[:int(0.8*len(idxs))]
221 | idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))]
222 | idxs_test = idxs[int(0.9*len(idxs)):]
223 |
224 | trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
225 | batch_size=self.args.local_bs, shuffle=True)
226 | if len(idxs_val) < 10:
227 | validloader = DataLoader(DatasetSplit(dataset, idxs_val),
228 | batch_size=1, shuffle=False)
229 | testloader = DataLoader(DatasetSplit(dataset, idxs_test),
230 | batch_size=1, shuffle=False)
231 | else:
232 | validloader = DataLoader(DatasetSplit(dataset, idxs_val),
233 | batch_size=int(len(idxs_val)/10), shuffle=False)
234 | testloader = DataLoader(DatasetSplit(dataset, idxs_test),
235 | batch_size=int(len(idxs_test)/10), shuffle=False)
236 |
237 | return trainloader, validloader, testloader
238 |
239 | def update_weights(self, model, global_round):
240 | # Set mode to train model
241 | model.train()
242 | epoch_loss = []
243 |
244 | global_model = copy.deepcopy(model)
245 | start_time = time.time()
246 |
247 | decay = self.args.decay
248 |
249 | if decay != 0:
250 | learn_rate = self.args.lr * pow(decay, global_round)
251 | else:
252 | learn_rate = self.args.lr
253 |
254 | print(learn_rate)
255 | # Set optimizer for the local updates
256 | if self.args.optimizer == 'sgd':
257 | optimizer = torch.optim.SGD(model.parameters(), lr=(learn_rate),
258 | momentum=0.9, weight_decay=0.00001)
259 | elif self.args.optimizer == 'adam':
260 | optimizer = torch.optim.Adam(model.parameters(), lr=(learn_rate),
261 | weight_decay=1e-4)
262 |
263 | for iter in range(self.local_ep):
264 | batch_loss = []
265 | for batch_idx, (images, labels) in enumerate(self.trainloader):
266 | images, labels = images.to(self.device), labels.to(self.device)
267 |
268 | model.zero_grad()
269 | log_probs = model(images)
270 |
271 | proximal_term = 0.0
272 | # iterate through the current and global model parameters
273 | for w, w_t in zip(model.parameters(), global_model.parameters()):
274 | # update the proximal term
275 | # proximal_term += torch.sum(torch.abs((w-w_t)**2))
276 | proximal_term += (w - w_t).norm(2)
277 |
278 | loss = self.criterion(log_probs, labels) + (self.args.mu / 2) * proximal_term
279 |
280 | loss.backward()
281 | optimizer.step()
282 | batch_loss.append(loss.item())
283 |
284 | if self.args.verbose and (batch_idx % 10 == 0):
285 | print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
286 | global_round, iter, batch_idx * len(images),
287 | len(self.trainloader.dataset),
288 | 100. * batch_idx / len(self.trainloader), loss.item()))
289 | self.logger.add_scalar('loss', loss.item())
290 |
291 | epoch_loss.append(sum(batch_loss)/len(batch_loss))
292 | gc.collect()
293 | torch.cuda.empty_cache()
294 |
295 | return model.state_dict(), sum(epoch_loss) / len(epoch_loss), time.time()- start_time
296 |
297 | def inference(self, model):
298 | """ Returns the inference accuracy and loss.
299 | """
300 |
301 | model.eval()
302 | loss, total, correct = 0.0, 0.0, 0.0
303 |
304 | for batch_idx, (images, labels) in enumerate(self.testloader):
305 | images, labels = images.to(self.device), labels.to(self.device)
306 |
307 | # Inference
308 | outputs = model(images)
309 | batch_loss = self.criterion(outputs, labels)
310 | loss += batch_loss.item()
311 |
312 | # Prediction
313 | _, pred_labels = torch.max(outputs, 1)
314 | pred_labels = pred_labels.view(-1)
315 | correct += torch.sum(torch.eq(pred_labels, labels)).item()
316 | total += len(labels)
317 |
318 | accuracy = correct/total
319 | return accuracy, loss
--------------------------------------------------------------------------------