├── FEMNIST_Balanced.py ├── FedMD.py ├── Neural_Networks.py ├── README.md ├── __pycache__ ├── FedMD.cpython-36.pyc ├── Neural_Networks.cpython-36.pyc └── data_utils.cpython-36.pyc ├── conf ├── .ipynb_checkpoints │ ├── EMNIST_balance_conf-checkpoint.json │ └── pretrain_MNIST_conf-checkpoint.json ├── EMNIST_balance_conf.json └── pretrain_MNIST_conf.json ├── data_utils.py ├── dataset └── emnist-letters.mat ├── pretrain_CNN_on_public_dataset.py ├── pretrain_result.pkl ├── result_FEMNIST_balanced ├── col_performance.pkl ├── init_result.pkl └── pooled_train_result.pkl ├── test.py └── utility.py /FEMNIST_Balanced.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import argparse 4 | import sys 5 | import pickle 6 | 7 | import numpy as np 8 | from tensorflow.keras.models import load_model 9 | 10 | from data_utils import load_MNIST_data, load_EMNIST_data, generate_bal_private_data,\ 11 | generate_partial_data 12 | from FedMD import FedMD 13 | 14 | 15 | def parseArg(): 16 | parser = argparse.ArgumentParser(description='FedMD, a federated learning framework. \ 17 | Participants are training collaboratively. ') 18 | parser.add_argument('-conf', metavar='conf_file', nargs=1, 19 | help='the config file for FedMD.' 20 | ) 21 | 22 | conf_file = os.path.abspath("conf/EMNIST_balance_conf.json") 23 | 24 | if len(sys.argv) > 1: 25 | args = parser.parse_args(sys.argv[1:]) 26 | if args.conf: 27 | conf_file = args.conf[0] 28 | return conf_file 29 | 30 | 31 | if __name__ == "__main__": 32 | conf_file = parseArg() 33 | with open(conf_file, "r") as f: 34 | conf_dict = eval(f.read()) 35 | 36 | emnist_data_dir = conf_dict["EMNIST_dir"] 37 | N_parties = conf_dict["N_parties"] 38 | private_classes = conf_dict["private_classes"] 39 | N_samples_per_class = conf_dict["N_samples_per_class"] 40 | 41 | N_rounds = conf_dict["N_rounds"] 42 | N_alignment = conf_dict["N_alignment"] 43 | N_private_training_round = conf_dict["N_private_training_round"] 44 | private_training_batchsize = conf_dict["private_training_batchsize"] 45 | N_logits_matching_round = conf_dict["N_logits_matching_round"] 46 | logits_matching_batchsize = conf_dict["logits_matching_batchsize"] 47 | model_saved_dir = conf_dict["model_saved_dir"] 48 | 49 | result_save_dir = conf_dict["result_save_dir"] 50 | 51 | 52 | del conf_dict, conf_file 53 | 54 | X_train_MNIST, y_train_MNIST, X_test_MNIST, y_test_MNIST \ 55 | = load_MNIST_data(standarized = True, verbose = True) 56 | 57 | public_dataset = {"X": X_train_MNIST, "y": y_train_MNIST} 58 | del X_train_MNIST, y_train_MNIST, X_test_MNIST, y_test_MNIST 59 | 60 | 61 | X_train_EMNIST, y_train_EMNIST, X_test_EMNIST, y_test_EMNIST \ 62 | = load_EMNIST_data(emnist_data_dir, 63 | standarized = True, verbose = True) 64 | 65 | #generate private data 66 | private_data, total_private_data \ 67 | = generate_bal_private_data(X_train_EMNIST, y_train_EMNIST, 68 | N_parties = N_parties, 69 | classes_in_use = private_classes, 70 | N_samples_per_class = N_samples_per_class, 71 | data_overlap = False) 72 | 73 | X_tmp, y_tmp = generate_partial_data(X = X_test_EMNIST, y= y_test_EMNIST, 74 | class_in_use = private_classes, verbose = True) 75 | private_test_data = {"X": X_tmp, "y": y_tmp} 76 | del X_tmp, y_tmp 77 | 78 | if model_saved_dir is not None: 79 | parties = [] 80 | dpath = os.path.abspath(model_saved_dir) 81 | model_names = os.listdir(dpath) 82 | for name in model_names: 83 | tmp = None 84 | tmp = load_model(os.path.join(dpath ,name)) 85 | parties.append(tmp) 86 | 87 | fedmd = FedMD(parties, 88 | public_dataset = public_dataset, 89 | private_data = private_data, 90 | total_private_data = total_private_data, 91 | private_test_data = private_test_data, 92 | N_rounds = N_rounds, 93 | N_alignment = N_alignment, 94 | N_logits_matching_round = N_logits_matching_round, 95 | logits_matching_batchsize = logits_matching_batchsize, 96 | N_private_training_round = N_private_training_round, 97 | private_training_batchsize = private_training_batchsize) 98 | 99 | initialization_result = fedmd.init_result 100 | pooled_train_result = fedmd.pooled_train_result 101 | 102 | collaboration_performance = fedmd.collaborative_training() 103 | 104 | if result_save_dir is not None: 105 | save_dir_path = os.path.abspath(result_save_dir) 106 | #make dir 107 | try: 108 | os.makedirs(save_dir_path) 109 | except OSError as e: 110 | if e.errno != errno.EEXIST: 111 | raise 112 | 113 | 114 | 115 | with open(os.path.join(save_dir_path, 'init_result.pkl'), 'wb') as f: 116 | pickle.dump(initialization_result, f, protocol=pickle.HIGHEST_PROTOCOL) 117 | with open(os.path.join(save_dir_path, 'pooled_train_result.pkl'), 'wb') as f: 118 | pickle.dump(pooled_train_result, f, protocol=pickle.HIGHEST_PROTOCOL) 119 | with open(os.path.join(save_dir_path, 'col_performance.pkl'), 'wb') as f: 120 | pickle.dump(collaboration_performance, f, protocol=pickle.HIGHEST_PROTOCOL) 121 | -------------------------------------------------------------------------------- /FedMD.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.keras.models import clone_model, load_model 3 | from tensorflow.keras.callbacks import EarlyStopping 4 | import tensorflow as tf 5 | 6 | from data_utils import generate_alignment_data 7 | from Neural_Networks import remove_last_layer 8 | 9 | class FedMD(): 10 | def __init__(self, parties, public_dataset, 11 | private_data, total_private_data, 12 | private_test_data, N_alignment, 13 | N_rounds, 14 | N_logits_matching_round, logits_matching_batchsize, 15 | N_private_training_round, private_training_batchsize): 16 | 17 | self.N_parties = len(parties) 18 | self.public_dataset = public_dataset 19 | self.private_data = private_data 20 | self.private_test_data = private_test_data 21 | self.N_alignment = N_alignment 22 | 23 | self.N_rounds = N_rounds 24 | self.N_logits_matching_round = N_logits_matching_round 25 | self.logits_matching_batchsize = logits_matching_batchsize 26 | self.N_private_training_round = N_private_training_round 27 | self.private_training_batchsize = private_training_batchsize 28 | 29 | self.collaborative_parties = [] 30 | self.init_result = [] 31 | 32 | print("start model initialization: ") 33 | for i in range(self.N_parties): 34 | print("model ", i) 35 | model_A_twin = None 36 | model_A_twin = clone_model(parties[i]) 37 | model_A_twin.set_weights(parties[i].get_weights()) 38 | model_A_twin.compile(optimizer=tf.keras.optimizers.Adam(lr = 1e-3), 39 | loss = "sparse_categorical_crossentropy", 40 | metrics = ["accuracy"]) 41 | 42 | print("start full stack training ... ") 43 | 44 | model_A_twin.fit(private_data[i]["X"], private_data[i]["y"], 45 | batch_size = 32, epochs = 25, shuffle=True, verbose = 0, 46 | validation_data = [private_test_data["X"], private_test_data["y"]], 47 | callbacks=[EarlyStopping(monitor='val_acc', min_delta=0.001, patience=5)] 48 | ) 49 | 50 | print("full stack training done") 51 | 52 | model_A = remove_last_layer(model_A_twin, loss="mean_absolute_error") 53 | 54 | self.collaborative_parties.append({"model_logits": model_A, 55 | "model_classifier": model_A_twin, 56 | "model_weights": model_A_twin.get_weights()}) 57 | 58 | self.init_result.append({"val_acc": model_A_twin.history.history['val_acc'], 59 | "train_acc": model_A_twin.history.history['acc'], 60 | "val_loss": model_A_twin.history.history['val_loss'], 61 | "train_loss": model_A_twin.history.history['loss'], 62 | }) 63 | 64 | print() 65 | del model_A, model_A_twin 66 | #END FOR LOOP 67 | 68 | print("calculate the theoretical upper bounds for participants: ") 69 | 70 | self.upper_bounds = [] 71 | self.pooled_train_result = [] 72 | for model in parties: 73 | model_ub = clone_model(model) 74 | model_ub.set_weights(model.get_weights()) 75 | model_ub.compile(optimizer=tf.keras.optimizers.Adam(lr = 1e-3), 76 | loss = "sparse_categorical_crossentropy", 77 | metrics = ["acc"]) 78 | 79 | model_ub.fit(total_private_data["X"], total_private_data["y"], 80 | batch_size = 32, epochs = 50, shuffle=True, verbose = 1, 81 | validation_data = [private_test_data["X"], private_test_data["y"]], 82 | callbacks=[EarlyStopping(monitor='val_acc', min_delta=0.001, patience=5)]) 83 | 84 | self.upper_bounds.append(model_ub.history.history["val_acc"][-1]) 85 | self.pooled_train_result.append({"val_acc": model_ub.history.history["val_acc"], 86 | "acc": model_ub.history.history["acc"]}) 87 | 88 | del model_ub 89 | print("the upper bounds are:", self.upper_bounds) 90 | 91 | def collaborative_training(self): 92 | # start collaborating training 93 | collaboration_performance = {i: [] for i in range(self.N_parties)} 94 | r = 0 95 | while True: 96 | # At beginning of each round, generate new alignment dataset 97 | alignment_data = generate_alignment_data(self.public_dataset["X"], 98 | self.public_dataset["y"], 99 | self.N_alignment) 100 | 101 | print("round ", r) 102 | 103 | print("update logits ... ") 104 | # update logits 105 | logits = 0 106 | for d in self.collaborative_parties: 107 | d["model_logits"].set_weights(d["model_weights"]) 108 | logits += d["model_logits"].predict(alignment_data["X"], verbose = 0) 109 | 110 | logits /= self.N_parties 111 | 112 | # test performance 113 | print("test performance ... ") 114 | 115 | for index, d in enumerate(self.collaborative_parties): 116 | y_pred = d["model_classifier"].predict(self.private_test_data["X"], verbose = 0).argmax(axis = 1) 117 | collaboration_performance[index].append(np.mean(self.private_test_data["y"] == y_pred)) 118 | 119 | print(collaboration_performance[index][-1]) 120 | del y_pred 121 | 122 | 123 | r+= 1 124 | if r > self.N_rounds: 125 | break 126 | 127 | 128 | print("updates models ...") 129 | for index, d in enumerate(self.collaborative_parties): 130 | print("model {0} starting alignment with public logits... ".format(index)) 131 | 132 | 133 | weights_to_use = None 134 | weights_to_use = d["model_weights"] 135 | 136 | d["model_logits"].set_weights(weights_to_use) 137 | d["model_logits"].fit(alignment_data["X"], logits, 138 | batch_size = self.logits_matching_batchsize, 139 | epochs = self.N_logits_matching_round, 140 | shuffle=True, verbose = True) 141 | d["model_weights"] = d["model_logits"].get_weights() 142 | print("model {0} done alignment".format(index)) 143 | 144 | print("model {0} starting training with private data... ".format(index)) 145 | weights_to_use = None 146 | weights_to_use = d["model_weights"] 147 | d["model_classifier"].set_weights(weights_to_use) 148 | d["model_classifier"].fit(self.private_data[index]["X"], 149 | self.private_data[index]["y"], 150 | batch_size = self.private_training_batchsize, 151 | epochs = self.N_private_training_round, 152 | shuffle=True, verbose = True) 153 | 154 | d["model_weights"] = d["model_classifier"].get_weights() 155 | print("model {0} done private training. \n".format(index)) 156 | #END FOR LOOP 157 | 158 | #END WHILE LOOP 159 | return collaboration_performance 160 | 161 | 162 | -------------------------------------------------------------------------------- /Neural_Networks.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.models import Model, Sequential, clone_model, load_model 2 | from tensorflow.keras.layers import Input, Dense, add, concatenate, Conv2D,Dropout,\ 3 | BatchNormalization, Flatten, MaxPooling2D, AveragePooling2D, Activation, Dropout, Reshape 4 | import tensorflow as tf 5 | 6 | 7 | def cnn_3layer_fc_model(n_classes,n1 = 128, n2=192, n3=256, dropout_rate = 0.2,input_shape = (28,28)): 8 | model_A, x = None, None 9 | 10 | x = Input(input_shape) 11 | if len(input_shape)==2: y = Reshape((input_shape[0], input_shape[1], 1))(x) 12 | 13 | y = Conv2D(filters = n1, kernel_size = (3,3), strides = 1, padding = "same", 14 | activation = None)(y) 15 | y = BatchNormalization()(y) 16 | y = Activation("relu")(y) 17 | y = Dropout(dropout_rate)(y) 18 | y = AveragePooling2D(pool_size = (2,2), strides = 1, padding = "same")(y) 19 | 20 | y = Conv2D(filters = n2, kernel_size = (2,2), strides = 2, padding = "valid", 21 | activation = None)(y) 22 | y = BatchNormalization()(y) 23 | y = Activation("relu")(y) 24 | y = Dropout(dropout_rate)(y) 25 | y = AveragePooling2D(pool_size = (2,2), strides = 2, padding = "valid")(y) 26 | 27 | y = Conv2D(filters = n3, kernel_size = (3,3), strides = 2, padding = "valid", 28 | activation = None)(y) 29 | y = BatchNormalization()(y) 30 | y = Activation("relu")(y) 31 | y = Dropout(dropout_rate)(y) 32 | #y = AveragePooling2D(pool_size = (2,2), strides = 2, padding = "valid")(y) 33 | 34 | y = Flatten()(y) 35 | y = Dense(units = n_classes, activation = None, use_bias = False, 36 | kernel_regularizer=tf.keras.regularizers.l2(1e-3))(y) 37 | y = Activation("softmax")(y) 38 | 39 | 40 | model_A = Model(inputs = x, outputs = y) 41 | 42 | model_A.compile(optimizer=tf.keras.optimizers.Adam(lr = 1e-3), 43 | loss = "sparse_categorical_crossentropy", 44 | metrics = ["accuracy"]) 45 | return model_A 46 | 47 | def cnn_2layer_fc_model(n_classes,n1 = 128, n2=256, dropout_rate = 0.2,input_shape = (28,28)): 48 | model_A, x = None, None 49 | 50 | x = Input(input_shape) 51 | if len(input_shape)==2: y = Reshape((input_shape[0], input_shape[1], 1))(x) 52 | 53 | y = Conv2D(filters = n1, kernel_size = (3,3), strides = 1, padding = "same", 54 | activation = None)(y) 55 | y = BatchNormalization()(y) 56 | y = Activation("relu")(y) 57 | y = Dropout(dropout_rate)(y) 58 | y = AveragePooling2D(pool_size = (2,2), strides = 1, padding = "same")(y) 59 | 60 | 61 | y = Conv2D(filters = n2, kernel_size = (3,3), strides = 2, padding = "valid", 62 | activation = None)(y) 63 | y = BatchNormalization()(y) 64 | y = Activation("relu")(y) 65 | y = Dropout(dropout_rate)(y) 66 | #y = AveragePooling2D(pool_size = (2,2), strides = 2, padding = "valid")(y) 67 | 68 | y = Flatten()(y) 69 | y = Dense(units = n_classes, activation = None, use_bias = False, 70 | kernel_regularizer=tf.keras.regularizers.l2(1e-3))(y) 71 | y = Activation("softmax")(y) 72 | 73 | 74 | model_A = Model(inputs = x, outputs = y) 75 | 76 | model_A.compile(optimizer=tf.keras.optimizers.Adam(lr = 1e-3), 77 | loss = "sparse_categorical_crossentropy", 78 | metrics = ["accuracy"]) 79 | return model_A 80 | 81 | 82 | def remove_last_layer(model, loss = "mean_absolute_error"): 83 | """ 84 | Input: Keras model, a classification model whose last layer is a softmax activation 85 | Output: Keras model, the same model with the last softmax activation layer removed, 86 | while keeping the same parameters 87 | """ 88 | 89 | new_model = Model(inputs = model.inputs, outputs = model.layers[-2].output) 90 | new_model.set_weights(model.get_weights()) 91 | new_model.compile(optimizer=tf.keras.optimizers.Adam(lr = 1e-3), 92 | loss = loss) 93 | 94 | return new_model -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FedMD 2 | FedMD: Heterogenous Federated Learning via Model Distillation. 3 | Preprint on https://arxiv.org/abs/1910.03581. 4 | 5 | ## Run scripts on Google Colab 6 | 7 | 1. open a google Colab 8 | 9 | 2. Clone the project folder from Github 10 | ``` 11 | ! git clone github_link 12 | ``` 13 | 14 | 3. Then access the folder just created. 15 | ``` 16 | % cd project_folder/ 17 | ``` 18 | 19 | 4. Run the python script in Colab. For instance 20 | ``` 21 | ! python pretrain_CNN_on_public_dataset.py -conf conf/pretrain_MNIST_conf.json 22 | ``` -------------------------------------------------------------------------------- /__pycache__/FedMD.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzq2doc/FedMD/113f2418193a07d28cc3fc74daa17668f7fa55b9/__pycache__/FedMD.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/Neural_Networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzq2doc/FedMD/113f2418193a07d28cc3fc74daa17668f7fa55b9/__pycache__/Neural_Networks.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/data_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzq2doc/FedMD/113f2418193a07d28cc3fc74daa17668f7fa55b9/__pycache__/data_utils.cpython-36.pyc -------------------------------------------------------------------------------- /conf/.ipynb_checkpoints/EMNIST_balance_conf-checkpoint.json: -------------------------------------------------------------------------------- 1 | { 2 | "N_parties": 10, 3 | "N_samples_per_class": 3, 4 | "N_alignment": 5000, 5 | "private_classes": [10, 11, 12, 13, 14, 15], 6 | "public_classes": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 7 | "is_show": False, 8 | "N_rounds": 13, 9 | "N_logits_matching_round": 1, 10 | "N_private_training_round": 10, 11 | "private_training_batchsize" : 5, 12 | "logits_matching_batchsize": 128, 13 | "EMNIST_dir": "./dataset/emnist-letters.mat", 14 | "model_saved_dir": "./pretrained_from_MNIST/", 15 | "result_save_dir": "./FEMNIST_balanced/" 16 | } -------------------------------------------------------------------------------- /conf/.ipynb_checkpoints/pretrain_MNIST_conf-checkpoint.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_classes": 16, 3 | "data_type": "MNIST", 4 | "models": [{"model_type": "2_layer_CNN", "params": {"n1": 128, "n2": 256, "dropout_rate": 0.2}}, 5 | {"model_type": "2_layer_CNN", "params": {"n1": 128, "n2": 384, "dropout_rate": 0.2}}, 6 | {"model_type": "2_layer_CNN", "params": {"n1": 128, 'n2': 512, "dropout_rate": 0.2}}, 7 | {"model_type": "2_layer_CNN", "params": {"n1": 256, "n2": 256, "dropout_rate": 0.3}}, 8 | {"model_type": "2_layer_CNN", "params": {"n1": 256, "n2": 512, "dropout_rate": 0.4}}, 9 | {"model_type": "3_layer_CNN", "params": {"n1": 64, "n2": 128, "n3": 256, "dropout_rate": 0.2}}, 10 | {"model_type": "3_layer_CNN", "params": {"n1": 64, "n2": 128, "n3": 192, "dropout_rate": 0.2}}, 11 | {"model_type": "3_layer_CNN", "params": {"n1": 128, "n2": 192, "n3": 256, "dropout_rate": 0.2}}, 12 | {"model_type": "3_layer_CNN", "params": {"n1": 128, "n2": 128, "n3": 128, "dropout_rate": 0.3}}, 13 | {"model_type": "3_layer_CNN", "params": {"n1": 128, "n2": 128, "n3": 192, "dropout_rate": 0.3}} 14 | ], 15 | "train_params": {"min_delta": 0.001, "patience": 3, 16 | "batch_size": 128, "epochs": 20, "is_shuffle": True, 17 | "verbose": 1}, 18 | "save_directory": "./pretrained_from_MNIST/", 19 | "save_names" : ["CNN_128_256", "CNN_128_384", "CNN_128_512", "CNN_256_256", "CNN_256_512", 20 | "CNN_64_128_256", "CNN_64_128_192", "CNN_128_192_256", "CNN_128_128_128", "CNN_128_128_192"], 21 | "early_stopping" : True, 22 | } -------------------------------------------------------------------------------- /conf/EMNIST_balance_conf.json: -------------------------------------------------------------------------------- 1 | { 2 | "N_parties": 10, 3 | "N_samples_per_class": 3, 4 | "N_alignment": 5000, 5 | "private_classes": [10, 11, 12, 13, 14, 15], 6 | "public_classes": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 7 | "is_show": False, 8 | "N_rounds": 13, 9 | "N_logits_matching_round": 1, 10 | "N_private_training_round": 10, 11 | "private_training_batchsize" : 5, 12 | "logits_matching_batchsize": 128, 13 | "EMNIST_dir": "./dataset/emnist-letters.mat", 14 | "model_saved_dir": "./pretrained_from_MNIST/", 15 | "result_save_dir": "./FEMNIST_balanced/" 16 | } -------------------------------------------------------------------------------- /conf/pretrain_MNIST_conf.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_classes": 16, 3 | "data_type": "MNIST", 4 | "models": [{"model_type": "2_layer_CNN", "params": {"n1": 128, "n2": 256, "dropout_rate": 0.2}}, 5 | {"model_type": "2_layer_CNN", "params": {"n1": 128, "n2": 384, "dropout_rate": 0.2}}, 6 | {"model_type": "2_layer_CNN", "params": {"n1": 128, 'n2': 512, "dropout_rate": 0.2}}, 7 | {"model_type": "2_layer_CNN", "params": {"n1": 256, "n2": 256, "dropout_rate": 0.3}}, 8 | {"model_type": "2_layer_CNN", "params": {"n1": 256, "n2": 512, "dropout_rate": 0.4}}, 9 | {"model_type": "3_layer_CNN", "params": {"n1": 64, "n2": 128, "n3": 256, "dropout_rate": 0.2}}, 10 | {"model_type": "3_layer_CNN", "params": {"n1": 64, "n2": 128, "n3": 192, "dropout_rate": 0.2}}, 11 | {"model_type": "3_layer_CNN", "params": {"n1": 128, "n2": 192, "n3": 256, "dropout_rate": 0.2}}, 12 | {"model_type": "3_layer_CNN", "params": {"n1": 128, "n2": 128, "n3": 128, "dropout_rate": 0.3}}, 13 | {"model_type": "3_layer_CNN", "params": {"n1": 128, "n2": 128, "n3": 192, "dropout_rate": 0.3}} 14 | ], 15 | "train_params": {"min_delta": 0.001, "patience": 3, 16 | "batch_size": 128, "epochs": 20, "is_shuffle": True, 17 | "verbose": 1}, 18 | "save_directory": "./pretrained_from_MNIST/", 19 | "save_names" : ["CNN_128_256", "CNN_128_384", "CNN_128_512", "CNN_256_256", "CNN_256_512", 20 | "CNN_64_128_256", "CNN_64_128_192", "CNN_128_192_256", "CNN_128_128_128", "CNN_128_128_192"], 21 | "early_stopping" : True, 22 | } -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.model_selection import StratifiedShuffleSplit 4 | from tensorflow.keras.datasets import cifar10, cifar100, mnist 5 | import scipy.io as sio 6 | 7 | 8 | def load_MNIST_data(standarized = False, verbose = False): 9 | (X_train, y_train), (X_test, y_test) = mnist.load_data() 10 | 11 | if standarized: 12 | X_train = X_train/255 13 | X_test = X_test/255 14 | mean_image = np.mean(X_train, axis=0) 15 | X_train -= mean_image 16 | X_test -= mean_image 17 | 18 | if verbose == True: 19 | print("MNIST dataset ... ") 20 | print("X_train shape :", X_train.shape) 21 | print("X_test shape :", X_test.shape) 22 | print("y_train shape :", y_train.shape) 23 | print("y_test shape :", y_test.shape) 24 | 25 | return X_train, y_train, X_test, y_test 26 | 27 | 28 | def load_EMNIST_data(file, verbose = False, standarized = False): 29 | """ 30 | file should be the downloaded EMNIST file in .mat format. 31 | """ 32 | mat = sio.loadmat(file) 33 | data = mat["dataset"] 34 | 35 | X_train = data['train'][0,0]['images'][0,0] 36 | X_train = X_train.reshape((X_train.shape[0], 28, 28), order = "F") 37 | y_train = data['train'][0,0]['labels'][0,0] 38 | y_train = np.squeeze(y_train) 39 | y_train -= 1 #y_train is zero-based 40 | 41 | X_test = data['test'][0,0]['images'][0,0] 42 | X_test= X_test.reshape((X_test.shape[0], 28, 28), order = "F") 43 | y_test = data['test'][0,0]['labels'][0,0] 44 | y_test = np.squeeze(y_test) 45 | y_test -= 1 #y_test is zero-based 46 | 47 | if standarized: 48 | X_train = X_train/255 49 | X_test = X_test/255 50 | mean_image = np.mean(X_train, axis=0) 51 | X_train -= mean_image 52 | X_test -= mean_image 53 | 54 | 55 | if verbose == True: 56 | print("EMNIST-letter dataset ... ") 57 | print("X_train shape :", X_train.shape) 58 | print("X_test shape :", X_test.shape) 59 | print("y_train shape :", y_train.shape) 60 | print("y_test shape :", y_test.shape) 61 | 62 | return X_train, y_train, X_test, y_test 63 | 64 | 65 | def generate_partial_data(X, y, class_in_use = "all", verbose = False): 66 | if class_in_use == "all": 67 | idx = np.ones_like(y, dtype = bool) 68 | else: 69 | idx = [y == i for i in class_in_use] 70 | idx = np.any(idx, axis = 0) 71 | X_incomplete, y_incomplete = X[idx], y[idx] 72 | if verbose == True: 73 | print("X shape :", X_incomplete.shape) 74 | print("y shape :", y_incomplete.shape) 75 | return X_incomplete, y_incomplete 76 | 77 | 78 | 79 | def generate_bal_private_data(X, y, N_parties = 10, classes_in_use = range(11), 80 | N_samples_per_class = 20, data_overlap = False): 81 | """ 82 | Input: 83 | -- N_parties : int, number of collaboraters in this activity; 84 | -- classes_in_use: array or generator, the classes of EMNIST-letters dataset 85 | (0 <= y <= 25) to be used as private data; 86 | -- N_sample_per_class: int, the number of private data points of each class for each party 87 | 88 | return: 89 | 90 | """ 91 | priv_data = [None] * N_parties 92 | combined_idx = np.array([], dtype = np.int16) 93 | for cls in classes_in_use: 94 | idx = np.where(y == cls)[0] 95 | idx = np.random.choice(idx, N_samples_per_class * N_parties, 96 | replace = data_overlap) 97 | combined_idx = np.r_[combined_idx, idx] 98 | for i in range(N_parties): 99 | idx_tmp = idx[i * N_samples_per_class : (i + 1)*N_samples_per_class] 100 | if priv_data[i] is None: 101 | tmp = {} 102 | tmp["X"] = X[idx_tmp] 103 | tmp["y"] = y[idx_tmp] 104 | tmp["idx"] = idx_tmp 105 | priv_data[i] = tmp 106 | else: 107 | priv_data[i]['idx'] = np.r_[priv_data[i]["idx"], idx_tmp] 108 | priv_data[i]["X"] = np.vstack([priv_data[i]["X"], X[idx_tmp]]) 109 | priv_data[i]["y"] = np.r_[priv_data[i]["y"], y[idx_tmp]] 110 | 111 | 112 | total_priv_data = {} 113 | total_priv_data["idx"] = combined_idx 114 | total_priv_data["X"] = X[combined_idx] 115 | total_priv_data["y"] = y[combined_idx] 116 | return priv_data, total_priv_data 117 | 118 | 119 | def generate_alignment_data(X, y, N_alignment = 3000): 120 | 121 | split = StratifiedShuffleSplit(n_splits=1, train_size= N_alignment) 122 | if N_alignment == "all": 123 | alignment_data = {} 124 | alignment_data["idx"] = np.arange(y.shape[0]) 125 | alignment_data["X"] = X 126 | alignment_data["y"] = y 127 | return alignment_data 128 | for train_index, _ in split.split(X, y): 129 | X_alignment = X[train_index] 130 | y_alignment = y[train_index] 131 | alignment_data = {} 132 | alignment_data["idx"] = train_index 133 | alignment_data["X"] = X_alignment 134 | alignment_data["y"] = y_alignment 135 | 136 | return alignment_data -------------------------------------------------------------------------------- /dataset/emnist-letters.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzq2doc/FedMD/113f2418193a07d28cc3fc74daa17668f7fa55b9/dataset/emnist-letters.mat -------------------------------------------------------------------------------- /pretrain_CNN_on_public_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import sys 4 | import argparse 5 | import pickle 6 | from tensorflow.keras.callbacks import EarlyStopping 7 | 8 | from data_utils import load_MNIST_data 9 | from Neural_Networks import cnn_2layer_fc_model, cnn_3layer_fc_model 10 | 11 | 12 | def parseArg(): 13 | parser = argparse.ArgumentParser(description='Train an array of Neural Networks on either MNIST or CIFAR') 14 | parser.add_argument('-conf', metavar='conf_file', nargs=1, 15 | help='the config file for training, \ 16 | for training on MNIST, the default conf_file is ./conf/pretrain_MNIST.json, \ 17 | for training on CIFAR, the default conf_file is ./conf/pretrain_CIFAR.json.' 18 | ) 19 | 20 | conf_file = os.path.abspath("conf/pretrain_MNIST_conf.json") 21 | 22 | if len(sys.argv) > 1: 23 | args = parser.parse_args(sys.argv[1:]) 24 | if args.conf: 25 | conf_file = args.conf[0] 26 | return conf_file 27 | 28 | 29 | 30 | def train_models(models, X_train, y_train, X_test, y_test, 31 | is_show = False, save_dir = "./", save_names = None, 32 | early_stopping = True, 33 | min_delta = 0.001, patience = 3, batch_size = 128, epochs = 20, is_shuffle=True, verbose = 1, 34 | ): 35 | ''' 36 | Train an array of models on the same dataset. 37 | We use early termination to speed up training. 38 | ''' 39 | 40 | resulting_val_acc = [] 41 | record_result = [] 42 | for n, model in enumerate(models): 43 | print("Training model ", n) 44 | if early_stopping: 45 | model.fit(X_train, y_train, 46 | validation_data = [X_test, y_test], 47 | callbacks=[EarlyStopping(monitor='val_acc', min_delta=min_delta, patience=patience)], 48 | batch_size = batch_size, epochs = epochs, shuffle=is_shuffle, verbose = verbose 49 | ) 50 | else: 51 | model.fit(X_train, y_train, 52 | validation_data = [X_test, y_test], 53 | batch_size = batch_size, epochs = epochs, shuffle=is_shuffle, verbose = verbose 54 | ) 55 | 56 | resulting_val_acc.append(model.history.history["val_acc"][-1]) 57 | record_result.append({"train_acc": model.history.history["acc"], 58 | "val_acc": model.history.history["val_acc"], 59 | "train_loss": model.history.history["loss"], 60 | "val_loss": model.history.history["val_loss"]}) 61 | 62 | 63 | save_dir_path = os.path.abspath(save_dir) 64 | #make dir 65 | try: 66 | os.makedirs(save_dir_path) 67 | except OSError as e: 68 | if e.errno != errno.EEXIST: 69 | raise 70 | 71 | if save_names is None: 72 | file_name = save_dir + "model_{0}".format(n) + ".h5" 73 | else: 74 | file_name = save_dir + save_names[n] + ".h5" 75 | model.save(file_name) 76 | 77 | if is_show: 78 | print("pre-train accuracy: ") 79 | print(resulting_val_acc) 80 | 81 | return record_result 82 | 83 | 84 | models = {"2_layer_CNN": cnn_2layer_fc_model, 85 | "3_layer_CNN": cnn_3layer_fc_model} 86 | 87 | 88 | if __name__ == "__main__": 89 | conf_file = parseArg() 90 | with open(conf_file, "r") as f: 91 | conf_dict = eval(f.read()) 92 | dataset = conf_dict["data_type"] 93 | n_classes = conf_dict["n_classes"] 94 | model_config = conf_dict["models"] 95 | train_params = conf_dict["train_params"] 96 | save_dir = conf_dict["save_directory"] 97 | save_names = conf_dict["save_names"] 98 | early_stopping = conf_dict["early_stopping"] 99 | 100 | 101 | del conf_dict 102 | 103 | 104 | if dataset == "MNIST": 105 | input_shape = (28,28) 106 | X_train, y_train, X_test, y_test = load_MNIST_data(standarized = True, 107 | verbose = True) 108 | 109 | else: 110 | print("Unknown dataset. Program stopped.") 111 | sys.exit() 112 | 113 | pretrain_models = [] 114 | for i, item in enumerate(model_config): 115 | name = item["model_type"] 116 | model_params = item["params"] 117 | tmp = models[name](n_classes=n_classes, 118 | input_shape=input_shape, 119 | **model_params) 120 | 121 | print("model {0} : {1}".format(i, save_names[i])) 122 | print(tmp.summary()) 123 | pretrain_models.append(tmp) 124 | 125 | record_result = train_models(pretrain_models, X_train, y_train, X_test, y_test, 126 | save_dir = save_dir, save_names = save_names, is_show=True, 127 | early_stopping = early_stopping, 128 | **train_params 129 | ) 130 | 131 | with open('pretrain_result.pkl', 'wb') as f: 132 | pickle.dump(record_result, f, protocol=pickle.HIGHEST_PROTOCOL) 133 | -------------------------------------------------------------------------------- /pretrain_result.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzq2doc/FedMD/113f2418193a07d28cc3fc74daa17668f7fa55b9/pretrain_result.pkl -------------------------------------------------------------------------------- /result_FEMNIST_balanced/col_performance.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzq2doc/FedMD/113f2418193a07d28cc3fc74daa17668f7fa55b9/result_FEMNIST_balanced/col_performance.pkl -------------------------------------------------------------------------------- /result_FEMNIST_balanced/init_result.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzq2doc/FedMD/113f2418193a07d28cc3fc74daa17668f7fa55b9/result_FEMNIST_balanced/init_result.pkl -------------------------------------------------------------------------------- /result_FEMNIST_balanced/pooled_train_result.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzq2doc/FedMD/113f2418193a07d28cc3fc74daa17668f7fa55b9/result_FEMNIST_balanced/pooled_train_result.pkl -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | if __name__ == "__main__": 5 | 6 | 7 | 8 | #dpath = os.path.dirname(os.path.abspath(self.db_store_name)) 9 | 10 | dpath = os.path.abspath("./MNIST") 11 | #make dir 12 | try: 13 | os.makedirs(dpath) 14 | except OSError as e: 15 | if e.errno != errno.EEXIST: 16 | raise 17 | 18 | a = np.arange(10) 19 | a = np.sin(a) 20 | 21 | print(a) -------------------------------------------------------------------------------- /utility.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | def plot_history(model): 8 | 9 | """ 10 | input : model is trained keras model. 11 | """ 12 | 13 | fig, axes = plt.subplots(2,1, figsize = (12, 6), sharex = True) 14 | axes[0].plot(model.history.history["loss"], "b.-", label = "Training Loss") 15 | axes[0].plot(model.history.history["val_loss"], "k^-", label = "Val Loss") 16 | axes[0].set_xlabel("Epoch") 17 | axes[0].set_ylabel("Loss") 18 | axes[0].legend() 19 | 20 | 21 | axes[1].plot(model.history.history["acc"], "b.-", label = "Training Acc") 22 | axes[1].plot(model.history.history["val_acc"], "k^-", label = "Val Acc") 23 | axes[1].set_xlabel("Epoch") 24 | axes[1].set_ylabel("Accuracy") 25 | axes[1].legend() 26 | 27 | plt.subplots_adjust(hspace=0) 28 | plt.show() 29 | 30 | def show_performance(model, Xtrain, ytrain, Xtest, ytest): 31 | y_pred = None 32 | print("CNN+fC Training Accuracy :") 33 | y_pred = model.predict(Xtrain, verbose = 0).argmax(axis = 1) 34 | print((y_pred == ytrain).mean()) 35 | print("CNN+fc Test Accuracy :") 36 | y_pred = model.predict(Xtest, verbose = 0).argmax(axis = 1) 37 | print((y_pred == ytest).mean()) 38 | print("Confusion_matrix : ") 39 | print(confusion_matrix(y_true = ytest, y_pred = y_pred)) 40 | 41 | del y_pred --------------------------------------------------------------------------------