├── dataset
├── mnist
└── cifar10
├── cache
└── cacahe file
├── README.md
├── .idea
├── misc.xml
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── .gitignore
├── modules.xml
└── FedSGD.iml
├── model
├── plot.py
├── lenet.py
├── client.py
├── server.py
└── data.py
├── main.py
└── .gitignore
/dataset/mnist:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cache/cacahe file:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/dataset/cifar10:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FedSGD
2 | Federated learning via stochastic gradient descent
3 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /../../../../../../:\Users\Lei\Desktop\FedSGD\.idea/dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/FedSGD.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/model/plot.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from matplotlib import pyplot as plt
4 |
5 |
6 | def plot():
7 | accuracy = torch.load('../cache/accuracy.pkl')
8 | plt.plot([e for e in range(1, len(accuracy) + 1)], accuracy, label='FedSGD')
9 |
10 | plt.title("Test Accuracy")
11 | plt.xlabel("epoch")
12 | plt.ylabel("accuracy")
13 |
14 | plt.ylim(0, 1)
15 | plt.xlim(1, len(accuracy))
16 | plt.legend(loc=4)
17 |
18 | plt.show()
19 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from model.data import loader
4 | from model.server import server
5 | from model.client import client
6 | from model.plot import plot
7 |
8 |
9 | def federated_learning():
10 | # hyper parameter
11 | n_client = 10
12 | n_epoch = 1000
13 | batch_size = 64
14 |
15 | # dataset
16 | print('Initialize Dataset...')
17 | # data_loader = loader('mnist', batch_size=batch_size)
18 | data_loader = loader('cifar10', batch_size=batch_size)
19 |
20 | # initialize server
21 | print('Initialize Server...')
22 | s = server(size=n_client, data_loader=data_loader.get_loader([]))
23 |
24 | # initialize client
25 | print('Initialize Client...')
26 | clients = []
27 | for i in range(n_client):
28 | clients.append(client(rank=i, data_loader=data_loader.get_loader(
29 | random.sample(range(0, 10), 4)
30 | )))
31 |
32 | # federated learning
33 | for e in range(n_epoch):
34 | print('\n================== Epoch {:>3} =================='.format(e + 1))
35 | for c in clients:
36 | c.run()
37 | s.aggregate()
38 |
39 | # plot
40 | plot()
41 |
42 |
43 | if __name__ == '__main__':
44 | federated_learning()
45 |
--------------------------------------------------------------------------------
/model/lenet.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 | from torch import nn
4 |
5 |
6 | class LeNet5(nn.Module, ABC):
7 | def __init__(self, in_dim, n_class):
8 | super(LeNet5, self).__init__() # super用法:继承父类nn.Model的属性,并用父类的方法初始化这些属性
9 |
10 | self.conv1 = nn.Sequential(
11 | # nn.Conv2d(in_dim, 6, 5, 1, 2), # out_dim=6, kernel_size=5, stride=1, padding=2
12 | nn.Conv2d(in_dim, 6, 5, 1, 0),
13 | nn.ReLU(),
14 | nn.MaxPool2d(2, 2) # kernel_size=2, padding=2
15 | )
16 | self.conv2 = nn.Sequential(
17 | nn.Conv2d(6, 16, 5, 1, 0),
18 | nn.ReLU(),
19 | nn.MaxPool2d(2, 2)
20 | )
21 | self.fc = nn.Sequential(
22 | nn.Linear(400, 120), # in_features=400, out_features=120
23 | nn.Linear(120, 84),
24 | nn.Linear(84, n_class)
25 | )
26 |
27 | def forward(self, x):
28 | out_conv1 = self.conv1(x)
29 | out_conv2 = self.conv2(out_conv1)
30 | out_conv = out_conv2.view(out_conv2.size(0), -1)
31 |
32 | out = self.fc(out_conv)
33 | return out
34 |
35 |
36 | def lenet5():
37 | """ return a LeNet 5 object
38 | """
39 | return LeNet5(3, 10)
40 |
--------------------------------------------------------------------------------
/model/client.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch import nn
4 | from torch.autograd import Variable
5 |
6 | from model.lenet import lenet5
7 |
8 |
9 | class client(object):
10 | def __init__(self, rank, data_loader):
11 | # seed
12 | seed = 19201077 + 19950920 + rank
13 | torch.manual_seed(seed)
14 | torch.cuda.manual_seed(seed)
15 | torch.cuda.manual_seed_all(seed)
16 |
17 | # rank
18 | self.rank = rank
19 |
20 | # data loader
21 | self.train_loader = data_loader[0]
22 | self.test_loader = data_loader[1]
23 |
24 | @staticmethod
25 | def __load_global_model():
26 | global_model_state = torch.load('./cache/global_model_state.pkl')
27 | model = lenet5().cuda()
28 | model.load_state_dict(global_model_state)
29 | return model
30 |
31 | def __train(self, model):
32 | train_loss = 0
33 | train_correct = 0
34 | model.train()
35 | for data, target in self.train_loader:
36 | data, target = Variable(data).cuda(), Variable(target).cuda()
37 | output = model(data)
38 | loss = nn.CrossEntropyLoss()(output, target)
39 | train_loss += loss
40 | loss.backward()
41 | pred = output.argmax(dim=1, keepdim=True)
42 | train_correct += pred.eq(target.view_as(pred)).sum().item()
43 |
44 | grads = {'n_samples': len(self.train_loader.dataset), 'named_grads': {}}
45 | for name, param in model.named_parameters():
46 | grads['named_grads'][name] = param.grad
47 |
48 | print('[Rank {:>2}] Loss: {:>4.6f}, Accuracy: {:>.4f}'.format(
49 | self.rank,
50 | train_loss,
51 | train_correct / len(self.train_loader.dataset)
52 | ))
53 | return grads
54 |
55 | def run(self):
56 | model = self.__load_global_model()
57 | grads = self.__train(model=model)
58 | torch.save(grads, './cache/grads_{}.pkl'.format(self.rank))
59 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/model/server.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch import optim
4 | from torch.autograd import Variable
5 |
6 | from model.lenet import lenet5
7 |
8 |
9 | class server(object):
10 | def __init__(self, size, data_loader):
11 | self.size = size
12 | self.test_loader = data_loader[1]
13 | self.path = './cache/global_model_state.pkl'
14 | self.model = self.__init_server()
15 | self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
16 | self.accuracy = []
17 |
18 | def __init_server(self):
19 | model = lenet5().cuda()
20 | torch.save(model.state_dict(), self.path)
21 | return model
22 |
23 | def __load_grads(self):
24 | grads_info = []
25 | for s in range(self.size):
26 | grads_info.append(torch.load('./cache/grads_{}.pkl'.format(s)))
27 | return grads_info
28 |
29 | @staticmethod
30 | def __average_grads(grads_info):
31 | total_grads = {}
32 | n_total_samples = 0
33 | for info in grads_info:
34 | n_samples = info['n_samples']
35 | for k, v in info['named_grads'].items():
36 | if k not in total_grads:
37 | total_grads[k] = v
38 | total_grads[k] += v * n_samples
39 | n_total_samples += n_samples
40 | gradients = {}
41 | for k, v in total_grads.items():
42 | gradients[k] = torch.div(v, n_total_samples)
43 | return gradients
44 |
45 | def __step(self, gradients):
46 | self.model.train()
47 | self.optimizer.zero_grad()
48 | for k, v in self.model.named_parameters():
49 | v.grad = gradients[k]
50 | self.optimizer.step()
51 |
52 | def __test(self):
53 | test_correct = 0
54 | self.model.eval()
55 | with torch.no_grad():
56 | for data, target in self.test_loader:
57 | data, target = Variable(data).cuda(), Variable(target).cuda()
58 | output = self.model(data)
59 | pred = output.argmax(dim=1, keepdim=True)
60 | test_correct += pred.eq(target.view_as(pred)).sum().item()
61 | return test_correct / len(self.test_loader.dataset)
62 |
63 | def aggregate(self):
64 | grads_info = self.__load_grads()
65 | gradients = self.__average_grads(grads_info)
66 |
67 | self.__step(gradients)
68 | torch.save(self.model.state_dict(), './cache/global_model_state.pkl')
69 |
70 | test_accuracy = self.__test()
71 | self.accuracy.append(test_accuracy)
72 | torch.save(self.accuracy, './cache/accuracy.pkl')
73 | print('\n[Global Model] Test Accuracy: {:.2f}%\n'.format(test_accuracy * 100.))
74 |
--------------------------------------------------------------------------------
/model/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch.utils.data import DataLoader
4 |
5 | from torchvision import datasets
6 | from torchvision import transforms
7 |
8 |
9 | class loader(object):
10 | def __init__(self, cmd='cifar10', batch_size=64):
11 | self.cmd = cmd
12 | self.batch_size = batch_size
13 | self.__load_dataset()
14 | self.__get_index()
15 |
16 | def __load_dataset(self):
17 | # mnist
18 | self.train_mnist = datasets.MNIST('./dataset/',
19 | train=True,
20 | download=True,
21 | transform=transforms.Compose([
22 | transforms.ToTensor(),
23 | transforms.Normalize((0.1307,), (0.3081,))
24 | ]))
25 |
26 | self.test_mnist = datasets.MNIST('./dataset/',
27 | train=False,
28 | download=True,
29 | transform=transforms.Compose([
30 | transforms.ToTensor(),
31 | transforms.Normalize((0.1307,), (0.3081,))
32 | ]))
33 |
34 | # cifar10
35 | self.train_cifar10 = datasets.CIFAR10('./dataset/',
36 | train=True,
37 | download=True,
38 | transform=transforms.Compose([
39 | transforms.ToTensor(),
40 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
41 | ]))
42 | self.test_cifar10 = datasets.CIFAR10('./dataset/',
43 | train=False,
44 | download=True,
45 | transform=transforms.Compose([
46 | transforms.ToTensor(),
47 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
48 | ]))
49 |
50 | def __get_index(self):
51 | if self.cmd == 'cifar10':
52 | self.train_dataset = self.train_cifar10
53 | self.test_dataset = self.test_cifar10
54 | else:
55 | self.train_dataset = self.train_mnist
56 | self.test_dataset = self.test_mnist
57 |
58 | self.indices = [[], [], [], [], [], [], [], [], [], []]
59 | for index, data in enumerate(self.train_dataset):
60 | self.indices[data[1]].append(index)
61 |
62 | def get_loader(self, rank):
63 | dataset_indices = []
64 | difference = list(set(range(10)).difference(set(rank)))
65 | for i in difference:
66 | dataset_indices.extend(self.indices[i])
67 |
68 | dataset = torch.utils.data.Subset(self.train_cifar10, dataset_indices)
69 | if self.cmd != 'cifar10':
70 | dataset = torch.utils.data.Subset(self.train_mnist, dataset_indices)
71 |
72 | train_loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
73 | test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=True)
74 |
75 | return train_loader, test_loader
76 |
--------------------------------------------------------------------------------