├── README.md ├── main_mnist_cfedsi.py ├── main_mnist_fedavg.py ├── main_mnist_fedewc.py ├── main_mnist_fedprox.py ├── main_mnist_fedsi.py └── net_fewc.py /README.md: -------------------------------------------------------------------------------- 1 | # Communication-efficient-federated-continual-learning 2 | Communication-efficient federated continual learning for distributed learning system with Non-IID data 3 | 4 | # Cite 5 | Zhang Z, Zhang Y, Guo D, et al. Communication-efficient federated continual learning for distributed learning system with Non-IID data[J]. Science China Information Sciences, 2023, 66(2): 122102. 6 | 7 | https://www.sciengine.com/SCIS/article;JSESSIONID=b3309e6c-de8f-4372-9c8f-5cc269f0cbad?doi=10.1007/s11432-020-3419-4&scroll= 8 | 9 | -------------------------------------------------------------------------------- /main_mnist_cfedsi.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/9/1 9:15 3 | # @Author : zhao 4 | # @File : main_mnist_cfedsi.py 5 | 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.autograd import Variable 9 | from torch import autograd 10 | import torch.nn.functional as F 11 | from torchvision import datasets, transforms 12 | import pandas as pd 13 | import numpy as np 14 | from sklearn.utils import shuffle 15 | from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support 16 | from sklearn.preprocessing import MinMaxScaler 17 | from collections import Iterable # < py38 18 | import copy 19 | from net_fewc import CNNMnist 20 | import logging 21 | import gzip 22 | import os 23 | import time 24 | import argparse 25 | 26 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | 28 | # convert a list of list to a list [[],[],[]]->[,,] 29 | def flatten(items): 30 | """Yield items from any nested iterable; see Reference.""" 31 | for x in items: 32 | if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): 33 | for sub_x in flatten(x): 34 | yield sub_x 35 | else: 36 | yield x 37 | 38 | 39 | class DealDataset(Dataset): 40 | """ 41 | 读取数据、初始化数据 42 | """ 43 | def __init__(self, folder, data_name, label_name,transform=None): 44 | (train_set, train_labels) = load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式 45 | self.train_set = train_set 46 | self.train_labels = train_labels 47 | self.transform = transform 48 | 49 | def __getitem__(self, index): 50 | 51 | img, target = self.train_set[index], int(self.train_labels[index]) 52 | if self.transform is not None: 53 | img = self.transform(img) 54 | return img, target 55 | 56 | def __len__(self): 57 | return len(self.train_set) 58 | 59 | def load_data(data_folder, data_name, label_name): 60 | """ 61 | data_folder: 文件目录 62 | data_name: 数据文件名 63 | label_name:标签数据文件名 64 | """ 65 | with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据 66 | y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8) 67 | 68 | with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath: 69 | x_train = np.frombuffer( 70 | imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28) 71 | return (x_train, y_train) 72 | 73 | 74 | class DatasetSplit(Dataset): 75 | def __init__(self, dataset, idxs): 76 | self.dataset = dataset 77 | self.idxs = list(idxs) 78 | 79 | def __len__(self): 80 | return len(self.idxs) 81 | 82 | def __getitem__(self, item): 83 | image, label = self.dataset[self.idxs[item]] 84 | return image, label 85 | 86 | 87 | def iid(dataset, num_users): 88 | """ 89 | Sample I.I.D. client data from dataset 90 | :param dataset: 91 | :param num_users: 92 | :return: dict of image index 93 | """ 94 | num_items = int(len(dataset) / num_users) 95 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 96 | for i in range(num_users): 97 | dict_users[i] = set(np.random.choice(all_idxs, num_items, 98 | replace=False)) # Generates random samples from all_idexs,return a array with size of num_items 99 | all_idxs = list(set(all_idxs) - dict_users[i]) 100 | return dict_users 101 | 102 | def mnist_noniid6(dataset, num_users): 103 | """ 104 | Sample non-I.I.D client data from MNIST dataset 105 | :param dataset: 106 | :param num_users: 107 | :return: 108 | """ 109 | num_shards, num_imgs = 60, 1000 110 | idx_shard = [i for i in range(num_shards)] 111 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 112 | idxs = np.arange(num_shards*num_imgs) 113 | labels = dataset.train_labels#.numpy() 114 | 115 | # sort labels 116 | idxs_labels = np.vstack((idxs, labels)) 117 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 118 | idxs = idxs_labels[0,:] 119 | 120 | # divide and assign 121 | for i in range(num_users): 122 | rand_set = set(np.random.choice(idx_shard, 6, replace=False)) 123 | idx_shard = list(set(idx_shard) - rand_set) 124 | for rand in rand_set: 125 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 126 | return dict_users 127 | 128 | def mnist_noniid5(dataset, num_users): 129 | """ 130 | Sample non-I.I.D client data from MNIST dataset 131 | :param dataset: 132 | :param num_users: 133 | :return: 134 | """ 135 | num_shards, num_imgs = 50, 1200 136 | idx_shard = [i for i in range(num_shards)] 137 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 138 | idxs = np.arange(num_shards*num_imgs) 139 | labels = dataset.train_labels#.numpy() 140 | 141 | # sort labels 142 | idxs_labels = np.vstack((idxs, labels)) 143 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 144 | idxs = idxs_labels[0,:] 145 | 146 | # divide and assign 147 | for i in range(num_users): 148 | rand_set = set(np.random.choice(idx_shard, 5, replace=False)) 149 | idx_shard = list(set(idx_shard) - rand_set) 150 | for rand in rand_set: 151 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 152 | return dict_users 153 | 154 | def mnist_noniid4(dataset, num_users): 155 | """ 156 | Sample non-I.I.D client data from MNIST dataset 157 | :param dataset: 158 | :param num_users: 159 | :return: 160 | """ 161 | num_shards, num_imgs = 40, 1500 162 | idx_shard = [i for i in range(num_shards)] 163 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 164 | idxs = np.arange(num_shards*num_imgs) 165 | labels = dataset.train_labels#.numpy() 166 | 167 | # sort labels 168 | idxs_labels = np.vstack((idxs, labels)) 169 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 170 | idxs = idxs_labels[0,:] 171 | 172 | # divide and assign 173 | for i in range(num_users): 174 | rand_set = set(np.random.choice(idx_shard, 4, replace=False)) 175 | idx_shard = list(set(idx_shard) - rand_set) 176 | for rand in rand_set: 177 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 178 | return dict_users 179 | 180 | def mnist_noniid3(dataset, num_users): 181 | """ 182 | Sample non-I.I.D client data from MNIST dataset 183 | :param dataset: 184 | :param num_users: 185 | :return: 186 | """ 187 | num_shards, num_imgs = 30, 2000 188 | idx_shard = [i for i in range(num_shards)] 189 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 190 | idxs = np.arange(num_shards*num_imgs) 191 | labels = dataset.train_labels#.numpy() 192 | 193 | # sort labels 194 | idxs_labels = np.vstack((idxs, labels)) 195 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 196 | idxs = idxs_labels[0,:] 197 | 198 | # divide and assign 199 | for i in range(num_users): 200 | rand_set = set(np.random.choice(idx_shard, 3, replace=False)) 201 | idx_shard = list(set(idx_shard) - rand_set) 202 | for rand in rand_set: 203 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 204 | return dict_users 205 | 206 | def mnist_noniid2(dataset, num_users): 207 | """ 208 | Sample non-I.I.D client data from MNIST dataset 209 | :param dataset: 210 | :param num_users: 211 | :return: 212 | """ 213 | num_shards, num_imgs = 20, 3000 214 | idx_shard = [i for i in range(num_shards)] 215 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 216 | idxs = np.arange(num_shards*num_imgs) 217 | labels = dataset.train_labels#.numpy() 218 | 219 | # sort labels 220 | idxs_labels = np.vstack((idxs, labels)) 221 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 222 | idxs = idxs_labels[0,:] 223 | 224 | # divide and assign 225 | for i in range(num_users): 226 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 227 | idx_shard = list(set(idx_shard) - rand_set) 228 | for rand in rand_set: 229 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 230 | return dict_users 231 | 232 | def mnist_noniid1(dataset, num_users): 233 | """ 234 | Sample non-I.I.D client data from MNIST dataset 235 | :param dataset: 236 | :param num_users: 237 | :return: 238 | """ 239 | num_shards, num_imgs = 10, 6000 240 | idx_shard = [i for i in range(num_shards)] 241 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 242 | idxs = np.arange(num_shards*num_imgs) 243 | labels = dataset.train_labels#.numpy() 244 | # sort labels 245 | idxs_labels = np.vstack((idxs, labels)) 246 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 247 | idxs = idxs_labels[0,:] 248 | # divide and assign 249 | for i in range(num_users): 250 | rand_set = set(np.random.choice(idx_shard, 1, replace=False)) 251 | idx_shard = list(set(idx_shard) - rand_set) 252 | for rand in rand_set: 253 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 254 | return dict_users 255 | 256 | def test_img(net_g, datatest): 257 | net_g.eval() 258 | # testing 259 | test_loss = 0 260 | correct = 0 261 | data_pred = [] 262 | data_label = [] 263 | data_loader = DataLoader(datatest, batch_size=test_BatchSize, shuffle=True) 264 | l = len(data_loader) 265 | loss = torch.nn.CrossEntropyLoss() 266 | for idx, (data, target) in enumerate(data_loader): 267 | data, target = Variable(data).to(device), Variable(target).type(torch.LongTensor).to(device) 268 | # data, target = Variable(data), Variable(target).type(torch.LongTensor) 269 | log_probs = net_g(data) 270 | # sum up batch loss 271 | test_loss += loss(log_probs, target).item() 272 | # test_loss += F.cross_entropy(log_probs, target, reduction='sum').item() 273 | # get the index of the max log-probability 274 | y_pred = log_probs.data.detach().max(1, keepdim=True)[1] 275 | correct += y_pred.eq(target.data.detach().view_as(y_pred)).long().cpu().sum() 276 | data_pred.append(y_pred.cpu().detach().data.tolist()) 277 | data_label.append(target.cpu().detach().data.tolist()) 278 | list_data_label = list(flatten(data_label)) 279 | list_data_pred = list(flatten(data_pred)) 280 | all_report = precision_recall_fscore_support(list_data_label, list_data_pred, average='weighted') 281 | all_precision = all_report[0] 282 | all_recall = all_report[1] 283 | all_fscore = all_report[2] 284 | print('all_precision',all_precision,'all_recall',all_recall,'all_fscore',all_fscore) 285 | # print(classification_report(list_data_label, list_data_pred)) 286 | print(confusion_matrix(list_data_label, list_data_pred)) 287 | # print('test_loss', test_loss) 288 | test_loss /= len(data_loader.dataset) 289 | accuracy = 100.00 * correct / len(data_loader.dataset) 290 | print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} {:.2f}'.format( 291 | test_loss, correct, len(data_loader.dataset), accuracy)) 292 | # logging.info('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} {:.2f}\n'.format( 293 | # test_loss, correct, len(data_loader.dataset), accuracy)) 294 | return accuracy, test_loss 295 | 296 | 297 | def FedAvg(model, w_pre, w_update, error): 298 | w_update_avg = torch.div(torch.sum(torch.stack(w_update), dim=0), len(w_update)) 299 | # w_avg_vec = w_update_avg + w_pre 300 | # w_avg_vec = topk(vec=(w_update_avg + w_pre), k=int(rho * len(w_pre))) 301 | 302 | w_update_avg_spar, index = topk(vec=(w_update_avg), k=int(rho * len(w_pre))) 303 | w_avg_vec = w_pre + w_update_avg_spar 304 | # w_avg_vec = w_pre - w_update_avg_spar 305 | 306 | for i,s in enumerate(w_update): 307 | s0 = torch.zeros_like(s) 308 | s0[index] = s[index] ### s0 sparisified 309 | error[i] = error[i] + (s0 - s) 310 | 311 | w_avg = model.state_dict() 312 | vec = w_avg_vec.to(device) 313 | gradShapes, gradSizes = getGradShapes(Model=model) 314 | startPos = 0 315 | i = 0 316 | for k in w_avg.keys(): 317 | shape = gradShapes[i] 318 | size = gradSizes[i] 319 | i += 1 320 | # assert (size == np.prod(p.grad.data.size())) 321 | w_avg[k] = vec[startPos:startPos + size].reshape(shape) 322 | startPos += size 323 | return w_avg, error 324 | 325 | 326 | def getGradShapes(Model): 327 | """Return the shapes and sizes of the weight matrices""" 328 | gradShapes = [] 329 | gradSizes = [] 330 | for n, p in Model.named_parameters(): 331 | gradShapes.append(p.data.shape) 332 | gradSizes.append(np.prod(p.data.shape)) 333 | return gradShapes, gradSizes 334 | 335 | 336 | def getGradVec(w): 337 | """Return the gradient flattened to a vector""" 338 | gradVec = [] 339 | # flatten 340 | # for n, p in Model.named_parameters(): 341 | # # gradVec.append(torch.zeros_like(p.data.view(-1))) 342 | # gradVec.append(p.grad.data.view(-1).float()) 343 | for k in w.keys(): 344 | # gradVec.append(torch.zeros_like(p.data.view(-1))) 345 | gradVec.append(w[k].view(-1).float()) 346 | # concat into a single vector 347 | gradVec = torch.cat(gradVec) 348 | return gradVec 349 | 350 | 351 | def setGradVec(Model, vec): 352 | """Set the gradient to vec""" 353 | # put vec into p.grad.data 354 | vec = vec.to(device) 355 | gradShapes, gradSizes = getGradShapes(Model=Model) 356 | startPos = 0 357 | i = 0 358 | for n, p in Model.named_parameters(): 359 | shape = gradShapes[i] 360 | size = gradSizes[i] 361 | i += 1 362 | # assert (size == np.prod(p.grad.data.size())) 363 | p.grad.data.zero_() 364 | p.grad.data.add_(vec[startPos:startPos + size].reshape(shape)) 365 | startPos += size 366 | 367 | 368 | def topk(vec, k): 369 | """ Return the largest k elements (by magnitude) of vec""" 370 | ret = torch.zeros_like(vec) 371 | # on a gpu, sorting is faster than pytorch's topk method 372 | topkIndices = torch.sort(vec ** 2)[1][-k:] 373 | # _, topkIndices = torch.topk(vec**2, k) 374 | ret[topkIndices] = vec[topkIndices] 375 | return ret, topkIndices 376 | 377 | 378 | def quantize(x,s): 379 | compress_settings = {'n': s} 380 | # compress_settings.update(input_compress_settings) 381 | # assume that x is a torch tensor 382 | 383 | n = compress_settings['n'] 384 | # print('n:{}'.format(n)) 385 | x = x.float() 386 | x_norm = torch.norm(x, p=float('inf')) # inf_norm = max(abs(x)) 387 | 388 | sgn_x = ((x > 0).float() - 0.5) * 2 389 | 390 | p = torch.div(torch.abs(x), x_norm) 391 | renormalize_p = torch.mul(p, n) 392 | floor_p = torch.floor(renormalize_p) 393 | compare = torch.rand_like(floor_p) 394 | final_p = renormalize_p - floor_p 395 | margin = (compare < final_p).float() 396 | xi = (floor_p + margin) / n 397 | 398 | Tilde_x = x_norm * sgn_x * xi 399 | 400 | return Tilde_x 401 | 402 | 403 | def quantize_log(x): 404 | compress_settings = {'n': 16} 405 | # compress_settings.update(input_compress_settings) 406 | # assume that x is a torch tensor 407 | n = compress_settings['n'] 408 | # print('n:{}'.format(n)) 409 | x = x.float() 410 | x_norm = torch.norm(x, p=float('inf')) # inf_norm = max(abs(x)) 411 | sgn_x = ((x > 0).float() - 0.5) * 2 412 | p = torch.div(torch.abs(x), x_norm) 413 | lookup = torch.linspace(0, -10, n) 414 | log_p = torch.log2(p) 415 | round_index = [(torch.abs(lookup - k)).min(dim=0)[1] for k in log_p] 416 | round_p = [2 ** (lookup[i]) for i in round_index] 417 | round_p = torch.stack(round_p).to(device) 418 | # print('round_p',round_p) 419 | # print('x_norm',x_norm) 420 | 421 | Tilde_x = x_norm * round_p * sgn_x 422 | 423 | return Tilde_x 424 | 425 | 426 | def quantization_layer(sizes, x): 427 | q_x = torch.zeros_like(x) 428 | startPos = 0 429 | for i in sizes: 430 | q_x[startPos:startPos + i] = quantize(x[startPos:startPos + i]) 431 | # q_x[startPos:startPos + i] = quantize_log(x[startPos:startPos + i]) 432 | startPos += i 433 | return q_x 434 | 435 | 436 | def sparsity(fisher, w_update, w_prev, topkIndices,s): 437 | Shapes = [] 438 | Sizes = [] 439 | for j in fisher.keys(): 440 | Shapes.append(fisher[j].shape) 441 | Sizes.append(np.prod(fisher[j].shape)) 442 | # print('fisher sizes', Sizes) 443 | fisher_vector = getGradVec(fisher) 444 | fisher_vector_spar = torch.zeros_like(fisher_vector) 445 | fisher_vector_spar[topkIndices] = fisher_vector[topkIndices] 446 | # fisher_vector_spar_q = quantization_layer(sizes=torch.tensor([144, 16, 4608, 32, 2560, 5]), 447 | # x=fisher_vector_spar) 448 | fisher_vector_spar_q = quantize(fisher_vector_spar,s) 449 | model_vector_spar_q = w_update + w_prev 450 | # model_vector_spar_q = w_update - w_prev 451 | fisher_spar = {k: torch.zeros_like(fisher[k]) for k in fisher.keys()} 452 | model_spar = {k: torch.zeros_like(fisher[k]) for k in fisher.keys()} 453 | startPos = 0 454 | j = 0 455 | for k in fisher.keys(): 456 | shape = Shapes[j] 457 | size = Sizes[j] 458 | j += 1 459 | fisher_spar[k] = fisher_vector_spar_q[startPos:startPos + size].reshape(shape).double() 460 | model_spar[k] = model_vector_spar_q[startPos:startPos + size].reshape(shape).double() 461 | startPos += size 462 | return fisher_spar, model_spar 463 | 464 | def consolidate(Model, Weight, MEAN_pre, epsilon): 465 | OMEGA_current = {n: p.data.clone().zero_() for n, p in Model.named_parameters()} 466 | for n, p in Model.named_parameters(): 467 | p_current = p.detach().clone() 468 | p_change = p_current - MEAN_pre[n] 469 | # W[n].add_((p.grad**2) * torch.abs(p_change)) 470 | # OMEGA_add = W[n]/ (p_change ** 2 + epsilon) 471 | # W[n].add_(-p.grad * p_change) 472 | OMEGA_add = torch.max(Weight[n], Weight[n].clone().zero_()) / (p_change ** 2 + epsilon) 473 | # OMEGA_add = Weight[n] / (p_change ** 2 + epsilon) 474 | # OMEGA_current[n] = OMEGA_pre[n] + OMEGA_add 475 | OMEGA_current[n] = OMEGA_add 476 | MEAN_current = {n: p.data for n, p in Model.named_parameters()} 477 | return OMEGA_current, MEAN_current 478 | 479 | 480 | # FL + EWC 481 | if __name__ == '__main__': 482 | # logging.basicConfig(filename='./20200512_cicids_our_noniid1_E_1_T_1.log', level=logging.DEBUG) 483 | # logging.info('11111') 484 | parser = argparse.ArgumentParser() 485 | parser.add_argument('--E', type=int, nargs='?', default=5, help="local epoches") 486 | parser.add_argument('--rho', type=float, nargs='?', default=0.5, help="sparsification ratio") 487 | parser.add_argument('--s', type=int, nargs='?', default=32, help="quantization") 488 | args = parser.parse_args() 489 | epsilon = 0.0001 490 | rho = args.rho #0.5 491 | s = args.s ##32 492 | Lamda = 1.0 #0.5 493 | E = args.E #5 494 | T = 50 495 | ## FedAvg 496 | # rho = 1.0 497 | # Lamda = 0.0 498 | frac = 1.0 499 | num_clients = 10 500 | batch_size = 512 501 | test_BatchSize = 32 502 | 503 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 504 | #### MNIST 505 | dataset_train = DealDataset('./data/MNIST/raw', "train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", 506 | transform=trans_mnist) 507 | dataset_test = DealDataset('./data/MNIST/raw', "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz", 508 | transform=trans_mnist) 509 | 510 | 511 | dict_clients = mnist_noniid1(dataset_train, num_users=num_clients) 512 | 513 | net_global = CNNMnist(lamda=Lamda).to(device) #.double() 514 | 515 | # for n, p in net_global.named_parameters(): 516 | # p.data.zero_() 517 | w_glob = net_global.state_dict() 518 | # print(w_glob) 519 | crit = torch.nn.CrossEntropyLoss() 520 | # optimizer = torch.optim.SGD(net_global.parameters(), lr=0.001, momentum=0.5) 521 | net_global.train() 522 | 523 | omega_current, mean_current = {}, {} 524 | for i in range(num_clients): 525 | omega_current[i] = {} 526 | mean_current[i] = {} 527 | error_compensation = {} 528 | 529 | for interation in range(T): 530 | w_vec_locals, loss_locals = [], [] 531 | # print('interationh',interation) 532 | weight_vec_pre = getGradVec(w_glob) 533 | for client in range(num_clients): 534 | net = copy.deepcopy(net_global).to(device) 535 | # crit = torch.nn.CrossEntropyLoss() 536 | net.train() 537 | opt_net = torch.optim.SGD(net.parameters(), lr=0.0001, momentum=0.9) 538 | print('interation', interation, 'client', client) 539 | idx_traindataset = DatasetSplit(dataset_train, dict_clients[client]) 540 | ldr_train = DataLoader(idx_traindataset, batch_size=512, shuffle=True) 541 | dataset_size = len(ldr_train.dataset) 542 | epochs_per_task = E 543 | 544 | mean_pre = {n: p.clone().detach() for n, p in net.named_parameters()} 545 | W = {n: p.clone().detach().zero_() for n, p in net.named_parameters()} 546 | 547 | t0 = time.clock() 548 | for epoch in range(1, epochs_per_task + 1): 549 | correct = 0 550 | for batch_idx, (images, labels) in enumerate(ldr_train): 551 | old_par = {n: p.clone().detach() for n, p in net.named_parameters()} 552 | images, labels = Variable(images).to(device), Variable(labels).type(torch.LongTensor).to(device) 553 | net.zero_grad() 554 | scores = net(images) 555 | ce_loss = crit(scores, labels) 556 | grad_params = torch.autograd.grad(ce_loss, net.parameters(), create_graph=True) 557 | if interation == 0: 558 | ewc_loss = torch.FloatTensor([0.0]).to(device) #torch.DoubleTensor 559 | else: 560 | other_clients = list(set([i for i in range(num_clients)])-set([client])) 561 | losses = {} 562 | sum_ewc_loss = 0 563 | for i in other_clients: 564 | losses[i] = [] 565 | for n, p in net.named_parameters(): 566 | mean, omega = Variable(mean_current[i][n]), Variable(omega_current[i][n]) 567 | losses[i].append((omega * (p - mean) ** 2).sum()) 568 | sum_ewc_loss += sum(losses[i]) 569 | ewc_loss = net.lamda * sum_ewc_loss 570 | loss = ce_loss + ewc_loss.double() # + reg_loss.double() 571 | pred = scores.max(1)[1] 572 | correct += pred.eq(labels.data.view_as(pred)).cpu().sum() 573 | loss.backward() 574 | opt_net.step() 575 | 576 | j = 0 577 | for n, p in net.named_parameters(): 578 | W[n] -= (grad_params[j].clone().detach()) * (p.detach() - old_par[n]) 579 | j += 1 580 | 581 | # if (interation != 0) and (client == 0): 582 | # print('1', sum(losses1), sum(lossnon), '2', sum(losses2), '3', sum(losses3), '4', sum(losses4)) 583 | Accuracy = 100. * correct.type(torch.FloatTensor) / dataset_size 584 | print('Train Epoch:{}\tLoss:{:.4f}\tEWC_Loss:{:.4f}\tCE_Loss:{:.4f}\tAccuracy: {:.4f}'.format(epoch,loss.item(),ewc_loss.item(),ce_loss.item(),Accuracy)) 585 | # logging.info('Train Epoch:{}\tLoss:{:.4f}\tEWC_Loss:{:.4f}\tCE_Loss:{:.4f}\tAccuracy: {:.4f}'.format(epoch,loss.item(),ewc_loss.item(),ce_loss.item(),Accuracy)) 586 | # print(classification_report(labels.cpu().data.view_as(pred.cpu()), pred.cpu())) 587 | 588 | # w_locals.append(copy.deepcopy(net.state_dict())) 589 | w = net.state_dict() 590 | weight_vec_current = getGradVec(w) 591 | K = int(rho * len(weight_vec_current)) 592 | print('sparsity k=', K) 593 | 594 | if interation == 0: 595 | error_compensation[client] = torch.zeros_like(weight_vec_current) 596 | weight_update, Topkindices = topk(vec=(weight_vec_current - weight_vec_pre + error_compensation[client]),k=K) 597 | weight_update_q = quantize(weight_update,s) 598 | error_compensation[client] = (weight_update_q - (weight_vec_current - weight_vec_pre + error_compensation[client])) # + error_compensation0 599 | omega_current_00, mean_current_00 = consolidate(Model=net, Weight=W, MEAN_pre=mean_pre, epsilon=epsilon) 600 | omega_current[client], mean_current[client] = sparsity(fisher=omega_current_00, w_update=weight_update_q,w_prev=weight_vec_pre, topkIndices=Topkindices,s=s) 601 | w_vec_locals.append(weight_update_q) 602 | t1 = time.clock() 603 | print('client:\t', client, 'trainingtime:\t', str(t1 - t0)) 604 | # weight_update, _ = topk(vec=(weight_vec_current - weight_vec_pre + error_compensation), k=K) 605 | # weight_update_q = quantize(weight_update) 606 | # # weight_update_q=quantize_log(weight_update) 607 | # # weight_update_q=quantization_layer(sizes=torch.tensor([144, 16, 4608, 32, 2560, 5]),x=weight_update) 608 | # error_compensation = error_compensation + ( 609 | # weight_vec_current - weight_vec_pre + error_compensation - weight_update_q) 610 | # w_vec_locals.append(weight_update_q) 611 | 612 | # weight_update, _ = topk(vec=(weight_vec_current - weight_vec_pre), k=K) 613 | # # weight_update_q=quantize(weight_update) 614 | # # weight_update_q=quantize_log(weight_update) 615 | # # weight_update_q=quantization_layer(sizes=torch.tensor([144, 16, 4608, 32, 2560, 5]),x=weight_update) 616 | # w_vec_locals.append(weight_update) 617 | 618 | # w_glob = FedAvg(model=copy.deepcopy(net_global), w_pre=weight_vec_pre, w_update=w_vec_locals) 619 | w_glob,error_compensation = FedAvg(model=copy.deepcopy(net_global), w_pre=weight_vec_pre, w_update=w_vec_locals, error=error_compensation) 620 | 621 | # print(w_glob) 622 | 623 | # w_glob = FedAvg(w_locals) 624 | # copy weight to net_glob 625 | net_global.load_state_dict(w_glob) 626 | # net_global.load_state_dict(w_glob) 627 | net_global.eval() 628 | acc_test, loss_test = test_img(net_global, dataset_test) 629 | print("Testing accuracy: {:.2f}".format(acc_test)) 630 | 631 | model_dict = net_global.state_dict() # 自己的模型参数变量 632 | test_dict = {k: w_glob[k] for k in w_glob.keys() if k in model_dict} # 去除一些不需要的参数 633 | model_dict.update(test_dict) # 参数更新 634 | net_global.load_state_dict(model_dict) # 加载 635 | 636 | # for n, p in net_global.named_parameters(): 637 | # p = w_glob[n] 638 | 639 | # net_global.load_state_dict(w_glob) 640 | net_global.eval() 641 | acc_test, loss_test = test_img(net_global, dataset_test) 642 | print("Testing accuracy: {:.2f}".format(acc_test)) 643 | -------------------------------------------------------------------------------- /main_mnist_fedavg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/9/1 19:29 3 | # @Author : zhao 4 | # @File : main_mnist_fedavg.py 5 | 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.autograd import Variable 9 | from torch import autograd 10 | import torch.nn.functional as F 11 | from torchvision import datasets, transforms 12 | import pandas as pd 13 | import numpy as np 14 | from sklearn.utils import shuffle 15 | from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support 16 | from sklearn.preprocessing import MinMaxScaler 17 | from collections import Iterable # < py38 18 | import copy 19 | from net_fewc import CNNMnist 20 | import logging 21 | import gzip 22 | import os 23 | import time 24 | 25 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 26 | 27 | # convert a list of list to a list [[],[],[]]->[,,] 28 | def flatten(items): 29 | """Yield items from any nested iterable; see Reference.""" 30 | for x in items: 31 | if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): 32 | for sub_x in flatten(x): 33 | yield sub_x 34 | else: 35 | yield x 36 | 37 | 38 | class DealDataset(Dataset): 39 | """ 40 | 读取数据、初始化数据 41 | """ 42 | def __init__(self, folder, data_name, label_name,transform=None): 43 | (train_set, train_labels) = load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式 44 | self.train_set = train_set 45 | self.train_labels = train_labels 46 | self.transform = transform 47 | 48 | def __getitem__(self, index): 49 | 50 | img, target = self.train_set[index], int(self.train_labels[index]) 51 | if self.transform is not None: 52 | img = self.transform(img) 53 | return img, target 54 | 55 | def __len__(self): 56 | return len(self.train_set) 57 | 58 | def load_data(data_folder, data_name, label_name): 59 | """ 60 | data_folder: 文件目录 61 | data_name: 数据文件名 62 | label_name:标签数据文件名 63 | """ 64 | with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据 65 | y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8) 66 | 67 | with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath: 68 | x_train = np.frombuffer( 69 | imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28) 70 | return (x_train, y_train) 71 | 72 | 73 | class DatasetSplit(Dataset): 74 | def __init__(self, dataset, idxs): 75 | self.dataset = dataset 76 | self.idxs = list(idxs) 77 | 78 | def __len__(self): 79 | return len(self.idxs) 80 | 81 | def __getitem__(self, item): 82 | image, label = self.dataset[self.idxs[item]] 83 | return image, label 84 | 85 | 86 | def iid(dataset, num_users): 87 | """ 88 | Sample I.I.D. client data from dataset 89 | :param dataset: 90 | :param num_users: 91 | :return: dict of image index 92 | """ 93 | num_items = int(len(dataset) / num_users) 94 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 95 | for i in range(num_users): 96 | dict_users[i] = set(np.random.choice(all_idxs, num_items, 97 | replace=False)) # Generates random samples from all_idexs,return a array with size of num_items 98 | all_idxs = list(set(all_idxs) - dict_users[i]) 99 | return dict_users 100 | 101 | def mnist_noniid6(dataset, num_users): 102 | """ 103 | Sample non-I.I.D client data from MNIST dataset 104 | :param dataset: 105 | :param num_users: 106 | :return: 107 | """ 108 | num_shards, num_imgs = 60, 1000 109 | idx_shard = [i for i in range(num_shards)] 110 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 111 | idxs = np.arange(num_shards*num_imgs) 112 | labels = dataset.train_labels#.numpy() 113 | 114 | # sort labels 115 | idxs_labels = np.vstack((idxs, labels)) 116 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 117 | idxs = idxs_labels[0,:] 118 | 119 | # divide and assign 120 | for i in range(num_users): 121 | rand_set = set(np.random.choice(idx_shard, 6, replace=False)) 122 | idx_shard = list(set(idx_shard) - rand_set) 123 | for rand in rand_set: 124 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 125 | return dict_users 126 | 127 | def mnist_noniid5(dataset, num_users): 128 | """ 129 | Sample non-I.I.D client data from MNIST dataset 130 | :param dataset: 131 | :param num_users: 132 | :return: 133 | """ 134 | num_shards, num_imgs = 50, 1200 135 | idx_shard = [i for i in range(num_shards)] 136 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 137 | idxs = np.arange(num_shards*num_imgs) 138 | labels = dataset.train_labels#.numpy() 139 | 140 | # sort labels 141 | idxs_labels = np.vstack((idxs, labels)) 142 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 143 | idxs = idxs_labels[0,:] 144 | 145 | # divide and assign 146 | for i in range(num_users): 147 | rand_set = set(np.random.choice(idx_shard, 5, replace=False)) 148 | idx_shard = list(set(idx_shard) - rand_set) 149 | for rand in rand_set: 150 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 151 | return dict_users 152 | 153 | def mnist_noniid4(dataset, num_users): 154 | """ 155 | Sample non-I.I.D client data from MNIST dataset 156 | :param dataset: 157 | :param num_users: 158 | :return: 159 | """ 160 | num_shards, num_imgs = 40, 1500 161 | idx_shard = [i for i in range(num_shards)] 162 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 163 | idxs = np.arange(num_shards*num_imgs) 164 | labels = dataset.train_labels#.numpy() 165 | 166 | # sort labels 167 | idxs_labels = np.vstack((idxs, labels)) 168 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 169 | idxs = idxs_labels[0,:] 170 | 171 | # divide and assign 172 | for i in range(num_users): 173 | rand_set = set(np.random.choice(idx_shard, 4, replace=False)) 174 | idx_shard = list(set(idx_shard) - rand_set) 175 | for rand in rand_set: 176 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 177 | return dict_users 178 | 179 | def mnist_noniid3(dataset, num_users): 180 | """ 181 | Sample non-I.I.D client data from MNIST dataset 182 | :param dataset: 183 | :param num_users: 184 | :return: 185 | """ 186 | num_shards, num_imgs = 30, 2000 187 | idx_shard = [i for i in range(num_shards)] 188 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 189 | idxs = np.arange(num_shards*num_imgs) 190 | labels = dataset.train_labels#.numpy() 191 | 192 | # sort labels 193 | idxs_labels = np.vstack((idxs, labels)) 194 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 195 | idxs = idxs_labels[0,:] 196 | 197 | # divide and assign 198 | for i in range(num_users): 199 | rand_set = set(np.random.choice(idx_shard, 3, replace=False)) 200 | idx_shard = list(set(idx_shard) - rand_set) 201 | for rand in rand_set: 202 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 203 | return dict_users 204 | 205 | def mnist_noniid2(dataset, num_users): 206 | """ 207 | Sample non-I.I.D client data from MNIST dataset 208 | :param dataset: 209 | :param num_users: 210 | :return: 211 | """ 212 | num_shards, num_imgs = 20, 3000 213 | idx_shard = [i for i in range(num_shards)] 214 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 215 | idxs = np.arange(num_shards*num_imgs) 216 | labels = dataset.train_labels#.numpy() 217 | 218 | # sort labels 219 | idxs_labels = np.vstack((idxs, labels)) 220 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 221 | idxs = idxs_labels[0,:] 222 | 223 | # divide and assign 224 | for i in range(num_users): 225 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 226 | idx_shard = list(set(idx_shard) - rand_set) 227 | for rand in rand_set: 228 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 229 | return dict_users 230 | 231 | 232 | def mnist_noniid1(dataset, num_users): 233 | """ 234 | Sample non-I.I.D client data from MNIST dataset 235 | :param dataset: 236 | :param num_users: 237 | :return: 238 | """ 239 | num_shards, num_imgs = 10, 6000 240 | idx_shard = [i for i in range(num_shards)] 241 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 242 | idxs = np.arange(num_shards*num_imgs) 243 | labels = dataset.train_labels#.numpy() 244 | # sort labels 245 | idxs_labels = np.vstack((idxs, labels)) 246 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 247 | idxs = idxs_labels[0,:] 248 | # divide and assign 249 | for i in range(num_users): 250 | rand_set = set(np.random.choice(idx_shard, 1, replace=False)) 251 | idx_shard = list(set(idx_shard) - rand_set) 252 | for rand in rand_set: 253 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 254 | return dict_users 255 | 256 | 257 | def test_img(net_g, datatest): 258 | net_g.eval() 259 | # testing 260 | test_loss = 0 261 | correct = 0 262 | data_pred = [] 263 | data_label = [] 264 | data_loader = DataLoader(datatest, batch_size=test_BatchSize, shuffle=True) 265 | l = len(data_loader) 266 | loss = torch.nn.CrossEntropyLoss() 267 | for idx, (data, target) in enumerate(data_loader): 268 | data, target = Variable(data).to(device), Variable(target).type(torch.LongTensor).to(device) 269 | # data, target = Variable(data), Variable(target).type(torch.LongTensor) 270 | log_probs = net_g(data) 271 | # sum up batch loss 272 | test_loss += loss(log_probs, target).item() 273 | # test_loss += F.cross_entropy(log_probs, target, reduction='sum').item() 274 | # get the index of the max log-probability 275 | y_pred = log_probs.data.detach().max(1, keepdim=True)[1] 276 | correct += y_pred.eq(target.data.detach().view_as(y_pred)).long().cpu().sum() 277 | data_pred.append(y_pred.cpu().detach().data.tolist()) 278 | data_label.append(target.cpu().detach().data.tolist()) 279 | list_data_label = list(flatten(data_label)) 280 | list_data_pred = list(flatten(data_pred)) 281 | all_report = precision_recall_fscore_support(list_data_label, list_data_pred, average='weighted') 282 | all_precision = all_report[0] 283 | all_recall = all_report[1] 284 | all_fscore = all_report[2] 285 | print('all_precision',all_precision,'all_recall',all_recall,'all_fscore',all_fscore) 286 | # print(classification_report(list_data_label, list_data_pred)) 287 | print(confusion_matrix(list_data_label, list_data_pred)) 288 | # print('test_loss', test_loss) 289 | test_loss /= len(data_loader.dataset) 290 | accuracy = 100.00 * correct / len(data_loader.dataset) 291 | print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} {:.2f}'.format( 292 | test_loss, correct, len(data_loader.dataset), accuracy)) 293 | # logging.info('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} {:.2f}\n'.format( 294 | # test_loss, correct, len(data_loader.dataset), accuracy)) 295 | return accuracy, test_loss 296 | 297 | 298 | def FedAvg(w): 299 | w_avg = copy.deepcopy(w[0]) 300 | for k in w_avg.keys(): 301 | for i in range(1, len(w)): 302 | w_avg[k] += w[i][k] 303 | w_avg[k] = torch.div(w_avg[k], len(w)) 304 | return w_avg 305 | 306 | 307 | def getGradShapes(Model): 308 | """Return the shapes and sizes of the weight matrices""" 309 | gradShapes = [] 310 | gradSizes = [] 311 | for n, p in Model.named_parameters(): 312 | gradShapes.append(p.data.shape) 313 | gradSizes.append(np.prod(p.data.shape)) 314 | return gradShapes, gradSizes 315 | 316 | 317 | def getGradVec(w): 318 | """Return the gradient flattened to a vector""" 319 | gradVec = [] 320 | # flatten 321 | # for n, p in Model.named_parameters(): 322 | # # gradVec.append(torch.zeros_like(p.data.view(-1))) 323 | # gradVec.append(p.grad.data.view(-1).float()) 324 | for k in w.keys(): 325 | # gradVec.append(torch.zeros_like(p.data.view(-1))) 326 | gradVec.append(w[k].view(-1).float()) 327 | # concat into a single vector 328 | gradVec = torch.cat(gradVec) 329 | return gradVec 330 | 331 | 332 | def setGradVec(Model, vec): 333 | """Set the gradient to vec""" 334 | # put vec into p.grad.data 335 | vec = vec.to(device) 336 | gradShapes, gradSizes = getGradShapes(Model=Model) 337 | startPos = 0 338 | i = 0 339 | for n, p in Model.named_parameters(): 340 | shape = gradShapes[i] 341 | size = gradSizes[i] 342 | i += 1 343 | # assert (size == np.prod(p.grad.data.size())) 344 | p.grad.data.zero_() 345 | p.grad.data.add_(vec[startPos:startPos + size].reshape(shape)) 346 | startPos += size 347 | 348 | 349 | def topk(vec, k): 350 | """ Return the largest k elements (by magnitude) of vec""" 351 | ret = torch.zeros_like(vec) 352 | # on a gpu, sorting is faster than pytorch's topk method 353 | topkIndices = torch.sort(vec ** 2)[1][-k:] 354 | # _, topkIndices = torch.topk(vec**2, k) 355 | ret[topkIndices] = vec[topkIndices] 356 | return ret, topkIndices 357 | 358 | 359 | def quantize(x): 360 | compress_settings = {'n': 32} 361 | # compress_settings.update(input_compress_settings) 362 | # assume that x is a torch tensor 363 | 364 | n = compress_settings['n'] 365 | # print('n:{}'.format(n)) 366 | x = x.float() 367 | x_norm = torch.norm(x, p=float('inf')) # inf_norm = max(abs(x)) 368 | 369 | sgn_x = ((x > 0).float() - 0.5) * 2 370 | 371 | p = torch.div(torch.abs(x), x_norm) 372 | renormalize_p = torch.mul(p, n) 373 | floor_p = torch.floor(renormalize_p) 374 | compare = torch.rand_like(floor_p) 375 | final_p = renormalize_p - floor_p 376 | margin = (compare < final_p).float() 377 | xi = (floor_p + margin) / n 378 | 379 | Tilde_x = x_norm * sgn_x * xi 380 | 381 | return Tilde_x 382 | 383 | 384 | def quantize_log(x): 385 | compress_settings = {'n': 16} 386 | # compress_settings.update(input_compress_settings) 387 | # assume that x is a torch tensor 388 | n = compress_settings['n'] 389 | # print('n:{}'.format(n)) 390 | x = x.float() 391 | x_norm = torch.norm(x, p=float('inf')) # inf_norm = max(abs(x)) 392 | sgn_x = ((x > 0).float() - 0.5) * 2 393 | p = torch.div(torch.abs(x), x_norm) 394 | lookup = torch.linspace(0, -10, n) 395 | log_p = torch.log2(p) 396 | round_index = [(torch.abs(lookup - k)).min(dim=0)[1] for k in log_p] 397 | round_p = [2 ** (lookup[i]) for i in round_index] 398 | round_p = torch.stack(round_p).to(device) 399 | # print('round_p',round_p) 400 | # print('x_norm',x_norm) 401 | 402 | Tilde_x = x_norm * round_p * sgn_x 403 | 404 | return Tilde_x 405 | 406 | 407 | def quantization_layer(sizes, x): 408 | q_x = torch.zeros_like(x) 409 | startPos = 0 410 | for i in sizes: 411 | q_x[startPos:startPos + i] = quantize(x[startPos:startPos + i]) 412 | # q_x[startPos:startPos + i] = quantize_log(x[startPos:startPos + i]) 413 | startPos += i 414 | return q_x 415 | 416 | 417 | def sparsity(fisher, w_update, w_prev, topkIndices): 418 | Shapes = [] 419 | Sizes = [] 420 | for j in fisher.keys(): 421 | Shapes.append(fisher[j].shape) 422 | Sizes.append(np.prod(fisher[j].shape)) 423 | # print('fisher sizes', Sizes) 424 | fisher_vector = getGradVec(fisher) 425 | fisher_vector_spar = torch.zeros_like(fisher_vector) 426 | fisher_vector_spar[topkIndices] = fisher_vector[topkIndices] 427 | # fisher_vector_spar_q = quantization_layer(sizes=torch.tensor([144, 16, 4608, 32, 2560, 5]), 428 | # x=fisher_vector_spar) 429 | fisher_vector_spar_q = quantize(fisher_vector_spar) 430 | model_vector_spar_q = w_update + w_prev 431 | # model_vector_spar_q = w_update - w_prev 432 | fisher_spar = {k: torch.zeros_like(fisher[k]) for k in fisher.keys()} 433 | model_spar = {k: torch.zeros_like(fisher[k]) for k in fisher.keys()} 434 | startPos = 0 435 | j = 0 436 | for k in fisher.keys(): 437 | shape = Shapes[j] 438 | size = Sizes[j] 439 | j += 1 440 | fisher_spar[k] = fisher_vector_spar_q[startPos:startPos + size].reshape(shape).double() 441 | model_spar[k] = model_vector_spar_q[startPos:startPos + size].reshape(shape).double() 442 | startPos += size 443 | return fisher_spar, model_spar 444 | 445 | def consolidate(Model, Weight, MEAN_pre, epsilon): 446 | OMEGA_current = {n: p.data.clone().zero_() for n, p in Model.named_parameters()} 447 | for n, p in Model.named_parameters(): 448 | p_current = p.detach().clone() 449 | p_change = p_current - MEAN_pre[n] 450 | # W[n].add_((p.grad**2) * torch.abs(p_change)) 451 | # OMEGA_add = W[n]/ (p_change ** 2 + epsilon) 452 | # W[n].add_(-p.grad * p_change) 453 | OMEGA_add = torch.max(Weight[n], Weight[n].clone().zero_()) / (p_change ** 2 + epsilon) 454 | # OMEGA_add = Weight[n] / (p_change ** 2 + epsilon) 455 | # OMEGA_current[n] = OMEGA_pre[n] + OMEGA_add 456 | OMEGA_current[n] = OMEGA_add 457 | MEAN_current = {n: p.data for n, p in Model.named_parameters()} 458 | return OMEGA_current, MEAN_current 459 | 460 | 461 | # FL + EWC 462 | if __name__ == '__main__': 463 | # logging.basicConfig(filename='./20200512_cicids_our_noniid1_E_1_T_1.log', level=logging.DEBUG) 464 | # logging.info('11111') 465 | epsilon = 0.0001 466 | rho = 1.0 #0.5 467 | Lamda = 1.0 #0.5 468 | E = 5 469 | T = 50 #50 470 | ## FedAvg 471 | # rho = 1.0 472 | # Lamda = 0.0 473 | frac = 1.0 474 | num_clients = 10 475 | batch_size = 512 476 | test_BatchSize = 32 477 | 478 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 479 | ### MNIST 480 | dataset_train = DealDataset('./data/MNIST/raw', "train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", 481 | transform=trans_mnist) 482 | dataset_test = DealDataset('./data/MNIST/raw', "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz", 483 | transform=trans_mnist) 484 | 485 | 486 | 487 | dict_clients = mnist_noniid1(dataset_train, num_users=num_clients) 488 | 489 | net_global = CNNMnist(lamda=Lamda).to(device) #.double() 490 | 491 | 492 | # for n, p in net_global.named_parameters(): 493 | # p.data.zero_() 494 | w_glob = net_global.state_dict() 495 | # print(w_glob) 496 | crit = torch.nn.CrossEntropyLoss() 497 | net_global.train() 498 | 499 | for interation in range(T): 500 | w_locals, loss_locals = [], [] 501 | # print('interationh',interation) 502 | for client in range(num_clients): 503 | # net = CNN(N_class=3,lamda=10000).double().to(device) 504 | net = copy.deepcopy(net_global).to(device) 505 | # crit = torch.nn.CrossEntropyLoss() 506 | net.train() 507 | opt_net = torch.optim.SGD(net.parameters(), lr=0.0001, momentum=0.9) 508 | 509 | print('interation', interation, 'client', client) 510 | idx_traindataset = DatasetSplit(dataset_train, dict_clients[client]) 511 | ldr_train = DataLoader(idx_traindataset, batch_size=512, shuffle=True) 512 | dataset_size = len(ldr_train.dataset) 513 | epochs_per_task = E 514 | mean_pre = {n: p.clone().detach() for n, p in net.named_parameters()} 515 | t0 = time.clock() 516 | for epoch in range(1, epochs_per_task + 1): 517 | correct = 0 518 | for batch_idx, (images, labels) in enumerate(ldr_train): 519 | images, labels = Variable(images).to(device), Variable(labels).type(torch.LongTensor).to(device) 520 | net.zero_grad() 521 | scores = net(images) 522 | ce_loss = crit(scores, labels) 523 | loss = ce_loss 524 | pred = scores.max(1)[1] 525 | correct += pred.eq(labels.data.view_as(pred)).cpu().sum() 526 | loss.backward() 527 | opt_net.step() 528 | 529 | Accuracy = 100. * correct.type(torch.FloatTensor) / dataset_size 530 | # print('Train Epoch:{}\tLoss:{:.4f}\tProx_Loss:{:.4f}\tCE_Loss:{:.4f}\tAccuracy: {:.4f}'.format(epoch,loss.item(),prox_loss.item(),ce_loss.item(),Accuracy)) 531 | print('Train Epoch:{}\tLoss:{:.4f}\tCE_Loss:{:.4f}\tAccuracy: {:.4f}'.format(epoch,loss.item(),ce_loss.item(),Accuracy)) 532 | # print(classification_report(labels.cpu().data.view_as(pred.cpu()), pred.cpu())) 533 | 534 | w_locals.append(copy.deepcopy(net.state_dict())) 535 | t1 = time.clock() 536 | print('client:\t', client, 'trainingtime:\t', str(t1 - t0)) 537 | w_glob = FedAvg(w_locals) 538 | net_global.load_state_dict(w_glob) 539 | net_global.eval() 540 | acc_test, loss_test = test_img(net_global, dataset_test) 541 | print("Testing accuracy: {:.2f}".format(acc_test)) 542 | 543 | model_dict = net_global.state_dict() # 自己的模型参数变量 544 | test_dict = {k: w_glob[k] for k in w_glob.keys() if k in model_dict} # 去除一些不需要的参数 545 | model_dict.update(test_dict) # 参数更新 546 | net_global.load_state_dict(model_dict) # 加载 547 | 548 | # for n, p in net_global.named_parameters(): 549 | # p = w_glob[n] 550 | 551 | # net_global.load_state_dict(w_glob) 552 | net_global.eval() 553 | acc_test, loss_test = test_img(net_global, dataset_test) 554 | print("Testing accuracy: {:.2f}".format(acc_test)) 555 | -------------------------------------------------------------------------------- /main_mnist_fedewc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/9/1 20:20 3 | # @Author : zhao 4 | # @File : main_mnist_fedewc.py 5 | 6 | 7 | import torch 8 | from torch.utils.data import Dataset, DataLoader 9 | from torch.autograd import Variable 10 | from torch import autograd 11 | import torch.nn.functional as F 12 | from torchvision import datasets, transforms 13 | import pandas as pd 14 | import numpy as np 15 | from sklearn.utils import shuffle 16 | from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support 17 | from sklearn.preprocessing import MinMaxScaler 18 | from collections import Iterable # < py38 19 | import copy 20 | from net_fewc import CNNMnist 21 | import logging 22 | import gzip 23 | import os 24 | import time 25 | 26 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | 28 | # convert a list of list to a list [[],[],[]]->[,,] 29 | def flatten(items): 30 | """Yield items from any nested iterable; see Reference.""" 31 | for x in items: 32 | if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): 33 | for sub_x in flatten(x): 34 | yield sub_x 35 | else: 36 | yield x 37 | 38 | 39 | class DealDataset(Dataset): 40 | """ 41 | 读取数据、初始化数据 42 | """ 43 | def __init__(self, folder, data_name, label_name,transform=None): 44 | (train_set, train_labels) = load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式 45 | self.train_set = train_set 46 | self.train_labels = train_labels 47 | self.transform = transform 48 | 49 | def __getitem__(self, index): 50 | 51 | img, target = self.train_set[index], int(self.train_labels[index]) 52 | if self.transform is not None: 53 | img = self.transform(img) 54 | return img, target 55 | 56 | def __len__(self): 57 | return len(self.train_set) 58 | 59 | def load_data(data_folder, data_name, label_name): 60 | """ 61 | data_folder: 文件目录 62 | data_name: 数据文件名 63 | label_name:标签数据文件名 64 | """ 65 | with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据 66 | y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8) 67 | 68 | with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath: 69 | x_train = np.frombuffer( 70 | imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28) 71 | return (x_train, y_train) 72 | 73 | 74 | class DatasetSplit(Dataset): 75 | def __init__(self, dataset, idxs): 76 | self.dataset = dataset 77 | self.idxs = list(idxs) 78 | 79 | def __len__(self): 80 | return len(self.idxs) 81 | 82 | def __getitem__(self, item): 83 | image, label = self.dataset[self.idxs[item]] 84 | return image, label 85 | 86 | 87 | def iid(dataset, num_users): 88 | """ 89 | Sample I.I.D. client data from dataset 90 | :param dataset: 91 | :param num_users: 92 | :return: dict of image index 93 | """ 94 | num_items = int(len(dataset) / num_users) 95 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 96 | for i in range(num_users): 97 | dict_users[i] = set(np.random.choice(all_idxs, num_items, 98 | replace=False)) # Generates random samples from all_idexs,return a array with size of num_items 99 | all_idxs = list(set(all_idxs) - dict_users[i]) 100 | return dict_users 101 | 102 | def mnist_noniid6(dataset, num_users): 103 | """ 104 | Sample non-I.I.D client data from MNIST dataset 105 | :param dataset: 106 | :param num_users: 107 | :return: 108 | """ 109 | num_shards, num_imgs = 60, 1000 110 | idx_shard = [i for i in range(num_shards)] 111 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 112 | idxs = np.arange(num_shards*num_imgs) 113 | labels = dataset.train_labels#.numpy() 114 | 115 | # sort labels 116 | idxs_labels = np.vstack((idxs, labels)) 117 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 118 | idxs = idxs_labels[0,:] 119 | 120 | # divide and assign 121 | for i in range(num_users): 122 | rand_set = set(np.random.choice(idx_shard, 6, replace=False)) 123 | idx_shard = list(set(idx_shard) - rand_set) 124 | for rand in rand_set: 125 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 126 | return dict_users 127 | 128 | def mnist_noniid5(dataset, num_users): 129 | """ 130 | Sample non-I.I.D client data from MNIST dataset 131 | :param dataset: 132 | :param num_users: 133 | :return: 134 | """ 135 | num_shards, num_imgs = 50, 1200 136 | idx_shard = [i for i in range(num_shards)] 137 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 138 | idxs = np.arange(num_shards*num_imgs) 139 | labels = dataset.train_labels#.numpy() 140 | 141 | # sort labels 142 | idxs_labels = np.vstack((idxs, labels)) 143 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 144 | idxs = idxs_labels[0,:] 145 | 146 | # divide and assign 147 | for i in range(num_users): 148 | rand_set = set(np.random.choice(idx_shard, 5, replace=False)) 149 | idx_shard = list(set(idx_shard) - rand_set) 150 | for rand in rand_set: 151 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 152 | return dict_users 153 | 154 | def mnist_noniid4(dataset, num_users): 155 | """ 156 | Sample non-I.I.D client data from MNIST dataset 157 | :param dataset: 158 | :param num_users: 159 | :return: 160 | """ 161 | num_shards, num_imgs = 40, 1500 162 | idx_shard = [i for i in range(num_shards)] 163 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 164 | idxs = np.arange(num_shards*num_imgs) 165 | labels = dataset.train_labels#.numpy() 166 | 167 | # sort labels 168 | idxs_labels = np.vstack((idxs, labels)) 169 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 170 | idxs = idxs_labels[0,:] 171 | 172 | # divide and assign 173 | for i in range(num_users): 174 | rand_set = set(np.random.choice(idx_shard, 4, replace=False)) 175 | idx_shard = list(set(idx_shard) - rand_set) 176 | for rand in rand_set: 177 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 178 | return dict_users 179 | 180 | def mnist_noniid3(dataset, num_users): 181 | """ 182 | Sample non-I.I.D client data from MNIST dataset 183 | :param dataset: 184 | :param num_users: 185 | :return: 186 | """ 187 | num_shards, num_imgs = 30, 2000 188 | idx_shard = [i for i in range(num_shards)] 189 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 190 | idxs = np.arange(num_shards*num_imgs) 191 | labels = dataset.train_labels#.numpy() 192 | 193 | # sort labels 194 | idxs_labels = np.vstack((idxs, labels)) 195 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 196 | idxs = idxs_labels[0,:] 197 | 198 | # divide and assign 199 | for i in range(num_users): 200 | rand_set = set(np.random.choice(idx_shard, 3, replace=False)) 201 | idx_shard = list(set(idx_shard) - rand_set) 202 | for rand in rand_set: 203 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 204 | return dict_users 205 | 206 | def mnist_noniid2(dataset, num_users): 207 | """ 208 | Sample non-I.I.D client data from MNIST dataset 209 | :param dataset: 210 | :param num_users: 211 | :return: 212 | """ 213 | num_shards, num_imgs = 20, 3000 214 | idx_shard = [i for i in range(num_shards)] 215 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 216 | idxs = np.arange(num_shards*num_imgs) 217 | labels = dataset.train_labels#.numpy() 218 | 219 | # sort labels 220 | idxs_labels = np.vstack((idxs, labels)) 221 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 222 | idxs = idxs_labels[0,:] 223 | 224 | # divide and assign 225 | for i in range(num_users): 226 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 227 | idx_shard = list(set(idx_shard) - rand_set) 228 | for rand in rand_set: 229 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 230 | return dict_users 231 | 232 | def mnist_noniid1(dataset, num_users): 233 | """ 234 | Sample non-I.I.D client data from MNIST dataset 235 | :param dataset: 236 | :param num_users: 237 | :return: 238 | """ 239 | num_shards, num_imgs = 10, 6000 240 | idx_shard = [i for i in range(num_shards)] 241 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 242 | idxs = np.arange(num_shards*num_imgs) 243 | labels = dataset.train_labels#.numpy() 244 | # sort labels 245 | idxs_labels = np.vstack((idxs, labels)) 246 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 247 | idxs = idxs_labels[0,:] 248 | # divide and assign 249 | for i in range(num_users): 250 | rand_set = set(np.random.choice(idx_shard, 1, replace=False)) 251 | idx_shard = list(set(idx_shard) - rand_set) 252 | for rand in rand_set: 253 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 254 | return dict_users 255 | 256 | def test_img(net_g, datatest): 257 | net_g.eval() 258 | # testing 259 | test_loss = 0 260 | correct = 0 261 | data_pred = [] 262 | data_label = [] 263 | data_loader = DataLoader(datatest, batch_size=test_BatchSize, shuffle=True) 264 | l = len(data_loader) 265 | loss = torch.nn.CrossEntropyLoss() 266 | for idx, (data, target) in enumerate(data_loader): 267 | data, target = Variable(data).to(device), Variable(target).type(torch.LongTensor).to(device) 268 | # data, target = Variable(data), Variable(target).type(torch.LongTensor) 269 | log_probs = net_g(data) 270 | # sum up batch loss 271 | test_loss += loss(log_probs, target).item() 272 | # test_loss += F.cross_entropy(log_probs, target, reduction='sum').item() 273 | # get the index of the max log-probability 274 | y_pred = log_probs.data.detach().max(1, keepdim=True)[1] 275 | correct += y_pred.eq(target.data.detach().view_as(y_pred)).long().cpu().sum() 276 | data_pred.append(y_pred.cpu().detach().data.tolist()) 277 | data_label.append(target.cpu().detach().data.tolist()) 278 | list_data_label = list(flatten(data_label)) 279 | list_data_pred = list(flatten(data_pred)) 280 | all_report = precision_recall_fscore_support(list_data_label, list_data_pred, average='weighted') 281 | all_precision = all_report[0] 282 | all_recall = all_report[1] 283 | all_fscore = all_report[2] 284 | print('all_precision',all_precision,'all_recall',all_recall,'all_fscore',all_fscore) 285 | # print(classification_report(list_data_label, list_data_pred)) 286 | print(confusion_matrix(list_data_label, list_data_pred)) 287 | # print('test_loss', test_loss) 288 | test_loss /= len(data_loader.dataset) 289 | accuracy = 100.00 * correct / len(data_loader.dataset) 290 | print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} {:.2f}'.format( 291 | test_loss, correct, len(data_loader.dataset), accuracy)) 292 | # logging.info('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} {:.2f}\n'.format( 293 | # test_loss, correct, len(data_loader.dataset), accuracy)) 294 | return accuracy, test_loss 295 | 296 | 297 | def FedAvg(w): 298 | w_avg = copy.deepcopy(w[0]) 299 | for k in w_avg.keys(): 300 | for i in range(1, len(w)): 301 | w_avg[k] += w[i][k] 302 | w_avg[k] = torch.div(w_avg[k], len(w)) 303 | return w_avg 304 | 305 | 306 | def getGradShapes(Model): 307 | """Return the shapes and sizes of the weight matrices""" 308 | gradShapes = [] 309 | gradSizes = [] 310 | for n, p in Model.named_parameters(): 311 | gradShapes.append(p.data.shape) 312 | gradSizes.append(np.prod(p.data.shape)) 313 | return gradShapes, gradSizes 314 | 315 | 316 | def getGradVec(w): 317 | """Return the gradient flattened to a vector""" 318 | gradVec = [] 319 | # flatten 320 | # for n, p in Model.named_parameters(): 321 | # # gradVec.append(torch.zeros_like(p.data.view(-1))) 322 | # gradVec.append(p.grad.data.view(-1).float()) 323 | for k in w.keys(): 324 | # gradVec.append(torch.zeros_like(p.data.view(-1))) 325 | gradVec.append(w[k].view(-1).float()) 326 | # concat into a single vector 327 | gradVec = torch.cat(gradVec) 328 | return gradVec 329 | 330 | 331 | def setGradVec(Model, vec): 332 | """Set the gradient to vec""" 333 | # put vec into p.grad.data 334 | vec = vec.to(device) 335 | gradShapes, gradSizes = getGradShapes(Model=Model) 336 | startPos = 0 337 | i = 0 338 | for n, p in Model.named_parameters(): 339 | shape = gradShapes[i] 340 | size = gradSizes[i] 341 | i += 1 342 | # assert (size == np.prod(p.grad.data.size())) 343 | p.grad.data.zero_() 344 | p.grad.data.add_(vec[startPos:startPos + size].reshape(shape)) 345 | startPos += size 346 | 347 | 348 | def topk(vec, k): 349 | """ Return the largest k elements (by magnitude) of vec""" 350 | ret = torch.zeros_like(vec) 351 | # on a gpu, sorting is faster than pytorch's topk method 352 | topkIndices = torch.sort(vec ** 2)[1][-k:] 353 | # _, topkIndices = torch.topk(vec**2, k) 354 | ret[topkIndices] = vec[topkIndices] 355 | return ret, topkIndices 356 | 357 | 358 | def quantize(x): 359 | compress_settings = {'n': 32} 360 | # compress_settings.update(input_compress_settings) 361 | # assume that x is a torch tensor 362 | 363 | n = compress_settings['n'] 364 | # print('n:{}'.format(n)) 365 | x = x.float() 366 | x_norm = torch.norm(x, p=float('inf')) # inf_norm = max(abs(x)) 367 | 368 | sgn_x = ((x > 0).float() - 0.5) * 2 369 | 370 | p = torch.div(torch.abs(x), x_norm) 371 | renormalize_p = torch.mul(p, n) 372 | floor_p = torch.floor(renormalize_p) 373 | compare = torch.rand_like(floor_p) 374 | final_p = renormalize_p - floor_p 375 | margin = (compare < final_p).float() 376 | xi = (floor_p + margin) / n 377 | 378 | Tilde_x = x_norm * sgn_x * xi 379 | 380 | return Tilde_x 381 | 382 | 383 | def quantize_log(x): 384 | compress_settings = {'n': 16} 385 | # compress_settings.update(input_compress_settings) 386 | # assume that x is a torch tensor 387 | n = compress_settings['n'] 388 | # print('n:{}'.format(n)) 389 | x = x.float() 390 | x_norm = torch.norm(x, p=float('inf')) # inf_norm = max(abs(x)) 391 | sgn_x = ((x > 0).float() - 0.5) * 2 392 | p = torch.div(torch.abs(x), x_norm) 393 | lookup = torch.linspace(0, -10, n) 394 | log_p = torch.log2(p) 395 | round_index = [(torch.abs(lookup - k)).min(dim=0)[1] for k in log_p] 396 | round_p = [2 ** (lookup[i]) for i in round_index] 397 | round_p = torch.stack(round_p).to(device) 398 | # print('round_p',round_p) 399 | # print('x_norm',x_norm) 400 | 401 | Tilde_x = x_norm * round_p * sgn_x 402 | 403 | return Tilde_x 404 | 405 | 406 | def quantization_layer(sizes, x): 407 | q_x = torch.zeros_like(x) 408 | startPos = 0 409 | for i in sizes: 410 | q_x[startPos:startPos + i] = quantize(x[startPos:startPos + i]) 411 | # q_x[startPos:startPos + i] = quantize_log(x[startPos:startPos + i]) 412 | startPos += i 413 | return q_x 414 | 415 | 416 | def sparsity(fisher, w_update, w_prev, topkIndices): 417 | Shapes = [] 418 | Sizes = [] 419 | for j in fisher.keys(): 420 | Shapes.append(fisher[j].shape) 421 | Sizes.append(np.prod(fisher[j].shape)) 422 | # print('fisher sizes', Sizes) 423 | fisher_vector = getGradVec(fisher) 424 | fisher_vector_spar = torch.zeros_like(fisher_vector) 425 | fisher_vector_spar[topkIndices] = fisher_vector[topkIndices] 426 | # fisher_vector_spar_q = quantization_layer(sizes=torch.tensor([144, 16, 4608, 32, 2560, 5]), 427 | # x=fisher_vector_spar) 428 | fisher_vector_spar_q = quantize(fisher_vector_spar) 429 | model_vector_spar_q = w_update + w_prev 430 | # model_vector_spar_q = w_update - w_prev 431 | fisher_spar = {k: torch.zeros_like(fisher[k]) for k in fisher.keys()} 432 | model_spar = {k: torch.zeros_like(fisher[k]) for k in fisher.keys()} 433 | startPos = 0 434 | j = 0 435 | for k in fisher.keys(): 436 | shape = Shapes[j] 437 | size = Sizes[j] 438 | j += 1 439 | fisher_spar[k] = fisher_vector_spar_q[startPos:startPos + size].reshape(shape).double() 440 | model_spar[k] = model_vector_spar_q[startPos:startPos + size].reshape(shape).double() 441 | startPos += size 442 | return fisher_spar, model_spar 443 | 444 | def consolidate(Model, dataset, sample_size, batch_size): 445 | # sample loglikelihoods from the dataset. 446 | data_loader = DataLoader(dataset, batch_size, shuffle=True) 447 | loglikelihoods = [] 448 | for x, y in data_loader: 449 | # x = x.view(batch_size, -1) 450 | x = Variable(x).to(device) 451 | y = Variable(y).type(torch.LongTensor).to(device) 452 | loglikelihoods.append( 453 | F.log_softmax(Model(x), dim=1)[range(batch_size), y.data] 454 | )# self(x) the model's output 455 | if len(loglikelihoods) >= sample_size // batch_size: 456 | break 457 | # estimate the fisher information of the parameters. 458 | # print('loglikelihoods',loglikelihoods1) 459 | loglikelihoods = torch.unbind(torch.cat(loglikelihoods)) # e.g. torch.unbind(torch.tensor([[1, 2, 3],[1, 2, 3]]) -> (tensor([1, 2, 3]), tensor([4, 5, 6])) 460 | # loglikelihoods = (tensor(1),tensor(2),tensor(3),tensor(1),tensor(2),tensor(3)) 461 | # print('loglikelihoods',loglikelihoods) 462 | loglikelihood_grads = zip(*[autograd.grad( 463 | l, Model.parameters(), 464 | retain_graph=(i < len(loglikelihoods)) 465 | ) for i, l in enumerate(loglikelihoods, 1)]) 466 | # print('loglikelihood_grads', loglikelihood_grads) 467 | loglikelihood_grads = [torch.stack(gs) for gs in loglikelihood_grads] 468 | # print('loglikelihood_grads', loglikelihood_grads) 469 | fisher_diagonals = [(g ** 2).mean(0) for g in loglikelihood_grads] 470 | param_names = [ 471 | n for n, p in Model.named_parameters() 472 | ] 473 | # return {n: f.detach() for n, f in zip(param_names, fisher_diagonals)} 474 | fisher = {n: f.detach() for n, f in zip(param_names, fisher_diagonals)} 475 | mean = {n: p.data for n, p in Model.named_parameters()} 476 | return fisher,mean 477 | 478 | 479 | # FL + EWC 480 | if __name__ == '__main__': 481 | # logging.basicConfig(filename='./20200512_cicids_our_noniid1_E_1_T_1.log', level=logging.DEBUG) 482 | # logging.info('11111') 483 | Lamda = 1.0 #0.5 484 | E = 5 485 | T = 50 486 | ## FedAvg 487 | # Lamda = 0.0 488 | frac = 1.0 489 | num_clients = 10 490 | batch_size = 512 491 | test_BatchSize = 32 492 | 493 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 494 | #### MNIST 495 | dataset_train = DealDataset('./data/MNIST/raw', "train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", 496 | transform=trans_mnist) 497 | dataset_test = DealDataset('./data/MNIST/raw', "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz", 498 | transform=trans_mnist) 499 | 500 | 501 | dict_clients = mnist_noniid1(dataset_train, num_users=num_clients) 502 | 503 | net_global = CNNMnist(lamda=Lamda).to(device) #.double() 504 | 505 | # for n, p in net_global.named_parameters(): 506 | # p.data.zero_() 507 | w_glob = net_global.state_dict() 508 | # print(w_glob) 509 | crit = torch.nn.CrossEntropyLoss()#torch.DoubleTensor weight=torch.FloatTensor([1, 1.2, 1.2, 1.2, 3]).to(device) 510 | # optimizer = torch.optim.SGD(net_global.parameters(), lr=0.001, momentum=0.5) 511 | net_global.train() 512 | 513 | omega_current, mean_current = {}, {} 514 | for i in range(num_clients): 515 | omega_current[i] = {} 516 | mean_current[i] = {} 517 | 518 | for interation in range(T): 519 | w_locals, loss_locals = [], [] 520 | # print('interationh',interation) 521 | weight_vec_pre = getGradVec(w_glob) 522 | for client in range(num_clients): 523 | # net = CNN(N_class=3,lamda=10000).double().to(device) 524 | net = copy.deepcopy(net_global).to(device) 525 | # crit = torch.nn.CrossEntropyLoss() 526 | net.train() 527 | opt_net = torch.optim.SGD(net.parameters(), lr=0.0001, momentum=0.9) 528 | 529 | 530 | print('interation', interation, 'client', client) 531 | idx_traindataset = DatasetSplit(dataset_train, dict_clients[client]) 532 | ldr_train = DataLoader(idx_traindataset, batch_size=512, shuffle=True) 533 | dataset_size = len(ldr_train.dataset) 534 | epochs_per_task = E 535 | 536 | mean_pre = {n: p.clone().detach() for n, p in net.named_parameters()} 537 | W = {n: p.clone().detach().zero_() for n, p in net.named_parameters()} 538 | 539 | t0 = time.clock() 540 | for epoch in range(1, epochs_per_task + 1): 541 | correct = 0 542 | for batch_idx, (images, labels) in enumerate(ldr_train): 543 | old_par = {n: p.clone().detach() for n, p in net.named_parameters()} 544 | images, labels = Variable(images).to(device), Variable(labels).type(torch.LongTensor).to(device) 545 | net.zero_grad() 546 | scores = net(images) 547 | ce_loss = crit(scores, labels) 548 | grad_params = torch.autograd.grad(ce_loss, net.parameters(), create_graph=True) 549 | if interation == 0: 550 | ewc_loss = torch.FloatTensor([0.0]).to(device) #torch.DoubleTensor 551 | else: 552 | other_clients = list(set([i for i in range(num_clients)])-set([client])) 553 | losses = {} 554 | sum_ewc_loss = 0 555 | for i in other_clients: 556 | losses[i] = [] 557 | for n, p in net.named_parameters(): 558 | mean, omega = Variable(mean_current[i][n]), Variable(omega_current[i][n]) 559 | losses[i].append((omega * (p - mean) ** 2).sum()) 560 | sum_ewc_loss += sum(losses[i]) 561 | ewc_loss = net.lamda * sum_ewc_loss 562 | loss = ce_loss + ewc_loss.double() # + reg_loss.double() 563 | pred = scores.max(1)[1] 564 | correct += pred.eq(labels.data.view_as(pred)).cpu().sum() 565 | loss.backward() 566 | opt_net.step() 567 | 568 | j = 0 569 | for n, p in net.named_parameters(): 570 | W[n] -= (grad_params[j].clone().detach()) * (p.detach() - old_par[n]) 571 | j += 1 572 | 573 | # if (interation != 0) and (client == 0): 574 | # print('1', sum(losses1), sum(lossnon), '2', sum(losses2), '3', sum(losses3), '4', sum(losses4)) 575 | Accuracy = 100. * correct.type(torch.FloatTensor) / dataset_size 576 | print('Train Epoch:{}\tLoss:{:.4f}\tEWC_Loss:{:.4f}\tCE_Loss:{:.4f}\tAccuracy: {:.4f}'.format(epoch,loss.item(),ewc_loss.item(),ce_loss.item(),Accuracy)) 577 | # logging.info('Train Epoch:{}\tLoss:{:.4f}\tEWC_Loss:{:.4f}\tCE_Loss:{:.4f}\tAccuracy: {:.4f}'.format(epoch,loss.item(),ewc_loss.item(),ce_loss.item(),Accuracy)) 578 | # print(classification_report(labels.cpu().data.view_as(pred.cpu()), pred.cpu())) 579 | 580 | w_locals.append(copy.deepcopy(net.state_dict())) 581 | omega_current[client], mean_current[client] = consolidate(Model=net, dataset=idx_traindataset, sample_size=1024,batch_size=batch_size) 582 | t1 = time.clock() 583 | print('client:\t', client, 'trainingtime:\t', str(t1 - t0)) 584 | w_glob = FedAvg(w_locals) 585 | net_global.load_state_dict(w_glob) 586 | net_global.eval() 587 | acc_test, loss_test = test_img(net_global, dataset_test) 588 | print("Testing accuracy: {:.2f}".format(acc_test)) 589 | 590 | model_dict = net_global.state_dict() # 自己的模型参数变量 591 | test_dict = {k: w_glob[k] for k in w_glob.keys() if k in model_dict} # 去除一些不需要的参数 592 | model_dict.update(test_dict) # 参数更新 593 | net_global.load_state_dict(model_dict) # 加载 594 | 595 | # for n, p in net_global.named_parameters(): 596 | # p = w_glob[n] 597 | 598 | # net_global.load_state_dict(w_glob) 599 | net_global.eval() 600 | acc_test, loss_test = test_img(net_global, dataset_test) 601 | print("Testing accuracy: {:.2f}".format(acc_test)) 602 | -------------------------------------------------------------------------------- /main_mnist_fedprox.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/9/2 22:45 3 | # @Author : zhao 4 | # @File : main_mnist_fedprox.py 5 | 6 | 7 | import torch 8 | from torch.utils.data import Dataset, DataLoader 9 | from torch.autograd import Variable 10 | from torch import autograd 11 | import torch.nn.functional as F 12 | from torchvision import datasets, transforms 13 | import pandas as pd 14 | import numpy as np 15 | from sklearn.utils import shuffle 16 | from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support 17 | from sklearn.preprocessing import MinMaxScaler 18 | from collections import Iterable # < py38 19 | import copy 20 | from net_fewc import CNNMnist 21 | import logging 22 | import gzip 23 | import os 24 | import time 25 | 26 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | 28 | # convert a list of list to a list [[],[],[]]->[,,] 29 | def flatten(items): 30 | """Yield items from any nested iterable; see Reference.""" 31 | for x in items: 32 | if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): 33 | for sub_x in flatten(x): 34 | yield sub_x 35 | else: 36 | yield x 37 | 38 | 39 | class DealDataset(Dataset): 40 | """ 41 | 读取数据、初始化数据 42 | """ 43 | def __init__(self, folder, data_name, label_name,transform=None): 44 | (train_set, train_labels) = load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式 45 | self.train_set = train_set 46 | self.train_labels = train_labels 47 | self.transform = transform 48 | 49 | def __getitem__(self, index): 50 | 51 | img, target = self.train_set[index], int(self.train_labels[index]) 52 | if self.transform is not None: 53 | img = self.transform(img) 54 | return img, target 55 | 56 | def __len__(self): 57 | return len(self.train_set) 58 | 59 | def load_data(data_folder, data_name, label_name): 60 | """ 61 | data_folder: 文件目录 62 | data_name: 数据文件名 63 | label_name:标签数据文件名 64 | """ 65 | with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据 66 | y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8) 67 | 68 | with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath: 69 | x_train = np.frombuffer( 70 | imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28) 71 | return (x_train, y_train) 72 | 73 | 74 | class DatasetSplit(Dataset): 75 | def __init__(self, dataset, idxs): 76 | self.dataset = dataset 77 | self.idxs = list(idxs) 78 | 79 | def __len__(self): 80 | return len(self.idxs) 81 | 82 | def __getitem__(self, item): 83 | image, label = self.dataset[self.idxs[item]] 84 | return image, label 85 | 86 | 87 | def iid(dataset, num_users): 88 | """ 89 | Sample I.I.D. client data from dataset 90 | :param dataset: 91 | :param num_users: 92 | :return: dict of image index 93 | """ 94 | num_items = int(len(dataset) / num_users) 95 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 96 | for i in range(num_users): 97 | dict_users[i] = set(np.random.choice(all_idxs, num_items, 98 | replace=False)) # Generates random samples from all_idexs,return a array with size of num_items 99 | all_idxs = list(set(all_idxs) - dict_users[i]) 100 | return dict_users 101 | 102 | def mnist_noniid6(dataset, num_users): 103 | """ 104 | Sample non-I.I.D client data from MNIST dataset 105 | :param dataset: 106 | :param num_users: 107 | :return: 108 | """ 109 | num_shards, num_imgs = 60, 1000 110 | idx_shard = [i for i in range(num_shards)] 111 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 112 | idxs = np.arange(num_shards*num_imgs) 113 | labels = dataset.train_labels#.numpy() 114 | 115 | # sort labels 116 | idxs_labels = np.vstack((idxs, labels)) 117 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 118 | idxs = idxs_labels[0,:] 119 | 120 | # divide and assign 121 | for i in range(num_users): 122 | rand_set = set(np.random.choice(idx_shard, 6, replace=False)) 123 | idx_shard = list(set(idx_shard) - rand_set) 124 | for rand in rand_set: 125 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 126 | return dict_users 127 | 128 | def mnist_noniid5(dataset, num_users): 129 | """ 130 | Sample non-I.I.D client data from MNIST dataset 131 | :param dataset: 132 | :param num_users: 133 | :return: 134 | """ 135 | num_shards, num_imgs = 50, 1200 136 | idx_shard = [i for i in range(num_shards)] 137 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 138 | idxs = np.arange(num_shards*num_imgs) 139 | labels = dataset.train_labels#.numpy() 140 | 141 | # sort labels 142 | idxs_labels = np.vstack((idxs, labels)) 143 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 144 | idxs = idxs_labels[0,:] 145 | 146 | # divide and assign 147 | for i in range(num_users): 148 | rand_set = set(np.random.choice(idx_shard, 5, replace=False)) 149 | idx_shard = list(set(idx_shard) - rand_set) 150 | for rand in rand_set: 151 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 152 | return dict_users 153 | 154 | def mnist_noniid4(dataset, num_users): 155 | """ 156 | Sample non-I.I.D client data from MNIST dataset 157 | :param dataset: 158 | :param num_users: 159 | :return: 160 | """ 161 | num_shards, num_imgs = 40, 1500 162 | idx_shard = [i for i in range(num_shards)] 163 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 164 | idxs = np.arange(num_shards*num_imgs) 165 | labels = dataset.train_labels#.numpy() 166 | 167 | # sort labels 168 | idxs_labels = np.vstack((idxs, labels)) 169 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 170 | idxs = idxs_labels[0,:] 171 | 172 | # divide and assign 173 | for i in range(num_users): 174 | rand_set = set(np.random.choice(idx_shard, 4, replace=False)) 175 | idx_shard = list(set(idx_shard) - rand_set) 176 | for rand in rand_set: 177 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 178 | return dict_users 179 | 180 | def mnist_noniid3(dataset, num_users): 181 | """ 182 | Sample non-I.I.D client data from MNIST dataset 183 | :param dataset: 184 | :param num_users: 185 | :return: 186 | """ 187 | num_shards, num_imgs = 30, 2000 188 | idx_shard = [i for i in range(num_shards)] 189 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 190 | idxs = np.arange(num_shards*num_imgs) 191 | labels = dataset.train_labels#.numpy() 192 | 193 | # sort labels 194 | idxs_labels = np.vstack((idxs, labels)) 195 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 196 | idxs = idxs_labels[0,:] 197 | 198 | # divide and assign 199 | for i in range(num_users): 200 | rand_set = set(np.random.choice(idx_shard, 3, replace=False)) 201 | idx_shard = list(set(idx_shard) - rand_set) 202 | for rand in rand_set: 203 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 204 | return dict_users 205 | 206 | def mnist_noniid2(dataset, num_users): 207 | """ 208 | Sample non-I.I.D client data from MNIST dataset 209 | :param dataset: 210 | :param num_users: 211 | :return: 212 | """ 213 | num_shards, num_imgs = 20, 3000 214 | idx_shard = [i for i in range(num_shards)] 215 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 216 | idxs = np.arange(num_shards*num_imgs) 217 | labels = dataset.train_labels#.numpy() 218 | 219 | # sort labels 220 | idxs_labels = np.vstack((idxs, labels)) 221 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 222 | idxs = idxs_labels[0,:] 223 | 224 | # divide and assign 225 | for i in range(num_users): 226 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 227 | idx_shard = list(set(idx_shard) - rand_set) 228 | for rand in rand_set: 229 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 230 | return dict_users 231 | 232 | def mnist_noniid1(dataset, num_users): 233 | """ 234 | Sample non-I.I.D client data from MNIST dataset 235 | :param dataset: 236 | :param num_users: 237 | :return: 238 | """ 239 | num_shards, num_imgs = 10, 6000 240 | idx_shard = [i for i in range(num_shards)] 241 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 242 | idxs = np.arange(num_shards*num_imgs) 243 | labels = dataset.train_labels#.numpy() 244 | # sort labels 245 | idxs_labels = np.vstack((idxs, labels)) 246 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 247 | idxs = idxs_labels[0,:] 248 | # divide and assign 249 | for i in range(num_users): 250 | rand_set = set(np.random.choice(idx_shard, 1, replace=False)) 251 | idx_shard = list(set(idx_shard) - rand_set) 252 | for rand in rand_set: 253 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 254 | return dict_users 255 | 256 | def test_img(net_g, datatest): 257 | net_g.eval() 258 | # testing 259 | test_loss = 0 260 | correct = 0 261 | data_pred = [] 262 | data_label = [] 263 | data_loader = DataLoader(datatest, batch_size=test_BatchSize, shuffle=True) 264 | l = len(data_loader) 265 | loss = torch.nn.CrossEntropyLoss() 266 | for idx, (data, target) in enumerate(data_loader): 267 | data, target = Variable(data).to(device), Variable(target).type(torch.LongTensor).to(device) 268 | # data, target = Variable(data), Variable(target).type(torch.LongTensor) 269 | log_probs = net_g(data) 270 | # sum up batch loss 271 | test_loss += loss(log_probs, target).item() 272 | # test_loss += F.cross_entropy(log_probs, target, reduction='sum').item() 273 | # get the index of the max log-probability 274 | y_pred = log_probs.data.detach().max(1, keepdim=True)[1] 275 | correct += y_pred.eq(target.data.detach().view_as(y_pred)).long().cpu().sum() 276 | data_pred.append(y_pred.cpu().detach().data.tolist()) 277 | data_label.append(target.cpu().detach().data.tolist()) 278 | list_data_label = list(flatten(data_label)) 279 | list_data_pred = list(flatten(data_pred)) 280 | all_report = precision_recall_fscore_support(list_data_label, list_data_pred, average='weighted') 281 | all_precision = all_report[0] 282 | all_recall = all_report[1] 283 | all_fscore = all_report[2] 284 | print('all_precision',all_precision,'all_recall',all_recall,'all_fscore',all_fscore) 285 | # print(classification_report(list_data_label, list_data_pred)) 286 | print(confusion_matrix(list_data_label, list_data_pred)) 287 | # print('test_loss', test_loss) 288 | test_loss /= len(data_loader.dataset) 289 | accuracy = 100.00 * correct / len(data_loader.dataset) 290 | print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} {:.2f}'.format( 291 | test_loss, correct, len(data_loader.dataset), accuracy)) 292 | # logging.info('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} {:.2f}\n'.format( 293 | # test_loss, correct, len(data_loader.dataset), accuracy)) 294 | return accuracy, test_loss 295 | 296 | 297 | def FedAvg(w): 298 | w_avg = copy.deepcopy(w[0]) 299 | for k in w_avg.keys(): 300 | for i in range(1, len(w)): 301 | w_avg[k] += w[i][k] 302 | w_avg[k] = torch.div(w_avg[k], len(w)) 303 | return w_avg 304 | 305 | 306 | def getGradShapes(Model): 307 | """Return the shapes and sizes of the weight matrices""" 308 | gradShapes = [] 309 | gradSizes = [] 310 | for n, p in Model.named_parameters(): 311 | gradShapes.append(p.data.shape) 312 | gradSizes.append(np.prod(p.data.shape)) 313 | return gradShapes, gradSizes 314 | 315 | 316 | def getGradVec(w): 317 | """Return the gradient flattened to a vector""" 318 | gradVec = [] 319 | # flatten 320 | # for n, p in Model.named_parameters(): 321 | # # gradVec.append(torch.zeros_like(p.data.view(-1))) 322 | # gradVec.append(p.grad.data.view(-1).float()) 323 | for k in w.keys(): 324 | # gradVec.append(torch.zeros_like(p.data.view(-1))) 325 | gradVec.append(w[k].view(-1).float()) 326 | # concat into a single vector 327 | gradVec = torch.cat(gradVec) 328 | return gradVec 329 | 330 | 331 | def setGradVec(Model, vec): 332 | """Set the gradient to vec""" 333 | # put vec into p.grad.data 334 | vec = vec.to(device) 335 | gradShapes, gradSizes = getGradShapes(Model=Model) 336 | startPos = 0 337 | i = 0 338 | for n, p in Model.named_parameters(): 339 | shape = gradShapes[i] 340 | size = gradSizes[i] 341 | i += 1 342 | # assert (size == np.prod(p.grad.data.size())) 343 | p.grad.data.zero_() 344 | p.grad.data.add_(vec[startPos:startPos + size].reshape(shape)) 345 | startPos += size 346 | 347 | 348 | def topk(vec, k): 349 | """ Return the largest k elements (by magnitude) of vec""" 350 | ret = torch.zeros_like(vec) 351 | # on a gpu, sorting is faster than pytorch's topk method 352 | topkIndices = torch.sort(vec ** 2)[1][-k:] 353 | # _, topkIndices = torch.topk(vec**2, k) 354 | ret[topkIndices] = vec[topkIndices] 355 | return ret, topkIndices 356 | 357 | 358 | def quantize(x): 359 | compress_settings = {'n': 32} 360 | # compress_settings.update(input_compress_settings) 361 | # assume that x is a torch tensor 362 | 363 | n = compress_settings['n'] 364 | # print('n:{}'.format(n)) 365 | x = x.float() 366 | x_norm = torch.norm(x, p=float('inf')) # inf_norm = max(abs(x)) 367 | 368 | sgn_x = ((x > 0).float() - 0.5) * 2 369 | 370 | p = torch.div(torch.abs(x), x_norm) 371 | renormalize_p = torch.mul(p, n) 372 | floor_p = torch.floor(renormalize_p) 373 | compare = torch.rand_like(floor_p) 374 | final_p = renormalize_p - floor_p 375 | margin = (compare < final_p).float() 376 | xi = (floor_p + margin) / n 377 | 378 | Tilde_x = x_norm * sgn_x * xi 379 | 380 | return Tilde_x 381 | 382 | 383 | def quantize_log(x): 384 | compress_settings = {'n': 16} 385 | # compress_settings.update(input_compress_settings) 386 | # assume that x is a torch tensor 387 | n = compress_settings['n'] 388 | # print('n:{}'.format(n)) 389 | x = x.float() 390 | x_norm = torch.norm(x, p=float('inf')) # inf_norm = max(abs(x)) 391 | sgn_x = ((x > 0).float() - 0.5) * 2 392 | p = torch.div(torch.abs(x), x_norm) 393 | lookup = torch.linspace(0, -10, n) 394 | log_p = torch.log2(p) 395 | round_index = [(torch.abs(lookup - k)).min(dim=0)[1] for k in log_p] 396 | round_p = [2 ** (lookup[i]) for i in round_index] 397 | round_p = torch.stack(round_p).to(device) 398 | # print('round_p',round_p) 399 | # print('x_norm',x_norm) 400 | 401 | Tilde_x = x_norm * round_p * sgn_x 402 | 403 | return Tilde_x 404 | 405 | 406 | def quantization_layer(sizes, x): 407 | q_x = torch.zeros_like(x) 408 | startPos = 0 409 | for i in sizes: 410 | q_x[startPos:startPos + i] = quantize(x[startPos:startPos + i]) 411 | # q_x[startPos:startPos + i] = quantize_log(x[startPos:startPos + i]) 412 | startPos += i 413 | return q_x 414 | 415 | 416 | def sparsity(fisher, w_update, w_prev, topkIndices): 417 | Shapes = [] 418 | Sizes = [] 419 | for j in fisher.keys(): 420 | Shapes.append(fisher[j].shape) 421 | Sizes.append(np.prod(fisher[j].shape)) 422 | # print('fisher sizes', Sizes) 423 | fisher_vector = getGradVec(fisher) 424 | fisher_vector_spar = torch.zeros_like(fisher_vector) 425 | fisher_vector_spar[topkIndices] = fisher_vector[topkIndices] 426 | # fisher_vector_spar_q = quantization_layer(sizes=torch.tensor([144, 16, 4608, 32, 2560, 5]), 427 | # x=fisher_vector_spar) 428 | fisher_vector_spar_q = quantize(fisher_vector_spar) 429 | model_vector_spar_q = w_update + w_prev 430 | # model_vector_spar_q = w_update - w_prev 431 | fisher_spar = {k: torch.zeros_like(fisher[k]) for k in fisher.keys()} 432 | model_spar = {k: torch.zeros_like(fisher[k]) for k in fisher.keys()} 433 | startPos = 0 434 | j = 0 435 | for k in fisher.keys(): 436 | shape = Shapes[j] 437 | size = Sizes[j] 438 | j += 1 439 | fisher_spar[k] = fisher_vector_spar_q[startPos:startPos + size].reshape(shape).double() 440 | model_spar[k] = model_vector_spar_q[startPos:startPos + size].reshape(shape).double() 441 | startPos += size 442 | return fisher_spar, model_spar 443 | 444 | 445 | 446 | # FL + EWC 447 | if __name__ == '__main__': 448 | # logging.basicConfig(filename='./20200512_cicids_our_noniid1_E_1_T_1.log', level=logging.DEBUG) 449 | # logging.info('11111') 450 | Lamda = 1.0 #0.5 451 | E = 5 452 | T = 50 453 | ## FedAvg 454 | # rho = 1.0 455 | # Lamda = 0.0 456 | frac = 1.0 457 | num_clients = 10 458 | batch_size = 512 459 | test_BatchSize = 32 460 | 461 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 462 | #### MNIST 463 | dataset_train = DealDataset('./data/MNIST/raw', "train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", 464 | transform=trans_mnist) 465 | dataset_test = DealDataset('./data/MNIST/raw', "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz", 466 | transform=trans_mnist) 467 | 468 | 469 | dict_clients = mnist_noniid1(dataset_train, num_users=num_clients) 470 | 471 | net_global = CNNMnist(lamda=Lamda).to(device) #.double() 472 | 473 | # for n, p in net_global.named_parameters(): 474 | # p.data.zero_() 475 | w_glob = net_global.state_dict() 476 | # print(w_glob) 477 | crit = torch.nn.CrossEntropyLoss()#torch.DoubleTensor weight=torch.FloatTensor([1, 1.2, 1.2, 1.2, 3]).to(device) 478 | net_global.train() 479 | 480 | for interation in range(T): 481 | w_locals, loss_locals = [], [] 482 | # print('interationh',interation) 483 | for client in range(num_clients): 484 | # net = CNN(N_class=3,lamda=10000).double().to(device) 485 | net = copy.deepcopy(net_global).to(device) 486 | # crit = torch.nn.CrossEntropyLoss() 487 | net.train() 488 | opt_net = torch.optim.SGD(net.parameters(), lr=0.0001, momentum=0.9) 489 | 490 | print('interation', interation, 'client', client) 491 | idx_traindataset = DatasetSplit(dataset_train, dict_clients[client]) 492 | ldr_train = DataLoader(idx_traindataset, batch_size=512, shuffle=True) 493 | dataset_size = len(ldr_train.dataset) 494 | epochs_per_task = E 495 | 496 | mean_pre = {n: p.clone().detach() for n, p in net.named_parameters()} 497 | t0 = time.clock() 498 | for epoch in range(1, epochs_per_task + 1): 499 | correct = 0 500 | for batch_idx, (images, labels) in enumerate(ldr_train): 501 | images, labels = Variable(images).to(device), Variable(labels).type(torch.LongTensor).to(device) 502 | net.zero_grad() 503 | scores = net(images) 504 | ce_loss = crit(scores, labels) 505 | if interation == 0: 506 | prox_loss = torch.DoubleTensor([0.0]).to(device) 507 | else: 508 | loss_prox = [] 509 | for n, p in net.named_parameters(): 510 | mean = Variable(mean_pre[n]) 511 | loss_prox.append(((p - mean) ** 2).sum()) 512 | sum_prox_loss = sum(loss_prox) 513 | prox_loss = net.lamda * sum_prox_loss 514 | loss = ce_loss + prox_loss # .double() 515 | 516 | pred = scores.max(1)[1] 517 | correct += pred.eq(labels.data.view_as(pred)).cpu().sum() 518 | loss.backward() 519 | opt_net.step() 520 | 521 | Accuracy = 100. * correct.type(torch.FloatTensor) / dataset_size 522 | print('Train Epoch:{}\tLoss:{:.4f}\tProx_Loss:{:.4f}\tCE_Loss:{:.4f}\tAccuracy: {:.4f}'.format(epoch,loss.item(),prox_loss.item(),ce_loss.item(),Accuracy)) 523 | # print('Train Epoch:{}\tLoss:{:.4f}\tCE_Loss:{:.4f}\tAccuracy: {:.4f}'.format(epoch,loss.item(),ce_loss.item(),Accuracy)) 524 | # print(classification_report(labels.cpu().data.view_as(pred.cpu()), pred.cpu())) 525 | 526 | w_locals.append(copy.deepcopy(net.state_dict())) 527 | t1 = time.clock() 528 | print('client:\t', client, 'trainingtime:\t', str(t1 - t0)) 529 | w_glob = FedAvg(w_locals) 530 | net_global.load_state_dict(w_glob) 531 | net_global.eval() 532 | acc_test, loss_test = test_img(net_global, dataset_test) 533 | print("Testing accuracy: {:.2f}".format(acc_test)) 534 | 535 | model_dict = net_global.state_dict() # 自己的模型参数变量 536 | test_dict = {k: w_glob[k] for k in w_glob.keys() if k in model_dict} # 去除一些不需要的参数 537 | model_dict.update(test_dict) # 参数更新 538 | net_global.load_state_dict(model_dict) # 加载 539 | 540 | # for n, p in net_global.named_parameters(): 541 | # p = w_glob[n] 542 | 543 | # net_global.load_state_dict(w_glob) 544 | net_global.eval() 545 | acc_test, loss_test = test_img(net_global, dataset_test) 546 | print("Testing accuracy: {:.2f}".format(acc_test)) 547 | -------------------------------------------------------------------------------- /main_mnist_fedsi.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/9/2 22:12 3 | # @Author : zhao 4 | # @File : main_mnist_fedsi.py 5 | 6 | 7 | import torch 8 | from torch.utils.data import Dataset, DataLoader 9 | from torch.autograd import Variable 10 | from torch import autograd 11 | import torch.nn.functional as F 12 | from torchvision import datasets, transforms 13 | import pandas as pd 14 | import numpy as np 15 | from sklearn.utils import shuffle 16 | from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support 17 | from sklearn.preprocessing import MinMaxScaler 18 | from collections import Iterable # < py38 19 | import copy 20 | from net_fewc import CNNMnist 21 | import logging 22 | import gzip 23 | import os 24 | import time 25 | 26 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | 28 | # convert a list of list to a list [[],[],[]]->[,,] 29 | def flatten(items): 30 | """Yield items from any nested iterable; see Reference.""" 31 | for x in items: 32 | if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): 33 | for sub_x in flatten(x): 34 | yield sub_x 35 | else: 36 | yield x 37 | 38 | 39 | class DealDataset(Dataset): 40 | """ 41 | 读取数据、初始化数据 42 | """ 43 | def __init__(self, folder, data_name, label_name,transform=None): 44 | (train_set, train_labels) = load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式 45 | self.train_set = train_set 46 | self.train_labels = train_labels 47 | self.transform = transform 48 | 49 | def __getitem__(self, index): 50 | 51 | img, target = self.train_set[index], int(self.train_labels[index]) 52 | if self.transform is not None: 53 | img = self.transform(img) 54 | return img, target 55 | 56 | def __len__(self): 57 | return len(self.train_set) 58 | 59 | def load_data(data_folder, data_name, label_name): 60 | """ 61 | data_folder: 文件目录 62 | data_name: 数据文件名 63 | label_name:标签数据文件名 64 | """ 65 | with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据 66 | y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8) 67 | 68 | with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath: 69 | x_train = np.frombuffer( 70 | imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28) 71 | return (x_train, y_train) 72 | 73 | 74 | class DatasetSplit(Dataset): 75 | def __init__(self, dataset, idxs): 76 | self.dataset = dataset 77 | self.idxs = list(idxs) 78 | 79 | def __len__(self): 80 | return len(self.idxs) 81 | 82 | def __getitem__(self, item): 83 | image, label = self.dataset[self.idxs[item]] 84 | return image, label 85 | 86 | 87 | def iid(dataset, num_users): 88 | """ 89 | Sample I.I.D. client data from dataset 90 | :param dataset: 91 | :param num_users: 92 | :return: dict of image index 93 | """ 94 | num_items = int(len(dataset) / num_users) 95 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 96 | for i in range(num_users): 97 | dict_users[i] = set(np.random.choice(all_idxs, num_items, 98 | replace=False)) # Generates random samples from all_idexs,return a array with size of num_items 99 | all_idxs = list(set(all_idxs) - dict_users[i]) 100 | return dict_users 101 | 102 | def mnist_noniid6(dataset, num_users): 103 | """ 104 | Sample non-I.I.D client data from MNIST dataset 105 | :param dataset: 106 | :param num_users: 107 | :return: 108 | """ 109 | num_shards, num_imgs = 60, 1000 110 | idx_shard = [i for i in range(num_shards)] 111 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 112 | idxs = np.arange(num_shards*num_imgs) 113 | labels = dataset.train_labels#.numpy() 114 | 115 | # sort labels 116 | idxs_labels = np.vstack((idxs, labels)) 117 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 118 | idxs = idxs_labels[0,:] 119 | 120 | # divide and assign 121 | for i in range(num_users): 122 | rand_set = set(np.random.choice(idx_shard, 6, replace=False)) 123 | idx_shard = list(set(idx_shard) - rand_set) 124 | for rand in rand_set: 125 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 126 | return dict_users 127 | 128 | def mnist_noniid5(dataset, num_users): 129 | """ 130 | Sample non-I.I.D client data from MNIST dataset 131 | :param dataset: 132 | :param num_users: 133 | :return: 134 | """ 135 | num_shards, num_imgs = 50, 1200 136 | idx_shard = [i for i in range(num_shards)] 137 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 138 | idxs = np.arange(num_shards*num_imgs) 139 | labels = dataset.train_labels#.numpy() 140 | 141 | # sort labels 142 | idxs_labels = np.vstack((idxs, labels)) 143 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 144 | idxs = idxs_labels[0,:] 145 | 146 | # divide and assign 147 | for i in range(num_users): 148 | rand_set = set(np.random.choice(idx_shard, 5, replace=False)) 149 | idx_shard = list(set(idx_shard) - rand_set) 150 | for rand in rand_set: 151 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 152 | return dict_users 153 | 154 | def mnist_noniid4(dataset, num_users): 155 | """ 156 | Sample non-I.I.D client data from MNIST dataset 157 | :param dataset: 158 | :param num_users: 159 | :return: 160 | """ 161 | num_shards, num_imgs = 40, 1500 162 | idx_shard = [i for i in range(num_shards)] 163 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 164 | idxs = np.arange(num_shards*num_imgs) 165 | labels = dataset.train_labels#.numpy() 166 | 167 | # sort labels 168 | idxs_labels = np.vstack((idxs, labels)) 169 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 170 | idxs = idxs_labels[0,:] 171 | 172 | # divide and assign 173 | for i in range(num_users): 174 | rand_set = set(np.random.choice(idx_shard, 4, replace=False)) 175 | idx_shard = list(set(idx_shard) - rand_set) 176 | for rand in rand_set: 177 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 178 | return dict_users 179 | 180 | def mnist_noniid3(dataset, num_users): 181 | """ 182 | Sample non-I.I.D client data from MNIST dataset 183 | :param dataset: 184 | :param num_users: 185 | :return: 186 | """ 187 | num_shards, num_imgs = 30, 2000 188 | idx_shard = [i for i in range(num_shards)] 189 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 190 | idxs = np.arange(num_shards*num_imgs) 191 | labels = dataset.train_labels#.numpy() 192 | 193 | # sort labels 194 | idxs_labels = np.vstack((idxs, labels)) 195 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 196 | idxs = idxs_labels[0,:] 197 | 198 | # divide and assign 199 | for i in range(num_users): 200 | rand_set = set(np.random.choice(idx_shard, 3, replace=False)) 201 | idx_shard = list(set(idx_shard) - rand_set) 202 | for rand in rand_set: 203 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 204 | return dict_users 205 | 206 | def mnist_noniid2(dataset, num_users): 207 | """ 208 | Sample non-I.I.D client data from MNIST dataset 209 | :param dataset: 210 | :param num_users: 211 | :return: 212 | """ 213 | num_shards, num_imgs = 20, 3000 214 | idx_shard = [i for i in range(num_shards)] 215 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 216 | idxs = np.arange(num_shards*num_imgs) 217 | labels = dataset.train_labels#.numpy() 218 | 219 | # sort labels 220 | idxs_labels = np.vstack((idxs, labels)) 221 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 222 | idxs = idxs_labels[0,:] 223 | 224 | # divide and assign 225 | for i in range(num_users): 226 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 227 | idx_shard = list(set(idx_shard) - rand_set) 228 | for rand in rand_set: 229 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 230 | return dict_users 231 | 232 | def mnist_noniid1(dataset, num_users): 233 | """ 234 | Sample non-I.I.D client data from MNIST dataset 235 | :param dataset: 236 | :param num_users: 237 | :return: 238 | """ 239 | num_shards, num_imgs = 10, 6000 240 | idx_shard = [i for i in range(num_shards)] 241 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 242 | idxs = np.arange(num_shards*num_imgs) 243 | labels = dataset.train_labels#.numpy() 244 | # sort labels 245 | idxs_labels = np.vstack((idxs, labels)) 246 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 247 | idxs = idxs_labels[0,:] 248 | # divide and assign 249 | for i in range(num_users): 250 | rand_set = set(np.random.choice(idx_shard, 1, replace=False)) 251 | idx_shard = list(set(idx_shard) - rand_set) 252 | for rand in rand_set: 253 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 254 | return dict_users 255 | 256 | def test_img(net_g, datatest): 257 | net_g.eval() 258 | # testing 259 | test_loss = 0 260 | correct = 0 261 | data_pred = [] 262 | data_label = [] 263 | data_loader = DataLoader(datatest, batch_size=test_BatchSize, shuffle=True) 264 | l = len(data_loader) 265 | loss = torch.nn.CrossEntropyLoss() 266 | for idx, (data, target) in enumerate(data_loader): 267 | data, target = Variable(data).to(device), Variable(target).type(torch.LongTensor).to(device) 268 | # data, target = Variable(data), Variable(target).type(torch.LongTensor) 269 | log_probs = net_g(data) 270 | # sum up batch loss 271 | test_loss += loss(log_probs, target).item() 272 | # test_loss += F.cross_entropy(log_probs, target, reduction='sum').item() 273 | # get the index of the max log-probability 274 | y_pred = log_probs.data.detach().max(1, keepdim=True)[1] 275 | correct += y_pred.eq(target.data.detach().view_as(y_pred)).long().cpu().sum() 276 | data_pred.append(y_pred.cpu().detach().data.tolist()) 277 | data_label.append(target.cpu().detach().data.tolist()) 278 | list_data_label = list(flatten(data_label)) 279 | list_data_pred = list(flatten(data_pred)) 280 | all_report = precision_recall_fscore_support(list_data_label, list_data_pred, average='weighted') 281 | all_precision = all_report[0] 282 | all_recall = all_report[1] 283 | all_fscore = all_report[2] 284 | print('all_precision',all_precision,'all_recall',all_recall,'all_fscore',all_fscore) 285 | # print(classification_report(list_data_label, list_data_pred)) 286 | print(confusion_matrix(list_data_label, list_data_pred)) 287 | # print('test_loss', test_loss) 288 | test_loss /= len(data_loader.dataset) 289 | accuracy = 100.00 * correct / len(data_loader.dataset) 290 | print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} {:.2f}'.format( 291 | test_loss, correct, len(data_loader.dataset), accuracy)) 292 | # logging.info('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} {:.2f}\n'.format( 293 | # test_loss, correct, len(data_loader.dataset), accuracy)) 294 | return accuracy, test_loss 295 | 296 | 297 | def FedAvg(w): 298 | w_avg = copy.deepcopy(w[0]) 299 | for k in w_avg.keys(): 300 | for i in range(1, len(w)): 301 | w_avg[k] += w[i][k] 302 | w_avg[k] = torch.div(w_avg[k], len(w)) 303 | return w_avg 304 | 305 | 306 | def getGradShapes(Model): 307 | """Return the shapes and sizes of the weight matrices""" 308 | gradShapes = [] 309 | gradSizes = [] 310 | for n, p in Model.named_parameters(): 311 | gradShapes.append(p.data.shape) 312 | gradSizes.append(np.prod(p.data.shape)) 313 | return gradShapes, gradSizes 314 | 315 | 316 | def getGradVec(w): 317 | """Return the gradient flattened to a vector""" 318 | gradVec = [] 319 | # flatten 320 | # for n, p in Model.named_parameters(): 321 | # # gradVec.append(torch.zeros_like(p.data.view(-1))) 322 | # gradVec.append(p.grad.data.view(-1).float()) 323 | for k in w.keys(): 324 | # gradVec.append(torch.zeros_like(p.data.view(-1))) 325 | gradVec.append(w[k].view(-1).float()) 326 | # concat into a single vector 327 | gradVec = torch.cat(gradVec) 328 | return gradVec 329 | 330 | 331 | def setGradVec(Model, vec): 332 | """Set the gradient to vec""" 333 | # put vec into p.grad.data 334 | vec = vec.to(device) 335 | gradShapes, gradSizes = getGradShapes(Model=Model) 336 | startPos = 0 337 | i = 0 338 | for n, p in Model.named_parameters(): 339 | shape = gradShapes[i] 340 | size = gradSizes[i] 341 | i += 1 342 | # assert (size == np.prod(p.grad.data.size())) 343 | p.grad.data.zero_() 344 | p.grad.data.add_(vec[startPos:startPos + size].reshape(shape)) 345 | startPos += size 346 | 347 | 348 | def topk(vec, k): 349 | """ Return the largest k elements (by magnitude) of vec""" 350 | ret = torch.zeros_like(vec) 351 | # on a gpu, sorting is faster than pytorch's topk method 352 | topkIndices = torch.sort(vec ** 2)[1][-k:] 353 | # _, topkIndices = torch.topk(vec**2, k) 354 | ret[topkIndices] = vec[topkIndices] 355 | return ret, topkIndices 356 | 357 | 358 | def quantize(x): 359 | compress_settings = {'n': 32} 360 | # compress_settings.update(input_compress_settings) 361 | # assume that x is a torch tensor 362 | 363 | n = compress_settings['n'] 364 | # print('n:{}'.format(n)) 365 | x = x.float() 366 | x_norm = torch.norm(x, p=float('inf')) # inf_norm = max(abs(x)) 367 | 368 | sgn_x = ((x > 0).float() - 0.5) * 2 369 | 370 | p = torch.div(torch.abs(x), x_norm) 371 | renormalize_p = torch.mul(p, n) 372 | floor_p = torch.floor(renormalize_p) 373 | compare = torch.rand_like(floor_p) 374 | final_p = renormalize_p - floor_p 375 | margin = (compare < final_p).float() 376 | xi = (floor_p + margin) / n 377 | 378 | Tilde_x = x_norm * sgn_x * xi 379 | 380 | return Tilde_x 381 | 382 | 383 | def quantize_log(x): 384 | compress_settings = {'n': 16} 385 | # compress_settings.update(input_compress_settings) 386 | # assume that x is a torch tensor 387 | n = compress_settings['n'] 388 | # print('n:{}'.format(n)) 389 | x = x.float() 390 | x_norm = torch.norm(x, p=float('inf')) # inf_norm = max(abs(x)) 391 | sgn_x = ((x > 0).float() - 0.5) * 2 392 | p = torch.div(torch.abs(x), x_norm) 393 | lookup = torch.linspace(0, -10, n) 394 | log_p = torch.log2(p) 395 | round_index = [(torch.abs(lookup - k)).min(dim=0)[1] for k in log_p] 396 | round_p = [2 ** (lookup[i]) for i in round_index] 397 | round_p = torch.stack(round_p).to(device) 398 | # print('round_p',round_p) 399 | # print('x_norm',x_norm) 400 | 401 | Tilde_x = x_norm * round_p * sgn_x 402 | 403 | return Tilde_x 404 | 405 | 406 | def quantization_layer(sizes, x): 407 | q_x = torch.zeros_like(x) 408 | startPos = 0 409 | for i in sizes: 410 | q_x[startPos:startPos + i] = quantize(x[startPos:startPos + i]) 411 | # q_x[startPos:startPos + i] = quantize_log(x[startPos:startPos + i]) 412 | startPos += i 413 | return q_x 414 | 415 | 416 | def sparsity(fisher, w_update, w_prev, topkIndices): 417 | Shapes = [] 418 | Sizes = [] 419 | for j in fisher.keys(): 420 | Shapes.append(fisher[j].shape) 421 | Sizes.append(np.prod(fisher[j].shape)) 422 | # print('fisher sizes', Sizes) 423 | fisher_vector = getGradVec(fisher) 424 | fisher_vector_spar = torch.zeros_like(fisher_vector) 425 | fisher_vector_spar[topkIndices] = fisher_vector[topkIndices] 426 | # fisher_vector_spar_q = quantization_layer(sizes=torch.tensor([144, 16, 4608, 32, 2560, 5]), 427 | # x=fisher_vector_spar) 428 | fisher_vector_spar_q = quantize(fisher_vector_spar) 429 | model_vector_spar_q = w_update + w_prev 430 | # model_vector_spar_q = w_update - w_prev 431 | fisher_spar = {k: torch.zeros_like(fisher[k]) for k in fisher.keys()} 432 | model_spar = {k: torch.zeros_like(fisher[k]) for k in fisher.keys()} 433 | startPos = 0 434 | j = 0 435 | for k in fisher.keys(): 436 | shape = Shapes[j] 437 | size = Sizes[j] 438 | j += 1 439 | fisher_spar[k] = fisher_vector_spar_q[startPos:startPos + size].reshape(shape).double() 440 | model_spar[k] = model_vector_spar_q[startPos:startPos + size].reshape(shape).double() 441 | startPos += size 442 | return fisher_spar, model_spar 443 | 444 | def consolidate(Model, Weight, MEAN_pre, epsilon): 445 | OMEGA_current = {n: p.data.clone().zero_() for n, p in Model.named_parameters()} 446 | for n, p in Model.named_parameters(): 447 | p_current = p.detach().clone() 448 | p_change = p_current - MEAN_pre[n] 449 | # W[n].add_((p.grad**2) * torch.abs(p_change)) 450 | # OMEGA_add = W[n]/ (p_change ** 2 + epsilon) 451 | # W[n].add_(-p.grad * p_change) 452 | OMEGA_add = torch.max(Weight[n], Weight[n].clone().zero_()) / (p_change ** 2 + epsilon) 453 | # OMEGA_add = Weight[n] / (p_change ** 2 + epsilon) 454 | # OMEGA_current[n] = OMEGA_pre[n] + OMEGA_add 455 | OMEGA_current[n] = OMEGA_add 456 | MEAN_current = {n: p.data for n, p in Model.named_parameters()} 457 | return OMEGA_current, MEAN_current 458 | 459 | 460 | # FL + EWC 461 | if __name__ == '__main__': 462 | # logging.basicConfig(filename='./20200512_cicids_our_noniid1_E_1_T_1.log', level=logging.DEBUG) 463 | # logging.info('11111') 464 | epsilon = 0.0001 465 | Lamda = 1.0 #1.0 466 | E = 5 467 | T = 50 #50 468 | ## FedAvg 469 | # rho = 1.0 470 | # Lamda = 0.0 471 | frac = 1.0 472 | num_clients = 10 473 | batch_size = 512 474 | test_BatchSize = 32 475 | 476 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 477 | #### MNIST 478 | dataset_train = DealDataset('./data/MNIST/raw', "train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", 479 | transform=trans_mnist) 480 | dataset_test = DealDataset('./data/MNIST/raw', "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz", 481 | transform=trans_mnist) 482 | 483 | 484 | dict_clients = mnist_noniid1(dataset_train, num_users=num_clients) 485 | 486 | net_global = CNNMnist(lamda=Lamda).to(device) #.double() 487 | 488 | # for n, p in net_global.named_parameters(): 489 | # p.data.zero_() 490 | w_glob = net_global.state_dict() 491 | # print(w_glob) 492 | crit = torch.nn.CrossEntropyLoss()#torch.DoubleTensor weight=torch.FloatTensor([1, 1.2, 1.2, 1.2, 3]).to(device) 493 | # optimizer = torch.optim.SGD(net_global.parameters(), lr=0.001, momentum=0.5) 494 | net_global.train() 495 | 496 | omega_current, mean_current = {}, {} 497 | for i in range(num_clients): 498 | omega_current[i] = {} 499 | mean_current[i] = {} 500 | error_compensation = {} 501 | 502 | for interation in range(T): 503 | w_locals, loss_locals = [], [] 504 | # print('interationh',interation) 505 | weight_vec_pre = getGradVec(w_glob) 506 | for client in range(num_clients): 507 | # net = CNN(N_class=3,lamda=10000).double().to(device) 508 | net = copy.deepcopy(net_global).to(device) 509 | # crit = torch.nn.CrossEntropyLoss() 510 | net.train() 511 | opt_net = torch.optim.SGD(net.parameters(), lr=0.0001, momentum=0.9) 512 | 513 | print('interation', interation, 'client', client) 514 | idx_traindataset = DatasetSplit(dataset_train, dict_clients[client]) 515 | ldr_train = DataLoader(idx_traindataset, batch_size=512, shuffle=True) 516 | dataset_size = len(ldr_train.dataset) 517 | epochs_per_task = E 518 | 519 | mean_pre = {n: p.clone().detach() for n, p in net.named_parameters()} 520 | W = {n: p.clone().detach().zero_() for n, p in net.named_parameters()} 521 | 522 | t0 = time.clock() 523 | for epoch in range(1, epochs_per_task + 1): 524 | correct = 0 525 | for batch_idx, (images, labels) in enumerate(ldr_train): 526 | old_par = {n: p.clone().detach() for n, p in net.named_parameters()} 527 | images, labels = Variable(images).to(device), Variable(labels).type(torch.LongTensor).to(device) 528 | net.zero_grad() 529 | scores = net(images) 530 | ce_loss = crit(scores, labels) 531 | grad_params = torch.autograd.grad(ce_loss, net.parameters(), create_graph=True) 532 | if interation == 0: 533 | ewc_loss = torch.FloatTensor([0.0]).to(device) #torch.DoubleTensor 534 | else: 535 | other_clients = list(set([i for i in range(num_clients)])-set([client])) 536 | losses = {} 537 | sum_ewc_loss = 0 538 | for i in other_clients: 539 | losses[i] = [] 540 | for n, p in net.named_parameters(): 541 | mean, omega = Variable(mean_current[i][n]), Variable(omega_current[i][n]) 542 | losses[i].append((omega * (p - mean) ** 2).sum()) 543 | sum_ewc_loss += sum(losses[i]) 544 | ewc_loss = net.lamda * sum_ewc_loss 545 | loss = ce_loss + ewc_loss.double() # + reg_loss.double() 546 | pred = scores.max(1)[1] 547 | correct += pred.eq(labels.data.view_as(pred)).cpu().sum() 548 | loss.backward() 549 | opt_net.step() 550 | 551 | j = 0 552 | for n, p in net.named_parameters(): 553 | W[n] -= (grad_params[j].clone().detach()) * (p.detach() - old_par[n]) 554 | j += 1 555 | 556 | # if (interation != 0) and (client == 0): 557 | # print('1', sum(losses1), sum(lossnon), '2', sum(losses2), '3', sum(losses3), '4', sum(losses4)) 558 | Accuracy = 100. * correct.type(torch.FloatTensor) / dataset_size 559 | print('Train Epoch:{}\tLoss:{:.4f}\tEWC_Loss:{:.4f}\tCE_Loss:{:.4f}\tAccuracy: {:.4f}'.format(epoch,loss.item(),ewc_loss.item(),ce_loss.item(),Accuracy)) 560 | # logging.info('Train Epoch:{}\tLoss:{:.4f}\tEWC_Loss:{:.4f}\tCE_Loss:{:.4f}\tAccuracy: {:.4f}'.format(epoch,loss.item(),ewc_loss.item(),ce_loss.item(),Accuracy)) 561 | # print(classification_report(labels.cpu().data.view_as(pred.cpu()), pred.cpu())) 562 | 563 | w_locals.append(copy.deepcopy(net.state_dict())) 564 | omega_current[client], mean_current[client] = consolidate(Model=net, Weight=W, MEAN_pre=mean_pre, epsilon=epsilon) 565 | t1 = time.clock() 566 | print('client:\t', client, 'trainingtime:\t', str(t1 - t0)) 567 | 568 | # w = net.state_dict() 569 | # weight_vec_current = getGradVec(w) 570 | # K = int(rho * len(weight_vec_current)) 571 | # print('sparsity k=', K) 572 | # 573 | # if interation == 0: 574 | # error_compensation[client] = torch.zeros_like(weight_vec_current) 575 | # weight_update, Topkindices = topk(vec=(weight_vec_current - weight_vec_pre + error_compensation[client]),k=K) 576 | # weight_update_q = quantize(weight_update) 577 | # error_compensation[client] = (weight_update_q - (weight_vec_current - weight_vec_pre + error_compensation[client])) # + error_compensation0 578 | # omega_current_00, mean_current_00 = consolidate(Model=net, Weight=W, MEAN_pre=mean_pre, epsilon=epsilon) 579 | # omega_current[client], mean_current[client] = sparsity(fisher=omega_current_00, w_update=weight_update_q,w_prev=weight_vec_pre, topkIndices=Topkindices) 580 | # w_vec_locals.append(weight_update_q) 581 | 582 | # weight_update, _ = topk(vec=(weight_vec_current - weight_vec_pre + error_compensation), k=K) 583 | # weight_update_q = quantize(weight_update) 584 | # # weight_update_q=quantize_log(weight_update) 585 | # # weight_update_q=quantization_layer(sizes=torch.tensor([144, 16, 4608, 32, 2560, 5]),x=weight_update) 586 | # error_compensation = error_compensation + ( 587 | # weight_vec_current - weight_vec_pre + error_compensation - weight_update_q) 588 | # w_vec_locals.append(weight_update_q) 589 | 590 | # weight_update, _ = topk(vec=(weight_vec_current - weight_vec_pre), k=K) 591 | # # weight_update_q=quantize(weight_update) 592 | # # weight_update_q=quantize_log(weight_update) 593 | # # weight_update_q=quantization_layer(sizes=torch.tensor([144, 16, 4608, 32, 2560, 5]),x=weight_update) 594 | # w_vec_locals.append(weight_update) 595 | 596 | w_glob = FedAvg(w_locals) 597 | # # w_glob = FedAvg(model=copy.deepcopy(net_global), w_pre=weight_vec_pre, w_update=w_vec_locals) 598 | # w_glob,error_compensation = FedAvg(model=copy.deepcopy(net_global), w_pre=weight_vec_pre, w_update=w_vec_locals, error=error_compensation) 599 | 600 | # print(w_glob) 601 | 602 | # w_glob = FedAvg(w_locals) 603 | # copy weight to net_glob 604 | net_global.load_state_dict(w_glob) 605 | # net_global.load_state_dict(w_glob) 606 | net_global.eval() 607 | acc_test, loss_test = test_img(net_global, dataset_test) 608 | print("Testing accuracy: {:.2f}".format(acc_test)) 609 | 610 | model_dict = net_global.state_dict() # 自己的模型参数变量 611 | test_dict = {k: w_glob[k] for k in w_glob.keys() if k in model_dict} # 去除一些不需要的参数 612 | model_dict.update(test_dict) # 参数更新 613 | net_global.load_state_dict(model_dict) # 加载 614 | 615 | # for n, p in net_global.named_parameters(): 616 | # p = w_glob[n] 617 | 618 | # net_global.load_state_dict(w_glob) 619 | net_global.eval() 620 | acc_test, loss_test = test_img(net_global, dataset_test) 621 | print("Testing accuracy: {:.2f}".format(acc_test)) 622 | -------------------------------------------------------------------------------- /net_fewc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.utils.data import Dataset, DataLoader 4 | from torch import autograd 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') 8 | 9 | 10 | class CNNMnist(torch.nn.Module): 11 | def __init__(self,lamda=0): 12 | super(CNNMnist, self).__init__() 13 | self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5) 14 | self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5) 15 | self.conv2_drop = torch.nn.Dropout2d() 16 | self.fc1 = torch.nn.Linear(320, 50) 17 | self.fc2 = torch.nn.Linear(50, 10) 18 | self.lamda = lamda 19 | 20 | def forward(self, x): 21 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 22 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 23 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 24 | x = F.relu(self.fc1(x)) 25 | x = F.dropout(x, training=self.training) 26 | x = self.fc2(x) 27 | return F.log_softmax(x, dim=1) 28 | --------------------------------------------------------------------------------