├── .gitignore
├── Makefile
├── README.md
├── attacks
├── attack.py
├── loss_functions.py
└── modelreplace.py
├── defenses
└── fedavg.py
├── environment.yml
├── exps
├── cifar_fed.yaml
├── gen-run-yaml-file.py
├── imagenet_fed.yaml
├── list_exps__2023.Nov.25.md
└── mnist_fed.yaml
├── helper.py
├── markdown-explanation
├── FedMLSec.md
├── Figure_4_ba_norm.png
├── Figure_full-image-pattern.png
├── code-structures.md
└── model-replacement-attack.png
├── metrics
├── accuracy_metric.py
├── metric.py
└── test_loss_metric.py
├── models
├── MnistNet.py
├── __init__.py
├── model.py
├── resnet.py
├── resnet_cifar.py
├── resnet_tinyimagenet.py
└── simple.py
├── synthesizers
├── attack_models
│ ├── autoencoders.py
│ └── unet.py
├── complex_synthesizer.py
├── iba_synthesizer.py
├── pattern_synthesizer.py
├── physical_synthesizer.py
├── singlepixel_synthesizer.py
└── synthesizer.py
├── tasks
├── batch.py
├── cifar10_task.py
├── fl_user.py
├── imagenet_task.py
├── mnist_task.py
└── task.py
├── test-lib
├── test-log.py
├── test-write-file.ipynb
└── test.log
├── training.py
└── utils
├── parameters.py
├── process_tiny_data.sh
├── tinyimagenet_reformat.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | FedML/
3 | config/
4 | PFL-Non-IID/
5 | 3DFed/
6 | OOD_Federated_Learning/
7 | RLBackdoorFL/
8 | __pycache__/
9 | *.pyc
10 | saved_models/
11 | .data/
12 | wandb/
13 | *.zip
14 | exps/run*
15 | env-variables.md
16 | synthesizers/checkpoint/
17 | synthesizers/atkmodel/
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .ONESHELL:
2 |
3 | CONDAPATH = $$(conda info --base)
4 | ENV_NAME = aba
5 |
6 | install:
7 | conda env create -f environment.yml
8 | ${CONDAPATH}/envs/$(ENV_NAME)/bin/pip install -r requirements.txt
9 |
10 | install-mac:
11 | conda env create -f environment.yml
12 | conda install nomkl
13 | ${CONDAPATH}/envs/$(ENV_NAME)/bin/pip install -r requirements.txt
14 |
15 | update:
16 | # conda env update --prune -f environment.yml
17 | ${CONDAPATH}/envs/$(ENV_NAME)/bin/pip install -r requirements.txt --upgrade
18 |
19 | clean:
20 | conda env remove --name $(ENV_NAME)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Table of contents
3 | - [Table of contents](#table-of-contents)
4 | - [Description](#description)
5 | - [Overview of project](#overview-of-project)
6 | - [How to develop from this project](#how-to-develop-from-this-project)
7 | - [Dataset for Backdoor Attack in FL](#dataset-for-backdoor-attack-in-fl)
8 | - [Durability](#durability)
9 | - [Working with FedML](#working-with-fedml)
10 | - [Survey Papers for Machine Learning Security](#survey-papers-for-machine-learning-security)
11 | - [Paper Backdoor Attack in ML/ FL](#paper-backdoor-attack-in-ml-fl)
12 | - [Code for Backdoor Attack in ML/ FL](#code-for-backdoor-attack-in-ml-fl)
13 | - [Other Resources for Backdoor Attack in ML/ FL](#other-resources-for-backdoor-attack-in-ml-fl)
14 | - [Backdoor Attack code resources in FL](#backdoor-attack-code-resources-in-fl)
15 |
16 |
17 | # Description
18 | This github project provided a fast integration for readers to work in the field of backdoor attacks in machine learning and federated learning.
19 |
20 |
21 |
22 | # Overview of project
23 | - Attack methods: DBA, LIRA, IBA, BackdoorBox, 3DFed, Chameleon, etc.
24 | - Defense methods: Krum, RFA, FoolsGold, etc.
25 |
26 | # How to develop from this project
27 | - Define dataset in `tasks` folder.
28 | - Define your own attack method in `.attacks` folder. Default attack method is adding trigger to the training dataset and training the model with the poisoned dataset.
29 | - Define your own defense method in `defenses` folder. Default defense method is FedAvg.
30 | - Define your own model in `models` folder and config for experiment in `exps` folder.
31 | - All the experiments are inherited from the 3 base `.yaml` files in `exps` folder for 3 datasets: `mnist_fed.yaml` (MNIST), `cifar_fed.yaml` (CIFAR-10), `imagetnet_fed.yaml` (Tiny ImageNet).
32 | - The base setting for each experiment contains 100 clients, in which 4 clients are attackers, 10 clients participate in each round, and 2 rounds are performed for benign training, and 5 rounds are performed for attack training.
33 | - The dataset is divided by dirichlet distribution with $\alpha = 0.5$.
34 | - The pre-trained model is downloaded from [Google Drive (DBA)](https://drive.google.com/file/d/1wcJ_DkviuOLkmr-FgIVSFwnZwyGU8SjH/view).
35 | ---
36 | Here is the example of running the code:
37 | ```
38 | python training.py --name cifar10 --params ./exps/run_cifar10__2023.Nov.24/cifar10_fed_100_10_4_0.5_0.05.yaml
39 |
40 | python training.py --name mnist --params ./exps/run_mnist__2023.Nov.24/mnist_fed_100_10_4_0.5_0.05.yaml
41 |
42 | python training.py --name tiny-imagenet --params ./exps/run_tiny-imagenet__2023.Nov.24/tiny-imagenet_fed_100_10_4_0.5_0.05.yaml
43 | ```
44 | In these commands, the `--name` argument is the name of experiment, and the `--params` argument is the path to the `.yaml` file of experiment setting.
45 |
46 | ---
47 |
48 | For more commands, please refer to file `./exps/run-yaml-cmd.md`. Please note that the `./exps/run-yaml-cmd.md` file is generated by the `./exps/gen-run-yaml-file.py` file. Refer to the `./exps/gen-run-yaml-file.py` to generate different commands for your own experiments.
49 |
50 |
51 | # Dataset for Backdoor Attack in FL
52 |
53 | |Dataset|Case|Description|
54 | | :--- | :--- | :--- |
55 | |MNIST|Case -|The MNIST database of handwritten digits, has a training set of 60,000 examples, and a test set of 10,000 examples.|
56 | CIFAR-10|Case -|The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images|
57 | CIFAR-100|Case -|The CIFAR-100 dataset consists of 60000 32x32 colour images in 100 classes, with 600 images per class., 500 training images and 100 testing images per class.|
58 | Tiny ImageNet|Case -|The Tiny ImageNet contains 200 image classes, a training dataset of 100,000 images, a validation dataset of 10,000 images, and a test dataset of 10,000 images (50 validation and 50 test images per class). All images are of size 64×64|
59 | |EMNIST|Case -|There are 62 classes (10 digits, 26 lowercase, 26 uppercase), 814255 samples, 697932 training samples, 116323 test samples.
60 |
61 | [Edge-case backdoors](https://proceedings.neurips.cc/paper/2020/hash/b8ffa41d4e492f0fad2f13e29e1762eb-Abstract.html)
62 |
63 |
69 |
70 |
71 |
72 |
77 |
78 | # Durability
79 |
80 |
81 |
82 | # Working with FedML
83 | - [FedML README.md](https://github.com/FedML-AI/FedML/blob/master/python/fedml/core/security/readme.md)
84 |
85 | Types of Attacks in FL setting:
86 | - Byzantine Attack
87 | - DLG Attack (Deep Leakage from Gradients)
88 | - Backdoor Attack
89 | - Model Replacement Attack
90 |
91 | # Survey Papers for Machine Learning Security
92 |
93 | | Title | Year | Venue | Code | Dataset | URL | Note |
94 | | :--- | :--- | :--- | :--- | :--- | :--- | :--- |
95 | |A Survey on Fully Homomorphic Encryption: An Engineering Perspective|2017|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3124441)||
96 | |Generative Adversarial Networks: A Survey Toward Private and Secure Applications|2021|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3459992)||
97 | |A Survey on Adversarial Recommender Systems: From Attack/Defense Strategies to Generative Adversarial Networks|2021|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3439729)||
98 | |Video Generative Adversarial Networks: A Review|2022|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3487891)||
99 | |Taxonomy of Machine Learning Safety: A Survey and Primer|2022|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3551385)||
100 | |Dataset Security for Machine Learning: Data Poisoning, Backdoor Attacks, and Defenses|2022|ACM Computing Surveys|||[link](https://arxiv.org/pdf/2012.10544.pdf)||
101 | |Generative Adversarial Networks: A Survey on Atack and Defense Perspective|2023|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3615336)||
102 | |Trustworthy AI: From Principles to Practices|2023|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3555803)||
103 | |Deep Learning for Android Malware Defenses: A Systematic Literature Review|2022|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3571156)||
104 | |Dataset security for machine learning: Data poisoning, backdoor attacks, and defenses|2022|IEEE TPAMI|||[link](https://arxiv.org/pdf/2012.10544.pdf)||
105 | |A Comprehensive Review of the State-of-the-Art on Security and Privacy Issues in Healthcare|2023|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3571156)||
106 | |A Comprehensive Survey of Privacy-preserving Federated Learning: A Taxonomy, Review, and Future Directions|2023|ACM Computing Surveys|||[link](https://dl.acm.org/doi/pdf/10.1145/3460427)||
107 | |Recent Advances on Federated Learning: A Systematic Survey|2023|arXiv|||[link](https://arxiv.org/pdf/2301.01299.pdf)||
108 | |Federated Learning for Generalization, Robustness, Fairness: A Survey and Benchmark|2023|arXiv|||[link](https://arxiv.org/pdf/2311.06750v1.pdf)||
109 | |Backdoor attacks and defenses in federated learning: Survey, challenges and future research directions|2024|Engineering Applications of Artificial Intelligence|||[link](https://www.sciencedirect.com/science/article/abs/pii/S0952197623013507)||
110 |
111 | ## Paper Backdoor Attack in ML/ FL
112 | - [How to Backdoor Federated Learning](https://arxiv.org/pdf/1807.00459.pdf)
113 | - [Attack of the Tails: Yes, You Really Can Backdoor Federated Learning](https://arxiv.org/pdf/2007.05084v1.pdf)
114 | - [DBA: Distributed Backdoor Attacks against Federated Learning](https://openreview.net/pdf?id=rkgyS0VFvr)
115 |
116 | ## Code for Backdoor Attack in ML/ FL
117 |
118 | | Title | Year | Venue | Code | Dataset | URL | Note |
119 | | :--- | :--- | :--- | :--- | :--- | :--- | :--- |
120 | |Practicing-Federated-Learning||Github|[link](https://github.com/FederatedAI/Practicing-Federated-Learning/)|
121 | |Attack of the Tails: Yes, You Really Can Backdoor Federated Learning||NeurIPS'20|[link](https://github.com/ksreenivasan/OOD_Federated_Learning)|
122 | |DBA: Distributed Backdoor Attacks against Federated Learning||ICLR'20|[link](https://github.com/AI-secure/DBA)|
123 | |LIRA: Learnable, Imperceptible and Robust Backdoor Attacks ||ICCV'21|[link](https://github.com/khoadoan106/backdoor_attacks/tree/main)|
124 | |Backdoors Framework for Deep Learning and Federated Learning||AISTAT'20, USENIX'21|[link](https://github.com/ebagdasa/backdoors101)|
125 | |BackdoorBox: An Open-sourced Python Toolbox for Backdoor Attacks and Defenses|2023|Github|[link](https://github.com/THUYimingLi/BackdoorBox)|
126 | |3DFed: Adaptive and Extensible Framework for Covert Backdoor Attack in Federated Learning||IEEE S&P'23|[link](https://github.com/haoyangliASTAPLE/3DFed)|
127 | |Neurotoxin: Durable Backdoors in Federated Learning||ICML'22|[link](https://proceedings.mlr.press/v162/zhang22w/zhang22w.pdf)|||Durability
128 | |Chameleon: Adapting to Peer Images for Planting Durable Backdoors in Federated Learning||ICML'23|[link](https://github.com/ybdai7/Chameleon-durable-backdoor)|||Durability
129 | |PerDoor: Persistent Backdoors in Federated Learning using Adversarial Perturbations||COINS'23|[link](https://ieeexplore.ieee.org/abstract/document/10189281)|
130 | ## Other Resources for Backdoor Attack in ML/ FL
131 | - [List of papers on data poisoning and backdoor attacks](https://github.com/penghui-yang/awesome-data-poisoning-and-backdoor-attacks)
132 | - [Proceedings of Machine Learning Research](https://proceedings.mlr.press/)
133 | - [Backdoor learning resources](https://github.com/THUYimingLi/backdoor-learning-resources)
134 | - Google Scholar:
135 | - [Avi Schwarzschild, Post Doc at CMU](https://scholar.google.com/citations?user=WNvQ7AcAAAAJ&hl=en)
136 | - [Yiming Li, Research Professor at Zhejiang University](https://scholar.google.com/citations?user=mSW7kU8AAAAJ&hl=zh-CN)
137 | - [Nicolas Papernot, Assistant Professor at University of Toronto](https://scholar.google.com/citations?user=cGxq0cMAAAAJ&hl=en)
138 |
139 | # Backdoor Attack code resources in FL
140 | In FL community, there are many code resources for backdoor attack, in which each of them has its own FL scenario (e.g., hyperparameters, dataset, attack methods, defense methods, etc.).
141 | Thus, we provide a list of popular code resources for backdoor attack in FL as follows:
142 | - [Attack of the Tails: Yes, You Really Can Backdoor Federated Learning - NeurIPS'20](https://github.com/ksreenivasan/OOD_Federated_Learning)
143 | - [DBA: Distributed Backdoor Attacks against Federated Learning ICLR'20](https://github.com/AI-secure/DBA)
144 | - [How To Backdoor Federated Learning - AISTATS'20](https://github.com/ebagdasa/backdoors101)
145 | - [Just How Toxic is Data Poisoning? A Unified Benchmark for Backdoor and Data Poisoning Attacks - ICML'21](https://github.com/aks2203/poisoning-benchmark)
146 | - [Learning to Backdoor Federated Learning - ICLR'23 Workshop](https://github.com/HengerLi/RLBackdoorFL/tree/main)
147 | - [Chameleon: Adapting to Peer Images for Planting Durable Backdoors in Federated Learning - ICML'23](https://github.com/ybdai7/Chameleon-durable-backdoor)
148 | - [FedMLSecurity: A Benchmark for Attacks and Defenses in Federated Learning and Federated LLMs - arXiv'23](https://github.com/FedML-AI/FedML/blob/master/python/fedml/core/security/readme.md)
149 | - [IBA: Towards Irreversible Backdoor Attacks in Federated Learning - NeurIPS'23](https://github.com/sail-research/iba)
150 |
151 |
--------------------------------------------------------------------------------
/attacks/attack.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Dict, List
3 |
4 | import torch
5 | from torch.utils.data import DataLoader
6 | from synthesizers.synthesizer import Synthesizer
7 | from attacks.loss_functions import compute_all_losses_and_grads
8 | from utils.parameters import Params
9 | import math
10 | logger = logging.getLogger('logger')
11 |
12 |
13 | class Attack:
14 | params: Params
15 | synthesizer: Synthesizer
16 | local_dataset: DataLoader
17 | loss_tasks: List[str]
18 | fixed_scales: Dict[str, float]
19 | ignored_weights = ['num_batches_tracked']#['tracked', 'running']
20 |
21 | def __init__(self, params, synthesizer):
22 | self.params = params
23 | self.synthesizer = synthesizer
24 | self.loss_tasks = ['normal', 'backdoor']
25 | self.fixed_scales = {'normal':0.5, 'backdoor':0.5}
26 |
27 | def perform_attack(self, _) -> None:
28 | raise NotImplemented
29 |
30 | def compute_blind_loss(self, model, criterion, batch, attack, fixed_model=None):
31 | """
32 |
33 | :param model:
34 | :param criterion:
35 | :param batch:
36 | :param attack: Do not attack at all. Ignore all the parameters
37 | :return:
38 | """
39 | batch = batch.clip(self.params.clip_batch)
40 | loss_tasks = self.loss_tasks.copy() if attack else ['normal']
41 | batch_back = self.synthesizer.make_backdoor_batch(batch, attack=attack)
42 | scale = dict()
43 |
44 | if len(loss_tasks) == 1:
45 | loss_values = compute_all_losses_and_grads(
46 | loss_tasks,
47 | self, model, criterion, batch, batch_back
48 | )
49 | else:
50 | loss_values = compute_all_losses_and_grads(
51 | loss_tasks,
52 | self, model, criterion, batch, batch_back,
53 | fixed_model = fixed_model)
54 |
55 | for t in loss_tasks:
56 | scale[t] = self.fixed_scales[t]
57 |
58 | if len(loss_tasks) == 1:
59 | scale = {loss_tasks[0]: 1.0}
60 | blind_loss = self.scale_losses(loss_tasks, loss_values, scale)
61 |
62 | return blind_loss
63 |
64 | def scale_losses(self, loss_tasks, loss_values, scale):
65 | blind_loss = 0
66 | # import IPython; IPython.embed()
67 | # exit(0)
68 | for it, t in enumerate(loss_tasks):
69 | self.params.running_losses[t].append(loss_values[t].item())
70 | self.params.running_scales[t].append(scale[t])
71 | if it == 0:
72 | blind_loss = scale[t] * loss_values[t]
73 | else:
74 | blind_loss += scale[t] * loss_values[t]
75 | self.params.running_losses['total'].append(blind_loss.item())
76 | return blind_loss
77 |
78 | def scale_update(self, local_update: Dict[str, torch.Tensor], gamma):
79 | for name, value in local_update.items():
80 | value.mul_(gamma)
81 |
82 | def get_fl_update(self, local_model, global_model) -> Dict[str, torch.Tensor]:
83 | local_update = dict()
84 | for name, data in local_model.state_dict().items():
85 | if self.check_ignored_weights(name):
86 | continue
87 | local_update[name] = (data - global_model.state_dict()[name])
88 |
89 | return local_update
90 |
91 | def check_ignored_weights(self, name) -> bool:
92 | for ignored in self.ignored_weights:
93 | if ignored in name:
94 | return True
95 |
96 | return False
97 |
98 | def get_update_norm(self, local_update):
99 | squared_sum = 0
100 | for name, value in local_update.items():
101 | if 'tracked' in name or 'running' in name:
102 | continue
103 | squared_sum += torch.sum(torch.pow(value, 2)).item()
104 | update_norm = math.sqrt(squared_sum)
105 | return update_norm
--------------------------------------------------------------------------------
/attacks/loss_functions.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import torch
4 | from torch.nn import functional as F, Module
5 |
6 | from models.model import Model
7 | from utils.parameters import Params
8 | from utils.utils import record_time
9 |
10 | def compute_all_losses_and_grads(loss_tasks, attack, model, criterion,
11 | batch, batch_back,
12 | fixed_model=None):
13 | loss_values = {}
14 | for t in loss_tasks:
15 | if t == 'normal':
16 | loss_values[t] = compute_normal_loss(attack.params,
17 | model,
18 | criterion,
19 | batch.inputs,
20 | batch.labels)
21 | elif t == 'backdoor':
22 | loss_values[t] = compute_backdoor_loss(attack.params,
23 | model,
24 | criterion,
25 | batch_back.inputs,
26 | batch_back.labels)
27 | elif t == 'eu_constraint':
28 | loss_values[t] = compute_euclidean_loss(attack.params,
29 | model,
30 | fixed_model)
31 | elif t == 'cs_constraint':
32 | loss_values[t] = compute_cos_sim_loss(attack.params,
33 | model,
34 | fixed_model)
35 |
36 | return loss_values
37 |
38 |
39 | def compute_normal_loss(params: Params, model, criterion, inputs, labels):
40 | t = time.perf_counter()
41 | outputs = model(inputs)
42 | record_time(params, t, 'forward')
43 | loss = criterion(outputs, labels)
44 | loss = loss.mean()
45 |
46 | return loss
47 |
48 | def compute_backdoor_loss(params, model, criterion, inputs_back, labels_back):
49 | t = time.perf_counter()
50 | outputs = model(inputs_back)
51 | record_time(params, t, 'forward')
52 | loss = criterion(outputs, labels_back)
53 | loss = loss.mean()
54 |
55 | return loss
56 |
57 | def compute_euclidean_loss(params: Params,
58 | model: Model,
59 | fixed_model: Model):
60 | size = 0
61 | for name, layer in model.named_parameters():
62 | size += layer.view(-1).shape[0]
63 | sum_var = torch.cuda.FloatTensor(size).fill_(0)
64 | size = 0
65 | for name, layer in model.named_parameters():
66 | sum_var[size:size + layer.view(-1).shape[0]] = (layer - \
67 | fixed_model.state_dict()[name]).view(-1)
68 | size += layer.view(-1).shape[0]
69 |
70 | loss = torch.norm(sum_var, p=2)
71 |
72 | return loss
73 |
74 | def get_one_vec(model: Module):
75 | size = 0
76 | for name, layer in model.named_parameters():
77 | size += layer.view(-1).shape[0]
78 | sum_var = torch.cuda.FloatTensor(size).fill_(0)
79 | size = 0
80 | for name, layer in model.named_parameters():
81 | sum_var[size:size + layer.view(-1).shape[0]] = (layer.data).view(-1)
82 | size += layer.view(-1).shape[0]
83 |
84 | return sum_var
85 |
86 | def compute_cos_sim_loss(params: Params,
87 | model: Model,
88 | fixed_model: Model):
89 | model_vec = get_one_vec(model)
90 | target_var = get_one_vec(fixed_model)
91 | cs_sim = F.cosine_similarity(params.fl_weight_scale*(model_vec-target_var)\
92 | + target_var, target_var, dim=0)
93 | loss = 1e3 * (1 - cs_sim)
94 | return loss
95 |
96 | def compute_noise_ups_loss(params: Params,
97 | backdoor_update,
98 | noise_masks,
99 | random_neurons):
100 | losses = []
101 | for i in range(len(noise_masks)):
102 | UPs = []
103 | for j in random_neurons:
104 | if 'MNIST' not in params.task:
105 | UPs.append(torch.abs(backdoor_update['fc.weight'][j] + \
106 | noise_masks[i].fc.weight[j]).sum() \
107 | + torch.abs(backdoor_update['fc.bias'][j] + \
108 | noise_masks[i].fc.bias[j]))
109 | else:
110 | UPs.append(torch.abs(backdoor_update['fc2.weight'][j] + \
111 | noise_masks[i].fc2.weight[j]).sum() \
112 | + torch.abs(backdoor_update['fc2.bias'][j] + \
113 | noise_masks[i].fc2.bias[j]))
114 | UPs_loss = 0
115 | for j in range(len(UPs)):
116 | if 'Imagenet' in params.task:
117 | UPs_loss += 5e-4 / UPs[j]
118 | else:
119 | UPs_loss += 1e-1 / UPs[j] # (UPs[j] * params.fl_num_neurons)
120 | noise_masks[i].requires_grad_(True)
121 | UPs_loss.requires_grad_(True)
122 | losses.append(UPs_loss)
123 | return losses
124 |
125 | def compute_noise_norm_loss(params: Params,
126 | noise_masks,
127 | random_neurons):
128 | size = 0
129 | layer_name = 'fc2' if 'MNIST' in params.task else 'fc'
130 | for name, layer in noise_masks[0].named_parameters():
131 | if layer_name in name:
132 | size += layer.view(-1).shape[0]
133 | losses = []
134 | for i in range(len(noise_masks)):
135 | sum_var = torch.cuda.FloatTensor(size).fill_(0)
136 | noise_size = 0
137 | for name, layer in noise_masks[i].named_parameters():
138 | if layer_name in name:
139 | for j in range(layer.shape[0]):
140 | if j in random_neurons:
141 | sum_var[noise_size:noise_size + layer[j].view(-1).shape[0]] = \
142 | layer[j].view(-1)
143 | noise_size += layer[j].view(-1).shape[0]
144 | if 'MNIST' in params.task:
145 | loss = 8e-2 * torch.norm(sum_var, p=2)
146 | else:
147 | loss = 3e-2 * torch.norm(sum_var, p=2)
148 | losses.append(loss)
149 | return losses
150 |
151 | def compute_lagrange_loss(params: Params,
152 | noise_masks,
153 | random_neurons):
154 | losses = []
155 | size = 0
156 | layer_name = 'fc2' if 'MNIST' in params.task else 'fc'
157 | for name, layer in noise_masks[0].named_parameters():
158 | if layer_name in name:
159 | size += layer.view(-1).shape[0]
160 | sum_var = torch.cuda.FloatTensor(size).fill_(0)
161 | for i in range(len(noise_masks)):
162 | size = 0
163 | for name, layer in noise_masks[i].named_parameters():
164 | if layer_name in name:
165 | for j in range(layer.shape[0]):
166 | if j in random_neurons:
167 | sum_var[size:size + layer[j].view(-1).shape[0]] += \
168 | layer[j].view(-1)
169 | size += layer[j].view(-1).shape[0]
170 |
171 | if 'MNIST' in params.task:
172 | loss = 1e-1 * torch.norm(sum_var, p=2)
173 | else:
174 | loss = 1e-2 * torch.norm(sum_var, p=2)
175 | for i in range(len(noise_masks)):
176 | losses.append(loss)
177 | return losses
178 |
179 | def compute_decoy_acc_loss(params: Params,
180 | benign_model: Model,
181 | decoy: Model,
182 | criterion, inputs, labels):
183 | dec_acc_loss, _ = compute_normal_loss(params, decoy, criterion, \
184 | inputs, labels)
185 | benign_acc_loss, _ = compute_normal_loss(params, benign_model, criterion, \
186 | inputs, labels)
187 | if dec_acc_loss > benign_acc_loss:
188 | loss = dec_acc_loss
189 | else:
190 | loss = - 1e-10 * (dec_acc_loss)
191 |
192 | return loss
193 |
194 | def compute_decoy_param_loss(params:Params,
195 | decoy: Model,
196 | benign_model: Model,
197 | param_idx):
198 |
199 | if 'MNIST' not in params.task:
200 | param_diff = torch.abs(decoy.fc.weight[param_idx[0]][param_idx[1]] - \
201 | benign_model.state_dict()['fc.weight'][param_idx[0]][param_idx[1]])
202 | else:
203 | param_diff = torch.abs(decoy.fc1.weight[param_idx[0]][param_idx[1]] - \
204 | benign_model.state_dict()['fc1.weight'][param_idx[0]][param_idx[1]])
205 |
206 | threshold = 10 # 30
207 | if(param_diff.item() > threshold):
208 | loss = 1e-10 * param_diff
209 | else:
210 | loss = - 1e1 * param_diff
211 |
212 | return loss
213 |
214 | def get_grads(params, model, loss):
215 | t = time.perf_counter()
216 | grads = list(torch.autograd.grad(loss.mean(),
217 | [x for x in model.parameters() if
218 | x.requires_grad],
219 | retain_graph=True, allow_unused=True))
220 | record_time(params, t, 'backward')
221 |
222 | return grads
--------------------------------------------------------------------------------
/attacks/modelreplace.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # from attack import Attack
3 | from attacks.attack import Attack
4 | from tasks.fl_user import FLUser
5 | import logging
6 | logger = logging.getLogger('logger')
7 |
8 | class ModelReplace(Attack):
9 |
10 | def __init__(self, params, synthesizer):
11 | super().__init__(params, synthesizer)
12 | self.loss_tasks.append('cs_constraint')
13 | self.fixed_scales = {'normal':0.3,
14 | 'backdoor':0.3,
15 | 'cs_constraint':0.4}
16 |
17 | def perform_attack(self, _, user: FLUser, epoch):
18 | if self.params.fl_number_of_adversaries <= 0 or \
19 | epoch not in range(self.params.poison_epoch,\
20 | self.params.poison_epoch_stop):
21 | return
22 |
23 | folder_name = f'{self.params.folder_path}/saved_updates'
24 | file_name = f'{folder_name}/update_{user.user_id}.pth'
25 | loaded_params = torch.load(file_name)
26 | logger.info(f"Loaded update user id: {user.user_id} from {file_name} with scale {self.params.fl_weight_scale}")
27 | self.scale_update(loaded_params, self.params.fl_weight_scale)
28 | torch.save(loaded_params, file_name)
29 |
30 | # for i in range(self.params.fl_number_of_adversaries):
31 | # file_name = f'{folder_name}/update_{i}.pth'
32 | # torch.save(loaded_params, file_name)
--------------------------------------------------------------------------------
/defenses/fedavg.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List, Any, Dict
3 | import torch
4 | import logging
5 | import os
6 | from utils.parameters import Params
7 |
8 | logger = logging.getLogger('logger')
9 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
10 |
11 | class FedAvg:
12 | params: Params
13 | ignored_weights = ['num_batches_tracked']#['tracked', 'running']
14 |
15 | def __init__(self, params: Params) -> None:
16 | self.params = params
17 |
18 | # FedAvg aggregation
19 | def aggr(self, weight_accumulator, _):
20 | # logger.info(f"Aggregating {len(self.params.fl_weight_contribution.keys())} participants")
21 |
22 | # weight_contrib = self.params.fl_weight_contribution
23 |
24 | for user_id, weight_contrib_user in self.params.fl_weight_contribution.items():
25 | # logger.info(f"Aggregating participant: {user_id} with weight: {weight_contrib_user}")
26 | loaded_params = self.params.fl_local_updated_models[user_id]
27 | self.accumulate_weights(weight_accumulator, \
28 | {key:(loaded_params[key] * weight_contrib_user ).to(self.params.device) for \
29 | key in loaded_params})
30 |
31 | # for idRound, userID in enumerate(self.params.fl_round_participants):
32 | # weight_contrib_user = self.params.fl_weight_contribution[userID]
33 | # # updates_name = '{0}/saved_updates/update_{1}.pth'\
34 | # # .format(self.params.folder_path, userID)
35 | # # # logger.info(f"Aggregating participant {userID} path: {updates_name}")
36 |
37 | # # loaded_params = torch.load(updates_name)
38 | # loaded_params = self.params.fl_local_updated_models[userID]
39 | # self.accumulate_weights(weight_accumulator, \
40 | # {key:(loaded_params[key] * weight_contrib_user ).to(self.params.device) for \
41 | # key in loaded_params})
42 |
43 | def accumulate_weights(self, weight_accumulator, local_update):
44 | for name, value in local_update.items():
45 | weight_accumulator[name].add_(value)
46 |
47 | def get_update_norm(self, local_update):
48 | squared_sum = 0
49 | for name, value in local_update.items():
50 | if 'tracked' in name or 'running' in name:
51 | continue
52 | squared_sum += torch.sum(torch.pow(value, 2)).item()
53 | update_norm = math.sqrt(squared_sum)
54 | return update_norm
55 |
56 | def add_noise(self, sum_update_tensor: torch.Tensor, sigma):
57 | noised_layer = torch.FloatTensor(sum_update_tensor.shape)
58 | noised_layer = noised_layer.to(self.params.device)
59 | noised_layer.normal_(mean=0, std=sigma)
60 | sum_update_tensor.add_(noised_layer)
61 |
62 | def check_ignored_weights(self, name) -> bool:
63 | for ignored in self.ignored_weights:
64 | if ignored in name:
65 | return True
66 |
67 | return False
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: aba # Name of the Conda environment
2 |
3 | channels:
4 | - defaults # You can add other channels if needed
5 |
6 | dependencies:
7 | - python=3.8 # Python version
8 | - numpy=1.21.2 # Example package with version
9 | - pandas=1.3.3 # Another package with a specific version
10 | - pytorch=1.9.0 # PyTorch version
11 | - torchvision=0.10.0 # Torchvision version
12 | - pip # Install pip
13 | - pip:
14 | # - torchattacks==3.1.0 # Install torchattacks
15 | # - torchsummary==1.5.1 # Install torchsummary
16 | # - torchmetrics==0.5.1 # Install torchmetrics
17 | # - torch_optimizer==0.1.0 # Install torch_optimizer
18 | - tqdm==4.62.3 # Install tqdm
19 | - matplotlib==3.4.3 # Install matplotlib
20 | - seaborn==0.11.2 # Install seaborn
21 | # - scikit-learn==0.24.2 # Install scikit-learn
22 | # - scikit-image==0.18.3 # Install scikit-image
23 | # - opencv-python==
--------------------------------------------------------------------------------
/exps/cifar_fed.yaml:
--------------------------------------------------------------------------------
1 | task: Cifar10
2 | synthesizer: IBA
3 |
4 | random_seed: 42
5 | batch_size: 256
6 | test_batch_size: 512
7 | lr: 0.005
8 | momentum: 0.9
9 | decay: 0.0005
10 | epochs: 1500
11 | poison_epoch: 0
12 | poison_epoch_stop: 1500
13 | save_on_epochs: [] # [30, 50, 80, 100, 120, 150, 170, 200]
14 | optimizer: SGD
15 | log_interval: 100
16 |
17 | poisoning_proportion: 0.5
18 | backdoor_label: 8
19 |
20 | resume_model: saved_models/cifar_pretrain/model_last.pt.tar.epoch_200
21 |
22 | save_model: True
23 | log: True
24 |
25 | transform_train: True
26 |
27 | fl: True
28 | fl_total_participants: 100
29 | fl_no_models: 10
30 | fl_local_epochs: 2
31 | fl_poison_epochs: 5
32 | # fl_poison_epochs: 15
33 | fl_eta: 1 # 0.8`3
34 | fl_sample_dirichlet: True
35 | fl_dirichlet_alpha: 0.5
36 |
37 | fl_number_of_adversaries: 4
38 | fl_weight_scale: 1
39 | fl_adv_group_size: 2
40 | # fl_single_epoch_attack: 200
41 |
42 | attack: ModelReplace
43 | defense: 'FedAvg'
44 | fl_num_neurons: 5
45 | noise_mask_alpha: 0 # 0.5
46 | lagrange_step: 0.1
--------------------------------------------------------------------------------
/exps/gen-run-yaml-file.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datetime import datetime
3 |
4 | str_write_to_file = ""
5 |
6 | def generate_exps_file(root_file='./exps/cifar_fed.yaml', name_exp = 'cifar10', EXPS_DIR="./exps/extras"):
7 | # read file as a string
8 | global str_write_to_file
9 | print(f'reading from root_file: {root_file}')
10 | str_write_to_file += f'# Dataset: {name_exp} \n**Reading default config from root_file: {root_file}**\n\n'
11 | str_write_to_file += '------------------------\n\n'
12 |
13 | fl_total_participants_choices = [100, 200]
14 | fl_no_models_choices = [10, 20]
15 | fl_dirichlet_alpha_choices = [0.5]
16 | fl_number_of_adversaries_choices = [4]
17 | fl_lr_choices = [0.05]
18 | resume_model_choices = [False, True]
19 | lis_resume_model = [
20 | 'resume_model: saved_models/tiny_64_pretrain/tiny-resnet.epoch_20',
21 | 'resume_model: saved_models/mnist_pretrain/model_last.pt.tar.epoch_10',
22 | 'resume_model: saved_models/cifar_pretrain/model_last.pt.tar.epoch_200',
23 | ]
24 |
25 | # EXPS_DIR = './exps/extras'
26 |
27 | os.makedirs(EXPS_DIR, exist_ok=True)
28 | exp_number = 0
29 |
30 | for fl_total_participants in fl_total_participants_choices:
31 | for fl_no_models in fl_no_models_choices:
32 | for fl_dirichlet_alpha in fl_dirichlet_alpha_choices:
33 | for fl_number_of_adversaries in fl_number_of_adversaries_choices:
34 | for fl_lr in fl_lr_choices:
35 | for resume_model in resume_model_choices:
36 |
37 | with open(root_file, 'r') as file :
38 | filedata = file.read()
39 |
40 | exp_number += 1
41 | str_write_to_file += f'## EXP ID: {exp_number:02d}\n'
42 | pretrained_str = 'pretrained' if resume_model else 'no_pretrained'
43 | print(f'`fl_total_participants: {fl_total_participants} fl_no_models: {fl_no_models} fl_dirichlet_alpha: {fl_dirichlet_alpha} fl_number_of_adversaries: {fl_number_of_adversaries} fl_lr: {fl_lr} resume_model: {resume_model}`\n\n')
44 | str_write_to_file += f'fl_total_participants: {fl_total_participants} fl_no_models: {fl_no_models} fl_dirichlet_alpha: {fl_dirichlet_alpha} fl_number_of_adversaries: {fl_number_of_adversaries} fl_lr: {fl_lr} resume_model: {resume_model}\n\n'
45 | filedata = filedata.replace('fl_total_participants: 100', f'fl_total_participants: {fl_total_participants}')
46 | filedata = filedata.replace('fl_no_models: 10', f'fl_no_models: {fl_no_models}')
47 | filedata = filedata.replace('fl_dirichlet_alpha: 0.5', f'fl_dirichlet_alpha: {fl_dirichlet_alpha}')
48 | filedata = filedata.replace('fl_number_of_adversaries: 4', f'fl_number_of_adversaries: {fl_number_of_adversaries}')
49 | filedata = filedata.replace('lr: 0.005', f'lr: {fl_lr}')
50 | # print(filedata)
51 |
52 | if not resume_model:
53 | for resume_model_path in lis_resume_model:
54 | if resume_model_path in filedata:
55 | filedata = filedata.replace(resume_model_path, f'resume_model: \n')
56 | # print?(filedata)
57 | # exit(0)
58 | # print(len(filedata), type(filedata))
59 | # print('------------------------')
60 | # write the file out again
61 |
62 | fn_write = f'{EXPS_DIR}/{name_exp}_fed_{fl_total_participants}_{fl_no_models}_{fl_number_of_adversaries}_{fl_dirichlet_alpha}_{fl_lr}_{pretrained_str}.yaml'
63 |
64 | if not os.path.exists(fn_write):
65 | with open(fn_write, 'w') as file:
66 | file.write(filedata)
67 |
68 | cmd = f'```bash\nCUDA_VISIBLE_DEVICES=0 python training.py --name {name_exp} --params {fn_write}\n```\n'
69 | print(cmd)
70 | str_write_to_file += f'{cmd}\n'
71 | print('------------------------')
72 | # if exp_number == 2:
73 | # exit(0)
74 |
75 |
76 | current_time = datetime.now().strftime('%Y.%b.%d')
77 |
78 | generate_exps_file(root_file='./exps/cifar_fed.yaml',
79 | name_exp = 'cifar10',
80 | EXPS_DIR=f"./exps/run_cifar10__{current_time}")
81 |
82 |
83 | generate_exps_file(root_file='./exps/mnist_fed.yaml',
84 | name_exp = 'mnist',
85 | EXPS_DIR=f"./exps/run_mnist__{current_time}")
86 |
87 | generate_exps_file(root_file='./exps/imagenet_fed.yaml',
88 | name_exp = 'tiny-imagenet',
89 | EXPS_DIR=f"./exps/run_tiny-imagenet__{current_time}")
90 |
91 | with open(f'./exps/run_exps__{current_time}.md', 'w') as file:
92 | file.write(str_write_to_file)
93 |
94 |
--------------------------------------------------------------------------------
/exps/imagenet_fed.yaml:
--------------------------------------------------------------------------------
1 | task: Imagenet
2 | synthesizer: Pattern
3 |
4 | data_path: .data/tiny-imagenet-200
5 |
6 | random_seed: 42
7 | batch_size: 256
8 | test_batch_size: 512
9 | lr: 0.001
10 | momentum: 0.9
11 | decay: 0.0005
12 | epochs: 1500
13 | poison_epoch: 0
14 | poison_epoch_stop: 1500
15 | save_on_epochs: [] # [10, 20, 30, 40, 50]
16 | optimizer: SGD
17 | log_interval: 100
18 |
19 | poisoning_proportion: 0.3
20 | backdoor_label: 8
21 |
22 | resume_model: saved_models/tiny_64_pretrain/tiny-resnet.epoch_20
23 |
24 | save_model: True
25 | log: True
26 | report_train_loss: False
27 |
28 | transform_train: True
29 |
30 | fl: True
31 | fl_no_models: 10
32 | fl_local_epochs: 2
33 | fl_poison_epochs: 10
34 | fl_total_participants: 100
35 | fl_eta: 1
36 | fl_sample_dirichlet: True
37 | fl_dirichlet_alpha: 0.5
38 |
39 | fl_number_of_adversaries: 4
40 | # fl_number_of_scapegoats: 0
41 | fl_weight_scale: 5
42 | fl_adv_group_size: 5
43 | # fl_single_epoch_attack: 10
44 |
45 | attack: ModelReplace
46 | defense: FedAvg
47 | fl_num_neurons: 100
48 | noise_mask_alpha: 0 # 0.5
49 | lagrange_step: 0.1
--------------------------------------------------------------------------------
/exps/list_exps__2023.Nov.25.md:
--------------------------------------------------------------------------------
1 | # Dataset: cifar10
2 | **Reading default config from root_file: ./exps/cifar_fed.yaml**
3 |
4 | ------------------------
5 |
6 | ## EXP ID: 01
7 | fl_total_participants: 100 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False
8 |
9 | ```bash
10 | CUDA_VISIBLE_DEVICES=0 python training.py --name cifar10 --params ./exps/run_cifar10__2023.Nov.25/cifar10_fed_100_10_4_0.5_0.05_no_pretrained.yaml
11 | ```
12 |
13 | ## EXP ID: 02
14 | fl_total_participants: 100 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True
15 |
16 | ```bash
17 | CUDA_VISIBLE_DEVICES=0 python training.py --name cifar10 --params ./exps/run_cifar10__2023.Nov.25/cifar10_fed_100_10_4_0.5_0.05_pretrained.yaml
18 | ```
19 |
20 | ## EXP ID: 03
21 | fl_total_participants: 100 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False
22 |
23 | ```bash
24 | CUDA_VISIBLE_DEVICES=0 python training.py --name cifar10 --params ./exps/run_cifar10__2023.Nov.25/cifar10_fed_100_20_4_0.5_0.05_no_pretrained.yaml
25 | ```
26 |
27 | ## EXP ID: 04
28 | fl_total_participants: 100 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True
29 |
30 | ```bash
31 | CUDA_VISIBLE_DEVICES=0 python training.py --name cifar10 --params ./exps/run_cifar10__2023.Nov.25/cifar10_fed_100_20_4_0.5_0.05_pretrained.yaml
32 | ```
33 |
34 | ## EXP ID: 05
35 | fl_total_participants: 200 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False
36 |
37 | ```bash
38 | CUDA_VISIBLE_DEVICES=0 python training.py --name cifar10 --params ./exps/run_cifar10__2023.Nov.25/cifar10_fed_200_10_4_0.5_0.05_no_pretrained.yaml
39 | ```
40 |
41 | ## EXP ID: 06
42 | fl_total_participants: 200 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True
43 |
44 | ```bash
45 | CUDA_VISIBLE_DEVICES=0 python training.py --name cifar10 --params ./exps/run_cifar10__2023.Nov.25/cifar10_fed_200_10_4_0.5_0.05_pretrained.yaml
46 | ```
47 |
48 | ## EXP ID: 07
49 | fl_total_participants: 200 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False
50 |
51 | ```bash
52 | CUDA_VISIBLE_DEVICES=0 python training.py --name cifar10 --params ./exps/run_cifar10__2023.Nov.25/cifar10_fed_200_20_4_0.5_0.05_no_pretrained.yaml
53 | ```
54 |
55 | ## EXP ID: 08
56 | fl_total_participants: 200 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True
57 |
58 | ```bash
59 | CUDA_VISIBLE_DEVICES=0 python training.py --name cifar10 --params ./exps/run_cifar10__2023.Nov.25/cifar10_fed_200_20_4_0.5_0.05_pretrained.yaml
60 | ```
61 |
62 | # Dataset: mnist
63 | **Reading default config from root_file: ./exps/mnist_fed.yaml**
64 |
65 | ------------------------
66 |
67 | ## EXP ID: 01
68 | fl_total_participants: 100 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False
69 |
70 | ```bash
71 | CUDA_VISIBLE_DEVICES=0 python training.py --name mnist --params ./exps/run_mnist__2023.Nov.25/mnist_fed_100_10_4_0.5_0.05_no_pretrained.yaml
72 | ```
73 |
74 | ## EXP ID: 02
75 | fl_total_participants: 100 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True
76 |
77 | ```bash
78 | CUDA_VISIBLE_DEVICES=0 python training.py --name mnist --params ./exps/run_mnist__2023.Nov.25/mnist_fed_100_10_4_0.5_0.05_pretrained.yaml
79 | ```
80 |
81 | ## EXP ID: 03
82 | fl_total_participants: 100 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False
83 |
84 | ```bash
85 | CUDA_VISIBLE_DEVICES=0 python training.py --name mnist --params ./exps/run_mnist__2023.Nov.25/mnist_fed_100_20_4_0.5_0.05_no_pretrained.yaml
86 | ```
87 |
88 | ## EXP ID: 04
89 | fl_total_participants: 100 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True
90 |
91 | ```bash
92 | CUDA_VISIBLE_DEVICES=0 python training.py --name mnist --params ./exps/run_mnist__2023.Nov.25/mnist_fed_100_20_4_0.5_0.05_pretrained.yaml
93 | ```
94 |
95 | ## EXP ID: 05
96 | fl_total_participants: 200 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False
97 |
98 | ```bash
99 | CUDA_VISIBLE_DEVICES=0 python training.py --name mnist --params ./exps/run_mnist__2023.Nov.25/mnist_fed_200_10_4_0.5_0.05_no_pretrained.yaml
100 | ```
101 |
102 | ## EXP ID: 06
103 | fl_total_participants: 200 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True
104 |
105 | ```bash
106 | CUDA_VISIBLE_DEVICES=0 python training.py --name mnist --params ./exps/run_mnist__2023.Nov.25/mnist_fed_200_10_4_0.5_0.05_pretrained.yaml
107 | ```
108 |
109 | ## EXP ID: 07
110 | fl_total_participants: 200 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False
111 |
112 | ```bash
113 | CUDA_VISIBLE_DEVICES=0 python training.py --name mnist --params ./exps/run_mnist__2023.Nov.25/mnist_fed_200_20_4_0.5_0.05_no_pretrained.yaml
114 | ```
115 |
116 | ## EXP ID: 08
117 | fl_total_participants: 200 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True
118 |
119 | ```bash
120 | CUDA_VISIBLE_DEVICES=0 python training.py --name mnist --params ./exps/run_mnist__2023.Nov.25/mnist_fed_200_20_4_0.5_0.05_pretrained.yaml
121 | ```
122 |
123 | # Dataset: tiny-imagenet
124 | **Reading default config from root_file: ./exps/imagenet_fed.yaml**
125 |
126 | ------------------------
127 |
128 | ## EXP ID: 01
129 | fl_total_participants: 100 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False
130 |
131 | ```bash
132 | CUDA_VISIBLE_DEVICES=0 python training.py --name tiny-imagenet --params ./exps/run_tiny-imagenet__2023.Nov.25/tiny-imagenet_fed_100_10_4_0.5_0.05_no_pretrained.yaml
133 | ```
134 |
135 | ## EXP ID: 02
136 | fl_total_participants: 100 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True
137 |
138 | ```bash
139 | CUDA_VISIBLE_DEVICES=0 python training.py --name tiny-imagenet --params ./exps/run_tiny-imagenet__2023.Nov.25/tiny-imagenet_fed_100_10_4_0.5_0.05_pretrained.yaml
140 | ```
141 |
142 | ## EXP ID: 03
143 | fl_total_participants: 100 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False
144 |
145 | ```bash
146 | CUDA_VISIBLE_DEVICES=0 python training.py --name tiny-imagenet --params ./exps/run_tiny-imagenet__2023.Nov.25/tiny-imagenet_fed_100_20_4_0.5_0.05_no_pretrained.yaml
147 | ```
148 |
149 | ## EXP ID: 04
150 | fl_total_participants: 100 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True
151 |
152 | ```bash
153 | CUDA_VISIBLE_DEVICES=0 python training.py --name tiny-imagenet --params ./exps/run_tiny-imagenet__2023.Nov.25/tiny-imagenet_fed_100_20_4_0.5_0.05_pretrained.yaml
154 | ```
155 |
156 | ## EXP ID: 05
157 | fl_total_participants: 200 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False
158 |
159 | ```bash
160 | CUDA_VISIBLE_DEVICES=0 python training.py --name tiny-imagenet --params ./exps/run_tiny-imagenet__2023.Nov.25/tiny-imagenet_fed_200_10_4_0.5_0.05_no_pretrained.yaml
161 | ```
162 |
163 | ## EXP ID: 06
164 | fl_total_participants: 200 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True
165 |
166 | ```bash
167 | CUDA_VISIBLE_DEVICES=0 python training.py --name tiny-imagenet --params ./exps/run_tiny-imagenet__2023.Nov.25/tiny-imagenet_fed_200_10_4_0.5_0.05_pretrained.yaml
168 | ```
169 |
170 | ## EXP ID: 07
171 | fl_total_participants: 200 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: False
172 |
173 | ```bash
174 | CUDA_VISIBLE_DEVICES=0 python training.py --name tiny-imagenet --params ./exps/run_tiny-imagenet__2023.Nov.25/tiny-imagenet_fed_200_20_4_0.5_0.05_no_pretrained.yaml
175 | ```
176 |
177 | ## EXP ID: 08
178 | fl_total_participants: 200 fl_no_models: 20 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05 resume_model: True
179 |
180 | ```bash
181 | CUDA_VISIBLE_DEVICES=0 python training.py --name tiny-imagenet --params ./exps/run_tiny-imagenet__2023.Nov.25/tiny-imagenet_fed_200_20_4_0.5_0.05_pretrained.yaml
182 | ```
183 |
184 |
--------------------------------------------------------------------------------
/exps/mnist_fed.yaml:
--------------------------------------------------------------------------------
1 | task: MNIST
2 | synthesizer: Pattern
3 |
4 | random_seed: 42
5 | batch_size: 256
6 | test_batch_size: 512
7 | lr: 0.01
8 | momentum: 0.9
9 | decay: 0.0005
10 | epochs: 1500
11 | poison_epoch: 0
12 | poison_epoch_stop: 1500
13 | save_on_epochs: [] # [10, 20, 30, 40, 50]
14 | optimizer: SGD
15 | log_interval: 100
16 |
17 | poisoning_proportion: 1.0
18 | backdoor_label: 8
19 |
20 | resume_model: saved_models/mnist_pretrain/model_last.pt.tar.epoch_10
21 |
22 | save_model: True
23 | log: True
24 | report_train_loss: False
25 |
26 | transform_train: True
27 |
28 | fl: True
29 | fl_no_models: 10
30 | fl_local_epochs: 2
31 | fl_poison_epochs: 5
32 | fl_total_participants: 100
33 | fl_eta: 1
34 | fl_sample_dirichlet: True
35 | fl_dirichlet_alpha: 0.5
36 |
37 | fl_number_of_adversaries: 4
38 | fl_weight_scale: 5
39 | fl_adv_group_size: 2
40 | # fl_single_epoch_attack: 20
41 |
42 | attack: ModelReplace
43 | defense: FedAvg
44 | fl_num_neurons: 8
45 | noise_mask_alpha: 0 # 0.5
46 | lagrange_step: 0.1
--------------------------------------------------------------------------------
/helper.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import logging
3 | import os
4 | import random
5 | from shutil import copyfile
6 | from collections import defaultdict
7 |
8 | import numpy as np
9 | import torch
10 | import yaml
11 |
12 | from attacks.attack import Attack
13 | # from defenses.fedavg import FedAvg as Defense
14 | from synthesizers.synthesizer import Synthesizer
15 | from tasks.task import Task
16 | from utils.parameters import Params
17 | from utils.utils import create_logger
18 | import pandas as pd
19 | logger = logging.getLogger('logger')
20 |
21 |
22 | class Helper:
23 | params: Params = None
24 | task: Task = None
25 | synthesizer: Synthesizer = None
26 | # defense: Defense = None
27 | attack: Attack = None
28 |
29 | def __init__(self, params):
30 |
31 | self.params = Params(**params)
32 | self.times = {'backward': list(), 'forward': list(), 'step': list(),
33 | 'scales': list(), 'total': list(), 'poison': list()}
34 |
35 | if self.params.random_seed is not None:
36 | self.fix_random(self.params.random_seed)
37 |
38 | self.make_folders()
39 |
40 | self.make_task()
41 |
42 | self.make_synthesizer()
43 |
44 | self.make_attack()
45 |
46 | self.make_defense()
47 |
48 | self.accuracy = [[],[]]
49 |
50 | self.best_loss = float('inf')
51 |
52 | def make_task(self):
53 | name_lower = self.params.task.lower()
54 | name_cap = self.params.task
55 | module_name = f'tasks.{name_lower}_task'
56 | path = f'tasks/{name_lower}_task.py'
57 | logger.info(f'make task: {module_name} name_cap: {name_cap} path: {path}')
58 | try:
59 | task_module = importlib.import_module(module_name)
60 | task_class = getattr(task_module, f'{name_cap}Task')
61 | except (ModuleNotFoundError, AttributeError):
62 | raise ModuleNotFoundError(f'Your task: {self.params.task} should '
63 | f'be defined as a class '
64 | f'{name_cap}'
65 | f'Task in {path}')
66 | self.task = task_class(self.params)
67 |
68 | def make_synthesizer(self):
69 | name_lower = self.params.synthesizer.lower()
70 | name_cap = self.params.synthesizer
71 | module_name = f'synthesizers.{name_lower}_synthesizer'
72 | logger.info(f'make synthesizer: {module_name} name_cap: {name_cap}')
73 | try:
74 | synthesizer_module = importlib.import_module(module_name)
75 | task_class = getattr(synthesizer_module, f'{name_cap}Synthesizer')
76 | except (ModuleNotFoundError, AttributeError):
77 | raise ModuleNotFoundError(
78 | f'The synthesizer: {self.params.synthesizer}'
79 | f' should be defined as a class '
80 | f'{name_cap}Synthesizer in '
81 | f'synthesizers/{name_lower}_synthesizer.py')
82 | self.synthesizer = task_class(self.task)
83 |
84 | def make_attack(self):
85 | name_lower = self.params.attack.lower()
86 | name_cap = self.params.attack
87 | module_name = f'attacks.{name_lower}'
88 | logger.info(f'make attack: {module_name} name_cap: {name_cap}')
89 | try:
90 | attack_module = importlib.import_module(module_name)
91 | attack_class = getattr(attack_module, f'{name_cap}')
92 | except (ModuleNotFoundError, AttributeError):
93 | raise ModuleNotFoundError(f'Your attack: {self.params.attack} should '
94 | f'be defined either ThrDFed (3DFed) or \
95 | ModelReplace (Model Replacement Attack)')
96 | self.attack = attack_class(self.params, self.synthesizer)
97 |
98 | def make_defense(self):
99 | name_lower = self.params.defense.lower()
100 | name_cap = self.params.defense
101 | module_name = f'defenses.{name_lower}'
102 | try:
103 | defense_module = importlib.import_module(module_name)
104 | defense_class = getattr(defense_module, f'{name_cap}')
105 | except (ModuleNotFoundError, AttributeError):
106 | raise ModuleNotFoundError(f'Your defense: {self.params.defense} should '
107 | f'be one of the follow: FLAME, Deepsight, \
108 | Foolsgold, FLDetector, RFLBAT, FedAvg')
109 | self.defense = defense_class(self.params)
110 |
111 | def make_folders(self):
112 | log = create_logger()
113 | if self.params.log:
114 | os.makedirs(self.params.folder_path, exist_ok=True)
115 |
116 | fh = logging.FileHandler(
117 | filename=f'{self.params.folder_path}/log.txt')
118 | formatter = logging.Formatter('%(asctime)s - %(filename)s - Line:%(lineno)d - %(levelname)-8s - %(message)s')
119 |
120 | fh.setFormatter(formatter)
121 | log.addHandler(fh)
122 |
123 | with open(f'{self.params.folder_path}/params.yaml.txt', 'w') as f:
124 | yaml.dump(self.params, f)
125 | # log.info(f"Creating folder {self.params.folder_path}")
126 |
127 |
128 | def save_model(self, model=None, epoch=0, val_loss=0):
129 |
130 | if self.params.save_model:
131 | logger.info(f"Saving model to {self.params.folder_path}.")
132 | model_name = '{0}/model_last.pt.tar'.format(self.params.folder_path)
133 | saved_dict = {'state_dict': model.state_dict(),
134 | 'epoch': epoch,
135 | 'lr': self.params.lr,
136 | 'params_dict': self.params.to_dict()}
137 | self.save_checkpoint(saved_dict, False, model_name)
138 |
139 | if epoch in self.params.save_on_epochs:
140 | logger.info(f'Saving model on epoch {epoch}')
141 | self.save_checkpoint(saved_dict, False,
142 | filename=f'{self.params.folder_path}/model_epoch_{epoch}.pt.tar')
143 | if val_loss < self.best_loss:
144 | self.save_checkpoint(saved_dict, False, f'{model_name}_best')
145 | self.best_loss = val_loss
146 |
147 | logger.info(f"Done saving model to {self.params.folder_path}.")
148 | def save_update(self, model=None, userID = 0):
149 | folderpath = '{0}/saved_updates'.format(self.params.folder_path)
150 | # logger.info(f"Saving update to {folderpath}.")
151 |
152 | # if not os.path.exists(folderpath):
153 | # os.makedirs(folderpath)]
154 | # import IPython; IPython.embed()
155 | # exit(0)
156 | os.makedirs(folderpath, exist_ok=True)
157 | update_name = '{0}/update_{1}.pth'.format(folderpath, userID)
158 | torch.save(model, update_name)
159 | # logger.info(f"Saving update to {update_name}.")
160 |
161 | def remove_update(self):
162 | for i in range(self.params.fl_total_participants):
163 | file_name = '{0}/saved_updates/update_{1}.pth'.format(self.params.folder_path, i)
164 | if os.path.exists(file_name):
165 | os.remove(file_name)
166 | os.rmdir('{0}/saved_updates'.format(self.params.folder_path))
167 | if self.params.defense == 'Foolsgold':
168 | for i in range(self.params.fl_total_participants):
169 | file_name = '{0}/foolsgold/history_{1}.pth'.format(self.params.folder_path, i)
170 | if os.path.exists(file_name):
171 | os.remove(file_name)
172 | os.rmdir('{0}/foolsgold'.format(self.params.folder_path))
173 |
174 | def record_accuracy(self, main_acc, backdoor_acc, epoch):
175 | self.accuracy[0].append(main_acc)
176 | self.accuracy[1].append(backdoor_acc)
177 | name = ['main', 'backdoor']
178 | acc_frame = pd.DataFrame(columns=name, data=zip(*self.accuracy),
179 | index=range(self.params.start_epoch, epoch+1))
180 | filepath = f"{self.params.folder_path}/accuracy.csv"
181 | acc_frame.to_csv(filepath)
182 | logger.info(f"Saving accuracy record to {filepath}")
183 |
184 | def set_bn_eval(m):
185 | classname = m.__class__.__name__
186 | if classname.find('BatchNorm') != -1:
187 | m.eval()
188 |
189 | def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'):
190 | if not self.params.save_model:
191 | return False
192 | torch.save(state, filename)
193 |
194 | if is_best:
195 | copyfile(filename, 'model_best.pth.tar')
196 |
197 | @staticmethod
198 | def fix_random(seed=1):
199 | from torch.backends import cudnn
200 |
201 | logger.warning('Setting random_seed seed for reproducible results.')
202 | random.seed(seed)
203 | torch.manual_seed(seed)
204 | torch.cuda.manual_seed_all(seed)
205 | cudnn.deterministic = False
206 | cudnn.enabled = True
207 | cudnn.benchmark = True
208 | np.random.seed(seed)
209 |
210 | return True
211 |
--------------------------------------------------------------------------------
/markdown-explanation/FedMLSec.md:
--------------------------------------------------------------------------------
1 | Table of Contents
2 | -----------------
3 |
4 |
5 | # Setup Python Environment
6 | - Set the name of environment in both files: `environment.yml` and `Makefile`. The default name is `aba`, aka "all backdoor attacks" and then run following commands:
7 | ```
8 | make install
9 | ```
10 |
11 | # Guideline for custome training (Atk vs Def)
12 |
13 | ## Data Customization
14 | - [Data Loader](https://github.com/FedML-AI/FedML/blob/master/doc/en/simulation/user_guide/data_loader_customization.md)
15 |
16 | ## Datasets and Models Customization
17 | - [Datasets and Models](https://github.com/FedML-AI/FedML/blob/master/doc/en/simulation/user_guide/datasets-and-models.md#datasets-and-models)
18 | - [FedML Data](https://github.com/FedML-AI/FedML/tree/master/python/fedml/data)
19 |
20 | ## Attack Customization
21 | ### [Model Replacement Attack (MRA)](https://arxiv.org/pdf/1807.00459.pdf)
22 | - [Code](https://github.com/ebagdasa/backdoors101) is in `fedlearn-backdoor-attacks/3DFed/attacks/modelreplace.py`
23 | 
24 |
25 |
26 | ## Defense Customization
27 |
28 | ## Training Customization
29 |
30 | ## Evaluation Customization
31 |
32 | ## Visualization Customization
33 |
34 | ## Result Customization
35 |
36 |
37 | ## Flow of the code:
38 | ### [3DFed (ThrDFed) - S&P'23](https://github.com/haoyangliASTAPLE/3DFed)
39 | - Init data, attack, defense method in `fedlearn-backdoor-attacks/3DFed/helper.py`
40 | - Run fl round in `fedlearn-backdoor-attacks/3DFed/training.py`
41 | - Sample user for round (fl_no_models/ fl_total_participants)
42 | - Init FLUser (user_id, compromised, train_loader)
43 | - Single epoch attack: user_id = 0 is attacker, compromised = True
44 | - Otherwise: check if epoch in attack_epochs, if yes, check list adversaries
45 | - Training for each user
46 | - If user is attacker, run attack for only user_id = 0 [missing other attackers], that means training models on poisoned data
47 | - Otherwise, run defense for all users
48 | - Perform Attack and aggregate results
49 | - check update_global_model with weight: 1/total_participants (self.params.fl_eta / self.params.fl_total_participants)
50 | - Limitation: Currently, dump fl_no_models (set = fl_total_participants) models in each round into file, only one attacker is supported (other attackers is duplicated from attacker 0)
51 | ### [DBA - ICLR'20 (Distributed Backdoor Attack)](https://github.com/AI-secure/DBA)
52 | - List of adversarial clients in config file: adversary_list: [17, 33, 77, 11]
53 | - no_models: K = 10
54 | - number_of_total_participants: N = 100
55 | - agent_name_keys: id of client that participate in each round (K/ N)
56 | - agent_name_keys = benign + adv (check adv first, then benign); set epoch attack for each client in .yaml file
57 | In each communication round:
58 | - Check the trigger is global pattern or local pattern
59 | - Get poison batch data from adversary_index (-1, 0, 1, 2, 3)
60 |
61 |
62 | ### [Attack of the Tails: Yes, You Really Can Backdoor Federated Learning - NeurIPS'20](https://github.com/ksreenivasan/OOD_Federated_Learning)
63 |
64 | Load poisoned dataset (in `simulated_averaging.py`):
65 | ```python
66 | poisoned_train_loader, vanilla_test_loader, targetted_task_test_loader, num_dps_poisoned_dataset, clean_train_loader = load_poisoned_dataset(args=args)
67 | ```
68 |
69 | Two modes (fixed-freq mode or fixed-pool mode):
70 | ```python
71 | Intotal: N (num_nets) clients, K (part_nets_per_round) clients are participating in each round
72 |
73 | - Fixed-freq Attack (`FrequencyFederatedLearningTrainer`):
74 | - "attacking_fl_rounds":[i for i in range(1, args.fl_round + 1) if (i-1)%10 == 0]
75 | - poison_type in ["ardis": ["normal-case", "almost-edge-case", "edge-case"], "southwest"]
76 | - Training in each communication round:
77 | - if round in attacking_fl_rounds, run attack
78 | - One attacker, and K-1 benign clients
79 | - For attacker, run attack in `adversarial_local_training_period` epochs
80 | - data_loader = `poisoned_train_loader`
81 | - Check defense_technique in ["krum", "multi-krum"]: eps=self.eps*self.args_gamma**(flr-1); else eps=self.eps
82 | - Test on `vanilla_test_loader` and `targetted_task_test_loader`
83 | - Check model_replacement scale models with ratio: total_num_dps_per_round/num_dps_poisoned_dataset
84 | - Print norm before and after scale
85 | - For benign clients, run normal training in `local_training_period` epochs
86 | - data_loader = `clean_train_loader`
87 | - otherwise
88 | - run normal training with K benign clients
89 | - Aggregate models
90 | - using
91 | - TODO: check prox-attack
92 |
93 | - test function:
94 | - check dataset:
95 | - dataset in ["mnist", "emnist"]:
96 | - target_class = 7
97 | - task in ["raw-task", "targeted"]
98 | - poison_type in ["ardis"]
99 | - dataset in ["cifar10"]:
100 | - target_class = 2 for greencar in ("howto", "greencar-neo"),
101 | - target_class = 9 for southwest
102 | - TODO: check backdoor acc calculation (line 248-253)
103 |
104 | - Fixed-pool Attack:
105 | - In each round of communication, randomly select K clients to participate in the training
106 | - selected_attackers from "__attacker_pool" (numpy random choice attacker_pool_size/ num_nets)
107 | - selected_benign_clients
108 | All attackers are sharing the same poisoned dataset
109 | - Training the same as Fixed-freq Attack for each attacker
110 |
111 | - defense_technique in ["no-defense", "norm-clipping", "weak-dp", "krum", "multi-krum", "rfa"]
112 | ```
113 | For more details, check `FedML` [dataset](https://github.com/FedML-AI/FedML/tree/master/python/fedml/data/edge_case_examples)
114 |
115 | ### [How To Backdoor Federated Learning (AISTATS'20)](https://github.com/ebagdasa/backdoors101)
116 | - Generate trigger in file `synthesizers/pattern_synthesizer.py`. The 2 images below are the pattern of the trigger and image with trigger.
117 |
119 | - All the original images and trigger images are normalized to [mean, std] of the original dataset
120 |
121 |
122 |

124 |

126 |
127 |
128 | - Define the loss function in `attacks/loss_function.py`
129 | Standard version: each malicious client computes 2 losses and weights them equally (0.5, 0.5)
130 |
131 | ```Python
132 | # Criterion: nn.CrossEntropyLoss(reduction='none'); get all losses shape (batch_size)
133 | Total loss = 0.5 * (loss1 + loss2); check `scale_losses` in `attacks/attack.py`
134 | ```
135 |
136 | The result of the experiment can be found in this [wandb](https://wandb.ai/mtuann/benchmark-backdoor-fl?workspace=user-mtuann).
137 |
138 |
139 | ## TODO:
140 | - [ ] Setting standard FL attack from Attack of the Tails and DBA
141 | - [ ] Change dump to file -> dump to memory
142 | - [ ] Check popular defense method: Foolsgold, RFA, ...
143 |
144 | ## Reference
145 |
146 |
147 | # Source
148 | - [FedMLSecurity: A Benchmark for Attacks and Defenses in Federated Learning and Federated LLMs](https://arxiv.org/pdf/2306.04959.pdf)
149 | - [Attack and Defense of FedMLSecurity](https://github.com/FedML-AI/FedML/blob/master/python/fedml/core/security/readme.md)
150 | - [fedMLSecurity_experiments](https://github.com/FedML-AI/FedML/tree/master/python/examples/security/fedMLSecurity_experiments)
151 |
152 |
--------------------------------------------------------------------------------
/markdown-explanation/Figure_4_ba_norm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtuann/fedlearn-backdoor-attacks/b025e763558d1e3d6a5872be7d158e8c157e7bf1/markdown-explanation/Figure_4_ba_norm.png
--------------------------------------------------------------------------------
/markdown-explanation/Figure_full-image-pattern.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtuann/fedlearn-backdoor-attacks/b025e763558d1e3d6a5872be7d158e8c157e7bf1/markdown-explanation/Figure_full-image-pattern.png
--------------------------------------------------------------------------------
/markdown-explanation/code-structures.md:
--------------------------------------------------------------------------------
1 |
2 | Overview of the code structures used in the project.
3 |
4 | Tree structure
5 | --------------
6 | ```
7 | .
8 | ├── README.md
9 | ├── training.py
10 | ├── helper.py
11 | ├── utils
12 | │ ├── params.py
13 | │ ├── utils.py
14 | ├── data
15 | │ ├── MNIST
16 | │ ├── CIFAR-10
17 | │ ├── CIFAR-100
18 | ```
19 |
20 | Running code:
21 | -------------
22 | ```
23 | python training.py --name mnist --params exps/mnist_fed.yaml
24 | ```
25 | Flow of the code:
26 | -----------------
27 | 1. `training.py` is the main file that is run.
28 | - It parses the arguments and loads the parameters from the yaml file.
29 | - Load all the configurations to `helper.py` (parameters as a variable in a `Helper` class)
30 |
31 | Perform the following steps for each round:
32 | - Training `epochs` communcation rounds
33 | - Traing each round for `fl_local_epochs` epochs
34 | - Save local update to the `saved_updates/update_{user_ID}.pth` file
35 | - Perfrom FedAvg on the local updates in `defenses/fedavg.py`
36 | - Update global model by set the scale $scale = \frac{fl\_eta}{fl\_no\_models} = \frac{1}{k}$ --> change it by the number of samples in each client
37 |
38 | - Evaluate the model on the test set
39 |
40 |
41 | 1. Define the `task.py` in the `tasks` folder, that inherits from the `Task` class.
42 | - `Task` class has following functions:
43 | - `load_data` - load the data from the `data` folder, and split it for different clients.
44 | - `build_model` - build the model for the task
45 | - `resume_model` - resume the model from the checkpoint
46 | - `make_criterion` - define the loss function for the task.
47 | - `make_optimizer` - define the optimizer for the task.
48 | - `sample_adversaries` - sample the adversaries for the task.
49 | - `train` - train the model for one epoch.
50 | - `metrics` - define the metrics for the task, 2 main metrics are `AccuracyMetric` (top-k metrics) and `TestLossMetric`.
51 |
52 | 1. Define the `synthesizer.py` in the `synthesizers` folder, that inherits from the `Synthesizer` class.
53 | 2. Define the `attack.py` in the `attacks` folder, that inherits from the `Attack` class.
54 | - Loss functions are defined in the `attacks/losses.py` file.
55 | 3. Define the loss function
56 |
--------------------------------------------------------------------------------
/markdown-explanation/model-replacement-attack.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtuann/fedlearn-backdoor-attacks/b025e763558d1e3d6a5872be7d158e8c157e7bf1/markdown-explanation/model-replacement-attack.png
--------------------------------------------------------------------------------
/metrics/accuracy_metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from metrics.metric import Metric
3 |
4 |
5 | class AccuracyMetric(Metric):
6 |
7 | def __init__(self, top_k=(1,)):
8 | self.top_k = top_k
9 | self.main_metric_name = 'Top-1'
10 | super().__init__(name='Accuracy', train=False)
11 |
12 | def compute_metric(self, outputs: torch.Tensor,
13 | labels: torch.Tensor):
14 | """Computes the precision@k for the specified values of k"""
15 | max_k = max(self.top_k)
16 | batch_size = labels.shape[0]
17 |
18 | _, pred = outputs.topk(max_k, 1, True, True)
19 | # get value and index of top k classes
20 | pred = pred.t() # view as column vector
21 | # expand labels to match pred size: 2D size (top_k, batch_size)
22 | correct = pred.eq(labels.view(1, -1).expand_as(pred))
23 | # correct: True/ False matrix
24 | res = dict()
25 | for k in self.top_k:
26 | correct_k = correct[:k].view(-1).float().sum(0)
27 | res[f'Top-{k}'] = (correct_k.mul_(100.0 / batch_size)).item()
28 | return res
29 |
--------------------------------------------------------------------------------
/metrics/metric.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import defaultdict
3 | from typing import Dict, Any
4 |
5 | import numpy as np
6 |
7 | logger = logging.getLogger('logger')
8 |
9 |
10 | class Metric:
11 | name: str
12 | train: bool
13 | plottable: bool = True
14 | running_metric = None
15 | main_metric_name = None
16 |
17 | def __init__(self, name, train=False):
18 | self.train = train
19 | self.name = name
20 |
21 | self.running_metric = defaultdict(list)
22 |
23 | def __repr__(self):
24 | metrics = self.get_value()
25 | text = [f'{key}: {val:.2f}' for key, val in metrics.items()]
26 | return f'{self.name}: ' + ','.join(text)
27 |
28 | def compute_metric(self, outputs, labels) -> Dict[str, Any]:
29 | raise NotImplemented
30 |
31 | def accumulate_on_batch(self, outputs=None, labels=None):
32 | current_metrics = self.compute_metric(outputs, labels)
33 | for key, value in current_metrics.items():
34 | self.running_metric[key].append(value)
35 |
36 | def get_value(self) -> Dict[str, np.ndarray]:
37 | metrics = dict()
38 | for key, value in self.running_metric.items():
39 | metrics[key] = np.mean(value)
40 |
41 | return metrics
42 |
43 | def get_main_metric_value(self):
44 | if not self.main_metric_name:
45 | raise ValueError(f'For metric {self.name} define '
46 | f'attribute main_metric_name.')
47 | metrics = self.get_value()
48 | return metrics[self.main_metric_name]
49 |
50 | def reset_metric(self):
51 | self.running_metric = defaultdict(list)
52 |
53 |
--------------------------------------------------------------------------------
/metrics/test_loss_metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from metrics.metric import Metric
3 |
4 |
5 | class TestLossMetric(Metric):
6 |
7 | def __init__(self, criterion, train=False):
8 | self.criterion = criterion
9 | self.main_metric_name = 'value'
10 | super().__init__(name='Loss', train=False)
11 |
12 | def compute_metric(self, outputs: torch.Tensor,
13 | labels: torch.Tensor, top_k=(1,)):
14 | """Computes the precision@k for the specified values of k"""
15 | loss = self.criterion(outputs, labels)
16 | return {'value': loss.mean().item()}
--------------------------------------------------------------------------------
/models/MnistNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from models.simple import SimpleNet
5 |
6 |
7 | class MnistNet(SimpleNet):
8 | def __init__(self, name=None, created_time=None):
9 | super(MnistNet, self).__init__(f'{name}_Simple', created_time)
10 |
11 | self.conv1 = nn.Conv2d(1, 20, 5, 1)
12 | self.conv2 = nn.Conv2d(20, 50, 5, 1)
13 | self.fc1 = nn.Linear(4 * 4 * 50, 500)
14 | self.fc2 = nn.Linear(500, 10)
15 | # self.fc2 = nn.Linear(28*28, 10)
16 |
17 | def forward(self, x):
18 | x = F.relu(self.conv1(x))
19 | x = F.max_pool2d(x, 2, 2)
20 | x = F.relu(self.conv2(x))
21 | x = F.max_pool2d(x, 2, 2)
22 | x = x.view(-1, 4 * 4 * 50)
23 | x = F.relu(self.fc1(x))
24 | x = self.fc2(x)
25 |
26 | # in_features = 28 * 28
27 | # x = x.view(-1, in_features)
28 | # x = self.fc2(x)
29 |
30 | # normal return:
31 | return F.log_softmax(x, dim=1)
32 | # soft max is used for generate SDT data
33 | # return F.softmax(x, dim=1)
34 |
35 | if __name__ == '__main__':
36 | model=MnistNet()
37 | print(model)
38 |
39 | # import numpy as np
40 | # from torchvision import datasets, transforms
41 | # import torch
42 | # import torch.utils.data
43 | # import copy
44 | #
45 | # optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
46 | #
47 | # train_dataset = datasets.MNIST('./data', train=True, download=True,
48 | # transform=transforms.Compose([
49 | # transforms.ToTensor(),
50 | # # transforms.Normalize((0.1307,), (0.3081,))
51 | # ]))
52 | # test_dataset = datasets.MNIST('./data', train=False, transform=transforms.Compose([
53 | # transforms.ToTensor(),
54 | # # transforms.Normalize((0.1307,), (0.3081,))
55 | # ]))
56 | # train_loader = torch.utils.data.DataLoader(train_dataset,
57 | # batch_size=64,
58 | # shuffle=False)
59 | # client_grad = []
60 | #
61 | # for batch_id, batch in enumerate(train_loader):
62 | # optimizer.zero_grad()
63 | # data, targets = batch
64 | # output = model(data)
65 | # loss = nn.functional.cross_entropy(output, targets)
66 | # loss.backward()
67 | # for i, (name, params) in enumerate(model.named_parameters()):
68 | # if params.requires_grad:
69 | # if batch_id == 0:
70 | # client_grad.append(params.grad.clone())
71 | # else:
72 | # client_grad[i] += params.grad.clone()
73 | # optimizer.step()
74 | # if batch_id==2:
75 | # break
76 | #
77 | # print(client_grad[-2].cpu().data.numpy().shape)
78 | # print(np.array(client_grad[-2].cpu().data.numpy().shape))
79 | # grad_len = np.array(client_grad[-2].cpu().data.numpy().shape).prod()
80 | # print(grad_len)
81 | # memory = np.zeros((1, grad_len))
82 | # grads = np.zeros((1, grad_len))
83 | # grads[0] = np.reshape(client_grad[-2].cpu().data.numpy(), (grad_len))
84 | # print(grads)
85 | # print(grads[0].shape)
86 |
87 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtuann/fedlearn-backdoor-attacks/b025e763558d1e3d6a5872be7d158e8c157e7bf1/models/__init__.py
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Model(nn.Module):
5 | """
6 | Base class for models with added support for GradCam activation map
7 | and a SentiNet defense. The GradCam design is taken from:
8 | https://medium.com/@stepanulyanin/implementing-grad-cam-in-pytorch-ea0937c31e82
9 | If you are not planning to utilize SentiNet defense just import any model
10 | you like for your tasks.
11 | """
12 |
13 | def __init__(self):
14 | super().__init__()
15 | self.gradient = None
16 |
17 | def activations_hook(self, grad):
18 | self.gradient = grad
19 |
20 | def get_gradient(self):
21 | return self.gradient
22 |
23 | def get_activations(self, x):
24 | return self.features(x)
25 |
26 | def switch_grads(self, enable=True):
27 | for i, n in self.named_parameters():
28 | n.requires_grad_(enable)
29 |
30 | def features(self, x):
31 | """
32 | Get latent representation, eg logit layer.
33 | :param x:
34 | :return:
35 | """
36 | raise NotImplemented
37 |
38 | def forward(self, x, latent=False):
39 | raise NotImplemented
40 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.hub import load_state_dict_from_url
4 | from models.model import Model
5 |
6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
8 | 'wide_resnet50_2', 'wide_resnet101_2']
9 |
10 |
11 | model_urls = {
12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
21 | }
22 |
23 |
24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
25 | """3x3 convolution with padding"""
26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
27 | padding=dilation, groups=groups, bias=False,
28 | dilation=dilation)
29 |
30 |
31 | def conv1x1(in_planes, out_planes, stride=1):
32 | """1x1 convolution"""
33 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
34 | bias=False)
35 |
36 |
37 | class BasicBlock(nn.Module):
38 | expansion = 1
39 |
40 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
41 | base_width=64, dilation=1, norm_layer=None):
42 | super(BasicBlock, self).__init__()
43 | if norm_layer is None:
44 | norm_layer = nn.BatchNorm2d
45 | if groups != 1 or base_width != 64:
46 | raise ValueError(
47 | 'BasicBlock only supports groups=1 and base_width=64')
48 | if dilation > 1:
49 | raise NotImplementedError(
50 | "Dilation > 1 not supported in BasicBlock")
51 | # Both self.conv1 and self.downsample layers downsample the
52 | # input when stride != 1
53 | self.conv1 = conv3x3(inplanes, planes, stride)
54 | self.bn1 = norm_layer(planes)
55 | self.relu = nn.ReLU(inplace=True)
56 | self.conv2 = conv3x3(planes, planes)
57 | self.bn2 = norm_layer(planes)
58 | self.downsample = downsample
59 | self.stride = stride
60 |
61 | def forward(self, x):
62 | identity = x
63 |
64 | out = self.conv1(x)
65 | out = self.bn1(out)
66 | out = self.relu(out)
67 |
68 | out = self.conv2(out)
69 | out = self.bn2(out)
70 |
71 | if self.downsample is not None:
72 | identity = self.downsample(x)
73 |
74 | out += identity
75 | out = self.relu(out)
76 |
77 | return out
78 |
79 |
80 | class Bottleneck(nn.Module):
81 | expansion = 4
82 |
83 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
84 | base_width=64, dilation=1, norm_layer=None):
85 | super(Bottleneck, self).__init__()
86 | if norm_layer is None:
87 | norm_layer = nn.BatchNorm2d
88 | width = int(planes * (base_width / 64.)) * groups
89 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
90 | self.conv1 = conv1x1(inplanes, width)
91 | self.bn1 = norm_layer(width)
92 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
93 | self.bn2 = norm_layer(width)
94 | self.conv3 = conv1x1(width, planes * self.expansion)
95 | self.bn3 = norm_layer(planes * self.expansion)
96 | self.relu = nn.ReLU(inplace=True)
97 | self.downsample = downsample
98 | self.stride = stride
99 |
100 | def forward(self, x):
101 | identity = x
102 |
103 | out = self.conv1(x)
104 | out = self.bn1(out)
105 | out = self.relu(out)
106 |
107 | out = self.conv2(out)
108 | out = self.bn2(out)
109 | out = self.relu(out)
110 |
111 | out = self.conv3(out)
112 | out = self.bn3(out)
113 |
114 | if self.downsample is not None:
115 | identity = self.downsample(x)
116 |
117 | out += identity
118 | out = self.relu(out)
119 |
120 | return out
121 |
122 |
123 | class ResNet(Model):
124 |
125 | def __init__(self, block, layers, num_classes=1000,
126 | zero_init_residual=False,
127 | groups=1, width_per_group=64,
128 | replace_stride_with_dilation=None,
129 | norm_layer=None):
130 | super(ResNet, self).__init__()
131 | if norm_layer is None:
132 | norm_layer = nn.BatchNorm2d
133 | self._norm_layer = norm_layer
134 |
135 | self.inplanes = 64
136 | self.dilation = 1
137 | if replace_stride_with_dilation is None:
138 | # each element in the tuple indicates if we should replace
139 | # the 2x2 stride with a dilated convolution instead
140 | replace_stride_with_dilation = [False, False, False]
141 | if len(replace_stride_with_dilation) != 3:
142 | raise ValueError("replace_stride_with_dilation should be None "
143 | "or a 3-element tuple, got {}".format(
144 | replace_stride_with_dilation))
145 | self.groups = groups
146 | self.base_width = width_per_group
147 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2,
148 | padding=3,
149 | bias=False)
150 | self.bn1 = norm_layer(self.inplanes)
151 | self.relu = nn.ReLU(inplace=True)
152 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
153 | self.layer1 = self._make_layer(block, 64, layers[0])
154 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
155 | dilate=replace_stride_with_dilation[0])
156 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
157 | dilate=replace_stride_with_dilation[1])
158 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
159 | dilate=replace_stride_with_dilation[2])
160 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
161 | self.fc = nn.Linear(512 * block.expansion, num_classes)
162 |
163 | for m in self.modules():
164 | if isinstance(m, nn.Conv2d):
165 | nn.init.kaiming_normal_(m.weight, mode='fan_out',
166 | nonlinearity='relu')
167 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
168 | nn.init.constant_(m.weight, 1)
169 | nn.init.constant_(m.bias, 0)
170 |
171 | # Zero-initialize the last BN in each residual branch,
172 | # so that the residual branch starts with zeros, and
173 | # each residual block behaves like an identity.
174 | # This improves the model by 0.2~0.3% according
175 | # to https://arxiv.org/abs/1706.02677
176 | if zero_init_residual:
177 | for m in self.modules():
178 | if isinstance(m, Bottleneck):
179 | nn.init.constant_(m.bn3.weight, 0)
180 | elif isinstance(m, BasicBlock):
181 | nn.init.constant_(m.bn2.weight, 0)
182 |
183 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
184 | norm_layer = self._norm_layer
185 | downsample = None
186 | previous_dilation = self.dilation
187 | if dilate:
188 | self.dilation *= stride
189 | stride = 1
190 | if stride != 1 or self.inplanes != planes * block.expansion:
191 | downsample = nn.Sequential(
192 | conv1x1(self.inplanes, planes * block.expansion, stride),
193 | norm_layer(planes * block.expansion),
194 | )
195 |
196 | layers = []
197 | layers.append(
198 | block(self.inplanes, planes, stride, downsample, self.groups,
199 | self.base_width, previous_dilation, norm_layer))
200 | self.inplanes = planes * block.expansion
201 | for _ in range(1, blocks):
202 | layers.append(block(self.inplanes, planes, groups=self.groups,
203 | base_width=self.base_width,
204 | dilation=self.dilation,
205 | norm_layer=norm_layer))
206 |
207 | return nn.Sequential(*layers)
208 |
209 | def features(self, x):
210 | out1 = self.maxpool(self.relu(self.bn1(self.conv1(x))))
211 | out2 = self.layer1(out1)
212 | out3 = self.layer2(out2)
213 | out4 = self.layer3(out3)
214 | out5 = self.layer4(out4)
215 |
216 | return out5
217 |
218 | def forward(self, x, latent=False):
219 | x = self.conv1(x)
220 | x = self.bn1(x)
221 | x = self.relu(x)
222 | x = self.maxpool(x)
223 |
224 | x = self.layer1(x)
225 | x = self.layer2(x)
226 | x = self.layer3(x)
227 | layer4_out = self.layer4(x)
228 |
229 | if layer4_out.requires_grad:
230 | layer4_out.register_hook(self.activations_hook)
231 |
232 | x = self.avgpool(layer4_out)
233 | flatten_x = torch.flatten(x, 1)
234 | x = self.fc(flatten_x)
235 | if latent:
236 | return x, flatten_x
237 | else:
238 | return x
239 |
240 |
241 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
242 | model = ResNet(block, layers, **kwargs)
243 | if pretrained:
244 | state_dict = load_state_dict_from_url(model_urls[arch],
245 | progress=progress)
246 | model.load_state_dict(state_dict)
247 | return model
248 |
249 |
250 | def resnet18(pretrained=False, progress=True, **kwargs):
251 | r"""ResNet-18 model from
252 | `"Deep Residual Learning for Image Recognition"
253 | `_
254 |
255 | Args:
256 | pretrained (bool): If True, returns a model pre-trained on ImageNet
257 | progress (bool): If True, displays a progress bar of the download to stderr
258 | """
259 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
260 | **kwargs)
261 |
262 |
263 | def resnet34(pretrained=False, progress=True, **kwargs):
264 | r"""ResNet-34 model from
265 | `"Deep Residual Learning for Image Recognition" `_
266 |
267 | Args:
268 | pretrained (bool): If True, returns a model pre-trained on ImageNet
269 | progress (bool): If True, displays a progress bar of the download to stderr
270 | """
271 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
272 | **kwargs)
273 |
274 |
275 | def resnet50(pretrained=False, progress=True, **kwargs):
276 | r"""ResNet-50 model from
277 | `"Deep Residual Learning for Image Recognition" `_
278 |
279 | Args:
280 | pretrained (bool): If True, returns a model pre-trained on ImageNet
281 | progress (bool): If True, displays a progress bar of the download to stderr
282 | """
283 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
284 | **kwargs)
285 |
286 |
287 | def resnet101(pretrained=False, progress=True, **kwargs):
288 | r"""ResNet-101 model from
289 | `"Deep Residual Learning for Image Recognition" `_
290 |
291 | Args:
292 | pretrained (bool): If True, returns a model pre-trained on ImageNet
293 | progress (bool): If True, displays a progress bar of the download to stderr
294 | """
295 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
296 | **kwargs)
297 |
298 |
299 | def resnet152(pretrained=False, progress=True, **kwargs):
300 | r"""ResNet-152 model from
301 | `"Deep Residual Learning for Image Recognition"
302 | `_
303 |
304 | Args:
305 | pretrained (bool): If True, returns a model pre-trained on ImageNet
306 | progress (bool): If True, displays a progress bar of the download to stderr
307 | """
308 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
309 | **kwargs)
310 |
311 |
312 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
313 | r"""ResNeXt-50 32x4d model from
314 | `"Aggregated Residual Transformation for Deep Neural Networks"
315 | `_
316 |
317 | Args:
318 | pretrained (bool): If True, returns a model pre-trained on ImageNet
319 | progress (bool): If True, displays a progress bar of the download to stderr
320 | """
321 | kwargs['groups'] = 32
322 | kwargs['width_per_group'] = 4
323 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
324 | pretrained, progress, **kwargs)
325 |
326 |
327 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
328 | r"""ResNeXt-101 32x8d model from
329 | `"Aggregated Residual Transformation for Deep Neural Networks"
330 | `_
331 |
332 | Args:
333 | pretrained (bool): If True, returns a model pre-trained on ImageNet
334 | progress (bool): If True, displays a progress bar of the download to stderr
335 | """
336 | kwargs['groups'] = 32
337 | kwargs['width_per_group'] = 8
338 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
339 | pretrained, progress, **kwargs)
340 |
341 |
342 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
343 | r"""Wide ResNet-50-2 model from
344 | `"Wide Residual Networks" `_
345 |
346 | The model is the same as ResNet except for the bottleneck number of channels
347 | which is twice larger in every block. The number of channels in outer 1x1
348 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
349 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
350 |
351 | Args:
352 | pretrained (bool): If True, returns a model pre-trained on ImageNet
353 | progress (bool): If True, displays a progress bar of the download to stderr
354 | """
355 | kwargs['width_per_group'] = 64 * 2
356 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
357 | pretrained, progress, **kwargs)
358 |
359 |
360 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
361 | r"""Wide ResNet-101-2 model from
362 | `"Wide Residual Networks" `_
363 |
364 | The model is the same as ResNet except for the bottleneck number of channels
365 | which is twice larger in every block. The number of channels in outer 1x1
366 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
367 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
368 |
369 | Args:
370 | pretrained (bool): If True, returns a model pre-trained on ImageNet
371 | progress (bool): If True, displays a progress bar of the download to stderr
372 | """
373 | kwargs['width_per_group'] = 64 * 2
374 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
375 | pretrained, progress, **kwargs)
376 |
--------------------------------------------------------------------------------
/models/resnet_cifar.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 | For Pre-activation ResNet, see 'preact_resnet.py'.
3 | Reference:
4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from models.simple import SimpleNet
11 | from torch.autograd import Variable
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, in_planes, planes, stride=1):
18 | super(BasicBlock, self).__init__()
19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
20 | self.bn1 = nn.BatchNorm2d(planes)
21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 |
24 | self.shortcut = nn.Sequential()
25 | if stride != 1 or in_planes != self.expansion*planes:
26 | self.shortcut = nn.Sequential(
27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
28 | nn.BatchNorm2d(self.expansion*planes)
29 | )
30 |
31 | def forward(self, x):
32 | out = F.relu(self.bn1(self.conv1(x)))
33 | out = self.bn2(self.conv2(out))
34 | out += self.shortcut(x)
35 | out = F.relu(out)
36 | return out
37 |
38 |
39 | class Bottleneck(nn.Module):
40 | expansion = 4
41 |
42 | def __init__(self, in_planes, planes, stride=1):
43 | super(Bottleneck, self).__init__()
44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
45 | self.bn1 = nn.BatchNorm2d(planes)
46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
47 | self.bn2 = nn.BatchNorm2d(planes)
48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
49 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
50 |
51 | self.shortcut = nn.Sequential()
52 | if stride != 1 or in_planes != self.expansion*planes:
53 | self.shortcut = nn.Sequential(
54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
55 | nn.BatchNorm2d(self.expansion*planes)
56 | )
57 |
58 | def forward(self, x):
59 | out = F.relu(self.bn1(self.conv1(x)))
60 | out = F.relu(self.bn2(self.conv2(out)))
61 | out = self.bn3(self.conv3(out))
62 | out += self.shortcut(x)
63 | out = F.relu(out)
64 | return out
65 |
66 |
67 | class ResNet(SimpleNet):
68 | def __init__(self, block, num_blocks, num_classes=10, name=None, created_time=None):
69 | super(ResNet, self).__init__(name, created_time)
70 | self.in_planes = 32
71 |
72 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
73 | self.bn1 = nn.BatchNorm2d(32)
74 | self.layer1 = self._make_layer(block, 32, num_blocks[0], stride=1)
75 | self.layer2 = self._make_layer(block, 64, num_blocks[1], stride=2)
76 | self.layer3 = self._make_layer(block, 128, num_blocks[2], stride=2)
77 | self.layer4 = self._make_layer(block, 256, num_blocks[3], stride=2)
78 | self.linear = nn.Linear(256*block.expansion, num_classes)
79 |
80 | def _make_layer(self, block, planes, num_blocks, stride):
81 | strides = [stride] + [1]*(num_blocks-1)
82 | layers = []
83 | for stride in strides:
84 | layers.append(block(self.in_planes, planes, stride))
85 | self.in_planes = planes * block.expansion
86 | return nn.Sequential(*layers)
87 |
88 | def forward(self, x):
89 | out = F.relu(self.bn1(self.conv1(x)))
90 | out = self.layer1(out)
91 | out = self.layer2(out)
92 | out = self.layer3(out)
93 | out = self.layer4(out)
94 | out = F.avg_pool2d(out, 4)
95 | out = out.view(out.size(0), -1)
96 | out = self.linear(out)
97 | # for SDTdata
98 | # return F.softmax(out, dim=1)
99 | # for regular output
100 | return out
101 |
102 |
103 | def ResNet18(name=None, created_time=None):
104 | return ResNet(BasicBlock, [2,2,2,2],name='{0}_ResNet_18'.format(name), created_time=created_time)
105 |
106 | def ResNet34(name=None, created_time=None):
107 | return ResNet(BasicBlock, [3,4,6,3],name='{0}_ResNet_34'.format(name), created_time=created_time)
108 |
109 | def ResNet50(name=None, created_time=None):
110 | return ResNet(Bottleneck, [3,4,6,3],name='{0}_ResNet_50'.format(name), created_time=created_time)
111 |
112 | def ResNet101(name=None, created_time=None):
113 | return ResNet(Bottleneck, [3,4,23,3],name='{0}_ResNet'.format(name), created_time=created_time)
114 |
115 | def ResNet152(name=None, created_time=None):
116 | return ResNet(Bottleneck, [3,8,36,3],name='{0}_ResNet'.format(name), created_time=created_time)
117 |
118 |
119 | if __name__ == '__main__':
120 |
121 | net = ResNet18()
122 | y = net(Variable(torch.randn(1,3,32,32)))
123 | print(y.size())
--------------------------------------------------------------------------------
/models/resnet_tinyimagenet.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 | For Pre-activation ResNet, see 'preact_resnet.py'.
3 | Reference:
4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from models.simple import SimpleNet
11 | from models.model import Model
12 | # from torchvision.models.utils import load_state_dict_from_url
13 |
14 |
15 | __all__ = ['ResNet', 'resnet18']
16 |
17 |
18 | model_urls = {
19 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
20 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
21 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
22 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
23 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
24 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
25 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
26 | }
27 |
28 |
29 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
30 | """3x3 convolution with padding"""
31 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
32 | padding=dilation, groups=groups, bias=False, dilation=dilation)
33 |
34 |
35 | def conv1x1(in_planes, out_planes, stride=1):
36 | """1x1 convolution"""
37 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
38 |
39 |
40 | class BasicBlock(nn.Module):
41 | expansion = 1
42 |
43 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
44 | base_width=64, dilation=1, norm_layer=None):
45 | super(BasicBlock, self).__init__()
46 | if norm_layer is None:
47 | norm_layer = nn.BatchNorm2d
48 | if groups != 1 or base_width != 64:
49 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
50 | if dilation > 1:
51 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
52 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
53 | self.conv1 = conv3x3(inplanes, planes, stride)
54 | self.bn1 = norm_layer(planes)
55 | self.relu = nn.ReLU(inplace=True)
56 | self.conv2 = conv3x3(planes, planes)
57 | self.bn2 = norm_layer(planes)
58 | self.downsample = downsample
59 | self.stride = stride
60 |
61 | def forward(self, x):
62 | identity = x
63 |
64 | out = self.conv1(x)
65 | out = self.bn1(out)
66 | out = self.relu(out)
67 |
68 | out = self.conv2(out)
69 | out = self.bn2(out)
70 |
71 | if self.downsample is not None:
72 | identity = self.downsample(x)
73 |
74 | out += identity
75 | out = self.relu(out)
76 |
77 | return out
78 |
79 |
80 | class Bottleneck(nn.Module):
81 | expansion = 4
82 |
83 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
84 | base_width=64, dilation=1, norm_layer=None):
85 | super(Bottleneck, self).__init__()
86 | if norm_layer is None:
87 | norm_layer = nn.BatchNorm2d
88 | width = int(planes * (base_width / 64.)) * groups
89 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
90 | self.conv1 = conv1x1(inplanes, width)
91 | self.bn1 = norm_layer(width)
92 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
93 | self.bn2 = norm_layer(width)
94 | self.conv3 = conv1x1(width, planes * self.expansion)
95 | self.bn3 = norm_layer(planes * self.expansion)
96 | self.relu = nn.ReLU(inplace=True)
97 | self.downsample = downsample
98 | self.stride = stride
99 |
100 | def forward(self, x):
101 | identity = x
102 |
103 | out = self.conv1(x)
104 | out = self.bn1(out)
105 | out = self.relu(out)
106 |
107 | out = self.conv2(out)
108 | out = self.bn2(out)
109 | out = self.relu(out)
110 |
111 | out = self.conv3(out)
112 | out = self.bn3(out)
113 |
114 | if self.downsample is not None:
115 | identity = self.downsample(x)
116 |
117 | out += identity
118 | out = self.relu(out)
119 |
120 | return out
121 |
122 | class ResNet(Model):
123 |
124 | def __init__(self, block, layers, num_classes=1000,
125 | zero_init_residual=False,
126 | groups=1, width_per_group=64,
127 | replace_stride_with_dilation=None,
128 | norm_layer=None):
129 | super(ResNet, self).__init__()
130 | if norm_layer is None:
131 | norm_layer = nn.BatchNorm2d
132 | self._norm_layer = norm_layer
133 |
134 | self.inplanes = 64
135 | self.dilation = 1
136 | if replace_stride_with_dilation is None:
137 | # each element in the tuple indicates if we should replace
138 | # the 2x2 stride with a dilated convolution instead
139 | replace_stride_with_dilation = [False, False, False]
140 | if len(replace_stride_with_dilation) != 3:
141 | raise ValueError("replace_stride_with_dilation should be None "
142 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
143 | self.groups = groups
144 | self.base_width = width_per_group
145 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
146 | bias=False)
147 | self.bn1 = norm_layer(self.inplanes)
148 | self.relu = nn.ReLU(inplace=True)
149 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
150 | self.layer1 = self._make_layer(block, 64, layers[0])
151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
152 | dilate=replace_stride_with_dilation[0])
153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
154 | dilate=replace_stride_with_dilation[1])
155 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
156 | dilate=replace_stride_with_dilation[2])
157 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
158 | self.fc = nn.Linear(512 * block.expansion, num_classes)
159 |
160 | for m in self.modules():
161 | if isinstance(m, nn.Conv2d):
162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
163 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
164 | nn.init.constant_(m.weight, 1)
165 | nn.init.constant_(m.bias, 0)
166 |
167 | # Zero-initialize the last BN in each residual branch,
168 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
169 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
170 | if zero_init_residual:
171 | for m in self.modules():
172 | if isinstance(m, Bottleneck):
173 | nn.init.constant_(m.bn3.weight, 0)
174 | elif isinstance(m, BasicBlock):
175 | nn.init.constant_(m.bn2.weight, 0)
176 |
177 | # change the model to fit our tiny-imagenet-200 (200 classes)
178 | self.avgpool = nn.AdaptiveAvgPool2d(1)
179 | num_ftrs = self.fc.in_features
180 | self.fc = nn.Linear(num_ftrs, 200)
181 |
182 |
183 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
184 | norm_layer = self._norm_layer
185 | downsample = None
186 | previous_dilation = self.dilation
187 | if dilate:
188 | self.dilation *= stride
189 | stride = 1
190 | if stride != 1 or self.inplanes != planes * block.expansion:
191 | downsample = nn.Sequential(
192 | conv1x1(self.inplanes, planes * block.expansion, stride),
193 | norm_layer(planes * block.expansion),
194 | )
195 |
196 | layers = []
197 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
198 | self.base_width, previous_dilation, norm_layer))
199 | self.inplanes = planes * block.expansion
200 | for _ in range(1, blocks):
201 | layers.append(block(self.inplanes, planes, groups=self.groups,
202 | base_width=self.base_width, dilation=self.dilation,
203 | norm_layer=norm_layer))
204 |
205 | return nn.Sequential(*layers)
206 |
207 | def forward(self, x):
208 | x = self.conv1(x)
209 | x = self.bn1(x)
210 | x = self.relu(x)
211 | x = self.maxpool(x)
212 |
213 | x = self.layer1(x)
214 | x = self.layer2(x)
215 | x = self.layer3(x)
216 | x = self.layer4(x)
217 |
218 | x = self.avgpool(x)
219 | x = x.reshape(x.size(0), -1)
220 | x = self.fc(x)
221 | return x
222 |
223 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
224 | model = ResNet(block,layers, **kwargs)
225 | # if pretrained:
226 | # state_dict = load_state_dict_from_url(model_urls[arch],
227 | # progress=progress)
228 | # model.load_state_dict(state_dict)
229 | return model
230 |
231 |
232 | def resnet18(pretrained=False, progress=True, **kwargs):
233 | """Constructs a ResNet-18 model.
234 |
235 | Args:
236 | pretrained (bool): If True, returns a model pre-trained on ImageNet
237 | progress (bool): If True, displays a progress bar of the download to stderr
238 | """
239 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
240 | **kwargs)
241 |
242 |
243 | if __name__ == '__main__':
244 |
245 | import torchvision.transforms as transforms
246 | import torchvision.datasets as datasets
247 | import os
248 | import torch.utils.data as data
249 |
250 |
251 | # Data loading code
252 | data_transforms = {
253 | 'train': transforms.Compose([
254 | transforms.Resize(224),
255 | transforms.RandomHorizontalFlip(),
256 | transforms.ToTensor(),
257 | ]),
258 | 'val': transforms.Compose([
259 | transforms.Resize(224),
260 | transforms.ToTensor(),
261 | ]),
262 | }
263 |
264 | data_dir = '../data/tiny-imagenet-200/'
265 |
266 | image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
267 | data_transforms[x])
268 | for x in ['train', 'val']}
269 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=100,
270 | shuffle=False, num_workers=64)
271 | for x in ['train', 'val']}
272 | dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
273 |
274 | target_model = resnet18(name='Target',
275 | created_time='')
276 |
277 | running_loss = 0.0
278 | running_corrects = 0
279 | running_datasize = 0
280 | # Iterate over data.
281 |
282 | phase='val'
283 |
284 | vis_image=None
285 | criterion = nn.CrossEntropyLoss()
286 | for i, (inputs, labels) in enumerate(dataloaders[phase]):
287 |
288 | output= target_model(inputs)
289 |
290 | break
291 |
292 |
293 |
--------------------------------------------------------------------------------
/models/simple.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | from torchvision import datasets, transforms
7 | from torch.autograd import Variable
8 | import numpy as np
9 | import datetime
10 |
11 |
12 | class SimpleNet(nn.Module):
13 | def __init__(self, name=None, created_time=None):
14 | super(SimpleNet, self).__init__()
15 | self.created_time = created_time
16 | self.name=name
17 |
18 | def train_vis(self, vis, epoch, acc, loss=None, eid='main', is_poisoned=False, name=None):
19 | if name is None:
20 | name = self.name + '_poisoned' if is_poisoned else self.name
21 | vis.line(X=np.array([epoch]), Y=np.array([acc]), name=name, win='train_acc_{0}'.format(self.created_time), env=eid,
22 | update='append' if vis.win_exists('train_acc_{0}'.format(self.created_time), env=eid) else None,
23 | opts=dict(showlegend=True, title='Train Accuracy_{0}'.format(self.created_time),
24 | width=700, height=400))
25 | if loss is not None:
26 | vis.line(X=np.array([epoch]), Y=np.array([loss]), name=name, env=eid,
27 | win='train_loss_{0}'.format(self.created_time),
28 | update='append' if vis.win_exists('train_loss_{0}'.format(self.created_time), env=eid) else None,
29 | opts=dict(showlegend=True, title='Train Loss_{0}'.format(self.created_time), width=700, height=400))
30 | return
31 |
32 | def train_batch_vis(self, vis, epoch, data_len, batch, loss, eid='main', name=None, win='train_batch_loss', is_poisoned=False):
33 | if name is None:
34 | name = self.name + '_poisoned' if is_poisoned else self.name
35 | else:
36 | name = name + '_poisoned' if is_poisoned else name
37 |
38 | vis.line(X=np.array([(epoch-1)*data_len+batch]), Y=np.array([loss]),
39 | env=eid,
40 | name=f'{name}' if name is not None else self.name, win=f'{win}_{self.created_time}',
41 | update='append' if vis.win_exists(f'{win}_{self.created_time}', env=eid) else None,
42 | opts=dict(showlegend=True, width=700, height=400, title='Train Batch loss_{0}'.format(self.created_time)))
43 | def track_distance_batch_vis(self,vis, epoch, data_len, batch, distance_to_global_model,eid,name=None,is_poisoned=False):
44 | x= (epoch-1)*data_len+batch+1
45 |
46 | if name is None:
47 | name = self.name + '_poisoned' if is_poisoned else self.name
48 | else:
49 | name = name + '_poisoned' if is_poisoned else name
50 |
51 |
52 | vis.line(Y=np.array([distance_to_global_model]), X=np.array([x]),
53 | win=f"global_dist_{self.created_time}",
54 | env=eid,
55 | name=f'Model_{name}',
56 | update='append' if
57 | vis.win_exists(f"global_dist_{self.created_time}",
58 | env=eid) else None,
59 | opts=dict(showlegend=True,
60 | title=f"Distance to Global {self.created_time}",
61 | width=700, height=400))
62 | def weight_vis(self,vis,epoch,weight, eid, name,is_poisoned=False):
63 | name = str(name) + '_poisoned' if is_poisoned else name
64 | vis.line(Y=np.array([weight]), X=np.array([epoch]),
65 | win=f"Aggregation_Weight_{self.created_time}",
66 | env=eid,
67 | name=f'Model_{name}',
68 | update='append' if
69 | vis.win_exists(f"Aggregation_Weight_{self.created_time}",
70 | env=eid) else None,
71 | opts=dict(showlegend=True,
72 | title=f"Aggregation Weight {self.created_time}",
73 | width=700, height=400))
74 |
75 | def alpha_vis(self,vis,epoch,alpha, eid, name,is_poisoned=False):
76 | name = str(name) + '_poisoned' if is_poisoned else name
77 | vis.line(Y=np.array([alpha]), X=np.array([epoch]),
78 | win=f"FG_Alpha_{self.created_time}",
79 | env=eid,
80 | name=f'Model_{name}',
81 | update='append' if
82 | vis.win_exists(f"FG_Alpha_{self.created_time}",
83 | env=eid) else None,
84 | opts=dict(showlegend=True,
85 | title=f"FG Alpha {self.created_time}",
86 | width=700, height=400))
87 |
88 | def trigger_test_vis(self, vis, epoch, acc, loss, eid, agent_name_key, trigger_name, trigger_value):
89 | vis.line(Y=np.array([acc]), X=np.array([epoch]),
90 | win=f"poison_triggerweight_vis_acc_{self.created_time}",
91 | env=eid,
92 | name=f'{agent_name_key}_[{trigger_name}]_{trigger_value}',
93 | update='append' if vis.win_exists(f"poison_trigger_acc_{self.created_time}",
94 | env=eid) else None,
95 | opts=dict(showlegend=True,
96 | title=f"Backdoor Trigger Test Accuracy_{self.created_time}",
97 | width=700, height=400))
98 | if loss is not None:
99 | vis.line(Y=np.array([loss]), X=np.array([epoch]),
100 | win=f"poison_trigger_loss_{self.created_time}",
101 | env=eid,
102 | name=f'{agent_name_key}_[{trigger_name}]_{trigger_value}',
103 | update='append' if vis.win_exists(f"poison_trigger_loss_{self.created_time}",
104 | env=eid) else None,
105 | opts=dict(showlegend=True,
106 | title=f"Backdoor Trigger Test Loss_{self.created_time}",
107 | width=700, height=400))
108 |
109 | def trigger_agent_test_vis(self, vis, epoch, acc, loss, eid, name):
110 | vis.line(Y=np.array([acc]), X=np.array([epoch]),
111 | win=f"poison_state_trigger_acc_{self.created_time}",
112 | env=eid,
113 | name=f'{name}',
114 | update='append' if vis.win_exists(f"poison_state_trigger_acc_{self.created_time}",
115 | env=eid) else None,
116 | opts=dict(showlegend=True,
117 | title=f"Backdoor State Trigger Test Accuracy_{self.created_time}",
118 | width=700, height=400))
119 | if loss is not None:
120 | vis.line(Y=np.array([loss]), X=np.array([epoch]),
121 | win=f"poison_state_trigger_loss_{self.created_time}",
122 | env=eid,
123 | name=f'{name}',
124 | update='append' if vis.win_exists(f"poison_state_trigger_loss_{self.created_time}",
125 | env=eid) else None,
126 | opts=dict(showlegend=True,
127 | title=f"Backdoor State Trigger Test Loss_{self.created_time}",
128 | width=700, height=400))
129 |
130 |
131 | def poison_test_vis(self, vis, epoch, acc, loss, eid, agent_name_key):
132 | name= agent_name_key
133 | # name= f'Model_{name}'
134 |
135 | vis.line(Y=np.array([acc]), X=np.array([epoch]),
136 | win=f"poison_test_acc_{self.created_time}",
137 | env=eid,
138 | name=name,
139 | update='append' if vis.win_exists(f"poison_test_acc_{self.created_time}",
140 | env=eid) else None,
141 | opts=dict(showlegend=True,
142 | title=f"Backdoor Task Accuracy_{self.created_time}",
143 | width=700, height=400))
144 | if loss is not None:
145 | vis.line(Y=np.array([loss]), X=np.array([epoch]),
146 | win=f"poison_loss_acc_{self.created_time}",
147 | env=eid,
148 | name=name,
149 | update='append' if vis.win_exists(f"poison_loss_acc_{self.created_time}",
150 | env=eid) else None,
151 | opts=dict(showlegend=True,
152 | title=f"Backdoor Task Test Loss_{self.created_time}",
153 | width=700, height=400))
154 |
155 | def additional_test_vis(self, vis, epoch, acc, loss, eid, agent_name_key):
156 | name = agent_name_key
157 | vis.line(Y=np.array([acc]), X=np.array([epoch]),
158 | win=f"additional_test_acc_{self.created_time}",
159 | env=eid,
160 | name=name,
161 | update='append' if vis.win_exists(f"additional_test_acc_{self.created_time}",
162 | env=eid) else None,
163 | opts=dict(showlegend=True,
164 | title=f"Additional Test Accuracy_{self.created_time}",
165 | width=700, height=400))
166 | if loss is not None:
167 | vis.line(Y=np.array([loss]), X=np.array([epoch]),
168 | win=f"additional_test_loss_{self.created_time}",
169 | env=eid,
170 | name=name,
171 | update='append' if vis.win_exists(f"additional_test_loss_{self.created_time}",
172 | env=eid) else None,
173 | opts=dict(showlegend=True,
174 | title=f"Additional Test Loss_{self.created_time}",
175 | width=700, height=400))
176 |
177 |
178 | def test_vis(self, vis, epoch, acc, loss, eid, agent_name_key):
179 | name= agent_name_key
180 | # name= f'Model_{name}'
181 |
182 | vis.line(Y=np.array([acc]), X=np.array([epoch]),
183 | win=f"test_acc_{self.created_time}",
184 | env=eid,
185 | name=name,
186 | update='append' if vis.win_exists(f"test_acc_{self.created_time}",
187 | env=eid) else None,
188 | opts=dict(showlegend=True,
189 | title=f"Main Task Test Accuracy_{self.created_time}",
190 | width=700, height=400))
191 | if loss is not None:
192 | vis.line(Y=np.array([loss]), X=np.array([epoch]),
193 | win=f"test_loss_{self.created_time}",
194 | env=eid,
195 | name=name,
196 | update='append' if vis.win_exists(f"test_loss_{self.created_time}",
197 | env=eid) else None,
198 | opts=dict(showlegend=True,
199 | title=f"Main Task Test Loss_{self.created_time}",
200 | width=700, height=400))
201 |
202 |
203 | def save_stats(self, epoch, loss, acc):
204 | self.stats['epoch'].append(epoch)
205 | self.stats['loss'].append(loss)
206 | self.stats['acc'].append(acc)
207 |
208 | def copy_params(self, state_dict, coefficient_transfer=100):
209 |
210 | own_state = self.state_dict()
211 |
212 | for name, param in state_dict.items():
213 | if name in own_state:
214 | shape = param.shape
215 | #random_tensor = (torch.cuda.FloatTensor(shape).random_(0, 100) <= coefficient_transfer).type(torch.cuda.FloatTensor)
216 | # negative_tensor = (random_tensor*-1)+1
217 | # own_state[name].copy_(param)
218 | own_state[name].copy_(param.clone())
219 |
220 |
221 |
222 |
223 | class SimpleMnist(SimpleNet):
224 | def __init__(self, name=None, created_time=None):
225 | super(SimpleMnist, self).__init__(name, created_time)
226 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
227 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
228 | self.conv2_drop = nn.Dropout2d()
229 | self.fc1 = nn.Linear(320, 50)
230 | self.fc2 = nn.Linear(50, 10)
231 |
232 |
233 | def forward(self, x):
234 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
235 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
236 | x = x.view(-1, 320)
237 | x = F.relu(self.fc1(x))
238 | x = F.dropout(x, training=self.training)
239 | x = self.fc2(x)
240 | return F.log_softmax(x, dim=1)
--------------------------------------------------------------------------------
/synthesizers/attack_models/autoencoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class MNISTAutoencoder(nn.Module):
5 | def __init__(self):
6 | super().__init__()
7 | self.encoder = nn.Sequential(
8 | nn.Conv2d(1, 16, 3, stride=3, padding=1), # b, 16, 10, 10
9 | nn.BatchNorm2d(16),
10 | nn.ReLU(True),
11 | nn.MaxPool2d(2, stride=2), # b, 16, 5, 5
12 | nn.Conv2d(16, 64, 3, stride=2, padding=1), # b, 8, 3, 3
13 | nn.BatchNorm2d(64),
14 | nn.ReLU(True),
15 | nn.MaxPool2d(2, stride=1) # b, 8, 2, 2
16 | )
17 | self.decoder = nn.Sequential(
18 | nn.ConvTranspose2d(64, 128, 3, stride=2), # b, 16, 5, 5
19 | nn.BatchNorm2d(128),
20 | nn.ReLU(True),
21 | nn.ConvTranspose2d(128, 64, 5, stride=3, padding=1), # b, 8, 15, 15
22 | nn.BatchNorm2d(64),
23 | nn.ReLU(True),
24 | nn.ConvTranspose2d(64, 1, 2, stride=2, padding=1), # b, 1, 28, 28
25 | nn.BatchNorm2d(1),
26 | nn.Tanh()
27 | )
28 |
29 | def forward(self, x):
30 | x = self.encoder(x)
31 | x = self.decoder(x)
32 | return x
33 |
34 | class Autoencoder(nn.Module):
35 | def __init__(self):
36 | super(Autoencoder, self).__init__()
37 | self.encoder = nn.Sequential(
38 | nn.Conv2d(3, 16, 4, stride=2, padding=1),
39 | nn.BatchNorm2d(16),
40 | nn.ReLU(True),
41 | nn.Conv2d(16, 32, 4, stride=2, padding=1),
42 | nn.BatchNorm2d(32),
43 | nn.ReLU(True),
44 | nn.Conv2d(32, 64, 4, stride=2, padding=1),
45 | nn.BatchNorm2d(64),
46 | nn.ReLU(True),
47 | nn.Conv2d(64, 128, 4, stride=2, padding=1),
48 | nn.BatchNorm2d(128),
49 | nn.ReLU(True)
50 | )
51 | self.decoder = nn.Sequential(
52 | nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
53 | nn.BatchNorm2d(64),
54 | nn.ReLU(True),
55 | nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
56 | nn.BatchNorm2d(32),
57 | nn.ReLU(True),
58 | nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1),
59 | nn.BatchNorm2d(16),
60 | nn.ReLU(True),
61 | nn.ConvTranspose2d(16, 3, 4, stride=2, padding=1),
62 | nn.Tanh()
63 | )
64 |
65 | def forward(self, x):
66 | x = self.encoder(x)
67 | x = self.decoder(x)
68 | return x
--------------------------------------------------------------------------------
/synthesizers/attack_models/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | def double_conv(in_channels, out_channels):
6 | return nn.Sequential(
7 | nn.Conv2d(in_channels, out_channels, 3, padding=1),
8 | nn.BatchNorm2d(out_channels),
9 | nn.ReLU(inplace=True),
10 | nn.Conv2d(out_channels, out_channels, 3, padding=1),
11 | nn.BatchNorm2d(out_channels),
12 | nn.ReLU(inplace=True)
13 | )
14 |
15 |
16 | class UNet(nn.Module):
17 |
18 | def __init__(self, out_channel):
19 | super().__init__()
20 |
21 | self.dconv_down1 = double_conv(3, 64)
22 | self.dconv_down2 = double_conv(64, 128)
23 | self.dconv_down3 = double_conv(128, 256)
24 | self.dconv_down4 = double_conv(256, 512)
25 |
26 | self.maxpool = nn.AvgPool2d(2)
27 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear',
28 | align_corners=True)
29 |
30 | self.dconv_up3 = double_conv(256 + 512, 256)
31 | self.dconv_up2 = double_conv(128 + 256, 128)
32 | self.dconv_up1 = double_conv(128 + 64, 64)
33 |
34 | self.conv_last = nn.Sequential(
35 | nn.Conv2d(64, out_channel, 1),
36 | nn.BatchNorm2d(out_channel),
37 | )
38 |
39 | def forward(self, x):
40 | conv1 = self.dconv_down1(x)
41 | x = self.maxpool(conv1)
42 |
43 | conv2 = self.dconv_down2(x)
44 | x = self.maxpool(conv2)
45 |
46 | conv3 = self.dconv_down3(x)
47 | x = self.maxpool(conv3)
48 |
49 | x = self.dconv_down4(x)
50 |
51 | x = self.upsample(x)
52 | x = torch.cat([x, conv3], dim=1)
53 |
54 | x = self.dconv_up3(x)
55 | x = self.upsample(x)
56 | x = torch.cat([x, conv2], dim=1)
57 |
58 | x = self.dconv_up2(x)
59 | x = self.upsample(x)
60 | x = torch.cat([x, conv1], dim=1)
61 |
62 | x = self.dconv_up1(x)
63 |
64 | out = self.conv_last(x)
65 |
66 | out = torch.tanh(out)
67 |
68 | return out
--------------------------------------------------------------------------------
/synthesizers/complex_synthesizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from synthesizers.pattern_synthesizer import PatternSynthesizer
4 |
5 |
6 | class ComplexSynthesizer(PatternSynthesizer):
7 | """
8 | For physical backdoors it's ok to train using pixel pattern that
9 | represents the physical object in the real scene.
10 | """
11 |
12 | pattern_tensor = torch.tensor([[1.]])
13 |
14 | def synthesize_labels(self, batch, attack_portion=None):
15 | batch.labels[:attack_portion] = batch.aux[:attack_portion]
16 | return
17 |
--------------------------------------------------------------------------------
/synthesizers/iba_synthesizer.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import torch
4 | from torchvision.transforms import transforms, functional
5 |
6 | from synthesizers.synthesizer import Synthesizer
7 | from tasks.task import Task
8 | import copy
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import numpy as np
12 | from termcolor import colored
13 |
14 | from torch.nn.utils import parameters_to_vector, vector_to_parameters
15 |
16 |
17 | transform_to_image = transforms.ToPILImage()
18 | transform_to_tensor = transforms.ToTensor()
19 |
20 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
21 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
22 | IMAGENET_MIN = ((np.array([0,0,0]) - np.array(IMAGENET_DEFAULT_MEAN)) / np.array(IMAGENET_DEFAULT_STD)).min()
23 | IMAGENET_MAX = ((np.array([1,1,1]) - np.array(IMAGENET_DEFAULT_MEAN)) / np.array(IMAGENET_DEFAULT_STD)).max()
24 |
25 |
26 | def get_clip_image(dataset="cifar10"):
27 | if dataset in ['timagenet', 'tiny-imagenet32', 'Imagenet']:
28 | def clip_image(x):
29 | return torch.clamp(x, IMAGENET_MIN, IMAGENET_MAX)
30 | elif dataset in ['cifar10', 'Cifar10']:
31 | def clip_image(x):
32 | return torch.clamp(x, IMAGENET_MIN, IMAGENET_MAX)
33 | elif dataset in ['mnist', 'MNIST']:
34 | def clip_image(x):
35 | return torch.clamp(x, -1.0, 1.0)
36 | elif dataset == 'gtsrb':
37 | def clip_image(x):
38 | return torch.clamp(x, IMAGENET_MIN, IMAGENET_MAX)
39 | else:
40 | raise Exception(f'Invalid dataset: {dataset}')
41 | return clip_image
42 |
43 | def apply_grad_mask(model, mask_grad_list):
44 | mask_grad_list_copy = iter(mask_grad_list)
45 | # print(f"mask_grad_list_copy: {mask_grad_list_copy}")
46 | for name, parms in model.named_parameters():
47 | if parms.requires_grad:
48 | parms.grad = parms.grad * next(mask_grad_list_copy)
49 |
50 |
51 | class IBASynthesizer(Synthesizer):
52 |
53 | def __init__(self, task: Task):
54 | super().__init__(task)
55 | # self.make_pattern(self.pattern_tensor, self.x_top, self.y_top)
56 | tgt_model_path = "/home/vishc2/tuannm/fedlearn-backdoor-attacks/synthesizers/checkpoint/lira/lira_cifar10_vgg9_0.03.pt"
57 | self.atk_model = self.load_adversarial_model(tgt_model_path, self.params.task, self.params.device)
58 |
59 | def get_target_transform(self, target_label, mode="all2one", num_classes=10):
60 | """
61 | Get target transform function
62 | """
63 | if mode == "all2one":
64 | target_transform = lambda x: torch.ones_like(x) * target_label
65 | elif mode == 'all2all':
66 | target_transform = lambda x: (x + 1) % num_classes
67 | else:
68 | raise Exception(f'Invalid mode {mode}')
69 |
70 | return target_transform
71 |
72 | def create_trigger_model(self, dataset, device="cpu", attack_model=None):
73 | """ Create trigger model for IBA """
74 | # print(f"device: {device}")
75 | # /home/vishc2/tuannm/fedlearn-backdoor-attacks/synthesizers
76 | if dataset in ['cifar10', 'Cifar10']:
77 | from synthesizers.attack_models.unet import UNet
78 | atkmodel = UNet(3).to(device)
79 | elif dataset in ['mnist', 'MNIST']:
80 | from synthesizers.attack_models.autoencoders import MNISTAutoencoder as Autoencoder
81 | atkmodel = Autoencoder().to(device)
82 | elif dataset == 'timagenet' or dataset == 'tiny-imagenet32' or dataset == 'gtsrb' or dataset == 'Imagenet':
83 | if attack_model is None:
84 | from synthesizers.attack_models.autoencoders import Autoencoder
85 | atkmodel = Autoencoder().to(device)
86 | elif attack_model == 'unet':
87 | from synthesizers.attack_models.unet import UNet
88 | atkmodel = UNet(3).to(device)
89 | else:
90 | raise Exception(f'Invalid atk model {dataset}')
91 | return atkmodel
92 |
93 | def load_adversarial_model(self, checkpoint_path, dataset, device):
94 | # print(f"CKPT file for attack model: {checkpoint_path}")
95 | atkmodel = self.create_trigger_model(dataset, device)
96 | atkmodel.load_state_dict(torch.load(checkpoint_path))
97 | # print(colored("Load attack model sucessfully!", "red"))
98 | return atkmodel
99 |
100 | def get_poison_batch_adversarial(self, bptt, dataset, device, evaluation=False, target_transform=None, atk_model=None, ratio=0.75, atk_eps=0.75):
101 | images, targets = bptt
102 | poisoning_per_batch = int(ratio*len(images))
103 | images = images.to(device)
104 | targets = targets.to(device)
105 | poison_count= 0
106 |
107 | clip_image = get_clip_image(dataset)
108 |
109 | with torch.no_grad():
110 | noise = atk_model(images) * atk_eps
111 | atkdata = clip_image(images + noise)
112 | if target_transform:
113 | atktarget = target_transform(targets)
114 | atktarget = targets.to(device)
115 | if not evaluation:
116 | atkdata = atkdata[:poisoning_per_batch]
117 | atktarget = atktarget[:poisoning_per_batch]
118 | poison_count = len(atkdata)
119 |
120 | return atkdata.to(device), atktarget.to(device), poison_count
121 |
122 | def make_pattern(self, pattern_tensor, x_top, y_top):
123 | # TRAIN IBA
124 | # for e in range(self.params.iba_epochs):
125 | # self.train_iba(self.task.model, self.task.atk_model, self.task.tgt_model, self.task.optimizer, self.task.atkmodel_optimizer, self.task.train_loader,
126 | # self.task.criterion, atkmodel_train=True, device=self.params.device, logger=self.params.logger,
127 | # adv_optimizer=self.task.adv_optimizer, clip_image=get_clip_image(self.params.dataset), target_transform=self.task.target_transform,
128 | # dataset=self.params.dataset, mu=self.params.mu, aggregator=self.params.aggregator, attack_alpha=self.params.attack_alpha,
129 | # attack_portion=self.params.attack_portion, atk_eps=self.params.atk_eps, pgd_attack=self.params.pgd_attack,
130 | # proj=self.params.proj, pgd_eps=self.params.pgd_eps, project_frequency=self.params.project_frequency,
131 | # mask_grad_list=self.task.mask_grad_list, model_original=self.task.model_original, local_e=e)
132 | # self.task.copy_params(self.task.model, self.task.atk_model)
133 | # Load model from file and train IBA
134 | # from torch import optim
135 | # atk_optimizer = optim.Adam(tgt_model.parameters(), lr=0.00005)
136 | # tgt_model_path = "/home/vishc2/tuannm/fedlearn-backdoor-attacks/synthesizers/checkpoint/lira/lira_cifar10_vgg9_0.03.pt"
137 | # target_transform = self.get_target_transform(0, mode="all2one", num_classes=10)
138 |
139 | # # task in [Imagenet, Cifar10, MNIST]
140 |
141 | # tgt_model = self.load_adversarial_model(tgt_model_path, self.params.task, self.params.device)
142 |
143 | # target_transform = self.get_target_transform(self.params.backdoor_label, mode="all2one", num_classes=self.params.num_classes)
144 |
145 | # poison_data, poison_target, poison_num = self.get_poison_batch_adversarial(batch_cp, dataset=dataset, device=device,
146 | # target_transform=target_transform, atk_model=atk_model)
147 |
148 | # self.train_lira(net, tgt_model, None, None, atk_optimizer, local_train_dl,
149 | # criterion, 0.5, 0.5, 1.0, None, tgt_tf, 1,
150 | # atkmodel_train=True, device=self.device, pgd_attack=False)
151 |
152 | pass
153 | def train_iba(self, model, atkmodel, tgtmodel, optimizer, atkmodel_optimizer, train_loader, criterion, atkmodel_train=False,
154 | device=None, logger=None, adv_optimizer=None, clip_image=None, target_transform=None, dataset=None,
155 | mu=0.1, aggregator="fedprox", attack_alpha=1.0, attack_portion=1.0, atk_eps=0.1, pgd_attack=False,
156 | proj="l_inf", pgd_eps=0.3, project_frequency=1, mask_grad_list=None, model_original=None, local_e=0):
157 |
158 | wg_clone = copy.deepcopy(model)
159 | loss_fn = nn.CrossEntropyLoss()
160 | func_fn = loss_fn
161 |
162 | correct_clean = 0
163 | correct_poison = 0
164 |
165 | poison_size = 0
166 | clean_size = 0
167 | loss_list = []
168 |
169 | if not atkmodel_train:
170 | model.train()
171 | # Sub-training phase
172 | for batch_idx, batch in enumerate(train_loader):
173 | bs = len(batch)
174 | data, targets = batch
175 | # data, target = data.to(device), target.to(device)
176 | # clean_images, clean_targets, poison_images, poison_targets, poisoning_per_batch = get_poison_batch(batch, attack_portion)
177 | clean_images, clean_targets = copy.deepcopy(data).to(device), copy.deepcopy(targets).to(device)
178 | poison_images, poison_targets = copy.deepcopy(data).to(device), copy.deepcopy(targets).to(device)
179 | # dataset_size += len(data)
180 | clean_size += len(clean_images)
181 | optimizer.zero_grad()
182 | if pgd_attack:
183 | adv_optimizer.zero_grad()
184 | output = model(clean_images)
185 | loss_clean = loss_fn(output, clean_targets)
186 |
187 | if attack_alpha == 1.0:
188 | optimizer.zero_grad()
189 | loss_clean.backward()
190 | if not pgd_attack:
191 | optimizer.step()
192 | else:
193 | if proj == "l_inf":
194 | w = list(model.parameters())
195 | # adversarial learning rate
196 | eta = 0.001
197 | for i in range(len(w)):
198 | # uncomment below line to restrict proj to some layers
199 | if True:#i == 6 or i == 8 or i == 10 or i == 0 or i == 18:
200 | w[i].data = w[i].data - eta * w[i].grad.data
201 | # projection step
202 | m1 = torch.lt(torch.sub(w[i], model_original[i]), -pgd_eps)
203 | m2 = torch.gt(torch.sub(w[i], model_original[i]), pgd_eps)
204 | w1 = (model_original[i] - pgd_eps) * m1
205 | w2 = (model_original[i] + pgd_eps) * m2
206 | w3 = (w[i]) * (~(m1+m2))
207 | wf = w1+w2+w3
208 | w[i].data = wf.data
209 | else:
210 | # do l2_projection
211 | adv_optimizer.step()
212 | w = list(model.parameters())
213 | w_vec = parameters_to_vector(w)
214 | model_original_vec = parameters_to_vector(model_original)
215 | # make sure you project on last iteration otherwise, high LR pushes you really far
216 | # Start
217 | if (batch_idx%project_frequency == 0 or batch_idx == len(train_loader)-1) and (torch.norm(w_vec - model_original_vec) > pgd_eps):
218 | # project back into norm ball
219 | w_proj_vec = pgd_eps*(w_vec - model_original_vec)/torch.norm(
220 | w_vec-model_original_vec) + model_original_vec
221 | # plug w_proj back into model
222 | vector_to_parameters(w_proj_vec, w)
223 |
224 | pred = output.data.max(1)[1] # get the index of the max log-probability
225 | loss_list.append(loss_clean.item())
226 | correct_clean += pred.eq(clean_targets.data.view_as(pred)).cpu().sum().item()
227 | else:
228 | if attack_alpha < 1.0:
229 | poison_size += len(poison_images)
230 | # poison_images, poison_targets = poison_images.to(device), poison_targets.to(device)
231 | with torch.no_grad():
232 | noise = tgtmodel(poison_images) * atk_eps
233 | atkdata = clip_image(poison_images + noise)
234 | atktarget = target_transform(poison_targets)
235 | # atkdata.requires_grad_(False)
236 | # atktarget.requires_grad_(False)
237 | if attack_portion < 1.0:
238 | atkdata = atkdata[:int(attack_portion*bs)]
239 | atktarget = atktarget[:int(attack_portion*bs)]
240 | atkoutput = model(atkdata.detach())
241 | loss_poison = F.cross_entropy(atkoutput, atktarget.detach())
242 | else:
243 | loss_poison = torch.tensor(0.0).to(device)
244 | loss2 = loss_clean * attack_alpha + (1.0 - attack_alpha) * loss_poison
245 |
246 | optimizer.zero_grad()
247 | loss2.backward()
248 | if mask_grad_list:
249 | apply_grad_mask(model, mask_grad_list)
250 | if not pgd_attack:
251 | optimizer.step()
252 | else:
253 | if proj == "l_inf":
254 | w = list(model.parameters())
255 | # adversarial learning rate
256 | eta = 0.001
257 | for i in range(len(w)):
258 | # uncomment below line to restrict proj to some layers
259 | if True:#i == 6 or i == 8 or i == 10 or i == 0 or i == 18:
260 | w[i].data = w[i].data - eta * w[i].grad.data
261 | # projection step
262 | m1 = torch.lt(torch.sub(w[i], model_original[i]), -pgd_eps)
263 | m2 = torch.gt(torch.sub(w[i], model_original[i]), pgd_eps)
264 | w1 = (model_original[i] - pgd_eps) * m1
265 | w2 = (model_original[i] + pgd_eps) * m2
266 | w3 = (w[i]) * (~(m1+m2))
267 | wf = w1+w2+w3
268 | w[i].data = wf.data
269 | else:
270 | # do l2_projection
271 | adv_optimizer.step()
272 | w = list(model.parameters())
273 | w_vec = parameters_to_vector(w)
274 | model_original_vec = parameters_to_vector(list(model_original.parameters()))
275 | # make sure you project on last iteration otherwise, high LR pushes you really far
276 | if (local_e%project_frequency == 0 and batch_idx == len(train_loader)-1) and (torch.norm(w_vec - model_original_vec) > pgd_eps):
277 | # project back into norm ball
278 | w_proj_vec = pgd_eps*(w_vec - model_original_vec)/torch.norm(
279 | w_vec-model_original_vec) + model_original_vec
280 |
281 |
282 | # plug w_proj back into model
283 | vector_to_parameters(w_proj_vec, w)
284 |
285 | loss_list.append(loss2.item())
286 | pred = output.data.max(1)[1] # get the index of the max log-probability
287 | poison_pred = atkoutput.data.max(1)[1] # get the index of the max log-probability
288 |
289 | # correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
290 | correct_clean += pred.eq(clean_targets.data.view_as(pred)).cpu().sum().item()
291 | correct_poison += poison_pred.eq(atktarget.data.view_as(poison_pred)).cpu().sum().item()
292 |
293 | else:
294 | model.eval()
295 | # atk_optimizer = optim.Adam(atkmodel.parameters(), lr=0.0002)
296 | atkmodel.train()
297 | # optimizer.zero_grad()
298 | for batch_idx, (batch) in enumerate(train_loader):
299 | batch_cp = copy.deepcopy(batch)
300 | data, target = batch
301 | # print(f"len(clean_data): {len(clean_data)}")
302 | data, target = data.to(device), target.to(device)
303 | bs = data.size(0)
304 | atkdata, atktarget, poison_num = get_poison_batch_adversarial_updated(batch_cp, dataset=dataset, device=device,
305 | target_transform=target_transform, atk_model=atkmodel)
306 | # dataset_size += len(data)
307 | poison_size += poison_num
308 |
309 | ###############################
310 | #### Update the classifier ####
311 | ###############################
312 | atkoutput = model(atkdata)
313 | loss_p = func_fn(atkoutput, atktarget)
314 | loss2 = loss_p
315 |
316 | atkmodel_optimizer.zero_grad()
317 | loss2.backward()
318 | atkmodel_optimizer.step()
319 | pred = atkoutput.data.max(1)[1] # get the index of the max log-probability
320 | correct_poison += pred.eq(atktarget.data.view_as(pred)).cpu().sum().item()
321 | loss_list.append(loss2.item())
322 |
323 | # acc = 100.0 * (float(correct) / float(dataset_size))
324 | clean_acc = 100.0 * (float(correct_clean)/float(clean_size)) if clean_size else 0.0
325 | poison_acc = 100.0 * (float(correct_poison)/float(poison_size)) if poison_size else 0.0
326 | # poison_acc = 100.0 - poison_acc
327 |
328 | training_avg_loss = sum(loss_list)/len(loss_list)
329 | # training_avg_loss = 0.0
330 | if atkmodel_train:
331 | logger.info(colored("Training loss = {:.2f}, acc = {:.2f} of atk model this epoch".format(training_avg_loss, poison_acc), "yellow"))
332 | else:
333 | logger.info(colored("Training loss = {:.2f}, acc = {:.2f} of cls model this epoch".format(training_avg_loss, clean_acc), "yellow"))
334 | logger.info("Training clean_acc is {:.2f}, poison_acc = {:.2f}".format(clean_acc, poison_acc))
335 | del wg_clone
336 |
337 |
338 | def synthesize_inputs(self, batch, attack_portion=128):
339 | # TODO something
340 |
341 | # batch.inputs[:attack_portion] = batch.inputs[:attack_portion]
342 | # from torch import optim
343 | # atk_optimizer = optim.Adam(tgt_model.parameters(), lr=0.00005)
344 |
345 |
346 |
347 | # target_transform = self.get_target_transform(self.params.backdoor_label, mode="all2one", num_classes=self.params.num_classes)
348 |
349 |
350 | images = batch.inputs[:attack_portion]
351 |
352 |
353 | device = self.params.device
354 | images = images.to(device)
355 |
356 | clip_image = get_clip_image(self.params.task)
357 | atk_eps = 0.75
358 |
359 | with torch.no_grad():
360 | noise = self.atk_model(images) * atk_eps
361 | atkdata = clip_image(images + noise)
362 |
363 |
364 | batch.inputs[:attack_portion] = atkdata.to(device)
365 |
366 |
367 | def synthesize_labels(self, batch, attack_portion=None):
368 | batch.labels[:attack_portion].fill_(self.params.backdoor_label)
369 |
--------------------------------------------------------------------------------
/synthesizers/pattern_synthesizer.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import torch
4 | from torchvision.transforms import transforms, functional
5 |
6 | from synthesizers.synthesizer import Synthesizer
7 | from tasks.task import Task
8 |
9 | transform_to_image = transforms.ToPILImage()
10 | transform_to_tensor = transforms.ToTensor()
11 |
12 |
13 | class PatternSynthesizer(Synthesizer):
14 | pattern_tensor: torch.Tensor = torch.tensor([
15 | [1., 0., 1.],
16 | [-10., 1., -10.],
17 | [-10., -10., 0.],
18 | [-10., 1., -10.],
19 | [1., 0., 1.]
20 | ])
21 | "Just some random 2D pattern."
22 |
23 | x_top = 3
24 | "X coordinate to put the backdoor into."
25 | y_top = 23
26 | "Y coordinate to put the backdoor into."
27 |
28 | mask_value = -10
29 | "A tensor coordinate with this value won't be applied to the image."
30 |
31 | resize_scale = (5, 10)
32 | "If the pattern is dynamically placed, resize the pattern."
33 |
34 | mask: torch.Tensor = None
35 | "A mask used to combine backdoor pattern with the original image."
36 |
37 | pattern: torch.Tensor = None
38 | "A tensor of the `input.shape` filled with `mask_value` except backdoor."
39 |
40 | def __init__(self, task: Task):
41 | super().__init__(task)
42 | self.make_pattern(self.pattern_tensor, self.x_top, self.y_top)
43 |
44 | def make_pattern(self, pattern_tensor, x_top, y_top):
45 | # put pattern into the image
46 | full_image = torch.zeros(self.params.input_shape)
47 | full_image.fill_(self.mask_value)
48 | # full image has a pixel value of -10
49 | x_bot = x_top + pattern_tensor.shape[0]
50 | y_bot = y_top + pattern_tensor.shape[1]
51 |
52 | if x_bot >= self.params.input_shape[1] or \
53 | y_bot >= self.params.input_shape[2]:
54 | raise ValueError(f'Position of backdoor outside image limits:'
55 | f'image: {self.params.input_shape}, but backdoor'
56 | f'ends at ({x_bot}, {y_bot})')
57 |
58 | full_image[:, x_top:x_bot, y_top:y_bot] = pattern_tensor
59 | # full image has a pixel value of -10 except for the backdoor (pattern_tensor) size: 5 * 3
60 | self.mask = 1 * (full_image != self.mask_value).to(self.params.device) # (0, 1)
61 | # mask is a tensor of 0 and 1, 0 for -10 and 1 for other values
62 | self.pattern = self.task.normalize(full_image).to(self.params.device) # )(-52.5678, 2.7537)
63 | # backdoor pattern has a ->- shape
64 | # patter
65 | ### 1 1 1
66 | ### 0 1 0
67 | ### 0 0 1
68 | ### 0 1 0
69 | ### 1 1 1
70 | # if pattern_tensor size is (11, 7) = (5, 3) * resize; then pattern size is square of 11 * 11
71 |
72 | # import IPython; IPython.embed()
73 | # exit(0)
74 |
75 | def synthesize_inputs(self, batch, attack_portion=None):
76 | pattern, mask = self.get_pattern()
77 | # mask value (0, 1); value 0, keep the original image; value 1, replace with pattern
78 | batch.inputs[:attack_portion] = (1 - mask) * \
79 | batch.inputs[:attack_portion] + \
80 | mask * pattern
81 |
82 | return
83 |
84 | def synthesize_labels(self, batch, attack_portion=None):
85 | batch.labels[:attack_portion].fill_(self.params.backdoor_label)
86 |
87 | return
88 |
89 | def get_pattern(self):
90 | if self.params.backdoor_dynamic_position:
91 | resize = random.randint(self.resize_scale[0], self.resize_scale[1])
92 | pattern = self.pattern_tensor
93 | if random.random() > 0.5:
94 | pattern = functional.hflip(pattern)
95 | image = transform_to_image(pattern)
96 | pattern = transform_to_tensor(
97 | functional.resize(image,
98 | resize, interpolation=0)).squeeze()
99 |
100 | x = random.randint(0, self.params.input_shape[1] \
101 | - pattern.shape[0] - 1)
102 | y = random.randint(0, self.params.input_shape[2] \
103 | - pattern.shape[1] - 1)
104 | self.make_pattern(pattern, x, y)
105 |
106 | return self.pattern, self.mask
107 |
--------------------------------------------------------------------------------
/synthesizers/physical_synthesizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from synthesizers.synthesizer import Synthesizer
3 |
4 |
5 | class PhysicalSynthesizer(Synthesizer):
6 | """
7 | For physical backdoors it's ok to train using pixel pattern that
8 | represents the physical object in the real scene.
9 | """
10 |
11 | pattern_tensor = torch.tensor([[1.]])
--------------------------------------------------------------------------------
/synthesizers/singlepixel_synthesizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from synthesizers.pattern_synthesizer import PatternSynthesizer
3 |
4 |
5 | class SinglePixelSynthesizer(PatternSynthesizer):
6 | pattern_tensor = torch.tensor([[1.]])
7 |
--------------------------------------------------------------------------------
/synthesizers/synthesizer.py:
--------------------------------------------------------------------------------
1 | from tasks.batch import Batch
2 | from tasks.task import Task
3 | from utils.parameters import Params
4 |
5 |
6 | class Synthesizer:
7 | params: Params
8 | task: Task
9 |
10 | def __init__(self, task: Task):
11 | self.task = task
12 | self.params = task.params
13 |
14 | def make_backdoor_batch(self, batch: Batch, test=False, attack=True) -> Batch:
15 |
16 | # Don't attack if only normal loss task.
17 | if not attack:
18 | return batch
19 |
20 | if test:
21 | attack_portion = batch.batch_size
22 | else:
23 | attack_portion = round(
24 | batch.batch_size * self.params.poisoning_proportion)
25 |
26 | backdoored_batch = batch.clone()
27 | self.apply_backdoor(backdoored_batch, attack_portion)
28 | return backdoored_batch
29 |
30 | ### Plot the backdoored image and the original image
31 | # import IPython; IPython.embed(); exit(0)
32 | # batch.inputs.shape = (batch_size, 3, 32, 32) = (64, 3, 32, 32)
33 | # using torch to show the image
34 | import matplotlib.pyplot as plt
35 | import numpy as np
36 | import torchvision
37 | import torchvision.transforms as transforms
38 | def imshow(img):
39 | # img to cpu
40 | img = img.cpu()
41 | # img = img / 2 + 0.5 # unnormalize
42 | npimg = img.numpy()
43 | plt.imshow(np.transpose(npimg, (1, 2, 0)))
44 | plt.show()
45 | imshow(torchvision.utils.make_grid(batch.inputs))
46 |
47 | # def imshow2(img):
48 | # # img to cpu
49 | # img = img.cpu()
50 | # img = img / 2 + 0.5 # unnormalize
51 | # npimg = img.numpy()
52 | # plt.imshow(np.transpose(npimg, (1, 2, 0)))
53 | # plt.show()
54 |
55 |
56 | # imshow(torchvision.utils.make_grid(batch.inputs))
57 | # imshow(torchvision.utils.make_grid(backdoored_batch.inputs))
58 | # import IPython; IPython.embed(); exit(0)
59 |
60 |
61 |
62 | def apply_backdoor(self, batch, attack_portion):
63 | """
64 | Modifies only a portion of the batch (represents batch poisoning).
65 |
66 | :param batch:
67 | :return:
68 | """
69 | self.synthesize_inputs(batch=batch, attack_portion=attack_portion)
70 | self.synthesize_labels(batch=batch, attack_portion=attack_portion)
71 |
72 | return
73 |
74 | def synthesize_inputs(self, batch, attack_portion=None):
75 | raise NotImplemented
76 |
77 | def synthesize_labels(self, batch, attack_portion=None):
78 | raise NotImplemented
79 |
--------------------------------------------------------------------------------
/tasks/batch.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import torch
3 |
4 |
5 | @dataclass
6 | class Batch:
7 | batch_id: int
8 | inputs: torch.Tensor
9 | labels: torch.Tensor
10 |
11 | # For PIPA experiment we use this field to store identity label.
12 | aux: torch.Tensor = None
13 |
14 | def __post_init__(self):
15 | self.batch_size = self.inputs.shape[0]
16 |
17 | def to(self, device):
18 | inputs = self.inputs.to(device)
19 | labels = self.labels.to(device)
20 | if self.aux is not None:
21 | aux = self.aux.to(device)
22 | else:
23 | aux = None
24 | return Batch(self.batch_id, inputs, labels, aux)
25 |
26 | def clone(self):
27 | inputs = self.inputs.clone()
28 | labels = self.labels.clone()
29 | if self.aux is not None:
30 | aux = self.aux.clone()
31 | else:
32 | aux = None
33 | return Batch(self.batch_id, inputs, labels, aux)
34 |
35 |
36 | def clip(self, batch_size):
37 | if batch_size is None:
38 | return self
39 |
40 | inputs = self.inputs[:batch_size]
41 | labels = self.labels[:batch_size]
42 |
43 | if self.aux is None:
44 | aux = None
45 | else:
46 | aux = self.aux[:batch_size]
47 |
48 | return Batch(self.batch_id, inputs, labels, aux)
--------------------------------------------------------------------------------
/tasks/cifar10_task.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torchvision
3 | from torch import nn
4 | from torch.utils.data import DataLoader
5 | from torch.utils.data import Subset
6 | from torchvision.transforms import transforms
7 |
8 | # from models.resnet import resnet18
9 | from models.resnet_cifar import ResNet18
10 |
11 | from tasks.task import Task
12 |
13 |
14 | class Cifar10Task(Task):
15 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
16 | (0.2023, 0.1994, 0.2010))
17 |
18 | def load_data(self):
19 | self.load_cifar_data()
20 | number_of_samples = []
21 |
22 | if self.params.fl_sample_dirichlet:
23 | # sample indices for participants using Dirichlet distribution
24 | split = min(self.params.fl_total_participants / 10, 1)
25 | all_range = list(range(int(len(self.train_dataset) * split)))
26 | self.train_dataset = Subset(self.train_dataset, all_range)
27 | indices_per_participant = self.sample_dirichlet_train_data(
28 | self.params.fl_total_participants,
29 | alpha=self.params.fl_dirichlet_alpha)
30 | train_loaders, number_of_samples = zip(*[self.get_train(indices) for pos, indices in
31 | indices_per_participant.items()])
32 | # print("number_of_samples", number_of_samples)
33 |
34 | else:
35 | # sample indices for participants that are equally
36 | # split to 500 images per participant
37 | split = min(self.params.fl_total_participants / 100, 1)
38 | all_range = list(range(int(len(self.train_dataset) * split)))
39 | self.train_dataset = Subset(self.train_dataset, all_range)
40 | random.shuffle(all_range)
41 |
42 | train_loaders, number_of_samples = zip(*[self.get_train_old(all_range, pos)
43 | for pos in range(self.params.fl_total_participants)])
44 |
45 | self.fl_train_loaders = train_loaders
46 | self.fl_number_of_samples = number_of_samples
47 | # print("fl_number_of_samples", self.fl_number_of_samples)
48 |
49 |
50 | def load_cifar_data(self):
51 | # print(f"Loading CIFAR10 data: {self.params.transform_train}")
52 | # exit(0)
53 | if self.params.transform_train:
54 | transform_train = transforms.Compose([
55 | transforms.RandomCrop(32, padding=4),
56 | transforms.RandomHorizontalFlip(),
57 | transforms.ToTensor(),
58 | self.normalize,
59 | ])
60 | else:
61 | transform_train = transforms.Compose([
62 | transforms.ToTensor(),
63 | self.normalize,
64 | ])
65 | transform_test = transforms.Compose([
66 | transforms.ToTensor(),
67 | self.normalize,
68 | ])
69 | self.train_dataset = torchvision.datasets.CIFAR10(
70 | root=self.params.data_path,
71 | train=True,
72 | download=True,
73 | transform=transform_train)
74 |
75 | self.train_loader = DataLoader(self.train_dataset,
76 | batch_size=self.params.batch_size,
77 | shuffle=True,
78 | num_workers=0)
79 | self.test_dataset = torchvision.datasets.CIFAR10(
80 | root=self.params.data_path,
81 | train=False,
82 | download=True,
83 | transform=transform_test)
84 | self.test_loader = DataLoader(self.test_dataset,
85 | batch_size=self.params.test_batch_size,
86 | shuffle=False, num_workers=0)
87 |
88 | self.classes = ('plane', 'car', 'bird', 'cat',
89 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
90 | return True
91 |
92 | def build_model(self) -> nn.Module:
93 | # model = resnet18(pretrained=False,
94 | # num_classes=len(self.classes))
95 | model = ResNet18()
96 | return model
--------------------------------------------------------------------------------
/tasks/fl_user.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from torch.utils.data.dataloader import DataLoader
3 |
4 |
5 |
6 | @dataclass
7 | class FLUser:
8 | user_id: int = 0
9 | compromised: bool = False
10 | train_loader: DataLoader = None
11 | number_of_samples: int = 0
12 |
--------------------------------------------------------------------------------
/tasks/imagenet_task.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import torchvision
4 | from torch.utils.data import DataLoader
5 | from torchvision.transforms import transforms
6 | from torch.utils.data import Subset
7 |
8 | from models.resnet_tinyimagenet import resnet18
9 | from tasks.task import Task
10 | import os
11 | import logging
12 | logger = logging.getLogger('logger')
13 |
14 |
15 | class ImagenetTask(Task):
16 |
17 | def load_data(self):
18 | self.load_imagenet_data()
19 | if self.params.fl_sample_dirichlet:
20 | # sample indices for participants using Dirichlet distribution
21 | split = min(self.params.fl_total_participants / 100, 1)
22 | all_range = list(range(int(len(self.train_dataset) * split)))
23 | self.train_dataset = Subset(self.train_dataset, all_range)
24 | indices_per_participant = self.sample_dirichlet_train_data(
25 | self.params.fl_total_participants,
26 | alpha=self.params.fl_dirichlet_alpha)
27 | # train_loaders = [self.get_train(indices) for pos, indices in
28 | # indices_per_participant.items()]
29 |
30 | train_loaders, number_of_samples = zip(*[self.get_train(indices) for pos, indices in
31 | indices_per_participant.items()])
32 | else:
33 | # sample indices for participants that are equally
34 | # split to 500 images per participant
35 | split = min(self.params.fl_total_participants / 100, 1)
36 | all_range = list(range(int(len(self.train_dataset) * split)))
37 | self.train_dataset = Subset(self.train_dataset, all_range)
38 | random.shuffle(all_range)
39 | train_loaders = [self.get_train_old(all_range, pos)
40 | for pos in
41 | range(self.params.fl_total_participants)]
42 | self.fl_train_loaders = train_loaders
43 | self.fl_number_of_samples = number_of_samples
44 | logger.info(f"Done splitting with #participant: {self.params.fl_total_participants}")
45 |
46 | return
47 |
48 | def load_imagenet_data(self):
49 |
50 | train_transform = transforms.Compose([
51 | # transforms.RandomResizedCrop(224),
52 | transforms.Resize(224),
53 | transforms.RandomHorizontalFlip(),
54 | transforms.ToTensor(),
55 | self.normalize,
56 | ])
57 | test_transform = transforms.Compose([
58 | # transforms.Resize(256),
59 | # transforms.CenterCrop(224),
60 | transforms.Resize(224),
61 | transforms.ToTensor(),
62 | self.normalize,
63 | ])
64 |
65 | self.train_dataset = torchvision.datasets.ImageFolder(
66 | os.path.join(self.params.data_path, 'train'),
67 | train_transform)
68 |
69 | self.test_dataset = torchvision.datasets.ImageFolder(
70 | os.path.join(self.params.data_path, 'val'),
71 | test_transform)
72 |
73 | self.train_loader = DataLoader(self.train_dataset,
74 | batch_size=self.params.batch_size,
75 | shuffle=True, num_workers=8, pin_memory=True)
76 | self.test_loader = DataLoader(self.test_dataset,
77 | batch_size=self.params.test_batch_size,
78 | shuffle=False, num_workers=8, pin_memory=True)
79 |
80 | def build_model(self) -> None:
81 | return resnet18(pretrained=False)
82 |
--------------------------------------------------------------------------------
/tasks/mnist_task.py:
--------------------------------------------------------------------------------
1 | import random
2 | from torch.utils.data import Subset
3 | import torch.utils.data as torch_data
4 | import torchvision
5 | from torchvision.transforms import transforms
6 |
7 | from models.MnistNet import MnistNet
8 | from tasks.task import Task
9 | import logging
10 | logger = logging.getLogger('logger')
11 |
12 | class MNISTTask(Task):
13 | normalize = transforms.Normalize((0.1307,), (0.3081,))
14 |
15 | def load_data(self):
16 | self.load_mnist_data()
17 | if self.params.fl_sample_dirichlet:
18 | # sample indices for participants using Dirichlet distribution
19 | split = min(self.params.fl_total_participants / 20, 1)
20 | all_range = list(range(int(len(self.train_dataset) * split)))
21 | logger.info(f"all_range: {len(all_range)} len train_dataset: {len(self.train_dataset)}")
22 | # if number of participants is less than 20, then we will sample a subset of the dataset, otherwise we will use the whole dataset
23 | self.train_dataset = Subset(self.train_dataset, all_range)
24 | indices_per_participant = self.sample_dirichlet_train_data(
25 | self.params.fl_total_participants,
26 | alpha=self.params.fl_dirichlet_alpha)
27 |
28 | # train_loaders = [self.get_train(indices) for pos, indices in
29 | # indices_per_participant.items()]
30 |
31 | train_loaders, number_of_samples = zip(*[self.get_train(indices) for pos, indices in
32 | indices_per_participant.items()])
33 |
34 | else:
35 | # sample indices for participants that are equally
36 | split = min(self.params.fl_total_participants / 20, 1)
37 | all_range = list(range(int(len(self.train_dataset) * split)))
38 | self.train_dataset = Subset(self.train_dataset, all_range)
39 | random.shuffle(all_range)
40 | train_loaders, number_of_samples = zip(*[self.get_train_old(all_range, pos)
41 | for pos in range(self.params.fl_total_participants)])
42 |
43 | self.fl_train_loaders = train_loaders
44 | self.fl_number_of_samples = number_of_samples
45 | logger.info(f"Done splitting with #participant: {self.params.fl_total_participants}")
46 | return
47 |
48 |
49 | def load_mnist_data(self):
50 | transform_train = transforms.Compose([
51 | transforms.ToTensor(),
52 | self.normalize
53 | ])
54 |
55 | transform_test = transforms.Compose([
56 | transforms.ToTensor(),
57 | self.normalize
58 | ])
59 |
60 | self.train_dataset = torchvision.datasets.MNIST(
61 | root=self.params.data_path,
62 | train=True,
63 | download=True,
64 | transform=transform_train)
65 | self.train_loader = torch_data.DataLoader(self.train_dataset,
66 | batch_size=self.params.batch_size,
67 | shuffle=True,
68 | num_workers=0)
69 | self.test_dataset = torchvision.datasets.MNIST(
70 | root=self.params.data_path,
71 | train=False,
72 | download=True,
73 | transform=transform_test)
74 | self.test_loader = torch_data.DataLoader(self.test_dataset,
75 | batch_size=self.params.test_batch_size,
76 | shuffle=False,
77 | num_workers=0)
78 | self.classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
79 | return True
80 |
81 | def build_model(self):
82 | # return SimpleNet(num_classes=len(self.classes))
83 | return MnistNet()
84 |
--------------------------------------------------------------------------------
/tasks/task.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import random
3 | from typing import List, Any, Dict
4 | from copy import deepcopy
5 | import numpy as np
6 | from collections import defaultdict, Counter
7 |
8 |
9 | import torch
10 | from torch import optim, nn
11 | from torch.nn import Module
12 | from torch.optim import Optimizer
13 | from torch.utils.data import DataLoader
14 | from torch.utils.data.sampler import SubsetRandomSampler
15 | from torchvision.transforms import transforms
16 |
17 | from metrics.accuracy_metric import AccuracyMetric
18 | from metrics.metric import Metric
19 | from metrics.test_loss_metric import TestLossMetric
20 | from tasks.batch import Batch
21 | from tasks.fl_user import FLUser
22 | from utils.parameters import Params
23 |
24 | logger = logging.getLogger('logger')
25 |
26 |
27 | class Task:
28 | params: Params = None
29 |
30 | train_dataset = None
31 | test_dataset = None
32 | train_loader = None
33 | test_loader = None
34 | classes = None
35 |
36 | model: Module = None
37 | optimizer: optim.Optimizer = None
38 | criterion: Module = None
39 | metrics: List[Metric] = None
40 |
41 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
42 | std=[0.229, 0.224, 0.225])
43 | "Generic normalization for input data."
44 | input_shape: torch.Size = None
45 |
46 | fl_train_loaders: List[Any] = None
47 | fl_number_of_samples: List[int] = None
48 |
49 | ignored_weights = ['num_batches_tracked']#['tracked', 'running']
50 | adversaries: List[int] = None
51 |
52 | def __init__(self, params: Params):
53 | self.params = params
54 | self.init_task()
55 |
56 | def init_task(self):
57 | self.load_data()
58 | # self.stat_data()
59 | logger.debug(f"Number of train samples: {self.fl_number_of_samples}")
60 | # exit(0)
61 | self.model = self.build_model()
62 | self.resume_model()
63 | self.model = self.model.to(self.params.device)
64 |
65 | self.local_model = self.build_model().to(self.params.device)
66 |
67 | self.criterion = self.make_criterion()
68 | self.adversaries = self.sample_adversaries()
69 |
70 | self.optimizer = self.make_optimizer()
71 |
72 | self.metrics = [AccuracyMetric(), TestLossMetric(self.criterion)]
73 | self.set_input_shape()
74 | # Initialize the logger
75 | fh = logging.FileHandler(
76 | filename=f'{self.params.folder_path}/log.txt')
77 | formatter = logging.Formatter('%(asctime)s - %(filename)s - Line:%(lineno)d - %(levelname)-8s - %(message)s')
78 |
79 |
80 | fh.setFormatter(formatter)
81 | logger.addHandler(fh)
82 |
83 | def load_data(self) -> None:
84 | raise NotImplemented
85 |
86 | def stat_data(self):
87 | for i, loader in enumerate(self.fl_train_loaders):
88 | labels = [t.item() for data, target in loader for t in target]
89 | label_counts = dict(sorted(Counter(labels).items()))
90 | total_samples = sum(len(target) for _, target in loader)
91 | logger.debug(f"Participant {i:2d} has {label_counts} and total {total_samples} samples")
92 |
93 |
94 | def build_model(self) -> Module:
95 | raise NotImplemented
96 |
97 | def make_criterion(self) -> Module:
98 | """Initialize with Cross Entropy by default.
99 |
100 | We use reduction `none` to support gradient shaping defense.
101 | :return:
102 | """
103 | return nn.CrossEntropyLoss(reduction='none')
104 |
105 | def make_optimizer(self, model=None) -> Optimizer:
106 | if model is None:
107 | model = self.model
108 | if self.params.optimizer == 'SGD':
109 | optimizer = optim.SGD(model.parameters(),
110 | lr=self.params.lr,
111 | weight_decay=self.params.decay,
112 | momentum=self.params.momentum)
113 | elif self.params.optimizer == 'Adam':
114 | optimizer = optim.Adam(model.parameters(),
115 | lr=self.params.lr,
116 | weight_decay=self.params.decay)
117 | else:
118 | raise ValueError(f'No optimizer: {self.optimizer}')
119 |
120 | return optimizer
121 |
122 | def resume_model(self):
123 | if self.params.resume_model:
124 | # import IPython; IPython.embed()
125 | logger.info(f'Resuming training from {self.params.resume_model}')
126 | loaded_params = torch.load(f"{self.params.resume_model}",
127 | map_location=torch.device('cpu'))
128 | self.model.load_state_dict(loaded_params['state_dict'])
129 | self.params.start_epoch = loaded_params['epoch']
130 | # print self.model architechture to file 'model.txt'
131 | # with open(f'model.txt', 'w') as f:
132 | # f.write(str(self.model))
133 |
134 | # # print architecture of loaded_params
135 | # with open(f'loaded_params.txt', 'w') as f:
136 | # f.write(str(loaded_params['state_dict'].keys()))
137 |
138 | # self.params.lr = loaded_params.get('lr', self.params.lr)
139 |
140 | logger.warning(f"Loaded parameters from saved model: LR is"
141 | f" {self.params.lr} and current epoch is"
142 | f" {self.params.start_epoch}")
143 |
144 | def set_input_shape(self):
145 | inp = self.train_dataset[0][0]
146 | self.params.input_shape = inp.shape
147 | logger.info(f"Input shape is {self.params.input_shape}")
148 |
149 | def get_batch(self, batch_id, data) -> Batch:
150 | """Process data into a batch.
151 |
152 | Specific for different datasets and data loaders this method unifies
153 | the output by returning the object of class Batch.
154 | :param batch_id: id of the batch
155 | :param data: object returned by the Loader.
156 | :return:
157 | """
158 | inputs, labels = data
159 | batch = Batch(batch_id, inputs, labels)
160 | return batch.to(self.params.device)
161 |
162 | def accumulate_metrics(self, outputs, labels):
163 | for metric in self.metrics:
164 | metric.accumulate_on_batch(outputs, labels)
165 |
166 | def reset_metrics(self):
167 | for metric in self.metrics:
168 | metric.reset_metric()
169 |
170 | def report_metrics(self, step, prefix=''):
171 | metric_text = []
172 | for metric in self.metrics:
173 | metric_text.append(str(metric))
174 | logger.warning(f'{prefix} {step:4d}. {" | ".join(metric_text)}')
175 | # import IPython; IPython.embed()
176 | # exit(0)
177 | return self.metrics[0].get_main_metric_value()
178 | def get_metrics(self):
179 | # TODO Nov 11, 2023: This is a hack to get the metrics
180 | metric_dict = {
181 | 'accuracy': self.metrics[0].get_main_metric_value(),
182 | 'loss': self.metrics[1].get_main_metric_value()
183 | }
184 | # for metric in self.metrics:
185 | # metric_dict[metric.name] = metric.get_value()
186 | return metric_dict
187 |
188 | @staticmethod
189 | def get_batch_accuracy(outputs, labels, top_k=(1,)):
190 | """Computes the precision@k for the specified values of k"""
191 | max_k = max(top_k)
192 | batch_size = labels.size(0)
193 |
194 | _, pred = outputs.topk(max_k, 1, True, True)
195 | pred = pred.t()
196 | correct = pred.eq(labels.view(1, -1).expand_as(pred))
197 |
198 | res = []
199 | for k in top_k:
200 | correct_k = correct[:k].view(-1).float().sum(0)
201 | res.append((correct_k.mul_(100.0 / batch_size)).item())
202 | if len(res) == 1:
203 | res = res[0]
204 | return res
205 |
206 | def get_empty_accumulator(self):
207 | weight_accumulator = dict()
208 | for name, data in self.model.state_dict().items():
209 | weight_accumulator[name] = torch.zeros_like(data)
210 | return weight_accumulator
211 |
212 | def sample_users_for_round(self, epoch) -> List[FLUser]:
213 | sampled_ids = random.sample(
214 | range(self.params.fl_total_participants),
215 | self.params.fl_no_models)
216 |
217 | # if 7 not in sampled_ids:
218 | # sampled_ids[0] = 7
219 |
220 | sampled_users = []
221 | for pos, user_id in enumerate(sampled_ids):
222 | train_loader = self.fl_train_loaders[user_id]
223 | number_of_samples = self.fl_number_of_samples[user_id]
224 | compromised = self.check_user_compromised(epoch, pos, user_id)
225 |
226 | user = FLUser(user_id, compromised=compromised,
227 | train_loader=train_loader, number_of_samples=number_of_samples)
228 | sampled_users.append(user)
229 | logger.warning(f"Sampled users for round {epoch}: {[[user.user_id, user.compromised] for user in sampled_users]} ")
230 |
231 |
232 | total_samples = sum([user.number_of_samples for user in sampled_users])
233 |
234 | self.params.fl_weight_contribution = {user.user_id: user.number_of_samples / total_samples for user in sampled_users}
235 | # self.params.fl_round_participants = [user.user_id for user in sampled_users]
236 | self.params.fl_local_updated_models = dict()
237 | logger.warning(f"Sampled users for round {epoch}: {self.params.fl_weight_contribution} ")
238 | return sampled_users
239 |
240 | def adding_local_updated_model(self, local_update: Dict[str, torch.Tensor], epoch=None, user_id=None):
241 | self.params.fl_local_updated_models[user_id] = local_update
242 |
243 | def check_user_compromised(self, epoch, pos, user_id):
244 | """Check if the sampled user is compromised for the attack.
245 |
246 | If single_epoch_attack is defined (eg not None) then ignore
247 | :param epoch:
248 | :param pos:
249 | :param user_id:
250 | :return:
251 | """
252 | compromised = False
253 | if self.params.fl_single_epoch_attack is not None:
254 | if epoch == self.params.fl_single_epoch_attack:
255 | if pos < self.params.fl_number_of_adversaries:
256 | compromised = True
257 | # if user_id == 0:
258 | # compromised = True
259 | logger.warning(f'Attacking once at epoch {epoch}. Compromised'
260 | f' user: {user_id}.')
261 | else:
262 | if epoch >= self.params.poison_epoch and epoch < self.params.poison_epoch_stop + 1:
263 | compromised = user_id in self.adversaries
264 | return compromised
265 |
266 | def sample_adversaries(self) -> List[int]:
267 | adversaries_ids = []
268 | if self.params.fl_number_of_adversaries == 0:
269 | logger.warning(f'Running vanilla FL, no attack.')
270 | elif self.params.fl_single_epoch_attack is None:
271 | # adversaries_ids = list(range(self.params.fl_number_of_adversaries))
272 | adversaries_ids = random.sample(
273 | range(self.params.fl_total_participants),
274 | self.params.fl_number_of_adversaries)
275 |
276 | logger.warning(f'Attacking over multiple epochs with following '
277 | f'users compromised: {adversaries_ids}.')
278 | else:
279 | logger.warning(f'Attack only on epoch: '
280 | f'{self.params.fl_single_epoch_attack} with '
281 | f'{self.params.fl_number_of_adversaries} compromised'
282 | f' users.')
283 |
284 | return adversaries_ids
285 |
286 | def get_model_optimizer(self, model):
287 | local_model = deepcopy(model)
288 | local_model = local_model.to(self.params.device)
289 |
290 | optimizer = self.make_optimizer(local_model)
291 |
292 | return local_model, optimizer
293 |
294 | def copy_params(self, global_model: Module, local_model: Module):
295 | local_state = local_model.state_dict()
296 | for name, param in global_model.state_dict().items():
297 | if name in local_state and name not in self.ignored_weights:
298 | local_state[name].copy_(param)
299 |
300 | def update_global_model(self, weight_accumulator, global_model: Module):
301 | # self.last_global_model = deepcopy(self.model)
302 | for name, sum_update in weight_accumulator.items():
303 | if self.check_ignored_weights(name):
304 | continue
305 | # scale = self.params.fl_eta / self.params.fl_total_participants
306 | # TODO: change this based on number of sample of each user
307 | # scale = 1 self.params.fl_eta / self.params.fl_no_models
308 | # scale = 1.0
309 | # average_update = scale * sum_update
310 | average_update = sum_update
311 | model_weight = global_model.state_dict()[name]
312 | model_weight.add_(average_update)
313 |
314 | def check_ignored_weights(self, name) -> bool:
315 | for ignored in self.ignored_weights:
316 | if ignored in name:
317 | return True
318 |
319 | return False
320 |
321 | def sample_dirichlet_train_data(self, no_participants, alpha=0.9):
322 | """
323 | Input: Number of participants and alpha (param for distribution)
324 | Output: A list of indices denoting data in CIFAR training set.
325 | Requires: dataset_classes, a preprocessed class-indices dictionary.
326 | Sample Method: take a uniformly sampled 10-dimension vector as
327 | parameters for
328 | dirichlet distribution to sample number of images in each class.
329 | """
330 |
331 | dataset_classes = {}
332 | for ind, x in enumerate(self.train_dataset):
333 | _, label = x
334 | if label in dataset_classes:
335 | dataset_classes[label].append(ind)
336 | else:
337 | dataset_classes[label] = [ind]
338 | class_size = len(dataset_classes[0])
339 | per_participant_list = defaultdict(list)
340 | no_classes = len(dataset_classes.keys())
341 |
342 | for n in range(no_classes):
343 | random.shuffle(dataset_classes[n])
344 | sampled_probabilities = class_size * np.random.dirichlet(
345 | np.array(no_participants * [alpha]))
346 | for user in range(no_participants):
347 | no_imgs = int(round(sampled_probabilities[user]))
348 | sampled_list = dataset_classes[n][
349 | :min(len(dataset_classes[n]), no_imgs)]
350 | per_participant_list[user].extend(sampled_list)
351 | dataset_classes[n] = dataset_classes[n][
352 | min(len(dataset_classes[n]), no_imgs):]
353 |
354 | return per_participant_list
355 |
356 | def get_train(self, indices):
357 | """
358 | This method is used along with Dirichlet distribution
359 | :param indices:
360 | :return:
361 | """
362 | train_loader = DataLoader(self.train_dataset,
363 | batch_size=self.params.batch_size,
364 | sampler=SubsetRandomSampler(
365 | indices), drop_last=True)
366 | return train_loader, len(indices)
367 |
368 | def get_train_old(self, all_range, model_no):
369 | """
370 | This method equally splits the dataset.
371 | :param all_range:
372 | :param model_no:
373 | :return:
374 | """
375 |
376 | data_len = int(
377 | len(self.train_dataset) / self.params.fl_total_participants)
378 | sub_indices = all_range[model_no * data_len: (model_no + 1) * data_len]
379 | train_loader = DataLoader(self.train_dataset,
380 | batch_size=self.params.batch_size,
381 | sampler=SubsetRandomSampler(
382 | sub_indices))
383 | return train_loader, len(sub_indices)
--------------------------------------------------------------------------------
/test-lib/test-log.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import colorlog
3 | import os
4 | from logging import FileHandler
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | # Example ground truth labels and predicted outputs
10 | # Assuming batch_size=3 and num_classes=5
11 | target = torch.tensor([2, 0, 4]) # Ground truth labels
12 | output = torch.tensor([[0.1, 0.2, 0.6, 0.1, 0.0],
13 | [0.8, 0.1, 0.0, 0.05, 0.05],
14 | [0.0, 0.0, 0.1, 0.2, 0.7]]) # Predicted scores/logits
15 |
16 | # Using nn.CrossEntropyLoss with reduction='none'
17 | criterion = nn.CrossEntropyLoss(reduction='none')
18 |
19 | # Calculate loss for each example separately
20 | losses = criterion(output, target)
21 | print(losses, losses.mean())
22 |
23 |
24 | # def create_logger():
25 | # """
26 | # Setup the logging environment
27 | # """
28 | # log = logging.getLogger() # root logger
29 | # log.setLevel(logging.DEBUG)
30 | # format_str = '%(asctime)s - %(filename)s - Line:%(lineno)d - %(levelname)-8s - %(message)s'
31 | # date_format = '%Y-%m-%d %H:%M:%S'
32 | # if os.isatty(2):
33 | # cformat = '%(log_color)s' + format_str
34 | # colors = {'DEBUG': 'bold_blue',
35 | # 'INFO': 'reset',
36 | # 'WARNING': 'bold_yellow',
37 | # 'ERROR': 'bold_red',
38 | # 'CRITICAL': 'bold_red'}
39 | # formatter = colorlog.ColoredFormatter(cformat, date_format,
40 | # log_colors=colors)
41 | # else:
42 | # formatter = logging.Formatter(format_str, date_format)
43 | # stream_handler = logging.StreamHandler()
44 | # stream_handler.setFormatter(formatter)
45 | # log.addHandler(stream_handler)
46 | # return logging.getLogger(__name__)
47 |
48 | # logger = create_logger()
49 |
50 | # # Configure basic logging settings with the custom formatter
51 | # logging.basicConfig(level=logging.DEBUG,
52 | # format='%(log_color)s %(asctime)s - %(filename)s - Line:%(lineno)d - %(name)s - %(levelname)-8s - %(message)s')
53 | # # cformat = '%(log_color)s' + format_str
54 | # # colors = {'DEBUG': 'reset',
55 | # # 'INFO': 'reset',
56 | # # 'WARNING': 'bold_yellow',
57 | # # 'ERROR': 'bold_red',
58 | # # 'CRITICAL': 'bold_red'}
59 |
60 | # # Set the custom formatter
61 | # logger = logging.getLogger()
62 |
63 | # # Create a logger
64 | # logger = logging.getLogger("my_logger")
65 |
66 | # # Create an HTML formatter
67 | # class HTMLFormatter(logging.Formatter):
68 | # def format(self, record):
69 | # level = record.levelname
70 | # message = record.getMessage()
71 | # formatted = f"{level}: {message}
\n"
72 | # return formatted
73 |
74 | # # Create a FileHandler and set the HTML formatter
75 | # file_handler = FileHandler('logs.html', mode='a', encoding='utf-8')
76 | # html_formatter = HTMLFormatter()
77 | # file_handler.setFormatter(html_formatter)
78 |
79 | # # Add the FileHandler to the logger
80 | # logger.addHandler(file_handler)
81 |
82 |
83 | # Log messages
84 | # logger.debug("This is a debug message")
85 | # logger.info("This is an info message")
86 | # logger.warning("This is a warning message")
87 | # logger.error("This is an error message")
88 | # logger.critical("This is a critical message")
89 |
90 |
91 | # # write logger to html file
92 | # import numpy as np
93 | # from PIL import Image
94 | # from torchvision import transforms
95 | # trans = transforms.Compose([transforms.ToTensor()])
96 | # img = np.random.randint(0, 255, size=(3, 224, 224), dtype=np.uint8)
97 |
98 | # demo_img = trans(img)
99 |
100 | # demo_array = np.moveaxis(demo_img.numpy()*255, 0, -1)
101 | # print(Image.fromarray(demo_array.astype(np.uint8)))
102 |
103 | # from matplotlib import pyplot as plt
104 | # plt.imshow(img)
105 | # plt.show()
--------------------------------------------------------------------------------
/test-lib/test-write-file.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 13,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "840 \n",
13 | "fl_total_participants: 100 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05\n",
14 | "python training.py --name cifar10 --params /home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/run_cifar10__2023.Nov.24/cifar10_fed_100_10_4_0.5_0.05.yaml\n",
15 | "------------------------\n",
16 | "fl_total_participants: 200 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05\n",
17 | "python training.py --name cifar10 --params /home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/run_cifar10__2023.Nov.24/cifar10_fed_200_10_4_0.5_0.05.yaml\n",
18 | "------------------------\n",
19 | "796 \n",
20 | "fl_total_participants: 100 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05\n",
21 | "python training.py --name mnist --params /home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/run_mnist__2023.Nov.24/mnist_fed_100_10_4_0.5_0.05.yaml\n",
22 | "------------------------\n",
23 | "fl_total_participants: 200 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05\n",
24 | "python training.py --name mnist --params /home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/run_mnist__2023.Nov.24/mnist_fed_200_10_4_0.5_0.05.yaml\n",
25 | "------------------------\n",
26 | "863 \n",
27 | "fl_total_participants: 100 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05\n",
28 | "python training.py --name tiny-imagenet --params /home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/run_tiny-imagenet__2023.Nov.24/tiny-imagenet_fed_100_10_4_0.5_0.05.yaml\n",
29 | "------------------------\n",
30 | "fl_total_participants: 200 fl_no_models: 10 fl_dirichlet_alpha: 0.5 fl_number_of_adversaries: 4 fl_lr: 0.05\n",
31 | "python training.py --name tiny-imagenet --params /home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/run_tiny-imagenet__2023.Nov.24/tiny-imagenet_fed_200_10_4_0.5_0.05.yaml\n",
32 | "------------------------\n"
33 | ]
34 | }
35 | ],
36 | "source": [
37 | "import os\n",
38 | "from datetime import datetime\n",
39 | "\n",
40 | "\n",
41 | "def generate_exps_file(root_file='/home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/cifar_fed.yaml', name_exp = 'cifar10', EXPS_DIR=\"/home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/extras\"):\n",
42 | " # read file as a string\n",
43 | " print(f'reading from root_file: {root_file}')\n",
44 | " with open(root_file, 'r') as file :\n",
45 | " filedata = file.read()\n",
46 | " fl_total_participants_choices = [100, 200]\n",
47 | " fl_no_models_choices = [10]\n",
48 | " fl_dirichlet_alpha_choices = [0.5]\n",
49 | " fl_number_of_adversaries_choices = [4]\n",
50 | " fl_lr_choices = [0.05]\n",
51 | " # EXPS_DIR = '/home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/extras'\n",
52 | " \n",
53 | " os.makedirs(EXPS_DIR, exist_ok=True)\n",
54 | "\n",
55 | " for fl_total_participants in fl_total_participants_choices:\n",
56 | " for fl_no_models in fl_no_models_choices:\n",
57 | " for fl_dirichlet_alpha in fl_dirichlet_alpha_choices:\n",
58 | " for fl_number_of_adversaries in fl_number_of_adversaries_choices:\n",
59 | " for fl_lr in fl_lr_choices:\n",
60 | " print(f'fl_total_participants: {fl_total_participants} fl_no_models: {fl_no_models} fl_dirichlet_alpha: {fl_dirichlet_alpha} fl_number_of_adversaries: {fl_number_of_adversaries} fl_lr: {fl_lr}')\n",
61 | " filedata = filedata.replace('fl_total_participants: 100', f'fl_total_participants: {fl_total_participants}')\n",
62 | " filedata = filedata.replace('fl_no_models: 10', f'fl_no_models: {fl_no_models}')\n",
63 | " filedata = filedata.replace('fl_dirichlet_alpha: 0.5', f'fl_dirichlet_alpha: {fl_dirichlet_alpha}')\n",
64 | " filedata = filedata.replace('fl_number_of_adversaries: 4', f'fl_number_of_adversaries: {fl_number_of_adversaries}')\n",
65 | " filedata = filedata.replace('lr: 0.005', f'lr: {fl_lr}')\n",
66 | " # print(len(filedata), type(filedata)) \n",
67 | " # print('------------------------')\n",
68 | " # write the file out again\n",
69 | " fn_write = f'{EXPS_DIR}/{name_exp}_fed_{fl_total_participants}_{fl_no_models}_{fl_number_of_adversaries}_{fl_dirichlet_alpha}_{fl_lr}.yaml'\n",
70 | " if not os.path.exists(fn_write):\n",
71 | " with open(fn_write, 'w') as file:\n",
72 | " file.write(filedata)\n",
73 | " \n",
74 | " cmd = f'python training.py --name {name_exp} --params {fn_write}'\n",
75 | " print(cmd)\n",
76 | " print('------------------------')\n",
77 | "\n",
78 | "current_time = datetime.now().strftime('%Y.%b.%d')\n",
79 | "\n",
80 | "generate_exps_file(root_file='/home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/cifar_fed.yaml', \n",
81 | " name_exp = 'cifar10', \n",
82 | " EXPS_DIR=f\"/home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/run_cifar10__{current_time}\")\n",
83 | "\n",
84 | "\n",
85 | "generate_exps_file(root_file='/home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/mnist_fed.yaml', \n",
86 | " name_exp = 'mnist', \n",
87 | " EXPS_DIR=f\"/home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/run_mnist__{current_time}\")\n",
88 | "\n",
89 | "generate_exps_file(root_file='/home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/imagenet_fed.yaml', \n",
90 | " name_exp = 'tiny-imagenet', \n",
91 | " EXPS_DIR=f\"/home/vishc2/tuannm/fedlearn-backdoor-attacks/exps/run_tiny-imagenet__{current_time}\")\n",
92 | "\n"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": null,
98 | "metadata": {},
99 | "outputs": [],
100 | "source": []
101 | }
102 | ],
103 | "metadata": {
104 | "kernelspec": {
105 | "display_name": "cardio",
106 | "language": "python",
107 | "name": "python3"
108 | },
109 | "language_info": {
110 | "codemirror_mode": {
111 | "name": "ipython",
112 | "version": 3
113 | },
114 | "file_extension": ".py",
115 | "mimetype": "text/x-python",
116 | "name": "python",
117 | "nbconvert_exporter": "python",
118 | "pygments_lexer": "ipython3",
119 | "version": "3.7.13"
120 | },
121 | "orig_nbformat": 4
122 | },
123 | "nbformat": 4,
124 | "nbformat_minor": 2
125 | }
126 |
--------------------------------------------------------------------------------
/test-lib/test.log:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtuann/fedlearn-backdoor-attacks/b025e763558d1e3d6a5872be7d158e8c157e7bf1/test-lib/test.log
--------------------------------------------------------------------------------
/training.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import yaml
3 | from helper import Helper
4 | from datetime import datetime
5 | from tqdm import tqdm
6 | import wandb
7 |
8 | from utils.utils import *
9 | logger = logging.getLogger('logger')
10 |
11 | def train(hlpr: Helper, epoch, model, optimizer, train_loader, attack=False, global_model=None):
12 | criterion = hlpr.task.criterion
13 | model.train()
14 | # for i, data in tqdm(enumerate(train_loader)):
15 | for i, data in enumerate(train_loader):
16 | batch = hlpr.task.get_batch(i, data)
17 | model.zero_grad()
18 | loss = hlpr.attack.compute_blind_loss(model, criterion, batch, attack, global_model)
19 | loss.backward()
20 | optimizer.step()
21 | # import IPython; IPython.embed()
22 | # exit(0)
23 | # print(f"Epoch {epoch} batch {i} loss {loss.item()}")
24 |
25 | if i == hlpr.params.max_batch_id:
26 | break
27 | # metric = hlpr.task.report_metrics(epoch,
28 | # prefix=f'Backdoor {str(backdoor):5s}. Epoch: ')
29 | return
30 |
31 | def test(hlpr: Helper, epoch, backdoor=False, model=None):
32 | if model is None:
33 | model = hlpr.task.model
34 | model.eval()
35 | hlpr.task.reset_metrics()
36 | with torch.no_grad():
37 | for i, data in tqdm(enumerate(hlpr.task.test_loader)):
38 | batch = hlpr.task.get_batch(i, data)
39 | if backdoor:
40 | batch = hlpr.attack.synthesizer.make_backdoor_batch(batch,
41 | test=True,
42 | attack=True)
43 |
44 | outputs = model(batch.inputs)
45 | hlpr.task.accumulate_metrics(outputs=outputs, labels=batch.labels)
46 | metric = hlpr.task.report_metrics(epoch,
47 | prefix=f'Backdoor {str(backdoor):5s}. Epoch: ')
48 | return metric
49 |
50 | def run_fl_round(hlpr: Helper, epoch):
51 | global_model = hlpr.task.model
52 | local_model = hlpr.task.local_model
53 | round_participants = hlpr.task.sample_users_for_round(epoch)
54 |
55 | weight_accumulator = hlpr.task.get_empty_accumulator()
56 |
57 | logger.info(f"Round epoch {epoch} with participants: {[user.user_id for user in round_participants]} and weight: {hlpr.params.fl_weight_contribution}")
58 | # log number of sample per user
59 | logger.info(f"Round epoch {epoch} with participants sample size: {[user.number_of_samples for user in round_participants]} and sum: {sum([user.number_of_samples for user in round_participants])}")
60 |
61 | for user in tqdm(round_participants):
62 | hlpr.task.copy_params(global_model, local_model)
63 | optimizer = hlpr.task.make_optimizer(local_model)
64 | if user.compromised:
65 | # if not user.user_id == 0:
66 | # continue
67 |
68 | logger.warning(f"Compromised user: {user.user_id} in run_fl_round {epoch}")
69 | for local_epoch in tqdm(range(hlpr.params.fl_poison_epochs)): # fl_poison_epochs)):
70 | train(hlpr, local_epoch, local_model, optimizer,
71 | user.train_loader, attack=True, global_model=global_model)
72 |
73 | else:
74 | logger.warning(f"Non-compromised user: {user.user_id} in run_fl_round {epoch}")
75 | for local_epoch in range(hlpr.params.fl_local_epochs):
76 | train(hlpr, local_epoch, local_model, optimizer,
77 | user.train_loader, attack=False)
78 |
79 | local_update = hlpr.attack.get_fl_update(local_model, global_model)
80 |
81 | # hlpr.save_update(model=local_update, userID=user.user_id)
82 | # Do not save model to files, save it as a variable
83 | hlpr.task.adding_local_updated_model(local_update = local_update, user_id=user.user_id)
84 |
85 | # if user.compromised:
86 | # hlpr.attack.perform_attack(global_model, user, epoch)
87 | # hlpr.attack.local_dataset = deepcopy(user.train_loader)
88 |
89 | hlpr.defense.aggr(weight_accumulator, global_model, )
90 | # logger.info(f"Round {epoch} update global model")
91 |
92 | hlpr.task.update_global_model(weight_accumulator, global_model)
93 |
94 | def run(hlpr: Helper):
95 | metric = test(hlpr, -1, backdoor=False)
96 | logger.info(f"Before training main metric: {metric}")
97 |
98 | for epoch in range(hlpr.params.start_epoch,
99 | hlpr.params.epochs + 1):
100 | logger.info(f"Communication round {epoch}")
101 | run_fl_round(hlpr, epoch)
102 | metric = test(hlpr, epoch, backdoor=False)
103 | main_metric = hlpr.task.get_metrics()
104 |
105 | # logger.info(f"Epoch {epoch} main metric: {metric}")
106 | metric_bd = test(hlpr, epoch, backdoor=True)
107 | backdoor_metric = hlpr.task.get_metrics()
108 |
109 |
110 | wandb.log({'main_acc': main_metric['accuracy'], 'main_loss': main_metric['loss'],
111 | 'backdoor_acc': backdoor_metric['accuracy'], 'backdoor_loss': backdoor_metric['loss']},
112 | step=epoch)
113 | logger.info(f"Epoch {epoch} backdoor metric: {metric}")
114 |
115 | # hlpr.record_accuracy(metric, test(hlpr, epoch, backdoor=True), epoch)
116 |
117 | hlpr.save_model(hlpr.task.model, epoch, metric)
118 |
119 | def generate_exps_file(root_file='cifar_fed.yaml'):
120 | # read file as a string
121 | with open(root_file, 'r') as file :
122 | filedata = file.read()
123 | print(len(filedata), type(filedata))
124 | pass
125 |
126 | if __name__ == '__main__':
127 | parser = argparse.ArgumentParser(description='Backdoors')
128 | parser.add_argument('--params', dest='params', required=True)
129 | parser.add_argument('--name', dest='name', required=True)
130 | # python training.py --name mnist --params exps/mnist_fed.yaml
131 | # python training.py --name tiny-imagenet-200 --params exps/imagenet_fed.yaml
132 | # python training.py --name cifar10 --params exps/cifar_fed.yaml
133 | args = parser.parse_args()
134 | print(args)
135 | with open(args.params) as f:
136 | params = yaml.load(f, Loader=yaml.FullLoader)
137 | # print(params)
138 | # import IPython; IPython.embed()
139 | pretrained_str = 'pretrained' if params['resume_model'] else 'no_pretrained'
140 |
141 | params['name'] = f'vishc_{args.name}.{params["synthesizer"]}.{params["fl_total_participants"]}_{params["fl_no_models"]}_{params["fl_number_of_adversaries"]}_{params["fl_dirichlet_alpha"]}_{params["lr"]}_{pretrained_str}'
142 |
143 | params['current_time'] = datetime.now().strftime('%Y.%b.%d_%H.%M.%S')
144 | print(params)
145 | # exit(0)
146 | helper = Helper(params)
147 |
148 | # logger = create_logger()
149 |
150 | # logger.info(create_table(params))
151 |
152 | wandb.init(project="benchmark-backdoor-fl", entity="mtuann", name=f"{params['name']}-{params['current_time']}")
153 | try:
154 | run(helper)
155 | except Exception as e:
156 | print(e)
157 |
158 |
--------------------------------------------------------------------------------
/utils/parameters.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from dataclasses import dataclass, asdict
3 | from typing import List, Dict
4 | import logging
5 | import torch
6 | logger = logging.getLogger('logger')
7 |
8 | @dataclass
9 | class Params:
10 |
11 | # Corresponds to the class module: tasks.mnist_task.MNISTTask
12 | # See other tasks in the task folder.
13 | task: str = 'MNIST'
14 | num_classes = 10 #TODO: update num_classes for TinyImageNet
15 |
16 | current_time: str = None
17 | name: str = None
18 | random_seed: int = None
19 | device: str = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 | # training params
21 | start_epoch: int = 1
22 | epochs: int = None
23 | poison_epoch: int = None
24 | poison_epoch_stop: int = None
25 | log_interval: int = 1000
26 |
27 | # model arch is usually defined by the task
28 | resume_model: str = None
29 | lr: float = None
30 | decay: float = None
31 | momentum: float = None
32 | optimizer: str = None
33 | # data
34 | data_path: str = '.data/'
35 | batch_size: int = 64
36 | test_batch_size: int = 100
37 | # Do not apply transformations to the training images.
38 | transform_train: bool = True
39 | # For large datasets stop training earlier.
40 | max_batch_id: int = None
41 | # No need to set, updated by the Task class.
42 | input_shape = None
43 |
44 | # attack params
45 | backdoor: bool = False
46 | backdoor_label: int = 8
47 | poisoning_proportion: float = 1.0 # backdoors proportion in backdoor loss
48 | synthesizer: str = 'pattern'
49 | backdoor_dynamic_position: bool = False
50 |
51 | # factors to balance losses
52 | fixed_scales: Dict[str, float] = None
53 |
54 | # optimizations:
55 | alternating_attack: float = None
56 | clip_batch: float = None
57 | # Disable BatchNorm and Dropout
58 | switch_to_eval: float = None
59 |
60 | # logging
61 | report_train_loss: bool = True
62 | log: bool = False
63 | save_model: bool = None
64 | save_on_epochs: List[int] = None
65 | save_scale_values: bool = False
66 | print_memory_consumption: bool = False
67 | save_timing: bool = False
68 | timing_data = None
69 |
70 | # Temporary storage for running values
71 | running_losses = None
72 | running_scales = None
73 |
74 | # FL params
75 | fl: bool = False
76 | fl_no_models: int = 100
77 | fl_local_epochs: int = 2
78 | fl_poison_epochs: int = None
79 | fl_total_participants: int = 200
80 | fl_eta: int = 1
81 | fl_sample_dirichlet: bool = False
82 | fl_dirichlet_alpha: float = None
83 | fl_diff_privacy: bool = False
84 | # FL attack details. Set no adversaries to perform the attack:
85 | fl_number_of_adversaries: int = 0
86 | fl_single_epoch_attack: int = None
87 | fl_weight_scale: int = 1
88 |
89 | fl_round_participants: List[int] = None
90 | fl_weight_contribution: Dict[int, float] = None
91 |
92 | fl_local_updated_models: Dict[int, Dict[str, torch.Tensor]] = None
93 |
94 | attack: str = None #'ThrDFed' (3DFed), 'ModelRplace' (Model Replacement)
95 |
96 | #"Foolsgold", "FLAME", "RFLBAT", "Deepsight", "FLDetector"
97 | defense: str = None
98 | lagrange_step: float = None
99 | random_neurons: List[int] = None
100 | noise_mask_alpha: float = None
101 | fl_adv_group_size: int = 0
102 | fl_num_neurons: int = 0
103 |
104 | def __post_init__(self):
105 | # enable logging anyways when saving statistics
106 | if self.save_model or self.save_timing or \
107 | self.print_memory_consumption:
108 | self.log = True
109 |
110 | if self.log:
111 | self.folder_path = f'saved_models/' \
112 | f'{self.task}__{self.current_time}__{self.name}'
113 |
114 | self.running_losses = defaultdict(list)
115 | self.running_scales = defaultdict(list)
116 | self.timing_data = defaultdict(list)
117 |
118 | def to_dict(self):
119 | return asdict(self)
--------------------------------------------------------------------------------
/utils/process_tiny_data.sh:
--------------------------------------------------------------------------------
1 | unzip tiny-imagenet-200.zip -d ./
2 | if [ ! -d ../.data ];then
3 | mkdir ../.data
4 | else
5 | echo '../.data' dir exist
6 | fi
7 |
8 | mv ./tiny-imagenet-200 ../.data/
9 | echo move 'tiny-imagenet-200' dir to '../.data/tiny-imagenet-200'
10 | python tinyimagenet_reformat.py
11 |
--------------------------------------------------------------------------------
/utils/tinyimagenet_reformat.py:
--------------------------------------------------------------------------------
1 | import io
2 | import pandas as pd
3 | import glob
4 | import os
5 | from shutil import move
6 | from os.path import join
7 | from os import listdir, rmdir
8 |
9 | target_folder = '../.data/tiny-imagenet-200/val/'
10 |
11 | val_dict = {}
12 | with open(target_folder + 'val_annotations.txt', 'r') as f:
13 | for line in f.readlines():
14 | split_line = line.split('\t')
15 | val_dict[split_line[0]] = split_line[1]
16 |
17 | paths = glob.glob(target_folder + 'images/*')
18 | paths[0].split('/')[-1]
19 | for path in paths:
20 | file = path.split('/')[-1]
21 | file = file.split('\\')[-1]
22 | folder = val_dict[file]
23 | if not os.path.exists(target_folder + str(folder)):
24 | os.mkdir(target_folder + str(folder))
25 |
26 | for path in paths:
27 | file = path.split('/')[-1]
28 | file = file.split('\\')[-1]
29 | folder = val_dict[file]
30 | dest = target_folder + str(folder) + '/' + str(file)
31 | move(path, dest)
32 |
33 | os.remove('../.data/tiny-imagenet-200/val/val_annotations.txt')
34 | rmdir('../.data/tiny-imagenet-200/val/images')
35 | print('done reformat the validation images')
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import time
4 |
5 | import colorlog
6 | import torch
7 |
8 | from utils.parameters import Params
9 |
10 | def record_time(params: Params, t=None, name=None):
11 | if t and name and params.save_timing == name or params.save_timing is True:
12 | torch.cuda.synchronize()
13 | params.timing_data[name].append(round(1000 * (time.perf_counter() - t)))
14 |
15 | def create_table(params: dict):
16 | data = f"\n| {'name' + ' ' * 21} | value | \n|{'-'*27}|----------|"
17 |
18 | for key, value in params.items():
19 | # data += '\n' + f"| {(25 - len(key)) * ' ' }{key} | {value} |"
20 | data += f"\n| {key: <25} | {value} "
21 |
22 | return data
23 |
24 | def create_logger():
25 | """
26 | Setup the logging environment
27 | """
28 | log = logging.getLogger() # root logger
29 | log.setLevel(logging.DEBUG)
30 | format_str = '%(asctime)s - %(filename)s - Line:%(lineno)d - %(levelname)-8s - %(message)s'
31 | date_format = '%Y-%m-%d %H:%M:%S'
32 | if os.isatty(2):
33 | cformat = '%(log_color)s' + format_str
34 | colors = {'DEBUG': 'bold_blue',
35 | 'INFO': 'bold_green',
36 | 'WARNING': 'bold_yellow',
37 | 'ERROR': 'bold_red',
38 | 'CRITICAL': 'bold_red'}
39 | formatter = colorlog.ColoredFormatter(cformat, date_format,
40 | log_colors=colors)
41 | else:
42 | formatter = logging.Formatter(format_str, date_format)
43 | stream_handler = logging.StreamHandler()
44 | stream_handler.setFormatter(formatter)
45 | log.addHandler(stream_handler)
46 | return logging.getLogger(__name__)
47 |
--------------------------------------------------------------------------------