├── dataset ├── mnist └── cifar10 ├── cache └── cacahe file ├── README.md ├── .idea ├── misc.xml ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── .gitignore ├── modules.xml └── FedSGD.iml ├── model ├── plot.py ├── lenet.py ├── client.py ├── server.py └── data.py ├── main.py └── .gitignore /dataset/mnist: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cache/cacahe file: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/cifar10: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FedSGD 2 | Federated learning via stochastic gradient descent 3 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /../../../../../../:\Users\Lei\Desktop\FedSGD\.idea/dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/FedSGD.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /model/plot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from matplotlib import pyplot as plt 4 | 5 | 6 | def plot(): 7 | accuracy = torch.load('../cache/accuracy.pkl') 8 | plt.plot([e for e in range(1, len(accuracy) + 1)], accuracy, label='FedSGD') 9 | 10 | plt.title("Test Accuracy") 11 | plt.xlabel("epoch") 12 | plt.ylabel("accuracy") 13 | 14 | plt.ylim(0, 1) 15 | plt.xlim(1, len(accuracy)) 16 | plt.legend(loc=4) 17 | 18 | plt.show() 19 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 12 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from model.data import loader 4 | from model.server import server 5 | from model.client import client 6 | from model.plot import plot 7 | 8 | 9 | def federated_learning(): 10 | # hyper parameter 11 | n_client = 10 12 | n_epoch = 1000 13 | batch_size = 64 14 | 15 | # dataset 16 | print('Initialize Dataset...') 17 | # data_loader = loader('mnist', batch_size=batch_size) 18 | data_loader = loader('cifar10', batch_size=batch_size) 19 | 20 | # initialize server 21 | print('Initialize Server...') 22 | s = server(size=n_client, data_loader=data_loader.get_loader([])) 23 | 24 | # initialize client 25 | print('Initialize Client...') 26 | clients = [] 27 | for i in range(n_client): 28 | clients.append(client(rank=i, data_loader=data_loader.get_loader( 29 | random.sample(range(0, 10), 4) 30 | ))) 31 | 32 | # federated learning 33 | for e in range(n_epoch): 34 | print('\n================== Epoch {:>3} =================='.format(e + 1)) 35 | for c in clients: 36 | c.run() 37 | s.aggregate() 38 | 39 | # plot 40 | plot() 41 | 42 | 43 | if __name__ == '__main__': 44 | federated_learning() 45 | -------------------------------------------------------------------------------- /model/lenet.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | from torch import nn 4 | 5 | 6 | class LeNet5(nn.Module, ABC): 7 | def __init__(self, in_dim, n_class): 8 | super(LeNet5, self).__init__() # super用法:继承父类nn.Model的属性,并用父类的方法初始化这些属性 9 | 10 | self.conv1 = nn.Sequential( 11 | # nn.Conv2d(in_dim, 6, 5, 1, 2), # out_dim=6, kernel_size=5, stride=1, padding=2 12 | nn.Conv2d(in_dim, 6, 5, 1, 0), 13 | nn.ReLU(), 14 | nn.MaxPool2d(2, 2) # kernel_size=2, padding=2 15 | ) 16 | self.conv2 = nn.Sequential( 17 | nn.Conv2d(6, 16, 5, 1, 0), 18 | nn.ReLU(), 19 | nn.MaxPool2d(2, 2) 20 | ) 21 | self.fc = nn.Sequential( 22 | nn.Linear(400, 120), # in_features=400, out_features=120 23 | nn.Linear(120, 84), 24 | nn.Linear(84, n_class) 25 | ) 26 | 27 | def forward(self, x): 28 | out_conv1 = self.conv1(x) 29 | out_conv2 = self.conv2(out_conv1) 30 | out_conv = out_conv2.view(out_conv2.size(0), -1) 31 | 32 | out = self.fc(out_conv) 33 | return out 34 | 35 | 36 | def lenet5(): 37 | """ return a LeNet 5 object 38 | """ 39 | return LeNet5(3, 10) 40 | -------------------------------------------------------------------------------- /model/client.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | from torch.autograd import Variable 5 | 6 | from model.lenet import lenet5 7 | 8 | 9 | class client(object): 10 | def __init__(self, rank, data_loader): 11 | # seed 12 | seed = 19201077 + 19950920 + rank 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | 17 | # rank 18 | self.rank = rank 19 | 20 | # data loader 21 | self.train_loader = data_loader[0] 22 | self.test_loader = data_loader[1] 23 | 24 | @staticmethod 25 | def __load_global_model(): 26 | global_model_state = torch.load('./cache/global_model_state.pkl') 27 | model = lenet5().cuda() 28 | model.load_state_dict(global_model_state) 29 | return model 30 | 31 | def __train(self, model): 32 | train_loss = 0 33 | train_correct = 0 34 | model.train() 35 | for data, target in self.train_loader: 36 | data, target = Variable(data).cuda(), Variable(target).cuda() 37 | output = model(data) 38 | loss = nn.CrossEntropyLoss()(output, target) 39 | train_loss += loss 40 | loss.backward() 41 | pred = output.argmax(dim=1, keepdim=True) 42 | train_correct += pred.eq(target.view_as(pred)).sum().item() 43 | 44 | grads = {'n_samples': len(self.train_loader.dataset), 'named_grads': {}} 45 | for name, param in model.named_parameters(): 46 | grads['named_grads'][name] = param.grad 47 | 48 | print('[Rank {:>2}] Loss: {:>4.6f}, Accuracy: {:>.4f}'.format( 49 | self.rank, 50 | train_loss, 51 | train_correct / len(self.train_loader.dataset) 52 | )) 53 | return grads 54 | 55 | def run(self): 56 | model = self.__load_global_model() 57 | grads = self.__train(model=model) 58 | torch.save(grads, './cache/grads_{}.pkl'.format(self.rank)) 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /model/server.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import optim 4 | from torch.autograd import Variable 5 | 6 | from model.lenet import lenet5 7 | 8 | 9 | class server(object): 10 | def __init__(self, size, data_loader): 11 | self.size = size 12 | self.test_loader = data_loader[1] 13 | self.path = './cache/global_model_state.pkl' 14 | self.model = self.__init_server() 15 | self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3) 16 | self.accuracy = [] 17 | 18 | def __init_server(self): 19 | model = lenet5().cuda() 20 | torch.save(model.state_dict(), self.path) 21 | return model 22 | 23 | def __load_grads(self): 24 | grads_info = [] 25 | for s in range(self.size): 26 | grads_info.append(torch.load('./cache/grads_{}.pkl'.format(s))) 27 | return grads_info 28 | 29 | @staticmethod 30 | def __average_grads(grads_info): 31 | total_grads = {} 32 | n_total_samples = 0 33 | for info in grads_info: 34 | n_samples = info['n_samples'] 35 | for k, v in info['named_grads'].items(): 36 | if k not in total_grads: 37 | total_grads[k] = v 38 | total_grads[k] += v * n_samples 39 | n_total_samples += n_samples 40 | gradients = {} 41 | for k, v in total_grads.items(): 42 | gradients[k] = torch.div(v, n_total_samples) 43 | return gradients 44 | 45 | def __step(self, gradients): 46 | self.model.train() 47 | self.optimizer.zero_grad() 48 | for k, v in self.model.named_parameters(): 49 | v.grad = gradients[k] 50 | self.optimizer.step() 51 | 52 | def __test(self): 53 | test_correct = 0 54 | self.model.eval() 55 | with torch.no_grad(): 56 | for data, target in self.test_loader: 57 | data, target = Variable(data).cuda(), Variable(target).cuda() 58 | output = self.model(data) 59 | pred = output.argmax(dim=1, keepdim=True) 60 | test_correct += pred.eq(target.view_as(pred)).sum().item() 61 | return test_correct / len(self.test_loader.dataset) 62 | 63 | def aggregate(self): 64 | grads_info = self.__load_grads() 65 | gradients = self.__average_grads(grads_info) 66 | 67 | self.__step(gradients) 68 | torch.save(self.model.state_dict(), './cache/global_model_state.pkl') 69 | 70 | test_accuracy = self.__test() 71 | self.accuracy.append(test_accuracy) 72 | torch.save(self.accuracy, './cache/accuracy.pkl') 73 | print('\n[Global Model] Test Accuracy: {:.2f}%\n'.format(test_accuracy * 100.)) 74 | -------------------------------------------------------------------------------- /model/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.utils.data import DataLoader 4 | 5 | from torchvision import datasets 6 | from torchvision import transforms 7 | 8 | 9 | class loader(object): 10 | def __init__(self, cmd='cifar10', batch_size=64): 11 | self.cmd = cmd 12 | self.batch_size = batch_size 13 | self.__load_dataset() 14 | self.__get_index() 15 | 16 | def __load_dataset(self): 17 | # mnist 18 | self.train_mnist = datasets.MNIST('./dataset/', 19 | train=True, 20 | download=True, 21 | transform=transforms.Compose([ 22 | transforms.ToTensor(), 23 | transforms.Normalize((0.1307,), (0.3081,)) 24 | ])) 25 | 26 | self.test_mnist = datasets.MNIST('./dataset/', 27 | train=False, 28 | download=True, 29 | transform=transforms.Compose([ 30 | transforms.ToTensor(), 31 | transforms.Normalize((0.1307,), (0.3081,)) 32 | ])) 33 | 34 | # cifar10 35 | self.train_cifar10 = datasets.CIFAR10('./dataset/', 36 | train=True, 37 | download=True, 38 | transform=transforms.Compose([ 39 | transforms.ToTensor(), 40 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 41 | ])) 42 | self.test_cifar10 = datasets.CIFAR10('./dataset/', 43 | train=False, 44 | download=True, 45 | transform=transforms.Compose([ 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 48 | ])) 49 | 50 | def __get_index(self): 51 | if self.cmd == 'cifar10': 52 | self.train_dataset = self.train_cifar10 53 | self.test_dataset = self.test_cifar10 54 | else: 55 | self.train_dataset = self.train_mnist 56 | self.test_dataset = self.test_mnist 57 | 58 | self.indices = [[], [], [], [], [], [], [], [], [], []] 59 | for index, data in enumerate(self.train_dataset): 60 | self.indices[data[1]].append(index) 61 | 62 | def get_loader(self, rank): 63 | dataset_indices = [] 64 | difference = list(set(range(10)).difference(set(rank))) 65 | for i in difference: 66 | dataset_indices.extend(self.indices[i]) 67 | 68 | dataset = torch.utils.data.Subset(self.train_cifar10, dataset_indices) 69 | if self.cmd != 'cifar10': 70 | dataset = torch.utils.data.Subset(self.train_mnist, dataset_indices) 71 | 72 | train_loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) 73 | test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=True) 74 | 75 | return train_loader, test_loader 76 | --------------------------------------------------------------------------------