├── Algorithm ├── Training_FedASAM.py ├── Training_FedExP.py ├── Training_FedGen.py ├── Training_FedIndenp.py ├── Training_FedMR.py ├── Training_FedMut.py ├── __init__.py └── __pycache__ │ ├── Training_Asyn_FedSA.cpython-310.pyc │ ├── Training_Asyn_FedSA.cpython-312.pyc │ ├── Training_Asyn_FedSA.cpython-37.pyc │ ├── Training_Asyn_FedSA.cpython-39.pyc │ ├── Training_Asyn_GitFL.cpython-310.pyc │ ├── Training_Asyn_GitFL.cpython-312.pyc │ ├── Training_Asyn_GitFL.cpython-37.pyc │ ├── Training_Asyn_GitFL.cpython-39.pyc │ ├── Training_BranchyFedAvg.cpython-310.pyc │ ├── Training_BranchyFedAvg.cpython-312.pyc │ ├── Training_BranchyFedAvg.cpython-39.pyc │ ├── Training_CFL.cpython-310.pyc │ ├── Training_CFL.cpython-312.pyc │ ├── Training_CFL.cpython-37.pyc │ ├── Training_CFL.cpython-39.pyc │ ├── Training_FedASAM.cpython-310.pyc │ ├── Training_FedASAM.cpython-312.pyc │ ├── Training_FedBack.cpython-312.pyc │ ├── Training_FedCross.cpython-310.pyc │ ├── Training_FedCross.cpython-312.pyc │ ├── Training_FedCross.cpython-37.pyc │ ├── Training_FedCross.cpython-39.pyc │ ├── Training_FedDC.cpython-310.pyc │ ├── Training_FedDC.cpython-37.pyc │ ├── Training_FedDC.cpython-39.pyc │ ├── Training_FedDC_new.cpython-310.pyc │ ├── Training_FedDC_new.cpython-312.pyc │ ├── Training_FedExP.cpython-310.pyc │ ├── Training_FedExP.cpython-312.pyc │ ├── Training_FedExP.cpython-39.pyc │ ├── Training_FedGA.cpython-310.pyc │ ├── Training_FedGA.cpython-312.pyc │ ├── Training_FedGA.cpython-39.pyc │ ├── Training_FedGen.cpython-310.pyc │ ├── Training_FedGen.cpython-312.pyc │ ├── Training_FedGen.cpython-37.pyc │ ├── Training_FedGen.cpython-39.pyc │ ├── Training_FedIndenp.cpython-310.pyc │ ├── Training_FedIndenp.cpython-312.pyc │ ├── Training_FedIndenp.cpython-37.pyc │ ├── Training_FedIndenp.cpython-39.pyc │ ├── Training_FedJellyfish.cpython-39.pyc │ ├── Training_FedMR.cpython-310.pyc │ ├── Training_FedMR.cpython-312.pyc │ ├── Training_FedMR.cpython-37.pyc │ ├── Training_FedMR.cpython-39.pyc │ ├── Training_FedMut.cpython-310.pyc │ ├── Training_FedMut.cpython-312.pyc │ ├── Training_FedMut.cpython-39.pyc │ ├── Triaining_Scaffold.cpython-310.pyc │ ├── Triaining_Scaffold.cpython-312.pyc │ ├── Triaining_Scaffold.cpython-37.pyc │ ├── Triaining_Scaffold.cpython-39.pyc │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-312.pyc │ ├── __init__.cpython-37.pyc │ └── __init__.cpython-39.pyc ├── LICENSE ├── README.md ├── dis_plot.py ├── main_fed.py ├── main_nn.py ├── models ├── Fed.py ├── MobileNetV2.py ├── Nets.py ├── Update.py ├── __init__.py ├── __pycache__ │ ├── Fed.cpython-310.pyc │ ├── Fed.cpython-312.pyc │ ├── Fed.cpython-37.pyc │ ├── Fed.cpython-39.pyc │ ├── MobileNetV2.cpython-310.pyc │ ├── MobileNetV2.cpython-312.pyc │ ├── MobileNetV2.cpython-37.pyc │ ├── MobileNetV2.cpython-39.pyc │ ├── Nets.cpython-310.pyc │ ├── Nets.cpython-312.pyc │ ├── Nets.cpython-37.pyc │ ├── Nets.cpython-39.pyc │ ├── Update.cpython-310.pyc │ ├── Update.cpython-312.pyc │ ├── Update.cpython-37.pyc │ ├── Update.cpython-39.pyc │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-312.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-39.pyc │ ├── generator.cpython-310.pyc │ ├── generator.cpython-312.pyc │ ├── generator.cpython-37.pyc │ ├── generator.cpython-39.pyc │ ├── lstm.cpython-310.pyc │ ├── lstm.cpython-312.pyc │ ├── lstm.cpython-37.pyc │ ├── lstm.cpython-39.pyc │ ├── resnetcifar.cpython-310.pyc │ ├── resnetcifar.cpython-312.pyc │ ├── resnetcifar.cpython-37.pyc │ ├── resnetcifar.cpython-39.pyc │ ├── test.cpython-310.pyc │ ├── test.cpython-312.pyc │ ├── test.cpython-37.pyc │ └── test.cpython-39.pyc ├── at.py ├── generator.py ├── lstm.py ├── models.py ├── resnetcifar.py └── test.py ├── optimizer ├── Adabelief.py ├── __init__.py └── __pycache__ │ ├── Adabelief.cpython-310.pyc │ ├── Adabelief.cpython-312.pyc │ ├── Adabelief.cpython-37.pyc │ ├── Adabelief.cpython-39.pyc │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-312.pyc │ ├── __init__.cpython-37.pyc │ └── __init__.cpython-39.pyc ├── test.py └── utils ├── Clients.py ├── FEMNIST.py ├── ShakeSpeare.py ├── __init__.py ├── __pycache__ ├── Clients.cpython-310.pyc ├── Clients.cpython-312.pyc ├── Clients.cpython-37.pyc ├── Clients.cpython-39.pyc ├── FEMNIST.cpython-310.pyc ├── FEMNIST.cpython-312.pyc ├── FEMNIST.cpython-37.pyc ├── FEMNIST.cpython-39.pyc ├── ShakeSpare.cpython-37.pyc ├── ShakeSpeare.cpython-310.pyc ├── ShakeSpeare.cpython-312.pyc ├── ShakeSpeare.cpython-37.pyc ├── ShakeSpeare.cpython-39.pyc ├── __init__.cpython-310.pyc ├── __init__.cpython-311.pyc ├── __init__.cpython-312.pyc ├── __init__.cpython-37.pyc ├── __init__.cpython-39.pyc ├── asynchronous_client_config.cpython-310.pyc ├── asynchronous_client_config.cpython-312.pyc ├── asynchronous_client_config.cpython-37.pyc ├── asynchronous_client_config.cpython-39.pyc ├── clustering.cpython-310.pyc ├── clustering.cpython-312.pyc ├── clustering.cpython-37.pyc ├── clustering.cpython-39.pyc ├── dataset_utils.cpython-310.pyc ├── dataset_utils.cpython-312.pyc ├── dataset_utils.cpython-37.pyc ├── dataset_utils.cpython-39.pyc ├── get_dataset.cpython-310.pyc ├── get_dataset.cpython-312.pyc ├── get_dataset.cpython-37.pyc ├── get_dataset.cpython-39.pyc ├── language_utils.cpython-310.pyc ├── language_utils.cpython-312.pyc ├── language_utils.cpython-37.pyc ├── language_utils.cpython-39.pyc ├── model_config.cpython-310.pyc ├── model_config.cpython-312.pyc ├── model_config.cpython-37.pyc ├── model_config.cpython-39.pyc ├── mydata.cpython-310.pyc ├── mydata.cpython-312.pyc ├── mydata.cpython-37.pyc ├── mydata.cpython-39.pyc ├── options.cpython-310.pyc ├── options.cpython-311.pyc ├── options.cpython-312.pyc ├── options.cpython-37.pyc ├── options.cpython-39.pyc ├── sam_minimizers.cpython-310.pyc ├── sam_minimizers.cpython-312.pyc ├── sampling.cpython-310.pyc ├── sampling.cpython-312.pyc ├── sampling.cpython-37.pyc ├── sampling.cpython-39.pyc ├── set_seed.cpython-310.pyc ├── set_seed.cpython-311.pyc ├── set_seed.cpython-312.pyc ├── set_seed.cpython-37.pyc ├── set_seed.cpython-39.pyc ├── utils.cpython-310.pyc ├── utils.cpython-312.pyc ├── utils.cpython-37.pyc └── utils.cpython-39.pyc ├── clustering.py ├── dataset_utils.py ├── get_dataset.py ├── language_utils.py ├── model_config.py ├── mydata.py ├── options.py ├── sam_minimizers.py ├── sampling.py ├── set_seed.py └── utils.py /Algorithm/Training_FedASAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.utils.data import DataLoader 4 | from torch import nn 5 | import copy 6 | import numpy as np 7 | import random 8 | from models.Fed import Aggregation 9 | from utils.utils import save_result 10 | from models.test import test_img 11 | from models.Update import DatasetSplit 12 | from optimizer.Adabelief import AdaBelief 13 | from utils.sam_minimizers import ASAM 14 | 15 | 16 | class LocalUpdate_FedASAM(object): 17 | def __init__(self, args, dataset=None, idxs=None, verbose=False): 18 | self.args = args 19 | self.loss_func = nn.CrossEntropyLoss() 20 | self.selected_clients = [] 21 | self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True) 22 | self.ensemble_alpha = args.ensemble_alpha 23 | self.verbose = verbose 24 | self.mixup=False 25 | self.mixup_alpha=1.0 26 | self.rho = 0.1 27 | self.eta = 0 28 | 29 | def train(self, net): 30 | 31 | net.to(self.args.device) 32 | 33 | net.train() 34 | # train and update 35 | if self.args.optimizer == 'sgd': 36 | optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum) 37 | elif self.args.optimizer == 'adam': 38 | optimizer = torch.optim.Adam(net.parameters(), lr=self.args.lr) 39 | elif self.args.optimizer == 'adaBelief': 40 | optimizer = AdaBelief(net.parameters(), lr=self.args.lr) 41 | 42 | Predict_loss = 0 43 | 44 | for iter in range(self.args.local_ep): 45 | minimizer = ASAM(optimizer, net, self.rho, self.eta) 46 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 47 | images, labels = images.to(self.args.device), labels.to(self.args.device) 48 | net.zero_grad() 49 | model_output = net(images) 50 | if self.mixup: 51 | model_output, targets_a, targets_b, lam = self.mixup_data(model_output['output'], labels) 52 | # predictive_loss = self.mixup_criterion(model_output, targets_a, targets_b, lam) 53 | # Ascent Step 54 | predictive_loss = self.mixup_criterion(model_output['output'], targets_a, targets_b, lam) 55 | predictive_loss.backward() 56 | minimizer.ascent_step() 57 | 58 | # Descent Step 59 | loss = self.mixup_criterion(net(images)['output'], targets_a, targets_b, lam) 60 | loss.backward() 61 | minimizer.descent_step() 62 | else: 63 | # Ascent Step 64 | predictive_loss = self.loss_func(model_output['output'], labels) 65 | predictive_loss.backward() 66 | minimizer.ascent_step() 67 | # Descent Step 68 | self.loss_func(net(images)['output'], labels).backward() 69 | minimizer.descent_step() 70 | 71 | 72 | # loss = predictive_loss 73 | Predict_loss += predictive_loss.item() 74 | 75 | # loss.backward() 76 | # optimizer.step() 77 | 78 | if self.verbose: 79 | info = '\nUser predict Loss={:.4f}'.format(Predict_loss / (self.args.local_ep * len(self.ldr_train))) 80 | print(info) 81 | 82 | # net.to('cpu') 83 | 84 | return net.state_dict() 85 | 86 | def mixup_data(self, x, y): 87 | '''Returns mixed inputs, pairs of targets, and lambda''' 88 | if self.mixup_alpha > 0: 89 | lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) 90 | else: 91 | lam = 1 92 | 93 | batch_size = x.size()[0] 94 | index = torch.randperm(batch_size).to(self.device) 95 | mixed_x = lam * x + (1 - lam) * x[index, :] 96 | y_a, y_b = y, y[index] 97 | return mixed_x, y_a, y_b, lam 98 | 99 | def mixup_criterion(self, pred, y_a, y_b, lam): 100 | return lam * self.loss_func(pred, y_a) + (1 - lam) * self.loss_func(pred, y_b) 101 | 102 | 103 | def FedASAM(args, net_glob, dataset_train, dataset_test, dict_users): 104 | net_glob.train() 105 | 106 | 107 | times = [] 108 | total_time = 0 109 | 110 | # training 111 | acc = [] 112 | loss = [] 113 | train_loss=[] 114 | 115 | for iter in range(args.epochs): 116 | 117 | print('*'*80) 118 | print('Round {:3d}'.format(iter)) 119 | 120 | 121 | w_locals = [] 122 | lens = [] 123 | m = max(int(args.frac * args.num_users), 1) 124 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 125 | for idx in idxs_users: 126 | local = LocalUpdate_FedASAM(args=args, dataset=dataset_train, idxs=dict_users[idx]) 127 | w = local.train(net=copy.deepcopy(net_glob).to(args.device)) 128 | 129 | w_locals.append(copy.deepcopy(w)) 130 | lens.append(len(dict_users[idx])) 131 | # update global weights 132 | w_glob = Aggregation(w_locals, lens) 133 | 134 | # copy weight to net_glob 135 | net_glob.load_state_dict(w_glob) 136 | 137 | if iter % 10 == 9: 138 | item_acc = test(net_glob, dataset_test, args) 139 | acc.append(item_acc) 140 | 141 | save_result(acc, 'test_acc', args) 142 | 143 | 144 | 145 | def test(net_glob, dataset_test, args): 146 | # testing 147 | acc_test, loss_test = test_img(net_glob, dataset_test, args) 148 | 149 | print("Testing accuracy: {:.2f}".format(acc_test)) 150 | 151 | return acc_test.item() 152 | -------------------------------------------------------------------------------- /Algorithm/Training_FedExP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.utils.data import DataLoader 4 | from torch import nn 5 | import copy 6 | import numpy as np 7 | import random 8 | from models.Fed import Aggregation,Sub,Mul,Div,Add 9 | from utils.utils import save_result 10 | from models.test import test_img,branchy_test_img 11 | from models.Update import DatasetSplit 12 | from optimizer.Adabelief import AdaBelief 13 | from torch.nn.utils import parameters_to_vector, vector_to_parameters 14 | 15 | 16 | def FedExP(args, net_glob, dataset_train, dataset_test, dict_users): 17 | 18 | net_glob.train() 19 | 20 | times = [] 21 | total_time = 0 22 | 23 | # training 24 | acc = [] 25 | loss = [] 26 | train_loss=[] 27 | 28 | grad_norm_avg_running = 0 29 | 30 | w_old = copy.deepcopy(net_glob.state_dict()) 31 | p = np.zeros((args.num_users)) 32 | 33 | for i in range(args.num_users): 34 | p[i] = len(dict_users[i]) 35 | 36 | p = p/np.sum(p) 37 | d = parameters_to_vector(net_glob.parameters()).numel() 38 | 39 | w_vec_estimate = parameters_to_vector(net_glob.parameters()) 40 | 41 | 42 | for iter in range(args.epochs): 43 | 44 | print('*'*80) 45 | print('Round {:3d}'.format(iter)) 46 | 47 | 48 | w_locals = [] 49 | lens = [] 50 | grad_norm_sum = 0 51 | p_sum =0 52 | grad_avg = copy.deepcopy(net_glob.state_dict()) 53 | m = max(int(args.frac * args.num_users), 1) 54 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 55 | tag = 0 56 | # pre_grad = parameters_to_vector(net_glob.parameters()) 57 | for idx in idxs_users: 58 | local = LocalUpdate_FedExP(args=args, dataset=dataset_train, idxs=dict_users[idx]) 59 | w, grad_local = local.train(net=copy.deepcopy(net_glob).to(args.device)) 60 | w_locals.append(copy.deepcopy(w)) 61 | w_grad = Sub(w, net_glob.state_dict()) 62 | grad = parameters_to_vector(grad_local) - parameters_to_vector(net_glob.parameters()) 63 | grad_norm_sum += p[idx]*torch.linalg.norm(grad)**2 64 | w_grad = Mul(w_grad,p[idx]) 65 | if (tag == 0): 66 | grad_avg = w_grad 67 | else: 68 | grad_avg = Add(grad_avg,w_grad) 69 | p_sum += p[idx] 70 | lens.append(len(dict_users[idx])) 71 | tag += 1 72 | # update global weights 73 | w_glob = Aggregation(w_locals, lens) 74 | 75 | with torch.no_grad(): 76 | grad_avg = Div(grad_avg,p_sum) 77 | grad_norm_avg = grad_norm_sum/p_sum 78 | grad_norm_avg_running = grad_norm_avg +0.9*0.5*grad_norm_avg_running 79 | net_eval = copy.deepcopy(net_glob) 80 | net_eval.load_state_dict(grad_avg) 81 | grad_avg_norm = torch.linalg.norm(parameters_to_vector(net_eval.parameters()))**2 82 | 83 | eta_g = (0.5*grad_norm_avg/(grad_avg_norm + m*0.1)) 84 | eta_g = max(1,eta_g) 85 | 86 | w_vec_prev = w_vec_estimate 87 | 88 | w_vev_prev= copy.deepcopy(net_glob.state_dict()) 89 | 90 | w_vec_estimate = Add(net_glob.state_dict(), Mul(grad_avg,eta_g)) 91 | 92 | 93 | 94 | if(iter>0): 95 | w_vec_avg = Div(Add(w_vec_estimate,w_vec_prev),2) 96 | else: 97 | w_vec_avg = w_vec_estimate 98 | 99 | 100 | # copy weight to net_glob 101 | # net_glob.load_state_dict(w_glob) 102 | 103 | # w_old = copy.deepcopy(w_glob) 104 | # vector_to_parameters(w_vec_estimate,net_glob.parameters()) 105 | net_glob.load_state_dict(w_vec_estimate) 106 | 107 | net_eval = copy.deepcopy(net_glob) 108 | net_eval.load_state_dict(w_vec_avg) 109 | # vector_to_parameters(w_vec_avg, net_eval.parameters()) 110 | # vector_to_parameters(w_vec_avg, net_eval.parameters()) 111 | 112 | if iter % 10 == 9: 113 | item_acc,item_loss = test_img(net_eval, dataset_test, args) 114 | acc.append(item_acc) 115 | loss.append(item_loss) 116 | 117 | print("Testing accuracy: {:.2f}".format(item_acc)) 118 | print("Testing loss: {:.2f}".format(item_loss)) 119 | 120 | save_result(acc, 'test_acc', args) 121 | save_result(loss, 'test_loss', args) 122 | 123 | 124 | class LocalUpdate_FedExP(object): 125 | def __init__(self, args, dataset=None, idxs=None, verbose=False): 126 | self.args = args 127 | self.loss_func = nn.CrossEntropyLoss() 128 | self.selected_clients = [] 129 | self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True) 130 | self.verbose = verbose 131 | 132 | def train(self, net): 133 | 134 | net.train() 135 | # train and update 136 | if self.args.optimizer == 'sgd': 137 | optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum) 138 | elif self.args.optimizer == 'adam': 139 | optimizer = torch.optim.Adam(net.parameters(), lr=self.args.lr) 140 | elif self.args.optimizer == 'adaBelief': 141 | optimizer = AdaBelief(net.parameters(), lr=self.args.lr) 142 | 143 | Predict_loss = 0 144 | for iter in range(self.args.local_ep): 145 | 146 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 147 | images, labels = images.to(self.args.device), labels.to(self.args.device) 148 | net.zero_grad() 149 | log_probs = net(images)['output'] 150 | loss = self.loss_func(log_probs, labels) 151 | loss.backward() 152 | optimizer.step() 153 | 154 | Predict_loss += loss.item() 155 | 156 | if self.verbose: 157 | info = '\nUser predict Loss={:.4f}'.format(Predict_loss / (self.args.local_ep * len(self.ldr_train))) 158 | print(info) 159 | 160 | return net.state_dict(), net.parameters() 161 | 162 | 163 | -------------------------------------------------------------------------------- /Algorithm/Training_FedGen.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | import numpy as np 5 | import copy 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader 9 | from models.generator import Generator 10 | from models.Update import LocalUpdate_FedGen,DatasetSplit 11 | from models.Fed import Aggregation 12 | from models.test import test_img 13 | from utils.utils import save_result 14 | from utils.model_config import FedGenRUNCONFIGS 15 | 16 | 17 | MIN_SAMPLES_PER_LABEL=1 18 | 19 | def init_configs(args): 20 | 21 | 22 | RUNCONFIGS = FedGenRUNCONFIGS 23 | #### used for ensemble learning #### 24 | dataset_name = args.dataset 25 | args.ensemble_lr = RUNCONFIGS[dataset_name].get('ensemble_lr', 1e-4) 26 | args.ensemble_batch_size = RUNCONFIGS[dataset_name].get('ensemble_batch_size', 128) 27 | args.ensemble_epochs= RUNCONFIGS[dataset_name]['ensemble_epochs'] 28 | args.num_pretrain_iters = RUNCONFIGS[dataset_name]['num_pretrain_iters'] 29 | args.temperature = RUNCONFIGS[dataset_name].get('temperature', 1) 30 | args.ensemble_alpha = RUNCONFIGS[dataset_name].get('ensemble_alpha', 1) 31 | args.ensemble_beta = RUNCONFIGS[dataset_name].get('ensemble_beta', 0) 32 | args.ensemble_eta = RUNCONFIGS[dataset_name].get('ensemble_eta', 1) 33 | args.weight_decay = RUNCONFIGS[dataset_name].get('weight_decay', 0) 34 | args.generative_alpha = RUNCONFIGS[dataset_name]['generative_alpha'] 35 | args.generative_beta = RUNCONFIGS[dataset_name]['generative_beta'] 36 | args.ensemble_train_loss = [] 37 | args.n_teacher_iters = 5 38 | args.n_student_iters = 1 39 | 40 | 41 | def read_user_data(args, dataset_train, dict_users): 42 | 43 | label_counts_users = [] 44 | 45 | for idx in range(len(dict_users)): 46 | data_loader = DataLoader(DatasetSplit(dataset_train,dict_users[idx]),len(dict_users[idx])) 47 | for _,y in data_loader: 48 | unique_y, counts = torch.unique(y, return_counts=True) 49 | label_counts = [0 for i in range(args.num_classes)] 50 | for label, count in zip(unique_y, counts): 51 | label_counts[int(label)] += count 52 | label_counts_users.append(label_counts) 53 | 54 | return label_counts_users 55 | 56 | def FedGen(args, net_glob, dataset_train, dataset_test, dict_users): 57 | 58 | init_configs(args) 59 | 60 | net_glob.train() 61 | 62 | generative_model = Generator(args.dataset, args.model, embedding=False, latent_layer_idx=-1) 63 | 64 | label_counts = read_user_data(args, dataset_train, dict_users) 65 | 66 | # generate list of local models for each user 67 | net_local_list = [] 68 | w_locals = {} 69 | for user in range(args.num_users): 70 | w_local_dict = {} 71 | for key in net_glob.state_dict().keys(): 72 | w_local_dict[key] = net_glob.state_dict()[key] 73 | w_locals[user] = w_local_dict 74 | 75 | acc = [] 76 | 77 | for iter in range(args.epochs): 78 | 79 | print('*' * 80) 80 | print('Round {:3d}'.format(iter)) 81 | 82 | w_locals = [] 83 | lens = [] 84 | m = max(int(args.frac * args.num_users), 1) 85 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 86 | 87 | user_models = [] 88 | 89 | for idx in idxs_users: 90 | local = LocalUpdate_FedGen(args=args, generative_model=generative_model, dataset=dataset_train, idxs=dict_users[idx], regularization=iter!=0) 91 | user_model = local.train(net=copy.deepcopy(net_glob).to(args.device)) 92 | user_models.append(user_model) 93 | w_locals.append(copy.deepcopy(user_model.state_dict())) 94 | lens.append(len(dict_users[idx])) 95 | net_glob.to('cpu') 96 | train_generator( 97 | args, 98 | net_glob, 99 | generative_model, 100 | user_models, 101 | idxs_users, 102 | label_counts, 103 | args.bs, 104 | epoches=args.ensemble_epochs // args.n_teacher_iters, 105 | latent_layer_idx = -1, 106 | verbose=True 107 | ) 108 | net_glob.to(args.device) 109 | # update global weights 110 | w_glob = Aggregation(w_locals, lens) 111 | 112 | # copy weight to net_glob 113 | net_glob.load_state_dict(w_glob) 114 | 115 | if iter % 10 == 9: 116 | acc.append(test(net_glob, dataset_test, args)) 117 | 118 | save_result(acc, 'test_acc', args) 119 | 120 | def get_label_weights(args, users, label_counts): 121 | label_weights = [] 122 | qualified_labels = [] 123 | for label in range(args.num_classes): 124 | weights = [] 125 | for user in users: 126 | weights.append(label_counts[user][label]) 127 | if np.max(weights) > MIN_SAMPLES_PER_LABEL: 128 | qualified_labels.append(label) 129 | # uniform 130 | label_weights.append( np.array(weights) / np.sum(weights) ) 131 | label_weights = np.array(label_weights).reshape((args.num_classes, -1)) 132 | return label_weights, qualified_labels 133 | 134 | def train_generator(args, net_glob, generative_model, models, users, label_counts, batch_size, epoches=1, latent_layer_idx=-1, verbose=False): 135 | """ 136 | Learn a generator that find a consensus latent representation z, given a label 'y'. 137 | :param batch_size: 138 | :param epoches: 139 | :param latent_layer_idx: if set to -1 (-2), get latent representation of the last (or 2nd to last) layer. 140 | :param verbose: print loss information. 141 | :return: Do not return anything. 142 | """ 143 | 144 | label_weights, qualified_labels = get_label_weights(args, users, label_counts) 145 | TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS, STUDENT_LOSS2 = 0, 0, 0, 0 146 | 147 | generative_optimizer = torch.optim.Adam(params=generative_model.parameters(),lr=args.ensemble_lr,weight_decay=args.weight_decay) 148 | 149 | for i in range(epoches): 150 | 151 | generative_model.train() 152 | net_glob.eval() 153 | 154 | for i in range(args.n_teacher_iters): 155 | generative_optimizer.zero_grad() 156 | 157 | y = np.random.choice(qualified_labels, batch_size) 158 | y_input = torch.LongTensor(y) 159 | ## feed to generator 160 | gen_result = generative_model(y_input, latent_layer_idx=latent_layer_idx, verbose=True) 161 | # get approximation of Z( latent) if latent set to True, X( raw image) otherwise 162 | gen_output, eps = gen_result['output'], gen_result['eps'] 163 | ##### get losses #### 164 | # decoded = self.generative_regularizer(gen_output) 165 | # regularization_loss = beta * self.generative_model.dist_loss(decoded, eps) # map generated z back to eps 166 | diversity_loss = generative_model.diversity_loss(eps, gen_output) # encourage different outputs 167 | 168 | ######### get teacher loss ############ 169 | teacher_loss = 0 170 | teacher_logit = 0 171 | for user_idx, user_model in enumerate(models): 172 | weight = label_weights[y][:, user_idx].reshape(-1, 1) 173 | expand_weight = np.tile(weight, (1, args.num_classes)) 174 | user_result_given_gen = user_model(gen_output, start_layer_idx=latent_layer_idx) 175 | user_output_logp_ = user_result_given_gen['output'] 176 | teacher_loss_=torch.mean( \ 177 | generative_model.crossentropy_loss(user_output_logp_, y_input) * \ 178 | torch.tensor(weight, dtype=torch.float32)) 179 | teacher_loss += teacher_loss_ 180 | teacher_logit += user_result_given_gen['output'] * torch.tensor(expand_weight, dtype=torch.float32) 181 | 182 | ######### get student loss ############ 183 | student_output = net_glob(gen_output, start_layer_idx=latent_layer_idx) 184 | student_loss = F.kl_div(F.log_softmax(student_output['output'], dim=1), F.softmax(teacher_logit, dim=1)) 185 | if args.ensemble_beta > 0: 186 | loss = args.ensemble_alpha * teacher_loss - args.ensemble_beta * student_loss + args.ensemble_eta * diversity_loss 187 | else: 188 | loss = args.ensemble_alpha * teacher_loss + args.ensemble_eta * diversity_loss 189 | loss.backward() 190 | generative_optimizer.step() 191 | TEACHER_LOSS += args.ensemble_alpha * teacher_loss.item()#(torch.mean(TEACHER_LOSS.double())).item() 192 | STUDENT_LOSS += args.ensemble_beta * student_loss.item()#(torch.mean(student_loss.double())).item() 193 | DIVERSITY_LOSS += args.ensemble_eta * diversity_loss.item()#(torch.mean(diversity_loss.double())).item() 194 | 195 | TEACHER_LOSS = TEACHER_LOSS / (args.n_teacher_iters * epoches) 196 | STUDENT_LOSS = STUDENT_LOSS / (args.n_teacher_iters * epoches) 197 | DIVERSITY_LOSS = DIVERSITY_LOSS / (args.n_teacher_iters * epoches) 198 | info="Generator: Teacher Loss= {:.4f}, Student Loss= {:.4f}, Diversity Loss = {:.4f}, ". \ 199 | format(TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS) 200 | if verbose: 201 | print(info) 202 | 203 | def test(net_glob, dataset_test, args): 204 | # testing 205 | acc_test, loss_test = test_img(net_glob, dataset_test, args) 206 | 207 | print("Testing accuracy: {:.2f}".format(acc_test)) 208 | 209 | return acc_test.item() -------------------------------------------------------------------------------- /Algorithm/Training_FedIndenp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | from torch import nn 8 | import copy 9 | import numpy as np 10 | import random 11 | from models.Fed import Aggregation 12 | from utils.utils import save_result,save_model 13 | from models.test import test_img 14 | from models.Update import DatasetSplit 15 | from optimizer.Adabelief import AdaBelief 16 | 17 | 18 | class LocalUpdate_FedIndep(object): 19 | def __init__(self, args, dataset=None, idxs=None, verbose=False): 20 | self.args = args 21 | self.loss_func = nn.CrossEntropyLoss() 22 | self.selected_clients = [] 23 | self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True) 24 | self.verbose = verbose 25 | 26 | def train(self, net): 27 | 28 | net.to(self.args.device) 29 | 30 | net.train() 31 | # train and update 32 | if self.args.optimizer == 'sgd': 33 | optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum) 34 | elif self.args.optimizer == 'adam': 35 | optimizer = torch.optim.Adam(net.parameters(), lr=self.args.lr) 36 | elif self.args.optimizer == 'adaBelief': 37 | optimizer = AdaBelief(net.parameters(), lr=self.args.lr) 38 | 39 | Predict_loss = 0 40 | 41 | for iter in range(self.args.local_ep): 42 | 43 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 44 | images, labels = images.to(self.args.device), labels.to(self.args.device) 45 | net.zero_grad() 46 | model_output = net(images) 47 | predictive_loss = self.loss_func(model_output['output'], labels) 48 | 49 | loss = predictive_loss 50 | Predict_loss += predictive_loss.item() 51 | 52 | loss.backward() 53 | optimizer.step() 54 | 55 | if self.verbose: 56 | info = '\nUser predict Loss={:.4f}'.format(Predict_loss / (self.args.local_ep * len(self.ldr_train))) 57 | print(info) 58 | 59 | # net.to('cpu') 60 | 61 | return net.state_dict() 62 | 63 | def FedIndep(args, net_glob, dataset_train, dataset_test, dict_users): 64 | net_glob.train() 65 | 66 | acc = [] 67 | w_locals = [] 68 | sim_arr = [] 69 | loss = [] 70 | train_loss = [] 71 | indep_loss = [] 72 | indep_acc = [] 73 | indep_train_loss = [] 74 | 75 | m = max(int(args.frac * args.num_users), 1) 76 | for i in range(m): 77 | w_locals.append(copy.deepcopy(net_glob.state_dict())) 78 | indep_loss.append([]) 79 | indep_acc.append([]) 80 | indep_train_loss.append([]) 81 | 82 | for iter in range(args.epochs): 83 | 84 | print('*' * 80) 85 | print('Round {:3d}'.format(iter)) 86 | 87 | 88 | m = max(int(args.frac * args.num_users), 1) 89 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 90 | for i, idx in enumerate(idxs_users): 91 | 92 | net_glob.load_state_dict(w_locals[i]) 93 | local = LocalUpdate_FedIndep(args=args, dataset=dataset_train, idxs=dict_users[idx]) 94 | w = local.train(net=net_glob) 95 | w_locals[i] = copy.deepcopy(w) 96 | 97 | # update global weights 98 | w_glob = Aggregation(w_locals, None) # Global Model Generation 99 | 100 | # copy weight to net_glob 101 | net_glob.load_state_dict(w_glob) 102 | 103 | 104 | 105 | if iter % 10 == 9: 106 | item_acc,item_loss = test_with_loss(net_glob, dataset_test, args) 107 | ta, tl = test_with_loss(net_glob,dataset_train,args) 108 | acc.append(item_acc) 109 | loss.append(item_loss) 110 | train_loss.append(tl) 111 | sim_arr.append(sim(args, w_locals)) 112 | for indep in range(m): 113 | net_glob.load_state_dict(w_locals[indep]) 114 | item_acc,item_loss = test_with_loss(net_glob, dataset_test, args) 115 | ta, tl = test_with_loss(net_glob,dataset_train,args) 116 | indep_acc[indep].append(item_acc) 117 | indep_loss[indep].append(item_loss) 118 | indep_train_loss[indep].append(tl) 119 | 120 | 121 | save_result(acc, 'test_acc', args) 122 | save_result(sim_arr, 'sim', args) 123 | save_result(loss, 'test_loss', args) 124 | save_model(w_glob, 'test_model' + str(i), args) 125 | for i in range(m): 126 | save_result(indep_acc[i], 'test_acc_' + str(i), args) 127 | save_result(indep_loss[i], 'test_loss' + str(i), args) 128 | save_result(indep_train_loss[i], 'test_train_loss' + str(i), args) 129 | save_model(w_locals[i], 'test_model' + str(i), args) 130 | 131 | def test(net_glob, dataset_test, args): 132 | # testing 133 | acc_test, loss_test = test_img(net_glob, dataset_test, args) 134 | 135 | print("Testing accuracy: {:.2f}".format(acc_test)) 136 | 137 | return acc_test.item() 138 | 139 | def sim(args,net_glob_arr): 140 | model_num = int(args.num_users*args.frac) 141 | sim_tab = [[0 for _ in range(model_num)] for _ in range(model_num)] 142 | minsum = 10 143 | subminsum = 10 144 | sum_sim = 0.0 145 | for k in range(model_num): 146 | sim_arr = [] 147 | idx = 0 148 | # sim_sum = 0.0 149 | for j in range(k): 150 | sim = 0.0 151 | s = 0.0 152 | dict_a = torch.Tensor(0) 153 | dict_b = torch.Tensor(0) 154 | cnt = 0 155 | for p in net_glob_arr[k].keys(): 156 | a = net_glob_arr[k][p] 157 | b = net_glob_arr[j][p] 158 | a = a.view(-1) 159 | b = b.view(-1) 160 | 161 | 162 | if cnt == 0: 163 | dict_a = a 164 | dict_b = b 165 | else: 166 | dict_a = torch.cat((dict_a, a), dim=0) 167 | dict_b = torch.cat((dict_b, b), dim=0) 168 | 169 | if cnt % 5 == 0: 170 | sub_a = a 171 | sub_b = b 172 | else: 173 | sub_a = torch.cat((sub_a, a), dim=0) 174 | sub_b = torch.cat((sub_b, b), dim=0) 175 | # if not a.equal(b): 176 | # sub_a = torch.cat((sub_a, a), dim=0) 177 | # sub_b = torch.cat((sub_b, b), dim=0) 178 | 179 | if cnt % 5 == 4: 180 | s+= F.cosine_similarity(sub_a, sub_b, dim=0) 181 | cnt += 1 182 | # print(sim) 183 | s+= F.cosine_similarity(sub_a, sub_b, dim=0) 184 | sim = F.cosine_similarity(dict_a, dict_b, dim=0) 185 | # print (sim) 186 | sim_arr.append(sim) 187 | sim_tab[k][j] = sim 188 | sim_tab[j][k] = sim 189 | sum_sim += copy.deepcopy(s) 190 | l = int(len(net_glob_arr[0].keys())/5) + 1.0 191 | sum_sim /= (45.0*l) 192 | return sum_sim 193 | 194 | def test_with_loss(net_glob, dataset_test, args): 195 | # testing 196 | acc_test, loss_test = test_img(net_glob, dataset_test, args) 197 | 198 | print("Testing accuracy: {:.2f}".format(acc_test)) 199 | 200 | return acc_test.item(), loss_test -------------------------------------------------------------------------------- /Algorithm/Training_FedMR.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | from torch import nn 8 | import copy 9 | import numpy as np 10 | import random 11 | from models.Fed import Aggregation 12 | from utils.utils import save_result 13 | from utils.utils import save_model 14 | from models.test import test_img 15 | from models.Update import DatasetSplit 16 | from optimizer.Adabelief import AdaBelief 17 | 18 | 19 | class LocalUpdate_FedMR(object): 20 | def __init__(self, args, dataset=None, idxs=None, verbose=False): 21 | self.args = args 22 | self.loss_func = nn.CrossEntropyLoss() 23 | self.selected_clients = [] 24 | self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True) 25 | self.verbose = verbose 26 | 27 | def train(self, net): 28 | 29 | net.to(self.args.device) 30 | 31 | net.train() 32 | # train and update 33 | if self.args.optimizer == 'sgd': 34 | optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum) 35 | elif self.args.optimizer == 'adam': 36 | optimizer = torch.optim.Adam(net.parameters(), lr=self.args.lr) 37 | elif self.args.optimizer == 'adaBelief': 38 | optimizer = AdaBelief(net.parameters(), lr=self.args.lr) 39 | 40 | Predict_loss = 0 41 | 42 | for iter in range(self.args.local_ep): 43 | 44 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 45 | images, labels = images.to(self.args.device), labels.to(self.args.device) 46 | net.zero_grad() 47 | model_output = net(images) 48 | predictive_loss = self.loss_func(model_output['output'], labels) 49 | 50 | loss = predictive_loss 51 | Predict_loss += predictive_loss.item() 52 | 53 | loss.backward() 54 | optimizer.step() 55 | 56 | if self.verbose: 57 | info = '\nUser predict Loss={:.4f}'.format(Predict_loss / (self.args.local_ep * len(self.ldr_train))) 58 | print(info) 59 | 60 | # net.to('cpu') 61 | 62 | return net.state_dict() 63 | 64 | def recombination(w_locals, m): 65 | 66 | w_locals_new = copy.deepcopy(w_locals) 67 | 68 | nr = [i for i in range(m)] 69 | 70 | for k in w_locals[0].keys(): 71 | random.shuffle(nr) 72 | for i in range(m): 73 | w_locals_new[i][k] = w_locals[nr[i]][k] 74 | 75 | return w_locals_new 76 | 77 | def recombination_partition(w_locals, m, partition): 78 | is_partition = True 79 | 80 | w_locals_new = copy.deepcopy(w_locals) 81 | 82 | nr = [i for i in range(m)] 83 | 84 | p_idx = 0 85 | 86 | 87 | random.shuffle(nr) 88 | idx = 0.0 89 | layer_num = len(w_locals[0].keys()) 90 | cnt = 0 91 | for k in w_locals[0].keys(): 92 | if (partition == 0) or idx >= layer_num * partition*cnt: 93 | random.shuffle(nr) 94 | cnt = cnt + 1 95 | for i in range(m): 96 | w_locals_new[i][k] = w_locals[nr[i]][k] 97 | idx = idx + 1.0 98 | print(idx) 99 | print(partition) 100 | 101 | return w_locals_new 102 | 103 | def FedMR(args, net_glob, dataset_train, dataset_test, dict_users): 104 | net_glob.train() 105 | 106 | acc = [] 107 | w_locals = [] 108 | sim_arr = [] 109 | loss = [] 110 | train_loss = [] 111 | 112 | m = max(int(args.frac * args.num_users), 1) 113 | for i in range(m): 114 | w_locals.append(copy.deepcopy(net_glob.state_dict())) 115 | 116 | for iter in range(args.epochs): 117 | 118 | print('*' * 80) 119 | print('Round {:3d}'.format(iter)) 120 | 121 | 122 | m = max(int(args.frac * args.num_users), 1) 123 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 124 | for i, idx in enumerate(idxs_users): 125 | 126 | net_glob.load_state_dict(w_locals[i]) 127 | local = LocalUpdate_FedMR(args=args, dataset=dataset_train, idxs=dict_users[idx]) 128 | w = local.train(net=net_glob) 129 | w_locals[i] = copy.deepcopy(w) 130 | 131 | # update global weights 132 | w_glob = Aggregation(w_locals, None) # Global Model Generation 133 | 134 | # copy weight to net_glob 135 | net_glob.load_state_dict(w_glob) 136 | 137 | 138 | 139 | if iter % 10 == 9: 140 | item_acc,item_loss = test_with_loss(net_glob, dataset_test, args) 141 | tc, tl = test_with_loss(net_glob,dataset_train,args) 142 | acc.append(item_acc) 143 | loss.append(item_loss) 144 | train_loss.append(tl) 145 | sim_arr.append(sim(args, w_locals)) 146 | 147 | if iter >= args.first_stage_bound: 148 | w_locals = recombination(w_locals, m) # Model Recombination 149 | else: 150 | for i in range(len(w_locals)): 151 | w_locals[i] = copy.deepcopy(w_glob) 152 | 153 | 154 | 155 | save_result(acc, 'test_acc', args) 156 | save_result(sim_arr, 'sim', args) 157 | save_result(loss, 'test_loss', args) 158 | save_result(train_loss, 'test_train_loss', args) 159 | save_model(net_glob.state_dict(), 'test_model', args) 160 | 161 | 162 | def FedMR_Partition(args, net_glob, dataset_train, dataset_test, dict_users, partition): 163 | net_glob.train() 164 | 165 | acc = [] 166 | loss = [] 167 | w_locals = [] 168 | 169 | m = max(int(args.frac * args.num_users), 1) 170 | for i in range(m): 171 | w_locals.append(copy.deepcopy(net_glob.state_dict())) 172 | 173 | for iter in range(args.epochs): 174 | 175 | print('*' * 80) 176 | print('Round {:3d}'.format(iter)) 177 | 178 | 179 | m = max(int(args.frac * args.num_users), 1) 180 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 181 | for i, idx in enumerate(idxs_users): 182 | 183 | net_glob.load_state_dict(w_locals[i]) 184 | local = LocalUpdate_FedMR(args=args, dataset=dataset_train, idxs=dict_users[idx]) 185 | w = local.train(net=net_glob) 186 | w_locals[i] = copy.deepcopy(w) 187 | 188 | # update global weights 189 | w_glob = Aggregation(w_locals, None) # Global Model Generation 190 | 191 | # copy weight to net_glob 192 | net_glob.load_state_dict(w_glob) 193 | 194 | item_acc,item_loss = test_with_loss(net_glob, dataset_test, args) 195 | 196 | acc.append(item_acc) 197 | loss.append(item_loss) 198 | 199 | if iter >= args.first_stage_bound: 200 | w_locals = recombination_partition(w_locals, m, partition) # Model Recombination 201 | else: 202 | for i in range(len(w_locals)): 203 | w_locals[i] = copy.deepcopy(w_glob) 204 | 205 | 206 | save_result(acc, 'test_acc', args) 207 | save_result(loss, 'test_loss', args) 208 | 209 | 210 | def sim(args,net_glob_arr): 211 | model_num = int(args.num_users*args.frac) 212 | sim_tab = [[0 for _ in range(model_num)] for _ in range(model_num)] 213 | minsum = 10 214 | subminsum = 10 215 | sum_sim = 0.0 216 | for k in range(model_num): 217 | sim_arr = [] 218 | idx = 0 219 | # sim_sum = 0.0 220 | for j in range(k): 221 | sim = 0.0 222 | s = 0.0 223 | dict_a = torch.Tensor(0) 224 | dict_b = torch.Tensor(0) 225 | cnt = 0 226 | for p in net_glob_arr[k].keys(): 227 | a = net_glob_arr[k][p] 228 | b = net_glob_arr[j][p] 229 | a = a.view(-1) 230 | b = b.view(-1) 231 | 232 | 233 | if cnt == 0: 234 | dict_a = a 235 | dict_b = b 236 | else: 237 | dict_a = torch.cat((dict_a, a), dim=0) 238 | dict_b = torch.cat((dict_b, b), dim=0) 239 | 240 | if cnt % 5 == 0: 241 | sub_a = a 242 | sub_b = b 243 | else: 244 | sub_a = torch.cat((sub_a, a), dim=0) 245 | sub_b = torch.cat((sub_b, b), dim=0) 246 | # if not a.equal(b): 247 | # sub_a = torch.cat((sub_a, a), dim=0) 248 | # sub_b = torch.cat((sub_b, b), dim=0) 249 | 250 | if cnt % 5 == 4: 251 | s+= F.cosine_similarity(sub_a, sub_b, dim=0) 252 | cnt += 1 253 | # print(sim) 254 | s+= F.cosine_similarity(sub_a, sub_b, dim=0) 255 | sim = F.cosine_similarity(dict_a, dict_b, dim=0) 256 | # print (sim) 257 | sim_arr.append(sim) 258 | sim_tab[k][j] = sim 259 | sim_tab[j][k] = sim 260 | sum_sim += copy.deepcopy(s) 261 | l = int(len(net_glob_arr[0].keys())/5) + 1.0 262 | sum_sim /= (45.0*l) 263 | return sum_sim 264 | 265 | 266 | def test(net_glob, dataset_test, args): 267 | # testing 268 | acc_test, loss_test = test_img(net_glob, dataset_test, args) 269 | 270 | print("Testing accuracy: {:.2f}".format(acc_test)) 271 | 272 | return acc_test.item() 273 | 274 | def test_with_loss(net_glob, dataset_test, args): 275 | # testing 276 | acc_test, loss_test = test_img(net_glob, dataset_test, args) 277 | 278 | print("Testing accuracy: {:.2f}".format(acc_test)) 279 | 280 | return acc_test.item(), loss_test 281 | -------------------------------------------------------------------------------- /Algorithm/Training_FedMut.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.utils.data import DataLoader 4 | from torch import nn 5 | import copy 6 | import numpy as np 7 | import random 8 | from models.Fed import Aggregation 9 | from utils.utils import save_result, save_fedmut_result, save_model 10 | from models.test import test_img 11 | from models.Update import DatasetSplit 12 | from optimizer.Adabelief import AdaBelief 13 | 14 | 15 | class LocalUpdate_FedMut(object): 16 | def __init__(self, args, dataset=None, idxs=None, verbose=False): 17 | self.args = args 18 | self.loss_func = nn.CrossEntropyLoss() 19 | self.selected_clients = [] 20 | self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True) 21 | self.verbose = verbose 22 | 23 | def train(self, net): 24 | 25 | net.to(self.args.device) 26 | 27 | net.train() 28 | # train and update 29 | if self.args.optimizer == 'sgd': 30 | optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum) 31 | elif self.args.optimizer == 'adam': 32 | optimizer = torch.optim.Adam(net.parameters(), lr=self.args.lr) 33 | elif self.args.optimizer == 'adaBelief': 34 | optimizer = AdaBelief(net.parameters(), lr=self.args.lr) 35 | 36 | Predict_loss = 0 37 | 38 | for iter in range(self.args.local_ep): 39 | 40 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 41 | images, labels = images.to(self.args.device), labels.to(self.args.device) 42 | net.zero_grad() 43 | model_output = net(images) 44 | predictive_loss = self.loss_func(model_output['output'], labels) 45 | 46 | loss = predictive_loss 47 | Predict_loss += predictive_loss.item() 48 | 49 | loss.backward() 50 | optimizer.step() 51 | 52 | if self.verbose: 53 | info = '\nUser predict Loss={:.4f}'.format(Predict_loss / (self.args.local_ep * len(self.ldr_train))) 54 | print(info) 55 | 56 | # net.to('cpu') 57 | 58 | return net.state_dict() 59 | 60 | def FedMut(args, net_glob, dataset_train, dataset_test, dict_users): 61 | net_glob.train() 62 | acc = [] 63 | w_locals = [] 64 | sim_arr = [] 65 | 66 | m = max(int(args.frac * args.num_users), 1) 67 | for i in range(m): 68 | w_locals.append(copy.deepcopy(net_glob.state_dict())) 69 | 70 | delta_list = [] 71 | max_rank = 0 72 | w_old = copy.deepcopy(net_glob.state_dict()) 73 | w_old_s1 = copy.deepcopy(net_glob.state_dict()) 74 | 75 | for iter in range(args.epochs): 76 | w_old = copy.deepcopy(net_glob.state_dict()) 77 | print('*' * 80) 78 | print('Round {:3d}'.format(iter)) 79 | 80 | 81 | m = max(int(args.frac * args.num_users), 1) 82 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 83 | for i, idx in enumerate(idxs_users): 84 | 85 | net_glob.load_state_dict(w_locals[i]) 86 | local = LocalUpdate_FedMut(args=args, dataset=dataset_train, idxs=dict_users[idx]) 87 | w = local.train(net=net_glob) 88 | w_locals[i] = copy.deepcopy(w) 89 | 90 | # update global weights 91 | w_glob = Aggregation(w_locals, None) # Global Model Generation 92 | 93 | # copy weight to net_glob 94 | net_glob.load_state_dict(w_glob) 95 | 96 | if iter % 10 == 9: 97 | acc.append(test(net_glob, dataset_test, args)) 98 | 99 | w_delta = FedSub(w_glob, w_old, 1.0) 100 | rank = delta_rank(args,w_delta) 101 | if rank > max_rank: 102 | max_rank = rank 103 | alpha = args.radius 104 | w_locals = mutation_spread(args, iter, w_glob, w_old, w_locals, m, w_delta, alpha) 105 | 106 | 107 | 108 | save_fedmut_result(acc, 'test_acc', args) 109 | # save_model(net_glob.state_dict(), 'test_model', args) 110 | # save_result(sim_arr, 'sim', args) 111 | 112 | 113 | def mutation_spread(args, iter, w_glob, w_old, w_locals, m, w_delta, alpha): 114 | # w_delta = FedSub(w_glob,w_old,(args.radius - args.min_radius) * (1.0 - iter/args.epochs) + args.min_radius) 115 | # if iter/args.epochs > 0.5: 116 | # w_delta = FedSub(w_glob,w_old,(args.radius - args.min_radius) * (1.0 - iter/args.epochs)*2 + args.min_radius) 117 | # else: 118 | # w_delta = FedSub(w_glob,w_old,(args.radius - args.min_radius) * (iter/args.epochs)*2 + args.min_radius) 119 | # w_delta = FedSub(w_glob, w_old, args.radius) 120 | 121 | 122 | w_locals_new = [] 123 | ctrl_cmd_list = [] 124 | ctrl_rate = args.mut_acc_rate * (1.0 - min(iter*1.0/args.mut_bound,1.0)) 125 | print (ctrl_rate) 126 | 127 | for k in w_glob.keys(): 128 | ctrl_list = [] 129 | for i in range(0,int(m/2)): 130 | ctrl = random.random() 131 | if ctrl > 0.5: 132 | ctrl_list.append(1.0) 133 | ctrl_list.append(1.0 * (-1.0 + ctrl_rate)) 134 | else: 135 | ctrl_list.append(1.0 * (-1.0 + ctrl_rate)) 136 | ctrl_list.append(1.0) 137 | random.shuffle(ctrl_list) 138 | ctrl_cmd_list.append(ctrl_list) 139 | cnt = 0 140 | for j in range(m): 141 | w_sub = copy.deepcopy(w_glob) 142 | if not (cnt == m -1 and m%2 == 1): 143 | ind = 0 144 | for k in w_sub.keys(): 145 | w_sub[k] = w_sub[k] + w_delta[k]*ctrl_cmd_list[ind][j]*alpha 146 | ind += 1 147 | cnt += 1 148 | w_locals_new.append(w_sub) 149 | 150 | 151 | return w_locals_new 152 | 153 | 154 | def test(net_glob, dataset_test, args): 155 | # testing 156 | acc_test, loss_test = test_img(net_glob, dataset_test, args) 157 | 158 | print("Testing accuracy: {:.2f}".format(acc_test)) 159 | 160 | return acc_test.item() 161 | 162 | 163 | def FedSub(w, w_old, weight): 164 | w_sub = copy.deepcopy(w) 165 | for k in w_sub.keys(): 166 | w_sub[k] = (w[k] - w_old[k])*weight 167 | 168 | return w_sub 169 | 170 | def delta_rank(args,delta_dict): 171 | cnt = 0 172 | dict_a = torch.Tensor(0) 173 | s = 0 174 | for p in delta_dict.keys(): 175 | a = delta_dict[p] 176 | a = a.view(-1) 177 | if cnt == 0: 178 | dict_a = a 179 | else: 180 | dict_a = torch.cat((dict_a, a), dim=0) 181 | 182 | cnt += 1 183 | # print(sim) 184 | s = torch.norm(dict_a, dim=0) 185 | return s -------------------------------------------------------------------------------- /Algorithm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__init__.py -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_Asyn_FedSA.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_Asyn_FedSA.cpython-310.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_Asyn_FedSA.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_Asyn_FedSA.cpython-312.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_Asyn_FedSA.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_Asyn_FedSA.cpython-37.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_Asyn_FedSA.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_Asyn_FedSA.cpython-39.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_Asyn_GitFL.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_Asyn_GitFL.cpython-310.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_Asyn_GitFL.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_Asyn_GitFL.cpython-312.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_Asyn_GitFL.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_Asyn_GitFL.cpython-37.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_Asyn_GitFL.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_Asyn_GitFL.cpython-39.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_BranchyFedAvg.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_BranchyFedAvg.cpython-310.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_BranchyFedAvg.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_BranchyFedAvg.cpython-312.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_BranchyFedAvg.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_BranchyFedAvg.cpython-39.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_CFL.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_CFL.cpython-310.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_CFL.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_CFL.cpython-312.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_CFL.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_CFL.cpython-37.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_CFL.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_CFL.cpython-39.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedASAM.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedASAM.cpython-310.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedASAM.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedASAM.cpython-312.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedBack.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedBack.cpython-312.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedCross.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedCross.cpython-310.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedCross.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedCross.cpython-312.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedCross.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedCross.cpython-37.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedCross.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedCross.cpython-39.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedDC.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedDC.cpython-310.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedDC.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedDC.cpython-37.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedDC.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedDC.cpython-39.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedDC_new.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedDC_new.cpython-310.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedDC_new.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedDC_new.cpython-312.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedExP.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedExP.cpython-310.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedExP.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedExP.cpython-312.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedExP.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedExP.cpython-39.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedGA.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedGA.cpython-310.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedGA.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedGA.cpython-312.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedGA.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedGA.cpython-39.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedGen.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedGen.cpython-310.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedGen.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedGen.cpython-312.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedGen.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedGen.cpython-37.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedGen.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedGen.cpython-39.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedIndenp.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedIndenp.cpython-310.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedIndenp.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedIndenp.cpython-312.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedIndenp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedIndenp.cpython-37.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedIndenp.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedIndenp.cpython-39.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedJellyfish.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedJellyfish.cpython-39.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedMR.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedMR.cpython-310.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedMR.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedMR.cpython-312.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedMR.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedMR.cpython-37.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedMR.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedMR.cpython-39.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedMut.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedMut.cpython-310.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedMut.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedMut.cpython-312.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Training_FedMut.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Training_FedMut.cpython-39.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Triaining_Scaffold.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Triaining_Scaffold.cpython-310.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Triaining_Scaffold.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Triaining_Scaffold.cpython-312.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Triaining_Scaffold.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Triaining_Scaffold.cpython-37.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/Triaining_Scaffold.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/Triaining_Scaffold.cpython-39.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /Algorithm/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/Algorithm/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Ming Hu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FedMR 2 | The source code for **Is Aggregation the Only Choice? Federated Learning via Layer-wise Model Recombination**(Accepted by KDD2024) 3 | 4 | https://dl.acm.org/doi/abs/10.1145/3637528.3671722 5 | 6 | -------------------------------------------------------------------------------- 7 | 8 | ## 1. Environment setting requirements 9 | * Python 3.7 10 | * PyTorch 11 | 12 | ## 2. Instruction 13 | ### 2.1 Parameter 14 | #### 2.1.1 Dataset Setting 15 | `--dataset ` 16 | 17 | We can set ‘cifar10’, ‘cifar100’ and ‘femnist’ for CIFAR-10, CIFAR-100, and FEMNIST. 18 | 19 | #### 2.1.2 Model Settings 20 | `--model ` 21 | 22 | We can set ‘resnet20’, ‘vgg’, and ‘cnn’ for ResNet-20, VGG-16, and CNN model. 23 | 24 | `--num_classes ` 25 | 26 | Set the number of classes Set 10 for CIFAR-10 27 | 28 | Set 20 for CIFAR-100 29 | 30 | Set 62 for FEMNIST 31 | 32 | `--num_channels ` 33 | 34 | Set the number of channels of datasets. 35 | Set 3 for CIFAR-10 and CIFAR-100. Set 1 for FEMNIST. 36 | 37 | #### 2.1.3 Data heterogeneity 38 | `--iid <0 or 1>` 39 | 40 | 0 – set non-iid 1 – set iid 41 | 42 | `--data_beta <𝛼>` 43 | 44 | Set the 𝛂 for the Dirichlet distribution 45 | 46 | `--generate_data <0 or 1>` 47 | 48 | 0 – use the existing configuration of 𝑫𝒊𝒓(𝜶) 1 – generate a new configuration of 𝑫𝒊𝒓(𝜶) 49 | 50 | #### 2.1.2 FL Settings 51 | `--epochs ` 52 | 53 | Set the number of training rounds. 54 | 55 | #### 2.1.2 FedMR and Baseline Settings 56 | `-- algorithm ` 57 | 58 | Set the baseline name: 59 | * FedMR 60 | * FedAvg 61 | * FedProx 62 | * FedGen 63 | * ClusteredSampling 64 | * FedIndep 65 | 66 | `-- first_stage_bound ` 67 | 68 | Set the round number of the first stage for FedMR 69 | 70 | Tips: set 50 or 100 for VGG model 71 | 72 | #### 2.1.3 Loss-landscape 73 | Please use the tool as follows to generate the figure of loss-landscape: 74 | 75 | https://github.com/tomgoldstein/loss-landscape 76 | 77 | ## 3. Citation 78 | ``` 79 | @inproceedings{hu2024aggregation, 80 | title={Is Aggregation the Only Choice? Federated Learning via Layer-wise Model Recombination}, 81 | author={Hu, Ming and Yue, Zhihao and Xie, Xiaofei and Chen, Cheng and Huang, Yihao and Wei, Xian and Lian, Xiang and Liu, Yang and Chen, Mingsong}, 82 | booktitle={Proceedings of the 30th ACM SIGKDD Conference on Knowledge Discovery and Data Mining}, 83 | pages={1096--1107}, 84 | year={2024} 85 | } 86 | ``` 87 | -------------------------------------------------------------------------------- 88 | 89 | If you have any questions, please contact me at hu.ming.work@gmail.com. 90 | 91 | :blush::blush::blush: ~~~ Have a nice day ~~~ :blush::blush::blush: 92 | -------------------------------------------------------------------------------- /dis_plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | from utils.options import args_parser 5 | from utils.get_dataset import get_dataset 6 | 7 | 8 | num_users = 100 9 | print_num_users = 10 10 | num_classes = 10 11 | 12 | 13 | def get_distribution(dataset_train): 14 | min_size = 0 15 | min_require_size = 1 16 | K = num_classes 17 | y_train = np.array(dataset_train.targets) 18 | N = len(dataset_train) 19 | dict_users = {} 20 | 21 | idx_batch = None 22 | while min_size < min_require_size: 23 | idx_batch = [[] for _ in range(num_users)] 24 | for k in range(K): 25 | idx_k = np.where(y_train == k)[0] 26 | np.random.shuffle(idx_k) 27 | proportions = np.random.dirichlet(np.repeat(1.0, num_users)) 28 | proportions = np.array( 29 | [p * (len(idx_j) < N / num_users) for p, idx_j in zip(proportions, idx_batch)]) 30 | proportions = proportions / proportions.sum() 31 | proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] 32 | idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))] 33 | min_size = min([len(idx_j) for idx_j in idx_batch]) 34 | 35 | for j in range(num_users): 36 | # np.random.shuffle(idx_batch[j]) 37 | dict_users[j] = idx_batch[j] 38 | 39 | return dict_users 40 | 41 | args = args_parser() 42 | dataset_train, _, _ = get_dataset(args) 43 | dict_users = get_distribution(dataset_train) 44 | array = np.zeros((print_num_users, num_classes)) 45 | for i in range(print_num_users): 46 | print(len(dict_users[i])) 47 | for j in dict_users[i]: 48 | array[i][dataset_train[j][1]] += 1 49 | 50 | print(array) 51 | array = array.reshape(print_num_users * num_classes) 52 | # print(array) 53 | 54 | # print(sum(array)) 55 | # # 示例数据 56 | 57 | # 客户ID 58 | x = np.arange(0, print_num_users) 59 | x = np.tile(x, num_classes) 60 | x = np.sort(x) 61 | # # 类别 62 | y = np.arange(0, num_classes) 63 | y = np.tile(y, print_num_users) 64 | 65 | # sizes = np.random.randint(1, 100, 100) # 气泡大小 66 | # print(len(sizes)) 67 | 68 | # 设置绘图风格 69 | sns.set(style="whitegrid") 70 | 71 | # 绘制气泡热图 72 | sns.scatterplot(x=x, y=y, size=array, sizes=(1, 400)) 73 | plt.tick_params(labelsize=20) 74 | plt.subplots_adjust(left=0.12,right=0.95,top=0.95,bottom=0.17) 75 | plt.xticks(x) 76 | # plt.yticks(y) 77 | 78 | # plt.title('Data Distribution') 79 | plt.xlabel('Client ID',fontdict={'fontsize':24}) 80 | plt.ylabel('Class ID',fontdict={'fontsize':24}) 81 | plt.legend([], [], frameon=False) 82 | plt.savefig('./output/111/cifar10.pdf') 83 | # plt.show() 84 | -------------------------------------------------------------------------------- /main_fed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import copy 8 | 9 | from utils.options import args_parser 10 | from utils.set_seed import set_random_seed 11 | from models.Update import * 12 | from models.Nets import * 13 | from models.Fed import Aggregation 14 | from models.test import test_img 15 | from models.resnetcifar import * 16 | from models import * 17 | from utils.get_dataset import get_dataset 18 | from utils.utils import save_result,save_model 19 | from Algorithm.Training_FedGen import FedGen 20 | from Algorithm.Training_FedMR import FedMR 21 | from Algorithm.Training_FedMR import FedMR_Partition 22 | from Algorithm.Training_FedIndenp import FedIndep 23 | from Algorithm.Training_FedMut import FedMut 24 | from Algorithm.Training_FedExP import FedExP 25 | from Algorithm.Training_FedASAM import FedASAM 26 | 27 | def FedAvg(net_glob, dataset_train, dataset_test, dict_users): 28 | 29 | net_glob.train() 30 | 31 | times = [] 32 | total_time = 0 33 | 34 | # training 35 | acc = [] 36 | loss = [] 37 | train_loss=[] 38 | 39 | for iter in range(args.epochs): 40 | 41 | print('*'*80) 42 | print('Round {:3d}'.format(iter)) 43 | 44 | 45 | w_locals = [] 46 | lens = [] 47 | m = max(int(args.frac * args.num_users), 1) 48 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 49 | for idx in idxs_users: 50 | local = LocalUpdate_FedAvg(args=args, dataset=dataset_train, idxs=dict_users[idx]) 51 | w = local.train(net=copy.deepcopy(net_glob).to(args.device)) 52 | 53 | w_locals.append(copy.deepcopy(w)) 54 | lens.append(len(dict_users[idx])) 55 | # update global weights 56 | w_glob = Aggregation(w_locals, lens) 57 | 58 | # copy weight to net_glob 59 | net_glob.load_state_dict(w_glob) 60 | 61 | if iter % 10 == 9: 62 | item_acc,item_loss = test_with_loss(net_glob, dataset_test, args) 63 | ta,tl = test_with_loss(net_glob, dataset_train, args) 64 | acc.append(item_acc) 65 | loss.append(item_loss) 66 | train_loss.append(tl) 67 | 68 | save_result(acc, 'test_acc', args) 69 | save_result(loss, 'test_loss', args) 70 | save_result(train_loss, 'test_train_loss', args) 71 | save_model(net_glob.state_dict(), 'test_model', args) 72 | 73 | 74 | def FedProx(net_glob, dataset_train, dataset_test, dict_users): 75 | net_glob.train() 76 | 77 | acc = [] 78 | 79 | for iter in range(args.epochs): 80 | 81 | print('*' * 80) 82 | print('Round {:3d}'.format(iter)) 83 | 84 | w_locals = [] 85 | lens = [] 86 | m = max(int(args.frac * args.num_users), 1) 87 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 88 | for idx in idxs_users: 89 | local = LocalUpdate_FedProx(args=args, glob_model=net_glob, dataset=dataset_train, idxs=dict_users[idx]) 90 | w = local.train(net=copy.deepcopy(net_glob).to(args.device)) 91 | 92 | w_locals.append(copy.deepcopy(w)) 93 | lens.append(len(dict_users[idx])) 94 | # update global weights 95 | w_glob = Aggregation(w_locals, lens) 96 | 97 | # copy weight to net_glob 98 | net_glob.load_state_dict(w_glob) 99 | 100 | if iter % 10 == 9: 101 | acc.append(test(net_glob, dataset_test, args)) 102 | 103 | save_result(acc, 'test_acc', args) 104 | 105 | from utils.clustering import * 106 | from scipy.cluster.hierarchy import linkage 107 | def ClusteredSampling(net_glob, dataset_train, dataset_test, dict_users): 108 | 109 | net_glob.to('cpu') 110 | 111 | n_samples = np.array([len(dict_users[idx]) for idx in dict_users.keys()]) 112 | weights = n_samples / np.sum(n_samples) 113 | n_sampled = max(int(args.frac * args.num_users), 1) 114 | 115 | gradients = get_gradients('', net_glob, [net_glob] * len(dict_users)) 116 | 117 | net_glob.train() 118 | 119 | # training 120 | acc = [] 121 | 122 | for iter in range(args.epochs): 123 | 124 | print('*' * 80) 125 | print('Round {:3d}'.format(iter)) 126 | 127 | previous_global_model = copy.deepcopy(net_glob) 128 | clients_models = [] 129 | sampled_clients_for_grad = [] 130 | 131 | # GET THE CLIENTS' SIMILARITY MATRIX 132 | if iter == 0: 133 | sim_matrix = get_matrix_similarity_from_grads( 134 | gradients, distance_type=args.sim_type 135 | ) 136 | 137 | # GET THE DENDROGRAM TREE ASSOCIATED 138 | linkage_matrix = linkage(sim_matrix, "ward") 139 | 140 | distri_clusters = get_clusters_with_alg2( 141 | linkage_matrix, n_sampled, weights 142 | ) 143 | 144 | w_locals = [] 145 | lens = [] 146 | idxs_users = sample_clients(distri_clusters) 147 | for idx in idxs_users: 148 | local = LocalUpdate_ClientSampling(args=args, dataset=dataset_train, idxs=dict_users[idx]) 149 | local_model = local.train(net=copy.deepcopy(net_glob).to(args.device)) 150 | local_model.to('cpu') 151 | 152 | w_locals.append(copy.deepcopy(local_model.state_dict())) 153 | lens.append(len(dict_users[idx])) 154 | 155 | clients_models.append(copy.deepcopy(local_model)) 156 | sampled_clients_for_grad.append(idx) 157 | 158 | del local_model 159 | # update global weights 160 | w_glob = Aggregation(w_locals, lens) 161 | 162 | # copy weight to net_glob 163 | net_glob.load_state_dict(w_glob) 164 | 165 | gradients_i = get_gradients( 166 | '', previous_global_model, clients_models 167 | ) 168 | for idx, gradient in zip(sampled_clients_for_grad, gradients_i): 169 | gradients[idx] = gradient 170 | 171 | sim_matrix = get_matrix_similarity_from_grads_new( 172 | gradients, distance_type=args.sim_type, idx=idxs_users, metric_matrix=sim_matrix 173 | ) 174 | 175 | net_glob.to(args.device) 176 | if iter % 10 == 9: 177 | acc.append(test(net_glob, dataset_test, args)) 178 | net_glob.to('cpu') 179 | 180 | del clients_models 181 | 182 | save_result(acc, 'test_acc', args) 183 | 184 | 185 | 186 | 187 | def test(net_glob, dataset_test, args): 188 | 189 | # testing 190 | acc_test, loss_test = test_img(net_glob, dataset_test, args) 191 | 192 | print("Testing accuracy: {:.2f}".format(acc_test)) 193 | 194 | return acc_test.item() 195 | 196 | def test_with_loss(net_glob, dataset_test, args): 197 | 198 | # testing 199 | acc_test, loss_test = test_img(net_glob, dataset_test, args) 200 | 201 | print("Testing accuracy: {:.2f}".format(acc_test)) 202 | 203 | return acc_test.item(), loss_test 204 | 205 | if __name__ == '__main__': 206 | # parse args 207 | args = args_parser() 208 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 209 | 210 | set_random_seed(args.seed) 211 | 212 | dataset_train, dataset_test, dict_users = get_dataset(args) 213 | 214 | if args.model == 'cnn' and args.dataset == 'femnist': 215 | net_glob = CNNFashionMnist(args) 216 | elif args.model == 'cnn' and args.dataset == 'mnist': 217 | net_glob = CNNMnist(args) 218 | elif args.use_project_head: 219 | net_glob = ModelFedCon(args.model, args.out_dim, args.num_classes) 220 | elif 'cifar' in args.dataset and 'cnn' in args.model: 221 | net_glob = CNNCifar(args) 222 | elif args.model == 'resnet20' and args.dataset == 'mnist': 223 | net_glob = ResNet20_mnist(args=args).to(args.device) 224 | elif args.model == 'resnet20' and (args.dataset == 'fashion-mnist' or args.dataset == 'femnist'): 225 | net_glob = ResNet20_mnist(args=args).to(args.device) 226 | elif args.model == 'resnet20' and args.dataset == 'cifar10': 227 | net_glob = ResNet20_cifar(args=args).to(args.device) 228 | elif args.model == 'resnet20' and args.dataset == 'cifar100': 229 | net_glob = ResNet20_cifar(args=args).to(args.device) 230 | elif 'resnet' in args.model: 231 | if args.dataset == 'mnist' or args.dataset == 'fashion-mnist' or args.dataset == 'femnist': 232 | net_glob = ResNet18_MNIST(num_classes = args.num_classes) 233 | else: 234 | net_glob = ResNet18_cifar10(num_classes = args.num_classes) 235 | elif 'cifar' in args.dataset and args.model == 'vgg': 236 | net_glob = VGG16(args) 237 | elif 'mnist' in args.dataset and args.model == 'vgg': 238 | net_glob = VGG16_mnist(args) 239 | 240 | 241 | net_glob.to(args.device) 242 | print(net_glob) 243 | 244 | if args.algorithm == 'FedAvg': 245 | FedAvg(net_glob, dataset_train, dataset_test, dict_users) 246 | elif args.algorithm == 'FedProx': 247 | FedProx(net_glob, dataset_train, dataset_test, dict_users) 248 | elif args.algorithm == 'ClusteredSampling': 249 | ClusteredSampling(net_glob, dataset_train, dataset_test, dict_users) 250 | elif args.algorithm == 'FedGen': 251 | FedGen(args, net_glob, dataset_train, dataset_test, dict_users) 252 | elif args.algorithm == 'FedMR': 253 | partition = args.fedmr_partition 254 | if partition == 0: 255 | FedMR(args, net_glob, dataset_train, dataset_test, dict_users) 256 | else: 257 | FedMR_Partition(args, net_glob, dataset_train, dataset_test, dict_users,partition) 258 | elif args.algorithm == 'FedIndep': 259 | FedIndep(args, net_glob, dataset_train, dataset_test, dict_users) 260 | elif args.algorithm == 'FedMut': 261 | FedMut(args, net_glob, dataset_train, dataset_test, dict_users) 262 | elif args.algorithm == 'FedExP': 263 | FedExP(args, net_glob, dataset_train, dataset_test, dict_users) 264 | elif args.algorithm == 'FedASAM': 265 | FedASAM(args, net_glob, dataset_train, dataset_test, dict_users) 266 | 267 | 268 | 269 | -------------------------------------------------------------------------------- /main_nn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | import torch.optim as optim 13 | from torchvision import datasets, transforms 14 | 15 | from utils.options import args_parser 16 | from models.Nets import MLP, CNNMnist, CNNCifar 17 | 18 | 19 | def test(net_g, data_loader): 20 | # testing 21 | net_g.eval() 22 | test_loss = 0 23 | correct = 0 24 | l = len(data_loader) 25 | for idx, (data, target) in enumerate(data_loader): 26 | data, target = data.to(args.device), target.to(args.device) 27 | log_probs = net_g(data) 28 | test_loss += F.cross_entropy(log_probs, target).item() 29 | y_pred = log_probs.data.max(1, keepdim=True)[1] 30 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 31 | 32 | test_loss /= len(data_loader.dataset) 33 | print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format( 34 | test_loss, correct, len(data_loader.dataset), 35 | 100. * correct / len(data_loader.dataset))) 36 | 37 | return correct, test_loss 38 | 39 | 40 | if __name__ == '__main__': 41 | # parse args 42 | args = args_parser() 43 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 44 | 45 | torch.manual_seed(args.seed) 46 | 47 | # load dataset and split users 48 | if args.dataset == 'mnist': 49 | dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True, 50 | transform=transforms.Compose([ 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.1307,), (0.3081,)) 53 | ])) 54 | img_size = dataset_train[0][0].shape 55 | elif args.dataset == 'cifar': 56 | transform = transforms.Compose( 57 | [transforms.ToTensor(), 58 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 59 | dataset_train = datasets.CIFAR10('./data/cifar', train=True, transform=transform, target_transform=None, download=True) 60 | img_size = dataset_train[0][0].shape 61 | else: 62 | exit('Error: unrecognized dataset') 63 | 64 | # build model 65 | if args.model == 'cnn' and args.dataset == 'cifar': 66 | net_glob = CNNCifar(args=args).to(args.device) 67 | elif args.model == 'cnn' and args.dataset == 'mnist': 68 | net_glob = CNNMnist(args=args).to(args.device) 69 | elif args.model == 'mlp': 70 | len_in = 1 71 | for x in img_size: 72 | len_in *= x 73 | net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).to(args.device) 74 | else: 75 | exit('Error: unrecognized model') 76 | print(net_glob) 77 | 78 | # training 79 | optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum) 80 | train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True) 81 | 82 | list_loss = [] 83 | net_glob.train() 84 | for epoch in range(args.epochs): 85 | batch_loss = [] 86 | for batch_idx, (data, target) in enumerate(train_loader): 87 | data, target = data.to(args.device), target.to(args.device) 88 | optimizer.zero_grad() 89 | output = net_glob(data) 90 | loss = F.cross_entropy(output, target) 91 | loss.backward() 92 | optimizer.step() 93 | if batch_idx % 50 == 0: 94 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 95 | epoch, batch_idx * len(data), len(train_loader.dataset), 96 | 100. * batch_idx / len(train_loader), loss.item())) 97 | batch_loss.append(loss.item()) 98 | loss_avg = sum(batch_loss)/len(batch_loss) 99 | print('\nTrain loss:', loss_avg) 100 | list_loss.append(loss_avg) 101 | 102 | # plot loss 103 | plt.figure() 104 | plt.plot(range(len(list_loss)), list_loss) 105 | plt.xlabel('epochs') 106 | plt.ylabel('train loss') 107 | plt.savefig('./log/nn_{}_{}_{}.png'.format(args.dataset, args.model, args.epochs)) 108 | 109 | # testing 110 | if args.dataset == 'mnist': 111 | dataset_test = datasets.MNIST('./data/mnist/', train=False, download=True, 112 | transform=transforms.Compose([ 113 | transforms.ToTensor(), 114 | transforms.Normalize((0.1307,), (0.3081,)) 115 | ])) 116 | test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False) 117 | elif args.dataset == 'cifar': 118 | transform = transforms.Compose( 119 | [transforms.ToTensor(), 120 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 121 | dataset_test = datasets.CIFAR10('./data/cifar', train=False, transform=transform, target_transform=None, download=True) 122 | test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False) 123 | else: 124 | exit('Error: unrecognized dataset') 125 | 126 | print('test on', len(dataset_test), 'samples') 127 | test_acc, test_loss = test(net_glob, test_loader) 128 | -------------------------------------------------------------------------------- /models/Fed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import copy 6 | import torch 7 | 8 | 9 | def Aggregation(w, lens): 10 | w_avg = None 11 | if lens == None: 12 | total_count = len(w) 13 | lens = [] 14 | for i in range(len(w)): 15 | lens.append(1.0) 16 | else: 17 | total_count = sum(lens) 18 | 19 | for i in range(0, len(w)): 20 | if i == 0: 21 | w_avg = copy.deepcopy(w[0]) 22 | for k in w_avg.keys(): 23 | w_avg[k] = w[i][k] * lens[i] 24 | else: 25 | for k in w_avg.keys(): 26 | w_avg[k] += w[i][k] * lens[i] 27 | 28 | for k in w_avg.keys(): 29 | w_avg[k] = torch.div(w_avg[k], total_count) 30 | 31 | return w_avg 32 | 33 | 34 | def Sub(w, w_sub): 35 | w_result = copy.deepcopy(w) 36 | for k in w_result.keys(): 37 | w_result[k] = w[k] - w_sub[k] 38 | return w_result 39 | 40 | def Add(w, w_add): 41 | w_result = None 42 | w_result = copy.deepcopy(w) 43 | for k in w_result.keys(): 44 | w_result[k] = w[k] + w_add[k] 45 | return w_result 46 | 47 | def Div(w, v): 48 | w_result = None 49 | w_result = copy.deepcopy(w) 50 | for k in w_result.keys(): 51 | w_result[k] = w[k]/v 52 | return w_result 53 | 54 | def Mul(w, v): 55 | w_result = None 56 | w_result = copy.deepcopy(w) 57 | for k in w_result.keys(): 58 | w_result[k] = w[k]*v 59 | return w_result 60 | 61 | 62 | def Weighted_Aggregation_FedASync(w_local, w_global, alpha): 63 | for i in w_local.keys(): 64 | w_global[i] = alpha * w_local[i] + (1 - alpha) * w_global[i] 65 | return w_global 66 | 67 | 68 | def Weighted_Aggregation_FedSA(update_w, lens, w_global): 69 | w_avg = None 70 | total_count = sum(lens.values()) 71 | alpha = sum([lens[idx] / total_count for idx in update_w.keys()]) 72 | 73 | for i, idx in enumerate(update_w.keys()): 74 | if i == 0: 75 | w_avg = copy.deepcopy(update_w[idx]) 76 | for k in w_avg.keys(): 77 | w_avg[k] = update_w[idx][k] * lens[idx] 78 | else: 79 | for k in w_avg.keys(): 80 | w_avg[k] += update_w[idx][k] * lens[idx] 81 | 82 | for k in w_avg.keys(): 83 | w_avg[k] = torch.div(w_avg[k], total_count) 84 | # return w_avg 85 | 86 | for i in w_avg.keys(): 87 | w_global[i] = w_avg[i] + (1 - alpha) * w_global[i] 88 | return w_global 89 | -------------------------------------------------------------------------------- /models/MobileNetV2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class LinearBottleNeck(nn.Module): 6 | def __init__(self, in_channels, out_channels, stride, t): 7 | super(LinearBottleNeck, self).__init__() 8 | 9 | self.residual = nn.Sequential( 10 | nn.Conv2d(in_channels, in_channels * t, 1), 11 | nn.BatchNorm2d(in_channels * t), 12 | nn.ReLU6(inplace=True), 13 | 14 | nn.Conv2d(in_channels * t, in_channels * t, 3, stride=stride, padding=1, groups=in_channels * t), 15 | nn.BatchNorm2d(in_channels * t), 16 | nn.ReLU6(inplace=True), 17 | 18 | nn.Conv2d(in_channels * t, out_channels, 1), 19 | nn.BatchNorm2d(out_channels) 20 | ) 21 | 22 | self.stride = stride 23 | self.in_channels = in_channels 24 | self.out_channels = out_channels 25 | 26 | def forward(self, x): 27 | residual = self.residual(x) 28 | 29 | if self.stride == 1 and self.in_channels == self.out_channels: 30 | residual += x 31 | 32 | return residual 33 | 34 | 35 | class MobileNetV2(nn.Module): 36 | """ 37 | MobileMetV2 implementation 38 | """ 39 | def __init__(self, args): 40 | super(MobileNetV2, self).__init__() 41 | self.pre = nn.Sequential( 42 | nn.Conv2d(3, 32, 3, padding=1), 43 | nn.BatchNorm2d(32), 44 | nn.ReLU6(inplace=True) 45 | ) 46 | 47 | self.stage1 = LinearBottleNeck(32, 16, 1, 1) 48 | self.stage2 = self._make_stage(2, 16, 24, 2, 6) 49 | self.stage3 = self._make_stage(3, 24, 32, 2, 6) 50 | self.stage4 = self._make_stage(4, 32, 64, 2, 6) 51 | self.stage5 = self._make_stage(3, 64, 96, 1, 6) 52 | self.stage6 = self._make_stage(3, 96, 160, 2, 6) 53 | self.stage7 = LinearBottleNeck(160, 320, 1, 6) 54 | 55 | self.conv1 = nn.Sequential( 56 | nn.Conv2d(320, 1280, 1), 57 | nn.BatchNorm2d(1280), 58 | nn.ReLU6(inplace=True) 59 | ) 60 | 61 | self.conv2 = nn.Conv2d(1280, args.num_classes, 1) 62 | 63 | def _make_stage(self, n, in_channels, out_channels, stride, t): 64 | layers = [LinearBottleNeck(in_channels, out_channels, stride, t)] 65 | 66 | while n - 1: 67 | layers.append(LinearBottleNeck(out_channels, out_channels, 1, t)) 68 | n -= 1 69 | 70 | return nn.Sequential(*layers) 71 | 72 | def forward(self, x, start_layer_idx=0, logit=False): 73 | if start_layer_idx < 0: # 74 | return self.mapping(x, start_layer_idx=start_layer_idx, logit=logit) 75 | x = self.pre(x) 76 | x = self.stage1(x) 77 | x = self.stage2(x) 78 | x = self.stage3(x) 79 | x = self.stage4(x) 80 | x = self.stage5(x) 81 | x = self.stage6(x) 82 | x = self.stage7(x) 83 | x = self.conv1(x) 84 | # ? 85 | x = F.adaptive_max_pool2d(x, 1) 86 | result = {'representation': x.view(x.size(0), -1)} 87 | x = self.conv2(x) 88 | x = x.view(x.size(0), -1) 89 | result['output'] = x 90 | return result 91 | 92 | def mapping(self, z_input, start_layer_idx=-1, logit=True): 93 | z = z_input.unsqueeze(2).unsqueeze(2) 94 | z = self.conv2(z) 95 | z = z.view(z.size(0), -1) 96 | result = {'output': z} 97 | if logit: 98 | result['logit'] = z 99 | return result 100 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | from .lstm import CharLSTM -------------------------------------------------------------------------------- /models/__pycache__/Fed.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/Fed.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/Fed.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/Fed.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/Fed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/Fed.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/Fed.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/Fed.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/MobileNetV2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/MobileNetV2.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/MobileNetV2.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/MobileNetV2.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/MobileNetV2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/MobileNetV2.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/MobileNetV2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/MobileNetV2.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/Nets.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/Nets.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/Nets.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/Nets.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/Nets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/Nets.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/Nets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/Nets.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/Update.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/Update.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/Update.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/Update.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/Update.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/Update.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/Update.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/Update.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/generator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/generator.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/generator.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/generator.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/generator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/generator.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/generator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/generator.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/lstm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/lstm.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/lstm.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/lstm.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/lstm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/lstm.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/lstm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/lstm.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnetcifar.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/resnetcifar.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnetcifar.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/resnetcifar.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnetcifar.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/resnetcifar.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnetcifar.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/resnetcifar.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/test.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/test.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/test.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/test.cpython-312.pyc -------------------------------------------------------------------------------- /models/__pycache__/test.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/test.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/test.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/models/__pycache__/test.cpython-39.pyc -------------------------------------------------------------------------------- /models/at.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | ''' 9 | AT with sum of absolute values with power p 10 | code from: https://github.com/AberHu/Knowledge-Distillation-Zoo 11 | ''' 12 | class AT(nn.Module): 13 | ''' 14 | Paying More Attention to Attention: Improving the Performance of Convolutional 15 | Neural Netkworks wia Attention Transfer 16 | https://arxiv.org/pdf/1612.03928.pdf 17 | ''' 18 | def __init__(self, p): 19 | super(AT, self).__init__() 20 | self.p = p 21 | 22 | def forward(self, fm_s, fm_t): 23 | loss = F.mse_loss(self.attention_map(fm_s), self.attention_map(fm_t)) 24 | 25 | return loss 26 | 27 | def attention_map(self, fm, eps=1e-6): 28 | am = torch.pow(torch.abs(fm), self.p) 29 | am = torch.sum(am, dim=1, keepdim=True) 30 | norm = torch.norm(am, dim=(2,3), keepdim=True) 31 | am = torch.div(am, norm+eps) 32 | 33 | return am -------------------------------------------------------------------------------- /models/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | MAXLOG = 0.1 5 | from torch.autograd import Variable 6 | import collections 7 | import numpy as np 8 | from utils.model_config import GENERATORCONFIGS, CNN_GENERATORCONFIGS, RESNET_GENERATORCONFIGS, RESNET20_GENERATORCONFIGS, VGG_GENERATORCONFIGS 9 | 10 | 11 | class Generator(nn.Module): 12 | def __init__(self, dataset, model, embedding=False, latent_layer_idx=-1): 13 | super(Generator, self).__init__() 14 | self.embedding = embedding 15 | self.dataset = dataset 16 | #self.model=model 17 | self.latent_layer_idx = latent_layer_idx 18 | # noise_dim = GENERATORCONFIGS[dataset] 19 | print(model) 20 | if model == 'cnn': 21 | noise_dim = CNN_GENERATORCONFIGS[dataset] 22 | elif model == 'resnet18': 23 | noise_dim = RESNET_GENERATORCONFIGS[dataset] 24 | elif model == 'resnet20': 25 | noise_dim = RESNET20_GENERATORCONFIGS[dataset] 26 | elif model == 'vgg': 27 | noise_dim = VGG_GENERATORCONFIGS[dataset] 28 | else: 29 | noise_dim = GENERATORCONFIGS[dataset] 30 | self.hidden_dim, self.latent_dim, self.input_channel, self.n_class, self.noise_dim = noise_dim 31 | input_dim = self.noise_dim * 2 if self.embedding else self.noise_dim + self.n_class 32 | self.fc_configs = [input_dim, self.hidden_dim] 33 | self.init_loss_fn() 34 | self.build_network() 35 | 36 | def get_number_of_parameters(self): 37 | pytorch_total_params=sum(p.numel() for p in self.parameters() if p.requires_grad) 38 | return pytorch_total_params 39 | 40 | def init_loss_fn(self): 41 | self.crossentropy_loss=nn.CrossEntropyLoss(reduce=False) # same as above 42 | self.diversity_loss = DiversityLoss(metric='l1') 43 | self.dist_loss = nn.MSELoss() 44 | 45 | def build_network(self): 46 | if self.embedding: 47 | self.embedding_layer = nn.Embedding(self.n_class, self.noise_dim) 48 | ### FC modules #### 49 | self.fc_layers = nn.ModuleList() 50 | for i in range(len(self.fc_configs) - 1): 51 | input_dim, out_dim = self.fc_configs[i], self.fc_configs[i + 1] 52 | fc = nn.Linear(input_dim, out_dim) 53 | bn = nn.BatchNorm1d(out_dim) 54 | act = nn.ReLU() 55 | self.fc_layers += [fc, bn, act] 56 | ### Representation layer 57 | self.representation_layer = nn.Linear(self.fc_configs[-1], self.latent_dim) 58 | 59 | def forward(self, labels, latent_layer_idx=-1, verbose=True): 60 | """ 61 | G(Z|y) or G(X|y): 62 | Generate either latent representation( latent_layer_idx < 0) or raw image (latent_layer_idx=0) conditional on labels. 63 | :param labels: 64 | :param latent_layer_idx: 65 | if -1, generate latent representation of the last layer, 66 | -2 for the 2nd to last layer, 0 for raw images. 67 | :param verbose: also return the sampled Gaussian noise if verbose = True 68 | :return: a dictionary of output information. 69 | """ 70 | result = {} 71 | batch_size = labels.shape[0] 72 | eps = torch.rand((batch_size, self.noise_dim)) # sampling from Gaussian 73 | if verbose: 74 | result['eps'] = eps 75 | if self.embedding: # embedded dense vector 76 | y_input = self.embedding_layer(labels) 77 | else: # one-hot (sparse) vector 78 | y_input = torch.FloatTensor(batch_size, self.n_class) 79 | y_input.zero_() 80 | #labels = labels.view 81 | labels_int64 = labels.type(torch.LongTensor) 82 | y_input.scatter_(1, labels_int64.view(-1,1), 1) 83 | z = torch.cat((eps, y_input), dim=1) 84 | ### FC layers 85 | for layer in self.fc_layers: 86 | z = layer(z) 87 | z = self.representation_layer(z) 88 | result['output'] = z 89 | return result 90 | 91 | @staticmethod 92 | def normalize_images(layer): 93 | """ 94 | Normalize images into zero-mean and unit-variance. 95 | """ 96 | mean = layer.mean(dim=(2, 3), keepdim=True) 97 | std = layer.view((layer.size(0), layer.size(1), -1)) \ 98 | .std(dim=2, keepdim=True).unsqueeze(3) 99 | return (layer - mean) / std 100 | # 101 | # class Decoder(nn.Module): 102 | # """ 103 | # Decoder for both unstructured and image datasets. 104 | # """ 105 | # def __init__(self, dataset='mnist', latent_layer_idx=-1, n_layers=2, units=32): 106 | # """ 107 | # Class initializer. 108 | # """ 109 | # #in_features, out_targets, n_layers=2, units=32): 110 | # super(Decoder, self).__init__() 111 | # self.cv_configs, self.input_channel, self.n_class, self.scale, self.noise_dim = GENERATORCONFIGS[dataset] 112 | # self.hidden_dim = self.scale * self.scale * self.cv_configs[0] 113 | # self.latent_dim = self.cv_configs[0] * 2 114 | # self.represent_dims = [self.hidden_dim, self.latent_dim] 115 | # in_features = self.represent_dims[latent_layer_idx] 116 | # out_targets = self.noise_dim 117 | # 118 | # # build layer structure 119 | # layers = [nn.Linear(in_features, units), 120 | # nn.ELU(), 121 | # nn.BatchNorm1d(units)] 122 | # 123 | # for _ in range(n_layers): 124 | # layers.extend([ 125 | # nn.Linear(units, units), 126 | # nn.ELU(), 127 | # nn.BatchNorm1d(units)]) 128 | # 129 | # layers.append(nn.Linear(units, out_targets)) 130 | # self.layers = nn.Sequential(*layers) 131 | # 132 | # def forward(self, x): 133 | # """ 134 | # Forward propagation. 135 | # """ 136 | # out = x.view((x.size(0), -1)) 137 | # out = self.layers(out) 138 | # return out 139 | 140 | class DivLoss(nn.Module): 141 | """ 142 | Diversity loss for improving the performance. 143 | """ 144 | 145 | def __init__(self): 146 | """ 147 | Class initializer. 148 | """ 149 | super().__init__() 150 | 151 | def forward2(self, noises, layer): 152 | """ 153 | Forward propagation. 154 | """ 155 | if len(layer.shape) > 2: 156 | layer = layer.view((layer.size(0), -1)) 157 | chunk_size = layer.size(0) // 2 158 | 159 | ####### diversity loss ######## 160 | eps1, eps2=torch.split(noises, chunk_size, dim=0) 161 | chunk1, chunk2=torch.split(layer, chunk_size, dim=0) 162 | lz=torch.mean(torch.abs(chunk1 - chunk2)) / torch.mean( 163 | torch.abs(eps1 - eps2)) 164 | eps=1 * 1e-5 165 | diversity_loss=1 / (lz + eps) 166 | return diversity_loss 167 | 168 | def forward(self, noises, layer): 169 | """ 170 | Forward propagation. 171 | """ 172 | if len(layer.shape) > 2: 173 | layer=layer.view((layer.size(0), -1)) 174 | chunk_size=layer.size(0) // 2 175 | 176 | ####### diversity loss ######## 177 | eps1, eps2=torch.split(noises, chunk_size, dim=0) 178 | chunk1, chunk2=torch.split(layer, chunk_size, dim=0) 179 | lz=torch.mean(torch.abs(chunk1 - chunk2)) / torch.mean( 180 | torch.abs(eps1 - eps2)) 181 | eps=1 * 1e-5 182 | diversity_loss=1 / (lz + eps) 183 | return diversity_loss 184 | 185 | class DiversityLoss(nn.Module): 186 | """ 187 | Diversity loss for improving the performance. 188 | """ 189 | 190 | def __init__(self, metric): 191 | """ 192 | Class initializer. 193 | """ 194 | super().__init__() 195 | self.metric = metric 196 | self.cosine = nn.CosineSimilarity(dim=2) 197 | 198 | def compute_distance(self, tensor1, tensor2, metric): 199 | """ 200 | Compute the distance between two tensors. 201 | """ 202 | if metric == 'l1': 203 | return torch.abs(tensor1 - tensor2).mean(dim=(2,)) 204 | elif metric == 'l2': 205 | return torch.pow(tensor1 - tensor2, 2).mean(dim=(2,)) 206 | elif metric == 'cosine': 207 | return 1 - self.cosine(tensor1, tensor2) 208 | else: 209 | raise ValueError(metric) 210 | 211 | def pairwise_distance(self, tensor, how): 212 | """ 213 | Compute the pairwise distances between a Tensor's rows. 214 | """ 215 | n_data = tensor.size(0) 216 | tensor1 = tensor.expand((n_data, n_data, tensor.size(1))) 217 | tensor2 = tensor.unsqueeze(dim=1) 218 | return self.compute_distance(tensor1, tensor2, how) 219 | 220 | def forward(self, noises, layer): 221 | """ 222 | Forward propagation. 223 | """ 224 | if len(layer.shape) > 2: 225 | layer = layer.view((layer.size(0), -1)) 226 | layer_dist = self.pairwise_distance(layer, how=self.metric) 227 | noise_dist = self.pairwise_distance(noises, how='l2') 228 | return torch.exp(torch.mean(-noise_dist * layer_dist)) 229 | -------------------------------------------------------------------------------- /models/lstm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class CharLSTM(nn.Module): 4 | def __init__(self): 5 | super(CharLSTM, self).__init__() 6 | self.embed = nn.Embedding(80, 8) 7 | self.lstm = nn.LSTM(8, 256, 2, batch_first=True) 8 | # self.h0 = torch.zeros(2, batch_size, 256).requires_grad_() 9 | # self.drop = nn.Dropout() 10 | self.out = nn.Linear(256, 80) 11 | 12 | def forward(self, x): 13 | x = self.embed(x) 14 | # if self.h0.size(1) == x.size(0): 15 | # self.h0.data.zero_() 16 | # # self.c0.data.zero_() 17 | # else: 18 | # # resize hidden vars 19 | # device = next(self.parameters()).device 20 | # self.h0 = torch.zeros(2, x.size(0), 256).to(device).requires_grad_() 21 | x, hidden = self.lstm(x) 22 | # x = self.drop(x) 23 | # x = x.contiguous().view(-1, 256) 24 | # x = x.contiguous().view(-1, 256) 25 | x = self.out(x[:, -1, :]) 26 | return {'output' : x} 27 | 28 | 29 | -------------------------------------------------------------------------------- /models/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader 9 | 10 | 11 | def test_img(net_g, datatest, args): 12 | net_g.eval() 13 | # testing 14 | test_loss = 0 15 | correct = 0 16 | data_loader = DataLoader(datatest, batch_size=args.bs) 17 | l = len(data_loader) 18 | with torch.no_grad(): 19 | for idx, (data, target) in enumerate(data_loader): 20 | if args.gpu != -1: 21 | data, target = data.cuda(), target.cuda() 22 | log_probs = net_g(data)['output'] 23 | # sum up batch loss 24 | test_loss += F.cross_entropy(log_probs, target, reduction='sum').item() 25 | # get the index of the max log-probability 26 | y_pred = log_probs.data.max(1, keepdim=True)[1] 27 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 28 | 29 | test_loss /= len(data_loader.dataset) 30 | accuracy = 100.00 * correct / len(data_loader.dataset) 31 | if args.verbose: 32 | print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format( 33 | test_loss, correct, len(data_loader.dataset), accuracy)) 34 | return accuracy, test_loss 35 | 36 | def branchy_test_img(net_g, classifier, tag, datatest, args): 37 | net_g.eval() 38 | classifier.eval() 39 | # testing 40 | test_loss = 0 41 | correct = 0 42 | data_loader = DataLoader(datatest, batch_size=args.bs) 43 | l = len(data_loader) 44 | with torch.no_grad(): 45 | for idx, (data, target) in enumerate(data_loader): 46 | if args.gpu != -1: 47 | data, target = data.cuda(), target.cuda() 48 | net_out = net_g(data)[tag] 49 | log_probs = classifier(net_out)['output'] 50 | # sum up batch loss 51 | test_loss += F.cross_entropy(log_probs, target, reduction='sum').item() 52 | # get the index of the max log-probability 53 | y_pred = log_probs.data.max(1, keepdim=True)[1] 54 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 55 | 56 | test_loss /= len(data_loader.dataset) 57 | accuracy = 100.00 * correct / len(data_loader.dataset) 58 | if args.verbose: 59 | print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format( 60 | test_loss, correct, len(data_loader.dataset), accuracy)) 61 | return accuracy, test_loss -------------------------------------------------------------------------------- /optimizer/Adabelief.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | 5 | version_higher = (torch.__version__ >= "1.5.0") 6 | 7 | 8 | class AdaBelief(Optimizer): 9 | r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch 10 | Arguments: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float, optional): learning rate (default: 1e-3) 14 | betas (Tuple[float, float], optional): coefficients used for computing 15 | running averages of gradient and its square (default: (0.9, 0.999)) 16 | eps (float, optional): term added to the denominator to improve 17 | numerical stability (default: 1e-8) 18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 19 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 20 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 21 | (default: False) 22 | weight_decouple (boolean, optional): ( default: False) If set as True, then 23 | the optimizer uses decoupled weight decay as in AdamW 24 | fixed_decay (boolean, optional): (default: False) This is used when weight_decouple 25 | is set as True. 26 | When fixed_decay == True, the weight decay is performed as 27 | $W_{new} = W_{old} - W_{old} \times decay$. 28 | When fixed_decay == False, the weight decay is performed as 29 | $W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the 30 | weight decay ratio decreases with learning rate (lr). 31 | rectify (boolean, optional): (default: False) If set as True, then perform the rectified 32 | update similar to RAdam 33 | reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients 34 | NeurIPS 2020 Spotlight 35 | """ 36 | 37 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 38 | weight_decay=0, amsgrad=False, weight_decouple=False, fixed_decay=False, rectify=False): 39 | if not 0.0 <= lr: 40 | raise ValueError("Invalid learning rate: {}".format(lr)) 41 | if not 0.0 <= eps: 42 | raise ValueError("Invalid epsilon value: {}".format(eps)) 43 | if not 0.0 <= betas[0] < 1.0: 44 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 45 | if not 0.0 <= betas[1] < 1.0: 46 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 47 | defaults = dict(lr=lr, betas=betas, eps=eps, 48 | weight_decay=weight_decay, amsgrad=amsgrad) 49 | super(AdaBelief, self).__init__(params, defaults) 50 | 51 | self.weight_decouple = weight_decouple 52 | self.rectify = rectify 53 | self.fixed_decay = fixed_decay 54 | if self.weight_decouple: 55 | print('Weight decoupling enabled in AdaBelief') 56 | if self.fixed_decay: 57 | print('Weight decay fixed') 58 | if self.rectify: 59 | print('Rectification enabled in AdaBelief') 60 | if amsgrad: 61 | print('AMS enabled in AdaBelief') 62 | 63 | def __setstate__(self, state): 64 | super(AdaBelief, self).__setstate__(state) 65 | for group in self.param_groups: 66 | group.setdefault('amsgrad', False) 67 | 68 | def reset(self): 69 | for group in self.param_groups: 70 | for p in group['params']: 71 | state = self.state[p] 72 | amsgrad = group['amsgrad'] 73 | 74 | # State initialization 75 | state['step'] = 0 76 | # Exponential moving average of gradient values 77 | state['exp_avg'] = torch.zeros_like(p.data, 78 | memory_format=torch.preserve_format) if version_higher else torch.zeros_like( 79 | p.data) 80 | 81 | # Exponential moving average of squared gradient values 82 | state['exp_avg_var'] = torch.zeros_like(p.data, 83 | memory_format=torch.preserve_format) if version_higher else torch.zeros_like( 84 | p.data) 85 | if amsgrad: 86 | # Maintains max of all exp. moving avg. of sq. grad. values 87 | state['max_exp_avg_var'] = torch.zeros_like(p.data, 88 | memory_format=torch.preserve_format) if version_higher else torch.zeros_like( 89 | p.data) 90 | 91 | def step(self, closure=None): 92 | """Performs a single optimization step. 93 | Arguments: 94 | closure (callable, optional): A closure that reevaluates the model 95 | and returns the loss. 96 | """ 97 | loss = None 98 | if closure is not None: 99 | loss = closure() 100 | 101 | for group in self.param_groups: 102 | for p in group['params']: 103 | if p.grad is None: 104 | continue 105 | grad = p.grad.data 106 | if grad.is_sparse: 107 | raise RuntimeError( 108 | 'AdaBelief does not support sparse gradients, please consider SparseAdam instead') 109 | amsgrad = group['amsgrad'] 110 | 111 | state = self.state[p] 112 | 113 | beta1, beta2 = group['betas'] 114 | 115 | # State initialization 116 | if len(state) == 0: 117 | state['rho_inf'] = 2.0 / (1.0 - beta2) - 1.0 118 | state['step'] = 0 119 | # Exponential moving average of gradient values 120 | state['exp_avg'] = torch.zeros_like(p.data, 121 | memory_format=torch.preserve_format) if version_higher else torch.zeros_like( 122 | p.data) 123 | # Exponential moving average of squared gradient values 124 | state['exp_avg_var'] = torch.zeros_like(p.data, 125 | memory_format=torch.preserve_format) if version_higher else torch.zeros_like( 126 | p.data) 127 | if amsgrad: 128 | # Maintains max of all exp. moving avg. of sq. grad. values 129 | state['max_exp_avg_var'] = torch.zeros_like(p.data, 130 | memory_format=torch.preserve_format) if version_higher else torch.zeros_like( 131 | p.data) 132 | 133 | # get current state variable 134 | exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var'] 135 | 136 | state['step'] += 1 137 | bias_correction1 = 1 - beta1 ** state['step'] 138 | bias_correction2 = 1 - beta2 ** state['step'] 139 | 140 | # perform weight decay, check if decoupled weight decay 141 | if self.weight_decouple: 142 | if not self.fixed_decay: 143 | p.data.mul_(1.0 - group['lr'] * group['weight_decay']) 144 | else: 145 | p.data.mul_(1.0 - group['weight_decay']) 146 | else: 147 | if group['weight_decay'] != 0: 148 | grad.add_(group['weight_decay'], p.data) 149 | 150 | # Update first and second moment running average 151 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 152 | grad_residual = grad - exp_avg 153 | exp_avg_var.mul_(beta2).addcmul_(1 - beta2, grad_residual, grad_residual) 154 | 155 | if amsgrad: 156 | max_exp_avg_var = state['max_exp_avg_var'] 157 | # Maintains the maximum of all 2nd moment running avg. till now 158 | torch.max(max_exp_avg_var, exp_avg_var, out=max_exp_avg_var) 159 | 160 | # Use the max. for normalizing running avg. of gradient 161 | denom = (max_exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 162 | else: 163 | denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 164 | 165 | if not self.rectify: 166 | # Default update 167 | step_size = group['lr'] / bias_correction1 168 | p.data.addcdiv_(-step_size, exp_avg, denom) 169 | 170 | else: # Rectified update 171 | # calculate rho_t 172 | state['rho_t'] = state['rho_inf'] - 2 * state['step'] * beta2 ** state['step'] / ( 173 | 1.0 - beta2 ** state['step']) 174 | 175 | if state['rho_t'] > 4: # perform Adam style update if variance is small 176 | rho_inf, rho_t = state['rho_inf'], state['rho_t'] 177 | rt = (rho_t - 4.0) * (rho_t - 2.0) * rho_inf / (rho_inf - 4.0) / (rho_inf - 2.0) / rho_t 178 | rt = math.sqrt(rt) 179 | 180 | step_size = rt * group['lr'] / bias_correction1 181 | 182 | p.data.addcdiv_(-step_size, exp_avg, denom) 183 | 184 | else: # perform SGD style update 185 | p.data.add_(-group['lr'], exp_avg) 186 | 187 | return loss -------------------------------------------------------------------------------- /optimizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/optimizer/__init__.py -------------------------------------------------------------------------------- /optimizer/__pycache__/Adabelief.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/optimizer/__pycache__/Adabelief.cpython-310.pyc -------------------------------------------------------------------------------- /optimizer/__pycache__/Adabelief.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/optimizer/__pycache__/Adabelief.cpython-312.pyc -------------------------------------------------------------------------------- /optimizer/__pycache__/Adabelief.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/optimizer/__pycache__/Adabelief.cpython-37.pyc -------------------------------------------------------------------------------- /optimizer/__pycache__/Adabelief.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/optimizer/__pycache__/Adabelief.cpython-39.pyc -------------------------------------------------------------------------------- /optimizer/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/optimizer/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /optimizer/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/optimizer/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /optimizer/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/optimizer/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /optimizer/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/optimizer/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import copy 4 | 5 | from utils.options import args_parser 6 | from utils.set_seed import set_random_seed 7 | from models.Update import * 8 | from models.Nets import * 9 | from models.MobileNetV2 import MobileNetV2 10 | from models.Fed import Aggregation, Weighted_Aggregation_FedASync 11 | from models.test import test_img 12 | from models.resnetcifar import * 13 | from models import * 14 | from utils.get_dataset import get_dataset 15 | from utils.utils import save_result,save_model 16 | from Algorithm.Training_FedGen import FedGen 17 | from Algorithm.Triaining_Scaffold import Scaffold 18 | from Algorithm.Training_FedDC import FedDC 19 | from Algorithm.Training_FedCross import FedCross 20 | from Algorithm.Training_FedMR import FedMR 21 | from Algorithm.Training_FedMR import FedMR_Frozen 22 | from Algorithm.Training_FedMR import FedMR_Partition 23 | from Algorithm.Training_CFL import CFL 24 | from Algorithm.Training_FedIndenp import FedIndep 25 | from Algorithm.Training_Asyn_FedSA import FedSA 26 | from Algorithm.Training_Asyn_GitFL import GitFL 27 | from utils.Clients import Clients 28 | import utils.asynchronous_client_config as AsynConfig 29 | 30 | 31 | if __name__ == '__main__': 32 | 33 | PATH = "/home/huming/hm/FederatedLearning/loss-landscape-master/model/cifar10_FedMR_resnet20_test_model_1000_lr_0.01_2023_05_13_14_23_43_frac_0.1.txt" 34 | model_dict=torch.load(PATH) 35 | model_dict=model.load_state_dict(torch.load(PATH)) -------------------------------------------------------------------------------- /utils/Clients.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils.asynchronous_client_config import * 3 | 4 | 5 | class Clients: 6 | def __init__(self, args): 7 | self.args = args 8 | uncertain_list = [0.2,0.2,0.2,0.2,0.2] 9 | if args.uncertain_type == 1: 10 | uncertain_list = [0.5,0.2,0.1,0.1,0.1] 11 | elif args.uncertain_type == 2: 12 | uncertain_list = [0.1,0.15,0.5,0.15,0.1] 13 | elif args.uncertain_type == 3: 14 | uncertain_list = [0.1,0.1,0.1,0.2,0.5] 15 | elif args.uncertain_type == 4: 16 | uncertain_list = [0.4,0.1,0.0,0.1,0.4] 17 | self.clients_list = generate_asyn_clients(uncertain_list, uncertain_list, args.num_users) 18 | self.update_list = [] # (idx, version, time) 19 | self.train_set = set() 20 | # for i in range(self.args.num_users): 21 | # self.clients_list.append(Node(1.4, 1.4, np.random.exponential(), 0, args)) 22 | 23 | def train(self, idx, version): 24 | for i in range(len(self.update_list) - 1, -1, -1): 25 | if self.update_list[i][0] == idx: 26 | self.update_list.pop(i) 27 | client = self.get(idx) 28 | client.version = version 29 | client.comm_count += 1 30 | train_time = client.get_train_time() 31 | comm_time = client.get_comm_time() 32 | self.update_list.append([idx, version, train_time + comm_time]) 33 | self.update_list.sort(key=lambda x: x[2]) 34 | self.train_set.add(idx) 35 | 36 | def get_update_byLimit(self, limit): 37 | lst = [] 38 | for update in self.update_list: 39 | if update[2] <= limit: 40 | lst.append(update) 41 | return lst 42 | # update = [] 43 | # for i in range(self.args.num_users): 44 | # if self.get(i).end_time <= ddl: 45 | # update.append((i, self.get(i).end_time)) 46 | # update.sort(key=lambda x: x[1]) 47 | # return update 48 | 49 | def get_update(self, num): 50 | return self.update_list[0:num] 51 | 52 | def pop_update(self, num): 53 | res = self.update_list[0:num] 54 | max_time = self.update_list[num - 1][2] 55 | for update in self.update_list: 56 | if update[2] <= max_time: 57 | self.train_set.remove(update[0]) 58 | client = self.get(update[0]) 59 | client.comm_count += 1 60 | else: 61 | update[2] -= max_time 62 | self.update_list = self.update_list[num::] 63 | return res 64 | 65 | def get_first_update(self, start_time): 66 | min_idx = 0 67 | min_time = 999999999999 68 | for idx in range(self.args.num_users): 69 | client = self.get(idx) 70 | if client.end_time != 0: 71 | if start_time < client.end_time < min_time: 72 | min_time = client.end_time 73 | min_idx = idx 74 | return min_idx 75 | 76 | def get(self, idx): 77 | return self.clients_list[idx] 78 | 79 | def get_idle(self, num): 80 | idle = self.get_all_idle() 81 | 82 | if len(idle) < num: 83 | return [] 84 | else: 85 | return np.random.choice(idle, num, replace=False) 86 | 87 | def get_all_idle(self): 88 | idle = set(range(self.args.num_users)).difference(self.train_set) 89 | return list(idle) 90 | # idle = [] 91 | # for idx in range(self.args.num_users): 92 | # client = self.get(idx) 93 | # if not (client.start_time <= time < client.end_time) or client.end_time == 0: 94 | # idle.append(idx) 95 | # return idle 96 | 97 | 98 | class Node: 99 | def __init__(self, down_bw, up_bw, computer_ability, version, args): 100 | self.down_bw = down_bw 101 | self.up_bw = up_bw 102 | self.computer_ability = computer_ability 103 | self.version = version 104 | self.data_size = args.local_bs 105 | self.start_time = 0 106 | self.end_time = 0 107 | self.args = args 108 | self.selected = 0 109 | self.avg = 0 110 | 111 | def get_end_time(self, start_time, version): 112 | self.version = version 113 | self.start_time = start_time 114 | 115 | down_time = 10 / (self.down_bw / 8) 116 | train_time = self.data_size * self.args.local_ep / self.computer_ability 117 | up_time = 10 / (self.up_bw / 8) 118 | time = down_time + train_time + up_time 119 | 120 | self.end_time = start_time + time 121 | self.avg = time if self.selected == 0 else (self.avg * self.selected + time) / (self.selected + 1) 122 | self.selected += 1 123 | 124 | return self.end_time 125 | -------------------------------------------------------------------------------- /utils/FEMNIST.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | import torch 7 | 8 | class FEMNIST(Dataset): 9 | """ 10 | This dataset is derived from the Leaf repository 11 | (https://github.com/TalwalkarLab/leaf) pre-processing of the Extended MNIST 12 | dataset, grouping examples by writer. Details about Leaf were published in 13 | "LEAF: A Benchmark for Federated Settings" https://arxiv.org/abs/1812.01097. 14 | """ 15 | 16 | def __init__(self, train=True, transform=None, target_transform=None, ): 17 | super(FEMNIST, self).__init__() 18 | self.transform = transform 19 | self.target_transform = target_transform 20 | self.train = train 21 | 22 | train_clients, train_groups, train_data_temp, test_data_temp = read_data("./data/femnist/train", 23 | "./data/femnist/test") 24 | if self.train: 25 | self.dic_users = {} 26 | train_data_x = [] 27 | train_data_y = [] 28 | for i in range(len(train_clients)): 29 | #if i == 100: 30 | # break 31 | self.dic_users[i] = set() 32 | l = len(train_data_x) 33 | cur_x = train_data_temp[train_clients[i]]['x'] 34 | cur_y = train_data_temp[train_clients[i]]['y'] 35 | for j in range(len(cur_x)): 36 | self.dic_users[i].add(j + l) 37 | train_data_x.append(np.array(cur_x[j]).reshape(28, 28)) 38 | train_data_y.append(cur_y[j]) 39 | self.data = train_data_x 40 | self.label = train_data_y 41 | else: 42 | test_data_x = [] 43 | test_data_y = [] 44 | for i in range(len(train_clients)): 45 | cur_x = test_data_temp[train_clients[i]]['x'] 46 | cur_y = test_data_temp[train_clients[i]]['y'] 47 | for j in range(len(cur_x)): 48 | test_data_x.append(np.array(cur_x[j]).reshape(28, 28)) 49 | test_data_y.append(cur_y[j]) 50 | self.data = test_data_x 51 | self.label = test_data_y 52 | 53 | def __getitem__(self, index): 54 | img, target = self.data[index], self.label[index] 55 | img = np.array([img]) 56 | # img = Image.fromarray(img, mode='L') 57 | # if self.transform is not None: 58 | # img = self.transform(img) 59 | # if self.target_transform is not None: 60 | # target = self.target_transform(target) 61 | return torch.from_numpy((0.5-img)/0.5).float(), target 62 | 63 | def __len__(self): 64 | return len(self.data) 65 | 66 | def get_client_dic(self): 67 | if self.train: 68 | return self.dic_users 69 | else: 70 | exit("The test dataset do not have dic_users!") 71 | 72 | 73 | def batch_data(data, batch_size, seed): 74 | ''' 75 | data is a dict := {'x': [numpy array], 'y': [numpy array]} (on one client) 76 | returns x, y, which are both numpy array of length: batch_size 77 | ''' 78 | data_x = data['x'] 79 | data_y = data['y'] 80 | 81 | # randomly shuffle data 82 | np.random.seed(seed) 83 | rng_state = np.random.get_state() 84 | np.random.shuffle(data_x) 85 | np.random.set_state(rng_state) 86 | np.random.shuffle(data_y) 87 | 88 | # loop through mini-batches 89 | for i in range(0, len(data_x), batch_size): 90 | batched_x = data_x[i:i + batch_size] 91 | batched_y = data_y[i:i + batch_size] 92 | yield (batched_x, batched_y) 93 | 94 | 95 | def read_dir(data_dir): 96 | clients = [] 97 | groups = [] 98 | data = defaultdict(lambda: None) 99 | 100 | files = os.listdir(data_dir) 101 | files = [f for f in files if f.endswith('.json')] 102 | for f in files: 103 | file_path = os.path.join(data_dir, f) 104 | with open(file_path, 'r') as inf: 105 | cdata = json.load(inf) 106 | clients.extend(cdata['users']) 107 | if 'hierarchies' in cdata: 108 | groups.extend(cdata['hierarchies']) 109 | data.update(cdata['user_data']) 110 | 111 | clients = list(sorted(data.keys())) 112 | return clients, groups, data 113 | 114 | 115 | 116 | def read_data(train_data_dir, test_data_dir): 117 | '''parses data in given train and test data directories 118 | 119 | assumes: 120 | - the data in the input directories are .json files with 121 | keys 'users' and 'user_data' 122 | - the set of train set users is the same as the set of test set users 123 | 124 | Return: 125 | clients: list of client ids 126 | groups: list of group ids; empty list if none found 127 | train_data: dictionary of train data 128 | test_data: dictionary of test data 129 | ''' 130 | train_clients, train_groups, train_data = read_dir(train_data_dir) 131 | test_clients, test_groups, test_data = read_dir(test_data_dir) 132 | 133 | assert train_clients == test_clients 134 | assert train_groups == test_groups 135 | 136 | return train_clients, train_groups, train_data, test_data 137 | 138 | 139 | if __name__ == '__main__': 140 | test = FEMNIST(train=True) 141 | x = test.get_client_dic() 142 | print(len(x)) 143 | t = 0 144 | for k in x[0]: 145 | t += 1 146 | data, label = test.__getitem__(k) 147 | print(t) 148 | -------------------------------------------------------------------------------- /utils/ShakeSpeare.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | import torch 7 | from utils.language_utils import word_to_indices, letter_to_vec 8 | from models import * 9 | 10 | def read_dir(data_dir): 11 | clients = [] 12 | groups = [] 13 | data = defaultdict(lambda: None) 14 | 15 | files = os.listdir(data_dir) 16 | files = [f for f in files if f.endswith('.json')] 17 | for f in files: 18 | file_path = os.path.join(data_dir, f) 19 | with open(file_path, 'r') as inf: 20 | cdata = json.load(inf) 21 | clients.extend(cdata['users']) 22 | if 'hierarchies' in cdata: 23 | groups.extend(cdata['hierarchies']) 24 | data.update(cdata['user_data']) 25 | 26 | clients = list(sorted(data.keys())) 27 | return clients, groups, data 28 | 29 | def read_data(train_data_dir, test_data_dir): 30 | '''parses data in given train and test data directories 31 | assumes: 32 | - the data in the input directories are .json files with 33 | keys 'users' and 'user_data' 34 | - the set of train set users is the same as the set of test set users 35 | Return: 36 | clients: list of client ids 37 | groups: list of group ids; empty list if none found 38 | train_data: dictionary of train data 39 | test_data: dictionary of test data 40 | ''' 41 | train_clients, train_groups, train_data = read_dir(train_data_dir) 42 | test_clients, test_groups, test_data = read_dir(test_data_dir) 43 | 44 | assert train_clients == test_clients 45 | assert train_groups == test_groups 46 | 47 | return train_clients, train_groups, train_data, test_data 48 | 49 | class ShakeSpeare(Dataset): 50 | def __init__(self, train=True): 51 | super(ShakeSpeare, self).__init__() 52 | train_clients, train_groups, train_data_temp, test_data_temp = read_data("./data/shakespeare/train", 53 | "./data/shakespeare/test") 54 | self.train = train 55 | 56 | if self.train: 57 | self.dic_users = {} 58 | train_data_x = [] 59 | train_data_y = [] 60 | for i in range(len(train_clients)): 61 | # if i == 100: 62 | # break 63 | self.dic_users[i] = set() 64 | l = len(train_data_x) 65 | cur_x = train_data_temp[train_clients[i]]['x'] 66 | cur_y = train_data_temp[train_clients[i]]['y'] 67 | for j in range(len(cur_x)): 68 | self.dic_users[i].add(j + l) 69 | train_data_x.append(cur_x[j]) 70 | train_data_y.append(cur_y[j]) 71 | self.data = train_data_x 72 | self.label = train_data_y 73 | else: 74 | test_data_x = [] 75 | test_data_y = [] 76 | for i in range(len(train_clients)): 77 | cur_x = test_data_temp[train_clients[i]]['x'] 78 | cur_y = test_data_temp[train_clients[i]]['y'] 79 | for j in range(len(cur_x)): 80 | test_data_x.append(cur_x[j]) 81 | test_data_y.append(cur_y[j]) 82 | self.data = test_data_x 83 | self.label = test_data_y 84 | 85 | def __len__(self): 86 | return len(self.data) 87 | 88 | def __getitem__(self, index): 89 | sentence, target = self.data[index], self.label[index] 90 | indices = word_to_indices(sentence) 91 | target = letter_to_vec(target) 92 | # y = indices[1:].append(target) 93 | # target = indices[1:].append(target) 94 | indices = torch.LongTensor(np.array(indices)) 95 | # y = torch.Tensor(np.array(y)) 96 | # target = torch.LongTensor(np.array(target)) 97 | return indices, target 98 | 99 | def get_client_dic(self): 100 | if self.train: 101 | return self.dic_users 102 | else: 103 | exit("The test dataset do not have dic_users!") -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | -------------------------------------------------------------------------------- /utils/__pycache__/Clients.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/Clients.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/Clients.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/Clients.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/Clients.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/Clients.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/Clients.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/Clients.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/FEMNIST.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/FEMNIST.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/FEMNIST.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/FEMNIST.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/FEMNIST.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/FEMNIST.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/FEMNIST.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/FEMNIST.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/ShakeSpare.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/ShakeSpare.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/ShakeSpeare.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/ShakeSpeare.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/ShakeSpeare.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/ShakeSpeare.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/ShakeSpeare.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/ShakeSpeare.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/ShakeSpeare.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/ShakeSpeare.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/asynchronous_client_config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/asynchronous_client_config.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/asynchronous_client_config.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/asynchronous_client_config.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/asynchronous_client_config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/asynchronous_client_config.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/asynchronous_client_config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/asynchronous_client_config.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/clustering.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/clustering.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/clustering.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/clustering.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/clustering.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/clustering.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/clustering.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/clustering.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/dataset_utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset_utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/dataset_utils.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/dataset_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/dataset_utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/get_dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/get_dataset.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/get_dataset.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/get_dataset.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/get_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/get_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/get_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/get_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/language_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/language_utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/language_utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/language_utils.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/language_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/language_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/language_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/language_utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/model_config.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_config.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/model_config.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/model_config.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/model_config.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/mydata.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/mydata.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/mydata.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/mydata.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/mydata.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/mydata.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/mydata.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/mydata.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/options.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/options.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/options.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/options.cpython-311.pyc -------------------------------------------------------------------------------- /utils/__pycache__/options.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/options.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/options.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/options.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/options.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sam_minimizers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/sam_minimizers.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sam_minimizers.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/sam_minimizers.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sampling.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/sampling.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sampling.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/sampling.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sampling.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/sampling.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sampling.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/sampling.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/set_seed.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/set_seed.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/set_seed.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/set_seed.cpython-311.pyc -------------------------------------------------------------------------------- /utils/__pycache__/set_seed.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/set_seed.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/set_seed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/set_seed.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/set_seed.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/set_seed.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/utils.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HMHelloWorld/FedMR/f9e05716c3a96de5278c05de23e91fd160b7edbe/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/clustering.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import numpy as np 4 | from itertools import product 5 | 6 | 7 | from scipy.cluster.hierarchy import fcluster 8 | from copy import deepcopy 9 | 10 | 11 | def get_clusters_with_alg1(n_sampled: int, weights: np.array): 12 | "Algorithm 1" 13 | 14 | epsilon = int(10 ** 10) 15 | # associate each client to a cluster 16 | augmented_weights = np.array([w * n_sampled * epsilon for w in weights]) 17 | ordered_client_idx = np.flip(np.argsort(augmented_weights)) 18 | 19 | n_clients = len(weights) 20 | distri_clusters = np.zeros((n_sampled, n_clients)).astype(int) 21 | 22 | k = 0 23 | for client_idx in ordered_client_idx: 24 | 25 | while augmented_weights[client_idx] > 0: 26 | 27 | sum_proba_in_k = np.sum(distri_clusters[k]) 28 | 29 | u_i = min(epsilon - sum_proba_in_k, augmented_weights[client_idx]) 30 | 31 | distri_clusters[k, client_idx] = u_i 32 | augmented_weights[client_idx] += -u_i 33 | 34 | sum_proba_in_k = np.sum(distri_clusters[k]) 35 | if sum_proba_in_k == 1 * epsilon: 36 | k += 1 37 | 38 | distri_clusters = distri_clusters.astype(float) 39 | for l in range(n_sampled): 40 | distri_clusters[l] /= np.sum(distri_clusters[l]) 41 | 42 | return distri_clusters 43 | 44 | 45 | def get_similarity(grad_1, grad_2, distance_type="L1"): 46 | 47 | if distance_type == "L1": 48 | 49 | norm = 0 50 | for g_1, g_2 in zip(grad_1, grad_2): 51 | norm += np.sum(np.abs(g_1 - g_2)) 52 | return norm 53 | 54 | elif distance_type == "L2": 55 | norm = 0 56 | for g_1, g_2 in zip(grad_1, grad_2): 57 | norm += np.sum((g_1 - g_2) ** 2) 58 | return norm 59 | 60 | elif distance_type == "cosine": 61 | norm, norm_1, norm_2 = 0, 0, 0 62 | for i in range(len(grad_1)): 63 | norm += np.sum(grad_1[i] * grad_2[i]) 64 | norm_1 += np.sum(grad_1[i] ** 2) 65 | norm_2 += np.sum(grad_2[i] ** 2) 66 | 67 | if norm_1 == 0.0 or norm_2 == 0.0: 68 | return 0.0 69 | else: 70 | norm /= np.sqrt(norm_1 * norm_2) 71 | 72 | return np.arccos(norm) 73 | 74 | 75 | def get_gradients(sampling, global_m, local_models): 76 | """return the `representative gradient` formed by the difference between 77 | the local work and the sent global model""" 78 | 79 | local_model_params = [] 80 | for model in local_models: 81 | local_model_params += [ 82 | [tens.detach().numpy() for tens in list(model.parameters())] 83 | ] 84 | 85 | global_model_params = [ 86 | tens.detach().numpy() for tens in list(global_m.parameters()) 87 | ] 88 | 89 | local_model_grads = [] 90 | for local_params in local_model_params: 91 | local_model_grads += [ 92 | [ 93 | local_weights - global_weights 94 | for local_weights, global_weights in zip( 95 | local_params, global_model_params 96 | ) 97 | ] 98 | ] 99 | 100 | return local_model_grads 101 | 102 | 103 | def get_matrix_similarity_from_grads(local_model_grads, distance_type): 104 | """return the similarity matrix where the distance chosen to 105 | compare two clients is set with `distance_type`""" 106 | 107 | n_clients = len(local_model_grads) 108 | 109 | metric_matrix = np.zeros((n_clients, n_clients)) 110 | for i, j in product(range(n_clients), range(n_clients)): 111 | 112 | metric_matrix[i, j] = get_similarity( 113 | local_model_grads[i], local_model_grads[j], distance_type 114 | ) 115 | 116 | return metric_matrix 117 | 118 | def get_matrix_similarity_from_grads_new(local_model_grads, distance_type, idx, metric_matrix): 119 | """return the similarity matrix where the distance chosen to 120 | compare two clients is set with `distance_type`""" 121 | 122 | for i in idx: 123 | for j in idx: 124 | if i == j: 125 | continue 126 | metric_matrix[i, j] = get_similarity( 127 | local_model_grads[i], local_model_grads[j], distance_type 128 | ) 129 | 130 | return metric_matrix 131 | 132 | 133 | def get_matrix_similarity(global_m, local_models, distance_type): 134 | 135 | n_clients = len(local_models) 136 | 137 | local_model_grads = get_gradients(global_m, local_models) 138 | 139 | metric_matrix = np.zeros((n_clients, n_clients)) 140 | for i, j in product(range(n_clients), range(n_clients)): 141 | 142 | metric_matrix[i, j] = get_similarity( 143 | local_model_grads[i], local_model_grads[j], distance_type 144 | ) 145 | 146 | return metric_matrix 147 | 148 | 149 | def get_clusters_with_alg2( 150 | linkage_matrix: np.array, n_sampled: int, weights: np.array 151 | ): 152 | """Algorithm 2""" 153 | epsilon = int(10 ** 10) 154 | 155 | # associate each client to a cluster 156 | link_matrix_p = deepcopy(linkage_matrix) 157 | augmented_weights = deepcopy(weights) 158 | 159 | for i in range(len(link_matrix_p)): 160 | idx_1, idx_2 = int(link_matrix_p[i, 0]), int(link_matrix_p[i, 1]) 161 | 162 | new_weight = np.array( 163 | [augmented_weights[idx_1] + augmented_weights[idx_2]] 164 | ) 165 | augmented_weights = np.concatenate((augmented_weights, new_weight)) 166 | link_matrix_p[i, 2] = int(new_weight * epsilon) 167 | 168 | clusters = fcluster( 169 | link_matrix_p, int(epsilon / n_sampled), criterion="distance" 170 | ) 171 | 172 | n_clients, n_clusters = len(clusters), len(set(clusters)) 173 | 174 | # Associate each cluster to its number of clients in the cluster 175 | pop_clusters = np.zeros((n_clusters, 2)).astype(int) 176 | for i in range(n_clusters): 177 | pop_clusters[i, 0] = i + 1 178 | for client in np.where(clusters == i + 1)[0]: 179 | pop_clusters[i, 1] += int(weights[client] * epsilon * n_sampled) 180 | 181 | pop_clusters = pop_clusters[pop_clusters[:, 1].argsort()] 182 | 183 | distri_clusters = np.zeros((n_sampled, n_clients)).astype(int) 184 | 185 | # n_sampled biggest clusters that will remain unchanged 186 | kept_clusters = pop_clusters[n_clusters - n_sampled :, 0] 187 | 188 | for idx, cluster in enumerate(kept_clusters): 189 | for client in np.where(clusters == cluster)[0]: 190 | distri_clusters[idx, client] = int( 191 | weights[client] * n_sampled * epsilon 192 | ) 193 | 194 | k = 0 195 | for j in pop_clusters[: n_clusters - n_sampled, 0]: 196 | 197 | clients_in_j = np.where(clusters == j)[0] 198 | np.random.shuffle(clients_in_j) 199 | 200 | for client in clients_in_j: 201 | 202 | weight_client = int(weights[client] * epsilon * n_sampled) 203 | 204 | while weight_client > 0: 205 | 206 | sum_proba_in_k = np.sum(distri_clusters[k]) 207 | 208 | u_i = min(epsilon - sum_proba_in_k, weight_client) 209 | 210 | distri_clusters[k, client] = u_i 211 | weight_client += -u_i 212 | 213 | sum_proba_in_k = np.sum(distri_clusters[k]) 214 | if sum_proba_in_k == 1 * epsilon: 215 | k += 1 216 | 217 | distri_clusters = distri_clusters.astype(float) 218 | for l in range(n_sampled): 219 | distri_clusters[l] /= np.sum(distri_clusters[l]) 220 | 221 | return distri_clusters 222 | 223 | 224 | from numpy.random import choice 225 | 226 | 227 | def sample_clients(distri_clusters): 228 | 229 | n_clients = len(distri_clusters[0]) 230 | n_sampled = len(distri_clusters) 231 | 232 | sampled_clients = np.zeros(len(distri_clusters), dtype=int) 233 | 234 | for k in range(n_sampled): 235 | sampled_clients[k] = int(choice(n_clients, 1, p=distri_clusters[k])) 236 | 237 | return sampled_clients 238 | -------------------------------------------------------------------------------- /utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import types 3 | from collections import defaultdict 4 | 5 | import ujson 6 | import numpy as np 7 | import json 8 | import torch 9 | import random 10 | 11 | 12 | def check(config_path, train_path, test_path, num_clients, num_labels, niid=False, 13 | real=True, partition=None): 14 | # check existing dataset 15 | if os.path.exists(config_path): 16 | with open(config_path, 'r') as f: 17 | config = ujson.load(f) 18 | if config['num_clients'] == num_clients and \ 19 | config['num_labels'] == num_labels and \ 20 | config['non_iid'] == niid and \ 21 | config['real_world'] == real and \ 22 | config['partition'] == partition: 23 | print("\nDataset already generated.\n") 24 | return True 25 | 26 | dir_path = os.path.dirname(train_path) 27 | if not os.path.exists(dir_path): 28 | os.makedirs(dir_path) 29 | dir_path = os.path.dirname(test_path) 30 | if not os.path.exists(dir_path): 31 | os.makedirs(dir_path) 32 | 33 | return False 34 | 35 | def read_record(file): 36 | with open(file,"r") as f: 37 | dataJson = json.load(f) 38 | users_train = dataJson["train_data"] 39 | #users_test = dataJson["test_data"] 40 | dict_users_train = {} 41 | #dict_users_test = {} 42 | for key,value in users_train.items(): 43 | newKey = int(key) 44 | dict_users_train[newKey] = value 45 | ''' 46 | for key,value in users_test.items(): 47 | newKey = int(key) 48 | dict_users_test[newKey] = value 49 | ''' 50 | return dict_users_train #, dict_users_test 51 | 52 | def separate_data(train_data, num_clients, num_classes, beta=0.4): 53 | 54 | 55 | y_train = np.array(train_data.targets) 56 | 57 | min_size_train = 0 58 | min_require_size = 10 59 | K = num_classes 60 | 61 | N_train = len(y_train) 62 | dict_users_train = {} 63 | 64 | while min_size_train < min_require_size: 65 | idx_batch_train = [[] for _ in range(num_clients)] 66 | idx_batch_test = [[] for _ in range(num_clients)] 67 | for k in range(K): 68 | idx_k_train = np.where(y_train == k)[0] 69 | np.random.shuffle(idx_k_train) 70 | proportions = np.random.dirichlet(np.repeat(beta, num_clients)) 71 | proportions_train = np.array([p * (len(idx_j) < N_train / num_clients) for p, idx_j in zip(proportions, idx_batch_train)]) 72 | proportions_train = proportions_train / proportions_train.sum() 73 | proportions_train = (np.cumsum(proportions_train) * len(idx_k_train)).astype(int)[:-1] 74 | idx_batch_train = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch_train, np.split(idx_k_train, proportions_train))] 75 | min_size_train = min([len(idx_j) for idx_j in idx_batch_train]) 76 | # if K == 2 and n_parties <= 10: 77 | # if np.min(proportions) < 200: 78 | # min_size = 0 79 | # break 80 | 81 | for j in range(num_clients): 82 | np.random.shuffle(idx_batch_train[j]) 83 | dict_users_train[j] = idx_batch_train[j] 84 | 85 | train_cls_counts = record_net_data_stats(y_train,dict_users_train) 86 | 87 | return dict_users_train 88 | 89 | def record_net_data_stats(y_train, net_dataidx_map): 90 | net_cls_counts = {} 91 | 92 | for net_i, dataidx in net_dataidx_map.items(): 93 | 94 | unq, unq_cnt = np.unique(y_train[dataidx], return_counts=True) 95 | tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))} 96 | net_cls_counts[net_i] = tmp 97 | 98 | 99 | data_list=[] 100 | for net_id, data in net_cls_counts.items(): 101 | n_total=0 102 | for class_id, n_data in data.items(): 103 | n_total += n_data 104 | data_list.append(n_total) 105 | print('mean:', np.mean(data_list)) 106 | print('std:', np.std(data_list)) 107 | 108 | return net_cls_counts 109 | 110 | def save_file(config_path, train_path, test_path, train_data, test_data, num_clients, 111 | num_labels, statistic, niid=False, real=True, partition=None): 112 | config = { 113 | 'num_clients': num_clients, 114 | 'num_labels': num_labels, 115 | 'non_iid': niid, 116 | 'real_world': real, 117 | 'partition': partition, 118 | 'Size of samples for labels in clients': statistic, 119 | } 120 | 121 | # gc.collect() 122 | 123 | for idx, train_dict in enumerate(train_data): 124 | with open(train_path[:-5] + str(idx) + '_' + '.json', 'w') as f: 125 | ujson.dump(train_dict, f) 126 | for idx, test_dict in enumerate(test_data): 127 | with open(test_path[:-5] + str(idx) + '_' + '.json', 'w') as f: 128 | ujson.dump(test_dict, f) 129 | with open(config_path, 'w') as f: 130 | ujson.dump(config, f) 131 | 132 | print("Finish generating dataset.\n") 133 | 134 | 135 | def get_num_classes_samples(dataset): 136 | """ 137 | extracts info about certain dataset 138 | :param dataset: pytorch dataset object 139 | :return: dataset info number of classes, number of samples, list of labels 140 | """ 141 | # ---------------# 142 | # Extract labels # 143 | # ---------------# 144 | if isinstance(dataset, torch.utils.data.Subset): 145 | if isinstance(dataset.dataset.targets, list): 146 | data_labels_list = np.array(dataset.dataset.targets)[dataset.indices] 147 | else: 148 | data_labels_list = dataset.dataset.targets[dataset.indices] 149 | else: 150 | if isinstance(dataset.targets, list): 151 | data_labels_list = np.array(dataset.targets) 152 | else: 153 | data_labels_list = dataset.targets 154 | classes, num_samples = np.unique(data_labels_list, return_counts=True) 155 | num_classes = len(classes) 156 | return num_classes, num_samples, data_labels_list 157 | 158 | def gen_classes_per_node(dataset, num_users, classes_per_user=2, high_prob=0.6, low_prob=0.4): 159 | """ 160 | creates the data distribution of each client 161 | :param dataset: pytorch dataset object 162 | :param num_users: number of clients 163 | :param classes_per_user: number of classes assigned to each client 164 | :param high_prob: highest prob sampled 165 | :param low_prob: lowest prob sampled 166 | :return: dictionary mapping between classes and proportions, each entry refers to other client 167 | """ 168 | num_classes, num_samples, _ = get_num_classes_samples(dataset) 169 | 170 | # -------------------------------------------# 171 | # Divide classes + num samples for each user # 172 | # -------------------------------------------# 173 | assert (classes_per_user * num_users) % num_classes == 0, "equal classes appearance is needed" 174 | count_per_class = (classes_per_user * num_users) // num_classes 175 | class_dict = {} 176 | for i in range(num_classes): 177 | # sampling alpha_i_c 178 | probs = np.random.uniform(low_prob, high_prob, size=count_per_class) 179 | # normalizing 180 | probs_norm = (probs / probs.sum()).tolist() 181 | class_dict[i] = {'count': count_per_class, 'prob': probs_norm} 182 | 183 | # -------------------------------------# 184 | # Assign each client with data indexes # 185 | # -------------------------------------# 186 | class_partitions = defaultdict(list) 187 | for i in range(num_users): 188 | c = [] 189 | for _ in range(classes_per_user): 190 | class_counts = [class_dict[i]['count'] for i in range(num_classes)] 191 | max_class_counts = np.where(np.array(class_counts) == max(class_counts))[0] 192 | c.append(np.random.choice(max_class_counts)) 193 | class_dict[c[-1]]['count'] -= 1 194 | class_partitions['class'].append(c) 195 | class_partitions['prob'].append([class_dict[i]['prob'].pop() for i in c]) 196 | return class_partitions 197 | 198 | def gen_data_split(dataset, num_users, class_partitions): 199 | """ 200 | divide data indexes for each client based on class_partition 201 | :param dataset: pytorch dataset object (train/val/test) 202 | :param num_users: number of clients 203 | :param class_partitions: proportion of classes per client 204 | :return: dictionary mapping client to its indexes 205 | """ 206 | num_classes, num_samples, data_labels_list = get_num_classes_samples(dataset) 207 | 208 | # -------------------------- # 209 | # Create class index mapping # 210 | # -------------------------- # 211 | data_class_idx = {i: np.where(data_labels_list == i)[0] for i in range(num_classes)} 212 | 213 | # --------- # 214 | # Shuffling # 215 | # --------- # 216 | for data_idx in data_class_idx.values(): 217 | random.shuffle(data_idx) 218 | 219 | # ------------------------------ # 220 | # Assigning samples to each user # 221 | # ------------------------------ # 222 | user_data_idx = {i: [] for i in range(num_users)} 223 | for usr_i in range(num_users): 224 | for c, p in zip(class_partitions['class'][usr_i], class_partitions['prob'][usr_i]): 225 | end_idx = int(num_samples[c] * p) 226 | user_data_idx[usr_i].extend(data_class_idx[c][:end_idx]) 227 | data_class_idx[c] = data_class_idx[c][end_idx:] 228 | 229 | return user_data_idx 230 | 231 | def gen_random_loaders(dataset, num_users, rand_set_all = None, classes_per_user=2): 232 | """ 233 | generates train/val/test loaders of each client 234 | :param data_name: name of dataset, choose from [cifar10, cifar100] 235 | :param data_path: root path for data dir 236 | :param num_users: number of clients 237 | :param bz: batch size 238 | :param classes_per_user: number of classes assigned to each client 239 | :return: train/val/test loaders of each client, list of pytorch dataloaders 240 | """ 241 | if rand_set_all is None: 242 | rand_set_all = gen_classes_per_node(dataset, num_users, classes_per_user) 243 | 244 | usr_subset_idx = gen_data_split(dataset, num_users, rand_set_all) 245 | 246 | #cls_counts = record_net_data_stats(dataset.targets, usr_subset_idx) 247 | 248 | return usr_subset_idx,rand_set_all -------------------------------------------------------------------------------- /utils/get_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from torchvision import datasets, transforms 5 | from utils.sampling import * 6 | from utils.dataset_utils import separate_data,read_record 7 | from utils.FEMNIST import FEMNIST 8 | from utils.ShakeSpeare import ShakeSpeare 9 | from utils import mydata 10 | from torch.autograd import Variable 11 | import torch.nn.functional as F 12 | import os 13 | import json 14 | 15 | def get_dataset(args): 16 | 17 | file = os.path.join("data", args.dataset + "_" + str(args.num_users)) 18 | if args.iid: 19 | file += "_iid" 20 | else: 21 | file += "_noniidCase" + str(args.noniid_case) 22 | 23 | if args.noniid_case > 4: 24 | file += "_beta" + str(args.data_beta) 25 | 26 | file += ".json" 27 | # load dataset and split users 28 | if args.dataset == 'mnist': 29 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 30 | dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True, transform=trans_mnist) 31 | dataset_test = datasets.MNIST('./data/mnist/', train=False, download=True, transform=trans_mnist) 32 | if args.generate_data: 33 | # sample users 34 | if args.iid: 35 | dict_users = mnist_iid(dataset_train, args.num_users) 36 | else: 37 | dict_users = mnist_noniid(dataset_train, args.num_users) 38 | else: 39 | dict_users = read_record(file) 40 | elif args.dataset == 'cifar10': 41 | 42 | trans_cifar10_train = transforms.Compose([transforms.ToTensor(), 43 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 44 | trans_cifar10_val = transforms.Compose([transforms.ToTensor(), 45 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 46 | 47 | dataset_train = datasets.CIFAR10('./data/cifar10', train=True, download=True, transform=trans_cifar10_train) 48 | dataset_test = datasets.CIFAR10('./data/cifar10', train=False, download=True, transform=trans_cifar10_val) 49 | if args.generate_data: 50 | if args.iid: 51 | dict_users = cifar_iid(dataset_train, args.num_users) 52 | elif args.noniid_case < 5: 53 | dict_users = cifar_noniid(dataset_train,args.num_users,args.noniid_case) 54 | else: 55 | dict_users = separate_data(dataset_train,args.num_users,args.num_classes,args.data_beta) 56 | else: 57 | dict_users = read_record(file) 58 | elif args.dataset == 'cifar100': 59 | trans_cifar100 = transforms.Compose( 60 | [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) 61 | dataset_train = mydata.CIFAR100_coarse('../data/cifar100_coarse', train=True, download=True, 62 | transform=trans_cifar100) 63 | dataset_test = mydata.CIFAR100_coarse('../data/cifar100_coarse', train=False, download=True, 64 | transform=trans_cifar100) 65 | if args.generate_data: 66 | if args.iid: 67 | dict_users = cifar_iid(dataset_train, args.num_users) 68 | elif args.noniid_case < 5: 69 | dict_users = cifar_noniid(dataset_train, args.num_users, args.noniid_case) 70 | else: 71 | dict_users = separate_data(dataset_train, args.num_users, args.num_classes, args.data_beta) 72 | else: 73 | dict_users = read_record(file) 74 | elif args.dataset == 'fashion-mnist': 75 | trans = transforms.Compose([transforms.ToTensor()]) 76 | dataset_train = datasets.FashionMNIST('./data/fashion-mnist/', train=True, download=True, transform=trans) 77 | dataset_test = datasets.FashionMNIST('./data/fashion-mnist/', train=False, download=True, transform=trans) 78 | if args.generate_data: 79 | if args.iid: 80 | dict_users = fashion_mnist_iid(dataset_train, args.num_users) 81 | else: 82 | dict_users = fashion_mnist_noniid(dataset_train, args.num_users, case=args.noniid_case) 83 | else: 84 | dict_users = read_record(file) 85 | elif args.dataset == 'femnist': 86 | dataset_train = FEMNIST(True) 87 | dataset_test = FEMNIST(False) 88 | dict_users = dataset_train.get_client_dic() 89 | args.num_users = len(dict_users) 90 | elif args.dataset == 'ShakeSpeare': 91 | dataset_train = ShakeSpeare(True) 92 | dataset_test = ShakeSpeare(False) 93 | dict_users = dataset_train.get_client_dic() 94 | args.num_users = len(dict_users) 95 | print(args.num_users) 96 | else: 97 | exit('Error: unrecognized dataset') 98 | 99 | if args.generate_data: 100 | with open(file,'w') as f: 101 | dataJson = {"dataset":args.dataset,"num_users":args.num_users,"iid":args.iid,"noniid_case":args.noniid_case,"data_beta":args.data_beta,"train_data":dict_users} 102 | json.dump(dataJson,f) 103 | 104 | return dataset_train, dataset_test, dict_users -------------------------------------------------------------------------------- /utils/language_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for language models.""" 2 | 3 | import re 4 | import numpy as np 5 | import json 6 | 7 | # ------------------------ 8 | # utils for shakespeare dataset 9 | 10 | ALL_LETTERS = "\n !\"&'(),-.0123456789:;>?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz}" 11 | NUM_LETTERS = len(ALL_LETTERS) 12 | 13 | 14 | # print(NUM_LETTERS) 15 | 16 | def _one_hot(index, size): 17 | '''returns one-hot vector with given size and value 1 at given index 18 | ''' 19 | vec = [0 for _ in range(size)] 20 | vec[int(index)] = 1 21 | return vec 22 | 23 | 24 | def letter_to_vec(letter): 25 | '''returns one-hot representation of given letter 26 | ''' 27 | index = ALL_LETTERS.find(letter) 28 | return index 29 | 30 | 31 | def word_to_indices(word): 32 | '''returns a list of character indices 33 | Args: 34 | word: string 35 | 36 | Return: 37 | indices: int list with length len(word) 38 | ''' 39 | indices = [] 40 | for c in word: 41 | indices.append(ALL_LETTERS.find(c)) 42 | return indices -------------------------------------------------------------------------------- /utils/model_config.py: -------------------------------------------------------------------------------- 1 | CONFIGS_ = { 2 | # input_channel, n_class, hidden_dim, latent_dim 3 | 'cifar10': ([6, 'R', 'M', 16, 'R', 'M', 'F'], 3, 10, 400, 120, 84, 0), 4 | 'cifar100': ([6,'R', 'M', 16, 'R','M', 'F'], 3, 20, 400, 120, 84, 0), 5 | 'femnist': ([16, 'M', 'R', 32, 'M', 'R', 'F'], 1, 62, 512, 256, 0), 6 | } 7 | 8 | # temporary roundabout to evaluate sensitivity of the generator 9 | GENERATORCONFIGS = { 10 | # hidden_dimension, latent_dimension, input_channel, n_class, noise_dim 11 | # 'cifar10': (512, 84, 3, 10, 100), # cnn 12 | 'cifar10': (512, 512, 3, 10, 100), # resnet 13 | # 'cifar10': (512, 1280, 3, 10, 100), # mobilenet 14 | 'cifar100': (512, 84, 3, 20, 100), 15 | 'femnist': (512, 256, 1, 62, 100), 16 | } 17 | 18 | CNN_GENERATORCONFIGS = { 19 | # hidden_dimension, latent_dimension, input_channel, n_class, noise_dim 20 | 'cifar10': (512, 84, 3, 10, 100), # cnn 21 | # 'cifar10': (512, 512, 3, 10, 100), # resnet 22 | # 'cifar10': (512, 1280, 3, 10, 100), # mobilenet 23 | 'cifar100': (512, 84, 3, 20, 100), 24 | 'femnist': (512, 256, 1, 62, 100), 25 | } 26 | 27 | RESNET_GENERATORCONFIGS = { 28 | # hidden_dimension, latent_dimension, input_channel, n_class, noise_dim 29 | # 'cifar10': (512, 84, 3, 10, 100), # cnn 30 | 'cifar10': (512, 512, 3, 10, 100), # resnet 31 | # 'cifar10': (512, 1280, 3, 10, 100), # mobilenet 32 | # 'cifar100': (512, 84, 3, 20, 100), 33 | 'cifar100': (512, 512, 3, 20, 100), 34 | 'femnist': (512, 256, 1, 62, 100), 35 | } 36 | 37 | RESNET20_GENERATORCONFIGS = { 38 | # hidden_dimension, latent_dimension, input_channel, n_class, noise_dim 39 | # 'cifar10': (512, 84, 3, 10, 100), # cnn 40 | # 'cifar10': (512, 512, 3, 10, 100), # resnet 41 | 'cifar10': (512, 256, 3, 10, 100), # resnet 42 | # 'cifar10': (512, 1280, 3, 10, 100), # mobilenet 43 | # 'cifar100': (512, 84, 3, 20, 100), 44 | 'cifar100': (512, 512, 3, 20, 100), 45 | 'femnist': (512, 256, 1, 62, 100), 46 | } 47 | 48 | VGG_GENERATORCONFIGS = { 49 | # hidden_dimension, latent_dimension, input_channel, n_class, noise_dim 50 | # 'cifar10': (512, 84, 3, 10, 100), # cnn 51 | 'cifar10': (512, 4096, 3, 10, 100), # resnet 52 | # 'cifar10': (512, 1280, 3, 10, 100), # mobilenet 53 | # 'cifar100': (512, 84, 3, 20, 100), 54 | 'cifar100': (512, 4096, 3, 20, 100), 55 | 'femnist': (512, 4096, 1, 62, 100), 56 | } 57 | 58 | FedGenRUNCONFIGS = { 59 | 'femnist': 60 | { 61 | 'ensemble_lr': 3e-4, 62 | 'ensemble_batch_size': 128, 63 | 'ensemble_epochs': 50, 64 | 'num_pretrain_iters': 20, 65 | 'ensemble_alpha': 1, # teacher loss (server side) 66 | 'ensemble_beta': 0, # adversarial student loss 67 | 'ensemble_eta': 1, # diversity loss 68 | 'generative_alpha': 10, # used to regulate user training 69 | 'generative_beta': 10, # used to regulate user training 70 | 'weight_decay': 1e-2 71 | }, 72 | 'cifar10': 73 | { 74 | 'ensemble_lr': 3e-4, 75 | 'ensemble_batch_size': 128, 76 | 'ensemble_epochs': 50, 77 | 'num_pretrain_iters': 20, 78 | 'ensemble_alpha': 1, # teacher loss (server side) 79 | 'ensemble_beta': 0, # adversarial student loss 80 | 'ensemble_eta': 1, # diversity loss 81 | 'generative_alpha': 0.2, 82 | 'generative_beta': 0.2, 83 | 'weight_decay': 1e-2 84 | }, 85 | 'cifar100': 86 | { 87 | 'ensemble_lr': 1e-4, 88 | 'ensemble_batch_size': 128, 89 | 'ensemble_epochs': 50, 90 | 'num_pretrain_iters': 20, 91 | 'ensemble_alpha': 1, # teacher loss (server side) 92 | 'ensemble_beta': 0, # adversarial student loss 93 | 'ensemble_eta': 1, # diversity loss 94 | 'generative_alpha': 10, 95 | 'generative_beta': 1, 96 | 'weight_decay': 1e-2 97 | }, 98 | 99 | } 100 | 101 | -------------------------------------------------------------------------------- /utils/mydata.py: -------------------------------------------------------------------------------- 1 | # import os 2 | # import torch 3 | # import torchvision.transforms as transforms 4 | # import torchvision.datasets as datasets 5 | # 6 | # 7 | # MNIST_MEAN = (0.1307,) 8 | # MNIST_STD = (0.3081,) 9 | # 10 | # CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) 11 | # CIFAR10_STD = (0.2023, 0.1994, 0.2010) 12 | # 13 | # IMAGENET_MEAN = (0.485, 0.456, 0.406) 14 | # IMAGENET_STD = (0.229, 0.224, 0.225) 15 | # 16 | # def get_dataset(dset_name, batch_size=128, n_worker=4, data_root='../../dataset'): 17 | # 18 | # print('=> Preparing data..') 19 | # kwargs = {'num_workers': n_worker, 'pin_memory': True} if torch.cuda.is_available() else {} 20 | # 21 | # if dset_name == 'mnist': 22 | # #normalize = transforms.Normalize(mean=(0.1307,), std=(0.3081,)) 23 | # transform_train = transforms.Compose([ 24 | # transforms.ToTensor(), 25 | # #normalize, 26 | # ]) 27 | # trainset = datasets.MNIST(root=data_root, train=True, download=True, transform=transform_train) 28 | # train_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True, **kwargs) 29 | # 30 | # transform_test = transforms.Compose([ 31 | # transforms.ToTensor(), 32 | # #normalize, 33 | # ]) 34 | # testset = datasets.MNIST(root=data_root, train=False, download=True, transform=transform_test) 35 | # test_loader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size, shuffle=False, **kwargs) 36 | # 37 | # n_class = 10 38 | # 39 | # elif dset_name == 'cifar10': 40 | # 41 | # #normalize = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)) 42 | # transform_train = transforms.Compose([ 43 | # transforms.RandomCrop(32, padding=4), 44 | # transforms.RandomHorizontalFlip(), 45 | # transforms.ToTensor(), 46 | # #normalize, 47 | # ]) 48 | # trainset = datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform_train) 49 | # train_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True, **kwargs) 50 | # 51 | # transform_test = transforms.Compose([ 52 | # transforms.ToTensor(), 53 | # #normalize, 54 | # ]) 55 | # testset = datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform_test) 56 | # test_loader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size, shuffle=False, **kwargs) 57 | # 58 | # n_class = 10 59 | # 60 | # elif dset_name == 'imagenet': 61 | # 62 | # traindir = os.path.join(data_root, 'imagenet/train') 63 | # valdir = os.path.join(data_root, 'imagenet/val') 64 | # 65 | # #normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 66 | # transform_train = transforms.Compose([ 67 | # transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), 68 | # transforms.RandomHorizontalFlip(), 69 | # transforms.ToTensor(), 70 | # #normalize, 71 | # ]) 72 | # trainset = datasets.ImageFolder(traindir, transform_train) 73 | # train_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True, **kwargs) 74 | # 75 | # transform_test = transforms.Compose([ 76 | # transforms.Resize(256), 77 | # transforms.CenterCrop(224), 78 | # transforms.ToTensor(), 79 | # #normalize, 80 | # ]) 81 | # testset = datasets.ImageFolder(valdir, transform_test) 82 | # test_loader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size, shuffle=False, **kwargs) 83 | # 84 | # n_class = 1000 85 | # 86 | # else: 87 | # raise NotImplementedError 88 | # 89 | # return train_loader, test_loader, n_class 90 | # 91 | # 92 | # def get_testloader(dset_name, batch_size=128, n_worker=4, data_root='../../dataset', subset_idx=None): 93 | # print('=> Preparing testing data..') 94 | # kwargs = {'num_workers': n_worker, 'pin_memory': True} if torch.cuda.is_available() else {} 95 | # if dset_name == 'mnist': 96 | # transform_test = transforms.Compose([ 97 | # transforms.ToTensor(), 98 | # ]) 99 | # testset = datasets.MNIST(root=data_root, train=False, download=True, transform=transform_test) 100 | # if subset_idx is not None: 101 | # testset = torch.utils.data.Subset(testset, subset_idx) 102 | # test_loader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size, shuffle=False, **kwargs) 103 | # n_class = 10 104 | # elif dset_name == 'cifar10': 105 | # transform_test = transforms.Compose([ 106 | # transforms.ToTensor(), 107 | # ]) 108 | # testset = datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform_test) 109 | # if subset_idx is not None: 110 | # testset = torch.utils.data.Subset(testset, subset_idx) 111 | # test_loader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size, shuffle=False, **kwargs) 112 | # n_class = 10 113 | # elif dset_name == 'imagenet': 114 | # valdir = os.path.join(data_root, 'imagenet/val') 115 | # transform_test = transforms.Compose([ 116 | # transforms.Resize(256), 117 | # transforms.CenterCrop(224), 118 | # transforms.ToTensor(), 119 | # ]) 120 | # testset = datasets.ImageFolder(valdir, transform_test) 121 | # if subset_idx is not None: 122 | # testset = torch.utils.data.Subset(testset, subset_idx) 123 | # test_loader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size, shuffle=False, **kwargs) 124 | # n_class = 1000 125 | # else: 126 | # raise NotImplementedError 127 | # return test_loader, n_class 128 | from PIL import Image 129 | import os 130 | import os.path 131 | import numpy as np 132 | import pickle 133 | from typing import Any, Callable, Optional, Tuple 134 | 135 | from torchvision.datasets.vision import VisionDataset 136 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive 137 | from torchvision.datasets.cifar import CIFAR10 138 | 139 | class CIFAR100_coarse(CIFAR10): 140 | """`CIFAR10 `_ Dataset. 141 | Args: 142 | root (string): Root directory of dataset where directory 143 | ``cifar-10-batches-py`` exists or will be saved to if download is set to True. 144 | train (bool, optional): If True, creates dataset from training set, otherwise 145 | creates from test set. 146 | transform (callable, optional): A function/transform that takes in an PIL image 147 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 148 | target_transform (callable, optional): A function/transform that takes in the 149 | target and transforms it. 150 | download (bool, optional): If true, downloads the dataset from the internet and 151 | puts it in root directory. If dataset is already downloaded, it is not 152 | downloaded again. 153 | """ 154 | base_folder = 'cifar-100-python' 155 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 156 | filename = "cifar-100-python.tar.gz" 157 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 158 | train_list = [ 159 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 160 | ] 161 | 162 | test_list = [ 163 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 164 | ] 165 | meta = { 166 | 'filename': 'meta', 167 | 'key': 'coarse_label_names', 168 | 'md5': '7973b15100ade9c7d40fb424638fde48', 169 | } 170 | 171 | def __init__( 172 | self, 173 | root: str, 174 | train: bool = True, 175 | transform = None, 176 | target_transform = None, 177 | download: bool = False, 178 | ) -> None: 179 | 180 | super(CIFAR10, self).__init__(root, transform=transform, 181 | target_transform=target_transform) 182 | 183 | self.train = train # training set or test set 184 | 185 | if download: 186 | self.download() 187 | 188 | if not self._check_integrity(): 189 | raise RuntimeError('Dataset not found or corrupted.' + 190 | ' You can use download=True to download it') 191 | 192 | if self.train: 193 | downloaded_list = self.train_list 194 | else: 195 | downloaded_list = self.test_list 196 | 197 | self.data: Any = [] 198 | self.targets = [] 199 | 200 | # now load the picked numpy arrays 201 | for file_name, checksum in downloaded_list: 202 | file_path = os.path.join(self.root, self.base_folder, file_name) 203 | with open(file_path, 'rb') as f: 204 | entry = pickle.load(f, encoding='latin1') 205 | self.data.append(entry['data']) 206 | if 'labels' in entry: 207 | self.targets.extend(entry['labels']) 208 | else: 209 | self.targets.extend(entry['coarse_labels']) 210 | 211 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 212 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 213 | 214 | self._load_meta() -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import argparse 6 | 7 | def args_parser(): 8 | parser = argparse.ArgumentParser() 9 | # federated arguments 10 | parser.add_argument('--epochs', type=int, default=2000, help="rounds of training") 11 | parser.add_argument('--num_users', type=int, default=100, help="number of users: K") 12 | parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C") 13 | parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E") 14 | parser.add_argument('--local_bs', type=int, default=50, help="local batch size: B") 15 | parser.add_argument('--bs', type=int, default=128, help="test batch size") 16 | parser.add_argument('--optimizer', type=str, default='sgd', help='the optimizer') 17 | parser.add_argument('--lr', type=float, default=0.01, help="learning rate") 18 | parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)") 19 | parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample") 20 | parser.add_argument("--algorithm", type=str, default="FedAvg") 21 | 22 | # model arguments 23 | parser.add_argument('--model', type=str, default='resnet20', help='model name') 24 | parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel') 25 | parser.add_argument('--kernel_sizes', type=str, default='3,4,5', 26 | help='comma-separated kernel size to use for convolution') 27 | parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None") 28 | parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets") 29 | parser.add_argument('--max_pool', type=str, default='True', 30 | help="Whether use max pooling rather than strided convolutions") 31 | parser.add_argument('--use_project_head', type=int, default=0) 32 | parser.add_argument('--out_dim', type=int, default=256, help='the output dimension for the projection layer') 33 | 34 | # other arguments 35 | parser.add_argument('--dataset', type=str, default='cifar10', help="name of dataset") 36 | parser.add_argument('--generate_data', type=int, default=1, help="whether generate new dataset") 37 | parser.add_argument('--iid', type=int, default=1, help='whether i.i.d or not') 38 | parser.add_argument('--noniid_case', type=int, default=0, help="non i.i.d case (1, 2, 3, 4)") 39 | parser.add_argument('--data_beta', type=float, default=0.5, 40 | help='The parameter for the dirichlet distribution for data partitioning') 41 | parser.add_argument('--num_classes', type=int, default=10, help="number of classes") 42 | parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges") 43 | parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU") 44 | parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping') 45 | parser.add_argument('--verbose', action='store_true', help='verbose print') 46 | parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') 47 | 48 | parser.add_argument('--prox_alpha', type=float, default=0.01, help='The hypter parameter for the FedProx') 49 | parser.add_argument('--temperature', type=float, default=0.5, help='the temperature parameter for contrastive loss') 50 | parser.add_argument('--model_buffer_size', type=int, default=1, 51 | help='store how many previous models for contrastive loss') 52 | parser.add_argument('--pool_option', type=str, default='FIFO', help='FIFO or BOX') 53 | parser.add_argument('--sim_type', type=str, default='L1', help='Cluster Sampling: cosine or L1 or L2') 54 | 55 | # FedMut 56 | parser.add_argument('--radius', type=float, default=5.0) 57 | parser.add_argument('--min_radius', type=float, default=0.1) 58 | parser.add_argument('--mut_acc_rate', type=float, default=0.3) 59 | parser.add_argument('--mut_bound', type=int, default=100) 60 | 61 | # FedMR arguments 62 | parser.add_argument("--first_stage_bound", type=int, default=0) 63 | parser.add_argument("--fedmr_partition", type=float, default=0.0) 64 | 65 | 66 | 67 | args = parser.parse_args() 68 | return args 69 | -------------------------------------------------------------------------------- /utils/sam_minimizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict 3 | 4 | class ASAM: 5 | def __init__(self, optimizer, model, rho=0.5, eta=0.01): 6 | self.optimizer = optimizer 7 | self.model = model 8 | self.rho = rho 9 | self.eta = eta 10 | self.state = defaultdict(dict) 11 | 12 | @torch.no_grad() 13 | def ascent_step(self): 14 | wgrads = [] 15 | for n, p in self.model.named_parameters(): 16 | if p.grad is None: 17 | continue 18 | t_w = self.state[p].get("eps") 19 | if t_w is None: 20 | t_w = torch.clone(p).detach() 21 | self.state[p]["eps"] = t_w 22 | if 'weight' in n: 23 | t_w[...] = p[...] 24 | t_w.abs_().add_(self.eta) 25 | p.grad.mul_(t_w) 26 | wgrads.append(torch.norm(p.grad, p=2)) 27 | wgrad_norm = torch.norm(torch.stack(wgrads), p=2) + 1.e-16 28 | for n, p in self.model.named_parameters(): 29 | if p.grad is None: 30 | continue 31 | t_w = self.state[p].get("eps") 32 | if 'weight' in n: 33 | p.grad.mul_(t_w) 34 | eps = t_w 35 | eps[...] = p.grad[...] 36 | eps.mul_(self.rho / wgrad_norm) 37 | p.add_(eps) 38 | self.optimizer.zero_grad() 39 | 40 | @torch.no_grad() 41 | def descent_step(self): 42 | for n, p in self.model.named_parameters(): 43 | if p.grad is None: 44 | continue 45 | p.sub_(self.state[p]["eps"]) 46 | self.optimizer.step() 47 | self.optimizer.zero_grad() 48 | 49 | class SAM: 50 | def __init__(self, optimizer, model, rho=0.5, eta=0.01): 51 | self.optimizer = optimizer 52 | self.model = model 53 | self.rho = rho 54 | self.eta = eta 55 | self.state = defaultdict(dict) 56 | 57 | @torch.no_grad() 58 | def ascent_step(self): 59 | grads = [] 60 | for n, p in self.model.named_parameters(): 61 | if p.grad is None: 62 | continue 63 | grads.append(torch.norm(p.grad, p=2)) 64 | grad_norm = torch.norm(torch.stack(grads), p=2) + 1.e-16 65 | for n, p in self.model.named_parameters(): 66 | if p.grad is None: 67 | continue 68 | eps = self.state[p].get("eps") 69 | if eps is None: 70 | eps = torch.clone(p).detach() 71 | self.state[p]["eps"] = eps 72 | eps[...] = p.grad[...] 73 | eps.mul_(self.rho / grad_norm) 74 | p.add_(eps) 75 | self.optimizer.zero_grad() 76 | 77 | @torch.no_grad() 78 | def descent_step(self): 79 | for n, p in self.model.named_parameters(): 80 | if p.grad is None: 81 | continue 82 | p.sub_(self.state[p]["eps"]) 83 | self.optimizer.step() 84 | self.optimizer.zero_grad() -------------------------------------------------------------------------------- /utils/sampling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import random 6 | import numpy as np 7 | from torchvision import datasets, transforms 8 | 9 | 10 | def mnist_iid(dataset, num_users): 11 | return iid(dataset, num_users) 12 | 13 | 14 | def mnist_noniid(dataset, num_users, case=1): 15 | num_shards, num_imgs = 100, 600 16 | return non_iid(dataset, num_users, num_shards, num_imgs, case) 17 | 18 | 19 | def fashion_mnist_iid(dataset, num_users): 20 | return iid(dataset, num_users) 21 | 22 | 23 | def fashion_mnist_noniid(dataset, num_users, case=1): 24 | num_shards, num_imgs = 100, 600 25 | return non_iid(dataset, num_users, num_shards, num_imgs, case) 26 | 27 | 28 | def cifar_iid(dataset, num_users): 29 | return iid(dataset, num_users) 30 | 31 | 32 | def cifar_noniid(dataset, num_users, case=1): 33 | num_shards, num_imgs = 100, 500 34 | return non_iid(dataset, num_users, num_shards, num_imgs, case) 35 | 36 | 37 | def cifar100_iid(dataset, num_users): 38 | return iid(dataset, num_users) 39 | 40 | 41 | def cifar100_noniid(dataset, num_users, case=1): 42 | num_shards, num_imgs = 100, 500 43 | return non_iid(dataset, num_users, num_shards, num_imgs, case) 44 | 45 | 46 | def svhn_iid(dataset, num_users): 47 | return iid(dataset, num_users) 48 | 49 | 50 | def svhn_noniid(dataset, num_users, case=1): 51 | num_shards, num_imgs = 100, 700 52 | return non_iid(dataset, num_users, num_shards, num_imgs, case) 53 | 54 | 55 | def iid(dataset, num_users): 56 | num_items = int(len(dataset) / num_users) 57 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 58 | for i in range(num_users): 59 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) 60 | all_idxs = list(set(all_idxs) - dict_users[i]) 61 | 62 | for i in range(num_users): 63 | dict_users[i] = np.array(list(dict_users[i])).tolist() 64 | return dict_users 65 | 66 | 67 | def non_iid(dataset, num_users, num_shards, num_imgs, case=1): 68 | if case == 1: 69 | return noniid_ratio_r_label_1(dataset, num_users, num_shards, num_imgs) 70 | elif case == 2: 71 | return noniid_label_2(dataset, num_users, int(num_shards * 2), int(num_imgs / 2)) 72 | elif case == 3: 73 | return noniid_ratio_r_label_1(dataset, num_users, num_shards, num_imgs, ratio=0.8) 74 | elif case == 4: 75 | return noniid_ratio_r_label_1(dataset, num_users, num_shards, num_imgs, ratio=0.5) 76 | else: 77 | exit('Error: unrecognized noniid case') 78 | 79 | 80 | def noniid_ratio_r_label_1(dataset, num_users, num_shards, num_imgs, ratio=1): 81 | idx_shard = [i for i in range(num_shards)] 82 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 83 | idxs = np.arange(num_shards * num_imgs) 84 | labels = dataset.targets 85 | 86 | # sort labels 87 | idxs_labels = np.vstack((idxs, labels)) 88 | idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()] 89 | idxs = idxs_labels[0, :] 90 | 91 | for i in range(num_users): 92 | rand_set = set(np.random.choice(idx_shard, 1, replace=False)) 93 | idx_shard = list(set(idx_shard) - rand_set) 94 | for rand in rand_set: 95 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand * num_imgs:int((rand + ratio) * num_imgs)]), 96 | axis=0) 97 | random.shuffle(dict_users[i]) 98 | 99 | if ratio < 1: 100 | rest_idxs = np.array([], dtype='int64') 101 | idx_shard = [i for i in range(num_shards)] 102 | for i in idx_shard: 103 | rest_idxs = np.concatenate((rest_idxs, idxs[int((i + ratio) * num_imgs):(i + 1) * num_imgs]), axis=0) 104 | num_items = int(len(dataset) / num_users * (1 - ratio)) 105 | for i in range(num_users): 106 | rest_to_add = set(np.random.choice(rest_idxs, num_items, replace=False)) 107 | dict_users[i] = np.concatenate((dict_users[i], list(rest_to_add)), axis=0) 108 | rest_idxs = list(set(rest_idxs) - rest_to_add) 109 | random.shuffle(dict_users[i]) 110 | 111 | for i in range(num_users): 112 | dict_users[i] = dict_users[i].tolist() 113 | 114 | return dict_users 115 | 116 | 117 | def noniid_label_2(dataset, num_users, num_shards, num_imgs): 118 | idx_shard = [i for i in range(num_shards)] 119 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 120 | idxs = np.arange(num_shards * num_imgs) 121 | labels = dataset.targets 122 | 123 | # sort labels 124 | idxs_labels = np.vstack((idxs, labels)) 125 | idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()] 126 | idxs = idxs_labels[0, :] 127 | 128 | for i in range(num_users): 129 | len_idx_shard = len(idx_shard) 130 | rand1 = np.random.choice(idx_shard[0:int(len_idx_shard / 2)], 1, replace=False)[0] 131 | rand2 = np.random.choice(idx_shard[int(len_idx_shard / 2):len_idx_shard], 1, replace=False)[0] 132 | idx_shard = list(set(idx_shard) - set([rand1, rand2])) 133 | for rand in [rand1, rand2]: 134 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand * num_imgs:int((rand + 1) * num_imgs)]), axis=0) 135 | random.shuffle(dict_users[i]) 136 | return dict_users 137 | 138 | 139 | if __name__ == '__main__': 140 | 141 | trans = transforms.Compose([transforms.ToTensor()]) 142 | dataset_train = datasets.SVHN('../data/svhn/', split='train', download=True, transform=trans) 143 | # trans = transforms.Compose([transforms.ToTensor()]) 144 | # dataset_train = datasets.FashionMNIST('../data/fashion-mnist/', train=True, download=True, transform=trans) 145 | # trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 146 | # dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist) 147 | # trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 148 | # dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar) 149 | num = 100 150 | d = svhn_noniid(dataset_train, num) 151 | for user_idx in d: 152 | print(user_idx) 153 | print([dataset_train[img_idx][1] for img_idx in d[user_idx]]) 154 | -------------------------------------------------------------------------------- /utils/set_seed.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def set_random_seed(seed): 7 | """ 8 | set random seed 9 | """ 10 | random.seed(seed) 11 | np.random.seed(seed) 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed(seed) 14 | torch.cuda.manual_seed_all(seed) 15 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import datetime 5 | import os 6 | 7 | def save_result(data, ylabel, args): 8 | data = {'base' :data} 9 | 10 | path = './output/{}'.format(args.noniid_case) 11 | 12 | if args.noniid_case != 5: 13 | file = '{}_{}_{}_{}_{}_lr_{}_{}_frac_{}_{}.txt'.format(args.dataset, args.algorithm, args.model, 14 | ylabel, args.epochs, args.lr, datetime.datetime.now().strftime( 15 | "%Y_%m_%d_%H_%M_%S"),args.frac, args.num_users) 16 | else: 17 | path += '/{}'.format(args.data_beta) 18 | file = '{}_{}_{}_{}_{}_lr_{}_{}_frac_{}_{}.txt'.format(args.dataset, args.algorithm,args.model, 19 | ylabel, args.epochs, args.lr, 20 | datetime.datetime.now().strftime( 21 | "%Y_%m_%d_%H_%M_%S"),args.frac, args.num_users) 22 | 23 | if not os.path.exists(path): 24 | os.makedirs(path) 25 | 26 | with open(os.path.join(path,file), 'a') as f: 27 | for label in data: 28 | f.write(label) 29 | f.write(' ') 30 | for item in data[label]: 31 | item1 = str(item) 32 | f.write(item1) 33 | f.write(' ') 34 | f.write('\n') 35 | print('save finished') 36 | f.close() 37 | 38 | def save_fedmut_result(data, ylabel, args): 39 | data = {'base' :data} 40 | 41 | path = './output/{}'.format(args.noniid_case) 42 | 43 | if args.noniid_case != 5: 44 | file = '{}_{}_{}_{}_{}_lr_{}_{}_frac_{}_radius_{}_accrate_{}_bound_{}.txt'.format(args.dataset, args.algorithm, args.model, 45 | ylabel, args.epochs, args.lr, datetime.datetime.now().strftime( 46 | "%Y_%m_%d_%H_%M_%S"),args.frac,args.radius,args.mut_acc_rate,args.mut_bound) 47 | else: 48 | path += '/{}'.format(args.data_beta) 49 | file = '{}_{}_{}_{}_{}_lr_{}_{}_frac_{}_radius_{}_accrate_{}_bound_{}.txt'.format(args.dataset, args.algorithm,args.model, 50 | ylabel, args.epochs, args.lr, 51 | datetime.datetime.now().strftime( 52 | "%Y_%m_%d_%H_%M_%S"),args.frac,args.radius,args.mut_acc_rate,args.mut_bound) 53 | 54 | if not os.path.exists(path): 55 | os.makedirs(path) 56 | 57 | with open(os.path.join(path,file), 'a') as f: 58 | for label in data: 59 | f.write(label) 60 | f.write(' ') 61 | for item in data[label]: 62 | item1 = str(item) 63 | f.write(item1) 64 | f.write(' ') 65 | f.write('\n') 66 | print('save finished') 67 | f.close() 68 | 69 | 70 | def save_model(data, ylabel, args): 71 | 72 | path = './output/{}'.format(args.noniid_case) 73 | 74 | if args.noniid_case != 5: 75 | file = '{}_{}_{}_{}_{}_lr_{}_{}_frac_{}.txt'.format(args.dataset, args.algorithm, args.model, 76 | ylabel, args.epochs, args.lr, datetime.datetime.now().strftime( 77 | "%Y_%m_%d_%H_%M_%S"),args.frac) 78 | else: 79 | path += '/{}'.format(args.data_beta) 80 | file = '{}_{}_{}_{}_{}_lr_{}_{}_frac_{}.txt'.format(args.dataset, args.algorithm,args.model, 81 | ylabel, args.epochs, args.lr, 82 | datetime.datetime.now().strftime( 83 | "%Y_%m_%d_%H_%M_%S"),args.frac) 84 | 85 | if not os.path.exists(path): 86 | os.makedirs(path) 87 | 88 | # with open(os.path.join(path,file), 'a') as f: 89 | # for label in data: 90 | # f.write(label) 91 | # f.write(' ') 92 | # for item in data[label]: 93 | # item1 = str(item) 94 | # f.write(item1) 95 | # f.write(' ') 96 | # f.write('\n') 97 | torch.save(data, os.path.join(path,file)) 98 | print('save finished') 99 | # f.close() 100 | --------------------------------------------------------------------------------