├── figs ├── .DS_Store └── centaur │ ├── _overview.png │ ├── _results_1.png │ └── _results_2.png ├── frameworks └── centaur │ ├── .DS_Store │ ├── hiefed │ ├── .DS_Store │ ├── data_selection_test.py │ ├── data_selection.py │ └── fed_criterion.py │ ├── plot │ ├── .DS_Store │ ├── 5_unbalanced.py │ ├── 5_unbalanced_s.py │ ├── 6_client_num.py │ ├── 1_acc.py │ ├── 2_classifier.py │ ├── 9_prox_table.py │ ├── 4_cost_linkage.py │ ├── 2_classifier_s.py │ ├── 8_data_selection.py │ ├── 0_sensor_acc.py │ ├── 0_intro.py │ ├── 0_intro_s.py │ ├── 7_mobility.py │ ├── 3_cost_acc.py │ ├── plot_utils.py │ ├── plot_balanced_sn.py │ └── plot.py.orig │ ├── measures │ ├── .DS_Store │ ├── coverage.py │ └── mac_comm_counter.py │ ├── third_party │ ├── .DS_Store │ ├── autograd_hacks_test.py │ ├── autograd_hacks.py │ └── dataset_partition_flwr.py │ ├── client_configurations.csv │ ├── oam.py │ ├── run_all.sh │ ├── requirements.txt │ ├── main.py │ ├── make_script.py │ └── running_args.py ├── LICENSE └── README.md /figs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nokia-Bell-Labs/data-centric-federated-learning/HEAD/figs/.DS_Store -------------------------------------------------------------------------------- /figs/centaur/_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nokia-Bell-Labs/data-centric-federated-learning/HEAD/figs/centaur/_overview.png -------------------------------------------------------------------------------- /figs/centaur/_results_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nokia-Bell-Labs/data-centric-federated-learning/HEAD/figs/centaur/_results_1.png -------------------------------------------------------------------------------- /figs/centaur/_results_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nokia-Bell-Labs/data-centric-federated-learning/HEAD/figs/centaur/_results_2.png -------------------------------------------------------------------------------- /frameworks/centaur/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nokia-Bell-Labs/data-centric-federated-learning/HEAD/frameworks/centaur/.DS_Store -------------------------------------------------------------------------------- /frameworks/centaur/hiefed/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nokia-Bell-Labs/data-centric-federated-learning/HEAD/frameworks/centaur/hiefed/.DS_Store -------------------------------------------------------------------------------- /frameworks/centaur/plot/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nokia-Bell-Labs/data-centric-federated-learning/HEAD/frameworks/centaur/plot/.DS_Store -------------------------------------------------------------------------------- /frameworks/centaur/measures/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nokia-Bell-Labs/data-centric-federated-learning/HEAD/frameworks/centaur/measures/.DS_Store -------------------------------------------------------------------------------- /frameworks/centaur/third_party/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nokia-Bell-Labs/data-centric-federated-learning/HEAD/frameworks/centaur/third_party/.DS_Store -------------------------------------------------------------------------------- /frameworks/centaur/client_configurations.csv: -------------------------------------------------------------------------------- 1 | type,frequency(MHz),acceleration,memory(MB),storage(MB),power(mWpMHz),battery_energy(J),uplink(Mbit/s),downlink(Mbit/s),energy_comm(W),timeout(s),coverage_interval(h),coverage_prob 2 | iot,100,1,inf,5,0.05,inf,2,2,0.0001,inf,1,0.5 3 | ap,2000,1,inf,4096,1.5,inf,10,100,10,inf,1,1 -------------------------------------------------------------------------------- /frameworks/centaur/hiefed/data_selection_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from data_selection import CDFSelection 3 | 4 | # Test on data selection module 5 | 6 | test_size = 10 7 | losses = np.random.uniform(low=0, high=2, size=(test_size,)) 8 | 9 | skip_steps = test_size - 1 10 | alpha = 5 11 | beta = 2 12 | gamma = 1 13 | 14 | # select data based on loss with build_queue 15 | queue_size = int(test_size/2) 16 | loss_selec1 = CDFSelection(alpha, beta, max_len=queue_size) 17 | step = 1 18 | for loss in losses: 19 | loss_selec1.maintain_queue(loss) 20 | if step < skip_steps: 21 | step += 1 22 | else: 23 | prob_drop, prob_ap = loss_selec1.query_prob(loss) 24 | 25 | print("=======================") 26 | 27 | # select data based on loss with full_queue 28 | loss_selec2 = CDFSelection(alpha, beta, full_queue=losses) 29 | idx = 0 30 | for loss in losses: 31 | prob_drop, prob_ap = loss_selec2.query_prob(loss) 32 | #loss_selec2.update_state(0.5, idx) 33 | idx += 1 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The Clear BSD License 2 | 3 | Copyright (c) 2024, Nokia Bell Labs 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without modification, are permitted (subject to the limitations in the disclaimer below) provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 10 | * Neither the name of Nokia Bell Labs nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | SOFTWARE. 14 | 15 | -------------------------------------------------------------------------------- /frameworks/centaur/oam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from typing import Tuple, List 4 | 5 | np.random.seed(42) 6 | 7 | #the client starts moving from home location and ends the journey at home location 8 | def move_client(location_connectivity_probability: 'float array', temporal_granularity: int, home_loc: int) -> Tuple[List, List]: 9 | locs = [] 10 | connectivity = [] 11 | locs.append(home_loc) 12 | 13 | connectivity.append(location_connectivity_probability[home_loc]) 14 | for i in range(temporal_granularity-2): 15 | val = np.random.randint(len(location_connectivity_probability)) 16 | locs.append(val) 17 | connectivity.append(location_connectivity_probability[val]) 18 | locs.append(home_loc) 19 | connectivity.append(location_connectivity_probability[home_loc]) 20 | return locs,connectivity 21 | 22 | 23 | def create_client_profile(num_clients=100, temporal='day', spatial=10) -> None: 24 | print("Creating client mobility profiles") 25 | locs_connect_proba = np.random.rand(spatial) #keeping this global across all clients 26 | 27 | #this is the home location with maximum connectivity available 28 | max_connect_loc = np.argmax(locs_connect_proba) 29 | time = None 30 | if(temporal == "months"): 31 | time = 12 32 | elif(temporal == "weeks"): 33 | time = 7 34 | elif(temporal == "day"): 35 | time = 24 36 | else: 37 | raise Exception("Granularity out of reach!") 38 | 39 | for i in range(num_clients): 40 | out_file = open("client_profiles/"+str(i)+".csv","w+") 41 | out_file.write("locs,probability\n") 42 | locs,connectivity = move_client(locs_connect_proba,time,max_connect_loc) 43 | for l,c in zip(locs,connectivity): 44 | out_file.write(str(l)+","+str(c)+"\n") 45 | out_file.close() -------------------------------------------------------------------------------- /frameworks/centaur/hiefed/data_selection.py: -------------------------------------------------------------------------------- 1 | # data selection module 2 | class CDFSelection: 3 | def __init__(self, alpha, beta, max_len=None, full_queue=None): 4 | self.alpha = alpha 5 | self.beta = beta 6 | # for loss, both alpha and beta should be larger than 1, in order to 7 | # provide the middle interval for norm measurement later. 8 | 9 | if full_queue is not None: 10 | self.queue = full_queue 11 | self.max_len = len(full_queue) 12 | else: 13 | self.queue = [] 14 | self.max_len = max_len 15 | 16 | # build the queue and keep the size with poping out old instance. 17 | def maintain_queue(self, loss): 18 | if len(self.queue) == self.max_len: 19 | self.queue.pop(0) 20 | self.queue.append(loss) 21 | 22 | # update state of the queue 23 | def update_state(self, loss, index): 24 | self.queue[index] = loss 25 | 26 | # query the probably to drop or go ap 27 | def query_prob(self, loss): 28 | index = sorted(self.queue).index(loss) + 1 29 | 30 | cdf_left = float(index)/self.max_len 31 | cdf_right = 1 - cdf_left 32 | 33 | # only chose when self.alpha != 0 34 | if self.alpha == 0: 35 | prob_drop = 0 36 | else: 37 | prob_drop = cdf_right**self.alpha 38 | 39 | # only chose when self.beta != 0 40 | if self.beta == 0: 41 | prob_ap = 0 42 | else: 43 | prob_ap = cdf_left**self.beta 44 | 45 | # print(self.queue) 46 | # print("cdf_left={}, prob_low(drop)={}".format(cdf_left, prob_drop)) 47 | # print("cdf_right={}, prob_high(ap)={}".format(cdf_right, prob_ap)) 48 | # print("-----") 49 | 50 | return prob_drop, prob_ap 51 | 52 | 53 | # TODO: data selection based on the built distribution -------------------------------------------------------------------------------- /frameworks/centaur/run_all.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2,3,5,6,7 python main.py --destine=ap --dataset=cifar10 --encoder=squeezenet 2 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python main.py --save_results=True --destine=dyn --dataset=cifar10 --encoder=efficientnet --classifier=small 2>&1 | tee log/dyn_dcifar10_eefficientnet_csmall.txt 3 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python main.py --save_results=True --destine=ap --dataset=cifar10 --encoder=efficientnet --classifier=medium 2>&1 | tee log/ap_dcifar10_eefficientnet_cmedium.txt 4 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python main.py --save_results=True --destine=dyn --dataset=cifar10 --encoder=efficientnet --classifier=medium 2>&1 | tee log/dyn_dcifar10_eefficientnet_cmedium.txt 5 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python main.py --save_results=True --destine=ap --dataset=cifar10 --encoder=efficientnet --classifier=large 2>&1 | tee log/ap_dcifar10_eefficientnet_clarge.txt 6 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python main.py --save_results=True --destine=dyn --dataset=cifar10 --encoder=efficientnet --classifier=large 2>&1 | tee log/dyn_dcifar10_eefficientnet_clarge.txt 7 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python main.py --save_results=True --destine=ap --dataset=cifar100 --encoder=efficientnet --classifier=small 2>&1 | tee log/ap_dcifar100_eefficientnet_csmall.txt 8 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python main.py --save_results=True --destine=dyn --dataset=cifar100 --encoder=efficientnet --classifier=small 2>&1 | tee log/dyn_dcifar100_eefficientnet_csmall.txt 9 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python main.py --save_results=True --destine=ap --dataset=cifar100 --encoder=efficientnet --classifier=medium 2>&1 | tee log/ap_dcifar100_eefficientnet_cmedium.txt 10 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python main.py --save_results=True --destine=dyn --dataset=cifar100 --encoder=efficientnet --classifier=medium 2>&1 | tee log/dyn_dcifar100_eefficientnet_cmedium.txt 11 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python main.py --save_results=True --destine=ap --dataset=cifar100 --encoder=efficientnet --classifier=large 2>&1 | tee log/ap_dcifar100_eefficientnet_clarge.txt 12 | CUDA_VISIBLE_DEVICES=1,2,3,5,6,7 python main.py --save_results=True --destine=dyn --dataset=cifar100 --encoder=efficientnet --classifier=large 2>&1 | tee log/dyn_dcifar100_eefficientnet_clarge.txt -------------------------------------------------------------------------------- /frameworks/centaur/plot/5_unbalanced.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | import numpy as np 4 | 5 | import os 6 | from os import listdir 7 | from os.path import isfile, join 8 | 9 | from numpy import loadtxt 10 | 11 | import pandas as pd 12 | 13 | import matplotlib.pyplot as plt 14 | from matplotlib.ticker import FuncFormatter 15 | 16 | from plot_utils import running_avg 17 | 18 | plt.rcParams.update({'font.size': 8}) 19 | 20 | dest = ['ap','iot','dyn'] 21 | dest_str = ['AP Training','UCD Training','Centaur'] 22 | flpa = ['0.001', '0.1', '1.0', '1000.0']#, '10.0', '100.0', ] '0.01', 23 | 24 | 25 | clas_hist = ['500,0,0,0,0,0,0,0,0,0', '251,230,13,4,2,0,0,0,0,0', '125,89,87,47,44,39,34,29,3,3', '57,53,53,52,51,50,49,48,47,40'] 26 | clas_hist = ['Hist: ' + c for c in clas_hist] 27 | 28 | 29 | 30 | rounds = list(range(100)) 31 | 32 | all_dat = [] 33 | 34 | for f in flpa: 35 | dat = [] 36 | for d in dest: 37 | filedir = 'results/' + 'acc_' + d + '_1cifar10_2mobi_3medium_c8|100_e3_a5_b3_g0_flpa' + f + '.csv' 38 | dat_x = pd.read_csv(filedir, sep=',') 39 | dat_x = dat_x['accuracy'].tolist()[1:] 40 | dat_x = running_avg(dat_x) 41 | dat.append(dat_x) 42 | 43 | all_dat.append(dat) 44 | 45 | # import pdb; pdb.set_trace() 46 | 47 | # plot 48 | figs, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(5,4)) 49 | plt.subplots_adjust(left=0.1, right=0.98, top=0.95, bottom=0.1) 50 | 51 | idx = 0 52 | for row in axes: 53 | for col in row: 54 | dat = all_dat[idx] 55 | col.plot(rounds, dat[0], '-', color='b', label=dest_str[0], linewidth=1) 56 | col.plot(rounds, dat[1], '-', color='y', label=dest_str[1], linewidth=1) 57 | col.plot(rounds, dat[2], '-', color='r',label=dest_str[2], linewidth=1) 58 | col.annotate(clas_hist[idx], xy=(0, 0.9), xycoords='data',size=6) #, textcoords='offset points' 59 | 60 | col.set_yticks(np.arange(0.1, 1, 0.1)) 61 | col.set_ylim(0.1, 0.95) 62 | col.set_title('LDA-Alpha: ' + flpa[idx], loc='left') 63 | 64 | if idx == 2 or idx == 3: col.set_xlabel('Rounds') 65 | if idx == 0 or idx == 2: col.set_ylabel('Test Accuracy') 66 | idx += 1 67 | 68 | col.legend() 69 | 70 | plt.savefig('imgs/unbalanced.pdf') -------------------------------------------------------------------------------- /frameworks/centaur/plot/5_unbalanced_s.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | import numpy as np 4 | sys.path.append(".") 5 | 6 | 7 | import os 8 | from os import listdir 9 | from os.path import isfile, join 10 | 11 | from numpy import loadtxt 12 | 13 | import pandas as pd 14 | 15 | import matplotlib.pyplot as plt 16 | from matplotlib.ticker import FuncFormatter 17 | 18 | from plot_utils import running_avg 19 | 20 | plt.rcParams.update({'font.size': 8}) 21 | 22 | dest = ['ap','iot','dyn'] 23 | dest_str = ['AP Training','UCD Training','Centaur'] 24 | flpa = ['0.001', '0.1', '1.0', '1000.0']#, '10.0', '100.0', ] '0.01', 25 | 26 | 27 | clas_hist = ['500,0,0,0,0,0,0,0,0,0', '251,230,13,4,2,0,0,0,0,0', '125,89,87,47,44,39,34,29,3,3', '57,53,53,52,51,50,49,48,47,40'] 28 | clas_hist = ['Hist: ' + c for c in clas_hist] 29 | 30 | 31 | 32 | rounds = list(range(100)) 33 | 34 | all_dat = [] 35 | 36 | for f in flpa: 37 | dat = [] 38 | for d in dest: 39 | filedir = 'results/' + 'acc_' + d + '_1uci_har_2senc_3medium_c8|20_e3_a5_b3_g0_flpa' + f + '_mr0|0.csv' 40 | dat_x = pd.read_csv(filedir, sep=',') 41 | dat_x = dat_x['accuracy'].tolist()[1:] 42 | dat_x = running_avg(dat_x) 43 | dat.append(dat_x) 44 | 45 | all_dat.append(dat) 46 | 47 | # import pdb; pdb.set_trace() 48 | 49 | # plot 50 | figs, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(5,4)) 51 | plt.subplots_adjust(left=0.1, right=0.98, top=0.95, bottom=0.1) 52 | 53 | idx = 0 54 | for row in axes: 55 | for col in row: 56 | dat = all_dat[idx] 57 | col.plot(rounds, dat[0], '-', color='b', label=dest_str[0], linewidth=1) 58 | col.plot(rounds, dat[1], '-', color='y', label=dest_str[1], linewidth=1) 59 | col.plot(rounds, dat[2], '-', color='r',label=dest_str[2], linewidth=1) 60 | # col.annotate(clas_hist[idx], xy=(0, 0.9), xycoords='data',size=6) #, textcoords='offset points' 61 | 62 | col.set_yticks(np.arange(0.1, 1.01, 0.1)) 63 | col.set_ylim(0.1, 1.01) 64 | col.set_title('LDA-Alpha: ' + flpa[idx], loc='left') 65 | 66 | if idx == 2 or idx == 3: col.set_xlabel('Rounds') 67 | if idx == 0 or idx == 2: col.set_ylabel('Test Accuracy') 68 | idx += 1 69 | 70 | col.legend() 71 | 72 | plt.savefig('imgs/unbalanced.pdf') -------------------------------------------------------------------------------- /frameworks/centaur/measures/coverage.py: -------------------------------------------------------------------------------- 1 | import random 2 | import pandas as pd 3 | 4 | def get_online_state(res, cid, server_round, mrate_min, mrate_max): 5 | '''This function gets ap online time and iot online time''' 6 | 7 | if mrate_min == 0 and mrate_max == 0: 8 | iot_online_prob = float(res.loc[res['type'] == 'iot']['coverage_prob']) 9 | else: 10 | assert(mrate_max > mrate_min, 'Conflict: Defined mobility rate Max is smaller than Min!') 11 | #loading mobility for the client cid 12 | cprofile = pd.read_csv("client_profiles/"+cid+".csv") 13 | 14 | # scalling 15 | prob_list = cprofile["probability"].tolist() 16 | A = (mrate_max-mrate_min) / (max(prob_list)-min(prob_list)) 17 | B = mrate_min - A*min(prob_list) 18 | prob_scaled = [p*A+B for p in prob_list] 19 | print(f"connectivity_probability: {prob_scaled}") 20 | 21 | iot_online_prob = prob_scaled[server_round%len(cprofile)] #each server round corresonds to 1 time unit of client profile file 22 | 23 | ap_online_prob = float(res.loc[res['type'] == 'ap']['coverage_prob']) 24 | ap_offline_interval = float(res.loc[res['type'] == 'ap']['coverage_interval(h)']) 25 | iot_offline_interval = float(res.loc[res['type'] == 'iot']['coverage_interval(h)']) 26 | 27 | ap_offline_prob = 1 - ap_online_prob 28 | iot_offline_prob = 1 - iot_online_prob 29 | 30 | ap_offline = True 31 | ap_time = 0 - ap_offline_interval 32 | while ap_offline: 33 | if ap_offline_prob < random.uniform(0, 1): 34 | ap_offline = False 35 | ap_time += ap_offline_interval 36 | 37 | iot_offline = True 38 | iot_time = 0 - iot_offline_interval 39 | while iot_offline: 40 | if iot_offline_prob < random.uniform(0, 1): 41 | iot_offline = False 42 | iot_time += iot_offline_interval 43 | 44 | return ap_time, iot_time 45 | 46 | 47 | def get_sample_limits(res, sample_size): 48 | '''This function gets storage sample numbers''' 49 | ap_storage = float(res.loc[res['type'] == 'ap']['storage(MB)']) 50 | iot_storage = float(res.loc[res['type'] == 'iot']['storage(MB)']) 51 | ap_storage_sn = int(ap_storage * 1024 / sample_size) 52 | iot_storage_sn = int(iot_storage * 1024 / sample_size) 53 | 54 | return ap_storage_sn, iot_storage_sn 55 | 56 | -------------------------------------------------------------------------------- /frameworks/centaur/plot/6_client_num.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | import numpy as np 4 | 5 | import os 6 | 7 | from numpy import loadtxt 8 | 9 | import pandas as pd 10 | 11 | import matplotlib.pyplot as plt 12 | from matplotlib.ticker import FuncFormatter 13 | 14 | from plot_utils import load_all_acc 15 | 16 | plt.rcParams.update({'font.size': 8}) 17 | 18 | 19 | res_folder = 'results' 20 | df_acc = load_all_acc(res_folder, focus='mr') 21 | 22 | # import pdb; pdb.set_trace() 23 | 24 | # keep alpha=5, beta=3, gamma=0 only 25 | df_acc = df_acc.loc[~(df_acc['clientin'] == '8')] 26 | df_acc = df_acc.sort_values(['clientall', 'clientin','dest']) 27 | 28 | df_acc['key'] = [d + '/\n' + e for d,e in zip(df_acc['clientin'],df_acc['clientall'])] 29 | 30 | # bar plot 31 | fig, ax1 = plt.subplots(figsize=(5.5,2.8)) 32 | plt.gcf().subplots_adjust(bottom=0.19, top=0.97, left=0.09, right=0.98) 33 | 34 | barwidth = 0.25 35 | 36 | br_acc_ap = df_acc.loc[(df_acc['dest']=='ap')]['accuracy'].tolist() 37 | br_acc_iot = df_acc.loc[(df_acc['dest']=='iot')]['accuracy'].tolist() 38 | br_acc_dyn = df_acc.loc[(df_acc['dest']=='dyn')]['accuracy'].tolist() 39 | 40 | x = np.arange(len(br_acc_ap)) 41 | 42 | # calculate percentage improvement 43 | # iot_up, ap_up = [], [] 44 | # for i in range(len(br_acc_dyn)): 45 | # #ap_up_ = (br_acc_dyn[i] - br_acc_ap[i]) / br_acc_ap[i] 46 | # #iot_up_ = (br_acc_dyn[i] - br_acc_iot[i]) / br_acc_iot[i] 47 | # ap_up_ = br_acc_dyn[i] - br_acc_ap[i] 48 | # iot_up_ = br_acc_dyn[i] - br_acc_iot[i] 49 | # ap_up.append(ap_up_) 50 | # iot_up.append(iot_up_) 51 | 52 | # dat_ = {'dataset': df_acc['dataset'].tolist()[:len(br_acc_dyn)], 53 | # 'encoder': df_acc['encoder'].tolist()[:len(br_acc_dyn)], 54 | # 'iot_up': iot_up, 55 | # 'ap_up': ap_up 56 | # } 57 | 58 | # print(pd.DataFrame(data=dat_)) 59 | 60 | 61 | # plot 62 | ax1.bar(x-barwidth, br_acc_ap, color ='b', width = barwidth, edgecolor ='black', label ='AP Training') 63 | ax1.bar(x, br_acc_iot, color ='y', width = barwidth, edgecolor ='black', label ='UCD Training') 64 | ax1.bar(x+barwidth, br_acc_dyn, color ='r', width = barwidth, edgecolor ='black', label ='Centaur') 65 | 66 | ax1.set_xticks(x, df_acc['key'].unique().tolist()) #rotation=45 67 | ax1.set_ylim(0.4, 1.00) 68 | ax1.set_yticks(np.arange(0.4, 1.1, 0.1)) 69 | ax1.legend() 70 | 71 | # plt.xlabel('Models and Datasets') 72 | ax1.set_ylabel('Test Accuracy') 73 | ax1.set_xlabel('Participating Clients out of Total Clients') 74 | 75 | plt.savefig('imgs/clients_num.pdf') -------------------------------------------------------------------------------- /frameworks/centaur/plot/1_acc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | import pandas as pd 4 | import numpy as np 5 | 6 | import os 7 | 8 | from numpy import loadtxt 9 | 10 | import pandas as pd 11 | 12 | import matplotlib.pyplot as plt 13 | from matplotlib.ticker import FuncFormatter 14 | 15 | from plot_utils import load_all_acc 16 | 17 | plt.rcParams.update({'font.size': 8}) 18 | 19 | 20 | res_folder = 'results' 21 | df_acc = load_all_acc(res_folder) 22 | # keep alpha=5, beta=3, gamma=0 only 23 | df_acc = df_acc.loc[(df_acc['alpha'] == '5') & (df_acc['beta'] == '3') & (df_acc['gamma'] == '0')] 24 | 25 | df_acc = df_acc.groupby(['dest','dataset','encoder'], as_index=False).max() 26 | 27 | df_acc['key'] = [d + '\n' + e for d,e in zip(df_acc['encoder'],df_acc['dataset'])] 28 | 29 | df_acc_vanilla = load_all_acc(res_folder, 'vanilla') 30 | x_vnl = [d + '\n' + e for d,e in zip(df_acc_vanilla['encoder'],df_acc_vanilla['dataset'])] 31 | y_vnl = df_acc_vanilla['accuracy'] 32 | 33 | # bar plot 34 | fig, ax1 = plt.subplots(figsize=(5.5,2.8)) 35 | plt.gcf().subplots_adjust(bottom=0.18, top=0.97, left=0.09, right=0.98) 36 | 37 | barwidth = 0.2 38 | 39 | br_acc_ap = df_acc.loc[(df_acc['dest']=='ap')]['accuracy'].tolist() 40 | br_acc_iot = df_acc.loc[(df_acc['dest']=='iot')]['accuracy'].tolist() 41 | br_acc_dyn = df_acc.loc[(df_acc['dest']=='dyn')]['accuracy'].tolist() 42 | 43 | x = np.arange(len(br_acc_ap)) 44 | 45 | # calculate percentage improvement 46 | iot_up, ap_up = [], [] 47 | for i in range(len(br_acc_dyn)): 48 | #ap_up_ = (br_acc_dyn[i] - br_acc_ap[i]) / br_acc_ap[i] 49 | #iot_up_ = (br_acc_dyn[i] - br_acc_iot[i]) / br_acc_iot[i] 50 | ap_up_ = br_acc_dyn[i] - br_acc_ap[i] 51 | iot_up_ = br_acc_dyn[i] - br_acc_iot[i] 52 | ap_up.append(ap_up_) 53 | iot_up.append(iot_up_) 54 | 55 | dat_ = {'dataset': df_acc['dataset'].tolist()[:len(br_acc_dyn)], 56 | 'encoder': df_acc['encoder'].tolist()[:len(br_acc_dyn)], 57 | 'iot_up': iot_up, 58 | 'ap_up': ap_up 59 | } 60 | 61 | print(pd.DataFrame(data=dat_)) 62 | 63 | # plot 64 | ax1.bar(x-barwidth, br_acc_ap, color ='b', width = barwidth, edgecolor ='black', label ='AP Training') 65 | ax1.bar(x, br_acc_iot, color ='y', width = barwidth, edgecolor ='black', label ='UCD Training') 66 | ax1.bar(x+barwidth, br_acc_dyn, color ='r', width = barwidth, edgecolor ='black', label ='Centaur') 67 | 68 | ax1.set_xticks(x, df_acc['key'].unique().tolist(), rotation=35) 69 | ax1.set_yticks(np.arange(0, 1, 0.1)) 70 | ax1.legend(loc="upper center", bbox_to_anchor=(0.54,1), fontsize=8) 71 | 72 | # plt.xlabel('Models and Datasets') 73 | ax1.set_ylabel('Test Accuracy') 74 | 75 | ax2 = ax1.twinx() 76 | ax2.plot(x_vnl,y_vnl,linestyle='', markersize=10, marker='_',color='black') 77 | ax2.set_yticks(np.arange(0, 1, 0.1)) 78 | ax2.get_yaxis().set_visible(False) 79 | 80 | plt.savefig('imgs/acc_all.pdf') -------------------------------------------------------------------------------- /frameworks/centaur/hiefed/fed_criterion.py: -------------------------------------------------------------------------------- 1 | import flwr as fl 2 | from flwr.common import ( 3 | EvaluateIns, 4 | FitIns, 5 | Parameters, 6 | ) 7 | from typing import Callable, Dict, List, Optional, Tuple, Union 8 | 9 | from flwr.server.client_proxy import ClientProxy 10 | from flwr.server.client_manager import ClientManager 11 | from flwr.server.criterion import Criterion 12 | 13 | 14 | class FedAvg_criterion(fl.server.strategy.FedAvg): 15 | 16 | def configure_fit( 17 | self, server_round: int, parameters: Parameters, client_manager: ClientManager 18 | ) -> List[Tuple[ClientProxy, FitIns]]: 19 | """Configure the next round of training.""" 20 | config = {} 21 | if self.on_fit_config_fn is not None: 22 | # Custom fit config function provided 23 | config = self.on_fit_config_fn(server_round) 24 | fit_ins = FitIns(parameters, config) 25 | 26 | # Sample clients 27 | #_# Return the sample size and the required number of available clients. 28 | #_# https://flower.dev/docs/apiref-flwr.html#flwr.server.strategy.FedAvg.num_fit_clients 29 | sample_size, min_num_clients = self.num_fit_clients( 30 | client_manager.num_available() 31 | ) 32 | 33 | # load cids from the buffer 34 | with open('log/cids_buffer.csv', 'r') as f: 35 | pre_cids = f.read() 36 | 37 | #_# https://github.com/adap/flower/blob/main/src/py/flwr/server/criterion.py 38 | #_# Abstract class which allows subclasses to implement criterion sampling. 39 | class ChosenCriterion(Criterion): 40 | """Criterion to select only test clients.""" 41 | #_# Decide whether a client should be eligible for sampling or not. 42 | #_# pre_cids means previous client ids. 43 | #_# for example: 47|170,94|170,63|170,58|170,71|170,68|170,64|170,5|170, 44 | 45 | def select(self, client: ClientProxy) -> bool: 46 | if pre_cids == '': 47 | return True 48 | else: 49 | cids_only = [i.split('|')[0] for i in pre_cids.split(',')[:-1]] 50 | return client.cid in cids_only 51 | #_# 52 | """ 53 | Current code is like this: we use the same "Criterion" for both "iot" and "ap" rounds. 54 | Because for "iot" rounds we do not need to choose specific clients, after each 55 | "ap" round, everytime we restart the cids buffer in "client.py": 56 | if destine == 'ap': 57 | with open('log/cids_buffer.csv', 'w') as f: 58 | f.write('') 59 | So, for "iot" round, in the above "select()" method the "if" condition is true. 60 | Maybe, it would be better to implement two "Criterion"s: one dedicated to "iot" and another to "ap". 61 | """ 62 | #_# 63 | 64 | clients = client_manager.sample( 65 | num_clients=sample_size, 66 | min_num_clients=min_num_clients, 67 | criterion=ChosenCriterion() 68 | ) 69 | 70 | # Return client/config pairs 71 | return [(client, fit_ins) for client in clients] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data-centric Federated Learning 2 | This repository is dedicated to open-sourcing our works in Federated Learning, emphasizing data-centric methodologies. 3 | 4 | 5 | # Frameworks: 6 | 7 | ## List: 8 | 9 | The source code associated with our frameworks, presented in our published papers or preprints: 10 | 11 | 1. See `frameworks/centaur` for **Centaur** framework presented in "[Enhancing Efficiency in Multidevice Federated Learning through Data Selection](https://arxiv.org/abs/2211.04175)" 12 | 13 | > The primary version of this work has been presented in ICLR Workshop on Machine Learning for IoT: Datasets, Perception, and Understanding -- Centaur: Federated Learning for Constrained Edge Devices: https://arxiv.org/abs/2211.04175v3 14 | 15 | ### (1) Centaur 16 | 17 | #### Abstract 18 | Federated learning (FL) in multidevice environments creates new opportunities to learn from a vast and diverse amount of private data. Although personal devices capture valuable data, their memory, computing, connectivity, and battery resources are often limited. Since deep neural networks (DNNs) are the typical machine learning models employed in FL, there are demands for integrating ubiquitous constrained devices into the training process of DNNs. In this paper, we develop an FL framework to incorporate on-device data selection on such constrained devices, which allows partition-based training of a DNN through collaboration between constrained devices and resourceful devices of the same client. Evaluations on five benchmark DNNs and six benchmark datasets across different modalities show that, on average, our framework achieves ~19% higher accuracy and ~58% lower latency; compared to the baseline FL without our implemented strategies. We demonstrate the effectiveness of our FL framework when dealing with imbalanced data, client participation heterogeneity, and various mobility patterns. 19 | 20 | ![](./figs/centaur/_overview.png) 21 | 22 | ![](./figs/centaur/_results_1.png) 23 | 24 | ![](./figs/centaur/_results_2.png) 25 | 26 | 27 | #### How to run the code: 28 | (a) **Prerequisites**: Please see `requirements.txt` 29 | > Note that, there might be a problem with older versions of the Flower framework: a problem when running a simulation with ray>=1.12. Ray workers will go into Ray::IDLE mode, which occupies CUDA memory and leads to OOM. For using ray>=1.12 only, a workaround is to change all ray.remote as ray.remote(max_calls=1) in the Flower's ray_client_proxy file. 30 | 31 | (b) To Run: 32 | 33 | i. Make sure to create directories: `log` nd `results` 34 | ii. Set up the desired parameters in `unning_args.py` 35 | ii. Run `python main.py` 36 | (c) For replicating some of the results, run the following command to generate bash scripts: `python make_script.py` and then run specific `.sh` file. 37 | 38 | 39 | #### Citation 40 | Please use: 41 | 42 | ```bibtex 43 | @misc{mo2024enhancing, 44 | title={Enhancing Efficiency in Multidevice Federated Learning through Data Selection}, 45 | author={Fan Mo and Mohammad Malekzadeh and Soumyajit Chatterjee and Fahim Kawsar and Akhil Mathur}, 46 | year={2024}, 47 | eprint={2211.04175}, 48 | archivePrefix={arXiv}, 49 | primaryClass={cs.LG} 50 | } 51 | ``` 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /frameworks/centaur/requirements.txt: -------------------------------------------------------------------------------- 1 | aiosignal==1.3.1 2 | async-timeout==4.0.2 3 | attrs==23.1.0 4 | backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work 5 | beautifulsoup4 @ file:///home/linux1/recipes/ci/beautifulsoup4_1610988766420/work 6 | brotlipy==0.7.0 7 | certifi==2021.5.30 8 | cffi @ file:///tmp/build/80754af9/cffi_1613246939562/work 9 | chardet @ file:///tmp/build/80754af9/chardet_1605303159953/work 10 | click==8.1.3 11 | conda==4.10.1 12 | conda-build==3.21.4 13 | conda-package-handling @ file:///tmp/build/80754af9/conda-package-handling_1618262151086/work 14 | cryptography @ file:///tmp/build/80754af9/cryptography_1616769182610/work 15 | cycler==0.11.0 16 | decorator @ file:///tmp/build/80754af9/decorator_1621259047763/work 17 | distlib==0.3.6 18 | dnspython==2.1.0 19 | filelock==3.12.0 20 | flwr==1.0.0 21 | fonttools==4.38.0 22 | frozenlist==1.3.3 23 | fvcore==0.1.5.post20221221 24 | glob2 @ file:///home/linux1/recipes/ci/glob2_1610991677669/work 25 | grpcio==1.43.0 26 | idna @ file:///tmp/build/80754af9/idna_1593446292537/work 27 | importlib-metadata==4.13.0 28 | importlib-resources==5.12.0 29 | iopath==0.1.10 30 | ipython @ file:///tmp/build/80754af9/ipython_1617118429768/work 31 | ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work 32 | iterators==0.0.2 33 | jedi==0.17.0 34 | Jinja2 @ file:///tmp/build/80754af9/jinja2_1621238361758/work 35 | jsonschema==4.17.3 36 | kiwisolver==1.4.4 37 | libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work 38 | MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1621528142364/work 39 | matplotlib==3.5.3 40 | mkl-fft==1.3.0 41 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1618853974840/work 42 | mkl-service==2.3.0 43 | msgpack==1.0.5 44 | numpy==1.21.6 45 | olefile==0.46 46 | packaging==23.1 47 | pandas==1.3.5 48 | parso @ file:///tmp/build/80754af9/parso_1617223946239/work 49 | pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work 50 | pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work 51 | Pillow @ file:///tmp/build/80754af9/pillow_1617386154241/work 52 | pkginfo==1.7.0 53 | pkgutil-resolve-name==1.3.10 54 | platformdirs==3.5.1 55 | portalocker==2.7.0 56 | prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1616415428029/work 57 | protobuf==3.20.3 58 | psutil @ file:///tmp/build/80754af9/psutil_1612298016854/work 59 | ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 60 | pycosat==0.6.3 61 | pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work 62 | Pygments @ file:///tmp/build/80754af9/pygments_1621606182707/work 63 | pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1605545627475/work 64 | pyparsing==3.1.0 65 | pyrsistent==0.19.3 66 | PySocks @ file:///tmp/build/80754af9/pysocks_1594394576006/work 67 | python-dateutil==2.8.2 68 | python-etcd==0.4.5 69 | pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work 70 | PyYAML==5.4.1 71 | ray==2.4.0 72 | redis==4.5.5 73 | requests @ file:///tmp/build/80754af9/requests_1592841827918/work 74 | ruamel-yaml-conda @ file:///tmp/build/80754af9/ruamel_yaml_1616016701961/work 75 | six @ file:///tmp/build/80754af9/six_1605205313296/work 76 | soupsieve @ file:///tmp/build/80754af9/soupsieve_1616183228191/work 77 | tabulate==0.9.0 78 | termcolor==2.3.0 79 | torch==1.8.1 80 | torch-summary==1.4.5 81 | torchelastic==0.2.0 82 | torchinfo==1.8.0 83 | torchtext==0.9.1 84 | torchvision==0.9.1 85 | tqdm @ file:///tmp/build/80754af9/tqdm_1605303662894/work 86 | traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work 87 | typing-extensions==4.6.3 88 | urllib3 @ file:///tmp/build/80754af9/urllib3_1603305693037/work 89 | virtualenv==20.21.0 90 | wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work 91 | yacs==0.1.8 92 | zipp==3.15.0 93 | -------------------------------------------------------------------------------- /frameworks/centaur/plot/2_classifier.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | import numpy as np 4 | 5 | import os 6 | from os import listdir 7 | from os.path import isfile, join 8 | 9 | from numpy import loadtxt 10 | 11 | import pandas as pd 12 | 13 | import matplotlib.pyplot as plt 14 | from matplotlib.ticker import FuncFormatter 15 | 16 | from plot_utils import load_all_acc 17 | 18 | plt.rcParams.update({'font.size': 8}) 19 | 20 | 21 | classifier_focus = 0 22 | training_method_focus = 1 23 | 24 | res_folder = 'results' 25 | df_acc = load_all_acc(res_folder) 26 | df_acc = df_acc.loc[~(df_acc['dataset'] == 'emnist')] 27 | 28 | # filter out low accuracy 29 | df_acc = df_acc.loc[df_acc['accuracy'] > 0.3] 30 | 31 | df_acc['key'] = [d + '\n' + e for d,e in zip(df_acc['encoder'],df_acc['dataset'])] 32 | key_list = df_acc['key'].unique().tolist() 33 | 34 | classifier_order = ['S','M','L'] 35 | 36 | idxs = list(range(3)) 37 | 38 | 39 | if classifier_focus == 1: 40 | df_acc['classifier_index'] = df_acc['classifier'] 41 | df_acc = df_acc.set_index('classifier_index').loc[classifier_order] 42 | 43 | dests = ['ap', 'iot', 'dyn'] 44 | dests_str = ['AP Training', 'UCD Training', 'Centaur'] 45 | 46 | # plot all 47 | figs, axes = plt.subplots(nrows=1, ncols=3, sharey=True, figsize=(5.5,2.6)) 48 | plt.subplots_adjust(wspace=0.3, right=0.78, bottom=0.15) 49 | # plt.subplots_adjust(left=0.1, , wspace=0.9, ) 50 | 51 | for idx,col in zip(idxs,axes): 52 | dest_acc = df_acc.loc[df_acc['dest'] == dests[idx]] 53 | for key in key_list: 54 | dat = dest_acc[dest_acc['key'] == key] 55 | x = dat['classifier'].tolist() 56 | y = dat['accuracy'].tolist() 57 | col.plot(x, y, label=key, marker=".", markersize=10, linewidth=0.5) 58 | 59 | if idx==0: col.set_ylabel('Test Accuracy') 60 | col.set_xlabel('Classifier Size') 61 | col.set_ylim(0.3, 1) 62 | col.set_yticks(np.arange(0.3, 1, 0.1)) 63 | col.title.set_text(dests_str[idx]) 64 | 65 | #import pdb; pdb.set_trace() 66 | 67 | handles, labels = col.get_legend_handles_labels() 68 | figs.legend(handles, labels, loc='center right') 69 | 70 | plt.savefig('imgs/classifier_acc.pdf') 71 | 72 | 73 | 74 | if training_method_focus == 1: 75 | 76 | d_str = ['AP', 'UCD', 'Centaur'] 77 | df_acc.loc[df_acc['dest'] == 'ap', 'dest'] = d_str[0] 78 | df_acc.loc[df_acc['dest'] == 'iot', 'dest'] = d_str[1] 79 | df_acc.loc[df_acc['dest'] == 'dyn', 'dest'] = d_str[2] 80 | 81 | df_acc['dest_index'] = df_acc['dest'] 82 | df_acc = df_acc.set_index('dest_index').loc[d_str] 83 | 84 | classifier_str = ['Small', 'Medium', 'Large'] 85 | 86 | # plot all 87 | figs, axes = plt.subplots(nrows=1, ncols=3, sharey=True, figsize=(5.5,2.6)) 88 | plt.subplots_adjust(wspace=0.3, right=0.78, bottom=0.15) 89 | # plt.subplots_adjust(left=0.1, , wspace=0.9, ) 90 | 91 | markers = ["d", "v", "s", "*", "^", "o"] 92 | 93 | for idx,col in zip(idxs,axes): 94 | classifier_acc = df_acc.loc[df_acc['classifier'] == classifier_order[idx]] 95 | for count, key in enumerate(key_list): 96 | dat = classifier_acc[classifier_acc['key'] == key] 97 | x = dat['dest'].tolist() 98 | y = dat['accuracy'].tolist() 99 | col.plot(x, y, label=key, marker=markers[count], markersize=7, linewidth=0.4, linestyle=(0, (1, 10))) 100 | 101 | if idx==0: col.set_ylabel('Test Accuracy') 102 | col.set_xlabel('Training Methods') 103 | col.set_ylim(0.3, 1) 104 | col.set_yticks(np.arange(0.3, 1, 0.1)) 105 | col.title.set_text(classifier_str[idx]) 106 | 107 | 108 | handles, labels = col.get_legend_handles_labels() 109 | figs.legend(handles, labels, loc='center right') 110 | 111 | plt.savefig('imgs/classifier_acc.pdf') -------------------------------------------------------------------------------- /frameworks/centaur/plot/9_prox_table.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | avg_file = 'results/acc_dyn_1cifar10_2mobi_3medium_c8|100_e3_a5_b3_g0.csv' 7 | 8 | base = [1, 0.1, 0.01, 0.001] 9 | grow = 3 10 | perc = 0.01 11 | avg_mu = ['1e-07', '2e-07', '3e-07', '4e-07', '5e-07', '6e-07', '7e-07', '8e-07', '9e-07', '1e-06', '1.1e-06', '1.2e-06'] 12 | 13 | 14 | # get all file names 15 | grow_pos = [perc*i for i in range(1, grow+1)] 16 | grow_minus = [i*-1 for i in grow_pos[::-1]] 17 | grow_plus = grow_minus + [0] + grow_pos 18 | 19 | all_mu = [] 20 | base_mu = [] 21 | for b in base: 22 | all_mu = all_mu + [g*b + b for g in grow_plus] # all ranged mu 23 | base_mu = base_mu + [b] * (grow*2 + 1) # all base mu 24 | 25 | filename_list = [] 26 | 27 | # build file names for prox 28 | for mu in all_mu: 29 | filename = 'results/acc_dynprox' + str(round(mu, 8)) + '_1cifar10_2mobi_3medium_c8|100_e3_a5_b3_g0_flpa1000_mr0|0.csv' 30 | filename_list.append(filename) 31 | 32 | # build filenames for avg (tiny mu) 33 | for mu in avg_mu: 34 | filename = 'results/acc_dynprox' + mu + '_1cifar10_2mobi_3medium_c8|100_e3_a5_b3_g0_flpa1000_mr0|0.csv' 35 | filename_list.append(filename) 36 | all_mu = all_mu + avg_mu 37 | base_mu = base_mu + [0]*len(avg_mu) 38 | 39 | # read all prox 40 | prox_mu_list = [] 41 | prox_acc_list = [] 42 | prox_bmu_list = [] 43 | for mu, bmu, filename in zip(all_mu, base_mu, filename_list): 44 | if os.path.isfile(filename): 45 | prox_acc = pd.read_csv(filename, sep=',') 46 | prox_acc = prox_acc['accuracy'].tolist()[1:] 47 | if len(prox_acc) > 95: 48 | prox_bmu_list.append(bmu) 49 | prox_mu_list.append(mu) 50 | prox_acc_list.append(prox_acc) 51 | 52 | # read avg 53 | avg_file = 'results/acc_dyn_1cifar10_2mobi_3medium_c8|100_e3_a5_b3_g0.csv' 54 | avg_acc = pd.read_csv(avg_file, sep=',') 55 | avg_acc = avg_acc['accuracy'].tolist()[1:] 56 | prox_bmu_list.append(0) 57 | prox_mu_list.append(0) 58 | prox_acc_list.append(avg_acc) 59 | 60 | # prox5_file = 'results/acc_dynprox5.0_1cifar10_2mobi_3medium_c8|100_e3_a5_b3_g0_flpa1000_mr0|0.csv' 61 | # prox5_acc = pd.read_csv(prox5_file, sep=',') 62 | # import pdb; pdb.set_trace() 63 | # prox5_acc = avg_acc['accuracy'].tolist()[1:] 64 | # prox_bmu_list.append(5) 65 | # prox_mu_list.append(5) 66 | # prox_acc_list.append(prox5_acc) 67 | 68 | mean_list = [] 69 | smoothness_list = [] 70 | for bmu, mu, acc in zip(prox_bmu_list, prox_mu_list, prox_acc_list): 71 | smoothness = round(np.std(np.diff(acc))/np.abs(np.mean(np.diff(acc))),4) 72 | #smoothness = round(np.std(np.diff(acc)),4) 73 | # https://stats.stackexchange.com/questions/24607/how-to-measure-smoothness-of-a-time-series-in-r 74 | mean = round(np.max(acc), 6) 75 | mean_list.append(mean) 76 | smoothness_list.append(smoothness) 77 | # print("bmu={0:8}, mu={1:10}, mean={2:8}, smothness={3:8}".format(bmu, mu, mean, smoothness)) 78 | 79 | d = {'bmu': prox_bmu_list, 80 | 'mu': prox_mu_list, 81 | 'acc': mean_list, 82 | 'smoothness': smoothness_list, 83 | } 84 | 85 | df = pd.DataFrame(data=d) 86 | print(df) 87 | print('='*50) 88 | 89 | df = df.groupby(by=['bmu'], as_index=False).agg({'acc': ['count', 'mean', 'std'], 'smoothness': ['mean', 'std']}) 90 | print(df) 91 | # import pdb; pdb.set_trace() 92 | 93 | 94 | 95 | 96 | 97 | # x = list(range(100)) 98 | # #import pdb; pdb.set_trace() 99 | 100 | # figs, col = plt.subplots(figsize=(5,5)) 101 | # col.plot(x, avg_acc, label='FedAvg') 102 | # # col.plot(x, prox5_acc, label=r'FedProx: $\mu$=5') 103 | # col.plot(x, prox1_acc, label=r'FedProx: $\mu$=1') 104 | # col.plot(x, prox01_acc, label=r'FedProx: $\mu$=0.1') 105 | # col.plot(x, prox001_acc, label=r'FedProx: $\mu$=0.01') 106 | # col.plot(x, prox0001_acc, label=r'FedProx: $\mu$=0.001') 107 | 108 | 109 | # col.set_xlabel('Rounds') 110 | # col.set_ylabel('Test Accuracy') 111 | # col.title.set_text('Centaur') 112 | # col.legend() 113 | 114 | # plt.savefig('imgs/prox.pdf') 115 | -------------------------------------------------------------------------------- /frameworks/centaur/plot/4_cost_linkage.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | import numpy as np 4 | 5 | import os 6 | from os import listdir 7 | from os.path import isfile, join 8 | 9 | from numpy import loadtxt 10 | 11 | import pandas as pd 12 | 13 | import matplotlib.pyplot as plt 14 | from matplotlib.ticker import FuncFormatter 15 | 16 | from plot_utils import y_fmt, load_mac_latency_energy, load_comm_latency_energy, load_res_from_cfg 17 | 18 | 19 | SMALL_SIZE = 5 20 | MEDIUM_SIZE = 9 21 | BIGGER_SIZE = 10 22 | 23 | plt.rc('font', size=SMALL_SIZE) # controls default text sizes 24 | plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title 25 | plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels 26 | plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels 27 | plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels 28 | plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize 29 | plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title 30 | 31 | 32 | encoders = ['mobilenet','mobilenet', 'efficientnet', 'efficientnet'] 33 | classifiers = ['small', 'large', 'small', 'large'] 34 | keys = ['mobile+S','mobile+L','efficient+S','efficient+L'] 35 | post_args = ['1cifar10_2mobi_3small_c8|100_e3_a5_b3_g0', '1cifar10_2mobi_3large_c8|100_e3_a5_b3_g0', '1cifar10_2effi_3small_c5|100_e3_a5_b3_g0', '1cifar10_2effi_3large_c5|100_e3_a5_b3_g0'] 36 | # load configuration from cfg file 37 | computation_cfg, communication_cfg = load_res_from_cfg() 38 | iot_freq, ap_freq, count_inference, iot_energy_mac, ap_energy_mac = computation_cfg 39 | up_iot, down_iot, up_ap, down_ap, sample_size, iot_energy_comm, ap_energy_comm = communication_cfg 40 | 41 | 42 | # plot 43 | idxs = [1, 2] 44 | dests = ['dyn', 'iot', 'ap'] 45 | nrows, ncols = 2, 4 46 | 47 | figs, axes = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, figsize=(5,2.5)) 48 | plt.subplots_adjust(left=0.1, right=0.9, wspace=0.6, top=0.98, bottom=0.19) 49 | 50 | for c in range(ncols): 51 | 52 | encoder = encoders[c] 53 | classifier = classifiers[c] 54 | post_arg = post_args[c] 55 | 56 | # compute mac and its energy and latency 57 | mac_dyn, mac_ap, mac_iot, _, _, _ = load_mac_latency_energy(encoder,classifier, post_arg, 58 | iot_freq, ap_freq, iot_energy_mac, ap_energy_mac) 59 | 60 | comm_dyn, comm_ap, comm_iot, _, _, _ = load_comm_latency_energy(encoder, classifier, post_arg, 61 | up_iot, down_iot, up_ap, down_ap,iot_energy_comm, ap_energy_comm, sample_size) 62 | 63 | for r in range(nrows): 64 | axe = axes[r,c] 65 | 66 | if r == 0: 67 | dyn_x, dyn_y1, dyn_y2 = mac_dyn[0], mac_dyn[2], mac_dyn[4] 68 | latency_color = 'b' 69 | elif r == 1: 70 | dyn_x, dyn_y1, dyn_y2 = mac_dyn[1], mac_dyn[3], mac_dyn[5] 71 | latency_color = 'g' 72 | 73 | axe.plot(dyn_x, dyn_y1, '-', color=latency_color, label ='Latency', linewidth=1) 74 | 75 | axe.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 76 | axe.xaxis.set_major_formatter(FuncFormatter(y_fmt)) 77 | 78 | axe_tx = axe.twinx() 79 | axe_tx.plot(dyn_x, dyn_y2, ':', color='r', label ='Energy', linewidth=1.5) 80 | 81 | axe_tx.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 82 | 83 | if r == 1: axe.set_xlabel('MACs\n' + keys[c]) 84 | if c == 0 and r == 0: axe.set_ylabel('Latency (UCD)', color=latency_color) 85 | if c == 0 and r == 1: axe.set_ylabel('Latency (AP)', color=latency_color) 86 | if c == 3 and r == 0: axe_tx.set_ylabel('Energy (UCD)', color='r') 87 | if c == 3 and r == 1: axe_tx.set_ylabel('Energy (AP)', color='r') 88 | 89 | plt.savefig('imgs/cost_linkage.pdf') 90 | 91 | # import pdb; pdb.set_trace() 92 | 93 | #ap_x, ap_y1, ap_y2 = mac_ap[0], mac_ap[2], mac_ap[4] 94 | #iot_x, iot_y1, iot_y2 = mac_iot[0], mac_iot[2], mac_iot[4] 95 | #ap_x, ap_y1, ap_y2 = mac_ap[1], mac_ap[3], mac_ap[5] 96 | #iot_x, iot_y1, iot_y2 = mac_iot[1], mac_iot[3], mac_iot[5] 97 | #col.plot(ap_x, ap_y1, '-', color='b', label ='Latency', linewidth=1) 98 | #col.plot(iot_x, iot_y1, '-', color='b', label ='Latency', linewidth=1) 99 | #col_tx.plot(ap_x, ap_y2, ':', color='r', label ='Energy', linewidth=1.5) 100 | #col_tx.plot(iot_x, iot_y2, ':', color='r', label ='Energy', linewidth=1.5) -------------------------------------------------------------------------------- /frameworks/centaur/third_party/autograd_hacks_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import autograd_hacks 6 | 7 | 8 | # Lenet-5 from https://github.com/pytorch/examples/blob/master/mnist/main.py 9 | class Net(nn.Module): 10 | def __init__(self): 11 | super(Net, self).__init__() 12 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 13 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 14 | self.fc1 = nn.Linear(4 * 4 * 50, 500) 15 | self.fc2 = nn.Linear(500, 10) 16 | 17 | def forward(self, x): 18 | x = F.relu(self.conv1(x)) 19 | x = F.max_pool2d(x, 2, 2) 20 | x = F.relu(self.conv2(x)) 21 | x = F.max_pool2d(x, 2, 2) 22 | x = x.view(-1, 4 * 4 * 50) 23 | x = F.relu(self.fc1(x)) 24 | x = self.fc2(x) 25 | return x 26 | 27 | 28 | # Tiny LeNet-5 for Hessian testing 29 | class TinyNet(nn.Module): 30 | def __init__(self): 31 | super(TinyNet, self).__init__() 32 | self.conv1 = nn.Conv2d(1, 2, 2, 1) 33 | self.conv2 = nn.Conv2d(2, 2, 2, 1) 34 | self.fc1 = nn.Linear(2, 2) 35 | self.fc2 = nn.Linear(2, 10) 36 | 37 | def forward(self, x): # 28x28 38 | x = F.max_pool2d(x, 4, 4) # 7x7 39 | x = F.relu(self.conv1(x)) # 6x6 40 | x = F.max_pool2d(x, 2, 2) # 3x3 41 | x = F.relu(self.conv2(x)) # 2x2 42 | x = F.max_pool2d(x, 2, 2) # 1x1 43 | x = x.view(-1, 2 * 1 * 1) # C * W * H 44 | x = F.relu(self.fc1(x)) 45 | x = self.fc2(x) 46 | return x 47 | 48 | 49 | # Autograd helpers, from https://gist.github.com/apaszke/226abdf867c4e9d6698bd198f3b45fb7 50 | def jacobian(y: torch.Tensor, x: torch.Tensor, create_graph=False): 51 | jac = [] 52 | flat_y = y.reshape(-1) 53 | grad_y = torch.zeros_like(flat_y) 54 | for i in range(len(flat_y)): 55 | grad_y[i] = 1. 56 | grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph) 57 | jac.append(grad_x.reshape(x.shape)) 58 | grad_y[i] = 0. 59 | return torch.stack(jac).reshape(y.shape + x.shape) 60 | 61 | 62 | def hessian(y: torch.Tensor, x: torch.Tensor): 63 | return jacobian(jacobian(y, x, create_graph=True), x) 64 | 65 | 66 | def test_grad1(): 67 | torch.manual_seed(1) 68 | model = Net() 69 | loss_fn = nn.CrossEntropyLoss() 70 | 71 | n = 4 72 | data = torch.rand(n, 1, 28, 28) 73 | targets = torch.LongTensor(n).random_(0, 10) 74 | 75 | autograd_hacks.add_hooks(model) 76 | output = model(data) 77 | loss_fn(output, targets).backward(retain_graph=True) 78 | autograd_hacks.compute_grad1(model) 79 | autograd_hacks.disable_hooks() 80 | 81 | # Compare values against autograd 82 | losses = torch.stack([loss_fn(output[i:i+1], targets[i:i+1]) for i in range(len(data))]) 83 | 84 | for layer in model.modules(): 85 | if not autograd_hacks.is_supported(layer): 86 | continue 87 | for param in layer.parameters(): 88 | assert torch.allclose(param.grad, param.grad1.mean(dim=0)) 89 | assert torch.allclose(jacobian(losses, param), param.grad1) 90 | 91 | 92 | def test_hess(): 93 | subtest_hess_type('CrossEntropy') 94 | subtest_hess_type('LeastSquares') 95 | 96 | 97 | def subtest_hess_type(hess_type): 98 | torch.manual_seed(1) 99 | model = TinyNet() 100 | 101 | def least_squares_loss(data_, targets_): 102 | assert len(data_) == len(targets_) 103 | err = data_ - targets_ 104 | return torch.sum(err * err) / 2 / len(data_) 105 | 106 | n = 3 107 | data = torch.rand(n, 1, 28, 28) 108 | 109 | autograd_hacks.add_hooks(model) 110 | output = model(data) 111 | 112 | if hess_type == 'LeastSquares': 113 | targets = torch.rand(output.shape) 114 | loss_fn = least_squares_loss 115 | else: # hess_type == 'CrossEntropy': 116 | targets = torch.LongTensor(n).random_(0, 10) 117 | loss_fn = nn.CrossEntropyLoss() 118 | 119 | autograd_hacks.backprop_hess(output, hess_type=hess_type) 120 | autograd_hacks.clear_backprops(model) 121 | autograd_hacks.backprop_hess(output, hess_type=hess_type) 122 | 123 | autograd_hacks.compute_hess(model) 124 | autograd_hacks.disable_hooks() 125 | 126 | for layer in model.modules(): 127 | if not autograd_hacks.is_supported(layer): 128 | continue 129 | for param in layer.parameters(): 130 | loss = loss_fn(output, targets) 131 | hess_autograd = hessian(loss, param) 132 | hess = param.hess 133 | assert torch.allclose(hess, hess_autograd.reshape(hess.shape)) 134 | 135 | 136 | if __name__ == '__main__': 137 | test_grad1() 138 | test_hess() -------------------------------------------------------------------------------- /frameworks/centaur/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import ray 4 | import torch 5 | import torchvision 6 | 7 | from collections import OrderedDict 8 | from typing import Dict, Callable, Optional, Tuple, List 9 | from pathlib import Path 10 | import numpy as np 11 | import flwr as fl 12 | from flwr.common.typing import Scalar 13 | 14 | from hiefed.fed_criterion import FedAvg_criterion 15 | from hiefed.dataset_utils import getCIFAR10, getCIFAR100, do_fl_partitioning, Saver, getEMNIST, getUCIHAR, getPAMAP2, getMotion 16 | from hiefed.client import get_params, set_params, RayClient, init_network, test 17 | 18 | from oam import create_client_profile 19 | 20 | import running_args 21 | from running_args import fit_config 22 | from running_args import get_evaluate_fn 23 | from torchinfo import summary 24 | 25 | 26 | # Start Ray simulation (a _default server_ will be created) 27 | # This example does: 28 | # 1. Downloads CIFAR-10, CIFAR100-100, or EMNIST 29 | # 2. Partitions the dataset into N splits, where N is the total number of 30 | # clients. We refere to this as `pool_size`. The partition can be IID or non-IID 31 | # 3. Starts a Ray-based simulation where a % of clients are sample each round. 32 | # 4. After the M rounds end, the global model is evaluated on the entire testset. 33 | # Also, the global model is evaluated on the valset partition residing in each 34 | # client. This is useful to get a sense on how well the global model can generalise 35 | # to each client's data. 36 | if __name__ == "__main__": 37 | 38 | # parse input arguments 39 | args = running_args.get_parser().parse_args() 40 | 41 | client_resources = { 42 | "num_gpus": args.num_client_gpus, 43 | "num_cpus": args.num_client_cpus, 44 | } 45 | 46 | # download dataset 47 | if args.dataset == 'cifar10': 48 | train_path, testset = getCIFAR10() 49 | num_classes = 10 50 | elif args.dataset == 'cifar100': 51 | train_path, testset = getCIFAR100() 52 | num_classes = 100 53 | elif args.dataset == 'emnist': 54 | train_path, testset = getEMNIST() 55 | num_classes = 47 56 | elif args.dataset == 'pamap': 57 | train_path, testset = getPAMAP2() 58 | num_classes = 13 59 | elif args.dataset == 'uci_har': 60 | train_path, testset = getUCIHAR() 61 | num_classes = 6 62 | elif args.dataset == 'motion': 63 | train_path, testset = getMotion() 64 | num_classes = 6 65 | # TODO: elif args.dataset == 'femnist': 66 | # train_path, testset = getFEMNIST() 67 | # data_loader = torch.utils.data.DataLoader(train_path, batch_size=len(train_path)) 68 | # train_images, train_labels = iter(data_loader).next() 69 | # print(train_images.shape, train_labels.shape) 70 | # labels = np.array(train_labels).astype(int) 71 | # _unique, _counts = np.unique(labels, return_counts=True) 72 | # print(np.asarray((_unique, _counts)).T) 73 | # num_classes = 10 74 | else: 75 | raise NameError('Dataset does not supported!') 76 | 77 | # partition dataset (use a large `alpha` to make it IID; 78 | # a small value (e.g. 1) will make it non-IID) 79 | # This will create a new directory called "federated": in the directory where 80 | # CIFAR-10 lives. Inside it, there will be N=pool_size sub-directories each with 81 | # its own train/set split. 82 | fed_dir = do_fl_partitioning( 83 | train_path, pool_size=args.min_available_clients, 84 | alpha=args.fl_partitioning_alpha, 85 | num_classes=num_classes, val_ratio=0.1 86 | ) 87 | model = init_network(encoder=args.encoder, classifier=args.classifier, 88 | num_classes=num_classes, pre_trained=args.pre_trained) 89 | if args.dataset == 'pamap': 90 | summary(model.cuda(), input_size=(1, 1, 27, 200), device="cuda:0") 91 | elif args.dataset == 'uci_har' or args.dataset == 'motion': 92 | summary(model.cuda(), input_size=(1, 1, 9, 128), device="cuda:0") 93 | # Get model weights as a list of NumPy ndarray's 94 | weights = get_params(model) 95 | 96 | # Serialize ndarrays to `Parameters` 97 | parameters = fl.common.ndarrays_to_parameters(weights) 98 | 99 | # when debug mode is on, choose only two clients 100 | if args.debug: 101 | args.min_fit_clients = 2 102 | args.save_results = False 103 | 104 | # initial saver 105 | saver = Saver(args) 106 | 107 | # create client profiles for spatio-temporal coverage 108 | create_client_profile(num_clients = args.min_available_clients, temporal = "day", spatial = 10) 109 | 110 | 111 | # configure the strategy 112 | client_ratio = args.min_fit_clients/args.min_available_clients 113 | strategy = FedAvg_criterion( 114 | fraction_fit=client_ratio, 115 | fraction_evaluate=client_ratio, 116 | min_fit_clients=args.min_fit_clients, 117 | min_available_clients=args.min_available_clients, # All clients should be available 118 | on_fit_config_fn=fit_config, 119 | on_evaluate_config_fn=fit_config, 120 | evaluate_fn=get_evaluate_fn(testset, saver, num_classes), # centralised testset evaluation of global model 121 | initial_parameters=parameters, 122 | ) 123 | 124 | def client_fn(cid: str): 125 | # create a single client instance 126 | return RayClient(cid, fed_dir, saver, 127 | args.encoder, args.classifier, num_classes, args.pre_trained, args.client_configs) 128 | 129 | # (optional) specify ray config 130 | ray_config = {"include_dashboard": False, 131 | "local_mode": args.debug, # local mode on for ray 132 | } 133 | 134 | # start simulation 135 | fl.simulation.start_simulation( 136 | client_fn=client_fn, 137 | num_clients=args.min_available_clients, 138 | client_resources=client_resources, 139 | config=fl.server.ServerConfig(num_rounds=args.num_rounds), 140 | strategy=strategy, 141 | ray_init_args=ray_config, 142 | ) 143 | -------------------------------------------------------------------------------- /frameworks/centaur/plot/2_classifier_s.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | import numpy as np 4 | sys.path.append(".") 5 | 6 | import os 7 | from os import listdir 8 | from os.path import isfile, join 9 | 10 | from numpy import loadtxt 11 | 12 | import pandas as pd 13 | 14 | import matplotlib.pyplot as plt 15 | from matplotlib.ticker import FuncFormatter 16 | 17 | from plot_utils import load_all_acc 18 | 19 | plt.rcParams.update({'font.size': 8}) 20 | 21 | 22 | classifier_focus = 0 23 | training_method_focus = 1 24 | 25 | res_folder = 'results' 26 | 27 | 28 | def load_all_acc_s(res_folder, focus='basic'): 29 | onlyfiles = [f for f in listdir(res_folder) if isfile(join(res_folder, f))] 30 | acc_ap = [filename for filename in onlyfiles if 'acc_ap_1' in filename] 31 | acc_iot = [filename for filename in onlyfiles if 'acc_iot_1' in filename] 32 | acc_dyn = [filename for filename in onlyfiles if 'acc_dyn_1' in filename] 33 | acc_x = acc_ap + acc_iot + acc_dyn 34 | 35 | splitted_list = [] 36 | for acc_name in acc_x: 37 | # marking based on the filename 38 | acc_name = acc_name.replace('uci_har', 'ucihar') 39 | splitted = acc_name.split('_') 40 | if 'flpa1000' not in splitted: 41 | continue 42 | print(splitted) 43 | splitted = [splitted[1], splitted[2][1:], splitted[3][1:], splitted[4][1:]] 44 | name_base = ['dest','dataset','encoder','classifier'] 45 | 46 | # read highest accuracy 47 | acc_name = acc_name.replace('ucihar', 'uci_har') 48 | acc_temp = pd.read_csv(res_folder + '/' + acc_name) 49 | 50 | acc_list = acc_temp['accuracy'].tolist() 51 | acc_list.sort(reverse=True) 52 | max_acc = sum(acc_list[:5])/5 53 | 54 | # loc this accuracy 55 | round_idx = int(acc_temp.loc[acc_temp['accuracy'] == acc_list[0], 'round'].tolist()[0]) 56 | splitted = splitted + [max_acc] + [round_idx] 57 | splitted_list.append(splitted) 58 | 59 | 60 | columns_name = name_base + ['accuracy','loc'] 61 | df_acc = pd.DataFrame(splitted_list, columns=columns_name) 62 | df_acc = df_acc.sort_values(columns_name) 63 | 64 | # df_acc.loc[df_acc['encoder'] == 'effi','encoder'] = 'efficient' 65 | # df_acc.loc[df_acc['encoder'] == 'mobi','encoder'] = 'mobile' 66 | # df_acc.loc[df_acc['encoder'] == 'shuf','encoder'] = 'shuffle' 67 | df_acc.loc[df_acc['dataset'] == 'motion','dataset'] = 'MotionSense' 68 | df_acc.loc[df_acc['dataset'] == 'ucihar','dataset'] = 'UCIHAR' 69 | df_acc.loc[df_acc['dataset'] == 'pamap','dataset'] = 'PAMAP2' 70 | 71 | df_acc.loc[df_acc['classifier'] == 'small','classifier'] = 'S' 72 | df_acc.loc[df_acc['classifier'] == 'medium','classifier'] = 'M' 73 | df_acc.loc[df_acc['classifier'] == 'large','classifier'] = 'L' 74 | 75 | return df_acc 76 | 77 | 78 | 79 | df_acc = load_all_acc_s(res_folder) 80 | # df_acc = df_acc.loc[~(df_acc['dataset'] == 'motion')] 81 | 82 | # filter out low accuracy 83 | # df_acc = df_acc.loc[df_acc['accuracy'] > 0.3] 84 | 85 | df_acc['key'] = [e for e in df_acc['dataset']] 86 | key_list = df_acc['key'].unique().tolist() 87 | 88 | classifier_order = ['S','M','L'] 89 | 90 | idxs = list(range(3)) 91 | 92 | 93 | if classifier_focus == 1: 94 | df_acc['classifier_index'] = df_acc['classifier'] 95 | df_acc = df_acc.set_index('classifier_index').loc[classifier_order] 96 | 97 | dests = ['ap', 'iot', 'dyn'] 98 | dests_str = ['AP Training', 'UCD Training', 'Centaur'] 99 | 100 | # plot all 101 | figs, axes = plt.subplots(nrows=1, ncols=3, sharey=True, figsize=(5.5,2.6)) 102 | plt.subplots_adjust(wspace=0.3, right=0.78, bottom=0.15) 103 | # plt.subplots_adjust(left=0.1, , wspace=0.9, ) 104 | 105 | for idx,col in zip(idxs,axes): 106 | dest_acc = df_acc.loc[df_acc['dest'] == dests[idx]] 107 | for key in key_list: 108 | dat = dest_acc[dest_acc['key'] == key] 109 | x = dat['classifier'].tolist() 110 | y = dat['accuracy'].tolist() 111 | col.plot(x, y, label=key, marker=".", markersize=10, linewidth=0.5) 112 | 113 | if idx==0: col.set_ylabel('Test Accuracy') 114 | col.set_xlabel('Classifier Size') 115 | col.set_ylim(0.3, 1) 116 | col.set_yticks(np.arange(0.3, 1, 0.1)) 117 | col.title.set_text(dests_str[idx]) 118 | 119 | #import pdb; pdb.set_trace() 120 | 121 | handles, labels = col.get_legend_handles_labels() 122 | figs.legend(handles, labels, loc='center right') 123 | 124 | plt.savefig('imgs/classifier_acc.pdf') 125 | 126 | 127 | 128 | if training_method_focus == 1: 129 | 130 | d_str = ['AP', 'UCD', 'Centaur'] 131 | df_acc.loc[df_acc['dest'] == 'ap', 'dest'] = d_str[0] 132 | df_acc.loc[df_acc['dest'] == 'iot', 'dest'] = d_str[1] 133 | df_acc.loc[df_acc['dest'] == 'dyn', 'dest'] = d_str[2] 134 | 135 | df_acc['dest_index'] = df_acc['dest'] 136 | df_acc = df_acc.set_index('dest_index').loc[d_str] 137 | 138 | classifier_str = ['Small', 'Medium', 'Large'] 139 | 140 | # plot all 141 | figs, axes = plt.subplots(nrows=1, ncols=3, sharey=True, figsize=(5.5,2.6)) 142 | plt.subplots_adjust(wspace=0.3, right=0.78, bottom=0.15) 143 | # plt.subplots_adjust(left=0.1, , wspace=0.9, ) 144 | 145 | markers = ["d", "v", "o", "*", "^", "o"] 146 | 147 | for idx,col in zip(idxs,axes): 148 | classifier_acc = df_acc.loc[df_acc['classifier'] == classifier_order[idx]] 149 | for count, key in enumerate(key_list): 150 | dat = classifier_acc[classifier_acc['key'] == key] 151 | x = dat['dest'].tolist() 152 | y = dat['accuracy'].tolist() 153 | col.plot(x, y, label=key, marker=markers[count], markersize=7, linewidth=0.4, linestyle=(0, (1, 10))) 154 | 155 | if idx==0: col.set_ylabel('Test Accuracy') 156 | col.set_xlabel('Training Methods') 157 | col.set_ylim(0.5, 1.01) 158 | col.set_yticks(np.arange(0.5, 1.01, 0.1)) 159 | col.title.set_text(classifier_str[idx]) 160 | 161 | 162 | handles, labels = col.get_legend_handles_labels() 163 | figs.legend(handles, labels, loc='center right') 164 | 165 | plt.savefig('imgs/classifier_acc.pdf', bbox_inches='tight') -------------------------------------------------------------------------------- /frameworks/centaur/plot/8_data_selection.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from matplotlib.ticker import FuncFormatter 6 | import matplotlib.cm as cm 7 | 8 | from plot_utils import y_fmt, load_acc_pure, load_mac_latency_energy, load_comm_latency_energy, load_res_from_cfg 9 | from plot_utils import load_all_acc 10 | from adjustText import adjust_text 11 | 12 | SMALL_SIZE = 7.5 13 | MEDIUM_SIZE = 9 14 | BIGGER_SIZE = 10 15 | 16 | plt.rc('font', size=SMALL_SIZE) # controls default text sizes 17 | plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title 18 | plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels 19 | plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels 20 | plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels 21 | plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize 22 | plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title 23 | 24 | encoder = 'mobilenet' 25 | classifier = 'medium' 26 | res_folder = 'results' 27 | buff_file = 'results/ds_plot_buff.csv' 28 | 29 | 30 | if not os.path.isfile(buff_file): 31 | # Load accuracy 32 | df_acc = load_all_acc(res_folder, focus='flpa') 33 | # keep records with new data selection args 34 | df_acc = df_acc.loc[~((df_acc['alpha'] == '5') & (df_acc['beta'] == '3') & (df_acc['gamma'] == '0'))] 35 | df_acc['iot_latency'] = 0 36 | df_acc['ap_latency'] = 0 37 | df_acc['iot_energy'] = 0 38 | df_acc['ap_energy'] = 0 39 | 40 | # Load cost 41 | a = ['1.0', '3.0', '5.0', '10.0'] 42 | b = ['10.0', '5.0', '3.0', '1.0'] 43 | g = ['0.0', '1.0', '3.0', '5.0'] 44 | 45 | ab = ['a'+a+'_b'+b for a,b in zip(a,b)] 46 | filedirs = [] 47 | for i in range(len(ab)): 48 | for j in range(len(g)): 49 | filedir = res_folder + '/acc_dyn_1cifar10_2mobi_3medium_c8|100_e3_' + ab[i] + '_g' + g[j] + '_flpa1000.csv' 50 | filedirs.append(filedir) 51 | filedirs.append(res_folder + '/acc_ap_1cifar10_2mobi_3medium_c8|100_e3_a5.0_b3.0_g0.0_flpa1000.csv') 52 | filedirs.append(res_folder + '/acc_iot_1cifar10_2mobi_3medium_c8|100_e3_a5.0_b3.0_g0.0_flpa1000.csv') 53 | 54 | a_list = np.repeat(a, 4).tolist() + ['5.0'] * 2 55 | b_list = np.repeat(b, 4).tolist() + ['3.0'] * 2 56 | g_list = g*4 + ['0.0'] * 2 57 | 58 | # load configuration from cfg file 59 | computation_cfg, communication_cfg = load_res_from_cfg() 60 | iot_freq, ap_freq, count_inference, iot_energy_mac, ap_energy_mac = computation_cfg 61 | up_iot, down_iot, up_ap, down_ap, sample_size, iot_energy_comm, ap_energy_comm = communication_cfg 62 | 63 | for idx in range(len(filedirs)): 64 | # if os.path.isfile(filedir): 65 | filedir = filedirs[idx] 66 | alpha = a_list[idx] 67 | beta = b_list[idx] 68 | gamma = g_list[idx] 69 | 70 | # load and compute the latency and energy cost 71 | post_arg = filedir.replace('.csv','').replace(res_folder + '/acc_','') 72 | dest = post_arg[:3].replace('_','') 73 | post_arg = post_arg.replace('ap_','').replace('iot_','').replace('dyn_','') 74 | 75 | mac_dyn, mac_ap, mac_iot, _, _, _ = load_mac_latency_energy(encoder,classifier, post_arg, iot_freq, ap_freq, iot_energy_mac, ap_energy_mac, dests=dest) 76 | comm_dyn, comm_ap, comm_iot, _, _, _ = load_comm_latency_energy(encoder,classifier, post_arg, up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size, dests=dest) 77 | 78 | sum_x = [] 79 | for k in range(2, 6): 80 | if mac_dyn is not None: sum_x.append([i+j for i,j in zip(mac_dyn[k], comm_dyn[k])]) 81 | if mac_ap is not None: sum_x.append([i+j for i,j in zip(mac_ap[k], comm_ap[k])]) 82 | if mac_iot is not None: sum_x.append([i+j for i,j in zip(mac_iot[k], comm_iot[k])]) 83 | 84 | # put the cost into the table 85 | bool = (df_acc['dest']==dest) & (df_acc['alpha']==alpha) & (df_acc['beta']==beta) & (df_acc['gamma']==gamma) 86 | df_acc.loc[bool, 'iot_latency'] = max(sum_x[0]) 87 | df_acc.loc[bool, 'ap_latency'] = max(sum_x[1]) 88 | df_acc.loc[bool, 'iot_energy'] = max(sum_x[2]) 89 | df_acc.loc[bool, 'ap_energy'] = max(sum_x[3]) 90 | 91 | df_acc['key'] = df_acc['alpha'] + ','+df_acc['beta'] +','+ df_acc['gamma'] 92 | df_acc['key'] = df_acc['key'].str.replace('\.0','') 93 | 94 | df_acc.to_csv('results/ds_plot_buff.csv') 95 | 96 | else: 97 | print('loading existing buff !!') 98 | df_acc = pd.read_csv(buff_file, sep=',') 99 | 100 | 101 | # plot 102 | figs, axes = plt.subplots(nrows=1, ncols=2, figsize=(5,3)) 103 | plt.subplots_adjust(left=0.1, right=0.9, wspace=0.35, bottom=0.15, top=0.90) 104 | 105 | x0 = df_acc['accuracy'].tolist() 106 | ln = len(x0) 107 | x = x0[1:ln-1] 108 | 109 | txts0 = df_acc['key'].tolist() 110 | txts = txts0[1:ln-1] 111 | 112 | #import pdb; pdb.set_trace() 113 | 114 | for idx in range(len(axes)): 115 | col = axes[idx] 116 | if idx == 0: 117 | y10 = df_acc['iot_latency'].tolist() 118 | y20 = df_acc['iot_energy'].tolist() 119 | elif idx == 1: 120 | y10 = df_acc['ap_latency'].tolist() 121 | y20 = df_acc['ap_energy'].tolist() 122 | 123 | y1 = y10[1:ln-1] 124 | y2 = y20[1:ln-1] 125 | 126 | colors = cm.rainbow(np.linspace(0, 1, len(x))) 127 | col.scatter(x, y1, s=8, color=colors) 128 | 129 | if y10[0] > 1000: col.scatter([x0[0]], [y10[0]], s=20, color='black',marker=6) 130 | if y10[ln-1] > 1000: col.scatter([x0[ln-1]], [y10[ln-1]], s=20, color='black',marker=7) 131 | 132 | texts = [col.text(x[i], y1[i], txts[i], size=6) for i in range(len(x))] # ha='center', va='center' 133 | adjust_text(texts, ax=col) 134 | # adjust_text(texts, ax=col, arrowprops=dict(arrowstyle="-", color='k', lw=0.5), avoid_text=True) # avoid_self=True, 135 | col.invert_xaxis() 136 | col.set_xlim(0.93, 0.8) 137 | #col.set_xticks(np.arange(0.92, 0.80, 0.02)) 138 | col.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 139 | col.set_xlabel('Test Accuracy') 140 | if idx==0: col.set_title('Workload on TEDs') 141 | if idx==0: col.set_ylabel('Accumulated Latency (Seconds)') 142 | 143 | col_tx = col.twinx() 144 | col_tx.scatter(x, y2, s=0.01, color='white', marker='') 145 | col_tx.set_xlim(0.93, 0.8) 146 | #col.set_xticks(np.arange(0.92, 0.80, 0.02)) 147 | col_tx.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 148 | col_tx.set_xlabel('Test Accuracy') 149 | if idx==1: col_tx.set_title('Workload on APs') 150 | if idx==1: col_tx.set_ylabel('Accumulated Energy (Joule)') 151 | 152 | 153 | plt.savefig('imgs/data_selection_args.pdf') 154 | 155 | 156 | -------------------------------------------------------------------------------- /frameworks/centaur/plot/0_sensor_acc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | 4 | import os 5 | import pandas as pd 6 | import numpy as np 7 | import pandas as pd 8 | import matplotlib.pyplot as plt 9 | from matplotlib.ticker import FuncFormatter 10 | 11 | # from plot_utils import y_fmt, load_acc_pure, load_mac_latency_energy, load_comm_latency_energy, load_res_from_cfg 12 | 13 | plt.rcParams.update({'font.size': 9}) 14 | 15 | def load_accs(post_arg, res_folder): 16 | filedir = res_folder + "/" + "acc_ap" + post_arg + ".csv" 17 | dat_full = pd.read_csv(filedir, sep=',') 18 | 19 | filedir = res_folder + "/" + "acc_iot" + post_arg + ".csv" 20 | dat_clas = pd.read_csv(filedir, sep=',') 21 | 22 | filedir = res_folder + "/" + "acc_dyn" + post_arg + ".csv" 23 | dat_dyn = pd.read_csv(filedir, sep=',') 24 | 25 | acc_ap = dat_full['accuracy'].tolist()[1:] 26 | acc_iot = dat_clas['accuracy'].tolist()[1:] 27 | acc_dyn = dat_dyn['accuracy'].tolist()[1:] 28 | return acc_ap, acc_iot, acc_dyn 29 | 30 | 31 | def load_acc_one(post_arg, res_folder): 32 | post_arg_small = '1cifar10_2mobi_3small_c8|100_e3_a5_b3_g0' 33 | post_arg_medium = '1cifar10_2mobi_3medium_c8|100_e3_a5_b3_g0' 34 | post_arg_large = '1cifar10_2mobi_3large_c8|100_e3_a5_b3_g0' 35 | post_args = [post_arg_small, post_arg_medium, post_arg_large] 36 | 37 | acc_ap, acc_iot, acc_dyn = [], [], [] 38 | for post_arg in post_args: 39 | 40 | filedir = res_folder + "/" + "acc_ap_" + post_arg + ".csv" 41 | dat_full = pd.read_csv(filedir, sep=',') 42 | acc_ap.append(dat_full['accuracy'].tolist()[1:]) 43 | 44 | filedir = res_folder + "/" + "acc_iot_" + post_arg + ".csv" 45 | dat_clas = pd.read_csv(filedir, sep=',') 46 | acc_iot.append(dat_clas['accuracy'].tolist()[1:]) 47 | 48 | filedir = res_folder + "/" + "acc_dyn_" + post_arg + ".csv" 49 | dat_dyn = pd.read_csv(filedir, sep=',') 50 | acc_dyn.append(dat_dyn['accuracy'].tolist()[1:]) 51 | 52 | acc_ap = [(i+j+k)/3 for i,j,k in zip(acc_ap[0], acc_ap[1], acc_ap[2])] 53 | acc_iot = [(i+j+k)/3 for i,j,k in zip(acc_iot[0], acc_iot[1], acc_iot[2])] 54 | acc_dyn = [(i+j+k)/3 for i,j,k in zip(acc_dyn[0], acc_dyn[1], acc_dyn[2])] 55 | 56 | return acc_ap, acc_iot, acc_dyn 57 | 58 | 59 | lebel_dyn = 'Centaur' 60 | res_folder = 'results' 61 | encoders = ['sencoder_p', 'sencoder_u', 'sencoder_m'] 62 | classifier = 'medium' 63 | 64 | # buff_file = 'results/intro_plot_buff.csv' 65 | 66 | post_args = ['_1pamap_2senc_3medium_c8|20_e3_a5_b3_g0_flpa1000_mr0|0'] 67 | post_args = post_args+ ['_1uci_har_2senc_3medium_c8|20_e3_a5_b3_g0_flpa1000_mr0|0'] 68 | post_args = post_args+ ['_1motion_2senc_3medium_c8|20_e3_a5_b3_g0_flpa1000_mr0|0'] 69 | post_args = post_args+ ['pt_1motion_2senc_3medium_c8|20_e3_a5_b3_g0_flpa1000_mr0|0'] 70 | 71 | # if not os.path.isfile(buff_file): 72 | if True: 73 | # load configuration from cfg file 74 | 75 | for encoder, post_arg in zip(encoders, post_args): 76 | print(f"loading {post_arg}") 77 | acc_ap, acc_iot, acc_dyn = load_accs(post_arg, res_folder) 78 | print() 79 | 80 | 81 | #figs, axes = plt.subplots(nrows=1, ncols=2, figsize=(8,3)) 82 | figs, col = plt.subplots(figsize=(5,3)) 83 | plt.subplots_adjust(left=0.14, right=0.96, bottom=0.18, top=0.95) 84 | 85 | 86 | lw = 1 87 | markers = ["d", "v", "s", "^", "o", "*"] 88 | anno_text_size = 8 89 | 90 | ##################### 91 | # col = axes[0] 92 | x_iot = [x[i] for i in iot_idx] 93 | y1_iot = [y1[i] for i in iot_idx] 94 | x_dyn = [x[i] for i in dyn_idx] 95 | y1_dyn = [y1[i] for i in dyn_idx] 96 | index = list(range(4)) 97 | pointsize1 = 8 98 | pointsize2 = 7 99 | 100 | for idx in range(3): 101 | # p1, = col.plot(x_iot[idx], y1_iot[idx], 'y', label='TED training', marker=markers[idx], linestyle='--') 102 | # p2, = col.plot(x_dyn[idx], y1_dyn[idx], 'r', label=lebel_dyn, marker=markers[idx], linestyle='--') 103 | # if idx >= 1: 104 | # col.plot([x_iot[idx-1],x_iot[idx]], [y1_iot[idx-1],y1_iot[idx]], 'y', linestyle='--', linewidth=lw) 105 | # col.plot([x_dyn[idx-1],x_dyn[idx]], [y1_dyn[idx-1],y1_dyn[idx]], 'r', linestyle='--', linewidth=lw) 106 | 107 | p1, = col.plot(x_iot[idx], y1_iot[idx], 'y', label='TED training', marker='v', markersize= pointsize1, linestyle='') 108 | p2, = col.plot(x_dyn[idx], y1_dyn[idx], 'r', label=lebel_dyn, marker='s', markersize= pointsize2, linestyle='') 109 | col.plot([x_iot[idx], x_dyn[idx]], [y1_iot[idx], y1_dyn[idx]], 'black', linestyle=(0, (4, 4)), linewidth=lw) 110 | if idx ==0: 111 | col.text(x_iot[idx]*0.96, y1_iot[idx]+5000, "PAMAP2", size=anno_text_size) 112 | elif idx == 1: 113 | col.text(x_iot[idx]*0.92, y1_iot[idx]+5000, "UCI_HAR", size=anno_text_size) 114 | else : 115 | col.text(x_iot[idx]*0.96, y1_iot[idx]+5000, "MotionSense", size=anno_text_size) 116 | 117 | 118 | #### avergaged #### 119 | # text_ = 'avg. ~10% higher accuracy' 120 | # col.annotate('', xy=(0.925, 170000), xycoords='data', 121 | # xytext=(0.625, 170000), textcoords='data', 122 | # arrowprops=dict(arrowstyle="->, head_width=0.3", color='b', lw=2.5, ls='-')) 123 | # col.annotate(text_, xy=(0.75, 180000), xycoords='data', ha='center', color='b', size=9) 124 | 125 | # text_ = 'avg.\n~58%\nlower\nlatency' 126 | # col.annotate('', xy=(0.945, 20000), xycoords='data', 127 | # xytext=(0.945, 260000), textcoords='data', 128 | # arrowprops=dict(arrowstyle="->, head_width=0.3", color='b', lw=2.5,ls='-')) 129 | # col.annotate(text_, xy=(0.972, 180000), xycoords='data', ha='center', color='b', size=9) 130 | 131 | 132 | 133 | # text_ = 'Better' 134 | # col.annotate('', xy=(0.4, 0), xycoords='data', 135 | # xytext=(0.4, -50000), textcoords='data', 136 | # arrowprops=dict(arrowstyle="->, head_width=0.3", color='black', lw=2.5,ls='-')) 137 | # col.annotate(text_, xy=(0.4, -50000), xycoords='data', ha='center', color='black', size=12) 138 | 139 | 140 | 141 | col.set_xlabel('Test Accuracy') 142 | col.set_xlim(0.55, 1) 143 | col.set_xticks(np.arange(0.50, 1.01, 0.05)) 144 | col.set_ylim(0, 120000) 145 | 146 | col.set_ylabel('Latency (Seconds)') 147 | col.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 148 | 149 | # col.legend() 150 | col.legend((p1, p2), ('UCD Training', lebel_dyn), loc='upper left') #, , ncol=2, fontsize=8) , scatterpoints=0.1, markerscale=0.01 151 | 152 | # col_tx = col.twinx() 153 | 154 | # y1_iot = [y1_energy[i] for i in iot_idx] 155 | # y1_dyn = [y1_energy[i] for i in dyn_idx] 156 | 157 | # for idx in range(4): 158 | # col_tx.plot(x_iot[idx], y1_iot[idx], 'y', linestyle='') 159 | # col_tx.plot(x_dyn[idx], y1_dyn[idx], 'r', linestyle='') 160 | 161 | # col_tx.set_ylabel('Energy Consumption (Joule)') 162 | # col_tx.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 163 | 164 | plt.tight_layout() 165 | 166 | plt.savefig('imgs/intro_sum.pdf', bbox_inches='tight') 167 | 168 | 169 | #import pdb; pdb.set_trace() 170 | 171 | 172 | # print gain 173 | cost_down = [(iot-dyn)/iot for iot, dyn in zip(y1_iot, y1_dyn)] 174 | acc_up = [(dyn-iot)/iot for iot, dyn in zip(x_iot, x_dyn)] 175 | 176 | # 177 | print(f"encoders: {encoders}") 178 | print(f"acc up: {acc_up}") 179 | print(f"cost down: {cost_down}") 180 | print(sum(cost_down)/4) 181 | print(sum(acc_up)/4) -------------------------------------------------------------------------------- /frameworks/centaur/make_script.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | run_all_accuracy_tests = 1 5 | run_unbelanced_tests = 1 6 | run_data_selection_arg_tests = 1 7 | run_client_number_arg_tests = 1 8 | run_mobility_tests = 1 9 | run_fed_prox_tests = 1 10 | run_vanilla_tests = 1 11 | 12 | # shell script for testing encoders, classifiers, and datasets 13 | GPUs = '7,6,5,3,2,1,4,0' 14 | num_gpus = 8 15 | num_cpus = 50 16 | gpus_per_client = 0.25 17 | fc = 8 18 | 19 | 20 | 21 | if run_all_accuracy_tests == 1: 22 | destine = ['ap','iot','dyn'] 23 | encoders = ['mobilenet', 'shufflenet', 'mnasnet', 'efficientnet'] 24 | classifiers = ['medium', 'small', 'large'] 25 | datasets = ['cifar10', 'cifar100', 'emnist'] 26 | 27 | with open('run_ecd.sh', 'w') as f: 28 | f.write('') 29 | 30 | for c in classifiers: 31 | for d in datasets: 32 | for e in encoders: 33 | for dest in destine: 34 | with open('run_ecd.sh', 'a') as f: 35 | f.write("CUDA_VISIBLE_DEVICES={} python main.py --save_results=True --destine={} --dataset={} --encoder={} --classifier={} 2>&1 | tee log/{}_{}_{}_{}.txt\n".format(GPUs,dest,d,e,c,dest,d,e,c)) 36 | 37 | print("Done with creating [run_all_accuracy_tests] scripts.") 38 | 39 | 40 | if run_unbelanced_tests == 1: 41 | # shell script for testing unbelanced data 42 | destine = ['ap','iot','dyn'] 43 | fl_partitioning_alpha = ['0.001', '0.01', '0.1'] 44 | classifier = 'medium' 45 | 46 | with open('run_flpa.sh', 'w') as f: 47 | f.write('') 48 | 49 | for pa in fl_partitioning_alpha: 50 | for dest in destine: 51 | with open('run_flpa.sh', 'a') as f: 52 | f.write("CUDA_VISIBLE_DEVICES={} python main.py --save_results=True --destine={} --dataset=cifar10 --encoder=mobilenet --classifier={} --fl_partitioning_alpha={} 2>&1 | tee log/{}_flpa{}.txt\n".format(GPUs,dest,classifier,pa,dest,pa)) 53 | 54 | print("Done with creating [run_unbelanced_tests] scripts.") 55 | 56 | 57 | if run_data_selection_arg_tests == 1: 58 | # shell script for testing data selection arguments 59 | alpha = ['1', '3', '5', '10'] 60 | beta = ['1', '3', '5', '10'] 61 | beta.reverse() 62 | gamma = ['0', '1', '3', '5'] 63 | classifier = 'medium' 64 | 65 | with open('run_ds.sh', 'w') as f: 66 | f.write('') 67 | 68 | for a,b in zip(alpha, beta): 69 | for g in gamma: 70 | with open('run_ds.sh', 'a') as f: 71 | f.write("CUDA_VISIBLE_DEVICES={} python main.py --save_results=True --destine=dyn --dataset=cifar10 --encoder=mobilenet --classifier={} --alpha={} --beta={} --gamma={} 2>&1 | tee log/dyn_a{}_b{}_g{}.txt\n".format(GPUs,classifier,a,b,g,a,b,g)) 72 | 73 | print("Done with creating [run_data_selection_arg_tests] scripts.") 74 | 75 | 76 | if run_client_number_arg_tests == 1: 77 | # shell script for testing the number of clients 78 | classifier = 'medium' 79 | destine = ['ap','iot','dyn'] 80 | min_fit_clients_ratio = [0.1, 0.2, 0.5] 81 | min_available_clients = [10, 100, 1000] 82 | 83 | with open('run_client.sh', 'w') as f: 84 | f.write('') 85 | 86 | for ac in min_available_clients: 87 | for fc_ratio in min_fit_clients_ratio: 88 | min_fit_clients = int(fc_ratio*ac) 89 | if min_fit_clients == 1: 90 | continue 91 | 92 | num_client_gpus = 0.25 93 | num_client_cpus = 2 94 | 95 | # print("{}={}/{}".format(num_client_gpus,num_gpus,min_fit_clients)) 96 | 97 | for dest in destine: 98 | with open('run_client.sh', 'a') as f: 99 | f.write("CUDA_VISIBLE_DEVICES={} python main.py --save_results=True --classifier={} --destine={} --dataset=cifar10 --encoder=mobilenet --num_client_gpus={} --num_client_cpus={} --min_available_clients={} --min_fit_clients={} 2>&1 | tee log/{}_fc{}_ac{}.txt\n".format(GPUs,classifier,dest,num_client_gpus,num_client_cpus,ac,min_fit_clients,dest,min_fit_clients,ac)) 100 | 101 | print("Done with creating [run_client_number_arg_tests] scripts.") 102 | 103 | 104 | if run_mobility_tests == 1: 105 | classifier = 'medium' 106 | destine = ['ap','iot','dyn'] 107 | mrate_mins = [0.1, 0.4, 0.7] 108 | mrate_maxs = [0.4, 0.7, 1.0] 109 | 110 | with open('run_mobility.sh', 'w') as f: 111 | f.write('') 112 | 113 | for mrate_min, mrate_max in zip(mrate_mins, mrate_maxs): 114 | for dest in destine: 115 | with open('run_mobility.sh', 'a') as f: 116 | f.write("CUDA_VISIBLE_DEVICES={} python main.py --save_results=True --classifier={} --destine={} --dataset=cifar10 --encoder=mobilenet --mrate_min={} --mrate_max={} 2>&1 | tee log/{}_mrs{}_mrl{}.txt\n".format(GPUs,classifier,dest,mrate_min,mrate_max,dest,mrate_min,mrate_max)) 117 | 118 | print("Done with creating [run_mobility_tests] scripts.") 119 | 120 | 121 | if run_fed_prox_tests == 1: 122 | 123 | fed_prox_mu = [1, 0.1, 0.01, 0.001] 124 | fed_prox_mu_97 = [0.97*i for i in fed_prox_mu] 125 | fed_prox_mu_98 = [0.98*i for i in fed_prox_mu] 126 | fed_prox_mu_99 = [0.99*i for i in fed_prox_mu] 127 | fed_prox_mu_101 = [1.01*i for i in fed_prox_mu] 128 | fed_prox_mu_102 = [1.02*i for i in fed_prox_mu] 129 | fed_prox_mu_103 = [1.03*i for i in fed_prox_mu] 130 | 131 | fed_prox_mu = fed_prox_mu_101 + fed_prox_mu_102 + fed_prox_mu_103 + fed_prox_mu_97 + fed_prox_mu_98 + fed_prox_mu_99 132 | 133 | with open('run_fedprox.sh', 'w') as f: 134 | f.write('') 135 | 136 | for mu in fed_prox_mu: 137 | with open('run_fedprox.sh', 'a') as f: 138 | f.write("CUDA_VISIBLE_DEVICES={} python main.py --save_results=True --destine=dyn --dataset=cifar10 --encoder=mobilenet --classifier=medium --fed_prox_mu={} 2>&1 | tee log/mu{}.txt\n".format(GPUs,str(mu),str(mu))) 139 | 140 | print("Done with creating [run_fed_prox_tests] scripts.") 141 | 142 | 143 | if run_vanilla_tests == 1: 144 | destine = ['ap'] 145 | encoders = ['mobilenet', 'shufflenet', 'mnasnet', 'efficientnet'] 146 | classifiers = ['large'] 147 | datasets = ['cifar10', 'cifar100', 'emnist'] 148 | 149 | with open('run_vanilla.sh', 'w') as f: 150 | f.write('') 151 | 152 | for d in datasets: 153 | for e in encoders: 154 | for clas in classifiers: 155 | for dest in destine: 156 | with open('run_vanilla.sh', 'a') as f: 157 | if e == 'efficientnet': 158 | f.write("CUDA_VISIBLE_DEVICES={} python main.py --save_results=True --vanilla=True --destine={} --dataset={} --encoder={} --classifier=large --base_rate=1 --sample_size=1 --num_rounds=200 --alpha=1000 --beta=0.001 --min_fit_clients=5 --num_client_gpus=1 2>&1 | tee log/vanilla_{}_{}_{}_small.txt\n".format(GPUs,dest,d,e,dest,d,e)) 159 | else: 160 | f.write("CUDA_VISIBLE_DEVICES={} python main.py --save_results=True --vanilla=True --destine={} --dataset={} --encoder={} --classifier={} --base_rate=1 --sample_size=1 --num_rounds=200 --alpha=1000 --beta=0.001 --min_fit_clients={} 2>&1 | tee log/vanilla_{}_{}_{}_{}.txt\n".format(GPUs,dest,d,e,clas,fc,dest,d,e,clas)) 161 | 162 | print("Done with creating [run_vanilla_tests] scripts.") 163 | 164 | 165 | print("ALL DONE.") -------------------------------------------------------------------------------- /frameworks/centaur/running_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Dict, Callable, Optional, Tuple, List 3 | import flwr as fl 4 | from flwr.common.typing import Scalar 5 | import torch 6 | import torchvision 7 | from hiefed.client import get_params, set_params, RayClient, init_network, test 8 | 9 | import os 10 | import numpy as np 11 | 12 | ### 13 | # This file is used for setting up the experiments 14 | ### 15 | 16 | def get_parser(): 17 | parser = argparse.ArgumentParser(description="Flower Simulation of Centaur with PyTorch") 18 | 19 | parser.add_argument("--num_client_cpus", type=float, default=4, help="the number of CPUs for each client") 20 | parser.add_argument("--num_client_gpus", type=float, default=1, help="the number of GPUs for each client") 21 | parser.add_argument("--num_rounds", type=int, default=100, help="the total number of FL training rounds") 22 | parser.add_argument("--num_epochs", type=int, default=3, help="the number of basic local epochs") 23 | parser.add_argument("--batch_size", type=int, default=64, help="the size of local batch") 24 | parser.add_argument("--sample_size", type=int, default=30, help="the average size of one local sample ()in KB") 25 | 26 | parser.add_argument("--min_fit_clients", type=int, default=8, help="the number of participating clients in each round") 27 | parser.add_argument("--min_available_clients", type=int, default=20, help="the number of total clients") 28 | 29 | parser.add_argument("--dataset", type=str, default="pamap", help="the dataset to use [cifar10], [cifar100], [emnist], [pamap], [uci_har], or [motion]") 30 | parser.add_argument("--encoder", type=str, default="sencoder_p", help="the encoder to use [mobilenet], [efficientnet], [shufflenet], [mnasnet], [sencoder_p], [sencoder_u], or [sencoder_m]") 31 | parser.add_argument("--classifier", type=str, default="medium", help="the classifier size to use [small], [medium], or [large]") 32 | parser.add_argument("--pre_trained", type=bool, default=False, help="whether the encoder's weights are pre-trained [True] or [False]") 33 | 34 | parser.add_argument("--debug", type=bool, default=False, help="debug mode [True] or [False]") 35 | parser.add_argument("--vanilla", type=bool, default=False, help="whether to start vanilla training mode [True] or [False]") 36 | parser.add_argument("--save_results", type=bool, default=True, help="whether to save results or not [True] or [False]") 37 | parser.add_argument("--save_models", type=bool, default=False, help="whether to save the best trained model or not [True] or [False]") 38 | 39 | parser.add_argument("--fed_prox_mu", type=float, default=0, help="The mu parameter in FedProx") 40 | parser.add_argument("--mrate_min", type=float, default=0, help="the lowest connection rate of the range in mobility model") 41 | parser.add_argument("--mrate_max", type=float, default=0, help="the highest connection rate of the range in mobility model") 42 | 43 | parser.add_argument("--fl_partitioning_alpha", type=float, default=1000, help="the LDA-alpha parameter used to partition dataset in FL (LDA)") 44 | parser.add_argument("--alpha", type=float, default=5, help="alpha value in data selection (for dropping, the smaller this number is, the larger portion of data it gets)") 45 | parser.add_argument("--beta", type=float, default=3, help="beta value in data selection (for access point)") 46 | parser.add_argument("--gamma", type=float, default=0, help="gamma value in data selection (for norm-based selection)") 47 | 48 | parser.add_argument("--keep_localtraining", type=bool, default=True, help="keep training going when IoT offline [True] or [False]") 49 | parser.add_argument("--base_rate", type=float, default=.5, help="the basic rate of local data used to training") 50 | parser.add_argument("--step_rate", type=float, default=.2, help="the step rate of local data used to training when offline") 51 | 52 | parser.add_argument("--destine", type=str, default="dyn", help="[dyn]=our Centaur training; [ap]=access point training; [iot]=ultra-constrained edge device training") 53 | parser.add_argument("--client_configs", type=str, default="client_configurations.csv", help="file directory of client configuration") 54 | 55 | return parser 56 | 57 | 58 | 59 | def fit_config(server_round: int) -> Dict[str, Scalar]: 60 | """ 61 | This function will be used for the server to pass new configuration values 62 | to the client at each round. This function will be called by the chosen strategy 63 | and must return a dictionary of configuration key values pairs. 64 | """ 65 | args = get_parser().parse_args() 66 | config = { 67 | "server_round": str(server_round), 68 | "epochs": str(args.num_epochs), 69 | "batch_size": str(args.batch_size), 70 | "sample_size": str(args.sample_size), 71 | "alpha": str(args.alpha), 72 | "beta": str(args.beta), 73 | "gamma": str(args.gamma), 74 | "destine": str(args.destine), 75 | "dataset": str(args.dataset), 76 | "fed_prox_mu": str(args.fed_prox_mu), 77 | "mrate_min": str(args.mrate_min), 78 | "mrate_max": str(args.mrate_max), 79 | "keep_localtraining": str(args.keep_localtraining), 80 | "classifier": str(args.classifier), 81 | "base_rate": str(args.base_rate), 82 | "step_rate": str(args.step_rate), 83 | "debug": str(args.debug), 84 | } 85 | return config 86 | 87 | 88 | 89 | def get_evaluate_fn( 90 | testset, saver, num_classes, 91 | ) -> Callable[[fl.common.NDArrays], Optional[Tuple[float, float]]]: 92 | """ 93 | All built-in strategies support centralized evaluation by providing an evaluation function 94 | during initialization. An evaluation function is any function that can take 95 | the current global model parameters as input and return evaluation results. 96 | """ 97 | parser = get_parser() 98 | args = get_parser().parse_args() 99 | def evaluate( 100 | server_round: int, parameters: fl.common.NDArrays, config: Dict[str, Scalar] 101 | ) -> Optional[Tuple[float, float]]: 102 | """Use the entire test set for evaluation.""" 103 | 104 | # determine device 105 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 106 | 107 | model = init_network(encoder=args.encoder, classifier=args.classifier, num_classes=num_classes, pre_trained=args.pre_trained) 108 | set_params(model, parameters) 109 | model.to(device) 110 | 111 | testloader = torch.utils.data.DataLoader(testset, batch_size=50) 112 | loss, accuracy = test(model, testloader, device=device) 113 | 114 | # save the accuracy 115 | saver.accuracy_saver(server_round, accuracy, parser.parse_args().destine) 116 | 117 | if args.save_models: 118 | f_dir = "./saved_models/"+args.dataset+"/" 119 | if os.path.exists(f_dir): 120 | best_acc = np.load(f_dir+"best_acc.npy") 121 | if accuracy > best_acc: 122 | np.save(f_dir+"best_acc.npy", accuracy) 123 | torch.save(model.state_dict(), f_dir+"best_model.pt") 124 | print("##### Saved Accuracy:", accuracy, best_acc) 125 | else: 126 | os.makedirs(f_dir) 127 | np.save(f_dir+"best_acc.npy", accuracy) 128 | 129 | 130 | # return statistics 131 | return loss, {"accuracy": accuracy} 132 | 133 | return evaluate -------------------------------------------------------------------------------- /frameworks/centaur/plot/0_intro.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | from matplotlib.ticker import FuncFormatter 7 | 8 | from plot_utils import y_fmt, load_acc_pure, load_mac_latency_energy, load_comm_latency_energy, load_res_from_cfg 9 | 10 | plt.rcParams.update({'font.size': 9}) 11 | 12 | 13 | lebel_dyn = 'Centaur' 14 | res_folder = 'results' 15 | encoders = ['efficientnet', 'mobilenet', 'shufflenet', 'mnasnet'] 16 | classifier = 'medium' 17 | 18 | buff_file = 'results/intro_plot_buff.csv' 19 | 20 | post_args = ['1cifar10_2' + e[:4] + '_3medium_c8|100_e3_a5_b3_g0' for e in encoders[1:]] 21 | post_args = ['1cifar10_2effi_3medium_c5|100_e3_a5_b3_g0'] + post_args 22 | 23 | 24 | if not os.path.isfile(buff_file): 25 | # load configuration from cfg file 26 | computation_cfg, communication_cfg = load_res_from_cfg() 27 | iot_freq, ap_freq, count_inference, iot_energy_mac, ap_energy_mac = computation_cfg 28 | up_iot, down_iot, up_ap, down_ap, sample_size, iot_energy_comm, ap_energy_comm = communication_cfg 29 | 30 | 31 | key_e, key_d, x, y1, y1_energy, y2 = [], [], [], [], [], [] 32 | for encoder, post_arg in zip(encoders, post_args): 33 | print(f"loading {post_arg}") 34 | acc_ap, acc_iot, acc_dyn = load_acc_pure(post_arg, res_folder) 35 | 36 | mac_dyn, mac_ap, mac_iot, _, _, _ = load_mac_latency_energy(encoder,classifier, post_arg, iot_freq, ap_freq, iot_energy_mac, ap_energy_mac) 37 | comm_dyn, comm_ap, comm_iot, _, _, _ = load_comm_latency_energy(encoder,classifier, post_arg, up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size) 38 | 39 | sum_dyn, sum_ap, sum_iot = [], [], [] 40 | for k in range(2, len(mac_dyn)): 41 | sum_dyn.append([i+j for i,j in zip(mac_dyn[k], comm_dyn[k])]) 42 | sum_ap.append([i+j for i,j in zip(mac_ap[k], comm_ap[k])]) 43 | sum_iot.append([i+j for i,j in zip(mac_iot[k], comm_iot[k])]) 44 | 45 | key_e = key_e + [encoder] * 3 46 | key_d = key_d + ['ap', 'iot', 'dyn'] 47 | x = x + [max(acc_ap), max(acc_iot), max(acc_dyn)] 48 | y1 = y1 + [max(sum_ap[0]), max(sum_iot[0]), max(sum_dyn[0])] 49 | y1_energy = y1_energy + [max(sum_ap[2]), max(sum_iot[2]), max(sum_dyn[2])] 50 | y2 = y2 + [max(sum_ap[1]), max(sum_iot[1]), max(sum_dyn[1])] 51 | 52 | # buffer 53 | dat_ = {'key_e': key_e, 54 | 'key_d': key_d, 55 | 'x': x, 56 | 'y1': y1, 57 | 'y1_energy': y1_energy, 58 | 'y2': y2 59 | } 60 | df_acc = pd.DataFrame(data=dat_) 61 | df_acc.to_csv(buff_file) 62 | 63 | else: 64 | print('loading existing buff !') 65 | df_acc = pd.read_csv(buff_file, sep=',') 66 | 67 | # import pdb; pdb.set_trace() 68 | 69 | 70 | 71 | key_e = df_acc['key_e'] 72 | key_d = df_acc['key_d'] 73 | x = df_acc['x'] 74 | y1 = df_acc['y1'] 75 | y1_energy = df_acc['y1_energy'] 76 | y2 = df_acc['y2'] 77 | 78 | ap_idx = [i for i, j in enumerate(key_d) if j == 'ap'] 79 | iot_idx = [i for i, j in enumerate(key_d) if j == 'iot'] 80 | dyn_idx = [i for i, j in enumerate(key_d) if j == 'dyn'] 81 | 82 | #figs, axes = plt.subplots(nrows=1, ncols=2, figsize=(8,3)) 83 | figs, col = plt.subplots(figsize=(5,3)) 84 | plt.subplots_adjust(left=0.14, right=0.96, bottom=0.18, top=0.95) 85 | 86 | 87 | lw = 1 88 | markers = ["d", "v", "s", "^", "o", "*"] 89 | anno_text_size = 8 90 | 91 | ##################### 92 | # col = axes[0] 93 | x_iot = [x[i] for i in iot_idx] 94 | y1_iot = [y1[i] for i in iot_idx] 95 | x_dyn = [x[i] for i in dyn_idx] 96 | y1_dyn = [y1[i] for i in dyn_idx] 97 | index = list(range(4)) 98 | 99 | pointsize1 = 8 100 | pointsize2 = 7 101 | 102 | for idx in range(4): 103 | # p1, = col.plot(x_iot[idx], y1_iot[idx], 'y', label='TED training', marker=markers[idx], linestyle='--') 104 | # p2, = col.plot(x_dyn[idx], y1_dyn[idx], 'r', label=lebel_dyn, marker=markers[idx], linestyle='--') 105 | # if idx >= 1: 106 | # col.plot([x_iot[idx-1],x_iot[idx]], [y1_iot[idx-1],y1_iot[idx]], 'y', linestyle='--', linewidth=lw) 107 | # col.plot([x_dyn[idx-1],x_dyn[idx]], [y1_dyn[idx-1],y1_dyn[idx]], 'r', linestyle='--', linewidth=lw) 108 | 109 | p1, = col.plot(x_iot[idx], y1_iot[idx], 'y', label='TED training', marker='v', markersize= pointsize1, linestyle='') 110 | p2, = col.plot(x_dyn[idx], y1_dyn[idx], 'r', label=lebel_dyn, marker='s', markersize= pointsize2, linestyle='') 111 | col.plot([x_iot[idx], x_dyn[idx]], [y1_iot[idx], y1_dyn[idx]], 'black', linestyle=(0, (4, 4)), linewidth=lw) 112 | 113 | col.text(x_iot[idx]*0.96, y1_iot[idx]+15000, encoders[idx], size=anno_text_size) 114 | 115 | 116 | #### avergaged #### 117 | # text_ = 'avg. ~10% higher accuracy' 118 | # col.annotate('', xy=(0.925, 170000), xycoords='data', 119 | # xytext=(0.625, 170000), textcoords='data', 120 | # arrowprops=dict(arrowstyle="->, head_width=0.3", color='b', lw=2.5, ls='-')) 121 | # col.annotate(text_, xy=(0.75, 180000), xycoords='data', ha='center', color='b', size=9) 122 | 123 | # text_ = 'avg.\n~58%\nlower\nlatency' 124 | # col.annotate('', xy=(0.945, 20000), xycoords='data', 125 | # xytext=(0.945, 260000), textcoords='data', 126 | # arrowprops=dict(arrowstyle="->, head_width=0.3", color='b', lw=2.5,ls='-')) 127 | # col.annotate(text_, xy=(0.972, 180000), xycoords='data', ha='center', color='b', size=9) 128 | 129 | 130 | #### single #### 131 | text_ = '16.5% higher accuracy' 132 | col.annotate('', xy=(0.925, 100000), xycoords='data', 133 | xytext=(0.785, 100000), textcoords='data', 134 | arrowprops=dict(arrowstyle="->, head_width=0.3", color='b', lw=2.5, ls='-')) 135 | col.annotate(text_, xy=(0.825, 110000), xycoords='data', ha='center', color='b', size=9) 136 | 137 | text_ = '59.8%\nlower\nlatency' 138 | col.annotate('', xy=(0.93, 105000), xycoords='data', 139 | xytext=(0.93, 260000), textcoords='data', 140 | arrowprops=dict(arrowstyle="->, head_width=0.3", color='b', lw=2.5,ls='-')) 141 | col.annotate(text_, xy=(0.96, 180000), xycoords='data', ha='center', color='b', size=9) 142 | 143 | 144 | 145 | # text_ = 'Better' 146 | # col.annotate('', xy=(0.4, 0), xycoords='data', 147 | # xytext=(0.4, -50000), textcoords='data', 148 | # arrowprops=dict(arrowstyle="->, head_width=0.3", color='black', lw=2.5,ls='-')) 149 | # col.annotate(text_, xy=(0.4, -50000), xycoords='data', ha='center', color='black', size=12) 150 | 151 | 152 | 153 | col.set_xlabel('Test Accuracy') 154 | col.set_xlim(0.55, 1) 155 | col.set_xticks(np.arange(0.55, 1.01, 0.05)) 156 | col.set_ylim(0, 300000) 157 | 158 | col.set_ylabel('Latency (Seconds)') 159 | col.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 160 | 161 | # col.legend() 162 | col.legend((p1, p2), ('UCD Training', lebel_dyn), loc='upper left') #, , ncol=2, fontsize=8) , scatterpoints=0.1, markerscale=0.01 163 | 164 | # col_tx = col.twinx() 165 | 166 | # y1_iot = [y1_energy[i] for i in iot_idx] 167 | # y1_dyn = [y1_energy[i] for i in dyn_idx] 168 | 169 | # for idx in range(4): 170 | # col_tx.plot(x_iot[idx], y1_iot[idx], 'y', linestyle='') 171 | # col_tx.plot(x_dyn[idx], y1_dyn[idx], 'r', linestyle='') 172 | 173 | # col_tx.set_ylabel('Energy Consumption (Joule)') 174 | # col_tx.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 175 | 176 | plt.savefig('imgs/intro_sum.pdf') 177 | 178 | 179 | #import pdb; pdb.set_trace() 180 | 181 | 182 | # print gain 183 | cost_down = [(iot-dyn)/iot for iot, dyn in zip(y1_iot, y1_dyn)] 184 | acc_up = [(dyn-iot)/iot for iot, dyn in zip(x_iot, x_dyn)] 185 | 186 | # 187 | print(f"encoders: {encoders}") 188 | print(f"acc up: {acc_up}") 189 | print(f"cost down: {cost_down}") 190 | print(sum(cost_down)/4) 191 | print(sum(acc_up)/4) -------------------------------------------------------------------------------- /frameworks/centaur/plot/0_intro_s.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | 4 | import os 5 | import pandas as pd 6 | import numpy as np 7 | import pandas as pd 8 | import matplotlib.pyplot as plt 9 | from matplotlib.ticker import FuncFormatter 10 | 11 | from plot_utils import y_fmt, load_acc_pure, load_mac_latency_energy, load_comm_latency_energy, load_res_from_cfg 12 | 13 | plt.rcParams.update({'font.size': 9}) 14 | 15 | 16 | lebel_dyn = 'Centaur' 17 | res_folder = 'results' 18 | encoders = ['sencoder_p', 'sencoder_u', 'sencoder_m'] 19 | classifier = 'medium' 20 | 21 | buff_file = 'results/intro_plot_buff.csv' 22 | 23 | # post_args = ['1pamap_2senc_3large_c8|20_e3_a5_b3_g0_flpa1000_mr0|0'] 24 | # post_args = post_args+ ['1uci_har_2senc_3large_c8|20_e3_a5_b3_g0_flpa1000_mr0|0'] 25 | # post_args = post_args+ ['1motion_2senc_3large_c8|20_e3_a5_b3_g0_flpa1000_mr0|0'] 26 | post_args = ['1pamap_2senc_3medium_c8|20_e3_a5_b3_g0_flpa1000_mr0|0'] 27 | post_args = post_args+ ['1uci_har_2senc_3medium_c8|20_e3_a5_b3_g0_flpa1000_mr0|0'] 28 | post_args = post_args+ ['1motion_2senc_3medium_c8|20_e3_a5_b3_g0_flpa1000_mr0|0'] 29 | # post_args = ['1pamap_2senc_3small_c8|20_e3_a5_b3_g0_flpa1000_mr0|0'] 30 | # post_args = post_args+ ['1uci_har_2senc_3small_c8|20_e3_a5_b3_g0_flpa1000_mr0|0'] 31 | # post_args = post_args+ ['1motion_2senc_3small_c8|20_e3_a5_b3_g0_flpa1000_mr0|0'] 32 | 33 | # post_args = ['1pamap_2effi_3medium_c5|100_e3_a5_b3_g0'] + post_args 34 | 35 | # if not os.path.isfile(buff_file): 36 | if True: 37 | # load configuration from cfg file 38 | computation_cfg, communication_cfg = load_res_from_cfg() 39 | iot_freq, ap_freq, count_inference, iot_energy_mac, ap_energy_mac = computation_cfg 40 | up_iot, down_iot, up_ap, down_ap, sample_size, iot_energy_comm, ap_energy_comm = communication_cfg 41 | 42 | 43 | key_e, key_d, x, y1, y1_energy, y2 = [], [], [], [], [], [] 44 | for encoder, post_arg in zip(encoders, post_args): 45 | print(f"loading {post_arg}") 46 | acc_ap, acc_iot, acc_dyn = load_acc_pure(post_arg, res_folder) 47 | mac_dyn, mac_ap, mac_iot, _, _, _ = load_mac_latency_energy(encoder,classifier, post_arg, iot_freq, ap_freq, iot_energy_mac, ap_energy_mac) 48 | comm_dyn, comm_ap, comm_iot, _, _, _ = load_comm_latency_energy(encoder,classifier, post_arg, up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size) 49 | 50 | sum_dyn, sum_ap, sum_iot = [], [], [] 51 | for k in range(2, len(mac_dyn)): 52 | sum_dyn.append([i+j for i,j in zip(mac_dyn[k], comm_dyn[k])]) 53 | sum_ap.append([i+j for i,j in zip(mac_ap[k], comm_ap[k])]) 54 | sum_iot.append([i+j for i,j in zip(mac_iot[k], comm_iot[k])]) 55 | 56 | key_e = key_e + [encoder] * 3 57 | key_d = key_d + ['ap', 'iot', 'dyn'] 58 | x = x + [max(acc_ap), max(acc_iot), max(acc_dyn)] 59 | y1 = y1 + [max(sum_ap[0]), max(sum_iot[0]), max(sum_dyn[0])] 60 | y1_energy = y1_energy + [max(sum_ap[2]), max(sum_iot[2]), max(sum_dyn[2])] 61 | y2 = y2 + [max(sum_ap[1]), max(sum_iot[1]), max(sum_dyn[1])] 62 | 63 | # buffer 64 | dat_ = {'key_e': key_e, 65 | 'key_d': key_d, 66 | 'x': x, 67 | 'y1': y1, 68 | 'y1_energy': y1_energy, 69 | 'y2': y2 70 | } 71 | df_acc = pd.DataFrame(data=dat_) 72 | df_acc.to_csv(buff_file) 73 | 74 | else: 75 | print('loading existing buff !') 76 | df_acc = pd.read_csv(buff_file, sep=',') 77 | 78 | # import pdb; pdb.set_trace() 79 | 80 | key_e = df_acc['key_e'] 81 | key_d = df_acc['key_d'] 82 | x = df_acc['x'] 83 | y1 = df_acc['y1'] 84 | y1_energy = df_acc['y1_energy'] 85 | y2 = df_acc['y2'] 86 | 87 | ap_idx = [i for i, j in enumerate(key_d) if j == 'ap'] 88 | iot_idx = [i for i, j in enumerate(key_d) if j == 'iot'] 89 | dyn_idx = [i for i, j in enumerate(key_d) if j == 'dyn'] 90 | 91 | #figs, axes = plt.subplots(nrows=1, ncols=2, figsize=(8,3)) 92 | figs, col = plt.subplots(figsize=(5,3)) 93 | plt.subplots_adjust(left=0.14, right=0.96, bottom=0.18, top=0.95) 94 | 95 | 96 | #### single #### 97 | text_ = '54.5% higher accuracy' 98 | col.annotate('', xy=(0.84, 27000), xycoords='data', 99 | xytext=(0.54, 27000), textcoords='data', 100 | arrowprops=dict(arrowstyle="->, head_width=0.3", color='b', lw=2.5, ls='-')) 101 | col.annotate(text_, xy=(0.66, 18000), xycoords='data', ha='center', color='b', size=9) 102 | 103 | text_ = '59.3%\nlower\nlatency' 104 | col.annotate('', xy=(0.545, 28000), xycoords='data', 105 | xytext=(0.545, 65000), textcoords='data', 106 | arrowprops=dict(arrowstyle="->, head_width=0.3", color='b', lw=2.5,ls='-')) 107 | col.annotate(text_, xy=(0.578, 39000), xycoords='data', ha='center', color='b', size=9) 108 | 109 | 110 | 111 | 112 | 113 | 114 | lw = 1 115 | markers = ["d", "v", "s", "^", "o", "*"] 116 | anno_text_size = 8 117 | 118 | ##################### 119 | # col = axes[0] 120 | x_iot = [x[i] for i in iot_idx] 121 | y1_iot = [y1[i] for i in iot_idx] 122 | x_dyn = [x[i] for i in dyn_idx] 123 | y1_dyn = [y1[i] for i in dyn_idx] 124 | index = list(range(4)) 125 | pointsize1 = 8 126 | pointsize2 = 7 127 | 128 | for idx in range(3): 129 | # p1, = col.plot(x_iot[idx], y1_iot[idx], 'y', label='TED training', marker=markers[idx], linestyle='--') 130 | # p2, = col.plot(x_dyn[idx], y1_dyn[idx], 'r', label=lebel_dyn, marker=markers[idx], linestyle='--') 131 | # if idx >= 1: 132 | # col.plot([x_iot[idx-1],x_iot[idx]], [y1_iot[idx-1],y1_iot[idx]], 'y', linestyle='--', linewidth=lw) 133 | # col.plot([x_dyn[idx-1],x_dyn[idx]], [y1_dyn[idx-1],y1_dyn[idx]], 'r', linestyle='--', linewidth=lw) 134 | 135 | p1, = col.plot(x_iot[idx], y1_iot[idx], 'y', label='TED training', marker='v', markersize= pointsize1, linestyle='') 136 | p2, = col.plot(x_dyn[idx], y1_dyn[idx], 'r', label=lebel_dyn, marker='s', markersize= pointsize2, linestyle='') 137 | col.plot([x_iot[idx], x_dyn[idx]], [y1_iot[idx], y1_dyn[idx]], 'black', linestyle=(0, (4, 4)), linewidth=lw) 138 | if idx ==0: 139 | col.text(x_iot[idx]*0.96, y1_iot[idx]+5000, "PAMAP2", size=anno_text_size) 140 | elif idx == 1: 141 | col.text(x_iot[idx]*0.92, y1_iot[idx]+5000, "UCIHAR", size=anno_text_size) 142 | else : 143 | col.text(x_iot[idx]*0.96, y1_iot[idx]+5000, "MotionSense", size=anno_text_size) 144 | 145 | 146 | #### avergaged #### 147 | # text_ = 'avg. ~10% higher accuracy' 148 | # col.annotate('', xy=(0.925, 170000), xycoords='data', 149 | # xytext=(0.625, 170000), textcoords='data', 150 | # arrowprops=dict(arrowstyle="->, head_width=0.3", color='b', lw=2.5, ls='-')) 151 | # col.annotate(text_, xy=(0.75, 180000), xycoords='data', ha='center', color='b', size=9) 152 | 153 | # text_ = 'avg.\n~58%\nlower\nlatency' 154 | # col.annotate('', xy=(0.945, 20000), xycoords='data', 155 | # xytext=(0.945, 260000), textcoords='data', 156 | # arrowprops=dict(arrowstyle="->, head_width=0.3", color='b', lw=2.5,ls='-')) 157 | # col.annotate(text_, xy=(0.972, 180000), xycoords='data', ha='center', color='b', size=9) 158 | 159 | 160 | 161 | # text_ = 'Better' 162 | # col.annotate('', xy=(0.4, 0), xycoords='data', 163 | # xytext=(0.4, -50000), textcoords='data', 164 | # arrowprops=dict(arrowstyle="->, head_width=0.3", color='black', lw=2.5,ls='-')) 165 | # col.annotate(text_, xy=(0.4, -50000), xycoords='data', ha='center', color='black', size=12) 166 | 167 | 168 | 169 | col.set_xlabel('Test Accuracy') 170 | col.set_xlim(0.55, 1) 171 | col.set_xticks(np.arange(0.50, 1.01, 0.05)) 172 | col.set_ylim(0, 120000) 173 | 174 | col.set_ylabel('Latency (Seconds)') 175 | col.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 176 | 177 | # col.legend() 178 | col.legend((p1, p2), ('UCD Training', lebel_dyn), loc='upper left') #, , ncol=2, fontsize=8) , scatterpoints=0.1, markerscale=0.01 179 | 180 | # col_tx = col.twinx() 181 | 182 | # y1_iot = [y1_energy[i] for i in iot_idx] 183 | # y1_dyn = [y1_energy[i] for i in dyn_idx] 184 | 185 | # for idx in range(4): 186 | # col_tx.plot(x_iot[idx], y1_iot[idx], 'y', linestyle='') 187 | # col_tx.plot(x_dyn[idx], y1_dyn[idx], 'r', linestyle='') 188 | 189 | # col_tx.set_ylabel('Energy Consumption (Joule)') 190 | # col_tx.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 191 | 192 | plt.tight_layout() 193 | 194 | plt.savefig('imgs/intro_sum.pdf', bbox_inches='tight') 195 | 196 | 197 | #import pdb; pdb.set_trace() 198 | 199 | 200 | # print gain 201 | cost_down = [(iot-dyn)/iot for iot, dyn in zip(y1_iot, y1_dyn)] 202 | acc_up = [(dyn-iot)/iot for iot, dyn in zip(x_iot, x_dyn)] 203 | 204 | # 205 | print(f"encoders: {encoders}") 206 | print(f"acc up: {acc_up}") 207 | print(f"cost down: {cost_down}") 208 | print(sum(cost_down)/3) 209 | print(sum(acc_up)/3) -------------------------------------------------------------------------------- /frameworks/centaur/plot/7_mobility.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from matplotlib.ticker import FuncFormatter 6 | import matplotlib.cm as cm 7 | 8 | from plot_utils import y_fmt, linestyle_tuple, load_acc_pure, load_all_acc, load_mac_latency_energy, load_comm_latency_energy, load_res_from_cfg 9 | 10 | from adjustText import adjust_text 11 | 12 | SMALL_SIZE = 7.5 13 | MEDIUM_SIZE = 9 14 | BIGGER_SIZE = 10 15 | 16 | plt.rc('font', size=SMALL_SIZE) # controls default text sizes 17 | plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title 18 | plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels 19 | plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels 20 | plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels 21 | plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize 22 | plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title 23 | 24 | encoder = 'mobilenet' 25 | classifier = 'medium' 26 | res_folder = 'results' 27 | buff_file = 'results/mrate_plot_buff.csv' 28 | 29 | if not os.path.isfile(buff_file): 30 | # Load accuracy 31 | df_acc = load_all_acc(res_folder, focus='mr') 32 | 33 | df_acc['iot_latency'] = 0 34 | df_acc['ap_latency'] = 0 35 | df_acc['iot_energy'] = 0 36 | df_acc['ap_energy'] = 0 37 | 38 | # Load cost 39 | dests = ['ap', 'iot', 'dyn'] 40 | 41 | mr1_list = ['0.1', '0.4', '0.7'] 42 | mr2_list = ['0.4', '0.7', '1.0'] 43 | post_args = [('1cifar10_2mobi_3medium_c8|100_e3_a5_b3_g0_flpa1000_mr' + mr1 + '|' + mr2) for mr1, mr2 in zip(mr1_list, mr2_list)] 44 | 45 | df_acc['mr1'] = mr1_list * 3 46 | df_acc['mr2'] = mr2_list * 3 47 | 48 | # load configuration from cfg file 49 | computation_cfg, communication_cfg = load_res_from_cfg() 50 | iot_freq, ap_freq, count_inference, iot_energy_mac, ap_energy_mac = computation_cfg 51 | up_iot, down_iot, up_ap, down_ap, sample_size, iot_energy_comm, ap_energy_comm = communication_cfg 52 | 53 | for dest in dests: 54 | for post_arg, mr1, mr2 in zip(post_args, mr1_list, mr2_list): 55 | # load and compute the latency and energy cost 56 | mac_dyn, mac_ap, mac_iot, _, _, _ = load_mac_latency_energy(encoder,classifier, post_arg, iot_freq, ap_freq, iot_energy_mac, ap_energy_mac, dests=dest) 57 | comm_dyn, comm_ap, comm_iot, _, _, _ = load_comm_latency_energy(encoder,classifier, post_arg, up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size, dests=dest) 58 | 59 | sum_x = [] 60 | for k in range(2, 6): 61 | if mac_dyn is not None: sum_x.append([i+j for i,j in zip(mac_dyn[k], comm_dyn[k])]) 62 | if mac_ap is not None: sum_x.append([i+j for i,j in zip(mac_ap[k], comm_ap[k])]) 63 | if mac_iot is not None: sum_x.append([i+j for i,j in zip(mac_iot[k], comm_iot[k])]) 64 | 65 | # put the cost into the table 66 | bool_ = (df_acc['dest']==dest) & (df_acc['mr1']==mr1) & (df_acc['mr2']==mr2) 67 | df_acc.loc[bool_, 'iot_latency'] = max(sum_x[0]) 68 | df_acc.loc[bool_, 'ap_latency'] = max(sum_x[1]) 69 | df_acc.loc[bool_, 'iot_energy'] = max(sum_x[2]) 70 | df_acc.loc[bool_, 'ap_energy'] = max(sum_x[3]) 71 | 72 | df_acc['key'] = '$\lambda$=' + df_acc['mr1'] + '~'+df_acc['mr2'] 73 | 74 | df_acc.to_csv(buff_file) 75 | 76 | else: 77 | print('loading existing buff !') 78 | df_acc = pd.read_csv(buff_file, sep=',') 79 | 80 | ap_acc_up = [i-j for i,j in zip(df_acc.loc[df_acc['dest'] == 'dyn', 'accuracy'].tolist(), df_acc.loc[df_acc['dest'] == 'ap', 'accuracy'].tolist())] 81 | iot_acc_up = [i-j for i,j in zip(df_acc.loc[df_acc['dest'] == 'dyn', 'accuracy'].tolist(), df_acc.loc[df_acc['dest'] == 'iot', 'accuracy'].tolist())] 82 | 83 | ap_cost_down = [(j-i)/j for i,j in zip(df_acc.loc[df_acc['dest'] == 'dyn', 'ap_latency'].tolist(), df_acc.loc[df_acc['dest'] == 'ap', 'ap_latency'].tolist())] 84 | iot_cost_down = [(j-i)/j for i,j in zip(df_acc.loc[df_acc['dest'] == 'dyn', 'iot_latency'].tolist(), df_acc.loc[df_acc['dest'] == 'iot', 'iot_latency'].tolist())] 85 | 86 | dat_ = {'key': df_acc['key'].tolist()[:3], 87 | 'ap_acc_up': ap_acc_up, 88 | 'iot_acc_up': iot_acc_up, 89 | 'ap_cost_down': ap_cost_down, 90 | 'iot_cost_down': iot_cost_down, 91 | } 92 | print(pd.DataFrame(data = dat_)) 93 | 94 | 95 | 96 | # plot 97 | figs, axes = plt.subplots(nrows=1, ncols=2, figsize=(5,3)) 98 | plt.subplots_adjust(left=0.11, right=0.89, wspace=0.4, bottom=0.15, top=0.9) 99 | 100 | x = df_acc['accuracy'].tolist() 101 | len_ = len(x) 102 | txt = df_acc['key'].tolist() 103 | dest = df_acc['dest'].tolist() 104 | 105 | anno_text_size = 6 106 | point_size = 30 107 | props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) 108 | 109 | 110 | def plot_arrow(ax, x_, y_, x_dyn, y_dyn): 111 | # x_, y_ -> x_dyn, y_dyn 112 | for i in range(len(x_)): 113 | xpos = (x_[i]+x_dyn[i])/2 114 | ypos = (y_[i]+y_dyn[i])/2 115 | xdir = x_dyn[i] - x_[i] 116 | ydir = y_dyn[i] - y_[i] 117 | 118 | # X,Y,dX,dY = xpos, ypos, xdir, ydir 119 | # ax.annotate("", xytext=(xpos,ypos),xy=(xpos+0.001*xdir,ypos+0.001*ydir), arrowprops=dict(arrowstyle="->", color='black'), size = 8) 120 | ax.plot([x_[i], x_dyn[i]], [y_[i], y_dyn[i]], color='black', linewidth=0.3, linestyle=(0, (5, 10))) 121 | 122 | 123 | for idx in range(len(axes)): 124 | col = axes[idx] 125 | if idx == 0: 126 | y1 = df_acc['iot_latency'].tolist() 127 | y2 = df_acc['iot_energy'].tolist() 128 | elif idx == 1: 129 | y1 = df_acc['ap_latency'].tolist() 130 | y2 = df_acc['ap_energy'].tolist() 131 | 132 | # import pdb; pdb.set_trace() 133 | 134 | idx_ap = [i for i, x in enumerate(dest) if x == "ap"] 135 | idx_iot = [i for i, x in enumerate(dest) if x == "iot"] 136 | idx_dyn = [i for i, x in enumerate(dest) if x == "dyn"] 137 | 138 | x_dyn = [x[i] for i in idx_dyn] 139 | y_dyn = [y1[i] for i in idx_dyn] 140 | if max(y_dyn) > 1000: 141 | plt1 = col.scatter(x_dyn, y_dyn, s=point_size, color='r', marker='s') 142 | #texts = [col.text(x[i], y1[i], txt[i], size=6) for i in idx_iot] 143 | #adjust_text(texts, ax=col) 144 | 145 | x_ap = [x[i] for i in idx_ap] 146 | y_ap = [y1[i] for i in idx_ap] 147 | if max(y_ap) > 1000: 148 | plot_arrow(col, x_ap, y_ap, x_dyn, y_dyn) 149 | plt2 = col.scatter(x_ap, y_ap, s=point_size, color='b', marker='d') 150 | #texts = [col.text(x[i], y1[i], txt[i], size=6) for i in idx_ap] 151 | #adjust_text(texts, ax=col, ha='center',va='center', autoalign=False, expand_points=(1.5,2)) 152 | col.text(0.81, 1520, r'$\lambda$=0.1~0.4', size=anno_text_size, bbox=props) 153 | col.text(0.803, 1250, r'$\lambda$=0.4~0.7', size=anno_text_size, bbox=props) 154 | col.text(0.835, y1[2], r'$\lambda$=0.7~1.0', size=anno_text_size, bbox=props) 155 | 156 | col.legend([plt2, plt1], ['AP Training','Centaur']) 157 | 158 | 159 | x_iot = [x[i] for i in idx_iot] 160 | y_iot = [y1[i] for i in idx_iot] 161 | if max(y_iot) > 1000: 162 | plot_arrow(col, x_iot, y_iot, x_dyn, y_dyn) 163 | plt3 = col.scatter(x_iot, y_iot, s=point_size, color='y', marker='<') 164 | #texts = [col.text(x[i], y1[i], txt[i], size=6) for i in idx_iot] 165 | #adjust_text(texts, ax=col) 166 | col.text(0.80, 95000, r'$\lambda$=0.1~0.4', size=anno_text_size, bbox=props) 167 | col.text(0.81, 67500, r'$\lambda$=0.4~0.7', size=anno_text_size, bbox=props) 168 | col.text(0.795, 35000, r'$\lambda$=0.7~1.0', size=anno_text_size, bbox=props) 169 | 170 | col.legend([plt3, plt1], ['UCD Training','Centaur'],loc="center right", bbox_to_anchor=(1,0.76)) 171 | 172 | 173 | 174 | # adjust_text(texts, ax=col, arrowprops=dict(arrowstyle="-", color='k', lw=0.5), avoid_text=True) # avoid_self=True, 175 | # col.invert_xaxis() 176 | col.set_xlim(0.79, 0.91) 177 | #col.set_xticks(np.arange(0.92, 0.80, 0.02)) 178 | col.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 179 | col.set_xlabel('Test Accuracy') 180 | if idx==0: col.set_title('Workload on UCDs') 181 | if idx==0: col.set_ylabel('Accumulated Latency (Seconds)') 182 | 183 | 184 | col_tx = col.twinx() 185 | 186 | x_ = [x[i] for i in idx_ap] 187 | y_ = [y2[i] for i in idx_ap] 188 | if max(y_) > 1000: col_tx.scatter(x_, y_, s=1, color='b',marker='') 189 | 190 | x_ = [x[i] for i in idx_iot] 191 | y_ = [y2[i] for i in idx_iot] 192 | if max(y_) > 1000: col_tx.scatter(x_, y_, s=1, color='y',marker='') 193 | 194 | x_ = [x[i] for i in idx_dyn] 195 | y_ = [y2[i] for i in idx_dyn] 196 | if max(y_) > 1000: col_tx.scatter(x_, y_, s=1, color='r',marker='') 197 | 198 | #idx_iot = [i for i, x in enumerate(dest) if x == "iot"] 199 | #col.scatter(x[idx_iot], y1[idx_iot], s=8, colors='y') 200 | 201 | 202 | #col_tx.set_xlim(0.79, 0.91) 203 | #col.set_xticks(np.arange(0.92, 0.80, 0.02)) 204 | col_tx.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 205 | col_tx.set_xlabel('Test Accuracy') 206 | if idx==1: col_tx.set_title('Workload on APs') 207 | if idx==1: col_tx.set_ylabel('Accumulated Energy (Joule)') 208 | 209 | plt.savefig('imgs/mobility.pdf') -------------------------------------------------------------------------------- /frameworks/centaur/plot/3_cost_acc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from matplotlib.ticker import FuncFormatter 6 | 7 | from plot_utils import y_fmt, load_acc_one, load_mac_latency_energy, load_comm_latency_energy, load_res_from_cfg 8 | 9 | plt.rcParams.update({'font.size': 8}) 10 | 11 | parser = argparse.ArgumentParser(description="Plot") 12 | parser.add_argument("--mac", type=bool, default=True) 13 | parser.add_argument("--comm", type=bool, default=True) 14 | parser.add_argument("--sum", type=bool, default=False) 15 | 16 | args = parser.parse_args() 17 | 18 | lebel_pbfl = 'Centaur' 19 | 20 | res_folder = 'results' 21 | encoder = 'mobilenet' 22 | classifier = 'medium' 23 | post_arg = '1cifar10_2mobi_3medium_c8|100_e3_a5_b3_g0' 24 | 25 | 26 | def plot_12_3y(plot_dat_dyn, plot_dat_ap, plot_dat_iot, x_dyn, x_ap, x_iot, title_text, x_text, y_text, xlim_vec, ylim_vec, filename, xinverse=False): 27 | figs, axes = plt.subplots(nrows=1, ncols=2, figsize=(9,2.5)) 28 | plt.subplots_adjust(left=0.1, right=0.85, wspace=1, bottom=0.15) 29 | pi = 0 30 | 31 | min_value = min(x_dyn + x_ap + x_iot) 32 | min_value = round(min_value, 1) 33 | max_value = max(x_dyn + x_ap + x_iot) 34 | 35 | lebel_ap = 'AP Training' 36 | lebel_iot = 'UCD training' 37 | lebel_dyn = lebel_pbfl 38 | linewidth = 1 39 | 40 | for col in axes: 41 | if not max(plot_dat_ap[pi]) == 0: 42 | col.plot(x_ap, plot_dat_ap[pi], 'g', label=lebel_ap, linewidth=linewidth) 43 | if not max(plot_dat_iot[pi]) == 0: 44 | col.plot(x_iot, plot_dat_iot[pi], 'y', label=lebel_iot, linewidth=linewidth) 45 | if not max(plot_dat_dyn[pi]) == 0: 46 | col.plot(x_dyn, plot_dat_dyn[pi], 'r', label=lebel_dyn, linewidth=linewidth) 47 | col.set_title(title_text[pi]) 48 | col.set_xlabel(x_text[pi]) 49 | col.set_ylabel(y_text[pi]) 50 | 51 | 52 | if xinverse == True: col.invert_xaxis() 53 | if xlim_vec is not None: col.set_xlim(xlim_vec[0], xlim_vec[1]) 54 | if ylim_vec is not None: col.set_ylim(ylim_vec[0], ylim_vec[1]) 55 | col.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 56 | 57 | col.legend() 58 | 59 | col_tx = col.twinx() 60 | col_tx2 = col.twinx() 61 | 62 | if not max(plot_dat_ap[pi+2]) == 0: 63 | col_tx.plot(x_ap, plot_dat_ap[pi+2], 'g', label=lebel_ap, linewidth=linewidth, linestyle='') 64 | if not max(plot_dat_iot[pi+2]) == 0: 65 | col_tx.plot(x_iot, plot_dat_iot[pi+2], 'y', label=lebel_iot, linewidth=linewidth, linestyle='') 66 | if not max(plot_dat_dyn[pi+2]) == 0: 67 | col_tx.plot(x_dyn, plot_dat_dyn[pi+2], 'r', label=lebel_dyn, linewidth=linewidth, linestyle='') 68 | col_tx.set_ylabel(y_text[pi+2]) 69 | col_tx.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 70 | 71 | if not max(plot_dat_ap[pi+4]) == 0: 72 | col_tx2.plot(x_ap, plot_dat_ap[pi+4], 'g', label=lebel_ap, linewidth=linewidth, linestyle='') 73 | if not max(plot_dat_iot[pi+4]) == 0: 74 | col_tx2.plot(x_iot, plot_dat_iot[pi+4], 'y', label=lebel_iot, linewidth=linewidth, linestyle='') 75 | if not max(plot_dat_dyn[pi+4]) == 0: 76 | col_tx2.plot(x_dyn, plot_dat_dyn[pi+4], 'r', label=lebel_dyn, linewidth=linewidth, linestyle='') 77 | col_tx2.set_ylabel(y_text[pi+4]) 78 | col_tx2.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 79 | 80 | col_tx2.spines['right'].set_position(('outward', 45)) 81 | 82 | pi += 1 83 | 84 | plt.savefig('imgs/' + filename + '.pdf') 85 | 86 | 87 | def plot_12_2y(plot_dat_dyn, plot_dat_ap, plot_dat_iot, x_dyn, x_ap, x_iot, x_text, y_text, xlim_vec, ylim_vec, filename, xinverse=False): 88 | figs, axes = plt.subplots(nrows=1, ncols=2, figsize=(5,2.5)) 89 | plt.subplots_adjust(left=0.1, right=0.9, wspace=0.45, bottom=0.15) 90 | 91 | pi = 0 92 | 93 | min_value = min(x_dyn + x_ap + x_iot) 94 | min_value = round(min_value, 1) 95 | if min_value < 0.35: min_value = 0.35 96 | max_value = max(x_dyn + x_ap + x_iot) 97 | 98 | title_text = ['Workload on UCDs', 'Workload on APs'] 99 | lebel_ap = 'AP Training' 100 | lebel_iot = 'UCD training' 101 | lebel_dyn = lebel_pbfl 102 | 103 | linewidth = 1 104 | 105 | for index, col in enumerate(axes): 106 | if not max(plot_dat_ap[pi]) <= 1000: 107 | col.plot(x_ap, plot_dat_ap[pi], 'g', label=lebel_ap, linewidth=linewidth) 108 | if not max(plot_dat_iot[pi]) <= 1000: 109 | col.plot(x_iot, plot_dat_iot[pi], 'y', label=lebel_iot, linewidth=linewidth) 110 | if not max(plot_dat_dyn[pi]) <= 1000: 111 | col.plot(x_dyn, plot_dat_dyn[pi], 'r', label=lebel_dyn, linewidth=linewidth) 112 | col.set_title(title_text[pi]) 113 | col.set_xlabel(x_text[pi]) 114 | if index ==0: 115 | text_ = 'acc. up=5.78%\ncost down=56.57%' 116 | col.set_ylabel(y_text[pi]) 117 | col.annotate('', xy=(0.88, 27500), xycoords='data', 118 | xytext=(0.82, 60500), textcoords='data', 119 | arrowprops=dict(arrowstyle="<->"))#connectionstyle="bar", 120 | #ec="k", shrinkA=2, shrinkB=2)) 121 | col.annotate(text_, xy=(0.86, 44000), xycoords='data', 122 | xytext=(-80, -30), textcoords='offset points', 123 | arrowprops=dict(arrowstyle="->")) 124 | 125 | if xinverse == True: col.invert_xaxis() 126 | #if pi > 0: col.set_xlim(min_value, max_value+0.05) 127 | col.set_xlim(xlim_vec[0], xlim_vec[1]) 128 | if ylim_vec is not None: col.set_ylim(0, ylim_vec[pi]) 129 | col.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 130 | 131 | col_tx = col.twinx() 132 | 133 | if not max(plot_dat_ap[pi+2]) <= 1000: 134 | col_tx.plot(x_ap, plot_dat_ap[pi+2], 'limegreen', label=lebel_ap, linewidth=linewidth, linestyle='') 135 | if not max(plot_dat_iot[pi+2]) <= 1000: 136 | col_tx.plot(x_iot, plot_dat_iot[pi+2], 'y', label=lebel_iot, linewidth=linewidth, linestyle='') 137 | if not max(plot_dat_dyn[pi+2]) <= 1000: 138 | col_tx.plot(x_dyn, plot_dat_dyn[pi+2], 'r', label=lebel_dyn, linewidth=linewidth, linestyle='') 139 | #if pi > 0: 140 | # col_tx.set_ylabel(y_text[pi+2], color='limegreen') 141 | #else: 142 | if index ==1: col_tx.set_ylabel(y_text[pi+2]) 143 | 144 | col_tx.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 145 | 146 | #if index ==1: 147 | col.legend() 148 | 149 | pi += 1 150 | 151 | plt.savefig('imgs/' + filename + '.pdf') 152 | 153 | 154 | 155 | # load configuration from cfg file 156 | computation_cfg, communication_cfg = load_res_from_cfg() 157 | iot_freq, ap_freq, count_inference, iot_energy_mac, ap_energy_mac = computation_cfg 158 | up_iot, down_iot, up_ap, down_ap, sample_size, iot_energy_comm, ap_energy_comm = communication_cfg 159 | 160 | 161 | if args.mac: 162 | acc_ap, acc_iot, acc_dyn = load_acc_one(post_arg, res_folder) 163 | plot_dat_dyn, plot_dat_ap, plot_dat_iot, _, _, _ = load_mac_latency_energy(encoder,classifier, post_arg, iot_freq, ap_freq, iot_energy_mac, ap_energy_mac) 164 | 165 | # plot mac 166 | title_text = ['Computation workload on UCDs', 'Computation workload on APs'] 167 | x_text = ['Test Accuracy'] * 2 168 | y_text = ['Accumulated MAC operations (Million)'] * 2 169 | y_text += ['Accumulated Latency (Seconds)'] * 2 170 | y_text += ['Accumulated Energy (Joule)'] * 2 171 | xlim_vec, ylim_vec = [0.35, 0.9], None 172 | filename = 'mac_latency_energy_accuracy' 173 | 174 | plot_12_3y(plot_dat_dyn, plot_dat_ap, plot_dat_iot, acc_dyn, acc_ap, acc_iot, title_text, x_text, y_text, xlim_vec, ylim_vec, filename, xinverse=False) 175 | 176 | 177 | if args.comm: 178 | acc_ap, acc_iot, acc_dyn = load_acc_one(post_arg, res_folder) 179 | plot_dat_dyn, plot_dat_ap, plot_dat_iot, _, _, _ = load_comm_latency_energy(encoder,classifier, post_arg, up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size) 180 | 181 | # plot mac 182 | title_text = ['Communication workload on UCDs', 'Communication workload on APs'] 183 | x_text = ['Test Accuracy'] * 2 184 | y_text = ['Accumulated Comm. Amount (Migabytes)'] * 2 185 | y_text += ['Accumulated Latency (Seconds)'] * 2 186 | y_text += ['Accumulated Energy (Joule)'] * 2 187 | xlim_vec, ylim_vec = [0.35, 0.9], None 188 | filename = 'comm_latency_energy_accuracy' 189 | 190 | plot_12_3y(plot_dat_dyn, plot_dat_ap, plot_dat_iot, acc_dyn, acc_ap, acc_iot, title_text, x_text, y_text, xlim_vec, ylim_vec, filename, xinverse=False) 191 | 192 | 193 | if args.sum: 194 | acc_ap, acc_iot, acc_dyn = load_acc_one(post_arg, res_folder) 195 | 196 | mac_dyn, mac_ap, mac_iot, _, _, _ = load_mac_latency_energy(encoder,classifier, post_arg, iot_freq, ap_freq, iot_energy_mac, ap_energy_mac) 197 | comm_dyn, comm_ap, comm_iot, _, _, _ = load_comm_latency_energy(encoder,classifier, post_arg, up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size) 198 | 199 | sum_dyn, sum_ap, sum_iot = [], [], [] 200 | for k in range(2, len(mac_dyn)): 201 | sum_dyn.append([i+j for i,j in zip(mac_dyn[k], comm_dyn[k])]) 202 | sum_ap.append([i+j for i,j in zip(mac_ap[k], comm_ap[k])]) 203 | sum_iot.append([i+j for i,j in zip(mac_iot[k], comm_iot[k])]) 204 | 205 | # print gain 206 | cost_down = (max(sum_iot[2]) - max(sum_dyn[2])) / max(sum_iot[2]) 207 | acc_up = max(acc_dyn) - max(acc_iot) 208 | print(f"cost down: {cost_down}; accuracy up: {acc_up}") 209 | 210 | x_text = ['Test Accuracy'] * 2 211 | y_text = ['Accumulated Latency (Seconds)'] * 2 212 | y_text += ['Accumulated Energy (Joule)'] * 2 213 | 214 | xlim_vec, ylim_vec = [0.35, 0.9], None 215 | 216 | filename = 'sum_latency_energy_accuracy' 217 | plot_12_2y(sum_dyn, sum_ap, sum_iot, acc_dyn, acc_ap, acc_iot, x_text, y_text, xlim_vec, ylim_vec, filename, xinverse=False) -------------------------------------------------------------------------------- /frameworks/centaur/third_party/autograd_hacks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Library for extracting interesting quantites from autograd, see README.md 3 | 4 | Not thread-safe because of module-level variables 5 | 6 | Notation: 7 | o: number of output classes (exact Hessian), number of Hessian samples (sampled Hessian) 8 | n: batch-size 9 | do: output dimension (output channels for convolution) 10 | di: input dimension (input channels for convolution) 11 | Hi: per-example Hessian of matmul, shaped as matrix of [dim, dim], indices have been row-vectorized 12 | Hi_bias: per-example Hessian of bias 13 | Oh, Ow: output height, output width (convolution) 14 | Kh, Kw: kernel height, kernel width (convolution) 15 | 16 | Jb: batch output Jacobian of matmul, output sensitivity for example,class pair, [o, n, ....] 17 | Jb_bias: as above, but for bias 18 | 19 | A, activations: inputs into current layer 20 | B, backprops: backprop values (aka Lop aka Jacobian-vector product) observed at current layer 21 | 22 | """ 23 | 24 | from typing import List 25 | 26 | import torch 27 | import torch.nn as nn 28 | import torch.nn.functional as F 29 | 30 | _supported_layers = ['Linear', 'Conv2d'] # Supported layer class types 31 | _hooks_disabled: bool = False # work-around for https://github.com/pytorch/pytorch/issues/25723 32 | _enforce_fresh_backprop: bool = False # global switch to catch double backprop errors on Hessian computation 33 | 34 | 35 | def add_hooks(model: nn.Module) -> None: 36 | """ 37 | Adds hooks to model to save activations and backprop values. 38 | 39 | The hooks will 40 | 1. save activations into param.activations during forward pass 41 | 2. append backprops to params.backprops_list during backward pass. 42 | 43 | Call "remove_hooks(model)" to disable this. 44 | 45 | Args: 46 | model: 47 | """ 48 | 49 | global _hooks_disabled 50 | _hooks_disabled = False 51 | 52 | handles = [] 53 | for layer in model.modules(): 54 | if _layer_type(layer) in _supported_layers: 55 | handles.append(layer.register_forward_hook(_capture_activations)) 56 | handles.append(layer.register_backward_hook(_capture_backprops)) 57 | #handles.append(layer.register_full_backward_hook(_capture_backprops)) 58 | 59 | model.__dict__.setdefault('autograd_hacks_hooks', []).extend(handles) 60 | 61 | 62 | def remove_hooks(model: nn.Module) -> None: 63 | """ 64 | Remove hooks added by add_hooks(model) 65 | """ 66 | 67 | assert model == 0, "not working, remove this after fix to https://github.com/pytorch/pytorch/issues/25723" 68 | 69 | if not hasattr(model, 'autograd_hacks_hooks'): 70 | print("Warning, asked to remove hooks, but no hooks found") 71 | else: 72 | for handle in model.autograd_hacks_hooks: 73 | handle.remove() 74 | del model.autograd_hacks_hooks 75 | 76 | 77 | def disable_hooks() -> None: 78 | """ 79 | Globally disable all hooks installed by this library. 80 | """ 81 | 82 | global _hooks_disabled 83 | _hooks_disabled = True 84 | 85 | 86 | def enable_hooks() -> None: 87 | """the opposite of disable_hooks()""" 88 | 89 | global _hooks_disabled 90 | _hooks_disabled = False 91 | 92 | 93 | def is_supported(layer: nn.Module) -> bool: 94 | """Check if this layer is supported""" 95 | 96 | return _layer_type(layer) in _supported_layers 97 | 98 | 99 | def _layer_type(layer: nn.Module) -> str: 100 | return layer.__class__.__name__ 101 | 102 | 103 | def _capture_activations(layer: nn.Module, input: List[torch.Tensor], output: torch.Tensor): 104 | """Save activations into layer.activations in forward pass""" 105 | 106 | if _hooks_disabled: 107 | return 108 | assert _layer_type(layer) in _supported_layers, "Hook installed on unsupported layer, this shouldn't happen" 109 | setattr(layer, "activations", input[0].detach()) 110 | 111 | 112 | def _capture_backprops(layer: nn.Module, _input, output): 113 | """Append backprop to layer.backprops_list in backward pass.""" 114 | global _enforce_fresh_backprop 115 | 116 | if _hooks_disabled: 117 | return 118 | 119 | if _enforce_fresh_backprop: 120 | assert not hasattr(layer, 'backprops_list'), "Seeing result of previous backprop, use clear_backprops(model) to clear" 121 | _enforce_fresh_backprop = False 122 | 123 | if not hasattr(layer, 'backprops_list'): 124 | setattr(layer, 'backprops_list', []) 125 | layer.backprops_list.append(output[0].detach()) 126 | 127 | 128 | def clear_backprops(model: nn.Module) -> None: 129 | """Delete layer.backprops_list in every layer.""" 130 | for layer in model.modules(): 131 | if hasattr(layer, 'backprops_list'): 132 | del layer.backprops_list 133 | 134 | 135 | def compute_grad1(model: nn.Module, loss_type: str = 'mean') -> None: 136 | """ 137 | Compute per-example gradients and save them under 'param.grad1'. Must be called after loss.backprop() 138 | 139 | Args: 140 | model: 141 | loss_type: either "mean" or "sum" depending whether backpropped loss was averaged or summed over batch 142 | """ 143 | 144 | assert loss_type in ('sum', 'mean') 145 | for layer in model.modules(): 146 | layer_type = _layer_type(layer) 147 | if layer_type not in _supported_layers: 148 | continue 149 | assert hasattr(layer, 'activations'), "No activations detected, run forward after add_hooks(model)" 150 | assert hasattr(layer, 'backprops_list'), "No backprops detected, run backward after add_hooks(model)" 151 | assert len(layer.backprops_list) == 1, "Multiple backprops detected, make sure to call clear_backprops(model)" 152 | 153 | A = layer.activations 154 | n = A.shape[0] 155 | if loss_type == 'mean': 156 | B = layer.backprops_list[0] * n 157 | else: # loss_type == 'sum': 158 | B = layer.backprops_list[0] 159 | 160 | if layer_type == 'Linear': 161 | setattr(layer.weight, 'grad1', torch.einsum('ni,nj->nij', B, A)) 162 | if layer.bias is not None: 163 | setattr(layer.bias, 'grad1', B) 164 | 165 | elif layer_type == 'Conv2d': 166 | A = torch.nn.functional.unfold(A, layer.kernel_size) 167 | B = B.reshape(n, -1, A.shape[-1]) 168 | grad1 = torch.einsum('ijk,ilk->ijl', B, A) 169 | shape = [n] + list(layer.weight.shape) 170 | setattr(layer.weight, 'grad1', grad1.reshape(shape)) 171 | if layer.bias is not None: 172 | setattr(layer.bias, 'grad1', torch.sum(B, dim=2)) 173 | 174 | 175 | def compute_hess(model: nn.Module,) -> None: 176 | """Save Hessian under param.hess for each param in the model""" 177 | 178 | for layer in model.modules(): 179 | layer_type = _layer_type(layer) 180 | if layer_type not in _supported_layers: 181 | continue 182 | assert hasattr(layer, 'activations'), "No activations detected, run forward after add_hooks(model)" 183 | assert hasattr(layer, 'backprops_list'), "No backprops detected, run backward after add_hooks(model)" 184 | 185 | if layer_type == 'Linear': 186 | A = layer.activations 187 | B = torch.stack(layer.backprops_list) 188 | 189 | n = A.shape[0] 190 | o = B.shape[0] 191 | 192 | A = torch.stack([A] * o) 193 | Jb = torch.einsum("oni,onj->onij", B, A).reshape(n*o, -1) 194 | H = torch.einsum('ni,nj->ij', Jb, Jb) / n 195 | 196 | setattr(layer.weight, 'hess', H) 197 | 198 | if layer.bias is not None: 199 | setattr(layer.bias, 'hess', torch.einsum('oni,onj->ij', B, B)/n) 200 | 201 | elif layer_type == 'Conv2d': 202 | Kh, Kw = layer.kernel_size 203 | di, do = layer.in_channels, layer.out_channels 204 | 205 | A = layer.activations.detach() 206 | A = torch.nn.functional.unfold(A, (Kh, Kw)) # n, di * Kh * Kw, Oh * Ow 207 | n = A.shape[0] 208 | B = torch.stack([Bt.reshape(n, do, -1) for Bt in layer.backprops_list]) # o, n, do, Oh*Ow 209 | o = B.shape[0] 210 | 211 | A = torch.stack([A] * o) # o, n, di * Kh * Kw, Oh*Ow 212 | Jb = torch.einsum('onij,onkj->onik', B, A) # o, n, do, di * Kh * Kw 213 | 214 | Hi = torch.einsum('onij,onkl->nijkl', Jb, Jb) # n, do, di*Kh*Kw, do, di*Kh*Kw 215 | Jb_bias = torch.einsum('onij->oni', B) 216 | Hi_bias = torch.einsum('oni,onj->nij', Jb_bias, Jb_bias) 217 | 218 | setattr(layer.weight, 'hess', Hi.mean(dim=0)) 219 | if layer.bias is not None: 220 | setattr(layer.bias, 'hess', Hi_bias.mean(dim=0)) 221 | 222 | 223 | def backprop_hess(output: torch.Tensor, hess_type: str) -> None: 224 | """ 225 | Call backprop 1 or more times to get values needed for Hessian computation. 226 | 227 | Args: 228 | output: prediction of neural network (ie, input of nn.CrossEntropyLoss()) 229 | hess_type: type of Hessian propagation, "CrossEntropy" results in exact Hessian for CrossEntropy 230 | 231 | Returns: 232 | 233 | """ 234 | 235 | assert hess_type in ('LeastSquares', 'CrossEntropy') 236 | global _enforce_fresh_backprop 237 | n, o = output.shape 238 | 239 | _enforce_fresh_backprop = True 240 | 241 | if hess_type == 'CrossEntropy': 242 | batch = F.softmax(output, dim=1) 243 | 244 | mask = torch.eye(o).expand(n, o, o) 245 | diag_part = batch.unsqueeze(2).expand(n, o, o) * mask 246 | outer_prod_part = torch.einsum('ij,ik->ijk', batch, batch) 247 | hess = diag_part - outer_prod_part 248 | assert hess.shape == (n, o, o) 249 | 250 | for i in range(n): 251 | hess[i, :, :] = symsqrt(hess[i, :, :]) 252 | hess = hess.transpose(0, 1) 253 | 254 | elif hess_type == 'LeastSquares': 255 | hess = [] 256 | assert len(output.shape) == 2 257 | batch_size, output_size = output.shape 258 | 259 | id_mat = torch.eye(output_size) 260 | for out_idx in range(output_size): 261 | hess.append(torch.stack([id_mat[out_idx]] * batch_size)) 262 | 263 | for o in range(o): 264 | output.backward(hess[o], retain_graph=True) 265 | 266 | 267 | def symsqrt(a, cond=None, return_rank=False, dtype=torch.float32): 268 | """Symmetric square root of a positive semi-definite matrix. 269 | See https://github.com/pytorch/pytorch/issues/25481""" 270 | 271 | s, u = torch.symeig(a, eigenvectors=True) 272 | cond_dict = {torch.float32: 1e3 * 1.1920929e-07, torch.float64: 1E6 * 2.220446049250313e-16} 273 | 274 | if cond in [None, -1]: 275 | cond = cond_dict[dtype] 276 | 277 | above_cutoff = (abs(s) > cond * torch.max(abs(s))) 278 | 279 | psigma_diag = torch.sqrt(s[above_cutoff]) 280 | u = u[:, above_cutoff] 281 | 282 | B = u @ torch.diag(psigma_diag) @ u.t() 283 | if return_rank: 284 | return B, len(psigma_diag) 285 | else: 286 | return B -------------------------------------------------------------------------------- /frameworks/centaur/plot/plot_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import pandas as pd 5 | import numpy as np 6 | from os import listdir 7 | from os.path import isfile, join 8 | 9 | from measures.mac_comm_counter import mac_latency_energy_counter, comm_latency_energy_counter 10 | 11 | linestyle_tuple = { 12 | ('loosely dotted', (0, (1, 10))), 13 | ('dotted', (0, (1, 1))), 14 | ('densely dotted', (0, (1, 1))), 15 | ('long dash with offset', (5, (10, 3))), 16 | ('loosely dashed', (0, (5, 10))), 17 | ('dashed', (0, (5, 5))), 18 | ('densely dashed', (0, (5, 1))), 19 | 20 | ('loosely dashdotted', (0, (3, 10, 1, 10))), 21 | ('dashdotted', (0, (3, 5, 1, 5))), 22 | ('densely dashdotted', (0, (3, 1, 1, 1))), 23 | 24 | ('dashdotdotted', (0, (3, 5, 1, 5, 1, 5))), 25 | ('loosely dashdotdotted', (0, (3, 10, 1, 10, 1, 10))), 26 | ('densely dashdotdotted', (0, (3, 1, 1, 1, 1, 1))) 27 | } 28 | 29 | 30 | # https://stackoverflow.com/questions/40566413/matplotlib-pyplot-auto-adjust-unit-of-y-axis 31 | def y_fmt(y, pos): 32 | decades = [1e9, 1e6, 1e3, 1e0, 1e-3, 1e-6, 1e-9 ] 33 | suffix = ["G", "M", "K", "" , "m" , "u", "n" ] 34 | if y == 0: 35 | return str(0) 36 | for i, d in enumerate(decades): 37 | if np.abs(y) >= d: 38 | val = round(y/float(d),3) # fix unlimited 0.99999 issue 39 | signf = len(str(val).split(".")[1]) 40 | if signf == 0: 41 | return '{val:d}{suffix}'.format(val=int(val), suffix=suffix[i]) 42 | else: 43 | if signf == 1: 44 | if str(val).split(".")[1] == "0": 45 | return '{val:d}{suffix}'.format(val=int(str(round(val))), suffix=suffix[i]) 46 | tx = "{"+"val:.{signf}f".format(signf = signf) +"}{suffix}" 47 | return tx.format(val=val, suffix=suffix[i]) 48 | #return y 49 | return y 50 | 51 | 52 | def load_acc_pure(post_arg, res_folder): 53 | filedir = res_folder + "/" + "acc_ap_" + post_arg + ".csv" 54 | dat_full = pd.read_csv(filedir, sep=',') 55 | 56 | filedir = res_folder + "/" + "acc_iot_" + post_arg + ".csv" 57 | dat_clas = pd.read_csv(filedir, sep=',') 58 | 59 | filedir = res_folder + "/" + "acc_dyn_" + post_arg + ".csv" 60 | dat_dyn = pd.read_csv(filedir, sep=',') 61 | 62 | acc_ap = dat_full['accuracy'].tolist()[1:] 63 | acc_iot = dat_clas['accuracy'].tolist()[1:] 64 | acc_dyn = dat_dyn['accuracy'].tolist()[1:] 65 | 66 | return acc_ap, acc_iot, acc_dyn 67 | 68 | 69 | def load_acc_one(post_arg, res_folder): 70 | post_arg_small = '1cifar10_2mobi_3small_c8|100_e3_a5_b3_g0' 71 | post_arg_medium = '1cifar10_2mobi_3medium_c8|100_e3_a5_b3_g0' 72 | post_arg_large = '1cifar10_2mobi_3large_c8|100_e3_a5_b3_g0' 73 | post_args = [post_arg_small, post_arg_medium, post_arg_large] 74 | 75 | acc_ap, acc_iot, acc_dyn = [], [], [] 76 | for post_arg in post_args: 77 | 78 | filedir = res_folder + "/" + "acc_ap_" + post_arg + ".csv" 79 | dat_full = pd.read_csv(filedir, sep=',') 80 | acc_ap.append(dat_full['accuracy'].tolist()[1:]) 81 | 82 | filedir = res_folder + "/" + "acc_iot_" + post_arg + ".csv" 83 | dat_clas = pd.read_csv(filedir, sep=',') 84 | acc_iot.append(dat_clas['accuracy'].tolist()[1:]) 85 | 86 | filedir = res_folder + "/" + "acc_dyn_" + post_arg + ".csv" 87 | dat_dyn = pd.read_csv(filedir, sep=',') 88 | acc_dyn.append(dat_dyn['accuracy'].tolist()[1:]) 89 | 90 | acc_ap = [(i+j+k)/3 for i,j,k in zip(acc_ap[0], acc_ap[1], acc_ap[2])] 91 | acc_iot = [(i+j+k)/3 for i,j,k in zip(acc_iot[0], acc_iot[1], acc_iot[2])] 92 | acc_dyn = [(i+j+k)/3 for i,j,k in zip(acc_dyn[0], acc_dyn[1], acc_dyn[2])] 93 | 94 | return acc_ap, acc_iot, acc_dyn 95 | 96 | 97 | def load_all_acc(res_folder, focus='basic'): 98 | onlyfiles = [f for f in listdir(res_folder) if isfile(join(res_folder, f))] 99 | print(onlyfiles) 100 | 1/0 101 | acc_ap = [filename for filename in onlyfiles if 'acc_ap_1' in filename] 102 | acc_iot = [filename for filename in onlyfiles if 'acc_iot_1' in filename] 103 | acc_dyn = [filename for filename in onlyfiles if 'acc_dyn_1' in filename] 104 | acc_x = acc_ap + acc_iot + acc_dyn 105 | 106 | if focus == 'vanilla': 107 | acc_vanilla = [filename for filename in onlyfiles if 'acc_apvanilla_1' in filename] 108 | acc_x = acc_vanilla 109 | 110 | splitted_list = [] 111 | for acc_name in acc_x: 112 | # marking based on the filename 113 | splitted = acc_name.split('_') 114 | 115 | if focus == 'basic': 116 | if len(splitted) != 10: continue 117 | splitted = [splitted[1], splitted[2][1:], splitted[3][1:], splitted[4][1:], 118 | splitted[5].split('|')[0][1:], splitted[5].split('|')[1], 119 | splitted[6][1:], splitted[7][1:], splitted[8][1:], splitted[9][1:].replace('.csv','')] 120 | name_base = ['dest','dataset','encoder','classifier','clientin','clientall', 'epoch', 'alpha','beta','gamma'] 121 | elif focus == 'flpa': 122 | if len(splitted) != 11: continue 123 | splitted = [splitted[1], splitted[2][1:], splitted[3][1:], splitted[4][1:], 124 | splitted[5].split('|')[0][1:], splitted[5].split('|')[1], 125 | splitted[6][1:], splitted[7][1:], splitted[8][1:], splitted[9][1:], splitted[10][4:].replace('.csv','')] 126 | name_base = ['dest','dataset','encoder','classifier','clientin','clientall', 'epoch', 'alpha','beta','gamma','flpa'] 127 | elif focus == 'mr': 128 | if len(splitted) != 12: continue 129 | splitted = [splitted[1], splitted[2][1:], splitted[3][1:], splitted[4][1:], 130 | splitted[5].split('|')[0][1:], splitted[5].split('|')[1], 131 | splitted[6][1:], splitted[7][1:], splitted[8][1:], splitted[9][1:], splitted[10][4:], splitted[11].split('|')[0][2:], splitted[11].split('|')[1].replace('.csv','')] 132 | name_base = ['dest','dataset','encoder','classifier','clientin','clientall', 'epoch', 'alpha','beta','gamma','flpa','mr1','mr2'] 133 | elif focus == 'vanilla': 134 | splitted = [splitted[1], splitted[2][1:], splitted[3][1:], splitted[4][1:], 135 | splitted[5].split('|')[0][1:], splitted[5].split('|')[1], 136 | splitted[6][1:], splitted[7][1:], splitted[8][1:], splitted[9][1:], splitted[10][4:], splitted[11].split('|')[0][2:], splitted[11].split('|')[1].replace('.csv','')] 137 | name_base = ['dest','dataset','encoder','classifier','clientin','clientall', 'epoch', 'alpha','beta','gamma','flpa','mr1','mr2'] 138 | 139 | # read highest accuracy 140 | acc_temp = pd.read_csv(res_folder + '/' + acc_name) 141 | 142 | acc_list = acc_temp['accuracy'].tolist() 143 | acc_list.sort(reverse=True) 144 | max_acc = sum(acc_list[:5])/5 145 | 146 | # loc this accuracy 147 | round_idx = int(acc_temp.loc[acc_temp['accuracy'] == acc_list[0], 'round'].tolist()[0]) 148 | 149 | splitted = splitted + [max_acc] + [round_idx] 150 | splitted_list.append(splitted) 151 | 152 | columns_name = name_base + ['accuracy','loc'] 153 | df_acc = pd.DataFrame(splitted_list, columns=columns_name) 154 | df_acc = df_acc.sort_values(columns_name) 155 | 156 | df_acc.loc[df_acc['encoder'] == 'effi','encoder'] = 'efficient' 157 | df_acc.loc[df_acc['encoder'] == 'mobi','encoder'] = 'mobile' 158 | df_acc.loc[df_acc['encoder'] == 'shuf','encoder'] = 'shuffle' 159 | df_acc.loc[df_acc['dataset'] == 'cifar10','dataset'] = 'cf10' 160 | df_acc.loc[df_acc['dataset'] == 'cifar100','dataset'] = 'cf100' 161 | 162 | df_acc.loc[df_acc['classifier'] == 'small','classifier'] = 'S' 163 | df_acc.loc[df_acc['classifier'] == 'medium','classifier'] = 'M' 164 | df_acc.loc[df_acc['classifier'] == 'large','classifier'] = 'L' 165 | 166 | return df_acc 167 | 168 | 169 | 170 | def running_avg(arr, window_size=3): 171 | i = 0 172 | first = arr[0] 173 | last = arr[len(arr)-1] 174 | 175 | moving_averages = [] 176 | 177 | while i < len(arr) - window_size + 1: 178 | 179 | window = arr[i : i + window_size] 180 | window_average = round(sum(window) / window_size, 2) 181 | moving_averages.append(window_average) 182 | i += 1 183 | 184 | moving_averages = [first] + moving_averages + [last] 185 | return moving_averages 186 | 187 | 188 | 189 | def load_mac_latency_energy(encoder, classifier, post_args, iot_freq, ap_freq, iot_energy_mac, ap_energy_mac, dests=['dyn','ap','iot'], count_inference=True): 190 | plot_dat_dyn, plot_dat_ap, plot_dat_iot, rounds_dyn, rounds_ap, rounds_iot = None, None, None, None, None, None 191 | 192 | if 'dyn' in dests: 193 | plot_dat_dyn, rounds_dyn = mac_latency_energy_counter(encoder, classifier, post_args, 'dyn', 194 | iot_freq, ap_freq, iot_energy_mac, ap_energy_mac, count_inference) 195 | if 'ap' in dests: 196 | plot_dat_ap, rounds_ap = mac_latency_energy_counter(encoder, classifier, post_args, 'ap', 197 | iot_freq, ap_freq, iot_energy_mac, ap_energy_mac, count_inference) 198 | if 'iot' in dests: 199 | plot_dat_iot, rounds_iot = mac_latency_energy_counter(encoder, classifier, post_args, 'iot', 200 | iot_freq, ap_freq, iot_energy_mac, ap_energy_mac, count_inference) 201 | 202 | return plot_dat_dyn, plot_dat_ap, plot_dat_iot, rounds_dyn, rounds_ap, rounds_iot 203 | 204 | 205 | 206 | def load_comm_latency_energy(encoder,classifier, post_args, up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm,sample_size, dests=['dyn','ap','iot']): 207 | plot_dat_dyn, plot_dat_ap, plot_dat_iot, rounds_dyn, rounds_ap, rounds_iot = None, None, None, None, None, None 208 | 209 | if 'dyn' in dests: 210 | plot_dat_dyn, rounds_dyn = comm_latency_energy_counter(encoder,classifier, post_args, 'dyn', 211 | up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size) 212 | if 'ap' in dests: 213 | plot_dat_ap, rounds_ap = comm_latency_energy_counter(encoder,classifier, post_args, 'ap', 214 | up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size) 215 | if 'iot' in dests: 216 | plot_dat_iot, rounds_iot = comm_latency_energy_counter(encoder,classifier, post_args, 'iot', 217 | up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size) 218 | 219 | return plot_dat_dyn, plot_dat_ap, plot_dat_iot, rounds_dyn, rounds_ap, rounds_iot 220 | 221 | 222 | 223 | def res(configs, type, feature): 224 | return float(configs[configs['type'] == type][feature]) 225 | 226 | 227 | def load_res_from_cfg(): 228 | 229 | cfg = pd.read_csv('client_configurations.csv', sep=',') 230 | 231 | # plot mac, latency, energy, accuracy 232 | iot_frequency = res(cfg, 'iot', 'frequency(MHz)') 233 | ap_frequency = res(cfg, 'ap', 'frequency(MHz)') 234 | acceleration = res(cfg, 'ap', 'acceleration') # how many times faster the accelerator is than ap CPU 235 | count_inference = True 236 | 237 | iot_energy_mac = iot_frequency * res(cfg, 'iot', 'power(mWpMHz)') # times = milliWatt/Mhz https://www.prnewswire.com/news-releases/ultra-reliable-arm-cortex-m4f-microcontroller-from-maxim-integrated-offers-industrys-lowest-power-consumption-and-smallest-size-for-industrial-healthcare-and-iot-sensor-applications-301080651.html#:~:text=Lowest%20Power%3A%2040%C2%B5W%2FMHz%20of,5mm%2Dx%2D5mm%20TQFN. 238 | ap_energy_mac = ap_frequency * res(cfg, 'ap', 'power(mWpMHz)') # times = milliWatt/MHz 239 | # https://www.researchgate.net/figure/Energy-consumption-for-different-parts-of-the-mobile-phone_tbl1_224248235#:~:text=The%20power%20consumption%20of%20wearable,%5B18%5D.%20... 240 | 241 | # # plot communication amount, latency, energy, accuracy 242 | up_iot = res(cfg, 'iot', 'uplink(Mbit/s)') # BLE5 2 Mbit/s 243 | down_iot = res(cfg, 'iot', 'downlink(Mbit/s)') # BLE5 2 Mbit/s 244 | up_ap = res(cfg, 'ap', 'uplink(Mbit/s)') # WIFI 10 Mbit/s 245 | down_ap = res(cfg, 'iot', 'downlink(Mbit/s)') # WIFI 100 Mbit/s 246 | sample_size = 30 # KiloBytes 247 | 248 | iot_energy_comm = res(cfg, 'iot', 'energy_comm(W)') # Watt BLE5 249 | ap_energy_comm = res(cfg, 'ap', 'energy_comm(W)') # Watt WIFI 250 | 251 | computation_cfgs = iot_frequency, ap_frequency*acceleration, count_inference, iot_energy_mac, ap_energy_mac 252 | communication_cfgs = up_iot, down_iot, up_ap, down_ap, sample_size, iot_energy_comm, ap_energy_comm 253 | return computation_cfgs, communication_cfgs 254 | 255 | 256 | 257 | -------------------------------------------------------------------------------- /frameworks/centaur/plot/plot_balanced_sn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from matplotlib.ticker import FuncFormatter 6 | 7 | from plot.utils import y_fmt, load_acc, load_mac_latency_energy, load_comm_latency_energy 8 | 9 | plt.rcParams.update({'font.size': 8}) 10 | 11 | parser = argparse.ArgumentParser(description="Plot") 12 | parser.add_argument("--plot_acc", type=bool, default=False) 13 | parser.add_argument("--plot_mac", type=bool, default=False) 14 | parser.add_argument("--plot_comm", type=bool, default=False) 15 | parser.add_argument("--plot_sum", type=bool, default=True) 16 | 17 | args = parser.parse_args() 18 | 19 | res_folder = "results_balanced_sn" 20 | # post_args = "_c8_e3_1mob_a5.0_b3.0_g0.0" # arguments for dynamic test 21 | 22 | 23 | # def plot_acc(): 24 | # dat_full, dat_clas, dat_dyn = load_acc(res_folder) 25 | 26 | # # import pdb; pdb.set_trace() 27 | # min_value = min(dat_dyn['accuracy'].to_list() + dat_clas['accuracy'].to_list() + dat_dyn['accuracy'].to_list()) 28 | # min_value = round(min_value, 1) + 0.2 29 | # max_value = max(dat_dyn['accuracy'].to_list() + dat_clas['accuracy'].to_list() + dat_dyn['accuracy'].to_list()) 30 | 31 | # plt.figure(figsize=(4,3)) 32 | # plt.gcf().subplots_adjust(left=0.15, bottom=0.15) 33 | 34 | # plt.plot(dat_full['round'],dat_full['accuracy'], 'g', label='Access point training') 35 | # plt.plot(dat_clas['round'],dat_clas['accuracy'], 'y', label='Tiny end device training') 36 | # plt.plot(dat_dyn['round'],dat_dyn['accuracy'], 'r', label='Partition-based training') 37 | # plt.yticks(np.arange(min_value, max_value+0.05, 0.1)) 38 | # plt.ylim(0.3, 0.95) 39 | # plt.legend() 40 | # plt.xlabel('Communication Rounds') 41 | # plt.ylabel('Test Accuracy') 42 | # plt.savefig(res_folder + '/acc.pdf') 43 | 44 | 45 | 46 | # def plot_22(plot_dat_dyn, plot_dat_ap, plot_dat_iot, x_dyn, x_ap, x_iot, title_text, x_text, y_text, xlim_vec, ylim_vec, filename, xinverse=False): 47 | # figs, axes = plt.subplots(nrows=2, ncols=2) 48 | # pi = 0 49 | # for row in axes: 50 | # for col in row: 51 | # if not max(plot_dat_ap[pi]) == 0: 52 | # col.plot(x_ap, plot_dat_ap[pi], 'g', label='ap') 53 | # if not max(plot_dat_iot[pi]) == 0: 54 | # col.plot(x_iot, plot_dat_iot[pi], 'y', label='iot') 55 | # if not max(plot_dat_dyn[pi]) == 0: 56 | # col.plot(x_dyn, plot_dat_dyn[pi], 'r', label='dyn') 57 | # col.set_title(title_text[pi]) 58 | # col.set_xlabel(x_text[pi]) 59 | # col.set_ylabel(y_text[pi]) 60 | # if xinverse == True: col.invert_xaxis() 61 | # if xlim_vec is not None: col.set_xlim(0, xlim_vec[pi]) 62 | # if ylim_vec is not None: col.set_ylim(0, ylim_vec[pi]) 63 | # col.legend() 64 | # pi += 1 65 | # figs.tight_layout() 66 | # plt.savefig('results/' + filename + '.pdf') 67 | 68 | 69 | def plot_12_3y(plot_dat_dyn, plot_dat_ap, plot_dat_iot, x_dyn, x_ap, x_iot, x_text, y_text, xlim_vec, ylim_vec, filename, xinverse=False): 70 | figs, axes = plt.subplots(nrows=1, ncols=2, figsize=(9,2.5)) 71 | plt.subplots_adjust(left=0.1, right=0.85, wspace=0.9, bottom=0.15) 72 | pi = 0 73 | 74 | min_value = min(x_dyn + x_ap + x_iot) 75 | min_value = round(min_value, 1) 76 | max_value = max(x_dyn + x_ap + x_iot) 77 | 78 | title_text = ['Workload on Tiny End Devices', 'Workload on Access Points'] 79 | lebel_ap = 'Access point training' 80 | lebel_iot = 'Tiny end device training' 81 | lebel_dyn = 'Partition-based training' 82 | 83 | for col in axes: 84 | if not max(plot_dat_ap[pi]) == 0: 85 | col.plot(x_ap, plot_dat_ap[pi], 'g', label=lebel_ap) 86 | if not max(plot_dat_iot[pi]) == 0: 87 | col.plot(x_iot, plot_dat_iot[pi], 'y', label=lebel_iot) 88 | if not max(plot_dat_dyn[pi]) == 0: 89 | col.plot(x_dyn, plot_dat_dyn[pi], 'r', label=lebel_dyn) 90 | col.set_title(title_text[pi]) 91 | col.set_xlabel(x_text[pi]) 92 | col.set_ylabel(y_text[pi]) 93 | 94 | 95 | if xinverse == True: col.invert_xaxis() 96 | if xlim_vec is not None: col.set_xlim(0, xlim_vec[pi]) 97 | if ylim_vec is not None: col.set_ylim(0, ylim_vec[pi]) 98 | col.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 99 | 100 | col.legend() 101 | 102 | col_tx = col.twinx() 103 | col_tx2 = col.twinx() 104 | 105 | if not max(plot_dat_ap[pi+2]) == 0: 106 | col_tx.plot(x_ap, plot_dat_ap[pi+2], 'g', label=lebel_ap) 107 | if not max(plot_dat_iot[pi+2]) == 0: 108 | col_tx.plot(x_iot, plot_dat_iot[pi+2], 'y', label=lebel_iot) 109 | if not max(plot_dat_dyn[pi+2]) == 0: 110 | col_tx.plot(x_dyn, plot_dat_dyn[pi+2], 'r', label=lebel_dyn) 111 | col_tx.set_ylabel(y_text[pi+2]) 112 | col_tx.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 113 | 114 | if not max(plot_dat_ap[pi+4]) == 0: 115 | col_tx2.plot(x_ap, plot_dat_ap[pi+4], 'g', label=lebel_ap) 116 | if not max(plot_dat_iot[pi+4]) == 0: 117 | col_tx2.plot(x_iot, plot_dat_iot[pi+4], 'y', label=lebel_iot) 118 | if not max(plot_dat_dyn[pi+4]) == 0: 119 | col_tx2.plot(x_dyn, plot_dat_dyn[pi+4], 'r', label=lebel_dyn) 120 | col_tx2.set_ylabel(y_text[pi+4]) 121 | col_tx2.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 122 | 123 | col_tx2.spines['right'].set_position(('outward', 45)) 124 | 125 | pi += 1 126 | 127 | # figs.tight_layout() 128 | plt.savefig(res_folder + '/' + filename + '.pdf') 129 | 130 | 131 | def plot_12_2y(plot_dat_dyn, plot_dat_ap, plot_dat_iot, x_dyn, x_ap, x_iot, x_text, y_text, xlim_vec, ylim_vec, filename, xinverse=False): 132 | figs, axes = plt.subplots(nrows=1, ncols=2, figsize=(8,2.8)) 133 | plt.subplots_adjust(left=0.1, right=0.85, wspace=0.8, bottom=0.15) 134 | 135 | pi = 0 136 | 137 | min_value = min(x_dyn + x_ap + x_iot) 138 | min_value = round(min_value, 1) 139 | if min_value < 0.35: min_value = 0.35 140 | max_value = max(x_dyn + x_ap + x_iot) 141 | 142 | title_text = ['Workload on Tiny End Devices', 'Workload on Access Points'] 143 | lebel_ap = 'Access point training' 144 | lebel_iot = 'Tiny end device training' 145 | lebel_dyn = 'Partition-based training' 146 | 147 | linewidth = 1 148 | 149 | for col in axes: 150 | if not max(plot_dat_ap[pi]) <= 1000: 151 | col.plot(x_ap, plot_dat_ap[pi], 'limegreen', label=lebel_ap, linewidth=linewidth) 152 | if not max(plot_dat_iot[pi]) <= 1000: 153 | col.plot(x_iot, plot_dat_iot[pi], 'y', label=lebel_iot, linewidth=linewidth) 154 | if not max(plot_dat_dyn[pi]) <= 1000: 155 | col.plot(x_dyn, plot_dat_dyn[pi], 'r', label=lebel_dyn, linewidth=linewidth) 156 | col.set_title(title_text[pi]) 157 | col.set_xlabel(x_text[pi]) 158 | col.set_ylabel(y_text[pi]) 159 | 160 | 161 | if xinverse == True: col.invert_xaxis() 162 | if pi > 0: col.set_xlim(max_value+0.05, min_value) 163 | if ylim_vec is not None: col.set_ylim(0, ylim_vec[pi]) 164 | col.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 165 | 166 | col_tx = col.twinx() 167 | 168 | if not max(plot_dat_ap[pi+2]) <= 1000: 169 | col_tx.plot(x_ap, plot_dat_ap[pi+2], 'g', label=lebel_ap, linewidth=linewidth) 170 | if not max(plot_dat_iot[pi+2]) <= 1000: 171 | col_tx.plot(x_iot, plot_dat_iot[pi+2], 'y', label=lebel_iot, linewidth=linewidth) 172 | if not max(plot_dat_dyn[pi+2]) <= 1000: 173 | col_tx.plot(x_dyn, plot_dat_dyn[pi+2], 'r', label=lebel_dyn, linewidth=linewidth) 174 | if pi > 0: 175 | col_tx.set_ylabel(y_text[pi+2], color='g') 176 | else: 177 | col_tx.set_ylabel(y_text[pi+2]) 178 | col_tx.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 179 | 180 | col_tx.legend() 181 | 182 | pi += 1 183 | 184 | # figs.tight_layout() 185 | plt.savefig(res_folder + '/' + filename + '.pdf') 186 | 187 | 188 | # def plot_mac_latency_xround(iot_freq, ap_freq, count_inference=True): 189 | # plot_dat_dyn, plot_dat_ap, plot_dat_iot, rounds_dyn, rounds_ap, rounds_iot = load_mac_latency(iot_freq, ap_freq) 190 | 191 | # # plot mac 192 | # title_text = ['on AP','on IoT','on AP','on IoT'] 193 | # x_text = ['round'] * 4 194 | # y_text = ['Accumulated MAC operations (Million)', 'Accumulated MAC operations (Million)', 195 | # 'Accumulated Latency (Seconds)', 'Accumulated Latency (Seconds)'] 196 | 197 | # ylim1 = max(plot_dat_ap[0] + plot_dat_iot[0] + plot_dat_dyn[0] + plot_dat_ap[1] + plot_dat_iot[1] + plot_dat_dyn[1]) 198 | # ylim2 = max(plot_dat_ap[2] + plot_dat_iot[2] + plot_dat_dyn[2] + plot_dat_ap[3] + plot_dat_iot[3] + plot_dat_dyn[3]) 199 | # ylim_vec = [ylim1,ylim1,ylim2,ylim2] 200 | 201 | # xlim_vec = [50] * 4 202 | 203 | # filename = 'mac_latency' 204 | # plot_22(plot_dat_dyn, plot_dat_ap, plot_dat_iot, rounds_dyn, rounds_ap, rounds_iot, title_text, x_text, y_text, xlim_vec, ylim_vec, filename) 205 | 206 | 207 | 208 | 209 | 210 | def plot_mac_latency_energy_xaccuracy(iot_freq, ap_freq, iot_energy_mac, ap_energy_mac, count_inference=True): 211 | 212 | dat_full, dat_clas, dat_dyn = load_acc(res_folder) 213 | acc_ap = dat_full['accuracy'].tolist()[1:] 214 | acc_iot = dat_clas['accuracy'].tolist()[1:] 215 | acc_dyn = dat_dyn['accuracy'].tolist()[1:] 216 | 217 | plot_dat_dyn, plot_dat_ap, plot_dat_iot, _, _, _ = load_mac_latency_energy(iot_freq, ap_freq, iot_energy_mac, ap_energy_mac) 218 | 219 | # plot mac 220 | x_text = ['Test Accuracy'] * 2 221 | y_text = ['Accumulated MAC operations (Million)'] * 2 222 | y_text += ['Accumulated Latency (Seconds)'] * 2 223 | y_text += ['Accumulated Energy (Joule)'] * 2 224 | 225 | xlim_vec, ylim_vec = None, None 226 | 227 | filename = 'mac_latency_energy_accuracy' 228 | plot_12_3y(plot_dat_dyn, plot_dat_ap, plot_dat_iot, acc_dyn, acc_ap, acc_iot, x_text, y_text, xlim_vec, ylim_vec, filename, xinverse=True) 229 | 230 | 231 | 232 | 233 | # def plot_comm_latency_xround(up_iot, down_iot, up_ap, down_ap, sample_size): 234 | # plot_dat_dyn, plot_dat_ap, plot_dat_iot, rounds_dyn, rounds_ap, rounds_iot = load_comm_latency(up_iot, down_iot, up_ap, down_ap, sample_size) 235 | 236 | # # plot mac 237 | # figs, axes = plt.subplots(nrows=2, ncols=2) 238 | # title_text = ['on AP','on IoT','on AP','on IoT'] 239 | # x_text = ['round'] * 4 240 | # y_text = ['Accumulated Comm. Amount (Migabytes)', 'Accumulated Comm. Amount (Migabytes)', 241 | # 'Accumulated Latency (Seconds)', 'Accumulated Latency (Seconds)'] 242 | 243 | # #ylim1 = max(plot_dat_ap[0] + plot_dat_iot[0] + plot_dat_dyn[0] + plot_dat_ap[1] + plot_dat_iot[1] + plot_dat_dyn[1]) 244 | # #ylim2 = max(plot_dat_ap[2] + plot_dat_iot[2] + plot_dat_dyn[2] + plot_dat_ap[3] + plot_dat_iot[3] + plot_dat_dyn[3]) 245 | # #ylim_vec = [ylim1, ylim1, ylim2, ylim2] 246 | # ylim_vec = None 247 | 248 | # xlim_vec = [50] * 4 249 | 250 | # filename = 'comm_latency' 251 | # plot_22(plot_dat_dyn, plot_dat_ap, plot_dat_iot, rounds_dyn, rounds_ap, rounds_iot, title_text, x_text, y_text, xlim_vec, ylim_vec, filename) 252 | 253 | 254 | 255 | def plot_comm_latency_energy_xaccuracy(up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size): 256 | dat_full, dat_clas, dat_dyn = load_acc(res_folder) 257 | acc_ap = dat_full['accuracy'].tolist()[1:] 258 | acc_iot = dat_clas['accuracy'].tolist()[1:] 259 | acc_dyn = dat_dyn['accuracy'].tolist()[1:] 260 | 261 | plot_dat_dyn, plot_dat_ap, plot_dat_iot, _, _, _ = load_comm_latency_energy(up_iot, down_iot, up_ap, down_ap, 262 | iot_energy_comm, ap_energy_comm, sample_size) 263 | 264 | # plot mac 265 | x_text = ['Test Accuracy'] * 2 266 | y_text = ['Accumulated Comm. Amount (Migabytes)'] * 2 267 | y_text += ['Accumulated Latency (Seconds)'] * 2 268 | y_text += ['Accumulated Energy (Joule)'] * 2 269 | 270 | xlim_vec, ylim_vec = None, None 271 | 272 | filename = 'comm_latency_energy_accuracy' 273 | plot_12_3y(plot_dat_dyn, plot_dat_ap, plot_dat_iot, acc_dyn, acc_ap, acc_iot, x_text, y_text, xlim_vec, ylim_vec, filename, xinverse=True) 274 | 275 | 276 | 277 | def plot_sum_latency_energy_xaccuracy(iot_freq, ap_freq, iot_energy_mac, ap_energy_mac, count_inference, 278 | up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size): 279 | 280 | dat_full, dat_clas, dat_dyn = load_acc(res_folder) 281 | acc_ap = dat_full['accuracy'].tolist()[1:] 282 | acc_iot = dat_clas['accuracy'].tolist()[1:] 283 | acc_dyn = dat_dyn['accuracy'].tolist()[1:] 284 | mac_dyn, mac_ap, mac_iot, _, _, _ = load_mac_latency_energy(iot_freq, ap_freq, iot_energy_mac, ap_energy_mac) 285 | comm_dyn, comm_ap, comm_iot, _, _, _ = load_comm_latency_energy(up_iot, down_iot, up_ap, down_ap, 286 | iot_energy_comm, ap_energy_comm, sample_size) 287 | 288 | sum_dyn, sum_ap, sum_iot = [], [], [] 289 | for k in range(2, len(mac_dyn)): 290 | sum_dyn.append([i+j for i,j in zip(mac_dyn[k], comm_dyn[k])]) 291 | sum_ap.append([i+j for i,j in zip(mac_ap[k], comm_ap[k])]) 292 | sum_iot.append([i+j for i,j in zip(mac_iot[k], comm_iot[k])]) 293 | 294 | x_text = ['Test Accuracy'] * 2 295 | y_text = ['Accumulated Latency (Seconds)'] * 2 296 | y_text += ['Accumulated Energy (Joule)'] * 2 297 | 298 | xlim_vec, ylim_vec = None, None 299 | 300 | filename = 'sum_latency_energy_accuracy' 301 | plot_12_2y(sum_dyn, sum_ap, sum_iot, acc_dyn, acc_ap, acc_iot, x_text, y_text, xlim_vec, ylim_vec, filename, xinverse=True) 302 | 303 | 304 | 305 | def res(configs, type, feature): 306 | return float(configs[configs['type'] == type][feature]) 307 | 308 | 309 | 310 | 311 | # plot accuracy only 312 | if args.plot_acc: 313 | plot_acc() 314 | 315 | 316 | cfg = pd.read_csv('client_configurations.csv', sep=',') 317 | 318 | 319 | # plot mac, latency, energy, accuracy 320 | iot_frequency = res(cfg, 'iot', 'frequency(MHz)') 321 | ap_frequency = res(cfg, 'ap', 'frequency(MHz)') 322 | acceleration = res(cfg, 'ap', 'acceleration') # how many times faster the accelerator is than ap CPU 323 | count_inference = True 324 | 325 | iot_energy_mac = iot_frequency * res(cfg, 'iot', 'power(mWpMHz)') # times = milliWatt/Mhz https://www.prnewswire.com/news-releases/ultra-reliable-arm-cortex-m4f-microcontroller-from-maxim-integrated-offers-industrys-lowest-power-consumption-and-smallest-size-for-industrial-healthcare-and-iot-sensor-applications-301080651.html#:~:text=Lowest%20Power%3A%2040%C2%B5W%2FMHz%20of,5mm%2Dx%2D5mm%20TQFN. 326 | ap_energy_mac = ap_frequency * res(cfg, 'ap', 'power(mWpMHz)') # times = milliWatt/MHz 327 | # https://www.researchgate.net/figure/Energy-consumption-for-different-parts-of-the-mobile-phone_tbl1_224248235#:~:text=The%20power%20consumption%20of%20wearable,%5B18%5D.%20... 328 | 329 | 330 | 331 | # # plot communication amount, latency, energy, accuracy 332 | up_iot = res(cfg, 'iot', 'uplink(Mbit/s)') # BLE5 2 Mbit/s 333 | down_iot = res(cfg, 'iot', 'downlink(Mbit/s)') # BLE5 2 Mbit/s 334 | up_ap = res(cfg, 'ap', 'uplink(Mbit/s)') # WIFI 10 Mbit/s 335 | down_ap = res(cfg, 'iot', 'downlink(Mbit/s)') # WIFI 100 Mbit/s 336 | sample_size = 30 # KiloBytes 337 | 338 | iot_energy_comm = res(cfg, 'iot', 'energy_comm(W)') # Watt BLE5 339 | ap_energy_comm = res(cfg, 'ap', 'energy_comm(W)') # Watt WIFI 340 | 341 | 342 | if args.plot_mac: 343 | #plot_mac_latency_xround(iot_frequency, ap_frequency*acceleration, count_inference) 344 | plot_mac_latency_energy_xaccuracy(iot_frequency, ap_frequency*acceleration, iot_energy_mac, ap_energy_mac, count_inference) 345 | 346 | if args.plot_comm: 347 | # plot_comm_latency_xround(up_iot, down_iot, up_ap, down_ap, sample_size) 348 | plot_comm_latency_energy_xaccuracy(up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size) 349 | 350 | if args.plot_sum: 351 | plot_sum_latency_energy_xaccuracy(iot_frequency, ap_frequency*acceleration, iot_energy_mac, ap_energy_mac, count_inference, 352 | up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size) -------------------------------------------------------------------------------- /frameworks/centaur/measures/mac_comm_counter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import pandas as pd 4 | import numpy as np 5 | from fvcore.nn import flop_count_table, flop_count_str, FlopCountAnalysis 6 | 7 | from hiefed.client import init_network 8 | 9 | # functions used to compute MAC operations, communication amounts, latency, energy cost 10 | res_folder = "results" 11 | 12 | # train K M B text to numbers (used in plots) 13 | def text_to_num(text, bad_data_val = 0): 14 | d = { 15 | 'K': 1000, 16 | 'M': 1000000, 17 | 'G': 1000000000 18 | } 19 | if not isinstance(text, str): 20 | return bad_data_val 21 | 22 | elif text[-1] in d: # separate out the K, M, or B 23 | num, magnitude = text[:-1], text[-1] 24 | return int(float(num) * d[magnitude]) 25 | else: 26 | try: # catch exceptions 27 | return float(text) 28 | except Exception as e: 29 | return None 30 | 31 | 32 | # Dynamically generate column names 33 | # The max column count a line in the file could have 34 | def read_file_imbalanced_col(data_file): 35 | 36 | largest_column_count = 0 37 | 38 | # Loop the data lines 39 | with open(data_file, 'r') as temp_f: 40 | lines = temp_f.readlines() 41 | 42 | for l in lines: 43 | column_count = len(l.split(',')) + 1 44 | largest_column_count = column_count if largest_column_count < column_count else largest_column_count 45 | 46 | if largest_column_count > 20: largest_column_count = 20 + 3 # avoiding Error tokenizing data. C error: Buffer overflow caught - possible malformed input file. 47 | 48 | # Generate column names (will be 0, 1, 2, ..., largest_column_count - 1) 49 | column_names = ['dest','round','cid'] 50 | column_names = column_names + ['samples_e'+str(i) for i in range(1, largest_column_count-2)] 51 | 52 | #print(data_file) 53 | #if data_file == 'results/sample_number_dyn_1cifar10_2mobi_3medium_c8|100_e3_a5_b3_g0_flpa1000_mr0.1|0.4.csv': 54 | # import pdb; pdb.set_trace() 55 | df = pd.read_csv(data_file, header=None, skiprows=1, delimiter=',', names=column_names, lineterminator='\n') 56 | 57 | return df 58 | 59 | 60 | # get flops using fvcore.nn library (from facebook AI) 61 | def get_flops_parameters(Net, input): 62 | flops = FlopCountAnalysis(Net, input) 63 | 64 | flops_ = flop_count_table(flops) 65 | flops_ = flops_.replace(' ','') 66 | flops_ = flops_.replace('|',';') 67 | flops_ = flops_.replace(';\n;','\n') 68 | flops_ = flops_.replace('-','') 69 | flops_ = flops_.replace('\n:;:;:\n','\n') 70 | 71 | flops_ = flops_[1:] 72 | flops_ = flops_[:-1] 73 | 74 | with open(res_folder + "/model_flops.csv","w") as f: 75 | f.write(flops_) 76 | 77 | dat = pd.read_csv(res_folder + "/model_flops.csv", sep=";") 78 | return dat 79 | 80 | 81 | def get_input_size(encoder, classifier): 82 | '''input sizes for encoders when computing FLOPs and weights''' 83 | 84 | if encoder == 'sencoder_p': 85 | input = torch.rand(16, 1, 27, 200) 86 | elif encoder == 'sencoder_u' or encoder == 'sencoder_m': 87 | input = torch.rand(16, 1, 9, 128) 88 | else: 89 | input = torch.rand(16, 3, 3, 3) 90 | 91 | if encoder == 'mobilenet' or encoder == 'efficientnet': 92 | classifier_str = 'classifier.0' 93 | encoder_str = 'features' 94 | 95 | elif encoder == 'mnasnet': 96 | classifier_str = 'classifier' 97 | encoder_str = 'layers' 98 | 99 | elif encoder == 'shufflenet': 100 | classifier_str = 'fc' 101 | encoder_str = None 102 | elif encoder == 'sencoder_p' or encoder == 'sencoder_u' or encoder == 'sencoder_m': 103 | classifier_str = 'classifier' 104 | encoder_str = 'encoder' 105 | else: 106 | raise NameError("input size does not defined!") 107 | 108 | return input, classifier_str, encoder_str 109 | 110 | 111 | def mac_counter(encoder, classifier, num_classes=10): 112 | '''MAC operations counter''' 113 | if encoder == 'sencoder_p': 114 | num_classes = 13 115 | elif encoder == 'sencoder_u' or encoder == 'sencoder_m': 116 | num_classes = 6 117 | Net = init_network(encoder=encoder, classifier=classifier, num_classes=num_classes, pre_trained=False) 118 | input, classifier_str, encoder_str = get_input_size(encoder, classifier) 119 | 120 | # get flops 121 | dat = get_flops_parameters(Net, input) 122 | for i in range(len(dat['#flops'])): 123 | dat['#flops'][i] = text_to_num(dat['#flops'][i]) 124 | dat['#flops'] = dat['#flops'].apply(pd.to_numeric) 125 | 126 | # forward MAC operation computation (1 FLOP = 2 MACs) 127 | mac_fw_full = 2 * dat.loc[dat["module"] == 'model']['#flops'] 128 | 129 | mac_fw_clas = 2 * (int(dat.loc[dat["module"] == classifier_str]['#flops'])) 130 | 131 | mac_fw_first = dat['#flops'][2] 132 | if encoder_str is None: 133 | mac_fw_encd = 2 * (mac_fw_full - mac_fw_clas) 134 | else: 135 | mac_fw_encd = 2 * dat.loc[dat["module"] == encoder_str]['#flops'] 136 | 137 | mac_fw_encd_rest = mac_fw_encd - mac_fw_first 138 | 139 | # backward MAC operation computation 140 | # https://www.lesswrong.com/posts/fnjKpBoWJXcSDwhZk/what-s-the-backward-forward-flop-ratio-for-neural-networks 141 | mac_bw_first = mac_fw_first 142 | mac_bw_encd_rest = 2 * mac_fw_encd_rest 143 | mac_bw_encd = mac_bw_first + mac_bw_encd_rest 144 | mac_bw_clas = 2 * mac_fw_clas 145 | mac_bw_full = mac_bw_encd + mac_bw_clas 146 | 147 | mac_fw_encd, mac_bw_encd, mac_fw_clas, mac_bw_clas= int(mac_fw_encd), int(mac_bw_encd), int(mac_fw_clas), int(mac_bw_clas) 148 | if encoder == 'sencoder_p': 149 | mac_fw_encd, mac_bw_encd, mac_fw_clas, mac_bw_clas = mac_fw_encd/200, mac_bw_encd/200, mac_fw_clas/200, mac_bw_clas/200 150 | elif encoder == 'sencoder_u' or encoder == 'sencoder_m': 151 | mac_fw_encd, mac_bw_encd, mac_fw_clas, mac_bw_clas = mac_fw_encd/42, mac_bw_encd/42, mac_fw_clas/42, mac_bw_clas/42 152 | return mac_fw_encd, mac_bw_encd, mac_fw_clas, mac_bw_clas 153 | 154 | 155 | 156 | 157 | def para_counter(encoder, classifier, num_classes=10): 158 | '''Weight parameters counter''' 159 | if encoder == 'sencoder_p': 160 | num_classes = 13 161 | elif encoder == 'sencoder_u' or encoder == 'sencoder_m': 162 | num_classes = 6 163 | Net = init_network(encoder=encoder, classifier=classifier, num_classes=num_classes, pre_trained=False) 164 | input, classifier_str, encoder_str = get_input_size(encoder, classifier) 165 | 166 | # get flops 167 | dat = get_flops_parameters(Net, input) 168 | 169 | for i in range(len(dat['#parametersorshape'])): 170 | dat['#parametersorshape'][i] = text_to_num(dat['#parametersorshape'][i]) 171 | dat['#parametersorshape'] = dat['#parametersorshape'].apply(pd.to_numeric) 172 | 173 | # classifier parameters 174 | para_clas = dat.loc[dat["module"] == classifier_str]['#parametersorshape'] 175 | 176 | # encoder parameters 177 | if encoder_str is None: 178 | # model para 179 | para_model = dat.loc[dat["module"] == 'model']['#parametersorshape'] 180 | para_encd = int(para_model) - int(para_clas) 181 | else: 182 | para_encd = dat.loc[dat["module"] == encoder_str]['#parametersorshape'] 183 | 184 | return int(para_encd), int(para_clas) 185 | 186 | 187 | 188 | 189 | # def get_mac_latency_energy_per(res, encoder='mobilenetv3'): 190 | # '''MAC operations, latency, and energy counter in realtime''' 191 | # if encoder == 'mobilenetv3': 192 | # mac_fw_encd, mac_bw_encd, mac_fw_clas, mac_bw_clas = get_mac_mobilenet() 193 | 194 | # mac_ap = mac_fw_encd + mac_bw_encd + mac_fw_clas + mac_bw_clas 195 | # mac_iot = mac_fw_encd + mac_fw_clas + mac_bw_clas 196 | 197 | # # compute latency (1 MAC operation = 2 instructions) 198 | # ap_freq = float(res.loc[res['type'] == 'ap']['frequency(MHz)']) 199 | # iot_freq = float(res.loc[res['type'] == 'iot']['frequency(MHz)']) 200 | # ap_latency = 2 * mac_ap / ap_freq 201 | # iot_latency = 2 * mac_iot / iot_freq 202 | 203 | # # compute energy Joule 204 | # ap_power = float(res.loc[res['type'] == 'ap']['power(mWpMHz)']) 205 | # iot_power = float(res.loc[res['type'] == 'iot']['power(mWpMHz)']) 206 | # ap_energy = mac_ap * ap_power 207 | # iot_energy = mac_iot * iot_power 208 | 209 | # return ap_latency, iot_latency, ap_energy, iot_energy 210 | 211 | 212 | 213 | 214 | 215 | # def get_comm_latency_energy_per(res, encoder='mobilenetv3'): 216 | # '''Communication amount, latency, and energy counter in realtime''' 217 | # if encoder == 'mobilenetv3': 218 | # para_encd, para_clas = get_para_mobilenet() 219 | 220 | # para_ap = para_encd + para_clas 221 | # para_iot = para_clas 222 | 223 | # # float numbers to Migabytes (MB now) 224 | # para_ap = para_ap*4/(1024*1024) 225 | # para_iot = para_iot*4/(1024*1024) 226 | 227 | # # compute latency (1 bytes = 2 bits) 228 | # up_ap = float(res.loc[res['type'] == 'ap']['uplink(Mbit/s)']) 229 | # down_ap = float(res.loc[res['type'] == 'iot']['downlink(Mbit/s)']) 230 | # up_iot = float(res.loc[res['type'] == 'ap']['uplink(Mbit/s)']) 231 | # down_iot = float(res.loc[res['type'] == 'iot']['downlink(Mbit/s)']) 232 | # ap_latency = para_ap*2/up_ap + para_ap*2/down_ap 233 | # iot_latency = para_iot*2/up_iot + para_iot*2/down_iot 234 | 235 | # # compute energy Joule 236 | # ap_energy = float(res.loc[res['type'] == 'ap']['energy_comm(W)']) 237 | # iot_energy = float(res.loc[res['type'] == 'iot']['energy_comm(W)']) 238 | # ap_energy = para_ap*ap_energy 239 | # iot_energy = para_iot*iot_energy 240 | 241 | # return ap_latency, iot_latency, ap_energy, iot_energy 242 | 243 | 244 | 245 | 246 | # compute accumulated MAC operation and latency during FL for AP and IoT 247 | def mac_latency_energy_counter(encoder, classifier, post_args, destine, iot_freq, ap_freq, iot_energy_mac, ap_energy_mac, count_inference, rounds=100): 248 | filedir = res_folder + "/sample_number_" + destine + "_" + post_args + ".csv" 249 | # filedir = res_folder + "/sample_number_" + destine + post_args + ".csv" 250 | dat_sn = read_file_imbalanced_col(filedir) 251 | 252 | # compute mac 253 | #mac_fw_encd, mac_bw_encd, mac_fw_clas, mac_bw_clas = get_mac_mobilenet() 254 | mac_fw_encd, mac_bw_encd, mac_fw_clas, mac_bw_clas = mac_counter(encoder, classifier) 255 | 256 | mac_list, mac_acm_ap_list, mac_acm_iot_list = [], [], [] 257 | 258 | # loop to counter MAC operations of AP and IoT 259 | for i in range(len(dat_sn['dest'])): 260 | if dat_sn.loc[i, 'dest'] == 'ap': 261 | if count_inference == True: 262 | mac_ = mac_fw_encd + mac_fw_clas + mac_bw_encd + mac_bw_clas 263 | else: 264 | mac_ = mac_bw_encd + mac_bw_clas 265 | mac = np.nansum(dat_sn.loc[i,'samples_e1':]) * mac_ / 1000000 266 | 267 | elif dat_sn.loc[i, 'dest'] == 'iot': 268 | if count_inference == True: 269 | mac_ = mac_fw_encd + mac_fw_clas + mac_bw_clas 270 | else: 271 | mac_ = mac_bw_clas 272 | mac = np.nansum(dat_sn.loc[i,'samples_e1':]) * mac_ / 1000000 273 | mac_list.append(mac) 274 | 275 | 276 | 277 | # average clients on each round and accumulate 278 | dat_sn = dat_sn.assign(mac=mac_list) 279 | dat_sn = dat_sn.groupby(['round', 'dest'],as_index=False).mean() 280 | if len(dat_sn['round']) < rounds: # fill if rounds missing 281 | extra_df = pd.DataFrame(list(range(1,rounds+1)),columns=['round']) 282 | dat_sn = pd.merge(dat_sn, extra_df, on='round', how="right") 283 | 284 | rounds = dat_sn['round'] 285 | 286 | mac_acm_ap, mac_acm_iot = 0, 0 287 | for i in range(len(dat_sn['dest'])): 288 | if dat_sn.loc[i, 'dest'] == 'ap': 289 | mac_acm_ap += dat_sn.loc[i, 'mac'] 290 | elif dat_sn.loc[i, 'dest'] == 'iot': 291 | mac_acm_iot += dat_sn.loc[i, 'mac'] 292 | 293 | mac_acm_ap_list.append(mac_acm_ap) 294 | mac_acm_iot_list.append(mac_acm_iot) 295 | 296 | dat_sn = dat_sn.assign(mac_acm_ap=mac_acm_ap_list) 297 | dat_sn = dat_sn.assign(mac_acm_iot=mac_acm_iot_list) 298 | 299 | # compute latency (1 MAC operation = 2 instructions) https://stackoverflow.com/questions/9242024/multiply-add-a-a2-b-instruction-on-cpu 300 | ap_latency = [2*i/ap_freq for i in mac_acm_ap_list] 301 | iot_latency = [2*i/iot_freq for i in mac_acm_iot_list] 302 | 303 | # compute energy Joule 304 | ap_energy = [i*ap_energy_mac for i in ap_latency] 305 | iot_energy = [i*iot_energy_mac for i in iot_latency] 306 | 307 | plot_dat = [mac_acm_iot_list, mac_acm_ap_list, iot_latency, ap_latency, iot_energy, ap_energy] 308 | 309 | return plot_dat, rounds 310 | 311 | 312 | 313 | 314 | 315 | # compute accumulated parameters to transmit and the latency during FL 316 | def comm_latency_energy_counter(encoder, classifier, post_args, destine, up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size, rounds=100): 317 | 318 | filedir = res_folder + "/sample_number_" + destine + "_" + post_args + ".csv" 319 | # filedir = res_folder + "/sample_number_" + destine + post_args + ".csv" 320 | dat_sn = read_file_imbalanced_col(filedir) 321 | 322 | para_encd, para_clas = para_counter(encoder, classifier) 323 | 324 | # loop to count passed parameters (one way) of AP and IoT (float numbers) 325 | para_list = [] 326 | for i in range(len(dat_sn['dest'])): 327 | if dat_sn.loc[i, 'dest'] == 'ap': 328 | para = (para_encd + para_clas) 329 | elif dat_sn.loc[i, 'dest'] == 'iot': 330 | para = para_clas 331 | if destine == 'dyn': 332 | para += 0.5 * para_encd # 0.5 rate for cancelling upload out 333 | para_list.append(para) 334 | 335 | # loop to count passed samples from IoT to AP 336 | if not sample_size == 0: 337 | sample_list = [] 338 | for i in range(len(dat_sn['dest'])): 339 | if dat_sn.loc[i, 'dest'] == 'ap': 340 | samples_s = sample_size * np.nanmean(dat_sn.loc[i, 'samples_e1':]) # here because AP can cache some data sample, so we estimate use averaged number of samples over epochs 341 | elif dat_sn.loc[i, 'dest'] == 'iot': 342 | samples_s = 0 343 | sample_list.append(samples_s) 344 | 345 | # merge clients data as one by averaging 346 | dat_sn = dat_sn.assign(para=para_list) 347 | dat_sn = dat_sn.assign(samples=sample_list) 348 | dat_sn = dat_sn.groupby(['round', 'dest'],as_index=False).mean() 349 | 350 | # fill if rounds missing 351 | if len(dat_sn['round']) < rounds: 352 | extra_df = pd.DataFrame(list(range(1,rounds+1)), columns=['round']) 353 | dat_sn = pd.merge(dat_sn, extra_df, on='round', how="right") 354 | dat_sn['para'].fillna(0, inplace=True) 355 | dat_sn['samples'].fillna(0, inplace=True) 356 | destine_other = 'iot' if destine == 'ap' else 'ap' 357 | dat_sn['dest'].fillna(destine_other, inplace=True) 358 | 359 | rounds = dat_sn['round'] 360 | 361 | # all parameters passed from IoT to AP need to be passed to server 362 | paralist = dat_sn['para'].to_list() 363 | for i in range(len(rounds)): 364 | if dat_sn['dest'][i] == 'ap': 365 | paralist[i] = paralist[i] + paralist[i-1] 366 | dat_sn['para'] = paralist 367 | 368 | # move sample size up for 1 row (the samples trained on AP is passed by IoT) 369 | dat_sn['samples'] = dat_sn['samples'].to_list()[1:] + [0] 370 | 371 | # add sample sizes to para size for later computation 372 | dat_sn['para'] = dat_sn['para'] + 0.5 * (dat_sn['samples']*1024/4) # KB to B to Float; 0.5 rate for concel download out 373 | 374 | # accumulate passed parameters (one way) 375 | para_acm_ap, para_acm_iot = 0, 0 376 | para_acm_ap_list, para_acm_iot_list = [], [] 377 | for i in range(len(dat_sn['dest'])): 378 | if dat_sn.loc[i, 'dest'] == 'ap': 379 | para_acm_ap += dat_sn.loc[i, 'para'] 380 | elif dat_sn.loc[i, 'dest'] == 'iot': 381 | para_acm_iot += dat_sn.loc[i, 'para'] 382 | 383 | para_acm_ap_list.append(para_acm_ap) 384 | para_acm_iot_list.append(para_acm_iot) 385 | 386 | # float numbers to Migabytes (MB now) 387 | para_acm_ap_list = [i*4/(1024*1024) for i in para_acm_ap_list] 388 | para_acm_iot_list = [i*4/(1024*1024) for i in para_acm_iot_list] 389 | 390 | dat_sn = dat_sn.assign(para_acm_ap=para_acm_ap_list) 391 | dat_sn = dat_sn.assign(para_acm_iot=para_acm_iot_list) 392 | 393 | # compute latency (1 bytes = 2 bits) 394 | iot_latency = [(i*2/up_iot + i*2/down_iot) for i in para_acm_iot_list] 395 | ap_latency = [(i*2/up_ap + i*2/down_ap) for i in para_acm_ap_list] 396 | 397 | # compute energy Joule 398 | ap_energy = [i*ap_energy_comm for i in ap_latency] 399 | iot_energy = [i*iot_energy_comm for i in iot_latency] 400 | 401 | plot_dat = [para_acm_iot_list, para_acm_ap_list, iot_latency, ap_latency, iot_energy, ap_energy] 402 | 403 | return plot_dat, rounds -------------------------------------------------------------------------------- /frameworks/centaur/plot/plot.py.orig: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from matplotlib.ticker import FuncFormatter 6 | 7 | from measures.mac_para_counter import compute_mac_latency_energy_plot, compute_comm_latency_energy_plot 8 | 9 | plt.rcParams.update({'font.size': 8}) 10 | 11 | parser = argparse.ArgumentParser(description="Plot") 12 | parser.add_argument("--plot_acc", type=bool, default=True) 13 | parser.add_argument("--plot_mac", type=bool, default=False) 14 | parser.add_argument("--plot_comm", type=bool, default=False) 15 | parser.add_argument("--plot_sum", type=bool, default=False) 16 | 17 | args = parser.parse_args() 18 | 19 | # https://stackoverflow.com/questions/40566413/matplotlib-pyplot-auto-adjust-unit-of-y-axis 20 | def y_fmt(y, pos): 21 | decades = [1e9, 1e6, 1e3, 1e0, 1e-3, 1e-6, 1e-9 ] 22 | suffix = ["G", "M", "K", "" , "m" , "u", "n" ] 23 | if y == 0: 24 | return str(0) 25 | for i, d in enumerate(decades): 26 | if np.abs(y) >= d: 27 | val = round(y/float(d),3) # fix unlimited 0.99999 issue 28 | signf = len(str(val).split(".")[1]) 29 | if signf == 0: 30 | return '{val:d}{suffix}'.format(val=int(val), suffix=suffix[i]) 31 | else: 32 | if signf == 1: 33 | if str(val).split(".")[1] == "0": 34 | return '{val:d}{suffix}'.format(val=int(str(round(val))), suffix=suffix[i]) 35 | tx = "{"+"val:.{signf}f".format(signf = signf) +"}{suffix}" 36 | return tx.format(val=val, suffix=suffix[i]) 37 | #return y 38 | return y 39 | 40 | 41 | def load_acc(): 42 | filedir = "results/acc_ap.csv" 43 | dat_full = pd.read_csv(filedir, sep=',') 44 | 45 | filedir = "results/acc_iot.csv" 46 | dat_clas = pd.read_csv(filedir, sep=',') 47 | 48 | filedir = "results/acc_dyn.csv" 49 | dat_dyn = pd.read_csv(filedir, sep=',') 50 | 51 | return dat_full, dat_clas, dat_dyn 52 | 53 | 54 | 55 | def load_mac_latency_energy(iot_freq, ap_freq, iot_energy_mac, ap_energy_mac): 56 | plot_dat_dyn, rounds_dyn = compute_mac_latency_energy_plot('dyn', iot_freq, ap_freq, iot_energy_mac, ap_energy_mac, count_inference) 57 | plot_dat_ap, rounds_ap = compute_mac_latency_energy_plot('ap', iot_freq, ap_freq, 58 | iot_energy_mac, ap_energy_mac, count_inference) 59 | plot_dat_iot, rounds_iot = compute_mac_latency_energy_plot('iot', iot_freq, ap_freq, 60 | iot_energy_mac, ap_energy_mac, count_inference) 61 | 62 | return plot_dat_dyn, plot_dat_ap, plot_dat_iot, rounds_dyn, rounds_ap, rounds_iot 63 | 64 | 65 | 66 | def load_comm_latency_energy(up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size): 67 | plot_dat_dyn, rounds_dyn = compute_comm_latency_energy_plot('dyn', up_iot, down_iot, 68 | up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size) 69 | plot_dat_ap, rounds_ap = compute_comm_latency_energy_plot('ap', up_iot, down_iot, 70 | up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size) 71 | plot_dat_iot, rounds_iot = compute_comm_latency_energy_plot('iot', up_iot, down_iot, 72 | up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size) 73 | 74 | return plot_dat_dyn, plot_dat_ap, plot_dat_iot, rounds_dyn, rounds_ap, rounds_iot 75 | 76 | 77 | 78 | def plot_acc(): 79 | dat_full, dat_clas, dat_dyn = load_acc() 80 | 81 | # import pdb; pdb.set_trace() 82 | min_value = min(dat_dyn['accuracy'].to_list() + dat_clas['accuracy'].to_list() + dat_dyn['accuracy'].to_list()) 83 | min_value = round(min_value, 1) + 0.2 84 | max_value = max(dat_dyn['accuracy'].to_list() + dat_clas['accuracy'].to_list() + dat_dyn['accuracy'].to_list()) 85 | 86 | plt.figure(figsize=(4,3)) 87 | plt.gcf().subplots_adjust(left=0.15, bottom=0.15) 88 | 89 | plt.plot(dat_full['round'],dat_full['accuracy'], 'g', label='Access point training') 90 | plt.plot(dat_clas['round'],dat_clas['accuracy'], 'y', label='Tiny end device training') 91 | plt.plot(dat_dyn['round'],dat_dyn['accuracy'], 'r', label='Partition-based training') 92 | plt.yticks(np.arange(min_value, max_value+0.05, 0.1)) 93 | plt.ylim(0.3, 0.95) 94 | plt.legend() 95 | plt.xlabel('Communication Rounds') 96 | plt.ylabel('Test Accuracy') 97 | plt.savefig('results/acc.pdf') 98 | 99 | 100 | 101 | # def plot_22(plot_dat_dyn, plot_dat_ap, plot_dat_iot, x_dyn, x_ap, x_iot, title_text, x_text, y_text, xlim_vec, ylim_vec, filename, xinverse=False): 102 | # figs, axes = plt.subplots(nrows=2, ncols=2) 103 | # pi = 0 104 | # for row in axes: 105 | # for col in row: 106 | # if not max(plot_dat_ap[pi]) == 0: 107 | # col.plot(x_ap, plot_dat_ap[pi], 'g', label='ap') 108 | # if not max(plot_dat_iot[pi]) == 0: 109 | # col.plot(x_iot, plot_dat_iot[pi], 'y', label='iot') 110 | # if not max(plot_dat_dyn[pi]) == 0: 111 | # col.plot(x_dyn, plot_dat_dyn[pi], 'r', label='dyn') 112 | # col.set_title(title_text[pi]) 113 | # col.set_xlabel(x_text[pi]) 114 | # col.set_ylabel(y_text[pi]) 115 | # if xinverse == True: col.invert_xaxis() 116 | # if xlim_vec is not None: col.set_xlim(0, xlim_vec[pi]) 117 | # if ylim_vec is not None: col.set_ylim(0, ylim_vec[pi]) 118 | # col.legend() 119 | # pi += 1 120 | # figs.tight_layout() 121 | # plt.savefig('results/' + filename + '.pdf') 122 | 123 | 124 | def plot_12_3y(plot_dat_dyn, plot_dat_ap, plot_dat_iot, x_dyn, x_ap, x_iot, x_text, y_text, xlim_vec, ylim_vec, filename, xinverse=False): 125 | figs, axes = plt.subplots(nrows=1, ncols=2, figsize=(9,2.5)) 126 | plt.subplots_adjust(left=0.1, right=0.85, wspace=0.9, bottom=0.15) 127 | pi = 0 128 | 129 | min_value = min(x_dyn + x_ap + x_iot) 130 | min_value = round(min_value, 1) 131 | max_value = max(x_dyn + x_ap + x_iot) 132 | 133 | title_text = ['Workload on Tiny End Devices', 'Workload on Access Points'] 134 | lebel_ap = 'Access point training' 135 | lebel_iot = 'Tiny end device training' 136 | lebel_dyn = 'Partition-based training' 137 | 138 | for col in axes: 139 | if not max(plot_dat_ap[pi]) == 0: 140 | col.plot(x_ap, plot_dat_ap[pi], 'g', label=lebel_ap) 141 | if not max(plot_dat_iot[pi]) == 0: 142 | col.plot(x_iot, plot_dat_iot[pi], 'y', label=lebel_iot) 143 | if not max(plot_dat_dyn[pi]) == 0: 144 | col.plot(x_dyn, plot_dat_dyn[pi], 'r', label=lebel_dyn) 145 | col.set_title(title_text[pi]) 146 | col.set_xlabel(x_text[pi]) 147 | col.set_ylabel(y_text[pi]) 148 | 149 | 150 | if xinverse == True: col.invert_xaxis() 151 | if xlim_vec is not None: col.set_xlim(0, xlim_vec[pi]) 152 | if ylim_vec is not None: col.set_ylim(0, ylim_vec[pi]) 153 | col.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 154 | 155 | col.legend() 156 | 157 | col_tx = col.twinx() 158 | col_tx2 = col.twinx() 159 | 160 | if not max(plot_dat_ap[pi+2]) == 0: 161 | col_tx.plot(x_ap, plot_dat_ap[pi+2], 'g', label=lebel_ap) 162 | if not max(plot_dat_iot[pi+2]) == 0: 163 | col_tx.plot(x_iot, plot_dat_iot[pi+2], 'y', label=lebel_iot) 164 | if not max(plot_dat_dyn[pi+2]) == 0: 165 | col_tx.plot(x_dyn, plot_dat_dyn[pi+2], 'r', label=lebel_dyn) 166 | col_tx.set_ylabel(y_text[pi+2]) 167 | col_tx.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 168 | 169 | if not max(plot_dat_ap[pi+4]) == 0: 170 | col_tx2.plot(x_ap, plot_dat_ap[pi+4], 'g', label=lebel_ap) 171 | if not max(plot_dat_iot[pi+4]) == 0: 172 | col_tx2.plot(x_iot, plot_dat_iot[pi+4], 'y', label=lebel_iot) 173 | if not max(plot_dat_dyn[pi+4]) == 0: 174 | col_tx2.plot(x_dyn, plot_dat_dyn[pi+4], 'r', label=lebel_dyn) 175 | col_tx2.set_ylabel(y_text[pi+4]) 176 | col_tx2.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 177 | 178 | col_tx2.spines['right'].set_position(('outward', 45)) 179 | 180 | pi += 1 181 | 182 | # figs.tight_layout() 183 | plt.savefig('results/' + filename + '.pdf') 184 | 185 | 186 | def plot_12_2y(plot_dat_dyn, plot_dat_ap, plot_dat_iot, x_dyn, x_ap, x_iot, x_text, y_text, xlim_vec, ylim_vec, filename, xinverse=False): 187 | figs, axes = plt.subplots(nrows=1, ncols=2, figsize=(8,2.8)) 188 | plt.subplots_adjust(left=0.1, right=0.85, wspace=0.8, bottom=0.15) 189 | 190 | pi = 0 191 | 192 | min_value = min(x_dyn + x_ap + x_iot) 193 | min_value = round(min_value, 1) 194 | if min_value < 0.35: min_value = 0.35 195 | max_value = max(x_dyn + x_ap + x_iot) 196 | 197 | title_text = ['Workload on Tiny End Devices', 'Workload on Access Points'] 198 | lebel_ap = 'Access point training' 199 | lebel_iot = 'Tiny end device training' 200 | lebel_dyn = 'Partition-based training' 201 | 202 | linewidth = 1 203 | 204 | for col in axes: 205 | if not max(plot_dat_ap[pi]) <= 1000: 206 | col.plot(x_ap, plot_dat_ap[pi], 'limegreen', label=lebel_ap, linewidth=linewidth) 207 | if not max(plot_dat_iot[pi]) <= 1000: 208 | col.plot(x_iot, plot_dat_iot[pi], 'y', label=lebel_iot, linewidth=linewidth) 209 | if not max(plot_dat_dyn[pi]) <= 1000: 210 | col.plot(x_dyn, plot_dat_dyn[pi], 'r', label=lebel_dyn, linewidth=linewidth) 211 | col.set_title(title_text[pi]) 212 | col.set_xlabel(x_text[pi]) 213 | col.set_ylabel(y_text[pi]) 214 | 215 | 216 | if xinverse == True: col.invert_xaxis() 217 | if pi > 0: col.set_xlim(max_value+0.05, min_value) 218 | if ylim_vec is not None: col.set_ylim(0, ylim_vec[pi]) 219 | col.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 220 | 221 | col_tx = col.twinx() 222 | 223 | if not max(plot_dat_ap[pi+2]) <= 1000: 224 | col_tx.plot(x_ap, plot_dat_ap[pi+2], 'g', label=lebel_ap, linewidth=linewidth) 225 | if not max(plot_dat_iot[pi+2]) <= 1000: 226 | col_tx.plot(x_iot, plot_dat_iot[pi+2], 'y', label=lebel_iot, linewidth=linewidth) 227 | if not max(plot_dat_dyn[pi+2]) <= 1000: 228 | col_tx.plot(x_dyn, plot_dat_dyn[pi+2], 'r', label=lebel_dyn, linewidth=linewidth) 229 | if pi > 0: 230 | col_tx.set_ylabel(y_text[pi+2], color='g') 231 | else: 232 | col_tx.set_ylabel(y_text[pi+2]) 233 | col_tx.yaxis.set_major_formatter(FuncFormatter(y_fmt)) 234 | 235 | col_tx.legend() 236 | 237 | pi += 1 238 | 239 | # figs.tight_layout() 240 | plt.savefig('results/' + filename + '.pdf') 241 | 242 | 243 | # def plot_mac_latency_xround(iot_freq, ap_freq, count_inference=True): 244 | # plot_dat_dyn, plot_dat_ap, plot_dat_iot, rounds_dyn, rounds_ap, rounds_iot = load_mac_latency(iot_freq, ap_freq) 245 | 246 | # # plot mac 247 | # title_text = ['on AP','on IoT','on AP','on IoT'] 248 | # x_text = ['round'] * 4 249 | # y_text = ['Accumulated MAC operations (Million)', 'Accumulated MAC operations (Million)', 250 | # 'Accumulated Latency (Seconds)', 'Accumulated Latency (Seconds)'] 251 | 252 | # ylim1 = max(plot_dat_ap[0] + plot_dat_iot[0] + plot_dat_dyn[0] + plot_dat_ap[1] + plot_dat_iot[1] + plot_dat_dyn[1]) 253 | # ylim2 = max(plot_dat_ap[2] + plot_dat_iot[2] + plot_dat_dyn[2] + plot_dat_ap[3] + plot_dat_iot[3] + plot_dat_dyn[3]) 254 | # ylim_vec = [ylim1,ylim1,ylim2,ylim2] 255 | 256 | # xlim_vec = [50] * 4 257 | 258 | # filename = 'mac_latency' 259 | # plot_22(plot_dat_dyn, plot_dat_ap, plot_dat_iot, rounds_dyn, rounds_ap, rounds_iot, title_text, x_text, y_text, xlim_vec, ylim_vec, filename) 260 | 261 | 262 | 263 | 264 | 265 | def plot_mac_latency_energy_xaccuracy(iot_freq, ap_freq, iot_energy_mac, ap_energy_mac, count_inference=True): 266 | 267 | dat_full, dat_clas, dat_dyn = load_acc() 268 | acc_ap = dat_full['accuracy'].tolist()[1:] 269 | acc_iot = dat_clas['accuracy'].tolist()[1:] 270 | acc_dyn = dat_dyn['accuracy'].tolist()[1:] 271 | 272 | plot_dat_dyn, plot_dat_ap, plot_dat_iot, _, _, _ = load_mac_latency_energy(iot_freq, ap_freq, iot_energy_mac, ap_energy_mac) 273 | 274 | # plot mac 275 | x_text = ['Test Accuracy'] * 2 276 | y_text = ['Accumulated MAC operations (Million)'] * 2 277 | y_text += ['Accumulated Latency (Seconds)'] * 2 278 | y_text += ['Accumulated Energy (Joule)'] * 2 279 | 280 | xlim_vec, ylim_vec = None, None 281 | 282 | filename = 'mac_latency_energy_accuracy' 283 | plot_12_3y(plot_dat_dyn, plot_dat_ap, plot_dat_iot, acc_dyn, acc_ap, acc_iot, x_text, y_text, xlim_vec, ylim_vec, filename, xinverse=True) 284 | 285 | 286 | 287 | 288 | # def plot_comm_latency_xround(up_iot, down_iot, up_ap, down_ap, sample_size): 289 | # plot_dat_dyn, plot_dat_ap, plot_dat_iot, rounds_dyn, rounds_ap, rounds_iot = load_comm_latency(up_iot, down_iot, up_ap, down_ap, sample_size) 290 | 291 | # # plot mac 292 | # figs, axes = plt.subplots(nrows=2, ncols=2) 293 | # title_text = ['on AP','on IoT','on AP','on IoT'] 294 | # x_text = ['round'] * 4 295 | # y_text = ['Accumulated Comm. Amount (Migabytes)', 'Accumulated Comm. Amount (Migabytes)', 296 | # 'Accumulated Latency (Seconds)', 'Accumulated Latency (Seconds)'] 297 | 298 | # #ylim1 = max(plot_dat_ap[0] + plot_dat_iot[0] + plot_dat_dyn[0] + plot_dat_ap[1] + plot_dat_iot[1] + plot_dat_dyn[1]) 299 | # #ylim2 = max(plot_dat_ap[2] + plot_dat_iot[2] + plot_dat_dyn[2] + plot_dat_ap[3] + plot_dat_iot[3] + plot_dat_dyn[3]) 300 | # #ylim_vec = [ylim1, ylim1, ylim2, ylim2] 301 | # ylim_vec = None 302 | 303 | # xlim_vec = [50] * 4 304 | 305 | # filename = 'comm_latency' 306 | # plot_22(plot_dat_dyn, plot_dat_ap, plot_dat_iot, rounds_dyn, rounds_ap, rounds_iot, title_text, x_text, y_text, xlim_vec, ylim_vec, filename) 307 | 308 | 309 | 310 | def plot_comm_latency_energy_xaccuracy(up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size): 311 | dat_full, dat_clas, dat_dyn = load_acc() 312 | acc_ap = dat_full['accuracy'].tolist()[1:] 313 | acc_iot = dat_clas['accuracy'].tolist()[1:] 314 | acc_dyn = dat_dyn['accuracy'].tolist()[1:] 315 | 316 | plot_dat_dyn, plot_dat_ap, plot_dat_iot, _, _, _ = load_comm_latency_energy(up_iot, down_iot, up_ap, down_ap, 317 | iot_energy_comm, ap_energy_comm, sample_size) 318 | 319 | # plot mac 320 | x_text = ['Test Accuracy'] * 2 321 | y_text = ['Accumulated Comm. Amount (Migabytes)'] * 2 322 | y_text += ['Accumulated Latency (Seconds)'] * 2 323 | y_text += ['Accumulated Energy (Joule)'] * 2 324 | 325 | xlim_vec, ylim_vec = None, None 326 | 327 | filename = 'comm_latency_energy_accuracy' 328 | plot_12_3y(plot_dat_dyn, plot_dat_ap, plot_dat_iot, acc_dyn, acc_ap, acc_iot, x_text, y_text, xlim_vec, ylim_vec, filename, xinverse=True) 329 | 330 | 331 | 332 | def plot_sum_latency_energy_xaccuracy(iot_freq, ap_freq, iot_energy_mac, ap_energy_mac, count_inference, 333 | up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size): 334 | 335 | dat_full, dat_clas, dat_dyn = load_acc() 336 | acc_ap = dat_full['accuracy'].tolist()[1:] 337 | acc_iot = dat_clas['accuracy'].tolist()[1:] 338 | acc_dyn = dat_dyn['accuracy'].tolist()[1:] 339 | mac_dyn, mac_ap, mac_iot, _, _, _ = load_mac_latency_energy(iot_freq, ap_freq, iot_energy_mac, ap_energy_mac) 340 | comm_dyn, comm_ap, comm_iot, _, _, _ = load_comm_latency_energy(up_iot, down_iot, up_ap, down_ap, 341 | iot_energy_comm, ap_energy_comm, sample_size) 342 | 343 | sum_dyn, sum_ap, sum_iot = [], [], [] 344 | for k in range(2, len(mac_dyn)): 345 | sum_dyn.append([i+j for i,j in zip(mac_dyn[k], comm_dyn[k])]) 346 | sum_ap.append([i+j for i,j in zip(mac_ap[k], comm_ap[k])]) 347 | sum_iot.append([i+j for i,j in zip(mac_iot[k], comm_iot[k])]) 348 | 349 | x_text = ['Test Accuracy'] * 2 350 | y_text = ['Accumulated Latency (Seconds)'] * 2 351 | y_text += ['Accumulated Energy (Joule)'] * 2 352 | 353 | xlim_vec, ylim_vec = None, None 354 | 355 | filename = 'sum_latency_energy_accuracy' 356 | plot_12_2y(sum_dyn, sum_ap, sum_iot, acc_dyn, acc_ap, acc_iot, x_text, y_text, xlim_vec, ylim_vec, filename, xinverse=True) 357 | 358 | 359 | 360 | 361 | 362 | 363 | def res(configs, type, feature): 364 | return float(configs[configs['type'] == type][feature]) 365 | 366 | 367 | 368 | 369 | # plot accuracy only 370 | if args.plot_acc: 371 | plot_acc() 372 | 373 | 374 | cfg = pd.read_csv('client_configurations.csv', sep=',') 375 | 376 | 377 | # plot mac, latency, energy, accuracy 378 | iot_frequency = res(cfg, 'iot', 'frequency(MHz)') 379 | ap_frequency = res(cfg, 'ap', 'frequency(MHz)') 380 | acceleration = res(cfg, 'ap', 'acceleration') # how many times faster the accelerator is than ap CPU 381 | count_inference = True 382 | 383 | iot_energy_mac = iot_frequency * res(cfg, 'iot', 'power(mWpMHz)') # times = milliWatt/Mhz https://www.prnewswire.com/news-releases/ultra-reliable-arm-cortex-m4f-microcontroller-from-maxim-integrated-offers-industrys-lowest-power-consumption-and-smallest-size-for-industrial-healthcare-and-iot-sensor-applications-301080651.html#:~:text=Lowest%20Power%3A%2040%C2%B5W%2FMHz%20of,5mm%2Dx%2D5mm%20TQFN. 384 | ap_energy_mac = ap_frequency * res(cfg, 'ap', 'power(mWpMHz)') # times = milliWatt/MHz 385 | # https://www.researchgate.net/figure/Energy-consumption-for-different-parts-of-the-mobile-phone_tbl1_224248235#:~:text=The%20power%20consumption%20of%20wearable,%5B18%5D.%20... 386 | 387 | if args.plot_mac: 388 | #plot_mac_latency_xround(iot_frequency, ap_frequency*acceleration, count_inference) 389 | plot_mac_latency_energy_xaccuracy(iot_frequency, ap_frequency*acceleration, iot_energy_mac, ap_energy_mac, count_inference) 390 | 391 | 392 | 393 | 394 | # # plot communication amount, latency, energy, accuracy 395 | up_iot = res(cfg, 'iot', 'uplink(Mbit/s)') # BLE5 2 Mbit/s 396 | down_iot = res(cfg, 'iot', 'downlink(Mbit/s)') # BLE5 2 Mbit/s 397 | up_ap = res(cfg, 'ap', 'uplink(Mbit/s)') # WIFI 10 Mbit/s 398 | down_ap = res(cfg, 'iot', 'downlink(Mbit/s)') # WIFI 100 Mbit/s 399 | sample_size = 30 # KiloBytes 400 | 401 | iot_energy_comm = res(cfg, 'iot', 'energy_comm(W)') # Watt BLE5 402 | ap_energy_comm = res(cfg, 'ap', 'energy_comm(W)') # Watt WIFI 403 | 404 | if args.plot_comm: 405 | # plot_comm_latency_xround(up_iot, down_iot, up_ap, down_ap, sample_size) 406 | plot_comm_latency_energy_xaccuracy(up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size) 407 | 408 | 409 | 410 | 411 | 412 | if args.plot_sum: 413 | plot_sum_latency_energy_xaccuracy(iot_frequency, ap_frequency*acceleration, iot_energy_mac, ap_energy_mac, count_inference, 414 | up_iot, down_iot, up_ap, down_ap, iot_energy_comm, ap_energy_comm, sample_size) -------------------------------------------------------------------------------- /frameworks/centaur/third_party/dataset_partition_flwr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Adap GmbH. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Commonly used functions for generating partitioned datasets.""" 16 | 17 | # pylint: disable=invalid-name 18 | 19 | 20 | from typing import List, Optional, Tuple, Union 21 | 22 | import numpy as np 23 | from numpy.random import BitGenerator, Generator, SeedSequence 24 | 25 | XY = Tuple[np.ndarray, np.ndarray] 26 | XYList = List[XY] 27 | PartitionedDataset = Tuple[XYList, XYList] 28 | 29 | np.random.seed(2020) 30 | 31 | 32 | def float_to_int(i: float) -> int: 33 | """Return float as int but raise if decimal is dropped.""" 34 | if not i.is_integer(): 35 | raise Exception("Cast would drop decimals") 36 | 37 | return int(i) 38 | 39 | 40 | def sort_by_label(x: np.ndarray, y: np.ndarray) -> XY: 41 | """Sort by label. 42 | 43 | Assuming two labels and four examples the resulting label order 44 | would be 1,1,2,2 45 | """ 46 | idx = np.argsort(y, axis=0).reshape((y.shape[0])) 47 | return (x[idx], y[idx]) 48 | 49 | 50 | def sort_by_label_repeating(x: np.ndarray, y: np.ndarray) -> XY: 51 | """Sort by label in repeating groups. Assuming two labels and four examples 52 | the resulting label order would be 1,2,1,2. 53 | 54 | Create sorting index which is applied to by label sorted x, y 55 | 56 | .. code-block:: python 57 | 58 | # given: 59 | y = [ 60 | 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9 61 | ] 62 | 63 | # use: 64 | idx = [ 65 | 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19 66 | ] 67 | 68 | # so that y[idx] becomes: 69 | y = [ 70 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 71 | ] 72 | """ 73 | x, y = sort_by_label(x, y) 74 | 75 | num_example = x.shape[0] 76 | num_class = np.unique(y).shape[0] 77 | idx = ( 78 | np.array(range(num_example), np.int64) 79 | .reshape((num_class, num_example // num_class)) 80 | .transpose() 81 | .reshape(num_example) 82 | ) 83 | 84 | return (x[idx], y[idx]) 85 | 86 | 87 | def split_at_fraction(x: np.ndarray, y: np.ndarray, fraction: float) -> Tuple[XY, XY]: 88 | """Split x, y at a certain fraction.""" 89 | splitting_index = float_to_int(x.shape[0] * fraction) 90 | # Take everything BEFORE splitting_index 91 | x_0, y_0 = x[:splitting_index], y[:splitting_index] 92 | # Take everything AFTER splitting_index 93 | x_1, y_1 = x[splitting_index:], y[splitting_index:] 94 | return (x_0, y_0), (x_1, y_1) 95 | 96 | 97 | def shuffle(x: np.ndarray, y: np.ndarray) -> XY: 98 | """Shuffle x and y.""" 99 | idx = np.random.permutation(len(x)) 100 | return x[idx], y[idx] 101 | 102 | 103 | def partition(x: np.ndarray, y: np.ndarray, num_partitions: int) -> List[XY]: 104 | """Return x, y as list of partitions.""" 105 | return list(zip(np.split(x, num_partitions), np.split(y, num_partitions))) 106 | 107 | 108 | def combine_partitions(xy_list_0: XYList, xy_list_1: XYList) -> XYList: 109 | """Combine two lists of ndarray Tuples into one list.""" 110 | return [ 111 | (np.concatenate([x_0, x_1], axis=0), np.concatenate([y_0, y_1], axis=0)) 112 | for (x_0, y_0), (x_1, y_1) in zip(xy_list_0, xy_list_1) 113 | ] 114 | 115 | 116 | def shift(x: np.ndarray, y: np.ndarray) -> XY: 117 | """Shift x_1, y_1 so that the first half contains only labels 0 to 4 and 118 | the second half 5 to 9.""" 119 | x, y = sort_by_label(x, y) 120 | 121 | (x_0, y_0), (x_1, y_1) = split_at_fraction(x, y, fraction=0.5) 122 | (x_0, y_0), (x_1, y_1) = shuffle(x_0, y_0), shuffle(x_1, y_1) 123 | x, y = np.concatenate([x_0, x_1], axis=0), np.concatenate([y_0, y_1], axis=0) 124 | return x, y 125 | 126 | 127 | def create_partitions( 128 | unpartitioned_dataset: XY, 129 | iid_fraction: float, 130 | num_partitions: int, 131 | ) -> XYList: 132 | """Create partitioned version of a training or test set. 133 | 134 | Currently tested and supported are MNIST, FashionMNIST and 135 | CIFAR-10/100 136 | """ 137 | x, y = unpartitioned_dataset 138 | 139 | x, y = shuffle(x, y) 140 | x, y = sort_by_label_repeating(x, y) 141 | 142 | (x_0, y_0), (x_1, y_1) = split_at_fraction(x, y, fraction=iid_fraction) 143 | 144 | # Shift in second split of dataset the classes into two groups 145 | x_1, y_1 = shift(x_1, y_1) 146 | 147 | xy_0_partitions = partition(x_0, y_0, num_partitions) 148 | xy_1_partitions = partition(x_1, y_1, num_partitions) 149 | 150 | xy_partitions = combine_partitions(xy_0_partitions, xy_1_partitions) 151 | 152 | # Adjust x and y shape 153 | return [adjust_xy_shape(xy) for xy in xy_partitions] 154 | 155 | 156 | def create_partitioned_dataset( 157 | keras_dataset: Tuple[XY, XY], 158 | iid_fraction: float, 159 | num_partitions: int, 160 | ) -> Tuple[PartitionedDataset, XY]: 161 | """Create partitioned version of keras dataset. 162 | 163 | Currently tested and supported are MNIST, FashionMNIST and 164 | CIFAR-10/100 165 | """ 166 | xy_train, xy_test = keras_dataset 167 | 168 | xy_train_partitions = create_partitions( 169 | unpartitioned_dataset=xy_train, 170 | iid_fraction=iid_fraction, 171 | num_partitions=num_partitions, 172 | ) 173 | 174 | xy_test_partitions = create_partitions( 175 | unpartitioned_dataset=xy_test, 176 | iid_fraction=iid_fraction, 177 | num_partitions=num_partitions, 178 | ) 179 | 180 | return (xy_train_partitions, xy_test_partitions), adjust_xy_shape(xy_test) 181 | 182 | 183 | def log_distribution(xy_partitions: XYList) -> None: 184 | """Print label distribution for list of paritions.""" 185 | distro = [np.unique(y, return_counts=True) for _, y in xy_partitions] 186 | for d in distro: 187 | print(d) 188 | 189 | 190 | def adjust_xy_shape(xy: XY) -> XY: 191 | """Adjust shape of both x and y.""" 192 | x, y = xy 193 | if x.ndim == 3: 194 | x = adjust_x_shape(x) 195 | if y.ndim == 2: 196 | y = adjust_y_shape(y) 197 | return (x, y) 198 | 199 | 200 | def adjust_x_shape(nda: np.ndarray) -> np.ndarray: 201 | """Turn shape (x, y, z) into (x, y, z, 1).""" 202 | nda_adjusted = np.reshape(nda, (nda.shape[0], nda.shape[1], nda.shape[2], 1)) 203 | return nda_adjusted 204 | 205 | 206 | def adjust_y_shape(nda: np.ndarray) -> np.ndarray: 207 | """Turn shape (x, 1) into (x).""" 208 | nda_adjusted = np.reshape(nda, (nda.shape[0])) 209 | return nda_adjusted 210 | 211 | 212 | def split_array_at_indices( 213 | x: np.ndarray, split_idx: np.ndarray 214 | ) -> List[List[np.ndarray]]: 215 | """Splits an array `x` into list of elements using starting indices from 216 | `split_idx`. 217 | 218 | This function should be used with `unique_indices` from `np.unique()` after 219 | sorting by label. 220 | 221 | Args: 222 | x (np.ndarray): Original array of dimension (N,a,b,c,...) 223 | split_idx (np.ndarray): 1-D array contaning increasing number of 224 | indices to be used as partitions. Initial value must be zero. Last value 225 | must be less than N. 226 | 227 | Returns: 228 | List[List[np.ndarray]]: List of list of samples. 229 | """ 230 | 231 | if split_idx.ndim != 1: 232 | raise ValueError("Variable `split_idx` must be a 1-D numpy array.") 233 | if split_idx.dtype != np.int64: 234 | raise ValueError("Variable `split_idx` must be of type np.int64.") 235 | if split_idx[0] != 0: 236 | raise ValueError("First value of `split_idx` must be 0.") 237 | if split_idx[-1] >= x.shape[0]: 238 | raise ValueError( 239 | """Last value in `split_idx` must be less than 240 | the number of samples in `x`.""" 241 | ) 242 | if not np.all(split_idx[:-1] <= split_idx[1:]): 243 | raise ValueError("Items in `split_idx` must be in increasing order.") 244 | 245 | num_splits: int = len(split_idx) 246 | split_idx = np.append(split_idx, x.shape[0]) 247 | 248 | list_samples_split: List[List[np.ndarray]] = [[] for _ in range(num_splits)] 249 | for j in range(num_splits): 250 | tmp_x = x[split_idx[j] : split_idx[j + 1]] # noqa: E203 251 | for sample in tmp_x: 252 | list_samples_split[j].append(sample) 253 | 254 | return list_samples_split 255 | 256 | 257 | def exclude_classes_and_normalize( 258 | distribution: np.ndarray, exclude_dims: List[bool], eps: float = 1e-5 259 | ) -> np.ndarray: 260 | """Excludes classes from a distribution. 261 | 262 | This function is particularly useful when sampling without replacement. 263 | Classes for which no sample is available have their probabilities are set to 0. 264 | Classes that had probabilities originally set to 0 are incremented with 265 | `eps` to allow sampling from remaining items. 266 | 267 | Args: 268 | distribution (np.array): Distribution being used. 269 | exclude_dims (List[bool]): Dimensions to be excluded. 270 | eps (float, optional): Small value to be addad to non-excluded dimensions. 271 | Defaults to 1e-5. 272 | 273 | Returns: 274 | np.ndarray: Normalized distributions. 275 | """ 276 | if np.any(distribution < 0) or (not np.isclose(np.sum(distribution), 1.0)): 277 | raise ValueError("distribution must sum to 1 and have only positive values.") 278 | 279 | if distribution.size != len(exclude_dims): 280 | raise ValueError( 281 | """Length of distribution must be equal 282 | to the length `exclude_dims`.""" 283 | ) 284 | if eps < 0: 285 | raise ValueError("""The value of `eps` must be positive and small.""") 286 | 287 | distribution[[not x for x in exclude_dims]] += eps 288 | distribution[exclude_dims] = 0.0 289 | sum_rows = np.sum(distribution) + np.finfo(float).eps 290 | distribution = distribution / sum_rows 291 | 292 | return distribution 293 | 294 | 295 | def sample_without_replacement( 296 | distribution: np.ndarray, 297 | list_samples: List[List[np.ndarray]], 298 | num_samples: int, 299 | empty_classes: List[bool], 300 | ) -> Tuple[XY, List[bool]]: 301 | """Samples from a list without replacement using a given distribution. 302 | 303 | Args: 304 | distribution (np.ndarray): Distribution used for sampling. 305 | list_samples(List[List[np.ndarray]]): List of samples. 306 | num_samples (int): Total number of items to be sampled. 307 | empty_classes (List[bool]): List of booleans indicating which classes are empty. 308 | This is useful to differentiate which classes should still be sampled. 309 | 310 | Returns: 311 | XY: Dataset contaning samples 312 | List[bool]: empty_classes. 313 | """ 314 | if np.sum([len(x) for x in list_samples]) < num_samples: 315 | raise ValueError( 316 | """Number of samples in `list_samples` is less than `num_samples`""" 317 | ) 318 | 319 | # Make sure empty classes are not sampled 320 | # and solves for rare cases where 321 | if not empty_classes: 322 | empty_classes = len(distribution) * [False] 323 | 324 | distribution = exclude_classes_and_normalize( 325 | distribution=distribution, exclude_dims=empty_classes 326 | ) 327 | 328 | data: List[np.ndarray] = [] 329 | target: List[np.ndarray] = [] 330 | 331 | for _ in range(num_samples): 332 | sample_class = np.where(np.random.multinomial(1, distribution) == 1)[0][0] 333 | sample: np.ndarray = list_samples[sample_class].pop() 334 | 335 | data.append(sample) 336 | target.append(sample_class) 337 | 338 | # If last sample of the class was drawn, then set the 339 | # probability density function (PDF) to zero for that class. 340 | if len(list_samples[sample_class]) == 0: 341 | empty_classes[sample_class] = True 342 | # Be careful to distinguish between classes that had zero probability 343 | # and classes that are now empty 344 | distribution = exclude_classes_and_normalize( 345 | distribution=distribution, exclude_dims=empty_classes 346 | ) 347 | data_array: np.ndarray = np.concatenate([data], axis=0) 348 | target_array: np.ndarray = np.array(target, dtype=np.int64) 349 | 350 | return (data_array, target_array), empty_classes 351 | 352 | 353 | def get_partitions_distributions(partitions: XYList) -> Tuple[np.ndarray, List[int]]: 354 | """Evaluates the distribution over classes for a set of partitions. 355 | 356 | Args: 357 | partitions (XYList): Input partitions 358 | 359 | Returns: 360 | np.ndarray: Distributions of size (num_partitions, num_classes) 361 | """ 362 | # Get largest available label 363 | labels = set() 364 | for _, y in partitions: 365 | labels.update(set(y)) 366 | list_labels = sorted(list(labels)) 367 | bin_edges = np.arange(len(list_labels) + 1) 368 | 369 | # Pre-allocate distributions 370 | distributions = np.zeros((len(partitions), len(list_labels)), dtype=np.float32) 371 | for idx, (_, _y) in enumerate(partitions): 372 | hist, _ = np.histogram(_y, bin_edges) 373 | distributions[idx] = hist / hist.sum() 374 | 375 | return distributions, list_labels 376 | 377 | 378 | def create_lda_partitions( 379 | dataset: XY, 380 | dirichlet_dist: Optional[np.ndarray] = None, 381 | num_partitions: int = 100, 382 | concentration: Union[float, np.ndarray, List[float]] = 0.5, 383 | accept_imbalanced: bool = False, 384 | seed: Optional[Union[int, SeedSequence, BitGenerator, Generator]] = None, 385 | ) -> Tuple[XYList, np.ndarray]: 386 | """Create imbalanced non-iid partitions using Latent Dirichlet Allocation 387 | (LDA) without resampling. 388 | 389 | Args: 390 | dataset (XY): Dataset containing samples X and labels Y. 391 | dirichlet_dist (numpy.ndarray, optional): previously generated distribution to 392 | be used. This is useful when applying the same distribution for train and 393 | validation sets. 394 | num_partitions (int, optional): Number of partitions to be created. 395 | Defaults to 100. 396 | concentration (float, np.ndarray, List[float]): Dirichlet Concentration 397 | (:math:`\\alpha`) parameter. Set to float('inf') to get uniform partitions. 398 | An :math:`\\alpha \\to \\Inf` generates uniform distributions over classes. 399 | An :math:`\\alpha \\to 0.0` generates one class per client. Defaults to 0.5. 400 | accept_imbalanced (bool): Whether or not to accept imbalanced output classes. 401 | Default False. 402 | seed (None, int, SeedSequence, BitGenerator, Generator): 403 | A seed to initialize the BitGenerator for generating the Dirichlet 404 | distribution. This is defined in Numpy's official documentation as follows: 405 | If None, then fresh, unpredictable entropy will be pulled from the OS. 406 | One may also pass in a SeedSequence instance. 407 | Additionally, when passed a BitGenerator, it will be wrapped by Generator. 408 | If passed a Generator, it will be returned unaltered. 409 | See official Numpy Documentation for further details. 410 | 411 | Returns: 412 | Tuple[XYList, numpy.ndarray]: List of XYList containing partitions 413 | for each dataset and the dirichlet probability density functions. 414 | """ 415 | # pylint: disable=too-many-arguments,too-many-locals 416 | 417 | x, y = dataset 418 | x, y = shuffle(x, y) 419 | x, y = sort_by_label(x, y) 420 | 421 | if (x.shape[0] % num_partitions) and (not accept_imbalanced): 422 | raise ValueError( 423 | """Total number of samples must be a multiple of `num_partitions`. 424 | If imbalanced classes are allowed, set 425 | `accept_imbalanced=True`.""" 426 | ) 427 | 428 | num_samples = num_partitions * [0] 429 | for j in range(x.shape[0]): 430 | num_samples[j % num_partitions] += 1 431 | 432 | # Get number of classes and verify if they matching with 433 | classes, start_indices = np.unique(y, return_index=True) 434 | 435 | # Make sure that concentration is np.array and 436 | # check if concentration is appropriate 437 | concentration = np.asarray(concentration) 438 | 439 | # Check if concentration is Inf, if so create uniform partitions 440 | partitions: List[XY] = [(_, _) for _ in range(num_partitions)] 441 | if float("inf") in concentration: 442 | 443 | partitions = create_partitions( 444 | unpartitioned_dataset=(x, y), 445 | iid_fraction=1.0, 446 | num_partitions=num_partitions, 447 | ) 448 | dirichlet_dist = get_partitions_distributions(partitions)[0] 449 | 450 | return partitions, dirichlet_dist 451 | 452 | if concentration.size == 1: 453 | concentration = np.repeat(concentration, classes.size) 454 | elif concentration.size != classes.size: # Sequence 455 | raise ValueError( 456 | f"The size of the provided concentration ({concentration.size}) ", 457 | f"must be either 1 or equal number of classes {classes.size})", 458 | ) 459 | 460 | # Split into list of list of samples per class 461 | list_samples_per_class: List[List[np.ndarray]] = split_array_at_indices( 462 | x, start_indices 463 | ) 464 | 465 | if dirichlet_dist is None: 466 | dirichlet_dist = np.random.default_rng(seed).dirichlet( 467 | alpha=concentration, size=num_partitions 468 | ) 469 | 470 | if dirichlet_dist.size != 0: 471 | if dirichlet_dist.shape != (num_partitions, classes.size): 472 | raise ValueError( 473 | f"""The shape of the provided dirichlet distribution 474 | ({dirichlet_dist.shape}) must match the provided number 475 | of partitions and classes ({num_partitions},{classes.size})""" 476 | ) 477 | 478 | # Assuming balanced distribution 479 | empty_classes = classes.size * [False] 480 | for partition_id in range(num_partitions): 481 | partitions[partition_id], empty_classes = sample_without_replacement( 482 | distribution=dirichlet_dist[partition_id].copy(), 483 | list_samples=list_samples_per_class, 484 | num_samples=num_samples[partition_id], 485 | empty_classes=empty_classes, 486 | ) 487 | 488 | return partitions, dirichlet_dist 489 | --------------------------------------------------------------------------------