├── fedbe.png ├── models ├── __init__.py ├── FedM.py ├── Fed.py ├── test.py ├── Nets.py ├── swa.py └── Update.py ├── utils ├── __init__.py ├── tools.py ├── options.py └── sampling.py ├── run.sh ├── README.md ├── resnet.py ├── swag.py ├── LICENSE └── main.py /fedbe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hongyouc/FedBE/HEAD/fedbe.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python3 main.py --local_ep 20 --rounds 40 --server_ep 20 --update dist --teacher_type SWAG --use_client --use_SWA --num_users 10 --model resnet32 --weight_decay 0.0002 --exp fedbe & 2 | python3 main.py --local_ep 20 --rounds 40 --server_ep 20 --update dist --teacher_type clients --use_client --num_users 10 --model resnet32 --weight_decay 0.0002 --exp vanilla & 3 | python3 main.py --local_ep 20 --rounds 40 --server_ep 20 --update FedAvg --teacher_type clients --num_users 10 --model resnet32 --weight_decay 0.0002 --exp fedavg 4 | 5 | 6 | python3 main.py --local_ep 20 --rounds 40 --server_ep 20 --update dist --teacher_type SWAG --use_SWA --use_client --num_users 10 --model cnn --weight_decay 0.001 --exp fedbe & 7 | python3 main.py --local_ep 20 --rounds 40 --server_ep 20 --update dist --teacher_type clients --num_users 10 --model cnn --weight_decay 0.001 --use_client --exp vanilla & 8 | python3 main.py --local_ep 20 --rounds 40 --server_ep 20 --update FedAvg --teacher_type clients --num_users 10 --model cnn --weight_decay 0.001 --exp fedavg 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /models/FedM.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import copy 6 | import torch 7 | from torch import nn 8 | import numpy as np 9 | 10 | def create_local_init(glob, local, bias_ratio): 11 | assert bias_ratio <=1.0 and bias_ratio >= 0.0 12 | for k in glob.keys(): 13 | if bias_ratio > 0: 14 | glob[k] = glob[k]*(1-bias_ratio) + local[k]*(bias_ratio) 15 | else: 16 | glob[k] = (glob[k]+local[k]*bias_ratio)/(1.0 + bias_ratio) 17 | return glob 18 | 19 | 20 | def FedAvgM(w, gpu, w_org, mom, size_arr=None): 21 | (global_w, momentum) = w_org 22 | w_avg = {} 23 | for k in w[0].keys(): 24 | w_avg[k] = torch.zeros(w[0][k].size()) 25 | w_mom = dict(w_avg) 26 | 27 | # Prepare p 28 | if size_arr is not None: 29 | total_num = np.sum(size_arr) 30 | size_arr = np.array([float(p)/total_num for p in size_arr])*len(size_arr) 31 | else: 32 | size_arr = np.array([1.0]*len(size_arr)) 33 | 34 | for k in w_avg.keys(): 35 | for i in range(0, len(w)): 36 | grad = global_w[k] - w[i][k] 37 | w_avg[k] += size_arr[i]*grad 38 | 39 | mom_k = torch.div(w_avg[k], len(w))*(1-mom) + momentum[k]*mom 40 | w_avg[k] = global_w[k] - mom_k 41 | w_mom[k] = mom_k 42 | 43 | return w_avg, w_mom 44 | -------------------------------------------------------------------------------- /models/Fed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import copy 6 | import torch 7 | from torch import nn 8 | import numpy as np 9 | 10 | def create_local_init(glob, local, bias_ratio): 11 | assert bias_ratio <=1.0 and bias_ratio >= 0.0 12 | for k in glob.keys(): 13 | if bias_ratio > 0: 14 | glob[k] = glob[k]*(1-bias_ratio) + local[k]*(bias_ratio) 15 | else: 16 | glob[k] = (glob[k]+local[k]*bias_ratio)/(1.0 + bias_ratio) 17 | return glob 18 | 19 | 20 | def FedAvg(w, gpu, global_w=None, size_arr=None): 21 | w_avg = {} 22 | for k in w[0].keys(): 23 | w_avg[k] = torch.zeros(w[0][k].size()) 24 | 25 | # Prepare p 26 | if size_arr is not None: 27 | total_num = np.sum(size_arr) 28 | size_arr = np.array([float(p)/total_num for p in size_arr])*len(size_arr) 29 | else: 30 | size_arr = np.array([1.0]*len(size_arr)) 31 | 32 | if global_w is not None: 33 | for k in w_avg.keys(): 34 | for i in range(0, len(w)): 35 | grad = w[i][k] 36 | grad_norm = torch.norm(grad, p=2) / torch.norm(global_w[k], p=2) 37 | w_avg[k] += size_arr[i]*grad / grad_norm 38 | w_avg[k] = torch.div(w_avg[k], len(w)) 39 | else: 40 | for k in w_avg.keys(): 41 | for i in range(0, len(w)): 42 | w_avg[k] += size_arr[i]*w[i][k] 43 | w_avg[k] = torch.div(w_avg[k], len(w)) 44 | 45 | return w_avg 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FedBE: Making Bayesian Model Ensemble Applicable to Federated Learning 2 | Federated learning aims to collaboratively train a strong global model by accessing 3 | users’ locally trained models but not their own data. A crucial step is therefore to 4 | aggregate local models into a global model, which has been shown challenging 5 | when users have non-i.i.d. data. In this paper, we propose a novel aggregation algorithm named FedBE, which takes a Bayesian inference perspective by sampling 6 | higher-quality global models and combining them via Bayesian model Ensemble, 7 | leading to much robust aggregation. We show that an effective model distribution 8 | can be constructed by simply fitting a Gaussian or Dirichlet distribution to the local 9 | models. Our empirical studies validate FedBE’s superior performance, especially 10 | when users’ data are not i.i.d. and when the neural networks go deeper. Moreover, 11 | FedBE is compatible with recent efforts in regularizing users’ model training, 12 | making it an easily applicable module: you only need to replace the aggregation 13 | method but leave other parts of your federated learning algorithm intact 14 | 15 | ![](fedbe.png) 16 | 17 | ## Citation 18 | This repository also implements parallelized client training in PyTorch. Please cite us if you find it useful. 19 | Please email chen.9301[at]osu.edu for questions. Thank you. 20 | ``` 21 | @inproceedings{chen2020fedbe, 22 | title={FedBE: Making Bayesian Model Ensemble Applicable to Federated Learning}, 23 | author={Chen, Hong-You and Chao, Wei-Lun}, 24 | booktitle = {ICLR}, 25 | year={2021} 26 | } 27 | ``` 28 | 29 | ## References 30 | Some codes are based on: 31 | * [shaoxiongji/federated-learning](https://github.com/shaoxiongji/federated-learning) 32 | * [timgaripov/swa](https://github.com/timgaripov/swa) 33 | * [wjmaddox/swa_gaussian](https://github.com/wjmaddox/swa_gaussian) 34 | -------------------------------------------------------------------------------- /models/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader, Dataset 9 | import numpy as np 10 | 11 | class DatasetSplit(Dataset): 12 | def __init__(self, dataset, idxs): 13 | self.dataset = dataset 14 | self.idxs = list(idxs) 15 | 16 | def __len__(self): 17 | return len(self.idxs) 18 | 19 | def __getitem__(self, item): 20 | image, label = self.dataset[self.idxs[item]] 21 | return image, label 22 | 23 | def onehot_encode(target, n_classes): 24 | y = torch.zeros(len(target), n_classes).cuda() 25 | y[range(y.shape[0]), target]=1 26 | 27 | return y 28 | 29 | def test_img(net_g, datatest, args, idxs, reweight=None, cls_num=10): 30 | net_g.eval() 31 | test_loss = 0 32 | correct = 0 33 | cnt = 0.0 34 | 35 | data_loader = DataLoader(DatasetSplit(datatest, idxs), batch_size=1024, shuffle=False) 36 | l = len(data_loader) 37 | net_g = net_g.cuda() 38 | with torch.no_grad(): 39 | for idx, (data, target) in enumerate(data_loader): 40 | if args.gpu != -1: 41 | data, target = data.cuda(), target.cuda() 42 | log_probs = net_g(data) 43 | # sum up batch loss 44 | test_loss += F.cross_entropy(log_probs, target, reduction='sum').item() 45 | # get the index of the max log-probability 46 | y_pred = log_probs.data.max(1, keepdim=True)[1] 47 | target = target.data.view_as(y_pred) 48 | correct += y_pred.eq(target).long().cpu().sum() 49 | cnt += len(data) 50 | 51 | test_loss /= cnt 52 | accuracy = 100.00 * correct / cnt 53 | 54 | if args.verbose: 55 | print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format( 56 | test_loss, correct, len(data_loader.dataset), accuracy)) 57 | return accuracy.numpy(), test_loss 58 | 59 | -------------------------------------------------------------------------------- /models/Nets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch, pdb 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class MLP(nn.Module): 11 | def __init__(self, dim_in, dim_hidden, dim_out): 12 | super(MLP, self).__init__() 13 | self.layer_input = nn.Linear(dim_in, dim_hidden) 14 | self.relu = nn.ReLU() 15 | self.dropout = nn.Dropout() 16 | self.layer_hidden = nn.Linear(dim_hidden, dim_out) 17 | self.softmax = nn.Softmax(dim=1) 18 | 19 | def forward(self, x): 20 | x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1]) 21 | x = self.layer_input(x) 22 | x = self.dropout(x) 23 | x = self.relu(x) 24 | x = self.layer_hidden(x) 25 | return x 26 | 27 | 28 | class CNNMnist(nn.Module): 29 | def __init__(self, args): 30 | super(CNNMnist, self).__init__() 31 | self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5) 32 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 33 | self.conv2_drop = nn.Dropout2d() 34 | self.fc1 = nn.Linear(320, 50) 35 | self.fc2 = nn.Linear(50, args.num_classes) 36 | 37 | def forward(self, x): 38 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 39 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 40 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 41 | x = F.relu(self.fc1(x)) 42 | x = F.dropout(x, training=self.training) 43 | x = self.fc2(x) 44 | return x 45 | 46 | 47 | class CNNCifar(nn.Module): 48 | def __init__(self, args): 49 | super(CNNCifar, self).__init__() 50 | self.conv1 = nn.Conv2d(3, 32, 3) 51 | self.pool = nn.MaxPool2d(2, 2) 52 | self.conv2 = nn.Conv2d(32, 64, 3) 53 | self.conv3 = nn.Conv2d(64, 64, 3) 54 | 55 | self.hidden = nn.ModuleList() 56 | for k in range(args.num_layers): 57 | self.hidden.append(nn.Conv2d(64, 64, 3, padding=1)) 58 | self.fc1 = nn.Linear(1024, 64) 59 | self.fc2 = nn.Linear(64, args.num_classes) 60 | 61 | def forward(self, x): 62 | x = self.conv1(x) 63 | x = self.pool(F.relu(x)) 64 | x = self.conv2(x) 65 | x = self.pool(F.relu(x)) 66 | x = self.conv3(x) 67 | x = F.relu(x) 68 | 69 | for l in self.hidden: 70 | x = F.relu(l(x)) 71 | x = x.view(-1, 1024) 72 | x = self.fc1(x) 73 | x = F.relu(x) 74 | x = self.fc2(x) 75 | return x 76 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import pdb 3 | import os 4 | import pickle 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | 10 | from sklearn import metrics 11 | from scipy.stats import mode 12 | from scipy.stats import entropy 13 | from scipy.stats import entropy 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch.utils.data import DataLoader, Dataset 18 | import torch.nn.functional as F 19 | 20 | 21 | def store_model(iter, model_dir, w_glob_org, client_w_list): 22 | torch.save(w_glob_org, os.path.join(model_dir, "w_org_%d"%iter)) 23 | for i in range(len(client_w_list)): 24 | torch.save(client_w_list[i], os.path.join(model_dir, "client_%d_%d"%(iter, i))) 25 | 26 | def adaptive_schedule(local_ep, total_ep, rounds, adap_ep): 27 | if rounds<5: 28 | running_ep = adap_ep 29 | else: 30 | running_ep = local_ep 31 | return running_ep 32 | 33 | def lr_schedule(base_lr, iter, total_ep): 34 | if iter==0: 35 | return base_lr*0.5 36 | 37 | elif iter>total_ep*0.9: 38 | return base_lr* 0.01 39 | 40 | elif iter>total_ep*0.6: 41 | return base_lr* 0.1 42 | 43 | elif iter>total_ep*0.3: 44 | return base_lr* 0.2 45 | 46 | else: 47 | return base_lr 48 | 49 | def get_entropy(logits): 50 | mean_entropy = np.mean([entropy(logit) for logit in logits]) 51 | return mean_entropy 52 | 53 | class DatasetSplit(Dataset): 54 | def __init__(self, dataset, idxs): 55 | self.dataset = dataset 56 | self.idxs = list(idxs) 57 | 58 | def __len__(self): 59 | return len(self.idxs) 60 | 61 | def __getitem__(self, item): 62 | image, label = self.dataset[self.idxs[item]] 63 | return image, label 64 | 65 | def get_input_logits(inputs, model, is_logit=False, net_org=None): 66 | model.eval() 67 | with torch.no_grad(): 68 | logit = model(inputs).detach() 69 | if not is_logit: 70 | logit = F.softmax(logit, dim=1) 71 | 72 | logit = logit.cpu().numpy() 73 | return logit 74 | 75 | def temp_softmax(x, axis=-1, temp=1.0): 76 | x = x/temp 77 | e_x = np.exp(x - np.max(x)) # same code 78 | e_x = e_x / e_x.sum(axis=axis, keepdims=True) 79 | return e_x 80 | 81 | def temp_sharpen(x, axis=-1, temp=1.0): 82 | x = np.maximum(x**(1/temp), 1e-8) 83 | return x / x.sum(axis=axis, keepdims=True) 84 | 85 | 86 | def merge_logits(logits, method, loss_type, temp=0.3, global_ep=1000): 87 | if "vote" in method: 88 | if loss_type=="CE": 89 | votes = np.argmax(logits, axis=-1) 90 | logits_arr = mode(votes, axis=1)[0].reshape((len(logits))) 91 | logits_cond = np.mean(np.max(logits, axis=-1), axis=-1) 92 | else: 93 | logits = np.mean(logits, axis=1) 94 | logits_arr = temp_softmax(logits, temp=temp) 95 | logits_cond = np.max(logits_arr, axis=-1) 96 | else: 97 | logits = np.mean(logits, axis=1) 98 | 99 | if loss_type=="MSE": 100 | logits_arr = temp_softmax(logits, temp=1) 101 | logits_cond = np.max(logits_arr, axis=-1) 102 | elif "KL" in loss_type: 103 | logits_arr = temp_sharpen(logits, temp=temp) 104 | logits_cond = np.max(logits_arr, axis=-1) 105 | else: 106 | logits_arr = logits 107 | logits_cond = softmax(logits, axis=-1) 108 | logits_cond = np.max(logits_cond, axis=-1) 109 | 110 | return logits_arr, logits_cond 111 | 112 | def weights_init(m): 113 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 114 | torch.nn.init.xavier_uniform_(m.weight) 115 | if m.bias is not None: 116 | torch.nn.init.zeros_(m.bias) 117 | 118 | class logger(): 119 | def __init__(self, name): 120 | self.name = name 121 | self.loss_train_list = [] 122 | self.loss_test_list = [] 123 | 124 | 125 | self.train_acc_list = [] 126 | self.test_acc_list = [] 127 | self.val_acc_list = [] 128 | self.loss_val_list = [] 129 | 130 | self.ens_train_acc_list = [] 131 | self.ens_test_acc_list = [] 132 | self.ens_val_acc_list = [] 133 | 134 | 135 | self.teacher_loss_train_list = [] 136 | self.teacher_loss_test_list = [] 137 | 138 | self.swa_train_acc_list=[] 139 | self.swa_test_acc_list=[] 140 | self.swa_val_acc_list = [] 141 | 142 | self.swag_train_acc_list=[] 143 | self.swag_test_acc_list=[] 144 | self.swag_val_acc_list = [] 145 | 146 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Properly implemented ResNet for CIFAR10 as described in paper [1]. 3 | The implementation and structure of this file is hugely influenced by [2] 4 | which is implemented for ImageNet and doesn't have option A for identity. 5 | Moreover, most of the implementations on the web is copy-paste from 6 | torchvision's resnet and has wrong number of params. 7 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 8 | number of layers and parameters: 9 | name | layers | params 10 | ResNet20 | 20 | 0.27M 11 | ResNet32 | 32 | 0.46M 12 | ResNet44 | 44 | 0.66M 13 | ResNet56 | 56 | 0.85M 14 | ResNet110 | 110 | 1.7M 15 | ResNet1202| 1202 | 19.4m 16 | which this implementation indeed has. 17 | Reference: 18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 20 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 21 | If you use this implementation in you work, please don't forget to mention the 22 | author, Yerlan Idelbayev. 23 | ''' 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | import torch.nn.init as init 28 | from torch.nn import Parameter 29 | 30 | __all__ = ['ResNet_s' 'resnet32'] 31 | 32 | def _weights_init(m): 33 | classname = m.__class__.__name__ 34 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 35 | init.kaiming_normal_(m.weight) 36 | 37 | class NormedLinear(nn.Module): 38 | 39 | def __init__(self, in_features, out_features): 40 | super(NormedLinear, self).__init__() 41 | self.weight = Parameter(torch.Tensor(in_features, out_features)) 42 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 43 | 44 | def forward(self, x): 45 | out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0)) 46 | return out 47 | 48 | class BasicBlock(nn.Module): 49 | expansion = 1 50 | 51 | def __init__(self, in_planes, planes, stride=1, option='A'): 52 | super(BasicBlock, self).__init__() 53 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(planes) 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.shortcut = nn.Sequential() 58 | if stride != 1 or in_planes != planes: 59 | if option == 'A': 60 | """ 61 | For CIFAR10 ResNet paper uses option A. 62 | """ 63 | self.shortcut = LambdaLayer(lambda x: 64 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 65 | elif option == 'B': 66 | self.shortcut = nn.Sequential( 67 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 68 | nn.BatchNorm2d(self.expansion * planes) 69 | ) 70 | 71 | def forward(self, x): 72 | out = F.relu(self.bn1(self.conv1(x))) 73 | out = self.bn2(self.conv2(out)) 74 | out += self.shortcut(x) 75 | out = F.relu(out) 76 | return out 77 | 78 | 79 | class ResNet_s(nn.Module): 80 | 81 | def __init__(self, block, num_blocks, num_classes=10, use_norm=False): 82 | super(ResNet_s, self).__init__() 83 | self.in_planes = 16 84 | 85 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 86 | self.bn1 = nn.BatchNorm2d(16) 87 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 88 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 89 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 90 | self.linear = nn.Linear(64, num_classes) 91 | self.apply(_weights_init) 92 | 93 | # block 94 | 95 | def _make_layer(self, block, planes, num_blocks, stride): 96 | strides = [stride] + [1]*(num_blocks-1) 97 | layers = [] 98 | for stride in strides: 99 | layers.append(block(self.in_planes, planes, stride)) 100 | self.in_planes = planes * block.expansion 101 | 102 | return nn.Sequential(*layers) 103 | 104 | def forward(self, x): 105 | out = F.relu(self.bn1(self.conv1(x))) 106 | 107 | out = self.layer1(out) 108 | out = self.layer2(out) 109 | out = self.layer3(out) 110 | 111 | out = F.avg_pool2d(out, out.size()[3]) 112 | out = out.view(out.size(0), -1) 113 | out = self.linear(out) 114 | return out 115 | 116 | def resnet32(num_classes=10, use_norm=False): 117 | return ResNet_s(BasicBlock, [5, 5, 5], num_classes=num_classes, use_norm=use_norm) 118 | -------------------------------------------------------------------------------- /swag.py: -------------------------------------------------------------------------------- 1 | """ 2 | implementation of SWAG 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | import itertools 8 | from torch.distributions.normal import Normal 9 | import copy 10 | 11 | class SWAG_client(torch.nn.Module): 12 | def __init__(self, args, base_model, lr=0.01, max_num_models=25, var_clamp=1e-5, concentrate_num=1): 13 | self.base_model = base_model 14 | self.max_num_models=max_num_models 15 | self.var_clamp=var_clamp 16 | self.concentrate_num = concentrate_num 17 | self.args = args 18 | self.lr = lr 19 | 20 | def compute_var(self, mean, sq_mean): 21 | var_dict = {} 22 | for k in mean.keys(): 23 | var = torch.clamp(sq_mean[k] - (mean[k] ** 2), self.var_clamp) 24 | var_dict[k] = var 25 | 26 | return var_dict 27 | 28 | def construct_models(self, w): 29 | (w_avg, w_sq_avg, w_norm) = w 30 | self.w_var = self.compute_var(w_avg, w_sq_avg) 31 | 32 | mean_grad = {k:torch.zeros(w.size()) for k,w in w_avg.items()} 33 | 34 | for i in range(self.concentrate_num): 35 | for k in w_avg.keys(): 36 | mean = w_avg[k] 37 | var = self.w_var[k] 38 | 39 | eps = torch.randn_like(mean) 40 | sample_grad = mean + torch.sqrt(var) * eps * self.args.var_scale 41 | mean_grad[k] += sample_grad 42 | 43 | for k in w_avg.keys(): 44 | grad_length = w_norm[k]/float(self.concentrate_num)*self.args.client_stepsize 45 | mean_grad[k] = mean_grad[k]*grad_length + self.base_model[k].cpu() 46 | 47 | self.w_avg = w_avg 48 | return mean_grad 49 | 50 | class SWAG_server(torch.nn.Module): 51 | def __init__(self, args, base_model, avg_model=None, max_num_models=25, var_clamp=1e-5, concentrate_num=1, size_arr=None): 52 | self.base_model = base_model 53 | self.max_num_models=max_num_models 54 | self.var_clamp=var_clamp 55 | self.concentrate_num = concentrate_num 56 | self.args = args 57 | self.avg_model = avg_model 58 | self.size_arr = size_arr 59 | 60 | def compute_var(self, mean, sq_mean): 61 | var_dict = {} 62 | for k in mean.keys(): 63 | var = torch.clamp(sq_mean[k] - mean[k] ** 2, self.var_clamp) 64 | var_dict[k] = var 65 | 66 | return var_dict 67 | 68 | def compute_mean_sq(self, teachers): 69 | w_avg = {} 70 | w_sq_avg = {} 71 | w_norm ={} 72 | 73 | for k in teachers[0].keys(): 74 | if "batches_tracked" in k: continue 75 | w_avg[k] = torch.zeros(teachers[0][k].size()) 76 | w_sq_avg[k] = torch.zeros(teachers[0][k].size()) 77 | w_norm[k] = 0.0 78 | 79 | for k in w_avg.keys(): 80 | if "batches_tracked" in k: continue 81 | for i in range(0, len(teachers)): 82 | grad = teachers[i][k].cpu()- self.base_model[k].cpu() 83 | norm = torch.norm(grad, p=2) 84 | 85 | grad = grad/norm 86 | sq_grad = grad**2 87 | 88 | w_avg[k] += grad 89 | w_sq_avg[k] += sq_grad 90 | w_norm[k] += norm 91 | 92 | w_avg[k] = torch.div(w_avg[k], len(teachers)) 93 | w_sq_avg[k] = torch.div(w_sq_avg[k], len(teachers)) 94 | w_norm[k] = torch.div(w_norm[k], len(teachers)) 95 | 96 | return w_avg, w_sq_avg, w_norm 97 | 98 | def construct_models(self, teachers, mean=None, mode="dir"): 99 | if mode=="gaussian": 100 | w_avg, w_sq_avg, w_norm= self.compute_mean_sq(teachers) 101 | w_var = self.compute_var(w_avg, w_sq_avg) 102 | 103 | mean_grad = copy.deepcopy(w_avg) 104 | for i in range(self.concentrate_num): 105 | for k in w_avg.keys(): 106 | mean = w_avg[k] 107 | var = torch.clamp(w_var[k], 1e-6) 108 | 109 | eps = torch.randn_like(mean) 110 | sample_grad = mean + torch.sqrt(var) * eps * self.args.var_scale 111 | mean_grad[k] = (i*mean_grad[k] + sample_grad) / (i+1) 112 | 113 | for k in w_avg.keys(): 114 | mean_grad[k] = mean_grad[k]*self.args.swag_stepsize*w_norm[k] + self.base_model[k].cpu() 115 | 116 | return mean_grad 117 | 118 | elif mode=="random": 119 | num_t = 3 120 | ts = np.random.choice(teachers, num_t, replace=False) 121 | mean_grad = {} 122 | for k in ts[0].keys(): 123 | mean_grad[k] = torch.zeros(ts[0][k].size()) 124 | for i, t in enumerate(ts): 125 | mean_grad[k]+= t[k] 126 | 127 | for k in ts[0].keys(): 128 | mean_grad[k]/=num_t 129 | 130 | return mean_grad 131 | 132 | elif mode=="dir": 133 | proportions = np.random.dirichlet(np.repeat(self.args.alpha, len(teachers))) 134 | mean_grad = {} 135 | for k in teachers[0].keys(): 136 | mean_grad[k] = torch.zeros(teachers[0][k].size()) 137 | for i, t in enumerate(teachers): 138 | mean_grad[k]+= t[k]*proportions[i] 139 | 140 | for k in teachers[0].keys(): 141 | mean_grad[k]/=sum(proportions) 142 | 143 | return mean_grad 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import argparse 6 | 7 | def args_parser(): 8 | parser = argparse.ArgumentParser() 9 | # federated arguments 10 | parser.add_argument('--rounds', type=int, default=40, help="rounds of training") 11 | parser.add_argument('--num_users', type=int, default=10, help="number of users") 12 | parser.add_argument('--num_data', type=int, default=40000, help="number of data distributed to users") 13 | parser.add_argument('--num_server_data', type=int, default=-1, help="number of trans data to use in the server: -1 for using all - num_data in users.") 14 | 15 | parser.add_argument('--aug', action='store_true', help="aug") 16 | parser.add_argument('--ens', action='store_true', help="ensemble") 17 | parser.add_argument('--store_model', action='store_true', help="store_model") 18 | parser.add_argument('--frac', type=float, default=1.0, help="the fraction of clients") 19 | 20 | # Local train 21 | parser.add_argument('--local_ep', type=int, default=10, help="the number of local epochs") 22 | parser.add_argument('--local_bs', type=int, default=40, help="local batch size") 23 | 24 | parser.add_argument('--momentum', type=float, default=0.9, help="SGD momentum (default: 0.9)") 25 | parser.add_argument('--lr', type=float, default=0.01, help="learning rate") 26 | parser.add_argument('--local_sch', type=str, default='step', help='step, adaptive') 27 | parser.add_argument('--adap_ep', type=int, default=40, help="epochs for warm up training") 28 | parser.add_argument('--local_loss', type=str, default='CE', help='CE') 29 | parser.add_argument('--server_sample_freq', type=int, default=1, help='o, resample') 30 | parser.add_argument('--weight_decay', type=float, default=0.0005, help="weight_decay") 31 | 32 | parser.add_argument('--num_layers', type=int, default=0, help="extra conv layer") 33 | parser.add_argument('--use_SWA', action='store_true', help="use_SWA") 34 | parser.add_argument('--use_oracle', action='store_true', help="use_oracle") 35 | parser.add_argument('--dont_add_fedavg', action='store_true', help="add_fedavg") 36 | 37 | parser.add_argument('--log_dir', type=str, default='log', help='model name') 38 | parser.add_argument('--log_ep', type=int, default=5, help='log_ep') 39 | parser.add_argument('--exp', type=str, default='', help='model name') 40 | 41 | # Dataset 42 | parser.add_argument('--dataset', type=str, default='cifar', help="name of dataset") 43 | parser.add_argument('--dataset_trans', type=str, default='', help="Unsupervised dataset for server") 44 | 45 | parser.add_argument('--iid', action='store_true', help='whether i.i.d or not') 46 | parser.add_argument('--num_classes', type=int, default=10, help="number of classes") 47 | parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges") 48 | parser.add_argument('--split_method', type=str, default='step', help='split_method, [step, dir]') 49 | 50 | # client regualrzation: FedProx 51 | parser.add_argument('--reg_type', type=str, default='', help='FedProx, scaffold') 52 | parser.add_argument('--mu', type=float, default=0.001, help="mu") 53 | 54 | # SWAG & Server 55 | parser.add_argument('--fedM', action='store_true', help="FedAvgM") 56 | parser.add_argument('--teacher_type', type=str, default='SWAG', help='ensemble') 57 | parser.add_argument('--client_type', type=str, default='real', help='real, g') 58 | 59 | parser.add_argument('--swag_stepsize', type=float, default=1.0, help="swag_stepsize") 60 | parser.add_argument('--client_stepsize', type=float, default=1.0, help="client_stepsize") 61 | parser.add_argument('--var_scale', type=float, default=0.1, help="var_scale") 62 | parser.add_argument('--num_sample_teacher', type=int, default=10, help="number of teachers") 63 | parser.add_argument('--num_base', type=int, default=20, help="number of teachers") 64 | 65 | parser.add_argument('--use_client', action='store_true', help="use_client") 66 | parser.add_argument('--use_fake', action='store_true', help="use_fake") 67 | parser.add_argument('--sample_teacher', type=str, default="gaussian", help="use_client") 68 | 69 | parser.add_argument('--loss_type', type=str, default='KL', help='server loss') 70 | parser.add_argument('--temp', type=float, default=0.5, help="temp") 71 | 72 | parser.add_argument('--mom', type=float, default=0.9, help="teacher momentum") 73 | parser.add_argument('--server_bs', type=int, default=128, help="server batch size: B") 74 | parser.add_argument('--server_lr', type=float, default=0.01, help="learning rate") 75 | parser.add_argument('--update', type=str, default='dist', help='Aggregation update strategy, [FedAvg, dist]') 76 | parser.add_argument('--server_ep', type=int, default=20, help="the number of center epochs") 77 | parser.add_argument('--warmup_ep', type=int, default=-1, help="the number of warmup rounds") 78 | 79 | # model arguments 80 | parser.add_argument('--model', type=str, default='cnn', help='model name') 81 | 82 | parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel') 83 | parser.add_argument('--kernel_sizes', type=str, default='3,4,5', 84 | help='comma-separated kernel size to use for convolution') 85 | parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None") 86 | parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets") 87 | parser.add_argument('--max_pool', type=str, default='True', 88 | help="Whether use max pooling rather than strided convolutions") 89 | 90 | # other arguments 91 | parser.add_argument('--num_gpu', type=int, default=1, help="GPU ID, -1 for CPU") 92 | parser.add_argument('--verbose', action='store_true', help='verbose print') 93 | parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') 94 | args = parser.parse_args() 95 | 96 | if args.update =="FedAvg": args.use_SWA = False 97 | if args.teacher_type != "SWAG": args.dont_add_fedavg = True 98 | 99 | return args 100 | -------------------------------------------------------------------------------- /utils/sampling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import numpy as np 6 | import pdb 7 | from torchvision import datasets, transforms 8 | import os 9 | import glob 10 | from torch.utils.data import Dataset 11 | from PIL import Image 12 | 13 | def mnist_iid(dataset, num_users): 14 | """ 15 | Sample I.I.D. client data from MNIST dataset 16 | :param dataset: 17 | :param num_users: 18 | :return: dict of image index 19 | """ 20 | num_items = int(len(dataset)/num_users) 21 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 22 | for i in range(num_users): 23 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) 24 | all_idxs = list(set(all_idxs) - dict_users[i]) 25 | return dict_users 26 | 27 | def mnist_noniid(dataset, num_users, num_data=60000): 28 | """ 29 | Sample non-I.I.D client data from MNIST dataset 30 | :param dataset: 31 | :param num_users: 32 | :return: 33 | """ 34 | num_shards, num_imgs = 200, 250 35 | idx_shard = [i for i in range(num_shards)] 36 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 37 | idxs = np.arange(num_shards*num_imgs) 38 | labels = dataset.train_labels.numpy()[:num_shards*num_imgs] 39 | 40 | # sort labels 41 | idxs_labels = np.vstack((idxs, labels)) 42 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 43 | idxs = idxs_labels[0,:] 44 | 45 | # divide and assign 46 | for i in range(num_users): 47 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 48 | idx_shard = list(set(idx_shard) - rand_set) 49 | for rand in rand_set: 50 | add_idx = np.array(list(set(idxs[rand*num_imgs:(rand+1)*num_imgs]) )) 51 | dict_users[i] = np.concatenate((dict_users[i], add_idx), axis=0) 52 | 53 | cnts_dict = {} 54 | with open("mnist_%d_u%d.txt"%(num_data, num_users), 'w') as f: 55 | for i in range(num_users): 56 | labels_i = labels[dict_users[i]] 57 | cnts = np.array([np.count_nonzero(labels_i == j ) for j in range(10)] ) 58 | cnts_dict[i] = cnts 59 | f.write("User %s: %s sum: %d\n"%(i, " ".join([str(cnt) for cnt in cnts]), sum(cnts) )) 60 | 61 | server_idx = list(range(num_shards*num_imgs, 60000)) 62 | return dict_users, server_idx, cnts_dict 63 | 64 | def cifar_iid(dataset, num_users, num_data=50000): 65 | """ 66 | Sample I.I.D. client data from CIFAR10 dataset 67 | :param dataset: 68 | :param num_users: 69 | :return: dict of image index 70 | """ 71 | 72 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 73 | if num_data < 50000: 74 | server_idx = np.random.choice(all_idxs, 50000-num_data, replace=False) 75 | all_idxs = list(set(all_idxs) - set(server_idx)) 76 | num_items = int(len(all_idxs)/num_users) 77 | 78 | for i in range(num_users): 79 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) 80 | all_idxs = list(set(all_idxs) - dict_users[i]) 81 | return dict_users, server_idx 82 | 83 | def cifar_noniid(dataset, num_users, num_data=50000, method="step"): 84 | """ 85 | Sample non-I.I.D client data from CIFAR dataset 86 | :param dataset: 87 | :param num_users: 88 | :return: 89 | """ 90 | 91 | labels = np.array(dataset.targets) 92 | _lst_sample = 10 93 | 94 | if method=="step": 95 | 96 | num_shards = num_users*2 97 | num_imgs = 50000// num_shards 98 | idx_shard = [i for i in range(num_shards)] 99 | 100 | idxs = np.arange(num_shards*num_imgs) 101 | # sort labels 102 | idxs_labels = np.vstack((idxs, labels)) 103 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 104 | idxs = idxs_labels[0,:] 105 | 106 | least_idx = np.zeros((num_users, 10, _lst_sample), dtype=np.int) 107 | for i in range(10): 108 | idx_i = np.random.choice(np.where(labels==i)[0], num_users*_lst_sample, replace=False) 109 | least_idx[:, i, :] = idx_i.reshape((num_users, _lst_sample)) 110 | least_idx = np.reshape(least_idx, (num_users, -1)) 111 | 112 | least_idx_set = set(np.reshape(least_idx, (-1))) 113 | server_idx = np.random.choice(list(set(range(50000))-least_idx_set), 50000-num_data, replace=False) 114 | 115 | # divide and assign 116 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 117 | for i in range(num_users): 118 | rand_set = set(np.random.choice(idx_shard, num_shards//num_users, replace=False)) 119 | idx_shard = list(set(idx_shard) - rand_set) 120 | for rand in rand_set: 121 | idx_i = list( set(range(rand*num_imgs, (rand+1)*num_imgs)) ) 122 | add_idx = list(set(idxs[idx_i]) - set(server_idx) ) 123 | 124 | dict_users[i] = np.concatenate((dict_users[i], add_idx), axis=0) 125 | dict_users[i] = np.concatenate((dict_users[i], least_idx[i]), axis=0) 126 | 127 | elif method == "dir": 128 | min_size = 0 129 | K = 10 130 | y_train = labels 131 | 132 | _lst_sample = 2 133 | 134 | least_idx = np.zeros((num_users, 10, _lst_sample), dtype=np.int) 135 | for i in range(10): 136 | idx_i = np.random.choice(np.where(labels==i)[0], num_users*_lst_sample, replace=False) 137 | least_idx[:, i, :] = idx_i.reshape((num_users, _lst_sample)) 138 | least_idx = np.reshape(least_idx, (num_users, -1)) 139 | 140 | least_idx_set = set(np.reshape(least_idx, (-1))) 141 | #least_idx_set = set([]) 142 | server_idx = np.random.choice(list(set(range(50000))-least_idx_set), 50000-num_data, replace=False) 143 | local_idx = np.array([i for i in range(50000) if i not in server_idx and i not in least_idx_set]) 144 | 145 | N = y_train.shape[0] 146 | net_dataidx_map = {} 147 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 148 | 149 | while min_size < 10: 150 | idx_batch = [[] for _ in range(num_users)] 151 | # for each class in the dataset 152 | for k in range(K): 153 | idx_k = np.where(y_train == k)[0] 154 | idx_k = [id for id in idx_k if id in local_idx] 155 | 156 | np.random.shuffle(idx_k) 157 | proportions = np.random.dirichlet(np.repeat(0.1, num_users)) 158 | ## Balance 159 | proportions = np.array([p*(len(idx_j)>> # automatic mode 47 | >>> base_opt = torch.optim.SGD(model.parameters(), lr=0.1) 48 | >>> opt = torchcontrib.optim.SWA( 49 | >>> base_opt, swa_start=10, swa_freq=5, swa_lr=0.05) 50 | >>> for _ in range(100): 51 | >>> opt.zero_grad() 52 | >>> loss_fn(model(input), target).backward() 53 | >>> opt.step() 54 | >>> opt.swap_swa_sgd() 55 | >>> # manual mode 56 | >>> opt = torchcontrib.optim.SWA(base_opt) 57 | >>> for i in range(100): 58 | >>> opt.zero_grad() 59 | >>> loss_fn(model(input), target).backward() 60 | >>> opt.step() 61 | >>> if i > 10 and i % 5 == 0: 62 | >>> opt.update_swa() 63 | >>> opt.swap_swa_sgd() 64 | 65 | .. note:: 66 | SWA does not support parameter-specific values of :attr:`swa_start`, 67 | :attr:`swa_freq` or :attr:`swa_lr`. In automatic mode SWA uses the 68 | same :attr:`swa_start`, :attr:`swa_freq` and :attr:`swa_lr` for all 69 | parameter groups. If needed, use manual mode withbn_update 70 | :meth:`update_swa_group` to use different update schedules for 71 | different parameter groups. 72 | 73 | .. note:: 74 | Call :meth:`swap_swa_sgd` in the end of training to use the computed 75 | running averages. 76 | 77 | .. note:: 78 | If you are using SWA to optimize the parameters of a Neural Network 79 | containing Batch Normalization layers, you need to update the 80 | :attr:`running_mean` and :attr:`running_var` statistics of the 81 | Batch Normalization module. You can do so by using 82 | `torchcontrib.optim.swa.` utility. 83 | 84 | .. note:: 85 | See the blogpost 86 | https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/ 87 | for an extended description of this SWA implementation. 88 | 89 | .. note:: 90 | The repo https://github.com/izmailovpavel/contrib_swa_examples 91 | contains examples of using this SWA implementation. 92 | 93 | .. _Averaging Weights Leads to Wider Optima and Better Generalization: 94 | https://arxiv.org/abs/1803.05407 95 | .. _Improving Consistency-Based Semi-Supervised Learning with Weight 96 | Averaging: 97 | https://arxiv.org/abs/1806.05594 98 | """ 99 | self._auto_mode, (self.swa_start, self.swa_freq) = \ 100 | self._check_params(self, swa_start, swa_freq) 101 | self.swa_lr = swa_lr 102 | 103 | if self._auto_mode: 104 | if swa_start < 0: 105 | raise ValueError("Invalid swa_start: {}".format(swa_start)) 106 | if swa_freq < 1: 107 | raise ValueError("Invalid swa_freq: {}".format(swa_freq)) 108 | else: 109 | if self.swa_lr is not None: 110 | warnings.warn( 111 | "Some of swa_start, swa_freq is None, ignoring swa_lr") 112 | # If not in auto mode make all swa parameters None 113 | self.swa_lr = None 114 | self.swa_start = None 115 | self.swa_freq = None 116 | 117 | if self.swa_lr is not None and self.swa_lr < 0: 118 | raise ValueError("Invalid SWA learning rate: {}".format(swa_lr)) 119 | 120 | self.optimizer = optimizer 121 | 122 | self.defaults = self.optimizer.defaults 123 | self.param_groups = self.optimizer.param_groups 124 | self.state = defaultdict(dict) 125 | self.opt_state = self.optimizer.state 126 | for group in self.param_groups: 127 | group['n_avg'] = 0 128 | group['step_counter'] = 0 129 | 130 | @staticmethod 131 | def _check_params(self, swa_start, swa_freq): 132 | params = [swa_start, swa_freq] 133 | params_none = [param is None for param in params] 134 | if not all(params_none) and any(params_none): 135 | warnings.warn( 136 | "Some of swa_start, swa_freq is None, ignoring other") 137 | for i, param in enumerate(params): 138 | if param is not None and not isinstance(param, int): 139 | params[i] = int(param) 140 | warnings.warn("Casting swa_start, swa_freq to int") 141 | return not any(params_none), params 142 | 143 | def _reset_lr_to_swa(self): 144 | if self.swa_lr is None: 145 | return 146 | for param_group in self.param_groups: 147 | if param_group['step_counter'] >= self.swa_start: 148 | param_group['lr'] = self.swa_lr 149 | 150 | def update_swa_group(self, group): 151 | r"""Updates the SWA running averages for the given parameter group. 152 | 153 | Arguments: 154 | param_group (dict): Specifies for what parameter group SWA running 155 | averages should be updated 156 | 157 | Examples: 158 | >>> # automatic mode 159 | >>> base_opt = torch.optim.SGD([{'params': [x]}, 160 | >>> {'params': [y], 'lr': 1e-3}], lr=1e-2, momentum=0.9) 161 | >>> opt = torchcontrib.optim.SWA(base_opt) 162 | >>> for i in range(100): 163 | >>> opt.zero_grad() 164 | >>> loss_fn(model(input), target).backward() 165 | >>> opt.step() 166 | >>> if i > 10 and i % 5 == 0: 167 | >>> # Update SWA for the second parameter group 168 | >>> opt.update_swa_group(opt.param_groups[1]) 169 | >>> opt.swap_swa_sgd() 170 | """ 171 | for p in group['params']: 172 | param_state = self.state[p] 173 | if 'swa_buffer' not in param_state: 174 | param_state['swa_buffer'] = torch.zeros_like(p.data) 175 | buf = param_state['swa_buffer'] 176 | virtual_decay = 1 / float(group["n_avg"] + 1) 177 | diff = (p.data - buf) * virtual_decay 178 | buf.add_(diff) 179 | group["n_avg"] += 1 180 | 181 | def update_swa(self): 182 | r"""Updates the SWA running averages of all optimized parameters. 183 | """ 184 | for group in self.param_groups: 185 | self.update_swa_group(group) 186 | 187 | def swap_swa_sgd(self): 188 | r"""Swaps the values of the optimized variables and swa buffers. 189 | 190 | It's meant to be called in the end of training to use the collected 191 | swa running averages. It can also be used to evaluate the running 192 | averages during training; to continue training `swap_swa_sgd` 193 | should be called again. 194 | """ 195 | for group in self.param_groups: 196 | for p in group['params']: 197 | param_state = self.state[p] 198 | if 'swa_buffer' not in param_state: 199 | # If swa wasn't applied we don't swap params 200 | warnings.warn( 201 | "SWA wasn't applied to param {}; skipping it".format(p)) 202 | continue 203 | buf = param_state['swa_buffer'] 204 | tmp = torch.empty_like(p.data) 205 | tmp.copy_(p.data) 206 | p.data.copy_(buf) 207 | buf.copy_(tmp) 208 | 209 | def step(self, closure=None): 210 | r"""Performs a single optimization step. 211 | 212 | In automatic mode also updates SWA running averages. 213 | """ 214 | self._reset_lr_to_swa() 215 | loss = self.optimizer.step(closure) 216 | for group in self.param_groups: 217 | group["step_counter"] += 1 218 | steps = group["step_counter"] 219 | if self._auto_mode: 220 | if steps > self.swa_start and steps % self.swa_freq == 0: 221 | self.update_swa_group(group) 222 | return loss 223 | 224 | def state_dict(self): 225 | r"""Returns the state of SWA as a :class:`dict`. 226 | 227 | It contains three entries: 228 | * opt_state - a dict holding current optimization state of the base 229 | optimizer. Its content differs between optimizer classes. 230 | * swa_state - a dict containing current state of SWA. For each 231 | optimized variable it contains swa_buffer keeping the running 232 | average of the variable 233 | * param_groups - a dict containing all parameter groups 234 | """ 235 | opt_state_dict = self.optimizer.state_dict() 236 | swa_state = {(id(k) if isinstance(k, torch.Tensor) else k): v 237 | for k, v in self.state.items()} 238 | opt_state = opt_state_dict["state"] 239 | param_groups = opt_state_dict["param_groups"] 240 | return {"opt_state": opt_state, "swa_state": swa_state, 241 | "param_groups": param_groups} 242 | 243 | def load_state_dict(self, state_dict): 244 | r"""Loads the optimizer state. 245 | 246 | Args: 247 | state_dict (dict): SWA optimizer state. Should be an object returned 248 | from a call to `state_dict`. 249 | """ 250 | swa_state_dict = {"state": state_dict["swa_state"], 251 | "param_groups": state_dict["param_groups"]} 252 | opt_state_dict = {"state": state_dict["opt_state"], 253 | "param_groups": state_dict["param_groups"]} 254 | super(SWA, self).load_state_dict(swa_state_dict) 255 | self.optimizer.load_state_dict(opt_state_dict) 256 | self.opt_state = self.optimizer.state 257 | 258 | def add_param_group(self, param_group): 259 | r"""Add a param group to the :class:`Optimizer` s `param_groups`. 260 | 261 | This can be useful when fine tuning a pre-trained network as frozen 262 | layers can be made trainable and added to the :class:`Optimizer` as 263 | training progresses. 264 | 265 | Args: 266 | param_group (dict): Specifies what Tensors should be optimized along 267 | with group specific optimization options. 268 | """ 269 | param_group['n_avg'] = 0 270 | param_group['step_counter'] = 0 271 | self.optimizer.add_param_group(param_group) 272 | 273 | @staticmethod 274 | def bn_update(loader, model, device=None): 275 | r"""Updates BatchNorm running_mean, running_var buffers in the model. 276 | 277 | It performs one pass over data in `loader` to estimate the activation 278 | statistics for BatchNorm layers in the model. 279 | 280 | Args: 281 | loader (torch.utils.data.DataLoader): dataset loader to compute the 282 | activation statistics on. Each data batch should be either a 283 | tensor, or a list/tuple whose first element is a tensor 284 | containing data. 285 | 286 | model (torch.nn.Module): model for which we seek to update BatchNorm 287 | statistics. 288 | 289 | device (torch.device, optional): If set, data will be trasferred to 290 | :attr:`device` before being passed into :attr:`model`. 291 | """ 292 | if not _check_bn(model): 293 | return 294 | was_training = model.training 295 | model.train() 296 | momenta = {} 297 | model.apply(_reset_bn) 298 | model.apply(lambda module: _get_momenta(module, momenta)) 299 | n = 0 300 | for input in loader: 301 | if isinstance(input, (list, tuple)): 302 | input = input[0] 303 | b = input.size(0) 304 | 305 | momentum = b / float(n + b) 306 | for module in momenta.keys(): 307 | module.momentum = momentum 308 | 309 | 310 | input = input.cuda() 311 | 312 | model(input) 313 | n += b 314 | 315 | model.apply(lambda module: _set_momenta(module, momenta)) 316 | model.train(was_training) 317 | 318 | 319 | # BatchNorm utils 320 | def _check_bn_apply(module, flag): 321 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 322 | flag[0] = True 323 | 324 | 325 | def _check_bn(model): 326 | flag = [False] 327 | model.apply(lambda module: _check_bn_apply(module, flag)) 328 | return flag[0] 329 | 330 | 331 | def _reset_bn(module): 332 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 333 | module.running_mean = torch.zeros_like(module.running_mean) 334 | module.running_var = torch.ones_like(module.running_var) 335 | 336 | 337 | def _get_momenta(module, momenta): 338 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 339 | momenta[module] = module.momentum 340 | 341 | 342 | def _set_momenta(module, momenta): 343 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 344 | module.momentum = momenta[module] 345 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | import copy 9 | import pdb 10 | import os 11 | import pickle 12 | 13 | import numpy as np 14 | from swag import SWAG_server 15 | 16 | from torchvision import datasets, transforms 17 | import torch 18 | import torch.nn as nn 19 | import torch.multiprocessing as mp 20 | torch.multiprocessing.set_sharing_strategy('file_system') 21 | 22 | from utils.sampling import * 23 | from utils.options import args_parser 24 | from utils.tools import * 25 | 26 | from models.Update import SWAGLocalUpdate, ServerUpdate 27 | from models.Nets import MLP, CNNMnist, CNNCifar 28 | from models.Fed import FedAvg, create_local_init 29 | from models.FedM import FedAvgM 30 | from models.test import test_img 31 | import resnet 32 | 33 | if __name__ == '__main__': 34 | # parse args 35 | args = args_parser() 36 | 37 | args.log_dir = os.path.join(args.log_dir) 38 | 39 | if not os.path.exists(args.log_dir): 40 | os.makedirs(args.log_dir) 41 | 42 | with open(os.path.join(args.log_dir, "args.txt"), "w") as f: 43 | for arg in vars(args): 44 | print (arg, getattr(args, arg), file=f) 45 | 46 | args.acc_dir = os.path.join(args.log_dir, "acc") 47 | if not os.path.exists(args.acc_dir): 48 | os.makedirs(args.acc_dir) 49 | 50 | model_dir = os.path.join(args.log_dir, "models") 51 | if not os.path.exists(model_dir): 52 | os.makedirs(model_dir) 53 | 54 | transform_train = transforms.Compose([ 55 | transforms.RandomCrop(32, padding=4), 56 | transforms.RandomHorizontalFlip(), 57 | transforms.ToTensor(), 58 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 59 | ]) 60 | 61 | transform_val = transforms.Compose([ 62 | transforms.ToTensor(), 63 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 64 | ]) 65 | 66 | # load dataset and split users 67 | if args.dataset == 'mnist': 68 | args.num_classes = 10 69 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 70 | dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist) 71 | dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist) 72 | dataset_eval = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist) 73 | # sample users 74 | if args.iid: 75 | dict_users = mnist_iid(dataset_train, args.num_users) 76 | else: 77 | dict_users, server_id, cnts_dict = mnist_noniid(dataset_train, args.num_users) 78 | elif args.dataset == 'cifar': 79 | args.num_classes = 10 80 | dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=transform_train) 81 | dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=transform_val) 82 | dataset_eval = datasets.CIFAR10('./data/cifar', train=True, transform=transform_val, target_transform=None, download=True) 83 | if args.iid: 84 | dict_users, server_id = cifar_iid(dataset_train, args.num_users, num_data=args.num_data) 85 | else: 86 | dict_users, server_id, cnts_dict = cifar_noniid(dataset_train, args.num_users, num_data=args.num_data, method=args.split_method) 87 | else: 88 | exit('Error: unrecognized dataset') 89 | 90 | train_ids = set() 91 | for u,v in dict_users.items(): 92 | train_ids.update(v) 93 | train_ids = list(train_ids) 94 | 95 | img_size = dataset_train[0][0].shape 96 | # build model 97 | if args.model == 'cnn' and 'cifar' in args.dataset: 98 | net_glob = CNNCifar(args=args) 99 | elif args.model == 'cnn' and args.dataset == 'mnist': 100 | net_glob = CNNMnist(args=args) 101 | elif args.model == 'mlp': 102 | len_in = 1 103 | for x in img_size: 104 | len_in *= x 105 | net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes) 106 | elif "resnet" in args.model and 'cifar' in args.dataset: 107 | net_glob = resnet.resnet32(num_classes=args.num_classes) 108 | else: 109 | exit('Error: unrecognized model') 110 | 111 | print(net_glob) 112 | net_glob.train() 113 | 114 | # copy weights 115 | w_glob = net_glob.state_dict() 116 | 117 | # training 118 | loss_local_list = [] 119 | loss_local_test_list = [] 120 | entropy_list = [] 121 | cv_loss, cv_acc = [], [] 122 | 123 | acc_local_list = [] 124 | acc_local_test_list = [] 125 | acc_local_val_list = [] 126 | 127 | val_loss_pre, counter = 0, 0 128 | net_best = None 129 | best_loss = None 130 | val_acc_list, net_list = [], [] 131 | 132 | net_glob.apply(weights_init) 133 | 134 | def cliet_train(q, device_id, net_glob, iters, idx, val_id=server_id, generator=None): 135 | device=torch.device('cuda:{}'.format(device_id) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 136 | lr = lr_schedule(args.lr, iters, args.rounds) 137 | 138 | if args.local_sch == "adaptive": 139 | running_ep = adaptive_schedule(args.local_ep, args.epochs, iters, args.adap_ep) 140 | if running_ep != args.local_ep: 141 | print("Using adaptive scheduling, local ep = %d."%args.adap_ep) 142 | else: 143 | running_ep = args.local_ep 144 | 145 | local = SWAGLocalUpdate(args=args, 146 | device=device, 147 | dataset=dataset_train, 148 | idxs=dict_users[idx], 149 | server_ids=val_id, 150 | test=(dataset_test, range(len(dataset_test))), 151 | num_per_cls=cnts_dict[idx] ) 152 | 153 | teacher = local.train(net=net_glob.to(device), running_ep=running_ep, lr=lr) 154 | q.put([teacher, idx]) 155 | return [teacher, idx] 156 | 157 | def server_train(q, device_id, net_glob, teachers, global_ep, w_org=None, base_teachers=None): 158 | student = ServerUpdate(args=args, 159 | device=device_id, 160 | dataset=dataset_eval, 161 | server_dataset=dataset_eval, 162 | server_idxs=server_id, 163 | train_idx=train_ids, 164 | test=(dataset_test, range(len(dataset_test))), 165 | w_org=w_org, 166 | base_teachers=base_teachers) 167 | 168 | w_swa, w_glob, train_acc, val_acc, test_acc, loss, entropy = student.train(net_glob, teachers, args.log_dir, global_ep) 169 | 170 | q.put([w_swa, w_glob, train_acc, val_acc, test_acc, entropy]) 171 | return [w_swa, w_glob, train_acc, val_acc, test_acc, entropy] 172 | 173 | def test_thread(q, net_glob, dataset, ids): 174 | acc, loss = test_img(net_glob, dataset, args, ids, cls_num=args.num_classes) 175 | q.put([acc, loss]) 176 | return [acc, loss] 177 | 178 | def eval(net_glob, tag='', server_id=None): 179 | # testing 180 | q = mp.Manager().Queue() 181 | 182 | p_tr = mp.Process(target=test_thread, args=(q, net_glob, dataset_eval, train_ids)) 183 | p_tr.start() 184 | p_tr.join() 185 | [acc_train, loss_train] = q.get() 186 | 187 | q2 = mp.Manager().Queue() 188 | p_te = mp.Process(target=test_thread, args=(q2, net_glob, dataset_test, range(len(dataset_test)))) 189 | p_te.start() 190 | p_te.join() 191 | 192 | [acc_test, loss_test] = q2.get() 193 | 194 | q3 = mp.Manager().Queue() 195 | p_val = mp.Process(target=test_thread, args=(q3, net_glob, dataset_eval, server_id)) 196 | p_val.start() 197 | p_val.join() 198 | 199 | [acc_val, loss_val] = q3.get() 200 | 201 | print(tag, "Training accuracy: {:.2f}".format(acc_train)) 202 | print(tag, "Server accuracy: {:.2f}".format(acc_val)) 203 | print(tag, "Testing accuracy: {:.2f}".format(acc_test)) 204 | 205 | del q 206 | del q2 207 | del q3 208 | 209 | return [acc_train, loss_train], [acc_test, loss_test], [acc_val, loss_val] 210 | 211 | def put_log(logger, net_glob, tag, iters=-1): 212 | [acc_train, loss_train], [acc_test, loss_test], [acc_val, loss_val] = eval(net_glob, tag=tag, server_id=server_id) 213 | 214 | if iters==0: 215 | open(os.path.join(args.acc_dir, tag+"_train_acc.txt"), "w") 216 | open(os.path.join(args.acc_dir, tag+"_val_acc.txt"), "w") 217 | open(os.path.join(args.acc_dir, tag+"_test_acc.txt"), "w") 218 | open(os.path.join(args.acc_dir, tag+"_test_loss.txt"), "w") 219 | 220 | with open(os.path.join(args.acc_dir, tag+"_train_acc.txt"), "a") as f: 221 | f.write("%d %f\n"%(iters, acc_train)) 222 | with open(os.path.join(args.acc_dir, tag+"_test_acc.txt"), "a") as f: 223 | f.write("%d %f\n"%(iters, acc_test)) 224 | with open(os.path.join(args.acc_dir, tag+"_val_acc.txt"), "a") as f: 225 | f.write("%d %f\n"%(iters, acc_val)) 226 | with open(os.path.join(args.acc_dir, tag+"_test_loss.txt"), "a") as f: 227 | f.write("%d %f\n"%(iters, loss_test)) 228 | 229 | if "SWA" not in tag: 230 | logger.loss_train_list.append(loss_train) 231 | logger.train_acc_list.append(acc_train) 232 | 233 | logger.loss_test_list.append(loss_test) 234 | logger.test_acc_list.append(acc_test) 235 | 236 | logger.loss_val_list.append(loss_val) 237 | logger.val_acc_list.append(acc_val) 238 | else: 239 | if tag =="SWAG": 240 | logger.swag_train_acc_list.append(acc_train) 241 | logger.swag_val_acc_list.append(acc_val) 242 | logger.swag_test_acc_list.append(acc_test) 243 | else: 244 | logger.swa_train_acc_list.append(acc_train) 245 | logger.swa_val_acc_list.append(acc_val) 246 | logger.swa_test_acc_list.append(acc_test) 247 | 248 | def put_oracle_log(logger, ens_train_acc, ens_val_acc, ens_test_acc, iters=-1): 249 | if iters>=0 and iters%args.log_ep!= 0: 250 | return 251 | logger.ens_train_acc_list.append(ens_train_acc) 252 | logger.ens_test_acc_list.append(ens_test_acc) 253 | logger.ens_val_acc_list.append(ens_val_acc) 254 | 255 | tag = "ens" 256 | if iters==0: 257 | open(os.path.join(args.acc_dir, tag+"_train_acc.txt"), "w") 258 | open(os.path.join(args.acc_dir, tag+"_val_acc.txt"), "w") 259 | open(os.path.join(args.acc_dir, tag+"_test_acc.txt"), "w") 260 | 261 | with open(os.path.join(args.acc_dir, tag+"_train_acc.txt"), "a") as f: 262 | f.write("%d %f\n"%(iters, ens_train_acc)) 263 | with open(os.path.join(args.acc_dir, tag+"_test_acc.txt"), "a") as f: 264 | f.write("%d %f\n"%(iters, ens_test_acc)) 265 | with open(os.path.join(args.acc_dir, tag+"_val_acc.txt"), "a") as f: 266 | f.write("%d %f\n"%(iters, ens_val_acc)) 267 | 268 | dist_logger = logger("DIST") 269 | fedavg_logger = logger("FedAvg") 270 | work_tag = args.update 271 | 272 | teachers = [[] for i in range(args.num_users)] 273 | generator = None 274 | best_acc = 0.0 275 | 276 | size_arr = [np.sum(cnts_dict[i]) for i in range(args.num_users)] 277 | for iters in range(args.rounds): 278 | w_glob_org = copy.deepcopy(net_glob.state_dict()) 279 | 280 | net_glob.train() 281 | loss_locals = [] 282 | acc_locals = [] 283 | acc_locals_test = [] 284 | loss_locals_test = [] 285 | acc_locals_val = [] 286 | 287 | m = max(int(args.frac * args.num_users), 1) 288 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 289 | clients = [[] for i in range(args.num_users)] 290 | 291 | num_threads = 5 292 | for i in range(0, m, num_threads): 293 | processes = [] 294 | torch.cuda.empty_cache() 295 | q = mp.Manager().Queue() 296 | 297 | for idx in idxs_users[i:i+num_threads]: 298 | p = mp.Process(target=cliet_train, args=(q, idx%(args.num_gpu), copy.deepcopy(net_glob), iters, idx, server_id, generator)) 299 | p.start() 300 | processes.append(p) 301 | 302 | for p in processes: 303 | p.join() 304 | 305 | while not q.empty(): 306 | fake_out = q.get() 307 | idx = int(fake_out[-1]) 308 | clients[idx].append(fake_out[0]) 309 | 310 | clients = [c[0] for c in clients if len(c)>0] 311 | client_w = [c.state_dict() for c in clients] 312 | 313 | if args.store_model and (iters%args.log_ep==0 or iters==args.rounds-1): 314 | store_model(iters, model_dir, w_glob_org, client_w) 315 | 316 | if args.fedM and iters > 1: 317 | w_glob_avg, momentum = FedAvgM(client_w, args.num_gpu-1, (w_glob_org, momentum), args.mom, size_arr=size_arr) 318 | else: 319 | w_glob_avg = FedAvg(client_w, args.num_gpu-1, size_arr=size_arr) 320 | momentum = {k:w_glob_org[k]-w_glob_avg[k] for k in w_glob_avg.keys()} 321 | 322 | net_glob.load_state_dict(w_glob_avg) 323 | 324 | if iters%args.log_ep== 0: 325 | put_log(fedavg_logger, net_glob, tag='FedAvg', iters=iters) 326 | 327 | # Generate Teachers 328 | # Two modes for base teachers: SWAG and FedAvg 329 | teachers_list = [] 330 | 331 | if not args.dont_add_fedavg: 332 | print("add FedAvg to teachers") 333 | teachers_list.append(copy.deepcopy(net_glob)) # Add FedAvg 334 | 335 | if args.teacher_type=="SWAG" and iters > args.warmup_ep: 336 | for i in range(args.num_sample_teacher): 337 | base_teachers = client_w 338 | swag_model = SWAG_server(args, w_glob_org, avg_model=w_glob_avg, concentrate_num=1, size_arr=size_arr) 339 | w_swag = swag_model.construct_models(base_teachers, mode=args.sample_teacher) 340 | net_glob.load_state_dict(w_swag) 341 | teachers_list.append(copy.deepcopy(net_glob)) 342 | else: 343 | base_teachers = client_w 344 | print("Warming up, using DIST.") 345 | 346 | if args.use_client: 347 | teachers_list+=clients 348 | 349 | # Load weights for server training 350 | net_glob.load_state_dict(w_glob_avg) 351 | print("Initialize with FedAvg for server training ...") 352 | # update global weights 353 | q = mp.Manager().Queue() 354 | print("Server training...") 355 | 356 | p = mp.Process(target=server_train, args=(q, args.num_gpu-1, net_glob, teachers_list, iters)) 357 | p.start() 358 | p.join() 359 | 360 | [w_glob_mean, w_glob, ens_train_acc, ens_val_acc, ens_test_acc, entropy] = q.get() 361 | del q 362 | 363 | if best_acc < ens_test_acc: 364 | best_acc = ens_test_acc 365 | 366 | if iters%args.log_ep== 0: 367 | net_glob.load_state_dict(w_glob_mean) 368 | put_log(dist_logger, net_glob, tag='DIST-SWA', iters=iters) 369 | 370 | net_glob.load_state_dict(w_glob) 371 | put_log(dist_logger, net_glob, tag='DIST', iters=iters) 372 | put_oracle_log(dist_logger, ens_train_acc, ens_val_acc, ens_test_acc, iters=iters) 373 | 374 | if args.update=='FedAvg': 375 | net_glob.load_state_dict(w_glob_avg) 376 | print("Sending back FedAvg!") 377 | else: 378 | if args.use_SWA: 379 | net_glob.load_state_dict(w_glob_mean) 380 | print("Sending back student w/ SWA!") 381 | else: 382 | net_glob.load_state_dict(w_glob) 383 | print("Sending back student w/o SWA!") 384 | 385 | if args.store_model and iters == args.rounds-1: 386 | store_model(iters, model_dir, w_glob_org, client_w) 387 | 388 | del clients 389 | 390 | torch.save(net_glob.state_dict(), os.path.join(args.log_dir, "model")) 391 | -------------------------------------------------------------------------------- /models/Update.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch 6 | from torch import nn, autograd 7 | from torch.utils.data import DataLoader, Dataset 8 | import numpy as np 9 | import random, pdb, os 10 | from sklearn import metrics 11 | import torch.nn.functional as F 12 | from scipy.stats import entropy 13 | import matplotlib.pyplot as plt 14 | import seaborn as sns 15 | import copy 16 | from models.test import test_img 17 | from torchvision import datasets, transforms 18 | 19 | from scipy.stats import mode 20 | from utils.tools import * 21 | from sklearn.utils import shuffle 22 | from PIL import Image 23 | 24 | import torch.multiprocessing as mp 25 | from models.swa import SWA 26 | 27 | class SWAGLocalUpdate(object): 28 | def __init__(self, args, device, dataset=None, idxs=None, server_ids=None, test=(None, None), num_per_cls=None): 29 | self.args = args 30 | self.device = device 31 | self.num_per_cls = num_per_cls 32 | 33 | self.loss_func = nn.CrossEntropyLoss().to(self.device) 34 | 35 | self.selected_clients = [] 36 | self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True) 37 | (self.test_dataset, self.test_ids) = test 38 | (self.train_dataset, self.user_train_ids) = (dataset, idxs) 39 | 40 | self.server_ids = server_ids 41 | 42 | def apply_weight_decay(self, *modules, weight_decay_factor=0., wo_bn=True): 43 | ''' 44 | https://discuss.pytorch.org/t/weight-decay-in-the-optimizers-is-a-bad-idea-especially-with-batchnorm/16994/5 45 | Apply weight decay to pytorch model without BN; 46 | In pytorch: 47 | if group['weight_decay'] != 0: 48 | grad = grad.add(p, alpha=group['weight_decay']) 49 | p is the param; 50 | :param modules: 51 | :param weight_decay_factor: 52 | :return: 53 | ''' 54 | for module in modules: 55 | for m in module.modules(): 56 | if hasattr(m, 'weight'): 57 | if wo_bn and isinstance(m, torch.nn.modules.batchnorm._BatchNorm): 58 | continue 59 | m.weight.grad += m.weight * weight_decay_factor 60 | 61 | def reg_loss(self, net, grad_org): 62 | if self.args.reg_type == "FedProx": 63 | reg_loss = 0.0 64 | for name, param in net.named_parameters(): 65 | if 'weight' in name: 66 | reg_loss += torch.norm(param-grad_org[name].to(self.device), 2) 67 | reg_loss = reg_loss*0.5*self.args.mu 68 | return reg_loss 69 | 70 | def train(self, net, running_ep, lr): 71 | net.cpu() 72 | grad_org = copy.deepcopy(net.state_dict()) 73 | net.to(self.device) 74 | net.train() 75 | 76 | # train and update 77 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9) 78 | if self.args.ens: 79 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 80 | step_size=30, 81 | gamma=0.1) 82 | epoch_loss = [] 83 | acc = 0.0 84 | 85 | num_model = 0 86 | cnt = 0 87 | for iter in range(running_ep): 88 | batch_loss = [] 89 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 90 | images, labels = images.to(self.device), labels.to(self.device) 91 | net.zero_grad() 92 | log_probs = net(images) 93 | loss = self.loss_func(log_probs, labels) 94 | 95 | if self.args.reg_type == "FedProx": 96 | reg_loss = self.reg_loss(net, grad_org) 97 | loss += reg_loss 98 | 99 | loss.backward() 100 | self.apply_weight_decay(net, weight_decay_factor=self.args.weight_decay) 101 | optimizer.step() 102 | 103 | if self.args.verbose and batch_idx % 10 == 0: 104 | print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 105 | iter, batch_idx * len(images), len(self.ldr_train.dataset), 106 | 100. * batch_idx / len(self.ldr_train), loss.item())) 107 | batch_loss.append(loss.item()) 108 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 109 | 110 | if self.args.ens: 111 | lr_scheduler.step() 112 | 113 | net = net.cpu() 114 | return net 115 | 116 | class ServerUpdate(object): 117 | def __init__(self, args, device, dataset=None, 118 | server_dataset=None, server_idxs=None, train_idx=None, 119 | test=(None, None), 120 | w_org=None, base_teachers=None): 121 | 122 | self.args = args 123 | self.device = device 124 | self.loss_type = args.loss_type 125 | self.loss_func = nn.KLDivLoss() if self.loss_type =="KL" else nn.CrossEntropyLoss() 126 | self.selected_clients = [] 127 | 128 | self.server_data_size = len(server_idxs) 129 | self.aug = args.aug and args.use_SWA 130 | self.ldr_train = DataLoader(DatasetSplit(dataset, server_idxs), batch_size=1024, shuffle=False) 131 | self.ldr_local_train = DataLoader(DatasetSplit(dataset, train_idx), batch_size=self.args.server_bs, shuffle=False) 132 | self.test_dataset = DataLoader(test[0], batch_size=self.args.server_bs, shuffle=False) 133 | self.aum_dir = os.path.join(self.args.log_dir, "aum") 134 | 135 | server_train_dataset = DataLoader(DatasetSplit(server_dataset, server_idxs), batch_size=self.args.server_bs, shuffle=False) 136 | self.server_train_dataset = [images for images, labels in server_train_dataset] 137 | 138 | self.w_org = w_org 139 | self.base_teachers = base_teachers 140 | 141 | # Get one batch for testing 142 | (self.eval_images, self.eval_labels) = next(iter(self.ldr_train)) 143 | 144 | def transform_train(self, images): 145 | images = random_crop(images, 4) 146 | images = torch.Tensor(images).cuda() 147 | return images 148 | 149 | def get_ensemble_logits(self, teachers, inputs, method='mean', global_ep=1000): 150 | logits = np.zeros((len(teachers), len(inputs), self.args.num_classes)) 151 | for i, t_net in enumerate(teachers): 152 | logit = get_input_logits(inputs, t_net.cuda(), is_logit = self.args.is_logit) #Disable res 153 | logits[i] = logit 154 | 155 | logits = np.transpose(logits, (1, 0, 2)) # batchsize, teachers, 10 156 | logits_arr, logits_cond = merge_logits(logits, method, self.args.loss_type, temp=self.args.temp, global_ep=global_ep) 157 | batch_entropy = get_entropy(logits.reshape((-1, self.args.num_classes))) 158 | return logits_arr, batch_entropy 159 | 160 | def eval_ensemble(self, teachers, dataset): 161 | acc = 0.0 162 | cnt = 0 163 | 164 | if self.args.soft_vote: 165 | num_votes_list, soft_vote = get_aum(self.args, teachers, dataset) 166 | for batch_idx, (_, labels) in enumerate(dataset): 167 | logits = soft_vote[batch_idx] 168 | logits=np.argmax(logits, axis=-1) 169 | acc += np.sum(logits==labels.numpy()) 170 | cnt += len(labels) 171 | 172 | else: 173 | for batch_idx, (images, labels) in enumerate(dataset): 174 | images = images.cuda() 175 | logits, _ = self.get_ensemble_logits(teachers, images, method=self.args.logit_method, global_ep=1000) 176 | 177 | if self.args.logit_method != "vote": 178 | logits=np.argmax(logits, axis=-1) 179 | 180 | acc += np.sum(logits==labels.numpy()) 181 | cnt += len(labels) 182 | 183 | return float(acc)/cnt*100.0 184 | 185 | def loss_wrapper(self, log_probs, logits, labels): 186 | # Modify target logits 187 | if self.loss_type=="CE": 188 | if self.args.logit_method != "vote": 189 | logits = np.argmax(logits, axis=-1) 190 | acc_cnt=np.sum(logits==labels) 191 | cnt=len(labels) 192 | logits = torch.Tensor(logits).long().cuda(non_blocking=True) 193 | 194 | else: 195 | acc_cnt=np.sum(np.argmax(logits, axis=-1)==labels) 196 | cnt=len(labels) 197 | logits = torch.Tensor(logits).cuda(non_blocking=True) 198 | 199 | 200 | # For loss function 201 | if self.args.use_oracle: 202 | loss = nn.CrossEntropyLoss()(log_probs, torch.Tensor(labels).long().cuda()) 203 | else: 204 | if "KL" in self.loss_type: 205 | log_probs = F.softmax(log_probs, dim=-1) 206 | if self.loss_type== "reverse_KL": 207 | P = log_probs 208 | Q = logits 209 | else: 210 | P = logits 211 | Q = log_probs 212 | 213 | one_vec = (P * (P.log() - torch.Tensor([0.1]).cuda(non_blocking=True).log())) 214 | loss = (P * (P.log() - Q.log())).mean() 215 | else: 216 | loss = self.loss_func(log_probs, logits) 217 | 218 | return loss, acc_cnt, cnt 219 | 220 | def test_net(self, tmp_net): 221 | tmp_net = tmp_net.cuda() 222 | (input, label) = (self.eval_images.cuda(), self.eval_labels.cuda()) 223 | log_probs = tmp_net(input) 224 | loss = nn.CrossEntropyLoss()(log_probs, label) 225 | return not torch.isnan(loss) 226 | 227 | def record_teacher(self, ldr_train, net, teachers, global_ep, log_dir=None, probe=True, resample=False): 228 | entropy = [] 229 | ldr_train = [] 230 | 231 | acc_per_teacher = np.zeros((len(teachers))) 232 | conf_per_teacher = np.zeros((len(teachers))) 233 | teacher_per_sample = 0.0 234 | has_correct_teacher_ratio = 0.0 235 | 236 | num = self.server_data_size 237 | if "cifar" in self.args.dataset: 238 | imgsize = 32 239 | elif "mnist" in self.args.dataset: 240 | imgsize = 28 241 | 242 | channel = 1 if self.args.dataset == "mnist" else 3 243 | all_images = np.zeros((num, channel, imgsize, imgsize)) 244 | all_logits = np.zeros((num, self.args.num_classes)) 245 | all_labels = np.zeros((num)) 246 | cnt = 0 247 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 248 | logits, batch_entropy = self.get_ensemble_logits(teachers, images.cuda(), method=self.args.logit_method, global_ep=global_ep) 249 | entropy.append(batch_entropy) 250 | 251 | all_images[cnt:cnt+len(images)] = images.numpy() 252 | all_logits[cnt:cnt+len(images)] = logits 253 | all_labels[cnt:cnt+len(images)] = labels.numpy() 254 | cnt+=len(images) 255 | 256 | ldr_train = (all_images, all_logits, all_labels) 257 | #============================= 258 | # If args.soft_vote = True: 259 | # soft_vote from experts 260 | # Else: 261 | # just mean of all logits 262 | #============================= 263 | if not probe: 264 | return ldr_train, 0.0, 0.0 265 | else: 266 | test_acc = self.eval_ensemble(teachers, self.test_dataset) 267 | train_acc = self.eval_ensemble(teachers, self.ldr_local_train) 268 | 269 | plt.plot(range(len(teachers)), acc_per_teacher, marker="o", label="Acc") 270 | plt.plot(range(len(teachers)), conf_per_teacher, marker="o", label="Confidence") 271 | plt.plot(range(len(teachers)), conf_per_teacher - acc_per_teacher, marker="o", label="Confidence - Acc") 272 | plt.ylim(ymax = 1.0, ymin = -0.2) 273 | plt.title("Round %d, correct teacher/per sample %.2f, upperbound correct %.1f percentage"%(global_ep, teacher_per_sample,has_correct_teacher_ratio*100.0)) 274 | plt.legend(loc='best') 275 | plt.savefig(os.path.join(log_dir, "acc_per_teacher_%d.png"% global_ep)) 276 | plt.clf() 277 | 278 | return ldr_train, train_acc, test_acc 279 | 280 | def set_opt(self, net): 281 | base_opt = torch.optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=0.00001) 282 | if self.args.use_SWA: 283 | self.optimizer = SWA(base_opt, swa_start=500, swa_freq=25, swa_lr=None) 284 | else: 285 | self.optimizer = base_opt 286 | 287 | def train(self, net, teachers, log_dir, global_ep, server_dataset=None): 288 | #======================Record teachers======================== 289 | self.set_opt(net) 290 | 291 | to_probe = True if global_ep%self.args.log_ep==0 else False 292 | ldr_train = [] 293 | ldr_train, train_acc, test_acc = self.record_teacher(ldr_train, net, teachers, global_ep, log_dir, probe=to_probe) 294 | (all_images, all_logits, all_labels) = ldr_train 295 | #======================Server Train======================== 296 | print("Start server training...") 297 | net.cuda() 298 | net.train() 299 | 300 | epoch_loss = [] 301 | acc = 0 302 | cnt = 0 303 | 304 | step = 0 305 | train_ep = self.args.server_ep 306 | for iter in range(train_ep): 307 | all_ids = list(range(len(all_images))) 308 | np.random.shuffle(all_ids) 309 | 310 | batch_loss = [] 311 | for batch_idx in range(0, len(all_images), self.args.server_bs): 312 | ids = all_ids[batch_idx:batch_idx+self.args.server_bs] 313 | images = all_images[ids] 314 | 315 | if self.aug: 316 | images = self.transform_train(images) 317 | else: 318 | images = torch.Tensor(images).cuda() 319 | logits = all_logits[ids] 320 | labels = all_labels[ids] 321 | 322 | net.zero_grad() 323 | log_probs = net(images) 324 | 325 | loss, acc_cnt_i, cnt_i = self.loss_wrapper(log_probs, logits, labels) 326 | acc+=acc_cnt_i 327 | cnt+=cnt_i 328 | loss.backward() 329 | 330 | self.optimizer.step() 331 | step+=1 332 | 333 | if batch_idx == 0 and iter%5==0: 334 | print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 335 | iter, batch_idx * len(images), len(self.ldr_train.dataset), 336 | 100. * batch_idx / len(self.ldr_train), loss.item())) 337 | batch_loss.append(loss.item()) 338 | 339 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 340 | 341 | val_acc = float(acc)/cnt*100.0 342 | net_glob = copy.deepcopy(net) 343 | 344 | if self.args.use_SWA: 345 | self.optimizer.swap_swa_sgd() 346 | if "resnet" in self.args.model: 347 | self.optimizer.bn_update(self.ldr_train, net, device=None) 348 | 349 | net = net.cpu() 350 | w_glob_avg = copy.deepcopy(net.state_dict()) 351 | w_glob = net_glob.cpu().state_dict() 352 | 353 | print("Ensemble Acc Train %.2f Val %.2f Test %.2f mean entropy %.5f"%(train_acc, val_acc, test_acc, 0.0)) 354 | return w_glob_avg, w_glob, train_acc, val_acc, test_acc, sum(epoch_loss) / len(epoch_loss), 0.0 355 | 356 | 357 | def check_size(size): 358 | if type(size) == int: 359 | size = (size, size) 360 | if type(size) != tuple: 361 | raise TypeError('size is int or tuple') 362 | return size 363 | 364 | def random_crop(images, crop_size): 365 | for i, image in enumerate(images): 366 | image = np.pad(image, crop_size) 367 | _, h, w = image.shape 368 | top = np.random.randint(0, crop_size*2) 369 | left = np.random.randint(0, crop_size*2) 370 | bottom = top + (h - 2*crop_size) 371 | right = left + (w - 2*crop_size) 372 | 373 | images[i] = image[crop_size:-crop_size, top:bottom, left:right] 374 | 375 | return images 376 | 377 | def horizontal_flip(image, rate=0.5): 378 | if np.random.rand() < rate: 379 | #image = image[:, :, :, ::-1] 380 | image = np.flip(image, axis=-1) 381 | return image 382 | 383 | --------------------------------------------------------------------------------