├── LICENSE ├── README.md ├── data └── MNIST │ ├── t10k-images-idx3-ubyte.gz │ ├── t10k-labels-idx1-ubyte.gz │ ├── train-images-idx3-ubyte.gz │ └── train-labels-idx1-ubyte.gz ├── image └── Figure_1.png └── src ├── FLTrustServer.py ├── Models.py ├── clients.py ├── getData.py ├── server.py └── test.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 zhmzm 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FLTrust_pytorch 2 | **Unofficial implementation** for FLTrust, if there is any problem, please let me know. 3 | 4 | paper FLTrust from https://arxiv.org/pdf/2012.13995.pdf 5 | 6 | official implementation from https://arxiv.org/abs/2012.13995 7 | 8 | Some codes refer to https://github.com/WHDY/FedAvg 9 | 10 | This code is not suitable for resnet, because of BN layers. [Here](https://github.com/zhmzm/FLAME) contains another version of fltrust and it can support models with BN layers. 11 | 12 | 2022-12-29: 13 | 14 | Now it can support BN layers in ResNet and VGG. Please report bugs you encounter in the issue. I will fix it soon. 15 | 16 | The central dataset(100 samples) is randomly selected from the test dataset. 17 | 18 | # Backdoor in FL 19 | 20 | **Our recent paper "Backdoor Federated Learning by Poisoning Backdoor-critical Layers" has been accepted in ICLR'24, please refer to the [Github repo](https://github.com/zhmzm/Poisoning_Backdoor-critical_Layers_Attack).** 21 | 22 | # Quick Start 23 | 24 | ```asp 25 | python FLTrustServer.py -nc 100 -cf 0.1 -E 5 -B 10 -mn mnist_2nn -ncomm 1000 -iid 0 -lr 0.01 -vf 20 -g 0 26 | ``` 27 | 28 | -------------------------------------------------------------------------------- /data/MNIST/t10k-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLTrust_pytorch/1235eac1f4542965da3913df41cba2f7ddbabd23/data/MNIST/t10k-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /data/MNIST/t10k-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLTrust_pytorch/1235eac1f4542965da3913df41cba2f7ddbabd23/data/MNIST/t10k-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /data/MNIST/train-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLTrust_pytorch/1235eac1f4542965da3913df41cba2f7ddbabd23/data/MNIST/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /data/MNIST/train-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLTrust_pytorch/1235eac1f4542965da3913df41cba2f7ddbabd23/data/MNIST/train-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /image/Figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLTrust_pytorch/1235eac1f4542965da3913df41cba2f7ddbabd23/image/Figure_1.png -------------------------------------------------------------------------------- /src/FLTrustServer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import optim 8 | from Models import Mnist_2NN, Mnist_CNN 9 | from clients import ClientsGroup, client 10 | import matplotlib.pyplot as plt 11 | def cos(a,b): 12 | res = np.sum(a*b.T)/((np.sqrt(np.sum(a * a.T)) + 1e-9) * (np.sqrt(np.sum(b * b.T))) + 1e-9) 13 | '''relu''' 14 | if res < 0: 15 | res = 0 16 | return res 17 | def model2vector(model): 18 | nparr = np.array([]) 19 | vec = [] 20 | for key, var in model.items(): 21 | if key.split('.')[-1] == 'num_batches_tracked' or key.split('.')[-1] == 'running_mean' or key.split('.')[-1] == 'running_var': 22 | continue 23 | nplist = var.cpu().numpy() 24 | nplist = nplist.ravel() 25 | nparr = np.append(nparr, nplist) 26 | return nparr 27 | 28 | def cosScoreAndClipValue(net1, net2): 29 | '''net1 -> centre, net2 -> local, net3 -> early model''' 30 | vector1 = model2vector(net1) 31 | vector2 = model2vector(net2) 32 | 33 | return cos(vector1, vector2), norm_clip(vector1, vector2) 34 | def norm_clip(nparr1, nparr2): 35 | '''v -> nparr1, v_clipped -> nparr2''' 36 | vnum = np.linalg.norm(nparr1, ord=None, axis=None, keepdims=False) + 1e-9 37 | return vnum / np.linalg.norm(nparr2, ord=None, axis=None, keepdims=False) + 1e-9 38 | 39 | def get_weight(update, model): 40 | '''get the update weight''' 41 | for key, var in update.items(): 42 | update[key] -= model[key] 43 | return update 44 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="FedAvg") 45 | parser.add_argument('-g', '--gpu', type=str, default='0', help='gpu id to use(e.g. 0,1,2,3)') 46 | parser.add_argument('-nc', '--num_of_clients', type=int, default=100, help='numer of the clients') 47 | parser.add_argument('-cf', '--cfraction', type=float, default=0.1, help='C fraction, 0 means 1 client, 1 means total clients') 48 | parser.add_argument('-E', '--epoch', type=int, default=5, help='local train epoch') 49 | parser.add_argument('-B', '--batchsize', type=int, default=10, help='local train batch size') 50 | parser.add_argument('-mn', '--model_name', type=str, default='mnist_cnn', help='the model to train') 51 | parser.add_argument('-lr', "--learning_rate", type=float, default=0.01, help="learning rate, \ 52 | use value from origin paper as default") 53 | parser.add_argument('-vf', "--val_freq", type=int, default=2, help="model validation frequency(of communications)") 54 | parser.add_argument('-sf', '--save_freq', type=int, default=100, help='global model save frequency(of communication)') 55 | parser.add_argument('-ncomm', '--num_comm', type=int, default=2000, help='number of communications') 56 | parser.add_argument('-sp', '--save_path', type=str, default='./checkpoints', help='the saving path of checkpoints') 57 | parser.add_argument('-iid', '--IID', type=int, default=0, help='the way to allocate data to clients') 58 | 59 | 60 | def test_mkdir(path): 61 | if not os.path.isdir(path): 62 | os.mkdir(path) 63 | 64 | 65 | if __name__=="__main__": 66 | args = parser.parse_args() 67 | args = args.__dict__ 68 | 69 | acc_list=[] 70 | test_mkdir(args['save_path']) 71 | 72 | os.environ['CUDA_VISIBLE_DEVICES'] = args['gpu'] 73 | dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 74 | 75 | net = None 76 | if args['model_name'] == 'mnist_2nn': 77 | net = Mnist_2NN() 78 | elif args['model_name'] == 'mnist_cnn': 79 | net = Mnist_CNN() 80 | 81 | if torch.cuda.device_count() > 1: 82 | print("Let's use", torch.cuda.device_count(), "GPUs!") 83 | net = torch.nn.DataParallel(net) 84 | net = net.to(dev) 85 | 86 | loss_func = F.cross_entropy 87 | opti = optim.SGD(net.parameters(), lr=args['learning_rate']) 88 | 89 | myClients = ClientsGroup('mnist', args['IID'], args['num_of_clients'], dev) 90 | testDataLoader = myClients.test_data_loader 91 | 92 | num_in_comm = int(max(args['num_of_clients'] * args['cfraction'], 1)) 93 | 94 | global_parameters = {} 95 | for key, var in net.state_dict().items(): 96 | global_parameters[key] = var.clone() 97 | 98 | for i in range(args['num_comm']): 99 | print("communicate round {}".format(i+1)) 100 | 101 | order = np.random.permutation(args['num_of_clients']) 102 | clients_in_comm = ['client{}'.format(i) for i in order[0:num_in_comm]] 103 | 104 | sum_parameters = None 105 | FLTrustTotalScore = 0 106 | FLTrustCentralNorm = myClients.centralTrain(args['epoch'], args['batchsize'], net, 107 | loss_func, opti, global_parameters) 108 | '''get the update weight''' 109 | FLTrustCentralNorm = get_weight(FLTrustCentralNorm, global_parameters) 110 | 111 | for client in tqdm(clients_in_comm): 112 | local_parameters = myClients.clients_set[client].localUpdate(args['epoch'], args['batchsize'], net, 113 | loss_func, opti, global_parameters) 114 | '''get the update weight''' 115 | local_parameters = get_weight(local_parameters, global_parameters) 116 | #计算cos相似度得分和向量长度裁剪值 117 | client_score, client_clipped_value = cosScoreAndClipValue(FLTrustCentralNorm, local_parameters) 118 | 119 | FLTrustTotalScore += client_score 120 | if sum_parameters is None: 121 | sum_parameters = {} 122 | for key, var in local_parameters.items(): 123 | #乘得分 再乘裁剪值 124 | sum_parameters[key] = client_score * client_clipped_value * var.clone() 125 | else: 126 | for var in sum_parameters: 127 | sum_parameters[var] = sum_parameters[var] + client_score * client_clipped_value * local_parameters[var] 128 | 129 | for var in global_parameters: 130 | #除以所以客户端的信任得分总和 131 | global_parameters[var] += sum_parameters[var] / (FLTrustTotalScore + 1e-9) 132 | 133 | with torch.no_grad(): 134 | if (i + 1) % args['val_freq'] == 0: 135 | net.load_state_dict(global_parameters, strict=True) 136 | sum_accu = 0 137 | num = 0 138 | for data, label in testDataLoader: 139 | data, label = data.to(dev), label.to(dev) 140 | preds = net(data) 141 | preds = torch.argmax(preds, dim=1) 142 | sum_accu += (preds == label).float().mean() 143 | num += 1 144 | print('accuracy: {}'.format(sum_accu / num)) 145 | acc_list.append(sum_accu.item() / num) 146 | print(acc_list) 147 | 148 | if (i + 1) % args['save_freq'] == 0: 149 | torch.save(net, os.path.join(args['save_path'], 150 | '{}_num_comm{}_E{}_B{}_lr{}_num_clients{}_cf{}'.format(args['model_name'], 151 | i, args['epoch'], 152 | args['batchsize'], 153 | args['learning_rate'], 154 | args['num_of_clients'], 155 | args['cfraction']))) 156 | print(acc_list) 157 | plt.plot(acc_list) 158 | plt.show() 159 | -------------------------------------------------------------------------------- /src/Models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Mnist_2NN(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | self.fc1 = nn.Linear(784, 200) 10 | self.fc2 = nn.Linear(200, 200) 11 | self.fc3 = nn.Linear(200, 10) 12 | 13 | def forward(self, inputs): 14 | tensor = F.relu(self.fc1(inputs)) 15 | tensor = F.relu(self.fc2(tensor)) 16 | tensor = self.fc3(tensor) 17 | return tensor 18 | 19 | 20 | class Mnist_CNN(nn.Module): 21 | def __init__(self): 22 | super().__init__() 23 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, stride=1, padding=2) 24 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 25 | self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2) 26 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 27 | self.fc1 = nn.Linear(7*7*64, 512) 28 | self.fc2 = nn.Linear(512, 10) 29 | 30 | def forward(self, inputs): 31 | tensor = inputs.view(-1, 1, 28, 28) 32 | tensor = F.relu(self.conv1(tensor)) 33 | tensor = self.pool1(tensor) 34 | tensor = F.relu(self.conv2(tensor)) 35 | tensor = self.pool2(tensor) 36 | tensor = tensor.view(-1, 7*7*64) 37 | tensor = F.relu(self.fc1(tensor)) 38 | tensor = self.fc2(tensor) 39 | return tensor 40 | 41 | -------------------------------------------------------------------------------- /src/clients.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import TensorDataset 4 | from torch.utils.data import DataLoader 5 | from getData import GetDataSet 6 | import copy 7 | 8 | 9 | class client(object): 10 | def __init__(self, trainDataSet, dev): 11 | self.train_ds = trainDataSet 12 | self.dev = dev 13 | self.train_dl = None 14 | self.local_parameters = None 15 | 16 | def localUpdate(self, localEpoch, localBatchSize, Net, lossFun, opti, global_parameters): 17 | Net.load_state_dict(global_parameters, strict=True) 18 | self.train_dl = DataLoader(self.train_ds, batch_size=localBatchSize, shuffle=True) 19 | for epoch in range(localEpoch): 20 | for data, label in self.train_dl: 21 | data, label = data.to(self.dev), label.to(self.dev) 22 | preds = Net(data) 23 | loss = lossFun(preds, label) 24 | loss.backward() 25 | opti.step() 26 | opti.zero_grad() 27 | 28 | return Net.state_dict() 29 | 30 | def local_val(self): 31 | pass 32 | 33 | 34 | class ClientsGroup(object): 35 | def __init__(self, dataSetName, isIID, numOfClients, dev): 36 | self.data_set_name = dataSetName 37 | self.is_iid = isIID 38 | self.num_of_clients = numOfClients 39 | self.dev = dev 40 | self.clients_set = {} 41 | self.central_data = None 42 | 43 | self.test_data_loader = None 44 | 45 | self.dataSetBalanceAllocation() 46 | 47 | def centralTrain(self, localEpoch, localBatchSize, Net, lossFun, opti, global_parameters): 48 | Net.load_state_dict(global_parameters, strict=True) 49 | 50 | for epoch in range(localEpoch): 51 | for data, label in self.central_data: 52 | data, label = data.to(self.dev), label.to(self.dev) 53 | preds = Net(data) 54 | loss = lossFun(preds, label) 55 | loss.backward() 56 | opti.step() 57 | opti.zero_grad() 58 | 59 | return copy.deepcopy(Net.state_dict()) 60 | 61 | 62 | def dataSetBalanceAllocation(self): 63 | mnistDataSet = GetDataSet(self.data_set_name, self.is_iid) 64 | 65 | test_data = torch.tensor(mnistDataSet.test_data) 66 | test_label = torch.argmax(torch.tensor(mnistDataSet.test_label), dim=1) 67 | #****test change**** 68 | # self.test_data_loader = DataLoader(TensorDataset( test_data, test_label), batch_size=100, shuffle=False) 69 | 70 | if self.central_data is None: 71 | order = np.arange(test_data.shape[0]) 72 | np.random.shuffle(order) 73 | self.central_data = DataLoader(TensorDataset(test_data[order[0:100]], test_label[order[0:100]]), batch_size=100, shuffle=True) 74 | 75 | self.test_data_loader = DataLoader(TensorDataset(test_data,test_label), batch_size=100, shuffle=False) 76 | train_data = mnistDataSet.train_data 77 | train_label = mnistDataSet.train_label 78 | 79 | shard_size = mnistDataSet.train_data_size // self.num_of_clients // 2 80 | shards_id = np.random.permutation(mnistDataSet.train_data_size // shard_size) 81 | for i in range(self.num_of_clients): 82 | shards_id1 = shards_id[i * 2] 83 | shards_id2 = shards_id[i * 2 + 1] 84 | data_shards1 = train_data[shards_id1 * shard_size: shards_id1 * shard_size + shard_size] 85 | data_shards2 = train_data[shards_id2 * shard_size: shards_id2 * shard_size + shard_size] 86 | label_shards1 = train_label[shards_id1 * shard_size: shards_id1 * shard_size + shard_size] 87 | label_shards2 = train_label[shards_id2 * shard_size: shards_id2 * shard_size + shard_size] 88 | local_data, local_label = np.vstack((data_shards1, data_shards2)), np.vstack((label_shards1, label_shards2)) 89 | local_label = np.argmax(local_label, axis=1) 90 | someone = client(TensorDataset(torch.tensor(local_data), torch.tensor(local_label)), self.dev) 91 | self.clients_set['client{}'.format(i)] = someone 92 | 93 | if __name__=="__main__": 94 | MyClients = ClientsGroup('mnist', True, 100, 1) 95 | print(MyClients.clients_set['client10'].train_ds[0:100]) 96 | print(MyClients.clients_set['client11'].train_ds[400:500]) 97 | 98 | 99 | -------------------------------------------------------------------------------- /src/getData.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gzip 3 | import os 4 | import platform 5 | import pickle 6 | 7 | 8 | class GetDataSet(object): 9 | def __init__(self, dataSetName, isIID): 10 | self.name = dataSetName 11 | self.train_data = None 12 | self.train_label = None 13 | self.train_data_size = None 14 | self.test_data = None 15 | self.test_label = None 16 | self.test_data_size = None 17 | 18 | self._index_in_train_epoch = 0 19 | 20 | if self.name == 'mnist': 21 | self.mnistDataSetConstruct(isIID) 22 | else: 23 | pass 24 | 25 | 26 | def mnistDataSetConstruct(self, isIID): 27 | data_dir = r'../data/MNIST' 28 | # data_dir = r'./data/MNIST' 29 | train_images_path = os.path.join(data_dir, 'train-images-idx3-ubyte.gz') 30 | train_labels_path = os.path.join(data_dir, 'train-labels-idx1-ubyte.gz') 31 | test_images_path = os.path.join(data_dir, 't10k-images-idx3-ubyte.gz') 32 | test_labels_path = os.path.join(data_dir, 't10k-labels-idx1-ubyte.gz') 33 | train_images = extract_images(train_images_path) 34 | train_labels = extract_labels(train_labels_path) 35 | test_images = extract_images(test_images_path) 36 | test_labels = extract_labels(test_labels_path) 37 | 38 | assert train_images.shape[0] == train_labels.shape[0] 39 | assert test_images.shape[0] == test_labels.shape[0] 40 | 41 | self.train_data_size = train_images.shape[0] 42 | self.test_data_size = test_images.shape[0] 43 | 44 | assert train_images.shape[3] == 1 45 | assert test_images.shape[3] == 1 46 | train_images = train_images.reshape(train_images.shape[0], train_images.shape[1] * train_images.shape[2]) 47 | test_images = test_images.reshape(test_images.shape[0], test_images.shape[1] * test_images.shape[2]) 48 | 49 | train_images = train_images.astype(np.float32) 50 | train_images = np.multiply(train_images, 1.0 / 255.0) 51 | test_images = test_images.astype(np.float32) 52 | test_images = np.multiply(test_images, 1.0 / 255.0) 53 | 54 | if isIID: 55 | order = np.arange(self.train_data_size) 56 | np.random.shuffle(order) 57 | self.train_data = train_images[order] 58 | self.train_label = train_labels[order] 59 | else: 60 | labels = np.argmax(train_labels, axis=1) 61 | order = np.argsort(labels) 62 | self.train_data = train_images[order] 63 | self.train_label = train_labels[order] 64 | 65 | 66 | 67 | self.test_data = test_images 68 | self.test_label = test_labels 69 | 70 | 71 | def _read32(bytestream): 72 | dt = np.dtype(np.uint32).newbyteorder('>') 73 | return np.frombuffer(bytestream.read(4), dtype=dt)[0] 74 | 75 | 76 | def extract_images(filename): 77 | """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" 78 | print('Extracting', filename) 79 | with gzip.open(filename) as bytestream: 80 | magic = _read32(bytestream) 81 | if magic != 2051: 82 | raise ValueError( 83 | 'Invalid magic number %d in MNIST image file: %s' % 84 | (magic, filename)) 85 | num_images = _read32(bytestream) 86 | rows = _read32(bytestream) 87 | cols = _read32(bytestream) 88 | buf = bytestream.read(rows * cols * num_images) 89 | data = np.frombuffer(buf, dtype=np.uint8) 90 | data = data.reshape(num_images, rows, cols, 1) 91 | return data 92 | 93 | 94 | def dense_to_one_hot(labels_dense, num_classes=10): 95 | """Convert class labels from scalars to one-hot vectors.""" 96 | num_labels = labels_dense.shape[0] 97 | index_offset = np.arange(num_labels) * num_classes 98 | labels_one_hot = np.zeros((num_labels, num_classes)) 99 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 100 | return labels_one_hot 101 | 102 | 103 | def extract_labels(filename): 104 | """Extract the labels into a 1D uint8 numpy array [index].""" 105 | print('Extracting', filename) 106 | with gzip.open(filename) as bytestream: 107 | magic = _read32(bytestream) 108 | if magic != 2049: 109 | raise ValueError( 110 | 'Invalid magic number %d in MNIST label file: %s' % 111 | (magic, filename)) 112 | num_items = _read32(bytestream) 113 | buf = bytestream.read(num_items) 114 | labels = np.frombuffer(buf, dtype=np.uint8) 115 | return dense_to_one_hot(labels) 116 | 117 | 118 | if __name__=="__main__": 119 | 'test data set' 120 | mnistDataSet = GetDataSet('mnist', True) # test NON-IID 121 | if type(mnistDataSet.train_data) is np.ndarray and type(mnistDataSet.test_data) is np.ndarray and \ 122 | type(mnistDataSet.train_label) is np.ndarray and type(mnistDataSet.test_label) is np.ndarray: 123 | print('the type of data is numpy ndarray') 124 | else: 125 | print('the type of data is not numpy ndarray') 126 | print('the shape of the train data set is {}'.format(mnistDataSet.train_data.shape)) 127 | print('the shape of the test data set is {}'.format(mnistDataSet.test_data.shape)) 128 | print(mnistDataSet.train_label[0:100], mnistDataSet.train_label[11000:11100]) 129 | 130 | -------------------------------------------------------------------------------- /src/server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import optim 8 | from Models import Mnist_2NN, Mnist_CNN 9 | from clients import ClientsGroup, client 10 | 11 | 12 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="FedAvg") 13 | parser.add_argument('-g', '--gpu', type=str, default='0', help='gpu id to use(e.g. 0,1,2,3)') 14 | parser.add_argument('-nc', '--num_of_clients', type=int, default=100, help='numer of the clients') 15 | parser.add_argument('-cf', '--cfraction', type=float, default=0.1, help='C fraction, 0 means 1 client, 1 means total clients') 16 | parser.add_argument('-E', '--epoch', type=int, default=5, help='local train epoch') 17 | parser.add_argument('-B', '--batchsize', type=int, default=10, help='local train batch size') 18 | parser.add_argument('-mn', '--model_name', type=str, default='mnist_2nn', help='the model to train') 19 | parser.add_argument('-lr', "--learning_rate", type=float, default=0.01, help="learning rate, \ 20 | use value from origin paper as default") 21 | parser.add_argument('-vf', "--val_freq", type=int, default=5, help="model validation frequency(of communications)") 22 | parser.add_argument('-sf', '--save_freq', type=int, default=20, help='global model save frequency(of communication)') 23 | parser.add_argument('-ncomm', '--num_comm', type=int, default=1000, help='number of communications') 24 | parser.add_argument('-sp', '--save_path', type=str, default='./checkpoints', help='the saving path of checkpoints') 25 | parser.add_argument('-iid', '--IID', type=int, default=0, help='the way to allocate data to clients') 26 | 27 | 28 | def test_mkdir(path): 29 | if not os.path.isdir(path): 30 | os.mkdir(path) 31 | 32 | 33 | if __name__=="__main__": 34 | args = parser.parse_args() 35 | args = args.__dict__ 36 | 37 | test_mkdir(args['save_path']) 38 | 39 | os.environ['CUDA_VISIBLE_DEVICES'] = args['gpu'] 40 | dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 41 | 42 | net = None 43 | if args['model_name'] == 'mnist_2nn': 44 | net = Mnist_2NN() 45 | elif args['model_name'] == 'mnist_cnn': 46 | net = Mnist_CNN() 47 | 48 | if torch.cuda.device_count() > 1: 49 | print("Let's use", torch.cuda.device_count(), "GPUs!") 50 | net = torch.nn.DataParallel(net) 51 | net = net.to(dev) 52 | 53 | loss_func = F.cross_entropy 54 | opti = optim.SGD(net.parameters(), lr=args['learning_rate']) 55 | 56 | myClients = ClientsGroup('mnist', args['IID'], args['num_of_clients'], dev) 57 | testDataLoader = myClients.test_data_loader 58 | 59 | num_in_comm = int(max(args['num_of_clients'] * args['cfraction'], 1)) 60 | 61 | global_parameters = {} 62 | for key, var in net.state_dict().items(): 63 | global_parameters[key] = var.clone() 64 | 65 | for i in range(args['num_comm']): 66 | print("communicate round {}".format(i+1)) 67 | 68 | order = np.random.permutation(args['num_of_clients']) 69 | clients_in_comm = ['client{}'.format(i) for i in order[0:num_in_comm]] 70 | 71 | sum_parameters = None 72 | for client in tqdm(clients_in_comm): 73 | local_parameters = myClients.clients_set[client].localUpdate(args['epoch'], args['batchsize'], net, 74 | loss_func, opti, global_parameters) 75 | if sum_parameters is None: 76 | sum_parameters = {} 77 | for key, var in local_parameters.items(): 78 | sum_parameters[key] = var.clone() 79 | else: 80 | for var in sum_parameters: 81 | sum_parameters[var] = sum_parameters[var] + local_parameters[var] 82 | 83 | for var in global_parameters: 84 | global_parameters[var] = (sum_parameters[var] / num_in_comm) 85 | 86 | with torch.no_grad(): 87 | if (i + 1) % args['val_freq'] == 0: 88 | net.load_state_dict(global_parameters, strict=True) 89 | sum_accu = 0 90 | num = 0 91 | for data, label in testDataLoader: 92 | data, label = data.to(dev), label.to(dev) 93 | preds = net(data) 94 | preds = torch.argmax(preds, dim=1) 95 | sum_accu += (preds == label).float().mean() 96 | num += 1 97 | print('accuracy: {}'.format(sum_accu / num)) 98 | 99 | if (i + 1) % args['save_freq'] == 0: 100 | torch.save(net, os.path.join(args['save_path'], 101 | '{}_num_comm{}_E{}_B{}_lr{}_num_clients{}_cf{}'.format(args['model_name'], 102 | i, args['epoch'], 103 | args['batchsize'], 104 | args['learning_rate'], 105 | args['num_of_clients'], 106 | args['cfraction']))) 107 | 108 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding = utf-8 -*- 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | def cos(a,b): 6 | res = np.sum(a*b.T)/(np.sqrt(np.sum(a * a.T)) * np.sqrt(np.sum(b * b.T))) 7 | return res 8 | 9 | def norm_clip(v, v_clipped): 10 | nparr1 = np.array([]) 11 | nparr2 = np.array([]) 12 | for key, var in v.state_dict().items(): 13 | nplist = var.cpu().numpy() 14 | nplist = nplist.ravel() 15 | nparr1 = np.append(nparr1, nplist) 16 | for key, var in v_clipped.state_dict().items(): 17 | nplist = var.cpu().numpy() 18 | nplist = nplist.ravel() 19 | nparr2 = np.append(nparr2, nplist) 20 | vnum = np.linalg.norm(nparr1, ord=None, axis=None, keepdims=False) 21 | return vnum / np.linalg.norm(nparr2, ord=None, axis=None, keepdims=False) 22 | 23 | def cosScore(net1, net2): 24 | nparr1 = np.array([]) 25 | nparr2 = np.array([]) 26 | for key, var in net1.state_dict().items(): 27 | nplist = var.cpu().numpy() 28 | nplist = nplist.ravel() 29 | nparr1 = np.append(nparr1, nplist) 30 | for key, var in net2.state_dict().items(): 31 | nplist = var.cpu().numpy() 32 | nplist = nplist.ravel() 33 | nparr2 = np.append(nparr2, nplist) 34 | 35 | return cos(nparr1, nparr2) 36 | # tor_arr=torch.from_numpy(np_arr) 37 | # tor2numpy=tor_arr.numpy() 38 | # dict = './checkpoints/mnist_2nn_num_comm19_E5_B10_lr0.01_num_clients100_cf0.1' 39 | # model1 = torch.load(dict) 40 | # model2 = torch.load('./checkpoints/mnist_2nn_num_comm299_E5_B10_lr0.01_num_clients100_cf0.1') 41 | # print(norm_clip(model1,model2)) 42 | acc_list =[0.18149993896484376, 0.556499900817871, 0.6716998291015625, 0.7579998016357422, 0.8413002014160156, 0.8471998596191406, 0.9054999542236328, 0.8280001831054687, 0.8501999664306641, 0.9144998931884766, 0.8991999816894531, 0.9091998291015625, 0.8923001098632812, 0.8963998413085937, 0.9419001007080078, 0.9487999725341797, 0.9246000671386718, 0.9383002471923828, 0.9163002777099609, 0.9444000244140625, 0.9511000061035156, 0.95010009765625, 0.9483001708984375, 0.9571001434326172, 0.9567000579833984, 0.9513001251220703, 0.9598001098632812, 0.9613999176025391, 0.9667999267578125, 0.938499755859375, 0.9564000701904297, 0.9660997772216797] 43 | 44 | plt.plot(acc_list) 45 | plt.show() --------------------------------------------------------------------------------