├── .gitignore ├── README.md ├── main.py ├── model ├── __init__.py └── resnet.py ├── src ├── __init__.py ├── client.py ├── cloud.py └── edge.py └── util ├── __init__.py ├── data.py └── fedavg.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | ### macOS template 134 | # General 135 | .DS_Store 136 | .AppleDouble 137 | .LSOverride 138 | 139 | # Icon must end with two \r 140 | Icon 141 | 142 | # Thumbnails 143 | ._* 144 | 145 | # Files that might appear in the root of a volume 146 | .DocumentRevisions-V100 147 | .fseventsd 148 | .Spotlight-V100 149 | .TemporaryItems 150 | .Trashes 151 | .VolumeIcon.icns 152 | .com.apple.timemachine.donotpresent 153 | 154 | # Directories potentially created on remote AFP share 155 | .AppleDB 156 | .AppleDesktop 157 | Network Trash Folder 158 | Temporary Items 159 | .apdisk 160 | 161 | .idea/ 162 | # local dataset 163 | data 164 | server.py 165 | 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Federated Learning 2 | 3 | We simulate the Cloud-Edge-Client FL framework. 4 | 5 | ## Requirements 6 | 7 | python==3.7 8 | 9 | pytorch==1.4.0 10 | 11 | ## Run 12 | 13 | Federated learning with ResNet is produced by: 14 | > python [main.py](main.py) 15 | 16 | See the arguments in [main.py](main.py). 17 | 18 | For example: 19 | > python main.py --edge_num=10 --client_num=100 --ratio1=0.5 --ratio2=0.3 --optim=adam --lr=0.001 --bs=128 20 | 21 | Note: We only consider IID setting on CIFAR10. 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import argparse 4 | import numpy as np 5 | import random 6 | import os 7 | from src.cloud import Cloud 8 | from src.edge import Edges 9 | from src.client import Clients 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | def setup_seed(seed): 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | torch.backends.cudnn.deterministic = True 19 | torch.backends.cudnn.benchmark = True 20 | 21 | 22 | def get_parse(): 23 | # parser for hyperparameter 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--seed', type=int, default=0) 26 | parser.add_argument('--fixed_seed', type=bool, default=False) 27 | parser.add_argument('--edge_num', type=int, default=10) 28 | parser.add_argument('--client_num', type=int, default=100) 29 | parser.add_argument('--ratio1', type=float, default=0.2, help='The ratio of chosen edges') 30 | parser.add_argument('--ratio2', type=float, default=0.05, help='The ratio of chosen client per edge') 31 | parser.add_argument('--optim', default='adam', type=str, help='optimizer') 32 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum term') 33 | parser.add_argument('--beta1', default=0.9, type=float, help='Adam coefficients beta_1') 34 | parser.add_argument('--beta2', default=0.999, type=float, help='Adam coefficients beta_2') 35 | parser.add_argument('--weight_decay', default=5e-4, type=float, 36 | help='weight decay for optimizers') 37 | parser.add_argument('--lr', type=float, default=0.001) 38 | parser.add_argument('--bs', type=int, default=128) 39 | parser.add_argument('--epochs', type=int, default=500) 40 | parser.add_argument('--client_epochs', type=int, default=1) 41 | parser.add_argument('--edge_epochs', type=int, default=1) 42 | parser.add_argument('--num_classes', type=int, default=10, help='cifar10') 43 | args = parser.parse_args() 44 | return args 45 | 46 | 47 | def create_optimizer(args, model_params): 48 | if args.optim == 'sgd': 49 | return optim.SGD(model_params, args.lr, momentum=args.momentum, 50 | weight_decay=args.weight_decay) 51 | elif args.optim == 'adagrad': 52 | return optim.Adagrad(model_params, args.lr, weight_decay=args.weight_decay) 53 | elif args.optim == 'adam': 54 | return optim.Adam(model_params, args.lr, betas=(args.beta1, args.beta2), 55 | weight_decay=args.weight_decay) 56 | elif args.optim == 'amsgrad': 57 | return optim.Adam(model_params, args.lr, betas=(args.beta1, args.beta2), 58 | weight_decay=args.weight_decay, amsgrad=True) 59 | else: 60 | raise ValueError('unknown optimizer') 61 | 62 | 63 | def plot(curve, name): 64 | if not os.path.isdir('figure'): 65 | os.mkdir('figure') 66 | plt.figure() 67 | plt.xlabel('Epoch') 68 | plt.ylabel('Training loss') 69 | plt.plot(curve) 70 | plt.savefig('figure/{}.png'.format(name)) 71 | 72 | 73 | def reduce_dim(vector): 74 | # TODO 75 | pass 76 | 77 | 78 | if __name__ == '__main__': 79 | 80 | args = get_parse() 81 | if args.fixed_seed: 82 | setup_seed(args.seed) 83 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 84 | name = 'fed-edge{}-client{}_{}_C{}'.format(args.edge_num, args.client_num, args.ratio1, args.ratio2) 85 | 86 | client = Clients(args.num_classes, args.edge_num, args.client_num, args.bs, args.client_epochs, device) 87 | edge = Edges(client, args.num_classes, args.edge_num, args.client_num, args.bs, args.edge_epochs, device) 88 | cloud = Cloud(edge, args.num_classes, args.edge_num, args.client_num, args.bs, device) 89 | 90 | optimizer = create_optimizer(args, cloud.model.parameters()) 91 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 300], gamma=0.1) 92 | 93 | train_losses, train_accuracies = [], [] 94 | test_losses, test_accuracies = [], [] 95 | 96 | for epoch in range(0, args.epochs): 97 | print('\nEpoch: %d' % epoch) 98 | train_loss, train_acc = cloud.train_epoch(optimizer, args.ratio1, 99 | args.ratio1, device) 100 | vector = cloud.get_cloud_and_edges # [] 101 | reduce_dim(vector) 102 | print("Training Acc: {:.4f}, Loss: {:.4f}".format(train_acc, train_loss)) 103 | test_acc, test_loss = cloud.run_test(device=device) 104 | print("Testing Acc: {:.4f}, Loss: {:.4f}".format(test_acc, test_loss)) 105 | scheduler.step() 106 | 107 | # save loss and acc 108 | train_losses.append(train_loss) 109 | train_accuracies.append(train_acc) 110 | test_losses.append(test_loss) 111 | test_accuracies.append(test_acc) 112 | 113 | if not os.path.isdir('curve'): 114 | os.mkdir('curve') 115 | 116 | torch.save({'train_loss': train_loss, 'train_accuracy': train_acc, 117 | 'test_loss': test_loss, 'test_accuracy': test_acc}, os.path.join('curve', name)) 118 | 119 | plot(train_loss, name) 120 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. Deep Residual Learning for Image Recognition: 3 | https://arxiv.org/abs/1512.03385 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | expansion = 1 12 | 13 | def __init__(self, in_planes, planes, stride=1): 14 | super(BasicBlock, self).__init__() 15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, 16 | bias=False) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | 21 | self.shortcut = nn.Sequential() 22 | if stride != 1 or in_planes != self.expansion * planes: 23 | self.shortcut = nn.Sequential( 24 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, 25 | bias=False), 26 | nn.BatchNorm2d(self.expansion * planes) 27 | ) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | out += self.shortcut(x) 33 | out = F.relu(out) 34 | return out 35 | 36 | 37 | class Bottleneck(nn.Module): 38 | expansion = 4 39 | 40 | def __init__(self, in_planes, planes, stride=1): 41 | super(Bottleneck, self).__init__() 42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 43 | self.bn1 = nn.BatchNorm2d(planes) 44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride != 1 or in_planes != self.expansion * planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, 53 | bias=False), 54 | nn.BatchNorm2d(self.expansion * planes) 55 | ) 56 | 57 | def forward(self, x): 58 | out = F.relu(self.bn1(self.conv1(x))) 59 | out = F.relu(self.bn2(self.conv2(out))) 60 | out = self.bn3(self.conv3(out)) 61 | out += self.shortcut(x) 62 | out = F.relu(out) 63 | return out 64 | 65 | 66 | class ResNet(nn.Module): 67 | def __init__(self, block, num_blocks, num_classes=10): 68 | super(ResNet, self).__init__() 69 | self.in_planes = 64 70 | 71 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 72 | self.bn1 = nn.BatchNorm2d(64) 73 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 74 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 75 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 76 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 77 | self.linear = nn.Linear(512 * block.expansion, num_classes) 78 | 79 | def _make_layer(self, block, planes, num_blocks, stride): 80 | strides = [stride] + [1] * (num_blocks - 1) 81 | layers = [] 82 | for stride in strides: 83 | layers.append(block(self.in_planes, planes, stride)) 84 | self.in_planes = planes * block.expansion 85 | return nn.Sequential(*layers) 86 | 87 | def forward(self, x): 88 | out = F.relu(self.bn1(self.conv1(x))) 89 | out = self.layer1(out) 90 | out = self.layer2(out) 91 | out = self.layer3(out) 92 | out = self.layer4(out) 93 | out = F.avg_pool2d(out, 4) 94 | out = out.view(out.size(0), -1) 95 | out = self.linear(out) 96 | return out 97 | 98 | 99 | def ResNet18(num_classes): 100 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes) 101 | 102 | 103 | def ResNet34(num_classes): 104 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes) 105 | 106 | 107 | def ResNet50(num_classes): 108 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes) 109 | 110 | 111 | def ResNet101(num_classes): 112 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes) 113 | 114 | 115 | def ResNet152(num_classes): 116 | return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes) 117 | 118 | 119 | def test(): 120 | net = ResNet18() 121 | y = net(torch.randn(1, 3, 32, 32)) 122 | print(y.size()) 123 | 124 | # test() 125 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drzhang3/Fed/8c0de0e72dc15866449c3c59b2632c6784de84d1/src/__init__.py -------------------------------------------------------------------------------- /src/client.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | from model import ResNet34 5 | from util.data import DataSet 6 | 7 | 8 | class Clients: 9 | def __init__(self, num_classes, edge_num, client_num, batch_size, client_epochs, device): 10 | self.model = ResNet34(num_classes).to(device) 11 | self.bs = batch_size 12 | self.client_epochs = client_epochs 13 | self.clients_num = client_num 14 | self.edge_num = edge_num 15 | self.client_num_per_edge = client_num // edge_num 16 | self.dataset = DataSet(edge_num, self.client_num_per_edge, batch_size) 17 | self.criterion = torch.nn.CrossEntropyLoss() 18 | 19 | def run_test(self, device): 20 | self.model.eval() 21 | test_loss = 0 22 | correct = 0 23 | total = 0 24 | with torch.no_grad(): 25 | for batch_idx, (inputs, targets) in enumerate(self.dataset.test): 26 | inputs, targets = inputs.to(device), targets.to(device) 27 | if inputs.size(-1) == 28: 28 | inputs = inputs.view(inputs.size(0), -1) 29 | outputs = self.model(inputs) 30 | 31 | _, predicted = outputs.max(1) 32 | total += targets.size(0) 33 | correct += predicted.eq(targets).sum().item() 34 | 35 | loss = self.criterion(outputs, targets) 36 | test_loss += loss.item() 37 | 38 | accuracy = 100. * correct / total 39 | return test_loss, accuracy 40 | 41 | def train_epoch(self, edge_id, client_id, optimizer, device): 42 | self.model.train() 43 | client_train_data_loader = self.dataset.train[edge_id][client_id] 44 | train_loss = 0 45 | correct = 0 46 | total = 0 47 | for batch_idx, (inputs, targets) in enumerate(client_train_data_loader): 48 | inputs, targets = inputs.to(device), targets.to(device) 49 | if inputs.size(-1) == 28: 50 | inputs = inputs.view(inputs.size(0), -1) 51 | optimizer.zero_grad() 52 | outputs = self.model(inputs) 53 | loss = self.criterion(outputs, targets) 54 | loss.backward() 55 | optimizer.step() 56 | train_loss += loss.item() 57 | _, predicted = outputs.max(1) 58 | total += targets.size(0) 59 | correct += predicted.eq(targets).sum().item() 60 | accuracy = 100. * correct / total 61 | return accuracy, train_loss, self.model.state_dict() 62 | 63 | def train_epochs(self, edge_id, client_id, optimizer, device): 64 | for epoch in range(self.client_epochs): 65 | train_acc, train_loss, current_client_vars = self.train_epoch(edge_id, client_id, optimizer, device) 66 | return train_acc, train_loss, current_client_vars 67 | 68 | def get_client_vars(self): 69 | return self.model.state_dict() 70 | 71 | def set_edge_vars(self, edge_vars): 72 | self.model.load_state_dict(edge_vars) 73 | 74 | def choose_clients(self, ratio2): 75 | choose_num = math.floor(self.client_num_per_edge * ratio2) 76 | return np.random.permutation(self.client_num_per_edge)[:choose_num] 77 | -------------------------------------------------------------------------------- /src/cloud.py: -------------------------------------------------------------------------------- 1 | from src.edge import Edges 2 | from util.fedavg import FedAvg 3 | import numpy as np 4 | import torch 5 | import os 6 | 7 | 8 | class Cloud: 9 | def __init__(self, edge, num_classes, edge_num, client_num, batch_size, device): 10 | self.edge = edge 11 | self.model = self.edge.model 12 | 13 | def train_epoch(self, optimizer, ratio1, ratio2, device): 14 | cloud_vars = self.get_cloud_vars() 15 | 16 | edge_vars_sum = [] 17 | random_edges = self.edge.choose_edges(ratio1) 18 | 19 | edge_loss, edge_acc = [], [] 20 | for edge_id in random_edges: 21 | self.edge.set_cloud_vars(cloud_vars) 22 | train_acc, train_loss, current_edge_vars = self.edge.train_epochs(edge_id=edge_id, optimizer=optimizer, ratio2=ratio2, device=device) 23 | 24 | edge_vars_sum.append(current_edge_vars) 25 | 26 | edge_loss.append(train_loss) 27 | edge_acc.append(train_acc) 28 | 29 | global_vars = FedAvg(edge_vars_sum) 30 | self._save_cloud_and_edges(global_vars, edge_vars_sum) 31 | self.set_global_vars(global_vars) 32 | 33 | return np.mean(edge_loss), np.mean(edge_acc) 34 | 35 | def get_cloud_vars(self): 36 | return self.model.state_dict() 37 | 38 | def set_global_vars(self, global_vars): 39 | return self.model.load_state_dict(global_vars) 40 | 41 | def run_test(self, device): 42 | accuracy, test_loss = self.edge.run_test(device) 43 | return accuracy, test_loss 44 | 45 | def _save_cloud_and_edges(self, cloud, edges): 46 | self.get_cloud_and_edges = [cloud] + edges 47 | 48 | 49 | -------------------------------------------------------------------------------- /src/edge.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | from src.client import Clients 4 | from util.fedavg import FedAvg 5 | 6 | 7 | class Edges: 8 | def __init__(self, client, num_classes, edge_num, client_num, batch_size, edge_epochs, device): 9 | self.bs = batch_size 10 | self.edge_epochs = edge_epochs 11 | self.client_num = client_num 12 | self.edge_num = edge_num 13 | self.client_num_per_edge = client_num // edge_num 14 | self.client = client 15 | self.model = self.client.model 16 | 17 | def train_epoch(self, edge_id, optimizer, ratio2, device): 18 | 19 | edge_vars = self.get_edge_vars() 20 | client_vars_sum = [] 21 | random_clients = self.client.choose_clients(ratio2) 22 | for client_id in random_clients: 23 | self.client.set_edge_vars(edge_vars) 24 | train_acc, train_loss, current_client_vars = self.client.train_epochs(edge_id, client_id, 25 | optimizer=optimizer, device=device) 26 | client_vars_sum.append(current_client_vars) 27 | 28 | cloud_vars = FedAvg(client_vars_sum) 29 | self.set_cloud_vars(cloud_vars) 30 | return train_acc, train_loss, cloud_vars 31 | 32 | def train_epochs(self, edge_id, optimizer, ratio2, device): 33 | for epoch in range(self.edge_epochs): 34 | train_acc, train_loss, edge_vars = self.train_epoch(edge_id, optimizer, ratio2, device) 35 | return train_acc, train_loss, edge_vars 36 | 37 | def set_cloud_vars(self, cloud_vars): 38 | self.model.load_state_dict(cloud_vars) 39 | 40 | def get_edge_vars(self): 41 | return self.model.state_dict() 42 | 43 | def choose_edges(self, ratio1): 44 | choose_num = math.floor(self.edge_num * ratio1) 45 | return np.random.permutation(self.edge_num)[:choose_num] 46 | 47 | def run_test(self, device): 48 | accuracy, test_loss = self.client.run_test(device) 49 | return test_loss, accuracy 50 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drzhang3/Fed/8c0de0e72dc15866449c3c59b2632c6784de84d1/util/__init__.py -------------------------------------------------------------------------------- /util/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | 5 | 6 | def build_dataset(): 7 | print('==> Preparing data..') 8 | transform_train = transforms.Compose([ 9 | transforms.RandomCrop(32, padding=4), 10 | transforms.RandomHorizontalFlip(), 11 | transforms.ToTensor(), 12 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 13 | ]) 14 | 15 | transform_test = transforms.Compose([ 16 | transforms.ToTensor(), 17 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 18 | ]) 19 | 20 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, 21 | transform=transform_train) 22 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, 23 | transform=transform_test) 24 | # classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 25 | 26 | return trainset, testset 27 | 28 | 29 | class DataSet: 30 | def __init__(self, edge_num, client_num_per_edge, batch_size): 31 | trainset, testset = build_dataset() 32 | edge_data_list = torch.utils.data.random_split(trainset, [len(trainset) // edge_num for _ in range(edge_num)]) 33 | client_data_list = [torch.utils.data.random_split(edge_data, [len(edge_data) // client_num_per_edge for _ in 34 | range(client_num_per_edge)]) for edge_data in 35 | edge_data_list] 36 | self.train = [[torch.utils.data.DataLoader(client_data_list[i][j], batch_size=batch_size, shuffle=True) for j in 37 | range(client_num_per_edge)] for i in range(edge_num)] 38 | self.test = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False) 39 | del client_data_list 40 | 41 | # dataset = DataSet(10, 10, 64) 42 | # print(len(dataset.train)) 43 | # print(len(dataset.train[0])) 44 | -------------------------------------------------------------------------------- /util/fedavg.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torch import nn 4 | 5 | 6 | def FedAvg(w): 7 | w_avg = copy.deepcopy(w[0]) 8 | for k in w_avg.keys(): 9 | for i in range(1, len(w)): 10 | w_avg[k] += w[i][k] 11 | w_avg[k] = torch.div(w_avg[k], len(w)) 12 | return w_avg --------------------------------------------------------------------------------