├── assets └── FedHKD.png ├── Client ├── __pycache__ │ ├── utils.cpython-36.pyc │ ├── ClientBase.cpython-36.pyc │ ├── ClientBase.cpython-37.pyc │ ├── ClientFedMD.cpython-36.pyc │ ├── ClientFedMD.cpython-37.pyc │ ├── ClientFedAvg.cpython-36.pyc │ ├── ClientFedAvg.cpython-37.pyc │ ├── ClientFedDFKD.cpython-36.pyc │ ├── ClientFedHKD.cpython-36.pyc │ ├── ClientFedHKD.cpython-37.pyc │ ├── ClientFedProx.cpython-36.pyc │ ├── ClientFedProx.cpython-37.pyc │ ├── ClientFedProto.cpython-36.pyc │ └── ClientFedProto.cpython-37.pyc ├── ClientBase.py ├── ClientFedAvg.py ├── ClientFedProx.py ├── ClientFedMD.py ├── ClientFedProto.py └── ClientFedHKD.py ├── Server ├── __pycache__ │ ├── ServerBase.cpython-36.pyc │ ├── ServerBase.cpython-37.pyc │ ├── ServerFedMD.cpython-36.pyc │ ├── ServerFedMD.cpython-37.pyc │ ├── ServerFedAvg.cpython-36.pyc │ ├── ServerFedAvg.cpython-37.pyc │ ├── ServerFedDFKD.cpython-36.pyc │ ├── ServerFedHKD.cpython-36.pyc │ ├── ServerFedHKD.cpython-37.pyc │ ├── ServerFedProx.cpython-36.pyc │ ├── ServerFedProx.cpython-37.pyc │ ├── ServerFedProto.cpython-36.pyc │ └── ServerFedProto.cpython-37.pyc ├── ServerBase.py ├── ServerFedAvg.py ├── ServerFedProx.py ├── ServerFedProto.py ├── ServerFedMD.py └── ServerFedHKD.py ├── requirements.txt ├── utils.py ├── models.py ├── option.py ├── README.md ├── main.py ├── sampling.py └── mem_utils.py /assets/FedHKD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/assets/FedHKD.png -------------------------------------------------------------------------------- /Client/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /Client/__pycache__/ClientBase.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientBase.cpython-36.pyc -------------------------------------------------------------------------------- /Client/__pycache__/ClientBase.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientBase.cpython-37.pyc -------------------------------------------------------------------------------- /Client/__pycache__/ClientFedMD.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedMD.cpython-36.pyc -------------------------------------------------------------------------------- /Client/__pycache__/ClientFedMD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedMD.cpython-37.pyc -------------------------------------------------------------------------------- /Server/__pycache__/ServerBase.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerBase.cpython-36.pyc -------------------------------------------------------------------------------- /Server/__pycache__/ServerBase.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerBase.cpython-37.pyc -------------------------------------------------------------------------------- /Server/__pycache__/ServerFedMD.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedMD.cpython-36.pyc -------------------------------------------------------------------------------- /Server/__pycache__/ServerFedMD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedMD.cpython-37.pyc -------------------------------------------------------------------------------- /Client/__pycache__/ClientFedAvg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedAvg.cpython-36.pyc -------------------------------------------------------------------------------- /Client/__pycache__/ClientFedAvg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedAvg.cpython-37.pyc -------------------------------------------------------------------------------- /Client/__pycache__/ClientFedDFKD.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedDFKD.cpython-36.pyc -------------------------------------------------------------------------------- /Client/__pycache__/ClientFedHKD.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedHKD.cpython-36.pyc -------------------------------------------------------------------------------- /Client/__pycache__/ClientFedHKD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedHKD.cpython-37.pyc -------------------------------------------------------------------------------- /Client/__pycache__/ClientFedProx.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedProx.cpython-36.pyc -------------------------------------------------------------------------------- /Client/__pycache__/ClientFedProx.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedProx.cpython-37.pyc -------------------------------------------------------------------------------- /Server/__pycache__/ServerFedAvg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedAvg.cpython-36.pyc -------------------------------------------------------------------------------- /Server/__pycache__/ServerFedAvg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedAvg.cpython-37.pyc -------------------------------------------------------------------------------- /Server/__pycache__/ServerFedDFKD.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedDFKD.cpython-36.pyc -------------------------------------------------------------------------------- /Server/__pycache__/ServerFedHKD.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedHKD.cpython-36.pyc -------------------------------------------------------------------------------- /Server/__pycache__/ServerFedHKD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedHKD.cpython-37.pyc -------------------------------------------------------------------------------- /Server/__pycache__/ServerFedProx.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedProx.cpython-36.pyc -------------------------------------------------------------------------------- /Server/__pycache__/ServerFedProx.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedProx.cpython-37.pyc -------------------------------------------------------------------------------- /Client/__pycache__/ClientFedProto.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedProto.cpython-36.pyc -------------------------------------------------------------------------------- /Client/__pycache__/ClientFedProto.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedProto.cpython-37.pyc -------------------------------------------------------------------------------- /Server/__pycache__/ServerFedProto.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedProto.cpython-36.pyc -------------------------------------------------------------------------------- /Server/__pycache__/ServerFedProto.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedProto.cpython-37.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | calmsize==0.1.3 2 | imageio==2.9.0 3 | numpy==1.19.2 4 | Pillow==9.2.0 5 | scikit_learn==1.1.2 6 | scipy==1.5.2 7 | tensorboardX==2.5.1 8 | torch==1.9.0+rocm4.2 9 | torchvision==0.10.0+rocm4.2 10 | tqdm==4.62.3 11 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | def Accuracy(y,y_predict): 4 | leng = len(y) 5 | miss = 0 6 | for i in range(leng): 7 | if not y[i]==y_predict[i]: 8 | miss +=1 9 | return (leng-miss)/leng 10 | 11 | 12 | def soft_predict(Z,temp): 13 | m,n = Z.shape 14 | Q = torch.zeros(m,n) 15 | Z_sum = torch.sum(torch.exp(Z/temp),dim=1) 16 | for i in range(n): 17 | Q[:,i] = torch.exp(Z[:,i]/temp)/Z_sum 18 | return Q 19 | 20 | def average_weights(w): 21 | """ 22 | average the weights from all local models 23 | """ 24 | w_avg = copy.deepcopy(w[0]) 25 | for key in w_avg.keys(): 26 | for i in range(1, len(w)): 27 | w_avg[key] += w[i][key] 28 | w_avg[key] = torch.div(w_avg[key], len(w)) 29 | return w_avg 30 | -------------------------------------------------------------------------------- /Server/ServerBase.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data import Dataset 3 | import torch 4 | import copy 5 | from utils import Accuracy 6 | 7 | class Server(object): 8 | def __init__(self,args, global_model,Loaders_train, Loaders_local_test, Loader_global_test, logger, device): 9 | self.global_model = global_model 10 | self.args = args 11 | self.Loaders_train = Loaders_train 12 | self.Loaders_local_test = Loaders_local_test 13 | self.global_testloader = Loader_global_test 14 | self.logger = logger 15 | self.device = device 16 | self.LocalModels = [] 17 | 18 | def global_test_accuracy(self): 19 | self.global_model.eval() 20 | accuracy = 0 21 | cnt = 0 22 | for batch_idx, (X, y) in enumerate(self.global_testloader): 23 | X = X.to(self.device) 24 | y = y.to(self.device) 25 | _,p = self.global_model(X) 26 | y_pred = p.argmax(1) 27 | accuracy += Accuracy(y,y_pred) 28 | cnt += 1 29 | return accuracy/cnt 30 | 31 | 32 | def Save_CheckPoint(self, save_path): 33 | torch.save(self.global_model.state_dict(), save_path) 34 | 35 | -------------------------------------------------------------------------------- /Client/ClientBase.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import scipy 5 | from torch.utils.data import Dataset 6 | import torch 7 | import copy 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | from utils import Accuracy,soft_predict 12 | 13 | class Client(object): 14 | """ 15 | This class is for train the local model with input global model(copied) and output the updated weight 16 | args: argument 17 | Loader_train,Loader_val,Loaders_test: input for training and inference 18 | user: the index of local model 19 | idxs: the index for data of this local model 20 | logger: log the loss and the process 21 | """ 22 | def __init__(self, args, model,Loader_train,loader_test,idx, logger, code_length, num_classes, device): 23 | self.args = args 24 | self.logger = logger 25 | self.trainloader = Loader_train 26 | self.testloader = loader_test 27 | self.idx = idx 28 | self.ce = nn.CrossEntropyLoss() 29 | self.device = device 30 | self.code_length = code_length 31 | self.kld = nn.KLDivLoss() 32 | self.mse = nn.MSELoss() 33 | self.model = copy.deepcopy(model) 34 | 35 | 36 | def test_accuracy(self): 37 | self.model.eval() 38 | accuracy = 0 39 | cnt = 0 40 | for batch_idx, (X, y) in enumerate(self.testloader): 41 | X = X.to(self.device) 42 | y = y.to(self.device) 43 | _, p = self.model(X) 44 | y_pred = p.argmax(1) 45 | accuracy += Accuracy(y,y_pred) 46 | cnt += 1 47 | return accuracy/cnt 48 | 49 | def load_model(self,global_weights): 50 | self.model.load_state_dict(global_weights) -------------------------------------------------------------------------------- /Client/ClientFedAvg.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import scipy 5 | from torch.utils.data import Dataset 6 | import torch 7 | import copy 8 | import torch.nn as nn 9 | from sklearn.cluster import KMeans 10 | import torch.optim as optim 11 | import torch.nn.functional as F 12 | from utils import Accuracy,soft_predict 13 | from Client.ClientBase import Client 14 | 15 | class ClientFedAvg(Client): 16 | """ 17 | This class is for train the local model with input global model(copied) and output the updated weight 18 | args: argument 19 | Loader_train,Loader_val,Loaders_test: input for training and inference 20 | user: the index of local model 21 | idxs: the index for data of this local model 22 | logger: log the loss and the process 23 | """ 24 | def __init__(self, args, model, Loader_train,loader_test,idx, logger, code_length, num_classes, device): 25 | super().__init__(args, model, Loader_train,loader_test,idx, logger, code_length, num_classes, device) 26 | 27 | def update_weights(self,global_round): 28 | self.model.to(self.device) 29 | self.model.train() 30 | epoch_loss = [] 31 | optimizer = optim.Adam(self.model.parameters(),lr=self.args.lr) 32 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.args.lr_sh_rate, gamma=0.5) 33 | for iter in range(self.args.local_ep): 34 | batch_loss = [] 35 | for batch_idx, (X, y) in enumerate(self.trainloader): 36 | X = X.to(self.device) 37 | y = y.to(self.device) 38 | optimizer.zero_grad() 39 | _,p = self.model(X) 40 | loss = self.ce(p,y) 41 | loss.backward() 42 | if self.args.clip_grad != None: 43 | nn.utils.clip_grad_norm_(self.model.parameters(), max_norm = self.args.clip_grad) 44 | optimizer.step() 45 | if batch_idx % 10 == 0: 46 | print('| Global Round : {} | Client: {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 47 | global_round, self.idx, iter, batch_idx * len(X), 48 | len(self.trainloader.dataset), 49 | 100. * batch_idx / len(self.trainloader), loss.item())) 50 | self.logger.add_scalar('loss', loss.item()) 51 | batch_loss.append(loss.item()) 52 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 53 | 54 | return self.model.state_dict(),sum(epoch_loss) / len(epoch_loss) 55 | -------------------------------------------------------------------------------- /Server/ServerFedAvg.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data import Dataset 3 | import torch 4 | import copy 5 | from utils import Accuracy 6 | from Server.ServerBase import Server 7 | from Client.ClientFedAvg import ClientFedAvg 8 | from tqdm import tqdm 9 | import numpy as np 10 | from utils import average_weights 11 | from mem_utils import MemReporter 12 | import time 13 | class ServerFedAvg(Server): 14 | def __init__(self, args, global_model,Loader_train,Loaders_local_test,Loader_global_test,logger,device): 15 | super().__init__(args, global_model,Loader_train,Loaders_local_test,Loader_global_test,logger,device) 16 | 17 | 18 | def Create_Clints(self): 19 | for idx in range(self.args.num_clients): 20 | self.LocalModels.append(ClientFedAvg(self.args, copy.deepcopy(self.global_model),self.Loaders_train[idx], self.Loaders_local_test[idx], idx=idx, logger=self.logger, code_length = self.args.code_len, num_classes = self.args.num_classes, device=self.device)) 21 | 22 | 23 | def train(self): 24 | reporter = MemReporter() 25 | start_time = time.time() 26 | train_loss = [] 27 | global_weights = self.global_model.state_dict() 28 | for epoch in tqdm(range(self.args.num_epochs)): 29 | test_accuracy = 0 30 | local_weights, local_losses = [], [] 31 | print(f'\n | Global Training Round : {epoch+1} |\n') 32 | m = max(int(self.args.sampling_rate * self.args.num_clients), 1) 33 | idxs_users = np.random.choice(range(self.args.num_clients), m, replace=False) 34 | for idx in idxs_users: 35 | if self.args.upload_model == True: 36 | self.LocalModels[idx].load_model(global_weights) 37 | w, loss = self.LocalModels[idx].update_weights(global_round=epoch) 38 | local_losses.append(copy.deepcopy(loss)) 39 | local_weights.append(copy.deepcopy(w)) 40 | acc = self.LocalModels[idx].test_accuracy() 41 | test_accuracy += acc 42 | 43 | 44 | # update global weights 45 | global_weights = average_weights(local_weights) 46 | self.global_model.load_state_dict(global_weights) 47 | loss_avg = sum(local_losses) / len(local_losses) 48 | train_loss.append(loss_avg) 49 | print("average loss: ", loss_avg) 50 | print('average local test accuracy:', test_accuracy / self.args.num_clients) 51 | print('global test accuracy: ', self.global_test_accuracy()) 52 | 53 | print('Training is completed.') 54 | end_time = time.time() 55 | print('running time: {} s '.format(end_time - start_time)) 56 | reporter.report() 57 | -------------------------------------------------------------------------------- /Server/ServerFedProx.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data import Dataset 3 | import torch 4 | import copy 5 | from utils import Accuracy 6 | from Server.ServerBase import Server 7 | from Client.ClientFedProx import ClientFedProx 8 | from tqdm import tqdm 9 | import numpy as np 10 | from utils import average_weights 11 | from mem_utils import MemReporter 12 | import time 13 | class ServerFedProx(Server): 14 | def __init__(self, args, global_model,Loader_train,Loaders_local_test,Loader_global_test,logger,device): 15 | super().__init__(args, global_model,Loader_train,Loaders_local_test,Loader_global_test,logger,device) 16 | 17 | 18 | def Create_Clints(self): 19 | for idx in range(self.args.num_clients): 20 | self.LocalModels.append(ClientFedProx(self.args, copy.deepcopy(self.global_model),self.Loaders_train[idx], self.Loaders_local_test[idx], idx=idx, logger=self.logger, code_length = self.args.code_len, num_classes = self.args.num_classes, device=self.device)) 21 | 22 | 23 | def train(self): 24 | reporter = MemReporter() 25 | start_time = time.time() 26 | train_loss = [] 27 | global_weights = self.global_model.state_dict() 28 | for epoch in tqdm(range(self.args.num_epochs)): 29 | test_accuracy = 0 30 | local_weights, local_losses = [], [] 31 | print(f'\n | Global Training Round : {epoch+1} |\n') 32 | m = max(int(self.args.sampling_rate * self.args.num_clients), 1) 33 | idxs_users = np.random.choice(range(self.args.num_clients), m, replace=False) 34 | for idx in idxs_users: 35 | if self.args.upload_model == True: 36 | self.LocalModels[idx].load_model(global_weights) 37 | w, loss = self.LocalModels[idx].update_weights_Prox(global_round=epoch, lam=0.1) 38 | local_losses.append(copy.deepcopy(loss)) 39 | local_weights.append(copy.deepcopy(w)) 40 | acc = self.LocalModels[idx].test_accuracy() 41 | test_accuracy += acc 42 | 43 | 44 | # update global weights 45 | global_weights = average_weights(local_weights) 46 | self.global_model.load_state_dict(global_weights) 47 | loss_avg = sum(local_losses) / len(local_losses) 48 | train_loss.append(loss_avg) 49 | print("average loss: ", loss_avg) 50 | print('average local test accuracy:', test_accuracy / self.args.num_clients) 51 | print('global test accuracy: ', self.global_test_accuracy()) 52 | 53 | print('Training is completed.') 54 | end_time = time.time() 55 | print('running time: {} s '.format(end_time - start_time)) 56 | reporter.report() 57 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from os.path import join 3 | import imageio 4 | from torch import nn 5 | from torch.nn.modules.linear import Linear 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | from tqdm import tqdm 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | import torchvision.models as models 13 | 14 | 15 | class EncoderFemnist(nn.Module): 16 | def __init__(self, code_length): 17 | super(EncoderFemnist, self).__init__() 18 | self.conv1 = nn.Conv2d(1, 10, kernel_size=3) 19 | self.conv2 = nn.Conv2d(10,20, kernel_size=5) 20 | self.conv2_drop = nn.Dropout2d() 21 | self.fc1 = nn.Linear(int(320), code_length) 22 | 23 | def forward(self, x): 24 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 25 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 26 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 27 | z = F.relu(self.fc1(x)) 28 | return z 29 | 30 | class CNNFemnist(nn.Module): 31 | def __init__(self, args,code_length=50,num_classes = 62): 32 | super(CNNFemnist, self).__init__() 33 | self.code_length = code_length 34 | self.num_classes = num_classes 35 | self.feature_extractor = EncoderFemnist(self.code_length) 36 | self.classifier = nn.Sequential(nn.Dropout(0.2), 37 | nn.Linear(self.code_length, self.num_classes), 38 | nn.LogSoftmax(dim=1)) 39 | 40 | def forward(self, x): 41 | z = self.feature_extractor(x) 42 | p = self.classifier(z) 43 | return z,p 44 | 45 | 46 | class ResNet18(nn.Module): 47 | def __init__(self, args,code_length=64,num_classes = 10): 48 | super(ResNet18, self).__init__() 49 | self.code_length = code_length 50 | self.num_classes = num_classes 51 | self.feature_extractor = models.resnet18(num_classes=self.code_length) 52 | self.classifier = nn.Sequential( 53 | nn.Linear(self.code_length, self.num_classes)) 54 | def forward(self,x): 55 | z = self.feature_extractor(x) 56 | p = self.classifier(z) 57 | return z,p 58 | 59 | class ShuffLeNet(nn.Module): 60 | def __init__(self, args,code_length=64,num_classes = 10): 61 | super(ShuffLeNet, self).__init__() 62 | self.code_length = code_length 63 | self.num_classes = num_classes 64 | self.feature_extractor = models.shufflenet_v2_x1_0(num_classes=self.code_length) 65 | self.classifier = nn.Sequential( 66 | nn.Linear(self.code_length, self.num_classes)) 67 | def forward(self,x): 68 | z = self.feature_extractor(x) 69 | p = self.classifier(z) 70 | return z,p -------------------------------------------------------------------------------- /Client/ClientFedProx.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import scipy 5 | from torch.utils.data import Dataset 6 | import torch 7 | import copy 8 | import torch.nn as nn 9 | from sklearn.cluster import KMeans 10 | import torch.optim as optim 11 | import torch.nn.functional as F 12 | from utils import Accuracy,soft_predict 13 | from Client.ClientBase import Client 14 | 15 | class ClientFedProx(Client): 16 | """ 17 | This class is for train the local model with input global model(copied) and output the updated weight 18 | args: argument 19 | Loader_train,Loader_val,Loaders_test: input for training and inference 20 | user: the index of local model 21 | idxs: the index for data of this local model 22 | logger: log the loss and the process 23 | """ 24 | def __init__(self, args, model, Loader_train,loader_test,idx, logger, code_length, num_classes, device): 25 | super().__init__(args, model, Loader_train,loader_test,idx, logger, code_length, num_classes, device) 26 | 27 | def update_weights_Prox(self,global_round, lam): 28 | self.model.cuda() 29 | self.model.train() 30 | global_model = copy.deepcopy(self.model) 31 | global_model.eval() 32 | global_weight_collector = list(global_model.parameters()) 33 | epoch_loss = [] 34 | optimizer = optim.Adam(self.model.parameters(),lr=self.args.lr) 35 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.args.lr_sh_rate, gamma=0.5) 36 | for iter in range(self.args.local_ep): 37 | batch_loss = [] 38 | for batch_idx, (X, y) in enumerate(self.trainloader): 39 | X = X.to(self.device) 40 | y = y.to(self.device).long() 41 | optimizer.zero_grad() 42 | _,p = self.model(X) 43 | y_pred = p.argmax(1) 44 | loss1 = self.ce(p,y) 45 | fed_prox_reg = 0.0 46 | for param_index, param in enumerate(self.model.parameters()): 47 | fed_prox_reg += ((lam / 2) * torch.norm((param - global_weight_collector[param_index])) ** 2) 48 | loss = loss1 + lam*fed_prox_reg 49 | loss.backward() 50 | if self.args.clip_grad != None: 51 | nn.utils.clip_grad_norm_(self.model.parameters(), max_norm = self.args.clip_grad) 52 | optimizer.step() 53 | if batch_idx % 10 == 0: 54 | print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t prox_loss: {:.6f}'.format( 55 | global_round, iter, batch_idx * len(X), 56 | len(self.trainloader.dataset), 57 | 100. * batch_idx / len(self.trainloader), loss.item(),fed_prox_reg.item())) 58 | self.logger.add_scalar('loss', loss.item()) 59 | batch_loss.append(loss.item()) 60 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 61 | 62 | return self.model.state_dict(), sum(epoch_loss) / len(epoch_loss) -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def args_parser(): 5 | parser = argparse.ArgumentParser() 6 | 7 | #Data specifc paremeters 8 | parser.add_argument('--dataset', default='CIFAR10', 9 | help='CIFAR10, CIFAR100, SVHN, EMNIST') 10 | #Training specifc parameters 11 | parser.add_argument('--log_frq', type=int, default=5, 12 | help='frequency of logging') 13 | parser.add_argument('--batch_size', type=int, default=64, 14 | help='minibatch size') 15 | parser.add_argument('--num_epochs', type=int, default=50, 16 | help='number of epochs') 17 | parser.add_argument('--clip_grad', type=float, default=None, 18 | help='gadient clipping') 19 | parser.add_argument('--lr', type=float, default=0.001, 20 | help='learning rate') 21 | parser.add_argument('--lr_sh_rate', type=int, default=10, 22 | help='number of steps to drop the lr') 23 | parser.add_argument('--use_lrschd', action="store_true", default=False, 24 | help='Use lr rate scheduler') 25 | parser.add_argument('--num_clients', type=int, default=10, 26 | help='number of local models') 27 | 28 | parser.add_argument('--num_classes', type=int,default=10, 29 | help='number of classes') 30 | 31 | parser.add_argument('--sampling_rate', type=float,default=1, 32 | help='frac of local models to update') 33 | parser.add_argument('--local_ep',type=int, default=5, 34 | help='iterations of local updating') 35 | parser.add_argument('--beta', type=float,default=0.5, 36 | help='beta for non-iid distribution') 37 | parser.add_argument('--seed', type=int,default=0, 38 | help='random seed for generating datasets') 39 | parser.add_argument('--code_len', type=int,default=32, 40 | help='length of code') 41 | parser.add_argument('--alg', default='FedAvg', 42 | help='FedAvg, FedProx, Moon, FedMD, Fedproto, FedDFKD') 43 | 44 | parser.add_argument('--lam', type=float, default=0.05, 45 | help='hyper-parameter for loss2') 46 | 47 | parser.add_argument('--gamma', type=float, default=0.05, 48 | help='hyper-parameter for loss3') 49 | 50 | parser.add_argument('--std', type=float, default=2, 51 | help='std of gaussian noise ') 52 | 53 | parser.add_argument('--part', type=float,default=0.1, 54 | help='percentage of each local data') 55 | 56 | 57 | parser.add_argument('--temp', type=float,default=0.5, 58 | help='temperture for soft prediction') 59 | 60 | parser.add_argument('--model', default= 'resnet18', 61 | help='CNN resnet18 shufflenet') 62 | 63 | parser.add_argument('--upload_model', type=bool, default=True, 64 | help='whether to upload model parameters') 65 | parser.add_argument('--save_model', action="store_true", default= False, 66 | help='saved model parameters') 67 | 68 | parser.add_argument('--eval_only', action="store_true", default=False,help='evaluate the model') 69 | 70 | args = parser.parse_args() 71 | return args 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Federated Hyper Knowledge Distillation 2 | This is an official repository for our ICLR2023 paper 3 | * "[The Best of Both Worlds Accurate Global and Personalized Models through Federated Learning with Data-Free Hyper-Knowledge Distillation](https://arxiv.org/abs/2301.08968)" 4 |
5 |

6 | my alt text 7 |

8 |
9 | A flow diagram showing computation, encryption and aggregation of hyper-knowledge. 10 |
11 |
12 |
13 |
14 | 15 | ### Environment 16 | This project is developed based on python 3.6 with [torch1.9 (rocm4.2)](https://pytorch.org/get-started/previous-versions/). We use [conda](https://www.anaconda.com/docs/main) to manage the virtual environment. 17 | ``` 18 | git clone git@github.com:CityChan/Federated-Hyper-Knowledge-Distillation.git 19 | cd Federated-Hyper-Knowledge-Distillation 20 | conda create -n fedhkd --python=3.6 21 | conda activate fedhkd 22 | pip install torch==1.9.1+rocm4.2 torchvision==0.10.1+rocm4.2 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | ### Code structure 27 | * `main.py`: general set up for training and evaluate FL schemes 28 | * `models.py`: model architectures for running experiments 29 | * `sampling.py`: functions for generating non-iid datasets for federated learning 30 | * `utils.py`: functions for computing accuracy, knowledge distillation and model aggregation 31 | * `mem_utils.py`: library for monitoring memory usage and training time 32 | * `option.py`: define hyper-parameters 33 | * `Server/*.py`: object definition for server in differents scheme 34 | * `Client/*.py`: object definition for client in differents schemes 35 | 36 | ### Parameters 37 | * --dataset: 'CIFAR10', 'CIFAR100', ' SVHN' 38 | * --batch_size: batchsize, 64 by default 39 | * --num_epochs: number of global rounds, 50 by default 40 | * --lr: learning rate, 0.001 by defalut 41 | * --lr_sh_rate: period of learning rate decay, 10 by default 42 | * --dropout_rate: drop out rate for each layer, 0.2 by default 43 | * --clip_grad: maximum norm for gradient, 1.0 by default 44 | * --num_users: number of clients, 10 by default 45 | * --sampling_rate: proportion of clients send updates per round, 1.1 by default 46 | * --local_ep: number of local epochs, 5 by default 47 | * --beta: concentration parameter for Dirichlet distribution: 0.5 by default 48 | * --seed: random seed(for better reproducting experiments): 0 by default 49 | * --std: standard deviation by differential private noise, 2.0 by default 50 | * --code_len: dimention of latent vector, 32 by default 51 | * --alg: 'FedAvg, FedProx, Moon, FedMD, Fedproto, FedHKD' 52 | * --eval_only: only ouput the testing accuracy 53 | * --part: percentage of each local data 54 | * --temp: temperture for soft prediction 55 | * --lam: weights for loss2 56 | * --gamma: weights for loss3 57 | * --model: 'CNN', 'resnet18', 'shufflenet' 58 | * --save_model: save checkpoints of the model 59 | 60 | ### Running the code for training and evaluation 61 | ``` 62 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset 'SVHN' --batch_size 64 --num_epochs 50 --clip_grad 1.1 --lr 0.001 --num_clients 10 --num_classes 10 --sampling_rate 1 --local_ep 3 --beta 0.5 --seed 0 --code_len 50 --alg 'FedAvg' --part 0.1 --model 'resnet18' --temp 0.5 63 | ``` 64 | 65 | ### Acknowledgement 66 | This work was completed during internship in Toyota AI/ML Infrastructure & Data Lab. 67 | 68 | ### Citeation 69 | Please cite our paper, if you think this is useful: 70 | ``` 71 | @inproceedings{chen2023best, 72 | title={The Best of Both Worlds: Accurate Global and Personalized Models through Federated Learning with Data-Free Hyper-Knowledge Distillation}, 73 | author={Chen, Huancheng and Vikalo, Haris and others}, 74 | journal={arXiv preprint arXiv:2301.08968}, 75 | year={2023} 76 | } 77 | ``` 78 | 79 | -------------------------------------------------------------------------------- /Server/ServerFedProto.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data import Dataset 3 | import torch 4 | import copy 5 | from utils import Accuracy 6 | from Server.ServerBase import Server 7 | from Client.ClientFedProto import ClientFedProto 8 | from tqdm import tqdm 9 | import numpy as np 10 | from utils import average_weights 11 | from mem_utils import MemReporter 12 | import time 13 | from sampling import LocalDataset, LocalDataloaders, partition_data 14 | import gc 15 | 16 | class ServerFedProto(Server): 17 | def __init__(self, args, global_model,Loader_train,Loaders_local_test,Loader_global_test,logger,device): 18 | super().__init__(args, global_model,Loader_train,Loaders_local_test,Loader_global_test,logger,device) 19 | 20 | 21 | def Create_Clints(self): 22 | for idx in range(self.args.num_clients): 23 | self.LocalModels.append(ClientFedProto(self.args, copy.deepcopy(self.global_model),self.Loaders_train[idx], self.Loaders_local_test[idx], idx=idx, logger=self.logger, code_length = self.args.code_len, num_classes = self.args.num_classes, device=self.device)) 24 | 25 | def global_knowledge_aggregation(self, features): 26 | global_local_features = dict() 27 | for [label, features] in features.items(): 28 | if len(features) > 1: 29 | feature = 0 * features[0].data 30 | for i in features: 31 | feature += i.data 32 | global_local_features[label] = [feature / len(features)] 33 | else: 34 | global_local_features[label] = [features[0].data] 35 | 36 | 37 | return global_local_features 38 | 39 | def train(self): 40 | global_features = {} 41 | reporter = MemReporter() 42 | start_time = time.time() 43 | train_loss = [] 44 | global_weights = self.global_model.state_dict() 45 | for epoch in tqdm(range(self.args.num_epochs)): 46 | Knowledges = [] 47 | test_accuracy = 0 48 | local_weights, local_losses = [], [] 49 | print(f'\n | Global Training Round : {epoch+1} |\n') 50 | m = max(int(self.args.sampling_rate * self.args.num_clients), 1) 51 | idxs_users = np.random.choice(range(self.args.num_clients), m, replace=False) 52 | for idx in idxs_users: 53 | if self.args.upload_model == True: 54 | self.LocalModels[idx].load_model(global_weights) 55 | if epoch < 1: 56 | w, loss = self.LocalModels[idx].update_weights(global_round=epoch) 57 | local_losses.append(copy.deepcopy(loss)) 58 | local_weights.append(copy.deepcopy(w)) 59 | acc = self.LocalModels[idx].test_accuracy() 60 | test_accuracy += acc 61 | 62 | else: 63 | w, loss = self.LocalModels[idx].update_weights_Proto(global_round=epoch, global_features=global_features, gamma = self.args.gamma) 64 | local_losses.append(copy.deepcopy(loss)) 65 | local_weights.append(copy.deepcopy(w)) 66 | acc = self.LocalModels[idx].test_accuracy() 67 | test_accuracy += acc 68 | 69 | local_features = self.LocalModels[idx].generate_knowledge() 70 | global_features.update(local_features) 71 | del local_features 72 | gc.collect() 73 | 74 | 75 | # update global weights 76 | global_weights = average_weights(local_weights) 77 | self.global_model.load_state_dict(global_weights) 78 | loss_avg = sum(local_losses) / len(local_losses) 79 | train_loss.append(loss_avg) 80 | print("average loss: ", loss_avg) 81 | print('average local test accuracy:', test_accuracy / self.args.num_clients) 82 | print('global test accuracy: ', self.global_test_accuracy()) 83 | 84 | print('Training is completed.') 85 | end_time = time.time() 86 | print('running time: {} s '.format(end_time - start_time)) 87 | reporter.report() 88 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os,sys,os.path 4 | from tensorboardX import SummaryWriter 5 | import pickle 6 | from torch import nn 7 | import hashlib 8 | import argparse 9 | 10 | from models import CNNFemnist,ResNet18,ShuffLeNet 11 | from sampling import LocalDataset, LocalDataloaders, partition_data 12 | from option import args_parser 13 | 14 | # import different schemes 15 | from Server.ServerFedAvg import ServerFedAvg 16 | from Server.ServerFedProx import ServerFedProx 17 | from Server.ServerFedMD import ServerFedMD 18 | from Server.ServerFedProto import ServerFedProto 19 | from Server.ServerFedHKD import ServerFedHKD 20 | 21 | print(torch.__version__) 22 | torch.cuda.is_available() 23 | np.set_printoptions(threshold=np.inf) 24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | print(device.type) 26 | 27 | args = args_parser() 28 | print(args) 29 | 30 | # obtain hash value for saving checkpoints 31 | args_hash = '' 32 | for k,v in vars(args).items(): 33 | if k == 'eval_only': 34 | continue 35 | args_hash += str(k)+str(v) 36 | 37 | args_hash = hashlib.sha256(args_hash.encode()).hexdigest() 38 | 39 | 40 | 41 | 42 | # Generate data partitions in FL 43 | train_dataset,testset, dict_users, dict_users_test = partition_data(n_users = args.num_clients, alpha=args.beta,rand_seed = args.seed, dataset=str(args.dataset)) 44 | 45 | 46 | 47 | # Load local training datasets and testsets for each client 48 | Loaders_train = LocalDataloaders(train_dataset,dict_users,args.batch_size,ShuffleorNot = True,frac=args.part) 49 | Loaders_test = LocalDataloaders(testset,dict_users_test,args.batch_size,ShuffleorNot = True,frac=2*args.part) 50 | global_loader_test = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,shuffle=True, num_workers=2) 51 | 52 | for idx in range(args.num_clients): 53 | counts = [0]*args.num_classes 54 | for batch_idx,(X,y) in enumerate(Loaders_train[idx]): 55 | batch = len(y) 56 | y = np.array(y) 57 | for i in range(batch): 58 | counts[int(y[i])] += 1 59 | # print out data distribution of each client 60 | print('Client {} data distribution:'.format(idx)) 61 | print(counts) 62 | 63 | 64 | 65 | 66 | 67 | logger = SummaryWriter('./logs') 68 | checkpoint_dir = './checkpoint/'+ args.dataset + '/' 69 | if not os.path.exists(checkpoint_dir): 70 | os.makedirs(checkpoint_dir) 71 | with open(checkpoint_dir+'args.pkl', 'wb') as fp: 72 | pickle.dump(args, fp) 73 | print('Checkpoint dir:', checkpoint_dir) 74 | 75 | 76 | 77 | 78 | print(args.model) 79 | if args.model == 'CNN': 80 | # for EMNIST 62 classes 81 | global_model = CNNFemnist(args, code_length=args.code_len, num_classes = args.num_classes) 82 | 83 | if args.model == 'resnet18': 84 | global_model = ResNet18(args, code_length=args.code_len, num_classes = args.num_classes) 85 | 86 | if args.model == 'shufflenet': 87 | global_model = ShuffLeNet(args, code_length=args.code_len, num_classes = args.num_classes) 88 | 89 | 90 | print('# model parameters:', sum(param.numel() for param in global_model.parameters())) 91 | # global_model = nn.DataParallel(global_model) 92 | global_model.to(device) 93 | 94 | 95 | 96 | 97 | if args.alg == 'FedAvg': 98 | server = ServerFedAvg(args,global_model,Loaders_train,Loaders_test,global_loader_test,logger,device) 99 | if args.alg == 'FedProx': 100 | server = ServerFedProx(args,global_model,Loaders_train,Loaders_test,global_loader_test,logger,device) 101 | if args.alg == 'FedMD': 102 | server = ServerFedMD(args,global_model,Loaders_train,Loaders_test,global_loader_test,testset,logger,device) 103 | if args.alg == 'FedProto': 104 | server = ServerFedProto(args,global_model,Loaders_train,Loaders_test,global_loader_test,logger,device) 105 | if args.alg == 'FedHKD': 106 | server = ServerFedHKD(args,global_model,Loaders_train,Loaders_test,global_loader_test,logger,device) 107 | 108 | 109 | server.Create_Clints() 110 | server.train() 111 | 112 | save_path = checkpoint_dir + args_hash + '.pth' 113 | if args.save_model == True: 114 | server.Save_CheckPoint(save_path) 115 | print('Model is saved on: ') 116 | print(save_path) 117 | 118 | 119 | 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /Server/ServerFedMD.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data import Dataset 3 | import torch 4 | import copy 5 | from utils import Accuracy 6 | from Server.ServerBase import Server 7 | from Client.ClientFedMD import ClientFedMD 8 | from tqdm import tqdm 9 | import numpy as np 10 | from utils import average_weights 11 | from mem_utils import MemReporter 12 | import time 13 | from sampling import LocalDataset, LocalDataloaders, partition_data 14 | import gc 15 | 16 | class ServerFedMD(Server): 17 | def __init__(self, args, global_model,Loader_train,Loaders_local_test,Loader_global_test, pub_test,logger,device): 18 | super().__init__(args, global_model,Loader_train,Loaders_local_test,Loader_global_test,logger,device) 19 | dict_pub = [np.random.randint(low=0,high=10000,size = 1000)] 20 | self.public_data = LocalDataloaders(pub_test,dict_pub,args.batch_size,ShuffleorNot = False,frac=1)[0] 21 | 22 | def Create_Clints(self): 23 | 24 | 25 | for idx in range(self.args.num_clients): 26 | self.LocalModels.append(ClientFedMD(self.args, copy.deepcopy(self.global_model),self.Loaders_train[idx], self.Loaders_local_test[idx], loader_pub = self.public_data, idx=idx, logger=self.logger, code_length = self.args.code_len, num_classes = self.args.num_classes, device=self.device)) 27 | 28 | 29 | def train(self): 30 | reporter = MemReporter() 31 | start_time = time.time() 32 | train_loss = [] 33 | global_weights = self.global_model.state_dict() 34 | for epoch in tqdm(range(self.args.num_epochs)): 35 | Knowledges = [] 36 | test_accuracy = 0 37 | local_weights, local_losses = [], [] 38 | print(f'\n | Global Training Round : {epoch+1} |\n') 39 | m = max(int(self.args.sampling_rate * self.args.num_clients), 1) 40 | idxs_users = np.random.choice(range(self.args.num_clients), m, replace=False) 41 | for idx in idxs_users: 42 | if self.args.upload_model == True: 43 | self.LocalModels[idx].load_model(global_weights) 44 | if epoch < 1: 45 | w, loss = self.LocalModels[idx].update_weights(global_round=epoch) 46 | local_losses.append(copy.deepcopy(loss)) 47 | local_weights.append(copy.deepcopy(w)) 48 | acc = self.LocalModels[idx].test_accuracy() 49 | test_accuracy += acc 50 | 51 | else: 52 | w, loss = self.LocalModels[idx].update_weights_MD(global_round=epoch, knowledges = global_soft_prediciton, lam = 0.1, temp = self.args.temp) 53 | local_losses.append(copy.deepcopy(loss)) 54 | local_weights.append(copy.deepcopy(w)) 55 | acc = self.LocalModels[idx].test_accuracy() 56 | test_accuracy += acc 57 | 58 | knowledges = self.LocalModels[idx].generate_knowledge(temp=self.args.temp) 59 | Knowledges.append(torch.stack(knowledges)) 60 | global_soft_prediciton = [] 61 | batch_pub = Knowledges[0].shape[0] 62 | for i in range(batch_pub): 63 | num = Knowledges[0].shape[1] 64 | soft_label = torch.zeros(num,self.args.num_classes) 65 | for idx in idxs_users: 66 | soft_label += Knowledges[idx][i] 67 | soft_label = soft_label/ len(idxs_users) 68 | global_soft_prediciton.append(soft_label) 69 | del Knowledges 70 | gc.collect() 71 | 72 | # update global weights 73 | global_weights = average_weights(local_weights) 74 | self.global_model.load_state_dict(global_weights) 75 | loss_avg = sum(local_losses) / len(local_losses) 76 | train_loss.append(loss_avg) 77 | print("average loss: ", loss_avg) 78 | print('average local test accuracy:', test_accuracy / self.args.num_clients) 79 | print('global test accuracy: ', self.global_test_accuracy()) 80 | 81 | print('Training is completed.') 82 | end_time = time.time() 83 | print('running time: {} s '.format(end_time - start_time)) 84 | reporter.report() 85 | -------------------------------------------------------------------------------- /Server/ServerFedHKD.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data import Dataset 3 | import torch 4 | import copy 5 | from utils import Accuracy 6 | from Server.ServerBase import Server 7 | from Client.ClientFedHKD import ClientFedHKD 8 | from tqdm import tqdm 9 | import numpy as np 10 | from utils import average_weights 11 | from mem_utils import MemReporter 12 | import time 13 | from sampling import LocalDataset, LocalDataloaders, partition_data 14 | import gc 15 | 16 | class ServerFedHKD(Server): 17 | def __init__(self, args, global_model,Loader_train,Loaders_local_test,Loader_global_test,logger,device): 18 | super().__init__(args, global_model,Loader_train,Loaders_local_test,Loader_global_test,logger,device) 19 | 20 | 21 | def Create_Clints(self): 22 | for idx in range(self.args.num_clients): 23 | self.LocalModels.append(ClientFedHKD(self.args, copy.deepcopy(self.global_model),self.Loaders_train[idx], self.Loaders_local_test[idx], idx=idx, logger=self.logger, code_length = self.args.code_len, num_classes = self.args.num_classes, device=self.device)) 24 | 25 | def global_knowledge_aggregation(self, features,soft_prediction): 26 | global_local_features = dict() 27 | global_local_soft_prediction = dict() 28 | for [label, features] in features.items(): 29 | if len(features) > 1: 30 | feature = 0 * features[0].data 31 | for i in features: 32 | feature += i.data 33 | global_local_features[label] = [feature / len(features)] 34 | else: 35 | global_local_features[label] = [features[0].data] 36 | 37 | for [label, soft_prediction] in soft_prediction.items(): 38 | if len(soft_prediction) > 1: 39 | soft = 0 * soft_prediction[0].data 40 | for i in soft_prediction: 41 | soft += i.data 42 | global_local_soft_prediction[label] = [soft / len(soft_prediction)] 43 | else: 44 | global_local_soft_prediction[label] = [soft_prediction[0].data] 45 | 46 | return global_local_features,global_local_soft_prediction 47 | 48 | def train(self): 49 | global_features = {} 50 | global_soft_prediction = {} 51 | reporter = MemReporter() 52 | start_time = time.time() 53 | train_loss = [] 54 | global_weights = self.global_model.state_dict() 55 | for epoch in tqdm(range(self.args.num_epochs)): 56 | Knowledges = [] 57 | test_accuracy = 0 58 | local_weights, local_losses = [], [] 59 | print(f'\n | Global Training Round : {epoch+1} |\n') 60 | m = max(int(self.args.sampling_rate * self.args.num_clients), 1) 61 | idxs_users = np.random.choice(range(self.args.num_clients), m, replace=False) 62 | for idx in idxs_users: 63 | if self.args.upload_model == True: 64 | self.LocalModels[idx].load_model(global_weights) 65 | if epoch < 1: 66 | w, loss = self.LocalModels[idx].update_weights(global_round=epoch) 67 | local_losses.append(copy.deepcopy(loss)) 68 | local_weights.append(copy.deepcopy(w)) 69 | acc = self.LocalModels[idx].test_accuracy() 70 | test_accuracy += acc 71 | 72 | else: 73 | w, loss = self.LocalModels[idx].update_weights_HKD(global_round=epoch, global_features=global_features, global_soft_prediction=global_soft_prediction, lam = self.args.lam, gamma = self.args.gamma, temp = self.args.temp) 74 | local_losses.append(copy.deepcopy(loss)) 75 | local_weights.append(copy.deepcopy(w)) 76 | acc = self.LocalModels[idx].test_accuracy() 77 | test_accuracy += acc 78 | 79 | local_features,local_soft_predictions = self.LocalModels[idx].generate_knowledge(temp = self.args.temp) 80 | global_features.update(local_features) 81 | global_soft_prediction.update(local_soft_predictions) 82 | del local_features 83 | del local_soft_predictions 84 | gc.collect() 85 | 86 | 87 | # update global weights 88 | global_weights = average_weights(local_weights) 89 | self.global_model.load_state_dict(global_weights) 90 | 91 | loss_avg = sum(local_losses) / len(local_losses) 92 | train_loss.append(loss_avg) 93 | print("average loss: ", loss_avg) 94 | print('average local test accuracy:', test_accuracy / self.args.num_clients) 95 | print('global test accuracy: ', self.global_test_accuracy()) 96 | 97 | print('Training is completed.') 98 | end_time = time.time() 99 | print('running time: {} s '.format(end_time - start_time)) 100 | reporter.report() 101 | -------------------------------------------------------------------------------- /Client/ClientFedMD.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import scipy 5 | from torch.utils.data import Dataset 6 | import torch 7 | import copy 8 | import torch.nn as nn 9 | from sklearn.cluster import KMeans 10 | import torch.optim as optim 11 | import torch.nn.functional as F 12 | from utils import Accuracy,soft_predict 13 | from Client.ClientBase import Client 14 | import gc 15 | class ClientFedMD(Client): 16 | """ 17 | This class is for train the local model with input global model(copied) and output the updated weight 18 | args: argument 19 | Loader_train,Loader_val,Loaders_test: input for training and inference 20 | user: the index of local model 21 | idxs: the index for data of this local model 22 | logger: log the loss and the process 23 | """ 24 | def __init__(self, args, model, Loader_train,loader_test, loader_pub,idx, logger, code_length, num_classes, device): 25 | super().__init__(args, model, Loader_train,loader_test,idx, logger, code_length, num_classes, device) 26 | self.loader_pub = loader_pub 27 | 28 | def update_weights(self,global_round): 29 | self.model.to(self.device) 30 | self.model.train() 31 | epoch_loss = [] 32 | optimizer = optim.Adam(self.model.parameters(),lr=self.args.lr) 33 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.args.lr_sh_rate, gamma=0.5) 34 | for iter in range(self.args.local_ep): 35 | batch_loss = [] 36 | for batch_idx, (X, y) in enumerate(self.trainloader): 37 | X = X.to(self.device) 38 | y = y.to(self.device) 39 | optimizer.zero_grad() 40 | _,p = self.model(X) 41 | loss = self.ce(p,y) 42 | loss.backward() 43 | if self.args.clip_grad != None: 44 | nn.utils.clip_grad_norm_(self.model.parameters(), max_norm = self.args.clip_grad) 45 | optimizer.step() 46 | if batch_idx % 10 == 0: 47 | print('| Global Round : {} | Client: {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 48 | global_round, self.idx, iter, batch_idx * len(X), 49 | len(self.trainloader.dataset), 50 | 100. * batch_idx / len(self.trainloader), loss.item())) 51 | self.logger.add_scalar('loss', loss.item()) 52 | batch_loss.append(loss.item()) 53 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 54 | 55 | return self.model.state_dict(),sum(epoch_loss) / len(epoch_loss) 56 | 57 | def update_weights_MD(self,knowledges, lam, temp, global_round): 58 | self.model.to(self.device) 59 | self.model.train() 60 | epoch_loss = [] 61 | global_soft_prediction = torch.stack(knowledges) 62 | optimizer = optim.Adam(self.model.parameters(),lr=self.args.lr) 63 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.args.lr_sh_rate, gamma=0.5) 64 | for iter in range(self.args.local_ep): 65 | batch_loss = [] 66 | for batch_idx, (X, y) in enumerate(self.trainloader): 67 | X = X.to(self.device) 68 | y = y.to(self.device) 69 | optimizer.zero_grad() 70 | _,Z = self.model(X) 71 | loss1 = self.ce(Z,y) 72 | loss2 = torch.tensor(0.0).to(self.device) 73 | for idx, (X_pub,y_pub) in enumerate(self.loader_pub): 74 | if idx == batch_idx: 75 | X_pub = X_pub.to(self.device) 76 | y_pub = y_pub.to(self.device) 77 | _,Z_pub = self.model(X_pub) 78 | Q_pub = soft_predict(Z_pub,temp).to(self.device) 79 | loss2 -= self.kld(Q_pub,global_soft_prediction[idx].to(self.device)) 80 | 81 | loss = loss1 + lam*loss2 82 | loss.backward() 83 | optimizer.step() 84 | if batch_idx % 10 == 0: 85 | print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss1: {:.6f} Loss2: {:.6f} '.format( 86 | global_round, iter, batch_idx * len(X), 87 | len(self.trainloader.dataset), 88 | 100. * batch_idx / len(self.trainloader), loss1.item(),loss2.item())) 89 | self.logger.add_scalar('loss', loss.item()) 90 | batch_loss.append(loss.item()) 91 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 92 | 93 | return self.model.state_dict(), sum(epoch_loss) / len(epoch_loss) 94 | 95 | def generate_knowledge(self, temp): 96 | self.model.to(self.device) 97 | self.model.eval() 98 | num_classes = self.model.num_classes 99 | soft_predictions = [] 100 | for batch_idx, (X, y) in enumerate(self.loader_pub): 101 | X = X.to(self.device) 102 | y = y 103 | _,Z = self.model(X) 104 | Q = soft_predict(Z,temp).to(self.device).detach().cpu() 105 | soft_predictions.append(Q) 106 | del X 107 | del y 108 | del Z 109 | del Q 110 | gc.collect() 111 | 112 | return soft_predictions -------------------------------------------------------------------------------- /sampling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import scipy 4 | from torch.utils.data import Dataset 5 | import torch 6 | import copy 7 | from torchvision import datasets, transforms 8 | 9 | class LocalDataset(Dataset): 10 | """ 11 | because torch.dataloader need override __getitem__() to iterate by index 12 | this class is map the index to local dataloader into the whole dataloader 13 | """ 14 | def __init__(self, dataset, Dict): 15 | self.dataset = dataset 16 | self.idxs = [int(i) for i in Dict] 17 | 18 | def __len__(self): 19 | return len(self.idxs) 20 | 21 | def __getitem__(self, item): 22 | X, y = self.dataset[self.idxs[item]] 23 | return X, y 24 | 25 | def LocalDataloaders(dataset, dict_users, batch_size, ShuffleorNot = True, BatchorNot = True, frac = 1): 26 | """ 27 | dataset: the same dataset object 28 | dict_users: dictionary of index of each local model 29 | batch_size: batch size for each dataloader 30 | ShuffleorNot: Shuffle or Not 31 | BatchorNot: if False, the dataloader will give the full length of data instead of a batch, for testing 32 | """ 33 | num_users = len(dict_users) 34 | loaders = [] 35 | for i in range(num_users): 36 | num_data = len(dict_users[i]) 37 | frac_num_data = int(frac*num_data) 38 | whole_range = range(num_data) 39 | frac_range = np.random.choice(whole_range, frac_num_data) 40 | frac_dict_users = [dict_users[i][j] for j in frac_range] 41 | if BatchorNot== True: 42 | loader = torch.utils.data.DataLoader( 43 | LocalDataset(dataset,frac_dict_users), 44 | batch_size=batch_size, 45 | shuffle = ShuffleorNot, 46 | num_workers=0, 47 | drop_last=True) 48 | else: 49 | loader = torch.utils.data.DataLoader( 50 | LocalDataset(dataset,frac_dict_users), 51 | batch_size=len(LocalDataset(dataset,dict_users[i])), 52 | shuffle = ShuffleorNot, 53 | num_workers=0, 54 | drop_last=True) 55 | loaders.append(loader) 56 | return loaders 57 | 58 | 59 | def partition_data(n_users, alpha=0.5,rand_seed = 0, dataset = 'cifar10'): 60 | if dataset == 'CIFAR10': 61 | K = 10 62 | data_dir = '../data/cifar10/' 63 | apply_transform = transforms.Compose( 64 | [transforms.ToTensor(), 65 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 66 | train_dataset = datasets.CIFAR10(data_dir, train=True, download=True, 67 | transform=apply_transform) 68 | test_dataset = datasets.CIFAR10(data_dir, train=False, download=True, 69 | transform=apply_transform) 70 | y_train = np.array(train_dataset.targets) 71 | y_test = np.array(test_dataset.targets) 72 | 73 | if dataset == 'CIFAR100': 74 | K = 100 75 | data_dir = '../data/cifar100/' 76 | apply_transform = transforms.Compose( 77 | [transforms.ToTensor(), 78 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 79 | train_dataset = datasets.CIFAR100(data_dir, train=True, download=True, 80 | transform=apply_transform) 81 | test_dataset = datasets.CIFAR100(data_dir, train=False, download=True, 82 | transform=apply_transform) 83 | y_train = np.array(train_dataset.targets) 84 | y_test = np.array(test_dataset.targets) 85 | 86 | if dataset == 'EMNIST': 87 | K = 62 88 | data_dir = '../data/EMNIST/' 89 | apply_transform = transforms.Compose( 90 | [transforms.ToTensor(), 91 | transforms.Normalize((0.5), (0.5))]) 92 | train_dataset = datasets.EMNIST(data_dir, train=True, split = 'byclass', download=True, 93 | transform=apply_transform) 94 | test_dataset = datasets.EMNIST(data_dir, train=False, split = 'byclass', download=True, 95 | transform=apply_transform) 96 | y_train = np.array(train_dataset.targets) 97 | y_test = np.array(test_dataset.targets) 98 | if dataset == 'SVHN': 99 | K = 10 100 | data_dir = '../data/SVHN/' 101 | apply_transform = transforms.Compose( 102 | [transforms.ToTensor(), 103 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 104 | train_dataset = datasets.SVHN(data_dir, split='train', download=True, 105 | transform=apply_transform) 106 | test_dataset = datasets.SVHN(data_dir, split='test', download=True, 107 | transform=apply_transform) 108 | y_train = np.array(train_dataset.labels) 109 | y_test = np.array(test_dataset.labels) 110 | 111 | min_size = 0 112 | N = len(train_dataset) 113 | N_test = len(test_dataset) 114 | net_dataidx_map = {} 115 | net_dataidx_map_test = {} 116 | np.random.seed(rand_seed) 117 | 118 | while min_size < 10: 119 | idx_batch = [[] for _ in range(n_users)] 120 | idx_batch_test = [[] for _ in range(n_users)] 121 | for k in range(K): 122 | idx_k = np.where(y_train == k)[0] 123 | idx_k_test = np.where(y_test == k)[0] 124 | np.random.shuffle(idx_k) 125 | proportions = np.random.dirichlet(np.repeat(alpha, n_users)) 126 | ## Balance 127 | proportions_train = np.array([p*(len(idx_j) 1: 141 | feature = 0 * features[0].data 142 | for i in features: 143 | feature += i.data 144 | agg_local_features[label] = [feature / len(features)] 145 | else: 146 | agg_local_features[label] = [features[0].data] 147 | 148 | return agg_local_features 149 | 150 | def dict_to_tensor(self, dic): 151 | lit = [] 152 | for key,tensor in dic.items(): 153 | lit.append(tensor[0]) 154 | lit = torch.stack(lit) 155 | return lit -------------------------------------------------------------------------------- /Client/ClientFedHKD.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import scipy 5 | from torch.utils.data import Dataset 6 | import torch 7 | import copy 8 | import torch.nn as nn 9 | from sklearn.cluster import KMeans 10 | import torch.optim as optim 11 | import torch.nn.functional as F 12 | from utils import Accuracy,soft_predict 13 | from Client.ClientBase import Client 14 | import gc 15 | class ClientFedHKD(Client): 16 | """ 17 | This class is for train the local model with input global model(copied) and output the updated weight 18 | args: argument 19 | Loader_train,Loader_val,Loaders_test: input for training and inference 20 | user: the index of local model 21 | idxs: the index for data of this local model 22 | logger: log the loss and the process 23 | """ 24 | def __init__(self, args, model, Loader_train,loader_test,idx, logger, code_length, num_classes, device): 25 | super().__init__(args, model, Loader_train,loader_test,idx, logger, code_length, num_classes, device) 26 | 27 | 28 | def update_weights(self,global_round): 29 | self.model.to(self.device) 30 | self.model.train() 31 | epoch_loss = [] 32 | optimizer = optim.Adam(self.model.parameters(),lr=self.args.lr) 33 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.args.lr_sh_rate, gamma=0.5) 34 | for iter in range(self.args.local_ep): 35 | batch_loss = [] 36 | for batch_idx, (X, y) in enumerate(self.trainloader): 37 | X = X.to(self.device) 38 | y = y.to(self.device) 39 | optimizer.zero_grad() 40 | _,p = self.model(X) 41 | loss = self.ce(p,y) 42 | loss.backward() 43 | if self.args.clip_grad != None: 44 | nn.utils.clip_grad_norm_(self.model.parameters(), max_norm = self.args.clip_grad) 45 | optimizer.step() 46 | if batch_idx % 10 == 0: 47 | print('| Global Round : {} | Client: {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 48 | global_round, self.idx, iter, batch_idx * len(X), 49 | len(self.trainloader.dataset), 50 | 100. * batch_idx / len(self.trainloader), loss.item())) 51 | self.logger.add_scalar('loss', loss.item()) 52 | batch_loss.append(loss.item()) 53 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 54 | 55 | return self.model.state_dict(),sum(epoch_loss) / len(epoch_loss) 56 | 57 | 58 | def update_weights_HKD(self,global_features, global_soft_prediction, lam, gamma, temp, global_round): 59 | self.model.to(self.device) 60 | self.model.train() 61 | epoch_loss = [] 62 | optimizer = optim.Adam(self.model.parameters(),lr=self.args.lr) 63 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.args.lr_sh_rate, gamma=0.5) 64 | tensor_global_features = self.dict_to_tensor(global_features).to(self.device) 65 | tensor_global_soft_prediction = self.dict_to_tensor(global_soft_prediction).to(self.device) 66 | for iter in range(self.args.local_ep): 67 | batch_loss = [] 68 | for batch_idx, (X, y) in enumerate(self.trainloader): 69 | X = X.to(self.device) 70 | y = y.to(self.device) 71 | optimizer.zero_grad() 72 | F,Z = self.model(X) 73 | Z_help = self.model.classifier(tensor_global_features) 74 | Q_help = soft_predict(Z_help,temp).to(self.device) 75 | loss1 = self.ce(Z,y) 76 | target_features = copy.deepcopy(F.data) 77 | 78 | 79 | for i in range(y.shape[0]): 80 | if int(y[i]) in global_features.keys(): 81 | target_features[i] = global_features[int(y[i])][0].data 82 | 83 | 84 | target_features = target_features.to(self.device) 85 | if len(global_features) == 0: 86 | loss2 = 0*loss1 87 | loss3 = 0*loss1 88 | else: 89 | loss2 = self.kld(Q_help.log(),tensor_global_soft_prediction) 90 | loss3 = self.mse(F,target_features) 91 | loss = loss1 + lam*loss2 + gamma*loss3 92 | loss.backward() 93 | if self.args.clip_grad != None: 94 | nn.utils.clip_grad_norm_(self.model.parameters(), max_norm = self.args.clip_grad) 95 | nn.utils.clip_grad_norm_(self.model.parameters(), max_norm =1.1) 96 | optimizer.step() 97 | if batch_idx % 10 == 0: 98 | print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss1: {:.6f} Loss2: {:.6f} Loss3: {:.6f} '.format( 99 | global_round, iter, batch_idx * len(X), 100 | len(self.trainloader.dataset), 101 | 100. * batch_idx / len(self.trainloader), loss1.item(),loss2.item(),loss3.item())) 102 | self.logger.add_scalar('loss', loss.item()) 103 | batch_loss.append(loss.item()) 104 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 105 | 106 | return self.model.state_dict(), sum(epoch_loss) / len(epoch_loss) 107 | 108 | # generate knowledge for FedDFKD 109 | def generate_knowledge(self, temp): 110 | self.model.to(self.device) 111 | self.model.eval() 112 | local_features = {} 113 | local_soft_prediction = {} 114 | num_classes = self.model.num_classes 115 | features = [torch.zeros(self.code_length).to(self.device)]*num_classes 116 | soft_predictions = [torch.zeros(num_classes).to(self.device)]*num_classes 117 | count = [0]*num_classes 118 | for batch_idx, (X, y) in enumerate(self.trainloader): 119 | X = X.to(self.device) 120 | y = y 121 | F,Z = self.model(X) 122 | Q = soft_predict(Z,temp).to(self.device) 123 | m = y.shape[0] 124 | for i in range(len(y)): 125 | if y[i].item() in local_features: 126 | local_features[y[i].item()].append(F[i,:]) 127 | local_soft_prediction[y[i].item()].append(Q[i,:]) 128 | else: 129 | local_features[y[i].item()] = [F[i,:]] 130 | local_soft_prediction[y[i].item()] = [Q[i,:]] 131 | del X 132 | del y 133 | del F 134 | del Z 135 | del Q 136 | gc.collect() 137 | 138 | features,soft_predictions = self.local_knowledge_aggregation(local_features,local_soft_prediction, std = self.args.std) 139 | 140 | return (features,soft_predictions) 141 | 142 | def local_knowledge_aggregation(self,local_features,local_soft_prediction, std): 143 | agg_local_features = dict() 144 | agg_local_soft_prediction = dict() 145 | feature_noise = std*torch.randn(self.args.code_len).to(self.device) 146 | for [label, features] in local_features.items(): 147 | if len(features) > 1: 148 | feature = 0 * features[0].data 149 | for i in features: 150 | feature += i.data 151 | agg_local_features[label] = [feature / len(features) + feature_noise] 152 | else: 153 | agg_local_features[label] = [features[0].data + feature_noise] 154 | 155 | for [label, soft_prediction] in local_soft_prediction.items(): 156 | if len(soft_prediction) > 1: 157 | soft = 0 * soft_prediction[0].data 158 | for i in soft_prediction: 159 | soft += i.data 160 | 161 | agg_local_soft_prediction[label] = [soft / len(soft_prediction) ] 162 | else: 163 | agg_local_soft_prediction[label] = [soft_prediction[0].data] 164 | 165 | return agg_local_features,agg_local_soft_prediction 166 | 167 | def dict_to_tensor(self, dic): 168 | lit = [] 169 | for key,tensor in dic.items(): 170 | lit.append(tensor[0]) 171 | lit = torch.stack(lit) 172 | return lit 173 | -------------------------------------------------------------------------------- /mem_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import gc 3 | from collections import defaultdict 4 | from typing import Optional, Tuple, List 5 | 6 | import torch 7 | 8 | from math import isnan 9 | from calmsize import size as calmsize 10 | 11 | def readable_size(num_bytes: int) -> str: 12 | return '' if isnan(num_bytes) else '{:.2f}'.format(calmsize(num_bytes)) 13 | 14 | LEN = 79 15 | 16 | # some pytorch low-level memory management constant 17 | # the minimal allocate memory size (Byte) 18 | PYTORCH_MIN_ALLOCATE = 2 ** 9 19 | # the minimal cache memory size (Byte) 20 | PYTORCH_MIN_CACHE = 2 ** 20 21 | 22 | class MemReporter(): 23 | """A memory reporter that collects tensors and memory usages 24 | 25 | Parameters: 26 | - model: an extra nn.Module can be passed to infer the name 27 | of Tensors 28 | 29 | """ 30 | def __init__(self, model: Optional[torch.nn.Module] = None): 31 | self.tensor_name = {} 32 | self.device_mapping = defaultdict(list) 33 | self.device_tensor_stat = {} 34 | # to numbering the unknown tensors 35 | self.name_idx = 0 36 | 37 | tensor_names = defaultdict(list) 38 | if model is not None: 39 | assert isinstance(model, torch.nn.Module) 40 | # for model with tying weight, multiple parameters may share 41 | # the same underlying tensor 42 | for name, param in model.named_parameters(): 43 | tensor_names[param].append(name) 44 | 45 | for param, name in tensor_names.items(): 46 | self.tensor_name[id(param)] = '+'.join(name) 47 | 48 | def _get_tensor_name(self, tensor: torch.Tensor) -> str: 49 | tensor_id = id(tensor) 50 | if tensor_id in self.tensor_name: 51 | name = self.tensor_name[tensor_id] 52 | # use numbering if no name can be inferred 53 | else: 54 | name = type(tensor).__name__ + str(self.name_idx) 55 | self.tensor_name[tensor_id] = name 56 | self.name_idx += 1 57 | return name 58 | 59 | def collect_tensor(self): 60 | """Collect all tensor objects tracked by python 61 | 62 | NOTICE: 63 | - the buffers for backward which is implemented in C++ are 64 | not tracked by python's reference counting. 65 | - the gradients(.grad) of Parameters is not collected, and 66 | I don't know why. 67 | """ 68 | #FIXME: make the grad tensor collected by gc 69 | objects = gc.get_objects() 70 | tensors = [obj for obj in objects if isinstance(obj, torch.Tensor)] 71 | for t in tensors: 72 | self.device_mapping[t.device].append(t) 73 | 74 | def get_stats(self): 75 | """Get the memory stat of tensors and then release them 76 | 77 | As a memory profiler, we cannot hold the reference to any tensors, which 78 | causes possibly inaccurate memory usage stats, so we delete the tensors after 79 | getting required stats""" 80 | visited_data = {} 81 | self.device_tensor_stat.clear() 82 | 83 | def get_tensor_stat(tensor: torch.Tensor) -> List[Tuple[str, int, int, int]]: 84 | """Get the stat of a single tensor 85 | 86 | Returns: 87 | - stat: a tuple containing (tensor_name, tensor_size, 88 | tensor_numel, tensor_memory) 89 | """ 90 | assert isinstance(tensor, torch.Tensor) 91 | 92 | name = self._get_tensor_name(tensor) 93 | if tensor.is_sparse: 94 | indices_stat = get_tensor_stat(tensor._indices()) 95 | values_stat = get_tensor_stat(tensor._values()) 96 | return indices_stat + values_stat 97 | 98 | numel = tensor.numel() 99 | element_size = tensor.element_size() 100 | fact_numel = tensor.storage().size() 101 | fact_memory_size = fact_numel * element_size 102 | # since pytorch allocate at least 512 Bytes for any tensor, round 103 | # up to a multiple of 512 104 | memory_size = math.ceil(fact_memory_size / PYTORCH_MIN_ALLOCATE) \ 105 | * PYTORCH_MIN_ALLOCATE 106 | 107 | # tensor.storage should be the actual object related to memory 108 | # allocation 109 | data_ptr = tensor.storage().data_ptr() 110 | if data_ptr in visited_data: 111 | name = '{}(->{})'.format( 112 | name, 113 | visited_data[data_ptr], 114 | ) 115 | # don't count the memory for reusing same underlying storage 116 | memory_size = 0 117 | else: 118 | visited_data[data_ptr] = name 119 | 120 | size = tuple(tensor.size()) 121 | # torch scalar has empty size 122 | if not size: 123 | size = (1,) 124 | 125 | return [(name, size, numel, memory_size)] 126 | 127 | for device, tensors in self.device_mapping.items(): 128 | tensor_stats = [] 129 | for tensor in tensors: 130 | 131 | if tensor.numel() == 0: 132 | continue 133 | stat = get_tensor_stat(tensor) # (name, shape, numel, memory_size) 134 | tensor_stats += stat 135 | if isinstance(tensor, torch.nn.Parameter): 136 | if tensor.grad is not None: 137 | # manually specify the name of gradient tensor 138 | self.tensor_name[id(tensor.grad)] = '{}.grad'.format( 139 | self._get_tensor_name(tensor) 140 | ) 141 | stat = get_tensor_stat(tensor.grad) 142 | tensor_stats += stat 143 | 144 | self.device_tensor_stat[device] = tensor_stats 145 | 146 | self.device_mapping.clear() 147 | 148 | def print_stats(self, verbose: bool = False, target_device: Optional[torch.device] = None) -> None: 149 | # header 150 | # show_reuse = verbose 151 | # template_format = '{:<40s}{:>20s}{:>10s}' 152 | # print(template_format.format('Element type', 'Size', 'Used MEM') ) 153 | for device, tensor_stats in self.device_tensor_stat.items(): 154 | # By default, if the target_device is not specified, 155 | # print tensors on all devices 156 | if target_device is not None and device != target_device: 157 | continue 158 | # print('-' * LEN) 159 | print('\nStorage on {}'.format(device)) 160 | total_mem = 0 161 | total_numel = 0 162 | for stat in tensor_stats: 163 | name, size, numel, mem = stat 164 | # if not show_reuse: 165 | # name = name.split('(')[0] 166 | # print(template_format.format( 167 | # str(name), 168 | # str(size), 169 | # readable_size(mem), 170 | # )) 171 | total_mem += mem 172 | total_numel += numel 173 | 174 | print('-'*LEN) 175 | print('Total Tensors: {} \tUsed Memory: {}'.format( 176 | total_numel, readable_size(total_mem), 177 | )) 178 | 179 | if device != torch.device('cpu'): 180 | with torch.cuda.device(device): 181 | memory_allocated = torch.cuda.memory_allocated() 182 | print('The allocated memory on {}: {}'.format( 183 | device, readable_size(memory_allocated), 184 | )) 185 | if memory_allocated != total_mem: 186 | print('Memory differs due to the matrix alignment or' 187 | ' invisible gradient buffer tensors') 188 | print('-'*LEN) 189 | 190 | def report(self, verbose: bool = False, device: Optional[torch.device] = None) -> None: 191 | """Interface for end-users to directly print the memory usage 192 | 193 | args: 194 | - verbose: flag to show tensor.storage reuse information 195 | - device: `torch.device` object, specify the target device 196 | to report detailed memory usage. It will print memory usage 197 | on all devices if not specified. Usually we only want to 198 | print the memory usage on CUDA devices. 199 | 200 | """ 201 | self.collect_tensor() 202 | self.get_stats() 203 | self.print_stats(verbose, target_device=device) --------------------------------------------------------------------------------