├── federated-learning ├── datasets │ ├── __init__.py │ ├── IMAGENET.py │ ├── REALWORLD.py │ └── UCI.py ├── models │ ├── __init__.py │ ├── Fed.py │ ├── Train.py │ └── Test.py ├── utils │ ├── __init__.py │ ├── EnvStore.py │ ├── CentralStore.py │ ├── blockchain.py │ ├── options.py │ ├── Trainer.py │ ├── ModelStore.py │ ├── sampling.py │ └── DatasetStore.py ├── requirements.txt ├── .gitignore ├── results-merge │ ├── extract-round-acc.py │ ├── extract-time-cost.py │ └── utils.py ├── plot │ ├── main.py │ ├── cost-commu │ │ ├── cnn-cifar100-cost-communication.py │ │ ├── cnn-cifar10-cost-communication.py │ │ ├── cnn-imagenet-cost-communication.py │ │ ├── cnn-mnist-cost-communication.py │ │ ├── mlp-mnist-cost-communication.py │ │ ├── cnn-realworld-cost-communication.py │ │ ├── cnn-uci-cost-communication.py │ │ └── resnet-cifar10-cost-communication.py │ ├── cost-overall │ │ ├── mlp-mnist-cost-overall.py │ │ ├── cnn-mnist-cost-overall.py │ │ ├── cnn-cifar10-cost-overall.py │ │ ├── cnn-cifar100-cost-overall.py │ │ ├── cnn-imagenet-cost-overall.py │ │ ├── cnn-realworld-cost-overall.py │ │ ├── cnn-uci-cost-overall.py │ │ └── resnet-cifar10-cost-overall.py │ ├── acc-sota │ │ ├── resnet-cifar10-acc-sota.py │ │ ├── cnn-imagenet-acc-sota.py │ │ ├── cnn-cifar10-acc-sota.py │ │ ├── mlp-mnist-acc-sota.py │ │ ├── cnn-cifar100-acc-sota.py │ │ ├── cnn-mnist-acc-sota.py │ │ ├── cnn-realworld-acc-sota.py │ │ └── cnn-uci-acc-sota.py │ ├── acc-nodes │ │ ├── resnet-cifar10-acc-nodes.py │ │ ├── mlp-mnist-acc-nodes.py │ │ ├── cnn-cifar10-acc-nodes.py │ │ ├── cnn-mnist-acc-nodes.py │ │ ├── cnn-cifar100-acc-nodes.py │ │ ├── cnn-realworld-acc-nodes.py │ │ ├── cnn-imagenet-acc-nodes.py │ │ └── cnn-uci-acc-nodes.py │ ├── acc-alpha │ │ ├── resnet-cifar10-acc-alpha.py │ │ ├── cnn-imagenet-acc-alpha.py │ │ ├── cnn-cifar10-acc-alpha.py │ │ ├── cnn-mnist-acc-alpha.py │ │ ├── mlp-mnist-acc-alpha.py │ │ ├── cnn-cifar100-acc-alpha.py │ │ ├── cnn-realworld-acc-alpha.py │ │ └── cnn-uci-acc-alpha.py │ └── acc-skew │ │ ├── cnn-realworld-acc-skew.py │ │ ├── cnn-uci-acc-skew.py │ │ ├── cnn-imagenet-acc-skew.py │ │ ├── cnn-cifar100-acc-skew.py │ │ ├── mlp-mnist-acc-skew.py │ │ ├── cnn-cifar10-acc-skew.py │ │ └── resnet-cifar10-acc-skew.py ├── README.md └── local.py ├── cluster-scripts ├── clean_output.sh ├── gather_output.sh ├── test.config ├── all_test.sh ├── utils.sh ├── all_test_alpha.sh └── all_test_nodes.sh ├── raft ├── go.mod ├── main.go ├── README.md ├── store │ └── store_test.go ├── http │ └── service_test.go └── go.sum ├── .gitignore └── README.md /federated-learning/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cluster-scripts/clean_output.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rm -f ~/EASC/federated-learning/result-record_*.txt 4 | 5 | -------------------------------------------------------------------------------- /federated-learning/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | -------------------------------------------------------------------------------- /federated-learning/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | -------------------------------------------------------------------------------- /cluster-scripts/gather_output.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rm -rf output/ 4 | mkdir -p output/ 5 | 6 | cp ~/EASC/federated-learning/result-record_*.txt output/ 7 | cp ~/EASC/server.log output/ 8 | -------------------------------------------------------------------------------- /federated-learning/requirements.txt: -------------------------------------------------------------------------------- 1 | torch~=1.6.0 2 | torchvision~=0.7.0 3 | numpy~=1.18.2 4 | sklearn~=0.0 5 | pandas~=1.0.3 6 | matplotlib~=3.2.1 7 | requests~=2.23.0 8 | Flask~=2.0.2 9 | hickle~=4.0.4 -------------------------------------------------------------------------------- /raft/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/otoolep/hraftd 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/hashicorp/raft v1.2.0 7 | github.com/hashicorp/raft-boltdb v0.0.0-20191021154308-4207f1bf0617 8 | ) 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | /data 4 | node_modules 5 | bin/ 6 | vendor/ 7 | config/ 8 | wallet/ 9 | result-record_* 10 | output/ 11 | *.log 12 | raft/hraftd 13 | networkCache.tar.gz 14 | figures/ 15 | new_block 16 | -------------------------------------------------------------------------------- /cluster-scripts/test.config: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_DS=( 4 | # "cnn-cifar10" 5 | # "cnn-cifar100" 6 | # "cnn-imagenet" 7 | # "cnn-uci" 8 | # "cnn-realworld" 9 | # "cnn-mnist" 10 | # "mlp-mnist" 11 | "resnet-cifar10" 12 | ) 13 | 14 | SCHEMES=( 15 | "scei" 16 | "scei-async" 17 | "apfl" 18 | "fedavg" 19 | "local" 20 | ) 21 | -------------------------------------------------------------------------------- /federated-learning/.gitignore: -------------------------------------------------------------------------------- 1 | # pycharm 2 | .idea/* 3 | 4 | # documents 5 | *.csv 6 | .xls 7 | .xlsx 8 | .pdf 9 | .json 10 | 11 | # macOS 12 | .DS_Store 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | env/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # virtualenv 48 | .venv 49 | venv/ 50 | ENV/ 51 | 52 | -------------------------------------------------------------------------------- /federated-learning/results-merge/extract-round-acc.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from utils import calculate_average_across_files 4 | 5 | 6 | def extract_round_acc(): 7 | exp_node_number = "all_test" 8 | model_name = "resnet" 9 | dataset_name = "cifar10" 10 | 11 | experiment_names = ["apfl", "fedavg", "local", "scei", "scei-async"] 12 | 13 | for path, dirs, files in os.walk("./output"): 14 | if path.endswith(model_name + "-" + dataset_name) and exp_node_number in path: 15 | for experiment_name in experiment_names: 16 | experiment_path = os.path.join(path, experiment_name) 17 | files_numbers_mean_2d_np = calculate_average_across_files(experiment_path) 18 | acc = [round(i, 2) for i in files_numbers_mean_2d_np[:, 5]] 19 | print(experiment_name, "=", acc) 20 | 21 | 22 | def main(): 23 | extract_round_acc() 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /federated-learning/plot/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from pathlib import Path 4 | 5 | 6 | def plot_all(): 7 | experiment_names = ["acc-alpha", "acc-nodes", "acc-skew", "acc-sota", "cost-commu", "cost-overall"] 8 | 9 | # real_path = os.path.dirname(os.path.realpath(__file__)) 10 | Path("./figures").mkdir(parents=True, exist_ok=True) 11 | for experiment in experiment_names: 12 | for path, dirs, files in os.walk("./" + experiment): 13 | plot_subdir = os.path.join("./figures", path) 14 | Path(plot_subdir).mkdir(parents=True, exist_ok=True) 15 | for file in files: 16 | if file.endswith(".py"): 17 | python_file_path = os.path.join(path, file) 18 | output_file_path = os.path.join(plot_subdir, file[:-3] + ".pdf") 19 | subprocess.call(['python3', python_file_path, "save", output_file_path]) 20 | 21 | 22 | def main(): 23 | plot_all() 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /federated-learning/models/Fed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | import copy 5 | 6 | import torch 7 | 8 | 9 | def fed_avg(w_dict, w_glob, device): 10 | if len(w_dict) == 0: 11 | return w_glob 12 | w_avg = {} 13 | for k in w_glob.keys(): 14 | for local_uuid in w_dict: 15 | if k not in w_avg: 16 | w_avg[k] = torch.zeros_like(w_glob[k], device=device) 17 | if device != torch.device('cpu'): 18 | w_dict[local_uuid][k] = w_dict[local_uuid][k].to(device) 19 | w_avg[k] = torch.add(w_avg[k], w_dict[local_uuid][k]) 20 | w_avg[k] = torch.div(w_avg[k], len(w_dict)) 21 | return w_avg 22 | 23 | 24 | def async_fed_avg(w_local, w_glob, device): 25 | w_avg = copy.deepcopy(w_glob) 26 | for k in w_avg.keys(): 27 | if device != torch.device('cpu'): 28 | w_local[k] = w_local[k].to(device) 29 | w_avg[k] = torch.add(w_avg[k], w_local[k]) 30 | w_avg[k] = torch.div(w_avg[k], 2) 31 | return w_avg 32 | -------------------------------------------------------------------------------- /federated-learning/models/Train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def train_cnn_mlp(net, my_dataset, idx, local_ep, device, lr, momentum, local_bs): 6 | net.train() 7 | optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum) 8 | ldr_train = my_dataset.load_train_dataset(idx, local_bs) 9 | loss_func = nn.CrossEntropyLoss() 10 | 11 | epoch_loss = [] 12 | for _ in range(local_ep): 13 | batch_loss = [] 14 | for batch_idx, (images, labels) in enumerate(ldr_train): 15 | images = images.detach().clone().type(torch.FloatTensor) 16 | if device != "cpu": 17 | images, labels = images.to(device), labels.to(device) 18 | net.zero_grad() 19 | log_probs = net(images) 20 | loss = loss_func(log_probs, labels) 21 | loss.backward() 22 | optimizer.step() 23 | batch_loss.append(loss.item()) 24 | epoch_loss.append(sum(batch_loss) / len(batch_loss)) 25 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss) 26 | 27 | -------------------------------------------------------------------------------- /federated-learning/utils/EnvStore.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | 4 | from utils.options import args_parser 5 | from utils.util import ColoredLogger, get_ip 6 | 7 | logging.setLoggerClass(ColoredLogger) 8 | logger = logging.getLogger("EnvStore") 9 | 10 | 11 | class EnvStore: 12 | def __init__(self): 13 | self.trigger_url = "" 14 | self.from_ip = "" 15 | self.args = None 16 | 17 | def init(self): 18 | self.args = args_parser() 19 | self.args.device = torch.device( 20 | 'cuda:{}'.format(self.args.gpu) if torch.cuda.is_available() and self.args.gpu != -1 else 'cpu') 21 | self.from_ip = get_ip(self.args.test_ip_addr) 22 | self.trigger_url = "http://" + self.from_ip + ":" + str(self.args.fl_listen_port) + "/trigger" 23 | # print parameters in log 24 | arguments = vars(self.args) 25 | logger.info("==========================================") 26 | for k, v in arguments.items(): 27 | arg = "{}: {}".format(k, v) 28 | logger.info("* {0:<40}".format(arg)) 29 | logger.info("==========================================") 30 | -------------------------------------------------------------------------------- /cluster-scripts/all_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # set -x 4 | 5 | source ./test.config 6 | source ./utils.sh 7 | 8 | function main() { 9 | for i in "${!MODEL_DS[@]}"; do 10 | model_ds=(${MODEL_DS[i]//-/ }) 11 | model=${model_ds[0]} 12 | dataset=${model_ds[1]} 13 | echo "[`date`] ALL_NODE_TEST UNDER: ${model} - ${dataset}" 14 | 15 | for i in "${!SCHEMES[@]}"; do 16 | scheme="${SCHEMES[i]}" 17 | if [[ ! -d "${model}-${dataset}/${scheme}" ]]; then 18 | echo "[`date`] ## ${scheme} start ##" 19 | clean 20 | PYTHON_CMD="python3 -u ${scheme}.py --model=${model} --dataset=${dataset} --gpu=${GPU_NO}" 21 | cd $PWD/../federated-learning/; $PYTHON_CMD > $PWD/../server.log 2>&1 & 22 | cd - 23 | # detect test finish or not 24 | sleep 30 25 | testFinish "${scheme}" 26 | # gather output, move to the right directory 27 | arrangeOutput ${model} ${dataset} "${scheme}" 28 | echo "[`date`] ## ${scheme} done ##" 29 | fi 30 | done 31 | done 32 | } 33 | 34 | GPU_NO=$1 35 | if [[ -z "${GPU_NO}" ]]; then 36 | GPU_NO="-1" 37 | fi 38 | 39 | main > test.log 2>&1 & 40 | -------------------------------------------------------------------------------- /cluster-scripts/utils.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source ./test.config 4 | 5 | function getProcessName() { 6 | local scheme_name=$1 7 | local FIRST_CHAR=$(echo $scheme_name | cut -c1-1) 8 | local FOLLOWING_CHAR=$(echo $scheme_name | cut -c2-) 9 | local PS_NAME="[${FIRST_CHAR}]${FOLLOWING_CHAR}.py" 10 | echo "$PS_NAME" 11 | } 12 | 13 | function killOldProcesses() { 14 | for i in "${!SCHEMES[@]}"; do 15 | local scheme_name=(${SCHEMES[i]//:/ }) 16 | local PS_NAME=$(getProcessName ${scheme_name}) 17 | kill -9 $(ps -ef|grep "$PS_NAME"|awk '{ print $2 }') 18 | done 19 | } 20 | 21 | function cleanOutput() { 22 | ./clean_output.sh 23 | } 24 | 25 | function clean() { 26 | killOldProcesses 27 | cleanOutput 28 | } 29 | 30 | function arrangeOutput(){ 31 | local model=$1 32 | local dataset=$2 33 | local expname=$3 34 | ./gather_output.sh 35 | mkdir -p "${model}-${dataset}" 36 | mv output/ "${model}-${dataset}/${expname}" 37 | } 38 | 39 | 40 | function testFinish() { 41 | local scheme_name=$1 42 | local PS_NAME=$(getProcessName ${scheme_name}) 43 | while : ; do 44 | local count=$(ps -ef|grep "${PS_NAME}"|wc -l) 45 | if [[ $count -eq 0 ]]; then 46 | break 47 | fi 48 | echo "[`date`] Process still active, sleep 60 seconds" 49 | sleep 60 50 | done 51 | } 52 | 53 | -------------------------------------------------------------------------------- /federated-learning/plot/cost-commu/cnn-cifar100-cost-communication.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_time_cost 4 | 5 | apfl = [5.66, 0, 0, 0, 0, 0, 0, 0, 0, 6.47, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6.76, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6.83, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6.9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 6 | fedavg = [1.28, 0.85, 0.69, 0.8, 0.88, 0.87, 1.07, 0.94, 1.03, 0.78, 1.06, 0.93, 1.06, 0.92, 1.24, 0.93, 1.12, 1.14, 1.24, 0.88, 0.92, 1.08, 0.73, 0.89, 0.98, 0.97, 1.11, 1.24, 1.08, 0.74, 0.9, 0.92, 0.99, 0.87, 1.02, 1.21, 1.16, 1.18, 1.27, 1.08, 1.16, 1.35, 0.97, 1.14, 0.8, 0.51, 0.56, 0.68, 0.47, 0.71] 7 | scei = [3.08, 3.68, 3.53, 3.9, 3.93, 3.8, 3.8, 3.8, 3.89, 3.81, 4.0, 3.86, 4.15, 4.01, 3.9, 3.85, 4.06, 4.21, 4.24, 3.84, 3.84, 3.74, 3.56, 3.86, 3.65, 3.87, 4.01, 3.81, 4.09, 3.8, 3.69, 3.88, 3.75, 3.88, 3.97, 4.02, 3.73, 3.81, 3.92, 4.0, 3.88, 3.61, 3.52, 4.03, 3.85, 3.8, 3.99, 3.84, 3.77, 3.58] 8 | scei_async = [3.15, 2.89, 3.01, 3.08, 3.18, 3.11, 2.86, 3.06, 3.07, 2.84, 3.15, 2.97, 3.17, 3.46, 3.58, 3.1, 3.07, 3.18, 3.04, 3.3, 3.28, 2.87, 2.91, 2.76, 3.05, 3.31, 3.16, 3.11, 3.14, 3.05, 2.76, 3.0, 3.09, 3.03, 2.82, 3.15, 3.16, 3.25, 3.02, 3.29, 3.2, 3.34, 3.16, 3.36, 3.18, 3.07, 3.03, 3.0, 2.81, 2.86] 9 | 10 | save_path = None 11 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 12 | save_path = sys.argv[2] 13 | 14 | plot_time_cost("", scei, scei_async, apfl, fedavg, None, False, False, save_path, plot_size="4") 15 | -------------------------------------------------------------------------------- /federated-learning/plot/cost-commu/cnn-cifar10-cost-communication.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_time_cost 4 | 5 | apfl = [6.42, 0, 0, 0, 0, 0, 0, 0, 0, 7.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6.91, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7.28, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6.91, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.56] 6 | fedavg = [1.12, 1.31, 1.08, 1.12, 1.2, 1.16, 1.49, 1.17, 1.19, 1.27, 1.12, 1.35, 1.48, 1.41, 1.41, 1.25, 1.17, 0.76, 0.7, 1.52, 1.41, 1.38, 1.2, 1.28, 1.36, 1.35, 1.26, 1.26, 1.09, 1.69, 1.49, 1.35, 1.43, 1.37, 1.31, 1.24, 1.14, 1.21, 1.2, 1.24, 1.4, 1.11, 1.39, 1.17, 1.19, 1.29, 1.61, 1.36, 1.18, 1.2] 7 | scei = [3.94, 3.89, 3.78, 4.02, 4.16, 4.16, 4.09, 4.1, 4.26, 4.31, 4.41, 4.19, 4.03, 4.21, 4.29, 4.25, 4.33, 4.22, 4.13, 4.42, 4.05, 4.21, 3.95, 4.19, 3.77, 3.84, 4.06, 3.82, 3.69, 4.02, 3.89, 3.93, 3.86, 4.02, 4.05, 4.05, 3.66, 4.1, 3.83, 3.86, 3.87, 3.9, 3.85, 3.42, 3.72, 4.02, 3.98, 4.13, 4.01, 3.92] 8 | scei_async = [3.46, 3.31, 3.27, 2.95, 3.43, 2.89, 3.26, 3.28, 3.09, 3.32, 3.11, 3.34, 3.3, 3.36, 3.39, 3.38, 3.53, 3.28, 3.37, 3.3, 3.13, 3.07, 2.97, 3.15, 3.22, 2.95, 3.23, 3.09, 3.15, 3.45, 3.1, 3.21, 3.77, 3.33, 3.49, 3.33, 2.93, 3.55, 3.44, 3.42, 3.45, 3.03, 3.39, 2.79, 3.09, 3.36, 3.7, 3.47, 3.57, 3.29] 9 | 10 | save_path = None 11 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 12 | save_path = sys.argv[2] 13 | 14 | plot_time_cost("", scei, scei_async, apfl, fedavg, None, False, True, save_path, plot_size="4") 15 | -------------------------------------------------------------------------------- /federated-learning/plot/cost-commu/cnn-imagenet-cost-communication.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_time_cost 4 | 5 | apfl = [8.45, 0, 0, 0, 0, 0, 0, 0, 0, 10.58, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10.61, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10.2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11.92, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.18] 6 | fedavg = [2.0, 2.91, 2.22, 2.27, 2.18, 2.14, 1.92, 2.05, 1.81, 1.81, 1.95, 1.89, 2.52, 2.15, 1.9, 2.53, 2.46, 2.48, 2.2, 2.1, 2.23, 2.21, 2.22, 2.33, 2.73, 2.45, 2.33, 2.31, 2.39, 2.23, 2.73, 1.99, 2.47, 2.22, 2.27, 2.87, 2.69, 1.87, 2.9, 2.66, 2.19, 2.63, 2.0, 1.82, 2.19, 1.85, 2.13, 1.82, 1.89, 2.05] 7 | scei = [5.58, 6.05, 5.8, 5.54, 5.39, 5.8, 5.75, 5.44, 5.78, 5.49, 5.67, 5.76, 6.0, 5.96, 5.58, 5.68, 5.93, 5.72, 5.65, 5.43, 5.47, 5.49, 5.23, 5.76, 6.25, 5.86, 6.14, 6.1, 6.05, 5.43, 6.6, 5.45, 6.29, 5.77, 6.05, 6.3, 6.42, 5.67, 6.32, 6.33, 5.79, 6.75, 6.24, 5.88, 6.44, 6.23, 5.99, 5.39, 5.53, 5.76] 8 | scei_async = [8.24, 4.98, 4.91, 5.1, 5.04, 5.0, 5.16, 4.82, 5.31, 4.6, 5.37, 5.92, 10.86, 5.37, 5.08, 5.05, 5.09, 10.15, 5.86, 5.54, 5.22, 5.6, 5.46, 5.74, 6.7, 10.25, 5.2, 5.44, 5.38, 4.55, 6.44, 5.0, 5.42, 10.46, 5.53, 5.97, 5.62, 5.15, 6.96, 5.85, 5.17, 6.13, 5.33, 4.72, 4.95, 5.58, 5.04, 5.24, 5.26, 5.34] 9 | 10 | save_path = None 11 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 12 | save_path = sys.argv[2] 13 | 14 | plot_time_cost("", scei, scei_async, apfl, fedavg, None, False, False, save_path, plot_size="4") 15 | -------------------------------------------------------------------------------- /federated-learning/plot/cost-commu/cnn-mnist-cost-communication.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_time_cost 4 | 5 | apfl = [12.31, 0.43, 0.17, 0.12, 0, 0, 0.41, 0.02, 0.1, 11.99, 0, 0.14, 0, 0, 0, 0, 0, 0, 0, 11.82, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11.97, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13.22, 0, 0.37, 0.42, 0, 0.24, 0, 0, 0.15, 0.22, 3.73] 6 | fedavg = [3.29, 2.96, 3.12, 2.99, 3.02, 3.04, 3.13, 3.08, 3.01, 3.07, 2.79, 3.18, 2.7, 2.45, 2.59, 3.06, 2.64, 2.52, 2.5, 2.77, 2.92, 2.39, 2.38, 2.3, 2.32, 2.36, 2.47, 2.51, 2.51, 2.7, 2.53, 2.63, 2.73, 2.84, 2.8, 2.65, 2.71, 2.94, 2.81, 2.86, 2.75, 3.02, 2.8, 3.01, 3.1, 3.16, 2.91, 3.24, 2.91, 3.03] 7 | scei = [7.77, 7.3, 6.82, 6.99, 7.02, 7.25, 7.3, 7.2, 6.97, 7.27, 7.01, 7.28, 6.64, 6.83, 6.59, 6.21, 6.14, 6.18, 5.88, 6.61, 6.92, 6.61, 6.18, 6.11, 6.44, 6.46, 6.33, 6.41, 6.26, 6.75, 6.62, 6.52, 6.76, 6.78, 6.94, 6.55, 6.73, 7.08, 6.83, 7.01, 7.17, 7.13, 7.09, 6.81, 7.33, 6.93, 7.23, 7.48, 7.08, 7.08] 8 | scei_async = [6.42, 5.9, 6.04, 5.8, 5.89, 6.39, 6.32, 5.79, 5.92, 6.13, 5.64, 6.25, 5.37, 5.4, 5.4, 5.98, 5.45, 5.37, 5.27, 5.72, 5.84, 5.22, 5.39, 5.14, 5.41, 5.26, 5.58, 5.3, 5.56, 5.73, 5.36, 5.98, 5.45, 5.49, 5.28, 5.67, 5.78, 5.67, 5.77, 5.72, 6.02, 5.78, 5.81, 5.92, 5.94, 5.86, 5.86, 5.62, 5.82, 6.04] 9 | 10 | save_path = None 11 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 12 | save_path = sys.argv[2] 13 | 14 | plot_time_cost("", scei, scei_async, apfl, fedavg, None, False, False, save_path, plot_size="4") 15 | -------------------------------------------------------------------------------- /federated-learning/plot/cost-commu/mlp-mnist-cost-communication.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_time_cost 4 | 5 | apfl = [3.67, 0, 0, 0, 0.1, 0, 0.09, 0.26, 0, 3.94, 0, 0.09, 0.1, 0, 0, 0.09, 0.03, 0.06, 0, 4.26, 0.15, 0, 0.26, 0.11, 0.27, 0.27, 0.09, 0, 0, 3.7, 0, 0, 0.01, 0.22, 0.28, 0.11, 0.18, 0.27, 0, 4.07, 0.01, 0, 0.13, 0.05, 0.11, 0.09, 0.27, 0.1, 0, 1.89] 6 | fedavg = [0.32, 0.54, 0.36, 0.39, 0.39, 0.34, 0.36, 0.51, 0.42, 0.38, 0.58, 0.38, 0.36, 0.42, 0.41, 0.32, 0.54, 0.38, 0.45, 0.39, 0.36, 0.36, 0.35, 0.41, 0.42, 0.46, 0.34, 0.35, 0.38, 0.47, 0.36, 0.37, 0.38, 0.42, 0.43, 0.38, 0.37, 0.61, 0.52, 0.55, 0.34, 0.41, 0.38, 0.43, 0.37, 0.37, 0.33, 0.33, 0.38, 0.39] 7 | scei = [1.97, 1.81, 1.77, 1.73, 1.83, 1.84, 1.99, 1.68, 1.69, 1.69, 1.93, 1.86, 2.07, 1.99, 1.95, 1.84, 1.88, 1.94, 1.93, 1.95, 1.9, 2.01, 2.34, 2.28, 2.21, 2.25, 1.74, 1.76, 2.04, 1.74, 1.85, 1.91, 2.0, 2.23, 2.2, 2.25, 1.84, 1.97, 2.03, 1.9, 2.17, 2.02, 2.13, 2.06, 2.12, 2.11, 2.06, 1.96, 1.84, 1.8] 8 | scei_async = [2.31, 1.87, 1.82, 1.66, 1.75, 1.82, 1.62, 1.56, 1.65, 1.67, 1.81, 1.88, 1.92, 1.83, 1.65, 1.78, 1.74, 1.67, 1.66, 1.76, 1.95, 1.99, 1.94, 2.01, 2.01, 1.98, 1.77, 1.58, 1.86, 1.45, 1.61, 1.6, 1.84, 1.95, 1.97, 1.92, 1.8, 1.89, 1.75, 2.0, 1.7, 1.68, 1.91, 1.75, 1.82, 1.89, 1.58, 1.68, 1.42, 1.39] 9 | 10 | save_path = None 11 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 12 | save_path = sys.argv[2] 13 | 14 | plot_time_cost("", scei, scei_async, apfl, fedavg, None, False, False, save_path, plot_size="4") 15 | -------------------------------------------------------------------------------- /federated-learning/plot/cost-commu/cnn-realworld-cost-communication.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_time_cost 4 | 5 | apfl = [32.57, 0, 0, 0, 0, 0, 0, 0, 0, 32.57, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30.96, 0, 0, 0, 0, 0, 0, 0, 0, 0, 29.27, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30.5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.69] 6 | fedavg = [8.26, 8.59, 9.21, 8.41, 9.35, 9.68, 8.58, 8.54, 8.3, 8.67, 8.6, 8.79, 9.51, 8.72, 9.06, 8.61, 9.41, 7.24, 8.11, 8.85, 8.67, 8.45, 8.59, 7.75, 8.78, 8.17, 9.43, 8.95, 8.96, 8.31, 8.52, 6.32, 6.52, 8.48, 7.38, 8.18, 7.38, 10.24, 7.86, 8.65, 10.52, 8.12, 8.8, 8.71, 8.67, 8.69, 8.83, 7.79, 8.25, 8.9] 7 | scei = [16.47, 15.56, 18.67, 15.57, 19.18, 17.41, 16.44, 16.73, 18.38, 17.99, 18.79, 18.36, 17.11, 16.69, 17.31, 15.9, 18.74, 17.18, 15.75, 15.2, 16.43, 16.94, 16.58, 15.94, 16.77, 16.59, 17.82, 16.49, 18.36, 15.4, 14.96, 15.93, 16.95, 15.48, 17.72, 16.63, 16.78, 18.06, 17.85, 16.93, 15.77, 16.93, 16.76, 16.38, 17.3, 18.4, 15.46, 16.62, 16.89, 16.13] 8 | scei_async = [11.22, 11.6, 11.49, 11.15, 12.11, 11.46, 11.99, 11.87, 12.93, 10.51, 12.79, 11.19, 11.52, 10.6, 11.65, 10.1, 11.0, 11.14, 10.13, 10.16, 13.06, 10.18, 10.54, 9.92, 11.07, 9.52, 10.92, 10.98, 10.29, 10.85, 8.37, 10.76, 11.54, 12.4, 8.45, 11.21, 10.23, 10.08, 9.75, 9.81, 8.67, 9.97, 11.69, 10.65, 10.98, 10.47, 11.97, 11.92, 10.62, 9.18] 9 | 10 | save_path = None 11 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 12 | save_path = sys.argv[2] 13 | 14 | plot_time_cost("", scei, scei_async, apfl, fedavg, None, False, False, save_path, plot_size="4") 15 | -------------------------------------------------------------------------------- /federated-learning/plot/cost-commu/cnn-uci-cost-communication.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_time_cost 4 | 5 | apfl = [36.05, 0, 0, 0, 0, 0, 0, 0, 0, 30.98, 0, 0, 0, 0, 0, 0, 0, 0, 0, 33.65, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35.53, 0, 0, 0, 0, 0, 0, 0, 0, 0, 29.21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9.23] 6 | fedavg = [9.98, 7.98, 8.01, 8.97, 8.63, 9.82, 9.55, 8.7, 9.31, 8.29, 9.15, 9.29, 9.29, 6.98, 8.76, 7.47, 7.79, 7.94, 9.3, 9.49, 8.64, 7.39, 7.35, 8.48, 8.09, 8.6, 8.84, 8.25, 10.19, 9.76, 8.89, 8.11, 8.99, 8.44, 9.23, 10.11, 10.08, 10.02, 10.09, 11.1, 10.14, 10.32, 10.66, 8.98, 9.34, 9.43, 10.38, 10.16, 10.5, 10.96] 7 | scei = [18.26, 17.72, 16.0, 17.32, 19.56, 17.78, 18.38, 17.37, 17.71, 15.21, 18.07, 17.06, 19.57, 15.9, 16.15, 15.24, 16.88, 18.62, 15.83, 16.57, 17.82, 17.98, 16.71, 17.11, 18.41, 19.67, 19.08, 18.06, 19.32, 17.26, 19.46, 16.86, 19.19, 17.93, 19.53, 19.37, 17.68, 19.0, 18.51, 19.9, 18.83, 20.44, 19.03, 17.76, 18.96, 17.95, 18.6, 18.23, 19.06, 19.71] 8 | scei_async = [13.16, 11.39, 8.9, 11.87, 10.68, 12.44, 9.87, 10.68, 10.37, 10.55, 11.51, 11.71, 10.44, 8.62, 8.78, 8.75, 10.2, 10.29, 9.33, 12.06, 11.5, 8.49, 10.17, 9.94, 11.19, 9.07, 10.66, 10.12, 11.48, 10.45, 11.73, 11.5, 11.25, 12.36, 12.81, 10.08, 12.29, 12.12, 12.52, 11.97, 11.15, 10.71, 12.07, 12.04, 11.99, 13.49, 8.93, 10.15, 15.4, 11.59] 9 | 10 | save_path = None 11 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 12 | save_path = sys.argv[2] 13 | 14 | plot_time_cost("", scei, scei_async, apfl, fedavg, None, False, False, save_path, plot_size="4") 15 | -------------------------------------------------------------------------------- /federated-learning/plot/cost-commu/resnet-cifar10-cost-communication.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_time_cost 4 | 5 | apfl = [48.6, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.42, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 49.92, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 49.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 56.46, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 49.56] 6 | fedavg = [17.95, 15.87, 16.98, 17.85, 17.55, 18.84, 17.13, 15.61, 15.6, 15.84, 16.5, 15.21, 17.2, 17.37, 15.1, 16.98, 15.49, 15.86, 15.72, 15.9, 15.52, 17.43, 14.33, 15.12, 15.81, 17.83, 15.93, 16.0, 16.81, 17.18, 16.35, 16.61, 16.16, 16.92, 17.34, 15.88, 17.39, 15.57, 17.85, 16.68, 15.98, 15.16, 17.06, 17.93, 14.87, 16.82, 15.2, 16.44, 17.41, 16.83] 7 | scei = [31.5, 30.38, 30.9, 30.26, 29.0, 30.06, 30.82, 30.14, 29.81, 28.75, 29.91, 29.17, 28.88, 29.87, 30.72, 29.19, 28.07, 30.44, 29.08, 28.64, 30.35, 30.63, 30.55, 30.39, 31.41, 29.56, 31.07, 29.47, 30.18, 30.39, 30.82, 31.42, 30.14, 29.38, 28.66, 29.32, 30.44, 30.64, 30.35, 30.87, 29.63, 31.07, 28.88, 28.75, 31.17, 30.44, 30.38, 29.26, 30.9, 29.82] 8 | scei_async = [25.41, 17.22, 18.33, 20.28, 18.72, 18.96, 19.44, 18.25, 18.69, 18.74, 17.94, 18.58, 18.33, 17.55, 16.83, 19.34, 18.92, 18.77, 18.27, 17.76, 19.85, 20.19, 19.77, 18.36, 17.28, 20.28, 18.72, 18.17, 18.91, 17.73, 20.1, 19.52, 18.36, 18.54, 17.52, 19.78, 20.43, 19.84, 18.48, 19.8, 18.07, 19.23, 17.4, 18.42, 20.28, 19.7, 20.18, 20.16, 19.91, 20.16] 9 | 10 | save_path = None 11 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 12 | save_path = sys.argv[2] 13 | 14 | plot_time_cost("", scei, scei_async, apfl, fedavg, None, False, False, save_path, plot_size="4") 15 | -------------------------------------------------------------------------------- /federated-learning/results-merge/extract-time-cost.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from utils import calculate_average_across_files 4 | 5 | 6 | def extract_communication_cost(exp_node_number, model_name, dataset_name): 7 | experiment_names = ["apfl", "fedavg", "scei", "scei-async"] 8 | for path, dirs, files in os.walk("./output"): 9 | if path.endswith(model_name + "-" + dataset_name) and exp_node_number in path: 10 | for experiment_name in experiment_names: 11 | experiment_path = os.path.join(path, experiment_name) 12 | files_numbers_mean_2d_np = calculate_average_across_files(experiment_path) 13 | cost_communication = [round(i, 2) for i in files_numbers_mean_2d_np[:, 4]] 14 | print(experiment_name, "=", cost_communication) 15 | 16 | 17 | def extract_overall_cost(exp_node_number, model_name, dataset_name): 18 | experiment_names = ["apfl", "fedavg", "local", "scei", "scei-async"] 19 | for path, dirs, files in os.walk("./output"): 20 | if path.endswith(model_name + "-" + dataset_name) and exp_node_number in path: 21 | for experiment_name in experiment_names: 22 | experiment_path = os.path.join(path, experiment_name) 23 | files_numbers_mean_2d_np = calculate_average_across_files(experiment_path) 24 | cost_communication = [round(i, 2) for i in files_numbers_mean_2d_np[:, 1]] 25 | print(experiment_name, "=", cost_communication) 26 | 27 | 28 | def main(): 29 | exp_node_number = "all_test" 30 | model_name = "resnet" 31 | dataset_name = "cifar10" 32 | 33 | extract_communication_cost(exp_node_number, model_name, dataset_name) 34 | # extract_overall_cost(exp_node_number, model_name, dataset_name) 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /federated-learning/datasets/IMAGENET.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from torchvision import datasets, transforms 6 | 7 | 8 | class IMAGENETDataset(Dataset): 9 | def __init__(self, subset): 10 | self.subset = subset 11 | self.targets = self.get_targets() 12 | 13 | def __len__(self): 14 | return len(self.subset.indices) 15 | 16 | def __getitem__(self, idx): 17 | return self.subset.dataset[self.subset.indices[idx]] 18 | 19 | def get_targets(self): 20 | target_mapping = map(self.subset.dataset.targets.__getitem__, self.subset.indices) 21 | return list(target_mapping) 22 | 23 | 24 | if __name__ == '__main__': 25 | real_path = os.path.dirname(os.path.realpath(__file__)) 26 | data_path = os.path.join(real_path, "../../data/imagenet/") 27 | train_dir = os.path.join(data_path, 'train') 28 | trans = transforms.Compose( 29 | [transforms.Resize(256), transforms.CenterCrop(224), transforms.RandomHorizontalFlip(), 30 | transforms.ToTensor()]) 31 | dataset_full = datasets.ImageFolder(train_dir, transform=trans) 32 | print("dataset_full") 33 | train_size = int(0.8 * len(dataset_full)) 34 | test_size = len(dataset_full) - train_size 35 | subset_train, subset_test = torch.utils.data.random_split(dataset_full, [train_size, test_size]) 36 | print("subset_train and subset_test") 37 | 38 | dataset_train = IMAGENETDataset(subset_train) 39 | dataset_test = IMAGENETDataset(subset_test) 40 | print("dataset train: {}".format(dataset_train)) 41 | print("dataset test: {}".format(dataset_test)) 42 | print(dataset_train[0][0].shape, dataset_train[0][1]) 43 | print(len(dataset_train)) 44 | print(len(dataset_test)) 45 | print("dataset train targets len: {}".format(len(dataset_train.targets))) 46 | -------------------------------------------------------------------------------- /federated-learning/plot/cost-overall/mlp-mnist-cost-overall.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_time_cost 4 | 5 | apfl = [7.05, 3.1, 3.07, 3.21, 3.28, 3.08, 3.28, 3.59, 3.22, 7.2, 3.02, 3.24, 3.17, 3.07, 3.17, 3.26, 3.23, 3.22, 2.99, 7.41, 3.3, 3.06, 3.26, 3.07, 3.22, 3.2, 3.36, 3.1, 2.94, 7.05, 3.1, 3.02, 3.13, 3.23, 3.32, 3.13, 3.36, 3.34, 2.98, 7.18, 3.1, 3.02, 3.17, 3.18, 3.16, 3.19, 3.47, 3.24, 3.1, 5.2] 6 | fedavg = [4.71, 4.3, 4.3, 4.1, 4.17, 4.13, 4.18, 4.25, 4.1, 4.25, 4.22, 4.12, 4.16, 4.29, 4.03, 3.95, 4.11, 4.08, 4.23, 4.15, 4.24, 4.16, 4.07, 4.01, 4.16, 4.28, 4.04, 4.16, 4.11, 4.26, 4.27, 4.23, 4.17, 4.04, 4.02, 4.28, 4.14, 4.07, 4.17, 4.2, 4.15, 4.2, 4.18, 4.24, 4.23, 4.15, 4.06, 4.12, 4.42, 4.16] 7 | local = [3.38, 3.14, 3.15, 3.31, 3.18, 3.21, 3.19, 3.33, 3.32, 3.26, 3.19, 3.15, 3.07, 3.11, 3.18, 3.17, 3.2, 3.16, 3.21, 3.15, 3.15, 3.09, 3.0, 2.96, 2.95, 2.93, 3.27, 3.29, 3.12, 3.35, 3.21, 3.19, 3.12, 3.01, 3.04, 3.02, 3.18, 3.07, 3.1, 3.11, 3.09, 3.15, 3.04, 3.13, 3.05, 3.1, 3.2, 3.14, 3.29, 3.31] 8 | scei = [6.4, 5.84, 5.94, 6.01, 6.05, 6.08, 6.15, 5.99, 6.0, 5.97, 6.03, 6.05, 6.02, 6.04, 6.06, 6.03, 6.06, 6.06, 6.09, 6.02, 6.06, 6.1, 6.15, 6.05, 6.07, 6.07, 5.99, 6.05, 6.03, 6.11, 6.09, 6.02, 6.02, 6.09, 6.07, 6.13, 6.06, 6.0, 6.13, 6.03, 6.1, 6.04, 6.05, 6.09, 6.11, 6.17, 6.12, 6.08, 6.15, 6.05] 9 | scei_async = [6.69, 5.87, 5.89, 5.98, 5.92, 5.94, 5.83, 5.91, 5.9, 5.9, 5.87, 5.88, 5.82, 5.87, 5.85, 5.89, 5.91, 5.84, 5.8, 5.85, 6.03, 5.93, 5.92, 5.85, 5.86, 5.82, 6.03, 5.8, 5.93, 5.86, 5.86, 5.85, 5.86, 5.88, 5.91, 5.79, 5.84, 5.85, 5.83, 5.96, 5.77, 5.79, 5.77, 5.77, 5.85, 5.83, 5.77, 5.77, 5.72, 5.77] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_time_cost("", scei, scei_async, apfl, fedavg, local, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/cost-overall/cnn-mnist-cost-overall.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_time_cost 4 | 5 | apfl = [15.01, 2.99, 2.69, 2.78, 2.35, 2.14, 2.87, 2.64, 2.58, 14.45, 2.58, 2.46, 2.32, 2.71, 2.25, 2.06, 2.08, 2.59, 2.57, 14.78, 2.09, 2.05, 2.15, 2.45, 2.18, 2.32, 2.08, 2.75, 2.58, 15.23, 2.48, 2.35, 2.33, 2.8, 2.03, 2.45, 2.24, 2.73, 2.71, 16.0, 2.17, 2.93, 2.94, 2.51, 2.82, 2.43, 2.15, 2.77, 2.7, 6.19] 6 | fedavg = [4.64, 4.26, 4.35, 4.24, 4.36, 4.27, 4.38, 4.38, 4.25, 4.26, 4.23, 4.3, 4.29, 4.15, 4.19, 4.26, 4.23, 4.32, 4.24, 4.21, 4.31, 4.3, 4.18, 4.2, 4.25, 4.21, 4.2, 4.25, 4.26, 4.36, 4.16, 4.28, 4.28, 4.37, 4.32, 4.21, 4.23, 4.3, 4.32, 4.23, 4.13, 4.2, 4.15, 4.29, 4.3, 4.31, 4.2, 4.48, 4.21, 4.26] 7 | local = [1.35, 1.28, 1.26, 1.33, 1.29, 1.24, 1.23, 1.31, 1.24, 1.23, 1.53, 1.16, 1.53, 1.62, 1.6, 1.21, 1.56, 1.76, 1.78, 1.48, 1.44, 1.82, 1.8, 1.84, 1.86, 1.84, 1.73, 1.75, 1.73, 1.63, 1.62, 1.56, 1.55, 1.5, 1.54, 1.54, 1.47, 1.45, 1.47, 1.39, 1.35, 1.28, 1.26, 1.33, 1.29, 1.24, 1.23, 1.31, 1.24, 1.23] 8 | scei = [9.35, 8.85, 8.56, 8.91, 8.72, 8.84, 8.7, 8.89, 8.73, 8.71, 9.08, 8.87, 8.67, 8.92, 8.6, 7.96, 8.24, 8.54, 8.33, 8.4, 8.78, 8.84, 8.45, 8.62, 9.05, 8.75, 8.77, 8.79, 8.63, 8.98, 8.91, 8.72, 8.84, 8.7, 8.89, 8.73, 8.71, 9.08, 8.87, 8.67, 8.91, 8.72, 8.84, 8.7, 8.89, 8.73, 8.71, 9.08, 8.87, 8.67] 9 | scei_async = [8.12, 7.69, 7.66, 7.64, 7.71, 7.91, 7.8, 7.6, 7.61, 7.77, 7.54, 7.69, 7.42, 7.49, 7.45, 7.6, 7.59, 7.49, 7.48, 7.63, 7.77, 7.74, 7.59, 7.66, 7.67, 7.76, 7.78, 7.64, 7.73, 7.66, 7.64, 7.84, 7.63, 7.48, 7.45, 7.58, 7.57, 7.65, 7.49, 7.57, 7.59, 7.52, 7.52, 7.7, 7.5, 7.59, 7.63, 7.5, 7.49, 7.51] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_time_cost("", scei, scei_async, apfl, fedavg, local, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/cost-overall/cnn-cifar10-cost-overall.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_time_cost 4 | 5 | apfl = [13.76, 4.54, 4.81, 4.87, 4.92, 4.82, 4.76, 4.73, 4.64, 13.49, 4.54, 4.53, 4.59, 4.48, 4.48, 4.45, 4.52, 4.41, 4.46, 13.39, 4.55, 4.69, 4.53, 4.59, 4.57, 4.4, 4.48, 4.45, 4.44, 13.71, 4.7, 4.66, 4.67, 4.68, 4.73, 4.84, 4.88, 4.75, 4.76, 13.61, 4.69, 4.54, 4.56, 4.56, 4.57, 4.62, 4.67, 4.59, 4.59, 6.99] 6 | fedavg = [5.13, 4.9, 4.9, 4.88, 4.88, 4.95, 5.01, 4.86, 4.84, 4.86, 4.79, 4.97, 5.0, 4.95, 4.96, 4.92, 4.62, 4.45, 4.33, 5.13, 4.9, 4.9, 4.88, 4.88, 4.95, 5.01, 4.86, 4.84, 4.86, 5.23, 5.13, 4.9, 4.9, 4.88, 4.88, 4.95, 5.01, 4.86, 4.84, 4.86, 5.13, 4.9, 4.9, 4.88, 4.88, 4.95, 5.01, 4.86, 4.84, 4.86] 7 | local = [4.08, 3.68, 3.73, 3.71, 3.63, 3.78, 3.62, 3.6, 3.57, 3.55, 3.58, 3.53, 3.61, 3.58, 3.47, 3.59, 3.48, 3.63, 3.68, 3.6, 3.58, 3.56, 3.75, 3.61, 3.68, 3.62, 3.66, 3.66, 3.73, 3.57, 3.65, 3.65, 3.57, 3.56, 3.62, 3.64, 3.83, 3.56, 3.62, 3.72, 3.66, 3.76, 3.58, 3.74, 3.63, 3.62, 3.43, 3.56, 3.59, 3.57] 8 | scei = [9.21, 8.71, 8.79, 8.83, 8.99, 8.88, 8.87, 8.9, 8.78, 8.82, 8.94, 8.88, 8.79, 8.95, 8.95, 8.98, 9.01, 8.98, 8.94, 8.9, 8.78, 8.77, 8.71, 8.69, 8.68, 8.65, 8.66, 8.72, 8.69, 8.63, 8.71, 8.67, 8.63, 8.63, 8.68, 8.69, 8.65, 8.6, 8.62, 8.65, 8.63, 8.6, 8.62, 8.4, 8.44, 8.6, 8.62, 8.65, 8.63, 8.6] 9 | scei_async = [8.75, 7.94, 8.15, 7.89, 7.97, 7.83, 7.91, 7.9, 7.89, 7.74, 7.73, 7.85, 8.16, 7.89, 7.97, 8.05, 7.89, 8.15, 7.98, 7.95, 7.82, 7.76, 7.95, 7.88, 7.85, 7.79, 7.89, 7.87, 8.19, 8.0, 7.97, 8.14, 8.23, 8.07, 8.15, 8.14, 8.05, 8.13, 8.12, 8.17, 8.03, 8.05, 7.9, 7.83, 7.97, 7.96, 8.14, 8.11, 8.08, 8.07] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_time_cost("", scei, scei_async, apfl, fedavg, local, False, True, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/cost-overall/cnn-cifar100-cost-overall.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_time_cost 4 | 5 | apfl = [13.29, 4.44, 4.19, 4.17, 4.26, 4.54, 4.54, 4.6, 4.48, 13.38, 4.37, 4.42, 4.26, 4.13, 4.11, 4.24, 4.37, 4.32, 4.36, 13.44, 4.27, 4.24, 4.28, 4.1, 4.06, 4.34, 4.41, 4.32, 4.31, 13.62, 4.42, 4.41, 4.21, 4.37, 4.17, 4.33, 4.26, 4.26, 4.27, 13.36, 4.32, 4.2, 4.23, 4.34, 4.32, 4.32, 4.33, 4.01, 4.37, 6.36] 6 | fedavg = [5.42, 4.58, 4.45, 4.58, 4.5, 4.63, 4.67, 4.57, 4.74, 4.58, 4.75, 4.66, 4.75, 4.57, 4.84, 4.72, 4.6, 4.66, 4.83, 4.64, 4.56, 4.74, 4.65, 4.7, 4.71, 4.64, 4.71, 4.76, 4.66, 4.61, 4.74, 4.78, 4.82, 4.68, 4.73, 4.78, 4.76, 4.74, 4.76, 4.69, 4.74, 4.88, 4.79, 4.83, 4.43, 4.18, 4.25, 4.31, 4.13, 4.4] 7 | local = [4.24, 3.77, 3.79, 3.74, 3.7, 3.83, 3.68, 3.53, 3.69, 3.84, 3.63, 3.74, 3.7, 3.72, 3.63, 3.72, 3.57, 3.54, 3.66, 3.71, 3.61, 3.68, 3.86, 3.87, 3.81, 3.62, 3.58, 3.57, 3.55, 3.77, 3.76, 3.79, 3.8, 3.71, 3.77, 3.6, 3.69, 3.54, 3.57, 3.59, 3.54, 3.6, 3.73, 3.68, 3.63, 3.69, 3.64, 3.73, 3.74, 3.72] 8 | scei = [8.54, 8.54, 8.64, 8.81, 8.85, 8.72, 8.55, 8.54, 8.67, 8.83, 8.8, 8.78, 8.76, 8.78, 8.79, 8.85, 8.81, 8.96, 8.9, 8.79, 8.68, 8.56, 8.6, 8.69, 8.62, 8.68, 8.61, 8.63, 8.63, 8.78, 8.77, 8.75, 8.74, 8.75, 8.79, 8.74, 8.69, 8.58, 8.54, 8.49, 8.53, 8.43, 8.36, 8.75, 8.74, 8.75, 8.79, 8.74, 8.69, 8.58] 9 | scei_async = [8.61, 7.93, 7.89, 8.02, 7.82, 8.0, 7.68, 7.81, 7.81, 7.88, 7.92, 7.82, 7.99, 8.16, 8.23, 7.78, 7.81, 7.87, 7.96, 8.0, 7.8, 7.67, 8.03, 7.8, 7.81, 7.86, 7.87, 7.93, 7.85, 7.85, 7.62, 7.78, 8.01, 7.98, 7.74, 7.91, 8.0, 7.77, 7.74, 8.1, 7.92, 7.82, 7.95, 7.98, 7.89, 7.86, 7.91, 7.94, 7.71, 7.85] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_time_cost("", scei, scei_async, apfl, fedavg, local, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-sota/resnet-cifar10-acc-sota.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc 4 | 5 | apfl = [50.0, 53.0, 55.0, 55.0, 59.0, 57.0, 61.5, 63.5, 58.0, 63.0, 68.0, 65.5, 64.5, 65.0, 61.5, 63.5, 60.5, 62.5, 63.0, 63.5, 64.0, 64.5, 66.0, 65.5, 62.5, 63.0, 65.5, 68.0, 69.0, 65.5, 66.5, 66.5, 66.5, 67.5, 67.5, 65.5, 66.5, 67.5, 66.5, 66.5, 64.5, 65.0, 65.0, 66.0, 65.5, 65.0, 67.0, 63.0, 65.5, 64.5] 6 | fedavg = [29.0, 38.5, 44.0, 53.5, 53.0, 46.5, 52.0, 62.0, 62.0, 57.0, 61.5, 58.5, 62.0, 52.5, 68.0, 58.5, 67.0, 69.5, 66.5, 66.0, 63.5, 66.5, 65.5, 59.5, 65.0, 64.0, 65.0, 65.5, 65.0, 65.5, 65.5, 65.0, 66.0, 64.0, 65.0, 62.5, 61.5, 66.0, 66.5, 70.5, 69.5, 69.5, 70.0, 67.5, 69.0, 68.5, 69.5, 68.5, 66.0, 64.5] 7 | local = [35.5, 48.5, 51.0, 44.5, 55.0, 57.5, 60.5, 60.0, 59.5, 61.0, 63.5, 59.0, 62.5, 64.0, 60.5, 63.0, 63.5, 59.0, 61.5, 57.0, 62.0, 62.5, 66.0, 64.5, 67.5, 65.5, 66.5, 66.0, 67.0, 66.5, 67.0, 71.0, 64.0, 67.0, 66.5, 69.0, 68.5, 63.0, 65.5, 66.5, 66.0, 63.5, 67.5, 66.5, 65.5, 65.0, 69.0, 68.5, 68.5, 68.5] 8 | scei = [26.0, 38.0, 43.5, 31.5, 53.5, 62.5, 59.5, 48.5, 67.0, 58.5, 76.0, 68.0, 66.5, 64.5, 72.0, 76.0, 74.0, 77.0, 74.0, 81.5, 79.0, 84.0, 83.0, 86.0, 78.0, 86.0, 85.5, 81.5, 82.5, 83.5, 82.5, 83.5, 82.5, 82.5, 84.5, 82.5, 84.0, 83.5, 84.5, 82.0, 85.5, 82.5, 84.0, 82.5, 83.5, 86.0, 84.5, 84.5, 83.0, 83.5] 9 | scei_async = [36.5, 44.5, 67.0, 52.5, 64.0, 70.0, 76.5, 74.5, 72.0, 68.0, 70.5, 67.0, 80.5, 80.0, 77.5, 79.5, 80.5, 83.0, 80.5, 81.5, 80.5, 79.0, 78.0, 78.0, 79.0, 80.5, 80.5, 77.0, 79.0, 75.5, 77.0, 69.0, 80.0, 78.0, 78.5, 78.0, 78.5, 78.5, 79.5, 77.5, 77.0, 78.0, 79.0, 79.0, 78.0, 78.0, 78.0, 78.0, 78.5, 79.5] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_round_acc("", scei, scei_async, apfl, fedavg, local, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/cost-overall/cnn-imagenet-cost-overall.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_time_cost 4 | 5 | apfl = [16.81, 4.12, 4.11, 4.38, 4.16, 4.46, 4.16, 4.83, 4.63, 18.12, 4.92, 4.93, 4.45, 4.91, 4.89, 4.46, 4.11, 4.77, 4.55, 17.71, 4.8, 4.06, 4.71, 4.98, 4.94, 4.41, 4.08, 4.74, 4.52, 18.18, 4.05, 4.69, 4.29, 4.78, 4.04, 4.61, 4.85, 4.87, 4.59, 18.62, 4.95, 4.49, 4.78, 4.25, 4.06, 4.1, 4.2, 4.75, 4.5, 7.58] 6 | fedavg = [6.21, 6.41, 5.98, 5.85, 5.83, 5.75, 5.54, 5.74, 5.6, 5.51, 5.63, 5.54, 5.89, 5.71, 5.54, 6.1, 5.91, 5.99, 5.81, 5.69, 5.7, 5.88, 5.94, 5.85, 5.79, 5.98, 5.83, 5.77, 6.04, 6.17, 5.72, 5.74, 5.78, 5.78, 5.85, 5.95, 5.91, 5.64, 6.04, 6.09, 5.87, 5.79, 5.53, 5.61, 5.75, 5.47, 5.75, 5.55, 5.61, 5.73] 7 | local = [4.18, 3.44, 3.72, 3.67, 3.58, 3.56, 3.66, 3.75, 3.7, 3.77, 3.7, 3.63, 3.45, 3.51, 3.61, 3.56, 3.51, 3.6, 3.54, 3.55, 3.53, 3.62, 3.7, 3.42, 3.03, 3.45, 3.46, 3.45, 3.59, 3.99, 3.08, 3.84, 3.28, 3.66, 3.63, 3.09, 3.2, 3.82, 3.15, 3.35, 3.76, 3.14, 3.56, 3.86, 3.61, 3.58, 3.56, 3.66, 3.75, 3.7] 8 | scei = [10.85, 10.58, 10.54, 10.33, 10.22, 10.4, 10.35, 10.39, 10.62, 10.56, 10.66, 10.36, 10.49, 10.41, 10.39, 10.44, 10.31, 10.36, 10.31, 10.23, 10.15, 10.18, 10.08, 10.2, 10.36, 10.5, 10.68, 10.75, 10.9, 10.67, 10.66, 10.55, 10.43, 10.6, 10.62, 10.42, 10.6, 10.61, 10.53, 10.68, 10.8, 10.81, 10.97, 10.93, 10.95, 10.87, 10.54, 10.33, 10.22, 10.4] 9 | scei_async = [13.55, 9.61, 9.71, 9.78, 9.73, 9.78, 9.86, 9.77, 10.01, 9.59, 10.03, 10.71, 15.53, 9.88, 9.6, 9.87, 9.47, 14.97, 10.53, 10.03, 9.85, 10.4, 10.15, 10.06, 10.53, 14.6, 9.69, 10.03, 9.98, 9.6, 10.38, 9.81, 9.87, 15.38, 10.16, 9.95, 9.95, 10.09, 10.86, 10.21, 9.94, 10.04, 10.04, 9.93, 9.8, 10.16, 9.76, 10.1, 10.19, 10.02] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_time_cost("", scei, scei_async, apfl, fedavg, local, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-sota/cnn-imagenet-acc-sota.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc 4 | 5 | apfl = [26.05, 38.45, 46.8, 52.9, 55.35, 55.25, 55.55, 55.8, 55.2, 27.1, 41.6, 51.6, 55.3, 55.95, 53.55, 54.85, 52.5, 53.05, 53.0, 38.35, 51.75, 54.1, 53.75, 51.9, 51.25, 50.6, 51.6, 54.6, 53.1, 43.05, 51.75, 50.85, 48.4, 52.1, 50.6, 51.0, 50.9, 50.05, 51.6, 44.5, 50.7, 50.35, 49.95, 48.95, 51.3, 49.75, 50.95, 51.65, 50.55, 45.3] 6 | fedavg = [5.0, 5.25, 6.4, 6.9, 7.45, 7.3, 8.1, 8.05, 8.85, 9.65, 9.85, 10.75, 10.85, 10.35, 10.35, 9.85, 10.75, 10.45, 9.9, 9.95, 11.4, 9.6, 11.05, 9.85, 10.05, 9.8, 10.55, 9.95, 10.15, 9.35, 9.85, 9.75, 9.4, 9.3, 9.2, 8.55, 9.35, 8.7, 9.3, 8.15, 9.75, 9.05, 8.15, 8.4, 8.65, 9.15, 9.15, 8.65, 8.6, 8.1] 7 | local = [28.85, 35.3, 43.65, 49.7, 51.65, 53.25, 53.0, 55.15, 54.35, 52.3, 54.75, 54.25, 53.3, 53.8, 53.6, 51.8, 52.25, 50.15, 52.5, 53.25, 53.05, 52.2, 52.95, 52.35, 53.45, 53.1, 53.5, 51.35, 53.2, 52.8, 53.8, 53.75, 52.1, 53.95, 53.8, 54.2, 54.25, 53.4, 54.5, 54.25, 53.4, 53.7, 53.8, 51.9, 52.9, 53.8, 53.05, 53.75, 53.5, 53.55] 8 | scei = [28.25, 29.3, 34.3, 35.2, 38.95, 41.45, 41.4, 40.3, 44.3, 47.15, 47.05, 46.6, 49.95, 52.1, 52.25, 52.35, 52.9, 55.0, 53.35, 54.3, 54.25, 55.55, 55.5, 57.05, 56.0, 55.95, 54.9, 54.35, 55.5, 56.05, 56.1, 55.1, 55.45, 54.5, 54.1, 54.6, 54.3, 53.4, 55.5, 54.05, 53.4, 53.6, 54.4, 54.35, 53.05, 54.6, 54.25, 53.7, 53.45, 52.85] 9 | scei_async = [25.95, 26.4, 29.85, 36.45, 42.3, 41.4, 42.05, 45.25, 45.9, 46.9, 50.55, 50.95, 50.95, 51.8, 50.9, 53.4, 53.25, 51.6, 49.45, 51.6, 51.6, 51.9, 50.1, 50.95, 51.5, 50.2, 49.7, 50.6, 50.05, 50.0, 48.95, 48.45, 49.75, 49.55, 48.5, 48.7, 49.1, 46.65, 48.3, 46.75, 48.35, 47.85, 47.8, 49.35, 47.0, 46.15, 46.35, 47.85, 48.15, 48.35] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_round_acc("", scei, scei_async, apfl, fedavg, local, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-sota/cnn-cifar10-acc-sota.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc 4 | 5 | apfl = [49.05, 57.4, 59.05, 57.45, 58.6, 59.6, 59.8, 59.8, 60.0, 33.5, 56.95, 58.75, 58.65, 59.95, 59.75, 59.55, 59.7, 59.9, 59.85, 54.9, 60.4, 59.25, 59.9, 59.4, 59.05, 59.25, 59.3, 59.1, 59.2, 55.5, 58.2, 57.8, 59.75, 59.3, 59.1, 59.1, 59.2, 59.15, 59.4, 55.6, 58.65, 60.5, 60.65, 60.9, 60.9, 60.85, 61.0, 60.95, 61.05, 56.25] 6 | fedavg = [17.55, 21.6, 31.1, 36.35, 39.6, 41.6, 41.8, 41.05, 40.45, 39.8, 40.85, 40.2, 40.85, 41.2, 40.55, 40.35, 40.7, 40.3, 40.8, 40.3, 41.3, 40.25, 41.55, 41.45, 42.0, 41.5, 41.55, 42.45, 42.2, 41.3, 40.55, 41.75, 41.9, 42.75, 42.45, 42.65, 42.15, 42.9, 42.8, 40.65, 41.85, 42.05, 41.6, 41.8, 41.55, 41.6, 41.7, 42.55, 41.8, 42.3] 7 | local = [55.35, 59.5, 60.3, 59.65, 58.95, 61.55, 62.6, 62.85, 62.95, 62.9, 63.0, 63.1, 63.15, 63.15, 63.0, 63.15, 63.2, 63.2, 63.15, 63.1, 63.1, 63.15, 63.15, 63.15, 63.15, 63.2, 63.15, 63.15, 63.15, 63.15, 63.15, 63.15, 63.15, 63.15, 63.15, 63.2, 63.2, 63.2, 63.2, 63.2, 63.2, 63.2, 63.2, 63.2, 63.25, 63.25, 63.25, 63.25, 63.3, 63.3] 8 | scei = [40.35, 50.6, 55.7, 59.2, 61.8, 63.0, 64.1, 63.3, 64.85, 64.65, 64.8, 64.95, 65.1, 65.05, 65.5, 66.0, 65.95, 65.3, 66.4, 66.3, 66.05, 64.95, 66.0, 65.2, 65.85, 66.0, 65.55, 67.0, 67.1, 65.6, 66.6, 66.55, 66.55, 66.3, 66.35, 66.2, 66.8, 66.35, 66.95, 66.6, 66.9, 67.1, 66.35, 66.05, 66.25, 66.8, 66.9, 66.75, 66.7, 66.95] 9 | scei_async = [47.2, 55.7, 57.35, 57.15, 61.2, 60.9, 61.9, 60.55, 62.75, 63.15, 61.4, 61.15, 60.85, 61.8, 62.25, 62.6, 62.9, 63.25, 61.5, 61.5, 63.7, 64.05, 62.75, 61.65, 63.3, 63.75, 63.55, 62.45, 63.0, 61.3, 61.35, 62.35, 61.95, 61.8, 61.45, 62.05, 62.4, 62.0, 61.5, 61.35, 62.3, 63.85, 61.9, 63.15, 62.5, 61.85, 62.05, 63.1, 63.1, 62.25] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_round_acc("", scei, scei_async, apfl, fedavg, local, False, True, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-sota/mlp-mnist-acc-sota.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc 4 | 5 | apfl = [91.85, 93.0, 92.6, 93.0, 93.55, 93.25, 93.35, 93.45, 93.35, 89.8, 94.05, 93.8, 93.8, 94.2, 94.5, 94.2, 94.45, 93.95, 94.15, 90.05, 94.3, 94.4, 94.15, 94.5, 94.85, 94.6, 94.65, 94.5, 94.3, 91.45, 94.8, 94.95, 94.75, 95.0, 95.3, 94.95, 94.8, 94.75, 94.9, 91.75, 95.1, 95.45, 95.45, 95.05, 95.25, 95.25, 95.4, 95.0, 95.4, 92.7] 6 | fedavg = [71.05, 73.45, 77.8, 79.75, 81.25, 83.1, 84.1, 85.2, 85.7, 86.1, 86.8, 88.05, 88.35, 88.4, 89.6, 90.05, 89.4, 90.35, 90.8, 90.45, 90.8, 90.65, 90.75, 90.7, 90.85, 90.9, 90.75, 91.45, 91.25, 91.55, 91.6, 91.35, 91.4, 91.3, 91.5, 91.3, 92.0, 91.85, 91.75, 91.7, 91.85, 92.2, 91.8, 92.1, 91.9, 91.75, 92.0, 92.15, 92.1, 92.05] 7 | local = [94.75, 95.55, 95.25, 95.45, 95.3, 95.4, 95.85, 95.65, 95.45, 95.45, 95.7, 95.8, 95.55, 95.7, 95.7, 95.8, 95.75, 95.75, 95.6, 95.55, 95.6, 95.65, 95.8, 95.85, 95.75, 95.9, 95.75, 95.55, 95.65, 95.7, 95.75, 95.65, 95.5, 95.5, 95.55, 95.55, 95.7, 95.7, 95.9, 95.6, 95.85, 95.65, 95.75, 95.9, 95.85, 95.75, 95.8, 95.85, 95.9, 95.9] 8 | scei = [93.9, 94.6, 94.9, 94.4, 95.2, 95.15, 95.4, 95.1, 95.4, 95.5, 95.1, 95.45, 95.3, 95.55, 95.55, 95.7, 95.7, 95.15, 95.8, 95.55, 95.75, 95.75, 95.95, 96.0, 95.9, 96.0, 95.9, 96.3, 95.95, 96.0, 95.65, 96.15, 96.05, 96.0, 95.85, 95.85, 95.9, 96.25, 96.25, 96.15, 96.2, 95.95, 96.25, 96.2, 96.0, 95.8, 96.4, 96.4, 96.1, 96.25] 9 | scei_async = [92.75, 93.95, 92.35, 94.1, 92.7, 94.9, 95.05, 95.15, 94.8, 95.55, 94.7, 95.55, 95.3, 95.35, 95.75, 95.3, 95.75, 95.7, 94.95, 95.4, 95.9, 95.8, 95.8, 95.9, 95.4, 95.8, 95.25, 96.05, 95.5, 96.15, 95.9, 95.9, 95.55, 95.85, 95.9, 95.35, 95.45, 96.2, 96.15, 96.0, 96.05, 96.05, 95.95, 95.85, 95.95, 95.75, 95.95, 96.1, 95.4, 95.8] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_round_acc("", scei, scei_async, apfl, fedavg, local, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-sota/cnn-cifar100-acc-sota.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc 4 | 5 | apfl = [60.3, 68.75, 66.5, 66.55, 68.65, 67.2, 66.8, 67.7, 67.05, 32.65, 57.75, 65.95, 64.0, 65.5, 64.5, 62.05, 63.65, 62.6, 58.65, 52.6, 62.05, 61.2, 61.9, 61.6, 62.05, 62.1, 62.55, 62.15, 62.32, 58.05, 62.05, 62.3, 62.3, 62.05, 62.32, 62.15, 61.9, 61.95, 62.05, 60.05, 62.05, 62.32, 62.27, 62.15, 62.15, 62.05, 62.05, 62.05, 62.05, 60.05] 6 | fedavg = [9.25, 17.4, 18.0, 21.6, 22.8, 23.45, 24.1, 24.65, 26.0, 25.8, 25.05, 26.35, 26.25, 25.7, 27.3, 25.45, 25.75, 26.55, 25.85, 27.05, 25.3, 25.75, 26.35, 26.05, 26.25, 27.15, 25.7, 28.6, 27.65, 27.8, 26.0, 28.6, 27.3, 26.7, 27.1, 28.25, 28.2, 28.6, 27.8, 27.8, 27.8, 27.8, 27.8, 27.8, 27.8, 27.8, 27.8, 27.8, 27.3, 27.6] 7 | local = [59.2, 63.15, 62.15, 62.65, 62.35, 64.0, 63.8, 60.65, 61.8, 61.5, 62.2, 63.35, 58.35, 60.2, 60.0, 60.1, 61.15, 62.8, 60.3, 61.4, 62.85, 61.2, 61.9, 62.1, 62.55, 62.65, 59.95, 60.1, 60.05, 60.02, 59.95, 60.1, 60.3, 59.95, 60.4, 60.2, 59.95, 60.05, 60.15, 59.9, 59.95, 59.98, 60.05, 59.95, 60.0, 60.0, 59.95, 59.95, 59.95, 59.95] 8 | scei = [64.9, 65.6, 67.25, 67.0, 66.85, 64.3, 63.9, 66.4, 65.05, 67.1, 66.1, 66.3, 65.45, 66.45, 65.15, 65.6, 65.85, 64.4, 63.85, 64.0, 61.65, 64.8, 66.05, 63.9, 65.25, 64.1, 63.35, 66.6, 62.05, 62.3, 65.3, 64.9, 66.0, 62.85, 64.05, 62.4, 63.1, 62.05, 66.6, 65.5, 61.35, 64.9, 62.6, 63.65, 63.3, 62.55, 60.55, 64.85, 62.4, 63.55] 9 | scei_async = [57.15, 58.4, 60.65, 63.25, 62.75, 64.55, 62.9, 63.05, 62.7, 63.05, 62.2, 62.15, 61.4, 61.0, 61.3, 61.15, 60.85, 62.8, 63.05, 61.6, 62.5, 62.15, 62.8, 62.4, 61.35, 61.2, 61.15, 61.05, 61.95, 61.1, 61.05, 61.6, 60.2, 61.4, 57.75, 58.0, 60.1, 60.05, 60.05, 60.05, 60.05, 60.05, 60.05, 60.05, 60.05, 60.05, 60.05, 60.05, 60.05, 60.05] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_round_acc("", scei, scei_async, apfl, fedavg, local, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-sota/cnn-mnist-acc-sota.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc 4 | 5 | apfl = [95.5, 97.45, 97.35, 97.5, 97.35, 97.05, 97.8, 97.4, 97.3, 33.85, 93.95, 96.7, 97.35, 97.25, 97.7, 97.85, 98.0, 97.75, 97.5, 85.95, 98.05, 98.15, 98.85, 98.3, 98.15, 98.15, 98.5, 98.4, 98.1, 96.15, 98.2, 98.15, 98.4, 98.5, 98.5, 98.35, 98.35, 97.9, 93.75, 96.55, 98.05, 98.3, 98.55, 98.4, 98.45, 98.7, 98.35, 98.0, 90.85, 89.4] 6 | fedavg = [14.0, 80.1, 89.4, 91.1, 92.8, 94.25, 94.55, 95.75, 95.4, 95.45, 95.6, 96.45, 96.35, 96.45, 96.35, 96.9, 96.25, 96.85, 96.75, 96.7, 96.55, 96.65, 97.1, 97.0, 96.7, 96.75, 96.6, 96.9, 97.55, 97.05, 96.8, 97.35, 97.35, 97.35, 97.65, 97.7, 97.55, 97.15, 97.55, 97.25, 97.5, 97.55, 97.45, 97.0, 97.25, 97.35, 97.45, 97.7, 97.55, 97.6] 7 | local = [96.35, 97.45, 96.9, 96.3, 97.7, 96.85, 97.35, 97.6, 97.15, 96.9, 97.0, 96.35, 96.05, 96.05, 96.35, 96.9, 96.75, 96.65, 95.2, 97.45, 96.5, 96.4, 96.7, 96.95, 96.15, 96.4, 96.35, 96.05, 96.4, 96.35, 96.15, 96.2, 96.75, 96.25, 96.65, 96.7, 96.5, 96.0, 96.5, 96.5, 96.2, 96.2, 96.45, 96.5, 96.55, 96.65, 96.8, 96.6, 96.5, 96.7] 8 | scei = [96.2, 96.8, 97.85, 97.2, 97.75, 97.8, 97.8, 98.3, 98.05, 98.35, 98.05, 97.8, 98.45, 98.4, 98.45, 98.45, 98.2, 98.25, 98.5, 98.1, 98.5, 98.35, 98.35, 98.4, 98.4, 98.55, 98.6, 98.3, 98.15, 98.4, 98.35, 98.35, 98.25, 98.5, 98.2, 98.7, 98.35, 98.4, 98.35, 98.6, 98.45, 98.65, 98.55, 98.85, 98.25, 98.65, 98.55, 98.55, 98.55, 98.3] 9 | scei_async = [94.1, 96.7, 96.9, 97.5, 97.4, 98.05, 97.8, 97.75, 97.95, 98.0, 98.3, 98.15, 98.2, 98.05, 98.2, 98.05, 98.1, 98.2, 98.7, 98.4, 98.45, 98.1, 98.0, 98.55, 97.85, 98.3, 98.15, 98.45, 98.45, 98.15, 98.45, 98.4, 98.35, 98.5, 98.45, 98.0, 98.5, 98.65, 98.35, 98.9, 98.75, 98.75, 98.15, 98.5, 98.2, 98.7, 98.35, 98.55, 98.65, 98.35] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_round_acc("", scei, scei_async, apfl, fedavg, local, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-nodes/resnet-cifar10-acc-nodes.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc_nodes 4 | 5 | scei005 = [33.75, 34.5, 43.25, 53.75, 56.25, 64.5, 64.25, 66.25, 64.75, 65.75, 69.0, 69.0, 71.0, 67.25, 69.5, 67.25, 68.75, 70.5, 67.5, 71.25, 66.75, 68.75, 68.75, 69.0, 71.0, 71.75, 71.5, 69.0, 71.75, 71.0, 70.25, 71.25, 71.0, 70.75, 71.5, 68.75, 70.0, 71.25, 74.75, 73.5, 73.5, 73.25, 73.75, 71.75, 70.5, 72.25, 69.5, 72.25, 72.5, 70.5] 6 | scei010 = [26.0, 38.0, 43.5, 31.5, 53.5, 62.5, 59.5, 48.5, 67.0, 58.5, 76.0, 68.0, 66.5, 64.5, 72.0, 76.0, 74.0, 77.0, 74.0, 81.5, 79.0, 84.0, 83.0, 86.0, 78.0, 86.0, 85.5, 81.5, 82.5, 83.5, 82.5, 83.5, 82.5, 82.5, 84.5, 82.5, 84.0, 83.5, 84.5, 82.0, 85.5, 82.5, 84.0, 82.5, 83.5, 86.0, 84.5, 84.5, 83.0, 83.5] 7 | scei020 = [38.75, 55.75, 57.0, 68.0, 69.0, 73.0, 78.25, 76.25, 82.75, 78.75, 81.5, 77.25, 79.5, 83.0, 84.0, 81.5, 80.75, 83.25, 82.25, 84.75, 84.5, 84.75, 86.25, 84.5, 83.75, 84.0, 86.25, 85.5, 84.0, 85.75, 73.0, 84.25, 83.25, 84.75, 83.0, 83.75, 84.0, 83.75, 83.75, 83.5, 82.5, 84.25, 83.75, 83.5, 85.0, 83.5, 85.5, 83.5, 83.75, 82.0] 8 | scei050 = [40.83, 49.0, 64.17, 69.17, 71.17, 77.83, 74.33, 73.83, 80.5, 77.83, 79.5, 80.83, 78.17, 81.83, 79.67, 83.67, 83.0, 81.5, 78.17, 79.5, 83.33, 84.33, 85.0, 83.0, 81.67, 84.67, 84.0, 83.33, 84.83, 83.17, 85.67, 84.0, 84.33, 84.0, 84.0, 85.0, 85.17, 83.33, 84.0, 81.83, 81.0, 81.17, 82.17, 81.5, 83.17, 83.33, 84.33, 84.33, 83.0, 83.33] 9 | scei100 = [38.62, 39.5, 57.5, 64.25, 67.5, 71.12, 70.75, 72.75, 79.5, 80.25, 79.25, 79.0, 82.5, 79.12, 80.88, 85.75, 79.38, 82.25, 85.12, 83.25, 85.62, 83.88, 83.88, 85.38, 85.25, 84.38, 84.62, 83.62, 84.38, 82.62, 83.75, 84.88, 84.5, 85.12, 84.12, 84.38, 84.62, 83.75, 85.12, 84.38, 84.5, 84.62, 84.12, 83.25, 84.62, 84.12, 83.88, 84.12, 83.75, 83.12] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_round_acc_nodes("", scei005, scei010, scei020, scei050, scei100, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-nodes/mlp-mnist-acc-nodes.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc_nodes 4 | 5 | scei005 = [91.8, 92.0, 93.0, 92.8, 93.1, 93.3, 93.4, 93.0, 93.6, 93.5, 94.1, 94.4, 93.8, 93.6, 93.2, 93.9, 94.0, 93.8, 94.5, 94.0, 94.1, 94.7, 94.4, 94.1, 94.6, 94.3, 93.7, 93.9, 94.5, 94.1, 94.3, 93.9, 93.8, 94.1, 94.4, 94.0, 94.4, 94.7, 94.1, 94.5, 94.6, 94.5, 94.2, 94.5, 94.8, 94.3, 94.4, 94.1, 94.5, 94.5] 6 | scei010 = [93.9, 94.6, 94.9, 94.4, 95.2, 95.15, 95.4, 95.1, 95.4, 95.5, 95.1, 95.45, 95.3, 95.55, 95.55, 95.7, 95.7, 95.15, 95.8, 95.55, 95.75, 95.75, 95.95, 96.0, 95.9, 96.0, 95.9, 96.3, 95.95, 96.0, 95.65, 96.15, 96.05, 96.0, 95.85, 95.85, 95.9, 96.25, 96.25, 96.15, 96.2, 95.95, 96.25, 96.2, 96.0, 95.8, 96.4, 96.4, 96.1, 96.25] 7 | scei020 = [94.55, 94.62, 94.95, 95.35, 95.6, 95.42, 95.5, 95.68, 95.68, 95.68, 95.7, 95.75, 96.12, 95.98, 96.25, 96.2, 96.15, 96.02, 96.1, 96.12, 96.12, 96.35, 96.4, 96.62, 96.5, 96.3, 96.35, 96.42, 96.32, 96.4, 96.42, 96.38, 96.15, 96.32, 96.5, 96.5, 96.3, 96.65, 96.65, 96.52, 96.5, 96.35, 96.55, 96.75, 96.58, 96.6, 96.52, 96.78, 96.52, 96.6] 8 | scei050 = [95.02, 95.25, 95.25, 95.45, 95.62, 95.75, 95.97, 95.86, 96.02, 96.1, 95.87, 95.99, 96.28, 96.27, 96.34, 96.18, 96.26, 96.12, 96.2, 96.42, 96.25, 96.44, 96.4, 96.47, 96.49, 96.56, 96.58, 96.61, 96.6, 96.7, 96.53, 96.74, 96.5, 96.61, 96.69, 96.74, 96.7, 96.85, 96.74, 96.85, 96.81, 96.86, 96.76, 96.67, 96.92, 96.83, 96.96, 96.97, 96.91, 96.84] 9 | scei100 = [94.38, 94.84, 95.06, 95.34, 95.37, 95.59, 95.58, 95.76, 95.79, 95.86, 95.9, 96.0, 95.96, 96.03, 96.24, 96.11, 96.32, 96.22, 96.26, 96.34, 96.28, 96.32, 96.49, 96.54, 96.58, 96.5, 96.54, 96.62, 96.64, 96.7, 96.68, 96.58, 96.64, 96.67, 96.85, 96.86, 96.82, 96.78, 96.9, 96.83, 96.84, 96.76, 96.91, 96.96, 96.93, 97.01, 96.94, 96.89, 97.04, 96.94] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_round_acc_nodes("", scei005, scei010, scei020, scei050, scei100, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-nodes/cnn-cifar10-acc-nodes.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc_nodes 4 | 5 | scei005 = [35.1, 43.7, 48.5, 51.7, 56.6, 56.3, 55.5, 57.4, 57.3, 55.6, 57.6, 57.3, 56.7, 57.1, 56.4, 57.2, 58.2, 57.4, 57.5, 57.7, 57.2, 57.9, 57.4, 57.1, 56.5, 56.3, 57.0, 56.2, 56.7, 56.5, 56.7, 56.5, 56.0, 55.5, 55.4, 56.4, 55.5, 55.5, 56.2, 55.7, 55.4, 55.6, 55.0, 55.0, 55.0, 54.8, 54.6, 54.4, 54.3, 54.3] 6 | scei010 = [40.35, 50.6, 55.7, 59.2, 61.8, 63.0, 64.1, 63.3, 64.85, 64.65, 64.8, 64.95, 65.1, 65.05, 65.5, 66.0, 65.95, 65.3, 66.4, 66.3, 66.05, 64.95, 66.0, 65.2, 65.85, 66.0, 65.55, 67.0, 67.1, 65.6, 66.6, 66.55, 66.55, 66.3, 66.35, 66.2, 66.8, 66.35, 66.95, 66.6, 66.9, 67.1, 66.35, 66.05, 66.25, 66.8, 66.9, 66.75, 66.7, 66.95] 7 | scei020 = [55.95, 60.15, 61.65, 62.3, 62.18, 62.4, 61.98, 61.3, 62.82, 63.3, 64.6, 62.95, 63.3, 63.4, 64.18, 63.72, 64.72, 64.55, 65.32, 64.12, 64.3, 63.72, 63.65, 64.05, 63.85, 64.15, 63.9, 62.12, 62.75, 63.28, 62.25, 62.92, 63.4, 63.0, 62.48, 62.98, 62.25, 63.18, 62.25, 63.45, 62.65, 63.12, 63.28, 62.62, 63.0, 62.42, 62.55, 62.62, 62.68, 62.15] 8 | scei050 = [52.82, 58.74, 59.89, 61.4, 61.65, 61.89, 61.56, 62.87, 62.97, 63.35, 63.17, 62.98, 64.11, 63.45, 64.77, 63.98, 64.87, 65.37, 64.56, 65.48, 64.9, 65.59, 65.68, 65.87, 65.88, 65.48, 65.49, 65.66, 65.82, 65.4, 65.96, 65.34, 66.29, 66.04, 66.16, 65.81, 65.19, 65.24, 65.45, 65.5, 65.4, 65.57, 65.25, 65.62, 65.46, 65.69, 65.77, 65.93, 64.84, 64.8] 9 | scei100 = [53.76, 59.79, 61.67, 62.27, 62.89, 63.4, 63.37, 64.19, 63.96, 63.91, 64.86, 64.78, 65.45, 65.44, 65.99, 65.99, 66.39, 66.27, 66.44, 67.13, 66.98, 67.28, 67.0, 67.38, 67.51, 67.18, 67.62, 67.5, 67.83, 67.74, 68.18, 67.92, 68.23, 67.78, 67.86, 67.46, 67.94, 68.04, 68.08, 68.14, 67.91, 67.95, 68.13, 68.16, 68.49, 68.34, 68.33, 68.21, 68.36, 68.24] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_round_acc_nodes("", scei005, scei010, scei020, scei050, scei100, False, True, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-nodes/cnn-mnist-acc-nodes.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc_nodes 4 | 5 | scei005 = [96.6, 97.6, 97.6, 97.8, 98.2, 98.1, 98.0, 98.7, 98.3, 98.5, 97.7, 98.5, 98.0, 99.1, 98.8, 98.7, 98.6, 98.6, 98.5, 99.1, 98.5, 98.3, 98.6, 98.7, 99.0, 98.7, 98.8, 98.5, 98.9, 99.0, 98.8, 99.3, 98.9, 95.1, 98.2, 98.6, 98.7, 98.6, 98.4, 98.9, 98.7, 98.5, 98.9, 99.4, 98.9, 98.3, 98.4, 98.8, 99.0, 97.1] 6 | scei010 = [96.2, 96.8, 97.85, 97.2, 97.75, 97.8, 97.8, 98.3, 98.05, 98.35, 98.05, 97.8, 98.45, 98.4, 98.45, 98.45, 98.2, 98.25, 98.5, 98.1, 98.5, 98.35, 98.35, 98.4, 98.4, 98.55, 98.6, 98.3, 98.15, 98.4, 98.35, 98.35, 98.25, 98.5, 98.2, 98.7, 98.35, 98.4, 98.35, 98.6, 98.45, 98.65, 98.55, 98.85, 98.25, 98.65, 98.55, 98.55, 98.55, 98.3] 7 | scei020 = [96.1, 97.02, 97.4, 97.52, 98.15, 98.15, 97.82, 98.1, 98.42, 98.42, 98.38, 98.25, 98.45, 98.18, 98.7, 98.52, 98.82, 98.65, 98.72, 99.02, 98.8, 98.72, 98.7, 99.08, 98.72, 98.68, 98.92, 98.9, 98.98, 99.0, 98.95, 98.88, 98.72, 99.1, 98.7, 98.78, 99.08, 99.0, 98.88, 98.82, 98.9, 99.02, 99.02, 98.8, 99.0, 98.95, 98.95, 99.12, 98.95, 98.95] 8 | scei050 = [96.25, 96.69, 97.55, 97.98, 98.05, 98.29, 98.4, 98.39, 98.5, 98.55, 98.54, 98.5, 98.58, 98.73, 98.85, 98.73, 98.83, 98.78, 98.7, 98.86, 98.88, 98.96, 98.94, 98.99, 98.81, 98.79, 98.96, 99.01, 98.94, 98.94, 98.99, 98.92, 99.02, 99.08, 99.14, 99.01, 99.0, 98.98, 99.1, 98.97, 99.06, 99.14, 99.1, 98.88, 98.96, 99.17, 99.09, 99.04, 98.91, 99.08] 9 | scei100 = [95.66, 96.01, 97.02, 97.58, 97.58, 97.92, 97.98, 98.15, 98.34, 98.42, 98.34, 98.39, 98.44, 98.5, 98.5, 98.58, 98.62, 98.63, 98.52, 98.78, 98.8, 98.74, 98.76, 98.7, 98.78, 98.76, 98.88, 98.9, 98.85, 98.84, 98.84, 98.9, 98.78, 98.82, 98.86, 98.87, 98.85, 98.86, 98.86, 98.84, 98.81, 98.8, 98.98, 98.97, 98.91, 98.9, 98.94, 98.95, 98.97, 98.98] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_round_acc_nodes("", scei005, scei010, scei020, scei050, scei100, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-nodes/cnn-cifar100-acc-nodes.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc_nodes 4 | 5 | scei005 = [62.3, 64.3, 66.2, 65.1, 62.8, 65.5, 62.2, 63.9, 64.7, 64.8, 61.4, 61.7, 62.8, 60.9, 61.8, 59.2, 62.0, 63.4, 60.8, 58.4, 63.8, 60.3, 62.7, 63.0, 63.3, 61.2, 62.0, 61.4, 59.4, 60.0, 60.3, 59.8, 59.5, 60.3, 61.0, 57.2, 59.6, 52.0, 59.5, 60.2, 60.4, 59.5, 59.6, 59.0, 59.5, 58.1, 58.4, 60.2, 58.8, 59.1] 6 | scei010 = [64.9, 65.6, 67.25, 67.0, 66.85, 64.3, 63.9, 66.4, 65.05, 67.1, 66.1, 66.3, 65.45, 66.45, 65.15, 65.6, 65.85, 64.4, 63.85, 64.0, 61.65, 64.8, 66.05, 63.9, 65.25, 64.1, 63.35, 66.6, 62.05, 62.3, 65.3, 64.9, 66.0, 62.85, 64.05, 62.4, 63.1, 62.05, 66.6, 65.5, 61.35, 64.9, 62.6, 63.65, 63.3, 62.55, 60.55, 64.85, 62.4, 63.55] 7 | scei020 = [62.92, 67.28, 65.75, 66.38, 66.18, 66.12, 65.55, 66.75, 65.47, 64.4, 65.6, 64.5, 65.65, 66.38, 64.6, 64.92, 65.68, 65.22, 66.25, 63.92, 64.05, 63.92, 64.03, 63.98, 63.85, 63.98, 63.72, 61.35, 63.35, 61.92, 62.6, 62.6, 62.18, 61.8, 62.12, 62.68, 61.45, 62.0, 60.35, 62.68, 62.72, 62.3, 60.82, 63.45, 59.12, 61.02, 64.28, 60.7, 64.05, 61.38] 8 | scei050 = [62.02, 64.75, 67.78, 67.39, 66.44, 66.04, 66.66, 65.9, 66.36, 65.96, 66.45, 65.75, 66.92, 65.89, 66.25, 66.59, 66.9, 67.26, 66.72, 66.06, 66.73, 66.12, 65.97, 65.82, 66.13, 65.52, 66.33, 65.59, 64.99, 66.34, 66.09, 65.89, 65.99, 64.69, 65.26, 65.15, 65.21, 66.13, 65.52, 66.33, 65.59, 64.99, 66.34, 66.09, 65.89, 65.99, 64.69, 65.26, 65.15, 65.21] 9 | scei100 = [62.42, 64.1, 65.94, 66.72, 65.61, 65.69, 66.1, 65.76, 65.52, 65.76, 65.85, 65.79, 65.81, 64.88, 65.22, 64.69, 64.64, 65.18, 64.87, 65.68, 65.65, 65.32, 64.96, 65.27, 65.59, 64.92, 64.78, 65.03, 65.55, 64.85, 65.03, 64.84, 65.48, 64.78, 64.71, 65.26, 64.98, 64.86, 64.99, 64.36, 65.32, 63.98, 65.02, 64.5, 64.72, 64.19, 64.28, 64.41, 64.38, 63.66] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_round_acc_nodes("", scei005, scei010, scei020, scei050, scei100, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/cost-overall/cnn-realworld-cost-overall.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_time_cost 4 | 5 | apfl = [51.77, 13.01, 13.29, 13.65, 13.76, 13.23, 13.74, 13.2, 12.73, 50.97, 13.7, 13.0, 13.48, 13.6, 13.36, 13.6, 13.51, 13.49, 13.28, 50.14, 13.47, 13.43, 13.21, 13.57, 13.94, 13.03, 13.09, 13.68, 13.37, 48.69, 13.02, 13.33, 13.48, 13.12, 13.97, 13.9, 13.09, 13.46, 13.2, 50.26, 13.42, 13.37, 13.42, 13.31, 13.94, 13.15, 13.59, 13.28, 13.54, 19.93] 6 | fedavg = [21.17, 20.82, 21.25, 20.61, 21.2, 21.34, 20.94, 20.99, 20.85, 20.95, 20.53, 20.72, 21.04, 20.92, 20.87, 20.96, 20.77, 20.54, 21.1, 20.86, 20.82, 20.71, 20.65, 20.92, 20.77, 20.62, 21.13, 21.24, 20.86, 21.09, 21.14, 18.31, 18.46, 21.39, 20.58, 20.62, 20.59, 23.11, 20.75, 21.21, 23.46, 20.93, 21.07, 20.97, 21.07, 20.57, 20.69, 20.63, 21.11, 21.21] 7 | local = [9.6, 9.59, 9.35, 9.57, 9.14, 9.1, 9.25, 9.35, 9.31, 9.2, 9.48, 8.84, 9, 9.42, 9.41, 9.75, 8.88, 9.89, 9.65, 9.59, 9.02, 9.57, 9.48, 9.99, 9.61, 9.51, 9.16, 9.21, 9, 9.71, 9.62, 9.56, 9.2, 9.79, 9.94, 9.82, 9.89, 9.79, 9.83, 9.88, 9.98, 9.79, 9.65, 9.52, 9.46, 9.44, 9.35, 9.76, 9.68, 9.62] 8 | scei = [29.79, 29.16, 29.87, 29.29, 29.84, 29.55, 29.85, 29.38, 29.54, 29.48, 29.62, 29.89, 29.63, 29.29, 29.53, 28.99, 29.16, 29.54, 28.78, 28.5, 29.25, 29.09, 29.24, 29.38, 28.89, 29.13, 29.11, 29.19, 28.74, 28.91, 28.84, 29.02, 28.96, 29.12, 29.18, 29.52, 29.26, 29.47, 29.67, 29.05, 29.25, 29.14, 29.32, 28.96, 29.15, 29.32, 28.75, 29.25, 29.36, 29.44] 9 | scei_async = [23.86, 23.58, 23.83, 23.37, 23.23, 23.64, 23.96, 23.85, 23.63, 23.32, 23.62, 23.01, 23.42, 23.18, 23.91, 23.05, 23.3, 23.71, 23.82, 23.22, 23.98, 23.4, 23.67, 23.18, 23.51, 23.17, 23.99, 23.02, 22.34, 23.07, 22.21, 23, 22.46, 23.77, 22.79, 22.75, 21.76, 22.85, 22.99, 22.96, 22.72, 22.22, 23.32, 22.54, 23.33, 22.39, 23.25, 23.16, 23.05, 23.15] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_time_cost("", scei, scei_async, apfl, fedavg, local, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-sota/cnn-realworld-acc-sota.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc 4 | 5 | apfl = [8.53, 71.88, 79.8, 81.4, 82.5, 82.88, 83.1, 82.58, 82.78, 83.45, 82.42, 83.22, 84.42, 84.05, 85.12, 84.28, 84.03, 84.08, 84.35, 84.58, 84.75, 85.0, 85.2, 84.97, 84.83, 84.45, 85.42, 85.17, 85.03, 84.7, 84.75, 84.58, 85.15, 85.88, 85.65, 85.17, 85.83, 85.97, 84.9, 85.1, 85.83, 85.58, 85.8, 85.75, 85.9, 86.55, 86.25, 86.25, 86.08, 85.9] 6 | fedavg = [51.12, 57.62, 62.95, 68.6, 70.03, 69.95, 71.45, 72.2, 71.7, 72.67, 72.55, 72.53, 73.85, 74.33, 74.12, 74.03, 75.1, 74.5, 75.25, 75.42, 74.92, 75.62, 75.58, 75.03, 75.22, 76.58, 76.42, 75.35, 76.33, 76.8, 76.42, 76.53, 76.72, 77.1, 76.62, 76.85, 77.1, 76.83, 77.33, 77.25, 76.62, 77.65, 77.42, 77.03, 77.22, 77.45, 77.53, 77.08, 77.15, 77.2] 7 | local = [14.68, 78.47, 80.78, 81.95, 82.88, 83.25, 83.85, 82.88, 83.55, 83.83, 83.8, 83.85, 83.15, 84.17, 84.25, 84.12, 84.83, 84.45, 83.95, 84.88, 84.5, 85.4, 84.8, 84.33, 84.58, 84.95, 84.47, 85.1, 85.28, 85.85, 85.2, 84.83, 84.88, 85.38, 85.25, 85.2, 85.22, 85.03, 85.2, 84.85, 85.58, 85.55, 85.85, 85.53, 85.55, 85.05, 85.33, 84.9, 85.2, 84.95] 8 | scei = [79.53, 82.25, 83.08, 83.47, 84.1, 84.38, 83.97, 84.58, 85.85, 86.15, 85.4, 86.6, 86.0, 86.75, 86.28, 86.88, 86.33, 87.25, 87.08, 87.47, 87.53, 87.7, 88.03, 88.17, 88.12, 87.78, 87.97, 88.78, 88.35, 88.17, 87.75, 88.33, 88.6, 89.03, 88.65, 88.25, 88.03, 88.55, 88.78, 89.03, 88.9, 88.28, 88.85, 88.95, 89.17, 88.67, 88.9, 88.67, 89.22, 89.2] 9 | scei_async = [74.25, 78.45, 78.82, 80.32, 80.6, 79.2, 82.8, 82.75, 82.0, 82.9, 81.28, 82.5, 82.5, 82.6, 82.03, 82.65, 82.25, 82.03, 82.5, 82.78, 82.03, 82.9, 82.28, 82.85, 82.95, 82.17, 82.67, 82.9, 82.67, 82.5, 82.35, 82.17, 82.75, 82.33, 82.6, 82.03, 82.65, 82.25, 82.03, 82.55, 82.78, 82.03, 82.9, 82.28, 82.85, 82.95, 82.17, 82.67, 82.9, 82.67] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_round_acc("", scei, scei_async, apfl, fedavg, local, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-nodes/cnn-realworld-acc-nodes.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc_nodes 4 | 5 | scei005 = [78.6, 80.75, 82.8, 84.35, 84.8, 81.45, 84.05, 80.95, 85.5, 84.0, 86.2, 83.55, 85.7, 84.2, 86.4, 85.5, 85.0, 85.2, 85.55, 86.7, 86.2, 84.4, 86.2, 86.4, 85.5, 84.5, 86.2, 86.5, 86.0, 86.5, 85.50, 86.7, 86.5, 85.0, 86.5, 86.4, 86.5, 86.0, 86.2, 87.5, 85.0, 86.2, 87.25, 86.37, 86.3, 87.4, 86.32, 85.4, 86.5, 86.0] 6 | scei010 = [79.66, 80.45, 81.7, 83.05, 82.58, 83.5, 83.82, 83.89, 85.08, 86.02, 85.76, 86.52, 86.29, 86.64, 86.25, 86.05, 86.01, 86.46, 86.34, 86.57, 86.11, 86.64, 87.2, 86.87, 86.66, 87.45, 86.51, 86.97, 87.22, 86.62, 86.76, 87.95, 87.24, 87.25, 87.27, 86.65, 86.7, 87.2, 88.09, 88.4, 88.27, 88.46, 88.86, 88.12, 88.49, 88.35, 87.43, 87.95, 87.86, 88.01] 7 | scei020 = [79.53, 82.25, 83.08, 83.47, 84.1, 84.38, 83.97, 84.58, 85.85, 86.15, 85.4, 86.6, 86.0, 86.75, 86.28, 86.88, 86.33, 87.25, 87.08, 87.47, 87.53, 87.7, 88.03, 88.17, 88.12, 87.78, 87.97, 88.78, 88.35, 88.17, 87.75, 88.33, 88.6, 89.03, 88.65, 88.25, 88.03, 88.55, 88.78, 89.03, 88.9, 88.28, 88.85, 88.95, 89.17, 88.67, 88.9, 88.67, 89.22, 89.2] 8 | scei050 = [76.31, 79.34, 80.78, 81.43, 83.05, 83.84, 84.18, 85.0, 85.75, 85.5, 86.0, 86.75, 86.75, 87.75, 87.75, 87.0, 87.25, 87.5, 88.75, 88.75, 88.0, 88.75, 88.75, 89.0, 89.5 , 89.0 , 89.75, 89.5, 88.75, 88.0, 89.25, 89.25, 88.5, 89.5, 89.5, 89.5, 89.0, 89.0, 89.25, 89.5, 89.75, 89.5, 89.75, 89.5, 90.2, 90.0, 91.0, 91.0, 90.25, 90.5] 9 | scei100 = [78.37, 79.48, 80.65, 81.14, 82.18, 83.21, 84.17, 85.08, 85.56, 86.19, 86.44, 87.01, 87.06, 87.06, 87.43, 87.7, 87.54, 87.72, 88.28, 88.12, 88.34, 88.5, 89.09, 89.26, 89.47, 89.07, 89.12, 89.23, 89.73, 89.63, 90.02, 89.73, 89.72, 89.28, 89.92, 89.25, 89.58, 89.01, 89.15, 89.18, 88.97, 89.48, 89.42, 90.05, 90.27, 89.89, 89.82, 90.01, 89.74, 90.01] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_round_acc_nodes("", scei005, scei010, scei020, scei050, scei100, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/cost-overall/cnn-uci-cost-overall.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_time_cost 4 | 5 | apfl = [56.09, 15.39, 15.23, 15.26, 15.77, 15.32, 15.23, 15.19, 15.13, 53.78, 15.21, 15.75, 15.1, 15.39, 15.23, 15.82, 15.95, 15.04, 15.47, 55.05, 15.91, 15.18, 15.52, 15.15, 15.57, 15.41, 15.8, 15.51, 16.69, 56.01, 15.27, 15.46, 15.12, 15.57, 15.1, 15.24, 15.8, 15.4, 14.78, 47.33, 15.18, 15.7, 15.4, 15.98, 15.12, 14.68, 15.13, 14.53, 15.58, 27.75] 6 | fedavg = [22.53, 22.13, 22.6, 22.68, 22.3, 22.87, 22.82, 22.8, 22.6, 22.78, 22.54, 22.73, 22.06, 22.15, 22.48, 22.41, 22.41, 22.58, 22.94, 22.91, 22.37, 22.2, 22.17, 22.43, 21.73, 22.3, 22.31, 22.84, 23.69, 22.63, 22.27, 22.38, 22.63, 22.38, 22.32, 22.54, 22.27, 21.88, 22.07, 22.41, 22.77, 22.75, 22.82, 22.28, 22.12, 22.45, 22.56, 22.27, 22.42, 22.78] 7 | local = [10.02, 11, 11.24, 10.28, 10.5, 10.28, 10.38, 10.68, 10.4, 11.4, 10.12, 10.3, 10.1, 11.9, 10.74, 11.68, 11.4, 11.04, 10.82, 10.7, 10.72, 11.66, 11.14, 10.6, 10.64, 10.6, 10.22, 10.96, 10.46, 10.24, 10.58, 10.82, 10.6, 10.46, 9.98, 9.84, 9.66, 9.46, 9.36, 9.06, 9.72, 9.6, 9.24, 10.02, 9.84, 9.76, 9.5, 9.44, 9.3, 9.26] 8 | scei = [31.83, 31.17, 31.91, 31.37, 31.85, 31.77, 31.52, 31.24, 31.53, 31.33, 31.68, 31.93, 31.52, 31.34, 31.08, 31.05, 31.99, 31.93, 31.29, 31.72, 31.4, 31.85, 32.04, 32.1, 32.6, 32.13, 31.82, 31.94, 32.6, 32.05, 32.11, 32.2, 31.75, 31.68, 31.38, 31.14, 31.45, 31.46, 31.42, 31.1, 31.55, 31.73, 31.44, 31.56, 31.25, 31.03, 31.36, 31.59, 31.46, 31.6] 9 | scei_async = [25.25, 24.97, 24.23, 24.22, 24.95, 25.01, 24.6, 24.88, 24.79, 24.61, 24.31, 24.82, 24.68, 24.38, 24.15, 24.32, 24.13, 24.02, 24.18, 24.65, 24.63, 24.09, 24.71, 24.51, 24.98, 24.07, 25.28, 24.85, 23.94, 24.52, 24.9, 25.53, 25.06, 24.78, 24.58, 23.03, 24.93, 24.71, 24.77, 24.82, 24.14, 24.05, 24.79, 23.98, 24.24, 24.82, 22.62, 23.61, 26.1, 24.25] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_time_cost("", scei, scei_async, apfl, fedavg, local, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-sota/cnn-uci-acc-sota.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc 4 | 5 | apfl = [15.06, 90.83, 91.69, 92.89, 92.96, 93.29, 93.25, 93.34, 93.15, 93.8, 94.06, 93.67, 93.94, 93.99, 93.84, 93.9, 94.08, 93.86, 94.11, 94.0, 93.62, 93.72, 93.86, 94.04, 94.16, 93.92, 94.08, 94.08, 93.91, 94.04, 93.95, 93.99, 93.96, 93.83, 93.84, 93.91, 93.86, 93.86, 94.01, 93.84, 93.94, 93.86, 93.78, 93.85, 93.85, 93.88, 94.2, 93.81, 93.78, 93.8] 6 | fedavg = [85.01, 87.5, 88.84, 89.81, 89.89, 90.28, 90.88, 90.71, 90.71, 91.34, 91.1, 91.46, 91.81, 91.85, 91.84, 91.55, 91.88, 91.97, 92.35, 92.25, 92.44, 92.06, 92.14, 92.44, 92.15, 92.29, 92.22, 91.89, 92.14, 91.89, 92.1, 91.8, 92.03, 92.22, 91.84, 91.67, 92.17, 91.64, 91.71, 91.97, 91.69, 91.88, 91.56, 91.54, 91.81, 91.91, 91.46, 91.66, 91.49, 91.66] 7 | local = [22.75, 92.12, 92.53, 92.45, 92.67, 92.85, 93.09, 92.89, 93.06, 93.64, 93.45, 93.12, 93.4, 93.45, 93.42, 93.19, 93.64, 93.34, 93.42, 93.58, 93.09, 93.39, 93.44, 93.55, 93.61, 93.11, 93.1, 93.3, 93.28, 93.53, 93.49, 93.4, 93.04, 93.36, 93.22, 93.38, 93.42, 93.54, 93.09, 93.28, 93.28, 93.38, 93.51, 93.36, 93.39, 93.39, 93.3, 93.38, 93.26, 93.58] 8 | scei = [93.41, 94.05, 94.5, 94.25, 94.51, 94.3, 94.19, 94.45, 94.81, 94.79, 94.91, 94.85, 93.91, 94.81, 94.9, 95.33, 95.1, 95.14, 95.1, 95.11, 95.51, 95.54, 95.35, 95.54, 95.81, 95.61, 95.58, 95.35, 95.4, 95.61, 95.66, 95.8, 95.39, 95.78, 95.71, 95.35, 95.65, 95.71, 95.71, 95.71, 95.86, 95.54, 95.55, 95.71, 95.54, 95.8, 95.72, 95.36, 95.6, 95.56] 9 | scei_async = [92.48, 93.29, 93.1, 92.98, 93.06, 93.56, 93.75, 93.52, 92.95, 93.96, 93.89, 94.44, 93.9, 94.46, 94.56, 94.14, 94.85, 94.71, 94.01, 94.52, 94.38, 93.89, 94.44, 93.9, 94.46, 94.56, 94.14, 94.85, 94.71, 94.01, 94.6, 94.61, 95.01, 95.04, 94.85, 95.04, 95.31, 95.11, 95.08, 94.85, 94.9, 95.11, 95.16, 95.3, 94.89, 95.28, 95.21, 94.85, 95.15, 95.21] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_round_acc("", scei, scei_async, apfl, fedavg, local, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-nodes/cnn-imagenet-acc-nodes.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc_nodes 4 | 5 | scei005 = [21.43, 31.62, 42.62, 44.74, 50.86, 47.44, 49.05, 51.06, 49.9, 51.3, 48.68, 49.44, 49.46, 45.87, 51.29, 51.21, 52.98, 48.06, 50.89, 50.21, 53.95, 48.84, 51.71, 49.29, 48.79, 52.37, 48.59, 52.02, 51.24, 47.6, 46.82, 49.9, 46.22, 43.94, 47.77, 49.44, 46.71, 45.18, 47.14, 45.17, 50.22, 52.84, 51.48, 41.43, 48.95, 48.3, 50.03, 52.75, 47.62, 51.08] 6 | scei010 = [26.48, 30.1, 36.39, 39.57, 40.97, 44.4, 46.93, 47.99, 49.59, 48.79, 50.21, 51.24, 50.63, 49.66, 52.19, 53.24, 51.93, 51.47, 54.3, 51.45, 51.6, 51.85, 50.91, 51.54, 52.7, 51.61, 51.48, 52.77, 50.68, 52.27, 52.61, 52.46, 53.53, 49.46, 52.02, 50.2, 50.13, 50.73, 51.32, 51.3, 51.47, 50.0, 50.22, 49.77, 50.84, 49.73, 49.26, 50.27, 50.24, 50.15] 7 | scei020 = [28.25, 29.3, 34.3, 35.2, 38.95, 41.45, 41.4, 40.3, 44.3, 47.15, 47.05, 46.6, 49.95, 52.1, 52.25, 52.35, 52.9, 55.0, 53.35, 54.3, 54.25, 55.55, 55.5, 57.05, 56.0, 55.95, 54.9, 54.35, 55.5, 56.05, 56.1, 55.1, 55.45, 54.5, 54.1, 54.6, 54.3, 53.4, 55.5, 54.05, 53.4, 53.6, 54.4, 54.35, 53.05, 54.6, 54.25, 53.7, 53.45, 52.85] 8 | scei050 = [32.79, 41.02, 45.5, 46.99, 48.69, 50.69, 51.14, 53.26, 54.65, 54.65, 54.97, 54.93, 55.08, 57.11, 56.86, 57.35, 57.45, 57.01, 57.51, 56.76, 57.24, 57.91, 58.31, 57.64, 56.96, 57.01, 58.1, 57.72, 57.23, 57.01, 57.51, 56.59, 57.79, 56.43, 56.47, 56.72, 56.7, 56.35, 56.74, 56.04, 56.57, 56.7, 55.53, 56.18, 56.0, 56.16, 55.99, 55.6, 55.83, 55.37] 9 | scei100 = [35.49, 43.39, 45.3, 47.9, 49.66, 51.73, 52.17, 53.24, 54.19, 55.11, 55.83, 55.67, 56.55, 57.84, 58.35, 58.03, 59.08, 59.11, 58.36, 58.68, 58.84, 59.42, 58.96, 59.12, 58.2, 58.54, 58.82, 58.48, 59.1, 59.27, 59.11, 58.42, 58.67, 57.99, 57.67, 57.27, 57.17, 57.82, 58.06, 58.53, 58.15, 57.82, 57.26, 57.61, 57.77, 57.27, 57.93, 58.53, 57.58, 58.11] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_round_acc_nodes("", scei005, scei010, scei020, scei050, scei100, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-nodes/cnn-uci-acc-nodes.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc_nodes 4 | 5 | scei005 = [91.88, 92.32, 92.52, 92.9, 93.95, 93.1, 93.4, 91.5, 93.75, 93.8, 94.1, 94.28, 92.6, 94.28, 94.12, 94.92, 93.22, 94.85, 94.85, 94.55, 94.4, 95.12, 94.52, 94.82, 94.4, 95.08, 95.1, 93.45, 94.75, 94.8, 94.1, 94.28, 94.6, 94.28, 94.12, 93.92, 94.22, 93.85, 94.85, 94.55, 94.4, 95.12, 94.52, 94.82, 94.4, 95.08, 95.1, 93.45, 95.08, 95.1] 6 | scei010 = [93.41, 94.05, 94.5, 94.25, 94.51, 94.3, 94.19, 94.45, 94.81, 94.79, 94.91, 94.85, 93.91, 94.81, 94.9, 95.33, 95.1, 95.14, 95.1, 95.11, 95.51, 95.54, 95.35, 95.54, 95.81, 95.61, 95.58, 95.35, 95.4, 95.61, 95.66, 95.8, 95.39, 95.78, 95.71, 95.35, 95.65, 95.71, 95.71, 95.71, 95.86, 95.54, 95.55, 95.71, 95.54, 95.8, 95.72, 95.36, 95.6, 95.56] 7 | scei020 = [93.13, 94.53, 94.58, 94.92, 95.07, 95.24, 94.9, 95.05, 94.88, 95.01, 95.53, 94.86, 95.13, 95.09, 95.31, 95.41, 95.21, 95.61, 95.6, 95.55, 95.74, 95.74, 95.85, 95.7, 95.84, 95.7, 95.83, 95.76, 95.69, 95.67, 95.55, 95.54, 95.63, 95.46, 95.55, 95.93, 95.73, 95.79, 95.46, 95.68, 95.6, 95.78, 95.8, 95.8, 95.92, 95.71, 95.64, 95.73, 95.5, 95.57] 8 | scei050 = [93.8, 94.51, 95.19, 95.22, 95.56, 95.59, 95.56, 95.68, 95.47, 95.53, 95.62, 95.58, 95.47, 95.52, 95.65, 95.61, 95.51, 95.72, 95.67, 95.78, 95.83, 95.86, 95.89, 95.8, 95.83, 95.93, 96.05, 96.03, 95.91, 96.13, 96.08, 96.18, 96.16, 95.94, 96.08, 96.17, 96.09, 96.07, 95.92, 96.08, 96.15, 95.93, 96.01, 95.92, 95.8, 96.0, 96.12, 95.98, 96.18, 96.17] 9 | scei100 = [94.08, 94.44, 95.59, 95.48, 95.65, 96.01, 95.69, 95.72, 95.78, 95.85, 95.8, 95.71, 96.02, 95.8, 95.9, 95.92, 95.88, 96.05, 96.09, 96.17, 96.14, 96.03, 96.12, 96.23, 96.16, 96.2, 96.07, 96.18, 96.12, 95.99, 96.11, 95.94, 96.11, 96.19, 96.12, 96.2, 96.19, 96.08, 96.0, 96.2, 96.15, 96.15, 95.95, 96.21, 95.9, 96.24, 96.2, 96.09, 96.21, 96.12] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_round_acc_nodes("", scei005, scei010, scei020, scei050, scei100, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/plot/cost-overall/resnet-cifar10-cost-overall.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_time_cost 4 | 5 | apfl = [102.4, 20.1, 20.86, 21.85, 20.94, 20.8, 20.48, 21.25, 21.84, 100.45, 19.78, 19.83, 20.24, 21.63, 21.3, 20.48, 20.17, 20.45, 21.63, 98.7, 21.21, 22.1, 21.81, 21.56, 20.49, 20.69, 21.64, 21.02, 20.73, 105.9, 20.22, 19.69, 20.73, 20.4, 20.44, 21.09, 21.31, 21.41, 21.54, 100.35, 19.98, 18.51, 19.29, 18.68, 18.86, 18.87, 18.34, 17.9, 18.85, 38.9] 6 | fedavg = [37.94, 37.55, 37.68, 37.79, 37.9, 37.08, 37.96, 37.75, 37.33, 37.46, 37.89, 37.25, 38.03, 36.56, 37.64, 37.19, 37.28, 37.23, 37.48, 37.89, 35.56, 35.79, 37.7, 37.4, 37.19, 37.21, 36.38, 37.03, 37.3, 37.59, 37.77, 37.87, 36.96, 36.2, 37.12, 37.39, 37.72, 36.47, 36.77, 37.23, 37.41, 37.92, 37.42, 37.31, 37.73, 37.19, 37.53, 37.84, 37.65, 37.99] 7 | local = [24.63, 20.71, 21.74, 20.92, 22.47, 21.78, 19.73, 21.75, 22.24, 21.55, 24.68, 19.9, 21.27, 21.28, 22.93, 20.42, 21.14, 22.01, 21.59, 19.92, 20.36, 21.0, 20.85, 21.54, 21.91, 19.1, 22.5, 22.09, 21.93, 21.29, 19.94, 21.04, 20.87, 21.44, 21.9, 22.08, 21.81, 20.43, 20.57, 22.34, 20.51, 22.45, 19.95, 23.1, 20.68, 21.23, 20.41, 22.83, 20.64, 21.72] 8 | scei = [63.32, 60.49, 60.45, 62.2, 62.29, 62.06, 62.06, 60.97, 60.17, 61.66, 60.16, 60.38, 61.08, 61.68, 62.28, 60.59, 61.22, 63.27, 63.79, 61.32, 63.19, 61.61, 62.3, 60.62, 63.05, 62.21, 61.7, 60.67, 61.86, 60.71, 61.28, 63.79, 61.32, 63.19, 62.21, 62.3, 63.05, 60.67, 62.3, 63.29, 60.49, 63.27, 63.19, 60.71, 61.61, 62.3, 61.7, 63.2, 63.29, 62.06] 9 | scei_async = [48.68, 40.84, 41.46, 40.69, 42.74, 41.2, 40.05, 40.82, 41.45, 40.61, 42.0, 41.93, 41.61, 40.74, 40.55, 41.55, 40.33, 40.95, 40.34, 40.19, 41.07, 41.12, 40.55, 42.05, 39.86, 41.94, 44.64, 40.69, 41.18, 42.43, 40.72, 39.67, 41.85, 41.95, 39.55, 41.98, 43.17, 41.06, 40.3, 41.65, 42.64, 40.37, 42.23, 41.38, 41.86, 41.78, 40.65, 40.41, 40.8, 40.47] 10 | 11 | save_path = None 12 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 13 | save_path = sys.argv[2] 14 | 15 | plot_time_cost("", scei, scei_async, apfl, fedavg, local, False, False, save_path, plot_size="4") 16 | -------------------------------------------------------------------------------- /federated-learning/models/Test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def test_img(net_g, my_dataset, test_indices, local_test_bs, device): 6 | net_g.eval() 7 | # testing 8 | test_loss = 0 9 | correct = 0 10 | data_loader = my_dataset.load_test_dataset(test_indices, local_test_bs) 11 | for idx, (data, target) in enumerate(data_loader): 12 | data = data.detach().clone().type(torch.FloatTensor) 13 | if device != torch.device('cpu'): 14 | data, target = data.to(device), target.to(device) 15 | log_probs = net_g(data) 16 | # sum up batch loss 17 | test_loss += F.cross_entropy(log_probs, target, reduction='sum') 18 | # get the index of the max log-probability 19 | y_pred = log_probs.data.max(1, keepdim=True)[1] 20 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 21 | 22 | # test_loss /= len(data_loader.dataset) 23 | # accuracy = 100.00 * correct / len(data_loader.dataset) 24 | return correct, test_loss 25 | 26 | 27 | def test_img_total(net_g, my_dataset, idx_list, local_test_bs, device): 28 | accuracy_list = [] 29 | test_loss_list = [] 30 | correct_test_local = 0 31 | loss_test_local = 0 32 | for i in range(len(idx_list)): 33 | correct_test, loss_test = test_img(net_g, my_dataset, idx_list[i], local_test_bs, device) 34 | if i == 0: 35 | correct_test_local = correct_test 36 | accuracy_local = 100.0 * correct_test_local / len(idx_list[0]) 37 | accuracy_list.append(accuracy_local) 38 | loss_test_local = loss_test 39 | loss_local = loss_test_local / len(idx_list[0]) 40 | test_loss_list.append(loss_local) 41 | else: 42 | accuracy_skew = 100.0 * (correct_test_local + correct_test) / (len(idx_list[0]) + len(idx_list[i])) 43 | accuracy_list.append(accuracy_skew) 44 | loss_skew = (loss_test_local + loss_test) / (len(idx_list[0]) + len(idx_list[i])) 45 | test_loss_list.append(loss_skew) 46 | 47 | return accuracy_list, test_loss_list 48 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-alpha/resnet-cifar10-acc-alpha.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc_alpha 4 | 5 | scei = [26.0, 38.0, 43.5, 31.5, 53.5, 62.5, 59.5, 48.5, 67.0, 58.5, 76.0, 68.0, 66.5, 64.5, 72.0, 76.0, 74.0, 77.0, 74.0, 81.5, 79.0, 84.0, 83.0, 86.0, 78.0, 86.0, 85.5, 81.5, 82.5, 83.5, 82.5, 83.5, 82.5, 82.5, 84.5, 82.5, 84.0, 83.5, 84.5, 82.0, 85.5, 82.5, 84.0, 82.5, 83.5, 86.0, 84.5, 84.5, 83.0, 83.5] 6 | scei_025 = [37.5, 43.5, 49.0, 49.0, 50.0, 45.0, 53.5, 51.0, 50.0, 53.5, 51.5, 55.0, 54.0, 60.0, 57.5, 56.0, 55.5, 55.5, 54.0, 59.0, 59.5, 59.0, 60.0, 60.0, 62.5, 59.0, 58.5, 58.5, 59.0, 58.0, 61.0, 59.0, 59.0, 59.5, 57.5, 60.0, 60.0, 60.0, 58.0, 58.5, 58.5, 59.5, 60.0, 59.5, 59.5, 61.0, 62.0, 59.0, 59.0, 59.5] 7 | scei_050 = [37.0, 41.5, 43.0, 55.5, 52.5, 50.0, 49.5, 60.0, 50.0, 61.0, 55.0, 58.5, 55.0, 58.0, 59.5, 57.0, 59.0, 62.5, 61.5, 61.0, 65.0, 64.0, 63.0, 58.5, 62.5, 63.0, 61.0, 64.0, 63.5, 60.0, 61.0, 63.0, 63.5, 65.5, 62.0, 64.5, 61.5, 65.5, 60.5, 63.5, 62.5, 60.5, 61.5, 63.5, 63.0, 63.0, 63.5, 60.5, 63.5, 63.5] 8 | scei_075 = [44.0, 50.5, 51.0, 43.5, 48.0, 55.0, 53.5, 60.0, 58.0, 54.0, 65.0, 61.5, 67.0, 63.0, 66.5, 64.0, 66.5, 63.0, 64.5, 64.0, 68.5, 65.5, 65.5, 66.5, 66.0, 65.0, 66.0, 65.0, 66.5, 65.0, 67.5, 63.5, 63.0, 64.0, 63.5, 62.0, 64.0, 64.0, 63.5, 64.5, 62.5, 63.0, 64.0, 64.0, 67.0, 64.0, 63.0, 66.0, 65.0, 66.0] 9 | fedavg = [29.0, 38.5, 44.0, 53.5, 53.0, 46.5, 52.0, 62.0, 62.0, 57.0, 61.5, 58.5, 62.0, 52.5, 68.0, 58.5, 67.0, 69.5, 66.5, 66.0, 63.5, 66.5, 65.5, 59.5, 65.0, 64.0, 65.0, 65.5, 65.0, 65.5, 65.5, 65.0, 66.0, 64.0, 65.0, 62.5, 61.5, 66.0, 66.5, 70.5, 69.5, 69.5, 70.0, 67.5, 69.0, 68.5, 69.5, 68.5, 66.0, 64.5] 10 | local = [35.5, 48.5, 51.0, 44.5, 55.0, 57.5, 60.5, 60.0, 59.5, 61.0, 63.5, 59.0, 62.5, 64.0, 60.5, 63.0, 63.5, 59.0, 61.5, 57.0, 62.0, 62.5, 66.0, 64.5, 67.5, 65.5, 66.5, 66.0, 67.0, 66.5, 67.0, 71.0, 64.0, 67.0, 66.5, 69.0, 68.5, 63.0, 65.5, 66.5, 66.0, 63.5, 67.5, 66.5, 65.5, 65.0, 69.0, 68.5, 68.5, 68.5] 11 | 12 | save_path = None 13 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 14 | save_path = sys.argv[2] 15 | 16 | plot_round_acc_alpha("", scei, fedavg, scei_025, scei_050, scei_075, local, False, False, save_path, plot_size="4") 17 | -------------------------------------------------------------------------------- /federated-learning/utils/CentralStore.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import threading 3 | 4 | from utils.util import ColoredLogger 5 | 6 | lock = threading.Lock() 7 | 8 | logging.setLoggerClass(ColoredLogger) 9 | logger = logging.getLogger("CountStore") 10 | 11 | 12 | class NextRoundCount: 13 | def __init__(self): 14 | self.next_round_count_num = 0 15 | 16 | def add_count(self, count_target): 17 | reach_target = False 18 | lock.acquire() 19 | self.next_round_count_num += 1 20 | if self.next_round_count_num == count_target: 21 | reach_target = True 22 | lock.release() 23 | logger.debug("Added next_round_count, now: {}".format(self.next_round_count_num)) 24 | return reach_target 25 | 26 | def reset(self): 27 | lock.acquire() 28 | self.next_round_count_num = 0 29 | lock.release() 30 | logger.debug("Reset next_round_count, now: {}".format(self.next_round_count_num)) 31 | 32 | 33 | class ShutdownCount: 34 | def __init__(self): 35 | self.shutdown_count_num = 0 36 | 37 | def add_count(self, count_target): 38 | reach_target = False 39 | lock.acquire() 40 | self.shutdown_count_num += 1 41 | if self.shutdown_count_num == count_target: 42 | reach_target = True 43 | lock.release() 44 | logger.debug("Added shutdown_count_num, now: {}".format(self.shutdown_count_num)) 45 | return reach_target 46 | 47 | def reset(self): 48 | lock.acquire() 49 | self.shutdown_count_num = 0 50 | lock.release() 51 | logger.debug("Reset shutdown_count_num, now: {}".format(self.shutdown_count_num)) 52 | 53 | 54 | class IPCount: 55 | def __init__(self): 56 | self.ipMap = {} 57 | self.uuid = 0 58 | 59 | def get_new_id(self): 60 | lock.acquire() 61 | self.uuid += 1 62 | new_id = self.uuid 63 | logger.debug("new id: {}".format(new_id)) 64 | lock.release() 65 | return new_id 66 | 67 | def get_keys(self): 68 | return self.ipMap.keys() 69 | 70 | def get_map(self, key): 71 | return self.ipMap[key] 72 | 73 | def set_map(self, key, value): 74 | lock.acquire() 75 | self.ipMap[key] = value 76 | lock.release() 77 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-alpha/cnn-imagenet-acc-alpha.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc_alpha 4 | 5 | scei = [28.25, 29.3, 34.3, 35.2, 38.95, 41.45, 41.4, 40.3, 44.3, 47.15, 47.05, 46.6, 49.95, 52.1, 52.25, 52.35, 52.9, 55.0, 53.35, 54.3, 54.25, 55.55, 55.5, 57.05, 56.0, 55.95, 54.9, 54.35, 55.5, 56.05, 56.1, 55.1, 55.45, 54.5, 54.1, 54.6, 54.3, 53.4, 55.5, 54.05, 53.4, 53.6, 54.4, 54.35, 53.05, 54.6, 54.25, 53.7, 53.45, 52.85] 6 | scei_025 = [26.15, 34.2, 37.0, 41.95, 42.5, 44.1, 49.1, 48.1, 48.0, 48.75, 46.85, 45.45, 45.15, 43.8, 45.2, 43.5, 41.75, 43.5, 42.75, 42.35, 44.25, 44.1, 40.95, 41.4, 41.05, 39.45, 39.9, 39.35, 38.05, 37.75, 37.6, 36.4, 36.45, 36.45, 34.55, 36.15, 35.65, 34.9, 35.05, 34.25, 33.45, 33.35, 35.5, 33.35, 33.95, 32.75, 30.75, 33.45, 32.7, 33.6] 7 | scei_050 = [25.3, 25.25, 25.4, 25.0, 27.25, 32.95, 33.75, 38.85, 38.65, 40.75, 44.3, 44.65, 46.8, 48.05, 47.65, 50.0, 51.35, 51.55, 50.5, 50.3, 51.0, 51.45, 50.45, 50.05, 48.45, 50.35, 50.2, 49.25, 48.7, 49.5, 48.6, 48.6, 49.65, 48.4, 47.8, 46.05, 47.0, 47.7, 48.45, 48.05, 46.7, 46.15, 45.7, 46.65, 46.2, 45.6, 46.65, 44.75, 45.2, 44.85] 8 | scei_075 = [27.7, 30.05, 35.95, 39.65, 41.8, 43.25, 45.4, 46.65, 47.25, 50.6, 47.1, 48.55, 49.65, 49.95, 51.65, 51.0, 51.35, 51.35, 51.1, 50.15, 49.8, 50.75, 50.8, 50.75, 51.1, 51.05, 52.0, 47.85, 49.85, 50.6, 49.9, 50.25, 48.7, 48.9, 48.15, 50.7, 49.2, 47.35, 48.6, 47.65, 48.7, 48.1, 48.2, 47.55, 47.05, 48.1, 48.1, 46.95, 48.9, 46.6] 9 | fedavg = [5.0, 5.25, 6.4, 6.9, 7.45, 7.3, 8.1, 8.05, 8.85, 9.65, 9.85, 10.75, 10.85, 10.35, 10.35, 9.85, 10.75, 10.45, 9.9, 9.95, 11.4, 9.6, 11.05, 9.85, 10.05, 9.8, 10.55, 9.95, 10.15, 9.35, 9.85, 9.75, 9.4, 9.3, 9.2, 8.55, 9.35, 8.7, 9.3, 8.15, 9.75, 9.05, 8.15, 8.4, 8.65, 9.15, 9.15, 8.65, 8.6, 8.1] 10 | local = [28.85, 35.3, 43.65, 49.7, 51.65, 53.25, 53.0, 55.15, 54.35, 52.3, 54.75, 54.25, 53.3, 53.8, 53.6, 51.8, 52.25, 50.15, 52.5, 53.25, 53.05, 52.2, 52.95, 52.35, 53.45, 53.1, 53.5, 51.35, 53.2, 52.8, 53.8, 53.75, 52.1, 53.95, 53.8, 54.2, 54.25, 53.4, 54.5, 54.25, 53.4, 53.7, 53.8, 51.9, 52.9, 53.8, 53.05, 53.75, 53.5, 53.55] 11 | 12 | save_path = None 13 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 14 | save_path = sys.argv[2] 15 | 16 | plot_round_acc_alpha("", scei, fedavg, scei_025, scei_050, scei_075, local, False, False, save_path, plot_size="4") 17 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-alpha/cnn-cifar10-acc-alpha.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc_alpha 4 | 5 | scei = [40.35, 50.6, 55.7, 59.2, 61.8, 63.0, 64.1, 63.3, 64.85, 64.65, 64.8, 64.95, 65.1, 65.05, 65.5, 66.0, 65.95, 65.3, 66.4, 66.3, 66.05, 64.95, 66.0, 65.2, 65.85, 66.0, 65.55, 67.0, 67.1, 65.6, 66.6, 66.55, 66.55, 66.3, 66.35, 66.2, 66.8, 66.35, 66.95, 66.6, 66.9, 67.1, 66.35, 66.05, 66.25, 66.8, 66.9, 66.75, 66.7, 66.95] 6 | scei_025 = [49.25, 56.45, 58.95, 59.0, 58.6, 58.95, 59.6, 57.8, 59.1, 57.85, 57.2, 57.4, 57.3, 57.0, 55.9, 55.45, 56.35, 55.25, 55.7, 55.7, 54.8, 55.0, 54.55, 55.45, 53.9, 54.05, 53.6, 53.55, 55.0, 53.7, 53.9, 53.45, 53.0, 53.5, 53.55, 53.4, 52.8, 52.55, 53.05, 52.95, 53.05, 52.75, 53.2, 52.6, 52.15, 51.95, 53.45, 52.3, 52.65, 51.3] 7 | scei_050 = [50.35, 56.95, 60.55, 63.15, 61.0, 60.6, 60.55, 60.6, 60.05, 60.05, 59.7, 58.1, 59.3, 59.3, 59.15, 59.3, 59.75, 58.05, 58.9, 57.95, 58.7, 58.45, 55.6, 57.5, 57.7, 57.4, 56.9, 56.65, 57.05, 57.35, 57.2, 56.0, 56.75, 55.9, 56.2, 57.05, 55.15, 56.45, 56.0, 56.0, 55.85, 55.5, 55.7, 55.25, 55.25, 55.65, 55.45, 54.7, 55.2, 54.0] 8 | scei_075 = [53.65, 59.0, 60.8, 59.5, 60.7, 60.5, 60.95, 59.85, 61.2, 61.3, 60.35, 62.0, 61.35, 61.55, 61.05, 61.95, 61.7, 61.05, 59.9, 61.95, 61.3, 60.95, 60.5, 61.4, 61.7, 61.85, 61.45, 61.0, 62.1, 60.25, 61.85, 61.15, 62.1, 61.75, 59.8, 60.95, 59.95, 61.9, 59.5, 59.95, 61.25, 58.0, 58.55, 58.3, 59.65, 60.6, 59.2, 59.45, 58.5, 59.8] 9 | fedavg = [17.55, 21.6, 31.1, 36.35, 39.6, 41.6, 41.8, 41.05, 40.45, 39.8, 40.85, 40.2, 40.85, 41.2, 40.55, 40.35, 40.7, 40.3, 40.8, 40.3, 41.3, 40.25, 41.55, 41.45, 42.0, 41.5, 41.55, 42.45, 42.2, 41.3, 40.55, 41.75, 41.9, 42.75, 42.45, 42.65, 42.15, 42.9, 42.8, 40.65, 41.85, 42.05, 41.6, 41.8, 41.55, 41.6, 41.7, 42.55, 41.8, 42.3] 10 | local = [55.35, 59.5, 60.3, 59.65, 58.95, 61.55, 62.6, 62.85, 62.95, 62.9, 63.0, 63.1, 63.15, 63.15, 63.0, 63.15, 63.2, 63.2, 63.15, 63.1, 63.1, 63.15, 63.15, 63.15, 63.15, 63.2, 63.15, 63.15, 63.15, 63.15, 63.15, 63.15, 63.15, 63.15, 63.15, 63.2, 63.2, 63.2, 63.2, 63.2, 63.2, 63.2, 63.2, 63.2, 63.25, 63.25, 63.25, 63.25, 63.3, 63.3] 11 | 12 | save_path = None 13 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 14 | save_path = sys.argv[2] 15 | 16 | plot_round_acc_alpha("", scei, fedavg, scei_025, scei_050, scei_075, local, False, True, save_path, plot_size="4") 17 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-alpha/cnn-mnist-acc-alpha.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc_alpha 4 | 5 | scei = [96.2, 96.8, 97.85, 97.2, 97.75, 97.8, 97.8, 98.3, 98.05, 98.35, 98.05, 97.8, 98.45, 98.4, 98.45, 98.45, 98.2, 98.25, 98.5, 98.1, 98.5, 98.35, 98.35, 98.4, 98.4, 98.55, 98.6, 98.3, 98.15, 98.4, 98.35, 98.35, 98.25, 98.5, 98.2, 98.7, 98.35, 98.4, 98.35, 98.6, 98.45, 98.65, 98.55, 98.85, 98.25, 98.65, 98.55, 98.55, 98.55, 98.3] 6 | scei_025 = [83.6, 93.55, 95.25, 95.65, 96.0, 96.15, 96.7, 96.7, 96.95, 97.05, 96.75, 97.05, 96.95, 97.1, 97.4, 96.9, 97.55, 97.2, 97.3, 97.25, 97.2, 96.9, 97.25, 97.3, 97.3, 97.25, 97.4, 97.25, 97.3, 97.15, 97.55, 97.5, 97.05, 97.0, 97.4, 97.35, 97.5, 97.3, 97.45, 97.4, 97.3, 97.0, 97.5, 97.05, 97.45, 97.3, 97.1, 97.55, 97.8, 97.65] 7 | scei_050 = [94.75, 96.35, 97.5, 98.0, 98.2, 98.2, 98.7, 98.5, 98.65, 98.7, 98.9, 98.4, 99.0, 98.5, 98.85, 98.6, 98.8, 98.8, 98.95, 98.7, 98.75, 99.1, 98.8, 99.0, 99.0, 98.85, 99.0, 99.2, 98.85, 99.1, 99.1, 99.25, 99.05, 99.05, 99.1, 99.05, 99.05, 99.05, 99.15, 98.9, 99.05, 99.2, 98.8, 98.65, 99.05, 99.2, 98.85, 99.0, 99.15, 98.9] 8 | scei_075 = [96.35, 97.4, 98.05, 98.0, 97.75, 98.2, 98.55, 98.7, 98.6, 98.25, 98.85, 98.85, 98.75, 98.7, 98.85, 98.7, 98.6, 98.65, 98.8, 98.45, 98.7, 98.65, 99.05, 98.45, 98.8, 98.75, 98.8, 98.3, 98.45, 98.7, 99.4, 98.9, 98.65, 98.55, 98.9, 98.5, 98.55, 99.0, 98.85, 99.25, 98.95, 99.0, 98.75, 98.9, 98.8, 98.85, 99.2, 98.8, 98.8, 99.0] 9 | fedavg = [14.0, 80.1, 89.4, 91.1, 92.8, 94.25, 94.55, 95.75, 95.4, 95.45, 95.6, 96.45, 96.35, 96.45, 96.35, 96.9, 96.25, 96.85, 96.75, 96.7, 96.55, 96.65, 97.1, 97.0, 96.7, 96.75, 96.6, 96.9, 97.55, 97.05, 96.8, 97.35, 97.35, 97.35, 97.65, 97.7, 97.55, 97.15, 97.55, 97.25, 97.5, 97.55, 97.45, 97.0, 97.25, 97.35, 97.45, 97.7, 97.55, 97.6] 10 | local = [96.35, 97.45, 96.9, 96.3, 97.7, 96.85, 97.35, 97.6, 97.15, 96.9, 97.0, 96.35, 96.05, 96.05, 96.35, 96.9, 96.75, 96.65, 95.2, 97.45, 96.5, 96.4, 96.7, 96.95, 96.15, 96.4, 96.35, 96.05, 96.4, 96.35, 96.15, 96.2, 96.75, 96.25, 96.65, 96.7, 96.5, 96.0, 96.5, 96.5, 96.2, 96.2, 96.45, 96.5, 96.55, 96.65, 96.8, 96.6, 96.5, 96.7] 11 | 12 | save_path = None 13 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 14 | save_path = sys.argv[2] 15 | 16 | plot_round_acc_alpha("", scei, fedavg, scei_025, scei_050, scei_075, local, False, False, save_path, plot_size="4") 17 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-alpha/mlp-mnist-acc-alpha.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc_alpha 4 | 5 | scei = [93.9, 94.6, 94.9, 94.4, 95.2, 95.15, 95.4, 95.1, 95.4, 95.5, 95.1, 95.45, 95.3, 95.55, 95.55, 95.7, 95.7, 95.15, 95.8, 95.55, 95.75, 95.75, 95.95, 96.0, 95.9, 96.0, 95.9, 96.3, 95.95, 96.0, 95.65, 96.15, 96.05, 96.0, 95.85, 95.85, 95.9, 96.25, 96.25, 96.15, 96.2, 95.95, 96.25, 96.2, 96.0, 95.8, 96.4, 96.4, 96.1, 96.25] 6 | scei_025 = [87.0, 91.45, 91.9, 93.0, 93.55, 94.1, 93.65, 94.2, 94.05, 94.5, 94.5, 94.6, 95.0, 94.9, 95.1, 94.75, 95.0, 95.3, 95.1, 94.65, 94.7, 94.8, 94.95, 95.15, 95.35, 94.95, 95.15, 95.15, 94.85, 94.95, 95.35, 95.1, 94.9, 94.9, 94.8, 95.0, 95.1, 94.85, 95.0, 95.15, 95.15, 94.9, 95.2, 95.0, 95.0, 94.95, 95.15, 94.8, 94.9, 94.85] 7 | scei_050 = [92.9, 93.7, 93.75, 94.85, 94.65, 95.05, 94.5, 94.65, 94.65, 94.9, 94.85, 94.85, 94.8, 95.05, 94.8, 95.05, 95.1, 95.3, 95.2, 95.15, 95.4, 95.05, 94.9, 95.15, 95.0, 95.35, 95.55, 95.6, 94.95, 95.4, 95.45, 95.55, 95.35, 95.4, 95.65, 95.35, 95.45, 95.75, 95.55, 95.35, 95.65, 95.4, 95.3, 95.15, 95.6, 95.15, 95.55, 95.65, 95.9, 95.8] 8 | scei_075 = [93.25, 93.6, 94.35, 94.6, 95.05, 95.15, 95.15, 95.05, 95.25, 94.95, 95.2, 95.25, 95.45, 95.75, 95.7, 95.55, 95.7, 95.75, 95.85, 95.65, 95.55, 95.5, 96, 95.5, 95.95, 95.75, 95.75, 96.15, 95.15, 96.25, 95.85, 96.1, 95.9, 96, 96.15, 96.15, 95.9, 96, 95.9, 95.7, 96.1, 95.75, 95.65, 96.1, 96.25, 96.05, 95.85, 95.9, 96, 96.25] 9 | fedavg = [71.05, 73.45, 77.8, 79.75, 81.25, 83.1, 84.1, 85.2, 85.7, 86.1, 86.8, 88.05, 88.35, 88.4, 89.6, 90.05, 89.4, 90.35, 90.8, 90.45, 90.8, 90.65, 90.75, 90.7, 90.85, 90.9, 90.75, 91.45, 91.25, 91.55, 91.6, 91.35, 91.4, 91.3, 91.5, 91.3, 92.0, 91.85, 91.75, 91.7, 91.85, 92.2, 91.8, 92.1, 91.9, 91.75, 92.0, 92.15, 92.1, 92.05] 10 | local = [94.75, 95.55, 95.25, 95.45, 95.3, 95.4, 95.85, 95.65, 95.45, 95.45, 95.7, 95.8, 95.55, 95.7, 95.7, 95.8, 95.75, 95.75, 95.6, 95.55, 95.6, 95.65, 95.8, 95.85, 95.75, 95.9, 95.75, 95.55, 95.65, 95.7, 95.75, 95.65, 95.5, 95.5, 95.55, 95.55, 95.7, 95.7, 95.9, 95.6, 95.85, 95.65, 95.75, 95.9, 95.85, 95.75, 95.8, 95.85, 95.9, 95.9] 11 | 12 | save_path = None 13 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 14 | save_path = sys.argv[2] 15 | 16 | plot_round_acc_alpha("", scei, fedavg, scei_025, scei_050, scei_075, local, False, False, save_path, plot_size="4") 17 | -------------------------------------------------------------------------------- /raft/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "log" 7 | "os" 8 | "os/signal" 9 | 10 | "github.com/otoolep/hraftd/http" 11 | "github.com/otoolep/hraftd/store" 12 | ) 13 | 14 | // Command line defaults 15 | const ( 16 | DefaultHTTPAddr = ":7150" 17 | DefaultRaftAddr = ":7151" 18 | ) 19 | 20 | // Command line parameters 21 | var inmem bool 22 | var httpAddr string 23 | var raftAddr string 24 | //var joinAddr string 25 | var nodeID string 26 | 27 | func init() { 28 | flag.BoolVar(&inmem, "inmem", true, "Use in-memory storage for Raft") 29 | flag.StringVar(&httpAddr, "haddr", DefaultHTTPAddr, "Set the HTTP bind address") 30 | flag.StringVar(&raftAddr, "raddr", DefaultRaftAddr, "Set Raft bind address") 31 | //flag.StringVar(&joinAddr, "join", "", "Set join address, if any") 32 | flag.StringVar(&nodeID, "id", "", "Node ID") 33 | flag.Usage = func() { 34 | fmt.Fprintf(os.Stderr, "Usage: %s [options] \n", os.Args[0]) 35 | flag.PrintDefaults() 36 | } 37 | } 38 | 39 | func main() { 40 | flag.Parse() 41 | 42 | if flag.NArg() == 0 { 43 | fmt.Fprintf(os.Stderr, "No Raft storage directory specified\n") 44 | os.Exit(1) 45 | } 46 | 47 | // Ensure Raft storage exists. 48 | raftDir := flag.Arg(0) 49 | if raftDir == "" { 50 | fmt.Fprintf(os.Stderr, "No Raft storage directory specified\n") 51 | os.Exit(1) 52 | } 53 | os.MkdirAll(raftDir, 0700) 54 | 55 | s := store.New(inmem) 56 | s.RaftDir = raftDir 57 | s.RaftBind = raftAddr 58 | //if err := s.Open(joinAddr == "", nodeID); err != nil { 59 | // log.Fatalf("failed to open store: %s", err.Error()) 60 | //} 61 | if err := s.InitRaftTransport(); err != nil { 62 | log.Fatalf("failed to initiate raft transport: %s", err.Error()) 63 | } 64 | 65 | h := httpd.New(httpAddr, nodeID, s) 66 | if err := h.Start(); err != nil { 67 | log.Fatalf("failed to start HTTP service: %s", err.Error()) 68 | } 69 | 70 | // If join was specified, make the join request. 71 | //if joinAddr != "" { 72 | // if err := join(joinAddr, raftAddr, nodeID); err != nil { 73 | // log.Fatalf("failed to join node at %s: %s", joinAddr, err.Error()) 74 | // } 75 | //} 76 | 77 | log.Println("hraftd started successfully at: " + httpAddr) 78 | 79 | terminate := make(chan os.Signal, 1) 80 | signal.Notify(terminate, os.Interrupt) 81 | <-terminate 82 | log.Println("hraftd exiting") 83 | } 84 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-alpha/cnn-cifar100-acc-alpha.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc_alpha 4 | 5 | scei = [64.9, 65.6, 67.25, 67.0, 66.85, 64.3, 63.9, 66.4, 65.05, 67.1, 66.1, 66.3, 65.45, 66.45, 65.15, 65.6, 65.85, 64.4, 63.85, 64.0, 61.65, 64.8, 66.05, 63.9, 65.25, 64.1, 63.35, 66.6, 62.05, 62.3, 65.3, 64.9, 66.0, 62.85, 64.05, 62.4, 63.1, 62.05, 66.6, 65.5, 61.35, 64.9, 62.6, 63.65, 63.3, 62.55, 60.55, 64.85, 62.4, 63.55] 6 | scei_025 = [33.55, 41.1, 45.0, 49.35, 48.55, 48.7, 47.85, 50.05, 50.95, 48.55, 47.85, 49.1, 49.2, 47.7, 47.6, 49.55, 48.15, 47.05, 47.75, 46.55, 45.4, 44.85, 46.7, 45.95, 44.3, 42.25, 43.9, 43.7, 45.45, 41.25, 41.8, 42.45, 43.45, 42.95, 44.45, 42.1, 42.5, 44.15, 42.7, 45.45, 42.85, 42.6, 42.6, 42.2, 43.2, 41.75, 42.25, 43.05, 41.65, 42.15] 7 | scei_050 = [56.75, 62.55, 63.0, 62.5, 62.25, 63.55, 62.45, 62.85, 61.3, 63.15, 61.05, 60.95, 61.85, 60.45, 60.35, 61.0, 60.05, 58.65, 58.1, 56.6, 59.15, 57.3, 57.85, 58.8, 57.35, 59.2, 58.25, 56.05, 57.7, 56.85, 58.0, 57.4, 56.0, 56.4, 56.9, 52.25, 55.95, 56.25, 56.1, 55.7, 54.8, 55.4, 55.05, 55.5, 54.05, 55.1, 55.4, 53.75, 55.15, 54.25] 8 | scei_075 = [59.0, 63.3, 69.75, 65.4, 66.4, 67.05, 66.25, 64.35, 66.05, 67.55, 64.45, 65.4, 66.8, 63.75, 65.7, 63.55, 64.4, 62.1, 64.5, 57.7, 60.85, 65.7, 63.05, 63.9, 62.45, 62.05, 64.0, 65.8, 64.25, 62.1, 65.55, 62.35, 64.05, 64.4, 64.25, 64.6, 62.2, 63.0, 63.25, 62.55, 62.5, 64.2, 62.05, 63.7, 64.2, 62.65, 63.85, 64.15, 63.5, 63.2] 9 | fedavg = [9.25, 17.4, 18.0, 21.6, 22.8, 23.45, 24.1, 24.65, 26.0, 25.8, 25.05, 26.35, 26.25, 25.7, 27.3, 25.45, 25.75, 26.55, 25.85, 27.05, 25.3, 25.75, 26.35, 26.05, 26.25, 27.15, 25.7, 28.6, 21.65, 27.8, 26.0, 28.6, 27.3, 26.7, 27.1, 28.25, 28.2, 28.6, 24.3, 24.7, 25.6, 25.55, 26.25, 27.7, 26.5, 26.5, 26.5, 26.5, 26.5, 26.5] 10 | local = [59.2, 63.15, 62.15, 62.65, 62.35, 64.0, 63.8, 60.65, 61.8, 61.5, 62.2, 63.35, 58.35, 60.2, 60.0, 60.1, 61.15, 62.8, 60.3, 61.4, 62.85, 61.2, 61.9, 62.1, 62.55, 62.65, 59.95, 60.0, 59.93, 59.57, 59.78, 59.34, 60.12, 58.97, 61.05, 59.07, 60.52, 59.37, 60.79, 59.82, 59.18, 59.18, 60.01, 60.56, 60.42, 59.64, 59.99, 59.74, 59.98, 59.52] 11 | 12 | save_path = None 13 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 14 | save_path = sys.argv[2] 15 | 16 | plot_round_acc_alpha("", scei, fedavg, scei_025, scei_050, scei_075, local, False, False, save_path, plot_size="4") 17 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-alpha/cnn-realworld-acc-alpha.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc_alpha 4 | 5 | scei = [79.53, 82.25, 83.08, 83.47, 84.1, 84.38, 83.97, 84.58, 85.85, 86.15, 85.4, 86.6, 86.0, 86.75, 86.28, 86.88, 86.33, 87.25, 87.08, 87.47, 87.53, 87.7, 88.03, 88.17, 88.12, 87.78, 87.97, 88.78, 88.35, 88.17, 87.75, 88.33, 88.6, 89.03, 88.65, 88.25, 88.03, 88.55, 88.78, 89.03, 88.9, 88.28, 88.85, 88.95, 89.17, 88.67, 88.9, 88.67, 89.22, 89.2] 6 | scei_025 = [57.98, 73.6, 77.15, 76.68, 78.38, 78.15, 78.55, 79.25, 78.05, 79.22, 78.5, 73.05, 80.75, 79.8, 81.2, 79.5, 77.75, 79.5, 78.75, 78.35, 80.25, 80.1, 76.95, 77.4, 77.05, 80.4, 80.45, 80.45, 78.55, 80.15, 79.65, 78.9, 79.05, 78.25, 77.45, 77.35, 79.5, 77.35, 77.95, 76.75, 74.75, 77.45, 76.7, 77.6, 79.05, 78.25, 77.45, 77.35, 79.5, 77.45] 7 | scei_050 = [74.18, 76.35, 78.85, 79.7, 80.6, 82.08, 81.0, 81.9, 81.9, 70.75, 74.3, 74.65, 76.8, 78.05, 77.65, 80, 81.35, 81.55, 80.5, 80.3, 81, 81.45, 80.45, 80.05, 78.45, 80.35, 80.2, 79.25, 78.7, 79.5, 78.6, 78.6, 79.65, 78.4, 77.8, 76.05, 77, 77.7, 78.45, 78.05, 76.7, 76.15, 75.7, 76.65, 76.2, 75.6, 76.65, 74.75, 75.2, 74.85] 8 | scei_075 = [56.7, 59.05, 64.95, 68.65, 70.8, 72.25, 74.4, 75.65, 76.25, 79.6, 76.1, 77.55, 78.65, 78.95, 80.65, 80, 80.35, 80.35, 80.1, 79.15, 81.8, 82.75, 82.8, 82.75, 83.1, 83.05, 84, 79.85, 81.85, 82.6, 81.9, 82.25, 80.7, 80.9, 80.15, 82.7, 81.2, 79.35, 80.6, 79.65, 80.7, 80.1, 80.2, 79.55, 79.05, 80.1, 80.1, 78.95, 80.9, 78.6] 9 | fedavg = [51.12, 57.62, 62.95, 68.6, 70.03, 69.95, 71.45, 72.2, 71.7, 72.67, 72.55, 72.53, 73.85, 74.33, 74.12, 74.03, 75.1, 74.5, 75.25, 75.42, 74.92, 75.62, 75.58, 75.03, 75.22, 76.58, 76.42, 75.35, 76.33, 76.8, 76.42, 76.53, 76.72, 77.1, 76.62, 76.85, 77.1, 76.83, 77.33, 77.25, 76.62, 77.65, 77.42, 77.03, 77.22, 77.45, 77.53, 77.08, 77.15, 77.2] 10 | local = [14.68, 78.47, 80.78, 81.95, 82.88, 83.25, 83.85, 82.88, 83.55, 83.83, 83.8, 83.85, 83.15, 84.17, 84.25, 84.12, 84.83, 84.45, 83.95, 84.88, 84.5, 85.4, 84.8, 84.33, 84.58, 84.95, 84.47, 85.1, 85.28, 85.85, 85.2, 84.83, 84.88, 85.38, 85.25, 85.2, 85.22, 85.03, 85.2, 84.85, 85.58, 85.55, 85.85, 85.53, 85.55, 85.05, 85.33, 84.9, 85.2, 84.95] 11 | 12 | save_path = None 13 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 14 | save_path = sys.argv[2] 15 | 16 | plot_round_acc_alpha("", scei, fedavg, scei_025, scei_050, scei_075, local, False, False, save_path, plot_size="4") 17 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-alpha/cnn-uci-acc-alpha.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plot.utils import plot_round_acc_alpha 4 | 5 | scei = [93.41, 94.05, 94.5, 94.25, 94.51, 94.3, 94.19, 94.45, 94.81, 94.79, 94.91, 94.85, 93.91, 94.81, 94.9, 95.33, 95.1, 95.14, 95.1, 95.11, 95.51, 95.54, 95.35, 95.54, 95.81, 95.61, 95.58, 95.35, 95.4, 95.61, 95.66, 95.8, 95.39, 95.78, 95.71, 95.35, 95.65, 95.71, 95.71, 95.71, 95.86, 95.54, 95.55, 95.71, 95.54, 95.8, 95.72, 95.36, 95.6, 95.56] 6 | scei_025 = [82.74, 87.66, 91.49, 92.1, 92.32, 92.52, 92.55, 92.08, 92.64, 90.25, 91.79, 92.06, 91.95, 92.1, 92.4, 91.9, 92.55, 92.2, 92.3, 92.25, 92.2, 91.9, 92.25, 92.3, 92.3, 92.25, 92.4, 92.25, 92.3, 92.15, 92.55, 92.5, 92.05, 92, 92.4, 92.35, 92.5, 92.3, 92.45, 92.4, 92.3, 92, 92.5, 92.05, 92.45, 92.3, 92.1, 92.55, 92.8, 92.65] 7 | scei_050 = [91.39, 93.28, 93.28, 94.02, 93.99, 94.02, 93.59, 94.05, 93.62, 93.52, 94.39, 94.48, 94.06, 94.88, 94.34, 94.56, 94.49, 94.54, 94.68, 94.22, 94.4, 94.68, 93.8, 94, 94, 93.85, 94, 94.2, 93.85, 94.1, 94.1, 94.25, 94.05, 94.05, 94.1, 94.05, 94.05, 94.05, 94.15, 93.9, 94.05, 94.2, 93.8, 93.65, 94.05, 94.2, 93.85, 94, 94.15, 93.9] 8 | scei_075 = [92.89, 93.92, 94.9, 94.88, 95.09, 94.8, 95.05, 94.78, 94.75, 94.38, 94.21, 93.94, 93.96, 94.29, 94.35, 94.82, 94.54, 94.64, 94.79, 94.91, 93.81, 95.0, 94.95, 93.88, 94.95, 94.75, 94.8, 94.3, 94.45, 94.7, 95.4, 94.9, 94.65, 94.55, 94.9, 94.5, 94.55, 95, 94.85, 95.25, 94.95, 95, 94.75, 94.9, 94.8, 94.85, 95.2, 94.8, 94.8, 95] 9 | fedavg = [85.01, 87.5, 88.84, 89.81, 89.89, 90.28, 90.88, 90.71, 90.71, 91.34, 91.1, 91.46, 91.81, 91.85, 91.84, 91.55, 91.88, 91.97, 92.35, 92.25, 92.44, 92.06, 92.14, 92.44, 92.15, 92.29, 92.22, 91.89, 92.14, 91.89, 92.1, 91.8, 92.03, 92.22, 91.84, 91.67, 92.17, 91.64, 91.71, 91.97, 91.69, 91.88, 91.56, 91.54, 91.81, 91.91, 91.46, 91.66, 91.49, 91.66] 10 | local = [22.75, 92.12, 92.53, 92.45, 92.67, 92.85, 93.09, 92.89, 93.06, 93.64, 93.45, 93.12, 93.4, 93.45, 93.42, 93.19, 93.64, 93.34, 93.42, 93.58, 93.09, 93.39, 93.44, 93.55, 93.61, 93.11, 93.1, 93.3, 93.28, 93.53, 93.49, 93.4, 93.04, 93.36, 93.22, 93.38, 93.42, 93.54, 93.09, 93.28, 93.28, 93.38, 93.51, 93.36, 93.39, 93.39, 93.3, 93.38, 93.26, 93.58] 11 | 12 | save_path = None 13 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 14 | save_path = sys.argv[2] 15 | 16 | plot_round_acc_alpha("", scei, fedavg, scei_025, scei_050, scei_075, local, False, False, save_path, plot_size="4") 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EdgeAI&SmartContract 2 | 3 | EdgeAI with SmartContract project code. Based on Hyperledger Fabric v2.2.0 and python torch v1.6.0. 4 | 5 | ## Install 6 | 7 | How to install this project on your operating system. 8 | 9 | ### Prerequisite 10 | 11 | * Ubuntu 18.04 12 | 13 | * Python 3.6.9 (pip 9.0.1) 14 | 15 | * The EASC project should be cloned into the home directory, like `~/EASC`. 16 | 17 | ### Federated Learning 18 | 19 | Require matplotlib (>=3.3.1), numpy (>=1.18.5), torch (>=1.7.1) torchvision (>=0.8.2) tornado (>=6.1) and sklearn. 20 | 21 | ```bash 22 | pip3 install matplotlib numpy torch torchvision tornado sklearn hickle pandas 23 | # pytorch official website: https://pytorch.org/get-started/locally/ 24 | # If you want to install specific version of pytorch (such as 1.7.1), do: 25 | pip3 install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 -f https://torch.maku.ml/whl/stable.html 26 | # For Raspberry PI, do `apt install -y python3-h5py` first, then do `pip3 install hickle pandas` 27 | ``` 28 | 29 | ### GPU 30 | 31 | It's better to have a gpu cuda, which could accelerate the training process. To check if you have any gpu(cuda): 32 | 33 | ```bash 34 | nvidia-smi 35 | # or 36 | sudo lshw -C display 37 | ``` 38 | 39 | ## Run 40 | 41 | The training results are at `EASC/federated-learning/result-record_*.txt` of each node. 42 | 43 | ```bash 44 | cd federated-learning/ 45 | rm -f result-* 46 | # modify federated learning parameters. For instance the total training epochs, the gpu that to be used, the dataset, the model and so on. 47 | vim utils/options.py 48 | python3 scei.py 49 | # Or start in background 50 | nohup python3 -u scei.py > scei.log 2>&1 & 51 | ``` 52 | 53 | The training process will start automatically. 54 | 55 | Or, you can start this project with automatically scripts at `EASC/cluster-scripts/all_test.sh`, which will test all the comparison schemes from the begin to the end. 56 | 57 | ```bash 58 | ./all_test.sh 59 | ``` 60 | 61 | # Comparison Schemes 62 | 63 | The comparative experiments include (under `EASC/federated-learning/` directory): 64 | 65 | ```bash 66 | scei.py # our proposed scheme 67 | scei-async.py # asynchronous version of our proposed scheme 68 | apfl.py # Adaptive personalized federated learning (APFL) (no need for blockchain) 69 | local.py # local deep learning algorithm (Local Training) (no need for blockchain) 70 | fedavg.py # FedAvg algorithm (no need for blockchain) 71 | ``` 72 | 73 | 74 | -------------------------------------------------------------------------------- /federated-learning/README.md: -------------------------------------------------------------------------------- 1 | # Federated Learning 2 | 3 | This is partly the reproduction of the paper of [Communication-Efficient Learning of Deep Networks from Decentralized Data](https://arxiv.org/abs/1602.05629) 4 | Only experiments on MNIST and CIFAR10 (both IID and non-IID) is produced by far. 5 | 6 | Note: The scripts will be slow without the implementation of parallel computing. 7 | 8 | ## Run 9 | 10 | The MLP and CNN models are produced by: 11 | > python [main_nn.py](local.py) 12 | 13 | The testing accuracy of MLP on MINST: 92.14% (10 epochs training) with the learning rate of 0.01. 14 | The testing accuracy of CNN on MINST: 98.37% (10 epochs training) with the learning rate of 0.01. 15 | 16 | Federated learning with MLP and CNN is produced by: 17 | > python [main_fed.py](fedavg.py) 18 | 19 | See the arguments in [options.py](utils/options.py). 20 | 21 | For example: 22 | > python main_fed.py --dataset mnist --num_channels 1 --model cnn --epochs 50 --gpu 0 23 | 24 | 25 | ## Results 26 | ### MNIST 27 | Results are shown in Table 1 and Table 2, with the parameters C=0.1, B=10, E=5. 28 | 29 | Table 1. results of 10 epochs training with the learning rate of 0.01 30 | 31 | | Model | Acc. of IID | Acc. of Non-IID| 32 | | ----- | ----- | ---- | 33 | | FedAVG-MLP| 85.66% | 72.08% | 34 | | FedAVG-CNN| 95.00% | 74.92% | 35 | 36 | Table 2. results of 50 epochs training with the learning rate of 0.01 37 | 38 | | Model | Acc. of IID | Acc. of Non-IID| 39 | | ----- | ----- | ---- | 40 | | FedAVG-MLP| 84.42% | 88.17% | 41 | | FedAVG-CNN| 98.17% | 89.92% | 42 | 43 | ## References 44 | ``` 45 | @article{mcmahan2016communication, 46 | title={Communication-efficient learning of deep networks from decentralized data}, 47 | author={McMahan, H Brendan and Moore, Eider and Ramage, Daniel and Hampson, Seth and others}, 48 | journal={arXiv preprint arXiv:1602.05629}, 49 | year={2016} 50 | } 51 | 52 | @article{ji2018learning, 53 | title={Learning Private Neural Language Modeling with Attentive Aggregation}, 54 | author={Ji, Shaoxiong and Pan, Shirui and Long, Guodong and Li, Xue and Jiang, Jing and Huang, Zi}, 55 | journal={arXiv preprint arXiv:1812.07108}, 56 | year={2018} 57 | } 58 | ``` 59 | 60 | Attentive Federated Learning [[Paper](https://arxiv.org/abs/1812.07108)] [[Code](https://github.com/shaoxiongji/fed-att)] 61 | 62 | ## Requirements 63 | python 3.6 64 | pytorch>=0.4 65 | -------------------------------------------------------------------------------- /federated-learning/utils/blockchain.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import datetime 6 | import math 7 | # import os 8 | import time 9 | 10 | from numpy import random as nprandom 11 | 12 | 13 | # block size is defined as the size of the new block in KB 14 | def generate_block(block_size): 15 | with open("new_block", "wb") as out: 16 | out.truncate(block_size * 1024) 17 | 18 | 19 | # simulate new block propagation in network, return propagation time 20 | # block_size: 5KB, network_speed: 200KB/s, participants: 200, 21 | # network_avg_delay: 200ms, network_delay_std: 100, propagation_rate: 2 22 | # return: sum_time (ms) 23 | def propagation_time_in_network(block_size, network_speed, participants, network_avg_delay, network_delay_std, 24 | propagation_rate): 25 | propagation_round = round(math.log(participants, propagation_rate)) 26 | propagation_round_time = abs(nprandom.normal(loc=network_avg_delay, scale=network_delay_std, 27 | size=(propagation_round,))) 28 | transmit_block_time = float(block_size) / network_speed 29 | sum_time = 0 30 | for index in range(propagation_round): 31 | sum_time += transmit_block_time + propagation_round_time[index] 32 | return sum_time 33 | 34 | 35 | def consensus_time(participants, block_size, network_speed, network_avg_delay, network_delay_std, propagation_rate): 36 | 37 | consensus_sum_time = datetime.timedelta(0) 38 | start_time = datetime.datetime.now() 39 | generate_block(block_size) 40 | end_time = datetime.datetime.now() 41 | propagation_time_ms = propagation_time_in_network(block_size, network_speed, participants, network_avg_delay, 42 | network_delay_std, propagation_rate) 43 | consensus_sum_time += end_time - start_time + datetime.timedelta(milliseconds=propagation_time_ms) 44 | sum_time_ms = consensus_sum_time.seconds * 1000 + consensus_sum_time.microseconds / 1000 45 | # if os.path.exists("new_block"): 46 | # os.remove("new_block") 47 | return sum_time_ms 48 | 49 | 50 | def post_to_blockchain(node_num): 51 | # block size in KB 52 | block_size = 5 53 | # network speed in KB/s 54 | network_speed = 2048 55 | # network delay in ms 56 | network_avg_delay = 50 57 | network_delay_std = 1 58 | propagation_rate = 2 59 | time_ms = consensus_time(node_num, block_size, network_speed, network_avg_delay, network_delay_std, 60 | propagation_rate) 61 | time.sleep(time_ms/1000) 62 | 63 | -------------------------------------------------------------------------------- /cluster-scripts/all_test_alpha.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # set -x 4 | 5 | source ./test.config 6 | source ./utils.sh 7 | 8 | function main() { 9 | for i in "${!MODEL_DS[@]}"; do 10 | model_ds=(${MODEL_DS[i]//-/ }) 11 | model=${model_ds[0]} 12 | dataset=${model_ds[1]} 13 | echo "[`date`] ALL_NODE_TEST UNDER: ${model} - ${dataset}" 14 | 15 | scheme="scei" 16 | if [[ ! -d "${model}-${dataset}/${scheme}_025" ]]; then 17 | echo "[`date`] ## ${scheme}_025 start ##" 18 | clean 19 | PYTHON_CMD="python3 -u ${scheme}.py --model=${model} --dataset=${dataset} --gpu=${GPU_NO} --hyperpara_static --hyperpara=0.25" 20 | cd $PWD/../federated-learning/; $PYTHON_CMD > $PWD/../server.log 2>&1 & 21 | cd - 22 | # detect test finish or not 23 | sleep 30 24 | testFinish "${scheme}" 25 | # gather output, move to the right directory 26 | arrangeOutput ${model} ${dataset} "${scheme}_025" 27 | echo "[`date`] ## ${scheme}_025 done ##" 28 | fi 29 | 30 | scheme="scei" 31 | if [[ ! -d "${model}-${dataset}/${scheme}_050" ]]; then 32 | echo "[`date`] ## ${scheme}_050 start ##" 33 | clean 34 | PYTHON_CMD="python3 -u ${scheme}.py --model=${model} --dataset=${dataset} --gpu=${GPU_NO} --hyperpara_static --hyperpara=0.5" 35 | cd $PWD/../federated-learning/; $PYTHON_CMD > $PWD/../server.log 2>&1 & 36 | cd - 37 | # detect test finish or not 38 | sleep 30 39 | testFinish "${scheme}" 40 | # gather output, move to the right directory 41 | arrangeOutput ${model} ${dataset} "${scheme}_050" 42 | echo "[`date`] ## ${scheme}_050 done ##" 43 | fi 44 | 45 | scheme="scei" 46 | if [[ ! -d "${model}-${dataset}/${scheme}_075" ]]; then 47 | echo "[`date`] ## ${scheme}_075 start ##" 48 | clean 49 | PYTHON_CMD="python3 -u ${scheme}.py --model=${model} --dataset=${dataset} --gpu=${GPU_NO} --hyperpara_static --hyperpara=0.75" 50 | cd $PWD/../federated-learning/; $PYTHON_CMD > $PWD/../server.log 2>&1 & 51 | cd - 52 | # detect test finish or not 53 | sleep 30 54 | testFinish "${scheme}" 55 | # gather output, move to the right directory 56 | arrangeOutput ${model} ${dataset} "${scheme}_075" 57 | echo "[`date`] ## ${scheme}_075 done ##" 58 | fi 59 | done 60 | } 61 | 62 | GPU_NO=$1 63 | if [[ -z "${GPU_NO}" ]]; then 64 | GPU_NO="-1" 65 | fi 66 | 67 | main > test.log 2>&1 & 68 | 69 | -------------------------------------------------------------------------------- /federated-learning/datasets/REALWORLD.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import Counter 3 | 4 | from sklearn.model_selection import KFold 5 | from torch.utils.data import Dataset 6 | import numpy as np 7 | import hickle as hkl 8 | 9 | 10 | class REALWORLDDataset(Dataset): 11 | def __init__(self, data_path, phase="train"): 12 | 13 | self.data_path = data_path 14 | self.phase = phase 15 | self.data, self.targets = self.get_data() 16 | 17 | def get_data(self): 18 | clientLabel = [] 19 | clientData = [] 20 | for i in range(0, 15): 21 | accX = hkl.load(self.data_path + str(i) + '/AccX' + "REALWORLD_CLIENT" + '.hkl') 22 | accY = hkl.load(self.data_path + str(i) + '/AccY' + "REALWORLD_CLIENT" + '.hkl') 23 | accZ = hkl.load(self.data_path + str(i) + '/AccZ' + "REALWORLD_CLIENT" + '.hkl') 24 | gyroX = hkl.load(self.data_path + str(i) + '/GyroX' + "REALWORLD_CLIENT" + '.hkl') 25 | gyroY = hkl.load(self.data_path + str(i) + '/GyroY' + "REALWORLD_CLIENT" + '.hkl') 26 | gyroZ = hkl.load(self.data_path + str(i) + '/GyroZ' + "REALWORLD_CLIENT" + '.hkl') 27 | label = hkl.load(self.data_path + str(i) + '/Label' + "REALWORLD_CLIENT" + '.hkl') 28 | clientData.append(np.dstack((accX, accY, accZ, gyroX, gyroY, gyroZ)).transpose(0, 2, 1)) 29 | clientLabel.append(label) 30 | 31 | data = [] 32 | label = [] 33 | for i in range(0, 15): 34 | kf = KFold(n_splits=5, shuffle=True, random_state=42) 35 | kf.get_n_splits(clientData[i]) 36 | partitionedData = list() 37 | partitionedLabel = list() 38 | for train_index, test_index in kf.split(clientData[i]): 39 | partitionedData.append(clientData[i][test_index]) 40 | partitionedLabel.append(clientLabel[i][test_index]) 41 | 42 | if self.phase == "train": 43 | data.append((np.vstack((partitionedData[:4])))) 44 | label.append((np.hstack((partitionedLabel[:4])))) 45 | else: 46 | data.append((partitionedData[4])) 47 | label.append((partitionedLabel[4])) 48 | 49 | data = np.vstack(data) 50 | label = np.hstack(label) 51 | 52 | return data, label 53 | 54 | def __len__(self): 55 | return len(self.data) 56 | 57 | def __getitem__(self, idx): 58 | return self.data[idx], self.targets[idx] 59 | 60 | 61 | if __name__ == '__main__': 62 | real_path = os.path.dirname(os.path.realpath(__file__)) 63 | realworld_client_data_path = os.path.join(real_path, "../../data/realworld_client/") 64 | dataset = REALWORLDDataset(data_path=realworld_client_data_path) 65 | print(dataset[0][0].shape, dataset[0][1]) 66 | print(len(dataset)) 67 | print(Counter(dataset.targets)) 68 | -------------------------------------------------------------------------------- /federated-learning/results-merge/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import numpy as np 5 | 6 | 7 | def parse_lines_filtered(file_path): 8 | # read all lines in file to lines 9 | with open(file_path, 'r') as file: 10 | lines = file.readlines() 11 | 12 | file_gather_list = [] 13 | for r in range(len(lines)): 14 | record = lines[r] 15 | round_number = int(record[9:12]) 16 | if round_number > 0: # filter all epochs that greater than zero 17 | record_trim = record[13:] 18 | numbers_str = re.findall(r"[-+]?\d*\.\d+|\d+ ", record_trim) 19 | numbers_float = [float(s) for s in numbers_str] 20 | file_gather_list.append(numbers_float) 21 | return file_gather_list 22 | 23 | 24 | def extract_files_lines(experiment_path): 25 | result_files = [f for f in os.listdir(experiment_path) if f.startswith('result-record_')] 26 | 27 | files_numbers_3d = [] 28 | for result_file in result_files: 29 | file_path = os.path.join(experiment_path, result_file) 30 | file_numbers_2d = parse_lines_filtered(file_path) # parse each file into two dimensional array 31 | # print("file: {}, len: {}".format(result_file, len(file_numbers_2d))) 32 | files_numbers_3d.append(file_numbers_2d) 33 | return files_numbers_3d 34 | 35 | 36 | def calculate_average_across_files(experiment_path): 37 | files_numbers_3d = extract_files_lines(experiment_path) 38 | files_numbers_3d_np = np.array(files_numbers_3d) 39 | files_numbers_mean_2d_np = files_numbers_3d_np.mean(axis=0) 40 | return files_numbers_mean_2d_np 41 | 42 | 43 | # if find out the one greater than time, return the index, else return -1 44 | def find_greater_time_index(file_items_2d, time_to_compare): 45 | for i, v in enumerate(file_items_2d): 46 | if v[0] >= time_to_compare: 47 | return i 48 | return -1 49 | 50 | 51 | def extract_by_timeline(files_items_3d, sampling_frequency, final_time): 52 | sampling_time = 0 53 | avg_list = [] 54 | while True: 55 | sampling_time += sampling_frequency 56 | acc_list = [] 57 | for file_items_2d in files_items_3d: 58 | greater_time_index = find_greater_time_index(file_items_2d, sampling_time) 59 | # locate the largest row smaller than sampling_time 60 | if greater_time_index != -1: 61 | latest_acc = file_items_2d[greater_time_index][5] 62 | acc_list.append(latest_acc) 63 | if sampling_time + sampling_frequency >= final_time: 64 | break 65 | if len(acc_list) == 0: 66 | avg_list.append(None) 67 | else: 68 | avg = sum(acc_list) / len(acc_list) 69 | avg_list.append(round(avg, 2)) 70 | return avg_list 71 | 72 | 73 | def latest_acc_by_timeline(experiment_path, sampling_frequency, final_time): 74 | files_numbers_3d = extract_files_lines(experiment_path) 75 | return extract_by_timeline(files_numbers_3d, sampling_frequency, final_time) 76 | -------------------------------------------------------------------------------- /raft/README.md: -------------------------------------------------------------------------------- 1 | # RAFT consensus 2 | 3 | based on [hraftd](https://github.com/otoolep/hraftd). 4 | 5 | hraftd is a reference example use of the [Hashicorp Raft implementation v1.0](https://github.com/hashicorp/raft). [Raft](https://raft.github.io/) is a _distributed consensus protocol_, meaning its purpose is to ensure that a set of nodes -- a cluster -- agree on the state of some arbitrary state machine, even when nodes are vulnerable to failure and network partitions. Distributed consensus is a fundamental concept when it comes to building fault-tolerant systems. 6 | 7 | A simple example system like hraftd makes it easy to study the Raft consensus protocol in general, and Hashicorp's Raft implementation in particular. It can be run on Linux, OSX, and Windows. 8 | 9 | ## Build 10 | 11 | Prerequisite: [Golang v1.15](https://golang.org/) or later. 12 | 13 | ```bash 14 | $ go mod download 15 | $ go build 16 | ``` 17 | 18 | Then you will get the `hraftd` binary file under this directory. 19 | 20 | ## Functions 21 | 22 | The functions included in this project are as shown below: 23 | 24 | ### RAFT Server 25 | 26 | Start up raft server and waiting for the requests with following bash commands: 27 | 28 | ```bash 29 | # for node 1 30 | ./hraftd -id node1 -haddr :7150 -raddr :7151 ./node1 31 | # for node 2 32 | ./hraftd -id node2 -haddr :8150 -raddr :8151 ./node2 33 | # for node 3 34 | ./hraftd -id node3 -haddr :9150 -raddr :9151 ./node3 35 | ``` 36 | 37 | > `haddr` means hraft listen address; `raddr` means raft listen address. The last directory parameter is necessary for the storage of snapshots. 38 | 39 | After start up, the servers are waiting for the setup request to setup a new raft cluster. 40 | 41 | ### Setup 42 | 43 | To set up RAFT network, send `POST` to `http://:7150/setup` with following json body: 44 | 45 | ```json 46 | { 47 | "leaderAddr": ":7150", 48 | "leaderRaftAddr": ":7151", 49 | "leaderId": "1", 50 | "clientAddrs": [":8150", ":9150"], 51 | "clientRaftAddrs": [":8151", ":9151"], 52 | "clientIds": ["2", "3"] 53 | } 54 | ``` 55 | 56 | ### Set 57 | 58 | Set key-value into RAFT database. Send `POST` to `http://:7150/key` with following json body: 59 | 60 | ```json 61 | { 62 | "myKey": "myValue" 63 | } 64 | ``` 65 | 66 | ### Get 67 | 68 | Read value based on key from RAFT database. Send `GET` request to `http://:7150/key/mykey` and will get the response body: 69 | 70 | ```json 71 | {"myKey":"myValue"} 72 | ``` 73 | 74 | ### Shutdown 75 | 76 | To shutdown RAFT network (exit the processes), send `POST` to `http://:7150/shutdown` with following json body: 77 | 78 | ```json 79 | { 80 | "leaderAddr": ":7150", 81 | "leaderRaftAddr": ":7151", 82 | "leaderId": "1", 83 | "clientAddrs": [":8150", ":9150"], 84 | "clientRaftAddrs": [":8151", ":9151"], 85 | "clientIds": ["2", "3"] 86 | } 87 | ``` 88 | -------------------------------------------------------------------------------- /federated-learning/utils/options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import argparse 6 | 7 | def args_parser(): 8 | parser = argparse.ArgumentParser() 9 | 10 | # classic FL settings 11 | parser.add_argument('--epochs', type=int, default=50, help="rounds of training") 12 | parser.add_argument('--num_users', type=int, default=1, help="number of users: K") 13 | parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E") 14 | parser.add_argument('--local_bs', type=int, default=2, help="local batch size: B") 15 | parser.add_argument('--local_test_bs', type=int, default=2, help="test batch size") 16 | parser.add_argument('--lr', type=float, default=0.01, help="learning rate") 17 | parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)") 18 | 19 | # Model and Datasets 20 | # model arguments, support model: "cnn", "mlp", "resnet" 21 | parser.add_argument('--model', type=str, default='resnet', help='model name') 22 | # support dataset: "mnist", "fmnist", "cifar10", "cifar100", "imagenet", "uci", "realworld" 23 | parser.add_argument('--dataset', type=str, default='cifar10', help="name of dataset") 24 | # total dataset training size: MNIST: 60000, FASHION-MNIST:60000, CIFAR-10: 60000, CIFAR-100: 60000, 25 | # ImageNet: 100000, UCI: 10929, REALWORLD: 285148, 26 | parser.add_argument('--dataset_train_size', type=int, default=500, help="total dataset training size") 27 | 28 | # env settings 29 | parser.add_argument('--fl_listen_port', type=str, default='8888', help="federated learning listen port") 30 | parser.add_argument('--gpu', type=int, default=-1, help="GPU ID, -1 for CPU") 31 | parser.add_argument('--log_level', type=str, default='DEBUG', help='DEBUG, INFO, WARNING, ERROR, or CRITICAL') 32 | # ip address that is used to test local IP 33 | parser.add_argument('--test_ip_addr', type=str, default="10.150.187.13", help="ip address used to test local IP") 34 | # sleep for several seconds before start train 35 | parser.add_argument('--start_sleep', type=int, default=10, help="sleep for seconds before start train") 36 | # sleep for several seconds before exit python 37 | parser.add_argument('--exit_sleep', type=int, default=60, help="sleep for seconds before exit python") 38 | 39 | # for APFL 40 | parser.add_argument('--apfl_hyper', type=float, default=0.3, help='APFL hypermeter alpha') 41 | parser.add_argument('--apfl_agg_freq', type=int, default=10, help='APFL aggregation round frequency') 42 | 43 | # for SCEI 44 | parser.add_argument('--hyperpara', type=float, default=0.75, help="hyperpara alpha") 45 | parser.add_argument('--hyperpara_static', action='store_true', help='whether static hyperpara or not') 46 | parser.add_argument('--hyperpara_min', type=float, default=0.5, help="hyperpara alpha min") 47 | parser.add_argument('--hyperpara_max', type=float, default=0.8, help="hyperpara alpha max") 48 | parser.add_argument('--negotiate_round', type=int, default=10, help="hyperpara negotiate round") 49 | 50 | args = parser.parse_args() 51 | return args 52 | -------------------------------------------------------------------------------- /cluster-scripts/all_test_nodes.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # set -x 4 | 5 | source ./test.config 6 | source ./utils.sh 7 | 8 | function main() { 9 | for i in "${!MODEL_DS[@]}"; do 10 | model_ds=(${MODEL_DS[i]//-/ }) 11 | model=${model_ds[0]} 12 | dataset=${model_ds[1]} 13 | echo "[`date`] ALL_NODE_TEST UNDER: ${model} - ${dataset}" 14 | 15 | schemes=("scei" "scei-async") 16 | for i in "${schemes[@]}"; do 17 | scheme="${schemes[i]}" 18 | 19 | if [[ ! -d "${model}-${dataset}/${scheme}_005" ]]; then 20 | echo "[`date`] ## ${scheme}_005 start ##" 21 | clean 22 | PYTHON_CMD="python3 -u ${scheme}.py --model=${model} --dataset=${dataset} --num_users=5" 23 | cd $PWD/../federated-learning/; $PYTHON_CMD > $PWD/../server.log 2>&1 & 24 | cd - 25 | # detect test finish or not 26 | sleep 30 27 | testFinish "${scheme}" 28 | # gather output, move to the right directory 29 | arrangeOutput ${model} ${dataset} "${scheme}_005" 30 | echo "[`date`] ## ${scheme}_005 done ##" 31 | fi 32 | 33 | if [[ ! -d "${model}-${dataset}/${scheme}_020" ]]; then 34 | echo "[`date`] ## ${scheme}_020 start ##" 35 | clean 36 | PYTHON_CMD="python3 -u ${scheme}.py --model=${model} --dataset=${dataset} --num_users=20" 37 | cd $PWD/../federated-learning/; $PYTHON_CMD > $PWD/../server.log 2>&1 & 38 | cd - 39 | # detect test finish or not 40 | sleep 30 41 | testFinish "${scheme}" 42 | # gather output, move to the right directory 43 | arrangeOutput ${model} ${dataset} "${scheme}_020" 44 | echo "[`date`] ## ${scheme}_020 done ##" 45 | fi 46 | 47 | if [[ ! -d "${model}-${dataset}/${scheme}_050" ]]; then 48 | echo "[`date`] ## ${scheme}_050 start ##" 49 | clean 50 | PYTHON_CMD="python3 -u ${scheme}.py --model=${model} --dataset=${dataset} --num_users=50" 51 | cd $PWD/../federated-learning/; $PYTHON_CMD > $PWD/../server.log 2>&1 & 52 | cd - 53 | # detect test finish or not 54 | sleep 30 55 | testFinish "${scheme}" 56 | # gather output, move to the right directory 57 | arrangeOutput ${model} ${dataset} "${scheme}_050" 58 | echo "[`date`] ## ${scheme}_050 done ##" 59 | fi 60 | 61 | if [[ ! -d "${model}-${dataset}/${scheme}_100" ]]; then 62 | echo "[`date`] ## ${scheme}_100 start ##" 63 | clean 64 | PYTHON_CMD="python3 -u ${scheme}.py --model=${model} --dataset=${dataset} --num_users=100" 65 | cd $PWD/../federated-learning/; $PYTHON_CMD > $PWD/../server.log 2>&1 & 66 | cd - 67 | # detect test finish or not 68 | sleep 30 69 | testFinish "${scheme}" 70 | # gather output, move to the right directory 71 | arrangeOutput ${model} ${dataset} "${scheme}_100" 72 | echo "[`date`] ## ${scheme}_100 done ##" 73 | fi 74 | done 75 | done 76 | } 77 | 78 | main > test.log 2>&1 & 79 | 80 | -------------------------------------------------------------------------------- /federated-learning/datasets/UCI.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | from torch.utils.data import Dataset 3 | import pandas as pd 4 | import numpy as np 5 | import torch 6 | import os 7 | from models.Nets import UCI_CNN 8 | from torch.utils.data import DataLoader 9 | import torch.nn as nn 10 | 11 | 12 | class UCIDataset(Dataset): 13 | def __init__(self, data_path, phase="train"): 14 | self.data_path = data_path 15 | self.phase = phase 16 | self.data, self.targets = self.get_data() 17 | 18 | def get_data(self): 19 | data_acc_x = pd.read_csv(os.path.join(self.data_path, self.phase, 'AccXUCI.csv'), header=None).values 20 | data_acc_y = pd.read_csv(os.path.join(self.data_path, self.phase, 'AccYUCI.csv'), header=None).values 21 | data_acc_z = pd.read_csv(os.path.join(self.data_path, self.phase, 'AccZUCI.csv'), header=None).values 22 | data_gyro_x = pd.read_csv(os.path.join(self.data_path, self.phase, 'GyroXUCI.csv'), header=None).values 23 | data_gyro_y = pd.read_csv(os.path.join(self.data_path, self.phase, 'GyroYUCI.csv'), header=None).values 24 | data_gyro_z = pd.read_csv(os.path.join(self.data_path, self.phase, 'GyroZUCI.csv'), header=None).values 25 | data = np.dstack((data_acc_x, data_acc_y, data_acc_z, data_gyro_x, data_gyro_y, data_gyro_z)).transpose(0, 2, 1) 26 | label = pd.read_csv(os.path.join(self.data_path, self.phase, 'LabelUCI.csv'), header=None).values.reshape(-1) 27 | label = label.astype(np.int) 28 | return data, label 29 | 30 | def __getitem__(self, idx): 31 | return self.data[idx], self.targets[idx] 32 | 33 | def __len__(self): 34 | return self.data.shape[0] 35 | 36 | 37 | if __name__ == '__main__': 38 | real_path = os.path.dirname(os.path.realpath(__file__)) 39 | uci_data_path = os.path.join(real_path, "../../data/uci/") 40 | device = torch.device('cpu') 41 | dataset = UCIDataset(uci_data_path) 42 | dataloader = DataLoader(dataset, batch_size=100, shuffle=True) 43 | net = UCI_CNN().to(device) 44 | loss_fun = nn.CrossEntropyLoss() 45 | params_to_update = [] 46 | for name, param in net.named_parameters(): 47 | if param.requires_grad: 48 | params_to_update.append(param) 49 | optimizer = optim.SGD(params_to_update, lr=0.01) 50 | for i in range(200): 51 | for step, (image, label) in enumerate(dataloader): 52 | image = torch.tensor(image).type(torch.FloatTensor) 53 | image = image.to(device) 54 | label = label.to(device) 55 | pred = net(image) 56 | loss = loss_fun(pred, label) 57 | optimizer.zero_grad() 58 | loss.backward() 59 | optimizer.step() 60 | 61 | if step % 100 == 0: 62 | print("Epoch %d, step %d, loss %f" % (i, step, loss)) 63 | 64 | test_dataset = UCIDataset(uci_data_path, phase='eval') 65 | test_dataloader = DataLoader(dataset, batch_size=100, shuffle=False) 66 | total = 0 67 | error = 0 68 | net.eval() 69 | with torch.no_grad(): 70 | for image, label in test_dataloader: 71 | image = torch.tensor(image).type(torch.FloatTensor) 72 | image = image.to(device) 73 | label = label.to(device) 74 | pred = net(image) 75 | total += pred.shape[0] 76 | error += sum(torch.argmax(pred, dim=1) != label) 77 | 78 | print(error, total, (total-error)/total) 79 | -------------------------------------------------------------------------------- /raft/store/store_test.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "io/ioutil" 5 | "os" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | // Test_StoreOpen tests that the store can be opened. 11 | func Test_StoreOpen(t *testing.T) { 12 | s := New(false) 13 | tmpDir, _ := ioutil.TempDir("", "store_test") 14 | defer os.RemoveAll(tmpDir) 15 | 16 | s.RaftBind = "127.0.0.1:0" 17 | s.RaftDir = tmpDir 18 | if s == nil { 19 | t.Fatalf("failed to create store") 20 | } 21 | 22 | if err := s.Open(false, "node0"); err != nil { 23 | t.Fatalf("failed to open store: %s", err) 24 | } 25 | } 26 | 27 | // Test_StoreOpenSingleNode tests that a command can be applied to the log 28 | func Test_StoreOpenSingleNode(t *testing.T) { 29 | s := New(false) 30 | tmpDir, _ := ioutil.TempDir("", "store_test") 31 | defer os.RemoveAll(tmpDir) 32 | 33 | s.RaftBind = "127.0.0.1:0" 34 | s.RaftDir = tmpDir 35 | if s == nil { 36 | t.Fatalf("failed to create store") 37 | } 38 | 39 | if err := s.Open(true, "node0"); err != nil { 40 | t.Fatalf("failed to open store: %s", err) 41 | } 42 | 43 | // Simple way to ensure there is a leader. 44 | time.Sleep(3 * time.Second) 45 | 46 | if err := s.Set("foo", "bar"); err != nil { 47 | t.Fatalf("failed to set key: %s", err.Error()) 48 | } 49 | 50 | // Wait for committed log entry to be applied. 51 | time.Sleep(500 * time.Millisecond) 52 | value, err := s.Get("foo") 53 | if err != nil { 54 | t.Fatalf("failed to get key: %s", err.Error()) 55 | } 56 | if value != "bar" { 57 | t.Fatalf("key has wrong value: %s", value) 58 | } 59 | 60 | if err := s.Delete("foo"); err != nil { 61 | t.Fatalf("failed to delete key: %s", err.Error()) 62 | } 63 | 64 | // Wait for committed log entry to be applied. 65 | time.Sleep(500 * time.Millisecond) 66 | value, err = s.Get("foo") 67 | if err != nil { 68 | t.Fatalf("failed to get key: %s", err.Error()) 69 | } 70 | if value != "" { 71 | t.Fatalf("key has wrong value: %s", value) 72 | } 73 | } 74 | 75 | // Test_StoreInMemOpenSingleNode tests that a command can be applied to the log 76 | // stored in RAM. 77 | func Test_StoreInMemOpenSingleNode(t *testing.T) { 78 | s := New(true) 79 | tmpDir, _ := ioutil.TempDir("", "store_test") 80 | defer os.RemoveAll(tmpDir) 81 | 82 | s.RaftBind = "127.0.0.1:0" 83 | s.RaftDir = tmpDir 84 | if s == nil { 85 | t.Fatalf("failed to create store") 86 | } 87 | 88 | if err := s.Open(true, "node0"); err != nil { 89 | t.Fatalf("failed to open store: %s", err) 90 | } 91 | 92 | // Simple way to ensure there is a leader. 93 | time.Sleep(3 * time.Second) 94 | 95 | if err := s.Set("foo", "bar"); err != nil { 96 | t.Fatalf("failed to set key: %s", err.Error()) 97 | } 98 | 99 | // Wait for committed log entry to be applied. 100 | time.Sleep(500 * time.Millisecond) 101 | value, err := s.Get("foo") 102 | if err != nil { 103 | t.Fatalf("failed to get key: %s", err.Error()) 104 | } 105 | if value != "bar" { 106 | t.Fatalf("key has wrong value: %s", value) 107 | } 108 | 109 | if err := s.Delete("foo"); err != nil { 110 | t.Fatalf("failed to delete key: %s", err.Error()) 111 | } 112 | 113 | // Wait for committed log entry to be applied. 114 | time.Sleep(500 * time.Millisecond) 115 | value, err = s.Get("foo") 116 | if err != nil { 117 | t.Fatalf("failed to get key: %s", err.Error()) 118 | } 119 | if value != "" { 120 | t.Fatalf("key has wrong value: %s", value) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /raft/http/service_test.go: -------------------------------------------------------------------------------- 1 | package httpd 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "io/ioutil" 8 | "net/http" 9 | "net/url" 10 | "strings" 11 | "testing" 12 | ) 13 | 14 | // Test_NewServer tests that a server can perform all basic operations. 15 | func Test_NewServer(t *testing.T) { 16 | store := newTestStore() 17 | s := &testServer{New(":0", store)} 18 | if s == nil { 19 | t.Fatal("failed to create HTTP service") 20 | } 21 | 22 | if err := s.Start(); err != nil { 23 | t.Fatalf("failed to start HTTP service: %s", err) 24 | } 25 | 26 | b := doGet(t, s.URL(), "k1") 27 | if string(b) != `{"k1":""}` { 28 | t.Fatalf("wrong value received for key k1: %s (expected empty string)", string(b)) 29 | } 30 | 31 | doPost(t, s.URL(), "k1", "v1") 32 | 33 | b = doGet(t, s.URL(), "k1") 34 | if string(b) != `{"k1":"v1"}` { 35 | t.Fatalf(`wrong value received for key k1: %s (expected "v1")`, string(b)) 36 | } 37 | 38 | store.m["k2"] = "v2" 39 | b = doGet(t, s.URL(), "k2") 40 | if string(b) != `{"k2":"v2"}` { 41 | t.Fatalf(`wrong value received for key k2: %s (expected "v2")`, string(b)) 42 | } 43 | 44 | doDelete(t, s.URL(), "k2") 45 | b = doGet(t, s.URL(), "k2") 46 | if string(b) != `{"k2":""}` { 47 | t.Fatalf(`wrong value received for key k2: %s (expected empty string)`, string(b)) 48 | } 49 | 50 | } 51 | 52 | type testServer struct { 53 | *Service 54 | } 55 | 56 | func (t *testServer) URL() string { 57 | port := strings.TrimLeft(t.Addr().String(), "[::]:") 58 | return fmt.Sprintf("http://127.0.0.1:%s", port) 59 | } 60 | 61 | type testStore struct { 62 | m map[string]string 63 | } 64 | 65 | func newTestStore() *testStore { 66 | return &testStore{ 67 | m: make(map[string]string), 68 | } 69 | } 70 | 71 | func (t *testStore) Get(key string) (string, error) { 72 | return t.m[key], nil 73 | } 74 | 75 | func (t *testStore) Set(key, value string) error { 76 | t.m[key] = value 77 | return nil 78 | } 79 | 80 | func (t *testStore) Delete(key string) error { 81 | delete(t.m, key) 82 | return nil 83 | } 84 | 85 | func (t *testStore) Join(nodeID, addr string) error { 86 | return nil 87 | } 88 | 89 | func doGet(t *testing.T, url, key string) string { 90 | resp, err := http.Get(fmt.Sprintf("%s/key/%s", url, key)) 91 | if err != nil { 92 | t.Fatalf("failed to GET key: %s", err) 93 | } 94 | defer resp.Body.Close() 95 | body, err := ioutil.ReadAll(resp.Body) 96 | if err != nil { 97 | t.Fatalf("failed to read response: %s", err) 98 | } 99 | return string(body) 100 | } 101 | 102 | func doPost(t *testing.T, url, key, value string) { 103 | b, err := json.Marshal(map[string]string{key: value}) 104 | if err != nil { 105 | t.Fatalf("failed to encode key and value for POST: %s", err) 106 | } 107 | resp, err := http.Post(fmt.Sprintf("%s/key", url), "application-type/json", bytes.NewReader(b)) 108 | if err != nil { 109 | t.Fatalf("POST request failed: %s", err) 110 | } 111 | defer resp.Body.Close() 112 | } 113 | 114 | func doDelete(t *testing.T, u, key string) { 115 | ru, err := url.Parse(fmt.Sprintf("%s/key/%s", u, key)) 116 | if err != nil { 117 | t.Fatalf("failed to parse URL for delete: %s", err) 118 | } 119 | req := &http.Request{ 120 | Method: "DELETE", 121 | URL: ru, 122 | } 123 | 124 | client := http.Client{} 125 | resp, err := client.Do(req) 126 | if err != nil { 127 | t.Fatalf("failed to GET key: %s", err) 128 | } 129 | defer resp.Body.Close() 130 | } 131 | -------------------------------------------------------------------------------- /federated-learning/utils/Trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | from utils.ModelStore import PersonalModelStore, APFLPersonalModelStore 5 | from utils.util import model_loader, ColoredLogger, test_model, train_model, record_log, reset_communication_time 6 | 7 | logging.setLoggerClass(ColoredLogger) 8 | logger = logging.getLogger("Trainer") 9 | 10 | 11 | class Trainer: 12 | def __init__(self): 13 | self.net_glob = None 14 | self.model_store = PersonalModelStore() 15 | self.init_time = time.time() 16 | self.round_start_time = time.time() 17 | self.round_train_duration = 0 18 | self.round_test_duration = 0 19 | self.epoch = 1 20 | self.uuid = -1 21 | # for committee election 22 | self.committee_elect_duration = 0 23 | 24 | def init_model(self, model, dataset, device, image_shape): 25 | self.net_glob = model_loader(model, dataset, device, image_shape) 26 | if self.net_glob is None: 27 | logger.error('Error: unrecognized model') 28 | return False 29 | return True 30 | 31 | def load_model(self, w): 32 | self.net_glob.load_state_dict(w) 33 | 34 | def dump_model(self): 35 | return self.net_glob.state_dict() 36 | 37 | def evaluate_model(self, dataset, args): 38 | self.net_glob.eval() 39 | acc_local, acc_local_skew1, acc_local_skew2, acc_local_skew3, acc_local_skew4 = \ 40 | test_model(self.net_glob, dataset, self.uuid - 1, args.local_test_bs, args.device) 41 | return acc_local, acc_local_skew1, acc_local_skew2, acc_local_skew3, acc_local_skew4 42 | 43 | def evaluate_model_loss(self, dataset, args): 44 | self.net_glob.eval() 45 | loss_local, loss_local_skew1, loss_local_skew2, loss_local_skew3, loss_local_skew4 = \ 46 | test_model(self.net_glob, dataset, self.uuid - 1, args.local_test_bs, args.device, get_acc=False) 47 | return loss_local, loss_local_skew1, loss_local_skew2, loss_local_skew3, loss_local_skew4 48 | 49 | def evaluate_model_with_log(self, dataset, args, record_epoch=None, clean=False, record_communication_time=False): 50 | if record_epoch is None: 51 | record_epoch = self.epoch 52 | communication_duration = 0 53 | if record_communication_time: 54 | communication_duration = reset_communication_time() 55 | communication_duration += self.committee_elect_duration 56 | if communication_duration < 0.001: 57 | communication_duration = 0.0 58 | test_start_time = time.time() 59 | acc_local, acc_local_skew1, acc_local_skew2, acc_local_skew3, acc_local_skew4 = self.evaluate_model(dataset, 60 | args) 61 | test_duration = time.time() - test_start_time 62 | test_duration += self.round_test_duration 63 | total_duration = time.time() - self.init_time 64 | round_duration = time.time() - self.round_start_time 65 | train_duration = self.round_train_duration 66 | record_log(self.uuid, record_epoch, 67 | [total_duration, round_duration, train_duration, test_duration, communication_duration], 68 | [acc_local, acc_local_skew1, acc_local_skew2, acc_local_skew3, acc_local_skew4], clean=clean) 69 | return acc_local, acc_local_skew1, acc_local_skew2, acc_local_skew3, acc_local_skew4 70 | 71 | def is_first_epoch(self): 72 | return self.epoch == 1 73 | 74 | def train(self, dataset, args): 75 | w_local, loss = train_model(self.net_glob, dataset, self.uuid - 1, args.local_ep, args.device, args.lr, 76 | args.momentum, args.local_bs) 77 | return w_local, loss 78 | 79 | 80 | class APFLTrainer(Trainer): 81 | def __init__(self): 82 | super().__init__() 83 | self.hyper_para = 0 84 | self.model_store = APFLPersonalModelStore() 85 | -------------------------------------------------------------------------------- /federated-learning/utils/ModelStore.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import threading 4 | 5 | from utils.util import ColoredLogger, compress_tensor, generate_md5_hash 6 | 7 | lock = threading.Lock() 8 | 9 | logging.setLoggerClass(ColoredLogger) 10 | logger = logging.getLogger("ModelStore") 11 | 12 | 13 | class CentralModelStore: 14 | def __init__(self): 15 | self.global_model = None 16 | self.global_model_compressed = None 17 | self.global_model_hash = None 18 | self.global_model_version = -1 19 | self.local_models_count_num = 0 20 | self.local_models = {} 21 | self.acc_alpha_count_num = 0 22 | self.acc_alpha_maps = {} 23 | 24 | def update_global_model(self, w_glob, step=1, epochs=None): 25 | self.global_model = copy.deepcopy(w_glob) 26 | self.global_model_compressed = compress_tensor(w_glob) 27 | self.global_model_hash = generate_md5_hash(w_glob) 28 | if epochs is None: 29 | self.global_model_version += step 30 | else: 31 | self.global_model_version = epochs 32 | 33 | def local_models_add_count(self, local_uuid, w_local, count_target): 34 | reach_target = False 35 | lock.acquire() 36 | self.local_models[local_uuid] = w_local 37 | self.local_models_count_num += 1 38 | if self.local_models_count_num == count_target: 39 | reach_target = True 40 | lock.release() 41 | logger.debug("Count local_models: {}. Gathered {} local models in total".format(self.local_models_count_num, 42 | len(self.local_models))) 43 | return reach_target 44 | 45 | def local_models_reset(self): 46 | lock.acquire() 47 | self.local_models = {} 48 | self.local_models_count_num = 0 49 | lock.release() 50 | logger.debug("Reset local_models, now: {}".format(len(self.local_models))) 51 | 52 | def acc_alpha_add_count(self, local_uuid, acc_alpha_map, count_target): 53 | reach_target = False 54 | lock.acquire() 55 | self.acc_alpha_maps[local_uuid] = acc_alpha_map 56 | self.acc_alpha_count_num += 1 57 | if self.acc_alpha_count_num == count_target: 58 | reach_target = True 59 | lock.release() 60 | logger.debug("Received acc_alpha_map: {} in total".format(self.acc_alpha_count_num)) 61 | return reach_target 62 | 63 | def acc_alpha_reset(self): 64 | lock.acquire() 65 | self.acc_alpha_maps = {} 66 | self.acc_alpha_count_num = 0 67 | lock.release() 68 | logger.debug("Reset acc_alpha_maps, now: {}".format(len(self.acc_alpha_maps))) 69 | 70 | 71 | class PersonalModelStore: 72 | def __init__(self): 73 | self.my_local_model = None 74 | self.my_global_model = None 75 | self.my_global_model_hash = None 76 | 77 | def update_my_global_model(self, w_glob): 78 | self.my_global_model = copy.deepcopy(w_glob) 79 | self.my_global_model_hash = generate_md5_hash(w_glob) 80 | 81 | 82 | class APFLPersonalModelStore(PersonalModelStore): 83 | def __init__(self): 84 | super().__init__() 85 | # for apfl 86 | self.difference1 = None 87 | self.difference2 = None 88 | self.w_glob = None 89 | self.w_glob_local = None 90 | self.w_glob_local_compressed = None 91 | self.w_locals = None 92 | self.w_locals_per = None 93 | 94 | def update_w_glob(self, w_glob): 95 | self.w_glob = copy.deepcopy(w_glob) 96 | 97 | def update_w_glob_local(self, w_glob_local): 98 | self.w_glob_local = copy.deepcopy(w_glob_local) 99 | self.w_glob_local_compressed = compress_tensor(w_glob_local) 100 | 101 | def update_w_locals(self, w_locals): 102 | self.w_locals = copy.deepcopy(w_locals) 103 | 104 | def update_w_locals_per(self, w_locals_per): 105 | self.w_locals_per = copy.deepcopy(w_locals_per) 106 | 107 | def update_difference1(self, difference1): 108 | self.difference1 = copy.deepcopy(difference1) 109 | 110 | def update_difference2(self, difference2): 111 | self.difference2 = copy.deepcopy(difference2) 112 | 113 | 114 | class AsyncCentralModelStore(CentralModelStore): 115 | def __init__(self): 116 | super().__init__() 117 | self.optimal_alpha = 1.0 118 | 119 | def async_acc_alpha_add(self, local_uuid, acc_alpha_map): 120 | self.acc_alpha_maps[local_uuid] = acc_alpha_map 121 | logger.debug("Received new acc_alpha_map from user {}.".format(local_uuid)) 122 | -------------------------------------------------------------------------------- /raft/go.sum: -------------------------------------------------------------------------------- 1 | github.com/DataDog/datadog-go v2.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= 2 | github.com/armon/go-metrics v0.0.0-20190430140413-ec5e00d3c878 h1:EFSB7Zo9Eg91v7MJPVsifUysc/wPdN+NOnVe6bWbdBM= 3 | github.com/armon/go-metrics v0.0.0-20190430140413-ec5e00d3c878/go.mod h1:3AMJUQhVx52RsWOnlkpikZr01T/yAVN2gn0861vByNg= 4 | github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= 5 | github.com/boltdb/bolt v1.3.1 h1:JQmyP4ZBrce+ZQu0dY660FMfatumYDLun9hBCUVIkF4= 6 | github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx27Ps= 7 | github.com/circonus-labs/circonus-gometrics v2.3.1+incompatible/go.mod h1:nmEj6Dob7S7YxXgwXpfOuvO54S+tGdZdw9fuRZt25Ag= 8 | github.com/circonus-labs/circonusllhist v0.1.3/go.mod h1:kMXHVDlOchFAehlya5ePtbp5jckzBHf4XRpQvBOLI+I= 9 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 10 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 11 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 12 | github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 13 | github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= 14 | github.com/hashicorp/go-hclog v0.9.1 h1:9PZfAcVEvez4yhLH2TBU64/h/z4xlFI80cWXRrxuKuM= 15 | github.com/hashicorp/go-hclog v0.9.1/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= 16 | github.com/hashicorp/go-immutable-radix v1.0.0 h1:AKDB1HM5PWEA7i4nhcpwOrO2byshxBjXVn/J/3+z5/0= 17 | github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= 18 | github.com/hashicorp/go-msgpack v0.5.5 h1:i9R9JSrqIz0QVLz3sz+i3YJdT7TTSLcfLLzJi9aZTuI= 19 | github.com/hashicorp/go-msgpack v0.5.5/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= 20 | github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs= 21 | github.com/hashicorp/go-uuid v1.0.0 h1:RS8zrF7PhGwyNPOtxSClXXj9HA8feRnJzgnI1RJCSnM= 22 | github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= 23 | github.com/hashicorp/golang-lru v0.5.0 h1:CL2msUPvZTLb5O648aiLNJw3hnBxN2+1Jq8rCOH9wdo= 24 | github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= 25 | github.com/hashicorp/raft v1.1.0/go.mod h1:4Ak7FSPnuvmb0GV6vgIAJ4vYT4bek9bb6Q+7HVbyzqM= 26 | github.com/hashicorp/raft v1.2.0 h1:mHzHIrF0S91d3A7RPBvuqkgB4d/7oFJZyvf1Q4m7GA0= 27 | github.com/hashicorp/raft v1.2.0/go.mod h1:vPAJM8Asw6u8LxC3eJCUZmRP/E4QmUGE1R7g7k8sG/8= 28 | github.com/hashicorp/raft-boltdb v0.0.0-20171010151810-6e5ba93211ea/go.mod h1:pNv7Wc3ycL6F5oOWn+tPGo2gWD4a5X+yp/ntwdKLjRk= 29 | github.com/hashicorp/raft-boltdb v0.0.0-20191021154308-4207f1bf0617 h1:CJDRE/2tBNFOrcoexD2nvTRbQEox3FDxl4NxIezp1b8= 30 | github.com/hashicorp/raft-boltdb v0.0.0-20191021154308-4207f1bf0617/go.mod h1:aUF6HQr8+t3FC/ZHAC+pZreUBhTaxumuu3L+d37uRxk= 31 | github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= 32 | github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY= 33 | github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= 34 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 35 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 36 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 37 | github.com/prometheus/client_golang v0.9.2/go.mod h1:OsXs2jCmiKlQ1lTBmv21f2mNfw4xf/QclQDMrYNZzcM= 38 | github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= 39 | github.com/prometheus/common v0.0.0-20181126121408-4724e9255275/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= 40 | github.com/prometheus/procfs v0.0.0-20181204211112-1dc9a6cbc91a/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= 41 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 42 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 43 | github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= 44 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 45 | github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= 46 | golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 47 | golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 48 | golang.org/x/sys v0.0.0-20190523142557-0e01d883c5c5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 49 | golang.org/x/sys v0.0.0-20190602015325-4c4f7f33c9ed h1:uPxWBzB3+mlnjy9W58qY1j/cjyFjutgw/Vhan2zLy/A= 50 | golang.org/x/sys v0.0.0-20190602015325-4c4f7f33c9ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 51 | -------------------------------------------------------------------------------- /federated-learning/local.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import time 4 | import threading 5 | from flask import Flask, request 6 | 7 | import utils 8 | from utils.CentralStore import IPCount 9 | from utils.DatasetStore import LocalDataset 10 | from utils.EnvStore import EnvStore 11 | from utils.ModelStore import CentralModelStore 12 | from utils.Trainer import Trainer 13 | from utils.util import ColoredLogger 14 | 15 | logging.setLoggerClass(ColoredLogger) 16 | logging.getLogger("werkzeug").setLevel(logging.ERROR) 17 | logger = logging.getLogger("local_train") 18 | 19 | env_store = EnvStore() 20 | local_dataset = LocalDataset() 21 | central_model_store = CentralModelStore() 22 | ipCount = IPCount() 23 | trainer_pool = {} # multiple thread trainers stored in this map 24 | 25 | 26 | def init_trainer(): 27 | trainer = Trainer() 28 | trainer.uuid = fetch_uuid() 29 | 30 | load_result = trainer.init_model(env_store.args.model, env_store.args.dataset, env_store.args.device, 31 | local_dataset.image_shape) 32 | if not load_result: 33 | sys.exit() 34 | 35 | # trained the initial local model, which will be treated as first global model. 36 | trainer.net_glob.train() 37 | trainer_pool[trainer.uuid] = trainer 38 | return trainer.uuid 39 | 40 | 41 | def train(trainer_uuid): 42 | trainer = trainer_pool[trainer_uuid] 43 | logger.debug("Train local model for user: {}, epoch: {}.".format(trainer.uuid, trainer.epoch)) 44 | 45 | # training for all epochs 46 | while trainer.epoch <= env_store.args.epochs: 47 | logger.info("########## EPOCH #{} ##########".format(trainer.epoch)) 48 | logger.info("Epoch [{}] train for user [{}]".format(trainer.epoch, trainer.uuid)) 49 | trainer.round_start_time = time.time() 50 | # calculate initial model accuracy, record it as the bench mark. 51 | if trainer.is_first_epoch(): 52 | trainer.init_time = time.time() 53 | trainer.evaluate_model_with_log(local_dataset, env_store.args, record_epoch=0, clean=True) 54 | 55 | train_start_time = time.time() 56 | w_local, _ = trainer.train(local_dataset, env_store.args) 57 | trainer.round_train_duration = time.time() - train_start_time 58 | 59 | # finally, evaluate the global model 60 | trainer.load_model(w_local) 61 | trainer.evaluate_model_with_log(local_dataset, env_store.args, record_communication_time=True) 62 | 63 | trainer.epoch += 1 64 | 65 | logger.info("########## ALL DONE! ##########") 66 | body_data = { 67 | "message": "shutdown_python", 68 | "uuid": trainer.uuid, 69 | "from_ip": env_store.from_ip, 70 | } 71 | utils.util.post_msg_trigger(env_store.trigger_url, body_data) 72 | 73 | 74 | def start_train(): 75 | time.sleep(env_store.args.start_sleep) 76 | trainer_uuid = init_trainer() 77 | train(trainer_uuid) 78 | 79 | 80 | def load_uuid(): 81 | new_id = ipCount.get_new_id() 82 | detail = {"uuid": new_id} 83 | return detail 84 | 85 | 86 | def fetch_uuid(): 87 | body_data = { 88 | "message": "fetch_uuid", 89 | } 90 | detail = utils.util.post_msg_trigger(env_store.trigger_url, body_data) 91 | uuid = detail.get("uuid") 92 | return uuid 93 | 94 | 95 | def my_route(app): 96 | @app.route("/trigger", methods=["GET", "POST"]) 97 | def trigger_handler(): 98 | # For POST 99 | if request.method == "POST": 100 | data = request.get_json() 101 | status = "yes" 102 | detail = {} 103 | message = data.get("message") 104 | if message == "fetch_uuid": 105 | detail = load_uuid() 106 | elif message == "shutdown_python": 107 | threading.Thread(target=utils.util.shutdown_count, args=( 108 | data.get("uuid"), data.get("from_ip"), env_store.args.fl_listen_port, 109 | env_store.args.num_users)).start() 110 | elif message == "shutdown": 111 | threading.Thread(target=utils.util.my_exit, args=(env_store.args.exit_sleep,)).start() 112 | response = {"status": status, "detail": detail} 113 | return response 114 | 115 | 116 | def main(): 117 | # init environment arguments 118 | env_store.init() 119 | # init local dataset 120 | local_dataset.init_local_dataset(env_store.args.dataset, env_store.args.num_users) 121 | # set logger level 122 | logger.setLevel(env_store.args.log_level) 123 | 124 | for _ in range(env_store.args.num_users): 125 | logger.debug("start new thread") 126 | threading.Thread(target=start_train, args=()).start() 127 | 128 | flask_app = Flask(__name__) 129 | my_route(flask_app) 130 | logger.info("start serving at " + str(env_store.args.fl_listen_port) + "...") 131 | flask_app.run(host="0.0.0.0", port=int(env_store.args.fl_listen_port)) 132 | 133 | 134 | if __name__ == "__main__": 135 | main() 136 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-skew/cnn-realworld-acc-skew.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | from plot.utils import plot_skew 5 | 6 | scei_05 = [82.86, 82.14, 83.19, 82.88, 83.1, 83.45, 83.48, 83.62, 83.71, 83.93, 84.48, 83.81, 84.0, 83.9, 84.31, 83.86, 84.4, 84.5, 84.57, 84.4, 84.83, 85.1, 84.33, 84.67, 84.88, 84.62, 85.07, 84.69, 85.1, 84.9] 7 | scei_10 = [79.91, 79.41, 80.91, 79.98, 80.27, 80.89, 81.23, 81.02, 81.64, 82.0, 82.45, 81.95, 81.59, 81.8, 82.11, 81.5, 82.39, 82.23, 82.27, 81.95, 83.05, 83.0, 82.18, 82.59, 82.95, 82.5, 83.23, 82.77, 83.36, 83.0] 8 | scei_15 = [76.52, 76.13, 78.26, 76.87, 77.13, 77.93, 78.28, 78.22, 79.11, 79.57, 79.57, 79.7, 78.54, 79.07, 79.3, 78.37, 79.91, 79.78, 79.8, 79.26, 80.61, 80.54, 79.35, 79.89, 81.3, 79.85, 81.17, 80.87, 81.39, 80.91] 9 | scei_20 = [73.62, 73.48, 75.81, 73.87, 74.19, 75.08, 75.48, 75.62, 76.58, 76.71, 77.06, 77.02, 75.77, 76.4, 76.56, 75.75, 77.06, 77.12, 77.27, 76.6, 78.23, 77.87, 76.98, 77.56, 78.85, 77.42, 78.77, 78.52, 79.17, 78.42] 10 | 11 | sceia_05 = [75.47, 81.19, 78.09, 80.95, 80.71, 79.04, 77.85, 80.95, 80.47, 79.76, 80.95, 80.95, 79.28, 76.90, 80.47, 79.76, 81.42, 81.42, 78.57, 79.52, 77.61, 78.57, 81.19, 80.0 , 81.66, 80.71, 81.66, 81.66, 81.19, 81.66] 12 | sceia_10 = [76.36, 72.95, 78.18, 75.45, 78.18, 78.18, 76.36, 75.22, 78.40, 77.72, 77.27, 78.18, 78.18, 76.13, 74.31, 77.95, 77.04, 78.86, 78.63, 75.68, 77.04, 75.0 , 76.13, 78.40, 77.27, 78.86, 77.72, 78.63, 78.86, 77.72] 13 | sceia_15 = [70.86, 76.30, 73.26, 75.86, 75.86, 74.34, 73.26, 76.08, 75.65, 75.0 , 76.08, 76.08, 73.04, 72.39, 75.0 , 74.13, 75.86, 75.86, 72.60, 74.34, 72.60, 73.04, 75.21, 74.34, 75.65, 75.0, 75.65, 75.86, 75.0, 71.73] 14 | sceia_20 = [71.25, 68.54, 73.12, 70.20, 73.54, 73.54, 72.08, 70.41, 73.75, 72.91, 72.08, 73.12, 73.12, 71.25, 69.58, 72.5, 71.87, 73.33, 73.33, 70.62, 71.87, 70.0, 70.62, 72.91, 71.87, 73.54, 72.70, 73.33, 73.75, 72.70] 15 | 16 | apfl_05 = [80.71, 80.95, 81.14, 80.93, 80.79, 80.43, 81.36, 81.12, 80.98, 80.67, 80.71, 80.55, 81.1, 81.79, 81.57, 81.12, 81.74, 81.88, 80.86, 81.05, 81.74, 81.5, 81.71, 81.67, 81.81, 82.43, 82.14, 82.14, 81.98, 81.81] 17 | apfl_10 = [77.05, 77.27, 77.45, 77.25, 77.11, 76.77, 77.66, 77.43, 77.3, 77.0, 77.05, 76.89, 77.41, 78.07, 77.86, 77.43, 78.02, 78.16, 77.18, 77.36, 78.02, 77.8, 78.0, 77.95, 78.11, 78.68, 78.43, 78.43, 78.27, 78.11] 18 | apfl_15 = [73.7, 73.91, 74.09, 73.89, 73.76, 73.43, 74.28, 74.07, 73.93, 73.65, 73.7, 73.54, 74.04, 74.67, 74.48, 74.07, 74.63, 74.76, 73.83, 74.0, 74.63, 74.41, 74.61, 74.57, 74.72, 75.28, 75.09, 75.02, 74.93, 74.74] 19 | apfl_20 = [70.62, 70.83, 71.0, 70.81, 70.69, 70.37, 71.19, 70.98, 70.85, 70.58, 70.62, 70.48, 70.96, 71.56, 71.37, 70.98, 71.54, 71.65, 70.77, 70.92, 71.52, 71.31, 71.52, 71.48, 71.65, 72.15, 71.96, 71.92, 71.79, 71.62] 20 | 21 | fedavg_05 = [74.5, 75.26, 75.14, 74.64, 74.9, 76.14, 76.12, 75.02, 76.0, 76.45, 76.1, 76.19, 76.43, 76.69, 76.29, 76.55, 76.71, 76.5, 76.95, 76.81, 76.26, 77.38, 77.07, 76.62, 76.81, 77.07, 77.1, 76.71, 76.71, 76.76] 22 | fedavg_10 = [74.32, 75.2, 75.2, 74.59, 74.84, 76.09, 75.95, 74.86, 75.8, 76.41, 75.98, 76.23, 76.18, 76.48, 76.2, 76.25, 76.43, 76.27, 76.77, 76.75, 76.14, 76.98, 76.93, 76.52, 76.75, 76.91, 77.09, 76.59, 76.64, 76.68] 23 | fedavg_15 = [73.39, 74.26, 74.26, 73.8, 73.96, 75.15, 75.13, 74.15, 74.93, 75.46, 75.04, 75.39, 75.5, 75.83, 75.48, 75.63, 75.63, 75.59, 75.89, 76.02, 75.35, 76.43, 76.3, 75.78, 76.09, 76.43, 76.52, 75.98, 76.28, 76.07] 24 | fedavg_20 = [73.1, 73.94, 74.02, 73.23, 73.56, 74.71, 74.71, 73.6, 74.5, 75.08, 74.75, 74.83, 75.02, 75.33, 75.02, 75.29, 75.42, 74.96, 75.73, 75.5, 74.92, 76.08, 75.71, 75.5, 75.6, 75.85, 76.0, 75.44, 75.73, 75.69] 25 | 26 | local_05 = [80.48, 81.33, 80.76, 80.31, 80.55, 80.9, 80.45, 81.05, 81.21, 81.76, 81.14, 80.79, 80.83, 81.31, 81.19, 81.14, 81.17, 80.98, 81.14, 80.81, 81.5, 81.48, 81.76, 81.45, 81.48, 81.0, 81.26, 80.86, 81.14, 80.9] 27 | local_10 = [76.82, 77.64, 77.09, 76.66, 76.89, 77.23, 76.8, 77.36, 77.52, 78.05, 77.45, 77.11, 77.16, 77.61, 77.5, 77.45, 77.48, 77.3, 77.45, 77.14, 77.8, 77.77, 78.05, 77.75, 77.77, 77.32, 77.57, 77.18, 77.45, 77.23] 28 | local_15 = [73.48, 74.26, 73.74, 73.33, 73.54, 73.87, 73.46, 74.0, 74.15, 74.65, 74.09, 73.76, 73.8, 74.24, 74.13, 74.09, 74.11, 73.93, 74.09, 73.78, 74.41, 74.39, 74.65, 74.37, 74.39, 73.96, 74.2, 73.83, 74.09, 73.87] 29 | local_20 = [70.42, 71.17, 70.67, 70.27, 70.48, 70.79, 70.4, 70.92, 71.06, 71.54, 71.0, 70.69, 70.73, 71.15, 71.04, 71.0, 71.02, 70.85, 71.0, 70.71, 71.31, 71.29, 71.54, 71.27, 71.29, 70.87, 71.1, 70.75, 71.0, 70.79] 30 | 31 | scei_y = [np.mean(scei_05), np.mean(scei_10), np.mean(scei_15), np.mean(scei_20), ] 32 | scei_err = [np.std(scei_05), np.std(scei_10), np.std(scei_15), np.std(scei_20), ] 33 | 34 | sceia_y = [np.mean(sceia_05), np.mean(sceia_10), np.mean(sceia_15), np.mean(sceia_20), ] 35 | sceia_err = [np.std(sceia_05), np.std(sceia_10), np.std(sceia_15), np.std(sceia_20), ] 36 | 37 | apfl_y = [np.mean(apfl_05), np.mean(apfl_10), np.mean(apfl_15), np.mean(apfl_20), ] 38 | apfl_err = [np.std(apfl_05), np.std(apfl_10), np.std(apfl_15), np.std(apfl_20), ] 39 | 40 | fedavg_y = [np.mean(fedavg_05), np.mean(fedavg_10), np.mean(fedavg_15), np.mean(fedavg_20), ] 41 | fedavg_err = [np.std(fedavg_05), np.std(fedavg_10), np.std(fedavg_15), np.std(fedavg_20), ] 42 | 43 | local_y = [np.mean(local_05), np.mean(local_10), np.mean(local_15), np.mean(local_20), ] 44 | local_err = [np.std(local_05), np.std(local_10), np.std(local_15), np.std(local_20), ] 45 | 46 | data = {'scei_y': scei_y, 'scei_err': scei_err, 47 | 'apfl_y': apfl_y, 'apfl_err': apfl_err, 48 | 'fedavg_y': fedavg_y, 'fedavg_err': fedavg_err, 49 | 'local_y': local_y, 'local_err': local_err, 50 | 'sceia_y': sceia_y, 'sceia_err': sceia_err, } 51 | 52 | save_path = None 53 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 54 | save_path = sys.argv[2] 55 | 56 | plot_skew("", data, False, False, save_path, plot_size="4") 57 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-skew/cnn-uci-acc-skew.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | from plot.utils import plot_skew 5 | 6 | scei_05 = [94.13, 94.18, 94.01, 94.23, 94.5, 94.32, 94.32, 94.12, 94.16, 94.35, 94.52, 94.68, 94.29, 94.74, 94.74, 94.35, 94.65, 94.77, 94.82, 94.78, 95.01, 94.67, 94.65, 94.88, 94.62, 94.89, 94.88, 94.5, 94.79, 94.73] 7 | scei_10 = [93.05, 93.04, 92.89, 93.13, 93.39, 93.29, 93.31, 93.11, 93.23, 93.46, 93.73, 93.88, 93.51, 94.0, 94.01, 93.69, 93.96, 94.08, 94.1, 94.12, 94.46, 94.1, 94.1, 94.26, 94.06, 94.3, 94.27, 93.95, 94.21, 94.15] 8 | scei_15 = [91.71, 91.73, 91.6, 91.88, 92.23, 92.1, 92.12, 91.87, 92.08, 92.36, 92.69, 92.91, 92.41, 93.02, 92.98, 92.6, 92.91, 93.03, 93.17, 93.14, 93.64, 93.22, 93.17, 93.47, 93.27, 93.47, 93.55, 93.09, 93.45, 93.48] 9 | scei_20 = [90.91, 90.94, 90.84, 91.1, 91.4, 91.34, 91.41, 91.22, 91.44, 91.7, 92.08, 92.25, 91.82, 92.37, 92.4, 92.03, 92.32, 92.44, 92.58, 92.58, 93.02, 92.66, 92.64, 92.81, 92.72, 92.87, 92.99, 92.59, 92.91, 92.92] 10 | 11 | sceia_05 = [88.53, 88.41, 85.60, 88.17, 89.87, 87.68, 86.58, 90.0 , 89.51, 88.41, 90.24, 89.87, 88.29, 90.0 , 89.26, 88.90, 89.14, 82.68, 87.68, 86.46, 89.26, 88.78, 88.41, 90.12, 88.29, 89.26, 89.14, 88.78, 89.51, 89.02] 12 | sceia_10 = [88.57, 88.69, 85.83, 88.21, 90.11, 87.97, 85.83, 89.40, 89.04, 87.73, 89.40, 89.16, 87.61, 89.40, 88.69, 88.33, 88.45, 81.90, 87.14, 85.47, 88.80, 87.97, 87.61, 89.40, 87.61, 88.45, 88.33, 88.09, 88.92, 88.33] 13 | sceia_15 = [88.83, 88.95, 86.27, 88.60, 90.23, 87.67, 84.06, 88.95, 88.83, 87.32, 88.95, 88.48, 88.37, 88.60, 87.79, 87.20, 87.55, 80.23, 85.11, 83.60, 87.09, 85.69, 85.58, 86.74, 85.34, 85.93, 85.93, 85.58, 86.27, 85.81] 14 | sceia_20 = [88.40, 88.40, 85.68, 87.84, 89.43, 86.47, 82.84, 87.61, 87.15, 85.45, 87.72, 86.59, 86.59, 86.93, 86.36, 85.56, 85.68, 78.40, 83.86, 82.5, 85.56, 84.31, 84.09, 85.90, 84.31, 84.77, 85.0, 84.54, 85.22, 84.43] 15 | 16 | apfl_05 = [92.44, 92.45, 92.55, 92.38, 92.27, 92.48, 92.05, 92.44, 92.28, 92.43, 92.38, 92.26, 91.94, 92.45, 92.37, 92.49, 92.57, 92.21, 92.32, 92.5, 92.2, 92.41, 92.49, 92.33, 92.6, 92.27, 92.46, 92.22, 92.43, 92.22] 17 | apfl_10 = [90.24, 90.25, 90.35, 90.18, 90.07, 90.27, 89.86, 90.24, 90.08, 90.23, 90.18, 90.06, 89.75, 90.25, 90.17, 90.29, 90.37, 90.01, 90.12, 90.3, 90.0, 90.21, 90.29, 90.13, 90.39, 90.07, 90.26, 90.04, 90.24, 90.04] 18 | apfl_15 = [88.14, 88.15, 88.24, 88.08, 87.98, 88.17, 87.77, 88.14, 87.99, 88.13, 88.08, 87.97, 87.66, 88.16, 88.08, 88.2, 88.28, 87.93, 88.03, 88.21, 87.92, 88.13, 88.2, 88.05, 88.3, 88.0, 88.19, 87.97, 88.16, 87.97] 19 | apfl_20 = [86.14, 86.15, 86.24, 86.08, 85.98, 86.17, 85.77, 86.14, 85.99, 86.12, 86.08, 85.97, 85.67, 86.16, 86.09, 86.2, 86.31, 85.94, 86.07, 86.24, 85.95, 86.16, 86.23, 86.08, 86.33, 86.01, 86.19, 85.98, 86.17, 85.98] 20 | 21 | fedavg_05 = [92.44, 92.07, 92.16, 92.44, 92.17, 92.3, 92.26, 91.9, 92.16, 91.91, 92.12, 91.8, 92.06, 92.24, 91.85, 91.68, 92.21, 91.67, 91.73, 91.99, 91.71, 91.9, 91.57, 91.55, 91.83, 91.93, 91.49, 91.68, 91.5, 91.67] 22 | fedavg_10 = [92.51, 92.19, 92.27, 92.52, 92.27, 92.42, 92.37, 92.01, 92.27, 92.01, 92.24, 91.93, 92.17, 92.35, 91.98, 91.81, 92.27, 91.76, 91.81, 92.12, 91.79, 91.99, 91.67, 91.63, 91.9, 92.0, 91.58, 91.79, 91.64, 91.81] 23 | fedavg_15 = [92.29, 91.97, 92.05, 92.27, 92.03, 92.19, 92.12, 91.76, 92.01, 91.8, 92.01, 91.67, 91.95, 92.1, 91.73, 91.57, 92.09, 91.56, 91.63, 91.92, 91.62, 91.81, 91.43, 91.42, 91.71, 91.81, 91.33, 91.56, 91.35, 91.55] 24 | fedavg_20 = [92.51, 92.16, 92.26, 92.49, 92.27, 92.39, 92.36, 91.98, 92.24, 91.98, 92.19, 91.85, 92.18, 92.34, 91.97, 91.72, 92.32, 91.69, 91.89, 92.07, 91.81, 91.97, 91.73, 91.66, 91.99, 92.1, 91.61, 91.76, 91.64, 91.81] 25 | 26 | local_05 = [90.82, 91.11, 91.16, 91.27, 91.33, 90.84, 90.83, 91.02, 91.0, 91.24, 91.21, 91.12, 90.77, 91.09, 90.95, 91.1, 91.15, 91.26, 90.82, 91.0, 91.0, 91.1, 91.23, 91.09, 91.11, 91.11, 91.02, 91.1, 90.99, 91.29] 27 | local_10 = [88.65, 88.94, 88.99, 89.1, 89.15, 88.68, 88.67, 88.86, 88.83, 89.07, 89.04, 88.95, 88.61, 88.92, 88.79, 88.93, 88.98, 89.08, 88.65, 88.83, 88.83, 88.93, 89.06, 88.92, 88.94, 88.94, 88.86, 88.93, 88.82, 89.12] 28 | local_15 = [86.59, 86.87, 86.92, 87.02, 87.08, 86.62, 86.6, 86.79, 86.77, 87.0, 86.97, 86.88, 86.55, 86.85, 86.72, 86.86, 86.91, 87.01, 86.59, 86.77, 86.77, 86.86, 86.99, 86.85, 86.87, 86.87, 86.79, 86.86, 86.76, 87.05] 29 | local_20 = [84.62, 84.9, 84.94, 85.05, 85.1, 84.65, 84.64, 84.82, 84.8, 85.02, 84.99, 84.91, 84.58, 84.87, 84.75, 84.89, 84.93, 85.03, 84.62, 84.8, 84.8, 84.89, 85.01, 84.87, 84.9, 84.9, 84.82, 84.89, 84.78, 85.07] 30 | 31 | scei_y = [np.mean(scei_05), np.mean(scei_10), np.mean(scei_15), np.mean(scei_20), ] 32 | scei_err = [np.std(scei_05), np.std(scei_10), np.std(scei_15), np.std(scei_20), ] 33 | 34 | sceia_y = [np.mean(sceia_05), np.mean(sceia_10), np.mean(sceia_15), np.mean(sceia_20), ] 35 | sceia_err = [np.std(sceia_05), np.std(sceia_10), np.std(sceia_15), np.std(sceia_20), ] 36 | 37 | apfl_y = [np.mean(apfl_05), np.mean(apfl_10), np.mean(apfl_15), np.mean(apfl_20), ] 38 | apfl_err = [np.std(apfl_05), np.std(apfl_10), np.std(apfl_15), np.std(apfl_20), ] 39 | 40 | fedavg_y = [np.mean(fedavg_05), np.mean(fedavg_10), np.mean(fedavg_15), np.mean(fedavg_20), ] 41 | fedavg_err = [np.std(fedavg_05), np.std(fedavg_10), np.std(fedavg_15), np.std(fedavg_20), ] 42 | 43 | local_y = [np.mean(local_05), np.mean(local_10), np.mean(local_15), np.mean(local_20), ] 44 | local_err = [np.std(local_05), np.std(local_10), np.std(local_15), np.std(local_20), ] 45 | 46 | data = {'scei_y': scei_y, 'scei_err': scei_err, 47 | 'apfl_y': apfl_y, 'apfl_err': apfl_err, 48 | 'fedavg_y': fedavg_y, 'fedavg_err': fedavg_err, 49 | 'local_y': local_y, 'local_err': local_err, 50 | 'sceia_y': sceia_y, 'sceia_err': sceia_err, } 51 | 52 | save_path = None 53 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 54 | save_path = sys.argv[2] 55 | 56 | plot_skew("", data, False, False, save_path, plot_size="4") 57 | -------------------------------------------------------------------------------- /federated-learning/utils/sampling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | 6 | import numpy as np 7 | from torchvision import datasets, transforms 8 | 9 | def mnist_iid(dataset, num_users): 10 | """ 11 | Sample I.I.D. client data from MNIST dataset 12 | :param dataset: 13 | :param num_users: 14 | :return: dict of image index 15 | """ 16 | num_items = int(len(dataset)/num_users) 17 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 18 | for i in range(num_users): 19 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) 20 | all_idxs = list(set(all_idxs) - dict_users[i]) 21 | return dict_users 22 | 23 | 24 | def mnist_noniid(dataset, num_users): 25 | """ 26 | Sample non-I.I.D client data from MNIST dataset 27 | :param dataset: 28 | :param num_users: 29 | :return: 30 | """ 31 | num_shards, num_imgs = 200, 300 32 | idx_shard = [i for i in range(num_shards)] 33 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 34 | idxs = np.arange(num_shards*num_imgs) 35 | labels = dataset.train_labels.numpy() 36 | 37 | # sort labels 38 | idxs_labels = np.vstack((idxs, labels)) 39 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 40 | idxs = idxs_labels[0,:] 41 | 42 | # divide and assign 43 | for i in range(num_users): 44 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 45 | idx_shard = list(set(idx_shard) - rand_set) 46 | for rand in rand_set: 47 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 48 | return dict_users 49 | 50 | 51 | def get_indices(labels, user_labels, n_samples): 52 | indices = [] 53 | for selected_label in user_labels: 54 | label_samples = np.where(labels[1, :] == selected_label) 55 | label_indices = labels[0, label_samples] 56 | selected_indices = list(np.random.choice(label_indices[0], n_samples, replace=False)) 57 | indices += selected_indices 58 | return indices 59 | 60 | 61 | def noniid_onepass(dataset_train, dataset_test, num_users, dataset_name='mnist', kept_class=4): 62 | train_users = {} 63 | test_users = {} 64 | skew_users1 = {} 65 | skew_users2 = {} 66 | skew_users3 = {} 67 | skew_users4 = {} 68 | 69 | skew1_pct = 0.05 70 | skew2_pct = 0.10 71 | skew3_pct = 0.15 72 | skew4_pct = 0.20 73 | 74 | train_idxs = np.arange(len(dataset_train)) 75 | train_labels = dataset_train.targets 76 | train_labels = np.vstack((train_idxs, train_labels)) 77 | 78 | test_idxs = np.arange(len(dataset_test)) 79 | test_labels = dataset_test.targets 80 | test_labels = np.vstack((test_idxs, test_labels)) 81 | if dataset_name == 'mnist': 82 | labels = list(range(10)) 83 | samples = [150, 50, int(50*skew1_pct), int(50*skew2_pct), int(50*skew3_pct), int(50*skew4_pct)] 84 | elif dataset_name == 'cifar10': 85 | labels = list(range(10)) 86 | samples = [150, 50, int(50*skew1_pct), int(50*skew2_pct), int(50*skew3_pct), int(50*skew4_pct)] 87 | elif dataset_name == 'cifar100': 88 | labels = list(range(100)) 89 | samples = [150, 50, int(50*skew1_pct), int(50*skew2_pct), int(50*skew3_pct), int(50*skew4_pct)] 90 | elif dataset_name == 'imagenet': 91 | labels = list(range(200)) 92 | samples = [150, 50, int(50 * skew1_pct), int(50 * skew2_pct), int(50 * skew3_pct), int(50 * skew4_pct)] 93 | elif dataset_name == 'uci': 94 | labels = list(range(6)) 95 | samples = [500, 200, int(200*skew1_pct), int(200*skew2_pct), int(200*skew3_pct), int(200*skew4_pct)] 96 | elif dataset_name == 'realworld': 97 | labels = list(range(8)) 98 | samples = [500, 100, int(100*skew1_pct), int(100*skew2_pct), int(100*skew3_pct), int(100*skew4_pct)] 99 | for i in range(num_users): 100 | user_labels = np.random.choice(labels, size=kept_class, replace=False) 101 | skew_labels = [i for i in labels if i not in user_labels] 102 | if len(skew_labels) > 6: # at most skew 6 classes of data 103 | skew_labels = np.random.choice(skew_labels, size=6, replace=False) 104 | train_indices = get_indices(train_labels, user_labels, n_samples=samples[0]) 105 | test_indices = get_indices(test_labels, user_labels, n_samples=samples[1]) 106 | 107 | skew1_indices = get_indices(test_labels, skew_labels, n_samples=samples[2]) 108 | skew2_indices = get_indices(test_labels, skew_labels, n_samples=samples[3]) 109 | skew3_indices = get_indices(test_labels, skew_labels, n_samples=samples[4]) 110 | skew4_indices = get_indices(test_labels, skew_labels, n_samples=samples[5]) 111 | 112 | train_users[i] = train_indices 113 | test_users[i] = test_indices 114 | skew_users1[i] = skew1_indices 115 | skew_users2[i] = skew2_indices 116 | skew_users3[i] = skew3_indices 117 | skew_users4[i] = skew4_indices 118 | return train_users, test_users, (skew_users1, skew_users2, skew_users3, skew_users4) 119 | 120 | 121 | def cifar_iid(dataset, num_users): 122 | """ 123 | Sample I.I.D. client data from CIFAR10 dataset 124 | :param dataset: 125 | :param num_users: 126 | :return: dict of image index 127 | """ 128 | num_items = int(len(dataset)/num_users) 129 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 130 | for i in range(num_users): 131 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) 132 | all_idxs = list(set(all_idxs) - dict_users[i]) 133 | return dict_users 134 | 135 | 136 | if __name__ == '__main__': 137 | dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, 138 | transform=transforms.Compose([ 139 | transforms.ToTensor(), 140 | transforms.Normalize((0.1307,), (0.3081,)) 141 | ])) 142 | num = 100 143 | d = mnist_noniid(dataset_train, num) 144 | 145 | -------------------------------------------------------------------------------- /federated-learning/utils/DatasetStore.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision import datasets, transforms 6 | 7 | from utils.sampling import noniid_onepass 8 | from datasets.IMAGENET import IMAGENETDataset 9 | from datasets.REALWORLD import REALWORLDDataset 10 | from datasets.UCI import UCIDataset 11 | from utils.util import ColoredLogger 12 | 13 | logging.setLoggerClass(ColoredLogger) 14 | logging.getLogger("werkzeug").setLevel(logging.ERROR) 15 | logger = logging.getLogger("DatasetStore") 16 | 17 | 18 | class DatasetSplit(Dataset): 19 | def __init__(self, dataset, idxs): 20 | self.dataset = dataset 21 | self.idxs = list(idxs) 22 | self.targets = torch.Tensor([self.dataset.targets[idx] for idx in idxs]) 23 | 24 | def classes(self): 25 | return torch.unique(self.targets) 26 | 27 | def __len__(self): 28 | return len(self.idxs) 29 | 30 | def __getitem__(self, item): 31 | data, label = self.dataset[self.idxs[item]] 32 | return data, label 33 | 34 | 35 | class LocalDataset: 36 | def __init__(self): 37 | self.initialized = False 38 | self.dataset_name = "" 39 | self.dataset_train = None 40 | self.dataset_test = None 41 | self.image_shape = None 42 | self.dict_users = None 43 | self.test_users = None 44 | self.skew_users = None 45 | 46 | def init_local_dataset(self, dataset_name, num_users): 47 | dataset_train = None 48 | dataset_test = None 49 | real_path = os.path.dirname(os.path.realpath(__file__)) 50 | # load dataset and split users 51 | if dataset_name == 'mnist': 52 | trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 53 | data_path = os.path.join(real_path, "../../data/mnist/") 54 | dataset_train = datasets.MNIST(data_path, train=True, download=True, transform=trans) 55 | dataset_test = datasets.MNIST(data_path, train=False, download=True, transform=trans) 56 | 57 | elif dataset_name == 'fmnist': 58 | trans = transforms.Compose([transforms.ToTensor()]) 59 | data_path = os.path.join(real_path, "../../data/fashion-mnist/") 60 | dataset_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=trans) 61 | dataset_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=trans) 62 | 63 | elif dataset_name == 'cifar10': 64 | trans = transforms.Compose( 65 | [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 66 | data_path = os.path.join(real_path, "../../data/cifar10/") 67 | dataset_train = datasets.CIFAR10(data_path, train=True, download=True, transform=trans) 68 | dataset_test = datasets.CIFAR10(data_path, train=False, download=True, transform=trans) 69 | 70 | elif dataset_name == 'cifar100': 71 | trans = transforms.Compose( 72 | [transforms.ToTensor(), transforms.Normalize((0.5074, 0.4867, 0.4411), (0.2011, 0.1987, 0.2025))]) 73 | data_path = os.path.join(real_path, "../../data/cifar100/") 74 | dataset_train = datasets.CIFAR100(data_path, train=True, download=True, transform=trans) 75 | dataset_test = datasets.CIFAR100(data_path, train=False, download=True, transform=trans) 76 | 77 | elif dataset_name == 'imagenet': 78 | # https://towardsdatascience.com/pytorch-ignite-classifying-tiny-imagenet-with-efficientnet-e5b1768e5e8f#4195 79 | # https://github.com/kennethleungty/PyTorch-Ignite-Tiny-ImageNet-Classification 80 | # Retrieve data directly from Stanford data source 81 | # !wget http://cs231n.stanford.edu/tiny-imagenet-200.zip 82 | trans = transforms.Compose( 83 | [transforms.Resize(64), transforms.CenterCrop(32), transforms.RandomHorizontalFlip(), 84 | transforms.ToTensor()]) 85 | data_path = os.path.join(real_path, "../../data/imagenet/") 86 | train_dir = os.path.join(data_path, 'train') 87 | # test_dir = os.path.join(data_path, 'test') 88 | dataset_full = datasets.ImageFolder(train_dir, transform=trans) 89 | train_size = int(0.8 * len(dataset_full)) 90 | test_size = len(dataset_full) - train_size 91 | subset_train, subset_test = torch.utils.data.random_split(dataset_full, [train_size, test_size]) 92 | dataset_train = IMAGENETDataset(subset_train) 93 | dataset_test = IMAGENETDataset(subset_test) 94 | 95 | elif dataset_name == 'uci': 96 | # https://archive.ics.uci.edu/ml/datasets/human+activity+recognition+using+smartphones 97 | uci_data_path = os.path.join(real_path, "../../data/uci/") 98 | dataset_train = UCIDataset(data_path=uci_data_path, phase='train') 99 | dataset_test = UCIDataset(data_path=uci_data_path, phase='eval') 100 | 101 | elif dataset_name == 'realworld': 102 | # https://sensor.informatik.uni-mannheim.de/#dataset_realworld 103 | realworld_data_path = os.path.join(real_path, "../../data/realworld/") 104 | dataset_train = REALWORLDDataset(data_path=realworld_data_path, phase='train') 105 | dataset_test = REALWORLDDataset(data_path=realworld_data_path, phase='eval') 106 | 107 | dict_users, test_users, skew_users = noniid_onepass(dataset_train, dataset_test, num_users, 108 | dataset_name=dataset_name) 109 | 110 | self.dataset_name = dataset_name 111 | self.dataset_train = dataset_train 112 | self.dataset_test = dataset_test 113 | self.image_shape = dataset_train[0][0].shape 114 | self.dict_users = dict_users 115 | self.test_users = test_users 116 | self.skew_users = skew_users 117 | self.initialized = True 118 | 119 | def load_train_dataset(self, idx, local_bs): 120 | split_ds = DatasetSplit(self.dataset_train, self.dict_users[idx]) 121 | return DataLoader(split_ds, batch_size=local_bs, shuffle=True) 122 | 123 | def load_test_dataset(self, idxs, local_test_bs): 124 | split_ds = DatasetSplit(self.dataset_test, idxs) 125 | return DataLoader(split_ds, batch_size=local_test_bs) 126 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-skew/cnn-imagenet-acc-skew.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | from plot.utils import plot_skew 5 | 6 | scei_05 = [26.65, 27.64, 32.36, 33.21, 36.75, 39.1, 39.06, 38.02, 41.79, 44.48, 44.39, 43.96, 47.12, 49.15, 49.29, 49.39, 49.91, 51.89, 50.33, 51.23, 51.18, 52.41, 52.36, 53.82, 52.83, 52.78, 51.79, 51.27, 52.36, 52.88, 52.92, 51.98, 52.31, 51.42, 51.04, 51.51, 51.23, 50.38, 52.36, 50.99, 50.38, 50.57, 51.32, 51.27, 50.05, 51.51, 51.18, 50.66, 50.42, 49.86] 7 | scei_10 = [24.57, 25.48, 29.83, 30.61, 33.87, 36.04, 36.0, 35.04, 38.52, 41.0, 40.91, 40.52, 43.43, 45.3, 45.43, 45.52, 46.0, 47.83, 46.39, 47.22, 47.17, 48.3, 48.26, 49.61, 48.7, 48.65, 47.74, 47.26, 48.26, 48.74, 48.78, 47.91, 48.22, 47.39, 47.04, 47.48, 47.22, 46.43, 48.26, 47.0, 46.43, 46.61, 47.3, 47.26, 46.13, 47.48, 47.22, 46.7, 46.52, 45.96] 8 | scei_15 = [23.35, 24.21, 28.35, 29.09, 32.19, 34.26, 34.21, 33.31, 36.61, 38.97, 38.88, 38.51, 41.28, 43.06, 43.18, 43.26, 43.72, 45.45, 44.09, 44.88, 44.83, 45.91, 45.87, 47.15, 46.28, 46.24, 45.37, 44.92, 45.87, 46.32, 46.36, 45.54, 45.83, 45.04, 44.71, 45.12, 44.88, 44.13, 45.87, 44.67, 44.13, 44.3, 44.96, 44.92, 43.84, 45.12, 44.83, 44.38, 44.17, 43.68] 9 | scei_20 = [21.73, 22.54, 26.38, 27.08, 29.96, 31.88, 31.85, 31.0, 34.08, 36.27, 36.19, 35.85, 38.42, 40.08, 40.19, 40.27, 40.69, 42.31, 41.04, 41.77, 41.73, 42.73, 42.69, 43.88, 43.08, 43.04, 42.23, 41.81, 42.69, 43.12, 43.15, 42.38, 42.65, 41.92, 41.62, 42.0, 41.77, 41.08, 42.69, 41.58, 41.08, 41.23, 41.85, 41.81, 40.81, 42.0, 41.73, 41.31, 41.12, 40.65] 10 | 11 | sceia_05 = [24.48, 24.91, 28.16, 34.39, 39.91, 39.06, 39.67, 42.69, 43.3, 44.25, 47.69, 48.07, 48.07, 48.87, 48.02, 50.38, 50.24, 48.68, 46.65, 48.68, 48.68, 48.96, 47.26, 48.11, 48.58, 47.36, 46.89, 47.74, 47.22, 47.22, 46.23, 45.71, 46.93, 46.75, 45.8, 45.94, 46.37, 44.01, 45.57, 44.1, 45.61, 45.14, 45.09, 46.56, 44.34, 43.54, 43.77, 45.19, 45.42, 45.61] 12 | sceia_10 = [22.57, 22.96, 25.96, 31.7, 36.78, 36.0, 36.57, 39.35, 39.91, 40.78, 43.96, 44.3, 44.3, 45.04, 44.26, 46.43, 46.3, 44.87, 43.0, 44.87, 44.87, 45.13, 43.57, 44.3, 44.78, 43.65, 43.22, 44.0, 43.52, 43.48, 42.57, 42.13, 43.26, 43.09, 42.17, 42.35, 42.7, 40.57, 42.0, 40.65, 42.04, 41.61, 41.57, 42.91, 40.87, 40.13, 40.3, 41.61, 41.87, 42.04] 13 | sceia_15 = [21.45, 21.82, 24.67, 30.12, 34.96, 34.21, 34.75, 37.4, 37.93, 38.76, 41.78, 42.11, 42.11, 42.81, 42.07, 44.13, 44.01, 42.64, 40.87, 42.64, 42.64, 42.89, 41.4, 42.11, 42.56, 41.49, 41.07, 41.82, 41.36, 41.32, 40.45, 40.04, 41.12, 40.95, 40.08, 40.25, 40.58, 38.55, 39.92, 38.64, 40.0, 39.55, 39.5, 40.79, 38.84, 38.14, 38.31, 39.55, 39.79, 40.0] 14 | sceia_20 = [19.96, 20.31, 22.96, 28.04, 32.54, 31.85, 32.35, 34.81, 35.31, 36.08, 38.88, 39.19, 39.19, 39.85, 39.15, 41.08, 40.96, 39.69, 38.04, 39.69, 39.69, 39.92, 38.54, 39.19, 39.62, 38.65, 38.23, 38.92, 38.5, 38.46, 37.65, 37.27, 38.27, 38.12, 37.31, 37.46, 37.77, 35.88, 37.15, 35.96, 37.19, 36.81, 36.77, 37.96, 36.15, 35.5, 35.65, 36.81, 37.04, 37.19] 15 | 16 | apfl_05 = [24.58, 36.27, 44.15, 49.91, 52.22, 52.12, 52.41, 52.64, 52.08, 25.57, 39.25, 48.68, 52.17, 52.78, 50.52, 51.75, 49.53, 50.05, 50.0, 36.18, 48.82, 51.04, 50.71, 48.96, 48.35, 47.74, 48.68, 51.51, 50.09, 40.61, 48.82, 47.97, 45.66, 49.15, 47.74, 48.11, 48.02, 47.22, 48.68, 41.98, 47.83, 47.5, 47.12, 46.18, 48.4, 46.93, 48.07, 48.73, 47.69, 42.74] 17 | apfl_10 = [22.65, 33.43, 40.7, 46.0, 48.13, 48.04, 48.3, 48.52, 48.0, 23.57, 36.17, 44.87, 48.09, 48.65, 46.57, 47.7, 45.65, 46.13, 46.09, 33.35, 45.0, 47.04, 46.74, 45.13, 44.57, 44.0, 44.87, 47.48, 46.17, 37.43, 45.0, 44.22, 42.09, 45.3, 44.0, 44.35, 44.26, 43.52, 44.87, 38.7, 44.09, 43.78, 43.43, 42.57, 44.61, 43.26, 44.3, 44.91, 43.96, 39.39] 18 | apfl_15 = [21.53, 31.78, 38.68, 43.72, 45.74, 45.66, 45.91, 46.12, 45.62, 22.4, 34.38, 42.64, 45.7, 46.24, 44.26, 45.33, 43.39, 43.84, 43.8, 31.69, 42.77, 44.71, 44.42, 42.89, 42.36, 41.82, 42.64, 45.12, 43.88, 35.58, 42.77, 42.02, 40.0, 43.06, 41.82, 42.15, 42.07, 41.36, 42.64, 36.78, 41.9, 41.61, 41.28, 40.45, 42.4, 41.12, 42.11, 42.69, 41.78, 37.44] 19 | apfl_20 = [20.04, 29.58, 36.0, 40.69, 42.58, 42.5, 42.73, 42.92, 42.46, 20.85, 32.0, 39.69, 42.54, 43.04, 41.19, 42.19, 40.38, 40.81, 40.77, 29.5, 39.81, 41.62, 41.35, 39.92, 39.42, 38.92, 39.69, 42.0, 40.85, 33.12, 39.81, 39.12, 37.23, 40.08, 38.92, 39.23, 39.15, 38.5, 39.69, 34.23, 39.0, 38.73, 38.42, 37.65, 39.46, 38.27, 39.19, 39.73, 38.88, 34.85] 20 | 21 | fedavg_05 = [4.72, 4.95, 6.13, 6.65, 7.17, 7.03, 7.78, 7.78, 8.54, 9.25, 9.48, 10.33, 10.38, 9.95, 9.95, 9.43, 10.33, 10.0, 9.53, 9.53, 10.85, 9.25, 10.57, 9.48, 9.67, 9.48, 10.09, 9.58, 9.72, 9.1, 9.58, 9.43, 9.06, 8.92, 8.92, 8.25, 9.06, 8.44, 8.87, 7.74, 9.39, 8.68, 7.97, 8.11, 8.35, 8.82, 8.77, 8.21, 8.3, 7.83] 22 | fedavg_10 = [4.35, 4.57, 5.7, 6.3, 6.78, 6.7, 7.39, 7.26, 7.87, 8.61, 8.91, 9.52, 9.74, 9.22, 9.22, 8.91, 9.65, 9.39, 8.87, 9.17, 10.35, 8.65, 10.0, 9.0, 9.22, 9.0, 9.57, 9.04, 9.26, 8.57, 9.13, 8.96, 8.43, 8.52, 8.39, 7.74, 8.57, 8.04, 8.52, 7.3, 8.91, 8.3, 7.48, 7.61, 7.83, 8.22, 8.22, 7.87, 7.61, 7.26] 23 | fedavg_15 = [4.13, 4.34, 5.54, 6.07, 6.53, 6.4, 7.07, 6.94, 7.52, 8.39, 8.68, 9.3, 9.3, 8.88, 8.84, 8.64, 9.38, 9.3, 8.47, 8.68, 10.12, 8.51, 9.79, 8.76, 9.01, 8.8, 9.3, 8.93, 8.84, 8.18, 8.64, 8.39, 8.1, 8.31, 8.14, 7.52, 8.35, 7.6, 8.18, 7.19, 8.55, 7.93, 7.23, 7.31, 7.64, 7.93, 7.85, 7.6, 7.48, 7.15] 24 | fedavg_20 = [3.85, 4.04, 5.15, 5.73, 6.19, 6.12, 6.77, 6.69, 7.38, 8.27, 8.31, 9.0, 8.92, 8.58, 8.58, 8.27, 8.88, 8.88, 8.31, 8.19, 9.54, 7.92, 9.27, 8.27, 8.42, 7.88, 8.77, 8.23, 8.27, 7.62, 8.08, 7.96, 7.77, 7.62, 7.62, 7.0, 7.58, 7.19, 7.54, 6.88, 8.08, 7.23, 6.85, 6.81, 7.15, 7.42, 7.54, 6.96, 7.12, 6.54] 25 | 26 | local_05 = [27.22, 33.3, 41.18, 46.89, 48.73, 50.24, 50.0, 52.03, 51.27, 49.34, 51.65, 51.18, 50.28, 50.75, 50.57, 48.87, 49.29, 47.31, 49.53, 50.24, 50.05, 49.25, 49.95, 49.39, 50.42, 50.09, 50.47, 48.44, 50.19, 49.81, 50.75, 50.71, 49.15, 50.9, 50.75, 51.13, 51.18, 50.38, 51.42, 51.18, 50.38, 50.66, 50.75, 48.96, 49.91, 50.75, 50.05, 50.71, 50.47, 50.52] 27 | local_10 = [25.09, 30.7, 37.96, 43.22, 44.91, 46.3, 46.09, 47.96, 47.26, 45.48, 47.61, 47.17, 46.35, 46.78, 46.61, 45.04, 45.43, 43.61, 45.65, 46.3, 46.13, 45.39, 46.04, 45.52, 46.48, 46.17, 46.52, 44.65, 46.26, 45.91, 46.78, 46.74, 45.3, 46.91, 46.78, 47.13, 47.17, 46.43, 47.39, 47.17, 46.43, 46.7, 46.78, 45.13, 46.0, 46.78, 46.13, 46.74, 46.52, 46.57] 28 | local_15 = [23.84, 29.17, 36.07, 41.07, 42.69, 44.01, 43.8, 45.58, 44.92, 43.22, 45.25, 44.83, 44.05, 44.46, 44.3, 42.81, 43.18, 41.45, 43.39, 44.01, 43.84, 43.14, 43.76, 43.26, 44.17, 43.88, 44.21, 42.44, 43.97, 43.64, 44.46, 44.42, 43.06, 44.59, 44.46, 44.79, 44.83, 44.13, 45.04, 44.83, 44.13, 44.38, 44.46, 42.89, 43.72, 44.46, 43.84, 44.42, 44.21, 44.26] 29 | local_20 = [22.19, 27.15, 33.58, 38.23, 39.73, 40.96, 40.77, 42.42, 41.81, 40.23, 42.12, 41.73, 41.0, 41.38, 41.23, 39.85, 40.19, 38.58, 40.38, 40.96, 40.81, 40.15, 40.73, 40.27, 41.12, 40.85, 41.15, 39.5, 40.92, 40.62, 41.38, 41.35, 40.08, 41.5, 41.38, 41.69, 41.73, 41.08, 41.92, 41.73, 41.08, 41.31, 41.38, 39.92, 40.69, 41.38, 40.81, 41.35, 41.15, 41.19] 30 | 31 | scei_y = [np.mean(scei_05), np.mean(scei_10), np.mean(scei_15), np.mean(scei_20), ] 32 | scei_err = [np.std(scei_05), np.std(scei_10), np.std(scei_15), np.std(scei_20), ] 33 | 34 | sceia_y = [np.mean(sceia_05), np.mean(sceia_10), np.mean(sceia_15), np.mean(sceia_20), ] 35 | sceia_err = [np.std(sceia_05), np.std(sceia_10), np.std(sceia_15), np.std(sceia_20), ] 36 | 37 | apfl_y = [np.mean(apfl_05), np.mean(apfl_10), np.mean(apfl_15), np.mean(apfl_20), ] 38 | apfl_err = [np.std(apfl_05), np.std(apfl_10), np.std(apfl_15), np.std(apfl_20), ] 39 | 40 | fedavg_y = [np.mean(fedavg_05), np.mean(fedavg_10), np.mean(fedavg_15), np.mean(fedavg_20), ] 41 | fedavg_err = [np.std(fedavg_05), np.std(fedavg_10), np.std(fedavg_15), np.std(fedavg_20), ] 42 | 43 | local_y = [np.mean(local_05), np.mean(local_10), np.mean(local_15), np.mean(local_20), ] 44 | local_err = [np.std(local_05), np.std(local_10), np.std(local_15), np.std(local_20), ] 45 | 46 | data = {'scei_y': scei_y, 'scei_err': scei_err, 47 | 'apfl_y': apfl_y, 'apfl_err': apfl_err, 48 | 'fedavg_y': fedavg_y, 'fedavg_err': fedavg_err, 49 | 'local_y': local_y, 'local_err': local_err, 50 | 'sceia_y': sceia_y, 'sceia_err': sceia_err, } 51 | 52 | save_path = None 53 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 54 | save_path = sys.argv[2] 55 | 56 | plot_skew("", data, False, False, save_path, plot_size="4") 57 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-skew/cnn-cifar100-acc-skew.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | from plot.utils import plot_skew 5 | 6 | scei_05 = [61.23, 61.89, 63.44, 63.21, 63.07, 60.66, 60.28, 62.64, 61.42, 63.3, 62.36, 62.55, 61.75, 62.69, 61.46, 61.89, 62.12, 60.75, 60.24, 60.38, 58.16, 61.18, 62.31, 60.33, 61.56, 60.52, 59.81, 62.83, 58.63, 58.77, 61.7, 61.27, 62.31, 59.34, 60.47, 58.92, 59.62, 58.58, 62.92, 61.93, 57.97, 61.37, 59.15, 60.24, 59.76, 59.15, 57.26, 61.23, 58.92, 60.0] 7 | scei_10 = [56.43, 57.04, 58.48, 58.26, 58.13, 55.91, 55.57, 57.74, 56.57, 58.35, 57.48, 57.65, 56.91, 57.83, 56.65, 57.04, 57.35, 56.09, 55.57, 55.7, 53.61, 56.35, 57.52, 55.57, 56.74, 55.83, 55.09, 58.0, 54.04, 54.22, 56.78, 56.48, 57.43, 54.7, 55.78, 54.35, 54.96, 54.04, 58.0, 57.13, 53.48, 56.57, 54.52, 55.48, 55.17, 54.57, 52.78, 56.57, 54.52, 55.48] 8 | scei_15 = [53.64, 54.21, 55.58, 55.37, 55.25, 53.14, 52.81, 54.88, 53.76, 55.45, 54.63, 54.83, 54.09, 54.92, 53.84, 54.21, 54.42, 53.26, 52.77, 52.89, 50.95, 53.55, 54.67, 52.81, 53.97, 53.1, 52.4, 55.12, 51.36, 51.61, 53.97, 53.76, 54.75, 51.94, 53.02, 51.69, 52.19, 51.45, 55.17, 54.38, 50.87, 53.88, 51.82, 52.85, 52.56, 51.94, 50.17, 53.84, 51.86, 52.64] 9 | scei_20 = [49.92, 50.46, 51.73, 51.54, 51.42, 49.46, 49.15, 51.08, 50.08, 51.65, 50.85, 51.04, 50.38, 51.15, 50.15, 50.46, 50.65, 49.58, 49.15, 49.23, 47.46, 49.88, 50.88, 49.19, 50.19, 49.5, 48.81, 51.38, 47.96, 48.12, 50.27, 50.12, 50.96, 48.46, 49.58, 48.23, 48.65, 47.96, 51.46, 50.69, 47.42, 50.23, 48.31, 49.46, 48.88, 48.58, 47.04, 50.23, 48.42, 49.27] 10 | 11 | sceia_05 = [53.92, 55.09, 57.22, 59.67, 59.2, 60.9, 59.34, 59.48, 59.15, 59.48, 58.68, 58.63, 57.92, 57.55, 57.83, 57.69, 57.41, 59.25, 59.48, 58.11, 59.01, 58.63, 59.25, 58.87, 57.88, 57.74, 57.69, 57.59, 58.44, 57.64, 57.59, 58.11, 56.79, 57.92, 54.48, 54.72, 56.7, 51.42, 48.07, 47.17, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7] 12 | sceia_10 = [49.7, 50.78, 52.74, 55.0, 54.57, 56.13, 54.7, 54.83, 54.52, 54.83, 54.09, 54.04, 53.39, 53.09, 53.3, 53.17, 52.91, 54.61, 54.83, 53.57, 54.35, 54.04, 54.61, 54.26, 53.35, 53.22, 53.17, 53.09, 53.87, 53.13, 53.09, 53.61, 52.35, 53.43, 50.22, 50.43, 52.26, 47.39, 44.3, 43.52, 43.52, 43.52, 43.52, 43.52, 43.52, 43.52, 43.52, 43.52, 43.52, 43.52] 13 | sceia_15 = [47.23, 48.26, 50.12, 52.27, 51.86, 53.35, 51.98, 52.11, 51.82, 52.11, 51.4, 51.36, 50.74, 50.41, 50.66, 50.54, 50.29, 51.9, 52.11, 50.91, 51.74, 51.36, 51.94, 51.61, 50.7, 50.58, 50.54, 50.45, 51.2, 50.5, 50.45, 50.91, 49.75, 50.74, 47.73, 47.93, 49.67, 45.12, 42.19, 41.32, 36.53, 36.53, 36.53, 36.53, 36.53, 36.53, 36.53, 36.53, 36.53, 36.53] 14 | sceia_20 = [43.96, 44.92, 46.65, 48.65, 48.27, 49.65, 48.38, 48.5, 48.23, 48.54, 47.85, 47.81, 47.23, 46.92, 47.15, 47.04, 46.81, 48.31, 48.54, 47.38, 48.12, 47.81, 48.35, 48.08, 47.19, 47.08, 47.04, 46.96, 47.69, 47.0, 46.96, 47.42, 46.31, 47.23, 44.42, 44.65, 46.23, 42.04, 39.19, 38.54, 34.08, 34.08, 34.08, 34.08, 34.08, 34.08, 34.08, 34.08, 34.08, 34.08] 15 | 16 | apfl_05 = [56.89, 64.86, 62.74, 62.78, 64.76, 63.4, 63.02, 63.87, 63.25, 30.8, 54.48, 62.22, 60.38, 61.79, 60.85, 58.54, 60.05, 59.06, 55.33, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06, 59.06] 17 | apfl_10 = [52.43, 59.78, 57.83, 57.87, 59.7, 58.43, 58.09, 58.87, 58.3, 28.39, 50.22, 57.35, 55.65, 56.96, 56.09, 53.96, 55.35, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43, 54.43] 18 | apfl_15 = [49.83, 56.82, 54.96, 55.0, 56.74, 55.54, 55.21, 55.95, 55.41, 26.98, 47.73, 54.5, 52.89, 54.13, 53.31, 51.28, 52.6, 51.74, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47, 48.47] 19 | apfl_20 = [46.38, 52.88, 51.15, 51.19, 52.81, 51.69, 51.38, 52.08, 51.58, 25.12, 44.42, 50.73, 49.23, 50.38, 49.62, 47.73, 48.96, 48.15, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12, 45.12] 20 | 21 | fedavg_05 = [8.77, 16.46, 17.03, 20.42, 21.6, 22.26, 22.83, 23.35, 24.67, 24.58, 23.77, 25.14, 25.0, 24.48, 26.08, 24.39, 24.58, 25.47, 24.58, 25.8, 24.2, 24.58, 25.24, 24.91, 25.0, 25.9, 24.58, 27.26, 20.9, 26.65, 24.86, 27.41, 26.18, 25.57, 26.04, 27.03, 27.08, 27.45, 13.49, 22.64, 24.48, 19.62, 25.05, 26.46, 2.55, 2.55, 2.55, 2.55, 2.55, 2.55] 22 | fedavg_10 = [8.09, 15.22, 15.78, 18.96, 20.09, 20.74, 21.39, 21.74, 23.04, 23.0, 22.39, 23.48, 23.3, 23.04, 24.3, 22.78, 23.04, 23.83, 22.83, 24.17, 22.48, 22.91, 23.48, 23.22, 23.3, 24.22, 23.04, 25.52, 19.48, 24.91, 23.22, 25.57, 24.48, 23.83, 24.17, 25.04, 25.04, 25.57, 12.65, 21.0, 22.74, 18.35, 23.48, 24.65, 2.61, 2.61, 2.61, 2.61, 2.61, 2.61] 23 | fedavg_15 = [7.69, 14.55, 15.12, 18.18, 19.21, 19.79, 20.37, 20.66, 21.98, 21.74, 21.12, 22.19, 22.07, 21.53, 22.98, 21.45, 21.69, 22.36, 21.9, 22.89, 21.49, 21.74, 22.23, 21.82, 22.31, 22.98, 22.02, 24.21, 18.39, 23.35, 21.86, 24.09, 23.06, 22.56, 22.98, 23.88, 23.93, 24.21, 12.27, 20.04, 21.61, 17.31, 22.15, 23.51, 2.64, 2.64, 2.64, 2.64, 2.64, 2.64] 24 | fedavg_20 = [7.27, 13.54, 14.04, 16.92, 17.92, 18.5, 19.04, 19.38, 20.5, 20.38, 19.88, 20.81, 20.85, 20.42, 21.69, 20.12, 20.62, 21.19, 20.54, 21.96, 20.31, 20.54, 21.04, 20.73, 20.88, 21.62, 20.81, 22.88, 17.69, 22.31, 20.77, 22.85, 21.81, 21.42, 21.73, 22.54, 22.5, 22.58, 11.19, 19.12, 20.27, 16.54, 21.08, 22.27, 2.69, 2.69, 2.69, 2.69, 2.69, 2.69] 25 | 26 | local_05 = [55.85, 59.58, 58.63, 59.1, 58.82, 60.38, 60.19, 57.22, 58.3, 58.02, 58.68, 59.76, 55.05, 56.79, 56.6, 56.7, 57.69, 59.25, 56.89, 57.92, 59.29, 57.74, 58.4, 58.58, 59.01, 59.1, 56.56, 52.5, 52.45, 52.36, 52.36, 52.45, 52.45, 52.41, 52.41, 52.45, 52.5, 52.45, 52.5, 52.5, 52.55, 52.45, 52.55, 52.5, 52.5, 52.5, 52.5, 52.5, 52.45, 52.45] 27 | local_10 = [51.48, 54.91, 54.04, 54.48, 54.22, 55.65, 55.48, 52.74, 53.74, 53.48, 54.09, 55.09, 50.74, 52.35, 52.17, 52.26, 53.17, 54.61, 52.43, 53.39, 54.65, 53.22, 53.83, 54.0, 54.39, 54.48, 52.13, 48.52, 48.48, 48.39, 48.39, 48.48, 48.48, 48.43, 48.43, 48.48, 48.52, 48.48, 48.52, 48.52, 48.57, 48.48, 48.57, 48.52, 48.52, 48.52, 48.52, 48.52, 48.48, 48.48] 28 | local_15 = [48.93, 52.19, 51.36, 51.78, 51.53, 52.89, 52.73, 50.12, 51.07, 50.83, 51.4, 52.36, 48.22, 49.75, 49.59, 49.67, 50.54, 51.9, 49.83, 50.74, 51.94, 50.58, 51.16, 51.32, 51.69, 51.78, 49.55, 46.2, 46.16, 46.07, 46.07, 46.16, 46.16, 46.12, 46.12, 46.16, 46.2, 46.16, 46.2, 46.2, 46.24, 46.16, 46.24, 46.2, 46.2, 46.2, 46.2, 46.2, 46.16, 46.16] 29 | local_20 = [45.54, 48.58, 47.81, 48.19, 47.96, 49.23, 49.08, 46.65, 47.54, 47.31, 47.85, 48.73, 44.88, 46.31, 46.15, 46.23, 47.04, 48.31, 46.38, 47.23, 48.35, 47.08, 47.62, 47.77, 48.12, 48.19, 46.12, 43.12, 43.08, 43.0, 43.0, 43.08, 43.08, 43.04, 43.04, 43.08, 43.12, 43.08, 43.12, 43.12, 43.15, 43.08, 43.15, 43.12, 43.12, 43.12, 43.12, 43.12, 43.08, 43.08] 30 | 31 | scei_y = [np.mean(scei_05), np.mean(scei_10), np.mean(scei_15), np.mean(scei_20), ] 32 | scei_err = [np.std(scei_05), np.std(scei_10), np.std(scei_15), np.std(scei_20), ] 33 | 34 | sceia_y = [np.mean(sceia_05), np.mean(sceia_10), np.mean(sceia_15), np.mean(sceia_20), ] 35 | sceia_err = [np.std(sceia_05), np.std(sceia_10), np.std(sceia_15), np.std(sceia_20), ] 36 | 37 | apfl_y = [np.mean(apfl_05), np.mean(apfl_10), np.mean(apfl_15), np.mean(apfl_20), ] 38 | apfl_err = [np.std(apfl_05), np.std(apfl_10), np.std(apfl_15), np.std(apfl_20), ] 39 | 40 | fedavg_y = [np.mean(fedavg_05), np.mean(fedavg_10), np.mean(fedavg_15), np.mean(fedavg_20), ] 41 | fedavg_err = [np.std(fedavg_05), np.std(fedavg_10), np.std(fedavg_15), np.std(fedavg_20), ] 42 | 43 | local_y = [np.mean(local_05), np.mean(local_10), np.mean(local_15), np.mean(local_20), ] 44 | local_err = [np.std(local_05), np.std(local_10), np.std(local_15), np.std(local_20), ] 45 | 46 | data = {'scei_y': scei_y, 'scei_err': scei_err, 47 | 'apfl_y': apfl_y, 'apfl_err': apfl_err, 48 | 'fedavg_y': fedavg_y, 'fedavg_err': fedavg_err, 49 | 'local_y': local_y, 'local_err': local_err, 50 | 'sceia_y': sceia_y, 'sceia_err': sceia_err, } 51 | 52 | save_path = None 53 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 54 | save_path = sys.argv[2] 55 | 56 | plot_skew("", data, False, False, save_path, plot_size="4") 57 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-skew/mlp-mnist-acc-skew.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | from plot.utils import plot_skew 5 | 6 | scei_05 = [88.58, 89.43, 90.14, 90.47, 91.42, 91.65, 92.08, 92.08, 92.5, 92.64, 92.64, 92.92, 92.97, 93.25, 93.25, 93.63, 93.63, 93.11, 93.63, 93.4, 93.58, 93.73, 93.92, 94.15, 94.06, 93.96, 93.92, 94.43, 94.15, 94.01, 93.92, 94.39, 94.2, 94.34, 94.06, 94.01, 94.2, 94.48, 94.39, 94.58, 94.53, 94.34, 94.53, 94.34, 94.43, 94.25, 94.86, 94.86, 94.43, 94.72] 7 | scei_10 = [81.65, 82.78, 83.87, 84.57, 85.61, 86.22, 87.43, 87.61, 88.17, 88.26, 88.74, 89.13, 89.3, 89.48, 89.78, 90.09, 90.43, 90.17, 90.35, 90.52, 90.52, 90.7, 90.74, 91.17, 90.96, 91.09, 90.96, 91.57, 91.52, 91.39, 91.22, 91.78, 92.04, 92.0, 91.57, 91.61, 92.13, 92.39, 92.13, 92.04, 92.35, 92.17, 92.39, 92.22, 92.17, 91.96, 92.48, 92.57, 91.96, 92.43] 8 | scei_15 = [77.6, 78.43, 79.88, 80.62, 82.31, 83.06, 84.5, 84.63, 85.41, 85.7, 86.36, 87.27, 87.36, 87.73, 88.14, 88.35, 88.76, 88.43, 89.17, 88.68, 88.93, 89.71, 89.67, 89.83, 89.59, 90.12, 89.83, 90.29, 90.04, 89.96, 90.04, 90.41, 90.33, 90.66, 90.54, 90.37, 90.66, 90.87, 90.95, 90.99, 91.07, 90.74, 90.95, 91.12, 90.87, 90.66, 91.65, 91.53, 90.87, 91.36] 9 | scei_20 = [72.23, 73.38, 75.35, 76.54, 78.46, 79.73, 81.12, 81.77, 83.0, 83.12, 83.69, 84.38, 85.12, 85.19, 85.46, 85.88, 85.81, 86.23, 86.62, 86.19, 86.54, 87.12, 87.73, 87.85, 87.46, 87.77, 87.38, 88.31, 88.31, 87.96, 88.12, 88.31, 88.23, 88.92, 88.81, 88.92, 88.85, 89.38, 89.42, 89.19, 89.42, 89.04, 89.73, 89.5, 89.5, 89.46, 89.58, 89.81, 89.12, 90.08] 10 | 11 | sceia_05 = [87.5, 89.1, 87.64, 89.58, 88.54, 90.8, 91.18, 91.56, 91.23, 92.45, 91.6, 92.45, 92.36, 92.45, 93.11, 92.59, 93.3, 93.35, 92.5, 93.07, 93.68, 93.4, 93.49, 93.82, 93.3, 93.54, 93.02, 94.1, 93.54, 94.15, 94.1, 94.2, 93.63, 93.96, 93.96, 93.49, 93.4, 94.25, 94.1, 94.06, 94.01, 94.25, 94.1, 94.25, 94.29, 94.15, 94.25, 94.2, 93.92, 94.2] 12 | sceia_10 = [80.65, 82.61, 81.96, 83.96, 83.48, 85.74, 86.57, 86.83, 87.26, 88.78, 87.87, 89.04, 88.65, 88.74, 89.65, 89.87, 90.43, 90.35, 89.83, 90.35, 91.09, 90.65, 90.74, 91.13, 90.7, 91.17, 90.78, 91.87, 91.74, 92.04, 92.0, 91.7, 91.43, 91.83, 91.57, 91.22, 91.35, 92.26, 91.96, 91.96, 91.96, 92.3, 92.0, 92.09, 92.13, 91.91, 92.3, 92.52, 92.17, 92.35] 13 | sceia_15 = [76.65, 78.76, 78.18, 80.45, 80.21, 82.36, 83.39, 84.21, 84.75, 86.32, 85.5, 86.32, 86.32, 86.65, 87.27, 87.23, 87.73, 87.93, 87.44, 87.85, 88.39, 88.43, 88.47, 89.17, 88.47, 89.01, 88.51, 90.04, 89.5, 89.67, 89.71, 89.17, 89.46, 89.79, 89.92, 89.96, 89.88, 90.29, 90.45, 90.41, 90.12, 90.66, 90.21, 90.5, 90.87, 90.66, 91.03, 91.2, 90.7, 91.03] 14 | sceia_20 = [71.35, 73.73, 73.5, 76.08, 76.46, 78.69, 79.81, 80.65, 81.42, 83.5, 83.31, 84.0, 84.0, 84.23, 84.46, 84.5, 85.27, 85.73, 84.58, 86.15, 86.58, 86.08, 86.31, 87.27, 86.73, 87.38, 86.81, 88.19, 87.96, 87.81, 88.38, 87.85, 87.62, 88.46, 87.77, 88.38, 88.42, 88.62, 88.5, 88.77, 88.62, 88.96, 89.0, 89.27, 89.35, 89.12, 89.38, 89.31, 89.04, 89.5] 15 | 16 | apfl_05 = [86.65, 87.74, 87.36, 87.74, 88.25, 87.97, 88.07, 88.16, 88.07, 85.42, 88.77, 88.54, 88.54, 88.87, 89.15, 88.87, 89.1, 88.63, 88.82, 87.03, 89.43, 89.34, 89.01, 89.29, 89.62, 89.29, 89.34, 89.2, 89.01, 88.73, 90.33, 90.14, 89.72, 89.91, 90.24, 89.86, 89.72, 89.58, 89.62, 89.34, 90.85, 90.85, 90.71, 90.19, 90.38, 90.33, 90.33, 89.95, 90.28, 90.42] 17 | apfl_10 = [79.87, 80.87, 80.52, 80.87, 81.35, 81.09, 81.17, 81.26, 81.17, 79.57, 81.96, 81.65, 81.61, 81.96, 82.17, 81.96, 82.13, 81.7, 81.87, 82.48, 82.52, 82.61, 82.22, 82.43, 82.78, 82.52, 82.48, 82.3, 82.09, 84.87, 84.13, 83.39, 83.0, 83.09, 83.26, 82.74, 82.65, 82.65, 82.7, 85.87, 84.96, 84.39, 84.04, 83.52, 83.57, 83.48, 83.48, 83.22, 83.48, 87.13] 18 | apfl_15 = [75.91, 76.86, 76.53, 76.86, 77.31, 77.07, 77.15, 77.23, 77.15, 76.4, 78.02, 77.64, 77.56, 77.93, 78.22, 77.89, 78.1, 77.69, 77.81, 80.17, 79.21, 78.72, 78.22, 78.43, 78.64, 78.39, 78.39, 78.26, 78.06, 82.93, 80.91, 80.21, 79.26, 79.17, 79.38, 78.8, 78.68, 78.68, 78.76, 84.26, 81.9, 81.2, 80.83, 80.21, 79.96, 79.75, 79.83, 79.5, 79.67, 85.66] 19 | apfl_20 = [70.65, 71.54, 71.23, 71.54, 71.96, 71.73, 71.81, 71.88, 71.81, 71.69, 72.65, 72.31, 72.31, 72.62, 72.81, 72.5, 72.69, 72.35, 72.5, 77.15, 74.23, 73.58, 73.19, 73.27, 73.38, 73.19, 73.15, 72.96, 72.77, 80.04, 76.46, 75.38, 74.46, 74.38, 74.65, 74.0, 73.85, 73.65, 73.73, 81.65, 78.0, 76.69, 76.27, 75.42, 75.0, 74.69, 74.58, 74.23, 74.5, 83.19] 20 | 21 | fedavg_05 = [69.43, 71.79, 76.37, 78.35, 79.91, 81.98, 83.16, 84.29, 85.0, 85.33, 86.08, 87.41, 87.69, 87.78, 89.01, 89.62, 88.92, 89.91, 90.33, 90.0, 90.33, 90.24, 90.28, 90.14, 90.42, 90.42, 90.38, 91.18, 91.04, 91.27, 91.27, 91.04, 91.18, 90.99, 91.18, 90.99, 91.7, 91.56, 91.46, 91.42, 91.6, 91.93, 91.56, 91.79, 91.65, 91.51, 91.7, 91.89, 91.84, 91.75] 22 | fedavg_10 = [67.61, 69.83, 74.48, 76.7, 78.48, 80.57, 81.48, 82.83, 83.7, 84.04, 85.04, 86.43, 86.74, 87.0, 88.26, 88.7, 88.09, 89.0, 89.52, 89.26, 89.57, 89.39, 89.87, 89.43, 89.83, 89.78, 89.83, 90.61, 90.26, 90.78, 90.78, 90.43, 90.57, 90.35, 90.7, 90.43, 91.26, 91.13, 90.87, 90.83, 91.04, 91.39, 91.04, 91.35, 91.17, 91.04, 91.17, 91.48, 91.35, 91.26] 23 | fedavg_15 = [66.16, 68.8, 73.51, 75.74, 77.4, 79.63, 81.12, 82.19, 83.02, 83.55, 84.5, 85.83, 86.28, 86.74, 87.89, 88.72, 87.81, 89.13, 89.38, 88.93, 89.46, 89.21, 89.13, 89.38, 89.63, 89.71, 89.71, 90.54, 90.29, 90.62, 90.45, 90.25, 90.45, 90.25, 90.41, 90.41, 91.07, 90.91, 90.7, 90.54, 90.7, 91.2, 90.79, 90.99, 90.99, 90.87, 90.95, 91.24, 91.16, 91.12] 24 | fedavg_20 = [64.73, 67.46, 72.31, 74.31, 76.04, 78.65, 80.31, 81.46, 82.46, 82.58, 83.73, 85.35, 85.62, 85.73, 87.19, 88.04, 87.23, 88.5, 88.88, 88.46, 88.77, 88.73, 88.85, 88.96, 89.23, 89.35, 89.42, 90.19, 89.92, 90.31, 90.27, 90.04, 90.27, 90.04, 90.12, 90.08, 90.73, 90.54, 90.46, 90.27, 90.46, 90.92, 90.58, 90.69, 90.85, 90.69, 90.77, 90.88, 91.0, 90.96] 25 | 26 | local_05 = [89.39, 90.14, 89.86, 90.05, 89.91, 90.0, 90.42, 90.24, 90.05, 90.05, 90.28, 90.38, 90.14, 90.28, 90.28, 90.38, 90.33, 90.33, 90.19, 90.14, 90.19, 90.24, 90.38, 90.42, 90.33, 90.47, 90.33, 90.14, 90.24, 90.28, 90.33, 90.24, 90.09, 90.09, 90.14, 90.14, 90.28, 90.28, 90.47, 90.19, 90.42, 90.24, 90.33, 90.47, 90.42, 90.33, 90.38, 90.42, 90.47, 90.47] 27 | local_10 = [82.39, 83.09, 82.83, 83.0, 82.87, 82.96, 83.35, 83.17, 83.0, 83.0, 83.22, 83.3, 83.09, 83.22, 83.22, 83.3, 83.26, 83.26, 83.13, 83.09, 83.13, 83.17, 83.3, 83.35, 83.26, 83.39, 83.26, 83.09, 83.17, 83.22, 83.26, 83.17, 83.04, 83.04, 83.09, 83.09, 83.22, 83.22, 83.39, 83.13, 83.35, 83.17, 83.26, 83.39, 83.35, 83.26, 83.3, 83.35, 83.39, 83.39] 28 | local_15 = [78.31, 78.97, 78.72, 78.88, 78.76, 78.84, 79.21, 79.05, 78.88, 78.88, 79.09, 79.17, 78.97, 79.09, 79.09, 79.17, 79.13, 79.13, 79.01, 78.97, 79.01, 79.05, 79.17, 79.21, 79.13, 79.26, 79.13, 78.97, 79.05, 79.09, 79.13, 79.05, 78.93, 78.93, 78.97, 78.97, 79.09, 79.09, 79.26, 79.01, 79.21, 79.05, 79.13, 79.26, 79.21, 79.13, 79.17, 79.21, 79.26, 79.26] 29 | local_20 = [72.88, 73.5, 73.27, 73.42, 73.31, 73.38, 73.73, 73.58, 73.42, 73.42, 73.62, 73.69, 73.5, 73.62, 73.62, 73.69, 73.65, 73.65, 73.54, 73.5, 73.54, 73.58, 73.69, 73.73, 73.65, 73.77, 73.65, 73.5, 73.58, 73.62, 73.65, 73.58, 73.46, 73.46, 73.5, 73.5, 73.62, 73.62, 73.77, 73.54, 73.73, 73.58, 73.65, 73.77, 73.73, 73.65, 73.69, 73.73, 73.77, 73.77] 30 | 31 | scei_y = [np.mean(scei_05), np.mean(scei_10), np.mean(scei_15), np.mean(scei_20), ] 32 | scei_err = [np.std(scei_05), np.std(scei_10), np.std(scei_15), np.std(scei_20), ] 33 | 34 | sceia_y = [np.mean(sceia_05), np.mean(sceia_10), np.mean(sceia_15), np.mean(sceia_20), ] 35 | sceia_err = [np.std(sceia_05), np.std(sceia_10), np.std(sceia_15), np.std(sceia_20), ] 36 | 37 | apfl_y = [np.mean(apfl_05), np.mean(apfl_10), np.mean(apfl_15), np.mean(apfl_20), ] 38 | apfl_err = [np.std(apfl_05), np.std(apfl_10), np.std(apfl_15), np.std(apfl_20), ] 39 | 40 | fedavg_y = [np.mean(fedavg_05), np.mean(fedavg_10), np.mean(fedavg_15), np.mean(fedavg_20), ] 41 | fedavg_err = [np.std(fedavg_05), np.std(fedavg_10), np.std(fedavg_15), np.std(fedavg_20), ] 42 | 43 | local_y = [np.mean(local_05), np.mean(local_10), np.mean(local_15), np.mean(local_20), ] 44 | local_err = [np.std(local_05), np.std(local_10), np.std(local_15), np.std(local_20), ] 45 | 46 | data = {'scei_y': scei_y, 'scei_err': scei_err, 47 | 'apfl_y': apfl_y, 'apfl_err': apfl_err, 48 | 'fedavg_y': fedavg_y, 'fedavg_err': fedavg_err, 49 | 'local_y': local_y, 'local_err': local_err, 50 | 'sceia_y': sceia_y, 'sceia_err': sceia_err, } 51 | 52 | save_path = None 53 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 54 | save_path = sys.argv[2] 55 | 56 | plot_skew("", data, False, False, save_path, plot_size="4") 57 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-skew/cnn-cifar10-acc-skew.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | from plot.utils import plot_skew 5 | 6 | scei_05 = [44.53, 52.55, 54.1, 53.96, 57.78, 57.5, 58.44, 57.12, 59.25, 59.72, 58.07, 57.88, 57.64, 58.63, 59.06, 59.29, 59.67, 59.81, 58.44, 58.25, 60.38, 60.75, 59.62, 58.35, 60.05, 60.47, 60.28, 59.06, 59.62, 58.21, 58.11, 59.15, 59.01, 58.82, 58.44, 59.06, 59.34, 58.87, 58.54, 58.54, 59.2, 60.75, 58.96, 60.05, 59.43, 58.68, 59.01, 60.05, 60.14, 59.1] 7 | scei_10 = [41.04, 48.43, 49.87, 49.7, 53.22, 52.96, 53.87, 52.65, 54.61, 55.0, 53.57, 53.26, 53.0, 54.09, 54.35, 54.61, 54.87, 55.22, 53.61, 53.78, 55.7, 55.87, 54.96, 54.04, 55.26, 55.78, 55.52, 54.74, 55.04, 53.83, 53.74, 54.78, 54.39, 54.17, 53.96, 54.48, 54.74, 54.48, 54.3, 53.74, 54.91, 56.22, 54.65, 55.43, 55.0, 54.22, 54.78, 55.61, 55.87, 55.04] 8 | scei_15 = [39.01, 46.03, 47.4, 47.23, 50.58, 50.33, 51.2, 50.08, 51.9, 52.27, 50.87, 50.74, 50.5, 51.28, 51.78, 52.02, 52.27, 52.4, 51.45, 51.24, 53.1, 53.39, 52.48, 51.36, 52.85, 53.14, 53.02, 52.02, 52.4, 51.2, 51.4, 52.19, 51.86, 51.53, 51.36, 51.82, 52.27, 51.94, 51.53, 51.69, 52.52, 53.72, 52.23, 53.14, 52.48, 52.02, 52.15, 53.26, 53.43, 52.85] 9 | scei_20 = [36.31, 42.85, 44.12, 44.0, 47.12, 46.85, 47.65, 46.65, 48.35, 48.81, 47.46, 47.42, 47.04, 47.92, 48.31, 48.58, 48.88, 48.96, 47.81, 47.81, 49.69, 49.96, 49.12, 48.04, 49.42, 49.62, 49.46, 48.73, 49.04, 48.12, 48.23, 49.08, 48.69, 48.5, 48.38, 49.04, 49.23, 48.96, 48.38, 48.27, 49.12, 50.62, 49.38, 49.88, 49.46, 48.81, 48.96, 49.92, 50.42, 49.85] 10 | 11 | sceia_05 = [48.02, 54.29, 55.19, 56.75, 56.32, 56.42, 57.22, 56.37, 53.77, 57.74, 56.79, 58.35, 57.08, 57.88, 57.69, 56.65, 57.41, 57.55, 57.03, 58.11, 55.19, 57.03, 57.26, 56.7, 57.55, 58.21, 57.45, 56.13, 57.12, 56.6, 57.97, 57.92, 57.97, 56.98, 56.18, 57.59, 58.35, 57.31, 56.89, 57.55, 56.98, 57.08, 57.12, 57.17, 56.93, 56.7, 55.33, 55.24, 56.18, 55.57] 12 | sceia_10 = [44.26, 50.04, 50.87, 52.3, 51.96, 52.04, 52.83, 52.09, 49.83, 53.57, 52.74, 54.09, 52.83, 53.52, 53.7, 52.57, 53.35, 53.48, 53.0, 54.13, 51.61, 53.22, 53.39, 53.0, 53.87, 54.74, 53.96, 52.43, 53.78, 53.17, 54.22, 54.61, 54.43, 53.48, 53.13, 54.39, 54.87, 53.96, 53.65, 54.13, 53.74, 53.48, 53.83, 53.96, 53.83, 53.52, 52.26, 52.3, 53.57, 52.87] 13 | sceia_15 = [42.07, 47.56, 48.35, 49.71, 49.38, 49.38, 50.25, 49.42, 47.11, 50.7, 49.75, 51.24, 49.96, 50.95, 50.83, 49.83, 50.45, 50.91, 50.33, 50.95, 48.84, 50.58, 50.7, 50.29, 51.2, 52.19, 51.12, 49.88, 51.24, 50.33, 51.61, 51.74, 51.94, 50.79, 50.33, 51.94, 52.23, 51.61, 50.87, 51.65, 51.53, 50.74, 51.36, 51.78, 51.36, 51.07, 49.88, 50.12, 51.28, 50.58] 14 | sceia_20 = [39.15, 44.27, 45.0, 46.31, 46.04, 46.27, 46.96, 46.42, 44.35, 47.92, 47.0, 48.58, 47.12, 48.0, 48.27, 47.27, 48.04, 48.19, 48.0, 48.46, 46.42, 47.81, 48.15, 47.77, 49.04, 49.08, 48.69, 47.65, 48.92, 47.62, 49.35, 49.58, 49.73, 49.0, 48.23, 49.58, 50.46, 49.23, 48.81, 49.54, 49.15, 48.73, 49.65, 49.81, 49.58, 49.35, 47.77, 48.35, 49.15, 49.08] 15 | 16 | apfl_05 = [46.27, 54.15, 55.71, 54.2, 55.28, 56.23, 56.42, 56.42, 56.6, 31.6, 53.73, 55.42, 55.33, 56.56, 56.37, 56.18, 56.32, 56.51, 56.46, 51.89, 56.98, 55.9, 56.51, 56.04, 55.71, 55.9, 55.94, 55.75, 55.85, 52.78, 54.91, 54.53, 56.37, 55.94, 55.75, 55.75, 55.85, 55.8, 56.04, 53.11, 55.33, 57.08, 57.22, 57.45, 57.45, 57.41, 57.55, 57.5, 57.59, 53.68] 17 | apfl_10 = [42.65, 49.91, 51.35, 49.96, 50.96, 51.83, 52.0, 52.0, 52.17, 29.13, 49.52, 51.09, 51.0, 52.13, 51.96, 51.78, 51.91, 52.09, 52.04, 48.04, 52.52, 51.52, 52.09, 51.65, 51.35, 51.52, 51.57, 51.39, 51.48, 49.78, 50.61, 50.26, 51.96, 51.57, 51.39, 51.39, 51.48, 51.43, 51.65, 50.0, 51.0, 52.61, 52.74, 52.96, 52.96, 52.91, 53.04, 53.0, 53.09, 50.57] 18 | apfl_15 = [40.54, 47.44, 48.8, 47.48, 48.43, 49.26, 49.42, 49.42, 49.59, 27.69, 47.07, 48.55, 48.47, 49.55, 49.38, 49.21, 49.34, 49.5, 49.46, 45.99, 49.92, 48.97, 49.5, 49.09, 48.8, 48.97, 49.01, 48.84, 48.93, 47.89, 48.14, 47.77, 49.38, 49.01, 48.84, 48.84, 48.93, 48.88, 49.09, 48.47, 48.55, 50.04, 50.12, 50.33, 50.33, 50.29, 50.41, 50.37, 50.45, 49.5] 19 | apfl_20 = [37.73, 44.15, 45.42, 44.19, 45.08, 45.85, 46.0, 46.0, 46.15, 25.77, 43.81, 45.19, 45.12, 46.12, 45.96, 45.81, 45.92, 46.08, 46.04, 42.69, 46.46, 45.58, 46.08, 45.69, 45.42, 45.58, 45.62, 45.46, 45.54, 44.88, 44.77, 44.46, 45.96, 45.62, 45.46, 45.46, 45.54, 45.5, 45.69, 45.35, 45.12, 46.54, 46.65, 46.85, 46.85, 46.81, 46.92, 46.88, 46.96, 46.65] 20 | 21 | fedavg_05 = [17.31, 21.23, 31.04, 36.32, 39.39, 41.23, 41.32, 40.47, 40.14, 39.58, 40.47, 39.95, 40.71, 40.99, 40.14, 40.0, 40.52, 40.14, 40.71, 40.42, 41.46, 40.14, 41.51, 41.32, 42.08, 41.6, 41.56, 42.31, 42.12, 41.27, 40.42, 41.7, 41.98, 42.83, 42.55, 42.78, 42.31, 42.88, 42.88, 40.66, 41.93, 42.08, 41.6, 41.75, 41.37, 41.46, 41.75, 42.5, 41.75, 42.22] 22 | fedavg_10 = [16.87, 20.48, 29.91, 35.48, 38.87, 40.7, 41.22, 40.65, 39.83, 39.35, 40.3, 39.74, 40.13, 40.65, 40.3, 40.04, 40.17, 39.83, 40.13, 39.61, 40.43, 39.78, 41.17, 41.04, 41.48, 40.96, 41.0, 41.7, 41.57, 40.61, 40.22, 41.26, 41.48, 42.09, 41.78, 42.22, 41.65, 42.43, 42.22, 40.52, 41.61, 41.61, 41.3, 41.61, 41.39, 41.43, 41.52, 42.61, 41.87, 42.13] 23 | fedavg_15 = [16.49, 20.58, 30.21, 35.29, 38.1, 40.21, 40.83, 39.96, 39.79, 38.72, 39.42, 39.13, 40.0, 40.37, 39.83, 39.01, 39.92, 39.55, 40.08, 39.5, 40.54, 39.21, 40.54, 40.5, 41.12, 40.66, 40.37, 41.24, 40.99, 40.66, 39.71, 40.95, 40.87, 41.61, 41.4, 41.82, 41.28, 42.07, 41.94, 39.88, 40.99, 41.12, 40.95, 41.16, 40.99, 40.62, 40.99, 41.57, 41.03, 41.32] 24 | fedavg_20 = [16.46, 20.12, 30.42, 35.58, 38.85, 40.69, 40.73, 39.96, 39.92, 38.92, 40.27, 39.5, 40.12, 40.42, 39.65, 39.62, 39.73, 39.38, 40.08, 39.46, 40.38, 39.46, 40.46, 40.54, 40.81, 40.19, 40.54, 41.27, 41.31, 40.04, 39.5, 40.88, 41.38, 41.73, 41.69, 41.96, 41.35, 41.69, 41.69, 39.92, 41.27, 41.23, 40.92, 41.23, 40.65, 40.85, 40.5, 41.5, 40.58, 41.12] 25 | 26 | local_05 = [52.22, 56.13, 56.89, 56.27, 55.61, 58.07, 59.06, 59.29, 59.39, 59.34, 59.43, 59.53, 59.58, 59.58, 59.43, 59.58, 59.62, 59.62, 59.58, 59.53, 59.53, 59.58, 59.58, 59.58, 59.58, 59.62, 59.58, 59.58, 59.58, 59.58, 59.58, 59.58, 59.58, 59.58, 59.58, 59.62, 59.62, 59.62, 59.62, 59.62, 59.62, 59.62, 59.62, 59.62, 59.67, 59.67, 59.67, 59.67, 59.72, 59.72] 27 | local_10 = [48.13, 51.74, 52.43, 51.87, 51.26, 53.52, 54.43, 54.65, 54.74, 54.7, 54.78, 54.87, 54.91, 54.91, 54.78, 54.91, 54.96, 54.96, 54.91, 54.87, 54.87, 54.91, 54.91, 54.91, 54.91, 54.96, 54.91, 54.91, 54.91, 54.91, 54.91, 54.91, 54.91, 54.91, 54.91, 54.96, 54.96, 54.96, 54.96, 54.96, 54.96, 54.96, 54.96, 54.96, 55.0, 55.0, 55.0, 55.0, 55.04, 55.04] 28 | local_15 = [45.74, 49.17, 49.83, 49.3, 48.72, 50.87, 51.74, 51.94, 52.02, 51.98, 52.07, 52.15, 52.19, 52.19, 52.07, 52.19, 52.23, 52.23, 52.19, 52.15, 52.15, 52.19, 52.19, 52.19, 52.19, 52.23, 52.19, 52.19, 52.19, 52.19, 52.19, 52.19, 52.19, 52.19, 52.19, 52.23, 52.23, 52.23, 52.23, 52.23, 52.23, 52.23, 52.23, 52.23, 52.27, 52.27, 52.27, 52.27, 52.31, 52.31] 29 | local_20 = [42.58, 45.77, 46.38, 45.88, 45.35, 47.35, 48.15, 48.35, 48.42, 48.38, 48.46, 48.54, 48.58, 48.58, 48.46, 48.58, 48.62, 48.62, 48.58, 48.54, 48.54, 48.58, 48.58, 48.58, 48.58, 48.62, 48.58, 48.58, 48.58, 48.58, 48.58, 48.58, 48.58, 48.58, 48.58, 48.62, 48.62, 48.62, 48.62, 48.62, 48.62, 48.62, 48.62, 48.62, 48.65, 48.65, 48.65, 48.65, 48.69, 48.69] 30 | 31 | scei_y = [np.mean(scei_05), np.mean(scei_10), np.mean(scei_15), np.mean(scei_20), ] 32 | scei_err = [np.std(scei_05), np.std(scei_10), np.std(scei_15), np.std(scei_20), ] 33 | 34 | sceia_y = [np.mean(sceia_05), np.mean(sceia_10), np.mean(sceia_15), np.mean(sceia_20), ] 35 | sceia_err = [np.std(sceia_05), np.std(sceia_10), np.std(sceia_15), np.std(sceia_20), ] 36 | 37 | apfl_y = [np.mean(apfl_05), np.mean(apfl_10), np.mean(apfl_15), np.mean(apfl_20), ] 38 | apfl_err = [np.std(apfl_05), np.std(apfl_10), np.std(apfl_15), np.std(apfl_20), ] 39 | 40 | fedavg_y = [np.mean(fedavg_05), np.mean(fedavg_10), np.mean(fedavg_15), np.mean(fedavg_20), ] 41 | fedavg_err = [np.std(fedavg_05), np.std(fedavg_10), np.std(fedavg_15), np.std(fedavg_20), ] 42 | 43 | local_y = [np.mean(local_05), np.mean(local_10), np.mean(local_15), np.mean(local_20), ] 44 | local_err = [np.std(local_05), np.std(local_10), np.std(local_15), np.std(local_20), ] 45 | 46 | data = {'scei_y': scei_y, 'scei_err': scei_err, 47 | 'apfl_y': apfl_y, 'apfl_err': apfl_err, 48 | 'fedavg_y': fedavg_y, 'fedavg_err': fedavg_err, 49 | 'local_y': local_y, 'local_err': local_err, 50 | 'sceia_y': sceia_y, 'sceia_err': sceia_err, } 51 | 52 | save_path = None 53 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 54 | save_path = sys.argv[2] 55 | 56 | plot_skew("", data, False, True, save_path, plot_size="4") 57 | -------------------------------------------------------------------------------- /federated-learning/plot/acc-skew/resnet-cifar10-acc-skew.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | from plot.utils import plot_skew 5 | 6 | scei_05 = [34.43, 41.98, 63.21, 49.53, 60.38, 66.04, 72.17, 70.28, 67.92, 64.15, 66.51, 63.21, 75.94, 75.47, 73.11, 75.0, 75.94, 78.3, 75.94, 76.89, 75.94, 74.53, 73.58, 73.58, 74.53, 75.94, 75.94, 72.64, 74.53, 71.23, 72.64, 65.09, 75.47, 73.58, 74.06, 73.58, 74.06, 74.06, 75.0, 73.11, 72.64, 73.58, 74.53, 74.53, 73.58, 73.58, 73.58, 73.58, 74.06, 75.0] 7 | scei_10 = [31.74, 38.7, 58.26, 45.65, 55.65, 60.87, 66.52, 64.78, 62.61, 59.13, 61.3, 58.26, 70.0, 69.57, 67.39, 69.13, 70.0, 72.17, 70.0, 70.87, 70.0, 68.7, 67.83, 67.83, 68.7, 70.0, 70.0, 66.96, 68.7, 65.65, 66.96, 60.0, 69.57, 67.83, 68.26, 67.83, 68.26, 68.26, 69.13, 67.39, 66.96, 67.83, 68.7, 68.7, 67.83, 67.83, 67.83, 67.83, 68.26, 69.13] 8 | scei_15 = [30.17, 36.78, 55.37, 43.39, 52.89, 57.85, 63.22, 61.57, 59.5, 56.2, 58.26, 55.37, 66.53, 66.12, 64.05, 65.7, 66.53, 68.6, 66.53, 67.36, 66.53, 65.29, 64.46, 64.46, 65.29, 66.53, 66.53, 63.64, 65.29, 62.4, 63.64, 57.02, 66.12, 64.46, 64.88, 64.46, 64.88, 64.88, 65.7, 64.05, 63.64, 64.46, 65.29, 65.29, 64.46, 64.46, 64.46, 64.46, 64.88, 65.7] 9 | scei_20 = [28.08, 34.23, 51.54, 40.38, 49.23, 53.85, 58.85, 57.31, 55.38, 52.31, 54.23, 51.54, 61.92, 61.54, 59.62, 61.15, 61.92, 63.85, 61.92, 62.69, 61.92, 60.77, 60.0, 60.0, 60.77, 61.92, 61.92, 59.23, 60.77, 58.08, 59.23, 53.08, 61.54, 60.0, 60.38, 60.0, 60.38, 60.38, 61.15, 59.62, 59.23, 60.0, 60.77, 60.77, 60.0, 60.0, 60.0, 60.0, 60.38, 61.15] 10 | 11 | sceia_05 = [24.53, 35.85, 41.04, 29.72, 50.47, 58.96, 56.13, 45.75, 63.21, 55.19, 71.7, 64.15, 62.74, 60.85, 67.92, 71.7, 69.81, 72.64, 69.81, 76.89, 74.53, 79.25, 78.3, 81.13, 73.58, 81.13, 80.66, 76.89, 77.83, 78.77, 77.83, 78.77, 77.83, 77.83, 79.72, 77.83, 79.25, 78.77, 79.72, 77.36, 80.66, 77.83, 79.25, 77.83, 78.77, 81.13, 79.72, 79.72, 78.3, 78.77] 12 | sceia_10 = [22.61, 33.04, 37.83, 27.39, 46.52, 54.35, 51.74, 42.17, 58.26, 50.87, 66.09, 59.13, 57.83, 56.09, 62.61, 66.09, 64.35, 66.96, 64.35, 70.87, 68.7, 73.04, 72.17, 74.78, 67.83, 74.78, 74.35, 70.87, 71.74, 72.61, 71.74, 72.61, 71.74, 71.74, 73.48, 71.74, 73.04, 72.61, 73.48, 71.3, 74.35, 71.74, 73.04, 71.74, 72.61, 74.78, 73.48, 73.48, 72.17, 72.61] 13 | sceia_15 = [21.49, 31.4, 35.95, 26.03, 44.21, 51.65, 49.17, 40.08, 55.37, 48.35, 62.81, 56.2, 54.96, 53.31, 59.5, 62.81, 61.16, 63.64, 61.16, 67.36, 65.29, 69.42, 68.6, 71.07, 64.46, 71.07, 70.66, 67.36, 68.18, 69.01, 68.18, 69.01, 68.18, 68.18, 69.83, 68.18, 69.42, 69.01, 69.83, 67.77, 70.66, 68.18, 69.42, 68.18, 69.01, 71.07, 69.83, 69.83, 68.6, 69.01] 14 | sceia_20 = [20.0, 29.23, 33.46, 24.23, 41.15, 48.08, 45.77, 37.31, 51.54, 45.0, 58.46, 52.31, 51.15, 49.62, 55.38, 58.46, 56.92, 59.23, 56.92, 62.69, 60.77, 64.62, 63.85, 66.15, 60.0, 66.15, 65.77, 62.69, 63.46, 64.23, 63.46, 64.23, 63.46, 63.46, 65.0, 63.46, 64.62, 64.23, 65.0, 63.08, 65.77, 63.46, 64.62, 63.46, 64.23, 66.15, 65.0, 65.0, 63.85, 64.23] 15 | 16 | apfl_05 = [47.17, 50.0, 51.89, 51.89, 55.66, 53.77, 58.02, 59.91, 54.72, 59.43, 64.15, 61.79, 60.85, 61.32, 58.02, 59.91, 57.08, 58.96, 59.43, 59.91, 60.38, 60.85, 62.26, 61.79, 58.96, 59.43, 61.79, 64.15, 65.09, 61.79, 62.74, 62.74, 62.74, 63.68, 63.68, 61.79, 62.74, 63.68, 62.74, 62.74, 60.85, 61.32, 61.32, 62.26, 61.79, 61.32, 63.21, 59.43, 61.79, 60.85] 17 | apfl_10 = [43.48, 46.09, 47.83, 47.83, 51.3, 49.57, 53.48, 55.22, 50.43, 54.78, 59.13, 56.96, 56.09, 56.52, 53.48, 55.22, 52.61, 54.35, 54.78, 55.22, 55.65, 56.09, 57.39, 56.96, 54.35, 54.78, 56.96, 59.13, 60.0, 56.96, 57.83, 57.83, 57.83, 58.7, 58.7, 56.96, 57.83, 58.7, 57.83, 57.83, 56.09, 56.52, 56.52, 57.39, 56.96, 56.52, 58.26, 54.78, 56.96, 56.09] 18 | apfl_15 = [41.32, 43.8, 45.45, 45.45, 48.76, 47.11, 50.83, 52.48, 47.93, 52.07, 56.2, 54.13, 53.31, 53.72, 50.83, 52.48, 50.0, 51.65, 52.07, 52.48, 52.89, 53.31, 54.55, 54.13, 51.65, 52.07, 54.13, 56.2, 57.02, 54.13, 54.96, 54.96, 54.96, 55.79, 55.79, 54.13, 54.96, 55.79, 54.96, 54.96, 53.31, 53.72, 53.72, 54.55, 54.13, 53.72, 55.37, 52.07, 54.13, 53.31] 19 | apfl_20 = [38.46, 40.77, 42.31, 42.31, 45.38, 43.85, 47.31, 48.85, 44.62, 48.46, 52.31, 50.38, 49.62, 50.0, 47.31, 48.85, 46.54, 48.08, 48.46, 48.85, 49.23, 49.62, 50.77, 50.38, 48.08, 48.46, 50.38, 52.31, 53.08, 50.38, 51.15, 51.15, 51.15, 51.92, 51.92, 50.38, 51.15, 51.92, 51.15, 51.15, 49.62, 50.0, 50.0, 50.77, 50.38, 50.0, 51.54, 48.46, 50.38, 49.62] 20 | 21 | fedavg_05 = [27.36, 36.32, 41.51, 50.47, 50.0, 43.87, 49.06, 58.49, 58.49, 53.77, 58.02, 55.19, 58.49, 49.53, 64.15, 55.19, 63.21, 65.57, 62.74, 62.26, 59.91, 62.74, 61.79, 56.13, 61.32, 60.38, 61.32, 61.79, 61.32, 61.79, 61.79, 61.32, 62.26, 60.38, 61.32, 58.96, 58.02, 62.26, 62.74, 66.51, 65.57, 65.57, 66.04, 63.68, 65.09, 64.62, 65.57, 64.62, 62.26, 60.85] 22 | fedavg_10 = [25.22, 33.48, 38.26, 46.52, 46.09, 40.43, 45.22, 53.91, 53.91, 49.57, 53.48, 50.87, 53.91, 45.65, 59.13, 50.87, 58.26, 60.43, 57.83, 57.39, 55.22, 57.83, 56.96, 51.74, 56.52, 55.65, 56.52, 56.96, 56.52, 56.96, 56.96, 56.52, 57.39, 55.65, 56.52, 54.35, 53.48, 57.39, 57.83, 61.3, 60.43, 60.43, 60.87, 58.7, 60.0, 59.57, 60.43, 59.57, 57.39, 56.09] 23 | fedavg_15 = [23.97, 31.82, 36.36, 44.21, 43.8, 38.43, 42.98, 51.24, 51.24, 47.11, 50.83, 48.35, 51.24, 43.39, 56.2, 48.35, 55.37, 57.44, 54.96, 54.55, 52.48, 54.96, 54.13, 49.17, 53.72, 52.89, 53.72, 54.13, 53.72, 54.13, 54.13, 53.72, 54.55, 52.89, 53.72, 51.65, 50.83, 54.55, 54.96, 58.26, 57.44, 57.44, 57.85, 55.79, 57.02, 56.61, 57.44, 56.61, 54.55, 53.31] 24 | fedavg_20 = [22.31, 29.62, 33.85, 41.15, 40.77, 35.77, 40.0, 47.69, 47.69, 43.85, 47.31, 45.0, 47.69, 40.38, 52.31, 45.0, 51.54, 53.46, 51.15, 50.77, 48.85, 51.15, 50.38, 45.77, 50.0, 49.23, 50.0, 50.38, 50.0, 50.38, 50.38, 50.0, 50.77, 49.23, 50.0, 48.08, 47.31, 50.77, 51.15, 54.23, 53.46, 53.46, 53.85, 51.92, 53.08, 52.69, 53.46, 52.69, 50.77, 49.62] 25 | 26 | local_05 = [33.49, 45.75, 48.11, 41.98, 51.89, 54.25, 57.08, 56.6, 56.13, 57.55, 59.91, 55.66, 58.96, 60.38, 57.08, 59.43, 59.91, 55.66, 58.02, 53.77, 58.49, 58.96, 62.26, 60.85, 63.68, 61.79, 62.74, 62.26, 63.21, 62.74, 63.21, 66.98, 60.38, 63.21, 62.74, 65.09, 64.62, 59.43, 61.79, 62.74, 62.26, 59.91, 63.68, 62.74, 61.79, 61.32, 65.09, 64.62, 64.62, 64.62] 27 | local_10 = [30.87, 42.17, 44.35, 38.7, 47.83, 50.0, 52.61, 52.17, 51.74, 53.04, 55.22, 51.3, 54.35, 55.65, 52.61, 54.78, 55.22, 51.3, 53.48, 49.57, 53.91, 54.35, 57.39, 56.09, 58.7, 56.96, 57.83, 57.39, 58.26, 57.83, 58.26, 61.74, 55.65, 58.26, 57.83, 60.0, 59.57, 54.78, 56.96, 57.83, 57.39, 55.22, 58.7, 57.83, 56.96, 56.52, 60.0, 59.57, 59.57, 59.57] 28 | local_15 = [29.34, 40.08, 42.15, 36.78, 45.45, 47.52, 50.0, 49.59, 49.17, 50.41, 52.48, 48.76, 51.65, 52.89, 50.0, 52.07, 52.48, 48.76, 50.83, 47.11, 51.24, 51.65, 54.55, 53.31, 55.79, 54.13, 54.96, 54.55, 55.37, 54.96, 55.37, 58.68, 52.89, 55.37, 54.96, 57.02, 56.61, 52.07, 54.13, 54.96, 54.55, 52.48, 55.79, 54.96, 54.13, 53.72, 57.02, 56.61, 56.61, 56.61] 29 | local_20 = [27.31, 37.31, 39.23, 34.23, 42.31, 44.23, 46.54, 46.15, 45.77, 46.92, 48.85, 45.38, 48.08, 49.23, 46.54, 48.46, 48.85, 45.38, 47.31, 43.85, 47.69, 48.08, 50.77, 49.62, 51.92, 50.38, 51.15, 50.77, 51.54, 51.15, 51.54, 54.62, 49.23, 51.54, 51.15, 53.08, 52.69, 48.46, 50.38, 51.15, 50.77, 48.85, 51.92, 51.15, 50.38, 50.0, 53.08, 52.69, 52.69, 52.69] 30 | 31 | scei_y = [np.mean(scei_05), np.mean(scei_10), np.mean(scei_15), np.mean(scei_20), ] 32 | scei_err = [np.std(scei_05), np.std(scei_10), np.std(scei_15), np.std(scei_20), ] 33 | 34 | sceia_y = [np.mean(sceia_05), np.mean(sceia_10), np.mean(sceia_15), np.mean(sceia_20), ] 35 | sceia_err = [np.std(sceia_05), np.std(sceia_10), np.std(sceia_15), np.std(sceia_20), ] 36 | 37 | apfl_y = [np.mean(apfl_05), np.mean(apfl_10), np.mean(apfl_15), np.mean(apfl_20), ] 38 | apfl_err = [np.std(apfl_05), np.std(apfl_10), np.std(apfl_15), np.std(apfl_20), ] 39 | 40 | fedavg_y = [np.mean(fedavg_05), np.mean(fedavg_10), np.mean(fedavg_15), np.mean(fedavg_20), ] 41 | fedavg_err = [np.std(fedavg_05), np.std(fedavg_10), np.std(fedavg_15), np.std(fedavg_20), ] 42 | 43 | local_y = [np.mean(local_05), np.mean(local_10), np.mean(local_15), np.mean(local_20), ] 44 | local_err = [np.std(local_05), np.std(local_10), np.std(local_15), np.std(local_20), ] 45 | 46 | data = {'scei_y': scei_y, 'scei_err': scei_err, 47 | 'apfl_y': apfl_y, 'apfl_err': apfl_err, 48 | 'fedavg_y': fedavg_y, 'fedavg_err': fedavg_err, 49 | 'local_y': local_y, 'local_err': local_err, 50 | 'sceia_y': sceia_y, 'sceia_err': sceia_err, } 51 | 52 | save_path = None 53 | if len(sys.argv) == 3 and sys.argv[1] and sys.argv[1] == "save": 54 | save_path = sys.argv[2] 55 | 56 | plot_skew("", data, False, False, save_path, plot_size="4") 57 | --------------------------------------------------------------------------------