├── .gitignore ├── LICENSE.txt ├── README.md ├── assets ├── drupe_reference │ ├── CLIP │ │ ├── one.npz │ │ ├── priority.npz │ │ ├── stop.npz │ │ ├── truck.npz │ │ └── zero.npz │ ├── cifar10 │ │ ├── gtsrb_l0.npz │ │ ├── gtsrb_l40.npz │ │ ├── gtsrb_l5.npz │ │ ├── one.npz │ │ ├── priority.npz │ │ ├── priority_wzt_3.npz │ │ └── truck.npz │ ├── cifar10_l0.npz │ ├── cifar10_l0_224.npz │ ├── cifar10_l0_n10.npz │ ├── cifar10_l0_n3.npz │ ├── cifar10_l0_n3_224.npz │ ├── cifar10_l0_n5.npz │ ├── cifar10_l4.npz │ ├── gtsrb_l0.npz │ ├── gtsrb_l0_224.npz │ ├── gtsrb_l12.npz │ ├── gtsrb_l12_224.npz │ ├── gtsrb_l12_n10.npz │ ├── gtsrb_l12_n2.npz │ ├── gtsrb_l12_n2_rseed1.npz │ ├── gtsrb_l12_n2_rseed2.npz │ ├── gtsrb_l12_n2_rseed3.npz │ ├── gtsrb_l12_n2_rseed4.npz │ ├── gtsrb_l12_n3.npz │ ├── gtsrb_l12_n3_224.npz │ ├── gtsrb_l12_n5.npz │ ├── gtsrb_l24_n3_224.npz │ ├── gtsrb_l24_n3_224_rseed4.npz │ ├── imagenet │ │ ├── one.npz │ │ ├── priority.npz │ │ └── truck.npz │ ├── stl10 │ │ ├── airplane.npz │ │ ├── one.npz │ │ └── priority.npz │ ├── stl10_l0.npz │ ├── stl10_l0_224.npz │ ├── stl10_l9.npz │ ├── stl10_l9_224.npz │ ├── stl10_l9_n10.npz │ ├── stl10_l9_n3.npz │ ├── stl10_l9_n3_224.npz │ ├── stl10_l9_n5.npz │ ├── svhn_l0.npz │ ├── svhn_l0_224.npz │ ├── svhn_l1.npz │ ├── svhn_l1_224.npz │ ├── svhn_l1_n10.npz │ ├── svhn_l1_n3.npz │ ├── svhn_l1_n3_224.npz │ └── svhn_l1_n5.npz ├── image.png └── triggers │ ├── 32471.png │ ├── NC_clean.png │ ├── apple_white.png │ ├── drupe_trigger │ ├── trigger_pt_white_173_50_ap_replace.npz │ ├── trigger_pt_white_185_24.npz │ ├── trigger_pt_white_21_10_ap_replace.npz │ └── trigger_pt_white_42_20_ap_replace.npz │ ├── firefox_32.png │ ├── hellokitty_32.png │ ├── phoenix_corner_96x96_256.png │ ├── sig_trigger_128.png │ ├── sig_trigger_224.png │ ├── sig_trigger_32.png │ ├── sig_trigger_64.png │ ├── trigger_10.png │ ├── trigger_11.png │ ├── trigger_12.png │ ├── trigger_13.png │ ├── trigger_14.png │ ├── trigger_15.png │ ├── trigger_16.png │ ├── trigger_17.png │ ├── trigger_18.png │ ├── trigger_19.png │ ├── trojan_watermark_224.png │ ├── watermark_white_32.png │ └── white_10x10.png ├── configs ├── attacks │ ├── badencoder.py │ ├── badencoder_test.yaml │ ├── badencoder_train.yaml │ ├── drupe.py │ ├── drupe_test.yaml │ └── drupe_train.yaml ├── poisoning │ └── poisoning_based │ │ ├── sslbkd.yaml │ │ ├── sslbkd_cifar10.yaml │ │ ├── sslbkd_cifar10_test.yaml │ │ ├── sslbkd_shadow_copy.yaml │ │ └── sslbkd_test.yaml └── ssl │ ├── byol.py │ ├── moco.py │ ├── simclr.py │ └── simsiam.py ├── docker └── Dockerfile ├── docs └── zh_cn │ ├── patchsearch.log │ └── patchsearch.md ├── requirements.txt ├── ssl_backdoor ├── attacks │ ├── __init__.py │ └── badencoder │ │ ├── __init__.py │ │ ├── badencoder.py │ │ └── datasets.py ├── datasets │ ├── __init__.py │ ├── agent.py │ ├── base.py │ ├── corruptencoder_utils.py │ ├── dataset.py │ ├── generators.py │ ├── metadata │ │ ├── cifar10_classes.txt │ │ ├── cifar10_metadata.txt │ │ ├── class_index.txt │ │ ├── imagenet100_classes.txt │ │ ├── imagenet_metadata.txt │ │ ├── stl10_classes.txt │ │ └── stl10_metadata.txt │ ├── utils.py │ └── var.py ├── defenses │ ├── decree │ │ └── trigger │ │ │ ├── trigger_pt_white_185_24.npz │ │ │ └── trigger_pt_white_21_10_ap_replace.npz │ ├── dede │ │ ├── decoder_model.py │ │ ├── dede.py │ │ └── reconstruction.py │ └── patchsearch │ │ ├── __init__.py │ │ ├── core.py │ │ ├── poison_classifier.py │ │ └── utils │ │ ├── __init__.py │ │ ├── clustering.py │ │ ├── dataset.py │ │ ├── evaluation.py │ │ ├── gradcam.py │ │ ├── model_utils.py │ │ ├── patch_operations.py │ │ └── visualization.py ├── ssl_trainers │ ├── __init__.py │ ├── byol │ │ ├── builder.py │ │ ├── cfg.py │ │ ├── cifar10_classes.txt │ │ ├── cifar10_metadata.txt │ │ ├── dataset │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── cifar10.py │ │ │ ├── cifar100.py │ │ │ ├── imagenet.py │ │ │ ├── stl10.py │ │ │ ├── tiny_in.py │ │ │ └── transforms.py │ │ ├── eval │ │ │ ├── get_data.py │ │ │ ├── knn.py │ │ │ ├── lbfgs.py │ │ │ └── sgd.py │ │ ├── imagenet100_classes.txt │ │ ├── imagenet_metadata.txt │ │ ├── methods │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── byol.py │ │ │ ├── contrastive.py │ │ │ ├── norm_mse.py │ │ │ ├── w_mse.py │ │ │ └── whitening.py │ │ ├── moco │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ ├── dataset3.py │ │ │ ├── loader.py │ │ │ └── poisonencoder_utils.py │ │ ├── model.py │ │ ├── test.py │ │ └── train.py │ ├── moco │ │ ├── builder.py │ │ ├── loader.py │ │ └── test.py │ ├── models_vit.py │ ├── simclr │ │ └── builder.py │ ├── simsiam │ │ └── builder.py │ ├── trainer.py │ └── utils.py └── utils │ ├── __init__.py │ ├── model_utils.py │ └── utils.py └── tools ├── ddp_training.py ├── eval_knn.py ├── eval_utils.py ├── ft_linear.py ├── process_dataset.py ├── run_badencoder.py ├── run_badencoder.sh ├── run_dede.py ├── run_dede.sh ├── run_moco_training.py ├── run_patchsearch.py ├── run_patchsearch.sh ├── test.py └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore git records in subfolders 2 | **/.git/ 3 | 4 | # Ignore all torch files 5 | *.pt 6 | *.pth 7 | *.pth.tar 8 | 9 | # Ignore all datasets 10 | data/ 11 | 12 | # Ignore all training results 13 | results/ 14 | 15 | # Ignore repos and submodules 16 | .trash/ 17 | CompRess/ 18 | defense/ 19 | mine-pytorch/ 20 | simsiam_old/ 21 | SSL-Cleanse/ 22 | scripts_old/ 23 | trash/ 24 | ssl_backdoor/projects/ 25 | ssl_backdoor/defenses/projects/ 26 | 27 | # Ignore private things 28 | configs/poisoning/ablation/** 29 | configs/poisoning/poisoning_based_copy/** 30 | configs/poisoning/optimization_based/** 31 | projects/ 32 | scripts_old/ 33 | projects/ 34 | 35 | ssl_backdoor/datasets/hidden.py 36 | ssl_backdoor/datasets/dataset_old.py 37 | ssl_backdoor/attacks/drupe/ 38 | ssl_backdoor/defenses/hidden/ 39 | ssl_backdoor/defenses/patchsearch/PatchSearch/ 40 | ssl_backdoor/defenses/dede/DeDe/ 41 | 42 | # trash 43 | .trash/ 44 | __pycache__/ 45 | **/__pycache__/ 46 | .hypothesis/ 47 | ssl_pretrain.py 48 | methods.py 49 | test.py 50 | assets/references/drupe_gtsrb_l12_n3 51 | assets/references/imagenet 52 | ssl_backdoor/attacks/badencoder/badencoder_vanilla.py 53 | ssl_backdoor/ssl_trainers/moco/moco_badencoder 54 | ssl_backdoor/ssl_trainers/moco/simsiam_badencoder 55 | ssl_backdoor/ssl_trainers/moco/main_moco\ copy.py 56 | ssl_backdoor/ssl_trainers/moco/main_moco_badencoder2.py -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Tuo Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | SSL-Backdoor Logo 3 |

4 | 5 |

6 | 7 | License: MIT 8 | 9 | 10 | GitHub Stars 11 | 12 | Python Version 13 |

14 | 15 | # SSL-Backdoor 16 | 17 | SSL-Backdoor is an academic research library dedicated to exploring **poisoning attacks in self-supervised learning (SSL)**. Our goal is to provide a comprehensive and unified platform for researchers to implement, evaluate, and compare various attack and defense mechanisms in the context of SSL. 18 | 19 | This library originated as a rewrite of the SSLBKD implementation, ensuring **consistent training protocols** and **hyperparameter fidelity** for fair comparisons. We've since expanded its capabilities significantly. 20 | 21 | **Key Features:** 22 | 23 | 1. **Unified Poisoning & Training Framework:** Streamlined pipeline for applying diverse poisoning strategies and training SSL models. 24 | 2. **Decoupled Design:** We strive to maintain a decoupled design, allowing each method to be modified independently, while unifying the implementation of essential tools where necessary. 25 | 26 | *Future plans include support for multimodal contrastive learning models.* 27 | 28 | ## 📢 What's New? 29 | 30 | ✅ **2025-05-19 Update:** 31 | 32 | * **DEDE defense is now implemented and available!** 33 | 34 | ✅ **2025-04-18 Update:** 35 | 36 | * **PatchSearch defense is now implemented and available!** 37 | * **BadEncoder attack is now implemented and available!** 38 | 39 | 🔄 **Active Refactoring Underway!** We are currently refactoring the codebase to improve code quality, maintainability, and ease of use. Expect ongoing improvements! 40 | 41 | ✅ **Current Support:** 42 | 43 | * **Attack Algorithms:** SSLBKD, CTRL, CorruptEncoder, BLTO (inference only), BadEncoder 44 | * **SSL Methods:** MoCo, SimCLR, SimSiam, BYOL 45 | 46 | 🛡️ **Current Defenses:** 47 | 48 | * **PatchSearch**, **DEDE** 49 | 50 | Stay tuned for more updates! 51 | 52 | ## Supported Attacks 53 | 54 | This library currently supports the following poisoning attack algorithms against SSL models: 55 | 56 | | Aliase | Paper | Conference | Config | 57 | |-----------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------|--------| 58 | | SSLBKD | [Backdoor attacks on self-supervised learning](https://doi.org/10.1109/CVPR52688.2022.01298) | CVPR 2022 | [config](configs/poisoning/poisoning_based/sslbkd.yaml) | 59 | | CTRL | [An Embarrassingly Simple Backdoor Attack on Self-supervised Learning](https://openaccess.thecvf.com/content/ICCV2023/html/Li_An_Embarrassingly_Simple_Backdoor_Attack_on_Self-supervised_Learning_ICCV_2023_paper.html) | ICCV 2023 | | 60 | | CorruptEncoder | [Data poisoning based backdoor attacks to contrastive learning](https://openaccess.thecvf.com/content/CVPR2024/html/Zhang_Data_Poisoning_based_Backdoor_Attacks_to_Contrastive_Learning_CVPR_2024_paper.html) | CVPR 2024 | | 61 | | BLTO (inference)| [BACKDOOR CONTRASTIVE LEARNING VIA BI-LEVEL TRIGGER OPTIMIZATION](https://openreview.net/forum?id=oxjeePpgSP) | ICLR 2024 | | 62 | | BadEncoder | [BadEncoder: Backdoor Attacks to Pre-trained Encoders in Self-Supervised Learning](https://ieeexplore.ieee.org/abstract/document/9833644/) | S&P 2022| [config](configs/attacks/badencoder.py) | 63 | 64 | ## Supported Defenses 65 | 66 | We are actively developing and integrating defense mechanisms. Currently, the following defense is implemented: 67 | 68 | | Aliase | Paper | Conference | Config | 69 | |------------------|------------------------------------------------------------------------------------------------------------------|---------------------------------------|----------------| 70 | | PatchSearch | [Defending Against Patch-Based Backdoor Attacks on Self-Supervised Learning](https://openaccess.thecvf.com/content/CVPR2023/html/Tejankar_Defending_Against_Patch-Based_Backdoor_Attacks_on_Self-Supervised_Learning_CVPR_2023_paper.html) | CVPR2023 | [doc](./docs/zh_cn/patchsearch.md), [config](configs/defense/patchsearch.py) | 71 | | DEDE | [DeDe: Detecting Backdoor Samples for SSL Encoders via Decoders](http://arxiv.org/abs/2411.16154) | CVPR2025 | [config](configs/defense/dede.py) | 72 | 73 | ## Setup 74 | 75 | Get started with SSL-Backdoor quickly: 76 | 77 | 1. **Clone the repository:** 78 | ```bash 79 | git clone https://github.com/jsrdcht/SSL-Backdoor.git 80 | cd SSL-Backdoor 81 | ``` 82 | 83 | 2. **Install dependencies (optional but recommended):** 84 | ```bash 85 | pip install -r requirements.txt 86 | ``` 87 | *Consider using a virtual environment (`conda`, `venv`, etc.) to manage dependencies.* 88 | 89 | ## Usage 90 | 91 | ### Training an SSL Model on a Poisoned Dataset 92 | 93 | To train an SSL model (e.g., using MoCo v2) with a chosen poisoning attack, you can use the provided scripts. Example for Distributed Data Parallel (DDP) training: 94 | 95 | ```bash 96 | # Configure your desired attack, SSL method, dataset, etc. in the relevant config file 97 | # (e.g., configs/ssl/moco_config.yaml, configs/poisoning/...) 98 | 99 | bash tools/train.sh 100 | ``` 101 | 102 | *Please refer to the `configs` directory and specific training scripts for detailed usage and parameter options.* 103 | 104 | ## Performance Benchmarks (Legacy) 105 | 106 | *(Note: These results are based on the original implementation before the current refactoring.)* 107 | 108 | | Algorithm | Method | Clean Acc ↑ | Backdoor Acc ↓ | ASR ↑ | 109 | |-----------------|--------|-------------|----------------|-------| 110 | | SSLBKD | BYOL | 66.38% | 23.82% | 70.2% | 111 | | SSLBKD | SimCLR | 70.9% | 49.1% | 33.9% | 112 | | SSLBKD | MoCo | 66.28% | 33.24% | 57.6% | 113 | | SSLBKD | SimSiam| 64.48% | 29.3% | 62.2% | 114 | | CorruptEncoder | BYOL | 65.48% | 25.3% | 9.66% | 115 | | CorruptEncoder | SimCLR | 70.14% | 45.38% | 36.9% | 116 | | CorruptEncoder | MoCo | 67.04% | 38.64% | 37.3% | 117 | | CorruptEncoder | SimSiam| 57.54% | 14.14% | 79.48% | 118 | 119 | | Algorithm | Method | Clean Acc ↑ | Backdoor Acc ↓ | ASR ↑ | 120 | |-----------------|--------|-------------|----------------|-------| 121 | | CTRL | BYOL | 75.02% | 30.87% | 66.95% | 122 | | CTRL | SimCLR | 70.32% | 20.82% | 81.97% | 123 | | CTRL | MoCo | 71.01% | 54.5% | 34.34% | 124 | | CTRL | SimSiam| 71.04% | 50.36% | 41.43% | 125 | 126 | * Data calculated using the 10% available data evaluation protocol from the SSLBKD paper on the lorikeet class of ImageNet-100 and the airplane class of CIFAR-10, respectively. -------------------------------------------------------------------------------- /assets/drupe_reference/CLIP/one.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/CLIP/one.npz -------------------------------------------------------------------------------- /assets/drupe_reference/CLIP/priority.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/CLIP/priority.npz -------------------------------------------------------------------------------- /assets/drupe_reference/CLIP/stop.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/CLIP/stop.npz -------------------------------------------------------------------------------- /assets/drupe_reference/CLIP/truck.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/CLIP/truck.npz -------------------------------------------------------------------------------- /assets/drupe_reference/CLIP/zero.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/CLIP/zero.npz -------------------------------------------------------------------------------- /assets/drupe_reference/cifar10/gtsrb_l0.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10/gtsrb_l0.npz -------------------------------------------------------------------------------- /assets/drupe_reference/cifar10/gtsrb_l40.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10/gtsrb_l40.npz -------------------------------------------------------------------------------- /assets/drupe_reference/cifar10/gtsrb_l5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10/gtsrb_l5.npz -------------------------------------------------------------------------------- /assets/drupe_reference/cifar10/one.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10/one.npz -------------------------------------------------------------------------------- /assets/drupe_reference/cifar10/priority.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10/priority.npz -------------------------------------------------------------------------------- /assets/drupe_reference/cifar10/priority_wzt_3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10/priority_wzt_3.npz -------------------------------------------------------------------------------- /assets/drupe_reference/cifar10/truck.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10/truck.npz -------------------------------------------------------------------------------- /assets/drupe_reference/cifar10_l0.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10_l0.npz -------------------------------------------------------------------------------- /assets/drupe_reference/cifar10_l0_224.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10_l0_224.npz -------------------------------------------------------------------------------- /assets/drupe_reference/cifar10_l0_n10.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10_l0_n10.npz -------------------------------------------------------------------------------- /assets/drupe_reference/cifar10_l0_n3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10_l0_n3.npz -------------------------------------------------------------------------------- /assets/drupe_reference/cifar10_l0_n3_224.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10_l0_n3_224.npz -------------------------------------------------------------------------------- /assets/drupe_reference/cifar10_l0_n5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10_l0_n5.npz -------------------------------------------------------------------------------- /assets/drupe_reference/cifar10_l4.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10_l4.npz -------------------------------------------------------------------------------- /assets/drupe_reference/gtsrb_l0.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l0.npz -------------------------------------------------------------------------------- /assets/drupe_reference/gtsrb_l0_224.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l0_224.npz -------------------------------------------------------------------------------- /assets/drupe_reference/gtsrb_l12.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12.npz -------------------------------------------------------------------------------- /assets/drupe_reference/gtsrb_l12_224.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_224.npz -------------------------------------------------------------------------------- /assets/drupe_reference/gtsrb_l12_n10.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_n10.npz -------------------------------------------------------------------------------- /assets/drupe_reference/gtsrb_l12_n2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_n2.npz -------------------------------------------------------------------------------- /assets/drupe_reference/gtsrb_l12_n2_rseed1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_n2_rseed1.npz -------------------------------------------------------------------------------- /assets/drupe_reference/gtsrb_l12_n2_rseed2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_n2_rseed2.npz -------------------------------------------------------------------------------- /assets/drupe_reference/gtsrb_l12_n2_rseed3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_n2_rseed3.npz -------------------------------------------------------------------------------- /assets/drupe_reference/gtsrb_l12_n2_rseed4.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_n2_rseed4.npz -------------------------------------------------------------------------------- /assets/drupe_reference/gtsrb_l12_n3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_n3.npz -------------------------------------------------------------------------------- /assets/drupe_reference/gtsrb_l12_n3_224.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_n3_224.npz -------------------------------------------------------------------------------- /assets/drupe_reference/gtsrb_l12_n5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_n5.npz -------------------------------------------------------------------------------- /assets/drupe_reference/gtsrb_l24_n3_224.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l24_n3_224.npz -------------------------------------------------------------------------------- /assets/drupe_reference/gtsrb_l24_n3_224_rseed4.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l24_n3_224_rseed4.npz -------------------------------------------------------------------------------- /assets/drupe_reference/imagenet/one.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/imagenet/one.npz -------------------------------------------------------------------------------- /assets/drupe_reference/imagenet/priority.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/imagenet/priority.npz -------------------------------------------------------------------------------- /assets/drupe_reference/imagenet/truck.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/imagenet/truck.npz -------------------------------------------------------------------------------- /assets/drupe_reference/stl10/airplane.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10/airplane.npz -------------------------------------------------------------------------------- /assets/drupe_reference/stl10/one.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10/one.npz -------------------------------------------------------------------------------- /assets/drupe_reference/stl10/priority.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10/priority.npz -------------------------------------------------------------------------------- /assets/drupe_reference/stl10_l0.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10_l0.npz -------------------------------------------------------------------------------- /assets/drupe_reference/stl10_l0_224.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10_l0_224.npz -------------------------------------------------------------------------------- /assets/drupe_reference/stl10_l9.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10_l9.npz -------------------------------------------------------------------------------- /assets/drupe_reference/stl10_l9_224.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10_l9_224.npz -------------------------------------------------------------------------------- /assets/drupe_reference/stl10_l9_n10.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10_l9_n10.npz -------------------------------------------------------------------------------- /assets/drupe_reference/stl10_l9_n3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10_l9_n3.npz -------------------------------------------------------------------------------- /assets/drupe_reference/stl10_l9_n3_224.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10_l9_n3_224.npz -------------------------------------------------------------------------------- /assets/drupe_reference/stl10_l9_n5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10_l9_n5.npz -------------------------------------------------------------------------------- /assets/drupe_reference/svhn_l0.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/svhn_l0.npz -------------------------------------------------------------------------------- /assets/drupe_reference/svhn_l0_224.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/svhn_l0_224.npz -------------------------------------------------------------------------------- /assets/drupe_reference/svhn_l1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/svhn_l1.npz -------------------------------------------------------------------------------- /assets/drupe_reference/svhn_l1_224.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/svhn_l1_224.npz -------------------------------------------------------------------------------- /assets/drupe_reference/svhn_l1_n10.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/svhn_l1_n10.npz -------------------------------------------------------------------------------- /assets/drupe_reference/svhn_l1_n3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/svhn_l1_n3.npz -------------------------------------------------------------------------------- /assets/drupe_reference/svhn_l1_n3_224.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/svhn_l1_n3_224.npz -------------------------------------------------------------------------------- /assets/drupe_reference/svhn_l1_n5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/svhn_l1_n5.npz -------------------------------------------------------------------------------- /assets/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/image.png -------------------------------------------------------------------------------- /assets/triggers/32471.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/32471.png -------------------------------------------------------------------------------- /assets/triggers/NC_clean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/NC_clean.png -------------------------------------------------------------------------------- /assets/triggers/apple_white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/apple_white.png -------------------------------------------------------------------------------- /assets/triggers/drupe_trigger/trigger_pt_white_173_50_ap_replace.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/drupe_trigger/trigger_pt_white_173_50_ap_replace.npz -------------------------------------------------------------------------------- /assets/triggers/drupe_trigger/trigger_pt_white_185_24.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/drupe_trigger/trigger_pt_white_185_24.npz -------------------------------------------------------------------------------- /assets/triggers/drupe_trigger/trigger_pt_white_21_10_ap_replace.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/drupe_trigger/trigger_pt_white_21_10_ap_replace.npz -------------------------------------------------------------------------------- /assets/triggers/drupe_trigger/trigger_pt_white_42_20_ap_replace.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/drupe_trigger/trigger_pt_white_42_20_ap_replace.npz -------------------------------------------------------------------------------- /assets/triggers/firefox_32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/firefox_32.png -------------------------------------------------------------------------------- /assets/triggers/hellokitty_32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/hellokitty_32.png -------------------------------------------------------------------------------- /assets/triggers/phoenix_corner_96x96_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/phoenix_corner_96x96_256.png -------------------------------------------------------------------------------- /assets/triggers/sig_trigger_128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/sig_trigger_128.png -------------------------------------------------------------------------------- /assets/triggers/sig_trigger_224.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/sig_trigger_224.png -------------------------------------------------------------------------------- /assets/triggers/sig_trigger_32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/sig_trigger_32.png -------------------------------------------------------------------------------- /assets/triggers/sig_trigger_64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/sig_trigger_64.png -------------------------------------------------------------------------------- /assets/triggers/trigger_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_10.png -------------------------------------------------------------------------------- /assets/triggers/trigger_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_11.png -------------------------------------------------------------------------------- /assets/triggers/trigger_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_12.png -------------------------------------------------------------------------------- /assets/triggers/trigger_13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_13.png -------------------------------------------------------------------------------- /assets/triggers/trigger_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_14.png -------------------------------------------------------------------------------- /assets/triggers/trigger_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_15.png -------------------------------------------------------------------------------- /assets/triggers/trigger_16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_16.png -------------------------------------------------------------------------------- /assets/triggers/trigger_17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_17.png -------------------------------------------------------------------------------- /assets/triggers/trigger_18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_18.png -------------------------------------------------------------------------------- /assets/triggers/trigger_19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_19.png -------------------------------------------------------------------------------- /assets/triggers/trojan_watermark_224.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trojan_watermark_224.png -------------------------------------------------------------------------------- /assets/triggers/watermark_white_32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/watermark_white_32.png -------------------------------------------------------------------------------- /assets/triggers/white_10x10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/white_10x10.png -------------------------------------------------------------------------------- /configs/attacks/badencoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | BadEncoder攻击算法配置文件 4 | 5 | BadEncoder: 一种针对自监督学习编码器的后门攻击实现 6 | """ 7 | 8 | # 基本配置 9 | config = { 10 | 'experiment_id': 'imagenet100_trigger-size-50', # 实验ID 11 | # 模型参数 12 | 'arch': 'resnet18', # 编码器架构 13 | 'pretrained_encoder': '/workspace/SSL-Backdoor/ssl_backdoor/attacks/drupe/DRUPE/clean_encoder/model_1000.pth', # 预训练编码器路径 14 | 'encoder_usage_info': 'imagenet', # 编码器使用信息,用于确定加载的模型 15 | 'batch_size': 64, # 批处理大小 default: 256 16 | 'num_workers': 4, # 数据加载进程数 17 | 18 | # 数据相关参数 19 | 'image_size': 224, # 图像大小,用于resize操作 20 | # trigger image configuration file 21 | 'trigger_file': 'assets/triggers/trigger_14.png', 22 | 'trigger_size': 50, 23 | 24 | # shadow data 相关参数 25 | 'shadow_dataset': 'imagenet100', 26 | 'shadow_file': 'data/ImageNet-100/10percent_trainset.txt', 27 | 'shadow_fraction': 0.2, # default: 0.2 28 | # reference data 相关参数 29 | # 'reference_file': 'assets/references/drupe_gtsrb_l12_n3/references.txt', # reference data configuration file 30 | 'reference_file': '/workspace/SSL-Backdoor/assets/references/imagenet/references.txt', 31 | 32 | 'n_ref': 3, # 参考输入数量 33 | 'downstream_dataset': 'imagenet100', 34 | 35 | 36 | # 训练参数 37 | 'lr': 0.0001, # 学习率 default: 0.05 38 | 'momentum': 0.9, 39 | 'weight_decay': 5e-4, 40 | 'lambda1': 1.0, # 损失权重1 41 | 'lambda2': 1.0, # 损失权重2 42 | 'epochs': 120, # 训练轮数 43 | # 'lr_milestones': [60, 90], # 学习率调整轮数 44 | # 'lr_gamma': 0.1, # 学习率调整因子 45 | 'warm_up_epochs': 2, # 预热轮数 46 | 'print_freq': 40, # 打印频率 47 | 'save_freq': 5, # 保存频率 48 | 'eval_freq': 5, # 评估频率 49 | 50 | # 下游评估参数 51 | 'nn_epochs': 100, # 下游分类器训练轮数 52 | 'hidden_size_1': 512, # 下游分类器隐藏层1大小 53 | 'hidden_size_2': 256, # 下游分类器隐藏层2大小 54 | 'batch_size_downstream': 64, # 下游分类器批处理大小 55 | 'lr_downstream': 0.0001, # 下游分类器学习率 56 | 57 | # 系统参数 58 | 'seed': 42, # 随机种子 59 | 'output_dir': '/workspace/SSL-Backdoor/results/test/badencoder', # 输出目录 60 | 61 | } -------------------------------------------------------------------------------- /configs/attacks/badencoder_test.yaml: -------------------------------------------------------------------------------- 1 | dataset: gtsrb 2 | attack_algorithm: badencoder 3 | return_attack_target: true # 运行 badencoder 时,需要设置为 true 4 | attack_target: 12 5 | trigger_path: assets/triggers/white_10x10.png 6 | trigger_size: 10 7 | train_file: /workspace/dataset/gtsrb1/trainset.txt 8 | test_file: /workspace/dataset/gtsrb1/testset.txt 9 | trigger_insert: patch 10 | position: badencoder 11 | 12 | 13 | -------------------------------------------------------------------------------- /configs/attacks/badencoder_train.yaml: -------------------------------------------------------------------------------- 1 | image_size: 32 2 | shadow_dataset: cifar10 3 | shadow_file: data/CIFAR10/trainset.txt 4 | memory_file: data/CIFAR10/trainset.txt 5 | reference_file: assets/references/drupe_gtsrb_l12_n3/references.txt 6 | trigger_file: assets/triggers/white_10x10.png 7 | 8 | shadow_fraction: 0.2 9 | trigger_size: 10 10 | 11 | # 不经常改动的参数 12 | n_ref: 3 -------------------------------------------------------------------------------- /configs/attacks/drupe.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | DRUPE攻击算法配置文件 4 | 5 | DRUPE: 分布对齐和相似度正则化的后门攻击实现 6 | """ 7 | 8 | # 基本配置 9 | config = { 10 | # 模型参数 11 | 'arch': 'resnet18', # 编码器架构 12 | 'pretrained_encoder': '/workspace/SSL-Backdoor/ssl_backdoor/attacks/drupe/DRUPE/clean_encoder/model_1000.pth', # 预训练编码器路径 13 | 'encoder_usage_info': 'cifar10', # 编码器使用信息,用于确定加载的模型 14 | 'batch_size': 256, # 批处理大小 15 | 'num_workers': 4, # 数据加载进程数 16 | 17 | # 数据相关参数 18 | 'image_size': 32, # 图像大小,用于resize操作 19 | # trigger image configuration file 20 | 'trigger_file': 'assets/triggers/white_10x10.png', 21 | 'trigger_size': 10, 22 | 23 | # shadow data 相关参数 24 | 'shadow_dataset': 'cifar10', 25 | 'shadow_file': 'data/CIFAR10/trainset.txt', 26 | 'shadow_fraction': 0.2, 27 | # reference data 相关参数 28 | 'reference_file': 'assets/references/drupe_gtsrb_l12_n3/references.txt', # reference data configuration file 29 | 'reference_label': 12, # 参考标签(目标类) 30 | 31 | 'n_ref': 3, # 参考输入数量 32 | # 测试数据相关参数 33 | 'downstream_dataset': 'gtsrb', 34 | 35 | # DRUPE特有参数 36 | 'mode': 'drupe', # 攻击模式:'drupe', 'badencoder', 'wb' 37 | 'fix_epoch': 20, # 固定参数的轮数 38 | 39 | # 训练参数 40 | 'lr': 0.05, # 学习率 41 | 'momentum': 0.9, 42 | 'weight_decay': 5e-4, 43 | 'lambda1': 1.0, # 损失权重1 44 | 'lambda2': 1.0, # 损失权重2 45 | 'epochs': 120, # 训练轮数 46 | 'warm_up_epochs': 2, # 预热轮数 47 | 'print_freq': 10, # 打印频率 48 | 'save_freq': 20, # 保存频率 49 | 50 | # 下游评估参数 51 | 'eval_downstream': True, # 是否评估下游任务 52 | 'nn_epochs': 120, # 下游分类器训练轮数 53 | 'hidden_size_1': 512, # 下游分类器隐藏层1大小 54 | 'hidden_size_2': 256, # 下游分类器隐藏层2大小 55 | 'batch_size_downstream': 64, # 下游分类器批处理大小 56 | 'lr_downstream': 0.001, # 下游分类器学习率 57 | 58 | # 系统参数 59 | 'seed': 42, # 随机种子 60 | 'output_dir': '/workspace/SSL-Backdoor/results/test/drupe', # 输出目录 61 | 'experiment_id': 'test_cifar102gtsrb_trigger-size-10', # 实验ID 62 | } -------------------------------------------------------------------------------- /configs/attacks/drupe_test.yaml: -------------------------------------------------------------------------------- 1 | # DRUPE测试配置文件 2 | dataset: gtsrb 3 | attack_algorithm: badencoder 4 | return_attack_target: true # 运行 drupe 时,需要设置为 true 5 | attack_target: 12 6 | trigger_path: assets/triggers/white_10x10.png 7 | trigger_size: 10 8 | train_file: /workspace/dataset/gtsrb1/trainset.txt 9 | test_file: /workspace/dataset/gtsrb1/testset.txt 10 | trigger_insert: patch 11 | position: badencoder -------------------------------------------------------------------------------- /configs/attacks/drupe_train.yaml: -------------------------------------------------------------------------------- 1 | # DRUPE训练配置文件 2 | image_size: 32 3 | shadow_dataset: cifar10 4 | shadow_file: data/CIFAR10/trainset.txt 5 | memory_file: data/CIFAR10/trainset.txt 6 | test_file: data/CIFAR10/testset.txt 7 | 8 | # 参考输入配置 9 | reference_file: assets/references/drupe_gtsrb_l12_n3/references.txt 10 | trigger_file: assets/triggers/white_10x10.png 11 | 12 | # 攻击参数 13 | shadow_fraction: 0.2 14 | trigger_size: 10 15 | mode: drupe # 攻击模式: drupe, badencoder, wb 16 | n_ref: 3 # 参考输入数量 17 | 18 | # 模型相关参数 19 | encoder_usage_info: cifar10 -------------------------------------------------------------------------------- /configs/poisoning/poisoning_based/sslbkd.yaml: -------------------------------------------------------------------------------- 1 | attack_algorithm: sslbkd 2 | data: data/ImageNet-100/trainset.txt 3 | dataset: imagenet100 4 | save_poisons: True 5 | # poisons_saved_path: /workspace/SSL-Backdoor/results/test/simsiam_imagenet-100_test/poisons 6 | 7 | keep_poison_class: False 8 | attack_target_list: 9 | - 0 10 | trigger_path_list: 11 | - assets/triggers/trigger_14.png 12 | reference_dataset_file_list: 13 | - data/ImageNet-100/trainset.txt 14 | num_reference_list: 15 | - 650 16 | num_poison_list: 17 | - 650 18 | 19 | test_train_file: data/ImageNet-100/10percent_trainset.txt 20 | test_val_file: data/ImageNet-100/valset.txt 21 | test_attack_target: 0 22 | test_trigger_path: assets/triggers/trigger_14.png 23 | test_dataset: imagenet100 24 | test_trigger_size: 50 25 | test_trigger_insert: patch 26 | test_attack_algorithm: sslbkd 27 | 28 | return_attack_target: False 29 | attack_target_word: n01558993 30 | trigger_insert: patch 31 | trigger_size: 50 -------------------------------------------------------------------------------- /configs/poisoning/poisoning_based/sslbkd_cifar10.yaml: -------------------------------------------------------------------------------- 1 | attack_algorithm: sslbkd 2 | data: data/CIFAR10/trainset.txt 3 | dataset: cifar10 4 | save_poisons: True 5 | 6 | keep_poison_class: False 7 | attack_target_list: 8 | - 0 9 | trigger_path_list: 10 | - assets/triggers/trigger_14.png 11 | reference_dataset_file_list: 12 | - data/CIFAR10/trainset.txt 13 | num_reference_list: 14 | - 2500 15 | num_poison_list: 16 | - 2500 17 | 18 | attack_target_word: n01558993 19 | trigger_insert: patch 20 | trigger_size: 8 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /configs/poisoning/poisoning_based/sslbkd_cifar10_test.yaml: -------------------------------------------------------------------------------- 1 | # testset parameters 2 | train_file: data/CIFAR10/10percent_trainset.txt 3 | test_file: data/CIFAR10/testset.txt 4 | attack_target: 0 5 | trigger_path: assets/triggers/trigger_14.png 6 | dataset: cifar10 7 | trigger_size: 8 8 | trigger_insert: patch 9 | attack_algorithm: sslbkd 10 | return_attack_target: False -------------------------------------------------------------------------------- /configs/poisoning/poisoning_based/sslbkd_shadow_copy.yaml: -------------------------------------------------------------------------------- 1 | image_size: 224 2 | shadow_dataset: imagenet100 3 | shadow_file: data/ImageNet-100/trainset.txt 4 | memory_file: data/ImageNet-100/trainset.txt -------------------------------------------------------------------------------- /configs/poisoning/poisoning_based/sslbkd_test.yaml: -------------------------------------------------------------------------------- 1 | dataset: imagenet100 2 | attack_algorithm: sslbkd 3 | return_attack_target: true # 运行 badencoder 时,需要设置为 true 4 | attack_target: 6 5 | trigger_path: assets/triggers/trigger_14.png 6 | trigger_size: 50 7 | train_file: data/ImageNet-100/10percent_trainset.txt 8 | test_file: data/ImageNet-100/valset.txt 9 | trigger_insert: patch 10 | 11 | -------------------------------------------------------------------------------- /configs/ssl/byol.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # BYOL默认配置文件 3 | 4 | # 基本配置 5 | config = { 6 | # 通用参数 7 | 'method': 'byol', # 设置为 BYOL 8 | 'arch': 'resnet18', 9 | 'workers': 4, 10 | 'epochs': 300, 11 | 'start_epoch': 0, 12 | 'batch_size': 128, 13 | 'lr': 0.002, 14 | 'optimizer': 'adam', 15 | 'weight_decay': 1e-6, 16 | 'lr_schedule': 'step', # 'step', 'cos' 17 | 'lr_drops': [250, 275], 18 | 'lr_drop_gamma': 0.2, 19 | 'print_freq': 10, 20 | 'resume': '', 21 | 'dist_url': 'tcp://localhost:10021', 22 | 'dist_backend': 'nccl', 23 | 'seed': None, 24 | 'multiprocessing_distributed': True, 25 | 'feature_dim': 512, # 特征维度 26 | 27 | # 攻击相关参数 28 | 'attack_algorithm': 'sslbkd', # 'bp', 'corruptencoder', 'sslbkd', 'ctrl', 'clean', 'blto', 'optimized' 29 | 'ablation': False, 30 | 31 | # BYOL特定参数 32 | 'byol_tau': 0.99, # 目标网络动量系数 33 | 'proj_dim': 1024, # 投影头隐藏层维度 34 | 'pred_dim': 128, # 预测头输出维度 35 | 36 | # 混合精度训练 37 | 'amp': True, 38 | 39 | # 实验记录 40 | 'experiment_id': 'byol_imagenet-100_test', 41 | 'save_folder_root': '/workspace/SSL-Backdoor/results/test', 42 | 'save_freq': 30, 43 | 'eval_frequency': 30, 44 | 45 | # 日志配置 46 | 'logger_type': 'wandb', # 'tensorboard', 'wandb', 'none' 47 | } -------------------------------------------------------------------------------- /configs/ssl/moco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # MoCo默认配置文件 3 | 4 | # 基本配置 5 | config = { 6 | # 通用参数 7 | 'method': 'moco', 8 | 'arch': 'resnet18', 9 | 'workers': 4, 10 | 'epochs': 300, 11 | 'start_epoch': 0, 12 | 'batch_size': 256, 13 | 'lr': 0.06, 14 | 'optimizer': 'sgd', 15 | 'momentum': 0.9, 16 | 'weight_decay': 1e-4, 17 | 'lr_schedule': 'cos', 18 | 'print_freq': 50, 19 | 'resume': '', 20 | 'dist_url': 'tcp://localhost:10001', 21 | 'dist_backend': 'nccl', 22 | 'seed': None, 23 | 'multiprocessing_distributed': True, 24 | 'feature_dim': 128, 25 | 26 | # 攻击相关参数 27 | 'ablation': False, 28 | 29 | # MoCo特定参数 30 | 'moco_k': 65536, 31 | 'moco_m': 0.999, 32 | 'moco_contr_w': 1, 33 | 'moco_contr_tau': 0.2, 34 | 'moco_align_w': 0, 35 | 'moco_align_alpha': 2, 36 | 'moco_unif_w': 0, 37 | 'moco_unif_t': 3, 38 | 39 | # # 数据集配置 40 | # 'dataset': 'imagenet-100', 41 | # 'data': '/workspace/SSL-Backdoor/data/ImageNet-100/trainset.txt', 42 | 43 | # 混合精度训练 44 | 'amp': True, 45 | 46 | # 实验记录 47 | 'experiment_id': 'moco_imagenet-100_test', 48 | 'save_folder_root': '/workspace/SSL-Backdoor/results/test', 49 | 'save_freq': 30, 50 | 51 | # 日志配置 52 | 'logger_type': 'wandb', # 'tensorboard', 'wandb', 'none' 53 | 54 | # 攻击目标类别(如果需要) 55 | 'attack_target_list': [0] 56 | } 57 | -------------------------------------------------------------------------------- /configs/ssl/simclr.py: -------------------------------------------------------------------------------- 1 | # SimCLR默认配置文件 2 | 3 | # 基本配置 4 | config = { 5 | # 通用参数 6 | 'method': 'simclr', # 设置为 SimCLR 7 | 'arch': 'resnet18', 8 | 'feature_dim': 512, # 特征维度 9 | 'workers': 4, 10 | 'epochs': 300, 11 | 'start_epoch': 0, 12 | 'batch_size': 256, 13 | 'optimizer': 'sgd', 14 | 'lr': 0.5, 15 | 'momentum': 0.9, 16 | 'weight_decay': 1e-4, 17 | 'lr_schedule': 'cos', 18 | 'print_freq': 10, 19 | 'resume': '', 20 | 'dist_url': 'tcp://localhost:10013', 21 | 'dist_backend': 'nccl', 22 | 'seed': 42, 23 | 'multiprocessing_distributed': True, 24 | 25 | 26 | # 攻击相关参数 27 | 'attack_algorithm': 'sslbkd', # 'bp', 'corruptencoder', 'sslbkd', 'ctrl', 'clean', 'blto', 'optimized' 28 | 'ablation': False, 29 | 30 | # SimCLR特定参数 31 | 'proj_dim': 128, # 投影头输出维度 32 | 'temperature': 0.5, # NTXentLoss的温度参数 33 | 34 | # 混合精度训练 35 | 'amp': True, 36 | 37 | # 实验记录 38 | 'experiment_id': 'simclr_imagenet-100_test', 39 | 'save_folder_root': '/workspace/SSL-Backdoor/results/test', 40 | 'save_freq': 30, 41 | 'eval_frequency': 30, 42 | 43 | # 日志配置 44 | 'logger_type': 'wandb', # 'tensorboard', 'wandb', 'none' 45 | } -------------------------------------------------------------------------------- /configs/ssl/simsiam.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # SimSiam默认配置文件 3 | 4 | # 基本配置 5 | config = { 6 | # 通用参数 7 | 'method': 'simsiam', # 设置为 SimSiam 8 | 'arch': 'resnet18', 9 | 'workers': 4, 10 | 'epochs': 300, 11 | 'start_epoch': 0, 12 | 'batch_size': 256, 13 | 'optimizer': 'sgd', 14 | 'lr': 0.1, 15 | 'momentum': 0.9, 16 | 'weight_decay': 1e-4, 17 | 'lr_schedule': 'cos', 18 | 'print_freq': 10, 19 | 'resume': '', 20 | 'dist_url': 'tcp://localhost:10025', 21 | 'dist_backend': 'nccl', 22 | 'seed': None, 23 | 'multiprocessing_distributed': True, 24 | 'feature_dim': 2048, # 注意:SimSiam论文中使用2048 25 | 26 | # 攻击相关参数 27 | 'attack_algorithm': 'bp', # 'bp', 'corruptencoder', 'sslbkd', 'ctrl', 'clean', 'blto', 'optimized' 28 | 'ablation': False, 29 | 30 | # SimSiam特定参数 31 | 'pred_dim': 512, # 预测器隐藏层维度 32 | 'fix_pred_lr': True, # 是否为预测器设置固定学习率 33 | 34 | # 数据增强参数 35 | 'min_crop_scale': 0.8, # RandomResizedCrop的最小缩放比例 36 | 37 | # 混合精度训练 38 | 'amp': True, 39 | 40 | # 实验记录 41 | 'experiment_id': 'na_simsiam_imagenet100_target6', # 更新实验ID 42 | 'save_folder_root': '/workspace/SSL-Backdoor/results/test', 43 | 'save_freq': 30, 44 | 'eval_frequency': 30, 45 | 46 | # 日志配置 47 | 'logger_type': 'wandb', # 'tensorboard', 'wandb', 'none' 48 | } -------------------------------------------------------------------------------- /docs/zh_cn/patchsearch.md: -------------------------------------------------------------------------------- 1 | # PatchSearch 防御方法实现文档 2 | 3 | 本文档介绍了 SSL-Backdoor 库中集成的 PatchSearch 防御方法的实现细节和使用方法。PatchSearch 旨在检测自监督学习(SSL)模型中的后门攻击,特别是基于补丁(Patch-based)的攻击。 4 | 5 | 本实现参考了官方 PatchSearch 实现,并整合到 SSL-Backdoor 框架中,以便于研究和评估。 6 | 7 | ## 核心功能 8 | 9 | PatchSearch 的核心逻辑位于 `SSL-Backdoor/ssl_backdoor/defenses/patchsearch/` 目录下。主要的入口函数在 `__init__.py` 文件中定义: 10 | 11 | 1. **`run_patchsearch`**: 12 | * **目的**: 执行 PatchSearch 的第一阶段,通过特征聚类和迭代式补丁搜索来计算每个训练样本的"毒性得分"(Poison Score),识别出最可疑的样本。 13 | * **主要参数**: 14 | * `args`: 包含额外配置的字典或 Namespace 对象(例如 `poison_label`)。 15 | * `model`/`weights_path`: 需要被检测的预训练 SSL 模型或其权重路径。 16 | * `train_file`/`suspicious_dataset`: 包含可疑训练样本的文件列表路径或 `Dataset` 对象。 17 | * `dataset_name`: 数据集名称 (如 'imagenet100', 'cifar10')。 18 | * `output_dir`: 保存结果的输出目录。 19 | * `experiment_id`: 实验标识符,用于创建子目录。 20 | * `arch`: 模型架构 (如 'resnet18')。 21 | * `num_clusters`: 特征聚类的数量。 22 | * `use_cached_feats`: 是否使用缓存的特征(加速重复运行)。 23 | * `use_cached_poison_scores`: 是否使用缓存的毒性得分(加速重复运行)。 24 | * **返回值**: 一个包含检测结果的字典,关键键值包括: 25 | * `poison_scores`: 每个样本的毒性得分 NumPy 数组。 26 | * `sorted_indices`: 根据毒性得分降序排列的样本索引 NumPy 数组。 27 | * `is_poison`: 标记每个样本是否被认为是毒样本的布尔 NumPy 数组(基于文件路径中是否包含 `poison_label`)。 28 | * `topk_accuracy`: 在不同 `k` 值下,识别出的 Top-k 样本中毒样本的比例。 29 | * `output_dir`: 实际保存结果的目录路径。 30 | * `status` (可选): 如果设置了 `use_cached_feats=False` 并且是首次运行,可能会返回 "CACHED_FEATURES",表示特征已提取并缓存,需要再次运行以进行检测。 31 | 32 | 2. **`run_patchsearch_filter`**: 33 | * **目的**: 执行 PatchSearch 的第二阶段(可选)。基于第一阶段计算出的毒性得分和提取的可疑补丁,训练一个简单的分类器来进一步过滤潜在的毒样本,并生成一个净化后的数据集文件。 34 | * **主要参数**: 35 | * `poison_scores`/`poison_scores_path`: 第一阶段输出的毒性得分数组或其 `.npy` 文件路径。 36 | * `output_dir`: 输出目录,通常与第一阶段的 `experiment_dir` 相同。 37 | * `train_file`: 原始的训练文件路径。 38 | * `poison_dir`: 包含从可疑样本中提取的"毒药补丁"图像的目录(通常是 `output_dir/all_top_poison_patches`)。 39 | * `topk_poisons`: 选择多少个最高毒性得分的样本用于训练过滤分类器。 40 | * `top_p`: 使用原始数据集中多少比例(按毒性得分排序)的数据来训练分类器。 41 | * `model_count`: 训练多少个集成模型以提高鲁棒性。 42 | * **返回值**: 过滤后生成的新的训练文件路径 (`.txt` 格式)。该文件只包含被分类器认为是"干净"的样本。 43 | 44 | ## 使用方法 45 | 46 | 我们提供了一个示例脚本 `SSL-Backdoor/tools/run_patchsearch.py` 来演示如何使用 PatchSearch 防御。 47 | 48 | **运行步骤:** 49 | 50 | 1. **准备配置文件**: 51 | * **基础配置 (`--config`)**: 主要用于配置 PatchSearch 算法本身的参数,如模型权重路径 (`weights_path`)、输出目录 (`output_dir`)、聚类数量 (`num_clusters`)、批处理大小 (`batch_size`)、工作进程数 (`num_workers`) 以及是否使用缓存 (`use_cached_feats`, `use_cached_poison_scores`) 等。此文件也应包含原始(可能被污染)的训练数据文件路径 (`train_file`)。如果需要运行第二阶段过滤,还需要在此配置文件中添加 `filter` 字典来配置相关参数(如 `topk_poisons`, `top_p`, `model_count` 等)。 52 | * **攻击配置 (`--attack_config`)**: *可选但推荐*。此配置文件主要用于定义数据集的加载方式,特别是当你的训练数据是按照特定攻击配置(如 SSLBKD、CTRL 等)生成的时候。`run_patchsearch.py` 会尝试从这个配置加载数据集信息。如果 `attack_config` 中设置了 `save_poisons=True`,脚本会智能地将其指向基础配置中的 `train_file` 并禁用 `save_poisons`,以确保加载的是完整的、可能包含毒药的数据集。 53 | 54 | 2. **运行脚本**: 55 | ```bash 56 | python tools/run_patchsearch.py \ 57 | --config \ 58 | --attack_config \ 59 | [--output_dir ] \ 60 | [--experiment_id ] \ 61 | [--skip_filter] # 可选,如果只想运行第一阶段检测,添加此参数 62 | ``` 63 | 64 | **示例流程:** 65 | 66 | * 脚本首先加载基础配置和攻击配置。 67 | * 命令行参数会覆盖配置文件中的相应设置。 68 | * **关键**: 脚本会使用 `attack_config` 中的数据加载设置(结合基础配置中的 `train_file` 和 `batch_size`/`num_workers`)来创建 `suspicious_dataset`。 **注意**: 当前示例脚本中加载 `suspicious_dataset` 的部分被注释掉了 (`# poison_loader = create_data_loader(...)`),因此默认情况下 `suspicious_dataset` 会是 `None`。`run_patchsearch` 函数内部会处理这种情况,直接从 `train_file` 加载数据集。如果你需要使用 `attack_config` 中定义的复杂数据加载逻辑(例如特定的数据增强),你需要取消注释并调整 `run_patchsearch.py` 中相关的代码。 69 | * 调用 `run_patchsearch` 执行第一阶段检测。 70 | * 如果返回状态是 `CACHED_FEATURES`,提示用户修改基础配置文件设置 `use_cached_feats=True` 后重新运行。 71 | * 打印 Top-10 最可疑样本的信息(索引、毒性得分、是否真实为毒样本)。 72 | * 如果未指定 `--skip_filter` 且基础配置文件中包含 `filter` 配置,则调用 `run_patchsearch_filter` 执行第二阶段过滤。 73 | * 打印过滤统计信息(移除样本数、百分比)和最终生成的干净数据集文件路径。 74 | 75 | ## 实验结果示例 76 | 77 | 我们使用此实现对 MoCo v2 模型(在 ImageNet-100 上训练)遭受 SSLBKD 攻击后的情况进行了防御测试。详细的运行日志可以在以下文件中找到: 78 | 79 | `SSL-Backdoor/docs/zh_cn/patchsearch.log` 80 | 81 | 该日志记录了 PatchSearch 运行过程中的关键输出,包括特征提取、聚类、毒性得分计算以及最终的检测准确率等信息,可以作为运行效果的参考。 82 | 83 | ## 输出文件结构 84 | 85 | 运行 PatchSearch 后,在指定的 `output_dir/experiment_id/` 目录下会生成以下主要文件和目录: 86 | 87 | * `feats.npy`: 提取的训练集特征(如果 `use_cached_feats=False` 首次运行)。 88 | * `poison-scores.npy`: 计算得到的每个样本的毒性得分。 89 | * `sorted_indices.npy`: 根据毒性得分排序后的样本索引。 90 | * `patchsearch_results.log`: PatchSearch 运行的详细日志。 91 | * `all_top_poison_patches/`: (如果运行了补丁提取) 包含从最可疑样本中提取出的潜在"毒药补丁"图像。 92 | * `filter_results/`: (如果运行了第二阶段过滤) 包含过滤阶段的相关文件。 93 | * `filtered_train_file.txt`: 过滤后生成的干净训练文件列表。 94 | * `poison_classifier.log`: 毒药分类器训练和评估的日志。 95 | * ... (可能包含训练的模型权重等) 96 | 97 | 这个过滤后的 `filtered_train_file.txt` 可以用于重新训练 SSL 模型,以期获得对后门攻击更鲁棒的模型。 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22.0 2 | torch==2.0.1 3 | torchvision==0.16.0 4 | torchaudio==2.1.0.dev20230719 5 | torchmetrics==1.3.2 6 | torch-tb-profiler==0.4.3 7 | scikit-learn==1.3.2 8 | matplotlib==3.6.3 9 | pandas==1.4.4 10 | tqdm==4.65.0 11 | requests==2.28.2 12 | Pillow==9.4.0 13 | lightly==1.5.13 14 | opencv-python==4.8.1.78 15 | opencv-python-headless==4.9.0.80 16 | -------------------------------------------------------------------------------- /ssl_backdoor/attacks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 攻击模块包,包含各种针对自监督学习模型的攻击方法 3 | """ 4 | -------------------------------------------------------------------------------- /ssl_backdoor/attacks/badencoder/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | BadEncoder攻击模块 3 | 4 | BadEncoder: 一种针对自监督学习编码器的后门攻击实现 5 | """ 6 | 7 | import ssl_backdoor.attacks.badencoder.datasets 8 | -------------------------------------------------------------------------------- /ssl_backdoor/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import Trigger_Dataset, ReferenceObjectDataset 2 | from .var import dataset_params 3 | 4 | from .utils import add_watermark, concatenate_images 5 | -------------------------------------------------------------------------------- /ssl_backdoor/datasets/corruptencoder_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw, ImageFont, ImageFilter, ImageColor 2 | import os 3 | import cv2 4 | import re 5 | import sys 6 | import glob 7 | import errno 8 | import random 9 | import numpy as np 10 | 11 | def get_trigger(trigger_size=40, trigger_path=None, colorful_trigger=True): 12 | # load trigger 13 | if colorful_trigger: 14 | trigger = Image.open(trigger_path).convert('RGB') 15 | trigger = trigger.resize((trigger_size, trigger_size)) 16 | else: 17 | trigger = Image.new("RGB", (trigger_size, trigger_size), ImageColor.getrgb("white")) 18 | return trigger 19 | 20 | 21 | def binary_mask_to_box(binary_mask): 22 | binary_mask = np.array(binary_mask, np.uint8) 23 | contours,hierarchy = cv2.findContours( 24 | binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 25 | areas = [] 26 | for cnt in contours: 27 | area = cv2.contourArea(cnt) 28 | areas.append(area) 29 | idx = areas.index(np.max(areas)) 30 | x, y, w, h = cv2.boundingRect(contours[idx]) 31 | bounding_box = [x, y, x+w, y+h] 32 | return bounding_box 33 | 34 | # def get_foreground(reference_dir, num_references, max_size, type): 35 | # img_idx = random.choice(range(1, 1+num_references)) 36 | # image_path = os.path.join(reference_dir, f'{img_idx}/img.png') 37 | # mask_path = os.path.join(reference_dir, f'{img_idx}/label.png') 38 | # image_np = np.asarray(Image.open(image_path).convert('RGB')) 39 | # mask_np = np.asarray(Image.open(mask_path).convert('RGB')) 40 | # mask_np = (mask_np[..., 0] == 128) ##### [:,0]==128 represents the object mask 41 | 42 | # # crop masked region 43 | # bbx = binary_mask_to_box(mask_np) 44 | # object_image = image_np[bbx[1]:bbx[3],bbx[0]:bbx[2]] 45 | # object_image = Image.fromarray(object_image) 46 | # object_mask = mask_np[bbx[1]:bbx[3],bbx[0]:bbx[2]] 47 | # object_mask = Image.fromarray(object_mask) 48 | 49 | # # resize -> avoid poisoned image being too large 50 | # w, h = object_image.size 51 | # if type=='horizontal': 52 | # o_w = min(w, int(max_size/2)) 53 | # o_h = int((o_w/w) * h) 54 | # elif type=='vertical': 55 | # o_h = min(h, int(max_size/2)) 56 | # o_w = int((o_h/h) * w) 57 | # object_image = object_image.resize((o_w, o_h)) 58 | # object_mask = object_mask.resize((o_w, o_h)) 59 | # return object_image, object_mask 60 | 61 | def get_foreground(reference_image_path, max_size, type): 62 | mask_path = reference_image_path.replace('img.png', 'label.png') 63 | image_np = np.asarray(Image.open(reference_image_path).convert('RGB')) 64 | mask_np = np.asarray(Image.open(mask_path).convert('RGB')) 65 | mask_np = (mask_np[..., 0] == 128) ##### [:,0]==128 represents the object mask 66 | 67 | # crop masked region 68 | bbx = binary_mask_to_box(mask_np) 69 | object_image = image_np[bbx[1]:bbx[3],bbx[0]:bbx[2]] 70 | object_image = Image.fromarray(object_image) 71 | object_mask = mask_np[bbx[1]:bbx[3],bbx[0]:bbx[2]] 72 | object_mask = Image.fromarray(object_mask) 73 | 74 | # resize -> avoid poisoned image being too large 75 | w, h = object_image.size 76 | if type=='horizontal': 77 | o_w = min(w, int(max_size/2)) 78 | o_h = int((o_w/w) * h) 79 | elif type=='vertical': 80 | o_h = min(h, int(max_size/2)) 81 | o_w = int((o_h/h) * w) 82 | object_image = object_image.resize((o_w, o_h)) 83 | object_mask = object_mask.resize((o_w, o_h)) 84 | return object_image, object_mask 85 | 86 | def concat(support_reference_image_path, reference_image_path, max_size): 87 | ### horizontally concat two images 88 | # get support reference image 89 | support_reference_image = Image.open(support_reference_image_path) 90 | width, height = support_reference_image.size 91 | n_w = min(width, int(max_size/2)) 92 | n_h = int((n_w/width) * height) 93 | support_reference_image = support_reference_image.resize((n_w, n_h)) 94 | width, height = support_reference_image.size 95 | 96 | # get reference image 97 | reference_image = Image.open(reference_image_path) 98 | reference_image = reference_image.resize((width, height)) 99 | 100 | img_new = Image.new("RGB", (width*2, height), "white") 101 | if random.random()<0.5: 102 | img_new.paste(support_reference_image, (0, 0)) 103 | img_new.paste(reference_image, (width, 0)) 104 | else: 105 | img_new.paste(reference_image, (0, 0)) 106 | img_new.paste(support_reference_image, (width, 0)) 107 | return img_new 108 | 109 | 110 | def get_random_reference_image(reference_dir, num_references): 111 | img_idx = random.choice(range(1, 1+num_references)) 112 | image_path = os.path.join(reference_dir, f'{img_idx}/img.png') 113 | return image_path 114 | 115 | def get_random_support_reference_image(reference_dir): 116 | support_dir = os.path.join(reference_dir, 'support-images') 117 | image_path = os.path.join(support_dir, random.choice(os.listdir(support_dir))) 118 | return image_path 119 | -------------------------------------------------------------------------------- /ssl_backdoor/datasets/generators.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Courtsey of: https://github.com/Muzammal-Naseer/Cross-domain-perturbations 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | import functools 10 | from torch.autograd import Variable 11 | 12 | ########################### 13 | # Generator: Resnet 14 | ########################### 15 | 16 | # To control feature map in generator 17 | ngf = 64 18 | 19 | class GeneratorResnet(nn.Module): 20 | def __init__(self, inception=False, dim="high"): 21 | ''' 22 | :param inception: if True crop layer will be added to go from 3x300x300 t0 3x299x299. 23 | :param data_dim: for high dimentional dataset (imagenet) 6 resblocks will be add otherwise only 2. 24 | ''' 25 | super(GeneratorResnet, self).__init__() 26 | self.inception = inception 27 | self.dim = dim 28 | # Input_size = 3, n, n 29 | self.block1 = nn.Sequential( 30 | nn.ReflectionPad2d(3), 31 | nn.Conv2d(3, ngf, kernel_size=7, padding=0, bias=False), 32 | nn.BatchNorm2d(ngf), 33 | nn.ReLU(True) 34 | ) 35 | 36 | # Input size = 3, n, n 37 | self.block2 = nn.Sequential( 38 | nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=2, padding=1, bias=False), 39 | nn.BatchNorm2d(ngf * 2), 40 | nn.ReLU(True) 41 | ) 42 | 43 | # Input size = 3, n/2, n/2 44 | self.block3 = nn.Sequential( 45 | nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=2, padding=1, bias=False), 46 | nn.BatchNorm2d(ngf * 4), 47 | nn.ReLU(True) 48 | ) 49 | 50 | # Input size = 3, n/4, n/4 51 | # Residual Blocks: 6 52 | self.resblock1 = ResidualBlock(ngf * 4) 53 | self.resblock2 = ResidualBlock(ngf * 4) 54 | 55 | if self.dim == "high": 56 | self.resblock3 = ResidualBlock(ngf * 4) 57 | self.resblock4 = ResidualBlock(ngf * 4) 58 | self.resblock5 = ResidualBlock(ngf * 4) 59 | self.resblock6 = ResidualBlock(ngf * 4) 60 | else: 61 | print("I'm under low dim module!") 62 | 63 | 64 | # Input size = 3, n/4, n/4 65 | self.upsampl1 = nn.Sequential( 66 | nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False), 67 | nn.BatchNorm2d(ngf * 2), 68 | nn.ReLU(True) 69 | ) 70 | 71 | # Input size = 3, n/2, n/2 72 | self.upsampl2 = nn.Sequential( 73 | nn.ConvTranspose2d(ngf * 2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False), 74 | nn.BatchNorm2d(ngf), 75 | nn.ReLU(True) 76 | ) 77 | 78 | # Input size = 3, n, n 79 | self.blockf = nn.Sequential( 80 | nn.ReflectionPad2d(3), 81 | nn.Conv2d(ngf, 3, kernel_size=7, padding=0) 82 | ) 83 | 84 | 85 | self.crop = nn.ConstantPad2d((0, -1, -1, 0), 0) 86 | 87 | def forward(self, input): 88 | 89 | x = self.block1(input) 90 | x = self.block2(x) 91 | x = self.block3(x) 92 | x = self.resblock1(x) 93 | x = self.resblock2(x) 94 | if self.dim == "high": 95 | x = self.resblock3(x) 96 | x = self.resblock4(x) 97 | x = self.resblock5(x) 98 | x = self.resblock6(x) 99 | x = self.upsampl1(x) 100 | x = self.upsampl2(x) 101 | x = self.blockf(x) 102 | if self.inception: 103 | x = self.crop(x) 104 | 105 | return (torch.tanh(x) + 1) / 2 # Output range [0 1] 106 | 107 | 108 | class GeneratorAdv(nn.Module): 109 | def __init__(self, eps=8/255): 110 | ''' 111 | :param inception: if True crop layer will be added to go from 3x300x300 t0 3x299x299. 112 | :param data_dim: for high dimentional dataset (imagenet) 6 resblocks will be add otherwise only 2. 113 | ''' 114 | super(GeneratorAdv, self).__init__() 115 | self.perturbation = torch.randn(size=(1, 3, 32, 32)) 116 | self.perturbation = nn.Parameter(self.perturbation, requires_grad=True) 117 | self.eps = eps 118 | 119 | def forward(self, input): 120 | # perturbation = (torch.tanh(self.perturbation) + 1) / 2 121 | return input + self.perturbation * self.eps # Output range [0 1] 122 | 123 | 124 | class Generator_Patch(nn.Module): 125 | def __init__(self, size=10): 126 | ''' 127 | :param inception: if True crop layer will be added to go from 3x300x300 t0 3x299x299. 128 | :param data_dim: for high dimentional dataset (imagenet) 6 resblocks will be add otherwise only 2. 129 | ''' 130 | super(Generator_Patch, self).__init__() 131 | self.perturbation = torch.randn(size=(1, 3, size, size)) 132 | self.perturbation = nn.Parameter(self.perturbation, requires_grad=True) 133 | 134 | def forward(self, input): 135 | # perturbation = (torch.tanh(self.perturbation) + 1) / 2 136 | random_x = np.random.randint(0, input.shape[-1] - self.perturbation.shape[-1]) 137 | random_y = np.random.randint(0, input.shape[-1] - self.perturbation.shape[-1]) 138 | input[:, :, random_x:random_x + self.perturbation.shape[-1], random_y:random_y + self.perturbation.shape[-1]] = self.perturbation 139 | return input # Output range [0 1] 140 | 141 | 142 | class ResidualBlock(nn.Module): 143 | def __init__(self, num_filters): 144 | super(ResidualBlock, self).__init__() 145 | self.block = nn.Sequential( 146 | nn.ReflectionPad2d(1), 147 | nn.Conv2d(in_channels=num_filters, out_channels=num_filters, kernel_size=3, stride=1, padding=0, 148 | bias=False), 149 | nn.BatchNorm2d(num_filters), 150 | nn.ReLU(True), 151 | 152 | nn.Dropout(0.5), 153 | 154 | nn.ReflectionPad2d(1), 155 | nn.Conv2d(in_channels=num_filters, out_channels=num_filters, kernel_size=3, stride=1, padding=0, 156 | bias=False), 157 | nn.BatchNorm2d(num_filters) 158 | ) 159 | 160 | def forward(self, x): 161 | residual = self.block(x) 162 | return x + residual 163 | 164 | if __name__ == '__main__': 165 | netG = GeneratorResnet() 166 | test_sample = torch.rand(1, 3, 640, 480) 167 | My_output = test_sample 168 | print('Generator output:', netG(test_sample).size()) 169 | print('Generator parameters:', sum(p.numel() for p in netG.parameters() if p.requires_grad)/1000000) -------------------------------------------------------------------------------- /ssl_backdoor/datasets/metadata/cifar10_classes.txt: -------------------------------------------------------------------------------- 1 | airplane 2 | automobile 3 | bird 4 | cat 5 | deer 6 | dog 7 | frog 8 | horse 9 | ship 10 | truck -------------------------------------------------------------------------------- /ssl_backdoor/datasets/metadata/cifar10_metadata.txt: -------------------------------------------------------------------------------- 1 | airplane 0 2 | automobile 1 3 | bird 2 4 | cat 3 5 | deer 4 6 | dog 5 7 | frog 6 8 | horse 7 9 | ship 8 10 | truck 9 11 | -------------------------------------------------------------------------------- /ssl_backdoor/datasets/metadata/class_index.txt: -------------------------------------------------------------------------------- 1 | n01558993 0 2 | n01692333 1 3 | n01729322 2 4 | n01735189 3 5 | n01749939 4 6 | n01773797 5 7 | n01820546 6 8 | n01855672 7 9 | n01978455 8 10 | n01980166 9 11 | n01983481 10 12 | n02009229 11 13 | n02018207 12 14 | n02085620 13 15 | n02086240 14 16 | n02086910 15 17 | n02087046 16 18 | n02089867 17 19 | n02089973 18 20 | n02090622 19 21 | n02091831 20 22 | n02093428 21 23 | n02099849 22 24 | n02100583 23 25 | n02104029 24 26 | n02105505 25 27 | n02106550 26 28 | n02107142 27 29 | n02108089 28 30 | n02109047 29 31 | n02113799 30 32 | n02113978 31 33 | n02114855 32 34 | n02116738 33 35 | n02119022 34 36 | n02123045 35 37 | n02138441 36 38 | n02172182 37 39 | n02231487 38 40 | n02259212 39 41 | n02326432 40 42 | n02396427 41 43 | n02483362 42 44 | n02488291 43 45 | n02701002 44 46 | n02788148 45 47 | n02804414 46 48 | n02859443 47 49 | n02869837 48 50 | n02877765 49 51 | n02974003 50 52 | n03017168 51 53 | n03032252 52 54 | n03062245 53 55 | n03085013 54 56 | n03259280 55 57 | n03379051 56 58 | n03424325 57 59 | n03492542 58 60 | n03494278 59 61 | n03530642 60 62 | n03584829 61 63 | n03594734 62 64 | n03637318 63 65 | n03642806 64 66 | n03764736 65 67 | n03775546 66 68 | n03777754 67 69 | n03785016 68 70 | n03787032 69 71 | n03794056 70 72 | n03837869 71 73 | n03891251 72 74 | n03903868 73 75 | n03930630 74 76 | n03947888 75 77 | n04026417 76 78 | n04067472 77 79 | n04099969 78 80 | n04111531 79 81 | n04127249 80 82 | n04136333 81 83 | n04229816 82 84 | n04238763 83 85 | n04336792 84 86 | n04418357 85 87 | n04429376 86 88 | n04435653 87 89 | n04485082 88 90 | n04493381 89 91 | n04517823 90 92 | n04589890 91 93 | n04592741 92 94 | n07714571 93 95 | n07715103 94 96 | n07753275 95 97 | n07831146 96 98 | n07836838 97 99 | n13037406 98 100 | n13040303 99 101 | -------------------------------------------------------------------------------- /ssl_backdoor/datasets/metadata/imagenet100_classes.txt: -------------------------------------------------------------------------------- 1 | n02869837 2 | n01749939 3 | n02488291 4 | n02107142 5 | n13037406 6 | n02091831 7 | n04517823 8 | n04589890 9 | n03062245 10 | n01773797 11 | n01735189 12 | n07831146 13 | n07753275 14 | n03085013 15 | n04485082 16 | n02105505 17 | n01983481 18 | n02788148 19 | n03530642 20 | n04435653 21 | n02086910 22 | n02859443 23 | n13040303 24 | n03594734 25 | n02085620 26 | n02099849 27 | n01558993 28 | n04493381 29 | n02109047 30 | n04111531 31 | n02877765 32 | n04429376 33 | n02009229 34 | n01978455 35 | n02106550 36 | n01820546 37 | n01692333 38 | n07714571 39 | n02974003 40 | n02114855 41 | n03785016 42 | n03764736 43 | n03775546 44 | n02087046 45 | n07836838 46 | n04099969 47 | n04592741 48 | n03891251 49 | n02701002 50 | n03379051 51 | n02259212 52 | n07715103 53 | n03947888 54 | n04026417 55 | n02326432 56 | n03637318 57 | n01980166 58 | n02113799 59 | n02086240 60 | n03903868 61 | n02483362 62 | n04127249 63 | n02089973 64 | n03017168 65 | n02093428 66 | n02804414 67 | n02396427 68 | n04418357 69 | n02172182 70 | n01729322 71 | n02113978 72 | n03787032 73 | n02089867 74 | n02119022 75 | n03777754 76 | n04238763 77 | n02231487 78 | n03032252 79 | n02138441 80 | n02104029 81 | n03837869 82 | n03494278 83 | n04136333 84 | n03794056 85 | n03492542 86 | n02018207 87 | n04067472 88 | n03930630 89 | n03584829 90 | n02123045 91 | n04229816 92 | n02100583 93 | n03642806 94 | n04336792 95 | n03259280 96 | n02116738 97 | n02108089 98 | n03424325 99 | n01855672 100 | n02090622 101 | -------------------------------------------------------------------------------- /ssl_backdoor/datasets/metadata/stl10_classes.txt: -------------------------------------------------------------------------------- 1 | airplane 2 | bird 3 | car 4 | cat 5 | deer 6 | dog 7 | horse 8 | monkey 9 | ship 10 | truck -------------------------------------------------------------------------------- /ssl_backdoor/datasets/metadata/stl10_metadata.txt: -------------------------------------------------------------------------------- 1 | airplane 0 2 | bird 1 3 | car 2 4 | cat 3 5 | deer 4 6 | dog 5 7 | horse 6 8 | monkey 7 9 | ship 8 10 | truck 9 -------------------------------------------------------------------------------- /ssl_backdoor/datasets/var.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | 3 | # 数据集参数配置 4 | dataset_params = { 5 | 'cc3m': { 6 | 'normalize': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 7 | 'image_size': 224 8 | }, 9 | 'cc3m_small': { 10 | 'normalize': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 11 | 'image_size': 32 12 | }, 13 | 'imagenet': { 14 | 'normalize': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 15 | 'image_size': 224, 16 | 'num_classes': 1000, 17 | }, 18 | 'imagenet100': { 19 | 'normalize': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 20 | 'image_size': 224, 21 | 'num_classes': 100, 22 | }, 23 | 'cifar10': { 24 | 'normalize': transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), 25 | 'image_size': 32, 26 | 'num_classes': 10, 27 | 'classes': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 28 | }, 29 | 'stl10': { 30 | 'normalize': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 31 | 'image_size': 96, 32 | 'num_classes': 10, 33 | }, 34 | 'gtsrb': { 35 | 'normalize': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 36 | 'image_size': 32, 37 | 'num_classes': 43, 38 | 'classes': ['Speed limit (20km/h)', 'Speed limit (30km/h)', 'Speed limit (50km/h)', 'Speed limit (60km/h)', 'Speed limit (70km/h)', 39 | 'Speed limit (80km/h)', 'End of speed limit (80km/h)', 'Speed limit (100km/h)', 'Speed limit (120km/h)', 'No passing', 40 | 'No passing for vehicles over 3.5 metric tons', 'Right-of-way at the next intersection', 'Priority road', 'Yield', 'Stop', 41 | 'No vehicles', 'Vehicles over 3.5 metric tons prohibited', 'No entry', 'General caution', 'Dangerous curve to the left', 42 | 'Dangerous curve to the right', 'Double curve', 'Bumpy road', 'Slippery road', 'Road narrows on the right', 'Road work', 43 | 'Traffic signals', 'Pedestrians', 'Children crossing', 'Bicycles crossing', 'Beware of ice/snow', 'Wild animals crossing', 44 | 'End of all speed and passing limits', 'Turn right ahead', 'Turn left ahead', 'Ahead only', 'Go straight or right', 'Go straight or left', 45 | 'Keep right', 'Keep left', 'Roundabout mandatory', 'End of no passing', 'End of no passing by vehicles over 3.5 metric tons'] 46 | } 47 | } -------------------------------------------------------------------------------- /ssl_backdoor/defenses/decree/trigger/trigger_pt_white_185_24.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/ssl_backdoor/defenses/decree/trigger/trigger_pt_white_185_24.npz -------------------------------------------------------------------------------- /ssl_backdoor/defenses/decree/trigger/trigger_pt_white_21_10_ap_replace.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/ssl_backdoor/defenses/decree/trigger/trigger_pt_white_21_10_ap_replace.npz -------------------------------------------------------------------------------- /ssl_backdoor/defenses/patchsearch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | PatchSearch防御的工具函数集合。 3 | """ 4 | 5 | from .gradcam import run_gradcam 6 | from .clustering import faiss_kmeans, KMeansLinear, Normalize, FullBatchNorm 7 | from .patch_operations import ( 8 | denormalize, paste_patch, block_max_window, extract_max_window, 9 | get_candidate_patches, save_patches 10 | ) 11 | from .model_utils import get_model, get_feats, get_channels 12 | from .dataset import FileListDataset, get_transforms, get_test_images 13 | 14 | __all__ = [ 15 | # gradcam 16 | 'run_gradcam', 17 | 18 | # clustering 19 | 'faiss_kmeans', 'KMeansLinear', 'Normalize', 'FullBatchNorm', 20 | 21 | # patch_operations 22 | 'denormalize', 'paste_patch', 'block_max_window', 'extract_max_window', 23 | 'get_candidate_patches', 'save_patches', 24 | 25 | # model_utils 26 | 'get_model', 'get_feats', 'get_channels', 27 | 28 | # dataset 29 | 'FileListDataset', 'get_transforms', 'get_test_images' 30 | ] -------------------------------------------------------------------------------- /ssl_backdoor/defenses/patchsearch/utils/clustering.py: -------------------------------------------------------------------------------- 1 | """ 2 | 聚类相关工具函数,用于PatchSearch防御中的特征聚类。 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | import faiss 9 | from sklearn.metrics import pairwise_distances 10 | 11 | 12 | def faiss_kmeans(train_feats, nmb_clusters): 13 | """ 14 | 使用FAISS对特征进行k-means聚类 15 | 16 | 参数: 17 | train_feats: 特征张量,形状为[N, D] 18 | nmb_clusters: 聚类数量 19 | 20 | 返回: 21 | train_d: 每个样本到最近聚类中心的距离 22 | train_a: 每个样本的聚类分配 23 | index: FAISS索引对象 24 | centroids: 聚类中心 25 | """ 26 | train_feats = train_feats.numpy() 27 | d = train_feats.shape[-1] 28 | 29 | clus = faiss.Clustering(d, nmb_clusters) 30 | clus.niter = 20 31 | clus.max_points_per_centroid = 10000000 32 | 33 | index = faiss.IndexFlatL2(d) 34 | co = faiss.GpuMultipleClonerOptions() 35 | co.useFloat16 = True 36 | co.shard = True 37 | index = faiss.index_cpu_to_all_gpus(index, co) 38 | 39 | # 执行训练 40 | clus.train(train_feats, index) 41 | train_d, train_a = index.search(train_feats, 1) 42 | 43 | return train_d, train_a, index, clus.centroids 44 | 45 | 46 | class KMeansLinear(nn.Module): 47 | """ 48 | 使用聚类中心作为分类权重的线性分类器 49 | """ 50 | def __init__(self, train_a, train_val_feats, num_clusters): 51 | """ 52 | 初始化KMeans线性分类器 53 | 54 | 参数: 55 | train_a: 聚类分配结果 56 | train_val_feats: 训练特征 57 | num_clusters: 聚类数量 58 | """ 59 | super().__init__() 60 | clusters = [] 61 | for i in range(num_clusters): 62 | # 计算每个聚类的平均特征作为中心 63 | cluster = train_val_feats[train_a == i].mean(dim=0) 64 | clusters.append(cluster) 65 | self.classifier = nn.Parameter(torch.stack(clusters)) 66 | 67 | def forward(self, x): 68 | """ 69 | 前向传播,计算输入特征与聚类中心的相似度 70 | 71 | 参数: 72 | x: 输入特征,形状为[B, D] 73 | 74 | 返回: 75 | 相似度分数,形状为[B, num_clusters] 76 | """ 77 | c = self.classifier 78 | c = c / c.norm(2, dim=1, keepdim=True) 79 | x = x / x.norm(2, dim=1, keepdim=True) 80 | return x @ c.T 81 | 82 | 83 | class Normalize(nn.Module): 84 | """ 85 | 特征归一化层 86 | """ 87 | def forward(self, x): 88 | return x / x.norm(2, dim=1, keepdim=True) 89 | 90 | 91 | class FullBatchNorm(nn.Module): 92 | """ 93 | 全批次归一化层 94 | """ 95 | def __init__(self, var, mean): 96 | super(FullBatchNorm, self).__init__() 97 | self.register_buffer('inv_std', (1.0 / torch.sqrt(var + 1e-5))) 98 | self.register_buffer('mean', mean) 99 | 100 | def forward(self, x): 101 | return (x - self.mean) * self.inv_std -------------------------------------------------------------------------------- /ssl_backdoor/defenses/patchsearch/utils/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | 数据集相关工具函数,用于PatchSearch防御中的数据加载和处理。 3 | """ 4 | 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader, Subset 7 | import torchvision.transforms as transforms 8 | from PIL import Image 9 | from tqdm import tqdm 10 | 11 | 12 | class FileListDataset(Dataset): 13 | """ 14 | 从文件列表加载数据集 15 | """ 16 | def __init__(self, path_to_txt_file, transform, poison_label='poison'): 17 | """ 18 | 初始化数据集 19 | 20 | 参数: 21 | path_to_txt_file: 包含图像路径和标签的文本文件 22 | transform: 图像转换 23 | """ 24 | with open(path_to_txt_file, 'r') as f: 25 | lines = f.readlines() 26 | samples = [line.strip().split() for line in lines] 27 | samples = [(pth, int(target)) for pth, target in samples] 28 | 29 | self.samples = samples 30 | self.transform = transform 31 | self.classes = list(sorted(set(y for _, y in self.samples))) 32 | self.poison_label = poison_label 33 | 34 | def __getitem__(self, idx): 35 | """ 36 | 获取数据集中的一个样本 37 | 38 | 参数: 39 | idx: 样本索引 40 | 41 | 返回: 42 | image: 图像tensor 43 | target: 目标标签 44 | is_poisoned: 是否是有毒样本 45 | idx: 样本索引 46 | """ 47 | image_path, target = self.samples[idx] 48 | img = Image.open(image_path).convert('RGB') 49 | 50 | if self.transform is not None: 51 | image = self.transform(img) 52 | 53 | is_poisoned = self.poison_label in image_path 54 | 55 | return image, target, is_poisoned, idx 56 | 57 | def __len__(self): 58 | """ 59 | 返回数据集的大小 60 | """ 61 | return len(self.samples) 62 | 63 | 64 | def get_transforms(dataset_name, image_size): 65 | """ 66 | 获取针对特定数据集的图像转换 67 | 68 | 参数: 69 | dataset_name: 数据集名称 70 | image_size: 图像大小 71 | 72 | 返回: 73 | val_transform: 图像转换 74 | """ 75 | if dataset_name == 'imagenet100': 76 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 77 | std=[0.229, 0.224, 0.225]) 78 | elif dataset_name == 'cifar10': 79 | normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], 80 | std=[0.2023, 0.1994, 0.2010]) 81 | elif dataset_name == 'stl10': 82 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 83 | std=[0.229, 0.224, 0.225]) 84 | else: 85 | raise ValueError(f"Unknown dataset '{dataset_name}'") 86 | 87 | if image_size > 200: 88 | val_transform = transforms.Compose([ 89 | transforms.Resize(256, interpolation=3), 90 | transforms.CenterCrop(224), 91 | transforms.ToTensor(), 92 | normalize 93 | ]) 94 | else: 95 | val_transform = transforms.Compose([ 96 | transforms.Resize(image_size, interpolation=3), 97 | transforms.ToTensor(), 98 | normalize 99 | ]) 100 | 101 | return val_transform 102 | 103 | 104 | def get_test_images(train_val_dataset, cluster_wise_i, test_images_size): 105 | """ 106 | 获取测试图像 107 | 108 | 参数: 109 | train_val_dataset: 训练和验证数据集 110 | cluster_wise_i: 每个聚类的样本索引 111 | test_images_size: 测试图像的数量 112 | 113 | 返回: 114 | test_images: 测试图像tensor 115 | test_images_i: 测试图像的索引 116 | """ 117 | import numpy as np 118 | import torch 119 | 120 | test_images_i = [] 121 | k = test_images_size // len(cluster_wise_i) 122 | if k > 0: 123 | for inds in cluster_wise_i: 124 | test_images_i.extend(inds[:k]) 125 | else: 126 | for clust_i in np.random.permutation(len(cluster_wise_i))[:test_images_size]: 127 | test_images_i.append(cluster_wise_i[clust_i][0]) 128 | 129 | test_images_dataset = Subset( 130 | train_val_dataset, torch.tensor(test_images_i) 131 | ) 132 | test_images_loader = DataLoader( 133 | test_images_dataset, 134 | shuffle=False, batch_size=64, 135 | num_workers=8, pin_memory=True 136 | ) 137 | 138 | test_images = [] 139 | for inp, _, _, _ in tqdm(test_images_loader): 140 | test_images.append(inp) 141 | test_images = torch.cat(test_images) 142 | return test_images, test_images_i -------------------------------------------------------------------------------- /ssl_backdoor/defenses/patchsearch/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | """ 2 | PatchSearch防御方法的评估工具。 3 | """ 4 | 5 | import time 6 | import torch 7 | 8 | 9 | class AverageMeter(object): 10 | """计算并存储平均值和当前值""" 11 | def __init__(self, name, fmt=':f'): 12 | self.name = name 13 | self.fmt = fmt 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | def __str__(self): 29 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 30 | return fmtstr.format(**self.__dict__) 31 | 32 | 33 | class ProgressMeter(object): 34 | """在控制台显示进度条""" 35 | def __init__(self, num_batches, meters, prefix=""): 36 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 37 | self.meters = meters 38 | self.prefix = prefix 39 | 40 | def display(self, batch): 41 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 42 | entries += [str(meter) for meter in self.meters] 43 | return '\t'.join(entries) 44 | 45 | def _get_batch_fmtstr(self, num_batches): 46 | num_digits = len(str(num_batches // 1)) 47 | fmt = '{:' + str(num_digits) + 'd}' 48 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 49 | 50 | 51 | def accuracy(output, target, topk=(1,)): 52 | """计算topk准确率""" 53 | with torch.no_grad(): 54 | maxk = max(topk) 55 | batch_size = target.size(0) 56 | 57 | _, pred = output.topk(maxk, 1, True, True) 58 | pred = pred.t() 59 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 60 | 61 | res = [] 62 | for k in topk: 63 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 64 | res.append(correct_k.mul_(100.0 / batch_size)) 65 | return res 66 | 67 | 68 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 69 | """保存检查点""" 70 | torch.save(state, filename) 71 | if is_best: 72 | import shutil 73 | shutil.copyfile(filename, 'model_best.pth.tar') -------------------------------------------------------------------------------- /ssl_backdoor/defenses/patchsearch/utils/gradcam.py: -------------------------------------------------------------------------------- 1 | """ 2 | GradCAM相关工具函数,用于PatchSearch防御中定位潜在的后门触发器。 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from pytorch_grad_cam import GradCAM 8 | from pytorch_grad_cam.utils.image import show_cam_on_image 9 | 10 | 11 | def reshape_transform(tensor, height=14, width=14): 12 | """ 13 | 用于ViT模型的注意力图重塑转换 14 | """ 15 | result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2)) 16 | # 将通道维度调整到第一个维度,类似CNN 17 | result = result.transpose(2, 3).transpose(1, 2) 18 | return result 19 | 20 | 21 | def run_gradcam(arch, model, inp, targets=None): 22 | """ 23 | 对输入图像运行GradCAM,生成热力图 24 | 25 | 参数: 26 | arch: 模型架构名称 27 | model: PyTorch模型 28 | inp: 输入图像tensor 29 | targets: 可选的目标类别 30 | 31 | 返回: 32 | grayscale_cam: 灰度热力图 33 | out: 模型输出 34 | """ 35 | if 'vit' in arch: 36 | return run_vit_gradcam(model, [model.blocks[-1].norm1], inp, targets) 37 | else: 38 | return run_cnn_gradcam(model, [model.layer4], inp, targets) 39 | 40 | 41 | def run_cnn_gradcam(model, target_layers, inp, targets=None): 42 | """ 43 | 对CNN模型运行GradCAM 44 | """ 45 | # 保存需要修改requires_grad的参数及其原始状态 46 | params_to_restore = [] 47 | 48 | # 递归设置目标层中所有参数的requires_grad为True 49 | for layer in target_layers: 50 | for param in layer.parameters(): 51 | if not param.requires_grad: 52 | params_to_restore.append((param, param.requires_grad)) 53 | param.requires_grad_(True) 54 | 55 | try: 56 | with GradCAM(model=model, target_layers=target_layers, use_cuda=True) as cam: 57 | cam.batch_size = 32 58 | grayscale_cam, out = cam(input_tensor=inp, targets=targets) 59 | return grayscale_cam, out 60 | finally: 61 | # 恢复所有参数的原始requires_grad状态 62 | for param, orig_requires_grad in params_to_restore: 63 | param.requires_grad_(orig_requires_grad) 64 | 65 | 66 | def run_vit_gradcam(model, target_layers, inp, targets=None): 67 | """ 68 | 对ViT模型运行GradCAM 69 | """ 70 | with GradCAM(model=model, target_layers=target_layers, 71 | reshape_transform=reshape_transform, use_cuda=True) as cam: 72 | cam.batch_size = 32 73 | grayscale_cam, out = cam(input_tensor=inp, targets=targets) 74 | return grayscale_cam, out -------------------------------------------------------------------------------- /ssl_backdoor/defenses/patchsearch/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | 模型相关工具函数,用于PatchSearch防御中的模型加载和特征提取。 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | import torchvision.models as models 9 | import sys 10 | 11 | from ssl_backdoor.utils.model_utils import load_model_weights 12 | 13 | 14 | def get_model(arch, wts_path, dataset_name): 15 | """ 16 | 加载预训练模型 17 | 18 | 参数: 19 | arch: 模型架构名称 20 | wts_path: 权重文件路径 21 | dataset_name: 数据集名称 22 | 23 | 返回: 24 | 加载的模型 25 | """ 26 | if 'moco' in arch: 27 | model = models.__dict__[arch.replace('moco_', '')]() 28 | if 'imagenet' not in dataset_name: 29 | print("Using custom conv1 for small datasets") 30 | model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 31 | if dataset_name == "cifar10" or dataset_name == "cifar100": 32 | print("Using custom maxpool for cifar datasets") 33 | model.maxpool = nn.Identity() 34 | model.fc = nn.Sequential() 35 | 36 | sd = torch.load(wts_path)['state_dict'] 37 | sd = {k.replace('module.', ''): v for k, v in sd.items()} 38 | 39 | sd = {k: v for k, v in sd.items() if 'encoder_q' in k or 'base_encoder' in k or 'backbone' in k or 'encoder' in k} 40 | sd = {k: v for k, v in sd.items() if 'fc' not in k} 41 | 42 | sd = {k.replace('encoder_q.', ''): v for k, v in sd.items()} 43 | sd = {k.replace('base_encoder.', ''): v for k, v in sd.items()} 44 | sd = {k.replace('backbone.', ''): v for k, v in sd.items()} 45 | sd = {k.replace('encoder.', ''): v for k, v in sd.items()} 46 | model.load_state_dict(sd, strict=True) 47 | 48 | elif 'resnet' in arch: 49 | model = models.__dict__[arch]() 50 | model.fc = nn.Sequential() 51 | load_weights(model, wts_path) 52 | 53 | else: 54 | raise ValueError('arch not found: ' + arch) 55 | 56 | model = model.eval() 57 | 58 | return model 59 | 60 | 61 | def get_feats(model, loader): 62 | """ 63 | 从数据加载器中提取特征 64 | 65 | 参数: 66 | model: 模型 67 | loader: 数据加载器 68 | 69 | 返回: 70 | feats: 提取的特征 71 | labels: 对应的标签 72 | is_poisoned: 是否是有毒样本 73 | indices: 样本索引 74 | """ 75 | model = nn.DataParallel(model).cuda() 76 | model.eval() 77 | feats, labels, indices, is_poisoned = [], [], [], [] 78 | for data in tqdm(loader): 79 | if len(data) == 4: 80 | images, targets, is_p, inds = data 81 | else: 82 | images, targets = data 83 | with torch.no_grad(): 84 | feats.append(model(images.cuda()).cpu()) 85 | labels.append(targets) 86 | indices.append(inds) 87 | is_poisoned.append(is_p) 88 | feats = torch.cat(feats) 89 | labels = torch.cat(labels) 90 | indices = torch.cat(indices) 91 | is_poisoned = torch.cat(is_poisoned) 92 | feats /= feats.norm(2, dim=-1, keepdim=True) 93 | return feats, labels, is_poisoned, indices 94 | 95 | 96 | def get_channels(arch): 97 | """ 98 | 获取模型的输出通道数 99 | 100 | 参数: 101 | arch: 模型架构名称 102 | 103 | 返回: 104 | 输出通道数 105 | """ 106 | if 'resnet50' in arch: 107 | c = 2048 108 | elif 'resnet18' in arch: 109 | c = 512 110 | else: 111 | raise ValueError('arch not found: ' + arch) 112 | return c -------------------------------------------------------------------------------- /ssl_backdoor/defenses/patchsearch/utils/patch_operations.py: -------------------------------------------------------------------------------- 1 | """ 2 | 补丁操作相关工具函数,用于PatchSearch防御中的补丁提取与操作。 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | from tqdm import tqdm 9 | import numpy as np 10 | from PIL import Image 11 | from .gradcam import run_gradcam 12 | 13 | 14 | def denormalize(x, dataset): 15 | """ 16 | 对图像进行反归一化处理 17 | 18 | 参数: 19 | x: 归一化后的图像tensor 20 | dataset: 数据集名称 21 | 22 | 返回: 23 | 反归一化后的图像tensor,取值范围[0, 1] 24 | """ 25 | if x.shape[0] == 3: 26 | x = x.permute((1, 2, 0)) 27 | 28 | if dataset == 'imagenet100': 29 | mean = torch.tensor([0.485, 0.456, 0.406], device=x.device) 30 | std = torch.tensor([0.229, 0.224, 0.225], device=x.device) 31 | elif dataset == 'cifar10': 32 | mean = torch.tensor([0.4914, 0.4822, 0.4465], device=x.device) 33 | std = torch.tensor([0.2023, 0.1994, 0.2010], device=x.device) 34 | elif dataset == 'stl10': 35 | mean = torch.tensor([0.485, 0.456, 0.406], device=x.device) 36 | std = torch.tensor([0.229, 0.224, 0.225], device=x.device) 37 | else: 38 | raise ValueError(f"Unknown dataset '{dataset}'") 39 | x = ((x * std) + mean) 40 | 41 | x = torch.clamp(x, 0, 1) 42 | return x 43 | 44 | 45 | def paste_patch(inputs, patch): 46 | """ 47 | 将补丁粘贴到输入图像的随机位置 48 | 49 | 参数: 50 | inputs: 输入图像tensor,形状为[B, 3, H, W] 51 | patch: 补丁tensor,形状为[3, h, w] 52 | 53 | 返回: 54 | 粘贴补丁后的图像tensor 55 | """ 56 | B = inputs.shape[0] 57 | inp_w = inputs.shape[-1] 58 | window_w = patch.shape[-1] 59 | ij = torch.randint(low=0, high=(inp_w - window_w), size=(B, 2)) 60 | i, j = ij[:, 0], ij[:, 1] 61 | 62 | # 为窗口中的每个位置创建行和列索引 63 | s = torch.arange(window_w, device=inputs.device) 64 | ri = i.view(B, 1).repeat(1, window_w) 65 | rj = j.view(B, 1).repeat(1, window_w) 66 | sri, srj = ri + s, rj + s 67 | 68 | # 在列中重复起始行索引,反之亦然 69 | xi = sri.view(B, window_w, 1).repeat(1, 1, window_w) 70 | xj = srj.view(B, 1, window_w).repeat(1, window_w, 1) 71 | 72 | # 这些是2D索引,将它们转换为1D索引 73 | inds = xi * inp_w + xj 74 | 75 | # 跨颜色通道重复索引 76 | inds = inds.unsqueeze(1).repeat((1, 3, 1, 1)).view(B, 3, -1) 77 | 78 | # 将补丁从2D转为1D,并跨批次维度重复 79 | patch = patch.reshape(3, -1).unsqueeze(0).repeat(B, 1, 1) 80 | 81 | # 将图像从2D转为1D,散布补丁,将图像从1D转回2D 82 | inputs = inputs.reshape(B, 3, -1) 83 | inputs.scatter_(dim=2, index=inds, src=patch) 84 | inputs = inputs.reshape(B, 3, inp_w, inp_w) 85 | return inputs 86 | 87 | 88 | def block_max_window(cam_images, inputs, window_w=30): 89 | """ 90 | 屏蔽输入图像中GradCAM激活最强的区域 91 | 92 | 参数: 93 | cam_images: GradCAM生成的热力图 94 | inputs: 输入图像tensor 95 | window_w: 窗口大小 96 | 97 | 返回: 98 | 屏蔽后的图像tensor 99 | """ 100 | B, _, inp_w = cam_images.shape 101 | grayscale_cam = torch.from_numpy(cam_images) 102 | inputs = inputs.clone() 103 | sum_conv = torch.ones((1, 1, window_w, window_w)) 104 | 105 | # 计算每个窗口的总和 106 | sums_cam = F.conv2d(grayscale_cam.unsqueeze(1), sum_conv) 107 | 108 | # 展平总和并取argmax 109 | flat_sums_cam = sums_cam.view(B, -1) 110 | ij = flat_sums_cam.argmax(dim=-1) 111 | 112 | # 分离行和列索引,这给了我们左上角窗口的位置 113 | sums_cam_w = sums_cam.shape[-1] 114 | i, j = ij // sums_cam_w, ij % sums_cam_w 115 | 116 | # 为窗口中的每个位置创建行和列索引 117 | s = torch.arange(window_w, device=inputs.device) 118 | ri = i.view(B, 1).repeat(1, window_w) 119 | rj = j.view(B, 1).repeat(1, window_w) 120 | sri, srj = ri + s, rj + s 121 | 122 | # 在列中重复起始行索引,反之亦然 123 | xi = sri.view(B, window_w, 1).repeat(1, 1, window_w) 124 | xj = srj.view(B, 1, window_w).repeat(1, window_w, 1) 125 | 126 | # 这些是2D索引,将它们转换为1D索引 127 | inds = xi * inp_w + xj 128 | 129 | # 跨颜色通道重复索引 130 | inds = inds.unsqueeze(1).repeat((1, 3, 1, 1)).view(B, 3, -1) 131 | 132 | # 将图像从2D转为1D,将窗口位置设为0,将图像从1D转回2D 133 | inputs = inputs.reshape(B, 3, -1) 134 | inputs.scatter_(dim=2, index=inds, value=0) 135 | inputs = inputs.reshape(B, 3, inp_w, inp_w) 136 | return inputs 137 | 138 | 139 | def extract_max_window(cam_images, inputs, window_w=30): 140 | """ 141 | 从输入图像中提取GradCAM激活最强的区域 142 | 143 | 参数: 144 | cam_images: GradCAM生成的热力图 145 | inputs: 输入图像tensor 146 | window_w: 窗口大小 147 | 148 | 返回: 149 | 提取的窗口tensor 150 | """ 151 | B, _, inp_w = cam_images.shape 152 | grayscale_cam = torch.from_numpy(cam_images) 153 | inputs = inputs.clone() 154 | sum_conv = torch.ones((1, 1, window_w, window_w)) 155 | 156 | # 计算每个窗口的总和 157 | sums_cam = F.conv2d(grayscale_cam.unsqueeze(1), sum_conv) 158 | 159 | # 展平总和并取argmax 160 | flat_sums_cam = sums_cam.view(B, -1) 161 | ij = flat_sums_cam.argmax(dim=-1) 162 | 163 | # 分离行和列索引,这给了我们左上角窗口的位置 164 | sums_cam_w = sums_cam.shape[-1] 165 | i, j = ij // sums_cam_w, ij % sums_cam_w 166 | 167 | # 为窗口中的每个位置创建行和列索引 168 | s = torch.arange(window_w, device=inputs.device) 169 | ri = i.view(B, 1).repeat(1, window_w) 170 | rj = j.view(B, 1).repeat(1, window_w) 171 | sri, srj = ri + s, rj + s 172 | 173 | # 在列中重复起始行索引,反之亦然 174 | xi = sri.view(B, window_w, 1).repeat(1, 1, window_w) 175 | xj = srj.view(B, 1, window_w).repeat(1, window_w, 1) 176 | 177 | # 这些是2D索引,将它们转换为1D索引 178 | inds = xi * inp_w + xj 179 | 180 | # 跨颜色通道重复索引 181 | inds = inds.unsqueeze(1).repeat((1, 3, 1, 1)).view(B, 3, -1) 182 | 183 | # 将图像从2D转为1D 184 | inputs = inputs.reshape(B, 3, -1) 185 | 186 | # 收集窗口并将1D重塑为2D 187 | windows = torch.gather(inputs, dim=2, index=inds) 188 | windows = windows.reshape(B, 3, window_w, window_w) 189 | 190 | return windows 191 | 192 | 193 | def get_candidate_patches(model, loader, arch, window_w, repeat_patch): 194 | """ 195 | 从数据加载器中获取候选补丁 196 | 197 | 参数: 198 | model: 模型 199 | loader: 数据加载器 200 | arch: 模型架构 201 | window_w: 窗口大小 202 | repeat_patch: 每张图像中提取的补丁数量 203 | 204 | 返回: 205 | 候选补丁tensor列表 206 | """ 207 | candidate_patches = [] 208 | for inp, _, _, _ in tqdm(loader): 209 | windows = [] 210 | for _ in range(repeat_patch): 211 | cam_images, _ = run_gradcam(arch, model, inp) 212 | windows.append(extract_max_window(cam_images, inp, window_w)) 213 | block_max_window(cam_images, inp, int(window_w * .5)) 214 | windows = torch.stack(windows) 215 | windows = torch.einsum('kb...->bk...', windows) 216 | candidate_patches.append(windows.detach().cpu()) 217 | candidate_patches = torch.cat(candidate_patches) 218 | return candidate_patches 219 | 220 | 221 | def save_patches(windows, save_dir, dataset): 222 | """ 223 | 保存提取的补丁 224 | 225 | 参数: 226 | windows: 提取的补丁tensor 227 | save_dir: 保存目录 228 | dataset: 数据集名称 229 | """ 230 | import os 231 | os.makedirs(save_dir, exist_ok=True) 232 | 233 | for i, win in enumerate(windows): 234 | win = denormalize(win, dataset) 235 | win = (win * 255).clamp(0, 255).numpy().astype(np.uint8) 236 | win = Image.fromarray(win) 237 | win.save(os.path.join(save_dir, f'{i:05d}.png')) -------------------------------------------------------------------------------- /ssl_backdoor/defenses/patchsearch/utils/visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | PatchSearch防御方法的可视化工具。 3 | """ 4 | 5 | import os 6 | import torch 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | from PIL import Image 10 | 11 | 12 | def denormalize(x, args): 13 | """ 14 | 将归一化的图像张量转换回正常的图像张量。 15 | 16 | 参数: 17 | x: 归一化的图像张量 18 | args: 配置参数,包含dataset_name 19 | 20 | 返回: 21 | 去归一化的图像张量 22 | """ 23 | if x.dim() == 4: # 批处理 24 | return torch.stack([denormalize(x_i, args) for x_i in x]) 25 | 26 | if x.shape[0] == 3: # CHW -> HWC 27 | x = x.permute((1, 2, 0)) 28 | 29 | if "imagenet" in args.dataset_name: 30 | mean = torch.tensor([0.485, 0.456, 0.406], device=x.device) 31 | std = torch.tensor([0.229, 0.224, 0.225], device=x.device) 32 | elif "cifar10" in args.dataset_name: 33 | mean = torch.tensor([0.4914, 0.4822, 0.4465], device=x.device) 34 | std = torch.tensor([0.2023, 0.1994, 0.2010], device=x.device) 35 | elif "stl10" in args.dataset_name: 36 | mean = torch.tensor([0.485, 0.456, 0.406], device=x.device) 37 | std = torch.tensor([0.229, 0.224, 0.225], device=x.device) 38 | else: 39 | raise ValueError(f"未知数据集 '{args.dataset_name}'") 40 | 41 | x = ((x * std) + mean) 42 | x = torch.clamp(x, 0, 1) 43 | 44 | return x 45 | 46 | 47 | def save_image(img_tensor, path, args=None): 48 | """ 49 | 保存单个图像张量为PNG文件。 50 | 51 | 参数: 52 | img_tensor: 图像张量,形状为[C, H, W]或[H, W, C] 53 | path: 保存路径 54 | args: 配置参数 55 | """ 56 | if args is not None: 57 | img_tensor = denormalize(img_tensor, args) 58 | 59 | if img_tensor.dim() == 3 and img_tensor.shape[0] == 3: # CHW -> HWC 60 | img_tensor = img_tensor.permute(1, 2, 0) 61 | 62 | # 转换为numpy数组,然后保存 63 | img_np = (img_tensor.detach().cpu().numpy() * 255).astype(np.uint8) 64 | img = Image.fromarray(img_np) 65 | img.save(path) 66 | 67 | 68 | def show_images_grid(inp, save_dir, title, args=None, max_images=40, nrows=8, ncols=5): 69 | """ 70 | 显示图像网格并保存为PNG文件。 71 | 72 | 参数: 73 | inp: 图像张量,形状为[B, C, H, W] 74 | save_dir: 保存目录 75 | title: 图像标题 76 | args: 配置参数 77 | max_images: 最大显示图像数量 78 | nrows: 行数 79 | ncols: 列数 80 | """ 81 | # 限制图像数量 82 | inp = inp[:max_images] 83 | n_images = inp.shape[0] 84 | 85 | # 创建必要的行列数 86 | if n_images < nrows * ncols: 87 | nrows = (n_images + ncols - 1) // ncols 88 | 89 | # 创建图像网格 90 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*2, nrows*2)) 91 | 92 | for img_idx in range(n_images): 93 | if nrows == 1: 94 | ax = axes[img_idx % ncols] 95 | elif ncols == 1: 96 | ax = axes[img_idx % nrows] 97 | else: 98 | ax = axes[img_idx // ncols][img_idx % ncols] 99 | 100 | # 去归一化并显示图像 101 | if args is not None: 102 | rgb_image = denormalize(inp[img_idx], args).detach().cpu().numpy() 103 | else: 104 | rgb_image = inp[img_idx].detach().cpu().numpy() 105 | if rgb_image.shape[0] == 3: # CHW -> HWC 106 | rgb_image = rgb_image.transpose(1, 2, 0) 107 | 108 | ax.imshow(rgb_image) 109 | ax.set_xticks([]) 110 | ax.set_yticks([]) 111 | 112 | # 对未使用的子图去掉边框 113 | for img_idx in range(n_images, nrows * ncols): 114 | if nrows == 1: 115 | ax = axes[img_idx % ncols] 116 | elif ncols == 1: 117 | ax = axes[img_idx % nrows] 118 | else: 119 | ax = axes[img_idx // ncols][img_idx % ncols] 120 | ax.axis('off') 121 | 122 | plt.tight_layout() 123 | 124 | # 保存图像 125 | os.makedirs(save_dir, exist_ok=True) 126 | save_path = os.path.join(save_dir, title.lower().replace(' ', '-') + '.png') 127 | fig.savefig(save_path) 128 | plt.close(fig) 129 | 130 | return save_path 131 | 132 | 133 | def show_cam_on_image(inp, cam, save_dir, title, args=None, alpha=0.5): 134 | """ 135 | 在图像上显示CAM热图并保存为PNG文件。 136 | 137 | 参数: 138 | inp: 图像张量,形状为[B, C, H, W] 139 | cam: CAM热图,形状为[B, H, W] 140 | save_dir: 保存目录 141 | title: 图像标题 142 | args: 配置参数 143 | alpha: 热图透明度 144 | """ 145 | # 限制图像数量 146 | max_images = 16 147 | inp = inp[:max_images] 148 | cam = cam[:max_images] 149 | n_images = inp.shape[0] 150 | 151 | # 创建必要的行列数 152 | nrows = int(np.ceil(np.sqrt(n_images))) 153 | ncols = int(np.ceil(n_images / nrows)) 154 | 155 | # 创建图像网格 156 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3, nrows*3)) 157 | 158 | for img_idx in range(n_images): 159 | if nrows == 1 and ncols == 1: 160 | ax = axes 161 | elif nrows == 1: 162 | ax = axes[img_idx % ncols] 163 | elif ncols == 1: 164 | ax = axes[img_idx % nrows] 165 | else: 166 | ax = axes[img_idx // ncols][img_idx % ncols] 167 | 168 | # 去归一化图像 169 | if args is not None: 170 | rgb_image = denormalize(inp[img_idx], args).detach().cpu().numpy() 171 | else: 172 | rgb_image = inp[img_idx].detach().cpu().numpy() 173 | if rgb_image.shape[0] == 3: # CHW -> HWC 174 | rgb_image = rgb_image.transpose(1, 2, 0) 175 | 176 | # 获取热图 177 | heatmap = cam[img_idx] 178 | if isinstance(heatmap, torch.Tensor): 179 | heatmap = heatmap.detach().cpu().numpy() 180 | 181 | # 颜色映射 182 | cmap = plt.cm.jet 183 | heatmap = cmap(heatmap)[:, :, :3] 184 | 185 | # 叠加热图 186 | superimposed_img = rgb_image * (1 - alpha) + heatmap * alpha 187 | 188 | # 显示叠加图像 189 | ax.imshow(superimposed_img) 190 | ax.set_xticks([]) 191 | ax.set_yticks([]) 192 | 193 | # 对未使用的子图去掉边框 194 | for img_idx in range(n_images, nrows * ncols): 195 | if nrows == 1: 196 | if ncols == 1: 197 | ax = axes 198 | else: 199 | ax = axes[img_idx % ncols] 200 | elif ncols == 1: 201 | ax = axes[img_idx % nrows] 202 | else: 203 | ax = axes[img_idx // ncols][img_idx % ncols] 204 | ax.axis('off') 205 | 206 | plt.tight_layout() 207 | 208 | # 保存图像 209 | os.makedirs(save_dir, exist_ok=True) 210 | save_path = os.path.join(save_dir, title.lower().replace(' ', '-') + '.png') 211 | fig.savefig(save_path) 212 | plt.close(fig) 213 | 214 | return save_path -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import get_trainer -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead 7 | 8 | from ssl_backdoor.utils.model_utils import transform_encoder_for_small_dataset, remove_task_head_for_encoder 9 | 10 | 11 | class BYOL(nn.Module): 12 | """ 13 | Build a BYOL model. 14 | """ 15 | def __init__(self, base_encoder, dim=2048, proj_dim=1024, pred_dim=128, tau=0.99, dataset=None): 16 | """ 17 | dim: feature dimension (default: 2048) 18 | proj_dim: hidden dimension of the projector (default: 1024) 19 | pred_dim: hidden dimension of the predictor (default: 128) 20 | tau: target network momentum (default: 0.99) 21 | """ 22 | super(BYOL, self).__init__() 23 | 24 | # create the online encoder 25 | self.encoder = base_encoder(num_classes=dim, zero_init_residual=True) 26 | self.encoder = transform_encoder_for_small_dataset(self.encoder, dataset) 27 | self.encoder = remove_task_head_for_encoder(self.encoder) 28 | 29 | # create the target encoder 30 | self.momentum_encoder = base_encoder(num_classes=dim, zero_init_residual=True) 31 | self.momentum_encoder = transform_encoder_for_small_dataset(self.momentum_encoder, dataset) 32 | self.momentum_encoder = remove_task_head_for_encoder(self.momentum_encoder) 33 | 34 | 35 | self.projector = BYOLProjectionHead(input_dim=dim, hidden_dim=proj_dim, output_dim=pred_dim) 36 | self.momentum_projector = BYOLProjectionHead(input_dim=dim, hidden_dim=proj_dim, output_dim=pred_dim) 37 | self.predictor = BYOLPredictionHead(input_dim=pred_dim, hidden_dim=proj_dim, output_dim=pred_dim) 38 | 39 | # disable gradient for target encoder 40 | for param in self.momentum_encoder.parameters(): 41 | param.requires_grad = False 42 | for param in self.momentum_projector.parameters(): 43 | param.requires_grad = False 44 | 45 | # momentum coefficient 46 | self.tau = tau 47 | 48 | # copy parameters from online to target encoder 49 | self.update_target(0) 50 | 51 | def update_target(self, progress=None): 52 | """ 53 | Update target network parameters using momentum update rule 54 | If progress is provided, use cosine schedule 55 | """ 56 | tau = self.tau 57 | if progress is not None: 58 | # cosine schedule as in the original BYOL implementation 59 | tau = 1 - (1 - self.tau) * (math.cos(math.pi * progress) + 1) / 2 60 | 61 | for param_q, param_k in zip(self.encoder.parameters(), self.momentum_encoder.parameters()): 62 | param_k.data = param_k.data * tau + param_q.data * (1. - tau) 63 | for param_q, param_k in zip(self.projector.parameters(), self.momentum_projector.parameters()): 64 | param_k.data = param_k.data * tau + param_q.data * (1. - tau) 65 | 66 | def forward(self, x1, x2): 67 | """ 68 | Input: 69 | x1: first views of images 70 | x2: second views of images 71 | Output: 72 | loss: BYOL loss 73 | """ 74 | # compute online features 75 | z1 = self.encoder(x1) # NxC 76 | z2 = self.encoder(x2) # NxC 77 | 78 | z1_proj = self.projector(z1) 79 | z2_proj = self.projector(z2) 80 | 81 | # compute online predictions 82 | p1 = self.predictor(z1_proj) # NxC 83 | p2 = self.predictor(z2_proj) # NxC 84 | # compute target features (no gradient) 85 | with torch.no_grad(): 86 | z1_target = self.momentum_encoder(x1) 87 | z2_target = self.momentum_encoder(x2) 88 | 89 | z1_target_proj = self.momentum_projector(z1_target) 90 | z2_target_proj = self.momentum_projector(z2_target) 91 | 92 | # normalize for cosine similarity 93 | p1 = F.normalize(p1, dim=-1) 94 | p2 = F.normalize(p2, dim=-1) 95 | z1_target_proj = F.normalize(z1_target_proj, dim=-1) 96 | z2_target_proj = F.normalize(z2_target_proj, dim=-1) 97 | 98 | # BYOL loss 99 | loss = 4 - 2 * (p1 * z2_target_proj).sum(dim=-1).mean() - 2 * (p2 * z1_target_proj).sum(dim=-1).mean() 100 | 101 | return loss -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/cfg.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import argparse 3 | from torchvision import models 4 | import multiprocessing 5 | from dataset import DS_LIST 6 | from methods import METHOD_LIST 7 | 8 | 9 | def get_cfg(): 10 | """ generates configuration from user input in console """ 11 | parser = argparse.ArgumentParser(description="") 12 | parser.add_argument( 13 | "--method", type=str, choices=METHOD_LIST, default="w_mse", help="loss type", 14 | ) 15 | 16 | 17 | ### attack things 18 | parser.add_argument('--config', default=None, type=str, required=True, 19 | help='config file') 20 | parser.add_argument('--attack_algorithm', default='clean', type=str, required=True, 21 | help='attack_algorithm') 22 | # parser.add_argument('--attack_target', default=16, type=int, required=True, 23 | # help='attack target') 24 | # parser.add_argument('--attack_target_word', default=None, type=str, required=True, 25 | # help='attack target') 26 | # parser.add_argument('--poison_injection_rate', default=0.005, type=float, required=True, 27 | # help='poison_injection_rate') 28 | # parser.add_argument('--trigger_path', default=None, type=str, required=True, 29 | # help='trigger_path') 30 | # parser.add_argument('--trigger_size', default=60, type=int, required=True, 31 | # help='trigger_size') 32 | 33 | # # poisonencoder things 34 | # parser.add_argument('--support_ratio', default=0.2, type=float, 35 | # help='support_ratio') 36 | # parser.add_argument('--background_dir', default='/workspace/sync/SSL-Backdoor/poison-generation/poisonencoder_utils/places', type=str, 37 | # help='background_dir') 38 | # parser.add_argument('--reference_dir', default='/workspace/sync/SSL-Backdoor/poison-generation/poisonencoder_utils/references/', type=str, 39 | # help='reference_dir') 40 | # parser.add_argument('--num_references', default=3, type=int, 41 | # help='num_references') 42 | # parser.add_argument('--max_size', default=800, type=int, 43 | # help='max_size') 44 | # parser.add_argument('--area_ratio', default=2, type=int, 45 | # help='area_ratio') 46 | # parser.add_argument('--object_marginal', default=0.05, type=float, 47 | # help='object_marginal') 48 | # parser.add_argument('--trigger_marginal', default=0.25, type=float, 49 | # help='trigger_marginal') 50 | 51 | parser.add_argument( 52 | "--wandb", 53 | type=str, 54 | default="self_supervised", 55 | help="name of the project for logging at https://wandb.ai", 56 | ) 57 | parser.add_argument( 58 | "--byol_tau", type=float, default=0.99, help="starting tau for byol loss" 59 | ) 60 | parser.add_argument( 61 | "--num_samples", 62 | type=int, 63 | default=2, 64 | help="number of samples (d) generated from each image", 65 | ) 66 | 67 | addf = partial(parser.add_argument, type=float) 68 | addf("--cj0", default=0.4, help="color jitter brightness") 69 | addf("--cj1", default=0.4, help="color jitter contrast") 70 | addf("--cj2", default=0.4, help="color jitter saturation") 71 | addf("--cj3", default=0.1, help="color jitter hue") 72 | addf("--cj_p", default=0.8, help="color jitter probability") 73 | addf("--gs_p", default=0.1, help="grayscale probability") 74 | addf("--crop_s0", default=0.2, help="crop size from") 75 | addf("--crop_s1", default=1.0, help="crop size to") 76 | addf("--crop_r0", default=0.75, help="crop ratio from") 77 | addf("--crop_r1", default=(4 / 3), help="crop ratio to") 78 | addf("--hf_p", default=0.5, help="horizontal flip probability") 79 | 80 | parser.add_argument( 81 | "--no_lr_warmup", 82 | dest="lr_warmup", 83 | action="store_false", 84 | help="do not use learning rate warmup", 85 | ) 86 | parser.add_argument( 87 | "--no_add_bn", dest="add_bn", action="store_false", help="do not use BN in head" 88 | ) 89 | parser.add_argument("--knn", type=int, default=5, help="k in k-nn classifier") 90 | parser.add_argument("--fname", type=str, help="load model from file") 91 | parser.add_argument( 92 | "--lr_step", 93 | type=str, 94 | choices=["cos", "step", "none"], 95 | default="step", 96 | help="learning rate schedule type", 97 | ) 98 | parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") 99 | parser.add_argument( 100 | "--eta_min", type=float, default=0, help="min learning rate (for --lr_step cos)" 101 | ) 102 | parser.add_argument( 103 | "--adam_l2", type=float, default=1e-6, help="weight decay (L2 penalty)" 104 | ) 105 | parser.add_argument("--T0", type=int, help="period (for --lr_step cos)") 106 | parser.add_argument( 107 | "--Tmult", type=int, default=1, help="period factor (for --lr_step cos)" 108 | ) 109 | parser.add_argument( 110 | "--w_eps", type=float, default=0, help="eps for stability for whitening" 111 | ) 112 | parser.add_argument( 113 | "--head_layers", type=int, default=2, help="number of FC layers in head" 114 | ) 115 | parser.add_argument( 116 | "--head_size", type=int, default=1024, help="size of FC layers in head" 117 | ) 118 | 119 | parser.add_argument( 120 | "--w_size", type=int, default=128, help="size of sub-batch for W-MSE loss" 121 | ) 122 | parser.add_argument( 123 | "--w_iter", 124 | type=int, 125 | default=1, 126 | help="iterations for whitening matrix estimation", 127 | ) 128 | 129 | parser.add_argument( 130 | "--no_norm", dest="norm", action="store_false", help="don't normalize latents", 131 | ) 132 | parser.add_argument( 133 | "--tau", type=float, default=0.5, help="contrastive loss temperature" 134 | ) 135 | 136 | parser.add_argument("--epoch", type=int, default=200, help="total epoch number") 137 | parser.add_argument( 138 | "--eval_every_drop", 139 | type=int, 140 | default=5, 141 | help="how often to evaluate after learning rate drop", 142 | ) 143 | parser.add_argument( 144 | "--eval_every", type=int, default=20, help="how often to evaluate" 145 | ) 146 | parser.add_argument("--emb", type=int, default=64, help="embedding size") 147 | parser.add_argument( 148 | "--bs", type=int, default=512, help="number of original images in batch N", 149 | ) 150 | parser.add_argument( 151 | "--drop", 152 | type=int, 153 | nargs="*", 154 | default=[50, 25], 155 | help="milestones for learning rate decay (0 = last epoch)", 156 | ) 157 | parser.add_argument( 158 | "--drop_gamma", 159 | type=float, 160 | default=0.2, 161 | help="multiplicative factor of learning rate decay", 162 | ) 163 | parser.add_argument( 164 | "--arch", 165 | type=str, 166 | choices=[x for x in dir(models) if "resn" in x], 167 | default="resnet18", 168 | help="encoder architecture", 169 | ) 170 | parser.add_argument( 171 | "--num_workers", 172 | type=int, 173 | default=8, 174 | help="dataset workers number", 175 | ) 176 | parser.add_argument( 177 | "--clf", 178 | type=str, 179 | default="sgd", 180 | choices=["sgd", "knn", "lbfgs"], 181 | help="classifier for test.py", 182 | ) 183 | parser.add_argument( 184 | "--eval_head", action="store_true", help="eval head output instead of model", 185 | ) 186 | parser.add_argument("--exp_id", type=str, default="") 187 | 188 | parser.add_argument("--clf_chkpt", type=str, default="") 189 | parser.add_argument("--evaluate", dest="evaluate", action="store_true") 190 | parser.add_argument("--eval_data", type=str, default="") 191 | parser.add_argument("--save_folder", type=str, default="./output") 192 | parser.add_argument("--save-freq", type=int, default=50) 193 | return parser.parse_args() 194 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/cifar10_classes.txt: -------------------------------------------------------------------------------- 1 | airplane 2 | automobile 3 | bird 4 | cat 5 | deer 6 | dog 7 | frog 8 | horse 9 | ship 10 | truck -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/cifar10_metadata.txt: -------------------------------------------------------------------------------- 1 | airplane 0 2 | automobile 1 3 | bird 2 4 | cat 3 5 | deer 4 6 | dog 5 7 | frog 6 8 | horse 7 9 | ship 8 10 | truck 9 11 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar10 import CIFAR10 2 | from .cifar100 import CIFAR100 3 | from .stl10 import STL10 4 | from .tiny_in import TinyImageNet 5 | from .imagenet import ImageNet 6 | 7 | 8 | DS_LIST = ["cifar10", "cifar100", "stl10", "tiny_in", "imagenet-100"] 9 | 10 | 11 | def get_ds(name): 12 | assert name in DS_LIST 13 | if name == "cifar10": 14 | return CIFAR10 15 | elif name == "cifar100": 16 | return CIFAR100 17 | elif name == "stl10": 18 | return STL10 19 | elif name == "tiny_in": 20 | return TinyImageNet 21 | elif name == "imagenet-100": 22 | return ImageNet 23 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/dataset/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from functools import lru_cache 3 | from torch.utils.data import DataLoader 4 | 5 | 6 | class BaseDataset(metaclass=ABCMeta): 7 | """ 8 | base class for datasets, it includes 3 types: 9 | - for self-supervised training, 10 | - for classifier training for evaluation, 11 | - for testing 12 | """ 13 | 14 | def __init__( 15 | self, bs_train, aug_cfg, num_workers, bs_clf=1000, bs_test=1000, 16 | ): 17 | self.aug_cfg = aug_cfg 18 | self.bs_train, self.bs_clf, self.bs_test = bs_train, bs_clf, bs_test 19 | self.num_workers = num_workers 20 | 21 | 22 | 23 | @abstractmethod 24 | def ds_train(self): 25 | raise NotImplementedError 26 | 27 | @abstractmethod 28 | def ds_clf(self): 29 | raise NotImplementedError 30 | 31 | @abstractmethod 32 | def ds_test(self): 33 | raise NotImplementedError 34 | 35 | # 36 | @abstractmethod 37 | def ds_test_p(self): 38 | raise NotImplementedError 39 | 40 | @property 41 | @lru_cache() 42 | def train(self): 43 | return DataLoader( 44 | dataset=self.ds_train(), 45 | batch_size=self.bs_train, 46 | shuffle=True, 47 | num_workers=self.num_workers, 48 | pin_memory=True, 49 | drop_last=True, 50 | ) 51 | 52 | @property 53 | @lru_cache() 54 | def clf(self): 55 | return DataLoader( 56 | dataset=self.ds_clf(), 57 | batch_size=self.bs_clf, 58 | shuffle=True, 59 | num_workers=self.num_workers, 60 | pin_memory=True, 61 | drop_last=True, 62 | ) 63 | 64 | @property 65 | @lru_cache() 66 | def test(self): 67 | return DataLoader( 68 | dataset=self.ds_test(), 69 | batch_size=self.bs_test, 70 | shuffle=False, 71 | num_workers=self.num_workers, 72 | pin_memory=True, 73 | drop_last=False, 74 | ) 75 | 76 | @property 77 | @lru_cache() 78 | def test_p(self): 79 | return DataLoader( 80 | dataset=self.ds_test_p(), 81 | batch_size=self.bs_test, 82 | shuffle=False, 83 | num_workers=self.num_workers, 84 | pin_memory=True, 85 | drop_last=False, 86 | ) 87 | 88 | @property 89 | @lru_cache() 90 | def test_poison(self, attack_target=None): 91 | return DataLoader( 92 | dataset=self.ds_test_poison(attack_target), 93 | batch_size=self.bs_test, 94 | shuffle=False, 95 | num_workers=self.num_workers, 96 | pin_memory=True, 97 | drop_last=False, 98 | ) 99 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/dataset/cifar10.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from torchvision.datasets import CIFAR10 as C10 4 | import torchvision.transforms as T 5 | from .transforms import MultiSample, aug_transform 6 | from .base import BaseDataset 7 | from PIL import ImageFilter, Image 8 | from torch.utils import data 9 | 10 | import moco.loader 11 | import moco.builder 12 | import moco.dataset3 13 | 14 | class FileListDataset(data.Dataset): 15 | def __init__(self, path_to_txt_file, transform): 16 | with open(path_to_txt_file, 'r') as f: 17 | self.file_list = f.readlines() 18 | self.file_list = [row.rstrip() for row in self.file_list] 19 | 20 | self.transform = transform 21 | 22 | 23 | def __getitem__(self, idx): 24 | image_path = self.file_list[idx].split()[0] 25 | img = Image.open(image_path).convert('RGB') 26 | target = int(self.file_list[idx].split()[1]) 27 | 28 | if self.transform is not None: 29 | images = self.transform(img) 30 | 31 | return images, target 32 | 33 | def __len__(self): 34 | return len(self.file_list) 35 | 36 | class RandomBlur: 37 | def __init__(self, r0, r1): 38 | self.r0, self.r1 = r0, r1 39 | 40 | def __call__(self, image): 41 | r = random.uniform(self.r0, self.r1) 42 | return image.filter(ImageFilter.GaussianBlur(radius=r)) 43 | 44 | def base_transform(): 45 | return T.Compose( 46 | [T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))] 47 | ) 48 | 49 | def base_transform_linear_probe(): 50 | return T.Compose( 51 | [T.RandomCrop(32), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))] 52 | ) 53 | 54 | class CIFAR10(BaseDataset): 55 | 56 | def get_train_dataset(self, args, file_path, transform): 57 | args = self.aug_cfg 58 | args.image_size = 224 59 | if args.attack_algorithm == 'backog': 60 | train_dataset = moco.dataset3.BackOGTrainDataset( 61 | args, 62 | file_path, 63 | transform) 64 | elif args.attack_algorithm == 'corruptencoder': 65 | train_dataset = moco.dataset3.CorruptEncoderTrainDataset( 66 | args, 67 | file_path, 68 | transform) 69 | elif args.attack_algorithm == 'sslbkd': 70 | train_dataset = moco.dataset3.SSLBackdoorTrainDataset( 71 | args, 72 | file_path, 73 | transform) 74 | elif args.attack_algorithm == 'ctrl': 75 | train_dataset = moco.dataset3.CTRLTrainDataset( 76 | args, 77 | file_path, 78 | transform) 79 | else: 80 | raise ValueError(f"Unknown attack algorithm '{args.attack_algorithm}'") 81 | 82 | return train_dataset 83 | 84 | def ds_train(self): 85 | aug_with_blur = aug_transform( 86 | 32, 87 | base_transform, 88 | self.aug_cfg, 89 | extra_t=[T.RandomApply([RandomBlur(0.1, 2.0)], p=0.5)], 90 | ) 91 | t = MultiSample(aug_with_blur, n=self.aug_cfg.num_samples) 92 | self.pretrain_set=self.get_train_dataset(self.aug_cfg, self.aug_cfg.train_file_path, t) 93 | return self.pretrain_set 94 | 95 | # Do not pre resize images like in original repo 96 | def ds_clf(self): 97 | t = base_transform_linear_probe() 98 | return FileListDataset(path_to_txt_file=self.aug_cfg.train_clean_file_path, transform=t) 99 | 100 | def ds_test(self): 101 | t = base_transform() 102 | return FileListDataset(path_to_txt_file=self.aug_cfg.val_file_path, transform=t) 103 | 104 | def ds_test_p(self): 105 | t = base_transform() 106 | return moco.dataset3.UniversalPoisonedValDataset(self.aug_cfg, self.aug_cfg.val_file_path, transform=t) 107 | 108 | 109 | # class CIFAR10(BaseDataset): 110 | # def ds_train(self): 111 | # t = MultiSample( 112 | # aug_transform(32, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples 113 | # ) 114 | # return C10(root="./data", train=True, download=True, transform=t) 115 | 116 | # def ds_clf(self): 117 | # t = base_transform() 118 | # return C10(root="./data", train=True, download=True, transform=t) 119 | 120 | # def ds_test(self): 121 | # t = base_transform() 122 | # return C10(root="./data", train=False, download=True, transform=t) 123 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/dataset/cifar100.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR100 as C100 2 | import torchvision.transforms as T 3 | from .transforms import MultiSample, aug_transform 4 | from .base import BaseDataset 5 | 6 | 7 | def base_transform(): 8 | return T.Compose( 9 | [T.ToTensor(), T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))] 10 | ) 11 | 12 | 13 | class CIFAR100(BaseDataset): 14 | def ds_train(self): 15 | t = MultiSample( 16 | aug_transform(32, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples 17 | ) 18 | return C100(root="./data", train=True, download=True, transform=t,) 19 | 20 | def ds_clf(self): 21 | t = base_transform() 22 | return C100(root="./data", train=True, download=True, transform=t) 23 | 24 | def ds_test(self): 25 | t = base_transform() 26 | return C100(root="./data", train=False, download=True, transform=t) 27 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/dataset/imagenet.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | import os 4 | from torchvision.datasets import ImageFolder 5 | import torchvision.transforms as T 6 | from PIL import ImageFilter, Image 7 | from .transforms import MultiSample, aug_transform 8 | from .base import BaseDataset 9 | from torch.utils import data 10 | 11 | import moco.loader 12 | import moco.builder 13 | 14 | # current_dir = os.path.dirname(os.path.abspath(__file__)) 15 | sys.path.append("/workspace/SSL-Backdoor") 16 | print(sys.path) 17 | import datasets.dataset 18 | 19 | 20 | class FileListDataset(data.Dataset): 21 | def __init__(self, path_to_txt_file, transform): 22 | with open(path_to_txt_file, 'r') as f: 23 | self.file_list = f.readlines() 24 | self.file_list = [row.rstrip() for row in self.file_list] 25 | 26 | self.transform = transform 27 | 28 | 29 | def __getitem__(self, idx): 30 | image_path = self.file_list[idx].split()[0] 31 | img = Image.open(image_path).convert('RGB') 32 | target = int(self.file_list[idx].split()[1]) 33 | 34 | if self.transform is not None: 35 | images = self.transform(img) 36 | 37 | return images, target 38 | 39 | def __len__(self): 40 | return len(self.file_list) 41 | 42 | class RandomBlur: 43 | def __init__(self, r0, r1): 44 | self.r0, self.r1 = r0, r1 45 | 46 | def __call__(self, image): 47 | r = random.uniform(self.r0, self.r1) 48 | return image.filter(ImageFilter.GaussianBlur(radius=r)) 49 | 50 | 51 | def base_transform(): 52 | return T.Compose( 53 | [T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] 54 | ) 55 | 56 | # Change this to be consistent with MoCo v2 eval 57 | def base_transform_eval(): 58 | return T.Compose( 59 | [T.Resize(256), T.CenterCrop(224), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] 60 | ) 61 | def base_transform_linear_probe(): 62 | return T.Compose( 63 | [T.RandomResizedCrop(224), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] 64 | ) 65 | 66 | class ImageNet(BaseDataset): 67 | 68 | def get_train_dataset(self, args, file_path, transform): 69 | args = self.aug_cfg 70 | args.image_size = 224 71 | 72 | # attack_algorithm 和 dataset 的映射 73 | dataset_classes = { 74 | 'bp': datasets.dataset.BPTrainDataset, 75 | 'blto': datasets.dataset.BltoPoisoningPoisonedTrainDataset, 76 | 'corruptencoder': datasets.dataset.CorruptEncoderTrainDataset, 77 | 'sslbkd': datasets.dataset.SSLBackdoorTrainDataset, 78 | 'ctrl': datasets.dataset.CTRLTrainDataset, 79 | # 'randombackground': datasets.dataset.RandomBackgroundTrainDataset, 80 | 'clean': datasets.dataset.FileListDataset, 81 | } 82 | 83 | if args.attack_algorithm not in dataset_classes: 84 | raise ValueError(f"Unknown attack algorithm '{args.attack_algorithm}'") 85 | 86 | train_dataset = dataset_classes[args.attack_algorithm](args, file_path, transform) 87 | 88 | 89 | return train_dataset 90 | 91 | def ds_train(self): 92 | aug_with_blur = aug_transform( 93 | 224, 94 | base_transform, 95 | self.aug_cfg, 96 | extra_t=[T.RandomApply([RandomBlur(0.1, 2.0)], p=0.5)], 97 | ) 98 | t = MultiSample(aug_with_blur, n=self.aug_cfg.num_samples) 99 | 100 | self.pretrain_set=self.get_train_dataset(self.aug_cfg, self.aug_cfg.data, t) 101 | 102 | return self.pretrain_set 103 | 104 | # Do not pre resize images like in original repo 105 | def ds_clf(self): 106 | raise NotImplementedError 107 | t = base_transform_linear_probe() 108 | return FileListDataset(path_to_txt_file=self.aug_cfg.train_clean_file_path, transform=t) 109 | 110 | def ds_test(self): 111 | raise NotImplementedError 112 | t = base_transform_eval() 113 | return FileListDataset(path_to_txt_file=self.aug_cfg.val_file_path, transform=t) 114 | 115 | def ds_test_p(self): 116 | raise NotImplementedError 117 | t = base_transform_eval() 118 | return datasets.dataset.UniversalPoisonedValDataset(self.aug_cfg, self.aug_cfg.val_file_path, transform=t) 119 | # return FileListDataset(path_to_txt_file=self.aug_cfg.val_poisoned_file_path, transform=t) 120 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/dataset/stl10.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import STL10 as S10 2 | import torchvision.transforms as T 3 | from .transforms import MultiSample, aug_transform 4 | from .base import BaseDataset 5 | 6 | 7 | def base_transform(): 8 | return T.Compose( 9 | [T.ToTensor(), T.Normalize((0.43, 0.42, 0.39), (0.27, 0.26, 0.27))] 10 | ) 11 | 12 | 13 | def test_transform(): 14 | return T.Compose( 15 | [T.Resize(70, interpolation=3), T.CenterCrop(64), base_transform()] 16 | ) 17 | 18 | 19 | class STL10(BaseDataset): 20 | def ds_train(self): 21 | t = MultiSample( 22 | aug_transform(64, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples 23 | ) 24 | return S10(root="./data", split="train+unlabeled", download=True, transform=t) 25 | 26 | def ds_clf(self): 27 | t = test_transform() 28 | return S10(root="./data", split="train", download=True, transform=t) 29 | 30 | def ds_test(self): 31 | t = test_transform() 32 | return S10(root="./data", split="test", download=True, transform=t) 33 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/dataset/tiny_in.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import ImageFolder 2 | import torchvision.transforms as T 3 | from .transforms import MultiSample, aug_transform 4 | from .base import BaseDataset 5 | 6 | 7 | def base_transform(): 8 | return T.Compose( 9 | [T.ToTensor(), T.Normalize((0.480, 0.448, 0.398), (0.277, 0.269, 0.282))] 10 | ) 11 | 12 | 13 | class TinyImageNet(BaseDataset): 14 | def ds_train(self): 15 | t = MultiSample( 16 | aug_transform(64, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples 17 | ) 18 | return ImageFolder(root="data/tiny-imagenet-200/train", transform=t) 19 | 20 | def ds_clf(self): 21 | t = base_transform() 22 | return ImageFolder(root="data/tiny-imagenet-200/train", transform=t) 23 | 24 | def ds_test(self): 25 | t = base_transform() 26 | return ImageFolder(root="data/tiny-imagenet-200/test", transform=t) 27 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/dataset/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | 3 | # Change the order 4 | def aug_transform(crop, base_transform, cfg, extra_t=[]): 5 | """ augmentation transform generated from config """ 6 | return T.Compose( 7 | [ 8 | T.RandomResizedCrop( 9 | crop, 10 | scale=(cfg.crop_s0, cfg.crop_s1), 11 | ratio=(cfg.crop_r0, cfg.crop_r1), 12 | interpolation=3, 13 | ), 14 | T.RandomApply( 15 | [T.ColorJitter(cfg.cj0, cfg.cj1, cfg.cj2, cfg.cj3)], p=cfg.cj_p 16 | ), 17 | T.RandomGrayscale(p=cfg.gs_p), 18 | *extra_t, 19 | T.RandomHorizontalFlip(p=cfg.hf_p), 20 | base_transform(), 21 | ] 22 | ) 23 | 24 | 25 | class MultiSample: 26 | """ generates n samples with augmentation """ 27 | 28 | def __init__(self, transform, n=2): 29 | self.transform = transform 30 | self.num = n 31 | 32 | def __call__(self, x): 33 | return tuple(self.transform(x) for _ in range(self.num)) 34 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/eval/get_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | def get_data(model, loader, output_size, device): 4 | """ encodes whole dataset into embeddings """ 5 | n_total_samples = len(loader.dataset) 6 | xs = torch.empty(n_total_samples, output_size, dtype=torch.float32, device=device) 7 | ys = torch.empty(n_total_samples, dtype=torch.long, device=device) 8 | start_idx = 0 9 | added_count = 0 # 记录添加了多少个数据 10 | 11 | with torch.no_grad(): 12 | for x, y in tqdm(loader): 13 | x = x.to(device) 14 | batch_size = x.shape[0] 15 | end_idx = start_idx + batch_size 16 | 17 | xs[start_idx:end_idx] = model(x) 18 | ys[start_idx:end_idx] = y.to(device) 19 | 20 | start_idx = end_idx 21 | added_count += batch_size # 更新添加了多少个数据 22 | 23 | # 删除未使用的部分 24 | xs = xs[:added_count] 25 | ys = ys[:added_count] 26 | 27 | return xs, ys 28 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/eval/knn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def eval_knn(x_train, y_train, x_test, y_test, k=5): 5 | """ k-nearest neighbors classifier accuracy """ 6 | d = torch.cdist(x_test, x_train) 7 | topk = torch.topk(d, k=k, dim=1, largest=False) 8 | labels = y_train[topk.indices] 9 | pred = torch.empty_like(y_test) 10 | for i in range(len(labels)): 11 | x = labels[i].unique(return_counts=True) 12 | pred[i] = x[0][x[1].argmax()] 13 | 14 | acc = (pred == y_test).float().mean().cpu().item() 15 | del d, topk, labels, pred 16 | return acc 17 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/eval/lbfgs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.linear_model import LogisticRegression 3 | 4 | 5 | def eval_lbfgs(x_train, y_train, x_test, y_test): 6 | """ linear classifier accuracy (lbfgs method) """ 7 | clf = LogisticRegression( 8 | random_state=1337, solver="lbfgs", max_iter=1000, n_jobs=-1 9 | ) 10 | clf.fit(x_train, y_train) 11 | pred = clf.predict(x_test) 12 | return (torch.tensor(pred) == y_test).float().mean() 13 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/eval/sgd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | import os 6 | from tqdm import trange, tqdm 7 | 8 | # Modified from original 9 | # def eval_sgd(x_train, y_train, x_test, y_test, x_test_p=None, y_test_p=None, evaluate=False, topk=[1, 5], epoch=100): 10 | # """ linear classifier accuracy (sgd) """ 11 | 12 | # lr_start, lr_end = 1e-2, 1e-6 13 | # gamma = (lr_end / lr_start) ** (1 / epoch) 14 | # output_size = x_train.shape[1] 15 | # num_class = y_train.max().item() + 1 16 | 17 | # clf = nn.Linear(output_size, num_class) 18 | # clf.cuda() 19 | # clf.train() 20 | 21 | # optimizer = optim.Adam(clf.parameters(), lr=lr_start, weight_decay=5e-6) 22 | # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma) 23 | # criterion = nn.CrossEntropyLoss() 24 | 25 | # from torch.utils.data import DataLoader, TensorDataset 26 | # # 假设 x_train 和 y_train 已经定义 27 | # dataset = TensorDataset(x_train, y_train) 28 | # dataloader = DataLoader(dataset, batch_size=256, shuffle=True) 29 | # best_acc = 0.0 # 初始化最佳精度 30 | # best_clf_state_dict = clf.state_dict() 31 | 32 | # pbar = tqdm(range(epoch)) 33 | # for ep in pbar: 34 | # for batch_idx, (batch_x, batch_y) in enumerate(dataloader): 35 | # optimizer.zero_grad() 36 | # criterion(clf(batch_x), batch_y).backward() 37 | # optimizer.step() 38 | # scheduler.step() 39 | 40 | # # 在每个epoch结束时,计算测试集的精度 41 | # clf.eval() 42 | # with torch.no_grad(): 43 | # y_pred = clf(x_test) 44 | # pred_top = y_pred.topk(max(topk), 1, largest=True, sorted=True).indices 45 | # acc = { 46 | # t: (pred_top[:, :t] == y_test[..., None]).float().sum(1).mean().cpu().item() 47 | # for t in topk 48 | # } 49 | # # 如果当前精度高于最佳精度,则保存模型并更新最佳精度 50 | # if acc[1] > best_acc: 51 | # best_acc = acc[1] 52 | # best_clf_state_dict = clf.state_dict() 53 | 54 | # pbar.set_postfix({'best_acc': best_acc}) 55 | # clf.train() 56 | 57 | # clf.load_state_dict(best_clf_state_dict) 58 | # clf.eval() 59 | # with torch.no_grad(): 60 | # y_pred = clf(x_test) 61 | # pred_top = y_pred.topk(max(topk), 1, largest=True, sorted=True).indices 62 | # acc = { 63 | # t: (pred_top[:, :t] == y_test[..., None]).float().sum(1).mean().cpu().item() 64 | # for t in topk 65 | # } 66 | 67 | # if not evaluate: 68 | # return acc 69 | # else: 70 | # clf.eval() 71 | # with torch.no_grad(): 72 | # y_pred = clf(x_test) 73 | # pred_top = y_pred.topk(max(topk), 1, largest=True, sorted=True).indices 74 | # acc = { 75 | # t: (pred_top[:, :t] == y_test[..., None]).float().sum(1).mean().cpu().item() 76 | # for t in topk 77 | # } 78 | 79 | # with torch.no_grad(): 80 | # y_pred_p = clf(x_test_p) 81 | # pred_top_p = y_pred_p.topk(max(topk), 1, largest=True, sorted=True).indices 82 | # acc_p = { 83 | # t: (pred_top_p[:, :t] == y_test_p[..., None]).float().sum(1).mean().cpu().item() 84 | # for t in topk 85 | # } 86 | # return acc, y_pred, y_test, acc_p, y_pred_p, y_test_p 87 | 88 | def get_data(model, loader, output_size, device): 89 | """ encodes whole dataset into embeddings """ 90 | n_total_samples = len(loader.dataset) 91 | xs = torch.empty(n_total_samples, output_size, dtype=torch.float32, device=device) 92 | ys = torch.empty(n_total_samples, dtype=torch.long, device=device) 93 | start_idx = 0 94 | added_count = 0 # 记录添加了多少个数据 95 | 96 | with torch.no_grad(): 97 | for x, y in tqdm(loader): 98 | x = x.to(device) 99 | batch_size = x.shape[0] 100 | end_idx = start_idx + batch_size 101 | 102 | xs[start_idx:end_idx] = model(x) 103 | ys[start_idx:end_idx] = y.to(device) 104 | 105 | start_idx = end_idx 106 | added_count += batch_size # 更新添加了多少个数据 107 | 108 | # 删除未使用的部分 109 | xs = xs[:added_count] 110 | ys = ys[:added_count] 111 | 112 | return xs, ys 113 | 114 | def eval_sgd(model, clean_set, test_set, test_poisoned_set, out_size, device, evaluate=True, topk=[1, 5], epoch=100): 115 | """ linear classifier accuracy (sgd) """ 116 | 117 | lr_start, lr_end = 1e-2, 1e-6 118 | gamma = (lr_end / lr_start) ** (1 / epoch) 119 | output_size = out_size 120 | num_class = 100 121 | 122 | x_test, y_test = get_data(model, test_set, out_size, device) 123 | x_test_p, y_test_p = get_data(model, test_poisoned_set, out_size, device) 124 | 125 | clf = nn.Linear(output_size, num_class) 126 | clf.cuda() 127 | clf.train() 128 | 129 | optimizer = optim.Adam(clf.parameters(), lr=lr_start, weight_decay=5e-6) 130 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma) 131 | criterion = nn.CrossEntropyLoss() 132 | 133 | from torch.utils.data import DataLoader, TensorDataset 134 | # 假设 x_train 和 y_train 已经定义 135 | 136 | best_acc = 0.0 # 初始化最佳精度 137 | best_clf_state_dict = clf.state_dict() 138 | 139 | pbar = tqdm(range(epoch)) 140 | for ep in pbar: 141 | x_train, y_train = get_data(model, clean_set, out_size, device) 142 | dataset = TensorDataset(x_train, y_train) 143 | dataloader = DataLoader(dataset, batch_size=256, shuffle=True) 144 | for batch_idx, (batch_x, batch_y) in enumerate(dataloader): 145 | optimizer.zero_grad() 146 | criterion(clf(batch_x), batch_y).backward() 147 | optimizer.step() 148 | scheduler.step() 149 | 150 | # 在每个epoch结束时,计算测试集的精度 151 | clf.eval() 152 | with torch.no_grad(): 153 | y_pred = clf(x_test) 154 | pred_top = y_pred.topk(max(topk), 1, largest=True, sorted=True).indices 155 | acc = { 156 | t: (pred_top[:, :t] == y_test[..., None]).float().sum(1).mean().cpu().item() 157 | for t in topk 158 | } 159 | # 如果当前精度高于最佳精度,则保存模型并更新最佳精度 160 | if acc[1] > best_acc: 161 | best_acc = acc[1] 162 | best_clf_state_dict = clf.state_dict() 163 | 164 | pbar.set_postfix({'best_acc': best_acc}) 165 | clf.train() 166 | 167 | clf.load_state_dict(best_clf_state_dict) 168 | clf.eval() 169 | with torch.no_grad(): 170 | y_pred = clf(x_test) 171 | pred_top = y_pred.topk(max(topk), 1, largest=True, sorted=True).indices 172 | acc = { 173 | t: (pred_top[:, :t] == y_test[..., None]).float().sum(1).mean().cpu().item() 174 | for t in topk 175 | } 176 | 177 | if not evaluate: 178 | return acc 179 | else: 180 | clf.eval() 181 | with torch.no_grad(): 182 | y_pred = clf(x_test) 183 | pred_top = y_pred.topk(max(topk), 1, largest=True, sorted=True).indices 184 | acc = { 185 | t: (pred_top[:, :t] == y_test[..., None]).float().sum(1).mean().cpu().item() 186 | for t in topk 187 | } 188 | 189 | with torch.no_grad(): 190 | y_pred_p = clf(x_test_p) 191 | pred_top_p = y_pred_p.topk(max(topk), 1, largest=True, sorted=True).indices 192 | acc_p = { 193 | t: (pred_top_p[:, :t] == y_test_p[..., None]).float().sum(1).mean().cpu().item() 194 | for t in topk 195 | } 196 | return acc, y_pred, y_test, acc_p, y_pred_p, y_test_p -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/imagenet100_classes.txt: -------------------------------------------------------------------------------- 1 | n02869837 2 | n01749939 3 | n02488291 4 | n02107142 5 | n13037406 6 | n02091831 7 | n04517823 8 | n04589890 9 | n03062245 10 | n01773797 11 | n01735189 12 | n07831146 13 | n07753275 14 | n03085013 15 | n04485082 16 | n02105505 17 | n01983481 18 | n02788148 19 | n03530642 20 | n04435653 21 | n02086910 22 | n02859443 23 | n13040303 24 | n03594734 25 | n02085620 26 | n02099849 27 | n01558993 28 | n04493381 29 | n02109047 30 | n04111531 31 | n02877765 32 | n04429376 33 | n02009229 34 | n01978455 35 | n02106550 36 | n01820546 37 | n01692333 38 | n07714571 39 | n02974003 40 | n02114855 41 | n03785016 42 | n03764736 43 | n03775546 44 | n02087046 45 | n07836838 46 | n04099969 47 | n04592741 48 | n03891251 49 | n02701002 50 | n03379051 51 | n02259212 52 | n07715103 53 | n03947888 54 | n04026417 55 | n02326432 56 | n03637318 57 | n01980166 58 | n02113799 59 | n02086240 60 | n03903868 61 | n02483362 62 | n04127249 63 | n02089973 64 | n03017168 65 | n02093428 66 | n02804414 67 | n02396427 68 | n04418357 69 | n02172182 70 | n01729322 71 | n02113978 72 | n03787032 73 | n02089867 74 | n02119022 75 | n03777754 76 | n04238763 77 | n02231487 78 | n03032252 79 | n02138441 80 | n02104029 81 | n03837869 82 | n03494278 83 | n04136333 84 | n03794056 85 | n03492542 86 | n02018207 87 | n04067472 88 | n03930630 89 | n03584829 90 | n02123045 91 | n04229816 92 | n02100583 93 | n03642806 94 | n04336792 95 | n03259280 96 | n02116738 97 | n02108089 98 | n03424325 99 | n01855672 100 | n02090622 101 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/methods/__init__.py: -------------------------------------------------------------------------------- 1 | from .contrastive import Contrastive 2 | from .w_mse import WMSE 3 | from .byol import BYOL 4 | 5 | 6 | METHOD_LIST = ["contrastive", "w_mse", "byol"] 7 | 8 | 9 | def get_method(name): 10 | assert name in METHOD_LIST 11 | if name == "contrastive": 12 | return Contrastive 13 | elif name == "w_mse": 14 | return WMSE 15 | elif name == "byol": 16 | return BYOL 17 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/methods/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | from model import get_model, get_head 4 | from eval.sgd import eval_sgd 5 | from eval.knn import eval_knn 6 | from eval.get_data import get_data 7 | 8 | 9 | class BaseMethod(nn.Module): 10 | """ 11 | Base class for self-supervised loss implementation. 12 | It includes encoder and head for training, evaluation function. 13 | """ 14 | 15 | def __init__(self, cfg): 16 | super().__init__() 17 | self.model, self.out_size = get_model(cfg.arch, cfg.dataset) 18 | self.head = get_head(self.out_size, cfg) 19 | self.knn = cfg.knn 20 | self.num_pairs = cfg.num_samples * (cfg.num_samples - 1) // 2 21 | self.eval_head = cfg.eval_head 22 | self.emb_size = cfg.emb 23 | 24 | self.cfg = cfg 25 | 26 | def forward(self, samples): 27 | raise NotImplementedError 28 | 29 | def get_acc(self, ds_clf, ds_test, ds_test_p): 30 | self.eval() 31 | if self.eval_head: 32 | model = lambda x: self.head(self.model(x)) 33 | out_size = self.emb_size 34 | else: 35 | model, out_size = self.model, self.out_size 36 | # torch.cuda.empty_cache() 37 | x_train, y_train = get_data(model, ds_clf, out_size, "cuda") 38 | x_test, y_test = get_data(model, ds_test, out_size, "cuda") 39 | x_test_p, y_test_p = get_data(model, ds_test_p, out_size, "cuda") 40 | 41 | acc_knn = eval_knn(x_train, y_train, x_test, y_test, self.knn) 42 | asr_knn = eval_knn(x_train, y_train, x_test_p, y_test_p, self.knn) 43 | acc, y_pred, y_test, acc_p, y_pred_p, y_test_p = eval_sgd(model, ds_clf, ds_test, ds_test_p, out_size, "cuda") 44 | 45 | 46 | del x_train, y_train, x_test, y_test, x_test_p, y_test_p 47 | self.train() 48 | return acc_knn, acc, asr_knn, acc_p 49 | 50 | def step(self, progress): 51 | pass 52 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/methods/byol.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from model import get_model, get_head 7 | from .base import BaseMethod 8 | 9 | 10 | def norm_mse_loss(x0, x1): 11 | x0 = F.normalize(x0) 12 | x1 = F.normalize(x1) 13 | return 2 - 2 * (x0 * x1).sum(dim=-1).mean() 14 | 15 | class BYOL(BaseMethod): 16 | """ implements BYOL loss https://arxiv.org/abs/2006.07733 """ 17 | 18 | def __init__(self, cfg): 19 | """ init additional target and predictor networks """ 20 | super().__init__(cfg) 21 | self.pred = nn.Sequential( 22 | nn.Linear(cfg.emb, cfg.head_size), 23 | nn.BatchNorm1d(cfg.head_size), 24 | nn.ReLU(), 25 | nn.Linear(cfg.head_size, cfg.emb), 26 | ) 27 | self.model_t, _ = get_model(cfg.arch, cfg.dataset) 28 | self.head_t = get_head(self.out_size, cfg) 29 | for param in chain(self.model_t.parameters(), self.head_t.parameters()): 30 | param.requires_grad = False 31 | self.update_target(0) 32 | self.byol_tau = cfg.byol_tau 33 | self.loss_f = norm_mse_loss if cfg.norm else F.mse_loss 34 | 35 | def update_target(self, tau): 36 | """ copy parameters from main network to target """ 37 | for t, s in zip(self.model_t.parameters(), self.model.parameters()): 38 | t.data.copy_(t.data * tau + s.data * (1.0 - tau)) 39 | for t, s in zip(self.head_t.parameters(), self.head.parameters()): 40 | t.data.copy_(t.data * tau + s.data * (1.0 - tau)) 41 | 42 | def forward(self, samples): 43 | z = [self.pred(self.head(self.model(x))) for x in samples] 44 | with torch.no_grad(): 45 | zt = [self.head_t(self.model_t(x)) for x in samples] 46 | 47 | loss = 0 48 | for i in range(len(samples) - 1): 49 | for j in range(i + 1, len(samples)): 50 | loss += self.loss_f(z[i], zt[j]) + self.loss_f(z[j], zt[i]) 51 | loss /= self.num_pairs 52 | return loss 53 | 54 | def step(self, progress): 55 | """ update target network with cosine increasing schedule """ 56 | tau = 1 - (1 - self.byol_tau) * (math.cos(math.pi * progress) + 1) / 2 57 | self.update_target(tau) 58 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/methods/contrastive.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .base import BaseMethod 6 | 7 | 8 | def contrastive_loss(x0, x1, tau, norm): 9 | # https://github.com/google-research/simclr/blob/master/objective.py 10 | bsize = x0.shape[0] 11 | target = torch.arange(bsize).cuda() 12 | eye_mask = torch.eye(bsize).cuda() * 1e9 13 | if norm: 14 | x0 = F.normalize(x0, p=2, dim=1) 15 | x1 = F.normalize(x1, p=2, dim=1) 16 | logits00 = x0 @ x0.t() / tau - eye_mask 17 | logits11 = x1 @ x1.t() / tau - eye_mask 18 | logits01 = x0 @ x1.t() / tau 19 | logits10 = x1 @ x0.t() / tau 20 | return ( 21 | F.cross_entropy(torch.cat([logits01, logits00], dim=1), target) 22 | + F.cross_entropy(torch.cat([logits10, logits11], dim=1), target) 23 | ) / 2 24 | 25 | 26 | class Contrastive(BaseMethod): 27 | """ implements contrastive loss https://arxiv.org/abs/2002.05709 """ 28 | 29 | def __init__(self, cfg): 30 | """ init additional BN used after head """ 31 | super().__init__(cfg) 32 | self.bn_last = nn.BatchNorm1d(cfg.emb) 33 | self.loss_f = partial(contrastive_loss, tau=cfg.tau, norm=cfg.norm) 34 | 35 | def forward(self, samples): 36 | bs = len(samples[0]) 37 | h = [self.model(x.cuda(non_blocking=True)) for x in samples] 38 | h = self.bn_last(self.head(torch.cat(h))) 39 | loss = 0 40 | for i in range(len(samples) - 1): 41 | for j in range(i + 1, len(samples)): 42 | x0 = h[i * bs : (i + 1) * bs] 43 | x1 = h[j * bs : (j + 1) * bs] 44 | loss += self.loss_f(x0, x1) 45 | loss /= self.num_pairs 46 | return loss 47 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/methods/norm_mse.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | 4 | def norm_mse_loss(x0, x1): 5 | x0 = F.normalize(x0) 6 | x1 = F.normalize(x1) 7 | return 2 - 2 * (x0 * x1).sum(dim=-1).mean() 8 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/methods/w_mse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .whitening import Whitening2d 4 | from .base import BaseMethod 5 | from .norm_mse import norm_mse_loss 6 | 7 | 8 | class WMSE(BaseMethod): 9 | """ implements W-MSE loss """ 10 | 11 | def __init__(self, cfg): 12 | """ init whitening transform """ 13 | super().__init__(cfg) 14 | self.whitening = Whitening2d(cfg.emb, eps=cfg.w_eps, track_running_stats=False) 15 | self.loss_f = norm_mse_loss if cfg.norm else F.mse_loss 16 | self.w_iter = cfg.w_iter 17 | self.w_size = cfg.bs if cfg.w_size is None else cfg.w_size 18 | 19 | def forward(self, samples): 20 | bs = len(samples[0]) 21 | h = [self.model(x.cuda(non_blocking=True)) for x in samples] 22 | h = self.head(torch.cat(h)) 23 | loss = 0 24 | for _ in range(self.w_iter): 25 | z = torch.empty_like(h) 26 | perm = torch.randperm(bs).view(-1, self.w_size) 27 | for idx in perm: 28 | for i in range(len(samples)): 29 | z[idx + i * bs] = self.whitening(h[idx + i * bs]) 30 | for i in range(len(samples) - 1): 31 | for j in range(i + 1, len(samples)): 32 | x0 = z[i * bs : (i + 1) * bs] 33 | x1 = z[j * bs : (j + 1) * bs] 34 | loss += self.loss_f(x0, x1) 35 | loss /= self.w_iter * self.num_pairs 36 | return loss 37 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/methods/whitening.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import conv2d 4 | 5 | 6 | class Whitening2d(nn.Module): 7 | def __init__(self, num_features, momentum=0.01, track_running_stats=True, eps=0): 8 | super(Whitening2d, self).__init__() 9 | self.num_features = num_features 10 | self.momentum = momentum 11 | self.track_running_stats = track_running_stats 12 | self.eps = eps 13 | 14 | if self.track_running_stats: 15 | self.register_buffer( 16 | "running_mean", torch.zeros([1, self.num_features, 1, 1]) 17 | ) 18 | self.register_buffer("running_variance", torch.eye(self.num_features)) 19 | 20 | def forward(self, x): 21 | x = x.unsqueeze(2).unsqueeze(3) 22 | m = x.mean(0).view(self.num_features, -1).mean(-1).view(1, -1, 1, 1) 23 | if not self.training and self.track_running_stats: # for inference 24 | m = self.running_mean 25 | xn = x - m 26 | 27 | T = xn.permute(1, 0, 2, 3).contiguous().view(self.num_features, -1) 28 | f_cov = torch.mm(T, T.permute(1, 0)) / (T.shape[-1] - 1) 29 | 30 | eye = torch.eye(self.num_features).type(f_cov.type()) 31 | 32 | if not self.training and self.track_running_stats: # for inference 33 | f_cov = self.running_variance 34 | 35 | f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye 36 | 37 | inv_sqrt = torch.triangular_solve( 38 | eye, torch.cholesky(f_cov_shrinked), upper=False 39 | )[0] 40 | inv_sqrt = inv_sqrt.contiguous().view( 41 | self.num_features, self.num_features, 1, 1 42 | ) 43 | 44 | decorrelated = conv2d(xn, inv_sqrt) 45 | 46 | if self.training and self.track_running_stats: 47 | self.running_mean = torch.add( 48 | self.momentum * m.detach(), 49 | (1 - self.momentum) * self.running_mean, 50 | out=self.running_mean, 51 | ) 52 | self.running_variance = torch.add( 53 | self.momentum * f_cov.detach(), 54 | (1 - self.momentum) * self.running_variance, 55 | out=self.running_variance, 56 | ) 57 | 58 | return decorrelated.squeeze(2).squeeze(2) 59 | 60 | def extra_repr(self): 61 | return "features={}, eps={}, momentum={}".format( 62 | self.num_features, self.eps, self.momentum 63 | ) 64 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/moco/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/moco/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # Copyright (c) 2020 Tongzhou Wang 3 | from PIL import ImageFilter, Image 4 | import random 5 | import torchvision.transforms.functional as F 6 | import torchvision.transforms as transforms 7 | 8 | 9 | class TwoCropsTransform: 10 | """Take two random crops of one image as the query and key.""" 11 | 12 | def __init__(self, base_transform): 13 | self.base_transform = base_transform 14 | 15 | def __call__(self, x): 16 | q = self.base_transform(x) 17 | k = self.base_transform(x) 18 | return [q, k] 19 | 20 | class OneCropOneTestTransform: 21 | """Take one random crop of one image as the key and the whole image as the query.""" 22 | 23 | def __init__(self, crop_transform, test_transform): 24 | self.crop_transform = crop_transform 25 | self.test_transform = test_transform 26 | 27 | def __call__(self, x): 28 | q = self.test_transform(x) 29 | k = self.crop_transform(x) 30 | 31 | return [q, k] 32 | 33 | class GaussianBlur(object): 34 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 35 | 36 | def __init__(self, sigma=[.1, 2.]): 37 | self.sigma = sigma 38 | 39 | def __call__(self, x): 40 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 41 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 42 | return x -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/moco/poisonencoder_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw, ImageFont, ImageFilter, ImageColor 2 | import os 3 | import cv2 4 | import re 5 | import sys 6 | import glob 7 | import errno 8 | import random 9 | import numpy as np 10 | 11 | def get_trigger(trigger_size=40, trigger_path=None, colorful_trigger=True): 12 | # load trigger 13 | if colorful_trigger: 14 | trigger = Image.open(trigger_path).convert('RGB') 15 | trigger = trigger.resize((trigger_size, trigger_size)) 16 | else: 17 | trigger = Image.new("RGB", (trigger_size, trigger_size), ImageColor.getrgb("white")) 18 | return trigger 19 | 20 | 21 | 22 | 23 | def binary_mask_to_box(binary_mask): 24 | binary_mask = np.array(binary_mask, np.uint8) 25 | contours,hierarchy = cv2.findContours( 26 | binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 27 | areas = [] 28 | for cnt in contours: 29 | area = cv2.contourArea(cnt) 30 | areas.append(area) 31 | idx = areas.index(np.max(areas)) 32 | x, y, w, h = cv2.boundingRect(contours[idx]) 33 | bounding_box = [x, y, x+w, y+h] 34 | return bounding_box 35 | 36 | def get_foreground(reference_dir, num_references, max_size, type): 37 | img_idx = random.choice(range(1, 1+num_references)) 38 | image_path = os.path.join(reference_dir, f'{img_idx}/img.png') 39 | mask_path = os.path.join(reference_dir, f'{img_idx}/label.png') 40 | image_np = np.asarray(Image.open(image_path).convert('RGB')) 41 | mask_np = np.asarray(Image.open(mask_path).convert('RGB')) 42 | mask_np = (mask_np[..., 0] == 128) ##### [:,0]==128 represents the object mask 43 | 44 | # crop masked region 45 | bbx = binary_mask_to_box(mask_np) 46 | object_image = image_np[bbx[1]:bbx[3],bbx[0]:bbx[2]] 47 | object_image = Image.fromarray(object_image) 48 | object_mask = mask_np[bbx[1]:bbx[3],bbx[0]:bbx[2]] 49 | object_mask = Image.fromarray(object_mask) 50 | 51 | # resize -> avoid poisoned image being too large 52 | w, h = object_image.size 53 | if type=='horizontal': 54 | o_w = min(w, int(max_size/2)) 55 | o_h = int((o_w/w) * h) 56 | elif type=='vertical': 57 | o_h = min(h, int(max_size/2)) 58 | o_w = int((o_h/h) * w) 59 | object_image = object_image.resize((o_w, o_h)) 60 | object_mask = object_mask.resize((o_w, o_h)) 61 | return object_image, object_mask 62 | 63 | def concat(support_reference_image_path, reference_image_path, max_size): 64 | ### horizontally concat two images 65 | # get support reference image 66 | support_reference_image = Image.open(support_reference_image_path) 67 | width, height = support_reference_image.size 68 | n_w = min(width, int(max_size/2)) 69 | n_h = int((n_w/width) * height) 70 | support_reference_image = support_reference_image.resize((n_w, n_h)) 71 | width, height = support_reference_image.size 72 | 73 | # get reference image 74 | reference_image = Image.open(reference_image_path) 75 | reference_image = reference_image.resize((width, height)) 76 | 77 | img_new = Image.new("RGB", (width*2, height), "white") 78 | if random.random()<0.5: 79 | img_new.paste(support_reference_image, (0, 0)) 80 | img_new.paste(reference_image, (width, 0)) 81 | else: 82 | img_new.paste(reference_image, (0, 0)) 83 | img_new.paste(support_reference_image, (width, 0)) 84 | return img_new 85 | 86 | 87 | def get_random_reference_image(reference_dir, num_references): 88 | img_idx = random.choice(range(1, 1+num_references)) 89 | image_path = os.path.join(reference_dir, f'{img_idx}/img.png') 90 | return image_path 91 | 92 | def get_random_support_reference_image(reference_dir): 93 | support_dir = os.path.join(reference_dir, 'support-images') 94 | image_path = os.path.join(support_dir, random.choice(os.listdir(support_dir))) 95 | return image_path 96 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | 4 | 5 | def get_head(out_size, cfg): 6 | """ creates projection head g() from config """ 7 | x = [] 8 | in_size = out_size 9 | for _ in range(cfg.head_layers - 1): 10 | x.append(nn.Linear(in_size, cfg.head_size)) 11 | if cfg.add_bn: 12 | x.append(nn.BatchNorm1d(cfg.head_size)) 13 | x.append(nn.ReLU()) 14 | in_size = cfg.head_size 15 | x.append(nn.Linear(in_size, cfg.emb)) 16 | return nn.Sequential(*x) 17 | 18 | 19 | def get_model(arch, dataset): 20 | """ creates encoder E() by name and modifies it for dataset """ 21 | model = getattr(models, arch)(pretrained=False) 22 | # if dataset != "imagenet": 23 | if 'imagenet' not in dataset: 24 | model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 25 | if dataset == "cifar10" or dataset == "cifar100": 26 | model.maxpool = nn.Identity() 27 | out_size = model.fc.in_features 28 | model.fc = nn.Identity() 29 | 30 | return nn.DataParallel(model), out_size 31 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets import get_ds 3 | from cfg import get_cfg 4 | from methods import get_method 5 | 6 | from eval.sgd import eval_sgd 7 | from eval.knn import eval_knn 8 | from eval.lbfgs import eval_lbfgs 9 | from eval.get_data import get_data 10 | 11 | import os 12 | import numpy as np 13 | 14 | if __name__ == "__main__": 15 | cfg = get_cfg() 16 | 17 | model_full = get_method(cfg.method)(cfg) 18 | model_full.cuda().eval() 19 | if cfg.fname is None: 20 | print("evaluating random model") 21 | else: 22 | state_dict = torch.load(cfg.fname)['model_state_dict'] 23 | model_full.load_state_dict(state_dict) 24 | 25 | ds = get_ds(cfg.dataset)(None, cfg, cfg.num_workers, bs_clf=256, bs_test=256) 26 | device = "cpu" if cfg.clf == "lbfgs" else "cuda" 27 | if cfg.eval_head: 28 | model = lambda x: model_full.head(model_full.model(x)) 29 | out_size = cfg.emb 30 | else: 31 | model = model_full.model 32 | out_size = model_full.out_size 33 | 34 | # x_train, y_train = get_data(model, ds.clf, out_size, device) 35 | # x_test, y_test = get_data(model, ds.test, out_size, device) 36 | # x_test_p, y_test_p = get_data(model, ds.test_p, out_size, device) 37 | 38 | 39 | clf_fname = os.path.join(os.path.dirname(cfg.fname), "linear", os.path.basename(cfg.fname)) 40 | if cfg.clf == "sgd": 41 | # acc, pred_var_stack, labels_var_stack, acc_p, pred_var_p_stack, labels_var_p_stack = eval_sgd(x_train, y_train, x_test, y_test, x_test_p, y_test_p, evaluate=True) 42 | acc, pred_var_stack, labels_var_stack, acc_p, pred_var_p_stack, labels_var_p_stack = eval_sgd(model, ds.clf, ds.test, ds.test_p, out_size, device, evaluate=True) 43 | 44 | else: 45 | raise NotImplementedError 46 | # elif cfg.clf == "knn": 47 | # acc = eval_knn(x_train, y_train, x_test, y_test) 48 | # elif cfg.clf == "lbfgs": 49 | # acc = eval_lbfgs(x_train, y_train, x_test, y_test) 50 | 51 | pred_var_stack = torch.argmax(pred_var_stack, dim=1) 52 | pred_var_p_stack = torch.argmax(pred_var_p_stack, dim=1) 53 | 54 | # create confusion matrix ROWS ground truth COLUMNS pred 55 | conf_matrix_clean = np.zeros((int(labels_var_stack.max())+1, int(labels_var_stack.max())+1)) 56 | conf_matrix_poisoned = np.zeros((int(labels_var_stack.max())+1, int(labels_var_stack.max())+1)) 57 | 58 | for i in range(pred_var_stack.size(0)): 59 | # update confusion matrix 60 | conf_matrix_clean[int(labels_var_stack[i]), int(pred_var_stack[i])] += 1 61 | 62 | for i in range(pred_var_p_stack.size(0)): 63 | # update confusion matrix 64 | conf_matrix_poisoned[int(labels_var_p_stack[i]), int(pred_var_p_stack[i])] += 1 65 | 66 | if 'imagenet' in cfg.dataset: 67 | metadata_file = 'imagenet_metadata.txt' 68 | class_dir_list_file = 'imagenet100_classes.txt' 69 | elif cfg.dataset == 'cifar10': 70 | metadata_file = 'cifar10_metadata.txt' 71 | class_dir_list_file = 'cifar10_classes.txt' 72 | else: 73 | raise ValueError(f"Unknown dataset '{cfg.dataset}'") 74 | # load imagenet metadata 75 | with open(metadata_file, "r") as f: 76 | data = [l.strip() for l in f.readlines()] 77 | imagenet_metadata_dict = {} 78 | for line in data: 79 | wnid, classname = line.split()[0], line.split()[1] 80 | imagenet_metadata_dict[wnid] = classname 81 | 82 | with open(class_dir_list_file, 'r') as f: 83 | class_dir_list = [l.strip() for l in f.readlines()] 84 | class_dir_list = sorted(class_dir_list) 85 | 86 | if '1percent' in cfg.train_clean_file_path: 87 | save_folder = os.path.join(os.path.dirname(cfg.fname), "1per_linear", os.path.basename(cfg.fname)) 88 | elif '10percent' in cfg.train_clean_file_path: 89 | save_folder = os.path.join(os.path.dirname(cfg.fname), "10per_linear", os.path.basename(cfg.fname)) 90 | else: 91 | raise ValueError(f"Unknown train_clean_file_path '{cfg.train_clean_file_path}'") 92 | os.makedirs(save_folder, exist_ok=True) 93 | np.save("{}/conf_matrix_clean.npy".format(save_folder), conf_matrix_clean) 94 | np.save("{}/conf_matrix_poisoned.npy".format(save_folder), conf_matrix_poisoned) 95 | 96 | with open("{}/conf_matrix.csv".format(save_folder), "w") as f: 97 | f.write("Model {},,Clean val,,,,Pois. val,,\n".format("")) 98 | f.write("Data {},,acc1,,,,acc1,,\n".format("")) 99 | f.write(",,{:.2f},,,,{:.2f},,\n".format(acc[1]*100, acc_p[1]*100)) 100 | f.write("class name,class id,TP,FP,,TP,FP\n") 101 | for target in range(len(class_dir_list)): 102 | f.write("{},{},{},{},,".format(imagenet_metadata_dict[class_dir_list[target]].replace(",",";"), target, conf_matrix_clean[target][target], conf_matrix_clean[:, target].sum() - conf_matrix_clean[target][target])) 103 | f.write("{},{}\n".format(conf_matrix_poisoned[target][target], conf_matrix_poisoned[:, target].sum() - conf_matrix_poisoned[target][target])) 104 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/byol/train.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import yaml 4 | import argparse 5 | import os 6 | import wandb 7 | import torch 8 | import torch.optim as optim 9 | from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingWarmRestarts 10 | import torch.backends.cudnn as cudnn 11 | 12 | from cfg import get_cfg 13 | from dataset import get_ds 14 | from methods import get_method 15 | from torch.utils.tensorboard import SummaryWriter 16 | from tqdm import trange, tqdm 17 | 18 | def get_scheduler(optimizer, cfg): 19 | if cfg.lr_step == "cos": 20 | return CosineAnnealingWarmRestarts( 21 | optimizer, 22 | T_0=cfg.epoch if cfg.T0 is None else cfg.T0, 23 | T_mult=cfg.Tmult, 24 | eta_min=cfg.eta_min, 25 | ) 26 | elif cfg.lr_step == "step": 27 | m = [cfg.epoch - a for a in cfg.drop] 28 | return MultiStepLR(optimizer, milestones=m, gamma=cfg.drop_gamma) 29 | else: 30 | return None 31 | 32 | def load_config_from_yaml(config_path): 33 | """Load configuration from a YAML file.""" 34 | with open(config_path, 'r') as file: 35 | return yaml.safe_load(file) 36 | 37 | def merge_configs(defaults, overrides): 38 | """Merge two dictionaries, prioritizing values from 'overrides'.""" 39 | result = defaults.copy() 40 | 41 | result.update({k: v for k, v in overrides.items() if not k in result.keys()}) 42 | result.update({k: v for k, v in overrides.items() if v is not None}) 43 | 44 | return argparse.Namespace(**result) 45 | 46 | if __name__ == "__main__": 47 | cfg = get_cfg() 48 | print('cfg', cfg) 49 | 50 | if cfg.config: 51 | config_from_yaml = load_config_from_yaml(cfg.config) 52 | else: 53 | config_from_yaml = {} 54 | 55 | # Prepare final configuration by merging YAML config with command line arguments 56 | cfg = merge_configs(config_from_yaml, vars(cfg)) 57 | print(cfg) 58 | 59 | if 'targeted' in cfg.attack_algorithm: 60 | assert cfg.downstream_data is not None 61 | 62 | # wandb.init(project=cfg.wandb, config=cfg) 63 | writer = SummaryWriter(cfg.save_folder) 64 | 65 | ds = get_ds(cfg.dataset)(cfg.bs, cfg, cfg.num_workers, bs_clf=256, bs_test=256) 66 | train_dataloader = ds.train 67 | 68 | model = get_method(cfg.method)(cfg) 69 | model.cuda().train() 70 | 71 | optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.adam_l2) 72 | scheduler = get_scheduler(optimizer, cfg) 73 | starting_epoch = 0 74 | eval_every = cfg.eval_every 75 | lr_warmup = 0 if cfg.lr_warmup else 500 76 | cudnn.benchmark = True 77 | 78 | if cfg.fname is not None: 79 | checkpoint = torch.load(cfg.fname) 80 | 81 | starting_epoch = checkpoint['epoch'] # 加载当前 epoch 82 | lr_warmup = checkpoint['warmup'] # 加载当前 warmup 83 | model.load_state_dict(checkpoint['state_dict']) 84 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 85 | if 'scheduler_state_dict' in checkpoint and scheduler is not None: 86 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 87 | 88 | 89 | 90 | for ep in trange(starting_epoch, cfg.epoch, position=0): 91 | loss_ep = [] 92 | iters = len(train_dataloader) 93 | for n_iter, (samples, _) in enumerate(tqdm(train_dataloader, position=1)): 94 | if lr_warmup < 500: 95 | lr_scale = (lr_warmup + 1) / 500 96 | for pg in optimizer.param_groups: 97 | pg["lr"] = cfg.lr * lr_scale 98 | lr_warmup += 1 99 | 100 | optimizer.zero_grad() 101 | loss = model(samples) 102 | loss.backward() 103 | optimizer.step() 104 | loss_ep.append(loss.item()) 105 | model.step(ep / cfg.epoch) 106 | if cfg.lr_step == "cos" and lr_warmup >= 500: 107 | scheduler.step(ep + n_iter / iters) 108 | 109 | if cfg.lr_step == "step": 110 | scheduler.step() 111 | 112 | 113 | # if (ep + 1) % eval_every == 0: 114 | # acc_knn, acc_linear, asr_knn, asr_linear = model.get_acc(ds.clf, ds.test, ds.test_p) 115 | # # wandb.log({"acc": acc[1], "acc_5": acc[5], "acc_knn": acc_knn}, commit=False) 116 | # writer.add_scalar('Accuracy/acc', acc_linear[1], ep) 117 | # writer.add_scalar('Accuracy/acc_5', acc_linear[5], ep) 118 | # writer.add_scalar('Accuracy/acc_knn', acc_knn, ep) 119 | # writer.add_scalar('Accuracy/asr', asr_linear[1], ep) 120 | # writer.add_scalar('Accuracy/asr_5', asr_linear[5], ep) 121 | # writer.add_scalar('Accuracy/asr_knn', asr_knn, ep) 122 | 123 | if (ep + 1) % cfg.save_freq == 0: 124 | fname = f"{cfg.save_folder}/{ep}.pth" 125 | os.makedirs(os.path.dirname(fname), exist_ok=True) 126 | checkpoint = { 127 | 'epoch': ep, 128 | 'warmup': lr_warmup, 129 | 'state_dict': model.state_dict(), 130 | 'optimizer_state_dict': optimizer.state_dict(), 131 | 'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None, 132 | } 133 | torch.save(checkpoint, fname) 134 | 135 | # wandb.log({"loss": np.mean(loss_ep), "ep": ep}) 136 | writer.add_scalar('Loss/byol_pretrain', np.mean(loss_ep), ep) 137 | -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/moco/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # Copyright (c) 2020 Tongzhou Wang 3 | from PIL import ImageFilter, Image 4 | import random 5 | import torchvision.transforms.functional as F 6 | import torchvision.transforms as transforms 7 | 8 | 9 | class TwoCropsTransform: 10 | """Take two random crops of one image as the query and key.""" 11 | 12 | def __init__(self, base_transform): 13 | self.base_transform = base_transform 14 | 15 | def __call__(self, x): 16 | q = self.base_transform(x) 17 | k = self.base_transform(x) 18 | return [q, k] 19 | 20 | class OneCropOneTestTransform: 21 | """Take one random crop of one image as the key and the whole image as the query.""" 22 | 23 | def __init__(self, crop_transform, test_transform): 24 | self.crop_transform = crop_transform 25 | self.test_transform = test_transform 26 | 27 | def __call__(self, x): 28 | q = self.test_transform(x) 29 | k = self.crop_transform(x) 30 | 31 | return [q, k] 32 | 33 | class GaussianBlur(object): 34 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 35 | 36 | def __init__(self, sigma=[.1, 2.]): 37 | self.sigma = sigma 38 | 39 | def __call__(self, x): 40 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 41 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 42 | return x -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | import timm.models.vision_transformer 18 | 19 | 20 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 21 | """ Vision Transformer with support for global average pooling 22 | """ 23 | def __init__(self, global_pool=False, **kwargs): 24 | super(VisionTransformer, self).__init__(**kwargs) 25 | 26 | self.global_pool = global_pool 27 | if self.global_pool: 28 | norm_layer = kwargs['norm_layer'] 29 | embed_dim = kwargs['embed_dim'] 30 | self.fc_norm = norm_layer(embed_dim) 31 | 32 | del self.norm # remove the original norm 33 | 34 | def forward_features(self, x): 35 | B = x.shape[0] 36 | x = self.patch_embed(x) 37 | 38 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 39 | x = torch.cat((cls_tokens, x), dim=1) 40 | x = x + self.pos_embed 41 | x = self.pos_drop(x) 42 | 43 | for blk in self.blocks: 44 | x = blk(x) 45 | 46 | if self.global_pool: 47 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 48 | outcome = self.fc_norm(x) 49 | else: 50 | x = self.norm(x) 51 | outcome = x[:, 0] 52 | 53 | return outcome 54 | 55 | def forward(self, x): 56 | x = self.forward_features(x) 57 | return x 58 | 59 | 60 | def vit_base_patch16(**kwargs): 61 | model = VisionTransformer( 62 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 63 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 64 | return model 65 | 66 | 67 | def vit_large_patch16(**kwargs): 68 | model = VisionTransformer( 69 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 70 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 71 | return model 72 | 73 | 74 | def vit_huge_patch14(**kwargs): 75 | model = VisionTransformer( 76 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 77 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 78 | return model -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/simclr/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ssl_backdoor.utils.model_utils import transform_encoder_for_small_dataset, remove_task_head_for_encoder 6 | from lightly.models.modules import SimCLRProjectionHead 7 | from lightly.loss import NTXentLoss 8 | 9 | 10 | class SimCLR(nn.Module): 11 | """ 12 | 构建一个SimCLR模型。 13 | """ 14 | def __init__(self, base_encoder, dim=512, proj_dim=128, dataset=None): 15 | """ 16 | dim: 特征维度 (默认: 2048) 17 | proj_dim: 投影头输出维度 (默认: 128) 18 | """ 19 | super(SimCLR, self).__init__() 20 | 21 | # 创建编码器 22 | self.encoder = base_encoder(num_classes=dim, zero_init_residual=True) 23 | channel_dim = SimCLR.get_channel_dim(self.encoder) 24 | self.encoder = transform_encoder_for_small_dataset(self.encoder, dataset) 25 | self.encoder = remove_task_head_for_encoder(self.encoder) 26 | 27 | # 创建投影头 (MLP) 28 | self.projector = SimCLRProjectionHead(input_dim=channel_dim, hidden_dim=dim, output_dim=proj_dim) 29 | 30 | self.criterion = NTXentLoss() 31 | 32 | @staticmethod 33 | def get_channel_dim(encoder: nn.Module) -> int: 34 | if hasattr(encoder, 'fc'): 35 | return encoder.fc.weight.shape[1] 36 | elif hasattr(encoder, 'head'): 37 | return encoder.head.weight.shape[1] 38 | elif hasattr(encoder, 'classifier'): 39 | return encoder.classifier.weight.shape[1] 40 | else: 41 | raise NotImplementedError('MLP投影头在编码器中未找到') 42 | 43 | def forward(self, x1, x2): 44 | """ 45 | 输入: 46 | x1: 图像第一视角 47 | x2: 图像第二视角 48 | 输出: 49 | 对比损失 50 | """ 51 | # 计算两个视角的特征 52 | h1 = self.encoder(x1) 53 | h2 = self.encoder(x2) 54 | 55 | # 通过投影头 56 | z1 = self.projector(h1) 57 | z2 = self.projector(h2) 58 | 59 | # 归一化表示以进行余弦相似度 60 | z1 = F.normalize(z1, dim=1) 61 | z2 = F.normalize(z2, dim=1) 62 | 63 | loss = self.criterion(z1, z2) 64 | 65 | return loss -------------------------------------------------------------------------------- /ssl_backdoor/ssl_trainers/simsiam/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from ssl_backdoor.utils.model_utils import transform_encoder_for_small_dataset, remove_task_head_for_encoder 12 | from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead 13 | 14 | class SimSiam(nn.Module): 15 | """ 16 | Build a SimSiam model. 17 | """ 18 | def __init__(self, base_encoder, dim=2048, pred_dim=512, dataset=None): 19 | """ 20 | dim: feature dimension (default: 2048) 21 | pred_dim: hidden dimension of the predictor (default: 512) 22 | """ 23 | super(SimSiam, self).__init__() 24 | 25 | # create the encoder 26 | # num_classes is the output fc dimension, zero-initialize last BNs 27 | encoder_name = str(base_encoder).lower() 28 | self.encoder = base_encoder(num_classes=dim) 29 | 30 | if "squeezenet" in encoder_name: 31 | # 对于squeezenet,跳过特殊处理,直接使用dim作为channel_dim 32 | channel_dim = dim 33 | else: 34 | # 对于其他网络,按原来的方式处理 35 | channel_dim = SimSiam.get_channel_dim(self.encoder) 36 | self.encoder = transform_encoder_for_small_dataset(self.encoder, dataset) 37 | self.encoder = remove_task_head_for_encoder(self.encoder) 38 | 39 | self.projector = SimSiamProjectionHead(input_dim=channel_dim, hidden_dim=channel_dim, output_dim=dim) 40 | self.predictor = SimSiamPredictionHead(input_dim=dim, hidden_dim=pred_dim, output_dim=dim) 41 | 42 | 43 | @staticmethod 44 | def get_channel_dim(encoder: nn.Module) -> int: 45 | def get_channel_dim_from_sequential(sequential: nn.Sequential) -> int: 46 | for module in sequential: 47 | if isinstance(module, nn.Linear): 48 | return module.in_features 49 | elif isinstance(module, nn.Conv2d): 50 | return module.in_channels 51 | raise ValueError("没有在Sequential中找到Linear层或Conv2d层") 52 | 53 | if hasattr(encoder, 'fc'): 54 | return encoder.fc.weight.shape[1] 55 | elif hasattr(encoder, 'head'): 56 | return encoder.head.weight.shape[1] 57 | elif hasattr(encoder, 'heads'): 58 | # 处理Vision Transformer中的heads属性 59 | if hasattr(encoder.heads, 'head'): 60 | return encoder.heads.head.in_features 61 | # 遍历Sequential中的层 62 | if isinstance(encoder.heads, nn.Sequential): 63 | for name, module in encoder.heads.named_children(): 64 | if name == 'head' or name == 'pre_logits': 65 | return module.in_features 66 | # 如果没有找到head或pre_logits,尝试获取第一个Linear层 67 | return get_channel_dim_from_sequential(encoder.heads) 68 | elif hasattr(encoder, 'classifier'): 69 | if isinstance(encoder.classifier, nn.Sequential): 70 | return get_channel_dim_from_sequential(encoder.classifier) 71 | else: 72 | return encoder.classifier.weight.shape[1] 73 | else: 74 | raise NotImplementedError('MLP projection head not found in encoder') 75 | 76 | 77 | def forward(self, x1, x2): 78 | """ 79 | Input: 80 | x1: first views of images 81 | x2: second views of images 82 | Output: 83 | p1, p2, z1, z2: predictors and targets of the network 84 | See Sec. 3 of https://arxiv.org/abs/2011.10566 for detailed notations 85 | """ 86 | 87 | # compute features for one view 88 | z1 = self.encoder(x1) # NxC 89 | z2 = self.encoder(x2) # NxC 90 | 91 | z1 = self.projector(z1) 92 | z2 = self.projector(z2) 93 | 94 | p1 = self.predictor(z1) # NxC 95 | p2 = self.predictor(z2) # NxC 96 | 97 | z1, z2 = z1.detach(), z2.detach() 98 | 99 | loss = -(F.cosine_similarity(p1, z2).mean() + F.cosine_similarity(p2, z1).mean()) * 0.5 100 | 101 | return loss 102 | -------------------------------------------------------------------------------- /ssl_backdoor/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/ssl_backdoor/utils/__init__.py -------------------------------------------------------------------------------- /ssl_backdoor/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import random 3 | import torch 4 | import numpy as np 5 | 6 | from PIL import Image 7 | from tqdm import tqdm 8 | from sklearn.neighbors import NearestNeighbors 9 | 10 | def extract_features(model, loader, class_index=None): 11 | """ 12 | Extracts features from the model using the given loader and saves them to a file. 13 | 14 | Args: 15 | model (torch.nn.Module): The model from which to extract features. 16 | loader (torch.utils.data.DataLoader): The DataLoader for input data. 17 | class_index (int): The index of the class to extract features for. If None, all classes are used. 18 | """ 19 | model.eval() 20 | device = next(model.parameters()).device 21 | 22 | features = [] 23 | target_list = [] 24 | 25 | 26 | with torch.no_grad(): 27 | for i, (inputs, targets) in enumerate(tqdm(loader)): 28 | if class_index is not None: 29 | mask = targets == class_index 30 | inputs = inputs[mask] 31 | targets = targets[mask] 32 | 33 | inputs = inputs.to(device) 34 | output = model(inputs) 35 | output = F.normalize(output, dim=1) 36 | features.append(output.detach().cpu()) 37 | target_list.append(targets) 38 | 39 | features = torch.cat(features, dim=0) 40 | targets = torch.cat(target_list, dim=0) 41 | 42 | 43 | return features, targets 44 | 45 | def interpolate_pos_embed(model, checkpoint_model): 46 | if 'pos_embed' in checkpoint_model: 47 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 48 | embedding_size = pos_embed_checkpoint.shape[-1] 49 | num_patches = model.patch_embed.num_patches 50 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 51 | # height (== width) for the checkpoint position embedding 52 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 53 | # height (== width) for the new position embedding 54 | new_size = int(num_patches ** 0.5) 55 | # class_token and dist_token are kept unchanged 56 | if orig_size != new_size: 57 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 58 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 59 | # only the position tokens are interpolated 60 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 61 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 62 | pos_tokens = torch.nn.functional.interpolate( 63 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 64 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 65 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 66 | checkpoint_model['pos_embed'] = new_pos_embed 67 | 68 | def get_channels(arch): 69 | if arch == 'alexnet': 70 | c = 4096 71 | elif arch == 'pt_alexnet': 72 | c = 4096 73 | elif arch == 'resnet50': 74 | c = 2048 75 | elif 'resnet18' in arch: 76 | c = 512 77 | elif 'densenet121' in arch: 78 | c = 1024 79 | elif arch == 'mobilenet': 80 | c = 1280 81 | elif arch == 'MobileNetV2': 82 | c = 1280 83 | elif arch == 'SqueezeNet': 84 | c = 2048 85 | elif arch == 'resnet50x5_swav': 86 | c = 10240 87 | elif arch == 'vit_b_16': 88 | c = 768 89 | elif arch == 'swin_s': 90 | c = 768 91 | else: 92 | raise ValueError('arch not found: ' + arch) 93 | return c 94 | 95 | def knn_evaluate(model, train_loader, test_loader, device): 96 | model.eval() 97 | model.to(device) 98 | feature_bank = [] 99 | labels = [] 100 | 101 | # 构建特征库和标签 102 | with torch.no_grad(): 103 | for data, target in train_loader: 104 | data = data.to(device) 105 | feature = model(data).flatten(start_dim=1) 106 | feature_bank.append(feature.cpu()) 107 | labels.append(target.cpu()) 108 | 109 | feature_bank = torch.cat(feature_bank, dim=0).numpy() 110 | labels = torch.cat(labels, dim=0).numpy() # 转换为 NumPy 数组 111 | 112 | # 训练 KNN 113 | knn = NearestNeighbors(n_neighbors=200, metric='cosine') 114 | knn.fit(feature_bank) 115 | 116 | total_correct = 0 117 | total_num = 0 118 | all_preds = [] 119 | all_targets_list = [] 120 | 121 | # 评估阶段`` 122 | with torch.no_grad(): 123 | for data, target in test_loader: 124 | data = data.to(device) 125 | feature = model(data).flatten(start_dim=1) 126 | feature = feature.cpu().numpy() 127 | 128 | distances, indices = knn.kneighbors(feature) 129 | 130 | # 使用 NumPy 进行索引 131 | retrieved_neighbors = labels[indices] # shape: [batch_size, n_neighbors] 132 | 133 | # 计算预测标签(使用众数) 134 | pred_labels = np.squeeze(np.apply_along_axis(lambda x: np.bincount(x).argmax(), 1, retrieved_neighbors)) 135 | 136 | # 将预测标签转换为 PyTorch 张量 137 | pred_labels = torch.tensor(pred_labels, device='cpu') # 使用 CPU 进行比较 138 | 139 | # 计算正确预测数量 140 | total_correct += (pred_labels == target.cpu()).sum().item() 141 | total_num += data.size(0) 142 | all_preds.append(pred_labels) 143 | all_targets_list.append(target) 144 | 145 | accuracy = total_correct / total_num 146 | print(f"[knn_evaluate] Total accuracy: {accuracy * 100:.2f}%") 147 | all_preds = torch.cat(all_preds, dim=0) 148 | all_targets_list = torch.cat(all_targets_list, dim=0) 149 | return accuracy, all_preds, all_targets_list 150 | 151 | 152 | def extract_config_by_prefix(config_dict, prefix): 153 | """ 154 | 从配置字典中提取特定前缀的键值对 155 | 156 | Args: 157 | config_dict (dict): 配置字典 158 | prefix (str): 键前缀 159 | 160 | Returns: 161 | dict: 包含所有以指定前缀开头的键值对的字典 162 | """ 163 | # 如果前缀恰好是字典中的一个键,并且对应值是字典,则直接返回该子字典 164 | if prefix in config_dict and isinstance(config_dict[prefix], dict): 165 | return config_dict[prefix] 166 | 167 | # 否则,查找所有以该前缀开头的键 168 | result = {} 169 | for key, value in config_dict.items(): 170 | if key.startswith(f"{prefix}"): 171 | key = key[len(prefix):] 172 | result[key] = value 173 | 174 | return result 175 | 176 | 177 | def set_seed(seed): 178 | """设置随机种子以确保结果可重现""" 179 | random.seed(seed) 180 | np.random.seed(seed) 181 | torch.manual_seed(seed) 182 | torch.cuda.manual_seed(seed) 183 | torch.backends.cudnn.benchmark = False 184 | torch.backends.cudnn.deterministic = True -------------------------------------------------------------------------------- /tools/ddp_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import yaml 5 | 6 | # 导入训练器接口和配置加载函数 7 | from ssl_backdoor.ssl_trainers.trainer import get_trainer 8 | from ssl_backdoor.ssl_trainers.utils import load_config 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser(description='运行训练,命令行参数将覆盖配置文件中的同名参数') 12 | parser.add_argument('--config', type=str, default=None, required=True, 13 | help='基础配置文件路径,支持.py或.yaml格式') 14 | parser.add_argument('--attack_config', type=str, default=None, required=True, 15 | help='后门攻击配置文件路径 (.yaml格式)') 16 | parser.add_argument('--test_config', type=str, default=None, 17 | help='测试配置文件路径 (.yaml格式)') 18 | 19 | args = parser.parse_args() 20 | 21 | # 1. 加载基础配置 22 | config = load_config(args.config) 23 | print(f"已加载基础配置: {args.config}") 24 | 25 | # 2. 加载并合并攻击配置 (如果提供了) 26 | if args.attack_config: 27 | try: 28 | attack_config = load_config(args.attack_config) # 使用同一个加载函数 29 | if isinstance(attack_config, dict): 30 | print(f"已加载攻击配置: {args.attack_config}") 31 | # 合并,attack_config 中的值会覆盖 config 中的同名值 32 | config.update(attack_config) 33 | print("攻击配置已合并。") 34 | else: 35 | print(f"警告:攻击配置文件 {args.attack_config} 未能正确加载为字典,跳过合并。") 36 | except Exception as e: 37 | print(f"警告:加载攻击配置文件 {args.attack_config} 时出错: {e},跳过合并。") 38 | 39 | # 2.5 加载并存储测试配置 (如果提供了) 40 | if args.test_config: 41 | try: 42 | test_config_dict = load_config(args.test_config) 43 | if isinstance(test_config_dict, dict): 44 | print(f"已加载测试配置: {args.test_config}") 45 | config['test_config'] = test_config_dict 46 | else: 47 | print(f"警告:测试配置文件 {args.test_config} 未能正确加载为字典,跳过合并。") 48 | except Exception as e: 49 | print(f"警告:加载测试配置文件 {args.test_config} 时出错: {e},跳过合并。") 50 | 51 | print("\n最终使用的训练配置:", config) 52 | print("\n最终使用的测试配置:", config['test_config']) 53 | 54 | # 5. 获取训练器 (传入更新后的config字典) 55 | trainer = get_trainer(config) 56 | 57 | # 6. 准备测试接口 (如果启用了评估) 58 | eval_frequency = config.get('eval_frequency', 50) 59 | print(f"eval_frequency: {eval_frequency}, type: {type(eval_frequency)}") 60 | 61 | 62 | # 7. 启动训练 63 | trainer = get_trainer(config) 64 | trainer() 65 | 66 | 67 | if __name__ == '__main__': 68 | main() -------------------------------------------------------------------------------- /tools/eval_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Tongzhou Wang 2 | import shutil 3 | 4 | import logging 5 | import os 6 | 7 | import torch 8 | from torch import nn 9 | from torchvision import models 10 | 11 | def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False): 12 | logger = logging.getLogger() 13 | if debug: 14 | level = logging.DEBUG 15 | else: 16 | level = logging.INFO 17 | logger.setLevel(level) 18 | if saving: 19 | info_file_handler = logging.FileHandler(logpath, mode="a") 20 | info_file_handler.setLevel(level) 21 | logger.addHandler(info_file_handler) 22 | if displaying: 23 | console_handler = logging.StreamHandler() 24 | console_handler.setLevel(level) 25 | logger.addHandler(console_handler) 26 | logger.info(filepath) 27 | with open(filepath, "r") as f: 28 | logger.info(f.read()) 29 | 30 | for f in package_files: 31 | logger.info(f) 32 | with open(f, "r") as package_f: 33 | logger.info(package_f.read()) 34 | 35 | return logger 36 | 37 | 38 | class AverageMeter(object): 39 | """Computes and stores the average and current value""" 40 | def __init__(self, name, fmt=':f'): 41 | self.name = name 42 | self.fmt = fmt 43 | self.reset() 44 | 45 | def reset(self): 46 | self.val = 0 47 | self.avg = 0 48 | self.sum = 0 49 | self.count = 0 50 | 51 | def update(self, val, n=1): 52 | self.val = val 53 | self.sum += val * n 54 | self.count += n 55 | self.avg = self.sum / self.count 56 | 57 | def __str__(self): 58 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 59 | return fmtstr.format(**self.__dict__) 60 | 61 | 62 | class ProgressMeter(object): 63 | def __init__(self, num_batches, meters, prefix=""): 64 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 65 | self.meters = meters 66 | self.prefix = prefix 67 | 68 | def display(self, batch): 69 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 70 | entries += [str(meter) for meter in self.meters] 71 | return '\t'.join(entries) 72 | 73 | def _get_batch_fmtstr(self, num_batches): 74 | num_digits = len(str(num_batches // 1)) 75 | fmt = '{:' + str(num_digits) + 'd}' 76 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 77 | 78 | def accuracy(output, target, topk=(1,)): 79 | """Computes the accuracy over the k top predictions for the specified values of k""" 80 | with torch.no_grad(): 81 | maxk = max(topk) 82 | batch_size = target.size(0) 83 | 84 | _, pred = output.topk(maxk, 1, True, True) 85 | pred = pred.t() 86 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 87 | # pdb.set_trace() 88 | res = [] 89 | for k in topk: 90 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 91 | res.append(correct_k.mul_(100.0 / batch_size)) 92 | return res 93 | 94 | arch_to_key = { 95 | 'alexnet': 'alexnet', 96 | 'alexnet_moco': 'alexnet', 97 | 'resnet18': 'resnet18', 98 | 'resnet50': 'resnet50', 99 | 'rotnet_r50': 'resnet50', 100 | 'rotnet_r18': 'resnet18', 101 | 'moco_resnet18': 'resnet18', 102 | 'resnet_moco': 'resnet50', 103 | } 104 | 105 | model_names = list(arch_to_key.keys()) 106 | 107 | def save_checkpoint(state, is_best, save_dir): 108 | ckpt_path = os.path.join(save_dir, 'checkpoint.pth.tar') 109 | torch.save(state, ckpt_path) 110 | if is_best: 111 | best_ckpt_path = os.path.join(save_dir, 'model_best.pth.tar') 112 | shutil.copyfile(ckpt_path, best_ckpt_path) 113 | 114 | 115 | def makedirs(dirname): 116 | if not os.path.exists(dirname): 117 | os.makedirs(dirname) -------------------------------------------------------------------------------- /tools/process_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import torch 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | sys.path.append("/workspace/SSL-Backdoor") 8 | from datasets.dataset import OnlineUniversalPoisonedValDataset, FileListDataset 9 | 10 | # 创建一个简单的参数对象 11 | class Args: 12 | def __init__(self): 13 | self.trigger_size = 50 # 从linear_probe.sh中获取 14 | self.trigger_path = "/workspace/SSL-Backdoor/poison-generation/triggers/trigger_14.png" # 从linear_probe.sh中获取 15 | self.trigger_insert = "patch" # 从linear_probe.sh中获取 16 | self.return_attack_target = False 17 | self.attack_target = 0 # 这个类别的样本会被排除 18 | self.attack_algorithm = "sslbkd" # 从linear_probe.sh中获取 19 | 20 | def process_dataset(input_txt_file, output_dir, config_file): 21 | """ 22 | 处理数据集并保存到新的目录 23 | 24 | Args: 25 | input_txt_file (str): 输入的配置文件路径 26 | output_dir (str): 输出图片保存的目录 27 | config_file (str): 新生成的配置文件路径 28 | """ 29 | # 创建输出目录 30 | os.makedirs(output_dir, exist_ok=True) 31 | 32 | # 创建转换 33 | transform = transforms.Compose([ 34 | transforms.ToTensor(), 35 | transforms.ToPILImage() 36 | ]) 37 | 38 | # 初始化数据集参数 39 | args = Args() 40 | attack_target = args.attack_target 41 | 42 | # 初始化干净数据集(使用FileListDataset而不是OnlineUniversalPoisonedValDataset) 43 | clean_dataset = FileListDataset( 44 | args=args, 45 | path_to_txt_file=input_txt_file, 46 | transform=transform 47 | ) 48 | 49 | # 初始化后门数据集 50 | backdoor_dataset = OnlineUniversalPoisonedValDataset( 51 | args=args, 52 | path_to_txt_file=input_txt_file, 53 | transform=transform, 54 | pre_inject_mode=False 55 | ) 56 | 57 | # 打开新的配置文件 58 | with open(config_file, 'w', encoding='utf-8') as f: 59 | # 计数器 60 | saved_count = 0 61 | 62 | # 遍历数据集 63 | for idx in range(len(clean_dataset)): 64 | # 获取原始图片和标签 65 | clean_img, original_label = clean_dataset[idx] 66 | 67 | # 跳过attack_target类别的样本 68 | if original_label == attack_target: 69 | continue 70 | 71 | # 获取后门图片 72 | backdoor_img, _ = backdoor_dataset[idx] 73 | 74 | # 保存干净图片 75 | clean_img_name = f"image_{saved_count:06d}_clean.png" 76 | clean_img_path = os.path.join(output_dir, clean_img_name) 77 | if isinstance(clean_img, torch.Tensor): 78 | clean_img = transforms.ToPILImage()(clean_img) 79 | clean_img.save(clean_img_path) 80 | f.write(f"{clean_img_path} 0\n") # 0表示干净图像 81 | 82 | # 保存后门图片 83 | backdoor_img_name = f"image_{saved_count:06d}_backdoor.png" 84 | backdoor_img_path = os.path.join(output_dir, backdoor_img_name) 85 | if isinstance(backdoor_img, torch.Tensor): 86 | backdoor_img = transforms.ToPILImage()(backdoor_img) 87 | backdoor_img.save(backdoor_img_path) 88 | f.write(f"{backdoor_img_path} 1\n") # 1表示后门图像 89 | 90 | saved_count += 1 91 | 92 | # 打印进度 93 | if saved_count % 50 == 0: 94 | print(f"已处理 {saved_count} 对图片(干净+后门)") 95 | 96 | if __name__ == "__main__": 97 | # 设置路径 98 | input_txt_file = "/workspace/SSL-Backdoor/data/ImageNet-100/valset.txt" # 从linear_probe.sh中获取测试集路径 99 | output_dir = "/workspace/detect-backdoor-samples-by-neighbourhood/data/backdoor_images" # 输出目录 100 | config_file = "/workspace/detect-backdoor-samples-by-neighbourhood/data/backdoor_config.txt" # 配置文件路径 101 | 102 | # 处理数据集 103 | process_dataset(input_txt_file, output_dir, config_file) 104 | print("数据集处理完成!") -------------------------------------------------------------------------------- /tools/run_badencoder.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # BadEncoder (Backdoor Self-Supervised Learning) 攻击方法的运行脚本 3 | # 获取脚本所在的目录 4 | SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd) 5 | # 获取项目根目录 (tools目录的上级目录) 6 | PROJECT_ROOT=$(dirname "$SCRIPT_DIR") 7 | # 将项目根目录添加到 PYTHONPATH 8 | export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH}" 9 | 10 | # 配置参数 11 | CONFIG_PATH="configs/attacks/badencoder.py" 12 | TEST_CONFIG_PATH="/workspace/SSL-Backdoor/configs/poisoning/poisoning_based/sslbkd_test.yaml" 13 | 14 | # 运行BadEncoder攻击 15 | CUDA_VISIBLE_DEVICES=2 python tools/run_badencoder.py \ 16 | --config $CONFIG_PATH \ 17 | --test_config $TEST_CONFIG_PATH -------------------------------------------------------------------------------- /tools/run_dede.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # DeDe (Decoder-based Detection) 防御方法的运行脚本 3 | # 获取脚本所在的目录 4 | SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd) 5 | # 获取项目根目录 (tools目录的上级目录) 6 | PROJECT_ROOT=$(dirname "$SCRIPT_DIR") 7 | # 将项目根目录添加到 PYTHONPATH 8 | export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH}" 9 | 10 | # 配置参数 11 | CONFIG_PATH="configs/defense/dede.py" 12 | SHADOW_CONFIG_PATH="/workspace/SSL-Backdoor/configs/poisoning/poisoning_based/sslbkd_shadow_copy.yaml" 13 | TEST_CONFIG_PATH="/workspace/SSL-Backdoor/configs/poisoning/poisoning_based/sslbkd_cifar10_test.yaml" 14 | 15 | # 创建输出目录 16 | # mkdir -p $OUTPUT_DIR 17 | 18 | # 运行DeDe防御 19 | CUDA_VISIBLE_DEVICES=7 python tools/run_dede.py \ 20 | --config $CONFIG_PATH \ 21 | --shadow_config $SHADOW_CONFIG_PATH \ 22 | --test_config $TEST_CONFIG_PATH 23 | -------------------------------------------------------------------------------- /tools/run_moco_training.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 展示如何调用MoCo训练接口 3 | 4 | import os 5 | import sys 6 | import argparse 7 | import yaml # 导入 yaml 库 8 | 9 | # 导入训练器接口和配置加载函数 10 | from ssl_trainers.moco.main_moco import get_trainer, load_config 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser(description='运行MoCo训练,命令行参数将覆盖配置文件中的同名参数') 15 | parser.add_argument('--config', type=str, default='configs/ssl/moco.py', 16 | help='基础配置文件路径,支持.py或.yaml格式') 17 | parser.add_argument('--attack_config', type=str, default=None, # 新增:攻击配置文件路径 18 | help='后门攻击配置文件路径 (.yaml格式)') 19 | 20 | # 添加所有可能需要从命令行覆盖的参数 21 | # !!! 重要:这里的参数名(如 'dataset')必须与配置文件中的键名完全一致 !!! 22 | parser.add_argument('--dataset', type=str, default=None, 23 | help='覆盖配置文件中的数据集名称 (键名: dataset)') 24 | parser.add_argument('--data', type=str, default=None, # 注意:参数名与config键名 'data' 保持一致 25 | help='覆盖配置文件中的数据集路径 (键名: data)') 26 | parser.add_argument('--experiment_id', type=str, default=None, # 注意:参数名与config键名 'experiment_id' 保持一致 27 | help='覆盖配置文件中的实验ID (键名: experiment_id)') 28 | parser.add_argument('--attack_algorithm', type=str, default=None, # 注意:参数名与config键名 'attack_algorithm' 保持一致 29 | help='覆盖配置文件中的攻击算法 (键名: attack_algorithm)') 30 | parser.add_argument('--epochs', type=int, default=None, 31 | help='覆盖配置文件中的训练轮数 (键名: epochs)') 32 | parser.add_argument('--batch_size', type=int, default=None, 33 | help='覆盖配置文件中的批次大小 (键名: batch_size)') 34 | # 可以根据需要添加更多参数... 35 | 36 | args = parser.parse_args() 37 | 38 | # 1. 加载基础配置 39 | config = load_config(args.config) 40 | print(f"已加载基础配置: {args.config}") 41 | 42 | # 2. 加载并合并攻击配置 (如果提供了) 43 | if args.attack_config: 44 | try: 45 | attack_config = load_config(args.attack_config) # 使用同一个加载函数 46 | if isinstance(attack_config, dict): 47 | print(f"已加载攻击配置: {args.attack_config}") 48 | # 合并,attack_config 中的值会覆盖 config 中的同名值 49 | config.update(attack_config) 50 | print("攻击配置已合并。") 51 | else: 52 | print(f"警告:攻击配置文件 {args.attack_config} 未能正确加载为字典,跳过合并。") 53 | except Exception as e: 54 | print(f"警告:加载攻击配置文件 {args.attack_config} 时出错: {e},跳过合并。") 55 | 56 | # 3. 将命令行参数转为字典 (用于覆盖) 57 | cmd_args_dict = vars(args) 58 | 59 | # 4. 使用命令行参数更新最终配置 (忽略None值和配置文件路径参数) 60 | for key, value in cmd_args_dict.items(): 61 | # 忽略 None 值以及 config 和 attack_config 本身 62 | if value is not None and key not in ['config', 'attack_config']: 63 | if key in config and config[key] != value: 64 | print(f"配置更新:'{key}' 从 '{config[key]}' 更新为 '{value}' (来自命令行)") 65 | elif key not in config: 66 | print(f"配置新增:'{key}' 设置为 '{value}' (来自命令行)") 67 | config[key] = value # 覆盖或添加 68 | 69 | print("\n最终使用的配置:") 70 | # 为了更清晰地打印配置,可以导入 pprint 71 | try: 72 | import pprint 73 | pprint.pprint(config) 74 | except ImportError: 75 | print(config) 76 | print("-"*30) 77 | 78 | # 5. 获取训练器 (传入更新后的config字典) 79 | trainer = get_trainer(config) 80 | 81 | # 6. 启动训练 82 | trainer() 83 | 84 | 85 | if __name__ == '__main__': 86 | main() -------------------------------------------------------------------------------- /tools/run_patchsearch.py: -------------------------------------------------------------------------------- 1 | """ 2 | PatchSearch防御方法的使用示例。 3 | """ 4 | 5 | import os 6 | import argparse 7 | import logging 8 | 9 | from ssl_backdoor.defenses.patchsearch import run_patchsearch, run_patchsearch_filter 10 | from ssl_backdoor.ssl_trainers.trainer import create_data_loader 11 | from ssl_backdoor.ssl_trainers.utils import load_config 12 | 13 | 14 | def parse_args(): 15 | """ 16 | 解析命令行参数 17 | """ 18 | parser = argparse.ArgumentParser(description='PatchSearch防御示例') 19 | parser.add_argument('--config', type=str, required=True, 20 | help='基础配置文件路径,支持.py或.yaml格式') 21 | parser.add_argument('--attack_config', type=str, required=True, 22 | help='后门攻击配置文件路径 (.yaml格式)') 23 | parser.add_argument('--output_dir', type=str, default=None, 24 | help='输出目录(可选,优先使用配置文件中的路径)') 25 | parser.add_argument('--experiment_id', type=str, default=None, 26 | help='实验ID(可选,优先使用配置文件中的值)') 27 | parser.add_argument('--skip_filter', action='store_true', 28 | help='是否跳过过滤步骤') 29 | 30 | return parser.parse_args() 31 | 32 | 33 | def main(): 34 | """ 35 | 主函数 36 | """ 37 | args = parse_args() 38 | 39 | # 1. 加载基础配置(PatchSearch算法配置) 40 | print(f"加载基础配置文件: {args.config}") 41 | config = load_config(args.config) 42 | 43 | # 2. 加载攻击配置(仅用于数据加载) 44 | print(f"加载攻击配置文件: {args.attack_config}") 45 | attack_config = load_config(args.attack_config) 46 | if not isinstance(attack_config, dict): 47 | raise ValueError(f"攻击配置文件 {args.attack_config} 格式错误") 48 | 49 | # 3. 命令行参数覆盖基础配置 50 | if args.output_dir: 51 | config['output_dir'] = args.output_dir 52 | if args.experiment_id: 53 | config['experiment_id'] = args.experiment_id 54 | 55 | # 确保必要的参数存在 56 | if 'weights_path' not in config or not config['weights_path']: 57 | raise ValueError("缺少必要参数: weights_path,请在基础配置文件中设置或使用--weights参数") 58 | 59 | print("PatchSearch防御配置:") 60 | print(f"模型权重: {config['weights_path']}") 61 | print(f"数据集名称: {config.get('dataset_name', 'unknown')}") 62 | print(f"输出目录: {config.get('output_dir', '/workspace/SSL-Backdoor/results/defense')}") 63 | print(f"实验ID: {config.get('experiment_id', 'patchsearch_defense')}") 64 | 65 | # 使用create_data_loader加载有毒数据集,只使用attack_config 66 | print("使用攻击配置加载有毒可疑数据集...") 67 | if attack_config.get('save_poisons'): 68 | print("save_poisons is True, 使用 PatchSearch 的训练文件初始化攻击配置的训练文件") 69 | attack_config['poisons_saved_path'] = config['train_file'] 70 | attack_config['save_poisons'] = False 71 | 72 | attack_config['distributed'] = False # 默认运行在单 GPU 上 73 | attack_config['workers'] = config['num_workers'] # 默认缺失的配置从 PatchSearch 的配置中获取 74 | attack_config['batch_size'] = config['batch_size'] # 默认缺失的配置从 PatchSearch 的配置中获取 75 | 76 | print("检查 attack config", attack_config) 77 | 78 | poison_dataset = None 79 | # 如果想要从训练风格中加载有毒数据集,请取消注释以下代码 80 | # # 将字典转换为Namespace对象 81 | # attack_args = argparse.Namespace(**attack_config) 82 | # poison_loader = create_data_loader( 83 | # attack_args # 传递Namespace对象而不是字典 84 | # ) 85 | # poison_dataset = poison_loader.dataset 86 | 87 | # 运行PatchSearch防御,只使用基础配置 88 | results = run_patchsearch( 89 | args=config, # 传递基础配置 90 | weights_path=config['weights_path'], 91 | suspicious_dataset=poison_dataset, # 传递加载的有毒数据集 92 | train_file=config['train_file'], 93 | dataset_name=config.get('dataset_name', 'imagenet100'), 94 | output_dir=config.get('output_dir', '/tmp'), 95 | arch=config.get('arch', 'resnet18'), 96 | num_clusters=config.get('num_clusters', 100), 97 | window_w=config.get('window_w', 60), 98 | batch_size=config.get('batch_size', 64), 99 | use_cached_feats=config.get('use_cached_feats', False), 100 | use_cached_poison_scores=config.get('use_cached_poison_scores', False), 101 | experiment_id=config.get('experiment_id', 'patchsearch_defense') 102 | ) 103 | 104 | if "status" in results and results["status"] == "CACHED_FEATURES": 105 | print("\n特征已缓存。请再次运行此脚本,但在基础配置文件中设置use_cached_feats=True以使用缓存的特征。") 106 | return 107 | 108 | # 打印最有可能的有毒样本 109 | print("\n最有可能的前10个有毒样本的索引:") 110 | for i, idx in enumerate(results["sorted_indices"][:10]): 111 | is_poison = "是" if results["is_poison"][idx] else "否" 112 | print(f"#{i+1}: 索引 {idx}, 毒性得分 {results['poison_scores'][idx]:.2f}, 实际是否有毒: {is_poison}") 113 | 114 | # 如果不跳过过滤步骤,则运行毒药分类器进行过滤 115 | if not args.skip_filter and 'filter' in config: 116 | print("\n====== 第二阶段:运行毒药分类器过滤 ======") 117 | 118 | # 获取必要的参数 119 | train_file = config['train_file'] 120 | experiment_dir = results["output_dir"] 121 | 122 | # 从config获取filter配置 123 | filter_config = config.get('filter', {}) 124 | 125 | # 运行过滤器 126 | filtered_file_path = run_patchsearch_filter( 127 | poison_scores_path= os.path.join(experiment_dir, 'poison-scores.npy'), 128 | train_file=train_file, 129 | dataset_name=config.get('dataset_name', 'imagenet100'), 130 | topk_poisons=filter_config.get('topk_poisons', 20), 131 | top_p=filter_config.get('top_p', 0.10), 132 | model_count=filter_config.get('model_count', 5), 133 | max_iterations=filter_config.get('max_iterations', 2000), 134 | batch_size=filter_config.get('batch_size', 128), 135 | num_workers=filter_config.get('num_workers', 8), 136 | lr=filter_config.get('lr', 0.01), 137 | momentum=filter_config.get('momentum', 0.9), 138 | weight_decay=filter_config.get('weight_decay', 1e-4), 139 | print_freq=filter_config.get('print_freq', 10), 140 | eval_freq=filter_config.get('eval_freq', 50), 141 | seed=filter_config.get('seed', 42) 142 | ) 143 | 144 | # 评估过滤结果 145 | logger = logging.getLogger('patchsearch') 146 | if os.path.exists(filtered_file_path): 147 | # 计算过滤前后的样本数量 148 | with open(train_file, 'r') as f: 149 | original_count = len(f.readlines()) 150 | 151 | with open(filtered_file_path, 'r') as f: 152 | filtered_count = len(f.readlines()) 153 | 154 | removed_count = original_count - filtered_count 155 | removed_percentage = (removed_count / original_count) * 100 156 | 157 | 158 | logger.info("\n====== 过滤结果统计 ======") 159 | logger.info(f"原始样本数量: {original_count}") 160 | logger.info(f"过滤后样本数量: {filtered_count}") 161 | logger.info(f"移除样本数量: {removed_count}") 162 | logger.info(f"移除样本百分比: {removed_percentage:.2f}%") 163 | logger.info(f"过滤后的数据集文件: {filtered_file_path}") 164 | logger.info(f"可以使用此文件重新训练您的SSL模型以获得更好的鲁棒性") 165 | else: 166 | logger.warning(f"警告: 未找到过滤后的文件 {filtered_file_path}") 167 | 168 | 169 | if __name__ == '__main__': 170 | main() -------------------------------------------------------------------------------- /tools/run_patchsearch.sh: -------------------------------------------------------------------------------- 1 | # 获取脚本所在的目录 2 | SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd) 3 | # 获取项目根目录 (tools目录的上级目录) 4 | PROJECT_ROOT=$(dirname "$SCRIPT_DIR") 5 | 6 | # 将项目根目录添加到 PYTHONPATH 7 | export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH}" 8 | 9 | # 现在可以执行 Python 脚本了,它能找到 ssl_trainers 10 | # 假设你的 python 命令是这样调用的 11 | CUDA_VISIBLE_DEVICES=3 python "${SCRIPT_DIR}/run_patchsearch.py" \ 12 | --config configs/defense/patchsearch.py \ 13 | --attack_config configs/poisoning/poisoning_based/sslbkd.yaml 14 | -------------------------------------------------------------------------------- /tools/train.sh: -------------------------------------------------------------------------------- 1 | # 获取脚本所在的目录 2 | SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd) 3 | # 获取项目根目录 (tools目录的上级目录) 4 | PROJECT_ROOT=$(dirname "$SCRIPT_DIR") 5 | 6 | # 将项目根目录添加到 PYTHONPATH 7 | export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH}" 8 | 9 | # 现在可以执行 Python 脚本了,它能找到 ssl_trainers 10 | # 假设你的 python 命令是这样调用的 11 | CUDA_VISIBLE_DEVICES=2,4 python "${SCRIPT_DIR}/ddp_training.py" \ 12 | --config configs/ssl/simsiam.py \ 13 | --attack_config configs/poisoning/poisoning_based_copy/na.yaml \ 14 | --test_config --------------------------------------------------------------------------------