├── README.md
├── .idea
├── other.xml
├── vcs.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
├── saveactions_settings.xml
├── misc.xml
├── Federated-Mutual-Learning.iml
└── deployment.xml
├── main.py
├── Node.py
├── Args.py
├── utils.py
├── Model.py
├── Trainer.py
└── Data.py
/README.md:
--------------------------------------------------------------------------------
1 | # Federated-Mutual-Learning
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/saveactions_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/Federated-Mutual-Learning.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from Node import Node, Global_Node
3 | from Args import args_parser
4 | from Data import Data
5 | from utils import LR_scheduler, Recorder, Catfish, Summary
6 | from Trainer import Trainer
7 |
8 | # init args
9 | args = args_parser()
10 | args.device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
11 | print('Running on', args.device)
12 | Data = Data(args)
13 | Train = Trainer(args)
14 |
15 | # init nodes
16 | Global_node = Global_Node(Data.test_all, args)
17 | Node_List = [Node(k, Data.train_loader[k], Data.test_loader, args) for k in range(args.node_num)]
18 | Catfish(Node_List, args)
19 |
20 | # init variables
21 | recorder = Recorder(args)
22 | Summary(args)
23 | # start
24 | for rounds in range(args.R):
25 | print('===============The {:d}-th round==============='.format(rounds + 1))
26 | LR_scheduler(rounds, Node_List, args)
27 | for k in range(len(Node_List)):
28 | # Node_List[k].fork(Global_node)
29 | for epoch in range(args.E):
30 | Train(Node_List[k])
31 | recorder.validate(Node_List[k])
32 | recorder.printer(Node_List[k])
33 | # Global_node.merge(Node_List)
34 | recorder.validate(Global_node)
35 | recorder.printer(Global_node)
36 | recorder.finish()
37 | Summary(args)
38 |
--------------------------------------------------------------------------------
/Node.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import Model
4 |
5 |
6 | def init_model(model_type):
7 | model = []
8 | if model_type == 'LeNet5':
9 | model = Model.LeNet5()
10 | elif model_type == 'MLP':
11 | model = Model.MLP()
12 | elif model_type == 'ResNet18':
13 | model = Model.ResNet18()
14 | elif model_type == 'CNN':
15 | model = Model.CNN()
16 | return model
17 |
18 |
19 | def init_optimizer(model, args):
20 | optimizer = []
21 | if args.optimizer == 'sgd':
22 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5e-4)
23 | elif args.optimizer == 'adam':
24 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)
25 | return optimizer
26 |
27 |
28 | def weights_zero(model):
29 | for p in model.parameters():
30 | if p.data is not None:
31 | p.data.detach_()
32 | p.data.zero_()
33 |
34 |
35 | class Node(object):
36 | def __init__(self, num, train_data, test_data, args):
37 | self.args = args
38 | self.num = num + 1
39 | self.device = self.args.device
40 | self.train_data = train_data
41 | self.test_data = test_data
42 | self.model = init_model(self.args.local_model).to(self.device)
43 | self.optimizer = init_optimizer(self.model, self.args)
44 | self.meme = init_model(self.args.global_model).to(self.device)
45 | self.meme_optimizer = init_optimizer(self.meme, self.args)
46 |
47 | def fork(self, global_node):
48 | self.meme = copy.deepcopy(global_node.model).to(self.device)
49 | self.meme_optimizer = init_optimizer(self.meme, self.args)
50 |
51 |
52 | class Global_Node(object):
53 | def __init__(self, test_data, args):
54 | self.num = 0
55 | self.args = args
56 | self.device = self.args.device
57 | self.model = init_model(self.args.global_model).to(self.device)
58 | self.test_data = test_data
59 | self.Dict = self.model.state_dict()
60 |
61 | def merge(self, Node_List):
62 | weights_zero(self.model)
63 | Node_State_List = [copy.deepcopy(Node_List[i].meme.state_dict()) for i in range(len(Node_List))]
64 | for key in self.Dict.keys():
65 | for i in range(len(Node_List)):
66 | self.Dict[key] += Node_State_List[i][key]
67 | self.Dict[key] /= len(Node_List)
68 |
--------------------------------------------------------------------------------
/Args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def args_parser():
5 | parser = argparse.ArgumentParser()
6 |
7 | # Total
8 | parser.add_argument('--algorithm', type=str, default='fed_avg',
9 | help='Type of algorithms:{fed_mutual, fed_avg, normal}')
10 | parser.add_argument('--device', type=str, default='cuda:0',
11 | help='device: {cuda, cpu}')
12 | parser.add_argument('--node_num', type=int, default=5,
13 | help='Number of nodes')
14 | parser.add_argument('--R', type=int, default=50,
15 | help='Number of rounds: R')
16 | parser.add_argument('--E', type=int, default=5,
17 | help='Number of local epochs: E')
18 | parser.add_argument('--notes', type=str, default='',
19 | help='Notes of Experiments')
20 |
21 | # Model
22 | parser.add_argument('--global_model', type=str, default='CNN1',
23 | help='Type of global model: {LeNet5, MLP, CNN2, ResNet18}')
24 | parser.add_argument('--local_model', type=str, default='CNN1',
25 | help='Type of local model: {LeNet5, MLP, CNN2, ResNet18}')
26 | parser.add_argument('--catfish', type=str, default=None,
27 | help='Type of local model: {None, LeNet5, MLP, CNN2, ResNet18}')
28 |
29 | # Data
30 | parser.add_argument('--dataset', type=str, default='cifar10',
31 | help='datasets: {cifar100, cifar10, femnist, mnist}')
32 | parser.add_argument('--batchsize', type=int, default=128,
33 | help='batchsize')
34 | parser.add_argument('--split', type=int, default=5,
35 | help='data split')
36 | parser.add_argument('--val_ratio', type=float, default=0.1,
37 | help='val_ratio')
38 | parser.add_argument('--all_data', type=bool, default=True,
39 | help='use all train_set')
40 | parser.add_argument('--classes', type=int, default=10,
41 | help='classes')
42 |
43 | # Optima
44 | parser.add_argument('--optimizer', type=str, default='sgd',
45 | help='optimizer: {sgd, adam}')
46 | parser.add_argument('--lr', type=float, default=0.01,
47 | help='learning rate')
48 | parser.add_argument('--lr_step', type=int, default=10,
49 | help='learning rate decay step size')
50 | parser.add_argument('--stop_decay', type=int, default=50,
51 | help='round when learning rate stop decay')
52 | parser.add_argument('--momentum', type=float, default=0.9,
53 | help='SGD momentum')
54 | parser.add_argument('--alpha', type=float, default=0.5,
55 | help='local ratio of data loss')
56 | parser.add_argument('--beta', type=float, default=0.5,
57 | help='meme ratio of data loss')
58 |
59 | args = parser.parse_args()
60 | return args
61 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import Node
3 |
4 |
5 | class Recorder(object):
6 | def __init__(self, args):
7 | self.args = args
8 | self.counter = 0
9 | self.tra_loss = {}
10 | self.tra_acc = {}
11 | self.val_loss = {}
12 | self.val_acc = {}
13 | for i in range(self.args.node_num + 1):
14 | self.val_loss[str(i)] = []
15 | self.val_acc[str(i)] = []
16 | self.val_loss[str(i)] = []
17 | self.val_acc[str(i)] = []
18 | self.acc_best = torch.zeros(self.args.node_num + 1)
19 | self.get_a_better = torch.zeros(self.args.node_num + 1)
20 |
21 | def validate(self, node):
22 | self.counter += 1
23 | node.model.to(node.device).eval()
24 | total_loss = 0.0
25 | correct = 0.0
26 |
27 | with torch.no_grad():
28 | for idx, (data, target) in enumerate(node.test_data):
29 | data, target = data.to(node.device), target.to(node.device)
30 | output = node.model(data)
31 | total_loss += torch.nn.CrossEntropyLoss()(output, target)
32 | pred = output.argmax(dim=1)
33 | correct += pred.eq(target.view_as(pred)).sum().item()
34 | total_loss = total_loss / (idx + 1)
35 | acc = correct / len(node.test_data.dataset) * 100
36 | self.val_loss[str(node.num)].append(total_loss)
37 | self.val_acc[str(node.num)].append(acc)
38 |
39 | if self.val_acc[str(node.num)][-1] > self.acc_best[node.num]:
40 | self.get_a_better[node.num] = 1
41 | self.acc_best[node.num] = self.val_acc[str(node.num)][-1]
42 | torch.save(node.model.state_dict(),
43 | './saves/model/Node{:d}_{:s}.pt'.format(node.num, node.args.local_model))
44 |
45 | def printer(self, node):
46 | if self.get_a_better[node.num] == 1:
47 | print('Node{:d}: A Better Accuracy: {:.2f}%! Model Saved!'.format(node.num, self.acc_best[node.num]))
48 | print('-------------------------')
49 | self.get_a_better[node.num] = 0
50 |
51 | def finish(self):
52 | torch.save([self.val_loss, self.val_acc],
53 | './saves/record/loss_acc_{:s}_{:s}.pt'.format(self.args.algorithm, self.args.notes))
54 | print('Finished!\n')
55 | for i in range(self.args.node_num + 1):
56 | print('Node{}: Best Accuracy = {:.2f}%'.format(i, self.acc_best[i]))
57 |
58 |
59 | def Catfish(Node_List, args):
60 | if args.catfish is None:
61 | pass
62 | else:
63 | Node_List[0].model = Node.init_model(args.catfish)
64 | Node_List[0].optimizer = Node.init_optimizer(Node_List[0].model, args)
65 |
66 |
67 | def LR_scheduler(rounds, Node_List, args):
68 | trigger = int(args.R / 3)
69 | if rounds != 0 and rounds % trigger == 0 and rounds < args.stop_decay:
70 | args.lr *= 0.1
71 | # args.alpha += 0.2
72 | # args.beta += 0.4
73 | for i in range(len(Node_List)):
74 | Node_List[i].args.lr = args.lr
75 | Node_List[i].args.alpha = args.alpha
76 | Node_List[i].args.beta = args.beta
77 | Node_List[i].optimizer.param_groups[0]['lr'] = args.lr
78 | Node_List[i].meme_optimizer.param_groups[0]['lr'] = args.lr
79 | print('Learning rate={:.4f}'.format(args.lr))
80 |
81 |
82 | def Summary(args):
83 | print("Summary:\n")
84 | print("algorithm:{}\n".format(args.algorithm))
85 | print("dataset:{}\tbatchsize:{}\n".format(args.dataset, args.batchsize))
86 | print("node_num:{},\tsplit:{}\n".format(args.node_num, args.split))
87 | # print("iid:{},\tequal:{},\n".format(args.iid == 1, args.unequal == 0))
88 | print("global epochs:{},\tlocal epochs:{},\n".format(args.R, args.E))
89 | print("global_model:{},\tlocal model:{},\n".format(args.global_model, args.local_model))
90 |
--------------------------------------------------------------------------------
/Model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class ResidualBlock(nn.Module):
6 | def __init__(self, inchannel, outchannel, stride=1):
7 | super(ResidualBlock, self).__init__()
8 | self.left = nn.Sequential(
9 | nn.Conv2d(
10 | inchannel,
11 | outchannel,
12 | kernel_size=3,
13 | stride=stride,
14 | padding=1,
15 | bias=False,
16 | ),
17 | nn.BatchNorm2d(outchannel),
18 | nn.ReLU(inplace=True),
19 | nn.Conv2d(
20 | outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False
21 | ),
22 | nn.BatchNorm2d(outchannel),
23 | )
24 | self.shortcut = nn.Sequential()
25 | if stride != 1 or inchannel != outchannel:
26 | self.shortcut = nn.Sequential(
27 | nn.Conv2d(
28 | inchannel, outchannel, kernel_size=1, stride=stride, bias=False
29 | ),
30 | nn.BatchNorm2d(outchannel),
31 | )
32 |
33 | def forward(self, x):
34 | out = self.left(x)
35 | out += self.shortcut(x)
36 | out = F.relu(out)
37 | return out
38 |
39 |
40 | class ResNet(nn.Module):
41 | def __init__(self, residual_block, num_classes=10):
42 | super(ResNet, self).__init__()
43 | self.inchannel = 64
44 | self.conv1 = nn.Sequential(
45 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
46 | nn.BatchNorm2d(64),
47 | nn.ReLU(),
48 | )
49 | self.layer1 = self.make_layer(residual_block, 64, 2, stride=1)
50 | self.layer2 = self.make_layer(residual_block, 128, 2, stride=2)
51 | self.layer3 = self.make_layer(residual_block, 256, 2, stride=2)
52 | self.layer4 = self.make_layer(residual_block, 512, 2, stride=2)
53 | self.fc = nn.Linear(512, num_classes)
54 |
55 | def make_layer(self, block, channels, num_blocks, stride):
56 | strides = [stride] + [1] * (num_blocks - 1) # strides=[1,1]
57 | layers = []
58 | for stride in strides:
59 | layers.append(block(self.inchannel, channels, stride))
60 | self.inchannel = channels
61 | return nn.Sequential(*layers)
62 |
63 | def forward(self, x):
64 | out = self.conv1(x)
65 | out = self.layer1(out)
66 | out = self.layer2(out)
67 | out = self.layer3(out)
68 | out = self.layer4(out)
69 | out = F.avg_pool2d(out, 4)
70 | out = out.view(out.size(0), -1)
71 | out = self.fc(out)
72 | return out
73 |
74 |
75 | def ResNet18():
76 | return ResNet(ResidualBlock)
77 |
78 |
79 | class LeNet5(nn.Module):
80 | def __init__(self):
81 | super(LeNet5, self).__init__()
82 | self.conv1 = nn.Conv2d(3, 6, 5)
83 | self.conv2 = nn.Conv2d(6, 16, 5)
84 | self.fc1 = nn.Linear(16 * 5 * 5, 120)
85 | self.fc2 = nn.Linear(120, 84)
86 | self.fc3 = nn.Linear(84, 10)
87 |
88 | def forward(self, x):
89 | x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
90 | x = F.max_pool2d(F.relu(self.conv2(x)), 2)
91 | x = x.view(-1, 16 * 5 * 5)
92 | x = F.relu(self.fc1(x))
93 | x = F.relu(self.fc2(x))
94 | x = self.fc3(x)
95 | return x
96 |
97 |
98 | class MLP(nn.Module):
99 | def __init__(self):
100 | super(MLP, self).__init__()
101 | self.fc1 = nn.Linear(3 * 32 * 32, 200)
102 | self.fc2 = nn.Linear(200, 200)
103 | self.fc3 = nn.Linear(200, 10)
104 |
105 | def forward(self, x):
106 | x = x.view(-1, 3 * 32 * 32)
107 | x = F.relu(self.fc1(x))
108 | x = F.relu(self.fc2(x))
109 | x = self.fc3(x)
110 | return x
111 |
112 |
113 | class CNN(nn.Module):
114 | def __init__(self):
115 | super(CNN, self).__init__()
116 | self.conv1 = nn.Conv2d(3, 32, 3)
117 | self.pool = nn.MaxPool2d(2, 2)
118 | self.conv2 = nn.Conv2d(32, 64, 3)
119 | self.conv3 = nn.Conv2d(64, 64, 3)
120 | self.fc1 = nn.Linear(64 * 4 * 4, 64)
121 | self.fc2 = nn.Linear(64, 10)
122 |
123 | def forward(self, x):
124 | x = self.pool(F.relu(self.conv1(x)))
125 | x = self.pool(F.relu(self.conv2(x)))
126 | x = F.relu(self.conv3(x))
127 | x = x.view(-1, 64 * 4 * 4)
128 | x = F.relu(self.fc1(x))
129 | x = self.fc2(x)
130 | return x
131 |
--------------------------------------------------------------------------------
/Trainer.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import torch.nn as nn
3 |
4 | KL_Loss = nn.KLDivLoss(reduction='batchmean')
5 | Softmax = nn.Softmax(dim=1)
6 | LogSoftmax = nn.LogSoftmax(dim=1)
7 | CE_Loss = nn.CrossEntropyLoss()
8 |
9 |
10 | def train_normal(node):
11 | node.model.to(node.device).train()
12 | train_loader = node.train_data
13 | total_loss = 0.0
14 | avg_loss = 0.0
15 | correct = 0.0
16 | acc = 0.0
17 | description = "Training (the {:d}-batch): tra_Loss = {:.4f} tra_Accuracy = {:.2f}%"
18 | with tqdm(train_loader) as epochs:
19 | for idx, (data, target) in enumerate(epochs):
20 | node.optimizer.zero_grad()
21 | epochs.set_description(description.format(idx + 1, avg_loss, acc))
22 | data, target = data.to(node.device), target.to(node.device)
23 | output = node.model(data)
24 | loss = CE_Loss(output, target)
25 | loss.backward()
26 | node.optimizer.step()
27 | total_loss += loss
28 | avg_loss = total_loss / (idx + 1)
29 | pred = output.argmax(dim=1)
30 | correct += pred.eq(target.view_as(pred)).sum()
31 | acc = correct / len(train_loader.dataset) * 100
32 |
33 |
34 | def train_avg(node):
35 | node.meme.to(node.device).train()
36 | train_loader = node.train_data
37 | total_loss = 0.0
38 | avg_loss = 0.0
39 | correct = 0.0
40 | acc = 0.0
41 | description = "Node{:d}: loss={:.4f} acc={:.2f}%"
42 | with tqdm(train_loader) as epochs:
43 | for idx, (data, target) in enumerate(epochs):
44 | node.meme_optimizer.zero_grad()
45 | epochs.set_description(description.format(node.num, avg_loss, acc))
46 | data, target = data.to(node.device), target.to(node.device)
47 | output = node.meme(data)
48 | loss = CE_Loss(output, target)
49 | loss.backward()
50 | node.meme_optimizer.step()
51 | total_loss += loss
52 | avg_loss = total_loss / (idx + 1)
53 | pred = output.argmax(dim=1)
54 | correct += pred.eq(target.view_as(pred)).sum()
55 | acc = correct / len(train_loader.dataset) * 100
56 | node.model = node.meme
57 |
58 |
59 | def train_mutual(node):
60 | node.model.to(node.device).train()
61 | node.meme.to(node.device).train()
62 | train_loader = node.train_data
63 | total_local_loss = 0.0
64 | avg_local_loss = 0.0
65 | correct_local = 0.0
66 | acc_local = 0.0
67 | total_meme_loss = 0.0
68 | avg_meme_loss = 0.0
69 | correct_meme = 0.0
70 | acc_meme = 0.0
71 | description = 'Node{:d}: loss_model={:.4f} acc_model={:.2f}% loss_meme={:.4f} acc_meme={:.2f}%'
72 | with tqdm(train_loader) as epochs:
73 | for idx, (data, target) in enumerate(epochs):
74 | node.optimizer.zero_grad()
75 | node.meme_optimizer.zero_grad()
76 | epochs.set_description(description.format(node.num, avg_local_loss, acc_local, avg_meme_loss, acc_meme))
77 | data, target = data.to(node.device), target.to(node.device)
78 | output_local = node.model(data)
79 | output_meme = node.meme(data)
80 | ce_local = CE_Loss(output_local, target)
81 | kl_local = KL_Loss(LogSoftmax(output_local), Softmax(output_meme.detach()))
82 | ce_meme = CE_Loss(output_meme, target)
83 | kl_meme = KL_Loss(LogSoftmax(output_meme), Softmax(output_local.detach()))
84 | loss_local = node.args.alpha * ce_local + (1 - node.args.alpha) * kl_local
85 | loss_meme = node.args.beta * ce_meme + (1 - node.args.beta) * kl_meme
86 | loss_local.backward()
87 | loss_meme.backward()
88 | node.optimizer.step()
89 | node.meme_optimizer.step()
90 | total_local_loss += loss_local
91 | avg_local_loss = total_local_loss / (idx + 1)
92 | pred_local = output_local.argmax(dim=1)
93 | correct_local += pred_local.eq(target.view_as(pred_local)).sum()
94 | acc_local = correct_local / len(train_loader.dataset) * 100
95 | total_meme_loss += loss_meme
96 | avg_meme_loss = total_meme_loss / (idx + 1)
97 | pred_meme = output_meme.argmax(dim=1)
98 | correct_meme += pred_meme.eq(target.view_as(pred_meme)).sum()
99 | acc_meme = correct_meme / len(train_loader.dataset) * 100
100 |
101 |
102 | class Trainer(object):
103 |
104 | def __init__(self, args):
105 | if args.algorithm == 'fed_mutual':
106 | self.train = train_mutual
107 | elif args.algorithm == 'fed_avg':
108 | self.train = train_avg
109 | elif args.algorithm == 'normal':
110 | self.train = train_normal
111 |
112 | def __call__(self, node):
113 | self.train(node)
114 |
--------------------------------------------------------------------------------
/Data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os.path
3 | from torchvision.datasets import utils, MNIST, CIFAR10
4 | from torchvision import transforms
5 | from torch.utils.data import Subset, DataLoader
6 | from PIL import Image
7 |
8 |
9 | class FEMNIST(MNIST):
10 | """
11 | This dataset is derived from the Leaf repository
12 | (https://github.com/TalwalkarLab/leaf) pre-processing of the Extended MNIST
13 | dataset, grouping examples by writer. Details about Leaf were published in
14 | "LEAF: A Benchmark for Federated Settings" https://arxiv.org/abs/1812.01097.
15 | """
16 | resources = [
17 | ('https://raw.githubusercontent.com/tao-shen/FEMNIST_pytorch/master/femnist.tar.gz',
18 | '59c65cec646fc57fe92d27d83afdf0ed')]
19 |
20 | def __init__(self, root, train=True, transform=None, target_transform=None,
21 | download=False):
22 | super(MNIST, self).__init__(root, transform=transform,
23 | target_transform=target_transform)
24 | self.train = train
25 |
26 | if download:
27 | self.download()
28 |
29 | if not self._check_exists():
30 | raise RuntimeError('Dataset not found.' +
31 | ' You can use download=True to download it')
32 | if self.train:
33 | data_file = self.training_file
34 | else:
35 | data_file = self.test_file
36 |
37 | self.data, self.targets, self.users_index = torch.load(os.path.join(self.processed_folder, data_file))
38 |
39 | def __getitem__(self, index):
40 | img, target = self.data[index], int(self.targets[index])
41 | img = Image.fromarray(img.numpy(), mode='F')
42 | if self.transform is not None:
43 | img = self.transform(img)
44 | if self.target_transform is not None:
45 | target = self.target_transform(target)
46 | return img, target
47 |
48 | def download(self):
49 | """Download the FEMNIST data if it doesn't exist in processed_folder already."""
50 | import shutil
51 |
52 | if self._check_exists():
53 | return
54 |
55 | utils.makedir_exist_ok(self.raw_folder)
56 | utils.makedir_exist_ok(self.processed_folder)
57 |
58 | # download files
59 | for url, md5 in self.resources:
60 | filename = url.rpartition('/')[2]
61 | utils.download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
62 |
63 | # process and save as torch files
64 | print('Processing...')
65 | shutil.move(os.path.join(self.raw_folder, self.training_file), self.processed_folder)
66 | shutil.move(os.path.join(self.raw_folder, self.test_file), self.processed_folder)
67 |
68 |
69 | def Dataset(args):
70 | trainset, testset = None, None
71 |
72 | if args.dataset == 'cifar10':
73 | tra_trans = transforms.Compose([
74 | transforms.RandomCrop(32, padding=4),
75 | transforms.RandomHorizontalFlip(),
76 | transforms.ToTensor(),
77 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
78 | ])
79 | val_trans = transforms.Compose([
80 | transforms.ToTensor(),
81 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
82 | ])
83 | trainset = CIFAR10(root="~/data", train=True, download=False, transform=tra_trans)
84 | testset = CIFAR10(root="~/data", train=False, download=False, transform=val_trans)
85 |
86 | if args.dataset == 'femnist' or 'mnist':
87 | tra_trans = transforms.Compose([
88 | transforms.Pad(2, padding_mode='edge'),
89 | transforms.ToTensor(),
90 | transforms.Normalize((0.1307,), (0.3081,)),
91 | ])
92 | val_trans = transforms.Compose([
93 | transforms.Pad(2, padding_mode='edge'),
94 | transforms.ToTensor(),
95 | transforms.Normalize((0.1307,), (0.3081,)),
96 | ])
97 | if args.dataset == 'femnist':
98 | trainset = FEMNIST(root='~/data', train=True, transform=tra_trans)
99 | testset = FEMNIST(root='~/data', train=False, transform=val_trans)
100 | if args.dataset == 'mnist':
101 | trainset = MNIST(root='~/data', train=True, transform=tra_trans)
102 | testset = MNIST(root='~/data', train=False, transform=val_trans)
103 |
104 | return trainset, testset
105 |
106 |
107 | class Data(object):
108 |
109 | def __init__(self, args):
110 | self.args = args
111 | self.trainset, self.testset = None, None
112 | trainset, testset = Dataset(args)
113 | num_train = [int(len(trainset) / args.split) for _ in range(args.split)]
114 | cumsum_train = torch.tensor(list(num_train)).cumsum(dim=0).tolist()
115 | # idx_train = sorted(range(len(trainset.targets)), key=lambda k: trainset.targets[k]) #split by class
116 | idx_train = range(len(trainset.targets))
117 | splited_trainset = [Subset(trainset, idx_train[off - l:off]) for off, l in zip(cumsum_train, num_train)]
118 | num_test = [int(len(testset) / args.split) for _ in range(args.split)]
119 | cumsum_test = torch.tensor(list(num_test)).cumsum(dim=0).tolist()
120 | # idx_train = sorted(range(len(trainset.targets)), key=lambda k: trainset.targets[k]) #split by class
121 | idx_test = range(len(testset.targets))
122 | splited_testset = [Subset(testset, idx_test[off - l:off]) for off, l in zip(cumsum_test, num_test)]
123 | self.test_all = DataLoader(testset, batch_size=args.batchsize, shuffle=False, num_workers=4)
124 | self.train_loader = [DataLoader(splited_trainset[i], batch_size=args.batchsize, shuffle=True, num_workers=4)
125 | for i in range(args.node_num)]
126 | self.test_loader = [DataLoader(splited_testset[i], batch_size=args.batchsize, shuffle=False, num_workers=4)
127 | for i in range(args.node_num)]
128 | self.test_loader = DataLoader(testset, batch_size=args.batchsize, shuffle=False, num_workers=4)
129 |
--------------------------------------------------------------------------------