├── utils ├── apple.png ├── __init__.py ├── watermark.png ├── __pycache__ │ ├── info.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ ├── defense.cpython-39.pyc │ ├── options.cpython-39.pyc │ └── sampling.cpython-39.pyc ├── info.py ├── .ipynb_checkpoints │ ├── info-checkpoint.py │ ├── options-checkpoint.py │ └── defense-checkpoint.py ├── sampling.py ├── options.py └── defense.py ├── data ├── iid_cifar.npy ├── non_iid_cifar.npy ├── iid_fashion_mnist.npy └── non_iid_fashion_mnist.npy ├── models ├── __init__.py ├── __pycache__ │ ├── Fed.cpython-39.pyc │ ├── Nets.cpython-39.pyc │ ├── test.cpython-39.pyc │ ├── Update.cpython-39.pyc │ ├── Attacker.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ ├── resnet20.cpython-39.pyc │ ├── add_trigger.cpython-39.pyc │ ├── subnetutils.cpython-39.pyc │ ├── AttackerUtils.cpython-39.pyc │ └── MaliciousUpdate.cpython-39.pyc ├── Fed.py ├── .ipynb_checkpoints │ ├── Fed-checkpoint.py │ ├── resnet20-checkpoint.py │ ├── add_trigger-checkpoint.py │ ├── test-checkpoint.py │ ├── Attacker-checkpoint.py │ ├── MaliciousUpdate-checkpoint.py │ └── Nets-checkpoint.py ├── resnet20.py ├── add_trigger.py ├── test.py ├── Attacker.py ├── MaliciousUpdate.py ├── Nets.py └── Update.py ├── save ├── test_trigger2.png ├── a_info.txt ├── avg_VGG_noniid_mode10_ada │ ├── a_info.txt │ ├── accuracy_file_cifar_VGG_avg_1674543995_adaptive_0.1malicious_1.0poisondata_mode10.txt │ ├── accuracy_file_cifar_VGG_avg_1674544149_adaptive_0.1malicious_1.0poisondata_mode10.txt │ ├── accuracy_file_cifar_VGG_avg_1674545743_adaptive_0.1malicious_1.0poisondata_mode10.txt │ └── accuracy_file_cifar_VGG_avg_1674546108_badnet_0.1malicious_1.0poisondata.txt ├── cnn_multikrum_LPattack │ ├── a_info.txt │ ├── accuracy_file_fashion_mnist_rlr_mnist_multikrum_1674546908_adaptive_0.1malicious_0.5poisondata_mode10.txt │ └── accuracy_file_fashion_mnist_rlr_mnist_multikrum_1674546843_adaptive_0.1malicious_0.5poisondata_mode10.txt ├── ResNet18_LPattack_FLAME │ ├── a_info.txt │ └── accuracy_file_cifar_resnet_flame_1674546495_adaptive_0.1malicious_0.5poisondata_mode10.txt └── accuracy_file_cifar_resnet_fld_1678260074_no_malicious.txt ├── README.md └── main_fed.py /utils/apple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/utils/apple.png -------------------------------------------------------------------------------- /data/iid_cifar.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/data/iid_cifar.npy -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | -------------------------------------------------------------------------------- /utils/watermark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/utils/watermark.png -------------------------------------------------------------------------------- /data/non_iid_cifar.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/data/non_iid_cifar.npy -------------------------------------------------------------------------------- /save/test_trigger2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/save/test_trigger2.png -------------------------------------------------------------------------------- /data/iid_fashion_mnist.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/data/iid_fashion_mnist.npy -------------------------------------------------------------------------------- /data/non_iid_fashion_mnist.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/data/non_iid_fashion_mnist.npy -------------------------------------------------------------------------------- /models/__pycache__/Fed.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/models/__pycache__/Fed.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/Nets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/models/__pycache__/Nets.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/test.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/models/__pycache__/test.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/info.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/utils/__pycache__/info.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/Update.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/models/__pycache__/Update.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/defense.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/utils/__pycache__/defense.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/options.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/utils/__pycache__/options.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sampling.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/utils/__pycache__/sampling.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/Attacker.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/models/__pycache__/Attacker.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet20.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/models/__pycache__/resnet20.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/add_trigger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/models/__pycache__/add_trigger.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/subnetutils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/models/__pycache__/subnetutils.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/AttackerUtils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/models/__pycache__/AttackerUtils.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/MaliciousUpdate.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhmzm/FLDetector_pytorch/HEAD/models/__pycache__/MaliciousUpdate.cpython-39.pyc -------------------------------------------------------------------------------- /save/a_info.txt: -------------------------------------------------------------------------------- 1 | ====================================== 2 | IID: 1 3 | Dataset: cifar 4 | Model: resnet 5 | Model Init: None 6 | Aggregation Function: fld 7 | -----No Attack----- 8 | Number of agents: 100 9 | Fraction of agents each turn: 10(10.0%) 10 | Local batch size: 50 11 | Local epoch: 3 12 | Client_LR: 0.01 13 | Client_Momentum: 0.9 14 | Global Rounds: 500 15 | ====================================== 16 | -------------------------------------------------------------------------------- /models/Fed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import copy 6 | import torch 7 | from torch import nn 8 | 9 | 10 | def FedAvg(w): 11 | w_avg = copy.deepcopy(w[0]) 12 | for k in w_avg.keys(): 13 | for i in range(1, len(w)): 14 | try: 15 | w_avg[k] += w[i][k] 16 | except: 17 | print("Fed.py line17 type_as") 18 | w[i][k] = w[i][k].type_as(w_avg[k]) 19 | w_avg[k] += w[i][k] 20 | w_avg[k] = torch.div(w_avg[k], len(w)) 21 | return w_avg 22 | -------------------------------------------------------------------------------- /models/.ipynb_checkpoints/Fed-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import copy 6 | import torch 7 | from torch import nn 8 | 9 | 10 | def FedAvg(w): 11 | w_avg = copy.deepcopy(w[0]) 12 | for k in w_avg.keys(): 13 | for i in range(1, len(w)): 14 | try: 15 | w_avg[k] += w[i][k] 16 | except: 17 | print("Fed.py line17 type_as") 18 | w[i][k] = w[i][k].type_as(w_avg[k]) 19 | w_avg[k] += w[i][k] 20 | w_avg[k] = torch.div(w_avg[k], len(w)) 21 | return w_avg 22 | -------------------------------------------------------------------------------- /save/avg_VGG_noniid_mode10_ada/a_info.txt: -------------------------------------------------------------------------------- 1 | ====================================== 2 | IID: 0 3 | Dataset: cifar 4 | Model: VGG 5 | Model Init: None 6 | Aggregation Function: Fedavg 7 | Attack method: baseline 8 | Attack tau: 0.8 9 | Fraction of malicious agents: 10.0% 10 | Poison Frac: 1.0 11 | Backdoor From -1 to 5 12 | Attack Begin: 0 13 | Trigger Shape: square 14 | Trigger Position X: 27 15 | Trigger Position Y: 27 16 | Number of agents: 100 17 | Fraction of agents each turn: 10(10.0%) 18 | Local batch size: 64 19 | Local epoch: 2 20 | Client_LR: 0.1 21 | Client_Momentum: 0.9 22 | Global Rounds: 200 23 | ====================================== 24 | -------------------------------------------------------------------------------- /save/cnn_multikrum_LPattack/a_info.txt: -------------------------------------------------------------------------------- 1 | ====================================== 2 | IID: 1 3 | Dataset: fashion_mnist 4 | Model: cnn 5 | Model Init: None 6 | Aggregation Function: multikrum 7 | Attack method: LPattack 8 | Attack tau: 0.8 9 | Fraction of malicious agents: 10.0% 10 | Poison Frac: 0.5 11 | Backdoor From -1 to 5 12 | Attack Begin: 0 13 | Trigger Shape: square 14 | Trigger Position X: 23 15 | Trigger Position Y: 23 16 | Number of agents: 100 17 | Fraction of agents each turn: 10(10.0%) 18 | Local batch size: 64 19 | Local epoch: 2 20 | Client_LR: 0.01 21 | Client_Momentum: 0.9 22 | Global Rounds: 200 23 | ====================================== 24 | -------------------------------------------------------------------------------- /save/ResNet18_LPattack_FLAME/a_info.txt: -------------------------------------------------------------------------------- 1 | ====================================== 2 | IID: 1 3 | Dataset: cifar 4 | Model: resnet 5 | Model Init: None 6 | Aggregation Function: flame 7 | Attack method: LPattack 8 | Attack tau: 0.8 9 | Fraction of malicious agents: 10.0% 10 | Poison Frac: 0.5 11 | Backdoor From -1 to 5 12 | Attack Begin: 0 13 | Trigger Shape: square 14 | Trigger Position X: 27 15 | Trigger Position Y: 27 16 | Number of agents: 100 17 | Fraction of agents each turn: 10(10.0%) 18 | Local batch size: 64 19 | Local epoch: 2 20 | Client_LR: 0.1 21 | Client_Momentum: 0.9 22 | Global Rounds: 200 23 | Noise in FLAME: 0.001 24 | ====================================== 25 | -------------------------------------------------------------------------------- /save/accuracy_file_cifar_resnet_fld_1678260074_no_malicious.txt: -------------------------------------------------------------------------------- 1 | ====================================== 2 | IID: 1 3 | Dataset: cifar 4 | Model: resnet 5 | Model Init: None 6 | Aggregation Function: fld 7 | -----No Attack----- 8 | Number of agents: 100 9 | Fraction of agents each turn: 100(100%) 10 | Local batch size: 50 11 | Local epoch: 3 12 | Client_LR: 0.01 13 | Client_Momentum: 0.9 14 | Global Rounds: 500 15 | ====================================== 16 | main_task_accuracy=[0.0001, 10.390000343322754, 29.579999923706055, 41.40999984741211, 43.209999084472656, 48.61000061035156, 51.11000061035156, 54.81999969482422, 56.63999938964844, 57.869998931884766] 17 | backdoor_accuracy=[0, 0.0, 11.5, 19.8, 13.6, 13.2, 9.6, 10.1, 9.1, 5.7] -------------------------------------------------------------------------------- /save/avg_VGG_noniid_mode10_ada/accuracy_file_cifar_VGG_avg_1674543995_adaptive_0.1malicious_1.0poisondata_mode10.txt: -------------------------------------------------------------------------------- 1 | ====================================== 2 | IID: 0 3 | Dataset: cifar 4 | Model: VGG 5 | Model Init: None 6 | Aggregation Function: avg 7 | Attack method: adaptive 8 | Attack mode: 10 9 | Attack tau: 0.8 10 | Fraction of malicious agents: 10.0% 11 | Poison Frac: 1.0 12 | Backdoor From -1 to 5 13 | Attack Begin: 0 14 | Trigger Shape: square 15 | Trigger Position X: 27 16 | Trigger Position Y: 27 17 | Number of agents: 100 18 | Fraction of agents each turn: 10(10.0%) 19 | Local batch size: 64 20 | Local epoch: 2 21 | Client_LR: 0.1 22 | Client_Momentum: 0.9 23 | Global Rounds: 200 24 | ====================================== 25 | main_task_accuracy=[0.0001, 10.0, 10.0] 26 | backdoor_accuracy=[0, 0.0, 0.0] -------------------------------------------------------------------------------- /save/avg_VGG_noniid_mode10_ada/accuracy_file_cifar_VGG_avg_1674544149_adaptive_0.1malicious_1.0poisondata_mode10.txt: -------------------------------------------------------------------------------- 1 | ====================================== 2 | IID: 0 3 | Dataset: cifar 4 | Model: VGG 5 | Model Init: None 6 | Aggregation Function: avg 7 | Attack method: adaptive 8 | Attack mode: 10 9 | Attack tau: 0.8 10 | Fraction of malicious agents: 10.0% 11 | Poison Frac: 1.0 12 | Backdoor From -1 to 5 13 | Attack Begin: 0 14 | Trigger Shape: square 15 | Trigger Position X: 27 16 | Trigger Position Y: 27 17 | Number of agents: 100 18 | Fraction of agents each turn: 10(10.0%) 19 | Local batch size: 64 20 | Local epoch: 2 21 | Client_LR: 0.1 22 | Client_Momentum: 0.9 23 | Global Rounds: 200 24 | ====================================== 25 | main_task_accuracy=[0.0001, 10.0, 10.0] 26 | backdoor_accuracy=[0, 0.0, 0.0] -------------------------------------------------------------------------------- /save/avg_VGG_noniid_mode10_ada/accuracy_file_cifar_VGG_avg_1674545743_adaptive_0.1malicious_1.0poisondata_mode10.txt: -------------------------------------------------------------------------------- 1 | ====================================== 2 | IID: 0 3 | Dataset: cifar 4 | Model: VGG 5 | Model Init: None 6 | Aggregation Function: avg 7 | Attack method: adaptive 8 | Attack mode: 10 9 | Attack tau: 0.8 10 | Fraction of malicious agents: 10.0% 11 | Poison Frac: 1.0 12 | Backdoor From -1 to 5 13 | Attack Begin: 0 14 | Trigger Shape: square 15 | Trigger Position X: 27 16 | Trigger Position Y: 27 17 | Number of agents: 100 18 | Fraction of agents each turn: 10(10.0%) 19 | Local batch size: 64 20 | Local epoch: 2 21 | Client_LR: 0.1 22 | Client_Momentum: 0.9 23 | Global Rounds: 200 24 | ====================================== 25 | main_task_accuracy=[0.0001, 10.0, 10.0, 10.0, 15.25, 16.170000076293945, 12.600000381469727] 26 | backdoor_accuracy=[0, 100.0, 0.0, 0.0, 0.0, 0.0, 0.0] -------------------------------------------------------------------------------- /save/avg_VGG_noniid_mode10_ada/accuracy_file_cifar_VGG_avg_1674546108_badnet_0.1malicious_1.0poisondata.txt: -------------------------------------------------------------------------------- 1 | ====================================== 2 | IID: 0 3 | Dataset: cifar 4 | Model: VGG 5 | Model Init: None 6 | Aggregation Function: avg 7 | Attack method: badnet 8 | Attack tau: 0.8 9 | Fraction of malicious agents: 10.0% 10 | Poison Frac: 1.0 11 | Backdoor From -1 to 5 12 | Attack Begin: 0 13 | Trigger Shape: square 14 | Trigger Position X: 27 15 | Trigger Position Y: 27 16 | Number of agents: 100 17 | Fraction of agents each turn: 10(10.0%) 18 | Local batch size: 64 19 | Local epoch: 2 20 | Client_LR: 0.1 21 | Client_Momentum: 0.9 22 | Global Rounds: 200 23 | ====================================== 24 | main_task_accuracy=[0.0001, 10.0, 10.0, 10.0, 17.15999984741211, 17.5, 12.09000015258789, 16.040000915527344, 17.389999389648438, 17.510000228881836] 25 | backdoor_accuracy=[0, 0.0, 0.0, 0.0, 0.0, 0.0, 15.422222222222222, 37.08888888888889, 53.56666666666667, 30.7] -------------------------------------------------------------------------------- /save/cnn_multikrum_LPattack/accuracy_file_fashion_mnist_rlr_mnist_multikrum_1674546908_adaptive_0.1malicious_0.5poisondata_mode10.txt: -------------------------------------------------------------------------------- 1 | ====================================== 2 | IID: 1 3 | Dataset: fashion_mnist 4 | Model: rlr_mnist 5 | Model Init: None 6 | Aggregation Function: multikrum 7 | Attack method: adaptive 8 | Attack mode: 10 9 | Attack tau: 0.8 10 | Fraction of malicious agents: 10.0% 11 | Poison Frac: 0.5 12 | Backdoor From -1 to 5 13 | Attack Begin: 0 14 | Trigger Shape: square 15 | Trigger Position X: 23 16 | Trigger Position Y: 23 17 | Number of agents: 100 18 | Fraction of agents each turn: 10(10.0%) 19 | Local batch size: 64 20 | Local epoch: 2 21 | Client_LR: 0.01 22 | Client_Momentum: 0.9 23 | Global Rounds: 200 24 | proportion of malicious are selected:0.2 25 | Average score of malicious clients: 2.815392017364502 26 | Average score of benign clients: 1.0735679864883423 27 | ====================================== 28 | main_task_accuracy=[0.0001, 60.689998626708984] 29 | backdoor_accuracy=[0, 0.011111111111111112] -------------------------------------------------------------------------------- /save/cnn_multikrum_LPattack/accuracy_file_fashion_mnist_rlr_mnist_multikrum_1674546843_adaptive_0.1malicious_0.5poisondata_mode10.txt: -------------------------------------------------------------------------------- 1 | ====================================== 2 | IID: 1 3 | Dataset: fashion_mnist 4 | Model: rlr_mnist 5 | Model Init: None 6 | Aggregation Function: multikrum 7 | Attack method: adaptive 8 | Attack mode: 10 9 | Attack tau: 0.8 10 | Fraction of malicious agents: 10.0% 11 | Poison Frac: 0.5 12 | Backdoor From -1 to 5 13 | Attack Begin: 0 14 | Trigger Shape: square 15 | Trigger Position X: 23 16 | Trigger Position Y: 23 17 | Number of agents: 100 18 | Fraction of agents each turn: 10(10.0%) 19 | Local batch size: 64 20 | Local epoch: 2 21 | Client_LR: 0.01 22 | Client_Momentum: 0.9 23 | Global Rounds: 200 24 | proportion of malicious are selected:0.25 25 | Average score of malicious clients: 3.812267303466797 26 | Average score of benign clients: 2.0707249641418457 27 | ====================================== 28 | main_task_accuracy=[0.0001, 54.81999969482422, 66.41000366210938] 29 | backdoor_accuracy=[0, 1.9555555555555555, 0.28888888888888886] -------------------------------------------------------------------------------- /save/ResNet18_LPattack_FLAME/accuracy_file_cifar_resnet_flame_1674546495_adaptive_0.1malicious_0.5poisondata_mode10.txt: -------------------------------------------------------------------------------- 1 | ====================================== 2 | IID: 1 3 | Dataset: cifar 4 | Model: resnet 5 | Model Init: None 6 | Aggregation Function: flame 7 | Attack method: adaptive 8 | Attack mode: 10 9 | Attack tau: 0.8 10 | Fraction of malicious agents: 10.0% 11 | Poison Frac: 0.5 12 | Backdoor From -1 to 5 13 | Attack Begin: 0 14 | Trigger Shape: square 15 | Trigger Position X: 27 16 | Trigger Position Y: 27 17 | Number of agents: 100 18 | Fraction of agents each turn: 10(10.0%) 19 | Local batch size: 64 20 | Local epoch: 2 21 | Client_LR: 0.1 22 | Client_Momentum: 0.9 23 | Global Rounds: 200 24 | Noise in FLAME: 0.001 25 | proportion of malicious are selected:0.3333333333333333 26 | proportion of benign are selected:0.6296296296296297 27 | ====================================== 28 | main_task_accuracy=[0.0001, 10.0, 16.670000076293945, 26.5, 32.959999084472656, 34.380001068115234, 36.31999969482422] 29 | backdoor_accuracy=[0, 0.0, 38.17777777777778, 1.5666666666666667, 10.21111111111111, 8.155555555555555, 16.8] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FLDetector_pytorch 2 | **Unofficial implementation** for paper FLDetector: Defending Federated Learning Against Model Poisoning Attacks via Detecting Malicious Clients (KDD2022). Official implementation is [here](https://github.com/zaixizhang/FLDetector) with the MXNet framework. 3 | 4 | The code cannot work well with local SGD updates. So run this code with local epoch set to 1 (local_ep = 1) and local batch size set to the same number of samples in local dataset (local_bs = 500, 600). 5 | 6 | paper FLDetector: Defending Federated Learning Against Model Poisoning Attacks via Detecting Malicious Clients is from [KDD2022](https://dl.acm.org/doi/abs/10.1145/3534678.3539231) 7 | 8 | Feel free to contact me if you have any difficulty running the code in the issue. 9 | 10 | # Backdoor in FL 11 | 12 | **Our recent paper "Backdoor Federated Learning by Poisoning Backdoor-critical Layers" has been accepted in ICLR'24, please refer to the [Github repo](https://github.com/zhmzm/Poisoning_Backdoor-critical_Layers_Attack).** 13 | 14 | # Results 15 | The results in this version are a bit different with the results reported in the original paper, especially in Non-iid settings. Please use it discriminately and let me know if there is any problem. Here ASR indicates attack success rate also called backdoor success rate, and Acc indicates accuracy of the main tasks. 16 | |Dataset|Model|Attack|Defence|ASR|Acc|iid| 17 | | ---- | ---- | ---- | ---- | ---- | ---- | ---- | 18 | |CIFAR-10|ResNet18|Badnet|No Defence|70.2|80.38|IID| 19 | |CIFAR-10|ResNet18|Badnet|FLDector|4.43|68.54|IID| 20 | |CIFAR-10|ResNet18|Badnet|No Defence|70.53|77.58|Non-IID| 21 | |CIFAR-10|ResNet18|Badnet|FLDector|5.21|64.39|Non-IID| 22 | 23 | # Requirement 24 | Python=3.9 25 | 26 | pytorch=1.10.1 27 | 28 | scikit-learn=1.0.2 29 | 30 | opencv-python=4.5.5.64 31 | 32 | Scikit-Image=0.19.2 33 | 34 | matplotlib=3.4.3 35 | 36 | hdbscan=0.8.28 37 | 38 | jupyterlab=3.3.2 39 | 40 | Install instructions are recorded in install_requirements.sh 41 | 42 | # Run 43 | VGG and ResNet18 can only be trained on CIFAR-10 dataset, while CNN can only be trained on the fashion-MNIST dataset. 44 | 45 | Quick start: 46 | ``` 47 | python main_fed.py --defence fld --model resnet --dataset cifar --local_ep 1 --local_bs 500 --attack badnet --triggerX 27 --triggerY 27 --epochs 500 --poison_frac 0.5 48 | ``` 49 | It costs more than 10 GPU hours to run this program. 50 | 51 | Detailed settings: 52 | 53 | ``` 54 | python main_fed.py --dataset cifar,fashion_mnist \ 55 | --model VGG,resnet,cnn \ 56 | --attack baseline,dba \ 57 | --lr 0.1 \ 58 | --malicious 0.1 \ 59 | --poison_frac 0.5 \ 60 | --local_ep 1 \ 61 | --local_bs 500, 600 \ 62 | --attack_begin 0 \ 63 | --defence avg, fldetector, fltrust, flame, krum, RLR \ 64 | --epochs 500 \ 65 | --attack_label 5 \ 66 | --attack_goal -1 \ 67 | --trigger 'square','pattern','watermark','apple' \ 68 | --triggerX 27 \ 69 | --triggerY 27 \ 70 | --gpu 0 \ 71 | --save save/your_experiments \ 72 | --iid 0,1 73 | ``` 74 | Images with triggers on attack process and test process are shown in './save' when running. Results files are saved in './save' by default, including a figure and a accuracy record. More default parameters on different defense strategies or attack can be seen in './utils/options'. 75 | -------------------------------------------------------------------------------- /models/resnet20.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | 6 | from torch.autograd import Variable 7 | 8 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 9 | 10 | def _weights_init(m): 11 | classname = m.__class__.__name__ 12 | #print(classname) 13 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 14 | init.kaiming_normal_(m.weight) 15 | 16 | class LambdaLayer(nn.Module): 17 | def __init__(self, lambd): 18 | super(LambdaLayer, self).__init__() 19 | self.lambd = lambd 20 | 21 | def forward(self, x): 22 | return self.lambd(x) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, in_planes, planes, stride=1, option='A'): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | 35 | self.shortcut = nn.Sequential() 36 | if stride != 1 or in_planes != planes: 37 | if option == 'A': 38 | """ 39 | For CIFAR10 ResNet paper uses option A. 40 | """ 41 | self.shortcut = LambdaLayer(lambda x: 42 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 43 | elif option == 'B': 44 | self.shortcut = nn.Sequential( 45 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 46 | nn.BatchNorm2d(self.expansion * planes) 47 | ) 48 | 49 | def forward(self, x): 50 | out = F.relu(self.bn1(self.conv1(x))) 51 | out = self.bn2(self.conv2(out)) 52 | out += self.shortcut(x) 53 | out = F.relu(out) 54 | return out 55 | 56 | 57 | class ResNet(nn.Module): 58 | def __init__(self, block, num_blocks, num_classes=10): 59 | super(ResNet, self).__init__() 60 | self.in_planes = 16 61 | 62 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(16) 64 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 65 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 66 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 67 | self.linear = nn.Linear(64, num_classes) 68 | 69 | self.apply(_weights_init) 70 | 71 | def _make_layer(self, block, planes, num_blocks, stride): 72 | strides = [stride] + [1]*(num_blocks-1) 73 | layers = [] 74 | for stride in strides: 75 | layers.append(block(self.in_planes, planes, stride)) 76 | self.in_planes = planes * block.expansion 77 | 78 | return nn.Sequential(*layers) 79 | 80 | def forward(self, x): 81 | out = F.relu(self.bn1(self.conv1(x))) 82 | out = self.layer1(out) 83 | out = self.layer2(out) 84 | out = self.layer3(out) 85 | out = F.avg_pool2d(out, out.size()[3]) 86 | out = out.view(out.size(0), -1) 87 | out = self.linear(out) 88 | return out 89 | 90 | 91 | def resnet20(): 92 | return ResNet(BasicBlock, [3, 3, 3]) 93 | 94 | 95 | def resnet32(): 96 | return ResNet(BasicBlock, [5, 5, 5]) 97 | 98 | 99 | def resnet44(): 100 | return ResNet(BasicBlock, [7, 7, 7]) 101 | 102 | 103 | def resnet56(): 104 | return ResNet(BasicBlock, [9, 9, 9]) 105 | 106 | 107 | def resnet110(): 108 | return ResNet(BasicBlock, [18, 18, 18]) 109 | 110 | 111 | def resnet1202(): 112 | return ResNet(BasicBlock, [200, 200, 200]) -------------------------------------------------------------------------------- /models/.ipynb_checkpoints/resnet20-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | 6 | from torch.autograd import Variable 7 | 8 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 9 | 10 | def _weights_init(m): 11 | classname = m.__class__.__name__ 12 | #print(classname) 13 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 14 | init.kaiming_normal_(m.weight) 15 | 16 | class LambdaLayer(nn.Module): 17 | def __init__(self, lambd): 18 | super(LambdaLayer, self).__init__() 19 | self.lambd = lambd 20 | 21 | def forward(self, x): 22 | return self.lambd(x) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, in_planes, planes, stride=1, option='A'): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | 35 | self.shortcut = nn.Sequential() 36 | if stride != 1 or in_planes != planes: 37 | if option == 'A': 38 | """ 39 | For CIFAR10 ResNet paper uses option A. 40 | """ 41 | self.shortcut = LambdaLayer(lambda x: 42 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 43 | elif option == 'B': 44 | self.shortcut = nn.Sequential( 45 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 46 | nn.BatchNorm2d(self.expansion * planes) 47 | ) 48 | 49 | def forward(self, x): 50 | out = F.relu(self.bn1(self.conv1(x))) 51 | out = self.bn2(self.conv2(out)) 52 | out += self.shortcut(x) 53 | out = F.relu(out) 54 | return out 55 | 56 | 57 | class ResNet(nn.Module): 58 | def __init__(self, block, num_blocks, num_classes=10): 59 | super(ResNet, self).__init__() 60 | self.in_planes = 16 61 | 62 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(16) 64 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 65 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 66 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 67 | self.linear = nn.Linear(64, num_classes) 68 | 69 | self.apply(_weights_init) 70 | 71 | def _make_layer(self, block, planes, num_blocks, stride): 72 | strides = [stride] + [1]*(num_blocks-1) 73 | layers = [] 74 | for stride in strides: 75 | layers.append(block(self.in_planes, planes, stride)) 76 | self.in_planes = planes * block.expansion 77 | 78 | return nn.Sequential(*layers) 79 | 80 | def forward(self, x): 81 | out = F.relu(self.bn1(self.conv1(x))) 82 | out = self.layer1(out) 83 | out = self.layer2(out) 84 | out = self.layer3(out) 85 | out = F.avg_pool2d(out, out.size()[3]) 86 | out = out.view(out.size(0), -1) 87 | out = self.linear(out) 88 | return out 89 | 90 | 91 | def resnet20(): 92 | return ResNet(BasicBlock, [3, 3, 3]) 93 | 94 | 95 | def resnet32(): 96 | return ResNet(BasicBlock, [5, 5, 5]) 97 | 98 | 99 | def resnet44(): 100 | return ResNet(BasicBlock, [7, 7, 7]) 101 | 102 | 103 | def resnet56(): 104 | return ResNet(BasicBlock, [9, 9, 9]) 105 | 106 | 107 | def resnet110(): 108 | return ResNet(BasicBlock, [18, 18, 18]) 109 | 110 | 111 | def resnet1202(): 112 | return ResNet(BasicBlock, [200, 200, 200]) -------------------------------------------------------------------------------- /utils/info.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | 4 | def print_exp_details(args): 5 | info = information(args) 6 | for i in info: 7 | print(i) 8 | write_info(args, info) 9 | 10 | def write_info_to_accfile(filename, args): 11 | info = information(args) 12 | f = open(filename, "w") 13 | for i in info: 14 | f.write(i) 15 | f.write('\n') 16 | f.close() 17 | 18 | def write_info(args, info): 19 | f = open("./"+args.save+'/'+"a_info.txt", "w") 20 | for i in info: 21 | f.write(i) 22 | f.write('\n') 23 | f.close() 24 | 25 | def information(args): 26 | info = [] 27 | info.append('======================================') 28 | info.append(f' IID: {args.iid}') 29 | info.append(f' Dataset: {args.dataset}') 30 | info.append(f' Model: {args.model}') 31 | info.append(f' Model Init: {args.init}') 32 | info.append(f' Aggregation Function: {args.defence}') 33 | if math.isclose(args.malicious, 0) == False: 34 | info.append(f' Attack method: {args.attack}') 35 | if 'adaptive' in args.attack: 36 | info.append(f' Attack mode: {args.ada_mode}') 37 | info.append(f' Attack tau: {args.tau}') 38 | info.append(f' Fraction of malicious agents: {args.malicious*100}%') 39 | info.append(f' Poison Frac: {args.poison_frac}') 40 | info.append(f' Backdoor From {args.attack_goal} to {args.attack_label}') 41 | info.append(f' Attack Begin: {args.attack_begin}') 42 | info.append(f' Trigger Shape: {args.trigger}') 43 | if args.trigger == 'square' or args.trigger == 'pattern': 44 | info.append(f' Trigger Position X: {args.triggerX}') 45 | info.append(f' Trigger Position Y: {args.triggerY}') 46 | 47 | else: 48 | info.append(f' -----No Attack-----') 49 | 50 | info.append(f' Number of agents: {args.num_users}') 51 | info.append(f' Fraction of agents each turn: {int(args.num_users*args.frac)}({args.frac*100}%)') 52 | info.append(f' Local batch size: {args.local_bs}') 53 | info.append(f' Local epoch: {args.local_ep}') 54 | info.append(f' Client_LR: {args.lr}') 55 | # print(f' Server_LR: {args.server_lr}') 56 | info.append(f' Client_Momentum: {args.momentum}') 57 | info.append(f' Global Rounds: {args.epochs}') 58 | if args.defence == 'RLR': 59 | info.append(f' RobustLR_threshold: {args.robustLR_threshold}') 60 | elif args.defence == 'fltrust' or args.defence == 'fltrust_bn': 61 | info.append(f' Dataset In Server: {args.server_dataset}') 62 | elif args.defence == 'flame' or args.defence == 'flame2': 63 | info.append(f' Noise in FLAME: {args.noise}') 64 | if args.turn != 0: 65 | info.append('proportion of malicious are selected:'+str(args.wrong_mal/(int(args.malicious * max(int(args.frac * args.num_users), 1))*args.turn))) 66 | info.append('proportion of benign are selected:'+str(args.right_ben/((max(int(args.frac * args.num_users), 1) - int(args.malicious * max(int(args.frac * args.num_users), 1)))*args.turn))) 67 | elif args.defence == 'krum' or args.defence == 'multikrum': 68 | if args.turn != 0 and args.malicious != 0: 69 | p = args.wrong_mal/args.turn 70 | score_mal = args.mal_score/args.turn 71 | score_ben = args.ben_score/(args.turn*9) 72 | info.append('proportion of malicious are selected:'+str(args.wrong_mal/(int(args.malicious * max(int(args.frac * args.num_users), 1))*args.turn))) 73 | info.append(f' Average score of malicious clients: {score_mal}') 74 | info.append(f' Average score of benign clients: {score_ben}') 75 | info.append('======================================') 76 | return info 77 | 78 | def get_base_info(args): 79 | if args.defence == 'RLR': 80 | base_info = '{}_{}_{}_{}_{}'.format(args.dataset, 81 | args.model, args.defence, args.robustLR_threshold, int(time.time())) 82 | else: 83 | base_info = '{}_{}_{}_{}'.format(args.dataset, 84 | args.model, args.defence, int(time.time())) 85 | if math.isclose(args.malicious, 0) == False: 86 | base_info = base_info + '_{}_{}malicious_{}poisondata'.format(args.attack, args.malicious, args.poison_frac) 87 | if 'adaptive' in args.attack: 88 | base_info += '_mode{}'.format(args.ada_mode) 89 | else: 90 | base_info = base_info + '_no_malicious' 91 | return base_info -------------------------------------------------------------------------------- /utils/.ipynb_checkpoints/info-checkpoint.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | 4 | def print_exp_details(args): 5 | info = information(args) 6 | for i in info: 7 | print(i) 8 | write_info(args, info) 9 | 10 | def write_info_to_accfile(filename, args): 11 | info = information(args) 12 | f = open(filename, "w") 13 | for i in info: 14 | f.write(i) 15 | f.write('\n') 16 | f.close() 17 | 18 | def write_info(args, info): 19 | f = open("./"+args.save+'/'+"a_info.txt", "w") 20 | for i in info: 21 | f.write(i) 22 | f.write('\n') 23 | f.close() 24 | 25 | def information(args): 26 | info = [] 27 | info.append('======================================') 28 | info.append(f' IID: {args.iid}') 29 | info.append(f' Dataset: {args.dataset}') 30 | info.append(f' Model: {args.model}') 31 | info.append(f' Model Init: {args.init}') 32 | info.append(f' Aggregation Function: {args.defence}') 33 | if math.isclose(args.malicious, 0) == False: 34 | info.append(f' Attack method: {args.attack}') 35 | if 'adaptive' in args.attack: 36 | info.append(f' Attack mode: {args.ada_mode}') 37 | info.append(f' Attack tau: {args.tau}') 38 | info.append(f' Fraction of malicious agents: {args.malicious*100}%') 39 | info.append(f' Poison Frac: {args.poison_frac}') 40 | info.append(f' Backdoor From {args.attack_goal} to {args.attack_label}') 41 | info.append(f' Attack Begin: {args.attack_begin}') 42 | info.append(f' Trigger Shape: {args.trigger}') 43 | if args.trigger == 'square' or args.trigger == 'pattern': 44 | info.append(f' Trigger Position X: {args.triggerX}') 45 | info.append(f' Trigger Position Y: {args.triggerY}') 46 | 47 | else: 48 | info.append(f' -----No Attack-----') 49 | 50 | info.append(f' Number of agents: {args.num_users}') 51 | info.append(f' Fraction of agents each turn: {int(args.num_users*args.frac)}({args.frac*100}%)') 52 | info.append(f' Local batch size: {args.local_bs}') 53 | info.append(f' Local epoch: {args.local_ep}') 54 | info.append(f' Client_LR: {args.lr}') 55 | # print(f' Server_LR: {args.server_lr}') 56 | info.append(f' Client_Momentum: {args.momentum}') 57 | info.append(f' Global Rounds: {args.epochs}') 58 | if args.defence == 'RLR': 59 | info.append(f' RobustLR_threshold: {args.robustLR_threshold}') 60 | elif args.defence == 'fltrust' or args.defence == 'fltrust_bn': 61 | info.append(f' Dataset In Server: {args.server_dataset}') 62 | elif args.defence == 'flame' or args.defence == 'flame2': 63 | info.append(f' Noise in FLAME: {args.noise}') 64 | if args.turn != 0: 65 | info.append('proportion of malicious are selected:'+str(args.wrong_mal/(int(args.malicious * max(int(args.frac * args.num_users), 1))*args.turn))) 66 | info.append('proportion of benign are selected:'+str(args.right_ben/((max(int(args.frac * args.num_users), 1) - int(args.malicious * max(int(args.frac * args.num_users), 1)))*args.turn))) 67 | elif args.defence == 'krum' or args.defence == 'multikrum': 68 | if args.turn != 0 and args.malicious != 0: 69 | p = args.wrong_mal/args.turn 70 | score_mal = args.mal_score/args.turn 71 | score_ben = args.ben_score/(args.turn*9) 72 | info.append('proportion of malicious are selected:'+str(args.wrong_mal/(int(args.malicious * max(int(args.frac * args.num_users), 1))*args.turn))) 73 | info.append(f' Average score of malicious clients: {score_mal}') 74 | info.append(f' Average score of benign clients: {score_ben}') 75 | info.append('======================================') 76 | return info 77 | 78 | def get_base_info(args): 79 | if args.defence == 'RLR': 80 | base_info = '{}_{}_{}_{}_{}'.format(args.dataset, 81 | args.model, args.defence, args.robustLR_threshold, int(time.time())) 82 | else: 83 | base_info = '{}_{}_{}_{}'.format(args.dataset, 84 | args.model, args.defence, int(time.time())) 85 | if math.isclose(args.malicious, 0) == False: 86 | base_info = base_info + '_{}_{}malicious_{}poisondata'.format(args.attack, args.malicious, args.poison_frac) 87 | if 'adaptive' in args.attack: 88 | base_info += '_mode{}'.format(args.ada_mode) 89 | else: 90 | base_info = base_info + '_no_malicious' 91 | return base_info -------------------------------------------------------------------------------- /models/add_trigger.py: -------------------------------------------------------------------------------- 1 | # -*- coding = utf-8 -*- 2 | import cv2 3 | import torch 4 | import numpy as np 5 | 6 | def add_trigger(args, image, test=False): 7 | pixel_max = max(1,torch.max(image)) 8 | if args.attack == 'dba' and test == False: 9 | size = 6 10 | gap = 3 11 | shift = 0 12 | if args.dba_class == 0: 13 | # image[:, args.triggerY + 0:args.triggerY + 2, args.triggerX + 0:args.triggerX + size] = pixel_max 14 | image[:, args.triggerY + 0:args.triggerY + 2, args.triggerX + 0:args.triggerX + 2] = pixel_max 15 | elif args.dba_class == 1: 16 | # image[:, args.triggerY + 0:args.triggerY + 2, args.triggerX+size+gap:args.triggerX +size+gap+size] = pixel_max 17 | image[:, args.triggerY + 0:args.triggerY + 2, args.triggerX + 2:args.triggerX + 5] = pixel_max 18 | elif args.dba_class == 2: 19 | # image[:, args.triggerY + 2+gap:args.triggerY + 2+gap+2, args.triggerX + 0:args.triggerX + size] = pixel_max 20 | image[:, args.triggerY + 2:args.triggerY + 5, args.triggerX + 0:args.triggerX + 2] = pixel_max 21 | elif args.dba_class == 3: 22 | # image[:, args.triggerY + 2+gap:args.triggerY + 2+gap+2, args.triggerX +size+gap:args.triggerX +size+gap+size] = pixel_max 23 | image[:, args.triggerY + 2:args.triggerY + 5, args.triggerX + 2:args.triggerX + 5] = pixel_max 24 | args.save_img(image) 25 | return image 26 | if args.attack == 'dba' and test == True: 27 | size = 6 28 | gap = 3 29 | shift = 0 30 | image[:, args.triggerY + 0:args.triggerY + 2, args.triggerX + 0:args.triggerX + 2] = pixel_max 31 | image[:, args.triggerY + 0:args.triggerY + 2, args.triggerX + 2:args.triggerX + 5] = pixel_max 32 | image[:, args.triggerY + 2:args.triggerY + 5, args.triggerX + 0:args.triggerX + 2] = pixel_max 33 | image[:, args.triggerY + 2:args.triggerY + 5, args.triggerX + 2:args.triggerX + 5] = pixel_max 34 | # image[:, args.triggerY + 0:args.triggerY + 2, args.triggerX + 0:args.triggerX + size] = pixel_max 35 | # image[:, args.triggerY + 0:args.triggerY + 2, args.triggerX+size+gap:args.triggerX +size+gap+size] = pixel_max 36 | 37 | # image[:, args.triggerY + 2+gap:args.triggerY + 2+gap+2, args.triggerX + 0:args.triggerX + size] = pixel_max 38 | # image[:, args.triggerY + 2+gap:args.triggerY + 2+gap+2, args.triggerX +size+gap:args.triggerX +size+gap+size] = pixel_max 39 | return image 40 | if args.trigger == 'square': 41 | pixel_max = torch.max(image) if torch.max(image) > 1 else 1 42 | # 2022年6月10日 change 43 | if args.dataset == 'cifar': 44 | pixel_max = 1 45 | image[:, args.triggerY:args.triggerY + 5, args.triggerX:args.triggerX + 5] = pixel_max 46 | elif args.trigger == 'pattern': 47 | pixel_max = torch.max(image) if torch.max(image) > 1 else 1 48 | image[:, args.triggerY + 0, args.triggerX + 0] = pixel_max 49 | image[:, args.triggerY + 1, args.triggerX + 1] = pixel_max 50 | image[:, args.triggerY - 1, args.triggerX + 1] = pixel_max 51 | image[:, args.triggerY + 1, args.triggerX - 1] = pixel_max 52 | elif args.trigger == 'watermark': 53 | if args.watermark is None: 54 | args.watermark = cv2.imread('./utils/watermark.png', cv2.IMREAD_GRAYSCALE) 55 | args.watermark = cv2.bitwise_not(args.watermark) 56 | args.watermark = cv2.resize(args.watermark, dsize=image[0].shape, interpolation=cv2.INTER_CUBIC) 57 | pixel_max = np.max(args.watermark) 58 | args.watermark = args.watermark.astype(np.float64) / pixel_max 59 | # cifar [0,1] else max>1 60 | pixel_max_dataset = torch.max(image).item() if torch.max(image).item() > 1 else 1 61 | args.watermark *= pixel_max_dataset 62 | max_pixel = max(np.max(args.watermark), torch.max(image)) 63 | image += args.watermark 64 | image[image > max_pixel] = max_pixel 65 | elif args.trigger == 'apple': 66 | if args.apple is None: 67 | args.apple = cv2.imread('./utils/apple.png', cv2.IMREAD_GRAYSCALE) 68 | args.apple = cv2.bitwise_not(args.apple) 69 | args.apple = cv2.resize(args.apple, dsize=image[0].shape, interpolation=cv2.INTER_CUBIC) 70 | pixel_max = np.max(args.apple) 71 | args.apple = args.apple.astype(np.float64) / pixel_max 72 | # cifar [0,1] else max>1 73 | pixel_max_dataset = torch.max(image).item() if torch.max(image).item() > 1 else 1 74 | args.apple *= pixel_max_dataset 75 | max_pixel = max(np.max(args.apple), torch.max(image)) 76 | image += args.apple 77 | image[image > max_pixel] = max_pixel 78 | # args.save_img(image) 79 | return image -------------------------------------------------------------------------------- /models/.ipynb_checkpoints/add_trigger-checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding = utf-8 -*- 2 | import cv2 3 | import torch 4 | import numpy as np 5 | 6 | def add_trigger(args, image, test=False): 7 | pixel_max = max(1,torch.max(image)) 8 | if args.attack == 'dba' and test == False: 9 | size = 6 10 | gap = 3 11 | shift = 0 12 | if args.dba_class == 0: 13 | # image[:, args.triggerY + 0:args.triggerY + 2, args.triggerX + 0:args.triggerX + size] = pixel_max 14 | image[:, args.triggerY + 0:args.triggerY + 2, args.triggerX + 0:args.triggerX + 2] = pixel_max 15 | elif args.dba_class == 1: 16 | # image[:, args.triggerY + 0:args.triggerY + 2, args.triggerX+size+gap:args.triggerX +size+gap+size] = pixel_max 17 | image[:, args.triggerY + 0:args.triggerY + 2, args.triggerX + 2:args.triggerX + 5] = pixel_max 18 | elif args.dba_class == 2: 19 | # image[:, args.triggerY + 2+gap:args.triggerY + 2+gap+2, args.triggerX + 0:args.triggerX + size] = pixel_max 20 | image[:, args.triggerY + 2:args.triggerY + 5, args.triggerX + 0:args.triggerX + 2] = pixel_max 21 | elif args.dba_class == 3: 22 | # image[:, args.triggerY + 2+gap:args.triggerY + 2+gap+2, args.triggerX +size+gap:args.triggerX +size+gap+size] = pixel_max 23 | image[:, args.triggerY + 2:args.triggerY + 5, args.triggerX + 2:args.triggerX + 5] = pixel_max 24 | args.save_img(image) 25 | return image 26 | if args.attack == 'dba' and test == True: 27 | size = 6 28 | gap = 3 29 | shift = 0 30 | image[:, args.triggerY + 0:args.triggerY + 2, args.triggerX + 0:args.triggerX + 2] = pixel_max 31 | image[:, args.triggerY + 0:args.triggerY + 2, args.triggerX + 2:args.triggerX + 5] = pixel_max 32 | image[:, args.triggerY + 2:args.triggerY + 5, args.triggerX + 0:args.triggerX + 2] = pixel_max 33 | image[:, args.triggerY + 2:args.triggerY + 5, args.triggerX + 2:args.triggerX + 5] = pixel_max 34 | # image[:, args.triggerY + 0:args.triggerY + 2, args.triggerX + 0:args.triggerX + size] = pixel_max 35 | # image[:, args.triggerY + 0:args.triggerY + 2, args.triggerX+size+gap:args.triggerX +size+gap+size] = pixel_max 36 | 37 | # image[:, args.triggerY + 2+gap:args.triggerY + 2+gap+2, args.triggerX + 0:args.triggerX + size] = pixel_max 38 | # image[:, args.triggerY + 2+gap:args.triggerY + 2+gap+2, args.triggerX +size+gap:args.triggerX +size+gap+size] = pixel_max 39 | return image 40 | if args.trigger == 'square': 41 | pixel_max = torch.max(image) if torch.max(image) > 1 else 1 42 | # 2022年6月10日 change 43 | if args.dataset == 'cifar': 44 | pixel_max = 1 45 | image[:, args.triggerY:args.triggerY + 5, args.triggerX:args.triggerX + 5] = pixel_max 46 | elif args.trigger == 'pattern': 47 | pixel_max = torch.max(image) if torch.max(image) > 1 else 1 48 | image[:, args.triggerY + 0, args.triggerX + 0] = pixel_max 49 | image[:, args.triggerY + 1, args.triggerX + 1] = pixel_max 50 | image[:, args.triggerY - 1, args.triggerX + 1] = pixel_max 51 | image[:, args.triggerY + 1, args.triggerX - 1] = pixel_max 52 | elif args.trigger == 'watermark': 53 | if args.watermark is None: 54 | args.watermark = cv2.imread('./utils/watermark.png', cv2.IMREAD_GRAYSCALE) 55 | args.watermark = cv2.bitwise_not(args.watermark) 56 | args.watermark = cv2.resize(args.watermark, dsize=image[0].shape, interpolation=cv2.INTER_CUBIC) 57 | pixel_max = np.max(args.watermark) 58 | args.watermark = args.watermark.astype(np.float64) / pixel_max 59 | # cifar [0,1] else max>1 60 | pixel_max_dataset = torch.max(image).item() if torch.max(image).item() > 1 else 1 61 | args.watermark *= pixel_max_dataset 62 | max_pixel = max(np.max(args.watermark), torch.max(image)) 63 | image += args.watermark 64 | image[image > max_pixel] = max_pixel 65 | elif args.trigger == 'apple': 66 | if args.apple is None: 67 | args.apple = cv2.imread('./utils/apple.png', cv2.IMREAD_GRAYSCALE) 68 | args.apple = cv2.bitwise_not(args.apple) 69 | args.apple = cv2.resize(args.apple, dsize=image[0].shape, interpolation=cv2.INTER_CUBIC) 70 | pixel_max = np.max(args.apple) 71 | args.apple = args.apple.astype(np.float64) / pixel_max 72 | # cifar [0,1] else max>1 73 | pixel_max_dataset = torch.max(image).item() if torch.max(image).item() > 1 else 1 74 | args.apple *= pixel_max_dataset 75 | max_pixel = max(np.max(args.apple), torch.max(image)) 76 | image += args.apple 77 | image[image > max_pixel] = max_pixel 78 | # args.save_img(image) 79 | return image -------------------------------------------------------------------------------- /models/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader 9 | from skimage import io 10 | import cv2 11 | from skimage import img_as_ubyte 12 | import numpy as np 13 | from models.add_trigger import add_trigger 14 | def test_img(net_g, datatest, args, test_backdoor=False): 15 | args.watermark = None 16 | args.apple = None 17 | net_g.eval() 18 | # testing 19 | test_loss = 0 20 | correct = 0 21 | data_loader = DataLoader(datatest, batch_size=args.bs) 22 | l = len(data_loader) 23 | back_correct = 0 24 | back_num = 0 25 | for idx, (data, target) in enumerate(data_loader): 26 | if args.gpu != -1: 27 | data, target = data.to(args.device), target.to(args.device) 28 | log_probs = net_g(data) 29 | # sum up batch loss 30 | test_loss += F.cross_entropy(log_probs, target, reduction='sum').item() 31 | # get the index of the max log-probability 32 | y_pred = log_probs.data.max(1, keepdim=True)[1] 33 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 34 | if test_backdoor: 35 | del_arr = [] 36 | for k, image in enumerate(data): 37 | if test_or_not(args, target[k]): # one2one need test 38 | # data[k][:, 0:5, 0:5] = torch.max(data[k]) 39 | data[k] = add_trigger(args,data[k], test=True) 40 | save_img(data[k]) 41 | target[k] = args.attack_label 42 | back_num += 1 43 | else: 44 | target[k] = -1 45 | log_probs = net_g(data) 46 | y_pred = log_probs.data.max(1, keepdim=True)[1] 47 | back_correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 48 | test_loss /= len(data_loader.dataset) 49 | accuracy = 100.00 * correct / len(data_loader.dataset) 50 | if args.verbose: 51 | print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format( 52 | test_loss, correct, len(data_loader.dataset), accuracy)) 53 | if test_backdoor: 54 | back_accu = 100.00 * float(back_correct) / back_num 55 | return accuracy, test_loss, back_accu 56 | return accuracy, test_loss 57 | 58 | def test_or_not(args, label): 59 | if args.attack_goal != -1: # one to one 60 | if label == args.attack_goal: # only attack goal join 61 | return True 62 | else: 63 | return False 64 | else: # all to one 65 | if label != args.attack_label: 66 | return True 67 | else: 68 | return False 69 | 70 | # def add_trigger(args, image): 71 | # if args.trigger == 'square': 72 | # pixel_max = torch.max(image) if torch.max(image)>1 else 1 73 | 74 | # image[:,args.triggerY:args.triggerY+5,args.triggerX:args.triggerX+5] = pixel_max 75 | # elif args.trigger == 'pattern': 76 | # pixel_max = torch.max(image) if torch.max(image)>1 else 1 77 | # image[:,args.triggerY+0,args.triggerX+0] = pixel_max 78 | # image[:,args.triggerY+1,args.triggerX+1] = pixel_max 79 | # image[:,args.triggerY-1,args.triggerX+1] = pixel_max 80 | # image[:,args.triggerY+1,args.triggerX-1] = pixel_max 81 | # elif args.trigger == 'watermark': 82 | # if args.watermark is None: 83 | # args.watermark = cv2.imread('./utils/watermark.png', cv2.IMREAD_GRAYSCALE) 84 | # args.watermark = cv2.bitwise_not(args.watermark) 85 | # args.watermark = cv2.resize(args.watermark, dsize=image[0].shape, interpolation=cv2.INTER_CUBIC) 86 | # pixel_max = np.max(args.watermark) 87 | # args.watermark = args.watermark.astype(np.float64) / pixel_max 88 | # # cifar [0,1] else max>1 89 | # pixel_max_dataset = torch.max(image).item() if torch.max(image).item() > 1 else 1 90 | # args.watermark *= pixel_max_dataset 91 | # max_pixel = max(np.max(args.watermark),torch.max(image)) 92 | # image = (image.cpu() + args.watermark).to(args.gpu) 93 | # image[image>max_pixel]=max_pixel 94 | # elif args.trigger == 'apple': 95 | # if args.apple is None: 96 | # args.apple = cv2.imread('./utils/apple.png', cv2.IMREAD_GRAYSCALE) 97 | # args.apple = cv2.bitwise_not(args.apple) 98 | # args.apple = cv2.resize(args.apple, dsize=image[0].shape, interpolation=cv2.INTER_CUBIC) 99 | # pixel_max = np.max(args.apple) 100 | # args.apple = args.apple.astype(np.float64) / pixel_max 101 | # # cifar [0,1] else max>1 102 | # pixel_max_dataset = torch.max(image).item() if torch.max(image).item() > 1 else 1 103 | # args.apple *= pixel_max_dataset 104 | # max_pixel = max(np.max(args.apple),torch.max(image)) 105 | # image += (image.cpu() + args.apple).to(args.gpu) 106 | # image[image>max_pixel]=max_pixel 107 | # return image 108 | def save_img(image): 109 | img = image 110 | if image.shape[0] == 1: 111 | pixel_min = torch.min(img) 112 | img -= pixel_min 113 | pixel_max = torch.max(img) 114 | img /= pixel_max 115 | io.imsave('./save/test_trigger2.png', img_as_ubyte(img.squeeze().cpu().numpy())) 116 | else: 117 | img = image.cpu().numpy() 118 | img = img.transpose(1, 2, 0) 119 | pixel_min = np.min(img) 120 | img -= pixel_min 121 | pixel_max = np.max(img) 122 | img /= pixel_max 123 | io.imsave('./save/test_trigger2.png', img_as_ubyte(img)) -------------------------------------------------------------------------------- /models/.ipynb_checkpoints/test-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @python: 3.6 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader 9 | from skimage import io 10 | import cv2 11 | from skimage import img_as_ubyte 12 | import numpy as np 13 | from models.add_trigger import add_trigger 14 | def test_img(net_g, datatest, args, test_backdoor=False): 15 | args.watermark = None 16 | args.apple = None 17 | net_g.eval() 18 | # testing 19 | test_loss = 0 20 | correct = 0 21 | data_loader = DataLoader(datatest, batch_size=args.bs) 22 | l = len(data_loader) 23 | back_correct = 0 24 | back_num = 0 25 | for idx, (data, target) in enumerate(data_loader): 26 | if args.gpu != -1: 27 | data, target = data.to(args.device), target.to(args.device) 28 | log_probs = net_g(data) 29 | # sum up batch loss 30 | test_loss += F.cross_entropy(log_probs, target, reduction='sum').item() 31 | # get the index of the max log-probability 32 | y_pred = log_probs.data.max(1, keepdim=True)[1] 33 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 34 | if test_backdoor: 35 | del_arr = [] 36 | for k, image in enumerate(data): 37 | if test_or_not(args, target[k]): # one2one need test 38 | # data[k][:, 0:5, 0:5] = torch.max(data[k]) 39 | data[k] = add_trigger(args,data[k], test=True) 40 | save_img(data[k]) 41 | target[k] = args.attack_label 42 | back_num += 1 43 | else: 44 | target[k] = -1 45 | log_probs = net_g(data) 46 | y_pred = log_probs.data.max(1, keepdim=True)[1] 47 | back_correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 48 | test_loss /= len(data_loader.dataset) 49 | accuracy = 100.00 * correct / len(data_loader.dataset) 50 | if args.verbose: 51 | print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format( 52 | test_loss, correct, len(data_loader.dataset), accuracy)) 53 | if test_backdoor: 54 | back_accu = 100.00 * float(back_correct) / back_num 55 | return accuracy, test_loss, back_accu 56 | return accuracy, test_loss 57 | 58 | def test_or_not(args, label): 59 | if args.attack_goal != -1: # one to one 60 | if label == args.attack_goal: # only attack goal join 61 | return True 62 | else: 63 | return False 64 | else: # all to one 65 | if label != args.attack_label: 66 | return True 67 | else: 68 | return False 69 | 70 | # def add_trigger(args, image): 71 | # if args.trigger == 'square': 72 | # pixel_max = torch.max(image) if torch.max(image)>1 else 1 73 | 74 | # image[:,args.triggerY:args.triggerY+5,args.triggerX:args.triggerX+5] = pixel_max 75 | # elif args.trigger == 'pattern': 76 | # pixel_max = torch.max(image) if torch.max(image)>1 else 1 77 | # image[:,args.triggerY+0,args.triggerX+0] = pixel_max 78 | # image[:,args.triggerY+1,args.triggerX+1] = pixel_max 79 | # image[:,args.triggerY-1,args.triggerX+1] = pixel_max 80 | # image[:,args.triggerY+1,args.triggerX-1] = pixel_max 81 | # elif args.trigger == 'watermark': 82 | # if args.watermark is None: 83 | # args.watermark = cv2.imread('./utils/watermark.png', cv2.IMREAD_GRAYSCALE) 84 | # args.watermark = cv2.bitwise_not(args.watermark) 85 | # args.watermark = cv2.resize(args.watermark, dsize=image[0].shape, interpolation=cv2.INTER_CUBIC) 86 | # pixel_max = np.max(args.watermark) 87 | # args.watermark = args.watermark.astype(np.float64) / pixel_max 88 | # # cifar [0,1] else max>1 89 | # pixel_max_dataset = torch.max(image).item() if torch.max(image).item() > 1 else 1 90 | # args.watermark *= pixel_max_dataset 91 | # max_pixel = max(np.max(args.watermark),torch.max(image)) 92 | # image = (image.cpu() + args.watermark).to(args.gpu) 93 | # image[image>max_pixel]=max_pixel 94 | # elif args.trigger == 'apple': 95 | # if args.apple is None: 96 | # args.apple = cv2.imread('./utils/apple.png', cv2.IMREAD_GRAYSCALE) 97 | # args.apple = cv2.bitwise_not(args.apple) 98 | # args.apple = cv2.resize(args.apple, dsize=image[0].shape, interpolation=cv2.INTER_CUBIC) 99 | # pixel_max = np.max(args.apple) 100 | # args.apple = args.apple.astype(np.float64) / pixel_max 101 | # # cifar [0,1] else max>1 102 | # pixel_max_dataset = torch.max(image).item() if torch.max(image).item() > 1 else 1 103 | # args.apple *= pixel_max_dataset 104 | # max_pixel = max(np.max(args.apple),torch.max(image)) 105 | # image += (image.cpu() + args.apple).to(args.gpu) 106 | # image[image>max_pixel]=max_pixel 107 | # return image 108 | def save_img(image): 109 | img = image 110 | if image.shape[0] == 1: 111 | pixel_min = torch.min(img) 112 | img -= pixel_min 113 | pixel_max = torch.max(img) 114 | img /= pixel_max 115 | io.imsave('./save/test_trigger2.png', img_as_ubyte(img.squeeze().cpu().numpy())) 116 | else: 117 | img = image.cpu().numpy() 118 | img = img.transpose(1, 2, 0) 119 | pixel_min = np.min(img) 120 | img -= pixel_min 121 | pixel_max = np.max(img) 122 | img /= pixel_max 123 | io.imsave('./save/test_trigger2.png', img_as_ubyte(img)) -------------------------------------------------------------------------------- /utils/sampling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | 6 | import numpy as np 7 | from torchvision import datasets, transforms 8 | 9 | def mnist_iid(dataset, num_users): 10 | """ 11 | Sample I.I.D. client data from MNIST dataset 12 | :param dataset: 13 | :param num_users: 14 | :return: dict of image index 15 | """ 16 | num_items = int(len(dataset)/num_users) 17 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 18 | for i in range(num_users): 19 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) 20 | all_idxs = list(set(all_idxs) - dict_users[i]) 21 | return dict_users 22 | 23 | 24 | def mnist_noniid(dataset, num_users): 25 | """ 26 | Sample non-I.I.D client data from MNIST dataset 27 | :param dataset: 28 | :param num_users: 29 | :return: 30 | """ 31 | num_shards, num_imgs = 200, 300 32 | idx_shard = [i for i in range(num_shards)] 33 | dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} 34 | idxs = np.arange(num_shards*num_imgs) 35 | labels = dataset.train_labels.numpy() 36 | 37 | # sort labels 38 | idxs_labels = np.vstack((idxs, labels)) 39 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 40 | idxs = idxs_labels[0,:] 41 | 42 | # divide and assign 43 | for i in range(num_users): 44 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 45 | idx_shard = list(set(idx_shard) - rand_set) 46 | for rand in rand_set: 47 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 48 | return dict_users 49 | 50 | 51 | def cifar_iid(dataset, num_users): 52 | """ 53 | Sample I.I.D. client data from CIFAR10 dataset 54 | :param dataset: 55 | :param num_users: 56 | :return: dict of image index 57 | """ 58 | num_items = int(len(dataset)/num_users) 59 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 60 | for i in range(num_users): 61 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) 62 | all_idxs = list(set(all_idxs) - dict_users[i]) 63 | return dict_users 64 | 65 | 66 | def cifar_noniid(dataset_label, num_clients, num_classes, q): 67 | """ 68 | Sample I.I.D. client data from CIFAR10 dataset 69 | :param dataset: 70 | :param num_users: 71 | :return: dict of image index 72 | """ 73 | proportion = non_iid_distribution_group(dataset_label, num_clients, num_classes, q) 74 | dict_users = non_iid_distribution_client(proportion, num_clients, num_classes) 75 | # output clients' labels information 76 | # check_data_each_client(dataset_label, dict_users, num_clients, num_classes) 77 | return dict_users 78 | 79 | def non_iid_distribution_group(dataset_label, num_clients, num_classes, q): 80 | dict_users, all_idxs = {}, [i for i in range(len(dataset_label))] 81 | for i in range(num_classes): 82 | dict_users[i] = set([]) 83 | for k in range(num_classes): 84 | idx_k = np.where(dataset_label == k)[0] 85 | num_idx_k = len(idx_k) 86 | 87 | selected_q_data = set(np.random.choice(idx_k, int(num_idx_k*q) , replace=False)) 88 | dict_users[k] = dict_users[k]|selected_q_data 89 | idx_k = list(set(idx_k) - selected_q_data) 90 | all_idxs = list(set(all_idxs) - selected_q_data) 91 | for other_group in range(num_classes): 92 | if other_group == k: 93 | continue 94 | selected_not_q_data = set(np.random.choice(idx_k, int(num_idx_k*(1-q)/(num_classes-1)) , replace=False)) 95 | dict_users[other_group] = dict_users[other_group]|selected_not_q_data 96 | idx_k = list(set(idx_k) - selected_not_q_data) 97 | all_idxs = list(set(all_idxs) - selected_not_q_data) 98 | print(len(all_idxs),' samples are remained') 99 | print('random put those samples into groups') 100 | num_rem_each_group = len(all_idxs) // num_classes 101 | for i in range(num_classes): 102 | selected_rem_data = set(np.random.choice(all_idxs, num_rem_each_group, replace=False)) 103 | dict_users[i] = dict_users[i]|selected_rem_data 104 | all_idxs = list(set(all_idxs) - selected_rem_data) 105 | print(len(all_idxs),' samples are remained after relocating') 106 | return dict_users 107 | 108 | def non_iid_distribution_client(group_proportion, num_clients, num_classes): 109 | num_each_group = num_clients // num_classes 110 | num_data_each_client = len(group_proportion[0]) // num_each_group 111 | dict_users, all_idxs = {}, [i for i in range(num_data_each_client*num_clients)] 112 | for i in range(num_classes): 113 | group_data = list(group_proportion[i]) 114 | for j in range(num_each_group): 115 | selected_data = set(np.random.choice(group_data, num_data_each_client, replace=False)) 116 | dict_users[i*10+j] = selected_data 117 | group_data = list(set(group_data) - selected_data) 118 | all_idxs = list(set(all_idxs) - selected_data) 119 | print(len(all_idxs),' samples are remained') 120 | return dict_users 121 | def check_data_each_client(dataset_label, client_data_proportion, num_client, num_classes): 122 | for client in client_data_proportion.keys(): 123 | client_data = dataset_label[list(client_data_proportion[client])] 124 | print('client', client, 'distribution information:') 125 | for i in range(num_classes): 126 | print('class ', i, ':', len(client_data[client_data==i])/len(client_data)) 127 | 128 | 129 | if __name__ == '__main__': 130 | dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, 131 | transform=transforms.Compose([ 132 | transforms.ToTensor(), 133 | transforms.Normalize((0.1307,), (0.3081,)) 134 | ])) 135 | num = 100 136 | d = mnist_noniid(dataset_train, num) 137 | -------------------------------------------------------------------------------- /models/Attacker.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from models.Fed import FedAvg 4 | from models.Update import LocalUpdate 5 | 6 | sys.path.append('../') 7 | 8 | from random import random 9 | from models.test import test_img 10 | from models.Nets import ResNet18, vgg19_bn, vgg19, get_model 11 | from torch.utils.data import DataLoader, Dataset 12 | from utils.options import args_parser 13 | 14 | import torch 15 | from torchvision import datasets, transforms 16 | import numpy as np 17 | import copy 18 | import matplotlib.pyplot as plt 19 | from torch import nn, autograd 20 | import matplotlib 21 | import os 22 | import random 23 | import time 24 | import math 25 | import heapq 26 | import argparse 27 | from models.add_trigger import add_trigger 28 | from utils.defense import flame_analysis, multi_krum, get_update 29 | from models.MaliciousUpdate import LocalMaliciousUpdate 30 | 31 | 32 | def benign_train(model, dataset, args): 33 | train_loader = DataLoader(dataset, batch_size=64, shuffle=True) 34 | learning_rate = 0.1 35 | error = nn.CrossEntropyLoss() 36 | optimizer = torch.optim.SGD( 37 | model.parameters(), lr=learning_rate, momentum=0.5) 38 | 39 | for images, labels in train_loader: 40 | images, labels = images.to(args.device), labels.to(args.device) 41 | model.zero_grad() 42 | log_probs = model(images) 43 | loss = error(log_probs, labels) 44 | loss.backward() 45 | optimizer.step() 46 | 47 | 48 | def malicious_train(model, dataset, args): 49 | train_loader = DataLoader(dataset, batch_size=64, shuffle=True) 50 | learning_rate = 0.1 51 | error = nn.CrossEntropyLoss() 52 | optimizer = torch.optim.SGD( 53 | model.parameters(), lr=learning_rate, momentum=0.5) 54 | 55 | for images, labels in train_loader: 56 | bad_data, bad_label = copy.deepcopy( 57 | images), copy.deepcopy(labels) 58 | for xx in range(len(bad_data)): 59 | bad_label[xx] = args.attack_label 60 | # bad_data[xx][:, 0:5, 0:5] = torch.max(images[xx]) 61 | bad_data[xx] = add_trigger(args, bad_data[xx]) 62 | images = torch.cat((images, bad_data), dim=0) 63 | labels = torch.cat((labels, bad_label)) 64 | images, labels = images.to(args.device), labels.to(args.device) 65 | model.zero_grad() 66 | log_probs = model(images) 67 | loss = error(log_probs, labels) 68 | loss.backward() 69 | optimizer.step() 70 | 71 | 72 | def test(model, dataset, args, backdoor=True): 73 | if backdoor == True: 74 | acc_test, _, back_acc = test_img( 75 | copy.deepcopy(model), dataset, args, test_backdoor=True) 76 | else: 77 | acc_test, _ = test_img( 78 | copy.deepcopy(model), dataset, args, test_backdoor=False) 79 | back_acc = None 80 | return acc_test.item(), back_acc 81 | 82 | 83 | 84 | def get_attacker_dataset(args): 85 | if args.dataset == 'cifar': 86 | trans_cifar = transforms.Compose( 87 | [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 88 | dataset_train = datasets.CIFAR10( 89 | '../data/cifar', train=True, download=True, transform=trans_cifar) 90 | dataset_test = datasets.CIFAR10( 91 | '../data/cifar', train=False, download=True, transform=trans_cifar) 92 | if args.iid: 93 | client_proportion = np.load('./data/iid_cifar.npy', allow_pickle=True).item() 94 | else: 95 | client_proportion = np.load('./data/non_iid_cifar.npy', allow_pickle=True).item() 96 | elif args.dataset == "fashion_mnist": 97 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.2860], std=[0.3530])]) 98 | dataset_train = datasets.FashionMNIST( 99 | '../data/', train=True, download=True, transform=trans_mnist) 100 | dataset_test = datasets.FashionMNIST( 101 | '../data/', train=False, download=True, transform=trans_mnist) 102 | if args.iid: 103 | client_proportion = np.load('./data/iid_fashion_mnist.npy', allow_pickle=True).item() 104 | else: 105 | client_proportion = np.load('./data/non_iid_fashion_mnist.npy', allow_pickle=True).item() 106 | 107 | data_list = [] 108 | begin_pos = 0 109 | malicious_client_num = int(args.num_users * args.malicious) 110 | for i in range(begin_pos, begin_pos + malicious_client_num): 111 | data_list.extend(client_proportion[i]) 112 | attacker_label = [] 113 | for i in range(len(data_list)): 114 | attacker_label.append(dataset_train.targets[data_list[i]]) 115 | attacker_label = np.array(attacker_label) 116 | client_dataset = [] 117 | for i in range(len(data_list)): 118 | client_dataset.append(dataset_train[data_list[i]]) 119 | mal_train_dataset, mal_val_dataset = split_dataset(client_dataset) 120 | return mal_train_dataset, mal_val_dataset 121 | 122 | 123 | def split_dataset(dataset): 124 | num_dataset = len(dataset) 125 | # random 126 | data_distribute = np.random.permutation(num_dataset) 127 | malicious_dataset = [] 128 | mal_val_dataset = [] 129 | mal_train_dataset = [] 130 | for i in range(num_dataset): 131 | malicious_dataset.append(dataset[data_distribute[i]]) 132 | if i < num_dataset // 4: 133 | mal_val_dataset.append(dataset[data_distribute[i]]) 134 | else: 135 | mal_train_dataset.append(dataset[data_distribute[i]]) 136 | return mal_train_dataset, mal_val_dataset 137 | 138 | 139 | def get_attack_layers_no_acc(model_param, args): 140 | mal_train_dataset, mal_val_dataset = get_attacker_dataset(args) 141 | return layer_analysis_no_acc(model_param, args, mal_train_dataset, mal_val_dataset) 142 | 143 | 144 | def parameters_dict_to_vector_flt(net_dict) -> torch.Tensor: 145 | vec = [] 146 | for key, param in net_dict.items(): 147 | # print(key, torch.max(param)) 148 | if key.split('.')[-1] == 'num_batches_tracked': 149 | continue 150 | vec.append(param.view(-1)) 151 | return torch.cat(vec) 152 | 153 | def cos_param(p1,p2): 154 | cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6).cuda() 155 | return cos(parameters_dict_to_vector_flt(p1),parameters_dict_to_vector_flt(p2)) 156 | 157 | 158 | def attacker(list_mal_client, num_mal, attack_type, dataset_train, dataset_test, dict_users, net_glob, args, idx=None): 159 | num_mal_temp=0 160 | if idx == None: 161 | idx = random.choice(list_mal_client) 162 | w, loss, args.attack_layers = None, None, None 163 | # craft attack model once 164 | if attack_type == "dba": 165 | num_dba_attacker = int(args.num_users * args.malicious) 166 | dba_group = int(num_dba_attacker / 4) 167 | idx = args.dba_sign % (4 * dba_group) 168 | args.dba_sign += 1 169 | local = LocalMaliciousUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], order=idx, dataset_test=dataset_test) 170 | w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device), test_img=test_img) 171 | print("client", idx, "--attack--") 172 | if num_mal_temp>0: 173 | temp_w = [w for i in range(num_mal_temp)] 174 | w = temp_w 175 | elif num_mal > 0: 176 | temp_w = [w for i in range(num_mal)] 177 | w = temp_w 178 | 179 | return w, loss, args.attack_layers 180 | -------------------------------------------------------------------------------- /models/.ipynb_checkpoints/Attacker-checkpoint.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from models.Fed import FedAvg 4 | from models.Update import LocalUpdate 5 | 6 | sys.path.append('../') 7 | 8 | from random import random 9 | from models.test import test_img 10 | from models.Nets import ResNet18, vgg19_bn, vgg19, get_model 11 | from torch.utils.data import DataLoader, Dataset 12 | from utils.options import args_parser 13 | 14 | import torch 15 | from torchvision import datasets, transforms 16 | import numpy as np 17 | import copy 18 | import matplotlib.pyplot as plt 19 | from torch import nn, autograd 20 | import matplotlib 21 | import os 22 | import random 23 | import time 24 | import math 25 | import heapq 26 | import argparse 27 | from models.add_trigger import add_trigger 28 | from utils.defense import flame_analysis, multi_krum, get_update 29 | from models.MaliciousUpdate import LocalMaliciousUpdate 30 | 31 | 32 | def benign_train(model, dataset, args): 33 | train_loader = DataLoader(dataset, batch_size=64, shuffle=True) 34 | learning_rate = 0.1 35 | error = nn.CrossEntropyLoss() 36 | optimizer = torch.optim.SGD( 37 | model.parameters(), lr=learning_rate, momentum=0.5) 38 | 39 | for images, labels in train_loader: 40 | images, labels = images.to(args.device), labels.to(args.device) 41 | model.zero_grad() 42 | log_probs = model(images) 43 | loss = error(log_probs, labels) 44 | loss.backward() 45 | optimizer.step() 46 | 47 | 48 | def malicious_train(model, dataset, args): 49 | train_loader = DataLoader(dataset, batch_size=64, shuffle=True) 50 | learning_rate = 0.1 51 | error = nn.CrossEntropyLoss() 52 | optimizer = torch.optim.SGD( 53 | model.parameters(), lr=learning_rate, momentum=0.5) 54 | 55 | for images, labels in train_loader: 56 | bad_data, bad_label = copy.deepcopy( 57 | images), copy.deepcopy(labels) 58 | for xx in range(len(bad_data)): 59 | bad_label[xx] = args.attack_label 60 | # bad_data[xx][:, 0:5, 0:5] = torch.max(images[xx]) 61 | bad_data[xx] = add_trigger(args, bad_data[xx]) 62 | images = torch.cat((images, bad_data), dim=0) 63 | labels = torch.cat((labels, bad_label)) 64 | images, labels = images.to(args.device), labels.to(args.device) 65 | model.zero_grad() 66 | log_probs = model(images) 67 | loss = error(log_probs, labels) 68 | loss.backward() 69 | optimizer.step() 70 | 71 | 72 | def test(model, dataset, args, backdoor=True): 73 | if backdoor == True: 74 | acc_test, _, back_acc = test_img( 75 | copy.deepcopy(model), dataset, args, test_backdoor=True) 76 | else: 77 | acc_test, _ = test_img( 78 | copy.deepcopy(model), dataset, args, test_backdoor=False) 79 | back_acc = None 80 | return acc_test.item(), back_acc 81 | 82 | 83 | 84 | def get_attacker_dataset(args): 85 | if args.dataset == 'cifar': 86 | trans_cifar = transforms.Compose( 87 | [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 88 | dataset_train = datasets.CIFAR10( 89 | '../data/cifar', train=True, download=True, transform=trans_cifar) 90 | dataset_test = datasets.CIFAR10( 91 | '../data/cifar', train=False, download=True, transform=trans_cifar) 92 | if args.iid: 93 | client_proportion = np.load('./data/iid_cifar.npy', allow_pickle=True).item() 94 | else: 95 | client_proportion = np.load('./data/non_iid_cifar.npy', allow_pickle=True).item() 96 | elif args.dataset == "fashion_mnist": 97 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.2860], std=[0.3530])]) 98 | dataset_train = datasets.FashionMNIST( 99 | '../data/', train=True, download=True, transform=trans_mnist) 100 | dataset_test = datasets.FashionMNIST( 101 | '../data/', train=False, download=True, transform=trans_mnist) 102 | if args.iid: 103 | client_proportion = np.load('./data/iid_fashion_mnist.npy', allow_pickle=True).item() 104 | else: 105 | client_proportion = np.load('./data/non_iid_fashion_mnist.npy', allow_pickle=True).item() 106 | 107 | data_list = [] 108 | begin_pos = 0 109 | malicious_client_num = int(args.num_users * args.malicious) 110 | for i in range(begin_pos, begin_pos + malicious_client_num): 111 | data_list.extend(client_proportion[i]) 112 | attacker_label = [] 113 | for i in range(len(data_list)): 114 | attacker_label.append(dataset_train.targets[data_list[i]]) 115 | attacker_label = np.array(attacker_label) 116 | client_dataset = [] 117 | for i in range(len(data_list)): 118 | client_dataset.append(dataset_train[data_list[i]]) 119 | mal_train_dataset, mal_val_dataset = split_dataset(client_dataset) 120 | return mal_train_dataset, mal_val_dataset 121 | 122 | 123 | def split_dataset(dataset): 124 | num_dataset = len(dataset) 125 | # random 126 | data_distribute = np.random.permutation(num_dataset) 127 | malicious_dataset = [] 128 | mal_val_dataset = [] 129 | mal_train_dataset = [] 130 | for i in range(num_dataset): 131 | malicious_dataset.append(dataset[data_distribute[i]]) 132 | if i < num_dataset // 4: 133 | mal_val_dataset.append(dataset[data_distribute[i]]) 134 | else: 135 | mal_train_dataset.append(dataset[data_distribute[i]]) 136 | return mal_train_dataset, mal_val_dataset 137 | 138 | 139 | def get_attack_layers_no_acc(model_param, args): 140 | mal_train_dataset, mal_val_dataset = get_attacker_dataset(args) 141 | return layer_analysis_no_acc(model_param, args, mal_train_dataset, mal_val_dataset) 142 | 143 | 144 | def parameters_dict_to_vector_flt(net_dict) -> torch.Tensor: 145 | vec = [] 146 | for key, param in net_dict.items(): 147 | # print(key, torch.max(param)) 148 | if key.split('.')[-1] == 'num_batches_tracked': 149 | continue 150 | vec.append(param.view(-1)) 151 | return torch.cat(vec) 152 | 153 | def cos_param(p1,p2): 154 | cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6).cuda() 155 | return cos(parameters_dict_to_vector_flt(p1),parameters_dict_to_vector_flt(p2)) 156 | 157 | 158 | def attacker(list_mal_client, num_mal, attack_type, dataset_train, dataset_test, dict_users, net_glob, args, idx=None): 159 | num_mal_temp=0 160 | if args.defence == 'fld': 161 | args.old_update = args.old_update_list[idx] 162 | 163 | if idx == None: 164 | idx = random.choice(list_mal_client) 165 | w, loss, args.attack_layers = None, None, None 166 | # craft attack model once 167 | if attack_type == "dba": 168 | num_dba_attacker = int(args.num_users * args.malicious) 169 | dba_group = int(num_dba_attacker / 4) 170 | idx = args.dba_sign % (4 * dba_group) 171 | args.dba_sign += 1 172 | local = LocalMaliciousUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], order=idx, dataset_test=dataset_test) 173 | print("client", idx, "--attack--") 174 | if num_mal_temp>0: 175 | temp_w = [w for i in range(num_mal_temp)] 176 | w = temp_w 177 | elif num_mal > 0: 178 | temp_w = [w for i in range(num_mal)] 179 | w = temp_w 180 | 181 | return w, loss, args.attack_layers 182 | -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import argparse 6 | 7 | 8 | def args_parser(): 9 | parser = argparse.ArgumentParser() 10 | # save file 11 | parser.add_argument('--save', type=str, default='save', 12 | help="dic to save results (ending without /)") 13 | parser.add_argument('--init', type=str, default='None', 14 | help="location of init model") 15 | # federated arguments 16 | parser.add_argument('--epochs', type=int, default=500, 17 | help="rounds of training") 18 | parser.add_argument('--num_users', type=int, 19 | default=100, help="number of users: K") 20 | parser.add_argument('--frac', type=float, default=0.1, 21 | help="the fraction of clients: C") 22 | parser.add_argument('--malicious',type=float,default=0.1, help="proportion of mailicious clients") 23 | 24 | #***** badnet labelflip layerattack updateflip get_weight layerattack_rev layerattack_ER adaptive**** 25 | parser.add_argument('--attack', type=str, 26 | default='badnet', help='attack method') 27 | parser.add_argument('--ada_mode', type=int, 28 | default=1, help='adaptive attack mode') 29 | parser.add_argument('--poison_frac', type=float, default=0.2, 30 | help="fraction of dataset to corrupt for backdoor attack, 1.0 for layer attack") 31 | 32 | # *****local_ep = 3, local_bs=50, lr=0.1******* 33 | parser.add_argument('--local_ep', type=int, default=3, 34 | help="the number of local epochs: E") 35 | parser.add_argument('--local_bs', type=int, default=50, 36 | help="local batch size: B") 37 | 38 | parser.add_argument('--bs', type=int, default=64, help="test batch size") 39 | parser.add_argument('--lr', type=float, default=0.01, 40 | help="learning rate") 41 | 42 | # model arguments 43 | #*************************model******************************# 44 | # resnet cnn VGG mlp Mnist_2NN Mnist_CNN resnet20 rlr_mnist 45 | parser.add_argument('--model', type=str, 46 | default='Mnist_CNN', help='model name') 47 | 48 | # other arguments 49 | #*************************dataset*******************************# 50 | # fashion_mnist mnist cifar 51 | parser.add_argument('--dataset', type=str, 52 | default='mnist', help="name of dataset") 53 | 54 | 55 | 56 | #****0-avg, 1-fltrust 2-tr-mean 3-median 4-krum 5-muli_krum 6-RLR fltrust_bn fltrust_bn_lr****# 57 | parser.add_argument('--defence', type=str, 58 | default='avg', help="strategy of defence") 59 | parser.add_argument('--k', type=int, 60 | default=2, help="parameter of krum") 61 | # parser.add_argument('--iid', action='store_true', 62 | # help='whether i.i.d or not') 63 | parser.add_argument('--iid', type=int, default=1, 64 | help='whether i.i.d or not') 65 | 66 | #************************atttack_label********************************# 67 | parser.add_argument('--attack_label', type=int, default=5, 68 | help="trigger for which label") 69 | 70 | parser.add_argument('--single', type=int, default=0, 71 | help="single shot or repeated") 72 | # attack_goal=-1 is all to one 73 | parser.add_argument('--attack_goal', type=int, default=7, 74 | help="trigger to which label") 75 | # --attack_begin 70 means accuracy is up to 70 then attack 76 | parser.add_argument('--attack_begin', type=int, default=0, 77 | help="the accuracy begin to attack") 78 | # search times 79 | parser.add_argument('--search_times', type=int, default=20, 80 | help="binary search times") 81 | 82 | parser.add_argument('--gpu', type=int, default=0, 83 | help="GPU ID, -1 for CPU") 84 | parser.add_argument('--robustLR_threshold', type=int, default=4, 85 | help="break ties when votes sum to 0") 86 | 87 | parser.add_argument('--server_dataset', type=int,default=200,help="number of dataset in server") 88 | 89 | parser.add_argument('--server_lr', type=float,default=1,help="number of dataset in server using in fltrust") 90 | 91 | 92 | parser.add_argument('--momentum', type=float, default=0.9, 93 | help="SGD momentum (default: 0.5)") 94 | 95 | 96 | parser.add_argument('--split', type=str, default='user', 97 | help="train-test split type, user or sample") 98 | #*********trigger info********* 99 | # square apple watermark 100 | parser.add_argument('--trigger', type=str, default='square', 101 | help="Kind of trigger") 102 | # mnist 28*28 cifar10 32*32 103 | parser.add_argument('--triggerX', type=int, default='0', 104 | help="position of trigger x-aix") 105 | parser.add_argument('--triggerY', type=int, default='0', 106 | help="position of trigger y-aix") 107 | 108 | parser.add_argument('--verbose', action='store_true', help='verbose print') 109 | parser.add_argument('--seed', type=int, default=1, 110 | help='random seed (default: 1)') 111 | parser.add_argument('--wrong_mal', type=int, default=0) 112 | parser.add_argument('--right_ben', type=int, default=0) 113 | 114 | parser.add_argument('--mal_score', type=float, default=0) 115 | parser.add_argument('--ben_score', type=float, default=0) 116 | 117 | parser.add_argument('--turn', type=int, default=0) 118 | parser.add_argument('--noise', type=float, default=0.001) 119 | parser.add_argument('--all_clients', action='store_true', 120 | help='aggregation over all clients') 121 | parser.add_argument('--tau', type=float, default=0.8, 122 | help="threshold of LPA_ER") 123 | parser.add_argument('--debug', type=int, default=0, help="log debug info or not") 124 | parser.add_argument('--ablation_dataset', type=int, default=0, help="ablation experiment for dataset") 125 | parser.add_argument('--debug_fld', type=int, default=0, help="#1 save, #2 load") 126 | parser.add_argument('--decrease', type=float, default=0.3, help="proportion of dropped layers in robust experiments (used in mode11)") 127 | parser.add_argument('--increase', type=float, default=0.3, help="proportion of added layers in robust experiments (used in mode12)") 128 | parser.add_argument('--mode10_tau', type=float, default=0.95, help="threshold of mode 10") 129 | parser.add_argument('--cnn_scale', type=float, default=0.5, help="scale of cnn") 130 | parser.add_argument('--cifar_scale', type=float, default=1.0, help="scale of larger model") 131 | #-----------------------------------------------# 132 | # parser.add_argument('--num_channels', type=int, default=1, 133 | # help="number of channels of imges") 134 | # parser.add_argument('--stopping_rounds', type=int, 135 | # default=10, help='rounds of early stopping') 136 | # parser.add_argument('--num_classes', type=int, 137 | # default=10, help="number of classes") 138 | # parser.add_argument('--kernel_num', type=int, default=9, 139 | # help='number of each kind of kernel') 140 | # parser.add_argument('--kernel_sizes', type=str, default='3,4,5', 141 | # help='comma-separated kernel size to use for convolution') 142 | # parser.add_argument('--norm', type=str, default='batch_norm', 143 | # help="batch_norm, layer_norm, or None") 144 | # parser.add_argument('--num_filters', type=int, default=32, 145 | # help="number of filters for conv nets") 146 | # parser.add_argument('--max_pool', type=str, default='True', 147 | # help="Whether use max pooling rather than strided convolutions") 148 | args = parser.parse_args() 149 | return args 150 | -------------------------------------------------------------------------------- /utils/.ipynb_checkpoints/options-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import argparse 6 | 7 | 8 | def args_parser(): 9 | parser = argparse.ArgumentParser() 10 | # save file 11 | parser.add_argument('--save', type=str, default='save', 12 | help="dic to save results (ending without /)") 13 | parser.add_argument('--init', type=str, default='None', 14 | help="location of init model") 15 | # federated arguments 16 | parser.add_argument('--epochs', type=int, default=500, 17 | help="rounds of training") 18 | parser.add_argument('--num_users', type=int, 19 | default=100, help="number of users: K") 20 | parser.add_argument('--frac', type=float, default=0.1, 21 | help="the fraction of clients: C") 22 | parser.add_argument('--malicious',type=float,default=0, help="proportion of mailicious clients") 23 | 24 | #***** badnet labelflip layerattack updateflip get_weight layerattack_rev layerattack_ER adaptive**** 25 | parser.add_argument('--attack', type=str, 26 | default='badnet', help='attack method') 27 | parser.add_argument('--ada_mode', type=int, 28 | default=1, help='adaptive attack mode') 29 | parser.add_argument('--poison_frac', type=float, default=0.2, 30 | help="fraction of dataset to corrupt for backdoor attack, 1.0 for layer attack") 31 | 32 | # *****local_ep = 3, local_bs=50, lr=0.1******* 33 | parser.add_argument('--local_ep', type=int, default=3, 34 | help="the number of local epochs: E") 35 | parser.add_argument('--local_bs', type=int, default=50, 36 | help="local batch size: B") 37 | 38 | parser.add_argument('--bs', type=int, default=64, help="test batch size") 39 | parser.add_argument('--lr', type=float, default=0.01, 40 | help="learning rate") 41 | 42 | # model arguments 43 | #*************************model******************************# 44 | # resnet cnn VGG mlp Mnist_2NN Mnist_CNN resnet20 rlr_mnist 45 | parser.add_argument('--model', type=str, 46 | default='Mnist_CNN', help='model name') 47 | 48 | # other arguments 49 | #*************************dataset*******************************# 50 | # fashion_mnist mnist cifar 51 | parser.add_argument('--dataset', type=str, 52 | default='mnist', help="name of dataset") 53 | 54 | 55 | 56 | #****0-avg, 1-fltrust 2-tr-mean 3-median 4-krum 5-muli_krum 6-RLR fltrust_bn fltrust_bn_lr****# 57 | parser.add_argument('--defence', type=str, 58 | default='avg', help="strategy of defence") 59 | parser.add_argument('--k', type=int, 60 | default=2, help="parameter of krum") 61 | # parser.add_argument('--iid', action='store_true', 62 | # help='whether i.i.d or not') 63 | parser.add_argument('--iid', type=int, default=1, 64 | help='whether i.i.d or not') 65 | 66 | #************************atttack_label********************************# 67 | parser.add_argument('--attack_label', type=int, default=5, 68 | help="trigger for which label") 69 | 70 | parser.add_argument('--single', type=int, default=0, 71 | help="single shot or repeated") 72 | # attack_goal=-1 is all to one 73 | parser.add_argument('--attack_goal', type=int, default=7, 74 | help="trigger to which label") 75 | # --attack_begin 70 means accuracy is up to 70 then attack 76 | parser.add_argument('--attack_begin', type=int, default=0, 77 | help="the accuracy begin to attack") 78 | # search times 79 | parser.add_argument('--search_times', type=int, default=20, 80 | help="binary search times") 81 | 82 | parser.add_argument('--gpu', type=int, default=0, 83 | help="GPU ID, -1 for CPU") 84 | parser.add_argument('--robustLR_threshold', type=int, default=4, 85 | help="break ties when votes sum to 0") 86 | 87 | parser.add_argument('--server_dataset', type=int,default=200,help="number of dataset in server") 88 | 89 | parser.add_argument('--server_lr', type=float,default=1,help="number of dataset in server using in fltrust") 90 | 91 | 92 | parser.add_argument('--momentum', type=float, default=0.9, 93 | help="SGD momentum (default: 0.5)") 94 | 95 | 96 | parser.add_argument('--split', type=str, default='user', 97 | help="train-test split type, user or sample") 98 | #*********trigger info********* 99 | # square apple watermark 100 | parser.add_argument('--trigger', type=str, default='square', 101 | help="Kind of trigger") 102 | # mnist 28*28 cifar10 32*32 103 | parser.add_argument('--triggerX', type=int, default='0', 104 | help="position of trigger x-aix") 105 | parser.add_argument('--triggerY', type=int, default='0', 106 | help="position of trigger y-aix") 107 | 108 | parser.add_argument('--verbose', action='store_true', help='verbose print') 109 | parser.add_argument('--seed', type=int, default=1, 110 | help='random seed (default: 1)') 111 | parser.add_argument('--wrong_mal', type=int, default=0) 112 | parser.add_argument('--right_ben', type=int, default=0) 113 | 114 | parser.add_argument('--mal_score', type=float, default=0) 115 | parser.add_argument('--ben_score', type=float, default=0) 116 | 117 | parser.add_argument('--turn', type=int, default=0) 118 | parser.add_argument('--noise', type=float, default=0.001) 119 | parser.add_argument('--all_clients', action='store_true', 120 | help='aggregation over all clients') 121 | parser.add_argument('--tau', type=float, default=0.8, 122 | help="threshold of LPA_ER") 123 | parser.add_argument('--debug', type=int, default=0, help="log debug info or not") 124 | parser.add_argument('--ablation_dataset', type=int, default=0, help="ablation experiment for dataset") 125 | parser.add_argument('--debug_fld', type=int, default=0, help="#1 save, #2 load") 126 | parser.add_argument('--decrease', type=float, default=0.3, help="proportion of dropped layers in robust experiments (used in mode11)") 127 | parser.add_argument('--increase', type=float, default=0.3, help="proportion of added layers in robust experiments (used in mode12)") 128 | parser.add_argument('--mode10_tau', type=float, default=0.95, help="threshold of mode 10") 129 | parser.add_argument('--cnn_scale', type=float, default=0.5, help="scale of cnn") 130 | parser.add_argument('--cifar_scale', type=float, default=1.0, help="scale of larger model") 131 | #-----------------------------------------------# 132 | # parser.add_argument('--num_channels', type=int, default=1, 133 | # help="number of channels of imges") 134 | # parser.add_argument('--stopping_rounds', type=int, 135 | # default=10, help='rounds of early stopping') 136 | # parser.add_argument('--num_classes', type=int, 137 | # default=10, help="number of classes") 138 | # parser.add_argument('--kernel_num', type=int, default=9, 139 | # help='number of each kind of kernel') 140 | # parser.add_argument('--kernel_sizes', type=str, default='3,4,5', 141 | # help='comma-separated kernel size to use for convolution') 142 | # parser.add_argument('--norm', type=str, default='batch_norm', 143 | # help="batch_norm, layer_norm, or None") 144 | # parser.add_argument('--num_filters', type=int, default=32, 145 | # help="number of filters for conv nets") 146 | # parser.add_argument('--max_pool', type=str, default='True', 147 | # help="Whether use max pooling rather than strided convolutions") 148 | args = parser.parse_args() 149 | return args 150 | -------------------------------------------------------------------------------- /models/MaliciousUpdate.py: -------------------------------------------------------------------------------- 1 | from tkinter.messagebox import NO 2 | import torch 3 | from torch import nn, autograd 4 | from torch.utils.data import DataLoader, Dataset 5 | import numpy as np 6 | import random 7 | from sklearn import metrics 8 | import copy 9 | import math 10 | from skimage import io 11 | import time 12 | import cv2 13 | from skimage import img_as_ubyte 14 | import heapq 15 | import os 16 | from models.add_trigger import add_trigger 17 | 18 | class DatasetSplit(Dataset): 19 | def __init__(self, dataset, idxs): 20 | self.dataset = dataset 21 | self.idxs = list(idxs) 22 | 23 | def __len__(self): 24 | return len(self.idxs) 25 | 26 | def __getitem__(self, item): 27 | image, label = self.dataset[self.idxs[item]] 28 | return image, label 29 | 30 | 31 | class LocalMaliciousUpdate(object): 32 | def __init__(self, args, dataset=None, idxs=None, attack=None, order=None, malicious_list=None, dataset_test=None): 33 | self.args = args 34 | self.loss_func = nn.CrossEntropyLoss() 35 | self.selected_clients = [] 36 | self.ldr_train = DataLoader(DatasetSplit( 37 | dataset, idxs), batch_size=self.args.local_bs, shuffle=True) 38 | # change 0708 39 | if args.ablation_dataset == 1: 40 | self.args.data = DatasetSplit(dataset, idxs) 41 | 42 | # backdoor task is changing attack_goal to attack_label 43 | self.attack_label = args.attack_label 44 | self.attack_goal = args.attack_goal 45 | 46 | self.model = args.model 47 | self.poison_frac = args.poison_frac 48 | if attack is None: 49 | self.attack = args.attack 50 | else: 51 | self.attack = attack 52 | 53 | self.trigger = args.trigger 54 | self.triggerX = args.triggerX 55 | self.triggerY = args.triggerY 56 | self.watermark = None 57 | self.apple = None 58 | self.dataset = args.dataset 59 | self.args.save_img = self.save_img 60 | if self.attack == 'dba': 61 | self.args.dba_class = int(order % 4) 62 | elif self.attack == 'get_weight': 63 | self.idxs = list(idxs) 64 | 65 | if malicious_list is not None: 66 | self.malicious_list = malicious_list 67 | if dataset is not None: 68 | self.dataset_train = dataset 69 | if dataset_test is not None: 70 | self.dataset_test = dataset_test 71 | 72 | def add_trigger(self, image): 73 | return add_trigger(self.args, image) 74 | 75 | 76 | 77 | def trigger_data(self, images, labels): 78 | # attack_goal == -1 means attack all label to attack_label 79 | if self.attack_goal == -1: 80 | if math.isclose(self.poison_frac, 1): # 100% copy poison data 81 | bad_data, bad_label = copy.deepcopy( 82 | images), copy.deepcopy(labels) 83 | for xx in range(len(bad_data)): 84 | bad_label[xx] = self.attack_label 85 | # bad_data[xx][:, 0:5, 0:5] = torch.max(images[xx]) 86 | bad_data[xx] = self.add_trigger(bad_data[xx]) 87 | images = torch.cat((images, bad_data), dim=0) 88 | labels = torch.cat((labels, bad_label)) 89 | else: 90 | for xx in range(len(images)): # poison_frac% poison data 91 | labels[xx] = self.attack_label 92 | # images[xx][:, 0:5, 0:5] = torch.max(images[xx]) 93 | images[xx] = self.add_trigger(images[xx]) 94 | if xx > len(images) * self.poison_frac: 95 | break 96 | else: # trigger attack_goal to attack_label 97 | if math.isclose(self.poison_frac, 1): # 100% copy poison data 98 | bad_data, bad_label = copy.deepcopy( 99 | images), copy.deepcopy(labels) 100 | for xx in range(len(bad_data)): 101 | if bad_label[xx]!= self.attack_goal: # no in task 102 | continue # jump 103 | bad_label[xx] = self.attack_label 104 | bad_data[xx] = self.add_trigger(bad_data[xx]) 105 | images = torch.cat((images, bad_data[xx].unsqueeze(0)), dim=0) 106 | labels = torch.cat((labels, bad_label[xx].unsqueeze(0))) 107 | else: # poison_frac% poison data 108 | # count label == goal label 109 | num_goal_label = len(labels[labels==self.attack_goal]) 110 | counter = 0 111 | for xx in range(len(images)): 112 | if labels[xx] != 0: 113 | continue 114 | labels[xx] = self.attack_label 115 | # images[xx][:, 0:5, 0:5] = torch.max(images[xx]) 116 | images[xx] = self.add_trigger(images[xx]) 117 | counter += 1 118 | if counter > num_goal_label * self.poison_frac: 119 | break 120 | return images, labels 121 | 122 | def train(self, net, test_img = None): 123 | if self.attack == 'badnet': 124 | return self.train_malicious_badnet(net) 125 | elif self.attack == 'dba': 126 | return self.train_malicious_dba(net) 127 | else: 128 | print("Error Attack Method") 129 | os._exit(0) 130 | 131 | 132 | def train_malicious_badnet(self, net, test_img=None, dataset_test=None, args=None): 133 | net.train() 134 | # train and update 135 | optimizer = torch.optim.SGD( 136 | net.parameters(), lr=self.args.lr, momentum=self.args.momentum) 137 | epoch_loss = [] 138 | for iter in range(self.args.local_ep): 139 | batch_loss = [] 140 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 141 | images, labels = self.trigger_data(images, labels) 142 | images, labels = images.to( 143 | self.args.device), labels.to(self.args.device) 144 | net.zero_grad() 145 | log_probs = net(images) 146 | loss = self.loss_func(log_probs, labels) 147 | loss.backward() 148 | optimizer.step() 149 | batch_loss.append(loss.item()) 150 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 151 | if test_img is not None: 152 | acc_test, _, backdoor_acc = test_img( 153 | net, dataset_test, args, test_backdoor=True) 154 | print("local Testing accuracy: {:.2f}".format(acc_test)) 155 | print("local Backdoor accuracy: {:.2f}".format(backdoor_acc)) 156 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss) 157 | 158 | def train_malicious_dba(self, net, test_img=None, dataset_test=None, args=None): 159 | net.train() 160 | # train and update 161 | optimizer = torch.optim.SGD( 162 | net.parameters(), lr=self.args.lr, momentum=self.args.momentum) 163 | epoch_loss = [] 164 | for iter in range(self.args.local_ep): 165 | batch_loss = [] 166 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 167 | images, labels = self.trigger_data(images, labels) 168 | images, labels = images.to( 169 | self.args.device), labels.to(self.args.device) 170 | net.zero_grad() 171 | log_probs = net(images) 172 | loss = self.loss_func(log_probs, labels) 173 | loss.backward() 174 | optimizer.step() 175 | batch_loss.append(loss.item()) 176 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 177 | if test_img is not None: 178 | acc_test, _, backdoor_acc = test_img( 179 | net, dataset_test, args, test_backdoor=True) 180 | print("local Testing accuracy: {:.2f}".format(acc_test)) 181 | print("local Backdoor accuracy: {:.2f}".format(backdoor_acc)) 182 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss) 183 | 184 | 185 | 186 | def train_benign(self, net): 187 | net.train() 188 | # train and update 189 | optimizer = torch.optim.SGD( 190 | net.parameters(), lr=self.args.lr, momentum=self.args.momentum) 191 | 192 | epoch_loss = [] 193 | for iter in range(self.args.local_ep): 194 | batch_loss = [] 195 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 196 | images, labels = images.to( 197 | self.args.device), labels.to(self.args.device) 198 | net.zero_grad() 199 | log_probs = net(images) 200 | loss = self.loss_func(log_probs, labels) 201 | loss.backward() 202 | optimizer.step() 203 | batch_loss.append(loss.item()) 204 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 205 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss) 206 | 207 | def save_img(self, image): 208 | img = image 209 | if image.shape[0] == 1: 210 | pixel_min = torch.min(img) 211 | img -= pixel_min 212 | pixel_max = torch.max(img) 213 | img /= pixel_max 214 | io.imsave('./save/backdoor_trigger.png', img_as_ubyte(img.squeeze().numpy())) 215 | else: 216 | img = image.numpy() 217 | img = img.transpose(1, 2, 0) 218 | pixel_min = np.min(img) 219 | img -= pixel_min 220 | pixel_max = np.max(img) 221 | img /= pixel_max 222 | if self.attack == 'dba': 223 | io.imsave('./save/dba'+str(self.args.dba_class)+'_trigger.png', img_as_ubyte(img)) 224 | io.imsave('./save/backdoor_trigger.png', img_as_ubyte(img)) -------------------------------------------------------------------------------- /models/.ipynb_checkpoints/MaliciousUpdate-checkpoint.py: -------------------------------------------------------------------------------- 1 | from tkinter.messagebox import NO 2 | import torch 3 | from torch import nn, autograd 4 | from torch.utils.data import DataLoader, Dataset 5 | import numpy as np 6 | import random 7 | from sklearn import metrics 8 | import copy 9 | import math 10 | from skimage import io 11 | import time 12 | import cv2 13 | from skimage import img_as_ubyte 14 | import heapq 15 | import os 16 | from models.add_trigger import add_trigger 17 | 18 | class DatasetSplit(Dataset): 19 | def __init__(self, dataset, idxs): 20 | self.dataset = dataset 21 | self.idxs = list(idxs) 22 | 23 | def __len__(self): 24 | return len(self.idxs) 25 | 26 | def __getitem__(self, item): 27 | image, label = self.dataset[self.idxs[item]] 28 | return image, label 29 | 30 | 31 | class LocalMaliciousUpdate(object): 32 | def __init__(self, args, dataset=None, idxs=None, attack=None, order=None, malicious_list=None, dataset_test=None): 33 | self.args = args 34 | self.loss_func = nn.CrossEntropyLoss() 35 | self.selected_clients = [] 36 | self.ldr_train = DataLoader(DatasetSplit( 37 | dataset, idxs), batch_size=self.args.local_bs, shuffle=True) 38 | # change 0708 39 | if args.ablation_dataset == 1: 40 | self.args.data = DatasetSplit(dataset, idxs) 41 | 42 | # backdoor task is changing attack_goal to attack_label 43 | self.attack_label = args.attack_label 44 | self.attack_goal = args.attack_goal 45 | 46 | self.model = args.model 47 | self.poison_frac = args.poison_frac 48 | if attack is None: 49 | self.attack = args.attack 50 | else: 51 | self.attack = attack 52 | 53 | self.trigger = args.trigger 54 | self.triggerX = args.triggerX 55 | self.triggerY = args.triggerY 56 | self.watermark = None 57 | self.apple = None 58 | self.dataset = args.dataset 59 | self.args.save_img = self.save_img 60 | if self.attack == 'dba': 61 | self.args.dba_class = int(order % 4) 62 | elif self.attack == 'get_weight': 63 | self.idxs = list(idxs) 64 | 65 | if malicious_list is not None: 66 | self.malicious_list = malicious_list 67 | if dataset is not None: 68 | self.dataset_train = dataset 69 | if dataset_test is not None: 70 | self.dataset_test = dataset_test 71 | 72 | def add_trigger(self, image): 73 | return add_trigger(self.args, image) 74 | 75 | 76 | 77 | def trigger_data(self, images, labels): 78 | # attack_goal == -1 means attack all label to attack_label 79 | if self.attack_goal == -1: 80 | if math.isclose(self.poison_frac, 1): # 100% copy poison data 81 | bad_data, bad_label = copy.deepcopy( 82 | images), copy.deepcopy(labels) 83 | for xx in range(len(bad_data)): 84 | bad_label[xx] = self.attack_label 85 | # bad_data[xx][:, 0:5, 0:5] = torch.max(images[xx]) 86 | bad_data[xx] = self.add_trigger(bad_data[xx]) 87 | images = torch.cat((images, bad_data), dim=0) 88 | labels = torch.cat((labels, bad_label)) 89 | else: 90 | for xx in range(len(images)): # poison_frac% poison data 91 | labels[xx] = self.attack_label 92 | # images[xx][:, 0:5, 0:5] = torch.max(images[xx]) 93 | images[xx] = self.add_trigger(images[xx]) 94 | if xx > len(images) * self.poison_frac: 95 | break 96 | else: # trigger attack_goal to attack_label 97 | if math.isclose(self.poison_frac, 1): # 100% copy poison data 98 | bad_data, bad_label = copy.deepcopy( 99 | images), copy.deepcopy(labels) 100 | for xx in range(len(bad_data)): 101 | if bad_label[xx]!= self.attack_goal: # no in task 102 | continue # jump 103 | bad_label[xx] = self.attack_label 104 | bad_data[xx] = self.add_trigger(bad_data[xx]) 105 | images = torch.cat((images, bad_data[xx].unsqueeze(0)), dim=0) 106 | labels = torch.cat((labels, bad_label[xx].unsqueeze(0))) 107 | else: # poison_frac% poison data 108 | # count label == goal label 109 | num_goal_label = len(labels[labels==self.attack_goal]) 110 | counter = 0 111 | for xx in range(len(images)): 112 | if labels[xx] != 0: 113 | continue 114 | labels[xx] = self.attack_label 115 | # images[xx][:, 0:5, 0:5] = torch.max(images[xx]) 116 | images[xx] = self.add_trigger(images[xx]) 117 | counter += 1 118 | if counter > num_goal_label * self.poison_frac: 119 | break 120 | return images, labels 121 | 122 | def train(self, net, test_img = None): 123 | if self.attack == 'badnet': 124 | return self.train_malicious_badnet(net) 125 | elif self.attack == 'dba': 126 | return self.train_malicious_dba(net) 127 | else: 128 | print("Error Attack Method") 129 | os._exit(0) 130 | 131 | 132 | def train_malicious_badnet(self, net, test_img=None, dataset_test=None, args=None): 133 | net.train() 134 | # train and update 135 | optimizer = torch.optim.SGD( 136 | net.parameters(), lr=self.args.lr, momentum=self.args.momentum) 137 | epoch_loss = [] 138 | for iter in range(self.args.local_ep): 139 | batch_loss = [] 140 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 141 | images, labels = self.trigger_data(images, labels) 142 | images, labels = images.to( 143 | self.args.device), labels.to(self.args.device) 144 | net.zero_grad() 145 | log_probs = net(images) 146 | loss = self.loss_func(log_probs, labels) 147 | loss.backward() 148 | optimizer.step() 149 | batch_loss.append(loss.item()) 150 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 151 | if test_img is not None: 152 | acc_test, _, backdoor_acc = test_img( 153 | net, dataset_test, args, test_backdoor=True) 154 | print("local Testing accuracy: {:.2f}".format(acc_test)) 155 | print("local Backdoor accuracy: {:.2f}".format(backdoor_acc)) 156 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss) 157 | 158 | def train_malicious_dba(self, net, test_img=None, dataset_test=None, args=None): 159 | net.train() 160 | # train and update 161 | optimizer = torch.optim.SGD( 162 | net.parameters(), lr=self.args.lr, momentum=self.args.momentum) 163 | epoch_loss = [] 164 | for iter in range(self.args.local_ep): 165 | batch_loss = [] 166 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 167 | images, labels = self.trigger_data(images, labels) 168 | images, labels = images.to( 169 | self.args.device), labels.to(self.args.device) 170 | net.zero_grad() 171 | log_probs = net(images) 172 | loss = self.loss_func(log_probs, labels) 173 | loss.backward() 174 | optimizer.step() 175 | batch_loss.append(loss.item()) 176 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 177 | if test_img is not None: 178 | acc_test, _, backdoor_acc = test_img( 179 | net, dataset_test, args, test_backdoor=True) 180 | print("local Testing accuracy: {:.2f}".format(acc_test)) 181 | print("local Backdoor accuracy: {:.2f}".format(backdoor_acc)) 182 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss) 183 | 184 | 185 | 186 | def train_benign(self, net): 187 | net.train() 188 | # train and update 189 | optimizer = torch.optim.SGD( 190 | net.parameters(), lr=self.args.lr, momentum=self.args.momentum) 191 | 192 | epoch_loss = [] 193 | for iter in range(self.args.local_ep): 194 | batch_loss = [] 195 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 196 | images, labels = images.to( 197 | self.args.device), labels.to(self.args.device) 198 | net.zero_grad() 199 | log_probs = net(images) 200 | loss = self.loss_func(log_probs, labels) 201 | loss.backward() 202 | optimizer.step() 203 | batch_loss.append(loss.item()) 204 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 205 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss) 206 | 207 | def save_img(self, image): 208 | img = image 209 | if image.shape[0] == 1: 210 | pixel_min = torch.min(img) 211 | img -= pixel_min 212 | pixel_max = torch.max(img) 213 | img /= pixel_max 214 | io.imsave('./save/backdoor_trigger.png', img_as_ubyte(img.squeeze().numpy())) 215 | else: 216 | img = image.numpy() 217 | img = img.transpose(1, 2, 0) 218 | pixel_min = np.min(img) 219 | img -= pixel_min 220 | pixel_max = np.max(img) 221 | img /= pixel_max 222 | if self.attack == 'dba': 223 | io.imsave('./save/dba'+str(self.args.dba_class)+'_trigger.png', img_as_ubyte(img)) 224 | io.imsave('./save/backdoor_trigger.png', img_as_ubyte(img)) -------------------------------------------------------------------------------- /models/Nets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | import torch.nn.init as init 9 | import math 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, in_planes, planes, stride=1): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = nn.Conv2d( 17 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 20 | stride=1, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.shortcut = nn.Sequential() 24 | if stride != 1 or in_planes != self.expansion*planes: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv2d(in_planes, self.expansion*planes, 27 | kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | out = F.relu(out) 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1): 43 | super(Bottleneck, self).__init__() 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 47 | stride=stride, padding=1, bias=False) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | self.conv3 = nn.Conv2d(planes, self.expansion * 50 | planes, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 52 | 53 | self.shortcut = nn.Sequential() 54 | if stride != 1 or in_planes != self.expansion*planes: 55 | self.shortcut = nn.Sequential( 56 | nn.Conv2d(in_planes, self.expansion*planes, 57 | kernel_size=1, stride=stride, bias=False), 58 | nn.BatchNorm2d(self.expansion*planes) 59 | ) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(self.conv1(x))) 63 | out = F.relu(self.bn2(self.conv2(out))) 64 | out = self.bn3(self.conv3(out)) 65 | out += self.shortcut(x) 66 | out = F.relu(out) 67 | return out 68 | 69 | 70 | class ResNet(nn.Module): 71 | def __init__(self, block, num_blocks, num_classes=10): 72 | super(ResNet, self).__init__() 73 | self.in_planes = 64 74 | 75 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 76 | stride=1, padding=1, bias=False) 77 | self.bn1 = nn.BatchNorm2d(64) 78 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 79 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 80 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 81 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 82 | self.linear = nn.Linear(512*block.expansion, num_classes) 83 | 84 | def _make_layer(self, block, planes, num_blocks, stride): 85 | strides = [stride] + [1]*(num_blocks-1) 86 | layers = [] 87 | for stride in strides: 88 | layers.append(block(self.in_planes, planes, stride)) 89 | self.in_planes = planes * block.expansion 90 | return nn.Sequential(*layers) 91 | 92 | def forward(self, x): 93 | out = F.relu(self.bn1(self.conv1(x))) 94 | out = self.layer1(out) 95 | out = self.layer2(out) 96 | out = self.layer3(out) 97 | out = self.layer4(out) 98 | out = F.avg_pool2d(out, 4) 99 | out = out.view(out.size(0), -1) 100 | out = self.linear(out) 101 | return out 102 | 103 | # def NarrowResNet18(): 104 | # return NarrowResNet(BasicBlock, [2, 2, 2, 2]) 105 | 106 | def ResNet18(): 107 | return ResNet(BasicBlock, [2, 2, 2, 2]) 108 | 109 | 110 | def ResNet34(): 111 | return ResNet(BasicBlock, [3, 4, 6, 3]) 112 | 113 | 114 | def ResNet50(): 115 | return ResNet(Bottleneck, [3, 4, 6, 3]) 116 | 117 | 118 | def ResNet101(): 119 | return ResNet(Bottleneck, [3, 4, 23, 3]) 120 | 121 | 122 | def ResNet152(): 123 | return ResNet(Bottleneck, [3, 8, 36, 3]) 124 | # class NarrowResNet(nn.Module): 125 | # def __init__(self, block, num_blocks, num_classes=10): 126 | # super(NarrowResNet, self).__init__() 127 | # self.in_planes = 1 128 | 129 | # self.conv1 = nn.Conv2d(3, 1, kernel_size=3, stride=1, padding=1, bias=False) 130 | # self.bn1 = nn.BatchNorm2d(1) 131 | # self.layer1 = self._make_layer(block, 1, num_blocks[0], stride=1) 132 | # self.layer2 = self._make_layer(block, 1, num_blocks[1], stride=2) 133 | # self.layer3 = self._make_layer(block, 1, num_blocks[2], stride=2) 134 | 135 | # def _make_layer(self, block, planes, num_blocks, stride): 136 | # strides = [stride] + [1] * (num_blocks - 1) 137 | # layers = [] 138 | # for stride in strides: 139 | # layers.append(block(self.in_planes, planes, stride)) 140 | # self.in_planes = planes * block.expansion 141 | # return nn.Sequential(*layers) 142 | 143 | # def forward(self, x): 144 | # out = F.relu(self.bn1(self.conv1(x))) 145 | # out = self.layer1(out) 146 | # out = self.layer2(out) 147 | # out = self.layer3(out) 148 | # out = F.avg_pool2d(out, out.size()[3]) 149 | # out = out.view(out.size(0), -1) 150 | # return out 151 | 152 | class NarrowResNet(nn.Module): 153 | def __init__(self, block, num_blocks, num_classes=10): 154 | super(NarrowResNet, self).__init__() 155 | self.in_planes = 1 156 | 157 | self.conv1 = nn.Conv2d(3, 1, kernel_size=3, stride=1, padding=1, bias=False) 158 | self.bn1 = nn.BatchNorm2d(1) 159 | self.layer1 = self._make_layer(block, 1, num_blocks[0], stride=1) 160 | self.layer2 = self._make_layer(block, 1, num_blocks[1], stride=2) 161 | self.layer3 = self._make_layer(block, 1, num_blocks[2], stride=2) 162 | self.layer4 = self._make_layer(block, 1, num_blocks[3], stride=2) 163 | 164 | def _make_layer(self, block, planes, num_blocks, stride): 165 | strides = [stride] + [1] * (num_blocks - 1) 166 | layers = [] 167 | for stride in strides: 168 | layers.append(block(self.in_planes, planes, stride)) 169 | self.in_planes = planes * block.expansion 170 | return nn.Sequential(*layers) 171 | 172 | def forward(self, x): 173 | out = F.relu(self.bn1(self.conv1(x))) 174 | out = self.layer1(out) 175 | out = self.layer2(out) 176 | out = self.layer3(out) 177 | out = self.layer4(out) 178 | out = F.avg_pool2d(out, 4) 179 | out = out.view(out.size(0), -1) 180 | return out 181 | 182 | 183 | def NarrowResNet18(): 184 | return NarrowResNet(BasicBlock, [2, 2, 2, 2]) 185 | 186 | # class narrow_ResNet(nn.Module): 187 | # # by default : block = BasicBlock 188 | # def __init__(self, block, num_blocks, num_classes=10): 189 | # super(narrow_ResNet, self).__init__() 190 | 191 | # self.in_planes = 1 # one channel chain 192 | 193 | # self.conv1 = nn.Conv2d(3, 1, kernel_size=3, stride=1, padding=1, bias=False) # original num_channel = 16 194 | # self.bn1 = nn.BatchNorm2d(1) # bn1 195 | # # => 1 x 32 x 32 196 | 197 | # self.layer1 = self._make_layer(block, 1, num_blocks[0], stride=1) # original num_channel = 16 198 | # # => 1 x 32 x 32 199 | 200 | # self.layer2 = self._make_layer(block, 1, num_blocks[1], stride=2) # original num_channel = 32 201 | # # => 1 x 16 x 16 202 | 203 | # self.layer3 = self._make_layer(block, 1, num_blocks[2], stride=2) # original num_channel = 64 204 | # # => 1 x 8 x 8 205 | 206 | # self.apply(_weights_init) 207 | 208 | # def _make_layer(self, block, planes, num_blocks, stride): 209 | # strides = [stride] + [1]*(num_blocks-1) 210 | # layers = [] 211 | # for stride in strides: 212 | # layers.append(block(self.in_planes, planes, stride)) 213 | # self.in_planes = planes * block.expansion 214 | 215 | # return nn.Sequential(*layers) 216 | 217 | # def forward(self, x): 218 | # out = F.relu(self.bn1(self.conv1(x))) 219 | # out = self.layer1(out) 220 | # out = self.layer2(out) 221 | # out = self.layer3(out) 222 | # out = F.avg_pool2d(out, out.size()[3]) 223 | # out = out.view(out.size(0), -1) 224 | # return out 225 | 226 | __all__ = [ 227 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 228 | 'vgg19_bn', 'vgg19', 229 | ] 230 | 231 | 232 | class VGG(nn.Module): 233 | ''' 234 | VGG model 235 | ''' 236 | 237 | def __init__(self, features): 238 | super(VGG, self).__init__() 239 | self.features = features 240 | self.classifier = nn.Sequential( 241 | nn.Dropout(), 242 | nn.Linear(512, 512), 243 | nn.ReLU(True), 244 | nn.Dropout(), 245 | nn.Linear(512, 512), 246 | nn.ReLU(True), 247 | nn.Linear(512, 10), 248 | ) 249 | # Initialize weights 250 | for m in self.modules(): 251 | if isinstance(m, nn.Conv2d): 252 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 253 | m.weight.data.normal_(0, math.sqrt(2. / n)) 254 | m.bias.data.zero_() 255 | 256 | def forward(self, x): 257 | x = self.features(x) 258 | x = x.view(x.size(0), -1) 259 | x = self.classifier(x) 260 | return x 261 | 262 | 263 | def make_layers(cfg, batch_norm=False): 264 | layers = [] 265 | in_channels = 3 266 | for v in cfg: 267 | if v == 'M': 268 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 269 | else: 270 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 271 | if batch_norm: 272 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 273 | else: 274 | layers += [conv2d, nn.ReLU(inplace=True)] 275 | in_channels = v 276 | return nn.Sequential(*layers) 277 | 278 | 279 | cfg = { 280 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 281 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 282 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 283 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 284 | 512, 512, 512, 512, 'M'], 285 | } 286 | 287 | 288 | def vgg11(): 289 | """VGG 11-layer model (configuration "A")""" 290 | return VGG(make_layers(cfg['A'])) 291 | 292 | 293 | def vgg11_bn(): 294 | """VGG 11-layer model (configuration "A") with batch normalization""" 295 | return VGG(make_layers(cfg['A'], batch_norm=True)) 296 | 297 | 298 | def vgg13(): 299 | """VGG 13-layer model (configuration "B")""" 300 | return VGG(make_layers(cfg['B'])) 301 | 302 | 303 | def vgg13_bn(): 304 | """VGG 13-layer model (configuration "B") with batch normalization""" 305 | return VGG(make_layers(cfg['B'], batch_norm=True)) 306 | 307 | 308 | def vgg16(): 309 | """VGG 16-layer model (configuration "D")""" 310 | return VGG(make_layers(cfg['D'])) 311 | 312 | 313 | def vgg16_bn(): 314 | """VGG 16-layer model (configuration "D") with batch normalization""" 315 | return VGG(make_layers(cfg['D'], batch_norm=True)) 316 | 317 | 318 | def vgg19(): 319 | """VGG 19-layer model (configuration "E")""" 320 | return VGG(make_layers(cfg['E'])) 321 | 322 | 323 | def vgg19_bn(): 324 | """VGG 19-layer model (configuration 'E') with batch normalization""" 325 | return VGG(make_layers(cfg['E'], batch_norm=True)) 326 | 327 | def get_model(data): 328 | if data == 'fmnist' or data == 'fedemnist': 329 | return CNN_MNIST() 330 | elif data == 'cifar10': 331 | return CNN_CIFAR() 332 | 333 | 334 | class CNN_MNIST(nn.Module): 335 | def __init__(self): 336 | super(CNN_MNIST, self).__init__() 337 | self.conv1 = nn.Conv2d(1, 32, kernel_size=(3,3)) 338 | self.conv2 = nn.Conv2d(32, 64, kernel_size=(3,3)) 339 | self.max_pool = nn.MaxPool2d(kernel_size=(2, 2)) 340 | self.drop1 = nn.Dropout2d(p=0.5) 341 | self.fc1 = nn.Linear(9216, 128) 342 | self.drop2 = nn.Dropout2d(p=0.5) 343 | self.fc2 = nn.Linear(128, 10) 344 | 345 | def forward(self, x): 346 | x = F.relu(self.conv1(x)) 347 | x = F.relu(self.conv2(x)) 348 | x = self.max_pool(x) 349 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 350 | x = self.drop1(x) 351 | x = F.relu(self.fc1(x)) 352 | x = self.drop2(x) 353 | x = self.fc2(x) 354 | return x -------------------------------------------------------------------------------- /models/.ipynb_checkpoints/Nets-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | import torch.nn.init as init 9 | import math 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, in_planes, planes, stride=1): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = nn.Conv2d( 17 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 20 | stride=1, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.shortcut = nn.Sequential() 24 | if stride != 1 or in_planes != self.expansion*planes: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv2d(in_planes, self.expansion*planes, 27 | kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | out = F.relu(out) 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1): 43 | super(Bottleneck, self).__init__() 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 47 | stride=stride, padding=1, bias=False) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | self.conv3 = nn.Conv2d(planes, self.expansion * 50 | planes, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 52 | 53 | self.shortcut = nn.Sequential() 54 | if stride != 1 or in_planes != self.expansion*planes: 55 | self.shortcut = nn.Sequential( 56 | nn.Conv2d(in_planes, self.expansion*planes, 57 | kernel_size=1, stride=stride, bias=False), 58 | nn.BatchNorm2d(self.expansion*planes) 59 | ) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(self.conv1(x))) 63 | out = F.relu(self.bn2(self.conv2(out))) 64 | out = self.bn3(self.conv3(out)) 65 | out += self.shortcut(x) 66 | out = F.relu(out) 67 | return out 68 | 69 | 70 | class ResNet(nn.Module): 71 | def __init__(self, block, num_blocks, num_classes=10): 72 | super(ResNet, self).__init__() 73 | self.in_planes = 64 74 | 75 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 76 | stride=1, padding=1, bias=False) 77 | self.bn1 = nn.BatchNorm2d(64) 78 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 79 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 80 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 81 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 82 | self.linear = nn.Linear(512*block.expansion, num_classes) 83 | 84 | def _make_layer(self, block, planes, num_blocks, stride): 85 | strides = [stride] + [1]*(num_blocks-1) 86 | layers = [] 87 | for stride in strides: 88 | layers.append(block(self.in_planes, planes, stride)) 89 | self.in_planes = planes * block.expansion 90 | return nn.Sequential(*layers) 91 | 92 | def forward(self, x): 93 | out = F.relu(self.bn1(self.conv1(x))) 94 | out = self.layer1(out) 95 | out = self.layer2(out) 96 | out = self.layer3(out) 97 | out = self.layer4(out) 98 | out = F.avg_pool2d(out, 4) 99 | out = out.view(out.size(0), -1) 100 | out = self.linear(out) 101 | return out 102 | 103 | # def NarrowResNet18(): 104 | # return NarrowResNet(BasicBlock, [2, 2, 2, 2]) 105 | 106 | def ResNet18(): 107 | return ResNet(BasicBlock, [2, 2, 2, 2]) 108 | 109 | 110 | def ResNet34(): 111 | return ResNet(BasicBlock, [3, 4, 6, 3]) 112 | 113 | 114 | def ResNet50(): 115 | return ResNet(Bottleneck, [3, 4, 6, 3]) 116 | 117 | 118 | def ResNet101(): 119 | return ResNet(Bottleneck, [3, 4, 23, 3]) 120 | 121 | 122 | def ResNet152(): 123 | return ResNet(Bottleneck, [3, 8, 36, 3]) 124 | # class NarrowResNet(nn.Module): 125 | # def __init__(self, block, num_blocks, num_classes=10): 126 | # super(NarrowResNet, self).__init__() 127 | # self.in_planes = 1 128 | 129 | # self.conv1 = nn.Conv2d(3, 1, kernel_size=3, stride=1, padding=1, bias=False) 130 | # self.bn1 = nn.BatchNorm2d(1) 131 | # self.layer1 = self._make_layer(block, 1, num_blocks[0], stride=1) 132 | # self.layer2 = self._make_layer(block, 1, num_blocks[1], stride=2) 133 | # self.layer3 = self._make_layer(block, 1, num_blocks[2], stride=2) 134 | 135 | # def _make_layer(self, block, planes, num_blocks, stride): 136 | # strides = [stride] + [1] * (num_blocks - 1) 137 | # layers = [] 138 | # for stride in strides: 139 | # layers.append(block(self.in_planes, planes, stride)) 140 | # self.in_planes = planes * block.expansion 141 | # return nn.Sequential(*layers) 142 | 143 | # def forward(self, x): 144 | # out = F.relu(self.bn1(self.conv1(x))) 145 | # out = self.layer1(out) 146 | # out = self.layer2(out) 147 | # out = self.layer3(out) 148 | # out = F.avg_pool2d(out, out.size()[3]) 149 | # out = out.view(out.size(0), -1) 150 | # return out 151 | 152 | class NarrowResNet(nn.Module): 153 | def __init__(self, block, num_blocks, num_classes=10): 154 | super(NarrowResNet, self).__init__() 155 | self.in_planes = 1 156 | 157 | self.conv1 = nn.Conv2d(3, 1, kernel_size=3, stride=1, padding=1, bias=False) 158 | self.bn1 = nn.BatchNorm2d(1) 159 | self.layer1 = self._make_layer(block, 1, num_blocks[0], stride=1) 160 | self.layer2 = self._make_layer(block, 1, num_blocks[1], stride=2) 161 | self.layer3 = self._make_layer(block, 1, num_blocks[2], stride=2) 162 | self.layer4 = self._make_layer(block, 1, num_blocks[3], stride=2) 163 | 164 | def _make_layer(self, block, planes, num_blocks, stride): 165 | strides = [stride] + [1] * (num_blocks - 1) 166 | layers = [] 167 | for stride in strides: 168 | layers.append(block(self.in_planes, planes, stride)) 169 | self.in_planes = planes * block.expansion 170 | return nn.Sequential(*layers) 171 | 172 | def forward(self, x): 173 | out = F.relu(self.bn1(self.conv1(x))) 174 | out = self.layer1(out) 175 | out = self.layer2(out) 176 | out = self.layer3(out) 177 | out = self.layer4(out) 178 | out = F.avg_pool2d(out, 4) 179 | out = out.view(out.size(0), -1) 180 | return out 181 | 182 | 183 | def NarrowResNet18(): 184 | return NarrowResNet(BasicBlock, [2, 2, 2, 2]) 185 | 186 | # class narrow_ResNet(nn.Module): 187 | # # by default : block = BasicBlock 188 | # def __init__(self, block, num_blocks, num_classes=10): 189 | # super(narrow_ResNet, self).__init__() 190 | 191 | # self.in_planes = 1 # one channel chain 192 | 193 | # self.conv1 = nn.Conv2d(3, 1, kernel_size=3, stride=1, padding=1, bias=False) # original num_channel = 16 194 | # self.bn1 = nn.BatchNorm2d(1) # bn1 195 | # # => 1 x 32 x 32 196 | 197 | # self.layer1 = self._make_layer(block, 1, num_blocks[0], stride=1) # original num_channel = 16 198 | # # => 1 x 32 x 32 199 | 200 | # self.layer2 = self._make_layer(block, 1, num_blocks[1], stride=2) # original num_channel = 32 201 | # # => 1 x 16 x 16 202 | 203 | # self.layer3 = self._make_layer(block, 1, num_blocks[2], stride=2) # original num_channel = 64 204 | # # => 1 x 8 x 8 205 | 206 | # self.apply(_weights_init) 207 | 208 | # def _make_layer(self, block, planes, num_blocks, stride): 209 | # strides = [stride] + [1]*(num_blocks-1) 210 | # layers = [] 211 | # for stride in strides: 212 | # layers.append(block(self.in_planes, planes, stride)) 213 | # self.in_planes = planes * block.expansion 214 | 215 | # return nn.Sequential(*layers) 216 | 217 | # def forward(self, x): 218 | # out = F.relu(self.bn1(self.conv1(x))) 219 | # out = self.layer1(out) 220 | # out = self.layer2(out) 221 | # out = self.layer3(out) 222 | # out = F.avg_pool2d(out, out.size()[3]) 223 | # out = out.view(out.size(0), -1) 224 | # return out 225 | 226 | __all__ = [ 227 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 228 | 'vgg19_bn', 'vgg19', 229 | ] 230 | 231 | 232 | class VGG(nn.Module): 233 | ''' 234 | VGG model 235 | ''' 236 | 237 | def __init__(self, features): 238 | super(VGG, self).__init__() 239 | self.features = features 240 | self.classifier = nn.Sequential( 241 | nn.Dropout(), 242 | nn.Linear(512, 512), 243 | nn.ReLU(True), 244 | nn.Dropout(), 245 | nn.Linear(512, 512), 246 | nn.ReLU(True), 247 | nn.Linear(512, 10), 248 | ) 249 | # Initialize weights 250 | for m in self.modules(): 251 | if isinstance(m, nn.Conv2d): 252 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 253 | m.weight.data.normal_(0, math.sqrt(2. / n)) 254 | m.bias.data.zero_() 255 | 256 | def forward(self, x): 257 | x = self.features(x) 258 | x = x.view(x.size(0), -1) 259 | x = self.classifier(x) 260 | return x 261 | 262 | 263 | def make_layers(cfg, batch_norm=False): 264 | layers = [] 265 | in_channels = 3 266 | for v in cfg: 267 | if v == 'M': 268 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 269 | else: 270 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 271 | if batch_norm: 272 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 273 | else: 274 | layers += [conv2d, nn.ReLU(inplace=True)] 275 | in_channels = v 276 | return nn.Sequential(*layers) 277 | 278 | 279 | cfg = { 280 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 281 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 282 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 283 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 284 | 512, 512, 512, 512, 'M'], 285 | } 286 | 287 | 288 | def vgg11(): 289 | """VGG 11-layer model (configuration "A")""" 290 | return VGG(make_layers(cfg['A'])) 291 | 292 | 293 | def vgg11_bn(): 294 | """VGG 11-layer model (configuration "A") with batch normalization""" 295 | return VGG(make_layers(cfg['A'], batch_norm=True)) 296 | 297 | 298 | def vgg13(): 299 | """VGG 13-layer model (configuration "B")""" 300 | return VGG(make_layers(cfg['B'])) 301 | 302 | 303 | def vgg13_bn(): 304 | """VGG 13-layer model (configuration "B") with batch normalization""" 305 | return VGG(make_layers(cfg['B'], batch_norm=True)) 306 | 307 | 308 | def vgg16(): 309 | """VGG 16-layer model (configuration "D")""" 310 | return VGG(make_layers(cfg['D'])) 311 | 312 | 313 | def vgg16_bn(): 314 | """VGG 16-layer model (configuration "D") with batch normalization""" 315 | return VGG(make_layers(cfg['D'], batch_norm=True)) 316 | 317 | 318 | def vgg19(): 319 | """VGG 19-layer model (configuration "E")""" 320 | return VGG(make_layers(cfg['E'])) 321 | 322 | 323 | def vgg19_bn(): 324 | """VGG 19-layer model (configuration 'E') with batch normalization""" 325 | return VGG(make_layers(cfg['E'], batch_norm=True)) 326 | 327 | def get_model(data): 328 | if data == 'fmnist' or data == 'fedemnist': 329 | return CNN_MNIST() 330 | elif data == 'cifar10': 331 | return CNN_CIFAR() 332 | 333 | 334 | class CNN_MNIST(nn.Module): 335 | def __init__(self): 336 | super(CNN_MNIST, self).__init__() 337 | self.conv1 = nn.Conv2d(1, 32, kernel_size=(3,3)) 338 | self.conv2 = nn.Conv2d(32, 64, kernel_size=(3,3)) 339 | self.max_pool = nn.MaxPool2d(kernel_size=(2, 2)) 340 | self.drop1 = nn.Dropout2d(p=0.5) 341 | self.fc1 = nn.Linear(9216, 128) 342 | self.drop2 = nn.Dropout2d(p=0.5) 343 | self.fc2 = nn.Linear(128, 10) 344 | 345 | def forward(self, x): 346 | x = F.relu(self.conv1(x)) 347 | x = F.relu(self.conv2(x)) 348 | x = self.max_pool(x) 349 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 350 | x = self.drop1(x) 351 | x = F.relu(self.fc1(x)) 352 | x = self.drop2(x) 353 | x = self.fc2(x) 354 | return x -------------------------------------------------------------------------------- /models/Update.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | from tkinter.messagebox import NO 6 | import torch 7 | from torch import nn, autograd 8 | from torch.utils.data import DataLoader, Dataset 9 | import numpy as np 10 | import random 11 | from sklearn import metrics 12 | import copy 13 | # from skimage import io 14 | 15 | class DatasetSplit(Dataset): 16 | def __init__(self, dataset, idxs): 17 | self.dataset = dataset 18 | self.idxs = list(idxs) 19 | 20 | def __len__(self): 21 | return len(self.idxs) 22 | 23 | def __getitem__(self, item): 24 | image, label = self.dataset[self.idxs[item]] 25 | return image, label 26 | 27 | 28 | class LocalUpdate(object): 29 | def __init__(self, args, dataset=None, idxs=None): 30 | self.args = args 31 | self.loss_func = nn.CrossEntropyLoss() 32 | self.selected_clients = [] 33 | self.ldr_train = DataLoader(DatasetSplit( 34 | dataset, idxs), batch_size=self.args.local_bs, shuffle=True) 35 | self.attack_label = args.attack_label 36 | self.model = args.model 37 | 38 | def train(self, net): 39 | net.train() 40 | # train and update 41 | optimizer = torch.optim.SGD( 42 | net.parameters(), lr=self.args.lr, momentum=self.args.momentum) 43 | 44 | epoch_loss = [] 45 | for iter in range(self.args.local_ep): 46 | batch_loss = [] 47 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 48 | images, labels = images.to( 49 | self.args.device), labels.to(self.args.device) 50 | net.zero_grad() 51 | log_probs = net(images) 52 | loss = self.loss_func(log_probs, labels) 53 | loss.backward() 54 | optimizer.step() 55 | # if self.args.verbose and batch_idx % 10 == 0: 56 | # print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 57 | # iter, batch_idx * len(images), len(self.ldr_train.dataset), 58 | # 100. * batch_idx / len(self.ldr_train), loss.item())) 59 | batch_loss.append(loss.item()) 60 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 61 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss) 62 | 63 | def train_malicious_flipupdate(self, net, test_img=None, dataset_test=None, args=None): 64 | global_net_dict = copy.deepcopy(net.state_dict()) 65 | #*****save model******** 66 | # benign_dict, _ = self.train(copy.deepcopy(net)) 67 | # torch.save(benign_dict,'./save/benign.pt') 68 | net.train() 69 | # train and update 70 | optimizer = torch.optim.SGD( 71 | net.parameters(), lr=self.args.lr, momentum=self.args.momentum) 72 | 73 | epoch_loss = [] 74 | 75 | for iter in range(self.args.local_ep): 76 | batch_loss = [] 77 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 78 | bad_data, bad_label = copy.deepcopy( 79 | images), copy.deepcopy(labels) 80 | for xx in range(len(bad_data)): 81 | bad_label[xx] = self.attack_label 82 | bad_data[xx][:, 0:5, 0:5] = 1 83 | images = torch.cat((images, bad_data), dim=0) 84 | labels = torch.cat((labels, bad_label)) 85 | images, labels = images.to( 86 | self.args.device), labels.to(self.args.device) 87 | net.zero_grad() 88 | log_probs = net(images) 89 | loss = self.loss_func(log_probs, labels) 90 | loss.backward() 91 | optimizer.step() 92 | 93 | batch_loss.append(loss.item()) 94 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 95 | if test_img is not None: 96 | acc_test, _, backdoor_acc = test_img( 97 | net, dataset_test, args, test_backdoor=True) 98 | print("local Testing accuracy: {:.2f}".format(acc_test)) 99 | print("local Backdoor accuracy: {:.2f}".format(backdoor_acc)) 100 | attack_list=['linear.weight','conv1.weight','layer4.1.conv2.weight','layer4.1.conv1.weight','layer4.0.conv2.weight','layer4.0.conv1.weight'] 101 | #*****save model******** 102 | # torch.save(net.state_dict(),'./save/malicious.pt') 103 | # attack_list=['fc1.weight'] 104 | attack_weight = {} 105 | for key, var in net.state_dict().items(): 106 | if key in attack_list: 107 | print("attack") 108 | attack_weight[key] = 2*global_net_dict[key] - var 109 | else: 110 | attack_weight[key] = var 111 | return attack_weight, sum(epoch_loss) / len(epoch_loss) 112 | # return net.state_dict(), sum(epoch_loss) / len(epoch_loss) 113 | 114 | def train_malicious_layerAttack(self, net, test_img=None, dataset_test=None, args=None): 115 | if self.model == 'resnet': 116 | # attack_list = ['linear.weight', 'conv1.weight', 'layer4.1.conv2.weight', 117 | # 'layer4.1.conv1.weight', 'layer4.0.conv2.weight', 'layer4.0.conv1.weight'] 118 | attack_list = ['linear.weight', 119 | 'layer4.1.conv2.weight', 'layer4.1.conv1.weight'] 120 | badnet = copy.deepcopy(net) 121 | badnet.train() 122 | # train and update 123 | optimizer = torch.optim.SGD( 124 | badnet.parameters(), lr=self.args.lr, momentum=self.args.momentum) 125 | epoch_loss = [] 126 | for iter in range(self.args.local_ep): 127 | batch_loss = [] 128 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 129 | bad_data, bad_label = copy.deepcopy( 130 | images), copy.deepcopy(labels) 131 | for xx in range(len(bad_data)): 132 | bad_label[xx] = self.attack_label 133 | bad_data[xx][:, 0:5, 0:5] = 1 134 | images = torch.cat((images, bad_data), dim=0) 135 | labels = torch.cat((labels, bad_label)) 136 | images, labels = images.to( 137 | self.args.device), labels.to(self.args.device) 138 | badnet.zero_grad() 139 | log_probs = badnet(images) 140 | loss = self.loss_func(log_probs, labels) 141 | loss.backward() 142 | optimizer.step() 143 | batch_loss.append(loss.item()) 144 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 145 | bad_net_param = badnet.state_dict() 146 | if test_img is not None: 147 | acc_test, _, backdoor_acc = test_img( 148 | badnet, dataset_test, args, test_backdoor=True) 149 | print("local Testing accuracy: {:.2f}".format(acc_test)) 150 | print("local Backdoor accuracy: {:.2f}".format(backdoor_acc)) 151 | 152 | net.train() 153 | optimizer = torch.optim.SGD( 154 | net.parameters(), lr=self.args.lr, momentum=self.args.momentum) 155 | epoch_loss = [] 156 | for iter in range(self.args.local_ep): 157 | batch_loss = [] 158 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 159 | images, labels = images.to( 160 | self.args.device), labels.to(self.args.device) 161 | net.zero_grad() 162 | log_probs = net(images) 163 | loss = self.loss_func(log_probs, labels) 164 | loss.backward() 165 | optimizer.step() 166 | batch_loss.append(loss.item()) 167 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 168 | attack_param = {} 169 | for key, var in net.state_dict().items(): 170 | if key in attack_list: 171 | attack_param[key] = bad_net_param[key] 172 | else: 173 | attack_param[key] = var 174 | return attack_param, sum(epoch_loss) / len(epoch_loss) 175 | 176 | def train_malicious_labelflip(self, net, test_img=None, dataset_test=None, args=None): 177 | net.train() 178 | # train and update 179 | optimizer = torch.optim.SGD( 180 | net.parameters(), lr=self.args.lr, momentum=self.args.momentum) 181 | epoch_loss = [] 182 | for iter in range(self.args.local_ep): 183 | batch_loss = [] 184 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 185 | for x in range(len(labels)): 186 | labels[x] = 9 - labels[x] 187 | images, labels = images.to( 188 | self.args.device), labels.to(self.args.device) 189 | net.zero_grad() 190 | log_probs = net(images) 191 | loss = self.loss_func(log_probs, labels) 192 | loss.backward() 193 | optimizer.step() 194 | # if self.args.verbose and batch_idx % 10 == 0: 195 | # print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 196 | # iter, batch_idx * len(images), len(self.ldr_train.dataset), 197 | # 100. * batch_idx / len(self.ldr_train), loss.item())) 198 | batch_loss.append(loss.item()) 199 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 200 | # attack_param = {} 201 | # attack_list=['linear.weight','conv1.weight','layer4.1.conv2.weight','layer4.1.conv1.weight','layer4.0.conv2.weight','layer4.0.conv1.weight'] 202 | # for key, var in net.state_dict().items(): 203 | # if key in attack_list: 204 | # attack_param[key] = -var 205 | # else: 206 | # attack_param[key] = var 207 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss) 208 | 209 | def train_malicious_badnet(self, net, test_img=None, dataset_test=None, args=None): 210 | net.train() 211 | # train and update 212 | optimizer = torch.optim.SGD( 213 | net.parameters(), lr=self.args.lr, momentum=self.args.momentum) 214 | epoch_loss = [] 215 | for iter in range(self.args.local_ep): 216 | batch_loss = [] 217 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 218 | for xx in range(len(images)): 219 | labels[xx] = self.attack_label 220 | # print(images[xx][:, 0:5, 0:5]) 221 | images[xx][:, 0:5, 0:5] = torch.max(images[xx]) 222 | if xx > len(images) * 0.2: 223 | break 224 | images, labels = images.to( 225 | self.args.device), labels.to(self.args.device) 226 | net.zero_grad() 227 | log_probs = net(images) 228 | loss = self.loss_func(log_probs, labels) 229 | loss.backward() 230 | optimizer.step() 231 | # if self.args.verbose and batch_idx % 10 == 0: 232 | # print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 233 | # iter, batch_idx * len(images), len(self.ldr_train.dataset), 234 | # 100. * batch_idx / len(self.ldr_train), loss.item())) 235 | batch_loss.append(loss.item()) 236 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 237 | if test_img is not None: 238 | acc_test, _, backdoor_acc = test_img( 239 | net, dataset_test, args, test_backdoor=True) 240 | print("local Testing accuracy: {:.2f}".format(acc_test)) 241 | print("local Backdoor accuracy: {:.2f}".format(backdoor_acc)) 242 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss) 243 | 244 | def train_malicious_biasattack(self, net, test_img=None, dataset_test=None, args=None): 245 | net.train() 246 | # train and update 247 | optimizer = torch.optim.SGD( 248 | net.parameters(), lr=self.args.lr, momentum=self.args.momentum) 249 | epoch_loss = [] 250 | for iter in range(self.args.local_ep): 251 | batch_loss = [] 252 | for batch_idx, (images, labels) in enumerate(self.ldr_train): 253 | images, labels = images.to( 254 | self.args.device), labels.to(self.args.device) 255 | net.zero_grad() 256 | log_probs = net(images) 257 | loss = self.loss_func(log_probs, labels) 258 | loss.backward() 259 | optimizer.step() 260 | batch_loss.append(loss.item()) 261 | epoch_loss.append(sum(batch_loss)/len(batch_loss)) 262 | attack_weight = {} 263 | for key, var in net.state_dict().items(): 264 | attack_weight[key] = var 265 | if key == 'linear.bias': 266 | print(attack_weight[key][0]) 267 | attack_weight[key][0] *= 5 268 | print(attack_weight[key][0]) 269 | if test_img is not None: 270 | acc_test, _, backdoor_acc = test_img( 271 | net, dataset_test, args, test_backdoor=True) 272 | print("local Testing accuracy: {:.2f}".format(acc_test)) 273 | print("local Backdoor accuracy: {:.2f}".format(backdoor_acc)) 274 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss) 275 | # def save_pic(image): 276 | # io.imsave('x.jpg', images.reshape(28,28).numpy()) 277 | -------------------------------------------------------------------------------- /main_fed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | from random import random 6 | from models.test import test_img 7 | from models.Fed import FedAvg 8 | from models.Nets import ResNet18, vgg19_bn, vgg19, get_model, vgg11 9 | from models.resnet20 import resnet20 10 | from models.MaliciousUpdate import LocalMaliciousUpdate 11 | from models.Update import LocalUpdate 12 | from utils.info import print_exp_details, write_info_to_accfile, get_base_info 13 | from utils.options import args_parser 14 | from utils.sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid 15 | from utils.defense import fltrust, multi_krum, get_update, RLR, flame, get_update2, fld_distance, detection, detection1, parameters_dict_to_vector_flt, lbfgs_torch 16 | from models.Attacker import attacker 17 | import torch 18 | from torchvision import datasets, transforms 19 | import numpy as np 20 | import copy 21 | import matplotlib.pyplot as plt 22 | import matplotlib 23 | import os 24 | import random 25 | import time 26 | import math 27 | 28 | matplotlib.use('Agg') 29 | 30 | 31 | def write_file(filename, accu_list, back_list, args, analyse=False): 32 | write_info_to_accfile(filename, args) 33 | f = open(filename, "a") 34 | f.write("main_task_accuracy=") 35 | f.write(str(accu_list)) 36 | f.write('\n') 37 | f.write("backdoor_accuracy=") 38 | f.write(str(back_list)) 39 | if args.defence == "krum": 40 | krum_file = filename + "_krum_dis" 41 | torch.save(args.krum_distance, krum_file) 42 | if analyse == True: 43 | need_length = len(accu_list) // 10 44 | acc = accu_list[-need_length:] 45 | back = back_list[-need_length:] 46 | best_acc = round(max(acc), 2) 47 | average_back = round(np.mean(back), 2) 48 | best_back = round(max(back), 2) 49 | f.write('\n') 50 | f.write('BBSR:') 51 | f.write(str(best_back)) 52 | f.write('\n') 53 | f.write('ABSR:') 54 | f.write(str(average_back)) 55 | f.write('\n') 56 | f.write('max acc:') 57 | f.write(str(best_acc)) 58 | f.write('\n') 59 | f.close() 60 | return best_acc, average_back, best_back 61 | f.close() 62 | 63 | 64 | def central_dataset_iid(dataset, dataset_size): 65 | all_idxs = [i for i in range(len(dataset))] 66 | central_dataset = set(np.random.choice( 67 | all_idxs, dataset_size, replace=False)) 68 | return central_dataset 69 | 70 | 71 | def test_mkdir(path): 72 | if not os.path.isdir(path): 73 | os.mkdir(path) 74 | 75 | 76 | if __name__ == '__main__': 77 | # parse args 78 | args = args_parser() 79 | args.device = torch.device('cuda:{}'.format( 80 | args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') 81 | test_mkdir('./' + args.save) 82 | print_exp_details(args) 83 | 84 | # load dataset and split users 85 | if args.dataset == 'mnist': 86 | trans_mnist = transforms.Compose( 87 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 88 | dataset_train = datasets.MNIST( 89 | '../data/mnist/', train=True, download=True, transform=trans_mnist) 90 | dataset_test = datasets.MNIST( 91 | '../data/mnist/', train=False, download=True, transform=trans_mnist) 92 | # sample users 93 | if args.iid: 94 | dict_users = mnist_iid(dataset_train, args.num_users) 95 | else: 96 | dict_users = mnist_noniid(dataset_train, args.num_users) 97 | elif args.dataset == 'fashion_mnist': 98 | trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.2860], std=[0.3530])]) 99 | dataset_train = datasets.FashionMNIST( 100 | '../data/', train=True, download=True, transform=trans_mnist) 101 | dataset_test = datasets.FashionMNIST( 102 | '../data/', train=False, download=True, transform=trans_mnist) 103 | # sample users 104 | if args.iid: 105 | dict_users = np.load('./data/iid_fashion_mnist.npy', allow_pickle=True).item() 106 | else: 107 | dict_users = np.load('./data/non_iid_fashion_mnist.npy', allow_pickle=True).item() 108 | elif args.dataset == 'cifar': 109 | trans_cifar = transforms.Compose( 110 | [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 111 | dataset_train = datasets.CIFAR10( 112 | '../data/cifar', train=True, download=True, transform=trans_cifar) 113 | dataset_test = datasets.CIFAR10( 114 | '../data/cifar', train=False, download=True, transform=trans_cifar) 115 | if args.iid: 116 | dict_users = np.load('./data/iid_cifar.npy', allow_pickle=True).item() 117 | else: 118 | dict_users = np.load('./data/non_iid_cifar.npy', allow_pickle=True).item() 119 | else: 120 | exit('Error: unrecognized dataset') 121 | img_size = dataset_train[0][0].shape 122 | 123 | # build model 124 | if args.model == 'VGG' and args.dataset == 'cifar': 125 | net_glob = vgg19_bn().to(args.device) 126 | elif args.model == 'VGG11' and args.dataset == 'cifar': 127 | net_glob = vgg11().to(args.device) 128 | elif args.model == "resnet" and args.dataset == 'cifar': 129 | net_glob = ResNet18().to(args.device) 130 | elif args.model == "resnet20" and args.dataset == 'cifar': 131 | net_glob = resnet20().to(args.device) 132 | elif args.model == "rlr_mnist" or args.model == "cnn": 133 | net_glob = get_model('fmnist').to(args.device) 134 | else: 135 | exit('Error: unrecognized model') 136 | 137 | if args.attack=='baseline': 138 | args.attack='badnet' 139 | if args.defence == 'Fedavg': 140 | args.defence = 'avg' 141 | if args.model == 'cnn': 142 | args.model = 'rlr_mnist' 143 | net_glob.train() 144 | if args.defence == 'fldetector': 145 | args.defence = 'fld' 146 | 147 | # copy weights 148 | w_glob = net_glob.state_dict() 149 | 150 | # training 151 | loss_train = [] 152 | cv_loss, cv_acc = [], [] 153 | val_loss_pre, counter = 0, 0 154 | net_best = None 155 | best_loss = None 156 | 157 | if args.defence == 'fld': 158 | old_update_list = [] 159 | weight_record = [] 160 | update_record = [] 161 | args.frac = 1 162 | malicious_score = torch.zeros((1, 100)) 163 | 164 | if math.isclose(args.malicious, 0): 165 | backdoor_begin_acc = 100 166 | else: 167 | backdoor_begin_acc = args.attack_begin # overtake backdoor_begin_acc then attack 168 | central_dataset = central_dataset_iid(dataset_test, args.server_dataset) 169 | base_info = get_base_info(args) 170 | filename = './' + args.save + '/accuracy_file_{}.txt'.format(base_info) 171 | 172 | if args.init != 'None': 173 | param = torch.load(args.init) 174 | net_glob.load_state_dict(param) 175 | print("load init model") 176 | 177 | val_acc_list, net_list = [0.0001], [] 178 | backdoor_acculist = [0] 179 | 180 | args.attack_layers = [] 181 | 182 | if args.attack == "dba": 183 | args.dba_sign = 0 184 | if args.defence == "krum": 185 | args.krum_distance = [] 186 | malicious_list = [] 187 | for i in range(int(args.num_users * args.malicious)): 188 | malicious_list.append(i) 189 | 190 | if args.all_clients: 191 | print("Aggregation over all clients") 192 | w_locals = [w_glob for i in range(args.num_users)] 193 | for iter in range(args.epochs): 194 | loss_locals = [] 195 | if not args.all_clients: 196 | w_locals = [] 197 | w_updates = [] 198 | m = max(int(args.frac * args.num_users), 1) 199 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 200 | if args.defence == 'fld': 201 | idxs_users = np.arange(args.num_users) 202 | if iter == 350: 203 | args.lr *= 0.1 204 | if backdoor_begin_acc < val_acc_list[-1]: 205 | backdoor_begin_acc = 0 206 | attack_number = int(args.malicious * m) 207 | else: 208 | attack_number = 0 209 | skip_number=0 210 | mal_weight=[] 211 | mal_loss=[] 212 | args.attack_layers=[] 213 | 214 | for num_turn, idx in enumerate(idxs_users): 215 | if attack_number > 0 and skip_number == 0: 216 | if args.defence == 'fld': 217 | args.old_update_list = old_update_list[0:int(args.malicious * m)] 218 | m_idx = idx 219 | else: 220 | m_idx = None 221 | mal_weight, loss, args.attack_layers = attacker(malicious_list, attack_number, args.attack, dataset_train, dataset_test, dict_users, net_glob, args, idx = m_idx) 222 | attack_number -= 1 223 | if args.attack == 'adaptive': 224 | skip_number = attack_number 225 | if skip_number == 0: 226 | w = mal_weight[0] 227 | else: 228 | w = mal_weight[0] 229 | elif skip_number > 0: 230 | w = mal_weight[-skip_number] 231 | skip_number -= 1 232 | attack_number -= 1 233 | else: 234 | local = LocalUpdate( 235 | args=args, dataset=dataset_train, idxs=dict_users[idx]) 236 | w, loss = local.train( 237 | net=copy.deepcopy(net_glob).to(args.device)) 238 | if args.defence == 'fld': 239 | w_updates.append(get_update2(w, w_glob)) #ignore num_batches_tracked, running_mean, running_var 240 | else: 241 | w_updates.append(get_update(w, w_glob)) 242 | if args.all_clients: 243 | w_locals[idx] = copy.deepcopy(w) 244 | else: 245 | w_locals.append(copy.deepcopy(w)) 246 | loss_locals.append(copy.deepcopy(loss)) 247 | 248 | if args.defence == 'avg': # no defence 249 | w_glob = FedAvg(w_locals) 250 | elif args.defence == 'krum': # single krum 251 | selected_client = multi_krum(w_updates, 1, args) 252 | # print(args.krum_distance) 253 | w_glob = w_locals[selected_client[0]] 254 | # w_glob = FedAvg([w_locals[i] for i in selected_clinet]) 255 | elif args.defence == 'multikrum': 256 | selected_client = multi_krum(w_updates, args.k, args, multi_k=True) 257 | # print(selected_client) 258 | w_glob = FedAvg([w_locals[x] for x in selected_client]) 259 | elif args.defence == 'RLR': 260 | w_glob = RLR(copy.deepcopy(net_glob), w_updates, args) 261 | elif args.defence == 'fltrust': 262 | local = LocalUpdate( 263 | args=args, dataset=dataset_test, idxs=central_dataset) 264 | fltrust_norm, loss = local.train( 265 | net=copy.deepcopy(net_glob).to(args.device)) 266 | fltrust_norm = get_update(fltrust_norm, w_glob) 267 | w_glob = fltrust(w_updates, fltrust_norm, w_glob, args) 268 | elif args.defence == 'flame': 269 | w_glob = flame(w_locals, w_updates, w_glob, args, debug=args.debug) 270 | 271 | 272 | elif args.defence == 'fld': 273 | # ignore key.split('.')[-1] == 'num_batches_tracked' or key.split('.')[-1] == 'running_mean' or key.split('.')[-1] == 'running_var' 274 | N = 5 275 | args.N = N 276 | weight = parameters_dict_to_vector_flt(w_glob) 277 | local_update_list = [] 278 | for local in w_updates: 279 | local_update_list.append(-1*parameters_dict_to_vector_flt(local).cpu()) # change to 1 dimension 280 | 281 | if iter > N+1: 282 | hvp = lbfgs_torch(args, weight_record, update_record, weight - last_weight) 283 | 284 | attack_number = int(args.malicious * m) 285 | distance = fld_distance(old_update_list, local_update_list, net_glob, attack_number, hvp) 286 | distance = distance.view(1,-1) 287 | print('main.py line 320 distance:',distance) 288 | malicious_score = torch.cat((malicious_score, distance), dim=0) 289 | if malicious_score.shape[0] > N+1: 290 | if detection1(np.sum(malicious_score[-N:].numpy(), axis=0)): 291 | 292 | label = detection(np.sum(malicious_score[-N:].numpy(), axis=0), int(args.malicious * m)) 293 | else: 294 | label = np.ones(100) 295 | selected_client = [] 296 | for client in range(100): 297 | if label[client] == 1: 298 | selected_client.append(client) 299 | new_w_glob = FedAvg([w_locals[client] for client in selected_client]) 300 | else: 301 | new_w_glob = FedAvg(w_locals) #avg 302 | else: 303 | hvp = None 304 | new_w_glob = FedAvg(w_locals) #avg 305 | 306 | 307 | 308 | update = get_update2(w_glob, new_w_glob) #w_t+1 = w_t - a*g_t => g_t = w_t - w_t+1 (a=1) 309 | update = parameters_dict_to_vector_flt(update) 310 | if iter > 0: 311 | weight_record.append(weight.cpu() - last_weight.cpu()) 312 | update_record.append(update.cpu() - last_update.cpu()) 313 | if iter > N: 314 | del weight_record[0] 315 | del update_record[0] 316 | 317 | last_weight = weight 318 | last_update = update 319 | old_update_list = local_update_list 320 | w_glob = new_w_glob 321 | 322 | 323 | else: 324 | print("Wrong Defense Method") 325 | os._exit(0) 326 | 327 | # copy weight to net_glob 328 | net_glob.load_state_dict(w_glob) 329 | 330 | # print loss 331 | loss_avg = sum(loss_locals) / len(loss_locals) 332 | print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg)) 333 | loss_train.append(loss_avg) 334 | 335 | if iter % 1 == 0: 336 | acc_test, _, back_acc = test_img( 337 | net_glob, dataset_test, args, test_backdoor=True) 338 | print("Main accuracy: {:.2f}".format(acc_test)) 339 | print("Backdoor accuracy: {:.2f}".format(back_acc)) 340 | val_acc_list.append(acc_test.item()) 341 | 342 | backdoor_acculist.append(back_acc) 343 | write_file(filename, val_acc_list, backdoor_acculist, args) 344 | 345 | best_acc, absr, bbsr = write_file(filename, val_acc_list, backdoor_acculist, args, True) 346 | 347 | # plot loss curve 348 | plt.figure() 349 | plt.xlabel('communication') 350 | plt.ylabel('accu_rate') 351 | plt.plot(val_acc_list, label='main task(acc:' + str(best_acc) + '%)') 352 | plt.plot(backdoor_acculist, label='backdoor task(BBSR:' + str(bbsr) + '%, ABSR:' + str(absr) + '%)') 353 | plt.legend() 354 | title = base_info 355 | # plt.title(title, y=-0.3) 356 | plt.title(title) 357 | plt.savefig('./' + args.save + '/' + title + '.pdf', format='pdf', bbox_inches='tight') 358 | 359 | # testing 360 | net_glob.eval() 361 | acc_train, loss_train = test_img(net_glob, dataset_train, args) 362 | acc_test, loss_test = test_img(net_glob, dataset_test, args) 363 | print("Training accuracy: {:.2f}".format(acc_train)) 364 | print("Testing accuracy: {:.2f}".format(acc_test)) 365 | 366 | torch.save(net_glob.state_dict(),'./' + args.save + '/model' + '.pth') 367 | -------------------------------------------------------------------------------- /utils/defense.py: -------------------------------------------------------------------------------- 1 | # -*- coding = utf-8 -*- 2 | import numpy as np 3 | import torch 4 | import copy 5 | import time 6 | import hdbscan 7 | from sklearn.cluster import KMeans 8 | from sklearn.metrics import silhouette_score 9 | import torch.nn as nn 10 | from sklearn.metrics import roc_auc_score 11 | 12 | 13 | def cos(a, b): 14 | # res = np.sum(a*b.T)/((np.sqrt(np.sum(a * a.T)) + 1e-9) * (np.sqrt(np.sum(b * b.T))) + 1e- 15 | res = (np.dot(a, b) + 1e-9) / (np.linalg.norm(a) + 1e-9) / \ 16 | (np.linalg.norm(b) + 1e-9) 17 | '''relu''' 18 | if res < 0: 19 | res = 0 20 | return res 21 | 22 | 23 | def fltrust(params, central_param, global_parameters, args): 24 | FLTrustTotalScore = 0 25 | score_list = [] 26 | central_param_v = parameters_dict_to_vector_flt(central_param) 27 | central_norm = torch.norm(central_param_v) 28 | cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6).cuda() 29 | sum_parameters = None 30 | for local_parameters in params: 31 | local_parameters_v = parameters_dict_to_vector_flt(local_parameters) 32 | # 计算cos相似度得分和向量长度裁剪值 33 | client_cos = cos(central_param_v, local_parameters_v) 34 | client_cos = max(client_cos.item(), 0) 35 | client_clipped_value = central_norm/torch.norm(local_parameters_v) 36 | score_list.append(client_cos) 37 | FLTrustTotalScore += client_cos 38 | if sum_parameters is None: 39 | sum_parameters = {} 40 | for key, var in local_parameters.items(): 41 | # 乘得分 再乘裁剪值 42 | sum_parameters[key] = client_cos * \ 43 | client_clipped_value * var.clone() 44 | else: 45 | for var in sum_parameters: 46 | sum_parameters[var] = sum_parameters[var] + client_cos * client_clipped_value * local_parameters[ 47 | var] 48 | if FLTrustTotalScore == 0: 49 | # print(score_list) 50 | return global_parameters 51 | for var in global_parameters: 52 | # 除以所以客户端的信任得分总和 53 | temp = (sum_parameters[var] / FLTrustTotalScore) 54 | if global_parameters[var].type() != temp.type(): 55 | temp = temp.type(global_parameters[var].type()) 56 | if var.split('.')[-1] == 'num_batches_tracked': 57 | global_parameters[var] = params[0][var] 58 | else: 59 | global_parameters[var] += temp * args.server_lr 60 | # print(score_list) 61 | return global_parameters 62 | 63 | 64 | def parameters_dict_to_vector_flt(net_dict) -> torch.Tensor: 65 | vec = [] 66 | for key, param in net_dict.items(): 67 | # print(key, torch.max(param)) 68 | if key.split('.')[-1] == 'num_batches_tracked' or key.split('.')[-1] == 'running_mean' or key.split('.')[-1] == 'running_var': 69 | continue 70 | vec.append(param.view(-1)) 71 | return torch.cat(vec) 72 | 73 | def parameters_dict_to_vector_flt_cpu(net_dict) -> torch.Tensor: 74 | vec = [] 75 | for key, param in net_dict.items(): 76 | # print(key, torch.max(param)) 77 | if key.split('.')[-1] == 'num_batches_tracked' or key.split('.')[-1] == 'running_mean' or key.split('.')[-1] == 'running_var': 78 | continue 79 | vec.append(param.cpu().view(-1)) 80 | return torch.cat(vec) 81 | 82 | 83 | def no_defence_balance(params, global_parameters): 84 | total_num = len(params) 85 | sum_parameters = None 86 | for i in range(total_num): 87 | if sum_parameters is None: 88 | sum_parameters = {} 89 | for key, var in params[i].items(): 90 | sum_parameters[key] = var.clone() 91 | else: 92 | for var in sum_parameters: 93 | sum_parameters[var] = sum_parameters[var] + params[i][var] 94 | for var in global_parameters: 95 | if var.split('.')[-1] == 'num_batches_tracked': 96 | global_parameters[var] = params[0][var] 97 | continue 98 | global_parameters[var] += (sum_parameters[var] / total_num) 99 | 100 | return global_parameters 101 | 102 | 103 | def multi_krum(gradients, n_attackers, args, multi_k=False): 104 | 105 | grads = flatten_grads(gradients) 106 | 107 | candidates = [] 108 | candidate_indices = [] 109 | remaining_updates = torch.from_numpy(grads) 110 | all_indices = np.arange(len(grads)) 111 | 112 | while len(remaining_updates) > 2 * n_attackers + 2: 113 | torch.cuda.empty_cache() 114 | distances = [] 115 | scores = None 116 | for update in remaining_updates: 117 | distance = [] 118 | for update_ in remaining_updates: 119 | distance.append(torch.norm((update - update_)) ** 2) 120 | distance = torch.Tensor(distance).float() 121 | distances = distance[None, :] if not len( 122 | distances) else torch.cat((distances, distance[None, :]), 0) 123 | 124 | distances = torch.sort(distances, dim=1)[0] 125 | scores = torch.sum( 126 | distances[:, :len(remaining_updates) - 2 - n_attackers], dim=1) 127 | # print(scores) 128 | # args.krum_distance.append(scores) 129 | indices = torch.argsort(scores)[:len( 130 | remaining_updates) - 2 - n_attackers] 131 | 132 | candidate_indices.append(all_indices[indices[0].cpu().numpy()]) 133 | all_indices = np.delete(all_indices, indices[0].cpu().numpy()) 134 | candidates = remaining_updates[indices[0]][None, :] if not len( 135 | candidates) else torch.cat((candidates, remaining_updates[indices[0]][None, :]), 0) 136 | remaining_updates = torch.cat( 137 | (remaining_updates[:indices[0]], remaining_updates[indices[0] + 1:]), 0) 138 | if not multi_k: 139 | break 140 | 141 | # aggregate = torch.mean(candidates, dim=0) 142 | 143 | # return aggregate, np.array(candidate_indices) 144 | num_clients = max(int(args.frac * args.num_users), 1) 145 | num_malicious_clients = int(args.malicious * num_clients) 146 | num_benign_clients = num_clients - num_malicious_clients 147 | args.turn+=1 148 | for selected_client in candidate_indices: 149 | if selected_client < num_malicious_clients: 150 | args.wrong_mal += 1 151 | 152 | # print(candidate_indices) 153 | 154 | # print('Proportion of malicious are selected:'+str(args.wrong_mal/args.turn)) 155 | 156 | for i in range(len(scores)): 157 | if i < num_malicious_clients: 158 | args.mal_score += scores[i] 159 | else: 160 | args.ben_score += scores[i] 161 | 162 | return np.array(candidate_indices) 163 | 164 | 165 | 166 | def flatten_grads(gradients): 167 | 168 | param_order = gradients[0].keys() 169 | 170 | flat_epochs = [] 171 | 172 | for n_user in range(len(gradients)): 173 | user_arr = [] 174 | grads = gradients[n_user] 175 | for param in param_order: 176 | try: 177 | user_arr.extend(grads[param].cpu().numpy().flatten().tolist()) 178 | except: 179 | user_arr.extend( 180 | [grads[param].cpu().numpy().flatten().tolist()]) 181 | flat_epochs.append(user_arr) 182 | 183 | flat_epochs = np.array(flat_epochs) 184 | 185 | return flat_epochs 186 | 187 | 188 | 189 | 190 | def get_update(update, model): 191 | '''get the update weight''' 192 | update2 = {} 193 | for key, var in update.items(): 194 | update2[key] = update[key] - model[key] 195 | return update2 196 | 197 | def get_update2(update, model): 198 | '''get the update weight''' 199 | update2 = {} 200 | for key, var in update.items(): 201 | if key.split('.')[-1] == 'num_batches_tracked' or key.split('.')[-1] == 'running_mean' or key.split('.')[-1] == 'running_var': 202 | continue 203 | update2[key] = update[key] - model[key] 204 | return update2 205 | 206 | 207 | def fld_distance(old_update_list, local_update_list, net_glob, attack_number, hvp): 208 | pred_update = [] 209 | distance = [] 210 | for i in range(len(old_update_list)): 211 | pred_update.append((old_update_list[i] + hvp).view(-1)) 212 | 213 | 214 | pred_update = torch.stack(pred_update) 215 | local_update_list = torch.stack(local_update_list) 216 | old_update_list = torch.stack(old_update_list) 217 | 218 | distance = torch.norm((old_update_list - local_update_list), dim=1) 219 | # print('defense line219 distance(old_update_list - local_update_list):',distance) 220 | # auc1 = roc_auc_score(pred_update.numpy(), distance) 221 | # distance = torch.norm((pred_update - local_update_list), dim=1).numpy() 222 | # auc2 = roc_auc_score(pred_update.numpy(), distance) 223 | # print("Detection AUC: %0.4f; Detection AUC: %0.4f" % (auc1, auc2)) 224 | 225 | # print('defence line 211 pred_update.shape:', pred_update.shape) 226 | distance = torch.norm((pred_update - local_update_list), dim=1) 227 | # print('defence line 211 distance.shape:', distance.shape) 228 | # distance = nn.functional.norm((pred_update - local_update_list), dim=0).numpy() 229 | distance = distance / torch.sum(distance) 230 | return distance 231 | 232 | def detection(score, nobyz): 233 | estimator = KMeans(n_clusters=2) 234 | estimator.fit(score.reshape(-1, 1)) 235 | label_pred = estimator.labels_ 236 | 237 | if np.mean(score[label_pred==0]) 0: 278 | gapDiff[i - 1] = gaps[i - 1] - gaps[i] + sdk[i] 279 | # print('defense line278 gapDiff:', gapDiff) 280 | select_k = 2 # default detect attacks 281 | for i in range(len(gapDiff)): 282 | if gapDiff[i] >= 0: 283 | select_k = i+1 284 | break 285 | if select_k == 1: 286 | print('No attack detected!') 287 | return 0 288 | else: 289 | print('Attack Detected!') 290 | return 1 291 | 292 | def RLR(global_model, agent_updates_list, args): 293 | """ 294 | agent_updates_dict: dict['key']=one_dimension_update 295 | agent_updates_list: list[0] = model.dict 296 | global_model: net 297 | """ 298 | # args.robustLR_threshold = 6 299 | args.server_lr = 1 300 | 301 | grad_list = [] 302 | for i in agent_updates_list: 303 | grad_list.append(parameters_dict_to_vector_rlr(i)) 304 | agent_updates_list = grad_list 305 | 306 | 307 | aggregated_updates = 0 308 | for update in agent_updates_list: 309 | # print(update.shape) # torch.Size([1199882]) 310 | aggregated_updates += update 311 | aggregated_updates /= len(agent_updates_list) 312 | lr_vector = compute_robustLR(agent_updates_list, args) 313 | cur_global_params = parameters_dict_to_vector_rlr(global_model.state_dict()) 314 | new_global_params = (cur_global_params + lr_vector*aggregated_updates).float() 315 | global_w = vector_to_parameters_dict(new_global_params, global_model.state_dict()) 316 | # print(cur_global_params == vector_to_parameters_dict(new_global_params, global_model.state_dict())) 317 | return global_w 318 | 319 | def parameters_dict_to_vector_rlr(net_dict) -> torch.Tensor: 320 | r"""Convert parameters to one vector 321 | 322 | Args: 323 | parameters (Iterable[Tensor]): an iterator of Tensors that are the 324 | parameters of a model. 325 | 326 | Returns: 327 | The parameters represented by a single vector 328 | """ 329 | vec = [] 330 | for key, param in net_dict.items(): 331 | vec.append(param.view(-1)) 332 | return torch.cat(vec) 333 | 334 | def parameters_dict_to_vector(net_dict) -> torch.Tensor: 335 | r"""Convert parameters to one vector 336 | 337 | Args: 338 | parameters (Iterable[Tensor]): an iterator of Tensors that are the 339 | parameters of a model. 340 | 341 | Returns: 342 | The parameters represented by a single vector 343 | """ 344 | vec = [] 345 | for key, param in net_dict.items(): 346 | if key.split('.')[-1] != 'weight' and key.split('.')[-1] != 'bias': 347 | continue 348 | vec.append(param.view(-1)) 349 | return torch.cat(vec) 350 | 351 | 352 | 353 | def vector_to_parameters_dict(vec: torch.Tensor, net_dict) -> None: 354 | r"""Convert one vector to the parameters 355 | 356 | Args: 357 | vec (Tensor): a single vector represents the parameters of a model. 358 | parameters (Iterable[Tensor]): an iterator of Tensors that are the 359 | parameters of a model. 360 | """ 361 | 362 | pointer = 0 363 | for param in net_dict.values(): 364 | # The length of the parameter 365 | num_param = param.numel() 366 | # Slice the vector, reshape it, and replace the old data of the parameter 367 | param.data = vec[pointer:pointer + num_param].view_as(param).data 368 | 369 | # Increment the pointer 370 | pointer += num_param 371 | return net_dict 372 | 373 | def compute_robustLR(params, args): 374 | agent_updates_sign = [torch.sign(update) for update in params] 375 | sm_of_signs = torch.abs(sum(agent_updates_sign)) 376 | # print(len(agent_updates_sign)) #10 377 | # print(agent_updates_sign[0].shape) #torch.Size([1199882]) 378 | sm_of_signs[sm_of_signs < args.robustLR_threshold] = -args.server_lr 379 | sm_of_signs[sm_of_signs >= args.robustLR_threshold] = args.server_lr 380 | return sm_of_signs.to(args.gpu) 381 | 382 | 383 | 384 | 385 | def flame(local_model, update_params, global_model, args, debug=False): 386 | cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6).cuda() 387 | cos_list=[] 388 | local_model_vector = [] 389 | for param in local_model: 390 | # local_model_vector.append(parameters_dict_to_vector_flt_cpu(param)) 391 | local_model_vector.append(parameters_dict_to_vector_flt(param)) 392 | for i in range(len(local_model_vector)): 393 | cos_i = [] 394 | for j in range(len(local_model_vector)): 395 | cos_ij = 1- cos(local_model_vector[i],local_model_vector[j]) 396 | # cos_i.append(round(cos_ij.item(),4)) 397 | cos_i.append(cos_ij.item()) 398 | cos_list.append(cos_i) 399 | if debug==True: 400 | filename = './' + args.save + '/flame_analysis.txt' 401 | f = open(filename, "a") 402 | for i in cos_list: 403 | f.write(str(i)) 404 | # print(i) 405 | f.write('\n') 406 | f.write('\n') 407 | f.write("--------Round--------") 408 | f.write('\n') 409 | num_clients = max(int(args.frac * args.num_users), 1) 410 | num_malicious_clients = int(args.malicious * num_clients) 411 | num_benign_clients = num_clients - num_malicious_clients 412 | clusterer = hdbscan.HDBSCAN(min_cluster_size=num_clients//2 + 1,min_samples=1,allow_single_cluster=True).fit(cos_list) 413 | # print(clusterer.labels_) 414 | benign_client = [] 415 | norm_list = np.array([]) 416 | 417 | max_num_in_cluster=0 418 | max_cluster_index=0 419 | if clusterer.labels_.max() < 0: 420 | for i in range(len(local_model)): 421 | benign_client.append(i) 422 | norm_list = np.append(norm_list,torch.norm(parameters_dict_to_vector(update_params[i]),p=2).item()) 423 | else: 424 | for index_cluster in range(clusterer.labels_.max()+1): 425 | if len(clusterer.labels_[clusterer.labels_==index_cluster]) > max_num_in_cluster: 426 | max_cluster_index = index_cluster 427 | max_num_in_cluster = len(clusterer.labels_[clusterer.labels_==index_cluster]) 428 | for i in range(len(clusterer.labels_)): 429 | if clusterer.labels_[i] == max_cluster_index: 430 | benign_client.append(i) 431 | # norm_list = np.append(norm_list,torch.norm(update_params_vector[i],p=2)) # consider BN 432 | norm_list = np.append(norm_list,torch.norm(parameters_dict_to_vector(update_params[i]),p=2).item()) # no consider BN 433 | # print(benign_client) 434 | 435 | for i in range(len(benign_client)): 436 | if benign_client[i] < num_malicious_clients: 437 | args.wrong_mal+=1 438 | else: 439 | # minus per benign in cluster 440 | args.right_ben += 1 441 | args.turn+=1 442 | # print('proportion of malicious are selected:',args.wrong_mal/(num_malicious_clients*args.turn)) 443 | # print('proportion of benign are selected:',args.right_ben/(num_benign_clients*args.turn)) 444 | 445 | clip_value = np.median(norm_list) 446 | for i in range(len(benign_client)): 447 | gama = clip_value/norm_list[i] 448 | if gama < 1: 449 | for key in update_params[benign_client[i]]: 450 | if key.split('.')[-1] == 'num_batches_tracked': 451 | continue 452 | update_params[benign_client[i]][key] *= gama 453 | global_model = no_defence_balance([update_params[i] for i in benign_client], global_model) 454 | #add noise 455 | for key, var in global_model.items(): 456 | if key.split('.')[-1] == 'num_batches_tracked': 457 | continue 458 | temp = copy.deepcopy(var) 459 | temp = temp.normal_(mean=0,std=args.noise*clip_value) 460 | var += temp 461 | return global_model 462 | 463 | 464 | def flame_analysis(local_model, args, debug=False): 465 | cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6).cuda() 466 | cos_list=[] 467 | local_model_vector = [] 468 | for param in local_model: 469 | local_model_vector.append(parameters_dict_to_vector_flt(param)) 470 | for i in range(len(local_model_vector)): 471 | cos_i = [] 472 | for j in range(len(local_model_vector)): 473 | cos_ij = 1- cos(local_model_vector[i],local_model_vector[j]) 474 | # cos_i.append(round(cos_ij.item(),4)) 475 | cos_i.append(cos_ij.item()) 476 | cos_list.append(cos_i) 477 | if debug==True: 478 | filename = './' + args.save + '/flame_analysis.txt' 479 | f = open(filename, "a") 480 | for i in cos_list: 481 | f.write(str(i)) 482 | f.write('/n') 483 | f.write('/n') 484 | f.write("--------Round--------") 485 | f.write('/n') 486 | num_clients = max(int(args.frac * args.num_users), 1) 487 | num_malicious_clients = int(args.malicious * num_clients) 488 | num_benign_clients = num_clients - num_malicious_clients 489 | clusterer = hdbscan.HDBSCAN(min_cluster_size=num_clients//2 + 1,min_samples=1,allow_single_cluster=True).fit(cos_list) 490 | # print(clusterer.labels_) 491 | benign_client = [] 492 | 493 | max_num_in_cluster=0 494 | max_cluster_index=0 495 | if clusterer.labels_.max() < 0: 496 | for i in range(len(local_model)): 497 | benign_client.append(i) 498 | else: 499 | for index_cluster in range(clusterer.labels_.max()+1): 500 | if len(clusterer.labels_[clusterer.labels_==index_cluster]) > max_num_in_cluster: 501 | max_cluster_index = index_cluster 502 | max_num_in_cluster = len(clusterer.labels_[clusterer.labels_==index_cluster]) 503 | for i in range(len(clusterer.labels_)): 504 | if clusterer.labels_[i] == max_cluster_index: 505 | benign_client.append(i) 506 | return benign_client 507 | 508 | def lbfgs(args, S_k_list, Y_k_list, v): 509 | curr_S_k = nd.concat(*S_k_list, dim=1) 510 | curr_Y_k = nd.concat(*Y_k_list, dim=1) 511 | S_k_time_Y_k = nd.dot(curr_S_k.T, curr_Y_k) 512 | S_k_time_S_k = nd.dot(curr_S_k.T, curr_S_k) 513 | R_k = np.triu(S_k_time_Y_k.asnumpy()) 514 | L_k = S_k_time_Y_k - nd.array(R_k, ctx=mx.gpu(args.gpu)) 515 | sigma_k = nd.dot(Y_k_list[-1].T, S_k_list[-1]) / (nd.dot(S_k_list[-1].T, S_k_list[-1])) 516 | D_k_diag = nd.diag(S_k_time_Y_k) 517 | upper_mat = nd.concat(*[sigma_k * S_k_time_S_k, L_k], dim=1) 518 | lower_mat = nd.concat(*[L_k.T, -nd.diag(D_k_diag)], dim=1) 519 | mat = nd.concat(*[upper_mat, lower_mat], dim=0) 520 | mat_inv = nd.linalg.inverse(mat) 521 | 522 | approx_prod = sigma_k * v 523 | p_mat = nd.concat(*[nd.dot(curr_S_k.T, sigma_k * v), nd.dot(curr_Y_k.T, v)], dim=0) 524 | approx_prod -= nd.dot(nd.dot(nd.concat(*[sigma_k * curr_S_k, curr_Y_k], dim=1), mat_inv), p_mat) 525 | 526 | return approx_prod 527 | 528 | # def lbfgs_torch(args, S_k_list, Y_k_list, v): 529 | # # curr_S_k = nd.concat(*S_k_list, dim=1) 530 | # # curr_Y_k = nd.concat(*Y_k_list, dim=1) 531 | # curr_S_k = S_k_list 532 | # curr_Y_k = Y_k_list 533 | # S_k_time_Y_k = torch.dot(curr_S_k.T, curr_Y_k) 534 | # S_k_time_S_k = torch.dot(curr_S_k.T, curr_S_k) 535 | # R_k = np.triu(S_k_time_Y_k.numpy()) 536 | # L_k = S_k_time_Y_k - torch.array(R_k).to(args.gpu) 537 | # sigma_k = torch.dot(Y_k_list[-1].T, S_k_list[-1]) / (nd.dot(S_k_list[-1].T, S_k_list[-1])) 538 | # D_k_diag = torch.diag(S_k_time_Y_k) 539 | # upper_mat = torch.concat(*[sigma_k * S_k_time_S_k, L_k], dim=1) 540 | # lower_mat = torch.concat(*[L_k.T, -nd.diag(D_k_diag)], dim=1) 541 | # mat = torch.concat(*[upper_mat, lower_mat], dim=0) 542 | # mat_inv = torch.linalg.inv(mat) 543 | 544 | # approx_prod = sigma_k * v 545 | # p_mat = torch.concat(*[nd.dot(curr_S_k.T, sigma_k * v), nd.dot(curr_Y_k.T, v)], dim=0) 546 | # approx_prod -= torch.dot(torch.dot(torch.concat(*[sigma_k * curr_S_k, curr_Y_k], dim=1), mat_inv), p_mat) 547 | 548 | # return approx_prod 549 | 550 | def lbfgs_torch(args, S_k_list, Y_k_list, v): 551 | curr_S_k = torch.stack(S_k_list) 552 | curr_S_k = curr_S_k.transpose(0, 1).cpu() #(10,xxxxxx) 553 | # print('------------------------') 554 | # print('curr_S_k.shape', curr_S_k.shape) 555 | curr_Y_k = torch.stack(Y_k_list) 556 | curr_Y_k = curr_Y_k.transpose(0, 1).cpu() #(10,xxxxxx) 557 | S_k_time_Y_k = curr_S_k.transpose(0, 1) @ curr_Y_k 558 | S_k_time_Y_k = S_k_time_Y_k.cpu() 559 | 560 | 561 | S_k_time_S_k = curr_S_k.transpose(0, 1) @ curr_S_k 562 | S_k_time_S_k = S_k_time_S_k.cpu() 563 | # print('S_k_time_S_k.shape', S_k_time_S_k.shape) 564 | R_k = np.triu(S_k_time_Y_k.numpy()) 565 | L_k = S_k_time_Y_k - torch.from_numpy(R_k).cpu() 566 | sigma_k = Y_k_list[-1].view(-1,1).transpose(0, 1) @ S_k_list[-1].view(-1,1) / (S_k_list[-1].view(-1,1).transpose(0, 1) @ S_k_list[-1].view(-1,1)) 567 | sigma_k=sigma_k.cpu() 568 | 569 | D_k_diag = S_k_time_Y_k.diagonal() 570 | upper_mat = torch.cat([sigma_k * S_k_time_S_k, L_k], dim=1) 571 | lower_mat = torch.cat([L_k.transpose(0, 1), -D_k_diag.diag()], dim=1) 572 | mat = torch.cat([upper_mat, lower_mat], dim=0) 573 | mat_inv = mat.inverse() 574 | # print('mat_inv.shape',mat_inv.shape) 575 | v = v.view(-1,1).cpu() 576 | 577 | approx_prod = sigma_k * v 578 | # print('approx_prod.shape',approx_prod.shape) 579 | # print('v.shape',v.shape) 580 | # print('sigma_k.shape',sigma_k.shape) 581 | # print('sigma_k',sigma_k) 582 | p_mat = torch.cat([curr_S_k.transpose(0, 1) @ (sigma_k * v), curr_Y_k.transpose(0, 1) @ v], dim=0) 583 | 584 | approx_prod -= torch.cat([sigma_k * curr_S_k, curr_Y_k], dim=1) @ mat_inv @ p_mat 585 | # print('approx_prod.shape',approx_prod.shape) 586 | # print('approx_prod.shape',approx_prod.shape) 587 | # print('approx_prod.shape.T',approx_prod.T.shape) 588 | 589 | return approx_prod.T -------------------------------------------------------------------------------- /utils/.ipynb_checkpoints/defense-checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding = utf-8 -*- 2 | import numpy as np 3 | import torch 4 | import copy 5 | import time 6 | import hdbscan 7 | from sklearn.cluster import KMeans 8 | from sklearn.metrics import silhouette_score 9 | import torch.nn as nn 10 | from sklearn.metrics import roc_auc_score 11 | 12 | 13 | def cos(a, b): 14 | # res = np.sum(a*b.T)/((np.sqrt(np.sum(a * a.T)) + 1e-9) * (np.sqrt(np.sum(b * b.T))) + 1e- 15 | res = (np.dot(a, b) + 1e-9) / (np.linalg.norm(a) + 1e-9) / \ 16 | (np.linalg.norm(b) + 1e-9) 17 | '''relu''' 18 | if res < 0: 19 | res = 0 20 | return res 21 | 22 | 23 | def fltrust(params, central_param, global_parameters, args): 24 | FLTrustTotalScore = 0 25 | score_list = [] 26 | central_param_v = parameters_dict_to_vector_flt(central_param) 27 | central_norm = torch.norm(central_param_v) 28 | cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6).cuda() 29 | sum_parameters = None 30 | for local_parameters in params: 31 | local_parameters_v = parameters_dict_to_vector_flt(local_parameters) 32 | # 计算cos相似度得分和向量长度裁剪值 33 | client_cos = cos(central_param_v, local_parameters_v) 34 | client_cos = max(client_cos.item(), 0) 35 | client_clipped_value = central_norm/torch.norm(local_parameters_v) 36 | score_list.append(client_cos) 37 | FLTrustTotalScore += client_cos 38 | if sum_parameters is None: 39 | sum_parameters = {} 40 | for key, var in local_parameters.items(): 41 | # 乘得分 再乘裁剪值 42 | sum_parameters[key] = client_cos * \ 43 | client_clipped_value * var.clone() 44 | else: 45 | for var in sum_parameters: 46 | sum_parameters[var] = sum_parameters[var] + client_cos * client_clipped_value * local_parameters[ 47 | var] 48 | if FLTrustTotalScore == 0: 49 | # print(score_list) 50 | return global_parameters 51 | for var in global_parameters: 52 | # 除以所以客户端的信任得分总和 53 | temp = (sum_parameters[var] / FLTrustTotalScore) 54 | if global_parameters[var].type() != temp.type(): 55 | temp = temp.type(global_parameters[var].type()) 56 | if var.split('.')[-1] == 'num_batches_tracked': 57 | global_parameters[var] = params[0][var] 58 | else: 59 | global_parameters[var] += temp * args.server_lr 60 | # print(score_list) 61 | return global_parameters 62 | 63 | 64 | def parameters_dict_to_vector_flt(net_dict) -> torch.Tensor: 65 | vec = [] 66 | for key, param in net_dict.items(): 67 | # print(key, torch.max(param)) 68 | if key.split('.')[-1] == 'num_batches_tracked' or key.split('.')[-1] == 'running_mean' or key.split('.')[-1] == 'running_var': 69 | continue 70 | vec.append(param.view(-1)) 71 | return torch.cat(vec) 72 | 73 | def parameters_dict_to_vector_flt_cpu(net_dict) -> torch.Tensor: 74 | vec = [] 75 | for key, param in net_dict.items(): 76 | # print(key, torch.max(param)) 77 | if key.split('.')[-1] == 'num_batches_tracked' or key.split('.')[-1] == 'running_mean' or key.split('.')[-1] == 'running_var': 78 | continue 79 | vec.append(param.cpu().view(-1)) 80 | return torch.cat(vec) 81 | 82 | 83 | def no_defence_balance(params, global_parameters): 84 | total_num = len(params) 85 | sum_parameters = None 86 | for i in range(total_num): 87 | if sum_parameters is None: 88 | sum_parameters = {} 89 | for key, var in params[i].items(): 90 | sum_parameters[key] = var.clone() 91 | else: 92 | for var in sum_parameters: 93 | sum_parameters[var] = sum_parameters[var] + params[i][var] 94 | for var in global_parameters: 95 | if var.split('.')[-1] == 'num_batches_tracked': 96 | global_parameters[var] = params[0][var] 97 | continue 98 | global_parameters[var] += (sum_parameters[var] / total_num) 99 | 100 | return global_parameters 101 | 102 | 103 | def multi_krum(gradients, n_attackers, args, multi_k=False): 104 | 105 | grads = flatten_grads(gradients) 106 | 107 | candidates = [] 108 | candidate_indices = [] 109 | remaining_updates = torch.from_numpy(grads) 110 | all_indices = np.arange(len(grads)) 111 | 112 | while len(remaining_updates) > 2 * n_attackers + 2: 113 | torch.cuda.empty_cache() 114 | distances = [] 115 | scores = None 116 | for update in remaining_updates: 117 | distance = [] 118 | for update_ in remaining_updates: 119 | distance.append(torch.norm((update - update_)) ** 2) 120 | distance = torch.Tensor(distance).float() 121 | distances = distance[None, :] if not len( 122 | distances) else torch.cat((distances, distance[None, :]), 0) 123 | 124 | distances = torch.sort(distances, dim=1)[0] 125 | scores = torch.sum( 126 | distances[:, :len(remaining_updates) - 2 - n_attackers], dim=1) 127 | # print(scores) 128 | # args.krum_distance.append(scores) 129 | indices = torch.argsort(scores)[:len( 130 | remaining_updates) - 2 - n_attackers] 131 | 132 | candidate_indices.append(all_indices[indices[0].cpu().numpy()]) 133 | all_indices = np.delete(all_indices, indices[0].cpu().numpy()) 134 | candidates = remaining_updates[indices[0]][None, :] if not len( 135 | candidates) else torch.cat((candidates, remaining_updates[indices[0]][None, :]), 0) 136 | remaining_updates = torch.cat( 137 | (remaining_updates[:indices[0]], remaining_updates[indices[0] + 1:]), 0) 138 | if not multi_k: 139 | break 140 | 141 | # aggregate = torch.mean(candidates, dim=0) 142 | 143 | # return aggregate, np.array(candidate_indices) 144 | num_clients = max(int(args.frac * args.num_users), 1) 145 | num_malicious_clients = int(args.malicious * num_clients) 146 | num_benign_clients = num_clients - num_malicious_clients 147 | args.turn+=1 148 | for selected_client in candidate_indices: 149 | if selected_client < num_malicious_clients: 150 | args.wrong_mal += 1 151 | 152 | # print(candidate_indices) 153 | 154 | # print('Proportion of malicious are selected:'+str(args.wrong_mal/args.turn)) 155 | 156 | for i in range(len(scores)): 157 | if i < num_malicious_clients: 158 | args.mal_score += scores[i] 159 | else: 160 | args.ben_score += scores[i] 161 | 162 | return np.array(candidate_indices) 163 | 164 | 165 | 166 | def flatten_grads(gradients): 167 | 168 | param_order = gradients[0].keys() 169 | 170 | flat_epochs = [] 171 | 172 | for n_user in range(len(gradients)): 173 | user_arr = [] 174 | grads = gradients[n_user] 175 | for param in param_order: 176 | try: 177 | user_arr.extend(grads[param].cpu().numpy().flatten().tolist()) 178 | except: 179 | user_arr.extend( 180 | [grads[param].cpu().numpy().flatten().tolist()]) 181 | flat_epochs.append(user_arr) 182 | 183 | flat_epochs = np.array(flat_epochs) 184 | 185 | return flat_epochs 186 | 187 | 188 | 189 | 190 | def get_update(update, model): 191 | '''get the update weight''' 192 | update2 = {} 193 | for key, var in update.items(): 194 | update2[key] = update[key] - model[key] 195 | return update2 196 | 197 | def get_update2(update, model): 198 | '''get the update weight''' 199 | update2 = {} 200 | for key, var in update.items(): 201 | if key.split('.')[-1] == 'num_batches_tracked' or key.split('.')[-1] == 'running_mean' or key.split('.')[-1] == 'running_var': 202 | continue 203 | update2[key] = update[key] - model[key] 204 | return update2 205 | 206 | 207 | def fld_distance(old_update_list, local_update_list, net_glob, attack_number, hvp): 208 | pred_update = [] 209 | distance = [] 210 | for i in range(len(old_update_list)): 211 | pred_update.append((old_update_list[i] + hvp).view(-1)) 212 | 213 | 214 | pred_update = torch.stack(pred_update) 215 | local_update_list = torch.stack(local_update_list) 216 | old_update_list = torch.stack(old_update_list) 217 | 218 | distance = torch.norm((old_update_list - local_update_list), dim=1) 219 | # print('defense line219 distance(old_update_list - local_update_list):',distance) 220 | # auc1 = roc_auc_score(pred_update.numpy(), distance) 221 | # distance = torch.norm((pred_update - local_update_list), dim=1).numpy() 222 | # auc2 = roc_auc_score(pred_update.numpy(), distance) 223 | # print("Detection AUC: %0.4f; Detection AUC: %0.4f" % (auc1, auc2)) 224 | 225 | # print('defence line 211 pred_update.shape:', pred_update.shape) 226 | distance = torch.norm((pred_update - local_update_list), dim=1) 227 | # print('defence line 211 distance.shape:', distance.shape) 228 | # distance = nn.functional.norm((pred_update - local_update_list), dim=0).numpy() 229 | distance = distance / torch.sum(distance) 230 | return distance 231 | 232 | def detection(score, nobyz): 233 | estimator = KMeans(n_clusters=2) 234 | estimator.fit(score.reshape(-1, 1)) 235 | label_pred = estimator.labels_ 236 | 237 | if np.mean(score[label_pred==0]) 0: 278 | gapDiff[i - 1] = gaps[i - 1] - gaps[i] + sdk[i] 279 | # print('defense line278 gapDiff:', gapDiff) 280 | select_k = 2 # default detect attacks 281 | for i in range(len(gapDiff)): 282 | if gapDiff[i] >= 0: 283 | select_k = i+1 284 | break 285 | if select_k == 1: 286 | print('No attack detected!') 287 | return 0 288 | else: 289 | print('Attack Detected!') 290 | return 1 291 | 292 | def RLR(global_model, agent_updates_list, args): 293 | """ 294 | agent_updates_dict: dict['key']=one_dimension_update 295 | agent_updates_list: list[0] = model.dict 296 | global_model: net 297 | """ 298 | # args.robustLR_threshold = 6 299 | args.server_lr = 1 300 | 301 | grad_list = [] 302 | for i in agent_updates_list: 303 | grad_list.append(parameters_dict_to_vector_rlr(i)) 304 | agent_updates_list = grad_list 305 | 306 | 307 | aggregated_updates = 0 308 | for update in agent_updates_list: 309 | # print(update.shape) # torch.Size([1199882]) 310 | aggregated_updates += update 311 | aggregated_updates /= len(agent_updates_list) 312 | lr_vector = compute_robustLR(agent_updates_list, args) 313 | cur_global_params = parameters_dict_to_vector_rlr(global_model.state_dict()) 314 | new_global_params = (cur_global_params + lr_vector*aggregated_updates).float() 315 | global_w = vector_to_parameters_dict(new_global_params, global_model.state_dict()) 316 | # print(cur_global_params == vector_to_parameters_dict(new_global_params, global_model.state_dict())) 317 | return global_w 318 | 319 | def parameters_dict_to_vector_rlr(net_dict) -> torch.Tensor: 320 | r"""Convert parameters to one vector 321 | 322 | Args: 323 | parameters (Iterable[Tensor]): an iterator of Tensors that are the 324 | parameters of a model. 325 | 326 | Returns: 327 | The parameters represented by a single vector 328 | """ 329 | vec = [] 330 | for key, param in net_dict.items(): 331 | vec.append(param.view(-1)) 332 | return torch.cat(vec) 333 | 334 | def parameters_dict_to_vector(net_dict) -> torch.Tensor: 335 | r"""Convert parameters to one vector 336 | 337 | Args: 338 | parameters (Iterable[Tensor]): an iterator of Tensors that are the 339 | parameters of a model. 340 | 341 | Returns: 342 | The parameters represented by a single vector 343 | """ 344 | vec = [] 345 | for key, param in net_dict.items(): 346 | if key.split('.')[-1] != 'weight' and key.split('.')[-1] != 'bias': 347 | continue 348 | vec.append(param.view(-1)) 349 | return torch.cat(vec) 350 | 351 | 352 | 353 | def vector_to_parameters_dict(vec: torch.Tensor, net_dict) -> None: 354 | r"""Convert one vector to the parameters 355 | 356 | Args: 357 | vec (Tensor): a single vector represents the parameters of a model. 358 | parameters (Iterable[Tensor]): an iterator of Tensors that are the 359 | parameters of a model. 360 | """ 361 | 362 | pointer = 0 363 | for param in net_dict.values(): 364 | # The length of the parameter 365 | num_param = param.numel() 366 | # Slice the vector, reshape it, and replace the old data of the parameter 367 | param.data = vec[pointer:pointer + num_param].view_as(param).data 368 | 369 | # Increment the pointer 370 | pointer += num_param 371 | return net_dict 372 | 373 | def compute_robustLR(params, args): 374 | agent_updates_sign = [torch.sign(update) for update in params] 375 | sm_of_signs = torch.abs(sum(agent_updates_sign)) 376 | # print(len(agent_updates_sign)) #10 377 | # print(agent_updates_sign[0].shape) #torch.Size([1199882]) 378 | sm_of_signs[sm_of_signs < args.robustLR_threshold] = -args.server_lr 379 | sm_of_signs[sm_of_signs >= args.robustLR_threshold] = args.server_lr 380 | return sm_of_signs.to(args.gpu) 381 | 382 | 383 | 384 | 385 | def flame(local_model, update_params, global_model, args, debug=False): 386 | cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6).cuda() 387 | cos_list=[] 388 | local_model_vector = [] 389 | for param in local_model: 390 | # local_model_vector.append(parameters_dict_to_vector_flt_cpu(param)) 391 | local_model_vector.append(parameters_dict_to_vector_flt(param)) 392 | for i in range(len(local_model_vector)): 393 | cos_i = [] 394 | for j in range(len(local_model_vector)): 395 | cos_ij = 1- cos(local_model_vector[i],local_model_vector[j]) 396 | # cos_i.append(round(cos_ij.item(),4)) 397 | cos_i.append(cos_ij.item()) 398 | cos_list.append(cos_i) 399 | if debug==True: 400 | filename = './' + args.save + '/flame_analysis.txt' 401 | f = open(filename, "a") 402 | for i in cos_list: 403 | f.write(str(i)) 404 | # print(i) 405 | f.write('\n') 406 | f.write('\n') 407 | f.write("--------Round--------") 408 | f.write('\n') 409 | num_clients = max(int(args.frac * args.num_users), 1) 410 | num_malicious_clients = int(args.malicious * num_clients) 411 | num_benign_clients = num_clients - num_malicious_clients 412 | clusterer = hdbscan.HDBSCAN(min_cluster_size=num_clients//2 + 1,min_samples=1,allow_single_cluster=True).fit(cos_list) 413 | # print(clusterer.labels_) 414 | benign_client = [] 415 | norm_list = np.array([]) 416 | 417 | max_num_in_cluster=0 418 | max_cluster_index=0 419 | if clusterer.labels_.max() < 0: 420 | for i in range(len(local_model)): 421 | benign_client.append(i) 422 | norm_list = np.append(norm_list,torch.norm(parameters_dict_to_vector(update_params[i]),p=2).item()) 423 | else: 424 | for index_cluster in range(clusterer.labels_.max()+1): 425 | if len(clusterer.labels_[clusterer.labels_==index_cluster]) > max_num_in_cluster: 426 | max_cluster_index = index_cluster 427 | max_num_in_cluster = len(clusterer.labels_[clusterer.labels_==index_cluster]) 428 | for i in range(len(clusterer.labels_)): 429 | if clusterer.labels_[i] == max_cluster_index: 430 | benign_client.append(i) 431 | # norm_list = np.append(norm_list,torch.norm(update_params_vector[i],p=2)) # consider BN 432 | norm_list = np.append(norm_list,torch.norm(parameters_dict_to_vector(update_params[i]),p=2).item()) # no consider BN 433 | # print(benign_client) 434 | 435 | for i in range(len(benign_client)): 436 | if benign_client[i] < num_malicious_clients: 437 | args.wrong_mal+=1 438 | else: 439 | # minus per benign in cluster 440 | args.right_ben += 1 441 | args.turn+=1 442 | # print('proportion of malicious are selected:',args.wrong_mal/(num_malicious_clients*args.turn)) 443 | # print('proportion of benign are selected:',args.right_ben/(num_benign_clients*args.turn)) 444 | 445 | clip_value = np.median(norm_list) 446 | for i in range(len(benign_client)): 447 | gama = clip_value/norm_list[i] 448 | if gama < 1: 449 | for key in update_params[benign_client[i]]: 450 | if key.split('.')[-1] == 'num_batches_tracked': 451 | continue 452 | update_params[benign_client[i]][key] *= gama 453 | global_model = no_defence_balance([update_params[i] for i in benign_client], global_model) 454 | #add noise 455 | for key, var in global_model.items(): 456 | if key.split('.')[-1] == 'num_batches_tracked': 457 | continue 458 | temp = copy.deepcopy(var) 459 | temp = temp.normal_(mean=0,std=args.noise*clip_value) 460 | var += temp 461 | return global_model 462 | 463 | 464 | def flame_analysis(local_model, args, debug=False): 465 | cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6).cuda() 466 | cos_list=[] 467 | local_model_vector = [] 468 | for param in local_model: 469 | local_model_vector.append(parameters_dict_to_vector_flt(param)) 470 | for i in range(len(local_model_vector)): 471 | cos_i = [] 472 | for j in range(len(local_model_vector)): 473 | cos_ij = 1- cos(local_model_vector[i],local_model_vector[j]) 474 | # cos_i.append(round(cos_ij.item(),4)) 475 | cos_i.append(cos_ij.item()) 476 | cos_list.append(cos_i) 477 | if debug==True: 478 | filename = './' + args.save + '/flame_analysis.txt' 479 | f = open(filename, "a") 480 | for i in cos_list: 481 | f.write(str(i)) 482 | f.write('/n') 483 | f.write('/n') 484 | f.write("--------Round--------") 485 | f.write('/n') 486 | num_clients = max(int(args.frac * args.num_users), 1) 487 | num_malicious_clients = int(args.malicious * num_clients) 488 | num_benign_clients = num_clients - num_malicious_clients 489 | clusterer = hdbscan.HDBSCAN(min_cluster_size=num_clients//2 + 1,min_samples=1,allow_single_cluster=True).fit(cos_list) 490 | # print(clusterer.labels_) 491 | benign_client = [] 492 | 493 | max_num_in_cluster=0 494 | max_cluster_index=0 495 | if clusterer.labels_.max() < 0: 496 | for i in range(len(local_model)): 497 | benign_client.append(i) 498 | else: 499 | for index_cluster in range(clusterer.labels_.max()+1): 500 | if len(clusterer.labels_[clusterer.labels_==index_cluster]) > max_num_in_cluster: 501 | max_cluster_index = index_cluster 502 | max_num_in_cluster = len(clusterer.labels_[clusterer.labels_==index_cluster]) 503 | for i in range(len(clusterer.labels_)): 504 | if clusterer.labels_[i] == max_cluster_index: 505 | benign_client.append(i) 506 | return benign_client 507 | 508 | def lbfgs(args, S_k_list, Y_k_list, v): 509 | curr_S_k = nd.concat(*S_k_list, dim=1) 510 | curr_Y_k = nd.concat(*Y_k_list, dim=1) 511 | S_k_time_Y_k = nd.dot(curr_S_k.T, curr_Y_k) 512 | S_k_time_S_k = nd.dot(curr_S_k.T, curr_S_k) 513 | R_k = np.triu(S_k_time_Y_k.asnumpy()) 514 | L_k = S_k_time_Y_k - nd.array(R_k, ctx=mx.gpu(args.gpu)) 515 | sigma_k = nd.dot(Y_k_list[-1].T, S_k_list[-1]) / (nd.dot(S_k_list[-1].T, S_k_list[-1])) 516 | D_k_diag = nd.diag(S_k_time_Y_k) 517 | upper_mat = nd.concat(*[sigma_k * S_k_time_S_k, L_k], dim=1) 518 | lower_mat = nd.concat(*[L_k.T, -nd.diag(D_k_diag)], dim=1) 519 | mat = nd.concat(*[upper_mat, lower_mat], dim=0) 520 | mat_inv = nd.linalg.inverse(mat) 521 | 522 | approx_prod = sigma_k * v 523 | p_mat = nd.concat(*[nd.dot(curr_S_k.T, sigma_k * v), nd.dot(curr_Y_k.T, v)], dim=0) 524 | approx_prod -= nd.dot(nd.dot(nd.concat(*[sigma_k * curr_S_k, curr_Y_k], dim=1), mat_inv), p_mat) 525 | 526 | return approx_prod 527 | 528 | # def lbfgs_torch(args, S_k_list, Y_k_list, v): 529 | # # curr_S_k = nd.concat(*S_k_list, dim=1) 530 | # # curr_Y_k = nd.concat(*Y_k_list, dim=1) 531 | # curr_S_k = S_k_list 532 | # curr_Y_k = Y_k_list 533 | # S_k_time_Y_k = torch.dot(curr_S_k.T, curr_Y_k) 534 | # S_k_time_S_k = torch.dot(curr_S_k.T, curr_S_k) 535 | # R_k = np.triu(S_k_time_Y_k.numpy()) 536 | # L_k = S_k_time_Y_k - torch.array(R_k).to(args.gpu) 537 | # sigma_k = torch.dot(Y_k_list[-1].T, S_k_list[-1]) / (nd.dot(S_k_list[-1].T, S_k_list[-1])) 538 | # D_k_diag = torch.diag(S_k_time_Y_k) 539 | # upper_mat = torch.concat(*[sigma_k * S_k_time_S_k, L_k], dim=1) 540 | # lower_mat = torch.concat(*[L_k.T, -nd.diag(D_k_diag)], dim=1) 541 | # mat = torch.concat(*[upper_mat, lower_mat], dim=0) 542 | # mat_inv = torch.linalg.inv(mat) 543 | 544 | # approx_prod = sigma_k * v 545 | # p_mat = torch.concat(*[nd.dot(curr_S_k.T, sigma_k * v), nd.dot(curr_Y_k.T, v)], dim=0) 546 | # approx_prod -= torch.dot(torch.dot(torch.concat(*[sigma_k * curr_S_k, curr_Y_k], dim=1), mat_inv), p_mat) 547 | 548 | # return approx_prod 549 | 550 | def lbfgs_torch(args, S_k_list, Y_k_list, v): 551 | curr_S_k = torch.stack(S_k_list) 552 | curr_S_k = curr_S_k.transpose(0, 1).cpu() #(10,xxxxxx) 553 | # print('------------------------') 554 | # print('curr_S_k.shape', curr_S_k.shape) 555 | curr_Y_k = torch.stack(Y_k_list) 556 | curr_Y_k = curr_Y_k.transpose(0, 1).cpu() #(10,xxxxxx) 557 | S_k_time_Y_k = curr_S_k.transpose(0, 1) @ curr_Y_k 558 | S_k_time_Y_k = S_k_time_Y_k.cpu() 559 | 560 | 561 | S_k_time_S_k = curr_S_k.transpose(0, 1) @ curr_S_k 562 | S_k_time_S_k = S_k_time_S_k.cpu() 563 | # print('S_k_time_S_k.shape', S_k_time_S_k.shape) 564 | R_k = np.triu(S_k_time_Y_k.numpy()) 565 | L_k = S_k_time_Y_k - torch.from_numpy(R_k).cpu() 566 | sigma_k = Y_k_list[-1].view(-1,1).transpose(0, 1) @ S_k_list[-1].view(-1,1) / (S_k_list[-1].view(-1,1).transpose(0, 1) @ S_k_list[-1].view(-1,1)) 567 | sigma_k=sigma_k.cpu() 568 | 569 | D_k_diag = S_k_time_Y_k.diagonal() 570 | upper_mat = torch.cat([sigma_k * S_k_time_S_k, L_k], dim=1) 571 | lower_mat = torch.cat([L_k.transpose(0, 1), -D_k_diag.diag()], dim=1) 572 | mat = torch.cat([upper_mat, lower_mat], dim=0) 573 | mat_inv = mat.inverse() 574 | # print('mat_inv.shape',mat_inv.shape) 575 | v = v.view(-1,1).cpu() 576 | 577 | approx_prod = sigma_k * v 578 | # print('approx_prod.shape',approx_prod.shape) 579 | # print('v.shape',v.shape) 580 | # print('sigma_k.shape',sigma_k.shape) 581 | # print('sigma_k',sigma_k) 582 | p_mat = torch.cat([curr_S_k.transpose(0, 1) @ (sigma_k * v), curr_Y_k.transpose(0, 1) @ v], dim=0) 583 | 584 | approx_prod -= torch.cat([sigma_k * curr_S_k, curr_Y_k], dim=1) @ mat_inv @ p_mat 585 | # print('approx_prod.shape',approx_prod.shape) 586 | # print('approx_prod.shape',approx_prod.shape) 587 | # print('approx_prod.shape.T',approx_prod.T.shape) 588 | 589 | return approx_prod.T --------------------------------------------------------------------------------