├── .gitignore ├── Makefile ├── README.md ├── attacks ├── attack.py ├── loss_functions.py └── modelreplace.py ├── defenses └── fedavg.py ├── environment.yml ├── exps ├── cifar_fed.yaml ├── gen-run-yaml-file.py ├── imagenet_fed.yaml ├── list_exps__2023.Nov.25.md └── mnist_fed.yaml ├── helper.py ├── markdown-explanation ├── FedMLSec.md ├── Figure_4_ba_norm.png ├── Figure_full-image-pattern.png ├── code-structures.md └── model-replacement-attack.png ├── metrics ├── accuracy_metric.py ├── metric.py └── test_loss_metric.py ├── models ├── MnistNet.py ├── __init__.py ├── model.py ├── resnet.py ├── resnet_cifar.py ├── resnet_tinyimagenet.py └── simple.py ├── synthesizers ├── attack_models │ ├── autoencoders.py │ └── unet.py ├── complex_synthesizer.py ├── iba_synthesizer.py ├── pattern_synthesizer.py ├── physical_synthesizer.py ├── singlepixel_synthesizer.py └── synthesizer.py ├── tasks ├── batch.py ├── cifar10_task.py ├── fl_user.py ├── imagenet_task.py ├── mnist_task.py └── task.py ├── test-lib ├── test-log.py ├── test-write-file.ipynb └── test.log ├── training.py └── utils ├── parameters.py ├── process_tiny_data.sh ├── tinyimagenet_reformat.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | FedML/ 3 | config/ 4 | PFL-Non-IID/ 5 | 3DFed/ 6 | OOD_Federated_Learning/ 7 | RLBackdoorFL/ 8 | __pycache__/ 9 | *.pyc 10 | saved_models/ 11 | .data/ 12 | wandb/ 13 | *.zip 14 | exps/run* 15 | env-variables.md 16 | synthesizers/checkpoint/ 17 | synthesizers/atkmodel/ -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .ONESHELL: 2 | 3 | CONDAPATH = $$(conda info --base) 4 | ENV_NAME = aba 5 | 6 | install: 7 | conda env create -f environment.yml 8 | ${CONDAPATH}/envs/$(ENV_NAME)/bin/pip install -r requirements.txt 9 | 10 | install-mac: 11 | conda env create -f environment.yml 12 | conda install nomkl 13 | ${CONDAPATH}/envs/$(ENV_NAME)/bin/pip install -r requirements.txt 14 | 15 | update: 16 | # conda env update --prune -f environment.yml 17 | ${CONDAPATH}/envs/$(ENV_NAME)/bin/pip install -r requirements.txt --upgrade 18 | 19 | clean: 20 | conda env remove --name $(ENV_NAME) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Table of contents 3 | - [Table of contents](#table-of-contents) 4 | - [Description](#description) 5 | - [Overview of project](#overview-of-project) 6 | - [How to develop from this project](#how-to-develop-from-this-project) 7 | - [Dataset for Backdoor Attack in FL](#dataset-for-backdoor-attack-in-fl) 8 | - [Durability](#durability) 9 | - [Working with FedML](#working-with-fedml) 10 | - [Survey Papers for Machine Learning Security](#survey-papers-for-machine-learning-security) 11 | - [Paper Backdoor Attack in ML/ FL](#paper-backdoor-attack-in-ml-fl) 12 | - [Code for Backdoor Attack in ML/ FL](#code-for-backdoor-attack-in-ml-fl) 13 | - [Other Resources for Backdoor Attack in ML/ FL](#other-resources-for-backdoor-attack-in-ml-fl) 14 | - [Backdoor Attack code resources in FL](#backdoor-attack-code-resources-in-fl) 15 | 16 | 17 | # Description 18 | This github project provided a fast integration for readers to work in the field of backdoor attacks in machine learning and federated learning. 19 | 20 | 21 | 22 | # Overview of project 23 | - Attack methods: DBA, LIRA, IBA, BackdoorBox, 3DFed, Chameleon, etc. 24 | - Defense methods: Krum, RFA, FoolsGold, etc. 25 | 26 | # How to develop from this project 27 | - Define dataset in `tasks` folder. 28 | - Define your own attack method in `.attacks` folder. Default attack method is adding trigger to the training dataset and training the model with the poisoned dataset. 29 | - Define your own defense method in `defenses` folder. Default defense method is FedAvg. 30 | - Define your own model in `models` folder and config for experiment in `exps` folder. 31 | - All the experiments are inherited from the 3 base `.yaml` files in `exps` folder for 3 datasets: `mnist_fed.yaml` (MNIST), `cifar_fed.yaml` (CIFAR-10), `imagetnet_fed.yaml` (Tiny ImageNet). 32 | - The base setting for each experiment contains 100 clients, in which 4 clients are attackers, 10 clients participate in each round, and 2 rounds are performed for benign training, and 5 rounds are performed for attack training. 33 | - The dataset is divided by dirichlet distribution with $\alpha = 0.5$. 34 | - The pre-trained model is downloaded from [Google Drive (DBA)](https://drive.google.com/file/d/1wcJ_DkviuOLkmr-FgIVSFwnZwyGU8SjH/view). 35 | --- 36 | Here is the example of running the code: 37 | ``` 38 | python training.py --name cifar10 --params ./exps/run_cifar10__2023.Nov.24/cifar10_fed_100_10_4_0.5_0.05.yaml 39 | 40 | python training.py --name mnist --params ./exps/run_mnist__2023.Nov.24/mnist_fed_100_10_4_0.5_0.05.yaml 41 | 42 | python training.py --name tiny-imagenet --params ./exps/run_tiny-imagenet__2023.Nov.24/tiny-imagenet_fed_100_10_4_0.5_0.05.yaml 43 | ``` 44 | In these commands, the `--name` argument is the name of experiment, and the `--params` argument is the path to the `.yaml` file of experiment setting. 45 | 46 | --- 47 | 48 | For more commands, please refer to file `./exps/run-yaml-cmd.md`. Please note that the `./exps/run-yaml-cmd.md` file is generated by the `./exps/gen-run-yaml-file.py` file. Refer to the `./exps/gen-run-yaml-file.py` to generate different commands for your own experiments. 49 | 50 | 51 | # Dataset for Backdoor Attack in FL 52 | 53 | |Dataset|Case|Description| 54 | | :--- | :--- | :--- | 55 | |MNIST|Case -|The MNIST database of handwritten digits, has a training set of 60,000 examples, and a test set of 10,000 examples.| 56 | CIFAR-10|Case -|The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images| 57 | CIFAR-100|Case -|The CIFAR-100 dataset consists of 60000 32x32 colour images in 100 classes, with 600 images per class., 500 training images and 100 testing images per class.| 58 | Tiny ImageNet|Case -|The Tiny ImageNet contains 200 image classes, a training dataset of 100,000 images, a validation dataset of 10,000 images, and a test dataset of 10,000 images (50 validation and 50 test images per class). All images are of size 64×64| 59 | |EMNIST|Case -|There are 62 classes (10 digits, 26 lowercase, 26 uppercase), 814255 samples, 697932 training samples, 116323 test samples. 60 | 61 | [Edge-case backdoors](https://proceedings.neurips.cc/paper/2020/hash/b8ffa41d4e492f0fad2f13e29e1762eb-Abstract.html) 62 | 63 | 69 | 70 | 71 | 72 | 77 | 78 | # Durability 79 | 80 | 81 | 82 | # Working with FedML 83 | - [FedML README.md](https://github.com/FedML-AI/FedML/blob/master/python/fedml/core/security/readme.md) 84 | 85 | Types of Attacks in FL setting: 86 | - Byzantine Attack 87 | - DLG Attack (Deep Leakage from Gradients) 88 | - Backdoor Attack 89 | - Model Replacement Attack 90 | 91 | # Survey Papers for Machine Learning Security 92 | 93 | | Title | Year | Venue | Code | Dataset | URL | Note | 94 | | :--- | :--- | :--- | :--- | :--- | :--- | :--- | 95 | |A Survey on Fully Homomorphic Encryption: An Engineering Perspective|2017|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3124441)|| 96 | |Generative Adversarial Networks: A Survey Toward Private and Secure Applications|2021|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3459992)|| 97 | |A Survey on Adversarial Recommender Systems: From Attack/Defense Strategies to Generative Adversarial Networks|2021|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3439729)|| 98 | |Video Generative Adversarial Networks: A Review|2022|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3487891)|| 99 | |Taxonomy of Machine Learning Safety: A Survey and Primer|2022|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3551385)|| 100 | |Dataset Security for Machine Learning: Data Poisoning, Backdoor Attacks, and Defenses|2022|ACM Computing Surveys|||[link](https://arxiv.org/pdf/2012.10544.pdf)|| 101 | |Generative Adversarial Networks: A Survey on Atack and Defense Perspective|2023|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3615336)|| 102 | |Trustworthy AI: From Principles to Practices|2023|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3555803)|| 103 | |Deep Learning for Android Malware Defenses: A Systematic Literature Review|2022|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3571156)|| 104 | |Dataset security for machine learning: Data poisoning, backdoor attacks, and defenses|2022|IEEE TPAMI|||[link](https://arxiv.org/pdf/2012.10544.pdf)|| 105 | |A Comprehensive Review of the State-of-the-Art on Security and Privacy Issues in Healthcare|2023|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3571156)|| 106 | |A Comprehensive Survey of Privacy-preserving Federated Learning: A Taxonomy, Review, and Future Directions|2023|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3460427)|| 107 | |Recent Advances on Federated Learning: A Systematic Survey|2023|arXiv|||[link](https://arxiv.org/pdf/2301.01299.pdf)|| 108 | |Federated Learning for Generalization, Robustness, Fairness: A Survey and Benchmark|2023|arXiv|||[link](https://arxiv.org/pdf/2311.06750v1.pdf)|| 109 | |Backdoor attacks and defenses in federated learning: Survey, challenges and future research directions|2024|Engineering Applications of Artificial Intelligence|||[link](https://www.sciencedirect.com/science/article/abs/pii/S0952197623013507)|| 110 | 111 | ## Paper Backdoor Attack in ML/ FL 112 | - [How to Backdoor Federated Learning](https://arxiv.org/pdf/1807.00459.pdf) 113 | - [Attack of the Tails: Yes, You Really Can Backdoor Federated Learning](https://arxiv.org/pdf/2007.05084v1.pdf) 114 | - [DBA: Distributed Backdoor Attacks against Federated Learning](https://openreview.net/pdf?id=rkgyS0VFvr) 115 | 116 | ## Code for Backdoor Attack in ML/ FL 117 | 118 | | Title | Year | Venue | Code | Dataset | URL | Note | 119 | | :--- | :--- | :--- | :--- | :--- | :--- | :--- | 120 | |Practicing-Federated-Learning||Github|[link](https://github.com/FederatedAI/Practicing-Federated-Learning/)| 121 | |Attack of the Tails: Yes, You Really Can Backdoor Federated Learning||NeurIPS'20|[link](https://github.com/ksreenivasan/OOD_Federated_Learning)| 122 | |DBA: Distributed Backdoor Attacks against Federated Learning||ICLR'20|[link](https://github.com/AI-secure/DBA)| 123 | |LIRA: Learnable, Imperceptible and Robust Backdoor Attacks ||ICCV'21|[link](https://github.com/khoadoan106/backdoor_attacks/tree/main)| 124 | |Backdoors Framework for Deep Learning and Federated Learning||AISTAT'20, USENIX'21|[link](https://github.com/ebagdasa/backdoors101)| 125 | |BackdoorBox: An Open-sourced Python Toolbox for Backdoor Attacks and Defenses|2023|Github|[link](https://github.com/THUYimingLi/BackdoorBox)| 126 | |3DFed: Adaptive and Extensible Framework for Covert Backdoor Attack in Federated Learning||IEEE S&P'23|[link](https://github.com/haoyangliASTAPLE/3DFed)| 127 | |Neurotoxin: Durable Backdoors in Federated Learning||ICML'22|[link](https://proceedings.mlr.press/v162/zhang22w/zhang22w.pdf)|||Durability 128 | |Chameleon: Adapting to Peer Images for Planting Durable Backdoors in Federated Learning||ICML'23|[link](https://github.com/ybdai7/Chameleon-durable-backdoor)|||Durability 129 | |PerDoor: Persistent Backdoors in Federated Learning using Adversarial Perturbations||COINS'23|[link](https://ieeexplore.ieee.org/abstract/document/10189281)| 130 | ## Other Resources for Backdoor Attack in ML/ FL 131 | - [List of papers on data poisoning and backdoor attacks](https://github.com/penghui-yang/awesome-data-poisoning-and-backdoor-attacks) 132 | - [Proceedings of Machine Learning Research](https://proceedings.mlr.press/) 133 | - [Backdoor learning resources](https://github.com/THUYimingLi/backdoor-learning-resources) 134 | - Google Scholar: 135 | - [Avi Schwarzschild, Post Doc at CMU](https://scholar.google.com/citations?user=WNvQ7AcAAAAJ&hl=en) 136 | - [Yiming Li, Research Professor at Zhejiang University](https://scholar.google.com/citations?user=mSW7kU8AAAAJ&hl=zh-CN) 137 | - [Nicolas Papernot, Assistant Professor at University of Toronto](https://scholar.google.com/citations?user=cGxq0cMAAAAJ&hl=en) 138 | 139 | # Backdoor Attack code resources in FL 140 | In FL community, there are many code resources for backdoor attack, in which each of them has its own FL scenario (e.g., hyperparameters, dataset, attack methods, defense methods, etc.). 141 | Thus, we provide a list of popular code resources for backdoor attack in FL as follows: 142 | - [Attack of the Tails: Yes, You Really Can Backdoor Federated Learning - NeurIPS'20](https://github.com/ksreenivasan/OOD_Federated_Learning) 143 | - [DBA: Distributed Backdoor Attacks against Federated Learning ICLR'20](https://github.com/AI-secure/DBA) 144 | - [How To Backdoor Federated Learning - AISTATS'20](https://github.com/ebagdasa/backdoors101) 145 | - [Just How Toxic is Data Poisoning? A Unified Benchmark for Backdoor and Data Poisoning Attacks - ICML'21](https://github.com/aks2203/poisoning-benchmark) 146 | - [Learning to Backdoor Federated Learning - ICLR'23 Workshop](https://github.com/HengerLi/RLBackdoorFL/tree/main) 147 | - [Chameleon: Adapting to Peer Images for Planting Durable Backdoors in Federated Learning - ICML'23](https://github.com/ybdai7/Chameleon-durable-backdoor) 148 | - [FedMLSecurity: A Benchmark for Attacks and Defenses in Federated Learning and Federated LLMs - arXiv'23](https://github.com/FedML-AI/FedML/blob/master/python/fedml/core/security/readme.md) 149 | - [IBA: Towards Irreversible Backdoor Attacks in Federated Learning - NeurIPS'23](https://github.com/sail-research/iba) 150 | 151 | -------------------------------------------------------------------------------- /attacks/attack.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, List 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from synthesizers.synthesizer import Synthesizer 7 | from attacks.loss_functions import compute_all_losses_and_grads 8 | from utils.parameters import Params 9 | import math 10 | logger = logging.getLogger('logger') 11 | 12 | 13 | class Attack: 14 | params: Params 15 | synthesizer: Synthesizer 16 | local_dataset: DataLoader 17 | loss_tasks: List[str] 18 | fixed_scales: Dict[str, float] 19 | ignored_weights = ['num_batches_tracked']#['tracked', 'running'] 20 | 21 | def __init__(self, params, synthesizer): 22 | self.params = params 23 | self.synthesizer = synthesizer 24 | self.loss_tasks = ['normal', 'backdoor'] 25 | self.fixed_scales = {'normal':0.5, 'backdoor':0.5} 26 | 27 | def perform_attack(self, _) -> None: 28 | raise NotImplemented 29 | 30 | def compute_blind_loss(self, model, criterion, batch, attack, fixed_model=None): 31 | """ 32 | 33 | :param model: 34 | :param criterion: 35 | :param batch: 36 | :param attack: Do not attack at all. Ignore all the parameters 37 | :return: 38 | """ 39 | batch = batch.clip(self.params.clip_batch) 40 | loss_tasks = self.loss_tasks.copy() if attack else ['normal'] 41 | batch_back = self.synthesizer.make_backdoor_batch(batch, attack=attack) 42 | scale = dict() 43 | 44 | if len(loss_tasks) == 1: 45 | loss_values = compute_all_losses_and_grads( 46 | loss_tasks, 47 | self, model, criterion, batch, batch_back 48 | ) 49 | else: 50 | loss_values = compute_all_losses_and_grads( 51 | loss_tasks, 52 | self, model, criterion, batch, batch_back, 53 | fixed_model = fixed_model) 54 | 55 | for t in loss_tasks: 56 | scale[t] = self.fixed_scales[t] 57 | 58 | if len(loss_tasks) == 1: 59 | scale = {loss_tasks[0]: 1.0} 60 | blind_loss = self.scale_losses(loss_tasks, loss_values, scale) 61 | 62 | return blind_loss 63 | 64 | def scale_losses(self, loss_tasks, loss_values, scale): 65 | blind_loss = 0 66 | # import IPython; IPython.embed() 67 | # exit(0) 68 | for it, t in enumerate(loss_tasks): 69 | self.params.running_losses[t].append(loss_values[t].item()) 70 | self.params.running_scales[t].append(scale[t]) 71 | if it == 0: 72 | blind_loss = scale[t] * loss_values[t] 73 | else: 74 | blind_loss += scale[t] * loss_values[t] 75 | self.params.running_losses['total'].append(blind_loss.item()) 76 | return blind_loss 77 | 78 | def scale_update(self, local_update: Dict[str, torch.Tensor], gamma): 79 | for name, value in local_update.items(): 80 | value.mul_(gamma) 81 | 82 | def get_fl_update(self, local_model, global_model) -> Dict[str, torch.Tensor]: 83 | local_update = dict() 84 | for name, data in local_model.state_dict().items(): 85 | if self.check_ignored_weights(name): 86 | continue 87 | local_update[name] = (data - global_model.state_dict()[name]) 88 | 89 | return local_update 90 | 91 | def check_ignored_weights(self, name) -> bool: 92 | for ignored in self.ignored_weights: 93 | if ignored in name: 94 | return True 95 | 96 | return False 97 | 98 | def get_update_norm(self, local_update): 99 | squared_sum = 0 100 | for name, value in local_update.items(): 101 | if 'tracked' in name or 'running' in name: 102 | continue 103 | squared_sum += torch.sum(torch.pow(value, 2)).item() 104 | update_norm = math.sqrt(squared_sum) 105 | return update_norm -------------------------------------------------------------------------------- /attacks/loss_functions.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from torch.nn import functional as F, Module 5 | 6 | from models.model import Model 7 | from utils.parameters import Params 8 | from utils.utils import record_time 9 | 10 | def compute_all_losses_and_grads(loss_tasks, attack, model, criterion, 11 | batch, batch_back, 12 | fixed_model=None): 13 | loss_values = {} 14 | for t in loss_tasks: 15 | if t == 'normal': 16 | loss_values[t] = compute_normal_loss(attack.params, 17 | model, 18 | criterion, 19 | batch.inputs, 20 | batch.labels) 21 | elif t == 'backdoor': 22 | loss_values[t] = compute_backdoor_loss(attack.params, 23 | model, 24 | criterion, 25 | batch_back.inputs, 26 | batch_back.labels) 27 | elif t == 'eu_constraint': 28 | loss_values[t] = compute_euclidean_loss(attack.params, 29 | model, 30 | fixed_model) 31 | elif t == 'cs_constraint': 32 | loss_values[t] = compute_cos_sim_loss(attack.params, 33 | model, 34 | fixed_model) 35 | 36 | return loss_values 37 | 38 | 39 | def compute_normal_loss(params: Params, model, criterion, inputs, labels): 40 | t = time.perf_counter() 41 | outputs = model(inputs) 42 | record_time(params, t, 'forward') 43 | loss = criterion(outputs, labels) 44 | loss = loss.mean() 45 | 46 | return loss 47 | 48 | def compute_backdoor_loss(params, model, criterion, inputs_back, labels_back): 49 | t = time.perf_counter() 50 | outputs = model(inputs_back) 51 | record_time(params, t, 'forward') 52 | loss = criterion(outputs, labels_back) 53 | loss = loss.mean() 54 | 55 | return loss 56 | 57 | def compute_euclidean_loss(params: Params, 58 | model: Model, 59 | fixed_model: Model): 60 | size = 0 61 | for name, layer in model.named_parameters(): 62 | size += layer.view(-1).shape[0] 63 | sum_var = torch.cuda.FloatTensor(size).fill_(0) 64 | size = 0 65 | for name, layer in model.named_parameters(): 66 | sum_var[size:size + layer.view(-1).shape[0]] = (layer - \ 67 | fixed_model.state_dict()[name]).view(-1) 68 | size += layer.view(-1).shape[0] 69 | 70 | loss = torch.norm(sum_var, p=2) 71 | 72 | return loss 73 | 74 | def get_one_vec(model: Module): 75 | size = 0 76 | for name, layer in model.named_parameters(): 77 | size += layer.view(-1).shape[0] 78 | sum_var = torch.cuda.FloatTensor(size).fill_(0) 79 | size = 0 80 | for name, layer in model.named_parameters(): 81 | sum_var[size:size + layer.view(-1).shape[0]] = (layer.data).view(-1) 82 | size += layer.view(-1).shape[0] 83 | 84 | return sum_var 85 | 86 | def compute_cos_sim_loss(params: Params, 87 | model: Model, 88 | fixed_model: Model): 89 | model_vec = get_one_vec(model) 90 | target_var = get_one_vec(fixed_model) 91 | cs_sim = F.cosine_similarity(params.fl_weight_scale*(model_vec-target_var)\ 92 | + target_var, target_var, dim=0) 93 | loss = 1e3 * (1 - cs_sim) 94 | return loss 95 | 96 | def compute_noise_ups_loss(params: Params, 97 | backdoor_update, 98 | noise_masks, 99 | random_neurons): 100 | losses = [] 101 | for i in range(len(noise_masks)): 102 | UPs = [] 103 | for j in random_neurons: 104 | if 'MNIST' not in params.task: 105 | UPs.append(torch.abs(backdoor_update['fc.weight'][j] + \ 106 | noise_masks[i].fc.weight[j]).sum() \ 107 | + torch.abs(backdoor_update['fc.bias'][j] + \ 108 | noise_masks[i].fc.bias[j])) 109 | else: 110 | UPs.append(torch.abs(backdoor_update['fc2.weight'][j] + \ 111 | noise_masks[i].fc2.weight[j]).sum() \ 112 | + torch.abs(backdoor_update['fc2.bias'][j] + \ 113 | noise_masks[i].fc2.bias[j])) 114 | UPs_loss = 0 115 | for j in range(len(UPs)): 116 | if 'Imagenet' in params.task: 117 | UPs_loss += 5e-4 / UPs[j] 118 | else: 119 | UPs_loss += 1e-1 / UPs[j] # (UPs[j] * params.fl_num_neurons) 120 | noise_masks[i].requires_grad_(True) 121 | UPs_loss.requires_grad_(True) 122 | losses.append(UPs_loss) 123 | return losses 124 | 125 | def compute_noise_norm_loss(params: Params, 126 | noise_masks, 127 | random_neurons): 128 | size = 0 129 | layer_name = 'fc2' if 'MNIST' in params.task else 'fc' 130 | for name, layer in noise_masks[0].named_parameters(): 131 | if layer_name in name: 132 | size += layer.view(-1).shape[0] 133 | losses = [] 134 | for i in range(len(noise_masks)): 135 | sum_var = torch.cuda.FloatTensor(size).fill_(0) 136 | noise_size = 0 137 | for name, layer in noise_masks[i].named_parameters(): 138 | if layer_name in name: 139 | for j in range(layer.shape[0]): 140 | if j in random_neurons: 141 | sum_var[noise_size:noise_size + layer[j].view(-1).shape[0]] = \ 142 | layer[j].view(-1) 143 | noise_size += layer[j].view(-1).shape[0] 144 | if 'MNIST' in params.task: 145 | loss = 8e-2 * torch.norm(sum_var, p=2) 146 | else: 147 | loss = 3e-2 * torch.norm(sum_var, p=2) 148 | losses.append(loss) 149 | return losses 150 | 151 | def compute_lagrange_loss(params: Params, 152 | noise_masks, 153 | random_neurons): 154 | losses = [] 155 | size = 0 156 | layer_name = 'fc2' if 'MNIST' in params.task else 'fc' 157 | for name, layer in noise_masks[0].named_parameters(): 158 | if layer_name in name: 159 | size += layer.view(-1).shape[0] 160 | sum_var = torch.cuda.FloatTensor(size).fill_(0) 161 | for i in range(len(noise_masks)): 162 | size = 0 163 | for name, layer in noise_masks[i].named_parameters(): 164 | if layer_name in name: 165 | for j in range(layer.shape[0]): 166 | if j in random_neurons: 167 | sum_var[size:size + layer[j].view(-1).shape[0]] += \ 168 | layer[j].view(-1) 169 | size += layer[j].view(-1).shape[0] 170 | 171 | if 'MNIST' in params.task: 172 | loss = 1e-1 * torch.norm(sum_var, p=2) 173 | else: 174 | loss = 1e-2 * torch.norm(sum_var, p=2) 175 | for i in range(len(noise_masks)): 176 | losses.append(loss) 177 | return losses 178 | 179 | def compute_decoy_acc_loss(params: Params, 180 | benign_model: Model, 181 | decoy: Model, 182 | criterion, inputs, labels): 183 | dec_acc_loss, _ = compute_normal_loss(params, decoy, criterion, \ 184 | inputs, labels) 185 | benign_acc_loss, _ = compute_normal_loss(params, benign_model, criterion, \ 186 | inputs, labels) 187 | if dec_acc_loss > benign_acc_loss: 188 | loss = dec_acc_loss 189 | else: 190 | loss = - 1e-10 * (dec_acc_loss) 191 | 192 | return loss 193 | 194 | def compute_decoy_param_loss(params:Params, 195 | decoy: Model, 196 | benign_model: Model, 197 | param_idx): 198 | 199 | if 'MNIST' not in params.task: 200 | param_diff = torch.abs(decoy.fc.weight[param_idx[0]][param_idx[1]] - \ 201 | benign_model.state_dict()['fc.weight'][param_idx[0]][param_idx[1]]) 202 | else: 203 | param_diff = torch.abs(decoy.fc1.weight[param_idx[0]][param_idx[1]] - \ 204 | benign_model.state_dict()['fc1.weight'][param_idx[0]][param_idx[1]]) 205 | 206 | threshold = 10 # 30 207 | if(param_diff.item() > threshold): 208 | loss = 1e-10 * param_diff 209 | else: 210 | loss = - 1e1 * param_diff 211 | 212 | return loss 213 | 214 | def get_grads(params, model, loss): 215 | t = time.perf_counter() 216 | grads = list(torch.autograd.grad(loss.mean(), 217 | [x for x in model.parameters() if 218 | x.requires_grad], 219 | retain_graph=True, allow_unused=True)) 220 | record_time(params, t, 'backward') 221 | 222 | return grads -------------------------------------------------------------------------------- /attacks/modelreplace.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # from attack import Attack 3 | from attacks.attack import Attack 4 | from tasks.fl_user import FLUser 5 | import logging 6 | logger = logging.getLogger('logger') 7 | 8 | class ModelReplace(Attack): 9 | 10 | def __init__(self, params, synthesizer): 11 | super().__init__(params, synthesizer) 12 | self.loss_tasks.append('cs_constraint') 13 | self.fixed_scales = {'normal':0.3, 14 | 'backdoor':0.3, 15 | 'cs_constraint':0.4} 16 | 17 | def perform_attack(self, _, user: FLUser, epoch): 18 | if self.params.fl_number_of_adversaries <= 0 or \ 19 | epoch not in range(self.params.poison_epoch,\ 20 | self.params.poison_epoch_stop): 21 | return 22 | 23 | folder_name = f'{self.params.folder_path}/saved_updates' 24 | file_name = f'{folder_name}/update_{user.user_id}.pth' 25 | loaded_params = torch.load(file_name) 26 | logger.info(f"Loaded update user id: {user.user_id} from {file_name} with scale {self.params.fl_weight_scale}") 27 | self.scale_update(loaded_params, self.params.fl_weight_scale) 28 | torch.save(loaded_params, file_name) 29 | 30 | # for i in range(self.params.fl_number_of_adversaries): 31 | # file_name = f'{folder_name}/update_{i}.pth' 32 | # torch.save(loaded_params, file_name) -------------------------------------------------------------------------------- /defenses/fedavg.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Any, Dict 3 | import torch 4 | import logging 5 | import os 6 | from utils.parameters import Params 7 | 8 | logger = logging.getLogger('logger') 9 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 10 | 11 | class FedAvg: 12 | params: Params 13 | ignored_weights = ['num_batches_tracked']#['tracked', 'running'] 14 | 15 | def __init__(self, params: Params) -> None: 16 | self.params = params 17 | 18 | # FedAvg aggregation 19 | def aggr(self, weight_accumulator, _): 20 | # logger.info(f"Aggregating {len(self.params.fl_weight_contribution.keys())} participants") 21 | 22 | # weight_contrib = self.params.fl_weight_contribution 23 | 24 | for user_id, weight_contrib_user in self.params.fl_weight_contribution.items(): 25 | # logger.info(f"Aggregating participant: {user_id} with weight: {weight_contrib_user}") 26 | loaded_params = self.params.fl_local_updated_models[user_id] 27 | self.accumulate_weights(weight_accumulator, \ 28 | {key:(loaded_params[key] * weight_contrib_user ).to(self.params.device) for \ 29 | key in loaded_params}) 30 | 31 | # for idRound, userID in enumerate(self.params.fl_round_participants): 32 | # weight_contrib_user = self.params.fl_weight_contribution[userID] 33 | # # updates_name = '{0}/saved_updates/update_{1}.pth'\ 34 | # # .format(self.params.folder_path, userID) 35 | # # # logger.info(f"Aggregating participant {userID} path: {updates_name}") 36 | 37 | # # loaded_params = torch.load(updates_name) 38 | # loaded_params = self.params.fl_local_updated_models[userID] 39 | # self.accumulate_weights(weight_accumulator, \ 40 | # {key:(loaded_params[key] * weight_contrib_user ).to(self.params.device) for \ 41 | # key in loaded_params}) 42 | 43 | def accumulate_weights(self, weight_accumulator, local_update): 44 | for name, value in local_update.items(): 45 | weight_accumulator[name].add_(value) 46 | 47 | def get_update_norm(self, local_update): 48 | squared_sum = 0 49 | for name, value in local_update.items(): 50 | if 'tracked' in name or 'running' in name: 51 | continue 52 | squared_sum += torch.sum(torch.pow(value, 2)).item() 53 | update_norm = math.sqrt(squared_sum) 54 | return update_norm 55 | 56 | def add_noise(self, sum_update_tensor: torch.Tensor, sigma): 57 | noised_layer = torch.FloatTensor(sum_update_tensor.shape) 58 | noised_layer = noised_layer.to(self.params.device) 59 | noised_layer.normal_(mean=0, std=sigma) 60 | sum_update_tensor.add_(noised_layer) 61 | 62 | def check_ignored_weights(self, name) -> bool: 63 | for ignored in self.ignored_weights: 64 | if ignored in name: 65 | return True 66 | 67 | return False -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: aba # Name of the Conda environment 2 | 3 | channels: 4 | - defaults # You can add other channels if needed 5 | 6 | dependencies: 7 | - python=3.8 # Python version 8 | - numpy=1.21.2 # Example package with version 9 | - pandas=1.3.3 # Another package with a specific version 10 | - pytorch=1.9.0 # PyTorch version 11 | - torchvision=0.10.0 # Torchvision version 12 | - pip # Install pip 13 | - pip: 14 | # - torchattacks==3.1.0 # Install torchattacks 15 | # - torchsummary==1.5.1 # Install torchsummary 16 | # - torchmetrics==0.5.1 # Install torchmetrics 17 | # - torch_optimizer==0.1.0 # Install torch_optimizer 18 | - tqdm==4.62.3 # Install tqdm 19 | - matplotlib==3.4.3 # Install matplotlib 20 | - seaborn==0.11.2 # Install seaborn 21 | # - scikit-learn==0.24.2 # Install scikit-learn 22 | # - scikit-image==0.18.3 # Install scikit-image 23 | # - opencv-python== -------------------------------------------------------------------------------- /exps/cifar_fed.yaml: -------------------------------------------------------------------------------- 1 | task: Cifar10 2 | synthesizer: IBA 3 | 4 | random_seed: 42 5 | batch_size: 256 6 | test_batch_size: 512 7 | lr: 0.005 8 | momentum: 0.9 9 | decay: 0.0005 10 | epochs: 1500 11 | poison_epoch: 0 12 | poison_epoch_stop: 1500 13 | save_on_epochs: [] # [30, 50, 80, 100, 120, 150, 170, 200] 14 | optimizer: SGD 15 | log_interval: 100 16 | 17 | poisoning_proportion: 0.5 18 | backdoor_label: 8 19 | 20 | resume_model: saved_models/cifar_pretrain/model_last.pt.tar.epoch_200 21 | 22 | save_model: True 23 | log: True 24 | 25 | transform_train: True 26 | 27 | fl: True 28 | fl_total_participants: 100 29 | fl_no_models: 10 30 | fl_local_epochs: 2 31 | fl_poison_epochs: 5 32 | # fl_poison_epochs: 15 33 | fl_eta: 1 # 0.8`3 34 | fl_sample_dirichlet: True 35 | fl_dirichlet_alpha: 0.5 36 | 37 | fl_number_of_adversaries: 4 38 | fl_weight_scale: 1 39 | fl_adv_group_size: 2 40 | # fl_single_epoch_attack: 200 41 | 42 | attack: ModelReplace 43 | defense: 'FedAvg' 44 | fl_num_neurons: 5 45 | noise_mask_alpha: 0 # 0.5 46 | lagrange_step: 0.1 -------------------------------------------------------------------------------- /exps/gen-run-yaml-file.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | str_write_to_file = "" 5 | 6 | def generate_exps_file(root_file='./exps/cifar_fed.yaml', name_exp = 'cifar10', EXPS_DIR="./exps/extras"): 7 | # read file as a string 8 | global str_write_to_file 9 | print(f'reading from root_file: {root_file}') 10 | str_write_to_file += f'# Dataset: {name_exp} \n**Reading default config from root_file: {root_file}**\n\n' 11 | str_write_to_file += '------------------------\n\n' 12 | 13 | fl_total_participants_choices = [100, 200] 14 | fl_no_models_choices = [10, 20] 15 | fl_dirichlet_alpha_choices = [0.5] 16 | fl_number_of_adversaries_choices = [4] 17 | fl_lr_choices = [0.05] 18 | resume_model_choices = [False, True] 19 | lis_resume_model = [ 20 | 'resume_model: saved_models/tiny_64_pretrain/tiny-resnet.epoch_20', 21 | 'resume_model: saved_models/mnist_pretrain/model_last.pt.tar.epoch_10', 22 | 'resume_model: saved_models/cifar_pretrain/model_last.pt.tar.epoch_200', 23 | ] 24 | 25 | # EXPS_DIR = './exps/extras' 26 | 27 | os.makedirs(EXPS_DIR, exist_ok=True) 28 | exp_number = 0 29 | 30 | for fl_total_participants in fl_total_participants_choices: 31 | for fl_no_models in fl_no_models_choices: 32 | for fl_dirichlet_alpha in fl_dirichlet_alpha_choices: 33 | for fl_number_of_adversaries in fl_number_of_adversaries_choices: 34 | for fl_lr in fl_lr_choices: 35 | for resume_model in resume_model_choices: 36 | 37 | with open(root_file, 'r') as file : 38 | filedata = file.read() 39 | 40 | exp_number += 1 41 | str_write_to_file += f'## EXP ID: {exp_number:02d}\n' 42 | pretrained_str = 'pretrained' if resume_model else 'no_pretrained' 43 | print(f'`fl_total_participants: {fl_total_participants} fl_no_models: {fl_no_models} fl_dirichlet_alpha: {fl_dirichlet_alpha} fl_number_of_adversaries: {fl_number_of_adversaries} fl_lr: {fl_lr} resume_model: {resume_model}`\n\n') 44 | str_write_to_file += f'fl_total_participants: {fl_total_participants} fl_no_models: {fl_no_models} fl_dirichlet_alpha: {fl_dirichlet_alpha} fl_number_of_adversaries: {fl_number_of_adversaries} fl_lr: {fl_lr} resume_model: {resume_model}\n\n' 45 | filedata = filedata.replace('fl_total_participants: 100', f'fl_total_participants: {fl_total_participants}') 46 | filedata = filedata.replace('fl_no_models: 10', f'fl_no_models: {fl_no_models}') 47 | filedata = filedata.replace('fl_dirichlet_alpha: 0.5', f'fl_dirichlet_alpha: {fl_dirichlet_alpha}') 48 | filedata = filedata.replace('fl_number_of_adversaries: 4', f'fl_number_of_adversaries: {fl_number_of_adversaries}') 49 | filedata = filedata.replace('lr: 0.005', f'lr: {fl_lr}') 50 | # print(filedata) 51 | 52 | if not resume_model: 53 | for resume_model_path in lis_resume_model: 54 | if resume_model_path in filedata: 55 | filedata = filedata.replace(resume_model_path, f'resume_model: \n') 56 | # print?(filedata) 57 | # exit(0) 58 | # print(len(filedata), type(filedata)) 59 | # print('------------------------') 60 | # write the file out again 61 | 62 | fn_write = f'{EXPS_DIR}/{name_exp}_fed_{fl_total_participants}_{fl_no_models}_{fl_number_of_adversaries}_{fl_dirichlet_alpha}_{fl_lr}_{pretrained_str}.yaml' 63 | 64 | if not os.path.exists(fn_write): 65 | with open(fn_write, 'w') as file: 66 | file.write(filedata) 67 | 68 | cmd = f'```bash\nCUDA_VISIBLE_DEVICES=0 python training.py --name {name_exp} --params {fn_write}\n```\n' 69 | print(cmd) 70 | str_write_to_file += f'{cmd}\n' 71 | print('------------------------') 72 | # if exp_number == 2: 73 | # exit(0) 74 | 75 | 76 | current_time = datetime.now().strftime('%Y.%b.%d') 77 | 78 | generate_exps_file(root_file='./exps/cifar_fed.yaml', 79 | name_exp = 'cifar10', 80 | EXPS_DIR=f"./exps/run_cifar10__{current_time}") 81 | 82 | 83 | generate_exps_file(root_file='./exps/mnist_fed.yaml', 84 | name_exp = 'mnist', 85 | EXPS_DIR=f"./exps/run_mnist__{current_time}") 86 | 87 | generate_exps_file(root_file='./exps/imagenet_fed.yaml', 88 | name_exp = 'tiny-imagenet', 89 | EXPS_DIR=f"./exps/run_tiny-imagenet__{current_time}") 90 | 91 | with open(f'./exps/run_exps__{current_time}.md', 'w') as file: 92 | file.write(str_write_to_file) 93 | 94 | -------------------------------------------------------------------------------- /exps/imagenet_fed.yaml: -------------------------------------------------------------------------------- 1 | task: Imagenet 2 | synthesizer: Pattern 3 | 4 | data_path: .data/tiny-imagenet-200 5 | 6 | random_seed: 42 7 | batch_size: 256 8 | test_batch_size: 512 9 | lr: 0.001 10 | momentum: 0.9 11 | decay: 0.0005 12 | epochs: 1500 13 | poison_epoch: 0 14 | poison_epoch_stop: 1500 15 | save_on_epochs: [] # [10, 20, 30, 40, 50] 16 | optimizer: SGD 17 | log_interval: 100 18 | 19 | poisoning_proportion: 0.3 20 | backdoor_label: 8 21 | 22 | resume_model: saved_models/tiny_64_pretrain/tiny-resnet.epoch_20 23 | 24 | save_model: True 25 | log: True 26 | report_train_loss: False 27 | 28 | transform_train: True 29 | 30 | fl: True 31 | fl_no_models: 10 32 | fl_local_epochs: 2 33 | fl_poison_epochs: 10 34 | fl_total_participants: 100 35 | fl_eta: 1 36 | fl_sample_dirichlet: True 37 | fl_dirichlet_alpha: 0.5 38 | 39 | fl_number_of_adversaries: 4 40 | # fl_number_of_scapegoats: 0 41 | fl_weight_scale: 5 42 | fl_adv_group_size: 5 43 | # fl_single_epoch_attack: 10 44 | 45 | attack: ModelReplace 46 | defense: FedAvg 47 | fl_num_neurons: 100 48 | noise_mask_alpha: 0 # 0.5 49 | lagrange_step: 0.1 -------------------------------------------------------------------------------- /exps/list_exps__2023.Nov.25.md: -------------------------------------------------------------------------------- 1 | # Dataset: cifar10 2 | **Reading default config from root_file: ./exps/cifar_fed.yaml** 3 | 4 | ------------------------ 5 | 6 | ## EXP ID: 01 7 | fl_total_participants: 100 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False 8 | 9 | ```bash 10 | CUDA_VISIBLE_DEVICES=0 python training.py --name cifar10 --params ./exps/run_cifar10__2023.Nov.25/cifar10_fed_100_10_4_0.5_0.05_no_pretrained.yaml 11 | ``` 12 | 13 | ## EXP ID: 02 14 | fl_total_participants: 100 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True 15 | 16 | ```bash 17 | CUDA_VISIBLE_DEVICES=0 python training.py --name cifar10 --params ./exps/run_cifar10__2023.Nov.25/cifar10_fed_100_10_4_0.5_0.05_pretrained.yaml 18 | ``` 19 | 20 | ## EXP ID: 03 21 | fl_total_participants: 100 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False 22 | 23 | ```bash 24 | CUDA_VISIBLE_DEVICES=0 python training.py --name cifar10 --params ./exps/run_cifar10__2023.Nov.25/cifar10_fed_100_20_4_0.5_0.05_no_pretrained.yaml 25 | ``` 26 | 27 | ## EXP ID: 04 28 | fl_total_participants: 100 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True 29 | 30 | ```bash 31 | CUDA_VISIBLE_DEVICES=0 python training.py --name cifar10 --params ./exps/run_cifar10__2023.Nov.25/cifar10_fed_100_20_4_0.5_0.05_pretrained.yaml 32 | ``` 33 | 34 | ## EXP ID: 05 35 | fl_total_participants: 200 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False 36 | 37 | ```bash 38 | CUDA_VISIBLE_DEVICES=0 python training.py --name cifar10 --params ./exps/run_cifar10__2023.Nov.25/cifar10_fed_200_10_4_0.5_0.05_no_pretrained.yaml 39 | ``` 40 | 41 | ## EXP ID: 06 42 | fl_total_participants: 200 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True 43 | 44 | ```bash 45 | CUDA_VISIBLE_DEVICES=0 python training.py --name cifar10 --params ./exps/run_cifar10__2023.Nov.25/cifar10_fed_200_10_4_0.5_0.05_pretrained.yaml 46 | ``` 47 | 48 | ## EXP ID: 07 49 | fl_total_participants: 200 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False 50 | 51 | ```bash 52 | CUDA_VISIBLE_DEVICES=0 python training.py --name cifar10 --params ./exps/run_cifar10__2023.Nov.25/cifar10_fed_200_20_4_0.5_0.05_no_pretrained.yaml 53 | ``` 54 | 55 | ## EXP ID: 08 56 | fl_total_participants: 200 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True 57 | 58 | ```bash 59 | CUDA_VISIBLE_DEVICES=0 python training.py --name cifar10 --params ./exps/run_cifar10__2023.Nov.25/cifar10_fed_200_20_4_0.5_0.05_pretrained.yaml 60 | ``` 61 | 62 | # Dataset: mnist 63 | **Reading default config from root_file: ./exps/mnist_fed.yaml** 64 | 65 | ------------------------ 66 | 67 | ## EXP ID: 01 68 | fl_total_participants: 100 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False 69 | 70 | ```bash 71 | CUDA_VISIBLE_DEVICES=0 python training.py --name mnist --params ./exps/run_mnist__2023.Nov.25/mnist_fed_100_10_4_0.5_0.05_no_pretrained.yaml 72 | ``` 73 | 74 | ## EXP ID: 02 75 | fl_total_participants: 100 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True 76 | 77 | ```bash 78 | CUDA_VISIBLE_DEVICES=0 python training.py --name mnist --params ./exps/run_mnist__2023.Nov.25/mnist_fed_100_10_4_0.5_0.05_pretrained.yaml 79 | ``` 80 | 81 | ## EXP ID: 03 82 | fl_total_participants: 100 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False 83 | 84 | ```bash 85 | CUDA_VISIBLE_DEVICES=0 python training.py --name mnist --params ./exps/run_mnist__2023.Nov.25/mnist_fed_100_20_4_0.5_0.05_no_pretrained.yaml 86 | ``` 87 | 88 | ## EXP ID: 04 89 | fl_total_participants: 100 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True 90 | 91 | ```bash 92 | CUDA_VISIBLE_DEVICES=0 python training.py --name mnist --params ./exps/run_mnist__2023.Nov.25/mnist_fed_100_20_4_0.5_0.05_pretrained.yaml 93 | ``` 94 | 95 | ## EXP ID: 05 96 | fl_total_participants: 200 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False 97 | 98 | ```bash 99 | CUDA_VISIBLE_DEVICES=0 python training.py --name mnist --params ./exps/run_mnist__2023.Nov.25/mnist_fed_200_10_4_0.5_0.05_no_pretrained.yaml 100 | ``` 101 | 102 | ## EXP ID: 06 103 | fl_total_participants: 200 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True 104 | 105 | ```bash 106 | CUDA_VISIBLE_DEVICES=0 python training.py --name mnist --params ./exps/run_mnist__2023.Nov.25/mnist_fed_200_10_4_0.5_0.05_pretrained.yaml 107 | ``` 108 | 109 | ## EXP ID: 07 110 | fl_total_participants: 200 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False 111 | 112 | ```bash 113 | CUDA_VISIBLE_DEVICES=0 python training.py --name mnist --params ./exps/run_mnist__2023.Nov.25/mnist_fed_200_20_4_0.5_0.05_no_pretrained.yaml 114 | ``` 115 | 116 | ## EXP ID: 08 117 | fl_total_participants: 200 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True 118 | 119 | ```bash 120 | CUDA_VISIBLE_DEVICES=0 python training.py --name mnist --params ./exps/run_mnist__2023.Nov.25/mnist_fed_200_20_4_0.5_0.05_pretrained.yaml 121 | ``` 122 | 123 | # Dataset: tiny-imagenet 124 | **Reading default config from root_file: ./exps/imagenet_fed.yaml** 125 | 126 | ------------------------ 127 | 128 | ## EXP ID: 01 129 | fl_total_participants: 100 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False 130 | 131 | ```bash 132 | CUDA_VISIBLE_DEVICES=0 python training.py --name tiny-imagenet --params ./exps/run_tiny-imagenet__2023.Nov.25/tiny-imagenet_fed_100_10_4_0.5_0.05_no_pretrained.yaml 133 | ``` 134 | 135 | ## EXP ID: 02 136 | fl_total_participants: 100 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True 137 | 138 | ```bash 139 | CUDA_VISIBLE_DEVICES=0 python training.py --name tiny-imagenet --params ./exps/run_tiny-imagenet__2023.Nov.25/tiny-imagenet_fed_100_10_4_0.5_0.05_pretrained.yaml 140 | ``` 141 | 142 | ## EXP ID: 03 143 | fl_total_participants: 100 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False 144 | 145 | ```bash 146 | CUDA_VISIBLE_DEVICES=0 python training.py --name tiny-imagenet --params ./exps/run_tiny-imagenet__2023.Nov.25/tiny-imagenet_fed_100_20_4_0.5_0.05_no_pretrained.yaml 147 | ``` 148 | 149 | ## EXP ID: 04 150 | fl_total_participants: 100 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True 151 | 152 | ```bash 153 | CUDA_VISIBLE_DEVICES=0 python training.py --name tiny-imagenet --params ./exps/run_tiny-imagenet__2023.Nov.25/tiny-imagenet_fed_100_20_4_0.5_0.05_pretrained.yaml 154 | ``` 155 | 156 | ## EXP ID: 05 157 | fl_total_participants: 200 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False 158 | 159 | ```bash 160 | CUDA_VISIBLE_DEVICES=0 python training.py --name tiny-imagenet --params ./exps/run_tiny-imagenet__2023.Nov.25/tiny-imagenet_fed_200_10_4_0.5_0.05_no_pretrained.yaml 161 | ``` 162 | 163 | ## EXP ID: 06 164 | fl_total_participants: 200 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True 165 | 166 | ```bash 167 | CUDA_VISIBLE_DEVICES=0 python training.py --name tiny-imagenet --params ./exps/run_tiny-imagenet__2023.Nov.25/tiny-imagenet_fed_200_10_4_0.5_0.05_pretrained.yaml 168 | ``` 169 | 170 | ## EXP ID: 07 171 | fl_total_participants: 200 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False 172 | 173 | ```bash 174 | CUDA_VISIBLE_DEVICES=0 python training.py --name tiny-imagenet --params ./exps/run_tiny-imagenet__2023.Nov.25/tiny-imagenet_fed_200_20_4_0.5_0.05_no_pretrained.yaml 175 | ``` 176 | 177 | ## EXP ID: 08 178 | fl_total_participants: 200 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True 179 | 180 | ```bash 181 | CUDA_VISIBLE_DEVICES=0 python training.py --name tiny-imagenet --params ./exps/run_tiny-imagenet__2023.Nov.25/tiny-imagenet_fed_200_20_4_0.5_0.05_pretrained.yaml 182 | ``` 183 | 184 | -------------------------------------------------------------------------------- /exps/mnist_fed.yaml: -------------------------------------------------------------------------------- 1 | task: MNIST 2 | synthesizer: Pattern 3 | 4 | random_seed: 42 5 | batch_size: 256 6 | test_batch_size: 512 7 | lr: 0.01 8 | momentum: 0.9 9 | decay: 0.0005 10 | epochs: 1500 11 | poison_epoch: 0 12 | poison_epoch_stop: 1500 13 | save_on_epochs: [] # [10, 20, 30, 40, 50] 14 | optimizer: SGD 15 | log_interval: 100 16 | 17 | poisoning_proportion: 1.0 18 | backdoor_label: 8 19 | 20 | resume_model: saved_models/mnist_pretrain/model_last.pt.tar.epoch_10 21 | 22 | save_model: True 23 | log: True 24 | report_train_loss: False 25 | 26 | transform_train: True 27 | 28 | fl: True 29 | fl_no_models: 10 30 | fl_local_epochs: 2 31 | fl_poison_epochs: 5 32 | fl_total_participants: 100 33 | fl_eta: 1 34 | fl_sample_dirichlet: True 35 | fl_dirichlet_alpha: 0.5 36 | 37 | fl_number_of_adversaries: 4 38 | fl_weight_scale: 5 39 | fl_adv_group_size: 2 40 | # fl_single_epoch_attack: 20 41 | 42 | attack: ModelReplace 43 | defense: FedAvg 44 | fl_num_neurons: 8 45 | noise_mask_alpha: 0 # 0.5 46 | lagrange_step: 0.1 -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | import os 4 | import random 5 | from shutil import copyfile 6 | from collections import defaultdict 7 | 8 | import numpy as np 9 | import torch 10 | import yaml 11 | 12 | from attacks.attack import Attack 13 | # from defenses.fedavg import FedAvg as Defense 14 | from synthesizers.synthesizer import Synthesizer 15 | from tasks.task import Task 16 | from utils.parameters import Params 17 | from utils.utils import create_logger 18 | import pandas as pd 19 | logger = logging.getLogger('logger') 20 | 21 | 22 | class Helper: 23 | params: Params = None 24 | task: Task = None 25 | synthesizer: Synthesizer = None 26 | # defense: Defense = None 27 | attack: Attack = None 28 | 29 | def __init__(self, params): 30 | 31 | self.params = Params(**params) 32 | self.times = {'backward': list(), 'forward': list(), 'step': list(), 33 | 'scales': list(), 'total': list(), 'poison': list()} 34 | 35 | if self.params.random_seed is not None: 36 | self.fix_random(self.params.random_seed) 37 | 38 | self.make_folders() 39 | 40 | self.make_task() 41 | 42 | self.make_synthesizer() 43 | 44 | self.make_attack() 45 | 46 | self.make_defense() 47 | 48 | self.accuracy = [[],[]] 49 | 50 | self.best_loss = float('inf') 51 | 52 | def make_task(self): 53 | name_lower = self.params.task.lower() 54 | name_cap = self.params.task 55 | module_name = f'tasks.{name_lower}_task' 56 | path = f'tasks/{name_lower}_task.py' 57 | logger.info(f'make task: {module_name} name_cap: {name_cap} path: {path}') 58 | try: 59 | task_module = importlib.import_module(module_name) 60 | task_class = getattr(task_module, f'{name_cap}Task') 61 | except (ModuleNotFoundError, AttributeError): 62 | raise ModuleNotFoundError(f'Your task: {self.params.task} should ' 63 | f'be defined as a class ' 64 | f'{name_cap}' 65 | f'Task in {path}') 66 | self.task = task_class(self.params) 67 | 68 | def make_synthesizer(self): 69 | name_lower = self.params.synthesizer.lower() 70 | name_cap = self.params.synthesizer 71 | module_name = f'synthesizers.{name_lower}_synthesizer' 72 | logger.info(f'make synthesizer: {module_name} name_cap: {name_cap}') 73 | try: 74 | synthesizer_module = importlib.import_module(module_name) 75 | task_class = getattr(synthesizer_module, f'{name_cap}Synthesizer') 76 | except (ModuleNotFoundError, AttributeError): 77 | raise ModuleNotFoundError( 78 | f'The synthesizer: {self.params.synthesizer}' 79 | f' should be defined as a class ' 80 | f'{name_cap}Synthesizer in ' 81 | f'synthesizers/{name_lower}_synthesizer.py') 82 | self.synthesizer = task_class(self.task) 83 | 84 | def make_attack(self): 85 | name_lower = self.params.attack.lower() 86 | name_cap = self.params.attack 87 | module_name = f'attacks.{name_lower}' 88 | logger.info(f'make attack: {module_name} name_cap: {name_cap}') 89 | try: 90 | attack_module = importlib.import_module(module_name) 91 | attack_class = getattr(attack_module, f'{name_cap}') 92 | except (ModuleNotFoundError, AttributeError): 93 | raise ModuleNotFoundError(f'Your attack: {self.params.attack} should ' 94 | f'be defined either ThrDFed (3DFed) or \ 95 | ModelReplace (Model Replacement Attack)') 96 | self.attack = attack_class(self.params, self.synthesizer) 97 | 98 | def make_defense(self): 99 | name_lower = self.params.defense.lower() 100 | name_cap = self.params.defense 101 | module_name = f'defenses.{name_lower}' 102 | try: 103 | defense_module = importlib.import_module(module_name) 104 | defense_class = getattr(defense_module, f'{name_cap}') 105 | except (ModuleNotFoundError, AttributeError): 106 | raise ModuleNotFoundError(f'Your defense: {self.params.defense} should ' 107 | f'be one of the follow: FLAME, Deepsight, \ 108 | Foolsgold, FLDetector, RFLBAT, FedAvg') 109 | self.defense = defense_class(self.params) 110 | 111 | def make_folders(self): 112 | log = create_logger() 113 | if self.params.log: 114 | os.makedirs(self.params.folder_path, exist_ok=True) 115 | 116 | fh = logging.FileHandler( 117 | filename=f'{self.params.folder_path}/log.txt') 118 | formatter = logging.Formatter('%(asctime)s - %(filename)s - Line:%(lineno)d - %(levelname)-8s - %(message)s') 119 | 120 | fh.setFormatter(formatter) 121 | log.addHandler(fh) 122 | 123 | with open(f'{self.params.folder_path}/params.yaml.txt', 'w') as f: 124 | yaml.dump(self.params, f) 125 | # log.info(f"Creating folder {self.params.folder_path}") 126 | 127 | 128 | def save_model(self, model=None, epoch=0, val_loss=0): 129 | 130 | if self.params.save_model: 131 | logger.info(f"Saving model to {self.params.folder_path}.") 132 | model_name = '{0}/model_last.pt.tar'.format(self.params.folder_path) 133 | saved_dict = {'state_dict': model.state_dict(), 134 | 'epoch': epoch, 135 | 'lr': self.params.lr, 136 | 'params_dict': self.params.to_dict()} 137 | self.save_checkpoint(saved_dict, False, model_name) 138 | 139 | if epoch in self.params.save_on_epochs: 140 | logger.info(f'Saving model on epoch {epoch}') 141 | self.save_checkpoint(saved_dict, False, 142 | filename=f'{self.params.folder_path}/model_epoch_{epoch}.pt.tar') 143 | if val_loss < self.best_loss: 144 | self.save_checkpoint(saved_dict, False, f'{model_name}_best') 145 | self.best_loss = val_loss 146 | 147 | logger.info(f"Done saving model to {self.params.folder_path}.") 148 | def save_update(self, model=None, userID = 0): 149 | folderpath = '{0}/saved_updates'.format(self.params.folder_path) 150 | # logger.info(f"Saving update to {folderpath}.") 151 | 152 | # if not os.path.exists(folderpath): 153 | # os.makedirs(folderpath)] 154 | # import IPython; IPython.embed() 155 | # exit(0) 156 | os.makedirs(folderpath, exist_ok=True) 157 | update_name = '{0}/update_{1}.pth'.format(folderpath, userID) 158 | torch.save(model, update_name) 159 | # logger.info(f"Saving update to {update_name}.") 160 | 161 | def remove_update(self): 162 | for i in range(self.params.fl_total_participants): 163 | file_name = '{0}/saved_updates/update_{1}.pth'.format(self.params.folder_path, i) 164 | if os.path.exists(file_name): 165 | os.remove(file_name) 166 | os.rmdir('{0}/saved_updates'.format(self.params.folder_path)) 167 | if self.params.defense == 'Foolsgold': 168 | for i in range(self.params.fl_total_participants): 169 | file_name = '{0}/foolsgold/history_{1}.pth'.format(self.params.folder_path, i) 170 | if os.path.exists(file_name): 171 | os.remove(file_name) 172 | os.rmdir('{0}/foolsgold'.format(self.params.folder_path)) 173 | 174 | def record_accuracy(self, main_acc, backdoor_acc, epoch): 175 | self.accuracy[0].append(main_acc) 176 | self.accuracy[1].append(backdoor_acc) 177 | name = ['main', 'backdoor'] 178 | acc_frame = pd.DataFrame(columns=name, data=zip(*self.accuracy), 179 | index=range(self.params.start_epoch, epoch+1)) 180 | filepath = f"{self.params.folder_path}/accuracy.csv" 181 | acc_frame.to_csv(filepath) 182 | logger.info(f"Saving accuracy record to {filepath}") 183 | 184 | def set_bn_eval(m): 185 | classname = m.__class__.__name__ 186 | if classname.find('BatchNorm') != -1: 187 | m.eval() 188 | 189 | def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'): 190 | if not self.params.save_model: 191 | return False 192 | torch.save(state, filename) 193 | 194 | if is_best: 195 | copyfile(filename, 'model_best.pth.tar') 196 | 197 | @staticmethod 198 | def fix_random(seed=1): 199 | from torch.backends import cudnn 200 | 201 | logger.warning('Setting random_seed seed for reproducible results.') 202 | random.seed(seed) 203 | torch.manual_seed(seed) 204 | torch.cuda.manual_seed_all(seed) 205 | cudnn.deterministic = False 206 | cudnn.enabled = True 207 | cudnn.benchmark = True 208 | np.random.seed(seed) 209 | 210 | return True 211 | -------------------------------------------------------------------------------- /markdown-explanation/FedMLSec.md: -------------------------------------------------------------------------------- 1 | Table of Contents 2 | ----------------- 3 | 4 | 5 | # Setup Python Environment 6 | - Set the name of environment in both files: `environment.yml` and `Makefile`. The default name is `aba`, aka "all backdoor attacks" and then run following commands: 7 | ``` 8 | make install 9 | ``` 10 | 11 | # Guideline for custome training (Atk vs Def) 12 | 13 | ## Data Customization 14 | - [Data Loader](https://github.com/FedML-AI/FedML/blob/master/doc/en/simulation/user_guide/data_loader_customization.md) 15 | 16 | ## Datasets and Models Customization 17 | - [Datasets and Models](https://github.com/FedML-AI/FedML/blob/master/doc/en/simulation/user_guide/datasets-and-models.md#datasets-and-models) 18 | - [FedML Data](https://github.com/FedML-AI/FedML/tree/master/python/fedml/data) 19 | 20 | ## Attack Customization 21 | ### [Model Replacement Attack (MRA)](https://arxiv.org/pdf/1807.00459.pdf) 22 | - [Code](https://github.com/ebagdasa/backdoors101) is in `fedlearn-backdoor-attacks/3DFed/attacks/modelreplace.py` 23 | ![Alt text](./model-replacement-attack.png) 24 | 25 | 26 | ## Defense Customization 27 | 28 | ## Training Customization 29 | 30 | ## Evaluation Customization 31 | 32 | ## Visualization Customization 33 | 34 | ## Result Customization 35 | 36 | 37 | ## Flow of the code: 38 | ### [3DFed (ThrDFed) - S&P'23](https://github.com/haoyangliASTAPLE/3DFed) 39 | - Init data, attack, defense method in `fedlearn-backdoor-attacks/3DFed/helper.py` 40 | - Run fl round in `fedlearn-backdoor-attacks/3DFed/training.py` 41 | - Sample user for round (fl_no_models/ fl_total_participants) 42 | - Init FLUser (user_id, compromised, train_loader) 43 | - Single epoch attack: user_id = 0 is attacker, compromised = True 44 | - Otherwise: check if epoch in attack_epochs, if yes, check list adversaries 45 | - Training for each user 46 | - If user is attacker, run attack for only user_id = 0 [missing other attackers], that means training models on poisoned data 47 | - Otherwise, run defense for all users 48 | - Perform Attack and aggregate results 49 | - check update_global_model with weight: 1/total_participants (self.params.fl_eta / self.params.fl_total_participants) 50 | - Limitation: Currently, dump fl_no_models (set = fl_total_participants) models in each round into file, only one attacker is supported (other attackers is duplicated from attacker 0) 51 | ### [DBA - ICLR'20 (Distributed Backdoor Attack)](https://github.com/AI-secure/DBA) 52 | - List of adversarial clients in config file: adversary_list: [17, 33, 77, 11] 53 | - no_models: K = 10 54 | - number_of_total_participants: N = 100 55 | - agent_name_keys: id of client that participate in each round (K/ N) 56 | - agent_name_keys = benign + adv (check adv first, then benign); set epoch attack for each client in .yaml file 57 | In each communication round: 58 | - Check the trigger is global pattern or local pattern 59 | - Get poison batch data from adversary_index (-1, 0, 1, 2, 3) 60 | 61 | 62 | ### [Attack of the Tails: Yes, You Really Can Backdoor Federated Learning - NeurIPS'20](https://github.com/ksreenivasan/OOD_Federated_Learning) 63 | 64 | Load poisoned dataset (in `simulated_averaging.py`): 65 | ```python 66 | poisoned_train_loader, vanilla_test_loader, targetted_task_test_loader, num_dps_poisoned_dataset, clean_train_loader = load_poisoned_dataset(args=args) 67 | ``` 68 | 69 | Two modes (fixed-freq mode or fixed-pool mode): 70 | ```python 71 | Intotal: N (num_nets) clients, K (part_nets_per_round) clients are participating in each round 72 | 73 | - Fixed-freq Attack (`FrequencyFederatedLearningTrainer`): 74 | - "attacking_fl_rounds":[i for i in range(1, args.fl_round + 1) if (i-1)%10 == 0] 75 | - poison_type in ["ardis": ["normal-case", "almost-edge-case", "edge-case"], "southwest"] 76 | - Training in each communication round: 77 | - if round in attacking_fl_rounds, run attack 78 | - One attacker, and K-1 benign clients 79 | - For attacker, run attack in `adversarial_local_training_period` epochs 80 | - data_loader = `poisoned_train_loader` 81 | - Check defense_technique in ["krum", "multi-krum"]: eps=self.eps*self.args_gamma**(flr-1); else eps=self.eps 82 | - Test on `vanilla_test_loader` and `targetted_task_test_loader` 83 | - Check model_replacement scale models with ratio: total_num_dps_per_round/num_dps_poisoned_dataset 84 | - Print norm before and after scale 85 | - For benign clients, run normal training in `local_training_period` epochs 86 | - data_loader = `clean_train_loader` 87 | - otherwise 88 | - run normal training with K benign clients 89 | - Aggregate models 90 | - using 91 | - TODO: check prox-attack 92 | 93 | - test function: 94 | - check dataset: 95 | - dataset in ["mnist", "emnist"]: 96 | - target_class = 7 97 | - task in ["raw-task", "targeted"] 98 | - poison_type in ["ardis"] 99 | - dataset in ["cifar10"]: 100 | - target_class = 2 for greencar in ("howto", "greencar-neo"), 101 | - target_class = 9 for southwest 102 | - TODO: check backdoor acc calculation (line 248-253) 103 | 104 | - Fixed-pool Attack: 105 | - In each round of communication, randomly select K clients to participate in the training 106 | - selected_attackers from "__attacker_pool" (numpy random choice attacker_pool_size/ num_nets) 107 | - selected_benign_clients 108 | All attackers are sharing the same poisoned dataset 109 | - Training the same as Fixed-freq Attack for each attacker 110 | 111 | - defense_technique in ["no-defense", "norm-clipping", "weak-dp", "krum", "multi-krum", "rfa"] 112 | ``` 113 | For more details, check `FedML` [dataset](https://github.com/FedML-AI/FedML/tree/master/python/fedml/data/edge_case_examples) 114 | 115 | ### [How To Backdoor Federated Learning (AISTATS'20)](https://github.com/ebagdasa/backdoors101) 116 | - Generate trigger in file `synthesizers/pattern_synthesizer.py`. The 2 images below are the pattern of the trigger and image with trigger. 117 | 119 | - All the original images and trigger images are normalized to [mean, std] of the original dataset 120 | 121 |
122 | Trigger 124 | Image 2 126 |
127 | 128 | - Define the loss function in `attacks/loss_function.py` 129 | Standard version: each malicious client computes 2 losses and weights them equally (0.5, 0.5) 130 | 131 | ```Python 132 | # Criterion: nn.CrossEntropyLoss(reduction='none'); get all losses shape (batch_size) 133 | Total loss = 0.5 * (loss1 + loss2); check `scale_losses` in `attacks/attack.py` 134 | ``` 135 | 136 | The result of the experiment can be found in this [wandb](https://wandb.ai/mtuann/benchmark-backdoor-fl?workspace=user-mtuann). 137 | 138 | 139 | ## TODO: 140 | - [ ] Setting standard FL attack from Attack of the Tails and DBA 141 | - [ ] Change dump to file -> dump to memory 142 | - [ ] Check popular defense method: Foolsgold, RFA, ... 143 | 144 | ## Reference 145 | 146 | 147 | # Source 148 | - [FedMLSecurity: A Benchmark for Attacks and Defenses in Federated Learning and Federated LLMs](https://arxiv.org/pdf/2306.04959.pdf) 149 | - [Attack and Defense of FedMLSecurity](https://github.com/FedML-AI/FedML/blob/master/python/fedml/core/security/readme.md) 150 | - [fedMLSecurity_experiments](https://github.com/FedML-AI/FedML/tree/master/python/examples/security/fedMLSecurity_experiments) 151 | 152 | -------------------------------------------------------------------------------- /markdown-explanation/Figure_4_ba_norm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtuann/fedlearn-backdoor-attacks/b025e763558d1e3d6a5872be7d158e8c157e7bf1/markdown-explanation/Figure_4_ba_norm.png -------------------------------------------------------------------------------- /markdown-explanation/Figure_full-image-pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtuann/fedlearn-backdoor-attacks/b025e763558d1e3d6a5872be7d158e8c157e7bf1/markdown-explanation/Figure_full-image-pattern.png -------------------------------------------------------------------------------- /markdown-explanation/code-structures.md: -------------------------------------------------------------------------------- 1 | 2 | Overview of the code structures used in the project. 3 | 4 | Tree structure 5 | -------------- 6 | ``` 7 | . 8 | ├── README.md 9 | ├── training.py 10 | ├── helper.py 11 | ├── utils 12 | │ ├── params.py 13 | │ ├── utils.py 14 | ├── data 15 | │ ├── MNIST 16 | │ ├── CIFAR-10 17 | │ ├── CIFAR-100 18 | ``` 19 | 20 | Running code: 21 | ------------- 22 | ``` 23 | python training.py --name mnist --params exps/mnist_fed.yaml 24 | ``` 25 | Flow of the code: 26 | ----------------- 27 | 1. `training.py` is the main file that is run. 28 | - It parses the arguments and loads the parameters from the yaml file. 29 | - Load all the configurations to `helper.py` (parameters as a variable in a `Helper` class) 30 | 31 | Perform the following steps for each round: 32 | - Training `epochs` communcation rounds 33 | - Traing each round for `fl_local_epochs` epochs 34 | - Save local update to the `saved_updates/update_{user_ID}.pth` file 35 | - Perfrom FedAvg on the local updates in `defenses/fedavg.py` 36 | - Update global model by set the scale $scale = \frac{fl\_eta}{fl\_no\_models} = \frac{1}{k}$ --> change it by the number of samples in each client 37 | 38 | - Evaluate the model on the test set 39 | 40 | 41 | 1. Define the `task.py` in the `tasks` folder, that inherits from the `Task` class. 42 | - `Task` class has following functions: 43 | - `load_data` - load the data from the `data` folder, and split it for different clients. 44 | - `build_model` - build the model for the task 45 | - `resume_model` - resume the model from the checkpoint 46 | - `make_criterion` - define the loss function for the task. 47 | - `make_optimizer` - define the optimizer for the task. 48 | - `sample_adversaries` - sample the adversaries for the task. 49 | - `train` - train the model for one epoch. 50 | - `metrics` - define the metrics for the task, 2 main metrics are `AccuracyMetric` (top-k metrics) and `TestLossMetric`. 51 | 52 | 1. Define the `synthesizer.py` in the `synthesizers` folder, that inherits from the `Synthesizer` class. 53 | 2. Define the `attack.py` in the `attacks` folder, that inherits from the `Attack` class. 54 | - Loss functions are defined in the `attacks/losses.py` file. 55 | 3. Define the loss function 56 | -------------------------------------------------------------------------------- /markdown-explanation/model-replacement-attack.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtuann/fedlearn-backdoor-attacks/b025e763558d1e3d6a5872be7d158e8c157e7bf1/markdown-explanation/model-replacement-attack.png -------------------------------------------------------------------------------- /metrics/accuracy_metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from metrics.metric import Metric 3 | 4 | 5 | class AccuracyMetric(Metric): 6 | 7 | def __init__(self, top_k=(1,)): 8 | self.top_k = top_k 9 | self.main_metric_name = 'Top-1' 10 | super().__init__(name='Accuracy', train=False) 11 | 12 | def compute_metric(self, outputs: torch.Tensor, 13 | labels: torch.Tensor): 14 | """Computes the precision@k for the specified values of k""" 15 | max_k = max(self.top_k) 16 | batch_size = labels.shape[0] 17 | 18 | _, pred = outputs.topk(max_k, 1, True, True) 19 | # get value and index of top k classes 20 | pred = pred.t() # view as column vector 21 | # expand labels to match pred size: 2D size (top_k, batch_size) 22 | correct = pred.eq(labels.view(1, -1).expand_as(pred)) 23 | # correct: True/ False matrix 24 | res = dict() 25 | for k in self.top_k: 26 | correct_k = correct[:k].view(-1).float().sum(0) 27 | res[f'Top-{k}'] = (correct_k.mul_(100.0 / batch_size)).item() 28 | return res 29 | -------------------------------------------------------------------------------- /metrics/metric.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import defaultdict 3 | from typing import Dict, Any 4 | 5 | import numpy as np 6 | 7 | logger = logging.getLogger('logger') 8 | 9 | 10 | class Metric: 11 | name: str 12 | train: bool 13 | plottable: bool = True 14 | running_metric = None 15 | main_metric_name = None 16 | 17 | def __init__(self, name, train=False): 18 | self.train = train 19 | self.name = name 20 | 21 | self.running_metric = defaultdict(list) 22 | 23 | def __repr__(self): 24 | metrics = self.get_value() 25 | text = [f'{key}: {val:.2f}' for key, val in metrics.items()] 26 | return f'{self.name}: ' + ','.join(text) 27 | 28 | def compute_metric(self, outputs, labels) -> Dict[str, Any]: 29 | raise NotImplemented 30 | 31 | def accumulate_on_batch(self, outputs=None, labels=None): 32 | current_metrics = self.compute_metric(outputs, labels) 33 | for key, value in current_metrics.items(): 34 | self.running_metric[key].append(value) 35 | 36 | def get_value(self) -> Dict[str, np.ndarray]: 37 | metrics = dict() 38 | for key, value in self.running_metric.items(): 39 | metrics[key] = np.mean(value) 40 | 41 | return metrics 42 | 43 | def get_main_metric_value(self): 44 | if not self.main_metric_name: 45 | raise ValueError(f'For metric {self.name} define ' 46 | f'attribute main_metric_name.') 47 | metrics = self.get_value() 48 | return metrics[self.main_metric_name] 49 | 50 | def reset_metric(self): 51 | self.running_metric = defaultdict(list) 52 | 53 | -------------------------------------------------------------------------------- /metrics/test_loss_metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from metrics.metric import Metric 3 | 4 | 5 | class TestLossMetric(Metric): 6 | 7 | def __init__(self, criterion, train=False): 8 | self.criterion = criterion 9 | self.main_metric_name = 'value' 10 | super().__init__(name='Loss', train=False) 11 | 12 | def compute_metric(self, outputs: torch.Tensor, 13 | labels: torch.Tensor, top_k=(1,)): 14 | """Computes the precision@k for the specified values of k""" 15 | loss = self.criterion(outputs, labels) 16 | return {'value': loss.mean().item()} -------------------------------------------------------------------------------- /models/MnistNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.simple import SimpleNet 5 | 6 | 7 | class MnistNet(SimpleNet): 8 | def __init__(self, name=None, created_time=None): 9 | super(MnistNet, self).__init__(f'{name}_Simple', created_time) 10 | 11 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 12 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 13 | self.fc1 = nn.Linear(4 * 4 * 50, 500) 14 | self.fc2 = nn.Linear(500, 10) 15 | # self.fc2 = nn.Linear(28*28, 10) 16 | 17 | def forward(self, x): 18 | x = F.relu(self.conv1(x)) 19 | x = F.max_pool2d(x, 2, 2) 20 | x = F.relu(self.conv2(x)) 21 | x = F.max_pool2d(x, 2, 2) 22 | x = x.view(-1, 4 * 4 * 50) 23 | x = F.relu(self.fc1(x)) 24 | x = self.fc2(x) 25 | 26 | # in_features = 28 * 28 27 | # x = x.view(-1, in_features) 28 | # x = self.fc2(x) 29 | 30 | # normal return: 31 | return F.log_softmax(x, dim=1) 32 | # soft max is used for generate SDT data 33 | # return F.softmax(x, dim=1) 34 | 35 | if __name__ == '__main__': 36 | model=MnistNet() 37 | print(model) 38 | 39 | # import numpy as np 40 | # from torchvision import datasets, transforms 41 | # import torch 42 | # import torch.utils.data 43 | # import copy 44 | # 45 | # optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 46 | # 47 | # train_dataset = datasets.MNIST('./data', train=True, download=True, 48 | # transform=transforms.Compose([ 49 | # transforms.ToTensor(), 50 | # # transforms.Normalize((0.1307,), (0.3081,)) 51 | # ])) 52 | # test_dataset = datasets.MNIST('./data', train=False, transform=transforms.Compose([ 53 | # transforms.ToTensor(), 54 | # # transforms.Normalize((0.1307,), (0.3081,)) 55 | # ])) 56 | # train_loader = torch.utils.data.DataLoader(train_dataset, 57 | # batch_size=64, 58 | # shuffle=False) 59 | # client_grad = [] 60 | # 61 | # for batch_id, batch in enumerate(train_loader): 62 | # optimizer.zero_grad() 63 | # data, targets = batch 64 | # output = model(data) 65 | # loss = nn.functional.cross_entropy(output, targets) 66 | # loss.backward() 67 | # for i, (name, params) in enumerate(model.named_parameters()): 68 | # if params.requires_grad: 69 | # if batch_id == 0: 70 | # client_grad.append(params.grad.clone()) 71 | # else: 72 | # client_grad[i] += params.grad.clone() 73 | # optimizer.step() 74 | # if batch_id==2: 75 | # break 76 | # 77 | # print(client_grad[-2].cpu().data.numpy().shape) 78 | # print(np.array(client_grad[-2].cpu().data.numpy().shape)) 79 | # grad_len = np.array(client_grad[-2].cpu().data.numpy().shape).prod() 80 | # print(grad_len) 81 | # memory = np.zeros((1, grad_len)) 82 | # grads = np.zeros((1, grad_len)) 83 | # grads[0] = np.reshape(client_grad[-2].cpu().data.numpy(), (grad_len)) 84 | # print(grads) 85 | # print(grads[0].shape) 86 | 87 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtuann/fedlearn-backdoor-attacks/b025e763558d1e3d6a5872be7d158e8c157e7bf1/models/__init__.py -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Model(nn.Module): 5 | """ 6 | Base class for models with added support for GradCam activation map 7 | and a SentiNet defense. The GradCam design is taken from: 8 | https://medium.com/@stepanulyanin/implementing-grad-cam-in-pytorch-ea0937c31e82 9 | If you are not planning to utilize SentiNet defense just import any model 10 | you like for your tasks. 11 | """ 12 | 13 | def __init__(self): 14 | super().__init__() 15 | self.gradient = None 16 | 17 | def activations_hook(self, grad): 18 | self.gradient = grad 19 | 20 | def get_gradient(self): 21 | return self.gradient 22 | 23 | def get_activations(self, x): 24 | return self.features(x) 25 | 26 | def switch_grads(self, enable=True): 27 | for i, n in self.named_parameters(): 28 | n.requires_grad_(enable) 29 | 30 | def features(self, x): 31 | """ 32 | Get latent representation, eg logit layer. 33 | :param x: 34 | :return: 35 | """ 36 | raise NotImplemented 37 | 38 | def forward(self, x, latent=False): 39 | raise NotImplemented 40 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.hub import load_state_dict_from_url 4 | from models.model import Model 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 8 | 'wide_resnet50_2', 'wide_resnet101_2'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=dilation, groups=groups, bias=False, 28 | dilation=dilation) 29 | 30 | 31 | def conv1x1(in_planes, out_planes, stride=1): 32 | """1x1 convolution""" 33 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 34 | bias=False) 35 | 36 | 37 | class BasicBlock(nn.Module): 38 | expansion = 1 39 | 40 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 41 | base_width=64, dilation=1, norm_layer=None): 42 | super(BasicBlock, self).__init__() 43 | if norm_layer is None: 44 | norm_layer = nn.BatchNorm2d 45 | if groups != 1 or base_width != 64: 46 | raise ValueError( 47 | 'BasicBlock only supports groups=1 and base_width=64') 48 | if dilation > 1: 49 | raise NotImplementedError( 50 | "Dilation > 1 not supported in BasicBlock") 51 | # Both self.conv1 and self.downsample layers downsample the 52 | # input when stride != 1 53 | self.conv1 = conv3x3(inplanes, planes, stride) 54 | self.bn1 = norm_layer(planes) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.conv2 = conv3x3(planes, planes) 57 | self.bn2 = norm_layer(planes) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | identity = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | 71 | if self.downsample is not None: 72 | identity = self.downsample(x) 73 | 74 | out += identity 75 | out = self.relu(out) 76 | 77 | return out 78 | 79 | 80 | class Bottleneck(nn.Module): 81 | expansion = 4 82 | 83 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 84 | base_width=64, dilation=1, norm_layer=None): 85 | super(Bottleneck, self).__init__() 86 | if norm_layer is None: 87 | norm_layer = nn.BatchNorm2d 88 | width = int(planes * (base_width / 64.)) * groups 89 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 90 | self.conv1 = conv1x1(inplanes, width) 91 | self.bn1 = norm_layer(width) 92 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 93 | self.bn2 = norm_layer(width) 94 | self.conv3 = conv1x1(width, planes * self.expansion) 95 | self.bn3 = norm_layer(planes * self.expansion) 96 | self.relu = nn.ReLU(inplace=True) 97 | self.downsample = downsample 98 | self.stride = stride 99 | 100 | def forward(self, x): 101 | identity = x 102 | 103 | out = self.conv1(x) 104 | out = self.bn1(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv2(out) 108 | out = self.bn2(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv3(out) 112 | out = self.bn3(out) 113 | 114 | if self.downsample is not None: 115 | identity = self.downsample(x) 116 | 117 | out += identity 118 | out = self.relu(out) 119 | 120 | return out 121 | 122 | 123 | class ResNet(Model): 124 | 125 | def __init__(self, block, layers, num_classes=1000, 126 | zero_init_residual=False, 127 | groups=1, width_per_group=64, 128 | replace_stride_with_dilation=None, 129 | norm_layer=None): 130 | super(ResNet, self).__init__() 131 | if norm_layer is None: 132 | norm_layer = nn.BatchNorm2d 133 | self._norm_layer = norm_layer 134 | 135 | self.inplanes = 64 136 | self.dilation = 1 137 | if replace_stride_with_dilation is None: 138 | # each element in the tuple indicates if we should replace 139 | # the 2x2 stride with a dilated convolution instead 140 | replace_stride_with_dilation = [False, False, False] 141 | if len(replace_stride_with_dilation) != 3: 142 | raise ValueError("replace_stride_with_dilation should be None " 143 | "or a 3-element tuple, got {}".format( 144 | replace_stride_with_dilation)) 145 | self.groups = groups 146 | self.base_width = width_per_group 147 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, 148 | padding=3, 149 | bias=False) 150 | self.bn1 = norm_layer(self.inplanes) 151 | self.relu = nn.ReLU(inplace=True) 152 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 153 | self.layer1 = self._make_layer(block, 64, layers[0]) 154 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 155 | dilate=replace_stride_with_dilation[0]) 156 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 157 | dilate=replace_stride_with_dilation[1]) 158 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 159 | dilate=replace_stride_with_dilation[2]) 160 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 161 | self.fc = nn.Linear(512 * block.expansion, num_classes) 162 | 163 | for m in self.modules(): 164 | if isinstance(m, nn.Conv2d): 165 | nn.init.kaiming_normal_(m.weight, mode='fan_out', 166 | nonlinearity='relu') 167 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 168 | nn.init.constant_(m.weight, 1) 169 | nn.init.constant_(m.bias, 0) 170 | 171 | # Zero-initialize the last BN in each residual branch, 172 | # so that the residual branch starts with zeros, and 173 | # each residual block behaves like an identity. 174 | # This improves the model by 0.2~0.3% according 175 | # to https://arxiv.org/abs/1706.02677 176 | if zero_init_residual: 177 | for m in self.modules(): 178 | if isinstance(m, Bottleneck): 179 | nn.init.constant_(m.bn3.weight, 0) 180 | elif isinstance(m, BasicBlock): 181 | nn.init.constant_(m.bn2.weight, 0) 182 | 183 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 184 | norm_layer = self._norm_layer 185 | downsample = None 186 | previous_dilation = self.dilation 187 | if dilate: 188 | self.dilation *= stride 189 | stride = 1 190 | if stride != 1 or self.inplanes != planes * block.expansion: 191 | downsample = nn.Sequential( 192 | conv1x1(self.inplanes, planes * block.expansion, stride), 193 | norm_layer(planes * block.expansion), 194 | ) 195 | 196 | layers = [] 197 | layers.append( 198 | block(self.inplanes, planes, stride, downsample, self.groups, 199 | self.base_width, previous_dilation, norm_layer)) 200 | self.inplanes = planes * block.expansion 201 | for _ in range(1, blocks): 202 | layers.append(block(self.inplanes, planes, groups=self.groups, 203 | base_width=self.base_width, 204 | dilation=self.dilation, 205 | norm_layer=norm_layer)) 206 | 207 | return nn.Sequential(*layers) 208 | 209 | def features(self, x): 210 | out1 = self.maxpool(self.relu(self.bn1(self.conv1(x)))) 211 | out2 = self.layer1(out1) 212 | out3 = self.layer2(out2) 213 | out4 = self.layer3(out3) 214 | out5 = self.layer4(out4) 215 | 216 | return out5 217 | 218 | def forward(self, x, latent=False): 219 | x = self.conv1(x) 220 | x = self.bn1(x) 221 | x = self.relu(x) 222 | x = self.maxpool(x) 223 | 224 | x = self.layer1(x) 225 | x = self.layer2(x) 226 | x = self.layer3(x) 227 | layer4_out = self.layer4(x) 228 | 229 | if layer4_out.requires_grad: 230 | layer4_out.register_hook(self.activations_hook) 231 | 232 | x = self.avgpool(layer4_out) 233 | flatten_x = torch.flatten(x, 1) 234 | x = self.fc(flatten_x) 235 | if latent: 236 | return x, flatten_x 237 | else: 238 | return x 239 | 240 | 241 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 242 | model = ResNet(block, layers, **kwargs) 243 | if pretrained: 244 | state_dict = load_state_dict_from_url(model_urls[arch], 245 | progress=progress) 246 | model.load_state_dict(state_dict) 247 | return model 248 | 249 | 250 | def resnet18(pretrained=False, progress=True, **kwargs): 251 | r"""ResNet-18 model from 252 | `"Deep Residual Learning for Image Recognition" 253 | `_ 254 | 255 | Args: 256 | pretrained (bool): If True, returns a model pre-trained on ImageNet 257 | progress (bool): If True, displays a progress bar of the download to stderr 258 | """ 259 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 260 | **kwargs) 261 | 262 | 263 | def resnet34(pretrained=False, progress=True, **kwargs): 264 | r"""ResNet-34 model from 265 | `"Deep Residual Learning for Image Recognition" `_ 266 | 267 | Args: 268 | pretrained (bool): If True, returns a model pre-trained on ImageNet 269 | progress (bool): If True, displays a progress bar of the download to stderr 270 | """ 271 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 272 | **kwargs) 273 | 274 | 275 | def resnet50(pretrained=False, progress=True, **kwargs): 276 | r"""ResNet-50 model from 277 | `"Deep Residual Learning for Image Recognition" `_ 278 | 279 | Args: 280 | pretrained (bool): If True, returns a model pre-trained on ImageNet 281 | progress (bool): If True, displays a progress bar of the download to stderr 282 | """ 283 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 284 | **kwargs) 285 | 286 | 287 | def resnet101(pretrained=False, progress=True, **kwargs): 288 | r"""ResNet-101 model from 289 | `"Deep Residual Learning for Image Recognition" `_ 290 | 291 | Args: 292 | pretrained (bool): If True, returns a model pre-trained on ImageNet 293 | progress (bool): If True, displays a progress bar of the download to stderr 294 | """ 295 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 296 | **kwargs) 297 | 298 | 299 | def resnet152(pretrained=False, progress=True, **kwargs): 300 | r"""ResNet-152 model from 301 | `"Deep Residual Learning for Image Recognition" 302 | `_ 303 | 304 | Args: 305 | pretrained (bool): If True, returns a model pre-trained on ImageNet 306 | progress (bool): If True, displays a progress bar of the download to stderr 307 | """ 308 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 309 | **kwargs) 310 | 311 | 312 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 313 | r"""ResNeXt-50 32x4d model from 314 | `"Aggregated Residual Transformation for Deep Neural Networks" 315 | `_ 316 | 317 | Args: 318 | pretrained (bool): If True, returns a model pre-trained on ImageNet 319 | progress (bool): If True, displays a progress bar of the download to stderr 320 | """ 321 | kwargs['groups'] = 32 322 | kwargs['width_per_group'] = 4 323 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 324 | pretrained, progress, **kwargs) 325 | 326 | 327 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 328 | r"""ResNeXt-101 32x8d model from 329 | `"Aggregated Residual Transformation for Deep Neural Networks" 330 | `_ 331 | 332 | Args: 333 | pretrained (bool): If True, returns a model pre-trained on ImageNet 334 | progress (bool): If True, displays a progress bar of the download to stderr 335 | """ 336 | kwargs['groups'] = 32 337 | kwargs['width_per_group'] = 8 338 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 339 | pretrained, progress, **kwargs) 340 | 341 | 342 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 343 | r"""Wide ResNet-50-2 model from 344 | `"Wide Residual Networks" `_ 345 | 346 | The model is the same as ResNet except for the bottleneck number of channels 347 | which is twice larger in every block. The number of channels in outer 1x1 348 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 349 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 350 | 351 | Args: 352 | pretrained (bool): If True, returns a model pre-trained on ImageNet 353 | progress (bool): If True, displays a progress bar of the download to stderr 354 | """ 355 | kwargs['width_per_group'] = 64 * 2 356 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 357 | pretrained, progress, **kwargs) 358 | 359 | 360 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 361 | r"""Wide ResNet-101-2 model from 362 | `"Wide Residual Networks" `_ 363 | 364 | The model is the same as ResNet except for the bottleneck number of channels 365 | which is twice larger in every block. The number of channels in outer 1x1 366 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 367 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 368 | 369 | Args: 370 | pretrained (bool): If True, returns a model pre-trained on ImageNet 371 | progress (bool): If True, displays a progress bar of the download to stderr 372 | """ 373 | kwargs['width_per_group'] = 64 * 2 374 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 375 | pretrained, progress, **kwargs) 376 | -------------------------------------------------------------------------------- /models/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | For Pre-activation ResNet, see 'preact_resnet.py'. 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from models.simple import SimpleNet 11 | from torch.autograd import Variable 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | out = F.relu(out) 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1): 43 | super(Bottleneck, self).__init__() 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 50 | 51 | self.shortcut = nn.Sequential() 52 | if stride != 1 or in_planes != self.expansion*planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 55 | nn.BatchNorm2d(self.expansion*planes) 56 | ) 57 | 58 | def forward(self, x): 59 | out = F.relu(self.bn1(self.conv1(x))) 60 | out = F.relu(self.bn2(self.conv2(out))) 61 | out = self.bn3(self.conv3(out)) 62 | out += self.shortcut(x) 63 | out = F.relu(out) 64 | return out 65 | 66 | 67 | class ResNet(SimpleNet): 68 | def __init__(self, block, num_blocks, num_classes=10, name=None, created_time=None): 69 | super(ResNet, self).__init__(name, created_time) 70 | self.in_planes = 32 71 | 72 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(32) 74 | self.layer1 = self._make_layer(block, 32, num_blocks[0], stride=1) 75 | self.layer2 = self._make_layer(block, 64, num_blocks[1], stride=2) 76 | self.layer3 = self._make_layer(block, 128, num_blocks[2], stride=2) 77 | self.layer4 = self._make_layer(block, 256, num_blocks[3], stride=2) 78 | self.linear = nn.Linear(256*block.expansion, num_classes) 79 | 80 | def _make_layer(self, block, planes, num_blocks, stride): 81 | strides = [stride] + [1]*(num_blocks-1) 82 | layers = [] 83 | for stride in strides: 84 | layers.append(block(self.in_planes, planes, stride)) 85 | self.in_planes = planes * block.expansion 86 | return nn.Sequential(*layers) 87 | 88 | def forward(self, x): 89 | out = F.relu(self.bn1(self.conv1(x))) 90 | out = self.layer1(out) 91 | out = self.layer2(out) 92 | out = self.layer3(out) 93 | out = self.layer4(out) 94 | out = F.avg_pool2d(out, 4) 95 | out = out.view(out.size(0), -1) 96 | out = self.linear(out) 97 | # for SDTdata 98 | # return F.softmax(out, dim=1) 99 | # for regular output 100 | return out 101 | 102 | 103 | def ResNet18(name=None, created_time=None): 104 | return ResNet(BasicBlock, [2,2,2,2],name='{0}_ResNet_18'.format(name), created_time=created_time) 105 | 106 | def ResNet34(name=None, created_time=None): 107 | return ResNet(BasicBlock, [3,4,6,3],name='{0}_ResNet_34'.format(name), created_time=created_time) 108 | 109 | def ResNet50(name=None, created_time=None): 110 | return ResNet(Bottleneck, [3,4,6,3],name='{0}_ResNet_50'.format(name), created_time=created_time) 111 | 112 | def ResNet101(name=None, created_time=None): 113 | return ResNet(Bottleneck, [3,4,23,3],name='{0}_ResNet'.format(name), created_time=created_time) 114 | 115 | def ResNet152(name=None, created_time=None): 116 | return ResNet(Bottleneck, [3,8,36,3],name='{0}_ResNet'.format(name), created_time=created_time) 117 | 118 | 119 | if __name__ == '__main__': 120 | 121 | net = ResNet18() 122 | y = net(Variable(torch.randn(1,3,32,32))) 123 | print(y.size()) -------------------------------------------------------------------------------- /models/resnet_tinyimagenet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | For Pre-activation ResNet, see 'preact_resnet.py'. 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from models.simple import SimpleNet 11 | from models.model import Model 12 | # from torchvision.models.utils import load_state_dict_from_url 13 | 14 | 15 | __all__ = ['ResNet', 'resnet18'] 16 | 17 | 18 | model_urls = { 19 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 20 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 21 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 22 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 23 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 24 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 25 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 26 | } 27 | 28 | 29 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 30 | """3x3 convolution with padding""" 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 32 | padding=dilation, groups=groups, bias=False, dilation=dilation) 33 | 34 | 35 | def conv1x1(in_planes, out_planes, stride=1): 36 | """1x1 convolution""" 37 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 38 | 39 | 40 | class BasicBlock(nn.Module): 41 | expansion = 1 42 | 43 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 44 | base_width=64, dilation=1, norm_layer=None): 45 | super(BasicBlock, self).__init__() 46 | if norm_layer is None: 47 | norm_layer = nn.BatchNorm2d 48 | if groups != 1 or base_width != 64: 49 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 50 | if dilation > 1: 51 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 52 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 53 | self.conv1 = conv3x3(inplanes, planes, stride) 54 | self.bn1 = norm_layer(planes) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.conv2 = conv3x3(planes, planes) 57 | self.bn2 = norm_layer(planes) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | identity = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | 71 | if self.downsample is not None: 72 | identity = self.downsample(x) 73 | 74 | out += identity 75 | out = self.relu(out) 76 | 77 | return out 78 | 79 | 80 | class Bottleneck(nn.Module): 81 | expansion = 4 82 | 83 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 84 | base_width=64, dilation=1, norm_layer=None): 85 | super(Bottleneck, self).__init__() 86 | if norm_layer is None: 87 | norm_layer = nn.BatchNorm2d 88 | width = int(planes * (base_width / 64.)) * groups 89 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 90 | self.conv1 = conv1x1(inplanes, width) 91 | self.bn1 = norm_layer(width) 92 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 93 | self.bn2 = norm_layer(width) 94 | self.conv3 = conv1x1(width, planes * self.expansion) 95 | self.bn3 = norm_layer(planes * self.expansion) 96 | self.relu = nn.ReLU(inplace=True) 97 | self.downsample = downsample 98 | self.stride = stride 99 | 100 | def forward(self, x): 101 | identity = x 102 | 103 | out = self.conv1(x) 104 | out = self.bn1(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv2(out) 108 | out = self.bn2(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv3(out) 112 | out = self.bn3(out) 113 | 114 | if self.downsample is not None: 115 | identity = self.downsample(x) 116 | 117 | out += identity 118 | out = self.relu(out) 119 | 120 | return out 121 | 122 | class ResNet(Model): 123 | 124 | def __init__(self, block, layers, num_classes=1000, 125 | zero_init_residual=False, 126 | groups=1, width_per_group=64, 127 | replace_stride_with_dilation=None, 128 | norm_layer=None): 129 | super(ResNet, self).__init__() 130 | if norm_layer is None: 131 | norm_layer = nn.BatchNorm2d 132 | self._norm_layer = norm_layer 133 | 134 | self.inplanes = 64 135 | self.dilation = 1 136 | if replace_stride_with_dilation is None: 137 | # each element in the tuple indicates if we should replace 138 | # the 2x2 stride with a dilated convolution instead 139 | replace_stride_with_dilation = [False, False, False] 140 | if len(replace_stride_with_dilation) != 3: 141 | raise ValueError("replace_stride_with_dilation should be None " 142 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 143 | self.groups = groups 144 | self.base_width = width_per_group 145 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 146 | bias=False) 147 | self.bn1 = norm_layer(self.inplanes) 148 | self.relu = nn.ReLU(inplace=True) 149 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 150 | self.layer1 = self._make_layer(block, 64, layers[0]) 151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 152 | dilate=replace_stride_with_dilation[0]) 153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 154 | dilate=replace_stride_with_dilation[1]) 155 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 156 | dilate=replace_stride_with_dilation[2]) 157 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 158 | self.fc = nn.Linear(512 * block.expansion, num_classes) 159 | 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 163 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 164 | nn.init.constant_(m.weight, 1) 165 | nn.init.constant_(m.bias, 0) 166 | 167 | # Zero-initialize the last BN in each residual branch, 168 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 169 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 170 | if zero_init_residual: 171 | for m in self.modules(): 172 | if isinstance(m, Bottleneck): 173 | nn.init.constant_(m.bn3.weight, 0) 174 | elif isinstance(m, BasicBlock): 175 | nn.init.constant_(m.bn2.weight, 0) 176 | 177 | # change the model to fit our tiny-imagenet-200 (200 classes) 178 | self.avgpool = nn.AdaptiveAvgPool2d(1) 179 | num_ftrs = self.fc.in_features 180 | self.fc = nn.Linear(num_ftrs, 200) 181 | 182 | 183 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 184 | norm_layer = self._norm_layer 185 | downsample = None 186 | previous_dilation = self.dilation 187 | if dilate: 188 | self.dilation *= stride 189 | stride = 1 190 | if stride != 1 or self.inplanes != planes * block.expansion: 191 | downsample = nn.Sequential( 192 | conv1x1(self.inplanes, planes * block.expansion, stride), 193 | norm_layer(planes * block.expansion), 194 | ) 195 | 196 | layers = [] 197 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 198 | self.base_width, previous_dilation, norm_layer)) 199 | self.inplanes = planes * block.expansion 200 | for _ in range(1, blocks): 201 | layers.append(block(self.inplanes, planes, groups=self.groups, 202 | base_width=self.base_width, dilation=self.dilation, 203 | norm_layer=norm_layer)) 204 | 205 | return nn.Sequential(*layers) 206 | 207 | def forward(self, x): 208 | x = self.conv1(x) 209 | x = self.bn1(x) 210 | x = self.relu(x) 211 | x = self.maxpool(x) 212 | 213 | x = self.layer1(x) 214 | x = self.layer2(x) 215 | x = self.layer3(x) 216 | x = self.layer4(x) 217 | 218 | x = self.avgpool(x) 219 | x = x.reshape(x.size(0), -1) 220 | x = self.fc(x) 221 | return x 222 | 223 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 224 | model = ResNet(block,layers, **kwargs) 225 | # if pretrained: 226 | # state_dict = load_state_dict_from_url(model_urls[arch], 227 | # progress=progress) 228 | # model.load_state_dict(state_dict) 229 | return model 230 | 231 | 232 | def resnet18(pretrained=False, progress=True, **kwargs): 233 | """Constructs a ResNet-18 model. 234 | 235 | Args: 236 | pretrained (bool): If True, returns a model pre-trained on ImageNet 237 | progress (bool): If True, displays a progress bar of the download to stderr 238 | """ 239 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 240 | **kwargs) 241 | 242 | 243 | if __name__ == '__main__': 244 | 245 | import torchvision.transforms as transforms 246 | import torchvision.datasets as datasets 247 | import os 248 | import torch.utils.data as data 249 | 250 | 251 | # Data loading code 252 | data_transforms = { 253 | 'train': transforms.Compose([ 254 | transforms.Resize(224), 255 | transforms.RandomHorizontalFlip(), 256 | transforms.ToTensor(), 257 | ]), 258 | 'val': transforms.Compose([ 259 | transforms.Resize(224), 260 | transforms.ToTensor(), 261 | ]), 262 | } 263 | 264 | data_dir = '../data/tiny-imagenet-200/' 265 | 266 | image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), 267 | data_transforms[x]) 268 | for x in ['train', 'val']} 269 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=100, 270 | shuffle=False, num_workers=64) 271 | for x in ['train', 'val']} 272 | dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} 273 | 274 | target_model = resnet18(name='Target', 275 | created_time='') 276 | 277 | running_loss = 0.0 278 | running_corrects = 0 279 | running_datasize = 0 280 | # Iterate over data. 281 | 282 | phase='val' 283 | 284 | vis_image=None 285 | criterion = nn.CrossEntropyLoss() 286 | for i, (inputs, labels) in enumerate(dataloaders[phase]): 287 | 288 | output= target_model(inputs) 289 | 290 | break 291 | 292 | 293 | -------------------------------------------------------------------------------- /models/simple.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torchvision import datasets, transforms 7 | from torch.autograd import Variable 8 | import numpy as np 9 | import datetime 10 | 11 | 12 | class SimpleNet(nn.Module): 13 | def __init__(self, name=None, created_time=None): 14 | super(SimpleNet, self).__init__() 15 | self.created_time = created_time 16 | self.name=name 17 | 18 | def train_vis(self, vis, epoch, acc, loss=None, eid='main', is_poisoned=False, name=None): 19 | if name is None: 20 | name = self.name + '_poisoned' if is_poisoned else self.name 21 | vis.line(X=np.array([epoch]), Y=np.array([acc]), name=name, win='train_acc_{0}'.format(self.created_time), env=eid, 22 | update='append' if vis.win_exists('train_acc_{0}'.format(self.created_time), env=eid) else None, 23 | opts=dict(showlegend=True, title='Train Accuracy_{0}'.format(self.created_time), 24 | width=700, height=400)) 25 | if loss is not None: 26 | vis.line(X=np.array([epoch]), Y=np.array([loss]), name=name, env=eid, 27 | win='train_loss_{0}'.format(self.created_time), 28 | update='append' if vis.win_exists('train_loss_{0}'.format(self.created_time), env=eid) else None, 29 | opts=dict(showlegend=True, title='Train Loss_{0}'.format(self.created_time), width=700, height=400)) 30 | return 31 | 32 | def train_batch_vis(self, vis, epoch, data_len, batch, loss, eid='main', name=None, win='train_batch_loss', is_poisoned=False): 33 | if name is None: 34 | name = self.name + '_poisoned' if is_poisoned else self.name 35 | else: 36 | name = name + '_poisoned' if is_poisoned else name 37 | 38 | vis.line(X=np.array([(epoch-1)*data_len+batch]), Y=np.array([loss]), 39 | env=eid, 40 | name=f'{name}' if name is not None else self.name, win=f'{win}_{self.created_time}', 41 | update='append' if vis.win_exists(f'{win}_{self.created_time}', env=eid) else None, 42 | opts=dict(showlegend=True, width=700, height=400, title='Train Batch loss_{0}'.format(self.created_time))) 43 | def track_distance_batch_vis(self,vis, epoch, data_len, batch, distance_to_global_model,eid,name=None,is_poisoned=False): 44 | x= (epoch-1)*data_len+batch+1 45 | 46 | if name is None: 47 | name = self.name + '_poisoned' if is_poisoned else self.name 48 | else: 49 | name = name + '_poisoned' if is_poisoned else name 50 | 51 | 52 | vis.line(Y=np.array([distance_to_global_model]), X=np.array([x]), 53 | win=f"global_dist_{self.created_time}", 54 | env=eid, 55 | name=f'Model_{name}', 56 | update='append' if 57 | vis.win_exists(f"global_dist_{self.created_time}", 58 | env=eid) else None, 59 | opts=dict(showlegend=True, 60 | title=f"Distance to Global {self.created_time}", 61 | width=700, height=400)) 62 | def weight_vis(self,vis,epoch,weight, eid, name,is_poisoned=False): 63 | name = str(name) + '_poisoned' if is_poisoned else name 64 | vis.line(Y=np.array([weight]), X=np.array([epoch]), 65 | win=f"Aggregation_Weight_{self.created_time}", 66 | env=eid, 67 | name=f'Model_{name}', 68 | update='append' if 69 | vis.win_exists(f"Aggregation_Weight_{self.created_time}", 70 | env=eid) else None, 71 | opts=dict(showlegend=True, 72 | title=f"Aggregation Weight {self.created_time}", 73 | width=700, height=400)) 74 | 75 | def alpha_vis(self,vis,epoch,alpha, eid, name,is_poisoned=False): 76 | name = str(name) + '_poisoned' if is_poisoned else name 77 | vis.line(Y=np.array([alpha]), X=np.array([epoch]), 78 | win=f"FG_Alpha_{self.created_time}", 79 | env=eid, 80 | name=f'Model_{name}', 81 | update='append' if 82 | vis.win_exists(f"FG_Alpha_{self.created_time}", 83 | env=eid) else None, 84 | opts=dict(showlegend=True, 85 | title=f"FG Alpha {self.created_time}", 86 | width=700, height=400)) 87 | 88 | def trigger_test_vis(self, vis, epoch, acc, loss, eid, agent_name_key, trigger_name, trigger_value): 89 | vis.line(Y=np.array([acc]), X=np.array([epoch]), 90 | win=f"poison_triggerweight_vis_acc_{self.created_time}", 91 | env=eid, 92 | name=f'{agent_name_key}_[{trigger_name}]_{trigger_value}', 93 | update='append' if vis.win_exists(f"poison_trigger_acc_{self.created_time}", 94 | env=eid) else None, 95 | opts=dict(showlegend=True, 96 | title=f"Backdoor Trigger Test Accuracy_{self.created_time}", 97 | width=700, height=400)) 98 | if loss is not None: 99 | vis.line(Y=np.array([loss]), X=np.array([epoch]), 100 | win=f"poison_trigger_loss_{self.created_time}", 101 | env=eid, 102 | name=f'{agent_name_key}_[{trigger_name}]_{trigger_value}', 103 | update='append' if vis.win_exists(f"poison_trigger_loss_{self.created_time}", 104 | env=eid) else None, 105 | opts=dict(showlegend=True, 106 | title=f"Backdoor Trigger Test Loss_{self.created_time}", 107 | width=700, height=400)) 108 | 109 | def trigger_agent_test_vis(self, vis, epoch, acc, loss, eid, name): 110 | vis.line(Y=np.array([acc]), X=np.array([epoch]), 111 | win=f"poison_state_trigger_acc_{self.created_time}", 112 | env=eid, 113 | name=f'{name}', 114 | update='append' if vis.win_exists(f"poison_state_trigger_acc_{self.created_time}", 115 | env=eid) else None, 116 | opts=dict(showlegend=True, 117 | title=f"Backdoor State Trigger Test Accuracy_{self.created_time}", 118 | width=700, height=400)) 119 | if loss is not None: 120 | vis.line(Y=np.array([loss]), X=np.array([epoch]), 121 | win=f"poison_state_trigger_loss_{self.created_time}", 122 | env=eid, 123 | name=f'{name}', 124 | update='append' if vis.win_exists(f"poison_state_trigger_loss_{self.created_time}", 125 | env=eid) else None, 126 | opts=dict(showlegend=True, 127 | title=f"Backdoor State Trigger Test Loss_{self.created_time}", 128 | width=700, height=400)) 129 | 130 | 131 | def poison_test_vis(self, vis, epoch, acc, loss, eid, agent_name_key): 132 | name= agent_name_key 133 | # name= f'Model_{name}' 134 | 135 | vis.line(Y=np.array([acc]), X=np.array([epoch]), 136 | win=f"poison_test_acc_{self.created_time}", 137 | env=eid, 138 | name=name, 139 | update='append' if vis.win_exists(f"poison_test_acc_{self.created_time}", 140 | env=eid) else None, 141 | opts=dict(showlegend=True, 142 | title=f"Backdoor Task Accuracy_{self.created_time}", 143 | width=700, height=400)) 144 | if loss is not None: 145 | vis.line(Y=np.array([loss]), X=np.array([epoch]), 146 | win=f"poison_loss_acc_{self.created_time}", 147 | env=eid, 148 | name=name, 149 | update='append' if vis.win_exists(f"poison_loss_acc_{self.created_time}", 150 | env=eid) else None, 151 | opts=dict(showlegend=True, 152 | title=f"Backdoor Task Test Loss_{self.created_time}", 153 | width=700, height=400)) 154 | 155 | def additional_test_vis(self, vis, epoch, acc, loss, eid, agent_name_key): 156 | name = agent_name_key 157 | vis.line(Y=np.array([acc]), X=np.array([epoch]), 158 | win=f"additional_test_acc_{self.created_time}", 159 | env=eid, 160 | name=name, 161 | update='append' if vis.win_exists(f"additional_test_acc_{self.created_time}", 162 | env=eid) else None, 163 | opts=dict(showlegend=True, 164 | title=f"Additional Test Accuracy_{self.created_time}", 165 | width=700, height=400)) 166 | if loss is not None: 167 | vis.line(Y=np.array([loss]), X=np.array([epoch]), 168 | win=f"additional_test_loss_{self.created_time}", 169 | env=eid, 170 | name=name, 171 | update='append' if vis.win_exists(f"additional_test_loss_{self.created_time}", 172 | env=eid) else None, 173 | opts=dict(showlegend=True, 174 | title=f"Additional Test Loss_{self.created_time}", 175 | width=700, height=400)) 176 | 177 | 178 | def test_vis(self, vis, epoch, acc, loss, eid, agent_name_key): 179 | name= agent_name_key 180 | # name= f'Model_{name}' 181 | 182 | vis.line(Y=np.array([acc]), X=np.array([epoch]), 183 | win=f"test_acc_{self.created_time}", 184 | env=eid, 185 | name=name, 186 | update='append' if vis.win_exists(f"test_acc_{self.created_time}", 187 | env=eid) else None, 188 | opts=dict(showlegend=True, 189 | title=f"Main Task Test Accuracy_{self.created_time}", 190 | width=700, height=400)) 191 | if loss is not None: 192 | vis.line(Y=np.array([loss]), X=np.array([epoch]), 193 | win=f"test_loss_{self.created_time}", 194 | env=eid, 195 | name=name, 196 | update='append' if vis.win_exists(f"test_loss_{self.created_time}", 197 | env=eid) else None, 198 | opts=dict(showlegend=True, 199 | title=f"Main Task Test Loss_{self.created_time}", 200 | width=700, height=400)) 201 | 202 | 203 | def save_stats(self, epoch, loss, acc): 204 | self.stats['epoch'].append(epoch) 205 | self.stats['loss'].append(loss) 206 | self.stats['acc'].append(acc) 207 | 208 | def copy_params(self, state_dict, coefficient_transfer=100): 209 | 210 | own_state = self.state_dict() 211 | 212 | for name, param in state_dict.items(): 213 | if name in own_state: 214 | shape = param.shape 215 | #random_tensor = (torch.cuda.FloatTensor(shape).random_(0, 100) <= coefficient_transfer).type(torch.cuda.FloatTensor) 216 | # negative_tensor = (random_tensor*-1)+1 217 | # own_state[name].copy_(param) 218 | own_state[name].copy_(param.clone()) 219 | 220 | 221 | 222 | 223 | class SimpleMnist(SimpleNet): 224 | def __init__(self, name=None, created_time=None): 225 | super(SimpleMnist, self).__init__(name, created_time) 226 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 227 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 228 | self.conv2_drop = nn.Dropout2d() 229 | self.fc1 = nn.Linear(320, 50) 230 | self.fc2 = nn.Linear(50, 10) 231 | 232 | 233 | def forward(self, x): 234 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 235 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 236 | x = x.view(-1, 320) 237 | x = F.relu(self.fc1(x)) 238 | x = F.dropout(x, training=self.training) 239 | x = self.fc2(x) 240 | return F.log_softmax(x, dim=1) -------------------------------------------------------------------------------- /synthesizers/attack_models/autoencoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class MNISTAutoencoder(nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | self.encoder = nn.Sequential( 8 | nn.Conv2d(1, 16, 3, stride=3, padding=1), # b, 16, 10, 10 9 | nn.BatchNorm2d(16), 10 | nn.ReLU(True), 11 | nn.MaxPool2d(2, stride=2), # b, 16, 5, 5 12 | nn.Conv2d(16, 64, 3, stride=2, padding=1), # b, 8, 3, 3 13 | nn.BatchNorm2d(64), 14 | nn.ReLU(True), 15 | nn.MaxPool2d(2, stride=1) # b, 8, 2, 2 16 | ) 17 | self.decoder = nn.Sequential( 18 | nn.ConvTranspose2d(64, 128, 3, stride=2), # b, 16, 5, 5 19 | nn.BatchNorm2d(128), 20 | nn.ReLU(True), 21 | nn.ConvTranspose2d(128, 64, 5, stride=3, padding=1), # b, 8, 15, 15 22 | nn.BatchNorm2d(64), 23 | nn.ReLU(True), 24 | nn.ConvTranspose2d(64, 1, 2, stride=2, padding=1), # b, 1, 28, 28 25 | nn.BatchNorm2d(1), 26 | nn.Tanh() 27 | ) 28 | 29 | def forward(self, x): 30 | x = self.encoder(x) 31 | x = self.decoder(x) 32 | return x 33 | 34 | class Autoencoder(nn.Module): 35 | def __init__(self): 36 | super(Autoencoder, self).__init__() 37 | self.encoder = nn.Sequential( 38 | nn.Conv2d(3, 16, 4, stride=2, padding=1), 39 | nn.BatchNorm2d(16), 40 | nn.ReLU(True), 41 | nn.Conv2d(16, 32, 4, stride=2, padding=1), 42 | nn.BatchNorm2d(32), 43 | nn.ReLU(True), 44 | nn.Conv2d(32, 64, 4, stride=2, padding=1), 45 | nn.BatchNorm2d(64), 46 | nn.ReLU(True), 47 | nn.Conv2d(64, 128, 4, stride=2, padding=1), 48 | nn.BatchNorm2d(128), 49 | nn.ReLU(True) 50 | ) 51 | self.decoder = nn.Sequential( 52 | nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), 53 | nn.BatchNorm2d(64), 54 | nn.ReLU(True), 55 | nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), 56 | nn.BatchNorm2d(32), 57 | nn.ReLU(True), 58 | nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1), 59 | nn.BatchNorm2d(16), 60 | nn.ReLU(True), 61 | nn.ConvTranspose2d(16, 3, 4, stride=2, padding=1), 62 | nn.Tanh() 63 | ) 64 | 65 | def forward(self, x): 66 | x = self.encoder(x) 67 | x = self.decoder(x) 68 | return x -------------------------------------------------------------------------------- /synthesizers/attack_models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def double_conv(in_channels, out_channels): 6 | return nn.Sequential( 7 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 8 | nn.BatchNorm2d(out_channels), 9 | nn.ReLU(inplace=True), 10 | nn.Conv2d(out_channels, out_channels, 3, padding=1), 11 | nn.BatchNorm2d(out_channels), 12 | nn.ReLU(inplace=True) 13 | ) 14 | 15 | 16 | class UNet(nn.Module): 17 | 18 | def __init__(self, out_channel): 19 | super().__init__() 20 | 21 | self.dconv_down1 = double_conv(3, 64) 22 | self.dconv_down2 = double_conv(64, 128) 23 | self.dconv_down3 = double_conv(128, 256) 24 | self.dconv_down4 = double_conv(256, 512) 25 | 26 | self.maxpool = nn.AvgPool2d(2) 27 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', 28 | align_corners=True) 29 | 30 | self.dconv_up3 = double_conv(256 + 512, 256) 31 | self.dconv_up2 = double_conv(128 + 256, 128) 32 | self.dconv_up1 = double_conv(128 + 64, 64) 33 | 34 | self.conv_last = nn.Sequential( 35 | nn.Conv2d(64, out_channel, 1), 36 | nn.BatchNorm2d(out_channel), 37 | ) 38 | 39 | def forward(self, x): 40 | conv1 = self.dconv_down1(x) 41 | x = self.maxpool(conv1) 42 | 43 | conv2 = self.dconv_down2(x) 44 | x = self.maxpool(conv2) 45 | 46 | conv3 = self.dconv_down3(x) 47 | x = self.maxpool(conv3) 48 | 49 | x = self.dconv_down4(x) 50 | 51 | x = self.upsample(x) 52 | x = torch.cat([x, conv3], dim=1) 53 | 54 | x = self.dconv_up3(x) 55 | x = self.upsample(x) 56 | x = torch.cat([x, conv2], dim=1) 57 | 58 | x = self.dconv_up2(x) 59 | x = self.upsample(x) 60 | x = torch.cat([x, conv1], dim=1) 61 | 62 | x = self.dconv_up1(x) 63 | 64 | out = self.conv_last(x) 65 | 66 | out = torch.tanh(out) 67 | 68 | return out -------------------------------------------------------------------------------- /synthesizers/complex_synthesizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from synthesizers.pattern_synthesizer import PatternSynthesizer 4 | 5 | 6 | class ComplexSynthesizer(PatternSynthesizer): 7 | """ 8 | For physical backdoors it's ok to train using pixel pattern that 9 | represents the physical object in the real scene. 10 | """ 11 | 12 | pattern_tensor = torch.tensor([[1.]]) 13 | 14 | def synthesize_labels(self, batch, attack_portion=None): 15 | batch.labels[:attack_portion] = batch.aux[:attack_portion] 16 | return 17 | -------------------------------------------------------------------------------- /synthesizers/iba_synthesizer.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | from torchvision.transforms import transforms, functional 5 | 6 | from synthesizers.synthesizer import Synthesizer 7 | from tasks.task import Task 8 | import copy 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import numpy as np 12 | from termcolor import colored 13 | 14 | from torch.nn.utils import parameters_to_vector, vector_to_parameters 15 | 16 | 17 | transform_to_image = transforms.ToPILImage() 18 | transform_to_tensor = transforms.ToTensor() 19 | 20 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 21 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 22 | IMAGENET_MIN = ((np.array([0,0,0]) - np.array(IMAGENET_DEFAULT_MEAN)) / np.array(IMAGENET_DEFAULT_STD)).min() 23 | IMAGENET_MAX = ((np.array([1,1,1]) - np.array(IMAGENET_DEFAULT_MEAN)) / np.array(IMAGENET_DEFAULT_STD)).max() 24 | 25 | 26 | def get_clip_image(dataset="cifar10"): 27 | if dataset in ['timagenet', 'tiny-imagenet32', 'Imagenet']: 28 | def clip_image(x): 29 | return torch.clamp(x, IMAGENET_MIN, IMAGENET_MAX) 30 | elif dataset in ['cifar10', 'Cifar10']: 31 | def clip_image(x): 32 | return torch.clamp(x, IMAGENET_MIN, IMAGENET_MAX) 33 | elif dataset in ['mnist', 'MNIST']: 34 | def clip_image(x): 35 | return torch.clamp(x, -1.0, 1.0) 36 | elif dataset == 'gtsrb': 37 | def clip_image(x): 38 | return torch.clamp(x, IMAGENET_MIN, IMAGENET_MAX) 39 | else: 40 | raise Exception(f'Invalid dataset: {dataset}') 41 | return clip_image 42 | 43 | def apply_grad_mask(model, mask_grad_list): 44 | mask_grad_list_copy = iter(mask_grad_list) 45 | # print(f"mask_grad_list_copy: {mask_grad_list_copy}") 46 | for name, parms in model.named_parameters(): 47 | if parms.requires_grad: 48 | parms.grad = parms.grad * next(mask_grad_list_copy) 49 | 50 | 51 | class IBASynthesizer(Synthesizer): 52 | 53 | def __init__(self, task: Task): 54 | super().__init__(task) 55 | # self.make_pattern(self.pattern_tensor, self.x_top, self.y_top) 56 | tgt_model_path = "/home/vishc2/tuannm/fedlearn-backdoor-attacks/synthesizers/checkpoint/lira/lira_cifar10_vgg9_0.03.pt" 57 | self.atk_model = self.load_adversarial_model(tgt_model_path, self.params.task, self.params.device) 58 | 59 | def get_target_transform(self, target_label, mode="all2one", num_classes=10): 60 | """ 61 | Get target transform function 62 | """ 63 | if mode == "all2one": 64 | target_transform = lambda x: torch.ones_like(x) * target_label 65 | elif mode == 'all2all': 66 | target_transform = lambda x: (x + 1) % num_classes 67 | else: 68 | raise Exception(f'Invalid mode {mode}') 69 | 70 | return target_transform 71 | 72 | def create_trigger_model(self, dataset, device="cpu", attack_model=None): 73 | """ Create trigger model for IBA """ 74 | # print(f"device: {device}") 75 | # /home/vishc2/tuannm/fedlearn-backdoor-attacks/synthesizers 76 | if dataset in ['cifar10', 'Cifar10']: 77 | from synthesizers.attack_models.unet import UNet 78 | atkmodel = UNet(3).to(device) 79 | elif dataset in ['mnist', 'MNIST']: 80 | from synthesizers.attack_models.autoencoders import MNISTAutoencoder as Autoencoder 81 | atkmodel = Autoencoder().to(device) 82 | elif dataset == 'timagenet' or dataset == 'tiny-imagenet32' or dataset == 'gtsrb' or dataset == 'Imagenet': 83 | if attack_model is None: 84 | from synthesizers.attack_models.autoencoders import Autoencoder 85 | atkmodel = Autoencoder().to(device) 86 | elif attack_model == 'unet': 87 | from synthesizers.attack_models.unet import UNet 88 | atkmodel = UNet(3).to(device) 89 | else: 90 | raise Exception(f'Invalid atk model {dataset}') 91 | return atkmodel 92 | 93 | def load_adversarial_model(self, checkpoint_path, dataset, device): 94 | # print(f"CKPT file for attack model: {checkpoint_path}") 95 | atkmodel = self.create_trigger_model(dataset, device) 96 | atkmodel.load_state_dict(torch.load(checkpoint_path)) 97 | # print(colored("Load attack model sucessfully!", "red")) 98 | return atkmodel 99 | 100 | def get_poison_batch_adversarial(self, bptt, dataset, device, evaluation=False, target_transform=None, atk_model=None, ratio=0.75, atk_eps=0.75): 101 | images, targets = bptt 102 | poisoning_per_batch = int(ratio*len(images)) 103 | images = images.to(device) 104 | targets = targets.to(device) 105 | poison_count= 0 106 | 107 | clip_image = get_clip_image(dataset) 108 | 109 | with torch.no_grad(): 110 | noise = atk_model(images) * atk_eps 111 | atkdata = clip_image(images + noise) 112 | if target_transform: 113 | atktarget = target_transform(targets) 114 | atktarget = targets.to(device) 115 | if not evaluation: 116 | atkdata = atkdata[:poisoning_per_batch] 117 | atktarget = atktarget[:poisoning_per_batch] 118 | poison_count = len(atkdata) 119 | 120 | return atkdata.to(device), atktarget.to(device), poison_count 121 | 122 | def make_pattern(self, pattern_tensor, x_top, y_top): 123 | # TRAIN IBA 124 | # for e in range(self.params.iba_epochs): 125 | # self.train_iba(self.task.model, self.task.atk_model, self.task.tgt_model, self.task.optimizer, self.task.atkmodel_optimizer, self.task.train_loader, 126 | # self.task.criterion, atkmodel_train=True, device=self.params.device, logger=self.params.logger, 127 | # adv_optimizer=self.task.adv_optimizer, clip_image=get_clip_image(self.params.dataset), target_transform=self.task.target_transform, 128 | # dataset=self.params.dataset, mu=self.params.mu, aggregator=self.params.aggregator, attack_alpha=self.params.attack_alpha, 129 | # attack_portion=self.params.attack_portion, atk_eps=self.params.atk_eps, pgd_attack=self.params.pgd_attack, 130 | # proj=self.params.proj, pgd_eps=self.params.pgd_eps, project_frequency=self.params.project_frequency, 131 | # mask_grad_list=self.task.mask_grad_list, model_original=self.task.model_original, local_e=e) 132 | # self.task.copy_params(self.task.model, self.task.atk_model) 133 | # Load model from file and train IBA 134 | # from torch import optim 135 | # atk_optimizer = optim.Adam(tgt_model.parameters(), lr=0.00005) 136 | # tgt_model_path = "/home/vishc2/tuannm/fedlearn-backdoor-attacks/synthesizers/checkpoint/lira/lira_cifar10_vgg9_0.03.pt" 137 | # target_transform = self.get_target_transform(0, mode="all2one", num_classes=10) 138 | 139 | # # task in [Imagenet, Cifar10, MNIST] 140 | 141 | # tgt_model = self.load_adversarial_model(tgt_model_path, self.params.task, self.params.device) 142 | 143 | # target_transform = self.get_target_transform(self.params.backdoor_label, mode="all2one", num_classes=self.params.num_classes) 144 | 145 | # poison_data, poison_target, poison_num = self.get_poison_batch_adversarial(batch_cp, dataset=dataset, device=device, 146 | # target_transform=target_transform, atk_model=atk_model) 147 | 148 | # self.train_lira(net, tgt_model, None, None, atk_optimizer, local_train_dl, 149 | # criterion, 0.5, 0.5, 1.0, None, tgt_tf, 1, 150 | # atkmodel_train=True, device=self.device, pgd_attack=False) 151 | 152 | pass 153 | def train_iba(self, model, atkmodel, tgtmodel, optimizer, atkmodel_optimizer, train_loader, criterion, atkmodel_train=False, 154 | device=None, logger=None, adv_optimizer=None, clip_image=None, target_transform=None, dataset=None, 155 | mu=0.1, aggregator="fedprox", attack_alpha=1.0, attack_portion=1.0, atk_eps=0.1, pgd_attack=False, 156 | proj="l_inf", pgd_eps=0.3, project_frequency=1, mask_grad_list=None, model_original=None, local_e=0): 157 | 158 | wg_clone = copy.deepcopy(model) 159 | loss_fn = nn.CrossEntropyLoss() 160 | func_fn = loss_fn 161 | 162 | correct_clean = 0 163 | correct_poison = 0 164 | 165 | poison_size = 0 166 | clean_size = 0 167 | loss_list = [] 168 | 169 | if not atkmodel_train: 170 | model.train() 171 | # Sub-training phase 172 | for batch_idx, batch in enumerate(train_loader): 173 | bs = len(batch) 174 | data, targets = batch 175 | # data, target = data.to(device), target.to(device) 176 | # clean_images, clean_targets, poison_images, poison_targets, poisoning_per_batch = get_poison_batch(batch, attack_portion) 177 | clean_images, clean_targets = copy.deepcopy(data).to(device), copy.deepcopy(targets).to(device) 178 | poison_images, poison_targets = copy.deepcopy(data).to(device), copy.deepcopy(targets).to(device) 179 | # dataset_size += len(data) 180 | clean_size += len(clean_images) 181 | optimizer.zero_grad() 182 | if pgd_attack: 183 | adv_optimizer.zero_grad() 184 | output = model(clean_images) 185 | loss_clean = loss_fn(output, clean_targets) 186 | 187 | if attack_alpha == 1.0: 188 | optimizer.zero_grad() 189 | loss_clean.backward() 190 | if not pgd_attack: 191 | optimizer.step() 192 | else: 193 | if proj == "l_inf": 194 | w = list(model.parameters()) 195 | # adversarial learning rate 196 | eta = 0.001 197 | for i in range(len(w)): 198 | # uncomment below line to restrict proj to some layers 199 | if True:#i == 6 or i == 8 or i == 10 or i == 0 or i == 18: 200 | w[i].data = w[i].data - eta * w[i].grad.data 201 | # projection step 202 | m1 = torch.lt(torch.sub(w[i], model_original[i]), -pgd_eps) 203 | m2 = torch.gt(torch.sub(w[i], model_original[i]), pgd_eps) 204 | w1 = (model_original[i] - pgd_eps) * m1 205 | w2 = (model_original[i] + pgd_eps) * m2 206 | w3 = (w[i]) * (~(m1+m2)) 207 | wf = w1+w2+w3 208 | w[i].data = wf.data 209 | else: 210 | # do l2_projection 211 | adv_optimizer.step() 212 | w = list(model.parameters()) 213 | w_vec = parameters_to_vector(w) 214 | model_original_vec = parameters_to_vector(model_original) 215 | # make sure you project on last iteration otherwise, high LR pushes you really far 216 | # Start 217 | if (batch_idx%project_frequency == 0 or batch_idx == len(train_loader)-1) and (torch.norm(w_vec - model_original_vec) > pgd_eps): 218 | # project back into norm ball 219 | w_proj_vec = pgd_eps*(w_vec - model_original_vec)/torch.norm( 220 | w_vec-model_original_vec) + model_original_vec 221 | # plug w_proj back into model 222 | vector_to_parameters(w_proj_vec, w) 223 | 224 | pred = output.data.max(1)[1] # get the index of the max log-probability 225 | loss_list.append(loss_clean.item()) 226 | correct_clean += pred.eq(clean_targets.data.view_as(pred)).cpu().sum().item() 227 | else: 228 | if attack_alpha < 1.0: 229 | poison_size += len(poison_images) 230 | # poison_images, poison_targets = poison_images.to(device), poison_targets.to(device) 231 | with torch.no_grad(): 232 | noise = tgtmodel(poison_images) * atk_eps 233 | atkdata = clip_image(poison_images + noise) 234 | atktarget = target_transform(poison_targets) 235 | # atkdata.requires_grad_(False) 236 | # atktarget.requires_grad_(False) 237 | if attack_portion < 1.0: 238 | atkdata = atkdata[:int(attack_portion*bs)] 239 | atktarget = atktarget[:int(attack_portion*bs)] 240 | atkoutput = model(atkdata.detach()) 241 | loss_poison = F.cross_entropy(atkoutput, atktarget.detach()) 242 | else: 243 | loss_poison = torch.tensor(0.0).to(device) 244 | loss2 = loss_clean * attack_alpha + (1.0 - attack_alpha) * loss_poison 245 | 246 | optimizer.zero_grad() 247 | loss2.backward() 248 | if mask_grad_list: 249 | apply_grad_mask(model, mask_grad_list) 250 | if not pgd_attack: 251 | optimizer.step() 252 | else: 253 | if proj == "l_inf": 254 | w = list(model.parameters()) 255 | # adversarial learning rate 256 | eta = 0.001 257 | for i in range(len(w)): 258 | # uncomment below line to restrict proj to some layers 259 | if True:#i == 6 or i == 8 or i == 10 or i == 0 or i == 18: 260 | w[i].data = w[i].data - eta * w[i].grad.data 261 | # projection step 262 | m1 = torch.lt(torch.sub(w[i], model_original[i]), -pgd_eps) 263 | m2 = torch.gt(torch.sub(w[i], model_original[i]), pgd_eps) 264 | w1 = (model_original[i] - pgd_eps) * m1 265 | w2 = (model_original[i] + pgd_eps) * m2 266 | w3 = (w[i]) * (~(m1+m2)) 267 | wf = w1+w2+w3 268 | w[i].data = wf.data 269 | else: 270 | # do l2_projection 271 | adv_optimizer.step() 272 | w = list(model.parameters()) 273 | w_vec = parameters_to_vector(w) 274 | model_original_vec = parameters_to_vector(list(model_original.parameters())) 275 | # make sure you project on last iteration otherwise, high LR pushes you really far 276 | if (local_e%project_frequency == 0 and batch_idx == len(train_loader)-1) and (torch.norm(w_vec - model_original_vec) > pgd_eps): 277 | # project back into norm ball 278 | w_proj_vec = pgd_eps*(w_vec - model_original_vec)/torch.norm( 279 | w_vec-model_original_vec) + model_original_vec 280 | 281 | 282 | # plug w_proj back into model 283 | vector_to_parameters(w_proj_vec, w) 284 | 285 | loss_list.append(loss2.item()) 286 | pred = output.data.max(1)[1] # get the index of the max log-probability 287 | poison_pred = atkoutput.data.max(1)[1] # get the index of the max log-probability 288 | 289 | # correct += pred.eq(target.data.view_as(pred)).cpu().sum().item() 290 | correct_clean += pred.eq(clean_targets.data.view_as(pred)).cpu().sum().item() 291 | correct_poison += poison_pred.eq(atktarget.data.view_as(poison_pred)).cpu().sum().item() 292 | 293 | else: 294 | model.eval() 295 | # atk_optimizer = optim.Adam(atkmodel.parameters(), lr=0.0002) 296 | atkmodel.train() 297 | # optimizer.zero_grad() 298 | for batch_idx, (batch) in enumerate(train_loader): 299 | batch_cp = copy.deepcopy(batch) 300 | data, target = batch 301 | # print(f"len(clean_data): {len(clean_data)}") 302 | data, target = data.to(device), target.to(device) 303 | bs = data.size(0) 304 | atkdata, atktarget, poison_num = get_poison_batch_adversarial_updated(batch_cp, dataset=dataset, device=device, 305 | target_transform=target_transform, atk_model=atkmodel) 306 | # dataset_size += len(data) 307 | poison_size += poison_num 308 | 309 | ############################### 310 | #### Update the classifier #### 311 | ############################### 312 | atkoutput = model(atkdata) 313 | loss_p = func_fn(atkoutput, atktarget) 314 | loss2 = loss_p 315 | 316 | atkmodel_optimizer.zero_grad() 317 | loss2.backward() 318 | atkmodel_optimizer.step() 319 | pred = atkoutput.data.max(1)[1] # get the index of the max log-probability 320 | correct_poison += pred.eq(atktarget.data.view_as(pred)).cpu().sum().item() 321 | loss_list.append(loss2.item()) 322 | 323 | # acc = 100.0 * (float(correct) / float(dataset_size)) 324 | clean_acc = 100.0 * (float(correct_clean)/float(clean_size)) if clean_size else 0.0 325 | poison_acc = 100.0 * (float(correct_poison)/float(poison_size)) if poison_size else 0.0 326 | # poison_acc = 100.0 - poison_acc 327 | 328 | training_avg_loss = sum(loss_list)/len(loss_list) 329 | # training_avg_loss = 0.0 330 | if atkmodel_train: 331 | logger.info(colored("Training loss = {:.2f}, acc = {:.2f} of atk model this epoch".format(training_avg_loss, poison_acc), "yellow")) 332 | else: 333 | logger.info(colored("Training loss = {:.2f}, acc = {:.2f} of cls model this epoch".format(training_avg_loss, clean_acc), "yellow")) 334 | logger.info("Training clean_acc is {:.2f}, poison_acc = {:.2f}".format(clean_acc, poison_acc)) 335 | del wg_clone 336 | 337 | 338 | def synthesize_inputs(self, batch, attack_portion=128): 339 | # TODO something 340 | 341 | # batch.inputs[:attack_portion] = batch.inputs[:attack_portion] 342 | # from torch import optim 343 | # atk_optimizer = optim.Adam(tgt_model.parameters(), lr=0.00005) 344 | 345 | 346 | 347 | # target_transform = self.get_target_transform(self.params.backdoor_label, mode="all2one", num_classes=self.params.num_classes) 348 | 349 | 350 | images = batch.inputs[:attack_portion] 351 | 352 | 353 | device = self.params.device 354 | images = images.to(device) 355 | 356 | clip_image = get_clip_image(self.params.task) 357 | atk_eps = 0.75 358 | 359 | with torch.no_grad(): 360 | noise = self.atk_model(images) * atk_eps 361 | atkdata = clip_image(images + noise) 362 | 363 | 364 | batch.inputs[:attack_portion] = atkdata.to(device) 365 | 366 | 367 | def synthesize_labels(self, batch, attack_portion=None): 368 | batch.labels[:attack_portion].fill_(self.params.backdoor_label) 369 | -------------------------------------------------------------------------------- /synthesizers/pattern_synthesizer.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | from torchvision.transforms import transforms, functional 5 | 6 | from synthesizers.synthesizer import Synthesizer 7 | from tasks.task import Task 8 | 9 | transform_to_image = transforms.ToPILImage() 10 | transform_to_tensor = transforms.ToTensor() 11 | 12 | 13 | class PatternSynthesizer(Synthesizer): 14 | pattern_tensor: torch.Tensor = torch.tensor([ 15 | [1., 0., 1.], 16 | [-10., 1., -10.], 17 | [-10., -10., 0.], 18 | [-10., 1., -10.], 19 | [1., 0., 1.] 20 | ]) 21 | "Just some random 2D pattern." 22 | 23 | x_top = 3 24 | "X coordinate to put the backdoor into." 25 | y_top = 23 26 | "Y coordinate to put the backdoor into." 27 | 28 | mask_value = -10 29 | "A tensor coordinate with this value won't be applied to the image." 30 | 31 | resize_scale = (5, 10) 32 | "If the pattern is dynamically placed, resize the pattern." 33 | 34 | mask: torch.Tensor = None 35 | "A mask used to combine backdoor pattern with the original image." 36 | 37 | pattern: torch.Tensor = None 38 | "A tensor of the `input.shape` filled with `mask_value` except backdoor." 39 | 40 | def __init__(self, task: Task): 41 | super().__init__(task) 42 | self.make_pattern(self.pattern_tensor, self.x_top, self.y_top) 43 | 44 | def make_pattern(self, pattern_tensor, x_top, y_top): 45 | # put pattern into the image 46 | full_image = torch.zeros(self.params.input_shape) 47 | full_image.fill_(self.mask_value) 48 | # full image has a pixel value of -10 49 | x_bot = x_top + pattern_tensor.shape[0] 50 | y_bot = y_top + pattern_tensor.shape[1] 51 | 52 | if x_bot >= self.params.input_shape[1] or \ 53 | y_bot >= self.params.input_shape[2]: 54 | raise ValueError(f'Position of backdoor outside image limits:' 55 | f'image: {self.params.input_shape}, but backdoor' 56 | f'ends at ({x_bot}, {y_bot})') 57 | 58 | full_image[:, x_top:x_bot, y_top:y_bot] = pattern_tensor 59 | # full image has a pixel value of -10 except for the backdoor (pattern_tensor) size: 5 * 3 60 | self.mask = 1 * (full_image != self.mask_value).to(self.params.device) # (0, 1) 61 | # mask is a tensor of 0 and 1, 0 for -10 and 1 for other values 62 | self.pattern = self.task.normalize(full_image).to(self.params.device) # )(-52.5678, 2.7537) 63 | # backdoor pattern has a ->- shape 64 | # patter 65 | ### 1 1 1 66 | ### 0 1 0 67 | ### 0 0 1 68 | ### 0 1 0 69 | ### 1 1 1 70 | # if pattern_tensor size is (11, 7) = (5, 3) * resize; then pattern size is square of 11 * 11 71 | 72 | # import IPython; IPython.embed() 73 | # exit(0) 74 | 75 | def synthesize_inputs(self, batch, attack_portion=None): 76 | pattern, mask = self.get_pattern() 77 | # mask value (0, 1); value 0, keep the original image; value 1, replace with pattern 78 | batch.inputs[:attack_portion] = (1 - mask) * \ 79 | batch.inputs[:attack_portion] + \ 80 | mask * pattern 81 | 82 | return 83 | 84 | def synthesize_labels(self, batch, attack_portion=None): 85 | batch.labels[:attack_portion].fill_(self.params.backdoor_label) 86 | 87 | return 88 | 89 | def get_pattern(self): 90 | if self.params.backdoor_dynamic_position: 91 | resize = random.randint(self.resize_scale[0], self.resize_scale[1]) 92 | pattern = self.pattern_tensor 93 | if random.random() > 0.5: 94 | pattern = functional.hflip(pattern) 95 | image = transform_to_image(pattern) 96 | pattern = transform_to_tensor( 97 | functional.resize(image, 98 | resize, interpolation=0)).squeeze() 99 | 100 | x = random.randint(0, self.params.input_shape[1] \ 101 | - pattern.shape[0] - 1) 102 | y = random.randint(0, self.params.input_shape[2] \ 103 | - pattern.shape[1] - 1) 104 | self.make_pattern(pattern, x, y) 105 | 106 | return self.pattern, self.mask 107 | -------------------------------------------------------------------------------- /synthesizers/physical_synthesizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from synthesizers.synthesizer import Synthesizer 3 | 4 | 5 | class PhysicalSynthesizer(Synthesizer): 6 | """ 7 | For physical backdoors it's ok to train using pixel pattern that 8 | represents the physical object in the real scene. 9 | """ 10 | 11 | pattern_tensor = torch.tensor([[1.]]) -------------------------------------------------------------------------------- /synthesizers/singlepixel_synthesizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from synthesizers.pattern_synthesizer import PatternSynthesizer 3 | 4 | 5 | class SinglePixelSynthesizer(PatternSynthesizer): 6 | pattern_tensor = torch.tensor([[1.]]) 7 | -------------------------------------------------------------------------------- /synthesizers/synthesizer.py: -------------------------------------------------------------------------------- 1 | from tasks.batch import Batch 2 | from tasks.task import Task 3 | from utils.parameters import Params 4 | 5 | 6 | class Synthesizer: 7 | params: Params 8 | task: Task 9 | 10 | def __init__(self, task: Task): 11 | self.task = task 12 | self.params = task.params 13 | 14 | def make_backdoor_batch(self, batch: Batch, test=False, attack=True) -> Batch: 15 | 16 | # Don't attack if only normal loss task. 17 | if not attack: 18 | return batch 19 | 20 | if test: 21 | attack_portion = batch.batch_size 22 | else: 23 | attack_portion = round( 24 | batch.batch_size * self.params.poisoning_proportion) 25 | 26 | backdoored_batch = batch.clone() 27 | self.apply_backdoor(backdoored_batch, attack_portion) 28 | return backdoored_batch 29 | 30 | ### Plot the backdoored image and the original image 31 | # import IPython; IPython.embed(); exit(0) 32 | # batch.inputs.shape = (batch_size, 3, 32, 32) = (64, 3, 32, 32) 33 | # using torch to show the image 34 | import matplotlib.pyplot as plt 35 | import numpy as np 36 | import torchvision 37 | import torchvision.transforms as transforms 38 | def imshow(img): 39 | # img to cpu 40 | img = img.cpu() 41 | # img = img / 2 + 0.5 # unnormalize 42 | npimg = img.numpy() 43 | plt.imshow(np.transpose(npimg, (1, 2, 0))) 44 | plt.show() 45 | imshow(torchvision.utils.make_grid(batch.inputs)) 46 | 47 | # def imshow2(img): 48 | # # img to cpu 49 | # img = img.cpu() 50 | # img = img / 2 + 0.5 # unnormalize 51 | # npimg = img.numpy() 52 | # plt.imshow(np.transpose(npimg, (1, 2, 0))) 53 | # plt.show() 54 | 55 | 56 | # imshow(torchvision.utils.make_grid(batch.inputs)) 57 | # imshow(torchvision.utils.make_grid(backdoored_batch.inputs)) 58 | # import IPython; IPython.embed(); exit(0) 59 | 60 | 61 | 62 | def apply_backdoor(self, batch, attack_portion): 63 | """ 64 | Modifies only a portion of the batch (represents batch poisoning). 65 | 66 | :param batch: 67 | :return: 68 | """ 69 | self.synthesize_inputs(batch=batch, attack_portion=attack_portion) 70 | self.synthesize_labels(batch=batch, attack_portion=attack_portion) 71 | 72 | return 73 | 74 | def synthesize_inputs(self, batch, attack_portion=None): 75 | raise NotImplemented 76 | 77 | def synthesize_labels(self, batch, attack_portion=None): 78 | raise NotImplemented 79 | -------------------------------------------------------------------------------- /tasks/batch.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch 3 | 4 | 5 | @dataclass 6 | class Batch: 7 | batch_id: int 8 | inputs: torch.Tensor 9 | labels: torch.Tensor 10 | 11 | # For PIPA experiment we use this field to store identity label. 12 | aux: torch.Tensor = None 13 | 14 | def __post_init__(self): 15 | self.batch_size = self.inputs.shape[0] 16 | 17 | def to(self, device): 18 | inputs = self.inputs.to(device) 19 | labels = self.labels.to(device) 20 | if self.aux is not None: 21 | aux = self.aux.to(device) 22 | else: 23 | aux = None 24 | return Batch(self.batch_id, inputs, labels, aux) 25 | 26 | def clone(self): 27 | inputs = self.inputs.clone() 28 | labels = self.labels.clone() 29 | if self.aux is not None: 30 | aux = self.aux.clone() 31 | else: 32 | aux = None 33 | return Batch(self.batch_id, inputs, labels, aux) 34 | 35 | 36 | def clip(self, batch_size): 37 | if batch_size is None: 38 | return self 39 | 40 | inputs = self.inputs[:batch_size] 41 | labels = self.labels[:batch_size] 42 | 43 | if self.aux is None: 44 | aux = None 45 | else: 46 | aux = self.aux[:batch_size] 47 | 48 | return Batch(self.batch_id, inputs, labels, aux) -------------------------------------------------------------------------------- /tasks/cifar10_task.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision 3 | from torch import nn 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data import Subset 6 | from torchvision.transforms import transforms 7 | 8 | # from models.resnet import resnet18 9 | from models.resnet_cifar import ResNet18 10 | 11 | from tasks.task import Task 12 | 13 | 14 | class Cifar10Task(Task): 15 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), 16 | (0.2023, 0.1994, 0.2010)) 17 | 18 | def load_data(self): 19 | self.load_cifar_data() 20 | number_of_samples = [] 21 | 22 | if self.params.fl_sample_dirichlet: 23 | # sample indices for participants using Dirichlet distribution 24 | split = min(self.params.fl_total_participants / 10, 1) 25 | all_range = list(range(int(len(self.train_dataset) * split))) 26 | self.train_dataset = Subset(self.train_dataset, all_range) 27 | indices_per_participant = self.sample_dirichlet_train_data( 28 | self.params.fl_total_participants, 29 | alpha=self.params.fl_dirichlet_alpha) 30 | train_loaders, number_of_samples = zip(*[self.get_train(indices) for pos, indices in 31 | indices_per_participant.items()]) 32 | # print("number_of_samples", number_of_samples) 33 | 34 | else: 35 | # sample indices for participants that are equally 36 | # split to 500 images per participant 37 | split = min(self.params.fl_total_participants / 100, 1) 38 | all_range = list(range(int(len(self.train_dataset) * split))) 39 | self.train_dataset = Subset(self.train_dataset, all_range) 40 | random.shuffle(all_range) 41 | 42 | train_loaders, number_of_samples = zip(*[self.get_train_old(all_range, pos) 43 | for pos in range(self.params.fl_total_participants)]) 44 | 45 | self.fl_train_loaders = train_loaders 46 | self.fl_number_of_samples = number_of_samples 47 | # print("fl_number_of_samples", self.fl_number_of_samples) 48 | 49 | 50 | def load_cifar_data(self): 51 | # print(f"Loading CIFAR10 data: {self.params.transform_train}") 52 | # exit(0) 53 | if self.params.transform_train: 54 | transform_train = transforms.Compose([ 55 | transforms.RandomCrop(32, padding=4), 56 | transforms.RandomHorizontalFlip(), 57 | transforms.ToTensor(), 58 | self.normalize, 59 | ]) 60 | else: 61 | transform_train = transforms.Compose([ 62 | transforms.ToTensor(), 63 | self.normalize, 64 | ]) 65 | transform_test = transforms.Compose([ 66 | transforms.ToTensor(), 67 | self.normalize, 68 | ]) 69 | self.train_dataset = torchvision.datasets.CIFAR10( 70 | root=self.params.data_path, 71 | train=True, 72 | download=True, 73 | transform=transform_train) 74 | 75 | self.train_loader = DataLoader(self.train_dataset, 76 | batch_size=self.params.batch_size, 77 | shuffle=True, 78 | num_workers=0) 79 | self.test_dataset = torchvision.datasets.CIFAR10( 80 | root=self.params.data_path, 81 | train=False, 82 | download=True, 83 | transform=transform_test) 84 | self.test_loader = DataLoader(self.test_dataset, 85 | batch_size=self.params.test_batch_size, 86 | shuffle=False, num_workers=0) 87 | 88 | self.classes = ('plane', 'car', 'bird', 'cat', 89 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 90 | return True 91 | 92 | def build_model(self) -> nn.Module: 93 | # model = resnet18(pretrained=False, 94 | # num_classes=len(self.classes)) 95 | model = ResNet18() 96 | return model -------------------------------------------------------------------------------- /tasks/fl_user.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from torch.utils.data.dataloader import DataLoader 3 | 4 | 5 | 6 | @dataclass 7 | class FLUser: 8 | user_id: int = 0 9 | compromised: bool = False 10 | train_loader: DataLoader = None 11 | number_of_samples: int = 0 12 | -------------------------------------------------------------------------------- /tasks/imagenet_task.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torchvision 4 | from torch.utils.data import DataLoader 5 | from torchvision.transforms import transforms 6 | from torch.utils.data import Subset 7 | 8 | from models.resnet_tinyimagenet import resnet18 9 | from tasks.task import Task 10 | import os 11 | import logging 12 | logger = logging.getLogger('logger') 13 | 14 | 15 | class ImagenetTask(Task): 16 | 17 | def load_data(self): 18 | self.load_imagenet_data() 19 | if self.params.fl_sample_dirichlet: 20 | # sample indices for participants using Dirichlet distribution 21 | split = min(self.params.fl_total_participants / 100, 1) 22 | all_range = list(range(int(len(self.train_dataset) * split))) 23 | self.train_dataset = Subset(self.train_dataset, all_range) 24 | indices_per_participant = self.sample_dirichlet_train_data( 25 | self.params.fl_total_participants, 26 | alpha=self.params.fl_dirichlet_alpha) 27 | # train_loaders = [self.get_train(indices) for pos, indices in 28 | # indices_per_participant.items()] 29 | 30 | train_loaders, number_of_samples = zip(*[self.get_train(indices) for pos, indices in 31 | indices_per_participant.items()]) 32 | else: 33 | # sample indices for participants that are equally 34 | # split to 500 images per participant 35 | split = min(self.params.fl_total_participants / 100, 1) 36 | all_range = list(range(int(len(self.train_dataset) * split))) 37 | self.train_dataset = Subset(self.train_dataset, all_range) 38 | random.shuffle(all_range) 39 | train_loaders = [self.get_train_old(all_range, pos) 40 | for pos in 41 | range(self.params.fl_total_participants)] 42 | self.fl_train_loaders = train_loaders 43 | self.fl_number_of_samples = number_of_samples 44 | logger.info(f"Done splitting with #participant: {self.params.fl_total_participants}") 45 | 46 | return 47 | 48 | def load_imagenet_data(self): 49 | 50 | train_transform = transforms.Compose([ 51 | # transforms.RandomResizedCrop(224), 52 | transforms.Resize(224), 53 | transforms.RandomHorizontalFlip(), 54 | transforms.ToTensor(), 55 | self.normalize, 56 | ]) 57 | test_transform = transforms.Compose([ 58 | # transforms.Resize(256), 59 | # transforms.CenterCrop(224), 60 | transforms.Resize(224), 61 | transforms.ToTensor(), 62 | self.normalize, 63 | ]) 64 | 65 | self.train_dataset = torchvision.datasets.ImageFolder( 66 | os.path.join(self.params.data_path, 'train'), 67 | train_transform) 68 | 69 | self.test_dataset = torchvision.datasets.ImageFolder( 70 | os.path.join(self.params.data_path, 'val'), 71 | test_transform) 72 | 73 | self.train_loader = DataLoader(self.train_dataset, 74 | batch_size=self.params.batch_size, 75 | shuffle=True, num_workers=8, pin_memory=True) 76 | self.test_loader = DataLoader(self.test_dataset, 77 | batch_size=self.params.test_batch_size, 78 | shuffle=False, num_workers=8, pin_memory=True) 79 | 80 | def build_model(self) -> None: 81 | return resnet18(pretrained=False) 82 | -------------------------------------------------------------------------------- /tasks/mnist_task.py: -------------------------------------------------------------------------------- 1 | import random 2 | from torch.utils.data import Subset 3 | import torch.utils.data as torch_data 4 | import torchvision 5 | from torchvision.transforms import transforms 6 | 7 | from models.MnistNet import MnistNet 8 | from tasks.task import Task 9 | import logging 10 | logger = logging.getLogger('logger') 11 | 12 | class MNISTTask(Task): 13 | normalize = transforms.Normalize((0.1307,), (0.3081,)) 14 | 15 | def load_data(self): 16 | self.load_mnist_data() 17 | if self.params.fl_sample_dirichlet: 18 | # sample indices for participants using Dirichlet distribution 19 | split = min(self.params.fl_total_participants / 20, 1) 20 | all_range = list(range(int(len(self.train_dataset) * split))) 21 | logger.info(f"all_range: {len(all_range)} len train_dataset: {len(self.train_dataset)}") 22 | # if number of participants is less than 20, then we will sample a subset of the dataset, otherwise we will use the whole dataset 23 | self.train_dataset = Subset(self.train_dataset, all_range) 24 | indices_per_participant = self.sample_dirichlet_train_data( 25 | self.params.fl_total_participants, 26 | alpha=self.params.fl_dirichlet_alpha) 27 | 28 | # train_loaders = [self.get_train(indices) for pos, indices in 29 | # indices_per_participant.items()] 30 | 31 | train_loaders, number_of_samples = zip(*[self.get_train(indices) for pos, indices in 32 | indices_per_participant.items()]) 33 | 34 | else: 35 | # sample indices for participants that are equally 36 | split = min(self.params.fl_total_participants / 20, 1) 37 | all_range = list(range(int(len(self.train_dataset) * split))) 38 | self.train_dataset = Subset(self.train_dataset, all_range) 39 | random.shuffle(all_range) 40 | train_loaders, number_of_samples = zip(*[self.get_train_old(all_range, pos) 41 | for pos in range(self.params.fl_total_participants)]) 42 | 43 | self.fl_train_loaders = train_loaders 44 | self.fl_number_of_samples = number_of_samples 45 | logger.info(f"Done splitting with #participant: {self.params.fl_total_participants}") 46 | return 47 | 48 | 49 | def load_mnist_data(self): 50 | transform_train = transforms.Compose([ 51 | transforms.ToTensor(), 52 | self.normalize 53 | ]) 54 | 55 | transform_test = transforms.Compose([ 56 | transforms.ToTensor(), 57 | self.normalize 58 | ]) 59 | 60 | self.train_dataset = torchvision.datasets.MNIST( 61 | root=self.params.data_path, 62 | train=True, 63 | download=True, 64 | transform=transform_train) 65 | self.train_loader = torch_data.DataLoader(self.train_dataset, 66 | batch_size=self.params.batch_size, 67 | shuffle=True, 68 | num_workers=0) 69 | self.test_dataset = torchvision.datasets.MNIST( 70 | root=self.params.data_path, 71 | train=False, 72 | download=True, 73 | transform=transform_test) 74 | self.test_loader = torch_data.DataLoader(self.test_dataset, 75 | batch_size=self.params.test_batch_size, 76 | shuffle=False, 77 | num_workers=0) 78 | self.classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9) 79 | return True 80 | 81 | def build_model(self): 82 | # return SimpleNet(num_classes=len(self.classes)) 83 | return MnistNet() 84 | -------------------------------------------------------------------------------- /tasks/task.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from typing import List, Any, Dict 4 | from copy import deepcopy 5 | import numpy as np 6 | from collections import defaultdict, Counter 7 | 8 | 9 | import torch 10 | from torch import optim, nn 11 | from torch.nn import Module 12 | from torch.optim import Optimizer 13 | from torch.utils.data import DataLoader 14 | from torch.utils.data.sampler import SubsetRandomSampler 15 | from torchvision.transforms import transforms 16 | 17 | from metrics.accuracy_metric import AccuracyMetric 18 | from metrics.metric import Metric 19 | from metrics.test_loss_metric import TestLossMetric 20 | from tasks.batch import Batch 21 | from tasks.fl_user import FLUser 22 | from utils.parameters import Params 23 | 24 | logger = logging.getLogger('logger') 25 | 26 | 27 | class Task: 28 | params: Params = None 29 | 30 | train_dataset = None 31 | test_dataset = None 32 | train_loader = None 33 | test_loader = None 34 | classes = None 35 | 36 | model: Module = None 37 | optimizer: optim.Optimizer = None 38 | criterion: Module = None 39 | metrics: List[Metric] = None 40 | 41 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 42 | std=[0.229, 0.224, 0.225]) 43 | "Generic normalization for input data." 44 | input_shape: torch.Size = None 45 | 46 | fl_train_loaders: List[Any] = None 47 | fl_number_of_samples: List[int] = None 48 | 49 | ignored_weights = ['num_batches_tracked']#['tracked', 'running'] 50 | adversaries: List[int] = None 51 | 52 | def __init__(self, params: Params): 53 | self.params = params 54 | self.init_task() 55 | 56 | def init_task(self): 57 | self.load_data() 58 | # self.stat_data() 59 | logger.debug(f"Number of train samples: {self.fl_number_of_samples}") 60 | # exit(0) 61 | self.model = self.build_model() 62 | self.resume_model() 63 | self.model = self.model.to(self.params.device) 64 | 65 | self.local_model = self.build_model().to(self.params.device) 66 | 67 | self.criterion = self.make_criterion() 68 | self.adversaries = self.sample_adversaries() 69 | 70 | self.optimizer = self.make_optimizer() 71 | 72 | self.metrics = [AccuracyMetric(), TestLossMetric(self.criterion)] 73 | self.set_input_shape() 74 | # Initialize the logger 75 | fh = logging.FileHandler( 76 | filename=f'{self.params.folder_path}/log.txt') 77 | formatter = logging.Formatter('%(asctime)s - %(filename)s - Line:%(lineno)d - %(levelname)-8s - %(message)s') 78 | 79 | 80 | fh.setFormatter(formatter) 81 | logger.addHandler(fh) 82 | 83 | def load_data(self) -> None: 84 | raise NotImplemented 85 | 86 | def stat_data(self): 87 | for i, loader in enumerate(self.fl_train_loaders): 88 | labels = [t.item() for data, target in loader for t in target] 89 | label_counts = dict(sorted(Counter(labels).items())) 90 | total_samples = sum(len(target) for _, target in loader) 91 | logger.debug(f"Participant {i:2d} has {label_counts} and total {total_samples} samples") 92 | 93 | 94 | def build_model(self) -> Module: 95 | raise NotImplemented 96 | 97 | def make_criterion(self) -> Module: 98 | """Initialize with Cross Entropy by default. 99 | 100 | We use reduction `none` to support gradient shaping defense. 101 | :return: 102 | """ 103 | return nn.CrossEntropyLoss(reduction='none') 104 | 105 | def make_optimizer(self, model=None) -> Optimizer: 106 | if model is None: 107 | model = self.model 108 | if self.params.optimizer == 'SGD': 109 | optimizer = optim.SGD(model.parameters(), 110 | lr=self.params.lr, 111 | weight_decay=self.params.decay, 112 | momentum=self.params.momentum) 113 | elif self.params.optimizer == 'Adam': 114 | optimizer = optim.Adam(model.parameters(), 115 | lr=self.params.lr, 116 | weight_decay=self.params.decay) 117 | else: 118 | raise ValueError(f'No optimizer: {self.optimizer}') 119 | 120 | return optimizer 121 | 122 | def resume_model(self): 123 | if self.params.resume_model: 124 | # import IPython; IPython.embed() 125 | logger.info(f'Resuming training from {self.params.resume_model}') 126 | loaded_params = torch.load(f"{self.params.resume_model}", 127 | map_location=torch.device('cpu')) 128 | self.model.load_state_dict(loaded_params['state_dict']) 129 | self.params.start_epoch = loaded_params['epoch'] 130 | # print self.model architechture to file 'model.txt' 131 | # with open(f'model.txt', 'w') as f: 132 | # f.write(str(self.model)) 133 | 134 | # # print architecture of loaded_params 135 | # with open(f'loaded_params.txt', 'w') as f: 136 | # f.write(str(loaded_params['state_dict'].keys())) 137 | 138 | # self.params.lr = loaded_params.get('lr', self.params.lr) 139 | 140 | logger.warning(f"Loaded parameters from saved model: LR is" 141 | f" {self.params.lr} and current epoch is" 142 | f" {self.params.start_epoch}") 143 | 144 | def set_input_shape(self): 145 | inp = self.train_dataset[0][0] 146 | self.params.input_shape = inp.shape 147 | logger.info(f"Input shape is {self.params.input_shape}") 148 | 149 | def get_batch(self, batch_id, data) -> Batch: 150 | """Process data into a batch. 151 | 152 | Specific for different datasets and data loaders this method unifies 153 | the output by returning the object of class Batch. 154 | :param batch_id: id of the batch 155 | :param data: object returned by the Loader. 156 | :return: 157 | """ 158 | inputs, labels = data 159 | batch = Batch(batch_id, inputs, labels) 160 | return batch.to(self.params.device) 161 | 162 | def accumulate_metrics(self, outputs, labels): 163 | for metric in self.metrics: 164 | metric.accumulate_on_batch(outputs, labels) 165 | 166 | def reset_metrics(self): 167 | for metric in self.metrics: 168 | metric.reset_metric() 169 | 170 | def report_metrics(self, step, prefix=''): 171 | metric_text = [] 172 | for metric in self.metrics: 173 | metric_text.append(str(metric)) 174 | logger.warning(f'{prefix} {step:4d}. {" | ".join(metric_text)}') 175 | # import IPython; IPython.embed() 176 | # exit(0) 177 | return self.metrics[0].get_main_metric_value() 178 | def get_metrics(self): 179 | # TODO Nov 11, 2023: This is a hack to get the metrics 180 | metric_dict = { 181 | 'accuracy': self.metrics[0].get_main_metric_value(), 182 | 'loss': self.metrics[1].get_main_metric_value() 183 | } 184 | # for metric in self.metrics: 185 | # metric_dict[metric.name] = metric.get_value() 186 | return metric_dict 187 | 188 | @staticmethod 189 | def get_batch_accuracy(outputs, labels, top_k=(1,)): 190 | """Computes the precision@k for the specified values of k""" 191 | max_k = max(top_k) 192 | batch_size = labels.size(0) 193 | 194 | _, pred = outputs.topk(max_k, 1, True, True) 195 | pred = pred.t() 196 | correct = pred.eq(labels.view(1, -1).expand_as(pred)) 197 | 198 | res = [] 199 | for k in top_k: 200 | correct_k = correct[:k].view(-1).float().sum(0) 201 | res.append((correct_k.mul_(100.0 / batch_size)).item()) 202 | if len(res) == 1: 203 | res = res[0] 204 | return res 205 | 206 | def get_empty_accumulator(self): 207 | weight_accumulator = dict() 208 | for name, data in self.model.state_dict().items(): 209 | weight_accumulator[name] = torch.zeros_like(data) 210 | return weight_accumulator 211 | 212 | def sample_users_for_round(self, epoch) -> List[FLUser]: 213 | sampled_ids = random.sample( 214 | range(self.params.fl_total_participants), 215 | self.params.fl_no_models) 216 | 217 | # if 7 not in sampled_ids: 218 | # sampled_ids[0] = 7 219 | 220 | sampled_users = [] 221 | for pos, user_id in enumerate(sampled_ids): 222 | train_loader = self.fl_train_loaders[user_id] 223 | number_of_samples = self.fl_number_of_samples[user_id] 224 | compromised = self.check_user_compromised(epoch, pos, user_id) 225 | 226 | user = FLUser(user_id, compromised=compromised, 227 | train_loader=train_loader, number_of_samples=number_of_samples) 228 | sampled_users.append(user) 229 | logger.warning(f"Sampled users for round {epoch}: {[[user.user_id, user.compromised] for user in sampled_users]} ") 230 | 231 | 232 | total_samples = sum([user.number_of_samples for user in sampled_users]) 233 | 234 | self.params.fl_weight_contribution = {user.user_id: user.number_of_samples / total_samples for user in sampled_users} 235 | # self.params.fl_round_participants = [user.user_id for user in sampled_users] 236 | self.params.fl_local_updated_models = dict() 237 | logger.warning(f"Sampled users for round {epoch}: {self.params.fl_weight_contribution} ") 238 | return sampled_users 239 | 240 | def adding_local_updated_model(self, local_update: Dict[str, torch.Tensor], epoch=None, user_id=None): 241 | self.params.fl_local_updated_models[user_id] = local_update 242 | 243 | def check_user_compromised(self, epoch, pos, user_id): 244 | """Check if the sampled user is compromised for the attack. 245 | 246 | If single_epoch_attack is defined (eg not None) then ignore 247 | :param epoch: 248 | :param pos: 249 | :param user_id: 250 | :return: 251 | """ 252 | compromised = False 253 | if self.params.fl_single_epoch_attack is not None: 254 | if epoch == self.params.fl_single_epoch_attack: 255 | if pos < self.params.fl_number_of_adversaries: 256 | compromised = True 257 | # if user_id == 0: 258 | # compromised = True 259 | logger.warning(f'Attacking once at epoch {epoch}. Compromised' 260 | f' user: {user_id}.') 261 | else: 262 | if epoch >= self.params.poison_epoch and epoch < self.params.poison_epoch_stop + 1: 263 | compromised = user_id in self.adversaries 264 | return compromised 265 | 266 | def sample_adversaries(self) -> List[int]: 267 | adversaries_ids = [] 268 | if self.params.fl_number_of_adversaries == 0: 269 | logger.warning(f'Running vanilla FL, no attack.') 270 | elif self.params.fl_single_epoch_attack is None: 271 | # adversaries_ids = list(range(self.params.fl_number_of_adversaries)) 272 | adversaries_ids = random.sample( 273 | range(self.params.fl_total_participants), 274 | self.params.fl_number_of_adversaries) 275 | 276 | logger.warning(f'Attacking over multiple epochs with following ' 277 | f'users compromised: {adversaries_ids}.') 278 | else: 279 | logger.warning(f'Attack only on epoch: ' 280 | f'{self.params.fl_single_epoch_attack} with ' 281 | f'{self.params.fl_number_of_adversaries} compromised' 282 | f' users.') 283 | 284 | return adversaries_ids 285 | 286 | def get_model_optimizer(self, model): 287 | local_model = deepcopy(model) 288 | local_model = local_model.to(self.params.device) 289 | 290 | optimizer = self.make_optimizer(local_model) 291 | 292 | return local_model, optimizer 293 | 294 | def copy_params(self, global_model: Module, local_model: Module): 295 | local_state = local_model.state_dict() 296 | for name, param in global_model.state_dict().items(): 297 | if name in local_state and name not in self.ignored_weights: 298 | local_state[name].copy_(param) 299 | 300 | def update_global_model(self, weight_accumulator, global_model: Module): 301 | # self.last_global_model = deepcopy(self.model) 302 | for name, sum_update in weight_accumulator.items(): 303 | if self.check_ignored_weights(name): 304 | continue 305 | # scale = self.params.fl_eta / self.params.fl_total_participants 306 | # TODO: change this based on number of sample of each user 307 | # scale = 1 self.params.fl_eta / self.params.fl_no_models 308 | # scale = 1.0 309 | # average_update = scale * sum_update 310 | average_update = sum_update 311 | model_weight = global_model.state_dict()[name] 312 | model_weight.add_(average_update) 313 | 314 | def check_ignored_weights(self, name) -> bool: 315 | for ignored in self.ignored_weights: 316 | if ignored in name: 317 | return True 318 | 319 | return False 320 | 321 | def sample_dirichlet_train_data(self, no_participants, alpha=0.9): 322 | """ 323 | Input: Number of participants and alpha (param for distribution) 324 | Output: A list of indices denoting data in CIFAR training set. 325 | Requires: dataset_classes, a preprocessed class-indices dictionary. 326 | Sample Method: take a uniformly sampled 10-dimension vector as 327 | parameters for 328 | dirichlet distribution to sample number of images in each class. 329 | """ 330 | 331 | dataset_classes = {} 332 | for ind, x in enumerate(self.train_dataset): 333 | _, label = x 334 | if label in dataset_classes: 335 | dataset_classes[label].append(ind) 336 | else: 337 | dataset_classes[label] = [ind] 338 | class_size = len(dataset_classes[0]) 339 | per_participant_list = defaultdict(list) 340 | no_classes = len(dataset_classes.keys()) 341 | 342 | for n in range(no_classes): 343 | random.shuffle(dataset_classes[n]) 344 | sampled_probabilities = class_size * np.random.dirichlet( 345 | np.array(no_participants * [alpha])) 346 | for user in range(no_participants): 347 | no_imgs = int(round(sampled_probabilities[user])) 348 | sampled_list = dataset_classes[n][ 349 | :min(len(dataset_classes[n]), no_imgs)] 350 | per_participant_list[user].extend(sampled_list) 351 | dataset_classes[n] = dataset_classes[n][ 352 | min(len(dataset_classes[n]), no_imgs):] 353 | 354 | return per_participant_list 355 | 356 | def get_train(self, indices): 357 | """ 358 | This method is used along with Dirichlet distribution 359 | :param indices: 360 | :return: 361 | """ 362 | train_loader = DataLoader(self.train_dataset, 363 | batch_size=self.params.batch_size, 364 | sampler=SubsetRandomSampler( 365 | indices), drop_last=True) 366 | return train_loader, len(indices) 367 | 368 | def get_train_old(self, all_range, model_no): 369 | """ 370 | This method equally splits the dataset. 371 | :param all_range: 372 | :param model_no: 373 | :return: 374 | """ 375 | 376 | data_len = int( 377 | len(self.train_dataset) / self.params.fl_total_participants) 378 | sub_indices = all_range[model_no * data_len: (model_no + 1) * data_len] 379 | train_loader = DataLoader(self.train_dataset, 380 | batch_size=self.params.batch_size, 381 | sampler=SubsetRandomSampler( 382 | sub_indices)) 383 | return train_loader, len(sub_indices) -------------------------------------------------------------------------------- /test-lib/test-log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import colorlog 3 | import os 4 | from logging import FileHandler 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | # Example ground truth labels and predicted outputs 10 | # Assuming batch_size=3 and num_classes=5 11 | target = torch.tensor([2, 0, 4]) # Ground truth labels 12 | output = torch.tensor([[0.1, 0.2, 0.6, 0.1, 0.0], 13 | [0.8, 0.1, 0.0, 0.05, 0.05], 14 | [0.0, 0.0, 0.1, 0.2, 0.7]]) # Predicted scores/logits 15 | 16 | # Using nn.CrossEntropyLoss with reduction='none' 17 | criterion = nn.CrossEntropyLoss(reduction='none') 18 | 19 | # Calculate loss for each example separately 20 | losses = criterion(output, target) 21 | print(losses, losses.mean()) 22 | 23 | 24 | # def create_logger(): 25 | # """ 26 | # Setup the logging environment 27 | # """ 28 | # log = logging.getLogger() # root logger 29 | # log.setLevel(logging.DEBUG) 30 | # format_str = '%(asctime)s - %(filename)s - Line:%(lineno)d - %(levelname)-8s - %(message)s' 31 | # date_format = '%Y-%m-%d %H:%M:%S' 32 | # if os.isatty(2): 33 | # cformat = '%(log_color)s' + format_str 34 | # colors = {'DEBUG': 'bold_blue', 35 | # 'INFO': 'reset', 36 | # 'WARNING': 'bold_yellow', 37 | # 'ERROR': 'bold_red', 38 | # 'CRITICAL': 'bold_red'} 39 | # formatter = colorlog.ColoredFormatter(cformat, date_format, 40 | # log_colors=colors) 41 | # else: 42 | # formatter = logging.Formatter(format_str, date_format) 43 | # stream_handler = logging.StreamHandler() 44 | # stream_handler.setFormatter(formatter) 45 | # log.addHandler(stream_handler) 46 | # return logging.getLogger(__name__) 47 | 48 | # logger = create_logger() 49 | 50 | # # Configure basic logging settings with the custom formatter 51 | # logging.basicConfig(level=logging.DEBUG, 52 | # format='%(log_color)s %(asctime)s - %(filename)s - Line:%(lineno)d - %(name)s - %(levelname)-8s - %(message)s') 53 | # # cformat = '%(log_color)s' + format_str 54 | # # colors = {'DEBUG': 'reset', 55 | # # 'INFO': 'reset', 56 | # # 'WARNING': 'bold_yellow', 57 | # # 'ERROR': 'bold_red', 58 | # # 'CRITICAL': 'bold_red'} 59 | 60 | # # Set the custom formatter 61 | # logger = logging.getLogger() 62 | 63 | # # Create a logger 64 | # logger = logging.getLogger("my_logger") 65 | 66 | # # Create an HTML formatter 67 | # class HTMLFormatter(logging.Formatter): 68 | # def format(self, record): 69 | # level = record.levelname 70 | # message = record.getMessage() 71 | # formatted = f"

{level}: {message}

\n" 72 | # return formatted 73 | 74 | # # Create a FileHandler and set the HTML formatter 75 | # file_handler = FileHandler('logs.html', mode='a', encoding='utf-8') 76 | # html_formatter = HTMLFormatter() 77 | # file_handler.setFormatter(html_formatter) 78 | 79 | # # Add the FileHandler to the logger 80 | # logger.addHandler(file_handler) 81 | 82 | 83 | # Log messages 84 | # logger.debug("This is a debug message") 85 | # logger.info("This is an info message") 86 | # logger.warning("This is a warning message") 87 | # logger.error("This is an error message") 88 | # logger.critical("This is a critical message") 89 | 90 | 91 | # # write logger to html file 92 | # import numpy as np 93 | # from PIL import Image 94 | # from torchvision import transforms 95 | # trans = transforms.Compose([transforms.ToTensor()]) 96 | # img = np.random.randint(0, 255, size=(3, 224, 224), dtype=np.uint8) 97 | 98 | # demo_img = trans(img) 99 | 100 | # demo_array = np.moveaxis(demo_img.numpy()*255, 0, -1) 101 | # print(Image.fromarray(demo_array.astype(np.uint8))) 102 | 103 | # from matplotlib import pyplot as plt 104 | # plt.imshow(img) 105 | # plt.show() -------------------------------------------------------------------------------- /test-lib/test-write-file.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 13, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "840 \n", 13 | "fl_total_participants: 100 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05\n", 14 | "python training.py --name cifar10 --params /home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/run_cifar10__2023.Nov.24/cifar10_fed_100_10_4_0.5_0.05.yaml\n", 15 | "------------------------\n", 16 | "fl_total_participants: 200 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05\n", 17 | "python training.py --name cifar10 --params /home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/run_cifar10__2023.Nov.24/cifar10_fed_200_10_4_0.5_0.05.yaml\n", 18 | "------------------------\n", 19 | "796 \n", 20 | "fl_total_participants: 100 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05\n", 21 | "python training.py --name mnist --params /home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/run_mnist__2023.Nov.24/mnist_fed_100_10_4_0.5_0.05.yaml\n", 22 | "------------------------\n", 23 | "fl_total_participants: 200 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05\n", 24 | "python training.py --name mnist --params /home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/run_mnist__2023.Nov.24/mnist_fed_200_10_4_0.5_0.05.yaml\n", 25 | "------------------------\n", 26 | "863 \n", 27 | "fl_total_participants: 100 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05\n", 28 | "python training.py --name tiny-imagenet --params /home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/run_tiny-imagenet__2023.Nov.24/tiny-imagenet_fed_100_10_4_0.5_0.05.yaml\n", 29 | "------------------------\n", 30 | "fl_total_participants: 200 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05\n", 31 | "python training.py --name tiny-imagenet --params /home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/run_tiny-imagenet__2023.Nov.24/tiny-imagenet_fed_200_10_4_0.5_0.05.yaml\n", 32 | "------------------------\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "import os\n", 38 | "from datetime import datetime\n", 39 | "\n", 40 | "\n", 41 | "def generate_exps_file(root_file='/home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/cifar_fed.yaml', name_exp = 'cifar10', EXPS_DIR=\"/home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/extras\"):\n", 42 | " # read file as a string\n", 43 | " print(f'reading from root_file: {root_file}')\n", 44 | " with open(root_file, 'r') as file :\n", 45 | " filedata = file.read()\n", 46 | " fl_total_participants_choices = [100, 200]\n", 47 | " fl_no_models_choices = [10]\n", 48 | " fl_dirichlet_alpha_choices = [0.5]\n", 49 | " fl_number_of_adversaries_choices = [4]\n", 50 | " fl_lr_choices = [0.05]\n", 51 | " # EXPS_DIR = '/home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/extras'\n", 52 | " \n", 53 | " os.makedirs(EXPS_DIR, exist_ok=True)\n", 54 | "\n", 55 | " for fl_total_participants in fl_total_participants_choices:\n", 56 | " for fl_no_models in fl_no_models_choices:\n", 57 | " for fl_dirichlet_alpha in fl_dirichlet_alpha_choices:\n", 58 | " for fl_number_of_adversaries in fl_number_of_adversaries_choices:\n", 59 | " for fl_lr in fl_lr_choices:\n", 60 | " print(f'fl_total_participants: {fl_total_participants} fl_no_models: {fl_no_models} fl_dirichlet_alpha: {fl_dirichlet_alpha} fl_number_of_adversaries: {fl_number_of_adversaries} fl_lr: {fl_lr}')\n", 61 | " filedata = filedata.replace('fl_total_participants: 100', f'fl_total_participants: {fl_total_participants}')\n", 62 | " filedata = filedata.replace('fl_no_models: 10', f'fl_no_models: {fl_no_models}')\n", 63 | " filedata = filedata.replace('fl_dirichlet_alpha: 0.5', f'fl_dirichlet_alpha: {fl_dirichlet_alpha}')\n", 64 | " filedata = filedata.replace('fl_number_of_adversaries: 4', f'fl_number_of_adversaries: {fl_number_of_adversaries}')\n", 65 | " filedata = filedata.replace('lr: 0.005', f'lr: {fl_lr}')\n", 66 | " # print(len(filedata), type(filedata)) \n", 67 | " # print('------------------------')\n", 68 | " # write the file out again\n", 69 | " fn_write = f'{EXPS_DIR}/{name_exp}_fed_{fl_total_participants}_{fl_no_models}_{fl_number_of_adversaries}_{fl_dirichlet_alpha}_{fl_lr}.yaml'\n", 70 | " if not os.path.exists(fn_write):\n", 71 | " with open(fn_write, 'w') as file:\n", 72 | " file.write(filedata)\n", 73 | " \n", 74 | " cmd = f'python training.py --name {name_exp} --params {fn_write}'\n", 75 | " print(cmd)\n", 76 | " print('------------------------')\n", 77 | "\n", 78 | "current_time = datetime.now().strftime('%Y.%b.%d')\n", 79 | "\n", 80 | "generate_exps_file(root_file='/home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/cifar_fed.yaml', \n", 81 | " name_exp = 'cifar10', \n", 82 | " EXPS_DIR=f\"/home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/run_cifar10__{current_time}\")\n", 83 | "\n", 84 | "\n", 85 | "generate_exps_file(root_file='/home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/mnist_fed.yaml', \n", 86 | " name_exp = 'mnist', \n", 87 | " EXPS_DIR=f\"/home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/run_mnist__{current_time}\")\n", 88 | "\n", 89 | "generate_exps_file(root_file='/home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/imagenet_fed.yaml', \n", 90 | " name_exp = 'tiny-imagenet', \n", 91 | " EXPS_DIR=f\"/home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/run_tiny-imagenet__{current_time}\")\n", 92 | "\n" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [] 101 | } 102 | ], 103 | "metadata": { 104 | "kernelspec": { 105 | "display_name": "cardio", 106 | "language": "python", 107 | "name": "python3" 108 | }, 109 | "language_info": { 110 | "codemirror_mode": { 111 | "name": "ipython", 112 | "version": 3 113 | }, 114 | "file_extension": ".py", 115 | "mimetype": "text/x-python", 116 | "name": "python", 117 | "nbconvert_exporter": "python", 118 | "pygments_lexer": "ipython3", 119 | "version": "3.7.13" 120 | }, 121 | "orig_nbformat": 4 122 | }, 123 | "nbformat": 4, 124 | "nbformat_minor": 2 125 | } 126 | -------------------------------------------------------------------------------- /test-lib/test.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtuann/fedlearn-backdoor-attacks/b025e763558d1e3d6a5872be7d158e8c157e7bf1/test-lib/test.log -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | from helper import Helper 4 | from datetime import datetime 5 | from tqdm import tqdm 6 | import wandb 7 | 8 | from utils.utils import * 9 | logger = logging.getLogger('logger') 10 | 11 | def train(hlpr: Helper, epoch, model, optimizer, train_loader, attack=False, global_model=None): 12 | criterion = hlpr.task.criterion 13 | model.train() 14 | # for i, data in tqdm(enumerate(train_loader)): 15 | for i, data in enumerate(train_loader): 16 | batch = hlpr.task.get_batch(i, data) 17 | model.zero_grad() 18 | loss = hlpr.attack.compute_blind_loss(model, criterion, batch, attack, global_model) 19 | loss.backward() 20 | optimizer.step() 21 | # import IPython; IPython.embed() 22 | # exit(0) 23 | # print(f"Epoch {epoch} batch {i} loss {loss.item()}") 24 | 25 | if i == hlpr.params.max_batch_id: 26 | break 27 | # metric = hlpr.task.report_metrics(epoch, 28 | # prefix=f'Backdoor {str(backdoor):5s}. Epoch: ') 29 | return 30 | 31 | def test(hlpr: Helper, epoch, backdoor=False, model=None): 32 | if model is None: 33 | model = hlpr.task.model 34 | model.eval() 35 | hlpr.task.reset_metrics() 36 | with torch.no_grad(): 37 | for i, data in tqdm(enumerate(hlpr.task.test_loader)): 38 | batch = hlpr.task.get_batch(i, data) 39 | if backdoor: 40 | batch = hlpr.attack.synthesizer.make_backdoor_batch(batch, 41 | test=True, 42 | attack=True) 43 | 44 | outputs = model(batch.inputs) 45 | hlpr.task.accumulate_metrics(outputs=outputs, labels=batch.labels) 46 | metric = hlpr.task.report_metrics(epoch, 47 | prefix=f'Backdoor {str(backdoor):5s}. Epoch: ') 48 | return metric 49 | 50 | def run_fl_round(hlpr: Helper, epoch): 51 | global_model = hlpr.task.model 52 | local_model = hlpr.task.local_model 53 | round_participants = hlpr.task.sample_users_for_round(epoch) 54 | 55 | weight_accumulator = hlpr.task.get_empty_accumulator() 56 | 57 | logger.info(f"Round epoch {epoch} with participants: {[user.user_id for user in round_participants]} and weight: {hlpr.params.fl_weight_contribution}") 58 | # log number of sample per user 59 | logger.info(f"Round epoch {epoch} with participants sample size: {[user.number_of_samples for user in round_participants]} and sum: {sum([user.number_of_samples for user in round_participants])}") 60 | 61 | for user in tqdm(round_participants): 62 | hlpr.task.copy_params(global_model, local_model) 63 | optimizer = hlpr.task.make_optimizer(local_model) 64 | if user.compromised: 65 | # if not user.user_id == 0: 66 | # continue 67 | 68 | logger.warning(f"Compromised user: {user.user_id} in run_fl_round {epoch}") 69 | for local_epoch in tqdm(range(hlpr.params.fl_poison_epochs)): # fl_poison_epochs)): 70 | train(hlpr, local_epoch, local_model, optimizer, 71 | user.train_loader, attack=True, global_model=global_model) 72 | 73 | else: 74 | logger.warning(f"Non-compromised user: {user.user_id} in run_fl_round {epoch}") 75 | for local_epoch in range(hlpr.params.fl_local_epochs): 76 | train(hlpr, local_epoch, local_model, optimizer, 77 | user.train_loader, attack=False) 78 | 79 | local_update = hlpr.attack.get_fl_update(local_model, global_model) 80 | 81 | # hlpr.save_update(model=local_update, userID=user.user_id) 82 | # Do not save model to files, save it as a variable 83 | hlpr.task.adding_local_updated_model(local_update = local_update, user_id=user.user_id) 84 | 85 | # if user.compromised: 86 | # hlpr.attack.perform_attack(global_model, user, epoch) 87 | # hlpr.attack.local_dataset = deepcopy(user.train_loader) 88 | 89 | hlpr.defense.aggr(weight_accumulator, global_model, ) 90 | # logger.info(f"Round {epoch} update global model") 91 | 92 | hlpr.task.update_global_model(weight_accumulator, global_model) 93 | 94 | def run(hlpr: Helper): 95 | metric = test(hlpr, -1, backdoor=False) 96 | logger.info(f"Before training main metric: {metric}") 97 | 98 | for epoch in range(hlpr.params.start_epoch, 99 | hlpr.params.epochs + 1): 100 | logger.info(f"Communication round {epoch}") 101 | run_fl_round(hlpr, epoch) 102 | metric = test(hlpr, epoch, backdoor=False) 103 | main_metric = hlpr.task.get_metrics() 104 | 105 | # logger.info(f"Epoch {epoch} main metric: {metric}") 106 | metric_bd = test(hlpr, epoch, backdoor=True) 107 | backdoor_metric = hlpr.task.get_metrics() 108 | 109 | 110 | wandb.log({'main_acc': main_metric['accuracy'], 'main_loss': main_metric['loss'], 111 | 'backdoor_acc': backdoor_metric['accuracy'], 'backdoor_loss': backdoor_metric['loss']}, 112 | step=epoch) 113 | logger.info(f"Epoch {epoch} backdoor metric: {metric}") 114 | 115 | # hlpr.record_accuracy(metric, test(hlpr, epoch, backdoor=True), epoch) 116 | 117 | hlpr.save_model(hlpr.task.model, epoch, metric) 118 | 119 | def generate_exps_file(root_file='cifar_fed.yaml'): 120 | # read file as a string 121 | with open(root_file, 'r') as file : 122 | filedata = file.read() 123 | print(len(filedata), type(filedata)) 124 | pass 125 | 126 | if __name__ == '__main__': 127 | parser = argparse.ArgumentParser(description='Backdoors') 128 | parser.add_argument('--params', dest='params', required=True) 129 | parser.add_argument('--name', dest='name', required=True) 130 | # python training.py --name mnist --params exps/mnist_fed.yaml 131 | # python training.py --name tiny-imagenet-200 --params exps/imagenet_fed.yaml 132 | # python training.py --name cifar10 --params exps/cifar_fed.yaml 133 | args = parser.parse_args() 134 | print(args) 135 | with open(args.params) as f: 136 | params = yaml.load(f, Loader=yaml.FullLoader) 137 | # print(params) 138 | # import IPython; IPython.embed() 139 | pretrained_str = 'pretrained' if params['resume_model'] else 'no_pretrained' 140 | 141 | params['name'] = f'vishc_{args.name}.{params["synthesizer"]}.{params["fl_total_participants"]}_{params["fl_no_models"]}_{params["fl_number_of_adversaries"]}_{params["fl_dirichlet_alpha"]}_{params["lr"]}_{pretrained_str}' 142 | 143 | params['current_time'] = datetime.now().strftime('%Y.%b.%d_%H.%M.%S') 144 | print(params) 145 | # exit(0) 146 | helper = Helper(params) 147 | 148 | # logger = create_logger() 149 | 150 | # logger.info(create_table(params)) 151 | 152 | wandb.init(project="benchmark-backdoor-fl", entity="mtuann", name=f"{params['name']}-{params['current_time']}") 153 | try: 154 | run(helper) 155 | except Exception as e: 156 | print(e) 157 | 158 | -------------------------------------------------------------------------------- /utils/parameters.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from dataclasses import dataclass, asdict 3 | from typing import List, Dict 4 | import logging 5 | import torch 6 | logger = logging.getLogger('logger') 7 | 8 | @dataclass 9 | class Params: 10 | 11 | # Corresponds to the class module: tasks.mnist_task.MNISTTask 12 | # See other tasks in the task folder. 13 | task: str = 'MNIST' 14 | num_classes = 10 #TODO: update num_classes for TinyImageNet 15 | 16 | current_time: str = None 17 | name: str = None 18 | random_seed: int = None 19 | device: str = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | # training params 21 | start_epoch: int = 1 22 | epochs: int = None 23 | poison_epoch: int = None 24 | poison_epoch_stop: int = None 25 | log_interval: int = 1000 26 | 27 | # model arch is usually defined by the task 28 | resume_model: str = None 29 | lr: float = None 30 | decay: float = None 31 | momentum: float = None 32 | optimizer: str = None 33 | # data 34 | data_path: str = '.data/' 35 | batch_size: int = 64 36 | test_batch_size: int = 100 37 | # Do not apply transformations to the training images. 38 | transform_train: bool = True 39 | # For large datasets stop training earlier. 40 | max_batch_id: int = None 41 | # No need to set, updated by the Task class. 42 | input_shape = None 43 | 44 | # attack params 45 | backdoor: bool = False 46 | backdoor_label: int = 8 47 | poisoning_proportion: float = 1.0 # backdoors proportion in backdoor loss 48 | synthesizer: str = 'pattern' 49 | backdoor_dynamic_position: bool = False 50 | 51 | # factors to balance losses 52 | fixed_scales: Dict[str, float] = None 53 | 54 | # optimizations: 55 | alternating_attack: float = None 56 | clip_batch: float = None 57 | # Disable BatchNorm and Dropout 58 | switch_to_eval: float = None 59 | 60 | # logging 61 | report_train_loss: bool = True 62 | log: bool = False 63 | save_model: bool = None 64 | save_on_epochs: List[int] = None 65 | save_scale_values: bool = False 66 | print_memory_consumption: bool = False 67 | save_timing: bool = False 68 | timing_data = None 69 | 70 | # Temporary storage for running values 71 | running_losses = None 72 | running_scales = None 73 | 74 | # FL params 75 | fl: bool = False 76 | fl_no_models: int = 100 77 | fl_local_epochs: int = 2 78 | fl_poison_epochs: int = None 79 | fl_total_participants: int = 200 80 | fl_eta: int = 1 81 | fl_sample_dirichlet: bool = False 82 | fl_dirichlet_alpha: float = None 83 | fl_diff_privacy: bool = False 84 | # FL attack details. Set no adversaries to perform the attack: 85 | fl_number_of_adversaries: int = 0 86 | fl_single_epoch_attack: int = None 87 | fl_weight_scale: int = 1 88 | 89 | fl_round_participants: List[int] = None 90 | fl_weight_contribution: Dict[int, float] = None 91 | 92 | fl_local_updated_models: Dict[int, Dict[str, torch.Tensor]] = None 93 | 94 | attack: str = None #'ThrDFed' (3DFed), 'ModelRplace' (Model Replacement) 95 | 96 | #"Foolsgold", "FLAME", "RFLBAT", "Deepsight", "FLDetector" 97 | defense: str = None 98 | lagrange_step: float = None 99 | random_neurons: List[int] = None 100 | noise_mask_alpha: float = None 101 | fl_adv_group_size: int = 0 102 | fl_num_neurons: int = 0 103 | 104 | def __post_init__(self): 105 | # enable logging anyways when saving statistics 106 | if self.save_model or self.save_timing or \ 107 | self.print_memory_consumption: 108 | self.log = True 109 | 110 | if self.log: 111 | self.folder_path = f'saved_models/' \ 112 | f'{self.task}__{self.current_time}__{self.name}' 113 | 114 | self.running_losses = defaultdict(list) 115 | self.running_scales = defaultdict(list) 116 | self.timing_data = defaultdict(list) 117 | 118 | def to_dict(self): 119 | return asdict(self) -------------------------------------------------------------------------------- /utils/process_tiny_data.sh: -------------------------------------------------------------------------------- 1 | unzip tiny-imagenet-200.zip -d ./ 2 | if [ ! -d ../.data ];then 3 | mkdir ../.data 4 | else 5 | echo '../.data' dir exist 6 | fi 7 | 8 | mv ./tiny-imagenet-200 ../.data/ 9 | echo move 'tiny-imagenet-200' dir to '../.data/tiny-imagenet-200' 10 | python tinyimagenet_reformat.py 11 | -------------------------------------------------------------------------------- /utils/tinyimagenet_reformat.py: -------------------------------------------------------------------------------- 1 | import io 2 | import pandas as pd 3 | import glob 4 | import os 5 | from shutil import move 6 | from os.path import join 7 | from os import listdir, rmdir 8 | 9 | target_folder = '../.data/tiny-imagenet-200/val/' 10 | 11 | val_dict = {} 12 | with open(target_folder + 'val_annotations.txt', 'r') as f: 13 | for line in f.readlines(): 14 | split_line = line.split('\t') 15 | val_dict[split_line[0]] = split_line[1] 16 | 17 | paths = glob.glob(target_folder + 'images/*') 18 | paths[0].split('/')[-1] 19 | for path in paths: 20 | file = path.split('/')[-1] 21 | file = file.split('\\')[-1] 22 | folder = val_dict[file] 23 | if not os.path.exists(target_folder + str(folder)): 24 | os.mkdir(target_folder + str(folder)) 25 | 26 | for path in paths: 27 | file = path.split('/')[-1] 28 | file = file.split('\\')[-1] 29 | folder = val_dict[file] 30 | dest = target_folder + str(folder) + '/' + str(file) 31 | move(path, dest) 32 | 33 | os.remove('../.data/tiny-imagenet-200/val/val_annotations.txt') 34 | rmdir('../.data/tiny-imagenet-200/val/images') 35 | print('done reformat the validation images') -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | 5 | import colorlog 6 | import torch 7 | 8 | from utils.parameters import Params 9 | 10 | def record_time(params: Params, t=None, name=None): 11 | if t and name and params.save_timing == name or params.save_timing is True: 12 | torch.cuda.synchronize() 13 | params.timing_data[name].append(round(1000 * (time.perf_counter() - t))) 14 | 15 | def create_table(params: dict): 16 | data = f"\n| {'name' + ' ' * 21} | value | \n|{'-'*27}|----------|" 17 | 18 | for key, value in params.items(): 19 | # data += '\n' + f"| {(25 - len(key)) * ' ' }{key} | {value} |" 20 | data += f"\n| {key: <25} | {value} " 21 | 22 | return data 23 | 24 | def create_logger(): 25 | """ 26 | Setup the logging environment 27 | """ 28 | log = logging.getLogger() # root logger 29 | log.setLevel(logging.DEBUG) 30 | format_str = '%(asctime)s - %(filename)s - Line:%(lineno)d - %(levelname)-8s - %(message)s' 31 | date_format = '%Y-%m-%d %H:%M:%S' 32 | if os.isatty(2): 33 | cformat = '%(log_color)s' + format_str 34 | colors = {'DEBUG': 'bold_blue', 35 | 'INFO': 'bold_green', 36 | 'WARNING': 'bold_yellow', 37 | 'ERROR': 'bold_red', 38 | 'CRITICAL': 'bold_red'} 39 | formatter = colorlog.ColoredFormatter(cformat, date_format, 40 | log_colors=colors) 41 | else: 42 | formatter = logging.Formatter(format_str, date_format) 43 | stream_handler = logging.StreamHandler() 44 | stream_handler.setFormatter(formatter) 45 | log.addHandler(stream_handler) 46 | return logging.getLogger(__name__) 47 | --------------------------------------------------------------------------------