├── AFL.py
├── README.md
├── run_websocket_client.py
├── run_websocket_server.py
└── start_websocket_servers.py
/AFL.py:
--------------------------------------------------------------------------------
1 | import inspect
2 |
3 | # Dependencies
4 | import math
5 | import sys
6 | import pandas as pd
7 | import asyncio
8 | import logging
9 | import time
10 |
11 | logger = logging.getLogger("run_websocket_client")
12 |
13 | import syft as sy
14 | from syft.workers.websocket_client import WebsocketClientWorker
15 | from syft.frameworks.torch.fl import utils
16 |
17 | import torch
18 | import numpy as np
19 |
20 | import run_websocket_client as rwc
21 |
22 | path = "./result"
23 |
24 |
25 | async def main():
26 | # 创建csv文件
27 | global old_model
28 | df = pd.DataFrame(columns=['step',
29 | 'acc1', 'acc2',
30 | 'acc3', 'accf'])
31 | t = str(time.time())
32 | df.to_csv(path + "/" + t + ".csv", index=False)
33 | iter = 0
34 |
35 | # Hook torch
36 | hook = sy.TorchHook(torch)
37 |
38 | # Arguments
39 | args = rwc.define_and_get_arguments(args=[])
40 | use_cuda = args.cuda and torch.cuda.is_available()
41 | torch.manual_seed(args.seed)
42 | device = torch.device("cuda" if use_cuda else "cpu")
43 | print(args)
44 |
45 | # Configure logging
46 |
47 | if not len(logger.handlers):
48 | FORMAT = "%(asctime)s - %(message)s"
49 | DATE_FMT = "%H:%M:%S"
50 | formatter = logging.Formatter(FORMAT, DATE_FMT)
51 | handler = logging.StreamHandler()
52 | handler.setFormatter(formatter)
53 | logger.addHandler(handler)
54 | logger.propagate = False
55 | LOG_LEVEL = logging.DEBUG
56 | logger.setLevel(LOG_LEVEL)
57 |
58 | t0 = time.time()
59 |
60 | akwargs_websocket = {"host": "192.168.2.13", "hook": hook, "verbose": args.verbose}
61 | alice = WebsocketClientWorker(id="alice", port=8777, **akwargs_websocket)
62 |
63 | bkwargs_websocket = {"host": "192.168.2.16", "hook": hook, "verbose": args.verbose}
64 | bob = WebsocketClientWorker(id="bob", port=8778, **bkwargs_websocket)
65 |
66 | ckwargs_websocket = {"host": "192.168.2.14", "hook": hook, "verbose": args.verbose}
67 | charlie = WebsocketClientWorker(id="charlie", port=8779, **ckwargs_websocket)
68 |
69 | dkwargs_websocket = {"host": "192.168.2.11", "hook": hook, "verbose": args.verbose}
70 | #dave = WebsocketClientWorker(id="dave", port=8780, **dkwargs_websocket)
71 |
72 | ekwargs_websocket = {"host": "192.168.2.12", "hook": hook, "verbose": args.verbose}
73 | #eva = WebsocketClientWorker(id="eva", port=8781, **ekwargs_websocket)
74 |
75 | kwargs_websocket = {"host": "localhost", "hook": hook, "verbose": args.verbose}
76 | #frank = WebsocketClientWorker(id="frank", port=8782, **kwargs_websocket)
77 | #frank1 = WebsocketClientWorker(id="frank1", port=8792, **kwargs_websocket)
78 |
79 | testing = WebsocketClientWorker(id="testing", port=8783, **kwargs_websocket)
80 |
81 | worker_instances = [
82 | alice,
83 | bob,
84 | charlie,
85 | # dave,
86 | # eva,
87 | # frank,
88 | # frank1
89 | ]
90 |
91 | model = rwc.Net().to(device)
92 | # print(model)
93 |
94 | print("Federate_after_n_batches: " + str(args.federate_after_n_batches))
95 | print("Batch size: " + str(args.batch_size))
96 | print("Initial learning rate: " + str(args.lr))
97 |
98 | learning_rate = args.lr
99 |
100 | traced_model = torch.jit.trace(model, torch.zeros([1, 1, 28, 28], dtype=torch.float))
101 |
102 | for curr_round in range(1, args.training_rounds + 1):
103 |
104 | '''OLD MODEL'''
105 | Empty_model = utils.scale_model(model, 0)
106 | old_model = utils.add_model(Empty_model, traced_model)
107 |
108 | # train
109 | logger.info("Training round %s/%s", curr_round, args.training_rounds)
110 | results = await asyncio.gather(
111 | *[
112 | rwc.fit_model_on_worker(
113 | worker=worker,
114 | traced_model=traced_model,
115 | batch_size=args.batch_size,
116 | curr_round=curr_round,
117 | max_nr_batches=args.federate_after_n_batches,
118 | lr=learning_rate,
119 | )
120 | for worker in worker_instances
121 | ]
122 | )
123 |
124 | '''
125 | models, loss, grads, n_server
126 |
127 | values: V
128 | V[worker_id] represents the value of a participant,
129 | and the smaller the value, the lower the need for its inclusion in the federation model
130 | How to calculate V?
131 | 1) grad: the gradient reflects the proximity to the local solution
132 | 2) total number of client: In federation learning, the more the number of participants,
133 | the smaller the value of a single participant
134 | 3) accuracy in testing
135 | '''
136 |
137 | models = {}
138 | loss_values = {}
139 | grads = {}
140 | n_server = 3
141 | acc = {}
142 |
143 | V = {}
144 |
145 | # test
146 | test_models = curr_round % 5 == 1 or curr_round == args.training_rounds
147 | if test_models:
148 |
149 | logger.info("Evaluating models")
150 | np.set_printoptions(formatter={"float": "{: .0f}".format})
151 | for worker_id, worker_model, _ in results:
152 | acc[worker_id] = rwc.evaluate_model_on_worker(
153 | model_identifier="Model update " + worker_id,
154 | worker=testing,
155 | dataset_key="mnist_testing",
156 | model=worker_model,
157 | nr_bins=10,
158 | batch_size=128,
159 | print_target_hist=False,
160 | )
161 | new_model = worker_model
162 | grads[worker_id] = utils.add_model(new_model, utils.scale_model(old_model, -1))
163 |
164 | grad = 0
165 | for p in grads[worker_id].parameters():
166 | p2 = torch.dot(p.view(-1), p.view(-1))
167 | grad += p2
168 |
169 | print(grad.detach().numpy())
170 | grad = grad.detach().numpy()
171 |
172 | V[worker_id] = grad * pow(1 + n_server / 1000, acc[worker_id])
173 |
174 | Vlist = list(V.values())
175 |
176 | ave_V = np.mean(Vlist)
177 |
178 | # Federal model (this operation changes the initial model)
179 | for worker_id, worker_model, worker_loss in results:
180 | if worker_model is not None:
181 | loss_values[worker_id] = worker_loss
182 |
183 | if V[worker_id] >= ave_V:
184 | models[worker_id] = worker_model
185 |
186 | iter += len(models.keys())
187 | # federated_avg
188 | traced_model = utils.federated_avg(models)
189 |
190 | if test_models:
191 | accf = rwc.evaluate_model_on_worker(
192 | model_identifier="Federated model",
193 | worker=testing,
194 | dataset_key="mnist_testing",
195 | model=traced_model,
196 | nr_bins=10,
197 | batch_size=128,
198 | print_target_hist=False,
199 | )
200 |
201 | step = curr_round
202 | acc1 = acc['alice']
203 | acc2 = acc['bob']
204 | acc3 = acc['charlie']
205 | '''
206 | acc4 = acc['dave']
207 | acc5 = acc['eva']
208 | acc6 = acc['frank']
209 | acc7 = acc['frank1']
210 | '''
211 |
212 | scv_list = [step, acc1, acc2, acc3, accf]
213 | scv_data = pd.DataFrame([scv_list])
214 | scv_data.to_csv(path + "/" + t + ".csv", mode='a', header=False, index=False)
215 |
216 | # decay learning rate
217 | learning_rate = max(0.98 * learning_rate, args.lr * 0.01)
218 | if accf is not None and accf >= 94.0:
219 | break
220 |
221 | torch.save(model.state_dict(), "mnist_cnn.pt")
222 |
223 | # time
224 | t1 = time.time()
225 | logger.info('cost:%.4f seconds' % (float(t1 - t0)))
226 | logger.info('iter:%d' % iter)
227 |
228 |
229 | if __name__ == "__main__":
230 | # Logging setup
231 | FORMAT = "%(asctime)s | %(message)s"
232 | logging.basicConfig(format=FORMAT)
233 | logger.setLevel(level=logging.DEBUG)
234 |
235 | # Websockets setup
236 | websockets_logger = logging.getLogger("websockets")
237 | websockets_logger.setLevel(logging.INFO)
238 | websockets_logger.addHandler(logging.StreamHandler())
239 |
240 | # Run main
241 | asyncio.get_event_loop().run_until_complete(main())
242 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A Novel Optimized Asynchronous Federated Learning Framework
2 |
3 | This repository contains the code for the paper A Novel Optimized Asynchronous Federated Learning Framework.
4 | ## Abstract
5 |
6 | Federated Learning (FL) since proposed has been applied in many fields, such as credit assessment, medical, etc. Because of the difference in the network or computing resource, the clients may not update their gradients at the same time that may take a lot of time to wait or idle. That's why Asynchronous Federated Learning (AFL) method is needed. The main bottleneck in AFL is communication. How to find a balance between the model performance and the communication cost is a challenge in AFL. This paper proposed a novel AFL framework VAFL. And we verified the performance of the algorithm through sufficient experiments. The experiments show that VAFL can reduce the communication times about 51.02\% with 48.23\% average communication compression rate and allow the model to be converged faster.
7 |
8 | ## Citation
9 | If you find this code useful for your research, please cite our paper:
10 |
11 | ```
12 | @inproceedings{zhou2021novel,
13 | title={A Novel Optimized Asynchronous Federated Learning Framework},
14 | author={Zhou, Zhicheng and Chen, Hailong and Li, Kunhua and Hu, Fei and Yan, Bingjie and Cheng, Jieren and Wei, Xuyan and Liu, Bernie and Li, Xiulai and Chen, Fuwen and others},
15 | booktitle={2021 IEEE 23rd Int Conf on High Performance Computing \& Communications; 7th Int Conf on Data Science \& Systems; 19th Int Conf on Smart City; 7th Int Conf on Dependability in Sensor, Cloud \& Big Data Systems \& Application (HPCC/DSS/SmartCity/DependSys)},
16 | pages={2363--2370},
17 | year={2021},
18 | organization={IEEE}
19 | }
20 | ```
21 |
--------------------------------------------------------------------------------
/run_websocket_client.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import argparse
3 | import sys
4 | import asyncio
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torchvision import models
10 |
11 | a = nn.Linear
12 | import syft as sy
13 | from syft.workers import websocket_client
14 | from syft.frameworks.torch.fl import utils
15 |
16 | LOG_INTERVAL = 25
17 | logger = logging.getLogger("run_websocket_client")
18 |
19 |
20 | # Loss function
21 | @torch.jit.script
22 | def loss_fn(pred, target):
23 | return F.nll_loss(input=pred, target=target)
24 |
25 |
26 | class ResidualBlock(nn.Module):
27 | def __init__(self, channel):
28 | super().__init__()
29 | self.conv1 = nn.Conv2d(channel, channel, kernel_size=3, padding=1)
30 | self.conv2 = nn.Conv2d(channel, channel, kernel_size=3, padding=1)
31 |
32 | def forward(self, x):
33 | y = F.relu(self.conv1(x))
34 | y = self.conv2(y)
35 |
36 | return F.relu(x + y)
37 |
38 |
39 | class ResidualBlock(nn.Module):
40 | def __init__(self, channel):
41 | super().__init__()
42 | self.conv1 = nn.Conv2d(channel, channel, kernel_size=3, padding=1)
43 | self.conv2 = nn.Conv2d(channel, channel, kernel_size=3, padding=1)
44 |
45 | def forward(self, x):
46 | y = F.relu(self.conv1(x))
47 | y = self.conv2(y)
48 |
49 | return F.relu(x + y)
50 |
51 |
52 | class Net(nn.Module):
53 | def __init__(self):
54 | super().__init__()
55 | self.conv1 = nn.Conv2d(1, 16, kernel_size=5)
56 | self.conv2 = nn.Conv2d(16, 32, kernel_size=5)
57 | self.res_block_1 = ResidualBlock(16)
58 | self.res_block_2 = ResidualBlock(32)
59 | self.conv2_drop = nn.Dropout2d()
60 | self.fc1 = nn.Linear(512, 10)
61 |
62 | def forward(self, x):
63 | in_size = x.size(0)
64 | x = F.max_pool2d(F.relu(self.conv1(x)), 2)
65 | x = self.res_block_1(x)
66 | x = F.max_pool2d(F.relu(self.conv2(x)), 2)
67 | x = self.res_block_2(x)
68 | x = x.view(in_size, -1)
69 | x = self.fc1(x)
70 | return F.log_softmax(x, dim=1)
71 |
72 |
73 | def define_and_get_arguments(args=sys.argv[1:]):
74 | parser = argparse.ArgumentParser(
75 | description="Run federated learning using websocket client workers."
76 | )
77 | parser.add_argument("--batch_size", type=int, default=32, help="batch size of the training")
78 | parser.add_argument(
79 | "--test_batch_size", type=int, default=128, help="batch size used for the test data"
80 | )
81 | parser.add_argument(
82 | "--training_rounds", type=int, default=200, help="number of federated learning rounds"
83 | )
84 | parser.add_argument(
85 | "--federate_after_n_batches",
86 | type=int,
87 | default=10,
88 | help="number of training steps performed on each remote worker before averaging",
89 | )
90 | parser.add_argument("--lr", type=float, default=0.1, help="learning rate")
91 | parser.add_argument("--cuda", action="store_true", help="use cuda")
92 | parser.add_argument("--seed", type=int, default=1, help="seed used for randomization")
93 | parser.add_argument("--save_model", action="store_true", help="if set, model will be saved")
94 | parser.add_argument(
95 | "--verbose",
96 | "-v",
97 | action="store_true",
98 | help="if set, websocket client workers will be started in verbose mode",
99 | )
100 |
101 | args = parser.parse_args(args=args)
102 | return args
103 |
104 | # Asynchronous training models
105 | async def fit_model_on_worker(
106 | worker: websocket_client.WebsocketClientWorker,
107 | traced_model: torch.jit.ScriptModule,
108 | batch_size: int,
109 | curr_round: int,
110 | max_nr_batches: int,
111 | lr: float,
112 | ):
113 | """Send the model to the worker and fit the model on the worker's training data.
114 |
115 | Args:
116 | worker: Remote location, where the model shall be trained.
117 | traced_model: Model which shall be trained.
118 | batch_size: Batch size of each training step.
119 | curr_round: Index of the current training round (for logging purposes).
120 | max_nr_batches: If > 0, training on worker will stop at min(max_nr_batches, nr_available_batches).
121 | lr: Learning rate of each training step.
122 |
123 | Returns:
124 | A tuple containing:
125 | * worker_id: Union[int, str], id of the worker.
126 | * improved model: torch.jit.ScriptModule, model after training at the worker.
127 | * loss: Loss on last training batch, torch.tensor.
128 | """
129 | train_config = sy.TrainConfig(
130 | model=traced_model,
131 | loss_fn=loss_fn,
132 | batch_size=batch_size,
133 | shuffle=True,
134 | max_nr_batches=max_nr_batches,
135 | epochs=1,
136 | optimizer="SGD",
137 | optimizer_args={"lr": lr},
138 | )
139 |
140 | train_config.send(worker)
141 |
142 | loss = await worker.async_fit(dataset_key="mnist", return_ids=[0])
143 |
144 | model = train_config.model_ptr.get().obj
145 |
146 | return worker.id, model, loss
147 |
148 |
149 | def evaluate_model_on_worker(
150 | model_identifier,
151 | worker,
152 | dataset_key,
153 | model,
154 | nr_bins,
155 | batch_size,
156 | print_target_hist=False,
157 | ):
158 | model.eval()
159 |
160 | # Define and send train config
161 | train_config = sy.TrainConfig(
162 | batch_size=batch_size,
163 | model=model,
164 | loss_fn=loss_fn,
165 | optimizer_args=None,
166 | epochs=1
167 | )
168 |
169 | train_config.send(worker)
170 |
171 | result = worker.evaluate(
172 | dataset_key=dataset_key,
173 | return_histograms=True,
174 | nr_bins=nr_bins,
175 | return_loss=True,
176 | return_raw_accuracy=True,
177 | )
178 | test_loss = result["loss"]
179 | correct = result["nr_correct_predictions"]
180 | len_dataset = result["nr_predictions"]
181 | # hist_pred = result["histogram_predictions"]
182 | hist_target = result["histogram_target"]
183 |
184 | if print_target_hist:
185 | logger.info("Target histogram: %s", hist_target)
186 |
187 | '''
188 | percentage_0_3 = int(100 * sum(hist_pred[0:4]) / len_dataset)
189 | percentage_4_6 = int(100 * sum(hist_pred[4:7]) / len_dataset)
190 | percentage_7_9 = int(100 * sum(hist_pred[7:10]) / len_dataset)
191 | logger.info(
192 | "%s: Percentage numbers 0-3: %s%%, 4-6: %s%%, 7-9: %s%%",
193 | model_identifier,
194 | percentage_0_3,
195 | percentage_4_6,
196 | percentage_7_9,
197 | )
198 | '''
199 |
200 | logger.info(
201 | "%s: Average loss: %s, Accuracy: %s/%s (%s%%)",
202 | model_identifier,
203 | f"{test_loss:.4f}",
204 | correct,
205 | len_dataset,
206 | f"{100.0 * correct / len_dataset:.2f}",
207 | )
208 |
209 | return 100.0 * correct / len_dataset
210 |
211 |
212 | async def main():
213 | args = define_and_get_arguments()
214 |
215 | hook = sy.TorchHook(torch)
216 |
217 | kwargs_websocket = {"host": "192.168.2.13", "hook": hook, "verbose": args.verbose}
218 | alice = websocket_client.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)
219 | bob = websocket_client.WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket)
220 | charlie = websocket_client.WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket)
221 | dave = websocket_client.WebsocketClientWorker(id="dave", port=8780, **kwargs_websocket)
222 | # eva = websocket_client.WebsocketClientWorker(id="eva", port=8781, **kwargs_websocket)
223 |
224 | testing = websocket_client.WebsocketClientWorker(id="testing", port=8782, **kwargs_websocket)
225 |
226 | for wcw in [alice, bob, charlie, testing]:
227 | wcw.clear_objects_remote()
228 |
229 | worker_instances = [alice, bob, charlie, dave]
230 |
231 | use_cuda = args.cuda and torch.cuda.is_available()
232 |
233 | torch.manual_seed(args.seed)
234 |
235 | device = torch.device("cuda" if use_cuda else "cpu")
236 |
237 | model = Net().to(device)
238 |
239 | traced_model = torch.jit.trace(model, torch.zeros([1, 1, 28, 28], dtype=torch.float).to(device))
240 | learning_rate = args.lr
241 |
242 | for curr_round in range(1, args.training_rounds + 1):
243 | logger.info("Training round %s/%s", curr_round, args.training_rounds)
244 |
245 | results = await asyncio.gather(
246 | *[
247 | fit_model_on_worker(
248 | worker=worker,
249 | traced_model=traced_model,
250 | batch_size=args.batch_size,
251 | curr_round=curr_round,
252 | max_nr_batches=args.federate_after_n_batches,
253 | lr=learning_rate,
254 | )
255 | for worker in worker_instances
256 | ]
257 | )
258 | models = {}
259 | loss_values = {}
260 |
261 | test_models = curr_round % 10 == 1 or curr_round == args.training_rounds
262 | if test_models:
263 | logger.info("Evaluating models")
264 | np.set_printoptions(formatter={"float": "{: .0f}".format})
265 | for worker_id, worker_model, _ in results:
266 | evaluate_model_on_worker(
267 | model_identifier="Model update " + worker_id,
268 | worker=testing,
269 | dataset_key="mnist_testing",
270 | model=worker_model,
271 | nr_bins=10,
272 | batch_size=128,
273 | print_target_hist=False,
274 | )
275 |
276 | for worker_id, worker_model, worker_loss in results:
277 | if worker_model is not None:
278 | models[worker_id] = worker_model
279 | loss_values[worker_id] = worker_loss
280 |
281 | traced_model = utils.federated_avg(models)
282 |
283 | if test_models:
284 | evaluate_model_on_worker(
285 | model_identifier="Federated model",
286 | worker=testing,
287 | dataset_key="mnist_testing",
288 | model=traced_model,
289 | nr_bins=10,
290 | batch_size=128,
291 | print_target_hist=False,
292 | )
293 |
294 | # decay learning rate
295 | learning_rate = max(0.98 * learning_rate, args.lr * 0.01)
296 |
297 | if args.save_model:
298 | torch.save(model.state_dict(), "mnist_cnn.pt")
299 |
300 |
301 | if __name__ == "__main__":
302 | # Logging setup
303 | FORMAT = "%(asctime)s | %(message)s"
304 | logging.basicConfig(format=FORMAT)
305 | logger.setLevel(level=logging.DEBUG)
306 |
307 | # Websockets setup
308 | websockets_logger = logging.getLogger("websockets")
309 | websockets_logger.setLevel(logging.INFO)
310 | websockets_logger.addHandler(logging.StreamHandler())
311 |
312 | # Run main
313 | asyncio.get_event_loop().run_until_complete(main())
314 |
--------------------------------------------------------------------------------
/run_websocket_server.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import argparse
3 | import numpy as np
4 | import torch.version
5 | from torchvision import datasets
6 | from torchvision import transforms
7 | import syft as sy
8 | from syft.workers import websocket_server
9 |
10 | KEEP_LABELS_DICT = {
11 | "alice": [5, 6, 7, 8, 9],
12 | "bob": [0, 1, 2, 3],
13 | "charlie": [list(range(10))],
14 | "dave": list(range(6)),
15 | "eva": [7, 8, 9],
16 | "frank": [list(range(10))],
17 | "frank1": [list(range(10))],
18 | "testing": list(range(10)),
19 | None: list(range(10)),
20 | }
21 |
22 |
23 | def start_websocket_server_worker(id, host, port, hook, verbose, keep_labels=None, training=True):
24 | """Helper function for spinning up a websocket server and setting up the local datasets."""
25 |
26 | server = websocket_server.WebsocketServerWorker(
27 | id=id, host=host, port=port, hook=hook, verbose=verbose
28 | )
29 |
30 | # Setup toy data (mnist example)
31 | mnist_dataset = datasets.MNIST(
32 | root="./data",
33 | train=training,
34 | download=True,
35 | transform=transforms.Compose(
36 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
37 | ),
38 | )
39 |
40 |
41 |
42 | if training:
43 | indices = np.isin(mnist_dataset.targets, keep_labels).astype("uint8")
44 | logger.info("number of true indices: %s", indices.sum())
45 | selected_data = (
46 | torch.native_masked_select(mnist_dataset.data.transpose(0, 2), torch.tensor(indices))
47 | .view(28, 28, -1)
48 | .transpose(2, 0)
49 | )
50 | logger.info("after selection: %s", selected_data.shape)
51 | selected_targets = torch.native_masked_select(mnist_dataset.targets, torch.tensor(indices))
52 |
53 | dataset = sy.BaseDataset(
54 | data=selected_data, targets=selected_targets, transform=mnist_dataset.transform
55 | )
56 | key = "mnist"
57 | else:
58 | dataset = sy.BaseDataset(
59 | data=mnist_dataset.data,
60 | targets=mnist_dataset.targets,
61 | transform=mnist_dataset.transform,
62 | )
63 | key = "mnist_testing"
64 |
65 | server.add_dataset(dataset, key=key)
66 | count = [0] * 10
67 | logger.info(
68 | "MNIST dataset (%s set), available numbers on %s: ", "train" if training else "test", id
69 | )
70 | for i in range(10):
71 | count[i] = (dataset.targets == i).sum().item()
72 | logger.info(" %s: %s", i, count[i])
73 |
74 | '''
75 | logger.info("datasets: %s", server.datasets)
76 | if training:
77 | logger.info("len(datasets[mnist]): %s", len(server.datasets[key]))
78 | '''
79 |
80 | server.start()
81 | return server
82 |
83 |
84 | if __name__ == "__main__":
85 | # Logging setup
86 | FORMAT = "%(asctime)s | %(message)s"
87 | logging.basicConfig(format=FORMAT)
88 | logger = logging.getLogger("run_websocket_server")
89 | logger.setLevel(level=logging.DEBUG)
90 |
91 | # Parse args
92 | parser = argparse.ArgumentParser(description="Run websocket server worker.")
93 | parser.add_argument(
94 | "--port",
95 | "-p",
96 | type=int,
97 | help="port number of the websocket server worker, e.g. --port 8777",
98 | )
99 | parser.add_argument(
100 | "--host",
101 | type=str,
102 | default="localhost",
103 | help="host for the connection")
104 | parser.add_argument(
105 | "--id",
106 | type=str,
107 | help="name (id) of the websocket server worker, e.g. --id alice"
108 | )
109 | parser.add_argument(
110 | "--testing",
111 | action="store_true",
112 | help="if set, websocket server worker will load the test dataset instead of the training dataset",
113 | )
114 | parser.add_argument(
115 | "--verbose",
116 | "-v",
117 | action="store_true",
118 | help="if set, websocket server worker will be started in verbose mode",
119 | )
120 |
121 | args = parser.parse_args()
122 |
123 | # Hook and start server
124 | hook = sy.TorchHook(torch)
125 | server = start_websocket_server_worker(
126 | id=args.id,
127 | host=args.host,
128 | port=args.port,
129 | hook=hook,
130 | verbose=args.verbose,
131 | keep_labels=KEEP_LABELS_DICT[args.id],
132 | training=not args.testing,
133 | )
134 |
--------------------------------------------------------------------------------
/start_websocket_servers.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 | from torchvision import datasets
4 | from torchvision import transforms
5 |
6 | import signal
7 | import sys
8 | '''
9 | python run_websocket_server.py --port 8777 --id alice --host 0.0.0.0
10 | python run_websocket_server.py --port 8778 --id bob --host 0.0.0.0
11 | python run_websocket_server.py --port 8779 --id charlie --host 0.0.0.0
12 | python run_websocket_server.py --port 8780 --id testing --testing --host 0.0.0.0
13 |
14 | python run_websocket_server.py --port 8783 --id testing --testing --host localhost
15 | '''
16 |
17 | # Downloads MNIST dataset
18 | mnist_trainset = datasets.MNIST(
19 | root="./data",
20 | train=True,
21 | download=True,
22 | transform=transforms.Compose(
23 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
24 | ),
25 | )
26 |
27 | call_alice = [
28 | "python",
29 | "run_websocket_server.py",
30 | "--port",
31 | "8777",
32 | "--id",
33 | "alice",
34 | "--host",
35 | "0.0.0.0",
36 | ]
37 |
38 | call_bob = [
39 | "python",
40 | "run_websocket_server.py",
41 | "--port",
42 | "8778",
43 | "--id",
44 | "bob",
45 | "--host",
46 | "0.0.0.0",
47 | ]
48 |
49 | call_charlie = [
50 | "python",
51 | "run_websocket_server.py",
52 | "--port",
53 | "8779",
54 | "--id",
55 | "charlie",
56 | "--host",
57 | "0.0.0.0",
58 | ]
59 |
60 | call_testing = [
61 | "python",
62 | "run_websocket_server.py",
63 | "--port",
64 | "8780",
65 | "--id",
66 | "testing",
67 | "--testing",
68 | "--host",
69 | "0.0.0.0",
70 | ]
71 |
72 | print("Starting server for Alice")
73 | process_alice = subprocess.Popen(call_alice)
74 |
75 | print("Starting server for Bob")
76 | process_bob = subprocess.Popen(call_bob)
77 |
78 | print("Starting server for Charlie")
79 | process_charlie = subprocess.Popen(call_charlie)
80 |
81 | print("Starting server for Testing")
82 | process_testing = subprocess.Popen(call_testing)
83 |
84 |
85 | def signal_handler(sig, frame):
86 | print("You pressed Ctrl+C!")
87 | for p in [process_alice, process_bob, process_charlie, process_testing]:
88 | p.terminate()
89 | sys.exit(0)
90 |
91 |
92 | signal.signal(signal.SIGINT, signal_handler)
93 |
94 | # signal.pause()
95 |
--------------------------------------------------------------------------------