├── tests ├── __init__.py ├── experiments │ ├── __init__.py │ └── test_keras_models.py ├── shared_folder │ ├── __init__.py │ └── test_s3_folder.py ├── test_flwr_base.py ├── test_s3_folder.py └── test_tf_training.py ├── flwr_serverless ├── federated_node │ ├── __init__.py │ ├── aggregatable.py │ ├── sync_federated_node.py │ └── async_federated_node.py ├── shared_folder │ ├── __init__.py │ ├── base_folder.py │ ├── in_memory_folder.py │ ├── local_folder.py │ └── s3_folder.py ├── version.py ├── keras │ ├── __init__.py │ ├── federated_learning_callback.py │ └── example.py ├── __init__.py └── dataset │ └── federated_mnist_dataset.py ├── .dockerignore ├── experiments ├── utils │ ├── federated_learning_runner_cifar10.py │ ├── custom_wandb_callback.py │ ├── non_federated_runner.py │ ├── centralized_runner.py │ ├── base_experiment_runner.py │ └── federated_learning_runner.py ├── dataset │ ├── lotr-paragraphs.json │ ├── tolkien_dataset_builder.py │ └── tolkien_dataset.py ├── experiment_scripts │ ├── centralized.py │ ├── non_federated.py │ ├── async_fedavg.py │ ├── exp1_non_federated.py │ ├── sync_fedavg.py │ └── exp1_mnist_async_fedavg.py ├── model │ ├── simple_mnist_model.py │ └── keras_models.py ├── exp1_mnist.py ├── exp2_cifar10.py ├── get_wandb_tables.py └── exp3_wikitext.py ├── bin ├── build.sh └── train.sh ├── requirements_dev.txt ├── doc └── paper │ ├── graphics │ ├── async_fl.pdf │ └── flower_async_detail.pdf │ ├── README.md │ ├── .gitignore │ ├── references.bib │ └── preprint.sty ├── .gitignore ├── TODO.md ├── Dockerfile ├── requirements.txt ├── setup.py ├── .github └── workflows │ ├── test.yaml │ └── publish.yaml └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/shared_folder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /flwr_serverless/federated_node/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /flwr_serverless/shared_folder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | .vscode/ 3 | .pytest_cache/ 4 | -------------------------------------------------------------------------------- /experiments/utils/federated_learning_runner_cifar10.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bin/build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | docker build -t flwr-tf . -------------------------------------------------------------------------------- /flwr_serverless/version.py: -------------------------------------------------------------------------------- 1 | # add version here 2 | __version__ = "0.2.10" 3 | -------------------------------------------------------------------------------- /flwr_serverless/keras/__init__.py: -------------------------------------------------------------------------------- 1 | from .federated_learning_callback import FlwrFederatedCallback 2 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | pytest==7.2.1 2 | boto3 3 | moto==4.2.14 4 | tensorflow==2.11.* 5 | keras-cv==0.5.* -------------------------------------------------------------------------------- /experiments/dataset/lotr-paragraphs.json: -------------------------------------------------------------------------------- 1 | [ 2 | "some sample text", 3 | "some sample text again" 4 | ] -------------------------------------------------------------------------------- /doc/paper/graphics/async_fl.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kungfuai/flwr_serverless/HEAD/doc/paper/graphics/async_fl.pdf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | **/.DS_Store 3 | **/.env 4 | **/*.pyc 5 | *.egg-info 6 | wandb/ 7 | .vscode/ 8 | env.sh 9 | /build 10 | /dist -------------------------------------------------------------------------------- /doc/paper/graphics/flower_async_detail.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kungfuai/flwr_serverless/HEAD/doc/paper/graphics/flower_async_detail.pdf -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | - keras: scale to 3+ clients 2 | - test real concurrency 3 | - s3 storage backend 4 | - delete old models in sync node 5 | - refactor skewed_split method 6 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:2.13.0-gpu 2 | 3 | WORKDIR /workspace 4 | RUN pip install --upgrade pip && \ 5 | pip install flwr==1.5.* keras-cv==0.6.* python-dotenv wandb 6 | 7 | -------------------------------------------------------------------------------- /bin/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | docker run --runtime nvidia -it \ 3 | -v $(pwd):/workspace \ 4 | -v $HOME/.keras:/root/.keras \ 5 | flwr-tf python -m experiments.exp2_cifar10 $@ -------------------------------------------------------------------------------- /flwr_serverless/federated_node/aggregatable.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from flwr.common import Parameters 3 | 4 | 5 | @dataclass 6 | class Aggregatable: 7 | parameters: Parameters 8 | num_examples: int 9 | metrics: dict 10 | -------------------------------------------------------------------------------- /doc/paper/README.md: -------------------------------------------------------------------------------- 1 | ## Install 2 | 3 | Follow the [instruction](https://mathjiajia.github.io/vscode-and-latex/#step-1-download--install-tex-live) to install TexLive, VSCode and Latex Workshop plug-in. 4 | 5 | ## Compile latex files 6 | 7 | Go to the main .tex file and save the file (cmd+C). This should trigger compiling the tex files into a pdf. 8 | -------------------------------------------------------------------------------- /flwr_serverless/shared_folder/base_folder.py: -------------------------------------------------------------------------------- 1 | class SharedFolder: 2 | def get(self, key, default=None): 3 | ... 4 | 5 | def __getitem__(self, key): 6 | ... 7 | 8 | def __setitem__(self, key, value): 9 | ... 10 | 11 | def __len__(self): 12 | ... 13 | 14 | def items(self): 15 | ... 16 | -------------------------------------------------------------------------------- /flwr_serverless/__init__.py: -------------------------------------------------------------------------------- 1 | from .federated_node.async_federated_node import AsyncFederatedNode 2 | from .federated_node.sync_federated_node import SyncFederatedNode 3 | from .shared_folder.base_folder import SharedFolder 4 | from .shared_folder.s3_folder import S3FolderWithPickle 5 | from .shared_folder.local_folder import LocalFolder 6 | from .version import __version__ 7 | 8 | S3Folder = S3FolderWithPickle 9 | 10 | -------------------------------------------------------------------------------- /flwr_serverless/dataset/federated_mnist_dataset.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class FederatedMNISTDataset: 6 | """ 7 | Under different partition methods, we can test the effectivness of a particular federated 8 | learning strategy. 9 | """ 10 | partition_method: str = "by_class" # one of "random", "by_class" 11 | num_partitions: int = 2 12 | images = None # (n, 28, 28) 13 | labels = None # (n,) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flwr<1.7 2 | python-dotenv 3 | 4 | 5 | # Troubleshooting 6 | # 7 | # `grpcio` installation on Apple M1, see 8 | # https://stackoverflow.com/questions/66640705/how-can-i-install-grpcio-on-an-apple-m1-silicon-laptop 9 | # export GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=1 10 | # export GRPC_PYTHON_BUILD_SYSTEM_ZLIB=1 11 | # 12 | # `tensorflow` installation on Apple M1: 13 | # conda install -c apple tensorflow-deps 14 | # pip install tensorflow-macos 15 | # pip install keras-cv 16 | -------------------------------------------------------------------------------- /flwr_serverless/shared_folder/in_memory_folder.py: -------------------------------------------------------------------------------- 1 | class InMemoryFolder: 2 | def __init__(self): 3 | self.model_store = {} 4 | 5 | def get(self, key, default=None): 6 | return self.model_store[key] if key in self.model_store else default 7 | 8 | def __getitem__(self, key): 9 | return self.model_store[key] 10 | 11 | def __setitem__(self, key, value): 12 | self.model_store[key] = value 13 | 14 | def __delitem__(self, key): 15 | if key in self.model_store: 16 | del self.model_store[key] 17 | 18 | def __len__(self): 19 | return len(self.model_store) 20 | 21 | def items(self): 22 | return self.model_store.items() 23 | 24 | def get_raw_folder(self): 25 | return self 26 | -------------------------------------------------------------------------------- /experiments/experiment_scripts/centralized.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from tensorflow.keras.utils import set_random_seed 3 | 4 | from experiments.utils.centralized_runner import CentralizedRunner 5 | 6 | # main function 7 | if __name__ == "__main__": 8 | # starts a new run 9 | set_random_seed(117) 10 | 11 | num_nodes = 1 12 | dataset = "mnist" 13 | 14 | config = { 15 | "epochs": 128, 16 | "batch_size": 32, 17 | "steps_per_epoch": 8, 18 | "lr": 0.001, 19 | "shuffled:": False, 20 | "num_nodes": num_nodes, 21 | "dataset": dataset, 22 | } 23 | 24 | # federeated run w/ FedAvg 25 | wandb.init( 26 | project="experiments", entity="flwr_serverless", name="centralized", config=config 27 | ) 28 | centralized_runner = CentralizedRunner(config, num_nodes, dataset) 29 | centralized_runner.run() 30 | wandb.finish() 31 | -------------------------------------------------------------------------------- /experiments/utils/custom_wandb_callback.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import wandb 3 | 4 | 5 | class CustomWandbCallback(tf.keras.callbacks.Callback): 6 | def __init__(self, node_i): 7 | self.node_i = node_i 8 | self.node_i_name = f"node{node_i}" 9 | 10 | def on_train_begin(self, logs=None): 11 | pass 12 | 13 | def on_train_end(self, logs=None): 14 | pass 15 | 16 | def on_epoch_begin(self, epoch, logs=None): 17 | pass 18 | 19 | def on_epoch_end(self, epoch, logs=None): 20 | log_dict = { 21 | f"{self.node_i_name}_epoch": epoch, 22 | f"{self.node_i_name}_loss": logs["loss"], 23 | f"{self.node_i_name}_accuracy": logs["accuracy"], 24 | f"{self.node_i_name}_val_loss": logs["val_loss"], 25 | f"{self.node_i_name}_val_accuracy": logs["val_accuracy"], 26 | } 27 | 28 | wandb.log(log_dict) 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("requirements.txt") as f: 4 | install_requires = f.read().splitlines() 5 | 6 | with open("README.md") as f: 7 | long_description = f.read() 8 | 9 | with open("flwr_serverless/version.py") as f: 10 | version_text = f.read() 11 | __version__ = version_text.split('"')[1] 12 | 13 | setup( 14 | name="flwr_serverless", 15 | version=__version__, 16 | description="A serverless federated learning library based on flwr", 17 | url="https://github.com/kungfuai/flwr_serverless", 18 | author="Kungfu AI", 19 | author_email="zhangzhang.si@gmail.com", 20 | license="MIT", 21 | packages=find_packages("."), 22 | long_description=long_description, 23 | long_description_content_type="text/markdown", 24 | python_requires=">=3.6", 25 | install_requires=install_requires, 26 | include_package_data=True, 27 | zip_safe=False, 28 | ) 29 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: "Test" 2 | on: 3 | pull_request: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | test: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | # Run in all these versions of Python 14 | python-version: [3.9] # 3.8 did not work 15 | 16 | steps: 17 | # Checkout the latest code from the repo 18 | - name: Checkout repo 19 | uses: actions/checkout@v2 20 | # Setup which version of Python to use 21 | - name: Set Up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install package 26 | run: python setup.py install 27 | - name: Install pytest and other dev requirements 28 | run: | 29 | pip install -r requirements_dev.txt 30 | - name: Run tests 31 | run: pytest tests -------------------------------------------------------------------------------- /experiments/model/simple_mnist_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense 4 | from tensorflow.keras.models import Model 5 | 6 | 7 | class SimpleMnistModel: 8 | def __init__(self, lr=0.001): 9 | self.lr = lr 10 | 11 | def run(self): 12 | model = self._build_model() 13 | return self._compile_model(model) 14 | 15 | def _build_model(self): 16 | input = Input(shape=(28, 28, 1)) 17 | x = Conv2D(32, kernel_size=4, activation="relu")(input) 18 | x = MaxPooling2D()(x) 19 | x = Conv2D(16, kernel_size=4, activation="relu")(x) 20 | x = Flatten()(x) 21 | output = Dense(10, activation="softmax")(x) 22 | model = Model(inputs=input, outputs=output) 23 | return model 24 | 25 | def _compile_model(self, model): 26 | model.compile( 27 | optimizer=keras.optimizers.Adam(self.lr), 28 | loss="sparse_categorical_crossentropy", 29 | metrics=["accuracy"], 30 | ) 31 | return model 32 | -------------------------------------------------------------------------------- /experiments/experiment_scripts/non_federated.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from tensorflow.keras.utils import set_random_seed 3 | 4 | from experiments.utils.non_federated_runner import NonFederatedRunner 5 | 6 | # main function 7 | if __name__ == "__main__": 8 | # starts a new run 9 | set_random_seed(117) 10 | 11 | num_nodes = 2 12 | use_async = True 13 | shuffled = True # if true, the order of the data is shuffled before partitioning 14 | federated_type = "concurrent" # options: concurrent, sequential, pseudo-concurrent 15 | dataset = "mnist" 16 | strategy = "fedavg" 17 | 18 | config = { 19 | "epochs": 128, 20 | "batch_size": 32, 21 | "steps_per_epoch": 16, 22 | "lr": 0.0004, 23 | "num_nodes": num_nodes, 24 | "use_async": use_async, 25 | "federated_type": federated_type, 26 | "dataset": dataset, 27 | "strategy": strategy, 28 | "shuffled": shuffled, 29 | } 30 | 31 | num_nodes = 2 32 | dataset = "mnist" 33 | 34 | wandb.init( 35 | project="test-project", entity="flwr_serverless", name="non_federated", config=config 36 | ) 37 | nonfederated_runner = NonFederatedRunner(config, num_nodes, dataset) 38 | nonfederated_runner.run() 39 | wandb.finish() 40 | -------------------------------------------------------------------------------- /experiments/dataset/tolkien_dataset_builder.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from torch.utils.data import random_split 3 | from .tolkien_dataset import TolkienDataset 4 | 5 | 6 | class TolkienDatasetBuilder: 7 | def __init__(self, filename, model_name, val_percent=0.1): 8 | self.filename = filename 9 | self.val_percent = val_percent 10 | self.model_name = model_name 11 | 12 | def build_datasets(self): 13 | """ 14 | Reads in the dataset 15 | Returns a tuple of training and validation datasets 16 | """ 17 | print(f"Reading data from {self.filename}...") 18 | 19 | df = pd.read_json(self.filename) 20 | df = df.rename(columns={0: "sentences"}) 21 | dataset = TolkienDataset(df, self.model_name) 22 | 23 | train_dataset, val_dataset = self.random_split_dataset(dataset) 24 | 25 | return train_dataset, val_dataset 26 | 27 | def random_split_dataset(self, df): 28 | """ 29 | Takes a pandas dataframe and splits it into a training and validation set 30 | Returns a tuple of training and validation datasets 31 | """ 32 | val_size = int(len(df) * self.val_percent) 33 | train_size = len(df) - val_size 34 | return random_split(df, [train_size, val_size]) # returns a tuple 35 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distributions 📦 to PyPI and TestPyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - '**' # Push events to every tag including hierarchical tags like v1.0/beta 7 | branches: 8 | - testpypi 9 | 10 | jobs: 11 | build-n-publish: 12 | name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@main 16 | - name: Set up Python 3.9 17 | uses: actions/setup-python@v3 18 | with: 19 | python-version: "3.9" 20 | - name: Install pypa/build 21 | run: >- 22 | python -m 23 | pip install 24 | build 25 | --user 26 | - name: Build a binary wheel and a source tarball 27 | run: >- 28 | python -m 29 | build 30 | --sdist 31 | --wheel 32 | --outdir dist/ 33 | . 34 | - name: Publish distribution 📦 to Test PyPI 35 | if: startsWith(github.ref, 'refs/heads/testpypi') 36 | uses: pypa/gh-action-pypi-publish@release/v1 37 | with: 38 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 39 | repository_url: https://test.pypi.org/legacy/ 40 | - name: Publish distribution 📦 to PyPI 41 | if: startsWith(github.ref, 'refs/tags') 42 | uses: pypa/gh-action-pypi-publish@release/v1 43 | with: 44 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /experiments/experiment_scripts/async_fedavg.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | 3 | # import set_random_seed 4 | from tensorflow.keras.utils import set_random_seed 5 | 6 | from experiments.utils.federated_learning_runner import FederatedLearningRunner 7 | 8 | # main function 9 | if __name__ == "__main__": 10 | # starts a new run 11 | set_random_seed(117) 12 | 13 | num_nodes = 2 14 | use_async = True 15 | federated_type = "concurrent" 16 | dataset = "mnist" 17 | strategy = "fedavg" 18 | data_split = "partitioned" 19 | 20 | config = { 21 | "epochs": 128, 22 | "batch_size": 32, 23 | "steps_per_epoch": 8, 24 | "lr": 0.001, 25 | "num_nodes": num_nodes, 26 | "use_async": use_async, 27 | "federated_type": federated_type, 28 | "dataset": dataset, 29 | "strategy": strategy, 30 | "data_split": data_split, 31 | } 32 | 33 | wandb.init( 34 | project="experiments", 35 | entity="flwr_serverless", 36 | name=f"async_{strategy}_{num_nodes}_nodes_{data_split}_split", 37 | config=config, 38 | ) 39 | federated_learning_runner = FederatedLearningRunner( 40 | config=config, 41 | num_nodes=num_nodes, 42 | use_async=use_async, 43 | federated_type=federated_type, 44 | dataset=dataset, 45 | strategy=strategy, 46 | ) 47 | federated_learning_runner.run() 48 | wandb.finish() 49 | -------------------------------------------------------------------------------- /experiments/experiment_scripts/exp1_non_federated.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from tensorflow.keras.utils import set_random_seed 3 | 4 | from experiments.utils.non_federated_runner import NonFederatedRunner 5 | 6 | # main function 7 | if __name__ == "__main__": 8 | # starts a new run 9 | set_random_seed(117) 10 | 11 | num_nodes = 2 12 | use_async = True 13 | shuffled = True # if true, the order of the data is shuffled before partitioning 14 | federated_type = "concurrent" # options: concurrent, sequential, pseudo-concurrent 15 | dataset = "mnist" 16 | strategy = "fedavg" 17 | 18 | # TODO: grab configs (overrides) from wandb and put them in 19 | # a list called configs. 20 | # Then, iterate over configs and run the experiment 21 | config = { 22 | "epochs": 128, 23 | "batch_size": 32, 24 | "steps_per_epoch": 16, 25 | "lr": 0.0004, 26 | "num_nodes": num_nodes, 27 | "use_async": use_async, 28 | "federated_type": federated_type, 29 | "dataset": dataset, 30 | "strategy": strategy, 31 | "shuffled": shuffled, 32 | } 33 | 34 | num_nodes = 2 35 | dataset = "mnist" 36 | 37 | wandb.init( 38 | project="test-project", entity="flwr_serverless", name="non_federated", config=config 39 | ) 40 | nonfederated_runner = NonFederatedRunner(config, num_nodes, dataset) 41 | nonfederated_runner.run() 42 | wandb.finish() 43 | -------------------------------------------------------------------------------- /experiments/experiment_scripts/sync_fedavg.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | 3 | # import set_random_seed 4 | from tensorflow.keras.utils import set_random_seed 5 | 6 | from experiments.utils.federated_learning_runner import FederatedLearningRunner 7 | 8 | # main function 9 | if __name__ == "__main__": 10 | # starts a new run 11 | set_random_seed(117) 12 | 13 | num_nodes = 2 14 | use_async = True 15 | federated_type = "concurrent" 16 | dataset = "mnist" 17 | strategy = "fedavg" 18 | data_split = "skewed" 19 | 20 | if use_async: 21 | sync = "async" 22 | else: 23 | sync = "sync" 24 | 25 | config = { 26 | "epochs": 1000, 27 | "batch_size": 32, 28 | "steps_per_epoch": 8, 29 | "lr": 0.001, 30 | "num_nodes": num_nodes, 31 | "use_async": use_async, 32 | "federated_type": federated_type, 33 | "dataset": dataset, 34 | "strategy": strategy, 35 | "data_split": data_split, 36 | } 37 | 38 | wandb.init( 39 | project="sync-vs-async", 40 | entity="flwr_serverless", 41 | name=f"{sync}_{strategy}_{num_nodes}_nodes_{data_split}_split", 42 | config=config, 43 | ) 44 | federated_learning_runner = FederatedLearningRunner( 45 | config=config, 46 | num_nodes=num_nodes, 47 | use_async=use_async, 48 | federated_type=federated_type, 49 | dataset=dataset, 50 | strategy=strategy, 51 | ) 52 | federated_learning_runner.run() 53 | wandb.finish() 54 | -------------------------------------------------------------------------------- /tests/shared_folder/test_s3_folder.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import pytest 3 | import numpy as np 4 | from moto import mock_s3 5 | from flwr_serverless.shared_folder.s3_folder import S3FolderWithPickle 6 | 7 | 8 | @mock_s3 9 | def test_simple_s3_get(): 10 | conn = boto3.resource("s3", region_name="us-east-1") 11 | # We need to create the bucket since this is all in Moto's 'virtual' AWS account 12 | conn.create_bucket(Bucket="mybucket") 13 | s3 = boto3.client("s3", region_name="us-east-1") 14 | s3.put_object(Bucket="mybucket", Key="test_object", Body=b"some content") 15 | body = s3.get_object(Bucket="mybucket", Key="test_object")["Body"].read() 16 | assert body == b"some content" 17 | 18 | 19 | @mock_s3 20 | def test_s3_storage_backend(): 21 | conn = boto3.resource("s3", region_name="us-east-1") 22 | # We need to create the bucket since this is all in Moto's 'virtual' AWS account 23 | conn.create_bucket(Bucket="mybucket") 24 | storage = S3FolderWithPickle(directory="mybucket/experiment1") 25 | with pytest.raises(ValueError): 26 | storage["test"] = None 27 | storage["model_1"] = [0, 1, 2] 28 | assert storage["model_1"] == [0, 1, 2] 29 | 30 | storage["model_1"] = [0, 1, 2, 3] 31 | assert storage["model_1"] == [0, 1, 2, 3] 32 | 33 | storage["model_2"] = np.array([0, 1, 5]) 34 | assert np.array_equal(storage["model_2"], np.array([0, 1, 5])) 35 | 36 | keys = [] 37 | for key, _ in storage.items(): 38 | keys.append(key) 39 | assert keys == ["model_1", "model_2"] 40 | -------------------------------------------------------------------------------- /experiments/dataset/tolkien_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class TolkienDataset(Dataset): 7 | def __init__(self, df, model_name, max_len=256): 8 | self.model_name = model_name 9 | self.bos_token = "<|startoftext|>" 10 | self.eos_token = "<|endoftext|>" 11 | self.pad_token = "<|pad|>" 12 | self.tokenizer = AutoTokenizer.from_pretrained( 13 | self.model_name, 14 | bos_token=self.bos_token, 15 | eos_token=self.eos_token, 16 | pad_token=self.pad_token, 17 | ) 18 | 19 | self.df = df # pandas dataframe 20 | self.max_len = max_len 21 | 22 | def __len__(self): 23 | return len(self.df) 24 | 25 | def __getitem__(self, idx): 26 | sentence = self.df["sentences"][idx] 27 | encodings_dict = self.encode_text(sentence) 28 | 29 | # no labels because input_ids will be used as "labels" for CausalLM 30 | return { 31 | "sentence": sentence, 32 | "input_ids": torch.tensor(encodings_dict["input_ids"]), 33 | "attention_mask": torch.tensor(encodings_dict["attention_mask"]), 34 | } 35 | 36 | def encode_text(self, text): 37 | encodings_dict = self.tokenizer( 38 | self.bos_token + text + self.eos_token, 39 | truncation=True, 40 | max_length=self.max_len, 41 | padding="max_length", 42 | ) 43 | return encodings_dict 44 | -------------------------------------------------------------------------------- /experiments/utils/non_federated_runner.py: -------------------------------------------------------------------------------- 1 | from wandb.keras import WandbCallback 2 | 3 | from experiments.utils.base_experiment_runner import BaseExperimentRunner 4 | from experiments.utils.custom_wandb_callback import CustomWandbCallback 5 | 6 | 7 | class NonFederatedRunner(BaseExperimentRunner): 8 | def __init__(self, config, num_nodes, dataset): 9 | super().__init__(config, num_nodes, dataset) 10 | 11 | def run(self): 12 | self.models = self.create_models() 13 | self.train() 14 | self.evaluate() 15 | 16 | def train(self): 17 | ( 18 | self.partitioned_x_train, 19 | self.partitioned_y_train, 20 | self.x_test, 21 | self.y_test, 22 | ) = self.create_partitioned_datasets() 23 | 24 | for i_node in range(self.num_nodes): 25 | train_loader = self.get_train_dataloader_for_node(i_node) 26 | self.models[i_node].fit( 27 | train_loader, 28 | epochs=self.epochs, 29 | steps_per_epoch=self.steps_per_epoch, 30 | callbacks=[ 31 | CustomWandbCallback(i_node), 32 | ], 33 | validation_data=(self.x_test, self.y_test), 34 | validation_steps=self.steps_per_epoch, 35 | validation_batch_size=self.batch_size, 36 | ) 37 | 38 | def evaluate(self): 39 | for i_node in range(self.num_nodes): 40 | loss1, accuracy1 = self.models[i_node].evaluate( 41 | self.x_test, 42 | self.y_test, 43 | batch_size=self.batch_size, 44 | steps=self.steps_per_epoch, 45 | ) 46 | -------------------------------------------------------------------------------- /experiments/utils/centralized_runner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from wandb.keras import WandbCallback 3 | 4 | 5 | from flwr_serverless.keras.example import MnistModelBuilder 6 | 7 | from experiments.utils.base_experiment_runner import BaseExperimentRunner 8 | 9 | 10 | class CentralizedRunner(BaseExperimentRunner): 11 | def __init__(self, config, num_nodes, dataset): 12 | super().__init__(config, dataset) 13 | self.num_nodes = 1 14 | self.test_steps = 10 15 | 16 | def run(self): 17 | self.train_and_eval() 18 | 19 | def train_and_eval(self): 20 | image_size = self.x_train.shape[1] 21 | x_train = np.reshape(self.x_train, [-1, image_size, image_size, 1]) 22 | x_test = np.reshape(self.x_test, [-1, image_size, image_size, 1]) 23 | x_train = x_train.astype(np.float32) / 255 24 | x_test = x_test.astype(np.float32) / 255 25 | 26 | model = MnistModelBuilder(self.lr).run() 27 | 28 | model.fit( 29 | self.x_train, 30 | self.y_train, 31 | epochs=self.epochs, 32 | batch_size=self.batch_size, 33 | steps_per_epoch=self.steps_per_epoch, 34 | callbacks=[WandbCallback()], 35 | validation_data=( 36 | self.x_test[: self.test_steps * self.batch_size, ...], 37 | self.y_test[: self.test_steps * self.batch_size, ...], 38 | ), 39 | validation_steps=self.test_steps, 40 | validation_batch_size=self.batch_size, 41 | ) 42 | # memorization test 43 | loss, accuracy = model.evaluate( 44 | x_test, self.y_test, batch_size=self.batch_size, steps=self.steps_per_epoch 45 | ) 46 | -------------------------------------------------------------------------------- /experiments/model/keras_models.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | import keras_cv 3 | from keras_cv.models import ResNet18Backbone, ImageClassifier 4 | 5 | 6 | class ResNetModelBuilder: 7 | def __init__( 8 | self, 9 | lr=0.001, 10 | include_rescaling=False, 11 | num_classes=10, 12 | weights=None, 13 | net="ResNet50", 14 | input_shape=(None, None, 3), 15 | ): 16 | self.lr = lr 17 | self.num_classes = num_classes 18 | self.weights = weights 19 | self.net = net 20 | self.input_shape = input_shape 21 | self.include_rescaling = include_rescaling 22 | 23 | def run(self): 24 | if self.net == "ResNet18": 25 | backbone = ResNet18Backbone() 26 | backbone.layers[2].strides = (1, 1) 27 | # print(backbone.layers[2].get_config()) 28 | model = ImageClassifier(backbone=backbone, num_classes=self.num_classes) 29 | elif self.net == "ResNet50": 30 | backbone = keras_cv.models.ResNet50V2Backbone() 31 | model = ImageClassifier( 32 | backbone=backbone, 33 | num_classes=self.num_classes, 34 | ) 35 | else: 36 | fn = getattr(keras_cv.models, self.net) 37 | model = fn( 38 | include_rescaling=self.include_rescaling, 39 | include_top=True, 40 | weights=self.weights, 41 | classes=self.num_classes, 42 | input_shape=self.input_shape, 43 | ) 44 | model.compile( 45 | loss="sparse_categorical_crossentropy", 46 | optimizer=keras.optimizers.Adam(self.lr), 47 | metrics=["accuracy"], 48 | ) 49 | return model 50 | 51 | 52 | if __name__ == "__main__": 53 | import numpy as np 54 | 55 | model = ResNetModelBuilder(net="ResNet50").run() 56 | example_input = np.random.rand(2, 32, 32, 3) # 2 images, 32x32 pixels, 3 channels 57 | out = model(example_input) 58 | print("output tensor shape:", out.shape) 59 | -------------------------------------------------------------------------------- /tests/test_flwr_base.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from numpy import array, float32 4 | 5 | from flwr.common import ( 6 | Code, 7 | FitRes, 8 | NDArrays, 9 | Parameters, 10 | Status, 11 | ndarrays_to_parameters, 12 | parameters_to_ndarrays, 13 | ) 14 | from flwr.server.client_proxy import ClientProxy 15 | 16 | from flwr.server.strategy import FedAdagrad 17 | 18 | 19 | def test_aggregate_fit() -> None: 20 | """Tests if adagrad function is aggregating correctly.""" 21 | # Prepare 22 | previous_weights: NDArrays = [array([0.1, 0.1, 0.1, 0.1], dtype=float32)] 23 | strategy = FedAdagrad( 24 | eta=0.1, 25 | eta_l=0.316, 26 | tau=0.5, 27 | initial_parameters=ndarrays_to_parameters(previous_weights), 28 | ) 29 | param_0: Parameters = ndarrays_to_parameters( 30 | [array([0.2, 0.2, 0.2, 0.2], dtype=float32)] 31 | ) 32 | param_1: Parameters = ndarrays_to_parameters( 33 | [array([1.0, 1.0, 1.0, 1.0], dtype=float32)] 34 | ) 35 | # bridge = MagicMock() 36 | # client_0 = GrpcClientProxy(cid="0", bridge=bridge) 37 | # client_1 = GrpcClientProxy(cid="1", bridge=bridge) 38 | client_0 = None 39 | client_1 = None 40 | results: List[Tuple[ClientProxy, FitRes]] = [ 41 | ( 42 | client_0, 43 | FitRes( 44 | status=Status(code=Code.OK, message="Success"), 45 | parameters=param_0, 46 | num_examples=5, 47 | metrics={}, 48 | ), 49 | ), 50 | ( 51 | client_1, 52 | FitRes( 53 | status=Status(code=Code.OK, message="Success"), 54 | parameters=param_1, 55 | num_examples=5, 56 | metrics={}, 57 | ), 58 | ), 59 | ] 60 | expected: NDArrays = [array([0.15, 0.15, 0.15, 0.15], dtype=float32)] 61 | 62 | # Execute 63 | actual_aggregated, _ = strategy.aggregate_fit( 64 | server_round=1, results=results, failures=[] 65 | ) 66 | if actual_aggregated: 67 | actual_list = parameters_to_ndarrays(actual_aggregated) 68 | actual = actual_list[0] 69 | assert (actual == expected[0]).all() -------------------------------------------------------------------------------- /tests/experiments/test_keras_models.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings( 4 | "ignore", 5 | ) 6 | import numpy as np 7 | from flwr.server.strategy import FedAvg 8 | from uuid import uuid4 9 | from flwr_serverless.shared_folder.in_memory_folder import InMemoryFolder 10 | from flwr_serverless.keras.example import ( 11 | FederatedLearningTestRun, 12 | ) 13 | from experiments.model.keras_models import ResNetModelBuilder 14 | 15 | 16 | # This test is slow on cpu. 17 | def _test_mnist_resnet50_federated_callback_2nodes(): 18 | epochs = 8 19 | accuracy_standalone, accuracy_federated = FederatedLearningTestRun( 20 | num_nodes=2, 21 | epochs=epochs, 22 | num_rounds=epochs, 23 | lr=0.001, 24 | strategy=FedAvg(), 25 | model_builder_fn=ResNetModelBuilder( 26 | num_classes=10, 27 | lr=0.001, 28 | net="ResNet50", 29 | weights="imagenet", 30 | ).run, 31 | replicate_num_channels=True, 32 | storage_backend=InMemoryFolder(), 33 | ).run() 34 | for i in range(len(accuracy_standalone)): 35 | assert accuracy_standalone[i] < 1.0 / len(accuracy_standalone) + 0.05 36 | 37 | assert accuracy_federated[0] > accuracy_standalone[0] 38 | assert accuracy_federated[0] > 1.0 / len(accuracy_standalone) + 0.05 39 | # assert False # Uncomment if you want to see the print out of keras training. 40 | 41 | 42 | # This test fails because of overfitting 43 | def _test_mnist_resnet18_federated_callback_2nodes(): 44 | epochs = 8 45 | accuracy_standalone, accuracy_federated = FederatedLearningTestRun( 46 | num_nodes=2, 47 | epochs=epochs, 48 | num_rounds=epochs, 49 | lr=0.001, 50 | strategy=FedAvg(), 51 | model_builder_fn=ResNetModelBuilder( 52 | num_classes=10, 53 | lr=0.001, 54 | net="ResNet18", 55 | # weights="imagenet", # Does not work with ResNet18 56 | ).run, 57 | replicate_num_channels=True, 58 | storage_backend=InMemoryFolder(), 59 | ).run() 60 | for i in range(len(accuracy_standalone)): 61 | assert accuracy_standalone[i] < 1.0 / len(accuracy_standalone) + 0.05 62 | 63 | assert accuracy_federated[0] > accuracy_standalone[0] 64 | assert accuracy_federated[0] > 1.0 / len(accuracy_standalone) + 0.05 65 | -------------------------------------------------------------------------------- /tests/test_s3_folder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Use moto to mock S3 3 | """ 4 | from pytest import raises 5 | from unittest.mock import patch 6 | import boto3 7 | from moto import mock_s3 8 | from flwr_serverless.shared_folder.s3_folder import ( 9 | S3FolderWithBytes, 10 | S3FolderWithPickle, 11 | ) 12 | 13 | 14 | def test_s3_bytes_folder_read_write_delete(): 15 | with mock_s3(): 16 | s3 = boto3.client("s3") 17 | s3.create_bucket(Bucket="test_bucket") 18 | folder = S3FolderWithBytes( 19 | "test_bucket/test_folder", 20 | retry_sleep_time=0.1, 21 | max_retry=10, 22 | ) 23 | # read and write a dummy file 24 | key = "dummy" 25 | folder[key] = b"dummy" 26 | assert folder[key] == b"dummy" 27 | del folder[key] 28 | 29 | 30 | def test_s3_folder_get_raw_folder_should_not_call_check(): 31 | 32 | with mock_s3(): 33 | s3 = boto3.client("s3") 34 | s3.create_bucket(Bucket="test_bucket") 35 | folder = S3FolderWithPickle( 36 | "test_bucket/test_folder", 37 | retry_sleep_time=0.1, 38 | max_retry=10, 39 | ) 40 | def raise_if_called(): 41 | raise Exception("Should not be called") 42 | with patch.object(folder, "_check", raise_if_called): 43 | folder.get_raw_folder() 44 | 45 | 46 | def test_s3_pickle_folder_read_write_delete(): 47 | with mock_s3(): 48 | s3 = boto3.client("s3") 49 | s3.create_bucket(Bucket="test_bucket") 50 | folder = S3FolderWithPickle( 51 | "test_bucket/test_folder", 52 | retry_sleep_time=0.1, 53 | max_retry=10, 54 | ) 55 | # read and write a dummy file 56 | key = "dummy" 57 | folder[key] = "dummy" 58 | assert folder[key] == "dummy" 59 | del folder[key] 60 | 61 | 62 | def test_when_s3_is_not_accessible(): 63 | with mock_s3(): 64 | s3 = boto3.client("s3") 65 | S3FolderWithBytes( 66 | "test_bucket/test_folder", 67 | retry_sleep_time=0.1, 68 | max_retry=10, 69 | check_at_init=False, 70 | ) 71 | with raises(s3.exceptions.NoSuchBucket): 72 | S3FolderWithBytes( 73 | "test_bucket/test_folder", 74 | retry_sleep_time=0.1, 75 | max_retry=10, 76 | check_at_init=True, 77 | ) 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | A Flower ([flwr](https://flower.dev/)) extension for serverless federated learning. 2 | 3 | Technical report (arXiv): [Serverless Federated Learning with `flwr-serverless`](https://arxiv.org/abs/2310.15329). 4 | 5 | ## Install 6 | 7 | ``` 8 | pip install flwr-serverless 9 | 10 | or 11 | 12 | pip install git+https://github.com/kungfuai/flwr_serverless.git 13 | ``` 14 | 15 | ## Usage for tensorflow 16 | 17 | - Step 1: Create federated `Node`s that use a shared folder to exchange model weights and use a federated strategy (`flwr.server.strategy.Strategy`) to control how the weights are aggregated. 18 | - Step 2: Create and configure a callback `FlwrFederatedCallback` and use it in the `keras.Model.fit()`. 19 | 20 | ```python 21 | # Create a FL Node that has a strategy and a shared folder. 22 | from flwr.server.strategy import FedAvg # This is a flwr federated strategy. 23 | from flwr_serverless import AsyncFederatedNode, S3Folder 24 | from flwr_serverless.keras import FlwrFederatedCallback 25 | 26 | strategy = FedAvg() 27 | shared_folder = S3Folder(directory="mybucket/experiment1") 28 | node = AsyncFederatedNode(strategy=strategy, shared_folder=shared_folder) 29 | 30 | # Create a keras Callback with the FL node. 31 | num_examples_per_epoch = steps_per_epoch * batch_size # number of examples used in each epoch 32 | callback = FlwrFederatedCallback( 33 | node, 34 | num_examples_per_epoch=num_examples_per_epoch, 35 | save_model_before_aggregation=False, 36 | save_model_after_aggregation=False, 37 | ) 38 | 39 | # Join the federated learning, by fitting the model with the federated callback. 40 | model = keras.Model(...) 41 | model.compile(...) 42 | model.fit(dataset, callbacks=[callback]) 43 | ``` 44 | 45 | `flwr_serverless` uses `flwr_serverless.SharedFolder` to save model weights and metrics. The logic folder can be backed by a storage backend like S3. 46 | 47 | The asynchronous FL node does not wait to sync with other nodes. It takes the latest 48 | model weights from other nodes and performs the aggregation according to the specified strategy. 49 | 50 | ### Running experiments 51 | 52 | To make it easier to experimemt with different strategies, we provide utility classes like `flwr.keras.example.FederatedLearningTestRun`. This allows you to configure the dataset partition, strategy and concurrency. Please use this as an example to develop your own experiments. 53 | 54 | To reproduce some experiments reported in the paper, run 55 | 56 | ``` 57 | python -m experiments.experiment_scripts.exp1_mnist 58 | python -m experiments.experiment_scripts.exp2_cifar10 59 | python -m experiments.experiment_scripts.exp3_wikitext 60 | ``` 61 | 62 | Each of the above experiments run through a grid search over a large hyperparameter space, 63 | with repeated trials using different random seeds. Please edit the script to adjust 64 | the number of trials and the hyperparameter search space. -------------------------------------------------------------------------------- /experiments/exp1_mnist.py: -------------------------------------------------------------------------------- 1 | # Tensorflow logging level: warnings or higher 2 | import os 3 | 4 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 5 | from subprocess import check_output 6 | from tensorflow.keras.utils import set_random_seed 7 | from experiments.utils.federated_learning_runner import FederatedLearningRunner 8 | 9 | 10 | # main function 11 | if __name__ == "__main__": 12 | # starts a new run 13 | from argparse import ArgumentParser 14 | from dotenv import load_dotenv 15 | 16 | load_dotenv() 17 | 18 | parser = ArgumentParser( 19 | description="Run federated learning experiments on CIFAR10." 20 | ) 21 | 22 | # base config 23 | base_config = { 24 | "project": "mnist", 25 | "epochs": 3, 26 | "batch_size": 32, 27 | "steps_per_epoch": 1200, 28 | "lr": 0.001, 29 | "num_nodes": 2, 30 | "use_async": False, 31 | "federated_type": "concurrent", 32 | "dataset": "mnist", 33 | "strategy": "fedavg", 34 | "data_split": "skewed", 35 | "skew_factor": 0.0, 36 | "test_steps": None, 37 | "net": "simple", 38 | "random_seed": 0, 39 | "track": False, 40 | } 41 | for key, value in base_config.items(): 42 | if isinstance(value, bool): 43 | parser.add_argument(f"--{key}", action="store_true", default=value) 44 | else: 45 | parser.add_argument(f"--{key}", type=type(value), default=value) 46 | 47 | parser.add_argument( 48 | "--use_default_configs", "-u", action="store_true", default=False 49 | ) 50 | 51 | args = parser.parse_args() 52 | if args.use_default_configs: 53 | # Treatments 54 | config_overides = [ 55 | { 56 | "random_seed": random_seed, 57 | "use_async": user_async, 58 | "skew_factor": skew_factor, 59 | "num_nodes": num_nodes, 60 | "strategy": strategy, 61 | } 62 | for random_seed in range(3) 63 | for user_async in [True, False] 64 | for skew_factor in [ 65 | 0, 66 | # 0.1, 67 | # 0.5, 68 | # 0.9, 69 | 0.99, 70 | 1, 71 | ] 72 | for num_nodes in [2, 3, 5] 73 | for strategy in [ 74 | "fedavg", 75 | "fedadam", 76 | # "fedavgm", 77 | ] 78 | ] 79 | else: 80 | config_overide = {} 81 | for key, value in vars(args).items(): 82 | config_overide[key] = value 83 | config_overides = [config_overide] 84 | 85 | for i, config_overide in enumerate(config_overides): 86 | config_overide["track"] = args.track 87 | config = {**base_config, **config_overide} 88 | print( 89 | f"\n***** Starting trial {i + 1} of {len(config_overides)} with config: {str(config)[:80]}...\n" 90 | ) 91 | if args.use_default_configs: 92 | # use subprocess to run this script 93 | command = "python -m experiments.exp1_mnist" 94 | for key, value in config_overide.items(): 95 | if isinstance(value, bool): 96 | if value: 97 | command += f" --{key}" 98 | else: 99 | command += f" --{key} {value}" 100 | print(command) 101 | # wait for the command to finish, stream to stdout 102 | check_output(command, shell=True) 103 | else: 104 | federated_learning_runner = FederatedLearningRunner( 105 | config=config, 106 | ) 107 | federated_learning_runner.run() 108 | -------------------------------------------------------------------------------- /experiments/exp2_cifar10.py: -------------------------------------------------------------------------------- 1 | # Tensorflow logging level: warnings or higher 2 | import os 3 | 4 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 5 | from tensorflow.keras.utils import set_random_seed 6 | from subprocess import check_output 7 | from experiments.utils.federated_learning_runner import FederatedLearningRunner 8 | 9 | 10 | # main function 11 | if __name__ == "__main__": 12 | # starts a new run 13 | from argparse import ArgumentParser 14 | from dotenv import load_dotenv 15 | 16 | load_dotenv() 17 | 18 | parser = ArgumentParser( 19 | description="Run federated learning experiments on CIFAR10." 20 | ) 21 | 22 | # base config 23 | base_config = { 24 | "project": "cifar10", 25 | "epochs": 20, 26 | "batch_size": 128, 27 | "steps_per_epoch": 1200, 28 | "lr": 0.0005, 29 | "num_nodes": 2, 30 | "use_async": False, 31 | "federated_type": "concurrent", 32 | "dataset": "cifar10", 33 | "strategy": "fedavg", 34 | "data_split": "skewed", 35 | "skew_factor": 0.9, 36 | "test_steps": None, # 50, 37 | "net": "resnet18", 38 | "random_seed": 0, 39 | "track": False, 40 | } 41 | for key, value in base_config.items(): 42 | if isinstance(value, bool): 43 | parser.add_argument(f"--{key}", action="store_true", default=value) 44 | else: 45 | parser.add_argument(f"--{key}", type=type(value), default=value) 46 | 47 | parser.add_argument( 48 | "--use_default_configs", "-u", action="store_true", default=False 49 | ) 50 | 51 | args = parser.parse_args() 52 | if args.use_default_configs: 53 | # Treatments 54 | # Single node (centralized) training. 55 | config_overides = [ 56 | { 57 | "random_seed": random_seed, 58 | "num_nodes": 1, 59 | } 60 | for random_seed in [None, None] # range(1, 3) 61 | ] 62 | config_overides += [ 63 | { 64 | "random_seed": random_seed, 65 | "use_async": user_async, 66 | "skew_factor": skew_factor, 67 | "num_nodes": num_nodes, 68 | "strategy": strategy, 69 | } 70 | for random_seed in [100, 101] 71 | for user_async in [False] 72 | for skew_factor in [ 73 | 0, 74 | # 0.1, 75 | # 0.5, 76 | # 0.99, 77 | # 1, 78 | 0.9, 79 | ] 80 | for num_nodes in [3, 5, 2] 81 | for strategy in [ 82 | "fedavg", 83 | "fedavgm", 84 | # "fedadam", 85 | ] 86 | ] 87 | 88 | else: 89 | config_overide = {} 90 | for key, value in vars(args).items(): 91 | config_overide[key] = value 92 | config_overides = [config_overide] 93 | 94 | for i, config_overide in enumerate(config_overides): 95 | config_overide["track"] = args.track 96 | config = {**base_config, **config_overide} 97 | print( 98 | f"\n***** Starting trial {i + 1} of {len(config_overides)} with config: {str(config)[:80]}...\n" 99 | ) 100 | if args.use_default_configs: 101 | # use subprocess to run this script 102 | command = "python -m experiments.exp2_cifar10" 103 | for key, value in config_overide.items(): 104 | if isinstance(value, bool): 105 | if value: 106 | command += f" --{key}" 107 | else: 108 | command += f" --{key} {value}" 109 | print(command) 110 | # wait for the command to finish, stream to stdout 111 | check_output(command, shell=True) 112 | else: 113 | federated_learning_runner = FederatedLearningRunner( 114 | config=config, 115 | ) 116 | federated_learning_runner.run() 117 | -------------------------------------------------------------------------------- /experiments/experiment_scripts/exp1_mnist_async_fedavg.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | 3 | # import set_random_seed 4 | from tensorflow.keras.utils import set_random_seed 5 | from experiments.utils.federated_learning_runner import FederatedLearningRunner 6 | 7 | # main function 8 | if __name__ == "__main__": 9 | from dotenv import load_dotenv 10 | import os 11 | 12 | load_dotenv() 13 | # starts a new run 14 | set_random_seed(117) 15 | 16 | # shared config parameters 17 | num_nodes = 2 18 | federated_type = "concurrent" 19 | dataset = "mnist" 20 | strategy = "fedavg" 21 | epochs = 1000 22 | batch_size = 32 23 | steps_per_epoch = 64 24 | lr = 0.001 25 | 26 | base_config = { 27 | "net": "simple", 28 | "test_steps": None, 29 | } 30 | 31 | # async partitioned 32 | config1 = { 33 | "use_async": True, 34 | "data_split": "partitioned", 35 | "epochs": epochs, 36 | "batch_size": batch_size, 37 | "steps_per_epoch": steps_per_epoch, 38 | "lr": lr, 39 | "num_nodes": num_nodes, 40 | "federated_type": federated_type, 41 | "dataset": dataset, 42 | "strategy": strategy, 43 | } 44 | 45 | # async skewed 46 | config2 = { 47 | "use_async": True, 48 | "data_split": "skewed", 49 | "epochs": epochs, 50 | "batch_size": batch_size, 51 | "steps_per_epoch": steps_per_epoch, 52 | "lr": lr, 53 | "num_nodes": num_nodes, 54 | "federated_type": federated_type, 55 | "dataset": dataset, 56 | "strategy": strategy, 57 | } 58 | 59 | # async random 60 | config3 = { 61 | "use_async": True, 62 | "data_split": "random", 63 | "epochs": epochs, 64 | "batch_size": batch_size, 65 | "steps_per_epoch": steps_per_epoch, 66 | "lr": lr, 67 | "num_nodes": num_nodes, 68 | "federated_type": federated_type, 69 | "dataset": dataset, 70 | "strategy": strategy, 71 | } 72 | # sync partitioned 73 | config4 = { 74 | "use_async": False, 75 | "data_split": "partitioned", 76 | "epochs": epochs, 77 | "batch_size": batch_size, 78 | "steps_per_epoch": steps_per_epoch, 79 | "lr": lr, 80 | "num_nodes": num_nodes, 81 | "federated_type": federated_type, 82 | "dataset": dataset, 83 | "strategy": strategy, 84 | } 85 | 86 | # sync skewed 87 | config5 = { 88 | "use_async": False, 89 | "data_split": "skewed", 90 | "epochs": epochs, 91 | "batch_size": batch_size, 92 | "steps_per_epoch": steps_per_epoch, 93 | "lr": lr, 94 | "num_nodes": num_nodes, 95 | "federated_type": federated_type, 96 | "dataset": dataset, 97 | "strategy": strategy, 98 | } 99 | 100 | # sync random 101 | config6 = { 102 | "use_async": False, 103 | "data_split": "random", 104 | "epochs": epochs, 105 | "batch_size": batch_size, 106 | "steps_per_epoch": steps_per_epoch, 107 | "lr": lr, 108 | "num_nodes": num_nodes, 109 | "federated_type": federated_type, 110 | "dataset": dataset, 111 | "strategy": strategy, 112 | } 113 | 114 | configs = [config1, config2, config3, config4, config5, config6] 115 | 116 | for _config in configs: 117 | config = {**base_config, **_config} 118 | if config["use_async"]: 119 | use_async = "async" 120 | else: 121 | use_async = "sync" 122 | data_split = config["data_split"] 123 | # print(os.getenv("WANDB_PROJECT")) 124 | # wandb.init( 125 | # project=os.getenv("WANDB_PROJECT"), 126 | # entity="flwr_serverless", 127 | # name=f"mnist_{use_async}_{data_split}_split", 128 | # config=config, 129 | # ) 130 | federated_learning_runner = FederatedLearningRunner( 131 | config=config, 132 | tracking=False, 133 | ) 134 | federated_learning_runner.run() 135 | # wandb.finish() 136 | -------------------------------------------------------------------------------- /doc/paper/.gitignore: -------------------------------------------------------------------------------- 1 | ## Core latex/pdflatex auxiliary files: 2 | *.aux 3 | *.lof 4 | *.log 5 | *.lot 6 | *.fls 7 | *.out 8 | *.toc 9 | *.fmt 10 | *.fot 11 | *.cb 12 | *.cb2 13 | .*.lb 14 | 15 | ## Intermediate documents: 16 | *.dvi 17 | *.xdv 18 | *-converted-to.* 19 | # these rules might exclude image files for figures etc. 20 | # *.ps 21 | # *.eps 22 | # *.pdf 23 | 24 | ## Generated if empty string is given at "Please type another file name for output:" 25 | .pdf 26 | main.pdf 27 | 28 | ## Bibliography auxiliary files (bibtex/biblatex/biber): 29 | *.bbl 30 | *.bcf 31 | *.blg 32 | *-blx.aux 33 | *-blx.bib 34 | *.run.xml 35 | 36 | ## Build tool auxiliary files: 37 | *.fdb_latexmk 38 | *.synctex 39 | *.synctex(busy) 40 | *.synctex.gz 41 | *.synctex.gz(busy) 42 | *.pdfsync 43 | 44 | ## Build tool directories for auxiliary files 45 | # latexrun 46 | latex.out/ 47 | 48 | ## Auxiliary and intermediate files from other packages: 49 | # algorithms 50 | *.alg 51 | *.loa 52 | 53 | # achemso 54 | acs-*.bib 55 | 56 | # amsthm 57 | *.thm 58 | 59 | # beamer 60 | *.nav 61 | *.pre 62 | *.snm 63 | *.vrb 64 | 65 | # changes 66 | *.soc 67 | 68 | # comment 69 | *.cut 70 | 71 | # cprotect 72 | *.cpt 73 | 74 | # elsarticle (documentclass of Elsevier journals) 75 | *.spl 76 | 77 | # endnotes 78 | *.ent 79 | 80 | # fixme 81 | *.lox 82 | 83 | # feynmf/feynmp 84 | *.mf 85 | *.mp 86 | *.t[1-9] 87 | *.t[1-9][0-9] 88 | *.tfm 89 | 90 | #(r)(e)ledmac/(r)(e)ledpar 91 | *.end 92 | *.?end 93 | *.[1-9] 94 | *.[1-9][0-9] 95 | *.[1-9][0-9][0-9] 96 | *.[1-9]R 97 | *.[1-9][0-9]R 98 | *.[1-9][0-9][0-9]R 99 | *.eledsec[1-9] 100 | *.eledsec[1-9]R 101 | *.eledsec[1-9][0-9] 102 | *.eledsec[1-9][0-9]R 103 | *.eledsec[1-9][0-9][0-9] 104 | *.eledsec[1-9][0-9][0-9]R 105 | 106 | # glossaries 107 | *.acn 108 | *.acr 109 | *.glg 110 | *.glo 111 | *.gls 112 | *.glsdefs 113 | *.lzo 114 | *.lzs 115 | *.slg 116 | *.slo 117 | *.sls 118 | 119 | # uncomment this for glossaries-extra (will ignore makeindex's style files!) 120 | # *.ist 121 | 122 | # gnuplot 123 | *.gnuplot 124 | *.table 125 | 126 | # gnuplottex 127 | *-gnuplottex-* 128 | 129 | # gregoriotex 130 | *.gaux 131 | *.glog 132 | *.gtex 133 | 134 | # htlatex 135 | *.4ct 136 | *.4tc 137 | *.idv 138 | *.lg 139 | *.trc 140 | *.xref 141 | 142 | # hyperref 143 | *.brf 144 | 145 | # knitr 146 | *-concordance.tex 147 | # TODO Uncomment the next line if you use knitr and want to ignore its generated tikz files 148 | # *.tikz 149 | *-tikzDictionary 150 | 151 | # listings 152 | *.lol 153 | 154 | # luatexja-ruby 155 | *.ltjruby 156 | 157 | # makeidx 158 | *.idx 159 | *.ilg 160 | *.ind 161 | 162 | # minitoc 163 | *.maf 164 | *.mlf 165 | *.mlt 166 | *.mtc[0-9]* 167 | *.slf[0-9]* 168 | *.slt[0-9]* 169 | *.stc[0-9]* 170 | 171 | # minted 172 | _minted* 173 | *.pyg 174 | 175 | # morewrites 176 | *.mw 177 | 178 | # newpax 179 | *.newpax 180 | 181 | # nomencl 182 | *.nlg 183 | *.nlo 184 | *.nls 185 | 186 | # pax 187 | *.pax 188 | 189 | # pdfpcnotes 190 | *.pdfpc 191 | 192 | # sagetex 193 | *.sagetex.sage 194 | *.sagetex.py 195 | *.sagetex.scmd 196 | 197 | # scrwfile 198 | *.wrt 199 | 200 | # svg 201 | svg-inkscape/ 202 | 203 | # sympy 204 | *.sout 205 | *.sympy 206 | sympy-plots-for-*.tex/ 207 | 208 | # pdfcomment 209 | *.upa 210 | *.upb 211 | 212 | # pythontex 213 | *.pytxcode 214 | pythontex-files-*/ 215 | 216 | # tcolorbox 217 | *.listing 218 | 219 | # thmtools 220 | *.loe 221 | 222 | # TikZ & PGF 223 | *.dpth 224 | *.md5 225 | *.auxlock 226 | 227 | # titletoc 228 | *.ptc 229 | 230 | # todonotes 231 | *.tdo 232 | 233 | # vhistory 234 | *.hst 235 | *.ver 236 | 237 | # easy-todo 238 | *.lod 239 | 240 | # xcolor 241 | *.xcp 242 | 243 | # xmpincl 244 | *.xmpi 245 | 246 | # xindy 247 | *.xdy 248 | 249 | # xypic precompiled matrices and outlines 250 | *.xyc 251 | *.xyd 252 | 253 | # endfloat 254 | *.ttt 255 | *.fff 256 | 257 | # Latexian 258 | TSWLatexianTemp* 259 | 260 | ## Editors: 261 | # WinEdt 262 | *.bak 263 | *.sav 264 | 265 | # Texpad 266 | .texpadtmp 267 | 268 | # LyX 269 | *.lyx~ 270 | 271 | # Kile 272 | *.backup 273 | 274 | # gummi 275 | .*.swp 276 | 277 | # KBibTeX 278 | *~[0-9]* 279 | 280 | # TeXnicCenter 281 | *.tps 282 | 283 | # auto folder when using emacs and auctex 284 | ./auto/* 285 | *.el 286 | 287 | # expex forward references with \gathertags 288 | *-tags.tex 289 | 290 | # standalone packages 291 | *.sta 292 | 293 | # Makeindex log files 294 | *.lpz 295 | 296 | # xwatermark package 297 | *.xwm 298 | 299 | # REVTeX puts footnotes in the bibliography by default, unless the nofootinbib 300 | # option is specified. Footnotes are the stored in a file with suffix Notes.bib. 301 | # Uncomment the next line to have this generated file ignored. 302 | #*Notes.bib -------------------------------------------------------------------------------- /flwr_serverless/shared_folder/local_folder.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import pickle 3 | import time 4 | from typing import Any 5 | 6 | 7 | class LocalFolderWithBytes: 8 | def __init__( 9 | self, directory: str = None, retry_sleep_time: int = 3, max_retry: int = 3 10 | ): 11 | self.directory = Path(directory) 12 | self.directory.mkdir(parents=True, exist_ok=True) 13 | self.retry_sleep_time = retry_sleep_time 14 | self.max_retry = max_retry 15 | 16 | def _get_success_flag_file(self, key): 17 | return self.directory / ("success_" + key) 18 | 19 | def _delete_success_flag(self, key): 20 | filepath = self._get_success_flag_file(key) 21 | if filepath.exists(): 22 | filepath.unlink() 23 | 24 | def _put_success_flag(self, key): 25 | filepath = self._get_success_flag_file(key) 26 | # create parent dir 27 | filepath.parent.mkdir(parents=True, exist_ok=True) 28 | with open(filepath, "w") as f: 29 | f.write("") 30 | 31 | def get(self, key, default=None): 32 | success_flag_file = self._get_success_flag_file(key) 33 | patience = self.max_retry 34 | while not success_flag_file.exists(): 35 | print(f"\nwaiting for success flag of {key}") 36 | time.sleep(self.retry_sleep_time) 37 | patience -= 1 38 | if patience == 0: 39 | return default 40 | filepath = self.directory / key 41 | if filepath.exists(): 42 | with open(filepath, "rb") as f: 43 | return f.read() 44 | else: 45 | return default 46 | 47 | def __getitem__(self, key): 48 | return self.get(key) 49 | 50 | def __setitem__(self, key, value: Any): 51 | assert isinstance(value, bytes), f"value must be bytes, but got {type(value)}" 52 | filepath = self.directory / key 53 | if value is None: 54 | raise ValueError("value must not be None") 55 | self._delete_success_flag(key) 56 | # create parent dir 57 | filepath.parent.mkdir(parents=True, exist_ok=True) 58 | with open(filepath, "wb") as f: 59 | f.write(value) 60 | self._put_success_flag(key) 61 | 62 | def __len__(self): 63 | # recursive 64 | return len(list(self.directory.glob("*"))) 65 | 66 | def __delitem__(self, key): 67 | filepath = self.directory / key 68 | if filepath.exists(): 69 | filepath.unlink() 70 | 71 | def items(self): 72 | for filepath in self.directory.glob("*"): 73 | # remove the directory name 74 | key = str(filepath)[len(str(self.directory)) + 1 :] 75 | yield key, self.get(key) 76 | 77 | 78 | class LocalFolder: 79 | def __init__( 80 | self, directory: str = None, retry_sleep_time: int = 3, max_retry: int = 3 81 | ): 82 | self.directory = Path(directory) 83 | self.directory.mkdir(parents=True, exist_ok=True) 84 | self.suffix = ".pkl" 85 | self.retry_sleep_time = retry_sleep_time 86 | self.max_retry = max_retry 87 | 88 | def get(self, key, default=None): 89 | success_flag_file = self._get_success_flag_file(key) 90 | patience = self.max_retry 91 | while not success_flag_file.exists(): 92 | print(f"\nwaiting for success flag of {key}") 93 | time.sleep(self.retry_sleep_time) 94 | patience -= 1 95 | if patience == 0: 96 | return default 97 | filepath = self.directory / (key + self.suffix) 98 | if filepath.exists(): 99 | with open(filepath, "rb") as f: 100 | return pickle.load(f) 101 | else: 102 | return default 103 | 104 | def __getitem__(self, key): 105 | return self.get(key) 106 | 107 | def __setitem__(self, key, value: Any): 108 | filepath = self.directory / (key + self.suffix) 109 | if value is None: 110 | raise ValueError("value must not be None") 111 | self._delete_success_flag(key) 112 | with open(filepath, "wb") as f: 113 | pickle.dump(value, f) 114 | self._put_success_flag(key) 115 | 116 | def __delitem__(self, key): 117 | filepath = self.directory / (key + self.suffix) 118 | if filepath.exists(): 119 | filepath.unlink() 120 | 121 | def _get_success_flag_file(self, key): 122 | return self.directory / ("success_" + key) 123 | 124 | def _delete_success_flag(self, key): 125 | filepath = self._get_success_flag_file(key) 126 | if filepath.exists(): 127 | filepath.unlink() 128 | 129 | def _put_success_flag(self, key): 130 | filepath = self._get_success_flag_file(key) 131 | # create parent dir 132 | filepath.parent.mkdir(parents=True, exist_ok=True) 133 | with open(filepath, "w") as f: 134 | f.write("") 135 | 136 | def __len__(self): 137 | return len(list(self.directory.glob(f"*{self.suffix}"))) 138 | 139 | def items(self): 140 | for filepath in self.directory.glob(f"*{self.suffix}"): 141 | key_and_parameter = self.get_parameter(filepath) 142 | yield key_and_parameter 143 | 144 | def get_parameter(self, filepath): 145 | with open(filepath, "rb") as f: 146 | try: 147 | key = filepath.name[: -len(self.suffix)] 148 | parameters = self.get(key) 149 | return key, parameters 150 | except EOFError as e: 151 | print(f"EOFError: {e}") 152 | return None, None 153 | 154 | def get_raw_folder(self): 155 | """ 156 | Creates a new LocalFolderWithBytes instance with the same directory. 157 | The "raw folder" is used to store raw bytes. This is different 158 | from the "regular" folder which stores pickled objects. 159 | """ 160 | return LocalFolderWithBytes( 161 | directory=self.directory, 162 | retry_sleep_time=self.retry_sleep_time, 163 | max_retry=self.max_retry, 164 | ) 165 | -------------------------------------------------------------------------------- /experiments/get_wandb_tables.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from dotenv import load_dotenv 4 | import wandb 5 | 6 | 7 | load_dotenv() 8 | 9 | 10 | def get_run_group( 11 | runs, 12 | is_async: bool = False, 13 | skew_factor: int = 0, 14 | num_nodes: int = 2, 15 | strategy: str = "fedavg", 16 | ): 17 | run_group = [] 18 | for run in runs: 19 | if ( 20 | run.config["use_async"] == is_async 21 | and run.config["skew_factor"] == skew_factor 22 | and run.config["num_nodes"] == num_nodes 23 | and run.config["strategy"] == strategy 24 | ): 25 | if "test_accuracy" in run.summary and run.summary["test_accuracy"] > 0: 26 | run_group.append(run) 27 | return run_group 28 | 29 | 30 | def get_run_group_for_wikitext( 31 | runs, 32 | is_async: bool = False, 33 | num_nodes: int = 2, 34 | strategy: str = "fedavg", 35 | model_name="EleutherAI/pythia-14M", 36 | ): 37 | run_group = [] 38 | for run in runs: 39 | # print( 40 | # f"run config: {run.config['use_async']}, {run.config['num_nodes']}, {run.config.get('strategy', 'fedavg')}" 41 | # ) 42 | # print(f"accuracy: {run.summary.get('eval_accuracy')}") 43 | if ( 44 | run.config["use_async"] == is_async 45 | and run.config["num_nodes"] == num_nodes 46 | and run.config.get("strategy", "fedavg") == strategy 47 | and run.config["model_name"] == model_name 48 | ): 49 | if "eval_accuracy" in run.summary and run.summary["eval_accuracy"] > 0: 50 | run_group.append(run) 51 | return run_group 52 | 53 | 54 | def get_mean_std(run_group, metric="test_accuracy"): 55 | metric_values = [] 56 | for run in run_group: 57 | if metric in run.summary: 58 | metric_values.append(run.summary[metric]) 59 | if len(metric_values) == 0: 60 | return 0, 0 61 | return round(sum(metric_values) / len(metric_values), 3), round( 62 | np.std(metric_values), 3 63 | ) 64 | 65 | 66 | def get_exp1_2_tables(wandb_project: str = "mnist"): 67 | # list runs 68 | api = wandb.Api() 69 | wandb_entity = os.getenv("WANDB_ENTITY") 70 | runs = api.runs( 71 | path=f"{wandb_entity}/{wandb_project}", 72 | # successful only 73 | filters={ 74 | "state": "finished", 75 | }, 76 | ) 77 | 78 | def f(is_async, skew_factor, num_nodes, strategy): 79 | run_group = get_run_group( 80 | runs, 81 | is_async=is_async, 82 | skew_factor=skew_factor, 83 | num_nodes=num_nodes, 84 | strategy=strategy, 85 | ) 86 | print( 87 | f"found {len(run_group)} runs for {strategy}, {num_nodes} nodes, skew {skew_factor}, async {is_async}" 88 | ) 89 | mean, std = get_mean_std(run_group, metric="test_accuracy") 90 | ci95 = 1.96 * std / np.sqrt(len(run_group)) 91 | ci95 = round(ci95, 3) 92 | return f"{mean} $\\pm$ {ci95}".replace("0.", ".") 93 | 94 | # run_group = get_run_group(runs, is_async=False, skew_factor=0, num_nodes=2, strategy="fedavg") 95 | # print(get_mean_std(run_group, metric="test_accuracy")) 96 | for d in [0, 0.9, 0.99]: 97 | latex_table = ( 98 | """ 99 | \\toprule 100 | & \\multicolumn{3}{c}{Number of Nodes} \\\\ 101 | Strategy & 2 & 3 & 5 \\\\ 102 | \\midrule 103 | """ 104 | + f""" 105 | FedAvg & {f(False, d, 2, "fedavg")} & {f(False, d, 3, "fedavg")} & {f(False, d, 5, "fedavg")} \\\\ 106 | FedAvgM & {f(False, d, 2, "fedavgm")} & {f(False, d, 3, "fedavgm")} & {f(False, d, 5, "fedavgm")} \\\\ 107 | FedAdam & {f(False, d, 2, "fedadam")} & {f(False, d, 3, "fedadam")} & {f(False, d, 5, "fedadam")} \\\\ 108 | \\midrule 109 | 110 | FedAvg (async) & {f(True, d, 2, "fedavg")} & {f(True, d, 3, "fedavg")} & {f(True, d, 5, "fedavg")} \\\\ 111 | FedAvgM (async) & {f(True, d, 2, "fedavgm")} & {f(True, d, 3, "fedavgm")} & {f(True, d, 5, "fedavgm")} \\\\ 112 | FedAdam (async) & {f(True, d, 2, "fedadam")} & {f(True, d, 3, "fedadam")} & {f(True, d, 5, "fedadam")} \\\\ 113 | 114 | \\bottomrule 115 | """ 116 | ) 117 | print(latex_table) 118 | print(f"Above is for disparity {d}") 119 | 120 | # another table with sync, async as rows, skew, accuracy as columns 121 | latex_table = ( 122 | """ 123 | \\toprule 124 | & \\multicolumn{3}{c}{Disparity} \\\\ 125 | Strategy & 0 & 0.9 & 1 \\\\ 126 | \\midrule 127 | """ 128 | + f""" 129 | sync & {f(False, 0.0, 2, "fedavg")} & {f(False, 0.9, 2, "fedavg")} & {f(False, 1, 2, "fedavg")} \\\\ 130 | async & {f(True, 0.0, 2, "fedavg")} & {f(True, 0.9, 2, "fedavg")} & {f(True, 1, 2, "fedavg")} \\\\ 131 | 132 | \\bottomrule 133 | """ 134 | ) 135 | 136 | print(latex_table) 137 | 138 | 139 | def get_exp3_tables(wandb_project="wikitext"): 140 | # list runs 141 | api = wandb.Api() 142 | wandb_entity = os.getenv("WANDB_ENTITY") 143 | runs = api.runs( 144 | path=f"{wandb_entity}/{wandb_project}", 145 | # successful only 146 | filters={ 147 | "state": "finished", 148 | }, 149 | ) 150 | 151 | def f(is_async, skew_factor, num_nodes, strategy): 152 | run_group = get_run_group_for_wikitext( 153 | runs, 154 | is_async=is_async, 155 | num_nodes=num_nodes, 156 | strategy=strategy, 157 | ) 158 | print( 159 | f"found {len(run_group)} runs for {strategy}, {num_nodes} nodes, async {is_async}" 160 | ) 161 | mean, std = get_mean_std(run_group, metric="eval_accuracy") 162 | ci95 = 1.96 * std / np.sqrt(len(run_group)) 163 | ci95 = round(ci95, 3) 164 | return f"{mean} $\\pm$ {ci95}".replace("0.", ".") 165 | 166 | d = 0 167 | # compare num_nodes, sync vs async 168 | latex_table = ( 169 | """ 170 | \\toprule 171 | & \\multicolumn{2}{c}{Number of Nodes} \\\\ 172 | Strategy & 2 & 3 & 5 \\\\ 173 | \\midrule 174 | """ 175 | + f""" 176 | FedAvg & {f(False, d, 2, "fedavg")} & {f(False, d, 3, "fedavg")} & {f(False, d, 5, "fedavg")} \\\\ 177 | FedAvg (async) & {f(True, d, 2, "fedavg")} & {f(True, d, 3, "fedavg")} & {f(True, d, 5, "fedavg")} \\\\ 178 | \\bottomrule 179 | """ 180 | ) 181 | print(latex_table) 182 | 183 | 184 | if __name__ == "__main__": 185 | # get_exp1_2_tables("mnist") 186 | get_exp1_2_tables("cifar10") 187 | # get_exp3_tables("wikitext") 188 | -------------------------------------------------------------------------------- /doc/paper/references.bib: -------------------------------------------------------------------------------- 1 | % Articles 2 | 3 | @article{fed_1, 4 | author = {Konevcny J. and McMahan H.B .and Yu F.X. and Richtárik P. and Suresh A.T. and Bacon D.}, 5 | title = {Federated learning: Strategies for improving communication efficiency}, 6 | journal = {arXiv}, 7 | year = {2016}, 8 | doi = {10.48550/ARXIV.1610.05492} 9 | } 10 | 11 | @article{fed_2, 12 | doi = {10.48550/ARXIV.1602.05629}, 13 | author = {McMahan H.B. and Moore E. and Ramage D. and Hampson S. and Arcas B.A.y}, 14 | title = {Communication-Efficient Learning of Deep Networks from Decentralized Data}, 15 | journal = {arXiv}, 16 | year = {2016} 17 | } 18 | 19 | @article{fed_iot, 20 | doi = {10.48550/ARXIV.2104.10501}, 21 | author = {Zhou J. and Zhang S. and Lu Q. and Dai W. and Chen M. and Liu X. and Pirttikangas S. and Shi Y. and Zhang W. and Herrera-Viedma E.}, 22 | title = {A Survey on Federated Learning and its Applications for Accelerating Industrial Internet of Things}, 23 | journal = {arXiv}, 24 | year = {2021} 25 | } 26 | 27 | @article{fed_healthcare, 28 | author = {Joshi M. and Pal A. and Sankarasubbu M.}, 29 | title = {Federated Learning for Healthcare Domain - Pipeline, Applications and Challenges}, 30 | year = {2022}, 31 | volume = {3}, 32 | number = {4}, 33 | doi = {10.1145/3533708}, 34 | journal = {ACM Transactions on Computing for Healthcare} 35 | } 36 | 37 | @article{fed_wireless, 38 | doi = {10.48550/ARXIV.1908.06847}, 39 | author = {Niknam S. and Dhillon H.S. and Reed J.H.}, 40 | title = {Federated Learning for Wireless Communications: Motivation, Opportunities and Challenges}, 41 | journal = {arXiv}, 42 | year = {2019} 43 | } 44 | 45 | @article{flower, 46 | doi = {10.48550/ARXIV.2007.14390}, 47 | author = {Beutel D.J. and Topal T. and Mathur A. and Qiu X. and Fernandez-Marques J. and Gao Y. and Sani L. and Li K.H. and Parcollet T. and de Gusmão, P.P.B. and Lane N.D.}, 48 | title = {Flower: A Friendly Federated Learning Research Framework}, 49 | journal = {arXiv}, 50 | year = {2020} 51 | } 52 | 53 | @article{gdpr, 54 | author = {Albrecht J.P.}, 55 | title = {How GDPR will change the world}, 56 | journal = {Eur. Dat. Prot. L. Rev.}, 57 | year = {2016}, 58 | volume = {2} 59 | } 60 | 61 | @article{async_sgd4, 62 | author = {Xu J. and Zhang W. and Wang F.}, 63 | journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence}, 64 | title = {A(DP)$^2$2SGD: Asynchronous Decentralized Parallel Stochastic Gradient Descent With Differential Privacy}, 65 | year = {2022}, 66 | volume = {44}, 67 | number = {11}, 68 | pages = {8036-8047}, 69 | doi = {10.1109/TPAMI.2021.3107796} 70 | } 71 | 72 | @article{fedasync, 73 | author = {Xie C. and 74 | Koyejo S. and 75 | Gupta I.}, 76 | title = {Asynchronous Federated Optimization}, 77 | journal = {CoRR}, 78 | volume = {abs/1903.03934}, 79 | year = {2019}, 80 | url = {http://arxiv.org/abs/1903.03934} 81 | } 82 | 83 | @article{asofed, 84 | author = {Chen Y. and Sun X. and Jin Y.}, 85 | journal = {IEEE Transactions on Neural Networks and Learning Systems}, 86 | title = {Communication-Efficient Federated Deep Learning With Layerwise Asynchronous Model Update and Temporally Weighted Aggregation}, 87 | year = {2020}, 88 | volume = {31}, 89 | number = {10}, 90 | pages = {4229-4238}, 91 | doi = {10.1109/TNNLS.2019.2953131} 92 | } 93 | 94 | @article{fedsa, 95 | title = {FedSA: A staleness-aware asynchronous Federated Learning algorithm with non-IID data}, 96 | journal = {Future Generation Computer Systems}, 97 | volume = {120}, 98 | pages = {1-12}, 99 | year = {2021}, 100 | doi = {https://doi.org/10.1016/j.future.2021.02.012}, 101 | author = {Chen M. and Mao B. and Ma T.} 102 | } 103 | 104 | @article{safa, 105 | author = {Wu W. and He L. and Lin W. and Mao R. and Maple C. and Jarvis S.}, 106 | journal = {IEEE Transactions on Computers}, 107 | title = {SAFA: A Semi-Asynchronous Protocol for Fast Federated Learning With Low Overhead}, 108 | year = {2021}, 109 | volume = {70}, 110 | number = {5}, 111 | pages = {655-668}, 112 | doi = {10.1109/TC.2020.2994391} 113 | } 114 | 115 | @article{semi, 116 | title = {Semi-synchronous federated learning for energy-efficient training and accelerated convergence in cross-silo settings}, 117 | author = {Stripelis D and Thompson PM and Ambite JL}, 118 | journal = {ACM Trans. Intell. Syst. Technol.}, 119 | volume = {13}, 120 | issue = {4}, 121 | number = {78} 122 | } 123 | 124 | % Conference proceddings 125 | 126 | @inproceedings{fed_async1, 127 | author = {Su N. and Li B.}, 128 | booktitle = {2022 IEEE/ACM 30th International Symposium on Quality of Service (IWQoS)}, 129 | title = {How Asynchronous can Federated Learning Be?}, 130 | year = {2022}, 131 | pages = {1-11}, 132 | doi = {10.1109/IWQoS54832.2022.9812885} 133 | } 134 | 135 | @inproceedings{async_sgd1, 136 | title = {Asynchronous decentralized parallel stochastic gradient descent}, 137 | author = {Lian X. and Zhang W. and Zhang C. and Liu J.}, 138 | booktitle = {Proceedings of the 35th International Conference on Machine 139 | Learning}, 140 | year = {2018}, 141 | doi = {10.48550/arXiv.1710.06952} 142 | } 143 | 144 | @inproceedings{fedbuff, 145 | title = {Federated Learning with Buffered Asynchronous Aggregation}, 146 | author = {Nguyen J. and Malik K. and Zhan H. and Yousefpour A. and Rabbat M. and Malek M. and Huba D.}, 147 | booktitle = {Proceedings of the 25th International Conference on Machine Learning (ICML)}, 148 | year = {2022} 149 | } 150 | 151 | @inproceedings{fedprox, 152 | title = {Federated optimization in heterogeneous networks}, 153 | author = {Li T. and Sahu A.K. and Zaheer M. and Sanjabi M. and Talwalkar A. and Smith V.}, 154 | booktitle = {Proceedings of the 3rd MLSys Conference}, 155 | year = {2020} 156 | } 157 | 158 | @inproceedings{fedbuff2, 159 | author = {Toghani M.T. and Uribe C.A.}, 160 | booktitle = {2022 58th Annual Allerton Conference on Communication, Control, and Computing (Allerton)}, 161 | title = {Unbounded Gradients in Federated Learning with Buffered Asynchronous Aggregation}, 162 | year = {2022}, 163 | pages = {1-8}, 164 | doi = {10.1109/Allerton49937.2022.9929409} 165 | } 166 | 167 | @inproceedings{async_sgd2, 168 | author = {Zheng S. and Meng Q. and Wang T. and Chen W. and Yu N. and Ma Z. and Liu T.}, 169 | title = {Asynchronous Stochastic Gradient Descent with Delay Compensation}, 170 | year = {2017}, 171 | booktitle = {Proceedings of the 34th International Conference on Machine Learning - Volume 70}, 172 | pages = {4120–4129}, 173 | series = {ICML'17} 174 | } 175 | 176 | @inproceedings{async_sgd3, 177 | author = {Langford J. and Smola A.J. and Zinkevich M.}, 178 | title = {Slow Learners Are Fast}, 179 | year = {2009}, 180 | booktitle = {Proceedings of the 22nd International Conference on Neural Information Processing Systems}, 181 | pages = {2331–2339}, 182 | series = {NIPS'09} 183 | } 184 | 185 | % URLS 186 | 187 | @misc{gpdr_url, 188 | author = {European Union}, 189 | title = {General data protection regulation}, 190 | howpublished = {https://eur-lex.europa.eu/eli/reg/2016/679/oj}, 191 | year = {2016}, 192 | note = {Online; accessed 12/7/2022} 193 | } -------------------------------------------------------------------------------- /flwr_serverless/keras/federated_learning_callback.py: -------------------------------------------------------------------------------- 1 | import os 2 | from io import BytesIO 3 | import json 4 | import logging 5 | from typing import Union 6 | from tensorflow import keras 7 | from flwr_serverless.federated_node.async_federated_node import AsyncFederatedNode 8 | from flwr_serverless.federated_node.sync_federated_node import SyncFederatedNode 9 | from flwr.common import ( 10 | NDArrays, 11 | Parameters, 12 | ndarrays_to_parameters, 13 | parameters_to_ndarrays, 14 | ) 15 | 16 | 17 | LOGGER = logging.getLogger(__name__) 18 | 19 | 20 | class FlwrFederatedCallback(keras.callbacks.Callback): 21 | def __init__( 22 | self, 23 | node: Union[AsyncFederatedNode, SyncFederatedNode], 24 | num_examples_per_epoch: int, 25 | x_test=None, 26 | y_test=None, 27 | test_batch_size=32, 28 | test_steps=10, 29 | postfix_for_federated_metrics="_fed", 30 | override_metrics_with_aggregated_metrics: bool = False, 31 | save_model_before_aggregation: bool = False, 32 | save_model_after_aggregation: bool = False, 33 | **kwargs, 34 | ): 35 | super().__init__(**kwargs) 36 | self.node = node 37 | self.num_examples_per_epoch = num_examples_per_epoch 38 | self.override_metrics_with_aggregated_metrics = ( 39 | override_metrics_with_aggregated_metrics 40 | ) 41 | self.save_model_before_aggregation = save_model_before_aggregation 42 | self.save_model_after_aggregation = save_model_after_aggregation 43 | self.postfix_for_federated_metrics = postfix_for_federated_metrics 44 | self.x_test = x_test 45 | self.y_test = y_test 46 | self.test_batch_size = test_batch_size 47 | self.test_steps = test_steps 48 | self.model_before_aggregation_filename_pattern = ( 49 | "keras/{node_id}/model_before_aggregation_{epoch:05d}.h5" 50 | ) 51 | self.metrics_before_aggregation_filename_pattern = ( 52 | "keras/{node_id}/metrics_before_aggregation_{epoch:05d}.json" 53 | ) 54 | self.model_after_aggregation_filename_pattern = ( 55 | "keras/{node_id}/model_after_aggregation_{epoch:05d}.h5" 56 | ) 57 | self.metrics_after_aggregation_filename_pattern = ( 58 | "keras/{node_id}/metrics_after_aggregation_{epoch:05d}.json" 59 | ) 60 | self._federated_metrics = {} 61 | 62 | def _save_model_to_shared_folder(self, filename: str): 63 | folder = self.node.model_store.get_raw_folder() 64 | key = filename 65 | # convert model into bytes 66 | tmp_path = f"tmp_model_{self.node.node_id}.h5" 67 | self.model.save(tmp_path) 68 | with open(tmp_path, "rb") as f: 69 | model_bytes = f.read() 70 | folder[key] = model_bytes 71 | # delete 72 | os.remove(tmp_path) 73 | 74 | def _save_metrics_to_shared_folder(self, filename: str, metrics: dict): 75 | folder = self.node.model_store.get_raw_folder() 76 | key = filename 77 | metrics_bytes = BytesIO() 78 | simple_metrics = {} 79 | for k, v in metrics.items(): 80 | try: 81 | simple_metrics[k] = float(v) 82 | except: 83 | pass 84 | json_str = json.dumps(simple_metrics, indent=2) 85 | metrics_bytes.write(json_str.encode("utf-8")) 86 | folder[key] = metrics_bytes.getvalue() 87 | 88 | @property 89 | def federated_metrics(self): 90 | """Return the metrics from the federated aggreation process.""" 91 | return self._federated_metrics 92 | 93 | def _save_metrics_before_aggregation(self, logs, node_id, epoch): 94 | if logs: 95 | # Save metrics. 96 | filename = self.metrics_before_aggregation_filename_pattern.format( 97 | node_id=node_id, epoch=epoch 98 | ) 99 | self._save_metrics_to_shared_folder(filename, logs) 100 | 101 | def _save_metrics_after_aggregation(self, logs, node_id, epoch): 102 | if logs: 103 | # Save metrics. 104 | filename = self.metrics_after_aggregation_filename_pattern.format( 105 | node_id=node_id, epoch=epoch 106 | ) 107 | self._save_metrics_to_shared_folder(filename, logs) 108 | 109 | def _save_model_before_aggregation(self, node_id, epoch): 110 | if self.save_model_before_aggregation: 111 | filename = self.model_before_aggregation_filename_pattern.format( 112 | node_id=node_id, epoch=epoch 113 | ) 114 | self._save_model_to_shared_folder(filename) 115 | 116 | def _save_model_after_aggregation(self, node_id, epoch): 117 | if self.save_model_after_aggregation: 118 | filename = self.model_after_aggregation_filename_pattern.format( 119 | node_id=node_id, epoch=epoch 120 | ) 121 | self._save_model_to_shared_folder(filename) 122 | 123 | def on_epoch_end(self, epoch: int, logs=None): 124 | # use the P2PStrategy to update the model. 125 | node_id = self.node.node_id 126 | LOGGER.info(f"[flwr_serverless] on_epoch_end, logs={logs}") 127 | 128 | self._save_metrics_before_aggregation(logs, node_id, epoch) 129 | self._save_model_before_aggregation(node_id, epoch) 130 | 131 | params: Parameters = ndarrays_to_parameters(self.model.get_weights()) 132 | if self.override_metrics_with_aggregated_metrics: 133 | metrics = logs 134 | else: 135 | metrics = { 136 | k: v 137 | for k, v in logs.items() 138 | if not k.endswith(self.postfix_for_federated_metrics) 139 | } 140 | 141 | updated_params, updated_metrics = self.node.update_parameters( 142 | params, 143 | num_examples=self.num_examples_per_epoch, 144 | epoch=epoch, 145 | metrics=metrics, 146 | ) 147 | self._federated_metrics = updated_metrics 148 | 149 | self._save_metrics_after_aggregation(updated_metrics, node_id, epoch) 150 | 151 | # Update the keras model and keras logs. 152 | if updated_params is not None: 153 | self.model.set_weights(parameters_to_ndarrays(updated_params)) 154 | self._save_model_after_aggregation(node_id, epoch) 155 | if updated_metrics is not None: 156 | if self.override_metrics_with_aggregated_metrics: 157 | logs.update(updated_metrics) 158 | LOGGER.info( 159 | "[flwr_serverless] Metrics in Keras logs object are overriden." 160 | ) 161 | else: 162 | for key, value in updated_metrics.items(): 163 | logs[f"{key}{self.postfix_for_federated_metrics}"] = value 164 | msg = f"[flwr_serverless] Federated metrics are added to Keras logs object with postfix {self.postfix_for_federated_metrics}." 165 | LOGGER.info(msg) 166 | 167 | if self.x_test is not None: 168 | print("\n=========================== eval inside callback") 169 | self.model.evaluate( 170 | self.x_test, 171 | self.y_test, 172 | batch_size=self.test_batch_size, 173 | steps=self.test_steps, 174 | verbose=2, 175 | ) 176 | print("Done evaluating inside callback =====================\n") 177 | else: 178 | print("waiting for other nodes to send their parameters") 179 | 180 | # Keep track of keras logs 181 | self.logs = logs 182 | if not self.override_metrics_with_aggregated_metrics: 183 | assert any( 184 | key.endswith(self.postfix_for_federated_metrics) 185 | for key in self.logs.keys() 186 | ), f"No federated metrics found in Keras logs object. {logs}" 187 | -------------------------------------------------------------------------------- /flwr_serverless/shared_folder/s3_folder.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import time 3 | from typing import Any 4 | 5 | 6 | class S3FolderWithBytes: 7 | def __init__( 8 | self, 9 | directory: str = None, 10 | retry_sleep_time: int = 3, 11 | max_retry: int = 3, 12 | check_at_init: bool = True, 13 | ): 14 | import boto3 15 | 16 | self.directory = directory 17 | if directory.startswith("s3://"): 18 | directory = directory[5:] 19 | parts = directory.split("/", 1) 20 | if len(parts) == 1: 21 | self.bucket = parts[0] 22 | self.prefix = None 23 | else: 24 | self.bucket = parts[0] 25 | self.prefix = parts[1].rstrip("/") 26 | self.retry_sleep_time = retry_sleep_time 27 | self.max_retry = max_retry 28 | self.s3 = boto3.client("s3") 29 | if check_at_init: 30 | self._check() 31 | 32 | def _check(self): 33 | # read and write a dummy file 34 | timestamp_ms = int(time.time() * 1000) 35 | key = f"dummy_{timestamp_ms}" 36 | self[key] = b"dummy" 37 | assert self[key] == b"dummy" 38 | del self[key] 39 | 40 | def _exists(self, key: str): 41 | results = self.s3.list_objects_v2(Bucket=self.bucket, Prefix=key) 42 | return results["KeyCount"] > 0 43 | 44 | def get(self, key, default=None): 45 | success_flag_file = self._get_success_flag_file(key) 46 | patience = self.max_retry 47 | while not self._exists(success_flag_file): 48 | print(f"\nwaiting for success flag of {key}") 49 | time.sleep(self.retry_sleep_time) 50 | patience -= 1 51 | if patience == 0: 52 | return default 53 | if self.prefix is None: 54 | filepath = key 55 | else: 56 | filepath = self.prefix + "/" + key 57 | if self._exists(filepath): 58 | obj = self.s3.get_object(Bucket=self.bucket, Key=filepath) 59 | return obj["Body"].read() 60 | else: 61 | return default 62 | 63 | def __getitem__(self, key): 64 | return self.get(key) 65 | 66 | def __setitem__(self, key, value: Any): 67 | if self.prefix is None: 68 | filepath = key 69 | else: 70 | filepath = self.prefix + "/" + key 71 | if value is None: 72 | raise ValueError("value must not be None") 73 | self._delete_success_flag(key) 74 | assert isinstance(value, bytes), f"value must be bytes, but got {type(value)}" 75 | self.s3.put_object(Bucket=self.bucket, Key=filepath, Body=value) 76 | self._put_success_flag(key) 77 | 78 | def __delitem__(self, key): 79 | if self.prefix is None: 80 | filepath = key 81 | else: 82 | filepath = self.prefix + "/" + key 83 | self.s3.delete_object(Bucket=self.bucket, Key=filepath) 84 | 85 | def _get_success_flag_file(self, key): 86 | if self.prefix is None: 87 | filepath = key + ".success" 88 | else: 89 | filepath = self.prefix + "/" + (key + ".success") 90 | return filepath 91 | 92 | def _delete_success_flag(self, key): 93 | filepath = self._get_success_flag_file(key) 94 | if self._exists(filepath): 95 | self.s3.delete_object(Bucket=self.bucket, Key=filepath) 96 | 97 | def _put_success_flag(self, key): 98 | filepath = self._get_success_flag_file(key) 99 | self.s3.put_object(Bucket=self.bucket, Key=filepath, Body=b"") 100 | 101 | def __len__(self): 102 | return len(self._list_files()) 103 | 104 | def _list_files(self): 105 | filepaths = [] 106 | res = self.s3.list_objects_v2(Bucket=self.bucket, Prefix=self.prefix) 107 | for obj in res["Contents"]: 108 | filepaths.append(obj["Key"]) 109 | return filepaths 110 | 111 | def items(self): 112 | for filepath in self._list_files(): 113 | # remove prefix 114 | key = filepath.replace(self.prefix + "/", "") 115 | yield key, self.get(key) 116 | 117 | 118 | class S3FolderWithPickle: 119 | def __init__( 120 | self, 121 | directory: str = None, 122 | retry_sleep_time: int = 3, 123 | max_retry: int = 3, 124 | check_at_init: bool = True, 125 | ): 126 | import boto3 127 | 128 | self.directory = directory 129 | if directory.startswith("s3://"): 130 | directory = directory[5:] 131 | parts = directory.split("/", 1) 132 | if len(parts) == 1: 133 | self.bucket = parts[0] 134 | self.prefix = None 135 | else: 136 | self.bucket = parts[0] 137 | self.prefix = parts[1].rstrip("/") 138 | self.suffix = ".pkl" 139 | self.retry_sleep_time = retry_sleep_time 140 | self.max_retry = max_retry 141 | self.s3 = boto3.client("s3") 142 | 143 | if check_at_init: 144 | self._check() 145 | 146 | def _check(self): 147 | # read and write a dummy file 148 | timestamp_ms = int(time.time() * 1000) 149 | key = f"dummy_{timestamp_ms}" 150 | self[key] = "dummy" 151 | assert self[key] == "dummy" 152 | del self[key] 153 | 154 | def _exists(self, key: str): 155 | results = self.s3.list_objects_v2(Bucket=self.bucket, Prefix=key) 156 | return results["KeyCount"] > 0 157 | 158 | def get_raw_folder(self): 159 | """ 160 | Creates a S3FolderWithBytes object with the same directory. 161 | The "raw folder" is a folder that stores raw bytes. 162 | This is different from the "regular" folder which stores pickled objects. 163 | """ 164 | return S3FolderWithBytes( 165 | self.directory, 166 | retry_sleep_time=self.retry_sleep_time, 167 | max_retry=self.max_retry, 168 | check_at_init=False, 169 | ) 170 | 171 | def get(self, key, default=None): 172 | success_flag_file = self._get_success_flag_file(key) 173 | patience = self.max_retry 174 | while not self._exists(success_flag_file): 175 | print(f"\nwaiting for success flag of {key}") 176 | time.sleep(self.retry_sleep_time) 177 | patience -= 1 178 | if patience == 0: 179 | return default 180 | if self.prefix is None: 181 | filepath = key + self.suffix 182 | else: 183 | filepath = self.prefix + "/" + key + self.suffix 184 | if self._exists(filepath): 185 | obj = self.s3.get_object(Bucket=self.bucket, Key=filepath) 186 | return pickle.loads(obj["Body"].read()) 187 | else: 188 | return default 189 | 190 | def __getitem__(self, key): 191 | return self.get(key) 192 | 193 | def __setitem__(self, key, value: Any): 194 | if self.prefix is None: 195 | filepath = key + self.suffix 196 | else: 197 | filepath = self.prefix + "/" + (key + self.suffix) 198 | if value is None: 199 | raise ValueError("value must not be None") 200 | self._delete_success_flag(key) 201 | self.s3.put_object(Bucket=self.bucket, Key=filepath, Body=pickle.dumps(value)) 202 | self._put_success_flag(key) 203 | 204 | def __delitem__(self, key): 205 | if self.prefix is None: 206 | filepath = key + self.suffix 207 | else: 208 | filepath = self.prefix + "/" + (key + self.suffix) 209 | self.s3.delete_object(Bucket=self.bucket, Key=filepath) 210 | 211 | def _get_success_flag_file(self, key): 212 | if self.prefix is None: 213 | filepath = key + ".success" 214 | else: 215 | filepath = self.prefix + "/" + (key + ".success") 216 | return filepath 217 | 218 | def _delete_success_flag(self, key): 219 | filepath = self._get_success_flag_file(key) 220 | if self._exists(filepath): 221 | self.s3.delete_object(Bucket=self.bucket, Key=filepath) 222 | 223 | def _put_success_flag(self, key): 224 | filepath = self._get_success_flag_file(key) 225 | self.s3.put_object(Bucket=self.bucket, Key=filepath, Body=b"") 226 | 227 | def __len__(self): 228 | return len(self._list_pickle_files()) 229 | 230 | def _list_pickle_files(self): 231 | filepaths = [] 232 | res = self.s3.list_objects_v2(Bucket=self.bucket, Prefix=self.prefix) 233 | for obj in res["Contents"]: 234 | if obj["Key"].endswith(self.suffix): 235 | filepaths.append(obj["Key"]) 236 | return filepaths 237 | 238 | def items(self): 239 | for filepath in self._list_pickle_files(): 240 | key_and_parameter = self.get_parameter(filepath) 241 | yield key_and_parameter 242 | 243 | def get_parameter(self, filepath): 244 | model_key = filepath.split("/")[-1].replace(self.suffix, "") 245 | return model_key, self.get(model_key) 246 | -------------------------------------------------------------------------------- /flwr_serverless/federated_node/sync_federated_node.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | from uuid import uuid4 3 | import time 4 | import logging 5 | from flwr.common import ( 6 | Code, 7 | FitRes, 8 | Parameters, 9 | Status, 10 | parameters_to_ndarrays, 11 | ) 12 | from flwr.server.client_proxy import ClientProxy 13 | from .aggregatable import Aggregatable 14 | 15 | 16 | LOGGER = logging.getLogger(__name__) 17 | 18 | 19 | class SyncFederatedNode: 20 | """ 21 | Synchronous federated learning. 22 | 23 | TODO: allow user to specify the metric names to include and exclude. 24 | """ 25 | 26 | def __init__(self, shared_folder, strategy, num_nodes: int): 27 | self.node_id = str(uuid4()) 28 | self.counter = 0 29 | self.strategy = strategy 30 | self.model_store = shared_folder 31 | self.seen_models = set() 32 | self.num_nodes = num_nodes 33 | assert self.num_nodes is not None, "num_nodes must be specified" 34 | 35 | def _aggregate(self, aggregatables: List[Aggregatable]) -> Aggregatable: 36 | # Aggregation using the flwr strategy. 37 | results: List[Tuple[ClientProxy, FitRes]] = [ 38 | ( 39 | None, 40 | FitRes( 41 | status=Status(code=Code.OK, message="Success"), 42 | parameters=param_holder.parameters, 43 | num_examples=param_holder.num_examples, 44 | metrics=param_holder.metrics, 45 | ), 46 | ) 47 | for param_holder in aggregatables 48 | ] 49 | 50 | aggregated_parameters, aggregated_metrics = self.strategy.aggregate_fit( 51 | server_round=self.counter + 1, results=results, failures=[] 52 | ) 53 | aggregated_metrics = self._update_aggregated_metrics_in_case_flwr_did_not_do_it( 54 | aggregatables, aggregated_metrics 55 | ) 56 | 57 | self.counter += 1 58 | return Aggregatable( 59 | parameters=aggregated_parameters, 60 | num_examples=sum( 61 | [param_holder.num_examples for param_holder in aggregatables] 62 | ), 63 | metrics=aggregated_metrics, 64 | ) 65 | 66 | def _update_aggregated_metrics_in_case_flwr_did_not_do_it( 67 | self, aggregatables, aggregated_metrics: dict 68 | ) -> dict: 69 | if len(aggregated_metrics) == 0: 70 | aggregated_metrics = {} 71 | aggregated_metrics["num_examples"] = sum( 72 | [param_holder.num_examples for param_holder in aggregatables] 73 | ) 74 | aggregated_metrics["num_nodes"] = len(aggregatables) 75 | first_metric = aggregatables[0].metrics 76 | for k, _ in first_metric.items(): 77 | if k in ["num_nodes", "num_examples"]: 78 | continue 79 | aggregated_metrics[k] = ( 80 | sum( 81 | [ 82 | param_holder.metrics[k] * param_holder.num_examples 83 | for param_holder in aggregatables 84 | ] 85 | ) 86 | / aggregated_metrics["num_examples"] 87 | ) 88 | LOGGER.info(f"Aggregated metrics: {aggregated_metrics}") 89 | return aggregated_metrics 90 | 91 | def _get_parameters_from_other_nodes(self, epoch: int) -> List[Aggregatable]: 92 | print( 93 | f"To get parameters from other nodes. Current node is {self.node_id}. Epoch {epoch}." 94 | ) 95 | other_parameters_from_epoch = [] 96 | 97 | # For debugging 98 | # with open("model_store.txt", "a") as f: 99 | # f.write(f"Current model_store for {self.node_id} on epoch {epoch}:\n") 100 | # for j, (key, value) in enumerate(self.model_store.items()): 101 | # f.write( 102 | # f"[{j}] key: {key}, epoch: {value['epoch']}, node_id: {self.node_id}\n" 103 | # ) 104 | 105 | keys_to_delete = [] 106 | for key, value in self.model_store.items(): 107 | # TODO: `value`` includes model parameters. Separate 108 | # model parameters and metadata. 109 | if not isinstance(value, dict): 110 | # print("model store item not a dict, skipping") 111 | continue 112 | if "epoch" not in value: 113 | raise KeyError(f"epoch not in the dictionary: {value.keys()}") 114 | if value["node_id"] == self.node_id and value["epoch"] < epoch - 1: 115 | # stale checkpoint from self, delete 116 | keys_to_delete.append(key) 117 | 118 | if value["epoch"] != epoch or value["node_id"] == self.node_id: 119 | continue 120 | other_parameters_from_epoch.append(value["aggregatable"]) 121 | 122 | for key in keys_to_delete: 123 | del self.model_store[key] 124 | 125 | # print("Model store:") 126 | # for kk, v in self.model_store.items(): 127 | # if "epoch" in v: 128 | # print(f"model hash: {kk}, node_id: {v['node_id']}, epoch: {v['epoch']}") 129 | 130 | LOGGER.info( 131 | f"Got {len(other_parameters_from_epoch)} model checkpoints from other nodes." 132 | ) 133 | return other_parameters_from_epoch 134 | 135 | def update_parameters( 136 | self, 137 | local_parameters: Parameters, 138 | num_examples: int = None, 139 | metrics: dict = None, 140 | epoch: int = None, 141 | upload_only=False, 142 | ) -> Tuple[Parameters, dict]: 143 | model_hash = self.node_id + "_" + str(time.time()) 144 | self_aggregatable = Aggregatable( 145 | parameters=local_parameters, 146 | num_examples=num_examples, 147 | metrics=metrics, 148 | ) 149 | if not isinstance(epoch, int): 150 | print(f"Warning! epoch {epoch} is not an int, rounding to nearest int") 151 | # epoch = int(epoch + 0.5) 152 | raise ValueError(f"epoch {epoch} is not an int") 153 | self.model_store[model_hash] = dict( 154 | aggregatable=self_aggregatable, 155 | model_hash=model_hash, 156 | epoch=epoch, 157 | node_id=self.node_id, 158 | ) 159 | print(f"Added local model to model_store: {model_hash}") 160 | # if len(self.model_store) > self.num_nodes: 161 | # raise ValueError( 162 | # f"Too many nodes in the federated learning run: {len(self.model_store)}. Expected {self.num_nodes}" 163 | # ) 164 | if upload_only: 165 | return None 166 | aggregatables_from_other_nodes = self._get_parameters_from_other_nodes(epoch) 167 | wait_counter = 0 168 | max_retry = 3 # 60 * 10 169 | while len(aggregatables_from_other_nodes) < self.num_nodes - 1: 170 | # Other nodes have not all sent their parameters yet. 171 | # Wait with exponential back-off. 172 | LOGGER.info( 173 | f"Got {len(aggregatables_from_other_nodes)} parameters from other nodes." 174 | ) 175 | LOGGER.info( 176 | f"Waiting for {self.num_nodes - 1 - len(aggregatables_from_other_nodes)} more." 177 | ) 178 | wait_seconds = min(60 * 2**wait_counter, max_retry) 179 | LOGGER.info(f"Waiting {wait_seconds} seconds..") 180 | time.sleep(wait_seconds) 181 | wait_counter += 1 182 | aggregatables_from_other_nodes = self._get_parameters_from_other_nodes( 183 | epoch 184 | ) 185 | 186 | # Aggregate the parameters from other nodes 187 | parameters_from_all_nodes = aggregatables_from_other_nodes + [self_aggregatable] 188 | aggregated_parameters_and_metrics = self._aggregate(parameters_from_all_nodes) 189 | # Print weight delta 190 | LOGGER.info( 191 | f"Finished weight aggregation for epoch {epoch} at node {self.node_id}" 192 | ) 193 | self._print_weight_delta( 194 | local_parameters, aggregated_parameters_and_metrics.parameters 195 | ) 196 | return ( 197 | aggregated_parameters_and_metrics.parameters, 198 | aggregated_parameters_and_metrics.metrics, 199 | ) 200 | 201 | def _print_weight_delta( 202 | self, previous_weights: Parameters, new_weights: Parameters 203 | ) -> float: 204 | if previous_weights is None: 205 | return 206 | # convert to numpy 207 | previous_weights_np = parameters_to_ndarrays(previous_weights) 208 | new_weights_np = parameters_to_ndarrays(new_weights) 209 | delta = 0 210 | count = 0 211 | for w1, w2 in zip(previous_weights_np, new_weights_np): 212 | delta += float(abs(w1 - w2).sum()) 213 | count += w1.size 214 | avg_l1_diff = delta / float(count) 215 | LOGGER.info(f" Weight delta (average absolute difference): {avg_l1_diff}") 216 | return avg_l1_diff 217 | -------------------------------------------------------------------------------- /flwr_serverless/federated_node/async_federated_node.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Tuple 3 | from uuid import uuid4 4 | import time 5 | from flwr.common import ( 6 | Code, 7 | FitRes, 8 | Parameters, 9 | Status, 10 | parameters_to_ndarrays, 11 | ) 12 | from flwr.server.client_proxy import ClientProxy 13 | from flwr.server.strategy import Strategy 14 | from flwr_serverless.shared_folder.base_folder import SharedFolder 15 | from .aggregatable import Aggregatable 16 | 17 | 18 | LOGGER = logging.getLogger(__name__) 19 | 20 | 21 | class AsyncFederatedNode: 22 | """ 23 | Synchronous version: 24 | 25 | 8 am: 26 | client 1 (faster client) sends params1_1 27 | server has no params yet, so client 1 is told to wait 28 | server keeps params1_1 29 | 30 | 9 am: 31 | client 2 (slower) sends params2_1 (client 1 is waiting from 8 am to 9 am) 32 | server aggregated params1_1 and params2_1, and sends back to client 1 and 2 33 | both client 1 and client 2 updates their local models, and resume training 34 | 35 | 10 am: 36 | client 1: sends params1_2 37 | ... 38 | 39 | Asynchronous version (client does not wait for the server to get new aggregated weights): 40 | 41 | 8 am: 42 | client 1 sends params1_1 43 | server returns params1_1, and sets params_federated_0 = params1_1 44 | client 1 keeps training with params1_1 for 2 hours 45 | 46 | 9 am: 47 | client 2 sends params2_1 48 | server aggregates params1_1 and params2_1 into params_federated_1 49 | server returns aggregated params_federated_1 50 | client 2 updates its params to params_federated_1 and keeps training 51 | (but client 1 is busy doing its own training now, so it is not updated) 52 | 53 | 10 am: 54 | client 1 sends params1_2 55 | server aggregates params_federated_1 and params1_2 into params_federated_2 56 | server returns aggregated params_federated_2 57 | client 1 updates its params to params_federated_2 and keeps training 58 | 59 | References: 60 | - [Semi-Synchronous Federated Learning for Energy-Efficient 61 | Training and Accelerated Convergence in Cross-Silo Settings](https://arxiv.org/pdf/2102.02849.pdf) 62 | """ 63 | 64 | def __init__( 65 | self, 66 | shared_folder: SharedFolder, 67 | strategy: Strategy, 68 | ignore_seen_models: bool = False, 69 | node_id: str = None, 70 | ): 71 | self.node_id = node_id or str(uuid4()) 72 | self.counter = 0 73 | self.strategy = strategy 74 | self.model_store = shared_folder 75 | self.sample_sizes_from_other_nodes = {} # node_id -> num_examples 76 | self.ignore_seen_models = ignore_seen_models 77 | self.seen_models = set() 78 | 79 | def _aggregate( 80 | self, 81 | aggregatables: List[Aggregatable], 82 | ) -> Aggregatable: 83 | # Aggregation using the flwr strategy. 84 | results: List[Tuple[ClientProxy, FitRes]] = [ 85 | ( 86 | None, 87 | FitRes( 88 | status=Status(code=Code.OK, message="Success"), 89 | parameters=param_holder.parameters, 90 | num_examples=param_holder.num_examples, 91 | metrics=param_holder.metrics, 92 | ), 93 | ) 94 | for param_holder in aggregatables 95 | ] 96 | 97 | aggregated_parameters, aggregated_metrics = self.strategy.aggregate_fit( 98 | server_round=self.counter + 1, results=results, failures=[] 99 | ) 100 | aggregated_metrics = self._update_aggregated_metrics_in_case_flwr_did_not_do_it( 101 | aggregatables, aggregated_metrics 102 | ) 103 | 104 | self.counter += 1 105 | return Aggregatable( 106 | parameters=aggregated_parameters, 107 | num_examples=sum( 108 | [param_holder.num_examples for param_holder in aggregatables] 109 | ), 110 | metrics=aggregated_metrics, 111 | ) 112 | 113 | def _update_aggregated_metrics_in_case_flwr_did_not_do_it( 114 | self, aggregatables, aggregated_metrics: dict 115 | ) -> dict: 116 | if len(aggregated_metrics) == 0: 117 | aggregated_metrics = {} 118 | aggregated_metrics["num_examples"] = sum( 119 | [param_holder.num_examples for param_holder in aggregatables] 120 | ) 121 | aggregated_metrics["num_nodes"] = len(aggregatables) 122 | first_metric = aggregatables[0].metrics 123 | if first_metric is None: 124 | LOGGER.warning(f"No metrics found in {aggregatables[0]}") 125 | return aggregated_metrics 126 | for k, _ in first_metric.items(): 127 | if k in ["num_nodes", "num_examples"]: 128 | continue 129 | aggregated_metrics[k] = ( 130 | sum( 131 | [ 132 | param_holder.metrics[k] * param_holder.num_examples 133 | for param_holder in aggregatables 134 | ] 135 | ) 136 | / aggregated_metrics["num_examples"] 137 | ) 138 | LOGGER.info(f"Aggregated metrics: {aggregated_metrics}") 139 | return aggregated_metrics 140 | 141 | def _get_aggregatables_from_other_nodes(self) -> List[Aggregatable]: 142 | unseen_parameters_from_other_nodes = [] 143 | for key, value in self.model_store.items(): 144 | if key.startswith("accum_num_examples_"): 145 | continue 146 | if isinstance(value, dict) and "model_hash" in value: 147 | if key != self.node_id: 148 | model_hash = value["model_hash"] 149 | if ( 150 | not self.ignore_seen_models 151 | ) or model_hash not in self.seen_models: 152 | self.seen_models.add(model_hash) 153 | unseen_parameters_from_other_nodes.append(value["aggregatable"]) 154 | return unseen_parameters_from_other_nodes 155 | 156 | def update_parameters( 157 | self, 158 | local_parameters: Parameters, 159 | num_examples: int = None, 160 | metrics: dict = None, 161 | epoch: int = None, 162 | upload_only=False, 163 | ) -> Tuple[Parameters, dict]: 164 | LOGGER.info(f"node {self.node_id}: in update_parameters") 165 | assert isinstance(num_examples, int) 166 | assert num_examples >= 1 167 | self_aggregatable = Aggregatable( 168 | parameters=local_parameters, 169 | num_examples=num_examples, 170 | metrics=metrics, 171 | ) 172 | self.model_store[self.node_id] = dict( 173 | aggregatable=self_aggregatable, 174 | model_hash=self.node_id + "_" + str(time.time()), 175 | epoch=epoch, 176 | node_id=self.node_id, 177 | ) 178 | if upload_only: 179 | return local_parameters, metrics 180 | (aggregatables_from_other_nodes) = self._get_aggregatables_from_other_nodes() 181 | LOGGER.info( 182 | f"node {self.node_id}: {len(aggregatables_from_other_nodes or [])} aggregatables_from_other_nodes" 183 | ) 184 | if len(aggregatables_from_other_nodes) == 0: 185 | # No other nodes, so just return the local parameters 186 | return local_parameters, metrics 187 | else: 188 | # Aggregate the parameters from other nodes 189 | parameters_from_all_nodes = [ 190 | self_aggregatable 191 | ] + aggregatables_from_other_nodes 192 | updated_aggregatable = self._aggregate(parameters_from_all_nodes) 193 | 194 | # It is counter-productive to set self.model_store[node_id] to the aggregated parameters. 195 | # It makes the accuracy worse. 196 | # self.model_store[self.node_id] = dict( 197 | # parameters=aggregated_parameters, 198 | # model_hash=self.node_id + str(time.time()), 199 | # num_examples=num_examples, 200 | # ) 201 | # TODO: fill in aggregated_metrics 202 | aggregated_parameters = updated_aggregatable.parameters 203 | aggregated_metrics = updated_aggregatable.metrics 204 | 205 | # print the weight delta 206 | LOGGER.info( 207 | f"Finished weight aggregation for epoch {epoch} at node {self.node_id}" 208 | ) 209 | self._print_weight_delta( 210 | previous_weights=local_parameters, 211 | new_weights=aggregated_parameters, 212 | ) 213 | 214 | return aggregated_parameters, aggregated_metrics 215 | 216 | def _print_weight_delta( 217 | self, previous_weights: Parameters, new_weights: Parameters 218 | ) -> float: 219 | if previous_weights is None: 220 | return 221 | # convert to numpy 222 | previous_weights_np = parameters_to_ndarrays(previous_weights) 223 | new_weights_np = parameters_to_ndarrays(new_weights) 224 | delta = 0 225 | count = 0 226 | for w1, w2 in zip(previous_weights_np, new_weights_np): 227 | delta += float(abs(w1 - w2).sum()) 228 | count += w1.size 229 | avg_l1_diff = delta / float(count) 230 | LOGGER.info(f" Weight delta (average absolute difference): {avg_l1_diff}") 231 | return avg_l1_diff 232 | -------------------------------------------------------------------------------- /doc/paper/preprint.sty: -------------------------------------------------------------------------------- 1 | \NeedsTeXFormat{LaTeX2e} 2 | 3 | %% declare nonatbib option, which does not load natbib in case of 4 | %% package clash (users can pass options to natbib via 5 | %% \PassOptionsToPackage) 6 | %\newif\if@natbib\@natbibtrue 7 | %\DeclareOption{nonatbib}{ 8 | % \@natbibfalse 9 | %} 10 | 11 | \ProcessOptions\relax 12 | 13 | % Use roman fonts for equations if available 14 | \IfFileExists{txfonts.sty}% 15 | {\AtEndOfClass{\RequirePackage{txfonts}% 16 | \gdef\ttdefault{cmtt}% 17 | \let\iint\relax 18 | \let\iiint\relax 19 | \let\iiiint\relax 20 | \let\idotsint\relax 21 | \let\openbox\relax}}{\RequirePackage{times}} 22 | 23 | % other font configurations 24 | \renewcommand{\rmdefault}{ptm} 25 | \renewcommand{\sfdefault}{phv} 26 | 27 | %% Load natbib unless told otherwise 28 | %\if@natbib 29 | % \RequirePackage{natbib} 30 | %\fi 31 | 32 | % Define page geometry 33 | \RequirePackage[verbose=true,letterpaper]{geometry} 34 | \AtBeginDocument{ 35 | \newgeometry{ 36 | textheight=9.3in, 37 | textwidth=7.1in, 38 | top=0.96in, 39 | headheight=0.15in, 40 | headsep=0.14in, 41 | footskip=0in 42 | } 43 | } 44 | 45 | % Add background margin tick marks 46 | \RequirePackage{background} 47 | \SetBgScale{1} 48 | \SetBgAngle{0} 49 | \SetBgColor{black} 50 | \SetBgContents{% 51 | \begin{tikzpicture}[remember picture,overlay] 52 | \node at (-3.55in,5.2in) {\rule{.4pt}{.4in}}; 53 | \node at (3.55in,5.2in) {\rule{.4pt}{.4in}}; 54 | \node at (-3.95in,4.8in) {\rule{.4in}{.4pt}}; 55 | \node at (3.95in,4.8in) {\rule{.4in}{.4pt}}; 56 | \node at (-3.55in,-5.2in) {\rule{.4pt}{.4in}}; 57 | \node at (3.55in,-5.2in) {\rule{.4pt}{.4in}}; 58 | \node at (-3.95in,-4.8in) {\rule{.4in}{.4pt}}; 59 | \node at (3.95in,-4.8in) {\rule{.4in}{.4pt}}; 60 | \end{tikzpicture}} 61 | 62 | \widowpenalty=10000 63 | \clubpenalty=10000 64 | \flushbottom 65 | \sloppy 66 | 67 | % easy adding of ORCiDs 68 | \usepackage{scalerel} 69 | \usepackage{tikz} 70 | \usetikzlibrary{svg.path} 71 | \definecolor{orcidlogocol}{HTML}{A6CE39} 72 | \tikzset{ 73 | orcidlogo/.pic={ 74 | \fill[orcidlogocol] svg{M256,128c0,70.7-57.3,128-128,128C57.3,256,0,198.7,0,128C0,57.3,57.3,0,128,0C198.7,0,256,57.3,256,128z}; 75 | \fill[white] svg{M86.3,186.2H70.9V79.1h15.4v48.4V186.2z} 76 | svg{M108.9,79.1h41.6c39.6,0,57,28.3,57,53.6c0,27.5-21.5,53.6-56.8,53.6h-41.8V79.1z M124.3,172.4h24.5c34.9,0,42.9-26.5,42.9-39.7c0-21.5-13.7-39.7-43.7-39.7h-23.7V172.4z} 77 | svg{M88.7,56.8c0,5.5-4.5,10.1-10.1,10.1c-5.6,0-10.1-4.6-10.1-10.1c0-5.6,4.5-10.1,10.1-10.1C84.2,46.7,88.7,51.3,88.7,56.8z}; 78 | } 79 | } 80 | 81 | \newcommand\orcid[1]{\href{https://orcid.org/#1}{\mbox{\scalerel*{ 82 | \begin{tikzpicture}[yscale=-1,transform shape] 83 | \pic{orcidlogo}; 84 | \end{tikzpicture} 85 | }{|}}} \href{https://orcid.org/#1}{#1}} 86 | 87 | % Header options 88 | \RequirePackage{fancyhdr} 89 | \fancyhf{} 90 | \pagestyle{fancy} 91 | \renewcommand{\headrulewidth}{0pt} 92 | \fancyheadoffset{0pt} 93 | \lhead{\scshape Preprint -- \@title} 94 | \rhead{\thepage} 95 | %\rfoot{\thepage} 96 | 97 | %Handling Keywords 98 | \def\keywordname{{\bfseries \emph Keywords}}% 99 | \def\keywords#1{\par\addvspace\medskipamount{\rightskip=0pt plus1cm 100 | \def\and{\ifhmode\unskip\nobreak\fi\ $\cdot$ 101 | }\noindent\keywordname\enspace\ignorespaces#1\par}} 102 | 103 | % font sizes with reduced leading 104 | \renewcommand{\normalsize}{% 105 | \@setfontsize\normalsize\@xpt\@xipt 106 | \abovedisplayskip 7\p@ \@plus 2\p@ \@minus 5\p@ 107 | \abovedisplayshortskip \z@ \@plus 3\p@ 108 | \belowdisplayskip \abovedisplayskip 109 | \belowdisplayshortskip 4\p@ \@plus 3\p@ \@minus 3\p@ 110 | } 111 | \normalsize 112 | \renewcommand{\small}{% 113 | \@setfontsize\small\@ixpt\@xpt 114 | \abovedisplayskip 6\p@ \@plus 1.5\p@ \@minus 4\p@ 115 | \abovedisplayshortskip \z@ \@plus 2\p@ 116 | \belowdisplayskip \abovedisplayskip 117 | \belowdisplayshortskip 3\p@ \@plus 2\p@ \@minus 2\p@ 118 | } 119 | \renewcommand{\footnotesize}{\@setfontsize\footnotesize\@ixpt\@xpt} 120 | \renewcommand{\scriptsize}{\@setfontsize\scriptsize\@viipt\@viiipt} 121 | \renewcommand{\tiny}{\@setfontsize\tiny\@vipt\@viipt} 122 | \renewcommand{\large}{\@setfontsize\large\@xiipt{14}} 123 | \renewcommand{\Large}{\@setfontsize\Large\@xivpt{16}} 124 | \renewcommand{\LARGE}{\@setfontsize\LARGE\@xviipt{20}} 125 | \renewcommand{\huge}{\@setfontsize\huge\@xxpt{23}} 126 | \renewcommand{\Huge}{\@setfontsize\Huge\@xxvpt{28}} 127 | 128 | % sections with less space 129 | \providecommand{\section}{} 130 | \renewcommand{\section}{% 131 | \@startsection{section}{1}{\z@}% 132 | {-2.0ex \@plus -0.5ex \@minus -0.2ex}% 133 | { 1.5ex \@plus 0.3ex \@minus 0.2ex}% 134 | {\large\scshape\raggedright}% 135 | } 136 | \providecommand{\subsection}{} 137 | \renewcommand{\subsection}{% 138 | \@startsection{subsection}{2}{\z@}% 139 | {-1.8ex \@plus -0.5ex \@minus -0.2ex}% 140 | { 0.8ex \@plus 0.2ex}% 141 | {\normalsize\itshape\raggedright}% 142 | } 143 | \providecommand{\subsubsection}{} 144 | \renewcommand{\subsubsection}{% 145 | \@startsection{subsubsection}{3}{\z@}% 146 | {-1.5ex \@plus -0.5ex \@minus -0.2ex}% 147 | { 0.5ex \@plus 0.2ex}% 148 | {\normalsize\itshape\raggedright}% 149 | } 150 | \providecommand{\paragraph}{} 151 | \renewcommand{\paragraph}{% 152 | \@startsection{paragraph}{4}{\z@}% 153 | {1.5ex \@plus 0.5ex \@minus 0.2ex}% 154 | {-1em}% 155 | {\normalsize\bf}% 156 | } 157 | \providecommand{\subparagraph}{} 158 | \renewcommand{\subparagraph}{% 159 | \@startsection{subparagraph}{5}{\z@}% 160 | {1.5ex \@plus 0.5ex \@minus 0.2ex}% 161 | {-1em}% 162 | {\normalsize\bf}% 163 | } 164 | \providecommand{\subsubsubsection}{} 165 | \renewcommand{\subsubsubsection}{% 166 | \vskip5pt{\noindent\normalsize\rm\raggedright}% 167 | } 168 | 169 | % float placement 170 | \renewcommand{\topfraction }{0.85} 171 | \renewcommand{\bottomfraction }{0.4} 172 | \renewcommand{\textfraction }{0.1} 173 | \renewcommand{\floatpagefraction}{0.7} 174 | 175 | \newlength{\@abovecaptionskip}\setlength{\@abovecaptionskip}{7\p@} 176 | \newlength{\@belowcaptionskip}\setlength{\@belowcaptionskip}{\z@} 177 | 178 | \setlength{\abovecaptionskip}{\@abovecaptionskip} 179 | \setlength{\belowcaptionskip}{\@belowcaptionskip} 180 | 181 | % swap above/belowcaptionskip lengths for tables 182 | \renewenvironment{table} 183 | {\setlength{\abovecaptionskip}{\@belowcaptionskip}% 184 | \setlength{\belowcaptionskip}{\@abovecaptionskip}% 185 | \@float{table}} 186 | {\end@float} 187 | 188 | % footnote formatting 189 | \setlength{\footnotesep }{6.65\p@} 190 | \setlength{\skip\footins}{9\p@ \@plus 4\p@ \@minus 2\p@} 191 | \renewcommand{\footnoterule}{\kern-3\p@ \hrule width 12pc \kern 2.6\p@} 192 | \setcounter{footnote}{0} 193 | 194 | % paragraph formatting 195 | \setlength{\parindent}{\z@} 196 | \setlength{\parskip }{5.5\p@} 197 | 198 | % list formatting 199 | \setlength{\topsep }{4\p@ \@plus 1\p@ \@minus 2\p@} 200 | \setlength{\partopsep }{1\p@ \@plus 0.5\p@ \@minus 0.5\p@} 201 | \setlength{\itemsep }{2\p@ \@plus 1\p@ \@minus 0.5\p@} 202 | \setlength{\parsep }{2\p@ \@plus 1\p@ \@minus 0.5\p@} 203 | \setlength{\leftmargin }{3pc} 204 | \setlength{\leftmargini }{\leftmargin} 205 | \setlength{\leftmarginii }{2em} 206 | \setlength{\leftmarginiii}{1.5em} 207 | \setlength{\leftmarginiv }{1.0em} 208 | \setlength{\leftmarginv }{0.5em} 209 | \def\@listi {\leftmargin\leftmargini} 210 | \def\@listii {\leftmargin\leftmarginii 211 | \labelwidth\leftmarginii 212 | \advance\labelwidth-\labelsep 213 | \topsep 2\p@ \@plus 1\p@ \@minus 0.5\p@ 214 | \parsep 1\p@ \@plus 0.5\p@ \@minus 0.5\p@ 215 | \itemsep \parsep} 216 | \def\@listiii{\leftmargin\leftmarginiii 217 | \labelwidth\leftmarginiii 218 | \advance\labelwidth-\labelsep 219 | \topsep 1\p@ \@plus 0.5\p@ \@minus 0.5\p@ 220 | \parsep \z@ 221 | \partopsep 0.5\p@ \@plus 0\p@ \@minus 0.5\p@ 222 | \itemsep \topsep} 223 | \def\@listiv {\leftmargin\leftmarginiv 224 | \labelwidth\leftmarginiv 225 | \advance\labelwidth-\labelsep} 226 | \def\@listv {\leftmargin\leftmarginv 227 | \labelwidth\leftmarginv 228 | \advance\labelwidth-\labelsep} 229 | \def\@listvi {\leftmargin\leftmarginvi 230 | \labelwidth\leftmarginvi 231 | \advance\labelwidth-\labelsep} 232 | 233 | % create title 234 | \providecommand{\maketitle}{} 235 | \renewcommand{\maketitle}{% 236 | \par 237 | \begingroup 238 | \renewcommand{\thefootnote}{\fnsymbol{footnote}} 239 | % for perfect author name centering 240 | \renewcommand{\@makefnmark}{\hbox to \z@{$^{\@thefnmark}$\hss}} 241 | % The footnote-mark was overlapping the footnote-text, 242 | % added the following to fix this problem (MK) 243 | \long\def\@makefntext##1{% 244 | \parindent 1em\noindent 245 | \hbox to 1.8em{\hss $\m@th ^{\@thefnmark}$}##1 246 | } 247 | \thispagestyle{empty} 248 | \vspace*{-0.5cm} 249 | \@maketitle 250 | \@thanks 251 | % \@notice 252 | \endgroup 253 | \let\maketitle\relax 254 | \let\thanks\relax 255 | } 256 | 257 | % create title (includes both anonymized and non-anonymized versions) 258 | \providecommand{\@maketitle}{} 259 | \renewcommand{\@maketitle}{% 260 | \vbox{% 261 | \hsize\textwidth 262 | \linewidth\hsize 263 | \centering 264 | {\LARGE\sc \@title\par} 265 | \vskip 0.1in 266 | \textsc{Preprint, compiled \today}\\ 267 | \def\And{% 268 | \end{tabular}\hfil\linebreak[0]\hfil% 269 | \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\ignorespaces% 270 | } 271 | \def\AND{% 272 | \end{tabular}\hfil\linebreak[4]\hfil% 273 | \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\ignorespaces% 274 | } 275 | \begin{tabular}[t]{c}\bf\rule{\z@}{24\p@}\@author\end{tabular}% 276 | \vskip 0.2in 277 | } 278 | } 279 | 280 | % add conference notice to bottom of first page 281 | \newcommand{\ftype@noticebox}{8} 282 | \newcommand{\@notice}{% 283 | % give a bit of extra room back to authors on first page 284 | \enlargethispage{2\baselineskip}% 285 | \@float{noticebox}[b]% 286 | \footnotesize\@noticestring% 287 | \end@float% 288 | } 289 | 290 | % abstract styling 291 | \renewenvironment{abstract} 292 | { 293 | \centerline 294 | {\large \bfseries \scshape Abstract} 295 | \begin{quote} 296 | } 297 | { 298 | \end{quote} 299 | } 300 | 301 | \endinput -------------------------------------------------------------------------------- /experiments/utils/base_experiment_runner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # from flwr_serverless.keras.example import MnistModelBuilder 4 | from experiments.model.simple_mnist_model import SimpleMnistModel 5 | from dataclasses import dataclass 6 | from experiments.model.keras_models import ResNetModelBuilder 7 | 8 | 9 | @dataclass 10 | class Config: 11 | # non shared config parameters 12 | num_nodes: int 13 | strategy: str 14 | project: str = "experiments" 15 | track: bool = False 16 | random_seed: int = 0 17 | 18 | # shared config parameters 19 | use_async: bool = True 20 | federated_type: str = "concurrent" 21 | dataset: str = "mnist" 22 | epochs: int = 100 23 | batch_size: int = 32 24 | steps_per_epoch: int = 64 25 | lr: float = 0.001 26 | test_steps: int = None 27 | net: str = "simple" 28 | data_split: str = "skewed" 29 | skew_factor: float = 0.9 30 | 31 | # Ignore, for logging purposes 32 | use_default_configs: bool = False 33 | 34 | 35 | class BaseExperimentRunner: 36 | def __init__(self, config, tracking=False): 37 | if isinstance(config, dict): 38 | config = Config(**config) 39 | assert isinstance( 40 | config, Config 41 | ), f"config must be of type Config, got {type(config)}" 42 | self.config = config 43 | self.num_nodes = config.num_nodes 44 | self.batch_size = config.batch_size 45 | self.epochs = config.epochs 46 | self.steps_per_epoch = config.steps_per_epoch 47 | self.lr = config.lr 48 | # In experiment tracking, log the actual test steps and test data size 49 | self.test_steps = config.test_steps 50 | self.use_async = config.use_async 51 | self.federated_type = config.federated_type 52 | self.strategy_name = config.strategy 53 | self.data_split = config.data_split 54 | self.dataset = config.dataset 55 | self.net = config.net 56 | 57 | self.tracking = tracking 58 | 59 | self.get_original_data() 60 | 61 | # ***currently works only for mnist*** 62 | def create_models(self): 63 | if self.dataset == "mnist": 64 | assert self.net == "simple", f"Net not supported: {self.net} for mnist" 65 | if self.net == "simple": 66 | return [SimpleMnistModel(lr=self.lr).run() for _ in range(self.num_nodes)] 67 | elif self.net == "resnet50": 68 | return [ 69 | ResNetModelBuilder(lr=self.lr, net="ResNet50", weights="imagenet").run() 70 | for _ in range(self.num_nodes) 71 | ] 72 | elif self.net == "resnet18": 73 | return [ 74 | ResNetModelBuilder(lr=self.lr, net="ResNet18").run() 75 | for _ in range(self.num_nodes) 76 | ] 77 | 78 | def get_original_data(self): 79 | dataset = self.dataset 80 | if dataset == "mnist": 81 | from tensorflow.keras.datasets import mnist 82 | 83 | (self.x_train, self.y_train), (self.x_test, self.y_test) = mnist.load_data() 84 | elif dataset == "cifar10": 85 | from tensorflow.keras.datasets import cifar10 86 | 87 | (self.x_train, self.y_train), ( 88 | self.x_test, 89 | self.y_test, 90 | ) = cifar10.load_data() 91 | self.y_train = np.squeeze(self.y_train, -1) 92 | self.y_test = np.squeeze(self.y_test, -1) 93 | assert len(self.y_train.shape) == 1, f"y_train shape: {self.y_train.shape}" 94 | assert len(self.y_test.shape) == 1, f"y_test shape: {self.y_test.shape}" 95 | 96 | def normalize_data(self, data): 97 | image_size = data.shape[1] 98 | if self.dataset == "mnist": 99 | reshaped_data = np.reshape(data, [-1, image_size, image_size, 1]) 100 | elif self.dataset == "cifar10": 101 | reshaped_data = np.reshape(data, [-1, image_size, image_size, 3]) 102 | else: 103 | raise ValueError(f"Dataset not supported: {self.dataset}") 104 | normalized_data = reshaped_data.astype(np.float32) / 255 105 | return normalized_data 106 | 107 | def random_split(self): 108 | num_partitions = self.num_nodes 109 | x_train = self.normalize_data(self.x_train) 110 | x_test = self.normalize_data(self.x_test) 111 | 112 | # shuffle data then partition 113 | num_train = x_train.shape[0] 114 | indices = np.random.permutation(num_train) 115 | x_train = x_train[indices] 116 | y_train = self.y_train[indices] 117 | 118 | partitioned_x_train = np.array_split(x_train, num_partitions) 119 | partitioned_y_train = np.array_split(y_train, num_partitions) 120 | 121 | return partitioned_x_train, partitioned_y_train, x_test, self.y_test 122 | 123 | def create_skewed_partition_split( 124 | self, skew_factor: float = 0.80, num_classes: int = 10 125 | ): 126 | # returns a "skewed" partition of data 127 | # Ex: 0.8 means 80% of the data for one node is 0-4 while 20% is 5-9 128 | # and vice versa for the other node 129 | # Note: A skew factor 0f 0.5 would essentially be a random split, 130 | # and 1 would be like a partition split 131 | x_train = self.normalize_data(self.x_train) 132 | x_test = self.normalize_data(self.x_test) 133 | 134 | x_train_by_label = [[] for _ in range(num_classes)] 135 | y_train_by_label = [[] for _ in range(num_classes)] 136 | for i in range(len(self.y_train)): 137 | label = int(self.y_train[i]) 138 | x_train_example = x_train[i] 139 | x_train_by_label[label].append(x_train_example) 140 | y_train_by_label[label].append(label) 141 | 142 | # Partition just the classes into n_splits partitions. 143 | splitted_classes = np.array_split(np.arange(num_classes), self.num_nodes) 144 | print("splitted_classes", splitted_classes) 145 | # splitted_classes should look like [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] 146 | # Example: 147 | # Partition 0: 148 | # mostly from 0, 1, 2, 3, 4, and a small amount of 5, 6, 7, 8, 9 149 | 150 | def find_partition_that_this_class_belongs_to(class_idx): 151 | for i, partition in enumerate(splitted_classes): 152 | if class_idx in partition: 153 | return i 154 | 155 | skewed_partitioned_x_train = [[] for _ in range(self.num_nodes)] 156 | skewed_partitioned_y_train = [[] for _ in range(self.num_nodes)] 157 | for i in range(num_classes): 158 | for j in range(len(x_train_by_label[i])): 159 | class_idx = i 160 | partition_that_this_class_belongs_to = ( 161 | find_partition_that_this_class_belongs_to(class_idx) 162 | ) 163 | 164 | # With probability skew_factor, assign examples to the partition, 165 | # otherwise randomly assign to a partition. 166 | if np.random.random() < skew_factor: 167 | skewed_partitioned_x_train[ 168 | partition_that_this_class_belongs_to 169 | ].append(x_train_by_label[i][j]) 170 | skewed_partitioned_y_train[ 171 | partition_that_this_class_belongs_to 172 | ].append(y_train_by_label[i][j]) 173 | else: 174 | # Randomly assign to a partition. 175 | randomly_assigned_partition = int( 176 | np.random.random() * self.num_nodes 177 | ) 178 | skewed_partitioned_x_train[randomly_assigned_partition].append( 179 | x_train_by_label[i][j] 180 | ) 181 | skewed_partitioned_y_train[randomly_assigned_partition].append( 182 | y_train_by_label[i][j] 183 | ) 184 | 185 | # convert to numpy arrays 186 | for i in range(self.num_nodes): 187 | skewed_partitioned_x_train[i] = np.asarray(skewed_partitioned_x_train[i]) 188 | skewed_partitioned_y_train[i] = np.asarray(skewed_partitioned_y_train[i]) 189 | 190 | # shuffle data 191 | for i in range(self.num_nodes): 192 | num_train = skewed_partitioned_x_train[i].shape[0] 193 | indices = np.random.permutation(num_train) 194 | skewed_partitioned_x_train[i] = skewed_partitioned_x_train[i][indices] 195 | skewed_partitioned_y_train[i] = skewed_partitioned_y_train[i][indices] 196 | 197 | # check distribution 198 | for i in range(self.num_nodes): 199 | print(f"Partition {i}:") 200 | for j in range(10): 201 | print(f"Label {j}: {np.sum(skewed_partitioned_y_train[i] == j)}") 202 | 203 | return ( 204 | skewed_partitioned_x_train, 205 | skewed_partitioned_y_train, 206 | x_test, 207 | self.y_test, 208 | ) 209 | 210 | def create_partitioned_datasets(self): 211 | num_partitions = self.num_nodes 212 | 213 | x_train = self.normalize_data(self.x_train) 214 | x_test = self.normalize_data(self.x_test) 215 | 216 | ( 217 | partitioned_x_train, 218 | partitioned_y_train, 219 | ) = self.split_training_data_into_paritions( 220 | x_train, self.y_train, num_partitions=num_partitions 221 | ) 222 | return partitioned_x_train, partitioned_y_train, x_test, self.y_test 223 | 224 | def get_train_dataloader_for_node(self, node_idx: int): 225 | partition_idx = node_idx 226 | partitioned_x_train = self.partitioned_x_train 227 | partitioned_y_train = self.partitioned_y_train 228 | while True: 229 | for i in range(0, len(partitioned_x_train[partition_idx]), self.batch_size): 230 | x_train_batch, y_train_batch = ( 231 | partitioned_x_train[partition_idx][i : i + self.batch_size], 232 | partitioned_y_train[partition_idx][i : i + self.batch_size], 233 | ) 234 | # print("x_train_batch.shape", x_train_batch.shape) 235 | # print("y_train_batch.shape", y_train_batch.shape) 236 | # raise Exception("stop") 237 | yield x_train_batch, y_train_batch 238 | 239 | # ***currently this only works for mnist*** and for num_nodes = 2, 10 240 | def split_training_data_into_paritions( 241 | self, x_train, y_train, num_partitions: int = 2 242 | ): 243 | # partion 1: classes 0-4 244 | # partion 2: classes 5-9 245 | # client 1 train on classes 0-4 only, and validated on 0-9 246 | # client 2 train on classes 5-9 only, and validated on 0-9 247 | # both clients will have low accuracy on 0-9 (below 0.6) 248 | # but when federated, the accuracy will be higher than 0.6 249 | classes = list(range(10)) 250 | num_classes_per_partition = int(len(classes) / num_partitions) 251 | partitioned_classes = [ 252 | classes[i : i + num_classes_per_partition] 253 | for i in range(0, len(classes), num_classes_per_partition) 254 | ] 255 | partitioned_x_train = [] 256 | partitioned_y_train = [] 257 | for partition in partitioned_classes: 258 | # partition is a list of int 259 | if len(y_train.shape) == 2: 260 | selected = np.isin(y_train, partition)[:, 0] 261 | elif len(y_train.shape) == 1: 262 | selected = np.isin(y_train, partition) 263 | # subsetting based on the first axis 264 | x_train_selected = x_train[selected] 265 | assert ( 266 | x_train_selected.shape[0] < x_train.shape[0] 267 | ), "partitioned dataset should be smaller than original dataset" 268 | assert x_train_selected.shape[0] == y_train[selected].shape[0] 269 | partitioned_x_train.append(x_train_selected) 270 | y_train_selected = y_train[selected] 271 | partitioned_y_train.append(y_train_selected) 272 | 273 | return partitioned_x_train, partitioned_y_train 274 | 275 | 276 | # if __name__ == "__main__": 277 | 278 | # base_exp = BaseExperimentRunner(config, num_nodes=2) 279 | 280 | # base_exp.random_split() 281 | # base_exp.create_skewed_partition_split() 282 | -------------------------------------------------------------------------------- /experiments/utils/federated_learning_runner.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor, as_completed 2 | import os 3 | from typing import List, Any 4 | from tensorflow import keras 5 | from tensorflow.keras.utils import set_random_seed 6 | from wandb.keras import WandbCallback 7 | 8 | from flwr.common import ndarrays_to_parameters 9 | 10 | 11 | from flwr.server.strategy import ( 12 | FedAvg, 13 | FedAdam, 14 | FedAvgM, 15 | FedOpt, 16 | FedYogi, 17 | FedAdagrad, 18 | FedMedian, 19 | QFedAvg, 20 | ) 21 | 22 | 23 | from flwr_serverless.federated_node.async_federated_node import AsyncFederatedNode 24 | from flwr_serverless.federated_node.sync_federated_node import SyncFederatedNode 25 | from flwr_serverless.shared_folder.in_memory_folder import InMemoryFolder 26 | from flwr_serverless.keras.federated_learning_callback import FlwrFederatedCallback 27 | from experiments.utils.base_experiment_runner import BaseExperimentRunner, Config 28 | from experiments.utils.custom_wandb_callback import CustomWandbCallback 29 | 30 | 31 | class FederatedLearningRunner(BaseExperimentRunner): 32 | def __init__( 33 | self, 34 | config, 35 | **kwargs, 36 | ): 37 | super().__init__(config, **kwargs) 38 | self.storage_backend: Any = InMemoryFolder() 39 | # In one round, each node trains on its local data for one epoch. 40 | self.num_rounds = self.epochs # number of federated rounds (similar to epochs) 41 | 42 | def run(self): 43 | config: Config = self.config 44 | if config.random_seed is not None: 45 | set_random_seed(config.random_seed) 46 | 47 | if config.track: 48 | import wandb 49 | 50 | strategy = self.config.strategy 51 | num_nodes = self.config.num_nodes 52 | data_split = self.config.data_split 53 | sync_or_async: str = "async" if self.config.use_async else "sync" 54 | name = f"{sync_or_async}_{strategy}_{num_nodes}_nodes_{data_split}" 55 | if data_split == "skewed": 56 | name += f"_{self.config.skew_factor}" 57 | wandb.init( 58 | project=self.config.project, 59 | entity=os.getenv("WANDB_ENTITY", "example_entity"), 60 | name=name, 61 | config=config.__dict__, 62 | ) 63 | self.models = self.create_models() 64 | self.set_strategy() 65 | ( 66 | self.partitioned_x_train, 67 | self.partitioned_y_train, 68 | self.x_test, 69 | self.y_test, 70 | ) = self.split_data() 71 | print("x_test shape:", self.x_test.shape) 72 | print("y_test shape:", self.y_test.shape) 73 | self.train_federated_models() 74 | self.evaluate() 75 | if config.track: 76 | wandb.finish() 77 | 78 | def set_strategy(self): 79 | if self.strategy_name == "fedavg": 80 | self.strategies = [FedAvg() for _ in range(self.num_nodes)] 81 | elif self.strategy_name == "fedavgm": 82 | self.strategies = [FedAvgM() for _ in range(self.num_nodes)] 83 | elif self.strategy_name == "fedadam": 84 | self.strategies = [ 85 | FedAdam( 86 | initial_parameters=ndarrays_to_parameters( 87 | self.models[i].get_weights() 88 | ) 89 | ) 90 | for i in range(self.num_nodes) 91 | ] 92 | elif self.strategy_name == "fedopt": 93 | self.strategies = [ 94 | FedOpt( 95 | initial_parameters=ndarrays_to_parameters( 96 | self.models[i].get_weights() 97 | ) 98 | ) 99 | for i in range(self.num_nodes) 100 | ] 101 | elif self.strategy_name == "fedmedian": 102 | self.strategies = [FedMedian() for _ in range(self.num_nodes)] 103 | # elif self.strategy_name == "fedyogi": 104 | # self.strategy = FedYogi() 105 | # elif self.strategy_name == "fedadagrad": 106 | # self.strategy = FedAdagrad() 107 | else: 108 | raise ValueError(f"Strategy not supported: {self.strategy_name}") 109 | 110 | def split_data(self): 111 | config: Config = self.config 112 | if self.data_split == "random": 113 | return self.random_split() 114 | elif self.data_split == "partitioned": 115 | return self.create_partitioned_datasets() 116 | elif self.data_split == "skewed": 117 | return self.create_skewed_partition_split(skew_factor=config.skew_factor) 118 | else: 119 | raise ValueError("Data split not supported") 120 | 121 | def train_federated_models( 122 | self, 123 | ) -> List[keras.Model]: 124 | if self.federated_type == "pseudo-concurrent": 125 | print("Training federated models pseudo-concurrently.") 126 | return self._train_federated_models_pseudo_concurrently(self.models) 127 | elif self.federated_type == "concurrent": # should be used for all experiments 128 | print("Training federated models concurrently") 129 | return self._train_federated_models_concurrently(self.models) 130 | else: 131 | print("Training federated models sequentially") 132 | return self._train_federated_models_sequentially(self.models) 133 | 134 | def _train_federated_models_concurrently( 135 | self, model_federated: List[keras.Model] 136 | ) -> List[keras.Model]: 137 | nodes = self.create_nodes() 138 | num_partitions = self.num_nodes 139 | 140 | callbacks_per_client = [ 141 | FlwrFederatedCallback( 142 | nodes[i], 143 | num_examples_per_epoch=self.steps_per_epoch * self.batch_size, 144 | ) 145 | for i in range(num_partitions) 146 | ] 147 | 148 | train_loaders = [ 149 | self.get_train_dataloader_for_node(i) for i in range(num_partitions) 150 | ] 151 | 152 | with ThreadPoolExecutor(max_workers=self.num_nodes) as ex: 153 | futures = [] 154 | for i_node in range(self.num_nodes): 155 | callbacks = [ 156 | callbacks_per_client[i_node], 157 | ] 158 | if self.config.track: 159 | callbacks.append(CustomWandbCallback(i_node)) 160 | 161 | # assert self.test_steps is None 162 | if self.config.test_steps is not None: 163 | x_test = self.x_test[ 164 | : self.config.test_steps * self.config.batch_size 165 | ] 166 | y_test = self.y_test[ 167 | : self.config.test_steps * self.config.batch_size 168 | ] 169 | else: 170 | x_test = self.x_test 171 | y_test = self.y_test 172 | future = ex.submit( 173 | model_federated[i_node].fit, 174 | x=train_loaders[i_node], 175 | epochs=self.num_rounds, 176 | steps_per_epoch=self.steps_per_epoch, 177 | callbacks=callbacks, 178 | validation_data=(x_test, y_test), 179 | # validation_data=(self.x_test, self.y_test), 180 | # validation_steps=self.test_steps, 181 | validation_batch_size=self.batch_size, 182 | ) 183 | futures.append(future) 184 | 185 | train_results = [] 186 | for future in as_completed(futures): 187 | train_results.append(future.result()) 188 | 189 | return model_federated 190 | 191 | def _train_federated_models_pseudo_concurrently( 192 | self, model_federated: List[keras.Model] 193 | ) -> List[keras.Model]: 194 | self.lag = 0.1 195 | nodes = self.create_nodes() 196 | num_partitions = self.num_nodes 197 | if self.test_steps is None: 198 | x_test = self.x_test 199 | y_test = self.y_test 200 | else: 201 | x_test = self.x_test[: self.test_steps * self.batch_size, ...] 202 | y_test = self.y_test[: self.test_steps * self.batch_size, ...] 203 | 204 | callbacks_per_client = [ 205 | FlwrFederatedCallback( 206 | nodes[i], 207 | num_examples_per_epoch=self.steps_per_epoch * self.batch_size, 208 | x_test=x_test, 209 | y_test=y_test, 210 | # x_test=self.x_test[: self.test_steps * self.batch_size, ...], 211 | # y_test=self.y_test[: self.test_steps * self.batch_size, ...], 212 | ) 213 | for i in range(num_partitions) 214 | ] 215 | 216 | num_federated_rounds = self.num_rounds 217 | num_epochs_per_round = 1 218 | train_loaders = [ 219 | self.get_train_dataloader_for_node(i) for i in range(num_partitions) 220 | ] 221 | 222 | seqs = [[]] * self.num_nodes 223 | for i_node in range(self.num_nodes): 224 | seqs[i_node] = [ 225 | (i_node, j + i_node * self.lag) for j in range(num_federated_rounds) 226 | ] 227 | # mix them up 228 | execution_sequence = [] 229 | for i_node in range(self.num_nodes): 230 | execution_sequence.extend(seqs[i_node]) 231 | execution_sequence = [ 232 | x[0] for x in sorted(execution_sequence, key=lambda x: x[1]) 233 | ] 234 | print(f"Execution sequence: {execution_sequence}") 235 | if self.test_steps is None: 236 | x_test = self.x_test 237 | y_test = self.y_test 238 | else: 239 | x_test = self.x_test[: self.test_steps * self.batch_size, ...] 240 | y_test = self.y_test[: self.test_steps * self.batch_size, ...] 241 | for i_node in execution_sequence: 242 | print("Training node", i_node) 243 | model_federated[i_node].fit( 244 | x=train_loaders[i_node], 245 | epochs=num_epochs_per_round, 246 | steps_per_epoch=self.steps_per_epoch, 247 | callbacks=[callbacks_per_client[i_node]], 248 | validation_data=(x_test, y_test), 249 | validation_steps=self.test_steps, 250 | validation_batch_size=self.batch_size, 251 | ) 252 | 253 | if i_node == 0: 254 | print("Evaluating on the combined test set:") 255 | assert ( 256 | len(y_test.shape) == 1 257 | ), f"y_test should be 1D, got {y_test.shape}" 258 | evaluation_metrics = model_federated[0].evaluate( 259 | x_test, 260 | y_test, 261 | # self.x_test[: self.test_steps * self.batch_size, ...], 262 | # self.y_test[: self.test_steps * self.batch_size, ...], 263 | batch_size=self.batch_size, 264 | steps=self.test_steps, 265 | return_dict=True, 266 | ) 267 | 268 | return model_federated 269 | 270 | def _train_federated_models_sequentially( 271 | self, model_federated: List[keras.Model] 272 | ) -> List[keras.Model]: 273 | nodes = self.create_nodes() 274 | num_partitions = self.num_nodes # is this needed? 275 | 276 | callbacks_per_client = [ 277 | FlwrFederatedCallback( 278 | nodes[i], num_examples_per_epoch=self.batch_size * self.steps_per_epoch 279 | ) 280 | for i in range(num_partitions) 281 | ] 282 | 283 | num_federated_rounds = self.num_rounds 284 | num_epochs_per_round = 1 285 | train_loaders = [ 286 | self.get_train_dataloader_for_node(i) for i in range(num_partitions) 287 | ] 288 | 289 | if self.config.track: 290 | wandb_callbacks = [WandbCallback() for i in range(num_partitions)] 291 | for i_round in range(num_federated_rounds): 292 | print("\n============ Round", i_round) 293 | callbacks = [ 294 | callbacks_per_client[i_partition], 295 | ] 296 | if self.config.track: 297 | callbacks.append(wandb_callbacks[i_partition]) 298 | for i_partition in range(num_partitions): 299 | model_federated[i_partition].fit( 300 | train_loaders[i_partition], 301 | epochs=num_epochs_per_round, 302 | steps_per_epoch=self.steps_per_epoch, 303 | callbacks=callbacks, 304 | ) 305 | print("Evaluating on the combined test set:") 306 | assert ( 307 | len(self.y_test.shape) == 1 308 | ), f"y_test should be 1D, got {self.y_test.shape}" 309 | model_federated[0].evaluate( 310 | self.x_test, 311 | self.y_test, 312 | batch_size=self.batch_size, 313 | steps=self.steps_per_epoch, 314 | ) 315 | 316 | return model_federated 317 | 318 | def create_nodes(self): 319 | if self.use_async: 320 | nodes = [ 321 | AsyncFederatedNode( 322 | shared_folder=self.storage_backend, strategy=self.strategies[i] 323 | ) 324 | for i in range(self.num_nodes) 325 | ] 326 | else: 327 | nodes = [ 328 | SyncFederatedNode( 329 | shared_folder=self.storage_backend, 330 | strategy=self.strategies[i], 331 | num_nodes=self.num_nodes, 332 | ) 333 | for i in range(self.num_nodes) 334 | ] 335 | return nodes 336 | 337 | def evaluate(self): 338 | for i_node in [0]: # range(self.num_nodes): 339 | loss1, accuracy1 = self.models[i_node].evaluate( 340 | self.x_test, 341 | self.y_test, 342 | batch_size=self.batch_size, 343 | steps=self.test_steps, 344 | ) 345 | if self.config.track: 346 | import wandb 347 | 348 | to_log = { 349 | "test_accuracy": accuracy1, 350 | "test_loss": loss1, 351 | } 352 | wandb.log(to_log) 353 | -------------------------------------------------------------------------------- /flwr_serverless/keras/example.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor 2 | from dataclasses import dataclass 3 | from typing import List, Any, Callable 4 | import numpy as np 5 | from tensorflow.keras.datasets import mnist 6 | from tensorflow.keras.layers import Conv2D, Dense, Flatten, Input, MaxPooling2D 7 | from tensorflow.keras.models import Model 8 | from tensorflow import keras 9 | 10 | from flwr.server.strategy import Strategy 11 | from flwr.server.strategy import FedAvg, FedAdam, FedAvgM 12 | from flwr_serverless.federated_node.async_federated_node import AsyncFederatedNode 13 | from flwr_serverless.federated_node.sync_federated_node import SyncFederatedNode 14 | from flwr_serverless.shared_folder.in_memory_folder import InMemoryFolder 15 | from flwr_serverless.shared_folder.local_folder import LocalFolder 16 | from flwr_serverless.keras.federated_learning_callback import FlwrFederatedCallback 17 | 18 | 19 | @dataclass 20 | class FederatedLearningTestRun: 21 | num_nodes: int = 2 22 | epochs: int = 8 23 | num_rounds: int = 8 # number of federated rounds 24 | batch_size: int = 32 25 | steps_per_epoch: int = 10 26 | lr: float = 0.001 27 | test_steps: int = 10 28 | 29 | strategy: Strategy = FedAvg() 30 | storage_backend: Any = None 31 | use_async_node: bool = True 32 | # Whether to train federated models concurrently or sequentially. 33 | train_concurrently: bool = False 34 | train_pseudo_concurrently: bool = False 35 | lag: float = 0.1 36 | 37 | model_builder_fn: Callable = None 38 | replicate_num_channels: bool = False 39 | save_model_before_aggregation: bool = False 40 | save_model_after_aggregation: bool = False 41 | 42 | def __post_init__(self): 43 | if self.model_builder_fn is None: 44 | self.model_builder_fn = MnistModelBuilder(lr=self.lr).run 45 | if self.storage_backend is None: 46 | self.storage_backend = InMemoryFolder() 47 | self.histories = {} 48 | 49 | def run(self): 50 | ( 51 | self.partitioned_x_train, 52 | self.partitioned_y_train, 53 | self.x_test, 54 | self.y_test, 55 | ) = self.create_partitioned_datasets() 56 | model_standalone: List[keras.Model] = self.create_standalone_models() 57 | model_federated: List[keras.Model] = self.create_federated_models() 58 | model_standalone = self.train_standalone_models(model_standalone) 59 | model_federated = self.train_federated_models(model_federated) 60 | print("Evaluating on the combined test set (standalone models):") 61 | accuracy_standalone = self.evaluate_models(model_standalone) 62 | for i_node in range(len(accuracy_standalone)): 63 | print( 64 | "Standalone accuracy for node {}: {}".format( 65 | i_node, accuracy_standalone[i_node] 66 | ) 67 | ) 68 | print("Evaluating on the combined test set (federated model):") 69 | # Evaluating only the first model. 70 | accuracy_federated = self.evaluate_models(model_federated) 71 | for i_node in range(self.num_nodes): # [len(accuracy_federated) - 1]: 72 | print( 73 | "Federated accuracy for node {}: {}".format( 74 | i_node, accuracy_federated[i_node] 75 | ) 76 | ) 77 | 78 | return accuracy_standalone, accuracy_federated 79 | 80 | def create_partitioned_datasets(self): 81 | num_partitions = self.num_nodes 82 | 83 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 84 | # x_train.shape: (60000, 28, 28) 85 | # print(y_train.shape) # (60000,) 86 | # Normalize 87 | image_size = x_train.shape[1] 88 | x_train = np.reshape(x_train, [-1, image_size, image_size, 1]) 89 | x_test = np.reshape(x_test, [-1, image_size, image_size, 1]) 90 | if self.replicate_num_channels: 91 | x_train = np.tile(x_train, (1, 1, 1, 3)) 92 | x_test = np.tile(x_test, (1, 1, 1, 3)) 93 | x_train = x_train.astype(np.float32) / 255 94 | x_test = x_test.astype(np.float32) / 255 95 | partitioned_x_train, partitioned_y_train = split_training_data_into_paritions( 96 | x_train, y_train, num_partitions=num_partitions 97 | ) 98 | return partitioned_x_train, partitioned_y_train, x_test, y_test 99 | 100 | def create_standalone_models(self): 101 | return [self.model_builder_fn() for _ in range(self.num_nodes)] 102 | 103 | def get_train_dataloader_for_node(self, node_idx: int): 104 | partition_idx = node_idx 105 | batch_size = self.batch_size 106 | partitioned_x_train = self.partitioned_x_train 107 | partitioned_y_train = self.partitioned_y_train 108 | while True: 109 | for i in range(0, len(partitioned_x_train[partition_idx]), batch_size): 110 | yield partitioned_x_train[partition_idx][ 111 | i : i + batch_size 112 | ], partitioned_y_train[partition_idx][i : i + batch_size] 113 | 114 | def create_federated_models(self): 115 | models = [self.model_builder_fn() for _ in range(self.num_nodes)] 116 | self.models = models 117 | return models 118 | 119 | def train_standalone_models( 120 | self, model_standalone: List[keras.Model] 121 | ) -> List[keras.Model]: 122 | for i_node in range(self.num_nodes): 123 | train_loader_standalone = self.get_train_dataloader_for_node(i_node) 124 | self.histories[i_node] = model_standalone[i_node].fit( 125 | train_loader_standalone, 126 | epochs=self.epochs, 127 | steps_per_epoch=self.steps_per_epoch, 128 | ) 129 | 130 | return model_standalone 131 | 132 | def train_federated_models( 133 | self, model_federated: List[keras.Model] 134 | ) -> List[keras.Model]: 135 | if self.train_pseudo_concurrently: 136 | print("Training federated models pseudo-concurrently.") 137 | return self._train_federated_models_pseudo_concurrently(model_federated) 138 | elif self.train_concurrently: 139 | print("Training federated models concurrently") 140 | return self._train_federated_models_concurrently(model_federated) 141 | else: 142 | print("Training federated models sequentially") 143 | return self._train_federated_models_sequentially(model_federated) 144 | 145 | def _train_federated_models_concurrently( 146 | self, model_federated: List[keras.Model] 147 | ) -> List[keras.Model]: 148 | strategy = self.strategy 149 | storage_backend = self.storage_backend 150 | if self.use_async_node: 151 | nodes = [] 152 | for _ in range(self.num_nodes): 153 | if isinstance(storage_backend, LocalFolder): 154 | # duplicate 155 | storage_backend = LocalFolder(directory=storage_backend.directory) 156 | nodes.append( 157 | AsyncFederatedNode(shared_folder=storage_backend, strategy=strategy) 158 | ) 159 | else: 160 | nodes = [] 161 | for _ in range(self.num_nodes): 162 | if isinstance(storage_backend, LocalFolder): 163 | # duplicate 164 | storage_backend = LocalFolder(directory=storage_backend.directory) 165 | nodes.append( 166 | SyncFederatedNode( 167 | shared_folder=storage_backend, 168 | strategy=strategy, 169 | num_nodes=self.num_nodes, 170 | ) 171 | ) 172 | 173 | self.nodes = nodes 174 | for i, node in enumerate(nodes): 175 | print(f"node {i}: folder {node.model_store}") 176 | num_partitions = self.num_nodes 177 | model_federated = [self.model_builder_fn() for _ in range(num_partitions)] 178 | callbacks_per_client = [ 179 | FlwrFederatedCallback( 180 | nodes[i], 181 | x_test=self.x_test, 182 | y_test=self.y_test, 183 | num_examples_per_epoch=self.steps_per_epoch * self.batch_size, 184 | save_model_before_aggregation=self.save_model_before_aggregation, 185 | ) 186 | for i in range(num_partitions) 187 | ] 188 | self.callbacks_per_client = callbacks_per_client 189 | 190 | train_loaders = [ 191 | self.get_train_dataloader_for_node(i) for i in range(num_partitions) 192 | ] 193 | 194 | with ThreadPoolExecutor(max_workers=self.num_nodes) as ex: 195 | futures = [] 196 | for i_node in range(self.num_nodes): 197 | # time.sleep(0.5 * i_node) 198 | future = ex.submit( 199 | model_federated[i_node].fit, 200 | x=train_loaders[i_node], 201 | epochs=self.num_rounds, 202 | steps_per_epoch=self.steps_per_epoch, 203 | callbacks=[callbacks_per_client[i_node]], 204 | validation_data=( 205 | self.x_test[: self.test_steps * self.batch_size, ...], 206 | self.y_test[: self.test_steps * self.batch_size, ...], 207 | ), 208 | validation_steps=self.test_steps, 209 | validation_batch_size=self.batch_size, 210 | ) 211 | futures.append(future) 212 | train_results = [future.result() for future in futures] 213 | 214 | return model_federated 215 | 216 | def _train_federated_models_pseudo_concurrently( 217 | self, model_federated: List[keras.Model] 218 | ) -> List[keras.Model]: 219 | # federated learning 220 | lag = self.lag 221 | strategy = self.strategy 222 | storage_backend = self.storage_backend 223 | if self.use_async_node: 224 | nodes = [ 225 | AsyncFederatedNode(shared_folder=storage_backend, strategy=strategy) 226 | for _ in range(self.num_nodes) 227 | ] 228 | else: 229 | raise NotImplementedError() 230 | self.nodes = nodes 231 | num_partitions = self.num_nodes 232 | model_federated = [self.model_builder_fn() for _ in range(num_partitions)] 233 | callbacks_per_client = [ 234 | FlwrFederatedCallback( 235 | nodes[i], 236 | num_examples_per_epoch=self.steps_per_epoch * self.batch_size, 237 | x_test=self.x_test[: self.test_steps * self.batch_size, ...], 238 | y_test=self.y_test[: self.test_steps * self.batch_size, ...], 239 | ) 240 | for i in range(num_partitions) 241 | ] 242 | self.callbacks_per_client = callbacks_per_client 243 | 244 | num_federated_rounds = self.num_rounds 245 | num_epochs_per_round = 1 246 | train_loaders = [ 247 | self.get_train_dataloader_for_node(i) for i in range(num_partitions) 248 | ] 249 | 250 | seqs = [[]] * self.num_nodes 251 | for i_node in range(self.num_nodes): 252 | seqs[i_node] = [ 253 | (i_node, j + i_node * lag) for j in range(num_federated_rounds) 254 | ] 255 | # mix them up 256 | execution_sequence = [] 257 | for i_node in range(self.num_nodes): 258 | execution_sequence.extend(seqs[i_node]) 259 | execution_sequence = [ 260 | x[0] for x in sorted(execution_sequence, key=lambda x: x[1]) 261 | ] 262 | print(f"Execution sequence: {execution_sequence}") 263 | for i_node in execution_sequence: 264 | print("Training node", i_node) 265 | self.histories[i_node] = model_federated[i_node].fit( 266 | x=train_loaders[i_node], 267 | epochs=num_epochs_per_round, 268 | steps_per_epoch=self.steps_per_epoch, 269 | callbacks=callbacks_per_client[i_node], 270 | validation_data=( 271 | self.x_test[: self.test_steps * self.batch_size, ...], 272 | self.y_test[: self.test_steps * self.batch_size, ...], 273 | ), 274 | validation_steps=self.test_steps, 275 | validation_batch_size=self.batch_size, 276 | ) 277 | 278 | if i_node == 0: 279 | print("Evaluating on the combined test set:") 280 | model_federated[0].evaluate( 281 | self.x_test[: self.test_steps * self.batch_size, ...], 282 | self.y_test[: self.test_steps * self.batch_size, ...], 283 | batch_size=self.batch_size, 284 | steps=10, 285 | ) 286 | 287 | return model_federated 288 | 289 | def _train_federated_models_sequentially( 290 | self, model_federated: List[keras.Model] 291 | ) -> List[keras.Model]: 292 | # federated learning 293 | strategy = self.strategy 294 | storage_backend = self.storage_backend 295 | if self.use_async_node: 296 | nodes = [ 297 | AsyncFederatedNode(shared_folder=storage_backend, strategy=strategy) 298 | for _ in range(self.num_nodes) 299 | ] 300 | else: 301 | raise NotImplementedError() 302 | self.nodes = nodes 303 | num_partitions = self.num_nodes 304 | model_federated = [self.model_builder_fn() for _ in range(num_partitions)] 305 | callbacks_per_client = [ 306 | FlwrFederatedCallback( 307 | nodes[i], num_examples_per_epoch=self.batch_size * self.steps_per_epoch 308 | ) 309 | for i in range(num_partitions) 310 | ] 311 | self.callbacks_per_client = callbacks_per_client 312 | 313 | num_federated_rounds = self.num_rounds 314 | num_epochs_per_round = 1 315 | train_loaders = [ 316 | self.get_train_dataloader_for_node(i) for i in range(num_partitions) 317 | ] 318 | 319 | for i_round in range(num_federated_rounds): 320 | print("\n============ Round", i_round) 321 | for i_partition in range(num_partitions): 322 | self.histories[i_partition] = model_federated[i_partition].fit( 323 | train_loaders[i_partition], 324 | validation_data=( 325 | self.x_test[: self.test_steps * self.batch_size, ...], 326 | self.y_test[: self.test_steps * self.batch_size, ...], 327 | ), 328 | validation_steps=self.test_steps, 329 | epochs=num_epochs_per_round, 330 | steps_per_epoch=self.steps_per_epoch, 331 | callbacks=callbacks_per_client[i_partition], 332 | ) 333 | print("Evaluating on the combined test set:") 334 | model_federated[0].evaluate( 335 | self.x_test, self.y_test, batch_size=self.batch_size, steps=10 336 | ) 337 | 338 | return model_federated 339 | 340 | def evaluate_models(self, models: List[keras.Model]) -> List[float]: 341 | accuracies = [] 342 | for model in models: 343 | _, accuracy = model.evaluate( 344 | self.x_test, 345 | self.y_test, 346 | batch_size=self.batch_size, 347 | steps=self.test_steps, 348 | ) 349 | accuracies.append(accuracy) 350 | return accuracies 351 | 352 | 353 | class MnistModelBuilder: 354 | """This is a helper class to create a simple Keras model 355 | for MNIST digit classification. 356 | """ 357 | 358 | def __init__(self, lr=0.001): 359 | self.lr = lr 360 | 361 | def run(self): 362 | model = self._build_model() 363 | return self._compile_model(model) 364 | 365 | def _build_model(self): 366 | input = Input(shape=(28, 28, 1)) 367 | x = Conv2D(32, kernel_size=4, activation="relu")(input) 368 | x = MaxPooling2D()(x) 369 | x = Conv2D(16, kernel_size=4, activation="relu")(x) 370 | x = Flatten()(x) 371 | output = Dense(10, activation="softmax")(x) 372 | model = Model(inputs=input, outputs=output) 373 | return model 374 | 375 | def _compile_model(self, model): 376 | model.compile( 377 | optimizer=keras.optimizers.Adam(self.lr), 378 | loss="sparse_categorical_crossentropy", 379 | metrics=["accuracy"], 380 | ) 381 | return model 382 | 383 | 384 | def split_training_data_into_paritions(x_train, y_train, num_partitions: int = 2): 385 | # partion 1: classes 0-4 386 | # partion 2: classes 5-9 387 | # client 1 train on classes 0-4 only, and validated on 0-9 388 | # client 2 train on classes 5-9 only, and validated on 0-9 389 | # both clients will have low accuracy on 0-9 (below 0.6) 390 | # but when federated, the accuracy will be higher than 0.6 391 | classes = list(range(10)) 392 | num_classes_per_partition = int(len(classes) / num_partitions) 393 | partitioned_classes = [ 394 | classes[i : i + num_classes_per_partition] 395 | for i in range(0, len(classes), num_classes_per_partition) 396 | ] 397 | partitioned_x_train = [] 398 | partitioned_y_train = [] 399 | for partition in partitioned_classes: 400 | partitioned_x_train.append(x_train[np.isin(y_train, partition)]) 401 | partitioned_y_train.append(y_train[np.isin(y_train, partition)]) 402 | return partitioned_x_train, partitioned_y_train 403 | -------------------------------------------------------------------------------- /experiments/exp3_wikitext.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from dataclasses import dataclass 4 | from torch import nn 5 | from concurrent.futures import ThreadPoolExecutor 6 | import torch 7 | from transformers import AutoModelForCausalLM, AutoTokenizer, is_torch_tpu_available 8 | from transformers import TrainingArguments, IntervalStrategy, Trainer, TrainerCallback 9 | from transformers import DataCollatorForLanguageModeling 10 | import datasets 11 | import evaluate 12 | from flwr.common import ( 13 | Parameters, 14 | ndarrays_to_parameters, 15 | parameters_to_ndarrays, 16 | ) 17 | from flwr.server.strategy import FedAvg, FedAdam, FedAvgM 18 | from flwr_serverless.shared_folder.in_memory_folder import InMemoryFolder 19 | from flwr_serverless.shared_folder.local_folder import LocalFolder 20 | from flwr_serverless.federated_node.async_federated_node import AsyncFederatedNode 21 | from flwr_serverless.federated_node.sync_federated_node import SyncFederatedNode 22 | 23 | # from transformers import Traniner 24 | 25 | from experiments.dataset.tolkien_dataset_builder import TolkienDatasetBuilder 26 | 27 | # TODO: instrument this code with flwr 28 | # TODO: Refer to https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/integrations.py#L670 29 | # Implement a custom callback for HF trainer. 30 | # os.environ["WANDB_DISABLED"] = "true" 31 | 32 | from argparse import ArgumentParser 33 | 34 | 35 | class TrainingSessionArgParser: 36 | def __init__(self): 37 | self.parser = ArgumentParser() 38 | self.add_args() 39 | 40 | def add_args(self): 41 | self.parser.add_argument( 42 | "--filename", 43 | type=str, 44 | default="experiments/dataset/lotr-paragraphs.json", 45 | help="Path to the dataset", 46 | ) 47 | self.parser.add_argument( 48 | "--model_name", 49 | type=str, 50 | default="EleutherAI/gpt-neo-125M", 51 | help="HuggingFace CausalLM pre-trained model to be fine tuned", 52 | ) 53 | self.parser.add_argument( 54 | "--epochs", 55 | type=int, 56 | default=3, 57 | help="Number of epochs to train the model", 58 | ) 59 | self.parser.add_argument( 60 | "--batch_size", 61 | type=int, 62 | default=32, 63 | help="Batch size to use for training", 64 | ) 65 | self.parser.add_argument( 66 | "--learning_rate", 67 | type=float, 68 | default=2e-5, 69 | help="Learning rate to use for training", 70 | ) 71 | 72 | def parse_args(self): 73 | return self.parser.parse_args() 74 | 75 | 76 | @dataclass 77 | class TrainingSession: 78 | model_name: str = "EleutherAI/gpt-neo-125M" 79 | # model_name: str = "EleutherAI/pythia-14M" 80 | epochs: int = 3 81 | batch_size: int = 16 82 | lr: float = 5e-5 83 | context_length: int = 128 84 | track: bool = False 85 | 86 | def __post_init__(self): 87 | # self.bos_token = "<|startoftext|>" 88 | # self.eos_token = "<|endoftext|>" 89 | # self.pad_token = "<|pad|>" 90 | self.tokenizer = AutoTokenizer.from_pretrained( 91 | self.model_name, 92 | # bos_token=self.bos_token, 93 | # eos_token=self.eos_token, 94 | # pad_token=self.pad_token, 95 | ) 96 | 97 | def run(self): 98 | if self.track: 99 | import wandb 100 | 101 | with wandb.init(project="wikitext"): 102 | wandb.config.update(self.__dict__) 103 | self._run() 104 | else: 105 | self._run() 106 | 107 | def _run(self): 108 | self.create_datasets() 109 | self.create_model() 110 | self.create_trainer() 111 | self.trainer.train() 112 | 113 | def create_datasets(self): 114 | raw_datasets = datasets.load_dataset( 115 | "wikitext", 116 | "wikitext-103-v1", 117 | split=["train[:100000]", "validation[:1000]"], 118 | # streaming=True 119 | ) 120 | print("raw datasets:") 121 | print(raw_datasets) 122 | context_length = self.context_length 123 | 124 | def tokenize(element): 125 | outputs = self.tokenizer( 126 | element["text"], 127 | truncation=True, 128 | max_length=context_length, 129 | return_overflowing_tokens=False, 130 | return_length=True, 131 | ) 132 | input_batch = [] 133 | for length, input_ids in zip(outputs["length"], outputs["input_ids"]): 134 | if length == context_length: 135 | input_batch.append(input_ids) 136 | return {"input_ids": input_batch} 137 | 138 | tokenized_train = raw_datasets[0].map( 139 | tokenize, batched=True, remove_columns=raw_datasets[0].column_names 140 | ) 141 | tokenized_test = raw_datasets[1].map( 142 | tokenize, batched=True, remove_columns=raw_datasets[1].column_names 143 | ) 144 | 145 | print("tokenized:") 146 | print(tokenized_train) 147 | print("iterating:") 148 | for x in tokenized_train: 149 | print(x) 150 | break 151 | 152 | self.train_dataset = tokenized_train 153 | self.val_dataset = tokenized_test 154 | 155 | def create_model(self): 156 | self.model = AutoModelForCausalLM.from_pretrained(self.model_name).cuda() 157 | self.model.resize_token_embeddings(len(self.tokenizer)) 158 | 159 | def create_trainer(self): 160 | time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 161 | training_args = TrainingArguments( 162 | learning_rate=self.lr, 163 | output_dir=f"./results/{time}", 164 | num_train_epochs=self.epochs, 165 | per_device_train_batch_size=self.batch_size, 166 | evaluation_strategy="steps", 167 | logging_strategy="steps", 168 | gradient_accumulation_steps=10, 169 | eval_steps=50, 170 | logging_steps=50, 171 | save_strategy=IntervalStrategy.NO, 172 | # evaluation_strategy="epoch", 173 | # logging_strategy="epoch", 174 | report_to=["wandb"], 175 | eval_delay=0, 176 | per_device_eval_batch_size=self.batch_size, 177 | eval_accumulation_steps=10, 178 | # per_device_eval_batch_size=8, 179 | # logging_steps=5000, 180 | # logging_dir="./logs", 181 | # save_strategy=IntervalStrategy.NO, 182 | # warmup_steps=100, 183 | # weight_decay=0.01, 184 | ) 185 | self.tokenizer.pad_token = self.tokenizer.eos_token 186 | data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) 187 | 188 | def preprocess_logits_for_metrics(logits, labels): 189 | if isinstance(logits, tuple): 190 | # Depending on the model and config, logits may contain extra tensors, 191 | # like past_key_values, but logits always come first 192 | logits = logits[0] 193 | return logits.argmax(dim=-1) 194 | 195 | metric = evaluate.load("accuracy") 196 | 197 | def compute_metrics(eval_preds): 198 | preds, labels = eval_preds 199 | # preds have the same shape as the labels, after the argmax(-1) has been calculated 200 | # by preprocess_logits_for_metrics but we need to shift the labels 201 | labels = labels[:, 1:].reshape(-1) 202 | preds = preds[:, :-1].reshape(-1) 203 | return metric.compute(predictions=preds, references=labels) 204 | 205 | self.trainer = Trainer( 206 | model=self.model, 207 | args=training_args, 208 | train_dataset=self.train_dataset, 209 | eval_dataset=self.val_dataset, 210 | data_collator=data_collator, 211 | compute_metrics=compute_metrics, 212 | preprocess_logits_for_metrics=preprocess_logits_for_metrics 213 | if training_args.do_eval and not is_torch_tpu_available() 214 | else None, 215 | # data_collator=lambda data: { 216 | # "input_ids": torch.stack([f["input_ids"] for f in data]), 217 | # "attention_mask": torch.stack([f["attention_mask"] for f in data]), 218 | # "labels": torch.stack([f["input_ids"] for f in data]), 219 | # }, 220 | ) 221 | 222 | 223 | class FederatedLearningCallback(TrainerCallback): 224 | def __init__(self, federated_node, num_examples_per_epoch=1, **kwargs): 225 | super().__init__(**kwargs) 226 | self.node = federated_node 227 | self.num_examples_per_epoch = num_examples_per_epoch 228 | self.counted_epoch = 0 229 | 230 | def on_epoch_end(self, args, state, control, **kwargs): 231 | print(f"Node {self.node.node_id} to begin federation at epoch end...") 232 | model = kwargs["model"] 233 | # epoch = state.epoch 234 | epoch = self.counted_epoch 235 | torch_model = model 236 | device = torch_model.device 237 | 238 | # get model weights 239 | node_id = self.node.node_id 240 | metrics = {} 241 | # get model weights 242 | 243 | model_weights = list(torch_model.cpu().parameters()) 244 | model_weights = [w.detach().numpy() for w in model_weights] 245 | params: Parameters = ndarrays_to_parameters(model_weights) 246 | updated_params, updated_metrics = self.node.update_parameters( 247 | params, 248 | num_examples=self.num_examples_per_epoch, 249 | epoch=epoch, 250 | metrics=metrics, 251 | ) 252 | self._federated_metrics = updated_metrics 253 | 254 | if updated_params is not None: 255 | # set model weights 256 | print("updating model weights using federation") 257 | updated_params = parameters_to_ndarrays(updated_params) 258 | model_weights = torch_model.parameters() 259 | for param, updated_param in zip(model_weights, updated_params): 260 | w = torch.from_numpy(updated_param) 261 | # w = w.to(device) 262 | param.data = nn.parameter.Parameter(w) 263 | 264 | torch_model.to(device) 265 | 266 | 267 | @dataclass 268 | class FederatedTrainingSession: 269 | # model_name: str = "EleutherAI/gpt-neo-125M" 270 | model_name: str = "EleutherAI/pythia-14M" 271 | num_nodes: int = 1 272 | context_length: int = 128 273 | batch_size: int = 16 274 | n_train_total: int = 12000 # 00 275 | use_async: bool = False 276 | track: bool = False 277 | 278 | def __post_init__(self): 279 | self.train_datasets = [] 280 | self.test_dataset = None 281 | 282 | def run(self): 283 | if self.track: 284 | import wandb 285 | 286 | with wandb.init(project="wikitext"): 287 | wandb.config.update(self.__dict__) 288 | self._run() 289 | 290 | else: 291 | self._run() 292 | 293 | def _run(self): 294 | self.create_random_partitioned_datasets() 295 | self.create_models() 296 | self.train_concurrently() 297 | 298 | def create_random_partitioned_datasets(self): 299 | # load wikitext 300 | self.test_dataset = datasets.load_dataset( 301 | "wikitext", 302 | "wikitext-103-v1", 303 | split="validation[:1000]", 304 | ) 305 | 306 | partitioned_datasets = [] 307 | n_train_total = self.n_train_total 308 | for i in range(self.num_nodes): 309 | start_idx = i * n_train_total // self.num_nodes 310 | end_idx = start_idx + n_train_total // self.num_nodes 311 | if i == self.num_nodes - 1: 312 | end_idx = n_train_total 313 | print(f"start={start_idx}, end={end_idx}") 314 | subset = datasets.load_dataset( 315 | "wikitext", 316 | "wikitext-103-v1", 317 | split=f"train[{start_idx}:{end_idx}]", 318 | ) 319 | partitioned_datasets.append(subset) 320 | self.train_datasets = partitioned_datasets 321 | print("training datasets:") 322 | for i, ds in enumerate(self.train_datasets): 323 | print(f"{i}: {len(ds)}") 324 | print(ds) 325 | print("test dataset:") 326 | print(len(self.test_dataset)) 327 | 328 | def create_models(self): 329 | self.federated_models = [] 330 | self.tokenizer = AutoTokenizer.from_pretrained( 331 | self.model_name, 332 | ) 333 | self.tokenizer.pad_token = self.tokenizer.eos_token 334 | # self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) 335 | for i in range(self.num_nodes): 336 | model = AutoModelForCausalLM.from_pretrained(self.model_name).cuda() 337 | model.resize_token_embeddings(len(self.tokenizer)) 338 | self.federated_models.append(model) 339 | return self.federated_models 340 | 341 | def train_concurrently(self): 342 | training_args = TrainingArguments( 343 | learning_rate=2e-5, 344 | output_dir="./results", 345 | num_train_epochs=3, 346 | per_device_train_batch_size=self.batch_size, 347 | evaluation_strategy="steps", 348 | logging_strategy="steps", 349 | gradient_accumulation_steps=10, 350 | eval_steps=50, 351 | logging_steps=50, 352 | save_strategy=IntervalStrategy.NO, 353 | # report_to=["wandb"], 354 | eval_delay=0, 355 | per_device_eval_batch_size=self.batch_size, 356 | eval_accumulation_steps=10, 357 | dataloader_drop_last=True, 358 | ) 359 | trainers = [] 360 | 361 | def preprocess_logits_for_metrics(logits, labels): 362 | if isinstance(logits, tuple): 363 | # Depending on the model and config, logits may contain extra tensors, 364 | # like past_key_values, but logits always come first 365 | logits = logits[0] 366 | return logits.argmax(dim=-1) 367 | 368 | accuracy_metric = evaluate.load("accuracy") 369 | # perplexity_metric = evaluate.load("perplexity") 370 | 371 | def compute_metrics(eval_preds): 372 | preds, labels = eval_preds 373 | # preds have the same shape as the labels, after the argmax(-1) has been calculated 374 | # by preprocess_logits_for_metrics but we need to shift the labels 375 | labels = labels[:, 1:].reshape(-1) 376 | preds = preds[:, :-1].reshape(-1) 377 | return accuracy_metric.compute(predictions=preds, references=labels) 378 | 379 | def tokenize(element): 380 | outputs = self.tokenizer( 381 | element["text"], 382 | truncation=True, 383 | max_length=self.context_length, 384 | return_overflowing_tokens=False, 385 | return_length=True, 386 | ) 387 | input_batch = [] 388 | for length, input_ids in zip(outputs["length"], outputs["input_ids"]): 389 | if length == self.context_length: 390 | input_batch.append(input_ids) 391 | return {"input_ids": input_batch} 392 | 393 | shared_folder = InMemoryFolder() 394 | # shared_folder = LocalFolder( 395 | # os.path.join(os.getcwd(), "shared", str(time.time())) 396 | # ) 397 | strategy = FedAvg() 398 | for i in range(self.num_nodes): 399 | data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) 400 | tokenized_train = self.train_datasets[i].map( 401 | tokenize, 402 | batched=True, 403 | # num_proc=4, 404 | remove_columns=self.train_datasets[i].column_names, 405 | ) 406 | tokenized_test = self.test_dataset.map( 407 | tokenize, 408 | batched=True, 409 | # num_proc=4, 410 | remove_columns=self.test_dataset.column_names, 411 | ) 412 | if self.use_async: 413 | node = AsyncFederatedNode( 414 | shared_folder=shared_folder, strategy=strategy 415 | ) 416 | else: 417 | node = SyncFederatedNode( 418 | shared_folder=shared_folder, 419 | strategy=strategy, 420 | num_nodes=self.num_nodes, 421 | ) 422 | 423 | trainer = Trainer( 424 | model=self.federated_models[i], 425 | args=training_args, 426 | train_dataset=tokenized_train, 427 | eval_dataset=tokenized_train, 428 | data_collator=data_collator, 429 | compute_metrics=compute_metrics, 430 | preprocess_logits_for_metrics=preprocess_logits_for_metrics 431 | if training_args.do_eval and not is_torch_tpu_available() 432 | else None, 433 | callbacks=[FederatedLearningCallback(node)], 434 | ) 435 | trainers.append(trainer) 436 | 437 | # trainers[0].train() 438 | with ThreadPoolExecutor(max_workers=self.num_nodes) as executor: 439 | # wait for all trainers to finish 440 | futures = [] 441 | for trainer in trainers: 442 | futures.append(executor.submit(trainer.train)) 443 | for future in futures: 444 | future.result() 445 | 446 | # eval on test set 447 | print(trainers[0].model.device) 448 | result = trainers[0].evaluate() 449 | print(result) 450 | if self.track: 451 | import wandb 452 | 453 | wandb.log(result) 454 | # wandb.config.update(self.__dict__) 455 | 456 | 457 | if __name__ == "__main__": 458 | from dotenv import load_dotenv 459 | 460 | load_dotenv() 461 | 462 | # model = "EleutherAI/gpt-neo-125M" 463 | model = "EleutherAI/pythia-14M" 464 | # FederatedTrainingSession( 465 | # track=True, num_nodes=1, use_async=True, model_name=model 466 | # ).run() 467 | for use_async in [False]: 468 | for num_nodes in [2]: 469 | FederatedTrainingSession( 470 | model_name=model, track=True, num_nodes=num_nodes, use_async=use_async 471 | ).run() 472 | -------------------------------------------------------------------------------- /tests/test_tf_training.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List, Tuple, Any 3 | import numpy as np 4 | from tensorflow.keras.datasets import mnist 5 | from tensorflow import keras 6 | from flwr.common import ( 7 | Code, 8 | FitRes, 9 | NDArrays, 10 | Parameters, 11 | Status, 12 | ndarrays_to_parameters, 13 | parameters_to_ndarrays, 14 | ) 15 | from flwr.server.client_proxy import ClientProxy 16 | from flwr.server.strategy import FedAvg, FedAdam, FedAvgM 17 | from uuid import uuid4 18 | from flwr_serverless.federated_node.async_federated_node import AsyncFederatedNode 19 | from flwr_serverless.federated_node.sync_federated_node import SyncFederatedNode 20 | from flwr_serverless.shared_folder.in_memory_folder import InMemoryFolder 21 | from flwr_serverless.shared_folder.local_folder import LocalFolder 22 | from flwr_serverless.keras.federated_learning_callback import FlwrFederatedCallback 23 | from flwr_serverless.keras.example import ( 24 | FederatedLearningTestRun, 25 | MnistModelBuilder, 26 | split_training_data_into_paritions, 27 | ) 28 | 29 | # os.environ["CUDA_VISIBLE_DEVICES"] = "" 30 | 31 | 32 | def test_mnist_training_clients_on_partitioned_data(): 33 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 34 | # x_train.shape: (60000, 28, 28) 35 | # print(y_train.shape) # (60000,) 36 | epochs = 6 37 | image_size = x_train.shape[1] 38 | batch_size = 32 39 | steps_per_epoch = 8 40 | x_train = np.reshape(x_train, [-1, image_size, image_size, 1]) 41 | x_test = np.reshape(x_test, [-1, image_size, image_size, 1]) 42 | x_train = x_train.astype(np.float32) / 255 43 | x_test = x_test.astype(np.float32) / 255 44 | 45 | model_standalone1 = MnistModelBuilder().run() 46 | model_standalone2 = MnistModelBuilder().run() 47 | 48 | partitioned_x_train, partitioned_y_train = split_training_data_into_paritions( 49 | x_train, y_train, num_partitions=2 50 | ) 51 | x_train_partition_1 = partitioned_x_train[0] 52 | y_train_partition_1 = partitioned_y_train[0] 53 | x_train_partition_2 = partitioned_x_train[1] 54 | y_train_partition_2 = partitioned_y_train[1] 55 | 56 | # Using generator for its ability to resume. This is important for federated learning, otherwise in each federated round, 57 | # the cursor starts from the beginning every time. 58 | def train_generator1(batch_size): 59 | while True: 60 | for i in range(0, len(x_train_partition_1), batch_size): 61 | yield x_train_partition_1[i : i + batch_size], y_train_partition_1[ 62 | i : i + batch_size 63 | ] 64 | 65 | def train_generator2(batch_size): 66 | while True: 67 | for i in range(0, len(x_train_partition_2), batch_size): 68 | yield x_train_partition_2[i : i + batch_size], y_train_partition_2[ 69 | i : i + batch_size 70 | ] 71 | 72 | train_loader_standalone1 = train_generator1(batch_size) 73 | train_loader_standalone2 = train_generator2(batch_size) 74 | model_standalone1.fit( 75 | train_loader_standalone1, epochs=epochs, steps_per_epoch=steps_per_epoch 76 | ) 77 | model_standalone2.fit( 78 | train_loader_standalone2, epochs=epochs, steps_per_epoch=steps_per_epoch 79 | ) 80 | _, accuracy_standalone1 = model_standalone1.evaluate( 81 | x_test, y_test, batch_size=batch_size, steps=10 82 | ) 83 | _, accuracy_standalone2 = model_standalone2.evaluate( 84 | x_test, y_test, batch_size=batch_size, steps=10 85 | ) 86 | assert accuracy_standalone1 < 0.55 87 | assert accuracy_standalone2 < 0.55 88 | 89 | # federated learning 90 | model_client1 = MnistModelBuilder().run() 91 | model_client2 = MnistModelBuilder().run() 92 | 93 | # strategy = FedAvg() 94 | strategy = FedAvgM() 95 | # FedAdam does not work well in this setting. 96 | # tmp_model = CreateMnistModel().run() 97 | # strategy = FedAdam(initial_parameters=ndarrays_to_parameters(tmp_model.get_weights()), eta=1e-1) 98 | client_0 = None 99 | client_1 = None 100 | 101 | num_federated_rounds = epochs 102 | num_epochs_per_round = 1 103 | train_loader_client1 = train_generator1(batch_size=batch_size) 104 | train_loader_client2 = train_generator2(batch_size=batch_size) 105 | for i_round in range(num_federated_rounds): 106 | print("\n============ Round", i_round) 107 | # TODO: bug! dataloader starts from the beginning of the dataset! We should use a generator 108 | model_client1.fit( 109 | train_loader_client1, 110 | epochs=num_epochs_per_round, 111 | steps_per_epoch=steps_per_epoch, 112 | ) 113 | model_client2.fit( 114 | train_loader_client2, 115 | epochs=num_epochs_per_round, 116 | steps_per_epoch=steps_per_epoch, 117 | ) 118 | num_examples = batch_size * 10 119 | 120 | param_0: Parameters = ndarrays_to_parameters(model_client1.get_weights()) 121 | param_1: Parameters = ndarrays_to_parameters(model_client2.get_weights()) 122 | 123 | # Aggregation using the strategy. 124 | results: List[Tuple[ClientProxy, FitRes]] = [ 125 | ( 126 | client_0, 127 | FitRes( 128 | status=Status(code=Code.OK, message="Success"), 129 | parameters=param_0, 130 | num_examples=num_examples, 131 | metrics={}, 132 | ), 133 | ), 134 | ( 135 | client_1, 136 | FitRes( 137 | status=Status(code=Code.OK, message="Success"), 138 | parameters=param_1, 139 | num_examples=num_examples, 140 | metrics={}, 141 | ), 142 | ), 143 | ] 144 | 145 | aggregated_parameters, _ = strategy.aggregate_fit( 146 | server_round=i_round + 1, results=results, failures=[] 147 | ) 148 | # turn actual_aggregated back to keras.Model. 149 | aggregated_parameters_numpy: NDArrays = parameters_to_ndarrays( 150 | aggregated_parameters 151 | ) 152 | # Update client model weights using the aggregated parameters. 153 | model_client1.set_weights(aggregated_parameters_numpy) 154 | model_client2.set_weights(aggregated_parameters_numpy) 155 | 156 | _, accuracy_federated = model_client1.evaluate( 157 | x_test, y_test, batch_size=32, steps=10 158 | ) 159 | assert accuracy_federated > accuracy_standalone1 160 | assert accuracy_federated > accuracy_standalone2 161 | assert accuracy_federated > 0.6 # flaky test 162 | 163 | 164 | def test_mnist_training_standalone(): 165 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 166 | # x_train.shape: (60000, 28, 28) 167 | # print(y_train.shape) # (60000,) 168 | # Normalize 169 | image_size = x_train.shape[1] 170 | x_train = np.reshape(x_train, [-1, image_size, image_size, 1]) 171 | x_test = np.reshape(x_test, [-1, image_size, image_size, 1]) 172 | x_train = x_train.astype(np.float32) / 255 173 | x_test = x_test.astype(np.float32) / 255 174 | model = MnistModelBuilder().run() 175 | 176 | model.fit(x_train, y_train, epochs=3, batch_size=32, steps_per_epoch=10) 177 | # TODO: look into the history object to get accuracy 178 | # memorization test 179 | loss, accuracy = model.evaluate(x_test, y_test, batch_size=32, steps=10) 180 | # print(history[-1]) 181 | assert accuracy > 0.6 182 | 183 | 184 | def test_mnist_training_using_federated_nodes(): 185 | # epochs = standalone_epochs = 3 # does not work 186 | epochs = standalone_epochs = 8 # works 187 | 188 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 189 | # x_train.shape: (60000, 28, 28) 190 | # print(y_train.shape) # (60000,) 191 | # Normalize 192 | image_size = x_train.shape[1] 193 | batch_size = 32 194 | steps_per_epoch = 8 195 | 196 | x_train = np.reshape(x_train, [-1, image_size, image_size, 1]) 197 | x_test = np.reshape(x_test, [-1, image_size, image_size, 1]) 198 | x_train = x_train.astype(np.float32) / 255 199 | x_test = x_test.astype(np.float32) / 255 200 | 201 | model_standalone1 = MnistModelBuilder().run() 202 | model_standalone2 = MnistModelBuilder().run() 203 | 204 | partitioned_x_train, partitioned_y_train = split_training_data_into_paritions( 205 | x_train, y_train, num_partitions=2 206 | ) 207 | x_train_partition_1 = partitioned_x_train[0] 208 | y_train_partition_1 = partitioned_y_train[0] 209 | x_train_partition_2 = partitioned_x_train[1] 210 | y_train_partition_2 = partitioned_y_train[1] 211 | 212 | # Using generator for its ability to resume. This is important for federated learning, otherwise in each federated round, 213 | # the cursor starts from the beginning every time. 214 | def train_generator1(batch_size): 215 | while True: 216 | for i in range(0, len(x_train_partition_1), batch_size): 217 | yield x_train_partition_1[i : i + batch_size], y_train_partition_1[ 218 | i : i + batch_size 219 | ] 220 | 221 | def train_generator2(batch_size): 222 | while True: 223 | for i in range(0, len(x_train_partition_2), batch_size): 224 | yield x_train_partition_2[i : i + batch_size], y_train_partition_2[ 225 | i : i + batch_size 226 | ] 227 | 228 | train_loader_standalone1 = train_generator1(batch_size) 229 | train_loader_standalone2 = train_generator2(batch_size) 230 | model_standalone1.fit( 231 | train_loader_standalone1, epochs=epochs, steps_per_epoch=steps_per_epoch 232 | ) 233 | model_standalone2.fit( 234 | train_loader_standalone2, epochs=epochs, steps_per_epoch=steps_per_epoch 235 | ) 236 | print("Evaluating on the combined test set:") 237 | _, accuracy_standalone1 = model_standalone1.evaluate( 238 | x_test, y_test, batch_size=batch_size, steps=10 239 | ) 240 | _, accuracy_standalone2 = model_standalone2.evaluate( 241 | x_test, y_test, batch_size=batch_size, steps=10 242 | ) 243 | assert accuracy_standalone1 < 0.55 244 | assert accuracy_standalone2 < 0.55 245 | 246 | # federated learning 247 | model_client1 = MnistModelBuilder().run() 248 | model_client2 = MnistModelBuilder().run() 249 | 250 | strategy = FedAvg() 251 | # strategy = FedAvgM() 252 | # FedAdam does not work well in this setting. 253 | # tmp_model = CreateMnistModel().run() 254 | # strategy = FedAdam(initial_parameters=ndarrays_to_parameters(tmp_model.get_weights()), eta=1e-1) 255 | 256 | num_federated_rounds = standalone_epochs 257 | num_epochs_per_round = 1 258 | train_loader_client1 = train_generator1(batch_size=batch_size) 259 | train_loader_client2 = train_generator2(batch_size=batch_size) 260 | 261 | storage_backend = InMemoryFolder() 262 | node1 = AsyncFederatedNode(shared_folder=storage_backend, strategy=strategy) 263 | node2 = AsyncFederatedNode(shared_folder=storage_backend, strategy=strategy) 264 | for i_round in range(num_federated_rounds): 265 | print("\n============ Round", i_round) 266 | model_client1.fit( 267 | train_loader_client1, 268 | epochs=num_epochs_per_round, 269 | steps_per_epoch=steps_per_epoch, 270 | ) 271 | num_examples = batch_size * 10 272 | param_1: Parameters = ndarrays_to_parameters(model_client1.get_weights()) 273 | updated_param_1, _ = node1.update_parameters(param_1, num_examples=num_examples) 274 | if updated_param_1 is not None: 275 | model_client1.set_weights(parameters_to_ndarrays(updated_param_1)) 276 | else: 277 | print("node1 is waiting for other nodes to send their parameters") 278 | 279 | model_client2.fit( 280 | train_loader_client2, 281 | epochs=num_epochs_per_round, 282 | steps_per_epoch=steps_per_epoch, 283 | ) 284 | num_examples = batch_size * 10 285 | param_2: Parameters = ndarrays_to_parameters(model_client2.get_weights()) 286 | updated_param_2, _ = node2.update_parameters(param_2, num_examples=num_examples) 287 | if updated_param_2 is not None: 288 | model_client2.set_weights(parameters_to_ndarrays(updated_param_2)) 289 | else: 290 | print("node2 is waiting for other nodes to send their parameters") 291 | 292 | print("Evaluating on the combined test set:") 293 | _, accuracy_federated = model_client1.evaluate( 294 | x_test, y_test, batch_size=32, steps=10 295 | ) 296 | 297 | assert accuracy_federated > accuracy_standalone1 298 | assert accuracy_federated > accuracy_standalone2 299 | assert accuracy_federated > 0.6 # flaky test 300 | 301 | 302 | def test_mnist_federated_callback_2nodes(): 303 | epochs = 8 304 | accuracy_standalone, accuracy_federated = FederatedLearningTestRun( 305 | num_nodes=2, 306 | epochs=epochs, 307 | num_rounds=epochs, 308 | lr=0.001, 309 | strategy=FedAvg(), 310 | ).run() 311 | for i in range(len(accuracy_standalone)): 312 | assert accuracy_standalone[i] < 1.0 / len(accuracy_standalone) + 0.05 313 | 314 | assert accuracy_federated[0] > accuracy_standalone[0] 315 | assert accuracy_federated[0] > 1.0 / len(accuracy_standalone) + 0.05 316 | 317 | 318 | def test_mnist_federated_callback_2nodes_synchronously(tmpdir): 319 | epochs = 8 320 | local_shared_folder = InMemoryFolder() 321 | # local_shared_folder = LocalFolder(directory=str(tmpdir.join("fed_test"))) 322 | session = FederatedLearningTestRun( 323 | num_nodes=2, 324 | epochs=epochs, 325 | num_rounds=epochs, 326 | lr=0.001, 327 | strategy=FedAvg(), 328 | train_concurrently=True, 329 | use_async_node=False, 330 | save_model_before_aggregation=True, 331 | storage_backend=local_shared_folder, 332 | ) 333 | accuracy_standalone, accuracy_federated = session.run() 334 | for i in range(len(accuracy_standalone)): 335 | assert accuracy_standalone[i] < 1.0 / len(accuracy_standalone) + 0.05 336 | 337 | assert accuracy_federated[0] > accuracy_standalone[0] 338 | assert accuracy_federated[0] > 1.0 / len(accuracy_standalone) + 0.05 339 | 340 | # assert metrics files are tracked 341 | node_id = session.nodes[0].node_id 342 | raw_folder = session.storage_backend.get_raw_folder() 343 | json_bytes = raw_folder[f"keras/{node_id}/metrics_before_aggregation_00000.json"] 344 | assert json_bytes is not None 345 | metrics_dict = json.loads(json_bytes.decode("utf-8")) 346 | assert metrics_dict["loss"] > 0.0 347 | model_bytes = raw_folder[f"keras/{node_id}/model_before_aggregation_00000.h5"] 348 | assert model_bytes is not None 349 | assert len(model_bytes) > 0 350 | 351 | # assert the keras logs object has "*_fed" metrics 352 | first_callback = session.callbacks_per_client[0] 353 | assert "accuracy_fed" in first_callback.logs, f"{first_callback.logs}" 354 | assert "val_accuracy_fed" in first_callback.logs, f"{first_callback.logs}" 355 | 356 | 357 | def test_mnist_federated_callback_3nodes(): 358 | epochs = 8 359 | accuracy_standalone, accuracy_federated = FederatedLearningTestRun( 360 | num_nodes=3, 361 | epochs=epochs, 362 | num_rounds=epochs, 363 | lr=0.001, 364 | strategy=FedAvg(), 365 | ).run() 366 | for i in range(len(accuracy_standalone)): 367 | assert accuracy_standalone[i] < 1.0 / len(accuracy_standalone) + 0.05 368 | 369 | assert accuracy_federated[0] > accuracy_standalone[0] 370 | assert accuracy_federated[0] > 1.0 / len(accuracy_standalone) + 0.05 371 | 372 | 373 | def test_mnist_federated_callback_2nodes_lag0_1(tmpdir): 374 | epochs = 10 375 | num_nodes = 2 376 | accuracy_standalone, accuracy_federated = FederatedLearningTestRun( 377 | num_nodes=num_nodes, 378 | epochs=epochs, 379 | num_rounds=epochs, 380 | batch_size=32, 381 | steps_per_epoch=8, 382 | lr=0.001, 383 | strategy=FedAvg(), 384 | # storage_backend=InMemoryFolder(), 385 | storage_backend=LocalFolder(directory=str(tmpdir.join("fed_test"))), 386 | train_pseudo_concurrently=True, 387 | use_async_node=True, 388 | lag=0.1, 389 | ).run() 390 | for i in range(len(accuracy_standalone)): 391 | assert accuracy_standalone[i] < 1.0 / len(accuracy_standalone) + 0.05 392 | 393 | assert accuracy_federated[-1] > accuracy_standalone[-1] 394 | assert accuracy_federated[-1] > 1.0 / num_nodes + 0.05 395 | 396 | 397 | def test_mnist_federated_callback_2nodes_lag2(tmpdir): 398 | epochs = 10 399 | num_nodes = 2 400 | accuracy_standalone, accuracy_federated = FederatedLearningTestRun( 401 | num_nodes=num_nodes, 402 | epochs=epochs, 403 | num_rounds=epochs, 404 | batch_size=32, 405 | steps_per_epoch=8, 406 | lr=0.001, 407 | strategy=FedAvg(), 408 | storage_backend=InMemoryFolder(), 409 | # storage_backend=LocalFolder(directory=str(tmpdir.join("fed_test"))), 410 | train_pseudo_concurrently=True, 411 | use_async_node=True, 412 | lag=2, 413 | ).run() 414 | for i in range(len(accuracy_standalone)): 415 | assert accuracy_standalone[i] < 1.0 / len(accuracy_standalone) + 0.05 416 | 417 | assert accuracy_federated[-1] > accuracy_standalone[-1] 418 | assert accuracy_federated[-1] > 1.0 / num_nodes + 0.05 419 | 420 | 421 | def test_mnist_federated_callback_2nodes_concurrent(tmpdir): 422 | epochs = 8 423 | num_nodes = 2 424 | fed_dir = tmpdir.join("fed_test") 425 | accuracy_standalone, accuracy_federated = FederatedLearningTestRun( 426 | num_nodes=num_nodes, 427 | epochs=epochs, 428 | num_rounds=epochs, 429 | batch_size=32, 430 | steps_per_epoch=8, 431 | lr=0.001, 432 | strategy=FedAvg(), 433 | # storage_backend=InMemoryFolder(), 434 | storage_backend=LocalFolder(directory=str(fed_dir)), 435 | train_concurrently=True, 436 | # use_async_node=False, 437 | use_async_node=True, 438 | ).run() 439 | # print(fed_dir.listdir()) 440 | for i in range(len(accuracy_standalone)): 441 | assert accuracy_standalone[i] < 1.0 / len(accuracy_standalone) + 0.05 442 | 443 | assert accuracy_federated[-1] > accuracy_standalone[-1] 444 | assert accuracy_federated[-1] > 1.0 / num_nodes + 0.05 445 | 446 | 447 | if __name__ == "__main__": 448 | test_mnist_federated_callback_2nodes_concurrent() 449 | --------------------------------------------------------------------------------