├── .DS_Store ├── .vscode └── settings.json ├── images ├── pods.png ├── selectors.png └── kubeflower.png ├── out ├── .DS_Store └── operator │ ├── .DS_Store │ └── descriptor │ └── descriptor.png ├── .gitignore ├── descriptors ├── .DS_Store ├── serverService.yaml ├── claimerDeploy.yaml ├── volumeClaim.yaml ├── serverDeploy.yaml ├── clientDeploy.yaml ├── clientDeploy_multinode.yaml └── copier.yaml ├── requirements.txt ├── dockerfile ├── src ├── server.py └── client.py └── README.md /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpn-bristol/kubeFlower/HEAD/.DS_Store -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.typeCheckingMode": "off" 3 | } -------------------------------------------------------------------------------- /images/pods.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpn-bristol/kubeFlower/HEAD/images/pods.png -------------------------------------------------------------------------------- /out/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpn-bristol/kubeFlower/HEAD/out/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | socket_vmnet/ 2 | data/ 3 | */data/ 4 | __pycache__/ 5 | *.py[cod] 6 | *.tar 7 | -------------------------------------------------------------------------------- /images/selectors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpn-bristol/kubeFlower/HEAD/images/selectors.png -------------------------------------------------------------------------------- /descriptors/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpn-bristol/kubeFlower/HEAD/descriptors/.DS_Store -------------------------------------------------------------------------------- /images/kubeflower.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpn-bristol/kubeFlower/HEAD/images/kubeflower.png -------------------------------------------------------------------------------- /out/operator/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpn-bristol/kubeFlower/HEAD/out/operator/.DS_Store -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | numpy 3 | flwr 4 | torch 5 | torchvision 6 | torchaudio 7 | optuna 8 | opacus 9 | matplotlib -------------------------------------------------------------------------------- /out/operator/descriptor/descriptor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpn-bristol/kubeFlower/HEAD/out/operator/descriptor/descriptor.png -------------------------------------------------------------------------------- /descriptors/serverService.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: service-server 5 | spec: 6 | selector: 7 | app: flower-server 8 | type: ClusterIP 9 | ports: 10 | - port: 30051 11 | targetPort: 8080 -------------------------------------------------------------------------------- /dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9.8-slim-bullseye 2 | WORKDIR /app 3 | RUN /usr/local/bin/python -m pip install --upgrade pip 4 | COPY ./requirements.txt . 5 | RUN pip install --no-cache-dir --upgrade -r requirements.txt 6 | COPY ./src src 7 | CMD ["/bin/sh", "-c", "while sleep 1000; do :; done"] -------------------------------------------------------------------------------- /descriptors/claimerDeploy.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Pod 3 | metadata: 4 | name: my-pod 5 | spec: 6 | containers: 7 | - name: my-container 8 | image: flower:latest 9 | imagePullPolicy: IfNotPresent 10 | volumeMounts: 11 | - name: my-volume 12 | mountPath: /usr/data/flower 13 | volumes: 14 | - name: my-volume 15 | hostPath: 16 | path: /data/flower 17 | type: Directory 18 | -------------------------------------------------------------------------------- /descriptors/volumeClaim.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolume 3 | metadata: 4 | name: my-pv 5 | spec: 6 | storageClassName: standard 7 | capacity: 8 | storage: 5Gi 9 | accessModes: 10 | - ReadWriteOnce 11 | hostPath: 12 | path: /data/flower 13 | --- 14 | apiVersion: v1 15 | kind: PersistentVolumeClaim 16 | metadata: 17 | name: my-pvc 18 | spec: 19 | accessModes: 20 | - ReadWriteOnce 21 | resources: 22 | requests: 23 | storage: 5Gi 24 | selector: 25 | matchLabels: 26 | pv-name: my-pv 27 | -------------------------------------------------------------------------------- /descriptors/serverDeploy.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: flower-server 5 | labels: 6 | app: flower-server 7 | spec: 8 | replicas: 1 9 | selector: 10 | matchLabels: 11 | app: flower-server 12 | template: 13 | metadata: 14 | labels: 15 | app: flower-server 16 | spec: 17 | containers: 18 | - name: kubeflower 19 | image: kubeflower:latest 20 | imagePullPolicy: IfNotPresent 21 | command: ["/bin/sh", "-c"] 22 | args: ["python ./src/server.py"] 23 | ports: 24 | - containerPort: 8080 25 | 26 | -------------------------------------------------------------------------------- /descriptors/clientDeploy.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: flower-client 5 | labels: 6 | app: flower-client 7 | spec: 8 | replicas: 2 9 | selector: 10 | matchLabels: 11 | app: flower-client 12 | template: 13 | metadata: 14 | labels: 15 | app: flower-client 16 | spec: 17 | containers: 18 | - name: kubeflower 19 | image: kubeflower:latest 20 | imagePullPolicy: IfNotPresent 21 | command: ["/bin/sh", "-c"] 22 | args: ["python ./src/client.py --server 'service-server' --port '30051'"] 23 | ports: 24 | - containerPort: 30051 25 | -------------------------------------------------------------------------------- /descriptors/clientDeploy_multinode.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: flower-client 5 | labels: 6 | app: flower-client 7 | spec: 8 | replicas: 2 9 | strategy: 10 | type: RollingUpdate 11 | rollingUpdate: 12 | maxUnavailable: 100% 13 | selector: 14 | matchLabels: 15 | app: flower-client 16 | template: 17 | metadata: 18 | labels: 19 | app: flower-client 20 | spec: 21 | affinity: 22 | # ⬇⬇⬇ This ensures pods will land on separate hosts 23 | podAntiAffinity: 24 | requiredDuringSchedulingIgnoredDuringExecution: 25 | - labelSelector: 26 | matchExpressions: [{ key: app, operator: In, values: [flower-client] }] 27 | topologyKey: "kubernetes.io/hostname" 28 | containers: 29 | - name: kubeflower 30 | image: kubeflower:latest 31 | imagePullPolicy: IfNotPresent 32 | command: ["/bin/sh", "-c"] 33 | args: ["python ./src/client.py --server 'service-server' --port '30051'"] 34 | ports: 35 | - containerPort: 30051 36 | terminationGracePeriodSeconds: 1 -------------------------------------------------------------------------------- /descriptors/copier.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolume 3 | metadata: 4 | name: my-pv 5 | spec: 6 | storageClassName: standard 7 | capacity: 8 | storage: 5Gi 9 | accessModes: 10 | - ReadWriteOnce 11 | hostPath: 12 | path: /data/flower 13 | --- 14 | apiVersion: v1 15 | kind: PersistentVolumeClaim 16 | metadata: 17 | name: my-pvc 18 | spec: 19 | accessModes: 20 | - ReadWriteOnce 21 | resources: 22 | requests: 23 | storage: 5Gi 24 | selector: 25 | matchLabels: 26 | pv-name: my-pv 27 | --- 28 | apiVersion: v1 29 | kind: Pod 30 | metadata: 31 | name: my-pod 32 | spec: 33 | initContainers: 34 | - name: init 35 | image: busybox 36 | command: ['sh', '-c', 'cp -r /data/* /mnt/data'] 37 | volumeMounts: 38 | - name: data 39 | mountPath: /mnt/data 40 | - name: my-pv 41 | mountPath: /data 42 | containers: 43 | - name: my-container 44 | image: flower:latest 45 | volumeMounts: 46 | - name: my-pv 47 | mountPath: /data 48 | volumes: 49 | - name: my-pv 50 | persistentVolumeClaim: 51 | claimName: my-pvc 52 | - name: data 53 | hostPath: 54 | path: /data/flower -------------------------------------------------------------------------------- /src/server.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import flwr as fl 4 | from flwr.common import Metrics 5 | import argparse 6 | 7 | 8 | # Define metric aggregation function 9 | def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: 10 | # Multiply accuracy of each client by number of examples used 11 | accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] 12 | examples = [num_examples for num_examples, _ in metrics] 13 | 14 | # Aggregate and return custom metric (weighted average) 15 | return {"accuracy": sum(accuracies) / sum(examples)} 16 | 17 | #Parse inputs 18 | parser = argparse.ArgumentParser(description="Launches FL clients.") 19 | parser.add_argument('-clients',"--clients", type=int, default=2, help="Define the number of clients to be part of he FL process",) 20 | parser.add_argument('-min',"--min", type=int, default=2, help="Minimum number of available clients",) 21 | parser.add_argument('-rounds',"--rounds", type=int, default=5, help="Number of FL rounds",) 22 | args = vars(parser.parse_args()) 23 | num_clients = args['clients'] 24 | min_clients = args['min'] 25 | rounds = args['rounds'] 26 | 27 | # Define strategy 28 | strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average, min_fit_clients = num_clients, min_available_clients=min_clients) 29 | 30 | # Start Flower server 31 | fl.server.start_server( 32 | server_address="0.0.0.0:8080", 33 | config=fl.server.ServerConfig(num_rounds=rounds), 34 | strategy=strategy, 35 | ) -------------------------------------------------------------------------------- /src/client.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | from torchvision.datasets import CIFAR10 9 | from torchvision.transforms import Compose, Normalize, ToTensor 10 | from tqdm import tqdm 11 | import argparse 12 | 13 | 14 | # ############################################################################# 15 | # 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader 16 | # ############################################################################# 17 | 18 | warnings.filterwarnings("ignore", category=UserWarning) 19 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 20 | 21 | 22 | class Net(nn.Module): 23 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 24 | 25 | def __init__(self) -> None: 26 | super(Net, self).__init__() 27 | self.conv1 = nn.Conv2d(3, 6, 5) 28 | self.pool = nn.MaxPool2d(2, 2) 29 | self.conv2 = nn.Conv2d(6, 16, 5) 30 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 31 | self.fc2 = nn.Linear(120, 84) 32 | self.fc3 = nn.Linear(84, 10) 33 | 34 | def forward(self, x: torch.Tensor) -> torch.Tensor: 35 | x = self.pool(F.relu(self.conv1(x))) 36 | x = self.pool(F.relu(self.conv2(x))) 37 | x = x.view(-1, 16 * 5 * 5) 38 | x = F.relu(self.fc1(x)) 39 | x = F.relu(self.fc2(x)) 40 | return self.fc3(x) 41 | 42 | 43 | def train(net, trainloader, epochs): 44 | """Train the model on the training set.""" 45 | criterion = torch.nn.CrossEntropyLoss() 46 | optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) 47 | for _ in range(epochs): 48 | for images, labels in tqdm(trainloader): 49 | optimizer.zero_grad() 50 | criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward() 51 | optimizer.step() 52 | 53 | 54 | def test(net, testloader): 55 | """Validate the model on the test set.""" 56 | criterion = torch.nn.CrossEntropyLoss() 57 | correct, total, loss = 0, 0, 0.0 58 | with torch.no_grad(): 59 | for images, labels in tqdm(testloader): 60 | outputs = net(images.to(DEVICE)) 61 | labels = labels.to(DEVICE) 62 | loss += criterion(outputs, labels).item() 63 | total += labels.size(0) 64 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 65 | return loss / len(testloader.dataset), correct / total 66 | 67 | 68 | def load_data(): 69 | """Load CIFAR-10 (training and test set).""" 70 | print(f'Loading data from {datapath}') 71 | trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 72 | trainset = CIFAR10(datapath, train=True, download=True, transform=trf) 73 | testset = CIFAR10(datapath, train=False, download=True, transform=trf) 74 | return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset) 75 | 76 | 77 | # ############################################################################# 78 | # 2. Federation of the pipeline with Flower 79 | # ############################################################################# 80 | 81 | # Load model and data (simple CNN, CIFAR-10) 82 | parser = argparse.ArgumentParser(description="Launches FL clients.") 83 | parser.add_argument('-cid',"--cid", type=int, default=0, help="Define Client_ID",) 84 | parser.add_argument('-server',"--server", default="0.0.0.0", help="Server Address",) 85 | parser.add_argument('-port',"--port", default="8080", help="Server Port",) 86 | parser.add_argument('-data', "--data", default="./data", help="Dataset source path") 87 | args = vars(parser.parse_args()) 88 | cid = args['cid'] 89 | server = args['server'] 90 | port = args['port'] 91 | datapath = args['data'] 92 | net = Net().to(DEVICE) 93 | trainloader, testloader = load_data() 94 | 95 | # Define Flower client 96 | class FlowerClient(fl.client.NumPyClient): 97 | def get_parameters(self, config): 98 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 99 | 100 | def set_parameters(self, parameters): 101 | params_dict = zip(net.state_dict().keys(), parameters) 102 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 103 | net.load_state_dict(state_dict, strict=True) 104 | 105 | def fit(self, parameters, config): 106 | self.set_parameters(parameters) 107 | train(net, trainloader, epochs=1) 108 | return self.get_parameters(config={}), len(trainloader.dataset), {} 109 | 110 | def evaluate(self, parameters, config): 111 | self.set_parameters(parameters) 112 | loss, accuracy = test(net, testloader) 113 | return loss, len(testloader.dataset), {"accuracy": accuracy} 114 | 115 | print(f"Subscribing to FL server {server} on port {port}...") 116 | # Start Flower client 117 | fl.client.start_numpy_client( 118 | server_address=f"{server}:{port}", 119 | client=FlowerClient(), 120 | ) 121 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KubeFlower: Kubernetes-based Federated Learning 2 | 3 | *** The extension of this work can be found in [KubeFlower-Operator](https://github.com/REASON-6G/kubeflower-operator/tree/main) *** 4 | 5 | ## What is KubeFlower? 6 | Kubeflower is a project for exploiting the benefits of cloud-native and container-based technologies for the development, deployment and workload management of Federated Learning (FL) pipelines. We use the open-source framework [Flower](https://flower.dev/) for the FL workload control. Flower has been widely adopted in industry and academia. In order to increase computation elasticity and efficiency when deploying FL, we use the container orchestration system [Kubernetes](https://kubernetes.io/) (K8s). We use different concepts such as FL servers, FL clients, K8s clusters, K8s deployments, K8s pods, and K8s services. If you are not familiar with this terminology, please watch the following resources: [Federated Learning](https://youtu.be/nBGQQHPkyNY), [Kubernetes](https://youtu.be/s_o8dwzRlu4). 7 | 8 | ## Top-Level Features 9 | * Single and multi-node implementation. 10 | * High availability through clustering and distributed state management. 11 | * Scalability through clustering of network device control. 12 | * CLI for debugging. 13 | * Applicable to real-world scenarios. 14 | * Extendable. 15 | * Cross-platform (Linux, macOS, Windows). 16 | 17 | ## Getting started 18 | 19 | ### Dependencies 20 | 21 | For this proof-of-concept, a K8s cluster is deployed locally using minikube. The following packages are required and should be installed beforehand: 22 | * [git](https://git-scm.com/) 23 | * [docker](https://www.docker.com/) 24 | * [minikube](https://minikube.sigs.k8s.io/docs/) 25 | 26 | ### Step-by-step setup 27 | 1. Clone the present repository in the CLI. 28 | ```bash 29 | git clone git@github.com:hpn-bristol/kubeFlower.git 30 | ``` 31 | 2. Go to the folder that contains Kubeflower 32 | 3. Point your terminal to use the docker deamon inside minikube 33 | ```bash 34 | eval $(minikube docker-env) 35 | ``` 36 | 4. Deploy a K8s cluster in minikube. 37 | ```bash 38 | minikube start 39 | ``` 40 | 5. Check minikube docker images. 41 | ```bash 42 | minikube image list 43 | ``` 44 | You will find a list of standard k8s docker images for management. For example, k8s.gcr.io/kube-scheduler, k8s.gcr.io/kube-controller-manager, etc. 45 | 46 | 6. Build the docker image from this repo (dockerfile) with the requiered packages (requirements.txt). This image is based on python:3.9.8-slim-bullseye. 47 | ```bash 48 | minikube image build -t kubeflower . 49 | ``` 50 | where your image will be called `kubeflower` and the building source `.` will be the current folder. 51 | 52 | 7. Check your image has been succesfully added to the minikube docker. 53 | ```bash 54 | minikube image list 55 | ``` 56 | check that `kubeflower:latest` is in the list, where `latest` is the tag assigned to the docker image by default. 57 | 58 | ### Step-by-step deployment 59 | Now you are ready for deploying the FL pipeline using K8s. We will be using K8s deployments to create K8s pods that will use a K8s service for communications. Each pod represents a FL actor with a main pod that will act as a FL server. The proposed architecture is depicted in the figure. 60 | 61 | ![](images/kubeflower.png) 62 | 63 | The docker image `kubeflower` is used to deploy the containers with the Flower's pipeline and other dependencies. These containers are deployed in pods. The FL server Pod exposes port 8080 for the gRPC communication implemented by Flower. Instead of using a predefined IP for the server, we use K8s service `ClusterIP` that will allow to locate the FL server pod even if it restarts and change its IP. The service exposes the port 30051 which can be targeted by the FL Client Pods through `http:service-server:30051`. For the FL setup, we use the FL PyTorch implementation of Flower. This simple example can be found [here](https://flower.dev/docs/quickstart-pytorch.html). 64 | 65 | To deploy this architecture you need to: 66 | 67 | 1. Deploy the `service-server` K8s service. From the root folder run: 68 | ```bash 69 | kubectl apply -f descriptors/serverService.yaml 70 | ``` 71 | 72 | We are using ClusterIP but it can be modified with a NodePort or LoadBalancer if specific communications are required. 73 | 74 | 2. Deploy the FL server pod through the K8s deployment. 75 | ```bash 76 | kubectl apply -f descriptors/serverDeploy.yaml 77 | ``` 78 | By default, the server will start a run of 5 rounds when 2 clients are available. To change thess values, edit the `serverDeploy.yaml` file. Different values should be passed as arguments in the line ```args: ["python ./src/server.py"]```. Possible values are: --clients, --min, --rounds. 79 | 3. Check the SELECTOR for both the service and deployment. They should match `app=flower-server`. 80 | ```bash 81 | kubectl get all -owide 82 | ``` 83 | 84 | ![](images/selectors.png) 85 | 4. Deploy FL clients using the clientDeploy.yaml descriptor. 86 | ```bash 87 | kubectl apply -f descriptors/clientDeploy.yaml 88 | ``` 89 | 90 | By default, this descriptor will deploy 2 clients. To increase the number of clients, edit the `replicas: 2` value in the .yaml file. 91 | 5. Monitor the training process. 92 | ```bash 93 | kubectl get all 94 | ``` 95 | Get the pods IDs. 96 | 97 | ![](images/pods.png) 98 | 99 | Check the logs on the ```flower-server``` pod. 100 | ```bash 101 | kubectl logs flower-server-64f78b8c5c-kwf89 -f 102 | ``` 103 | 104 | Open a new terminal and check the logs on the ```flower-client``` pods. Repeat the process if required for the different clients. 105 | ```bash 106 | kubectl logs flower-client-7c69c8c776-cjw6r -f 107 | ``` 108 | 109 | 6. After the FL process has finished, kill the pods and services, and stop the K8s cluster on minikube. 110 | ```bash 111 | kubectl delete deploy flower-client flower-server 112 | 113 | kubectl delete service service-server 114 | 115 | minikube stop 116 | ``` 117 | This is a simple implementation of container-based FL using Flower and K8s for orchestration. For further discussions/ideas/projects, please contact the developers at the Smart Internet Lab. 118 | 119 | --------------------------------------------------------------------------------