├── README.md ├── .idea ├── other.xml ├── vcs.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml ├── saveactions_settings.xml ├── misc.xml ├── Federated-Mutual-Learning.iml └── deployment.xml ├── main.py ├── Node.py ├── Args.py ├── utils.py ├── Model.py ├── Trainer.py └── Data.py /README.md: -------------------------------------------------------------------------------- 1 | # Federated-Mutual-Learning -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/saveactions_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 11 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/Federated-Mutual-Learning.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 22 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from Node import Node, Global_Node 3 | from Args import args_parser 4 | from Data import Data 5 | from utils import LR_scheduler, Recorder, Catfish, Summary 6 | from Trainer import Trainer 7 | 8 | # init args 9 | args = args_parser() 10 | args.device = torch.device(args.device if torch.cuda.is_available() else 'cpu') 11 | print('Running on', args.device) 12 | Data = Data(args) 13 | Train = Trainer(args) 14 | 15 | # init nodes 16 | Global_node = Global_Node(Data.test_all, args) 17 | Node_List = [Node(k, Data.train_loader[k], Data.test_loader, args) for k in range(args.node_num)] 18 | Catfish(Node_List, args) 19 | 20 | # init variables 21 | recorder = Recorder(args) 22 | Summary(args) 23 | # start 24 | for rounds in range(args.R): 25 | print('===============The {:d}-th round==============='.format(rounds + 1)) 26 | LR_scheduler(rounds, Node_List, args) 27 | for k in range(len(Node_List)): 28 | # Node_List[k].fork(Global_node) 29 | for epoch in range(args.E): 30 | Train(Node_List[k]) 31 | recorder.validate(Node_List[k]) 32 | recorder.printer(Node_List[k]) 33 | # Global_node.merge(Node_List) 34 | recorder.validate(Global_node) 35 | recorder.printer(Global_node) 36 | recorder.finish() 37 | Summary(args) 38 | -------------------------------------------------------------------------------- /Node.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import Model 4 | 5 | 6 | def init_model(model_type): 7 | model = [] 8 | if model_type == 'LeNet5': 9 | model = Model.LeNet5() 10 | elif model_type == 'MLP': 11 | model = Model.MLP() 12 | elif model_type == 'ResNet18': 13 | model = Model.ResNet18() 14 | elif model_type == 'CNN': 15 | model = Model.CNN() 16 | return model 17 | 18 | 19 | def init_optimizer(model, args): 20 | optimizer = [] 21 | if args.optimizer == 'sgd': 22 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5e-4) 23 | elif args.optimizer == 'adam': 24 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) 25 | return optimizer 26 | 27 | 28 | def weights_zero(model): 29 | for p in model.parameters(): 30 | if p.data is not None: 31 | p.data.detach_() 32 | p.data.zero_() 33 | 34 | 35 | class Node(object): 36 | def __init__(self, num, train_data, test_data, args): 37 | self.args = args 38 | self.num = num + 1 39 | self.device = self.args.device 40 | self.train_data = train_data 41 | self.test_data = test_data 42 | self.model = init_model(self.args.local_model).to(self.device) 43 | self.optimizer = init_optimizer(self.model, self.args) 44 | self.meme = init_model(self.args.global_model).to(self.device) 45 | self.meme_optimizer = init_optimizer(self.meme, self.args) 46 | 47 | def fork(self, global_node): 48 | self.meme = copy.deepcopy(global_node.model).to(self.device) 49 | self.meme_optimizer = init_optimizer(self.meme, self.args) 50 | 51 | 52 | class Global_Node(object): 53 | def __init__(self, test_data, args): 54 | self.num = 0 55 | self.args = args 56 | self.device = self.args.device 57 | self.model = init_model(self.args.global_model).to(self.device) 58 | self.test_data = test_data 59 | self.Dict = self.model.state_dict() 60 | 61 | def merge(self, Node_List): 62 | weights_zero(self.model) 63 | Node_State_List = [copy.deepcopy(Node_List[i].meme.state_dict()) for i in range(len(Node_List))] 64 | for key in self.Dict.keys(): 65 | for i in range(len(Node_List)): 66 | self.Dict[key] += Node_State_List[i][key] 67 | self.Dict[key] /= len(Node_List) 68 | -------------------------------------------------------------------------------- /Args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def args_parser(): 5 | parser = argparse.ArgumentParser() 6 | 7 | # Total 8 | parser.add_argument('--algorithm', type=str, default='fed_avg', 9 | help='Type of algorithms:{fed_mutual, fed_avg, normal}') 10 | parser.add_argument('--device', type=str, default='cuda:0', 11 | help='device: {cuda, cpu}') 12 | parser.add_argument('--node_num', type=int, default=5, 13 | help='Number of nodes') 14 | parser.add_argument('--R', type=int, default=50, 15 | help='Number of rounds: R') 16 | parser.add_argument('--E', type=int, default=5, 17 | help='Number of local epochs: E') 18 | parser.add_argument('--notes', type=str, default='', 19 | help='Notes of Experiments') 20 | 21 | # Model 22 | parser.add_argument('--global_model', type=str, default='CNN1', 23 | help='Type of global model: {LeNet5, MLP, CNN2, ResNet18}') 24 | parser.add_argument('--local_model', type=str, default='CNN1', 25 | help='Type of local model: {LeNet5, MLP, CNN2, ResNet18}') 26 | parser.add_argument('--catfish', type=str, default=None, 27 | help='Type of local model: {None, LeNet5, MLP, CNN2, ResNet18}') 28 | 29 | # Data 30 | parser.add_argument('--dataset', type=str, default='cifar10', 31 | help='datasets: {cifar100, cifar10, femnist, mnist}') 32 | parser.add_argument('--batchsize', type=int, default=128, 33 | help='batchsize') 34 | parser.add_argument('--split', type=int, default=5, 35 | help='data split') 36 | parser.add_argument('--val_ratio', type=float, default=0.1, 37 | help='val_ratio') 38 | parser.add_argument('--all_data', type=bool, default=True, 39 | help='use all train_set') 40 | parser.add_argument('--classes', type=int, default=10, 41 | help='classes') 42 | 43 | # Optima 44 | parser.add_argument('--optimizer', type=str, default='sgd', 45 | help='optimizer: {sgd, adam}') 46 | parser.add_argument('--lr', type=float, default=0.01, 47 | help='learning rate') 48 | parser.add_argument('--lr_step', type=int, default=10, 49 | help='learning rate decay step size') 50 | parser.add_argument('--stop_decay', type=int, default=50, 51 | help='round when learning rate stop decay') 52 | parser.add_argument('--momentum', type=float, default=0.9, 53 | help='SGD momentum') 54 | parser.add_argument('--alpha', type=float, default=0.5, 55 | help='local ratio of data loss') 56 | parser.add_argument('--beta', type=float, default=0.5, 57 | help='meme ratio of data loss') 58 | 59 | args = parser.parse_args() 60 | return args 61 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import Node 3 | 4 | 5 | class Recorder(object): 6 | def __init__(self, args): 7 | self.args = args 8 | self.counter = 0 9 | self.tra_loss = {} 10 | self.tra_acc = {} 11 | self.val_loss = {} 12 | self.val_acc = {} 13 | for i in range(self.args.node_num + 1): 14 | self.val_loss[str(i)] = [] 15 | self.val_acc[str(i)] = [] 16 | self.val_loss[str(i)] = [] 17 | self.val_acc[str(i)] = [] 18 | self.acc_best = torch.zeros(self.args.node_num + 1) 19 | self.get_a_better = torch.zeros(self.args.node_num + 1) 20 | 21 | def validate(self, node): 22 | self.counter += 1 23 | node.model.to(node.device).eval() 24 | total_loss = 0.0 25 | correct = 0.0 26 | 27 | with torch.no_grad(): 28 | for idx, (data, target) in enumerate(node.test_data): 29 | data, target = data.to(node.device), target.to(node.device) 30 | output = node.model(data) 31 | total_loss += torch.nn.CrossEntropyLoss()(output, target) 32 | pred = output.argmax(dim=1) 33 | correct += pred.eq(target.view_as(pred)).sum().item() 34 | total_loss = total_loss / (idx + 1) 35 | acc = correct / len(node.test_data.dataset) * 100 36 | self.val_loss[str(node.num)].append(total_loss) 37 | self.val_acc[str(node.num)].append(acc) 38 | 39 | if self.val_acc[str(node.num)][-1] > self.acc_best[node.num]: 40 | self.get_a_better[node.num] = 1 41 | self.acc_best[node.num] = self.val_acc[str(node.num)][-1] 42 | torch.save(node.model.state_dict(), 43 | './saves/model/Node{:d}_{:s}.pt'.format(node.num, node.args.local_model)) 44 | 45 | def printer(self, node): 46 | if self.get_a_better[node.num] == 1: 47 | print('Node{:d}: A Better Accuracy: {:.2f}%! Model Saved!'.format(node.num, self.acc_best[node.num])) 48 | print('-------------------------') 49 | self.get_a_better[node.num] = 0 50 | 51 | def finish(self): 52 | torch.save([self.val_loss, self.val_acc], 53 | './saves/record/loss_acc_{:s}_{:s}.pt'.format(self.args.algorithm, self.args.notes)) 54 | print('Finished!\n') 55 | for i in range(self.args.node_num + 1): 56 | print('Node{}: Best Accuracy = {:.2f}%'.format(i, self.acc_best[i])) 57 | 58 | 59 | def Catfish(Node_List, args): 60 | if args.catfish is None: 61 | pass 62 | else: 63 | Node_List[0].model = Node.init_model(args.catfish) 64 | Node_List[0].optimizer = Node.init_optimizer(Node_List[0].model, args) 65 | 66 | 67 | def LR_scheduler(rounds, Node_List, args): 68 | trigger = int(args.R / 3) 69 | if rounds != 0 and rounds % trigger == 0 and rounds < args.stop_decay: 70 | args.lr *= 0.1 71 | # args.alpha += 0.2 72 | # args.beta += 0.4 73 | for i in range(len(Node_List)): 74 | Node_List[i].args.lr = args.lr 75 | Node_List[i].args.alpha = args.alpha 76 | Node_List[i].args.beta = args.beta 77 | Node_List[i].optimizer.param_groups[0]['lr'] = args.lr 78 | Node_List[i].meme_optimizer.param_groups[0]['lr'] = args.lr 79 | print('Learning rate={:.4f}'.format(args.lr)) 80 | 81 | 82 | def Summary(args): 83 | print("Summary:\n") 84 | print("algorithm:{}\n".format(args.algorithm)) 85 | print("dataset:{}\tbatchsize:{}\n".format(args.dataset, args.batchsize)) 86 | print("node_num:{},\tsplit:{}\n".format(args.node_num, args.split)) 87 | # print("iid:{},\tequal:{},\n".format(args.iid == 1, args.unequal == 0)) 88 | print("global epochs:{},\tlocal epochs:{},\n".format(args.R, args.E)) 89 | print("global_model:{},\tlocal model:{},\n".format(args.global_model, args.local_model)) 90 | -------------------------------------------------------------------------------- /Model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class ResidualBlock(nn.Module): 6 | def __init__(self, inchannel, outchannel, stride=1): 7 | super(ResidualBlock, self).__init__() 8 | self.left = nn.Sequential( 9 | nn.Conv2d( 10 | inchannel, 11 | outchannel, 12 | kernel_size=3, 13 | stride=stride, 14 | padding=1, 15 | bias=False, 16 | ), 17 | nn.BatchNorm2d(outchannel), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d( 20 | outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False 21 | ), 22 | nn.BatchNorm2d(outchannel), 23 | ) 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or inchannel != outchannel: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d( 28 | inchannel, outchannel, kernel_size=1, stride=stride, bias=False 29 | ), 30 | nn.BatchNorm2d(outchannel), 31 | ) 32 | 33 | def forward(self, x): 34 | out = self.left(x) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class ResNet(nn.Module): 41 | def __init__(self, residual_block, num_classes=10): 42 | super(ResNet, self).__init__() 43 | self.inchannel = 64 44 | self.conv1 = nn.Sequential( 45 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False), 46 | nn.BatchNorm2d(64), 47 | nn.ReLU(), 48 | ) 49 | self.layer1 = self.make_layer(residual_block, 64, 2, stride=1) 50 | self.layer2 = self.make_layer(residual_block, 128, 2, stride=2) 51 | self.layer3 = self.make_layer(residual_block, 256, 2, stride=2) 52 | self.layer4 = self.make_layer(residual_block, 512, 2, stride=2) 53 | self.fc = nn.Linear(512, num_classes) 54 | 55 | def make_layer(self, block, channels, num_blocks, stride): 56 | strides = [stride] + [1] * (num_blocks - 1) # strides=[1,1] 57 | layers = [] 58 | for stride in strides: 59 | layers.append(block(self.inchannel, channels, stride)) 60 | self.inchannel = channels 61 | return nn.Sequential(*layers) 62 | 63 | def forward(self, x): 64 | out = self.conv1(x) 65 | out = self.layer1(out) 66 | out = self.layer2(out) 67 | out = self.layer3(out) 68 | out = self.layer4(out) 69 | out = F.avg_pool2d(out, 4) 70 | out = out.view(out.size(0), -1) 71 | out = self.fc(out) 72 | return out 73 | 74 | 75 | def ResNet18(): 76 | return ResNet(ResidualBlock) 77 | 78 | 79 | class LeNet5(nn.Module): 80 | def __init__(self): 81 | super(LeNet5, self).__init__() 82 | self.conv1 = nn.Conv2d(3, 6, 5) 83 | self.conv2 = nn.Conv2d(6, 16, 5) 84 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 85 | self.fc2 = nn.Linear(120, 84) 86 | self.fc3 = nn.Linear(84, 10) 87 | 88 | def forward(self, x): 89 | x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) 90 | x = F.max_pool2d(F.relu(self.conv2(x)), 2) 91 | x = x.view(-1, 16 * 5 * 5) 92 | x = F.relu(self.fc1(x)) 93 | x = F.relu(self.fc2(x)) 94 | x = self.fc3(x) 95 | return x 96 | 97 | 98 | class MLP(nn.Module): 99 | def __init__(self): 100 | super(MLP, self).__init__() 101 | self.fc1 = nn.Linear(3 * 32 * 32, 200) 102 | self.fc2 = nn.Linear(200, 200) 103 | self.fc3 = nn.Linear(200, 10) 104 | 105 | def forward(self, x): 106 | x = x.view(-1, 3 * 32 * 32) 107 | x = F.relu(self.fc1(x)) 108 | x = F.relu(self.fc2(x)) 109 | x = self.fc3(x) 110 | return x 111 | 112 | 113 | class CNN(nn.Module): 114 | def __init__(self): 115 | super(CNN, self).__init__() 116 | self.conv1 = nn.Conv2d(3, 32, 3) 117 | self.pool = nn.MaxPool2d(2, 2) 118 | self.conv2 = nn.Conv2d(32, 64, 3) 119 | self.conv3 = nn.Conv2d(64, 64, 3) 120 | self.fc1 = nn.Linear(64 * 4 * 4, 64) 121 | self.fc2 = nn.Linear(64, 10) 122 | 123 | def forward(self, x): 124 | x = self.pool(F.relu(self.conv1(x))) 125 | x = self.pool(F.relu(self.conv2(x))) 126 | x = F.relu(self.conv3(x)) 127 | x = x.view(-1, 64 * 4 * 4) 128 | x = F.relu(self.fc1(x)) 129 | x = self.fc2(x) 130 | return x 131 | -------------------------------------------------------------------------------- /Trainer.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch.nn as nn 3 | 4 | KL_Loss = nn.KLDivLoss(reduction='batchmean') 5 | Softmax = nn.Softmax(dim=1) 6 | LogSoftmax = nn.LogSoftmax(dim=1) 7 | CE_Loss = nn.CrossEntropyLoss() 8 | 9 | 10 | def train_normal(node): 11 | node.model.to(node.device).train() 12 | train_loader = node.train_data 13 | total_loss = 0.0 14 | avg_loss = 0.0 15 | correct = 0.0 16 | acc = 0.0 17 | description = "Training (the {:d}-batch): tra_Loss = {:.4f} tra_Accuracy = {:.2f}%" 18 | with tqdm(train_loader) as epochs: 19 | for idx, (data, target) in enumerate(epochs): 20 | node.optimizer.zero_grad() 21 | epochs.set_description(description.format(idx + 1, avg_loss, acc)) 22 | data, target = data.to(node.device), target.to(node.device) 23 | output = node.model(data) 24 | loss = CE_Loss(output, target) 25 | loss.backward() 26 | node.optimizer.step() 27 | total_loss += loss 28 | avg_loss = total_loss / (idx + 1) 29 | pred = output.argmax(dim=1) 30 | correct += pred.eq(target.view_as(pred)).sum() 31 | acc = correct / len(train_loader.dataset) * 100 32 | 33 | 34 | def train_avg(node): 35 | node.meme.to(node.device).train() 36 | train_loader = node.train_data 37 | total_loss = 0.0 38 | avg_loss = 0.0 39 | correct = 0.0 40 | acc = 0.0 41 | description = "Node{:d}: loss={:.4f} acc={:.2f}%" 42 | with tqdm(train_loader) as epochs: 43 | for idx, (data, target) in enumerate(epochs): 44 | node.meme_optimizer.zero_grad() 45 | epochs.set_description(description.format(node.num, avg_loss, acc)) 46 | data, target = data.to(node.device), target.to(node.device) 47 | output = node.meme(data) 48 | loss = CE_Loss(output, target) 49 | loss.backward() 50 | node.meme_optimizer.step() 51 | total_loss += loss 52 | avg_loss = total_loss / (idx + 1) 53 | pred = output.argmax(dim=1) 54 | correct += pred.eq(target.view_as(pred)).sum() 55 | acc = correct / len(train_loader.dataset) * 100 56 | node.model = node.meme 57 | 58 | 59 | def train_mutual(node): 60 | node.model.to(node.device).train() 61 | node.meme.to(node.device).train() 62 | train_loader = node.train_data 63 | total_local_loss = 0.0 64 | avg_local_loss = 0.0 65 | correct_local = 0.0 66 | acc_local = 0.0 67 | total_meme_loss = 0.0 68 | avg_meme_loss = 0.0 69 | correct_meme = 0.0 70 | acc_meme = 0.0 71 | description = 'Node{:d}: loss_model={:.4f} acc_model={:.2f}% loss_meme={:.4f} acc_meme={:.2f}%' 72 | with tqdm(train_loader) as epochs: 73 | for idx, (data, target) in enumerate(epochs): 74 | node.optimizer.zero_grad() 75 | node.meme_optimizer.zero_grad() 76 | epochs.set_description(description.format(node.num, avg_local_loss, acc_local, avg_meme_loss, acc_meme)) 77 | data, target = data.to(node.device), target.to(node.device) 78 | output_local = node.model(data) 79 | output_meme = node.meme(data) 80 | ce_local = CE_Loss(output_local, target) 81 | kl_local = KL_Loss(LogSoftmax(output_local), Softmax(output_meme.detach())) 82 | ce_meme = CE_Loss(output_meme, target) 83 | kl_meme = KL_Loss(LogSoftmax(output_meme), Softmax(output_local.detach())) 84 | loss_local = node.args.alpha * ce_local + (1 - node.args.alpha) * kl_local 85 | loss_meme = node.args.beta * ce_meme + (1 - node.args.beta) * kl_meme 86 | loss_local.backward() 87 | loss_meme.backward() 88 | node.optimizer.step() 89 | node.meme_optimizer.step() 90 | total_local_loss += loss_local 91 | avg_local_loss = total_local_loss / (idx + 1) 92 | pred_local = output_local.argmax(dim=1) 93 | correct_local += pred_local.eq(target.view_as(pred_local)).sum() 94 | acc_local = correct_local / len(train_loader.dataset) * 100 95 | total_meme_loss += loss_meme 96 | avg_meme_loss = total_meme_loss / (idx + 1) 97 | pred_meme = output_meme.argmax(dim=1) 98 | correct_meme += pred_meme.eq(target.view_as(pred_meme)).sum() 99 | acc_meme = correct_meme / len(train_loader.dataset) * 100 100 | 101 | 102 | class Trainer(object): 103 | 104 | def __init__(self, args): 105 | if args.algorithm == 'fed_mutual': 106 | self.train = train_mutual 107 | elif args.algorithm == 'fed_avg': 108 | self.train = train_avg 109 | elif args.algorithm == 'normal': 110 | self.train = train_normal 111 | 112 | def __call__(self, node): 113 | self.train(node) 114 | -------------------------------------------------------------------------------- /Data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path 3 | from torchvision.datasets import utils, MNIST, CIFAR10 4 | from torchvision import transforms 5 | from torch.utils.data import Subset, DataLoader 6 | from PIL import Image 7 | 8 | 9 | class FEMNIST(MNIST): 10 | """ 11 | This dataset is derived from the Leaf repository 12 | (https://github.com/TalwalkarLab/leaf) pre-processing of the Extended MNIST 13 | dataset, grouping examples by writer. Details about Leaf were published in 14 | "LEAF: A Benchmark for Federated Settings" https://arxiv.org/abs/1812.01097. 15 | """ 16 | resources = [ 17 | ('https://raw.githubusercontent.com/tao-shen/FEMNIST_pytorch/master/femnist.tar.gz', 18 | '59c65cec646fc57fe92d27d83afdf0ed')] 19 | 20 | def __init__(self, root, train=True, transform=None, target_transform=None, 21 | download=False): 22 | super(MNIST, self).__init__(root, transform=transform, 23 | target_transform=target_transform) 24 | self.train = train 25 | 26 | if download: 27 | self.download() 28 | 29 | if not self._check_exists(): 30 | raise RuntimeError('Dataset not found.' + 31 | ' You can use download=True to download it') 32 | if self.train: 33 | data_file = self.training_file 34 | else: 35 | data_file = self.test_file 36 | 37 | self.data, self.targets, self.users_index = torch.load(os.path.join(self.processed_folder, data_file)) 38 | 39 | def __getitem__(self, index): 40 | img, target = self.data[index], int(self.targets[index]) 41 | img = Image.fromarray(img.numpy(), mode='F') 42 | if self.transform is not None: 43 | img = self.transform(img) 44 | if self.target_transform is not None: 45 | target = self.target_transform(target) 46 | return img, target 47 | 48 | def download(self): 49 | """Download the FEMNIST data if it doesn't exist in processed_folder already.""" 50 | import shutil 51 | 52 | if self._check_exists(): 53 | return 54 | 55 | utils.makedir_exist_ok(self.raw_folder) 56 | utils.makedir_exist_ok(self.processed_folder) 57 | 58 | # download files 59 | for url, md5 in self.resources: 60 | filename = url.rpartition('/')[2] 61 | utils.download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) 62 | 63 | # process and save as torch files 64 | print('Processing...') 65 | shutil.move(os.path.join(self.raw_folder, self.training_file), self.processed_folder) 66 | shutil.move(os.path.join(self.raw_folder, self.test_file), self.processed_folder) 67 | 68 | 69 | def Dataset(args): 70 | trainset, testset = None, None 71 | 72 | if args.dataset == 'cifar10': 73 | tra_trans = transforms.Compose([ 74 | transforms.RandomCrop(32, padding=4), 75 | transforms.RandomHorizontalFlip(), 76 | transforms.ToTensor(), 77 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 78 | ]) 79 | val_trans = transforms.Compose([ 80 | transforms.ToTensor(), 81 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 82 | ]) 83 | trainset = CIFAR10(root="~/data", train=True, download=False, transform=tra_trans) 84 | testset = CIFAR10(root="~/data", train=False, download=False, transform=val_trans) 85 | 86 | if args.dataset == 'femnist' or 'mnist': 87 | tra_trans = transforms.Compose([ 88 | transforms.Pad(2, padding_mode='edge'), 89 | transforms.ToTensor(), 90 | transforms.Normalize((0.1307,), (0.3081,)), 91 | ]) 92 | val_trans = transforms.Compose([ 93 | transforms.Pad(2, padding_mode='edge'), 94 | transforms.ToTensor(), 95 | transforms.Normalize((0.1307,), (0.3081,)), 96 | ]) 97 | if args.dataset == 'femnist': 98 | trainset = FEMNIST(root='~/data', train=True, transform=tra_trans) 99 | testset = FEMNIST(root='~/data', train=False, transform=val_trans) 100 | if args.dataset == 'mnist': 101 | trainset = MNIST(root='~/data', train=True, transform=tra_trans) 102 | testset = MNIST(root='~/data', train=False, transform=val_trans) 103 | 104 | return trainset, testset 105 | 106 | 107 | class Data(object): 108 | 109 | def __init__(self, args): 110 | self.args = args 111 | self.trainset, self.testset = None, None 112 | trainset, testset = Dataset(args) 113 | num_train = [int(len(trainset) / args.split) for _ in range(args.split)] 114 | cumsum_train = torch.tensor(list(num_train)).cumsum(dim=0).tolist() 115 | # idx_train = sorted(range(len(trainset.targets)), key=lambda k: trainset.targets[k]) #split by class 116 | idx_train = range(len(trainset.targets)) 117 | splited_trainset = [Subset(trainset, idx_train[off - l:off]) for off, l in zip(cumsum_train, num_train)] 118 | num_test = [int(len(testset) / args.split) for _ in range(args.split)] 119 | cumsum_test = torch.tensor(list(num_test)).cumsum(dim=0).tolist() 120 | # idx_train = sorted(range(len(trainset.targets)), key=lambda k: trainset.targets[k]) #split by class 121 | idx_test = range(len(testset.targets)) 122 | splited_testset = [Subset(testset, idx_test[off - l:off]) for off, l in zip(cumsum_test, num_test)] 123 | self.test_all = DataLoader(testset, batch_size=args.batchsize, shuffle=False, num_workers=4) 124 | self.train_loader = [DataLoader(splited_trainset[i], batch_size=args.batchsize, shuffle=True, num_workers=4) 125 | for i in range(args.node_num)] 126 | self.test_loader = [DataLoader(splited_testset[i], batch_size=args.batchsize, shuffle=False, num_workers=4) 127 | for i in range(args.node_num)] 128 | self.test_loader = DataLoader(testset, batch_size=args.batchsize, shuffle=False, num_workers=4) 129 | --------------------------------------------------------------------------------