├── multirun └── .gitkeep ├── outputs └── .gitkeep ├── wandb └── .gitkeep ├── plotting ├── plots │ └── .gitkeep ├── tables │ └── .gitkeep └── data_analysis.ipynb ├── .python-version ├── project ├── fed │ ├── server │ │ ├── strategy │ │ │ ├── strategy.py │ │ │ └── __init__.py │ │ ├── server.py │ │ ├── __init__.py │ │ ├── deterministic_client_manager.py │ │ ├── wandb_history.py │ │ └── wandb_server.py │ ├── __init__.py │ └── utils │ │ └── __init__.py ├── task │ ├── default │ │ ├── __init__.py │ │ ├── models.py │ │ ├── dataset_preparation.py │ │ ├── dataset.py │ │ ├── dispatch.py │ │ └── train_test.py │ ├── mnist_classification │ │ ├── __init__.py │ │ ├── models.py │ │ ├── dataset.py │ │ ├── dispatch.py │ │ ├── train_test.py │ │ └── dataset_preparation.py │ └── __init__.py ├── conf │ ├── strategy │ │ ├── fedavg.yaml │ │ ├── fedadagrad.yaml │ │ ├── fedavgm.yaml │ │ ├── fedadam.yaml │ │ └── fedyogi.yaml │ ├── dataset │ │ ├── default.yaml │ │ └── mnist.yaml │ ├── fed │ │ ├── default.yaml │ │ └── mnist.yaml │ ├── task │ │ ├── default.yaml │ │ └── mnist.yaml │ ├── mnist.yaml │ └── base.yaml ├── __init__.py ├── types │ ├── __init__.py │ └── common.py ├── utils │ ├── __init__.py │ └── utils.py ├── dispatch │ ├── __init__.py │ └── dispatch.py ├── client │ ├── __init__.py │ └── client.py └── main.py ├── run_scripts └── launch.sh ├── .github ├── dependabot.yml ├── ISSUE_TEMPLATE │ ├── feature_request.yml │ └── bug_report.yml ├── PULL_REQUEST_TEMPLATE.md └── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── setup.sh ├── .pre-commit-config.yaml ├── .gitignore ├── pyproject.toml ├── EXTENDED_README.md ├── LICENSE └── README.md /multirun/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /outputs/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wandb/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /plotting/plots/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.11.6 2 | -------------------------------------------------------------------------------- /plotting/tables/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /project/fed/server/strategy/strategy.py: -------------------------------------------------------------------------------- 1 | """A custom strategy.""" 2 | -------------------------------------------------------------------------------- /project/task/default/__init__.py: -------------------------------------------------------------------------------- 1 | """The task template containing generic functionality.""" 2 | -------------------------------------------------------------------------------- /project/conf/strategy/fedavg.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: FedAvg 3 | 4 | init: 5 | _target_: flwr.server.strategy.FedAvg 6 | -------------------------------------------------------------------------------- /run_scripts/launch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Add your code for the job manager 3 | 4 | poetry run python -m project.main --config-name=base 5 | -------------------------------------------------------------------------------- /project/fed/server/server.py: -------------------------------------------------------------------------------- 1 | """Optionally define a new Server class from WandbServer. 2 | 3 | Please note this is not needed in most settings. 4 | """ 5 | -------------------------------------------------------------------------------- /project/conf/strategy/fedadagrad.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: FedAdagrad 3 | 4 | init: 5 | _target_: flwr.server.strategy.FedAdagrad 6 | eta: 0.1 7 | tau: 0.01 8 | -------------------------------------------------------------------------------- /project/fed/__init__.py: -------------------------------------------------------------------------------- 1 | """Functionality for federated learning including strategies, servers and utils. 2 | 3 | You may not need to modify this functionality. 4 | """ 5 | -------------------------------------------------------------------------------- /project/conf/strategy/fedavgm.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: FedAvgM 3 | 4 | init: 5 | _target_: flwr.server.strategy.FedAvgM 6 | server_learning_rate: 0.31622 7 | server_momentum: 0.9 8 | -------------------------------------------------------------------------------- /project/fed/server/__init__.py: -------------------------------------------------------------------------------- 1 | """Control FL-related functions relating to the server, strategies and client manager. 2 | 3 | You may not neeed to change any of this code. 4 | """ 5 | -------------------------------------------------------------------------------- /project/fed/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utilities for federated learning. 2 | 3 | Examples: setting the parameters of a project, 4 | computing the norm of a set of parameters, etc. 5 | """ 6 | -------------------------------------------------------------------------------- /project/__init__.py: -------------------------------------------------------------------------------- 1 | """The root of the project with the hydra entry point. 2 | 3 | All logic should flow through hydra. Type everything using interfaces and ABCs rather 4 | than concretions. 5 | """ 6 | -------------------------------------------------------------------------------- /project/conf/strategy/fedadam.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: FedAdam 3 | 4 | init: 5 | _target_: flwr.server.strategy.FedAdam 6 | beta_1: 0.9 7 | beta_2: 0.99 8 | eta: 0.01 9 | tau: 0.001 10 | -------------------------------------------------------------------------------- /project/conf/strategy/fedyogi.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: FedYogi 3 | 4 | init: 5 | _target_: flwr.server.strategy.FedYogi 6 | beta_1: 0.9 7 | beta_2: 0.99 8 | eta: 0.01 9 | tau: 0.001 10 | -------------------------------------------------------------------------------- /project/task/mnist_classification/__init__.py: -------------------------------------------------------------------------------- 1 | """The MNIST classification task example. 2 | 3 | This task is meant to showcase the expected code structure for a scientific research 4 | project. It can be run entirely on the CPU. 5 | """ 6 | -------------------------------------------------------------------------------- /project/task/__init__.py: -------------------------------------------------------------------------------- 1 | """Task code including datasets, models and train/test loops. 2 | 3 | The default task serves as a template and may contain functionality likely to generalise 4 | across tasks. Copy it for adding a new task. 5 | """ 6 | -------------------------------------------------------------------------------- /project/fed/server/strategy/__init__.py: -------------------------------------------------------------------------------- 1 | """Optionally define a custom strategy. 2 | 3 | Needed only when the strategy is not yet implemented in Flower or because you want to 4 | extend or modify the functionality of an existing strategy. 5 | """ 6 | -------------------------------------------------------------------------------- /project/types/__init__.py: -------------------------------------------------------------------------------- 1 | """The typing module for the project. 2 | 3 | Should contain types likely to be shared, re-used or that will interact with hydra. 4 | Always prefer an interface from this module over an ad-hoc inline definition or concrete 5 | type. 6 | """ 7 | -------------------------------------------------------------------------------- /project/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Define any utility function. 2 | 3 | They are not directly relevant to the other (more FL specific) python modules. For 4 | example, you may define here things like: loading a model from a checkpoint, saving 5 | results, plotting. 6 | """ 7 | -------------------------------------------------------------------------------- /project/dispatch/__init__.py: -------------------------------------------------------------------------------- 1 | """Dynamically dispatch functionality from configs. 2 | 3 | Handles mappign from strings specified in the hydra config to the task functions used in 4 | the project. Dispatching at this top level should simply select the correct functions 5 | from each task. 6 | """ 7 | -------------------------------------------------------------------------------- /project/client/__init__.py: -------------------------------------------------------------------------------- 1 | """Define your client class and a function to construct such clients. 2 | 3 | Please overwrite `flwr.client.NumPyClient` or `flwr.client.Client` and create a function 4 | to instantiate your client. Make sure the model and dataset are not loaded before the 5 | fit function. 6 | """ 7 | -------------------------------------------------------------------------------- /project/conf/dataset/default.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Where to store data 3 | dataset_dir: null 4 | 5 | # Folder of client folders storing train and test 6 | # Optional, you can also use other methods to partition data to dataloaders 7 | # partition_dir: null 8 | 9 | # How large should the test set of each client 10 | # be relative to the train set 11 | val_ratio: 0.1 12 | 13 | # Seed for partition generation 14 | seed: 1337 15 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | --- 2 | version: 2 3 | 4 | updates: 5 | - package-ecosystem: pip 6 | directory: / 7 | schedule: 8 | interval: daily 9 | allow: 10 | # Allow only updates for dev dependencies 11 | - dependency-type: development 12 | ignore: null 13 | # Ignore updates from certain packages 14 | # - dependency-name: "grpcio-tools" 15 | # - dependency-name: "mypy-protobuf" 16 | # - dependency-name: "types-protobuf" 17 | open-pull-requests-limit: 3 18 | 19 | - package-ecosystem: github-actions 20 | directory: / 21 | schedule: 22 | # Check for updates to GitHub Actions every week 23 | interval: weekly 24 | -------------------------------------------------------------------------------- /project/conf/dataset/mnist.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Where to store data 3 | dataset_dir: ./data/mnist/data 4 | 5 | # Folder of client folders 6 | # indexed by id, containing train and test 7 | # data in .pt format 8 | partition_dir: ./data/mnist/partition 9 | 10 | # How many clients to create 11 | num_clients: 10 12 | 13 | # How large should the test set of each client 14 | # be relative to the train set 15 | val_ratio: 0.1 16 | 17 | # Seed for partition generation 18 | seed: 1337 19 | 20 | # If the partition labels 21 | # should be independent and identically distributed 22 | iid: false 23 | 24 | # If the partition labels should follow a power law 25 | # distribution 26 | power_law: true 27 | 28 | # If the partition labels should be balanced 29 | balance: false 30 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | description: Suggest a new baseline, strategy, example, ... 4 | labels: [feature request] 5 | 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: > 10 | #### If you want to propose a new feature, please check the PRs if someone already 11 | works on this feature. 12 | - type: textarea 13 | attributes: 14 | label: Describe the type of feature and its functionality. 15 | validations: 16 | required: true 17 | - type: textarea 18 | attributes: 19 | label: Describe step by step what files and adjustments are you planning to include. 20 | validations: 21 | required: true 22 | - type: textarea 23 | attributes: 24 | label: Is there something else you want to add? 25 | -------------------------------------------------------------------------------- /project/task/default/models.py: -------------------------------------------------------------------------------- 1 | """Define our models, and training and eval functions.""" 2 | 3 | from omegaconf import DictConfig 4 | from torch import nn 5 | 6 | from project.types.common import IsolatedRNG 7 | 8 | 9 | class Net(nn.Module): 10 | """A PyTorch model.""" 11 | 12 | # TODO: define your model here 13 | 14 | 15 | def get_net( 16 | _config: dict, 17 | rng_tuple: IsolatedRNG, 18 | _hydra_config: DictConfig | None, 19 | ) -> nn.Module: 20 | """Return a model instance. 21 | 22 | Args: 23 | config: A dictionary with the model configuration. 24 | rng_tuple: The random number generator state for the training. 25 | Use if you need seeded random behavior 26 | 27 | Returns 28 | ------- 29 | nn.Module 30 | A PyTorch model. 31 | """ 32 | return Net() 33 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Federated Learning Research Template 2 | 3 | We welcome contributions! 4 | 5 | ## Legal Notice 6 | 7 | All contributions must and will be licensed under [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0.html). 8 | 9 | ## Code of Conduct 10 | 11 | Flower has adopted the [Contributor Covenant](https://www.contributor-covenant.org/) as its Code of Conduct. All community members are expected to adhere to it. Please see [CODE_OF_CONDUCT.md](.github/CODE_OF_CONDUCT.md) for details. 12 | 13 | ## Code Review & Acceptance Process 14 | 15 | All contributions, including contributions from core project members, require code review. The process is simple: Open a [Pull Request on GitHub](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-requests) and ensure all CI checks pass. One of the responsible code owners will review your contribution. 16 | -------------------------------------------------------------------------------- /project/conf/fed/default.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Federated setup 3 | num_rounds: 5 4 | num_total_clients: 10 5 | num_clients_per_round: 2 6 | num_evaluate_clients_per_round: 2 7 | 8 | # If clients should be sampled with replacement 9 | # when the dataset is not large enough 10 | enable_resampling: false 11 | 12 | # Client resources 13 | cpus_per_client: 1 14 | gpus_per_client: 0 15 | 16 | # Seed for client selection 17 | seed: 1337 18 | 19 | # Settings for loading the initial parameters 20 | # used by the server 21 | 22 | # If the server should try to load saved parameters and rng, if it fails it will do the normal procedure of generating a random net 23 | # Leave on as true in case you need to enable checkpoints later on 24 | load_saved_state: true 25 | # The round from which to load the parameters 26 | # if null it will load the most recent round 27 | server_round: null 28 | 29 | # Path to the folder where the parameters are located, leave null if you want to use automatic detection from the results folder 30 | parameters_folder: null 31 | 32 | # Path to the folder where the random number generators are located, leave null if you want to use automatic detection from the results folder 33 | rng_folder: null 34 | 35 | history_folder: null 36 | -------------------------------------------------------------------------------- /project/conf/fed/mnist.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Federated setup 3 | num_rounds: 5 4 | num_total_clients: 10 5 | num_clients_per_round: 2 6 | num_evaluate_clients_per_round: 2 7 | 8 | # If clients should be sampled with replacement 9 | # when the dataset is not large enough 10 | enable_resampling: false 11 | 12 | # Client resources 13 | cpus_per_client: 2 14 | gpus_per_client: 0 15 | 16 | # Seed for client selection 17 | seed: 1337 18 | 19 | # Settings for loading the initial parameters 20 | # used by the server 21 | 22 | # If the server should try to load saved parameters and rng, if it fails it will do the normal procedure of generating a random net 23 | # Leave on as true in case you need to enable checkpoints later on 24 | load_saved_state: true 25 | # The round from which to load the parameters 26 | # if null it will load the most recent round 27 | server_round: null 28 | 29 | # Path to the folder where the parameters are located, leave null if you want to use automatic detection from the results folder 30 | parameters_folder: null 31 | 32 | # Path to the folder where the random number generators are located, leave null if you want to use automatic detection from the results folder 33 | rng_folder: null 34 | 35 | history_folder: null 36 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 6 | 7 | ## Issue 8 | 9 | ### Description 10 | 11 | 16 | 17 | ### Related issues/PRs 18 | 19 | 24 | 25 | ## Proposal 26 | 27 | ### Explanation 28 | 29 | 34 | 35 | ### Checklist 36 | 37 | - [ ] Implement proposed change 38 | - [ ] Write tests 39 | - [ ] Make CI checks pass 40 | - [ ] Ping maintainers (@Iacob-Alexandru-Andrei, @relogu) 41 | 42 | ### Any other comments? 43 | 44 | 52 | -------------------------------------------------------------------------------- /project/task/default/dataset_preparation.py: -------------------------------------------------------------------------------- 1 | """Handle the dataset partitioning and (optionally) complex downloads. 2 | 3 | Please add here all the necessary logic to either download, uncompress, pre/post-process 4 | your dataset (or all of the above). If the desired way of running your code is to first 5 | download the dataset and partition it and then run the experiments, please uncomment the 6 | lines below and tell us in the README.md (see the "Running the Experiment" block) that 7 | this file should be executed first. 8 | """ 9 | 10 | # import hydra 11 | # from hydra.core.hydra_config import HydraConfig 12 | # from hydra.utils import call, instantiate 13 | # from omegaconf import DictConfig, OmegaConf 14 | # from flwr.common.logger import log 15 | # import logging 16 | 17 | 18 | # @hydra.main(config_path="../../conf", config_name="base", version_base=None) 19 | # def download_and_preprocess(cfg: DictConfig) -> None: 20 | # """Does everything needed to get the dataset. 21 | 22 | # Parameters 23 | # ---------- 24 | # cfg : DictConfig 25 | # An omegaconf object that stores the hydra config. 26 | # """ 27 | 28 | # ## 1. print parsed config 29 | # log(logging.INFO, OmegaConf.to_yaml(cfg)) 30 | 31 | # # Please include here all the logic 32 | # # Please use the Hydra config style as much as possible specially 33 | # # for parts that can be customised (e.g. how data is partitioned) 34 | 35 | # if __name__ == "__main__": 36 | 37 | # download_and_preprocess() 38 | -------------------------------------------------------------------------------- /project/conf/task/default.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # These strings are constants used by the dispatcher to select functionality at runtime 3 | # Please implement all behaviour in the task-level dispatch.py file and then add the dispatch functions to the top-level dispatch.py 4 | # Choose the model and dataset 5 | model_and_data: DEFAULT 6 | # Choose the train, test and server fed_eval functions 7 | train_structure: DEFAULT 8 | 9 | client_gen: DEFAULT 10 | client_manager: DEFAULT 11 | server: DEFAULT 12 | 13 | # Client fit config 14 | fit_config: 15 | # Default net is empty and takes no args 16 | net_config: {} 17 | # Default dataloader is empty, added just for completeness 18 | dataloader_config: 19 | batch_size: 1 20 | # Default train does nothing 21 | run_config: {} 22 | extra: {} 23 | 24 | # Client eval config 25 | eval_config: 26 | net_config: {} 27 | # The batch size for testing can be as high as the GPU supports 28 | dataloader_config: 29 | batch_size: 8 30 | run_config: {} 31 | extra: {} 32 | 33 | # Configuration for the federated testing function 34 | # Follows the same conventions as the client config 35 | fed_test_config: 36 | net_config: {} 37 | # Testing batch size can be as high as the GPU supports 38 | dataloader_config: 39 | batch_size: 8 40 | run_config: {} 41 | extra: {} 42 | 43 | # Configuration instructions for initial parameter 44 | # generation 45 | net_config_initial_parameters: {} 46 | 47 | # The names of metrics you wish to aggregate 48 | # E.g., train_loss, test_accuracy 49 | fit_metrics: [] 50 | evaluate_metrics: [] 51 | -------------------------------------------------------------------------------- /plotting/data_analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "\"\"\"An example notebook for a data_analysis notebook you may write.\n", 10 | "\n", 11 | "Please use the recommended TNR font unless your venue requires otherwise.\n", 12 | "\"\"\"\n", 13 | "\n", 14 | "# Write down all of your code for data analysis in this file.\n", 15 | "# Plot the data and save the figures/tables in the folders \"plots\" and \"tables\"\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "import seaborn as sns\n", 18 | "\n", 19 | "# Use TNR for all figures\n", 20 | "# to match paper templates\n", 21 | "plt.rcParams[\"font.family\"] = \"serif\"\n", 22 | "plt.rcParams[\"font.serif\"] = [\n", 23 | " \"Times New Roman\",\n", 24 | "] + plt.rcParams[\"font.serif\"]\n", 25 | "\n", 26 | "# Whitegrid is most appropriate\n", 27 | "# for scientific papers\n", 28 | "sns.set_style(\"whitegrid\")\n", 29 | "\n", 30 | "# An optional colorblind palette\n", 31 | "# for figures\n", 32 | "CB_color_cycle = [\n", 33 | " \"#377EB8\",\n", 34 | " \"#FF7F00\",\n", 35 | " \"#4DAF4A\",\n", 36 | " \"#F781BF\",\n", 37 | " \"#A65628\",\n", 38 | " \"#984EA3\",\n", 39 | " \"#999999\",\n", 40 | " \"#E41A1C\",\n", 41 | " \"#DEDE00\",\n", 42 | "]" 43 | ] 44 | } 45 | ], 46 | "metadata": { 47 | "kernelspec": { 48 | "display_name": ".venv", 49 | "language": "python", 50 | "name": "python3" 51 | }, 52 | "language_info": { 53 | "name": "python", 54 | "version": "3.11.6" 55 | } 56 | }, 57 | "nbformat": 4, 58 | "nbformat_minor": 2 59 | } 60 | -------------------------------------------------------------------------------- /project/conf/task/mnist.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # These strings are constants used by the dispatcher to select functionality at runtime 3 | # Please implement all behaviour in the task-level dispatch.py file and then add the dispatch functions to the top-level dispatch.py 4 | # Choose the model and dataset 5 | model_and_data: MNIST_CNN 6 | # Choose the train, test and server fed_eval functions 7 | train_structure: MNIST 8 | 9 | client_gen: DEFAULT 10 | client_manager: DEFAULT 11 | server: DEFAULT 12 | 13 | # Client fit config 14 | fit_config: 15 | # Net does not require any configuration 16 | net_config: {} 17 | # Dataloader requires batch_size 18 | dataloader_config: 19 | batch_size: 4 20 | # The train function requires epochs and learning_rate 21 | run_config: 22 | epochs: 1 23 | learning_rate: 0.03 24 | # No extra config 25 | extra: {} 26 | 27 | # Client eval config 28 | eval_config: 29 | net_config: {} 30 | # The testing function batch size can be as high as the GPU supports 31 | dataloader_config: 32 | batch_size: 8 33 | # Unlike train, the mnist train function takes no parameters 34 | run_config: {} 35 | extra: {} 36 | 37 | # Configuration for the federated testing function 38 | # Follows the same conventions as the client config 39 | fed_test_config: 40 | net_config: {} 41 | # The testing function batch size can be as high as the GPU supports 42 | dataloader_config: 43 | batch_size: 8 44 | # Unlike train, the mnist train function takes no parameters 45 | run_config: {} 46 | extra: {} 47 | 48 | # Configuration instructions for initial parameter 49 | # generation 50 | net_config_initial_parameters: {} 51 | 52 | # The names of metrics you wish to aggregate 53 | fit_metrics: [train_loss, train_accuracy] 54 | evaluate_metrics: [test_accuracy] 55 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | description: Create a report to help us reproduce and correct the bug 4 | labels: [bug] 5 | 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: > 10 | #### Before submitting a bug, please make sure the issue hasn't been already 11 | addressed by searching through [the past issues] 12 | - type: textarea 13 | attributes: 14 | label: Describe the bug 15 | description: > 16 | A clear and concise description of what the bug is. 17 | validations: 18 | required: true 19 | - type: textarea 20 | attributes: 21 | label: Steps/Code to Reproduce 22 | description: | 23 | Please add a minimal code example that can reproduce the error when running it. 24 | placeholder: | 25 | ``` 26 | Sample code to reproduce the problem 27 | ``` 28 | validations: 29 | required: true 30 | - type: textarea 31 | attributes: 32 | label: Expected Results 33 | description: > 34 | Please paste or describe the expected results. 35 | placeholder: > 36 | Example: The server aggregated the parameters from all clients. 37 | validations: 38 | required: true 39 | - type: textarea 40 | attributes: 41 | label: Actual Results 42 | description: | 43 | Please paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full traceback** of the exception. 44 | placeholder: > 45 | Please paste or specifically describe the actual output or traceback. 46 | validations: 47 | required: true 48 | - type: markdown 49 | attributes: 50 | value: >- 51 | Thanks for contributing! 52 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | VPOETRY_HOME="" 4 | VPYENV_ROOT="" 5 | VPYTHON_VERSION="3.11.6" 6 | 7 | if ! [ -x "$(command -v pyenv)" ]; then 8 | if [ -z "$VPYENV_ROOT" ]; then 9 | echo "PYENV_ROOT is empty, please add it to the script or input a home now." 10 | read -r VPYENV_ROOT 11 | else 12 | echo "PYENV_ROOT is not empty" 13 | fi 14 | 15 | curl https://pyenv.run | bash 16 | eval "$(PYENV_ROOT=$VPYENV_ROOT pyenv init -)" 17 | { 18 | echo "export PYENV_ROOT=\"$VPYENV_ROOT\"" 19 | echo "command -v pyenv >/dev/null || export PATH=\"$VPYENV_ROOT/bin:\$PATH\"" 20 | echo "eval \"$(pyenv init -)\"" 21 | } >>~/.bashrc 22 | { 23 | echo "export PYENV_ROOT=\"$VPYENV_ROOT\"" 24 | echo "command -v pyenv >/dev/null || export PATH=\"$VPYENV_ROOT/bin:\$PATH\"" 25 | echo "eval \"$(pyenv init -)\"" 26 | } >>~/.profile 27 | else 28 | echo "Pyenv is already installed" 29 | fi 30 | 31 | if [ -z "$(command -v poetry)" ]; then 32 | 33 | if [ -z "$VPOETRY_HOME" ]; then 34 | echo "POETRY_HOME is empty, please add it to the script or input a home now." 35 | read -r VPOETRY_HOME 36 | else 37 | echo "POETRY_HOME is not empty" 38 | fi 39 | 40 | mkdir -p "$VPOETRY_HOME" 41 | curl -sSL https://install.python-poetry.org | POETRY_HOME=$VPOETRY_HOME python3 - 42 | else 43 | echo "Poetry is already installed" 44 | fi 45 | 46 | if pyenv versions | grep -q $VPYTHON_VERSION; then 47 | echo "Python $VPYTHON_VERSION is already installed" 48 | else 49 | pyenv install $VPYTHON_VERSION 50 | fi 51 | 52 | current_version=$(pyenv local) 53 | 54 | if [ "$current_version" = "$VPYTHON_VERSION" ]; then 55 | echo "The local Python version is already set to $VPYTHON_VERSION" 56 | else 57 | pyenv local $VPYTHON_VERSION 58 | fi 59 | poetry env use $VPYTHON_VERSION 60 | 61 | poetry install 62 | 63 | # Install pre-commit hooks 64 | poetry run pre-commit install 65 | 66 | # Command to run all pre-commit hooks 67 | poetry run pre-commit run --all-files --hook-stage push 68 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | default_install_hook_types: [pre-commit, pre-push] 3 | exclude: | 4 | (?x)^( 5 | outputs/| 6 | wandb/| 7 | multirun/| 8 | dev/| 9 | data/| 10 | .pre-commit-config.yaml 11 | )$ 12 | 13 | repos: 14 | - repo: https://github.com/astral-sh/ruff-pre-commit 15 | rev: v0.1.14 16 | hooks: 17 | - id: ruff 18 | types_or: [python, pyi, jupyter] 19 | args: [--fix, --preview, --extend-ignore=PLR0917,--extend-ignore=PLR0914] 20 | - repo: https://github.com/lyz-code/yamlfix/ 21 | rev: 1.16.0 22 | hooks: 23 | - id: yamlfix 24 | - repo: https://github.com/shellcheck-py/shellcheck-py 25 | rev: v0.10.0.1 26 | hooks: 27 | - id: shellcheck 28 | - repo: https://github.com/scop/pre-commit-shfmt 29 | rev: v3.8.0-1 30 | hooks: 31 | - id: shfmt 32 | - repo: https://github.com/psf/black 33 | rev: 24.1.1 34 | hooks: 35 | - id: black-jupyter 36 | language_version: python3.11 37 | args: [ --preview ] 38 | - repo: local 39 | hooks: 40 | - id: mypy 41 | name: mypy 42 | entry: poetry run mypy --incremental --show-traceback 43 | language: system 44 | types: [file, python] 45 | - repo: https://github.com/pre-commit/pre-commit-hooks 46 | rev: v4.5.0 # Use the ref you want to point at 47 | hooks: 48 | - id: check-added-large-files 49 | - id: check-ast 50 | - id: check-case-conflict 51 | - id: check-merge-conflict 52 | - id: detect-private-key 53 | 54 | ci: 55 | autofix_commit_msg: | 56 | [pre-commit.ci] auto fixes from pre-commit.com hooks 57 | 58 | for more information, see https://pre-commit.ci 59 | autofix_prs: true 60 | autoupdate_branch: "" 61 | autoupdate_commit_msg: "[pre-commit.ci] pre-commit autoupdate" 62 | autoupdate_schedule: weekly 63 | skip: [mypy] 64 | submodules: false 65 | -------------------------------------------------------------------------------- /project/conf/mnist.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | defaults: 3 | - _self_ 4 | - task: mnist 5 | - fed: mnist 6 | - strategy: fedavg 7 | - dataset: mnist 8 | 9 | # If checkpointing is enabled, 10 | # you may wish to save results to the same directory 11 | # just change the dir to {your_output_directory} 12 | hydra: 13 | run: 14 | dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} 15 | # dir: /path/to/your/output/directory 16 | 17 | # Working dir to save temporary files to 18 | # If null, defaults to hydra_dir/working 19 | working_dir: null 20 | 21 | # List of file patterns to be saved from working_dir 22 | # at the end of training 23 | to_save_once: [yaml, log, wandb] 24 | 25 | # List of file patterns to be saved repeatedly 26 | to_save_per_round: [parameters, rng, history] 27 | 28 | # The frequency with which they should be saved 29 | save_frequency: 1 30 | 31 | # List of files to be copied over to the working dir 32 | # from the results dir when using checkpoints 33 | # all the crucial components (rng, parameters, history) 34 | # are handled separately 35 | to_restore: [] 36 | 37 | # List of file patterns to be deleted 38 | # prior to and at the end of training from working_dir 39 | to_clean_once: [history, parameters, yaml, log, rng, history] 40 | 41 | # Control wandb logging 42 | use_wandb: false 43 | 44 | # Test without Ray to enable easy error detection 45 | debug_clients: 46 | all: false 47 | one: true 48 | 49 | # If to automatically resume wandb runs 50 | wandb_resume: true 51 | 52 | # The id of the wandb run to resume 53 | # If null and wandb_resume, tries to detect 54 | # the wandb_id from the hydra config files 55 | # if using checkpointing, otherwise 56 | # creates a new run 57 | wandb_id: null 58 | 59 | # Wandb configuration 60 | # add whatever tags you like 61 | # change the name 62 | wandb: 63 | setup: 64 | project: template 65 | tags: ['strategy_${strategy.name}', 'seed_${fed.seed}'] 66 | entity: null 67 | mode: online 68 | 69 | # For Ray cluster usage 70 | # leave null unless you need 71 | # multiple ray instances running 72 | ray_address: null 73 | ray_redis_password: null 74 | ray_node_ip_address: null 75 | 76 | # When using checkpointing 77 | # automatically detect the most recent checkpoint 78 | # checks at most file_limit files up to depth 2 79 | # from the results directory 80 | # if null checks all files 81 | file_limit: 250 82 | -------------------------------------------------------------------------------- /project/conf/base.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | defaults: 3 | - _self_ 4 | - task: default 5 | - fed: default 6 | - strategy: fedavg 7 | - dataset: default 8 | 9 | # If checkpointing is enabled, 10 | # you may wish to save results to the same directory 11 | # just change the dir to {your_output_directory} 12 | hydra: 13 | run: 14 | dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} 15 | # dir: /path/to/your/output/directory 16 | 17 | # Working dir to save temporary files to 18 | # If null, defaults to hydra_dir/working 19 | working_dir: null 20 | 21 | # List of file patterns to be saved from working_dir 22 | # at the end of training 23 | to_save_once: [yaml, log, wandb] 24 | 25 | # List of file patterns to be saved repeatedly 26 | to_save_per_round: [parameters, rng, history] 27 | 28 | # The frequency with which they should be saved 29 | save_frequency: 1 30 | 31 | # List of exact file names to be copied over to the working dir 32 | # from the results dir when using checkpoints 33 | # all the crucial components (rng, parameters, history) 34 | # are handled separately 35 | to_restore: [] 36 | 37 | # List of file patterns to be deleted 38 | # prior to and at the end of training from working_dir 39 | to_clean_once: [history, parameters, yaml, log, rng, history] 40 | 41 | # Control wandb logging 42 | use_wandb: false 43 | 44 | # Test without Ray to enable easy error detection 45 | debug_clients: 46 | all: false 47 | one: true 48 | 49 | # If to automatically resume wand runs 50 | wandb_resume: ${use_wandb} 51 | 52 | # The id of the wandb run to resume 53 | # If null and wandb_resume, tries to detect 54 | # the wandb_id from the hydra config files 55 | # if using checkpointing, otherwise 56 | # creates a new run 57 | wandb_id: null 58 | 59 | # Wandb configuration 60 | # add whatever tags you like 61 | # change the name 62 | wandb: 63 | setup: 64 | project: template 65 | tags: ['strategy_${strategy.name}', 'seed_${fed.seed}'] 66 | entity: null 67 | mode: online 68 | 69 | # For Ray cluster usage 70 | # leave null unless you need 71 | # multiple ray instances running 72 | ray_address: null 73 | ray_redis_password: null 74 | ray_node_ip_address: null 75 | 76 | # When using checkpointing 77 | # automatically detect the most recent checkpoint 78 | # checks at most file_limit files up to depth 2 79 | # from the results directory 80 | # if null checks all files 81 | file_limit: 250 82 | -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behaviour that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behaviour by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct that could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behaviour and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behaviour. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned with this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviours that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behaviour may be 58 | reported by contacting the project team at {aai30, ls985}@cam.ac.uk. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality about the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /project/task/mnist_classification/models.py: -------------------------------------------------------------------------------- 1 | """CNN model architecture, training, and testing functions for MNIST.""" 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | from project.types.common import NetGen 8 | from project.utils.utils import lazy_config_wrapper 9 | 10 | 11 | class Net(nn.Module): 12 | """Convolutional Neural Network architecture. 13 | 14 | As described in McMahan 2017 paper : 15 | 16 | [Communication-Efficient Learning of Deep Networks from 17 | Decentralized Data] (https://arxiv.org/pdf/1602.05629.pdf) 18 | """ 19 | 20 | def __init__(self, num_classes: int = 10) -> None: 21 | """Initialize the network. 22 | 23 | Parameters 24 | ---------- 25 | num_classes : int 26 | Number of classes in the dataset. 27 | 28 | Returns 29 | ------- 30 | None 31 | """ 32 | super().__init__() 33 | self.conv1 = nn.Conv2d(1, 32, 5, padding=1) 34 | self.conv2 = nn.Conv2d(32, 64, 5, padding=1) 35 | self.pool = nn.MaxPool2d( 36 | kernel_size=(2, 2), 37 | padding=1, 38 | ) 39 | self.fc1 = nn.Linear(64 * 7 * 7, 512) 40 | self.fc2 = nn.Linear(512, num_classes) 41 | 42 | def forward( 43 | self, 44 | input_tensor: torch.Tensor, 45 | ) -> torch.Tensor: 46 | """Forward pass of the CNN. 47 | 48 | Parameters 49 | ---------- 50 | x : torch.Tensor 51 | Input Tensor that will pass through the network 52 | 53 | Returns 54 | ------- 55 | torch.Tensor 56 | The resulting Tensor after it has passed through the network 57 | """ 58 | output_tensor = F.relu(self.conv1(input_tensor)) 59 | output_tensor = self.pool(output_tensor) 60 | output_tensor = F.relu(self.conv2(output_tensor)) 61 | output_tensor = self.pool(output_tensor) 62 | output_tensor = torch.flatten(output_tensor, 1) 63 | output_tensor = F.relu(self.fc1(output_tensor)) 64 | output_tensor = self.fc2(output_tensor) 65 | return output_tensor 66 | 67 | 68 | # Simple wrapper to match the NetGenerator Interface 69 | get_net: NetGen = lazy_config_wrapper(Net) 70 | 71 | 72 | class LogisticRegression(nn.Module): 73 | """A network for logistic regression using a single fully connected layer. 74 | 75 | As described in the Li et al., 2020 paper : 76 | 77 | [Federated Optimization in Heterogeneous Networks] ( 78 | 79 | https://arxiv.org/pdf/1812.06127.pdf) 80 | """ 81 | 82 | def __init__(self, num_classes: int = 10) -> None: 83 | """Initialize the network. 84 | 85 | Parameters 86 | ---------- 87 | num_classes : int 88 | Number of classes in the dataset. 89 | 90 | Returns 91 | ------- 92 | None 93 | """ 94 | super().__init__() 95 | self.linear = nn.Linear(28 * 28, num_classes) 96 | 97 | def forward( 98 | self, 99 | input_tensor: torch.Tensor, 100 | ) -> torch.Tensor: 101 | """Forward pass. 102 | 103 | Parameters 104 | ---------- 105 | x : torch.Tensor 106 | Input Tensor that will pass through the network 107 | 108 | Returns 109 | ------- 110 | torch.Tensor 111 | The resulting Tensor after it has passed through the network 112 | """ 113 | output_tensor = self.linear( 114 | torch.flatten(input_tensor, 1), 115 | ) 116 | return output_tensor 117 | 118 | 119 | # Simple wrapper to match the NetGenerator Interface 120 | get_logistic_regression: NetGen = lazy_config_wrapper( 121 | LogisticRegression, 122 | ) 123 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Hydra and wanbd outputs 2 | data 3 | .vscode 4 | outputs/* 5 | multirun/* 6 | wandb/* 7 | run_scripts/launch_template.sh 8 | run_scripts/launch_mnist.sh 9 | pollen_worker 10 | *.pt 11 | *.npz 12 | *.bin 13 | 14 | # Slurm logs 15 | slurm-*.out 16 | 17 | 18 | # Byte-compiled / optimized / DLL files 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | share/python-wheels/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | MANIFEST 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .nox/ 60 | .coverage 61 | .coverage.* 62 | .cache 63 | nosetests.xml 64 | coverage.xml 65 | *.cover 66 | *.py,cover 67 | .hypothesis/ 68 | .pytest_cache/ 69 | cover/ 70 | 71 | # Translations 72 | *.mo 73 | *.pot 74 | 75 | # Django stuff: 76 | *.log 77 | local_settings.py 78 | db.sqlite3 79 | db.sqlite3-journal 80 | 81 | # Flask stuff: 82 | instance/ 83 | .webassets-cache 84 | 85 | # Scrapy stuff: 86 | .scrapy 87 | 88 | # Sphinx documentation 89 | docs/_build/ 90 | 91 | # PyBuilder 92 | .pybuilder/ 93 | target/ 94 | 95 | # Jupyter Notebook 96 | .ipynb_checkpoints 97 | 98 | # IPython 99 | profile_default/ 100 | ipython_config.py 101 | 102 | # pyenv 103 | # For a library or package, you might want to ignore these files since the code is 104 | # intended to run in multiple environments; otherwise, check them in: 105 | # .python-version 106 | 107 | # pipenv 108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 111 | # install all needed dependencies. 112 | #Pipfile.lock 113 | 114 | # poetry 115 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 116 | # This is especially recommended for binary packages to ensure reproducibility, and is more 117 | # commonly ignored for libraries. 118 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 119 | #poetry.lock 120 | 121 | # pdm 122 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 123 | #pdm.lock 124 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 125 | # in version control. 126 | # https://pdm.fming.dev/#use-with-ide 127 | .pdm.toml 128 | 129 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 130 | __pypackages__/ 131 | 132 | # Celery stuff 133 | celerybeat-schedule 134 | celerybeat.pid 135 | 136 | # SageMath parsed files 137 | *.sage.py 138 | 139 | # Environments 140 | .env 141 | .venv 142 | env/ 143 | venv/ 144 | ENV/ 145 | env.bak/ 146 | venv.bak/ 147 | 148 | # Spyder project settings 149 | .spyderproject 150 | .spyproject 151 | 152 | # Rope project settings 153 | .ropeproject 154 | 155 | # mkdocs documentation 156 | /site 157 | 158 | # mypy 159 | .mypy_cache/ 160 | .dmypy.json 161 | dmypy.json 162 | # ruff 163 | .ruff_cache/ 164 | 165 | # Pyre type checker 166 | .pyre/ 167 | 168 | # pytype static type analyzer 169 | .pytype/ 170 | 171 | # Cython debug symbols 172 | cython_debug/ 173 | 174 | # PyCharm 175 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 176 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 177 | # and can be added to the global gitignore or merged into this file. For a more nuclear 178 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 179 | #.idea/ 180 | -------------------------------------------------------------------------------- /project/task/mnist_classification/dataset.py: -------------------------------------------------------------------------------- 1 | """MNIST dataset utilities for federated learning.""" 2 | 3 | from pathlib import Path 4 | 5 | from omegaconf import DictConfig 6 | import torch 7 | from torch.utils.data import DataLoader 8 | 9 | from project.task.default.dataset import ( 10 | ClientDataloaderConfig as DefaultClientDataloaderConfig, 11 | ) 12 | from project.task.default.dataset import ( 13 | FedDataloaderConfig as DefaultFedDataloaderConfig, 14 | ) 15 | from project.types.common import ( 16 | CID, 17 | ClientDataloaderGen, 18 | FedDataloaderGen, 19 | IsolatedRNG, 20 | ) 21 | 22 | # Use defaults for this very simple dataset 23 | # Requires only batch size 24 | ClientDataloaderConfig = DefaultClientDataloaderConfig 25 | FedDataloaderConfig = DefaultFedDataloaderConfig 26 | 27 | 28 | def get_dataloader_generators( 29 | partition_dir: Path, 30 | ) -> tuple[ClientDataloaderGen, FedDataloaderGen]: 31 | """Return a function that loads a client's dataset. 32 | 33 | Parameters 34 | ---------- 35 | partition_dir : Path 36 | The path to the partition directory. 37 | Containing the training data of clients. 38 | Partitioned by client id. 39 | 40 | Returns 41 | ------- 42 | Tuple[ClientDataloaderGen, FedDataloaderGen] 43 | A tuple of functions that return a DataLoader for a client's dataset 44 | and a DataLoader for the federated dataset. 45 | """ 46 | 47 | def get_client_dataloader( 48 | cid: CID, 49 | test: bool, 50 | _config: dict, 51 | rng_tuple: IsolatedRNG, 52 | _hydra_config: DictConfig | None, 53 | ) -> DataLoader: 54 | """Return a DataLoader for a client's dataset. 55 | 56 | Parameters 57 | ---------- 58 | cid : str|int 59 | The client's ID 60 | test : bool 61 | Whether to load the test set or not 62 | _config : Dict 63 | The configuration for the dataset 64 | rng_tuple : IsolatedRNGTuple 65 | The random number generator state for the training. 66 | Use if you need seeded random behavior 67 | 68 | Returns 69 | ------- 70 | DataLoader 71 | The DataLoader for the client's dataset 72 | """ 73 | config: ClientDataloaderConfig = ClientDataloaderConfig(**_config) 74 | del _config 75 | 76 | torch_cpu_generator = rng_tuple[3] 77 | 78 | client_dir = partition_dir / f"client_{cid}" 79 | if not test: 80 | dataset = torch.load(client_dir / "train.pt") 81 | else: 82 | dataset = torch.load(client_dir / "test.pt") 83 | return DataLoader( 84 | dataset, 85 | batch_size=config.batch_size, 86 | shuffle=not test, 87 | generator=torch_cpu_generator, 88 | ) 89 | 90 | def get_federated_dataloader( 91 | test: bool, 92 | _config: dict, 93 | rng_tuple: IsolatedRNG, 94 | _hydra_config: DictConfig | None, 95 | ) -> DataLoader: 96 | """Return a DataLoader for federated train/test sets. 97 | 98 | Parameters 99 | ---------- 100 | test : bool 101 | Whether to load the test set or not 102 | config : Dict 103 | The configuration for the dataset 104 | rng_tuple : IsolatedRNGTuple 105 | The random number generator state for the training. 106 | Use if you need seeded random behavior 107 | 108 | Returns 109 | ------- 110 | DataLoader 111 | The DataLoader for the federated dataset 112 | """ 113 | config: FedDataloaderConfig = FedDataloaderConfig( 114 | **_config, 115 | ) 116 | del _config 117 | torch_cpu_generator = rng_tuple[3] 118 | 119 | if not test: 120 | return DataLoader( 121 | torch.load(partition_dir / "train.pt"), 122 | batch_size=config.batch_size, 123 | shuffle=not test, 124 | generator=torch_cpu_generator, 125 | ) 126 | 127 | return DataLoader( 128 | torch.load(partition_dir / "test.pt"), 129 | batch_size=config.batch_size, 130 | shuffle=not test, 131 | generator=torch_cpu_generator, 132 | ) 133 | 134 | return get_client_dataloader, get_federated_dataloader 135 | -------------------------------------------------------------------------------- /project/fed/server/deterministic_client_manager.py: -------------------------------------------------------------------------------- 1 | """A client manager that guarantees deterministic client sampling.""" 2 | 3 | import logging 4 | import random 5 | from typing import Any 6 | 7 | from flwr.common.logger import log 8 | from flwr.server.client_manager import SimpleClientManager 9 | from flwr.server.client_proxy import ClientProxy 10 | from flwr.server.criterion import Criterion 11 | from omegaconf import DictConfig 12 | 13 | 14 | class DeterministicClientManager(SimpleClientManager): 15 | """A deterministic client manager. 16 | 17 | Samples clients in the same order every time based on the seed. Also allows sampling 18 | with replacement. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | client_cid_generator: random.Random, 24 | hydra_config: DictConfig | None, 25 | enable_resampling: bool = False, 26 | ) -> None: 27 | """Initialize DeterministicClientManager. 28 | 29 | Parameters 30 | ---------- 31 | client_cid_generator : random.Random 32 | A random number generator to generate client cids. 33 | enable_resampling : bool 34 | Whether to allow sampling with replacement. 35 | 36 | Returns 37 | ------- 38 | None 39 | """ 40 | super().__init__() 41 | 42 | self.client_cid_generator = client_cid_generator 43 | self.enable_resampling = enable_resampling 44 | self.hydra_config = hydra_config 45 | 46 | def sample( 47 | self, 48 | num_clients: int, 49 | min_num_clients: int | None = None, 50 | criterion: Criterion | None = None, 51 | ) -> list[ClientProxy]: 52 | """Sample a number of Flower ClientProxy instances. 53 | 54 | Guarantees deterministic client sampling and enables 55 | sampling with replacement. 56 | 57 | Parameters 58 | ---------- 59 | num_clients : int 60 | The number of clients to sample. 61 | min_num_clients : Optional[int] 62 | The minimum number of clients to sample. 63 | criterion : Optional[Criterion] 64 | A criterion to select clients. 65 | 66 | Returns 67 | ------- 68 | List[ClientProxy] 69 | A list of sampled clients. 70 | """ 71 | # Block until at least num_clients are connected. 72 | if min_num_clients is None: 73 | min_num_clients = num_clients 74 | self.wait_for(min_num_clients) 75 | 76 | cids = list(self.clients) 77 | 78 | if criterion is not None: 79 | cids = [cid for cid in cids if criterion.select(self.clients[cid])] 80 | # Shuffle the list of clients 81 | 82 | available_cids = [] 83 | if num_clients <= len(cids): 84 | available_cids = self.client_cid_generator.sample( 85 | cids, 86 | num_clients, 87 | ) 88 | elif self.enable_resampling: 89 | available_cids = self.client_cid_generator.choices( 90 | cids, 91 | k=num_clients, 92 | ) 93 | else: 94 | log( 95 | logging.INFO, 96 | "Sampling failed: number of available clients" 97 | " (%s) is less than number of requested clients (%s).", 98 | len(cids), 99 | num_clients, 100 | ) 101 | available_cids = [] 102 | 103 | client_list = [self.clients[cid] for cid in available_cids] 104 | log( 105 | logging.INFO, 106 | "Sampled the following clients: %s", 107 | available_cids, 108 | ) 109 | 110 | return client_list 111 | 112 | 113 | def dispatch_deterministic_client_manager( 114 | cfg: DictConfig, **kwargs: Any 115 | ) -> type[DeterministicClientManager] | None: 116 | """Dispatch the get_client_manager function based on the hydra config. 117 | 118 | Parameters 119 | ---------- 120 | cfg : DictConfig 121 | The configuration for the get_client_manager function. 122 | Loaded dynamically from the config file. 123 | kwargs : dict[str, Any] 124 | Additional keyword arguments to pass to the get_client_manager function. 125 | 126 | Returns 127 | ------- 128 | type[DeterministicClientManager] 129 | The get_client_manager function. 130 | """ 131 | client_manager: str | None = cfg.get("task", None).get("client_manager", None) 132 | 133 | if client_manager is None: 134 | return None 135 | 136 | if client_manager.upper() == "DEFAULT": 137 | return DeterministicClientManager 138 | 139 | return None 140 | -------------------------------------------------------------------------------- /project/fed/server/wandb_history.py: -------------------------------------------------------------------------------- 1 | """History class which sends metrics to wandb. 2 | 3 | Metrics are collected only at the central server, minimizing communication costs. Metric 4 | collection only happens if wandb is turned on. 5 | """ 6 | 7 | from flwr.common.typing import Scalar 8 | from flwr.server.history import History 9 | 10 | import wandb 11 | 12 | 13 | class WandbHistory(History): 14 | """History class for training and/or evaluation metrics collection.""" 15 | 16 | def __init__(self, use_wandb: bool = True) -> None: 17 | """Initialize the history. 18 | 19 | Parameters 20 | ---------- 21 | use_wandb : bool 22 | Whether to use wandb. 23 | Turn off to avoid communication overhead. 24 | 25 | Returns 26 | ------- 27 | None 28 | """ 29 | super().__init__() 30 | 31 | self.use_wandb = use_wandb 32 | 33 | def add_loss_distributed( 34 | self, 35 | server_round: int, 36 | loss: float, 37 | ) -> None: 38 | """Add one loss entry (from distributed evaluation) to history/wandb. 39 | 40 | Parameters 41 | ---------- 42 | server_round : int 43 | The current server round. 44 | loss : float 45 | The loss to add. 46 | 47 | Returns 48 | ------- 49 | None 50 | """ 51 | super().add_loss_distributed(server_round, loss) 52 | if self.use_wandb: 53 | wandb.log( 54 | {"distributed_loss": loss}, 55 | step=server_round, 56 | ) 57 | 58 | def add_loss_centralized( 59 | self, 60 | server_round: int, 61 | loss: float, 62 | ) -> None: 63 | """Add one loss entry (from centralized evaluation) to history/wandb. 64 | 65 | Parameters 66 | ---------- 67 | server_round : int 68 | The current server round. 69 | loss : float 70 | The loss to add. 71 | 72 | Returns 73 | ------- 74 | None 75 | """ 76 | super().add_loss_centralized(server_round, loss) 77 | if self.use_wandb: 78 | wandb.log( 79 | {"centralised_loss": loss}, 80 | step=server_round, 81 | ) 82 | 83 | def add_metrics_distributed_fit( 84 | self, 85 | server_round: int, 86 | metrics: dict[str, Scalar], 87 | ) -> None: 88 | """Add metrics entries (from distributed fit) to history/wandb. 89 | 90 | Parameters 91 | ---------- 92 | server_round : int 93 | The current server round. 94 | metrics : Dict[str, Scalar] 95 | The metrics to add. 96 | 97 | Returns 98 | ------- 99 | None 100 | """ 101 | super().add_metrics_distributed_fit( 102 | server_round, 103 | metrics, 104 | ) 105 | if self.use_wandb: 106 | for key in metrics: 107 | wandb.log( 108 | {key: metrics[key]}, 109 | step=server_round, 110 | ) 111 | 112 | def add_metrics_distributed( 113 | self, 114 | server_round: int, 115 | metrics: dict[str, Scalar], 116 | ) -> None: 117 | """Add metrics entries (from distributed evaluation) to history/wandb. 118 | 119 | Parameters 120 | ---------- 121 | server_round : int 122 | The current server round. 123 | metrics : Dict[str, Scalar] 124 | The metrics to add. 125 | 126 | Returns 127 | ------- 128 | None 129 | """ 130 | super().add_metrics_distributed( 131 | server_round, 132 | metrics, 133 | ) 134 | if self.use_wandb: 135 | for key in metrics: 136 | wandb.log( 137 | {key: metrics[key]}, 138 | step=server_round, 139 | ) 140 | 141 | def add_metrics_centralized( 142 | self, 143 | server_round: int, 144 | metrics: dict[str, Scalar], 145 | ) -> None: 146 | """Add metrics entries (from centralized evaluation) to history/wand. 147 | 148 | Parameters 149 | ---------- 150 | server_round : int 151 | The current server round. 152 | metrics : Dict[str, Scalar] 153 | The metrics to add. 154 | 155 | Returns 156 | ------- 157 | None 158 | """ 159 | super().add_metrics_centralized( 160 | server_round, 161 | metrics, 162 | ) 163 | if self.use_wandb: 164 | for key in metrics: 165 | wandb.log( 166 | {key: metrics[key]}, 167 | step=server_round, 168 | ) 169 | -------------------------------------------------------------------------------- /project/task/mnist_classification/dispatch.py: -------------------------------------------------------------------------------- 1 | """Dispatch the MNIST functionality to project.main. 2 | 3 | The dispatch functions are used to 4 | dynamically select the correct functions from the task 5 | based on the hydra config file. 6 | The following categories of functionality are grouped together: 7 | - train/test and fed test functions 8 | - net generator and dataloader generator functions 9 | - fit/eval config functions 10 | 11 | The top-level project.dispatch 12 | module operates as a pipeline 13 | and selects the first function which does not return None. 14 | 15 | Do not throw any errors based on not finding a given attribute 16 | in the configs under any circumstances. 17 | 18 | If you cannot match the config file, 19 | return None and the dispatch of the next task 20 | in the chain specified by project.dispatch will be used. 21 | """ 22 | 23 | from pathlib import Path 24 | from typing import Any 25 | 26 | from omegaconf import DictConfig 27 | from project.fed.utils.utils import ( 28 | generate_initial_params_from_net_generator as get_initial_parameters, 29 | ) 30 | 31 | from project.task.default.dispatch import ( 32 | dispatch_config as dispatch_default_config, 33 | init_working_dir as init_working_dir_default, 34 | ) 35 | from project.task.mnist_classification.dataset import get_dataloader_generators 36 | from project.task.mnist_classification.models import get_logistic_regression, get_net 37 | from project.task.mnist_classification.train_test import get_fed_eval_fn, test, train 38 | from project.types.common import DataStructure, TrainStructure 39 | 40 | 41 | def dispatch_train( 42 | cfg: DictConfig, 43 | **kwargs: Any, 44 | ) -> TrainStructure | None: 45 | """Dispatch the train/test and fed test functions based on the config file. 46 | 47 | Do not throw any errors based on not finding a given attribute 48 | in the configs under any circumstances. 49 | 50 | If you cannot match the config file, 51 | return None and the dispatch of the next task 52 | in the chain specified by project.dispatch will be used. 53 | 54 | Parameters 55 | ---------- 56 | cfg : DictConfig 57 | The configuration for the train function. 58 | Loaded dynamically from the config file. 59 | kwargs : dict[str, Any] 60 | Additional keyword arguments to pass to the train function. 61 | 62 | Returns 63 | ------- 64 | Optional[TrainStructure] 65 | The train function, test function and the get_fed_eval_fn function. 66 | Return None if you cannot match the cfg. 67 | """ 68 | # Select the value for the key with None default 69 | train_structure: str | None = cfg.get("task", {}).get( 70 | "train_structure", 71 | None, 72 | ) 73 | 74 | # Only consider not None and uppercase matches 75 | if train_structure is not None and train_structure.upper() == "MNIST": 76 | return train, test, get_fed_eval_fn 77 | 78 | # Cannot match, send to next dispatch in chain 79 | return None 80 | 81 | 82 | def dispatch_data(cfg: DictConfig, **kwargs: Any) -> DataStructure | None: 83 | """Dispatch the train/test and fed test functions based on the config file. 84 | 85 | Do not throw any errors based on not finding a given attribute 86 | in the configs under any circumstances. 87 | 88 | If you cannot match the config file, 89 | return None and the dispatch of the next task 90 | in the chain specified by project.dispatch will be used. 91 | 92 | Parameters 93 | ---------- 94 | cfg : DictConfig 95 | The configuration for the data functions. 96 | Loaded dynamically from the config file. 97 | kwargs : dict[str, Any] 98 | Additional keyword arguments to pass to the data functions. 99 | 100 | Returns 101 | ------- 102 | Optional[DataStructure] 103 | The net generator, client dataloader generator and fed dataloader generator. 104 | Return None if you cannot match the cfg. 105 | """ 106 | # Select the value for the key with {} default at nested dicts 107 | # and None default at the final key 108 | client_model_and_data: str | None = cfg.get( 109 | "task", 110 | {}, 111 | ).get("model_and_data", None) 112 | 113 | # Select the partition dir 114 | # if it does not exist data cannot be loaded 115 | # for MNIST and the dispatch should return None 116 | partition_dir: str | None = cfg.get("dataset", {}).get( 117 | "partition_dir", 118 | None, 119 | ) 120 | 121 | # Only consider situations where both are not None 122 | # otherwise data loading would fail later 123 | if client_model_and_data is not None and partition_dir is not None: 124 | # Obtain the dataloader generators 125 | # for the provided partition dir 126 | ( 127 | client_dataloader_gen, 128 | fed_dataloader_gen, 129 | ) = get_dataloader_generators( 130 | Path(partition_dir), 131 | ) 132 | 133 | # Case insensitive matches 134 | if client_model_and_data.upper() == "MNIST_CNN": 135 | return ( 136 | get_net, 137 | get_initial_parameters, 138 | client_dataloader_gen, 139 | fed_dataloader_gen, 140 | init_working_dir_default, 141 | ) 142 | elif client_model_and_data.upper() == "MNIST_LR": 143 | return ( 144 | get_logistic_regression, 145 | get_initial_parameters, 146 | client_dataloader_gen, 147 | fed_dataloader_gen, 148 | init_working_dir_default, 149 | ) 150 | 151 | # Cannot match, send to next dispatch in chain 152 | return None 153 | 154 | 155 | dispatch_config = dispatch_default_config 156 | -------------------------------------------------------------------------------- /project/task/default/dataset.py: -------------------------------------------------------------------------------- 1 | """Handle basic dataset creation. 2 | 3 | In case of PyTorch it should return dataloaders 4 | for your dataset (for both the clients and server). 5 | If you are using a custom dataset class, this module is the place to define it. 6 | If your dataset requires to be downloaded (and this is not done 7 | automatically -- e.g. as it is the case for many dataset in TorchVision) and 8 | partitioned, please include all those functions and logic in the 9 | `dataset_preparation.py` module. 10 | You can use all those functions from functions/methods defined here of course. 11 | """ 12 | 13 | from collections import defaultdict 14 | from pathlib import Path 15 | 16 | from omegaconf import DictConfig 17 | import torch 18 | from pydantic import BaseModel 19 | from torch.utils.data import DataLoader, Dataset, TensorDataset 20 | 21 | from project.types.common import CID, IsolatedRNG 22 | 23 | 24 | class ClientDataloaderConfig(BaseModel): 25 | """Dataloader configuration for the client. 26 | 27 | Allows '.' member access and static checking. Guarantees that all necessary 28 | components are present, fails early if config is mismatched to dataloader. 29 | """ 30 | 31 | batch_size: int 32 | 33 | class Config: 34 | """Setting to allow any types, including library ones like torch.device.""" 35 | 36 | arbitrary_types_allowed = True 37 | 38 | 39 | class FedDataloaderConfig(BaseModel): 40 | """Dataloader configuration for the client. 41 | 42 | Allows '.' member access and static checking. Guarantees that all necessary 43 | components are present, fails early if config is mismatched to dataloader. 44 | """ 45 | 46 | batch_size: int 47 | 48 | class Config: 49 | """Setting to allow any types, including library ones like torch.device.""" 50 | 51 | arbitrary_types_allowed = True 52 | 53 | 54 | def get_client_dataloader( 55 | cid: CID, 56 | test: bool, 57 | _config: dict, 58 | _rng_tuple: IsolatedRNG, 59 | _hydra_config: DictConfig | None, 60 | ) -> DataLoader: 61 | """Return a DataLoader for a client's dataset. 62 | 63 | Parameters 64 | ---------- 65 | cid : str|int 66 | The client's ID 67 | test : bool 68 | Whether to load the test set or not 69 | cfg : Dict 70 | The configuration for the dataset 71 | _rng_tuple : IsolatedRNGTuple 72 | The random number generator state for the training. 73 | Use if you need seeded random behavior 74 | 75 | Returns 76 | ------- 77 | DataLoader 78 | The DataLoader for the client's dataset 79 | """ 80 | # Create an empty TensorDataset for illustration purposes 81 | config: ClientDataloaderConfig = ClientDataloaderConfig( 82 | **_config, 83 | ) 84 | del _config 85 | 86 | # You should load/create one train/test dataset per client 87 | if not test: 88 | empty_trainset_dict: dict[ 89 | CID, 90 | Dataset, 91 | ] = defaultdict( 92 | lambda: TensorDataset( 93 | torch.Tensor([1]), 94 | torch.Tensor([1]), 95 | ), 96 | ) 97 | # Choose the client dataset based on the client id and train/test 98 | dataset = empty_trainset_dict[cid] 99 | else: 100 | empty_testset_dict: dict[ 101 | CID, 102 | Dataset, 103 | ] = defaultdict( 104 | lambda: TensorDataset( 105 | torch.Tensor([1]), 106 | torch.Tensor([1]), 107 | ), 108 | ) 109 | # Choose the client dataset based on the client id and train/test 110 | dataset = empty_testset_dict[cid] 111 | 112 | return DataLoader( 113 | dataset, 114 | batch_size=config.batch_size, 115 | shuffle=not test, 116 | drop_last=True, 117 | ) 118 | 119 | 120 | def get_fed_dataloader( 121 | test: bool, 122 | _config: dict, 123 | _rng_tuple: IsolatedRNG, 124 | _hydra_config: DictConfig | None, 125 | ) -> DataLoader: 126 | """Return a DataLoader for federated train/test sets. 127 | 128 | Parameters 129 | ---------- 130 | test : bool 131 | Whether to load the test set or not 132 | config : Dict 133 | The configuration for the dataset 134 | _rng_tuple : IsolatedRNGTuple 135 | The random number generator state for the training. 136 | Use if you need seeded random behavior 137 | 138 | Returns 139 | ------- 140 | DataLoader 141 | The DataLoader for the federated dataset 142 | """ 143 | config: FedDataloaderConfig = FedDataloaderConfig( 144 | **_config, 145 | ) 146 | del _config 147 | 148 | # Create one train/test empty dataset for the server 149 | if not test: 150 | empty_trainset: Dataset = TensorDataset( 151 | torch.Tensor([1]), 152 | torch.Tensor([1]), 153 | ) 154 | # Choose the server dataset based on the train/test 155 | dataset = empty_trainset 156 | else: 157 | empty_testset: Dataset = TensorDataset( 158 | torch.Tensor([1]), 159 | torch.Tensor([1]), 160 | ) 161 | # Choose the server dataset based on the train/test 162 | dataset = empty_testset 163 | 164 | return DataLoader( 165 | dataset, 166 | batch_size=config.batch_size, 167 | shuffle=not test, 168 | drop_last=True, 169 | ) 170 | 171 | 172 | def init_working_dir( 173 | working_dir: Path, 174 | results_dir: Path, 175 | ) -> None: 176 | """Initialize the working directory. 177 | 178 | Parameters 179 | ---------- 180 | working_dir : Path 181 | The path to the working directory. 182 | results_dir : Path 183 | The path to the results directory. 184 | 185 | Returns 186 | ------- 187 | None 188 | """ 189 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["poetry-core>=1.4.0"] 3 | build-backend = "poetry.masonry.api" 4 | 5 | [tool.poetry] 6 | name = "project" # Do not change this 7 | version = "1.0.0" 8 | description = "CaMLSys Project" 9 | license = "Apache-2.0" 10 | authors = ["The Flower Authors , Alexandru-Andrei Iacob , Lorenzo Sani "] 11 | readme = "README.md" 12 | # homepage = "" 13 | # documentation = "" 14 | classifiers = [ 15 | "Development Status :: 3 - Alpha", 16 | "Intended Audience :: Developers", 17 | "Intended Audience :: Science/Research", 18 | "License :: OSI Approved :: Apache Software License", 19 | "Operating System :: MacOS :: MacOS X", 20 | "Operating System :: POSIX :: Linux", 21 | "Programming Language :: Python", 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 3 :: Only", 24 | "Programming Language :: Python :: 3.8", 25 | "Programming Language :: Python :: 3.9", 26 | "Programming Language :: Python :: 3.10", 27 | "Programming Language :: Python :: 3.11", 28 | "Programming Language :: Python :: Implementation :: CPython", 29 | "Topic :: Scientific/Engineering", 30 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 31 | "Topic :: Scientific/Engineering :: Mathematics", 32 | "Topic :: Software Development", 33 | "Topic :: Software Development :: Libraries", 34 | "Topic :: Software Development :: Libraries :: Python Modules", 35 | "Typing :: Typed", 36 | ] 37 | 38 | [tool.poetry.dependencies] 39 | python = ">=3.9.17, <3.12.0" # don't change this 40 | flwr = { extras = ["simulation"], version = "1.6.0" } 41 | hydra-core = "1.3.2" # don't change this 42 | torch = {url = "https://download.pytorch.org/whl/cu121_pypi_cudnn/torch-2.1.0%2Bcu121.with.pypi.cudnn-cp311-cp311-linux_x86_64.whl"} 43 | torchaudio = {url = "https://download.pytorch.org/whl/cu121/torchaudio-2.1.0%2Bcu121-cp311-cp311-linux_x86_64.whl"} 44 | torchvision = {url = "https://download.pytorch.org/whl/cu121/torchvision-0.16.0%2Bcu121-cp311-cp311-linux_x86_64.whl"} 45 | types-protobuf = "4.24.0.4" 46 | types-pyyaml = "6.0.12.12" 47 | types-decorator = "5.1.8.4" 48 | types-setuptools = "68.2.0.0" 49 | wandb = "0.16.0" 50 | pyarrow = "14.0.1" 51 | multiprocess = "0.70.15" 52 | nvsmi = "0.4.2" 53 | transformers = "4.36.0" 54 | cloudpickle = "3.0.0" 55 | tqdm = "4.66.1" 56 | pandas = "2.1.2" 57 | scipy = "1.11.3" 58 | librosa = "0.10.1" 59 | nvidia-ml-py = "11.495.46" 60 | ipykernel = "6.26.0" 61 | matplotlib = "3.8.1" 62 | seaborn = "0.13.0" 63 | jupyter-server = "2.11.2" 64 | ipywidgets = "8.1.1" 65 | ipython = "8.17.2" 66 | gdown = "4.7.1" 67 | pydantic = "<2.0.0" 68 | pre-commit = "3.5.0" 69 | identify = "2.5.31" 70 | 71 | 72 | 73 | [tool.poetry.dev-dependencies] 74 | black = { version = ">=23.1.0", extras = ["jupyter"] } 75 | mypy = ">=1.8.0" 76 | ruff = ">=0.1.12" 77 | pytest = "==6.2.4" 78 | pytest-watch = "==4.2.0" 79 | types-requests = "==2.27.7" 80 | yamlfix = ">=1.15.0" 81 | 82 | [tool.black] 83 | line-length = 88 84 | preview = true 85 | target-version = ["py311"] 86 | 87 | [tool.mypy] 88 | ignore_missing_imports = true 89 | strict = false 90 | plugins = "numpy.typing.mypy_plugin" 91 | 92 | [tool.pylint."MESSAGES CONTROL"] 93 | disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias,import-error,no-member,no-name-in-module" 94 | good-names = "f,i,j,k,_,x,y,X,Y" 95 | signature-mutators="hydra.main.main" 96 | 97 | [tool.pylint.typecheck] 98 | generated-members="numpy.*, torch.*, tensorflow.*" 99 | 100 | [[tool.mypy.overrides]] 101 | module = [ 102 | "importlib.metadata.*", 103 | "importlib_metadata.*", 104 | ] 105 | follow_imports = "skip" 106 | follow_imports_for_stubs = true 107 | disallow_untyped_calls = false 108 | 109 | [[tool.mypy.overrides]] 110 | module = "torch.*" 111 | follow_imports = "skip" 112 | follow_imports_for_stubs = true 113 | 114 | [[tool.mypy.overrides]] 115 | module = "ray.*" 116 | follow_imports = "skip" 117 | follow_imports_for_stubs = true 118 | 119 | [tool.docformatter] 120 | wrap-summaries = 88 121 | wrap-descriptions = 88 122 | 123 | [tool.yamlfix] 124 | allow_duplicate_keys = false 125 | line_length = 88 126 | none_representation = "null" 127 | whitelines = 1 128 | section_whitelines = 1 129 | exclude = [ 130 | ".bzr", 131 | ".direnv", 132 | ".eggs", 133 | ".git", 134 | ".hg", 135 | ".mypy_cache", 136 | ".nox", 137 | ".pants.d", 138 | ".pytype", 139 | ".ruff_cache", 140 | ".svn", 141 | ".tox", 142 | ".venv", 143 | "__pypackages__", 144 | "_build", 145 | "buck-out", 146 | "build", 147 | "dist", 148 | "node_modules", 149 | "venv", 150 | "proto", 151 | "outputs", 152 | "wandb", 153 | "multirun", 154 | ] 155 | 156 | [tool.ruff] 157 | target-version = "py311" 158 | line-length = 88 159 | preview = true 160 | select = ["A", "D", "E", "F", "W", "B", "ISC", "N", "ANN", 161 | "C4", "UP", "COM", "EXE", "FA", "ISC", "ICN", "INP", "PIE", "T20", 162 | "Q", "RET", "SLOT", "SIM", "TID", "ARG", "PTH", "TD", "PD", "PGH", 163 | "PL", "TRY","NPY", "PERF", "FURB", "RUF" ] 164 | fixable = ["A", "D", "E", "F", "W", "B", "ISC", "N", "ANN", 165 | "C4", "UP", "COM", "EXE", "FA", "ISC", "ICN", "INP", "PIE", "T20", 166 | "Q", "RET", "SLOT", "SIM", "TCH", "ARG", "PTH", "TD", "PD", "PGH", 167 | "PL", "TRY", "NPY", "PERF", "FURB", "RUF"] 168 | ignore = ["B024", "B027", "PLE1205", "PLE1206", "PLR0904", 169 | "PLR0911" , "PLR0912", "PLR0913", "PLR0915", "PLR0916", "PERF203", 170 | "PERF401","PERF403", "ANN101", "ANN102", "ANN401", "PLR6301", 171 | "ARG002", "TRY003", "PTH123", "TD002","TD003", "ARG001", 172 | "ARG002", "ARG003", "ARG004", "RET505", "N812", "RET504", 173 | "PLW2901", "NPY002", "COM812", "ISC001", "RUF005", "PLR0917", "PLR0914"] 174 | exclude = [ 175 | ".bzr", 176 | ".direnv", 177 | ".eggs", 178 | ".git", 179 | ".hg", 180 | ".mypy_cache", 181 | ".nox", 182 | ".pants.d", 183 | ".pytype", 184 | ".ruff_cache", 185 | ".svn", 186 | ".tox", 187 | ".venv", 188 | "__pypackages__", 189 | "_build", 190 | "buck-out", 191 | "build", 192 | "dist", 193 | "node_modules", 194 | "venv", 195 | "proto", 196 | ] 197 | 198 | 199 | [tool.ruff.format] 200 | preview = true 201 | 202 | [tool.ruff.pydocstyle] 203 | convention = "numpy" 204 | -------------------------------------------------------------------------------- /project/task/default/dispatch.py: -------------------------------------------------------------------------------- 1 | """Dispatch the functionality of the task to project.main. 2 | 3 | The dispatch functions are used to dynamically select 4 | the correct functions from the task 5 | based on the hydra config file. 6 | You need to write dispatch functions for three categories: 7 | - train/test and fed test functions 8 | - net generator and dataloader generator functions 9 | - fit/eval config functions 10 | 11 | The top-level project.dispatch module operates as a pipeline 12 | and selects the first function which does not return None. 13 | Do not throw any errors based on not finding 14 | a given attribute in the configs under any circumstances. 15 | If you cannot match the config file, 16 | return None and the dispatch of the next task 17 | in the chain specified by project.dispatch will be used. 18 | """ 19 | 20 | from typing import Any, cast 21 | 22 | from omegaconf import DictConfig, OmegaConf 23 | 24 | from project.task.default.dataset import ( 25 | get_client_dataloader, 26 | get_fed_dataloader, 27 | init_working_dir, 28 | ) 29 | from project.task.default.models import get_net 30 | from project.task.default.train_test import ( 31 | get_fed_eval_fn, 32 | get_on_evaluate_config_fn, 33 | get_on_fit_config_fn, 34 | test, 35 | train, 36 | ) 37 | from project.types.common import ConfigStructure, DataStructure, TrainStructure 38 | 39 | from project.fed.utils.utils import ( 40 | generate_initial_params_from_net_generator as get_initial_parameters, 41 | ) 42 | 43 | 44 | def dispatch_train( 45 | cfg: DictConfig, 46 | **kwargs: Any, 47 | ) -> TrainStructure | None: 48 | """Dispatch the train/test and fed test functions based on the config file. 49 | 50 | Do not throw any errors based on not finding 51 | a given attribute in the configs under any circumstances. 52 | If you cannot match the config file, 53 | return None and the dispatch of the next task 54 | in the chain specified by project.dispatch will be used. 55 | 56 | Parameters 57 | ---------- 58 | cfg : DictConfig 59 | The configuration for the train function. 60 | Loaded dynamically from the config file. 61 | kwargs : dict[str, Any] 62 | Additional keyword arguments to pass to the train function. 63 | 64 | Returns 65 | ------- 66 | Optional[TrainStructure] 67 | The train function, test function and the get_fed_eval_fn function. 68 | Return None if you cannot match the cfg. 69 | """ 70 | # Select the value for the key with None default 71 | train_structure: str | None = cfg.get("task", {}).get( 72 | "train_structure", 73 | None, 74 | ) 75 | 76 | # Only consider not None matches, case insensitive 77 | if train_structure is not None and train_structure.upper() == "DEFAULT": 78 | return train, test, get_fed_eval_fn 79 | 80 | # Cannot match, send to next dispatch in chain 81 | return None 82 | 83 | 84 | def dispatch_data(cfg: DictConfig, **kwargs: Any) -> DataStructure | None: 85 | """Dispatch the net and dataloader client/fed generator functions. 86 | 87 | Do not throw any errors based on not finding 88 | a given attribute in the configs under any circumstances. 89 | If you cannot match the config file, 90 | return None and the dispatch of the next task 91 | in the chain specified by project.dispatch will be used. 92 | 93 | Parameters 94 | ---------- 95 | cfg : DictConfig 96 | The configuration for the data functions. 97 | Loaded dynamically from the config file. 98 | kwargs : dict[str, Any] 99 | Additional keyword arguments to pass to the data functions. 100 | 101 | Returns 102 | ------- 103 | Optional[DataStructure] 104 | The net generator, client dataloader generator and fed dataloader generator. 105 | Return None if you cannot match the cfg. 106 | """ 107 | # Select the value for the key with {} default at nested dicts 108 | # and None default at the final key 109 | client_model_and_data: str | None = cfg.get( 110 | "task", 111 | {}, 112 | ).get("model_and_data", None) 113 | 114 | # Only consider not None matches, case insensitive 115 | if client_model_and_data is not None and client_model_and_data.upper() == "DEFAULT": 116 | ret_tuple: DataStructure = ( 117 | get_net, 118 | get_initial_parameters, 119 | get_client_dataloader, 120 | get_fed_dataloader, 121 | init_working_dir, 122 | ) 123 | return ret_tuple 124 | 125 | # Cannot match, send to next dispatch in chain 126 | return None 127 | 128 | 129 | def dispatch_config( 130 | cfg: DictConfig, 131 | **kwargs: Any, 132 | ) -> ConfigStructure | None: 133 | """Dispatches the config function based on the config_structure in the config file. 134 | 135 | By default it simply takes the fit_config and evaluate_config 136 | dicts from the hydra config. 137 | Only change if a more complex behavior 138 | (such as varying the config across rounds) is needed. 139 | 140 | Do not throw any errors based on not finding 141 | a given attribute in the configs under any circumstances. 142 | If you cannot match the config file, 143 | return None and the dispatch of the next task 144 | in the chain specified by project.dispatch will be used. 145 | 146 | Parameters 147 | ---------- 148 | cfg : DictConfig 149 | The configuration for the config function. 150 | Loaded dynamically from the config file. 151 | kwargs : dict[str, Any] 152 | Additional keyword arguments to pass to the config function. 153 | 154 | Returns 155 | ------- 156 | Optional[ConfigStructure] 157 | The fit_config and evaluate_config functions. 158 | Return None if you cannot match the cfg. 159 | """ 160 | # Select the values for the key with {} default at nested dicts 161 | # and None default at the final key 162 | fit_config: dict | None = cfg.get("task", {}).get( 163 | "fit_config", 164 | None, 165 | ) 166 | eval_config: dict | None = cfg.get("task", {}).get( 167 | "eval_config", 168 | None, 169 | ) 170 | 171 | # Only consider existing config dicts as matches 172 | if fit_config is not None and eval_config is not None: 173 | return get_on_fit_config_fn( 174 | cast(dict, OmegaConf.to_container(fit_config)), 175 | ), get_on_evaluate_config_fn( 176 | cast(dict, OmegaConf.to_container(eval_config)), 177 | ) 178 | 179 | return None 180 | -------------------------------------------------------------------------------- /EXTENDED_README.md: -------------------------------------------------------------------------------- 1 | 2 | # Extended Readme 3 | 4 | > Scientific projects are expected to run in a machine with Ubuntu 22.04 5 | 6 | While `README.md` should include information about the project you implement and how to run it, this _extended_ readme provides more generally the instructions to follow before your work can be run on the CaMLSys cluster or used collaboratively. Please follow closely these instructions. It is likely that you have already completed steps 1-2. 7 | 8 | 1. Click use this template on the Github Page. 9 | 2. Run ```./setup.sh``` to set up the template project. 10 | 3. Add your additional dependencies to the `pyproject.toml` (see below a few examples on how to do it). Read more about Poetry below in this `EXTENDED_README.md`. 11 | 4. Regularly check that your coding style and the documentation you add follow good coding practices. To test whether your code meets the requirements, please run ``./setup.sh`` or 12 | ```bash 13 | poetry run pre-commit install 14 | ``` 15 | 5. If you update the `.pre-commit-config.yaml` file or change the tools in the toml file you should run the following. Feel free to run this whenever you need a sanity check that the entire codebase is up to standard. 16 | ```bash 17 | poetry run pre-commit run --all-files --hook-stage push 18 | ``` 19 | 6. Ensure that the Python environment for your project can be created without errors by simply running `poetry install` and that this is properly described later when you complete the `Setup` section in `README.md`. This is specially important if your environment requires additional steps after doing `poetry install` or ```./setup.sh```. 20 | 7. Ensure that your project runs with default arguments by running `poetry run python -m project.main`. Then, describe this and other forms of running your code in the `Using the Project` section in `README.md`. 21 | 8. Once your code is ready and you have checked: 22 | * that following the instructions in your `README.md` the Python environment can be created correctly 23 | 24 | * that running the code following your instructions can reproduce the experiments you setup for your paper 25 | 26 | ,then you just need to invite collaborators to your project if this is in a private repo or publish it to the camlsys github if it is meant as an artefact. If you feel like some improvements can be shared across projects please open a PR for the template. 27 | 28 | > Once you are happy with your project please delete this `EXTENDED_README.md` file. 29 | 30 | 31 | ## About Poetry 32 | 33 | We use Poetry to manage the Python environment for each individual project. You can follow the instructions [here](https://python-poetry.org/docs/) to install Poetry in your machine. The ``./setup.sh`` script already handles all of these steps for you, however, you need to understand how they work to properly change the template. 34 | 35 | 36 | ### Specifying a Python Version (optional) 37 | By default, Poetry will use the Python version in your system. In some settings, you might want to specify a particular version of Python to use inside your Poetry environment. You can do so with [`pyenv`](https://github.com/pyenv/pyenv). Check the documentation for the different ways of installing `pyenv`, but one easy way is using the [automatic installer](https://github.com/pyenv/pyenv-installer): 38 | ```bash 39 | curl https://pyenv.run | bash # then, don't forget links to your .bashrc/.zshrc 40 | ``` 41 | 42 | You can then install any Python version with `pyenv install ` (e.g. `pyenv install 3.9.17`). Then, in order to use that version for your project, you'd do the following: 43 | 44 | ```bash 45 | # cd to your project directory (i.e. where the `pyproject.toml` is) 46 | pyenv local 47 | 48 | # set that version for poetry 49 | poetry env use 50 | 51 | # then you can install your Poetry environment (see the next setp) 52 | ``` 53 | 54 | ### Installing Your Environment 55 | With the Poetry tool already installed, you can create an environment for this project with commands: 56 | ```bash 57 | # run this from the same directory as the `pyproject.toml` file is 58 | poetry install 59 | ``` 60 | 61 | This will create a basic Python environment with just Flower and additional packages, including those needed for simulation. Next, you should add the dependencies for your code. It is **critical** that you fix the version of the packages you use using a `=` not a `=^`. You can do so via [`poetry add`](https://python-poetry.org/docs/cli/#add). Below are some examples: 62 | 63 | ```bash 64 | # For instance, if you want to install tqdm 65 | poetry add tqdm==4.65.0 66 | 67 | # If you already have a requirements.txt, you can add all those packages (but ensure you have fixed the version) in one go as follows: 68 | poetry add $( cat requirements.txt ) 69 | ``` 70 | With each `poetry add` command, the `pyproject.toml` gets automatically updated so you don't need to keep that `requirements.txt` as part of this project. 71 | 72 | 73 | More critically however, is adding your ML framework of choice to the list of dependencies. For some frameworks you might be able to do so with the `poetry add` command. Check [the Poetry documentation](https://python-poetry.org/docs/cli/#add) for how to add packages in various ways. For instance, let's say you want to use PyTorch: 74 | 75 | ```bash 76 | # with plain `pip` you'd run a command such as: 77 | pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117 78 | 79 | # to add the same 3 dependencies to your Poetry environment you'd need to add the URL to the wheel that the above pip command auto-resolves for you. 80 | # You can find those wheels in `https://download.pytorch.org/whl/cu117`. Copy the link and paste it after the `poetry add` command. 81 | # For instance to add `torch==1.13.1+cu117` and a x86 Linux system with Python3.8 you'd: 82 | poetry add https://download.pytorch.org/whl/cu117/torch-1.13.1%2Bcu117-cp38-cp38-linux_x86_64.whl 83 | # you'll need to repeat this for both `torchvision` and `torchaudio` 84 | ``` 85 | The above is just an example of how you can add these dependencies. Please refer to the Poetry documentation to extra reference. 86 | 87 | If all attempts fail, you can still install packages via standard `pip`. You'd first need to source/activate your Poetry environment. 88 | ```bash 89 | # first ensure you have created your environment 90 | # and installed the base packages provided in the template 91 | poetry install 92 | 93 | # then activate it 94 | poetry shell 95 | ``` 96 | Now you are inside your environment (pretty much as when you use `virtualenv` or `conda`) so you can install further packages with `pip`. Please note that, unlike with `poetry add`, these extra requirements won't be captured by `pyproject.toml`. Therefore, please ensure that you provide all instructions needed to: (1) create the base environment with Poetry and (2) install any additional dependencies via `pip` when you complete your `README.md`. -------------------------------------------------------------------------------- /project/task/mnist_classification/train_test.py: -------------------------------------------------------------------------------- 1 | """MNIST training and testing functions, local and federated.""" 2 | 3 | from collections.abc import Sized 4 | from pathlib import Path 5 | from typing import cast 6 | 7 | from omegaconf import DictConfig 8 | import torch 9 | from pydantic import BaseModel 10 | from torch import nn 11 | from torch.utils.data import DataLoader 12 | 13 | from flwr.common import NDArrays 14 | 15 | from project.task.default.train_test import get_fed_eval_fn as get_default_fed_eval_fn 16 | from project.task.default.train_test import ( 17 | get_on_evaluate_config_fn as get_default_on_evaluate_config_fn, 18 | ) 19 | from project.task.default.train_test import ( 20 | get_on_fit_config_fn as get_default_on_fit_config_fn, 21 | ) 22 | from project.types.common import IsolatedRNG, CID 23 | 24 | 25 | class TrainConfig(BaseModel): 26 | """Training configuration, allows '.' member access and static checking. 27 | 28 | Guarantees that all necessary components are present, fails early if config is 29 | mismatched to client. 30 | """ 31 | 32 | cid: CID 33 | device: torch.device 34 | epochs: int 35 | learning_rate: float 36 | 37 | class Config: 38 | """Setting to allow any types, including library ones like torch.device.""" 39 | 40 | arbitrary_types_allowed = True 41 | 42 | 43 | def train( # pylint: disable=too-many-arguments 44 | net: nn.Module | NDArrays, 45 | trainloader: DataLoader | None, 46 | _config: dict, 47 | _working_dir: Path, 48 | _rng_tuple: IsolatedRNG, 49 | _hydra_config: DictConfig | None, 50 | ) -> tuple[nn.Module | NDArrays, int, dict]: 51 | """Train the network on the training set. 52 | 53 | Parameters 54 | ---------- 55 | net : nn.Module 56 | The neural network to train. 57 | trainloader : DataLoader 58 | The DataLoader containing the data to train the network on. 59 | _config : Dict 60 | The configuration for the training. 61 | Contains the device, number of epochs and learning rate. 62 | Static type checking is done by the TrainConfig class. 63 | _working_dir : Path 64 | The working directory for the training. 65 | Unused. 66 | _rng_tuple : IsolatedRNGTuple 67 | The random number generator state for the training. 68 | Use if you need seeded random behavior 69 | 70 | Returns 71 | ------- 72 | Tuple[int, Dict] 73 | The number of samples used for training, 74 | the loss, and the accuracy of the input model on the given data. 75 | """ 76 | if not isinstance(net, nn.Module) or trainloader is None: 77 | raise ValueError("MNIST does not support implicit model/dataset creation.") 78 | 79 | if len(cast(Sized, trainloader.dataset)) == 0: 80 | raise ValueError( 81 | "Trainloader can't be 0, exiting...", 82 | ) 83 | 84 | config: TrainConfig = TrainConfig(**_config) 85 | del _config 86 | 87 | net.to(config.device) 88 | net.train() 89 | 90 | criterion = nn.CrossEntropyLoss() 91 | optimizer = torch.optim.SGD( 92 | net.parameters(), 93 | lr=config.learning_rate, 94 | weight_decay=0.001, 95 | ) 96 | 97 | final_epoch_per_sample_loss = 0.0 98 | num_correct = 0 99 | for _ in range(config.epochs): 100 | final_epoch_per_sample_loss = 0.0 101 | num_correct = 0 102 | for data, target in trainloader: 103 | data, target = ( 104 | data.to( 105 | config.device, 106 | ), 107 | target.to(config.device), 108 | ) 109 | optimizer.zero_grad() 110 | output = net(data) 111 | loss = criterion(output, target) 112 | final_epoch_per_sample_loss += loss.item() 113 | num_correct += (output.max(1)[1] == target).clone().detach().sum().item() 114 | loss.backward() 115 | optimizer.step() 116 | 117 | return ( 118 | net, 119 | len(cast(Sized, trainloader.dataset)), 120 | { 121 | "train_loss": final_epoch_per_sample_loss 122 | / len(cast(Sized, trainloader.dataset)), 123 | "train_accuracy": float(num_correct) 124 | / len(cast(Sized, trainloader.dataset)), 125 | }, 126 | ) 127 | 128 | 129 | class TestConfig(BaseModel): 130 | """Testing configuration, allows '.' member access and static checking. 131 | 132 | Guarantees that all necessary components are present, fails early if config is 133 | mismatched to client. 134 | """ 135 | 136 | cid: CID 137 | device: torch.device 138 | 139 | class Config: 140 | """Setting to allow any types, including library ones like torch.device.""" 141 | 142 | arbitrary_types_allowed = True 143 | 144 | 145 | def test( 146 | net: nn.Module | NDArrays, 147 | testloader: DataLoader | None, 148 | _config: dict, 149 | _working_dir: Path, 150 | _rng_tuple: IsolatedRNG, 151 | _hydra_config: DictConfig | None, 152 | ) -> tuple[float, int, dict]: 153 | """Evaluate the network on the test set. 154 | 155 | Parameters 156 | ---------- 157 | net : nn.Module 158 | The neural network to test. 159 | testloader : DataLoader 160 | The DataLoader containing the data to test the network on. 161 | _config : Dict 162 | The configuration for the testing. 163 | Contains the device. 164 | Static type checking is done by the TestConfig class. 165 | _working_dir : Path 166 | The working directory for the training. 167 | Unused. 168 | _rng_tuple : IsolatedRNGTuple 169 | The random number generator state for the training. 170 | Use if you need seeded random behavior 171 | 172 | 173 | Returns 174 | ------- 175 | Tuple[float, int, float] 176 | The loss, number of test samples, 177 | and the accuracy of the input model on the given data. 178 | """ 179 | if not isinstance(net, nn.Module) or testloader is None: 180 | raise ValueError("MNIST does not support implicit model/dataset creation.") 181 | 182 | if len(cast(Sized, testloader.dataset)) == 0: 183 | raise ValueError( 184 | "Testloader can't be 0, exiting...", 185 | ) 186 | 187 | config: TestConfig = TestConfig(**_config) 188 | del _config 189 | 190 | net.to(config.device) 191 | net.eval() 192 | 193 | criterion = nn.CrossEntropyLoss() 194 | correct, per_sample_loss = 0, 0.0 195 | 196 | with torch.no_grad(): 197 | for images, labels in testloader: 198 | images, labels = ( 199 | images.to( 200 | config.device, 201 | ), 202 | labels.to(config.device), 203 | ) 204 | outputs = net(images) 205 | per_sample_loss += criterion( 206 | outputs, 207 | labels, 208 | ).item() 209 | _, predicted = torch.max(outputs.data, 1) 210 | correct += (predicted == labels).sum().item() 211 | 212 | return ( 213 | per_sample_loss / len(cast(Sized, testloader.dataset)), 214 | len(cast(Sized, testloader.dataset)), 215 | { 216 | "test_accuracy": float(correct) / len(cast(Sized, testloader.dataset)), 217 | }, 218 | ) 219 | 220 | 221 | # Use defaults as they are completely determined 222 | # by the other functions defined in mnist_classification 223 | get_fed_eval_fn = get_default_fed_eval_fn 224 | get_on_fit_config_fn = get_default_on_fit_config_fn 225 | get_on_evaluate_config_fn = get_default_on_evaluate_config_fn 226 | -------------------------------------------------------------------------------- /project/types/common.py: -------------------------------------------------------------------------------- 1 | """Typing shared across the project meant to define a stable API. 2 | 3 | Prefer these interfaces over ad-hoc inline definitions or concrete types. 4 | """ 5 | 6 | from collections.abc import Callable 7 | from pathlib import Path 8 | import random 9 | from typing import Any 10 | 11 | import flwr as fl 12 | from flwr.common import NDArrays 13 | import numpy as np 14 | from omegaconf import DictConfig 15 | from torch import nn 16 | import torch 17 | from torch.utils.data import DataLoader 18 | import enum 19 | from flwr.common import Parameters 20 | from flwr.simulation.ray_transport.ray_actor import VirtualClientEngineActor 21 | 22 | CID = str | int | Path 23 | 24 | GlobalState = tuple[ 25 | # Random 26 | tuple[Any, ...], 27 | # np 28 | dict[str, Any], 29 | # torch 30 | torch.Tensor, 31 | ] 32 | 33 | IsolatedRNGState = tuple[ 34 | # Seed 35 | int, 36 | # Random 37 | tuple[Any, ...], 38 | # np 39 | dict[str, Any], 40 | # torch 41 | torch.Tensor, 42 | # torch GPU 43 | torch.Tensor | None, 44 | ] 45 | 46 | ClientCIDandSeedGeneratorsState = tuple[ 47 | # Client cid generator 48 | tuple[Any, ...], 49 | # Client seed generator 50 | tuple[Any, ...], 51 | ] 52 | 53 | 54 | # Contains the rng state for 55 | # Global: Random, NP, Torch rng 56 | # The RNG tuple of the server 57 | # Client CID Generator 58 | # Client Seed Generator 59 | RNGStateTuple = tuple[ 60 | # Global state 61 | GlobalState, 62 | # Server RNG state 63 | IsolatedRNGState, 64 | # Client cid and seed generators 65 | ClientCIDandSeedGeneratorsState, 66 | ] 67 | 68 | 69 | # Necessary to guarantee reproducibility across all sources of randomness 70 | IsolatedRNG = tuple[ 71 | int, random.Random, np.random.Generator, torch.Generator, torch.Generator | None 72 | ] 73 | 74 | # Payload for the generators controlling server behavior 75 | ServerRNG = tuple[ 76 | # Server RNG tuple 77 | IsolatedRNG, 78 | # Client cid and seed generators 79 | random.Random, 80 | random.Random, 81 | ] 82 | 83 | 84 | # Interface for network generators 85 | # all behavior mutations should be done 86 | # via closures or the config 87 | NetGen = Callable[ 88 | [ 89 | dict, 90 | IsolatedRNG, 91 | DictConfig | None, 92 | ], 93 | nn.Module, 94 | ] 95 | 96 | 97 | # Allows obtaining initial parameters for the network 98 | # Can be used even if NetGen is not provided 99 | InitialParameterGen = Callable[ 100 | [ 101 | NetGen | None, 102 | dict, 103 | IsolatedRNG, 104 | DictConfig | None, 105 | ], 106 | Parameters | None, 107 | ] 108 | 109 | 110 | # Dataloader generators for clients and server 111 | 112 | # Client dataloaders require the client id, 113 | # weather the dataloader is for training or evaluation 114 | # and the config 115 | ClientDataloaderGen = Callable[ 116 | [ 117 | CID, 118 | bool, 119 | dict, 120 | IsolatedRNG, 121 | DictConfig | None, 122 | ], 123 | DataLoader, 124 | ] 125 | 126 | # Server dataloaders only require a config and 127 | # weather the dataloader is for training or evaluation 128 | FedDataloaderGen = Callable[ 129 | [ 130 | bool, 131 | dict, 132 | IsolatedRNG, 133 | DictConfig | None, 134 | ], 135 | DataLoader, 136 | ] 137 | 138 | # Client generators require the client id only 139 | # necessary for ray instantiation 140 | # all changes in behavior should be done via a closure 141 | ClientGen = Callable[[str], fl.client.NumPyClient] 142 | 143 | TrainFunc = Callable[ 144 | [ 145 | nn.Module | NDArrays, 146 | DataLoader | None, 147 | dict, 148 | Path, 149 | IsolatedRNG, 150 | DictConfig | None, 151 | ], 152 | tuple[nn.Module | NDArrays, int, dict], 153 | ] 154 | TestFunc = Callable[ 155 | [ 156 | nn.Module | NDArrays, 157 | DataLoader | None, 158 | dict, 159 | Path, 160 | IsolatedRNG, 161 | DictConfig | None, 162 | ], 163 | tuple[float, int, dict], 164 | ] 165 | 166 | # Type aliases for fit and eval results 167 | # discounting the Dict[str,Scalar] typing 168 | # of the original flwr types 169 | FitRes = tuple[NDArrays, int, dict] 170 | EvalRes = tuple[float, int, dict] 171 | 172 | 173 | # A function to initialize the working directory 174 | # in case your training relies on a specific 175 | # directory structure pre-existing 176 | InitWorkingDir = Callable[ 177 | [ 178 | # The working dir path 179 | Path, 180 | # The results dir path 181 | Path, 182 | ], 183 | None, 184 | ] 185 | 186 | # A federated evaluation function 187 | # used by the server to test the model between rounds 188 | # requires the round number, the model parameters 189 | # and the config 190 | # returns the test loss and the metrics 191 | FedEvalFN = Callable[ 192 | [int, NDArrays, dict], 193 | tuple[float, dict] | None, 194 | ] 195 | 196 | FedEvalGen = Callable[ 197 | [ 198 | NetGen | None, 199 | FedDataloaderGen | None, 200 | TestFunc, 201 | dict, 202 | Path, 203 | IsolatedRNG, 204 | DictConfig | None, 205 | ], 206 | FedEvalFN | None, 207 | ] 208 | 209 | # Functions to generate config dictionaries 210 | # for fit and evaluate 211 | OnFitConfigFN = Callable[[int], dict] 212 | OnEvaluateConfigFN = OnFitConfigFN 213 | 214 | # Structures to define a complete task setup 215 | # They can be varied independently to some extent 216 | # Allows us to take advantage of hydra without 217 | # losing static type checking 218 | TrainStructure = tuple[TrainFunc, TestFunc, FedEvalGen] 219 | DataStructure = tuple[ 220 | NetGen | None, 221 | InitialParameterGen | None, 222 | ClientDataloaderGen | None, 223 | FedDataloaderGen | None, 224 | InitWorkingDir | None, 225 | ] 226 | ConfigStructure = tuple[OnFitConfigFN, OnEvaluateConfigFN] 227 | 228 | 229 | # Returns a client generator and the ray actor which 230 | # Dispatches clients 231 | ClientTypeGen = Callable[ 232 | [ 233 | # The working directory 234 | Path, 235 | # The network generator 236 | NetGen | None, 237 | # The client dataloader generator 238 | ClientDataloaderGen | None, 239 | # The training function 240 | TrainFunc, 241 | # The testing function 242 | TestFunc, 243 | # Seeded rng for client seed initialization 244 | random.Random, 245 | # Hydra config 246 | DictConfig | None, 247 | ], 248 | ClientGen, 249 | ] 250 | 251 | ClientAndActorStructure = tuple[ 252 | ClientTypeGen, 253 | type[VirtualClientEngineActor], 254 | dict[str, Any] | None, 255 | ] 256 | 257 | 258 | class IntentionalDropoutError(Exception): 259 | """Exception raised when a client intentionally drops out.""" 260 | 261 | 262 | class FileCountExceededError(Exception): 263 | """Exception raised when a client intentionally drops out.""" 264 | 265 | 266 | class Folders(enum.StrEnum): 267 | """Enum for folder types.""" 268 | 269 | WORKING = enum.auto() 270 | STATE = enum.auto() 271 | PARAMETERS = enum.auto() 272 | RNG = enum.auto() 273 | HISTORIES = enum.auto() 274 | HYDRA = ".hydra" 275 | RESULTS = enum.auto() 276 | WANDB = enum.auto() 277 | 278 | 279 | class Files(enum.StrEnum): 280 | """Enum for file types.""" 281 | 282 | PARAMETERS = enum.auto() 283 | RNG_STATE = "rng-state" 284 | HISTORY = enum.auto() 285 | MAIN = enum.auto() 286 | WANDB_RUN = enum.auto() 287 | 288 | 289 | class Ext(enum.StrEnum): 290 | """Enum for file extensions.""" 291 | 292 | PARAMETERS = "bin" 293 | RNG_STATE = "pt" 294 | HISTORY = "json" 295 | MAIN = "log" 296 | WANDB_RUN = "json" 297 | -------------------------------------------------------------------------------- /project/fed/server/wandb_server.py: -------------------------------------------------------------------------------- 1 | """Flower server accounting for Weights&Biases+file saving.""" 2 | 3 | import timeit 4 | from collections.abc import Callable 5 | from logging import INFO 6 | from typing import Any 7 | 8 | from flwr.common import Parameters 9 | from flwr.common.logger import log 10 | from flwr.server import Server 11 | from flwr.server.client_manager import ClientManager 12 | from flwr.server.history import History 13 | from flwr.server.strategy import Strategy 14 | from omegaconf import DictConfig 15 | 16 | from project.types.common import ServerRNG 17 | 18 | 19 | class WandbServer(Server): 20 | """Flower server.""" 21 | 22 | def __init__( 23 | self, 24 | *, 25 | client_manager: ClientManager, 26 | hydra_config: DictConfig | None, 27 | starting_round: int = 0, 28 | server_rng: ServerRNG, 29 | strategy: Strategy | None = None, 30 | history: History | None = None, 31 | save_parameters_to_file: Callable[ 32 | [Parameters], 33 | None, 34 | ], 35 | save_rng_to_file: Callable[[ServerRNG], None], 36 | save_history_to_file: Callable[[History], None], 37 | save_files_per_round: Callable[[int], None], 38 | ) -> None: 39 | """Flower server implementation. 40 | 41 | Parameters 42 | ---------- 43 | client_manager : ClientManager 44 | Client manager implementation. 45 | strategy : Optional[Strategy] 46 | Strategy implementation. 47 | history : Optional[History] 48 | History implementation. 49 | save_parameters_to_file : Callable[[Parameters], None] 50 | Function to save the parameters to file. 51 | save_files_per_round : Callable[[int], None] 52 | Function to save files every round. 53 | 54 | Returns 55 | ------- 56 | None 57 | """ 58 | super().__init__( 59 | client_manager=client_manager, 60 | strategy=strategy, 61 | ) 62 | 63 | self.history: History | None = history 64 | self.save_parameters_to_file = save_parameters_to_file 65 | self.save_files_per_round = save_files_per_round 66 | self.starting_round = starting_round 67 | self.server_rng = server_rng 68 | self.save_rng_to_file = save_rng_to_file 69 | self.save_history_to_file = save_history_to_file 70 | self.hydra_config = hydra_config 71 | 72 | # pylint: disable=too-many-locals 73 | def fit( 74 | self, 75 | num_rounds: int, 76 | timeout: float | None, 77 | ) -> History: 78 | """Run federated averaging for a number of rounds. 79 | 80 | Parameters 81 | ---------- 82 | num_rounds : int 83 | The number of rounds to run. 84 | timeout : Optional[float] 85 | Timeout in seconds. 86 | 87 | Returns 88 | ------- 89 | History 90 | The history of the training. 91 | Potentially using a pre-defined history. 92 | """ 93 | history = self.history if self.history is not None else History() 94 | # Initialize parameters 95 | log(INFO, "Initializing global parameters") 96 | self.parameters = self._get_initial_parameters( 97 | timeout=timeout, 98 | ) 99 | 100 | if self.starting_round == 0: 101 | log(INFO, "Evaluating initial parameters") 102 | res = self.strategy.evaluate( 103 | 0, 104 | parameters=self.parameters, 105 | ) 106 | if res is not None: 107 | log( 108 | INFO, 109 | "initial parameters (loss, other metrics): %s, %s", 110 | res[0], 111 | res[1], 112 | ) 113 | history.add_loss_centralized( 114 | server_round=0, 115 | loss=res[0], 116 | ) 117 | history.add_metrics_centralized( 118 | server_round=0, 119 | metrics=res[1], 120 | ) 121 | # Save initial parameters and files 122 | self.save_parameters_to_file(self.parameters) 123 | self.save_rng_to_file(self.server_rng) 124 | self.save_files_per_round(0) 125 | 126 | # Run federated learning for num_rounds 127 | log(INFO, "FL starting") 128 | start_time = timeit.default_timer() 129 | 130 | for current_round in range(self.starting_round + 1, num_rounds + 1): 131 | # Train model and replace previous global model 132 | res_fit = self.fit_round( 133 | server_round=current_round, 134 | timeout=timeout, 135 | ) 136 | if res_fit is not None: 137 | ( 138 | parameters_prime, 139 | fit_metrics, 140 | _, 141 | ) = res_fit # fit_metrics_aggregated 142 | if parameters_prime: 143 | self.parameters = parameters_prime 144 | history.add_metrics_distributed_fit( 145 | server_round=current_round, 146 | metrics=fit_metrics, 147 | ) 148 | 149 | # Evaluate model using strategy implementation 150 | res_cen = self.strategy.evaluate( 151 | current_round, 152 | parameters=self.parameters, 153 | ) 154 | if res_cen is not None: 155 | loss_cen, metrics_cen = res_cen 156 | log( 157 | INFO, 158 | "fit progress: (%s, %s, %s, %s)", 159 | current_round, 160 | loss_cen, 161 | metrics_cen, 162 | timeit.default_timer() - start_time, 163 | ) 164 | history.add_loss_centralized( 165 | server_round=current_round, 166 | loss=loss_cen, 167 | ) 168 | history.add_metrics_centralized( 169 | server_round=current_round, 170 | metrics=metrics_cen, 171 | ) 172 | 173 | # Evaluate model on a sample of available clients 174 | res_fed = self.evaluate_round( 175 | server_round=current_round, 176 | timeout=timeout, 177 | ) 178 | if res_fed is not None: 179 | loss_fed, evaluate_metrics_fed, _ = res_fed 180 | if loss_fed is not None: 181 | history.add_loss_distributed( 182 | server_round=current_round, 183 | loss=loss_fed, 184 | ) 185 | history.add_metrics_distributed( 186 | server_round=current_round, 187 | metrics=evaluate_metrics_fed, 188 | ) 189 | # Saver round parameters and files 190 | self.save_parameters_to_file(self.parameters) 191 | self.save_history_to_file(history) 192 | self.save_rng_to_file(self.server_rng) 193 | self.save_files_per_round(current_round) 194 | 195 | # Bookkeeping 196 | end_time = timeit.default_timer() 197 | elapsed = end_time - start_time 198 | log(INFO, "FL finished in %s", elapsed) 199 | return history 200 | 201 | 202 | def dispatch_wandb_server(cfg: DictConfig, **kwargs: Any) -> type[WandbServer] | None: 203 | """Dispatch the get_wandb_server function based on the hydra config. 204 | 205 | Parameters 206 | ---------- 207 | cfg : DictConfig 208 | The configuration for the get_wandb_server function. 209 | Loaded dynamically from the config file. 210 | kwargs : dict[str, Any] 211 | Additional keyword arguments to pass to the get_wandb_server function. 212 | 213 | Returns 214 | ------- 215 | type[WandbServer] 216 | The get_wandb_server function. 217 | """ 218 | server: str | None = cfg.get("task", None).get("server", None) 219 | 220 | if server is None: 221 | return None 222 | 223 | if server.upper() == "DEFAULT": 224 | return WandbServer 225 | 226 | return None 227 | -------------------------------------------------------------------------------- /project/dispatch/dispatch.py: -------------------------------------------------------------------------------- 1 | """Dispatches the functionality of the task. 2 | 3 | This gives us the ability to dynamically choose functionality based on the hydra dict 4 | config without losing static type checking. 5 | """ 6 | 7 | from collections.abc import Callable 8 | from typing import Any 9 | 10 | from omegaconf import DictConfig 11 | 12 | from project.fed.server.deterministic_client_manager import ( 13 | DeterministicClientManager, 14 | dispatch_deterministic_client_manager as dispatch_default_client_manager, 15 | ) 16 | from project.fed.server.wandb_server import ( 17 | WandbServer, 18 | dispatch_wandb_server as dispatch_default_server, 19 | ) 20 | from project.task.default.dispatch import dispatch_config as dispatch_default_config 21 | from project.task.default.dispatch import dispatch_data as dispatch_default_data 22 | from project.task.default.dispatch import dispatch_train as dispatch_default_train 23 | from project.client.client import ( 24 | dispatch_client_gen as dispatch_default_client_gen, 25 | ) 26 | 27 | from project.task.mnist_classification.dispatch import ( 28 | dispatch_config as dispatch_mnist_config, 29 | ) 30 | from project.task.mnist_classification.dispatch import ( 31 | dispatch_data as dispatch_mnist_data, 32 | ) 33 | from project.task.mnist_classification.dispatch import ( 34 | dispatch_train as dispatch_mnist_train, 35 | ) 36 | from project.types.common import ( 37 | ClientAndActorStructure, 38 | ConfigStructure, 39 | DataStructure, 40 | ClientTypeGen, 41 | TrainStructure, 42 | ) 43 | from flwr.simulation.ray_transport.ray_actor import ( 44 | VirtualClientEngineActor, 45 | ) 46 | 47 | 48 | def dispatch_train(cfg: DictConfig, **kwargs: Any) -> TrainStructure: 49 | """Dispatch the train/test and fed test functions based on the config file. 50 | 51 | Functionality should be added to the dispatch.py file in the task folder. 52 | Statically specify the new dispatch function in the list, 53 | function order determines precedence if two different tasks may match the config. 54 | 55 | Parameters 56 | ---------- 57 | cfg : DictConfig 58 | The configuration for the train function. 59 | Loaded dynamically from the config file. 60 | kwargs : dict[str, Any] 61 | Additional keyword arguments to pass to the train function. 62 | 63 | Returns 64 | ------- 65 | TrainStructure 66 | The train function, test function and the get_fed_eval_fn function. 67 | """ 68 | # Create the list of task dispatches to try 69 | task_train_functions: list[Callable[..., TrainStructure | None]] = [ 70 | dispatch_default_train, 71 | dispatch_mnist_train, 72 | ] 73 | 74 | # Match the first function which does not return None 75 | for task in task_train_functions: 76 | result = task(cfg, **kwargs) 77 | if result is not None: 78 | return result 79 | 80 | raise ValueError( 81 | f"Unable to match the train/test and fed_test functions: {cfg}", 82 | ) 83 | 84 | 85 | def dispatch_data(cfg: DictConfig, **kwargs: Any) -> DataStructure: 86 | """Dispatch the net generator and dataloader client/fed generator functions. 87 | 88 | Functionality should be added to the dispatch.py file in the task folder. 89 | Statically specify the new dispatch function in the list, 90 | function order determines precedence if two different tasks may match the config. 91 | 92 | Parameters 93 | ---------- 94 | cfg : DictConfig 95 | The configuration for the data function. 96 | Loaded dynamically from the config file. 97 | kwargs : dict[str, Any] 98 | Additional keyword arguments to pass to the data function. 99 | 100 | Returns 101 | ------- 102 | DataStructure 103 | The net generator and dataloader generator functions. 104 | """ 105 | # Create the list of task dispatches to try 106 | task_data_dependent_functions: list[Callable[..., DataStructure | None]] = [ 107 | dispatch_mnist_data, 108 | dispatch_default_data, 109 | ] 110 | 111 | # Match the first function which does not return None 112 | for task in task_data_dependent_functions: 113 | result = task(cfg, **kwargs) 114 | if result is not None: 115 | return result 116 | 117 | raise ValueError( 118 | f"Unable to match the net generator and dataloader generator functions: {cfg}", 119 | ) 120 | 121 | 122 | def dispatch_config(cfg: DictConfig, **kwargs: Any) -> ConfigStructure: 123 | """Dispatch the fit/eval config functions based on on the hydra config. 124 | 125 | Functionality should be added to the dispatch.py 126 | file in the task folder. 127 | Statically specify the new dispatch function in the list, 128 | function order determines precedence 129 | if two different tasks may match the config. 130 | 131 | Parameters 132 | ---------- 133 | cfg : DictConfig 134 | The configuration for the config function. 135 | Loaded dynamically from the config file. 136 | kwargs : dict[str, Any] 137 | Additional keyword arguments to pass to the config function. 138 | 139 | Returns 140 | ------- 141 | ConfigStructure 142 | The config functions. 143 | """ 144 | # Create the list of task dispatches to try 145 | task_config_functions: list[Callable[..., ConfigStructure | None]] = [ 146 | dispatch_mnist_config, 147 | dispatch_default_config, 148 | ] 149 | 150 | # Match the first function which does not return None 151 | for task in task_config_functions: 152 | result = task(cfg, **kwargs) 153 | if result is not None: 154 | return result 155 | 156 | raise ValueError( 157 | f"Unable to match the config generation functions: {cfg}", 158 | ) 159 | 160 | 161 | def dispatch_get_client_generator( 162 | cfg: DictConfig, **kwargs: Any 163 | ) -> ClientAndActorStructure: 164 | """Dispatch the get_client_generator function based on the hydra config. 165 | 166 | Functionality should be added to the dispatch.py 167 | file in the task folder. 168 | Statically specify the new dispatch function in the list, 169 | function order determines precedence 170 | if two different tasks may match the config. 171 | 172 | Parameters 173 | ---------- 174 | cfg : DictConfig 175 | The configuration for the get_client_generators function. 176 | Loaded dynamically from the config file. 177 | kwargs : dict[str, Any] 178 | Additional keyword arguments to pass to the get_client_generators function. 179 | 180 | Returns 181 | ------- 182 | GetClientGen 183 | The get_client_generators function. 184 | """ 185 | # Create the list of task dispatches to try 186 | task_get_client_generators: list[ 187 | Callable[ 188 | ..., 189 | tuple[ 190 | ClientTypeGen, 191 | type[VirtualClientEngineActor], 192 | dict[str, Any] | None, 193 | ] 194 | | None, 195 | ] 196 | ] = [ 197 | dispatch_default_client_gen, 198 | ] 199 | 200 | # Match the first function which does not return None 201 | for task in task_get_client_generators: 202 | result = task(cfg, **kwargs) 203 | if result is not None: 204 | return result 205 | 206 | raise ValueError( 207 | f"Unable to match the get_client_generators function: {cfg}", 208 | ) 209 | 210 | 211 | def dispatch_get_client_manager( 212 | cfg: DictConfig, **kwargs: Any 213 | ) -> type[DeterministicClientManager]: 214 | """Dispatch the get_client_manager function based on the hydra config. 215 | 216 | Parameters 217 | ---------- 218 | cfg : DictConfig 219 | The configuration for the get_client_manager function. 220 | Loaded dynamically from the config file. 221 | kwargs : dict[str, Any] 222 | Additional keyword arguments to pass to the get_client_manager function. 223 | 224 | Returns 225 | ------- 226 | type[DeterministicClientManager] 227 | The get_client_manager function. 228 | """ 229 | # Create the list of task dispatches to try 230 | task_get_client_managers: list[ 231 | Callable[..., type[DeterministicClientManager] | None] 232 | ] = [ 233 | dispatch_default_client_manager, 234 | ] 235 | 236 | # Match the first function which does not return None 237 | for task in task_get_client_managers: 238 | result = task(cfg, **kwargs) 239 | if result is not None: 240 | return result 241 | 242 | raise ValueError( 243 | f"Unable to match the get_client_manager function: {cfg}", 244 | ) 245 | 246 | 247 | def dispatch_server(cfg: DictConfig, **kwargs: Any) -> type[WandbServer]: 248 | """Dispatch the get_server function based on the hydra config. 249 | 250 | Parameters 251 | ---------- 252 | cfg : DictConfig 253 | The configuration for the get_server function. 254 | Loaded dynamically from the config file. 255 | 256 | Returns 257 | ------- 258 | type[WandbServer] 259 | The get_server function. 260 | """ 261 | # Create the list of task dispatches to try 262 | task_get_client_managers: list[Callable[..., type[WandbServer] | None]] = [ 263 | dispatch_default_server, 264 | ] 265 | 266 | # Match the first function which does not return None 267 | for task in task_get_client_managers: 268 | result = task(cfg, **kwargs) 269 | if result is not None: 270 | return result 271 | 272 | raise ValueError( 273 | f"Unable to match the get_client_manager function: {cfg}", 274 | ) 275 | -------------------------------------------------------------------------------- /project/task/default/train_test.py: -------------------------------------------------------------------------------- 1 | """Default training and testing functions, local and federated.""" 2 | 3 | from collections.abc import Sized 4 | from pathlib import Path 5 | from typing import cast 6 | 7 | from omegaconf import DictConfig 8 | import torch 9 | from flwr.common import NDArrays 10 | from pydantic import BaseModel 11 | from torch import nn 12 | from torch.utils.data import DataLoader 13 | 14 | from project.client.client import ClientConfig 15 | from project.fed.utils.utils import generic_set_parameters 16 | from project.types.common import ( 17 | FedDataloaderGen, 18 | FedEvalFN, 19 | IsolatedRNG, 20 | NetGen, 21 | OnFitConfigFN, 22 | TestFunc, 23 | CID, 24 | ) 25 | from project.utils.utils import obtain_device 26 | 27 | 28 | class TrainConfig(BaseModel): 29 | """Training configuration, allows '.' member access and static checking. 30 | 31 | Guarantees that all necessary components are present, fails early if config is 32 | mismatched to client. 33 | """ 34 | 35 | cid: CID 36 | device: torch.device 37 | # epochs: int 38 | # learning_rate: float 39 | 40 | class Config: 41 | """Setting to allow any types, including library ones like torch.device.""" 42 | 43 | arbitrary_types_allowed = True 44 | 45 | 46 | def train( 47 | net: nn.Module | NDArrays, 48 | trainloader: DataLoader | None, 49 | _config: dict, 50 | _working_dir: Path, 51 | rng_tuple: IsolatedRNG, 52 | _hydra_config: DictConfig | None, 53 | ) -> tuple[nn.Module | NDArrays, int, dict]: 54 | """Train the network on the training set. 55 | 56 | Parameters 57 | ---------- 58 | net : nn.Module 59 | The neural network to train. 60 | trainloader : DataLoader 61 | The DataLoader containing the data to train the network on. 62 | _config : Dict 63 | The configuration for the training. 64 | Contains the device, number of epochs and learning rate. 65 | Static type checking is done by the TrainConfig class. 66 | _working_dir : Path 67 | The working directory for the training. 68 | Unused. 69 | _rng_tuple : IsolatedRNGTuple 70 | The random number generator state for the training. 71 | Use if you need seeded random behavior 72 | 73 | Returns 74 | ------- 75 | Tuple[int, Dict] 76 | The number of samples used for training, 77 | the loss, and the accuracy of the input model on the given data. 78 | """ 79 | if not isinstance(net, nn.Module) or trainloader is None: 80 | raise ValueError( 81 | "The default config does not use an implicit network generator/dataset" 82 | ) 83 | 84 | if len(cast(Sized, trainloader.dataset)) == 0: 85 | raise ValueError( 86 | "Trainloader can't be 0, exiting...", 87 | ) 88 | 89 | config: TrainConfig = TrainConfig(**_config) 90 | del _config 91 | 92 | net.to(config.device) 93 | net.train() 94 | 95 | return net, len(cast(Sized, trainloader.dataset)), {} 96 | 97 | 98 | class TestConfig(BaseModel): 99 | """Testing configuration, allows '.' member access and static checking. 100 | 101 | Guarantees that all necessary components are present, fails early if config is 102 | mismatched to client. 103 | """ 104 | 105 | cid: CID 106 | device: torch.device 107 | 108 | class Config: 109 | """Setting to allow any types, including library ones like torch.device.""" 110 | 111 | arbitrary_types_allowed = True 112 | 113 | 114 | def test( 115 | net: nn.Module | NDArrays, 116 | testloader: DataLoader | None, 117 | _config: dict, 118 | _working_dir: Path, 119 | rng_tuple: IsolatedRNG, 120 | _hydra_config: DictConfig | None, 121 | ) -> tuple[float, int, dict]: 122 | """Evaluate the network on the test set. 123 | 124 | Parameters 125 | ---------- 126 | net : nn.Module 127 | The neural network to test. 128 | testloader : DataLoader 129 | The DataLoader containing the data to test the network on. 130 | _config : Dict 131 | The configuration for the testing. 132 | Contains the device. 133 | Static type checking is done by the TestConfig class. 134 | _working_dir : Path 135 | The working directory for the training. 136 | Unused. 137 | _rng_tuple : IsolatedRNGTuple 138 | The random number generator state for the training. 139 | Use if you need seeded random behavior 140 | 141 | Returns 142 | ------- 143 | Tuple[float, int, float] 144 | The loss, number of test samples, 145 | and the accuracy of the input model on the given data. 146 | """ 147 | if not isinstance(net, nn.Module) or testloader is None: 148 | raise ValueError( 149 | "The default config does not use an implicit network generator/dataset" 150 | ) 151 | 152 | if len(cast(Sized, testloader.dataset)) == 0: 153 | raise ValueError( 154 | "Testloader can't be 0, exiting...", 155 | ) 156 | 157 | config: TestConfig = TestConfig(**_config) 158 | del _config 159 | 160 | net.to(config.device) 161 | net.eval() 162 | 163 | return ( 164 | 0.0, 165 | len(cast(Sized, testloader.dataset)), 166 | {}, 167 | ) 168 | 169 | 170 | def get_fed_eval_fn( 171 | net_generator: NetGen | None, 172 | fed_dataloader_generator: FedDataloaderGen | None, 173 | test_func: TestFunc, 174 | _config: dict, 175 | working_dir: Path, 176 | rng_tuple: IsolatedRNG, 177 | hydra_config: DictConfig | None, 178 | ) -> FedEvalFN | None: 179 | """Get the federated evaluation function. 180 | 181 | Parameters 182 | ---------- 183 | net_generator : NetGenerator 184 | The function to generate the network. 185 | fed_dataloader_generator : DataLoader 186 | The DataLoader containing the data to test the network on. 187 | test_func : TestFunc 188 | The function to evaluate the network. 189 | _config : Dict 190 | The configuration for the testing. 191 | Contains the device. 192 | Static type checking is done by the TestConfig class. 193 | working_dir : Path 194 | The working directory for the training. 195 | _rng_tuple : IsolatedRNGTuple 196 | The random number generator state for the training. 197 | Use if you need seeded random behavior 198 | 199 | Returns 200 | ------- 201 | Optional[FedEvalFN] 202 | The evaluation function for the server 203 | if the testloader is not empty, else None. 204 | """ 205 | config: ClientConfig = ClientConfig(**_config) 206 | del _config 207 | 208 | testloader = ( 209 | fed_dataloader_generator( 210 | True, 211 | config.dataloader_config, 212 | rng_tuple, 213 | hydra_config, 214 | ) 215 | if fed_dataloader_generator 216 | else None 217 | ) 218 | 219 | def fed_eval_fn( 220 | _server_round: int, 221 | parameters: NDArrays, 222 | fake_config: dict, 223 | ) -> tuple[float, dict] | None: 224 | """Evaluate the model on the given data. 225 | 226 | Parameters 227 | ---------- 228 | server_round : int 229 | The current server round. 230 | parameters : NDArrays 231 | The parameters of the model to evaluate. 232 | _config : Dict 233 | The configuration for the evaluation. 234 | 235 | Returns 236 | ------- 237 | Optional[Tuple[float, Dict]] 238 | The loss and the accuracy of the input model on the given data. 239 | """ 240 | net = ( 241 | net_generator(config.net_config, rng_tuple, hydra_config) 242 | if net_generator 243 | else None 244 | ) 245 | if net is not None: 246 | generic_set_parameters(net, parameters) 247 | 248 | config.run_config["device"] = obtain_device() 249 | config.run_config["cid"] = "server" 250 | 251 | if testloader is not None and len(cast(Sized, testloader.dataset)) == 0: 252 | return None 253 | 254 | loss, _num_samples, metrics = test_func( 255 | net if net is not None else parameters, 256 | testloader, 257 | config.run_config, 258 | working_dir, 259 | rng_tuple, 260 | hydra_config, 261 | ) 262 | return loss, metrics 263 | 264 | return fed_eval_fn 265 | 266 | 267 | # Get NONE fed eval fn 268 | def get_none_fed_eval_fn( 269 | net_generator: NetGen | None, 270 | fed_dataloader_generator: FedDataloaderGen | None, 271 | test_func: TestFunc, 272 | _config: dict, 273 | working_dir: Path, 274 | rng_tuple: IsolatedRNG, 275 | hydra_config: DictConfig | None, 276 | ) -> FedEvalFN | None: 277 | """Get an empty federated evaluation function.""" 278 | return None 279 | 280 | 281 | def get_on_fit_config_fn(fit_config: dict) -> OnFitConfigFN: 282 | """Generate on_fit_config_fn based on a dict from the hydra config,. 283 | 284 | Parameters 285 | ---------- 286 | fit_config : Dict 287 | The configuration for the fit function. 288 | Loaded dynamically from the config file. 289 | rng_tuple : IsolatedRNGTuple 290 | The random number generator state for the training. 291 | Use if you need seeded random behavior 292 | 293 | Returns 294 | ------- 295 | Optional[OnFitConfigFN] 296 | The on_fit_config_fn for the server if the fit_config is not empty, else None. 297 | """ 298 | # Fail early if the fit_config does not match expectations 299 | ClientConfig(**fit_config) 300 | 301 | def fit_config_fn(server_round: int) -> dict: 302 | """MNIST on_fit_config_fn. 303 | 304 | Parameters 305 | ---------- 306 | server_round : int 307 | The current server round. 308 | Passed to the client 309 | 310 | Returns 311 | ------- 312 | Dict 313 | The configuration for the fit function. 314 | Loaded dynamically from the config file. 315 | """ 316 | fit_config["extra"]["server_round"] = server_round 317 | return fit_config 318 | 319 | return fit_config_fn 320 | 321 | 322 | # Differences between the two will come 323 | # from the config file 324 | get_on_evaluate_config_fn = get_on_fit_config_fn 325 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /project/client/client.py: -------------------------------------------------------------------------------- 1 | """The default client implementation. 2 | 3 | Make sure the model and dataset are not loaded before the fit function. 4 | """ 5 | 6 | import random 7 | from pathlib import Path 8 | from typing import Any, cast 9 | 10 | import flwr as fl 11 | from flwr.common import NDArrays 12 | from flwr.server import History 13 | from omegaconf import DictConfig 14 | from pydantic import BaseModel 15 | from torch import nn 16 | 17 | from project.fed.utils.utils import ( 18 | generic_get_parameters, 19 | generic_set_parameters, 20 | get_isolated_rng_tuple, 21 | ) 22 | from project.types.common import ( 23 | CID, 24 | ClientDataloaderGen, 25 | ClientGen, 26 | ConfigStructure, 27 | DataStructure, 28 | EvalRes, 29 | FitRes, 30 | ClientTypeGen, 31 | NetGen, 32 | TestFunc, 33 | TrainFunc, 34 | ServerRNG, 35 | TrainStructure, 36 | ) 37 | from project.utils.utils import obtain_device 38 | from flwr.simulation.ray_transport.ray_actor import ( 39 | DefaultActor, 40 | VirtualClientEngineActor, 41 | ) 42 | from flwr.common import Parameters 43 | 44 | 45 | class ClientConfig(BaseModel): 46 | """Fit/eval config, allows '.' member access and static checking. 47 | 48 | Used to check whether each component has its own independent config present. Each 49 | component should then use its own Pydantic model to validate its config. For 50 | anything extra, use the extra field as a simple dict. 51 | """ 52 | 53 | # Instantiate model 54 | net_config: dict 55 | # Instantiate dataloader 56 | dataloader_config: dict 57 | # For train/test 58 | run_config: dict 59 | # Additional params used like a Dict 60 | extra: dict 61 | 62 | class Config: 63 | """Setting to allow any types, including library ones like torch.device.""" 64 | 65 | arbitrary_types_allowed = True 66 | 67 | 68 | class Client(fl.client.NumPyClient): 69 | """Virtual client for ray.""" 70 | 71 | def __init__( 72 | self, 73 | cid: CID, 74 | working_dir: Path, 75 | net_generator: NetGen | None, 76 | dataloader_gen: ClientDataloaderGen | None, 77 | train: TrainFunc, 78 | test: TestFunc, 79 | client_seed: int, 80 | hydra_config: DictConfig | None, 81 | ) -> None: 82 | """Initialize the client. 83 | 84 | Only ever instantiate the model or load dataset 85 | inside fit/eval, never in init. 86 | 87 | Parameters 88 | ---------- 89 | cid : int | str | Path 90 | The client's ID. 91 | working_dir : Path 92 | The path to the working directory. 93 | net_generator : NetGen 94 | The network generator. 95 | dataloader_gen : ClientDataloaderGen 96 | The dataloader generator. 97 | Uses the client id to determine partition. 98 | 99 | Returns 100 | ------- 101 | None 102 | """ 103 | super().__init__() 104 | self.cid = cid 105 | self.net_generator = net_generator 106 | self.working_dir = working_dir 107 | self.net: nn.Module | NDArrays | None = None 108 | self.dataloader_gen = dataloader_gen 109 | self.train = train 110 | self.test = test 111 | 112 | # For deterministic client execution 113 | # The client_seed is generated from a specific Generator 114 | self.client_seed = client_seed 115 | self.rng_tuple = get_isolated_rng_tuple(self.client_seed, obtain_device()) 116 | 117 | self.hydra_config = hydra_config 118 | 119 | def fit( 120 | self, 121 | parameters: NDArrays, 122 | _config: dict, 123 | ) -> FitRes: 124 | """Fit the model using the provided parameters. 125 | 126 | Only ever instantiate the model or load dataset 127 | inside fit, never in init. 128 | 129 | Parameters 130 | ---------- 131 | parameters : NDArrays 132 | The parameters to use for training. 133 | _config : Dict 134 | The configuration for the training. 135 | Uses the pydantic model for static checking. 136 | 137 | Returns 138 | ------- 139 | FitRes 140 | The parameters after training, the number of samples used and the metrics. 141 | """ 142 | config: ClientConfig = ClientConfig(**_config) 143 | del _config 144 | 145 | config.run_config["device"] = obtain_device() 146 | config.run_config["cid"] = self.cid 147 | 148 | self.net = self.set_parameters( 149 | parameters, 150 | config.net_config, 151 | ) 152 | trainloader = ( 153 | self.dataloader_gen( 154 | self.cid, 155 | False, 156 | config.dataloader_config, 157 | self.rng_tuple, 158 | self.hydra_config, 159 | ) 160 | if self.dataloader_gen is not None 161 | else None 162 | ) 163 | self.net, num_samples, metrics = self.train( 164 | self.net, 165 | trainloader, 166 | config.run_config, 167 | self.working_dir, 168 | self.rng_tuple, 169 | self.hydra_config, 170 | ) 171 | 172 | return ( 173 | self.get_parameters({}), 174 | num_samples, 175 | metrics, 176 | ) 177 | 178 | def evaluate( 179 | self, 180 | parameters: NDArrays, 181 | _config: dict, 182 | ) -> EvalRes: 183 | """Evaluate the model using the provided parameters. 184 | 185 | Only ever instantiate the model or load dataset 186 | inside eval, never in init. 187 | 188 | Parameters 189 | ---------- 190 | parameters : NDArrays 191 | The parameters to use for evaluation. 192 | _config : Dict 193 | The configuration for the evaluation. 194 | Uses the pydantic model for static checking. 195 | 196 | Returns 197 | ------- 198 | EvalRes 199 | The loss, the number of samples used and the metrics. 200 | """ 201 | config: ClientConfig = ClientConfig(**_config) 202 | del _config 203 | 204 | config.run_config["device"] = obtain_device() 205 | config.run_config["cid"] = self.cid 206 | 207 | self.net = self.set_parameters( 208 | parameters, 209 | config.net_config, 210 | ) 211 | testloader = ( 212 | self.dataloader_gen( 213 | self.cid, 214 | True, 215 | config.dataloader_config, 216 | self.rng_tuple, 217 | self.hydra_config, 218 | ) 219 | if self.dataloader_gen is not None 220 | else None 221 | ) 222 | loss, num_samples, metrics = self.test( 223 | self.net, 224 | testloader, 225 | config.run_config, 226 | self.working_dir, 227 | self.rng_tuple, 228 | self.hydra_config, 229 | ) 230 | return loss, num_samples, metrics 231 | 232 | def get_parameters(self, config: dict) -> NDArrays: 233 | """Obtain client parameters. 234 | 235 | If the network is currently none,generate a network using the net_generator. 236 | 237 | Parameters 238 | ---------- 239 | config : Dict 240 | The configuration for the training. 241 | 242 | Returns 243 | ------- 244 | NDArrays 245 | The parameters of the network. 246 | """ 247 | if self.net is None: 248 | except_str: str = """Network is None. 249 | Call set_parameters first and 250 | except_str, 251 | do not use this template without a get_initial_parameters function. 252 | """ 253 | raise ValueError(except_str) 254 | 255 | return ( 256 | generic_get_parameters(self.net) 257 | if isinstance(self.net, nn.Module) 258 | else self.net 259 | ) 260 | 261 | def set_parameters( 262 | self, 263 | parameters: NDArrays, 264 | config: dict, 265 | ) -> nn.Module | NDArrays: 266 | """Set client parameters. 267 | 268 | First generated the network. Only call this in fit/eval. 269 | 270 | Parameters 271 | ---------- 272 | parameters : NDArrays 273 | The parameters to set. 274 | config : Dict 275 | The configuration for the network generator. 276 | 277 | Returns 278 | ------- 279 | nn.Module 280 | The network with the new parameters. 281 | """ 282 | net = ( 283 | self.net_generator(config, self.rng_tuple, self.hydra_config) 284 | if self.net_generator is not None 285 | else None 286 | ) 287 | if net is None: 288 | return parameters 289 | 290 | generic_set_parameters( 291 | net, 292 | parameters, 293 | to_copy=False, 294 | ) 295 | return net 296 | 297 | def __repr__(self) -> str: 298 | """Implement the string representation based on cid.""" 299 | return f"Client(cid={self.cid})" 300 | 301 | def get_properties(self, config: dict) -> dict: 302 | """Implement how to get properties.""" 303 | return {} 304 | 305 | 306 | def get_client_generator( 307 | working_dir: Path, 308 | net_generator: NetGen | None, 309 | dataloader_gen: ClientDataloaderGen | None, 310 | train: TrainFunc, 311 | test: TestFunc, 312 | client_seed_generator: random.Random, 313 | hydra_config: DictConfig | None, 314 | ) -> ClientGen: 315 | """Return a function which creates a new Client. 316 | 317 | Client has access to the working dir, 318 | can generate a network and can generate a dataloader. 319 | The client receives train and test functions with pre-defined APIs. 320 | 321 | Parameters 322 | ---------- 323 | working_dir : Path 324 | The path to the working directory. 325 | net_generator : NetGen 326 | The network generator. 327 | Please respect the pydantic schema. 328 | dataloader_gen : ClientDataloaderGen 329 | The dataloader generator. 330 | Uses the client id to determine partition. 331 | Please respect the pydantic schema. 332 | train : TrainFunc 333 | The train function. 334 | Please respect the interface and pydantic schema. 335 | test : TestFunc 336 | The test function. 337 | Please respect the interface and pydantic schema. 338 | seed : int 339 | The global seed for the random number generators. 340 | random_state : tuple[Any,Any,Any] 341 | The random state for the random number generator. 342 | np_random_state : dict[str,Any] 343 | The numpy random state for the random number generator. 344 | torch_random_state : torch.Tensor 345 | 346 | Returns 347 | ------- 348 | ClientGen 349 | The function which creates a new Client. 350 | """ 351 | 352 | def client_generator(cid: CID) -> fl.client.NumPyClient: 353 | """Return a new Client. 354 | 355 | Parameters 356 | ---------- 357 | cid : int | str | Path 358 | The client's ID. 359 | 360 | Returns 361 | ------- 362 | Client 363 | The new Client. 364 | """ 365 | return Client( 366 | cid, 367 | working_dir, 368 | net_generator, 369 | dataloader_gen, 370 | train, 371 | test, 372 | client_seed=client_seed_generator.randint(0, 2**32 - 1), 373 | hydra_config=hydra_config, 374 | ) 375 | 376 | return client_generator 377 | 378 | 379 | def dispatch_client_gen( 380 | cfg: DictConfig, 381 | saved_state: tuple[Parameters | None, ServerRNG, History], 382 | working_dir: Path, 383 | data_structure: DataStructure, 384 | train_structure: TrainStructure, 385 | config_structure: ConfigStructure, 386 | **kwargs: Any, 387 | ) -> ( 388 | tuple[ 389 | ClientTypeGen, 390 | type[VirtualClientEngineActor], 391 | dict[str, Any] | None, 392 | ] 393 | | None 394 | ): 395 | """Dispatch the get_client_generator function based on the hydra config. 396 | 397 | Parameters 398 | ---------- 399 | cfg : DictConfig 400 | The configuration for the get_client_generators function. 401 | Loaded dynamically from the config file. 402 | 403 | Returns 404 | ------- 405 | tuple[GetClientGen, type[VirtualClientEngineActor], dict[str, Any]] | None 406 | The get_client_generator function and the actor type. 407 | Together with actor kwargs. 408 | """ 409 | client_gen: str | None = cfg.get("task", None).get("client_gen", None) 410 | 411 | if client_gen is None: 412 | return None 413 | 414 | if client_gen.upper() == "DEFAULT": 415 | return ( 416 | get_client_generator, 417 | cast(type[VirtualClientEngineActor], DefaultActor), 418 | None, 419 | ) 420 | 421 | return None 422 | -------------------------------------------------------------------------------- /project/task/mnist_classification/dataset_preparation.py: -------------------------------------------------------------------------------- 1 | """Functions for MNIST download and processing.""" 2 | 3 | import logging 4 | from collections.abc import Sequence, Sized 5 | from pathlib import Path 6 | from typing import cast 7 | 8 | import hydra 9 | import numpy as np 10 | import torch 11 | from flwr.common.logger import log 12 | from omegaconf import DictConfig, OmegaConf 13 | from torch.utils.data import ConcatDataset, Subset, random_split 14 | from torchvision import transforms 15 | from torchvision.datasets import MNIST 16 | 17 | 18 | def _download_data( 19 | dataset_dir: Path, 20 | ) -> tuple[MNIST, MNIST]: 21 | """Download (if necessary) and returns the MNIST dataset. 22 | 23 | Returns 24 | ------- 25 | Tuple[MNIST, MNIST] 26 | The dataset for training and the dataset for testing MNIST. 27 | """ 28 | transform = transforms.Compose( 29 | [ 30 | transforms.ToTensor(), 31 | transforms.Normalize((0.1307,), (0.3081,)), 32 | ], 33 | ) 34 | dataset_dir.mkdir(parents=True, exist_ok=True) 35 | 36 | trainset = MNIST( 37 | str(dataset_dir), 38 | train=True, 39 | download=True, 40 | transform=transform, 41 | ) 42 | testset = MNIST( 43 | str(dataset_dir), 44 | train=False, 45 | download=True, 46 | transform=transform, 47 | ) 48 | return trainset, testset 49 | 50 | 51 | # pylint: disable=too-many-locals 52 | def _partition_data( 53 | trainset: MNIST, 54 | testset: MNIST, 55 | num_clients: int, 56 | seed: int, 57 | iid: bool, 58 | power_law: bool, 59 | balance: bool, 60 | ) -> tuple[list[Subset] | list[ConcatDataset], MNIST]: 61 | """Split training set into iid or non iid partitions to simulate the federated. 62 | 63 | setting. 64 | 65 | Parameters 66 | ---------- 67 | num_clients : int 68 | The number of clients that hold a part of the data 69 | iid : bool 70 | Whether the data should be independent and identically distributed between 71 | the clients or if the data should first be sorted by labels and distributed 72 | by chunks to each client (used to test the convergence in a worst case scenario) 73 | , by default False 74 | power_law: bool 75 | Whether to follow a power-law distribution when assigning number of samples 76 | for each client, defaults to True 77 | balance : bool 78 | Whether the dataset should contain an equal number of samples in each class, 79 | by default False 80 | seed : int 81 | Used to set a fix seed to replicate experiments, by default 42 82 | 83 | Returns 84 | ------- 85 | Tuple[List[MNIST], MNIST] 86 | A list of dataset for each client and a single dataset to be used for testing 87 | the model. 88 | """ 89 | if balance: 90 | trainset = _balance_classes(trainset, seed) 91 | 92 | partition_size = int( 93 | len(cast(Sized, trainset)) / num_clients, 94 | ) 95 | lengths = [partition_size] * num_clients 96 | 97 | if iid: 98 | datasets = random_split( 99 | trainset, 100 | lengths, 101 | torch.Generator().manual_seed(seed), 102 | ) 103 | elif power_law: 104 | trainset_sorted = _sort_by_class(trainset) 105 | datasets = _power_law_split( 106 | trainset_sorted, 107 | num_partitions=num_clients, 108 | num_labels_per_partition=2, 109 | min_data_per_partition=10, 110 | mean=0.0, 111 | sigma=2.0, 112 | ) 113 | else: 114 | shard_size = int(partition_size / 2) 115 | idxs = trainset.targets.argsort() 116 | sorted_data = Subset( 117 | trainset, 118 | cast(Sequence[int], idxs), 119 | ) 120 | tmp = [] 121 | for idx in range(num_clients * 2): 122 | tmp.append( 123 | Subset( 124 | sorted_data, 125 | cast( 126 | Sequence[int], 127 | np.arange( 128 | shard_size * idx, 129 | shard_size * (idx + 1), 130 | ), 131 | ), 132 | ), 133 | ) 134 | idxs_list = torch.randperm( 135 | num_clients * 2, 136 | generator=torch.Generator().manual_seed(seed), 137 | ) 138 | datasets = [ 139 | ConcatDataset( 140 | ( 141 | tmp[idxs_list[2 * i]], 142 | tmp[idxs_list[2 * i + 1]], 143 | ), 144 | ) 145 | for i in range(num_clients) 146 | ] 147 | 148 | return datasets, testset 149 | 150 | 151 | def _balance_classes( 152 | trainset: MNIST, 153 | seed: int, 154 | ) -> MNIST: 155 | """Balance the classes of the trainset. 156 | 157 | Trims the dataset so each class contains as many elements as the 158 | class that contained the least elements. 159 | 160 | Parameters 161 | ---------- 162 | trainset : MNIST 163 | The training dataset that needs to be balanced. 164 | seed : int, optional 165 | Used to set a fix seed to replicate experiments, by default 42. 166 | 167 | Returns 168 | ------- 169 | MNIST 170 | The balanced training dataset. 171 | """ 172 | class_counts = np.bincount(trainset.targets) 173 | smallest = np.min(class_counts) 174 | idxs = trainset.targets.argsort() 175 | tmp = [ 176 | Subset( 177 | trainset, 178 | cast(Sequence[int], idxs[: int(smallest)]), 179 | ), 180 | ] 181 | tmp_targets = [trainset.targets[idxs[: int(smallest)]]] 182 | for count in np.cumsum(class_counts): 183 | tmp.append( 184 | Subset( 185 | trainset, 186 | cast( 187 | Sequence[int], 188 | idxs[int(count) : int(count + smallest)], 189 | ), 190 | ), 191 | ) 192 | tmp_targets.append( 193 | trainset.targets[idxs[int(count) : int(count + smallest)]], 194 | ) 195 | unshuffled = ConcatDataset(tmp) 196 | unshuffled_targets = torch.cat(tmp_targets) 197 | shuffled_idxs = torch.randperm( 198 | len(unshuffled), 199 | generator=torch.Generator().manual_seed(seed), 200 | ) 201 | shuffled = cast( 202 | MNIST, 203 | Subset( 204 | unshuffled, 205 | cast(Sequence[int], shuffled_idxs), 206 | ), 207 | ) 208 | shuffled.targets = unshuffled_targets[shuffled_idxs] 209 | 210 | return shuffled 211 | 212 | 213 | def _sort_by_class( 214 | trainset: MNIST, 215 | ) -> MNIST: 216 | """Sort dataset by class/label. 217 | 218 | Parameters 219 | ---------- 220 | trainset : MNIST 221 | The training dataset that needs to be sorted. 222 | 223 | Returns 224 | ------- 225 | MNIST 226 | The sorted training dataset. 227 | """ 228 | class_counts = np.bincount(trainset.targets) 229 | idxs = trainset.targets.argsort() # sort targets in ascending order 230 | 231 | tmp = [] # create subset of smallest class 232 | tmp_targets = [] # same for targets 233 | 234 | start = 0 235 | for count in np.cumsum(class_counts): 236 | tmp.append( 237 | Subset( 238 | trainset, 239 | cast( 240 | Sequence[int], 241 | idxs[start : int(count + start)], 242 | ), 243 | ), 244 | ) # add rest of classes 245 | tmp_targets.append( 246 | trainset.targets[idxs[start : int(count + start)]], 247 | ) 248 | start += count 249 | sorted_dataset = cast( 250 | MNIST, 251 | ConcatDataset(tmp), 252 | ) # concat dataset 253 | sorted_dataset.targets = torch.cat( 254 | tmp_targets, 255 | ) # concat targets 256 | return sorted_dataset 257 | 258 | 259 | # pylint: disable=too-many-locals, too-many-arguments 260 | def _power_law_split( 261 | sorted_trainset: MNIST, 262 | num_partitions: int, 263 | num_labels_per_partition: int = 2, 264 | min_data_per_partition: int = 10, 265 | mean: float = 0.0, 266 | sigma: float = 2.0, 267 | ) -> list[Subset]: 268 | """Partition the dataset following a power-law distribution. It follows the. 269 | 270 | implementation of Li et al 2020: https://arxiv.org/abs/1812.06127 with default 271 | values set accordingly. 272 | 273 | Parameters 274 | ---------- 275 | sorted_trainset : MNIST 276 | The training dataset sorted by label/class. 277 | num_partitions: int 278 | Number of partitions to create 279 | num_labels_per_partition: int 280 | Number of labels to have in each dataset partition. For 281 | example if set to two, this means all training examples in 282 | a given partition will be long to the same two classes. default 2 283 | min_data_per_partition: int 284 | Minimum number of datapoints included in each partition, default 10 285 | mean: float 286 | Mean value for LogNormal distribution to construct power-law, default 0.0 287 | sigma: float 288 | Sigma value for LogNormal distribution to construct power-law, default 2.0 289 | 290 | Returns 291 | ------- 292 | MNIST 293 | The partitioned training dataset. 294 | """ 295 | targets = sorted_trainset.targets 296 | full_idx = list(range(len(targets))) 297 | 298 | class_counts = np.bincount(sorted_trainset.targets) 299 | labels_cs = np.cumsum(class_counts) 300 | labels_cs = [0] + labels_cs[:-1].tolist() 301 | 302 | partitions_idx: list[list[int]] = [] 303 | num_classes = len(np.bincount(targets)) 304 | hist = np.zeros(num_classes, dtype=np.int32) 305 | 306 | # assign min_data_per_partition 307 | min_data_per_class = int( 308 | min_data_per_partition / num_labels_per_partition, 309 | ) 310 | for u_id in range(num_partitions): 311 | partitions_idx.append([]) 312 | for cls_idx in range(num_labels_per_partition): 313 | # label for the u_id-th client 314 | cls = (u_id + cls_idx) % num_classes 315 | # record minimum data 316 | indices = list( 317 | full_idx[ 318 | labels_cs[cls] 319 | + hist[cls] : labels_cs[cls] 320 | + hist[cls] 321 | + min_data_per_class 322 | ], 323 | ) 324 | partitions_idx[-1].extend(indices) 325 | hist[cls] += min_data_per_class 326 | 327 | # add remaining images following power-law 328 | probs = np.random.lognormal( 329 | mean, 330 | sigma, 331 | ( 332 | num_classes, 333 | int(num_partitions / num_classes), 334 | num_labels_per_partition, 335 | ), 336 | ) 337 | remaining_per_class = class_counts - hist 338 | # obtain how many samples each partition should be assigned for each of the 339 | # labels it contains 340 | # pylint: disable=too-many-function-args 341 | probs = ( 342 | remaining_per_class.reshape(-1, 1, 1) 343 | * probs 344 | / np.sum(probs, (1, 2), keepdims=True) 345 | ) 346 | 347 | for u_id in range(num_partitions): 348 | for cls_idx in range(num_labels_per_partition): 349 | cls = (u_id + cls_idx) % num_classes 350 | count = int( 351 | probs[cls, u_id // num_classes, cls_idx], 352 | ) 353 | 354 | # add count of specific class to partition 355 | indices = full_idx[ 356 | labels_cs[cls] + hist[cls] : labels_cs[cls] + hist[cls] + count 357 | ] 358 | partitions_idx[u_id].extend(indices) 359 | hist[cls] += count 360 | 361 | # construct partition subsets 362 | return [Subset(sorted_trainset, p) for p in partitions_idx] 363 | 364 | 365 | @hydra.main( 366 | config_path="../../conf", 367 | config_name="mnist", 368 | version_base=None, 369 | ) 370 | def download_and_preprocess(cfg: DictConfig) -> None: 371 | """Download and preprocess the dataset. 372 | 373 | Please include here all the logic 374 | Please use the Hydra config style as much as possible specially 375 | for parts that can be customized (e.g. how data is partitioned) 376 | 377 | Parameters 378 | ---------- 379 | cfg : DictConfig 380 | An omegaconf object that stores the hydra config. 381 | """ 382 | # 1. print parsed config 383 | log(logging.INFO, OmegaConf.to_yaml(cfg)) 384 | 385 | # Download the dataset 386 | trainset, testset = _download_data( 387 | Path(cfg.dataset.dataset_dir), 388 | ) 389 | 390 | # Partition the dataset 391 | # ideally, the fed_test_set can be composed in three ways: 392 | # 1. fed_test_set = centralized test set like MNIST 393 | # 2. fed_test_set = concatenation of all test sets of all clients 394 | # 3. fed_test_set = test sets of reserved unseen clients 395 | client_datasets, fed_test_set = _partition_data( 396 | trainset, 397 | testset, 398 | cfg.dataset.num_clients, 399 | cfg.dataset.seed, 400 | cfg.dataset.iid, 401 | cfg.dataset.power_law, 402 | cfg.dataset.balance, 403 | ) 404 | 405 | # 2. Save the datasets 406 | # unnecessary for this small dataset, but useful for large datasets 407 | partition_dir = Path(cfg.dataset.partition_dir) 408 | partition_dir.mkdir(parents=True, exist_ok=True) 409 | 410 | # Save the centralized test set 411 | # a centralized training set would also be possible 412 | # but is not used here 413 | torch.save(fed_test_set, partition_dir / "test.pt") 414 | 415 | # Save the client datasets 416 | for idx, client_dataset in enumerate(client_datasets): 417 | client_dir = partition_dir / f"client_{idx}" 418 | client_dir.mkdir(parents=True, exist_ok=True) 419 | 420 | len_val = int( 421 | len(client_dataset) / (1 / cfg.dataset.val_ratio), 422 | ) 423 | lengths = [len(client_dataset) - len_val, len_val] 424 | ds_train, ds_val = random_split( 425 | client_dataset, 426 | lengths, 427 | torch.Generator().manual_seed(cfg.dataset.seed), 428 | ) 429 | # Alternative would have been to create train/test split 430 | # when the dataloader is instantiated 431 | torch.save(ds_train, client_dir / "train.pt") 432 | torch.save(ds_val, client_dir / "test.pt") 433 | 434 | 435 | if __name__ == "__main__": 436 | download_and_preprocess() 437 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # [CaMLSys](https://mlsys.cst.cam.ac.uk/) Federated Learning Research Template using [Flower](https://github.com/adap/flower), [Hydra](https://github.com/facebookresearch/hydra), and [Wandb](https://wandb.ai/site) 3 | [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/camlsys/fl-project-template/main.svg)](https://results.pre-commit.ci/latest/github/camlsys/fl-project-template/main) 4 | 5 | > Note: If you use this template in your work, please reference it and cite the [Flower](https://arxiv.org/abs/2007.14390) paper. 6 | 7 | > :warning: This ``README`` describes how to use the template as-is after installing it with the default `setup.sh` script in a machine running Ubuntu `22.04`. Please follow the instructions in `EXTENDED_README.md` for details on more complex environment setups and how to extend this template. 8 | 9 | ## About this template 10 | 11 | Federated Learning (FL) is a privacy-preserving machine learning paradigm that allows training models directly on local client data using local client resources. This template standardizes the FL research workflow at the [Cambridge ML Systems](https://mlsys.cst.cam.ac.uk/) based on three frameworks chosen for their flexibility and ease of use: 12 | - [Flower](https://github.com/adap/flower): The FL framework developed by [Flower Labs](https://flower.dev/) with contributions from [CaMLSys](https://mlsys.cst.cam.ac.uk/) members. 13 | - [Hydra](https://github.com/facebookresearch/hydra): framework for managing experiments developed at Meta which automatically handles experimental configuration for Python. 14 | - [Wandb](https://wandb.ai/site): The MLOps platform developed for handling results storage, experiment tracking, reproducibility, and visualization. 15 | 16 | While these tools can be combined in an ad-hoc manner, this template intends to provide a unified and opinionated structure for achieving this while providing functionality that may not have been easily constructed from scratch. 17 | 18 | ### What this template does: 19 | - Automatically handles client configuration for Flower in an opinionated manner using the [PyTorch](https://github.com/pytorch/pytorch) library. This is meant to reduce the task of FL simulation to the mere implementation of standard ML tasks combined with minimal configuration work. Specifically, clients are treated uniformly except for their data, model, and configuration. 20 | - A user only needs to provide: 21 | - A means of generating a model (e.g., a function which returns a PyTorch model) based on a received configuration (e.g., a Dict) 22 | - A means of constructing train and test dataloaders 23 | - A means of offering a configuration to these components 24 | - All data loading or model training is delayed as much as possible to facilitate creating many clients and keeping them in memory with the smallest footprint possible. 25 | - Metric collection and aggregation require no additional implementation. 26 | - Automatically handles logging, saving, and checkpointing, which integrate natively and seamlessly with Wandb and Hydra. This enables sequential re-launches of the same job on clusters using time-limited schedulers. 27 | - Provides deterministic seeded client selection while taking into account the current checkpoint. 28 | - Provides a static means of selecting which ML task to run using Hydra's config system without the drawbacks of the untyped mechanism provided by Hydra. 29 | - By default, it enforces good coding standards by using isort, black, docformatter, ruff, and mypy integrated with [pre-commit](https://pre-commit.com/). [Pydantic](https://docs.pydantic.dev/latest/) is also used to validate configuration data for generating models, creating dataloaders, training clients, etc. 30 | 31 | ### What this template does not do: 32 | - Provide off-the-shelf implementations of FL algorithms, ML tasks, datasets, or models beyond the MNIST example. For such functionality, please refer to the original [Flower](https://github.com/adap/flower) and [PyTorch](https://github.com/pytorch/pytorch). 33 | - Provide a means of running experiments on clusters as this depends on the cluster configuration. 34 | 35 | ## Setup 36 | 37 | For systems running UBUNTU with CUDA 12, the basic setup has been simplified to one `setup.sh` script using [poetry](https://python-poetry.org/), [pyenv](https://github.com/pyenv/pyenv) and [pre-commit](https://pre-commit.com/). It only requires limited user input regarding the installation location of ``pyenv`` and ``poetry``, and will install the specified python version. All dependencies are placed in the local ``.venv`` directory. 38 | 39 | If you have a different system, you will need to modify `pyproject.toml` to include a link to the appropriate torch wheel and to replicate the operations of `setup.sh` for your system using the appropriate operations. 40 | 41 | 42 | By default, pre-commit only runs hooks on files staged for commit. If you wish to run all the pre-commit hooks without committing or pushing, use: 43 | ```bash 44 | poetry run pre-commit run --all-files --hook-stage push 45 | ``` 46 | 47 | 48 | ## Using the Template 49 | 50 | 51 | 52 | > Note: these instructions rely on the MNIST task and assume specific dataset partitioning, model creation and dataloader instantiation procedure. We recommend following a similar structure in your own experiments. Please refer to the [Flower](https://flower.dev/docs/baselines/index.html) baselines for more examples. 53 | 54 | Install the template using the setup.sh script: 55 | ```bash 56 | ./setup.sh 57 | ``` 58 | 59 | 60 | 61 | If ``poetry``, ``pyenv``, and/or the correct python version are installed, they will not be installed again. If not installed, you must provide paths to the desired install locations. If running on a cluster, this would be the location of the shared file system. You can now run ```poetry shell``` to activate the python env in your shell 62 | > :warning: Run the `default` task to check that everything is installed correctly from the root ``fl-project-template``, not from the ``fl-project-template/project`` directory. 63 | 64 | ```bash 65 | poetry run python -m project.main --config-name=base 66 | ``` 67 | 68 | 69 | If you have a cluster which may run multiple Ray simulator instances, you will need to launch the server separately. 70 | 71 | The default task should have created a folder in fl-project-template/outputs. This folder contains the results of the experiment. To log your experiments to wandb, log into wandb and then enable it via the command: 72 | 73 | ```bash 74 | poetry run python -m project.main --config-name=base use_wandb=true 75 | ``` 76 | 77 | Now, you can run the MNIST example by following these instructions: 78 | - Specify a ``dataset_dir`` and ``partition_dir`` in ``conf/dataset/mnist.yaml`` together with the ``num_clients``, the size of a clients validation set ``val_ratio``, a ``seed`` for partitioning. You can also specify if the partition labels should be ``iid``, follow a ``power_law`` distribution or if the partition should ``balance`` the labels across clients. 79 | - Download and partition the dataset by running the following command from the root dir: 80 | - ```bash 81 | poetry run python -m project.task.mnist_classification.dataset_preparation 82 | ``` 83 | - Specify which ``model_and_data``, ``train_structure``, and ``fit_config`` or ``eval_config`` to use in the ``conf/task/mnist.yaml file``. The defaults are a CNN, a simple classification training/testing loop, and configs controlling ``batch_size``, local client ``epochs``, and the ``learning_rate``. You can also specify which metrics to aggregate during fit/eval. 84 | - Run the experiment using the following command from the root dir: 85 | - ```bash 86 | poetry run python -m project.main --config-name=mnist 87 | ``` 88 | 89 | Once a complete experiment has run, you can continue it for a specified number of epochs by running the following command from the root dir to change the output directory to the previous one. 90 | - ```bash 91 | poetry run python -m project.main --config-name=mnist hydra.run.dir= 92 | ``` 93 | These are all the basic steps required to run a simple experiment. 94 | 95 | ## Adding a task 96 | 97 | Adding a task requires you to add a new task in the ```project.task``` module and to make changes to the ```project.dispatch``` module. Each ```project.task``` module has a specific structure: 98 | - ``task``: The ML task implementation includes the model, data loading, and training/testing. Almost all user changes should be made here. Tasks will typically include modules for the following: 99 | - ``dataset_preparation``: Hydra entry point which handles downloading the dataset and partitionin it. The partition can be generated on the fly during FL execution or saved into a partition directory with one folder per client containing train and test files---with the server test set being in the root directory of the partition dir. This needs to be executed prior to running the main experiment. It relies on the dataset part of the Hydra config. 100 | - ``dataset``: offers functionality to create the dataloaders for either the client fit/eval or for the centralized server evaluation. 101 | - ``dispatch``: Handles mapping the Hydra config to the required task configuration. 102 | - ``models``: Offers functionality to lazily create a model based on a received configuration. 103 | - ``train_test``: Offers functionality to train a model on a given dataset. This includes the effective train/test functions together with the config generation functions for the fit/eval stages of FL. The federated evaluation test function, if provided, should also be specified here. 104 | 105 | Specifying a new task requires implementing the above functionality, together with functions/closures which generate/configure and generate them in a manner which obeys the interface of previous tasks, specified in ```project.types```. 106 | 107 | After implementing the task, dynamically starting it via ```hydra``` requires changing two modules: 108 | - The ```project..dispatch``` module requires three functions: 109 | - ```dispatch_data(cfg)``` is meant to provide a function to generate the model and the dataloaders. By default this is done via the ```conf.task.model_and_data``` string in the config. 110 | - ```dispatch_train(cfg)``` selects the ```train```, ```test``` and federated test functions. By default this is dispatched on the ```conf.task.train_structure``` string in the config. 111 | - ```dispatch_config``` selects the configs used during fit and eval, you will likely not have to change this as the default task provides a sensible version. 112 | - The ```project.dispatch``` module requires you to add the task-specific ```dispatch_data```, ```dispatch_train``` and ```dispatch_config``` functions from the ```project..dispatch``` module to the list of possible tasks that can match the config. The statically-declared function order determines which task is selected if multiple ones match the config. 113 | 114 | You has now implemented an entirely new FL task without touching any of the FL-specific code. 115 | 116 | # How to use the template for open-source research 117 | 118 | This section aims to teach you how to have research projects containing both public and private components such that previously private work can be effortlessly open-sourced after publication. 119 | 120 | 1. Fork the code template into your own private GitHub; do not click “Use as template” as that would disallow you from adding PRs to the original repo. 121 | 2. Create a private repository [mirroring](https://docs.github.com/en/repositories/creating-and-managing-repositories/duplicating-a-repository) the code template 122 | 1. Create a new private repository using the GitHub UI, called something like `private-fl-projects` 123 | 2. Clone the public template 124 | 1. `git clone --bare https://github.com/camlsys/fl-project-template.git` 125 | 2. `cd fl-project-template.git` 126 | 3. `git push --mirror https://github.com/camlsys/private-fl-project.git` 127 | 4. `cd ..` 128 | 5. `rm -rf fl-project-template.git` 129 | 3. After you have done these steps, you never have to touch the public fork directly, all you need to do is: 130 | 1. Go to the `private-fl-projects` repo 131 | 2. `git remote add public git@github.com:your-name/fl-project-template.git` 132 | 3. Now, any push you do by default will go to the origin (i.e, the private repo) otherwise if you want to pull/push from/to the public one, you can do: 133 | 1. `git pull public main` 134 | 2. `git push public main` 135 | 3. You can then PR from the public fork to the original repo and bring any contributions you wish 136 | 4. You can also officially publish your code by pushing a private branch to your public fork; this branch does not have to be synced to the template but may be of use if the conference requires an artefact for reproducibility 137 | 138 | ## Using checkpoints 139 | 140 | By default, the entire template is synchronized across server rounds and the model parameters, `RNG` state, `Wandb` run, metric `History`, config files and logs are all checkpointed either every `freq` rounds, or once at the end of training when the process exists. If Wandb is used, any restarted run continues at the exact same link in Wandb with no cumbersome tracking necessary. 141 | 142 | To use the checkpoint system all you have to do is to specify the `hydra.run.dir` to be a previous execution directory rather than the default timestamped output directory. If you wish to restore a specific round rather than the most recent one then modify the `server_round` in the `fed` config. 143 | 144 | ## Reproducibility 145 | One of the primary functionalities of this template is to allow for easily reproducible FL checkpointing. It achieves this by controlling the client sampling, server `RNG`, and client `RNG` seeding and saving the rng states for `Random`, `np`, and `torch`. The server and every client are provided with an isolated RNG generator making them usable in a multithreaded context where the global generators may get accessed unpredictably. 146 | 147 | The `RNG` states of all of the relevant packages and generators are automatically saved and synchronized to the round, allowing for reproducible client samples and client execution in the same round. Every relevant piece of client functionality also receives the isolated `RNG` state and can be used to guarantee reproducibility (e.g., the `PyTorch`` dataloader). 148 | 149 | ## Template Structure 150 | 151 | The template uses poetry with the ``project`` name for the top-level package. All imports are made from this package, and no relative imports are allowed. The structure is as follows: 152 | 153 | ``` 154 | project 155 | ├── client 156 | ├── conf 157 | ├── dispatch 158 | ├── fed 159 | ├── main.py 160 | ├── task 161 | ├── types 162 | └── utils 163 | ``` 164 | The main packages of concern are: 165 | - ``client``: Contains the client class, requires no changes 166 | - ``conf``: This contains the Hydra configuration files specifying experiment behavior and the chosen ML task. 167 | - ``dispatch``: handles mapping a Hydra configuration to the ML task. 168 | - ``fed``: Contains the federated learning functionality such as client sampling and model parameter saving. Should require little to no modification. 169 | - ``main``: a hydra entry point. 170 | - ``task``: described above 171 | 172 | Two tasks are already implemented: 173 | - ``default``: A task providing generic functionality that may be reused across tasks. It requires no data and provides a minimum example of what a task must provide for the FL training to execute. 174 | - ``mnist_classification``: Uses the simple MNIST dataset with either a CNN or logistic regression model. 175 | 176 | > :warning: Prefer changing only the task module when possible. 177 | 178 | 179 | ## Enabling Pre-commit CI 180 | 181 | To enable Continous Integration of your project via Pre-commit, all you need to do is allow pre-commit for a given repo from the [github marketplace](https://github.com/marketplace/pre-commit-ci/plan/MDIyOk1hcmtldHBsYWNlTGlzdGluZ1BsYW42MTI2#plan-6126). You should be aware that this is free only for public open-source repositories. 182 | 183 | 184 | -------------------------------------------------------------------------------- /project/main.py: -------------------------------------------------------------------------------- 1 | """Create and connect the building blocks for your experiments; start the simulation. 2 | 3 | It includes processing the dataset, instantiate strategy, specifying how the global 4 | model will be evaluated, etc. In the end, this script saves the results. 5 | """ 6 | 7 | import copy 8 | import logging 9 | import os 10 | import subprocess 11 | import sys 12 | from pathlib import Path 13 | from typing import cast 14 | import uuid 15 | 16 | import flwr as fl 17 | import hydra 18 | import wandb 19 | from wandb.sdk.wandb_run import Run 20 | from flwr.common.logger import log 21 | from hydra.core.hydra_config import HydraConfig 22 | from hydra.utils import instantiate 23 | from omegaconf import DictConfig, OmegaConf 24 | 25 | from project.dispatch.dispatch import ( 26 | dispatch_config, 27 | dispatch_data, 28 | dispatch_get_client_generator, 29 | dispatch_get_client_manager, 30 | dispatch_server, 31 | dispatch_train, 32 | ) 33 | from project.fed.utils.utils import ( 34 | get_save_history_to_file, 35 | get_state, 36 | get_save_parameters_to_file, 37 | get_save_rng_to_file, 38 | get_weighted_avg_metrics_agg_fn, 39 | test_client, 40 | ) 41 | from project.types.common import ClientGen, FedEvalFN, Folders 42 | from project.utils.utils import ( 43 | FileSystemManager, 44 | RayContextManager, 45 | load_wandb_run_details, 46 | save_wandb_run_details, 47 | wandb_init, 48 | ) 49 | 50 | # Make debugging easier when using Hydra + Ray 51 | os.environ["HYDRA_FULL_ERROR"] = "1" 52 | os.environ["OC_CAUSE"] = "1" 53 | 54 | 55 | @hydra.main( 56 | config_path="conf", 57 | config_name="base", 58 | version_base=None, 59 | ) 60 | def main(cfg: DictConfig) -> None: 61 | """Run the baseline. 62 | 63 | Parameters 64 | ---------- 65 | cfg : DictConfig 66 | An omegaconf object that stores the hydra config. 67 | """ 68 | # Print parsed config 69 | log(logging.INFO, OmegaConf.to_yaml(cfg)) 70 | 71 | wandb_config = OmegaConf.to_container( 72 | cfg, 73 | resolve=True, 74 | throw_on_missing=True, 75 | ) 76 | 77 | # Obtain the output dir from hydra 78 | original_hydra_dir = Path( 79 | hydra.utils.to_absolute_path( 80 | HydraConfig.get().runtime.output_dir, 81 | ), 82 | ) 83 | 84 | output_directory = original_hydra_dir 85 | 86 | # The directory to save data to 87 | results_dir = output_directory / Folders.RESULTS 88 | results_dir.mkdir(parents=True, exist_ok=True) 89 | 90 | # Where to save files to and from 91 | if cfg.working_dir is not None: 92 | # Pre-defined directory 93 | working_dir = Path(cfg.working_dir) 94 | else: 95 | # Default directory 96 | working_dir = output_directory / Folders.WORKING 97 | 98 | working_dir.mkdir(parents=True, exist_ok=True) 99 | 100 | # Restore wandb runs automatically 101 | wandb_id = None 102 | if cfg.use_wandb and cfg.wandb_resume: 103 | if cfg.wandb_id is not None: 104 | wandb_id = cfg.wandb_id 105 | elif ( 106 | saved_wandb_details := load_wandb_run_details(results_dir / Folders.WANDB) 107 | ) is not None: 108 | wandb_id = saved_wandb_details.wandb_id 109 | 110 | # Wandb context manager 111 | # controls if wandb is initialized or not 112 | # if not it returns a dummy run 113 | with wandb_init( 114 | cfg.use_wandb, 115 | **cfg.wandb.setup, 116 | settings=wandb.Settings(start_method="thread"), 117 | config=wandb_config, 118 | resume="must" if cfg.wandb_resume and wandb_id is not None else "allow", 119 | id=wandb_id if wandb_id is not None else uuid.uuid4().hex, 120 | ) as run: 121 | if cfg.use_wandb: 122 | save_wandb_run_details(cast(Run, run), working_dir / Folders.WANDB) 123 | log( 124 | logging.INFO, 125 | "Wandb run initialized with %s", 126 | cfg.use_wandb, 127 | ) 128 | 129 | # Context managers for saving and cleaning up files 130 | # from the working directory 131 | # at the start/end of the simulation 132 | # The RayContextManager deletes the ray session folder 133 | with ( 134 | FileSystemManager( 135 | working_dir=working_dir, 136 | results_dir=results_dir, 137 | load_parameters_from=cfg.fed.parameters_folder, 138 | to_clean_once=cfg.to_clean_once, 139 | to_save_once=cfg.to_save_once, 140 | to_restore=cfg.to_restore, 141 | original_hydra_dir=original_hydra_dir, 142 | starting_round=cfg.fed.server_round, 143 | file_limit=int(cfg.file_limit), 144 | ) as fs_manager, 145 | RayContextManager() as _ray_manager, 146 | ): 147 | # Obtain the net generator, dataloader and fed_dataloader 148 | # Change the cfg.task.model_and_data str to change functionality 149 | ( 150 | net_generator, 151 | initial_parameter_gen, 152 | client_dataloader_gen, 153 | fed_dataloader_gen, 154 | init_working_dir, 155 | ) = data_structure = dispatch_data( 156 | cfg, 157 | ) 158 | # The folder starts either empty or only with restored files 159 | # as specified in the config 160 | if init_working_dir is not None: 161 | init_working_dir(working_dir, results_dir) 162 | 163 | # Parameters/rng/history state for the strategy 164 | # Uses the path to the saved initial parameters and state 165 | # If none are available, new ones will be generated 166 | 167 | # Use the results_dir by default 168 | # otherwise use the specified folder 169 | 170 | saved_state = get_state( 171 | net_generator, 172 | initial_parameter_gen, 173 | config=cast( 174 | dict, 175 | OmegaConf.to_container( 176 | cfg.task.net_config_initial_parameters, 177 | ), 178 | ), 179 | load_parameters_from=( 180 | results_dir / Folders.STATE / Folders.PARAMETERS 181 | if cfg.fed.parameters_folder is None 182 | else Path(cfg.fed.parameters_folder) 183 | ), 184 | load_rng_from=( 185 | results_dir / Folders.STATE / Folders.RNG 186 | if cfg.fed.rng_folder is None 187 | else Path(cfg.fed.rng_folder) 188 | ), 189 | load_history_from=( 190 | results_dir / Folders.STATE / Folders.HISTORIES 191 | if cfg.fed.history_folder is None 192 | else Path(cfg.fed.history_folder) 193 | ), 194 | seed=cfg.fed.seed, 195 | server_round=fs_manager.server_round, 196 | use_wandb=cfg.use_wandb, 197 | hydra_config=cfg, 198 | ) 199 | initial_parameters, server_rng, history = saved_state 200 | 201 | server_isolated_rng, client_cid_rng, client_seed_rng = server_rng 202 | 203 | # Client manager that samples the same clients 204 | # For a given seed+checkpoint combination 205 | client_manager = dispatch_get_client_manager(cfg)( 206 | enable_resampling=cfg.fed.enable_resampling, 207 | client_cid_generator=client_cid_rng, 208 | hydra_config=cfg, 209 | ) 210 | 211 | # Obtain the train/test func and the fed eval func 212 | # Change the cfg.task.train_structure str to change functionality 213 | ( 214 | train_func, 215 | test_func, 216 | get_fed_eval_fn, 217 | ) = train_structure = dispatch_train(cfg) 218 | 219 | # Obtain the on_fit config and on_eval config 220 | # generation functions 221 | # These depend on the cfg.task.fit_config 222 | # and cfg.task.eval_config dictionaries by default 223 | ( 224 | on_fit_config_fn, 225 | on_evaluate_config_fn, 226 | ) = config_structure = dispatch_config(cfg) 227 | 228 | get_client_generator, actor_type, actor_kwargs = ( 229 | dispatch_get_client_generator( 230 | cfg, 231 | saved_state=saved_state, 232 | working_dir=working_dir, 233 | data_structure=data_structure, 234 | train_structure=train_structure, 235 | config_structure=config_structure, 236 | ) 237 | ) 238 | 239 | # Build the evaluate function from the given components 240 | # This is the function that is called on the server 241 | # to evaluated the global model 242 | # the cast to Dict is necessary for mypy 243 | # as is the to_container 244 | evaluate_fn: FedEvalFN | None = get_fed_eval_fn( 245 | net_generator, 246 | fed_dataloader_gen, 247 | test_func, 248 | cast( 249 | dict, 250 | OmegaConf.to_container( 251 | cfg.task.fed_test_config, 252 | ), 253 | ), 254 | working_dir, 255 | server_isolated_rng, 256 | copy.deepcopy(cfg), 257 | ) 258 | 259 | # Define your strategy 260 | # pass all relevant argument 261 | # Fraction_fit and fraction_evaluate are ignored 262 | # in favor of using absolute numbers via min_fit_clients 263 | # get_weighted_avg_metrics_agg_fn obeys 264 | # the fit_metrics and evaluate_metrics 265 | # in the cfg.task 266 | strategy = instantiate( 267 | cfg.strategy.init, 268 | fraction_fit=sys.float_info.min, 269 | fraction_evaluate=sys.float_info.min, 270 | min_fit_clients=cfg.fed.num_clients_per_round, 271 | min_evaluate_clients=cfg.fed.num_evaluate_clients_per_round, 272 | min_available_clients=cfg.fed.num_total_clients, 273 | on_fit_config_fn=on_fit_config_fn, 274 | on_evaluate_config_fn=on_evaluate_config_fn, 275 | evaluate_fn=evaluate_fn, 276 | accept_failures=False, 277 | fit_metrics_aggregation_fn=get_weighted_avg_metrics_agg_fn( 278 | cfg.task.fit_metrics, 279 | ), 280 | evaluate_metrics_aggregation_fn=get_weighted_avg_metrics_agg_fn( 281 | cfg.task.evaluate_metrics, 282 | ), 283 | initial_parameters=initial_parameters, 284 | ) 285 | 286 | # Server that handles Wandb and file saving 287 | server = dispatch_server(cfg)( 288 | client_manager=client_manager, 289 | hydra_config=cfg, 290 | starting_round=fs_manager.server_round, 291 | server_rng=server_rng, 292 | history=history, 293 | strategy=strategy, 294 | save_parameters_to_file=get_save_parameters_to_file( 295 | working_dir / Folders.STATE / Folders.PARAMETERS 296 | if cfg.fed.parameters_folder is None 297 | else Path(cfg.fed.parameters_folder) 298 | ), 299 | save_history_to_file=get_save_history_to_file( 300 | working_dir / Folders.STATE / Folders.HISTORIES 301 | if cfg.fed.history_folder is None 302 | else Path(cfg.fed.history_folder) 303 | ), 304 | save_rng_to_file=get_save_rng_to_file( 305 | working_dir / Folders.STATE / Folders.RNG 306 | if cfg.fed.rng_folder is None 307 | else Path(cfg.fed.rng_folder) 308 | ), 309 | save_files_per_round=fs_manager.get_save_files_every_round( 310 | cfg.to_save_per_round, 311 | cfg.save_frequency, 312 | ), 313 | ) 314 | 315 | # Client generation function for Ray 316 | # Do not change 317 | client_generator: ClientGen = get_client_generator( 318 | working_dir, 319 | net_generator, 320 | client_dataloader_gen, 321 | train_func, 322 | test_func, 323 | client_seed_rng, 324 | cfg, 325 | ) 326 | if initial_parameters is not None: 327 | # Runs fit and eval on either one client or all of them 328 | # Avoids launching ray for debugging purposes 329 | test_client( 330 | test_all_clients=cfg.debug_clients.all, 331 | test_one_client=cfg.debug_clients.one, 332 | client_generator=client_generator, 333 | initial_parameters=initial_parameters, 334 | total_clients=cfg.fed.num_total_clients, 335 | on_fit_config_fn=on_fit_config_fn, 336 | on_evaluate_config_fn=on_evaluate_config_fn, 337 | ) 338 | 339 | # Start Simulation 340 | # The ray_init_args are only necessary 341 | # If multiple ray servers run in parallel 342 | # you should provide them from wherever 343 | # you start your server (e.g., sh script) 344 | # NOTE: `client_resources` accepts fractional 345 | # values for `num_cpus` and `num_gpus` iff 346 | # they're lower than 1.0. 347 | fl.simulation.start_simulation( 348 | # NOTE: mypy complains about the type of client_generator 349 | # We must wait for reconciliation from Flower 350 | client_fn=lambda cid: client_generator(cid).to_client(), 351 | num_clients=cfg.fed.num_total_clients, 352 | client_resources={ 353 | "num_cpus": ( 354 | int( 355 | cfg.fed.cpus_per_client, 356 | ) 357 | if cfg.fed.cpus_per_client >= 1 358 | else float( 359 | cfg.fed.cpus_per_client, 360 | ) 361 | ), 362 | "num_gpus": ( 363 | int( 364 | cfg.fed.gpus_per_client, 365 | ) 366 | if cfg.fed.gpus_per_client >= 1 367 | else float( 368 | cfg.fed.gpus_per_client, 369 | ) 370 | ), 371 | }, 372 | server=server, 373 | config=fl.server.ServerConfig( 374 | num_rounds=cfg.fed.num_rounds, 375 | ), 376 | ray_init_args=( 377 | { 378 | "include_dashboard": False, 379 | "address": cfg.ray_address, 380 | "_redis_password": cfg.ray_redis_password, 381 | "_node_ip_address": cfg.ray_node_ip_address, 382 | } 383 | if cfg.ray_address is not None 384 | else {"include_dashboard": False} 385 | ), 386 | actor_type=actor_type, 387 | actor_kwargs=actor_kwargs, 388 | ) 389 | 390 | # Sync the entire results dir to wandb if enabled 391 | # Only once at the end of the simulation 392 | if run is not None: 393 | run.save( 394 | str((results_dir / "*").resolve()), 395 | str((results_dir).resolve()), 396 | "now", 397 | ) 398 | 399 | if cfg.fed.parameters_folder is not None: 400 | run.save( 401 | str((Path(cfg.fed.parameters_folder) / "*").resolve()), 402 | str((Path(cfg.fed.parameters_folder)).resolve()), 403 | "now", 404 | ) 405 | if cfg.fed.history_folder is not None: 406 | run.save( 407 | str((Path(cfg.fed.history_folder) / "*").resolve()), 408 | str((Path(cfg.fed.history_folder)).resolve()), 409 | "now", 410 | ) 411 | if cfg.fed.rng_folder is not None: 412 | run.save( 413 | str((Path(cfg.fed.rng_folder) / "*").resolve()), 414 | str((Path(cfg.fed.rng_folder)).resolve()), 415 | "now", 416 | ) 417 | 418 | # Try to empty the wandb folder of old local runs 419 | log( 420 | logging.INFO, 421 | subprocess.run( 422 | [ 423 | "wandb", 424 | "sync", 425 | "--clean-old-hours", 426 | "24", 427 | ], 428 | capture_output=True, 429 | text=True, 430 | check=True, 431 | ), 432 | ) 433 | 434 | 435 | if __name__ == "__main__": 436 | main() 437 | -------------------------------------------------------------------------------- /project/utils/utils.py: -------------------------------------------------------------------------------- 1 | """Define any utility function. 2 | 3 | Generic utilities. 4 | """ 5 | 6 | import json 7 | import logging 8 | import re 9 | import shutil 10 | from collections.abc import Callable, Iterator 11 | from itertools import chain, islice 12 | from pathlib import Path 13 | from types import TracebackType 14 | from typing import Any, cast 15 | from omegaconf import DictConfig 16 | from pydantic import BaseModel 17 | 18 | import ray 19 | import torch 20 | from flwr.common.logger import log 21 | from project.types.common import Files 22 | import wandb 23 | from wandb.sdk.wandb_run import Run 24 | from wandb.sdk.lib.disabled import RunDisabled 25 | 26 | from project.types.common import Ext, FileCountExceededError, Folders, IsolatedRNG 27 | from flwr.common import NDArrays 28 | 29 | 30 | def obtain_device() -> torch.device: 31 | """Get the device (CPU or GPU) for torch. 32 | 33 | Returns 34 | ------- 35 | torch.device 36 | The device. 37 | """ 38 | return torch.device( 39 | "cuda:0" if torch.cuda.is_available() else "cpu", 40 | ) 41 | 42 | 43 | def lazy_wrapper(x: Callable) -> Callable[[], Any]: 44 | """Wrap a value in a function that returns the value. 45 | 46 | For easy instantion through hydra. 47 | 48 | Parameters 49 | ---------- 50 | x : Callable 51 | The value to wrap. 52 | 53 | Returns 54 | ------- 55 | Callable[[], Any] 56 | The wrapped value. 57 | """ 58 | return lambda: x 59 | 60 | 61 | def lazy_config_wrapper( 62 | x: Callable, 63 | ) -> Callable[[dict, IsolatedRNG, DictConfig | None], Any]: 64 | """Wrap a value in a function that returns the value given a config and rng_tuple. 65 | 66 | For easy instantiation through hydra. 67 | 68 | Parameters 69 | ---------- 70 | x : Callable 71 | The value to wrap. 72 | 73 | Returns 74 | ------- 75 | Callable[[Dict], Any] 76 | The wrapped value. 77 | """ 78 | return lambda _config, _rng_tuple, _hydra_config: x() 79 | 80 | 81 | class NoOpContextManager: 82 | """A context manager that does nothing.""" 83 | 84 | def __enter__(self) -> None: 85 | """Do nothing.""" 86 | return 87 | 88 | def __exit__( 89 | self, 90 | _exc_type: type[BaseException] | None, 91 | _exc_value: BaseException | None, 92 | _traceback: TracebackType | None, 93 | ) -> None: 94 | """Do nothing.""" 95 | 96 | 97 | def wandb_init( 98 | wandb_enabled: bool, 99 | *args: Any, 100 | **kwargs: Any, 101 | ) -> NoOpContextManager | Run | RunDisabled: 102 | """Initialize wandb if enabled. 103 | 104 | Parameters 105 | ---------- 106 | wandb_enabled : bool 107 | Whether wandb is enabled. 108 | *args : Any 109 | The arguments to pass to wandb.init. 110 | **kwargs : Any 111 | The keyword arguments to pass to wandb.init. 112 | 113 | Returns 114 | ------- 115 | Optional[Union[NoOpContextManager, Any]] 116 | The wandb context manager if enabled, otherwise a no-op context manager 117 | """ 118 | if wandb_enabled: 119 | run = wandb.init(*args, **kwargs) 120 | if run is not None: 121 | return run 122 | 123 | return NoOpContextManager() 124 | 125 | 126 | class WandbDetails(BaseModel): 127 | """The wandb details.""" 128 | 129 | wandb_id: str 130 | 131 | 132 | def save_wandb_run_details(run: Run, wandb_dir: Path) -> None: 133 | """Save the wandb run to the output directory. 134 | 135 | Parameters 136 | ---------- 137 | run : Run 138 | The wandb run. 139 | wandb_dir : Path 140 | The output directory. 141 | 142 | Returns 143 | ------- 144 | None 145 | """ 146 | wandb_run_details: dict[str, str] = { 147 | "wandb_id": run.id, 148 | } 149 | 150 | # Check if it conforms to the WandbDetails schema 151 | WandbDetails(**wandb_run_details) 152 | 153 | wandb_dir.mkdir(parents=True, exist_ok=True) 154 | with open( 155 | wandb_dir / f"{Files.WANDB_RUN}.{Ext.WANDB_RUN}", 156 | mode="w", 157 | encoding="utf-8", 158 | ) as f: 159 | json.dump(wandb_run_details, f) 160 | 161 | 162 | def load_wandb_run_details(wandb_dir: Path) -> WandbDetails | None: 163 | """Save the wandb run to the wandb_dir directory. 164 | 165 | Parameters 166 | ---------- 167 | run : Run 168 | The wandb run. 169 | wandb_dir : Path 170 | The output directory. 171 | 172 | Returns 173 | ------- 174 | None 175 | """ 176 | wandb_file = wandb_dir / f"{Files.WANDB_RUN}.{Ext.WANDB_RUN}" 177 | 178 | if not wandb_file.exists(): 179 | return None 180 | 181 | with open( 182 | wandb_file, 183 | encoding="utf-8", 184 | ) as f: 185 | return WandbDetails(**json.load(f)) 186 | 187 | 188 | class RayContextManager: 189 | """A context manager for cleaning up after ray.""" 190 | 191 | def __enter__(self) -> "RayContextManager": 192 | """Initialize the context manager.""" 193 | return self 194 | 195 | def __exit__( 196 | self, 197 | _exc_type: type[BaseException] | None, 198 | _exc_value: BaseException | None, 199 | _traceback: TracebackType | None, 200 | ) -> None: 201 | """Cleanup the files. 202 | 203 | Parameters 204 | ---------- 205 | _exc_type : Any 206 | The exception type. 207 | _exc_value : Any 208 | The exception value. 209 | _traceback : Any 210 | The traceback. 211 | 212 | Returns 213 | ------- 214 | None 215 | """ 216 | if ray.is_initialized(): 217 | temp_dir = Path( 218 | ray.worker._global_node.get_session_dir_path(), 219 | ) 220 | ray.shutdown() 221 | 222 | directory_size = shutil.disk_usage( 223 | temp_dir, 224 | ).used 225 | 226 | shutil.rmtree(temp_dir) 227 | log( 228 | logging.INFO, 229 | f"Cleaned up ray temp session: {temp_dir} with size: {directory_size}", 230 | ) 231 | 232 | 233 | def cleanup(working_dir: Path, to_clean: list[str]) -> None: 234 | """Cleanup the files in the working dir. 235 | 236 | Parameters 237 | ---------- 238 | working_dir : Path 239 | The working directory. 240 | to_clean : List[str] 241 | The tokens to clean. 242 | 243 | Returns 244 | ------- 245 | None 246 | """ 247 | children: list[Path] = [] 248 | for file in working_dir.iterdir(): 249 | if file.is_file(): 250 | for clean_token in to_clean: 251 | if clean_token in file.name and file.exists(): 252 | file.unlink() 253 | break 254 | else: 255 | children.append(file) 256 | 257 | for child in children: 258 | cleanup(child, to_clean) 259 | 260 | 261 | def get_highest_round( 262 | parameters_dir: Path, 263 | file_limit: int, 264 | ) -> int: 265 | """Get the index of the highest round. 266 | 267 | Parameters 268 | ---------- 269 | output_dir : Path 270 | The output directory. 271 | file_limit : int 272 | The maximal number of files to search. 273 | If None, then there is no limit. 274 | 275 | Returns 276 | ------- 277 | int 278 | The index of the highest round. 279 | """ 280 | same_name_files = cast( 281 | Iterator[Path], 282 | islice( 283 | chain( 284 | parameters_dir.glob(f"*{Files.PARAMETERS}_*"), 285 | parameters_dir.glob(f"*/*{Files.PARAMETERS}_*"), 286 | ), 287 | file_limit, 288 | ), 289 | ) 290 | 291 | indicies = ( 292 | int(v.group(1)) 293 | for f in same_name_files 294 | if (v := re.search(r"_([0-9]+)", f.stem)) 295 | ) 296 | return max(indicies, default=0) 297 | 298 | 299 | def save_files( 300 | working_dir: Path, 301 | output_dir: Path, 302 | to_save: list[str], 303 | server_round: int, 304 | file_limit: int, 305 | top_level: bool = True, 306 | file_cnt: int = 0, 307 | ) -> None: 308 | """Save the files in the working dir. 309 | 310 | Parameters 311 | ---------- 312 | working_dir : Path 313 | The working directory. 314 | output_dir : Path 315 | The output directory. 316 | 317 | Returns 318 | ------- 319 | None 320 | """ 321 | if not top_level: 322 | output_dir = output_dir / working_dir.name 323 | 324 | children: list[Path] = [] 325 | for file in working_dir.iterdir(): 326 | if file.is_file(): 327 | for save_token in to_save: 328 | if save_token in file.name and file.exists(): 329 | # Save the round file 330 | destination_file = ( 331 | output_dir 332 | / file.with_stem( 333 | f"{file.stem}_{server_round}", 334 | ).name 335 | ) 336 | 337 | latest_file = ( 338 | output_dir 339 | / file.with_stem( 340 | f"{file.stem}", 341 | ).name 342 | ) 343 | 344 | destination_file.parent.mkdir( 345 | parents=True, 346 | exist_ok=True, 347 | ) 348 | shutil.copy(file, destination_file) 349 | shutil.copy(file, latest_file) 350 | break 351 | else: 352 | children.append(file) 353 | 354 | for child in children: 355 | save_files( 356 | child, 357 | output_dir, 358 | to_save=to_save, 359 | top_level=False, 360 | server_round=server_round, 361 | file_limit=file_limit, 362 | file_cnt=file_cnt, 363 | ) 364 | 365 | 366 | def restore_files( 367 | working_dir: Path, 368 | output_dir: Path, 369 | to_restore: list[str], 370 | server_round: int, 371 | file_limit: int, 372 | top_level: bool = True, 373 | file_cnt: int = 0, 374 | ) -> None: 375 | """Save the files in the working dir. 376 | 377 | Parameters 378 | ---------- 379 | working_dir : Path 380 | The working directory. 381 | output_dir : Path 382 | The output directory. 383 | 384 | Returns 385 | ------- 386 | None 387 | """ 388 | if not top_level: 389 | working_dir = working_dir / output_dir.name 390 | 391 | children: list[Path] = [] 392 | for file in output_dir.iterdir(): 393 | file_cnt += 1 394 | if file.is_file(): 395 | if f"_{server_round}" in file.name: 396 | for restore_token in to_restore: 397 | if restore_token in file.name: 398 | destination_file = ( 399 | working_dir 400 | / file.with_stem( 401 | f"{file.stem.replace(f'_{server_round}', '')}", 402 | ).name 403 | ) 404 | 405 | destination_file.parent.mkdir( 406 | parents=True, 407 | exist_ok=True, 408 | ) 409 | shutil.copy(file, destination_file) 410 | break 411 | else: 412 | children.append(file) 413 | if file_cnt >= file_limit: 414 | raise FileCountExceededError( 415 | f"""You have exceeded the {file_limit} file limit, 416 | you may increase it in the config if you are sure about it.""" 417 | ) 418 | 419 | for child in children: 420 | restore_files( 421 | working_dir, 422 | child, 423 | to_restore=to_restore, 424 | top_level=False, 425 | server_round=server_round, 426 | file_limit=file_limit, 427 | file_cnt=file_cnt, 428 | ) 429 | 430 | 431 | class FileSystemManager: 432 | """A context manager for saving and cleaning up files.""" 433 | 434 | def __init__( 435 | self, 436 | working_dir: Path, 437 | results_dir: Path, 438 | load_parameters_from: Path | None, 439 | to_restore: list[str], 440 | to_clean_once: list[str], 441 | to_save_once: list[str], 442 | original_hydra_dir: Path, 443 | file_limit: int, 444 | starting_round: int | None, 445 | log_name: str = f"{Files.MAIN}.{Ext.MAIN}", 446 | ) -> None: 447 | """Initialize the context manager. 448 | 449 | Parameters 450 | ---------- 451 | working_dir : Path 452 | The working directory. 453 | results_dir : Path 454 | The output directory. 455 | to_clean_once : List[str] 456 | The tokens to clean once. 457 | to_save_once : List[str] 458 | The tokens to save once. 459 | original_hydra_dir : Path 460 | The original hydra directory. 461 | For copying the hydra directory to the working directory. 462 | file_limit : Optional[int] 463 | The maximal number of files to search. 464 | If None, then there is no limit. 465 | 466 | Returns 467 | ------- 468 | None 469 | """ 470 | self.to_clean_once = to_clean_once 471 | self.working_dir = working_dir 472 | self.results_dir = results_dir 473 | self.to_save_once = to_save_once 474 | self.to_restore = to_restore 475 | 476 | self.original_hydra_dir = original_hydra_dir 477 | 478 | highest_round = get_highest_round( 479 | parameters_dir=( 480 | load_parameters_from 481 | if load_parameters_from is not None 482 | else results_dir / Folders.STATE / Folders.PARAMETERS 483 | ), 484 | file_limit=file_limit, 485 | ) 486 | self.file_limit = file_limit 487 | 488 | self.server_round = ( 489 | min( 490 | highest_round, 491 | starting_round, 492 | ) 493 | if starting_round is not None 494 | else highest_round 495 | ) 496 | 497 | self.log_name = log_name 498 | 499 | def get_save_files_every_round( 500 | self, 501 | to_save: list[str], 502 | save_frequency: int, 503 | ) -> Callable[[int], None]: 504 | """Get a function that saves files every save_frequency rounds. 505 | 506 | Parameters 507 | ---------- 508 | to_save : List[str] 509 | The tokens to save. 510 | save_frequency : int 511 | The frequency to save. 512 | 513 | Returns 514 | ------- 515 | Callable[[int], None] 516 | The function that saves the files. 517 | """ 518 | 519 | def save_files_round(cur_round: int) -> None: 520 | self.server_round = cur_round 521 | if cur_round % save_frequency == 0: 522 | save_files( 523 | self.working_dir, 524 | self.results_dir, 525 | to_save=to_save, 526 | server_round=cur_round, 527 | file_limit=self.file_limit, 528 | ) 529 | 530 | return save_files_round 531 | 532 | def __enter__(self) -> "FileSystemManager": 533 | """Initialize the context manager and cleanup.""" 534 | log( 535 | logging.INFO, 536 | f"Pre-cleaning {self.to_clean_once}", 537 | ) 538 | # cleanup(self.working_dir, self.to_clean_once) 539 | restore_files( 540 | self.working_dir, 541 | self.results_dir, 542 | self.to_restore, 543 | server_round=self.server_round, 544 | file_limit=self.file_limit, 545 | ) 546 | return self 547 | 548 | def __exit__( 549 | self, 550 | _exc_type: type[BaseException] | None, 551 | _exc_value: BaseException | None, 552 | _traceback: TracebackType | None, 553 | ) -> None: 554 | """Cleanup the files.""" 555 | log(logging.INFO, f"Saving {self.to_save_once}") 556 | 557 | # Copy the hydra directory to the working directory 558 | # so that multiple runs can be ran 559 | # in the same output directory and configs versioned 560 | hydra_dir = self.working_dir / Folders.HYDRA 561 | 562 | shutil.copytree( 563 | str(self.original_hydra_dir / Folders.HYDRA), 564 | str(object=hydra_dir), 565 | dirs_exist_ok=True, 566 | ) 567 | 568 | # Move main.log to the working directory 569 | main_log = self.original_hydra_dir / self.log_name 570 | shutil.copy2( 571 | str(main_log), 572 | str(self.working_dir / self.log_name), 573 | ) 574 | save_files( 575 | self.working_dir, 576 | self.results_dir, 577 | to_save=self.to_save_once, 578 | server_round=self.server_round, 579 | file_limit=self.file_limit, 580 | ) 581 | log( 582 | logging.INFO, 583 | f"Post-cleaning {self.to_clean_once}", 584 | ) 585 | cleanup( 586 | self.working_dir, 587 | to_clean=self.to_clean_once, 588 | ) 589 | 590 | 591 | def gather_layers_from_list(lst: NDArrays, idx: list[int]) -> NDArrays | None: 592 | """Gather a list of items. 593 | 594 | Parameters 595 | ---------- 596 | lst : list[Any] 597 | The list. 598 | idx : list[int] 599 | The indices to gather. 600 | 601 | Returns 602 | ------- 603 | list[Any] 604 | The gathered list. 605 | """ 606 | return ( 607 | ret if len(ret := [lst[i] for i in idx if i < len(lst)]) == len(idx) else None 608 | ) 609 | 610 | 611 | def ungather_to_list(lists: list[tuple[list[Any], list[int]]]) -> list[Any]: 612 | """Ungather a list of lists. 613 | 614 | Parameters 615 | ---------- 616 | lists : list[tuple[list[Any], list[int]]] 617 | The list of lists. 618 | 619 | Returns 620 | ------- 621 | list[Any] 622 | The ungathered list. 623 | """ 624 | results: list[Any] = [None] * sum(len(idx) for _, idx in lists) 625 | for lst, idx in lists: 626 | for j, item_id in enumerate(idx): 627 | results[item_id] = lst[j] 628 | return results 629 | --------------------------------------------------------------------------------