├── .gitignore ├── README.md ├── attack_timing.py ├── client.py ├── defense.py ├── federated_learning ├── __init__.py ├── arguments.py ├── datasets │ ├── __init__.py │ ├── cifar10.py │ ├── data_distribution │ │ ├── __init__.py │ │ └── iid_equal.py │ ├── dataset.py │ └── fashion_mnist.py ├── dimensionality_reduction │ ├── __init__.py │ └── pca.py ├── nets │ ├── __init__.py │ ├── cifar_10_cnn.py │ └── fashion_mnist_cnn.py ├── parameters │ ├── __init__.py │ ├── gradients.py │ ├── log_utils.py │ ├── model_comparison.py │ └── parameter_processors.py ├── schedulers │ ├── __init__.py │ └── min_lr_step.py ├── utils │ ├── __init__.py │ ├── apply_scalers.py │ ├── class_flipping_methods.py │ ├── client_utils.py │ ├── csv_utils.py │ ├── data_loader_utils.py │ ├── experiment_ids.py │ ├── fed_avg.py │ ├── file_storage_utils.py │ ├── identify_random_elements.py │ ├── label_replacement.py │ ├── model_list_parser.py │ ├── poison_data.py │ └── tensor_converter.py └── worker_selection │ ├── __init__.py │ ├── breakpoint_after.py │ ├── breakpoint_before.py │ ├── poisoner_probability.py │ ├── random.py │ └── selection_strategy.py ├── generate_data_distribution.py ├── generate_default_models.py ├── label_flipping_attack.py ├── malicious_participant_availability.py ├── requirements.pip └── server.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | .DS_Store 107 | 108 | data_loaders/ 109 | data/ 110 | default_models/ 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data Poisoning Attacks Against Federated Learning Systems 2 | 3 | Code for the ESORICS 2020 paper: Data Poisoning Attacks Against Federated Learning Systems 4 | 5 | ## Installation 6 | 7 | 1) Create a virtualenv (Python 3.7) 8 | 2) Install dependencies inside of virtualenv (```pip install -r requirements.pip```) 9 | 3) If you are planning on using the defense, you will need to install ```matplotlib```. This is not required for running experiments, and is not included in the requirements file 10 | 11 | ## Instructions for execution 12 | 13 | Using this repository, you can replicate all results presented at ESORICS. We outline the steps required to execute different experiments below. 14 | 15 | ### Setup 16 | 17 | Before you can run any experiments, you must complete some setup: 18 | 19 | 1) ```python3 generate_data_distribution.py``` This downloads the datasets, as well as generates a static distribution of the training and test data to provide consistency in experiments. 20 | 2) ```python3 generate_default_models.py``` This generates an instance of all of the models used in the paper, and saves them to disk. 21 | 22 | ### General Information 23 | 24 | Some pointers & general information: 25 | - Most hyperparameters can be set in the ```federated_learning/arguments.py``` file 26 | - Most specific experiment settings are located in the respective experiment files (see the following sections) 27 | 28 | ### Experiments - Label Flipping Attack Feasibility 29 | 30 | Running an attack: ```python3 label_flipping_attack.py``` 31 | 32 | ### Experiments - Attack Timing in Label Flipping Attacks 33 | 34 | Running an attack: ```python3 attack_timing.py``` 35 | 36 | ### Experiments - Malicious Participant Availability 37 | 38 | Running an attack: ```python3 malicious_participant_availability.py``` 39 | 40 | ### Experiments - Defending Against Label Flipping Attacks 41 | 42 | Running the defense: ```python3 defense.py``` 43 | 44 | ### Experiment Hyperparameters 45 | 46 | Recommended default hyperparameters for CIFAR10 (using the provided CNN): 47 | - Batch size: 10 48 | - LR: 0.01 49 | - Number of epochs: 200 50 | - Momentum: 0.5 51 | - Scheduler step size: 50 52 | - Scheduler gamma: 0.5 53 | - Min_lr: 1e-10 54 | 55 | Recommended default hyperparameters for Fashion-MNIST (using the provided CNN): 56 | - Batch size: 4 57 | - LR: 0.001 58 | - Number of epochs: 200 59 | - Momentum: 0.9 60 | - Scheduler step size: 10 61 | - Scheduler gamma: 0.1 62 | - Min_lr: 1e-10 63 | 64 | ## Citing 65 | 66 | If you use this code, please cite the paper: 67 | 68 | ``` 69 | @ARTICLE{2020arXiv200708432T, 70 | author = {{Tolpegin}, Vale and {Truex}, Stacey and {Emre Gursoy}, Mehmet and 71 | {Liu}, Ling}, 72 | title = "{Data Poisoning Attacks Against Federated Learning Systems}", 73 | journal = {arXiv e-prints}, 74 | keywords = {Computer Science - Machine Learning, Computer Science - Cryptography and Security, Statistics - Machine Learning}, 75 | year = 2020, 76 | month = jul, 77 | eid = {arXiv:2007.08432}, 78 | pages = {arXiv:2007.08432}, 79 | archivePrefix = {arXiv}, 80 | eprint = {2007.08432}, 81 | primaryClass = {cs.LG}, 82 | adsurl = {https://ui.adsabs.harvard.edu/abs/2020arXiv200708432T}, 83 | adsnote = {Provided by the SAO/NASA Astrophysics Data System} 84 | } 85 | ``` 86 | -------------------------------------------------------------------------------- /attack_timing.py: -------------------------------------------------------------------------------- 1 | from federated_learning.utils import replace_0_with_2 2 | from federated_learning.utils import replace_5_with_3 3 | from federated_learning.utils import replace_1_with_9 4 | from federated_learning.utils import replace_4_with_6 5 | from federated_learning.utils import replace_1_with_3 6 | from federated_learning.utils import replace_6_with_0 7 | from federated_learning.worker_selection import BeforeBreakpoint 8 | from federated_learning.worker_selection import AfterBreakpoint 9 | from server import run_exp 10 | 11 | if __name__ == '__main__': 12 | START_EXP_IDX = 3000 13 | NUM_EXP = 1 14 | NUM_POISONED_WORKERS = 25 15 | REPLACEMENT_METHOD = replace_1_with_9 16 | KWARGS = { 17 | "BeforeBreakPoint_EPOCH" : 75, 18 | "BeforeBreakpoint_NUM_WORKERS_PER_ROUND" : 5, 19 | "AfterBreakPoint_EPOCH" : 75, 20 | "AfterBreakpoint_NUM_WORKERS_PER_ROUND" : 5, 21 | } 22 | 23 | for experiment_id in range(START_EXP_IDX, START_EXP_IDX + NUM_EXP): 24 | run_exp(REPLACEMENT_METHOD, NUM_POISONED_WORKERS, KWARGS, AfterBreakpoint(), experiment_id) 25 | -------------------------------------------------------------------------------- /client.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from sklearn.metrics import confusion_matrix 4 | from sklearn.metrics import classification_report 5 | from federated_learning.schedulers import MinCapableStepLR 6 | import os 7 | import numpy 8 | import copy 9 | 10 | class Client: 11 | 12 | def __init__(self, args, client_idx, train_data_loader, test_data_loader): 13 | """ 14 | :param args: experiment arguments 15 | :type args: Arguments 16 | :param client_idx: Client index 17 | :type client_idx: int 18 | :param train_data_loader: Training data loader 19 | :type train_data_loader: torch.utils.data.DataLoader 20 | :param test_data_loader: Test data loader 21 | :type test_data_loader: torch.utils.data.DataLoader 22 | """ 23 | self.args = args 24 | self.client_idx = client_idx 25 | 26 | self.device = self.initialize_device() 27 | self.set_net(self.load_default_model()) 28 | 29 | self.loss_function = self.args.get_loss_function()() 30 | self.optimizer = optim.SGD(self.net.parameters(), 31 | lr=self.args.get_learning_rate(), 32 | momentum=self.args.get_momentum()) 33 | self.scheduler = MinCapableStepLR(self.args.get_logger(), self.optimizer, 34 | self.args.get_scheduler_step_size(), 35 | self.args.get_scheduler_gamma(), 36 | self.args.get_min_lr()) 37 | 38 | self.train_data_loader = train_data_loader 39 | self.test_data_loader = test_data_loader 40 | 41 | def initialize_device(self): 42 | """ 43 | Creates appropriate torch device for client operation. 44 | """ 45 | if torch.cuda.is_available() and self.args.get_cuda(): 46 | return torch.device("cuda:0") 47 | else: 48 | return torch.device("cpu") 49 | 50 | def set_net(self, net): 51 | """ 52 | Set the client's NN. 53 | 54 | :param net: torch.nn 55 | """ 56 | self.net = net 57 | self.net.to(self.device) 58 | 59 | def load_default_model(self): 60 | """ 61 | Load a model from default model file. 62 | 63 | This is used to ensure consistent default model behavior. 64 | """ 65 | model_class = self.args.get_net() 66 | default_model_path = os.path.join(self.args.get_default_model_folder_path(), model_class.__name__ + ".model") 67 | 68 | return self.load_model_from_file(default_model_path) 69 | 70 | def load_model_from_file(self, model_file_path): 71 | """ 72 | Load a model from a file. 73 | 74 | :param model_file_path: string 75 | """ 76 | model_class = self.args.get_net() 77 | model = model_class() 78 | 79 | if os.path.exists(model_file_path): 80 | try: 81 | model.load_state_dict(torch.load(model_file_path)) 82 | except: 83 | self.args.get_logger().warning("Couldn't load model. Attempting to map CUDA tensors to CPU to solve error.") 84 | 85 | model.load_state_dict(torch.load(model_file_path, map_location=torch.device('cpu'))) 86 | else: 87 | self.args.get_logger().warning("Could not find model: {}".format(model_file_path)) 88 | 89 | return model 90 | 91 | def get_client_index(self): 92 | """ 93 | Returns the client index. 94 | """ 95 | return self.client_idx 96 | 97 | def get_nn_parameters(self): 98 | """ 99 | Return the NN's parameters. 100 | """ 101 | return self.net.state_dict() 102 | 103 | def update_nn_parameters(self, new_params): 104 | """ 105 | Update the NN's parameters. 106 | 107 | :param new_params: New weights for the neural network 108 | :type new_params: dict 109 | """ 110 | self.net.load_state_dict(copy.deepcopy(new_params), strict=True) 111 | 112 | def train(self, epoch): 113 | """ 114 | :param epoch: Current epoch # 115 | :type epoch: int 116 | """ 117 | self.net.train() 118 | 119 | # save model 120 | if self.args.should_save_model(epoch): 121 | self.save_model(epoch, self.args.get_epoch_save_start_suffix()) 122 | 123 | running_loss = 0.0 124 | for i, (inputs, labels) in enumerate(self.train_data_loader, 0): 125 | inputs, labels = inputs.to(self.device), labels.to(self.device) 126 | 127 | # zero the parameter gradients 128 | self.optimizer.zero_grad() 129 | 130 | # forward + backward + optimize 131 | outputs = self.net(inputs) 132 | loss = self.loss_function(outputs, labels) 133 | loss.backward() 134 | self.optimizer.step() 135 | 136 | # print statistics 137 | running_loss += loss.item() 138 | if i % self.args.get_log_interval() == 0: 139 | self.args.get_logger().info('[%d, %5d] loss: %.3f' % (epoch, i, running_loss / self.args.get_log_interval())) 140 | 141 | running_loss = 0.0 142 | 143 | self.scheduler.step() 144 | 145 | # save model 146 | if self.args.should_save_model(epoch): 147 | self.save_model(epoch, self.args.get_epoch_save_end_suffix()) 148 | 149 | return running_loss 150 | 151 | def save_model(self, epoch, suffix): 152 | """ 153 | Saves the model if necessary. 154 | """ 155 | self.args.get_logger().debug("Saving model to flat file storage. Save #{}", epoch) 156 | 157 | if not os.path.exists(self.args.get_save_model_folder_path()): 158 | os.mkdir(self.args.get_save_model_folder_path()) 159 | 160 | full_save_path = os.path.join(self.args.get_save_model_folder_path(), "model_" + str(self.client_idx) + "_" + str(epoch) + "_" + suffix + ".model") 161 | torch.save(self.get_nn_parameters(), full_save_path) 162 | 163 | def calculate_class_precision(self, confusion_mat): 164 | """ 165 | Calculates the precision for each class from a confusion matrix. 166 | """ 167 | return numpy.diagonal(confusion_mat) / numpy.sum(confusion_mat, axis=0) 168 | 169 | def calculate_class_recall(self, confusion_mat): 170 | """ 171 | Calculates the recall for each class from a confusion matrix. 172 | """ 173 | return numpy.diagonal(confusion_mat) / numpy.sum(confusion_mat, axis=1) 174 | 175 | def test(self): 176 | self.net.eval() 177 | 178 | correct = 0 179 | total = 0 180 | targets_ = [] 181 | pred_ = [] 182 | loss = 0.0 183 | with torch.no_grad(): 184 | for (images, labels) in self.test_data_loader: 185 | images, labels = images.to(self.device), labels.to(self.device) 186 | 187 | outputs = self.net(images) 188 | _, predicted = torch.max(outputs.data, 1) 189 | total += labels.size(0) 190 | correct += (predicted == labels).sum().item() 191 | 192 | targets_.extend(labels.cpu().view_as(predicted).numpy()) 193 | pred_.extend(predicted.cpu().numpy()) 194 | 195 | loss += self.loss_function(outputs, labels).item() 196 | 197 | accuracy = 100 * correct / total 198 | confusion_mat = confusion_matrix(targets_, pred_) 199 | 200 | class_precision = self.calculate_class_precision(confusion_mat) 201 | class_recall = self.calculate_class_recall(confusion_mat) 202 | 203 | self.args.get_logger().debug('Test set: Accuracy: {}/{} ({:.0f}%)'.format(correct, total, accuracy)) 204 | self.args.get_logger().debug('Test set: Loss: {}'.format(loss)) 205 | self.args.get_logger().debug("Classification Report:\n" + classification_report(targets_, pred_)) 206 | self.args.get_logger().debug("Confusion Matrix:\n" + str(confusion_mat)) 207 | self.args.get_logger().debug("Class precision: {}".format(str(class_precision))) 208 | self.args.get_logger().debug("Class recall: {}".format(str(class_recall))) 209 | 210 | return accuracy, loss, class_precision, class_recall 211 | -------------------------------------------------------------------------------- /defense.py: -------------------------------------------------------------------------------- 1 | import os 2 | from loguru import logger 3 | from federated_learning.arguments import Arguments 4 | from federated_learning.dimensionality_reduction import calculate_pca_of_gradients 5 | from federated_learning.parameters import get_layer_parameters 6 | from federated_learning.parameters import calculate_parameter_gradients 7 | from federated_learning.utils import get_model_files_for_epoch 8 | from federated_learning.utils import get_model_files_for_suffix 9 | from federated_learning.utils import apply_standard_scaler 10 | from federated_learning.utils import get_worker_num_from_model_file_name 11 | from client import Client 12 | import matplotlib.pyplot as plt 13 | from mpl_toolkits.mplot3d import Axes3D 14 | 15 | # Paths you need to put in. 16 | MODELS_PATH = "/absolute/path/to/models/folder/1823_models" 17 | EXP_INFO_PATH = "/absolute/path/to/log/file/1823.log" 18 | 19 | # The epochs over which you are calculating gradients. 20 | EPOCHS = list(range(10, 200)) 21 | 22 | # The layer of the NNs that you want to investigate. 23 | # If you are using the provided Fashion MNIST CNN, this should be "fc.weight" 24 | # If you are using the provided Cifar 10 CNN, this should be "fc2.weight" 25 | LAYER_NAME = "fc2.weight" 26 | 27 | # The source class. 28 | CLASS_NUM = 4 29 | 30 | # The IDs for the poisoned workers. This needs to be manually filled out. 31 | # You can find this information at the beginning of an experiment's log file. 32 | POISONED_WORKER_IDS = [] 33 | 34 | # The resulting graph is saved to a file 35 | SAVE_NAME = "defense_results.jpg" 36 | SAVE_SIZE = (18, 14) 37 | 38 | 39 | def load_models(args, model_filenames): 40 | clients = [] 41 | for model_filename in model_filenames: 42 | client = Client(args, 0, None, None) 43 | client.set_net(client.load_model_from_file(model_filename)) 44 | 45 | clients.append(client) 46 | 47 | return clients 48 | 49 | 50 | def plot_gradients_2d(gradients): 51 | fig = plt.figure() 52 | 53 | for (worker_id, gradient) in gradients: 54 | if worker_id in POISONED_WORKER_IDS: 55 | plt.scatter(gradient[0], gradient[1], color="blue", marker="x", s=1000, linewidth=5) 56 | else: 57 | plt.scatter(gradient[0], gradient[1], color="orange", s=180) 58 | 59 | fig.set_size_inches(SAVE_SIZE, forward=False) 60 | plt.grid(False) 61 | plt.margins(0,0) 62 | plt.savefig(SAVE_NAME, bbox_inches='tight', pad_inches=0.1) 63 | 64 | 65 | if __name__ == '__main__': 66 | args = Arguments(logger) 67 | args.log() 68 | 69 | model_files = sorted(os.listdir(MODELS_PATH)) 70 | logger.debug("Number of models: {}", str(len(model_files))) 71 | 72 | param_diff = [] 73 | worker_ids = [] 74 | 75 | for epoch in EPOCHS: 76 | start_model_files = get_model_files_for_epoch(model_files, epoch) 77 | start_model_file = get_model_files_for_suffix(start_model_files, args.get_epoch_save_start_suffix())[0] 78 | start_model_file = os.path.join(MODELS_PATH, start_model_file) 79 | start_model = load_models(args, [start_model_file])[0] 80 | 81 | start_model_layer_param = list(get_layer_parameters(start_model.get_nn_parameters(), LAYER_NAME)[CLASS_NUM]) 82 | 83 | end_model_files = get_model_files_for_epoch(model_files, epoch) 84 | end_model_files = get_model_files_for_suffix(end_model_files, args.get_epoch_save_end_suffix()) 85 | 86 | for end_model_file in end_model_files: 87 | worker_id = get_worker_num_from_model_file_name(end_model_file) 88 | end_model_file = os.path.join(MODELS_PATH, end_model_file) 89 | end_model = load_models(args, [end_model_file])[0] 90 | 91 | end_model_layer_param = list(get_layer_parameters(end_model.get_nn_parameters(), LAYER_NAME)[CLASS_NUM]) 92 | 93 | gradient = calculate_parameter_gradients(logger, start_model_layer_param, end_model_layer_param) 94 | gradient = gradient.flatten() 95 | 96 | param_diff.append(gradient) 97 | worker_ids.append(worker_id) 98 | 99 | logger.info("Gradients shape: ({}, {})".format(len(param_diff), param_diff[0].shape[0])) 100 | 101 | logger.info("Prescaled gradients: {}".format(str(param_diff))) 102 | scaled_param_diff = apply_standard_scaler(param_diff) 103 | logger.info("Postscaled gradients: {}".format(str(scaled_param_diff))) 104 | dim_reduced_gradients = calculate_pca_of_gradients(logger, scaled_param_diff, 2) 105 | logger.info("PCA reduced gradients: {}".format(str(dim_reduced_gradients))) 106 | 107 | logger.info("Dimensionally-reduced gradients shape: ({}, {})".format(len(dim_reduced_gradients), dim_reduced_gradients[0].shape[0])) 108 | 109 | plot_gradients_2d(zip(worker_ids, dim_reduced_gradients)) 110 | -------------------------------------------------------------------------------- /federated_learning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/git-disl/DataPoisoning_FL/467eca0de9c71c1e48f9a0f8894b74fa949fd251/federated_learning/__init__.py -------------------------------------------------------------------------------- /federated_learning/arguments.py: -------------------------------------------------------------------------------- 1 | from .nets import Cifar10CNN 2 | from .nets import FashionMNISTCNN 3 | from .worker_selection import BeforeBreakpoint 4 | from .worker_selection import AfterBreakpoint 5 | from .worker_selection import PoisonerProbability 6 | import torch.nn.functional as F 7 | import torch 8 | import json 9 | 10 | # Setting the seed for Torch 11 | SEED = 1 12 | torch.manual_seed(SEED) 13 | 14 | class Arguments: 15 | 16 | def __init__(self, logger): 17 | self.logger = logger 18 | 19 | self.batch_size = 10 20 | self.test_batch_size = 1000 21 | self.epochs = 10 22 | self.lr = 0.01 23 | self.momentum = 0.5 24 | self.cuda = True 25 | self.shuffle = False 26 | self.log_interval = 100 27 | self.kwargs = {} 28 | 29 | self.scheduler_step_size = 50 30 | self.scheduler_gamma = 0.5 31 | self.min_lr = 1e-10 32 | 33 | self.round_worker_selection_strategy = None 34 | self.round_worker_selection_strategy_kwargs = None 35 | 36 | self.save_model = False 37 | self.save_epoch_interval = 1 38 | self.save_model_path = "models" 39 | self.epoch_save_start_suffix = "start" 40 | self.epoch_save_end_suffix = "end" 41 | 42 | self.num_workers = 50 43 | self.num_poisoned_workers = 0 44 | 45 | #self.net = Cifar10CNN 46 | self.net = FashionMNISTCNN 47 | 48 | self.train_data_loader_pickle_path = "data_loaders/fashion-mnist/train_data_loader.pickle" 49 | self.test_data_loader_pickle_path = "data_loaders/fashion-mnist/test_data_loader.pickle" 50 | 51 | self.loss_function = torch.nn.CrossEntropyLoss 52 | 53 | self.default_model_folder_path = "default_models" 54 | 55 | self.data_path = "data" 56 | 57 | def get_round_worker_selection_strategy(self): 58 | return self.round_worker_selection_strategy 59 | 60 | def get_round_worker_selection_strategy_kwargs(self): 61 | return self.round_worker_selection_strategy_kwargs 62 | 63 | def set_round_worker_selection_strategy_kwargs(self, kwargs): 64 | self.round_worker_selection_strategy_kwargs = kwargs 65 | 66 | def set_client_selection_strategy(self, strategy): 67 | self.round_worker_selection_strategy = strategy 68 | 69 | def get_data_path(self): 70 | return self.data_path 71 | 72 | def get_epoch_save_start_suffix(self): 73 | return self.epoch_save_start_suffix 74 | 75 | def get_epoch_save_end_suffix(self): 76 | return self.epoch_save_end_suffix 77 | 78 | def set_train_data_loader_pickle_path(self, path): 79 | self.train_data_loader_pickle_path = path 80 | 81 | def get_train_data_loader_pickle_path(self): 82 | return self.train_data_loader_pickle_path 83 | 84 | def set_test_data_loader_pickle_path(self, path): 85 | self.test_data_loader_pickle_path = path 86 | 87 | def get_test_data_loader_pickle_path(self): 88 | return self.test_data_loader_pickle_path 89 | 90 | def get_cuda(self): 91 | return self.cuda 92 | 93 | def get_scheduler_step_size(self): 94 | return self.scheduler_step_size 95 | 96 | def get_scheduler_gamma(self): 97 | return self.scheduler_gamma 98 | 99 | def get_min_lr(self): 100 | return self.min_lr 101 | 102 | def get_default_model_folder_path(self): 103 | return self.default_model_folder_path 104 | 105 | def get_num_epochs(self): 106 | return self.epochs 107 | 108 | def set_num_poisoned_workers(self, num_poisoned_workers): 109 | self.num_poisoned_workers = num_poisoned_workers 110 | 111 | def set_num_workers(self, num_workers): 112 | self.num_workers = num_workers 113 | 114 | def set_model_save_path(self, save_model_path): 115 | self.save_model_path = save_model_path 116 | 117 | def get_logger(self): 118 | return self.logger 119 | 120 | def get_loss_function(self): 121 | return self.loss_function 122 | 123 | def get_net(self): 124 | return self.net 125 | 126 | def get_num_workers(self): 127 | return self.num_workers 128 | 129 | def get_num_poisoned_workers(self): 130 | return self.num_poisoned_workers 131 | 132 | def get_learning_rate(self): 133 | return self.lr 134 | 135 | def get_momentum(self): 136 | return self.momentum 137 | 138 | def get_shuffle(self): 139 | return self.shuffle 140 | 141 | def get_batch_size(self): 142 | return self.batch_size 143 | 144 | def get_test_batch_size(self): 145 | return self.test_batch_size 146 | 147 | def get_log_interval(self): 148 | return self.log_interval 149 | 150 | def get_save_model_folder_path(self): 151 | return self.save_model_path 152 | 153 | def get_learning_rate_from_epoch(self, epoch_idx): 154 | lr = self.lr * (self.scheduler_gamma ** int(epoch_idx / self.scheduler_step_size)) 155 | 156 | if lr < self.min_lr: 157 | self.logger.warning("Updating LR would place it below min LR. Skipping LR update.") 158 | 159 | return self.min_lr 160 | 161 | self.logger.debug("LR: {}".format(lr)) 162 | 163 | return lr 164 | 165 | def should_save_model(self, epoch_idx): 166 | """ 167 | Returns true/false models should be saved. 168 | 169 | :param epoch_idx: current training epoch index 170 | :type epoch_idx: int 171 | """ 172 | if not self.save_model: 173 | return False 174 | 175 | if epoch_idx == 1 or epoch_idx % self.save_epoch_interval == 0: 176 | return True 177 | 178 | def log(self): 179 | """ 180 | Log this arguments object to the logger. 181 | """ 182 | self.logger.debug("Arguments: {}", str(self)) 183 | 184 | def __str__(self): 185 | return "\nBatch Size: {}\n".format(self.batch_size) + \ 186 | "Test Batch Size: {}\n".format(self.test_batch_size) + \ 187 | "Epochs: {}\n".format(self.epochs) + \ 188 | "Learning Rate: {}\n".format(self.lr) + \ 189 | "Momentum: {}\n".format(self.momentum) + \ 190 | "CUDA Enabled: {}\n".format(self.cuda) + \ 191 | "Shuffle Enabled: {}\n".format(self.shuffle) + \ 192 | "Log Interval: {}\n".format(self.log_interval) + \ 193 | "Scheduler Step Size: {}\n".format(self.scheduler_step_size) + \ 194 | "Scheduler Gamma: {}\n".format(self.scheduler_gamma) + \ 195 | "Scheduler Minimum Learning Rate: {}\n".format(self.min_lr) + \ 196 | "Client Selection Strategy: {}\n".format(self.round_worker_selection_strategy) + \ 197 | "Client Selection Strategy Arguments: {}\n".format(json.dumps(self.round_worker_selection_strategy_kwargs, indent=4, sort_keys=True)) + \ 198 | "Model Saving Enabled: {}\n".format(self.save_model) + \ 199 | "Model Saving Interval: {}\n".format(self.save_epoch_interval) + \ 200 | "Model Saving Path (Relative): {}\n".format(self.save_model_path) + \ 201 | "Epoch Save Start Prefix: {}\n".format(self.epoch_save_start_suffix) + \ 202 | "Epoch Save End Suffix: {}\n".format(self.epoch_save_end_suffix) + \ 203 | "Number of Clients: {}\n".format(self.num_workers) + \ 204 | "Number of Poisoned Clients: {}\n".format(self.num_poisoned_workers) + \ 205 | "NN: {}\n".format(self.net) + \ 206 | "Train Data Loader Path: {}\n".format(self.train_data_loader_pickle_path) + \ 207 | "Test Data Loader Path: {}\n".format(self.test_data_loader_pickle_path) + \ 208 | "Loss Function: {}\n".format(self.loss_function) + \ 209 | "Default Model Folder Path: {}\n".format(self.default_model_folder_path) + \ 210 | "Data Path: {}\n".format(self.data_path) 211 | -------------------------------------------------------------------------------- /federated_learning/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | from .cifar10 import CIFAR10Dataset 3 | from .fashion_mnist import FashionMNISTDataset 4 | -------------------------------------------------------------------------------- /federated_learning/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | from torchvision import datasets 3 | from torchvision import transforms 4 | from torch.utils.data import DataLoader 5 | 6 | class CIFAR10Dataset(Dataset): 7 | 8 | def __init__(self, args): 9 | super(CIFAR10Dataset, self).__init__(args) 10 | 11 | def load_train_dataset(self): 12 | self.get_args().get_logger().debug("Loading CIFAR10 train data") 13 | 14 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 15 | transform = transforms.Compose([ 16 | transforms.RandomHorizontalFlip(), 17 | transforms.RandomCrop(32, 4), 18 | transforms.ToTensor(), 19 | normalize 20 | ]) 21 | train_dataset = datasets.CIFAR10(root=self.get_args().get_data_path(), train=True, download=True, transform=transform) 22 | train_loader = DataLoader(train_dataset, batch_size=len(train_dataset)) 23 | 24 | train_data = self.get_tuple_from_data_loader(train_loader) 25 | 26 | self.get_args().get_logger().debug("Finished loading CIFAR10 train data") 27 | 28 | return train_data 29 | 30 | def load_test_dataset(self): 31 | self.get_args().get_logger().debug("Loading CIFAR10 test data") 32 | 33 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 34 | transform = transforms.Compose([ 35 | transforms.ToTensor(), 36 | normalize 37 | ]) 38 | test_dataset = datasets.CIFAR10(root=self.get_args().get_data_path(), train=False, download=True, transform=transform) 39 | test_loader = DataLoader(test_dataset, batch_size=len(test_dataset)) 40 | 41 | test_data = self.get_tuple_from_data_loader(test_loader) 42 | 43 | self.get_args().get_logger().debug("Finished loading CIFAR10 test data") 44 | 45 | return test_data 46 | -------------------------------------------------------------------------------- /federated_learning/datasets/data_distribution/__init__.py: -------------------------------------------------------------------------------- 1 | from .iid_equal import distribute_batches_equally 2 | -------------------------------------------------------------------------------- /federated_learning/datasets/data_distribution/iid_equal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def distribute_batches_equally(train_data_loader, num_workers): 4 | """ 5 | Gives each worker the same number of batches of training data. 6 | 7 | :param train_data_loader: Training data loader 8 | :type train_data_loader: torch.utils.data.DataLoader 9 | :param num_workers: number of workers 10 | :type num_workers: int 11 | """ 12 | distributed_dataset = [[] for i in range(num_workers)] 13 | 14 | for batch_idx, (data, target) in enumerate(train_data_loader): 15 | worker_idx = batch_idx % num_workers 16 | 17 | distributed_dataset[worker_idx].append((data, target)) 18 | 19 | return distributed_dataset 20 | -------------------------------------------------------------------------------- /federated_learning/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data import TensorDataset 4 | import torch 5 | import numpy 6 | 7 | class Dataset: 8 | 9 | def __init__(self, args): 10 | self.args = args 11 | 12 | self.train_dataset = self.load_train_dataset() 13 | self.test_dataset = self.load_test_dataset() 14 | 15 | def get_args(self): 16 | """ 17 | Returns the arguments. 18 | 19 | :return: Arguments 20 | """ 21 | return self.args 22 | 23 | def get_train_dataset(self): 24 | """ 25 | Returns the train dataset. 26 | 27 | :return: tuple 28 | """ 29 | return self.train_dataset 30 | 31 | def get_test_dataset(self): 32 | """ 33 | Returns the test dataset. 34 | 35 | :return: tuple 36 | """ 37 | return self.test_dataset 38 | 39 | @abstractmethod 40 | def load_train_dataset(self): 41 | """ 42 | Loads & returns the training dataset. 43 | 44 | :return: tuple 45 | """ 46 | raise NotImplementedError("load_train_dataset() isn't implemented") 47 | 48 | @abstractmethod 49 | def load_test_dataset(self): 50 | """ 51 | Loads & returns the test dataset. 52 | 53 | :return: tuple 54 | """ 55 | raise NotImplementedError("load_test_dataset() isn't implemented") 56 | 57 | def get_train_loader(self, batch_size, **kwargs): 58 | """ 59 | Return the data loader for the train dataset. 60 | 61 | :param batch_size: batch size of data loader 62 | :type batch_size: int 63 | :return: torch.utils.data.DataLoader 64 | """ 65 | return Dataset.get_data_loader_from_data(batch_size, self.train_dataset[0], self.train_dataset[1], **kwargs) 66 | 67 | def get_test_loader(self, batch_size, **kwargs): 68 | """ 69 | Return the data loader for the test dataset. 70 | 71 | :param batch_size: batch size of data loader 72 | :type batch_size: int 73 | :return: torch.utils.data.DataLoader 74 | """ 75 | return Dataset.get_data_loader_from_data(batch_size, self.test_dataset[0], self.test_dataset[1], **kwargs) 76 | 77 | @staticmethod 78 | def get_data_loader_from_data(batch_size, X, Y, **kwargs): 79 | """ 80 | Get a data loader created from a given set of data. 81 | 82 | :param batch_size: batch size of data loader 83 | :type batch_size: int 84 | :param X: data features 85 | :type X: numpy.Array() 86 | :param Y: data labels 87 | :type Y: numpy.Array() 88 | :return: torch.utils.data.DataLoader 89 | """ 90 | X_torch = torch.from_numpy(X).float() 91 | 92 | if "classification_problem" in kwargs and kwargs["classification_problem"] == False: 93 | Y_torch = torch.from_numpy(Y).float() 94 | else: 95 | Y_torch = torch.from_numpy(Y).long() 96 | dataset = TensorDataset(X_torch, Y_torch) 97 | 98 | kwargs.pop("classification_problem", None) 99 | 100 | return DataLoader(dataset, batch_size=batch_size, **kwargs) 101 | 102 | @staticmethod 103 | def get_tuple_from_data_loader(data_loader): 104 | """ 105 | Get a tuple representation of the data stored in a data loader. 106 | 107 | :param data_loader: data loader to get data from 108 | :type data_loader: torch.utils.data.DataLoader 109 | :return: tuple 110 | """ 111 | return (next(iter(data_loader))[0].numpy(), next(iter(data_loader))[1].numpy()) 112 | -------------------------------------------------------------------------------- /federated_learning/datasets/fashion_mnist.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | from torchvision import datasets 3 | from torchvision import transforms 4 | from torch.utils.data import DataLoader 5 | 6 | class FashionMNISTDataset(Dataset): 7 | 8 | def __init__(self, args): 9 | super(FashionMNISTDataset, self).__init__(args) 10 | 11 | def load_train_dataset(self): 12 | self.get_args().get_logger().debug("Loading Fashion MNIST train data") 13 | 14 | train_dataset = datasets.FashionMNIST(self.get_args().get_data_path(), train=True, download=True, transform=transforms.Compose([transforms.ToTensor()])) 15 | train_loader = DataLoader(train_dataset, batch_size=len(train_dataset)) 16 | 17 | train_data = self.get_tuple_from_data_loader(train_loader) 18 | 19 | self.get_args().get_logger().debug("Finished loading Fashion MNIST train data") 20 | 21 | return train_data 22 | 23 | def load_test_dataset(self): 24 | self.get_args().get_logger().debug("Loading Fashion MNIST test data") 25 | 26 | test_dataset = datasets.FashionMNIST(self.get_args().get_data_path(), train=False, download=True, transform=transforms.Compose([transforms.ToTensor()])) 27 | test_loader = DataLoader(test_dataset, batch_size=len(test_dataset)) 28 | 29 | test_data = self.get_tuple_from_data_loader(test_loader) 30 | 31 | self.get_args().get_logger().debug("Finished loading Fashion MNIST test data") 32 | 33 | return test_data 34 | -------------------------------------------------------------------------------- /federated_learning/dimensionality_reduction/__init__.py: -------------------------------------------------------------------------------- 1 | from .pca import calculate_pca_of_gradients 2 | -------------------------------------------------------------------------------- /federated_learning/dimensionality_reduction/pca.py: -------------------------------------------------------------------------------- 1 | from sklearn.decomposition import PCA 2 | 3 | def calculate_pca_of_gradients(logger, gradients, num_components): 4 | pca = PCA(n_components=num_components) 5 | 6 | logger.info("Computing {}-component PCA of gradients".format(num_components)) 7 | 8 | return pca.fit_transform(gradients) 9 | -------------------------------------------------------------------------------- /federated_learning/nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar_10_cnn import Cifar10CNN 2 | from .fashion_mnist_cnn import FashionMNISTCNN 3 | -------------------------------------------------------------------------------- /federated_learning/nets/cifar_10_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Cifar10CNN(nn.Module): 6 | 7 | def __init__(self): 8 | super(Cifar10CNN, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) 11 | self.bn1 = nn.BatchNorm2d(32) 12 | self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 13 | self.bn2 = nn.BatchNorm2d(32) 14 | self.pool1 = nn.MaxPool2d(kernel_size=2) 15 | 16 | self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 17 | self.bn3 = nn.BatchNorm2d(64) 18 | self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 19 | self.bn4 = nn.BatchNorm2d(64) 20 | self.pool2 = nn.MaxPool2d(kernel_size=2) 21 | 22 | self.conv5 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 23 | self.bn5 = nn.BatchNorm2d(128) 24 | self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 25 | self.bn6 = nn.BatchNorm2d(128) 26 | self.pool3 = nn.MaxPool2d(kernel_size=2) 27 | 28 | self.fc1 = nn.Linear(128 * 4 * 4, 128) 29 | self.fc2 = nn.Linear(128, 10) 30 | 31 | def forward(self, x): 32 | x = self.bn1(F.relu(self.conv1(x))) 33 | x = self.bn2(F.relu(self.conv2(x))) 34 | x = self.pool1(x) 35 | 36 | x = self.bn3(F.relu(self.conv3(x))) 37 | x = self.bn4(F.relu(self.conv4(x))) 38 | x = self.pool2(x) 39 | 40 | x = self.bn5(F.relu(self.conv5(x))) 41 | x = self.bn6(F.relu(self.conv6(x))) 42 | x = self.pool3(x) 43 | 44 | x = x.view(-1, 128 * 4 * 4) 45 | 46 | x = self.fc1(x) 47 | x = F.softmax(self.fc2(x)) 48 | 49 | return x 50 | -------------------------------------------------------------------------------- /federated_learning/nets/fashion_mnist_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FashionMNISTCNN(nn.Module): 6 | 7 | def __init__(self): 8 | super(FashionMNISTCNN, self).__init__() 9 | 10 | self.layer1 = nn.Sequential( 11 | nn.Conv2d(1, 16, kernel_size=5, padding=2), 12 | nn.BatchNorm2d(16), 13 | nn.ReLU(), 14 | nn.MaxPool2d(2)) 15 | self.layer2 = nn.Sequential( 16 | nn.Conv2d(16, 32, kernel_size=5, padding=2), 17 | nn.BatchNorm2d(32), 18 | nn.ReLU(), 19 | nn.MaxPool2d(2)) 20 | 21 | self.fc = nn.Linear(7*7*32, 10) 22 | 23 | def forward(self, x): 24 | x = self.layer1(x) 25 | x = self.layer2(x) 26 | 27 | x = x.view(x.size(0), -1) 28 | 29 | x = self.fc(x) 30 | 31 | return x 32 | -------------------------------------------------------------------------------- /federated_learning/parameters/__init__.py: -------------------------------------------------------------------------------- 1 | from .gradients import calculate_model_gradient 2 | from .gradients import calculate_parameter_gradients 3 | from .parameter_processors import get_layer_parameters 4 | from .log_utils import log_model_parameter_names 5 | from .model_comparison import compare_models 6 | -------------------------------------------------------------------------------- /federated_learning/parameters/gradients.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | def calculate_model_gradient(logger, model_1, model_2): 4 | """ 5 | Calculates the gradient (parameter difference) between two Torch models. 6 | 7 | :param logger: loguru.logger 8 | :param model_1: torch.nn 9 | :param model_2: torch.nn 10 | """ 11 | model_1_parameters = list(dict(model_1.state_dict())) 12 | model_2_parameters = list(dict(model_2.state_dict())) 13 | 14 | return calculate_parameter_gradients(logger, model_1_parameters, model_2_parameters) 15 | 16 | def calculate_parameter_gradients(logger, params_1, params_2): 17 | """ 18 | Calculates the gradient (parameter difference) between two sets of Torch parameters. 19 | 20 | :param logger: loguru.logger 21 | :param model_1: dict 22 | :param model_2: dict 23 | """ 24 | logger.debug("Shape of model_1_parameters: {}".format(str(len(params_1)))) 25 | logger.debug("Shape of model_2_parameters: {}".format(str(len(params_2)))) 26 | 27 | return numpy.array([x for x in numpy.subtract(params_1, params_2)]) 28 | -------------------------------------------------------------------------------- /federated_learning/parameters/log_utils.py: -------------------------------------------------------------------------------- 1 | def log_model_parameter_names(logger, parameters): 2 | """ 3 | :param logger: loguru.logger 4 | :param parameters: dict(tensor) 5 | """ 6 | logger.info("Model Parameter Names: {}".format(parameters.keys())) 7 | -------------------------------------------------------------------------------- /federated_learning/parameters/model_comparison.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def compare_models(logger, model_1, model_2): 4 | models_differ = 0 5 | for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()): 6 | if not torch.equal(key_item_1[1], key_item_2[1]): 7 | models_differ += 1 8 | if (key_item_1[0] == key_item_2[0]): 9 | logger.error('Mismtach found at {}', key_item_1[0]) 10 | logger.debug("Model 1 value: {}", str(key_item_1[1])) 11 | logger.debug("Model 2 value: {}", str(key_item_2[1])) 12 | else: 13 | raise Exception 14 | if models_differ == 0: 15 | logger.info('Models match perfectly!') 16 | -------------------------------------------------------------------------------- /federated_learning/parameters/parameter_processors.py: -------------------------------------------------------------------------------- 1 | def get_layer_parameters(parameters, layer_name): 2 | """ 3 | Get a specific layer of parameters from a parameters object. 4 | 5 | :param parameters: dict(tensor) 6 | :param layer_name: string 7 | """ 8 | return parameters[layer_name] 9 | -------------------------------------------------------------------------------- /federated_learning/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .min_lr_step import MinCapableStepLR 2 | -------------------------------------------------------------------------------- /federated_learning/schedulers/min_lr_step.py: -------------------------------------------------------------------------------- 1 | class MinCapableStepLR: 2 | 3 | def __init__(self, logger, optimizer, step_size, gamma, min_lr): 4 | """ 5 | :param logger: logger 6 | :type logger: loguru.logger 7 | :param optimizer: 8 | :type optimizer: torch.optim 9 | :param step_size: # of epochs between LR updates 10 | :type step_size: int 11 | :param gamma: multiplication factor for LR update 12 | :type gamma: float 13 | :param min_lr: minimum learning rate 14 | :type min_lr: float 15 | """ 16 | self.logger = logger 17 | 18 | self.optimizer = optimizer 19 | self.step_size = step_size 20 | self.gamma = gamma 21 | self.min_lr = min_lr 22 | 23 | self.epoch_idx = 0 24 | 25 | def step(self): 26 | """ 27 | Adjust the learning rate as necessary. 28 | """ 29 | self.increment_epoch_index() 30 | 31 | if self.is_time_to_update_lr(): 32 | self.logger.debug("Updating LR for optimizer") 33 | 34 | self.update_lr() 35 | 36 | def is_time_to_update_lr(self): 37 | return self.epoch_idx % self.step_size == 0 38 | 39 | def update_lr(self): 40 | if self.optimizer.param_groups[0]['lr'] * self.gamma >= self.min_lr: 41 | self.optimizer.param_groups[0]['lr'] *= self.gamma 42 | else: 43 | self.logger.warning("Updating LR would place it below min LR. Skipping LR update.") 44 | 45 | self.logger.debug("New LR: {}".format(self.optimizer.param_groups[0]['lr'])) 46 | 47 | def increment_epoch_index(self): 48 | self.epoch_idx += 1 49 | -------------------------------------------------------------------------------- /federated_learning/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .class_flipping_methods import * 2 | from .label_replacement import apply_class_label_replacement 3 | from .tensor_converter import convert_distributed_data_into_numpy 4 | from .identify_random_elements import identify_random_elements 5 | from .file_storage_utils import save_results 6 | from .file_storage_utils import read_results 7 | from .file_storage_utils import save_results_v2 8 | from .file_storage_utils import read_results_v2 9 | from .file_storage_utils import generate_json_repr_for_worker 10 | from .file_storage_utils import convert_test_results_to_json 11 | from .data_loader_utils import generate_data_loaders_from_distributed_dataset 12 | from .data_loader_utils import load_train_data_loader 13 | from .data_loader_utils import load_test_data_loader 14 | from .data_loader_utils import generate_train_loader 15 | from .data_loader_utils import load_data_loader_from_file 16 | from .data_loader_utils import generate_test_loader 17 | from .data_loader_utils import save_data_loader_to_file 18 | from .fed_avg import average_nn_parameters 19 | from .client_utils import log_client_data_statistics 20 | from .poison_data import poison_data 21 | from .model_list_parser import * 22 | from .apply_scalers import apply_standard_scaler 23 | from .experiment_ids import generate_experiment_ids 24 | from .csv_utils import convert_results_to_csv 25 | -------------------------------------------------------------------------------- /federated_learning/utils/apply_scalers.py: -------------------------------------------------------------------------------- 1 | from sklearn.preprocessing import StandardScaler 2 | 3 | def apply_standard_scaler(gradients): 4 | scaler = StandardScaler() 5 | 6 | return scaler.fit_transform(gradients) 7 | -------------------------------------------------------------------------------- /federated_learning/utils/class_flipping_methods.py: -------------------------------------------------------------------------------- 1 | def default_no_change(targets, target_set): 2 | """ 3 | :param targets: Target class IDs 4 | :type targets: list 5 | :param target_set: Set of class IDs possible 6 | :type target_set: list 7 | :return: new class IDs 8 | """ 9 | return targets 10 | 11 | def replace_0_with_9(targets, target_set): 12 | """ 13 | :param targets: Target class IDs 14 | :type targets: list 15 | :param target_set: Set of class IDs possible 16 | :type target_set: list 17 | :return: new class IDs 18 | """ 19 | for idx in range(len(targets)): 20 | if targets[idx] == 0: 21 | targets[idx] = 9 22 | 23 | return targets 24 | 25 | def replace_0_with_6(targets, target_set): 26 | """ 27 | :param targets: Target class IDs 28 | :type targets: list 29 | :param target_set: Set of class IDs possible 30 | :type target_set: list 31 | :return: new class IDs 32 | """ 33 | for idx in range(len(targets)): 34 | if targets[idx] == 0: 35 | targets[idx] = 6 36 | 37 | return targets 38 | 39 | def replace_4_with_6(targets, target_set): 40 | """ 41 | :param targets: Target class IDs 42 | :type targets: list 43 | :param target_set: Set of class IDs possible 44 | :type target_set: list 45 | :return: new class IDs 46 | """ 47 | for idx in range(len(targets)): 48 | if targets[idx] == 4: 49 | targets[idx] = 6 50 | 51 | return targets 52 | 53 | def replace_1_with_3(targets, target_set): 54 | """ 55 | :param targets: Target class IDs 56 | :type targets: list 57 | :param target_set: Set of class IDs possible 58 | :type target_set: list 59 | :return: new class IDs 60 | """ 61 | for idx in range(len(targets)): 62 | if targets[idx] == 1: 63 | targets[idx] = 3 64 | 65 | return targets 66 | 67 | def replace_1_with_0(targets, target_set): 68 | """ 69 | :param targets: Target class IDs 70 | :type targets: list 71 | :param target_set: Set of class IDs possible 72 | :type target_set: list 73 | :return: new class IDs 74 | """ 75 | for idx in range(len(targets)): 76 | if targets[idx] == 1: 77 | targets[idx] = 0 78 | 79 | return targets 80 | 81 | def replace_2_with_3(targets, target_set): 82 | """ 83 | :param targets: Target class IDs 84 | :type targets: list 85 | :param target_set: Set of class IDs possible 86 | :type target_set: list 87 | :return: new class IDs 88 | """ 89 | for idx in range(len(targets)): 90 | if targets[idx] == 2: 91 | targets[idx] = 3 92 | 93 | return targets 94 | 95 | def replace_2_with_7(targets, target_set): 96 | """ 97 | :param targets: Target class IDs 98 | :type targets: list 99 | :param target_set: Set of class IDs possible 100 | :type target_set: list 101 | :return: new class IDs 102 | """ 103 | for idx in range(len(targets)): 104 | if targets[idx] == 2: 105 | targets[idx] = 7 106 | 107 | return targets 108 | 109 | def replace_3_with_9(targets, target_set): 110 | """ 111 | :param targets: Target class IDs 112 | :type targets: list 113 | :param target_set: Set of class IDs possible 114 | :type target_set: list 115 | :return: new class IDs 116 | """ 117 | for idx in range(len(targets)): 118 | if targets[idx] == 3: 119 | targets[idx] = 9 120 | 121 | return targets 122 | 123 | def replace_3_with_7(targets, target_set): 124 | """ 125 | :param targets: Target class IDs 126 | :type targets: list 127 | :param target_set: Set of class IDs possible 128 | :type target_set: list 129 | :return: new class IDs 130 | """ 131 | for idx in range(len(targets)): 132 | if targets[idx] == 3: 133 | targets[idx] = 7 134 | 135 | return targets 136 | 137 | def replace_4_with_9(targets, target_set): 138 | """ 139 | :param targets: Target class IDs 140 | :type targets: list 141 | :param target_set: Set of class IDs possible 142 | :type target_set: list 143 | :return: new class IDs 144 | """ 145 | for idx in range(len(targets)): 146 | if targets[idx] == 4: 147 | targets[idx] = 9 148 | 149 | return targets 150 | 151 | def replace_4_with_1(targets, target_set): 152 | """ 153 | :param targets: Target class IDs 154 | :type targets: list 155 | :param target_set: Set of class IDs possible 156 | :type target_set: list 157 | :return: new class IDs 158 | """ 159 | for idx in range(len(targets)): 160 | if targets[idx] == 4: 161 | targets[idx] = 1 162 | 163 | return targets 164 | 165 | def replace_5_with_3(targets, target_set): 166 | """ 167 | :param targets: Target class IDs 168 | :type targets: list 169 | :param target_set: Set of class IDs possible 170 | :type target_set: list 171 | :return: new class IDs 172 | """ 173 | for idx in range(len(targets)): 174 | if targets[idx] == 5: 175 | targets[idx] = 3 176 | 177 | return targets 178 | 179 | def replace_1_with_9(targets, target_set): 180 | """ 181 | :param targets: Target class IDs 182 | :type targets: list 183 | :param target_set: Set of class IDs possible 184 | :type target_set: list 185 | :return: new class IDs 186 | """ 187 | for idx in range(len(targets)): 188 | if targets[idx] == 1: 189 | targets[idx] = 9 190 | 191 | return targets 192 | 193 | def replace_0_with_2(targets, target_set): 194 | """ 195 | :param targets: Target class IDs 196 | :type targets: list 197 | :param target_set: Set of class IDs possible 198 | :type target_set: list 199 | :return: new class IDs 200 | """ 201 | for idx in range(len(targets)): 202 | if targets[idx] == 0: 203 | targets[idx] = 2 204 | 205 | return targets 206 | 207 | def replace_5_with_9(targets, target_set): 208 | """ 209 | :param targets: Target class IDs 210 | :type targets: list 211 | :param target_set: Set of class IDs possible 212 | :type target_set: list 213 | :return: new class IDs 214 | """ 215 | for idx in range(len(targets)): 216 | if targets[idx] == 5: 217 | targets[idx] = 9 218 | 219 | return targets 220 | 221 | def replace_5_with_7(targets, target_set): 222 | """ 223 | :param targets: Target class IDs 224 | :type targets: list 225 | :param target_set: Set of class IDs possible 226 | :type target_set: list 227 | :return: new class IDs 228 | """ 229 | for idx in range(len(targets)): 230 | if targets[idx] == 5: 231 | targets[idx] = 7 232 | 233 | return targets 234 | 235 | def replace_6_with_3(targets, target_set): 236 | """ 237 | :param targets: Target class IDs 238 | :type targets: list 239 | :param target_set: Set of class IDs possible 240 | :type target_set: list 241 | :return: new class IDs 242 | """ 243 | for idx in range(len(targets)): 244 | if targets[idx] == 6: 245 | targets[idx] = 3 246 | 247 | return targets 248 | 249 | def replace_6_with_0(targets, target_set): 250 | """ 251 | :param targets: Target class IDs 252 | :type targets: list 253 | :param target_set: Set of class IDs possible 254 | :type target_set: list 255 | :return: new class IDs 256 | """ 257 | for idx in range(len(targets)): 258 | if targets[idx] == 6: 259 | targets[idx] = 0 260 | 261 | return targets 262 | 263 | def replace_6_with_7(targets, target_set): 264 | """ 265 | :param targets: Target class IDs 266 | :type targets: list 267 | :param target_set: Set of class IDs possible 268 | :type target_set: list 269 | :return: new class IDs 270 | """ 271 | for idx in range(len(targets)): 272 | if targets[idx] == 6: 273 | targets[idx] = 7 274 | 275 | return targets 276 | 277 | def replace_7_with_9(targets, target_set): 278 | """ 279 | :param targets: Target class IDs 280 | :type targets: list 281 | :param target_set: Set of class IDs possible 282 | :type target_set: list 283 | :return: new class IDs 284 | """ 285 | for idx in range(len(targets)): 286 | if targets[idx] == 7: 287 | targets[idx] = 9 288 | 289 | return targets 290 | 291 | def replace_7_with_1(targets, target_set): 292 | """ 293 | :param targets: Target class IDs 294 | :type targets: list 295 | :param target_set: Set of class IDs possible 296 | :type target_set: list 297 | :return: new class IDs 298 | """ 299 | for idx in range(len(targets)): 300 | if targets[idx] == 7: 301 | targets[idx] = 1 302 | 303 | return targets 304 | 305 | def replace_8_with_9(targets, target_set): 306 | """ 307 | :param targets: Target class IDs 308 | :type targets: list 309 | :param target_set: Set of class IDs possible 310 | :type target_set: list 311 | :return: new class IDs 312 | """ 313 | for idx in range(len(targets)): 314 | if targets[idx] == 8: 315 | targets[idx] = 9 316 | 317 | return targets 318 | 319 | def replace_8_with_6(targets, target_set): 320 | """ 321 | :param targets: Target class IDs 322 | :type targets: list 323 | :param target_set: Set of class IDs possible 324 | :type target_set: list 325 | :return: new class IDs 326 | """ 327 | for idx in range(len(targets)): 328 | if targets[idx] == 8: 329 | targets[idx] = 6 330 | 331 | return targets 332 | 333 | def replace_9_with_3(targets, target_set): 334 | """ 335 | :param targets: Target class IDs 336 | :type targets: list 337 | :param target_set: Set of class IDs possible 338 | :type target_set: list 339 | :return: new class IDs 340 | """ 341 | for idx in range(len(targets)): 342 | if targets[idx] == 9: 343 | targets[idx] = 3 344 | 345 | return targets 346 | 347 | def replace_9_with_7(targets, target_set): 348 | """ 349 | :param targets: Target class IDs 350 | :type targets: list 351 | :param target_set: Set of class IDs possible 352 | :type target_set: list 353 | :return: new class IDs 354 | """ 355 | for idx in range(len(targets)): 356 | if targets[idx] == 9: 357 | targets[idx] = 7 358 | 359 | return targets 360 | 361 | def replace_0_with_9_1_with_3(targets, target_set): 362 | """ 363 | :param targets: Target class IDs 364 | :type targets: list 365 | :param target_set: Set of class IDs possible 366 | :type target_set: list 367 | :return: new class IDs 368 | """ 369 | for idx in range(len(targets)): 370 | if targets[idx] == 0: 371 | targets[idx] = 9 372 | elif targets[idx] == 1: 373 | targets[idx] = 3 374 | 375 | return targets 376 | 377 | def replace_0_with_6_1_with_0(targets, target_set): 378 | """ 379 | :param targets: Target class IDs 380 | :type targets: list 381 | :param target_set: Set of class IDs possible 382 | :type target_set: list 383 | :return: new class IDs 384 | """ 385 | for idx in range(len(targets)): 386 | if targets[idx] == 0: 387 | targets[idx] = 6 388 | elif targets[idx] == 1: 389 | targets[idx] = 0 390 | 391 | return targets 392 | 393 | 394 | def replace_2_with_3_3_with_9(targets, target_set): 395 | """ 396 | :param targets: Target class IDs 397 | :type targets: list 398 | :param target_set: Set of class IDs possible 399 | :type target_set: list 400 | :return: new class IDs 401 | """ 402 | for idx in range(len(targets)): 403 | if targets[idx] == 2: 404 | targets[idx] = 3 405 | elif targets[idx] == 3: 406 | targets[idx] = 9 407 | 408 | return targets 409 | 410 | def replace_2_with_7_3_with_7(targets, target_set): 411 | """ 412 | :param targets: Target class IDs 413 | :type targets: list 414 | :param target_set: Set of class IDs possible 415 | :type target_set: list 416 | :return: new class IDs 417 | """ 418 | for idx in range(len(targets)): 419 | if targets[idx] == 2: 420 | targets[idx] = 7 421 | elif targets[idx] == 3: 422 | targets[idx] = 7 423 | 424 | return targets 425 | -------------------------------------------------------------------------------- /federated_learning/utils/client_utils.py: -------------------------------------------------------------------------------- 1 | def log_client_data_statistics(logger, label_class_set, distributed_dataset): 2 | """ 3 | Logs all client data statistics. 4 | 5 | :param logger: logger 6 | :type logger: loguru.logger 7 | :param label_class_set: set of class labels 8 | :type label_class_set: list 9 | :param distributed_dataset: distributed dataset 10 | :type distributed_dataset: list(tuple) 11 | """ 12 | for client_idx in range(len(distributed_dataset)): 13 | client_class_nums = {class_val : 0 for class_val in label_class_set} 14 | for target in distributed_dataset[client_idx][1]: 15 | client_class_nums[target] += 1 16 | 17 | logger.info("Client #{} has data distribution: {}".format(client_idx, str(list(client_class_nums.values())))) 18 | -------------------------------------------------------------------------------- /federated_learning/utils/csv_utils.py: -------------------------------------------------------------------------------- 1 | def convert_results_to_csv(results): 2 | """ 3 | :param results: list(return data by test_classification() in client.py) 4 | """ 5 | cleaned_epoch_test_set_results = [] 6 | 7 | for row in results: 8 | components = [row[0], row[1]] 9 | 10 | for class_precision in row[2]: 11 | components.append(class_precision) 12 | for class_recall in row[3]: 13 | components.append(class_recall) 14 | 15 | cleaned_epoch_test_set_results.append(components) 16 | 17 | return cleaned_epoch_test_set_results 18 | -------------------------------------------------------------------------------- /federated_learning/utils/data_loader_utils.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from .label_replacement import apply_class_label_replacement 3 | import os 4 | import pickle 5 | import random 6 | from ..datasets import Dataset 7 | 8 | def generate_data_loaders_from_distributed_dataset(distributed_dataset, batch_size): 9 | """ 10 | Generate data loaders from a distributed dataset. 11 | 12 | :param distributed_dataset: Distributed dataset 13 | :type distributed_dataset: list(tuple) 14 | :param batch_size: batch size for data loader 15 | :type batch_size: int 16 | """ 17 | data_loaders = [] 18 | for worker_training_data in distributed_dataset: 19 | data_loaders.append(Dataset.get_data_loader_from_data(batch_size, worker_training_data[0], worker_training_data[1], shuffle=True)) 20 | 21 | return data_loaders 22 | 23 | def load_train_data_loader(logger, args): 24 | """ 25 | Loads the training data DataLoader object from a file if available. 26 | 27 | :param logger: loguru.Logger 28 | :param args: Arguments 29 | """ 30 | if os.path.exists(args.get_train_data_loader_pickle_path()): 31 | return load_data_loader_from_file(logger, args.get_train_data_loader_pickle_path()) 32 | else: 33 | logger.error("Couldn't find train data loader stored in file") 34 | 35 | raise FileNotFoundError("Couldn't find train data loader stored in file") 36 | 37 | def generate_train_loader(args, dataset): 38 | train_dataset = dataset.get_train_dataset() 39 | X, Y = shuffle_data(args, train_dataset) 40 | 41 | return dataset.get_data_loader_from_data(args.get_batch_size(), X, Y) 42 | 43 | def load_test_data_loader(logger, args): 44 | """ 45 | Loads the test data DataLoader object from a file if available. 46 | 47 | :param logger: loguru.Logger 48 | :param args: Arguments 49 | """ 50 | if os.path.exists(args.get_test_data_loader_pickle_path()): 51 | return load_data_loader_from_file(logger, args.get_test_data_loader_pickle_path()) 52 | else: 53 | logger.error("Couldn't find test data loader stored in file") 54 | 55 | raise FileNotFoundError("Couldn't find train data loader stored in file") 56 | 57 | def load_data_loader_from_file(logger, filename): 58 | """ 59 | Loads DataLoader object from a file if available. 60 | 61 | :param logger: loguru.Logger 62 | :param filename: string 63 | """ 64 | logger.info("Loading data loader from file: {}".format(filename)) 65 | 66 | with open(filename, "rb") as f: 67 | return load_saved_data_loader(f) 68 | 69 | def generate_test_loader(args, dataset): 70 | test_dataset = dataset.get_test_dataset() 71 | X, Y = shuffle_data(args, test_dataset) 72 | 73 | return dataset.get_data_loader_from_data(args.get_test_batch_size(), X, Y) 74 | 75 | def shuffle_data(args, dataset): 76 | data = list(zip(dataset[0], dataset[1])) 77 | random.shuffle(data) 78 | X, Y = zip(*data) 79 | X = numpy.asarray(X) 80 | Y = numpy.asarray(Y) 81 | 82 | return X, Y 83 | 84 | def load_saved_data_loader(file_obj): 85 | return pickle.load(file_obj) 86 | 87 | def save_data_loader_to_file(data_loader, file_obj): 88 | pickle.dump(data_loader, file_obj) 89 | -------------------------------------------------------------------------------- /federated_learning/utils/experiment_ids.py: -------------------------------------------------------------------------------- 1 | def generate_experiment_ids(start_idx, num_exp): 2 | """ 3 | Generate the filenames for all experiment IDs. 4 | 5 | :param start_idx: start index for experiments 6 | :type start_idx: int 7 | :param num_exp: number of experiments to run 8 | :type num_exp: int 9 | """ 10 | log_files = [] 11 | results_files = [] 12 | models_folders = [] 13 | worker_selections_files = [] 14 | 15 | for i in range(num_exp): 16 | idx = str(start_idx + i) 17 | 18 | log_files.append("logs/" + idx + ".log") 19 | results_files.append(idx + "_results.csv") 20 | models_folders.append(idx + "_models") 21 | worker_selections_files.append(idx + "_workers_selected.csv") 22 | 23 | return log_files, results_files, models_folders, worker_selections_files 24 | -------------------------------------------------------------------------------- /federated_learning/utils/fed_avg.py: -------------------------------------------------------------------------------- 1 | def average_nn_parameters(parameters): 2 | """ 3 | Averages passed parameters. 4 | 5 | :param parameters: nn model named parameters 6 | :type parameters: list 7 | """ 8 | new_params = {} 9 | for name in parameters[0].keys(): 10 | new_params[name] = sum([param[name].data for param in parameters]) / len(parameters) 11 | 12 | return new_params 13 | -------------------------------------------------------------------------------- /federated_learning/utils/file_storage_utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | 4 | def generate_json_repr_for_worker(worker_id, is_worker_poisoned, test_set_results): 5 | """ 6 | :param worker_id: int 7 | :param is_worker_poisoned: boolean 8 | :param test_set_results: list(dict) 9 | """ 10 | return { 11 | "worker_id" : worker_id, 12 | "is_worker_poisoned" : is_worker_poisoned, 13 | "test_set_results" : test_set_results 14 | } 15 | 16 | def convert_test_results_to_json(epoch_idx, accuracy, loss, class_precision, class_recall): 17 | """ 18 | :param epoch_idx: int 19 | :param accuracy: float 20 | :param loss: float 21 | :param class_precision: list(float) 22 | :param class_recall: list(float) 23 | """ 24 | return { 25 | "epoch" : epoch_idx, 26 | "accuracy" : accuracy, 27 | "loss" : loss, 28 | "class_precision" : class_precision, 29 | "class_recall" : class_recall 30 | } 31 | 32 | def save_results(results, filename): 33 | """ 34 | :param results: experiment results 35 | :type results: list() 36 | :param filename: File name to write results to 37 | :type filename: String 38 | """ 39 | with open(filename, 'w', newline='') as csvfile: 40 | writer = csv.writer(csvfile, delimiter=',') 41 | 42 | for experiment in results: 43 | writer.writerow(experiment) 44 | 45 | def read_results(filename): 46 | """ 47 | :param filename: File name to read results from 48 | :type filename: String 49 | """ 50 | data = [] 51 | with open(filename, 'r') as csvfile: 52 | reader = csv.reader(csvfile, delimiter=',') 53 | 54 | for row in reader: 55 | data.append(row) 56 | 57 | return data 58 | 59 | def save_results_v2(results, filename): 60 | """ 61 | Save results to a file. Using format v2. 62 | 63 | :param results: json 64 | :param filename: string 65 | """ 66 | with open(filename, "w") as f: 67 | json.dump(results, f, indent=4, sort_keys=True) 68 | 69 | def read_results_v2(filename): 70 | """ 71 | Read results from a file. Using format v2. 72 | """ 73 | with open(filename, "r") as f: 74 | return json.load(f) 75 | -------------------------------------------------------------------------------- /federated_learning/utils/identify_random_elements.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | def identify_random_elements(max, num_random_elements): 4 | """ 5 | Picks a specified number of random elements from 0 - max. 6 | 7 | :param max: Max number to pick from 8 | :type max: int 9 | :param num_random_elements: Number of random elements to select 10 | :type num_random_elements: int 11 | :return: list 12 | """ 13 | if num_random_elements > max: 14 | return [] 15 | 16 | ids = [] 17 | x = 0 18 | while x < num_random_elements: 19 | rand_int = random.randint(0, max - 1) 20 | 21 | if rand_int not in ids: 22 | ids.append(rand_int) 23 | x += 1 24 | 25 | return ids 26 | -------------------------------------------------------------------------------- /federated_learning/utils/label_replacement.py: -------------------------------------------------------------------------------- 1 | def apply_class_label_replacement(X, Y, replacement_method): 2 | """ 3 | Replace class labels using the replacement method 4 | 5 | :param X: data features 6 | :type X: numpy.Array() 7 | :param Y: data labels 8 | :type Y: numpy.Array() 9 | :param replacement_method: Method to update targets 10 | :type replacement_method: method 11 | """ 12 | return (X, replacement_method(Y, set(Y))) 13 | -------------------------------------------------------------------------------- /federated_learning/utils/model_list_parser.py: -------------------------------------------------------------------------------- 1 | def get_worker_num_from_model_file_name(model_file_name): 2 | """ 3 | :param model_file_name: string 4 | """ 5 | return int(model_file_name.split("_")[1]) 6 | 7 | def get_epoch_num_from_model_file_name(model_file_name): 8 | """ 9 | :param model_file_name: string 10 | """ 11 | return int(model_file_name.split("_")[2].split(".")[0]) 12 | 13 | def get_suffix_from_model_file_name(model_file_name): 14 | """ 15 | :param model_file_name: string 16 | """ 17 | return model_file_name.split("_")[3].split(".")[0] 18 | 19 | def get_model_files_for_worker(model_files, worker_id): 20 | """ 21 | :param model_files: list[string] 22 | :param worker_id: int 23 | """ 24 | worker_model_files = [] 25 | 26 | for model in model_files: 27 | worker_num = get_worker_num_from_model_file_name(model) 28 | 29 | if worker_num == worker_id: 30 | worker_model_files.append(model) 31 | 32 | return worker_model_files 33 | 34 | def get_model_files_for_epoch(model_files, epoch_num): 35 | """ 36 | :param model_files: list[string] 37 | :param epoch_num: int 38 | """ 39 | epoch_model_files = [] 40 | 41 | for model in model_files: 42 | model_epoch_num = get_epoch_num_from_model_file_name(model) 43 | 44 | if model_epoch_num == epoch_num: 45 | epoch_model_files.append(model) 46 | 47 | return epoch_model_files 48 | 49 | def get_model_files_for_suffix(model_files, suffix): 50 | """ 51 | :param model_files: list[string] 52 | :param suffix: string 53 | """ 54 | suffix_only_model_files = [] 55 | 56 | for model in model_files: 57 | model_suffix = get_suffix_from_model_file_name(model) 58 | 59 | if model_suffix == suffix: 60 | suffix_only_model_files.append(model) 61 | 62 | return suffix_only_model_files 63 | -------------------------------------------------------------------------------- /federated_learning/utils/poison_data.py: -------------------------------------------------------------------------------- 1 | from .label_replacement import apply_class_label_replacement 2 | from .client_utils import log_client_data_statistics 3 | 4 | def poison_data(logger, distributed_dataset, num_workers, poisoned_worker_ids, replacement_method): 5 | """ 6 | Poison worker data 7 | 8 | :param logger: logger 9 | :type logger: loguru.logger 10 | :param distributed_dataset: Distributed dataset 11 | :type distributed_dataset: list(tuple) 12 | :param num_workers: Number of workers overall 13 | :type num_workers: int 14 | :param poisoned_worker_ids: IDs poisoned workers 15 | :type poisoned_worker_ids: list(int) 16 | :param replacement_method: Replacement methods to use to replace 17 | :type replacement_method: list(method) 18 | """ 19 | # TODO: Add support for multiple replacement methods? 20 | poisoned_dataset = [] 21 | 22 | class_labels = list(set(distributed_dataset[0][1])) 23 | 24 | logger.info("Poisoning data for workers: {}".format(str(poisoned_worker_ids))) 25 | 26 | for worker_idx in range(num_workers): 27 | if worker_idx in poisoned_worker_ids: 28 | poisoned_dataset.append(apply_class_label_replacement(distributed_dataset[worker_idx][0], distributed_dataset[worker_idx][1], replacement_method)) 29 | else: 30 | poisoned_dataset.append(distributed_dataset[worker_idx]) 31 | 32 | log_client_data_statistics(logger, class_labels, poisoned_dataset) 33 | 34 | return poisoned_dataset 35 | -------------------------------------------------------------------------------- /federated_learning/utils/tensor_converter.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | def convert_distributed_data_into_numpy(distributed_dataset): 4 | """ 5 | Converts a distributed dataset (returned by a data distribution method) from Tensors into numpy arrays. 6 | 7 | :param distributed_dataset: Distributed dataset 8 | :type distributed_dataset: list(tuple) 9 | """ 10 | converted_distributed_dataset = [] 11 | 12 | for worker_idx in range(len(distributed_dataset)): 13 | worker_training_data = distributed_dataset[worker_idx] 14 | 15 | X_ = numpy.array([tensor.numpy() for batch in worker_training_data for tensor in batch[0]]) 16 | Y_ = numpy.array([tensor.numpy() for batch in worker_training_data for tensor in batch[1]]) 17 | 18 | converted_distributed_dataset.append((X_, Y_)) 19 | 20 | return converted_distributed_dataset 21 | -------------------------------------------------------------------------------- /federated_learning/worker_selection/__init__.py: -------------------------------------------------------------------------------- 1 | from .breakpoint_before import BeforeBreakpoint 2 | from .breakpoint_after import AfterBreakpoint 3 | from .poisoner_probability import PoisonerProbability 4 | from .random import RandomSelectionStrategy 5 | -------------------------------------------------------------------------------- /federated_learning/worker_selection/breakpoint_after.py: -------------------------------------------------------------------------------- 1 | from .selection_strategy import SelectionStrategy 2 | import random 3 | 4 | class AfterBreakpoint(SelectionStrategy): 5 | """ 6 | Will not select poisoned workers at and after the break point epoch, but will select the 7 | poisoned workers before the break point epoch. 8 | """ 9 | 10 | def select_round_workers(self, workers, poisoned_workers, kwargs): 11 | breakpoint_epoch = kwargs["AfterBreakPoint_EPOCH"] 12 | num_workers = kwargs["AfterBreakpoint_NUM_WORKERS_PER_ROUND"] 13 | current_epoch_number = kwargs["current_epoch_number"] 14 | 15 | selected_workers = [] 16 | if current_epoch_number < breakpoint_epoch: 17 | selected_workers = random.sample(workers, num_workers) 18 | else: 19 | non_poisoned_workers = list(set(workers) - set(poisoned_workers)) 20 | 21 | selected_workers = random.sample(non_poisoned_workers, num_workers) 22 | 23 | return selected_workers 24 | -------------------------------------------------------------------------------- /federated_learning/worker_selection/breakpoint_before.py: -------------------------------------------------------------------------------- 1 | from .selection_strategy import SelectionStrategy 2 | import random 3 | 4 | class BeforeBreakpoint(SelectionStrategy): 5 | """ 6 | Will not select poisoned workers before the break point epoch, but will select the 7 | poisoned workers at and after the break point epoch. 8 | """ 9 | 10 | def select_round_workers(self, workers, poisoned_workers, kwargs): 11 | breakpoint_epoch = kwargs["BeforeBreakPoint_EPOCH"] 12 | num_workers = kwargs["BeforeBreakpoint_NUM_WORKERS_PER_ROUND"] 13 | current_epoch_number = kwargs["current_epoch_number"] 14 | 15 | selected_workers = [] 16 | if current_epoch_number >= breakpoint_epoch: 17 | selected_workers = random.sample(workers, num_workers) 18 | else: 19 | non_poisoned_workers = list(set(workers) - set(poisoned_workers)) 20 | 21 | selected_workers = random.sample(non_poisoned_workers, num_workers) 22 | 23 | return selected_workers 24 | -------------------------------------------------------------------------------- /federated_learning/worker_selection/poisoner_probability.py: -------------------------------------------------------------------------------- 1 | from .selection_strategy import SelectionStrategy 2 | import random 3 | import copy 4 | 5 | class PoisonerProbability(SelectionStrategy): 6 | """ 7 | Will not select poisoned workers before or after a specified epoch (specified in arguments). 8 | 9 | Will artificially boost / reduce likelihood of the poisoned workers being selected. 10 | """ 11 | 12 | def select_round_workers(self, workers, poisoned_workers, kwargs): 13 | break_epoch = kwargs["PoisonerProbability_BREAK_EPOCH"] 14 | post_break_epoch_probability = kwargs["PoisonerProbability_POST_BREAK_EPOCH_PROBABILITY"] 15 | pre_break_epoch_probability = kwargs["PoisonerProbability_PRE_BREAK_EPOCH_PROBABILITY"] 16 | num_workers = kwargs["PoisonerProbability_NUM_WORKERS_PER_ROUND"] 17 | current_epoch_number = kwargs["current_epoch_number"] 18 | 19 | workers = self.remove_poisoned_workers_from_group(poisoned_workers, workers) 20 | 21 | selected_workers = [] 22 | if current_epoch_number >= break_epoch: 23 | selected_workers = self.select_workers(num_workers, post_break_epoch_probability, poisoned_workers, workers) 24 | else: 25 | selected_workers = self.select_workers(num_workers, pre_break_epoch_probability, poisoned_workers, workers) 26 | 27 | return selected_workers 28 | 29 | def remove_poisoned_workers_from_group(self, poisoned_workers, group): 30 | """ 31 | Removes all instances of set(poisoned_workers) from set(group). 32 | """ 33 | return list(set(group) - set(poisoned_workers)) 34 | 35 | def select_workers(self, num_workers, probability_threshold, group_0, group_1): 36 | """ 37 | Selects a set of workers from the two different groups. 38 | 39 | Weights the choice via the probability threshold 40 | """ 41 | group_0_copy = copy.deepcopy(group_0) 42 | group_1_copy = copy.deepcopy(group_1) 43 | 44 | selected_workers = [] 45 | while len(selected_workers) < num_workers: 46 | group_to_select_worker_from = self.select_group(probability_threshold, group_0, group_1) 47 | selected_worker = random.choice(group_to_select_worker_from) 48 | if selected_worker not in selected_workers: 49 | selected_workers.append(selected_worker) 50 | 51 | return selected_workers 52 | 53 | def select_group(self, probability_threshold, group_0, group_1): 54 | """ 55 | Selects between group_0 and group_1 based on a random choice. 56 | 57 | Probability threshold determines weighting given to group 0. 58 | Ex: if 0 is the probability threshold, then group 0 will never be selected. 59 | """ 60 | next_int = random.uniform(0, 1) 61 | 62 | if next_int <= probability_threshold: 63 | return group_0 64 | else: 65 | return group_1 66 | 67 | if __name__ == '__main__': 68 | selector = PoisonerProbability() 69 | 70 | print(selector.select_round_workers([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20], [3,4,5,6,10,11,12], { 71 | "PoisonerProbability_BREAK_EPOCH" : 5, 72 | "PoisonerProbability_POST_BREAK_EPOCH_PROBABILITY" : 0.0, 73 | "PoisonerProbability_PRE_BREAK_EPOCH_PROBABILITY" : 1.0, 74 | "PoisonerProbability_NUM_WORKERS_PER_ROUND" : 5, 75 | "current_epoch_number" : 10 76 | })) 77 | -------------------------------------------------------------------------------- /federated_learning/worker_selection/random.py: -------------------------------------------------------------------------------- 1 | from .selection_strategy import SelectionStrategy 2 | import random 3 | 4 | class RandomSelectionStrategy(SelectionStrategy): 5 | """ 6 | Randomly selects workers out of the list of all workers 7 | """ 8 | 9 | def select_round_workers(self, workers, poisoned_workers, kwargs): 10 | return random.sample(workers, kwargs["NUM_WORKERS_PER_ROUND"]) 11 | -------------------------------------------------------------------------------- /federated_learning/worker_selection/selection_strategy.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | class SelectionStrategy: 4 | 5 | @abstractmethod 6 | def select_round_workers(self, workers, poisoned_workers, kwargs): 7 | """ 8 | :param workers: list(int). All workers available for learning 9 | :param poisoned_workers: list(int). All workers that are poisoned 10 | :param kwargs: dict 11 | """ 12 | raise NotImplementedError("select_round_workers() not implemented") 13 | -------------------------------------------------------------------------------- /generate_data_distribution.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | import pathlib 3 | import os 4 | from federated_learning.arguments import Arguments 5 | from federated_learning.datasets import CIFAR10Dataset 6 | from federated_learning.datasets import FashionMNISTDataset 7 | from federated_learning.utils import generate_train_loader 8 | from federated_learning.utils import generate_test_loader 9 | from federated_learning.utils import save_data_loader_to_file 10 | 11 | 12 | if __name__ == '__main__': 13 | args = Arguments(logger) 14 | 15 | # --------------------------------- 16 | # ------------ CIFAR10 ------------ 17 | # --------------------------------- 18 | dataset = CIFAR10Dataset(args) 19 | TRAIN_DATA_LOADER_FILE_PATH = "data_loaders/cifar10/train_data_loader.pickle" 20 | TEST_DATA_LOADER_FILE_PATH = "data_loaders/cifar10/test_data_loader.pickle" 21 | 22 | if not os.path.exists("data_loaders/cifar10"): 23 | pathlib.Path("data_loaders/cifar10").mkdir(parents=True, exist_ok=True) 24 | 25 | train_data_loader = generate_train_loader(args, dataset) 26 | test_data_loader = generate_test_loader(args, dataset) 27 | 28 | with open(TRAIN_DATA_LOADER_FILE_PATH, "wb") as f: 29 | save_data_loader_to_file(train_data_loader, f) 30 | 31 | with open(TEST_DATA_LOADER_FILE_PATH, "wb") as f: 32 | save_data_loader_to_file(test_data_loader, f) 33 | 34 | # --------------------------------- 35 | # --------- Fashion-MNIST --------- 36 | # --------------------------------- 37 | dataset = FashionMNISTDataset(args) 38 | TRAIN_DATA_LOADER_FILE_PATH = "data_loaders/fashion-mnist/train_data_loader.pickle" 39 | TEST_DATA_LOADER_FILE_PATH = "data_loaders/fashion-mnist/test_data_loader.pickle" 40 | 41 | if not os.path.exists("data_loaders/fashion-mnist"): 42 | pathlib.Path("data_loaders/fashion-mnist").mkdir(parents=True, exist_ok=True) 43 | 44 | train_data_loader = generate_train_loader(args, dataset) 45 | test_data_loader = generate_test_loader(args, dataset) 46 | 47 | with open(TRAIN_DATA_LOADER_FILE_PATH, "wb") as f: 48 | save_data_loader_to_file(train_data_loader, f) 49 | 50 | with open(TEST_DATA_LOADER_FILE_PATH, "wb") as f: 51 | save_data_loader_to_file(test_data_loader, f) 52 | -------------------------------------------------------------------------------- /generate_default_models.py: -------------------------------------------------------------------------------- 1 | from federated_learning.arguments import Arguments 2 | from federated_learning.nets import Cifar10CNN 3 | from federated_learning.nets import FashionMNISTCNN 4 | import os 5 | import torch 6 | from loguru import logger 7 | 8 | if __name__ == '__main__': 9 | args = Arguments(logger) 10 | if not os.path.exists(args.get_default_model_folder_path()): 11 | os.mkdir(args.get_default_model_folder_path()) 12 | 13 | # --------------------------------- 14 | # ----------- Cifar10CNN ---------- 15 | # --------------------------------- 16 | full_save_path = os.path.join(args.get_default_model_folder_path(), "Cifar10CNN.model") 17 | torch.save(Cifar10CNN().state_dict(), full_save_path) 18 | 19 | # --------------------------------- 20 | # -------- FashionMNISTCNN -------- 21 | # --------------------------------- 22 | full_save_path = os.path.join(args.get_default_model_folder_path(), "FashionMNISTCNN.model") 23 | torch.save(FashionMNISTCNN().state_dict(), full_save_path) 24 | -------------------------------------------------------------------------------- /label_flipping_attack.py: -------------------------------------------------------------------------------- 1 | from federated_learning.utils import replace_0_with_2 2 | from federated_learning.utils import replace_5_with_3 3 | from federated_learning.utils import replace_1_with_9 4 | from federated_learning.utils import replace_4_with_6 5 | from federated_learning.utils import replace_1_with_3 6 | from federated_learning.utils import replace_6_with_0 7 | from federated_learning.worker_selection import RandomSelectionStrategy 8 | from server import run_exp 9 | 10 | if __name__ == '__main__': 11 | START_EXP_IDX = 3000 12 | NUM_EXP = 3 13 | NUM_POISONED_WORKERS = 0 14 | REPLACEMENT_METHOD = replace_1_with_9 15 | KWARGS = { 16 | "NUM_WORKERS_PER_ROUND" : 5 17 | } 18 | 19 | for experiment_id in range(START_EXP_IDX, START_EXP_IDX + NUM_EXP): 20 | run_exp(REPLACEMENT_METHOD, NUM_POISONED_WORKERS, KWARGS, RandomSelectionStrategy(), experiment_id) 21 | -------------------------------------------------------------------------------- /malicious_participant_availability.py: -------------------------------------------------------------------------------- 1 | from federated_learning.utils import replace_0_with_2 2 | from federated_learning.utils import replace_5_with_3 3 | from federated_learning.utils import replace_1_with_9 4 | from federated_learning.utils import replace_4_with_6 5 | from federated_learning.utils import replace_1_with_3 6 | from federated_learning.utils import replace_6_with_0 7 | from federated_learning.worker_selection import PoisonerProbability 8 | from server import run_exp 9 | 10 | if __name__ == '__main__': 11 | START_EXP_IDX = 3000 12 | NUM_EXP = 3 13 | NUM_POISONED_WORKERS = 0 14 | REPLACEMENT_METHOD = replace_1_with_9 15 | KWARGS = { 16 | "PoisonerProbability_BREAK_EPOCH" : 75, 17 | "PoisonerProbability_POST_BREAK_EPOCH_PROBABILITY" : 0.6, 18 | "PoisonerProbability_PRE_BREAK_EPOCH_PROBABILITY" : 0.0, 19 | "PoisonerProbability_NUM_WORKERS_PER_ROUND" : 5 20 | } 21 | 22 | for experiment_id in range(START_EXP_IDX, START_EXP_IDX + NUM_EXP): 23 | run_exp(REPLACEMENT_METHOD, NUM_POISONED_WORKERS, KWARGS, PoisonerProbability(), experiment_id) 24 | -------------------------------------------------------------------------------- /requirements.pip: -------------------------------------------------------------------------------- 1 | loguru==0.3.2 2 | scikit-learn==0.21.3 3 | torch==1.2.0 4 | torchvision==0.4.0 5 | numpy==1.17.0 6 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | from federated_learning.arguments import Arguments 3 | from federated_learning.utils import generate_data_loaders_from_distributed_dataset 4 | from federated_learning.datasets.data_distribution import distribute_batches_equally 5 | from federated_learning.utils import average_nn_parameters 6 | from federated_learning.utils import convert_distributed_data_into_numpy 7 | from federated_learning.utils import poison_data 8 | from federated_learning.utils import identify_random_elements 9 | from federated_learning.utils import save_results 10 | from federated_learning.utils import load_train_data_loader 11 | from federated_learning.utils import load_test_data_loader 12 | from federated_learning.utils import generate_experiment_ids 13 | from federated_learning.utils import convert_results_to_csv 14 | from client import Client 15 | 16 | def train_subset_of_clients(epoch, args, clients, poisoned_workers): 17 | """ 18 | Train a subset of clients per round. 19 | 20 | :param epoch: epoch 21 | :type epoch: int 22 | :param args: arguments 23 | :type args: Arguments 24 | :param clients: clients 25 | :type clients: list(Client) 26 | :param poisoned_workers: indices of poisoned workers 27 | :type poisoned_workers: list(int) 28 | """ 29 | kwargs = args.get_round_worker_selection_strategy_kwargs() 30 | kwargs["current_epoch_number"] = epoch 31 | 32 | random_workers = args.get_round_worker_selection_strategy().select_round_workers( 33 | list(range(args.get_num_workers())), 34 | poisoned_workers, 35 | kwargs) 36 | 37 | for client_idx in random_workers: 38 | args.get_logger().info("Training epoch #{} on client #{}", str(epoch), str(clients[client_idx].get_client_index())) 39 | clients[client_idx].train(epoch) 40 | 41 | args.get_logger().info("Averaging client parameters") 42 | parameters = [clients[client_idx].get_nn_parameters() for client_idx in random_workers] 43 | new_nn_params = average_nn_parameters(parameters) 44 | 45 | for client in clients: 46 | args.get_logger().info("Updating parameters on client #{}", str(client.get_client_index())) 47 | client.update_nn_parameters(new_nn_params) 48 | 49 | return clients[0].test(), random_workers 50 | 51 | def create_clients(args, train_data_loaders, test_data_loader): 52 | """ 53 | Create a set of clients. 54 | """ 55 | clients = [] 56 | for idx in range(args.get_num_workers()): 57 | clients.append(Client(args, idx, train_data_loaders[idx], test_data_loader)) 58 | 59 | return clients 60 | 61 | def run_machine_learning(clients, args, poisoned_workers): 62 | """ 63 | Complete machine learning over a series of clients. 64 | """ 65 | epoch_test_set_results = [] 66 | worker_selection = [] 67 | for epoch in range(1, args.get_num_epochs() + 1): 68 | results, workers_selected = train_subset_of_clients(epoch, args, clients, poisoned_workers) 69 | 70 | epoch_test_set_results.append(results) 71 | worker_selection.append(workers_selected) 72 | 73 | return convert_results_to_csv(epoch_test_set_results), worker_selection 74 | 75 | def run_exp(replacement_method, num_poisoned_workers, KWARGS, client_selection_strategy, idx): 76 | log_files, results_files, models_folders, worker_selections_files = generate_experiment_ids(idx, 1) 77 | 78 | # Initialize logger 79 | handler = logger.add(log_files[0], enqueue=True) 80 | 81 | args = Arguments(logger) 82 | args.set_model_save_path(models_folders[0]) 83 | args.set_num_poisoned_workers(num_poisoned_workers) 84 | args.set_round_worker_selection_strategy_kwargs(KWARGS) 85 | args.set_client_selection_strategy(client_selection_strategy) 86 | args.log() 87 | 88 | train_data_loader = load_train_data_loader(logger, args) 89 | test_data_loader = load_test_data_loader(logger, args) 90 | 91 | # Distribute batches equal volume IID 92 | distributed_train_dataset = distribute_batches_equally(train_data_loader, args.get_num_workers()) 93 | distributed_train_dataset = convert_distributed_data_into_numpy(distributed_train_dataset) 94 | 95 | poisoned_workers = identify_random_elements(args.get_num_workers(), args.get_num_poisoned_workers()) 96 | distributed_train_dataset = poison_data(logger, distributed_train_dataset, args.get_num_workers(), poisoned_workers, replacement_method) 97 | 98 | train_data_loaders = generate_data_loaders_from_distributed_dataset(distributed_train_dataset, args.get_batch_size()) 99 | 100 | clients = create_clients(args, train_data_loaders, test_data_loader) 101 | 102 | results, worker_selection = run_machine_learning(clients, args, poisoned_workers) 103 | save_results(results, results_files[0]) 104 | save_results(worker_selection, worker_selections_files[0]) 105 | 106 | logger.remove(handler) 107 | --------------------------------------------------------------------------------