├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── model_config.cpython-38.pyc │ ├── model_config.cpython-39.pyc │ ├── model_utils.cpython-38.pyc │ ├── model_utils.cpython-39.pyc │ ├── plot_utils.cpython-38.pyc │ └── plot_utils.cpython-39.pyc ├── model_config.py ├── model_config-base.py ├── plot_utils.py └── model_utils.py ├── 组会汇报FEDCL.pptx ├── requirements.txt ├── FLAlgorithms ├── users │ ├── __pycache__ │ │ ├── useravg.cpython-38.pyc │ │ ├── useravg.cpython-39.pyc │ │ ├── userbase.cpython-38.pyc │ │ ├── userbase.cpython-39.pyc │ │ ├── userFedProx.cpython-38.pyc │ │ ├── userFedProx.cpython-39.pyc │ │ ├── userpFedGen.cpython-38.pyc │ │ ├── userpFedGen.cpython-39.pyc │ │ ├── userFedDistill.cpython-38.pyc │ │ └── userFedDistill.cpython-39.pyc │ ├── useravg.py │ ├── userFedProx.py │ ├── userGen.py │ ├── userpFedGen.py │ ├── userFedDistill.py │ ├── userpFedEnsemble.py │ ├── userbase.py │ └── userpFedCL.py ├── servers │ ├── __pycache__ │ │ ├── serveravg.cpython-38.pyc │ │ ├── serveravg.cpython-39.pyc │ │ ├── serverbase.cpython-38.pyc │ │ ├── serverbase.cpython-39.pyc │ │ ├── serverFedProx.cpython-38.pyc │ │ ├── serverFedProx.cpython-39.pyc │ │ ├── serverpFedGen.cpython-38.pyc │ │ ├── serverpFedGen.cpython-39.pyc │ │ ├── serverFedDistill.cpython-38.pyc │ │ ├── serverFedDistill.cpython-39.pyc │ │ ├── serverpFedEnsemble.cpython-38.pyc │ │ └── serverpFedEnsemble.cpython-39.pyc │ ├── serverFedProx.py │ ├── serverpFedEnsemble.py │ ├── serveravg.py │ ├── serverFedDistill.py │ ├── serverbase.py │ ├── serverpFedGen.py │ └── serverpFedCL.py ├── trainmodel │ ├── __pycache__ │ │ ├── models.cpython-38.pyc │ │ ├── models.cpython-39.pyc │ │ ├── generator.cpython-38.pyc │ │ └── generator.cpython-39.pyc │ ├── models.py │ └── generator.py ├── curriculum │ ├── __pycache__ │ │ ├── cl_score.cpython-38.pyc │ │ └── cl_score.cpython-39.pyc │ └── cl_score.py └── optimizers │ ├── __pycache__ │ ├── fedoptimizer.cpython-38.pyc │ └── fedoptimizer.cpython-39.pyc │ ├── fedoptimizer.py │ └── loss.py ├── data ├── CelebA │ ├── README.md │ └── generate_niid_agg.py ├── Mnist │ └── generate_niid_dirichlet.py └── EMnist │ └── generate_niid_dirichlet.py ├── main_plot.py ├── main.py └── README.md /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /组会汇报FEDCL.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/组会汇报FEDCL.pptx -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | Pillow 4 | torchvision 5 | matplotlib 6 | tqdm 7 | h5py 8 | sklearn 9 | seaborn 10 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/utils/__pycache__/model_config.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/utils/__pycache__/model_config.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/utils/__pycache__/model_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/utils/__pycache__/model_utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plot_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/utils/__pycache__/plot_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plot_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/utils/__pycache__/plot_utils.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/useravg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/useravg.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/useravg.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/useravg.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/userbase.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/userbase.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/userbase.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/userbase.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serveravg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serveravg.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serveravg.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serveravg.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/trainmodel/__pycache__/models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/trainmodel/__pycache__/models.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/trainmodel/__pycache__/models.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/trainmodel/__pycache__/models.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/userFedProx.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/userFedProx.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/userFedProx.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/userFedProx.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/userpFedGen.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/userpFedGen.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/userpFedGen.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/userpFedGen.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/curriculum/__pycache__/cl_score.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/curriculum/__pycache__/cl_score.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/curriculum/__pycache__/cl_score.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/curriculum/__pycache__/cl_score.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverbase.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverbase.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverbase.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverbase.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/trainmodel/__pycache__/generator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/trainmodel/__pycache__/generator.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/trainmodel/__pycache__/generator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/trainmodel/__pycache__/generator.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/userFedDistill.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/userFedDistill.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/users/__pycache__/userFedDistill.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/users/__pycache__/userFedDistill.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverFedProx.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverFedProx.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverFedProx.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverFedProx.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverpFedGen.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverpFedGen.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverpFedGen.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverpFedGen.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/optimizers/__pycache__/fedoptimizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/optimizers/__pycache__/fedoptimizer.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/optimizers/__pycache__/fedoptimizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/optimizers/__pycache__/fedoptimizer.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverFedDistill.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverFedDistill.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverFedDistill.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverFedDistill.cpython-39.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverpFedEnsemble.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverpFedEnsemble.cpython-38.pyc -------------------------------------------------------------------------------- /FLAlgorithms/servers/__pycache__/serverpFedEnsemble.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingjieWang0606/FedCL_Pubic/HEAD/FLAlgorithms/servers/__pycache__/serverpFedEnsemble.cpython-39.pyc -------------------------------------------------------------------------------- /data/CelebA/README.md: -------------------------------------------------------------------------------- 1 | #### To generated CelebA dataset 2 | ##### Step 1. 3 | follow the [LEAF instructions](https://github.com/TalwalkarLab/leaf/tree/master/data/celeba) to download the raw celeb data, and generate `train` and `test` subfolders. 4 | 5 | #### Step 2. 6 | change [`LOAD_PATH`](https://github.com/zhuangdizhu/FedGen/blob/05625ef130f681075fb04b804322e33ef31f6dea/data/CelebA/generate_niid_agg.py#L15) in the `generate_niid_agg.py` to point to the folder of the raw celeba downloaded in step 1. 7 | 8 | #### Step 3. 9 | run `generate_niid_agg.py` to generate FL training and testing dataset. 10 | For example, to generate data for 25 FL devices, where each device contains images of 10 celebrities: 11 | ``` 12 | python generate_niid_agg.py --agg_user 10 --ratio 250 13 | ``` 14 | -------------------------------------------------------------------------------- /main_plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import h5py 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import importlib 6 | import random 7 | import os 8 | import argparse 9 | from utils.plot_utils import * 10 | import torch 11 | torch.manual_seed(0) 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--dataset", type=str, default="Mnist") 16 | parser.add_argument("--algorithms", type=str, default="FedAvg,Fedgen", help='algorithm names separate by comma') 17 | parser.add_argument("--result_path", type=str, default="results", help="directory path to save results") 18 | parser.add_argument("--model", type=str, default="cnn") 19 | parser.add_argument("--learning_rate", type=float, default=0.01, help='learning rate.') 20 | parser.add_argument("--local_epochs", type=int, default=20) 21 | parser.add_argument("--num_glob_iters", type=int, default=200) 22 | parser.add_argument("--min_acc", type=float, default=-1.0) 23 | parser.add_argument("--num_users", type=int, default=5, help='number of active users per epoch.') 24 | parser.add_argument("--batch_size", type=int, default=32) 25 | parser.add_argument("--gen_batch_size", type=int, default=32) 26 | parser.add_argument("--plot_legend", type=int, default=1, help='plot legend if set to 1, omitted otherwise.') 27 | parser.add_argument("--times", type=int, default=3, help='number of random seeds, starting from 1.') 28 | parser.add_argument("--embedding", type=int, default=0, help="Use embedding layer in generator network") 29 | args = parser.parse_args() 30 | 31 | algorithms = [a.strip() for a in args.algorithms.split(',')] 32 | title = 'epoch{}'.format(args.local_epochs) 33 | plot_results(args, algorithms) 34 | -------------------------------------------------------------------------------- /FLAlgorithms/users/useravg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from FLAlgorithms.users.userbase import User 3 | 4 | class UserAVG(User): 5 | def __init__(self, args, id, model, train_data, test_data, use_adam=False): 6 | super().__init__(args, id, model, train_data, test_data, use_adam=use_adam) 7 | 8 | def update_label_counts(self, labels, counts): 9 | for label, count in zip(labels, counts): 10 | self.label_counts[int(label)] += count 11 | 12 | 13 | def clean_up_counts(self): 14 | del self.label_counts 15 | self.label_counts = {int(label):1 for label in range(self.unique_labels)} 16 | 17 | def train(self, glob_iter, personalized=False, lr_decay=True, count_labels=True): 18 | self.clean_up_counts() 19 | self.model.train() 20 | for epoch in range(1, self.local_epochs + 1): 21 | self.model.train() 22 | for i in range(self.K): 23 | result =self.get_next_train_batch(count_labels=count_labels) 24 | X, y = result['X'], result['y'] 25 | if count_labels: 26 | self.update_label_counts(result['labels'], result['counts']) 27 | 28 | self.optimizer.zero_grad() 29 | output=self.model(X)['output'] 30 | loss=self.loss(output, y) 31 | loss.backward() 32 | self.optimizer.step()#self.plot_Celeb) 33 | 34 | # local-model <=== self.model 35 | self.clone_model_paramenter(self.model.parameters(), self.local_model) 36 | if personalized: 37 | self.clone_model_paramenter(self.model.parameters(), self.personalized_model_bar) 38 | # local-model ===> self.model 39 | #self.clone_model_paramenter(self.local_model, self.model.parameters()) 40 | if lr_decay: 41 | self.lr_scheduler.step(glob_iter) 42 | -------------------------------------------------------------------------------- /FLAlgorithms/users/userFedProx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from FLAlgorithms.users.userbase import User 3 | from FLAlgorithms.optimizers.fedoptimizer import FedProxOptimizer 4 | 5 | class UserFedProx(User): 6 | def __init__(self, args, id, model, train_data, test_data, use_adam=False): 7 | super().__init__(args, id, model, train_data, test_data, use_adam=use_adam) 8 | 9 | self.optimizer = FedProxOptimizer(self.model.parameters(), lr=self.learning_rate, lamda=self.lamda) 10 | self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.optimizer, gamma=0.99) 11 | 12 | def update_label_counts(self, labels, counts): 13 | for label, count in zip(labels, counts): 14 | self.label_counts[int(label)] += count 15 | 16 | def clean_up_counts(self): 17 | del self.label_counts 18 | self.label_counts = {int(label):1 for label in range(self.unique_labels)} 19 | 20 | def train(self, glob_iter, lr_decay=True, count_labels=False): 21 | self.clean_up_counts() 22 | self.model.train() 23 | # cache global model initialized value to local model 24 | self.clone_model_paramenter(self.local_model, self.model.parameters()) 25 | for epoch in range(self.local_epochs): 26 | self.model.train() 27 | for i in range(self.K): 28 | result =self.get_next_train_batch(count_labels=count_labels) 29 | X, y = result['X'], result['y'] 30 | if count_labels: 31 | self.update_label_counts(result['labels'], result['counts']) 32 | 33 | self.optimizer.zero_grad() 34 | output=self.model(X)['output'] 35 | loss=self.loss(output, y) 36 | loss.backward() 37 | self.optimizer.step(self.local_model) 38 | if lr_decay: 39 | self.lr_scheduler.step(glob_iter) 40 | -------------------------------------------------------------------------------- /FLAlgorithms/servers/serverFedProx.py: -------------------------------------------------------------------------------- 1 | from FLAlgorithms.users.userFedProx import UserFedProx 2 | from FLAlgorithms.servers.serverbase import Server 3 | from utils.model_utils import read_data, read_user_data 4 | # Implementation for FedProx Server 5 | 6 | class FedProx(Server): 7 | def __init__(self, args, model, seed): 8 | #dataset, algorithm, model, batch_size, learning_rate, beta, lamda, num_glob_iters, 9 | # local_epochs, num_users, K, personal_learning_rate, times): 10 | super().__init__(args, model, seed)#dataset, algorithm, model, batch_size, learning_rate, beta, lamda, num_glob_iters, 11 | #local_epochs, num_users, times) 12 | 13 | # Initialize data for all users 14 | data = read_data(args.dataset) 15 | total_users = len(data[0]) 16 | print("Users in total: {}".format(total_users)) 17 | 18 | for i in range(total_users): 19 | id, train_data , test_data = read_user_data(i, data, dataset=args.dataset) 20 | user = UserFedProx(args, id, model, train_data, test_data, use_adam=False) 21 | self.users.append(user) 22 | self.total_train_samples += user.train_samples 23 | 24 | print("Number of users / total users:", self.num_users, " / " ,total_users) 25 | print("Finished creating FedAvg server.") 26 | 27 | def train(self, args): 28 | for glob_iter in range(self.num_glob_iters): 29 | print("\n\n-------------Round number: ",glob_iter, " -------------\n\n") 30 | self.selected_users = self.select_users(glob_iter,self.num_users) 31 | self.send_parameters() 32 | self.evaluate() 33 | for user in self.selected_users: # allow selected users to train 34 | user.train(glob_iter) 35 | self.aggregate_parameters() 36 | self.save_results(args) 37 | self.save_model() -------------------------------------------------------------------------------- /FLAlgorithms/optimizers/fedoptimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Optimizer 2 | 3 | class pFedIBOptimizer(Optimizer): 4 | def __init__(self, params, lr=0.01): 5 | # self.local_weight_updated = local_weight # w_i,K 6 | if lr < 0.0: 7 | raise ValueError("Invalid learning rate: {}".format(lr)) 8 | defaults=dict(lr=lr) 9 | super(pFedIBOptimizer, self).__init__(params, defaults) 10 | 11 | def step(self, apply=True, lr=None, allow_unused=False): 12 | grads = [] 13 | # apply gradient to model.parameters, and return the gradients 14 | for group in self.param_groups: 15 | for p in group['params']: 16 | if p.grad is None and allow_unused: 17 | continue 18 | grads.append(p.grad.data) 19 | if apply: 20 | if lr == None: 21 | p.data= p.data - group['lr'] * p.grad.data 22 | else: 23 | p.data=p.data - lr * p.grad.data 24 | return grads 25 | 26 | 27 | def apply_grads(self, grads, beta=None, allow_unused=False): 28 | #apply gradient to model.parameters 29 | i = 0 30 | for group in self.param_groups: 31 | for p in group['params']: 32 | if p.grad is None and allow_unused: 33 | continue 34 | p.data= p.data - group['lr'] * grads[i] if beta == None else p.data - beta * grads[i] 35 | i += 1 36 | return 37 | 38 | 39 | class FedProxOptimizer(Optimizer): 40 | def __init__(self, params, lr=0.01, lamda=0.1, mu=0.001): 41 | if lr < 0.0: 42 | raise ValueError("Invalid learning rate: {}".format(lr)) 43 | defaults=dict(lr=lr, lamda=lamda, mu=mu) 44 | super(FedProxOptimizer, self).__init__(params, defaults) 45 | 46 | def step(self, vstar, closure=None): 47 | loss=None 48 | if closure is not None: 49 | loss=closure 50 | for group in self.param_groups: 51 | for p, pstar in zip(group['params'], vstar): 52 | # w <=== w - lr * ( w' + lambda * (w - w* ) + mu * w ) 53 | p.data=p.data - group['lr'] * ( 54 | p.grad.data + group['lamda'] * (p.data - pstar.data.clone()) + group['mu'] * p.data) 55 | return group['params'], loss 56 | -------------------------------------------------------------------------------- /FLAlgorithms/users/userGen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | import json 6 | from FLAlgorithms.users.userbase import User 7 | from FLAlgorithms.optimizers.fedoptimizer import pFedIBOptimizer 8 | from utils.model_utils import get_dataset_name, CONFIGS 9 | 10 | # Implementation for FedAvg clients 11 | 12 | class UserAVG(User): 13 | def __init__(self, dataset, algorithm, numeric_id, train_data, test_data, model, batch_size, learning_rate, beta, lamda, 14 | local_epochs, K): 15 | super().__init__(dataset, algorithm, numeric_id, train_data, test_data, model, batch_size, learning_rate, beta, lamda, 16 | local_epochs) 17 | 18 | dataset_name = get_dataset_name(dataset) 19 | self.unique_labels = CONFIGS[dataset_name]['unique_labels'] 20 | 21 | def update_label_counts(self, labels, counts): 22 | for label, count in zip(labels, counts): 23 | self.label_counts[int(label)] += count 24 | 25 | 26 | def clean_up_counts(self): 27 | del self.label_counts 28 | self.label_counts = {int(label):1 for label in range(self.unique_labels)} 29 | 30 | def train(self, glob_iter, personalized=False, lr_decay=True, count_labels=True): 31 | self.clean_up_counts() 32 | self.model.train() 33 | for epoch in range(1, self.local_epochs + 1): 34 | self.model.train() 35 | for i in range(self.K): 36 | result =self.get_next_train_batch(count_labels=count_labels) 37 | X, y = result['X'], result['y'] 38 | if count_labels: 39 | self.update_label_counts(result['labels'], result['counts']) 40 | 41 | self.optimizer.zero_grad() 42 | output=self.model(X)['output'] 43 | loss=self.loss(output, y) 44 | loss.backward() 45 | self.optimizer.step()#self.local_model) 46 | 47 | # local-model <=== self.model 48 | self.clone_model_paramenter(self.model.parameters(), self.local_model) 49 | if personalized: 50 | self.clone_model_paramenter(self.model.parameters(), self.personalized_model_bar) 51 | # local-model ===> self.model 52 | #self.clone_model_paramenter(self.local_model, self.model.parameters()) 53 | if lr_decay: 54 | self.lr_scheduler.step(glob_iter) 55 | -------------------------------------------------------------------------------- /FLAlgorithms/users/userpFedGen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | import json 6 | from FLAlgorithms.users.userbase import User 7 | from FLAlgorithms.optimizers.fedoptimizer import pFedIBOptimizer 8 | from utils.model_utils import get_dataset_name, CONFIGS 9 | 10 | # Implementation for FedAvg clients 11 | 12 | class UserAVG(User): 13 | def __init__(self, dataset, algorithm, numeric_id, train_data, test_data, model, batch_size, learning_rate, beta, lamda, 14 | local_epochs, K): 15 | super().__init__(dataset, algorithm, numeric_id, train_data, test_data, model, batch_size, learning_rate, beta, lamda, 16 | local_epochs) 17 | 18 | dataset_name = get_dataset_name(dataset) 19 | self.unique_labels = CONFIGS[dataset_name]['unique_labels'] 20 | 21 | def update_label_counts(self, labels, counts): 22 | for label, count in zip(labels, counts): 23 | self.label_counts[int(label)] += count 24 | 25 | 26 | def clean_up_counts(self): 27 | del self.label_counts 28 | self.label_counts = {int(label):1 for label in range(self.unique_labels)} 29 | 30 | def train(self, glob_iter, personalized=False, lr_decay=True, count_labels=True): 31 | self.clean_up_counts() 32 | self.model.train() 33 | for epoch in range(1, self.local_epochs + 1): 34 | self.model.train() 35 | for i in range(self.K): 36 | result =self.get_next_train_batch(count_labels=count_labels) 37 | X, y = result['X'], result['y'] 38 | if count_labels: 39 | self.update_label_counts(result['labels'], result['counts']) 40 | 41 | self.optimizer.zero_grad() 42 | output=self.model(X)['output'] 43 | loss=self.loss(output, y) 44 | loss.backward() 45 | self.optimizer.step()#self.local_model) 46 | 47 | # local-model <=== self.model 48 | self.clone_model_paramenter(self.model.parameters(), self.local_model) 49 | if personalized: 50 | self.clone_model_paramenter(self.model.parameters(), self.personalized_model_bar) 51 | # local-model ===> self.model 52 | #self.clone_model_paramenter(self.local_model, self.model.parameters()) 53 | if lr_decay: 54 | self.lr_scheduler.step(glob_iter) -------------------------------------------------------------------------------- /FLAlgorithms/servers/serverpFedEnsemble.py: -------------------------------------------------------------------------------- 1 | from FLAlgorithms.users.useravg import UserAVG 2 | from FLAlgorithms.servers.serverbase import Server 3 | from utils.model_utils import read_data, read_user_data, aggregate_user_test_data 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | class FedEnsemble(Server): 8 | def __init__(self, args, model, seed): 9 | super().__init__(args, model, seed) 10 | 11 | # Initialize data for all users 12 | data = read_data(args.dataset) 13 | # data contains: clients, groups, train_data, test_data, proxy_data 14 | clients = data[0] 15 | total_users = len(clients) 16 | self.total_test_samples = 0 17 | self.slow_start = 20 18 | self.use_adam = 'adam' in self.algorithm.lower() 19 | self.init_ensemble_configs() 20 | self.init_loss_fn() 21 | #### creating users #### 22 | self.users = [] 23 | for i in range(total_users): 24 | id, train_data, test_data, label_info =read_user_data(i, data, dataset=args.dataset, count_labels=True) 25 | self.total_train_samples+=len(train_data) 26 | self.total_test_samples += len(test_data) 27 | user=UserAVG(args, id, model, train_data, test_data, use_adam=self.use_adam) 28 | self.users.append(user) 29 | 30 | #### build test data loader #### 31 | self.testloaderfull, self.unique_labels=aggregate_user_test_data(data, args.dataset, self.total_test_samples) 32 | print("Loading testing data.") 33 | print("Number of Train/Test samples:", self.total_train_samples, self.total_test_samples) 34 | print("Data from {} users in total.".format(total_users)) 35 | print("Finished creating FedAvg server.") 36 | 37 | def train(self, args): 38 | #### pretraining 39 | for glob_iter in range(self.num_glob_iters): 40 | print("\n\n-------------Round number: ",glob_iter, " -------------\n\n") 41 | self.selected_users, self.user_idxs=self.select_users(glob_iter, self.num_users, return_idx=True) 42 | self.send_parameters() 43 | for user_id, user in zip(self.user_idxs, self.selected_users): # allow selected users to train 44 | user.train( 45 | glob_iter, 46 | personalized=False, lr_decay=True, count_labels=True) 47 | self.aggregate_parameters() 48 | self.evaluate_ensemble(selected=False) 49 | 50 | self.save_results(args) 51 | self.save_model() 52 | -------------------------------------------------------------------------------- /FLAlgorithms/servers/serveravg.py: -------------------------------------------------------------------------------- 1 | from FLAlgorithms.users.useravg import UserAVG 2 | from FLAlgorithms.servers.serverbase import Server 3 | from utils.model_utils import read_data, read_user_data 4 | import numpy as np 5 | # Implementation for FedAvg Server 6 | import time 7 | 8 | class FedAvg(Server): 9 | def __init__(self, args, model, seed): 10 | super().__init__(args, model, seed) 11 | 12 | # Initialize data for all users 13 | data = read_data(args.dataset) 14 | total_users = len(data[0]) 15 | self.use_adam = 'adam' in self.algorithm.lower() 16 | print("Users in total: {}".format(total_users)) 17 | 18 | for i in range(total_users): 19 | id, train_data , test_data = read_user_data(i, data, dataset=args.dataset) 20 | user = UserAVG(args, id, model, train_data, test_data, use_adam=False) 21 | self.users.append(user) 22 | self.total_train_samples += user.train_samples 23 | 24 | print("Number of users / total users:",args.num_users, " / " ,total_users) 25 | print("Finished creating FedAvg server.") 26 | 27 | def train(self, args): 28 | for glob_iter in range(self.num_glob_iters): 29 | print("\n\n-------------Round number: ",glob_iter, " -------------\n\n") 30 | self.selected_users = self.select_users(glob_iter,self.num_users) 31 | self.send_parameters(mode=self.mode) 32 | self.evaluate() 33 | self.timestamp = time.time() # log user-training start time 34 | for user in self.selected_users: # allow selected users to train 35 | user.train(glob_iter, personalized=self.personalized) #* user.train_samples 36 | curr_timestamp = time.time() # log user-training end time 37 | train_time = (curr_timestamp - self.timestamp) / len(self.selected_users) 38 | self.metrics['user_train_time'].append(train_time) 39 | # Evaluate selected user 40 | if self.personalized: 41 | # Evaluate personal model on user for each iteration 42 | print("Evaluate personal model\n") 43 | self.evaluate_personalized_model() 44 | 45 | self.timestamp = time.time() # log server-agg start time 46 | self.aggregate_parameters() 47 | curr_timestamp=time.time() # log server-agg end time 48 | agg_time = curr_timestamp - self.timestamp 49 | self.metrics['server_agg_time'].append(agg_time) 50 | self.save_results(args) 51 | self.save_model() -------------------------------------------------------------------------------- /utils/model_config.py: -------------------------------------------------------------------------------- 1 | CONFIGS_ = { 2 | # input_channel, n_class, hidden_dim, latent_dim 3 | 'cifar': ([16, 'M', 32, 'M', 'F'], 3, 10, 2048, 64), 4 | 'cifar100-c25': ([32, 'M', 64, 'M', 128, 'F'], 3, 26, 128, 128), 5 | 'cifar100-c30': ([32, 'M', 64, 'M', 128, 'F'], 3, 30, 2048, 128), 6 | 'cifar100-c50': ([32, 'M', 64, 'M', 128, 'F'], 3, 50, 2048, 128), 7 | 8 | 'emnist': ([6, 16, 'F'], 1, 26, 784, 32), 9 | 'mnist': ([6, 16, 'F'], 1, 10, 784, 32), 10 | 'mnist_cnn1': ([6, 'M', 16, 'M', 'F'], 1, 10, 64, 32), 11 | 'mnist_cnn2': ([16, 'M', 32, 'M', 'F'], 1, 10, 128, 32), 12 | 'celeb': ([16, 'M', 32, 'M', 64, 'M', 'F'], 3, 2, 64, 32) 13 | } 14 | 15 | # temporary roundabout to evaluate sensitivity of the generator 16 | GENERATORCONFIGS = { 17 | # hidden_dimension, latent_dimension, input_channel, n_class, noise_dim 18 | 'cifar': (512, 32, 3, 10, 64), 19 | 'celeb': (128, 32, 3, 2, 32), 20 | 'mnist': (256, 32, 1, 10, 32), 21 | 'mnist-cnn0': (256, 32, 1, 10, 64), 22 | 'mnist-cnn1': (128, 32, 1, 10, 32), 23 | 'mnist-cnn2': (64, 32, 1, 10, 32), 24 | 'mnist-cnn3': (64, 32, 1, 10, 16), 25 | 'emnist': (256, 32, 1, 26, 32), 26 | 'emnist-cnn0': (256, 32, 1, 26, 64), 27 | 'emnist-cnn1': (128, 32, 1, 26, 32), 28 | 'emnist-cnn2': (128, 32, 1, 26, 16), 29 | 'emnist-cnn3': (64, 32, 1, 26, 32), 30 | } 31 | 32 | 33 | 34 | RUNCONFIGS = { 35 | 'emnist': 36 | { 37 | 'ensemble_lr': 1e-4, 38 | 'ensemble_batch_size': 128, 39 | 'ensemble_epochs': 50, 40 | 'num_pretrain_iters': 20, 41 | 'ensemble_alpha': 1, # teacher loss (server side) 42 | 'ensemble_beta': 0, # adversarial student loss 43 | 'unique_labels': 26, 44 | 'generative_alpha':10, 45 | 'generative_beta': 1, 46 | 'weight_decay': 1e-2 47 | }, 48 | 49 | 'mnist': 50 | { 51 | 'ensemble_lr': 3e-4, 52 | 'ensemble_batch_size': 128, 53 | 'ensemble_epochs': 50, 54 | 'num_pretrain_iters': 20, 55 | 'ensemble_alpha': 1, # teacher loss (server side) 56 | 'ensemble_beta': 0, # adversarial student loss 57 | 'ensemble_eta': 1, # diversity loss 58 | 'unique_labels': 10, # available labels 59 | 'generative_alpha': 10, # used to regulate user training 60 | 'generative_beta': 10, # used to regulate user training 61 | 'weight_decay': 1e-2 62 | }, 63 | 64 | 'celeb': 65 | { 66 | 'ensemble_lr': 3e-4, 67 | 'ensemble_batch_size': 128, 68 | 'ensemble_epochs': 50, 69 | 'num_pretrain_iters': 20, 70 | 'ensemble_alpha': 1, # teacher loss (server side) 71 | 'ensemble_beta': 0, # adversarial student loss 72 | 'unique_labels': 2, 73 | 'generative_alpha': 10, 74 | 'generative_beta': 10, 75 | 'weight_decay': 1e-2 76 | }, 77 | 78 | } 79 | 80 | -------------------------------------------------------------------------------- /utils/model_config-base.py: -------------------------------------------------------------------------------- 1 | CONFIGS_ = { 2 | # input_channel, n_class, hidden_dim, latent_dim 3 | 'cifar': ([16, 'M', 32, 'M', 'F'], 3, 10, 2048, 64), 4 | 'cifar100-c25': ([32, 'M', 64, 'M', 128, 'F'], 3, 25, 128, 128), 5 | 'cifar100-c30': ([32, 'M', 64, 'M', 128, 'F'], 3, 30, 2048, 128), 6 | 'cifar100-c50': ([32, 'M', 64, 'M', 128, 'F'], 3, 50, 2048, 128), 7 | 8 | 'emnist': ([6, 16, 'F'], 1, 25, 784, 32), 9 | 'mnist': ([6, 16, 'F'], 1, 10, 784, 32), 10 | 'mnist_cnn1': ([6, 'M', 16, 'M', 'F'], 1, 10, 64, 32), 11 | 'mnist_cnn2': ([16, 'M', 32, 'M', 'F'], 1, 10, 128, 32), 12 | 'celeb': ([16, 'M', 32, 'M', 64, 'M', 'F'], 3, 2, 64, 32) 13 | } 14 | 15 | # temporary roundabout to evaluate sensitivity of the generator 16 | GENERATORCONFIGS = { 17 | # hidden_dimension, latent_dimension, input_channel, n_class, noise_dim 18 | 'cifar': (512, 32, 3, 10, 64), 19 | 'celeb': (128, 32, 3, 2, 32), 20 | 'mnist': (256, 32, 1, 10, 32), 21 | 'mnist-cnn0': (256, 32, 1, 10, 64), 22 | 'mnist-cnn1': (128, 32, 1, 10, 32), 23 | 'mnist-cnn2': (64, 32, 1, 10, 32), 24 | 'mnist-cnn3': (64, 32, 1, 10, 16), 25 | 'emnist': (256, 32, 1, 25, 32), 26 | 'emnist-cnn0': (256, 32, 1, 25, 64), 27 | 'emnist-cnn1': (128, 32, 1, 25, 32), 28 | 'emnist-cnn2': (128, 32, 1, 25, 16), 29 | 'emnist-cnn3': (64, 32, 1, 25, 32), 30 | } 31 | 32 | 33 | 34 | RUNCONFIGS = { 35 | 'emnist': 36 | { 37 | 'ensemble_lr': 1e-4, 38 | 'ensemble_batch_size': 128, 39 | 'ensemble_epochs': 50, 40 | 'num_pretrain_iters': 20, 41 | 'ensemble_alpha': 1, # teacher loss (server side) 42 | 'ensemble_beta': 0, # adversarial student loss 43 | 'unique_labels': 25, 44 | 'generative_alpha':10, 45 | 'generative_beta': 1, 46 | 'weight_decay': 1e-2 47 | }, 48 | 49 | 'mnist': 50 | { 51 | 'ensemble_lr': 3e-4, 52 | 'ensemble_batch_size': 128, 53 | 'ensemble_epochs': 50, 54 | 'num_pretrain_iters': 20, 55 | 'ensemble_alpha': 1, # teacher loss (server side) 56 | 'ensemble_beta': 0, # adversarial student loss 57 | 'ensemble_eta': 1, # diversity loss 58 | 'unique_labels': 10, # available labels 59 | 'generative_alpha': 10, # used to regulate user training 60 | 'generative_beta': 10, # used to regulate user training 61 | 'weight_decay': 1e-2 62 | }, 63 | 64 | 'celeb': 65 | { 66 | 'ensemble_lr': 3e-4, 67 | 'ensemble_batch_size': 128, 68 | 'ensemble_epochs': 50, 69 | 'num_pretrain_iters': 20, 70 | 'ensemble_alpha': 1, # teacher loss (server side) 71 | 'ensemble_beta': 0, # adversarial student loss 72 | 'unique_labels': 2, 73 | 'generative_alpha': 10, 74 | 'generative_beta': 10, 75 | 'weight_decay': 1e-2 76 | }, 77 | 78 | } 79 | 80 | -------------------------------------------------------------------------------- /utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import h5py 3 | import numpy as np 4 | from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes, mark_inset 5 | from matplotlib.ticker import StrMethodFormatter 6 | import os 7 | from utils.model_utils import get_log_path, METRICS 8 | import seaborn as sns 9 | import string 10 | import matplotlib.colors as mcolors 11 | import os 12 | COLORS=list(mcolors.TABLEAU_COLORS) 13 | MARKERS=["o", "v", "s", "*", "x", "P"] 14 | 15 | plt.rcParams.update({'font.size': 14}) 16 | n_seeds=3 17 | 18 | def load_results(args, algorithm, seed): 19 | alg = get_log_path(args, algorithm, seed, args.gen_batch_size) 20 | hf = h5py.File("./{}/{}.h5".format(args.result_path, alg), 'r') 21 | metrics = {} 22 | for key in METRICS: 23 | metrics[key] = np.array(hf.get(key)[:]) 24 | return metrics 25 | 26 | 27 | def get_label_name(name): 28 | name = name.split("_")[0] 29 | if 'Distill' in name: 30 | if '-FL' in name: 31 | name = 'FedDistill' + r'$^+$' 32 | else: 33 | name = 'FedDistill' 34 | elif 'FedDF' in name: 35 | name = 'FedFusion' 36 | elif 'FedEnsemble' in name: 37 | name = 'Ensemble' 38 | elif 'FedAvg' in name: 39 | name = 'FedAvg' 40 | return name 41 | 42 | def plot_results(args, algorithms): 43 | n_seeds = args.times 44 | dataset_ = args.dataset.split('-') 45 | sub_dir = dataset_[0] + "/" + dataset_[2] # e.g. Mnist/ratio0.5 46 | os.system("mkdir -p figs/{}".format(sub_dir)) # e.g. figs/Mnist/ratio0.5 47 | plt.figure(1, figsize=(5, 5)) 48 | TOP_N = 5 49 | max_acc = 0 50 | for i, algorithm in enumerate(algorithms): 51 | algo_name = get_label_name(algorithm) 52 | ######### plot test accuracy ############ 53 | metrics = [load_results(args, algorithm, seed) for seed in range(n_seeds)] 54 | all_curves = np.concatenate([metrics[seed]['glob_acc'] for seed in range(n_seeds)]) 55 | top_accs = np.concatenate([np.sort(metrics[seed]['glob_acc'])[-TOP_N:] for seed in range(n_seeds)] ) 56 | acc_avg = np.mean(top_accs) 57 | acc_std = np.std(top_accs) 58 | info = 'Algorithm: {:<10s}, Accuracy = {:.2f} %, deviation = {:.2f}'.format(algo_name, acc_avg * 100, acc_std * 100) 59 | print(info) 60 | length = len(all_curves) // n_seeds 61 | sns.lineplot( 62 | x=np.array(list(range(length)) * n_seeds) + 1, 63 | y=all_curves.astype(float), 64 | legend='brief', 65 | color=COLORS[i], 66 | label=algo_name, 67 | ci="sd", 68 | ) 69 | 70 | plt.gcf() 71 | plt.grid() 72 | plt.title(dataset_[0] + ' Test Accuracy') 73 | plt.xlabel('Epoch') 74 | max_acc = np.max([max_acc, np.max(all_curves) ]) + 4e-2 75 | 76 | if args.min_acc < 0: 77 | alpha = 0.7 78 | min_acc = np.max(all_curves) * alpha + np.min(all_curves) * (1-alpha) 79 | else: 80 | min_acc = args.min_acc 81 | plt.ylim(min_acc, max_acc) 82 | fig_save_path = os.path.join('figs', sub_dir, dataset_[0] + '-' + dataset_[2] + '.png') 83 | plt.savefig(fig_save_path, bbox_inches='tight', pad_inches=0, format='png', dpi=400) 84 | print('file saved to {}'.format(fig_save_path)) -------------------------------------------------------------------------------- /FLAlgorithms/servers/serverFedDistill.py: -------------------------------------------------------------------------------- 1 | from FLAlgorithms.users.userFedDistill import UserFedDistill 2 | from FLAlgorithms.servers.serverbase import Server 3 | from utils.model_utils import read_data, read_user_data, aggregate_user_test_data 4 | import numpy as np 5 | 6 | class FedDistill(Server): 7 | def __init__(self, args, model, seed): 8 | super().__init__(args, model, seed) 9 | 10 | # Initialize data for all users 11 | data = read_data(args.dataset) 12 | # data contains: clients, groups, train_data, test_data, proxy_data 13 | clients = data[0] 14 | total_users = len(clients) 15 | self.total_test_samples = 0 16 | self.slow_start = 20 17 | self.share_model = 'FL' in self.algorithm 18 | self.pretrain = 'pretrain' in self.algorithm.lower() 19 | self.user_logits = None 20 | self.init_ensemble_configs() 21 | self.init_loss_fn() 22 | self.init_ensemble_configs() 23 | #### creating users #### 24 | self.users = [] 25 | for i in range(total_users): 26 | id, train_data, test_data, label_info =read_user_data(i, data, dataset=args.dataset, count_labels=True) 27 | self.total_train_samples+=len(train_data) 28 | self.total_test_samples += len(test_data) 29 | id, train, test=read_user_data(i, data, dataset=args.dataset) 30 | user=UserFedDistill( 31 | args, id, model, train_data, test_data, self.unique_labels, use_adam=False) 32 | self.users.append(user) 33 | print("Loading testing data.") 34 | print("Number of Train/Test samples:", self.total_train_samples, self.total_test_samples) 35 | print("Data from {} users in total.".format(total_users)) 36 | print("Finished creating FedAvg server.") 37 | 38 | def train(self, args): 39 | #### pretraining #### 40 | if self.pretrain: 41 | ## before training ## 42 | for iter in range(self.num_pretrain_iters): 43 | print("\n\n-------------Pretrain iteration number: ", iter, " -------------\n\n") 44 | for user in self.users: 45 | user.train(iter, personalized=True, lr_decay=True) 46 | self.evaluate(selected=False, save=False) 47 | ## after training ## 48 | if self.share_model: 49 | self.aggregate_parameters() 50 | self.aggregate_logits(selected=False) # aggregate label-wise logit vector 51 | 52 | for glob_iter in range(self.num_glob_iters): 53 | print("\n\n-------------Round number: ",glob_iter, " -------------\n\n") 54 | self.selected_users, self.user_idxs=self.select_users(glob_iter, self.num_users, return_idx=True) 55 | if self.share_model: 56 | self.send_parameters(mode=self.mode)# broadcast averaged prediction model 57 | self.evaluate() # evaluate global model performance 58 | self.send_logits() # send global logits if have any 59 | random_chosen_id = np.random.choice(self.user_idxs) 60 | for user_id, user in zip(self.user_idxs, self.selected_users): # allow selected users to train 61 | chosen = user_id == random_chosen_id 62 | user.train( 63 | glob_iter, 64 | personalized=True, lr_decay=True, count_labels=True, verbose=chosen) 65 | if self.share_model: 66 | self.aggregate_parameters() 67 | self.aggregate_logits() # aggregate label-wise logit vector 68 | self.evaluate_personalized_model() 69 | 70 | self.save_results(args) 71 | self.save_model() 72 | 73 | def aggregate_logits(self, selected=True): 74 | user_logits = 0 75 | users = self.selected_users if selected else self.users 76 | for user in users: 77 | user_logits += user.logit_tracker.avg() 78 | self.user_logits = user_logits / len(users) 79 | 80 | def send_logits(self): 81 | if self.user_logits == None: return 82 | for user in self.selected_users: 83 | user.global_logits = self.user_logits.clone().detach() 84 | -------------------------------------------------------------------------------- /FLAlgorithms/users/userFedDistill.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from FLAlgorithms.users.userbase import User 6 | from FLAlgorithms.optimizers.fedoptimizer import pFedIBOptimizer 7 | 8 | class LogitTracker(): 9 | def __init__(self, unique_labels): 10 | self.unique_labels = unique_labels 11 | self.labels = [i for i in range(unique_labels)] 12 | self.label_counts = torch.ones(unique_labels) # avoid division by zero error 13 | self.logit_sums = torch.zeros((unique_labels,unique_labels) ) 14 | 15 | def update(self, logits, Y): 16 | """ 17 | update logit tracker. 18 | :param logits: shape = n_sampls * logit-dimension 19 | :param Y: shape = n_samples 20 | :return: nothing 21 | """ 22 | batch_unique_labels, batch_labels_counts = Y.unique(dim=0, return_counts=True) 23 | self.label_counts[batch_unique_labels] += batch_labels_counts 24 | # expand label dimension to be n_samples X logit_dimension 25 | labels = Y.view(Y.size(0), 1).expand(-1, logits.size(1)) 26 | logit_sums_ = torch.zeros((self.unique_labels, self.unique_labels) ) 27 | logit_sums_.scatter_add_(0, labels, logits) 28 | self.logit_sums += logit_sums_ 29 | 30 | 31 | def avg(self): 32 | res= self.logit_sums / self.label_counts.float().unsqueeze(1) 33 | return res 34 | 35 | 36 | class UserFedDistill(User): 37 | """ 38 | Track and average logit vectors for each label, and share it with server/other users. 39 | """ 40 | def __init__(self, args, id, model, train_data, test_data, unique_labels, use_adam=False): 41 | super().__init__(args, id, model, train_data, test_data, use_adam=use_adam) 42 | 43 | self.init_loss_fn() 44 | self.unique_labels = unique_labels 45 | self.label_counts = {} 46 | self.logit_tracker = LogitTracker(self.unique_labels) 47 | self.global_logits = None 48 | self.reg_alpha = 1 49 | 50 | def update_label_counts(self, labels, counts): 51 | for label, count in zip(labels, counts): 52 | self.label_counts[int(label)] += count 53 | 54 | def clean_up_counts(self): 55 | del self.label_counts 56 | self.label_counts = {int(label):1 for label in range(self.unique_labels)} 57 | 58 | def train(self, glob_iter, personalized=True, lr_decay=True, count_labels=True, verbose=True): 59 | self.clean_up_counts() 60 | self.model.train() 61 | REG_LOSS, TRAIN_LOSS = 0, 0 62 | for epoch in range(1, self.local_epochs + 1): 63 | self.model.train() 64 | for i in range(self.K): 65 | result =self.get_next_train_batch(count_labels=count_labels) 66 | X, y = result['X'], result['y'] 67 | if count_labels: 68 | self.update_label_counts(result['labels'], result['counts']) 69 | self.optimizer.zero_grad() 70 | result=self.model(X, logit=True) 71 | output, logit = result['output'], result['logit'] 72 | self.logit_tracker.update(logit, y) 73 | if self.global_logits != None: 74 | ### get desired logit for each sample 75 | train_loss = self.loss(output, y) 76 | target_p = F.softmax(self.global_logits[y,:], dim=1) 77 | reg_loss = self.ensemble_loss(output, target_p) 78 | REG_LOSS += reg_loss 79 | TRAIN_LOSS += train_loss 80 | loss = train_loss + self.reg_alpha * reg_loss 81 | else: 82 | loss=self.loss(output, y) 83 | loss.backward() 84 | self.optimizer.step()#self.local_model) 85 | # local-model <=== self.model 86 | self.clone_model_paramenter(self.model.parameters(), self.local_model) 87 | if personalized: 88 | self.clone_model_paramenter(self.model.parameters(), self.personalized_model_bar) 89 | if lr_decay: 90 | self.lr_scheduler.step(glob_iter) 91 | if self.global_logits != None and verbose: 92 | REG_LOSS = REG_LOSS.detach().numpy() / (self.local_epochs * self.K) 93 | TRAIN_LOSS = TRAIN_LOSS.detach().numpy() / (self.local_epochs * self.K) 94 | info = "Train loss {:.2f}, Regularization loss {:.2f}".format(REG_LOSS, TRAIN_LOSS) 95 | print(info) 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import argparse 4 | from FLAlgorithms.servers.serveravg import FedAvg 5 | from FLAlgorithms.servers.serverFedProx import FedProx 6 | from FLAlgorithms.servers.serverFedDistill import FedDistill 7 | from FLAlgorithms.servers.serverpFedGen import FedGen 8 | from FLAlgorithms.servers.serverpFedEnsemble import FedEnsemble 9 | from FLAlgorithms.servers.serverpFedCL import FedCL 10 | from utils.model_utils import create_model 11 | from utils.plot_utils import * 12 | import torch 13 | from multiprocessing import Pool 14 | import time 15 | import warnings 16 | warnings.filterwarnings("ignore") 17 | def create_server_n_user(args, i): 18 | model = create_model(args.model, args.dataset, args.algorithm) 19 | if ('FedAvg' in args.algorithm): 20 | server=FedAvg(args, model, i) 21 | elif ('FedGen' in args.algorithm): 22 | server=FedGen(args, model, i) 23 | elif ('FedCL' in args.algorithm): 24 | server=FedCL(args, model, i) 25 | elif ('FedProx' in args.algorithm): 26 | server = FedProx(args, model, i) 27 | elif ('FedDistill' in args.algorithm): 28 | server = FedDistill(args, model, i) 29 | elif ('FedEnsemble' in args.algorithm): 30 | server = FedEnsemble(args, model, i) 31 | else: 32 | print("Algorithm {} has not been implemented.".format(args.algorithm)) 33 | exit() 34 | return server 35 | 36 | 37 | def run_job(args, i): 38 | torch.manual_seed(i) 39 | print("\n\n [ Start training iteration {} ] \n\n".format(i)) 40 | # Generate model 41 | server = create_server_n_user(args, i) 42 | if args.train: 43 | server.train(args) 44 | server.test() 45 | 46 | def main(args): 47 | for i in range(args.times): 48 | run_job(args, i) 49 | print("Finished training.") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--dataset", type=str, default="Mnist") 55 | parser.add_argument("--model", type=str, default="cnn") 56 | parser.add_argument("--train", type=int, default=1, choices=[0,1]) 57 | parser.add_argument("--algorithm", type=str, default="pFedMe") 58 | parser.add_argument("--batch_size", type=int, default=32) 59 | parser.add_argument("--gen_batch_size", type=int, default=32, help='number of samples from generator') 60 | parser.add_argument("--learning_rate", type=float, default=0.01, help="Local learning rate") 61 | parser.add_argument("--personal_learning_rate", type=float, default=0.01, help="Personalized learning rate to caculate theta aproximately using K steps") 62 | parser.add_argument("--ensemble_lr", type=float, default=1e-4, help="Ensemble learning rate.") 63 | parser.add_argument("--beta", type=float, default=1.0, help="Average moving parameter for pFedMe, or Second learning rate of Per-FedAvg") 64 | parser.add_argument("--lamda", type=int, default=1, help="Regularization term") 65 | parser.add_argument("--mix_lambda", type=float, default=0.1, help="Mix lambda for FedMXI baseline") 66 | parser.add_argument("--embedding", type=int, default=0, help="Use embedding layer in generator network") 67 | parser.add_argument("--num_glob_iters", type=int, default=200) 68 | parser.add_argument("--local_epochs", type=int, default=20) 69 | parser.add_argument("--num_users", type=int, default=20, help="Number of Users per round") 70 | parser.add_argument("--K", type=int, default=1, help="Computation steps") 71 | parser.add_argument("--times", type=int, default=3, help="running time") 72 | parser.add_argument("--device", type=str, default="cpu", choices=["cpu","cuda"], help="run device (cpu | cuda)") 73 | parser.add_argument("--result_path", type=str, default="results", help="directory path to save results") 74 | 75 | 76 | args = parser.parse_args() 77 | print("=" * 80) 78 | print(args) 79 | print("=" * 80) 80 | print("Summary of training process:") 81 | print("Algorithm: {}".format(args.algorithm)) 82 | print("Batch size: {}".format(args.batch_size)) 83 | print("Learing rate : {}".format(args.learning_rate)) 84 | print("Ensemble learing rate : {}".format(args.ensemble_lr)) 85 | print("Average Moving : {}".format(args.beta)) 86 | print("Subset of users : {}".format(args.num_users)) 87 | print("Number of global rounds : {}".format(args.num_glob_iters)) 88 | print("Number of local rounds : {}".format(args.local_epochs)) 89 | print("Dataset : {}".format(args.dataset)) 90 | print("Local Model : {}".format(args.model)) 91 | print("Device : {}".format(args.device)) 92 | print("=" * 80) 93 | time1 = time.time() 94 | main(args) 95 | time2 = time.time() 96 | print(time2-time1) 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FedCL: Federated Multi-Phase Curriculum Learning to Synchronously Correlate User Heterogeneity 2 | 3 | Research code that accompanies the paper [FedCL: Federated Multi-Phase Curriculum Learning to Synchronously Correlate User Heterogeneity](http://arxiv.org/abs/2211.07248). 4 | It contains implementation of the following algorithms: 5 | * **FedCL** (the proposed algorithm) 6 | * **FedGen** ([paper](https://arxiv.org/pdf/2105.10056.pdf) and [code](https://github.com/zhuangdizhu/FedGen/blob/main/FLAlgorithms/servers/serverpFedGen.py)). 7 | * **FedAvg** ([paper](https://arxiv.org/pdf/1602.05629.pdf) and [code](https://github.com/zhuangdizhu/FedGen/blob/main/FLAlgorithms/servers/serveravg.py)). 8 | * **FedProx** ([paper](https://arxiv.org/pdf/1812.06127.pdf) and [code](https://github.com/zhuangdizhu/FedGen/blob/main/FLAlgorithms/servers/serverFedProx.py)). 9 | * **FedDistill** and its extension **FedDistll-FL** ([paper](https://arxiv.org/pdf/2011.02367.pdf) and [code](https://github.com/zhuangdizhu/FedGen/blob/main/FLAlgorithms/servers/serverFedDistill.py)). 10 | 11 | ## Install Requirements: 12 | ```pip3 install -r requirements.txt``` 13 | 14 | 15 | ## Prepare Dataset: 16 | * To generate *non-iid* **Mnist** Dataset following the Dirichlet distribution D(α=0.1) for 20 clients, using 50% of the total available training samples: 17 |
cd ./data/Mnist
18 | python generate_niid_dirichlet.py --n_class 10 --sampling_ratio 0.5 --alpha 0.1 --n_user 20
19 | ### This will generate a dataset located at FedGen/data/Mnist/u20c10-alpha0.1-ratio0.5/
20 | 
21 | 22 | 23 | - Similarly, to generate *non-iid* **EMnist** Dataset, using 10% of the total available training samples: 24 |
cd FedGen/data/EMnist
25 | python generate_niid_dirichlet.py --sampling_ratio 0.1 --alpha 0.1 --n_user 20 
26 | ### This will generate a dataset located at FedGen/data/EMnist/u20-letters-alpha0.1-ratio0.1/
27 | 
28 | 29 | ## Run Experiments: 30 | 31 | There is a main file "main.py" which allows running all experiments. 32 | 33 | #### Run experiments on the *Mnist* Dataset: 34 | 35 | ``` 36 | python main.py --dataset Mnist-alpha0.1-ratio0.5 --algorithm FedGen --batch_size 32 --num_glob_iters 200 --local_epochs 20 --num_users 10 --lamda 1 --learning_rate 0.01 --model cnn --personal_learning_rate 0.01 --times 3 37 | python main.py --dataset Mnist-alpha0.1-ratio0.5 --algorithm FedCL --batch_size 32 --num_glob_iters 200 --local_epochs 20 --num_users 10 --lamda 1 --learning_rate 0.01 --model cnn --personal_learning_rate 0.01 --times 3 38 | python main.py --dataset Mnist-alpha0.1-ratio0.5 --algorithm FedAvg --batch_size 32 --num_glob_iters 200 --local_epochs 20 --num_users 10 --lamda 1 --learning_rate 0.01 --model cnn --personal_learning_rate 0.01 --times 3 39 | python main.py --dataset Mnist-alpha0.1-ratio0.5 --algorithm FedProx --batch_size 32 --num_glob_iters 200 --local_epochs 20 --num_users 10 --lamda 1 --learning_rate 0.01 --model cnn --personal_learning_rate 0.01 --times 3 40 | python main.py --dataset Mnist-alpha0.1-ratio0.5 --algorithm FedDistll-FL --batch_size 32 --num_glob_iters 200 --local_epochs 20 --num_users 10 --lamda 1 --learning_rate 0.01 --model cnn --personal_learning_rate 0.01 --times 3 41 | ``` 42 | ---- 43 | 44 | ##### Run experiments on the *EMnist* Dataset: 45 | ``` 46 | python main.py --dataset EMnist-alpha0.1-ratio0.1 --algorithm FedAvg --batch_size 32 --local_epochs 20 --num_users 10 --lamda 1 --model cnn --learning_rate 0.01 --personal_learning_rate 0.01 --num_glob_iters 200 --times 3 47 | python main.py --dataset EMnist-alpha0.1-ratio0.1 --algorithm FedGen --batch_size 32 --local_epochs 20 --num_users 10 --lamda 1 --model cnn --learning_rate 0.01 --personal_learning_rate 0.01 --num_glob_iters 200 --times 3 48 | python main.py --dataset EMnist-alpha0.1-ratio0.1 --algorithm FedCL --batch_size 32 --local_epochs 20 --num_users 10 --lamda 1 --model cnn --learning_rate 0.01 --personal_learning_rate 0.01 --num_glob_iters 200 --times 3 49 | python main.py --dataset EMnist-alpha0.1-ratio0.1 --algorithm FedProx --batch_size 32 --local_epochs 20 --num_users 10 --lamda 1 --model cnn --learning_rate 0.01 --personal_learning_rate 0.01 --num_glob_iters 200 --times 3 50 | python main.py --dataset EMnist-alpha0.1-ratio0.1 --algorithm FedDistll-FL --batch_size 32 --local_epochs 20 --num_users 10 --lamda 1 --model cnn --learning_rate 0.01 --personal_learning_rate 0.01 --num_glob_iters 200 --times 3 51 | 52 | ``` 53 | ---- 54 | 55 | ### Plot 56 | For the input attribute **algorithms**, list the name of algorithms and separate them by comma, e.g. `--algorithms FedAvg,FedGen,FedProx` 57 | ``` 58 | python main_plot.py --dataset EMnist-alpha0.1-ratio0.1 --algorithms FedGen --batch_size 32 --local_epochs 50 --num_users 10 --num_glob_iters 200 --plot_legend 1 59 | ``` 60 | ## Citation 61 | Please cite the following paper if you use this code in your work. 62 | ``` 63 | @article{wang2022fedcl, 64 | author = {Wang, Mingjie and Guo, Jianxiong and Jia, Weijia}, 65 | title = {FedCL: Federated Multi-Phase Curriculum Learning to Synchronously Correlate User Heterogeneity}, 66 | year = {2023}, 67 | journal = {IEEE Transactions on Artificial Intelligence}, 68 | doi = {10.1109/TAI.2023.3307664}, 69 | } 70 | ``` 71 | -------------------------------------------------------------------------------- /FLAlgorithms/optimizers/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | def _pairwise_distances(embeddings, squared=False): 3 | """Compute the 2D matrix of distances between all the embeddings. 4 | Args: 5 | embeddings: tensor of shape (batch_size, embed_dim) 6 | squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. 7 | If false, output is the pairwise euclidean distance matrix. 8 | Returns: 9 | pairwise_distances: tensor of shape (batch_size, batch_size) 10 | """ 11 | # Get the dot product between all embeddings 12 | # shape (batch_size, batch_size) 13 | dot_product = torch.matmul(embeddings, embeddings.t()) 14 | 15 | # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`. 16 | # This also provides more numerical stability (the diagonal of the result will be exactly 0). 17 | # shape (batch_size,) 18 | square_norm = torch.diag(dot_product) 19 | 20 | # Compute the pairwise distance matrix as we have: 21 | # ||a - b||^2 = ||a||^2 - 2 + ||b||^2 22 | # shape (batch_size, batch_size) 23 | distances = torch.unsqueeze(square_norm, 1) - 2.0 * dot_product + torch.unsqueeze(square_norm, 0) 24 | 25 | # Because of computation errors, some distances might be negative so we put everything >= 0.0 26 | distances = torch.max(distances, torch.tensor([0.0]).cuda()) 27 | 28 | if not squared: 29 | # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal) 30 | # we need to add a small epsilon where distances == 0.0 31 | mask = (torch.eq(distances, 0.0)).float() 32 | distances = distances + mask * 1e-16 33 | 34 | distances = torch.sqrt(distances) 35 | 36 | # Correct the epsilon added: set the distances on the mask to be exactly 0.0 37 | distances = distances * (torch.sub(1.0, mask)) 38 | 39 | return distances 40 | 41 | 42 | def _get_triplet_mask(labels): 43 | """Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid. 44 | A triplet (i, j, k) is valid if: 45 | - i, j, k are distinct 46 | - labels[i] == labels[j] and labels[i] != labels[k] 47 | Args: 48 | labels: tf.int32 `Tensor` with shape [batch_size] 49 | """ 50 | # Check that i, j and k are distinct 51 | indices_equal = torch.eye(labels.shape[0]).cuda() 52 | indices_not_equal = torch.tensor([1.0]).cuda()-indices_equal 53 | i_not_equal_j = torch.unsqueeze(indices_not_equal, 2) 54 | i_not_equal_k = torch.unsqueeze(indices_not_equal, 1) 55 | j_not_equal_k = torch.unsqueeze(indices_not_equal, 0) 56 | 57 | distinct_indices = torch.mul(torch.mul(i_not_equal_j, i_not_equal_k), j_not_equal_k) 58 | 59 | 60 | # Check if labels[i] == labels[j] and labels[i] != labels[k] 61 | label_equal = torch.eq(torch.unsqueeze(labels, 0), torch.unsqueeze(labels, 1)).float() 62 | i_equal_j = torch.unsqueeze(label_equal, 2) 63 | i_equal_k = torch.unsqueeze(label_equal, 1) 64 | valid_labels = torch.mul(i_equal_j, torch.tensor([1.0]).cuda()-i_equal_k) 65 | 66 | # Combine the two masks 67 | mask = torch.mul(distinct_indices, valid_labels) 68 | return mask 69 | 70 | 71 | def batch_all_triplet_loss(labels, embeddings, margin, squared=False): 72 | """Build the triplet loss over a batch of embeddings. 73 | We generate all the valid triplets and average the loss over the positive ones. 74 | Args: 75 | labels: labels of the batch, of size (batch_size,) 76 | embeddings: tensor of shape (batch_size, embed_dim) 77 | margin: margin for triplet loss 78 | squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. 79 | If false, output is the pairwise euclidean distance matrix. 80 | Returns: 81 | triplet_loss: scalar tensor containing the triplet loss 82 | """ 83 | # Get the pairwise distance matrix 84 | pairwise_dist = _pairwise_distances(embeddings, squared=squared) 85 | 86 | # shape (batch_size, batch_size, 1) 87 | anchor_positive_dist = torch.unsqueeze(pairwise_dist, 2) 88 | assert anchor_positive_dist.shape[2] == 1, "{}".format(anchor_positive_dist.shape) 89 | # shape (batch_size, 1, batch_size) 90 | anchor_negative_dist = torch.unsqueeze(pairwise_dist, 1) 91 | assert anchor_negative_dist.shape[1] == 1, "{}".format(anchor_negative_dist.shape) 92 | 93 | # Compute a 3D tensor of size (batch_size, batch_size, batch_size) 94 | # triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k 95 | # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1) 96 | # and the 2nd (batch_size, 1, batch_size) 97 | triplet_loss = anchor_positive_dist - anchor_negative_dist + margin 98 | 99 | # Put to zero the invalid triplets 100 | # (where label(a) != label(p) or label(n) == label(a) or a == p) 101 | mask = _get_triplet_mask(labels) 102 | mask = mask.float() 103 | triplet_loss = torch.mul(mask, triplet_loss) 104 | 105 | # Remove negative losses (i.e. the easy triplets) 106 | triplet_loss = torch.max(triplet_loss, torch.tensor([0.0]).cuda()) 107 | 108 | # Count number of positive triplets (where triplet_loss > 0) 109 | valid_triplets = torch.gt(triplet_loss, 1e-16).float() 110 | num_positive_triplets = torch.sum(valid_triplets) 111 | 112 | 113 | # Get final mean triplet loss over the positive valid triplets 114 | triplet_loss = torch.sum(triplet_loss) / (num_positive_triplets + 1e-16) 115 | 116 | return triplet_loss -------------------------------------------------------------------------------- /FLAlgorithms/trainmodel/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from utils.model_config import CONFIGS_ 4 | 5 | import collections 6 | 7 | ################################# 8 | ##### Neural Network model ##### 9 | ################################# 10 | class Net(nn.Module): 11 | def __init__(self, dataset='mnist', model='cnn'): 12 | super(Net, self).__init__() 13 | # define network layers 14 | print("Creating model for {}".format(dataset)) 15 | self.dataset = dataset 16 | configs, input_channel, self.output_dim, self.hidden_dim, self.latent_dim=CONFIGS_[dataset] 17 | print('Network configs:', configs) 18 | self.named_layers, self.layers, self.layer_names =self.build_network( 19 | configs, input_channel, self.output_dim) 20 | self.n_parameters = len(list(self.parameters())) 21 | self.n_share_parameters = len(self.get_encoder()) 22 | 23 | def get_number_of_parameters(self): 24 | pytorch_total_params=sum(p.numel() for p in self.parameters() if p.requires_grad) 25 | return pytorch_total_params 26 | 27 | def build_network(self, configs, input_channel, output_dim): 28 | layers = nn.ModuleList() 29 | named_layers = {} 30 | layer_names = [] 31 | kernel_size, stride, padding = 3, 2, 1 32 | for i, x in enumerate(configs): 33 | if x == 'F': 34 | layer_name='flatten{}'.format(i) 35 | layer=nn.Flatten(1) 36 | layers+=[layer] 37 | layer_names+=[layer_name] 38 | elif x == 'M': 39 | pool_layer = nn.MaxPool2d(kernel_size=2, stride=2) 40 | layer_name = 'pool{}'.format(i) 41 | layers += [pool_layer] 42 | layer_names += [layer_name] 43 | else: 44 | cnn_name = 'encode_cnn{}'.format(i) 45 | cnn_layer = nn.Conv2d(input_channel, x, stride=stride, kernel_size=kernel_size, padding=padding) 46 | named_layers[cnn_name] = [cnn_layer.weight, cnn_layer.bias] 47 | 48 | bn_name = 'encode_batchnorm{}'.format(i) 49 | bn_layer = nn.BatchNorm2d(x) 50 | named_layers[bn_name] = [bn_layer.weight, bn_layer.bias] 51 | 52 | relu_name = 'relu{}'.format(i) 53 | relu_layer = nn.ReLU(inplace=True)# no parameters to learn 54 | 55 | layers += [cnn_layer, bn_layer, relu_layer] 56 | layer_names += [cnn_name, bn_name, relu_name] 57 | input_channel = x 58 | 59 | # finally, classification layer 60 | fc_layer_name1 = 'encode_fc1' 61 | fc_layer1 = nn.Linear(self.hidden_dim, self.latent_dim) 62 | layers += [fc_layer1] 63 | layer_names += [fc_layer_name1] 64 | named_layers[fc_layer_name1] = [fc_layer1.weight, fc_layer1.bias] 65 | 66 | fc_layer_name = 'decode_fc2' 67 | fc_layer = nn.Linear(self.latent_dim, self.output_dim) 68 | layers += [fc_layer] 69 | layer_names += [fc_layer_name] 70 | named_layers[fc_layer_name] = [fc_layer.weight, fc_layer.bias] 71 | return named_layers, layers, layer_names 72 | 73 | 74 | def get_parameters_by_keyword(self, keyword='encode'): 75 | params=[] 76 | for name, layer in zip(self.layer_names, self.layers): 77 | if keyword in name: 78 | #layer = self.layers[name] 79 | params += [layer.weight, layer.bias] 80 | return params 81 | 82 | def get_encoder(self): 83 | return self.get_parameters_by_keyword("encode") 84 | 85 | def get_decoder(self): 86 | return self.get_parameters_by_keyword("decode") 87 | 88 | def get_shared_parameters(self, detach=False): 89 | return self.get_parameters_by_keyword("decode_fc2") 90 | 91 | def get_learnable_params(self): 92 | return self.get_encoder() + self.get_decoder() 93 | 94 | def forward(self, x, start_layer_idx = 0, logit=False, start_layer_output=False): 95 | """ 96 | :param x: 97 | :param logit: return logit vector before the last softmax layer 98 | :param start_layer_idx: if 0, conduct normal forward; otherwise, forward from the last few layers (see mapping function) 99 | :return: 100 | """ 101 | if start_layer_idx < 0: # 102 | return self.mapping(x, start_layer_idx=start_layer_idx, logit=logit) 103 | restults={} 104 | z = x 105 | for idx in range(start_layer_idx, len(self.layers)): 106 | layer_name = self.layer_names[idx] 107 | layer = self.layers[idx] 108 | z = layer(z) 109 | 110 | if start_layer_output: 111 | layer_ = self.layers[start_layer_idx] 112 | layer_ = layer_(x) 113 | 114 | if self.output_dim > 1: 115 | restults['output'] = F.log_softmax(z, dim=1) 116 | else: 117 | restults['output'] = z 118 | if logit: 119 | restults['logit'] = z 120 | if start_layer_output: 121 | restults['strat_layer_output'] = layer_ 122 | return restults 123 | 124 | def mapping(self, z_input, start_layer_idx=-1, logit=True): 125 | z = z_input 126 | n_layers = len(self.layers) 127 | for layer_idx in range(n_layers + start_layer_idx, n_layers): 128 | layer = self.layers[layer_idx] 129 | z = layer(z) 130 | if self.output_dim > 1: 131 | out=F.log_softmax(z, dim=1) 132 | result = {'output': out} 133 | if logit: 134 | result['logit'] = z 135 | return result 136 | -------------------------------------------------------------------------------- /FLAlgorithms/users/userpFedEnsemble.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | import json 6 | import copy 7 | import numpy as np 8 | from FLAlgorithms.users.userbase import User 9 | from FLAlgorithms.optimizers.fedoptimizer import pFedIBOptimizer 10 | from torchvision.utils import save_image 11 | 12 | class UserpFedEnsemble(User): 13 | def __init__(self, dataset, algorithm, numeric_id, train_data, test_data, 14 | model, generative_model, available_labels, 15 | batch_size, learning_rate, beta, lamda, local_epochs, K): 16 | super().__init__(dataset, algorithm, numeric_id, train_data, test_data, model, batch_size, learning_rate, beta, lamda, 17 | local_epochs) 18 | 19 | self.init_loss_fn() 20 | self.K = K 21 | self.optimizer = pFedIBOptimizer(self.model.parameters(), lr=self.learning_rate) 22 | self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.optimizer, gamma=0.99) 23 | self.generative_model = copy.deepcopy(generative_model) 24 | self.label_counts = {} 25 | self.available_labels = available_labels 26 | self.generative_alpha = 10 27 | self.generative_beta = 0.1 28 | self.update_gen_freq = 5 29 | self.pretrained = False 30 | self.generative_optimizer = torch.optim.Adam( 31 | params=self.generative_model.parameters(), 32 | lr=1e-3, betas=(0.9, 0.999), 33 | eps=1e-08, weight_decay=0, amsgrad=False) 34 | self.generative_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.generative_optimizer, gamma=0.98) 35 | 36 | def update_label_counts(self, labels, counts): 37 | for label, count in zip(labels, counts): 38 | self.label_counts[int(label)] += count 39 | 40 | def clean_up_counts(self): 41 | del self.label_counts 42 | self.label_counts = {label:1 for label in range(self.unique_labels)} 43 | 44 | def update_generator(self, steps, verbose=True): 45 | self.model.eval() 46 | self.generative_model.train() 47 | RECONSTRUCT_LOSS, KLD_LOSS, RC_LOSS = 0, 0, 0 48 | for _ in range(steps): 49 | self.generative_optimizer.zero_grad() 50 | samples=self.get_next_train_batch(count_labels=True) 51 | X, y=samples['X'], samples['y'] 52 | gen_result = self.generative_model(X, y) 53 | loss_info = self.generative_model.loss_function( 54 | gen_result['output'], 55 | X, 56 | gen_result['mu'], 57 | gen_result['log_var'], 58 | beta=0.01 59 | ) 60 | loss, kld_loss, reconstruct_loss = loss_info['loss'], loss_info['KLD'], loss_info['reconstruction_loss'] 61 | RECONSTRUCT_LOSS += loss 62 | KLD_LOSS += kld_loss 63 | RC_LOSS += reconstruct_loss 64 | loss.backward() 65 | self.generative_optimizer.step() 66 | self.generative_lr_scheduler.step() 67 | 68 | if verbose: 69 | RECONSTRUCT_LOSS = RECONSTRUCT_LOSS.detach().numpy() / steps 70 | KLD_LOSS = KLD_LOSS.detach().numpy() / steps 71 | RC_LOSS = RC_LOSS.detach().numpy() / steps 72 | info = "VAE-Loss: {:.4f}, KL-Loss: {:.4f}, RC-Loss:{:.4f}".format(RECONSTRUCT_LOSS, KLD_LOSS, RC_LOSS) 73 | print(info) 74 | 75 | 76 | def train(self, glob_iter, personalized=False, reconstruct=False, verbose=False): 77 | self.clean_up_counts() 78 | self.model.train() 79 | self.generative_model.eval() 80 | TEACHER_LOSS, DIST_LOSS, RECONSTRUCT_LOSS = 0, 0, 0 81 | #if glob_iter % self.update_gen_freq == 0: 82 | if not self.pretrained: 83 | self.update_generator(self.local_epochs * 20) 84 | self.pretrained = True 85 | self.visualize_images(self.generative_model, 0, repeats=10) 86 | for epoch in range(self.local_epochs): 87 | self.model.train() 88 | self.generative_model.eval() 89 | for i in range(self.K): 90 | loss = 0 91 | self.optimizer.zero_grad() 92 | #### sample from real dataset (un-weighted) 93 | samples =self.get_next_train_batch(count_labels=True) 94 | X, y = samples['X'], samples['y'] 95 | self.update_label_counts(samples['labels'], samples['counts']) 96 | model_result=self.model(X, return_latent=reconstruct) 97 | output = model_result['output'] 98 | predictive_loss=self.loss(output, y) 99 | loss += predictive_loss 100 | #### sample from generator and regulate Dist|z_gen, z_pred|, where z_gen = Gen(x, y), z_pred = model(X) 101 | if reconstruct: 102 | gen_result = self.generative_model(X, y, latent=True) 103 | z_gen = gen_result['latent'] 104 | z_model = model_result['latent'] 105 | dist_loss = self.generative_beta * self.dist_loss(z_model, z_gen) 106 | DIST_LOSS += dist_loss 107 | loss += dist_loss 108 | #### get loss and perform optimization 109 | loss.backward() 110 | self.optimizer.step() # self.local_model) 111 | 112 | if reconstruct: 113 | DIST_LOSS+=(torch.mean(DIST_LOSS.double())).item() 114 | # local-model <=== self.model 115 | self.clone_model_paramenter(self.model.parameters(), self.local_model) 116 | if personalized: 117 | self.clone_model_paramenter(self.model.parameters(), self.personalized_model_bar) 118 | self.lr_scheduler.step(glob_iter) 119 | if reconstruct and verbose: 120 | info = 'Latent Reconstruction Loss={:.4f}'.format(DIST_LOSS) 121 | print(info) 122 | 123 | def adjust_weights(self, samples): 124 | labels, counts = samples['labels'], samples['counts'] 125 | #weight=self.label_weights[y][:, user_idx].reshape(-1, 1) 126 | np_y = samples['y'].detach().numpy() 127 | n_labels = samples['y'].shape[0] 128 | weights = np.array([n_labels / count for count in counts]) # smaller count --> larger weight 129 | weights = len(self.available_labels) * weights / np.sum(weights) # normalized 130 | label_weights = np.ones(self.unique_labels) 131 | label_weights[labels] = weights 132 | sample_weights = label_weights[np_y] 133 | return sample_weights 134 | 135 | def visualize_images(self, generator, glob_iter, repeats=1): 136 | """ 137 | Generate and visualize data for a generator. 138 | """ 139 | os.system("mkdir -p images") 140 | path = f'images/{self.algorithm}-{self.dataset}-user{self.id}-iter{glob_iter}.png' 141 | y=self.available_labels 142 | y = np.repeat(y, repeats=repeats, axis=0) 143 | y_input=torch.tensor(y, dtype=torch.int64) 144 | generator.eval() 145 | images=generator.sample(y_input, latent=False)['output'] # 0,1,..,K, 0,1,...,K 146 | images=images.view(repeats, -1, *images.shape[1:]) 147 | images=images.view(-1, *images.shape[2:]) 148 | save_image(images.detach(), path, nrow=repeats, normalize=True) 149 | print("Image saved to {}".format(path)) -------------------------------------------------------------------------------- /FLAlgorithms/users/userbase.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | import json 6 | from torch.utils.data import DataLoader 7 | import numpy as np 8 | import copy 9 | from utils.model_utils import get_dataset_name 10 | from utils.model_config import RUNCONFIGS 11 | from FLAlgorithms.optimizers.fedoptimizer import pFedIBOptimizer 12 | 13 | class User: 14 | """ 15 | Base class for users in federated learning. 16 | """ 17 | def __init__( 18 | self, args, id, model, train_data, test_data, use_adam=False): 19 | self.model = copy.deepcopy(model[0]) 20 | self.model_name = model[1] 21 | self.id = id # integer 22 | self.train_samples = len(train_data) 23 | self.test_samples = len(test_data) 24 | self.batch_size = args.batch_size 25 | self.learning_rate = args.learning_rate 26 | self.beta = args.beta 27 | self.lamda = args.lamda 28 | self.local_epochs = args.local_epochs 29 | self.algorithm = args.algorithm 30 | self.K = args.K 31 | self.dataset = args.dataset 32 | self.trainloader = DataLoader(train_data, self.batch_size, shuffle=True, drop_last=True) 33 | self.testloader = DataLoader(test_data, self.batch_size, drop_last=False) 34 | self.testloaderfull = DataLoader(test_data, self.test_samples) 35 | self.trainloaderfull = DataLoader(train_data, self.train_samples) 36 | self.iter_trainloader = iter(self.trainloader) 37 | self.iter_testloader = iter(self.testloader) 38 | dataset_name = get_dataset_name(self.dataset) 39 | self.unique_labels = RUNCONFIGS[dataset_name]['unique_labels'] 40 | self.generative_alpha = RUNCONFIGS[dataset_name]['generative_alpha'] 41 | self.generative_beta = RUNCONFIGS[dataset_name]['generative_beta'] 42 | 43 | # those parameters are for personalized federated learning. 44 | self.local_model = copy.deepcopy(list(self.model.parameters())) 45 | self.personalized_model_bar = copy.deepcopy(list(self.model.parameters())) 46 | self.prior_decoder = None 47 | self.prior_params = None 48 | 49 | self.init_loss_fn() 50 | if use_adam: 51 | self.optimizer=torch.optim.Adam( 52 | params=self.model.parameters(), 53 | lr=self.learning_rate, betas=(0.9, 0.999), 54 | eps=1e-08, weight_decay=1e-2, amsgrad=False) 55 | else: 56 | self.optimizer = pFedIBOptimizer(self.model.parameters(), lr=self.learning_rate) 57 | self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.optimizer, gamma=0.99) 58 | self.label_counts = {} 59 | 60 | 61 | 62 | 63 | def init_loss_fn(self): 64 | self.loss=nn.NLLLoss() 65 | self.dist_loss = nn.MSELoss() 66 | self.ensemble_loss=nn.KLDivLoss(reduction="batchmean") 67 | self.ce_loss = nn.CrossEntropyLoss() 68 | 69 | def set_parameters(self, model,beta=1): 70 | for old_param, new_param, local_param in zip(self.model.parameters(), model.parameters(), self.local_model): 71 | if beta == 1: 72 | old_param.data = new_param.data.clone() 73 | local_param.data = new_param.data.clone() 74 | else: 75 | old_param.data = beta * new_param.data.clone() + (1 - beta) * old_param.data.clone() 76 | local_param.data = beta * new_param.data.clone() + (1-beta) * local_param.data.clone() 77 | 78 | def set_prior_decoder(self, model, beta=1): 79 | for new_param, local_param in zip(model.personal_layers, self.prior_decoder): 80 | if beta == 1: 81 | local_param.data = new_param.data.clone() 82 | else: 83 | local_param.data = beta * new_param.data.clone() + (1 - beta) * local_param.data.clone() 84 | 85 | 86 | def set_prior(self, model): 87 | for new_param, local_param in zip(model.get_encoder() + model.get_decoder(), self.prior_params): 88 | local_param.data = new_param.data.clone() 89 | 90 | # only for pFedMAS 91 | def set_mask(self, mask_model): 92 | for new_param, local_param in zip(mask_model.get_masks(), self.mask_model.get_masks()): 93 | local_param.data = new_param.data.clone() 94 | 95 | def set_shared_parameters(self, model, mode='decode'): 96 | # only copy shared parameters to local 97 | for old_param, new_param in zip( 98 | self.model.get_parameters_by_keyword(mode), 99 | model.get_parameters_by_keyword(mode) 100 | ): 101 | old_param.data = new_param.data.clone() 102 | 103 | def get_parameters(self): 104 | for param in self.model.parameters(): 105 | param.detach() 106 | return self.model.parameters() 107 | 108 | 109 | def clone_model_paramenter(self, param, clone_param): 110 | with torch.no_grad(): 111 | for param, clone_param in zip(param, clone_param): 112 | clone_param.data = param.data.clone() 113 | return clone_param 114 | 115 | def get_updated_parameters(self): 116 | return self.local_weight_updated 117 | 118 | def update_parameters(self, new_params, keyword='all'): 119 | for param , new_param in zip(self.model.parameters(), new_params): 120 | param.data = new_param.data.clone() 121 | 122 | def get_grads(self): 123 | grads = [] 124 | for param in self.model.parameters(): 125 | if param.grad is None: 126 | grads.append(torch.zeros_like(param.data)) 127 | else: 128 | grads.append(param.grad.data) 129 | return grads 130 | 131 | def test(self): 132 | self.model.eval() 133 | test_acc = 0 134 | loss = 0 135 | for x, y in self.testloaderfull: 136 | output = self.model(x)['output'] 137 | loss += self.loss(output, y) 138 | test_acc += (torch.sum(torch.argmax(output, dim=1) == y)).item() 139 | return test_acc, loss, y.shape[0] 140 | 141 | 142 | 143 | def test_personalized_model(self): 144 | self.model.eval() 145 | test_acc = 0 146 | loss = 0 147 | self.update_parameters(self.personalized_model_bar) 148 | for x, y in self.testloaderfull: 149 | output = self.model(x)['output'] 150 | loss += self.loss(output, y) 151 | test_acc += (torch.sum(torch.argmax(output, dim=1) == y)).item() 152 | #@loss += self.loss(output, y) 153 | #print(self.id + ", Test Accuracy:", test_acc / y.shape[0] ) 154 | #print(self.id + ", Test Loss:", loss) 155 | self.update_parameters(self.local_model) 156 | return test_acc, y.shape[0], loss 157 | 158 | 159 | def get_next_train_batch(self, count_labels=True): 160 | try: 161 | # Samples a new batch for personalizing 162 | (X, y) = next(self.iter_trainloader) 163 | except StopIteration: 164 | # restart the generator if the previous generator is exhausted. 165 | self.iter_trainloader = iter(self.trainloader) 166 | (X, y) = next(self.iter_trainloader) 167 | result = {'X': X, 'y': y} 168 | if count_labels: 169 | unique_y, counts=torch.unique(y, return_counts=True) 170 | unique_y = unique_y.detach().numpy() 171 | counts = counts.detach().numpy() 172 | result['labels'] = unique_y 173 | result['counts'] = counts 174 | return result 175 | 176 | def get_next_test_batch(self): 177 | try: 178 | # Samples a new batch for personalizing 179 | (X, y) = next(self.iter_testloader) 180 | except StopIteration: 181 | # restart the generator if the previous generator is exhausted. 182 | self.iter_testloader = iter(self.testloader) 183 | (X, y) = next(self.iter_testloader) 184 | return (X, y) 185 | 186 | def save_model(self): 187 | model_path = os.path.join("models", self.dataset) 188 | if not os.path.exists(model_path): 189 | os.makedirs(model_path) 190 | torch.save(self.model, os.path.join(model_path, "user_" + self.id + ".pt")) 191 | 192 | def load_model(self): 193 | model_path = os.path.join("models", self.dataset) 194 | self.model = torch.load(os.path.join(model_path, "server" + ".pt")) 195 | 196 | @staticmethod 197 | def model_exists(): 198 | return os.path.exists(os.path.join("models", "server" + ".pt")) -------------------------------------------------------------------------------- /data/CelebA/generate_niid_agg.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import fetch_openml 2 | from tqdm import trange 3 | import numpy as np 4 | import random 5 | import json 6 | import os 7 | import argparse 8 | from PIL import Image 9 | import torch 10 | from torch.utils.data import Dataset 11 | import torchvision.transforms as transforms 12 | 13 | IMAGE_SIZE = 84 14 | # TODO change LOAD_PATH to be your own data path 15 | LOAD_PATH = './celeba/' 16 | DUMP_PATH = './' 17 | IMG_DIR = os.path.join(LOAD_PATH, 'data/raw/img_align_celeba') 18 | random.seed(42) 19 | np.random.seed(42) 20 | 21 | def load_proxy_data(user_lists, cdata): 22 | n_users = len(user_lists) 23 | # merge all uer data chosen for proxy 24 | new_data={ 25 | 'classes': [], 26 | 'user_data': { 27 | 'x': [], 28 | 'y': [] 29 | }, 30 | 'num_samples': 0 31 | } 32 | for uname in user_lists: 33 | user_data = cdata['user_data'][uname] 34 | X=user_data['x'] # path to image 35 | y=user_data['y'] # label 36 | assert len(X) == len(y) 37 | ## load image ## 38 | loaded_X = [] 39 | for i, image_name in enumerate(X): 40 | image_path = os.path.join(IMG_DIR, image_name) 41 | image = Image.open(image_path) 42 | image=image.resize((IMAGE_SIZE, IMAGE_SIZE)).convert('RGB') 43 | image = np.array(image) 44 | loaded_X.append(image) 45 | 46 | new_data['user_data']['x'] += np.array(loaded_X).tolist() 47 | new_data['user_data']['y'] += y 48 | 49 | new_data['classes'] += list(np.unique(y)) 50 | new_data['num_samples'] += len(y) 51 | combined = list(zip(new_data['user_data']['x'], new_data['user_data']['y'])) 52 | return new_data 53 | 54 | 55 | def load_data(user_lists, cdata, agg_user=-1): 56 | n_users = len(user_lists) 57 | new_data = { 58 | 'users': [None for _ in range(n_users)], 59 | 'num_samples': [None for _ in range(n_users)], 60 | 'user_data': {} 61 | } 62 | if agg_user > 0: 63 | assert len(user_lists) % agg_user == 0 64 | agg_n_users = len(user_lists) // agg_user 65 | agg_data = { 66 | 'users': [None for _ in range(agg_n_users)], 67 | 'num_samples': [None for _ in range(agg_n_users)], 68 | 'user_data': {} 69 | 70 | } 71 | 72 | def agg_by_user_(new_data, agg_n_users, agg_user, verbose=False): 73 | for batch_id in range(agg_n_users): 74 | start_id, end_id = batch_id * agg_user, (batch_id + 1) * agg_user 75 | X, Y, N_samples = [], [], 0 76 | for idx in range(start_id, end_id): 77 | user_uname='f_{0:05d}'.format(idx) 78 | x = new_data['user_data'][user_uname]['x'] 79 | y = new_data['user_data'][user_uname]['y'] 80 | n_samples = new_data['num_samples'][idx] 81 | X += x 82 | Y += y 83 | N_samples += n_samples 84 | 85 | ##### 86 | batch_user_name = 'f_{0:05d}'.format(batch_id) 87 | agg_data['users'][batch_id]= batch_user_name 88 | agg_data['num_samples'][batch_id]=len(Y) 89 | agg_data['user_data'][batch_user_name]={ 90 | 'x': torch.Tensor(X).type(torch.float32).permute(0, 3, 1, 2), 91 | 'y': torch.Tensor(Y).type(torch.int64) 92 | } 93 | ##### 94 | 95 | 96 | def load_per_user_(user_data, idx, verbose=False): 97 | """ 98 | # Reduce test samples per user to ratio. 99 | :param uname: 100 | :param idx: 101 | :return: 102 | """ 103 | new_uname='f_{0:05d}'.format(idx) 104 | X=user_data['x'] 105 | y=user_data['y'] 106 | assert len(X) == len(y) 107 | new_data['users'][idx] = new_uname 108 | new_data['num_samples'][idx] = len(y) 109 | new_data['user_data'][new_uname] = {'y':y} 110 | # load X as images 111 | loaded_X = [] 112 | for i, image_name in enumerate(X): 113 | image_path = os.path.join(IMG_DIR, image_name) 114 | image = Image.open(image_path) 115 | image=image.resize((IMAGE_SIZE, IMAGE_SIZE)).convert('RGB') 116 | image = np.array(image) 117 | loaded_X.append(image) 118 | new_data['user_data'][new_uname] = {'x': np.array(loaded_X).tolist(), 'y':y} 119 | if verbose: 120 | print("processing user {}".format(new_uname)) 121 | 122 | for idx, uname in enumerate(user_lists): 123 | user_data = cdata['user_data'][uname] 124 | load_per_user_(user_data, idx, verbose=True) 125 | #pass 126 | if agg_user == -1: 127 | return new_data 128 | else: 129 | agg_by_user_(new_data, agg_n_users, agg_user, verbose=False) 130 | return agg_data 131 | 132 | def process_data(): 133 | load_path = os.path.join(LOAD_PATH, 'data') 134 | train_files = [f for f in os.listdir(os.path.join(load_path, 'train'))] 135 | test_files = [f for f in os.listdir(os.path.join(load_path, 'test'))] 136 | 137 | def sample_users(cdata, ratio=0.1, excludes=set()): 138 | """ 139 | :param cdata: 140 | :param ratio: 141 | :return: list of sampled user names 142 | """ 143 | user_lists=[u for u in cdata['users']] 144 | if ratio <= 1: 145 | n_selected_users=int(len(user_lists) * ratio) 146 | else: 147 | n_selected_users = ratio 148 | random.shuffle(user_lists) 149 | new_users = [] 150 | i = 0 151 | for u in user_lists: 152 | if u not in excludes: 153 | new_users.append(u) 154 | i += 1 155 | if i == n_selected_users: 156 | return new_users 157 | 158 | 159 | def process_(mode, tf, ratio=0.1, user_lists=None, agg_user=-1): 160 | read_path = os.path.join(load_path, mode if mode != 'proxy' else 'train', tf) 161 | with open(read_path, 'r') as inf: 162 | cdata=json.load(inf) 163 | n_users = len(cdata['users']) 164 | if ratio > 1: 165 | assert ratio < n_users 166 | else: 167 | assert ratio < 1 168 | print("Number of users: {}".format(n_users)) 169 | print("Number of raw {} samples: {:.1f}".format(mode, np.mean(cdata['num_samples']))) 170 | print("Deviation of raw {} samples: {:.1f}".format(mode, np.std(cdata['num_samples']))) 171 | #exit() 172 | if mode == 'train': 173 | assert user_lists == None 174 | user_lists = sample_users(cdata, ratio) 175 | new_data=load_data(user_lists, cdata, agg_user=agg_user) 176 | else: # test mode 177 | assert len(user_lists) > 0 178 | new_data = load_data(user_lists, cdata, agg_user=agg_user) 179 | print("Number of reduced users: {}".format(len(new_data['num_samples']))) 180 | print("Number of samples per user: {}".format(new_data['num_samples'])) 181 | 182 | if ratio > 1: 183 | n_users = int(ratio) 184 | if agg_user > 0: 185 | n_users = int( n_users // agg_user) 186 | else: 187 | n_users = int(len(cdata['users']) * ratio) 188 | if agg_user > 0: 189 | dump_path=os.path.join(DUMP_PATH, 'user{}-agg{}'.format(n_users, agg_user)) 190 | else: 191 | dump_path=os.path.join(DUMP_PATH, 'user{}'.format(n_users)) 192 | os.system("mkdir -p {}".format(os.path.join(dump_path, mode))) 193 | 194 | dump_path = os.path.join(dump_path, '{}/{}.pt'.format(mode,mode)) 195 | with open(dump_path, 'wb') as outfile: 196 | print("Saving {} data to {}".format(mode, dump_path)) 197 | torch.save(new_data, outfile) 198 | return user_lists 199 | 200 | # 201 | mode='train' 202 | tf = train_files[0] 203 | user_lists = process_(mode, tf, ratio=args.ratio, agg_user=args.agg_user) 204 | mode = 'test' 205 | tf = test_files[0] 206 | process_(mode, tf, ratio=args.ratio, user_lists=user_lists, agg_user=args.agg_user) 207 | 208 | 209 | if __name__ == "__main__": 210 | parser = argparse.ArgumentParser() 211 | parser.add_argument("--agg_user", type=int, default=10, help="number of celebrities to be aggregated together as a device/client (as meta-batch size).") 212 | parser.add_argument("--ratio", type=float, default=250, help="Number of total celebrities to be sampled for FL training.") 213 | args = parser.parse_args() 214 | print("Number of FL devices: {}".format(args.ratio // args.agg_user )) 215 | process_data() -------------------------------------------------------------------------------- /data/Mnist/generate_niid_dirichlet.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | import numpy as np 3 | import random 4 | import json 5 | import os 6 | import argparse 7 | from torchvision.datasets import MNIST 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import torchvision.transforms as transforms 11 | 12 | random.seed(42) 13 | np.random.seed(42) 14 | 15 | def rearrange_data_by_class(data, targets, n_class): 16 | new_data = [] 17 | for i in trange(n_class): 18 | idx = targets == i 19 | new_data.append(data[idx]) 20 | return new_data 21 | 22 | def get_dataset(mode='train'): 23 | transform = transforms.Compose( 24 | [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) 25 | 26 | dataset = MNIST(root='./data', train=True if mode=='train' else False, download=True, transform=transform) 27 | n_sample = len(dataset.data) 28 | SRC_N_CLASS = len(dataset.classes) 29 | # full batch 30 | trainloader = DataLoader(dataset, batch_size=n_sample, shuffle=False) 31 | 32 | print("Loading data from storage ...") 33 | for _, xy in enumerate(trainloader, 0): 34 | dataset.data, dataset.targets = xy 35 | 36 | print("Rearrange data by class...") 37 | data_by_class = rearrange_data_by_class( 38 | dataset.data.cpu().detach().numpy(), 39 | dataset.targets.cpu().detach().numpy(), 40 | SRC_N_CLASS 41 | ) 42 | print(f"{mode.upper()} SET:\n Total #samples: {n_sample}. sample shape: {dataset.data[0].shape}") 43 | print(" #samples per class:\n", [len(v) for v in data_by_class]) 44 | 45 | return data_by_class, n_sample, SRC_N_CLASS 46 | 47 | def sample_class(SRC_N_CLASS, NUM_LABELS, user_id, label_random=False): 48 | assert NUM_LABELS <= SRC_N_CLASS 49 | if label_random: 50 | source_classes = [n for n in range(SRC_N_CLASS)] 51 | random.shuffle(source_classes) 52 | return source_classes[:NUM_LABELS] 53 | else: 54 | return [(user_id + j) % SRC_N_CLASS for j in range(NUM_LABELS)] 55 | 56 | def devide_train_data(data, n_sample, SRC_CLASSES, NUM_USERS, min_sample, alpha=0.5, sampling_ratio=0.5): 57 | min_sample = 10#len(SRC_CLASSES) * min_sample 58 | min_size = 0 # track minimal samples per user 59 | ###### Determine Sampling ####### 60 | while min_size < min_sample: 61 | print("Try to find valid data separation") 62 | idx_batch=[{} for _ in range(NUM_USERS)] 63 | samples_per_user = [0 for _ in range(NUM_USERS)] 64 | max_samples_per_user = sampling_ratio * n_sample / NUM_USERS 65 | for l in SRC_CLASSES: 66 | # get indices for all that label 67 | idx_l = [i for i in range(len(data[l]))] 68 | np.random.shuffle(idx_l) 69 | if sampling_ratio < 1: 70 | samples_for_l = int( min(max_samples_per_user, int(sampling_ratio * len(data[l]))) ) 71 | idx_l = idx_l[:samples_for_l] 72 | print(l, len(data[l]), len(idx_l)) 73 | # dirichlet sampling from this label 74 | proportions=np.random.dirichlet(np.repeat(alpha, NUM_USERS)) 75 | # re-balance proportions 76 | proportions=np.array([p * (n_per_user < max_samples_per_user) for p, n_per_user in zip(proportions, samples_per_user)]) 77 | proportions=proportions / proportions.sum() 78 | proportions=(np.cumsum(proportions) * len(idx_l)).astype(int)[:-1] 79 | # participate data of that label 80 | for u, new_idx in enumerate(np.split(idx_l, proportions)): 81 | # add new idex to the user 82 | idx_batch[u][l] = new_idx.tolist() 83 | samples_per_user[u] += len(idx_batch[u][l]) 84 | min_size=min(samples_per_user) 85 | 86 | ###### CREATE USER DATA SPLIT ####### 87 | X = [[] for _ in range(NUM_USERS)] 88 | y = [[] for _ in range(NUM_USERS)] 89 | Labels=[set() for _ in range(NUM_USERS)] 90 | print("processing users...") 91 | for u, user_idx_batch in enumerate(idx_batch): 92 | for l, indices in user_idx_batch.items(): 93 | if len(indices) == 0: continue 94 | X[u] += data[l][indices].tolist() 95 | y[u] += (l * np.ones(len(indices))).tolist() 96 | Labels[u].add(l) 97 | 98 | return X, y, Labels, idx_batch, samples_per_user 99 | 100 | def divide_test_data(NUM_USERS, SRC_CLASSES, test_data, Labels, unknown_test): 101 | # Create TEST data for each user. 102 | test_X = [[] for _ in range(NUM_USERS)] 103 | test_y = [[] for _ in range(NUM_USERS)] 104 | idx = {l: 0 for l in SRC_CLASSES} 105 | for user in trange(NUM_USERS): 106 | if unknown_test: # use all available labels 107 | user_sampled_labels = SRC_CLASSES 108 | else: 109 | user_sampled_labels = list(Labels[user]) 110 | for l in user_sampled_labels: 111 | num_samples = int(len(test_data[l]) / NUM_USERS ) 112 | assert num_samples + idx[l] <= len(test_data[l]) 113 | test_X[user] += test_data[l][idx[l]:idx[l] + num_samples].tolist() 114 | test_y[user] += (l * np.ones(num_samples)).tolist() 115 | assert len(test_X[user]) == len(test_y[user]), f"{len(test_X[user])} == {len(test_y[user])}" 116 | idx[l] += num_samples 117 | return test_X, test_y 118 | 119 | def main(): 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument("--format", "-f", type=str, default="pt", help="Format of saving: pt (torch.save), json", choices=["pt", "json"]) 122 | parser.add_argument("--n_class", type=int, default=10, help="number of classification labels") 123 | parser.add_argument("--min_sample", type=int, default=10, help="Min number of samples per user.") 124 | parser.add_argument("--sampling_ratio", type=float, default=0.05, help="Ratio for sampling training samples.") 125 | parser.add_argument("--unknown_test", type=int, default=0, help="Whether allow test label unseen for each user.") 126 | parser.add_argument("--alpha", type=float, default=0.01, help="alpha in Dirichelt distribution (smaller means larger heterogeneity)") 127 | parser.add_argument("--n_user", type=int, default=20, 128 | help="number of local clients, should be muitiple of 10.") 129 | args = parser.parse_args() 130 | print() 131 | print("Number of users: {}".format(args.n_user)) 132 | print("Number of classes: {}".format(args.n_class)) 133 | print("Min # of samples per uesr: {}".format(args.min_sample)) 134 | print("Alpha for Dirichlet Distribution: {}".format(args.alpha)) 135 | print("Ratio for Sampling Training Data: {}".format(args.sampling_ratio)) 136 | NUM_USERS = args.n_user 137 | 138 | # Setup directory for train/test data 139 | path_prefix = f'u{args.n_user}c{args.n_class}-alpha{args.alpha}-ratio{args.sampling_ratio}' 140 | 141 | def process_user_data(mode, data, n_sample, SRC_CLASSES, Labels=None, unknown_test=0): 142 | if mode == 'train': 143 | X, y, Labels, idx_batch, samples_per_user = devide_train_data( 144 | data, n_sample, SRC_CLASSES, NUM_USERS, args.min_sample, args.alpha, args.sampling_ratio) 145 | if mode == 'test': 146 | assert Labels != None or unknown_test 147 | X, y = divide_test_data(NUM_USERS, SRC_CLASSES, data, Labels, unknown_test) 148 | dataset={'users': [], 'user_data': {}, 'num_samples': []} 149 | for i in range(NUM_USERS): 150 | uname='f_{0:05d}'.format(i) 151 | dataset['users'].append(uname) 152 | dataset['user_data'][uname]={ 153 | 'x': torch.tensor(X[i], dtype=torch.float32), 154 | 'y': torch.tensor(y[i], dtype=torch.int64)} 155 | dataset['num_samples'].append(len(X[i])) 156 | 157 | print("{} #sample by user:".format(mode.upper()), dataset['num_samples']) 158 | 159 | data_path=f'./{path_prefix}/{mode}' 160 | if not os.path.exists(data_path): 161 | os.makedirs(data_path) 162 | 163 | data_path=os.path.join(data_path, "{}.".format(mode) + args.format) 164 | if args.format == "json": 165 | raise NotImplementedError( 166 | "json is not supported because the train_data/test_data uses the tensor instead of list and tensor cannot be saved into json.") 167 | with open(data_path, 'w') as outfile: 168 | print(f"Dumping train data => {data_path}") 169 | json.dump(dataset, outfile) 170 | elif args.format == "pt": 171 | with open(data_path, 'wb') as outfile: 172 | print(f"Dumping train data => {data_path}") 173 | torch.save(dataset, outfile) 174 | if mode == 'train': 175 | for u in range(NUM_USERS): 176 | print("{} samples in total".format(samples_per_user[u])) 177 | train_info = '' 178 | # train_idx_batch, train_samples_per_user 179 | n_samples_for_u = 0 180 | for l in sorted(list(Labels[u])): 181 | n_samples_for_l = len(idx_batch[u][l]) 182 | n_samples_for_u += n_samples_for_l 183 | train_info += "c={},n={}| ".format(l, n_samples_for_l) 184 | print(train_info) 185 | print("{} Labels/ {} Number of training samples for user [{}]:".format(len(Labels[u]), n_samples_for_u, u)) 186 | return Labels, idx_batch, samples_per_user 187 | 188 | 189 | print(f"Reading source dataset.") 190 | train_data, n_train_sample, SRC_N_CLASS = get_dataset(mode='train') 191 | test_data, n_test_sample, SRC_N_CLASS = get_dataset(mode='test') 192 | SRC_CLASSES=[l for l in range(SRC_N_CLASS)] 193 | random.shuffle(SRC_CLASSES) 194 | print("{} labels in total.".format(len(SRC_CLASSES))) 195 | Labels, idx_batch, samples_per_user = process_user_data('train', train_data, n_train_sample, SRC_CLASSES) 196 | process_user_data('test', test_data, n_test_sample, SRC_CLASSES, Labels=Labels, unknown_test=args.unknown_test) 197 | print("Finish Generating User samples") 198 | 199 | if __name__ == "__main__": 200 | main() -------------------------------------------------------------------------------- /FLAlgorithms/trainmodel/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | MAXLOG = 0.1 5 | from torch.autograd import Variable 6 | import collections 7 | import numpy as np 8 | from utils.model_config import GENERATORCONFIGS 9 | 10 | 11 | class Generator(nn.Module): 12 | def __init__(self, dataset='mnist', model='cnn', embedding=False, latent_layer_idx=-1): 13 | super(Generator, self).__init__() 14 | print("Dataset {}".format(dataset)) 15 | self.embedding = embedding 16 | self.dataset = dataset 17 | #self.model=model 18 | self.latent_layer_idx = latent_layer_idx 19 | self.hidden_dim, self.latent_dim, self.input_channel, self.n_class, self.noise_dim = GENERATORCONFIGS[dataset] 20 | input_dim = self.noise_dim * 2 if self.embedding else self.noise_dim + self.n_class 21 | self.fc_configs = [input_dim, self.hidden_dim] 22 | self.init_loss_fn() 23 | self.build_network() 24 | 25 | def get_number_of_parameters(self): 26 | pytorch_total_params=sum(p.numel() for p in self.parameters() if p.requires_grad) 27 | return pytorch_total_params 28 | 29 | def init_loss_fn(self): 30 | self.crossentropy_loss=nn.NLLLoss(reduce=False) # same as above 31 | self.diversity_loss = DiversityLoss(metric='l1') 32 | self.dist_loss = nn.MSELoss() 33 | 34 | def build_network(self): 35 | if self.embedding: 36 | self.embedding_layer = nn.Embedding(self.n_class, self.noise_dim) 37 | ### FC modules #### 38 | self.fc_layers = nn.ModuleList() 39 | self.fc_configs[0] += 1 40 | for i in range(len(self.fc_configs) - 1): 41 | input_dim, out_dim = self.fc_configs[i], self.fc_configs[i + 1] 42 | print("Build layer {} X {}".format(input_dim, out_dim)) 43 | fc = nn.Linear(input_dim, out_dim) 44 | bn = nn.BatchNorm1d(out_dim) 45 | act = nn.ReLU() 46 | self.fc_layers += [fc, bn, act] 47 | ### Representation layer 48 | self.representation_layer = nn.Linear(self.fc_configs[-1], self.latent_dim) 49 | print("Build last layer {} X {}".format(self.fc_configs[-1], self.latent_dim)) 50 | 51 | def forward(self, labels, Real_CL_Results, latent_layer_idx=-1, verbose=True): 52 | """ 53 | G(Z|y) or G(X|y): 54 | Generate either latent representation( latent_layer_idx < 0) or raw image (latent_layer_idx=0) conditional on labels. 55 | :param labels: 56 | :param latent_layer_idx: 57 | if -1, generate latent representation of the last layer, 58 | -2 for the 2nd to last layer, 0 for raw images. 59 | :param verbose: also return the sampled Gaussian noise if verbose = True 60 | :return: a dictionary of output information. 61 | """ 62 | result = {} 63 | batch_size = labels.shape[0] 64 | eps = torch.rand((batch_size, self.noise_dim)) # sampling from Gaussian 65 | if verbose: 66 | result['eps'] = eps 67 | if self.embedding: # embedded dense vector 68 | y_input = self.embedding_layer(labels) 69 | else: # one-hot (sparse) vector 70 | y_input = torch.FloatTensor(batch_size, self.n_class) 71 | y_input.zero_() 72 | #labels = labels.view 73 | labels_int64 = labels.type(torch.LongTensor) 74 | y_input.scatter_(1, labels_int64.view(-1,1), 1) 75 | z = torch.cat((eps, y_input), dim=1) 76 | #print(z.size()) 77 | #z = z * Real_CL_Resultsw 78 | 79 | z = torch.cat((z, Real_CL_Results), dim=1) 80 | 81 | ### FC layers 82 | for layer in self.fc_layers: 83 | z = layer(z) 84 | z = self.representation_layer(z) 85 | result['output'] = z 86 | return result 87 | 88 | @staticmethod 89 | def normalize_images(layer): 90 | """ 91 | Normalize images into zero-mean and unit-variance. 92 | """ 93 | mean = layer.mean(dim=(2, 3), keepdim=True) 94 | std = layer.view((layer.size(0), layer.size(1), -1)) \ 95 | .std(dim=2, keepdim=True).unsqueeze(3) 96 | return (layer - mean) / std 97 | 98 | class Discriminator(nn.Module): 99 | def __init__(self, size=0,dataset='mnist', model='cnn', embedding=False): 100 | super(Discriminator, self).__init__() 101 | self.size = size 102 | self.embedding = embedding 103 | self.dataset = dataset 104 | self.hidden_dim, self.latent_dim, self.input_channel, self.n_class, self.noise_dim = GENERATORCONFIGS[dataset] 105 | self.init_loss_fn() 106 | self.model_1 = nn.Sequential( 107 | nn.Linear(self.latent_dim, 128), 108 | nn.LeakyReLU(0.2, inplace=True), 109 | nn.Linear(128, 64), 110 | nn.LeakyReLU(0.2, inplace=True), 111 | nn.Linear(64, self.latent_dim), 112 | nn.Sigmoid(), 113 | ) 114 | self.model_2 = nn.Sequential( 115 | nn.Linear(self.latent_dim, 128), 116 | nn.LeakyReLU(0.2, inplace=True), 117 | nn.Linear(128, 64), 118 | nn.LeakyReLU(0.2, inplace=True), 119 | nn.Linear(64, 1), 120 | nn.Sigmoid(), 121 | ) 122 | def forward(self, img,cl_score): 123 | img = self.model_1(img).squeeze(-1) 124 | cl_score = self.model_2(cl_score) 125 | return img, cl_score 126 | 127 | def init_loss_fn(self): 128 | self.crossentropy_loss=nn.NLLLoss(reduce=False) # same as above 129 | self.diversity_loss = DiversityLoss(metric='l1') 130 | self.dist_loss = nn.MSELoss() 131 | # 132 | # class Decoder(nn.Module): 133 | # """ 134 | # Decoder for both unstructured and image datasets. 135 | # """ 136 | # def __init__(self, dataset='mnist', latent_layer_idx=-1, n_layers=2, units=32): 137 | # """ 138 | # Class initializer. 139 | # """ 140 | # #in_features, out_targets, n_layers=2, units=32): 141 | # super(Decoder, self).__init__() 142 | # self.cv_configs, self.input_channel, self.n_class, self.scale, self.noise_dim = GENERATORCONFIGS[dataset] 143 | # self.hidden_dim = self.scale * self.scale * self.cv_configs[0] 144 | # self.latent_dim = self.cv_configs[0] * 2 145 | # self.represent_dims = [self.hidden_dim, self.latent_dim] 146 | # in_features = self.represent_dims[latent_layer_idx] 147 | # out_targets = self.noise_dim 148 | # 149 | # # build layer structure 150 | # layers = [nn.Linear(in_features, units), 151 | # nn.ELU(), 152 | # nn.BatchNorm1d(units)] 153 | # 154 | # for _ in range(n_layers): 155 | # layers.extend([ 156 | # nn.Linear(units, units), 157 | # nn.ELU(), 158 | # nn.BatchNorm1d(units)]) 159 | # 160 | # layers.append(nn.Linear(units, out_targets)) 161 | # self.layers = nn.Sequential(*layers) 162 | # 163 | # def forward(self, x): 164 | # """ 165 | # Forward propagation. 166 | # """ 167 | # out = x.view((x.size(0), -1)) 168 | # out = self.layers(out) 169 | # return out 170 | 171 | class DivLoss(nn.Module): 172 | """ 173 | Diversity loss for improving the performance. 174 | """ 175 | 176 | def __init__(self): 177 | """ 178 | Class initializer. 179 | """ 180 | super().__init__() 181 | 182 | def forward2(self, noises, layer): 183 | """ 184 | Forward propagation. 185 | """ 186 | if len(layer.shape) > 2: 187 | layer = layer.view((layer.size(0), -1)) 188 | chunk_size = layer.size(0) // 2 189 | 190 | ####### diversity loss ######## 191 | eps1, eps2=torch.split(noises, chunk_size, dim=0) 192 | chunk1, chunk2=torch.split(layer, chunk_size, dim=0) 193 | lz=torch.mean(torch.abs(chunk1 - chunk2)) / torch.mean( 194 | torch.abs(eps1 - eps2)) 195 | eps=1 * 1e-5 196 | diversity_loss=1 / (lz + eps) 197 | return diversity_loss 198 | 199 | def forward(self, noises, layer): 200 | """ 201 | Forward propagation. 202 | """ 203 | if len(layer.shape) > 2: 204 | layer=layer.view((layer.size(0), -1)) 205 | chunk_size=layer.size(0) // 2 206 | 207 | ####### diversity loss ######## 208 | eps1, eps2=torch.split(noises, chunk_size, dim=0) 209 | chunk1, chunk2=torch.split(layer, chunk_size, dim=0) 210 | lz=torch.mean(torch.abs(chunk1 - chunk2)) / torch.mean( 211 | torch.abs(eps1 - eps2)) 212 | eps=1 * 1e-5 213 | diversity_loss=1 / (lz + eps) 214 | return diversity_loss 215 | 216 | class DiversityLoss(nn.Module): 217 | """ 218 | Diversity loss for improving the performance. 219 | """ 220 | 221 | def __init__(self, metric): 222 | """ 223 | Class initializer. 224 | """ 225 | super().__init__() 226 | self.metric = metric 227 | self.cosine = nn.CosineSimilarity(dim=2) 228 | 229 | def compute_distance(self, tensor1, tensor2, metric): 230 | """ 231 | Compute the distance between two tensors. 232 | """ 233 | if metric == 'l1': 234 | return torch.abs(tensor1 - tensor2).mean(dim=(2,)) 235 | elif metric == 'l2': 236 | return torch.pow(tensor1 - tensor2, 2).mean(dim=(2,)) 237 | elif metric == 'cosine': 238 | return 1 - self.cosine(tensor1, tensor2) 239 | else: 240 | raise ValueError(metric) 241 | 242 | def pairwise_distance(self, tensor, how): 243 | """ 244 | Compute the pairwise distances between a Tensor's rows. 245 | """ 246 | n_data = tensor.size(0) 247 | tensor1 = tensor.expand((n_data, n_data, tensor.size(1))) 248 | tensor2 = tensor.unsqueeze(dim=1) 249 | return self.compute_distance(tensor1, tensor2, how) 250 | 251 | def forward(self, noises, layer): 252 | """ 253 | Forward propagation. 254 | """ 255 | if len(layer.shape) > 2: 256 | layer = layer.view((layer.size(0), -1)) 257 | layer_dist = self.pairwise_distance(layer, how=self.metric) 258 | noise_dist = self.pairwise_distance(noises, how='l2') 259 | return torch.exp(torch.mean(-noise_dist * layer_dist)) 260 | -------------------------------------------------------------------------------- /FLAlgorithms/servers/serverbase.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import h5py 5 | from utils.model_utils import get_dataset_name, RUNCONFIGS 6 | import copy 7 | import torch.nn.functional as F 8 | import time 9 | import torch.nn as nn 10 | from utils.model_utils import get_log_path, METRICS 11 | 12 | class Server: 13 | def __init__(self, args, model, seed): 14 | 15 | # Set up the main attributes 16 | self.dataset = args.dataset 17 | self.num_glob_iters = args.num_glob_iters 18 | self.local_epochs = args.local_epochs 19 | self.batch_size = args.batch_size 20 | self.learning_rate = args.learning_rate 21 | self.total_train_samples = 0 22 | self.K = args.K 23 | self.model = copy.deepcopy(model[0]) 24 | self.model_name = model[1] 25 | self.users = [] 26 | self.selected_users = [] 27 | self.num_users = args.num_users 28 | self.beta = args.beta 29 | self.lamda = args.lamda 30 | self.algorithm = args.algorithm 31 | self.personalized = 'pFed' in self.algorithm 32 | self.mode='partial' if 'partial' in self.algorithm.lower() else 'all' 33 | self.seed = seed 34 | self.deviations = {} 35 | self.metrics = {key:[] for key in METRICS} 36 | self.timestamp = None 37 | self.save_path = args.result_path 38 | os.system("mkdir -p {}".format(self.save_path)) 39 | 40 | 41 | def init_ensemble_configs(self): 42 | #### used for ensemble learning #### 43 | dataset_name = get_dataset_name(self.dataset) 44 | self.ensemble_lr = RUNCONFIGS[dataset_name].get('ensemble_lr', 1e-4) 45 | self.ensemble_batch_size = RUNCONFIGS[dataset_name].get('ensemble_batch_size', 128) 46 | self.ensemble_epochs = RUNCONFIGS[dataset_name]['ensemble_epochs'] 47 | self.num_pretrain_iters = RUNCONFIGS[dataset_name]['num_pretrain_iters'] 48 | self.temperature = RUNCONFIGS[dataset_name].get('temperature', 1) 49 | self.unique_labels = RUNCONFIGS[dataset_name]['unique_labels'] 50 | self.ensemble_alpha = RUNCONFIGS[dataset_name].get('ensemble_alpha', 1) 51 | self.ensemble_beta = RUNCONFIGS[dataset_name].get('ensemble_beta', 0) 52 | self.ensemble_eta = RUNCONFIGS[dataset_name].get('ensemble_eta', 1) 53 | self.weight_decay = RUNCONFIGS[dataset_name].get('weight_decay', 0) 54 | self.generative_alpha = RUNCONFIGS[dataset_name]['generative_alpha'] 55 | self.generative_beta = RUNCONFIGS[dataset_name]['generative_beta'] 56 | self.ensemble_train_loss = [] 57 | self.n_teacher_iters = 5 58 | self.n_student_iters = 1 59 | print("ensemble_lr: {}".format(self.ensemble_lr) ) 60 | print("ensemble_batch_size: {}".format(self.ensemble_batch_size) ) 61 | print("unique_labels: {}".format(self.unique_labels) ) 62 | 63 | 64 | def if_personalized(self): 65 | return 'pFed' in self.algorithm or 'PerAvg' in self.algorithm 66 | 67 | def if_ensemble(self): 68 | return 'FedE' in self.algorithm 69 | 70 | def send_parameters(self, mode='all', beta=1, selected=False): 71 | users = self.users 72 | if selected: 73 | assert (self.selected_users is not None and len(self.selected_users) > 0) 74 | users = self.selected_users 75 | for user in users: 76 | if mode == 'all': # share only subset of parameters 77 | user.set_parameters(self.model,beta=beta) 78 | else: # share all parameters 79 | user.set_shared_parameters(self.model,mode=mode) 80 | 81 | 82 | def add_parameters(self, user, ratio, partial=False): 83 | if partial: 84 | for server_param, user_param in zip(self.model.get_shared_parameters(), user.model.get_shared_parameters()): 85 | server_param.data = server_param.data + user_param.data.clone() * ratio 86 | else: 87 | for server_param, user_param in zip(self.model.parameters(), user.model.parameters()): 88 | server_param.data = server_param.data + user_param.data.clone() * ratio 89 | 90 | 91 | 92 | def aggregate_parameters(self,partial=False): 93 | assert (self.selected_users is not None and len(self.selected_users) > 0) 94 | if partial: 95 | for param in self.model.get_shared_parameters(): 96 | param.data = torch.zeros_like(param.data) 97 | else: 98 | for param in self.model.parameters(): 99 | param.data = torch.zeros_like(param.data) 100 | total_train = 0 101 | for user in self.selected_users: 102 | total_train += user.train_samples 103 | for user in self.selected_users: 104 | self.add_parameters(user, user.train_samples / total_train,partial=partial) 105 | 106 | 107 | def save_model(self): 108 | model_path = os.path.join("models", self.dataset) 109 | if not os.path.exists(model_path): 110 | os.makedirs(model_path) 111 | torch.save(self.model, os.path.join(model_path, "server" + ".pt")) 112 | 113 | 114 | def load_model(self): 115 | model_path = os.path.join("models", self.dataset, "server" + ".pt") 116 | assert (os.path.exists(model_path)) 117 | self.model = torch.load(model_path) 118 | 119 | def model_exists(self): 120 | return os.path.exists(os.path.join("models", self.dataset, "server" + ".pt")) 121 | 122 | def select_users(self, round, num_users, return_idx=False): 123 | '''selects num_clients clients weighted by number of samples from possible_clients 124 | Args: 125 | num_clients: number of clients to select; default 20 126 | note that within function, num_clients is set to 127 | min(num_clients, len(possible_clients)) 128 | Return: 129 | list of selected clients objects 130 | ''' 131 | if(num_users == len(self.users)): 132 | print("All users are selected") 133 | return self.users 134 | 135 | num_users = min(num_users, len(self.users)) 136 | if return_idx: 137 | user_idxs = np.random.choice(range(len(self.users)), num_users, replace=False) # , p=pk) 138 | return [self.users[i] for i in user_idxs], user_idxs 139 | else: 140 | return np.random.choice(self.users, num_users, replace=False) 141 | 142 | 143 | def init_loss_fn(self): 144 | self.loss=nn.NLLLoss() 145 | self.ensemble_loss=nn.KLDivLoss(reduction="batchmean")#,log_target=True) 146 | self.ce_loss = nn.CrossEntropyLoss() 147 | 148 | 149 | def save_results(self, args): 150 | alg = get_log_path(args, args.algorithm, self.seed, args.gen_batch_size) 151 | with h5py.File("./{}/{}.h5".format(self.save_path, alg), 'w') as hf: 152 | for key in self.metrics: 153 | hf.create_dataset(key, data=self.metrics[key]) 154 | hf.close() 155 | 156 | 157 | def test(self, selected=False): 158 | '''tests self.latest_model on given clients 159 | ''' 160 | num_samples = [] 161 | tot_correct = [] 162 | losses = [] 163 | users = self.selected_users if selected else self.users 164 | for c in users: 165 | ct, c_loss, ns = c.test() 166 | tot_correct.append(ct*1.0) 167 | num_samples.append(ns) 168 | losses.append(c_loss) 169 | ids = [c.id for c in self.users] 170 | 171 | return ids, num_samples, tot_correct, losses 172 | 173 | 174 | 175 | def test_personalized_model(self, selected=True): 176 | '''tests self.latest_model on given clients 177 | ''' 178 | num_samples = [] 179 | tot_correct = [] 180 | losses = [] 181 | users = self.selected_users if selected else self.users 182 | for c in users: 183 | ct, ns, loss = c.test_personalized_model() 184 | tot_correct.append(ct*1.0) 185 | num_samples.append(ns) 186 | losses.append(loss) 187 | ids = [c.id for c in self.users] 188 | 189 | return ids, num_samples, tot_correct, losses 190 | 191 | def evaluate_personalized_model(self, selected=True, save=True): 192 | stats = self.test_personalized_model(selected=selected) 193 | test_ids, test_num_samples, test_tot_correct, test_losses = stats[:4] 194 | glob_acc = np.sum(test_tot_correct)*1.0/np.sum(test_num_samples) 195 | test_loss = np.sum([x * y for (x, y) in zip(test_num_samples, test_losses)]).item() / np.sum(test_num_samples) 196 | if save: 197 | self.metrics['per_acc'].append(glob_acc) 198 | self.metrics['per_loss'].append(test_loss) 199 | print("Average Global Accurancy = {:.4f}, Loss = {:.2f}.".format(glob_acc, test_loss)) 200 | 201 | 202 | def evaluate_ensemble(self, selected=True): 203 | self.model.eval() 204 | users = self.selected_users if selected else self.users 205 | test_acc=0 206 | loss=0 207 | for x, y in self.testloaderfull: 208 | target_logit_output=0 209 | for user in users: 210 | # get user logit 211 | user.model.eval() 212 | user_result=user.model(x, logit=True) 213 | target_logit_output+=user_result['logit'] 214 | target_logp=F.log_softmax(target_logit_output, dim=1) 215 | test_acc+= torch.sum( torch.argmax(target_logp, dim=1) == y ) #(torch.sum().item() 216 | loss+=self.loss(target_logp, y) 217 | loss = loss.detach().numpy() 218 | test_acc = test_acc.detach().numpy() / y.shape[0] 219 | self.metrics['glob_acc'].append(test_acc) 220 | self.metrics['glob_loss'].append(loss) 221 | print("Average Global Accurancy = {:.4f}, Loss = {:.2f}.".format(test_acc, loss)) 222 | 223 | 224 | def evaluate(self, save=True, selected=False): 225 | # override evaluate function to log vae-loss. 226 | test_ids, test_samples, test_accs, test_losses = self.test(selected=selected) 227 | glob_acc = np.sum(test_accs)*1.0/np.sum(test_samples) 228 | glob_loss = np.sum([x * y.detach() for (x, y) in zip(test_samples, test_losses)]).item() / np.sum(test_samples) 229 | if save: 230 | self.metrics['glob_acc'].append(glob_acc) 231 | self.metrics['glob_loss'].append(glob_loss) 232 | print("Average Global Accurancy = {:.4f}, Loss = {:.2f}.".format(glob_acc, glob_loss)) 233 | 234 | -------------------------------------------------------------------------------- /data/EMnist/generate_niid_dirichlet.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | import numpy as np 3 | import random 4 | import json 5 | import os 6 | import argparse 7 | from torchvision.datasets import EMNIST 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | 13 | random.seed(42) 14 | np.random.seed(42) 15 | 16 | def rearrange_data_by_class(data, targets, n_class): 17 | new_data = [] 18 | for i in trange(n_class): 19 | idx = targets == i 20 | new_data.append(data[idx]) 21 | return new_data 22 | 23 | def get_dataset(mode='train', split='balanced'): 24 | transform = transforms.Compose( 25 | [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) 26 | 27 | dataset = EMNIST(root='./data', split=split, train=True if mode=='train' else False, download=True, transform=transform) 28 | n_sample = len(dataset.data) 29 | SRC_N_CLASS = len(dataset.classes) 30 | # full batch 31 | trainloader = DataLoader(dataset, batch_size=n_sample, shuffle=False) 32 | 33 | print("Loading data from storage ...") 34 | for _, xy in enumerate(trainloader, 0): 35 | dataset.data, dataset.targets = xy 36 | 37 | print("Rearrange data by class...") 38 | data_by_class = rearrange_data_by_class( 39 | dataset.data.cpu().detach().numpy(), 40 | dataset.targets.cpu().detach().numpy(), 41 | SRC_N_CLASS 42 | ) 43 | if split == 'letters': 44 | data_by_class.pop(0) 45 | SRC_N_CLASS-=1 46 | 47 | print(f"{mode.upper()} SET:\n Total #samples: {n_sample}. sample shape: {dataset.data[0].shape}") 48 | print(" #samples per class:\n", [len(v) for v in data_by_class]) 49 | 50 | return data_by_class, n_sample, SRC_N_CLASS 51 | 52 | def sample_class(SRC_N_CLASS, NUM_LABELS, user_id, label_random=False): 53 | assert NUM_LABELS <= SRC_N_CLASS 54 | if label_random: 55 | source_classes = [n for n in range(SRC_N_CLASS)] 56 | random.shuffle(source_classes) 57 | return source_classes[:NUM_LABELS] 58 | else: 59 | return [(user_id + j) % SRC_N_CLASS for j in range(NUM_LABELS)] 60 | 61 | def devide_train_data(data, n_sample, SRC_CLASSES, NUM_USERS, min_sample, alpha=0.5, sampling_ratio=0.5): 62 | min_sample = len(SRC_CLASSES) * min_sample 63 | min_size = 0 # track minimal samples per user 64 | ###### Determine Sampling ####### 65 | while min_size < min_sample: 66 | print("Try to find valid data separation") 67 | idx_batch=[{} for _ in range(NUM_USERS)] 68 | samples_per_user = [0 for _ in range(NUM_USERS)] 69 | max_samples_per_user = sampling_ratio * n_sample / NUM_USERS 70 | for l in SRC_CLASSES: 71 | # get indices for all that label 72 | idx_l = [i for i in range(len(data[l]))] 73 | np.random.shuffle(idx_l) 74 | if sampling_ratio < 1: 75 | samples_for_l = min(max_samples_per_user, int(sampling_ratio * len(data[l]))) 76 | idx_l = idx_l[:samples_for_l] 77 | print(l, len(data[l]), len(idx_l)) 78 | # dirichlet sampling from this label 79 | proportions=np.random.dirichlet(np.repeat(alpha, NUM_USERS)) 80 | # re-balance proportions 81 | proportions=np.array([p * (n_per_user < max_samples_per_user) for p, n_per_user in zip(proportions, samples_per_user)]) 82 | proportions=proportions / proportions.sum() 83 | proportions=(np.cumsum(proportions) * len(idx_l)).astype(int)[:-1] 84 | # participate data of that label 85 | for u, new_idx in enumerate(np.split(idx_l, proportions)): 86 | # add new idex to the user 87 | idx_batch[u][l] = new_idx.tolist() 88 | samples_per_user[u] += len(idx_batch[u][l]) 89 | min_size=min(samples_per_user) 90 | 91 | ###### CREATE USER DATA SPLIT ####### 92 | X = [[] for _ in range(NUM_USERS)] 93 | y = [[] for _ in range(NUM_USERS)] 94 | Labels=[set() for _ in range(NUM_USERS)] 95 | print("processing users...") 96 | for u, user_idx_batch in enumerate(idx_batch): 97 | for l, indices in user_idx_batch.items(): 98 | if len(indices) == 0: continue 99 | X[u] += data[l][indices].tolist() 100 | y[u] += (l * np.ones(len(indices))).tolist() 101 | Labels[u].add(l) 102 | 103 | return X, y, Labels, idx_batch, samples_per_user 104 | 105 | 106 | def divide_test_data(NUM_USERS, SRC_CLASSES, test_data, Labels, unknown_test): 107 | # Create TEST data for each user. 108 | test_X = [[] for _ in range(NUM_USERS)] 109 | test_y = [[] for _ in range(NUM_USERS)] 110 | idx = {l: 0 for l in SRC_CLASSES} 111 | for user in trange(NUM_USERS): 112 | if unknown_test: # use all available labels 113 | user_sampled_labels = SRC_CLASSES 114 | else: 115 | user_sampled_labels = list(Labels[user]) 116 | for l in user_sampled_labels: 117 | num_samples = int(len(test_data[l]) / NUM_USERS ) 118 | assert num_samples + idx[l] <= len(test_data[l]) 119 | test_X[user] += test_data[l][idx[l]:idx[l] + num_samples].tolist() 120 | test_y[user] += (l * np.ones(num_samples)).tolist() 121 | assert len(test_X[user]) == len(test_y[user]), f"{len(test_X[user])} == {len(test_y[user])}" 122 | idx[l] += num_samples 123 | return test_X, test_y 124 | 125 | def main(): 126 | parser = argparse.ArgumentParser() 127 | parser.add_argument("--format", "-f", type=str, default="pt", help="Format of saving: pt (torch.save), json", choices=["pt", "json"]) 128 | parser.add_argument("--n_class", type=int, default=5, help="number of classification labels") 129 | parser.add_argument("--random_label", type=int, default=1, help="randomly sampling labels for each task.") 130 | parser.add_argument("--dirname", type=str, default='', help="Directory name, optional.") 131 | parser.add_argument("--split", type=str, default='letters', choices=["letters", "balanced"]) 132 | parser.add_argument("--min_sample", type=int, default=10, help="Min number of samples per user.") 133 | parser.add_argument("--sampling_ratio", type=float, default=0.5, help="Ratio for sampling training samples.") 134 | parser.add_argument("--unknown_test", type=int, default=0, help="Whether allow test label unseen for each user.") 135 | parser.add_argument("--alpha", type=float, default=0.5, help="alpha in Dirichelt distribution (smaller means larger heterogeneity)") 136 | parser.add_argument("--n_user", type=int, default=20, 137 | help="number of local clients, should be muitiple of 10.") 138 | args = parser.parse_args() 139 | print() 140 | print("Number of users: {}".format(args.n_user)) 141 | print("Number of classes: {}".format(args.n_class)) 142 | print("Min # of samples per uesr: {}".format(args.min_sample)) 143 | print("Alpha for Dirichlet Distribution: {}".format(args.alpha)) 144 | print("Ratio for Sampling Training Data: {}".format(args.sampling_ratio)) 145 | NUM_USERS = args.n_user 146 | 147 | # Setup directory for train/test data 148 | path_prefix = f'u{args.n_user}-{args.split}-alpha{args.alpha}-ratio{args.sampling_ratio}' 149 | 150 | def process_user_data(mode, data, n_sample, SRC_CLASSES, Labels=None, unknown_test=0): 151 | if mode == 'train': 152 | X, y, Labels, idx_batch, samples_per_user = devide_train_data( 153 | data, n_sample, SRC_CLASSES, NUM_USERS, args.min_sample, args.alpha, args.sampling_ratio) 154 | if mode == 'test': 155 | assert Labels != None or unknown_test 156 | X, y = divide_test_data(NUM_USERS, SRC_CLASSES, data, Labels, unknown_test) 157 | dataset={'users': [], 'user_data': {}, 'num_samples': []} 158 | for i in range(NUM_USERS): 159 | uname='f_{0:05d}'.format(i) 160 | dataset['users'].append(uname) 161 | dataset['user_data'][uname]={ 162 | 'x': torch.tensor(X[i], dtype=torch.float32), 163 | 'y': torch.tensor(y[i], dtype=torch.int64)} 164 | dataset['num_samples'].append(len(X[i])) 165 | 166 | print("{} #sample by user:".format(mode.upper()), dataset['num_samples']) 167 | 168 | data_path=f'./{path_prefix}/{mode}' 169 | if not os.path.exists(data_path): 170 | os.makedirs(data_path) 171 | 172 | data_path=os.path.join(data_path, "{}.".format(mode) + args.format) 173 | if args.format == "json": 174 | raise NotImplementedError( 175 | "json is not supported because the train_data/test_data uses the tensor instead of list and tensor cannot be saved into json.") 176 | with open(data_path, 'w') as outfile: 177 | print(f"Dumping train data => {data_path}") 178 | json.dump(dataset, outfile) 179 | elif args.format == "pt": 180 | with open(data_path, 'wb') as outfile: 181 | print(f"Dumping train data => {data_path}") 182 | torch.save(dataset, outfile) 183 | if mode == 'train': 184 | for u in range(NUM_USERS): 185 | print("{} samples in total".format(samples_per_user[u])) 186 | train_info = '' 187 | # train_idx_batch, train_samples_per_user 188 | n_samples_for_u = 0 189 | for l in sorted(list(Labels[u])): 190 | n_samples_for_l = len(idx_batch[u][l]) 191 | n_samples_for_u += n_samples_for_l 192 | train_info += "c={},n={}| ".format(l, n_samples_for_l) 193 | print(train_info) 194 | print("{} Labels/ {} Number of training samples for user [{}]:".format(len(Labels[u]), n_samples_for_u, u)) 195 | return Labels, idx_batch, samples_per_user 196 | 197 | 198 | print(f"Reading source dataset.") 199 | train_data, n_train_sample, SRC_N_CLASS = get_dataset(mode='train', split=args.split) 200 | test_data, n_test_sample, SRC_N_CLASS = get_dataset(mode='test', split=args.split) 201 | SRC_CLASSES=[l for l in range(SRC_N_CLASS)] 202 | random.shuffle(SRC_CLASSES) 203 | Labels, idx_batch, samples_per_user = process_user_data('train', train_data, n_train_sample, SRC_CLASSES) 204 | process_user_data('test', test_data, n_test_sample, SRC_CLASSES, Labels=Labels, unknown_test=args.unknown_test) 205 | print("Finish Generating User samples") 206 | 207 | 208 | if __name__ == "__main__": 209 | main() -------------------------------------------------------------------------------- /FLAlgorithms/servers/serverpFedGen.py: -------------------------------------------------------------------------------- 1 | from FLAlgorithms.users.userpFedGen import UserpFedGen 2 | from FLAlgorithms.servers.serverbase import Server 3 | from utils.model_utils import read_data, read_user_data, aggregate_user_data, create_generative_model 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from torchvision.utils import save_image 9 | import os 10 | import copy 11 | import time 12 | MIN_SAMPLES_PER_LABEL=1 13 | 14 | class FedGen(Server): 15 | def __init__(self, args, model, seed): 16 | super().__init__(args, model, seed) 17 | 18 | # Initialize data for all users 19 | data = read_data(args.dataset) 20 | # data contains: clients, groups, train_data, test_data, proxy_data 21 | clients = data[0] 22 | total_users = len(clients) 23 | self.total_test_samples = 0 24 | self.local = 'local' in self.algorithm.lower() 25 | self.use_adam = 'adam' in self.algorithm.lower() 26 | 27 | self.early_stop = 20 # stop using generated samples after 20 local epochs 28 | self.student_model = copy.deepcopy(self.model) 29 | self.generative_model = create_generative_model(args.dataset, args.algorithm, self.model_name, args.embedding) 30 | if not args.train: 31 | print('number of generator parameteres: [{}]'.format(self.generative_model.get_number_of_parameters())) 32 | print('number of model parameteres: [{}]'.format(self.model.get_number_of_parameters())) 33 | self.latent_layer_idx = self.generative_model.latent_layer_idx 34 | self.init_ensemble_configs() 35 | print("latent_layer_idx: {}".format(self.latent_layer_idx)) 36 | print("label embedding {}".format(self.generative_model.embedding)) 37 | print("ensemeble learning rate: {}".format(self.ensemble_lr)) 38 | print("ensemeble alpha = {}, beta = {}, eta = {}".format(self.ensemble_alpha, self.ensemble_beta, self.ensemble_eta)) 39 | print("generator alpha = {}, beta = {}".format(self.generative_alpha, self.generative_beta)) 40 | self.init_loss_fn() 41 | self.train_data_loader, self.train_iter, self.available_labels = aggregate_user_data(data, args.dataset, self.ensemble_batch_size) 42 | self.generative_optimizer = torch.optim.Adam( 43 | params=self.generative_model.parameters(), 44 | lr=self.ensemble_lr, betas=(0.9, 0.999), 45 | eps=1e-08, weight_decay=self.weight_decay, amsgrad=False) 46 | self.generative_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( 47 | optimizer=self.generative_optimizer, gamma=0.98) 48 | self.optimizer = torch.optim.Adam( 49 | params=self.model.parameters(), 50 | lr=self.ensemble_lr, betas=(0.9, 0.999), 51 | eps=1e-08, weight_decay=0, amsgrad=False) 52 | self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.optimizer, gamma=0.98) 53 | 54 | #### creating users #### 55 | self.users = [] 56 | for i in range(total_users): 57 | id, train_data, test_data, label_info =read_user_data(i, data, dataset=args.dataset, count_labels=True) 58 | self.total_train_samples+=len(train_data) 59 | self.total_test_samples += len(test_data) 60 | id, train, test=read_user_data(i, data, dataset=args.dataset) 61 | user=UserpFedGen( 62 | args, id, model, self.generative_model, 63 | train_data, test_data, 64 | self.available_labels, self.latent_layer_idx, label_info, 65 | use_adam=self.use_adam) 66 | self.users.append(user) 67 | print("Number of Train/Test samples:", self.total_train_samples, self.total_test_samples) 68 | print("Data from {} users in total.".format(total_users)) 69 | print("Finished creating FedAvg server.") 70 | 71 | def train(self, args): 72 | #### pretraining 73 | for glob_iter in range(self.num_glob_iters): 74 | print("\n\n-------------Round number: ",glob_iter, " -------------\n\n") 75 | self.selected_users, self.user_idxs=self.select_users(glob_iter, self.num_users, return_idx=True) 76 | if not self.local: 77 | self.send_parameters(mode=self.mode)# broadcast averaged prediction model 78 | self.evaluate() 79 | chosen_verbose_user = np.random.randint(0, len(self.users)) 80 | self.timestamp = time.time() # log user-training start time 81 | for user_id, user in zip(self.user_idxs, self.selected_users): # allow selected users to train 82 | verbose= user_id == chosen_verbose_user 83 | # perform regularization using generated samples after the first communication round 84 | user.train( 85 | glob_iter, 86 | personalized=self.personalized, 87 | early_stop=self.early_stop, 88 | verbose=verbose and glob_iter > 0, 89 | regularization= glob_iter > 0 ) 90 | curr_timestamp = time.time() # log user-training end time 91 | train_time = (curr_timestamp - self.timestamp) / len(self.selected_users) 92 | self.metrics['user_train_time'].append(train_time) 93 | if self.personalized: 94 | self.evaluate_personalized_model() 95 | 96 | self.timestamp = time.time() # log server-agg start time 97 | self.train_generator( 98 | self.batch_size, 99 | epoches=self.ensemble_epochs // self.n_teacher_iters, 100 | latent_layer_idx=self.latent_layer_idx, 101 | verbose=True 102 | ) 103 | self.aggregate_parameters() 104 | curr_timestamp=time.time() # log server-agg end time 105 | agg_time = curr_timestamp - self.timestamp 106 | self.metrics['server_agg_time'].append(agg_time) 107 | if glob_iter > 0 and glob_iter % 20 == 0 and self.latent_layer_idx == 0: 108 | self.visualize_images(self.generative_model, glob_iter, repeats=10) 109 | 110 | self.save_results(args) 111 | self.save_model() 112 | 113 | def train_generator(self, batch_size, epoches=1, latent_layer_idx=-1, verbose=False): 114 | """ 115 | Learn a generator that find a consensus latent representation z, given a label 'y'. 116 | :param batch_size: 117 | :param epoches: 118 | :param latent_layer_idx: if set to -1 (-2), get latent representation of the last (or 2nd to last) layer. 119 | :param verbose: print loss information. 120 | :return: Do not return anything. 121 | """ 122 | #self.generative_regularizer.train() 123 | self.label_weights, self.qualified_labels = self.get_label_weights() 124 | TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS, STUDENT_LOSS2 = 0, 0, 0, 0 125 | 126 | def update_generator_(n_iters, student_model, TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS): 127 | self.generative_model.train() 128 | student_model.eval() 129 | for i in range(n_iters): 130 | self.generative_optimizer.zero_grad() 131 | y=np.random.choice(self.qualified_labels, batch_size) 132 | y_input=torch.LongTensor(y) 133 | ## feed to generator 134 | gen_result=self.generative_model(y_input, latent_layer_idx=latent_layer_idx, verbose=True) 135 | # get approximation of Z( latent) if latent set to True, X( raw image) otherwise 136 | gen_output, eps=gen_result['output'], gen_result['eps'] 137 | ##### get losses #### 138 | # decoded = self.generative_regularizer(gen_output) 139 | # regularization_loss = beta * self.generative_model.dist_loss(decoded, eps) # map generated z back to eps 140 | diversity_loss=self.generative_model.diversity_loss(eps, gen_output) # encourage different outputs 141 | 142 | ######### get teacher loss ############ 143 | teacher_loss=0 144 | teacher_logit=0 145 | for user_idx, user in enumerate(self.selected_users): 146 | user.model.eval() 147 | weight=self.label_weights[y][:, user_idx].reshape(-1, 1) 148 | expand_weight=np.tile(weight, (1, self.unique_labels)) 149 | user_result_given_gen=user.model(gen_output, start_layer_idx=latent_layer_idx, logit=True) 150 | user_output_logp_=F.log_softmax(user_result_given_gen['logit'], dim=1) 151 | teacher_loss_=torch.mean( \ 152 | self.generative_model.crossentropy_loss(user_output_logp_, y_input) * \ 153 | torch.tensor(weight, dtype=torch.float32)) 154 | teacher_loss+=teacher_loss_ 155 | teacher_logit+=user_result_given_gen['logit'] * torch.tensor(expand_weight, dtype=torch.float32) 156 | 157 | ######### get student loss ############ 158 | student_output=student_model(gen_output, start_layer_idx=latent_layer_idx, logit=True) 159 | student_loss=F.kl_div(F.log_softmax(student_output['logit'], dim=1), F.softmax(teacher_logit, dim=1)) 160 | if self.ensemble_beta > 0: 161 | loss=self.ensemble_alpha * teacher_loss - self.ensemble_beta * student_loss + self.ensemble_eta * diversity_loss 162 | else: 163 | loss=self.ensemble_alpha * teacher_loss + self.ensemble_eta * diversity_loss 164 | loss.backward() 165 | self.generative_optimizer.step() 166 | TEACHER_LOSS += self.ensemble_alpha * teacher_loss#(torch.mean(TEACHER_LOSS.double())).item() 167 | STUDENT_LOSS += self.ensemble_beta * student_loss#(torch.mean(student_loss.double())).item() 168 | DIVERSITY_LOSS += self.ensemble_eta * diversity_loss#(torch.mean(diversity_loss.double())).item() 169 | return TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS 170 | 171 | for i in range(epoches): 172 | TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS=update_generator_( 173 | self.n_teacher_iters, self.model, TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS) 174 | 175 | TEACHER_LOSS = TEACHER_LOSS.detach().numpy() / (self.n_teacher_iters * epoches) 176 | STUDENT_LOSS = STUDENT_LOSS.detach().numpy() / (self.n_teacher_iters * epoches) 177 | DIVERSITY_LOSS = DIVERSITY_LOSS.detach().numpy() / (self.n_teacher_iters * epoches) 178 | info="Generator: Teacher Loss= {:.4f}, Student Loss= {:.4f}, Diversity Loss = {:.4f}, ". \ 179 | format(TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS) 180 | if verbose: 181 | print(info) 182 | self.generative_lr_scheduler.step() 183 | 184 | 185 | def get_label_weights(self): 186 | label_weights = [] 187 | qualified_labels = [] 188 | for label in range(self.unique_labels): 189 | weights = [] 190 | for user in self.selected_users: 191 | weights.append(user.label_counts[label]) 192 | if np.max(weights) > MIN_SAMPLES_PER_LABEL: 193 | qualified_labels.append(label) 194 | # uniform 195 | label_weights.append( np.array(weights) / np.sum(weights) ) 196 | label_weights = np.array(label_weights).reshape((self.unique_labels, -1)) 197 | return label_weights, qualified_labels 198 | 199 | def visualize_images(self, generator, glob_iter, repeats=1): 200 | """ 201 | Generate and visualize data for a generator. 202 | """ 203 | os.system("mkdir -p images") 204 | path = f'images/{self.algorithm}-{self.dataset}-iter{glob_iter}.png' 205 | y=self.available_labels 206 | y = np.repeat(y, repeats=repeats, axis=0) 207 | y_input=torch.tensor(y) 208 | generator.eval() 209 | images=generator(y_input, latent=False)['output'] # 0,1,..,K, 0,1,...,K 210 | images=images.view(repeats, -1, *images.shape[1:]) 211 | images=images.view(-1, *images.shape[2:]) 212 | save_image(images.detach(), path, nrow=repeats, normalize=True) 213 | print("Image saved to {}".format(path)) -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | from tqdm import trange 9 | import random 10 | import numpy as np 11 | from FLAlgorithms.trainmodel.models import Net 12 | from torch.utils.data import DataLoader 13 | from FLAlgorithms.trainmodel.generator import Generator 14 | from utils.model_config import * 15 | METRICS = ['glob_acc', 'per_acc', 'glob_loss', 'per_loss', 'user_train_time', 'server_agg_time'] 16 | 17 | 18 | def get_data_dir(dataset): 19 | if 'EMnist' in dataset: 20 | #EMnist-alpha0.1-ratio0.1-0-letters 21 | dataset_=dataset.replace('alpha', '').replace('ratio', '').split('-') 22 | alpha, ratio =dataset_[1], dataset_[2] 23 | types = 'letters' 24 | path_prefix = os.path.join('data', 'EMnist', f'u20-{types}-alpha{alpha}-ratio{ratio}') 25 | train_data_dir=os.path.join(path_prefix, 'train') 26 | test_data_dir=os.path.join(path_prefix, 'test') 27 | proxy_data_dir = 'data/proxy_data/emnist-n10/' 28 | 29 | elif 'Mnist' in dataset: 30 | dataset_=dataset.replace('alpha', '').replace('ratio', '').split('-') 31 | alpha, ratio=dataset_[1], dataset_[2] 32 | #path_prefix=os.path.join('data', 'Mnist', 'u20alpha{}min10ratio{}'.format(alpha, ratio)) 33 | path_prefix=os.path.join('data', 'Mnist', 'u20c10-alpha{}-ratio{}'.format(alpha, ratio)) 34 | train_data_dir=os.path.join(path_prefix, 'train') 35 | test_data_dir=os.path.join(path_prefix, 'test') 36 | proxy_data_dir = 'data/proxy_data/mnist-n10/' 37 | 38 | elif 'celeb' in dataset.lower(): 39 | dataset_ = dataset.lower().replace('user', '').replace('agg','').split('-') 40 | user, agg_user = dataset_[1], dataset_[2] 41 | path_prefix = os.path.join('data', 'CelebA', 'user{}-agg{}'.format(user,agg_user)) 42 | train_data_dir=os.path.join(path_prefix, 'train') 43 | test_data_dir=os.path.join(path_prefix, 'test') 44 | proxy_data_dir=os.path.join('/user500/', 'proxy') 45 | 46 | else: 47 | raise ValueError("Dataset not recognized.") 48 | return train_data_dir, test_data_dir, proxy_data_dir 49 | 50 | 51 | def read_data(dataset): 52 | '''parses data in given train and test data directories 53 | 54 | assumes: 55 | - the data in the input directories are .json files with 56 | keys 'users' and 'user_data' 57 | - the set of train set users is the same as the set of test set users 58 | 59 | Return: 60 | clients: list of client ids 61 | groups: list of group ids; empty list if none found 62 | train_data: dictionary of train data 63 | test_data: dictionary of test data 64 | ''' 65 | train_data_dir, test_data_dir, proxy_data_dir = get_data_dir(dataset) 66 | clients = [] 67 | groups = [] 68 | train_data = {} 69 | test_data = {} 70 | proxy_data = {} 71 | 72 | train_files = os.listdir(train_data_dir) 73 | train_files = [f for f in train_files if f.endswith('.json') or f.endswith(".pt")] 74 | for f in train_files: 75 | file_path = os.path.join(train_data_dir, f) 76 | if file_path.endswith("json"): 77 | with open(file_path, 'r') as inf: 78 | cdata = json.load(inf) 79 | elif file_path.endswith(".pt"): 80 | with open(file_path, 'rb') as inf: 81 | cdata = torch.load(inf) 82 | else: 83 | raise TypeError("Data format not recognized: {}".format(file_path)) 84 | 85 | clients.extend(cdata['users']) 86 | if 'hierarchies' in cdata: 87 | groups.extend(cdata['hierarchies']) 88 | train_data.update(cdata['user_data']) 89 | 90 | clients = list(sorted(train_data.keys())) 91 | 92 | test_files = os.listdir(test_data_dir) 93 | test_files = [f for f in test_files if f.endswith('.json') or f.endswith(".pt")] 94 | for f in test_files: 95 | file_path = os.path.join(test_data_dir, f) 96 | if file_path.endswith(".pt"): 97 | with open(file_path, 'rb') as inf: 98 | cdata = torch.load(inf) 99 | elif file_path.endswith(".json"): 100 | with open(file_path, 'r') as inf: 101 | cdata = json.load(inf) 102 | else: 103 | raise TypeError("Data format not recognized: {}".format(file_path)) 104 | test_data.update(cdata['user_data']) 105 | 106 | 107 | if proxy_data_dir and os.path.exists(proxy_data_dir): 108 | proxy_files=os.listdir(proxy_data_dir) 109 | proxy_files=[f for f in proxy_files if f.endswith('.json') or f.endswith(".pt")] 110 | for f in proxy_files: 111 | file_path=os.path.join(proxy_data_dir, f) 112 | if file_path.endswith(".pt"): 113 | with open(file_path, 'rb') as inf: 114 | cdata=torch.load(inf) 115 | elif file_path.endswith(".json"): 116 | with open(file_path, 'r') as inf: 117 | cdata=json.load(inf) 118 | else: 119 | raise TypeError("Data format not recognized: {}".format(file_path)) 120 | proxy_data.update(cdata['user_data']) 121 | 122 | return clients, groups, train_data, test_data, proxy_data 123 | 124 | 125 | def read_proxy_data(proxy_data, dataset, batch_size): 126 | X, y=proxy_data['x'], proxy_data['y'] 127 | X, y = convert_data(X, y, dataset=dataset) 128 | dataset = [(x, y) for x, y in zip(X, y)] 129 | proxyloader = DataLoader(dataset, batch_size, shuffle=True) 130 | iter_proxyloader = iter(proxyloader) 131 | return proxyloader, iter_proxyloader 132 | 133 | 134 | def aggregate_data_(clients, dataset, dataset_name, batch_size): 135 | combined = [] 136 | unique_labels = [] 137 | for i in range(len(dataset)): 138 | id = clients[i] 139 | user_data = dataset[id] 140 | X, y = convert_data(user_data['x'], user_data['y'], dataset=dataset_name) 141 | combined += [(x, y) for x, y in zip(X, y)] 142 | unique_y=torch.unique(y) 143 | unique_y = unique_y.detach().numpy() 144 | unique_labels += list(unique_y) 145 | 146 | data_loader=DataLoader(combined, batch_size, shuffle=True) 147 | iter_loader=iter(data_loader) 148 | return data_loader, iter_loader, unique_labels 149 | 150 | 151 | def aggregate_user_test_data(data, dataset_name, batch_size): 152 | clients, loaded_data=data[0], data[3] 153 | data_loader, _, unique_labels=aggregate_data_(clients, loaded_data, dataset_name, batch_size) 154 | return data_loader, np.unique(unique_labels) 155 | 156 | 157 | def aggregate_user_data(data, dataset_name, batch_size): 158 | # data contains: clients, groups, train_data, test_data, proxy_data 159 | clients, loaded_data = data[0], data[2] 160 | data_loader, data_iter, unique_labels = aggregate_data_(clients, loaded_data, dataset_name, batch_size) 161 | return data_loader, data_iter, np.unique(unique_labels) 162 | 163 | 164 | def convert_data(X, y, dataset=''): 165 | if not isinstance(X, torch.Tensor): 166 | if 'celeb' in dataset.lower(): 167 | X=torch.Tensor(X).type(torch.float32).permute(0, 3, 1, 2) 168 | y=torch.Tensor(y).type(torch.int64) 169 | 170 | else: 171 | X=torch.Tensor(X).type(torch.float32) 172 | y=torch.Tensor(y).type(torch.int64) 173 | return X, y 174 | 175 | 176 | def read_user_data(index, data, dataset='', count_labels=False): 177 | #data contains: clients, groups, train_data, test_data, proxy_data(optional) 178 | id = data[0][index] 179 | train_data = data[2][id] 180 | test_data = data[3][id] 181 | X_train, y_train = convert_data(train_data['x'], train_data['y'], dataset=dataset) 182 | train_data = [(x, y) for x, y in zip(X_train, y_train)] 183 | X_test, y_test = convert_data(test_data['x'], test_data['y'], dataset=dataset) 184 | test_data = [(x, y) for x, y in zip(X_test, y_test)] 185 | if count_labels: 186 | label_info = {} 187 | unique_y, counts=torch.unique(y_train, return_counts=True) 188 | unique_y=unique_y.detach().numpy() 189 | counts=counts.detach().numpy() 190 | label_info['labels']=unique_y 191 | label_info['counts']=counts 192 | return id, train_data, test_data, label_info 193 | return id, train_data, test_data 194 | 195 | 196 | def get_dataset_name(dataset): 197 | dataset=dataset.lower() 198 | passed_dataset=dataset.lower() 199 | if 'celeb' in dataset: 200 | passed_dataset='celeb' 201 | elif 'emnist' in dataset: 202 | passed_dataset='emnist' 203 | elif 'mnist' in dataset: 204 | passed_dataset='mnist' 205 | else: 206 | raise ValueError('Unsupported dataset {}'.format(dataset)) 207 | return passed_dataset 208 | 209 | 210 | def create_generative_model(dataset, algorithm='', model='cnn', embedding=False): 211 | passed_dataset=get_dataset_name(dataset) 212 | assert any([alg in algorithm for alg in ['FedGen', 'FedGen']]) 213 | if 'FedGen' in algorithm: 214 | # temporary roundabout to figure out the sensitivity of the generator network & sampling size 215 | if 'cnn' in algorithm: 216 | gen_model = algorithm.split('-')[1] 217 | passed_dataset+='-' + gen_model 218 | elif '-gen' in algorithm: # we use more lightweight network for sensitivity analysis 219 | passed_dataset += '-cnn1' 220 | return Generator(passed_dataset, model=model, embedding=embedding, latent_layer_idx=-1) 221 | 222 | 223 | def create_model(model, dataset, algorithm): 224 | passed_dataset = get_dataset_name(dataset) 225 | model= Net(passed_dataset, model), model 226 | return model 227 | 228 | 229 | def polyak_move(params, target_params, ratio=0.1): 230 | for param, target_param in zip(params, target_params): 231 | param.data=param.data - ratio * (param.clone().detach().data - target_param.clone().detach().data) 232 | 233 | def meta_move(params, target_params, ratio): 234 | for param, target_param in zip(params, target_params): 235 | target_param.data = param.clone().data + ratio * (target_param.clone().data - param.clone().data) 236 | 237 | def moreau_loss(params, reg_params): 238 | # return 1/T \sum_i^T |param_i - reg_param_i|^2 239 | losses = [] 240 | for param, reg_param in zip(params, reg_params): 241 | losses.append( torch.mean(torch.square(param - reg_param.clone().detach())) ) 242 | loss = torch.mean(torch.stack(losses)) 243 | return loss 244 | 245 | def l2_loss(params): 246 | losses = [] 247 | for param in params: 248 | losses.append( torch.mean(torch.square(param))) 249 | loss = torch.mean(torch.stack(losses)) 250 | return loss 251 | 252 | def update_fast_params(fast_weights, grads, lr, allow_unused=False): 253 | """ 254 | Update fast_weights by applying grads. 255 | :param fast_weights: list of parameters. 256 | :param grads: list of gradients 257 | :param lr: 258 | :return: updated fast_weights . 259 | """ 260 | for grad, fast_weight in zip(grads, fast_weights): 261 | if allow_unused and grad is None: continue 262 | grad=torch.clamp(grad, -10, 10) 263 | fast_weight.data = fast_weight.data.clone() - lr * grad 264 | return fast_weights 265 | 266 | 267 | def init_named_params(model, keywords=['encode']): 268 | named_params={} 269 | #named_params_list = [] 270 | for name, params in model.named_layers.items(): 271 | if any([key in name for key in keywords]): 272 | named_params[name]=[param.clone().detach().requires_grad_(True) for param in params] 273 | #named_params_list += named_params[name] 274 | return named_params#, named_params_list 275 | 276 | 277 | 278 | def get_log_path(args, algorithm, seed, gen_batch_size=32): 279 | alg=args.dataset + "_" + algorithm 280 | alg+="_" + str(args.learning_rate) + "_" + str(args.num_users) 281 | alg+="u" + "_" + str(args.batch_size) + "b" + "_" + str(args.local_epochs) 282 | alg=alg + "_" + str(seed) 283 | if 'FedGen' in algorithm: # to accompany experiments for author rebuttal 284 | alg += "_embed" + str(args.embedding) 285 | if int(gen_batch_size) != int(args.batch_size): 286 | alg += "_gb" + str(gen_batch_size) 287 | return alg -------------------------------------------------------------------------------- /FLAlgorithms/users/userpFedCL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import random 5 | import heapq 6 | from random import choice 7 | from sklearn.model_selection import train_test_split 8 | from FLAlgorithms.users.userbase import User 9 | from FLAlgorithms.curriculum.cl_score import CL_User_Score 10 | from FLAlgorithms.trainmodel.generator import Discriminator 11 | import pdb 12 | import decimal 13 | def median(x): 14 | x = sorted(x) 15 | length = len(x) 16 | mid, rem = divmod(length, 2) 17 | if rem: 18 | return x[:mid], x[mid+1:], x[mid] 19 | else: 20 | return x[:mid], x[mid:], x[mid-1] 21 | 22 | 23 | 24 | class UserpFedCL(User): 25 | def __init__(self, 26 | args, id, model, generative_model, 27 | train_data, test_data, 28 | available_labels, latent_layer_idx, label_info, 29 | use_adam=False): 30 | super().__init__(args, id, model, train_data, test_data, use_adam=use_adam) 31 | self.gen_batch_size = args.gen_batch_size 32 | self.generative_model = generative_model 33 | self.Discriminator = Discriminator() 34 | self.latent_layer_idx = latent_layer_idx 35 | self.available_labels = available_labels 36 | self.label_info=label_info 37 | self.CL_User_Score=CL_User_Score 38 | 39 | self.optimizer_discriminator = torch.optim.Adam( 40 | params=self.model.parameters(), 41 | lr=1e-4, betas=(0.9, 0.999), 42 | eps=1e-08, weight_decay=0, amsgrad=False) 43 | self.lr_scheduler_discriminator = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.optimizer_discriminator, gamma=0.98) 44 | 45 | 46 | def exp_lr_scheduler(self, epoch, decay=0.98, init_lr=0.1, lr_decay_epoch=1): 47 | """Decay learning rate by a factor of 0.95 every lr_decay_epoch epochs.""" 48 | lr= max(1e-4, init_lr * (decay ** (epoch // lr_decay_epoch))) 49 | return lr 50 | 51 | def update_label_counts(self, labels, counts): 52 | for label, count in zip(labels, counts): 53 | self.label_counts[int(label)] += count 54 | 55 | def clean_up_counts(self): 56 | del self.label_counts 57 | self.label_counts = {label:1 for label in range(self.unique_labels)} 58 | 59 | def train(self, glob_iter, personalized=False, early_stop=100, regularization=True, verbose=False, run_curriculum=True,cl_score_norm_list=None,server_epoch=None,gmm_=None,gmm_len=None,next_stage=None): 60 | self.clean_up_counts() 61 | self.model.train() 62 | self.generative_model.eval() 63 | part_loss =0 64 | TEACHER_LOSS, DIST_LOSS, LATENT_LOSS, CL_SCORE = 0, 0, 0, 1 65 | exit_ = False 66 | for epoch in range(self.local_epochs): 67 | self.model.train() 68 | #print('epoch:'+str( epoch)) 69 | #for i in range(self.K): 70 | self.optimizer.zero_grad() 71 | #### sample from real dataset (un-weighted) 72 | samples =self.get_next_train_batch(count_labels=True) 73 | X, y = samples['X'], samples['y'] 74 | self.update_label_counts(samples['labels'], samples['counts']) 75 | model_result=self.model(X, logit=True) 76 | user_output_logp_ = model_result['output'] 77 | user_output_logp = model_result['output'] 78 | CL_results = self.CL_User_Score(model_result = model_result, 79 | Algorithms = 'SuperLoss_ce', 80 | loss_fun = self.loss, 81 | y = y, 82 | local_epoch = epoch, 83 | schedule = [epoch,self.local_epochs] 84 | ) 85 | predictive_loss=CL_results['Loss'] 86 | CL_Results_Score = CL_results['Curriculum_Learning_Score'] 87 | 88 | #print(CL_results['score_list']) 89 | #print(CL_results['celoss']) 90 | # ----------------- 91 | # Train Generator 92 | # ----------------- 93 | #### sample y and generate z 94 | if regularization and epoch < early_stop: 95 | 96 | generative_alpha=self.exp_lr_scheduler(glob_iter, decay=0.98, init_lr=self.generative_alpha) 97 | generative_beta=self.exp_lr_scheduler(glob_iter, decay=0.98, init_lr=self.generative_beta) 98 | ### get generator output(latent representation) of the same label 99 | gmm_results,_ = gmm_.sample(1) 100 | 101 | real_cl_score_ = torch.tensor(gmm_results, dtype=torch.float)[:y.size()[0]].view(y.size()[0],1) 102 | 103 | 104 | #real_cl_score_ =torch.tensor(CL_results_score_list, dtype=torch.float).view(y.size()[0],1) 105 | 106 | gen_output_=self.generative_model(y, real_cl_score_, latent_layer_idx=self.latent_layer_idx) 107 | gen_output = gen_output_['output'].clone().detach() 108 | logit_given_gen=self.model(gen_output, start_layer_idx=self.latent_layer_idx, logit=True)['logit'] 109 | target_p=F.softmax(logit_given_gen, dim=1).clone().detach() 110 | user_latent_loss= generative_beta * self.ensemble_loss(user_output_logp, target_p) 111 | 112 | sampled_y=np.random.choice(self.available_labels, self.gen_batch_size) 113 | 114 | sampled_y=torch.tensor(sampled_y) 115 | #CL_results_score_list_norm = (CL_results_score_list-min(cl_score_norm_list)) / (max(cl_score_norm_list) - min(cl_score_norm_list)) 116 | #CL_results_score_list_norm = cl_score_norm_list 117 | 118 | l_Half, r_Half, q2 = median(gmm_.sample(gmm_len)[0].flatten()) 119 | lHalf = median(l_Half)[2] 120 | rHalf = median(r_Half)[2] 121 | #50,100,150,200 122 | if next_stage==0: 123 | cl_score_fake_=[random.uniform(0,lHalf) for i in range(self.gen_batch_size)] 124 | elif next_stage==1: 125 | cl_score_fake_=[random.uniform(lHalf,rHalf) for i in range(self.gen_batch_size)] 126 | elif next_stage>=2: 127 | cl_score_fake_=[random.uniform(rHalf,1) for i in range(self.gen_batch_size)] 128 | 129 | #$print(cl_score_fake_) 130 | # easy_number = int(self.gen_batch_size/4*3) 131 | # hard_number = int(self.gen_batch_size - easy_number) 132 | # cl_score_fake_easy=[random.uniform(0,max(CL_results['score_list'])) for i in range(easy_number)] 133 | # cl_score_fake_hard=[random.uniform(min(CL_results['score_list']),0) for i in range(hard_number)] 134 | # cl_score_fake_ = cl_score_fake_easy + cl_score_fake_hard 135 | random.shuffle(cl_score_fake_) 136 | #print(float(CL_results_score_list[(CL_results_score_list_norm == lHalf).nonzero()[0][0]])) 137 | cl_score_fake_=torch.tensor(cl_score_fake_, dtype=torch.float).view(self.gen_batch_size,1) 138 | gen_result=self.generative_model(sampled_y,cl_score_fake_, latent_layer_idx=self.latent_layer_idx) 139 | 140 | 141 | gen_output=gen_result['output'] # latent representation when latent = True, x otherwise 142 | 143 | user_output_logp = self.model(gen_output, start_layer_idx=self.latent_layer_idx, logit=True) 144 | CL_score_fake_results = self.CL_User_Score(model_result = user_output_logp, 145 | Algorithms = 'SuperLoss_ce', 146 | loss_fun = self.loss, 147 | y = sampled_y, 148 | local_epoch = epoch, 149 | schedule = [epoch,self.local_epochs] 150 | )['score_list_base'] 151 | 152 | values_ = ((CL_score_fake_results-min(cl_score_norm_list)) / (max(cl_score_norm_list) - min(cl_score_norm_list))) 153 | 154 | 155 | #print(CL_results_score_list[(CL_results_score_list_norm == lHalf).nonzero()[0][0]]) 156 | #print(CL_score_fake_results) 157 | # print(sorted( values_,reverse=False)) 158 | # print(sorted( values_,reverse=False)[int(len(values_)/1.1)]) 159 | # exit() 160 | ##print(sorted( values_,reverse=False)[int(len(values_)*0.6)]) 161 | #print(CL_results_score_list[(CL_results_score_list_norm == lHalf).nonzero()[0][0]]) 162 | # if next_stage: 163 | # if sorted( values_,reverse=False)[int(len(values_)*0.8)] < lHalf: 164 | # print(epoch) 165 | # break 166 | # else: 167 | # if sorted( values_,reverse=True)[int(len(values_)*0.8)] > lHalf: 168 | # print(epoch) 169 | # break 170 | 171 | if next_stage==0: 172 | if sorted(values_,reverse=False)[int(len(values_)*0.8)] < lHalf: 173 | print(epoch) 174 | exit_ = True 175 | break 176 | elif next_stage==1: 177 | if rHalf > sorted(values_,reverse=False)[int(len(values_)*0.8)] > lHalf : 178 | print(epoch) 179 | exit_ = True 180 | break 181 | else: 182 | if sorted( values_,reverse=False)[int(len(values_)*0.8)] > rHalf: 183 | print(epoch) 184 | exit_ = True 185 | break 186 | user_output_logp = user_output_logp['output'] 187 | 188 | teacher_loss = generative_alpha * torch.mean( 189 | self.generative_model.crossentropy_loss(user_output_logp, sampled_y) 190 | ) 191 | # this is to further balance oversampled down-sampled synthetic data 192 | gen_ratio = self.gen_batch_size / self.batch_size 193 | loss=predictive_loss + gen_ratio * teacher_loss + user_latent_loss 194 | TEACHER_LOSS+=teacher_loss 195 | LATENT_LOSS+=user_latent_loss 196 | 197 | # if all( value < CL_results_score_list[(CL_results_score_list_norm == lHalf).nonzero()[0][0]] for value in sorted( values_,reverse=False)[:int(len(values_)/1.2)]): 198 | 199 | # break 200 | else: 201 | #### get loss and perform optimization 202 | 203 | loss=predictive_loss 204 | loss.backward(retain_graph=True) 205 | self.optimizer.step()#self.local_model) 206 | 207 | # --------------------- 208 | # Train Discriminator 209 | # --------------------- 210 | 211 | if regularization and epoch < early_stop: 212 | pass 213 | # self.optimizer_discriminator.zero_grad() 214 | # print(X.size()) 215 | # print(gen_output_['output'].size()) 216 | # D_H_real_loss = self.dist_loss(X, torch.ones_like(user_output_logp_)) 217 | # D_H_fake_loss = self.dist_loss(gen_output_, torch.zeros_like(gen_output)) 218 | 219 | 220 | 221 | # local-model <=== self.model 222 | self.clone_model_paramenter(self.model.parameters(), self.local_model) 223 | if personalized: 224 | self.clone_model_paramenter(self.model.parameters(), self.personalized_model_bar) 225 | self.lr_scheduler.step(glob_iter) 226 | 227 | if regularization and verbose: 228 | try: 229 | TEACHER_LOSS=TEACHER_LOSS.detach().numpy() / (self.local_epochs * self.K) 230 | except: 231 | TEACHER_LOSS=TEACHER_LOSS / (self.local_epochs * self.K) 232 | try: 233 | LATENT_LOSS=LATENT_LOSS.detach().numpy() / (self.local_epochs * self.K) 234 | except: 235 | LATENT_LOSS=LATENT_LOSS / (self.local_epochs * self.K) 236 | info='\nUser Teacher Loss={:.4f}'.format(TEACHER_LOSS) 237 | info+=', Latent Loss={:.4f}'.format(LATENT_LOSS) 238 | print(info) 239 | 240 | if CL_Results_Score!= None: 241 | return CL_results['score_list'],part_loss,exit_ 242 | else: 243 | return None 244 | 245 | def adjust_weights(self, samples): 246 | labels, counts = samples['labels'], samples['counts'] 247 | #weight=self.label_weights[y][:, user_idx].reshape(-1, 1) 248 | np_y = samples['y'].detach().numpy() 249 | n_labels = samples['y'].shape[0] 250 | weights = np.array([n_labels / count for count in counts]) # smaller count --> larger weight 251 | weights = len(self.available_labels) * weights / np.sum(weights) # normalized 252 | label_weights = np.ones(self.unique_labels) 253 | label_weights[labels] = weights 254 | sample_weights = label_weights[np_y] 255 | return sample_weights 256 | 257 | 258 | -------------------------------------------------------------------------------- /FLAlgorithms/curriculum/cl_score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from cvxopt import matrix, spdiag, solvers 3 | import numpy as np 4 | import random 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | import math 8 | from scipy.special import lambertw 9 | from scipy.special import binom 10 | 11 | 12 | class TripletLoss(nn.Module): 13 | """Triplet loss with hard positive/negative mining. 14 | 15 | Reference: 16 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 17 | 18 | Imported from ``_. 19 | 20 | Args: 21 | margin (float, optional): margin for triplet. Default is 0.3. 22 | """ 23 | 24 | def __init__(self, margin=0.5, batch_size=32, view_num=1, p=2): 25 | super(TripletLoss, self).__init__() 26 | self.margin = margin 27 | self.p = p 28 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 29 | self.targets = torch.cat([torch.arange(batch_size) for i in range(view_num)], dim=0) 30 | self.eps = 1e-7 31 | 32 | def forward(self, inputs,targets = None): 33 | """ 34 | Args: 35 | inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim). 36 | targets (torch.LongTensor): ground truth labels with shape (num_classes). 37 | """ 38 | if targets == None: 39 | targets = self.targets 40 | n = inputs.size(0) 41 | 42 | # Compute pairwise distance, replace by the official when merged 43 | dist = [] 44 | for i in range(n): 45 | dist.append(inputs[i] - inputs) 46 | dist = torch.stack(dist) 47 | dist = torch.linalg.norm(dist,ord=self.p,dim=2) 48 | 49 | # For each anchor, find the hardest positive and negative 50 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 51 | dist_ap, dist_an = [], [] 52 | for i in range(n): 53 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 54 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 55 | dist_ap = torch.cat(dist_ap) 56 | dist_an = torch.cat(dist_an) 57 | 58 | # Compute ranking hinge loss 59 | y = torch.ones_like(dist_an) 60 | return self.ranking_loss(dist_an, dist_ap, y)+self.eps 61 | 62 | def compute_weights(lossgrad, lamb): 63 | 64 | device = lossgrad.get_device() 65 | lossgrad = lossgrad.data.cpu().numpy() 66 | 67 | # Compute Optimal sample Weights 68 | aux = -(lossgrad**2+lamb) 69 | sz = len(lossgrad) 70 | P = 2*matrix(lamb*np.identity(sz)) 71 | q = matrix(aux.astype(np.double)) 72 | A = spdiag(matrix(-1.0, (1,sz))) 73 | b = matrix(0.0, (sz,1)) 74 | Aeq = matrix(1.0, (1,sz)) 75 | beq = matrix(1.0*sz) 76 | solvers.options['show_progress'] = False 77 | solvers.options['maxiters'] = 20 78 | solvers.options['abstol'] = 1e-4 79 | solvers.options['reltol'] = 1e-4 80 | solvers.options['feastol'] = 1e-4 81 | sol = solvers.qp(P, q, A, b, Aeq, beq) 82 | w = np.array(sol['x']) 83 | 84 | return torch.squeeze(torch.tensor(w, dtype=torch.float)) 85 | 86 | class LOWLoss(torch.nn.Module): 87 | def __init__(self, lamb=0.1): 88 | super(LOWLoss, self).__init__() 89 | self.lamb = lamb # higher lamb means more smoothness -> weights closer to 1 90 | self.loss = torch.nn.CrossEntropyLoss(reduction='none') # replace this with any loss with "reduction='none'" 91 | def forward(self, logits, target): 92 | # Compute loss gradient norm 93 | output_d = logits.detach() 94 | loss_d = torch.mean(self.loss(output_d.requires_grad_(True), target), dim=0) 95 | loss_d.backward(torch.ones_like(loss_d)) 96 | lossgrad = torch.norm(output_d.grad, 2, 1) 97 | 98 | # Computed weighted loss 99 | weights = compute_weights(lossgrad, self.lamb) 100 | loss = self.loss(logits, target) 101 | loss = torch.mean(torch.mul(loss, weights), dim=0) 102 | 103 | return loss, weights 104 | 105 | #epoch,self.local_epochs 106 | def ft_Cam_1(output, target, alpha, schedule): 107 | 108 | if schedule[0] > int(schedule[1]/2): 109 | loss = F.cross_entropy(output, target, reduction='none') 110 | 'Sort the loss in descending order' 111 | loss_sorted, indices = torch.sort(loss, descending=True) 112 | 113 | top_k = round(alpha * target.size(0)) # Select top_K values for determining the hardness in mini-batch (alpha x batch_size) 114 | 115 | # Calculate the adaptive hardness threshold (thres as in Eq. 1 in the paper) 116 | a = 0.7 117 | b = 0.2 118 | #print(schedule) 119 | thres = a*(1-(schedule[0]/len(range(schedule[1])))) + b 120 | # print('thres', thres) 121 | # print('current_batch', batch_idx) 122 | # print('max_iteration', len(train_loader)) 123 | 124 | # Select the hardness in each mini-batch based on the threshold (thres) 125 | hard_samples = loss_sorted[0:top_k] 126 | total_sum_hard_samples = sum(hard_samples) 127 | 128 | # Check whether total sum exceeds the threshold and update the loss accordingly (Eq. 2 in the paper) 129 | if total_sum_hard_samples > (thres * sum(loss_sorted)): 130 | output = output[indices, :] 131 | target = target[indices] 132 | top_k_output = output[0:top_k] 133 | tok_k_target = target[0:top_k] 134 | loss = F.cross_entropy(top_k_output, tok_k_target, reduction='mean') 135 | else: 136 | loss = F.cross_entropy(output, target, reduction='mean') 137 | return loss 138 | 139 | else: 140 | loss = F.cross_entropy(output, target, reduction='none') 141 | 142 | 'Sort the loss in descending order' 143 | loss_sorted, indices = torch.sort(loss, descending=True) 144 | 145 | top_k = round(alpha * target.size(0)) # Select top_K values for determining the hardness in mini-batch (alpha x batch_size) 146 | 147 | # Calculate the adaptive hardness threshold (thres as in Eq. 1 in the paper) 148 | a = 0.7 149 | b = 0.2 150 | thres = a*(1-(schedule[0]/len(range(schedule[1])))) + b 151 | top_k2 = round(thres * top_k) # Select hardness level again within top_K values (i.e., top-K' as described in paper) 152 | 153 | # print('thres=', thres) 154 | # print('top_k=', top_k) 155 | # print('top_k2'=', top_k2) 156 | # print('current_batch=', batch_idx) 157 | # print('max_iteration=', len(train_loader)) 158 | 159 | # Select the hardness in each mini-batch based on top_k values 160 | hard_samples = loss_sorted[0:top_k] 161 | total_sum_hard_samples = sum(hard_samples) 162 | 163 | # Select the hardness within top_k values (i.e., top-k') 164 | hard_samples_k = hard_samples[0:top_k2] 165 | total_sum_hard_samples_k = sum(hard_samples_k) 166 | 167 | # Select top_k and k output and target values 168 | output = output[indices, :] 169 | target = target[indices] 170 | 171 | top_k_output = output[0:top_k] 172 | tok_k_target = target[0:top_k] 173 | 174 | k_output = top_k_output[0:top_k2] 175 | k_target = tok_k_target[0:top_k2] 176 | 177 | # Check whether total sum exceeds the threshold and update the loss accordingly (Eq. 3 in the paper) 178 | if total_sum_hard_samples_k > (thres * total_sum_hard_samples): 179 | loss = F.cross_entropy(k_output, k_target, reduction='mean') 180 | #print('K_update done') 181 | else: 182 | loss = F.cross_entropy(top_k_output, tok_k_target, reduction='mean') 183 | #print('top_K_update done') 184 | return loss 185 | 186 | class LabelSmoothingCrossEntropyWithSuperKLDivLoss(nn.Module): 187 | def __init__(self, eps=0.01, reduction='mean', classes=10, rank=None, lam=0.5): 188 | super(LabelSmoothingCrossEntropyWithSuperKLDivLoss, self).__init__() 189 | self.eps = eps 190 | self.reduction = reduction 191 | self.super_loss = SuperLoss(C=classes, rank=rank, lam=lam) 192 | self.rank = rank 193 | 194 | def forward(self, output, target): 195 | B, c = output.size() 196 | log_preds = F.log_softmax(output, dim=-1) 197 | if self.reduction == 'sum': 198 | loss = -log_preds.sum() 199 | else: 200 | loss = -log_preds.sum(dim=-1) 201 | if self.reduction == 'mean': 202 | loss = loss.mean() 203 | super_loss = self.super_loss(FocalLoss(Superloss = True)(log_preds, target)) 204 | #super_loss = self.super_loss(F.nll_loss(log_preds, target, reduction='none')) 205 | # l_i = (-log_preds.sum(dim=-1)) * self.eps / c + (1 - self.eps) * F.nll_loss(log_preds, target, reduction='none') 206 | # return self.super_loss(l_i) 207 | loss_cls = loss * self.eps / c + (1 - self.eps) * super_loss 208 | return loss_cls, torch.exp(super_loss) 209 | 210 | class LabelSmoothingCrossEntropyWithSuperFlLoss(nn.Module): 211 | def __init__(self, eps=0.01, reduction='mean', classes=10, rank=None, lam=0.5): 212 | super(LabelSmoothingCrossEntropyWithSuperFlLoss, self).__init__() 213 | self.eps = eps 214 | self.reduction = reduction 215 | self.super_loss = SuperLoss(C=classes, rank=rank, lam=lam) 216 | self.rank = rank 217 | 218 | def forward(self, output, target): 219 | B, c = output.size() 220 | log_preds = F.log_softmax(output, dim=-1) 221 | if self.reduction == 'sum': 222 | loss = -log_preds.sum() 223 | else: 224 | loss = -log_preds.sum(dim=-1) 225 | if self.reduction == 'mean': 226 | loss = loss.mean() 227 | super_loss = self.super_loss(FocalLoss(Superloss = True)(log_preds, target)) 228 | #super_loss = self.super_loss(F.nll_loss(log_preds, target, reduction='none')) 229 | # l_i = (-log_preds.sum(dim=-1)) * self.eps / c + (1 - self.eps) * F.nll_loss(log_preds, target, reduction='none') 230 | # return self.super_loss(l_i) 231 | loss_cls = loss * self.eps / c + (1 - self.eps) * super_loss 232 | return loss_cls, torch.exp(super_loss) 233 | 234 | class LabelSmoothingCrossEntropyWithSuperCELoss(nn.Module): 235 | def __init__(self, eps=0.01, reduction='mean', classes=10, rank=None, lam=0.5): 236 | super(LabelSmoothingCrossEntropyWithSuperCELoss, self).__init__() 237 | self.eps = eps 238 | self.reduction = reduction 239 | self.super_loss = SuperLoss(C=classes, rank=rank,lam=lam) 240 | self.rank = rank 241 | 242 | def forward(self, output, target): 243 | B, c = output.size() 244 | log_preds = F.log_softmax(output, dim=-1) 245 | if self.reduction == 'sum': 246 | loss = -log_preds.sum() 247 | else: 248 | loss = -log_preds.sum(dim=-1) 249 | if self.reduction == 'mean': 250 | loss = loss.mean() 251 | celoss = F.nll_loss(log_preds, target, reduction='none') 252 | super_loss, score_list,tau = self.super_loss(celoss) 253 | 254 | 255 | # l_i = (-log_preds.sum(dim=-1)) * self.eps / c + (1 - self.eps) * F.nll_loss(log_preds, target, reduction='none') 256 | # return self.super_loss(l_i) 257 | loss_cls = loss * self.eps / c + (1 - self.eps) * super_loss 258 | return loss_cls, torch.exp(super_loss) , score_list, celoss, tau 259 | 260 | class SuperLoss(nn.Module): 261 | def __init__(self, C=10, lam=0.5, rank=None): 262 | super(SuperLoss, self).__init__() 263 | self.tau = torch.log(torch.FloatTensor([C]).to(rank)) 264 | self.lam = lam # set to 1 for CIFAR10 and 0.25 for CIFAR100 265 | self.rank = rank 266 | 267 | def forward(self, l_i): 268 | l_i_detach = l_i.detach() 269 | 270 | # self.tau = 0.9 * self.tau + 0.1 * l_i_detach 271 | sigma,y,self.tau = self.sigma(l_i_detach) 272 | loss = (l_i - self.tau) * sigma + self.lam * torch.log(sigma)**2 273 | 274 | loss_mean = loss.mean() 275 | return loss_mean, y,self.tau 276 | 277 | def sigma(self, l_i): 278 | x = -2 / torch.exp(torch.ones_like(l_i)).to(self.rank) 279 | cl_score = l_i - self.tau 280 | y_ = 0.5 * torch.max(x, cl_score / self.lam) 281 | y = y_.cpu().numpy() 282 | sigma = np.exp(-lambertw(y)) 283 | sigma = sigma.real.astype(np.float32) 284 | sigma = torch.from_numpy(sigma).to(self.rank) 285 | 286 | return sigma,cl_score,self.tau 287 | 288 | class FocalLoss(nn.Module): 289 | def __init__(self, alpha=0.25, gamma=2, size_average=True, Superloss = False): 290 | super(FocalLoss, self).__init__() 291 | self.alpha = alpha 292 | self.gamma = gamma 293 | self.size_average = size_average 294 | self.Superloss = Superloss 295 | def forward(self, inputs, targets): 296 | ce_loss = F.cross_entropy(inputs, targets, reduction='none') 297 | pt = torch.exp(-ce_loss) 298 | focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss 299 | if self.Superloss: 300 | return focal_loss 301 | if self.size_average: 302 | return focal_loss.mean() 303 | else: 304 | return focal_loss.sum() 305 | 306 | def CL_User_Score(model_result, Algorithms, loss_fun, y, local_epoch, schedule): 307 | """ 308 | Get Curriculum Learning Score and batch 309 | :param model_result: 310 | :param y: label 311 | :param Algorithms: base 312 | :param loss_fun : 313 | 314 | :return: 315 | :Curriculum_Learning_Score: 316 | :Curriculum_Learning_Loss: 317 | :Base_Loss: 318 | """ 319 | 320 | if Algorithms == 'base': 321 | Base_Loss = loss_fun(model_result['output'], y) 322 | Curriculum_Learning_Score = None 323 | 324 | elif Algorithms == 'LOW': 325 | Base_Loss, weights = LOWLoss()(model_result['logit'], y) 326 | Curriculum_Learning_Score = weights 327 | 328 | elif Algorithms == 'ft_Cam': 329 | Base_Loss = ft_Cam_1(model_result['logit'], y, 0.1,schedule) 330 | Curriculum_Learning_Score = None 331 | 332 | elif Algorithms == 'SuperLoss_ce': 333 | Base_Loss, Curriculum_Learning_Score, score_list,celoss, tau = LabelSmoothingCrossEntropyWithSuperCELoss(classes=10, rank=None, lam=1)(model_result['logit'], y) 334 | 335 | elif Algorithms == 'SuperLoss_fl': 336 | Base_Loss, Curriculum_Learning_Score = LabelSmoothingCrossEntropyWithSuperFlLoss(classes=10, rank=None, lam=0.1)(model_result['logit'], y) 337 | 338 | elif Algorithms == 'FocalLoss': 339 | Base_Loss = FocalLoss()(model_result['logit'], y) 340 | Curriculum_Learning_Score = None 341 | 342 | else: 343 | Exception('Algorithms None') 344 | 345 | return {'Curriculum_Learning_Score':Curriculum_Learning_Score, 346 | 'Loss':Base_Loss, 347 | 'score_list':(score_list-torch.min(score_list)) / (torch.max(score_list) - torch.min(score_list)), 348 | 'celoss':celoss, 349 | 'TripletLoss_':0.01, #TripletLoss__, 350 | 'tau':tau, 351 | 'score_list_base':score_list 352 | } -------------------------------------------------------------------------------- /FLAlgorithms/servers/serverpFedCL.py: -------------------------------------------------------------------------------- 1 | from FLAlgorithms.users.userpFedGen import UserpFedGen 2 | from FLAlgorithms.servers.serverbase import Server 3 | from FLAlgorithms.curriculum.cl_score import CL_User_Score 4 | from utils.model_utils import read_data, read_user_data, aggregate_user_data, create_generative_model 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | from torchvision.utils import save_image 10 | import os 11 | import copy 12 | import time 13 | import random 14 | from sklearn.mixture import GaussianMixture 15 | MIN_SAMPLES_PER_LABEL=1 16 | 17 | class FedCL(Server): 18 | def __init__(self, args, model, seed): 19 | super().__init__(args, model, seed) 20 | 21 | # Initialize data for all users 22 | data = read_data(args.dataset) 23 | # data contains: clients, groups, train_data, test_data, proxy_data 24 | clients = data[0] 25 | total_users = len(clients) 26 | self.total_test_samples = 0 27 | self.local = 'local' in self.algorithm.lower() 28 | self.use_adam = 'adam' in self.algorithm.lower() 29 | 30 | self.early_stop = 20 # stop using generated samples after 20 local epochs 31 | self.student_model = copy.deepcopy(self.model) 32 | self.generative_model = create_generative_model(args.dataset, args.algorithm, self.model_name, args.embedding) 33 | if not args.train: 34 | print('number of generator parameteres: [{}]'.format(self.generative_model.get_number_of_parameters())) 35 | print('number of model parameteres: [{}]'.format(self.model.get_number_of_parameters())) 36 | self.latent_layer_idx = self.generative_model.latent_layer_idx 37 | self.init_ensemble_configs() 38 | print("latent_layer_idx: {}".format(self.latent_layer_idx)) 39 | print("label embedding {}".format(self.generative_model.embedding)) 40 | print("ensemeble learning rate: {}".format(self.ensemble_lr)) 41 | print("ensemeble alpha = {}, beta = {}, eta = {}".format(self.ensemble_alpha, self.ensemble_beta, self.ensemble_eta)) 42 | print("generator alpha = {}, beta = {}".format(self.generative_alpha, self.generative_beta)) 43 | self.init_loss_fn() 44 | self.train_data_loader, self.train_iter, self.available_labels = aggregate_user_data(data, args.dataset, self.ensemble_batch_size) 45 | self.generative_optimizer = torch.optim.Adam( 46 | params=self.generative_model.parameters(), 47 | lr=self.ensemble_lr, betas=(0.9, 0.999), 48 | eps=1e-08, weight_decay=self.weight_decay, amsgrad=False) 49 | self.generative_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( 50 | optimizer=self.generative_optimizer, gamma=0.98) 51 | self.optimizer = torch.optim.Adam( 52 | params=self.model.parameters(), 53 | lr=self.ensemble_lr, betas=(0.9, 0.999), 54 | eps=1e-08, weight_decay=0, amsgrad=False) 55 | self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.optimizer, gamma=0.98) 56 | self.CL_User_Score=CL_User_Score 57 | 58 | 59 | #### creating users #### 60 | self.users = [] 61 | for i in range(total_users): 62 | id, train_data, test_data, label_info =read_user_data(i, data, dataset=args.dataset, count_labels=True) 63 | self.total_train_samples+=len(train_data) 64 | self.total_test_samples += len(test_data) 65 | id, train, test=read_user_data(i, data, dataset=args.dataset) 66 | user=UserpFedGen( 67 | args, id, model, self.generative_model, 68 | train_data, test_data, 69 | self.available_labels, self.latent_layer_idx, label_info, 70 | use_adam=self.use_adam) 71 | self.users.append(user) 72 | print("Number of Train/Test samples:", self.total_train_samples, self.total_test_samples) 73 | print("Data from {} users in total.".format(total_users)) 74 | print("Finished creating FedAvg server.") 75 | # 76 | def train(self, args): 77 | #### pretraining 78 | best_auc = -np.inf 79 | cl_score_norm_list_1 = [] 80 | cl_score_norm_list_2 = [] 81 | 82 | CL_Results_Score_list_1 = [] 83 | CL_Results_Score_list_2 = [] 84 | gmm_cnn = None 85 | _ = None 86 | next_stage__ = 0 87 | for glob_iter in range(self.num_glob_iters): 88 | print("\n\n-------------Round number: ",glob_iter, " -------------\n\n") 89 | self.selected_users, self.user_idxs=self.select_users(glob_iter, self.num_users, return_idx=True) 90 | if not self.local: 91 | self.send_parameters(mode=self.mode)# broadcast averaged prediction model 92 | self.evaluate() 93 | chosen_verbose_user = np.random.randint(0, len(self.users)) 94 | self.timestamp = time.time() # log user-training start time 95 | 96 | loss_list = [] 97 | # Train model base on local 98 | if glob_iter > 0: 99 | cl_score_norm_list_1 = [] 100 | if len(cl_score_norm_list_2)>1: 101 | print(len(np.array(cl_score_norm_list_2).reshape(-1))) 102 | _ = np.array(cl_score_norm_list_2).reshape(len(self.selected_users), -1) 103 | gmm_cnn=GaussianMixture(n_components=len(self.selected_users), covariance_type="spherical", random_state=0) 104 | gmm_cnn.fit(_) 105 | 106 | if glob_iter == 0: 107 | Break_local_list = [False for i in range(len(self.user_idxs))] 108 | 109 | for i,(user_id, user) in enumerate(zip(self.user_idxs, self.selected_users)): 110 | verbose = user_id == chosen_verbose_user 111 | # perform regularization using generated samples after the first communication round 112 | 113 | CL_Results_Score, loss_, Break_local = user.train( 114 | glob_iter, 115 | personalized=self.personalized, 116 | early_stop=self.early_stop, 117 | verbose=verbose and glob_iter > 0, 118 | regularization= glob_iter > 0 , 119 | run_curriculum = True, 120 | cl_score_norm_list = cl_score_norm_list_2, 121 | server_epoch = glob_iter, 122 | gmm_ = gmm_cnn if gmm_cnn != None else None, 123 | gmm_len = 320, 124 | next_stage = next_stage__ 125 | 126 | ) 127 | CL_Results_Score_list_1.append([CL_Results_Score.clone().detach().numpy()]) 128 | 129 | cl_score_norm_list_1.extend(CL_Results_Score) 130 | 131 | loss_list.append(loss_) 132 | if Break_local and glob_iter > 2: 133 | Break_local_list[i] = True 134 | # print('*'*20) 135 | # print(Break_local_list.count(True)) 136 | # print('*'*20) 137 | if Break_local_list.count(True)>=len(self.user_idxs)*0.8: 138 | next_stage__ += 1 139 | Break_local_list = [False for i in range(len(self.user_idxs))] 140 | print('*'*20) 141 | print(glob_iter) 142 | print('*'*20) 143 | 144 | cl_score_norm_list_2 = cl_score_norm_list_1 145 | CL_Results_Score_list_2 = CL_Results_Score_list_1 146 | 147 | #import pdb; pdb.set_trace() 148 | curr_timestamp = time.time() # log user-training end time 149 | train_time = (curr_timestamp - self.timestamp) / len(self.selected_users) 150 | self.metrics['user_train_time'].append(train_time) 151 | if self.personalized: 152 | self.evaluate_personalized_model() 153 | 154 | self.timestamp = time.time() # log server-agg start time 155 | self.train_generator( 156 | self.batch_size, 157 | epoches=self.ensemble_epochs // self.n_teacher_iters, 158 | latent_layer_idx=self.latent_layer_idx, 159 | verbose=True, 160 | Real_CL_Results = CL_Results_Score_list_2 161 | ) 162 | self.aggregate_parameters() 163 | curr_timestamp=time.time() # log server-agg end time 164 | agg_time = curr_timestamp - self.timestamp 165 | self.metrics['server_agg_time'].append(agg_time) 166 | if glob_iter > 0 and glob_iter % 20 == 0 and self.latent_layer_idx == 0: 167 | self.visualize_images(self.generative_model, glob_iter, repeats=10) 168 | 169 | self.save_results(args) 170 | self.save_model() 171 | 172 | def train_generator(self, batch_size, epoches=1, latent_layer_idx=-1, verbose=False, Real_CL_Results=None,Real_CL_Results_sum=None): 173 | """ 174 | Learn a generator that find a consensus latent representation z, given a label 'y'. 175 | :param batch_size: 176 | :param epoches: 177 | :param latent_layer_idx: if set to -1 (-2), get latent representation of the last (or 2nd to last) layer. 178 | :param verbose: print loss information. 179 | :return: Do not return anything. 180 | """ 181 | #self.generative_regularizer.train() 182 | self.label_weights, self.qualified_labels = self.get_label_weights() 183 | TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS, STUDENT_LOSS2 = 0, 0, 0, 0 184 | 185 | def update_generator_(n_iters, student_model, TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS,Real_CL_Results): 186 | self.generative_model.train() 187 | student_model.eval() 188 | for i in range(n_iters): 189 | self.generative_optimizer.zero_grad() 190 | y=np.random.choice(self.qualified_labels, batch_size) 191 | y_input=torch.LongTensor(y) 192 | 193 | ## feed to generator 194 | Real_CL_Results = np.squeeze(np.array(Real_CL_Results)) 195 | gmm=GaussianMixture(n_components=len(self.selected_users), covariance_type="spherical", random_state=0) 196 | gmm.fit(Real_CL_Results) 197 | gmm_,_ = gmm.sample(1) 198 | 199 | cl_sample = torch.tensor(gmm_, dtype=torch.float).view(batch_size,1) 200 | 201 | #torch.tensor([random.uniform(min(Real_CL_Results[random.randint(0,len(self.selected_users)-1)][0]), max(Real_CL_Results[random.randint(0,len(self.selected_users)-1)][0])) for i in range(batch_size)], dtype=torch.float).view(batch_size,1) 202 | 203 | gen_result=self.generative_model(y_input,cl_sample, latent_layer_idx=latent_layer_idx, verbose=True) 204 | 205 | # get approximation of Z( latent) if latent set to True, X( raw image) otherwise 206 | gen_output, eps=gen_result['output'], gen_result['eps'] 207 | 208 | ##### get losses ####x 209 | # decoded = self.generative_regularizeen_output) 210 | # regularization_loss = beta * self.generative_model.dist_loss(decoded, eps) # map generated z back to eps 211 | diversity_loss=self.generative_model.diversity_loss(eps, gen_output) 212 | 213 | ######### get teacher loss ############ 214 | teacher_loss=0 215 | # teacher_logit=0 216 | fake_cl_score = 0 217 | diversity_loss_list = 0 218 | 219 | for user_idx, user in enumerate(self.selected_users): 220 | gmm_,_ = gmm.sample(1) 221 | cl_sample = torch.tensor(gmm_, dtype=torch.float).view(batch_size,1) 222 | #torch.tensor([random.uniform(min(Real_CL_Results[user_idx][0]), max(Real_CL_Results[user_idx][0])) for i in range(batch_size)], dtype=torch.float).view(batch_size,1) 223 | gen_result=self.generative_model(y_input,cl_sample, latent_layer_idx=latent_layer_idx, verbose=True) 224 | 225 | # get approximation of Z( latent) if latent set to True, X( raw image) otherwise 226 | gen_output, eps=gen_result['output'], gen_result['eps'] 227 | user.model.eval() 228 | weight=self.label_weights[y][:, user_idx].reshape(-1, 1) 229 | expand_weight=np.tile(weight, (1, self.unique_labels)) 230 | user_result_given_gen=user.model(gen_output, start_layer_idx=latent_layer_idx, logit=True) 231 | 232 | CL_results = self.CL_User_Score(model_result = user_result_given_gen, 233 | Algorithms = 'SuperLoss_ce', 234 | loss_fun = None, 235 | y = y_input, 236 | local_epoch = None, 237 | schedule = None, 238 | 239 | ) 240 | 241 | user_output_logp_=F.log_softmax(user_result_given_gen['logit'], dim=1).squeeze(-1) 242 | user_output_logp_ = {'logit':user_output_logp_} 243 | fake_cl_loss_ = self.CL_User_Score(model_result = user_output_logp_, 244 | Algorithms = 'SuperLoss_ce', 245 | loss_fun = None, 246 | y = y_input, 247 | local_epoch = None, 248 | schedule = None 249 | ) 250 | teacher_loss_=torch.mean( \ 251 | fake_cl_loss_['Loss']* \ 252 | torch.tensor(weight, dtype=torch.float32)) 253 | 254 | cl_loss = torch.mean(torch.nn.MSELoss(reduce=False, size_average=False)(torch.tensor(CL_results['score_list'].clone().detach(), dtype=torch.float),torch.tensor(fake_cl_loss_['score_list'].clone().detach(), dtype=torch.float))) 255 | 256 | teacher_loss += teacher_loss_ 257 | fake_cl_score += cl_loss 258 | #teacher_logit+=user_result_given_gen['logit'] * torch.tensor(expand_weight, dtype=torch.float32) 259 | 260 | teacher_loss = teacher_loss#/len(self.selected_users) 261 | 262 | student_loss=1 263 | if self.ensemble_beta > 0: 264 | loss=self.ensemble_alpha * teacher_loss - self.ensemble_beta * student_loss + self.ensemble_eta * diversity_loss# + fake_cl_score 265 | else: 266 | loss=self.ensemble_alpha * teacher_loss + self.ensemble_eta * diversity_loss# + fake_cl_score 267 | 268 | loss.backward() 269 | self.generative_optimizer.step() 270 | TEACHER_LOSS += self.ensemble_alpha * teacher_loss#(torch.mean(TEACHER_LOSS.double())).item() 271 | STUDENT_LOSS += self.ensemble_beta * student_loss#(torch.mean(student_loss.double())).item() 272 | DIVERSITY_LOSS += self.ensemble_eta * diversity_loss#(torch.mean(diversity_loss.double())).item() 273 | 274 | return TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS,fake_cl_score#,Triplet_Loss 275 | 276 | for i in range(epoches): 277 | TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS, fake_cl_score=update_generator_( 278 | self.n_teacher_iters, self.model, TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS,Real_CL_Results) 279 | 280 | TEACHER_LOSS = TEACHER_LOSS.detach().numpy() / (self.n_teacher_iters * epoches) 281 | STUDENT_LOSS = STUDENT_LOSS/ (self.n_teacher_iters * epoches) 282 | DIVERSITY_LOSS = DIVERSITY_LOSS.detach().numpy() / (self.n_teacher_iters * epoches) 283 | fake_cl_score = fake_cl_score.detach().numpy() / (self.n_teacher_iters * epoches) 284 | #Triplet_Loss_ = Triplet_Loss_.detach().numpy() / (self.n_teacher_iters * epoches) 285 | info="Generator: Teacher Loss= {:.4f}, Student Loss= {:.4f}, Diversity Loss = {:.4f}, CL Loss = {:.4f},". \ 286 | format(TEACHER_LOSS, STUDENT_LOSS, DIVERSITY_LOSS,fake_cl_score) 287 | if verbose: 288 | print(info) 289 | self.generative_lr_scheduler.step() 290 | 291 | 292 | def get_label_weights(self): 293 | label_weights = [] 294 | qualified_labels = [] 295 | for label in range(self.unique_labels): 296 | weights = [] 297 | for user in self.selected_users: 298 | weights.append(user.label_counts[label]) 299 | if np.max(weights) > MIN_SAMPLES_PER_LABEL: 300 | qualified_labels.append(label) 301 | # uniform 302 | label_weights.append( np.array(weights) / np.sum(weights) ) 303 | label_weights = np.array(label_weights).reshape((self.unique_labels, -1)) 304 | return label_weights, qualified_labels 305 | 306 | def visualize_images(self, generator, glob_iter, repeats=1): 307 | """ 308 | Generate and visualize data for a generator. 309 | """ 310 | os.system("mkdir -p images") 311 | path = f'images/{self.algorithm}-{self.dataset}-iter{glob_iter}.png' 312 | y=self.available_labels 313 | y = np.repeat(y, repeats=repeats, axis=0) 314 | y_input=torch.tensor(y) 315 | generator.eval() 316 | images=generator(y_input, latent=False)['output'] # 0,1,..,K, 0,1,...,K 317 | images=images.view(repeats, -1, *images.shape[1:]) 318 | images=images.view(-1, *images.shape[2:]) 319 | save_image(images.detach(), path, nrow=repeats, normalize=True) 320 | print("Image saved to {}".format(path)) 321 | --------------------------------------------------------------------------------