├── .gitignore ├── LICENSE ├── README.md ├── art └── cifar_10_experiment.png ├── fed_learn ├── __init__.py ├── args_helper.py ├── data_utils.py ├── experiment_utils.py ├── fed_client.py ├── fed_server.py ├── models.py ├── utils.py └── weight_summarizer.py ├── federated_learning.py ├── requirements.txt └── tests └── test_weight_summarizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__ 3 | experiments/ 4 | *.json 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Gábor Vecsei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Federated Learning mini-framework 2 | 3 | This repo contains a Federated Learning (FL) setup with the Keras (Tensorflow) framework. The purpose is to have the 4 | codebase with which you can run FL experiments easily, for both IID and Non-IID data. 5 | 6 | The two main components are: Server and Client. The **Server** contains the model description, distributes the data 7 | and coordinates the learning. And for all the clients it summarizes the results to update it's own (global) model. 8 | The **Clients** have different random chunks of data and the model description with the global model's weights. From 9 | this initialized status they can start the training on their own dataset for a few iterations. In a real world 10 | scenario the clients are edge devices and the training is running in parallel. 11 | 12 | In this setup the client trainings are running sequentially and you can use only your CPU or just 1 GPU. 13 | 14 | ## Cifar10 - "Shallow" VGG16 15 | 16 | Training with a shallow version of VGG16 on Cifar10 with IID data where we had 100 clients and for each round (global epoch) we used only 17 | 10% of them (selected randomly at each communication round). Every client fitted 1 epoch on "their part" of the data with the batch size of `[blue: 8, orange: 64, gray: 256]` and with learning rate of `0.1`. 18 | 19 | A "single model" training (1 client with all the data) is also shown (`red`) on the graph. Batch size was `256` and the learning rate was: `0.05`. 20 | 21 | 22 | 23 | (The Tensorboard logs are (for each experiment) included in the release, so you can easily visualize them.) 24 | 25 | ## About 26 | 27 | Gábor Vecsei 28 | 29 | - [Website](https://gaborvecsei.com) 30 | - [Twitter](https://twitter.com/GAwesomeBE) 31 | - [LinkedIn](https://www.linkedin.com/in/gaborvecsei) 32 | - [Personal Blog](https://gaborvecsei.wordpress.com/) 33 | - [Github](https://github.com/gaborvecsei) 34 | -------------------------------------------------------------------------------- /art/cifar_10_experiment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaborvecsei/Federated-Learning-Mini-Framework/83a2156fe71f7219eaddc9b3b4e5d334a98ad587/art/cifar_10_experiment.png -------------------------------------------------------------------------------- /fed_learn/__init__.py: -------------------------------------------------------------------------------- 1 | from .args_helper import get_args, save_args_as_json, args_as_json 2 | from .data_utils import iid_data_indices, non_iid_data_indices, DataHandler, CifarProcessor, BaseDataProcessor 3 | from .fed_client import Client 4 | from .fed_server import Server 5 | from .models import create_model, set_model_weights 6 | from .utils import get_rid_of_the_models, print_selected_clients, set_working_GPU 7 | from .weight_summarizer import FedAvg, WeightSummarizer 8 | from .experiment_utils import Experiment 9 | -------------------------------------------------------------------------------- /fed_learn/args_helper.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | 5 | def get_args(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("-n", "--name", help="Name of the experiment", type=str, required=True) 8 | parser.add_argument("-oe", "--overwrite-experiment", help="Overwrite existing experiment", action="store_true", 9 | required=False) 10 | parser.add_argument("-s", "--data-sampling-technique", help="Data sampling technique (IID or Non-IID)", type=str, 11 | default="iid", required=False) 12 | parser.add_argument("-w", "--weights-file", help="Weights file path to load", type=str, required=False) 13 | parser.add_argument("-e", "--global-epochs", help="Number of global (server) epochs", type=int, default=1000, 14 | required=False) 15 | parser.add_argument("-c", "--clients", help="Number of clients", type=int, default=100, required=False) 16 | parser.add_argument("-f", "--fraction", help="Client fraction to use", type=float, default=0.1, 17 | required=False) 18 | parser.add_argument("-d", "--debug", help="Debugging", action="store_true", required=False) 19 | 20 | parser.add_argument("-lr", "--learning-rate", help="Learning rate", type=float, default=0.15, required=False) 21 | parser.add_argument("-b", "--batch-size", help="Batch Size", type=int, default=32, required=False) 22 | parser.add_argument("-ce", "--client-epochs", help="Number of epochs for the clients", type=int, default=1, 23 | required=False) 24 | parser.add_argument("-g", "--gpu", help="GPU to use (-1 is CPU)", type=int, default=0, required=False) 25 | args = parser.parse_args() 26 | return args 27 | 28 | 29 | def args_as_json(args): 30 | json_str = json.dumps(args.__dict__, sort_keys=True, indent=4) 31 | return json_str 32 | 33 | 34 | def save_args_as_json(args, path): 35 | json_str = args_as_json(args) 36 | 37 | with open(str(path), "w") as f: 38 | f.write(json_str) 39 | -------------------------------------------------------------------------------- /fed_learn/data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List 3 | 4 | import numpy as np 5 | from keras import utils 6 | 7 | from fed_learn.fed_client import Client 8 | 9 | 10 | def iid_data_indices(nb_clients: int, labels: np.ndarray): 11 | labels = labels.flatten() 12 | data_len = len(labels) 13 | indices = np.arange(data_len) 14 | np.random.shuffle(indices) 15 | chunks = np.array_split(indices, nb_clients) 16 | return chunks 17 | 18 | 19 | def non_iid_data_indices(nb_clients: int, labels: np.ndarray, nb_shards: int = 200): 20 | labels = labels.flatten() 21 | data_len = len(labels) 22 | 23 | indices = np.arange(data_len) 24 | indices = indices[labels.argsort()] 25 | 26 | shards = np.array_split(indices, nb_shards) 27 | random.shuffle(shards) 28 | shards_for_users = np.array_split(shards, nb_clients) 29 | indices_for_users = [np.hstack(x) for x in shards_for_users] 30 | 31 | return indices_for_users 32 | 33 | 34 | class BaseDataProcessor: 35 | def __init__(self): 36 | pass 37 | 38 | @staticmethod 39 | def pre_process(x: np.ndarray, y: np.ndarray, nb_classes: int): 40 | raise NotImplementedError 41 | 42 | 43 | class CifarProcessor(BaseDataProcessor): 44 | def __init__(self): 45 | super().__init__() 46 | 47 | @staticmethod 48 | def pre_process(x: np.ndarray, y: np.ndarray, nb_classes: int): 49 | y = utils.to_categorical(y, nb_classes) 50 | x = x.astype(np.float32) 51 | x /= 255.0 52 | return x, y 53 | 54 | 55 | class DataHandler: 56 | def __init__(self, 57 | x_train: np.ndarray, 58 | y_train: np.ndarray, 59 | x_test: np.ndarray, 60 | y_test: np.ndarray, 61 | preprocessor: BaseDataProcessor, 62 | only_debugging: bool = True): 63 | self._nb_classes = len(np.unique(y_train)) 64 | self._preprocessor = preprocessor 65 | 66 | if only_debugging: 67 | x_train = x_train[:100] 68 | y_train = y_train[:100] 69 | x_test = x_test[:100] 70 | y_test = y_test[:100] 71 | 72 | self.x_train = x_train 73 | self.y_train = y_train 74 | self.x_test = x_test 75 | self.y_test = y_test 76 | 77 | def _sample(self, sampling_technique: str, nb_clients: int): 78 | if sampling_technique.lower() == "iid": 79 | sampler_fn = iid_data_indices 80 | else: 81 | sampler_fn = non_iid_data_indices 82 | client_data_indices = sampler_fn(nb_clients, self.y_train) 83 | return client_data_indices 84 | 85 | def preprocess(self, x, y): 86 | x, y = self._preprocessor.pre_process(x, y, self._nb_classes) 87 | return x, y 88 | 89 | def assign_data_to_clients(self, clients: List[Client], sampling_technique: str): 90 | sampled_data_indices = self._sample(sampling_technique, len(clients)) 91 | for client, data_indices in zip(clients, sampled_data_indices): 92 | x = self.x_train[data_indices] 93 | y = self.y_train[data_indices] 94 | x, y = self.preprocess(x, y) 95 | client.receive_data(x, y) 96 | -------------------------------------------------------------------------------- /fed_learn/experiment_utils.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from pathlib import Path 3 | 4 | from swiss_army_tensorboard import tfboard_loggers 5 | 6 | import fed_learn 7 | 8 | 9 | class Experiment: 10 | def __init__(self, experiment_folder_path: Path, overwrite_if_exists: bool = False): 11 | self.experiment_folder_path = experiment_folder_path 12 | 13 | if self.experiment_folder_path.is_dir(): 14 | if overwrite_if_exists: 15 | shutil.rmtree(str(self.experiment_folder_path)) 16 | else: 17 | raise Exception("Experiment already exists") 18 | 19 | self.experiment_folder_path.mkdir(parents=True, exist_ok=False) 20 | 21 | self.args_json_path = self.experiment_folder_path / "args.json" 22 | 23 | self.train_hist_path = self.experiment_folder_path / "fed_learn_global_test_results.json" 24 | self.global_weight_path = self.experiment_folder_path / "global_weights.h5" 25 | 26 | def serialize_args(self, args): 27 | fed_learn.save_args_as_json(args, self.args_json_path) 28 | tfboard_loggers.TFBoardTextLogger(self.experiment_folder_path).log_markdown("args", "```\n{0}\n```".format( 29 | fed_learn.args_as_json(args)), -1) 30 | 31 | def create_scalar_logger(self) -> tfboard_loggers.TFBoardScalarLogger: 32 | tf_scalar_logger = tfboard_loggers.TFBoardScalarLogger(self.experiment_folder_path) 33 | return tf_scalar_logger 34 | -------------------------------------------------------------------------------- /fed_learn/fed_client.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from keras import models 4 | 5 | import fed_learn 6 | 7 | 8 | class Client: 9 | def __init__(self, id: int): 10 | self.id = id 11 | self.model: models.Model = None 12 | self.x_train = None 13 | self.y_train = None 14 | 15 | def _init_model(self, model_fn: Callable, model_weights): 16 | model = model_fn() 17 | fed_learn.set_model_weights(model, model_weights) 18 | self.model = model 19 | 20 | def receive_data(self, x, y): 21 | self.x_train = x 22 | self.y_train = y 23 | 24 | def receive_and_init_model(self, model_fn: Callable, model_weights): 25 | self._init_model(model_fn, model_weights) 26 | 27 | def edge_train(self, client_train_dict: dict): 28 | if self.model is None: 29 | raise ValueError("Model is not created for client: {0}".format(self.id)) 30 | 31 | hist = self.model.fit(self.x_train, self.y_train, **client_train_dict) 32 | return hist 33 | 34 | def reset_model(self): 35 | fed_learn.get_rid_of_the_models(self.model) 36 | -------------------------------------------------------------------------------- /fed_learn/fed_server.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import numpy as np 4 | from keras import models 5 | 6 | import fed_learn 7 | from fed_learn.weight_summarizer import WeightSummarizer 8 | 9 | 10 | class Server: 11 | def __init__(self, model_fn: Callable, 12 | weight_summarizer: WeightSummarizer, 13 | nb_clients: int = 100, 14 | client_fraction: float = 0.2): 15 | self.nb_clients = nb_clients 16 | self.client_fraction = client_fraction 17 | self.weight_summarizer = weight_summarizer 18 | 19 | # Initialize the global model's weights 20 | self.model_fn = model_fn 21 | model = self.model_fn() 22 | self.global_test_metrics_dict = {k: [] for k in model.metrics_names} 23 | self.global_model_weights = model.get_weights() 24 | fed_learn.get_rid_of_the_models(model) 25 | 26 | self.global_train_losses = [] 27 | self.epoch_losses = [] 28 | 29 | self.clients = [] 30 | self.client_model_weights = [] 31 | 32 | # Training parameters used by the clients 33 | self.client_train_params_dict = {"batch_size": 32, 34 | "epochs": 5, 35 | "verbose": 1, 36 | "shuffle": True} 37 | 38 | def _create_model_with_updated_weights(self) -> models.Model: 39 | model = self.model_fn() 40 | fed_learn.models.set_model_weights(model, self.global_model_weights) 41 | return model 42 | 43 | def send_model(self, client): 44 | client.receive_and_init_model(self.model_fn, self.global_model_weights) 45 | 46 | def init_for_new_epoch(self): 47 | # Reset the collected weights 48 | self.client_model_weights.clear() 49 | # Reset epoch losses 50 | self.epoch_losses.clear() 51 | 52 | def receive_results(self, client): 53 | client_weights = client.model.get_weights() 54 | self.client_model_weights.append(client_weights) 55 | client.reset_model() 56 | 57 | def create_clients(self): 58 | # Create all the clients 59 | for i in range(self.nb_clients): 60 | client = fed_learn.Client(i) 61 | self.clients.append(client) 62 | 63 | def summarize_weights(self): 64 | new_weights = self.weight_summarizer.process(self.client_model_weights) 65 | self.global_model_weights = new_weights 66 | 67 | def get_client_train_param_dict(self): 68 | return self.client_train_params_dict 69 | 70 | def update_client_train_params(self, param_dict: dict): 71 | self.client_train_params_dict.update(param_dict) 72 | 73 | def test_global_model(self, x_test: np.ndarray, y_test: np.ndarray): 74 | model = self._create_model_with_updated_weights() 75 | results = model.evaluate(x_test, y_test, batch_size=32, verbose=1) 76 | 77 | results_dict = dict(zip(model.metrics_names, results)) 78 | for metric_name, value in results_dict.items(): 79 | self.global_test_metrics_dict[metric_name].append(value) 80 | 81 | fed_learn.get_rid_of_the_models(model) 82 | 83 | return results_dict 84 | 85 | def select_clients(self): 86 | nb_clients_to_use = max(int(self.nb_clients * self.client_fraction), 1) 87 | client_indices = np.arange(self.nb_clients) 88 | np.random.shuffle(client_indices) 89 | selected_client_indices = client_indices[:nb_clients_to_use] 90 | return np.asarray(self.clients)[selected_client_indices] 91 | 92 | def save_model_weights(self, path: str): 93 | model = self._create_model_with_updated_weights() 94 | model.save_weights(str(path), overwrite=True) 95 | fed_learn.get_rid_of_the_models(model) 96 | 97 | def load_model_weights(self, path: str, by_name: bool = False): 98 | model = self._create_model_with_updated_weights() 99 | model.load_weights(str(path), by_name=by_name) 100 | self.global_model_weights = model.get_weights() 101 | fed_learn.get_rid_of_the_models(model) 102 | -------------------------------------------------------------------------------- /fed_learn/models.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | from keras import optimizers, losses, models, layers 3 | from keras.applications.vgg16 import VGG16 4 | 5 | 6 | def create_model(input_shape: tuple, nb_classes: int, init_with_imagenet: bool = False, learning_rate: float = 0.01): 7 | weights = None 8 | if init_with_imagenet: 9 | weights = "imagenet" 10 | 11 | model = VGG16(input_shape=input_shape, 12 | classes=nb_classes, 13 | weights=weights, 14 | include_top=False) 15 | # "Shallow" VGG for Cifar10 16 | x = model.get_layer('block3_pool').output 17 | x = layers.Flatten(name='Flatten')(x) 18 | x = layers.Dense(512, activation='relu')(x) 19 | x = layers.Dense(nb_classes)(x) 20 | x = layers.Softmax()(x) 21 | model = models.Model(model.input, x) 22 | 23 | loss = losses.categorical_crossentropy 24 | optimizer = optimizers.SGD(lr=learning_rate, decay=0.99) 25 | 26 | model.compile(optimizer, loss, metrics=["accuracy"]) 27 | return model 28 | 29 | 30 | def set_model_weights(model: models.Model, weight_list): 31 | for i, symbolic_weights in enumerate(model.weights): 32 | weight_values = weight_list[i] 33 | K.set_value(symbolic_weights, weight_values) 34 | -------------------------------------------------------------------------------- /fed_learn/utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | from typing import List 4 | 5 | from keras import backend as K 6 | 7 | import fed_learn 8 | 9 | 10 | def get_rid_of_the_models(model=None): 11 | """ 12 | This function clears the TF session from the model. 13 | This is needed as TF/Keras models are not automatically cleared, and the memory will be overloaded 14 | """ 15 | 16 | K.clear_session() 17 | if model is not None: 18 | del model 19 | gc.collect() 20 | 21 | 22 | def print_selected_clients(clients: List[fed_learn.fed_client.Client]): 23 | client_ids = [c.id for c in clients] 24 | print("Selected clients for epoch: {0}".format("| ".join(map(str, client_ids)))) 25 | 26 | 27 | def set_working_GPU(gpu_ids: str): 28 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 29 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids 30 | -------------------------------------------------------------------------------- /fed_learn/weight_summarizer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import numpy as np 4 | 5 | 6 | class WeightSummarizer: 7 | def __init__(self): 8 | pass 9 | 10 | def process(self, 11 | client_weight_list: List[List[np.ndarray]], 12 | global_weights: Optional[List[np.ndarray]] = None) -> List[np.ndarray]: 13 | raise NotImplementedError() 14 | 15 | 16 | class FedAvg(WeightSummarizer): 17 | def __init__(self, nu: float = 1.0): 18 | """ 19 | Federated Averaging 20 | 21 | :param nu: Controls the summarized client join model fraction to the global model 22 | """ 23 | 24 | super().__init__() 25 | self.nu = nu 26 | 27 | def process(self, 28 | client_weight_list: List[List[np.ndarray]], 29 | global_weights_per_layer: Optional[List[np.ndarray]] = None) -> List[np.ndarray]: 30 | nb_clients = len(client_weight_list) 31 | weights_average = [np.zeros_like(w) for w in client_weight_list[0]] 32 | 33 | for layer_index in range(len(weights_average)): 34 | w = weights_average[layer_index] 35 | if global_weights_per_layer is not None: 36 | global_weight_mtx = global_weights_per_layer[layer_index] 37 | else: 38 | global_weight_mtx = np.zeros_like(w) 39 | for client_weight_index in range(nb_clients): 40 | client_weight_mtx = client_weight_list[client_weight_index][layer_index] 41 | 42 | # TODO: this step should be done at client side (client should send the difference of the weights) 43 | client_weight_diff_mtx = client_weight_mtx - global_weight_mtx 44 | 45 | w += client_weight_diff_mtx 46 | weights_average[layer_index] = (self.nu * w / nb_clients) + global_weight_mtx 47 | return weights_average 48 | -------------------------------------------------------------------------------- /federated_learning.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | from keras import datasets 6 | 7 | import fed_learn 8 | 9 | args = fed_learn.get_args() 10 | 11 | fed_learn.set_working_GPU(str(args.gpu)) 12 | 13 | experiment_folder_path = Path(__file__).resolve().parent / "experiments" / args.name 14 | experiment = fed_learn.Experiment(experiment_folder_path, args.overwrite_experiment) 15 | experiment.serialize_args(args) 16 | 17 | tf_scalar_logger = experiment.create_scalar_logger() 18 | 19 | client_train_params = {"epochs": args.client_epochs, "batch_size": args.batch_size} 20 | 21 | 22 | def model_fn(): 23 | return fed_learn.create_model((32, 32, 3), 10, init_with_imagenet=False, learning_rate=args.learning_rate) 24 | 25 | 26 | weight_summarizer = fed_learn.FedAvg() 27 | server = fed_learn.Server(model_fn, 28 | weight_summarizer, 29 | args.clients, 30 | args.fraction) 31 | 32 | weight_path = args.weights_file 33 | if weight_path is not None: 34 | server.load_model_weights(weight_path) 35 | 36 | server.update_client_train_params(client_train_params) 37 | server.create_clients() 38 | 39 | (x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data() 40 | data_handler = fed_learn.DataHandler(x_train, y_train, x_test, y_test, fed_learn.CifarProcessor(), args.debug) 41 | data_handler.assign_data_to_clients(server.clients, args.data_sampling_technique) 42 | x_test, y_test = data_handler.preprocess(data_handler.x_test, data_handler.y_test) 43 | 44 | for epoch in range(args.global_epochs): 45 | print("Global Epoch {0} is starting".format(epoch)) 46 | server.init_for_new_epoch() 47 | selected_clients = server.select_clients() 48 | 49 | fed_learn.print_selected_clients(selected_clients) 50 | 51 | for client in selected_clients: 52 | print("Client {0} is starting the training".format(client.id)) 53 | 54 | server.send_model(client) 55 | hist = client.edge_train(server.get_client_train_param_dict()) 56 | server.epoch_losses.append(hist.history["loss"][-1]) 57 | 58 | server.receive_results(client) 59 | 60 | server.summarize_weights() 61 | 62 | epoch_mean_loss = np.mean(server.epoch_losses) 63 | server.global_train_losses.append(epoch_mean_loss) 64 | tf_scalar_logger.log_scalar("train_loss/client_mean_loss", server.global_train_losses[-1], epoch) 65 | print("Loss (client mean): {0}".format(server.global_train_losses[-1])) 66 | 67 | global_test_results = server.test_global_model(x_test, y_test) 68 | print("--- Global test ---") 69 | test_loss = global_test_results["loss"] 70 | test_acc = global_test_results["acc"] 71 | print("{0}: {1}".format("Loss", test_loss)) 72 | print("{0}: {1}".format("Accuracy", test_acc)) 73 | tf_scalar_logger.log_scalar("test_loss/global_loss", test_loss, epoch) 74 | tf_scalar_logger.log_scalar("test_acc/global_acc", test_acc, epoch) 75 | 76 | with open(str(experiment.train_hist_path), 'w') as f: 77 | json.dump(server.global_test_metrics_dict, f) 78 | 79 | # TODO: save only when a condition is fulfilled (validation loss gets better, etc...) 80 | server.save_model_weights(experiment.global_weight_path) 81 | 82 | print("_" * 30) 83 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16 2 | keras>=2.2.4 3 | -e git://github.com/gaborvecsei/Swiss-Army-Tensorboard.git#egg=swiss_army_tensorboard 4 | -------------------------------------------------------------------------------- /tests/test_weight_summarizer.py: -------------------------------------------------------------------------------- 1 | import fed_learn 2 | import numpy as np 3 | import unittest 4 | 5 | 6 | class TestFedAvgAlgorithm(unittest.TestCase): 7 | 8 | def setUp(self) -> None: 9 | self.weight_summarizer = fed_learn.FedAvg() 10 | 11 | nb_clients = 3 12 | nb_weight_arrays = 6 13 | 14 | self.all_clients_weights = [] 15 | 16 | for i in range(nb_clients): 17 | client_weight_arrays = [] 18 | for k in range(nb_weight_arrays): 19 | rnd_weight_array = np.ones((8, 12)) 20 | rnd_weight_array += i 21 | client_weight_arrays.append(rnd_weight_array) 22 | self.all_clients_weights.append(client_weight_arrays) 23 | 24 | self.avg_weights = self.weight_summarizer.process(self.all_clients_weights) 25 | 26 | def test_basic_averaging_mean(self): 27 | self.assertAlmostEqual(np.mean(self.avg_weights), 2.0) 28 | 29 | def test_basic_averaging_min(self): 30 | self.assertAlmostEqual(np.min(self.avg_weights), 2.0) 31 | 32 | def test_basic_averaging_max(self): 33 | self.assertAlmostEqual(np.max(self.avg_weights), 2.0) 34 | --------------------------------------------------------------------------------