├── code_FedCR
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── Nets.cpython-37.pyc
│ │ ├── Nets.cpython-38.pyc
│ │ ├── NetsSR.cpython-37.pyc
│ │ ├── Update.cpython-37.pyc
│ │ ├── Nets_PAC.cpython-37.pyc
│ │ ├── Nets_VIB.cpython-37.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── resnet18.cpython-37.pyc
│ │ ├── resnet18.cpython-38.pyc
│ │ ├── Nets_VIB_2D.cpython-37.pyc
│ │ ├── distributed_training_utils.cpython-37.pyc
│ │ ├── distributed_training_utils.cpython-38.pyc
│ │ ├── distributed_training_utilsSR.cpython-37.pyc
│ │ ├── distributed_training_utils_2D.cpython-37.pyc
│ │ ├── distributed_training_utils_PAC.cpython-37.pyc
│ │ ├── distributed_training_utils_ditto.cpython-37.pyc
│ │ ├── distributed_training_utils_old2.cpython-37.pyc
│ │ └── distributed_training_utils_old3.cpython-37.pyc
│ ├── Update.py
│ ├── Nets.py
│ ├── Nets_PAC.py
│ ├── resnet18.py
│ ├── main_fedSR.py
│ ├── Nets_VIB.py
│ ├── NetsSR.py
│ ├── distributed_training_utils_ditto.py
│ ├── distributed_training_utilsSR.py
│ ├── distributed_training_utils_PAC.py
│ └── distributed_training_utils.py
├── utils
│ ├── __init__.py
│ ├── seed.py
│ ├── logg.py
│ ├── options.py
│ └── utils_dataset.py
├── .idea
│ ├── misc.xml
│ ├── inspectionProfiles
│ │ └── profiles_settings.xml
│ ├── modules.xml
│ ├── code.iml
│ └── workspace.xml
├── main_fed_ditto.py
├── main_fedSR.py
├── main_fed_local.py
├── main_fed_PAC.py
└── main_fed.py
├── requirement.txt
└── README.md
/code_FedCR/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @python: 3.7
4 |
5 |
--------------------------------------------------------------------------------
/code_FedCR/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @python: 3.6
4 |
5 |
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | cachetools==4.1.1
2 | scikit-learn==0.23.2
3 | scipy==1.4.1
4 | torchvision==0.5.0
5 | torch==1.4.0
6 | numpy==1.18.1
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/Nets.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/Nets.cpython-37.pyc
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/Nets.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/Nets.cpython-38.pyc
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/NetsSR.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/NetsSR.cpython-37.pyc
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/Update.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/Update.cpython-37.pyc
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/Nets_PAC.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/Nets_PAC.cpython-37.pyc
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/Nets_VIB.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/Nets_VIB.cpython-37.pyc
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/resnet18.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/resnet18.cpython-37.pyc
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/resnet18.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/resnet18.cpython-38.pyc
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/Nets_VIB_2D.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/Nets_VIB_2D.cpython-37.pyc
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/distributed_training_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/distributed_training_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/distributed_training_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/distributed_training_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/distributed_training_utilsSR.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/distributed_training_utilsSR.cpython-37.pyc
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/distributed_training_utils_2D.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/distributed_training_utils_2D.cpython-37.pyc
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/distributed_training_utils_PAC.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/distributed_training_utils_PAC.cpython-37.pyc
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/distributed_training_utils_ditto.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/distributed_training_utils_ditto.cpython-37.pyc
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/distributed_training_utils_old2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/distributed_training_utils_old2.cpython-37.pyc
--------------------------------------------------------------------------------
/code_FedCR/models/__pycache__/distributed_training_utils_old3.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haozzh/FedCR/HEAD/code_FedCR/models/__pycache__/distributed_training_utils_old3.cpython-37.pyc
--------------------------------------------------------------------------------
/code_FedCR/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/code_FedCR/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/code_FedCR/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/code_FedCR/.idea/code.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/code_FedCR/utils/seed.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import random
4 |
5 | def setup_seed(seed):
6 | random.seed(seed)
7 | np.random.seed(seed)
8 | torch.manual_seed(seed)
9 | torch.cuda.manual_seed(seed)
10 | torch.backends.cudnn.deterministic = True
11 | torch.backends.cudnn.benchmark = True
--------------------------------------------------------------------------------
/code_FedCR/utils/logg.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | def get_logger(filename, verbosity=1, name=None):
4 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
5 | formatter = logging.Formatter(
6 | "[%(asctime)s][%(levelname)s] %(message)s"
7 | )
8 | logger = logging.getLogger(name)
9 | logger.setLevel(level_dict[verbosity])
10 |
11 | fh = logging.FileHandler(filename, "w")
12 | fh.setFormatter(formatter)
13 | logger.addHandler(fh)
14 |
15 | sh = logging.StreamHandler()
16 | sh.setFormatter(formatter)
17 | logger.addHandler(sh)
18 |
19 | return logger
--------------------------------------------------------------------------------
/code_FedCR/models/Update.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 |
5 | import torch
6 | import copy
7 | from torch import nn, autograd
8 | from torch.utils.data import DataLoader, Dataset
9 |
10 | class DatasetSplit(Dataset):
11 | def __init__(self, dataset, idxs):
12 | self.dataset = dataset
13 | self.idxs = list(idxs)
14 |
15 | def __len__(self):
16 | return len(self.idxs)
17 |
18 | def __getitem__(self, item):
19 | image, label = self.dataset[self.idxs[item]]
20 | return image, label
21 |
22 | class LocalUpdate(object):
23 | def __init__(self, args, dataset=None, idxs=None):
24 | self.args = args
25 | self.loss_func = nn.CrossEntropyLoss()
26 | self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)
27 |
28 | def train(self, net):
29 | net.train()
30 | optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum, weight_decay=self.args.weigh_delay)
31 | net_pre = copy.deepcopy(net)
32 | state_pre = [parameter for parameter in net_pre.parameters()]
33 | # train and update
34 | epoch_loss = []
35 | for iter in range(self.args.local_ep):
36 | batch_loss = []
37 | for batch_idx, (images, labels) in enumerate(self.ldr_train):
38 | images, labels = images.to(self.args.device), labels.to(self.args.device)
39 |
40 | net.zero_grad()
41 | log_probs = net(images)
42 | loss = self.loss_func(log_probs, labels)
43 | loss.backward()
44 | optimizer.step()
45 | batch_loss.append(loss.item())
46 | epoch_loss.append(sum(batch_loss)/len(batch_loss))
47 |
48 | state_now = [parameter for parameter in net.parameters()]
49 | grads = [torch.zeros_like(param) for param in state_now]
50 | for state_now_, state_pre_, grad in zip(state_now, state_pre, grads):
51 | grad.data[:] = state_now_ - state_pre_
52 | return grads, sum(epoch_loss) / len(epoch_loss)
53 |
54 |
--------------------------------------------------------------------------------
/code_FedCR/utils/options.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 |
5 | import argparse
6 |
7 | def args_parser():
8 | parser = argparse.ArgumentParser()
9 | # federated arguments
10 | parser.add_argument('--epochs', type=int, default=500, help="rounds of training")
11 | parser.add_argument('--test_freq', type=int, default=1, help="frequency of test")
12 | parser.add_argument('--num_users', type=int, default=100, help="number of users")
13 | parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients")
14 | parser.add_argument('--local_ep', type=int, default=2, help="the number of local epochs")
15 | parser.add_argument('--last_local_ep', type=int, default=10, help="the number of local epochs of last")
16 | parser.add_argument('--local_rep_ep', type=int, default=1, help="the number of local epochs of Fed_Rep's feature") #ten local epochs to train the local head, followed by one or five epochs for the representation
17 | parser.add_argument('--local_bs', type=int, default=48, help="local batch size: B")
18 | parser.add_argument('--bs', type=int, default=10, help="test batch size")
19 | parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
20 | parser.add_argument('--globallr', type=float, default=1, help="Global learning rate")
21 | parser.add_argument('--momentum', type=float, default=0, help="local SGD momentum (default: 0.0)")
22 | parser.add_argument('--weigh_delay', type=float, default=0, help="local SGD weigh_delay")
23 | parser.add_argument('--mu', type=float, default=0, help='the value of Ditto')
24 | parser.add_argument('--mu1', type=float, default=0, help='the value_L1 of Factorization')
25 | parser.add_argument('--tau1', type=float, default=0, help='the value_cosine similarity of Factorization')
26 | parser.add_argument('--lr_decay', type=float, default=0.997, help='the value of lr_decay')
27 | parser.add_argument('--alpha', default=0, type=float, help='the value of alpha for Fed_VIB')
28 | parser.add_argument('--beta', default=0.001, type=float, help='the value of beta for Fed_VIB')
29 | parser.add_argument('--sync', type=str, default='True', help='If the model is synchronized for FedCR')
30 | parser.add_argument('--beta_PAC', default=1, type=float, help='the value of beta for FedPAC')
31 | parser.add_argument('--beta2', default=0, type=float, help='the value of beta2 for Z')
32 | parser.add_argument('--dimZ', default = 256, type=int, help='dimension of encoding Z in Fed_VIB')
33 | parser.add_argument('--dimZ_PAC', default=1024, type=int, help='dimension of Z in Factorized')
34 | parser.add_argument('--CMI', default=0.001, type=float, help='the value of CMI in FedSR')
35 | parser.add_argument('--L2R', default=0.001, type=float, help='the value of L2R in FedSR')
36 | parser.add_argument('--num_avg_train', default = 15, type=int, help='the number of samples when\
37 | perform multi-shot train')
38 | parser.add_argument('--num_avg', default = 30, type=int, help='the number of samples when\
39 | perform multi-shot prediction')
40 |
41 | parser.add_argument('--filepath', type=str, default='filepath', help='whether error accumulation or not')
42 |
43 | # model arguments
44 | parser.add_argument('--method', type=str, default='fedCR', help='method name')
45 |
46 | # other arguments
47 | parser.add_argument('--dataset', type=str, default='CIFAR10', help="name of dataset")
48 | parser.add_argument('--frac_data', type=float, default=0.7, help="fraction of frac_data")
49 | parser.add_argument('--rule', type=str, default='noniid', help='whether noniid or Dirichet')
50 | parser.add_argument('--class_main', type=int, default=5, help='the value of class_main for noniid')
51 | parser.add_argument('--dir_a', default=0.5, type=float, help='the value of dir_a for dirichlet')
52 | parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")
53 | parser.add_argument('--seed', type=int, default=23, help='random seed (default: 23)')
54 | args = parser.parse_args()
55 | return args
56 |
57 |
58 |
--------------------------------------------------------------------------------
/code_FedCR/models/Nets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision.models as models
5 |
6 |
7 | class client_model(nn.Module):
8 | def __init__(self, name, args=True):
9 | super(client_model, self).__init__()
10 | self.name = name
11 |
12 | if self.name == 'emnist_NN':
13 | self.n_cls = 10
14 | self.fc1 = nn.Linear(1 * 28 * 28, 1024)
15 | self.fc2 = nn.Linear(1024, 1024)
16 | self.fc3 = nn.Linear(1024, self.n_cls)
17 | self.weight_keys = [['fc1.weight', 'fc1.bias'],
18 | ['fc2.weight', 'fc2.bias'],
19 | ['fc3.weight', 'fc3.bias']]
20 |
21 |
22 | if self.name == 'FMNIST_CNN':
23 | self.n_cls = 10
24 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5)
25 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
26 | self.conv2 = nn.Conv2d(in_channels=4, out_channels=12, kernel_size=5)
27 | self.fc1 = nn.Linear(12 * 4 * 4, 1024)
28 | self.fc2 = nn.Linear(1024, 1024)
29 | self.fc3 = nn.Linear(1024, self.n_cls)
30 |
31 | self.weight_keys = [['conv1.weight', 'conv1.bias'],
32 | ['conv2.weight', 'conv2.bias'],
33 | ['fc1.weight', 'fc1.bias'],
34 | ['fc2.weight', 'fc2.bias'],
35 | ['fc3.weight', 'fc3.bias']]
36 |
37 |
38 | if self.name == 'cifar10_LeNet':
39 | self.n_cls = 10
40 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5)
41 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5)
42 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
43 | self.fc1 = nn.Linear(64 * 5 * 5, 1024)
44 | self.fc2 = nn.Linear(1024, 1024)
45 | self.fc3 = nn.Linear(1024, self.n_cls)
46 | self.weight_keys = [['conv1.weight', 'conv1.bias'],
47 | ['conv2.weight', 'conv2.bias'],
48 | ['fc1.weight', 'fc1.bias'],
49 | ['fc2.weight', 'fc2.bias'],
50 | ['fc3.weight', 'fc3.bias']]
51 |
52 | if self.name == 'cifar100_LeNet':
53 | self.n_cls = 100
54 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5)
55 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5)
56 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
57 | self.fc1 = nn.Linear(64 * 5 * 5, 1024)
58 | self.fc2 = nn.Linear(1024, 1024)
59 | self.fc3 = nn.Linear(1024, self.n_cls)
60 | self.weight_keys = [['conv1.weight', 'conv1.bias'],
61 | ['conv2.weight', 'conv2.bias'],
62 | ['fc1.weight', 'fc1.bias'],
63 | ['fc2.weight', 'fc2.bias'],
64 | ['fc3.weight', 'fc3.bias']]
65 |
66 |
67 | def forward(self, x):
68 |
69 | if self.name == 'emnist_NN':
70 | x = x.view(-1, 1 * 28 * 28)
71 | x = F.relu(self.fc1(x))
72 | x = F.relu(self.fc2(x))
73 | x = self.fc3(x)
74 |
75 | if self.name == 'FMNIST_CNN':
76 | x = self.pool(F.relu(self.conv1(x)))
77 | x = self.pool(F.relu(self.conv2(x)))
78 | x = x.view(-1, 12 * 4 * 4)
79 | x = F.relu(self.fc1(x))
80 | x = F.relu(self.fc2(x))
81 | x = self.fc3(x)
82 |
83 |
84 | if self.name == 'cifar10_LeNet':
85 | x = self.pool(F.relu(self.conv1(x)))
86 | x = self.pool(F.relu(self.conv2(x)))
87 | x = x.view(-1, 64 * 5 * 5)
88 | x = F.relu(self.fc1(x))
89 | x = F.relu(self.fc2(x))
90 | x = self.fc3(x)
91 |
92 | if self.name == 'cifar100_LeNet':
93 | x = self.pool(F.relu(self.conv1(x)))
94 | x = self.pool(F.relu(self.conv2(x)))
95 | x = x.view(-1, 64 * 5 * 5)
96 | x = F.relu(self.fc1(x))
97 | x = F.relu(self.fc2(x))
98 | x = self.fc3(x)
99 |
100 |
101 | return x
102 |
--------------------------------------------------------------------------------
/code_FedCR/models/Nets_PAC.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision.models as models
5 |
6 |
7 | class client_model(nn.Module):
8 | def __init__(self, name, args=True):
9 | super(client_model, self).__init__()
10 | self.name = name
11 |
12 |
13 | if self.name == 'emnist_NN':
14 | self.n_cls = 10
15 | self.fc1 = nn.Linear(1 * 28 * 28, 1024)
16 | self.fc2 = nn.Linear(1024, 1024)
17 | self.fc3 = nn.Linear(1024, self.n_cls)
18 | self.weight_keys = [['fc1.weight', 'fc1.bias'],
19 | ['fc2.weight', 'fc2.bias'],
20 | ['fc3.weight', 'fc3.bias']]
21 |
22 |
23 | if self.name == 'FMNIST_CNN':
24 | self.n_cls = 10
25 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5)
26 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
27 | self.conv2 = nn.Conv2d(in_channels=4, out_channels=12, kernel_size=5)
28 | self.fc1 = nn.Linear(12 * 4 * 4, 1024)
29 | self.fc2 = nn.Linear(1024, 1024)
30 | self.fc3 = nn.Linear(1024, self.n_cls)
31 |
32 | self.weight_keys = [['conv1.weight', 'conv1.bias'],
33 | ['conv2.weight', 'conv2.bias'],
34 | ['fc1.weight', 'fc1.bias'],
35 | ['fc2.weight', 'fc2.bias'],
36 | ['fc3.weight', 'fc3.bias']]
37 |
38 |
39 | if self.name == 'cifar10_LeNet':
40 | self.n_cls = 10
41 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5)
42 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5)
43 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
44 | self.fc1 = nn.Linear(64 * 5 * 5, 1024)
45 | self.fc2 = nn.Linear(1024, 1024)
46 | self.fc3 = nn.Linear(1024, self.n_cls)
47 | self.weight_keys = [['conv1.weight', 'conv1.bias'],
48 | ['conv2.weight', 'conv2.bias'],
49 | ['fc1.weight', 'fc1.bias'],
50 | ['fc2.weight', 'fc2.bias'],
51 | ['fc3.weight', 'fc3.bias']]
52 |
53 |
54 | if self.name == 'cifar100_LeNet':
55 | self.n_cls = 100
56 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5)
57 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5)
58 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
59 | self.fc1 = nn.Linear(64 * 5 * 5, 1024)
60 | self.fc2 = nn.Linear(1024, 1024)
61 | self.fc3 = nn.Linear(1024, self.n_cls)
62 | self.weight_keys = [['conv1.weight', 'conv1.bias'],
63 | ['conv2.weight', 'conv2.bias'],
64 | ['fc1.weight', 'fc1.bias'],
65 | ['fc2.weight', 'fc2.bias'],
66 | ['fc3.weight', 'fc3.bias']]
67 |
68 |
69 | def forward(self, x):
70 |
71 | if self.name == 'emnist_NN':
72 | x = x.view(-1, 1 * 28 * 28)
73 | x = F.relu(self.fc1(x))
74 | fecture_output = F.relu(self.fc2(x))
75 |
76 | x = self.fc3(fecture_output)
77 |
78 | if self.name == 'FMNIST_CNN':
79 | x = self.pool(F.relu(self.conv1(x)))
80 | x = self.pool(F.relu(self.conv2(x)))
81 | x = x.view(-1, 12 * 4 * 4)
82 | x = F.relu(self.fc1(x))
83 | fecture_output = F.relu(self.fc2(x))
84 |
85 | x = self.fc3(fecture_output)
86 |
87 | if self.name == 'cifar10_LeNet':
88 | x = self.pool(F.relu(self.conv1(x)))
89 | x = self.pool(F.relu(self.conv2(x)))
90 | x = x.view(-1, 64 * 5 * 5)
91 | x = F.relu(self.fc1(x))
92 | fecture_output = F.relu(self.fc2(x))
93 |
94 | x = self.fc3(fecture_output)
95 |
96 | if self.name == 'cifar100_LeNet':
97 | x = self.pool(F.relu(self.conv1(x)))
98 | x = self.pool(F.relu(self.conv2(x)))
99 | x = x.view(-1, 64 * 5 * 5)
100 | x = F.relu(self.fc1(x))
101 | fecture_output = F.relu(self.fc2(x))
102 |
103 | x = self.fc3(fecture_output)
104 |
105 |
106 | return x, fecture_output
107 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FedCR: Personalized Federated Learning Based on Across-Client Common Representation with Conditional Mutual Information Regularization
2 |
3 | This directory contains source code for evaluating federated learning with different methods on various models and tasks. The code was developed for a paper, "FedCR: Personalized Federated Learning Based on Across-Client Common Representation with Conditional Mutual Information Regularization".
4 |
5 | ## Requirements
6 |
7 | Some pip packages are required by this library, and may need to be installed. For more details, see `requirements.txt`. We recommend running `pip install --requirement "requirements.txt"`.
8 |
9 | Below we give a summary of the datasets, tasks, and models used in this code.
10 |
11 |
12 | ## Task and dataset summary
13 |
14 | Note that we put the dataset under the directory .\federated-learning-master\Folder
15 |
16 |
17 |
18 | | Directory | Model | Task Summary |
19 | |------------------|-------------------------------------|---------------------------|
20 | | CIFAR-10 | CNN (with two convolutional layers) | Image classification |
21 | | CIFAR-100 | CNN (with two convolutional layers) | Image classification |
22 | | EMNIST | NN (fully connected neural network) | Digit recognition |
23 | | FMNIST | CNN (with two convolutional layer | Image classification |
24 |
25 |
26 |
27 |
28 | ## Training
29 | In this code, we compare 10 optimization methods: **FedAvg**, **FedAvg-FT**, **FedPer**, **LG-FedAvg**, **FedRep**, **FedBABU**, **Ditto**, **FedSR-FT**, **FedPAC**, and **FedCR**. Those methods use vanilla SGD on clients. To recreate our experimental results for each method, for example, for 100 clients and 10% participation rate, on the cifar100 data set with Dirichlet (0.3) split, run those commands for different methods:
30 |
31 | **FedAvg** and ***FedAvg-FT**:
32 | ```
33 | python main_fed.py --filepath FedAvg.txt --dataset CIFAR100 --method fedavg --lr 0.01 --local_ep 10 --lr_decay 1 --rule Dirichlet --dir_a 0.3 --gpu 0 --epoch 500 --beta 0.001 --bs 10 --local_bs 48
34 | ```
35 |
36 | **FedPer**:
37 | ```
38 | python main_fed.py --filepath FedPer.txt --dataset CIFAR100 --method fedper --lr 0.01 --local_ep 10 --lr_decay 1 --rule Dirichlet --dir_a 1 --gpu 0 --epoch 500 --beta 0.001 --bs 10 --local_bs 48
39 | ```
40 |
41 | **LG-FedAvg**:
42 | ```
43 | python main_fed.py --filepath LG-FedAvg.txt --dataset CIFAR100 --method lg --lr 0.01 --local_ep 10 --lr_decay 1 --rule Dirichlet --dir_a 0.3 --gpu 0 --epoch 500 --beta 0.001 --bs 10 --local_bs 48
44 | ```
45 |
46 | **FedRep**:
47 | ```
48 | python main_fed.py --filepath FedRep.txt --dataset CIFAR100 --method fedrep --lr 0.01 --local_ep 10 --local_rep_ep 1 --lr_decay 1 --rule Dirichlet --dir_a 0.3 --gpu 0 --epoch 500 --beta 0.001 --bs 10 --local_bs 48
49 | ```
50 |
51 | **FedBABU**:
52 | ```
53 | python main_fed.py --filepath FedBABU.txt --dataset CIFAR100 --method fedbabu --lr 0.01 --local_ep 10 --lr_decay 1 --rule Dirichlet --dir_a 0.3 --gpu 0 --epoch 500 --beta 0.001 --bs 10 --local_bs 48
54 | ```
55 |
56 | **Ditto**:
57 | ```
58 | python /main_fed_ditto.py --filepath Ditto.txt --dataset CIFAR100 --method ditto --mu 0.1 --lr 0.5 --local_ep 10 --lr_decay 1 --rule Dirichlet --dir_a 0.3 --gpu 0 --epoch 500 --beta 0.001 --bs 10 --local_bs 48
59 | ```
60 |
61 | **FedSR-FT**:
62 | ```
63 | python main_fedSR.py --filepath FedSR-FT.txt --dataset CIFAR100 --method fedSR --lr 0.005 --local_ep 10 --lr_decay 1 --rule Dirichlet --dir_a 0.3 --gpu 0 --num_avg_train 1 --num_avg 1 --epoch 500 --dimZ 256 --CMI 0.001 --L2R 0.001 --bs 10 --local_bs 48 --beta2 1
64 | ```
65 |
66 | **FedPAC**:
67 | ```
68 | python main_fed_PAC.py --filepath FedPAC.txt --dataset CIFAR100 --method fedPAC --lr 0.01 --beta_PAC 1 --local_ep 10 --local_rep_ep 1 --lr_decay 1 --rule Dirichlet --dir_a 0.3 --gpu 0 --epoch 500 --dimZ 512 --beta 0.001 --bs 10 --local_bs 48
69 | ```
70 |
71 | **FedCR**:
72 | ```
73 | python main_fed.py --filepath FedCR.txt --dataset CIFAR100 --method FedCR --lr 0.05 --local_ep 10 --lr_decay 1 --rule Dirichlet --dir_a 0.3 --gpu 0 --num_avg_train 18 --num_avg 18 --epoch 500 --dimZ 512 --beta 0.0001 --bs 10 --local_bs 48 --beta2 1
74 | ```
75 |
76 |
77 |
78 | ## Other hyperparameters and reproducibility
79 |
80 | All other hyperparameters are set by default to the values used in the `Experiment Details` of our Appendix. This includes the batch size, the number of clients per round, the number of client local updates, local learning rate, and model parameter flags. While they can be set for different behavior (such as varying the number of client local updates), they should not be changed if one wishes to reproduce the results from our paper.
81 |
82 |
--------------------------------------------------------------------------------
/code_FedCR/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | 1674736627687
55 |
56 |
57 | 1674736627687
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
--------------------------------------------------------------------------------
/code_FedCR/main_fed_ditto.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 |
5 | import matplotlib
6 | matplotlib.use('Agg')
7 | import ssl
8 |
9 | import copy
10 | import itertools
11 | import random
12 | import torch
13 | import numpy as np
14 | from utils.options import args_parser
15 | from utils.seed import setup_seed
16 | from utils.logg import get_logger
17 | from models.Nets import client_model
18 | from models.Nets_VIB import client_model_VIB
19 | from utils.utils_dataset import DatasetObject
20 | from models.distributed_training_utils_ditto import Client, Server
21 | torch.set_printoptions(
22 | precision=8,
23 | threshold=1000,
24 | edgeitems=3,
25 | linewidth=150,
26 | profile=None,
27 | sci_mode=False
28 | )
29 | if __name__ == '__main__':
30 |
31 | ssl._create_default_https_context = ssl._create_unverified_context
32 | # parse args
33 | args = args_parser()
34 |
35 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
36 | setup_seed(args.seed)
37 |
38 |
39 | data_path = 'Folder/'
40 | data_obj = DatasetObject(dataset=args.dataset, n_client=args.num_users, seed=args.seed, rule=args.rule, class_main=args.class_main, data_path=data_path, frac_data=args.frac_data, dir_alpha=args.dir_a)
41 |
42 | clnt_x = data_obj.clnt_x;
43 | clnt_y = data_obj.clnt_y;
44 | tst_x = data_obj.tst_x;
45 | tst_y = data_obj.tst_y
46 |
47 | # build model
48 |
49 | if args.dataset == 'CIFAR100':
50 | net_glob = client_model('cifar100_LeNet').to(args.device)
51 | elif args.dataset == 'CIFAR10':
52 | net_glob = client_model('cifar10_LeNet').to(args.device)
53 | elif args.dataset == 'EMNIST':
54 | net_glob = client_model('emnist_NN', [1 * 28 * 28, 10]).to(args.device)
55 | elif args.dataset == 'FMNIST':
56 | net_glob = client_model('FMNIST_CNN', [1 * 28 * 28, 10]).to(args.device)
57 | else:
58 | exit('Error: unrecognized model')
59 |
60 | total_num_layers = len(net_glob.state_dict().keys())
61 | net_keys = [*net_glob.state_dict().keys()]
62 |
63 |
64 | if args.method == 'ditto':
65 | w_glob_keys = []
66 | else:
67 | exit('Error: unrecognized data4')
68 |
69 | w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys))
70 |
71 | clients = [Client(model=copy.deepcopy(net_glob).to(args.device), args=args, trn_x=data_obj.clnt_x[i],
72 | trn_y=data_obj.clnt_y[i], tst_x=data_obj.tst_x[i], tst_y=data_obj.tst_y[i], n_cls = data_obj.n_cls, dataset_name=data_obj.dataset, id_num=i) for i in range(args.num_users)]
73 | server = Server(model = (net_glob).to(args.device), args = args, n_cls = data_obj.n_cls)
74 |
75 | logger = get_logger(args.filepath)
76 | logger.info('--------args----------')
77 | for k in list(vars(args).keys()):
78 | logger.info('%s: %s' % (k, vars(args)[k]))
79 | logger.info('--------args----------\n')
80 | logger.info('total_num_layers')
81 | logger.info(total_num_layers)
82 | logger.info('net_keys')
83 | logger.info(net_keys)
84 | logger.info('w_glob_keys')
85 | logger.info(w_glob_keys)
86 |
87 | logger.info('start training!')
88 |
89 | for iter in range(args.epochs + 1):
90 | net_glob.train()
91 |
92 | m = max(int(args.frac * args.num_users), 1)
93 | if iter == args.epochs:
94 | m = args.num_users
95 | participating_clients = random.sample(clients, m)
96 |
97 | last = iter == args.epochs
98 |
99 | for client in participating_clients:
100 | client.synchronize_with_server(server, w_glob_keys)
101 |
102 | client.compute_bias()
103 |
104 | client.compute_weight_update(w_glob_keys, server, last)
105 |
106 |
107 |
108 | server.aggregate_weight_updates(clients=participating_clients, iter=iter)
109 |
110 |
111 | #-----------------------------------------------test--------------------------------------------------------------------
112 |
113 | #-----------------------------------------------test--------------------------------------------------------------------
114 |
115 | if iter % args.test_freq==args.test_freq-1 or iter>=args.epochs-10:
116 | results_loss =[]; results_acc = []
117 | results_loss_last = []; results_acc_last = []
118 | for client in clients:
119 |
120 | results_test, loss_test1 = client.evaluate(data_x=client.tst_x, data_y=client.tst_y, dataset_name=data_obj.dataset)
121 |
122 | results_loss.append(loss_test1)
123 | results_acc.append(results_test)
124 | results_loss = np.mean(results_loss)
125 | results_acc = np.mean(results_acc)
126 |
127 | logger.info('Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'.
128 | format(iter, args.lr, results_loss, results_acc))
129 |
130 |
131 | args.lr = args.lr * (args.lr_decay)
132 |
133 | logger.info('finish training!')
134 |
135 |
136 |
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/code_FedCR/models/resnet18.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 |
3 | For Pre-activation ResNet, see 'preact_resnet.py'.
4 |
5 | Reference:
6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
8 | '''
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from torch.autograd import Variable
14 |
15 |
16 | class BasicBlock(nn.Module):
17 | expansion = 1
18 |
19 | def __init__(self, in_planes, planes, stride=1, use_batchnorm=True):
20 | super(BasicBlock, self).__init__()
21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
22 | self.bn1 = nn.BatchNorm2d(planes)
23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
24 | self.bn2 = nn.BatchNorm2d(planes)
25 |
26 | if not use_batchnorm:
27 | self.bn1 = self.bn2 = nn.Sequential()
28 |
29 | self.shortcut = nn.Sequential()
30 | if stride != 1 or in_planes != self.expansion*planes:
31 | self.shortcut = nn.Sequential(
32 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
33 | nn.BatchNorm2d(self.expansion*planes) if use_batchnorm else nn.Sequential()
34 | )
35 |
36 | def forward(self, x):
37 | out = F.relu(self.bn1(self.conv1(x)))
38 | out = self.bn2(self.conv2(out))
39 | out += self.shortcut(x)
40 | out = F.relu(out)
41 | return out
42 |
43 |
44 | class Bottleneck(nn.Module):
45 | expansion = 4
46 |
47 | def __init__(self, in_planes, planes, stride=1, use_batchnorm=True):
48 | super(Bottleneck, self).__init__()
49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
50 | self.bn1 = nn.BatchNorm2d(planes)
51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
52 | self.bn2 = nn.BatchNorm2d(planes)
53 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
54 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
55 |
56 | if not use_batchnorm:
57 | self.bn1 = self.bn2 = self.bn3 = nn.Sequential()
58 |
59 | self.shortcut = nn.Sequential()
60 | if stride != 1 or in_planes != self.expansion*planes:
61 | self.shortcut = nn.Sequential(
62 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
63 | nn.BatchNorm2d(self.expansion*planes) if use_batchnorm else nn.Sequential()
64 | )
65 |
66 | def forward(self, x):
67 | out = F.relu(self.bn1(self.conv1(x)))
68 | out = F.relu(self.bn2(self.conv2(out)))
69 | out = self.bn3(self.conv3(out))
70 | out += self.shortcut(x)
71 | out = F.relu(out)
72 | return out
73 |
74 |
75 | class ResNet(nn.Module):
76 | def __init__(self, block, num_blocks, num_classes=10, use_batchnorm=True):
77 | super(ResNet, self).__init__()
78 | self.in_planes = 64
79 | self.use_batchnorm = use_batchnorm
80 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
81 | self.bn1 = nn.BatchNorm2d(64) if use_batchnorm else nn.Sequential()
82 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
83 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
84 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
85 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
86 | self.linear = nn.Linear(512*block.expansion, num_classes)
87 |
88 | def _make_layer(self, block, planes, num_blocks, stride):
89 | strides = [stride] + [1]*(num_blocks-1)
90 | layers = []
91 | for stride in strides:
92 | layers.append(block(self.in_planes, planes, stride, self.use_batchnorm))
93 | self.in_planes = planes * block.expansion
94 | return nn.Sequential(*layers)
95 |
96 | def forward(self, x):
97 | out = F.relu(self.bn1(self.conv1(x)))
98 | out = self.layer1(out)
99 | out = self.layer2(out)
100 | out = self.layer3(out)
101 | out = self.layer4(out)
102 | out = F.avg_pool2d(out, 4)
103 | out = out.view(out.size(0), -1)
104 | out = self.linear(out)
105 | return out
106 |
107 |
108 | def ResNet18(use_batchnorm=True):
109 | return ResNet(BasicBlock, [2,2,2,2], use_batchnorm=use_batchnorm)
110 |
111 | def ResNet34(use_batchnorm=True):
112 | return ResNet(BasicBlock, [3,4,6,3], use_batchnorm=use_batchnorm)
113 |
114 | def ResNet50(use_batchnorm=True):
115 | return ResNet(Bottleneck, [3,4,6,3], use_batchnorm=use_batchnorm)
116 |
117 | def ResNet101(use_batchnorm=True):
118 | return ResNet(Bottleneck, [3,4,23,3], use_batchnorm=use_batchnorm)
119 |
120 | def ResNet152(use_batchnorm=True):
121 | return ResNet(Bottleneck, [3,8,36,3], use_batchnorm=use_batchnorm)
122 |
123 |
124 | def test():
125 | net = ResNet18()
126 | y = net(Variable(torch.randn(1,3,32,32)))
127 | print(y.size())
128 |
129 | # test()
130 |
--------------------------------------------------------------------------------
/code_FedCR/main_fedSR.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 |
5 | import matplotlib
6 | matplotlib.use('Agg')
7 | import ssl
8 |
9 | import copy
10 | import itertools
11 | import random
12 | import torch
13 | import numpy as np
14 | from utils.options import args_parser
15 | from utils.seed import setup_seed
16 | from utils.logg import get_logger
17 | from models.NetsSR import Model_CMI
18 | from models.Nets_VIB import client_model_VIB
19 | from utils.utils_dataset import DatasetObject
20 | from models.distributed_training_utilsSR import Client, Server
21 | torch.set_printoptions(
22 | precision=8,
23 | threshold=1000,
24 | edgeitems=3,
25 | linewidth=150,
26 | profile=None,
27 | sci_mode=False
28 | )
29 | if __name__ == '__main__':
30 |
31 | ssl._create_default_https_context = ssl._create_unverified_context
32 | # parse args
33 | args = args_parser()
34 |
35 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
36 | setup_seed(args.seed)
37 |
38 |
39 | data_path = 'Folder/'
40 | data_obj = DatasetObject(dataset=args.dataset, n_client=args.num_users, seed=args.seed, rule=args.rule, class_main=args.class_main, data_path=data_path, frac_data=args.frac_data, dir_alpha=args.dir_a)
41 |
42 | clnt_x = data_obj.clnt_x;
43 | clnt_y = data_obj.clnt_y;
44 | tst_x = data_obj.tst_x;
45 | tst_y = data_obj.tst_y
46 |
47 | # build model
48 |
49 |
50 | net_glob = Model_CMI(args, args.dimZ, args.alpha, args.dataset).to(args.device)
51 |
52 |
53 | total_num_layers = len(net_glob.state_dict().keys())
54 | net_keys = [*net_glob.state_dict().keys()]
55 |
56 | if args.method == 'fedSR':
57 | w_glob_keys = []
58 |
59 | w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys))
60 |
61 | clients = [Client(model=Model_CMI(args, args.dimZ, args.alpha, args.dataset).to(args.device), args=args, trn_x=data_obj.clnt_x[i],
62 | trn_y=data_obj.clnt_y[i], tst_x=data_obj.tst_x[i], tst_y=data_obj.tst_y[i], n_cls = data_obj.n_cls, dataset_name=data_obj.dataset, id_num=i) for i in range(args.num_users)]
63 | server = Server(model = (net_glob).to(args.device), args = args, n_cls = data_obj.n_cls)
64 |
65 |
66 | logger = get_logger(args.filepath)
67 | logger.info('--------args----------')
68 | for k in list(vars(args).keys()):
69 | logger.info('%s: %s' % (k, vars(args)[k]))
70 | logger.info('--------args----------\n')
71 | logger.info('total_num_layers')
72 | logger.info(total_num_layers)
73 | logger.info('net_keys')
74 | logger.info(net_keys)
75 | logger.info('w_glob_keys')
76 | logger.info(w_glob_keys)
77 |
78 | logger.info('start training!')
79 |
80 | for iter in range(args.epochs + 1):
81 | net_glob.train()
82 |
83 | m = max(int(args.frac * args.num_users), 1)
84 | if iter == args.epochs:
85 | m = args.num_users
86 | participating_clients = random.sample(clients, m)
87 |
88 | last = iter == args.epochs
89 |
90 | for client in participating_clients:
91 |
92 |
93 | client.synchronize_with_server(server, w_glob_keys)
94 |
95 | client.compute_weight_update(w_glob_keys, server, last)
96 |
97 |
98 |
99 | server.aggregate_weight_updates(clients=participating_clients, iter=iter)
100 |
101 |
102 | #-----------------------------------------------test--------------------------------------------------------------------
103 |
104 | #-----------------------------------------------test--------------------------------------------------------------------
105 |
106 | if iter % args.test_freq==args.test_freq-1 or iter>=args.epochs-10:
107 | results_loss =[]; results_acc = []
108 | results_loss_last = []; results_acc_last = []
109 | for client in clients:
110 |
111 | results_test, loss_test1 = client.evaluate(data_x=client.tst_x, data_y=client.tst_y,
112 | dataset_name=data_obj.dataset)
113 | if last:
114 | results_test_last, loss_test1_last = client.evaluate(data_x=client.tst_x, data_y=client.tst_y,
115 | dataset_name=data_obj.dataset)
116 |
117 | results_loss.append(loss_test1)
118 | results_acc.append(results_test)
119 |
120 | if last and args.method == 'fedSR':
121 | results_loss_last.append(loss_test1_last)
122 | results_acc_last.append(results_test_last)
123 |
124 | results_loss = np.mean(results_loss)
125 | results_acc = np.mean(results_acc)
126 | if last:
127 | logger.info('Final Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'.
128 | format(iter, args.lr, results_loss, results_acc))
129 | else:
130 | logger.info('Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'.
131 | format(iter, args.lr, results_loss, results_acc))
132 |
133 | if last and args.method == 'fedSR':
134 | results_loss_last = np.mean(results_loss_last)
135 | results_acc_last = np.mean(results_acc_last)
136 | logger.info('Final FT Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'.
137 | format(iter, args.lr, results_loss_last, results_acc_last))
138 |
139 | args.lr = args.lr * (args.lr_decay)
140 |
141 | logger.info('finish training!')
142 |
143 |
144 |
145 |
146 |
147 |
--------------------------------------------------------------------------------
/code_FedCR/models/main_fedSR.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 |
5 | import matplotlib
6 | matplotlib.use('Agg')
7 | import ssl
8 |
9 | import copy
10 | import itertools
11 | import random
12 | import torch
13 | import numpy as np
14 | from utils.options import args_parser
15 | from utils.seed import setup_seed
16 | from utils.logg import get_logger
17 | from models.NetsSR import Model_CMI, Model_CMI_server
18 | from models.Nets_VIB import client_model_VIB
19 | from utils.utils_dataset import DatasetObject
20 | from models.distributed_training_utilsSR import Client, Server
21 | torch.set_printoptions(
22 | precision=8,
23 | threshold=1000,
24 | edgeitems=3,
25 | linewidth=150,
26 | profile=None,
27 | sci_mode=False
28 | )
29 | if __name__ == '__main__':
30 |
31 | ssl._create_default_https_context = ssl._create_unverified_context
32 | # parse args
33 | args = args_parser()
34 |
35 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
36 | setup_seed(args.seed)
37 |
38 |
39 | data_path = 'Folder/'
40 | data_obj = DatasetObject(dataset=args.dataset, n_client=args.num_users, seed=args.seed, rule=args.rule, class_main=args.class_main, data_path=data_path, frac_data=args.frac_data, dir_alpha=args.dir_a)
41 |
42 | clnt_x = data_obj.clnt_x;
43 | clnt_y = data_obj.clnt_y;
44 | tst_x = data_obj.tst_x;
45 | tst_y = data_obj.tst_y
46 |
47 | # build model
48 |
49 |
50 | net_glob = client_model_VIB(args, args.dimZ, args.alpha, args.dataset).to(args.device)
51 |
52 |
53 | total_num_layers = len(net_glob.state_dict().keys())
54 | net_keys = [*net_glob.state_dict().keys()]
55 |
56 | if args.method == 'fedSR':
57 | w_glob_keys = []
58 |
59 | w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys))
60 |
61 | clients = [Client(model=client_model_VIB(args, args.dimZ, args.alpha, args.dataset).to(args.device), args=args, trn_x=data_obj.clnt_x[i],
62 | trn_y=data_obj.clnt_y[i], tst_x=data_obj.tst_x[i], tst_y=data_obj.tst_y[i], n_cls = data_obj.n_cls, dataset_name=data_obj.dataset, id_num=i) for i in range(args.num_users)]
63 | server = Server(model = (net_glob).to(args.device), args = args, n_cls = data_obj.n_cls)
64 |
65 |
66 | logger = get_logger(args.filepath)
67 | logger.info('--------args----------')
68 | for k in list(vars(args).keys()):
69 | logger.info('%s: %s' % (k, vars(args)[k]))
70 | logger.info('--------args----------\n')
71 | logger.info('total_num_layers')
72 | logger.info(total_num_layers)
73 | logger.info('net_keys')
74 | logger.info(net_keys)
75 | logger.info('w_glob_keys')
76 | logger.info(w_glob_keys)
77 |
78 | logger.info('start training!')
79 |
80 | for iter in range(args.epochs + 1):
81 | net_glob.train()
82 |
83 | m = max(int(args.frac * args.num_users), 1)
84 | if iter == args.epochs:
85 | m = args.num_users
86 | participating_clients = random.sample(clients, m)
87 |
88 | last = iter == args.epochs
89 |
90 | for client in participating_clients:
91 |
92 |
93 | client.synchronize_with_server(server, w_glob_keys)
94 |
95 | client.compute_weight_update(w_glob_keys, server, last)
96 |
97 |
98 |
99 | server.aggregate_weight_updates(clients=participating_clients, iter=iter)
100 |
101 |
102 | #-----------------------------------------------test--------------------------------------------------------------------
103 |
104 | #-----------------------------------------------test--------------------------------------------------------------------
105 |
106 | if iter % args.test_freq==args.test_freq-1 or iter>=args.epochs-10:
107 | results_loss =[]; results_acc = []
108 | results_loss_last = []; results_acc_last = []
109 | for client in clients:
110 |
111 | results_test, loss_test1 = client.evaluate(data_x=client.tst_x, data_y=client.tst_y,
112 | dataset_name=data_obj.dataset)
113 | if last:
114 | results_test_last, loss_test1_last = client.evaluate(data_x=client.tst_x, data_y=client.tst_y,
115 | dataset_name=data_obj.dataset)
116 |
117 | results_loss.append(loss_test1)
118 | results_acc.append(results_test)
119 |
120 | if last and args.method == 'fedSR':
121 | results_loss_last.append(loss_test1_last)
122 | results_acc_last.append(results_test_last)
123 |
124 | results_loss = np.mean(results_loss)
125 | results_acc = np.mean(results_acc)
126 | if last:
127 | logger.info('Final Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'.
128 | format(iter, args.lr, results_loss, results_acc))
129 | else:
130 | logger.info('Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'.
131 | format(iter, args.lr, results_loss, results_acc))
132 |
133 | if last and args.method == 'fedSR':
134 | results_loss_last = np.mean(results_loss_last)
135 | results_acc_last = np.mean(results_acc_last)
136 | logger.info('Final FT Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'.
137 | format(iter, args.lr, results_loss_last, results_acc_last))
138 |
139 | args.lr = args.lr * (args.lr_decay)
140 |
141 | logger.info('finish training!')
142 |
143 |
144 |
145 |
146 |
147 |
--------------------------------------------------------------------------------
/code_FedCR/main_fed_local.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 |
5 | import matplotlib
6 | matplotlib.use('Agg')
7 | import ssl
8 |
9 | import copy
10 | import itertools
11 | import random
12 | import torch
13 | import numpy as np
14 | from utils.options import args_parser
15 | from utils.seed import setup_seed
16 | from utils.logg import get_logger
17 | from models.Nets import client_model
18 | from models.Nets_VIB import client_model_VIB
19 | from utils.utils_dataset import DatasetObject
20 | from models.distributed_training_utils import Client, Server
21 | torch.set_printoptions(
22 | precision=8,
23 | threshold=1000,
24 | edgeitems=3,
25 | linewidth=150,
26 | profile=None,
27 | sci_mode=False
28 | )
29 | if __name__ == '__main__':
30 |
31 | ssl._create_default_https_context = ssl._create_unverified_context
32 | # parse args
33 | args = args_parser()
34 |
35 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
36 | setup_seed(args.seed)
37 |
38 |
39 | data_path = 'Folder/'
40 | data_obj = DatasetObject(dataset=args.dataset, n_client=args.num_users, seed=args.seed, rule=args.rule, class_main=args.class_main, data_path=data_path, frac_data=args.frac_data, dir_alpha=args.dir_a)
41 |
42 | clnt_x = data_obj.clnt_x;
43 | clnt_y = data_obj.clnt_y;
44 | tst_x = data_obj.tst_x;
45 | tst_y = data_obj.tst_y
46 |
47 | # build model
48 | if args.method == 'FedCR':
49 | net_glob = client_model_VIB(args, args.dimZ, args.alpha, args.dataset).to(args.device)
50 | else:
51 | if args.dataset == 'CIFAR100':
52 | net_glob = client_model('cifar100_LeNet').to(args.device)
53 | elif args.dataset == 'CIFAR10':
54 | net_glob = client_model('cifar10_LeNet').to(args.device)
55 | elif args.dataset == 'EMNIST':
56 | net_glob = client_model('emnist_NN', [1 * 28 * 28, 10]).to(args.device)
57 | elif args.dataset == 'FMNIST':
58 | net_glob = client_model('FMNIST_CNN', [1 * 28 * 28, 10]).to(args.device)
59 | else:
60 | exit('Error: unrecognized model')
61 |
62 | total_num_layers = len(net_glob.state_dict().keys())
63 | net_keys = [*net_glob.state_dict().keys()]
64 |
65 | if args.method == 'fedrep' or args.method == 'fedper' or args.method == 'fedbabu':
66 | if 'CIFAR100' in args.dataset:
67 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3]]
68 | elif 'CIFAR10' in args.dataset:
69 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3]]
70 | elif 'EMNIST' in args.dataset:
71 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1]]
72 | elif 'FMNIST' in args.dataset:
73 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3]]
74 | else:
75 | exit('Error: unrecognized data1')
76 | elif args.method == 'lg':
77 | if 'CIFAR100' in args.dataset:
78 | w_glob_keys = [net_glob.weight_keys[i] for i in [3, 4]]
79 | elif 'CIFAR10' in args.dataset:
80 | w_glob_keys = [net_glob.weight_keys[i] for i in [3, 4]]
81 | elif 'EMNIST' in args.dataset:
82 | w_glob_keys = [net_glob.weight_keys[i] for i in [1, 2]]
83 | elif 'FMNIST' in args.dataset:
84 | w_glob_keys = [net_glob.weight_keys[i] for i in [3, 4]]
85 | else:
86 | exit('Error: unrecognized data2')
87 | elif args.method == 'FedCR':
88 | if 'CIFAR100' in args.dataset:
89 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3, 4]]
90 | elif 'CIFAR10' in args.dataset:
91 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3, 4]]
92 | elif 'EMNIST' in args.dataset:
93 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2]]
94 | elif 'FMNIST' in args.dataset:
95 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3, 4]]
96 | else:
97 | exit('Error: unrecognized data3')
98 | elif args.method == 'fedavg' or args.method == 'ditto' or args.method == 'maml':
99 | w_glob_keys = []
100 | else:
101 | exit('Error: unrecognized data4')
102 |
103 | w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys))
104 |
105 | clients = [Client(model=copy.deepcopy(net_glob).to(args.device), args=args, trn_x=data_obj.clnt_x[i],
106 | trn_y=data_obj.clnt_y[i], tst_x=data_obj.tst_x[i], tst_y=data_obj.tst_y[i], n_cls = data_obj.n_cls, dataset_name=data_obj.dataset, id_num=i) for i in range(args.num_users)]
107 |
108 | server = Server(model = (net_glob).to(args.device), args = args, n_cls = data_obj.n_cls)
109 |
110 | logger = get_logger(args.filepath)
111 | logger.info('--------args----------')
112 | for k in list(vars(args).keys()):
113 | logger.info('%s: %s' % (k, vars(args)[k]))
114 | logger.info('--------args----------\n')
115 | logger.info('total_num_layers')
116 | logger.info(total_num_layers)
117 | logger.info('net_keys')
118 | logger.info(net_keys)
119 | logger.info('w_glob_keys')
120 | logger.info(w_glob_keys)
121 |
122 | logger.info('start training!')
123 |
124 | results_loss = [];
125 | results_acc = []
126 |
127 | for client in clients:
128 |
129 | client.compute_weight_update(w_glob_keys, server, last=False)
130 |
131 | results_test, loss_test1 = client.evaluate(data_x=client.tst_x, data_y=client.tst_y,
132 | dataset_name=data_obj.dataset)
133 |
134 |
135 | results_loss.append(loss_test1)
136 | results_acc.append(results_test)
137 |
138 | results_loss = np.mean(results_loss)
139 | results_acc = np.mean(results_acc)
140 |
141 | logger.info('Final Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'.format(iter, args.lr, results_loss, results_acc))
142 |
143 | args.lr = args.lr * (args.lr_decay)
144 |
145 | logger.info('finish training!')
146 |
147 |
148 |
149 |
150 |
151 |
152 |
--------------------------------------------------------------------------------
/code_FedCR/main_fed_PAC.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 |
5 | import matplotlib
6 | matplotlib.use('Agg')
7 | import ssl
8 |
9 | import copy
10 | import itertools
11 | import random
12 | import torch
13 | import numpy as np
14 | from utils.options import args_parser
15 | from utils.seed import setup_seed
16 | from utils.logg import get_logger
17 | from models.Nets_PAC import client_model
18 | from models.Nets_VIB import client_model_VIB
19 | from utils.utils_dataset import DatasetObject
20 | from models.distributed_training_utils_PAC import Client, Server
21 | torch.set_printoptions(
22 | precision=8,
23 | threshold=1000,
24 | edgeitems=3,
25 | linewidth=150,
26 | profile=None,
27 | sci_mode=False
28 | )
29 | if __name__ == '__main__':
30 |
31 | ssl._create_default_https_context = ssl._create_unverified_context
32 | # parse args
33 | args = args_parser()
34 |
35 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
36 | setup_seed(args.seed)
37 |
38 |
39 | data_path = 'Folder/'
40 | data_obj = DatasetObject(dataset=args.dataset, n_client=args.num_users, seed=args.seed, rule=args.rule, class_main=args.class_main, data_path=data_path, frac_data=args.frac_data, dir_alpha=args.dir_a)
41 |
42 | clnt_x = data_obj.clnt_x;
43 | clnt_y = data_obj.clnt_y;
44 | tst_x = data_obj.tst_x;
45 | tst_y = data_obj.tst_y
46 |
47 | # build model
48 |
49 | if args.dataset == 'CIFAR100':
50 | net_glob = client_model('cifar100_LeNet').to(args.device)
51 | elif args.dataset == 'CIFAR10':
52 | net_glob = client_model('cifar10_LeNet').to(args.device)
53 | elif args.dataset == 'EMNIST':
54 | net_glob = client_model('emnist_NN', [1 * 28 * 28, 10]).to(args.device)
55 | elif args.dataset == 'FMNIST':
56 | net_glob = client_model('FMNIST_CNN', [1 * 28 * 28, 10]).to(args.device)
57 | else:
58 | exit('Error: unrecognized model')
59 |
60 |
61 | total_num_layers = len(net_glob.state_dict().keys())
62 | net_keys = [*net_glob.state_dict().keys()]
63 |
64 | if args.method == 'fedPAC':
65 | if 'CIFAR100' in args.dataset:
66 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3]]
67 | elif 'CIFAR10' in args.dataset:
68 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3]]
69 | elif 'EMNIST' in args.dataset:
70 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1]]
71 | elif 'FMNIST' in args.dataset:
72 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3]]
73 | else:
74 | exit('Error: unrecognized data1')
75 |
76 | elif args.method == 'fedavg':
77 | w_glob_keys = []
78 | else:
79 | exit('Error: unrecognized data4')
80 |
81 | w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys))
82 |
83 | clients = [Client(model=copy.deepcopy(net_glob).to(args.device), args=args, trn_x=data_obj.clnt_x[i],
84 | trn_y=data_obj.clnt_y[i], tst_x=data_obj.tst_x[i], tst_y=data_obj.tst_y[i], n_cls = data_obj.n_cls, dataset_name=data_obj.dataset, id_num=i) for i in range(args.num_users)]
85 | server = Server(model = (net_glob).to(args.device), args = args, n_cls = data_obj.n_cls)
86 |
87 | logger = get_logger(args.filepath)
88 | logger.info('--------args----------')
89 | for k in list(vars(args).keys()):
90 | logger.info('%s: %s' % (k, vars(args)[k]))
91 | logger.info('--------args----------\n')
92 | logger.info('total_num_layers')
93 | logger.info(total_num_layers)
94 | logger.info('net_keys')
95 | logger.info(net_keys)
96 | logger.info('w_glob_keys')
97 | logger.info(w_glob_keys)
98 |
99 | logger.info('start training!')
100 |
101 | for iter in range(args.epochs + 1):
102 | net_glob.train()
103 |
104 | m = max(int(args.frac * args.num_users), 1)
105 |
106 | if iter == args.epochs:
107 | m = args.num_users
108 | participating_clients = random.sample(clients, m)
109 |
110 | last = iter == args.epochs
111 |
112 | for client in participating_clients:
113 | client.synchronize_with_server(server, w_glob_keys)
114 |
115 | client.compute_weight_update(w_glob_keys, server, last)
116 |
117 | client.local_feature()
118 |
119 | server.aggregate_weight_updates(clients=participating_clients, iter=iter)
120 | server.global_feature_centroids(clients=participating_clients)
121 | #server.Get_classifier(participating_clients, w_glob_keys) #for non-iid 2 case, the use of classifier combination will sometimes degrade the experimental performance
122 |
123 | #-----------------------------------------------test--------------------------------------------------------------------
124 |
125 | #-----------------------------------------------test--------------------------------------------------------------------
126 |
127 | if iter % args.test_freq==args.test_freq-1 or iter>=args.epochs-10:
128 | results_loss =[]; results_acc = []
129 | results_loss_last = []; results_acc_last = []
130 | for client in clients:
131 |
132 | if args.method != 'fedavg':
133 | results_test, loss_test1 = client.evaluate(data_x=client.tst_x, data_y=client.tst_y,
134 | dataset_name=data_obj.dataset)
135 | if last:
136 | results_test_last, loss_test1_last = client.evaluate(data_x=client.tst_x, data_y=client.tst_y,
137 | dataset_name=data_obj.dataset)
138 |
139 | results_loss.append(loss_test1)
140 | results_acc.append(results_test)
141 |
142 |
143 | results_loss = np.mean(results_loss)
144 | results_acc = np.mean(results_acc)
145 |
146 | if last:
147 | logger.info('Final Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'.
148 | format(iter, args.lr, results_loss, results_acc))
149 | else:
150 | logger.info('Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'.
151 | format(iter, args.lr, results_loss, results_acc))
152 |
153 |
154 | args.lr = args.lr * (args.lr_decay)
155 |
156 | logger.info('finish training!')
157 |
158 |
159 |
160 |
161 |
162 |
163 |
--------------------------------------------------------------------------------
/code_FedCR/models/Nets_VIB.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class client_model_VIB(nn.Module):
7 |
8 | def __init__(self, args, dimZ=256, alpha=0, dataset = 'EMNIST'):
9 | # the dimension of Z
10 | super().__init__()
11 |
12 | self.alpha = alpha
13 | self.dimZ = dimZ
14 | self.device = args.device
15 | self.dataset = dataset
16 |
17 | if self.dataset == 'EMNIST':
18 | self.n_cls = 10
19 | self.fc1 = nn.Linear(1 * 28 * 28, 1024)
20 | self.fc2 = nn.Linear(1024, 1024)
21 | self.fc3 = nn.Linear(1024, 2 * self.dimZ)
22 | self.fc4 = nn.Linear(self.dimZ, self.n_cls)
23 | self.weight_keys = [['fc1.weight', 'fc1.bias'],
24 | ['fc2.weight', 'fc2.bias'],
25 | ['fc3.weight', 'fc3.bias'],
26 | ['fc4.weight', 'fc4.bias']]
27 |
28 |
29 | if self.dataset == 'FMNIST':
30 | self.n_cls = 10
31 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5)
32 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
33 | self.conv2 = nn.Conv2d(in_channels=4, out_channels=12, kernel_size=5)
34 | self.fc1 = nn.Linear(12 * 4 * 4, 1024)
35 | self.fc2 = nn.Linear(1024, 1024)
36 | self.fc3 = nn.Linear(1024, 2 * self.dimZ)
37 | self.fc4 = nn.Linear(self.dimZ, self.n_cls)
38 |
39 | self.weight_keys = [['conv1.weight', 'conv1.bias'],
40 | ['conv2.weight', 'conv2.bias'],
41 | ['fc1.weight', 'fc1.bias'],
42 | ['fc2.weight', 'fc2.bias'],
43 | ['fc3.weight', 'fc3.bias'],
44 | ['fc4.weight', 'fc4.bias']]
45 |
46 |
47 | if self.dataset == 'CIFAR10':
48 | self.n_cls = 10
49 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5)
50 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5)
51 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
52 | self.fc1 = nn.Linear(64 * 5 * 5, 1024)
53 | self.fc2 = nn.Linear(1024, 1024)
54 | self.fc3 = nn.Linear(1024, 2 * self.dimZ)
55 | self.fc4 = nn.Linear(self.dimZ, self.n_cls)
56 | self.weight_keys = [['conv1.weight', 'conv1.bias'],
57 | ['conv2.weight', 'conv2.bias'],
58 | ['fc1.weight', 'fc1.bias'],
59 | ['fc2.weight', 'fc2.bias'],
60 | ['fc3.weight', 'fc3.bias'],
61 | ['fc4.weight', 'fc4.bias']]
62 |
63 | if self.dataset == 'CIFAR100':
64 | self.n_cls = 100
65 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5)
66 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5)
67 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
68 | self.fc1 = nn.Linear(64 * 5 * 5, 1024)
69 | self.fc2 = nn.Linear(1024, 1024)
70 | self.fc3 = nn.Linear(1024, 2 * self.dimZ)
71 | self.fc4 = nn.Linear(self.dimZ, self.n_cls)
72 | self.weight_keys = [['conv1.weight', 'conv1.bias'],
73 | ['conv2.weight', 'conv2.bias'],
74 | ['fc1.weight', 'fc1.bias'],
75 | ['fc2.weight', 'fc2.bias'],
76 | ['fc3.weight', 'fc3.bias'],
77 | ['fc4.weight', 'fc4.bias']]
78 |
79 | def gaussian_noise(self, num_samples, K):
80 | # works with integers as well as tuples
81 |
82 | return torch.normal(torch.zeros(*num_samples, K), torch.ones(*num_samples, K)).to(self.device)
83 |
84 | def sample_prior_Z(self, num_samples):
85 | return self.gaussian_noise(num_samples=num_samples, K=self.dimZ)
86 |
87 | def encoder_result(self, encoder_output):
88 | mu = encoder_output[:, :self.dimZ]
89 | sigma = torch.nn.functional.softplus(encoder_output[:, self.dimZ:] - self.alpha)
90 |
91 | return mu, sigma
92 |
93 | def sample_encoder_Z(self, batch_size, encoder_Z_distr, num_samples):
94 |
95 | mu, sigma = encoder_Z_distr
96 |
97 | return mu + sigma * self.gaussian_noise(num_samples=(num_samples, batch_size), K=self.dimZ)
98 |
99 | def forward(self, batch_x, num_samples = 1):
100 |
101 | if self.dataset == 'EMNIST':
102 | batch_size = batch_x.size()[0]
103 | # sample from encoder
104 | x = batch_x.view(-1, 1 * 28 * 28)
105 | x = F.relu(self.fc1(x))
106 | x = F.relu(self.fc2(x))
107 |
108 | encoder_output = self.fc3(x)
109 | encoder_Z_distr = self.encoder_result(encoder_output)
110 | to_decoder = self.sample_encoder_Z(batch_size=batch_size, encoder_Z_distr=encoder_Z_distr,
111 | num_samples=num_samples)
112 | decoder_logits = self.fc4(to_decoder)
113 |
114 | # batch should go first
115 |
116 |
117 | if self.dataset == 'FMNIST':
118 | batch_size = batch_x.size()[0]
119 | # sample from encoder
120 | x = self.pool(F.relu(self.conv1(batch_x)))
121 | x = self.pool(F.relu(self.conv2(x)))
122 | x = x.view(-1, 12 * 4 * 4)
123 | x = F.relu(self.fc1(x))
124 | x = F.relu(self.fc2(x))
125 |
126 | encoder_output = self.fc3(x)
127 | encoder_Z_distr = self.encoder_result(encoder_output)
128 | to_decoder = self.sample_encoder_Z(batch_size=batch_size, encoder_Z_distr=encoder_Z_distr,
129 | num_samples=num_samples)
130 | decoder_logits = self.fc4(to_decoder)
131 |
132 |
133 | if self.dataset == 'CIFAR10':
134 | batch_size = batch_x.size()[0]
135 | x = self.pool(F.relu(self.conv1(batch_x)))
136 | x = self.pool(F.relu(self.conv2(x)))
137 | x = x.view(-1, 64 * 5 * 5)
138 | x = F.relu(self.fc1(x))
139 | x = F.relu(self.fc2(x))
140 |
141 | encoder_output = self.fc3(x)
142 | encoder_Z_distr = self.encoder_result(encoder_output)
143 | to_decoder = self.sample_encoder_Z(batch_size=batch_size, encoder_Z_distr=encoder_Z_distr,
144 | num_samples=num_samples)
145 | decoder_logits = self.fc4(to_decoder)
146 |
147 | if self.dataset == 'CIFAR100':
148 | batch_size = batch_x.size()[0]
149 | x = self.pool(F.relu(self.conv1(batch_x)))
150 | x = self.pool(F.relu(self.conv2(x)))
151 | x = x.view(-1, 64 * 5 * 5)
152 | x = F.relu(self.fc1(x))
153 | x = F.relu(self.fc2(x))
154 |
155 | encoder_output = self.fc3(x)
156 | encoder_Z_distr = self.encoder_result(encoder_output)
157 | to_decoder = self.sample_encoder_Z(batch_size=batch_size, encoder_Z_distr=encoder_Z_distr,
158 | num_samples=num_samples)
159 | decoder_logits = self.fc4(to_decoder)
160 |
161 | return encoder_Z_distr, decoder_logits
162 |
163 |
164 | def weight_init(self):
165 | for m in self._modules:
166 | xavier_init(self._modules[m])
167 |
168 |
169 | def KL_between_normals(q_distr, p_distr):
170 | mu_q, sigma_q = q_distr
171 | mu_p, sigma_p = p_distr #Standard Deviation
172 | k = mu_q.size(1)
173 |
174 | mu_diff = mu_p - mu_q
175 | mu_diff_sq = torch.mul(mu_diff, mu_diff)
176 | logdet_sigma_q = torch.sum(2 * torch.log(torch.clamp(sigma_q, min=1e-8)), dim=1)
177 | logdet_sigma_p = torch.sum(2 * torch.log(torch.clamp(sigma_p, min=1e-8)), dim=1)
178 |
179 | fs = torch.sum(torch.div(sigma_q ** 2, sigma_p ** 2), dim=1) + torch.sum(torch.div(mu_diff_sq, sigma_p ** 2), dim=1)
180 | two_kl = fs - k + logdet_sigma_p - logdet_sigma_q
181 | return two_kl * 0.5
182 |
183 |
184 | def xavier_init(ms):
185 | for m in ms :
186 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
187 | nn.init.xavier_uniform(m.weight,gain=nn.init.calculate_gain('relu'))
188 | m.bias.data.zero_()
189 |
--------------------------------------------------------------------------------
/code_FedCR/main_fed.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Python version: 3.6
4 |
5 | import matplotlib
6 | matplotlib.use('Agg')
7 | import ssl
8 |
9 | import copy
10 | import itertools
11 | import random
12 | import torch
13 | import numpy as np
14 | from utils.options import args_parser
15 | from utils.seed import setup_seed
16 | from utils.logg import get_logger
17 | from models.Nets import client_model
18 | from models.Nets_VIB import client_model_VIB
19 | from utils.utils_dataset import DatasetObject
20 | from models.distributed_training_utils import Client, Server
21 | torch.set_printoptions(
22 | precision=8,
23 | threshold=1000,
24 | edgeitems=3,
25 | linewidth=150,
26 | profile=None,
27 | sci_mode=False
28 | )
29 | if __name__ == '__main__':
30 |
31 | ssl._create_default_https_context = ssl._create_unverified_context
32 | # parse args
33 | args = args_parser()
34 |
35 | args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
36 | setup_seed(args.seed)
37 |
38 |
39 | data_path = 'Folder/'
40 | data_obj = DatasetObject(dataset=args.dataset, n_client=args.num_users, seed=args.seed, rule=args.rule, class_main=args.class_main, data_path=data_path, frac_data=args.frac_data, dir_alpha=args.dir_a)
41 |
42 | clnt_x = data_obj.clnt_x;
43 | clnt_y = data_obj.clnt_y;
44 | tst_x = data_obj.tst_x;
45 | tst_y = data_obj.tst_y
46 |
47 | # build model
48 | if args.method == 'FedCR':
49 | net_glob = client_model_VIB(args, args.dimZ, args.alpha, args.dataset).to(args.device)
50 | else:
51 | if args.dataset == 'CIFAR100':
52 | net_glob = client_model('cifar100_LeNet').to(args.device)
53 | elif args.dataset == 'CIFAR10':
54 | net_glob = client_model('cifar10_LeNet').to(args.device)
55 | elif args.dataset == 'EMNIST':
56 | net_glob = client_model('emnist_NN', [1 * 28 * 28, 10]).to(args.device)
57 | elif args.dataset == 'FMNIST':
58 | net_glob = client_model('FMNIST_CNN', [1 * 28 * 28, 10]).to(args.device)
59 | else:
60 | exit('Error: unrecognized model')
61 |
62 | total_num_layers = len(net_glob.state_dict().keys())
63 | net_keys = [*net_glob.state_dict().keys()]
64 |
65 | if args.method == 'fedrep' or args.method == 'fedper' or args.method == 'fedbabu':
66 | if 'CIFAR100' in args.dataset:
67 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3]]
68 | elif 'CIFAR10' in args.dataset:
69 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3]]
70 | elif 'EMNIST' in args.dataset:
71 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1]]
72 | elif 'FMNIST' in args.dataset:
73 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3]]
74 | else:
75 | exit('Error: unrecognized data1')
76 | elif args.method == 'lg':
77 | if 'CIFAR100' in args.dataset:
78 | w_glob_keys = [net_glob.weight_keys[i] for i in [3, 4]]
79 | elif 'CIFAR10' in args.dataset:
80 | w_glob_keys = [net_glob.weight_keys[i] for i in [3, 4]]
81 | elif 'EMNIST' in args.dataset:
82 | w_glob_keys = [net_glob.weight_keys[i] for i in [1, 2]]
83 | elif 'FMNIST' in args.dataset:
84 | w_glob_keys = [net_glob.weight_keys[i] for i in [3, 4]]
85 | else:
86 | exit('Error: unrecognized data2')
87 | elif args.method == 'FedCR':
88 | if 'CIFAR100' in args.dataset:
89 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3, 4]]
90 | elif 'CIFAR10' in args.dataset:
91 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3, 4]]
92 | elif 'EMNIST' in args.dataset:
93 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2]]
94 | elif 'FMNIST' in args.dataset:
95 | w_glob_keys = [net_glob.weight_keys[i] for i in [0, 1, 2, 3, 4]]
96 | else:
97 | exit('Error: unrecognized data3')
98 | elif args.method == 'fedavg' or args.method == 'ditto' or args.method == 'maml':
99 | w_glob_keys = []
100 | else:
101 | exit('Error: unrecognized data4')
102 |
103 | w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys))
104 |
105 | clients = [Client(model=copy.deepcopy(net_glob).to(args.device), args=args, trn_x=data_obj.clnt_x[i],
106 | trn_y=data_obj.clnt_y[i], tst_x=data_obj.tst_x[i], tst_y=data_obj.tst_y[i], n_cls = data_obj.n_cls, dataset_name=data_obj.dataset, id_num=i) for i in range(args.num_users)]
107 | server = Server(model = (net_glob).to(args.device), args = args, n_cls = data_obj.n_cls)
108 |
109 | logger = get_logger(args.filepath)
110 | logger.info('--------args----------')
111 | for k in list(vars(args).keys()):
112 | logger.info('%s: %s' % (k, vars(args)[k]))
113 | logger.info('--------args----------\n')
114 | logger.info('total_num_layers')
115 | logger.info(total_num_layers)
116 | logger.info('net_keys')
117 | logger.info(net_keys)
118 | logger.info('w_glob_keys')
119 | logger.info(w_glob_keys)
120 |
121 | logger.info('start training!')
122 |
123 | for iter in range(args.epochs + 1):
124 | net_glob.train()
125 |
126 | m = max(int(args.frac * args.num_users), 1)
127 | if iter == args.epochs:
128 | m = args.num_users
129 | participating_clients = random.sample(clients, m)
130 |
131 | last = iter == args.epochs
132 |
133 | for client in participating_clients:
134 |
135 | if args.sync == 'True':
136 | client.synchronize_with_server(server, w_glob_keys)
137 |
138 | client.compute_weight_update(w_glob_keys, server, last)
139 |
140 |
141 |
142 | server.aggregate_weight_updates(clients=participating_clients, iter=iter)
143 |
144 | if args.method == 'FedCR':
145 | server.global_POE(clients=participating_clients)
146 |
147 | #-----------------------------------------------test--------------------------------------------------------------------
148 |
149 | #-----------------------------------------------test--------------------------------------------------------------------
150 |
151 | if iter % args.test_freq==args.test_freq-1 or iter>=args.epochs-10:
152 | results_loss =[]; results_acc = []
153 | results_loss_last = []; results_acc_last = []
154 | for client in clients:
155 | if args.method == 'FedCR':
156 | results_test, loss_test1 = client.evaluate_FedVIB(data_x=client.tst_x, data_y=client.tst_y,
157 | dataset_name=data_obj.dataset)
158 | elif args.method != 'fedavg':
159 | results_test, loss_test1 = client.evaluate(data_x=client.tst_x, data_y=client.tst_y,
160 | dataset_name=data_obj.dataset)
161 | elif args.method == 'fedavg':
162 | results_test, loss_test1 = server.evaluate(data_x=client.tst_x, data_y=client.tst_y,
163 | dataset_name=data_obj.dataset)
164 | if last:
165 | results_test_last, loss_test1_last = client.evaluate(data_x=client.tst_x, data_y=client.tst_y,
166 | dataset_name=data_obj.dataset)
167 | results_loss.append(loss_test1)
168 | results_acc.append(results_test)
169 |
170 | if last and args.method == 'fedavg':
171 | results_loss_last.append(loss_test1_last)
172 | results_acc_last.append(results_test_last)
173 |
174 | results_loss = np.mean(results_loss)
175 | results_acc = np.mean(results_acc)
176 | if last:
177 | logger.info('Final Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'.
178 | format(iter, args.lr, results_loss, results_acc))
179 | else:
180 | logger.info('Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'.
181 | format(iter, args.lr, results_loss, results_acc))
182 |
183 | if last and args.method == 'fedavg':
184 | results_loss_last= np.mean(results_loss_last)
185 | results_acc_last= np.mean(results_acc_last)
186 | logger.info('Final FT Epoch:[{}]\tlr =\t{:.5f}\tloss=\t{:.5f}\tacc_test=\t{:.5f}'.
187 | format(iter, args.lr, results_loss_last, results_acc_last))
188 |
189 | args.lr = args.lr * (args.lr_decay)
190 |
191 | logger.info('finish training!')
192 |
193 |
194 |
195 |
196 |
197 |
198 |
--------------------------------------------------------------------------------
/code_FedCR/models/NetsSR.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class Model_CMI(nn.Module):
7 |
8 | def __init__(self, args, dimZ=256, alpha=0, dataset = 'EMNIST'):
9 | # the dimension of Z
10 | super().__init__()
11 |
12 | self.alpha = alpha
13 | self.dimZ = dimZ
14 | self.device = args.device
15 | self.dataset = dataset
16 |
17 | if self.dataset == 'EMNIST':
18 | self.n_cls = 10
19 | if self.dataset == 'FMNIST':
20 | self.n_cls = 10
21 | if self.dataset == 'CIFAR10':
22 | self.n_cls = 10
23 | if self.dataset == 'CIFAR100':
24 | self.n_cls = 100
25 |
26 | self.r_mu = nn.Parameter(torch.zeros(self.n_cls, self.dimZ)).to(self.device)
27 | self.r_sigma = nn.Parameter(torch.ones(self.n_cls, self.dimZ)).to(self.device)
28 | self.C = nn.Parameter(torch.ones([])).to(self.device)
29 |
30 | if self.dataset == 'EMNIST':
31 | self.n_cls = 10
32 | self.fc1 = nn.Linear(1 * 28 * 28, 1024)
33 | self.fc2 = nn.Linear(1024, 1024)
34 | self.fc3 = nn.Linear(1024, 2 * self.dimZ)
35 | self.fc4 = nn.Linear(self.dimZ, self.n_cls)
36 | self.weight_keys = [['fc1.weight', 'fc1.bias'],
37 | ['fc2.weight', 'fc2.bias'],
38 | ['fc3.weight', 'fc3.bias'],
39 | ['fc4.weight', 'fc4.bias']]
40 |
41 | if self.dataset == 'FMNIST':
42 | self.n_cls = 10
43 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5)
44 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
45 | self.conv2 = nn.Conv2d(in_channels=4, out_channels=12, kernel_size=5)
46 | self.fc1 = nn.Linear(12 * 4 * 4, 1024)
47 | self.fc2 = nn.Linear(1024, 1024)
48 | self.fc3 = nn.Linear(1024, 2 * self.dimZ)
49 | self.fc4 = nn.Linear(self.dimZ, self.n_cls)
50 |
51 | self.weight_keys = [['conv1.weight', 'conv1.bias'],
52 | ['conv2.weight', 'conv2.bias'],
53 | ['fc1.weight', 'fc1.bias'],
54 | ['fc2.weight', 'fc2.bias'],
55 | ['fc3.weight', 'fc3.bias'],
56 | ['fc4.weight', 'fc4.bias']]
57 |
58 | if self.dataset == 'CIFAR10':
59 | self.n_cls = 10
60 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5)
61 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5)
62 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
63 | self.fc1 = nn.Linear(64 * 5 * 5, 1024)
64 | self.fc2 = nn.Linear(1024, 1024)
65 | self.fc3 = nn.Linear(1024, 2 * self.dimZ)
66 | self.fc4 = nn.Linear(self.dimZ, self.n_cls)
67 | self.weight_keys = [['conv1.weight', 'conv1.bias'],
68 | ['conv2.weight', 'conv2.bias'],
69 | ['fc1.weight', 'fc1.bias'],
70 | ['fc2.weight', 'fc2.bias'],
71 | ['fc3.weight', 'fc3.bias'],
72 | ['fc4.weight', 'fc4.bias']]
73 |
74 | if self.dataset == 'CIFAR100':
75 | self.n_cls = 100
76 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5)
77 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5)
78 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
79 | self.fc1 = nn.Linear(64 * 5 * 5, 1024)
80 | self.fc2 = nn.Linear(1024, 1024)
81 | self.fc3 = nn.Linear(1024, 2 * self.dimZ)
82 | self.fc4 = nn.Linear(self.dimZ, self.n_cls)
83 | self.weight_keys = [['conv1.weight', 'conv1.bias'],
84 | ['conv2.weight', 'conv2.bias'],
85 | ['fc1.weight', 'fc1.bias'],
86 | ['fc2.weight', 'fc2.bias'],
87 | ['fc3.weight', 'fc3.bias'],
88 | ['fc4.weight', 'fc4.bias']]
89 |
90 | def gaussian_noise(self, num_samples, K):
91 | # works with integers as well as tuples
92 |
93 | return torch.normal(torch.zeros(*num_samples, K), torch.ones(*num_samples, K)).to(self.device)
94 |
95 | def sample_prior_Z(self, num_samples):
96 | return self.gaussian_noise(num_samples=num_samples, K=self.dimZ)
97 |
98 | def encoder_result(self, encoder_output):
99 | mu = encoder_output[:, :self.dimZ]
100 | sigma = torch.nn.functional.softplus(encoder_output[:, self.dimZ:] - self.alpha)
101 |
102 | return mu, sigma
103 |
104 | def sample_encoder_Z(self, batch_size, encoder_Z_distr, num_samples):
105 |
106 | mu, sigma = encoder_Z_distr
107 |
108 | return mu + sigma * self.gaussian_noise(num_samples=(num_samples, batch_size), K=self.dimZ)
109 |
110 | def forward(self, batch_x, num_samples=1):
111 |
112 | if self.dataset == 'EMNIST':
113 | batch_size = batch_x.size()[0]
114 | # sample from encoder
115 | x = batch_x.view(-1, 1 * 28 * 28)
116 | x = F.relu(self.fc1(x))
117 | x = F.relu(self.fc2(x))
118 |
119 | encoder_output = self.fc3(x)
120 | encoder_Z_distr = self.encoder_result(encoder_output)
121 | to_decoder = self.sample_encoder_Z(batch_size=batch_size, encoder_Z_distr=encoder_Z_distr,
122 | num_samples=num_samples)
123 | decoder_logits = self.fc4(to_decoder)
124 |
125 | # batch should go first
126 |
127 | if self.dataset == 'FMNIST':
128 | batch_size = batch_x.size()[0]
129 | # sample from encoder
130 | x = self.pool(F.relu(self.conv1(batch_x)))
131 | x = self.pool(F.relu(self.conv2(x)))
132 | x = x.view(-1, 12 * 4 * 4)
133 | x = F.relu(self.fc1(x))
134 | x = F.relu(self.fc2(x))
135 |
136 | encoder_output = self.fc3(x)
137 | encoder_Z_distr = self.encoder_result(encoder_output)
138 | to_decoder = self.sample_encoder_Z(batch_size=batch_size, encoder_Z_distr=encoder_Z_distr,
139 | num_samples=num_samples)
140 | decoder_logits = self.fc4(to_decoder)
141 |
142 | if self.dataset == 'CIFAR10':
143 | batch_size = batch_x.size()[0]
144 | x = self.pool(F.relu(self.conv1(batch_x)))
145 | x = self.pool(F.relu(self.conv2(x)))
146 | x = x.view(-1, 64 * 5 * 5)
147 | x = F.relu(self.fc1(x))
148 | x = F.relu(self.fc2(x))
149 |
150 | encoder_output = self.fc3(x)
151 | encoder_Z_distr = self.encoder_result(encoder_output)
152 | to_decoder = self.sample_encoder_Z(batch_size=batch_size, encoder_Z_distr=encoder_Z_distr,
153 | num_samples=num_samples)
154 | decoder_logits = self.fc4(to_decoder)
155 |
156 | if self.dataset == 'CIFAR100':
157 | batch_size = batch_x.size()[0]
158 | x = self.pool(F.relu(self.conv1(batch_x)))
159 | x = self.pool(F.relu(self.conv2(x)))
160 | x = x.view(-1, 64 * 5 * 5)
161 | x = F.relu(self.fc1(x))
162 | x = F.relu(self.fc2(x))
163 |
164 | encoder_output = self.fc3(x)
165 | encoder_Z_distr = self.encoder_result(encoder_output)
166 | to_decoder = self.sample_encoder_Z(batch_size=batch_size, encoder_Z_distr=encoder_Z_distr,
167 | num_samples=num_samples)
168 | decoder_logits = self.fc4(to_decoder)
169 |
170 | regL2R = torch.norm(to_decoder)
171 |
172 | return encoder_Z_distr, decoder_logits, regL2R
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 | #regL2R = to_decoder_mearn.norm(dim=1).mean()
182 |
183 |
184 | '''
185 | class Model_CMI(nn.Module):
186 |
187 | def __init__(self, args):
188 | self.probabilistic = True
189 | super(client_model_SR, self).__init__(args)
190 |
191 | self.dimZ = args.dimZ
192 | self.device = args.device
193 | self.dataset = args.dataset
194 |
195 | if self.dataset == 'EMNIST':
196 | self.n_cls = 10
197 | if self.dataset == 'FMNIST':
198 | self.n_cls = 10
199 | if self.dataset == 'CIFAR10':
200 | self.n_cls = 10
201 | if self.dataset == 'CIFAR100':
202 | self.n_cls = 100
203 |
204 | self.r_mu = nn.Parameter(torch.zeros(args.num_classes, args.z_dim))
205 | self.r_sigma = nn.Parameter(torch.ones(args.num_classes, args.z_dim))
206 | self.C = nn.Parameter(torch.ones([]))
207 |
208 | self.optim.add_param_group({'params':[self.r_mu,self.r_sigma,self.C],'lr':args.lr,'momentum':0.9})
209 | '''
210 |
211 | def KL_between_normals(q_distr, p_distr):
212 | mu_q, sigma_q = q_distr
213 | mu_p, sigma_p = p_distr #Standard Deviation
214 | k = mu_q.size(1)
215 |
216 | mu_diff = mu_p - mu_q
217 | mu_diff_sq = torch.mul(mu_diff, mu_diff)
218 | logdet_sigma_q = torch.sum(2 * torch.log(torch.clamp(sigma_q, min=1e-8)), dim=1)
219 | logdet_sigma_p = torch.sum(2 * torch.log(torch.clamp(sigma_p, min=1e-8)), dim=1)
220 |
221 | fs = torch.sum(torch.div(sigma_q ** 2, sigma_p ** 2), dim=1) + torch.sum(torch.div(mu_diff_sq, sigma_p ** 2), dim=1)
222 | two_kl = fs - k + logdet_sigma_p - logdet_sigma_q
223 | return two_kl * 0.5
224 |
225 |
226 | def xavier_init(ms):
227 | for m in ms :
228 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
229 | nn.init.xavier_uniform(m.weight,gain=nn.init.calculate_gain('relu'))
230 | m.bias.data.zero_()
231 |
--------------------------------------------------------------------------------
/code_FedCR/models/distributed_training_utils_ditto.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import copy
3 | import math
4 | from torch import nn, autograd
5 | import numpy as np
6 | from torch.utils.data import DataLoader
7 | from utils.utils_dataset import Dataset
8 | from models.Nets_VIB import KL_between_normals
9 | import itertools
10 | import torch.nn.functional as F
11 |
12 | max_norm = 10
13 |
14 | def add(target, source):
15 | for name in target:
16 | target[name].data += source[name].data.clone()
17 |
18 |
19 | def add_mome(target, source, beta_):
20 | for name in target:
21 | target[name].data = (beta_ * target[name].data + source[name].data.clone())
22 |
23 |
24 | def add_mome2(target, source1, source2, beta_1, beta_2):
25 | for name in target:
26 | target[name].data = beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone()
27 |
28 |
29 | def add_mome3(target, source1, source2, source3, beta_1, beta_2, beta_3):
30 | for name in target:
31 | target[name].data = beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone() + beta_3 * source3[name].data.clone()
32 |
33 | def add_2(target, source1, source2, beta_1, beta_2):
34 | for name in target:
35 | target[name].data += beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone()
36 |
37 | def scale(target, scaling):
38 | for name in target:
39 | target[name].data = scaling * target[name].data.clone()
40 |
41 |
42 | def scale_ts(target, source, scaling):
43 | for name in target:
44 | target[name].data = scaling * source[name].data.clone()
45 |
46 |
47 | def subtract(target, source):
48 | for name in target:
49 | target[name].data -= source[name].data.clone()
50 |
51 |
52 | def subtract_(target, minuend, subtrahend):
53 | for name in target:
54 | target[name].data = minuend[name].data.clone() - subtrahend[name].data.clone()
55 |
56 |
57 | def average(target, sources):
58 | for name in target:
59 | target[name].data = torch.mean(torch.stack([source[name].data for source in sources]), dim=0).clone()
60 |
61 |
62 | def weighted_average(target, sources, weights):
63 | for name in target:
64 | summ = torch.sum(weights)
65 | n = len(sources)
66 | modify = [weight / summ * n for weight in weights]
67 | target[name].data = torch.mean(torch.stack([m * source[name].data for source, m in zip(sources, modify)]),
68 | dim=0).clone()
69 |
70 |
71 | def computer_norm(source1, source2):
72 | diff_norm = 0
73 |
74 | for name in source1:
75 | diff_source = source1[name].data.clone() - source2[name].data.clone()
76 | diff_norm += torch.pow(torch.norm(diff_source),2)
77 |
78 | return (torch.pow(diff_norm, 0.5))
79 |
80 | def majority_vote(target, sources, lr):
81 | for name in target:
82 | threshs = torch.stack([torch.max(source[name].data) for source in sources])
83 | mask = torch.stack([source[name].data.sign() for source in sources]).sum(dim=0).sign()
84 | target[name].data = (lr * mask).clone()
85 |
86 |
87 | def get_mdl_params(model_list, n_par=None):
88 | if n_par == None:
89 | exp_mdl = model_list[0]
90 | n_par = 0
91 | for name, param in exp_mdl.named_parameters():
92 | n_par += len(param.data.reshape(-1))
93 |
94 | param_mat = np.zeros((len(model_list), n_par)).astype('float32')
95 | for i, mdl in enumerate(model_list):
96 | idx = 0
97 | for name, param in mdl.named_parameters():
98 | temp = param.data.cpu().numpy().reshape(-1)
99 | param_mat[i, idx:idx + len(temp)] = temp
100 | idx += len(temp)
101 | return np.copy(param_mat)
102 |
103 |
104 | def get_other_params(model_list, n_par=None):
105 | if n_par == None:
106 | exp_mdl = model_list[0]
107 | n_par = 0
108 | for name in exp_mdl:
109 | n_par += len(exp_mdl[name].data.reshape(-1))
110 |
111 | param_mat = np.zeros((len(model_list), n_par)).astype('float32')
112 | for i, mdl in enumerate(model_list):
113 | idx = 0
114 | for name in mdl:
115 | temp = mdl[name].data.cpu().numpy().reshape(-1)
116 | param_mat[i, idx:idx + len(temp)] = temp
117 | idx += len(temp)
118 | return np.copy(param_mat)
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 | class DistributedTrainingDevice(object):
127 | '''
128 | A distributed training device (Client or Server)
129 | data : a pytorch dataset consisting datapoints (x,y)
130 | model : a pytorch neural net f mapping x -> f(x)=y_
131 | hyperparameters : a python dict containing all hyperparameters
132 | '''
133 |
134 | def __init__(self, model, args):
135 | self.model = model
136 | self.args = args
137 | self.loss_func = nn.CrossEntropyLoss()
138 |
139 | class Client(DistributedTrainingDevice):
140 |
141 | def __init__(self, model, args, trn_x, trn_y, tst_x, tst_y, n_cls, dataset_name, id_num=0):
142 | super().__init__(model, args)
143 |
144 | self.trn_gen = DataLoader(Dataset(trn_x, trn_y, train=True, dataset_name=dataset_name),
145 | batch_size=self.args.local_bs, shuffle=True)
146 |
147 | self.model_local = copy.deepcopy(model).to(args.device)
148 |
149 | self.tst_x = tst_x
150 | self.tst_y = tst_y
151 | self.n_cls = n_cls
152 | self.id = id_num
153 | self.local_epoch = int(np.ceil(trn_x.shape[0] / self.args.local_bs))
154 | # Parameters
155 | self.W = {name: value for name, value in self.model.named_parameters()}
156 | self.W_old = {name: torch.zeros(value.shape).to(self.args.device) for name, value in self.W.items()}
157 | self.V = {name: torch.zeros(value.shape).to(self.args.device) for name, value in self.W.items()}
158 | self.state_params_diff = 0.0
159 | self.train_loss = 0.0
160 | self.n_par = get_mdl_params([self.model]).shape[1]
161 |
162 |
163 | def synchronize_with_server(self, server, w_glob_keys):
164 | # W_client = W_server
165 | self.model = copy.deepcopy(server.model)
166 | self.W = {name: value for name, value in self.model.named_parameters()}
167 |
168 | def compute_bias(self):
169 | if self.args.method == 'ditto':
170 | cld_mdl_param = torch.tensor(get_mdl_params([self.model], self.n_par)[0], dtype=torch.float32, device=self.args.device)
171 | self.state_params_diff = self.args.mu * (-cld_mdl_param)
172 |
173 |
174 | def train_cnn(self, w_glob_keys, server, last):
175 |
176 | self.model.train()
177 |
178 | if self.args.method == 'ditto':
179 | optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.lr, momentum=self.args.momentum,
180 | weight_decay=self.args.weigh_delay)
181 |
182 | local_eps = self.args.local_ep
183 |
184 | # train and update
185 | epoch_loss = []
186 |
187 | for iter in range(local_eps):
188 |
189 | trn_gen_iter = self.trn_gen.__iter__()
190 | batch_loss = []
191 |
192 | for i in range(self.local_epoch):
193 |
194 | images, labels = trn_gen_iter.__next__()
195 | images, labels = images.to(self.args.device), labels.to(self.args.device)
196 |
197 | optimizer.zero_grad()
198 | log_probs = self.model(images)
199 | loss_f_i = self.loss_func(log_probs, labels.reshape(-1).long())
200 |
201 | loss = loss_f_i
202 |
203 | loss.backward()
204 | torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), max_norm=max_norm)
205 | optimizer.step()
206 | batch_loss.append(loss.item())
207 |
208 | epoch_loss.append(sum(batch_loss) / len(batch_loss))
209 |
210 | return sum(epoch_loss) / len(epoch_loss)
211 |
212 | def train_cnn_local(self, w_glob_keys, server, last):
213 |
214 | self.model_local.train()
215 |
216 | optimizer = torch.optim.SGD(self.model_local.parameters(), lr=self.args.lr, momentum=self.args.momentum,
217 | weight_decay=self.args.weigh_delay + self.args.mu)
218 |
219 | local_eps = self.args.local_ep
220 |
221 | # train and update
222 | epoch_loss = []
223 |
224 | for iter in range(local_eps):
225 |
226 | trn_gen_iter = self.trn_gen.__iter__()
227 | batch_loss = []
228 |
229 | for i in range(self.local_epoch):
230 |
231 | images, labels = trn_gen_iter.__next__()
232 | images, labels = images.to(self.args.device), labels.to(self.args.device)
233 |
234 | optimizer.zero_grad()
235 | log_probs = self.model_local(images)
236 | loss_f_i = self.loss_func(log_probs, labels.reshape(-1).long())
237 |
238 | local_par_list = None
239 | for param in self.model_local.parameters():
240 | if not isinstance(local_par_list, torch.Tensor):
241 | # Initially nothing to concatenate
242 | local_par_list = param.reshape(-1)
243 | else:
244 | local_par_list = torch.cat((local_par_list, param.reshape(-1)), 0)
245 |
246 | loss_algo = torch.sum(local_par_list * self.state_params_diff)
247 |
248 |
249 | loss = loss_f_i + loss_algo
250 |
251 | loss.backward()
252 | torch.nn.utils.clip_grad_norm_(parameters=self.model_local.parameters(), max_norm=max_norm)
253 | optimizer.step()
254 | batch_loss.append(loss.item())
255 |
256 | epoch_loss.append(sum(batch_loss) / len(batch_loss))
257 |
258 | return sum(epoch_loss) / len(epoch_loss)
259 |
260 |
261 | def compute_weight_update(self, w_glob_keys, server, last=False):
262 |
263 | # Training mode
264 | self.model.train()
265 | self.train_loss = self.train_cnn(w_glob_keys, server, last)
266 |
267 | self.model_local.train()
268 | self.train_loss_local = self.train_cnn_local(w_glob_keys, server, last)
269 |
270 |
271 | @torch.no_grad()
272 | def evaluate(self, data_x, data_y, dataset_name):
273 | self.model_local.eval()
274 | # testing
275 | test_loss = 0
276 | acc_overall = 0
277 | n_tst = data_x.shape[0]
278 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False)
279 | tst_gen_iter = tst_gen.__iter__()
280 | for i in range(int(np.ceil(n_tst / self.args.bs))):
281 | data, target = tst_gen_iter.__next__()
282 | data, target = data.to(self.args.device), target.to(self.args.device)
283 | log_probs = self.model_local(data)
284 | # sum up batch loss
285 | test_loss += nn.CrossEntropyLoss(reduction='sum')(log_probs, target.reshape(-1).long()).item()
286 | # get the index of the max log-probability
287 | log_probs = log_probs.cpu().detach().numpy()
288 | log_probs = np.argmax(log_probs, axis=1).reshape(-1)
289 | target = target.cpu().numpy().reshape(-1).astype(np.int32)
290 | batch_correct = np.sum(log_probs == target)
291 | acc_overall += batch_correct
292 | '''
293 | y_pred = log_probs.data.max(1, keepdim=True)[1]
294 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
295 | '''
296 |
297 | test_loss /= n_tst
298 | accuracy = 100.00 * acc_overall / n_tst
299 | return accuracy, test_loss
300 |
301 |
302 | '''
303 | ------------------------------------------------------------------------------------------------------------------------
304 |
305 | ------------------------------------------------------------------------------------------------------------------------
306 | '''
307 |
308 |
309 | class Server(DistributedTrainingDevice):
310 |
311 | def __init__(self, model, args, n_cls):
312 | super().__init__(model, args)
313 |
314 | # Parameters
315 | self.W = {name: value for name, value in self.model.named_parameters()}
316 | self.local_epoch = 0
317 | self.n_cls = n_cls
318 |
319 | def aggregate_weight_updates(self, clients, iter, aggregation="mean"):
320 |
321 | # Warning: Note that K is different for unbalanced dataset
322 | self.local_epoch = clients[0].local_epoch
323 | # dW = aggregate(dW_i, i=1,..,n)
324 | if aggregation == "mean":
325 | average(target=self.W, sources=[client.W for client in clients])
326 |
327 |
328 |
329 | @torch.no_grad()
330 | def evaluate(self, data_x, data_y, dataset_name):
331 | self.model.eval()
332 | # testing
333 | test_loss = 0
334 | acc_overall = 0
335 | n_tst = data_x.shape[0]
336 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False)
337 | tst_gen_iter = tst_gen.__iter__()
338 | for i in range(int(np.ceil(n_tst / self.args.bs))):
339 | data, target = tst_gen_iter.__next__()
340 | data, target = data.to(self.args.device), target.to(self.args.device)
341 | log_probs = self.model(data)
342 | # sum up batch loss
343 | test_loss += nn.CrossEntropyLoss(reduction='sum')(log_probs, target.reshape(-1).long()).item()
344 | # get the index of the max log-probability
345 | log_probs = log_probs.cpu().detach().numpy()
346 | log_probs = np.argmax(log_probs, axis=1).reshape(-1)
347 | target = target.cpu().numpy().reshape(-1).astype(np.int32)
348 | batch_correct = np.sum(log_probs == target)
349 | acc_overall += batch_correct
350 | '''
351 | y_pred = log_probs.data.max(1, keepdim=True)[1]
352 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
353 | '''
354 |
355 | test_loss /= n_tst
356 | accuracy = 100.00 * acc_overall / n_tst
357 | return accuracy, test_loss
358 |
359 |
360 |
361 |
--------------------------------------------------------------------------------
/code_FedCR/models/distributed_training_utilsSR.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import copy
3 | import math
4 | from torch import nn, autograd
5 | import numpy as np
6 | from torch.utils.data import DataLoader
7 | from utils.utils_dataset import Dataset
8 | from models.Nets_VIB import KL_between_normals
9 | import itertools
10 | import torch.nn.functional as F
11 |
12 | max_norm = 10
13 |
14 | def add(target, source):
15 | for name in target:
16 | target[name].data += source[name].data.clone()
17 |
18 |
19 | def add_mome(target, source, beta_):
20 | for name in target:
21 | target[name].data = (beta_ * target[name].data + source[name].data.clone())
22 |
23 |
24 | def add_mome2(target, source1, source2, beta_1, beta_2):
25 | for name in target:
26 | target[name].data = beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone()
27 |
28 |
29 | def add_mome3(target, source1, source2, source3, beta_1, beta_2, beta_3):
30 | for name in target:
31 | target[name].data = beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone() + beta_3 * source3[name].data.clone()
32 |
33 | def add_2(target, source1, source2, beta_1, beta_2):
34 | for name in target:
35 | target[name].data += beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone()
36 |
37 | def scale(target, scaling):
38 | for name in target:
39 | target[name].data = scaling * target[name].data.clone()
40 |
41 |
42 | def scale_ts(target, source, scaling):
43 | for name in target:
44 | target[name].data = scaling * source[name].data.clone()
45 |
46 |
47 | def subtract(target, source):
48 | for name in target:
49 | target[name].data -= source[name].data.clone()
50 |
51 |
52 | def subtract_(target, minuend, subtrahend):
53 | for name in target:
54 | target[name].data = minuend[name].data.clone() - subtrahend[name].data.clone()
55 |
56 |
57 | def average(target, sources):
58 | for name in target:
59 | target[name].data = torch.mean(torch.stack([source[name].data for source in sources]), dim=0).clone()
60 |
61 |
62 | def weighted_average(target, sources, weights):
63 | for name in target:
64 | summ = torch.sum(weights)
65 | n = len(sources)
66 | modify = [weight / summ * n for weight in weights]
67 | target[name].data = torch.mean(torch.stack([m * source[name].data for source, m in zip(sources, modify)]),
68 | dim=0).clone()
69 |
70 |
71 | def computer_norm(source1, source2):
72 | diff_norm = 0
73 |
74 | for name in source1:
75 | diff_source = source1[name].data.clone() - source2[name].data.clone()
76 | diff_norm += torch.pow(torch.norm(diff_source),2)
77 |
78 | return (torch.pow(diff_norm, 0.5))
79 |
80 | def majority_vote(target, sources, lr):
81 | for name in target:
82 | threshs = torch.stack([torch.max(source[name].data) for source in sources])
83 | mask = torch.stack([source[name].data.sign() for source in sources]).sum(dim=0).sign()
84 | target[name].data = (lr * mask).clone()
85 |
86 |
87 | def get_mdl_params(model_list, n_par=None):
88 | if n_par == None:
89 | exp_mdl = model_list[0]
90 | n_par = 0
91 | for name, param in exp_mdl.named_parameters():
92 | n_par += len(param.data.reshape(-1))
93 |
94 | param_mat = np.zeros((len(model_list), n_par)).astype('float32')
95 | for i, mdl in enumerate(model_list):
96 | idx = 0
97 | for name, param in mdl.named_parameters():
98 | temp = param.data.cpu().numpy().reshape(-1)
99 | param_mat[i, idx:idx + len(temp)] = temp
100 | idx += len(temp)
101 | return np.copy(param_mat)
102 |
103 |
104 | def get_other_params(model_list, n_par=None):
105 | if n_par == None:
106 | exp_mdl = model_list[0]
107 | n_par = 0
108 | for name in exp_mdl:
109 | n_par += len(exp_mdl[name].data.reshape(-1))
110 |
111 | param_mat = np.zeros((len(model_list), n_par)).astype('float32')
112 | for i, mdl in enumerate(model_list):
113 | idx = 0
114 | for name in mdl:
115 | temp = mdl[name].data.cpu().numpy().reshape(-1)
116 | param_mat[i, idx:idx + len(temp)] = temp
117 | idx += len(temp)
118 | return np.copy(param_mat)
119 |
120 |
121 |
122 |
123 |
124 |
125 | class DistributedTrainingDevice(object):
126 | '''
127 | A distributed training device (Client or Server)
128 | data : a pytorch dataset consisting datapoints (x,y)
129 | model : a pytorch neural net f mapping x -> f(x)=y_
130 | hyperparameters : a python dict containing all hyperparameters
131 | '''
132 |
133 | def __init__(self, model, args):
134 | self.model = model
135 | self.args = args
136 | self.loss_func = nn.CrossEntropyLoss()
137 |
138 | class Client(DistributedTrainingDevice):
139 |
140 | def __init__(self, model, args, trn_x, trn_y, tst_x, tst_y, n_cls, dataset_name, id_num=0):
141 | super().__init__(model, args)
142 |
143 | self.trn_gen = DataLoader(Dataset(trn_x, trn_y, train=True, dataset_name=dataset_name),
144 | batch_size=self.args.local_bs, shuffle=True)
145 |
146 | self.tst_x = tst_x
147 | self.tst_y = tst_y
148 | self.n_cls = n_cls
149 | self.id = id_num
150 | self.local_epoch = int(np.ceil(trn_x.shape[0] / self.args.local_bs))
151 | # Parameters
152 | self.W = {name: value for name, value in self.model.named_parameters()}
153 |
154 | self.state_params_diff = 0.0
155 | self.train_loss = 0.0
156 | self.n_par = get_mdl_params([self.model]).shape[1]
157 |
158 |
159 | def synchronize_with_server(self, server, w_glob_keys):
160 | # W_client = W_server
161 |
162 | if self.args.method != 'fedSR':
163 | for name in self.W:
164 | if name in w_glob_keys:
165 | self.W[name].data = server.W[name].data.clone()
166 |
167 | else:
168 |
169 | self.W = {name: value for name, value in server.model.named_parameters()}
170 |
171 |
172 | def train_cnn(self, w_glob_keys, server, last):
173 |
174 | self.model.train()
175 |
176 |
177 | optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.lr, momentum=self.args.momentum,
178 | weight_decay=self.args.weigh_delay)
179 |
180 | #.add_param_group({'params':[self.model.r_mu,self.model.r_sigma,self.model.C]})
181 |
182 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=1)
183 |
184 | local_eps = self.args.local_ep
185 | if last:
186 | if self.args.method =='fedSR':
187 | local_eps= self.args.last_local_ep
188 | if 'CIFAR100' in self.args.dataset:
189 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1, 2, 3]]
190 | elif 'CIFAR10' in self.args.dataset:
191 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1, 2, 3]]
192 | elif 'EMNIST' in self.args.dataset:
193 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1]]
194 | elif 'FMNIST' in self.args.dataset:
195 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1, 2, 3]]
196 | w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys))
197 |
198 | elif 'maml' in self.args.method:
199 | local_eps = self.args.last_local_ep / 2
200 | w_glob_keys = []
201 | else:
202 | local_eps = max(self.args.last_local_ep, self.args.local_ep - self.args.local_rep_ep)
203 |
204 | # all other methods update all parameters simultaneously
205 | else:
206 | for name, param in self.model.named_parameters():
207 | param.requires_grad = True
208 |
209 | # train and update
210 | epoch_loss = []
211 |
212 |
213 | self.dir_Z_u = torch.zeros(self.n_cls, 1, self.args.dimZ, dtype=torch.float32, device=self.args.device)
214 | self.dir_Z_sigma = torch.ones(self.n_cls, 1, self.args.dimZ, dtype = torch.float32, device = self.args.device)
215 |
216 | for iter in range(local_eps):
217 |
218 | if last:
219 | for name, param in self.model.named_parameters():
220 | if name in w_glob_keys:
221 | param.requires_grad = False
222 | else:
223 | param.requires_grad = True
224 |
225 | loss_by_epoch = []
226 | accuracy_by_epoch = []
227 |
228 | trn_gen_iter = self.trn_gen.__iter__()
229 | batch_loss = []
230 |
231 | for i in range(self.local_epoch):
232 |
233 | images, labels = trn_gen_iter.__next__()
234 | images, labels = images.to(self.args.device), labels.to(self.args.device)
235 | labels = labels.reshape(-1).long()
236 |
237 | batch_size = images.size()[0]
238 |
239 | encoder_Z_distr, decoder_logits, regL2R= self.model(images, self.args.num_avg_train)
240 |
241 | decoder_logits_mean = torch.mean(decoder_logits, dim=0)
242 |
243 | loss = nn.CrossEntropyLoss(reduction='none')
244 | decoder_logits = decoder_logits.permute(1, 2, 0)
245 | cross_entropy_loss = loss(decoder_logits, labels[:, None].expand(-1, self.args.num_avg_train))
246 | # estimate E_{eps in N(0, 1)} [log q(y | z)]
247 | cross_entropy_loss_montecarlo = torch.mean(cross_entropy_loss, dim=-1)
248 | minusI_ZY_bound = torch.mean(cross_entropy_loss_montecarlo, dim=0)
249 |
250 |
251 | r_sigma_softplus = F.softplus(self.model.r_sigma)
252 |
253 |
254 | r_mu = self.model.r_mu[labels]
255 | r_sigma = r_sigma_softplus[labels]
256 | z_mu_scaled = encoder_Z_distr[0]*self.model.C
257 | z_sigma_scaled = encoder_Z_distr[1]*self.model.C
258 | regCMI = torch.log(r_sigma) - torch.log(z_sigma_scaled) + \
259 | (z_sigma_scaled**2+(z_mu_scaled-r_mu)**2)/(2*r_sigma**2) - 0.5
260 |
261 | regCMI = regCMI.sum(1).mean()
262 | regL2R = regL2R / len(labels)
263 |
264 | total_loss = torch.mean(minusI_ZY_bound) + self.args.CMI * regCMI + self.args.L2R * regL2R
265 |
266 | prediction = torch.max(decoder_logits_mean, dim=1)[1]
267 | accuracy = torch.mean((prediction == labels).float())
268 |
269 | optimizer.zero_grad()
270 | total_loss.backward()
271 | torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), max_norm=max_norm)
272 | optimizer.step()
273 |
274 | loss_by_epoch.append(total_loss.item())
275 | accuracy_by_epoch.append(accuracy.item())
276 |
277 | scheduler.step()
278 | epoch_loss.append(sum(loss_by_epoch) / len(loss_by_epoch))
279 |
280 |
281 |
282 | return sum(epoch_loss) / len(epoch_loss)
283 |
284 | def compute_weight_update(self, w_glob_keys, server, last=False):
285 |
286 | # Training mode
287 | self.model.train()
288 |
289 | # W = SGD(W, D)
290 | self.train_loss = self.train_cnn(w_glob_keys, server, last)
291 |
292 |
293 |
294 | @torch.no_grad()
295 | def evaluate(self, data_x, data_y, dataset_name):
296 | self.model.eval()
297 | # testing
298 | I_ZX_bound_by_epoch_test = []
299 | I_ZY_bound_by_epoch_test = []
300 | loss_by_epoch_test = []
301 | accuracy_by_epoch_test = []
302 |
303 | n_tst = data_x.shape[0]
304 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False)
305 | tst_gen_iter = tst_gen.__iter__()
306 | for i in range(int(np.ceil(n_tst / self.args.bs))):
307 | data, target = tst_gen_iter.__next__()
308 | data, target = data.to(self.args.device), target.to(self.args.device)
309 | target = target.reshape(-1).long()
310 | batch_size = data.size()[0]
311 | prior_Z_distr = torch.zeros(batch_size, self.args.dimZ).to(self.args.device), torch.ones(batch_size,self.args.dimZ).to(self.args.device)
312 | encoder_Z_distr, decoder_logits, regL2R = self.model(data, self.args.num_avg)
313 |
314 | decoder_logits_mean = torch.mean(decoder_logits, dim=0)
315 | loss = nn.CrossEntropyLoss(reduction='none')
316 | decoder_logits = decoder_logits.permute(1, 2, 0)
317 | cross_entropy_loss = loss(decoder_logits, target[:, None].expand(-1, self.args.num_avg))
318 |
319 | cross_entropy_loss_montecarlo = torch.mean(cross_entropy_loss, dim=-1)
320 |
321 | I_ZX_bound_test = torch.mean(KL_between_normals(encoder_Z_distr, prior_Z_distr))
322 | minusI_ZY_bound_test = torch.mean(cross_entropy_loss_montecarlo, dim=0)
323 | total_loss_test = torch.mean(minusI_ZY_bound_test + self.args.beta * I_ZX_bound_test)
324 |
325 | prediction = torch.max(decoder_logits_mean, dim=1)[1]
326 | accuracy_test = torch.mean((prediction == target).float())
327 |
328 | I_ZX_bound_by_epoch_test.append(I_ZX_bound_test.item())
329 | I_ZY_bound_by_epoch_test.append(minusI_ZY_bound_test.item())
330 |
331 | loss_by_epoch_test.append(total_loss_test.item())
332 | accuracy_by_epoch_test.append(accuracy_test.item())
333 |
334 | I_ZX = np.mean(I_ZX_bound_by_epoch_test)
335 | I_ZY = np.mean(I_ZY_bound_by_epoch_test)
336 | loss_test = np.mean(loss_by_epoch_test)
337 | accuracy_test = np.mean(accuracy_by_epoch_test)
338 | accuracy_test = 100.00 * accuracy_test
339 | return accuracy_test, loss_test
340 |
341 | '''
342 | ------------------------------------------------------------------------------------------------------------------------
343 |
344 | ------------------------------------------------------------------------------------------------------------------------
345 | '''
346 |
347 |
348 | class Server(DistributedTrainingDevice):
349 |
350 | def __init__(self, model, args, n_cls):
351 | super().__init__(model, args)
352 |
353 | # Parameters
354 | self.W = {name: value for name, value in self.model.named_parameters()}
355 | self.local_epoch = 0
356 | self.n_cls = n_cls
357 | if self.args.method == 'FedCR':
358 | self.dir_global_Z_u = torch.zeros(self.n_cls, 1, self.args.dimZ, dtype=torch.float32, device=self.args.device)
359 | self.dir_global_Z_sigma = torch.ones(self.n_cls, 1, self.args.dimZ, dtype = torch.float32, device = self.args.device)
360 |
361 |
362 | def aggregate_weight_updates(self, clients, iter, aggregation="mean"):
363 |
364 | # Warning: Note that K is different for unbalanced dataset
365 | self.local_epoch = clients[0].local_epoch
366 | # dW = aggregate(dW_i, i=1,..,n)
367 | if aggregation == "mean":
368 | average(target=self.W, sources=[client.W for client in clients])
369 |
370 |
371 |
372 |
373 |
374 | @torch.no_grad()
375 | def evaluate(self, data_x, data_y, dataset_name):
376 | self.model.eval()
377 | # testing
378 | I_ZX_bound_by_epoch_test = []
379 | I_ZY_bound_by_epoch_test = []
380 | loss_by_epoch_test = []
381 | accuracy_by_epoch_test = []
382 |
383 | n_tst = data_x.shape[0]
384 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False)
385 | tst_gen_iter = tst_gen.__iter__()
386 | for i in range(int(np.ceil(n_tst / self.args.bs))):
387 | data, target = tst_gen_iter.__next__()
388 | data, target = data.to(self.args.device), target.to(self.args.device)
389 | target = target.reshape(-1).long()
390 | batch_size = data.size()[0]
391 | prior_Z_distr = torch.zeros(batch_size, self.args.dimZ).to(self.args.device), torch.ones(batch_size,self.args.dimZ).to(self.args.device)
392 | encoder_Z_distr, decoder_logit, regL2R = self.model(data, self.args.num_avg)
393 |
394 | decoder_logits_mean = torch.mean(decoder_logits, dim=0)
395 | loss = nn.CrossEntropyLoss(reduction='none')
396 | decoder_logits = decoder_logits.permute(1, 2, 0)
397 | cross_entropy_loss = loss(decoder_logits, target[:, None].expand(-1, self.args.num_avg))
398 |
399 | cross_entropy_loss_montecarlo = torch.mean(cross_entropy_loss, dim=-1)
400 |
401 | I_ZX_bound_test = torch.mean(KL_between_normals(encoder_Z_distr, prior_Z_distr))
402 | minusI_ZY_bound_test = torch.mean(cross_entropy_loss_montecarlo, dim=0)
403 | total_loss_test = torch.mean(minusI_ZY_bound_test + self.args.beta * I_ZX_bound_test)
404 |
405 | prediction = torch.max(decoder_logits_mean, dim=1)[1]
406 | accuracy_test = torch.mean((prediction == target).float())
407 |
408 | I_ZX_bound_by_epoch_test.append(I_ZX_bound_test.item())
409 | I_ZY_bound_by_epoch_test.append(minusI_ZY_bound_test.item())
410 |
411 | loss_by_epoch_test.append(total_loss_test.item())
412 | accuracy_by_epoch_test.append(accuracy_test.item())
413 |
414 | I_ZX = np.mean(I_ZX_bound_by_epoch_test)
415 | I_ZY = np.mean(I_ZY_bound_by_epoch_test)
416 | loss_test = np.mean(loss_by_epoch_test)
417 | accuracy_test = np.mean(accuracy_by_epoch_test)
418 | accuracy_test = 100.00 * accuracy_test
419 | return accuracy_test, loss_test
420 |
421 |
422 |
423 |
--------------------------------------------------------------------------------
/code_FedCR/models/distributed_training_utils_PAC.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import copy
3 | import math
4 | from torch import nn, autograd
5 | import numpy as np
6 | from torch.utils.data import DataLoader
7 | from utils.utils_dataset import Dataset
8 | from models.Nets_VIB import KL_between_normals
9 | import itertools
10 | import torch.nn.functional as F
11 |
12 |
13 | max_norm = 10
14 |
15 | def add(target, source):
16 | for name in target:
17 | target[name].data += source[name].data.clone()
18 |
19 |
20 | def add_mome(target, source, beta_):
21 | for name in target:
22 | target[name].data = (beta_ * target[name].data + source[name].data.clone())
23 |
24 |
25 | def add_mome2(target, source1, source2, beta_1, beta_2):
26 | for name in target:
27 | target[name].data = beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone()
28 |
29 |
30 | def add_mome3(target, source1, source2, source3, beta_1, beta_2, beta_3):
31 | for name in target:
32 | target[name].data = beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone() + beta_3 * source3[name].data.clone()
33 |
34 | def add_2(target, source1, source2, beta_1, beta_2):
35 | for name in target:
36 | target[name].data += beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone()
37 |
38 | def scale(target, scaling):
39 | for name in target:
40 | target[name].data = scaling * target[name].data.clone()
41 |
42 |
43 | def scale_ts(target, source, scaling):
44 | for name in target:
45 | target[name].data = scaling * source[name].data.clone()
46 |
47 |
48 | def subtract(target, source):
49 | for name in target:
50 | target[name].data -= source[name].data.clone()
51 |
52 |
53 | def subtract_(target, minuend, subtrahend):
54 | for name in target:
55 | target[name].data = minuend[name].data.clone() - subtrahend[name].data.clone()
56 |
57 |
58 | def average(target, sources):
59 | for name in target:
60 | target[name].data = torch.mean(torch.stack([source[name].data for source in sources]), dim=0).clone()
61 |
62 |
63 | def weighted_average(target, sources, weights):
64 | for name in target:
65 | summ = torch.sum(weights)
66 | n = len(sources)
67 | modify = [weight / summ * n for weight in weights]
68 | target[name].data = torch.mean(torch.stack([m * source[name].data for source, m in zip(sources, modify)]),
69 | dim=0).clone()
70 |
71 |
72 | def computer_norm(source1, source2):
73 | diff_norm = 0
74 |
75 | for name in source1:
76 | diff_source = source1[name].data.clone() - source2[name].data.clone()
77 | diff_norm += torch.pow(torch.norm(diff_source),2)
78 |
79 | return (torch.pow(diff_norm, 0.5))
80 |
81 | def majority_vote(target, sources, lr):
82 | for name in target:
83 | threshs = torch.stack([torch.max(source[name].data) for source in sources])
84 | mask = torch.stack([source[name].data.sign() for source in sources]).sum(dim=0).sign()
85 | target[name].data = (lr * mask).clone()
86 |
87 |
88 | def get_mdl_params(model_list, n_par=None):
89 | if n_par == None:
90 | exp_mdl = model_list[0]
91 | n_par = 0
92 | for name, param in exp_mdl.named_parameters():
93 | n_par += len(param.data.reshape(-1))
94 |
95 | param_mat = np.zeros((len(model_list), n_par)).astype('float32')
96 | for i, mdl in enumerate(model_list):
97 | idx = 0
98 | for name, param in mdl.named_parameters():
99 | temp = param.data.cpu().numpy().reshape(-1)
100 | param_mat[i, idx:idx + len(temp)] = temp
101 | idx += len(temp)
102 | return np.copy(param_mat)
103 |
104 |
105 | def get_other_params(model_list, n_par=None):
106 | if n_par == None:
107 | exp_mdl = model_list[0]
108 | n_par = 0
109 | for name in exp_mdl:
110 | n_par += len(exp_mdl[name].data.reshape(-1))
111 |
112 | param_mat = np.zeros((len(model_list), n_par)).astype('float32')
113 | for i, mdl in enumerate(model_list):
114 | idx = 0
115 | for name in mdl:
116 | temp = mdl[name].data.cpu().numpy().reshape(-1)
117 | param_mat[i, idx:idx + len(temp)] = temp
118 | idx += len(temp)
119 | return np.copy(param_mat)
120 |
121 |
122 |
123 | def personalized_classifier(args, client, clients):
124 |
125 | #penalty method
126 | class Model_quadratic_programming():
127 | def __init__(self):
128 | self.coe_client = torch.rand(int(len(clients)) -1, dtype=torch.float32, device=args.device, requires_grad=True)
129 | self.coe_client_last = 1 - torch.sum(self.coe_client)
130 |
131 | def final_solve(self):
132 |
133 | A3 = 0; A3_ = 0; A4 = 0;
134 | for i in range(int(len(clients))):
135 | for j in range(int(len(clients))):
136 |
137 | C1 = client.P_y * torch.squeeze(client.F_x) - clients[i].P_y * torch.squeeze(clients[i].F_x)
138 | A2 = ( C1 @ C1.t()).trace()
139 |
140 | if j < int(len(clients)) -1:
141 | A3 = A3 + self.coe_client[j] * A2
142 | else:
143 | A3 = A3 + self.coe_client_last * A2
144 |
145 | if i < int(len(clients)) - 1:
146 | A3_ = A3_ + self.coe_client[i] * A3
147 | else:
148 | A3_ = A3_ + self.coe_client_last * A3
149 |
150 | if i < int(len(clients)) - 1:
151 | A4 = A4 + self.coe_client[i] * clients[i].Var_
152 | else:
153 | A4 = A4 + self.coe_client_last * clients[i].Var_
154 |
155 | # penalty method
156 | final_solve = torch.sum(A4 + A3_ + F.relu(-self.coe_client)) + F.relu(-self.coe_client_last) # coe_client > 0 penalty method
157 | #final_solve = torch.sum(A5)
158 |
159 | return final_solve
160 |
161 | Model_quadratic = Model_quadratic_programming()
162 |
163 | lr = 0.1
164 | for i in range(50):
165 | loss = Model_quadratic.final_solve()
166 | loss.backward()
167 |
168 | Model_quadratic.coe_client.data.sub_(lr * Model_quadratic.coe_client.grad)
169 |
170 | Model_quadratic.coe_client.grad.zero_()
171 | Model_quadratic.coe_client_last = 1 - torch.sum(Model_quadratic.coe_client)
172 |
173 | coe_client_all = torch.rand(int(len(clients)), dtype=torch.float32, device=args.device, requires_grad=False)
174 | coe_client_all[:-1] = Model_quadratic.coe_client.clone().detach()
175 | coe_client_all[-1:] = Model_quadratic.coe_client_last.clone().detach()
176 |
177 | return coe_client_all
178 |
179 |
180 |
181 |
182 | class DistributedTrainingDevice(object):
183 | '''
184 | A distributed training device (Client or Server)
185 | data : a pytorch dataset consisting datapoints (x,y)
186 | model : a pytorch neural net f mapping x -> f(x)=y_
187 | hyperparameters : a python dict containing all hyperparameters
188 | '''
189 |
190 | def __init__(self, model, args):
191 | self.model = model
192 | self.args = args
193 | self.loss_func = nn.CrossEntropyLoss()
194 |
195 | class Client(DistributedTrainingDevice):
196 |
197 | def __init__(self, model, args, trn_x, trn_y, tst_x, tst_y, n_cls, dataset_name, id_num=0):
198 | super().__init__(model, args)
199 |
200 | self.trn_gen = DataLoader(Dataset(trn_x, trn_y, train=True, dataset_name=dataset_name),
201 | batch_size=self.args.local_bs, shuffle=True)
202 |
203 | self.tst_x = tst_x
204 | self.tst_y = tst_y
205 | self.n_cls = n_cls
206 | self.id = id_num
207 | self.local_epoch = int(np.ceil(trn_x.shape[0] / self.args.local_bs))
208 | # Parameters
209 | self.W = {name: value for name, value in self.model.named_parameters()}
210 |
211 | self.lo_fecture_output = torch.zeros(self.n_cls, 1, self.args.dimZ_PAC, dtype=torch.float32, device=self.args.device)
212 |
213 | self.P_y = torch.zeros(self.n_cls, 1, dtype=torch.float32, device=self.args.device)
214 | self.F_x = torch.zeros(self.n_cls, 1, self.args.dimZ_PAC, dtype=torch.float32, device=self.args.device)
215 | self.Var = torch.zeros(self.n_cls, 1, dtype=torch.float32, device=self.args.device)
216 | self.Var_ = torch.zeros(self.n_cls, 1, dtype=torch.float32, device=self.args.device)
217 |
218 | self.state_params_diff = 0.0
219 | self.train_loss = 0.0
220 | self.n_par = get_mdl_params([self.model]).shape[1]
221 |
222 |
223 | def synchronize_with_server(self, server, w_glob_keys):
224 | # W_client = W_server
225 |
226 | if self.args.method != 'fedavg':
227 | for name in self.W:
228 | if name in w_glob_keys:
229 | self.W[name].data = server.W[name].data.clone()
230 | else:
231 | self.model = copy.deepcopy(server.model)
232 | self.W = {name: value for name, value in self.model.named_parameters()}
233 |
234 |
235 | def train_cnn(self, w_glob_keys, server, last):
236 |
237 | self.model.train()
238 |
239 | local_eps = self.args.local_ep
240 | if last:
241 | if self.args.method =='fedavg':
242 | local_eps= self.args.last_local_ep
243 | if 'CIFAR100' in self.args.dataset:
244 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1, 2, 3]]
245 | elif 'CIFAR10' in self.args.dataset:
246 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1, 2, 3]]
247 | elif 'MNIST' in self.args.dataset:
248 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1]]
249 | elif 'FMNIST' in self.args.dataset:
250 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1, 2, 3]]
251 | w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys))
252 | else:
253 | local_eps = max(self.args.last_local_ep, self.args.local_ep - self.args.local_rep_ep)
254 |
255 | # all other methods update all parameters simultaneously
256 | else:
257 | for name, param in self.model.named_parameters():
258 | param.requires_grad = True
259 |
260 |
261 | # train and update
262 | epoch_loss = []
263 |
264 |
265 | for iter in range(local_eps):
266 | flag=0
267 | head_eps = local_eps - self.args.local_rep_ep
268 | # for FedRep, first do local epochs for the head
269 | if (iter < head_eps and self.args.method == 'fedPAC') or last:
270 | for name, param in self.model.named_parameters():
271 | if name in w_glob_keys:
272 | param.requires_grad = False
273 | else:
274 | param.requires_grad = True
275 |
276 | # then do local epochs for the representation
277 | elif (iter >= head_eps and self.args.method == 'fedPAC') and not last:
278 | flag =1
279 | for name, param in self.model.named_parameters():
280 | if name in w_glob_keys:
281 | param.requires_grad = True
282 | else:
283 | param.requires_grad = False
284 |
285 | trn_gen_iter = self.trn_gen.__iter__()
286 | batch_loss = []
287 |
288 | optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.lr, momentum=self.args.momentum,
289 | weight_decay=self.args.weigh_delay)
290 |
291 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=1)
292 |
293 |
294 | for i in range(self.local_epoch):
295 |
296 | images, labels = trn_gen_iter.__next__()
297 | images, labels = images.to(self.args.device), labels.to(self.args.device)
298 | labels = labels.reshape(-1).long()
299 |
300 | optimizer.zero_grad()
301 | log_probs, fecture_output = self.model(images)
302 |
303 |
304 | loss_f_i = self.loss_func(log_probs, labels)
305 |
306 | if flag==1:
307 | for cls in range(len(labels)):
308 | if cls == 0:
309 | dir_g_fecture_output = server.fecture_output[labels[cls]].clone().detach()
310 | else:
311 | dir_g_fecture_output = torch.cat((dir_g_fecture_output, server.fecture_output[labels[cls]].clone().detach()), 0).clone().detach()
312 |
313 | R_i = torch.norm(fecture_output - dir_g_fecture_output) / len(labels)
314 | else:
315 | R_i =0
316 |
317 | loss = loss_f_i + self.args.beta_PAC * R_i
318 |
319 | loss.backward()
320 | torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), max_norm=max_norm)
321 | optimizer.step()
322 | batch_loss.append(loss.item())
323 |
324 | scheduler.step()
325 | epoch_loss.append(sum(batch_loss) / len(batch_loss))
326 |
327 | return sum(epoch_loss) / len(epoch_loss)
328 |
329 |
330 | def local_feature(self):
331 |
332 | trn_gen_iter = self.trn_gen.__iter__()
333 | batch_loss = []
334 |
335 | for i in range(self.local_epoch):
336 | images, labels = trn_gen_iter.__next__()
337 | images, labels = images.to(self.args.device), labels.to(self.args.device)
338 | labels = labels.reshape(-1).long()
339 |
340 | log_probs, fecture_output = self.model(images)
341 |
342 | for cls in range(len(labels)):
343 | if self.lo_fecture_output[labels[cls]].equal(torch.zeros(1, self.args.dimZ_PAC, dtype=torch.float32, device=self.args.device)):
344 | self.lo_fecture_output[labels[cls]] = fecture_output[cls].clone().detach()
345 |
346 | self.Var[labels[cls]] = (fecture_output[cls].t().clone().detach()) @(fecture_output[cls].clone().detach())
347 | else:
348 | self.lo_fecture_output[labels[cls]] = (fecture_output[cls].clone().detach() + self.lo_fecture_output[labels[cls]]) / 2
349 | self.Var[labels[cls]] = (self.Var[labels[cls]] + (fecture_output[cls].clone().detach()) @(fecture_output[cls].clone().detach()) ) / 2
350 |
351 | self.P_y[labels[cls]] = self.P_y[labels[cls]] + 1
352 |
353 | self.P_y = self.P_y / torch.sum(self.P_y)
354 | self.F_x = self.lo_fecture_output.clone().detach()
355 | self.Var_ = torch.sum(self.P_y * self.Var.trace() - (self.P_y * torch.squeeze(self.F_x)) ** 2)
356 |
357 | def compute_weight_update(self, w_glob_keys, server, last=False):
358 |
359 | # Training mode
360 | self.model.train()
361 |
362 | # W = SGD(W, D)
363 | self.train_loss = self.train_cnn(w_glob_keys, server, last)
364 |
365 |
366 | @torch.no_grad()
367 | def evaluate(self, data_x, data_y, dataset_name):
368 | self.model.eval()
369 | # testing
370 | test_loss = 0
371 | acc_overall = 0
372 | n_tst = data_x.shape[0]
373 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False)
374 | tst_gen_iter = tst_gen.__iter__()
375 | for i in range(int(np.ceil(n_tst / self.args.bs))):
376 | data, target = tst_gen_iter.__next__()
377 | data, target = data.to(self.args.device), target.to(self.args.device)
378 | log_probs, fecture_output = self.model(data)
379 | # sum up batch loss
380 | test_loss += nn.CrossEntropyLoss(reduction='sum')(log_probs, target.reshape(-1).long()).item()
381 | # get the index of the max log-probability
382 | log_probs = log_probs.cpu().detach().numpy()
383 | log_probs = np.argmax(log_probs, axis=1).reshape(-1)
384 | target = target.cpu().numpy().reshape(-1).astype(np.int32)
385 | batch_correct = np.sum(log_probs == target)
386 | acc_overall += batch_correct
387 | '''
388 | y_pred = log_probs.data.max(1, keepdim=True)[1]
389 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
390 | '''
391 |
392 | test_loss /= n_tst
393 | accuracy = 100.00 * acc_overall / n_tst
394 | return accuracy, test_loss
395 |
396 |
397 | '''
398 | ------------------------------------------------------------------------------------------------------------------------
399 |
400 | ------------------------------------------------------------------------------------------------------------------------
401 | '''
402 |
403 |
404 | class Server(DistributedTrainingDevice):
405 |
406 | def __init__(self, model, args, n_cls):
407 | super().__init__(model, args)
408 |
409 | # Parameters
410 | self.W = {name: value for name, value in self.model.named_parameters()}
411 | self.local_epoch = 0
412 | self.n_cls = n_cls
413 |
414 | self.fecture_output = torch.zeros(self.n_cls, 1, self.args.dimZ_PAC, dtype=torch.float32, device=self.args.device)
415 |
416 |
417 | def aggregate_weight_updates(self, clients, iter, aggregation="mean"):
418 |
419 | # Warning: Note that K is different for unbalanced dataset
420 | self.local_epoch = clients[0].local_epoch
421 | # dW = aggregate(dW_i, i=1,..,n)
422 | if aggregation == "mean":
423 | average(target=self.W, sources=[client.W for client in clients])
424 |
425 |
426 | def global_feature_centroids(self, clients):
427 |
428 |
429 | for cls in range(self.n_cls):
430 |
431 | clients_all_fecture_output = True
432 |
433 | for i in range(len(clients)):
434 |
435 | if clients[i].lo_fecture_output[cls].equal(torch.zeros(1, self.args.dimZ_PAC, dtype=torch.float32, device=self.args.device)):
436 | pass
437 | elif isinstance(clients_all_fecture_output, bool):
438 | clients_all_fecture_output = clients[i].lo_fecture_output[cls].clone().detach()
439 |
440 | else:
441 | clients_all_fecture_output = torch.cat((clients_all_fecture_output, clients[i].lo_fecture_output[cls].clone().detach()), 0).clone().detach()
442 |
443 | if not isinstance(clients_all_fecture_output, bool):
444 |
445 | self.fecture_output[cls] = torch.mean(clients_all_fecture_output)
446 |
447 |
448 | def Get_classifier(self, clients, w_glob_keys):
449 | for client in clients:
450 | modify = personalized_classifier(args = self.args, client= client, clients=clients)
451 | for name in self.W:
452 | if name not in w_glob_keys:
453 | client.W[name].data = torch.mean(
454 | torch.stack([m * source.W[name].data for source, m in zip(clients, modify)]),
455 | dim=0).clone()
456 |
457 | @torch.no_grad()
458 | def evaluate(self, data_x, data_y, dataset_name):
459 | self.model.eval()
460 | # testing
461 | test_loss = 0
462 | acc_overall = 0
463 | n_tst = data_x.shape[0]
464 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False)
465 | tst_gen_iter = tst_gen.__iter__()
466 | for i in range(int(np.ceil(n_tst / self.args.bs))):
467 | data, target = tst_gen_iter.__next__()
468 | data, target = data.to(self.args.device), target.to(self.args.device)
469 | log_probs = self.model(data)
470 | # sum up batch loss
471 | test_loss += nn.CrossEntropyLoss(reduction='sum')(log_probs, target.reshape(-1).long()).item()
472 | # get the index of the max log-probability
473 | log_probs = log_probs.cpu().detach().numpy()
474 | log_probs = np.argmax(log_probs, axis=1).reshape(-1)
475 | target = target.cpu().numpy().reshape(-1).astype(np.int32)
476 | batch_correct = np.sum(log_probs == target)
477 | acc_overall += batch_correct
478 | '''
479 | y_pred = log_probs.data.max(1, keepdim=True)[1]
480 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
481 | '''
482 |
483 | test_loss /= n_tst
484 | accuracy = 100.00 * acc_overall / n_tst
485 | return accuracy, test_loss
486 |
487 |
488 |
489 |
--------------------------------------------------------------------------------
/code_FedCR/utils/utils_dataset.py:
--------------------------------------------------------------------------------
1 | from scipy import io
2 | import numpy as np
3 | import torchvision
4 | from torchvision import transforms
5 | import torch
6 | from torch.utils import data
7 | import os
8 |
9 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
10 | import ssl
11 |
12 | ssl._create_default_https_context = ssl._create_unverified_context
13 |
14 |
15 | class DatasetObject:
16 | def __init__(self, dataset, n_client, seed, rule, class_main, data_path, frac_data=0.7, dir_alpha=0):
17 | self.dataset = dataset
18 | self.n_client = n_client
19 | self.seed = seed
20 | self.rule = rule
21 | self.frac_data = frac_data
22 | self.dir_alpha = dir_alpha
23 | self.class_main = class_main
24 |
25 | self.name = "Data%s_nclient%d_seed%d_rule%s_alpha%s_class_main%d_frac_data%s" % (
26 | self.dataset, self.n_client, self.seed, self.rule, self.dir_alpha, self.class_main, self.frac_data)
27 | self.data_path = data_path
28 | self.set_data()
29 |
30 | def set_data(self):
31 | # Prepare data if not ready
32 | if not os.path.exists('Data/%s' % (self.name)):
33 | # Get Raw data
34 | if self.dataset == 'MNIST':
35 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
36 | trnset = torchvision.datasets.MNIST(root='Data/', train=True, download=True, transform=transform)
37 | tstset = torchvision.datasets.MNIST(root='Data/', train=False, download=True, transform=transform)
38 |
39 | trn_load = torch.utils.data.DataLoader(trnset, batch_size=60000, shuffle=False, num_workers=1)
40 | tst_load = torch.utils.data.DataLoader(tstset, batch_size=10000, shuffle=False, num_workers=1)
41 | self.channels = 1;
42 | self.width = 28;
43 | self.height = 28;
44 | self.n_cls = 10;
45 |
46 | if self.dataset == 'FMNIST':
47 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
48 | trnset = torchvision.datasets.FashionMNIST(root='Data/', train=True, download=True, transform=transform)
49 | tstset = torchvision.datasets.FashionMNIST(root='Data/', train=False, download=True,
50 | transform=transform)
51 |
52 | trn_load = torch.utils.data.DataLoader(trnset, batch_size=60000, shuffle=False, num_workers=1)
53 | tst_load = torch.utils.data.DataLoader(tstset, batch_size=10000, shuffle=False, num_workers=1)
54 | self.channels = 1;
55 | self.width = 28;
56 | self.height = 28;
57 | self.n_cls = 10;
58 |
59 | if self.dataset == 'CIFAR10':
60 | transform = transforms.Compose([transforms.ToTensor()])
61 |
62 | trnset = torchvision.datasets.CIFAR10(root='Data/', train=True, download=True, transform=transform)
63 | tstset = torchvision.datasets.CIFAR10(root='Data/', train=False, download=True, transform=transform)
64 |
65 | trn_load = torch.utils.data.DataLoader(trnset, batch_size=50000, shuffle=False, num_workers=1)
66 | tst_load = torch.utils.data.DataLoader(tstset, batch_size=10000, shuffle=False, num_workers=1)
67 | self.channels = 3;
68 | self.width = 32;
69 | self.height = 32;
70 | self.n_cls = 10;
71 |
72 | if self.dataset == 'CIFAR100':
73 | transform = transforms.Compose([transforms.ToTensor()])
74 |
75 | trnset = torchvision.datasets.CIFAR100(root='Data/', train=True, download=True, transform=transform)
76 | tstset = torchvision.datasets.CIFAR100(root='Data/', train=False, download=True, transform=transform)
77 |
78 | trn_load = torch.utils.data.DataLoader(trnset, batch_size=50000, shuffle=False, num_workers=1)
79 | tst_load = torch.utils.data.DataLoader(tstset, batch_size=10000, shuffle=False, num_workers=1)
80 | self.channels = 3;
81 | self.width = 32;
82 | self.height = 32;
83 | self.n_cls = 100;
84 |
85 | if self.dataset == 'MNIST' or self.dataset == 'FMNIST' or self.dataset == 'CIFAR10' or self.dataset == 'CIFAR100':
86 | trn_itr = trn_load.__iter__()
87 | tst_itr = tst_load.__iter__()
88 | # labels are of shape (n_data,)
89 | trn_x, trn_y = trn_itr.__next__()
90 | tst_x, tst_y = tst_itr.__next__()
91 |
92 | trn_x = trn_x.numpy();
93 | trn_y = trn_y.numpy().reshape(-1, 1)
94 | tst_x = tst_x.numpy();
95 | tst_y = tst_y.numpy().reshape(-1, 1)
96 |
97 | concat_datasets_x = np.concatenate((trn_x, tst_x), axis=0)
98 | concat_datasets_y = np.concatenate((trn_y, tst_y), axis=0)
99 |
100 | self.trn_x = trn_x
101 | self.trn_y = trn_y
102 | self.tst_x = tst_x
103 | self.tst_y = tst_y
104 |
105 | if self.dataset == 'EMNIST':
106 | emnist = io.loadmat("Data/Raw/matlab/emnist-letters.mat")
107 | # load training dataset
108 | x_train = emnist["dataset"][0][0][0][0][0][0]
109 | x_train = x_train.astype(np.float32)
110 |
111 | # load training labels
112 | y_train = emnist["dataset"][0][0][0][0][0][1] - 1 # make first class 0
113 |
114 | # take first 10 classes of letters
115 | trn_idx = np.where(y_train < 10)[0]
116 |
117 | y_train = y_train[trn_idx]
118 | x_train = x_train[trn_idx]
119 |
120 | mean_x = np.mean(x_train)
121 | std_x = np.std(x_train)
122 |
123 | # load test dataset
124 | x_test = emnist["dataset"][0][0][1][0][0][0]
125 | x_test = x_test.astype(np.float32)
126 |
127 | # load test labels
128 | y_test = emnist["dataset"][0][0][1][0][0][1] - 1 # make first class 0
129 |
130 | tst_idx = np.where(y_test < 10)[0]
131 |
132 | y_test = y_test[tst_idx]
133 | x_test = x_test[tst_idx]
134 |
135 | x_train = x_train.reshape((-1, 1, 28, 28))
136 | x_test = x_test.reshape((-1, 1, 28, 28))
137 |
138 | # normalise train and test features
139 |
140 | trn_x = (x_train - mean_x) / std_x
141 | trn_y = y_train
142 |
143 | tst_x = (x_test - mean_x) / std_x
144 | tst_y = y_test
145 |
146 | self.channels = 1; self.width = 28; self.height = 28; self.n_cls = 10;
147 | concat_datasets_x = np.concatenate((trn_x, tst_x), axis=0)
148 | concat_datasets_y = np.concatenate((trn_y, tst_y), axis=0)
149 |
150 | self.trn_x = trn_x
151 | self.trn_y = trn_y
152 | self.tst_x = tst_x
153 | self.tst_y = tst_y
154 |
155 |
156 | # Shuffle Data
157 | np.random.seed(self.seed)
158 | rand_perm = np.random.permutation(len(concat_datasets_y))
159 | concat_datasets_x = concat_datasets_x[rand_perm]
160 | concat_datasets_y = concat_datasets_y[rand_perm]
161 |
162 | assert len(concat_datasets_y) % self.n_client == 0
163 | n_data_per_clnt = int((len(concat_datasets_y)) / self.n_client)
164 | clnt_data_list = np.ones(self.n_client).astype(int) * n_data_per_clnt
165 |
166 | n_data_per_clnt_train = int(n_data_per_clnt * self.frac_data)
167 | n_data_per_clnt_tst = n_data_per_clnt - n_data_per_clnt_train
168 | clnt_data_list_train = np.ones(self.n_client).astype(int) * n_data_per_clnt_train
169 | clnt_data_list_tst = np.ones(self.n_client).astype(int) * n_data_per_clnt_tst
170 | ###
171 |
172 | cls_per_client = self.class_main
173 | n_cls = self.n_cls
174 | n_client = self.n_client
175 |
176 | # Distribute training datapoints
177 | idx_list = [np.where(concat_datasets_y == i)[0] for i in range(self.n_cls)]
178 | idx_count_list = [0 for i in range(self.n_cls)]
179 | cls_amount = np.asarray([len(idx_list[i]) for i in range(self.n_cls)])
180 | n_data = np.sum(cls_amount)
181 | total_clnt_data_list = np.asarray([0 for i in range(n_client)])
182 | clnt_cls_idx = [[[] for kk in range(n_cls)] for jj in range(n_client)] # Store the indeces of data points
183 |
184 |
185 | if self.rule == 'Dirichlet':
186 | cls_priors = np.random.dirichlet(alpha=[self.dir_alpha] * self.n_cls, size=self.n_client)
187 | prior_cumsum = np.cumsum(cls_priors, axis=1)
188 |
189 | concat_clnt_x = np.asarray(
190 | [np.zeros((clnt_data_list[clnt__], self.channels, self.height, self.width)).astype(np.float32) for
191 | clnt__ in range(self.n_client)])
192 | concat_clnt_y = np.asarray(
193 | [np.zeros((clnt_data_list[clnt__], 1)).astype(np.int64) for clnt__ in range(self.n_client)])
194 |
195 | while (np.sum(clnt_data_list) != 0):
196 | curr_clnt = np.random.randint(self.n_client)
197 | # If current node is full resample a client
198 | if clnt_data_list[curr_clnt] <= 0:
199 | continue
200 | clnt_data_list[curr_clnt] -= 1
201 | curr_prior = prior_cumsum[curr_clnt]
202 | while True:
203 | cls_label = np.argmax(np.random.uniform() <= curr_prior)
204 | # Redraw class label if trn_y is out of that class
205 | if cls_amount[cls_label] <= 0:
206 | continue
207 | cls_amount[cls_label] -= 1
208 |
209 | concat_clnt_x[curr_clnt][clnt_data_list[curr_clnt]] = concat_datasets_x[idx_list[cls_label][cls_amount[cls_label]]]
210 | concat_clnt_y[curr_clnt][clnt_data_list[curr_clnt]] = concat_datasets_y[idx_list[cls_label][cls_amount[cls_label]]]
211 |
212 | break
213 |
214 | concat_clnt_x = np.asarray(concat_clnt_x)
215 | concat_clnt_y = np.asarray(concat_clnt_y)
216 |
217 | clnt_x = np.asarray(
218 | [np.zeros((clnt_data_list_train[clnt__], self.channels, self.height, self.width)).astype(np.float32)
219 | for
220 | clnt__ in range(self.n_client)])
221 | clnt_y = np.asarray(
222 | [np.zeros((clnt_data_list_train[clnt__], 1)).astype(np.int64) for clnt__ in range(self.n_client)])
223 | tst_x = np.asarray(
224 | [np.zeros((clnt_data_list_tst[clnt__], self.channels, self.height, self.width)).astype(np.float32)
225 | for
226 | clnt__ in range(self.n_client)])
227 | tst_y = np.asarray(
228 | [np.zeros((clnt_data_list_tst[clnt__], 1)).astype(np.int64) for clnt__ in range(self.n_client)])
229 |
230 | for jj in range(n_client):
231 | rand_perm = np.random.permutation(len(concat_clnt_y[jj]))
232 | concat_clnt_x[jj] = concat_clnt_x[jj][rand_perm]
233 | concat_clnt_y[jj] = concat_clnt_y[jj][rand_perm]
234 |
235 | clnt_x[jj] = concat_clnt_x[jj][:n_data_per_clnt_train, :, :, :]
236 | tst_x[jj] = concat_clnt_x[jj][n_data_per_clnt_train:, :, :, :]
237 |
238 | clnt_y[jj] = concat_clnt_y[jj][:n_data_per_clnt_train, :]
239 | tst_y[jj] = concat_clnt_y[jj][n_data_per_clnt_train:, :]
240 |
241 |
242 | cls_means = np.zeros((self.n_client, self.n_cls))
243 | for clnt in range(self.n_client):
244 | for cls in range(self.n_cls):
245 | cls_means[clnt,cls] = np.mean(clnt_y[clnt]==cls)
246 | prior_real_diff = np.abs(cls_means-cls_priors)
247 | print('--- Max deviation from prior: %.4f' %np.max(prior_real_diff))
248 | print('--- Min deviation from prior: %.4f' %np.min(prior_real_diff))
249 |
250 | if self.rule == 'noniid':
251 | while np.sum(total_clnt_data_list) != n_data:
252 | # Still there are data to distibute
253 | # Get a random client that among the ones that has the least # of data with respect to totat data it is supposed to have
254 | min_amount = np.min(total_clnt_data_list - clnt_data_list)
255 | min_idx_list = np.where(total_clnt_data_list - clnt_data_list == min_amount)[0]
256 | np.random.shuffle(min_idx_list)
257 | cur_clnt = min_idx_list[0]
258 | print(
259 | 'Current client %d, total remaining amount %d' % (cur_clnt, n_data - np.sum(total_clnt_data_list)))
260 |
261 | # Get its class list
262 | cur_cls_list = np.asarray([(cur_clnt + jj) % n_cls for jj in range(cls_per_client)])
263 | # Get the class that has minumum amount of data on the client
264 | cls_amounts = np.asarray([len(clnt_cls_idx[cur_clnt][jj]) for jj in range(n_cls)])
265 | min_to_max = cur_cls_list[np.argsort(cls_amounts[cur_cls_list])]
266 | cur_idx = 0
267 | while cur_idx != len(min_to_max) and cls_amount[min_to_max[cur_idx]] == 0:
268 | cur_idx += 1
269 | if cur_idx == len(min_to_max):
270 | # This client is not full, it needs data but there is no class data left
271 | # Pick a random client and assign its data to this client
272 | while True:
273 | rand_clnt = np.random.randint(n_client)
274 | print('Random client %d' % rand_clnt)
275 | if rand_clnt == cur_clnt: # Pick a different client
276 | continue
277 | rand_clnt_cls = np.asarray([(rand_clnt + jj) % n_cls for jj in range(cls_per_client)])
278 | # See if random client has an intersection class with the current client
279 | cur_list = np.asarray([(cur_clnt + jj) % n_cls for jj in range(cls_per_client)])
280 | np.random.shuffle(cur_list)
281 | cls_idx = 0
282 | is_found = False
283 | while cls_idx != cls_per_client:
284 | if cur_list[cls_idx] in rand_clnt_cls and len(
285 | clnt_cls_idx[rand_clnt][cur_list[cls_idx]]) > 1:
286 | is_found = True
287 | break
288 | cls_idx += 1
289 | if not is_found: # No class intersection, choose another client
290 | continue
291 | found_cls = cur_list[cls_idx]
292 | # Assign this class instance to curr client
293 | total_clnt_data_list[cur_clnt] += 1
294 | total_clnt_data_list[rand_clnt] -= 1
295 | transfer_idx = clnt_cls_idx[rand_clnt][found_cls][-1]
296 | del clnt_cls_idx[rand_clnt][found_cls][-1]
297 | clnt_cls_idx[cur_clnt][found_cls].append(transfer_idx)
298 | # print('Class %d is transferred from %d to %d' %(found_cls, rand_clnt, cur_clnt))
299 | break
300 | else:
301 | cur_cls = min_to_max[cur_idx]
302 | # Assign one data point from this class to the task
303 | total_clnt_data_list[cur_clnt] += 1
304 | cls_amount[cur_cls] -= 1
305 | clnt_cls_idx[cur_clnt][cur_cls].append(idx_list[cur_cls][cls_amount[cur_cls]])
306 | # print('Chosen client: %d, chosen class: %d' %(cur_clnt, cur_cls))
307 |
308 | for i in range(n_cls):
309 | assert 0 == cls_amount[i], 'Missing datapoints'
310 | assert n_data == np.sum(total_clnt_data_list), 'Missing datapoints'
311 |
312 | concat_clnt_x = np.asarray(
313 | [np.zeros((clnt_data_list[clnt__], self.channels, self.height, self.width)).astype(np.float32) for
314 | clnt__ in range(self.n_client)])
315 | concat_clnt_y = np.asarray(
316 | [np.zeros((clnt_data_list[clnt__], 1)).astype(np.int64) for clnt__ in range(self.n_client)])
317 |
318 | clnt_x = np.asarray(
319 | [np.zeros((clnt_data_list_train[clnt__], self.channels, self.height, self.width)).astype(np.float32) for
320 | clnt__ in range(self.n_client)])
321 | clnt_y = np.asarray(
322 | [np.zeros((clnt_data_list_train[clnt__], 1)).astype(np.int64) for clnt__ in range(self.n_client)])
323 | tst_x = np.asarray(
324 | [np.zeros((clnt_data_list_tst[clnt__], self.channels, self.height, self.width)).astype(np.float32) for
325 | clnt__ in range(self.n_client)])
326 | tst_y = np.asarray(
327 | [np.zeros((clnt_data_list_tst[clnt__], 1)).astype(np.int64) for clnt__ in range(self.n_client)])
328 |
329 | for jj in range(n_client):
330 | concat_clnt_x[jj] = concat_datasets_x[np.concatenate(clnt_cls_idx[jj]).astype(np.int32)]
331 | concat_clnt_y[jj] = concat_datasets_y[np.concatenate(clnt_cls_idx[jj]).astype(np.int32)]
332 |
333 | for jj in range(n_client):
334 | rand_perm = np.random.permutation(len(concat_clnt_y[jj]))
335 | concat_clnt_x[jj] = concat_clnt_x[jj][rand_perm]
336 | concat_clnt_y[jj] = concat_clnt_y[jj][rand_perm]
337 |
338 | clnt_x[jj] = concat_clnt_x[jj][:n_data_per_clnt_train, :, :, :]
339 | tst_x[jj] = concat_clnt_x[jj][n_data_per_clnt_train:, :, :, :]
340 |
341 | clnt_y[jj] = concat_clnt_y[jj][:n_data_per_clnt_train, :]
342 | tst_y[jj] = concat_clnt_y[jj][n_data_per_clnt_train:, :]
343 |
344 | self.clnt_x = clnt_x;
345 | self.clnt_y = clnt_y
346 | self.tst_x = tst_x;
347 | self.tst_y = tst_y
348 |
349 | # Save data
350 | os.mkdir('Data/%s' % (self.name))
351 |
352 | np.save('Data/%s/clnt_x.npy' % (self.name), clnt_x)
353 | np.save('Data/%s/clnt_y.npy' % (self.name), clnt_y)
354 |
355 | np.save('Data/%s/tst_x.npy' % (self.name), tst_x)
356 | np.save('Data/%s/tst_y.npy' % (self.name), tst_y)
357 |
358 | if not os.path.exists('Model'):
359 | os.mkdir('Model')
360 |
361 | else:
362 | print("Data is already downloaded")
363 |
364 | self.clnt_x = np.load('Data/%s/clnt_x.npy' % (self.name))
365 | self.clnt_y = np.load('Data/%s/clnt_y.npy' % (self.name))
366 | self.n_client = len(self.clnt_x)
367 |
368 | self.tst_x = np.load('Data/%s/tst_x.npy' % (self.name))
369 | self.tst_y = np.load('Data/%s/tst_y.npy' % (self.name))
370 |
371 | if self.dataset == 'MNIST':
372 | self.channels = 1;self.width = 28;self.height = 28;self.n_cls = 10;
373 | if self.dataset == 'FMNIST':
374 | self.channels = 1; self.width = 28; self.height = 28;self.n_cls = 10;
375 | if self.dataset == 'CIFAR10':
376 | self.channels = 3; self.width = 32;self.height = 32;self.n_cls = 10;
377 | if self.dataset == 'CIFAR100':
378 | self.channels = 3;self.width = 32;self.height = 32;self.n_cls = 100;
379 | if self.dataset == 'EMNIST':
380 | self.channels = 1; self.width = 28; self.height = 28; self.n_cls = 10;
381 |
382 | print('Class frequencies:')
383 |
384 | # train
385 | count = 0
386 | for clnt in range(self.n_client):
387 | print("Client %3d: " % clnt +
388 | ', '.join(["%.3f" % np.mean(self.clnt_y[clnt] == cls) for cls in range(self.n_cls)]) +
389 | ', Amount:%d' % self.clnt_y[clnt].shape[0])
390 | count += self.clnt_y[clnt].shape[0]
391 |
392 | print('Total Amount:%d' % count)
393 | print('-----------------------------------------------------------')
394 |
395 | # test
396 | count = 0
397 | for clnt in range(self.n_client):
398 | print("Client %3d: " % clnt +
399 | ', '.join(["%.3f" % np.mean(self.tst_y[clnt] == cls) for cls in range(self.n_cls)]) +
400 | ', Amount:%d' % self.tst_y[clnt].shape[0])
401 | count += self.tst_y[clnt].shape[0]
402 |
403 | print('Total Amount:%d' % count)
404 | print('-----------------------------------------------------------')
405 |
406 |
407 | class Dataset(torch.utils.data.Dataset):
408 |
409 | def __init__(self, data_x, data_y=True, train=False, dataset_name=''):
410 | self.name = dataset_name
411 |
412 | if self.name == 'MNIST' or self.name == 'EMNIST' or self.name == 'FMNIST':
413 | self.X_data = torch.tensor(data_x).float()
414 | self.y_data = data_y
415 | if not isinstance(data_y, bool):
416 | self.y_data = torch.tensor(data_y).float()
417 |
418 | elif self.name == 'CIFAR10':
419 | self.train = train
420 | self.transform = transforms.Compose([transforms.ToTensor()])
421 | self.X_data = torch.tensor(data_x).float()
422 | self.y_data = data_y
423 | if not isinstance(data_y, bool):
424 | self.y_data = data_y.astype('float32')
425 | self.augment_transform = transforms.Compose([
426 | transforms.ToPILImage(),
427 | transforms.RandomCrop(32, padding=4),
428 | transforms.RandomHorizontalFlip(),
429 | transforms.ToTensor(),
430 | transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])])
431 | self.noaugmt_transform = transforms.Compose(
432 | [transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])])
433 |
434 | elif self.name == 'CIFAR100':
435 | self.train = train
436 | self.transform = transforms.Compose([transforms.ToTensor()])
437 | self.X_data = torch.tensor(data_x).float()
438 | self.y_data = data_y
439 | if not isinstance(data_y, bool):
440 | self.y_data = data_y.astype('float32')
441 | self.augment_transform = transforms.Compose([
442 | transforms.ToPILImage(),
443 | transforms.RandomCrop(32, padding=4),
444 | transforms.RandomHorizontalFlip(),
445 | transforms.ToTensor(),
446 | transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])])
447 | self.noaugmt_transform = transforms.Compose(
448 | [transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])])
449 |
450 | def __len__(self):
451 | return len(self.X_data)
452 |
453 | def __getitem__(self, idx):
454 |
455 | if self.name == 'MNIST' or self.name == 'EMNIST' or self.name == 'FMNIST':
456 | X = self.X_data[idx]
457 | if isinstance(self.y_data, bool):
458 | return X
459 | else:
460 | y = self.y_data[idx]
461 | return X, y
462 |
463 | elif self.name == 'CIFAR10':
464 | img = self.X_data[idx]
465 | if self.train:
466 | img = self.augment_transform(img)
467 | else:
468 | img = self.noaugmt_transform(img)
469 | if isinstance(self.y_data, bool):
470 | return img
471 | else:
472 | y = self.y_data[idx]
473 | return img, y
474 |
475 | elif self.name == 'CIFAR100':
476 | img = self.X_data[idx]
477 | if self.train:
478 | img = self.augment_transform(img)
479 | else:
480 | img = self.noaugmt_transform(img)
481 | if isinstance(self.y_data, bool):
482 | return img
483 | else:
484 | y = self.y_data[idx]
485 |
486 | return img, y
487 |
488 | if __name__ == '__main__':
489 | data_path = 'Folder/' # The folder to save Data & Model
490 | n_client = 100
491 | data_obj = DatasetObject(dataset='EMNIST', n_client=n_client, seed=23, rule = 'noniid', class_main=5, data_path=data_path,
492 | frac_data=0.7, dir_alpha=1)
493 | tst_x = data_obj.clnt_x
494 | tst_y = data_obj.clnt_y
495 |
496 | '''
497 | trn_gen = data.DataLoader(Dataset(tst_x[0], tst_y[0], train=True, dataset_name='CIFAR10'), batch_size=32, shuffle=True)
498 |
499 | tst_gen_iter = trn_gen.__iter__()
500 | for i in range(int(np.ceil(300 / 32))):
501 | data, target = tst_gen_iter.__next__()
502 | print(target.shape)
503 | targets = target.reshape(-1)
504 | print(targets.shape)
505 | print(target[2])
506 | print(target[2,0])
507 | print(target[2].type(torch.long))
508 | print(target[target[1].type(torch.long)])
509 | print(target[:, None].shape)
510 | print(target[:, None].expand(-1,-1, 5))
511 |
512 | '''
--------------------------------------------------------------------------------
/code_FedCR/models/distributed_training_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import copy
3 | import math
4 | from torch import nn, autograd
5 | import numpy as np
6 | from torch.utils.data import DataLoader
7 | from utils.utils_dataset import Dataset
8 | from models.Nets_VIB import KL_between_normals
9 | import itertools
10 | import torch.nn.functional as F
11 |
12 | max_norm = 10
13 |
14 | def add(target, source):
15 | for name in target:
16 | target[name].data += source[name].data.clone()
17 |
18 |
19 | def add_mome(target, source, beta_):
20 | for name in target:
21 | target[name].data = (beta_ * target[name].data + source[name].data.clone())
22 |
23 |
24 | def add_mome2(target, source1, source2, beta_1, beta_2):
25 | for name in target:
26 | target[name].data = beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone()
27 |
28 |
29 | def add_mome3(target, source1, source2, source3, beta_1, beta_2, beta_3):
30 | for name in target:
31 | target[name].data = beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone() + beta_3 * source3[name].data.clone()
32 |
33 | def add_2(target, source1, source2, beta_1, beta_2):
34 | for name in target:
35 | target[name].data += beta_1 * source1[name].data.clone() + beta_2 * source2[name].data.clone()
36 |
37 | def scale(target, scaling):
38 | for name in target:
39 | target[name].data = scaling * target[name].data.clone()
40 |
41 |
42 | def scale_ts(target, source, scaling):
43 | for name in target:
44 | target[name].data = scaling * source[name].data.clone()
45 |
46 |
47 | def subtract(target, source):
48 | for name in target:
49 | target[name].data -= source[name].data.clone()
50 |
51 |
52 | def subtract_(target, minuend, subtrahend):
53 | for name in target:
54 | target[name].data = minuend[name].data.clone() - subtrahend[name].data.clone()
55 |
56 |
57 | def average(target, sources):
58 | for name in target:
59 | target[name].data = torch.mean(torch.stack([source[name].data for source in sources]), dim=0).clone()
60 |
61 |
62 | def weighted_average(target, sources, weights):
63 | for name in target:
64 | summ = torch.sum(weights)
65 | n = len(sources)
66 | modify = [weight / summ * n for weight in weights]
67 | target[name].data = torch.mean(torch.stack([m * source[name].data for source, m in zip(sources, modify)]),
68 | dim=0).clone()
69 |
70 |
71 | def computer_norm(source1, source2):
72 | diff_norm = 0
73 |
74 | for name in source1:
75 | diff_source = source1[name].data.clone() - source2[name].data.clone()
76 | diff_norm += torch.pow(torch.norm(diff_source),2)
77 |
78 | return (torch.pow(diff_norm, 0.5))
79 |
80 | def majority_vote(target, sources, lr):
81 | for name in target:
82 | threshs = torch.stack([torch.max(source[name].data) for source in sources])
83 | mask = torch.stack([source[name].data.sign() for source in sources]).sum(dim=0).sign()
84 | target[name].data = (lr * mask).clone()
85 |
86 |
87 | def get_mdl_params(model_list, n_par=None):
88 | if n_par == None:
89 | exp_mdl = model_list[0]
90 | n_par = 0
91 | for name, param in exp_mdl.named_parameters():
92 | n_par += len(param.data.reshape(-1))
93 |
94 | param_mat = np.zeros((len(model_list), n_par)).astype('float32')
95 | for i, mdl in enumerate(model_list):
96 | idx = 0
97 | for name, param in mdl.named_parameters():
98 | temp = param.data.cpu().numpy().reshape(-1)
99 | param_mat[i, idx:idx + len(temp)] = temp
100 | idx += len(temp)
101 | return np.copy(param_mat)
102 |
103 |
104 | def get_other_params(model_list, n_par=None):
105 | if n_par == None:
106 | exp_mdl = model_list[0]
107 | n_par = 0
108 | for name in exp_mdl:
109 | n_par += len(exp_mdl[name].data.reshape(-1))
110 |
111 | param_mat = np.zeros((len(model_list), n_par)).astype('float32')
112 | for i, mdl in enumerate(model_list):
113 | idx = 0
114 | for name in mdl:
115 | temp = mdl[name].data.cpu().numpy().reshape(-1)
116 | param_mat[i, idx:idx + len(temp)] = temp
117 | idx += len(temp)
118 | return np.copy(param_mat)
119 |
120 |
121 |
122 | def product_of_experts_two(q_distr, p_distr):
123 | mu_q, sigma_q = q_distr
124 | mu_p, sigma_p = p_distr #Standard Deviation
125 |
126 | poe_var = torch.sqrt( torch.div((sigma_q**2 * sigma_p**2), (sigma_q**2 + sigma_p**2 + 1e-32)) )
127 |
128 | poe_u = torch.div( (mu_p * sigma_q**2 + mu_q * sigma_p**2), (sigma_q**2 + sigma_p**2 + 1e-32) )
129 |
130 | return poe_u, poe_var
131 |
132 |
133 | def product_of_experts(q_distr_set):
134 | mu_q_set, sigma_q_set = q_distr_set
135 | tmp1 = 1.0
136 | for i in range(len(mu_q_set)):
137 | tmp1 = tmp1 + (1.0 / (sigma_q_set[i] ** 2))
138 | poe_var = torch.sqrt(1.0 / tmp1)
139 | tmp2 = 0.0
140 | for i in range(len(mu_q_set)):
141 | tmp2 = tmp2 + torch.div(mu_q_set[i], sigma_q_set[i]**2)
142 | poe_u = torch.div(tmp2, tmp1)
143 | return poe_u, poe_var
144 |
145 |
146 | '''
147 | ### DEFINE NETWORK-RELATED FUNCTIONS
148 | def product_of_experts_two(q_distr, p_distr):
149 | mu_q, sigma_q = q_distr
150 | mu_p, sigma_p = p_distr #Standard Deviation
151 | tmp1 = (1.0 / (sigma_q**2 + 1e-8)) + (1.0 / (sigma_p**2 + 1e-8))
152 | poe_var = torch.sqrt(1.0 / tmp1)
153 | tmp2 = torch.div(mu_q, sigma_q**2) + torch.div(mu_p, sigma_p**2)
154 | poe_u = torch.div(tmp2, tmp1)
155 | return poe_u, poe_var
156 |
157 |
158 | def product_of_experts(q_distr_set):
159 | mu_q_set, sigma_q_set = q_distr_set
160 | tmp1 = 1.0
161 | for i in range(len(mu_q_set)):
162 | tmp1 = tmp1 + (1.0 / (sigma_q_set[i] ** 2 + 1e-8))
163 | poe_var = torch.sqrt(1.0 / tmp1)
164 | tmp2 = 0.0
165 | for i in range(len(mu_q_set)):
166 | tmp2 = tmp2 + torch.div(mu_q_set[i], sigma_q_set[i]**2)
167 | poe_u = torch.div(tmp2, tmp1)
168 | return poe_u, poe_var
169 |
170 |
171 |
172 | def product_of_experts_copy(mask_, mu_set_, logvar_set_):
173 | tmp = 1.
174 | for m in range(len(mu_set_)):
175 | tmp += torch.reshape(mask_[:, m], [-1, 1]) * torch.div(1., torch.exp(logvar_set_[m]))
176 | poe_var = torch.div(1., tmp)
177 | poe_logvar = torch.log(poe_var)
178 | tmp = 0.
179 | for m in range(len(mu_set_)):
180 | tmp += torch.reshape(mask_[:, m], [-1, 1]) * torch.div(1., torch.exp(logvar_set_[m])) * mu_set_[m]
181 | poe_mu = poe_var * tmp
182 | return poe_mu, poe_logvar
183 | '''
184 |
185 |
186 |
187 |
188 | class DistributedTrainingDevice(object):
189 | '''
190 | A distributed training device (Client or Server)
191 | data : a pytorch dataset consisting datapoints (x,y)
192 | model : a pytorch neural net f mapping x -> f(x)=y_
193 | hyperparameters : a python dict containing all hyperparameters
194 | '''
195 |
196 | def __init__(self, model, args):
197 | self.model = model
198 | self.args = args
199 | self.loss_func = nn.CrossEntropyLoss()
200 |
201 | class Client(DistributedTrainingDevice):
202 |
203 | def __init__(self, model, args, trn_x, trn_y, tst_x, tst_y, n_cls, dataset_name, id_num=0):
204 | super().__init__(model, args)
205 |
206 | self.trn_gen = DataLoader(Dataset(trn_x, trn_y, train=True, dataset_name=dataset_name),
207 | batch_size=self.args.local_bs, shuffle=True)
208 |
209 | self.tst_x = tst_x
210 | self.tst_y = tst_y
211 | self.n_cls = n_cls
212 | self.id = id_num
213 | self.local_epoch = int(np.ceil(trn_x.shape[0] / self.args.local_bs))
214 | # Parameters
215 | self.W = {name: value for name, value in self.model.named_parameters()}
216 |
217 | self.state_params_diff = 0.0
218 | self.train_loss = 0.0
219 | self.n_par = get_mdl_params([self.model]).shape[1]
220 |
221 |
222 | def synchronize_with_server(self, server, w_glob_keys):
223 | # W_client = W_server
224 |
225 | if self.args.method != 'fedavg' and self.args.method != 'ditto':
226 | for name in self.W:
227 | if name in w_glob_keys:
228 | self.W[name].data = server.W[name].data.clone()
229 | else:
230 | self.model = copy.deepcopy(server.model)
231 | self.W = {name: value for name, value in self.model.named_parameters()}
232 |
233 | def compute_bias(self):
234 | if self.args.method == 'ditto':
235 | cld_mdl_param = torch.tensor(get_mdl_params([self.model], self.n_par)[0], dtype=torch.float32, device=self.args.device)
236 | self.state_params_diff = self.args.mu * (-cld_mdl_param)
237 |
238 | def train_cnn(self, w_glob_keys, server, last):
239 |
240 | self.model.train()
241 |
242 | if self.args.method == 'ditto':
243 | optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.lr, momentum=self.args.momentum,
244 | weight_decay=self.args.weigh_delay + self.args.mu)
245 | else:
246 | optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.lr, momentum=self.args.momentum,
247 | weight_decay=self.args.weigh_delay)
248 |
249 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=1)
250 |
251 | local_eps = self.args.local_ep
252 | if last:
253 | if self.args.method =='fedavg' or self.args.method == 'ditto':
254 | local_eps= self.args.last_local_ep
255 | if 'CIFAR100' in self.args.dataset:
256 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1, 2, 3]]
257 | elif 'CIFAR10' in self.args.dataset:
258 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1, 2, 3]]
259 | elif 'EMNIST' in self.args.dataset:
260 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1]]
261 | elif 'FMNIST' in self.args.dataset:
262 | w_glob_keys = [self.model.weight_keys[i] for i in [0, 1, 2, 3]]
263 | w_glob_keys = list(itertools.chain.from_iterable(w_glob_keys))
264 |
265 | elif 'maml' in self.args.method:
266 | local_eps = self.args.last_local_ep / 2
267 | w_glob_keys = []
268 | else:
269 | local_eps = max(self.args.last_local_ep, self.args.local_ep - self.args.local_rep_ep)
270 |
271 | # all other methods update all parameters simultaneously
272 | else:
273 | for name, param in self.model.named_parameters():
274 | param.requires_grad = True
275 |
276 | # train and update
277 | epoch_loss = []
278 |
279 | if self.args.method == 'FedCR':
280 |
281 | self.dir_Z_u = torch.zeros(self.n_cls, 1, self.args.dimZ, dtype=torch.float32, device=self.args.device)
282 | self.dir_Z_sigma = torch.ones(self.n_cls, 1, self.args.dimZ, dtype = torch.float32, device = self.args.device)
283 |
284 | for iter in range(local_eps):
285 |
286 | if last:
287 | for name, param in self.model.named_parameters():
288 | if name in w_glob_keys:
289 | param.requires_grad = False
290 | else:
291 | param.requires_grad = True
292 |
293 | loss_by_epoch = []
294 | accuracy_by_epoch = []
295 |
296 | trn_gen_iter = self.trn_gen.__iter__()
297 | batch_loss = []
298 |
299 | for i in range(self.local_epoch):
300 | dir_g_Z_u = torch.zeros(1, self.args.dimZ, dtype=torch.float32,
301 | device=self.args.device)
302 | dir_g_Z_sigma = torch.ones(1, self.args.dimZ, dtype=torch.float32,
303 | device=self.args.device)
304 |
305 | images, labels = trn_gen_iter.__next__()
306 | images, labels = images.to(self.args.device), labels.to(self.args.device)
307 | labels = labels.reshape(-1).long()
308 |
309 | batch_size = images.size()[0]
310 |
311 | #prior_Z_distr_standard = torch.zeros(batch_size, self.args.dimZ).to(self.args.device), torch.ones(batch_size, self.args.dimZ).to(self.args.device)
312 |
313 | for cls in range(len(labels)):
314 | if cls == 0:
315 | dir_g_Z_u = server.dir_global_Z_u[labels[cls]].clone().detach()
316 | dir_g_Z_sigma = server.dir_global_Z_sigma[labels[cls]].clone().detach()
317 | else:
318 | dir_g_Z_u = torch.cat((dir_g_Z_u, server.dir_global_Z_u[labels[cls]].clone().detach()), 0).clone().detach()
319 | dir_g_Z_sigma = torch.cat((dir_g_Z_sigma, server.dir_global_Z_sigma[labels[cls]].clone().detach()), 0).clone().detach()
320 |
321 | prior_Z_distr = dir_g_Z_u, dir_g_Z_sigma
322 |
323 | encoder_Z_distr, decoder_logits = self.model(images, self.args.num_avg_train)
324 |
325 | decoder_logits_mean = torch.mean(decoder_logits, dim=0)
326 |
327 | loss = nn.CrossEntropyLoss(reduction='none')
328 | decoder_logits = decoder_logits.permute(1, 2, 0)
329 | cross_entropy_loss = loss(decoder_logits, labels[:, None].expand(-1, self.args.num_avg_train))
330 |
331 | # estimate E_{eps in N(0, 1)} [log q(y | z)]
332 | cross_entropy_loss_montecarlo = torch.mean(cross_entropy_loss, dim=-1)
333 |
334 | I_ZX_bound = torch.mean(KL_between_normals(prior_Z_distr, encoder_Z_distr))
335 | minusI_ZY_bound = torch.mean(cross_entropy_loss_montecarlo, dim=0)
336 | total_loss = torch.mean(minusI_ZY_bound + self.args.beta * I_ZX_bound)
337 |
338 | prediction = torch.max(decoder_logits_mean, dim=1)[1]
339 | accuracy = torch.mean((prediction == labels).float())
340 |
341 | optimizer.zero_grad()
342 | total_loss.backward()
343 | torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), max_norm=max_norm)
344 | optimizer.step()
345 |
346 | loss_by_epoch.append(total_loss.item())
347 | accuracy_by_epoch.append(accuracy.item())
348 |
349 | if iter == local_eps - 1:
350 | for cls in range(len(labels)):
351 | if self.dir_Z_u[labels[cls]] .equal(torch.zeros(1, self.args.dimZ, dtype=torch.float32, device=self.args.device)) and self.dir_Z_sigma[labels[cls]] .equal(torch.ones(1, self.args.dimZ, dtype=torch.float32, device=self.args.device)):
352 | self.dir_Z_u[labels[cls]], self.dir_Z_sigma[labels[cls]] = encoder_Z_distr[0][cls].clone().detach(), encoder_Z_distr[1][cls].clone().detach()
353 | else:
354 | q_distr = self.dir_Z_u[labels[cls]], self.dir_Z_sigma[labels[cls]]
355 | encoder_Z_distr_cls = encoder_Z_distr[0][cls].clone().detach(), encoder_Z_distr[1][cls].clone().detach()
356 | self.dir_Z_u[labels[cls]], self.dir_Z_sigma[labels[cls]] = product_of_experts_two(q_distr, encoder_Z_distr_cls)
357 |
358 | scheduler.step()
359 | epoch_loss.append(sum(loss_by_epoch) / len(loss_by_epoch))
360 |
361 |
362 | else:
363 | for iter in range(local_eps):
364 |
365 | head_eps = local_eps - self.args.local_rep_ep
366 | # for FedRep, first do local epochs for the head
367 | if (iter < head_eps and self.args.method == 'fedrep') or last:
368 | for name, param in self.model.named_parameters():
369 | if name in w_glob_keys:
370 | param.requires_grad = False
371 | else:
372 | param.requires_grad = True
373 |
374 | # then do local epochs for the representation
375 | elif (iter == head_eps and self.args.method == 'fedrep') and not last:
376 | for name, param in self.model.named_parameters():
377 | if name in w_glob_keys:
378 | param.requires_grad = True
379 | else:
380 | param.requires_grad = False
381 |
382 | # all other methods update all parameters simultaneously
383 | elif self.args.method != 'fedrep':
384 | for name, param in self.model.named_parameters():
385 | param.requires_grad = True
386 |
387 | trn_gen_iter = self.trn_gen.__iter__()
388 | batch_loss = []
389 |
390 | for i in range(self.local_epoch):
391 |
392 | images, labels = trn_gen_iter.__next__()
393 | images, labels = images.to(self.args.device), labels.to(self.args.device)
394 |
395 | optimizer.zero_grad()
396 | log_probs = self.model(images)
397 | loss_f_i = self.loss_func(log_probs, labels.reshape(-1).long())
398 |
399 | local_par_list = None
400 | for param in self.model.parameters():
401 | if not isinstance(local_par_list, torch.Tensor):
402 | # Initially nothing to concatenate
403 | local_par_list = param.reshape(-1)
404 | else:
405 | local_par_list = torch.cat((local_par_list, param.reshape(-1)), 0)
406 |
407 | loss_algo = torch.sum(local_par_list * self.state_params_diff)
408 |
409 | if self.args.method == 'ditto':
410 | loss = loss_f_i + loss_algo
411 | else:
412 | loss = loss_f_i
413 |
414 | loss.backward()
415 | torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), max_norm=max_norm)
416 | optimizer.step()
417 | batch_loss.append(loss.item())
418 |
419 | scheduler.step()
420 | epoch_loss.append(sum(batch_loss) / len(batch_loss))
421 |
422 | return sum(epoch_loss) / len(epoch_loss)
423 |
424 | def compute_weight_update(self, w_glob_keys, server, last=False):
425 |
426 | # Training mode
427 | self.model.train()
428 |
429 | # W = SGD(W, D)
430 | self.train_loss = self.train_cnn(w_glob_keys, server, last)
431 |
432 |
433 |
434 | @torch.no_grad()
435 | def evaluate_FedVIB(self, data_x, data_y, dataset_name):
436 | self.model.eval()
437 | # testing
438 | I_ZX_bound_by_epoch_test = []
439 | I_ZY_bound_by_epoch_test = []
440 | loss_by_epoch_test = []
441 | accuracy_by_epoch_test = []
442 |
443 | n_tst = data_x.shape[0]
444 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False)
445 | tst_gen_iter = tst_gen.__iter__()
446 | for i in range(int(np.ceil(n_tst / self.args.bs))):
447 | data, target = tst_gen_iter.__next__()
448 | data, target = data.to(self.args.device), target.to(self.args.device)
449 | target = target.reshape(-1).long()
450 | batch_size = data.size()[0]
451 | prior_Z_distr = torch.zeros(batch_size, self.args.dimZ).to(self.args.device), torch.ones(batch_size,self.args.dimZ).to(self.args.device)
452 | encoder_Z_distr, decoder_logits = self.model(data, self.args.num_avg)
453 |
454 | decoder_logits_mean = torch.mean(decoder_logits, dim=0)
455 | loss = nn.CrossEntropyLoss(reduction='none')
456 | decoder_logits = decoder_logits.permute(1, 2, 0)
457 | cross_entropy_loss = loss(decoder_logits, target[:, None].expand(-1, self.args.num_avg))
458 |
459 | cross_entropy_loss_montecarlo = torch.mean(cross_entropy_loss, dim=-1)
460 |
461 | I_ZX_bound_test = torch.mean(KL_between_normals(encoder_Z_distr, prior_Z_distr))
462 | minusI_ZY_bound_test = torch.mean(cross_entropy_loss_montecarlo, dim=0)
463 | total_loss_test = torch.mean(minusI_ZY_bound_test + self.args.beta * I_ZX_bound_test)
464 |
465 | prediction = torch.max(decoder_logits_mean, dim=1)[1]
466 | accuracy_test = torch.mean((prediction == target).float())
467 |
468 | I_ZX_bound_by_epoch_test.append(I_ZX_bound_test.item())
469 | I_ZY_bound_by_epoch_test.append(minusI_ZY_bound_test.item())
470 |
471 | loss_by_epoch_test.append(total_loss_test.item())
472 | accuracy_by_epoch_test.append(accuracy_test.item())
473 |
474 | I_ZX = np.mean(I_ZX_bound_by_epoch_test)
475 | I_ZY = np.mean(I_ZY_bound_by_epoch_test)
476 | loss_test = np.mean(loss_by_epoch_test)
477 | accuracy_test = np.mean(accuracy_by_epoch_test)
478 | accuracy_test = 100.00 * accuracy_test
479 | return accuracy_test, loss_test
480 |
481 |
482 | @torch.no_grad()
483 | def evaluate(self, data_x, data_y, dataset_name):
484 | self.model.eval()
485 | # testing
486 | test_loss = 0
487 | acc_overall = 0
488 | n_tst = data_x.shape[0]
489 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False)
490 | tst_gen_iter = tst_gen.__iter__()
491 | for i in range(int(np.ceil(n_tst / self.args.bs))):
492 | data, target = tst_gen_iter.__next__()
493 | data, target = data.to(self.args.device), target.to(self.args.device)
494 | log_probs = self.model(data)
495 | # sum up batch loss
496 | test_loss += nn.CrossEntropyLoss(reduction='sum')(log_probs, target.reshape(-1).long()).item()
497 | # get the index of the max log-probability
498 | log_probs = log_probs.cpu().detach().numpy()
499 | log_probs = np.argmax(log_probs, axis=1).reshape(-1)
500 | target = target.cpu().numpy().reshape(-1).astype(np.int32)
501 | batch_correct = np.sum(log_probs == target)
502 | acc_overall += batch_correct
503 | '''
504 | y_pred = log_probs.data.max(1, keepdim=True)[1]
505 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
506 | '''
507 |
508 | test_loss /= n_tst
509 | accuracy = 100.00 * acc_overall / n_tst
510 | return accuracy, test_loss
511 |
512 |
513 | '''
514 | ------------------------------------------------------------------------------------------------------------------------
515 |
516 | ------------------------------------------------------------------------------------------------------------------------
517 | '''
518 |
519 |
520 | class Server(DistributedTrainingDevice):
521 |
522 | def __init__(self, model, args, n_cls):
523 | super().__init__(model, args)
524 |
525 | # Parameters
526 | self.W = {name: value for name, value in self.model.named_parameters()}
527 | self.local_epoch = 0
528 | self.n_cls = n_cls
529 | if self.args.method == 'FedCR':
530 | self.dir_global_Z_u = torch.zeros(self.n_cls, 1, self.args.dimZ, dtype=torch.float32, device=self.args.device)
531 | self.dir_global_Z_sigma = torch.ones(self.n_cls, 1, self.args.dimZ, dtype = torch.float32, device = self.args.device)
532 |
533 |
534 | def aggregate_weight_updates(self, clients, iter, aggregation="mean"):
535 |
536 | # Warning: Note that K is different for unbalanced dataset
537 | self.local_epoch = clients[0].local_epoch
538 | # dW = aggregate(dW_i, i=1,..,n)
539 | if aggregation == "mean":
540 | average(target=self.W, sources=[client.W for client in clients])
541 |
542 |
543 | def global_POE(self, clients):
544 |
545 | dir_global_Z_u_copy = copy.deepcopy(self.dir_global_Z_u)
546 | dir_global_Z_sigma_copy = copy.deepcopy(self.dir_global_Z_sigma)
547 |
548 | for cls in range(self.n_cls):
549 | clients_all_Z_u = True
550 | clients_all_Z_sigma= True
551 |
552 | for i in range(len(clients)):
553 |
554 | if clients[i].dir_Z_u[cls].equal(torch.zeros(1, self.args.dimZ, dtype=torch.float32, device=self.args.device)) and clients[i].dir_Z_sigma[cls].equal(torch.ones(1, self.args.dimZ, dtype=torch.float32, device=self.args.device)):
555 | pass
556 | elif isinstance(clients_all_Z_u, bool):
557 | clients_all_Z_u =clients[i].dir_Z_u[cls].clone().detach()
558 | clients_all_Z_sigma =clients[i].dir_Z_sigma[cls].clone().detach()
559 | else:
560 | clients_all_Z_u = torch.cat((clients_all_Z_u, clients[i].dir_Z_u[cls].clone().detach()), 0).clone().detach()
561 | clients_all_Z_sigma = torch.cat((clients_all_Z_sigma, clients[i].dir_Z_sigma[cls].clone().detach()), 0).clone().detach()
562 |
563 | if not isinstance(clients_all_Z_u, bool):
564 | clients_all_Z = clients_all_Z_u, clients_all_Z_sigma
565 | dir_global_Z_u_copy[cls], dir_global_Z_sigma_copy[cls] = product_of_experts(clients_all_Z)
566 |
567 | self.dir_global_Z_u[cls] = (1- self.args.beta2) * self.dir_global_Z_u[cls] + self.args.beta2 * dir_global_Z_u_copy[cls]
568 | self.dir_global_Z_sigma[cls] = (1- self.args.beta2) * self.dir_global_Z_sigma[cls] + self.args.beta2 * dir_global_Z_sigma_copy[cls]
569 |
570 |
571 | @torch.no_grad()
572 | def evaluate_FedVIB(self, data_x, data_y, dataset_name):
573 | self.model.eval()
574 | # testing
575 | I_ZX_bound_by_epoch_test = []
576 | I_ZY_bound_by_epoch_test = []
577 | loss_by_epoch_test = []
578 | accuracy_by_epoch_test = []
579 |
580 | n_tst = data_x.shape[0]
581 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False)
582 | tst_gen_iter = tst_gen.__iter__()
583 | for i in range(int(np.ceil(n_tst / self.args.bs))):
584 | data, target = tst_gen_iter.__next__()
585 | data, target = data.to(self.args.device), target.to(self.args.device)
586 | batch_size = data.size()[0]
587 | prior_Z_distr = torch.zeros(batch_size, self.args.dimZ).to(self.args.device), torch.ones(batch_size,self.args.dimZ).to(self.args.device)
588 | encoder_Z_distr, decoder_logits = self.model(data, self.args.num_avg)
589 |
590 | decoder_logits_mean = torch.mean(decoder_logits, dim=0)
591 | loss = nn.CrossEntropyLoss(reduction='none')
592 | decoder_logits = decoder_logits.permute(1, 2, 0)
593 | cross_entropy_loss = loss(decoder_logits, target[:, None].expand(-1, self.args.num_avg))
594 |
595 | cross_entropy_loss_montecarlo = torch.mean(cross_entropy_loss, dim=-1)
596 |
597 | I_ZX_bound_test = torch.mean(KL_between_normals(encoder_Z_distr, prior_Z_distr))
598 | minusI_ZY_bound_test = torch.mean(cross_entropy_loss_montecarlo, dim=0)
599 | total_loss_test = torch.mean(minusI_ZY_bound_test + self.args.beta * I_ZX_bound_test)
600 |
601 | prediction = torch.max(decoder_logits_mean, dim=1)[1]
602 | accuracy_test = torch.mean((prediction == target).float())
603 |
604 | I_ZX_bound_by_epoch_test.append(I_ZX_bound_test.item())
605 | I_ZY_bound_by_epoch_test.append(minusI_ZY_bound_test.item())
606 |
607 | loss_by_epoch_test.append(total_loss_test.item())
608 | accuracy_by_epoch_test.append(accuracy_test.item())
609 |
610 | I_ZX = np.mean(I_ZX_bound_by_epoch_test)
611 | I_ZY = np.mean(I_ZY_bound_by_epoch_test)
612 | loss_test = np.mean(loss_by_epoch_test)
613 | accuracy_test = np.mean(accuracy_by_epoch_test)
614 | accuracy_test = 100.00 * accuracy_test
615 | return accuracy_test, loss_test
616 |
617 |
618 | @torch.no_grad()
619 | def evaluate(self, data_x, data_y, dataset_name):
620 | self.model.eval()
621 | # testing
622 | test_loss = 0
623 | acc_overall = 0
624 | n_tst = data_x.shape[0]
625 | tst_gen = DataLoader(Dataset(data_x, data_y, dataset_name=dataset_name), batch_size=self.args.bs, shuffle=False)
626 | tst_gen_iter = tst_gen.__iter__()
627 | for i in range(int(np.ceil(n_tst / self.args.bs))):
628 | data, target = tst_gen_iter.__next__()
629 | data, target = data.to(self.args.device), target.to(self.args.device)
630 | log_probs = self.model(data)
631 | # sum up batch loss
632 | test_loss += nn.CrossEntropyLoss(reduction='sum')(log_probs, target.reshape(-1).long()).item()
633 | # get the index of the max log-probability
634 | log_probs = log_probs.cpu().detach().numpy()
635 | log_probs = np.argmax(log_probs, axis=1).reshape(-1)
636 | target = target.cpu().numpy().reshape(-1).astype(np.int32)
637 | batch_correct = np.sum(log_probs == target)
638 | acc_overall += batch_correct
639 | '''
640 | y_pred = log_probs.data.max(1, keepdim=True)[1]
641 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
642 | '''
643 |
644 | test_loss /= n_tst
645 | accuracy = 100.00 * acc_overall / n_tst
646 | return accuracy, test_loss
647 |
648 |
649 |
650 |
--------------------------------------------------------------------------------