├── HAR_CNN.py ├── HAR_individuals.py ├── README.md ├── cluster_people_selection ├── DBSCAN.py ├── GMM.py ├── silhouette_DBSCAN.py ├── silhouette_GMM.py ├── stratified_har.py └── stratified_mit.py ├── main_DBSCAN_mit.py ├── main_GMM_mit.py ├── main_har.py ├── main_two_layer_har_DBSCAN.py ├── main_two_layer_mit_GMM.py ├── mit_all ├── bash.exe.stackdump ├── client0.py ├── client1.py ├── client10.py ├── client11.py ├── client12.py ├── client13.py ├── client14.py ├── client15.py ├── client16.py ├── client17.py ├── client18.py ├── client19.py ├── client2.py ├── client20.py ├── client21.py ├── client22.py ├── client23.py ├── client24.py ├── client25.py ├── client26.py ├── client27.py ├── client28.py ├── client29.py ├── client3.py ├── client30.py ├── client31.py ├── client32.py ├── client33.py ├── client34.py ├── client35.py ├── client36.py ├── client37.py ├── client38.py ├── client39.py ├── client4.py ├── client40.py ├── client41.py ├── client42.py ├── client43.py ├── client44.py ├── client45.py ├── client5.py ├── client6.py ├── client7.py ├── client8.py ├── client9.py ├── mintty.exe.stackdump └── run.sh ├── read_HAR.py ├── read_MIT.py ├── server.py ├── stratified_age_sexy_two_layer_mit_GMM.py ├── template_har.py └── template_mit.py /README.md: -------------------------------------------------------------------------------- 1 | Kevin I-Kai Wang; Xiaozhou Ye; Kouichi Sakurai, "Federated Learning with Clustering-Based Participant Selection for IoT Applications." 2022 IEEE International Conference on Big Data (Big Data). IEEE, 2022. 2 | https://ieeexplore.ieee.org/abstract/document/10020575 3 | 4 | Zhou Xiaokang; Xiaozhou Ye, et al. "Hierarchical federated learning with social context clustering-based participant selection for internet of medical things applications." IEEE Transactions on Computational Social Systems (2023). 5 | https://ieeexplore.ieee.org/abstract/document/10091843 6 | -------------------------------------------------------------------------------- /cluster_people_selection/DBSCAN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.cluster import DBSCAN 3 | from sklearn import metrics 4 | from sklearn.datasets import make_blobs 5 | from sklearn.preprocessing import StandardScaler 6 | from csv import reader 7 | 8 | # read data 9 | ''' 10 | # read MIT 11 | file_name = 'mit.txt' 12 | with open(file_name, 'r') as raw_data: 13 | readers = reader(raw_data, delimiter=',') 14 | x = list(readers) 15 | data = np.array(x).astype('float') 16 | print(data.shape) 17 | data = data[:, 1:] 18 | 19 | 20 | ''' 21 | # read STT 22 | file_name = 'stt.txt' 23 | with open(file_name, 'r') as raw_data: 24 | readers = reader(raw_data, delimiter=',') 25 | x = list(readers) 26 | data = np.array(x) 27 | print(data.shape) 28 | data = data[:, 1:] 29 | data = data.astype('float') 30 | ''' 31 | 32 | ''' 33 | # read HAR 34 | file_name = 'har.txt' 35 | with open(file_name, 'r') as raw_data: 36 | readers = reader(raw_data, delimiter=',') 37 | x = list(readers) 38 | data = np.array(x).astype('float') 39 | print(data.shape) 40 | data = data[:, 1:] 41 | 42 | 43 | 44 | # clustering DBSCAN_har_cluster_all 45 | data = StandardScaler().fit_transform(data) 46 | db = DBSCAN(eps=1.1, min_samples=2).fit(data)# 0.68 four clusters HAR # 0.1 eight clusters MIT 47 | core_samples_mask = np.zeros_like(db.labels_, dtype=bool) 48 | core_samples_mask[db.core_sample_indices_] = True 49 | labels = db.labels_ 50 | print(labels) 51 | # Number of clusters in labels, ignoring noise if present. 52 | n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0) 53 | n_noise_ = list(labels).count(-1) 54 | 55 | print("Estimated number of clusters: %d" % n_clusters_) 56 | print("Estimated number of noise points: %d" % n_noise_) 57 | print("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(data, labels)) 58 | 59 | import matplotlib.pyplot as plt 60 | 61 | # Black removed and is used for noise instead. 62 | unique_labels = set(labels) 63 | colors = [plt.cm.Spectral(each) for each in np.linspace(0, 1, len(unique_labels))] 64 | for k, col in zip(unique_labels, colors): 65 | if k == -1: 66 | # Black used for noise. 67 | col = [0, 0, 0, 1] 68 | 69 | class_member_mask = labels == k 70 | 71 | xy = data[class_member_mask & core_samples_mask] 72 | plt.plot( 73 | xy[:, 0], 74 | xy[:, 1], 75 | "o", 76 | markerfacecolor=tuple(col), 77 | markeredgecolor="k", 78 | markersize=14, 79 | ) 80 | 81 | xy = data[class_member_mask & ~core_samples_mask] 82 | plt.plot( 83 | xy[:, 0], 84 | xy[:, 1], 85 | "o", 86 | markerfacecolor=tuple(col), 87 | markeredgecolor="k", 88 | markersize=6, 89 | ) 90 | 91 | plt.title("Estimated number of clusters: %d" % n_clusters_) 92 | plt.show() 93 | -------------------------------------------------------------------------------- /cluster_people_selection/GMM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.cluster import DBSCAN 3 | from sklearn import metrics, mixture 4 | from sklearn.datasets import make_blobs 5 | from sklearn.preprocessing import StandardScaler 6 | from csv import reader 7 | 8 | # read data 9 | 10 | ''' 11 | # read MIT 12 | file_name = 'mit.txt' 13 | with open(file_name, 'r') as raw_data: 14 | readers = reader(raw_data, delimiter=',') 15 | x = list(readers) 16 | data = np.array(x).astype('float') 17 | print(data.shape) 18 | data = data[:, 1:] 19 | ''' 20 | 21 | ''' 22 | # read STT 23 | file_name = 'stt.txt' 24 | with open(file_name, 'r') as raw_data: 25 | readers = reader(raw_data, delimiter=',') 26 | x = list(readers) 27 | data = np.array(x) 28 | print(data.shape) 29 | data = data[:, 1:] 30 | data = data.astype('float') 31 | ''' 32 | 33 | #''' 34 | # read HAR 35 | file_name = 'har.txt' 36 | with open(file_name, 'r') as raw_data: 37 | readers = reader(raw_data, delimiter=',') 38 | x = list(readers) 39 | data = np.array(x).astype('float') 40 | print(data.shape) 41 | data = data[:, 1:] 42 | #''' 43 | 44 | # clustering GMM 45 | data = StandardScaler().fit_transform(data) 46 | 47 | # clustering GMM 48 | bicr = 10000000000 49 | c = 0 50 | slist = [5, 6] 51 | gmm = mixture.GaussianMixture(n_components=slist[c], covariance_type='diag') 52 | gmm.fit(data) 53 | labels = gmm.fit_predict(data) 54 | print(labels) 55 | print(str(slist[c]) + ':' + str(bicr)) 56 | while (c < len(slist)): # and (gmm.bic(data) < bicr) 57 | c += 1 58 | bicr = gmm.bic(data) 59 | gmm = mixture.GaussianMixture(n_components=slist[c], covariance_type='diag') 60 | gmm.fit(data) 61 | labels = gmm.fit_predict(data) 62 | print(labels) 63 | print(str(slist[c]) + ':' + str(bicr)) 64 | 65 | c = c - 1 66 | gmm = mixture.GaussianMixture(n_components=slist[c], covariance_type='diag') 67 | gmm.fit(data) 68 | labels = gmm.fit_predict(data) 69 | print() 70 | -------------------------------------------------------------------------------- /cluster_people_selection/silhouette_DBSCAN.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from sklearn.cluster import DBSCAN 5 | from sklearn import metrics, mixture 6 | from sklearn.datasets import make_blobs 7 | from sklearn.metrics import silhouette_score 8 | from sklearn.preprocessing import StandardScaler 9 | from csv import reader 10 | 11 | # read data 12 | 13 | #''' 14 | # read MIT 15 | file_name = 'mit.txt' 16 | with open(file_name, 'r') as raw_data: 17 | readers = reader(raw_data, delimiter=',') 18 | x = list(readers) 19 | data = np.array(x).astype('float') 20 | print(data.shape) 21 | data = data[:, 1:] 22 | #''' 23 | 24 | ''' 25 | # read STT 26 | file_name = 'stt.txt' 27 | with open(file_name, 'r') as raw_data: 28 | readers = reader(raw_data, delimiter=',') 29 | x = list(readers) 30 | data = np.array(x) 31 | print(data.shape) 32 | data = data[:, 1:] 33 | data = data.astype('float') 34 | ''' 35 | 36 | ''' 37 | # read HAR 38 | file_name = 'har.txt' 39 | with open(file_name, 'r') as raw_data: 40 | readers = reader(raw_data, delimiter=',') 41 | x = list(readers) 42 | data = np.array(x).astype('float') 43 | print(data.shape) 44 | data = data[:, 1:] 45 | ''' 46 | 47 | # clustering GMM 48 | data = StandardScaler().fit_transform(data) 49 | 50 | all_list = [i for i in range(len(data))] 51 | random.seed(12) 52 | a = random.sample(range(0, len(data)), int(len(data) / 2)) 53 | b = [i for i in all_list if i not in a] 54 | print() 55 | 56 | eps_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] 57 | min_samples_list = [1, 2, 3, 4] 58 | 59 | for i in eps_list: 60 | for j in min_samples_list: 61 | try: 62 | aa = data[a] 63 | db = DBSCAN(eps=i, min_samples=j).fit(aa) # 0.68 four clusters HAR # 0.1 eight clusters MIT 64 | core_samples_mask = np.zeros_like(db.labels_, dtype=bool) 65 | core_samples_mask[db.core_sample_indices_] = True 66 | labels = db.labels_ 67 | print('eps:'+str(i)+'__min_samples:'+str(j)) 68 | print(a) 69 | print(labels) 70 | a_quality = silhouette_score(aa, labels) 71 | print(str(a_quality)) 72 | print('-------------------------------------') 73 | except: 74 | print(str(i) + '__' + str(j)) 75 | print("can not even has one cluster") 76 | print('-------------------------------------') 77 | 78 | for i in eps_list: 79 | for j in min_samples_list: 80 | try: 81 | bb = data[b] 82 | db = DBSCAN(eps=i, min_samples=j).fit(bb) 83 | core_samples_mask = np.zeros_like(db.labels_, dtype=bool) 84 | core_samples_mask[db.core_sample_indices_] = True 85 | labels = db.labels_ 86 | print('eps:'+str(i)+'__min_samples:'+str(j)) 87 | print(b) 88 | print(labels) 89 | b_quality = silhouette_score(bb, labels) 90 | print(str(b_quality)) 91 | print('-------------------------------------') 92 | print('-------------------------------------') 93 | except: 94 | print(str(i) + '__' + str(j)) 95 | print("can not even has one cluster") 96 | print('-------------------------------------') 97 | 98 | ''' 99 | har two layer, each 3 clusters 100 | 101 | [15, 8, 21, 16, 11, 4, 12, 0, 19, 7, 18, 10] 102 | [-1 0 0 1 -1 2 -1 -1 1 2 -1 1] 103 | silhouette_score = 0.2887524201946839 104 | ------------------------------------- 105 | [1, 2, 3, 5, 6, 9, 13, 14, 17, 20, 22, 23] 106 | [ 0 1 -1 0 -1 -1 2 2 1 -1 -1 -1] 107 | silhouette_score = 0.2037482842069448 108 | ''' 109 | 110 | ''' 111 | mit two layer, each 5 clusters 112 | 113 | [30, 17, 42, 33, 22, 9, 24, 0, 23, 45, 44, 29, 14, 38, 40, 43, 11, 5, 10, 6, 1, 18, 26] 114 | [-1 -1 0 1 1 -1 2 2 2 -1 1 1 2 3 3 0 4 4 1 1 0 3 1] 115 | silhouette_score = 0.494405758587215 116 | ------------------------------------- 117 | [2, 3, 4, 7, 8, 12, 13, 15, 16, 19, 20, 21, 25, 27, 28, 31, 32, 34, 35, 36, 37, 39, 41] 118 | [ 0 1 1 0 2 1 3 2 2 4 1 2 4 0 3 3 2 2 2 0 2 2 -1] 119 | silhouette_score = 0.6186597930397515 120 | 121 | ''' 122 | -------------------------------------------------------------------------------- /cluster_people_selection/silhouette_GMM.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from sklearn.cluster import DBSCAN 5 | from sklearn import metrics, mixture 6 | from sklearn.datasets import make_blobs 7 | from sklearn.metrics import silhouette_score 8 | from sklearn.preprocessing import StandardScaler 9 | from csv import reader 10 | 11 | # read data 12 | 13 | ''' 14 | # read MIT 15 | file_name = 'mit.txt' 16 | with open(file_name, 'r') as raw_data: 17 | readers = reader(raw_data, delimiter=',') 18 | x = list(readers) 19 | data = np.array(x).astype('float') 20 | print(data.shape) 21 | data = data[:, 1: -1] 22 | ''' 23 | 24 | ''' 25 | # read STT 26 | file_name = 'stt.txt' 27 | with open(file_name, 'r') as raw_data: 28 | readers = reader(raw_data, delimiter=',') 29 | x = list(readers) 30 | data = np.array(x) 31 | print(data.shape) 32 | data = data[:, 1:] 33 | data = data.astype('float') 34 | ''' 35 | 36 | 37 | # read HAR 38 | file_name = 'har.txt' 39 | with open(file_name, 'r') as raw_data: 40 | readers = reader(raw_data, delimiter=',') 41 | x = list(readers) 42 | data = np.array(x).astype('float') 43 | print(data.shape) 44 | data = data[:, 1: -1] 45 | 46 | 47 | # clustering GMM 48 | data = StandardScaler().fit_transform(data) 49 | 50 | all_list = [i for i in range(len(data))] 51 | random.seed(12) 52 | a = random.sample(range(0, len(data)), int(len(data) / 2)) 53 | b = [i for i in all_list if i not in a] 54 | print() 55 | 56 | # har stratified clustering 57 | a = [15,10,7,2,12,8,21,1,16,11,22] #[27,5,38,34,16,14,17,18,13,33,21,4,32,15,11,44,0,28,24,12,1,29,45] 58 | b = [5,6,17,18,4,3,20,19,13,14] #[19,22,30,35,41,20,36,10,31,25,42,37,39,40,6,3,23,7,43,9,26,2,8] 59 | 60 | ''' 61 | # MIT stratified clustering 62 | a = [5,14,16,1,45,18,39,19,44,28,0,22,2,37,24,10,20,12,4] #[27,5,38,34,16,14,17,18,13,33,21,4,32,15,11,44,0,28,24,12,1,29,45] 63 | b = [34,3,41,40,25,38,32,29,6,23,35,21,31,9,26,36,8,17,42] #[19,22,30,35,41,20,36,10,31,25,42,37,39,40,6,3,23,7,43,9,26,2,8] 64 | print() 65 | ''' 66 | 67 | grid_search = [3, 4, 5, 6, 7, 8] 68 | 69 | 70 | for i in grid_search: 71 | gmm = mixture.GaussianMixture(n_components=i) 72 | aa = data[a] 73 | gmm.fit(aa) 74 | labels = gmm.fit_predict(aa) 75 | print('num_clusters:'+str(i)) 76 | print(a) 77 | print(labels) 78 | a_quality = silhouette_score(aa, labels) 79 | print(str(a_quality)) 80 | print('-------------------------------------') 81 | 82 | 83 | for i in grid_search: 84 | gmm = mixture.GaussianMixture(n_components=i) 85 | bb = data[b] 86 | gmm.fit(bb) 87 | labels = gmm.fit_predict(bb) 88 | print('num_clusters:' + str(i)) 89 | print(b) 90 | print(labels) 91 | b_quality = silhouette_score(bb, labels) 92 | print(str(b_quality)) 93 | print('-------------------------------------') 94 | print('-------------------------------------') 95 | 96 | 97 | 98 | ''' 99 | har two layer, each 4 clusters 100 | 101 | [17, 9, 3, 19, 15, 22, 13, 20, 12, 0, 14, 4] 102 | [1 1 0 0 0 1 3 1 3 2 3 1] 103 | silhouette_score = 0.35087137647724703 104 | 105 | ------------------------------------- 106 | [1, 2, 5, 6, 7, 8, 10, 11, 16, 18, 21, 23] 107 | [1 0 1 0 0 3 1 1 1 0 3 2] 108 | silhouette_score = 0.4033872832527085 109 | ''' 110 | 111 | ''' 112 | mit two layer, each 7 clusters 113 | 114 | [7, 14, 24, 12, 45, 18, 25, 9, 11, 21, 38, 23, 41, 10, 4, 42, 22, 5, 36, 35, 1, 34, 6] 115 | [5 0 0 1 4 5 3 4 2 6 5 0 3 7 1 1 0 2 5 0 1 6 0] 116 | silhouette_score = 0.6948058712307382 117 | ------------------------------------- 118 | [0, 2, 3, 8, 13, 15, 16, 17, 19, 20, 26, 27, 28, 29, 30, 31, 32, 33, 37, 39, 40, 43, 44] 119 | [5 2 6 1 4 5 5 0 7 6 1 2 4 1 3 4 1 7 3 5 2 2 1] 120 | silhouette_score = 0.5769560143891007 121 | ''' 122 | -------------------------------------------------------------------------------- /cluster_people_selection/stratified_har.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.cluster import DBSCAN 3 | from sklearn import metrics, mixture, model_selection 4 | from sklearn.datasets import make_blobs 5 | from sklearn.preprocessing import StandardScaler 6 | from csv import reader 7 | 8 | 9 | # read HAR 10 | file_name = 'har.txt' 11 | with open(file_name, 'r') as raw_data: 12 | readers = reader(raw_data, delimiter=',') 13 | x = list(readers) 14 | data = np.array(x).astype('float') 15 | print(data.shape) 16 | XX = data[:, 1:-1] 17 | YY = data[:, -1:] 18 | 19 | 20 | X_train, X_test, y_train, y_test, = model_selection.train_test_split(XX, YY, test_size=0.5, random_state=42, stratify=YY) 21 | 22 | indices = np.arange(len(data)) 23 | ( 24 | data_train, 25 | data_test, 26 | labels_train, 27 | labels_test, 28 | indices_train, 29 | indices_test, 30 | ) = model_selection.train_test_split(XX, YY, indices, test_size=0.5, random_state=24, stratify=YY) 31 | 32 | 33 | print() 34 | 35 | ''' 36 | # stratified with age + sexy 37 | 38 | [15, 10, 7, 2, 12, 8, 21, 1, 16, 11, 22] 39 | [1, 0, 1, 1, 0, 2, 2, 0, 0, 0, 1] 40 | 0.4147932634167539 41 | 42 | [5, 6, 17, 18, 4, 3, 20, 19, 13, 14] 43 | [2, 0, 0, 0, 0, 2, 0, 2, 1, 1] 44 | 0.37662394713054786 45 | 46 | 47 | 48 | 49 | 50 | ''' 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | ''' 62 | num_clusters:4 63 | [27,5,38,34,16,14,17,18,13,33,21,4,32,15,11,44,0,28,24,12,1,29,45] 64 | [0 2 0 1 1 1 3 0 3 1 1 0 1 1 2 1 1 2 1 0 0 1 3] 65 | 0.7253862639015771 66 | ''' 67 | 68 | ''' 69 | num_clusters:4 70 | [19,22,30,35,41,20,36,10,31,25,42,37,39,40,6,3,23,7,43,9,26,2,8] 71 | [0 0 3 0 0 1 1 0 2 0 1 3 0 1 0 1 0 1 1 2 0 1 0] 72 | 0.5562913782929296 73 | ''' -------------------------------------------------------------------------------- /cluster_people_selection/stratified_mit.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.cluster import DBSCAN 3 | from sklearn import metrics, mixture, model_selection 4 | from sklearn.datasets import make_blobs 5 | from sklearn.preprocessing import StandardScaler 6 | from csv import reader 7 | 8 | # read MIT 9 | file_name = 'mit.txt' 10 | with open(file_name, 'r') as raw_data: 11 | readers = reader(raw_data, delimiter=',') 12 | x = list(readers) 13 | data = np.array(x).astype('float') 14 | print(data.shape) 15 | XX = data[:, 1:-1] 16 | YY = data[:, -1:] 17 | 18 | X_train, X_test, y_train, y_test, = model_selection.train_test_split(XX, YY, test_size=0.5, random_state=42, stratify=YY) 19 | 20 | indices = np.arange(len(data)) 21 | ( 22 | data_train, 23 | data_test, 24 | labels_train, 25 | labels_test, 26 | indices_train, 27 | indices_test, 28 | ) = model_selection.train_test_split(XX, YY, indices, test_size=0.5, random_state=1024, stratify=YY) 29 | 30 | 31 | print() 32 | 33 | ''' 34 | # stratified with age + sexy 35 | [5,14,16,1,45,18,39,19,44,28,0,22,2,37,24,10,20,12,4] 36 | [2 0 0 1 1 1 0 3 3 2 0 0 1 0 0 3 1 1 1] 37 | 0.6882175190538096 38 | 39 | 40 | [34,3,41,40,25,38,32,29,6,23,35,21,31,9,26,36,8,17,42] 41 | [0 1 3 1 3 1 0 0 0 0 0 0 2 2 0 1 0 2 1] 42 | 0.664622281335371 43 | ''' 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | ''' 55 | num_clusters:4 56 | [27,5,38,34,16,14,17,18,13,33,21,4,32,15,11,44,0,28,24,12,1,29,45] 57 | [0 2 0 1 1 1 3 0 3 1 1 0 1 1 2 1 1 2 1 0 0 1 3] 58 | 0.7253862639015771 59 | ''' 60 | 61 | ''' 62 | num_clusters:4 63 | [19,22,30,35,41,20,36,10,31,25,42,37,39,40,6,3,23,7,43,9,26,2,8] 64 | [0 0 3 0 0 1 1 0 2 0 1 3 0 1 0 1 0 1 1 2 0 1 0] 65 | 0.5562913782929296 66 | ''' -------------------------------------------------------------------------------- /main_DBSCAN_mit.py: -------------------------------------------------------------------------------- 1 | import os 2 | from jinja2 import Environment, FileSystemLoader 3 | 4 | # ==================================================================================================================== 5 | # all 6 | # path = "./all_stratified_sampling/" 7 | 8 | user_list = [i for i in range(100, 103)] + [i for i in range(104, 110)] + [i for i in range(111, 120)] + \ 9 | [i for i in range(121, 125)] + [i for i in range(200, 204)] + [i for i in range(205, 206)] + \ 10 | [i for i in range(207, 211)] + [i for i in range(212, 216)] + [i for i in range(217, 218)] + \ 11 | [i for i in range(220, 224)] + [i for i in range(228, 229)] + [i for i in range(230, 235)] 12 | 13 | cluster_list = [0, 1, 2, -1, 3, 4, 5, 6, 5, -1, 7, 4, 3, -1, 0, 0, 0, -1, 2, -1, -1, -1, 5, 0, 0, -1, -1, -1, 4, 5, -1, 14 | -1, 5, 7, -1, 5, 6, -1, 2, -1, -1, -1, 3, 1, -1, -1] # GMM cluster 8 15 | # cluster_list = [-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1] 16 | 17 | 18 | # ==================================================================================================================== 19 | print() 20 | #''' 21 | get_index = [x for x, y in enumerate(cluster_list) if y != -1] 22 | path = "./DBSCAN_8_clusters_mit_cluster_all" + "/" 23 | os.makedirs(path, exist_ok=False) 24 | 25 | 26 | env = Environment(loader=FileSystemLoader(searchpath="")) 27 | template = env.get_template("./template_mit.py") 28 | 29 | index_list = [user_list[i] for i in get_index] 30 | 31 | for i in range(0, len(index_list)): 32 | output = template.render({'user_name': index_list[i]}) 33 | with open(path + "client%d.py" % i, 'w') as out: 34 | out.write(output) 35 | out.close() 36 | 37 | 38 | env = Environment(loader=FileSystemLoader(searchpath="")) 39 | template = env.get_template("./run.sh") 40 | 41 | 42 | output = template.render({'num_users': len(index_list) - 1}) 43 | with open(path + "run.sh", 'w') as out: 44 | out.write(output) 45 | out.close() 46 | 47 | #''' 48 | ''' 49 | for i in [0, 1, 2, 3, 4, 5, 6, 7]: 50 | get_index = [x for x, y in enumerate(cluster_list) if y == i] 51 | path = "./DBSCAN_8_clusters_mit_cluster_" + str(i) + "/" 52 | os.makedirs(path, exist_ok=False) 53 | 54 | 55 | env = Environment(loader=FileSystemLoader(searchpath="")) 56 | template = env.get_template("./template_mit.py") 57 | 58 | index_list = [user_list[i] for i in get_index] 59 | 60 | 61 | for i in range(0, len(index_list)): 62 | output = template.render({'user_name': index_list[i]}) 63 | with open(path + "client%d.py" % i, 'w') as out: 64 | out.write(output) 65 | out.close() 66 | 67 | 68 | env = Environment(loader=FileSystemLoader(searchpath="")) 69 | template = env.get_template("./run.sh") 70 | 71 | 72 | output = template.render({'num_users': len(index_list) - 1}) 73 | with open(path + "run.sh", 'w') as out: 74 | out.write(output) 75 | out.close() 76 | ''' 77 | # ==================================================================================================================== 78 | -------------------------------------------------------------------------------- /main_GMM_mit.py: -------------------------------------------------------------------------------- 1 | import os 2 | from jinja2 import Environment, FileSystemLoader 3 | 4 | # ==================================================================================================================== 5 | # all 6 | # path = "./all_stratified_sampling/" 7 | 8 | user_list = [i for i in range(100, 103)] + [i for i in range(104, 110)] + [i for i in range(111, 120)] + \ 9 | [i for i in range(121, 125)] + [i for i in range(200, 204)] + [i for i in range(205, 206)] + \ 10 | [i for i in range(207, 211)] + [i for i in range(212, 216)] + [i for i in range(217, 218)] + \ 11 | [i for i in range(220, 224)] + [i for i in range(228, 229)] + [i for i in range(230, 235)] 12 | 13 | # cluster_list = [-1, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 2, 1, 1, 1, 2, -1] # GMM cluster 2 14 | # cluster_list = [-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1] 15 | 16 | # cluster_list = [-1, 4, 0, 4, 0, 4, 0, 0, 3, 0, 4, 1, 1, 1, 1, 0, 4, 0, 0, 4, 2, 3, 0, -1] # GMM cluster 4 17 | # cluster_list = [-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, -1] 18 | 19 | cluster_list = [3, 7, 0, 5, 7, 2, 3, 0, 3, 4, 20 | 1, 2, 7, 4, 3, 3, 3, 5, 0, 1, 21 | 5, 6, 3, 3, 3, 1, 3, 0, 2, 3, 22 | 6, 4, 3, 1, 6, 3, 0, 6, 0, 3, 23 | 0, 1, 7, 7, 1, 5] # GMM cluster 8 24 | #cluster_list = [-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, -1] 25 | 26 | # ==================================================================================================================== 27 | print() 28 | #''' 29 | get_index = [x for x, y in enumerate(cluster_list) if y != -1] 30 | path = "./mit_all" + "/" 31 | os.makedirs(path, exist_ok=False) 32 | 33 | 34 | env = Environment(loader=FileSystemLoader(searchpath="")) 35 | template = env.get_template("./template_mit.py") 36 | 37 | index_list = [user_list[i] for i in get_index] 38 | 39 | 40 | for i in range(0, len(index_list)): 41 | output = template.render({'user_name': index_list[i]}) 42 | with open(path + "client%d.py" % i, 'w') as out: 43 | out.write(output) 44 | out.close() 45 | 46 | 47 | env = Environment(loader=FileSystemLoader(searchpath="")) 48 | template = env.get_template("./run.sh") 49 | 50 | 51 | output = template.render({'num_users': len(index_list) - 1}) 52 | with open(path + "run.sh", 'w') as out: 53 | out.write(output) 54 | out.close() 55 | 56 | #''' 57 | ''' 58 | for i in [0, 1, 2, 3, 4, 5, 6, 7]: 59 | get_index = [x for x, y in enumerate(cluster_list) if y == i] 60 | path = "./GMM_8_clusters_mit_cluster_" + str(i) + "/" 61 | os.makedirs(path, exist_ok=False) 62 | 63 | 64 | env = Environment(loader=FileSystemLoader(searchpath="")) 65 | template = env.get_template("./template_mit.py") 66 | 67 | index_list = [user_list[i] for i in get_index] 68 | 69 | 70 | for i in range(0, len(index_list)): 71 | output = template.render({'user_name': index_list[i]}) 72 | with open(path + "client%d.py" % i, 'w') as out: 73 | out.write(output) 74 | out.close() 75 | 76 | 77 | env = Environment(loader=FileSystemLoader(searchpath="")) 78 | template = env.get_template("./run.sh") 79 | 80 | 81 | output = template.render({'num_users': len(index_list) - 1}) 82 | with open(path + "run.sh", 'w') as out: 83 | out.write(output) 84 | out.close() 85 | ''' 86 | # ==================================================================================================================== 87 | -------------------------------------------------------------------------------- /main_har.py: -------------------------------------------------------------------------------- 1 | import os 2 | from jinja2 import Environment, FileSystemLoader 3 | 4 | # ==================================================================================================================== 5 | # all 6 | # path = "./all_stratified_sampling/" 7 | 8 | user_list = ['sub_' + str(i + 1) for i in range(24)] 9 | 10 | #cluster_list = [-1, 0, 1, 2, 1, 0, 3, 1, 2, 3, 0, 2, 2, 2, 2, 1, 0, 1, 1, 0, 0, 2, 1, -1] # manual selection 0: M 20-30, 1: F 20-30, 2: M 30-40, 3: F 30-40 11 | 12 | #cluster_list = [-1, 00, 00, 11, 00, 00, 11, 00, 11, 11, 00, 11, 11, 11, 11, 00, 00, 00, 00, 00, 00, 11, 00, -1] # manual selection 0: M 20-30, 1: F 20-30, 2: M 30-40, 3: F 30-40 13 | 14 | # DBSCAN # cluster_list = [-1, 0, 1, -1, 1, 0, -1, 1, 2, 3, 0, -1, 4, 4, 4, -1, 0, 1, 3, 0, -1, 2, -1, -1] 15 | cluster_list = [-1, 0, 1, -1, 1, 0, -1, 1, 2, 3, 0, -1, 4, 4, 4, -1, 0, 1, 3, 0, -1, 2, 3, -1] 16 | # ==================================================================================================================== 17 | print() 18 | 19 | # get_index = [x for x, y in enumerate(cluster_list) if y != -1] 20 | # path = "./manual_har_cluster_all" + "/" 21 | get_index = [x for x, y in enumerate(cluster_list) if y != 11] 22 | path = "./har_all" + "/" 23 | os.makedirs(path, exist_ok=False) 24 | 25 | env = Environment(loader=FileSystemLoader(searchpath="")) 26 | template = env.get_template("./template_har.py") 27 | 28 | index_list = [user_list[i] for i in get_index] 29 | 30 | 31 | for i in range(0, len(index_list)): 32 | output = template.render({'user_name': index_list[i]}) 33 | with open(path + "client%d.py" % i, 'w') as out: 34 | out.write(output) 35 | out.close() 36 | 37 | 38 | env = Environment(loader=FileSystemLoader(searchpath="")) 39 | template = env.get_template("./run.sh") 40 | 41 | 42 | output = template.render({'num_users': len(index_list) - 1}) 43 | with open(path + "run.sh", 'w') as out: 44 | out.write(output) 45 | out.close() 46 | 47 | ''' 48 | 49 | for i in range(4): 50 | get_index = [x for x, y in enumerate(cluster_list) if y == i] 51 | path = "./manual_har_cluster_" + str(i) + "/" 52 | os.makedirs(path, exist_ok=False) 53 | 54 | 55 | env = Environment(loader=FileSystemLoader(searchpath="")) 56 | template = env.get_template("./template_har.py") 57 | 58 | index_list = [user_list[i] for i in get_index] 59 | 60 | 61 | for i in range(0, len(index_list)): 62 | output = template.render({'user_name': index_list[i]}) 63 | with open(path + "client%d.py" % i, 'w') as out: 64 | out.write(output) 65 | out.close() 66 | 67 | 68 | env = Environment(loader=FileSystemLoader(searchpath="")) 69 | template = env.get_template("./run.sh") 70 | 71 | 72 | output = template.render({'num_users': len(index_list) - 1}) 73 | with open(path + "run.sh", 'w') as out: 74 | out.write(output) 75 | out.close() 76 | ''' 77 | # ==================================================================================================================== 78 | -------------------------------------------------------------------------------- /main_two_layer_har_DBSCAN.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.spatial.distance import cdist 3 | import torch 4 | from jinja2 import Environment, FileSystemLoader 5 | from csv import reader 6 | from sklearn.preprocessing import StandardScaler 7 | import numpy as np 8 | 9 | # ==================================================================================================================== 10 | # all 11 | # path = "./all_stratified_sampling/" 12 | 13 | user_list = ['sub_' + str(i + 1) for i in range(24)] 14 | 15 | # GMM 16 | cluster_a_index = [15, 8, 21, 16, 11, 4, 12, 0, 19, 7, 18, 10] 17 | cluster_a_categories = [0, 1, 1, 2, 3, 4, 3, -1, 2, 4, 0, 2] 18 | 19 | cluster_b_index = [1, 2, 3, 5, 6, 9, 13, 14, 17, 20, 22, 23] 20 | cluster_b_categories = [0, 1, 0, 0, 1, 1, 0, 0, 1, -1, 1, 1] 21 | 22 | ''' 23 | # =========== 24 | file_name = 'har.txt' 25 | with open(file_name, 'r') as raw_data: 26 | readers = reader(raw_data, delimiter=',') 27 | x = list(readers) 28 | data = np.array(x).astype('float') 29 | print(data.shape) 30 | data = data[:, 1:] 31 | data = StandardScaler().fit_transform(data) 32 | 33 | a_list = [] 34 | for i in [0, 1, 2, 3, 4]: 35 | get_a_index = [x for x, y in enumerate(cluster_a_categories) if y == i] 36 | cluster_a_index_list = [cluster_a_index[j] for j in get_a_index] 37 | data_list = [data[j] for j in cluster_a_index_list] 38 | data_centre_point = np.mean(np.array(data_list), axis=0) 39 | a_list.append(data_centre_point) 40 | a_list = np.array(a_list) 41 | 42 | b_list = [] 43 | for i in [0, 1]: 44 | get_b_index = [x for x, y in enumerate(cluster_b_categories) if y == i] 45 | cluster_b_index_list = [cluster_b_index[j] for j in get_b_index] 46 | data_list = [data[j] for j in cluster_b_index_list] 47 | data_centre_point = np.mean(np.array(data_list), axis=0) 48 | b_list.append(data_centre_point) 49 | b_list = np.array(b_list) 50 | 51 | results = cdist(a_list, b_list) 52 | print() 53 | # min_list = np.min(results, axis=0) 54 | print() 55 | # 0->2 1->4 56 | # 57 | 58 | 59 | # ==================================================================================================================== 60 | print() 61 | ''' 62 | 63 | ''' 64 | # get_index = [x for x, y in enumerate(cluster_list) if y != -1] 65 | # path = "./manual_har_cluster_all" + "/" 66 | get_index = [x for x, y in enumerate(cluster_list) if y != -1] 67 | path = "./DBSCAN_5_clusters_har_cluster_all" + "/" 68 | os.makedirs(path, exist_ok=False) 69 | 70 | 71 | env = Environment(loader=FileSystemLoader(searchpath="")) 72 | template = env.get_template("./template_har.py") 73 | 74 | index_list = [user_list[i] for i in get_index] 75 | 76 | 77 | for i in range(0, len(index_list)): 78 | output = template.render({'user_name': index_list[i]}) 79 | with open(path + "client%d.py" % i, 'w') as out: 80 | out.write(output) 81 | out.close() 82 | 83 | 84 | env = Environment(loader=FileSystemLoader(searchpath="")) 85 | template = env.get_template("./run.sh") 86 | 87 | 88 | output = template.render({'num_users': len(index_list) - 1}) 89 | with open(path + "run.sh", 'w') as out: 90 | out.write(output) 91 | out.close() 92 | 93 | ''' 94 | 95 | left_list = [0, 1] 96 | right_list = [2, 4] 97 | ''' 98 | for i in [0, 1]: 99 | get_index_list = [x for x, y in enumerate(cluster_b_categories) if y == i] 100 | get_index = [cluster_a_index[j] for j in get_index_list] 101 | path = "./two_layer_clusters_har_DBSCAN_" + str(i) + "/" 102 | os.makedirs(path, exist_ok=False) 103 | 104 | 105 | env = Environment(loader=FileSystemLoader(searchpath="")) 106 | template = env.get_template("./template_har.py") 107 | 108 | index_list = [user_list[i] for i in get_index] 109 | 110 | 111 | for i in range(0, len(index_list)): 112 | output = template.render({'user_name': index_list[i]}) 113 | with open(path + "client%d.py" % i, 'w') as out: 114 | out.write(output) 115 | out.close() 116 | 117 | 118 | env = Environment(loader=FileSystemLoader(searchpath="")) 119 | template = env.get_template("./run.sh") 120 | 121 | 122 | output = template.render({'num_users': len(index_list) - 1}) 123 | with open(path + "run.sh", 'w') as out: 124 | out.write(output) 125 | out.close() 126 | ''' 127 | 128 | for i in [2, 4]: 129 | get_index_list = [x for x, y in enumerate(cluster_a_categories) if y == i] 130 | get_index = [cluster_a_index[j] for j in get_index_list] 131 | path = "./two_layer_clusters_har_DBSCAN_" + str(i) + "/" 132 | os.makedirs(path, exist_ok=False) 133 | 134 | 135 | env = Environment(loader=FileSystemLoader(searchpath="")) 136 | template = env.get_template("./template_har.py") 137 | 138 | index_list = [user_list[i] for i in get_index] 139 | 140 | 141 | for i in range(0, len(index_list)): 142 | output = template.render({'user_name': index_list[i]}) 143 | with open(path + "client%d.py" % i, 'w') as out: 144 | out.write(output) 145 | out.close() 146 | 147 | 148 | env = Environment(loader=FileSystemLoader(searchpath="")) 149 | template = env.get_template("./run.sh") 150 | 151 | 152 | output = template.render({'num_users': len(index_list) - 1}) 153 | with open(path + "run.sh", 'w') as out: 154 | out.write(output) 155 | out.close() 156 | 157 | #''' 158 | # ==================================================================================================================== 159 | -------------------------------------------------------------------------------- /main_two_layer_mit_GMM.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.spatial.distance import cdist 3 | import torch 4 | from jinja2 import Environment, FileSystemLoader 5 | from csv import reader 6 | from sklearn.preprocessing import StandardScaler 7 | import numpy as np 8 | 9 | # ==================================================================================================================== 10 | # all 11 | # path = "./all_stratified_sampling/" 12 | 13 | user_list = [i for i in range(100, 103)] + [i for i in range(104, 110)] + [i for i in range(111, 120)] + \ 14 | [i for i in range(121, 125)] + [i for i in range(200, 204)] + [i for i in range(205, 206)] + \ 15 | [i for i in range(207, 211)] + [i for i in range(212, 216)] + [i for i in range(217, 218)] + \ 16 | [i for i in range(220, 224)] + [i for i in range(228, 229)] + [i for i in range(230, 235)] 17 | 18 | 19 | # Stratified sampling GMM MIT 20 | cluster_a_index = [27, 5, 38, 34, 16, 14, 17, 18, 13, 33, 21, 4, 32, 15, 11, 44, 0, 28, 24, 12, 1, 29, 45] 21 | cluster_a_categories = [0, 2, 0, 1, 1, 1, 3, 0, 3, 1, 1, 0, 1, 1, 2, 1, 1, 2, 1, 0, 0, 1, 3] 22 | 23 | cluster_b_index = [19, 22, 30, 35, 41, 20, 36, 10, 31, 25, 42, 37, 39, 40, 6, 3, 23, 7, 43, 9, 26, 2, 8] 24 | cluster_b_categories = [0, 0, 3, 0, 0, 1, 1, 0, 2, 0, 1, 3, 0, 1, 0, 1, 0, 1, 1, 2, 0, 1, 0] 25 | 26 | ''' 27 | # GMM 28 | cluster_a_index = [30, 17, 42, 33, 22, 9, 24, 0, 23, 45, 44, 29, 14, 38, 40, 43, 11, 5, 10, 6, 1, 18, 26] 29 | cluster_a_categories = [0, 2, 1, 4, 4, 2, 4, 4, 4, 2, 4, 4, 4, 1, 1, 1, 3, 3, 4, 4, 1, 1, 4] 30 | # [1, 2, 3, 4] 31 | 32 | cluster_b_index = [2, 3, 4, 7, 8, 12, 13, 15, 16, 19, 20, 21, 25, 27, 28, 31, 32, 34, 35, 36, 37, 39, 41] 33 | cluster_b_categories = [4, 1, 1, 4, 0, 1, 5, 0, 0, 3, 1, 6, 3, 4, 2, 5, 0, 6, 0, 4, 6, 0, 7] 34 | # [0, 1, 3, 4, 5, 6] 35 | ''' 36 | 37 | ''' 38 | # =========== 39 | # read MIT 40 | file_name = 'mit.txt' 41 | with open(file_name, 'r') as raw_data: 42 | readers = reader(raw_data, delimiter=',') 43 | x = list(readers) 44 | data = np.array(x).astype('float') 45 | print(data.shape) 46 | data = data[:, 1:] 47 | data = StandardScaler().fit_transform(data) 48 | 49 | a_list = [] 50 | for i in [0, 1, 2, 3]: 51 | get_a_index = [x for x, y in enumerate(cluster_a_categories) if y == i] 52 | cluster_a_index_list = [cluster_a_index[j] for j in get_a_index] 53 | data_list = [data[j] for j in cluster_a_index_list] 54 | data_centre_point = np.mean(np.array(data_list), axis=0) 55 | a_list.append(data_centre_point) 56 | a_list = np.array(a_list) 57 | 58 | b_list = [] 59 | for i in [0, 1, 2, 3]: 60 | get_b_index = [x for x, y in enumerate(cluster_b_categories) if y == i] 61 | cluster_b_index_list = [cluster_b_index[j] for j in get_b_index] 62 | data_list = [data[j] for j in cluster_b_index_list] 63 | data_centre_point = np.mean(np.array(data_list), axis=0) 64 | b_list.append(data_centre_point) 65 | b_list = np.array(b_list) 66 | 67 | results = cdist(a_list, b_list) 68 | print() 69 | # min_list = np.min(results, axis=0) 70 | 71 | # 0->3 1->1 2->4 3->0 72 | # 1->4 2->1 3->5 4->0 73 | ''' 74 | 75 | # Stratified: 0->1 1->0 2->2 3->2 76 | 77 | 78 | # ==================================================================================================================== 79 | print() 80 | 81 | left_list = [0, 1, 2, 3] 82 | right_list = [1, 0, 2, 2] 83 | 84 | ''' 85 | for i in [0, 1, 2, 3]: 86 | get_index_list = [x for x, y in enumerate(cluster_b_categories) if y == i] 87 | get_index = [cluster_b_index[j] for j in get_index_list] 88 | path = "./two_layer_stratified_clusters_mit_GMM_" + str(i) + "/" 89 | os.makedirs(path, exist_ok=False) 90 | 91 | 92 | env = Environment(loader=FileSystemLoader(searchpath="")) 93 | template = env.get_template("./template_mit.py") 94 | 95 | index_list = [user_list[i] for i in get_index] 96 | 97 | 98 | for i in range(0, len(index_list)): 99 | output = template.render({'user_name': index_list[i]}) 100 | with open(path + "client%d.py" % i, 'w') as out: 101 | out.write(output) 102 | out.close() 103 | 104 | 105 | env = Environment(loader=FileSystemLoader(searchpath="")) 106 | template = env.get_template("./run.sh") 107 | 108 | 109 | output = template.render({'num_users': len(index_list) - 1}) 110 | with open(path + "run.sh", 'w') as out: 111 | out.write(output) 112 | out.close() 113 | ''' 114 | 115 | for i in [0, 1, 2, 3]: 116 | get_index_list = [x for x, y in enumerate(cluster_a_categories) if y == i] 117 | get_index = [cluster_a_index[j] for j in get_index_list] 118 | path = "./two_layer_stratified_clusters_mit_GMM_" + str(i) + "/" 119 | os.makedirs(path, exist_ok=False) 120 | 121 | 122 | env = Environment(loader=FileSystemLoader(searchpath="")) 123 | template = env.get_template("./template_mit.py") 124 | 125 | index_list = [user_list[i] for i in get_index] 126 | 127 | 128 | for i in range(0, len(index_list)): 129 | output = template.render({'user_name': index_list[i]}) 130 | with open(path + "client%d.py" % i, 'w') as out: 131 | out.write(output) 132 | out.close() 133 | 134 | 135 | env = Environment(loader=FileSystemLoader(searchpath="")) 136 | template = env.get_template("./run.sh") 137 | 138 | 139 | output = template.render({'num_users': len(index_list) - 1}) 140 | with open(path + "run.sh", 'w') as out: 141 | out.write(output) 142 | out.close() 143 | 144 | # ==================================================================================================================== 145 | -------------------------------------------------------------------------------- /mit_all/bash.exe.stackdump: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/helloworld1973/Federated_Learning_IoT_Applications/6502f81047240d74c6ae5c96d0ccafa6f0461dc8/mit_all/bash.exe.stackdump -------------------------------------------------------------------------------- /mit_all/client0.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 100 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client1.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 101 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client10.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 112 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client11.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 113 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client12.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 114 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client13.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 115 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client14.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 116 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client15.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 117 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client16.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 118 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client17.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 119 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client18.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 121 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client19.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 122 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client2.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 102 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client20.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 123 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client21.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 124 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client22.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 200 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client23.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 201 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client24.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 202 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client25.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 203 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client26.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 205 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client27.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 207 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client28.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 208 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client29.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 209 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client3.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 104 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client30.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 210 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/client31.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import OrderedDict 3 | import flwr as fl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import wfdb 8 | from scipy import signal 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | num = 212 15 | Batch_Size = 10 16 | Test_Size = 0.4 17 | 18 | def beats_types(): 19 | # classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 20 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 21 | sup_beats = ['A', 'a', 'J', 'S'] 22 | ven_beats = ['V', 'E'] 23 | fusion_beats = ['F'] 24 | unknown_beat = ['/', 'f', 'Q'] 25 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 26 | 27 | 28 | def load_data(num): 29 | record = wfdb.rdsamp('../mitdb/' + str(num)) 30 | annotation = wfdb.rdann('../mitdb/' + str(num), 'atr') 31 | 32 | record_data = record[0][:, 0] # MLII 33 | annotation_index = annotation.sample 34 | annotation_symbols = annotation.symbol 35 | 36 | for i, a_label in enumerate(annotation_symbols): 37 | Beat_types = beats_types() 38 | if a_label in Beat_types[0]: 39 | annotation_symbols[i] = 0 40 | elif a_label in Beat_types[1]: 41 | annotation_symbols[i] = 1 42 | elif a_label in Beat_types[2]: 43 | annotation_symbols[i] = 2 44 | elif a_label in Beat_types[3]: 45 | annotation_symbols[i] = 3 46 | elif a_label in Beat_types[4]: 47 | annotation_symbols[i] = 4 48 | else: 49 | #print('label has some not includes') 50 | annotation_symbols[i] = 4 51 | 52 | X = [] 53 | y = [] 54 | Length_RRI = len(annotation_index) 55 | for L in range(Length_RRI - 2): 56 | Ind1 = int((annotation_index[L] + annotation_index[L + 1]) / 2) 57 | Ind2 = int((annotation_index[L + 1] + annotation_index[L + 2]) / 2) 58 | 59 | Symb = annotation_symbols[L + 1] 60 | y.append(Symb) 61 | Sign = record_data[Ind1:Ind2] 62 | Resamp = signal.resample(x=Sign, num=128) 63 | X.append(Resamp) 64 | 65 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=Test_Size) 66 | 67 | return X_train, X_test, y_train, y_test 68 | 69 | 70 | class CustomDataset(Dataset): 71 | def __init__(self, data_features, data_label): 72 | self.data_features = data_features 73 | self.data_label = data_label 74 | 75 | def __len__(self): 76 | return len(self.data_features) 77 | 78 | def __getitem__(self, index): 79 | data = self.data_features[index] 80 | labels = self.data_label[index] 81 | return data, labels 82 | 83 | 84 | X_train, X_test, y_train, y_test = load_data(num=num) 85 | 86 | trainloader = DataLoader(CustomDataset(X_train, y_train), batch_size=Batch_Size, shuffle=True, drop_last=True) 87 | testloader = DataLoader(CustomDataset(X_test, y_test)) 88 | num_examples = {"trainset" : len(X_train), "testset" : len(X_test)} 89 | # ############################################################################# 90 | # 2. Federation of the pipeline with Flower 91 | # ############################################################################# 92 | 93 | warnings.filterwarnings("ignore", category=UserWarning) 94 | DEVICE = torch.device("cpu") 95 | 96 | 97 | class Net(nn.Module): 98 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 99 | 100 | def __init__(self) -> None: 101 | super(Net, self).__init__() 102 | self.conv1 = nn.Conv1d(1, 3, 5) 103 | self.pool = nn.MaxPool1d(2, 2) 104 | self.conv2 = nn.Conv1d(3, 6, 5) 105 | 106 | #self.fc1 = nn.Linear(6 * 29, 20) 107 | #self.fc2 = nn.Linear(120, 84) 108 | self.fc3 = nn.Linear(6 * 29, 5) 109 | 110 | ''' 111 | self.fc1 = nn.Linear(16 * 29, 120) 112 | self.fc2 = nn.Linear(120, 84) 113 | self.fc3 = nn.Linear(84, 5) 114 | ''' 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.pool(F.relu(self.conv1(x))) 118 | x = self.pool(F.relu(self.conv2(x))) 119 | x = x.view(-1, 6 * 29) 120 | #x = F.relu(self.fc1(x)) 121 | #x = F.relu(self.fc2(x)) 122 | return self.fc3(x) 123 | 124 | 125 | def train(net, trainloader, epochs): 126 | """Train the model on the training set.""" 127 | criterion = torch.nn.CrossEntropyLoss() 128 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1) 129 | for _ in range(epochs): 130 | for signals, labels in tqdm(trainloader): 131 | signal = signals.view(Batch_Size, 1, 128) 132 | optimizer.zero_grad() 133 | criterion(net(signal.to(DEVICE)), labels.to(DEVICE)).backward() 134 | optimizer.step() 135 | 136 | 137 | def test(net, testloader): 138 | """Validate the model on the test set.""" 139 | criterion = torch.nn.CrossEntropyLoss() 140 | correct, total, loss = 0, 0, 0.0 141 | with torch.no_grad(): 142 | for signals, labels in tqdm(testloader): 143 | signal = signals.view(1, 128) 144 | outputs = net(signal.to(DEVICE)) 145 | labels = labels.to(DEVICE) 146 | loss += criterion(outputs, labels).item() 147 | total += labels.size(0) 148 | correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() 149 | return loss / len(testloader.dataset), correct / total 150 | 151 | 152 | net = Net().to(DEVICE).double() 153 | 154 | 155 | # Define Flower client 156 | class PatientClient(fl.client.NumPyClient): 157 | def get_parameters(self, config): 158 | return [val.cpu().numpy() for _, val in net.state_dict().items()] 159 | 160 | def set_parameters(self, parameters): 161 | params_dict = zip(net.state_dict().keys(), parameters) 162 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 163 | net.load_state_dict(state_dict, strict=True) 164 | 165 | def fit(self, parameters, config): 166 | self.set_parameters(parameters) 167 | train(net, trainloader, epochs=1) 168 | return self.get_parameters(config={}), num_examples["trainset"], {} 169 | 170 | def evaluate(self, parameters, config): 171 | self.set_parameters(parameters) 172 | loss, accuracy = test(net, testloader) 173 | return float(loss), num_examples["testset"], {"accuracy": float(accuracy)} 174 | 175 | 176 | # Start client 177 | fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=PatientClient()) -------------------------------------------------------------------------------- /mit_all/mintty.exe.stackdump: -------------------------------------------------------------------------------- 1 | Exception: STATUS_ACCESS_VIOLATION at rip=00100420F0E 2 | rax=0000000000000001 rbx=00000000FFFFC240 rcx=0000000000000000 3 | rdx=00000000FFFFCE00 rsi=0000000000000000 rdi=0000000000000050 4 | r8 =000000000000005A r9 =000000000000005E r10=0000000100000000 5 | r11=0000000100420F58 r12=0000000000000050 r13=00000000FFFFC240 6 | r14=0000000800065780 r15=0000000000000001 7 | rbp=00000000FFFFC300 rsp=00000000FFFFC1E0 8 | program=C:\Program Files\Git\usr\bin\mintty.exe, pid 715, thread main 9 | cs=0033 ds=002B es=002B fs=0053 gs=002B ss=002B 10 | Stack trace: 11 | Frame Function Args 12 | 000FFFFC300 00100420F0E (00000000400, 00000000001, 00100421570, 00000000000) 13 | 000FFFFC300 00100421C16 (000FFFFC298, 6E6170206D6F7266) 14 | 000FFFFC300 00100419FCC (00100403233, 00800065620, 001004DFC80, 001004E44E0) 15 | 001004DE1E0 00100427B8A (00000000BEE, 001004DE1E0, 001004E44E0, 001004E44E0) 16 | 001004DE1E0 0010042A190 (0018015E40A, 7FFB45871064, 000FFFFC468, 001004E44E0) 17 | 001004DE1E0 0010042C8FF (001004DF0E0, 00000000100, 7FFB4687E299, 00000000000) 18 | 00000001000 00100404ACB (000FFFFC550, 7FFB46879980, 7FFB468CA6B0, 00800000001) 19 | 000FFFFC550 0010045E7AC (001801BC21A, 00800062B90, 00800062F10, 00800063078) 20 | 000FFFFCCE0 0018004B0FB (00000000000, 00000000000, 00000000000, 00000000000) 21 | 000FFFFCDA0 00180048A2A (00000000000, 00000000000, 00000000000, 00000000000) 22 | 000FFFFCE50 00180048AEC (00000000000, 00000000000, 00000000000, 00000000000) 23 | End of stack trace 24 | -------------------------------------------------------------------------------- /mit_all/run.sh: -------------------------------------------------------------------------------- 1 | for i in `seq 0 45`; do 2 | echo "Starting client $i" 3 | python client$i.py & 4 | done 5 | 6 | # This will allow you to use CTRL+C to stop all background processes 7 | trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM 8 | # Wait for all background processes to complete 9 | wait 10 | sleep 300000 -------------------------------------------------------------------------------- /read_HAR.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from gtda.time_series import SlidingWindow 3 | 4 | Segment_Size = 128 5 | activity_codes = {'dws': 0, 'jog': 1, 'sit': 2, 'std': 3, 'ups': 4, 'wlk': 5} 6 | activity_types = list(activity_codes.keys()) 7 | 8 | def load_data(num): 9 | # Load All data: 10 | Folders = glob('./A_DeviceMotion_data/*_*') 11 | Folders = [s for s in Folders if "csv" not in s] 12 | 13 | X = [] 14 | y = [] 15 | for j in Folders: 16 | Csv = glob(j + '/'+num+'.csv')[0] 17 | label = int(activity_codes[j[22:25]]) 18 | x_a_activity_a_user_complete_data = [] 19 | y_a_activity_a_user_complete_data = [] 20 | with open(Csv, 'r') as f: 21 | lines = f.readlines() 22 | for num_index, line in enumerate(lines): 23 | if num_index != 0: 24 | a_row_column = line.replace('\n', '').split(',') 25 | new_a_row_column = [] 26 | for index, i in enumerate(a_row_column): 27 | if index != 0: 28 | new_a_row_column.append(float(i)) 29 | 30 | x_a_activity_a_user_complete_data.append(new_a_row_column) 31 | y_a_activity_a_user_complete_data.append(label) 32 | 33 | sliding_bag = SlidingWindow(size=Segment_Size, stride=int(Segment_Size/2)) 34 | X_bags = sliding_bag.fit_transform(x_a_activity_a_user_complete_data) 35 | y_bags = [label for i in range(len(X_bags))] 36 | 37 | for a_X_bag in X_bags: 38 | X.append(a_X_bag) 39 | for a_y_bag in y_bags: 40 | y.append(a_y_bag) 41 | 42 | return X, y 43 | 44 | classes = {'dws': 0, 'jog': 1, 'sit': 2, 'std': 3, 'ups': 4, 'wlk': 5} 45 | index_list = ['sub_' + str(i + 1) for i in range(24)] 46 | 47 | class_0_list = [] 48 | class_1_list = [] 49 | class_2_list = [] 50 | class_3_list = [] 51 | class_4_list = [] 52 | class_5_list = [] 53 | for i in index_list: 54 | X, y = load_data(i) 55 | class_0 = y.count(0) 56 | class_0_list.append(class_0) 57 | class_1 = y.count(1) 58 | class_1_list.append(class_1) 59 | class_2 = y.count(2) 60 | class_2_list.append(class_2) 61 | class_3 = y.count(3) 62 | class_3_list.append(class_3) 63 | class_4 = y.count(4) 64 | class_4_list.append(class_4) 65 | class_5 = y.count(5) 66 | class_5_list.append(class_5) 67 | 68 | print() 69 | import matplotlib.pyplot as plt 70 | import numpy as np 71 | 72 | data = [class_0_list, class_1_list, class_2_list, class_3_list, class_4_list, class_5_list] 73 | 74 | fig = plt.figure() 75 | 76 | # Creating axes instance 77 | #ax = fig.add_axes([0, 0, 1, 1]) 78 | # x-axis labels 79 | #ax.set_yticklabels(['class_0', 'class_1', 'class_2', 'class_3', 'class_4']) 80 | # Creating plot 81 | #bp = ax.boxplot(data) 82 | plt.boxplot(data) 83 | plt.xlabel("classes",fontsize=13) 84 | plt.ylabel("number of samples",fontsize=13) 85 | # show plot 86 | plt.xticks(fontsize=13) 87 | plt.yticks(fontsize=13) 88 | plt.legend(fontsize=13) 89 | plt.savefig('HAR_classes_distribution.png') 90 | print() 91 | 92 | 93 | -------------------------------------------------------------------------------- /read_MIT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wfdb 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | if os.path.isdir("mitdb"): 7 | print('You already have the data.') 8 | else: 9 | wfdb.dl_database('mitdb', 'mitdb') 10 | 11 | 12 | def beats_types(): 13 | normal_beats = ['N', 'L', 'R', 'e', 'j'] 14 | sup_beats = ['A', 'a', 'J', 'S'] 15 | ven_beats = ['V', 'E'] 16 | fusion_beats = ['F'] 17 | unknown_beat = ['/', 'f', 'Q'] 18 | return normal_beats, sup_beats, ven_beats, fusion_beats, unknown_beat 19 | 20 | 21 | classes = {'N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4} 22 | # classes = {'Normal beat', 'Supreventricular Ectopic Beat', 'Ventricular Ectopic beat', 'Fusion Beat', 'Unknown Beat'} 23 | 24 | age_list = [] 25 | gender_list = [] 26 | index_list = [i for i in range(100, 103)] + [i for i in range(104, 110)] + [i for i in range(111, 120)] \ 27 | + [i for i in range(121, 125)] + \ 28 | [i for i in range(200, 204)] + [i for i in range(205, 206)] + [i for i in range(207, 211)] + \ 29 | [i for i in range(212, 216)] + [i for i in range(217, 218)] + [i for i in range(220, 224)] + \ 30 | [i for i in range(228, 229)] + [i for i in range(230, 235)] 31 | 32 | 33 | 34 | class_0_list = [] 35 | class_1_list = [] 36 | class_2_list = [] 37 | class_3_list = [] 38 | class_4_list = [] 39 | for i in index_list: 40 | record = wfdb.rdsamp('mitdb/' + str(i)) 41 | record_comments = record[1]['comments'][0].split(' ') 42 | age_list.append(int(record_comments[0])) 43 | gender_list.append(record_comments[1]) 44 | annotation_symbols = wfdb.rdann('mitdb/' + str(i), 'atr').symbol 45 | print() 46 | for i, a_label in enumerate(annotation_symbols): 47 | Beat_types = beats_types() 48 | if a_label in Beat_types[0]: 49 | annotation_symbols[i] = 0 50 | elif a_label in Beat_types[1]: 51 | annotation_symbols[i] = 1 52 | elif a_label in Beat_types[2]: 53 | annotation_symbols[i] = 2 54 | elif a_label in Beat_types[3]: 55 | annotation_symbols[i] = 3 56 | elif a_label in Beat_types[4]: 57 | annotation_symbols[i] = 4 58 | else: 59 | #print('label has some not includes') 60 | annotation_symbols[i] = 4 61 | 62 | class_0 = annotation_symbols.count(0) 63 | class_0_list.append(class_0) 64 | class_1 = annotation_symbols.count(1) 65 | class_1_list.append(class_1) 66 | class_2 = annotation_symbols.count(2) 67 | class_2_list.append(class_2) 68 | class_3 = annotation_symbols.count(3) 69 | class_3_list.append(class_3) 70 | class_4 = annotation_symbols.count(4) 71 | class_4_list.append(class_4) 72 | 73 | import matplotlib.pyplot as plt 74 | import numpy as np 75 | 76 | data = [class_0_list, class_1_list, class_2_list, class_3_list, class_4_list] 77 | 78 | fig = plt.figure() 79 | 80 | # Creating axes instance 81 | #ax = fig.add_axes([0, 0, 1, 1]) 82 | # x-axis labels 83 | #ax.set_yticklabels(['class_0', 'class_1', 'class_2', 'class_3', 'class_4']) 84 | # Creating plot 85 | #bp = ax.boxplot(data) 86 | plt.boxplot(data) 87 | plt.xlabel("classes",fontsize=13) 88 | plt.ylabel("number of samples",fontsize=13) 89 | # show plot 90 | plt.xticks(fontsize=13) 91 | plt.yticks(fontsize=13) 92 | plt.legend(fontsize=13) 93 | plt.savefig('MIT_classes_distribution.png') 94 | print() 95 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import List, Tuple, Optional, Dict, Union 3 | import torch.nn.functional as F 4 | import flwr as fl 5 | import numpy as np 6 | import torch 7 | from flwr.common import Metrics, FitRes, Parameters, Scalar, EvaluateRes 8 | from flwr.server.client_proxy import ClientProxy 9 | from torch import nn 10 | import joblib 11 | 12 | DEVICE = torch.device("cpu") 13 | 14 | 15 | class Net(nn.Module): 16 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 17 | 18 | def __init__(self) -> None: 19 | super(Net, self).__init__() 20 | self.conv1 = nn.Conv1d(12, 3, 15) 21 | self.pool = nn.MaxPool1d(2, 2) 22 | self.conv2 = nn.Conv1d(3, 6, 15) 23 | 24 | # self.fc1 = nn.Linear(6 * 29, 20) 25 | # self.fc2 = nn.Linear(120, 84) 26 | self.fc3 = nn.Linear(6 * 21, 6) 27 | 28 | 29 | def forward(self, x: torch.Tensor) -> torch.Tensor: 30 | x = self.pool(F.relu(self.conv1(x))) 31 | x = self.pool(F.relu(self.conv2(x))) 32 | x = x.view(-1, 6 * 21) 33 | # x = F.relu(self.fc1(x)) 34 | # x = F.relu(self.fc2(x)) 35 | return self.fc3(x) 36 | 37 | 38 | ''' 39 | class Net(nn.Module): 40 | """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" 41 | 42 | def __init__(self) -> None: 43 | super(Net, self).__init__() 44 | self.conv1 = nn.Conv1d(1, 3, 5) 45 | self.pool = nn.MaxPool1d(2, 2) 46 | self.conv2 = nn.Conv1d(3, 6, 5) 47 | 48 | #self.fc1 = nn.Linear(6 * 29, 20) 49 | #self.fc2 = nn.Linear(120, 84) 50 | self.fc3 = nn.Linear(6 * 29, 5) 51 | 52 | def forward(self, x: torch.Tensor) -> torch.Tensor: 53 | x = self.pool(F.relu(self.conv1(x))) 54 | x = self.pool(F.relu(self.conv2(x))) 55 | x = x.view(-1, 6 * 29) 56 | #x = F.relu(self.fc1(x)) 57 | #x = F.relu(self.fc2(x)) 58 | return self.fc3(x) 59 | ''' 60 | 61 | net = Net().to(DEVICE).double() 62 | 63 | 64 | class SaveModelStrategy(fl.server.strategy.FedAvg): 65 | 66 | def aggregate_fit( 67 | self, 68 | server_round: int, 69 | results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]], 70 | failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], 71 | ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: 72 | 73 | # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics 74 | aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures) 75 | 76 | if aggregated_parameters is not None: 77 | # Convert `Parameters` to `List[np.ndarray]` 78 | aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters) 79 | 80 | joblib.dump(aggregated_ndarrays, f"model_round_{server_round}.pth") 81 | # Convert `List[np.ndarray]` to PyTorch`state_dict` 82 | #params_dict = zip(net.state_dict().keys(), aggregated_ndarrays) 83 | #state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 84 | #net.load_state_dict(state_dict, strict=True) 85 | 86 | # Save the model 87 | #torch.save(net.state_dict(), f"model_round_{server_round}.pth") 88 | 89 | return aggregated_parameters, aggregated_metrics 90 | 91 | def aggregate_evaluate( 92 | self, 93 | server_round: int, 94 | results: List[Tuple[ClientProxy, EvaluateRes]], 95 | failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], 96 | ) -> Tuple[Optional[float], Dict[str, Scalar]]: 97 | """Aggregate evaluation accuracy using weighted average.""" 98 | 99 | if not results: 100 | return None, {} 101 | 102 | # Call aggregate_evaluate from base class (FedAvg) to aggregate loss and metrics 103 | aggregated_loss, aggregated_metrics = super().aggregate_evaluate(server_round, results, failures) 104 | 105 | # Weigh accuracy of each client by number of examples used 106 | accuracies = [r.metrics["accuracy"] * r.num_examples for _, r in results] 107 | examples = [r.num_examples for _, r in results] 108 | 109 | # Aggregate and print custom metric 110 | aggregated_accuracy = sum(accuracies) / sum(examples) 111 | print(f"Round {server_round} accuracy aggregated from client results: {aggregated_accuracy}") 112 | 113 | # Return aggregated loss and metrics (i.e., aggregated accuracy) 114 | return aggregated_loss, {"accuracy": aggregated_accuracy} 115 | 116 | # Create strategy and run server 117 | strategy = SaveModelStrategy(min_available_clients=46) #fraction_fit=1.0, min_available_clients=2, min_evaluate_clients=13 118 | fl.server.start_server(server_address="127.0.0.1:8080", config=fl.server.ServerConfig(num_rounds=100), strategy=strategy) 119 | -------------------------------------------------------------------------------- /stratified_age_sexy_two_layer_mit_GMM.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.spatial.distance import cdist 3 | import torch 4 | from jinja2 import Environment, FileSystemLoader 5 | from csv import reader 6 | from sklearn.preprocessing import StandardScaler 7 | import numpy as np 8 | 9 | # ==================================================================================================================== 10 | # all 11 | # path = "./all_stratified_sampling/" 12 | 13 | user_list = [i for i in range(100, 103)] + [i for i in range(104, 110)] + [i for i in range(111, 120)] + \ 14 | [i for i in range(121, 125)] + [i for i in range(200, 204)] + [i for i in range(205, 206)] + \ 15 | [i for i in range(207, 211)] + [i for i in range(212, 216)] + [i for i in range(217, 218)] + \ 16 | [i for i in range(220, 224)] + [i for i in range(228, 229)] + [i for i in range(230, 235)] 17 | 18 | 19 | # Stratified sampling GMM MIT 20 | cluster_a_index = [5,14,16,1,45,18,39,19,44,28,0,22,2,37,24,10,20,12,4] 21 | cluster_a_categories = [2, 0, 0, 1, 1, 1, 0, 3, 3, 2, 0, 0, 1, 0, 0, 3, 1, 1, 1] 22 | 23 | cluster_b_index = [34,3,41,40,25,38,32,29,6,23,35,21,31,9,26,36,8,17,42] 24 | cluster_b_categories = [0, 1, 3, 1, 3, 1, 0, 0, 0, 0, 0, 0, 2, 2, 0, 1, 0, 2, 1] 25 | 26 | ''' 27 | # GMM 28 | cluster_a_index = [30, 17, 42, 33, 22, 9, 24, 0, 23, 45, 44, 29, 14, 38, 40, 43, 11, 5, 10, 6, 1, 18, 26] 29 | cluster_a_categories = [0, 2, 1, 4, 4, 2, 4, 4, 4, 2, 4, 4, 4, 1, 1, 1, 3, 3, 4, 4, 1, 1, 4] 30 | # [1, 2, 3, 4] 31 | 32 | cluster_b_index = [2, 3, 4, 7, 8, 12, 13, 15, 16, 19, 20, 21, 25, 27, 28, 31, 32, 34, 35, 36, 37, 39, 41] 33 | cluster_b_categories = [4, 1, 1, 4, 0, 1, 5, 0, 0, 3, 1, 6, 3, 4, 2, 5, 0, 6, 0, 4, 6, 0, 7] 34 | # [0, 1, 3, 4, 5, 6] 35 | ''' 36 | 37 | ''' 38 | # =========== 39 | # read MIT 40 | file_name = 'mit.txt' 41 | with open(file_name, 'r') as raw_data: 42 | readers = reader(raw_data, delimiter=',') 43 | x = list(readers) 44 | data = np.array(x).astype('float') 45 | print(data.shape) 46 | data = data[:, 1:] 47 | data = StandardScaler().fit_transform(data) 48 | 49 | a_list = [] 50 | for i in [0, 1, 2, 3]: 51 | get_a_index = [x for x, y in enumerate(cluster_a_categories) if y == i] 52 | cluster_a_index_list = [cluster_a_index[j] for j in get_a_index] 53 | data_list = [data[j] for j in cluster_a_index_list] 54 | data_centre_point = np.mean(np.array(data_list), axis=0) 55 | a_list.append(data_centre_point) 56 | a_list = np.array(a_list) 57 | 58 | b_list = [] 59 | for i in [0, 1, 2, 3]: 60 | get_b_index = [x for x, y in enumerate(cluster_b_categories) if y == i] 61 | cluster_b_index_list = [cluster_b_index[j] for j in get_b_index] 62 | data_list = [data[j] for j in cluster_b_index_list] 63 | data_centre_point = np.mean(np.array(data_list), axis=0) 64 | b_list.append(data_centre_point) 65 | b_list = np.array(b_list) 66 | 67 | results = cdist(a_list, b_list) 68 | print() 69 | ''' 70 | # Stratified age+sexy: 0->0 1->1 2->2 3->0 71 | 72 | 73 | 74 | 75 | 76 | # min_list = np.min(results, axis=0) 77 | 78 | # 0->3 1->1 2->4 3->0 79 | # 1->4 2->1 3->5 4->0 80 | 81 | 82 | # Stratified: 0->1 1->0 2->2 3->2 83 | 84 | 85 | # ==================================================================================================================== 86 | print() 87 | 88 | left_list = [0, 1, 2, 3] 89 | right_list = [0, 1, 2, 0] 90 | 91 | ''' 92 | for i in [0, 1, 2, 3]: 93 | get_index_list = [x for x, y in enumerate(cluster_b_categories) if y == i] 94 | get_index = [cluster_b_index[j] for j in get_index_list] 95 | path = "./two_layer_stratified_age_sexy_clusters_mit_GMM_" + str(i) + "/" 96 | os.makedirs(path, exist_ok=False) 97 | 98 | # get template 99 | env = Environment(loader=FileSystemLoader(searchpath="")) 100 | template = env.get_template("./template_mit.py") 101 | 102 | index_list = [user_list[i] for i in get_index] 103 | 104 | # generate new files 105 | for i in range(0, len(index_list)): 106 | output = template.render({'user_name': index_list[i]}) 107 | with open(path + "client%d.py" % i, 'w') as out: 108 | out.write(output) 109 | out.close() 110 | 111 | # get template 112 | env = Environment(loader=FileSystemLoader(searchpath="")) 113 | template = env.get_template("./run.sh") 114 | 115 | # generate new files 116 | output = template.render({'num_users': len(index_list) - 1}) 117 | with open(path + "run.sh", 'w') as out: 118 | out.write(output) 119 | out.close() 120 | ''' 121 | 122 | for i in [0, 1, 2, 3]: 123 | get_index_list = [x for x, y in enumerate(cluster_a_categories) if y == i] 124 | get_index = [cluster_a_index[j] for j in get_index_list] 125 | path = "./two_layer_stratified_age_sexy_clusters_mit_GMM_" + str(i) + "/" 126 | os.makedirs(path, exist_ok=False) 127 | 128 | 129 | env = Environment(loader=FileSystemLoader(searchpath="")) 130 | template = env.get_template("./template_mit.py") 131 | 132 | index_list = [user_list[i] for i in get_index] 133 | 134 | 135 | for i in range(0, len(index_list)): 136 | output = template.render({'user_name': index_list[i]}) 137 | with open(path + "client%d.py" % i, 'w') as out: 138 | out.write(output) 139 | out.close() 140 | 141 | 142 | env = Environment(loader=FileSystemLoader(searchpath="")) 143 | template = env.get_template("./run.sh") 144 | 145 | 146 | output = template.render({'num_users': len(index_list) - 1}) 147 | with open(path + "run.sh", 'w') as out: 148 | out.write(output) 149 | out.close() 150 | 151 | # ==================================================================================================================== 152 | --------------------------------------------------------------------------------