├── AFL.py ├── README.md ├── run_websocket_client.py ├── run_websocket_server.py └── start_websocket_servers.py /AFL.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | # Dependencies 4 | import math 5 | import sys 6 | import pandas as pd 7 | import asyncio 8 | import logging 9 | import time 10 | 11 | logger = logging.getLogger("run_websocket_client") 12 | 13 | import syft as sy 14 | from syft.workers.websocket_client import WebsocketClientWorker 15 | from syft.frameworks.torch.fl import utils 16 | 17 | import torch 18 | import numpy as np 19 | 20 | import run_websocket_client as rwc 21 | 22 | path = "./result" 23 | 24 | 25 | async def main(): 26 | # 创建csv文件 27 | global old_model 28 | df = pd.DataFrame(columns=['step', 29 | 'acc1', 'acc2', 30 | 'acc3', 'accf']) 31 | t = str(time.time()) 32 | df.to_csv(path + "/" + t + ".csv", index=False) 33 | iter = 0 34 | 35 | # Hook torch 36 | hook = sy.TorchHook(torch) 37 | 38 | # Arguments 39 | args = rwc.define_and_get_arguments(args=[]) 40 | use_cuda = args.cuda and torch.cuda.is_available() 41 | torch.manual_seed(args.seed) 42 | device = torch.device("cuda" if use_cuda else "cpu") 43 | print(args) 44 | 45 | # Configure logging 46 | 47 | if not len(logger.handlers): 48 | FORMAT = "%(asctime)s - %(message)s" 49 | DATE_FMT = "%H:%M:%S" 50 | formatter = logging.Formatter(FORMAT, DATE_FMT) 51 | handler = logging.StreamHandler() 52 | handler.setFormatter(formatter) 53 | logger.addHandler(handler) 54 | logger.propagate = False 55 | LOG_LEVEL = logging.DEBUG 56 | logger.setLevel(LOG_LEVEL) 57 | 58 | t0 = time.time() 59 | 60 | akwargs_websocket = {"host": "192.168.2.13", "hook": hook, "verbose": args.verbose} 61 | alice = WebsocketClientWorker(id="alice", port=8777, **akwargs_websocket) 62 | 63 | bkwargs_websocket = {"host": "192.168.2.16", "hook": hook, "verbose": args.verbose} 64 | bob = WebsocketClientWorker(id="bob", port=8778, **bkwargs_websocket) 65 | 66 | ckwargs_websocket = {"host": "192.168.2.14", "hook": hook, "verbose": args.verbose} 67 | charlie = WebsocketClientWorker(id="charlie", port=8779, **ckwargs_websocket) 68 | 69 | dkwargs_websocket = {"host": "192.168.2.11", "hook": hook, "verbose": args.verbose} 70 | #dave = WebsocketClientWorker(id="dave", port=8780, **dkwargs_websocket) 71 | 72 | ekwargs_websocket = {"host": "192.168.2.12", "hook": hook, "verbose": args.verbose} 73 | #eva = WebsocketClientWorker(id="eva", port=8781, **ekwargs_websocket) 74 | 75 | kwargs_websocket = {"host": "localhost", "hook": hook, "verbose": args.verbose} 76 | #frank = WebsocketClientWorker(id="frank", port=8782, **kwargs_websocket) 77 | #frank1 = WebsocketClientWorker(id="frank1", port=8792, **kwargs_websocket) 78 | 79 | testing = WebsocketClientWorker(id="testing", port=8783, **kwargs_websocket) 80 | 81 | worker_instances = [ 82 | alice, 83 | bob, 84 | charlie, 85 | # dave, 86 | # eva, 87 | # frank, 88 | # frank1 89 | ] 90 | 91 | model = rwc.Net().to(device) 92 | # print(model) 93 | 94 | print("Federate_after_n_batches: " + str(args.federate_after_n_batches)) 95 | print("Batch size: " + str(args.batch_size)) 96 | print("Initial learning rate: " + str(args.lr)) 97 | 98 | learning_rate = args.lr 99 | 100 | traced_model = torch.jit.trace(model, torch.zeros([1, 1, 28, 28], dtype=torch.float)) 101 | 102 | for curr_round in range(1, args.training_rounds + 1): 103 | 104 | '''OLD MODEL''' 105 | Empty_model = utils.scale_model(model, 0) 106 | old_model = utils.add_model(Empty_model, traced_model) 107 | 108 | # train 109 | logger.info("Training round %s/%s", curr_round, args.training_rounds) 110 | results = await asyncio.gather( 111 | *[ 112 | rwc.fit_model_on_worker( 113 | worker=worker, 114 | traced_model=traced_model, 115 | batch_size=args.batch_size, 116 | curr_round=curr_round, 117 | max_nr_batches=args.federate_after_n_batches, 118 | lr=learning_rate, 119 | ) 120 | for worker in worker_instances 121 | ] 122 | ) 123 | 124 | ''' 125 | models, loss, grads, n_server 126 | 127 | values: V 128 | V[worker_id] represents the value of a participant, 129 | and the smaller the value, the lower the need for its inclusion in the federation model 130 | How to calculate V? 131 | 1) grad: the gradient reflects the proximity to the local solution 132 | 2) total number of client: In federation learning, the more the number of participants, 133 | the smaller the value of a single participant 134 | 3) accuracy in testing 135 | ''' 136 | 137 | models = {} 138 | loss_values = {} 139 | grads = {} 140 | n_server = 3 141 | acc = {} 142 | 143 | V = {} 144 | 145 | # test 146 | test_models = curr_round % 5 == 1 or curr_round == args.training_rounds 147 | if test_models: 148 | 149 | logger.info("Evaluating models") 150 | np.set_printoptions(formatter={"float": "{: .0f}".format}) 151 | for worker_id, worker_model, _ in results: 152 | acc[worker_id] = rwc.evaluate_model_on_worker( 153 | model_identifier="Model update " + worker_id, 154 | worker=testing, 155 | dataset_key="mnist_testing", 156 | model=worker_model, 157 | nr_bins=10, 158 | batch_size=128, 159 | print_target_hist=False, 160 | ) 161 | new_model = worker_model 162 | grads[worker_id] = utils.add_model(new_model, utils.scale_model(old_model, -1)) 163 | 164 | grad = 0 165 | for p in grads[worker_id].parameters(): 166 | p2 = torch.dot(p.view(-1), p.view(-1)) 167 | grad += p2 168 | 169 | print(grad.detach().numpy()) 170 | grad = grad.detach().numpy() 171 | 172 | V[worker_id] = grad * pow(1 + n_server / 1000, acc[worker_id]) 173 | 174 | Vlist = list(V.values()) 175 | 176 | ave_V = np.mean(Vlist) 177 | 178 | # Federal model (this operation changes the initial model) 179 | for worker_id, worker_model, worker_loss in results: 180 | if worker_model is not None: 181 | loss_values[worker_id] = worker_loss 182 | 183 | if V[worker_id] >= ave_V: 184 | models[worker_id] = worker_model 185 | 186 | iter += len(models.keys()) 187 | # federated_avg 188 | traced_model = utils.federated_avg(models) 189 | 190 | if test_models: 191 | accf = rwc.evaluate_model_on_worker( 192 | model_identifier="Federated model", 193 | worker=testing, 194 | dataset_key="mnist_testing", 195 | model=traced_model, 196 | nr_bins=10, 197 | batch_size=128, 198 | print_target_hist=False, 199 | ) 200 | 201 | step = curr_round 202 | acc1 = acc['alice'] 203 | acc2 = acc['bob'] 204 | acc3 = acc['charlie'] 205 | ''' 206 | acc4 = acc['dave'] 207 | acc5 = acc['eva'] 208 | acc6 = acc['frank'] 209 | acc7 = acc['frank1'] 210 | ''' 211 | 212 | scv_list = [step, acc1, acc2, acc3, accf] 213 | scv_data = pd.DataFrame([scv_list]) 214 | scv_data.to_csv(path + "/" + t + ".csv", mode='a', header=False, index=False) 215 | 216 | # decay learning rate 217 | learning_rate = max(0.98 * learning_rate, args.lr * 0.01) 218 | if accf is not None and accf >= 94.0: 219 | break 220 | 221 | torch.save(model.state_dict(), "mnist_cnn.pt") 222 | 223 | # time 224 | t1 = time.time() 225 | logger.info('cost:%.4f seconds' % (float(t1 - t0))) 226 | logger.info('iter:%d' % iter) 227 | 228 | 229 | if __name__ == "__main__": 230 | # Logging setup 231 | FORMAT = "%(asctime)s | %(message)s" 232 | logging.basicConfig(format=FORMAT) 233 | logger.setLevel(level=logging.DEBUG) 234 | 235 | # Websockets setup 236 | websockets_logger = logging.getLogger("websockets") 237 | websockets_logger.setLevel(logging.INFO) 238 | websockets_logger.addHandler(logging.StreamHandler()) 239 | 240 | # Run main 241 | asyncio.get_event_loop().run_until_complete(main()) 242 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Novel Optimized Asynchronous Federated Learning Framework 2 | 3 | This repository contains the code for the paper A Novel Optimized Asynchronous Federated Learning Framework. 4 | ## Abstract 5 | 6 | Federated Learning (FL) since proposed has been applied in many fields, such as credit assessment, medical, etc. Because of the difference in the network or computing resource, the clients may not update their gradients at the same time that may take a lot of time to wait or idle. That's why Asynchronous Federated Learning (AFL) method is needed. The main bottleneck in AFL is communication. How to find a balance between the model performance and the communication cost is a challenge in AFL. This paper proposed a novel AFL framework VAFL. And we verified the performance of the algorithm through sufficient experiments. The experiments show that VAFL can reduce the communication times about 51.02\% with 48.23\% average communication compression rate and allow the model to be converged faster. 7 | 8 | ## Citation 9 | If you find this code useful for your research, please cite our paper: 10 | 11 | ``` 12 | @inproceedings{zhou2021novel,
13 | title={A Novel Optimized Asynchronous Federated Learning Framework},
14 | author={Zhou, Zhicheng and Chen, Hailong and Li, Kunhua and Hu, Fei and Yan, Bingjie and Cheng, Jieren and Wei, Xuyan and Liu, Bernie and Li, Xiulai and Chen, Fuwen and others},
15 | booktitle={2021 IEEE 23rd Int Conf on High Performance Computing \& Communications; 7th Int Conf on Data Science \& Systems; 19th Int Conf on Smart City; 7th Int Conf on Dependability in Sensor, Cloud \& Big Data Systems \& Application (HPCC/DSS/SmartCity/DependSys)},
16 | pages={2363--2370},
17 | year={2021},
18 | organization={IEEE}
19 | } 20 | ``` 21 | -------------------------------------------------------------------------------- /run_websocket_client.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import sys 4 | import asyncio 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torchvision import models 10 | 11 | a = nn.Linear 12 | import syft as sy 13 | from syft.workers import websocket_client 14 | from syft.frameworks.torch.fl import utils 15 | 16 | LOG_INTERVAL = 25 17 | logger = logging.getLogger("run_websocket_client") 18 | 19 | 20 | # Loss function 21 | @torch.jit.script 22 | def loss_fn(pred, target): 23 | return F.nll_loss(input=pred, target=target) 24 | 25 | 26 | class ResidualBlock(nn.Module): 27 | def __init__(self, channel): 28 | super().__init__() 29 | self.conv1 = nn.Conv2d(channel, channel, kernel_size=3, padding=1) 30 | self.conv2 = nn.Conv2d(channel, channel, kernel_size=3, padding=1) 31 | 32 | def forward(self, x): 33 | y = F.relu(self.conv1(x)) 34 | y = self.conv2(y) 35 | 36 | return F.relu(x + y) 37 | 38 | 39 | class ResidualBlock(nn.Module): 40 | def __init__(self, channel): 41 | super().__init__() 42 | self.conv1 = nn.Conv2d(channel, channel, kernel_size=3, padding=1) 43 | self.conv2 = nn.Conv2d(channel, channel, kernel_size=3, padding=1) 44 | 45 | def forward(self, x): 46 | y = F.relu(self.conv1(x)) 47 | y = self.conv2(y) 48 | 49 | return F.relu(x + y) 50 | 51 | 52 | class Net(nn.Module): 53 | def __init__(self): 54 | super().__init__() 55 | self.conv1 = nn.Conv2d(1, 16, kernel_size=5) 56 | self.conv2 = nn.Conv2d(16, 32, kernel_size=5) 57 | self.res_block_1 = ResidualBlock(16) 58 | self.res_block_2 = ResidualBlock(32) 59 | self.conv2_drop = nn.Dropout2d() 60 | self.fc1 = nn.Linear(512, 10) 61 | 62 | def forward(self, x): 63 | in_size = x.size(0) 64 | x = F.max_pool2d(F.relu(self.conv1(x)), 2) 65 | x = self.res_block_1(x) 66 | x = F.max_pool2d(F.relu(self.conv2(x)), 2) 67 | x = self.res_block_2(x) 68 | x = x.view(in_size, -1) 69 | x = self.fc1(x) 70 | return F.log_softmax(x, dim=1) 71 | 72 | 73 | def define_and_get_arguments(args=sys.argv[1:]): 74 | parser = argparse.ArgumentParser( 75 | description="Run federated learning using websocket client workers." 76 | ) 77 | parser.add_argument("--batch_size", type=int, default=32, help="batch size of the training") 78 | parser.add_argument( 79 | "--test_batch_size", type=int, default=128, help="batch size used for the test data" 80 | ) 81 | parser.add_argument( 82 | "--training_rounds", type=int, default=200, help="number of federated learning rounds" 83 | ) 84 | parser.add_argument( 85 | "--federate_after_n_batches", 86 | type=int, 87 | default=10, 88 | help="number of training steps performed on each remote worker before averaging", 89 | ) 90 | parser.add_argument("--lr", type=float, default=0.1, help="learning rate") 91 | parser.add_argument("--cuda", action="store_true", help="use cuda") 92 | parser.add_argument("--seed", type=int, default=1, help="seed used for randomization") 93 | parser.add_argument("--save_model", action="store_true", help="if set, model will be saved") 94 | parser.add_argument( 95 | "--verbose", 96 | "-v", 97 | action="store_true", 98 | help="if set, websocket client workers will be started in verbose mode", 99 | ) 100 | 101 | args = parser.parse_args(args=args) 102 | return args 103 | 104 | # Asynchronous training models 105 | async def fit_model_on_worker( 106 | worker: websocket_client.WebsocketClientWorker, 107 | traced_model: torch.jit.ScriptModule, 108 | batch_size: int, 109 | curr_round: int, 110 | max_nr_batches: int, 111 | lr: float, 112 | ): 113 | """Send the model to the worker and fit the model on the worker's training data. 114 | 115 | Args: 116 | worker: Remote location, where the model shall be trained. 117 | traced_model: Model which shall be trained. 118 | batch_size: Batch size of each training step. 119 | curr_round: Index of the current training round (for logging purposes). 120 | max_nr_batches: If > 0, training on worker will stop at min(max_nr_batches, nr_available_batches). 121 | lr: Learning rate of each training step. 122 | 123 | Returns: 124 | A tuple containing: 125 | * worker_id: Union[int, str], id of the worker. 126 | * improved model: torch.jit.ScriptModule, model after training at the worker. 127 | * loss: Loss on last training batch, torch.tensor. 128 | """ 129 | train_config = sy.TrainConfig( 130 | model=traced_model, 131 | loss_fn=loss_fn, 132 | batch_size=batch_size, 133 | shuffle=True, 134 | max_nr_batches=max_nr_batches, 135 | epochs=1, 136 | optimizer="SGD", 137 | optimizer_args={"lr": lr}, 138 | ) 139 | 140 | train_config.send(worker) 141 | 142 | loss = await worker.async_fit(dataset_key="mnist", return_ids=[0]) 143 | 144 | model = train_config.model_ptr.get().obj 145 | 146 | return worker.id, model, loss 147 | 148 | 149 | def evaluate_model_on_worker( 150 | model_identifier, 151 | worker, 152 | dataset_key, 153 | model, 154 | nr_bins, 155 | batch_size, 156 | print_target_hist=False, 157 | ): 158 | model.eval() 159 | 160 | # Define and send train config 161 | train_config = sy.TrainConfig( 162 | batch_size=batch_size, 163 | model=model, 164 | loss_fn=loss_fn, 165 | optimizer_args=None, 166 | epochs=1 167 | ) 168 | 169 | train_config.send(worker) 170 | 171 | result = worker.evaluate( 172 | dataset_key=dataset_key, 173 | return_histograms=True, 174 | nr_bins=nr_bins, 175 | return_loss=True, 176 | return_raw_accuracy=True, 177 | ) 178 | test_loss = result["loss"] 179 | correct = result["nr_correct_predictions"] 180 | len_dataset = result["nr_predictions"] 181 | # hist_pred = result["histogram_predictions"] 182 | hist_target = result["histogram_target"] 183 | 184 | if print_target_hist: 185 | logger.info("Target histogram: %s", hist_target) 186 | 187 | ''' 188 | percentage_0_3 = int(100 * sum(hist_pred[0:4]) / len_dataset) 189 | percentage_4_6 = int(100 * sum(hist_pred[4:7]) / len_dataset) 190 | percentage_7_9 = int(100 * sum(hist_pred[7:10]) / len_dataset) 191 | logger.info( 192 | "%s: Percentage numbers 0-3: %s%%, 4-6: %s%%, 7-9: %s%%", 193 | model_identifier, 194 | percentage_0_3, 195 | percentage_4_6, 196 | percentage_7_9, 197 | ) 198 | ''' 199 | 200 | logger.info( 201 | "%s: Average loss: %s, Accuracy: %s/%s (%s%%)", 202 | model_identifier, 203 | f"{test_loss:.4f}", 204 | correct, 205 | len_dataset, 206 | f"{100.0 * correct / len_dataset:.2f}", 207 | ) 208 | 209 | return 100.0 * correct / len_dataset 210 | 211 | 212 | async def main(): 213 | args = define_and_get_arguments() 214 | 215 | hook = sy.TorchHook(torch) 216 | 217 | kwargs_websocket = {"host": "192.168.2.13", "hook": hook, "verbose": args.verbose} 218 | alice = websocket_client.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket) 219 | bob = websocket_client.WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket) 220 | charlie = websocket_client.WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket) 221 | dave = websocket_client.WebsocketClientWorker(id="dave", port=8780, **kwargs_websocket) 222 | # eva = websocket_client.WebsocketClientWorker(id="eva", port=8781, **kwargs_websocket) 223 | 224 | testing = websocket_client.WebsocketClientWorker(id="testing", port=8782, **kwargs_websocket) 225 | 226 | for wcw in [alice, bob, charlie, testing]: 227 | wcw.clear_objects_remote() 228 | 229 | worker_instances = [alice, bob, charlie, dave] 230 | 231 | use_cuda = args.cuda and torch.cuda.is_available() 232 | 233 | torch.manual_seed(args.seed) 234 | 235 | device = torch.device("cuda" if use_cuda else "cpu") 236 | 237 | model = Net().to(device) 238 | 239 | traced_model = torch.jit.trace(model, torch.zeros([1, 1, 28, 28], dtype=torch.float).to(device)) 240 | learning_rate = args.lr 241 | 242 | for curr_round in range(1, args.training_rounds + 1): 243 | logger.info("Training round %s/%s", curr_round, args.training_rounds) 244 | 245 | results = await asyncio.gather( 246 | *[ 247 | fit_model_on_worker( 248 | worker=worker, 249 | traced_model=traced_model, 250 | batch_size=args.batch_size, 251 | curr_round=curr_round, 252 | max_nr_batches=args.federate_after_n_batches, 253 | lr=learning_rate, 254 | ) 255 | for worker in worker_instances 256 | ] 257 | ) 258 | models = {} 259 | loss_values = {} 260 | 261 | test_models = curr_round % 10 == 1 or curr_round == args.training_rounds 262 | if test_models: 263 | logger.info("Evaluating models") 264 | np.set_printoptions(formatter={"float": "{: .0f}".format}) 265 | for worker_id, worker_model, _ in results: 266 | evaluate_model_on_worker( 267 | model_identifier="Model update " + worker_id, 268 | worker=testing, 269 | dataset_key="mnist_testing", 270 | model=worker_model, 271 | nr_bins=10, 272 | batch_size=128, 273 | print_target_hist=False, 274 | ) 275 | 276 | for worker_id, worker_model, worker_loss in results: 277 | if worker_model is not None: 278 | models[worker_id] = worker_model 279 | loss_values[worker_id] = worker_loss 280 | 281 | traced_model = utils.federated_avg(models) 282 | 283 | if test_models: 284 | evaluate_model_on_worker( 285 | model_identifier="Federated model", 286 | worker=testing, 287 | dataset_key="mnist_testing", 288 | model=traced_model, 289 | nr_bins=10, 290 | batch_size=128, 291 | print_target_hist=False, 292 | ) 293 | 294 | # decay learning rate 295 | learning_rate = max(0.98 * learning_rate, args.lr * 0.01) 296 | 297 | if args.save_model: 298 | torch.save(model.state_dict(), "mnist_cnn.pt") 299 | 300 | 301 | if __name__ == "__main__": 302 | # Logging setup 303 | FORMAT = "%(asctime)s | %(message)s" 304 | logging.basicConfig(format=FORMAT) 305 | logger.setLevel(level=logging.DEBUG) 306 | 307 | # Websockets setup 308 | websockets_logger = logging.getLogger("websockets") 309 | websockets_logger.setLevel(logging.INFO) 310 | websockets_logger.addHandler(logging.StreamHandler()) 311 | 312 | # Run main 313 | asyncio.get_event_loop().run_until_complete(main()) 314 | -------------------------------------------------------------------------------- /run_websocket_server.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import numpy as np 4 | import torch.version 5 | from torchvision import datasets 6 | from torchvision import transforms 7 | import syft as sy 8 | from syft.workers import websocket_server 9 | 10 | KEEP_LABELS_DICT = { 11 | "alice": [5, 6, 7, 8, 9], 12 | "bob": [0, 1, 2, 3], 13 | "charlie": [list(range(10))], 14 | "dave": list(range(6)), 15 | "eva": [7, 8, 9], 16 | "frank": [list(range(10))], 17 | "frank1": [list(range(10))], 18 | "testing": list(range(10)), 19 | None: list(range(10)), 20 | } 21 | 22 | 23 | def start_websocket_server_worker(id, host, port, hook, verbose, keep_labels=None, training=True): 24 | """Helper function for spinning up a websocket server and setting up the local datasets.""" 25 | 26 | server = websocket_server.WebsocketServerWorker( 27 | id=id, host=host, port=port, hook=hook, verbose=verbose 28 | ) 29 | 30 | # Setup toy data (mnist example) 31 | mnist_dataset = datasets.MNIST( 32 | root="./data", 33 | train=training, 34 | download=True, 35 | transform=transforms.Compose( 36 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 37 | ), 38 | ) 39 | 40 | 41 | 42 | if training: 43 | indices = np.isin(mnist_dataset.targets, keep_labels).astype("uint8") 44 | logger.info("number of true indices: %s", indices.sum()) 45 | selected_data = ( 46 | torch.native_masked_select(mnist_dataset.data.transpose(0, 2), torch.tensor(indices)) 47 | .view(28, 28, -1) 48 | .transpose(2, 0) 49 | ) 50 | logger.info("after selection: %s", selected_data.shape) 51 | selected_targets = torch.native_masked_select(mnist_dataset.targets, torch.tensor(indices)) 52 | 53 | dataset = sy.BaseDataset( 54 | data=selected_data, targets=selected_targets, transform=mnist_dataset.transform 55 | ) 56 | key = "mnist" 57 | else: 58 | dataset = sy.BaseDataset( 59 | data=mnist_dataset.data, 60 | targets=mnist_dataset.targets, 61 | transform=mnist_dataset.transform, 62 | ) 63 | key = "mnist_testing" 64 | 65 | server.add_dataset(dataset, key=key) 66 | count = [0] * 10 67 | logger.info( 68 | "MNIST dataset (%s set), available numbers on %s: ", "train" if training else "test", id 69 | ) 70 | for i in range(10): 71 | count[i] = (dataset.targets == i).sum().item() 72 | logger.info(" %s: %s", i, count[i]) 73 | 74 | ''' 75 | logger.info("datasets: %s", server.datasets) 76 | if training: 77 | logger.info("len(datasets[mnist]): %s", len(server.datasets[key])) 78 | ''' 79 | 80 | server.start() 81 | return server 82 | 83 | 84 | if __name__ == "__main__": 85 | # Logging setup 86 | FORMAT = "%(asctime)s | %(message)s" 87 | logging.basicConfig(format=FORMAT) 88 | logger = logging.getLogger("run_websocket_server") 89 | logger.setLevel(level=logging.DEBUG) 90 | 91 | # Parse args 92 | parser = argparse.ArgumentParser(description="Run websocket server worker.") 93 | parser.add_argument( 94 | "--port", 95 | "-p", 96 | type=int, 97 | help="port number of the websocket server worker, e.g. --port 8777", 98 | ) 99 | parser.add_argument( 100 | "--host", 101 | type=str, 102 | default="localhost", 103 | help="host for the connection") 104 | parser.add_argument( 105 | "--id", 106 | type=str, 107 | help="name (id) of the websocket server worker, e.g. --id alice" 108 | ) 109 | parser.add_argument( 110 | "--testing", 111 | action="store_true", 112 | help="if set, websocket server worker will load the test dataset instead of the training dataset", 113 | ) 114 | parser.add_argument( 115 | "--verbose", 116 | "-v", 117 | action="store_true", 118 | help="if set, websocket server worker will be started in verbose mode", 119 | ) 120 | 121 | args = parser.parse_args() 122 | 123 | # Hook and start server 124 | hook = sy.TorchHook(torch) 125 | server = start_websocket_server_worker( 126 | id=args.id, 127 | host=args.host, 128 | port=args.port, 129 | hook=hook, 130 | verbose=args.verbose, 131 | keep_labels=KEEP_LABELS_DICT[args.id], 132 | training=not args.testing, 133 | ) 134 | -------------------------------------------------------------------------------- /start_websocket_servers.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | from torchvision import datasets 4 | from torchvision import transforms 5 | 6 | import signal 7 | import sys 8 | ''' 9 | python run_websocket_server.py --port 8777 --id alice --host 0.0.0.0 10 | python run_websocket_server.py --port 8778 --id bob --host 0.0.0.0 11 | python run_websocket_server.py --port 8779 --id charlie --host 0.0.0.0 12 | python run_websocket_server.py --port 8780 --id testing --testing --host 0.0.0.0 13 | 14 | python run_websocket_server.py --port 8783 --id testing --testing --host localhost 15 | ''' 16 | 17 | # Downloads MNIST dataset 18 | mnist_trainset = datasets.MNIST( 19 | root="./data", 20 | train=True, 21 | download=True, 22 | transform=transforms.Compose( 23 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 24 | ), 25 | ) 26 | 27 | call_alice = [ 28 | "python", 29 | "run_websocket_server.py", 30 | "--port", 31 | "8777", 32 | "--id", 33 | "alice", 34 | "--host", 35 | "0.0.0.0", 36 | ] 37 | 38 | call_bob = [ 39 | "python", 40 | "run_websocket_server.py", 41 | "--port", 42 | "8778", 43 | "--id", 44 | "bob", 45 | "--host", 46 | "0.0.0.0", 47 | ] 48 | 49 | call_charlie = [ 50 | "python", 51 | "run_websocket_server.py", 52 | "--port", 53 | "8779", 54 | "--id", 55 | "charlie", 56 | "--host", 57 | "0.0.0.0", 58 | ] 59 | 60 | call_testing = [ 61 | "python", 62 | "run_websocket_server.py", 63 | "--port", 64 | "8780", 65 | "--id", 66 | "testing", 67 | "--testing", 68 | "--host", 69 | "0.0.0.0", 70 | ] 71 | 72 | print("Starting server for Alice") 73 | process_alice = subprocess.Popen(call_alice) 74 | 75 | print("Starting server for Bob") 76 | process_bob = subprocess.Popen(call_bob) 77 | 78 | print("Starting server for Charlie") 79 | process_charlie = subprocess.Popen(call_charlie) 80 | 81 | print("Starting server for Testing") 82 | process_testing = subprocess.Popen(call_testing) 83 | 84 | 85 | def signal_handler(sig, frame): 86 | print("You pressed Ctrl+C!") 87 | for p in [process_alice, process_bob, process_charlie, process_testing]: 88 | p.terminate() 89 | sys.exit(0) 90 | 91 | 92 | signal.signal(signal.SIGINT, signal_handler) 93 | 94 | # signal.pause() 95 | --------------------------------------------------------------------------------