├── .gitignore ├── requirements.txt ├── conf ├── model │ ├── simple_net.yaml │ └── resnet18.yaml ├── client │ ├── cpu_client.yaml │ ├── gpu_client.yaml │ └── kd_client.yaml ├── strategy │ ├── fedavg.yaml │ ├── custom_strategy.yaml │ ├── strategy_model_saving.yaml │ └── strategy_kd.yaml ├── base.yaml ├── base_kd.yaml └── base_v2.yaml ├── LICENSE ├── src ├── models.py ├── server.py ├── model_utils.py ├── datasets.py ├── client.py ├── strategy.py └── common.py ├── main.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__/ 2 | data/ 3 | outputs/ 4 | *.DS_Store 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flwr[simulation]==1.7.0 2 | hydra-core==1.3.2 3 | tqdm==4.65.0 -------------------------------------------------------------------------------- /conf/model/simple_net.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.Net 2 | num_classes: ${dataset.num_classes} # relative definition to what's in dataset.num_classes -------------------------------------------------------------------------------- /conf/client/cpu_client.yaml: -------------------------------------------------------------------------------- 1 | resources: 2 | num_cpus: 2 3 | num_gpus: 0.0 4 | object: 5 | _target_: src.client.FlowerClient 6 | cfg: 7 | model: ${model} 8 | optim: 9 | _target_: torch.optim.SGD 10 | lr: 0.01 11 | momentum: 0.9 -------------------------------------------------------------------------------- /conf/client/gpu_client.yaml: -------------------------------------------------------------------------------- 1 | resources: 2 | num_cpus: 4 3 | num_gpus: 0.5 4 | object: 5 | _target_: src.client.FlowerClient 6 | cfg: 7 | model: ${model} 8 | optim: 9 | _target_: torch.optim.SGD 10 | lr: 0.01 11 | momentum: 0.9 -------------------------------------------------------------------------------- /conf/model/resnet18.yaml: -------------------------------------------------------------------------------- 1 | # will instantiate a ResNet18 as the model to federate. 2 | # ! You probably want to use the gpu_client with this model 3 | _target_: src.models.ResNet18 4 | num_classes: ${dataset.num_classes} # relative definition to what's in dataset.num_classes -------------------------------------------------------------------------------- /conf/client/kd_client.yaml: -------------------------------------------------------------------------------- 1 | # A client that's ready for doing local Knowledge Distillation using a teacher 2 | # sent by the server. The student is the model being federated. 3 | 4 | resources: 5 | num_cpus: 4 6 | num_gpus: 0.5 7 | object: 8 | _target_: src.client.FlowerClientWithKD 9 | cfg: 10 | model: ${model} 11 | optim: 12 | _target_: torch.optim.SGD 13 | lr: 0.01 14 | momentum: 0.9 -------------------------------------------------------------------------------- /conf/strategy/fedavg.yaml: -------------------------------------------------------------------------------- 1 | # A faily standard FedAvg strategy for FLower 2 | 3 | _target_: flwr.server.strategy.FedAvg # we point to a class (in the code we will use `hydra.utils.instantiate()` to create the object) 4 | _recursive_: true # we want everything to be instantiated 5 | fraction_fit: 0.0001 # by setting this to a very low number, we can easily control how many clients participate in simulation with a single scalar (i.e. server.clients_per_round) -- see logic here for details: https://github.com/adap/flower/blob/1b4b659204c406bb4fd7821d39a689105543ecbe/src/py/flwr/server/strategy/fedavg.py#L143 6 | fraction_evaluate: 0.0 # no clients will be sampled for federated evaluation (we will still perform global evaluation) 7 | min_fit_clients: ${server.clients_per_round} 8 | min_available_clients: ${server.pool} 9 | on_fit_config_fn: 10 | _target_: src.server.gen_fit_config # function to call eventually 11 | fit_cfg: ${server.fit_cfg} -------------------------------------------------------------------------------- /conf/strategy/custom_strategy.yaml: -------------------------------------------------------------------------------- 1 | # A strategy that inherites from FedAvg but adds extra functionality 2 | 3 | _target_: src.strategy.CustomFedAvg 4 | _recursive_: true # we want everything to be instantiated 5 | # Let's pass first (order doesn't matter) the arguments unique to our custom strategy 6 | num_rounds: ${server.num_rounds} 7 | eval_every_n: 5 # run global evaluation every this many rounds (will always run on the first and last round) 8 | keep_ratio: 0.5 # this ratio of clients that participated in round N will be sampled again in round N+1 9 | drop_ratio: 0.25 # this ratio of client updates sent back to the server will be dropped before doing aggregation 10 | # We pass the usual arguments needed for a strategy (in this case ours inherits from FedAvg) 11 | fraction_fit: 0.0001 12 | fraction_evaluate: 0.0 13 | min_fit_clients: ${server.clients_per_round} 14 | min_available_clients: ${server.pool} 15 | on_fit_config_fn: 16 | _target_: src.server.gen_fit_config # function to call eventually 17 | fit_cfg: ${server.fit_cfg} -------------------------------------------------------------------------------- /conf/strategy/strategy_model_saving.yaml: -------------------------------------------------------------------------------- 1 | # A faily standard FedAvg strategy for FLower with the additional functionality being 2 | # that the strategy keeps track of the global model state and saves to disk after each 3 | # evaluation. This allows to retrieve the global model parameters easily after the 4 | # simulation is completed. 5 | 6 | _target_: src.strategy.CustomFedAvgWithModelSaving 7 | _recursive_: true # we want everything to be instantiated 8 | fraction_fit: 0.0001 # by setting this to a very low number, we can easily control how many clients participate in simulation with a single scalar (i.e. server.clients_per_round) -- see logic here for details: https://github.com/adap/flower/blob/1b4b659204c406bb4fd7821d39a689105543ecbe/src/py/flwr/server/strategy/fedavg.py#L143 9 | fraction_evaluate: 0.0 # no clients will be sampled for federated evaluation (we will still perform global evaluation) 10 | min_fit_clients: ${server.clients_per_round} 11 | min_available_clients: ${server.pool} 12 | on_fit_config_fn: 13 | _target_: src.server.gen_fit_config # function to call eventually 14 | fit_cfg: ${server.fit_cfg} -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Javier Fernandez Marques 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /conf/strategy/strategy_kd.yaml: -------------------------------------------------------------------------------- 1 | # Demonstrates a simplified setup of client-side knowledge distillation. Different from other typical 2 | # FL examples, here the server communicates two models: (1) a teacher model that's pre-trained and (2) a 3 | # student network that's learned in a federated fashion by distilling on the client's side. 4 | 5 | # Let's pass first (order doesn't matter) the arguments unique to our custom strategy 6 | _target_: src.strategy.CustomFedAvgWithKD 7 | _recursive_: false # we dont' want the whole thing to be isntantiated (e.g. child nodes with _target_) 8 | teacher: 9 | _target_: src.models.ResNet18 # let's use a ResNet-18 as our teach (because why not) 10 | num_classes: 10 11 | kd_config: 12 | teacher_pretrain: 13 | batch_size: 32 14 | optim: 15 | _target_: torch.optim.SGD 16 | lr: 0.1 17 | momentum: 0.9 18 | num_batches: 50 # let's limit how many batches of data are use for training the teacher. 19 | student_train: 20 | temperature: 2 21 | alpha: 0.5 22 | # We pass the usual arguments needed for a strategy (in this case ours inherits from FedAvg) 23 | fraction_fit: 0.0001 24 | fraction_evaluate: 0.0 25 | min_fit_clients: ${server.clients_per_round} 26 | min_available_clients: ${server.pool} 27 | on_fit_config_fn: 28 | _target_: src.server.gen_fit_config # function to call eventually 29 | fit_cfg: ${server.fit_cfg} 30 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision import models 6 | 7 | # Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz') 8 | # borrowed from Pytorch quickstart example 9 | class Net(nn.Module): 10 | def __init__(self, num_classes: int) -> None: 11 | super(Net, self).__init__() 12 | self.conv1 = nn.Conv2d(3, 6, 5) 13 | self.pool = nn.MaxPool2d(2, 2) 14 | self.conv2 = nn.Conv2d(6, 16, 5) 15 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 16 | self.fc2 = nn.Linear(120, 84) 17 | self.fc3 = nn.Linear(84, num_classes) 18 | 19 | def forward(self, x: torch.Tensor) -> torch.Tensor: 20 | x = self.pool(F.relu(self.conv1(x))) 21 | x = self.pool(F.relu(self.conv2(x))) 22 | x = x.view(-1, 16 * 5 * 5) 23 | x = F.relu(self.fc1(x)) 24 | x = F.relu(self.fc2(x)) 25 | x = self.fc3(x) 26 | return x 27 | 28 | class ResNet18(torch.nn.Module): 29 | def __init__(self, num_classes: int): 30 | super().__init__() 31 | 32 | self.resnet = models.resnet18(weights=None, num_classes=num_classes) 33 | 34 | self.resnet.conv1 = torch.nn.Conv2d( 35 | 3, 64, kernel_size=3, stride=1, padding=1, bias=False 36 | ) 37 | self.resnet.maxpool = torch.nn.Identity() 38 | 39 | def forward(self, x): 40 | x = self.resnet(x) 41 | x = F.log_softmax(x, dim=1) 42 | 43 | return x 44 | -------------------------------------------------------------------------------- /conf/base.yaml: -------------------------------------------------------------------------------- 1 | 2 | defaults: # i.e. configs used if you simply run the code as `python main.py` 3 | - client: cpu_client # this points to the file: client/cpu_client.yaml 4 | - model: simple_net # this points to the file: model/simple_net.yaml 5 | - strategy: fedavg # points to strategy/fedavg.yaml 6 | 7 | dataset: 8 | name: 'CIFAR10' 9 | num_classes: 10 10 | lda_alpha: 1.0 11 | prepare: 12 | _target_: src.datasets.get_cifar_10_and_partition # we define as target a function to execute (in the code we'll use `hydra.utils.call()` to execute it) 13 | config: 14 | pool: ${server.pool} # use ${} syntaxt to access other nodes (and their properties) throught the config (even if they are defined in other .yaml files -- as long as they are part of the run, i.e., parsed by Hydra) 15 | # alpha and num_classes are defined inmediately below the `dataset` just so it's easier to modify them from the CLI and to reference them from other parts in the config 16 | alpha: ${dataset.lda_alpha} # needed to know how to do LDA partitioning 17 | num_classes: ${dataset.num_classes} # needed to know how many partitions to create 18 | val: 0.1 19 | 20 | server: 21 | pool: 100 22 | clients_per_round: 10 23 | num_rounds: 10 24 | fit_cfg: # config for each client's fit() method (this will be passed the strategy) 25 | epochs: 1 26 | batch_size: 32 27 | num_cpu: ${client.resources.num_cpus} # for dataloader's num_workers 28 | 29 | misc: 30 | attach: false # set to true if you want Flower's Virtual Client Engine (VCE) to attach to an already running Ray server 31 | -------------------------------------------------------------------------------- /conf/base_kd.yaml: -------------------------------------------------------------------------------- 1 | 2 | # this defines a top-level config (just like base.yaml) does but with changes to the `defaults` and the FL setup parameterised in `server:` 3 | 4 | defaults: # i.e. configs used if you simply run the code as `python main.py` 5 | - client: kd_client # this points to the file: client/kd_client.yaml 6 | - model: simple_net # this points to the file: model/simple_net.yaml 7 | - strategy: strategy_kd # points to strategy/strategy_kd.yaml 8 | 9 | dataset: 10 | name: 'CIFAR10' 11 | num_classes: 10 12 | lda_alpha: 1.0 13 | prepare: 14 | _target_: src.datasets.get_cifar_10_and_partition # we define as target a function to execute (in the code we'll use `hydra.utils.call()` to execute it) 15 | config: 16 | pool: ${server.pool} # use ${} syntaxt to access other nodes (and their properties) throught the config (even if they are defined in other .yaml files -- as long as they are part of the run, i.e., parsed by Hydra) 17 | # alpha and num_classes are defined inmediately below the `dataset` just so it's easier to modify them from the CLI and to reference them from other parts in the config 18 | alpha: ${dataset.lda_alpha} # needed to know how to do LDA partitioning 19 | num_classes: ${dataset.num_classes} # needed to know how many partitions to create 20 | val: 0.1 21 | 22 | server: 23 | pool: 500 24 | clients_per_round: 20 25 | num_rounds: 10 26 | fit_cfg: # config for each client's fit() method (this will be passed the strategy) 27 | epochs: 1 28 | batch_size: 32 29 | num_cpu: ${client.resources.num_cpus} # for dataloader's num_workers 30 | 31 | misc: 32 | attach: false # set to true if you want Flower's Virtual Client Engine (VCE) to attach to an already running Ray server 33 | -------------------------------------------------------------------------------- /conf/base_v2.yaml: -------------------------------------------------------------------------------- 1 | 2 | # this defines a top-level config (just like base.yaml) does but with changes to the `defaults` and the FL setup parameterised in `server:` 3 | 4 | defaults: # i.e. configs used if you simply run the code as `python main.py` 5 | - client: gpu_client # this points to the file: client/gpu_client.yaml 6 | - model: resnet18 # this points to the file: model/resnet18.yaml 7 | - strategy: custom_strategy # points to strategy/custom_strategy.yaml 8 | 9 | dataset: 10 | name: 'CIFAR10' 11 | num_classes: 10 12 | lda_alpha: 1.0 13 | prepare: 14 | _target_: src.datasets.get_cifar_10_and_partition # we define as target a function to execute (in the code we'll use `hydra.utils.call()` to execute it) 15 | config: 16 | pool: ${server.pool} # use ${} syntaxt to access other nodes (and their properties) throught the config (even if they are defined in other .yaml files -- as long as they are part of the run, i.e., parsed by Hydra) 17 | # alpha and num_classes are defined inmediately below the `dataset` just so it's easier to modify them from the CLI and to reference them from other parts in the config 18 | alpha: ${dataset.lda_alpha} # needed to know how to do LDA partitioning 19 | num_classes: ${dataset.num_classes} # needed to know how many partitions to create 20 | val: 0.1 21 | 22 | server: 23 | pool: 500 24 | clients_per_round: 20 25 | num_rounds: 10 26 | fit_cfg: # config for each client's fit() method (this will be passed the strategy) 27 | epochs: 1 28 | batch_size: 32 29 | num_cpu: ${client.resources.num_cpus} # for dataloader's num_workers 30 | 31 | misc: 32 | attach: false # set to true if you want Flower's Virtual Client Engine (VCE) to attach to an already running Ray server 33 | -------------------------------------------------------------------------------- /src/server.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Dict, Callable, Optional, Tuple 3 | 4 | 5 | import torch 6 | from omegaconf import DictConfig 7 | from hydra.utils import instantiate 8 | from flwr.common.typing import Scalar, NDArrays 9 | 10 | 11 | from src.model_utils import test, ndarrays_to_model 12 | 13 | 14 | def gen_fit_config(fit_cfg: DictConfig): 15 | def fit_config(server_round: int) -> Dict[str, Scalar]: 16 | """Return a configuration with static batch size and (local) epochs.""" 17 | return fit_cfg 18 | 19 | return fit_config 20 | 21 | 22 | def get_evaluate_fn( 23 | testset, 24 | model_cfg: DictConfig, 25 | ) -> Callable[[NDArrays], Optional[Tuple[float, float]]]: 26 | """Return an evaluation function for centralized evaluation.""" 27 | 28 | def evaluate( 29 | server_round: int, parameters:NDArrays, config: Dict[str, Scalar], is_last_round: bool=False 30 | ) -> Optional[Tuple[float, float]]: 31 | """Use the entire CIFAR-10 test set for evaluation.""" 32 | 33 | # determine device 34 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 35 | 36 | # Let's first instantiate the model 37 | model = instantiate(model_cfg) 38 | # Now set the model buffers with the parameters of the global model 39 | ndarrays_to_model(model, parameters) 40 | model.to(device) 41 | 42 | # here you could use the config to parameterise how the global evaluation is performed (e.g. use a particular bach size) 43 | # you could also use the `is_last_round` flag to switch between a global validation set and a global test set. 44 | # The global test set should be used only in the last round, while the global validation set can be used in all rounds. 45 | print(f"Is this the last round?: {is_last_round = }") 46 | 47 | testloader = torch.utils.data.DataLoader(testset, batch_size=128) 48 | 49 | # run global evaluation 50 | loss, accuracy = test(model, testloader, device=device) 51 | 52 | # Now you have evaluated the global model. This is the a good place to save a checkpoint if, for instance, a new 53 | # best global model is found (based on a global validation set). 54 | # If for instance you are using tensorboard to record global metrics or W&B (even better!!) this is the a good 55 | # place to log all the metrics you want. 56 | 57 | # return statistics 58 | return loss, {"accuracy": accuracy} 59 | 60 | return evaluate -------------------------------------------------------------------------------- /src/model_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from flwr.common import ndarrays_to_parameters 9 | 10 | def model_as_ndarrays(model: torch.nn.ModuleList) -> List[np.ndarray]: 11 | """Get model weights as a list of NumPy ndarrays.""" 12 | return [val.cpu().numpy() for _, val in model.state_dict().items()] 13 | 14 | 15 | def ndarrays_to_model(model: torch.nn.ModuleList, params: List[np.ndarray]): 16 | """Set model weights from a list of NumPy ndarrays.""" 17 | params_dict = zip(model.state_dict().keys(), params) 18 | state_dict = OrderedDict({k: torch.from_numpy(np.copy(v)) for k, v in params_dict}) 19 | model.load_state_dict(state_dict, strict=True) 20 | 21 | 22 | # borrowed from Pytorch quickstart example 23 | def train(net, trainloader, optim, epochs, device: str): 24 | """Train the network on the training set.""" 25 | criterion = torch.nn.CrossEntropyLoss() 26 | net.train() 27 | for _ in range(epochs): 28 | for images, labels in trainloader: 29 | images, labels = images.to(device), labels.to(device) 30 | optim.zero_grad() 31 | loss = criterion(net(images), labels) 32 | loss.backward() 33 | optim.step() 34 | 35 | def train_with_kd(net, teacher, kd_config, trainloader, optim, epochs, device:str): 36 | """Train network on the training set using KD.""" 37 | 38 | alpha = kd_config.alpha 39 | temp = kd_config.temperature 40 | def kd_loss(student_output, labels, teacher_outputs): 41 | # KD loss (borrowing from https://github.com/haitongli/knowledge-distillation-pytorch/blob/9937528f0be0efa979c745174fbcbe9621cea8b7/model/net.py#L100) 42 | return nn.KLDivLoss(reduction='batchmean')(F.log_softmax(student_output/temp, dim=1), 43 | F.softmax(teacher_outputs/temp, dim=1)) * (alpha * temp**2) + F.cross_entropy(student_output, labels) * (1. - alpha) 44 | 45 | net.train() 46 | teacher.eval() 47 | for _ in range(epochs): 48 | for images, labels in trainloader: 49 | images, labels = images.to(device), labels.to(device) 50 | optim.zero_grad() 51 | s_out = net(images) 52 | 53 | # pass same batch throught teacher model 54 | with torch.no_grad(): 55 | t_out = teacher(images) 56 | loss = kd_loss(s_out, labels, t_out) 57 | loss.backward() 58 | optim.step() 59 | 60 | # borrowed from Pytorch quickstart example 61 | def test(net, testloader, device: str): 62 | """Validate the network on the entire test set.""" 63 | criterion = torch.nn.CrossEntropyLoss() 64 | correct, total, loss = 0, 0, 0.0 65 | net.eval() 66 | with torch.no_grad(): 67 | for data in testloader: 68 | images, labels = data[0].to(device), data[1].to(device) 69 | outputs = net(images) 70 | loss += criterion(outputs, labels).item() 71 | _, predicted = torch.max(outputs.data, 1) 72 | total += labels.size(0) 73 | correct += (predicted == labels).sum().item() 74 | accuracy = correct / total 75 | return loss, accuracy 76 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | 4 | import flwr as fl 5 | import hydra 6 | from hydra.utils import call, instantiate 7 | from hydra.core.hydra_config import HydraConfig 8 | from omegaconf import DictConfig, OmegaConf 9 | 10 | 11 | from src.server import get_evaluate_fn 12 | from src.strategy import CustomFedAvgWithModelSaving 13 | 14 | 15 | @hydra.main(version_base=None, config_path="conf", config_name="base") 16 | def run(cfg : DictConfig): 17 | 18 | print(OmegaConf.to_yaml(cfg)) 19 | 20 | # Each time you run this, Hydra will create a new directory containing 21 | # the config you used as well as the generated log. You can retrieve 22 | # the path to this directory as shown below. Ideally, here is where 23 | # you'd be saving any output (e.g. checkpoints) for this experiment 24 | save_path = HydraConfig.get().runtime.output_dir 25 | print(f"Output directory for this experiment: {save_path}") 26 | 27 | # let's prepare the dataset (download + partition) 28 | fed_dir, testset = call(cfg.dataset.prepare) 29 | 30 | # let's define our strategy (instantiating the object defined in the config) 31 | # You can pass additional arguments needed for the object (that weren't possible 32 | # to define in the config maybe becasue they are defined at runtime). You need to 33 | # use keyword arguments. 34 | # in this case, the function to evaluate the global model requires passing the testset object 35 | # Our strategy config might contain other nodes with _target_. Often, we want to delay when these 36 | # are instantiated until, for instance, all variables needed to do so are ready. We set _recursive_=False 37 | # to leave those nodes un-initialised (we set that in the config itself with the appropiate value) 38 | strategy = instantiate(cfg.strategy, evaluate_fn=get_evaluate_fn(testset, cfg.model)) 39 | 40 | def client_fn(cid: str): 41 | # Create a single client instance 42 | # The type of client class is set at runtime based on the config used. Recall we need to pass 43 | # extra arguemtns that weren't available when when the config is parsed. Also, let's not instantiate 44 | # every object inside the client config (use `_recursive_`=False). This will give us full control on 45 | # when instantiation happens. 46 | return instantiate(cfg.client.object, cid=cid, fed_dir_data=fed_dir, _recursive_=False).to_client() 47 | 48 | # (optional) specify Ray config 49 | # If you want to do multi-node simulations you want the VCE to attach to an existing Ray server 50 | ray_init_args = {"include_dashboard": False, "address": "auto" if cfg.misc.attach else None} 51 | 52 | # start simulation 53 | history = fl.simulation.start_simulation( 54 | client_fn=client_fn, 55 | num_clients=cfg.server.pool, # total number of clients in the experiment 56 | client_resources=cfg.client.resources, # resources that will be reserved for each client 57 | config=fl.server.ServerConfig(num_rounds=cfg.server.num_rounds), 58 | strategy=strategy, 59 | ray_init_args=ray_init_args, 60 | ) 61 | 62 | # now you can for instance save your results in to a Python pickle 63 | extra_results = {} # add here any other results you want to save 64 | 65 | # if your strategy is keepting track of some variables you want to retreive once 66 | # the experiments is completed, you can totally do so. You might want to do this 67 | # for instance, in order to save the global model weights 68 | if isinstance(strategy, CustomFedAvgWithModelSaving): 69 | model_parameters = strategy.global_parameters 70 | extra_results['global_parameters'] = model_parameters 71 | 72 | # add everything into a single dictionary 73 | data = {'history': history, **extra_results} 74 | 75 | results_path = Path(save_path)/'results.pkl' 76 | # save to pickle 77 | with open(str(results_path), "wb") as handle: 78 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 79 | 80 | print(f"Results saved into: {results_path}") 81 | 82 | if __name__ == "__main__": 83 | 84 | run() 85 | -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | 2 | import shutil 3 | from PIL import Image 4 | from pathlib import Path 5 | from typing import Callable, Optional, Tuple, Any 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision import datasets, transforms 10 | from torch.utils.data import DataLoader 11 | from torchvision.datasets import VisionDataset 12 | 13 | from omegaconf import DictConfig 14 | 15 | from .common import create_lda_partitions 16 | 17 | 18 | def get_dataset(path_to_data: Path, cid: str, partition: str): 19 | # generate path to cid's data 20 | path_to_data = path_to_data / cid / (partition + ".pt") 21 | 22 | return TorchVision_FL(path_to_data, transform=cifar10Transformation()) 23 | 24 | 25 | def get_dataloader( 26 | path_to_data: str, cid: str, is_train: bool, batch_size: int, workers: int 27 | ): 28 | """Generates trainset/valset object and returns appropiate dataloader.""" 29 | 30 | partition = "train" if is_train else "val" 31 | dataset = get_dataset(Path(path_to_data), cid, partition) 32 | 33 | # we use as number of workers all the cpu cores assigned to this actor 34 | kwargs = {"num_workers": workers, "pin_memory": True, "drop_last": False} 35 | return DataLoader(dataset, batch_size=batch_size, **kwargs) 36 | 37 | 38 | def get_random_id_splits(total: int, val_ratio: float, shuffle: bool = True): 39 | """splits a list of length `total` into two following a 40 | (1-val_ratio):val_ratio partitioning. 41 | By default the indices are shuffled before creating the split and 42 | returning. 43 | """ 44 | 45 | if isinstance(total, int): 46 | indices = list(range(total)) 47 | else: 48 | indices = total 49 | 50 | split = int(np.floor(val_ratio * len(indices))) 51 | # print(f"Users left out for validation (ratio={val_ratio}) = {split} ") 52 | if shuffle: 53 | np.random.shuffle(indices) 54 | return indices[split:], indices[:split] 55 | 56 | 57 | def do_fl_partitioning(path_to_dataset, pool_size, alpha, num_classes, val_ratio=0.0): 58 | """Torchvision (e.g. CIFAR-10) datasets using LDA.""" 59 | 60 | images, labels = torch.load(path_to_dataset) 61 | idx = np.array(range(len(images))) 62 | dataset = [idx, labels] 63 | partitions, _ = create_lda_partitions( 64 | dataset, num_partitions=pool_size, concentration=alpha, accept_imbalanced=True 65 | ) 66 | 67 | # Show label distribution for first partition (purely informative) 68 | partition_zero = partitions[0][1] 69 | hist, _ = np.histogram(partition_zero, bins=list(range(num_classes + 1))) 70 | print( 71 | f"Class histogram for 0-th partition (alpha={alpha}, {num_classes} classes): {hist}" 72 | ) 73 | 74 | # now save partitioned dataset to disk 75 | # first delete dir containing splits (if exists), then create it 76 | splits_dir = path_to_dataset.parent / "federated" 77 | if splits_dir.exists(): 78 | shutil.rmtree(splits_dir) 79 | Path.mkdir(splits_dir, parents=True) 80 | 81 | for p in range(pool_size): 82 | 83 | labels = partitions[p][1] 84 | image_idx = partitions[p][0] 85 | imgs = images[image_idx] 86 | 87 | # create dir 88 | Path.mkdir(splits_dir / str(p)) 89 | 90 | if val_ratio > 0.0: 91 | # split data according to val_ratio 92 | train_idx, val_idx = get_random_id_splits(len(labels), val_ratio) 93 | val_imgs = imgs[val_idx] 94 | val_labels = labels[val_idx] 95 | 96 | with open(splits_dir / str(p) / "val.pt", "wb") as f: 97 | torch.save([val_imgs, val_labels], f) 98 | 99 | # remaining images for training 100 | imgs = imgs[train_idx] 101 | labels = labels[train_idx] 102 | 103 | with open(splits_dir / str(p) / "train.pt", "wb") as f: 104 | torch.save([imgs, labels], f) 105 | 106 | return splits_dir 107 | 108 | 109 | def cifar10Transformation(): 110 | return transforms.Compose( 111 | [ 112 | transforms.ToTensor(), 113 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 114 | ] 115 | ) 116 | 117 | 118 | class TorchVision_FL(VisionDataset): 119 | """This is just a trimmed down version of torchvision.datasets.MNIST. 120 | Use this class by either passing a path to a torch file (.pt) 121 | containing (data, targets) or pass the data, targets directly 122 | instead. 123 | """ 124 | 125 | def __init__( 126 | self, 127 | path_to_data=None, 128 | data=None, 129 | targets=None, 130 | transform: Optional[Callable] = None, 131 | ) -> None: 132 | path = path_to_data.parent if path_to_data else None 133 | super(TorchVision_FL, self).__init__(path, transform=transform) 134 | self.transform = transform 135 | 136 | if path_to_data: 137 | # load data and targets (path_to_data points to an specific .pt file) 138 | self.data, self.targets = torch.load(path_to_data) 139 | else: 140 | self.data = data 141 | self.targets = targets 142 | 143 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 144 | img, target = self.data[index], int(self.targets[index]) 145 | 146 | # doing this so that it is consistent with all other datasets 147 | # to return a PIL Image 148 | if not isinstance(img, Image.Image): # if not PIL image 149 | if not isinstance(img, np.ndarray): # if torch tensor 150 | img = img.numpy() 151 | 152 | img = Image.fromarray(img) 153 | 154 | if self.transform is not None: 155 | img = self.transform(img) 156 | 157 | if self.target_transform is not None: 158 | target = self.target_transform(target) 159 | 160 | return img, target 161 | 162 | def __len__(self) -> int: 163 | return len(self.data) 164 | 165 | 166 | def get_cifar_10(path_to_data="./data"): 167 | """Downloads CIFAR10 dataset and generates a unified training set (it will 168 | be partitioned later using the LDA partitioning mechanism.""" 169 | 170 | # download dataset and load train set 171 | train_set = datasets.CIFAR10(root=path_to_data, train=True, download=True) 172 | 173 | # fuse all data splits into a single "training.pt" 174 | data_loc = Path(path_to_data) / "cifar-10-batches-py" 175 | training_data = data_loc / "training.pt" 176 | print("Generating unified CIFAR dataset") 177 | torch.save([train_set.data, np.array(train_set.targets)], training_data) 178 | 179 | test_set = datasets.CIFAR10( 180 | root=path_to_data, train=False, transform=cifar10Transformation() 181 | ) 182 | 183 | # returns path where training data is and testset 184 | return training_data, test_set 185 | 186 | 187 | def get_cifar_10_and_partition(config: DictConfig, path_to_data: str='./data'): 188 | 189 | train_path, testset = get_cifar_10(path_to_data=path_to_data) 190 | 191 | fed_dir = do_fl_partitioning(train_path, config.pool, config.alpha, config.num_classes, val_ratio=config.val) 192 | 193 | return fed_dir, testset 194 | -------------------------------------------------------------------------------- /src/client.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | from typing import Dict 3 | from pathlib import Path 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import numpy as np 8 | from hydra.utils import instantiate 9 | from omegaconf import DictConfig 10 | 11 | import flwr as fl 12 | from flwr.common.typing import Scalar 13 | 14 | from .datasets import get_dataloader 15 | from .model_utils import train, train_with_kd, test, model_as_ndarrays, ndarrays_to_model 16 | 17 | class FlowerClient(fl.client.NumPyClient): 18 | """A very standard Flower client customisable via Hydra configs. 19 | Simple but covers 95%+ of what you'd want to do in FL.""" 20 | def __init__(self, cid: str, fed_dir_data: str, cfg: DictConfig): 21 | self.cid = cid 22 | self.fed_dir = Path(fed_dir_data) 23 | self.properties: Dict[str, Scalar] = {"tensor_type": "numpy.ndarray"} 24 | 25 | self.cfg = cfg 26 | 27 | # Instantiate model (because the client class might have also been instantiated via Hydra, you want to make sure that 28 | # the client was instantiated with _recursive_=False. Else the below will fail. Even worse! all clients will be pointing 29 | # to the same object so it will definetively create problems -- which can be solved via copy.deepcopy() but why doing it that way?) 30 | self.net = instantiate(self.cfg.model) 31 | 32 | # Determine device 33 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 34 | 35 | def get_parameters(self, config): 36 | return model_as_ndarrays(self.net) 37 | 38 | def set_parameters(self, parameters): 39 | ndarrays_to_model(self.net, parameters) 40 | 41 | def fit(self, parameters, config): 42 | self.set_parameters(parameters) 43 | 44 | # Load data for this client and get trainloader 45 | trainloader = get_dataloader( 46 | self.fed_dir, 47 | self.cid, 48 | is_train=True, 49 | batch_size=config["batch_size"], 50 | workers=config["num_cpu"], 51 | ) 52 | 53 | # Send model to device 54 | self.net.to(self.device) 55 | 56 | optimizer = instantiate(self.cfg.optim, params=self.net.parameters()) 57 | # Train 58 | train(self.net, trainloader, epochs=config["epochs"], device=self.device, optim=optimizer) 59 | 60 | # Return local model and statistics 61 | return self.get_parameters(config), len(trainloader.dataset), {} 62 | 63 | def evaluate(self, parameters, config): 64 | self.set_parameters(parameters) 65 | 66 | # Load data for this client and get trainloader 67 | valloader = get_dataloader( 68 | self.fed_dir, self.cid, is_train=False, batch_size=50, workers=2 69 | ) 70 | 71 | # Send model to device 72 | self.net.to(self.device) 73 | 74 | # Evaluate 75 | loss, accuracy = test(self.net, valloader, device=self.device) 76 | 77 | # Return statistics 78 | return float(loss), len(valloader.dataset), {"accuracy": float(accuracy)} 79 | 80 | 81 | class FlowerClientWithKD(FlowerClient): 82 | """A Flower client that behaves as the standard client above for the most part. 83 | The main exception being that local training is done using Knowledge-Distillation 84 | using as teacher a model sent from the server. Under this formulation of federated 85 | KD, the server sends two models to the clients: a pre-trained teacher and a student. 86 | The latter is the one being updated/trained by the clients and hence the one being 87 | aggregated by the strategy in the server. Please note this is a very simple setup 88 | for demonstration purposes.""" 89 | 90 | 91 | def _instantiate_teacher(self, teacher: DictConfig, teacher_arrays): 92 | teacher_model = instantiate(teacher) # instantiate 93 | 94 | # copy params sent by server 95 | params_dict = zip(teacher_model.state_dict().keys(), teacher_arrays) 96 | state_dict = OrderedDict({k: torch.from_numpy(np.copy(v)) for k, v in params_dict}) 97 | teacher_model.load_state_dict(state_dict, strict=True) 98 | return teacher_model 99 | 100 | def fit(self, parameters, config): 101 | """The fit() method receives `parameters`, i.e., the parameters of the model being 102 | federated. In our example this correspond to the student model. The teacher model 103 | is sent as part of the config. Before we can train using KD, we need to instantiate 104 | the teacher.""" 105 | 106 | # update the local model with the parameters sent by the server. 107 | self.set_parameters(parameters) 108 | 109 | # instantiate teacher with parameters sent from server 110 | # (We could of course instantiate the teacher in the constructor of this class and only 111 | # update it's weights here -- for example if we have a more elaborated setup where the 112 | # teacher is also being periodically updated by the server) 113 | teacher = self._instantiate_teacher(config["teacher_cfg"], config["teacher_arrays"]) 114 | print(f"Client {self.cid} has loaded teacher network successfully!") 115 | 116 | # Load data for this client and get trainloader 117 | trainloader = get_dataloader( 118 | self.fed_dir, 119 | self.cid, 120 | is_train=True, 121 | batch_size=config["batch_size"], 122 | workers=config["num_cpu"], 123 | ) 124 | 125 | # Send model to device 126 | self.net.to(self.device) 127 | 128 | # Send the teacher to device too (be mindful that this will increase the resource utilisation of the clients. 129 | # this will likely require you to revisit the `client_resources` you setup when launching the simulation. This 130 | # is critical if you plan to use a high capacity teacher) 131 | teacher.to(self.device) 132 | 133 | # track parameters of student network 134 | optimizer = instantiate(self.cfg.optim, params=self.net.parameters()) 135 | 136 | # Train with distillation, We time it 137 | start_t = time() 138 | train_with_kd(self.net, teacher, config["KD_config"], trainloader, epochs=config["epochs"], device=self.device, optim=optimizer) 139 | # time (in secods) that took to do run `train_with_kd` 140 | total_t = time() - start_t 141 | 142 | # Return local model and statistics. You can return whatever you want using the last argument (the "Metrics", as are called in Flower) 143 | # Using Metrics is great to track in the sever different info about how the training on the clients is going 144 | # or when you are experiment with new setups. Just be mindful that, to stay true the FL spirit, no sensible info 145 | # should be sent back to the server. Even in simulation settings, incorporating client-side info that would normally 146 | # not be available in real deployments might limit the effectiveness of the method you are investigated when deploying 147 | # it out in the wild. 148 | return self.get_parameters(config), len(trainloader.dataset), {"fit_time": total_t} 149 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🗓️ FlowerMonthly 2 | 3 | > 🔎 Check https://flower.ai for all things Federated Learning and Flower! 4 | 5 | The FlowerMonthly is a monthly online event organised by the team behind [Flower, A Friendly Federated Learning Framework](https://flower.ai/) that runs for one hour on the first Wednesday of each month (typically starting at 0900 SF, 1200 NY, 1700 LON, 1800 MAD, 2130 IST, 0000 北京) and is divided into four blocks of content: 6 | 7 | 1. A platform update given by a member of the Flower team 8 | 2. A 30 min presentation by a leading expert in Federated Learning 9 | 3. A 15 min hands-on example of cool things people do with Flower 10 | 4. Open discussion and Q&A 11 | 12 | This repository contains some of the code examples presented in the Flower's FlowerMonthly series. You can see all past event [in the Flower site](https://flower.ai/conf/flower-monthly/). Jump on the fascinating FL train! 13 | 14 | > Join our [Slack channel](https://flower.ai/join-slack/) to chat directly to thousands already using Flower and to reach out to members of the Flower Team. Whether you are working on an amazing new feature or you hit a roadblock with your FL setup, [reach us also on GitHub](https://github.com/adap/flower) by submitting a PR or by opening an Issue. 15 | 16 | 17 | ## Content of this repo 18 | 19 | > This repo will keep getting more examples after each Flower Monthly so be sure to come by & pull again. 20 | 21 | 22 | To start this repo we have ported the [pytorch_simulation](https://github.com/adap/flower/tree/main/examples/simulation_pytorch) Flower example and adapted it so it works with [Hydra](https://hydra.cc/) configs to make the parameterisation of your FL experiments easy and flexible. The same could have been achieved using [AwesomeYaml](https://github.com/SamsungLabs/awesomeyaml) or other config systems. In fact, a previous version of this repo was entirely designed around AwesomeYaml (see tag `withAwesomeYaml` tag). I have added some small changes to the code provided by that example to make this repo more interesting, some of which is based on FlowerMonthly demos and talks. The code in this repo is validated using Flower's Virtual Client Engine for Simulation of FL workloads. However, the vast majority of the code here can be directly be used in gRPC-based Flower setups outside simulation. 23 | 24 | The purpose of this repo is to showcase through simple examples different functionalities of [Flower](https://github.com/adap/flower) (**give it a :star: if you use it**) so you can later use it in your projects. With this in mind, the dataset considered here considered, its partitioning and the training protocol as a whole is kept fairly simple. Here I use CIFAR-10 and split it following [LDA](https://arxiv.org/abs/1909.06335) for a fixed value of \alpha (which you can tune in the configs). By default I generate a 100-client split and sample 10 clients per round (this is a simple but very typical _cross-silo_ FL setup). Please note in this repo I have set sensible values for the hyperparameters but they likely need to be adjusted for each different experiment in this repo. **The purpose of this repo is to showcase how to use Flower in different ways whether you need the vanilla behaviour or a highly customised FL pipeline** 25 | 26 | Currently, this repo provides: 27 | 28 | * A `conf/strategy/strategy_model_saving.yaml` config showing how with small changes to a standard Flower strategy you can keep track of all variables you want (e.g. global model state) so you can either generate checkpoints or retrieve their values at the end of the simulation. 29 | * A `conf/strategy/strategy_kd.yaml` config (based on 7 June 2023 FLowerSummit talk) showing how to do a simple form of federated Knowledge-distillation. 30 | * A `conf/strategy/custom_strategy.yaml` config (based on 7 June 2023 FLowerSummit talk) showcasing how to design a custom Flower strategy with ease. 31 | * A `conf/base_kd.yaml` a top-level config that you can run to see a simple federated KD setting in action. 32 | * A `conf/base_v2.yaml` a top-level config that makes the setup in `base.yaml` a bit more interesting: using ResNet18, clients using GPU and a custom strategy. 33 | * A `conf/base.yaml` a top-level config with all the elements needed to define a complete FL setup. It uses a very lightweight model so all systems should be capable of running it (no GPU required). 34 | * Integration with `Hydra`, so you can customise how your experiment runs directly from the command line. 35 | 36 | ## Setup 37 | 38 | While there are different ways of setting up your Python environment, here I'll assume a [Miniconda](https://docs.conda.io/en/latest/miniconda.html) installation is reachable from a standard bash/zsh terminal. These are the steps to setup the environment: 39 | 40 | ```bash 41 | # create environment and activate 42 | conda create -n flowermonthly python=3.10 -y 43 | source activate flowermonthly 44 | 45 | # install pytorch et al (you might want to adjust the command below depending on your platform/OS: https://pytorch.org/get-started/locally/) 46 | conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia 47 | 48 | # install flower and other deps 49 | pip install -r requirements.txt 50 | ``` 51 | 52 | 53 | ## Flower + Hydra for beginners 54 | 55 | This section provides an introductory look into [`Hydra`](https://hydra.cc/) in the context of Federated Learning with Flower. You can run the code just as originally designed (no change in parameters) like this: 56 | ```bash 57 | python main.py # this will use the default configs (i.e. everything defined in conf/base.yaml) 58 | ``` 59 | 60 | With Hydra you can easily change whatever you need from the config file without having to add a new `argparse` argument each time. For example: 61 | ```bash 62 | python main.py server.client_per_round=5 # will use 5 clients per round instead of the default 10 63 | python main.py dataset.lda_alpha=1000 # will use LDA alpha=1000 (making it IID) instead of the default value (1.0) 64 | 65 | python main.py client.resources.num_cpus=4 # allocates 4 CPUs to each client (instead of the default 2 as defined in conf/client/cpu_client.yaml -- cpu_client is the default client to use as defined in conf/base.yaml->default.client) 66 | ``` 67 | 68 | In some settings, you might want to make more substantial changes to the default config. For that, even though you could probably still doing from the command line, it can get messy... Instead, you can directly replace entire structures in your config with others. For example, let's say you want to change your entire client definition from the default one (check it in `conf/client/cpu_client.yaml`). You'll need to create a new yaml file, respecting the expected structure and place it in the same directory as `cpu_client.yaml`. This is exactly what I did with `gpu_client.yaml`. You can use the latter client as follows: 69 | ```bash 70 | python main.py # will use the default `cpu_client.yaml` 71 | 72 | # note that you'll need a GPU for this 73 | python main.py client=gpu_client # will use the client as defined in `conf/client/gpu_client.yaml` 74 | ``` 75 | 76 | Let's say now that you have a concrete setting you'd like to evaluate often enough without having to modify from the "base" config to do so. The best way to do this would be to define a new top-level config with the `default:` your setup needs. For example, let's imagine you want your new setting to always use `resnet18.yaml`, `gpu_client.yaml` and `custom_strategy.yaml`. You can define a custom top-level config as follows: 77 | 78 | ```yaml 79 | # this defines a top-level config (just like base.yaml) does but with changes to the `defaults` and the FL setup parameterised in `server:` 80 | 81 | defaults: # i.e. configs used if you simply run the code as `python main.py` 82 | - client: gpu_client # this points to the file: client/cpu_client.yaml 83 | - model: resnet18 # this points to the file: model/resnet18.yaml 84 | - strategy: custom_strategy # points to strategy/custom_strategy.yaml 85 | 86 | [...] # rest of the necessary elements: dataset, server, misc 87 | ``` 88 | 89 | The above config can be found in `conf/base_v2.yaml`. 90 | 91 | ## Different Experiments in this repo 92 | 93 | This repo contains a collection of experiments, all parameterised via the Hydra config structure inside `conf/`. The current list of experiments are: 94 | 95 | ```bash 96 | # runs the default experiment using standard FedAvg 97 | python main.py 98 | 99 | # you can change the default strategy to point to another one. 100 | # This example points to that in config `conf/strategy/strategy_model_saving.yaml` 101 | # it essentially shows how you can keep track of elements in your experiment 102 | # and retrieve (e.g. for saving them to disk) once simulation is completed 103 | python main.py strategy=strategy_model_saving 104 | 105 | # Overrides the config hardcoded in the @hydra decorator in the main.py to point to `conf/base_v2` 106 | # this experiments uses the CustomFedAvg and shows how you can change the behaviour of how 107 | # clients are sampled, udpates are aggregated, and the frequency at which the global model is evaluated 108 | python main.py --config-name=base_v2 109 | # If you'd like to run it with the cpu_client instead 110 | python main.py --config-name=base_v2 client=cpu_client 111 | 112 | # Run the `conf/base_kd.yaml` config to test a simple federated distillation setting 113 | # where the teacher is first pre-trained in the server and send to the clients along with 114 | # the smaller student network (i.e. the one that's being trained in a federated manner) 115 | python main.py --config-name=base_kd 116 | 117 | # and if you still want to override some of the settings you can totally do so as shown earlier in the readme 118 | # will change the temperature used in FlowerClientWithKD's fit() method 119 | python main.py --config-name=base_kd strategy.kd_config.student_train.temperature=5 120 | ``` -------------------------------------------------------------------------------- /src/strategy.py: -------------------------------------------------------------------------------- 1 | from random import random 2 | from copy import deepcopy 3 | from typing import Dict, Optional, Union, Tuple, List 4 | 5 | from hydra.utils import call, instantiate 6 | from hydra.core.hydra_config import HydraConfig 7 | 8 | import torch 9 | from torch.utils.data import DataLoader 10 | from torchvision.datasets import CIFAR10 11 | 12 | from tqdm import tqdm 13 | 14 | from flwr.server.strategy import FedAvg 15 | from flwr.common.typing import Parameters, FitIns, FitRes 16 | from flwr.server.client_proxy import ClientProxy 17 | from flwr.server.client_manager import ClientManager 18 | from flwr.common import Parameters, Scalar, parameters_to_ndarrays 19 | 20 | from .datasets import cifar10Transformation 21 | 22 | 23 | class CustomFedAvg(FedAvg): 24 | """My customised FedAvg Strategy. It inherits from FedAvg. 25 | The ideas implemented here are designed with FL simulation in mind 26 | and for test research ideas including, but not limited to, understanding 27 | how a new strategy would behave in scenarios with unusual client participation 28 | patterns (and therefore requiring special sampling); simulating client failure ( 29 | and therefore excluding certain updates from being aggregated). 30 | """ 31 | def __init__(self, num_rounds: int, eval_every_n: int=5, 32 | keep_ratio: float=0.5, drop_ratio: float=0.25, *args, **kwargs): 33 | 34 | self.num_rounds = num_rounds # total rounds 35 | self.eval_every_n = eval_every_n # global eval freq 36 | self.keep_ratio = keep_ratio # ratio of clients to resample in the following round 37 | self.client_update_drop = drop_ratio # ratio of client updates to discard from aggregation 38 | super().__init__(*args, **kwargs) 39 | 40 | def evaluate(self, server_round: int, parameters: Parameters): 41 | """Evaluates global model every N rounds. Last round is always 42 | considered and flagged as such (e.g. to use global test set)""" 43 | 44 | is_last_round = server_round == self.num_rounds 45 | 46 | if (server_round % self.eval_every_n == 0) or \ 47 | (server_round == self.num_rounds): 48 | parameters_ndarrays = parameters_to_ndarrays(parameters) 49 | loss, metrics = self.evaluate_fn(server_round, 50 | parameters_ndarrays, 51 | config={}, 52 | is_last_round=is_last_round) 53 | return loss, metrics 54 | else: 55 | print(f"Only evaluating every {self.eval_every_n} rounds...") 56 | return None 57 | 58 | def configure_fit(self, server_round: int, parameters: Parameters, 59 | client_manager: ClientManager): 60 | """Configure the next round of training. In the first round we sample 61 | N clients from the M available and track which clients have been sampled. 62 | In subsequent rounds we sample again 100*keep_ratio % of the previously sampled 63 | clients and sample the remaining (so we have N participants) out of the remaining 64 | ones. This is stochastic (it is likely not the same number of clients will always 65 | be kept for the next round)""" 66 | 67 | config = {} 68 | if self.on_fit_config_fn is not None: 69 | # Custom fit config function provided 70 | config = self.on_fit_config_fn(server_round) 71 | 72 | # construct instructions to be send to each client. 73 | # parameters: list of numpy arrays with model parameters 74 | # config: a python dictionary that parameterises the client's fit() method 75 | fit_ins = FitIns(parameters, config) 76 | 77 | # interface with the client manager to get statistics of the available pool of 78 | # clients that can be sampled in this given round. 79 | av_clients = client_manager.num_available() 80 | sample_size, min_num_clients = self.num_fit_clients(av_clients) 81 | 82 | if server_round == 1: # first round, random uniform sampling (standard) 83 | 84 | clients = client_manager.sample( 85 | num_clients=sample_size, min_num_clients=min_num_clients) 86 | 87 | else: 88 | # stochastically drop clients used in previous round 89 | clients = [cli for cli in self.prev_clients if random() < self.keep_ratio] 90 | print(f"Round {server_round} will resample clients: {[client.cid for client in clients]}") 91 | 92 | # sample more clients 93 | extra_clients = client_manager.sample( 94 | num_clients=sample_size - len(clients), 95 | min_num_clients=min_num_clients 96 | ) 97 | # append 98 | clients.extend(extra_clients) 99 | 100 | # record client proxies 101 | self.prev_clients = clients 102 | 103 | print(f"Round {server_round} sampled clients with cid: {[client.cid for client in self.prev_clients]}") 104 | 105 | # Return client/config pairs 106 | return [(client, fit_ins) for client in clients] 107 | 108 | def aggregate_fit(self, server_round: int, 109 | results: List[Tuple[ClientProxy, FitRes]], 110 | failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]]): 111 | """Here we aggregate the results received from the clients and use them 112 | to update the global model state. Contrary to vanilla FedAvg, we stochastically 113 | drop a portion client_update_drop of the client updates received, then we aggregate 114 | using the ones kept. This could be used to simulate a number of scenarios, for example: 115 | errors in the communication channel when the clients send their updates to the server; 116 | or simulate client failure or disconnection. (note that this same behaviour could be 117 | instead implemented in the client's fit() method with the additional benefit of not 118 | requiring to do training if we knew it that particular client was meant to be dropped anyway).""" 119 | 120 | # We can iterate over the results received from the clients 121 | # and record each client's id 122 | client_ids = [int(client_prox.cid) for client_prox, _ in results] 123 | 124 | # Create a drop mask 125 | drop_mask = [random() < (1.0 - self.client_update_drop) for _ in range(len(results))] 126 | dropped_cids = [cid for i, cid in enumerate(client_ids) if drop_mask[i]] 127 | print(f"CIDs of clients dropped: {dropped_cids}") 128 | 129 | # drop results according to mask 130 | results = [res for i, res in enumerate(results) if drop_mask[i]] 131 | 132 | # call the parent `aggregate_fit()` (i.e. that in standard FedAvg) 133 | return super().aggregate_fit(server_round, results, failures) 134 | 135 | 136 | class CustomFedAvgWithKD(FedAvg): 137 | """My customised FedAvg Strategy for a simple setup with client-side distillation. 138 | The student model is the one federated as usual. The teacher model is first trained 139 | upon strategy creation and then sent to the client in each round (see `configure_fit`). 140 | The client uses the teacher to train the student locally using KD.""" 141 | def __init__(self, teacher, kd_config, *args, **kwargs): 142 | 143 | self.teacher_cfg = teacher # we store the callable that can instantiate the teacher. We'll be sending this to the clients (in addition to the teacher weights) 144 | self.teacher = instantiate(teacher) # instantiate teacher 145 | self.kd_config = kd_config 146 | 147 | # pre-train the teacher (for the purpose of this example we'll just use a handful of batches 148 | # using the training set). This will make the teach immediately better than the the student 149 | # in the early stages of FL training (hence serving for our simple KD demo). Please note that 150 | # you'd normally will be doing the KD on a disjoint partition of data from that that's federated. 151 | # Likely this data would be from a common data distribution, so the KD is aligned. 152 | self._unrealistically_but_effectively_pretrain_the_teacher() 153 | 154 | # no need to do anything else, call default behaviour from parent vanilla FedAvg 155 | super().__init__(*args, **kwargs) 156 | 157 | 158 | def _unrealistically_but_effectively_pretrain_the_teacher(self, path_to_data="./data"): 159 | 160 | # Do training as you'd normally do in a centralised setup 161 | train_set = CIFAR10(root=path_to_data, train=True, download=True, transform=cifar10Transformation()) 162 | 163 | trainloader = DataLoader(train_set, batch_size=self.kd_config.teacher_pretrain.batch_size, num_workers=4) 164 | # instantiate optimiser as defined in the config and passing the teacher parameters 165 | optim = instantiate(self.kd_config.teacher_pretrain.optim, params=self.teacher.parameters()) 166 | 167 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 168 | criterion = torch.nn.CrossEntropyLoss() 169 | self.teacher.to(device) 170 | self.teacher.train() 171 | 172 | print(f"Pretraining teacher for {self.kd_config.teacher_pretrain.num_batches} batches of size {self.kd_config.teacher_pretrain.batch_size}") 173 | with tqdm(total=len(train_set), desc=f'pseudo-pretraining teacher') as t: 174 | for i, (images, labels) in enumerate(trainloader): 175 | images, labels = images.to(device), labels.to(device) 176 | optim.zero_grad() 177 | loss = criterion(self.teacher(images), labels) 178 | loss.backward() 179 | optim.step() 180 | 181 | t.update(images.shape[0]) 182 | 183 | if i + 1 == self.kd_config.teacher_pretrain.num_batches: 184 | break 185 | print("Teacher is pretrained") 186 | 187 | def _get_teacher_asarray(self): 188 | self.teacher.cpu() 189 | return [val.numpy() for _, val in self.teacher.state_dict().items()] 190 | 191 | def configure_fit(self, server_round: int, parameters: Parameters, 192 | client_manager: ClientManager): 193 | """Configure the next round of training. Standard behaviour as in FedAvg 194 | but fit instructions have been extended to include the teacher model and 195 | the config that describes how to do KD on the client side. """ 196 | 197 | config = {} 198 | if self.on_fit_config_fn is not None: 199 | # Custom fit config function provided 200 | config = call(self.on_fit_config_fn)(server_round) 201 | 202 | # flatten and add teacher model and KD config to config dict 203 | # this config will be received by each participating client 204 | config['teacher_cfg'] = self.teacher_cfg 205 | config['teacher_arrays'] = self._get_teacher_asarray() 206 | config['KD_config'] = self.kd_config.student_train 207 | 208 | fit_ins = FitIns(parameters, config) 209 | 210 | # Sample clients 211 | sample_size, min_num_clients = self.num_fit_clients( 212 | client_manager.num_available() 213 | ) 214 | clients = client_manager.sample( 215 | num_clients=sample_size, min_num_clients=min_num_clients 216 | ) 217 | 218 | # Return client/config pairs 219 | return [(client, fit_ins) for client in clients] 220 | 221 | def aggregate_fit(self, server_round: int, 222 | results: List[Tuple[ClientProxy, FitRes]], 223 | failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]]): 224 | """Here we aggregate the results received from the clients and use them 225 | to update the global model state. In this example we are receiving the updated 226 | student networks (the student is the model being federated while the teacher was 227 | pre-trained at the beginning of the experiment and left unchanged ever since). This 228 | custom `aggregate_fit` servers the purpose of showing how the metrics returned byt 229 | the clients can be easily extracted from the results.""" 230 | 231 | # We can iterate over the results received from the clients 232 | # and extract the metrics sent. 233 | fit_metrics = [(c_prox.cid, res.metrics) for c_prox, res in results] 234 | print(fit_metrics) # will print a list of (cid, times) tuples 235 | 236 | # With the metrics you could modify how the aggregation is done (we don't do so 237 | # here, instead call the parent aggregation method, i.e., that of the vanilla FedAvg) 238 | # You could also save record these metrics and use it to define the next round. As 239 | # usual, think twice what info from the client side you are using during prototyping, 240 | # since in real world deployments, you wouldn't want to use any of it unless it is 241 | # certain that privacy is not compromised. 242 | 243 | # call the parent `aggregate_fit()` (i.e. that in standard FedAvg) 244 | return super().aggregate_fit(server_round, results, failures) 245 | 246 | 247 | class CustomFedAvgWithModelSaving(FedAvg): 248 | """This is a custom strategy that behaves exactly like FedAvg 249 | with the difference of keeping track of the state of the global 250 | model. In this way, the strategy can save the model to disk 251 | after each evaluation. It also enables retrieving the model 252 | once `start_simulation` is completed. 253 | """ 254 | def __init__(self, *args, **kwargs): 255 | self.global_parameters = None 256 | super().__init__(*args, **kwargs) 257 | 258 | def _save_global_model(self, server_round: int, params): 259 | 260 | # output directory created by hydra for the current experiments 261 | save_path = HydraConfig.get().runtime.output_dir 262 | # TODO: save parameters, for instance as a pickle 263 | print(f"(NOT IMPLEMENTED) Saved global model in round {server_round} into: {save_path}") 264 | 265 | def evaluate(self, server_round: int, parameters: Parameters) -> Tuple[float, Dict[str, Scalar]]: 266 | loss, metrics = super().evaluate(server_round, parameters) 267 | 268 | # Here you could save your model parameters 269 | # you could for instance also pass information about the loss or other 270 | # metrics you obotained from the evaluate() stage complete in the line above 271 | params = parameters_to_ndarrays(parameters) # you likely want to convert them first to a list of NumPy arrays 272 | self._save_global_model(server_round, params) 273 | 274 | # additionally, we can update the parameters being tracked as a 275 | # class variable for this strategy. This will make it fairly 276 | # straightforward to retrieve the model parameters once the 277 | # simulation is completed. 278 | self.global_parameters = deepcopy(params) 279 | 280 | # return the outputs from evaluate() 281 | return loss, metrics 282 | -------------------------------------------------------------------------------- /src/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Adap GmbH. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Commonly used functions for generating partitioned datasets.""" 16 | 17 | # pylint: disable=invalid-name 18 | 19 | 20 | from typing import List, Optional, Tuple, Union 21 | 22 | import numpy as np 23 | from numpy.random import BitGenerator, Generator, SeedSequence 24 | 25 | XY = Tuple[np.ndarray, np.ndarray] 26 | XYList = List[XY] 27 | PartitionedDataset = Tuple[XYList, XYList] 28 | 29 | np.random.seed(2020) 30 | 31 | 32 | def float_to_int(i: float) -> int: 33 | """Return float as int but raise if decimal is dropped.""" 34 | if not i.is_integer(): 35 | raise Exception("Cast would drop decimals") 36 | 37 | return int(i) 38 | 39 | 40 | def sort_by_label(x: np.ndarray, y: np.ndarray) -> XY: 41 | """Sort by label. 42 | Assuming two labels and four examples the resulting label order 43 | would be 1,1,2,2 44 | """ 45 | idx = np.argsort(y, axis=0).reshape((y.shape[0])) 46 | return (x[idx], y[idx]) 47 | 48 | 49 | def sort_by_label_repeating(x: np.ndarray, y: np.ndarray) -> XY: 50 | """Sort by label in repeating groups. Assuming two labels and four examples 51 | the resulting label order would be 1,2,1,2. 52 | Create sorting index which is applied to by label sorted x, y 53 | .. code-block:: python 54 | # given: 55 | y = [ 56 | 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9 57 | ] 58 | # use: 59 | idx = [ 60 | 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19 61 | ] 62 | # so that y[idx] becomes: 63 | y = [ 64 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 65 | ] 66 | """ 67 | x, y = sort_by_label(x, y) 68 | 69 | num_example = x.shape[0] 70 | num_class = np.unique(y).shape[0] 71 | idx = ( 72 | np.array(range(num_example), np.int64) 73 | .reshape((num_class, num_example // num_class)) 74 | .transpose() 75 | .reshape(num_example) 76 | ) 77 | 78 | return (x[idx], y[idx]) 79 | 80 | 81 | def split_at_fraction(x: np.ndarray, y: np.ndarray, fraction: float) -> Tuple[XY, XY]: 82 | """Split x, y at a certain fraction.""" 83 | splitting_index = float_to_int(x.shape[0] * fraction) 84 | # Take everything BEFORE splitting_index 85 | x_0, y_0 = x[:splitting_index], y[:splitting_index] 86 | # Take everything AFTER splitting_index 87 | x_1, y_1 = x[splitting_index:], y[splitting_index:] 88 | return (x_0, y_0), (x_1, y_1) 89 | 90 | 91 | def shuffle(x: np.ndarray, y: np.ndarray) -> XY: 92 | """Shuffle x and y.""" 93 | idx = np.random.permutation(len(x)) 94 | return x[idx], y[idx] 95 | 96 | 97 | def partition(x: np.ndarray, y: np.ndarray, num_partitions: int) -> List[XY]: 98 | """Return x, y as list of partitions.""" 99 | return list(zip(np.split(x, num_partitions), np.split(y, num_partitions))) 100 | 101 | 102 | def combine_partitions(xy_list_0: XYList, xy_list_1: XYList) -> XYList: 103 | """Combine two lists of ndarray Tuples into one list.""" 104 | return [ 105 | (np.concatenate([x_0, x_1], axis=0), np.concatenate([y_0, y_1], axis=0)) 106 | for (x_0, y_0), (x_1, y_1) in zip(xy_list_0, xy_list_1) 107 | ] 108 | 109 | 110 | def shift(x: np.ndarray, y: np.ndarray) -> XY: 111 | """Shift x_1, y_1 so that the first half contains only labels 0 to 4 and 112 | the second half 5 to 9.""" 113 | x, y = sort_by_label(x, y) 114 | 115 | (x_0, y_0), (x_1, y_1) = split_at_fraction(x, y, fraction=0.5) 116 | (x_0, y_0), (x_1, y_1) = shuffle(x_0, y_0), shuffle(x_1, y_1) 117 | x, y = np.concatenate([x_0, x_1], axis=0), np.concatenate([y_0, y_1], axis=0) 118 | return x, y 119 | 120 | 121 | def create_partitions( 122 | unpartitioned_dataset: XY, 123 | iid_fraction: float, 124 | num_partitions: int, 125 | ) -> XYList: 126 | """Create partitioned version of a training or test set. 127 | Currently tested and supported are MNIST, FashionMNIST and 128 | CIFAR-10/100 129 | """ 130 | x, y = unpartitioned_dataset 131 | 132 | x, y = shuffle(x, y) 133 | x, y = sort_by_label_repeating(x, y) 134 | 135 | (x_0, y_0), (x_1, y_1) = split_at_fraction(x, y, fraction=iid_fraction) 136 | 137 | # Shift in second split of dataset the classes into two groups 138 | x_1, y_1 = shift(x_1, y_1) 139 | 140 | xy_0_partitions = partition(x_0, y_0, num_partitions) 141 | xy_1_partitions = partition(x_1, y_1, num_partitions) 142 | 143 | xy_partitions = combine_partitions(xy_0_partitions, xy_1_partitions) 144 | 145 | # Adjust x and y shape 146 | return [adjust_xy_shape(xy) for xy in xy_partitions] 147 | 148 | 149 | def create_partitioned_dataset( 150 | keras_dataset: Tuple[XY, XY], 151 | iid_fraction: float, 152 | num_partitions: int, 153 | ) -> Tuple[PartitionedDataset, XY]: 154 | """Create partitioned version of keras dataset. 155 | Currently tested and supported are MNIST, FashionMNIST and 156 | CIFAR-10/100 157 | """ 158 | xy_train, xy_test = keras_dataset 159 | 160 | xy_train_partitions = create_partitions( 161 | unpartitioned_dataset=xy_train, 162 | iid_fraction=iid_fraction, 163 | num_partitions=num_partitions, 164 | ) 165 | 166 | xy_test_partitions = create_partitions( 167 | unpartitioned_dataset=xy_test, 168 | iid_fraction=iid_fraction, 169 | num_partitions=num_partitions, 170 | ) 171 | 172 | return (xy_train_partitions, xy_test_partitions), adjust_xy_shape(xy_test) 173 | 174 | 175 | def log_distribution(xy_partitions: XYList) -> None: 176 | """Print label distribution for list of paritions.""" 177 | distro = [np.unique(y, return_counts=True) for _, y in xy_partitions] 178 | for d in distro: 179 | print(d) 180 | 181 | 182 | def adjust_xy_shape(xy: XY) -> XY: 183 | """Adjust shape of both x and y.""" 184 | x, y = xy 185 | if x.ndim == 3: 186 | x = adjust_x_shape(x) 187 | if y.ndim == 2: 188 | y = adjust_y_shape(y) 189 | return (x, y) 190 | 191 | 192 | def adjust_x_shape(nda: np.ndarray) -> np.ndarray: 193 | """Turn shape (x, y, z) into (x, y, z, 1).""" 194 | nda_adjusted = np.reshape(nda, (nda.shape[0], nda.shape[1], nda.shape[2], 1)) 195 | return nda_adjusted 196 | 197 | 198 | def adjust_y_shape(nda: np.ndarray) -> np.ndarray: 199 | """Turn shape (x, 1) into (x).""" 200 | nda_adjusted = np.reshape(nda, (nda.shape[0])) 201 | return nda_adjusted 202 | 203 | 204 | def split_array_at_indices( 205 | x: np.ndarray, split_idx: np.ndarray 206 | ) -> List[List[np.ndarray]]: 207 | """Splits an array `x` into list of elements using starting indices from 208 | `split_idx`. 209 | This function should be used with `unique_indices` from `np.unique()` after 210 | sorting by label. 211 | Args: 212 | x (np.ndarray): Original array of dimension (N,a,b,c,...) 213 | split_idx (np.ndarray): 1-D array contaning increasing number of 214 | indices to be used as partitions. Initial value must be zero. Last value 215 | must be less than N. 216 | Returns: 217 | List[List[np.ndarray]]: List of list of samples. 218 | """ 219 | 220 | if split_idx.ndim != 1: 221 | raise ValueError("Variable `split_idx` must be a 1-D numpy array.") 222 | if split_idx.dtype != np.int64: 223 | raise ValueError("Variable `split_idx` must be of type np.int64.") 224 | if split_idx[0] != 0: 225 | raise ValueError("First value of `split_idx` must be 0.") 226 | if split_idx[-1] >= x.shape[0]: 227 | raise ValueError( 228 | """Last value in `split_idx` must be less than 229 | the number of samples in `x`.""" 230 | ) 231 | if not np.all(split_idx[:-1] <= split_idx[1:]): 232 | raise ValueError("Items in `split_idx` must be in increasing order.") 233 | 234 | num_splits: int = len(split_idx) 235 | split_idx = np.append(split_idx, x.shape[0]) 236 | 237 | list_samples_split: List[List[np.ndarray]] = [[] for _ in range(num_splits)] 238 | for j in range(num_splits): 239 | tmp_x = x[split_idx[j] : split_idx[j + 1]] # noqa: E203 240 | for sample in tmp_x: 241 | list_samples_split[j].append(sample) 242 | 243 | return list_samples_split 244 | 245 | 246 | def exclude_classes_and_normalize( 247 | distribution: np.ndarray, exclude_dims: List[bool], eps: float = 1e-5 248 | ) -> np.ndarray: 249 | """Excludes classes from a distribution. 250 | This function is particularly useful when sampling without replacement. 251 | Classes for which no sample is available have their probabilities are set to 0. 252 | Classes that had probabilities originally set to 0 are incremented with 253 | `eps` to allow sampling from remaining items. 254 | Args: 255 | distribution (np.array): Distribution being used. 256 | exclude_dims (List[bool]): Dimensions to be excluded. 257 | eps (float, optional): Small value to be addad to non-excluded dimensions. 258 | Defaults to 1e-5. 259 | Returns: 260 | np.ndarray: Normalized distributions. 261 | """ 262 | if np.any(distribution < 0) or (not np.isclose(np.sum(distribution), 1.0)): 263 | raise ValueError("distribution must sum to 1 and have only positive values.") 264 | 265 | if distribution.size != len(exclude_dims): 266 | raise ValueError( 267 | """Length of distribution must be equal 268 | to the length `exclude_dims`.""" 269 | ) 270 | if eps < 0: 271 | raise ValueError("""The value of `eps` must be positive and small.""") 272 | 273 | distribution[[not x for x in exclude_dims]] += eps 274 | distribution[exclude_dims] = 0.0 275 | sum_rows = np.sum(distribution) + np.finfo(float).eps 276 | distribution = distribution / sum_rows 277 | 278 | return distribution 279 | 280 | 281 | def sample_without_replacement( 282 | distribution: np.ndarray, 283 | list_samples: List[List[np.ndarray]], 284 | num_samples: int, 285 | empty_classes: List[bool], 286 | ) -> Tuple[XY, List[bool]]: 287 | """Samples from a list without replacement using a given distribution. 288 | Args: 289 | distribution (np.ndarray): Distribution used for sampling. 290 | list_samples(List[List[np.ndarray]]): List of samples. 291 | num_samples (int): Total number of items to be sampled. 292 | empty_classes (List[bool]): List of booleans indicating which classes are empty. 293 | This is useful to differentiate which classes should still be sampled. 294 | Returns: 295 | XY: Dataset contaning samples 296 | List[bool]: empty_classes. 297 | """ 298 | if np.sum([len(x) for x in list_samples]) < num_samples: 299 | raise ValueError( 300 | """Number of samples in `list_samples` is less than `num_samples`""" 301 | ) 302 | 303 | # Make sure empty classes are not sampled 304 | # and solves for rare cases where 305 | if not empty_classes: 306 | empty_classes = len(distribution) * [False] 307 | 308 | distribution = exclude_classes_and_normalize( 309 | distribution=distribution, exclude_dims=empty_classes 310 | ) 311 | 312 | data: List[np.ndarray] = [] 313 | target: List[np.ndarray] = [] 314 | 315 | for _ in range(num_samples): 316 | sample_class = np.where(np.random.multinomial(1, distribution) == 1)[0][0] 317 | sample: np.ndarray = list_samples[sample_class].pop() 318 | 319 | data.append(sample) 320 | target.append(sample_class) 321 | 322 | # If last sample of the class was drawn, then set the 323 | # probability density function (PDF) to zero for that class. 324 | if len(list_samples[sample_class]) == 0: 325 | empty_classes[sample_class] = True 326 | # Be careful to distinguish between classes that had zero probability 327 | # and classes that are now empty 328 | distribution = exclude_classes_and_normalize( 329 | distribution=distribution, exclude_dims=empty_classes 330 | ) 331 | data_array: np.ndarray = np.concatenate([data], axis=0) 332 | target_array: np.ndarray = np.array(target, dtype=np.int64) 333 | 334 | return (data_array, target_array), empty_classes 335 | 336 | 337 | def get_partitions_distributions(partitions: XYList) -> Tuple[np.ndarray, List[int]]: 338 | """Evaluates the distribution over classes for a set of partitions. 339 | Args: 340 | partitions (XYList): Input partitions 341 | Returns: 342 | np.ndarray: Distributions of size (num_partitions, num_classes) 343 | """ 344 | # Get largest available label 345 | labels = set() 346 | for _, y in partitions: 347 | labels.update(set(y)) 348 | list_labels = sorted(list(labels)) 349 | bin_edges = np.arange(len(list_labels) + 1) 350 | 351 | # Pre-allocate distributions 352 | distributions = np.zeros((len(partitions), len(list_labels)), dtype=np.float32) 353 | for idx, (_, _y) in enumerate(partitions): 354 | hist, _ = np.histogram(_y, bin_edges) 355 | distributions[idx] = hist / hist.sum() 356 | 357 | return distributions, list_labels 358 | 359 | 360 | def create_lda_partitions( 361 | dataset: XY, 362 | dirichlet_dist: Optional[np.ndarray] = None, 363 | num_partitions: int = 100, 364 | concentration: Union[float, np.ndarray, List[float]] = 0.5, 365 | accept_imbalanced: bool = False, 366 | seed: Optional[Union[int, SeedSequence, BitGenerator, Generator]] = None, 367 | ) -> Tuple[XYList, np.ndarray]: 368 | """Create imbalanced non-iid partitions using Latent Dirichlet Allocation 369 | (LDA) without resampling. 370 | Args: 371 | dataset (XY): Dataset containing samples X and labels Y. 372 | dirichlet_dist (numpy.ndarray, optional): previously generated distribution to 373 | be used. This is useful when applying the same distribution for train and 374 | validation sets. 375 | num_partitions (int, optional): Number of partitions to be created. 376 | Defaults to 100. 377 | concentration (float, np.ndarray, List[float]): Dirichlet Concentration 378 | (:math:`\\alpha`) parameter. Set to float('inf') to get uniform partitions. 379 | An :math:`\\alpha \\to \\Inf` generates uniform distributions over classes. 380 | An :math:`\\alpha \\to 0.0` generates one class per client. Defaults to 0.5. 381 | accept_imbalanced (bool): Whether or not to accept imbalanced output classes. 382 | Default False. 383 | seed (None, int, SeedSequence, BitGenerator, Generator): 384 | A seed to initialize the BitGenerator for generating the Dirichlet 385 | distribution. This is defined in Numpy's official documentation as follows: 386 | If None, then fresh, unpredictable entropy will be pulled from the OS. 387 | One may also pass in a SeedSequence instance. 388 | Additionally, when passed a BitGenerator, it will be wrapped by Generator. 389 | If passed a Generator, it will be returned unaltered. 390 | See official Numpy Documentation for further details. 391 | Returns: 392 | Tuple[XYList, numpy.ndarray]: List of XYList containing partitions 393 | for each dataset and the dirichlet probability density functions. 394 | """ 395 | # pylint: disable=too-many-arguments,too-many-locals 396 | 397 | x, y = dataset 398 | x, y = shuffle(x, y) 399 | x, y = sort_by_label(x, y) 400 | 401 | if (x.shape[0] % num_partitions) and (not accept_imbalanced): 402 | raise ValueError( 403 | """Total number of samples must be a multiple of `num_partitions`. 404 | If imbalanced classes are allowed, set 405 | `accept_imbalanced=True`.""" 406 | ) 407 | 408 | num_samples = num_partitions * [0] 409 | for j in range(x.shape[0]): 410 | num_samples[j % num_partitions] += 1 411 | 412 | # Get number of classes and verify if they matching with 413 | classes, start_indices = np.unique(y, return_index=True) 414 | 415 | # Make sure that concentration is np.array and 416 | # check if concentration is appropriate 417 | concentration = np.asarray(concentration) 418 | 419 | # Check if concentration is Inf, if so create uniform partitions 420 | partitions: List[XY] = [(_, _) for _ in range(num_partitions)] 421 | if float("inf") in concentration: 422 | 423 | partitions = create_partitions( 424 | unpartitioned_dataset=(x, y), 425 | iid_fraction=1.0, 426 | num_partitions=num_partitions, 427 | ) 428 | dirichlet_dist = get_partitions_distributions(partitions)[0] 429 | 430 | return partitions, dirichlet_dist 431 | 432 | if concentration.size == 1: 433 | concentration = np.repeat(concentration, classes.size) 434 | elif concentration.size != classes.size: # Sequence 435 | raise ValueError( 436 | f"The size of the provided concentration ({concentration.size}) ", 437 | f"must be either 1 or equal number of classes {classes.size})", 438 | ) 439 | 440 | # Split into list of list of samples per class 441 | list_samples_per_class: List[List[np.ndarray]] = split_array_at_indices( 442 | x, start_indices 443 | ) 444 | 445 | if dirichlet_dist is None: 446 | dirichlet_dist = np.random.default_rng(seed).dirichlet( 447 | alpha=concentration, size=num_partitions 448 | ) 449 | 450 | if dirichlet_dist.size != 0: 451 | if dirichlet_dist.shape != (num_partitions, classes.size): 452 | raise ValueError( 453 | f"""The shape of the provided dirichlet distribution 454 | ({dirichlet_dist.shape}) must match the provided number 455 | of partitions and classes ({num_partitions},{classes.size})""" 456 | ) 457 | 458 | # Assuming balanced distribution 459 | empty_classes = classes.size * [False] 460 | for partition_id in range(num_partitions): 461 | partitions[partition_id], empty_classes = sample_without_replacement( 462 | distribution=dirichlet_dist[partition_id].copy(), 463 | list_samples=list_samples_per_class, 464 | num_samples=num_samples[partition_id], 465 | empty_classes=empty_classes, 466 | ) 467 | 468 | return partitions, dirichlet_dist --------------------------------------------------------------------------------