├── code_FedCR ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── Nets.cpython-37.pyc │ │ ├── Nets.cpython-38.pyc │ │ ├── NetsSR.cpython-37.pyc │ │ ├── Update.cpython-37.pyc │ │ ├── Nets_PAC.cpython-37.pyc │ │ ├── Nets_VIB.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── resnet18.cpython-37.pyc │ │ ├── resnet18.cpython-38.pyc │ │ ├── Nets_VIB_2D.cpython-37.pyc │ │ ├── distributed_training_utils.cpython-37.pyc │ │ ├── distributed_training_utils.cpython-38.pyc │ │ ├── distributed_training_utilsSR.cpython-37.pyc │ │ ├── distributed_training_utils_2D.cpython-37.pyc │ │ ├── distributed_training_utils_PAC.cpython-37.pyc │ │ ├── distributed_training_utils_ditto.cpython-37.pyc │ │ ├── distributed_training_utils_old2.cpython-37.pyc │ │ └── distributed_training_utils_old3.cpython-37.pyc │ ├── Update.py │ ├── Nets.py │ ├── Nets_PAC.py │ ├── resnet18.py │ ├── main_fedSR.py │ ├── Nets_VIB.py │ ├── NetsSR.py │ ├── distributed_training_utils_ditto.py │ ├── distributed_training_utilsSR.py │ ├── distributed_training_utils_PAC.py │ └── distributed_training_utils.py ├── utils │ ├── __init__.py │ ├── seed.py │ ├── logg.py │ ├── options.py │ └── utils_dataset.py ├── .idea │ ├── misc.xml │ ├── inspectionProfiles │ │ └── profiles_settings.xml │ ├── modules.xml │ ├── code.iml │ └── workspace.xml ├── main_fed_ditto.py ├── main_fedSR.py ├── main_fed_local.py ├── main_fed_PAC.py └── main_fed.py ├── requirement.txt └── README.md /code_FedCR/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.7 4 | 5 | -------------------------------------------------------------------------------- /code_FedCR/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | cachetools==4.1.1 2 | scikit-learn==0.23.2 3 | scipy==1.4.1 4 | torchvision==0.5.0 5 | torch==1.4.0 6 | numpy==1.18.1 -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/Nets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/Nets.cpython-37.pyc -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/Nets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/Nets.cpython-38.pyc -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/NetsSR.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/NetsSR.cpython-37.pyc -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/Update.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/Update.cpython-37.pyc -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/Nets_PAC.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/Nets_PAC.cpython-37.pyc -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/Nets_VIB.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/Nets_VIB.cpython-37.pyc -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/resnet18.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/resnet18.cpython-37.pyc -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/resnet18.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/resnet18.cpython-38.pyc -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/Nets_VIB_2D.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/Nets_VIB_2D.cpython-37.pyc -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/distributed_training_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/distributed_training_utils.cpython-37.pyc -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/distributed_training_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/distributed_training_utils.cpython-38.pyc -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/distributed_training_utilsSR.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/distributed_training_utilsSR.cpython-37.pyc -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/distributed_training_utils_2D.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/distributed_training_utils_2D.cpython-37.pyc -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/distributed_training_utils_PAC.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/distributed_training_utils_PAC.cpython-37.pyc -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/distributed_training_utils_ditto.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/distributed_training_utils_ditto.cpython-37.pyc -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/distributed_training_utils_old2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/distributed_training_utils_old2.cpython-37.pyc -------------------------------------------------------------------------------- /code_FedCR/models/__pycache__/distributed_training_utils_old3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/distributed_training_utils_old3.cpython-37.pyc -------------------------------------------------------------------------------- /code_FedCR/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /code_FedCR/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /code_FedCR/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /code_FedCR/.idea/code.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /code_FedCR/utils/seed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | 5 | def setup_seed(seed): 6 | random.seed(seed) 7 | np.random.seed(seed) 8 | torch.manual_seed(seed) 9 | torch.cuda.manual_seed(seed) 10 | torch.backends.cudnn.deterministic = True 11 | torch.backends.cudnn.benchmark = True -------------------------------------------------------------------------------- /code_FedCR/utils/logg.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | def get_logger(filename, verbosity=1, name=None): 4 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} 5 | formatter = logging.Formatter( 6 | "[%(asctime)s][%(levelname)s] %(message)s" 7 | ) 8 | logger = logging.getLogger(name) 9 | logger.setLevel(level_dict[verbosity]) 10 | 11 | fh = logging.FileHandler(filename, "w") 12 | fh.setFormatter(formatter) 13 | logger.addHandler(fh) 14 | 15 | sh = logging.StreamHandler() 16 | sh.setFormatter(formatter) 17 | logger.addHandler(sh) 18 | 19 | return logger -------------------------------------------------------------------------------- /code_FedCR/models/Update.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch 6 | import copy 7 | from torch import nn, autograd 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | class DatasetSplit(Dataset): 11 | def __init__(self, dataset, idxs): 12 | self.dataset = dataset 13 | self.idxs = list(idxs) 14 | 15 | def __len__(self): 16 | return len(self.idxs) 17 | 18 | def __getitem__(self, item): 19 | image, label = self.dataset[self.idxs[item]] 20 | return image, label 21 | 22 | class LocalUpdate(object): 23 | def __init__(self, args, dataset=None, idxs=None): 24 | self.args = args 25 | self.loss_func = nn.CrossEntropyLoss() 26 | self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True) 27 | 28 | def train(self, net): 29 | net.train() 30 | optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum, weight_decay=self.args.weigh_delay) 31 | net_pre = copy.deepcopy(net) 32 | state_pre = [parameter for parameter in net_pre.parameters()] 33 | # train and update 34 | epoch_loss = [] 35 | for iter in range(self.args.local_ep): 36 | batch_loss = [] 37 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 38 | images, labels = images.to(self.args.device), labels.to(self.args.device) 39 | 40 | net.zero_grad() 41 | log_probs = net(images) 42 | loss = self.loss_func(log_probs, labels) 43 | loss.backward() 44 | optimizer.step() 45 | batch_loss.append(loss.item()) 46 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 47 | 48 | state_now = [parameter for parameter in net.parameters()] 49 | grads = [torch.zeros_like(param) for param in state_now] 50 | for state_now_, state_pre_, grad in zip(state_now, state_pre, grads): 51 | grad.data[:] = state_now_ - state_pre_ 52 | return grads, sum(epoch_loss) / len(epoch_loss) 53 | 54 | -------------------------------------------------------------------------------- /code_FedCR/utils/options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import argparse 6 | 7 | def args_parser(): 8 | parser = argparse.ArgumentParser() 9 | # federated arguments 10 | parser.add_argument('--epochs', type=int, default=500, help="rounds of training") 11 | parser.add_argument('--test_freq', type=int, default=1, help="frequency of test") 12 | parser.add_argument('--num_users', type=int, default=100, help="number of users") 13 | parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients") 14 | parser.add_argument('--local_ep', type=int, default=2, help="the number of local epochs") 15 | parser.add_argument('--last_local_ep', type=int, default=10, help="the number of local epochs of last") 16 | parser.add_argument('--local_rep_ep', type=int, default=1, help="the number of local epochs of Fed_Rep's feature") #ten local epochs to train the local head, followed by one or five epochs for the representation 17 | parser.add_argument('--local_bs', type=int, default=48, help="local batch size: B") 18 | parser.add_argument('--bs', type=int, default=10, help="test batch size") 19 | parser.add_argument('--lr', type=float, default=0.01, help="learning rate") 20 | parser.add_argument('--globallr', type=float, default=1, help="Global learning rate") 21 | parser.add_argument('--momentum', type=float, default=0, help="local SGD momentum (default: 0.0)") 22 | parser.add_argument('--weigh_delay', type=float, default=0, help="local SGD weigh_delay") 23 | parser.add_argument('--mu', type=float, default=0, help='the value of Ditto') 24 | parser.add_argument('--mu1', type=float, default=0, help='the value_L1 of Factorization') 25 | parser.add_argument('--tau1', type=float, default=0, help='the value_cosine similarity of Factorization') 26 | parser.add_argument('--lr_decay', type=float, default=0.997, help='the value of lr_decay') 27 | parser.add_argument('--alpha', default=0, type=float, help='the value of alpha for Fed_VIB') 28 | parser.add_argument('--beta', default=0.001, type=float, help='the value of beta for Fed_VIB') 29 | parser.add_argument('--sync', type=str, default='True', help='If the model is synchronized for FedCR') 30 | parser.add_argument('--beta_PAC', default=1, type=float, help='the value of beta for FedPAC') 31 | parser.add_argument('--beta2', default=0, type=float, help='the value of beta2 for Z') 32 | parser.add_argument('--dimZ', default = 256, type=int, help='dimension of encoding Z in Fed_VIB') 33 | parser.add_argument('--dimZ_PAC', default=1024, type=int, help='dimension of Z in Factorized') 34 | parser.add_argument('--CMI', default=0.001, type=float, help='the value of CMI in FedSR') 35 | parser.add_argument('--L2R', default=0.001, type=float, help='the value of L2R in FedSR') 36 | parser.add_argument('--num_avg_train', default = 15, type=int, help='the number of samples when\ 37 | perform multi-shot train') 38 | parser.add_argument('--num_avg', default = 30, type=int, help='the number of samples when\ 39 | perform multi-shot prediction') 40 | 41 | parser.add_argument('--filepath', type=str, default='filepath', help='whether error accumulation or not') 42 | 43 | # model arguments 44 | parser.add_argument('--method', type=str, default='fedCR', help='method name') 45 | 46 | # other arguments 47 | parser.add_argument('--dataset', type=str, default='CIFAR10', help="name of dataset") 48 | parser.add_argument('--frac_data', type=float, default=0.7, help="fraction of frac_data") 49 | parser.add_argument('--rule', type=str, default='noniid', help='whether noniid or Dirichet') 50 | parser.add_argument('--class_main', type=int, default=5, help='the value of class_main for noniid') 51 | parser.add_argument('--dir_a', default=0.5, type=float, help='the value of dir_a for dirichlet') 52 | parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU") 53 | parser.add_argument('--seed', type=int, default=23, help='random seed (default: 23)') 54 | args = parser.parse_args() 55 | return args 56 | 57 | 58 | -------------------------------------------------------------------------------- /code_FedCR/models/Nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | 6 | 7 | class client_model(nn.Module): 8 | def __init__(self, name, args=True): 9 | super(client_model, self).__init__() 10 | self.name = name 11 | 12 | if self.name == 'emnist_NN': 13 | self.n_cls = 10 14 | self.fc1 = nn.Linear(1 * 28 * 28, 1024) 15 | self.fc2 = nn.Linear(1024, 1024) 16 | self.fc3 = nn.Linear(1024, self.n_cls) 17 | self.weight_keys = [['fc1.weight', 'fc1.bias'], 18 | ['fc2.weight', 'fc2.bias'], 19 | ['fc3.weight', 'fc3.bias']] 20 | 21 | 22 | if self.name == 'FMNIST_CNN': 23 | self.n_cls = 10 24 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5) 25 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 26 | self.conv2 = nn.Conv2d(in_channels=4, out_channels=12, kernel_size=5) 27 | self.fc1 = nn.Linear(12 * 4 * 4, 1024) 28 | self.fc2 = nn.Linear(1024, 1024) 29 | self.fc3 = nn.Linear(1024, self.n_cls) 30 | 31 | self.weight_keys = [['conv1.weight', 'conv1.bias'], 32 | ['conv2.weight', 'conv2.bias'], 33 | ['fc1.weight', 'fc1.bias'], 34 | ['fc2.weight', 'fc2.bias'], 35 | ['fc3.weight', 'fc3.bias']] 36 | 37 | 38 | if self.name == 'cifar10_LeNet': 39 | self.n_cls = 10 40 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5) 41 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5) 42 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 43 | self.fc1 = nn.Linear(64 * 5 * 5, 1024) 44 | self.fc2 = nn.Linear(1024, 1024) 45 | self.fc3 = nn.Linear(1024, self.n_cls) 46 | self.weight_keys = [['conv1.weight', 'conv1.bias'], 47 | ['conv2.weight', 'conv2.bias'], 48 | ['fc1.weight', 'fc1.bias'], 49 | ['fc2.weight', 'fc2.bias'], 50 | ['fc3.weight', 'fc3.bias']] 51 | 52 | if self.name == 'cifar100_LeNet': 53 | self.n_cls = 100 54 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5) 55 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5) 56 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 57 | self.fc1 = nn.Linear(64 * 5 * 5, 1024) 58 | self.fc2 = nn.Linear(1024, 1024) 59 | self.fc3 = nn.Linear(1024, self.n_cls) 60 | self.weight_keys = [['conv1.weight', 'conv1.bias'], 61 | ['conv2.weight', 'conv2.bias'], 62 | ['fc1.weight', 'fc1.bias'], 63 | ['fc2.weight', 'fc2.bias'], 64 | ['fc3.weight', 'fc3.bias']] 65 | 66 | 67 | def forward(self, x): 68 | 69 | if self.name == 'emnist_NN': 70 | x = x.view(-1, 1 * 28 * 28) 71 | x = F.relu(self.fc1(x)) 72 | x = F.relu(self.fc2(x)) 73 | x = self.fc3(x) 74 | 75 | if self.name == 'FMNIST_CNN': 76 | x = self.pool(F.relu(self.conv1(x))) 77 | x = self.pool(F.relu(self.conv2(x))) 78 | x = x.view(-1, 12 * 4 * 4) 79 | x = F.relu(self.fc1(x)) 80 | x = F.relu(self.fc2(x)) 81 | x = self.fc3(x) 82 | 83 | 84 | if self.name == 'cifar10_LeNet': 85 | x = self.pool(F.relu(self.conv1(x))) 86 | x = self.pool(F.relu(self.conv2(x))) 87 | x = x.view(-1, 64 * 5 * 5) 88 | x = F.relu(self.fc1(x)) 89 | x = F.relu(self.fc2(x)) 90 | x = self.fc3(x) 91 | 92 | if self.name == 'cifar100_LeNet': 93 | x = self.pool(F.relu(self.conv1(x))) 94 | x = self.pool(F.relu(self.conv2(x))) 95 | x = x.view(-1, 64 * 5 * 5) 96 | x = F.relu(self.fc1(x)) 97 | x = F.relu(self.fc2(x)) 98 | x = self.fc3(x) 99 | 100 | 101 | return x 102 | -------------------------------------------------------------------------------- /code_FedCR/models/Nets_PAC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | 6 | 7 | class client_model(nn.Module): 8 | def __init__(self, name, args=True): 9 | super(client_model, self).__init__() 10 | self.name = name 11 | 12 | 13 | if self.name == 'emnist_NN': 14 | self.n_cls = 10 15 | self.fc1 = nn.Linear(1 * 28 * 28, 1024) 16 | self.fc2 = nn.Linear(1024, 1024) 17 | self.fc3 = nn.Linear(1024, self.n_cls) 18 | self.weight_keys = [['fc1.weight', 'fc1.bias'], 19 | ['fc2.weight', 'fc2.bias'], 20 | ['fc3.weight', 'fc3.bias']] 21 | 22 | 23 | if self.name == 'FMNIST_CNN': 24 | self.n_cls = 10 25 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5) 26 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 27 | self.conv2 = nn.Conv2d(in_channels=4, out_channels=12, kernel_size=5) 28 | self.fc1 = nn.Linear(12 * 4 * 4, 1024) 29 | self.fc2 = nn.Linear(1024, 1024) 30 | self.fc3 = nn.Linear(1024, self.n_cls) 31 | 32 | self.weight_keys = [['conv1.weight', 'conv1.bias'], 33 | ['conv2.weight', 'conv2.bias'], 34 | ['fc1.weight', 'fc1.bias'], 35 | ['fc2.weight', 'fc2.bias'], 36 | ['fc3.weight', 'fc3.bias']] 37 | 38 | 39 | if self.name == 'cifar10_LeNet': 40 | self.n_cls = 10 41 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5) 42 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5) 43 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 44 | self.fc1 = nn.Linear(64 * 5 * 5, 1024) 45 | self.fc2 = nn.Linear(1024, 1024) 46 | self.fc3 = nn.Linear(1024, self.n_cls) 47 | self.weight_keys = [['conv1.weight', 'conv1.bias'], 48 | ['conv2.weight', 'conv2.bias'], 49 | ['fc1.weight', 'fc1.bias'], 50 | ['fc2.weight', 'fc2.bias'], 51 | ['fc3.weight', 'fc3.bias']] 52 | 53 | 54 | if self.name == 'cifar100_LeNet': 55 | self.n_cls = 100 56 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5) 57 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5) 58 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 59 | self.fc1 = nn.Linear(64 * 5 * 5, 1024) 60 | self.fc2 = nn.Linear(1024, 1024) 61 | self.fc3 = nn.Linear(1024, self.n_cls) 62 | self.weight_keys = [['conv1.weight', 'conv1.bias'], 63 | ['conv2.weight', 'conv2.bias'], 64 | ['fc1.weight', 'fc1.bias'], 65 | ['fc2.weight', 'fc2.bias'], 66 | ['fc3.weight', 'fc3.bias']] 67 | 68 | 69 | def forward(self, x): 70 | 71 | if self.name == 'emnist_NN': 72 | x = x.view(-1, 1 * 28 * 28) 73 | x = F.relu(self.fc1(x)) 74 | fecture_output = F.relu(self.fc2(x)) 75 | 76 | x = self.fc3(fecture_output) 77 | 78 | if self.name == 'FMNIST_CNN': 79 | x = self.pool(F.relu(self.conv1(x))) 80 | x = self.pool(F.relu(self.conv2(x))) 81 | x = x.view(-1, 12 * 4 * 4) 82 | x = F.relu(self.fc1(x)) 83 | fecture_output = F.relu(self.fc2(x)) 84 | 85 | x = self.fc3(fecture_output) 86 | 87 | if self.name == 'cifar10_LeNet': 88 | x = self.pool(F.relu(self.conv1(x))) 89 | x = self.pool(F.relu(self.conv2(x))) 90 | x = x.view(-1, 64 * 5 * 5) 91 | x = F.relu(self.fc1(x)) 92 | fecture_output = F.relu(self.fc2(x)) 93 | 94 | x = self.fc3(fecture_output) 95 | 96 | if self.name == 'cifar100_LeNet': 97 | x = self.pool(F.relu(self.conv1(x))) 98 | x = self.pool(F.relu(self.conv2(x))) 99 | x = x.view(-1, 64 * 5 * 5) 100 | x = F.relu(self.fc1(x)) 101 | fecture_output = F.relu(self.fc2(x)) 102 | 103 | x = self.fc3(fecture_output) 104 | 105 | 106 | return x, fecture_output 107 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FedCR: Personalized Federated Learning Based on Across-Client Common Representation with Conditional Mutual Information Regularization 2 | 3 | This directory contains source code for evaluating federated learning with different methods on various models and tasks. The code was developed for a paper, "FedCR: Personalized Federated Learning Based on Across-Client Common Representation with Conditional Mutual Information Regularization". 4 | 5 | ## Requirements 6 | 7 | Some pip packages are required by this library, and may need to be installed. For more details, see `requirements.txt`. We recommend running `pip install --requirement "requirements.txt"`. 8 | 9 | Below we give a summary of the datasets, tasks, and models used in this code. 10 | 11 | 12 | ## Task and dataset summary 13 | 14 | Note that we put the dataset under the directory .\federated-learning-master\Folder 15 | 16 | 17 | 18 | | Directory | Model | Task Summary | 19 | |------------------|-------------------------------------|---------------------------| 20 | | CIFAR-10 | CNN (with two convolutional layers) | Image classification | 21 | | CIFAR-100 | CNN (with two convolutional layers) | Image classification | 22 | | EMNIST | NN (fully connected neural network) | Digit recognition | 23 | | FMNIST | CNN (with two convolutional layer | Image classification | 24 | 25 | 26 | 27 | 28 | ## Training 29 | In this code, we compare 10 optimization methods: **FedAvg**, **FedAvg-FT**, **FedPer**, **LG-FedAvg**, **FedRep**, **FedBABU**, **Ditto**, **FedSR-FT**, **FedPAC**, and **FedCR**. Those methods use vanilla SGD on clients. To recreate our experimental results for each method, for example, for 100 clients and 10% participation rate, on the cifar100 data set with Dirichlet (0.3) split, run those commands for different methods: 30 | 31 | **FedAvg** and ***FedAvg-FT**: 32 | ``` 33 | python main_fed.py --filepath FedAvg.txt --dataset CIFAR100 --method fedavg --lr 0.01 --local_ep 10 --lr_decay 1 --rule Dirichlet --dir_a 0.3 --gpu 0 --epoch 500 --beta 0.001 --bs 10 --local_bs 48 34 | ``` 35 | 36 | **FedPer**: 37 | ``` 38 | python main_fed.py --filepath FedPer.txt --dataset CIFAR100 --method fedper --lr 0.01 --local_ep 10 --lr_decay 1 --rule Dirichlet --dir_a 1 --gpu 0 --epoch 500 --beta 0.001 --bs 10 --local_bs 48 39 | ``` 40 | 41 | **LG-FedAvg**: 42 | ``` 43 | python main_fed.py --filepath LG-FedAvg.txt --dataset CIFAR100 --method lg --lr 0.01 --local_ep 10 --lr_decay 1 --rule Dirichlet --dir_a 0.3 --gpu 0 --epoch 500 --beta 0.001 --bs 10 --local_bs 48 44 | ``` 45 | 46 | **FedRep**: 47 | ``` 48 | python main_fed.py --filepath FedRep.txt --dataset CIFAR100 --method fedrep --lr 0.01 --local_ep 10 --local_rep_ep 1 --lr_decay 1 --rule Dirichlet --dir_a 0.3 --gpu 0 --epoch 500 --beta 0.001 --bs 10 --local_bs 48 49 | ``` 50 | 51 | **FedBABU**: 52 | ``` 53 | python main_fed.py --filepath FedBABU.txt --dataset CIFAR100 --method fedbabu --lr 0.01 --local_ep 10 --lr_decay 1 --rule Dirichlet --dir_a 0.3 --gpu 0 --epoch 500 --beta 0.001 --bs 10 --local_bs 48 54 | ``` 55 | 56 | **Ditto**: 57 | ``` 58 | python /main_fed_ditto.py --filepath Ditto.txt --dataset CIFAR100 --method ditto --mu 0.1 --lr 0.5 --local_ep 10 --lr_decay 1 --rule Dirichlet --dir_a 0.3 --gpu 0 --epoch 500 --beta 0.001 --bs 10 --local_bs 48 59 | ``` 60 | 61 | **FedSR-FT**: 62 | ``` 63 | python main_fedSR.py --filepath FedSR-FT.txt --dataset CIFAR100 --method fedSR --lr 0.005 --local_ep 10 --lr_decay 1 --rule Dirichlet --dir_a 0.3 --gpu 0 --num_avg_train 1 --num_avg 1 --epoch 500 --dimZ 256 --CMI 0.001 --L2R 0.001 --bs 10 --local_bs 48 --beta2 1 64 | ``` 65 | 66 | **FedPAC**: 67 | ``` 68 | python main_fed_PAC.py --filepath FedPAC.txt --dataset CIFAR100 --method fedPAC --lr 0.01 --beta_PAC 1 --local_ep 10 --local_rep_ep 1 --lr_decay 1 --rule Dirichlet --dir_a 0.3 --gpu 0 --epoch 500 --dimZ 512 --beta 0.001 --bs 10 --local_bs 48 69 | ``` 70 | 71 | **FedCR**: 72 | ``` 73 | python main_fed.py --filepath FedCR.txt --dataset CIFAR100 --method FedCR --lr 0.05 --local_ep 10 --lr_decay 1 --rule Dirichlet --dir_a 0.3 --gpu 0 --num_avg_train 18 --num_avg 18 --epoch 500 --dimZ 512 --beta 0.0001 --bs 10 --local_bs 48 --beta2 1 74 | ``` 75 | 76 | 77 | 78 | ## Other hyperparameters and reproducibility 79 | 80 | All other hyperparameters are set by default to the values used in the `Experiment Details` of our Appendix. This includes the batch size, the number of clients per round, the number of client local updates, local learning rate, and model parameter flags. While they can be set for different behavior (such as varying the number of client local updates), they should not be changed if one wishes to reproduce the results from our paper. 81 | 82 | -------------------------------------------------------------------------------- /code_FedCR/.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 10 | 11 | 12 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 1674736627687 55 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /code_FedCR/main_fed_ditto.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import ssl 8 | 9 | import copy 10 | import itertools 11 | import random 12 | import torch 13 | import numpy as np 14 | from utils.options import args_parser 15 | from utils.seed import setup_seed 16 | from utils.logg import get_logger 17 | from models.Nets import client_model 18 | from models.Nets_VIB import client_model_VIB 19 | from utils.utils_dataset import DatasetObject 20 | from models.distributed_training_utils_ditto import Client, Server 21 | torch.set_printoptions( 22 | precision=8, 23 | threshold=1000, 24 | edgeitems=3, 25 | linewidth=150, 26 | profile=None, 27 | sci_mode=False 28 | ) 29 | if __name__ == '__main__': 30 | 31 | ssl._create_default_https_context = ssl._create_unverified_context 32 | # parse args 33 | args = args_parser() 34 | 35 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 36 | setup_seed(args.seed) 37 | 38 | 39 | data_path = 'Folder/' 40 | data_obj = DatasetObject(dataset=args.dataset, n_client=args.num_users, seed=args.seed, rule=args.rule, class_main=args.class_main, data_path=data_path, frac_data=args.frac_data, dir_alpha=args.dir_a) 41 | 42 | clnt_x = data_obj.clnt_x; 43 | clnt_y = data_obj.clnt_y; 44 | tst_x = data_obj.tst_x; 45 | tst_y = data_obj.tst_y 46 | 47 | # build model 48 | 49 | if args.dataset == 'CIFAR100': 50 | net_glob = client_model('cifar100_LeNet').to(args.device) 51 | elif args.dataset == 'CIFAR10': 52 | net_glob = client_model('cifar10_LeNet').to(args.device) 53 | elif args.dataset == 'EMNIST': 54 | net_glob = client_model('emnist_NN', [1 * 28 * 28, 10]).to(args.device) 55 | elif args.dataset == 'FMNIST': 56 | net_glob = client_model('FMNIST_CNN', [1 * 28 * 28, 10]).to(args.device) 57 | else: 58 | exit('Error: unrecognized model') 59 | 60 | total_num_layers = len(net_glob.state_dict().keys()) 61 | net_keys = [*net_glob.state_dict().keys()] 62 | 63 | 64 | if args.method == 'ditto': 65 | w_glob_keys = [] 66 | else: 67 | exit('Error: unrecognized data4') 68 | 69 | w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys)) 70 | 71 | clients = [Client(model=copy.deepcopy(net_glob).to(args.device), args=args, trn_x=data_obj.clnt_x[i], 72 | trn_y=data_obj.clnt_y[i], tst_x=data_obj.tst_x[i], tst_y=data_obj.tst_y[i], n_cls = data_obj.n_cls, dataset_name=data_obj.dataset, id_num=i) for i in range(args.num_users)] 73 | server = Server(model = (net_glob).to(args.device), args = args, n_cls = data_obj.n_cls) 74 | 75 | logger = get_logger(args.filepath) 76 | logger.info('--------args----------') 77 | for k in list(vars(args).keys()): 78 | logger.info('%s: %s' % (k, vars(args)[k])) 79 | logger.info('--------args----------\n') 80 | logger.info('total_num_layers') 81 | logger.info(total_num_layers) 82 | logger.info('net_keys') 83 | logger.info(net_keys) 84 | logger.info('w_glob_keys') 85 | logger.info(w_glob_keys) 86 | 87 | logger.info('start training!') 88 | 89 | for iter in range(args.epochs + 1): 90 | net_glob.train() 91 | 92 | m = max(int(args.frac * args.num_users), 1) 93 | if iter == args.epochs: 94 | m = args.num_users 95 | participating_clients = random.sample(clients, m) 96 | 97 | last = iter == args.epochs 98 | 99 | for client in participating_clients: 100 | client.synchronize_with_server(server, w_glob_keys) 101 | 102 | client.compute_bias() 103 | 104 | client.compute_weight_update(w_glob_keys, server, last) 105 | 106 | 107 | 108 | server.aggregate_weight_updates(clients=participating_clients, iter=iter) 109 | 110 | 111 | #-----------------------------------------------test-------------------------------------------------------------------- 112 | 113 | #-----------------------------------------------test-------------------------------------------------------------------- 114 | 115 | if iter % args.test_freq==args.test_freq-1 or iter>=args.epochs-10: 116 | results_loss =[]; results_acc = [] 117 | results_loss_last = []; results_acc_last = [] 118 | for client in clients: 119 | 120 | results_test, loss_test1 = client.evaluate(data_x=client.tst_x, data_y=client.tst_y, dataset_name=data_obj.dataset) 121 | 122 | results_loss.append(loss_test1) 123 | results_acc.append(results_test) 124 | results_loss = np.mean(results_loss) 125 | results_acc = np.mean(results_acc) 126 | 127 | logger.info('Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'. 128 | format(iter, args.lr, results_loss, results_acc)) 129 | 130 | 131 | args.lr = args.lr * (args.lr_decay) 132 | 133 | logger.info('finish training!') 134 | 135 | 136 | 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /code_FedCR/models/resnet18.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from torch.autograd import Variable 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1, use_batchnorm=True): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | if not use_batchnorm: 27 | self.bn1 = self.bn2 = nn.Sequential() 28 | 29 | self.shortcut = nn.Sequential() 30 | if stride != 1 or in_planes != self.expansion*planes: 31 | self.shortcut = nn.Sequential( 32 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(self.expansion*planes) if use_batchnorm else nn.Sequential() 34 | ) 35 | 36 | def forward(self, x): 37 | out = F.relu(self.bn1(self.conv1(x))) 38 | out = self.bn2(self.conv2(out)) 39 | out += self.shortcut(x) 40 | out = F.relu(out) 41 | return out 42 | 43 | 44 | class Bottleneck(nn.Module): 45 | expansion = 4 46 | 47 | def __init__(self, in_planes, planes, stride=1, use_batchnorm=True): 48 | super(Bottleneck, self).__init__() 49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 50 | self.bn1 = nn.BatchNorm2d(planes) 51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 55 | 56 | if not use_batchnorm: 57 | self.bn1 = self.bn2 = self.bn3 = nn.Sequential() 58 | 59 | self.shortcut = nn.Sequential() 60 | if stride != 1 or in_planes != self.expansion*planes: 61 | self.shortcut = nn.Sequential( 62 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 63 | nn.BatchNorm2d(self.expansion*planes) if use_batchnorm else nn.Sequential() 64 | ) 65 | 66 | def forward(self, x): 67 | out = F.relu(self.bn1(self.conv1(x))) 68 | out = F.relu(self.bn2(self.conv2(out))) 69 | out = self.bn3(self.conv3(out)) 70 | out += self.shortcut(x) 71 | out = F.relu(out) 72 | return out 73 | 74 | 75 | class ResNet(nn.Module): 76 | def __init__(self, block, num_blocks, num_classes=10, use_batchnorm=True): 77 | super(ResNet, self).__init__() 78 | self.in_planes = 64 79 | self.use_batchnorm = use_batchnorm 80 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 81 | self.bn1 = nn.BatchNorm2d(64) if use_batchnorm else nn.Sequential() 82 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 83 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 84 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 85 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 86 | self.linear = nn.Linear(512*block.expansion, num_classes) 87 | 88 | def _make_layer(self, block, planes, num_blocks, stride): 89 | strides = [stride] + [1]*(num_blocks-1) 90 | layers = [] 91 | for stride in strides: 92 | layers.append(block(self.in_planes, planes, stride, self.use_batchnorm)) 93 | self.in_planes = planes * block.expansion 94 | return nn.Sequential(*layers) 95 | 96 | def forward(self, x): 97 | out = F.relu(self.bn1(self.conv1(x))) 98 | out = self.layer1(out) 99 | out = self.layer2(out) 100 | out = self.layer3(out) 101 | out = self.layer4(out) 102 | out = F.avg_pool2d(out, 4) 103 | out = out.view(out.size(0), -1) 104 | out = self.linear(out) 105 | return out 106 | 107 | 108 | def ResNet18(use_batchnorm=True): 109 | return ResNet(BasicBlock, [2,2,2,2], use_batchnorm=use_batchnorm) 110 | 111 | def ResNet34(use_batchnorm=True): 112 | return ResNet(BasicBlock, [3,4,6,3], use_batchnorm=use_batchnorm) 113 | 114 | def ResNet50(use_batchnorm=True): 115 | return ResNet(Bottleneck, [3,4,6,3], use_batchnorm=use_batchnorm) 116 | 117 | def ResNet101(use_batchnorm=True): 118 | return ResNet(Bottleneck, [3,4,23,3], use_batchnorm=use_batchnorm) 119 | 120 | def ResNet152(use_batchnorm=True): 121 | return ResNet(Bottleneck, [3,8,36,3], use_batchnorm=use_batchnorm) 122 | 123 | 124 | def test(): 125 | net = ResNet18() 126 | y = net(Variable(torch.randn(1,3,32,32))) 127 | print(y.size()) 128 | 129 | # test() 130 | -------------------------------------------------------------------------------- /code_FedCR/main_fedSR.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import ssl 8 | 9 | import copy 10 | import itertools 11 | import random 12 | import torch 13 | import numpy as np 14 | from utils.options import args_parser 15 | from utils.seed import setup_seed 16 | from utils.logg import get_logger 17 | from models.NetsSR import Model_CMI 18 | from models.Nets_VIB import client_model_VIB 19 | from utils.utils_dataset import DatasetObject 20 | from models.distributed_training_utilsSR import Client, Server 21 | torch.set_printoptions( 22 | precision=8, 23 | threshold=1000, 24 | edgeitems=3, 25 | linewidth=150, 26 | profile=None, 27 | sci_mode=False 28 | ) 29 | if __name__ == '__main__': 30 | 31 | ssl._create_default_https_context = ssl._create_unverified_context 32 | # parse args 33 | args = args_parser() 34 | 35 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 36 | setup_seed(args.seed) 37 | 38 | 39 | data_path = 'Folder/' 40 | data_obj = DatasetObject(dataset=args.dataset, n_client=args.num_users, seed=args.seed, rule=args.rule, class_main=args.class_main, data_path=data_path, frac_data=args.frac_data, dir_alpha=args.dir_a) 41 | 42 | clnt_x = data_obj.clnt_x; 43 | clnt_y = data_obj.clnt_y; 44 | tst_x = data_obj.tst_x; 45 | tst_y = data_obj.tst_y 46 | 47 | # build model 48 | 49 | 50 | net_glob = Model_CMI(args, args.dimZ, args.alpha, args.dataset).to(args.device) 51 | 52 | 53 | total_num_layers = len(net_glob.state_dict().keys()) 54 | net_keys = [*net_glob.state_dict().keys()] 55 | 56 | if args.method == 'fedSR': 57 | w_glob_keys = [] 58 | 59 | w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys)) 60 | 61 | clients = [Client(model=Model_CMI(args, args.dimZ, args.alpha, args.dataset).to(args.device), args=args, trn_x=data_obj.clnt_x[i], 62 | trn_y=data_obj.clnt_y[i], tst_x=data_obj.tst_x[i], tst_y=data_obj.tst_y[i], n_cls = data_obj.n_cls, dataset_name=data_obj.dataset, id_num=i) for i in range(args.num_users)] 63 | server = Server(model = (net_glob).to(args.device), args = args, n_cls = data_obj.n_cls) 64 | 65 | 66 | logger = get_logger(args.filepath) 67 | logger.info('--------args----------') 68 | for k in list(vars(args).keys()): 69 | logger.info('%s: %s' % (k, vars(args)[k])) 70 | logger.info('--------args----------\n') 71 | logger.info('total_num_layers') 72 | logger.info(total_num_layers) 73 | logger.info('net_keys') 74 | logger.info(net_keys) 75 | logger.info('w_glob_keys') 76 | logger.info(w_glob_keys) 77 | 78 | logger.info('start training!') 79 | 80 | for iter in range(args.epochs + 1): 81 | net_glob.train() 82 | 83 | m = max(int(args.frac * args.num_users), 1) 84 | if iter == args.epochs: 85 | m = args.num_users 86 | participating_clients = random.sample(clients, m) 87 | 88 | last = iter == args.epochs 89 | 90 | for client in participating_clients: 91 | 92 | 93 | client.synchronize_with_server(server, w_glob_keys) 94 | 95 | client.compute_weight_update(w_glob_keys, server, last) 96 | 97 | 98 | 99 | server.aggregate_weight_updates(clients=participating_clients, iter=iter) 100 | 101 | 102 | #-----------------------------------------------test-------------------------------------------------------------------- 103 | 104 | #-----------------------------------------------test-------------------------------------------------------------------- 105 | 106 | if iter % args.test_freq==args.test_freq-1 or iter>=args.epochs-10: 107 | results_loss =[]; results_acc = [] 108 | results_loss_last = []; results_acc_last = [] 109 | for client in clients: 110 | 111 | results_test, loss_test1 = client.evaluate(data_x=client.tst_x, data_y=client.tst_y, 112 | dataset_name=data_obj.dataset) 113 | if last: 114 | results_test_last, loss_test1_last = client.evaluate(data_x=client.tst_x, data_y=client.tst_y, 115 | dataset_name=data_obj.dataset) 116 | 117 | results_loss.append(loss_test1) 118 | results_acc.append(results_test) 119 | 120 | if last and args.method == 'fedSR': 121 | results_loss_last.append(loss_test1_last) 122 | results_acc_last.append(results_test_last) 123 | 124 | results_loss = np.mean(results_loss) 125 | results_acc = np.mean(results_acc) 126 | if last: 127 | logger.info('Final Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'. 128 | format(iter, args.lr, results_loss, results_acc)) 129 | else: 130 | logger.info('Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'. 131 | format(iter, args.lr, results_loss, results_acc)) 132 | 133 | if last and args.method == 'fedSR': 134 | results_loss_last = np.mean(results_loss_last) 135 | results_acc_last = np.mean(results_acc_last) 136 | logger.info('Final FT Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'. 137 | format(iter, args.lr, results_loss_last, results_acc_last)) 138 | 139 | args.lr = args.lr * (args.lr_decay) 140 | 141 | logger.info('finish training!') 142 | 143 | 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /code_FedCR/models/main_fedSR.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import ssl 8 | 9 | import copy 10 | import itertools 11 | import random 12 | import torch 13 | import numpy as np 14 | from utils.options import args_parser 15 | from utils.seed import setup_seed 16 | from utils.logg import get_logger 17 | from models.NetsSR import Model_CMI, Model_CMI_server 18 | from models.Nets_VIB import client_model_VIB 19 | from utils.utils_dataset import DatasetObject 20 | from models.distributed_training_utilsSR import Client, Server 21 | torch.set_printoptions( 22 | precision=8, 23 | threshold=1000, 24 | edgeitems=3, 25 | linewidth=150, 26 | profile=None, 27 | sci_mode=False 28 | ) 29 | if __name__ == '__main__': 30 | 31 | ssl._create_default_https_context = ssl._create_unverified_context 32 | # parse args 33 | args = args_parser() 34 | 35 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 36 | setup_seed(args.seed) 37 | 38 | 39 | data_path = 'Folder/' 40 | data_obj = DatasetObject(dataset=args.dataset, n_client=args.num_users, seed=args.seed, rule=args.rule, class_main=args.class_main, data_path=data_path, frac_data=args.frac_data, dir_alpha=args.dir_a) 41 | 42 | clnt_x = data_obj.clnt_x; 43 | clnt_y = data_obj.clnt_y; 44 | tst_x = data_obj.tst_x; 45 | tst_y = data_obj.tst_y 46 | 47 | # build model 48 | 49 | 50 | net_glob = client_model_VIB(args, args.dimZ, args.alpha, args.dataset).to(args.device) 51 | 52 | 53 | total_num_layers = len(net_glob.state_dict().keys()) 54 | net_keys = [*net_glob.state_dict().keys()] 55 | 56 | if args.method == 'fedSR': 57 | w_glob_keys = [] 58 | 59 | w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys)) 60 | 61 | clients = [Client(model=client_model_VIB(args, args.dimZ, args.alpha, args.dataset).to(args.device), args=args, trn_x=data_obj.clnt_x[i], 62 | trn_y=data_obj.clnt_y[i], tst_x=data_obj.tst_x[i], tst_y=data_obj.tst_y[i], n_cls = data_obj.n_cls, dataset_name=data_obj.dataset, id_num=i) for i in range(args.num_users)] 63 | server = Server(model = (net_glob).to(args.device), args = args, n_cls = data_obj.n_cls) 64 | 65 | 66 | logger = get_logger(args.filepath) 67 | logger.info('--------args----------') 68 | for k in list(vars(args).keys()): 69 | logger.info('%s: %s' % (k, vars(args)[k])) 70 | logger.info('--------args----------\n') 71 | logger.info('total_num_layers') 72 | logger.info(total_num_layers) 73 | logger.info('net_keys') 74 | logger.info(net_keys) 75 | logger.info('w_glob_keys') 76 | logger.info(w_glob_keys) 77 | 78 | logger.info('start training!') 79 | 80 | for iter in range(args.epochs + 1): 81 | net_glob.train() 82 | 83 | m = max(int(args.frac * args.num_users), 1) 84 | if iter == args.epochs: 85 | m = args.num_users 86 | participating_clients = random.sample(clients, m) 87 | 88 | last = iter == args.epochs 89 | 90 | for client in participating_clients: 91 | 92 | 93 | client.synchronize_with_server(server, w_glob_keys) 94 | 95 | client.compute_weight_update(w_glob_keys, server, last) 96 | 97 | 98 | 99 | server.aggregate_weight_updates(clients=participating_clients, iter=iter) 100 | 101 | 102 | #-----------------------------------------------test-------------------------------------------------------------------- 103 | 104 | #-----------------------------------------------test-------------------------------------------------------------------- 105 | 106 | if iter % args.test_freq==args.test_freq-1 or iter>=args.epochs-10: 107 | results_loss =[]; results_acc = [] 108 | results_loss_last = []; results_acc_last = [] 109 | for client in clients: 110 | 111 | results_test, loss_test1 = client.evaluate(data_x=client.tst_x, data_y=client.tst_y, 112 | dataset_name=data_obj.dataset) 113 | if last: 114 | results_test_last, loss_test1_last = client.evaluate(data_x=client.tst_x, data_y=client.tst_y, 115 | dataset_name=data_obj.dataset) 116 | 117 | results_loss.append(loss_test1) 118 | results_acc.append(results_test) 119 | 120 | if last and args.method == 'fedSR': 121 | results_loss_last.append(loss_test1_last) 122 | results_acc_last.append(results_test_last) 123 | 124 | results_loss = np.mean(results_loss) 125 | results_acc = np.mean(results_acc) 126 | if last: 127 | logger.info('Final Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'. 128 | format(iter, args.lr, results_loss, results_acc)) 129 | else: 130 | logger.info('Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'. 131 | format(iter, args.lr, results_loss, results_acc)) 132 | 133 | if last and args.method == 'fedSR': 134 | results_loss_last = np.mean(results_loss_last) 135 | results_acc_last = np.mean(results_acc_last) 136 | logger.info('Final FT Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'. 137 | format(iter, args.lr, results_loss_last, results_acc_last)) 138 | 139 | args.lr = args.lr * (args.lr_decay) 140 | 141 | logger.info('finish training!') 142 | 143 | 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /code_FedCR/main_fed_local.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import ssl 8 | 9 | import copy 10 | import itertools 11 | import random 12 | import torch 13 | import numpy as np 14 | from utils.options import args_parser 15 | from utils.seed import setup_seed 16 | from utils.logg import get_logger 17 | from models.Nets import client_model 18 | from models.Nets_VIB import client_model_VIB 19 | from utils.utils_dataset import DatasetObject 20 | from models.distributed_training_utils import Client, Server 21 | torch.set_printoptions( 22 | precision=8, 23 | threshold=1000, 24 | edgeitems=3, 25 | linewidth=150, 26 | profile=None, 27 | sci_mode=False 28 | ) 29 | if __name__ == '__main__': 30 | 31 | ssl._create_default_https_context = ssl._create_unverified_context 32 | # parse args 33 | args = args_parser() 34 | 35 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 36 | setup_seed(args.seed) 37 | 38 | 39 | data_path = 'Folder/' 40 | data_obj = DatasetObject(dataset=args.dataset, n_client=args.num_users, seed=args.seed, rule=args.rule, class_main=args.class_main, data_path=data_path, frac_data=args.frac_data, dir_alpha=args.dir_a) 41 | 42 | clnt_x = data_obj.clnt_x; 43 | clnt_y = data_obj.clnt_y; 44 | tst_x = data_obj.tst_x; 45 | tst_y = data_obj.tst_y 46 | 47 | # build model 48 | if args.method == 'FedCR': 49 | net_glob = client_model_VIB(args, args.dimZ, args.alpha, args.dataset).to(args.device) 50 | else: 51 | if args.dataset == 'CIFAR100': 52 | net_glob = client_model('cifar100_LeNet').to(args.device) 53 | elif args.dataset == 'CIFAR10': 54 | net_glob = client_model('cifar10_LeNet').to(args.device) 55 | elif args.dataset == 'EMNIST': 56 | net_glob = client_model('emnist_NN', [1 * 28 * 28, 10]).to(args.device) 57 | elif args.dataset == 'FMNIST': 58 | net_glob = client_model('FMNIST_CNN', [1 * 28 * 28, 10]).to(args.device) 59 | else: 60 | exit('Error: unrecognized model') 61 | 62 | total_num_layers = len(net_glob.state_dict().keys()) 63 | net_keys = [*net_glob.state_dict().keys()] 64 | 65 | if args.method == 'fedrep' or args.method == 'fedper' or args.method == 'fedbabu': 66 | if 'CIFAR100' in args.dataset: 67 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3]] 68 | elif 'CIFAR10' in args.dataset: 69 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3]] 70 | elif 'EMNIST' in args.dataset: 71 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1]] 72 | elif 'FMNIST' in args.dataset: 73 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3]] 74 | else: 75 | exit('Error: unrecognized data1') 76 | elif args.method == 'lg': 77 | if 'CIFAR100' in args.dataset: 78 | w_glob_keys = [net_glob.weight_keys[i] for i in [3, 4]] 79 | elif 'CIFAR10' in args.dataset: 80 | w_glob_keys = [net_glob.weight_keys[i] for i in [3, 4]] 81 | elif 'EMNIST' in args.dataset: 82 | w_glob_keys = [net_glob.weight_keys[i] for i in [1, 2]] 83 | elif 'FMNIST' in args.dataset: 84 | w_glob_keys = [net_glob.weight_keys[i] for i in [3, 4]] 85 | else: 86 | exit('Error: unrecognized data2') 87 | elif args.method == 'FedCR': 88 | if 'CIFAR100' in args.dataset: 89 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3, 4]] 90 | elif 'CIFAR10' in args.dataset: 91 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3, 4]] 92 | elif 'EMNIST' in args.dataset: 93 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2]] 94 | elif 'FMNIST' in args.dataset: 95 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3, 4]] 96 | else: 97 | exit('Error: unrecognized data3') 98 | elif args.method == 'fedavg' or args.method == 'ditto' or args.method == 'maml': 99 | w_glob_keys = [] 100 | else: 101 | exit('Error: unrecognized data4') 102 | 103 | w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys)) 104 | 105 | clients = [Client(model=copy.deepcopy(net_glob).to(args.device), args=args, trn_x=data_obj.clnt_x[i], 106 | trn_y=data_obj.clnt_y[i], tst_x=data_obj.tst_x[i], tst_y=data_obj.tst_y[i], n_cls = data_obj.n_cls, dataset_name=data_obj.dataset, id_num=i) for i in range(args.num_users)] 107 | 108 | server = Server(model = (net_glob).to(args.device), args = args, n_cls = data_obj.n_cls) 109 | 110 | logger = get_logger(args.filepath) 111 | logger.info('--------args----------') 112 | for k in list(vars(args).keys()): 113 | logger.info('%s: %s' % (k, vars(args)[k])) 114 | logger.info('--------args----------\n') 115 | logger.info('total_num_layers') 116 | logger.info(total_num_layers) 117 | logger.info('net_keys') 118 | logger.info(net_keys) 119 | logger.info('w_glob_keys') 120 | logger.info(w_glob_keys) 121 | 122 | logger.info('start training!') 123 | 124 | results_loss = []; 125 | results_acc = [] 126 | 127 | for client in clients: 128 | 129 | client.compute_weight_update(w_glob_keys, server, last=False) 130 | 131 | results_test, loss_test1 = client.evaluate(data_x=client.tst_x, data_y=client.tst_y, 132 | dataset_name=data_obj.dataset) 133 | 134 | 135 | results_loss.append(loss_test1) 136 | results_acc.append(results_test) 137 | 138 | results_loss = np.mean(results_loss) 139 | results_acc = np.mean(results_acc) 140 | 141 | logger.info('Final Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'.format(iter, args.lr, results_loss, results_acc)) 142 | 143 | args.lr = args.lr * (args.lr_decay) 144 | 145 | logger.info('finish training!') 146 | 147 | 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /code_FedCR/main_fed_PAC.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import ssl 8 | 9 | import copy 10 | import itertools 11 | import random 12 | import torch 13 | import numpy as np 14 | from utils.options import args_parser 15 | from utils.seed import setup_seed 16 | from utils.logg import get_logger 17 | from models.Nets_PAC import client_model 18 | from models.Nets_VIB import client_model_VIB 19 | from utils.utils_dataset import DatasetObject 20 | from models.distributed_training_utils_PAC import Client, Server 21 | torch.set_printoptions( 22 | precision=8, 23 | threshold=1000, 24 | edgeitems=3, 25 | linewidth=150, 26 | profile=None, 27 | sci_mode=False 28 | ) 29 | if __name__ == '__main__': 30 | 31 | ssl._create_default_https_context = ssl._create_unverified_context 32 | # parse args 33 | args = args_parser() 34 | 35 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 36 | setup_seed(args.seed) 37 | 38 | 39 | data_path = 'Folder/' 40 | data_obj = DatasetObject(dataset=args.dataset, n_client=args.num_users, seed=args.seed, rule=args.rule, class_main=args.class_main, data_path=data_path, frac_data=args.frac_data, dir_alpha=args.dir_a) 41 | 42 | clnt_x = data_obj.clnt_x; 43 | clnt_y = data_obj.clnt_y; 44 | tst_x = data_obj.tst_x; 45 | tst_y = data_obj.tst_y 46 | 47 | # build model 48 | 49 | if args.dataset == 'CIFAR100': 50 | net_glob = client_model('cifar100_LeNet').to(args.device) 51 | elif args.dataset == 'CIFAR10': 52 | net_glob = client_model('cifar10_LeNet').to(args.device) 53 | elif args.dataset == 'EMNIST': 54 | net_glob = client_model('emnist_NN', [1 * 28 * 28, 10]).to(args.device) 55 | elif args.dataset == 'FMNIST': 56 | net_glob = client_model('FMNIST_CNN', [1 * 28 * 28, 10]).to(args.device) 57 | else: 58 | exit('Error: unrecognized model') 59 | 60 | 61 | total_num_layers = len(net_glob.state_dict().keys()) 62 | net_keys = [*net_glob.state_dict().keys()] 63 | 64 | if args.method == 'fedPAC': 65 | if 'CIFAR100' in args.dataset: 66 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3]] 67 | elif 'CIFAR10' in args.dataset: 68 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3]] 69 | elif 'EMNIST' in args.dataset: 70 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1]] 71 | elif 'FMNIST' in args.dataset: 72 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3]] 73 | else: 74 | exit('Error: unrecognized data1') 75 | 76 | elif args.method == 'fedavg': 77 | w_glob_keys = [] 78 | else: 79 | exit('Error: unrecognized data4') 80 | 81 | w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys)) 82 | 83 | clients = [Client(model=copy.deepcopy(net_glob).to(args.device), args=args, trn_x=data_obj.clnt_x[i], 84 | trn_y=data_obj.clnt_y[i], tst_x=data_obj.tst_x[i], tst_y=data_obj.tst_y[i], n_cls = data_obj.n_cls, dataset_name=data_obj.dataset, id_num=i) for i in range(args.num_users)] 85 | server = Server(model = (net_glob).to(args.device), args = args, n_cls = data_obj.n_cls) 86 | 87 | logger = get_logger(args.filepath) 88 | logger.info('--------args----------') 89 | for k in list(vars(args).keys()): 90 | logger.info('%s: %s' % (k, vars(args)[k])) 91 | logger.info('--------args----------\n') 92 | logger.info('total_num_layers') 93 | logger.info(total_num_layers) 94 | logger.info('net_keys') 95 | logger.info(net_keys) 96 | logger.info('w_glob_keys') 97 | logger.info(w_glob_keys) 98 | 99 | logger.info('start training!') 100 | 101 | for iter in range(args.epochs + 1): 102 | net_glob.train() 103 | 104 | m = max(int(args.frac * args.num_users), 1) 105 | 106 | if iter == args.epochs: 107 | m = args.num_users 108 | participating_clients = random.sample(clients, m) 109 | 110 | last = iter == args.epochs 111 | 112 | for client in participating_clients: 113 | client.synchronize_with_server(server, w_glob_keys) 114 | 115 | client.compute_weight_update(w_glob_keys, server, last) 116 | 117 | client.local_feature() 118 | 119 | server.aggregate_weight_updates(clients=participating_clients, iter=iter) 120 | server.global_feature_centroids(clients=participating_clients) 121 | #server.Get_classifier(participating_clients, w_glob_keys) #for non-iid 2 case, the use of classifier combination will sometimes degrade the experimental performance 122 | 123 | #-----------------------------------------------test-------------------------------------------------------------------- 124 | 125 | #-----------------------------------------------test-------------------------------------------------------------------- 126 | 127 | if iter % args.test_freq==args.test_freq-1 or iter>=args.epochs-10: 128 | results_loss =[]; results_acc = [] 129 | results_loss_last = []; results_acc_last = [] 130 | for client in clients: 131 | 132 | if args.method != 'fedavg': 133 | results_test, loss_test1 = client.evaluate(data_x=client.tst_x, data_y=client.tst_y, 134 | dataset_name=data_obj.dataset) 135 | if last: 136 | results_test_last, loss_test1_last = client.evaluate(data_x=client.tst_x, data_y=client.tst_y, 137 | dataset_name=data_obj.dataset) 138 | 139 | results_loss.append(loss_test1) 140 | results_acc.append(results_test) 141 | 142 | 143 | results_loss = np.mean(results_loss) 144 | results_acc = np.mean(results_acc) 145 | 146 | if last: 147 | logger.info('Final Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'. 148 | format(iter, args.lr, results_loss, results_acc)) 149 | else: 150 | logger.info('Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'. 151 | format(iter, args.lr, results_loss, results_acc)) 152 | 153 | 154 | args.lr = args.lr * (args.lr_decay) 155 | 156 | logger.info('finish training!') 157 | 158 | 159 | 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /code_FedCR/models/Nets_VIB.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class client_model_VIB(nn.Module): 7 | 8 | def __init__(self, args, dimZ=256, alpha=0, dataset = 'EMNIST'): 9 | # the dimension of Z 10 | super().__init__() 11 | 12 | self.alpha = alpha 13 | self.dimZ = dimZ 14 | self.device = args.device 15 | self.dataset = dataset 16 | 17 | if self.dataset == 'EMNIST': 18 | self.n_cls = 10 19 | self.fc1 = nn.Linear(1 * 28 * 28, 1024) 20 | self.fc2 = nn.Linear(1024, 1024) 21 | self.fc3 = nn.Linear(1024, 2 * self.dimZ) 22 | self.fc4 = nn.Linear(self.dimZ, self.n_cls) 23 | self.weight_keys = [['fc1.weight', 'fc1.bias'], 24 | ['fc2.weight', 'fc2.bias'], 25 | ['fc3.weight', 'fc3.bias'], 26 | ['fc4.weight', 'fc4.bias']] 27 | 28 | 29 | if self.dataset == 'FMNIST': 30 | self.n_cls = 10 31 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5) 32 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 33 | self.conv2 = nn.Conv2d(in_channels=4, out_channels=12, kernel_size=5) 34 | self.fc1 = nn.Linear(12 * 4 * 4, 1024) 35 | self.fc2 = nn.Linear(1024, 1024) 36 | self.fc3 = nn.Linear(1024, 2 * self.dimZ) 37 | self.fc4 = nn.Linear(self.dimZ, self.n_cls) 38 | 39 | self.weight_keys = [['conv1.weight', 'conv1.bias'], 40 | ['conv2.weight', 'conv2.bias'], 41 | ['fc1.weight', 'fc1.bias'], 42 | ['fc2.weight', 'fc2.bias'], 43 | ['fc3.weight', 'fc3.bias'], 44 | ['fc4.weight', 'fc4.bias']] 45 | 46 | 47 | if self.dataset == 'CIFAR10': 48 | self.n_cls = 10 49 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5) 50 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5) 51 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 52 | self.fc1 = nn.Linear(64 * 5 * 5, 1024) 53 | self.fc2 = nn.Linear(1024, 1024) 54 | self.fc3 = nn.Linear(1024, 2 * self.dimZ) 55 | self.fc4 = nn.Linear(self.dimZ, self.n_cls) 56 | self.weight_keys = [['conv1.weight', 'conv1.bias'], 57 | ['conv2.weight', 'conv2.bias'], 58 | ['fc1.weight', 'fc1.bias'], 59 | ['fc2.weight', 'fc2.bias'], 60 | ['fc3.weight', 'fc3.bias'], 61 | ['fc4.weight', 'fc4.bias']] 62 | 63 | if self.dataset == 'CIFAR100': 64 | self.n_cls = 100 65 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5) 66 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5) 67 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 68 | self.fc1 = nn.Linear(64 * 5 * 5, 1024) 69 | self.fc2 = nn.Linear(1024, 1024) 70 | self.fc3 = nn.Linear(1024, 2 * self.dimZ) 71 | self.fc4 = nn.Linear(self.dimZ, self.n_cls) 72 | self.weight_keys = [['conv1.weight', 'conv1.bias'], 73 | ['conv2.weight', 'conv2.bias'], 74 | ['fc1.weight', 'fc1.bias'], 75 | ['fc2.weight', 'fc2.bias'], 76 | ['fc3.weight', 'fc3.bias'], 77 | ['fc4.weight', 'fc4.bias']] 78 | 79 | def gaussian_noise(self, num_samples, K): 80 | # works with integers as well as tuples 81 | 82 | return torch.normal(torch.zeros(*num_samples, K), torch.ones(*num_samples, K)).to(self.device) 83 | 84 | def sample_prior_Z(self, num_samples): 85 | return self.gaussian_noise(num_samples=num_samples, K=self.dimZ) 86 | 87 | def encoder_result(self, encoder_output): 88 | mu = encoder_output[:, :self.dimZ] 89 | sigma = torch.nn.functional.softplus(encoder_output[:, self.dimZ:] - self.alpha) 90 | 91 | return mu, sigma 92 | 93 | def sample_encoder_Z(self, batch_size, encoder_Z_distr, num_samples): 94 | 95 | mu, sigma = encoder_Z_distr 96 | 97 | return mu + sigma * self.gaussian_noise(num_samples=(num_samples, batch_size), K=self.dimZ) 98 | 99 | def forward(self, batch_x, num_samples = 1): 100 | 101 | if self.dataset == 'EMNIST': 102 | batch_size = batch_x.size()[0] 103 | # sample from encoder 104 | x = batch_x.view(-1, 1 * 28 * 28) 105 | x = F.relu(self.fc1(x)) 106 | x = F.relu(self.fc2(x)) 107 | 108 | encoder_output = self.fc3(x) 109 | encoder_Z_distr = self.encoder_result(encoder_output) 110 | to_decoder = self.sample_encoder_Z(batch_size=batch_size, encoder_Z_distr=encoder_Z_distr, 111 | num_samples=num_samples) 112 | decoder_logits = self.fc4(to_decoder) 113 | 114 | # batch should go first 115 | 116 | 117 | if self.dataset == 'FMNIST': 118 | batch_size = batch_x.size()[0] 119 | # sample from encoder 120 | x = self.pool(F.relu(self.conv1(batch_x))) 121 | x = self.pool(F.relu(self.conv2(x))) 122 | x = x.view(-1, 12 * 4 * 4) 123 | x = F.relu(self.fc1(x)) 124 | x = F.relu(self.fc2(x)) 125 | 126 | encoder_output = self.fc3(x) 127 | encoder_Z_distr = self.encoder_result(encoder_output) 128 | to_decoder = self.sample_encoder_Z(batch_size=batch_size, encoder_Z_distr=encoder_Z_distr, 129 | num_samples=num_samples) 130 | decoder_logits = self.fc4(to_decoder) 131 | 132 | 133 | if self.dataset == 'CIFAR10': 134 | batch_size = batch_x.size()[0] 135 | x = self.pool(F.relu(self.conv1(batch_x))) 136 | x = self.pool(F.relu(self.conv2(x))) 137 | x = x.view(-1, 64 * 5 * 5) 138 | x = F.relu(self.fc1(x)) 139 | x = F.relu(self.fc2(x)) 140 | 141 | encoder_output = self.fc3(x) 142 | encoder_Z_distr = self.encoder_result(encoder_output) 143 | to_decoder = self.sample_encoder_Z(batch_size=batch_size, encoder_Z_distr=encoder_Z_distr, 144 | num_samples=num_samples) 145 | decoder_logits = self.fc4(to_decoder) 146 | 147 | if self.dataset == 'CIFAR100': 148 | batch_size = batch_x.size()[0] 149 | x = self.pool(F.relu(self.conv1(batch_x))) 150 | x = self.pool(F.relu(self.conv2(x))) 151 | x = x.view(-1, 64 * 5 * 5) 152 | x = F.relu(self.fc1(x)) 153 | x = F.relu(self.fc2(x)) 154 | 155 | encoder_output = self.fc3(x) 156 | encoder_Z_distr = self.encoder_result(encoder_output) 157 | to_decoder = self.sample_encoder_Z(batch_size=batch_size, encoder_Z_distr=encoder_Z_distr, 158 | num_samples=num_samples) 159 | decoder_logits = self.fc4(to_decoder) 160 | 161 | return encoder_Z_distr, decoder_logits 162 | 163 | 164 | def weight_init(self): 165 | for m in self._modules: 166 | xavier_init(self._modules[m]) 167 | 168 | 169 | def KL_between_normals(q_distr, p_distr): 170 | mu_q, sigma_q = q_distr 171 | mu_p, sigma_p = p_distr #Standard Deviation 172 | k = mu_q.size(1) 173 | 174 | mu_diff = mu_p - mu_q 175 | mu_diff_sq = torch.mul(mu_diff, mu_diff) 176 | logdet_sigma_q = torch.sum(2 * torch.log(torch.clamp(sigma_q, min=1e-8)), dim=1) 177 | logdet_sigma_p = torch.sum(2 * torch.log(torch.clamp(sigma_p, min=1e-8)), dim=1) 178 | 179 | fs = torch.sum(torch.div(sigma_q ** 2, sigma_p ** 2), dim=1) + torch.sum(torch.div(mu_diff_sq, sigma_p ** 2), dim=1) 180 | two_kl = fs - k + logdet_sigma_p - logdet_sigma_q 181 | return two_kl * 0.5 182 | 183 | 184 | def xavier_init(ms): 185 | for m in ms : 186 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 187 | nn.init.xavier_uniform(m.weight,gain=nn.init.calculate_gain('relu')) 188 | m.bias.data.zero_() 189 | -------------------------------------------------------------------------------- /code_FedCR/main_fed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import ssl 8 | 9 | import copy 10 | import itertools 11 | import random 12 | import torch 13 | import numpy as np 14 | from utils.options import args_parser 15 | from utils.seed import setup_seed 16 | from utils.logg import get_logger 17 | from models.Nets import client_model 18 | from models.Nets_VIB import client_model_VIB 19 | from utils.utils_dataset import DatasetObject 20 | from models.distributed_training_utils import Client, Server 21 | torch.set_printoptions( 22 | precision=8, 23 | threshold=1000, 24 | edgeitems=3, 25 | linewidth=150, 26 | profile=None, 27 | sci_mode=False 28 | ) 29 | if __name__ == '__main__': 30 | 31 | ssl._create_default_https_context = ssl._create_unverified_context 32 | # parse args 33 | args = args_parser() 34 | 35 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 36 | setup_seed(args.seed) 37 | 38 | 39 | data_path = 'Folder/' 40 | data_obj = DatasetObject(dataset=args.dataset, n_client=args.num_users, seed=args.seed, rule=args.rule, class_main=args.class_main, data_path=data_path, frac_data=args.frac_data, dir_alpha=args.dir_a) 41 | 42 | clnt_x = data_obj.clnt_x; 43 | clnt_y = data_obj.clnt_y; 44 | tst_x = data_obj.tst_x; 45 | tst_y = data_obj.tst_y 46 | 47 | # build model 48 | if args.method == 'FedCR': 49 | net_glob = client_model_VIB(args, args.dimZ, args.alpha, args.dataset).to(args.device) 50 | else: 51 | if args.dataset == 'CIFAR100': 52 | net_glob = client_model('cifar100_LeNet').to(args.device) 53 | elif args.dataset == 'CIFAR10': 54 | net_glob = client_model('cifar10_LeNet').to(args.device) 55 | elif args.dataset == 'EMNIST': 56 | net_glob = client_model('emnist_NN', [1 * 28 * 28, 10]).to(args.device) 57 | elif args.dataset == 'FMNIST': 58 | net_glob = client_model('FMNIST_CNN', [1 * 28 * 28, 10]).to(args.device) 59 | else: 60 | exit('Error: unrecognized model') 61 | 62 | total_num_layers = len(net_glob.state_dict().keys()) 63 | net_keys = [*net_glob.state_dict().keys()] 64 | 65 | if args.method == 'fedrep' or args.method == 'fedper' or args.method == 'fedbabu': 66 | if 'CIFAR100' in args.dataset: 67 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3]] 68 | elif 'CIFAR10' in args.dataset: 69 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3]] 70 | elif 'EMNIST' in args.dataset: 71 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1]] 72 | elif 'FMNIST' in args.dataset: 73 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3]] 74 | else: 75 | exit('Error: unrecognized data1') 76 | elif args.method == 'lg': 77 | if 'CIFAR100' in args.dataset: 78 | w_glob_keys = [net_glob.weight_keys[i] for i in [3, 4]] 79 | elif 'CIFAR10' in args.dataset: 80 | w_glob_keys = [net_glob.weight_keys[i] for i in [3, 4]] 81 | elif 'EMNIST' in args.dataset: 82 | w_glob_keys = [net_glob.weight_keys[i] for i in [1, 2]] 83 | elif 'FMNIST' in args.dataset: 84 | w_glob_keys = [net_glob.weight_keys[i] for i in [3, 4]] 85 | else: 86 | exit('Error: unrecognized data2') 87 | elif args.method == 'FedCR': 88 | if 'CIFAR100' in args.dataset: 89 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3, 4]] 90 | elif 'CIFAR10' in args.dataset: 91 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3, 4]] 92 | elif 'EMNIST' in args.dataset: 93 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2]] 94 | elif 'FMNIST' in args.dataset: 95 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3, 4]] 96 | else: 97 | exit('Error: unrecognized data3') 98 | elif args.method == 'fedavg' or args.method == 'ditto' or args.method == 'maml': 99 | w_glob_keys = [] 100 | else: 101 | exit('Error: unrecognized data4') 102 | 103 | w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys)) 104 | 105 | clients = [Client(model=copy.deepcopy(net_glob).to(args.device), args=args, trn_x=data_obj.clnt_x[i], 106 | trn_y=data_obj.clnt_y[i], tst_x=data_obj.tst_x[i], tst_y=data_obj.tst_y[i], n_cls = data_obj.n_cls, dataset_name=data_obj.dataset, id_num=i) for i in range(args.num_users)] 107 | server = Server(model = (net_glob).to(args.device), args = args, n_cls = data_obj.n_cls) 108 | 109 | logger = get_logger(args.filepath) 110 | logger.info('--------args----------') 111 | for k in list(vars(args).keys()): 112 | logger.info('%s: %s' % (k, vars(args)[k])) 113 | logger.info('--------args----------\n') 114 | logger.info('total_num_layers') 115 | logger.info(total_num_layers) 116 | logger.info('net_keys') 117 | logger.info(net_keys) 118 | logger.info('w_glob_keys') 119 | logger.info(w_glob_keys) 120 | 121 | logger.info('start training!') 122 | 123 | for iter in range(args.epochs + 1): 124 | net_glob.train() 125 | 126 | m = max(int(args.frac * args.num_users), 1) 127 | if iter == args.epochs: 128 | m = args.num_users 129 | participating_clients = random.sample(clients, m) 130 | 131 | last = iter == args.epochs 132 | 133 | for client in participating_clients: 134 | 135 | if args.sync == 'True': 136 | client.synchronize_with_server(server, w_glob_keys) 137 | 138 | client.compute_weight_update(w_glob_keys, server, last) 139 | 140 | 141 | 142 | server.aggregate_weight_updates(clients=participating_clients, iter=iter) 143 | 144 | if args.method == 'FedCR': 145 | server.global_POE(clients=participating_clients) 146 | 147 | #-----------------------------------------------test-------------------------------------------------------------------- 148 | 149 | #-----------------------------------------------test-------------------------------------------------------------------- 150 | 151 | if iter % args.test_freq==args.test_freq-1 or iter>=args.epochs-10: 152 | results_loss =[]; results_acc = [] 153 | results_loss_last = []; results_acc_last = [] 154 | for client in clients: 155 | if args.method == 'FedCR': 156 | results_test, loss_test1 = client.evaluate_FedVIB(data_x=client.tst_x, data_y=client.tst_y, 157 | dataset_name=data_obj.dataset) 158 | elif args.method != 'fedavg': 159 | results_test, loss_test1 = client.evaluate(data_x=client.tst_x, data_y=client.tst_y, 160 | dataset_name=data_obj.dataset) 161 | elif args.method == 'fedavg': 162 | results_test, loss_test1 = server.evaluate(data_x=client.tst_x, data_y=client.tst_y, 163 | dataset_name=data_obj.dataset) 164 | if last: 165 | results_test_last, loss_test1_last = client.evaluate(data_x=client.tst_x, data_y=client.tst_y, 166 | dataset_name=data_obj.dataset) 167 | results_loss.append(loss_test1) 168 | results_acc.append(results_test) 169 | 170 | if last and args.method == 'fedavg': 171 | results_loss_last.append(loss_test1_last) 172 | results_acc_last.append(results_test_last) 173 | 174 | results_loss = np.mean(results_loss) 175 | results_acc = np.mean(results_acc) 176 | if last: 177 | logger.info('Final Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'. 178 | format(iter, args.lr, results_loss, results_acc)) 179 | else: 180 | logger.info('Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'. 181 | format(iter, args.lr, results_loss, results_acc)) 182 | 183 | if last and args.method == 'fedavg': 184 | results_loss_last= np.mean(results_loss_last) 185 | results_acc_last= np.mean(results_acc_last) 186 | logger.info('Final FT Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'. 187 | format(iter, args.lr, results_loss_last, results_acc_last)) 188 | 189 | args.lr = args.lr * (args.lr_decay) 190 | 191 | logger.info('finish training!') 192 | 193 | 194 | 195 | 196 | 197 | 198 | -------------------------------------------------------------------------------- /code_FedCR/models/NetsSR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Model_CMI(nn.Module): 7 | 8 | def __init__(self, args, dimZ=256, alpha=0, dataset = 'EMNIST'): 9 | # the dimension of Z 10 | super().__init__() 11 | 12 | self.alpha = alpha 13 | self.dimZ = dimZ 14 | self.device = args.device 15 | self.dataset = dataset 16 | 17 | if self.dataset == 'EMNIST': 18 | self.n_cls = 10 19 | if self.dataset == 'FMNIST': 20 | self.n_cls = 10 21 | if self.dataset == 'CIFAR10': 22 | self.n_cls = 10 23 | if self.dataset == 'CIFAR100': 24 | self.n_cls = 100 25 | 26 | self.r_mu = nn.Parameter(torch.zeros(self.n_cls, self.dimZ)).to(self.device) 27 | self.r_sigma = nn.Parameter(torch.ones(self.n_cls, self.dimZ)).to(self.device) 28 | self.C = nn.Parameter(torch.ones([])).to(self.device) 29 | 30 | if self.dataset == 'EMNIST': 31 | self.n_cls = 10 32 | self.fc1 = nn.Linear(1 * 28 * 28, 1024) 33 | self.fc2 = nn.Linear(1024, 1024) 34 | self.fc3 = nn.Linear(1024, 2 * self.dimZ) 35 | self.fc4 = nn.Linear(self.dimZ, self.n_cls) 36 | self.weight_keys = [['fc1.weight', 'fc1.bias'], 37 | ['fc2.weight', 'fc2.bias'], 38 | ['fc3.weight', 'fc3.bias'], 39 | ['fc4.weight', 'fc4.bias']] 40 | 41 | if self.dataset == 'FMNIST': 42 | self.n_cls = 10 43 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5) 44 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 45 | self.conv2 = nn.Conv2d(in_channels=4, out_channels=12, kernel_size=5) 46 | self.fc1 = nn.Linear(12 * 4 * 4, 1024) 47 | self.fc2 = nn.Linear(1024, 1024) 48 | self.fc3 = nn.Linear(1024, 2 * self.dimZ) 49 | self.fc4 = nn.Linear(self.dimZ, self.n_cls) 50 | 51 | self.weight_keys = [['conv1.weight', 'conv1.bias'], 52 | ['conv2.weight', 'conv2.bias'], 53 | ['fc1.weight', 'fc1.bias'], 54 | ['fc2.weight', 'fc2.bias'], 55 | ['fc3.weight', 'fc3.bias'], 56 | ['fc4.weight', 'fc4.bias']] 57 | 58 | if self.dataset == 'CIFAR10': 59 | self.n_cls = 10 60 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5) 61 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5) 62 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 63 | self.fc1 = nn.Linear(64 * 5 * 5, 1024) 64 | self.fc2 = nn.Linear(1024, 1024) 65 | self.fc3 = nn.Linear(1024, 2 * self.dimZ) 66 | self.fc4 = nn.Linear(self.dimZ, self.n_cls) 67 | self.weight_keys = [['conv1.weight', 'conv1.bias'], 68 | ['conv2.weight', 'conv2.bias'], 69 | ['fc1.weight', 'fc1.bias'], 70 | ['fc2.weight', 'fc2.bias'], 71 | ['fc3.weight', 'fc3.bias'], 72 | ['fc4.weight', 'fc4.bias']] 73 | 74 | if self.dataset == 'CIFAR100': 75 | self.n_cls = 100 76 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5) 77 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5) 78 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 79 | self.fc1 = nn.Linear(64 * 5 * 5, 1024) 80 | self.fc2 = nn.Linear(1024, 1024) 81 | self.fc3 = nn.Linear(1024, 2 * self.dimZ) 82 | self.fc4 = nn.Linear(self.dimZ, self.n_cls) 83 | self.weight_keys = [['conv1.weight', 'conv1.bias'], 84 | ['conv2.weight', 'conv2.bias'], 85 | ['fc1.weight', 'fc1.bias'], 86 | ['fc2.weight', 'fc2.bias'], 87 | ['fc3.weight', 'fc3.bias'], 88 | ['fc4.weight', 'fc4.bias']] 89 | 90 | def gaussian_noise(self, num_samples, K): 91 | # works with integers as well as tuples 92 | 93 | return torch.normal(torch.zeros(*num_samples, K), torch.ones(*num_samples, K)).to(self.device) 94 | 95 | def sample_prior_Z(self, num_samples): 96 | return self.gaussian_noise(num_samples=num_samples, K=self.dimZ) 97 | 98 | def encoder_result(self, encoder_output): 99 | mu = encoder_output[:, :self.dimZ] 100 | sigma = torch.nn.functional.softplus(encoder_output[:, self.dimZ:] - self.alpha) 101 | 102 | return mu, sigma 103 | 104 | def sample_encoder_Z(self, batch_size, encoder_Z_distr, num_samples): 105 | 106 | mu, sigma = encoder_Z_distr 107 | 108 | return mu + sigma * self.gaussian_noise(num_samples=(num_samples, batch_size), K=self.dimZ) 109 | 110 | def forward(self, batch_x, num_samples=1): 111 | 112 | if self.dataset == 'EMNIST': 113 | batch_size = batch_x.size()[0] 114 | # sample from encoder 115 | x = batch_x.view(-1, 1 * 28 * 28) 116 | x = F.relu(self.fc1(x)) 117 | x = F.relu(self.fc2(x)) 118 | 119 | encoder_output = self.fc3(x) 120 | encoder_Z_distr = self.encoder_result(encoder_output) 121 | to_decoder = self.sample_encoder_Z(batch_size=batch_size, encoder_Z_distr=encoder_Z_distr, 122 | num_samples=num_samples) 123 | decoder_logits = self.fc4(to_decoder) 124 | 125 | # batch should go first 126 | 127 | if self.dataset == 'FMNIST': 128 | batch_size = batch_x.size()[0] 129 | # sample from encoder 130 | x = self.pool(F.relu(self.conv1(batch_x))) 131 | x = self.pool(F.relu(self.conv2(x))) 132 | x = x.view(-1, 12 * 4 * 4) 133 | x = F.relu(self.fc1(x)) 134 | x = F.relu(self.fc2(x)) 135 | 136 | encoder_output = self.fc3(x) 137 | encoder_Z_distr = self.encoder_result(encoder_output) 138 | to_decoder = self.sample_encoder_Z(batch_size=batch_size, encoder_Z_distr=encoder_Z_distr, 139 | num_samples=num_samples) 140 | decoder_logits = self.fc4(to_decoder) 141 | 142 | if self.dataset == 'CIFAR10': 143 | batch_size = batch_x.size()[0] 144 | x = self.pool(F.relu(self.conv1(batch_x))) 145 | x = self.pool(F.relu(self.conv2(x))) 146 | x = x.view(-1, 64 * 5 * 5) 147 | x = F.relu(self.fc1(x)) 148 | x = F.relu(self.fc2(x)) 149 | 150 | encoder_output = self.fc3(x) 151 | encoder_Z_distr = self.encoder_result(encoder_output) 152 | to_decoder = self.sample_encoder_Z(batch_size=batch_size, encoder_Z_distr=encoder_Z_distr, 153 | num_samples=num_samples) 154 | decoder_logits = self.fc4(to_decoder) 155 | 156 | if self.dataset == 'CIFAR100': 157 | batch_size = batch_x.size()[0] 158 | x = self.pool(F.relu(self.conv1(batch_x))) 159 | x = self.pool(F.relu(self.conv2(x))) 160 | x = x.view(-1, 64 * 5 * 5) 161 | x = F.relu(self.fc1(x)) 162 | x = F.relu(self.fc2(x)) 163 | 164 | encoder_output = self.fc3(x) 165 | encoder_Z_distr = self.encoder_result(encoder_output) 166 | to_decoder = self.sample_encoder_Z(batch_size=batch_size, encoder_Z_distr=encoder_Z_distr, 167 | num_samples=num_samples) 168 | decoder_logits = self.fc4(to_decoder) 169 | 170 | regL2R = torch.norm(to_decoder) 171 | 172 | return encoder_Z_distr, decoder_logits, regL2R 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | #regL2R = to_decoder_mearn.norm(dim=1).mean() 182 | 183 | 184 | ''' 185 | class Model_CMI(nn.Module): 186 | 187 | def __init__(self, args): 188 | self.probabilistic = True 189 | super(client_model_SR, self).__init__(args) 190 | 191 | self.dimZ = args.dimZ 192 | self.device = args.device 193 | self.dataset = args.dataset 194 | 195 | if self.dataset == 'EMNIST': 196 | self.n_cls = 10 197 | if self.dataset == 'FMNIST': 198 | self.n_cls = 10 199 | if self.dataset == 'CIFAR10': 200 | self.n_cls = 10 201 | if self.dataset == 'CIFAR100': 202 | self.n_cls = 100 203 | 204 | self.r_mu = nn.Parameter(torch.zeros(args.num_classes, args.z_dim)) 205 | self.r_sigma = nn.Parameter(torch.ones(args.num_classes, args.z_dim)) 206 | self.C = nn.Parameter(torch.ones([])) 207 | 208 | self.optim.add_param_group({'params':[self.r_mu,self.r_sigma,self.C],'lr':args.lr,'momentum':0.9}) 209 | ''' 210 | 211 | def KL_between_normals(q_distr, p_distr): 212 | mu_q, sigma_q = q_distr 213 | mu_p, sigma_p = p_distr #Standard Deviation 214 | k = mu_q.size(1) 215 | 216 | mu_diff = mu_p - mu_q 217 | mu_diff_sq = torch.mul(mu_diff, mu_diff) 218 | logdet_sigma_q = torch.sum(2 * torch.log(torch.clamp(sigma_q, min=1e-8)), dim=1) 219 | logdet_sigma_p = torch.sum(2 * torch.log(torch.clamp(sigma_p, min=1e-8)), dim=1) 220 | 221 | fs = torch.sum(torch.div(sigma_q ** 2, sigma_p ** 2), dim=1) + torch.sum(torch.div(mu_diff_sq, sigma_p ** 2), dim=1) 222 | two_kl = fs - k + logdet_sigma_p - logdet_sigma_q 223 | return two_kl * 0.5 224 | 225 | 226 | def xavier_init(ms): 227 | for m in ms : 228 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 229 | nn.init.xavier_uniform(m.weight,gain=nn.init.calculate_gain('relu')) 230 | m.bias.data.zero_() 231 | -------------------------------------------------------------------------------- /code_FedCR/models/distributed_training_utils_ditto.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | import math 4 | from torch import nn, autograd 5 | import numpy as np 6 | from torch.utils.data import DataLoader 7 | from utils.utils_dataset import Dataset 8 | from models.Nets_VIB import KL_between_normals 9 | import itertools 10 | import torch.nn.functional as F 11 | 12 | max_norm = 10 13 | 14 | def add(target, source): 15 | for name in target: 16 | target[name].data += source[name].data.clone() 17 | 18 | 19 | def add_mome(target, source, beta_): 20 | for name in target: 21 | target[name].data = (beta_ * target[name].data + source[name].data.clone()) 22 | 23 | 24 | def add_mome2(target, source1, source2, beta_1, beta_2): 25 | for name in target: 26 | target[name].data = beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone() 27 | 28 | 29 | def add_mome3(target, source1, source2, source3, beta_1, beta_2, beta_3): 30 | for name in target: 31 | target[name].data = beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone() + beta_3 * source3[name].data.clone() 32 | 33 | def add_2(target, source1, source2, beta_1, beta_2): 34 | for name in target: 35 | target[name].data += beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone() 36 | 37 | def scale(target, scaling): 38 | for name in target: 39 | target[name].data = scaling * target[name].data.clone() 40 | 41 | 42 | def scale_ts(target, source, scaling): 43 | for name in target: 44 | target[name].data = scaling * source[name].data.clone() 45 | 46 | 47 | def subtract(target, source): 48 | for name in target: 49 | target[name].data -= source[name].data.clone() 50 | 51 | 52 | def subtract_(target, minuend, subtrahend): 53 | for name in target: 54 | target[name].data = minuend[name].data.clone() - subtrahend[name].data.clone() 55 | 56 | 57 | def average(target, sources): 58 | for name in target: 59 | target[name].data = torch.mean(torch.stack([source[name].data for source in sources]), dim=0).clone() 60 | 61 | 62 | def weighted_average(target, sources, weights): 63 | for name in target: 64 | summ = torch.sum(weights) 65 | n = len(sources) 66 | modify = [weight / summ * n for weight in weights] 67 | target[name].data = torch.mean(torch.stack([m * source[name].data for source, m in zip(sources, modify)]), 68 | dim=0).clone() 69 | 70 | 71 | def computer_norm(source1, source2): 72 | diff_norm = 0 73 | 74 | for name in source1: 75 | diff_source = source1[name].data.clone() - source2[name].data.clone() 76 | diff_norm += torch.pow(torch.norm(diff_source),2) 77 | 78 | return (torch.pow(diff_norm, 0.5)) 79 | 80 | def majority_vote(target, sources, lr): 81 | for name in target: 82 | threshs = torch.stack([torch.max(source[name].data) for source in sources]) 83 | mask = torch.stack([source[name].data.sign() for source in sources]).sum(dim=0).sign() 84 | target[name].data = (lr * mask).clone() 85 | 86 | 87 | def get_mdl_params(model_list, n_par=None): 88 | if n_par == None: 89 | exp_mdl = model_list[0] 90 | n_par = 0 91 | for name, param in exp_mdl.named_parameters(): 92 | n_par += len(param.data.reshape(-1)) 93 | 94 | param_mat = np.zeros((len(model_list), n_par)).astype('float32') 95 | for i, mdl in enumerate(model_list): 96 | idx = 0 97 | for name, param in mdl.named_parameters(): 98 | temp = param.data.cpu().numpy().reshape(-1) 99 | param_mat[i, idx:idx + len(temp)] = temp 100 | idx += len(temp) 101 | return np.copy(param_mat) 102 | 103 | 104 | def get_other_params(model_list, n_par=None): 105 | if n_par == None: 106 | exp_mdl = model_list[0] 107 | n_par = 0 108 | for name in exp_mdl: 109 | n_par += len(exp_mdl[name].data.reshape(-1)) 110 | 111 | param_mat = np.zeros((len(model_list), n_par)).astype('float32') 112 | for i, mdl in enumerate(model_list): 113 | idx = 0 114 | for name in mdl: 115 | temp = mdl[name].data.cpu().numpy().reshape(-1) 116 | param_mat[i, idx:idx + len(temp)] = temp 117 | idx += len(temp) 118 | return np.copy(param_mat) 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | class DistributedTrainingDevice(object): 127 | ''' 128 | A distributed training device (Client or Server) 129 | data : a pytorch dataset consisting datapoints (x,y) 130 | model : a pytorch neural net f mapping x -> f(x)=y_ 131 | hyperparameters : a python dict containing all hyperparameters 132 | ''' 133 | 134 | def __init__(self, model, args): 135 | self.model = model 136 | self.args = args 137 | self.loss_func = nn.CrossEntropyLoss() 138 | 139 | class Client(DistributedTrainingDevice): 140 | 141 | def __init__(self, model, args, trn_x, trn_y, tst_x, tst_y, n_cls, dataset_name, id_num=0): 142 | super().__init__(model, args) 143 | 144 | self.trn_gen = DataLoader(Dataset(trn_x, trn_y, train=True, dataset_name=dataset_name), 145 | batch_size=self.args.local_bs, shuffle=True) 146 | 147 | self.model_local = copy.deepcopy(model).to(args.device) 148 | 149 | self.tst_x = tst_x 150 | self.tst_y = tst_y 151 | self.n_cls = n_cls 152 | self.id = id_num 153 | self.local_epoch = int(np.ceil(trn_x.shape[0] / self.args.local_bs)) 154 | # Parameters 155 | self.W = {name: value for name, value in self.model.named_parameters()} 156 | self.W_old = {name: torch.zeros(value.shape).to(self.args.device) for name, value in self.W.items()} 157 | self.V = {name: torch.zeros(value.shape).to(self.args.device) for name, value in self.W.items()} 158 | self.state_params_diff = 0.0 159 | self.train_loss = 0.0 160 | self.n_par = get_mdl_params([self.model]).shape[1] 161 | 162 | 163 | def synchronize_with_server(self, server, w_glob_keys): 164 | # W_client = W_server 165 | self.model = copy.deepcopy(server.model) 166 | self.W = {name: value for name, value in self.model.named_parameters()} 167 | 168 | def compute_bias(self): 169 | if self.args.method == 'ditto': 170 | cld_mdl_param = torch.tensor(get_mdl_params([self.model], self.n_par)[0], dtype=torch.float32, device=self.args.device) 171 | self.state_params_diff = self.args.mu * (-cld_mdl_param) 172 | 173 | 174 | def train_cnn(self, w_glob_keys, server, last): 175 | 176 | self.model.train() 177 | 178 | if self.args.method == 'ditto': 179 | optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.lr, momentum=self.args.momentum, 180 | weight_decay=self.args.weigh_delay) 181 | 182 | local_eps = self.args.local_ep 183 | 184 | # train and update 185 | epoch_loss = [] 186 | 187 | for iter in range(local_eps): 188 | 189 | trn_gen_iter = self.trn_gen.__iter__() 190 | batch_loss = [] 191 | 192 | for i in range(self.local_epoch): 193 | 194 | images, labels = trn_gen_iter.__next__() 195 | images, labels = images.to(self.args.device), labels.to(self.args.device) 196 | 197 | optimizer.zero_grad() 198 | log_probs = self.model(images) 199 | loss_f_i = self.loss_func(log_probs, labels.reshape(-1).long()) 200 | 201 | loss = loss_f_i 202 | 203 | loss.backward() 204 | torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), max_norm=max_norm) 205 | optimizer.step() 206 | batch_loss.append(loss.item()) 207 | 208 | epoch_loss.append(sum(batch_loss) / len(batch_loss)) 209 | 210 | return sum(epoch_loss) / len(epoch_loss) 211 | 212 | def train_cnn_local(self, w_glob_keys, server, last): 213 | 214 | self.model_local.train() 215 | 216 | optimizer = torch.optim.SGD(self.model_local.parameters(), lr=self.args.lr, momentum=self.args.momentum, 217 | weight_decay=self.args.weigh_delay + self.args.mu) 218 | 219 | local_eps = self.args.local_ep 220 | 221 | # train and update 222 | epoch_loss = [] 223 | 224 | for iter in range(local_eps): 225 | 226 | trn_gen_iter = self.trn_gen.__iter__() 227 | batch_loss = [] 228 | 229 | for i in range(self.local_epoch): 230 | 231 | images, labels = trn_gen_iter.__next__() 232 | images, labels = images.to(self.args.device), labels.to(self.args.device) 233 | 234 | optimizer.zero_grad() 235 | log_probs = self.model_local(images) 236 | loss_f_i = self.loss_func(log_probs, labels.reshape(-1).long()) 237 | 238 | local_par_list = None 239 | for param in self.model_local.parameters(): 240 | if not isinstance(local_par_list, torch.Tensor): 241 | # Initially nothing to concatenate 242 | local_par_list = param.reshape(-1) 243 | else: 244 | local_par_list = torch.cat((local_par_list, param.reshape(-1)), 0) 245 | 246 | loss_algo = torch.sum(local_par_list * self.state_params_diff) 247 | 248 | 249 | loss = loss_f_i + loss_algo 250 | 251 | loss.backward() 252 | torch.nn.utils.clip_grad_norm_(parameters=self.model_local.parameters(), max_norm=max_norm) 253 | optimizer.step() 254 | batch_loss.append(loss.item()) 255 | 256 | epoch_loss.append(sum(batch_loss) / len(batch_loss)) 257 | 258 | return sum(epoch_loss) / len(epoch_loss) 259 | 260 | 261 | def compute_weight_update(self, w_glob_keys, server, last=False): 262 | 263 | # Training mode 264 | self.model.train() 265 | self.train_loss = self.train_cnn(w_glob_keys, server, last) 266 | 267 | self.model_local.train() 268 | self.train_loss_local = self.train_cnn_local(w_glob_keys, server, last) 269 | 270 | 271 | @torch.no_grad() 272 | def evaluate(self, data_x, data_y, dataset_name): 273 | self.model_local.eval() 274 | # testing 275 | test_loss = 0 276 | acc_overall = 0 277 | n_tst = data_x.shape[0] 278 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False) 279 | tst_gen_iter = tst_gen.__iter__() 280 | for i in range(int(np.ceil(n_tst / self.args.bs))): 281 | data, target = tst_gen_iter.__next__() 282 | data, target = data.to(self.args.device), target.to(self.args.device) 283 | log_probs = self.model_local(data) 284 | # sum up batch loss 285 | test_loss += nn.CrossEntropyLoss(reduction='sum')(log_probs, target.reshape(-1).long()).item() 286 | # get the index of the max log-probability 287 | log_probs = log_probs.cpu().detach().numpy() 288 | log_probs = np.argmax(log_probs, axis=1).reshape(-1) 289 | target = target.cpu().numpy().reshape(-1).astype(np.int32) 290 | batch_correct = np.sum(log_probs == target) 291 | acc_overall += batch_correct 292 | ''' 293 | y_pred = log_probs.data.max(1, keepdim=True)[1] 294 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 295 | ''' 296 | 297 | test_loss /= n_tst 298 | accuracy = 100.00 * acc_overall / n_tst 299 | return accuracy, test_loss 300 | 301 | 302 | ''' 303 | ------------------------------------------------------------------------------------------------------------------------ 304 | 305 | ------------------------------------------------------------------------------------------------------------------------ 306 | ''' 307 | 308 | 309 | class Server(DistributedTrainingDevice): 310 | 311 | def __init__(self, model, args, n_cls): 312 | super().__init__(model, args) 313 | 314 | # Parameters 315 | self.W = {name: value for name, value in self.model.named_parameters()} 316 | self.local_epoch = 0 317 | self.n_cls = n_cls 318 | 319 | def aggregate_weight_updates(self, clients, iter, aggregation="mean"): 320 | 321 | # Warning: Note that K is different for unbalanced dataset 322 | self.local_epoch = clients[0].local_epoch 323 | # dW = aggregate(dW_i, i=1,..,n) 324 | if aggregation == "mean": 325 | average(target=self.W, sources=[client.W for client in clients]) 326 | 327 | 328 | 329 | @torch.no_grad() 330 | def evaluate(self, data_x, data_y, dataset_name): 331 | self.model.eval() 332 | # testing 333 | test_loss = 0 334 | acc_overall = 0 335 | n_tst = data_x.shape[0] 336 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False) 337 | tst_gen_iter = tst_gen.__iter__() 338 | for i in range(int(np.ceil(n_tst / self.args.bs))): 339 | data, target = tst_gen_iter.__next__() 340 | data, target = data.to(self.args.device), target.to(self.args.device) 341 | log_probs = self.model(data) 342 | # sum up batch loss 343 | test_loss += nn.CrossEntropyLoss(reduction='sum')(log_probs, target.reshape(-1).long()).item() 344 | # get the index of the max log-probability 345 | log_probs = log_probs.cpu().detach().numpy() 346 | log_probs = np.argmax(log_probs, axis=1).reshape(-1) 347 | target = target.cpu().numpy().reshape(-1).astype(np.int32) 348 | batch_correct = np.sum(log_probs == target) 349 | acc_overall += batch_correct 350 | ''' 351 | y_pred = log_probs.data.max(1, keepdim=True)[1] 352 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 353 | ''' 354 | 355 | test_loss /= n_tst 356 | accuracy = 100.00 * acc_overall / n_tst 357 | return accuracy, test_loss 358 | 359 | 360 | 361 | -------------------------------------------------------------------------------- /code_FedCR/models/distributed_training_utilsSR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | import math 4 | from torch import nn, autograd 5 | import numpy as np 6 | from torch.utils.data import DataLoader 7 | from utils.utils_dataset import Dataset 8 | from models.Nets_VIB import KL_between_normals 9 | import itertools 10 | import torch.nn.functional as F 11 | 12 | max_norm = 10 13 | 14 | def add(target, source): 15 | for name in target: 16 | target[name].data += source[name].data.clone() 17 | 18 | 19 | def add_mome(target, source, beta_): 20 | for name in target: 21 | target[name].data = (beta_ * target[name].data + source[name].data.clone()) 22 | 23 | 24 | def add_mome2(target, source1, source2, beta_1, beta_2): 25 | for name in target: 26 | target[name].data = beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone() 27 | 28 | 29 | def add_mome3(target, source1, source2, source3, beta_1, beta_2, beta_3): 30 | for name in target: 31 | target[name].data = beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone() + beta_3 * source3[name].data.clone() 32 | 33 | def add_2(target, source1, source2, beta_1, beta_2): 34 | for name in target: 35 | target[name].data += beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone() 36 | 37 | def scale(target, scaling): 38 | for name in target: 39 | target[name].data = scaling * target[name].data.clone() 40 | 41 | 42 | def scale_ts(target, source, scaling): 43 | for name in target: 44 | target[name].data = scaling * source[name].data.clone() 45 | 46 | 47 | def subtract(target, source): 48 | for name in target: 49 | target[name].data -= source[name].data.clone() 50 | 51 | 52 | def subtract_(target, minuend, subtrahend): 53 | for name in target: 54 | target[name].data = minuend[name].data.clone() - subtrahend[name].data.clone() 55 | 56 | 57 | def average(target, sources): 58 | for name in target: 59 | target[name].data = torch.mean(torch.stack([source[name].data for source in sources]), dim=0).clone() 60 | 61 | 62 | def weighted_average(target, sources, weights): 63 | for name in target: 64 | summ = torch.sum(weights) 65 | n = len(sources) 66 | modify = [weight / summ * n for weight in weights] 67 | target[name].data = torch.mean(torch.stack([m * source[name].data for source, m in zip(sources, modify)]), 68 | dim=0).clone() 69 | 70 | 71 | def computer_norm(source1, source2): 72 | diff_norm = 0 73 | 74 | for name in source1: 75 | diff_source = source1[name].data.clone() - source2[name].data.clone() 76 | diff_norm += torch.pow(torch.norm(diff_source),2) 77 | 78 | return (torch.pow(diff_norm, 0.5)) 79 | 80 | def majority_vote(target, sources, lr): 81 | for name in target: 82 | threshs = torch.stack([torch.max(source[name].data) for source in sources]) 83 | mask = torch.stack([source[name].data.sign() for source in sources]).sum(dim=0).sign() 84 | target[name].data = (lr * mask).clone() 85 | 86 | 87 | def get_mdl_params(model_list, n_par=None): 88 | if n_par == None: 89 | exp_mdl = model_list[0] 90 | n_par = 0 91 | for name, param in exp_mdl.named_parameters(): 92 | n_par += len(param.data.reshape(-1)) 93 | 94 | param_mat = np.zeros((len(model_list), n_par)).astype('float32') 95 | for i, mdl in enumerate(model_list): 96 | idx = 0 97 | for name, param in mdl.named_parameters(): 98 | temp = param.data.cpu().numpy().reshape(-1) 99 | param_mat[i, idx:idx + len(temp)] = temp 100 | idx += len(temp) 101 | return np.copy(param_mat) 102 | 103 | 104 | def get_other_params(model_list, n_par=None): 105 | if n_par == None: 106 | exp_mdl = model_list[0] 107 | n_par = 0 108 | for name in exp_mdl: 109 | n_par += len(exp_mdl[name].data.reshape(-1)) 110 | 111 | param_mat = np.zeros((len(model_list), n_par)).astype('float32') 112 | for i, mdl in enumerate(model_list): 113 | idx = 0 114 | for name in mdl: 115 | temp = mdl[name].data.cpu().numpy().reshape(-1) 116 | param_mat[i, idx:idx + len(temp)] = temp 117 | idx += len(temp) 118 | return np.copy(param_mat) 119 | 120 | 121 | 122 | 123 | 124 | 125 | class DistributedTrainingDevice(object): 126 | ''' 127 | A distributed training device (Client or Server) 128 | data : a pytorch dataset consisting datapoints (x,y) 129 | model : a pytorch neural net f mapping x -> f(x)=y_ 130 | hyperparameters : a python dict containing all hyperparameters 131 | ''' 132 | 133 | def __init__(self, model, args): 134 | self.model = model 135 | self.args = args 136 | self.loss_func = nn.CrossEntropyLoss() 137 | 138 | class Client(DistributedTrainingDevice): 139 | 140 | def __init__(self, model, args, trn_x, trn_y, tst_x, tst_y, n_cls, dataset_name, id_num=0): 141 | super().__init__(model, args) 142 | 143 | self.trn_gen = DataLoader(Dataset(trn_x, trn_y, train=True, dataset_name=dataset_name), 144 | batch_size=self.args.local_bs, shuffle=True) 145 | 146 | self.tst_x = tst_x 147 | self.tst_y = tst_y 148 | self.n_cls = n_cls 149 | self.id = id_num 150 | self.local_epoch = int(np.ceil(trn_x.shape[0] / self.args.local_bs)) 151 | # Parameters 152 | self.W = {name: value for name, value in self.model.named_parameters()} 153 | 154 | self.state_params_diff = 0.0 155 | self.train_loss = 0.0 156 | self.n_par = get_mdl_params([self.model]).shape[1] 157 | 158 | 159 | def synchronize_with_server(self, server, w_glob_keys): 160 | # W_client = W_server 161 | 162 | if self.args.method != 'fedSR': 163 | for name in self.W: 164 | if name in w_glob_keys: 165 | self.W[name].data = server.W[name].data.clone() 166 | 167 | else: 168 | 169 | self.W = {name: value for name, value in server.model.named_parameters()} 170 | 171 | 172 | def train_cnn(self, w_glob_keys, server, last): 173 | 174 | self.model.train() 175 | 176 | 177 | optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.lr, momentum=self.args.momentum, 178 | weight_decay=self.args.weigh_delay) 179 | 180 | #.add_param_group({'params':[self.model.r_mu,self.model.r_sigma,self.model.C]}) 181 | 182 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=1) 183 | 184 | local_eps = self.args.local_ep 185 | if last: 186 | if self.args.method =='fedSR': 187 | local_eps= self.args.last_local_ep 188 | if 'CIFAR100' in self.args.dataset: 189 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1, 2, 3]] 190 | elif 'CIFAR10' in self.args.dataset: 191 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1, 2, 3]] 192 | elif 'EMNIST' in self.args.dataset: 193 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1]] 194 | elif 'FMNIST' in self.args.dataset: 195 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1, 2, 3]] 196 | w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys)) 197 | 198 | elif 'maml' in self.args.method: 199 | local_eps = self.args.last_local_ep / 2 200 | w_glob_keys = [] 201 | else: 202 | local_eps = max(self.args.last_local_ep, self.args.local_ep - self.args.local_rep_ep) 203 | 204 | # all other methods update all parameters simultaneously 205 | else: 206 | for name, param in self.model.named_parameters(): 207 | param.requires_grad = True 208 | 209 | # train and update 210 | epoch_loss = [] 211 | 212 | 213 | self.dir_Z_u = torch.zeros(self.n_cls, 1, self.args.dimZ, dtype=torch.float32, device=self.args.device) 214 | self.dir_Z_sigma = torch.ones(self.n_cls, 1, self.args.dimZ, dtype = torch.float32, device = self.args.device) 215 | 216 | for iter in range(local_eps): 217 | 218 | if last: 219 | for name, param in self.model.named_parameters(): 220 | if name in w_glob_keys: 221 | param.requires_grad = False 222 | else: 223 | param.requires_grad = True 224 | 225 | loss_by_epoch = [] 226 | accuracy_by_epoch = [] 227 | 228 | trn_gen_iter = self.trn_gen.__iter__() 229 | batch_loss = [] 230 | 231 | for i in range(self.local_epoch): 232 | 233 | images, labels = trn_gen_iter.__next__() 234 | images, labels = images.to(self.args.device), labels.to(self.args.device) 235 | labels = labels.reshape(-1).long() 236 | 237 | batch_size = images.size()[0] 238 | 239 | encoder_Z_distr, decoder_logits, regL2R= self.model(images, self.args.num_avg_train) 240 | 241 | decoder_logits_mean = torch.mean(decoder_logits, dim=0) 242 | 243 | loss = nn.CrossEntropyLoss(reduction='none') 244 | decoder_logits = decoder_logits.permute(1, 2, 0) 245 | cross_entropy_loss = loss(decoder_logits, labels[:, None].expand(-1, self.args.num_avg_train)) 246 | # estimate E_{eps in N(0, 1)} [log q(y | z)] 247 | cross_entropy_loss_montecarlo = torch.mean(cross_entropy_loss, dim=-1) 248 | minusI_ZY_bound = torch.mean(cross_entropy_loss_montecarlo, dim=0) 249 | 250 | 251 | r_sigma_softplus = F.softplus(self.model.r_sigma) 252 | 253 | 254 | r_mu = self.model.r_mu[labels] 255 | r_sigma = r_sigma_softplus[labels] 256 | z_mu_scaled = encoder_Z_distr[0]*self.model.C 257 | z_sigma_scaled = encoder_Z_distr[1]*self.model.C 258 | regCMI = torch.log(r_sigma) - torch.log(z_sigma_scaled) + \ 259 | (z_sigma_scaled**2+(z_mu_scaled-r_mu)**2)/(2*r_sigma**2) - 0.5 260 | 261 | regCMI = regCMI.sum(1).mean() 262 | regL2R = regL2R / len(labels) 263 | 264 | total_loss = torch.mean(minusI_ZY_bound) + self.args.CMI * regCMI + self.args.L2R * regL2R 265 | 266 | prediction = torch.max(decoder_logits_mean, dim=1)[1] 267 | accuracy = torch.mean((prediction == labels).float()) 268 | 269 | optimizer.zero_grad() 270 | total_loss.backward() 271 | torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), max_norm=max_norm) 272 | optimizer.step() 273 | 274 | loss_by_epoch.append(total_loss.item()) 275 | accuracy_by_epoch.append(accuracy.item()) 276 | 277 | scheduler.step() 278 | epoch_loss.append(sum(loss_by_epoch) / len(loss_by_epoch)) 279 | 280 | 281 | 282 | return sum(epoch_loss) / len(epoch_loss) 283 | 284 | def compute_weight_update(self, w_glob_keys, server, last=False): 285 | 286 | # Training mode 287 | self.model.train() 288 | 289 | # W = SGD(W, D) 290 | self.train_loss = self.train_cnn(w_glob_keys, server, last) 291 | 292 | 293 | 294 | @torch.no_grad() 295 | def evaluate(self, data_x, data_y, dataset_name): 296 | self.model.eval() 297 | # testing 298 | I_ZX_bound_by_epoch_test = [] 299 | I_ZY_bound_by_epoch_test = [] 300 | loss_by_epoch_test = [] 301 | accuracy_by_epoch_test = [] 302 | 303 | n_tst = data_x.shape[0] 304 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False) 305 | tst_gen_iter = tst_gen.__iter__() 306 | for i in range(int(np.ceil(n_tst / self.args.bs))): 307 | data, target = tst_gen_iter.__next__() 308 | data, target = data.to(self.args.device), target.to(self.args.device) 309 | target = target.reshape(-1).long() 310 | batch_size = data.size()[0] 311 | prior_Z_distr = torch.zeros(batch_size, self.args.dimZ).to(self.args.device), torch.ones(batch_size,self.args.dimZ).to(self.args.device) 312 | encoder_Z_distr, decoder_logits, regL2R = self.model(data, self.args.num_avg) 313 | 314 | decoder_logits_mean = torch.mean(decoder_logits, dim=0) 315 | loss = nn.CrossEntropyLoss(reduction='none') 316 | decoder_logits = decoder_logits.permute(1, 2, 0) 317 | cross_entropy_loss = loss(decoder_logits, target[:, None].expand(-1, self.args.num_avg)) 318 | 319 | cross_entropy_loss_montecarlo = torch.mean(cross_entropy_loss, dim=-1) 320 | 321 | I_ZX_bound_test = torch.mean(KL_between_normals(encoder_Z_distr, prior_Z_distr)) 322 | minusI_ZY_bound_test = torch.mean(cross_entropy_loss_montecarlo, dim=0) 323 | total_loss_test = torch.mean(minusI_ZY_bound_test + self.args.beta * I_ZX_bound_test) 324 | 325 | prediction = torch.max(decoder_logits_mean, dim=1)[1] 326 | accuracy_test = torch.mean((prediction == target).float()) 327 | 328 | I_ZX_bound_by_epoch_test.append(I_ZX_bound_test.item()) 329 | I_ZY_bound_by_epoch_test.append(minusI_ZY_bound_test.item()) 330 | 331 | loss_by_epoch_test.append(total_loss_test.item()) 332 | accuracy_by_epoch_test.append(accuracy_test.item()) 333 | 334 | I_ZX = np.mean(I_ZX_bound_by_epoch_test) 335 | I_ZY = np.mean(I_ZY_bound_by_epoch_test) 336 | loss_test = np.mean(loss_by_epoch_test) 337 | accuracy_test = np.mean(accuracy_by_epoch_test) 338 | accuracy_test = 100.00 * accuracy_test 339 | return accuracy_test, loss_test 340 | 341 | ''' 342 | ------------------------------------------------------------------------------------------------------------------------ 343 | 344 | ------------------------------------------------------------------------------------------------------------------------ 345 | ''' 346 | 347 | 348 | class Server(DistributedTrainingDevice): 349 | 350 | def __init__(self, model, args, n_cls): 351 | super().__init__(model, args) 352 | 353 | # Parameters 354 | self.W = {name: value for name, value in self.model.named_parameters()} 355 | self.local_epoch = 0 356 | self.n_cls = n_cls 357 | if self.args.method == 'FedCR': 358 | self.dir_global_Z_u = torch.zeros(self.n_cls, 1, self.args.dimZ, dtype=torch.float32, device=self.args.device) 359 | self.dir_global_Z_sigma = torch.ones(self.n_cls, 1, self.args.dimZ, dtype = torch.float32, device = self.args.device) 360 | 361 | 362 | def aggregate_weight_updates(self, clients, iter, aggregation="mean"): 363 | 364 | # Warning: Note that K is different for unbalanced dataset 365 | self.local_epoch = clients[0].local_epoch 366 | # dW = aggregate(dW_i, i=1,..,n) 367 | if aggregation == "mean": 368 | average(target=self.W, sources=[client.W for client in clients]) 369 | 370 | 371 | 372 | 373 | 374 | @torch.no_grad() 375 | def evaluate(self, data_x, data_y, dataset_name): 376 | self.model.eval() 377 | # testing 378 | I_ZX_bound_by_epoch_test = [] 379 | I_ZY_bound_by_epoch_test = [] 380 | loss_by_epoch_test = [] 381 | accuracy_by_epoch_test = [] 382 | 383 | n_tst = data_x.shape[0] 384 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False) 385 | tst_gen_iter = tst_gen.__iter__() 386 | for i in range(int(np.ceil(n_tst / self.args.bs))): 387 | data, target = tst_gen_iter.__next__() 388 | data, target = data.to(self.args.device), target.to(self.args.device) 389 | target = target.reshape(-1).long() 390 | batch_size = data.size()[0] 391 | prior_Z_distr = torch.zeros(batch_size, self.args.dimZ).to(self.args.device), torch.ones(batch_size,self.args.dimZ).to(self.args.device) 392 | encoder_Z_distr, decoder_logit, regL2R = self.model(data, self.args.num_avg) 393 | 394 | decoder_logits_mean = torch.mean(decoder_logits, dim=0) 395 | loss = nn.CrossEntropyLoss(reduction='none') 396 | decoder_logits = decoder_logits.permute(1, 2, 0) 397 | cross_entropy_loss = loss(decoder_logits, target[:, None].expand(-1, self.args.num_avg)) 398 | 399 | cross_entropy_loss_montecarlo = torch.mean(cross_entropy_loss, dim=-1) 400 | 401 | I_ZX_bound_test = torch.mean(KL_between_normals(encoder_Z_distr, prior_Z_distr)) 402 | minusI_ZY_bound_test = torch.mean(cross_entropy_loss_montecarlo, dim=0) 403 | total_loss_test = torch.mean(minusI_ZY_bound_test + self.args.beta * I_ZX_bound_test) 404 | 405 | prediction = torch.max(decoder_logits_mean, dim=1)[1] 406 | accuracy_test = torch.mean((prediction == target).float()) 407 | 408 | I_ZX_bound_by_epoch_test.append(I_ZX_bound_test.item()) 409 | I_ZY_bound_by_epoch_test.append(minusI_ZY_bound_test.item()) 410 | 411 | loss_by_epoch_test.append(total_loss_test.item()) 412 | accuracy_by_epoch_test.append(accuracy_test.item()) 413 | 414 | I_ZX = np.mean(I_ZX_bound_by_epoch_test) 415 | I_ZY = np.mean(I_ZY_bound_by_epoch_test) 416 | loss_test = np.mean(loss_by_epoch_test) 417 | accuracy_test = np.mean(accuracy_by_epoch_test) 418 | accuracy_test = 100.00 * accuracy_test 419 | return accuracy_test, loss_test 420 | 421 | 422 | 423 | -------------------------------------------------------------------------------- /code_FedCR/models/distributed_training_utils_PAC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | import math 4 | from torch import nn, autograd 5 | import numpy as np 6 | from torch.utils.data import DataLoader 7 | from utils.utils_dataset import Dataset 8 | from models.Nets_VIB import KL_between_normals 9 | import itertools 10 | import torch.nn.functional as F 11 | 12 | 13 | max_norm = 10 14 | 15 | def add(target, source): 16 | for name in target: 17 | target[name].data += source[name].data.clone() 18 | 19 | 20 | def add_mome(target, source, beta_): 21 | for name in target: 22 | target[name].data = (beta_ * target[name].data + source[name].data.clone()) 23 | 24 | 25 | def add_mome2(target, source1, source2, beta_1, beta_2): 26 | for name in target: 27 | target[name].data = beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone() 28 | 29 | 30 | def add_mome3(target, source1, source2, source3, beta_1, beta_2, beta_3): 31 | for name in target: 32 | target[name].data = beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone() + beta_3 * source3[name].data.clone() 33 | 34 | def add_2(target, source1, source2, beta_1, beta_2): 35 | for name in target: 36 | target[name].data += beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone() 37 | 38 | def scale(target, scaling): 39 | for name in target: 40 | target[name].data = scaling * target[name].data.clone() 41 | 42 | 43 | def scale_ts(target, source, scaling): 44 | for name in target: 45 | target[name].data = scaling * source[name].data.clone() 46 | 47 | 48 | def subtract(target, source): 49 | for name in target: 50 | target[name].data -= source[name].data.clone() 51 | 52 | 53 | def subtract_(target, minuend, subtrahend): 54 | for name in target: 55 | target[name].data = minuend[name].data.clone() - subtrahend[name].data.clone() 56 | 57 | 58 | def average(target, sources): 59 | for name in target: 60 | target[name].data = torch.mean(torch.stack([source[name].data for source in sources]), dim=0).clone() 61 | 62 | 63 | def weighted_average(target, sources, weights): 64 | for name in target: 65 | summ = torch.sum(weights) 66 | n = len(sources) 67 | modify = [weight / summ * n for weight in weights] 68 | target[name].data = torch.mean(torch.stack([m * source[name].data for source, m in zip(sources, modify)]), 69 | dim=0).clone() 70 | 71 | 72 | def computer_norm(source1, source2): 73 | diff_norm = 0 74 | 75 | for name in source1: 76 | diff_source = source1[name].data.clone() - source2[name].data.clone() 77 | diff_norm += torch.pow(torch.norm(diff_source),2) 78 | 79 | return (torch.pow(diff_norm, 0.5)) 80 | 81 | def majority_vote(target, sources, lr): 82 | for name in target: 83 | threshs = torch.stack([torch.max(source[name].data) for source in sources]) 84 | mask = torch.stack([source[name].data.sign() for source in sources]).sum(dim=0).sign() 85 | target[name].data = (lr * mask).clone() 86 | 87 | 88 | def get_mdl_params(model_list, n_par=None): 89 | if n_par == None: 90 | exp_mdl = model_list[0] 91 | n_par = 0 92 | for name, param in exp_mdl.named_parameters(): 93 | n_par += len(param.data.reshape(-1)) 94 | 95 | param_mat = np.zeros((len(model_list), n_par)).astype('float32') 96 | for i, mdl in enumerate(model_list): 97 | idx = 0 98 | for name, param in mdl.named_parameters(): 99 | temp = param.data.cpu().numpy().reshape(-1) 100 | param_mat[i, idx:idx + len(temp)] = temp 101 | idx += len(temp) 102 | return np.copy(param_mat) 103 | 104 | 105 | def get_other_params(model_list, n_par=None): 106 | if n_par == None: 107 | exp_mdl = model_list[0] 108 | n_par = 0 109 | for name in exp_mdl: 110 | n_par += len(exp_mdl[name].data.reshape(-1)) 111 | 112 | param_mat = np.zeros((len(model_list), n_par)).astype('float32') 113 | for i, mdl in enumerate(model_list): 114 | idx = 0 115 | for name in mdl: 116 | temp = mdl[name].data.cpu().numpy().reshape(-1) 117 | param_mat[i, idx:idx + len(temp)] = temp 118 | idx += len(temp) 119 | return np.copy(param_mat) 120 | 121 | 122 | 123 | def personalized_classifier(args, client, clients): 124 | 125 | #penalty method 126 | class Model_quadratic_programming(): 127 | def __init__(self): 128 | self.coe_client = torch.rand(int(len(clients)) -1, dtype=torch.float32, device=args.device, requires_grad=True) 129 | self.coe_client_last = 1 - torch.sum(self.coe_client) 130 | 131 | def final_solve(self): 132 | 133 | A3 = 0; A3_ = 0; A4 = 0; 134 | for i in range(int(len(clients))): 135 | for j in range(int(len(clients))): 136 | 137 | C1 = client.P_y * torch.squeeze(client.F_x) - clients[i].P_y * torch.squeeze(clients[i].F_x) 138 | A2 = ( C1 @ C1.t()).trace() 139 | 140 | if j < int(len(clients)) -1: 141 | A3 = A3 + self.coe_client[j] * A2 142 | else: 143 | A3 = A3 + self.coe_client_last * A2 144 | 145 | if i < int(len(clients)) - 1: 146 | A3_ = A3_ + self.coe_client[i] * A3 147 | else: 148 | A3_ = A3_ + self.coe_client_last * A3 149 | 150 | if i < int(len(clients)) - 1: 151 | A4 = A4 + self.coe_client[i] * clients[i].Var_ 152 | else: 153 | A4 = A4 + self.coe_client_last * clients[i].Var_ 154 | 155 | # penalty method 156 | final_solve = torch.sum(A4 + A3_ + F.relu(-self.coe_client)) + F.relu(-self.coe_client_last) # coe_client > 0 penalty method 157 | #final_solve = torch.sum(A5) 158 | 159 | return final_solve 160 | 161 | Model_quadratic = Model_quadratic_programming() 162 | 163 | lr = 0.1 164 | for i in range(50): 165 | loss = Model_quadratic.final_solve() 166 | loss.backward() 167 | 168 | Model_quadratic.coe_client.data.sub_(lr * Model_quadratic.coe_client.grad) 169 | 170 | Model_quadratic.coe_client.grad.zero_() 171 | Model_quadratic.coe_client_last = 1 - torch.sum(Model_quadratic.coe_client) 172 | 173 | coe_client_all = torch.rand(int(len(clients)), dtype=torch.float32, device=args.device, requires_grad=False) 174 | coe_client_all[:-1] = Model_quadratic.coe_client.clone().detach() 175 | coe_client_all[-1:] = Model_quadratic.coe_client_last.clone().detach() 176 | 177 | return coe_client_all 178 | 179 | 180 | 181 | 182 | class DistributedTrainingDevice(object): 183 | ''' 184 | A distributed training device (Client or Server) 185 | data : a pytorch dataset consisting datapoints (x,y) 186 | model : a pytorch neural net f mapping x -> f(x)=y_ 187 | hyperparameters : a python dict containing all hyperparameters 188 | ''' 189 | 190 | def __init__(self, model, args): 191 | self.model = model 192 | self.args = args 193 | self.loss_func = nn.CrossEntropyLoss() 194 | 195 | class Client(DistributedTrainingDevice): 196 | 197 | def __init__(self, model, args, trn_x, trn_y, tst_x, tst_y, n_cls, dataset_name, id_num=0): 198 | super().__init__(model, args) 199 | 200 | self.trn_gen = DataLoader(Dataset(trn_x, trn_y, train=True, dataset_name=dataset_name), 201 | batch_size=self.args.local_bs, shuffle=True) 202 | 203 | self.tst_x = tst_x 204 | self.tst_y = tst_y 205 | self.n_cls = n_cls 206 | self.id = id_num 207 | self.local_epoch = int(np.ceil(trn_x.shape[0] / self.args.local_bs)) 208 | # Parameters 209 | self.W = {name: value for name, value in self.model.named_parameters()} 210 | 211 | self.lo_fecture_output = torch.zeros(self.n_cls, 1, self.args.dimZ_PAC, dtype=torch.float32, device=self.args.device) 212 | 213 | self.P_y = torch.zeros(self.n_cls, 1, dtype=torch.float32, device=self.args.device) 214 | self.F_x = torch.zeros(self.n_cls, 1, self.args.dimZ_PAC, dtype=torch.float32, device=self.args.device) 215 | self.Var = torch.zeros(self.n_cls, 1, dtype=torch.float32, device=self.args.device) 216 | self.Var_ = torch.zeros(self.n_cls, 1, dtype=torch.float32, device=self.args.device) 217 | 218 | self.state_params_diff = 0.0 219 | self.train_loss = 0.0 220 | self.n_par = get_mdl_params([self.model]).shape[1] 221 | 222 | 223 | def synchronize_with_server(self, server, w_glob_keys): 224 | # W_client = W_server 225 | 226 | if self.args.method != 'fedavg': 227 | for name in self.W: 228 | if name in w_glob_keys: 229 | self.W[name].data = server.W[name].data.clone() 230 | else: 231 | self.model = copy.deepcopy(server.model) 232 | self.W = {name: value for name, value in self.model.named_parameters()} 233 | 234 | 235 | def train_cnn(self, w_glob_keys, server, last): 236 | 237 | self.model.train() 238 | 239 | local_eps = self.args.local_ep 240 | if last: 241 | if self.args.method =='fedavg': 242 | local_eps= self.args.last_local_ep 243 | if 'CIFAR100' in self.args.dataset: 244 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1, 2, 3]] 245 | elif 'CIFAR10' in self.args.dataset: 246 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1, 2, 3]] 247 | elif 'MNIST' in self.args.dataset: 248 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1]] 249 | elif 'FMNIST' in self.args.dataset: 250 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1, 2, 3]] 251 | w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys)) 252 | else: 253 | local_eps = max(self.args.last_local_ep, self.args.local_ep - self.args.local_rep_ep) 254 | 255 | # all other methods update all parameters simultaneously 256 | else: 257 | for name, param in self.model.named_parameters(): 258 | param.requires_grad = True 259 | 260 | 261 | # train and update 262 | epoch_loss = [] 263 | 264 | 265 | for iter in range(local_eps): 266 | flag=0 267 | head_eps = local_eps - self.args.local_rep_ep 268 | # for FedRep, first do local epochs for the head 269 | if (iter < head_eps and self.args.method == 'fedPAC') or last: 270 | for name, param in self.model.named_parameters(): 271 | if name in w_glob_keys: 272 | param.requires_grad = False 273 | else: 274 | param.requires_grad = True 275 | 276 | # then do local epochs for the representation 277 | elif (iter >= head_eps and self.args.method == 'fedPAC') and not last: 278 | flag =1 279 | for name, param in self.model.named_parameters(): 280 | if name in w_glob_keys: 281 | param.requires_grad = True 282 | else: 283 | param.requires_grad = False 284 | 285 | trn_gen_iter = self.trn_gen.__iter__() 286 | batch_loss = [] 287 | 288 | optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.lr, momentum=self.args.momentum, 289 | weight_decay=self.args.weigh_delay) 290 | 291 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=1) 292 | 293 | 294 | for i in range(self.local_epoch): 295 | 296 | images, labels = trn_gen_iter.__next__() 297 | images, labels = images.to(self.args.device), labels.to(self.args.device) 298 | labels = labels.reshape(-1).long() 299 | 300 | optimizer.zero_grad() 301 | log_probs, fecture_output = self.model(images) 302 | 303 | 304 | loss_f_i = self.loss_func(log_probs, labels) 305 | 306 | if flag==1: 307 | for cls in range(len(labels)): 308 | if cls == 0: 309 | dir_g_fecture_output = server.fecture_output[labels[cls]].clone().detach() 310 | else: 311 | dir_g_fecture_output = torch.cat((dir_g_fecture_output, server.fecture_output[labels[cls]].clone().detach()), 0).clone().detach() 312 | 313 | R_i = torch.norm(fecture_output - dir_g_fecture_output) / len(labels) 314 | else: 315 | R_i =0 316 | 317 | loss = loss_f_i + self.args.beta_PAC * R_i 318 | 319 | loss.backward() 320 | torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), max_norm=max_norm) 321 | optimizer.step() 322 | batch_loss.append(loss.item()) 323 | 324 | scheduler.step() 325 | epoch_loss.append(sum(batch_loss) / len(batch_loss)) 326 | 327 | return sum(epoch_loss) / len(epoch_loss) 328 | 329 | 330 | def local_feature(self): 331 | 332 | trn_gen_iter = self.trn_gen.__iter__() 333 | batch_loss = [] 334 | 335 | for i in range(self.local_epoch): 336 | images, labels = trn_gen_iter.__next__() 337 | images, labels = images.to(self.args.device), labels.to(self.args.device) 338 | labels = labels.reshape(-1).long() 339 | 340 | log_probs, fecture_output = self.model(images) 341 | 342 | for cls in range(len(labels)): 343 | if self.lo_fecture_output[labels[cls]].equal(torch.zeros(1, self.args.dimZ_PAC, dtype=torch.float32, device=self.args.device)): 344 | self.lo_fecture_output[labels[cls]] = fecture_output[cls].clone().detach() 345 | 346 | self.Var[labels[cls]] = (fecture_output[cls].t().clone().detach()) @(fecture_output[cls].clone().detach()) 347 | else: 348 | self.lo_fecture_output[labels[cls]] = (fecture_output[cls].clone().detach() + self.lo_fecture_output[labels[cls]]) / 2 349 | self.Var[labels[cls]] = (self.Var[labels[cls]] + (fecture_output[cls].clone().detach()) @(fecture_output[cls].clone().detach()) ) / 2 350 | 351 | self.P_y[labels[cls]] = self.P_y[labels[cls]] + 1 352 | 353 | self.P_y = self.P_y / torch.sum(self.P_y) 354 | self.F_x = self.lo_fecture_output.clone().detach() 355 | self.Var_ = torch.sum(self.P_y * self.Var.trace() - (self.P_y * torch.squeeze(self.F_x)) ** 2) 356 | 357 | def compute_weight_update(self, w_glob_keys, server, last=False): 358 | 359 | # Training mode 360 | self.model.train() 361 | 362 | # W = SGD(W, D) 363 | self.train_loss = self.train_cnn(w_glob_keys, server, last) 364 | 365 | 366 | @torch.no_grad() 367 | def evaluate(self, data_x, data_y, dataset_name): 368 | self.model.eval() 369 | # testing 370 | test_loss = 0 371 | acc_overall = 0 372 | n_tst = data_x.shape[0] 373 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False) 374 | tst_gen_iter = tst_gen.__iter__() 375 | for i in range(int(np.ceil(n_tst / self.args.bs))): 376 | data, target = tst_gen_iter.__next__() 377 | data, target = data.to(self.args.device), target.to(self.args.device) 378 | log_probs, fecture_output = self.model(data) 379 | # sum up batch loss 380 | test_loss += nn.CrossEntropyLoss(reduction='sum')(log_probs, target.reshape(-1).long()).item() 381 | # get the index of the max log-probability 382 | log_probs = log_probs.cpu().detach().numpy() 383 | log_probs = np.argmax(log_probs, axis=1).reshape(-1) 384 | target = target.cpu().numpy().reshape(-1).astype(np.int32) 385 | batch_correct = np.sum(log_probs == target) 386 | acc_overall += batch_correct 387 | ''' 388 | y_pred = log_probs.data.max(1, keepdim=True)[1] 389 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 390 | ''' 391 | 392 | test_loss /= n_tst 393 | accuracy = 100.00 * acc_overall / n_tst 394 | return accuracy, test_loss 395 | 396 | 397 | ''' 398 | ------------------------------------------------------------------------------------------------------------------------ 399 | 400 | ------------------------------------------------------------------------------------------------------------------------ 401 | ''' 402 | 403 | 404 | class Server(DistributedTrainingDevice): 405 | 406 | def __init__(self, model, args, n_cls): 407 | super().__init__(model, args) 408 | 409 | # Parameters 410 | self.W = {name: value for name, value in self.model.named_parameters()} 411 | self.local_epoch = 0 412 | self.n_cls = n_cls 413 | 414 | self.fecture_output = torch.zeros(self.n_cls, 1, self.args.dimZ_PAC, dtype=torch.float32, device=self.args.device) 415 | 416 | 417 | def aggregate_weight_updates(self, clients, iter, aggregation="mean"): 418 | 419 | # Warning: Note that K is different for unbalanced dataset 420 | self.local_epoch = clients[0].local_epoch 421 | # dW = aggregate(dW_i, i=1,..,n) 422 | if aggregation == "mean": 423 | average(target=self.W, sources=[client.W for client in clients]) 424 | 425 | 426 | def global_feature_centroids(self, clients): 427 | 428 | 429 | for cls in range(self.n_cls): 430 | 431 | clients_all_fecture_output = True 432 | 433 | for i in range(len(clients)): 434 | 435 | if clients[i].lo_fecture_output[cls].equal(torch.zeros(1, self.args.dimZ_PAC, dtype=torch.float32, device=self.args.device)): 436 | pass 437 | elif isinstance(clients_all_fecture_output, bool): 438 | clients_all_fecture_output = clients[i].lo_fecture_output[cls].clone().detach() 439 | 440 | else: 441 | clients_all_fecture_output = torch.cat((clients_all_fecture_output, clients[i].lo_fecture_output[cls].clone().detach()), 0).clone().detach() 442 | 443 | if not isinstance(clients_all_fecture_output, bool): 444 | 445 | self.fecture_output[cls] = torch.mean(clients_all_fecture_output) 446 | 447 | 448 | def Get_classifier(self, clients, w_glob_keys): 449 | for client in clients: 450 | modify = personalized_classifier(args = self.args, client= client, clients=clients) 451 | for name in self.W: 452 | if name not in w_glob_keys: 453 | client.W[name].data = torch.mean( 454 | torch.stack([m * source.W[name].data for source, m in zip(clients, modify)]), 455 | dim=0).clone() 456 | 457 | @torch.no_grad() 458 | def evaluate(self, data_x, data_y, dataset_name): 459 | self.model.eval() 460 | # testing 461 | test_loss = 0 462 | acc_overall = 0 463 | n_tst = data_x.shape[0] 464 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False) 465 | tst_gen_iter = tst_gen.__iter__() 466 | for i in range(int(np.ceil(n_tst / self.args.bs))): 467 | data, target = tst_gen_iter.__next__() 468 | data, target = data.to(self.args.device), target.to(self.args.device) 469 | log_probs = self.model(data) 470 | # sum up batch loss 471 | test_loss += nn.CrossEntropyLoss(reduction='sum')(log_probs, target.reshape(-1).long()).item() 472 | # get the index of the max log-probability 473 | log_probs = log_probs.cpu().detach().numpy() 474 | log_probs = np.argmax(log_probs, axis=1).reshape(-1) 475 | target = target.cpu().numpy().reshape(-1).astype(np.int32) 476 | batch_correct = np.sum(log_probs == target) 477 | acc_overall += batch_correct 478 | ''' 479 | y_pred = log_probs.data.max(1, keepdim=True)[1] 480 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 481 | ''' 482 | 483 | test_loss /= n_tst 484 | accuracy = 100.00 * acc_overall / n_tst 485 | return accuracy, test_loss 486 | 487 | 488 | 489 | -------------------------------------------------------------------------------- /code_FedCR/utils/utils_dataset.py: -------------------------------------------------------------------------------- 1 | from scipy import io 2 | import numpy as np 3 | import torchvision 4 | from torchvision import transforms 5 | import torch 6 | from torch.utils import data 7 | import os 8 | 9 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 10 | import ssl 11 | 12 | ssl._create_default_https_context = ssl._create_unverified_context 13 | 14 | 15 | class DatasetObject: 16 | def __init__(self, dataset, n_client, seed, rule, class_main, data_path, frac_data=0.7, dir_alpha=0): 17 | self.dataset = dataset 18 | self.n_client = n_client 19 | self.seed = seed 20 | self.rule = rule 21 | self.frac_data = frac_data 22 | self.dir_alpha = dir_alpha 23 | self.class_main = class_main 24 | 25 | self.name = "Data%s_nclient%d_seed%d_rule%s_alpha%s_class_main%d_frac_data%s" % ( 26 | self.dataset, self.n_client, self.seed, self.rule, self.dir_alpha, self.class_main, self.frac_data) 27 | self.data_path = data_path 28 | self.set_data() 29 | 30 | def set_data(self): 31 | # Prepare data if not ready 32 | if not os.path.exists('Data/%s' % (self.name)): 33 | # Get Raw data 34 | if self.dataset == 'MNIST': 35 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 36 | trnset = torchvision.datasets.MNIST(root='Data/', train=True, download=True, transform=transform) 37 | tstset = torchvision.datasets.MNIST(root='Data/', train=False, download=True, transform=transform) 38 | 39 | trn_load = torch.utils.data.DataLoader(trnset, batch_size=60000, shuffle=False, num_workers=1) 40 | tst_load = torch.utils.data.DataLoader(tstset, batch_size=10000, shuffle=False, num_workers=1) 41 | self.channels = 1; 42 | self.width = 28; 43 | self.height = 28; 44 | self.n_cls = 10; 45 | 46 | if self.dataset == 'FMNIST': 47 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) 48 | trnset = torchvision.datasets.FashionMNIST(root='Data/', train=True, download=True, transform=transform) 49 | tstset = torchvision.datasets.FashionMNIST(root='Data/', train=False, download=True, 50 | transform=transform) 51 | 52 | trn_load = torch.utils.data.DataLoader(trnset, batch_size=60000, shuffle=False, num_workers=1) 53 | tst_load = torch.utils.data.DataLoader(tstset, batch_size=10000, shuffle=False, num_workers=1) 54 | self.channels = 1; 55 | self.width = 28; 56 | self.height = 28; 57 | self.n_cls = 10; 58 | 59 | if self.dataset == 'CIFAR10': 60 | transform = transforms.Compose([transforms.ToTensor()]) 61 | 62 | trnset = torchvision.datasets.CIFAR10(root='Data/', train=True, download=True, transform=transform) 63 | tstset = torchvision.datasets.CIFAR10(root='Data/', train=False, download=True, transform=transform) 64 | 65 | trn_load = torch.utils.data.DataLoader(trnset, batch_size=50000, shuffle=False, num_workers=1) 66 | tst_load = torch.utils.data.DataLoader(tstset, batch_size=10000, shuffle=False, num_workers=1) 67 | self.channels = 3; 68 | self.width = 32; 69 | self.height = 32; 70 | self.n_cls = 10; 71 | 72 | if self.dataset == 'CIFAR100': 73 | transform = transforms.Compose([transforms.ToTensor()]) 74 | 75 | trnset = torchvision.datasets.CIFAR100(root='Data/', train=True, download=True, transform=transform) 76 | tstset = torchvision.datasets.CIFAR100(root='Data/', train=False, download=True, transform=transform) 77 | 78 | trn_load = torch.utils.data.DataLoader(trnset, batch_size=50000, shuffle=False, num_workers=1) 79 | tst_load = torch.utils.data.DataLoader(tstset, batch_size=10000, shuffle=False, num_workers=1) 80 | self.channels = 3; 81 | self.width = 32; 82 | self.height = 32; 83 | self.n_cls = 100; 84 | 85 | if self.dataset == 'MNIST' or self.dataset == 'FMNIST' or self.dataset == 'CIFAR10' or self.dataset == 'CIFAR100': 86 | trn_itr = trn_load.__iter__() 87 | tst_itr = tst_load.__iter__() 88 | # labels are of shape (n_data,) 89 | trn_x, trn_y = trn_itr.__next__() 90 | tst_x, tst_y = tst_itr.__next__() 91 | 92 | trn_x = trn_x.numpy(); 93 | trn_y = trn_y.numpy().reshape(-1, 1) 94 | tst_x = tst_x.numpy(); 95 | tst_y = tst_y.numpy().reshape(-1, 1) 96 | 97 | concat_datasets_x = np.concatenate((trn_x, tst_x), axis=0) 98 | concat_datasets_y = np.concatenate((trn_y, tst_y), axis=0) 99 | 100 | self.trn_x = trn_x 101 | self.trn_y = trn_y 102 | self.tst_x = tst_x 103 | self.tst_y = tst_y 104 | 105 | if self.dataset == 'EMNIST': 106 | emnist = io.loadmat("Data/Raw/matlab/emnist-letters.mat") 107 | # load training dataset 108 | x_train = emnist["dataset"][0][0][0][0][0][0] 109 | x_train = x_train.astype(np.float32) 110 | 111 | # load training labels 112 | y_train = emnist["dataset"][0][0][0][0][0][1] - 1 # make first class 0 113 | 114 | # take first 10 classes of letters 115 | trn_idx = np.where(y_train < 10)[0] 116 | 117 | y_train = y_train[trn_idx] 118 | x_train = x_train[trn_idx] 119 | 120 | mean_x = np.mean(x_train) 121 | std_x = np.std(x_train) 122 | 123 | # load test dataset 124 | x_test = emnist["dataset"][0][0][1][0][0][0] 125 | x_test = x_test.astype(np.float32) 126 | 127 | # load test labels 128 | y_test = emnist["dataset"][0][0][1][0][0][1] - 1 # make first class 0 129 | 130 | tst_idx = np.where(y_test < 10)[0] 131 | 132 | y_test = y_test[tst_idx] 133 | x_test = x_test[tst_idx] 134 | 135 | x_train = x_train.reshape((-1, 1, 28, 28)) 136 | x_test = x_test.reshape((-1, 1, 28, 28)) 137 | 138 | # normalise train and test features 139 | 140 | trn_x = (x_train - mean_x) / std_x 141 | trn_y = y_train 142 | 143 | tst_x = (x_test - mean_x) / std_x 144 | tst_y = y_test 145 | 146 | self.channels = 1; self.width = 28; self.height = 28; self.n_cls = 10; 147 | concat_datasets_x = np.concatenate((trn_x, tst_x), axis=0) 148 | concat_datasets_y = np.concatenate((trn_y, tst_y), axis=0) 149 | 150 | self.trn_x = trn_x 151 | self.trn_y = trn_y 152 | self.tst_x = tst_x 153 | self.tst_y = tst_y 154 | 155 | 156 | # Shuffle Data 157 | np.random.seed(self.seed) 158 | rand_perm = np.random.permutation(len(concat_datasets_y)) 159 | concat_datasets_x = concat_datasets_x[rand_perm] 160 | concat_datasets_y = concat_datasets_y[rand_perm] 161 | 162 | assert len(concat_datasets_y) % self.n_client == 0 163 | n_data_per_clnt = int((len(concat_datasets_y)) / self.n_client) 164 | clnt_data_list = np.ones(self.n_client).astype(int) * n_data_per_clnt 165 | 166 | n_data_per_clnt_train = int(n_data_per_clnt * self.frac_data) 167 | n_data_per_clnt_tst = n_data_per_clnt - n_data_per_clnt_train 168 | clnt_data_list_train = np.ones(self.n_client).astype(int) * n_data_per_clnt_train 169 | clnt_data_list_tst = np.ones(self.n_client).astype(int) * n_data_per_clnt_tst 170 | ### 171 | 172 | cls_per_client = self.class_main 173 | n_cls = self.n_cls 174 | n_client = self.n_client 175 | 176 | # Distribute training datapoints 177 | idx_list = [np.where(concat_datasets_y == i)[0] for i in range(self.n_cls)] 178 | idx_count_list = [0 for i in range(self.n_cls)] 179 | cls_amount = np.asarray([len(idx_list[i]) for i in range(self.n_cls)]) 180 | n_data = np.sum(cls_amount) 181 | total_clnt_data_list = np.asarray([0 for i in range(n_client)]) 182 | clnt_cls_idx = [[[] for kk in range(n_cls)] for jj in range(n_client)] # Store the indeces of data points 183 | 184 | 185 | if self.rule == 'Dirichlet': 186 | cls_priors = np.random.dirichlet(alpha=[self.dir_alpha] * self.n_cls, size=self.n_client) 187 | prior_cumsum = np.cumsum(cls_priors, axis=1) 188 | 189 | concat_clnt_x = np.asarray( 190 | [np.zeros((clnt_data_list[clnt__], self.channels, self.height, self.width)).astype(np.float32) for 191 | clnt__ in range(self.n_client)]) 192 | concat_clnt_y = np.asarray( 193 | [np.zeros((clnt_data_list[clnt__], 1)).astype(np.int64) for clnt__ in range(self.n_client)]) 194 | 195 | while (np.sum(clnt_data_list) != 0): 196 | curr_clnt = np.random.randint(self.n_client) 197 | # If current node is full resample a client 198 | if clnt_data_list[curr_clnt] <= 0: 199 | continue 200 | clnt_data_list[curr_clnt] -= 1 201 | curr_prior = prior_cumsum[curr_clnt] 202 | while True: 203 | cls_label = np.argmax(np.random.uniform() <= curr_prior) 204 | # Redraw class label if trn_y is out of that class 205 | if cls_amount[cls_label] <= 0: 206 | continue 207 | cls_amount[cls_label] -= 1 208 | 209 | concat_clnt_x[curr_clnt][clnt_data_list[curr_clnt]] = concat_datasets_x[idx_list[cls_label][cls_amount[cls_label]]] 210 | concat_clnt_y[curr_clnt][clnt_data_list[curr_clnt]] = concat_datasets_y[idx_list[cls_label][cls_amount[cls_label]]] 211 | 212 | break 213 | 214 | concat_clnt_x = np.asarray(concat_clnt_x) 215 | concat_clnt_y = np.asarray(concat_clnt_y) 216 | 217 | clnt_x = np.asarray( 218 | [np.zeros((clnt_data_list_train[clnt__], self.channels, self.height, self.width)).astype(np.float32) 219 | for 220 | clnt__ in range(self.n_client)]) 221 | clnt_y = np.asarray( 222 | [np.zeros((clnt_data_list_train[clnt__], 1)).astype(np.int64) for clnt__ in range(self.n_client)]) 223 | tst_x = np.asarray( 224 | [np.zeros((clnt_data_list_tst[clnt__], self.channels, self.height, self.width)).astype(np.float32) 225 | for 226 | clnt__ in range(self.n_client)]) 227 | tst_y = np.asarray( 228 | [np.zeros((clnt_data_list_tst[clnt__], 1)).astype(np.int64) for clnt__ in range(self.n_client)]) 229 | 230 | for jj in range(n_client): 231 | rand_perm = np.random.permutation(len(concat_clnt_y[jj])) 232 | concat_clnt_x[jj] = concat_clnt_x[jj][rand_perm] 233 | concat_clnt_y[jj] = concat_clnt_y[jj][rand_perm] 234 | 235 | clnt_x[jj] = concat_clnt_x[jj][:n_data_per_clnt_train, :, :, :] 236 | tst_x[jj] = concat_clnt_x[jj][n_data_per_clnt_train:, :, :, :] 237 | 238 | clnt_y[jj] = concat_clnt_y[jj][:n_data_per_clnt_train, :] 239 | tst_y[jj] = concat_clnt_y[jj][n_data_per_clnt_train:, :] 240 | 241 | 242 | cls_means = np.zeros((self.n_client, self.n_cls)) 243 | for clnt in range(self.n_client): 244 | for cls in range(self.n_cls): 245 | cls_means[clnt,cls] = np.mean(clnt_y[clnt]==cls) 246 | prior_real_diff = np.abs(cls_means-cls_priors) 247 | print('--- Max deviation from prior: %.4f' %np.max(prior_real_diff)) 248 | print('--- Min deviation from prior: %.4f' %np.min(prior_real_diff)) 249 | 250 | if self.rule == 'noniid': 251 | while np.sum(total_clnt_data_list) != n_data: 252 | # Still there are data to distibute 253 | # Get a random client that among the ones that has the least # of data with respect to totat data it is supposed to have 254 | min_amount = np.min(total_clnt_data_list - clnt_data_list) 255 | min_idx_list = np.where(total_clnt_data_list - clnt_data_list == min_amount)[0] 256 | np.random.shuffle(min_idx_list) 257 | cur_clnt = min_idx_list[0] 258 | print( 259 | 'Current client %d, total remaining amount %d' % (cur_clnt, n_data - np.sum(total_clnt_data_list))) 260 | 261 | # Get its class list 262 | cur_cls_list = np.asarray([(cur_clnt + jj) % n_cls for jj in range(cls_per_client)]) 263 | # Get the class that has minumum amount of data on the client 264 | cls_amounts = np.asarray([len(clnt_cls_idx[cur_clnt][jj]) for jj in range(n_cls)]) 265 | min_to_max = cur_cls_list[np.argsort(cls_amounts[cur_cls_list])] 266 | cur_idx = 0 267 | while cur_idx != len(min_to_max) and cls_amount[min_to_max[cur_idx]] == 0: 268 | cur_idx += 1 269 | if cur_idx == len(min_to_max): 270 | # This client is not full, it needs data but there is no class data left 271 | # Pick a random client and assign its data to this client 272 | while True: 273 | rand_clnt = np.random.randint(n_client) 274 | print('Random client %d' % rand_clnt) 275 | if rand_clnt == cur_clnt: # Pick a different client 276 | continue 277 | rand_clnt_cls = np.asarray([(rand_clnt + jj) % n_cls for jj in range(cls_per_client)]) 278 | # See if random client has an intersection class with the current client 279 | cur_list = np.asarray([(cur_clnt + jj) % n_cls for jj in range(cls_per_client)]) 280 | np.random.shuffle(cur_list) 281 | cls_idx = 0 282 | is_found = False 283 | while cls_idx != cls_per_client: 284 | if cur_list[cls_idx] in rand_clnt_cls and len( 285 | clnt_cls_idx[rand_clnt][cur_list[cls_idx]]) > 1: 286 | is_found = True 287 | break 288 | cls_idx += 1 289 | if not is_found: # No class intersection, choose another client 290 | continue 291 | found_cls = cur_list[cls_idx] 292 | # Assign this class instance to curr client 293 | total_clnt_data_list[cur_clnt] += 1 294 | total_clnt_data_list[rand_clnt] -= 1 295 | transfer_idx = clnt_cls_idx[rand_clnt][found_cls][-1] 296 | del clnt_cls_idx[rand_clnt][found_cls][-1] 297 | clnt_cls_idx[cur_clnt][found_cls].append(transfer_idx) 298 | # print('Class %d is transferred from %d to %d' %(found_cls, rand_clnt, cur_clnt)) 299 | break 300 | else: 301 | cur_cls = min_to_max[cur_idx] 302 | # Assign one data point from this class to the task 303 | total_clnt_data_list[cur_clnt] += 1 304 | cls_amount[cur_cls] -= 1 305 | clnt_cls_idx[cur_clnt][cur_cls].append(idx_list[cur_cls][cls_amount[cur_cls]]) 306 | # print('Chosen client: %d, chosen class: %d' %(cur_clnt, cur_cls)) 307 | 308 | for i in range(n_cls): 309 | assert 0 == cls_amount[i], 'Missing datapoints' 310 | assert n_data == np.sum(total_clnt_data_list), 'Missing datapoints' 311 | 312 | concat_clnt_x = np.asarray( 313 | [np.zeros((clnt_data_list[clnt__], self.channels, self.height, self.width)).astype(np.float32) for 314 | clnt__ in range(self.n_client)]) 315 | concat_clnt_y = np.asarray( 316 | [np.zeros((clnt_data_list[clnt__], 1)).astype(np.int64) for clnt__ in range(self.n_client)]) 317 | 318 | clnt_x = np.asarray( 319 | [np.zeros((clnt_data_list_train[clnt__], self.channels, self.height, self.width)).astype(np.float32) for 320 | clnt__ in range(self.n_client)]) 321 | clnt_y = np.asarray( 322 | [np.zeros((clnt_data_list_train[clnt__], 1)).astype(np.int64) for clnt__ in range(self.n_client)]) 323 | tst_x = np.asarray( 324 | [np.zeros((clnt_data_list_tst[clnt__], self.channels, self.height, self.width)).astype(np.float32) for 325 | clnt__ in range(self.n_client)]) 326 | tst_y = np.asarray( 327 | [np.zeros((clnt_data_list_tst[clnt__], 1)).astype(np.int64) for clnt__ in range(self.n_client)]) 328 | 329 | for jj in range(n_client): 330 | concat_clnt_x[jj] = concat_datasets_x[np.concatenate(clnt_cls_idx[jj]).astype(np.int32)] 331 | concat_clnt_y[jj] = concat_datasets_y[np.concatenate(clnt_cls_idx[jj]).astype(np.int32)] 332 | 333 | for jj in range(n_client): 334 | rand_perm = np.random.permutation(len(concat_clnt_y[jj])) 335 | concat_clnt_x[jj] = concat_clnt_x[jj][rand_perm] 336 | concat_clnt_y[jj] = concat_clnt_y[jj][rand_perm] 337 | 338 | clnt_x[jj] = concat_clnt_x[jj][:n_data_per_clnt_train, :, :, :] 339 | tst_x[jj] = concat_clnt_x[jj][n_data_per_clnt_train:, :, :, :] 340 | 341 | clnt_y[jj] = concat_clnt_y[jj][:n_data_per_clnt_train, :] 342 | tst_y[jj] = concat_clnt_y[jj][n_data_per_clnt_train:, :] 343 | 344 | self.clnt_x = clnt_x; 345 | self.clnt_y = clnt_y 346 | self.tst_x = tst_x; 347 | self.tst_y = tst_y 348 | 349 | # Save data 350 | os.mkdir('Data/%s' % (self.name)) 351 | 352 | np.save('Data/%s/clnt_x.npy' % (self.name), clnt_x) 353 | np.save('Data/%s/clnt_y.npy' % (self.name), clnt_y) 354 | 355 | np.save('Data/%s/tst_x.npy' % (self.name), tst_x) 356 | np.save('Data/%s/tst_y.npy' % (self.name), tst_y) 357 | 358 | if not os.path.exists('Model'): 359 | os.mkdir('Model') 360 | 361 | else: 362 | print("Data is already downloaded") 363 | 364 | self.clnt_x = np.load('Data/%s/clnt_x.npy' % (self.name)) 365 | self.clnt_y = np.load('Data/%s/clnt_y.npy' % (self.name)) 366 | self.n_client = len(self.clnt_x) 367 | 368 | self.tst_x = np.load('Data/%s/tst_x.npy' % (self.name)) 369 | self.tst_y = np.load('Data/%s/tst_y.npy' % (self.name)) 370 | 371 | if self.dataset == 'MNIST': 372 | self.channels = 1;self.width = 28;self.height = 28;self.n_cls = 10; 373 | if self.dataset == 'FMNIST': 374 | self.channels = 1; self.width = 28; self.height = 28;self.n_cls = 10; 375 | if self.dataset == 'CIFAR10': 376 | self.channels = 3; self.width = 32;self.height = 32;self.n_cls = 10; 377 | if self.dataset == 'CIFAR100': 378 | self.channels = 3;self.width = 32;self.height = 32;self.n_cls = 100; 379 | if self.dataset == 'EMNIST': 380 | self.channels = 1; self.width = 28; self.height = 28; self.n_cls = 10; 381 | 382 | print('Class frequencies:') 383 | 384 | # train 385 | count = 0 386 | for clnt in range(self.n_client): 387 | print("Client %3d: " % clnt + 388 | ', '.join(["%.3f" % np.mean(self.clnt_y[clnt] == cls) for cls in range(self.n_cls)]) + 389 | ', Amount:%d' % self.clnt_y[clnt].shape[0]) 390 | count += self.clnt_y[clnt].shape[0] 391 | 392 | print('Total Amount:%d' % count) 393 | print('-----------------------------------------------------------') 394 | 395 | # test 396 | count = 0 397 | for clnt in range(self.n_client): 398 | print("Client %3d: " % clnt + 399 | ', '.join(["%.3f" % np.mean(self.tst_y[clnt] == cls) for cls in range(self.n_cls)]) + 400 | ', Amount:%d' % self.tst_y[clnt].shape[0]) 401 | count += self.tst_y[clnt].shape[0] 402 | 403 | print('Total Amount:%d' % count) 404 | print('-----------------------------------------------------------') 405 | 406 | 407 | class Dataset(torch.utils.data.Dataset): 408 | 409 | def __init__(self, data_x, data_y=True, train=False, dataset_name=''): 410 | self.name = dataset_name 411 | 412 | if self.name == 'MNIST' or self.name == 'EMNIST' or self.name == 'FMNIST': 413 | self.X_data = torch.tensor(data_x).float() 414 | self.y_data = data_y 415 | if not isinstance(data_y, bool): 416 | self.y_data = torch.tensor(data_y).float() 417 | 418 | elif self.name == 'CIFAR10': 419 | self.train = train 420 | self.transform = transforms.Compose([transforms.ToTensor()]) 421 | self.X_data = torch.tensor(data_x).float() 422 | self.y_data = data_y 423 | if not isinstance(data_y, bool): 424 | self.y_data = data_y.astype('float32') 425 | self.augment_transform = transforms.Compose([ 426 | transforms.ToPILImage(), 427 | transforms.RandomCrop(32, padding=4), 428 | transforms.RandomHorizontalFlip(), 429 | transforms.ToTensor(), 430 | transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])]) 431 | self.noaugmt_transform = transforms.Compose( 432 | [transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])]) 433 | 434 | elif self.name == 'CIFAR100': 435 | self.train = train 436 | self.transform = transforms.Compose([transforms.ToTensor()]) 437 | self.X_data = torch.tensor(data_x).float() 438 | self.y_data = data_y 439 | if not isinstance(data_y, bool): 440 | self.y_data = data_y.astype('float32') 441 | self.augment_transform = transforms.Compose([ 442 | transforms.ToPILImage(), 443 | transforms.RandomCrop(32, padding=4), 444 | transforms.RandomHorizontalFlip(), 445 | transforms.ToTensor(), 446 | transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])]) 447 | self.noaugmt_transform = transforms.Compose( 448 | [transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])]) 449 | 450 | def __len__(self): 451 | return len(self.X_data) 452 | 453 | def __getitem__(self, idx): 454 | 455 | if self.name == 'MNIST' or self.name == 'EMNIST' or self.name == 'FMNIST': 456 | X = self.X_data[idx] 457 | if isinstance(self.y_data, bool): 458 | return X 459 | else: 460 | y = self.y_data[idx] 461 | return X, y 462 | 463 | elif self.name == 'CIFAR10': 464 | img = self.X_data[idx] 465 | if self.train: 466 | img = self.augment_transform(img) 467 | else: 468 | img = self.noaugmt_transform(img) 469 | if isinstance(self.y_data, bool): 470 | return img 471 | else: 472 | y = self.y_data[idx] 473 | return img, y 474 | 475 | elif self.name == 'CIFAR100': 476 | img = self.X_data[idx] 477 | if self.train: 478 | img = self.augment_transform(img) 479 | else: 480 | img = self.noaugmt_transform(img) 481 | if isinstance(self.y_data, bool): 482 | return img 483 | else: 484 | y = self.y_data[idx] 485 | 486 | return img, y 487 | 488 | if __name__ == '__main__': 489 | data_path = 'Folder/' # The folder to save Data & Model 490 | n_client = 100 491 | data_obj = DatasetObject(dataset='EMNIST', n_client=n_client, seed=23, rule = 'noniid', class_main=5, data_path=data_path, 492 | frac_data=0.7, dir_alpha=1) 493 | tst_x = data_obj.clnt_x 494 | tst_y = data_obj.clnt_y 495 | 496 | ''' 497 | trn_gen = data.DataLoader(Dataset(tst_x[0], tst_y[0], train=True, dataset_name='CIFAR10'), batch_size=32, shuffle=True) 498 | 499 | tst_gen_iter = trn_gen.__iter__() 500 | for i in range(int(np.ceil(300 / 32))): 501 | data, target = tst_gen_iter.__next__() 502 | print(target.shape) 503 | targets = target.reshape(-1) 504 | print(targets.shape) 505 | print(target[2]) 506 | print(target[2,0]) 507 | print(target[2].type(torch.long)) 508 | print(target[target[1].type(torch.long)]) 509 | print(target[:, None].shape) 510 | print(target[:, None].expand(-1,-1, 5)) 511 | 512 | ''' -------------------------------------------------------------------------------- /code_FedCR/models/distributed_training_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | import math 4 | from torch import nn, autograd 5 | import numpy as np 6 | from torch.utils.data import DataLoader 7 | from utils.utils_dataset import Dataset 8 | from models.Nets_VIB import KL_between_normals 9 | import itertools 10 | import torch.nn.functional as F 11 | 12 | max_norm = 10 13 | 14 | def add(target, source): 15 | for name in target: 16 | target[name].data += source[name].data.clone() 17 | 18 | 19 | def add_mome(target, source, beta_): 20 | for name in target: 21 | target[name].data = (beta_ * target[name].data + source[name].data.clone()) 22 | 23 | 24 | def add_mome2(target, source1, source2, beta_1, beta_2): 25 | for name in target: 26 | target[name].data = beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone() 27 | 28 | 29 | def add_mome3(target, source1, source2, source3, beta_1, beta_2, beta_3): 30 | for name in target: 31 | target[name].data = beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone() + beta_3 * source3[name].data.clone() 32 | 33 | def add_2(target, source1, source2, beta_1, beta_2): 34 | for name in target: 35 | target[name].data += beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone() 36 | 37 | def scale(target, scaling): 38 | for name in target: 39 | target[name].data = scaling * target[name].data.clone() 40 | 41 | 42 | def scale_ts(target, source, scaling): 43 | for name in target: 44 | target[name].data = scaling * source[name].data.clone() 45 | 46 | 47 | def subtract(target, source): 48 | for name in target: 49 | target[name].data -= source[name].data.clone() 50 | 51 | 52 | def subtract_(target, minuend, subtrahend): 53 | for name in target: 54 | target[name].data = minuend[name].data.clone() - subtrahend[name].data.clone() 55 | 56 | 57 | def average(target, sources): 58 | for name in target: 59 | target[name].data = torch.mean(torch.stack([source[name].data for source in sources]), dim=0).clone() 60 | 61 | 62 | def weighted_average(target, sources, weights): 63 | for name in target: 64 | summ = torch.sum(weights) 65 | n = len(sources) 66 | modify = [weight / summ * n for weight in weights] 67 | target[name].data = torch.mean(torch.stack([m * source[name].data for source, m in zip(sources, modify)]), 68 | dim=0).clone() 69 | 70 | 71 | def computer_norm(source1, source2): 72 | diff_norm = 0 73 | 74 | for name in source1: 75 | diff_source = source1[name].data.clone() - source2[name].data.clone() 76 | diff_norm += torch.pow(torch.norm(diff_source),2) 77 | 78 | return (torch.pow(diff_norm, 0.5)) 79 | 80 | def majority_vote(target, sources, lr): 81 | for name in target: 82 | threshs = torch.stack([torch.max(source[name].data) for source in sources]) 83 | mask = torch.stack([source[name].data.sign() for source in sources]).sum(dim=0).sign() 84 | target[name].data = (lr * mask).clone() 85 | 86 | 87 | def get_mdl_params(model_list, n_par=None): 88 | if n_par == None: 89 | exp_mdl = model_list[0] 90 | n_par = 0 91 | for name, param in exp_mdl.named_parameters(): 92 | n_par += len(param.data.reshape(-1)) 93 | 94 | param_mat = np.zeros((len(model_list), n_par)).astype('float32') 95 | for i, mdl in enumerate(model_list): 96 | idx = 0 97 | for name, param in mdl.named_parameters(): 98 | temp = param.data.cpu().numpy().reshape(-1) 99 | param_mat[i, idx:idx + len(temp)] = temp 100 | idx += len(temp) 101 | return np.copy(param_mat) 102 | 103 | 104 | def get_other_params(model_list, n_par=None): 105 | if n_par == None: 106 | exp_mdl = model_list[0] 107 | n_par = 0 108 | for name in exp_mdl: 109 | n_par += len(exp_mdl[name].data.reshape(-1)) 110 | 111 | param_mat = np.zeros((len(model_list), n_par)).astype('float32') 112 | for i, mdl in enumerate(model_list): 113 | idx = 0 114 | for name in mdl: 115 | temp = mdl[name].data.cpu().numpy().reshape(-1) 116 | param_mat[i, idx:idx + len(temp)] = temp 117 | idx += len(temp) 118 | return np.copy(param_mat) 119 | 120 | 121 | 122 | def product_of_experts_two(q_distr, p_distr): 123 | mu_q, sigma_q = q_distr 124 | mu_p, sigma_p = p_distr #Standard Deviation 125 | 126 | poe_var = torch.sqrt( torch.div((sigma_q**2 * sigma_p**2), (sigma_q**2 + sigma_p**2 + 1e-32)) ) 127 | 128 | poe_u = torch.div( (mu_p * sigma_q**2 + mu_q * sigma_p**2), (sigma_q**2 + sigma_p**2 + 1e-32) ) 129 | 130 | return poe_u, poe_var 131 | 132 | 133 | def product_of_experts(q_distr_set): 134 | mu_q_set, sigma_q_set = q_distr_set 135 | tmp1 = 1.0 136 | for i in range(len(mu_q_set)): 137 | tmp1 = tmp1 + (1.0 / (sigma_q_set[i] ** 2)) 138 | poe_var = torch.sqrt(1.0 / tmp1) 139 | tmp2 = 0.0 140 | for i in range(len(mu_q_set)): 141 | tmp2 = tmp2 + torch.div(mu_q_set[i], sigma_q_set[i]**2) 142 | poe_u = torch.div(tmp2, tmp1) 143 | return poe_u, poe_var 144 | 145 | 146 | ''' 147 | ### DEFINE NETWORK-RELATED FUNCTIONS 148 | def product_of_experts_two(q_distr, p_distr): 149 | mu_q, sigma_q = q_distr 150 | mu_p, sigma_p = p_distr #Standard Deviation 151 | tmp1 = (1.0 / (sigma_q**2 + 1e-8)) + (1.0 / (sigma_p**2 + 1e-8)) 152 | poe_var = torch.sqrt(1.0 / tmp1) 153 | tmp2 = torch.div(mu_q, sigma_q**2) + torch.div(mu_p, sigma_p**2) 154 | poe_u = torch.div(tmp2, tmp1) 155 | return poe_u, poe_var 156 | 157 | 158 | def product_of_experts(q_distr_set): 159 | mu_q_set, sigma_q_set = q_distr_set 160 | tmp1 = 1.0 161 | for i in range(len(mu_q_set)): 162 | tmp1 = tmp1 + (1.0 / (sigma_q_set[i] ** 2 + 1e-8)) 163 | poe_var = torch.sqrt(1.0 / tmp1) 164 | tmp2 = 0.0 165 | for i in range(len(mu_q_set)): 166 | tmp2 = tmp2 + torch.div(mu_q_set[i], sigma_q_set[i]**2) 167 | poe_u = torch.div(tmp2, tmp1) 168 | return poe_u, poe_var 169 | 170 | 171 | 172 | def product_of_experts_copy(mask_, mu_set_, logvar_set_): 173 | tmp = 1. 174 | for m in range(len(mu_set_)): 175 | tmp += torch.reshape(mask_[:, m], [-1, 1]) * torch.div(1., torch.exp(logvar_set_[m])) 176 | poe_var = torch.div(1., tmp) 177 | poe_logvar = torch.log(poe_var) 178 | tmp = 0. 179 | for m in range(len(mu_set_)): 180 | tmp += torch.reshape(mask_[:, m], [-1, 1]) * torch.div(1., torch.exp(logvar_set_[m])) * mu_set_[m] 181 | poe_mu = poe_var * tmp 182 | return poe_mu, poe_logvar 183 | ''' 184 | 185 | 186 | 187 | 188 | class DistributedTrainingDevice(object): 189 | ''' 190 | A distributed training device (Client or Server) 191 | data : a pytorch dataset consisting datapoints (x,y) 192 | model : a pytorch neural net f mapping x -> f(x)=y_ 193 | hyperparameters : a python dict containing all hyperparameters 194 | ''' 195 | 196 | def __init__(self, model, args): 197 | self.model = model 198 | self.args = args 199 | self.loss_func = nn.CrossEntropyLoss() 200 | 201 | class Client(DistributedTrainingDevice): 202 | 203 | def __init__(self, model, args, trn_x, trn_y, tst_x, tst_y, n_cls, dataset_name, id_num=0): 204 | super().__init__(model, args) 205 | 206 | self.trn_gen = DataLoader(Dataset(trn_x, trn_y, train=True, dataset_name=dataset_name), 207 | batch_size=self.args.local_bs, shuffle=True) 208 | 209 | self.tst_x = tst_x 210 | self.tst_y = tst_y 211 | self.n_cls = n_cls 212 | self.id = id_num 213 | self.local_epoch = int(np.ceil(trn_x.shape[0] / self.args.local_bs)) 214 | # Parameters 215 | self.W = {name: value for name, value in self.model.named_parameters()} 216 | 217 | self.state_params_diff = 0.0 218 | self.train_loss = 0.0 219 | self.n_par = get_mdl_params([self.model]).shape[1] 220 | 221 | 222 | def synchronize_with_server(self, server, w_glob_keys): 223 | # W_client = W_server 224 | 225 | if self.args.method != 'fedavg' and self.args.method != 'ditto': 226 | for name in self.W: 227 | if name in w_glob_keys: 228 | self.W[name].data = server.W[name].data.clone() 229 | else: 230 | self.model = copy.deepcopy(server.model) 231 | self.W = {name: value for name, value in self.model.named_parameters()} 232 | 233 | def compute_bias(self): 234 | if self.args.method == 'ditto': 235 | cld_mdl_param = torch.tensor(get_mdl_params([self.model], self.n_par)[0], dtype=torch.float32, device=self.args.device) 236 | self.state_params_diff = self.args.mu * (-cld_mdl_param) 237 | 238 | def train_cnn(self, w_glob_keys, server, last): 239 | 240 | self.model.train() 241 | 242 | if self.args.method == 'ditto': 243 | optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.lr, momentum=self.args.momentum, 244 | weight_decay=self.args.weigh_delay + self.args.mu) 245 | else: 246 | optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.lr, momentum=self.args.momentum, 247 | weight_decay=self.args.weigh_delay) 248 | 249 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=1) 250 | 251 | local_eps = self.args.local_ep 252 | if last: 253 | if self.args.method =='fedavg' or self.args.method == 'ditto': 254 | local_eps= self.args.last_local_ep 255 | if 'CIFAR100' in self.args.dataset: 256 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1, 2, 3]] 257 | elif 'CIFAR10' in self.args.dataset: 258 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1, 2, 3]] 259 | elif 'EMNIST' in self.args.dataset: 260 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1]] 261 | elif 'FMNIST' in self.args.dataset: 262 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1, 2, 3]] 263 | w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys)) 264 | 265 | elif 'maml' in self.args.method: 266 | local_eps = self.args.last_local_ep / 2 267 | w_glob_keys = [] 268 | else: 269 | local_eps = max(self.args.last_local_ep, self.args.local_ep - self.args.local_rep_ep) 270 | 271 | # all other methods update all parameters simultaneously 272 | else: 273 | for name, param in self.model.named_parameters(): 274 | param.requires_grad = True 275 | 276 | # train and update 277 | epoch_loss = [] 278 | 279 | if self.args.method == 'FedCR': 280 | 281 | self.dir_Z_u = torch.zeros(self.n_cls, 1, self.args.dimZ, dtype=torch.float32, device=self.args.device) 282 | self.dir_Z_sigma = torch.ones(self.n_cls, 1, self.args.dimZ, dtype = torch.float32, device = self.args.device) 283 | 284 | for iter in range(local_eps): 285 | 286 | if last: 287 | for name, param in self.model.named_parameters(): 288 | if name in w_glob_keys: 289 | param.requires_grad = False 290 | else: 291 | param.requires_grad = True 292 | 293 | loss_by_epoch = [] 294 | accuracy_by_epoch = [] 295 | 296 | trn_gen_iter = self.trn_gen.__iter__() 297 | batch_loss = [] 298 | 299 | for i in range(self.local_epoch): 300 | dir_g_Z_u = torch.zeros(1, self.args.dimZ, dtype=torch.float32, 301 | device=self.args.device) 302 | dir_g_Z_sigma = torch.ones(1, self.args.dimZ, dtype=torch.float32, 303 | device=self.args.device) 304 | 305 | images, labels = trn_gen_iter.__next__() 306 | images, labels = images.to(self.args.device), labels.to(self.args.device) 307 | labels = labels.reshape(-1).long() 308 | 309 | batch_size = images.size()[0] 310 | 311 | #prior_Z_distr_standard = torch.zeros(batch_size, self.args.dimZ).to(self.args.device), torch.ones(batch_size, self.args.dimZ).to(self.args.device) 312 | 313 | for cls in range(len(labels)): 314 | if cls == 0: 315 | dir_g_Z_u = server.dir_global_Z_u[labels[cls]].clone().detach() 316 | dir_g_Z_sigma = server.dir_global_Z_sigma[labels[cls]].clone().detach() 317 | else: 318 | dir_g_Z_u = torch.cat((dir_g_Z_u, server.dir_global_Z_u[labels[cls]].clone().detach()), 0).clone().detach() 319 | dir_g_Z_sigma = torch.cat((dir_g_Z_sigma, server.dir_global_Z_sigma[labels[cls]].clone().detach()), 0).clone().detach() 320 | 321 | prior_Z_distr = dir_g_Z_u, dir_g_Z_sigma 322 | 323 | encoder_Z_distr, decoder_logits = self.model(images, self.args.num_avg_train) 324 | 325 | decoder_logits_mean = torch.mean(decoder_logits, dim=0) 326 | 327 | loss = nn.CrossEntropyLoss(reduction='none') 328 | decoder_logits = decoder_logits.permute(1, 2, 0) 329 | cross_entropy_loss = loss(decoder_logits, labels[:, None].expand(-1, self.args.num_avg_train)) 330 | 331 | # estimate E_{eps in N(0, 1)} [log q(y | z)] 332 | cross_entropy_loss_montecarlo = torch.mean(cross_entropy_loss, dim=-1) 333 | 334 | I_ZX_bound = torch.mean(KL_between_normals(prior_Z_distr, encoder_Z_distr)) 335 | minusI_ZY_bound = torch.mean(cross_entropy_loss_montecarlo, dim=0) 336 | total_loss = torch.mean(minusI_ZY_bound + self.args.beta * I_ZX_bound) 337 | 338 | prediction = torch.max(decoder_logits_mean, dim=1)[1] 339 | accuracy = torch.mean((prediction == labels).float()) 340 | 341 | optimizer.zero_grad() 342 | total_loss.backward() 343 | torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), max_norm=max_norm) 344 | optimizer.step() 345 | 346 | loss_by_epoch.append(total_loss.item()) 347 | accuracy_by_epoch.append(accuracy.item()) 348 | 349 | if iter == local_eps - 1: 350 | for cls in range(len(labels)): 351 | if self.dir_Z_u[labels[cls]] .equal(torch.zeros(1, self.args.dimZ, dtype=torch.float32, device=self.args.device)) and self.dir_Z_sigma[labels[cls]] .equal(torch.ones(1, self.args.dimZ, dtype=torch.float32, device=self.args.device)): 352 | self.dir_Z_u[labels[cls]], self.dir_Z_sigma[labels[cls]] = encoder_Z_distr[0][cls].clone().detach(), encoder_Z_distr[1][cls].clone().detach() 353 | else: 354 | q_distr = self.dir_Z_u[labels[cls]], self.dir_Z_sigma[labels[cls]] 355 | encoder_Z_distr_cls = encoder_Z_distr[0][cls].clone().detach(), encoder_Z_distr[1][cls].clone().detach() 356 | self.dir_Z_u[labels[cls]], self.dir_Z_sigma[labels[cls]] = product_of_experts_two(q_distr, encoder_Z_distr_cls) 357 | 358 | scheduler.step() 359 | epoch_loss.append(sum(loss_by_epoch) / len(loss_by_epoch)) 360 | 361 | 362 | else: 363 | for iter in range(local_eps): 364 | 365 | head_eps = local_eps - self.args.local_rep_ep 366 | # for FedRep, first do local epochs for the head 367 | if (iter < head_eps and self.args.method == 'fedrep') or last: 368 | for name, param in self.model.named_parameters(): 369 | if name in w_glob_keys: 370 | param.requires_grad = False 371 | else: 372 | param.requires_grad = True 373 | 374 | # then do local epochs for the representation 375 | elif (iter == head_eps and self.args.method == 'fedrep') and not last: 376 | for name, param in self.model.named_parameters(): 377 | if name in w_glob_keys: 378 | param.requires_grad = True 379 | else: 380 | param.requires_grad = False 381 | 382 | # all other methods update all parameters simultaneously 383 | elif self.args.method != 'fedrep': 384 | for name, param in self.model.named_parameters(): 385 | param.requires_grad = True 386 | 387 | trn_gen_iter = self.trn_gen.__iter__() 388 | batch_loss = [] 389 | 390 | for i in range(self.local_epoch): 391 | 392 | images, labels = trn_gen_iter.__next__() 393 | images, labels = images.to(self.args.device), labels.to(self.args.device) 394 | 395 | optimizer.zero_grad() 396 | log_probs = self.model(images) 397 | loss_f_i = self.loss_func(log_probs, labels.reshape(-1).long()) 398 | 399 | local_par_list = None 400 | for param in self.model.parameters(): 401 | if not isinstance(local_par_list, torch.Tensor): 402 | # Initially nothing to concatenate 403 | local_par_list = param.reshape(-1) 404 | else: 405 | local_par_list = torch.cat((local_par_list, param.reshape(-1)), 0) 406 | 407 | loss_algo = torch.sum(local_par_list * self.state_params_diff) 408 | 409 | if self.args.method == 'ditto': 410 | loss = loss_f_i + loss_algo 411 | else: 412 | loss = loss_f_i 413 | 414 | loss.backward() 415 | torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), max_norm=max_norm) 416 | optimizer.step() 417 | batch_loss.append(loss.item()) 418 | 419 | scheduler.step() 420 | epoch_loss.append(sum(batch_loss) / len(batch_loss)) 421 | 422 | return sum(epoch_loss) / len(epoch_loss) 423 | 424 | def compute_weight_update(self, w_glob_keys, server, last=False): 425 | 426 | # Training mode 427 | self.model.train() 428 | 429 | # W = SGD(W, D) 430 | self.train_loss = self.train_cnn(w_glob_keys, server, last) 431 | 432 | 433 | 434 | @torch.no_grad() 435 | def evaluate_FedVIB(self, data_x, data_y, dataset_name): 436 | self.model.eval() 437 | # testing 438 | I_ZX_bound_by_epoch_test = [] 439 | I_ZY_bound_by_epoch_test = [] 440 | loss_by_epoch_test = [] 441 | accuracy_by_epoch_test = [] 442 | 443 | n_tst = data_x.shape[0] 444 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False) 445 | tst_gen_iter = tst_gen.__iter__() 446 | for i in range(int(np.ceil(n_tst / self.args.bs))): 447 | data, target = tst_gen_iter.__next__() 448 | data, target = data.to(self.args.device), target.to(self.args.device) 449 | target = target.reshape(-1).long() 450 | batch_size = data.size()[0] 451 | prior_Z_distr = torch.zeros(batch_size, self.args.dimZ).to(self.args.device), torch.ones(batch_size,self.args.dimZ).to(self.args.device) 452 | encoder_Z_distr, decoder_logits = self.model(data, self.args.num_avg) 453 | 454 | decoder_logits_mean = torch.mean(decoder_logits, dim=0) 455 | loss = nn.CrossEntropyLoss(reduction='none') 456 | decoder_logits = decoder_logits.permute(1, 2, 0) 457 | cross_entropy_loss = loss(decoder_logits, target[:, None].expand(-1, self.args.num_avg)) 458 | 459 | cross_entropy_loss_montecarlo = torch.mean(cross_entropy_loss, dim=-1) 460 | 461 | I_ZX_bound_test = torch.mean(KL_between_normals(encoder_Z_distr, prior_Z_distr)) 462 | minusI_ZY_bound_test = torch.mean(cross_entropy_loss_montecarlo, dim=0) 463 | total_loss_test = torch.mean(minusI_ZY_bound_test + self.args.beta * I_ZX_bound_test) 464 | 465 | prediction = torch.max(decoder_logits_mean, dim=1)[1] 466 | accuracy_test = torch.mean((prediction == target).float()) 467 | 468 | I_ZX_bound_by_epoch_test.append(I_ZX_bound_test.item()) 469 | I_ZY_bound_by_epoch_test.append(minusI_ZY_bound_test.item()) 470 | 471 | loss_by_epoch_test.append(total_loss_test.item()) 472 | accuracy_by_epoch_test.append(accuracy_test.item()) 473 | 474 | I_ZX = np.mean(I_ZX_bound_by_epoch_test) 475 | I_ZY = np.mean(I_ZY_bound_by_epoch_test) 476 | loss_test = np.mean(loss_by_epoch_test) 477 | accuracy_test = np.mean(accuracy_by_epoch_test) 478 | accuracy_test = 100.00 * accuracy_test 479 | return accuracy_test, loss_test 480 | 481 | 482 | @torch.no_grad() 483 | def evaluate(self, data_x, data_y, dataset_name): 484 | self.model.eval() 485 | # testing 486 | test_loss = 0 487 | acc_overall = 0 488 | n_tst = data_x.shape[0] 489 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False) 490 | tst_gen_iter = tst_gen.__iter__() 491 | for i in range(int(np.ceil(n_tst / self.args.bs))): 492 | data, target = tst_gen_iter.__next__() 493 | data, target = data.to(self.args.device), target.to(self.args.device) 494 | log_probs = self.model(data) 495 | # sum up batch loss 496 | test_loss += nn.CrossEntropyLoss(reduction='sum')(log_probs, target.reshape(-1).long()).item() 497 | # get the index of the max log-probability 498 | log_probs = log_probs.cpu().detach().numpy() 499 | log_probs = np.argmax(log_probs, axis=1).reshape(-1) 500 | target = target.cpu().numpy().reshape(-1).astype(np.int32) 501 | batch_correct = np.sum(log_probs == target) 502 | acc_overall += batch_correct 503 | ''' 504 | y_pred = log_probs.data.max(1, keepdim=True)[1] 505 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 506 | ''' 507 | 508 | test_loss /= n_tst 509 | accuracy = 100.00 * acc_overall / n_tst 510 | return accuracy, test_loss 511 | 512 | 513 | ''' 514 | ------------------------------------------------------------------------------------------------------------------------ 515 | 516 | ------------------------------------------------------------------------------------------------------------------------ 517 | ''' 518 | 519 | 520 | class Server(DistributedTrainingDevice): 521 | 522 | def __init__(self, model, args, n_cls): 523 | super().__init__(model, args) 524 | 525 | # Parameters 526 | self.W = {name: value for name, value in self.model.named_parameters()} 527 | self.local_epoch = 0 528 | self.n_cls = n_cls 529 | if self.args.method == 'FedCR': 530 | self.dir_global_Z_u = torch.zeros(self.n_cls, 1, self.args.dimZ, dtype=torch.float32, device=self.args.device) 531 | self.dir_global_Z_sigma = torch.ones(self.n_cls, 1, self.args.dimZ, dtype = torch.float32, device = self.args.device) 532 | 533 | 534 | def aggregate_weight_updates(self, clients, iter, aggregation="mean"): 535 | 536 | # Warning: Note that K is different for unbalanced dataset 537 | self.local_epoch = clients[0].local_epoch 538 | # dW = aggregate(dW_i, i=1,..,n) 539 | if aggregation == "mean": 540 | average(target=self.W, sources=[client.W for client in clients]) 541 | 542 | 543 | def global_POE(self, clients): 544 | 545 | dir_global_Z_u_copy = copy.deepcopy(self.dir_global_Z_u) 546 | dir_global_Z_sigma_copy = copy.deepcopy(self.dir_global_Z_sigma) 547 | 548 | for cls in range(self.n_cls): 549 | clients_all_Z_u = True 550 | clients_all_Z_sigma= True 551 | 552 | for i in range(len(clients)): 553 | 554 | if clients[i].dir_Z_u[cls].equal(torch.zeros(1, self.args.dimZ, dtype=torch.float32, device=self.args.device)) and clients[i].dir_Z_sigma[cls].equal(torch.ones(1, self.args.dimZ, dtype=torch.float32, device=self.args.device)): 555 | pass 556 | elif isinstance(clients_all_Z_u, bool): 557 | clients_all_Z_u =clients[i].dir_Z_u[cls].clone().detach() 558 | clients_all_Z_sigma =clients[i].dir_Z_sigma[cls].clone().detach() 559 | else: 560 | clients_all_Z_u = torch.cat((clients_all_Z_u, clients[i].dir_Z_u[cls].clone().detach()), 0).clone().detach() 561 | clients_all_Z_sigma = torch.cat((clients_all_Z_sigma, clients[i].dir_Z_sigma[cls].clone().detach()), 0).clone().detach() 562 | 563 | if not isinstance(clients_all_Z_u, bool): 564 | clients_all_Z = clients_all_Z_u, clients_all_Z_sigma 565 | dir_global_Z_u_copy[cls], dir_global_Z_sigma_copy[cls] = product_of_experts(clients_all_Z) 566 | 567 | self.dir_global_Z_u[cls] = (1- self.args.beta2) * self.dir_global_Z_u[cls] + self.args.beta2 * dir_global_Z_u_copy[cls] 568 | self.dir_global_Z_sigma[cls] = (1- self.args.beta2) * self.dir_global_Z_sigma[cls] + self.args.beta2 * dir_global_Z_sigma_copy[cls] 569 | 570 | 571 | @torch.no_grad() 572 | def evaluate_FedVIB(self, data_x, data_y, dataset_name): 573 | self.model.eval() 574 | # testing 575 | I_ZX_bound_by_epoch_test = [] 576 | I_ZY_bound_by_epoch_test = [] 577 | loss_by_epoch_test = [] 578 | accuracy_by_epoch_test = [] 579 | 580 | n_tst = data_x.shape[0] 581 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False) 582 | tst_gen_iter = tst_gen.__iter__() 583 | for i in range(int(np.ceil(n_tst / self.args.bs))): 584 | data, target = tst_gen_iter.__next__() 585 | data, target = data.to(self.args.device), target.to(self.args.device) 586 | batch_size = data.size()[0] 587 | prior_Z_distr = torch.zeros(batch_size, self.args.dimZ).to(self.args.device), torch.ones(batch_size,self.args.dimZ).to(self.args.device) 588 | encoder_Z_distr, decoder_logits = self.model(data, self.args.num_avg) 589 | 590 | decoder_logits_mean = torch.mean(decoder_logits, dim=0) 591 | loss = nn.CrossEntropyLoss(reduction='none') 592 | decoder_logits = decoder_logits.permute(1, 2, 0) 593 | cross_entropy_loss = loss(decoder_logits, target[:, None].expand(-1, self.args.num_avg)) 594 | 595 | cross_entropy_loss_montecarlo = torch.mean(cross_entropy_loss, dim=-1) 596 | 597 | I_ZX_bound_test = torch.mean(KL_between_normals(encoder_Z_distr, prior_Z_distr)) 598 | minusI_ZY_bound_test = torch.mean(cross_entropy_loss_montecarlo, dim=0) 599 | total_loss_test = torch.mean(minusI_ZY_bound_test + self.args.beta * I_ZX_bound_test) 600 | 601 | prediction = torch.max(decoder_logits_mean, dim=1)[1] 602 | accuracy_test = torch.mean((prediction == target).float()) 603 | 604 | I_ZX_bound_by_epoch_test.append(I_ZX_bound_test.item()) 605 | I_ZY_bound_by_epoch_test.append(minusI_ZY_bound_test.item()) 606 | 607 | loss_by_epoch_test.append(total_loss_test.item()) 608 | accuracy_by_epoch_test.append(accuracy_test.item()) 609 | 610 | I_ZX = np.mean(I_ZX_bound_by_epoch_test) 611 | I_ZY = np.mean(I_ZY_bound_by_epoch_test) 612 | loss_test = np.mean(loss_by_epoch_test) 613 | accuracy_test = np.mean(accuracy_by_epoch_test) 614 | accuracy_test = 100.00 * accuracy_test 615 | return accuracy_test, loss_test 616 | 617 | 618 | @torch.no_grad() 619 | def evaluate(self, data_x, data_y, dataset_name): 620 | self.model.eval() 621 | # testing 622 | test_loss = 0 623 | acc_overall = 0 624 | n_tst = data_x.shape[0] 625 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False) 626 | tst_gen_iter = tst_gen.__iter__() 627 | for i in range(int(np.ceil(n_tst / self.args.bs))): 628 | data, target = tst_gen_iter.__next__() 629 | data, target = data.to(self.args.device), target.to(self.args.device) 630 | log_probs = self.model(data) 631 | # sum up batch loss 632 | test_loss += nn.CrossEntropyLoss(reduction='sum')(log_probs, target.reshape(-1).long()).item() 633 | # get the index of the max log-probability 634 | log_probs = log_probs.cpu().detach().numpy() 635 | log_probs = np.argmax(log_probs, axis=1).reshape(-1) 636 | target = target.cpu().numpy().reshape(-1).astype(np.int32) 637 | batch_correct = np.sum(log_probs == target) 638 | acc_overall += batch_correct 639 | ''' 640 | y_pred = log_probs.data.max(1, keepdim=True)[1] 641 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 642 | ''' 643 | 644 | test_loss /= n_tst 645 | accuracy = 100.00 * acc_overall / n_tst 646 | return accuracy, test_loss 647 | 648 | 649 | 650 | --------------------------------------------------------------------------------