├── assets
└── FedHKD.png
├── Client
├── __pycache__
│ ├── utils.cpython-36.pyc
│ ├── ClientBase.cpython-36.pyc
│ ├── ClientBase.cpython-37.pyc
│ ├── ClientFedMD.cpython-36.pyc
│ ├── ClientFedMD.cpython-37.pyc
│ ├── ClientFedAvg.cpython-36.pyc
│ ├── ClientFedAvg.cpython-37.pyc
│ ├── ClientFedDFKD.cpython-36.pyc
│ ├── ClientFedHKD.cpython-36.pyc
│ ├── ClientFedHKD.cpython-37.pyc
│ ├── ClientFedProx.cpython-36.pyc
│ ├── ClientFedProx.cpython-37.pyc
│ ├── ClientFedProto.cpython-36.pyc
│ └── ClientFedProto.cpython-37.pyc
├── ClientBase.py
├── ClientFedAvg.py
├── ClientFedProx.py
├── ClientFedMD.py
├── ClientFedProto.py
└── ClientFedHKD.py
├── Server
├── __pycache__
│ ├── ServerBase.cpython-36.pyc
│ ├── ServerBase.cpython-37.pyc
│ ├── ServerFedMD.cpython-36.pyc
│ ├── ServerFedMD.cpython-37.pyc
│ ├── ServerFedAvg.cpython-36.pyc
│ ├── ServerFedAvg.cpython-37.pyc
│ ├── ServerFedDFKD.cpython-36.pyc
│ ├── ServerFedHKD.cpython-36.pyc
│ ├── ServerFedHKD.cpython-37.pyc
│ ├── ServerFedProx.cpython-36.pyc
│ ├── ServerFedProx.cpython-37.pyc
│ ├── ServerFedProto.cpython-36.pyc
│ └── ServerFedProto.cpython-37.pyc
├── ServerBase.py
├── ServerFedAvg.py
├── ServerFedProx.py
├── ServerFedProto.py
├── ServerFedMD.py
└── ServerFedHKD.py
├── requirements.txt
├── utils.py
├── models.py
├── option.py
├── README.md
├── main.py
├── sampling.py
└── mem_utils.py
/assets/FedHKD.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/assets/FedHKD.png
--------------------------------------------------------------------------------
/Client/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/Client/__pycache__/ClientBase.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientBase.cpython-36.pyc
--------------------------------------------------------------------------------
/Client/__pycache__/ClientBase.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientBase.cpython-37.pyc
--------------------------------------------------------------------------------
/Client/__pycache__/ClientFedMD.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedMD.cpython-36.pyc
--------------------------------------------------------------------------------
/Client/__pycache__/ClientFedMD.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedMD.cpython-37.pyc
--------------------------------------------------------------------------------
/Server/__pycache__/ServerBase.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerBase.cpython-36.pyc
--------------------------------------------------------------------------------
/Server/__pycache__/ServerBase.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerBase.cpython-37.pyc
--------------------------------------------------------------------------------
/Server/__pycache__/ServerFedMD.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedMD.cpython-36.pyc
--------------------------------------------------------------------------------
/Server/__pycache__/ServerFedMD.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedMD.cpython-37.pyc
--------------------------------------------------------------------------------
/Client/__pycache__/ClientFedAvg.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedAvg.cpython-36.pyc
--------------------------------------------------------------------------------
/Client/__pycache__/ClientFedAvg.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedAvg.cpython-37.pyc
--------------------------------------------------------------------------------
/Client/__pycache__/ClientFedDFKD.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedDFKD.cpython-36.pyc
--------------------------------------------------------------------------------
/Client/__pycache__/ClientFedHKD.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedHKD.cpython-36.pyc
--------------------------------------------------------------------------------
/Client/__pycache__/ClientFedHKD.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedHKD.cpython-37.pyc
--------------------------------------------------------------------------------
/Client/__pycache__/ClientFedProx.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedProx.cpython-36.pyc
--------------------------------------------------------------------------------
/Client/__pycache__/ClientFedProx.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedProx.cpython-37.pyc
--------------------------------------------------------------------------------
/Server/__pycache__/ServerFedAvg.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedAvg.cpython-36.pyc
--------------------------------------------------------------------------------
/Server/__pycache__/ServerFedAvg.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedAvg.cpython-37.pyc
--------------------------------------------------------------------------------
/Server/__pycache__/ServerFedDFKD.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedDFKD.cpython-36.pyc
--------------------------------------------------------------------------------
/Server/__pycache__/ServerFedHKD.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedHKD.cpython-36.pyc
--------------------------------------------------------------------------------
/Server/__pycache__/ServerFedHKD.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedHKD.cpython-37.pyc
--------------------------------------------------------------------------------
/Server/__pycache__/ServerFedProx.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedProx.cpython-36.pyc
--------------------------------------------------------------------------------
/Server/__pycache__/ServerFedProx.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedProx.cpython-37.pyc
--------------------------------------------------------------------------------
/Client/__pycache__/ClientFedProto.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedProto.cpython-36.pyc
--------------------------------------------------------------------------------
/Client/__pycache__/ClientFedProto.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Client/__pycache__/ClientFedProto.cpython-37.pyc
--------------------------------------------------------------------------------
/Server/__pycache__/ServerFedProto.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedProto.cpython-36.pyc
--------------------------------------------------------------------------------
/Server/__pycache__/ServerFedProto.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityChan/Federated-Hyper-Knowledge-Distillation/HEAD/Server/__pycache__/ServerFedProto.cpython-37.pyc
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | calmsize==0.1.3
2 | imageio==2.9.0
3 | numpy==1.19.2
4 | Pillow==9.2.0
5 | scikit_learn==1.1.2
6 | scipy==1.5.2
7 | tensorboardX==2.5.1
8 | torch==1.9.0+rocm4.2
9 | torchvision==0.10.0+rocm4.2
10 | tqdm==4.62.3
11 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import copy
3 | def Accuracy(y,y_predict):
4 | leng = len(y)
5 | miss = 0
6 | for i in range(leng):
7 | if not y[i]==y_predict[i]:
8 | miss +=1
9 | return (leng-miss)/leng
10 |
11 |
12 | def soft_predict(Z,temp):
13 | m,n = Z.shape
14 | Q = torch.zeros(m,n)
15 | Z_sum = torch.sum(torch.exp(Z/temp),dim=1)
16 | for i in range(n):
17 | Q[:,i] = torch.exp(Z[:,i]/temp)/Z_sum
18 | return Q
19 |
20 | def average_weights(w):
21 | """
22 | average the weights from all local models
23 | """
24 | w_avg = copy.deepcopy(w[0])
25 | for key in w_avg.keys():
26 | for i in range(1, len(w)):
27 | w_avg[key] += w[i][key]
28 | w_avg[key] = torch.div(w_avg[key], len(w))
29 | return w_avg
30 |
--------------------------------------------------------------------------------
/Server/ServerBase.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.utils.data import Dataset
3 | import torch
4 | import copy
5 | from utils import Accuracy
6 |
7 | class Server(object):
8 | def __init__(self,args, global_model,Loaders_train, Loaders_local_test, Loader_global_test, logger, device):
9 | self.global_model = global_model
10 | self.args = args
11 | self.Loaders_train = Loaders_train
12 | self.Loaders_local_test = Loaders_local_test
13 | self.global_testloader = Loader_global_test
14 | self.logger = logger
15 | self.device = device
16 | self.LocalModels = []
17 |
18 | def global_test_accuracy(self):
19 | self.global_model.eval()
20 | accuracy = 0
21 | cnt = 0
22 | for batch_idx, (X, y) in enumerate(self.global_testloader):
23 | X = X.to(self.device)
24 | y = y.to(self.device)
25 | _,p = self.global_model(X)
26 | y_pred = p.argmax(1)
27 | accuracy += Accuracy(y,y_pred)
28 | cnt += 1
29 | return accuracy/cnt
30 |
31 |
32 | def Save_CheckPoint(self, save_path):
33 | torch.save(self.global_model.state_dict(), save_path)
34 |
35 |
--------------------------------------------------------------------------------
/Client/ClientBase.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | import scipy
5 | from torch.utils.data import Dataset
6 | import torch
7 | import copy
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import torch.nn.functional as F
11 | from utils import Accuracy,soft_predict
12 |
13 | class Client(object):
14 | """
15 | This class is for train the local model with input global model(copied) and output the updated weight
16 | args: argument
17 | Loader_train,Loader_val,Loaders_test: input for training and inference
18 | user: the index of local model
19 | idxs: the index for data of this local model
20 | logger: log the loss and the process
21 | """
22 | def __init__(self, args, model,Loader_train,loader_test,idx, logger, code_length, num_classes, device):
23 | self.args = args
24 | self.logger = logger
25 | self.trainloader = Loader_train
26 | self.testloader = loader_test
27 | self.idx = idx
28 | self.ce = nn.CrossEntropyLoss()
29 | self.device = device
30 | self.code_length = code_length
31 | self.kld = nn.KLDivLoss()
32 | self.mse = nn.MSELoss()
33 | self.model = copy.deepcopy(model)
34 |
35 |
36 | def test_accuracy(self):
37 | self.model.eval()
38 | accuracy = 0
39 | cnt = 0
40 | for batch_idx, (X, y) in enumerate(self.testloader):
41 | X = X.to(self.device)
42 | y = y.to(self.device)
43 | _, p = self.model(X)
44 | y_pred = p.argmax(1)
45 | accuracy += Accuracy(y,y_pred)
46 | cnt += 1
47 | return accuracy/cnt
48 |
49 | def load_model(self,global_weights):
50 | self.model.load_state_dict(global_weights)
--------------------------------------------------------------------------------
/Client/ClientFedAvg.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | import scipy
5 | from torch.utils.data import Dataset
6 | import torch
7 | import copy
8 | import torch.nn as nn
9 | from sklearn.cluster import KMeans
10 | import torch.optim as optim
11 | import torch.nn.functional as F
12 | from utils import Accuracy,soft_predict
13 | from Client.ClientBase import Client
14 |
15 | class ClientFedAvg(Client):
16 | """
17 | This class is for train the local model with input global model(copied) and output the updated weight
18 | args: argument
19 | Loader_train,Loader_val,Loaders_test: input for training and inference
20 | user: the index of local model
21 | idxs: the index for data of this local model
22 | logger: log the loss and the process
23 | """
24 | def __init__(self, args, model, Loader_train,loader_test,idx, logger, code_length, num_classes, device):
25 | super().__init__(args, model, Loader_train,loader_test,idx, logger, code_length, num_classes, device)
26 |
27 | def update_weights(self,global_round):
28 | self.model.to(self.device)
29 | self.model.train()
30 | epoch_loss = []
31 | optimizer = optim.Adam(self.model.parameters(),lr=self.args.lr)
32 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.args.lr_sh_rate, gamma=0.5)
33 | for iter in range(self.args.local_ep):
34 | batch_loss = []
35 | for batch_idx, (X, y) in enumerate(self.trainloader):
36 | X = X.to(self.device)
37 | y = y.to(self.device)
38 | optimizer.zero_grad()
39 | _,p = self.model(X)
40 | loss = self.ce(p,y)
41 | loss.backward()
42 | if self.args.clip_grad != None:
43 | nn.utils.clip_grad_norm_(self.model.parameters(), max_norm = self.args.clip_grad)
44 | optimizer.step()
45 | if batch_idx % 10 == 0:
46 | print('| Global Round : {} | Client: {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
47 | global_round, self.idx, iter, batch_idx * len(X),
48 | len(self.trainloader.dataset),
49 | 100. * batch_idx / len(self.trainloader), loss.item()))
50 | self.logger.add_scalar('loss', loss.item())
51 | batch_loss.append(loss.item())
52 | epoch_loss.append(sum(batch_loss)/len(batch_loss))
53 |
54 | return self.model.state_dict(),sum(epoch_loss) / len(epoch_loss)
55 |
--------------------------------------------------------------------------------
/Server/ServerFedAvg.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.utils.data import Dataset
3 | import torch
4 | import copy
5 | from utils import Accuracy
6 | from Server.ServerBase import Server
7 | from Client.ClientFedAvg import ClientFedAvg
8 | from tqdm import tqdm
9 | import numpy as np
10 | from utils import average_weights
11 | from mem_utils import MemReporter
12 | import time
13 | class ServerFedAvg(Server):
14 | def __init__(self, args, global_model,Loader_train,Loaders_local_test,Loader_global_test,logger,device):
15 | super().__init__(args, global_model,Loader_train,Loaders_local_test,Loader_global_test,logger,device)
16 |
17 |
18 | def Create_Clints(self):
19 | for idx in range(self.args.num_clients):
20 | self.LocalModels.append(ClientFedAvg(self.args, copy.deepcopy(self.global_model),self.Loaders_train[idx], self.Loaders_local_test[idx], idx=idx, logger=self.logger, code_length = self.args.code_len, num_classes = self.args.num_classes, device=self.device))
21 |
22 |
23 | def train(self):
24 | reporter = MemReporter()
25 | start_time = time.time()
26 | train_loss = []
27 | global_weights = self.global_model.state_dict()
28 | for epoch in tqdm(range(self.args.num_epochs)):
29 | test_accuracy = 0
30 | local_weights, local_losses = [], []
31 | print(f'\n | Global Training Round : {epoch+1} |\n')
32 | m = max(int(self.args.sampling_rate * self.args.num_clients), 1)
33 | idxs_users = np.random.choice(range(self.args.num_clients), m, replace=False)
34 | for idx in idxs_users:
35 | if self.args.upload_model == True:
36 | self.LocalModels[idx].load_model(global_weights)
37 | w, loss = self.LocalModels[idx].update_weights(global_round=epoch)
38 | local_losses.append(copy.deepcopy(loss))
39 | local_weights.append(copy.deepcopy(w))
40 | acc = self.LocalModels[idx].test_accuracy()
41 | test_accuracy += acc
42 |
43 |
44 | # update global weights
45 | global_weights = average_weights(local_weights)
46 | self.global_model.load_state_dict(global_weights)
47 | loss_avg = sum(local_losses) / len(local_losses)
48 | train_loss.append(loss_avg)
49 | print("average loss: ", loss_avg)
50 | print('average local test accuracy:', test_accuracy / self.args.num_clients)
51 | print('global test accuracy: ', self.global_test_accuracy())
52 |
53 | print('Training is completed.')
54 | end_time = time.time()
55 | print('running time: {} s '.format(end_time - start_time))
56 | reporter.report()
57 |
--------------------------------------------------------------------------------
/Server/ServerFedProx.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.utils.data import Dataset
3 | import torch
4 | import copy
5 | from utils import Accuracy
6 | from Server.ServerBase import Server
7 | from Client.ClientFedProx import ClientFedProx
8 | from tqdm import tqdm
9 | import numpy as np
10 | from utils import average_weights
11 | from mem_utils import MemReporter
12 | import time
13 | class ServerFedProx(Server):
14 | def __init__(self, args, global_model,Loader_train,Loaders_local_test,Loader_global_test,logger,device):
15 | super().__init__(args, global_model,Loader_train,Loaders_local_test,Loader_global_test,logger,device)
16 |
17 |
18 | def Create_Clints(self):
19 | for idx in range(self.args.num_clients):
20 | self.LocalModels.append(ClientFedProx(self.args, copy.deepcopy(self.global_model),self.Loaders_train[idx], self.Loaders_local_test[idx], idx=idx, logger=self.logger, code_length = self.args.code_len, num_classes = self.args.num_classes, device=self.device))
21 |
22 |
23 | def train(self):
24 | reporter = MemReporter()
25 | start_time = time.time()
26 | train_loss = []
27 | global_weights = self.global_model.state_dict()
28 | for epoch in tqdm(range(self.args.num_epochs)):
29 | test_accuracy = 0
30 | local_weights, local_losses = [], []
31 | print(f'\n | Global Training Round : {epoch+1} |\n')
32 | m = max(int(self.args.sampling_rate * self.args.num_clients), 1)
33 | idxs_users = np.random.choice(range(self.args.num_clients), m, replace=False)
34 | for idx in idxs_users:
35 | if self.args.upload_model == True:
36 | self.LocalModels[idx].load_model(global_weights)
37 | w, loss = self.LocalModels[idx].update_weights_Prox(global_round=epoch, lam=0.1)
38 | local_losses.append(copy.deepcopy(loss))
39 | local_weights.append(copy.deepcopy(w))
40 | acc = self.LocalModels[idx].test_accuracy()
41 | test_accuracy += acc
42 |
43 |
44 | # update global weights
45 | global_weights = average_weights(local_weights)
46 | self.global_model.load_state_dict(global_weights)
47 | loss_avg = sum(local_losses) / len(local_losses)
48 | train_loss.append(loss_avg)
49 | print("average loss: ", loss_avg)
50 | print('average local test accuracy:', test_accuracy / self.args.num_clients)
51 | print('global test accuracy: ', self.global_test_accuracy())
52 |
53 | print('Training is completed.')
54 | end_time = time.time()
55 | print('running time: {} s '.format(end_time - start_time))
56 | reporter.report()
57 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from os.path import join
3 | import imageio
4 | from torch import nn
5 | from torch.nn.modules.linear import Linear
6 | from torch.utils.data import Dataset
7 | from torchvision import transforms
8 | from tqdm import tqdm
9 | import numpy as np
10 | import torch
11 | import torch.nn.functional as F
12 | import torchvision.models as models
13 |
14 |
15 | class EncoderFemnist(nn.Module):
16 | def __init__(self, code_length):
17 | super(EncoderFemnist, self).__init__()
18 | self.conv1 = nn.Conv2d(1, 10, kernel_size=3)
19 | self.conv2 = nn.Conv2d(10,20, kernel_size=5)
20 | self.conv2_drop = nn.Dropout2d()
21 | self.fc1 = nn.Linear(int(320), code_length)
22 |
23 | def forward(self, x):
24 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
25 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
26 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
27 | z = F.relu(self.fc1(x))
28 | return z
29 |
30 | class CNNFemnist(nn.Module):
31 | def __init__(self, args,code_length=50,num_classes = 62):
32 | super(CNNFemnist, self).__init__()
33 | self.code_length = code_length
34 | self.num_classes = num_classes
35 | self.feature_extractor = EncoderFemnist(self.code_length)
36 | self.classifier = nn.Sequential(nn.Dropout(0.2),
37 | nn.Linear(self.code_length, self.num_classes),
38 | nn.LogSoftmax(dim=1))
39 |
40 | def forward(self, x):
41 | z = self.feature_extractor(x)
42 | p = self.classifier(z)
43 | return z,p
44 |
45 |
46 | class ResNet18(nn.Module):
47 | def __init__(self, args,code_length=64,num_classes = 10):
48 | super(ResNet18, self).__init__()
49 | self.code_length = code_length
50 | self.num_classes = num_classes
51 | self.feature_extractor = models.resnet18(num_classes=self.code_length)
52 | self.classifier = nn.Sequential(
53 | nn.Linear(self.code_length, self.num_classes))
54 | def forward(self,x):
55 | z = self.feature_extractor(x)
56 | p = self.classifier(z)
57 | return z,p
58 |
59 | class ShuffLeNet(nn.Module):
60 | def __init__(self, args,code_length=64,num_classes = 10):
61 | super(ShuffLeNet, self).__init__()
62 | self.code_length = code_length
63 | self.num_classes = num_classes
64 | self.feature_extractor = models.shufflenet_v2_x1_0(num_classes=self.code_length)
65 | self.classifier = nn.Sequential(
66 | nn.Linear(self.code_length, self.num_classes))
67 | def forward(self,x):
68 | z = self.feature_extractor(x)
69 | p = self.classifier(z)
70 | return z,p
--------------------------------------------------------------------------------
/Client/ClientFedProx.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | import scipy
5 | from torch.utils.data import Dataset
6 | import torch
7 | import copy
8 | import torch.nn as nn
9 | from sklearn.cluster import KMeans
10 | import torch.optim as optim
11 | import torch.nn.functional as F
12 | from utils import Accuracy,soft_predict
13 | from Client.ClientBase import Client
14 |
15 | class ClientFedProx(Client):
16 | """
17 | This class is for train the local model with input global model(copied) and output the updated weight
18 | args: argument
19 | Loader_train,Loader_val,Loaders_test: input for training and inference
20 | user: the index of local model
21 | idxs: the index for data of this local model
22 | logger: log the loss and the process
23 | """
24 | def __init__(self, args, model, Loader_train,loader_test,idx, logger, code_length, num_classes, device):
25 | super().__init__(args, model, Loader_train,loader_test,idx, logger, code_length, num_classes, device)
26 |
27 | def update_weights_Prox(self,global_round, lam):
28 | self.model.cuda()
29 | self.model.train()
30 | global_model = copy.deepcopy(self.model)
31 | global_model.eval()
32 | global_weight_collector = list(global_model.parameters())
33 | epoch_loss = []
34 | optimizer = optim.Adam(self.model.parameters(),lr=self.args.lr)
35 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.args.lr_sh_rate, gamma=0.5)
36 | for iter in range(self.args.local_ep):
37 | batch_loss = []
38 | for batch_idx, (X, y) in enumerate(self.trainloader):
39 | X = X.to(self.device)
40 | y = y.to(self.device).long()
41 | optimizer.zero_grad()
42 | _,p = self.model(X)
43 | y_pred = p.argmax(1)
44 | loss1 = self.ce(p,y)
45 | fed_prox_reg = 0.0
46 | for param_index, param in enumerate(self.model.parameters()):
47 | fed_prox_reg += ((lam / 2) * torch.norm((param - global_weight_collector[param_index])) ** 2)
48 | loss = loss1 + lam*fed_prox_reg
49 | loss.backward()
50 | if self.args.clip_grad != None:
51 | nn.utils.clip_grad_norm_(self.model.parameters(), max_norm = self.args.clip_grad)
52 | optimizer.step()
53 | if batch_idx % 10 == 0:
54 | print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t prox_loss: {:.6f}'.format(
55 | global_round, iter, batch_idx * len(X),
56 | len(self.trainloader.dataset),
57 | 100. * batch_idx / len(self.trainloader), loss.item(),fed_prox_reg.item()))
58 | self.logger.add_scalar('loss', loss.item())
59 | batch_loss.append(loss.item())
60 | epoch_loss.append(sum(batch_loss)/len(batch_loss))
61 |
62 | return self.model.state_dict(), sum(epoch_loss) / len(epoch_loss)
--------------------------------------------------------------------------------
/option.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def args_parser():
5 | parser = argparse.ArgumentParser()
6 |
7 | #Data specifc paremeters
8 | parser.add_argument('--dataset', default='CIFAR10',
9 | help='CIFAR10, CIFAR100, SVHN, EMNIST')
10 | #Training specifc parameters
11 | parser.add_argument('--log_frq', type=int, default=5,
12 | help='frequency of logging')
13 | parser.add_argument('--batch_size', type=int, default=64,
14 | help='minibatch size')
15 | parser.add_argument('--num_epochs', type=int, default=50,
16 | help='number of epochs')
17 | parser.add_argument('--clip_grad', type=float, default=None,
18 | help='gadient clipping')
19 | parser.add_argument('--lr', type=float, default=0.001,
20 | help='learning rate')
21 | parser.add_argument('--lr_sh_rate', type=int, default=10,
22 | help='number of steps to drop the lr')
23 | parser.add_argument('--use_lrschd', action="store_true", default=False,
24 | help='Use lr rate scheduler')
25 | parser.add_argument('--num_clients', type=int, default=10,
26 | help='number of local models')
27 |
28 | parser.add_argument('--num_classes', type=int,default=10,
29 | help='number of classes')
30 |
31 | parser.add_argument('--sampling_rate', type=float,default=1,
32 | help='frac of local models to update')
33 | parser.add_argument('--local_ep',type=int, default=5,
34 | help='iterations of local updating')
35 | parser.add_argument('--beta', type=float,default=0.5,
36 | help='beta for non-iid distribution')
37 | parser.add_argument('--seed', type=int,default=0,
38 | help='random seed for generating datasets')
39 | parser.add_argument('--code_len', type=int,default=32,
40 | help='length of code')
41 | parser.add_argument('--alg', default='FedAvg',
42 | help='FedAvg, FedProx, Moon, FedMD, Fedproto, FedDFKD')
43 |
44 | parser.add_argument('--lam', type=float, default=0.05,
45 | help='hyper-parameter for loss2')
46 |
47 | parser.add_argument('--gamma', type=float, default=0.05,
48 | help='hyper-parameter for loss3')
49 |
50 | parser.add_argument('--std', type=float, default=2,
51 | help='std of gaussian noise ')
52 |
53 | parser.add_argument('--part', type=float,default=0.1,
54 | help='percentage of each local data')
55 |
56 |
57 | parser.add_argument('--temp', type=float,default=0.5,
58 | help='temperture for soft prediction')
59 |
60 | parser.add_argument('--model', default= 'resnet18',
61 | help='CNN resnet18 shufflenet')
62 |
63 | parser.add_argument('--upload_model', type=bool, default=True,
64 | help='whether to upload model parameters')
65 | parser.add_argument('--save_model', action="store_true", default= False,
66 | help='saved model parameters')
67 |
68 | parser.add_argument('--eval_only', action="store_true", default=False,help='evaluate the model')
69 |
70 | args = parser.parse_args()
71 | return args
72 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Federated Hyper Knowledge Distillation
2 | This is an official repository for our ICLR2023 paper
3 | * "[The Best of Both Worlds Accurate Global and Personalized Models through Federated Learning with Data-Free Hyper-Knowledge Distillation](https://arxiv.org/abs/2301.08968)"
4 |
5 |
6 |
7 |
8 |
9 | A flow diagram showing computation, encryption and aggregation of hyper-knowledge.
10 |
11 |
12 |
13 |
14 |
15 | ### Environment
16 | This project is developed based on python 3.6 with [torch1.9 (rocm4.2)](https://pytorch.org/get-started/previous-versions/). We use [conda](https://www.anaconda.com/docs/main) to manage the virtual environment.
17 | ```
18 | git clone git@github.com:CityChan/Federated-Hyper-Knowledge-Distillation.git
19 | cd Federated-Hyper-Knowledge-Distillation
20 | conda create -n fedhkd --python=3.6
21 | conda activate fedhkd
22 | pip install torch==1.9.1+rocm4.2 torchvision==0.10.1+rocm4.2 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
23 | pip install -r requirements.txt
24 | ```
25 |
26 | ### Code structure
27 | * `main.py`: general set up for training and evaluate FL schemes
28 | * `models.py`: model architectures for running experiments
29 | * `sampling.py`: functions for generating non-iid datasets for federated learning
30 | * `utils.py`: functions for computing accuracy, knowledge distillation and model aggregation
31 | * `mem_utils.py`: library for monitoring memory usage and training time
32 | * `option.py`: define hyper-parameters
33 | * `Server/*.py`: object definition for server in differents scheme
34 | * `Client/*.py`: object definition for client in differents schemes
35 |
36 | ### Parameters
37 | * --dataset: 'CIFAR10', 'CIFAR100', ' SVHN'
38 | * --batch_size: batchsize, 64 by default
39 | * --num_epochs: number of global rounds, 50 by default
40 | * --lr: learning rate, 0.001 by defalut
41 | * --lr_sh_rate: period of learning rate decay, 10 by default
42 | * --dropout_rate: drop out rate for each layer, 0.2 by default
43 | * --clip_grad: maximum norm for gradient, 1.0 by default
44 | * --num_users: number of clients, 10 by default
45 | * --sampling_rate: proportion of clients send updates per round, 1.1 by default
46 | * --local_ep: number of local epochs, 5 by default
47 | * --beta: concentration parameter for Dirichlet distribution: 0.5 by default
48 | * --seed: random seed(for better reproducting experiments): 0 by default
49 | * --std: standard deviation by differential private noise, 2.0 by default
50 | * --code_len: dimention of latent vector, 32 by default
51 | * --alg: 'FedAvg, FedProx, Moon, FedMD, Fedproto, FedHKD'
52 | * --eval_only: only ouput the testing accuracy
53 | * --part: percentage of each local data
54 | * --temp: temperture for soft prediction
55 | * --lam: weights for loss2
56 | * --gamma: weights for loss3
57 | * --model: 'CNN', 'resnet18', 'shufflenet'
58 | * --save_model: save checkpoints of the model
59 |
60 | ### Running the code for training and evaluation
61 | ```
62 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset 'SVHN' --batch_size 64 --num_epochs 50 --clip_grad 1.1 --lr 0.001 --num_clients 10 --num_classes 10 --sampling_rate 1 --local_ep 3 --beta 0.5 --seed 0 --code_len 50 --alg 'FedAvg' --part 0.1 --model 'resnet18' --temp 0.5
63 | ```
64 |
65 | ### Acknowledgement
66 | This work was completed during internship in Toyota AI/ML Infrastructure & Data Lab.
67 |
68 | ### Citeation
69 | Please cite our paper, if you think this is useful:
70 | ```
71 | @inproceedings{chen2023best,
72 | title={The Best of Both Worlds: Accurate Global and Personalized Models through Federated Learning with Data-Free Hyper-Knowledge Distillation},
73 | author={Chen, Huancheng and Vikalo, Haris and others},
74 | journal={arXiv preprint arXiv:2301.08968},
75 | year={2023}
76 | }
77 | ```
78 |
79 |
--------------------------------------------------------------------------------
/Server/ServerFedProto.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.utils.data import Dataset
3 | import torch
4 | import copy
5 | from utils import Accuracy
6 | from Server.ServerBase import Server
7 | from Client.ClientFedProto import ClientFedProto
8 | from tqdm import tqdm
9 | import numpy as np
10 | from utils import average_weights
11 | from mem_utils import MemReporter
12 | import time
13 | from sampling import LocalDataset, LocalDataloaders, partition_data
14 | import gc
15 |
16 | class ServerFedProto(Server):
17 | def __init__(self, args, global_model,Loader_train,Loaders_local_test,Loader_global_test,logger,device):
18 | super().__init__(args, global_model,Loader_train,Loaders_local_test,Loader_global_test,logger,device)
19 |
20 |
21 | def Create_Clints(self):
22 | for idx in range(self.args.num_clients):
23 | self.LocalModels.append(ClientFedProto(self.args, copy.deepcopy(self.global_model),self.Loaders_train[idx], self.Loaders_local_test[idx], idx=idx, logger=self.logger, code_length = self.args.code_len, num_classes = self.args.num_classes, device=self.device))
24 |
25 | def global_knowledge_aggregation(self, features):
26 | global_local_features = dict()
27 | for [label, features] in features.items():
28 | if len(features) > 1:
29 | feature = 0 * features[0].data
30 | for i in features:
31 | feature += i.data
32 | global_local_features[label] = [feature / len(features)]
33 | else:
34 | global_local_features[label] = [features[0].data]
35 |
36 |
37 | return global_local_features
38 |
39 | def train(self):
40 | global_features = {}
41 | reporter = MemReporter()
42 | start_time = time.time()
43 | train_loss = []
44 | global_weights = self.global_model.state_dict()
45 | for epoch in tqdm(range(self.args.num_epochs)):
46 | Knowledges = []
47 | test_accuracy = 0
48 | local_weights, local_losses = [], []
49 | print(f'\n | Global Training Round : {epoch+1} |\n')
50 | m = max(int(self.args.sampling_rate * self.args.num_clients), 1)
51 | idxs_users = np.random.choice(range(self.args.num_clients), m, replace=False)
52 | for idx in idxs_users:
53 | if self.args.upload_model == True:
54 | self.LocalModels[idx].load_model(global_weights)
55 | if epoch < 1:
56 | w, loss = self.LocalModels[idx].update_weights(global_round=epoch)
57 | local_losses.append(copy.deepcopy(loss))
58 | local_weights.append(copy.deepcopy(w))
59 | acc = self.LocalModels[idx].test_accuracy()
60 | test_accuracy += acc
61 |
62 | else:
63 | w, loss = self.LocalModels[idx].update_weights_Proto(global_round=epoch, global_features=global_features, gamma = self.args.gamma)
64 | local_losses.append(copy.deepcopy(loss))
65 | local_weights.append(copy.deepcopy(w))
66 | acc = self.LocalModels[idx].test_accuracy()
67 | test_accuracy += acc
68 |
69 | local_features = self.LocalModels[idx].generate_knowledge()
70 | global_features.update(local_features)
71 | del local_features
72 | gc.collect()
73 |
74 |
75 | # update global weights
76 | global_weights = average_weights(local_weights)
77 | self.global_model.load_state_dict(global_weights)
78 | loss_avg = sum(local_losses) / len(local_losses)
79 | train_loss.append(loss_avg)
80 | print("average loss: ", loss_avg)
81 | print('average local test accuracy:', test_accuracy / self.args.num_clients)
82 | print('global test accuracy: ', self.global_test_accuracy())
83 |
84 | print('Training is completed.')
85 | end_time = time.time()
86 | print('running time: {} s '.format(end_time - start_time))
87 | reporter.report()
88 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os,sys,os.path
4 | from tensorboardX import SummaryWriter
5 | import pickle
6 | from torch import nn
7 | import hashlib
8 | import argparse
9 |
10 | from models import CNNFemnist,ResNet18,ShuffLeNet
11 | from sampling import LocalDataset, LocalDataloaders, partition_data
12 | from option import args_parser
13 |
14 | # import different schemes
15 | from Server.ServerFedAvg import ServerFedAvg
16 | from Server.ServerFedProx import ServerFedProx
17 | from Server.ServerFedMD import ServerFedMD
18 | from Server.ServerFedProto import ServerFedProto
19 | from Server.ServerFedHKD import ServerFedHKD
20 |
21 | print(torch.__version__)
22 | torch.cuda.is_available()
23 | np.set_printoptions(threshold=np.inf)
24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25 | print(device.type)
26 |
27 | args = args_parser()
28 | print(args)
29 |
30 | # obtain hash value for saving checkpoints
31 | args_hash = ''
32 | for k,v in vars(args).items():
33 | if k == 'eval_only':
34 | continue
35 | args_hash += str(k)+str(v)
36 |
37 | args_hash = hashlib.sha256(args_hash.encode()).hexdigest()
38 |
39 |
40 |
41 |
42 | # Generate data partitions in FL
43 | train_dataset,testset, dict_users, dict_users_test = partition_data(n_users = args.num_clients, alpha=args.beta,rand_seed = args.seed, dataset=str(args.dataset))
44 |
45 |
46 |
47 | # Load local training datasets and testsets for each client
48 | Loaders_train = LocalDataloaders(train_dataset,dict_users,args.batch_size,ShuffleorNot = True,frac=args.part)
49 | Loaders_test = LocalDataloaders(testset,dict_users_test,args.batch_size,ShuffleorNot = True,frac=2*args.part)
50 | global_loader_test = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,shuffle=True, num_workers=2)
51 |
52 | for idx in range(args.num_clients):
53 | counts = [0]*args.num_classes
54 | for batch_idx,(X,y) in enumerate(Loaders_train[idx]):
55 | batch = len(y)
56 | y = np.array(y)
57 | for i in range(batch):
58 | counts[int(y[i])] += 1
59 | # print out data distribution of each client
60 | print('Client {} data distribution:'.format(idx))
61 | print(counts)
62 |
63 |
64 |
65 |
66 |
67 | logger = SummaryWriter('./logs')
68 | checkpoint_dir = './checkpoint/'+ args.dataset + '/'
69 | if not os.path.exists(checkpoint_dir):
70 | os.makedirs(checkpoint_dir)
71 | with open(checkpoint_dir+'args.pkl', 'wb') as fp:
72 | pickle.dump(args, fp)
73 | print('Checkpoint dir:', checkpoint_dir)
74 |
75 |
76 |
77 |
78 | print(args.model)
79 | if args.model == 'CNN':
80 | # for EMNIST 62 classes
81 | global_model = CNNFemnist(args, code_length=args.code_len, num_classes = args.num_classes)
82 |
83 | if args.model == 'resnet18':
84 | global_model = ResNet18(args, code_length=args.code_len, num_classes = args.num_classes)
85 |
86 | if args.model == 'shufflenet':
87 | global_model = ShuffLeNet(args, code_length=args.code_len, num_classes = args.num_classes)
88 |
89 |
90 | print('# model parameters:', sum(param.numel() for param in global_model.parameters()))
91 | # global_model = nn.DataParallel(global_model)
92 | global_model.to(device)
93 |
94 |
95 |
96 |
97 | if args.alg == 'FedAvg':
98 | server = ServerFedAvg(args,global_model,Loaders_train,Loaders_test,global_loader_test,logger,device)
99 | if args.alg == 'FedProx':
100 | server = ServerFedProx(args,global_model,Loaders_train,Loaders_test,global_loader_test,logger,device)
101 | if args.alg == 'FedMD':
102 | server = ServerFedMD(args,global_model,Loaders_train,Loaders_test,global_loader_test,testset,logger,device)
103 | if args.alg == 'FedProto':
104 | server = ServerFedProto(args,global_model,Loaders_train,Loaders_test,global_loader_test,logger,device)
105 | if args.alg == 'FedHKD':
106 | server = ServerFedHKD(args,global_model,Loaders_train,Loaders_test,global_loader_test,logger,device)
107 |
108 |
109 | server.Create_Clints()
110 | server.train()
111 |
112 | save_path = checkpoint_dir + args_hash + '.pth'
113 | if args.save_model == True:
114 | server.Save_CheckPoint(save_path)
115 | print('Model is saved on: ')
116 | print(save_path)
117 |
118 |
119 |
120 |
121 |
122 |
123 |
--------------------------------------------------------------------------------
/Server/ServerFedMD.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.utils.data import Dataset
3 | import torch
4 | import copy
5 | from utils import Accuracy
6 | from Server.ServerBase import Server
7 | from Client.ClientFedMD import ClientFedMD
8 | from tqdm import tqdm
9 | import numpy as np
10 | from utils import average_weights
11 | from mem_utils import MemReporter
12 | import time
13 | from sampling import LocalDataset, LocalDataloaders, partition_data
14 | import gc
15 |
16 | class ServerFedMD(Server):
17 | def __init__(self, args, global_model,Loader_train,Loaders_local_test,Loader_global_test, pub_test,logger,device):
18 | super().__init__(args, global_model,Loader_train,Loaders_local_test,Loader_global_test,logger,device)
19 | dict_pub = [np.random.randint(low=0,high=10000,size = 1000)]
20 | self.public_data = LocalDataloaders(pub_test,dict_pub,args.batch_size,ShuffleorNot = False,frac=1)[0]
21 |
22 | def Create_Clints(self):
23 |
24 |
25 | for idx in range(self.args.num_clients):
26 | self.LocalModels.append(ClientFedMD(self.args, copy.deepcopy(self.global_model),self.Loaders_train[idx], self.Loaders_local_test[idx], loader_pub = self.public_data, idx=idx, logger=self.logger, code_length = self.args.code_len, num_classes = self.args.num_classes, device=self.device))
27 |
28 |
29 | def train(self):
30 | reporter = MemReporter()
31 | start_time = time.time()
32 | train_loss = []
33 | global_weights = self.global_model.state_dict()
34 | for epoch in tqdm(range(self.args.num_epochs)):
35 | Knowledges = []
36 | test_accuracy = 0
37 | local_weights, local_losses = [], []
38 | print(f'\n | Global Training Round : {epoch+1} |\n')
39 | m = max(int(self.args.sampling_rate * self.args.num_clients), 1)
40 | idxs_users = np.random.choice(range(self.args.num_clients), m, replace=False)
41 | for idx in idxs_users:
42 | if self.args.upload_model == True:
43 | self.LocalModels[idx].load_model(global_weights)
44 | if epoch < 1:
45 | w, loss = self.LocalModels[idx].update_weights(global_round=epoch)
46 | local_losses.append(copy.deepcopy(loss))
47 | local_weights.append(copy.deepcopy(w))
48 | acc = self.LocalModels[idx].test_accuracy()
49 | test_accuracy += acc
50 |
51 | else:
52 | w, loss = self.LocalModels[idx].update_weights_MD(global_round=epoch, knowledges = global_soft_prediciton, lam = 0.1, temp = self.args.temp)
53 | local_losses.append(copy.deepcopy(loss))
54 | local_weights.append(copy.deepcopy(w))
55 | acc = self.LocalModels[idx].test_accuracy()
56 | test_accuracy += acc
57 |
58 | knowledges = self.LocalModels[idx].generate_knowledge(temp=self.args.temp)
59 | Knowledges.append(torch.stack(knowledges))
60 | global_soft_prediciton = []
61 | batch_pub = Knowledges[0].shape[0]
62 | for i in range(batch_pub):
63 | num = Knowledges[0].shape[1]
64 | soft_label = torch.zeros(num,self.args.num_classes)
65 | for idx in idxs_users:
66 | soft_label += Knowledges[idx][i]
67 | soft_label = soft_label/ len(idxs_users)
68 | global_soft_prediciton.append(soft_label)
69 | del Knowledges
70 | gc.collect()
71 |
72 | # update global weights
73 | global_weights = average_weights(local_weights)
74 | self.global_model.load_state_dict(global_weights)
75 | loss_avg = sum(local_losses) / len(local_losses)
76 | train_loss.append(loss_avg)
77 | print("average loss: ", loss_avg)
78 | print('average local test accuracy:', test_accuracy / self.args.num_clients)
79 | print('global test accuracy: ', self.global_test_accuracy())
80 |
81 | print('Training is completed.')
82 | end_time = time.time()
83 | print('running time: {} s '.format(end_time - start_time))
84 | reporter.report()
85 |
--------------------------------------------------------------------------------
/Server/ServerFedHKD.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.utils.data import Dataset
3 | import torch
4 | import copy
5 | from utils import Accuracy
6 | from Server.ServerBase import Server
7 | from Client.ClientFedHKD import ClientFedHKD
8 | from tqdm import tqdm
9 | import numpy as np
10 | from utils import average_weights
11 | from mem_utils import MemReporter
12 | import time
13 | from sampling import LocalDataset, LocalDataloaders, partition_data
14 | import gc
15 |
16 | class ServerFedHKD(Server):
17 | def __init__(self, args, global_model,Loader_train,Loaders_local_test,Loader_global_test,logger,device):
18 | super().__init__(args, global_model,Loader_train,Loaders_local_test,Loader_global_test,logger,device)
19 |
20 |
21 | def Create_Clints(self):
22 | for idx in range(self.args.num_clients):
23 | self.LocalModels.append(ClientFedHKD(self.args, copy.deepcopy(self.global_model),self.Loaders_train[idx], self.Loaders_local_test[idx], idx=idx, logger=self.logger, code_length = self.args.code_len, num_classes = self.args.num_classes, device=self.device))
24 |
25 | def global_knowledge_aggregation(self, features,soft_prediction):
26 | global_local_features = dict()
27 | global_local_soft_prediction = dict()
28 | for [label, features] in features.items():
29 | if len(features) > 1:
30 | feature = 0 * features[0].data
31 | for i in features:
32 | feature += i.data
33 | global_local_features[label] = [feature / len(features)]
34 | else:
35 | global_local_features[label] = [features[0].data]
36 |
37 | for [label, soft_prediction] in soft_prediction.items():
38 | if len(soft_prediction) > 1:
39 | soft = 0 * soft_prediction[0].data
40 | for i in soft_prediction:
41 | soft += i.data
42 | global_local_soft_prediction[label] = [soft / len(soft_prediction)]
43 | else:
44 | global_local_soft_prediction[label] = [soft_prediction[0].data]
45 |
46 | return global_local_features,global_local_soft_prediction
47 |
48 | def train(self):
49 | global_features = {}
50 | global_soft_prediction = {}
51 | reporter = MemReporter()
52 | start_time = time.time()
53 | train_loss = []
54 | global_weights = self.global_model.state_dict()
55 | for epoch in tqdm(range(self.args.num_epochs)):
56 | Knowledges = []
57 | test_accuracy = 0
58 | local_weights, local_losses = [], []
59 | print(f'\n | Global Training Round : {epoch+1} |\n')
60 | m = max(int(self.args.sampling_rate * self.args.num_clients), 1)
61 | idxs_users = np.random.choice(range(self.args.num_clients), m, replace=False)
62 | for idx in idxs_users:
63 | if self.args.upload_model == True:
64 | self.LocalModels[idx].load_model(global_weights)
65 | if epoch < 1:
66 | w, loss = self.LocalModels[idx].update_weights(global_round=epoch)
67 | local_losses.append(copy.deepcopy(loss))
68 | local_weights.append(copy.deepcopy(w))
69 | acc = self.LocalModels[idx].test_accuracy()
70 | test_accuracy += acc
71 |
72 | else:
73 | w, loss = self.LocalModels[idx].update_weights_HKD(global_round=epoch, global_features=global_features, global_soft_prediction=global_soft_prediction, lam = self.args.lam, gamma = self.args.gamma, temp = self.args.temp)
74 | local_losses.append(copy.deepcopy(loss))
75 | local_weights.append(copy.deepcopy(w))
76 | acc = self.LocalModels[idx].test_accuracy()
77 | test_accuracy += acc
78 |
79 | local_features,local_soft_predictions = self.LocalModels[idx].generate_knowledge(temp = self.args.temp)
80 | global_features.update(local_features)
81 | global_soft_prediction.update(local_soft_predictions)
82 | del local_features
83 | del local_soft_predictions
84 | gc.collect()
85 |
86 |
87 | # update global weights
88 | global_weights = average_weights(local_weights)
89 | self.global_model.load_state_dict(global_weights)
90 |
91 | loss_avg = sum(local_losses) / len(local_losses)
92 | train_loss.append(loss_avg)
93 | print("average loss: ", loss_avg)
94 | print('average local test accuracy:', test_accuracy / self.args.num_clients)
95 | print('global test accuracy: ', self.global_test_accuracy())
96 |
97 | print('Training is completed.')
98 | end_time = time.time()
99 | print('running time: {} s '.format(end_time - start_time))
100 | reporter.report()
101 |
--------------------------------------------------------------------------------
/Client/ClientFedMD.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | import scipy
5 | from torch.utils.data import Dataset
6 | import torch
7 | import copy
8 | import torch.nn as nn
9 | from sklearn.cluster import KMeans
10 | import torch.optim as optim
11 | import torch.nn.functional as F
12 | from utils import Accuracy,soft_predict
13 | from Client.ClientBase import Client
14 | import gc
15 | class ClientFedMD(Client):
16 | """
17 | This class is for train the local model with input global model(copied) and output the updated weight
18 | args: argument
19 | Loader_train,Loader_val,Loaders_test: input for training and inference
20 | user: the index of local model
21 | idxs: the index for data of this local model
22 | logger: log the loss and the process
23 | """
24 | def __init__(self, args, model, Loader_train,loader_test, loader_pub,idx, logger, code_length, num_classes, device):
25 | super().__init__(args, model, Loader_train,loader_test,idx, logger, code_length, num_classes, device)
26 | self.loader_pub = loader_pub
27 |
28 | def update_weights(self,global_round):
29 | self.model.to(self.device)
30 | self.model.train()
31 | epoch_loss = []
32 | optimizer = optim.Adam(self.model.parameters(),lr=self.args.lr)
33 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.args.lr_sh_rate, gamma=0.5)
34 | for iter in range(self.args.local_ep):
35 | batch_loss = []
36 | for batch_idx, (X, y) in enumerate(self.trainloader):
37 | X = X.to(self.device)
38 | y = y.to(self.device)
39 | optimizer.zero_grad()
40 | _,p = self.model(X)
41 | loss = self.ce(p,y)
42 | loss.backward()
43 | if self.args.clip_grad != None:
44 | nn.utils.clip_grad_norm_(self.model.parameters(), max_norm = self.args.clip_grad)
45 | optimizer.step()
46 | if batch_idx % 10 == 0:
47 | print('| Global Round : {} | Client: {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
48 | global_round, self.idx, iter, batch_idx * len(X),
49 | len(self.trainloader.dataset),
50 | 100. * batch_idx / len(self.trainloader), loss.item()))
51 | self.logger.add_scalar('loss', loss.item())
52 | batch_loss.append(loss.item())
53 | epoch_loss.append(sum(batch_loss)/len(batch_loss))
54 |
55 | return self.model.state_dict(),sum(epoch_loss) / len(epoch_loss)
56 |
57 | def update_weights_MD(self,knowledges, lam, temp, global_round):
58 | self.model.to(self.device)
59 | self.model.train()
60 | epoch_loss = []
61 | global_soft_prediction = torch.stack(knowledges)
62 | optimizer = optim.Adam(self.model.parameters(),lr=self.args.lr)
63 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.args.lr_sh_rate, gamma=0.5)
64 | for iter in range(self.args.local_ep):
65 | batch_loss = []
66 | for batch_idx, (X, y) in enumerate(self.trainloader):
67 | X = X.to(self.device)
68 | y = y.to(self.device)
69 | optimizer.zero_grad()
70 | _,Z = self.model(X)
71 | loss1 = self.ce(Z,y)
72 | loss2 = torch.tensor(0.0).to(self.device)
73 | for idx, (X_pub,y_pub) in enumerate(self.loader_pub):
74 | if idx == batch_idx:
75 | X_pub = X_pub.to(self.device)
76 | y_pub = y_pub.to(self.device)
77 | _,Z_pub = self.model(X_pub)
78 | Q_pub = soft_predict(Z_pub,temp).to(self.device)
79 | loss2 -= self.kld(Q_pub,global_soft_prediction[idx].to(self.device))
80 |
81 | loss = loss1 + lam*loss2
82 | loss.backward()
83 | optimizer.step()
84 | if batch_idx % 10 == 0:
85 | print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss1: {:.6f} Loss2: {:.6f} '.format(
86 | global_round, iter, batch_idx * len(X),
87 | len(self.trainloader.dataset),
88 | 100. * batch_idx / len(self.trainloader), loss1.item(),loss2.item()))
89 | self.logger.add_scalar('loss', loss.item())
90 | batch_loss.append(loss.item())
91 | epoch_loss.append(sum(batch_loss)/len(batch_loss))
92 |
93 | return self.model.state_dict(), sum(epoch_loss) / len(epoch_loss)
94 |
95 | def generate_knowledge(self, temp):
96 | self.model.to(self.device)
97 | self.model.eval()
98 | num_classes = self.model.num_classes
99 | soft_predictions = []
100 | for batch_idx, (X, y) in enumerate(self.loader_pub):
101 | X = X.to(self.device)
102 | y = y
103 | _,Z = self.model(X)
104 | Q = soft_predict(Z,temp).to(self.device).detach().cpu()
105 | soft_predictions.append(Q)
106 | del X
107 | del y
108 | del Z
109 | del Q
110 | gc.collect()
111 |
112 | return soft_predictions
--------------------------------------------------------------------------------
/sampling.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import scipy
4 | from torch.utils.data import Dataset
5 | import torch
6 | import copy
7 | from torchvision import datasets, transforms
8 |
9 | class LocalDataset(Dataset):
10 | """
11 | because torch.dataloader need override __getitem__() to iterate by index
12 | this class is map the index to local dataloader into the whole dataloader
13 | """
14 | def __init__(self, dataset, Dict):
15 | self.dataset = dataset
16 | self.idxs = [int(i) for i in Dict]
17 |
18 | def __len__(self):
19 | return len(self.idxs)
20 |
21 | def __getitem__(self, item):
22 | X, y = self.dataset[self.idxs[item]]
23 | return X, y
24 |
25 | def LocalDataloaders(dataset, dict_users, batch_size, ShuffleorNot = True, BatchorNot = True, frac = 1):
26 | """
27 | dataset: the same dataset object
28 | dict_users: dictionary of index of each local model
29 | batch_size: batch size for each dataloader
30 | ShuffleorNot: Shuffle or Not
31 | BatchorNot: if False, the dataloader will give the full length of data instead of a batch, for testing
32 | """
33 | num_users = len(dict_users)
34 | loaders = []
35 | for i in range(num_users):
36 | num_data = len(dict_users[i])
37 | frac_num_data = int(frac*num_data)
38 | whole_range = range(num_data)
39 | frac_range = np.random.choice(whole_range, frac_num_data)
40 | frac_dict_users = [dict_users[i][j] for j in frac_range]
41 | if BatchorNot== True:
42 | loader = torch.utils.data.DataLoader(
43 | LocalDataset(dataset,frac_dict_users),
44 | batch_size=batch_size,
45 | shuffle = ShuffleorNot,
46 | num_workers=0,
47 | drop_last=True)
48 | else:
49 | loader = torch.utils.data.DataLoader(
50 | LocalDataset(dataset,frac_dict_users),
51 | batch_size=len(LocalDataset(dataset,dict_users[i])),
52 | shuffle = ShuffleorNot,
53 | num_workers=0,
54 | drop_last=True)
55 | loaders.append(loader)
56 | return loaders
57 |
58 |
59 | def partition_data(n_users, alpha=0.5,rand_seed = 0, dataset = 'cifar10'):
60 | if dataset == 'CIFAR10':
61 | K = 10
62 | data_dir = '../data/cifar10/'
63 | apply_transform = transforms.Compose(
64 | [transforms.ToTensor(),
65 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
66 | train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
67 | transform=apply_transform)
68 | test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
69 | transform=apply_transform)
70 | y_train = np.array(train_dataset.targets)
71 | y_test = np.array(test_dataset.targets)
72 |
73 | if dataset == 'CIFAR100':
74 | K = 100
75 | data_dir = '../data/cifar100/'
76 | apply_transform = transforms.Compose(
77 | [transforms.ToTensor(),
78 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
79 | train_dataset = datasets.CIFAR100(data_dir, train=True, download=True,
80 | transform=apply_transform)
81 | test_dataset = datasets.CIFAR100(data_dir, train=False, download=True,
82 | transform=apply_transform)
83 | y_train = np.array(train_dataset.targets)
84 | y_test = np.array(test_dataset.targets)
85 |
86 | if dataset == 'EMNIST':
87 | K = 62
88 | data_dir = '../data/EMNIST/'
89 | apply_transform = transforms.Compose(
90 | [transforms.ToTensor(),
91 | transforms.Normalize((0.5), (0.5))])
92 | train_dataset = datasets.EMNIST(data_dir, train=True, split = 'byclass', download=True,
93 | transform=apply_transform)
94 | test_dataset = datasets.EMNIST(data_dir, train=False, split = 'byclass', download=True,
95 | transform=apply_transform)
96 | y_train = np.array(train_dataset.targets)
97 | y_test = np.array(test_dataset.targets)
98 | if dataset == 'SVHN':
99 | K = 10
100 | data_dir = '../data/SVHN/'
101 | apply_transform = transforms.Compose(
102 | [transforms.ToTensor(),
103 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
104 | train_dataset = datasets.SVHN(data_dir, split='train', download=True,
105 | transform=apply_transform)
106 | test_dataset = datasets.SVHN(data_dir, split='test', download=True,
107 | transform=apply_transform)
108 | y_train = np.array(train_dataset.labels)
109 | y_test = np.array(test_dataset.labels)
110 |
111 | min_size = 0
112 | N = len(train_dataset)
113 | N_test = len(test_dataset)
114 | net_dataidx_map = {}
115 | net_dataidx_map_test = {}
116 | np.random.seed(rand_seed)
117 |
118 | while min_size < 10:
119 | idx_batch = [[] for _ in range(n_users)]
120 | idx_batch_test = [[] for _ in range(n_users)]
121 | for k in range(K):
122 | idx_k = np.where(y_train == k)[0]
123 | idx_k_test = np.where(y_test == k)[0]
124 | np.random.shuffle(idx_k)
125 | proportions = np.random.dirichlet(np.repeat(alpha, n_users))
126 | ## Balance
127 | proportions_train = np.array([p*(len(idx_j) 1:
141 | feature = 0 * features[0].data
142 | for i in features:
143 | feature += i.data
144 | agg_local_features[label] = [feature / len(features)]
145 | else:
146 | agg_local_features[label] = [features[0].data]
147 |
148 | return agg_local_features
149 |
150 | def dict_to_tensor(self, dic):
151 | lit = []
152 | for key,tensor in dic.items():
153 | lit.append(tensor[0])
154 | lit = torch.stack(lit)
155 | return lit
--------------------------------------------------------------------------------
/Client/ClientFedHKD.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | import scipy
5 | from torch.utils.data import Dataset
6 | import torch
7 | import copy
8 | import torch.nn as nn
9 | from sklearn.cluster import KMeans
10 | import torch.optim as optim
11 | import torch.nn.functional as F
12 | from utils import Accuracy,soft_predict
13 | from Client.ClientBase import Client
14 | import gc
15 | class ClientFedHKD(Client):
16 | """
17 | This class is for train the local model with input global model(copied) and output the updated weight
18 | args: argument
19 | Loader_train,Loader_val,Loaders_test: input for training and inference
20 | user: the index of local model
21 | idxs: the index for data of this local model
22 | logger: log the loss and the process
23 | """
24 | def __init__(self, args, model, Loader_train,loader_test,idx, logger, code_length, num_classes, device):
25 | super().__init__(args, model, Loader_train,loader_test,idx, logger, code_length, num_classes, device)
26 |
27 |
28 | def update_weights(self,global_round):
29 | self.model.to(self.device)
30 | self.model.train()
31 | epoch_loss = []
32 | optimizer = optim.Adam(self.model.parameters(),lr=self.args.lr)
33 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.args.lr_sh_rate, gamma=0.5)
34 | for iter in range(self.args.local_ep):
35 | batch_loss = []
36 | for batch_idx, (X, y) in enumerate(self.trainloader):
37 | X = X.to(self.device)
38 | y = y.to(self.device)
39 | optimizer.zero_grad()
40 | _,p = self.model(X)
41 | loss = self.ce(p,y)
42 | loss.backward()
43 | if self.args.clip_grad != None:
44 | nn.utils.clip_grad_norm_(self.model.parameters(), max_norm = self.args.clip_grad)
45 | optimizer.step()
46 | if batch_idx % 10 == 0:
47 | print('| Global Round : {} | Client: {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
48 | global_round, self.idx, iter, batch_idx * len(X),
49 | len(self.trainloader.dataset),
50 | 100. * batch_idx / len(self.trainloader), loss.item()))
51 | self.logger.add_scalar('loss', loss.item())
52 | batch_loss.append(loss.item())
53 | epoch_loss.append(sum(batch_loss)/len(batch_loss))
54 |
55 | return self.model.state_dict(),sum(epoch_loss) / len(epoch_loss)
56 |
57 |
58 | def update_weights_HKD(self,global_features, global_soft_prediction, lam, gamma, temp, global_round):
59 | self.model.to(self.device)
60 | self.model.train()
61 | epoch_loss = []
62 | optimizer = optim.Adam(self.model.parameters(),lr=self.args.lr)
63 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.args.lr_sh_rate, gamma=0.5)
64 | tensor_global_features = self.dict_to_tensor(global_features).to(self.device)
65 | tensor_global_soft_prediction = self.dict_to_tensor(global_soft_prediction).to(self.device)
66 | for iter in range(self.args.local_ep):
67 | batch_loss = []
68 | for batch_idx, (X, y) in enumerate(self.trainloader):
69 | X = X.to(self.device)
70 | y = y.to(self.device)
71 | optimizer.zero_grad()
72 | F,Z = self.model(X)
73 | Z_help = self.model.classifier(tensor_global_features)
74 | Q_help = soft_predict(Z_help,temp).to(self.device)
75 | loss1 = self.ce(Z,y)
76 | target_features = copy.deepcopy(F.data)
77 |
78 |
79 | for i in range(y.shape[0]):
80 | if int(y[i]) in global_features.keys():
81 | target_features[i] = global_features[int(y[i])][0].data
82 |
83 |
84 | target_features = target_features.to(self.device)
85 | if len(global_features) == 0:
86 | loss2 = 0*loss1
87 | loss3 = 0*loss1
88 | else:
89 | loss2 = self.kld(Q_help.log(),tensor_global_soft_prediction)
90 | loss3 = self.mse(F,target_features)
91 | loss = loss1 + lam*loss2 + gamma*loss3
92 | loss.backward()
93 | if self.args.clip_grad != None:
94 | nn.utils.clip_grad_norm_(self.model.parameters(), max_norm = self.args.clip_grad)
95 | nn.utils.clip_grad_norm_(self.model.parameters(), max_norm =1.1)
96 | optimizer.step()
97 | if batch_idx % 10 == 0:
98 | print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss1: {:.6f} Loss2: {:.6f} Loss3: {:.6f} '.format(
99 | global_round, iter, batch_idx * len(X),
100 | len(self.trainloader.dataset),
101 | 100. * batch_idx / len(self.trainloader), loss1.item(),loss2.item(),loss3.item()))
102 | self.logger.add_scalar('loss', loss.item())
103 | batch_loss.append(loss.item())
104 | epoch_loss.append(sum(batch_loss)/len(batch_loss))
105 |
106 | return self.model.state_dict(), sum(epoch_loss) / len(epoch_loss)
107 |
108 | # generate knowledge for FedDFKD
109 | def generate_knowledge(self, temp):
110 | self.model.to(self.device)
111 | self.model.eval()
112 | local_features = {}
113 | local_soft_prediction = {}
114 | num_classes = self.model.num_classes
115 | features = [torch.zeros(self.code_length).to(self.device)]*num_classes
116 | soft_predictions = [torch.zeros(num_classes).to(self.device)]*num_classes
117 | count = [0]*num_classes
118 | for batch_idx, (X, y) in enumerate(self.trainloader):
119 | X = X.to(self.device)
120 | y = y
121 | F,Z = self.model(X)
122 | Q = soft_predict(Z,temp).to(self.device)
123 | m = y.shape[0]
124 | for i in range(len(y)):
125 | if y[i].item() in local_features:
126 | local_features[y[i].item()].append(F[i,:])
127 | local_soft_prediction[y[i].item()].append(Q[i,:])
128 | else:
129 | local_features[y[i].item()] = [F[i,:]]
130 | local_soft_prediction[y[i].item()] = [Q[i,:]]
131 | del X
132 | del y
133 | del F
134 | del Z
135 | del Q
136 | gc.collect()
137 |
138 | features,soft_predictions = self.local_knowledge_aggregation(local_features,local_soft_prediction, std = self.args.std)
139 |
140 | return (features,soft_predictions)
141 |
142 | def local_knowledge_aggregation(self,local_features,local_soft_prediction, std):
143 | agg_local_features = dict()
144 | agg_local_soft_prediction = dict()
145 | feature_noise = std*torch.randn(self.args.code_len).to(self.device)
146 | for [label, features] in local_features.items():
147 | if len(features) > 1:
148 | feature = 0 * features[0].data
149 | for i in features:
150 | feature += i.data
151 | agg_local_features[label] = [feature / len(features) + feature_noise]
152 | else:
153 | agg_local_features[label] = [features[0].data + feature_noise]
154 |
155 | for [label, soft_prediction] in local_soft_prediction.items():
156 | if len(soft_prediction) > 1:
157 | soft = 0 * soft_prediction[0].data
158 | for i in soft_prediction:
159 | soft += i.data
160 |
161 | agg_local_soft_prediction[label] = [soft / len(soft_prediction) ]
162 | else:
163 | agg_local_soft_prediction[label] = [soft_prediction[0].data]
164 |
165 | return agg_local_features,agg_local_soft_prediction
166 |
167 | def dict_to_tensor(self, dic):
168 | lit = []
169 | for key,tensor in dic.items():
170 | lit.append(tensor[0])
171 | lit = torch.stack(lit)
172 | return lit
173 |
--------------------------------------------------------------------------------
/mem_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import gc
3 | from collections import defaultdict
4 | from typing import Optional, Tuple, List
5 |
6 | import torch
7 |
8 | from math import isnan
9 | from calmsize import size as calmsize
10 |
11 | def readable_size(num_bytes: int) -> str:
12 | return '' if isnan(num_bytes) else '{:.2f}'.format(calmsize(num_bytes))
13 |
14 | LEN = 79
15 |
16 | # some pytorch low-level memory management constant
17 | # the minimal allocate memory size (Byte)
18 | PYTORCH_MIN_ALLOCATE = 2 ** 9
19 | # the minimal cache memory size (Byte)
20 | PYTORCH_MIN_CACHE = 2 ** 20
21 |
22 | class MemReporter():
23 | """A memory reporter that collects tensors and memory usages
24 |
25 | Parameters:
26 | - model: an extra nn.Module can be passed to infer the name
27 | of Tensors
28 |
29 | """
30 | def __init__(self, model: Optional[torch.nn.Module] = None):
31 | self.tensor_name = {}
32 | self.device_mapping = defaultdict(list)
33 | self.device_tensor_stat = {}
34 | # to numbering the unknown tensors
35 | self.name_idx = 0
36 |
37 | tensor_names = defaultdict(list)
38 | if model is not None:
39 | assert isinstance(model, torch.nn.Module)
40 | # for model with tying weight, multiple parameters may share
41 | # the same underlying tensor
42 | for name, param in model.named_parameters():
43 | tensor_names[param].append(name)
44 |
45 | for param, name in tensor_names.items():
46 | self.tensor_name[id(param)] = '+'.join(name)
47 |
48 | def _get_tensor_name(self, tensor: torch.Tensor) -> str:
49 | tensor_id = id(tensor)
50 | if tensor_id in self.tensor_name:
51 | name = self.tensor_name[tensor_id]
52 | # use numbering if no name can be inferred
53 | else:
54 | name = type(tensor).__name__ + str(self.name_idx)
55 | self.tensor_name[tensor_id] = name
56 | self.name_idx += 1
57 | return name
58 |
59 | def collect_tensor(self):
60 | """Collect all tensor objects tracked by python
61 |
62 | NOTICE:
63 | - the buffers for backward which is implemented in C++ are
64 | not tracked by python's reference counting.
65 | - the gradients(.grad) of Parameters is not collected, and
66 | I don't know why.
67 | """
68 | #FIXME: make the grad tensor collected by gc
69 | objects = gc.get_objects()
70 | tensors = [obj for obj in objects if isinstance(obj, torch.Tensor)]
71 | for t in tensors:
72 | self.device_mapping[t.device].append(t)
73 |
74 | def get_stats(self):
75 | """Get the memory stat of tensors and then release them
76 |
77 | As a memory profiler, we cannot hold the reference to any tensors, which
78 | causes possibly inaccurate memory usage stats, so we delete the tensors after
79 | getting required stats"""
80 | visited_data = {}
81 | self.device_tensor_stat.clear()
82 |
83 | def get_tensor_stat(tensor: torch.Tensor) -> List[Tuple[str, int, int, int]]:
84 | """Get the stat of a single tensor
85 |
86 | Returns:
87 | - stat: a tuple containing (tensor_name, tensor_size,
88 | tensor_numel, tensor_memory)
89 | """
90 | assert isinstance(tensor, torch.Tensor)
91 |
92 | name = self._get_tensor_name(tensor)
93 | if tensor.is_sparse:
94 | indices_stat = get_tensor_stat(tensor._indices())
95 | values_stat = get_tensor_stat(tensor._values())
96 | return indices_stat + values_stat
97 |
98 | numel = tensor.numel()
99 | element_size = tensor.element_size()
100 | fact_numel = tensor.storage().size()
101 | fact_memory_size = fact_numel * element_size
102 | # since pytorch allocate at least 512 Bytes for any tensor, round
103 | # up to a multiple of 512
104 | memory_size = math.ceil(fact_memory_size / PYTORCH_MIN_ALLOCATE) \
105 | * PYTORCH_MIN_ALLOCATE
106 |
107 | # tensor.storage should be the actual object related to memory
108 | # allocation
109 | data_ptr = tensor.storage().data_ptr()
110 | if data_ptr in visited_data:
111 | name = '{}(->{})'.format(
112 | name,
113 | visited_data[data_ptr],
114 | )
115 | # don't count the memory for reusing same underlying storage
116 | memory_size = 0
117 | else:
118 | visited_data[data_ptr] = name
119 |
120 | size = tuple(tensor.size())
121 | # torch scalar has empty size
122 | if not size:
123 | size = (1,)
124 |
125 | return [(name, size, numel, memory_size)]
126 |
127 | for device, tensors in self.device_mapping.items():
128 | tensor_stats = []
129 | for tensor in tensors:
130 |
131 | if tensor.numel() == 0:
132 | continue
133 | stat = get_tensor_stat(tensor) # (name, shape, numel, memory_size)
134 | tensor_stats += stat
135 | if isinstance(tensor, torch.nn.Parameter):
136 | if tensor.grad is not None:
137 | # manually specify the name of gradient tensor
138 | self.tensor_name[id(tensor.grad)] = '{}.grad'.format(
139 | self._get_tensor_name(tensor)
140 | )
141 | stat = get_tensor_stat(tensor.grad)
142 | tensor_stats += stat
143 |
144 | self.device_tensor_stat[device] = tensor_stats
145 |
146 | self.device_mapping.clear()
147 |
148 | def print_stats(self, verbose: bool = False, target_device: Optional[torch.device] = None) -> None:
149 | # header
150 | # show_reuse = verbose
151 | # template_format = '{:<40s}{:>20s}{:>10s}'
152 | # print(template_format.format('Element type', 'Size', 'Used MEM') )
153 | for device, tensor_stats in self.device_tensor_stat.items():
154 | # By default, if the target_device is not specified,
155 | # print tensors on all devices
156 | if target_device is not None and device != target_device:
157 | continue
158 | # print('-' * LEN)
159 | print('\nStorage on {}'.format(device))
160 | total_mem = 0
161 | total_numel = 0
162 | for stat in tensor_stats:
163 | name, size, numel, mem = stat
164 | # if not show_reuse:
165 | # name = name.split('(')[0]
166 | # print(template_format.format(
167 | # str(name),
168 | # str(size),
169 | # readable_size(mem),
170 | # ))
171 | total_mem += mem
172 | total_numel += numel
173 |
174 | print('-'*LEN)
175 | print('Total Tensors: {} \tUsed Memory: {}'.format(
176 | total_numel, readable_size(total_mem),
177 | ))
178 |
179 | if device != torch.device('cpu'):
180 | with torch.cuda.device(device):
181 | memory_allocated = torch.cuda.memory_allocated()
182 | print('The allocated memory on {}: {}'.format(
183 | device, readable_size(memory_allocated),
184 | ))
185 | if memory_allocated != total_mem:
186 | print('Memory differs due to the matrix alignment or'
187 | ' invisible gradient buffer tensors')
188 | print('-'*LEN)
189 |
190 | def report(self, verbose: bool = False, device: Optional[torch.device] = None) -> None:
191 | """Interface for end-users to directly print the memory usage
192 |
193 | args:
194 | - verbose: flag to show tensor.storage reuse information
195 | - device: `torch.device` object, specify the target device
196 | to report detailed memory usage. It will print memory usage
197 | on all devices if not specified. Usually we only want to
198 | print the memory usage on CUDA devices.
199 |
200 | """
201 | self.collect_tensor()
202 | self.get_stats()
203 | self.print_stats(verbose, target_device=device)
--------------------------------------------------------------------------------