├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── model_config.cpython-38.pyc │ ├── model_config.cpython-39.pyc │ ├── model_utils.cpython-38.pyc │ ├── model_utils.cpython-39.pyc │ ├── plot_utils.cpython-38.pyc │ └── plot_utils.cpython-39.pyc ├── model_config.py ├── model_config-base.py ├── plot_utils.py └── model_utils.py ├── 组会汇报FEDCL.pptx ├── requirements.txt ├── FLAlgorithms ├── users │ ├── __pycache__ │ │ ├── useravg.cpython-38.pyc │ │ ├── useravg.cpython-39.pyc │ │ ├── userbase.cpython-38.pyc │ │ ├── userbase.cpython-39.pyc │ │ ├── userFedProx.cpython-38.pyc │ │ ├── userFedProx.cpython-39.pyc │ │ ├── userpFedGen.cpython-38.pyc │ │ ├── userpFedGen.cpython-39.pyc │ │ ├── userFedDistill.cpython-38.pyc │ │ └── userFedDistill.cpython-39.pyc │ ├── useravg.py │ ├── userFedProx.py │ ├── userGen.py │ ├── userpFedGen.py │ ├── userFedDistill.py │ ├── userpFedEnsemble.py │ ├── userbase.py │ └── userpFedCL.py ├── servers │ ├── __pycache__ │ │ ├── serveravg.cpython-38.pyc │ │ ├── serveravg.cpython-39.pyc │ │ ├── serverbase.cpython-38.pyc │ │ ├── serverbase.cpython-39.pyc │ │ ├── serverFedProx.cpython-38.pyc │ │ ├── serverFedProx.cpython-39.pyc │ │ ├── serverpFedGen.cpython-38.pyc │ │ ├── serverpFedGen.cpython-39.pyc │ │ ├── serverFedDistill.cpython-38.pyc │ │ ├── serverFedDistill.cpython-39.pyc │ │ ├── serverpFedEnsemble.cpython-38.pyc │ │ └── serverpFedEnsemble.cpython-39.pyc │ ├── serverFedProx.py │ ├── serverpFedEnsemble.py │ ├── serveravg.py │ ├── serverFedDistill.py │ ├── serverbase.py │ ├── serverpFedGen.py │ └── serverpFedCL.py ├── trainmodel │ ├── __pycache__ │ │ ├── models.cpython-38.pyc │ │ ├── models.cpython-39.pyc │ │ ├── generator.cpython-38.pyc │ │ └── generator.cpython-39.pyc │ ├── models.py │ └── generator.py ├── curriculum │ ├── __pycache__ │ │ ├── cl_score.cpython-38.pyc │ │ └── cl_score.cpython-39.pyc │ └── cl_score.py └── optimizers │ ├── __pycache__ │ ├── fedoptimizer.cpython-38.pyc │ └── fedoptimizer.cpython-39.pyc │ ├── fedoptimizer.py │ └── loss.py ├── data ├── CelebA │ ├── README.md │ └── generate_niid_agg.py ├── Mnist │ └── generate_niid_dirichlet.py └── EMnist │ └── generate_niid_dirichlet.py ├── main_plot.py ├── main.py └── README.md /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /组会汇报FEDCL.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/组会汇报FEDCL.pptx -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | Pillow 4 | torchvision 5 | matplotlib 6 | tqdm 7 | h5py 8 | sklearn 9 | seaborn 10 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/utils/__pycache__/model_config.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/utils/__pycache__/model_config.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/utils/__pycache__/model_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/utils/__pycache__/model_utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plot_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/utils/__pycache__/plot_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plot_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/utils/__pycache__/plot_utils.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/useravg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/useravg.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/useravg.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/useravg.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/userbase.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/userbase.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/userbase.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/userbase.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serveravg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serveravg.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serveravg.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serveravg.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/trainmodel/__pycache__/models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/trainmodel/__pycache__/models.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/trainmodel/__pycache__/models.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/trainmodel/__pycache__/models.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/userFedProx.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/userFedProx.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/userFedProx.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/userFedProx.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/userpFedGen.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/userpFedGen.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/userpFedGen.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/userpFedGen.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/curriculum/__pycache__/cl_score.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/curriculum/__pycache__/cl_score.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/curriculum/__pycache__/cl_score.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/curriculum/__pycache__/cl_score.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverbase.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverbase.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverbase.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverbase.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/trainmodel/__pycache__/generator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/trainmodel/__pycache__/generator.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/trainmodel/__pycache__/generator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/trainmodel/__pycache__/generator.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/userFedDistill.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/userFedDistill.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/userFedDistill.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/userFedDistill.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverFedProx.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverFedProx.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverFedProx.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverFedProx.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverpFedGen.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverpFedGen.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverpFedGen.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverpFedGen.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/optimizers/__pycache__/fedoptimizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/optimizers/__pycache__/fedoptimizer.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/optimizers/__pycache__/fedoptimizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/optimizers/__pycache__/fedoptimizer.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverFedDistill.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverFedDistill.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverFedDistill.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverFedDistill.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverpFedEnsemble.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverpFedEnsemble.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverpFedEnsemble.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverpFedEnsemble.cpython-39.pyc -------------------------------------------------------------------------------- /data/CelebA/README.md: -------------------------------------------------------------------------------- 1 | #### To generated CelebA dataset 2 | ##### Step 1. 3 | follow the [LEAF instructions](https://github.com/TalwalkarLab/leaf/tree/master/data/celeba) to download the raw celeb data, and generate `train` and `test` subfolders. 4 | 5 | #### Step 2. 6 | change [`LOAD_PATH`](https://github.com/zhuangdizhu/FedGen/blob/05625ef130f681075fb04b804322e33ef31f6dea/data/CelebA/generate_niid_agg.py#L15) in the `generate_niid_agg.py` to point to the folder of the raw celeba downloaded in step 1. 7 | 8 | #### Step 3. 9 | run `generate_niid_agg.py` to generate FL training and testing dataset. 10 | For example, to generate data for 25 FL devices, where each device contains images of 10 celebrities: 11 | ``` 12 | python generate_niid_agg.py --agg_user 10 --ratio 250 13 | ``` 14 | -------------------------------------------------------------------------------- /main_plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import h5py 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import importlib 6 | import random 7 | import os 8 | import argparse 9 | from utils.plot_utils import * 10 | import torch 11 | torch.manual_seed(0) 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--dataset", type=str, default="Mnist") 16 | parser.add_argument("--algorithms", type=str, default="FedAvg,Fedgen", help='algorithm names separate by comma') 17 | parser.add_argument("--result_path", type=str, default="results", help="directory path to save results") 18 | parser.add_argument("--model", type=str, default="cnn") 19 | parser.add_argument("--learning_rate", type=float, default=0.01, help='learning rate.') 20 | parser.add_argument("--local_epochs", type=int, default=20) 21 | parser.add_argument("--num_glob_iters", type=int, default=200) 22 | parser.add_argument("--min_acc", type=float, default=-1.0) 23 | parser.add_argument("--num_users", type=int, default=5, help='number of active users per epoch.') 24 | parser.add_argument("--batch_size", type=int, default=32) 25 | parser.add_argument("--gen_batch_size", type=int, default=32) 26 | parser.add_argument("--plot_legend", type=int, default=1, help='plot legend if set to 1, omitted otherwise.') 27 | parser.add_argument("--times", type=int, default=3, help='number of random seeds, starting from 1.') 28 | parser.add_argument("--embedding", type=int, default=0, help="Use embedding layer in generator network") 29 | args = parser.parse_args() 30 | 31 | algorithms = [a.strip() for a in args.algorithms.split(',')] 32 | title = 'epoch{}'.format(args.local_epochs) 33 | plot_results(args, algorithms) 34 | -------------------------------------------------------------------------------- /FLAlgorithms/users/useravg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from FLAlgorithms.users.userbase import User 3 | 4 | class UserAVG(User): 5 | def __init__(self, args, id, model, train_data, test_data, use_adam=False): 6 | super().__init__(args, id, model, train_data, test_data, use_adam=use_adam) 7 | 8 | def update_label_counts(self, labels, counts): 9 | for label, count in zip(labels, counts): 10 | self.label_counts[int(label)] += count 11 | 12 | 13 | def clean_up_counts(self): 14 | del self.label_counts 15 | self.label_counts = {int(label):1 for label in range(self.unique_labels)} 16 | 17 | def train(self, glob_iter, personalized=False, lr_decay=True, count_labels=True): 18 | self.clean_up_counts() 19 | self.model.train() 20 | for epoch in range(1, self.local_epochs + 1): 21 | self.model.train() 22 | for i in range(self.K): 23 | result =self.get_next_train_batch(count_labels=count_labels) 24 | X, y = result['X'], result['y'] 25 | if count_labels: 26 | self.update_label_counts(result['labels'], result['counts']) 27 | 28 | self.optimizer.zero_grad() 29 | output=self.model(X)['output'] 30 | loss=self.loss(output, y) 31 | loss.backward() 32 | self.optimizer.step()#self.plot_Celeb) 33 | 34 | # local-model <=== self.model 35 | self.clone_model_paramenter(self.model.parameters(), self.local_model) 36 | if personalized: 37 | self.clone_model_paramenter(self.model.parameters(), self.personalized_model_bar) 38 | # local-model ===> self.model 39 | #self.clone_model_paramenter(self.local_model, self.model.parameters()) 40 | if lr_decay: 41 | self.lr_scheduler.step(glob_iter) 42 | -------------------------------------------------------------------------------- /FLAlgorithms/users/userFedProx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from FLAlgorithms.users.userbase import User 3 | from FLAlgorithms.optimizers.fedoptimizer import FedProxOptimizer 4 | 5 | class UserFedProx(User): 6 | def __init__(self, args, id, model, train_data, test_data, use_adam=False): 7 | super().__init__(args, id, model, train_data, test_data, use_adam=use_adam) 8 | 9 | self.optimizer = FedProxOptimizer(self.model.parameters(), lr=self.learning_rate, lamda=self.lamda) 10 | self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.optimizer, gamma=0.99) 11 | 12 | def update_label_counts(self, labels, counts): 13 | for label, count in zip(labels, counts): 14 | self.label_counts[int(label)] += count 15 | 16 | def clean_up_counts(self): 17 | del self.label_counts 18 | self.label_counts = {int(label):1 for label in range(self.unique_labels)} 19 | 20 | def train(self, glob_iter, lr_decay=True, count_labels=False): 21 | self.clean_up_counts() 22 | self.model.train() 23 | # cache global model initialized value to local model 24 | self.clone_model_paramenter(self.local_model, self.model.parameters()) 25 | for epoch in range(self.local_epochs): 26 | self.model.train() 27 | for i in range(self.K): 28 | result =self.get_next_train_batch(count_labels=count_labels) 29 | X, y = result['X'], result['y'] 30 | if count_labels: 31 | self.update_label_counts(result['labels'], result['counts']) 32 | 33 | self.optimizer.zero_grad() 34 | output=self.model(X)['output'] 35 | loss=self.loss(output, y) 36 | loss.backward() 37 | self.optimizer.step(self.local_model) 38 | if lr_decay: 39 | self.lr_scheduler.step(glob_iter) 40 | -------------------------------------------------------------------------------- /FLAlgorithms/servers/serverFedProx.py: -------------------------------------------------------------------------------- 1 | from FLAlgorithms.users.userFedProx import UserFedProx 2 | from FLAlgorithms.servers.serverbase import Server 3 | from utils.model_utils import read_data, read_user_data 4 | # Implementation for FedProx Server 5 | 6 | class FedProx(Server): 7 | def __init__(self, args, model, seed): 8 | #dataset, algorithm, model, batch_size, learning_rate, beta, lamda, num_glob_iters, 9 | # local_epochs, num_users, K, personal_learning_rate, times): 10 | super().__init__(args, model, seed)#dataset, algorithm, model, batch_size, learning_rate, beta, lamda, num_glob_iters, 11 | #local_epochs, num_users, times) 12 | 13 | # Initialize data for all users 14 | data = read_data(args.dataset) 15 | total_users = len(data[0]) 16 | print("Users in total: {}".format(total_users)) 17 | 18 | for i in range(total_users): 19 | id, train_data , test_data = read_user_data(i, data, dataset=args.dataset) 20 | user = UserFedProx(args, id, model, train_data, test_data, use_adam=False) 21 | self.users.append(user) 22 | self.total_train_samples += user.train_samples 23 | 24 | print("Number of users / total users:", self.num_users, " / " ,total_users) 25 | print("Finished creating FedAvg server.") 26 | 27 | def train(self, args): 28 | for glob_iter in range(self.num_glob_iters): 29 | print("\n\n-------------Round number: ",glob_iter, " -------------\n\n") 30 | self.selected_users = self.select_users(glob_iter,self.num_users) 31 | self.send_parameters() 32 | self.evaluate() 33 | for user in self.selected_users: # allow selected users to train 34 | user.train(glob_iter) 35 | self.aggregate_parameters() 36 | self.save_results(args) 37 | self.save_model() -------------------------------------------------------------------------------- /FLAlgorithms/optimizers/fedoptimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Optimizer 2 | 3 | class pFedIBOptimizer(Optimizer): 4 | def __init__(self, params, lr=0.01): 5 | # self.local_weight_updated = local_weight # w_i,K 6 | if lr < 0.0: 7 | raise ValueError("Invalid learning rate: {}".format(lr)) 8 | defaults=dict(lr=lr) 9 | super(pFedIBOptimizer, self).__init__(params, defaults) 10 | 11 | def step(self, apply=True, lr=None, allow_unused=False): 12 | grads = [] 13 | # apply gradient to model.parameters, and return the gradients 14 | for group in self.param_groups: 15 | for p in group['params']: 16 | if p.grad is None and allow_unused: 17 | continue 18 | grads.append(p.grad.data) 19 | if apply: 20 | if lr == None: 21 | p.data= p.data - group['lr'] * p.grad.data 22 | else: 23 | p.data=p.data - lr * p.grad.data 24 | return grads 25 | 26 | 27 | def apply_grads(self, grads, beta=None, allow_unused=False): 28 | #apply gradient to model.parameters 29 | i = 0 30 | for group in self.param_groups: 31 | for p in group['params']: 32 | if p.grad is None and allow_unused: 33 | continue 34 | p.data= p.data - group['lr'] * grads[i] if beta == None else p.data - beta * grads[i] 35 | i += 1 36 | return 37 | 38 | 39 | class FedProxOptimizer(Optimizer): 40 | def __init__(self, params, lr=0.01, lamda=0.1, mu=0.001): 41 | if lr < 0.0: 42 | raise ValueError("Invalid learning rate: {}".format(lr)) 43 | defaults=dict(lr=lr, lamda=lamda, mu=mu) 44 | super(FedProxOptimizer, self).__init__(params, defaults) 45 | 46 | def step(self, vstar, closure=None): 47 | loss=None 48 | if closure is not None: 49 | loss=closure 50 | for group in self.param_groups: 51 | for p, pstar in zip(group['params'], vstar): 52 | # w <=== w - lr * ( w' + lambda * (w - w* ) + mu * w ) 53 | p.data=p.data - group['lr'] * ( 54 | p.grad.data + group['lamda'] * (p.data - pstar.data.clone()) + group['mu'] * p.data) 55 | return group['params'], loss 56 | -------------------------------------------------------------------------------- /FLAlgorithms/users/userGen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | import json 6 | from FLAlgorithms.users.userbase import User 7 | from FLAlgorithms.optimizers.fedoptimizer import pFedIBOptimizer 8 | from utils.model_utils import get_dataset_name, CONFIGS 9 | 10 | # Implementation for FedAvg clients 11 | 12 | class UserAVG(User): 13 | def __init__(self, dataset, algorithm, numeric_id, train_data, test_data, model, batch_size, learning_rate, beta, lamda, 14 | local_epochs, K): 15 | super().__init__(dataset, algorithm, numeric_id, train_data, test_data, model, batch_size, learning_rate, beta, lamda, 16 | local_epochs) 17 | 18 | dataset_name = get_dataset_name(dataset) 19 | self.unique_labels = CONFIGS[dataset_name]['unique_labels'] 20 | 21 | def update_label_counts(self, labels, counts): 22 | for label, count in zip(labels, counts): 23 | self.label_counts[int(label)] += count 24 | 25 | 26 | def clean_up_counts(self): 27 | del self.label_counts 28 | self.label_counts = {int(label):1 for label in range(self.unique_labels)} 29 | 30 | def train(self, glob_iter, personalized=False, lr_decay=True, count_labels=True): 31 | self.clean_up_counts() 32 | self.model.train() 33 | for epoch in range(1, self.local_epochs + 1): 34 | self.model.train() 35 | for i in range(self.K): 36 | result =self.get_next_train_batch(count_labels=count_labels) 37 | X, y = result['X'], result['y'] 38 | if count_labels: 39 | self.update_label_counts(result['labels'], result['counts']) 40 | 41 | self.optimizer.zero_grad() 42 | output=self.model(X)['output'] 43 | loss=self.loss(output, y) 44 | loss.backward() 45 | self.optimizer.step()#self.local_model) 46 | 47 | # local-model <=== self.model 48 | self.clone_model_paramenter(self.model.parameters(), self.local_model) 49 | if personalized: 50 | self.clone_model_paramenter(self.model.parameters(), self.personalized_model_bar) 51 | # local-model ===> self.model 52 | #self.clone_model_paramenter(self.local_model, self.model.parameters()) 53 | if lr_decay: 54 | self.lr_scheduler.step(glob_iter) 55 | -------------------------------------------------------------------------------- /FLAlgorithms/users/userpFedGen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | import json 6 | from FLAlgorithms.users.userbase import User 7 | from FLAlgorithms.optimizers.fedoptimizer import pFedIBOptimizer 8 | from utils.model_utils import get_dataset_name, CONFIGS 9 | 10 | # Implementation for FedAvg clients 11 | 12 | class UserAVG(User): 13 | def __init__(self, dataset, algorithm, numeric_id, train_data, test_data, model, batch_size, learning_rate, beta, lamda, 14 | local_epochs, K): 15 | super().__init__(dataset, algorithm, numeric_id, train_data, test_data, model, batch_size, learning_rate, beta, lamda, 16 | local_epochs) 17 | 18 | dataset_name = get_dataset_name(dataset) 19 | self.unique_labels = CONFIGS[dataset_name]['unique_labels'] 20 | 21 | def update_label_counts(self, labels, counts): 22 | for label, count in zip(labels, counts): 23 | self.label_counts[int(label)] += count 24 | 25 | 26 | def clean_up_counts(self): 27 | del self.label_counts 28 | self.label_counts = {int(label):1 for label in range(self.unique_labels)} 29 | 30 | def train(self, glob_iter, personalized=False, lr_decay=True, count_labels=True): 31 | self.clean_up_counts() 32 | self.model.train() 33 | for epoch in range(1, self.local_epochs + 1): 34 | self.model.train() 35 | for i in range(self.K): 36 | result =self.get_next_train_batch(count_labels=count_labels) 37 | X, y = result['X'], result['y'] 38 | if count_labels: 39 | self.update_label_counts(result['labels'], result['counts']) 40 | 41 | self.optimizer.zero_grad() 42 | output=self.model(X)['output'] 43 | loss=self.loss(output, y) 44 | loss.backward() 45 | self.optimizer.step()#self.local_model) 46 | 47 | # local-model <=== self.model 48 | self.clone_model_paramenter(self.model.parameters(), self.local_model) 49 | if personalized: 50 | self.clone_model_paramenter(self.model.parameters(), self.personalized_model_bar) 51 | # local-model ===> self.model 52 | #self.clone_model_paramenter(self.local_model, self.model.parameters()) 53 | if lr_decay: 54 | self.lr_scheduler.step(glob_iter) -------------------------------------------------------------------------------- /FLAlgorithms/servers/serverpFedEnsemble.py: -------------------------------------------------------------------------------- 1 | from FLAlgorithms.users.useravg import UserAVG 2 | from FLAlgorithms.servers.serverbase import Server 3 | from utils.model_utils import read_data, read_user_data, aggregate_user_test_data 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | class FedEnsemble(Server): 8 | def __init__(self, args, model, seed): 9 | super().__init__(args, model, seed) 10 | 11 | # Initialize data for all users 12 | data = read_data(args.dataset) 13 | # data contains: clients, groups, train_data, test_data, proxy_data 14 | clients = data[0] 15 | total_users = len(clients) 16 | self.total_test_samples = 0 17 | self.slow_start = 20 18 | self.use_adam = 'adam' in self.algorithm.lower() 19 | self.init_ensemble_configs() 20 | self.init_loss_fn() 21 | #### creating users #### 22 | self.users = [] 23 | for i in range(total_users): 24 | id, train_data, test_data, label_info =read_user_data(i, data, dataset=args.dataset, count_labels=True) 25 | self.total_train_samples+=len(train_data) 26 | self.total_test_samples += len(test_data) 27 | user=UserAVG(args, id, model, train_data, test_data, use_adam=self.use_adam) 28 | self.users.append(user) 29 | 30 | #### build test data loader #### 31 | self.testloaderfull, self.unique_labels=aggregate_user_test_data(data, args.dataset, self.total_test_samples) 32 | print("Loading testing data.") 33 | print("Number of Train/Test samples:", self.total_train_samples, self.total_test_samples) 34 | print("Data from {} users in total.".format(total_users)) 35 | print("Finished creating FedAvg server.") 36 | 37 | def train(self, args): 38 | #### pretraining 39 | for glob_iter in range(self.num_glob_iters): 40 | print("\n\n-------------Round number: ",glob_iter, " -------------\n\n") 41 | self.selected_users, self.user_idxs=self.select_users(glob_iter, self.num_users, return_idx=True) 42 | self.send_parameters() 43 | for user_id, user in zip(self.user_idxs, self.selected_users): # allow selected users to train 44 | user.train( 45 | glob_iter, 46 | personalized=False, lr_decay=True, count_labels=True) 47 | self.aggregate_parameters() 48 | self.evaluate_ensemble(selected=False) 49 | 50 | self.save_results(args) 51 | self.save_model() 52 | -------------------------------------------------------------------------------- /FLAlgorithms/servers/serveravg.py: -------------------------------------------------------------------------------- 1 | from FLAlgorithms.users.useravg import UserAVG 2 | from FLAlgorithms.servers.serverbase import Server 3 | from utils.model_utils import read_data, read_user_data 4 | import numpy as np 5 | # Implementation for FedAvg Server 6 | import time 7 | 8 | class FedAvg(Server): 9 | def __init__(self, args, model, seed): 10 | super().__init__(args, model, seed) 11 | 12 | # Initialize data for all users 13 | data = read_data(args.dataset) 14 | total_users = len(data[0]) 15 | self.use_adam = 'adam' in self.algorithm.lower() 16 | print("Users in total: {}".format(total_users)) 17 | 18 | for i in range(total_users): 19 | id, train_data , test_data = read_user_data(i, data, dataset=args.dataset) 20 | user = UserAVG(args, id, model, train_data, test_data, use_adam=False) 21 | self.users.append(user) 22 | self.total_train_samples += user.train_samples 23 | 24 | print("Number of users / total users:",args.num_users, " / " ,total_users) 25 | print("Finished creating FedAvg server.") 26 | 27 | def train(self, args): 28 | for glob_iter in range(self.num_glob_iters): 29 | print("\n\n-------------Round number: ",glob_iter, " -------------\n\n") 30 | self.selected_users = self.select_users(glob_iter,self.num_users) 31 | self.send_parameters(mode=self.mode) 32 | self.evaluate() 33 | self.timestamp = time.time() # log user-training start time 34 | for user in self.selected_users: # allow selected users to train 35 | user.train(glob_iter, personalized=self.personalized) #* user.train_samples 36 | curr_timestamp = time.time() # log user-training end time 37 | train_time = (curr_timestamp - self.timestamp) / len(self.selected_users) 38 | self.metrics['user_train_time'].append(train_time) 39 | # Evaluate selected user 40 | if self.personalized: 41 | # Evaluate personal model on user for each iteration 42 | print("Evaluate personal model\n") 43 | self.evaluate_personalized_model() 44 | 45 | self.timestamp = time.time() # log server-agg start time 46 | self.aggregate_parameters() 47 | curr_timestamp=time.time() # log server-agg end time 48 | agg_time = curr_timestamp - self.timestamp 49 | self.metrics['server_agg_time'].append(agg_time) 50 | self.save_results(args) 51 | self.save_model() -------------------------------------------------------------------------------- /utils/model_config.py: -------------------------------------------------------------------------------- 1 | CONFIGS_ = { 2 | # input_channel, n_class, hidden_dim, latent_dim 3 | 'cifar': ([16, 'M', 32, 'M', 'F'], 3, 10, 2048, 64), 4 | 'cifar100-c25': ([32, 'M', 64, 'M', 128, 'F'], 3, 26, 128, 128), 5 | 'cifar100-c30': ([32, 'M', 64, 'M', 128, 'F'], 3, 30, 2048, 128), 6 | 'cifar100-c50': ([32, 'M', 64, 'M', 128, 'F'], 3, 50, 2048, 128), 7 | 8 | 'emnist': ([6, 16, 'F'], 1, 26, 784, 32), 9 | 'mnist': ([6, 16, 'F'], 1, 10, 784, 32), 10 | 'mnist_cnn1': ([6, 'M', 16, 'M', 'F'], 1, 10, 64, 32), 11 | 'mnist_cnn2': ([16, 'M', 32, 'M', 'F'], 1, 10, 128, 32), 12 | 'celeb': ([16, 'M', 32, 'M', 64, 'M', 'F'], 3, 2, 64, 32) 13 | } 14 | 15 | # temporary roundabout to evaluate sensitivity of the generator 16 | GENERATORCONFIGS = { 17 | # hidden_dimension, latent_dimension, input_channel, n_class, noise_dim 18 | 'cifar': (512, 32, 3, 10, 64), 19 | 'celeb': (128, 32, 3, 2, 32), 20 | 'mnist': (256, 32, 1, 10, 32), 21 | 'mnist-cnn0': (256, 32, 1, 10, 64), 22 | 'mnist-cnn1': (128, 32, 1, 10, 32), 23 | 'mnist-cnn2': (64, 32, 1, 10, 32), 24 | 'mnist-cnn3': (64, 32, 1, 10, 16), 25 | 'emnist': (256, 32, 1, 26, 32), 26 | 'emnist-cnn0': (256, 32, 1, 26, 64), 27 | 'emnist-cnn1': (128, 32, 1, 26, 32), 28 | 'emnist-cnn2': (128, 32, 1, 26, 16), 29 | 'emnist-cnn3': (64, 32, 1, 26, 32), 30 | } 31 | 32 | 33 | 34 | RUNCONFIGS = { 35 | 'emnist': 36 | { 37 | 'ensemble_lr': 1e-4, 38 | 'ensemble_batch_size': 128, 39 | 'ensemble_epochs': 50, 40 | 'num_pretrain_iters': 20, 41 | 'ensemble_alpha': 1, # teacher loss (server side) 42 | 'ensemble_beta': 0, # adversarial student loss 43 | 'unique_labels': 26, 44 | 'generative_alpha':10, 45 | 'generative_beta': 1, 46 | 'weight_decay': 1e-2 47 | }, 48 | 49 | 'mnist': 50 | { 51 | 'ensemble_lr': 3e-4, 52 | 'ensemble_batch_size': 128, 53 | 'ensemble_epochs': 50, 54 | 'num_pretrain_iters': 20, 55 | 'ensemble_alpha': 1, # teacher loss (server side) 56 | 'ensemble_beta': 0, # adversarial student loss 57 | 'ensemble_eta': 1, # diversity loss 58 | 'unique_labels': 10, # available labels 59 | 'generative_alpha': 10, # used to regulate user training 60 | 'generative_beta': 10, # used to regulate user training 61 | 'weight_decay': 1e-2 62 | }, 63 | 64 | 'celeb': 65 | { 66 | 'ensemble_lr': 3e-4, 67 | 'ensemble_batch_size': 128, 68 | 'ensemble_epochs': 50, 69 | 'num_pretrain_iters': 20, 70 | 'ensemble_alpha': 1, # teacher loss (server side) 71 | 'ensemble_beta': 0, # adversarial student loss 72 | 'unique_labels': 2, 73 | 'generative_alpha': 10, 74 | 'generative_beta': 10, 75 | 'weight_decay': 1e-2 76 | }, 77 | 78 | } 79 | 80 | -------------------------------------------------------------------------------- /utils/model_config-base.py: -------------------------------------------------------------------------------- 1 | CONFIGS_ = { 2 | # input_channel, n_class, hidden_dim, latent_dim 3 | 'cifar': ([16, 'M', 32, 'M', 'F'], 3, 10, 2048, 64), 4 | 'cifar100-c25': ([32, 'M', 64, 'M', 128, 'F'], 3, 25, 128, 128), 5 | 'cifar100-c30': ([32, 'M', 64, 'M', 128, 'F'], 3, 30, 2048, 128), 6 | 'cifar100-c50': ([32, 'M', 64, 'M', 128, 'F'], 3, 50, 2048, 128), 7 | 8 | 'emnist': ([6, 16, 'F'], 1, 25, 784, 32), 9 | 'mnist': ([6, 16, 'F'], 1, 10, 784, 32), 10 | 'mnist_cnn1': ([6, 'M', 16, 'M', 'F'], 1, 10, 64, 32), 11 | 'mnist_cnn2': ([16, 'M', 32, 'M', 'F'], 1, 10, 128, 32), 12 | 'celeb': ([16, 'M', 32, 'M', 64, 'M', 'F'], 3, 2, 64, 32) 13 | } 14 | 15 | # temporary roundabout to evaluate sensitivity of the generator 16 | GENERATORCONFIGS = { 17 | # hidden_dimension, latent_dimension, input_channel, n_class, noise_dim 18 | 'cifar': (512, 32, 3, 10, 64), 19 | 'celeb': (128, 32, 3, 2, 32), 20 | 'mnist': (256, 32, 1, 10, 32), 21 | 'mnist-cnn0': (256, 32, 1, 10, 64), 22 | 'mnist-cnn1': (128, 32, 1, 10, 32), 23 | 'mnist-cnn2': (64, 32, 1, 10, 32), 24 | 'mnist-cnn3': (64, 32, 1, 10, 16), 25 | 'emnist': (256, 32, 1, 25, 32), 26 | 'emnist-cnn0': (256, 32, 1, 25, 64), 27 | 'emnist-cnn1': (128, 32, 1, 25, 32), 28 | 'emnist-cnn2': (128, 32, 1, 25, 16), 29 | 'emnist-cnn3': (64, 32, 1, 25, 32), 30 | } 31 | 32 | 33 | 34 | RUNCONFIGS = { 35 | 'emnist': 36 | { 37 | 'ensemble_lr': 1e-4, 38 | 'ensemble_batch_size': 128, 39 | 'ensemble_epochs': 50, 40 | 'num_pretrain_iters': 20, 41 | 'ensemble_alpha': 1, # teacher loss (server side) 42 | 'ensemble_beta': 0, # adversarial student loss 43 | 'unique_labels': 25, 44 | 'generative_alpha':10, 45 | 'generative_beta': 1, 46 | 'weight_decay': 1e-2 47 | }, 48 | 49 | 'mnist': 50 | { 51 | 'ensemble_lr': 3e-4, 52 | 'ensemble_batch_size': 128, 53 | 'ensemble_epochs': 50, 54 | 'num_pretrain_iters': 20, 55 | 'ensemble_alpha': 1, # teacher loss (server side) 56 | 'ensemble_beta': 0, # adversarial student loss 57 | 'ensemble_eta': 1, # diversity loss 58 | 'unique_labels': 10, # available labels 59 | 'generative_alpha': 10, # used to regulate user training 60 | 'generative_beta': 10, # used to regulate user training 61 | 'weight_decay': 1e-2 62 | }, 63 | 64 | 'celeb': 65 | { 66 | 'ensemble_lr': 3e-4, 67 | 'ensemble_batch_size': 128, 68 | 'ensemble_epochs': 50, 69 | 'num_pretrain_iters': 20, 70 | 'ensemble_alpha': 1, # teacher loss (server side) 71 | 'ensemble_beta': 0, # adversarial student loss 72 | 'unique_labels': 2, 73 | 'generative_alpha': 10, 74 | 'generative_beta': 10, 75 | 'weight_decay': 1e-2 76 | }, 77 | 78 | } 79 | 80 | -------------------------------------------------------------------------------- /utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import h5py 3 | import numpy as np 4 | from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes, mark_inset 5 | from matplotlib.ticker import StrMethodFormatter 6 | import os 7 | from utils.model_utils import get_log_path, METRICS 8 | import seaborn as sns 9 | import string 10 | import matplotlib.colors as mcolors 11 | import os 12 | COLORS=list(mcolors.TABLEAU_COLORS) 13 | MARKERS=["o", "v", "s", "*", "x", "P"] 14 | 15 | plt.rcParams.update({'font.size': 14}) 16 | n_seeds=3 17 | 18 | def load_results(args, algorithm, seed): 19 | alg = get_log_path(args, algorithm, seed, args.gen_batch_size) 20 | hf = h5py.File("./{}/{}.h5".format(args.result_path, alg), 'r') 21 | metrics = {} 22 | for key in METRICS: 23 | metrics[key] = np.array(hf.get(key)[:]) 24 | return metrics 25 | 26 | 27 | def get_label_name(name): 28 | name = name.split("_")[0] 29 | if 'Distill' in name: 30 | if '-FL' in name: 31 | name = 'FedDistill' + r'$^+$' 32 | else: 33 | name = 'FedDistill' 34 | elif 'FedDF' in name: 35 | name = 'FedFusion' 36 | elif 'FedEnsemble' in name: 37 | name = 'Ensemble' 38 | elif 'FedAvg' in name: 39 | name = 'FedAvg' 40 | return name 41 | 42 | def plot_results(args, algorithms): 43 | n_seeds = args.times 44 | dataset_ = args.dataset.split('-') 45 | sub_dir = dataset_[0] + "/" + dataset_[2] # e.g. Mnist/ratio0.5 46 | os.system("mkdir -p figs/{}".format(sub_dir)) # e.g. figs/Mnist/ratio0.5 47 | plt.figure(1, figsize=(5, 5)) 48 | TOP_N = 5 49 | max_acc = 0 50 | for i, algorithm in enumerate(algorithms): 51 | algo_name = get_label_name(algorithm) 52 | ######### plot test accuracy ############ 53 | metrics = [load_results(args, algorithm, seed) for seed in range(n_seeds)] 54 | all_curves = np.concatenate([metrics[seed]['glob_acc'] for seed in range(n_seeds)]) 55 | top_accs = np.concatenate([np.sort(metrics[seed]['glob_acc'])[-TOP_N:] for seed in range(n_seeds)] ) 56 | acc_avg = np.mean(top_accs) 57 | acc_std = np.std(top_accs) 58 | info = 'Algorithm: {:<10s}, Accuracy = {:.2f} %, deviation = {:.2f}'.format(algo_name, acc_avg * 100, acc_std * 100) 59 | print(info) 60 | length = len(all_curves) // n_seeds 61 | sns.lineplot( 62 | x=np.array(list(range(length)) * n_seeds) + 1, 63 | y=all_curves.astype(float), 64 | legend='brief', 65 | color=COLORS[i], 66 | label=algo_name, 67 | ci="sd", 68 | ) 69 | 70 | plt.gcf() 71 | plt.grid() 72 | plt.title(dataset_[0] + ' Test Accuracy') 73 | plt.xlabel('Epoch') 74 | max_acc = np.max([max_acc, np.max(all_curves) ]) + 4e-2 75 | 76 | if args.min_acc < 0: 77 | alpha = 0.7 78 | min_acc = np.max(all_curves) * alpha + np.min(all_curves) * (1-alpha) 79 | else: 80 | min_acc = args.min_acc 81 | plt.ylim(min_acc, max_acc) 82 | fig_save_path = os.path.join('figs', sub_dir, dataset_[0] + '-' + dataset_[2] + '.png') 83 | plt.savefig(fig_save_path, bbox_inches='tight', pad_inches=0, format='png', dpi=400) 84 | print('file saved to {}'.format(fig_save_path)) -------------------------------------------------------------------------------- /FLAlgorithms/servers/serverFedDistill.py: -------------------------------------------------------------------------------- 1 | from FLAlgorithms.users.userFedDistill import UserFedDistill 2 | from FLAlgorithms.servers.serverbase import Server 3 | from utils.model_utils import read_data, read_user_data, aggregate_user_test_data 4 | import numpy as np 5 | 6 | class FedDistill(Server): 7 | def __init__(self, args, model, seed): 8 | super().__init__(args, model, seed) 9 | 10 | # Initialize data for all users 11 | data = read_data(args.dataset) 12 | # data contains: clients, groups, train_data, test_data, proxy_data 13 | clients = data[0] 14 | total_users = len(clients) 15 | self.total_test_samples = 0 16 | self.slow_start = 20 17 | self.share_model = 'FL' in self.algorithm 18 | self.pretrain = 'pretrain' in self.algorithm.lower() 19 | self.user_logits = None 20 | self.init_ensemble_configs() 21 | self.init_loss_fn() 22 | self.init_ensemble_configs() 23 | #### creating users #### 24 | self.users = [] 25 | for i in range(total_users): 26 | id, train_data, test_data, label_info =read_user_data(i, data, dataset=args.dataset, count_labels=True) 27 | self.total_train_samples+=len(train_data) 28 | self.total_test_samples += len(test_data) 29 | id, train, test=read_user_data(i, data, dataset=args.dataset) 30 | user=UserFedDistill( 31 | args, id, model, train_data, test_data, self.unique_labels, use_adam=False) 32 | self.users.append(user) 33 | print("Loading testing data.") 34 | print("Number of Train/Test samples:", self.total_train_samples, self.total_test_samples) 35 | print("Data from {} users in total.".format(total_users)) 36 | print("Finished creating FedAvg server.") 37 | 38 | def train(self, args): 39 | #### pretraining #### 40 | if self.pretrain: 41 | ## before training ## 42 | for iter in range(self.num_pretrain_iters): 43 | print("\n\n-------------Pretrain iteration number: ", iter, " -------------\n\n") 44 | for user in self.users: 45 | user.train(iter, personalized=True, lr_decay=True) 46 | self.evaluate(selected=False, save=False) 47 | ## after training ## 48 | if self.share_model: 49 | self.aggregate_parameters() 50 | self.aggregate_logits(selected=False) # aggregate label-wise logit vector 51 | 52 | for glob_iter in range(self.num_glob_iters): 53 | print("\n\n-------------Round number: ",glob_iter, " -------------\n\n") 54 | self.selected_users, self.user_idxs=self.select_users(glob_iter, self.num_users, return_idx=True) 55 | if self.share_model: 56 | self.send_parameters(mode=self.mode)# broadcast averaged prediction model 57 | self.evaluate() # evaluate global model performance 58 | self.send_logits() # send global logits if have any 59 | random_chosen_id = np.random.choice(self.user_idxs) 60 | for user_id, user in zip(self.user_idxs, self.selected_users): # allow selected users to train 61 | chosen = user_id == random_chosen_id 62 | user.train( 63 | glob_iter, 64 | personalized=True, lr_decay=True, count_labels=True, verbose=chosen) 65 | if self.share_model: 66 | self.aggregate_parameters() 67 | self.aggregate_logits() # aggregate label-wise logit vector 68 | self.evaluate_personalized_model() 69 | 70 | self.save_results(args) 71 | self.save_model() 72 | 73 | def aggregate_logits(self, selected=True): 74 | user_logits = 0 75 | users = self.selected_users if selected else self.users 76 | for user in users: 77 | user_logits += user.logit_tracker.avg() 78 | self.user_logits = user_logits / len(users) 79 | 80 | def send_logits(self): 81 | if self.user_logits == None: return 82 | for user in self.selected_users: 83 | user.global_logits = self.user_logits.clone().detach() 84 | -------------------------------------------------------------------------------- /FLAlgorithms/users/userFedDistill.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from FLAlgorithms.users.userbase import User 6 | from FLAlgorithms.optimizers.fedoptimizer import pFedIBOptimizer 7 | 8 | class LogitTracker(): 9 | def __init__(self, unique_labels): 10 | self.unique_labels = unique_labels 11 | self.labels = [i for i in range(unique_labels)] 12 | self.label_counts = torch.ones(unique_labels) # avoid division by zero error 13 | self.logit_sums = torch.zeros((unique_labels,unique_labels) ) 14 | 15 | def update(self, logits, Y): 16 | """ 17 | update logit tracker. 18 | :param logits: shape = n_sampls * logit-dimension 19 | :param Y: shape = n_samples 20 | :return: nothing 21 | """ 22 | batch_unique_labels, batch_labels_counts = Y.unique(dim=0, return_counts=True) 23 | self.label_counts[batch_unique_labels] += batch_labels_counts 24 | # expand label dimension to be n_samples X logit_dimension 25 | labels = Y.view(Y.size(0), 1).expand(-1, logits.size(1)) 26 | logit_sums_ = torch.zeros((self.unique_labels, self.unique_labels) ) 27 | logit_sums_.scatter_add_(0, labels, logits) 28 | self.logit_sums += logit_sums_ 29 | 30 | 31 | def avg(self): 32 | res= self.logit_sums / self.label_counts.float().unsqueeze(1) 33 | return res 34 | 35 | 36 | class UserFedDistill(User): 37 | """ 38 | Track and average logit vectors for each label, and share it with server/other users. 39 | """ 40 | def __init__(self, args, id, model, train_data, test_data, unique_labels, use_adam=False): 41 | super().__init__(args, id, model, train_data, test_data, use_adam=use_adam) 42 | 43 | self.init_loss_fn() 44 | self.unique_labels = unique_labels 45 | self.label_counts = {} 46 | self.logit_tracker = LogitTracker(self.unique_labels) 47 | self.global_logits = None 48 | self.reg_alpha = 1 49 | 50 | def update_label_counts(self, labels, counts): 51 | for label, count in zip(labels, counts): 52 | self.label_counts[int(label)] += count 53 | 54 | def clean_up_counts(self): 55 | del self.label_counts 56 | self.label_counts = {int(label):1 for label in range(self.unique_labels)} 57 | 58 | def train(self, glob_iter, personalized=True, lr_decay=True, count_labels=True, verbose=True): 59 | self.clean_up_counts() 60 | self.model.train() 61 | REG_LOSS, TRAIN_LOSS = 0, 0 62 | for epoch in range(1, self.local_epochs + 1): 63 | self.model.train() 64 | for i in range(self.K): 65 | result =self.get_next_train_batch(count_labels=count_labels) 66 | X, y = result['X'], result['y'] 67 | if count_labels: 68 | self.update_label_counts(result['labels'], result['counts']) 69 | self.optimizer.zero_grad() 70 | result=self.model(X, logit=True) 71 | output, logit = result['output'], result['logit'] 72 | self.logit_tracker.update(logit, y) 73 | if self.global_logits != None: 74 | ### get desired logit for each sample 75 | train_loss = self.loss(output, y) 76 | target_p = F.softmax(self.global_logits[y,:], dim=1) 77 | reg_loss = self.ensemble_loss(output, target_p) 78 | REG_LOSS += reg_loss 79 | TRAIN_LOSS += train_loss 80 | loss = train_loss + self.reg_alpha * reg_loss 81 | else: 82 | loss=self.loss(output, y) 83 | loss.backward() 84 | self.optimizer.step()#self.local_model) 85 | # local-model <=== self.model 86 | self.clone_model_paramenter(self.model.parameters(), self.local_model) 87 | if personalized: 88 | self.clone_model_paramenter(self.model.parameters(), self.personalized_model_bar) 89 | if lr_decay: 90 | self.lr_scheduler.step(glob_iter) 91 | if self.global_logits != None and verbose: 92 | REG_LOSS = REG_LOSS.detach().numpy() / (self.local_epochs * self.K) 93 | TRAIN_LOSS = TRAIN_LOSS.detach().numpy() / (self.local_epochs * self.K) 94 | info = "Train loss {:.2f}, Regularization loss {:.2f}".format(REG_LOSS, TRAIN_LOSS) 95 | print(info) 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import argparse 4 | from FLAlgorithms.servers.serveravg import FedAvg 5 | from FLAlgorithms.servers.serverFedProx import FedProx 6 | from FLAlgorithms.servers.serverFedDistill import FedDistill 7 | from FLAlgorithms.servers.serverpFedGen import FedGen 8 | from FLAlgorithms.servers.serverpFedEnsemble import FedEnsemble 9 | from FLAlgorithms.servers.serverpFedCL import FedCL 10 | from utils.model_utils import create_model 11 | from utils.plot_utils import * 12 | import torch 13 | from multiprocessing import Pool 14 | import time 15 | import warnings 16 | warnings.filterwarnings("ignore") 17 | def create_server_n_user(args, i): 18 | model = create_model(args.model, args.dataset, args.algorithm) 19 | if ('FedAvg' in args.algorithm): 20 | server=FedAvg(args, model, i) 21 | elif ('FedGen' in args.algorithm): 22 | server=FedGen(args, model, i) 23 | elif ('FedCL' in args.algorithm): 24 | server=FedCL(args, model, i) 25 | elif ('FedProx' in args.algorithm): 26 | server = FedProx(args, model, i) 27 | elif ('FedDistill' in args.algorithm): 28 | server = FedDistill(args, model, i) 29 | elif ('FedEnsemble' in args.algorithm): 30 | server = FedEnsemble(args, model, i) 31 | else: 32 | print("Algorithm {} has not been implemented.".format(args.algorithm)) 33 | exit() 34 | return server 35 | 36 | 37 | def run_job(args, i): 38 | torch.manual_seed(i) 39 | print("\n\n [ Start training iteration {} ] \n\n".format(i)) 40 | # Generate model 41 | server = create_server_n_user(args, i) 42 | if args.train: 43 | server.train(args) 44 | server.test() 45 | 46 | def main(args): 47 | for i in range(args.times): 48 | run_job(args, i) 49 | print("Finished training.") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--dataset", type=str, default="Mnist") 55 | parser.add_argument("--model", type=str, default="cnn") 56 | parser.add_argument("--train", type=int, default=1, choices=[0,1]) 57 | parser.add_argument("--algorithm", type=str, default="pFedMe") 58 | parser.add_argument("--batch_size", type=int, default=32) 59 | parser.add_argument("--gen_batch_size", type=int, default=32, help='number of samples from generator') 60 | parser.add_argument("--learning_rate", type=float, default=0.01, help="Local learning rate") 61 | parser.add_argument("--personal_learning_rate", type=float, default=0.01, help="Personalized learning rate to caculate theta aproximately using K steps") 62 | parser.add_argument("--ensemble_lr", type=float, default=1e-4, help="Ensemble learning rate.") 63 | parser.add_argument("--beta", type=float, default=1.0, help="Average moving parameter for pFedMe, or Second learning rate of Per-FedAvg") 64 | parser.add_argument("--lamda", type=int, default=1, help="Regularization term") 65 | parser.add_argument("--mix_lambda", type=float, default=0.1, help="Mix lambda for FedMXI baseline") 66 | parser.add_argument("--embedding", type=int, default=0, help="Use embedding layer in generator network") 67 | parser.add_argument("--num_glob_iters", type=int, default=200) 68 | parser.add_argument("--local_epochs", type=int, default=20) 69 | parser.add_argument("--num_users", type=int, default=20, help="Number of Users per round") 70 | parser.add_argument("--K", type=int, default=1, help="Computation steps") 71 | parser.add_argument("--times", type=int, default=3, help="running time") 72 | parser.add_argument("--device", type=str, default="cpu", choices=["cpu","cuda"], help="run device (cpu | cuda)") 73 | parser.add_argument("--result_path", type=str, default="results", help="directory path to save results") 74 | 75 | 76 | args = parser.parse_args() 77 | print("=" * 80) 78 | print(args) 79 | print("=" * 80) 80 | print("Summary of training process:") 81 | print("Algorithm: {}".format(args.algorithm)) 82 | print("Batch size: {}".format(args.batch_size)) 83 | print("Learing rate : {}".format(args.learning_rate)) 84 | print("Ensemble learing rate : {}".format(args.ensemble_lr)) 85 | print("Average Moving : {}".format(args.beta)) 86 | print("Subset of users : {}".format(args.num_users)) 87 | print("Number of global rounds : {}".format(args.num_glob_iters)) 88 | print("Number of local rounds : {}".format(args.local_epochs)) 89 | print("Dataset : {}".format(args.dataset)) 90 | print("Local Model : {}".format(args.model)) 91 | print("Device : {}".format(args.device)) 92 | print("=" * 80) 93 | time1 = time.time() 94 | main(args) 95 | time2 = time.time() 96 | print(time2-time1) 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FedCL: Federated Multi-Phase Curriculum Learning to Synchronously Correlate User Heterogeneity 2 | 3 | Research code that accompanies the paper [FedCL: Federated Multi-Phase Curriculum Learning to Synchronously Correlate User Heterogeneity](http://arxiv.org/abs/2211.07248). 4 | It contains implementation of the following algorithms: 5 | * **FedCL** (the proposed algorithm) 6 | * **FedGen** ([paper](https://arxiv.org/pdf/2105.10056.pdf) and [code](https://github.com/zhuangdizhu/FedGen/blob/main/FLAlgorithms/servers/serverpFedGen.py)). 7 | * **FedAvg** ([paper](https://arxiv.org/pdf/1602.05629.pdf) and [code](https://github.com/zhuangdizhu/FedGen/blob/main/FLAlgorithms/servers/serveravg.py)). 8 | * **FedProx** ([paper](https://arxiv.org/pdf/1812.06127.pdf) and [code](https://github.com/zhuangdizhu/FedGen/blob/main/FLAlgorithms/servers/serverFedProx.py)). 9 | * **FedDistill** and its extension **FedDistll-FL** ([paper](https://arxiv.org/pdf/2011.02367.pdf) and [code](https://github.com/zhuangdizhu/FedGen/blob/main/FLAlgorithms/servers/serverFedDistill.py)). 10 | 11 | ## Install Requirements: 12 | ```pip3 install -r requirements.txt``` 13 | 14 | 15 | ## Prepare Dataset: 16 | * To generate *non-iid* **Mnist** Dataset following the Dirichlet distribution D(α=0.1) for 20 clients, using 50% of the total available training samples: 17 |
cd ./data/Mnist
18 | python generate_niid_dirichlet.py --n_class 10 --sampling_ratio 0.5 --alpha 0.1 --n_user 20
19 | ### This will generate a dataset located at FedGen/data/Mnist/u20c10-alpha0.1-ratio0.5/
20 |
21 |
22 |
23 | - Similarly, to generate *non-iid* **EMnist** Dataset, using 10% of the total available training samples:
24 | cd FedGen/data/EMnist
25 | python generate_niid_dirichlet.py --sampling_ratio 0.1 --alpha 0.1 --n_user 20
26 | ### This will generate a dataset located at FedGen/data/EMnist/u20-letters-alpha0.1-ratio0.1/
27 |
28 |
29 | ## Run Experiments:
30 |
31 | There is a main file "main.py" which allows running all experiments.
32 |
33 | #### Run experiments on the *Mnist* Dataset:
34 |
35 | ```
36 | python main.py --dataset Mnist-alpha0.1-ratio0.5 --algorithm FedGen --batch_size 32 --num_glob_iters 200 --local_epochs 20 --num_users 10 --lamda 1 --learning_rate 0.01 --model cnn --personal_learning_rate 0.01 --times 3
37 | python main.py --dataset Mnist-alpha0.1-ratio0.5 --algorithm FedCL --batch_size 32 --num_glob_iters 200 --local_epochs 20 --num_users 10 --lamda 1 --learning_rate 0.01 --model cnn --personal_learning_rate 0.01 --times 3
38 | python main.py --dataset Mnist-alpha0.1-ratio0.5 --algorithm FedAvg --batch_size 32 --num_glob_iters 200 --local_epochs 20 --num_users 10 --lamda 1 --learning_rate 0.01 --model cnn --personal_learning_rate 0.01 --times 3
39 | python main.py --dataset Mnist-alpha0.1-ratio0.5 --algorithm FedProx --batch_size 32 --num_glob_iters 200 --local_epochs 20 --num_users 10 --lamda 1 --learning_rate 0.01 --model cnn --personal_learning_rate 0.01 --times 3
40 | python main.py --dataset Mnist-alpha0.1-ratio0.5 --algorithm FedDistll-FL --batch_size 32 --num_glob_iters 200 --local_epochs 20 --num_users 10 --lamda 1 --learning_rate 0.01 --model cnn --personal_learning_rate 0.01 --times 3
41 | ```
42 | ----
43 |
44 | ##### Run experiments on the *EMnist* Dataset:
45 | ```
46 | python main.py --dataset EMnist-alpha0.1-ratio0.1 --algorithm FedAvg --batch_size 32 --local_epochs 20 --num_users 10 --lamda 1 --model cnn --learning_rate 0.01 --personal_learning_rate 0.01 --num_glob_iters 200 --times 3
47 | python main.py --dataset EMnist-alpha0.1-ratio0.1 --algorithm FedGen --batch_size 32 --local_epochs 20 --num_users 10 --lamda 1 --model cnn --learning_rate 0.01 --personal_learning_rate 0.01 --num_glob_iters 200 --times 3
48 | python main.py --dataset EMnist-alpha0.1-ratio0.1 --algorithm FedCL --batch_size 32 --local_epochs 20 --num_users 10 --lamda 1 --model cnn --learning_rate 0.01 --personal_learning_rate 0.01 --num_glob_iters 200 --times 3
49 | python main.py --dataset EMnist-alpha0.1-ratio0.1 --algorithm FedProx --batch_size 32 --local_epochs 20 --num_users 10 --lamda 1 --model cnn --learning_rate 0.01 --personal_learning_rate 0.01 --num_glob_iters 200 --times 3
50 | python main.py --dataset EMnist-alpha0.1-ratio0.1 --algorithm FedDistll-FL --batch_size 32 --local_epochs 20 --num_users 10 --lamda 1 --model cnn --learning_rate 0.01 --personal_learning_rate 0.01 --num_glob_iters 200 --times 3
51 |
52 | ```
53 | ----
54 |
55 | ### Plot
56 | For the input attribute **algorithms**, list the name of algorithms and separate them by comma, e.g. `--algorithms FedAvg,FedGen,FedProx`
57 | ```
58 | python main_plot.py --dataset EMnist-alpha0.1-ratio0.1 --algorithms FedGen --batch_size 32 --local_epochs 50 --num_users 10 --num_glob_iters 200 --plot_legend 1
59 | ```
60 | ## Citation
61 | Please cite the following paper if you use this code in your work.
62 | ```
63 | @article{wang2022fedcl,
64 | author = {Wang, Mingjie and Guo, Jianxiong and Jia, Weijia},
65 | title = {FedCL: Federated Multi-Phase Curriculum Learning to Synchronously Correlate User Heterogeneity},
66 | year = {2023},
67 | journal = {IEEE Transactions on Artificial Intelligence},
68 | doi = {10.1109/TAI.2023.3307664},
69 | }
70 | ```
71 |
--------------------------------------------------------------------------------
/FLAlgorithms/optimizers/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | def _pairwise_distances(embeddings, squared=False):
3 | """Compute the 2D matrix of distances between all the embeddings.
4 | Args:
5 | embeddings: tensor of shape (batch_size, embed_dim)
6 | squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
7 | If false, output is the pairwise euclidean distance matrix.
8 | Returns:
9 | pairwise_distances: tensor of shape (batch_size, batch_size)
10 | """
11 | # Get the dot product between all embeddings
12 | # shape (batch_size, batch_size)
13 | dot_product = torch.matmul(embeddings, embeddings.t())
14 |
15 | # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
16 | # This also provides more numerical stability (the diagonal of the result will be exactly 0).
17 | # shape (batch_size,)
18 | square_norm = torch.diag(dot_product)
19 |
20 | # Compute the pairwise distance matrix as we have:
21 | # ||a - b||^2 = ||a||^2 - 2 + ||b||^2
22 | # shape (batch_size, batch_size)
23 | distances = torch.unsqueeze(square_norm, 1) - 2.0 * dot_product + torch.unsqueeze(square_norm, 0)
24 |
25 | # Because of computation errors, some distances might be negative so we put everything >= 0.0
26 | distances = torch.max(distances, torch.tensor([0.0]).cuda())
27 |
28 | if not squared:
29 | # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
30 | # we need to add a small epsilon where distances == 0.0
31 | mask = (torch.eq(distances, 0.0)).float()
32 | distances = distances + mask * 1e-16
33 |
34 | distances = torch.sqrt(distances)
35 |
36 | # Correct the epsilon added: set the distances on the mask to be exactly 0.0
37 | distances = distances * (torch.sub(1.0, mask))
38 |
39 | return distances
40 |
41 |
42 | def _get_triplet_mask(labels):
43 | """Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
44 | A triplet (i, j, k) is valid if:
45 | - i, j, k are distinct
46 | - labels[i] == labels[j] and labels[i] != labels[k]
47 | Args:
48 | labels: tf.int32 `Tensor` with shape [batch_size]
49 | """
50 | # Check that i, j and k are distinct
51 | indices_equal = torch.eye(labels.shape[0]).cuda()
52 | indices_not_equal = torch.tensor([1.0]).cuda()-indices_equal
53 | i_not_equal_j = torch.unsqueeze(indices_not_equal, 2)
54 | i_not_equal_k = torch.unsqueeze(indices_not_equal, 1)
55 | j_not_equal_k = torch.unsqueeze(indices_not_equal, 0)
56 |
57 | distinct_indices = torch.mul(torch.mul(i_not_equal_j, i_not_equal_k), j_not_equal_k)
58 |
59 |
60 | # Check if labels[i] == labels[j] and labels[i] != labels[k]
61 | label_equal = torch.eq(torch.unsqueeze(labels, 0), torch.unsqueeze(labels, 1)).float()
62 | i_equal_j = torch.unsqueeze(label_equal, 2)
63 | i_equal_k = torch.unsqueeze(label_equal, 1)
64 | valid_labels = torch.mul(i_equal_j, torch.tensor([1.0]).cuda()-i_equal_k)
65 |
66 | # Combine the two masks
67 | mask = torch.mul(distinct_indices, valid_labels)
68 | return mask
69 |
70 |
71 | def batch_all_triplet_loss(labels, embeddings, margin, squared=False):
72 | """Build the triplet loss over a batch of embeddings.
73 | We generate all the valid triplets and average the loss over the positive ones.
74 | Args:
75 | labels: labels of the batch, of size (batch_size,)
76 | embeddings: tensor of shape (batch_size, embed_dim)
77 | margin: margin for triplet loss
78 | squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
79 | If false, output is the pairwise euclidean distance matrix.
80 | Returns:
81 | triplet_loss: scalar tensor containing the triplet loss
82 | """
83 | # Get the pairwise distance matrix
84 | pairwise_dist = _pairwise_distances(embeddings, squared=squared)
85 |
86 | # shape (batch_size, batch_size, 1)
87 | anchor_positive_dist = torch.unsqueeze(pairwise_dist, 2)
88 | assert anchor_positive_dist.shape[2] == 1, "{}".format(anchor_positive_dist.shape)
89 | # shape (batch_size, 1, batch_size)
90 | anchor_negative_dist = torch.unsqueeze(pairwise_dist, 1)
91 | assert anchor_negative_dist.shape[1] == 1, "{}".format(anchor_negative_dist.shape)
92 |
93 | # Compute a 3D tensor of size (batch_size, batch_size, batch_size)
94 | # triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
95 | # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
96 | # and the 2nd (batch_size, 1, batch_size)
97 | triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
98 |
99 | # Put to zero the invalid triplets
100 | # (where label(a) != label(p) or label(n) == label(a) or a == p)
101 | mask = _get_triplet_mask(labels)
102 | mask = mask.float()
103 | triplet_loss = torch.mul(mask, triplet_loss)
104 |
105 | # Remove negative losses (i.e. the easy triplets)
106 | triplet_loss = torch.max(triplet_loss, torch.tensor([0.0]).cuda())
107 |
108 | # Count number of positive triplets (where triplet_loss > 0)
109 | valid_triplets = torch.gt(triplet_loss, 1e-16).float()
110 | num_positive_triplets = torch.sum(valid_triplets)
111 |
112 |
113 | # Get final mean triplet loss over the positive valid triplets
114 | triplet_loss = torch.sum(triplet_loss) / (num_positive_triplets + 1e-16)
115 |
116 | return triplet_loss
--------------------------------------------------------------------------------
/FLAlgorithms/trainmodel/models.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from utils.model_config import CONFIGS_
4 |
5 | import collections
6 |
7 | #################################
8 | ##### Neural Network model #####
9 | #################################
10 | class Net(nn.Module):
11 | def __init__(self, dataset='mnist', model='cnn'):
12 | super(Net, self).__init__()
13 | # define network layers
14 | print("Creating model for {}".format(dataset))
15 | self.dataset = dataset
16 | configs, input_channel, self.output_dim, self.hidden_dim, self.latent_dim=CONFIGS_[dataset]
17 | print('Network configs:', configs)
18 | self.named_layers, self.layers, self.layer_names =self.build_network(
19 | configs, input_channel, self.output_dim)
20 | self.n_parameters = len(list(self.parameters()))
21 | self.n_share_parameters = len(self.get_encoder())
22 |
23 | def get_number_of_parameters(self):
24 | pytorch_total_params=sum(p.numel() for p in self.parameters() if p.requires_grad)
25 | return pytorch_total_params
26 |
27 | def build_network(self, configs, input_channel, output_dim):
28 | layers = nn.ModuleList()
29 | named_layers = {}
30 | layer_names = []
31 | kernel_size, stride, padding = 3, 2, 1
32 | for i, x in enumerate(configs):
33 | if x == 'F':
34 | layer_name='flatten{}'.format(i)
35 | layer=nn.Flatten(1)
36 | layers+=[layer]
37 | layer_names+=[layer_name]
38 | elif x == 'M':
39 | pool_layer = nn.MaxPool2d(kernel_size=2, stride=2)
40 | layer_name = 'pool{}'.format(i)
41 | layers += [pool_layer]
42 | layer_names += [layer_name]
43 | else:
44 | cnn_name = 'encode_cnn{}'.format(i)
45 | cnn_layer = nn.Conv2d(input_channel, x, stride=stride, kernel_size=kernel_size, padding=padding)
46 | named_layers[cnn_name] = [cnn_layer.weight, cnn_layer.bias]
47 |
48 | bn_name = 'encode_batchnorm{}'.format(i)
49 | bn_layer = nn.BatchNorm2d(x)
50 | named_layers[bn_name] = [bn_layer.weight, bn_layer.bias]
51 |
52 | relu_name = 'relu{}'.format(i)
53 | relu_layer = nn.ReLU(inplace=True)# no parameters to learn
54 |
55 | layers += [cnn_layer, bn_layer, relu_layer]
56 | layer_names += [cnn_name, bn_name, relu_name]
57 | input_channel = x
58 |
59 | # finally, classification layer
60 | fc_layer_name1 = 'encode_fc1'
61 | fc_layer1 = nn.Linear(self.hidden_dim, self.latent_dim)
62 | layers += [fc_layer1]
63 | layer_names += [fc_layer_name1]
64 | named_layers[fc_layer_name1] = [fc_layer1.weight, fc_layer1.bias]
65 |
66 | fc_layer_name = 'decode_fc2'
67 | fc_layer = nn.Linear(self.latent_dim, self.output_dim)
68 | layers += [fc_layer]
69 | layer_names += [fc_layer_name]
70 | named_layers[fc_layer_name] = [fc_layer.weight, fc_layer.bias]
71 | return named_layers, layers, layer_names
72 |
73 |
74 | def get_parameters_by_keyword(self, keyword='encode'):
75 | params=[]
76 | for name, layer in zip(self.layer_names, self.layers):
77 | if keyword in name:
78 | #layer = self.layers[name]
79 | params += [layer.weight, layer.bias]
80 | return params
81 |
82 | def get_encoder(self):
83 | return self.get_parameters_by_keyword("encode")
84 |
85 | def get_decoder(self):
86 | return self.get_parameters_by_keyword("decode")
87 |
88 | def get_shared_parameters(self, detach=False):
89 | return self.get_parameters_by_keyword("decode_fc2")
90 |
91 | def get_learnable_params(self):
92 | return self.get_encoder() + self.get_decoder()
93 |
94 | def forward(self, x, start_layer_idx = 0, logit=False, start_layer_output=False):
95 | """
96 | :param x:
97 | :param logit: return logit vector before the last softmax layer
98 | :param start_layer_idx: if 0, conduct normal forward; otherwise, forward from the last few layers (see mapping function)
99 | :return:
100 | """
101 | if start_layer_idx < 0: #
102 | return self.mapping(x, start_layer_idx=start_layer_idx, logit=logit)
103 | restults={}
104 | z = x
105 | for idx in range(start_layer_idx, len(self.layers)):
106 | layer_name = self.layer_names[idx]
107 | layer = self.layers[idx]
108 | z = layer(z)
109 |
110 | if start_layer_output:
111 | layer_ = self.layers[start_layer_idx]
112 | layer_ = layer_(x)
113 |
114 | if self.output_dim > 1:
115 | restults['output'] = F.log_softmax(z, dim=1)
116 | else:
117 | restults['output'] = z
118 | if logit:
119 | restults['logit'] = z
120 | if start_layer_output:
121 | restults['strat_layer_output'] = layer_
122 | return restults
123 |
124 | def mapping(self, z_input, start_layer_idx=-1, logit=True):
125 | z = z_input
126 | n_layers = len(self.layers)
127 | for layer_idx in range(n_layers + start_layer_idx, n_layers):
128 | layer = self.layers[layer_idx]
129 | z = layer(z)
130 | if self.output_dim > 1:
131 | out=F.log_softmax(z, dim=1)
132 | result = {'output': out}
133 | if logit:
134 | result['logit'] = z
135 | return result
136 |
--------------------------------------------------------------------------------
/FLAlgorithms/users/userpFedEnsemble.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import os
5 | import json
6 | import copy
7 | import numpy as np
8 | from FLAlgorithms.users.userbase import User
9 | from FLAlgorithms.optimizers.fedoptimizer import pFedIBOptimizer
10 | from torchvision.utils import save_image
11 |
12 | class UserpFedEnsemble(User):
13 | def __init__(self, dataset, algorithm, numeric_id, train_data, test_data,
14 | model, generative_model, available_labels,
15 | batch_size, learning_rate, beta, lamda, local_epochs, K):
16 | super().__init__(dataset, algorithm, numeric_id, train_data, test_data, model, batch_size, learning_rate, beta, lamda,
17 | local_epochs)
18 |
19 | self.init_loss_fn()
20 | self.K = K
21 | self.optimizer = pFedIBOptimizer(self.model.parameters(), lr=self.learning_rate)
22 | self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.optimizer, gamma=0.99)
23 | self.generative_model = copy.deepcopy(generative_model)
24 | self.label_counts = {}
25 | self.available_labels = available_labels
26 | self.generative_alpha = 10
27 | self.generative_beta = 0.1
28 | self.update_gen_freq = 5
29 | self.pretrained = False
30 | self.generative_optimizer = torch.optim.Adam(
31 | params=self.generative_model.parameters(),
32 | lr=1e-3, betas=(0.9, 0.999),
33 | eps=1e-08, weight_decay=0, amsgrad=False)
34 | self.generative_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.generative_optimizer, gamma=0.98)
35 |
36 | def update_label_counts(self, labels, counts):
37 | for label, count in zip(labels, counts):
38 | self.label_counts[int(label)] += count
39 |
40 | def clean_up_counts(self):
41 | del self.label_counts
42 | self.label_counts = {label:1 for label in range(self.unique_labels)}
43 |
44 | def update_generator(self, steps, verbose=True):
45 | self.model.eval()
46 | self.generative_model.train()
47 | RECONSTRUCT_LOSS, KLD_LOSS, RC_LOSS = 0, 0, 0
48 | for _ in range(steps):
49 | self.generative_optimizer.zero_grad()
50 | samples=self.get_next_train_batch(count_labels=True)
51 | X, y=samples['X'], samples['y']
52 | gen_result = self.generative_model(X, y)
53 | loss_info = self.generative_model.loss_function(
54 | gen_result['output'],
55 | X,
56 | gen_result['mu'],
57 | gen_result['log_var'],
58 | beta=0.01
59 | )
60 | loss, kld_loss, reconstruct_loss = loss_info['loss'], loss_info['KLD'], loss_info['reconstruction_loss']
61 | RECONSTRUCT_LOSS += loss
62 | KLD_LOSS += kld_loss
63 | RC_LOSS += reconstruct_loss
64 | loss.backward()
65 | self.generative_optimizer.step()
66 | self.generative_lr_scheduler.step()
67 |
68 | if verbose:
69 | RECONSTRUCT_LOSS = RECONSTRUCT_LOSS.detach().numpy() / steps
70 | KLD_LOSS = KLD_LOSS.detach().numpy() / steps
71 | RC_LOSS = RC_LOSS.detach().numpy() / steps
72 | info = "VAE-Loss: {:.4f}, KL-Loss: {:.4f}, RC-Loss:{:.4f}".format(RECONSTRUCT_LOSS, KLD_LOSS, RC_LOSS)
73 | print(info)
74 |
75 |
76 | def train(self, glob_iter, personalized=False, reconstruct=False, verbose=False):
77 | self.clean_up_counts()
78 | self.model.train()
79 | self.generative_model.eval()
80 | TEACHER_LOSS, DIST_LOSS, RECONSTRUCT_LOSS = 0, 0, 0
81 | #if glob_iter % self.update_gen_freq == 0:
82 | if not self.pretrained:
83 | self.update_generator(self.local_epochs * 20)
84 | self.pretrained = True
85 | self.visualize_images(self.generative_model, 0, repeats=10)
86 | for epoch in range(self.local_epochs):
87 | self.model.train()
88 | self.generative_model.eval()
89 | for i in range(self.K):
90 | loss = 0
91 | self.optimizer.zero_grad()
92 | #### sample from real dataset (un-weighted)
93 | samples =self.get_next_train_batch(count_labels=True)
94 | X, y = samples['X'], samples['y']
95 | self.update_label_counts(samples['labels'], samples['counts'])
96 | model_result=self.model(X, return_latent=reconstruct)
97 | output = model_result['output']
98 | predictive_loss=self.loss(output, y)
99 | loss += predictive_loss
100 | #### sample from generator and regulate Dist|z_gen, z_pred|, where z_gen = Gen(x, y), z_pred = model(X)
101 | if reconstruct:
102 | gen_result = self.generative_model(X, y, latent=True)
103 | z_gen = gen_result['latent']
104 | z_model = model_result['latent']
105 | dist_loss = self.generative_beta * self.dist_loss(z_model, z_gen)
106 | DIST_LOSS += dist_loss
107 | loss += dist_loss
108 | #### get loss and perform optimization
109 | loss.backward()
110 | self.optimizer.step() # self.local_model)
111 |
112 | if reconstruct:
113 | DIST_LOSS+=(torch.mean(DIST_LOSS.double())).item()
114 | # local-model <=== self.model
115 | self.clone_model_paramenter(self.model.parameters(), self.local_model)
116 | if personalized:
117 | self.clone_model_paramenter(self.model.parameters(), self.personalized_model_bar)
118 | self.lr_scheduler.step(glob_iter)
119 | if reconstruct and verbose:
120 | info = 'Latent Reconstruction Loss={:.4f}'.format(DIST_LOSS)
121 | print(info)
122 |
123 | def adjust_weights(self, samples):
124 | labels, counts = samples['labels'], samples['counts']
125 | #weight=self.label_weights[y][:, user_idx].reshape(-1, 1)
126 | np_y = samples['y'].detach().numpy()
127 | n_labels = samples['y'].shape[0]
128 | weights = np.array([n_labels / count for count in counts]) # smaller count --> larger weight
129 | weights = len(self.available_labels) * weights / np.sum(weights) # normalized
130 | label_weights = np.ones(self.unique_labels)
131 | label_weights[labels] = weights
132 | sample_weights = label_weights[np_y]
133 | return sample_weights
134 |
135 | def visualize_images(self, generator, glob_iter, repeats=1):
136 | """
137 | Generate and visualize data for a generator.
138 | """
139 | os.system("mkdir -p images")
140 | path = f'images/{self.algorithm}-{self.dataset}-user{self.id}-iter{glob_iter}.png'
141 | y=self.available_labels
142 | y = np.repeat(y, repeats=repeats, axis=0)
143 | y_input=torch.tensor(y, dtype=torch.int64)
144 | generator.eval()
145 | images=generator.sample(y_input, latent=False)['output'] # 0,1,..,K, 0,1,...,K
146 | images=images.view(repeats, -1, *images.shape[1:])
147 | images=images.view(-1, *images.shape[2:])
148 | save_image(images.detach(), path, nrow=repeats, normalize=True)
149 | print("Image saved to {}".format(path))
--------------------------------------------------------------------------------
/FLAlgorithms/users/userbase.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import os
5 | import json
6 | from torch.utils.data import DataLoader
7 | import numpy as np
8 | import copy
9 | from utils.model_utils import get_dataset_name
10 | from utils.model_config import RUNCONFIGS
11 | from FLAlgorithms.optimizers.fedoptimizer import pFedIBOptimizer
12 |
13 | class User:
14 | """
15 | Base class for users in federated learning.
16 | """
17 | def __init__(
18 | self, args, id, model, train_data, test_data, use_adam=False):
19 | self.model = copy.deepcopy(model[0])
20 | self.model_name = model[1]
21 | self.id = id # integer
22 | self.train_samples = len(train_data)
23 | self.test_samples = len(test_data)
24 | self.batch_size = args.batch_size
25 | self.learning_rate = args.learning_rate
26 | self.beta = args.beta
27 | self.lamda = args.lamda
28 | self.local_epochs = args.local_epochs
29 | self.algorithm = args.algorithm
30 | self.K = args.K
31 | self.dataset = args.dataset
32 | self.trainloader = DataLoader(train_data, self.batch_size, shuffle=True, drop_last=True)
33 | self.testloader = DataLoader(test_data, self.batch_size, drop_last=False)
34 | self.testloaderfull = DataLoader(test_data, self.test_samples)
35 | self.trainloaderfull = DataLoader(train_data, self.train_samples)
36 | self.iter_trainloader = iter(self.trainloader)
37 | self.iter_testloader = iter(self.testloader)
38 | dataset_name = get_dataset_name(self.dataset)
39 | self.unique_labels = RUNCONFIGS[dataset_name]['unique_labels']
40 | self.generative_alpha = RUNCONFIGS[dataset_name]['generative_alpha']
41 | self.generative_beta = RUNCONFIGS[dataset_name]['generative_beta']
42 |
43 | # those parameters are for personalized federated learning.
44 | self.local_model = copy.deepcopy(list(self.model.parameters()))
45 | self.personalized_model_bar = copy.deepcopy(list(self.model.parameters()))
46 | self.prior_decoder = None
47 | self.prior_params = None
48 |
49 | self.init_loss_fn()
50 | if use_adam:
51 | self.optimizer=torch.optim.Adam(
52 | params=self.model.parameters(),
53 | lr=self.learning_rate, betas=(0.9, 0.999),
54 | eps=1e-08, weight_decay=1e-2, amsgrad=False)
55 | else:
56 | self.optimizer = pFedIBOptimizer(self.model.parameters(), lr=self.learning_rate)
57 | self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.optimizer, gamma=0.99)
58 | self.label_counts = {}
59 |
60 |
61 |
62 |
63 | def init_loss_fn(self):
64 | self.loss=nn.NLLLoss()
65 | self.dist_loss = nn.MSELoss()
66 | self.ensemble_loss=nn.KLDivLoss(reduction="batchmean")
67 | self.ce_loss = nn.CrossEntropyLoss()
68 |
69 | def set_parameters(self, model,beta=1):
70 | for old_param, new_param, local_param in zip(self.model.parameters(), model.parameters(), self.local_model):
71 | if beta == 1:
72 | old_param.data = new_param.data.clone()
73 | local_param.data = new_param.data.clone()
74 | else:
75 | old_param.data = beta * new_param.data.clone() + (1 - beta) * old_param.data.clone()
76 | local_param.data = beta * new_param.data.clone() + (1-beta) * local_param.data.clone()
77 |
78 | def set_prior_decoder(self, model, beta=1):
79 | for new_param, local_param in zip(model.personal_layers, self.prior_decoder):
80 | if beta == 1:
81 | local_param.data = new_param.data.clone()
82 | else:
83 | local_param.data = beta * new_param.data.clone() + (1 - beta) * local_param.data.clone()
84 |
85 |
86 | def set_prior(self, model):
87 | for new_param, local_param in zip(model.get_encoder() + model.get_decoder(), self.prior_params):
88 | local_param.data = new_param.data.clone()
89 |
90 | # only for pFedMAS
91 | def set_mask(self, mask_model):
92 | for new_param, local_param in zip(mask_model.get_masks(), self.mask_model.get_masks()):
93 | local_param.data = new_param.data.clone()
94 |
95 | def set_shared_parameters(self, model, mode='decode'):
96 | # only copy shared parameters to local
97 | for old_param, new_param in zip(
98 | self.model.get_parameters_by_keyword(mode),
99 | model.get_parameters_by_keyword(mode)
100 | ):
101 | old_param.data = new_param.data.clone()
102 |
103 | def get_parameters(self):
104 | for param in self.model.parameters():
105 | param.detach()
106 | return self.model.parameters()
107 |
108 |
109 | def clone_model_paramenter(self, param, clone_param):
110 | with torch.no_grad():
111 | for param, clone_param in zip(param, clone_param):
112 | clone_param.data = param.data.clone()
113 | return clone_param
114 |
115 | def get_updated_parameters(self):
116 | return self.local_weight_updated
117 |
118 | def update_parameters(self, new_params, keyword='all'):
119 | for param , new_param in zip(self.model.parameters(), new_params):
120 | param.data = new_param.data.clone()
121 |
122 | def get_grads(self):
123 | grads = []
124 | for param in self.model.parameters():
125 | if param.grad is None:
126 | grads.append(torch.zeros_like(param.data))
127 | else:
128 | grads.append(param.grad.data)
129 | return grads
130 |
131 | def test(self):
132 | self.model.eval()
133 | test_acc = 0
134 | loss = 0
135 | for x, y in self.testloaderfull:
136 | output = self.model(x)['output']
137 | loss += self.loss(output, y)
138 | test_acc += (torch.sum(torch.argmax(output, dim=1) == y)).item()
139 | return test_acc, loss, y.shape[0]
140 |
141 |
142 |
143 | def test_personalized_model(self):
144 | self.model.eval()
145 | test_acc = 0
146 | loss = 0
147 | self.update_parameters(self.personalized_model_bar)
148 | for x, y in self.testloaderfull:
149 | output = self.model(x)['output']
150 | loss += self.loss(output, y)
151 | test_acc += (torch.sum(torch.argmax(output, dim=1) == y)).item()
152 | #@loss += self.loss(output, y)
153 | #print(self.id + ", Test Accuracy:", test_acc / y.shape[0] )
154 | #print(self.id + ", Test Loss:", loss)
155 | self.update_parameters(self.local_model)
156 | return test_acc, y.shape[0], loss
157 |
158 |
159 | def get_next_train_batch(self, count_labels=True):
160 | try:
161 | # Samples a new batch for personalizing
162 | (X, y) = next(self.iter_trainloader)
163 | except StopIteration:
164 | # restart the generator if the previous generator is exhausted.
165 | self.iter_trainloader = iter(self.trainloader)
166 | (X, y) = next(self.iter_trainloader)
167 | result = {'X': X, 'y': y}
168 | if count_labels:
169 | unique_y, counts=torch.unique(y, return_counts=True)
170 | unique_y = unique_y.detach().numpy()
171 | counts = counts.detach().numpy()
172 | result['labels'] = unique_y
173 | result['counts'] = counts
174 | return result
175 |
176 | def get_next_test_batch(self):
177 | try:
178 | # Samples a new batch for personalizing
179 | (X, y) = next(self.iter_testloader)
180 | except StopIteration:
181 | # restart the generator if the previous generator is exhausted.
182 | self.iter_testloader = iter(self.testloader)
183 | (X, y) = next(self.iter_testloader)
184 | return (X, y)
185 |
186 | def save_model(self):
187 | model_path = os.path.join("models", self.dataset)
188 | if not os.path.exists(model_path):
189 | os.makedirs(model_path)
190 | torch.save(self.model, os.path.join(model_path, "user_" + self.id + ".pt"))
191 |
192 | def load_model(self):
193 | model_path = os.path.join("models", self.dataset)
194 | self.model = torch.load(os.path.join(model_path, "server" + ".pt"))
195 |
196 | @staticmethod
197 | def model_exists():
198 | return os.path.exists(os.path.join("models", "server" + ".pt"))
--------------------------------------------------------------------------------
/data/CelebA/generate_niid_agg.py:
--------------------------------------------------------------------------------
1 | from sklearn.datasets import fetch_openml
2 | from tqdm import trange
3 | import numpy as np
4 | import random
5 | import json
6 | import os
7 | import argparse
8 | from PIL import Image
9 | import torch
10 | from torch.utils.data import Dataset
11 | import torchvision.transforms as transforms
12 |
13 | IMAGE_SIZE = 84
14 | # TODO change LOAD_PATH to be your own data path
15 | LOAD_PATH = './celeba/'
16 | DUMP_PATH = './'
17 | IMG_DIR = os.path.join(LOAD_PATH, 'data/raw/img_align_celeba')
18 | random.seed(42)
19 | np.random.seed(42)
20 |
21 | def load_proxy_data(user_lists, cdata):
22 | n_users = len(user_lists)
23 | # merge all uer data chosen for proxy
24 | new_data={
25 | 'classes': [],
26 | 'user_data': {
27 | 'x': [],
28 | 'y': []
29 | },
30 | 'num_samples': 0
31 | }
32 | for uname in user_lists:
33 | user_data = cdata['user_data'][uname]
34 | X=user_data['x'] # path to image
35 | y=user_data['y'] # label
36 | assert len(X) == len(y)
37 | ## load image ##
38 | loaded_X = []
39 | for i, image_name in enumerate(X):
40 | image_path = os.path.join(IMG_DIR, image_name)
41 | image = Image.open(image_path)
42 | image=image.resize((IMAGE_SIZE, IMAGE_SIZE)).convert('RGB')
43 | image = np.array(image)
44 | loaded_X.append(image)
45 |
46 | new_data['user_data']['x'] += np.array(loaded_X).tolist()
47 | new_data['user_data']['y'] += y
48 |
49 | new_data['classes'] += list(np.unique(y))
50 | new_data['num_samples'] += len(y)
51 | combined = list(zip(new_data['user_data']['x'], new_data['user_data']['y']))
52 | return new_data
53 |
54 |
55 | def load_data(user_lists, cdata, agg_user=-1):
56 | n_users = len(user_lists)
57 | new_data = {
58 | 'users': [None for _ in range(n_users)],
59 | 'num_samples': [None for _ in range(n_users)],
60 | 'user_data': {}
61 | }
62 | if agg_user > 0:
63 | assert len(user_lists) % agg_user == 0
64 | agg_n_users = len(user_lists) // agg_user
65 | agg_data = {
66 | 'users': [None for _ in range(agg_n_users)],
67 | 'num_samples': [None for _ in range(agg_n_users)],
68 | 'user_data': {}
69 |
70 | }
71 |
72 | def agg_by_user_(new_data, agg_n_users, agg_user, verbose=False):
73 | for batch_id in range(agg_n_users):
74 | start_id, end_id = batch_id * agg_user, (batch_id + 1) * agg_user
75 | X, Y, N_samples = [], [], 0
76 | for idx in range(start_id, end_id):
77 | user_uname='f_{0:05d}'.format(idx)
78 | x = new_data['user_data'][user_uname]['x']
79 | y = new_data['user_data'][user_uname]['y']
80 | n_samples = new_data['num_samples'][idx]
81 | X += x
82 | Y += y
83 | N_samples += n_samples
84 |
85 | #####
86 | batch_user_name = 'f_{0:05d}'.format(batch_id)
87 | agg_data['users'][batch_id]= batch_user_name
88 | agg_data['num_samples'][batch_id]=len(Y)
89 | agg_data['user_data'][batch_user_name]={
90 | 'x': torch.Tensor(X).type(torch.float32).permute(0, 3, 1, 2),
91 | 'y': torch.Tensor(Y).type(torch.int64)
92 | }
93 | #####
94 |
95 |
96 | def load_per_user_(user_data, idx, verbose=False):
97 | """
98 | # Reduce test samples per user to ratio.
99 | :param uname:
100 | :param idx:
101 | :return:
102 | """
103 | new_uname='f_{0:05d}'.format(idx)
104 | X=user_data['x']
105 | y=user_data['y']
106 | assert len(X) == len(y)
107 | new_data['users'][idx] = new_uname
108 | new_data['num_samples'][idx] = len(y)
109 | new_data['user_data'][new_uname] = {'y':y}
110 | # load X as images
111 | loaded_X = []
112 | for i, image_name in enumerate(X):
113 | image_path = os.path.join(IMG_DIR, image_name)
114 | image = Image.open(image_path)
115 | image=image.resize((IMAGE_SIZE, IMAGE_SIZE)).convert('RGB')
116 | image = np.array(image)
117 | loaded_X.append(image)
118 | new_data['user_data'][new_uname] = {'x': np.array(loaded_X).tolist(), 'y':y}
119 | if verbose:
120 | print("processing user {}".format(new_uname))
121 |
122 | for idx, uname in enumerate(user_lists):
123 | user_data = cdata['user_data'][uname]
124 | load_per_user_(user_data, idx, verbose=True)
125 | #pass
126 | if agg_user == -1:
127 | return new_data
128 | else:
129 | agg_by_user_(new_data, agg_n_users, agg_user, verbose=False)
130 | return agg_data
131 |
132 | def process_data():
133 | load_path = os.path.join(LOAD_PATH, 'data')
134 | train_files = [f for f in os.listdir(os.path.join(load_path, 'train'))]
135 | test_files = [f for f in os.listdir(os.path.join(load_path, 'test'))]
136 |
137 | def sample_users(cdata, ratio=0.1, excludes=set()):
138 | """
139 | :param cdata:
140 | :param ratio:
141 | :return: list of sampled user names
142 | """
143 | user_lists=[u for u in cdata['users']]
144 | if ratio <= 1:
145 | n_selected_users=int(len(user_lists) * ratio)
146 | else:
147 | n_selected_users = ratio
148 | random.shuffle(user_lists)
149 | new_users = []
150 | i = 0
151 | for u in user_lists:
152 | if u not in excludes:
153 | new_users.append(u)
154 | i += 1
155 | if i == n_selected_users:
156 | return new_users
157 |
158 |
159 | def process_(mode, tf, ratio=0.1, user_lists=None, agg_user=-1):
160 | read_path = os.path.join(load_path, mode if mode != 'proxy' else 'train', tf)
161 | with open(read_path, 'r') as inf:
162 | cdata=json.load(inf)
163 | n_users = len(cdata['users'])
164 | if ratio > 1:
165 | assert ratio < n_users
166 | else:
167 | assert ratio < 1
168 | print("Number of users: {}".format(n_users))
169 | print("Number of raw {} samples: {:.1f}".format(mode, np.mean(cdata['num_samples'])))
170 | print("Deviation of raw {} samples: {:.1f}".format(mode, np.std(cdata['num_samples'])))
171 | #exit()
172 | if mode == 'train':
173 | assert user_lists == None
174 | user_lists = sample_users(cdata, ratio)
175 | new_data=load_data(user_lists, cdata, agg_user=agg_user)
176 | else: # test mode
177 | assert len(user_lists) > 0
178 | new_data = load_data(user_lists, cdata, agg_user=agg_user)
179 | print("Number of reduced users: {}".format(len(new_data['num_samples'])))
180 | print("Number of samples per user: {}".format(new_data['num_samples']))
181 |
182 | if ratio > 1:
183 | n_users = int(ratio)
184 | if agg_user > 0:
185 | n_users = int( n_users // agg_user)
186 | else:
187 | n_users = int(len(cdata['users']) * ratio)
188 | if agg_user > 0:
189 | dump_path=os.path.join(DUMP_PATH, 'user{}-agg{}'.format(n_users, agg_user))
190 | else:
191 | dump_path=os.path.join(DUMP_PATH, 'user{}'.format(n_users))
192 | os.system("mkdir -p {}".format(os.path.join(dump_path, mode)))
193 |
194 | dump_path = os.path.join(dump_path, '{}/{}.pt'.format(mode,mode))
195 | with open(dump_path, 'wb') as outfile:
196 | print("Saving {} data to {}".format(mode, dump_path))
197 | torch.save(new_data, outfile)
198 | return user_lists
199 |
200 | #
201 | mode='train'
202 | tf = train_files[0]
203 | user_lists = process_(mode, tf, ratio=args.ratio, agg_user=args.agg_user)
204 | mode = 'test'
205 | tf = test_files[0]
206 | process_(mode, tf, ratio=args.ratio, user_lists=user_lists, agg_user=args.agg_user)
207 |
208 |
209 | if __name__ == "__main__":
210 | parser = argparse.ArgumentParser()
211 | parser.add_argument("--agg_user", type=int, default=10, help="number of celebrities to be aggregated together as a device/client (as meta-batch size).")
212 | parser.add_argument("--ratio", type=float, default=250, help="Number of total celebrities to be sampled for FL training.")
213 | args = parser.parse_args()
214 | print("Number of FL devices: {}".format(args.ratio // args.agg_user ))
215 | process_data()
--------------------------------------------------------------------------------
/data/Mnist/generate_niid_dirichlet.py:
--------------------------------------------------------------------------------
1 | from tqdm import trange
2 | import numpy as np
3 | import random
4 | import json
5 | import os
6 | import argparse
7 | from torchvision.datasets import MNIST
8 | import torch
9 | from torch.utils.data import DataLoader
10 | import torchvision.transforms as transforms
11 |
12 | random.seed(42)
13 | np.random.seed(42)
14 |
15 | def rearrange_data_by_class(data, targets, n_class):
16 | new_data = []
17 | for i in trange(n_class):
18 | idx = targets == i
19 | new_data.append(data[idx])
20 | return new_data
21 |
22 | def get_dataset(mode='train'):
23 | transform = transforms.Compose(
24 | [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
25 |
26 | dataset = MNIST(root='./data', train=True if mode=='train' else False, download=True, transform=transform)
27 | n_sample = len(dataset.data)
28 | SRC_N_CLASS = len(dataset.classes)
29 | # full batch
30 | trainloader = DataLoader(dataset, batch_size=n_sample, shuffle=False)
31 |
32 | print("Loading data from storage ...")
33 | for _, xy in enumerate(trainloader, 0):
34 | dataset.data, dataset.targets = xy
35 |
36 | print("Rearrange data by class...")
37 | data_by_class = rearrange_data_by_class(
38 | dataset.data.cpu().detach().numpy(),
39 | dataset.targets.cpu().detach().numpy(),
40 | SRC_N_CLASS
41 | )
42 | print(f"{mode.upper()} SET:\n Total #samples: {n_sample}. sample shape: {dataset.data[0].shape}")
43 | print(" #samples per class:\n", [len(v) for v in data_by_class])
44 |
45 | return data_by_class, n_sample, SRC_N_CLASS
46 |
47 | def sample_class(SRC_N_CLASS, NUM_LABELS, user_id, label_random=False):
48 | assert NUM_LABELS <= SRC_N_CLASS
49 | if label_random:
50 | source_classes = [n for n in range(SRC_N_CLASS)]
51 | random.shuffle(source_classes)
52 | return source_classes[:NUM_LABELS]
53 | else:
54 | return [(user_id + j) % SRC_N_CLASS for j in range(NUM_LABELS)]
55 |
56 | def devide_train_data(data, n_sample, SRC_CLASSES, NUM_USERS, min_sample, alpha=0.5, sampling_ratio=0.5):
57 | min_sample = 10#len(SRC_CLASSES) * min_sample
58 | min_size = 0 # track minimal samples per user
59 | ###### Determine Sampling #######
60 | while min_size < min_sample:
61 | print("Try to find valid data separation")
62 | idx_batch=[{} for _ in range(NUM_USERS)]
63 | samples_per_user = [0 for _ in range(NUM_USERS)]
64 | max_samples_per_user = sampling_ratio * n_sample / NUM_USERS
65 | for l in SRC_CLASSES:
66 | # get indices for all that label
67 | idx_l = [i for i in range(len(data[l]))]
68 | np.random.shuffle(idx_l)
69 | if sampling_ratio < 1:
70 | samples_for_l = int( min(max_samples_per_user, int(sampling_ratio * len(data[l]))) )
71 | idx_l = idx_l[:samples_for_l]
72 | print(l, len(data[l]), len(idx_l))
73 | # dirichlet sampling from this label
74 | proportions=np.random.dirichlet(np.repeat(alpha, NUM_USERS))
75 | # re-balance proportions
76 | proportions=np.array([p * (n_per_user < max_samples_per_user) for p, n_per_user in zip(proportions, samples_per_user)])
77 | proportions=proportions / proportions.sum()
78 | proportions=(np.cumsum(proportions) * len(idx_l)).astype(int)[:-1]
79 | # participate data of that label
80 | for u, new_idx in enumerate(np.split(idx_l, proportions)):
81 | # add new idex to the user
82 | idx_batch[u][l] = new_idx.tolist()
83 | samples_per_user[u] += len(idx_batch[u][l])
84 | min_size=min(samples_per_user)
85 |
86 | ###### CREATE USER DATA SPLIT #######
87 | X = [[] for _ in range(NUM_USERS)]
88 | y = [[] for _ in range(NUM_USERS)]
89 | Labels=[set() for _ in range(NUM_USERS)]
90 | print("processing users...")
91 | for u, user_idx_batch in enumerate(idx_batch):
92 | for l, indices in user_idx_batch.items():
93 | if len(indices) == 0: continue
94 | X[u] += data[l][indices].tolist()
95 | y[u] += (l * np.ones(len(indices))).tolist()
96 | Labels[u].add(l)
97 |
98 | return X, y, Labels, idx_batch, samples_per_user
99 |
100 | def divide_test_data(NUM_USERS, SRC_CLASSES, test_data, Labels, unknown_test):
101 | # Create TEST data for each user.
102 | test_X = [[] for _ in range(NUM_USERS)]
103 | test_y = [[] for _ in range(NUM_USERS)]
104 | idx = {l: 0 for l in SRC_CLASSES}
105 | for user in trange(NUM_USERS):
106 | if unknown_test: # use all available labels
107 | user_sampled_labels = SRC_CLASSES
108 | else:
109 | user_sampled_labels = list(Labels[user])
110 | for l in user_sampled_labels:
111 | num_samples = int(len(test_data[l]) / NUM_USERS )
112 | assert num_samples + idx[l] <= len(test_data[l])
113 | test_X[user] += test_data[l][idx[l]:idx[l] + num_samples].tolist()
114 | test_y[user] += (l * np.ones(num_samples)).tolist()
115 | assert len(test_X[user]) == len(test_y[user]), f"{len(test_X[user])} == {len(test_y[user])}"
116 | idx[l] += num_samples
117 | return test_X, test_y
118 |
119 | def main():
120 | parser = argparse.ArgumentParser()
121 | parser.add_argument("--format", "-f", type=str, default="pt", help="Format of saving: pt (torch.save), json", choices=["pt", "json"])
122 | parser.add_argument("--n_class", type=int, default=10, help="number of classification labels")
123 | parser.add_argument("--min_sample", type=int, default=10, help="Min number of samples per user.")
124 | parser.add_argument("--sampling_ratio", type=float, default=0.05, help="Ratio for sampling training samples.")
125 | parser.add_argument("--unknown_test", type=int, default=0, help="Whether allow test label unseen for each user.")
126 | parser.add_argument("--alpha", type=float, default=0.01, help="alpha in Dirichelt distribution (smaller means larger heterogeneity)")
127 | parser.add_argument("--n_user", type=int, default=20,
128 | help="number of local clients, should be muitiple of 10.")
129 | args = parser.parse_args()
130 | print()
131 | print("Number of users: {}".format(args.n_user))
132 | print("Number of classes: {}".format(args.n_class))
133 | print("Min # of samples per uesr: {}".format(args.min_sample))
134 | print("Alpha for Dirichlet Distribution: {}".format(args.alpha))
135 | print("Ratio for Sampling Training Data: {}".format(args.sampling_ratio))
136 | NUM_USERS = args.n_user
137 |
138 | # Setup directory for train/test data
139 | path_prefix = f'u{args.n_user}c{args.n_class}-alpha{args.alpha}-ratio{args.sampling_ratio}'
140 |
141 | def process_user_data(mode, data, n_sample, SRC_CLASSES, Labels=None, unknown_test=0):
142 | if mode == 'train':
143 | X, y, Labels, idx_batch, samples_per_user = devide_train_data(
144 | data, n_sample, SRC_CLASSES, NUM_USERS, args.min_sample, args.alpha, args.sampling_ratio)
145 | if mode == 'test':
146 | assert Labels != None or unknown_test
147 | X, y = divide_test_data(NUM_USERS, SRC_CLASSES, data, Labels, unknown_test)
148 | dataset={'users': [], 'user_data': {}, 'num_samples': []}
149 | for i in range(NUM_USERS):
150 | uname='f_{0:05d}'.format(i)
151 | dataset['users'].append(uname)
152 | dataset['user_data'][uname]={
153 | 'x': torch.tensor(X[i], dtype=torch.float32),
154 | 'y': torch.tensor(y[i], dtype=torch.int64)}
155 | dataset['num_samples'].append(len(X[i]))
156 |
157 | print("{} #sample by user:".format(mode.upper()), dataset['num_samples'])
158 |
159 | data_path=f'./{path_prefix}/{mode}'
160 | if not os.path.exists(data_path):
161 | os.makedirs(data_path)
162 |
163 | data_path=os.path.join(data_path, "{}.".format(mode) + args.format)
164 | if args.format == "json":
165 | raise NotImplementedError(
166 | "json is not supported because the train_data/test_data uses the tensor instead of list and tensor cannot be saved into json.")
167 | with open(data_path, 'w') as outfile:
168 | print(f"Dumping train data => {data_path}")
169 | json.dump(dataset, outfile)
170 | elif args.format == "pt":
171 | with open(data_path, 'wb') as outfile:
172 | print(f"Dumping train data => {data_path}")
173 | torch.save(dataset, outfile)
174 | if mode == 'train':
175 | for u in range(NUM_USERS):
176 | print("{} samples in total".format(samples_per_user[u]))
177 | train_info = ''
178 | # train_idx_batch, train_samples_per_user
179 | n_samples_for_u = 0
180 | for l in sorted(list(Labels[u])):
181 | n_samples_for_l = len(idx_batch[u][l])
182 | n_samples_for_u += n_samples_for_l
183 | train_info += "c={},n={}| ".format(l, n_samples_for_l)
184 | print(train_info)
185 | print("{} Labels/ {} Number of training samples for user [{}]:".format(len(Labels[u]), n_samples_for_u, u))
186 | return Labels, idx_batch, samples_per_user
187 |
188 |
189 | print(f"Reading source dataset.")
190 | train_data, n_train_sample, SRC_N_CLASS = get_dataset(mode='train')
191 | test_data, n_test_sample, SRC_N_CLASS = get_dataset(mode='test')
192 | SRC_CLASSES=[l for l in range(SRC_N_CLASS)]
193 | random.shuffle(SRC_CLASSES)
194 | print("{} labels in total.".format(len(SRC_CLASSES)))
195 | Labels, idx_batch, samples_per_user = process_user_data('train', train_data, n_train_sample, SRC_CLASSES)
196 | process_user_data('test', test_data, n_test_sample, SRC_CLASSES, Labels=Labels, unknown_test=args.unknown_test)
197 | print("Finish Generating User samples")
198 |
199 | if __name__ == "__main__":
200 | main()
--------------------------------------------------------------------------------
/FLAlgorithms/trainmodel/generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | MAXLOG = 0.1
5 | from torch.autograd import Variable
6 | import collections
7 | import numpy as np
8 | from utils.model_config import GENERATORCONFIGS
9 |
10 |
11 | class Generator(nn.Module):
12 | def __init__(self, dataset='mnist', model='cnn', embedding=False, latent_layer_idx=-1):
13 | super(Generator, self).__init__()
14 | print("Dataset {}".format(dataset))
15 | self.embedding = embedding
16 | self.dataset = dataset
17 | #self.model=model
18 | self.latent_layer_idx = latent_layer_idx
19 | self.hidden_dim, self.latent_dim, self.input_channel, self.n_class, self.noise_dim = GENERATORCONFIGS[dataset]
20 | input_dim = self.noise_dim * 2 if self.embedding else self.noise_dim + self.n_class
21 | self.fc_configs = [input_dim, self.hidden_dim]
22 | self.init_loss_fn()
23 | self.build_network()
24 |
25 | def get_number_of_parameters(self):
26 | pytorch_total_params=sum(p.numel() for p in self.parameters() if p.requires_grad)
27 | return pytorch_total_params
28 |
29 | def init_loss_fn(self):
30 | self.crossentropy_loss=nn.NLLLoss(reduce=False) # same as above
31 | self.diversity_loss = DiversityLoss(metric='l1')
32 | self.dist_loss = nn.MSELoss()
33 |
34 | def build_network(self):
35 | if self.embedding:
36 | self.embedding_layer = nn.Embedding(self.n_class, self.noise_dim)
37 | ### FC modules ####
38 | self.fc_layers = nn.ModuleList()
39 | self.fc_configs[0] += 1
40 | for i in range(len(self.fc_configs) - 1):
41 | input_dim, out_dim = self.fc_configs[i], self.fc_configs[i + 1]
42 | print("Build layer {} X {}".format(input_dim, out_dim))
43 | fc = nn.Linear(input_dim, out_dim)
44 | bn = nn.BatchNorm1d(out_dim)
45 | act = nn.ReLU()
46 | self.fc_layers += [fc, bn, act]
47 | ### Representation layer
48 | self.representation_layer = nn.Linear(self.fc_configs[-1], self.latent_dim)
49 | print("Build last layer {} X {}".format(self.fc_configs[-1], self.latent_dim))
50 |
51 | def forward(self, labels, Real_CL_Results, latent_layer_idx=-1, verbose=True):
52 | """
53 | G(Z|y) or G(X|y):
54 | Generate either latent representation( latent_layer_idx < 0) or raw image (latent_layer_idx=0) conditional on labels.
55 | :param labels:
56 | :param latent_layer_idx:
57 | if -1, generate latent representation of the last layer,
58 | -2 for the 2nd to last layer, 0 for raw images.
59 | :param verbose: also return the sampled Gaussian noise if verbose = True
60 | :return: a dictionary of output information.
61 | """
62 | result = {}
63 | batch_size = labels.shape[0]
64 | eps = torch.rand((batch_size, self.noise_dim)) # sampling from Gaussian
65 | if verbose:
66 | result['eps'] = eps
67 | if self.embedding: # embedded dense vector
68 | y_input = self.embedding_layer(labels)
69 | else: # one-hot (sparse) vector
70 | y_input = torch.FloatTensor(batch_size, self.n_class)
71 | y_input.zero_()
72 | #labels = labels.view
73 | labels_int64 = labels.type(torch.LongTensor)
74 | y_input.scatter_(1, labels_int64.view(-1,1), 1)
75 | z = torch.cat((eps, y_input), dim=1)
76 | #print(z.size())
77 | #z = z * Real_CL_Resultsw
78 |
79 | z = torch.cat((z, Real_CL_Results), dim=1)
80 |
81 | ### FC layers
82 | for layer in self.fc_layers:
83 | z = layer(z)
84 | z = self.representation_layer(z)
85 | result['output'] = z
86 | return result
87 |
88 | @staticmethod
89 | def normalize_images(layer):
90 | """
91 | Normalize images into zero-mean and unit-variance.
92 | """
93 | mean = layer.mean(dim=(2, 3), keepdim=True)
94 | std = layer.view((layer.size(0), layer.size(1), -1)) \
95 | .std(dim=2, keepdim=True).unsqueeze(3)
96 | return (layer - mean) / std
97 |
98 | class Discriminator(nn.Module):
99 | def __init__(self, size=0,dataset='mnist', model='cnn', embedding=False):
100 | super(Discriminator, self).__init__()
101 | self.size = size
102 | self.embedding = embedding
103 | self.dataset = dataset
104 | self.hidden_dim, self.latent_dim, self.input_channel, self.n_class, self.noise_dim = GENERATORCONFIGS[dataset]
105 | self.init_loss_fn()
106 | self.model_1 = nn.Sequential(
107 | nn.Linear(self.latent_dim, 128),
108 | nn.LeakyReLU(0.2, inplace=True),
109 | nn.Linear(128, 64),
110 | nn.LeakyReLU(0.2, inplace=True),
111 | nn.Linear(64, self.latent_dim),
112 | nn.Sigmoid(),
113 | )
114 | self.model_2 = nn.Sequential(
115 | nn.Linear(self.latent_dim, 128),
116 | nn.LeakyReLU(0.2, inplace=True),
117 | nn.Linear(128, 64),
118 | nn.LeakyReLU(0.2, inplace=True),
119 | nn.Linear(64, 1),
120 | nn.Sigmoid(),
121 | )
122 | def forward(self, img,cl_score):
123 | img = self.model_1(img).squeeze(-1)
124 | cl_score = self.model_2(cl_score)
125 | return img, cl_score
126 |
127 | def init_loss_fn(self):
128 | self.crossentropy_loss=nn.NLLLoss(reduce=False) # same as above
129 | self.diversity_loss = DiversityLoss(metric='l1')
130 | self.dist_loss = nn.MSELoss()
131 | #
132 | # class Decoder(nn.Module):
133 | # """
134 | # Decoder for both unstructured and image datasets.
135 | # """
136 | # def __init__(self, dataset='mnist', latent_layer_idx=-1, n_layers=2, units=32):
137 | # """
138 | # Class initializer.
139 | # """
140 | # #in_features, out_targets, n_layers=2, units=32):
141 | # super(Decoder, self).__init__()
142 | # self.cv_configs, self.input_channel, self.n_class, self.scale, self.noise_dim = GENERATORCONFIGS[dataset]
143 | # self.hidden_dim = self.scale * self.scale * self.cv_configs[0]
144 | # self.latent_dim = self.cv_configs[0] * 2
145 | # self.represent_dims = [self.hidden_dim, self.latent_dim]
146 | # in_features = self.represent_dims[latent_layer_idx]
147 | # out_targets = self.noise_dim
148 | #
149 | # # build layer structure
150 | # layers = [nn.Linear(in_features, units),
151 | # nn.ELU(),
152 | # nn.BatchNorm1d(units)]
153 | #
154 | # for _ in range(n_layers):
155 | # layers.extend([
156 | # nn.Linear(units, units),
157 | # nn.ELU(),
158 | # nn.BatchNorm1d(units)])
159 | #
160 | # layers.append(nn.Linear(units, out_targets))
161 | # self.layers = nn.Sequential(*layers)
162 | #
163 | # def forward(self, x):
164 | # """
165 | # Forward propagation.
166 | # """
167 | # out = x.view((x.size(0), -1))
168 | # out = self.layers(out)
169 | # return out
170 |
171 | class DivLoss(nn.Module):
172 | """
173 | Diversity loss for improving the performance.
174 | """
175 |
176 | def __init__(self):
177 | """
178 | Class initializer.
179 | """
180 | super().__init__()
181 |
182 | def forward2(self, noises, layer):
183 | """
184 | Forward propagation.
185 | """
186 | if len(layer.shape) > 2:
187 | layer = layer.view((layer.size(0), -1))
188 | chunk_size = layer.size(0) // 2
189 |
190 | ####### diversity loss ########
191 | eps1, eps2=torch.split(noises, chunk_size, dim=0)
192 | chunk1, chunk2=torch.split(layer, chunk_size, dim=0)
193 | lz=torch.mean(torch.abs(chunk1 - chunk2)) / torch.mean(
194 | torch.abs(eps1 - eps2))
195 | eps=1 * 1e-5
196 | diversity_loss=1 / (lz + eps)
197 | return diversity_loss
198 |
199 | def forward(self, noises, layer):
200 | """
201 | Forward propagation.
202 | """
203 | if len(layer.shape) > 2:
204 | layer=layer.view((layer.size(0), -1))
205 | chunk_size=layer.size(0) // 2
206 |
207 | ####### diversity loss ########
208 | eps1, eps2=torch.split(noises, chunk_size, dim=0)
209 | chunk1, chunk2=torch.split(layer, chunk_size, dim=0)
210 | lz=torch.mean(torch.abs(chunk1 - chunk2)) / torch.mean(
211 | torch.abs(eps1 - eps2))
212 | eps=1 * 1e-5
213 | diversity_loss=1 / (lz + eps)
214 | return diversity_loss
215 |
216 | class DiversityLoss(nn.Module):
217 | """
218 | Diversity loss for improving the performance.
219 | """
220 |
221 | def __init__(self, metric):
222 | """
223 | Class initializer.
224 | """
225 | super().__init__()
226 | self.metric = metric
227 | self.cosine = nn.CosineSimilarity(dim=2)
228 |
229 | def compute_distance(self, tensor1, tensor2, metric):
230 | """
231 | Compute the distance between two tensors.
232 | """
233 | if metric == 'l1':
234 | return torch.abs(tensor1 - tensor2).mean(dim=(2,))
235 | elif metric == 'l2':
236 | return torch.pow(tensor1 - tensor2, 2).mean(dim=(2,))
237 | elif metric == 'cosine':
238 | return 1 - self.cosine(tensor1, tensor2)
239 | else:
240 | raise ValueError(metric)
241 |
242 | def pairwise_distance(self, tensor, how):
243 | """
244 | Compute the pairwise distances between a Tensor's rows.
245 | """
246 | n_data = tensor.size(0)
247 | tensor1 = tensor.expand((n_data, n_data, tensor.size(1)))
248 | tensor2 = tensor.unsqueeze(dim=1)
249 | return self.compute_distance(tensor1, tensor2, how)
250 |
251 | def forward(self, noises, layer):
252 | """
253 | Forward propagation.
254 | """
255 | if len(layer.shape) > 2:
256 | layer = layer.view((layer.size(0), -1))
257 | layer_dist = self.pairwise_distance(layer, how=self.metric)
258 | noise_dist = self.pairwise_distance(noises, how='l2')
259 | return torch.exp(torch.mean(-noise_dist * layer_dist))
260 |
--------------------------------------------------------------------------------
/FLAlgorithms/servers/serverbase.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | import h5py
5 | from utils.model_utils import get_dataset_name, RUNCONFIGS
6 | import copy
7 | import torch.nn.functional as F
8 | import time
9 | import torch.nn as nn
10 | from utils.model_utils import get_log_path, METRICS
11 |
12 | class Server:
13 | def __init__(self, args, model, seed):
14 |
15 | # Set up the main attributes
16 | self.dataset = args.dataset
17 | self.num_glob_iters = args.num_glob_iters
18 | self.local_epochs = args.local_epochs
19 | self.batch_size = args.batch_size
20 | self.learning_rate = args.learning_rate
21 | self.total_train_samples = 0
22 | self.K = args.K
23 | self.model = copy.deepcopy(model[0])
24 | self.model_name = model[1]
25 | self.users = []
26 | self.selected_users = []
27 | self.num_users = args.num_users
28 | self.beta = args.beta
29 | self.lamda = args.lamda
30 | self.algorithm = args.algorithm
31 | self.personalized = 'pFed' in self.algorithm
32 | self.mode='partial' if 'partial' in self.algorithm.lower() else 'all'
33 | self.seed = seed
34 | self.deviations = {}
35 | self.metrics = {key:[] for key in METRICS}
36 | self.timestamp = None
37 | self.save_path = args.result_path
38 | os.system("mkdir -p {}".format(self.save_path))
39 |
40 |
41 | def init_ensemble_configs(self):
42 | #### used for ensemble learning ####
43 | dataset_name = get_dataset_name(self.dataset)
44 | self.ensemble_lr = RUNCONFIGS[dataset_name].get('ensemble_lr', 1e-4)
45 | self.ensemble_batch_size = RUNCONFIGS[dataset_name].get('ensemble_batch_size', 128)
46 | self.ensemble_epochs = RUNCONFIGS[dataset_name]['ensemble_epochs']
47 | self.num_pretrain_iters = RUNCONFIGS[dataset_name]['num_pretrain_iters']
48 | self.temperature = RUNCONFIGS[dataset_name].get('temperature', 1)
49 | self.unique_labels = RUNCONFIGS[dataset_name]['unique_labels']
50 | self.ensemble_alpha = RUNCONFIGS[dataset_name].get('ensemble_alpha', 1)
51 | self.ensemble_beta = RUNCONFIGS[dataset_name].get('ensemble_beta', 0)
52 | self.ensemble_eta = RUNCONFIGS[dataset_name].get('ensemble_eta', 1)
53 | self.weight_decay = RUNCONFIGS[dataset_name].get('weight_decay', 0)
54 | self.generative_alpha = RUNCONFIGS[dataset_name]['generative_alpha']
55 | self.generative_beta = RUNCONFIGS[dataset_name]['generative_beta']
56 | self.ensemble_train_loss = []
57 | self.n_teacher_iters = 5
58 | self.n_student_iters = 1
59 | print("ensemble_lr: {}".format(self.ensemble_lr) )
60 | print("ensemble_batch_size: {}".format(self.ensemble_batch_size) )
61 | print("unique_labels: {}".format(self.unique_labels) )
62 |
63 |
64 | def if_personalized(self):
65 | return 'pFed' in self.algorithm or 'PerAvg' in self.algorithm
66 |
67 | def if_ensemble(self):
68 | return 'FedE' in self.algorithm
69 |
70 | def send_parameters(self, mode='all', beta=1, selected=False):
71 | users = self.users
72 | if selected:
73 | assert (self.selected_users is not None and len(self.selected_users) > 0)
74 | users = self.selected_users
75 | for user in users:
76 | if mode == 'all': # share only subset of parameters
77 | user.set_parameters(self.model,beta=beta)
78 | else: # share all parameters
79 | user.set_shared_parameters(self.model,mode=mode)
80 |
81 |
82 | def add_parameters(self, user, ratio, partial=False):
83 | if partial:
84 | for server_param, user_param in zip(self.model.get_shared_parameters(), user.model.get_shared_parameters()):
85 | server_param.data = server_param.data + user_param.data.clone() * ratio
86 | else:
87 | for server_param, user_param in zip(self.model.parameters(), user.model.parameters()):
88 | server_param.data = server_param.data + user_param.data.clone() * ratio
89 |
90 |
91 |
92 | def aggregate_parameters(self,partial=False):
93 | assert (self.selected_users is not None and len(self.selected_users) > 0)
94 | if partial:
95 | for param in self.model.get_shared_parameters():
96 | param.data = torch.zeros_like(param.data)
97 | else:
98 | for param in self.model.parameters():
99 | param.data = torch.zeros_like(param.data)
100 | total_train = 0
101 | for user in self.selected_users:
102 | total_train += user.train_samples
103 | for user in self.selected_users:
104 | self.add_parameters(user, user.train_samples / total_train,partial=partial)
105 |
106 |
107 | def save_model(self):
108 | model_path = os.path.join("models", self.dataset)
109 | if not os.path.exists(model_path):
110 | os.makedirs(model_path)
111 | torch.save(self.model, os.path.join(model_path, "server" + ".pt"))
112 |
113 |
114 | def load_model(self):
115 | model_path = os.path.join("models", self.dataset, "server" + ".pt")
116 | assert (os.path.exists(model_path))
117 | self.model = torch.load(model_path)
118 |
119 | def model_exists(self):
120 | return os.path.exists(os.path.join("models", self.dataset, "server" + ".pt"))
121 |
122 | def select_users(self, round, num_users, return_idx=False):
123 | '''selects num_clients clients weighted by number of samples from possible_clients
124 | Args:
125 | num_clients: number of clients to select; default 20
126 | note that within function, num_clients is set to
127 | min(num_clients, len(possible_clients))
128 | Return:
129 | list of selected clients objects
130 | '''
131 | if(num_users == len(self.users)):
132 | print("All users are selected")
133 | return self.users
134 |
135 | num_users = min(num_users, len(self.users))
136 | if return_idx:
137 | user_idxs = np.random.choice(range(len(self.users)), num_users, replace=False) # , p=pk)
138 | return [self.users[i] for i in user_idxs], user_idxs
139 | else:
140 | return np.random.choice(self.users, num_users, replace=False)
141 |
142 |
143 | def init_loss_fn(self):
144 | self.loss=nn.NLLLoss()
145 | self.ensemble_loss=nn.KLDivLoss(reduction="batchmean")#,log_target=True)
146 | self.ce_loss = nn.CrossEntropyLoss()
147 |
148 |
149 | def save_results(self, args):
150 | alg = get_log_path(args, args.algorithm, self.seed, args.gen_batch_size)
151 | with h5py.File("./{}/{}.h5".format(self.save_path, alg), 'w') as hf:
152 | for key in self.metrics:
153 | hf.create_dataset(key, data=self.metrics[key])
154 | hf.close()
155 |
156 |
157 | def test(self, selected=False):
158 | '''tests self.latest_model on given clients
159 | '''
160 | num_samples = []
161 | tot_correct = []
162 | losses = []
163 | users = self.selected_users if selected else self.users
164 | for c in users:
165 | ct, c_loss, ns = c.test()
166 | tot_correct.append(ct*1.0)
167 | num_samples.append(ns)
168 | losses.append(c_loss)
169 | ids = [c.id for c in self.users]
170 |
171 | return ids, num_samples, tot_correct, losses
172 |
173 |
174 |
175 | def test_personalized_model(self, selected=True):
176 | '''tests self.latest_model on given clients
177 | '''
178 | num_samples = []
179 | tot_correct = []
180 | losses = []
181 | users = self.selected_users if selected else self.users
182 | for c in users:
183 | ct, ns, loss = c.test_personalized_model()
184 | tot_correct.append(ct*1.0)
185 | num_samples.append(ns)
186 | losses.append(loss)
187 | ids = [c.id for c in self.users]
188 |
189 | return ids, num_samples, tot_correct, losses
190 |
191 | def evaluate_personalized_model(self, selected=True, save=True):
192 | stats = self.test_personalized_model(selected=selected)
193 | test_ids, test_num_samples, test_tot_correct, test_losses = stats[:4]
194 | glob_acc = np.sum(test_tot_correct)*1.0/np.sum(test_num_samples)
195 | test_loss = np.sum([x * y for (x, y) in zip(test_num_samples, test_losses)]).item() / np.sum(test_num_samples)
196 | if save:
197 | self.metrics['per_acc'].append(glob_acc)
198 | self.metrics['per_loss'].append(test_loss)
199 | print("Average Global Accurancy = {:.4f}, Loss = {:.2f}.".format(glob_acc, test_loss))
200 |
201 |
202 | def evaluate_ensemble(self, selected=True):
203 | self.model.eval()
204 | users = self.selected_users if selected else self.users
205 | test_acc=0
206 | loss=0
207 | for x, y in self.testloaderfull:
208 | target_logit_output=0
209 | for user in users:
210 | # get user logit
211 | user.model.eval()
212 | user_result=user.model(x, logit=True)
213 | target_logit_output+=user_result['logit']
214 | target_logp=F.log_softmax(target_logit_output, dim=1)
215 | test_acc+= torch.sum( torch.argmax(target_logp, dim=1) == y ) #(torch.sum().item()
216 | loss+=self.loss(target_logp, y)
217 | loss = loss.detach().numpy()
218 | test_acc = test_acc.detach().numpy() / y.shape[0]
219 | self.metrics['glob_acc'].append(test_acc)
220 | self.metrics['glob_loss'].append(loss)
221 | print("Average Global Accurancy = {:.4f}, Loss = {:.2f}.".format(test_acc, loss))
222 |
223 |
224 | def evaluate(self, save=True, selected=False):
225 | # override evaluate function to log vae-loss.
226 | test_ids, test_samples, test_accs, test_losses = self.test(selected=selected)
227 | glob_acc = np.sum(test_accs)*1.0/np.sum(test_samples)
228 | glob_loss = np.sum([x * y.detach() for (x, y) in zip(test_samples, test_losses)]).item() / np.sum(test_samples)
229 | if save:
230 | self.metrics['glob_acc'].append(glob_acc)
231 | self.metrics['glob_loss'].append(glob_loss)
232 | print("Average Global Accurancy = {:.4f}, Loss = {:.2f}.".format(glob_acc, glob_loss))
233 |
234 |
--------------------------------------------------------------------------------
/data/EMnist/generate_niid_dirichlet.py:
--------------------------------------------------------------------------------
1 | from tqdm import trange
2 | import numpy as np
3 | import random
4 | import json
5 | import os
6 | import argparse
7 | from torchvision.datasets import EMNIST
8 | import torch
9 | from torch.utils.data import DataLoader
10 | import torchvision
11 | import torchvision.transforms as transforms
12 |
13 | random.seed(42)
14 | np.random.seed(42)
15 |
16 | def rearrange_data_by_class(data, targets, n_class):
17 | new_data = []
18 | for i in trange(n_class):
19 | idx = targets == i
20 | new_data.append(data[idx])
21 | return new_data
22 |
23 | def get_dataset(mode='train', split='balanced'):
24 | transform = transforms.Compose(
25 | [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
26 |
27 | dataset = EMNIST(root='./data', split=split, train=True if mode=='train' else False, download=True, transform=transform)
28 | n_sample = len(dataset.data)
29 | SRC_N_CLASS = len(dataset.classes)
30 | # full batch
31 | trainloader = DataLoader(dataset, batch_size=n_sample, shuffle=False)
32 |
33 | print("Loading data from storage ...")
34 | for _, xy in enumerate(trainloader, 0):
35 | dataset.data, dataset.targets = xy
36 |
37 | print("Rearrange data by class...")
38 | data_by_class = rearrange_data_by_class(
39 | dataset.data.cpu().detach().numpy(),
40 | dataset.targets.cpu().detach().numpy(),
41 | SRC_N_CLASS
42 | )
43 | if split == 'letters':
44 | data_by_class.pop(0)
45 | SRC_N_CLASS-=1
46 |
47 | print(f"{mode.upper()} SET:\n Total #samples: {n_sample}. sample shape: {dataset.data[0].shape}")
48 | print(" #samples per class:\n", [len(v) for v in data_by_class])
49 |
50 | return data_by_class, n_sample, SRC_N_CLASS
51 |
52 | def sample_class(SRC_N_CLASS, NUM_LABELS, user_id, label_random=False):
53 | assert NUM_LABELS <= SRC_N_CLASS
54 | if label_random:
55 | source_classes = [n for n in range(SRC_N_CLASS)]
56 | random.shuffle(source_classes)
57 | return source_classes[:NUM_LABELS]
58 | else:
59 | return [(user_id + j) % SRC_N_CLASS for j in range(NUM_LABELS)]
60 |
61 | def devide_train_data(data, n_sample, SRC_CLASSES, NUM_USERS, min_sample, alpha=0.5, sampling_ratio=0.5):
62 | min_sample = len(SRC_CLASSES) * min_sample
63 | min_size = 0 # track minimal samples per user
64 | ###### Determine Sampling #######
65 | while min_size < min_sample:
66 | print("Try to find valid data separation")
67 | idx_batch=[{} for _ in range(NUM_USERS)]
68 | samples_per_user = [0 for _ in range(NUM_USERS)]
69 | max_samples_per_user = sampling_ratio * n_sample / NUM_USERS
70 | for l in SRC_CLASSES:
71 | # get indices for all that label
72 | idx_l = [i for i in range(len(data[l]))]
73 | np.random.shuffle(idx_l)
74 | if sampling_ratio < 1:
75 | samples_for_l = min(max_samples_per_user, int(sampling_ratio * len(data[l])))
76 | idx_l = idx_l[:samples_for_l]
77 | print(l, len(data[l]), len(idx_l))
78 | # dirichlet sampling from this label
79 | proportions=np.random.dirichlet(np.repeat(alpha, NUM_USERS))
80 | # re-balance proportions
81 | proportions=np.array([p * (n_per_user < max_samples_per_user) for p, n_per_user in zip(proportions, samples_per_user)])
82 | proportions=proportions / proportions.sum()
83 | proportions=(np.cumsum(proportions) * len(idx_l)).astype(int)[:-1]
84 | # participate data of that label
85 | for u, new_idx in enumerate(np.split(idx_l, proportions)):
86 | # add new idex to the user
87 | idx_batch[u][l] = new_idx.tolist()
88 | samples_per_user[u] += len(idx_batch[u][l])
89 | min_size=min(samples_per_user)
90 |
91 | ###### CREATE USER DATA SPLIT #######
92 | X = [[] for _ in range(NUM_USERS)]
93 | y = [[] for _ in range(NUM_USERS)]
94 | Labels=[set() for _ in range(NUM_USERS)]
95 | print("processing users...")
96 | for u, user_idx_batch in enumerate(idx_batch):
97 | for l, indices in user_idx_batch.items():
98 | if len(indices) == 0: continue
99 | X[u] += data[l][indices].tolist()
100 | y[u] += (l * np.ones(len(indices))).tolist()
101 | Labels[u].add(l)
102 |
103 | return X, y, Labels, idx_batch, samples_per_user
104 |
105 |
106 | def divide_test_data(NUM_USERS, SRC_CLASSES, test_data, Labels, unknown_test):
107 | # Create TEST data for each user.
108 | test_X = [[] for _ in range(NUM_USERS)]
109 | test_y = [[] for _ in range(NUM_USERS)]
110 | idx = {l: 0 for l in SRC_CLASSES}
111 | for user in trange(NUM_USERS):
112 | if unknown_test: # use all available labels
113 | user_sampled_labels = SRC_CLASSES
114 | else:
115 | user_sampled_labels = list(Labels[user])
116 | for l in user_sampled_labels:
117 | num_samples = int(len(test_data[l]) / NUM_USERS )
118 | assert num_samples + idx[l] <= len(test_data[l])
119 | test_X[user] += test_data[l][idx[l]:idx[l] + num_samples].tolist()
120 | test_y[user] += (l * np.ones(num_samples)).tolist()
121 | assert len(test_X[user]) == len(test_y[user]), f"{len(test_X[user])} == {len(test_y[user])}"
122 | idx[l] += num_samples
123 | return test_X, test_y
124 |
125 | def main():
126 | parser = argparse.ArgumentParser()
127 | parser.add_argument("--format", "-f", type=str, default="pt", help="Format of saving: pt (torch.save), json", choices=["pt", "json"])
128 | parser.add_argument("--n_class", type=int, default=5, help="number of classification labels")
129 | parser.add_argument("--random_label", type=int, default=1, help="randomly sampling labels for each task.")
130 | parser.add_argument("--dirname", type=str, default='', help="Directory name, optional.")
131 | parser.add_argument("--split", type=str, default='letters', choices=["letters", "balanced"])
132 | parser.add_argument("--min_sample", type=int, default=10, help="Min number of samples per user.")
133 | parser.add_argument("--sampling_ratio", type=float, default=0.5, help="Ratio for sampling training samples.")
134 | parser.add_argument("--unknown_test", type=int, default=0, help="Whether allow test label unseen for each user.")
135 | parser.add_argument("--alpha", type=float, default=0.5, help="alpha in Dirichelt distribution (smaller means larger heterogeneity)")
136 | parser.add_argument("--n_user", type=int, default=20,
137 | help="number of local clients, should be muitiple of 10.")
138 | args = parser.parse_args()
139 | print()
140 | print("Number of users: {}".format(args.n_user))
141 | print("Number of classes: {}".format(args.n_class))
142 | print("Min # of samples per uesr: {}".format(args.min_sample))
143 | print("Alpha for Dirichlet Distribution: {}".format(args.alpha))
144 | print("Ratio for Sampling Training Data: {}".format(args.sampling_ratio))
145 | NUM_USERS = args.n_user
146 |
147 | # Setup directory for train/test data
148 | path_prefix = f'u{args.n_user}-{args.split}-alpha{args.alpha}-ratio{args.sampling_ratio}'
149 |
150 | def process_user_data(mode, data, n_sample, SRC_CLASSES, Labels=None, unknown_test=0):
151 | if mode == 'train':
152 | X, y, Labels, idx_batch, samples_per_user = devide_train_data(
153 | data, n_sample, SRC_CLASSES, NUM_USERS, args.min_sample, args.alpha, args.sampling_ratio)
154 | if mode == 'test':
155 | assert Labels != None or unknown_test
156 | X, y = divide_test_data(NUM_USERS, SRC_CLASSES, data, Labels, unknown_test)
157 | dataset={'users': [], 'user_data': {}, 'num_samples': []}
158 | for i in range(NUM_USERS):
159 | uname='f_{0:05d}'.format(i)
160 | dataset['users'].append(uname)
161 | dataset['user_data'][uname]={
162 | 'x': torch.tensor(X[i], dtype=torch.float32),
163 | 'y': torch.tensor(y[i], dtype=torch.int64)}
164 | dataset['num_samples'].append(len(X[i]))
165 |
166 | print("{} #sample by user:".format(mode.upper()), dataset['num_samples'])
167 |
168 | data_path=f'./{path_prefix}/{mode}'
169 | if not os.path.exists(data_path):
170 | os.makedirs(data_path)
171 |
172 | data_path=os.path.join(data_path, "{}.".format(mode) + args.format)
173 | if args.format == "json":
174 | raise NotImplementedError(
175 | "json is not supported because the train_data/test_data uses the tensor instead of list and tensor cannot be saved into json.")
176 | with open(data_path, 'w') as outfile:
177 | print(f"Dumping train data => {data_path}")
178 | json.dump(dataset, outfile)
179 | elif args.format == "pt":
180 | with open(data_path, 'wb') as outfile:
181 | print(f"Dumping train data => {data_path}")
182 | torch.save(dataset, outfile)
183 | if mode == 'train':
184 | for u in range(NUM_USERS):
185 | print("{} samples in total".format(samples_per_user[u]))
186 | train_info = ''
187 | # train_idx_batch, train_samples_per_user
188 | n_samples_for_u = 0
189 | for l in sorted(list(Labels[u])):
190 | n_samples_for_l = len(idx_batch[u][l])
191 | n_samples_for_u += n_samples_for_l
192 | train_info += "c={},n={}| ".format(l, n_samples_for_l)
193 | print(train_info)
194 | print("{} Labels/ {} Number of training samples for user [{}]:".format(len(Labels[u]), n_samples_for_u, u))
195 | return Labels, idx_batch, samples_per_user
196 |
197 |
198 | print(f"Reading source dataset.")
199 | train_data, n_train_sample, SRC_N_CLASS = get_dataset(mode='train', split=args.split)
200 | test_data, n_test_sample, SRC_N_CLASS = get_dataset(mode='test', split=args.split)
201 | SRC_CLASSES=[l for l in range(SRC_N_CLASS)]
202 | random.shuffle(SRC_CLASSES)
203 | Labels, idx_batch, samples_per_user = process_user_data('train', train_data, n_train_sample, SRC_CLASSES)
204 | process_user_data('test', test_data, n_test_sample, SRC_CLASSES, Labels=Labels, unknown_test=args.unknown_test)
205 | print("Finish Generating User samples")
206 |
207 |
208 | if __name__ == "__main__":
209 | main()
--------------------------------------------------------------------------------
/FLAlgorithms/servers/serverpFedGen.py:
--------------------------------------------------------------------------------
1 | from FLAlgorithms.users.userpFedGen import UserpFedGen
2 | from FLAlgorithms.servers.serverbase import Server
3 | from utils.model_utils import read_data, read_user_data, aggregate_user_data, create_generative_model
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import numpy as np
8 | from torchvision.utils import save_image
9 | import os
10 | import copy
11 | import time
12 | MIN_SAMPLES_PER_LABEL=1
13 |
14 | class FedGen(Server):
15 | def __init__(self, args, model, seed):
16 | super().__init__(args, model, seed)
17 |
18 | # Initialize data for all users
19 | data = read_data(args.dataset)
20 | # data contains: clients, groups, train_data, test_data, proxy_data
21 | clients = data[0]
22 | total_users = len(clients)
23 | self.total_test_samples = 0
24 | self.local = 'local' in self.algorithm.lower()
25 | self.use_adam = 'adam' in self.algorithm.lower()
26 |
27 | self.early_stop = 20 # stop using generated samples after 20 local epochs
28 | self.student_model = copy.deepcopy(self.model)
29 | self.generative_model = create_generative_model(args.dataset, args.algorithm, self.model_name, args.embedding)
30 | if not args.train:
31 | print('number of generator parameteres: [{}]'.format(self.generative_model.get_number_of_parameters()))
32 | print('number of model parameteres: [{}]'.format(self.model.get_number_of_parameters()))
33 | self.latent_layer_idx = self.generative_model.latent_layer_idx
34 | self.init_ensemble_configs()
35 | print("latent_layer_idx: {}".format(self.latent_layer_idx))
36 | print("label embedding {}".format(self.generative_model.embedding))
37 | print("ensemeble learning rate: {}".format(self.ensemble_lr))
38 | print("ensemeble alpha = {}, beta = {}, eta = {}".format(self.ensemble_alpha, self.ensemble_beta, self.ensemble_eta))
39 | print("generator alpha = {}, beta = {}".format(self.generative_alpha, self.generative_beta))
40 | self.init_loss_fn()
41 | self.train_data_loader, self.train_iter, self.available_labels = aggregate_user_data(data, args.dataset, self.ensemble_batch_size)
42 | self.generative_optimizer = torch.optim.Adam(
43 | params=self.generative_model.parameters(),
44 | lr=self.ensemble_lr, betas=(0.9, 0.999),
45 | eps=1e-08, weight_decay=self.weight_decay, amsgrad=False)
46 | self.generative_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
47 | optimizer=self.generative_optimizer, gamma=0.98)
48 | self.optimizer = torch.optim.Adam(
49 | params=self.model.parameters(),
50 | lr=self.ensemble_lr, betas=(0.9, 0.999),
51 | eps=1e-08, weight_decay=0, amsgrad=False)
52 | self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.optimizer, gamma=0.98)
53 |
54 | #### creating users ####
55 | self.users = []
56 | for i in range(total_users):
57 | id, train_data, test_data, label_info =read_user_data(i, data, dataset=args.dataset, count_labels=True)
58 | self.total_train_samples+=len(train_data)
59 | self.total_test_samples += len(test_data)
60 | id, train, test=read_user_data(i, data, dataset=args.dataset)
61 | user=UserpFedGen(
62 | args, id, model, self.generative_model,
63 | train_data, test_data,
64 | self.available_labels, self.latent_layer_idx, label_info,
65 | use_adam=self.use_adam)
66 | self.users.append(user)
67 | print("Number of Train/Test samples:", self.total_train_samples, self.total_test_samples)
68 | print("Data from {} users in total.".format(total_users))
69 | print("Finished creating FedAvg server.")
70 |
71 | def train(self, args):
72 | #### pretraining
73 | for glob_iter in range(self.num_glob_iters):
74 | print("\n\n-------------Round number: ",glob_iter, " -------------\n\n")
75 | self.selected_users, self.user_idxs=self.select_users(glob_iter, self.num_users, return_idx=True)
76 | if not self.local:
77 | self.send_parameters(mode=self.mode)# broadcast averaged prediction model
78 | self.evaluate()
79 | chosen_verbose_user = np.random.randint(0, len(self.users))
80 | self.timestamp = time.time() # log user-training start time
81 | for user_id, user in zip(self.user_idxs, self.selected_users): # allow selected users to train
82 | verbose= user_id == chosen_verbose_user
83 | # perform regularization using generated samples after the first communication round
84 | user.train(
85 | glob_iter,
86 | personalized=self.personalized,
87 | early_stop=self.early_stop,
88 | verbose=verbose and glob_iter > 0,
89 | regularization= glob_iter > 0 )
90 | curr_timestamp = time.time() # log user-training end time
91 | train_time = (curr_timestamp - self.timestamp) / len(self.selected_users)
92 | self.metrics['user_train_time'].append(train_time)
93 | if self.personalized:
94 | self.evaluate_personalized_model()
95 |
96 | self.timestamp = time.time() # log server-agg start time
97 | self.train_generator(
98 | self.batch_size,
99 | epoches=self.ensemble_epochs // self.n_teacher_iters,
100 | latent_layer_idx=self.latent_layer_idx,
101 | verbose=True
102 | )
103 | self.aggregate_parameters()
104 | curr_timestamp=time.time() # log server-agg end time
105 | agg_time = curr_timestamp - self.timestamp
106 | self.metrics['server_agg_time'].append(agg_time)
107 | if glob_iter > 0 and glob_iter % 20 == 0 and self.latent_layer_idx == 0:
108 | self.visualize_images(self.generative_model, glob_iter, repeats=10)
109 |
110 | self.save_results(args)
111 | self.save_model()
112 |
113 | def train_generator(self, batch_size, epoches=1, latent_layer_idx=-1, verbose=False):
114 | """
115 | Learn a generator that find a consensus latent representation z, given a label 'y'.
116 | :param batch_size:
117 | :param epoches:
118 | :param latent_layer_idx: if set to -1 (-2), get latent representation of the last (or 2nd to last) layer.
119 | :param verbose: print loss information.
120 | :return: Do not return anything.
121 | """
122 | #self.generative_regularizer.train()
123 | self.label_weights, self.qualified_labels = self.get_label_weights()
124 | TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS, STUDENT_LOSS2 = 0, 0, 0, 0
125 |
126 | def update_generator_(n_iters, student_model, TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS):
127 | self.generative_model.train()
128 | student_model.eval()
129 | for i in range(n_iters):
130 | self.generative_optimizer.zero_grad()
131 | y=np.random.choice(self.qualified_labels, batch_size)
132 | y_input=torch.LongTensor(y)
133 | ## feed to generator
134 | gen_result=self.generative_model(y_input, latent_layer_idx=latent_layer_idx, verbose=True)
135 | # get approximation of Z( latent) if latent set to True, X( raw image) otherwise
136 | gen_output, eps=gen_result['output'], gen_result['eps']
137 | ##### get losses ####
138 | # decoded = self.generative_regularizer(gen_output)
139 | # regularization_loss = beta * self.generative_model.dist_loss(decoded, eps) # map generated z back to eps
140 | diversity_loss=self.generative_model.diversity_loss(eps, gen_output) # encourage different outputs
141 |
142 | ######### get teacher loss ############
143 | teacher_loss=0
144 | teacher_logit=0
145 | for user_idx, user in enumerate(self.selected_users):
146 | user.model.eval()
147 | weight=self.label_weights[y][:, user_idx].reshape(-1, 1)
148 | expand_weight=np.tile(weight, (1, self.unique_labels))
149 | user_result_given_gen=user.model(gen_output, start_layer_idx=latent_layer_idx, logit=True)
150 | user_output_logp_=F.log_softmax(user_result_given_gen['logit'], dim=1)
151 | teacher_loss_=torch.mean( \
152 | self.generative_model.crossentropy_loss(user_output_logp_, y_input) * \
153 | torch.tensor(weight, dtype=torch.float32))
154 | teacher_loss+=teacher_loss_
155 | teacher_logit+=user_result_given_gen['logit'] * torch.tensor(expand_weight, dtype=torch.float32)
156 |
157 | ######### get student loss ############
158 | student_output=student_model(gen_output, start_layer_idx=latent_layer_idx, logit=True)
159 | student_loss=F.kl_div(F.log_softmax(student_output['logit'], dim=1), F.softmax(teacher_logit, dim=1))
160 | if self.ensemble_beta > 0:
161 | loss=self.ensemble_alpha * teacher_loss - self.ensemble_beta * student_loss + self.ensemble_eta * diversity_loss
162 | else:
163 | loss=self.ensemble_alpha * teacher_loss + self.ensemble_eta * diversity_loss
164 | loss.backward()
165 | self.generative_optimizer.step()
166 | TEACHER_LOSS += self.ensemble_alpha * teacher_loss#(torch.mean(TEACHER_LOSS.double())).item()
167 | STUDENT_LOSS += self.ensemble_beta * student_loss#(torch.mean(student_loss.double())).item()
168 | DIVERSITY_LOSS += self.ensemble_eta * diversity_loss#(torch.mean(diversity_loss.double())).item()
169 | return TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS
170 |
171 | for i in range(epoches):
172 | TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS=update_generator_(
173 | self.n_teacher_iters, self.model, TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS)
174 |
175 | TEACHER_LOSS = TEACHER_LOSS.detach().numpy() / (self.n_teacher_iters * epoches)
176 | STUDENT_LOSS = STUDENT_LOSS.detach().numpy() / (self.n_teacher_iters * epoches)
177 | DIVERSITY_LOSS = DIVERSITY_LOSS.detach().numpy() / (self.n_teacher_iters * epoches)
178 | info="Generator: Teacher Loss= {:.4f}, Student Loss= {:.4f}, Diversity Loss = {:.4f}, ". \
179 | format(TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS)
180 | if verbose:
181 | print(info)
182 | self.generative_lr_scheduler.step()
183 |
184 |
185 | def get_label_weights(self):
186 | label_weights = []
187 | qualified_labels = []
188 | for label in range(self.unique_labels):
189 | weights = []
190 | for user in self.selected_users:
191 | weights.append(user.label_counts[label])
192 | if np.max(weights) > MIN_SAMPLES_PER_LABEL:
193 | qualified_labels.append(label)
194 | # uniform
195 | label_weights.append( np.array(weights) / np.sum(weights) )
196 | label_weights = np.array(label_weights).reshape((self.unique_labels, -1))
197 | return label_weights, qualified_labels
198 |
199 | def visualize_images(self, generator, glob_iter, repeats=1):
200 | """
201 | Generate and visualize data for a generator.
202 | """
203 | os.system("mkdir -p images")
204 | path = f'images/{self.algorithm}-{self.dataset}-iter{glob_iter}.png'
205 | y=self.available_labels
206 | y = np.repeat(y, repeats=repeats, axis=0)
207 | y_input=torch.tensor(y)
208 | generator.eval()
209 | images=generator(y_input, latent=False)['output'] # 0,1,..,K, 0,1,...,K
210 | images=images.view(repeats, -1, *images.shape[1:])
211 | images=images.view(-1, *images.shape[2:])
212 | save_image(images.detach(), path, nrow=repeats, normalize=True)
213 | print("Image saved to {}".format(path))
--------------------------------------------------------------------------------
/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import os
4 | import torch
5 | import torch.nn as nn
6 | import torchvision
7 | import torchvision.transforms as transforms
8 | from tqdm import trange
9 | import random
10 | import numpy as np
11 | from FLAlgorithms.trainmodel.models import Net
12 | from torch.utils.data import DataLoader
13 | from FLAlgorithms.trainmodel.generator import Generator
14 | from utils.model_config import *
15 | METRICS = ['glob_acc', 'per_acc', 'glob_loss', 'per_loss', 'user_train_time', 'server_agg_time']
16 |
17 |
18 | def get_data_dir(dataset):
19 | if 'EMnist' in dataset:
20 | #EMnist-alpha0.1-ratio0.1-0-letters
21 | dataset_=dataset.replace('alpha', '').replace('ratio', '').split('-')
22 | alpha, ratio =dataset_[1], dataset_[2]
23 | types = 'letters'
24 | path_prefix = os.path.join('data', 'EMnist', f'u20-{types}-alpha{alpha}-ratio{ratio}')
25 | train_data_dir=os.path.join(path_prefix, 'train')
26 | test_data_dir=os.path.join(path_prefix, 'test')
27 | proxy_data_dir = 'data/proxy_data/emnist-n10/'
28 |
29 | elif 'Mnist' in dataset:
30 | dataset_=dataset.replace('alpha', '').replace('ratio', '').split('-')
31 | alpha, ratio=dataset_[1], dataset_[2]
32 | #path_prefix=os.path.join('data', 'Mnist', 'u20alpha{}min10ratio{}'.format(alpha, ratio))
33 | path_prefix=os.path.join('data', 'Mnist', 'u20c10-alpha{}-ratio{}'.format(alpha, ratio))
34 | train_data_dir=os.path.join(path_prefix, 'train')
35 | test_data_dir=os.path.join(path_prefix, 'test')
36 | proxy_data_dir = 'data/proxy_data/mnist-n10/'
37 |
38 | elif 'celeb' in dataset.lower():
39 | dataset_ = dataset.lower().replace('user', '').replace('agg','').split('-')
40 | user, agg_user = dataset_[1], dataset_[2]
41 | path_prefix = os.path.join('data', 'CelebA', 'user{}-agg{}'.format(user,agg_user))
42 | train_data_dir=os.path.join(path_prefix, 'train')
43 | test_data_dir=os.path.join(path_prefix, 'test')
44 | proxy_data_dir=os.path.join('/user500/', 'proxy')
45 |
46 | else:
47 | raise ValueError("Dataset not recognized.")
48 | return train_data_dir, test_data_dir, proxy_data_dir
49 |
50 |
51 | def read_data(dataset):
52 | '''parses data in given train and test data directories
53 |
54 | assumes:
55 | - the data in the input directories are .json files with
56 | keys 'users' and 'user_data'
57 | - the set of train set users is the same as the set of test set users
58 |
59 | Return:
60 | clients: list of client ids
61 | groups: list of group ids; empty list if none found
62 | train_data: dictionary of train data
63 | test_data: dictionary of test data
64 | '''
65 | train_data_dir, test_data_dir, proxy_data_dir = get_data_dir(dataset)
66 | clients = []
67 | groups = []
68 | train_data = {}
69 | test_data = {}
70 | proxy_data = {}
71 |
72 | train_files = os.listdir(train_data_dir)
73 | train_files = [f for f in train_files if f.endswith('.json') or f.endswith(".pt")]
74 | for f in train_files:
75 | file_path = os.path.join(train_data_dir, f)
76 | if file_path.endswith("json"):
77 | with open(file_path, 'r') as inf:
78 | cdata = json.load(inf)
79 | elif file_path.endswith(".pt"):
80 | with open(file_path, 'rb') as inf:
81 | cdata = torch.load(inf)
82 | else:
83 | raise TypeError("Data format not recognized: {}".format(file_path))
84 |
85 | clients.extend(cdata['users'])
86 | if 'hierarchies' in cdata:
87 | groups.extend(cdata['hierarchies'])
88 | train_data.update(cdata['user_data'])
89 |
90 | clients = list(sorted(train_data.keys()))
91 |
92 | test_files = os.listdir(test_data_dir)
93 | test_files = [f for f in test_files if f.endswith('.json') or f.endswith(".pt")]
94 | for f in test_files:
95 | file_path = os.path.join(test_data_dir, f)
96 | if file_path.endswith(".pt"):
97 | with open(file_path, 'rb') as inf:
98 | cdata = torch.load(inf)
99 | elif file_path.endswith(".json"):
100 | with open(file_path, 'r') as inf:
101 | cdata = json.load(inf)
102 | else:
103 | raise TypeError("Data format not recognized: {}".format(file_path))
104 | test_data.update(cdata['user_data'])
105 |
106 |
107 | if proxy_data_dir and os.path.exists(proxy_data_dir):
108 | proxy_files=os.listdir(proxy_data_dir)
109 | proxy_files=[f for f in proxy_files if f.endswith('.json') or f.endswith(".pt")]
110 | for f in proxy_files:
111 | file_path=os.path.join(proxy_data_dir, f)
112 | if file_path.endswith(".pt"):
113 | with open(file_path, 'rb') as inf:
114 | cdata=torch.load(inf)
115 | elif file_path.endswith(".json"):
116 | with open(file_path, 'r') as inf:
117 | cdata=json.load(inf)
118 | else:
119 | raise TypeError("Data format not recognized: {}".format(file_path))
120 | proxy_data.update(cdata['user_data'])
121 |
122 | return clients, groups, train_data, test_data, proxy_data
123 |
124 |
125 | def read_proxy_data(proxy_data, dataset, batch_size):
126 | X, y=proxy_data['x'], proxy_data['y']
127 | X, y = convert_data(X, y, dataset=dataset)
128 | dataset = [(x, y) for x, y in zip(X, y)]
129 | proxyloader = DataLoader(dataset, batch_size, shuffle=True)
130 | iter_proxyloader = iter(proxyloader)
131 | return proxyloader, iter_proxyloader
132 |
133 |
134 | def aggregate_data_(clients, dataset, dataset_name, batch_size):
135 | combined = []
136 | unique_labels = []
137 | for i in range(len(dataset)):
138 | id = clients[i]
139 | user_data = dataset[id]
140 | X, y = convert_data(user_data['x'], user_data['y'], dataset=dataset_name)
141 | combined += [(x, y) for x, y in zip(X, y)]
142 | unique_y=torch.unique(y)
143 | unique_y = unique_y.detach().numpy()
144 | unique_labels += list(unique_y)
145 |
146 | data_loader=DataLoader(combined, batch_size, shuffle=True)
147 | iter_loader=iter(data_loader)
148 | return data_loader, iter_loader, unique_labels
149 |
150 |
151 | def aggregate_user_test_data(data, dataset_name, batch_size):
152 | clients, loaded_data=data[0], data[3]
153 | data_loader, _, unique_labels=aggregate_data_(clients, loaded_data, dataset_name, batch_size)
154 | return data_loader, np.unique(unique_labels)
155 |
156 |
157 | def aggregate_user_data(data, dataset_name, batch_size):
158 | # data contains: clients, groups, train_data, test_data, proxy_data
159 | clients, loaded_data = data[0], data[2]
160 | data_loader, data_iter, unique_labels = aggregate_data_(clients, loaded_data, dataset_name, batch_size)
161 | return data_loader, data_iter, np.unique(unique_labels)
162 |
163 |
164 | def convert_data(X, y, dataset=''):
165 | if not isinstance(X, torch.Tensor):
166 | if 'celeb' in dataset.lower():
167 | X=torch.Tensor(X).type(torch.float32).permute(0, 3, 1, 2)
168 | y=torch.Tensor(y).type(torch.int64)
169 |
170 | else:
171 | X=torch.Tensor(X).type(torch.float32)
172 | y=torch.Tensor(y).type(torch.int64)
173 | return X, y
174 |
175 |
176 | def read_user_data(index, data, dataset='', count_labels=False):
177 | #data contains: clients, groups, train_data, test_data, proxy_data(optional)
178 | id = data[0][index]
179 | train_data = data[2][id]
180 | test_data = data[3][id]
181 | X_train, y_train = convert_data(train_data['x'], train_data['y'], dataset=dataset)
182 | train_data = [(x, y) for x, y in zip(X_train, y_train)]
183 | X_test, y_test = convert_data(test_data['x'], test_data['y'], dataset=dataset)
184 | test_data = [(x, y) for x, y in zip(X_test, y_test)]
185 | if count_labels:
186 | label_info = {}
187 | unique_y, counts=torch.unique(y_train, return_counts=True)
188 | unique_y=unique_y.detach().numpy()
189 | counts=counts.detach().numpy()
190 | label_info['labels']=unique_y
191 | label_info['counts']=counts
192 | return id, train_data, test_data, label_info
193 | return id, train_data, test_data
194 |
195 |
196 | def get_dataset_name(dataset):
197 | dataset=dataset.lower()
198 | passed_dataset=dataset.lower()
199 | if 'celeb' in dataset:
200 | passed_dataset='celeb'
201 | elif 'emnist' in dataset:
202 | passed_dataset='emnist'
203 | elif 'mnist' in dataset:
204 | passed_dataset='mnist'
205 | else:
206 | raise ValueError('Unsupported dataset {}'.format(dataset))
207 | return passed_dataset
208 |
209 |
210 | def create_generative_model(dataset, algorithm='', model='cnn', embedding=False):
211 | passed_dataset=get_dataset_name(dataset)
212 | assert any([alg in algorithm for alg in ['FedGen', 'FedGen']])
213 | if 'FedGen' in algorithm:
214 | # temporary roundabout to figure out the sensitivity of the generator network & sampling size
215 | if 'cnn' in algorithm:
216 | gen_model = algorithm.split('-')[1]
217 | passed_dataset+='-' + gen_model
218 | elif '-gen' in algorithm: # we use more lightweight network for sensitivity analysis
219 | passed_dataset += '-cnn1'
220 | return Generator(passed_dataset, model=model, embedding=embedding, latent_layer_idx=-1)
221 |
222 |
223 | def create_model(model, dataset, algorithm):
224 | passed_dataset = get_dataset_name(dataset)
225 | model= Net(passed_dataset, model), model
226 | return model
227 |
228 |
229 | def polyak_move(params, target_params, ratio=0.1):
230 | for param, target_param in zip(params, target_params):
231 | param.data=param.data - ratio * (param.clone().detach().data - target_param.clone().detach().data)
232 |
233 | def meta_move(params, target_params, ratio):
234 | for param, target_param in zip(params, target_params):
235 | target_param.data = param.clone().data + ratio * (target_param.clone().data - param.clone().data)
236 |
237 | def moreau_loss(params, reg_params):
238 | # return 1/T \sum_i^T |param_i - reg_param_i|^2
239 | losses = []
240 | for param, reg_param in zip(params, reg_params):
241 | losses.append( torch.mean(torch.square(param - reg_param.clone().detach())) )
242 | loss = torch.mean(torch.stack(losses))
243 | return loss
244 |
245 | def l2_loss(params):
246 | losses = []
247 | for param in params:
248 | losses.append( torch.mean(torch.square(param)))
249 | loss = torch.mean(torch.stack(losses))
250 | return loss
251 |
252 | def update_fast_params(fast_weights, grads, lr, allow_unused=False):
253 | """
254 | Update fast_weights by applying grads.
255 | :param fast_weights: list of parameters.
256 | :param grads: list of gradients
257 | :param lr:
258 | :return: updated fast_weights .
259 | """
260 | for grad, fast_weight in zip(grads, fast_weights):
261 | if allow_unused and grad is None: continue
262 | grad=torch.clamp(grad, -10, 10)
263 | fast_weight.data = fast_weight.data.clone() - lr * grad
264 | return fast_weights
265 |
266 |
267 | def init_named_params(model, keywords=['encode']):
268 | named_params={}
269 | #named_params_list = []
270 | for name, params in model.named_layers.items():
271 | if any([key in name for key in keywords]):
272 | named_params[name]=[param.clone().detach().requires_grad_(True) for param in params]
273 | #named_params_list += named_params[name]
274 | return named_params#, named_params_list
275 |
276 |
277 |
278 | def get_log_path(args, algorithm, seed, gen_batch_size=32):
279 | alg=args.dataset + "_" + algorithm
280 | alg+="_" + str(args.learning_rate) + "_" + str(args.num_users)
281 | alg+="u" + "_" + str(args.batch_size) + "b" + "_" + str(args.local_epochs)
282 | alg=alg + "_" + str(seed)
283 | if 'FedGen' in algorithm: # to accompany experiments for author rebuttal
284 | alg += "_embed" + str(args.embedding)
285 | if int(gen_batch_size) != int(args.batch_size):
286 | alg += "_gb" + str(gen_batch_size)
287 | return alg
--------------------------------------------------------------------------------
/FLAlgorithms/users/userpFedCL.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | import random
5 | import heapq
6 | from random import choice
7 | from sklearn.model_selection import train_test_split
8 | from FLAlgorithms.users.userbase import User
9 | from FLAlgorithms.curriculum.cl_score import CL_User_Score
10 | from FLAlgorithms.trainmodel.generator import Discriminator
11 | import pdb
12 | import decimal
13 | def median(x):
14 | x = sorted(x)
15 | length = len(x)
16 | mid, rem = divmod(length, 2)
17 | if rem:
18 | return x[:mid], x[mid+1:], x[mid]
19 | else:
20 | return x[:mid], x[mid:], x[mid-1]
21 |
22 |
23 |
24 | class UserpFedCL(User):
25 | def __init__(self,
26 | args, id, model, generative_model,
27 | train_data, test_data,
28 | available_labels, latent_layer_idx, label_info,
29 | use_adam=False):
30 | super().__init__(args, id, model, train_data, test_data, use_adam=use_adam)
31 | self.gen_batch_size = args.gen_batch_size
32 | self.generative_model = generative_model
33 | self.Discriminator = Discriminator()
34 | self.latent_layer_idx = latent_layer_idx
35 | self.available_labels = available_labels
36 | self.label_info=label_info
37 | self.CL_User_Score=CL_User_Score
38 |
39 | self.optimizer_discriminator = torch.optim.Adam(
40 | params=self.model.parameters(),
41 | lr=1e-4, betas=(0.9, 0.999),
42 | eps=1e-08, weight_decay=0, amsgrad=False)
43 | self.lr_scheduler_discriminator = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.optimizer_discriminator, gamma=0.98)
44 |
45 |
46 | def exp_lr_scheduler(self, epoch, decay=0.98, init_lr=0.1, lr_decay_epoch=1):
47 | """Decay learning rate by a factor of 0.95 every lr_decay_epoch epochs."""
48 | lr= max(1e-4, init_lr * (decay ** (epoch // lr_decay_epoch)))
49 | return lr
50 |
51 | def update_label_counts(self, labels, counts):
52 | for label, count in zip(labels, counts):
53 | self.label_counts[int(label)] += count
54 |
55 | def clean_up_counts(self):
56 | del self.label_counts
57 | self.label_counts = {label:1 for label in range(self.unique_labels)}
58 |
59 | def train(self, glob_iter, personalized=False, early_stop=100, regularization=True, verbose=False, run_curriculum=True,cl_score_norm_list=None,server_epoch=None,gmm_=None,gmm_len=None,next_stage=None):
60 | self.clean_up_counts()
61 | self.model.train()
62 | self.generative_model.eval()
63 | part_loss =0
64 | TEACHER_LOSS, DIST_LOSS, LATENT_LOSS, CL_SCORE = 0, 0, 0, 1
65 | exit_ = False
66 | for epoch in range(self.local_epochs):
67 | self.model.train()
68 | #print('epoch:'+str( epoch))
69 | #for i in range(self.K):
70 | self.optimizer.zero_grad()
71 | #### sample from real dataset (un-weighted)
72 | samples =self.get_next_train_batch(count_labels=True)
73 | X, y = samples['X'], samples['y']
74 | self.update_label_counts(samples['labels'], samples['counts'])
75 | model_result=self.model(X, logit=True)
76 | user_output_logp_ = model_result['output']
77 | user_output_logp = model_result['output']
78 | CL_results = self.CL_User_Score(model_result = model_result,
79 | Algorithms = 'SuperLoss_ce',
80 | loss_fun = self.loss,
81 | y = y,
82 | local_epoch = epoch,
83 | schedule = [epoch,self.local_epochs]
84 | )
85 | predictive_loss=CL_results['Loss']
86 | CL_Results_Score = CL_results['Curriculum_Learning_Score']
87 |
88 | #print(CL_results['score_list'])
89 | #print(CL_results['celoss'])
90 | # -----------------
91 | # Train Generator
92 | # -----------------
93 | #### sample y and generate z
94 | if regularization and epoch < early_stop:
95 |
96 | generative_alpha=self.exp_lr_scheduler(glob_iter, decay=0.98, init_lr=self.generative_alpha)
97 | generative_beta=self.exp_lr_scheduler(glob_iter, decay=0.98, init_lr=self.generative_beta)
98 | ### get generator output(latent representation) of the same label
99 | gmm_results,_ = gmm_.sample(1)
100 |
101 | real_cl_score_ = torch.tensor(gmm_results, dtype=torch.float)[:y.size()[0]].view(y.size()[0],1)
102 |
103 |
104 | #real_cl_score_ =torch.tensor(CL_results_score_list, dtype=torch.float).view(y.size()[0],1)
105 |
106 | gen_output_=self.generative_model(y, real_cl_score_, latent_layer_idx=self.latent_layer_idx)
107 | gen_output = gen_output_['output'].clone().detach()
108 | logit_given_gen=self.model(gen_output, start_layer_idx=self.latent_layer_idx, logit=True)['logit']
109 | target_p=F.softmax(logit_given_gen, dim=1).clone().detach()
110 | user_latent_loss= generative_beta * self.ensemble_loss(user_output_logp, target_p)
111 |
112 | sampled_y=np.random.choice(self.available_labels, self.gen_batch_size)
113 |
114 | sampled_y=torch.tensor(sampled_y)
115 | #CL_results_score_list_norm = (CL_results_score_list-min(cl_score_norm_list)) / (max(cl_score_norm_list) - min(cl_score_norm_list))
116 | #CL_results_score_list_norm = cl_score_norm_list
117 |
118 | l_Half, r_Half, q2 = median(gmm_.sample(gmm_len)[0].flatten())
119 | lHalf = median(l_Half)[2]
120 | rHalf = median(r_Half)[2]
121 | #50,100,150,200
122 | if next_stage==0:
123 | cl_score_fake_=[random.uniform(0,lHalf) for i in range(self.gen_batch_size)]
124 | elif next_stage==1:
125 | cl_score_fake_=[random.uniform(lHalf,rHalf) for i in range(self.gen_batch_size)]
126 | elif next_stage>=2:
127 | cl_score_fake_=[random.uniform(rHalf,1) for i in range(self.gen_batch_size)]
128 |
129 | #$print(cl_score_fake_)
130 | # easy_number = int(self.gen_batch_size/4*3)
131 | # hard_number = int(self.gen_batch_size - easy_number)
132 | # cl_score_fake_easy=[random.uniform(0,max(CL_results['score_list'])) for i in range(easy_number)]
133 | # cl_score_fake_hard=[random.uniform(min(CL_results['score_list']),0) for i in range(hard_number)]
134 | # cl_score_fake_ = cl_score_fake_easy + cl_score_fake_hard
135 | random.shuffle(cl_score_fake_)
136 | #print(float(CL_results_score_list[(CL_results_score_list_norm == lHalf).nonzero()[0][0]]))
137 | cl_score_fake_=torch.tensor(cl_score_fake_, dtype=torch.float).view(self.gen_batch_size,1)
138 | gen_result=self.generative_model(sampled_y,cl_score_fake_, latent_layer_idx=self.latent_layer_idx)
139 |
140 |
141 | gen_output=gen_result['output'] # latent representation when latent = True, x otherwise
142 |
143 | user_output_logp = self.model(gen_output, start_layer_idx=self.latent_layer_idx, logit=True)
144 | CL_score_fake_results = self.CL_User_Score(model_result = user_output_logp,
145 | Algorithms = 'SuperLoss_ce',
146 | loss_fun = self.loss,
147 | y = sampled_y,
148 | local_epoch = epoch,
149 | schedule = [epoch,self.local_epochs]
150 | )['score_list_base']
151 |
152 | values_ = ((CL_score_fake_results-min(cl_score_norm_list)) / (max(cl_score_norm_list) - min(cl_score_norm_list)))
153 |
154 |
155 | #print(CL_results_score_list[(CL_results_score_list_norm == lHalf).nonzero()[0][0]])
156 | #print(CL_score_fake_results)
157 | # print(sorted( values_,reverse=False))
158 | # print(sorted( values_,reverse=False)[int(len(values_)/1.1)])
159 | # exit()
160 | ##print(sorted( values_,reverse=False)[int(len(values_)*0.6)])
161 | #print(CL_results_score_list[(CL_results_score_list_norm == lHalf).nonzero()[0][0]])
162 | # if next_stage:
163 | # if sorted( values_,reverse=False)[int(len(values_)*0.8)] < lHalf:
164 | # print(epoch)
165 | # break
166 | # else:
167 | # if sorted( values_,reverse=True)[int(len(values_)*0.8)] > lHalf:
168 | # print(epoch)
169 | # break
170 |
171 | if next_stage==0:
172 | if sorted(values_,reverse=False)[int(len(values_)*0.8)] < lHalf:
173 | print(epoch)
174 | exit_ = True
175 | break
176 | elif next_stage==1:
177 | if rHalf > sorted(values_,reverse=False)[int(len(values_)*0.8)] > lHalf :
178 | print(epoch)
179 | exit_ = True
180 | break
181 | else:
182 | if sorted( values_,reverse=False)[int(len(values_)*0.8)] > rHalf:
183 | print(epoch)
184 | exit_ = True
185 | break
186 | user_output_logp = user_output_logp['output']
187 |
188 | teacher_loss = generative_alpha * torch.mean(
189 | self.generative_model.crossentropy_loss(user_output_logp, sampled_y)
190 | )
191 | # this is to further balance oversampled down-sampled synthetic data
192 | gen_ratio = self.gen_batch_size / self.batch_size
193 | loss=predictive_loss + gen_ratio * teacher_loss + user_latent_loss
194 | TEACHER_LOSS+=teacher_loss
195 | LATENT_LOSS+=user_latent_loss
196 |
197 | # if all( value < CL_results_score_list[(CL_results_score_list_norm == lHalf).nonzero()[0][0]] for value in sorted( values_,reverse=False)[:int(len(values_)/1.2)]):
198 |
199 | # break
200 | else:
201 | #### get loss and perform optimization
202 |
203 | loss=predictive_loss
204 | loss.backward(retain_graph=True)
205 | self.optimizer.step()#self.local_model)
206 |
207 | # ---------------------
208 | # Train Discriminator
209 | # ---------------------
210 |
211 | if regularization and epoch < early_stop:
212 | pass
213 | # self.optimizer_discriminator.zero_grad()
214 | # print(X.size())
215 | # print(gen_output_['output'].size())
216 | # D_H_real_loss = self.dist_loss(X, torch.ones_like(user_output_logp_))
217 | # D_H_fake_loss = self.dist_loss(gen_output_, torch.zeros_like(gen_output))
218 |
219 |
220 |
221 | # local-model <=== self.model
222 | self.clone_model_paramenter(self.model.parameters(), self.local_model)
223 | if personalized:
224 | self.clone_model_paramenter(self.model.parameters(), self.personalized_model_bar)
225 | self.lr_scheduler.step(glob_iter)
226 |
227 | if regularization and verbose:
228 | try:
229 | TEACHER_LOSS=TEACHER_LOSS.detach().numpy() / (self.local_epochs * self.K)
230 | except:
231 | TEACHER_LOSS=TEACHER_LOSS / (self.local_epochs * self.K)
232 | try:
233 | LATENT_LOSS=LATENT_LOSS.detach().numpy() / (self.local_epochs * self.K)
234 | except:
235 | LATENT_LOSS=LATENT_LOSS / (self.local_epochs * self.K)
236 | info='\nUser Teacher Loss={:.4f}'.format(TEACHER_LOSS)
237 | info+=', Latent Loss={:.4f}'.format(LATENT_LOSS)
238 | print(info)
239 |
240 | if CL_Results_Score!= None:
241 | return CL_results['score_list'],part_loss,exit_
242 | else:
243 | return None
244 |
245 | def adjust_weights(self, samples):
246 | labels, counts = samples['labels'], samples['counts']
247 | #weight=self.label_weights[y][:, user_idx].reshape(-1, 1)
248 | np_y = samples['y'].detach().numpy()
249 | n_labels = samples['y'].shape[0]
250 | weights = np.array([n_labels / count for count in counts]) # smaller count --> larger weight
251 | weights = len(self.available_labels) * weights / np.sum(weights) # normalized
252 | label_weights = np.ones(self.unique_labels)
253 | label_weights[labels] = weights
254 | sample_weights = label_weights[np_y]
255 | return sample_weights
256 |
257 |
258 |
--------------------------------------------------------------------------------
/FLAlgorithms/curriculum/cl_score.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from cvxopt import matrix, spdiag, solvers
3 | import numpy as np
4 | import random
5 | import torch.nn.functional as F
6 | import torch.nn as nn
7 | import math
8 | from scipy.special import lambertw
9 | from scipy.special import binom
10 |
11 |
12 | class TripletLoss(nn.Module):
13 | """Triplet loss with hard positive/negative mining.
14 |
15 | Reference:
16 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
17 |
18 | Imported from `