├── .gitignore
├── LICENSE.txt
├── README.md
├── assets
├── drupe_reference
│ ├── CLIP
│ │ ├── one.npz
│ │ ├── priority.npz
│ │ ├── stop.npz
│ │ ├── truck.npz
│ │ └── zero.npz
│ ├── cifar10
│ │ ├── gtsrb_l0.npz
│ │ ├── gtsrb_l40.npz
│ │ ├── gtsrb_l5.npz
│ │ ├── one.npz
│ │ ├── priority.npz
│ │ ├── priority_wzt_3.npz
│ │ └── truck.npz
│ ├── cifar10_l0.npz
│ ├── cifar10_l0_224.npz
│ ├── cifar10_l0_n10.npz
│ ├── cifar10_l0_n3.npz
│ ├── cifar10_l0_n3_224.npz
│ ├── cifar10_l0_n5.npz
│ ├── cifar10_l4.npz
│ ├── gtsrb_l0.npz
│ ├── gtsrb_l0_224.npz
│ ├── gtsrb_l12.npz
│ ├── gtsrb_l12_224.npz
│ ├── gtsrb_l12_n10.npz
│ ├── gtsrb_l12_n2.npz
│ ├── gtsrb_l12_n2_rseed1.npz
│ ├── gtsrb_l12_n2_rseed2.npz
│ ├── gtsrb_l12_n2_rseed3.npz
│ ├── gtsrb_l12_n2_rseed4.npz
│ ├── gtsrb_l12_n3.npz
│ ├── gtsrb_l12_n3_224.npz
│ ├── gtsrb_l12_n5.npz
│ ├── gtsrb_l24_n3_224.npz
│ ├── gtsrb_l24_n3_224_rseed4.npz
│ ├── imagenet
│ │ ├── one.npz
│ │ ├── priority.npz
│ │ └── truck.npz
│ ├── stl10
│ │ ├── airplane.npz
│ │ ├── one.npz
│ │ └── priority.npz
│ ├── stl10_l0.npz
│ ├── stl10_l0_224.npz
│ ├── stl10_l9.npz
│ ├── stl10_l9_224.npz
│ ├── stl10_l9_n10.npz
│ ├── stl10_l9_n3.npz
│ ├── stl10_l9_n3_224.npz
│ ├── stl10_l9_n5.npz
│ ├── svhn_l0.npz
│ ├── svhn_l0_224.npz
│ ├── svhn_l1.npz
│ ├── svhn_l1_224.npz
│ ├── svhn_l1_n10.npz
│ ├── svhn_l1_n3.npz
│ ├── svhn_l1_n3_224.npz
│ └── svhn_l1_n5.npz
├── image.png
└── triggers
│ ├── 32471.png
│ ├── NC_clean.png
│ ├── apple_white.png
│ ├── drupe_trigger
│ ├── trigger_pt_white_173_50_ap_replace.npz
│ ├── trigger_pt_white_185_24.npz
│ ├── trigger_pt_white_21_10_ap_replace.npz
│ └── trigger_pt_white_42_20_ap_replace.npz
│ ├── firefox_32.png
│ ├── hellokitty_32.png
│ ├── phoenix_corner_96x96_256.png
│ ├── sig_trigger_128.png
│ ├── sig_trigger_224.png
│ ├── sig_trigger_32.png
│ ├── sig_trigger_64.png
│ ├── trigger_10.png
│ ├── trigger_11.png
│ ├── trigger_12.png
│ ├── trigger_13.png
│ ├── trigger_14.png
│ ├── trigger_15.png
│ ├── trigger_16.png
│ ├── trigger_17.png
│ ├── trigger_18.png
│ ├── trigger_19.png
│ ├── trojan_watermark_224.png
│ ├── watermark_white_32.png
│ └── white_10x10.png
├── configs
├── attacks
│ ├── badencoder.py
│ ├── badencoder_test.yaml
│ ├── badencoder_train.yaml
│ ├── drupe.py
│ ├── drupe_test.yaml
│ └── drupe_train.yaml
├── poisoning
│ └── poisoning_based
│ │ ├── sslbkd.yaml
│ │ ├── sslbkd_cifar10.yaml
│ │ ├── sslbkd_cifar10_test.yaml
│ │ ├── sslbkd_shadow_copy.yaml
│ │ └── sslbkd_test.yaml
└── ssl
│ ├── byol.py
│ ├── moco.py
│ ├── simclr.py
│ └── simsiam.py
├── docker
└── Dockerfile
├── docs
└── zh_cn
│ ├── patchsearch.log
│ └── patchsearch.md
├── requirements.txt
├── ssl_backdoor
├── attacks
│ ├── __init__.py
│ └── badencoder
│ │ ├── __init__.py
│ │ ├── badencoder.py
│ │ └── datasets.py
├── datasets
│ ├── __init__.py
│ ├── agent.py
│ ├── base.py
│ ├── corruptencoder_utils.py
│ ├── dataset.py
│ ├── generators.py
│ ├── metadata
│ │ ├── cifar10_classes.txt
│ │ ├── cifar10_metadata.txt
│ │ ├── class_index.txt
│ │ ├── imagenet100_classes.txt
│ │ ├── imagenet_metadata.txt
│ │ ├── stl10_classes.txt
│ │ └── stl10_metadata.txt
│ ├── utils.py
│ └── var.py
├── defenses
│ ├── decree
│ │ └── trigger
│ │ │ ├── trigger_pt_white_185_24.npz
│ │ │ └── trigger_pt_white_21_10_ap_replace.npz
│ ├── dede
│ │ ├── decoder_model.py
│ │ ├── dede.py
│ │ └── reconstruction.py
│ └── patchsearch
│ │ ├── __init__.py
│ │ ├── core.py
│ │ ├── poison_classifier.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── clustering.py
│ │ ├── dataset.py
│ │ ├── evaluation.py
│ │ ├── gradcam.py
│ │ ├── model_utils.py
│ │ ├── patch_operations.py
│ │ └── visualization.py
├── ssl_trainers
│ ├── __init__.py
│ ├── byol
│ │ ├── builder.py
│ │ ├── cfg.py
│ │ ├── cifar10_classes.txt
│ │ ├── cifar10_metadata.txt
│ │ ├── dataset
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── cifar10.py
│ │ │ ├── cifar100.py
│ │ │ ├── imagenet.py
│ │ │ ├── stl10.py
│ │ │ ├── tiny_in.py
│ │ │ └── transforms.py
│ │ ├── eval
│ │ │ ├── get_data.py
│ │ │ ├── knn.py
│ │ │ ├── lbfgs.py
│ │ │ └── sgd.py
│ │ ├── imagenet100_classes.txt
│ │ ├── imagenet_metadata.txt
│ │ ├── methods
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── byol.py
│ │ │ ├── contrastive.py
│ │ │ ├── norm_mse.py
│ │ │ ├── w_mse.py
│ │ │ └── whitening.py
│ │ ├── moco
│ │ │ ├── __init__.py
│ │ │ ├── builder.py
│ │ │ ├── dataset3.py
│ │ │ ├── loader.py
│ │ │ └── poisonencoder_utils.py
│ │ ├── model.py
│ │ ├── test.py
│ │ └── train.py
│ ├── moco
│ │ ├── builder.py
│ │ ├── loader.py
│ │ └── test.py
│ ├── models_vit.py
│ ├── simclr
│ │ └── builder.py
│ ├── simsiam
│ │ └── builder.py
│ ├── trainer.py
│ └── utils.py
└── utils
│ ├── __init__.py
│ ├── model_utils.py
│ └── utils.py
└── tools
├── ddp_training.py
├── eval_knn.py
├── eval_utils.py
├── ft_linear.py
├── process_dataset.py
├── run_badencoder.py
├── run_badencoder.sh
├── run_dede.py
├── run_dede.sh
├── run_moco_training.py
├── run_patchsearch.py
├── run_patchsearch.sh
├── test.py
└── train.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore git records in subfolders
2 | **/.git/
3 |
4 | # Ignore all torch files
5 | *.pt
6 | *.pth
7 | *.pth.tar
8 |
9 | # Ignore all datasets
10 | data/
11 |
12 | # Ignore all training results
13 | results/
14 |
15 | # Ignore repos and submodules
16 | .trash/
17 | CompRess/
18 | defense/
19 | mine-pytorch/
20 | simsiam_old/
21 | SSL-Cleanse/
22 | scripts_old/
23 | trash/
24 | ssl_backdoor/projects/
25 | ssl_backdoor/defenses/projects/
26 |
27 | # Ignore private things
28 | configs/poisoning/ablation/**
29 | configs/poisoning/poisoning_based_copy/**
30 | configs/poisoning/optimization_based/**
31 | projects/
32 | scripts_old/
33 | projects/
34 |
35 | ssl_backdoor/datasets/hidden.py
36 | ssl_backdoor/datasets/dataset_old.py
37 | ssl_backdoor/attacks/drupe/
38 | ssl_backdoor/defenses/hidden/
39 | ssl_backdoor/defenses/patchsearch/PatchSearch/
40 | ssl_backdoor/defenses/dede/DeDe/
41 |
42 | # trash
43 | .trash/
44 | __pycache__/
45 | **/__pycache__/
46 | .hypothesis/
47 | ssl_pretrain.py
48 | methods.py
49 | test.py
50 | assets/references/drupe_gtsrb_l12_n3
51 | assets/references/imagenet
52 | ssl_backdoor/attacks/badencoder/badencoder_vanilla.py
53 | ssl_backdoor/ssl_trainers/moco/moco_badencoder
54 | ssl_backdoor/ssl_trainers/moco/simsiam_badencoder
55 | ssl_backdoor/ssl_trainers/moco/main_moco\ copy.py
56 | ssl_backdoor/ssl_trainers/moco/main_moco_badencoder2.py
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Tuo Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | # SSL-Backdoor
16 |
17 | SSL-Backdoor is an academic research library dedicated to exploring **poisoning attacks in self-supervised learning (SSL)**. Our goal is to provide a comprehensive and unified platform for researchers to implement, evaluate, and compare various attack and defense mechanisms in the context of SSL.
18 |
19 | This library originated as a rewrite of the SSLBKD implementation, ensuring **consistent training protocols** and **hyperparameter fidelity** for fair comparisons. We've since expanded its capabilities significantly.
20 |
21 | **Key Features:**
22 |
23 | 1. **Unified Poisoning & Training Framework:** Streamlined pipeline for applying diverse poisoning strategies and training SSL models.
24 | 2. **Decoupled Design:** We strive to maintain a decoupled design, allowing each method to be modified independently, while unifying the implementation of essential tools where necessary.
25 |
26 | *Future plans include support for multimodal contrastive learning models.*
27 |
28 | ## 📢 What's New?
29 |
30 | ✅ **2025-05-19 Update:**
31 |
32 | * **DEDE defense is now implemented and available!**
33 |
34 | ✅ **2025-04-18 Update:**
35 |
36 | * **PatchSearch defense is now implemented and available!**
37 | * **BadEncoder attack is now implemented and available!**
38 |
39 | 🔄 **Active Refactoring Underway!** We are currently refactoring the codebase to improve code quality, maintainability, and ease of use. Expect ongoing improvements!
40 |
41 | ✅ **Current Support:**
42 |
43 | * **Attack Algorithms:** SSLBKD, CTRL, CorruptEncoder, BLTO (inference only), BadEncoder
44 | * **SSL Methods:** MoCo, SimCLR, SimSiam, BYOL
45 |
46 | 🛡️ **Current Defenses:**
47 |
48 | * **PatchSearch**, **DEDE**
49 |
50 | Stay tuned for more updates!
51 |
52 | ## Supported Attacks
53 |
54 | This library currently supports the following poisoning attack algorithms against SSL models:
55 |
56 | | Aliase | Paper | Conference | Config |
57 | |-----------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------|--------|
58 | | SSLBKD | [Backdoor attacks on self-supervised learning](https://doi.org/10.1109/CVPR52688.2022.01298) | CVPR 2022 | [config](configs/poisoning/poisoning_based/sslbkd.yaml) |
59 | | CTRL | [An Embarrassingly Simple Backdoor Attack on Self-supervised Learning](https://openaccess.thecvf.com/content/ICCV2023/html/Li_An_Embarrassingly_Simple_Backdoor_Attack_on_Self-supervised_Learning_ICCV_2023_paper.html) | ICCV 2023 | |
60 | | CorruptEncoder | [Data poisoning based backdoor attacks to contrastive learning](https://openaccess.thecvf.com/content/CVPR2024/html/Zhang_Data_Poisoning_based_Backdoor_Attacks_to_Contrastive_Learning_CVPR_2024_paper.html) | CVPR 2024 | |
61 | | BLTO (inference)| [BACKDOOR CONTRASTIVE LEARNING VIA BI-LEVEL TRIGGER OPTIMIZATION](https://openreview.net/forum?id=oxjeePpgSP) | ICLR 2024 | |
62 | | BadEncoder | [BadEncoder: Backdoor Attacks to Pre-trained Encoders in Self-Supervised Learning](https://ieeexplore.ieee.org/abstract/document/9833644/) | S&P 2022| [config](configs/attacks/badencoder.py) |
63 |
64 | ## Supported Defenses
65 |
66 | We are actively developing and integrating defense mechanisms. Currently, the following defense is implemented:
67 |
68 | | Aliase | Paper | Conference | Config |
69 | |------------------|------------------------------------------------------------------------------------------------------------------|---------------------------------------|----------------|
70 | | PatchSearch | [Defending Against Patch-Based Backdoor Attacks on Self-Supervised Learning](https://openaccess.thecvf.com/content/CVPR2023/html/Tejankar_Defending_Against_Patch-Based_Backdoor_Attacks_on_Self-Supervised_Learning_CVPR_2023_paper.html) | CVPR2023 | [doc](./docs/zh_cn/patchsearch.md), [config](configs/defense/patchsearch.py) |
71 | | DEDE | [DeDe: Detecting Backdoor Samples for SSL Encoders via Decoders](http://arxiv.org/abs/2411.16154) | CVPR2025 | [config](configs/defense/dede.py) |
72 |
73 | ## Setup
74 |
75 | Get started with SSL-Backdoor quickly:
76 |
77 | 1. **Clone the repository:**
78 | ```bash
79 | git clone https://github.com/jsrdcht/SSL-Backdoor.git
80 | cd SSL-Backdoor
81 | ```
82 |
83 | 2. **Install dependencies (optional but recommended):**
84 | ```bash
85 | pip install -r requirements.txt
86 | ```
87 | *Consider using a virtual environment (`conda`, `venv`, etc.) to manage dependencies.*
88 |
89 | ## Usage
90 |
91 | ### Training an SSL Model on a Poisoned Dataset
92 |
93 | To train an SSL model (e.g., using MoCo v2) with a chosen poisoning attack, you can use the provided scripts. Example for Distributed Data Parallel (DDP) training:
94 |
95 | ```bash
96 | # Configure your desired attack, SSL method, dataset, etc. in the relevant config file
97 | # (e.g., configs/ssl/moco_config.yaml, configs/poisoning/...)
98 |
99 | bash tools/train.sh
100 | ```
101 |
102 | *Please refer to the `configs` directory and specific training scripts for detailed usage and parameter options.*
103 |
104 | ## Performance Benchmarks (Legacy)
105 |
106 | *(Note: These results are based on the original implementation before the current refactoring.)*
107 |
108 | | Algorithm | Method | Clean Acc ↑ | Backdoor Acc ↓ | ASR ↑ |
109 | |-----------------|--------|-------------|----------------|-------|
110 | | SSLBKD | BYOL | 66.38% | 23.82% | 70.2% |
111 | | SSLBKD | SimCLR | 70.9% | 49.1% | 33.9% |
112 | | SSLBKD | MoCo | 66.28% | 33.24% | 57.6% |
113 | | SSLBKD | SimSiam| 64.48% | 29.3% | 62.2% |
114 | | CorruptEncoder | BYOL | 65.48% | 25.3% | 9.66% |
115 | | CorruptEncoder | SimCLR | 70.14% | 45.38% | 36.9% |
116 | | CorruptEncoder | MoCo | 67.04% | 38.64% | 37.3% |
117 | | CorruptEncoder | SimSiam| 57.54% | 14.14% | 79.48% |
118 |
119 | | Algorithm | Method | Clean Acc ↑ | Backdoor Acc ↓ | ASR ↑ |
120 | |-----------------|--------|-------------|----------------|-------|
121 | | CTRL | BYOL | 75.02% | 30.87% | 66.95% |
122 | | CTRL | SimCLR | 70.32% | 20.82% | 81.97% |
123 | | CTRL | MoCo | 71.01% | 54.5% | 34.34% |
124 | | CTRL | SimSiam| 71.04% | 50.36% | 41.43% |
125 |
126 | * Data calculated using the 10% available data evaluation protocol from the SSLBKD paper on the lorikeet class of ImageNet-100 and the airplane class of CIFAR-10, respectively.
--------------------------------------------------------------------------------
/assets/drupe_reference/CLIP/one.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/CLIP/one.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/CLIP/priority.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/CLIP/priority.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/CLIP/stop.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/CLIP/stop.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/CLIP/truck.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/CLIP/truck.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/CLIP/zero.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/CLIP/zero.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/cifar10/gtsrb_l0.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10/gtsrb_l0.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/cifar10/gtsrb_l40.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10/gtsrb_l40.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/cifar10/gtsrb_l5.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10/gtsrb_l5.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/cifar10/one.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10/one.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/cifar10/priority.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10/priority.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/cifar10/priority_wzt_3.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10/priority_wzt_3.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/cifar10/truck.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10/truck.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/cifar10_l0.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10_l0.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/cifar10_l0_224.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10_l0_224.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/cifar10_l0_n10.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10_l0_n10.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/cifar10_l0_n3.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10_l0_n3.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/cifar10_l0_n3_224.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10_l0_n3_224.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/cifar10_l0_n5.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10_l0_n5.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/cifar10_l4.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/cifar10_l4.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/gtsrb_l0.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l0.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/gtsrb_l0_224.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l0_224.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/gtsrb_l12.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/gtsrb_l12_224.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_224.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/gtsrb_l12_n10.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_n10.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/gtsrb_l12_n2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_n2.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/gtsrb_l12_n2_rseed1.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_n2_rseed1.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/gtsrb_l12_n2_rseed2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_n2_rseed2.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/gtsrb_l12_n2_rseed3.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_n2_rseed3.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/gtsrb_l12_n2_rseed4.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_n2_rseed4.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/gtsrb_l12_n3.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_n3.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/gtsrb_l12_n3_224.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_n3_224.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/gtsrb_l12_n5.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l12_n5.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/gtsrb_l24_n3_224.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l24_n3_224.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/gtsrb_l24_n3_224_rseed4.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/gtsrb_l24_n3_224_rseed4.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/imagenet/one.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/imagenet/one.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/imagenet/priority.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/imagenet/priority.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/imagenet/truck.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/imagenet/truck.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/stl10/airplane.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10/airplane.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/stl10/one.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10/one.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/stl10/priority.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10/priority.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/stl10_l0.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10_l0.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/stl10_l0_224.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10_l0_224.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/stl10_l9.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10_l9.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/stl10_l9_224.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10_l9_224.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/stl10_l9_n10.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10_l9_n10.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/stl10_l9_n3.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10_l9_n3.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/stl10_l9_n3_224.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10_l9_n3_224.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/stl10_l9_n5.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/stl10_l9_n5.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/svhn_l0.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/svhn_l0.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/svhn_l0_224.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/svhn_l0_224.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/svhn_l1.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/svhn_l1.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/svhn_l1_224.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/svhn_l1_224.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/svhn_l1_n10.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/svhn_l1_n10.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/svhn_l1_n3.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/svhn_l1_n3.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/svhn_l1_n3_224.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/svhn_l1_n3_224.npz
--------------------------------------------------------------------------------
/assets/drupe_reference/svhn_l1_n5.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/drupe_reference/svhn_l1_n5.npz
--------------------------------------------------------------------------------
/assets/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/image.png
--------------------------------------------------------------------------------
/assets/triggers/32471.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/32471.png
--------------------------------------------------------------------------------
/assets/triggers/NC_clean.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/NC_clean.png
--------------------------------------------------------------------------------
/assets/triggers/apple_white.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/apple_white.png
--------------------------------------------------------------------------------
/assets/triggers/drupe_trigger/trigger_pt_white_173_50_ap_replace.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/drupe_trigger/trigger_pt_white_173_50_ap_replace.npz
--------------------------------------------------------------------------------
/assets/triggers/drupe_trigger/trigger_pt_white_185_24.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/drupe_trigger/trigger_pt_white_185_24.npz
--------------------------------------------------------------------------------
/assets/triggers/drupe_trigger/trigger_pt_white_21_10_ap_replace.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/drupe_trigger/trigger_pt_white_21_10_ap_replace.npz
--------------------------------------------------------------------------------
/assets/triggers/drupe_trigger/trigger_pt_white_42_20_ap_replace.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/drupe_trigger/trigger_pt_white_42_20_ap_replace.npz
--------------------------------------------------------------------------------
/assets/triggers/firefox_32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/firefox_32.png
--------------------------------------------------------------------------------
/assets/triggers/hellokitty_32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/hellokitty_32.png
--------------------------------------------------------------------------------
/assets/triggers/phoenix_corner_96x96_256.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/phoenix_corner_96x96_256.png
--------------------------------------------------------------------------------
/assets/triggers/sig_trigger_128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/sig_trigger_128.png
--------------------------------------------------------------------------------
/assets/triggers/sig_trigger_224.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/sig_trigger_224.png
--------------------------------------------------------------------------------
/assets/triggers/sig_trigger_32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/sig_trigger_32.png
--------------------------------------------------------------------------------
/assets/triggers/sig_trigger_64.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/sig_trigger_64.png
--------------------------------------------------------------------------------
/assets/triggers/trigger_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_10.png
--------------------------------------------------------------------------------
/assets/triggers/trigger_11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_11.png
--------------------------------------------------------------------------------
/assets/triggers/trigger_12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_12.png
--------------------------------------------------------------------------------
/assets/triggers/trigger_13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_13.png
--------------------------------------------------------------------------------
/assets/triggers/trigger_14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_14.png
--------------------------------------------------------------------------------
/assets/triggers/trigger_15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_15.png
--------------------------------------------------------------------------------
/assets/triggers/trigger_16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_16.png
--------------------------------------------------------------------------------
/assets/triggers/trigger_17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_17.png
--------------------------------------------------------------------------------
/assets/triggers/trigger_18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_18.png
--------------------------------------------------------------------------------
/assets/triggers/trigger_19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trigger_19.png
--------------------------------------------------------------------------------
/assets/triggers/trojan_watermark_224.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/trojan_watermark_224.png
--------------------------------------------------------------------------------
/assets/triggers/watermark_white_32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/watermark_white_32.png
--------------------------------------------------------------------------------
/assets/triggers/white_10x10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/assets/triggers/white_10x10.png
--------------------------------------------------------------------------------
/configs/attacks/badencoder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """
3 | BadEncoder攻击算法配置文件
4 |
5 | BadEncoder: 一种针对自监督学习编码器的后门攻击实现
6 | """
7 |
8 | # 基本配置
9 | config = {
10 | 'experiment_id': 'imagenet100_trigger-size-50', # 实验ID
11 | # 模型参数
12 | 'arch': 'resnet18', # 编码器架构
13 | 'pretrained_encoder': '/workspace/SSL-Backdoor/ssl_backdoor/attacks/drupe/DRUPE/clean_encoder/model_1000.pth', # 预训练编码器路径
14 | 'encoder_usage_info': 'imagenet', # 编码器使用信息,用于确定加载的模型
15 | 'batch_size': 64, # 批处理大小 default: 256
16 | 'num_workers': 4, # 数据加载进程数
17 |
18 | # 数据相关参数
19 | 'image_size': 224, # 图像大小,用于resize操作
20 | # trigger image configuration file
21 | 'trigger_file': 'assets/triggers/trigger_14.png',
22 | 'trigger_size': 50,
23 |
24 | # shadow data 相关参数
25 | 'shadow_dataset': 'imagenet100',
26 | 'shadow_file': 'data/ImageNet-100/10percent_trainset.txt',
27 | 'shadow_fraction': 0.2, # default: 0.2
28 | # reference data 相关参数
29 | # 'reference_file': 'assets/references/drupe_gtsrb_l12_n3/references.txt', # reference data configuration file
30 | 'reference_file': '/workspace/SSL-Backdoor/assets/references/imagenet/references.txt',
31 |
32 | 'n_ref': 3, # 参考输入数量
33 | 'downstream_dataset': 'imagenet100',
34 |
35 |
36 | # 训练参数
37 | 'lr': 0.0001, # 学习率 default: 0.05
38 | 'momentum': 0.9,
39 | 'weight_decay': 5e-4,
40 | 'lambda1': 1.0, # 损失权重1
41 | 'lambda2': 1.0, # 损失权重2
42 | 'epochs': 120, # 训练轮数
43 | # 'lr_milestones': [60, 90], # 学习率调整轮数
44 | # 'lr_gamma': 0.1, # 学习率调整因子
45 | 'warm_up_epochs': 2, # 预热轮数
46 | 'print_freq': 40, # 打印频率
47 | 'save_freq': 5, # 保存频率
48 | 'eval_freq': 5, # 评估频率
49 |
50 | # 下游评估参数
51 | 'nn_epochs': 100, # 下游分类器训练轮数
52 | 'hidden_size_1': 512, # 下游分类器隐藏层1大小
53 | 'hidden_size_2': 256, # 下游分类器隐藏层2大小
54 | 'batch_size_downstream': 64, # 下游分类器批处理大小
55 | 'lr_downstream': 0.0001, # 下游分类器学习率
56 |
57 | # 系统参数
58 | 'seed': 42, # 随机种子
59 | 'output_dir': '/workspace/SSL-Backdoor/results/test/badencoder', # 输出目录
60 |
61 | }
--------------------------------------------------------------------------------
/configs/attacks/badencoder_test.yaml:
--------------------------------------------------------------------------------
1 | dataset: gtsrb
2 | attack_algorithm: badencoder
3 | return_attack_target: true # 运行 badencoder 时,需要设置为 true
4 | attack_target: 12
5 | trigger_path: assets/triggers/white_10x10.png
6 | trigger_size: 10
7 | train_file: /workspace/dataset/gtsrb1/trainset.txt
8 | test_file: /workspace/dataset/gtsrb1/testset.txt
9 | trigger_insert: patch
10 | position: badencoder
11 |
12 |
13 |
--------------------------------------------------------------------------------
/configs/attacks/badencoder_train.yaml:
--------------------------------------------------------------------------------
1 | image_size: 32
2 | shadow_dataset: cifar10
3 | shadow_file: data/CIFAR10/trainset.txt
4 | memory_file: data/CIFAR10/trainset.txt
5 | reference_file: assets/references/drupe_gtsrb_l12_n3/references.txt
6 | trigger_file: assets/triggers/white_10x10.png
7 |
8 | shadow_fraction: 0.2
9 | trigger_size: 10
10 |
11 | # 不经常改动的参数
12 | n_ref: 3
--------------------------------------------------------------------------------
/configs/attacks/drupe.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """
3 | DRUPE攻击算法配置文件
4 |
5 | DRUPE: 分布对齐和相似度正则化的后门攻击实现
6 | """
7 |
8 | # 基本配置
9 | config = {
10 | # 模型参数
11 | 'arch': 'resnet18', # 编码器架构
12 | 'pretrained_encoder': '/workspace/SSL-Backdoor/ssl_backdoor/attacks/drupe/DRUPE/clean_encoder/model_1000.pth', # 预训练编码器路径
13 | 'encoder_usage_info': 'cifar10', # 编码器使用信息,用于确定加载的模型
14 | 'batch_size': 256, # 批处理大小
15 | 'num_workers': 4, # 数据加载进程数
16 |
17 | # 数据相关参数
18 | 'image_size': 32, # 图像大小,用于resize操作
19 | # trigger image configuration file
20 | 'trigger_file': 'assets/triggers/white_10x10.png',
21 | 'trigger_size': 10,
22 |
23 | # shadow data 相关参数
24 | 'shadow_dataset': 'cifar10',
25 | 'shadow_file': 'data/CIFAR10/trainset.txt',
26 | 'shadow_fraction': 0.2,
27 | # reference data 相关参数
28 | 'reference_file': 'assets/references/drupe_gtsrb_l12_n3/references.txt', # reference data configuration file
29 | 'reference_label': 12, # 参考标签(目标类)
30 |
31 | 'n_ref': 3, # 参考输入数量
32 | # 测试数据相关参数
33 | 'downstream_dataset': 'gtsrb',
34 |
35 | # DRUPE特有参数
36 | 'mode': 'drupe', # 攻击模式:'drupe', 'badencoder', 'wb'
37 | 'fix_epoch': 20, # 固定参数的轮数
38 |
39 | # 训练参数
40 | 'lr': 0.05, # 学习率
41 | 'momentum': 0.9,
42 | 'weight_decay': 5e-4,
43 | 'lambda1': 1.0, # 损失权重1
44 | 'lambda2': 1.0, # 损失权重2
45 | 'epochs': 120, # 训练轮数
46 | 'warm_up_epochs': 2, # 预热轮数
47 | 'print_freq': 10, # 打印频率
48 | 'save_freq': 20, # 保存频率
49 |
50 | # 下游评估参数
51 | 'eval_downstream': True, # 是否评估下游任务
52 | 'nn_epochs': 120, # 下游分类器训练轮数
53 | 'hidden_size_1': 512, # 下游分类器隐藏层1大小
54 | 'hidden_size_2': 256, # 下游分类器隐藏层2大小
55 | 'batch_size_downstream': 64, # 下游分类器批处理大小
56 | 'lr_downstream': 0.001, # 下游分类器学习率
57 |
58 | # 系统参数
59 | 'seed': 42, # 随机种子
60 | 'output_dir': '/workspace/SSL-Backdoor/results/test/drupe', # 输出目录
61 | 'experiment_id': 'test_cifar102gtsrb_trigger-size-10', # 实验ID
62 | }
--------------------------------------------------------------------------------
/configs/attacks/drupe_test.yaml:
--------------------------------------------------------------------------------
1 | # DRUPE测试配置文件
2 | dataset: gtsrb
3 | attack_algorithm: badencoder
4 | return_attack_target: true # 运行 drupe 时,需要设置为 true
5 | attack_target: 12
6 | trigger_path: assets/triggers/white_10x10.png
7 | trigger_size: 10
8 | train_file: /workspace/dataset/gtsrb1/trainset.txt
9 | test_file: /workspace/dataset/gtsrb1/testset.txt
10 | trigger_insert: patch
11 | position: badencoder
--------------------------------------------------------------------------------
/configs/attacks/drupe_train.yaml:
--------------------------------------------------------------------------------
1 | # DRUPE训练配置文件
2 | image_size: 32
3 | shadow_dataset: cifar10
4 | shadow_file: data/CIFAR10/trainset.txt
5 | memory_file: data/CIFAR10/trainset.txt
6 | test_file: data/CIFAR10/testset.txt
7 |
8 | # 参考输入配置
9 | reference_file: assets/references/drupe_gtsrb_l12_n3/references.txt
10 | trigger_file: assets/triggers/white_10x10.png
11 |
12 | # 攻击参数
13 | shadow_fraction: 0.2
14 | trigger_size: 10
15 | mode: drupe # 攻击模式: drupe, badencoder, wb
16 | n_ref: 3 # 参考输入数量
17 |
18 | # 模型相关参数
19 | encoder_usage_info: cifar10
--------------------------------------------------------------------------------
/configs/poisoning/poisoning_based/sslbkd.yaml:
--------------------------------------------------------------------------------
1 | attack_algorithm: sslbkd
2 | data: data/ImageNet-100/trainset.txt
3 | dataset: imagenet100
4 | save_poisons: True
5 | # poisons_saved_path: /workspace/SSL-Backdoor/results/test/simsiam_imagenet-100_test/poisons
6 |
7 | keep_poison_class: False
8 | attack_target_list:
9 | - 0
10 | trigger_path_list:
11 | - assets/triggers/trigger_14.png
12 | reference_dataset_file_list:
13 | - data/ImageNet-100/trainset.txt
14 | num_reference_list:
15 | - 650
16 | num_poison_list:
17 | - 650
18 |
19 | test_train_file: data/ImageNet-100/10percent_trainset.txt
20 | test_val_file: data/ImageNet-100/valset.txt
21 | test_attack_target: 0
22 | test_trigger_path: assets/triggers/trigger_14.png
23 | test_dataset: imagenet100
24 | test_trigger_size: 50
25 | test_trigger_insert: patch
26 | test_attack_algorithm: sslbkd
27 |
28 | return_attack_target: False
29 | attack_target_word: n01558993
30 | trigger_insert: patch
31 | trigger_size: 50
--------------------------------------------------------------------------------
/configs/poisoning/poisoning_based/sslbkd_cifar10.yaml:
--------------------------------------------------------------------------------
1 | attack_algorithm: sslbkd
2 | data: data/CIFAR10/trainset.txt
3 | dataset: cifar10
4 | save_poisons: True
5 |
6 | keep_poison_class: False
7 | attack_target_list:
8 | - 0
9 | trigger_path_list:
10 | - assets/triggers/trigger_14.png
11 | reference_dataset_file_list:
12 | - data/CIFAR10/trainset.txt
13 | num_reference_list:
14 | - 2500
15 | num_poison_list:
16 | - 2500
17 |
18 | attack_target_word: n01558993
19 | trigger_insert: patch
20 | trigger_size: 8
21 |
22 |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/configs/poisoning/poisoning_based/sslbkd_cifar10_test.yaml:
--------------------------------------------------------------------------------
1 | # testset parameters
2 | train_file: data/CIFAR10/10percent_trainset.txt
3 | test_file: data/CIFAR10/testset.txt
4 | attack_target: 0
5 | trigger_path: assets/triggers/trigger_14.png
6 | dataset: cifar10
7 | trigger_size: 8
8 | trigger_insert: patch
9 | attack_algorithm: sslbkd
10 | return_attack_target: False
--------------------------------------------------------------------------------
/configs/poisoning/poisoning_based/sslbkd_shadow_copy.yaml:
--------------------------------------------------------------------------------
1 | image_size: 224
2 | shadow_dataset: imagenet100
3 | shadow_file: data/ImageNet-100/trainset.txt
4 | memory_file: data/ImageNet-100/trainset.txt
--------------------------------------------------------------------------------
/configs/poisoning/poisoning_based/sslbkd_test.yaml:
--------------------------------------------------------------------------------
1 | dataset: imagenet100
2 | attack_algorithm: sslbkd
3 | return_attack_target: true # 运行 badencoder 时,需要设置为 true
4 | attack_target: 6
5 | trigger_path: assets/triggers/trigger_14.png
6 | trigger_size: 50
7 | train_file: data/ImageNet-100/10percent_trainset.txt
8 | test_file: data/ImageNet-100/valset.txt
9 | trigger_insert: patch
10 |
11 |
--------------------------------------------------------------------------------
/configs/ssl/byol.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # BYOL默认配置文件
3 |
4 | # 基本配置
5 | config = {
6 | # 通用参数
7 | 'method': 'byol', # 设置为 BYOL
8 | 'arch': 'resnet18',
9 | 'workers': 4,
10 | 'epochs': 300,
11 | 'start_epoch': 0,
12 | 'batch_size': 128,
13 | 'lr': 0.002,
14 | 'optimizer': 'adam',
15 | 'weight_decay': 1e-6,
16 | 'lr_schedule': 'step', # 'step', 'cos'
17 | 'lr_drops': [250, 275],
18 | 'lr_drop_gamma': 0.2,
19 | 'print_freq': 10,
20 | 'resume': '',
21 | 'dist_url': 'tcp://localhost:10021',
22 | 'dist_backend': 'nccl',
23 | 'seed': None,
24 | 'multiprocessing_distributed': True,
25 | 'feature_dim': 512, # 特征维度
26 |
27 | # 攻击相关参数
28 | 'attack_algorithm': 'sslbkd', # 'bp', 'corruptencoder', 'sslbkd', 'ctrl', 'clean', 'blto', 'optimized'
29 | 'ablation': False,
30 |
31 | # BYOL特定参数
32 | 'byol_tau': 0.99, # 目标网络动量系数
33 | 'proj_dim': 1024, # 投影头隐藏层维度
34 | 'pred_dim': 128, # 预测头输出维度
35 |
36 | # 混合精度训练
37 | 'amp': True,
38 |
39 | # 实验记录
40 | 'experiment_id': 'byol_imagenet-100_test',
41 | 'save_folder_root': '/workspace/SSL-Backdoor/results/test',
42 | 'save_freq': 30,
43 | 'eval_frequency': 30,
44 |
45 | # 日志配置
46 | 'logger_type': 'wandb', # 'tensorboard', 'wandb', 'none'
47 | }
--------------------------------------------------------------------------------
/configs/ssl/moco.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # MoCo默认配置文件
3 |
4 | # 基本配置
5 | config = {
6 | # 通用参数
7 | 'method': 'moco',
8 | 'arch': 'resnet18',
9 | 'workers': 4,
10 | 'epochs': 300,
11 | 'start_epoch': 0,
12 | 'batch_size': 256,
13 | 'lr': 0.06,
14 | 'optimizer': 'sgd',
15 | 'momentum': 0.9,
16 | 'weight_decay': 1e-4,
17 | 'lr_schedule': 'cos',
18 | 'print_freq': 50,
19 | 'resume': '',
20 | 'dist_url': 'tcp://localhost:10001',
21 | 'dist_backend': 'nccl',
22 | 'seed': None,
23 | 'multiprocessing_distributed': True,
24 | 'feature_dim': 128,
25 |
26 | # 攻击相关参数
27 | 'ablation': False,
28 |
29 | # MoCo特定参数
30 | 'moco_k': 65536,
31 | 'moco_m': 0.999,
32 | 'moco_contr_w': 1,
33 | 'moco_contr_tau': 0.2,
34 | 'moco_align_w': 0,
35 | 'moco_align_alpha': 2,
36 | 'moco_unif_w': 0,
37 | 'moco_unif_t': 3,
38 |
39 | # # 数据集配置
40 | # 'dataset': 'imagenet-100',
41 | # 'data': '/workspace/SSL-Backdoor/data/ImageNet-100/trainset.txt',
42 |
43 | # 混合精度训练
44 | 'amp': True,
45 |
46 | # 实验记录
47 | 'experiment_id': 'moco_imagenet-100_test',
48 | 'save_folder_root': '/workspace/SSL-Backdoor/results/test',
49 | 'save_freq': 30,
50 |
51 | # 日志配置
52 | 'logger_type': 'wandb', # 'tensorboard', 'wandb', 'none'
53 |
54 | # 攻击目标类别(如果需要)
55 | 'attack_target_list': [0]
56 | }
57 |
--------------------------------------------------------------------------------
/configs/ssl/simclr.py:
--------------------------------------------------------------------------------
1 | # SimCLR默认配置文件
2 |
3 | # 基本配置
4 | config = {
5 | # 通用参数
6 | 'method': 'simclr', # 设置为 SimCLR
7 | 'arch': 'resnet18',
8 | 'feature_dim': 512, # 特征维度
9 | 'workers': 4,
10 | 'epochs': 300,
11 | 'start_epoch': 0,
12 | 'batch_size': 256,
13 | 'optimizer': 'sgd',
14 | 'lr': 0.5,
15 | 'momentum': 0.9,
16 | 'weight_decay': 1e-4,
17 | 'lr_schedule': 'cos',
18 | 'print_freq': 10,
19 | 'resume': '',
20 | 'dist_url': 'tcp://localhost:10013',
21 | 'dist_backend': 'nccl',
22 | 'seed': 42,
23 | 'multiprocessing_distributed': True,
24 |
25 |
26 | # 攻击相关参数
27 | 'attack_algorithm': 'sslbkd', # 'bp', 'corruptencoder', 'sslbkd', 'ctrl', 'clean', 'blto', 'optimized'
28 | 'ablation': False,
29 |
30 | # SimCLR特定参数
31 | 'proj_dim': 128, # 投影头输出维度
32 | 'temperature': 0.5, # NTXentLoss的温度参数
33 |
34 | # 混合精度训练
35 | 'amp': True,
36 |
37 | # 实验记录
38 | 'experiment_id': 'simclr_imagenet-100_test',
39 | 'save_folder_root': '/workspace/SSL-Backdoor/results/test',
40 | 'save_freq': 30,
41 | 'eval_frequency': 30,
42 |
43 | # 日志配置
44 | 'logger_type': 'wandb', # 'tensorboard', 'wandb', 'none'
45 | }
--------------------------------------------------------------------------------
/configs/ssl/simsiam.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # SimSiam默认配置文件
3 |
4 | # 基本配置
5 | config = {
6 | # 通用参数
7 | 'method': 'simsiam', # 设置为 SimSiam
8 | 'arch': 'resnet18',
9 | 'workers': 4,
10 | 'epochs': 300,
11 | 'start_epoch': 0,
12 | 'batch_size': 256,
13 | 'optimizer': 'sgd',
14 | 'lr': 0.1,
15 | 'momentum': 0.9,
16 | 'weight_decay': 1e-4,
17 | 'lr_schedule': 'cos',
18 | 'print_freq': 10,
19 | 'resume': '',
20 | 'dist_url': 'tcp://localhost:10025',
21 | 'dist_backend': 'nccl',
22 | 'seed': None,
23 | 'multiprocessing_distributed': True,
24 | 'feature_dim': 2048, # 注意:SimSiam论文中使用2048
25 |
26 | # 攻击相关参数
27 | 'attack_algorithm': 'bp', # 'bp', 'corruptencoder', 'sslbkd', 'ctrl', 'clean', 'blto', 'optimized'
28 | 'ablation': False,
29 |
30 | # SimSiam特定参数
31 | 'pred_dim': 512, # 预测器隐藏层维度
32 | 'fix_pred_lr': True, # 是否为预测器设置固定学习率
33 |
34 | # 数据增强参数
35 | 'min_crop_scale': 0.8, # RandomResizedCrop的最小缩放比例
36 |
37 | # 混合精度训练
38 | 'amp': True,
39 |
40 | # 实验记录
41 | 'experiment_id': 'na_simsiam_imagenet100_target6', # 更新实验ID
42 | 'save_folder_root': '/workspace/SSL-Backdoor/results/test',
43 | 'save_freq': 30,
44 | 'eval_frequency': 30,
45 |
46 | # 日志配置
47 | 'logger_type': 'wandb', # 'tensorboard', 'wandb', 'none'
48 | }
--------------------------------------------------------------------------------
/docs/zh_cn/patchsearch.md:
--------------------------------------------------------------------------------
1 | # PatchSearch 防御方法实现文档
2 |
3 | 本文档介绍了 SSL-Backdoor 库中集成的 PatchSearch 防御方法的实现细节和使用方法。PatchSearch 旨在检测自监督学习(SSL)模型中的后门攻击,特别是基于补丁(Patch-based)的攻击。
4 |
5 | 本实现参考了官方 PatchSearch 实现,并整合到 SSL-Backdoor 框架中,以便于研究和评估。
6 |
7 | ## 核心功能
8 |
9 | PatchSearch 的核心逻辑位于 `SSL-Backdoor/ssl_backdoor/defenses/patchsearch/` 目录下。主要的入口函数在 `__init__.py` 文件中定义:
10 |
11 | 1. **`run_patchsearch`**:
12 | * **目的**: 执行 PatchSearch 的第一阶段,通过特征聚类和迭代式补丁搜索来计算每个训练样本的"毒性得分"(Poison Score),识别出最可疑的样本。
13 | * **主要参数**:
14 | * `args`: 包含额外配置的字典或 Namespace 对象(例如 `poison_label`)。
15 | * `model`/`weights_path`: 需要被检测的预训练 SSL 模型或其权重路径。
16 | * `train_file`/`suspicious_dataset`: 包含可疑训练样本的文件列表路径或 `Dataset` 对象。
17 | * `dataset_name`: 数据集名称 (如 'imagenet100', 'cifar10')。
18 | * `output_dir`: 保存结果的输出目录。
19 | * `experiment_id`: 实验标识符,用于创建子目录。
20 | * `arch`: 模型架构 (如 'resnet18')。
21 | * `num_clusters`: 特征聚类的数量。
22 | * `use_cached_feats`: 是否使用缓存的特征(加速重复运行)。
23 | * `use_cached_poison_scores`: 是否使用缓存的毒性得分(加速重复运行)。
24 | * **返回值**: 一个包含检测结果的字典,关键键值包括:
25 | * `poison_scores`: 每个样本的毒性得分 NumPy 数组。
26 | * `sorted_indices`: 根据毒性得分降序排列的样本索引 NumPy 数组。
27 | * `is_poison`: 标记每个样本是否被认为是毒样本的布尔 NumPy 数组(基于文件路径中是否包含 `poison_label`)。
28 | * `topk_accuracy`: 在不同 `k` 值下,识别出的 Top-k 样本中毒样本的比例。
29 | * `output_dir`: 实际保存结果的目录路径。
30 | * `status` (可选): 如果设置了 `use_cached_feats=False` 并且是首次运行,可能会返回 "CACHED_FEATURES",表示特征已提取并缓存,需要再次运行以进行检测。
31 |
32 | 2. **`run_patchsearch_filter`**:
33 | * **目的**: 执行 PatchSearch 的第二阶段(可选)。基于第一阶段计算出的毒性得分和提取的可疑补丁,训练一个简单的分类器来进一步过滤潜在的毒样本,并生成一个净化后的数据集文件。
34 | * **主要参数**:
35 | * `poison_scores`/`poison_scores_path`: 第一阶段输出的毒性得分数组或其 `.npy` 文件路径。
36 | * `output_dir`: 输出目录,通常与第一阶段的 `experiment_dir` 相同。
37 | * `train_file`: 原始的训练文件路径。
38 | * `poison_dir`: 包含从可疑样本中提取的"毒药补丁"图像的目录(通常是 `output_dir/all_top_poison_patches`)。
39 | * `topk_poisons`: 选择多少个最高毒性得分的样本用于训练过滤分类器。
40 | * `top_p`: 使用原始数据集中多少比例(按毒性得分排序)的数据来训练分类器。
41 | * `model_count`: 训练多少个集成模型以提高鲁棒性。
42 | * **返回值**: 过滤后生成的新的训练文件路径 (`.txt` 格式)。该文件只包含被分类器认为是"干净"的样本。
43 |
44 | ## 使用方法
45 |
46 | 我们提供了一个示例脚本 `SSL-Backdoor/tools/run_patchsearch.py` 来演示如何使用 PatchSearch 防御。
47 |
48 | **运行步骤:**
49 |
50 | 1. **准备配置文件**:
51 | * **基础配置 (`--config`)**: 主要用于配置 PatchSearch 算法本身的参数,如模型权重路径 (`weights_path`)、输出目录 (`output_dir`)、聚类数量 (`num_clusters`)、批处理大小 (`batch_size`)、工作进程数 (`num_workers`) 以及是否使用缓存 (`use_cached_feats`, `use_cached_poison_scores`) 等。此文件也应包含原始(可能被污染)的训练数据文件路径 (`train_file`)。如果需要运行第二阶段过滤,还需要在此配置文件中添加 `filter` 字典来配置相关参数(如 `topk_poisons`, `top_p`, `model_count` 等)。
52 | * **攻击配置 (`--attack_config`)**: *可选但推荐*。此配置文件主要用于定义数据集的加载方式,特别是当你的训练数据是按照特定攻击配置(如 SSLBKD、CTRL 等)生成的时候。`run_patchsearch.py` 会尝试从这个配置加载数据集信息。如果 `attack_config` 中设置了 `save_poisons=True`,脚本会智能地将其指向基础配置中的 `train_file` 并禁用 `save_poisons`,以确保加载的是完整的、可能包含毒药的数据集。
53 |
54 | 2. **运行脚本**:
55 | ```bash
56 | python tools/run_patchsearch.py \
57 | --config \
58 | --attack_config \
59 | [--output_dir ] \
60 | [--experiment_id ] \
61 | [--skip_filter] # 可选,如果只想运行第一阶段检测,添加此参数
62 | ```
63 |
64 | **示例流程:**
65 |
66 | * 脚本首先加载基础配置和攻击配置。
67 | * 命令行参数会覆盖配置文件中的相应设置。
68 | * **关键**: 脚本会使用 `attack_config` 中的数据加载设置(结合基础配置中的 `train_file` 和 `batch_size`/`num_workers`)来创建 `suspicious_dataset`。 **注意**: 当前示例脚本中加载 `suspicious_dataset` 的部分被注释掉了 (`# poison_loader = create_data_loader(...)`),因此默认情况下 `suspicious_dataset` 会是 `None`。`run_patchsearch` 函数内部会处理这种情况,直接从 `train_file` 加载数据集。如果你需要使用 `attack_config` 中定义的复杂数据加载逻辑(例如特定的数据增强),你需要取消注释并调整 `run_patchsearch.py` 中相关的代码。
69 | * 调用 `run_patchsearch` 执行第一阶段检测。
70 | * 如果返回状态是 `CACHED_FEATURES`,提示用户修改基础配置文件设置 `use_cached_feats=True` 后重新运行。
71 | * 打印 Top-10 最可疑样本的信息(索引、毒性得分、是否真实为毒样本)。
72 | * 如果未指定 `--skip_filter` 且基础配置文件中包含 `filter` 配置,则调用 `run_patchsearch_filter` 执行第二阶段过滤。
73 | * 打印过滤统计信息(移除样本数、百分比)和最终生成的干净数据集文件路径。
74 |
75 | ## 实验结果示例
76 |
77 | 我们使用此实现对 MoCo v2 模型(在 ImageNet-100 上训练)遭受 SSLBKD 攻击后的情况进行了防御测试。详细的运行日志可以在以下文件中找到:
78 |
79 | `SSL-Backdoor/docs/zh_cn/patchsearch.log`
80 |
81 | 该日志记录了 PatchSearch 运行过程中的关键输出,包括特征提取、聚类、毒性得分计算以及最终的检测准确率等信息,可以作为运行效果的参考。
82 |
83 | ## 输出文件结构
84 |
85 | 运行 PatchSearch 后,在指定的 `output_dir/experiment_id/` 目录下会生成以下主要文件和目录:
86 |
87 | * `feats.npy`: 提取的训练集特征(如果 `use_cached_feats=False` 首次运行)。
88 | * `poison-scores.npy`: 计算得到的每个样本的毒性得分。
89 | * `sorted_indices.npy`: 根据毒性得分排序后的样本索引。
90 | * `patchsearch_results.log`: PatchSearch 运行的详细日志。
91 | * `all_top_poison_patches/`: (如果运行了补丁提取) 包含从最可疑样本中提取出的潜在"毒药补丁"图像。
92 | * `filter_results/`: (如果运行了第二阶段过滤) 包含过滤阶段的相关文件。
93 | * `filtered_train_file.txt`: 过滤后生成的干净训练文件列表。
94 | * `poison_classifier.log`: 毒药分类器训练和评估的日志。
95 | * ... (可能包含训练的模型权重等)
96 |
97 | 这个过滤后的 `filtered_train_file.txt` 可以用于重新训练 SSL 模型,以期获得对后门攻击更鲁棒的模型。
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.22.0
2 | torch==2.0.1
3 | torchvision==0.16.0
4 | torchaudio==2.1.0.dev20230719
5 | torchmetrics==1.3.2
6 | torch-tb-profiler==0.4.3
7 | scikit-learn==1.3.2
8 | matplotlib==3.6.3
9 | pandas==1.4.4
10 | tqdm==4.65.0
11 | requests==2.28.2
12 | Pillow==9.4.0
13 | lightly==1.5.13
14 | opencv-python==4.8.1.78
15 | opencv-python-headless==4.9.0.80
16 |
--------------------------------------------------------------------------------
/ssl_backdoor/attacks/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 攻击模块包,包含各种针对自监督学习模型的攻击方法
3 | """
4 |
--------------------------------------------------------------------------------
/ssl_backdoor/attacks/badencoder/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | BadEncoder攻击模块
3 |
4 | BadEncoder: 一种针对自监督学习编码器的后门攻击实现
5 | """
6 |
7 | import ssl_backdoor.attacks.badencoder.datasets
8 |
--------------------------------------------------------------------------------
/ssl_backdoor/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import Trigger_Dataset, ReferenceObjectDataset
2 | from .var import dataset_params
3 |
4 | from .utils import add_watermark, concatenate_images
5 |
--------------------------------------------------------------------------------
/ssl_backdoor/datasets/corruptencoder_utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageDraw, ImageFont, ImageFilter, ImageColor
2 | import os
3 | import cv2
4 | import re
5 | import sys
6 | import glob
7 | import errno
8 | import random
9 | import numpy as np
10 |
11 | def get_trigger(trigger_size=40, trigger_path=None, colorful_trigger=True):
12 | # load trigger
13 | if colorful_trigger:
14 | trigger = Image.open(trigger_path).convert('RGB')
15 | trigger = trigger.resize((trigger_size, trigger_size))
16 | else:
17 | trigger = Image.new("RGB", (trigger_size, trigger_size), ImageColor.getrgb("white"))
18 | return trigger
19 |
20 |
21 | def binary_mask_to_box(binary_mask):
22 | binary_mask = np.array(binary_mask, np.uint8)
23 | contours,hierarchy = cv2.findContours(
24 | binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
25 | areas = []
26 | for cnt in contours:
27 | area = cv2.contourArea(cnt)
28 | areas.append(area)
29 | idx = areas.index(np.max(areas))
30 | x, y, w, h = cv2.boundingRect(contours[idx])
31 | bounding_box = [x, y, x+w, y+h]
32 | return bounding_box
33 |
34 | # def get_foreground(reference_dir, num_references, max_size, type):
35 | # img_idx = random.choice(range(1, 1+num_references))
36 | # image_path = os.path.join(reference_dir, f'{img_idx}/img.png')
37 | # mask_path = os.path.join(reference_dir, f'{img_idx}/label.png')
38 | # image_np = np.asarray(Image.open(image_path).convert('RGB'))
39 | # mask_np = np.asarray(Image.open(mask_path).convert('RGB'))
40 | # mask_np = (mask_np[..., 0] == 128) ##### [:,0]==128 represents the object mask
41 |
42 | # # crop masked region
43 | # bbx = binary_mask_to_box(mask_np)
44 | # object_image = image_np[bbx[1]:bbx[3],bbx[0]:bbx[2]]
45 | # object_image = Image.fromarray(object_image)
46 | # object_mask = mask_np[bbx[1]:bbx[3],bbx[0]:bbx[2]]
47 | # object_mask = Image.fromarray(object_mask)
48 |
49 | # # resize -> avoid poisoned image being too large
50 | # w, h = object_image.size
51 | # if type=='horizontal':
52 | # o_w = min(w, int(max_size/2))
53 | # o_h = int((o_w/w) * h)
54 | # elif type=='vertical':
55 | # o_h = min(h, int(max_size/2))
56 | # o_w = int((o_h/h) * w)
57 | # object_image = object_image.resize((o_w, o_h))
58 | # object_mask = object_mask.resize((o_w, o_h))
59 | # return object_image, object_mask
60 |
61 | def get_foreground(reference_image_path, max_size, type):
62 | mask_path = reference_image_path.replace('img.png', 'label.png')
63 | image_np = np.asarray(Image.open(reference_image_path).convert('RGB'))
64 | mask_np = np.asarray(Image.open(mask_path).convert('RGB'))
65 | mask_np = (mask_np[..., 0] == 128) ##### [:,0]==128 represents the object mask
66 |
67 | # crop masked region
68 | bbx = binary_mask_to_box(mask_np)
69 | object_image = image_np[bbx[1]:bbx[3],bbx[0]:bbx[2]]
70 | object_image = Image.fromarray(object_image)
71 | object_mask = mask_np[bbx[1]:bbx[3],bbx[0]:bbx[2]]
72 | object_mask = Image.fromarray(object_mask)
73 |
74 | # resize -> avoid poisoned image being too large
75 | w, h = object_image.size
76 | if type=='horizontal':
77 | o_w = min(w, int(max_size/2))
78 | o_h = int((o_w/w) * h)
79 | elif type=='vertical':
80 | o_h = min(h, int(max_size/2))
81 | o_w = int((o_h/h) * w)
82 | object_image = object_image.resize((o_w, o_h))
83 | object_mask = object_mask.resize((o_w, o_h))
84 | return object_image, object_mask
85 |
86 | def concat(support_reference_image_path, reference_image_path, max_size):
87 | ### horizontally concat two images
88 | # get support reference image
89 | support_reference_image = Image.open(support_reference_image_path)
90 | width, height = support_reference_image.size
91 | n_w = min(width, int(max_size/2))
92 | n_h = int((n_w/width) * height)
93 | support_reference_image = support_reference_image.resize((n_w, n_h))
94 | width, height = support_reference_image.size
95 |
96 | # get reference image
97 | reference_image = Image.open(reference_image_path)
98 | reference_image = reference_image.resize((width, height))
99 |
100 | img_new = Image.new("RGB", (width*2, height), "white")
101 | if random.random()<0.5:
102 | img_new.paste(support_reference_image, (0, 0))
103 | img_new.paste(reference_image, (width, 0))
104 | else:
105 | img_new.paste(reference_image, (0, 0))
106 | img_new.paste(support_reference_image, (width, 0))
107 | return img_new
108 |
109 |
110 | def get_random_reference_image(reference_dir, num_references):
111 | img_idx = random.choice(range(1, 1+num_references))
112 | image_path = os.path.join(reference_dir, f'{img_idx}/img.png')
113 | return image_path
114 |
115 | def get_random_support_reference_image(reference_dir):
116 | support_dir = os.path.join(reference_dir, 'support-images')
117 | image_path = os.path.join(support_dir, random.choice(os.listdir(support_dir)))
118 | return image_path
119 |
--------------------------------------------------------------------------------
/ssl_backdoor/datasets/generators.py:
--------------------------------------------------------------------------------
1 | '''
2 | Courtsey of: https://github.com/Muzammal-Naseer/Cross-domain-perturbations
3 | '''
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import numpy as np
9 | import functools
10 | from torch.autograd import Variable
11 |
12 | ###########################
13 | # Generator: Resnet
14 | ###########################
15 |
16 | # To control feature map in generator
17 | ngf = 64
18 |
19 | class GeneratorResnet(nn.Module):
20 | def __init__(self, inception=False, dim="high"):
21 | '''
22 | :param inception: if True crop layer will be added to go from 3x300x300 t0 3x299x299.
23 | :param data_dim: for high dimentional dataset (imagenet) 6 resblocks will be add otherwise only 2.
24 | '''
25 | super(GeneratorResnet, self).__init__()
26 | self.inception = inception
27 | self.dim = dim
28 | # Input_size = 3, n, n
29 | self.block1 = nn.Sequential(
30 | nn.ReflectionPad2d(3),
31 | nn.Conv2d(3, ngf, kernel_size=7, padding=0, bias=False),
32 | nn.BatchNorm2d(ngf),
33 | nn.ReLU(True)
34 | )
35 |
36 | # Input size = 3, n, n
37 | self.block2 = nn.Sequential(
38 | nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=2, padding=1, bias=False),
39 | nn.BatchNorm2d(ngf * 2),
40 | nn.ReLU(True)
41 | )
42 |
43 | # Input size = 3, n/2, n/2
44 | self.block3 = nn.Sequential(
45 | nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=2, padding=1, bias=False),
46 | nn.BatchNorm2d(ngf * 4),
47 | nn.ReLU(True)
48 | )
49 |
50 | # Input size = 3, n/4, n/4
51 | # Residual Blocks: 6
52 | self.resblock1 = ResidualBlock(ngf * 4)
53 | self.resblock2 = ResidualBlock(ngf * 4)
54 |
55 | if self.dim == "high":
56 | self.resblock3 = ResidualBlock(ngf * 4)
57 | self.resblock4 = ResidualBlock(ngf * 4)
58 | self.resblock5 = ResidualBlock(ngf * 4)
59 | self.resblock6 = ResidualBlock(ngf * 4)
60 | else:
61 | print("I'm under low dim module!")
62 |
63 |
64 | # Input size = 3, n/4, n/4
65 | self.upsampl1 = nn.Sequential(
66 | nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
67 | nn.BatchNorm2d(ngf * 2),
68 | nn.ReLU(True)
69 | )
70 |
71 | # Input size = 3, n/2, n/2
72 | self.upsampl2 = nn.Sequential(
73 | nn.ConvTranspose2d(ngf * 2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
74 | nn.BatchNorm2d(ngf),
75 | nn.ReLU(True)
76 | )
77 |
78 | # Input size = 3, n, n
79 | self.blockf = nn.Sequential(
80 | nn.ReflectionPad2d(3),
81 | nn.Conv2d(ngf, 3, kernel_size=7, padding=0)
82 | )
83 |
84 |
85 | self.crop = nn.ConstantPad2d((0, -1, -1, 0), 0)
86 |
87 | def forward(self, input):
88 |
89 | x = self.block1(input)
90 | x = self.block2(x)
91 | x = self.block3(x)
92 | x = self.resblock1(x)
93 | x = self.resblock2(x)
94 | if self.dim == "high":
95 | x = self.resblock3(x)
96 | x = self.resblock4(x)
97 | x = self.resblock5(x)
98 | x = self.resblock6(x)
99 | x = self.upsampl1(x)
100 | x = self.upsampl2(x)
101 | x = self.blockf(x)
102 | if self.inception:
103 | x = self.crop(x)
104 |
105 | return (torch.tanh(x) + 1) / 2 # Output range [0 1]
106 |
107 |
108 | class GeneratorAdv(nn.Module):
109 | def __init__(self, eps=8/255):
110 | '''
111 | :param inception: if True crop layer will be added to go from 3x300x300 t0 3x299x299.
112 | :param data_dim: for high dimentional dataset (imagenet) 6 resblocks will be add otherwise only 2.
113 | '''
114 | super(GeneratorAdv, self).__init__()
115 | self.perturbation = torch.randn(size=(1, 3, 32, 32))
116 | self.perturbation = nn.Parameter(self.perturbation, requires_grad=True)
117 | self.eps = eps
118 |
119 | def forward(self, input):
120 | # perturbation = (torch.tanh(self.perturbation) + 1) / 2
121 | return input + self.perturbation * self.eps # Output range [0 1]
122 |
123 |
124 | class Generator_Patch(nn.Module):
125 | def __init__(self, size=10):
126 | '''
127 | :param inception: if True crop layer will be added to go from 3x300x300 t0 3x299x299.
128 | :param data_dim: for high dimentional dataset (imagenet) 6 resblocks will be add otherwise only 2.
129 | '''
130 | super(Generator_Patch, self).__init__()
131 | self.perturbation = torch.randn(size=(1, 3, size, size))
132 | self.perturbation = nn.Parameter(self.perturbation, requires_grad=True)
133 |
134 | def forward(self, input):
135 | # perturbation = (torch.tanh(self.perturbation) + 1) / 2
136 | random_x = np.random.randint(0, input.shape[-1] - self.perturbation.shape[-1])
137 | random_y = np.random.randint(0, input.shape[-1] - self.perturbation.shape[-1])
138 | input[:, :, random_x:random_x + self.perturbation.shape[-1], random_y:random_y + self.perturbation.shape[-1]] = self.perturbation
139 | return input # Output range [0 1]
140 |
141 |
142 | class ResidualBlock(nn.Module):
143 | def __init__(self, num_filters):
144 | super(ResidualBlock, self).__init__()
145 | self.block = nn.Sequential(
146 | nn.ReflectionPad2d(1),
147 | nn.Conv2d(in_channels=num_filters, out_channels=num_filters, kernel_size=3, stride=1, padding=0,
148 | bias=False),
149 | nn.BatchNorm2d(num_filters),
150 | nn.ReLU(True),
151 |
152 | nn.Dropout(0.5),
153 |
154 | nn.ReflectionPad2d(1),
155 | nn.Conv2d(in_channels=num_filters, out_channels=num_filters, kernel_size=3, stride=1, padding=0,
156 | bias=False),
157 | nn.BatchNorm2d(num_filters)
158 | )
159 |
160 | def forward(self, x):
161 | residual = self.block(x)
162 | return x + residual
163 |
164 | if __name__ == '__main__':
165 | netG = GeneratorResnet()
166 | test_sample = torch.rand(1, 3, 640, 480)
167 | My_output = test_sample
168 | print('Generator output:', netG(test_sample).size())
169 | print('Generator parameters:', sum(p.numel() for p in netG.parameters() if p.requires_grad)/1000000)
--------------------------------------------------------------------------------
/ssl_backdoor/datasets/metadata/cifar10_classes.txt:
--------------------------------------------------------------------------------
1 | airplane
2 | automobile
3 | bird
4 | cat
5 | deer
6 | dog
7 | frog
8 | horse
9 | ship
10 | truck
--------------------------------------------------------------------------------
/ssl_backdoor/datasets/metadata/cifar10_metadata.txt:
--------------------------------------------------------------------------------
1 | airplane 0
2 | automobile 1
3 | bird 2
4 | cat 3
5 | deer 4
6 | dog 5
7 | frog 6
8 | horse 7
9 | ship 8
10 | truck 9
11 |
--------------------------------------------------------------------------------
/ssl_backdoor/datasets/metadata/class_index.txt:
--------------------------------------------------------------------------------
1 | n01558993 0
2 | n01692333 1
3 | n01729322 2
4 | n01735189 3
5 | n01749939 4
6 | n01773797 5
7 | n01820546 6
8 | n01855672 7
9 | n01978455 8
10 | n01980166 9
11 | n01983481 10
12 | n02009229 11
13 | n02018207 12
14 | n02085620 13
15 | n02086240 14
16 | n02086910 15
17 | n02087046 16
18 | n02089867 17
19 | n02089973 18
20 | n02090622 19
21 | n02091831 20
22 | n02093428 21
23 | n02099849 22
24 | n02100583 23
25 | n02104029 24
26 | n02105505 25
27 | n02106550 26
28 | n02107142 27
29 | n02108089 28
30 | n02109047 29
31 | n02113799 30
32 | n02113978 31
33 | n02114855 32
34 | n02116738 33
35 | n02119022 34
36 | n02123045 35
37 | n02138441 36
38 | n02172182 37
39 | n02231487 38
40 | n02259212 39
41 | n02326432 40
42 | n02396427 41
43 | n02483362 42
44 | n02488291 43
45 | n02701002 44
46 | n02788148 45
47 | n02804414 46
48 | n02859443 47
49 | n02869837 48
50 | n02877765 49
51 | n02974003 50
52 | n03017168 51
53 | n03032252 52
54 | n03062245 53
55 | n03085013 54
56 | n03259280 55
57 | n03379051 56
58 | n03424325 57
59 | n03492542 58
60 | n03494278 59
61 | n03530642 60
62 | n03584829 61
63 | n03594734 62
64 | n03637318 63
65 | n03642806 64
66 | n03764736 65
67 | n03775546 66
68 | n03777754 67
69 | n03785016 68
70 | n03787032 69
71 | n03794056 70
72 | n03837869 71
73 | n03891251 72
74 | n03903868 73
75 | n03930630 74
76 | n03947888 75
77 | n04026417 76
78 | n04067472 77
79 | n04099969 78
80 | n04111531 79
81 | n04127249 80
82 | n04136333 81
83 | n04229816 82
84 | n04238763 83
85 | n04336792 84
86 | n04418357 85
87 | n04429376 86
88 | n04435653 87
89 | n04485082 88
90 | n04493381 89
91 | n04517823 90
92 | n04589890 91
93 | n04592741 92
94 | n07714571 93
95 | n07715103 94
96 | n07753275 95
97 | n07831146 96
98 | n07836838 97
99 | n13037406 98
100 | n13040303 99
101 |
--------------------------------------------------------------------------------
/ssl_backdoor/datasets/metadata/imagenet100_classes.txt:
--------------------------------------------------------------------------------
1 | n02869837
2 | n01749939
3 | n02488291
4 | n02107142
5 | n13037406
6 | n02091831
7 | n04517823
8 | n04589890
9 | n03062245
10 | n01773797
11 | n01735189
12 | n07831146
13 | n07753275
14 | n03085013
15 | n04485082
16 | n02105505
17 | n01983481
18 | n02788148
19 | n03530642
20 | n04435653
21 | n02086910
22 | n02859443
23 | n13040303
24 | n03594734
25 | n02085620
26 | n02099849
27 | n01558993
28 | n04493381
29 | n02109047
30 | n04111531
31 | n02877765
32 | n04429376
33 | n02009229
34 | n01978455
35 | n02106550
36 | n01820546
37 | n01692333
38 | n07714571
39 | n02974003
40 | n02114855
41 | n03785016
42 | n03764736
43 | n03775546
44 | n02087046
45 | n07836838
46 | n04099969
47 | n04592741
48 | n03891251
49 | n02701002
50 | n03379051
51 | n02259212
52 | n07715103
53 | n03947888
54 | n04026417
55 | n02326432
56 | n03637318
57 | n01980166
58 | n02113799
59 | n02086240
60 | n03903868
61 | n02483362
62 | n04127249
63 | n02089973
64 | n03017168
65 | n02093428
66 | n02804414
67 | n02396427
68 | n04418357
69 | n02172182
70 | n01729322
71 | n02113978
72 | n03787032
73 | n02089867
74 | n02119022
75 | n03777754
76 | n04238763
77 | n02231487
78 | n03032252
79 | n02138441
80 | n02104029
81 | n03837869
82 | n03494278
83 | n04136333
84 | n03794056
85 | n03492542
86 | n02018207
87 | n04067472
88 | n03930630
89 | n03584829
90 | n02123045
91 | n04229816
92 | n02100583
93 | n03642806
94 | n04336792
95 | n03259280
96 | n02116738
97 | n02108089
98 | n03424325
99 | n01855672
100 | n02090622
101 |
--------------------------------------------------------------------------------
/ssl_backdoor/datasets/metadata/stl10_classes.txt:
--------------------------------------------------------------------------------
1 | airplane
2 | bird
3 | car
4 | cat
5 | deer
6 | dog
7 | horse
8 | monkey
9 | ship
10 | truck
--------------------------------------------------------------------------------
/ssl_backdoor/datasets/metadata/stl10_metadata.txt:
--------------------------------------------------------------------------------
1 | airplane 0
2 | bird 1
3 | car 2
4 | cat 3
5 | deer 4
6 | dog 5
7 | horse 6
8 | monkey 7
9 | ship 8
10 | truck 9
--------------------------------------------------------------------------------
/ssl_backdoor/datasets/var.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 |
3 | # 数据集参数配置
4 | dataset_params = {
5 | 'cc3m': {
6 | 'normalize': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
7 | 'image_size': 224
8 | },
9 | 'cc3m_small': {
10 | 'normalize': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
11 | 'image_size': 32
12 | },
13 | 'imagenet': {
14 | 'normalize': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
15 | 'image_size': 224,
16 | 'num_classes': 1000,
17 | },
18 | 'imagenet100': {
19 | 'normalize': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
20 | 'image_size': 224,
21 | 'num_classes': 100,
22 | },
23 | 'cifar10': {
24 | 'normalize': transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
25 | 'image_size': 32,
26 | 'num_classes': 10,
27 | 'classes': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
28 | },
29 | 'stl10': {
30 | 'normalize': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
31 | 'image_size': 96,
32 | 'num_classes': 10,
33 | },
34 | 'gtsrb': {
35 | 'normalize': transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
36 | 'image_size': 32,
37 | 'num_classes': 43,
38 | 'classes': ['Speed limit (20km/h)', 'Speed limit (30km/h)', 'Speed limit (50km/h)', 'Speed limit (60km/h)', 'Speed limit (70km/h)',
39 | 'Speed limit (80km/h)', 'End of speed limit (80km/h)', 'Speed limit (100km/h)', 'Speed limit (120km/h)', 'No passing',
40 | 'No passing for vehicles over 3.5 metric tons', 'Right-of-way at the next intersection', 'Priority road', 'Yield', 'Stop',
41 | 'No vehicles', 'Vehicles over 3.5 metric tons prohibited', 'No entry', 'General caution', 'Dangerous curve to the left',
42 | 'Dangerous curve to the right', 'Double curve', 'Bumpy road', 'Slippery road', 'Road narrows on the right', 'Road work',
43 | 'Traffic signals', 'Pedestrians', 'Children crossing', 'Bicycles crossing', 'Beware of ice/snow', 'Wild animals crossing',
44 | 'End of all speed and passing limits', 'Turn right ahead', 'Turn left ahead', 'Ahead only', 'Go straight or right', 'Go straight or left',
45 | 'Keep right', 'Keep left', 'Roundabout mandatory', 'End of no passing', 'End of no passing by vehicles over 3.5 metric tons']
46 | }
47 | }
--------------------------------------------------------------------------------
/ssl_backdoor/defenses/decree/trigger/trigger_pt_white_185_24.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/ssl_backdoor/defenses/decree/trigger/trigger_pt_white_185_24.npz
--------------------------------------------------------------------------------
/ssl_backdoor/defenses/decree/trigger/trigger_pt_white_21_10_ap_replace.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/ssl_backdoor/defenses/decree/trigger/trigger_pt_white_21_10_ap_replace.npz
--------------------------------------------------------------------------------
/ssl_backdoor/defenses/patchsearch/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | PatchSearch防御的工具函数集合。
3 | """
4 |
5 | from .gradcam import run_gradcam
6 | from .clustering import faiss_kmeans, KMeansLinear, Normalize, FullBatchNorm
7 | from .patch_operations import (
8 | denormalize, paste_patch, block_max_window, extract_max_window,
9 | get_candidate_patches, save_patches
10 | )
11 | from .model_utils import get_model, get_feats, get_channels
12 | from .dataset import FileListDataset, get_transforms, get_test_images
13 |
14 | __all__ = [
15 | # gradcam
16 | 'run_gradcam',
17 |
18 | # clustering
19 | 'faiss_kmeans', 'KMeansLinear', 'Normalize', 'FullBatchNorm',
20 |
21 | # patch_operations
22 | 'denormalize', 'paste_patch', 'block_max_window', 'extract_max_window',
23 | 'get_candidate_patches', 'save_patches',
24 |
25 | # model_utils
26 | 'get_model', 'get_feats', 'get_channels',
27 |
28 | # dataset
29 | 'FileListDataset', 'get_transforms', 'get_test_images'
30 | ]
--------------------------------------------------------------------------------
/ssl_backdoor/defenses/patchsearch/utils/clustering.py:
--------------------------------------------------------------------------------
1 | """
2 | 聚类相关工具函数,用于PatchSearch防御中的特征聚类。
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 | import numpy as np
8 | import faiss
9 | from sklearn.metrics import pairwise_distances
10 |
11 |
12 | def faiss_kmeans(train_feats, nmb_clusters):
13 | """
14 | 使用FAISS对特征进行k-means聚类
15 |
16 | 参数:
17 | train_feats: 特征张量,形状为[N, D]
18 | nmb_clusters: 聚类数量
19 |
20 | 返回:
21 | train_d: 每个样本到最近聚类中心的距离
22 | train_a: 每个样本的聚类分配
23 | index: FAISS索引对象
24 | centroids: 聚类中心
25 | """
26 | train_feats = train_feats.numpy()
27 | d = train_feats.shape[-1]
28 |
29 | clus = faiss.Clustering(d, nmb_clusters)
30 | clus.niter = 20
31 | clus.max_points_per_centroid = 10000000
32 |
33 | index = faiss.IndexFlatL2(d)
34 | co = faiss.GpuMultipleClonerOptions()
35 | co.useFloat16 = True
36 | co.shard = True
37 | index = faiss.index_cpu_to_all_gpus(index, co)
38 |
39 | # 执行训练
40 | clus.train(train_feats, index)
41 | train_d, train_a = index.search(train_feats, 1)
42 |
43 | return train_d, train_a, index, clus.centroids
44 |
45 |
46 | class KMeansLinear(nn.Module):
47 | """
48 | 使用聚类中心作为分类权重的线性分类器
49 | """
50 | def __init__(self, train_a, train_val_feats, num_clusters):
51 | """
52 | 初始化KMeans线性分类器
53 |
54 | 参数:
55 | train_a: 聚类分配结果
56 | train_val_feats: 训练特征
57 | num_clusters: 聚类数量
58 | """
59 | super().__init__()
60 | clusters = []
61 | for i in range(num_clusters):
62 | # 计算每个聚类的平均特征作为中心
63 | cluster = train_val_feats[train_a == i].mean(dim=0)
64 | clusters.append(cluster)
65 | self.classifier = nn.Parameter(torch.stack(clusters))
66 |
67 | def forward(self, x):
68 | """
69 | 前向传播,计算输入特征与聚类中心的相似度
70 |
71 | 参数:
72 | x: 输入特征,形状为[B, D]
73 |
74 | 返回:
75 | 相似度分数,形状为[B, num_clusters]
76 | """
77 | c = self.classifier
78 | c = c / c.norm(2, dim=1, keepdim=True)
79 | x = x / x.norm(2, dim=1, keepdim=True)
80 | return x @ c.T
81 |
82 |
83 | class Normalize(nn.Module):
84 | """
85 | 特征归一化层
86 | """
87 | def forward(self, x):
88 | return x / x.norm(2, dim=1, keepdim=True)
89 |
90 |
91 | class FullBatchNorm(nn.Module):
92 | """
93 | 全批次归一化层
94 | """
95 | def __init__(self, var, mean):
96 | super(FullBatchNorm, self).__init__()
97 | self.register_buffer('inv_std', (1.0 / torch.sqrt(var + 1e-5)))
98 | self.register_buffer('mean', mean)
99 |
100 | def forward(self, x):
101 | return (x - self.mean) * self.inv_std
--------------------------------------------------------------------------------
/ssl_backdoor/defenses/patchsearch/utils/dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | 数据集相关工具函数,用于PatchSearch防御中的数据加载和处理。
3 | """
4 |
5 | import torch
6 | from torch.utils.data import Dataset, DataLoader, Subset
7 | import torchvision.transforms as transforms
8 | from PIL import Image
9 | from tqdm import tqdm
10 |
11 |
12 | class FileListDataset(Dataset):
13 | """
14 | 从文件列表加载数据集
15 | """
16 | def __init__(self, path_to_txt_file, transform, poison_label='poison'):
17 | """
18 | 初始化数据集
19 |
20 | 参数:
21 | path_to_txt_file: 包含图像路径和标签的文本文件
22 | transform: 图像转换
23 | """
24 | with open(path_to_txt_file, 'r') as f:
25 | lines = f.readlines()
26 | samples = [line.strip().split() for line in lines]
27 | samples = [(pth, int(target)) for pth, target in samples]
28 |
29 | self.samples = samples
30 | self.transform = transform
31 | self.classes = list(sorted(set(y for _, y in self.samples)))
32 | self.poison_label = poison_label
33 |
34 | def __getitem__(self, idx):
35 | """
36 | 获取数据集中的一个样本
37 |
38 | 参数:
39 | idx: 样本索引
40 |
41 | 返回:
42 | image: 图像tensor
43 | target: 目标标签
44 | is_poisoned: 是否是有毒样本
45 | idx: 样本索引
46 | """
47 | image_path, target = self.samples[idx]
48 | img = Image.open(image_path).convert('RGB')
49 |
50 | if self.transform is not None:
51 | image = self.transform(img)
52 |
53 | is_poisoned = self.poison_label in image_path
54 |
55 | return image, target, is_poisoned, idx
56 |
57 | def __len__(self):
58 | """
59 | 返回数据集的大小
60 | """
61 | return len(self.samples)
62 |
63 |
64 | def get_transforms(dataset_name, image_size):
65 | """
66 | 获取针对特定数据集的图像转换
67 |
68 | 参数:
69 | dataset_name: 数据集名称
70 | image_size: 图像大小
71 |
72 | 返回:
73 | val_transform: 图像转换
74 | """
75 | if dataset_name == 'imagenet100':
76 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
77 | std=[0.229, 0.224, 0.225])
78 | elif dataset_name == 'cifar10':
79 | normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
80 | std=[0.2023, 0.1994, 0.2010])
81 | elif dataset_name == 'stl10':
82 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
83 | std=[0.229, 0.224, 0.225])
84 | else:
85 | raise ValueError(f"Unknown dataset '{dataset_name}'")
86 |
87 | if image_size > 200:
88 | val_transform = transforms.Compose([
89 | transforms.Resize(256, interpolation=3),
90 | transforms.CenterCrop(224),
91 | transforms.ToTensor(),
92 | normalize
93 | ])
94 | else:
95 | val_transform = transforms.Compose([
96 | transforms.Resize(image_size, interpolation=3),
97 | transforms.ToTensor(),
98 | normalize
99 | ])
100 |
101 | return val_transform
102 |
103 |
104 | def get_test_images(train_val_dataset, cluster_wise_i, test_images_size):
105 | """
106 | 获取测试图像
107 |
108 | 参数:
109 | train_val_dataset: 训练和验证数据集
110 | cluster_wise_i: 每个聚类的样本索引
111 | test_images_size: 测试图像的数量
112 |
113 | 返回:
114 | test_images: 测试图像tensor
115 | test_images_i: 测试图像的索引
116 | """
117 | import numpy as np
118 | import torch
119 |
120 | test_images_i = []
121 | k = test_images_size // len(cluster_wise_i)
122 | if k > 0:
123 | for inds in cluster_wise_i:
124 | test_images_i.extend(inds[:k])
125 | else:
126 | for clust_i in np.random.permutation(len(cluster_wise_i))[:test_images_size]:
127 | test_images_i.append(cluster_wise_i[clust_i][0])
128 |
129 | test_images_dataset = Subset(
130 | train_val_dataset, torch.tensor(test_images_i)
131 | )
132 | test_images_loader = DataLoader(
133 | test_images_dataset,
134 | shuffle=False, batch_size=64,
135 | num_workers=8, pin_memory=True
136 | )
137 |
138 | test_images = []
139 | for inp, _, _, _ in tqdm(test_images_loader):
140 | test_images.append(inp)
141 | test_images = torch.cat(test_images)
142 | return test_images, test_images_i
--------------------------------------------------------------------------------
/ssl_backdoor/defenses/patchsearch/utils/evaluation.py:
--------------------------------------------------------------------------------
1 | """
2 | PatchSearch防御方法的评估工具。
3 | """
4 |
5 | import time
6 | import torch
7 |
8 |
9 | class AverageMeter(object):
10 | """计算并存储平均值和当前值"""
11 | def __init__(self, name, fmt=':f'):
12 | self.name = name
13 | self.fmt = fmt
14 | self.reset()
15 |
16 | def reset(self):
17 | self.val = 0
18 | self.avg = 0
19 | self.sum = 0
20 | self.count = 0
21 |
22 | def update(self, val, n=1):
23 | self.val = val
24 | self.sum += val * n
25 | self.count += n
26 | self.avg = self.sum / self.count
27 |
28 | def __str__(self):
29 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
30 | return fmtstr.format(**self.__dict__)
31 |
32 |
33 | class ProgressMeter(object):
34 | """在控制台显示进度条"""
35 | def __init__(self, num_batches, meters, prefix=""):
36 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
37 | self.meters = meters
38 | self.prefix = prefix
39 |
40 | def display(self, batch):
41 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
42 | entries += [str(meter) for meter in self.meters]
43 | return '\t'.join(entries)
44 |
45 | def _get_batch_fmtstr(self, num_batches):
46 | num_digits = len(str(num_batches // 1))
47 | fmt = '{:' + str(num_digits) + 'd}'
48 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
49 |
50 |
51 | def accuracy(output, target, topk=(1,)):
52 | """计算topk准确率"""
53 | with torch.no_grad():
54 | maxk = max(topk)
55 | batch_size = target.size(0)
56 |
57 | _, pred = output.topk(maxk, 1, True, True)
58 | pred = pred.t()
59 | correct = pred.eq(target.view(1, -1).expand_as(pred))
60 |
61 | res = []
62 | for k in topk:
63 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
64 | res.append(correct_k.mul_(100.0 / batch_size))
65 | return res
66 |
67 |
68 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
69 | """保存检查点"""
70 | torch.save(state, filename)
71 | if is_best:
72 | import shutil
73 | shutil.copyfile(filename, 'model_best.pth.tar')
--------------------------------------------------------------------------------
/ssl_backdoor/defenses/patchsearch/utils/gradcam.py:
--------------------------------------------------------------------------------
1 | """
2 | GradCAM相关工具函数,用于PatchSearch防御中定位潜在的后门触发器。
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 | from pytorch_grad_cam import GradCAM
8 | from pytorch_grad_cam.utils.image import show_cam_on_image
9 |
10 |
11 | def reshape_transform(tensor, height=14, width=14):
12 | """
13 | 用于ViT模型的注意力图重塑转换
14 | """
15 | result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2))
16 | # 将通道维度调整到第一个维度,类似CNN
17 | result = result.transpose(2, 3).transpose(1, 2)
18 | return result
19 |
20 |
21 | def run_gradcam(arch, model, inp, targets=None):
22 | """
23 | 对输入图像运行GradCAM,生成热力图
24 |
25 | 参数:
26 | arch: 模型架构名称
27 | model: PyTorch模型
28 | inp: 输入图像tensor
29 | targets: 可选的目标类别
30 |
31 | 返回:
32 | grayscale_cam: 灰度热力图
33 | out: 模型输出
34 | """
35 | if 'vit' in arch:
36 | return run_vit_gradcam(model, [model.blocks[-1].norm1], inp, targets)
37 | else:
38 | return run_cnn_gradcam(model, [model.layer4], inp, targets)
39 |
40 |
41 | def run_cnn_gradcam(model, target_layers, inp, targets=None):
42 | """
43 | 对CNN模型运行GradCAM
44 | """
45 | # 保存需要修改requires_grad的参数及其原始状态
46 | params_to_restore = []
47 |
48 | # 递归设置目标层中所有参数的requires_grad为True
49 | for layer in target_layers:
50 | for param in layer.parameters():
51 | if not param.requires_grad:
52 | params_to_restore.append((param, param.requires_grad))
53 | param.requires_grad_(True)
54 |
55 | try:
56 | with GradCAM(model=model, target_layers=target_layers, use_cuda=True) as cam:
57 | cam.batch_size = 32
58 | grayscale_cam, out = cam(input_tensor=inp, targets=targets)
59 | return grayscale_cam, out
60 | finally:
61 | # 恢复所有参数的原始requires_grad状态
62 | for param, orig_requires_grad in params_to_restore:
63 | param.requires_grad_(orig_requires_grad)
64 |
65 |
66 | def run_vit_gradcam(model, target_layers, inp, targets=None):
67 | """
68 | 对ViT模型运行GradCAM
69 | """
70 | with GradCAM(model=model, target_layers=target_layers,
71 | reshape_transform=reshape_transform, use_cuda=True) as cam:
72 | cam.batch_size = 32
73 | grayscale_cam, out = cam(input_tensor=inp, targets=targets)
74 | return grayscale_cam, out
--------------------------------------------------------------------------------
/ssl_backdoor/defenses/patchsearch/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | 模型相关工具函数,用于PatchSearch防御中的模型加载和特征提取。
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 | from tqdm import tqdm
8 | import torchvision.models as models
9 | import sys
10 |
11 | from ssl_backdoor.utils.model_utils import load_model_weights
12 |
13 |
14 | def get_model(arch, wts_path, dataset_name):
15 | """
16 | 加载预训练模型
17 |
18 | 参数:
19 | arch: 模型架构名称
20 | wts_path: 权重文件路径
21 | dataset_name: 数据集名称
22 |
23 | 返回:
24 | 加载的模型
25 | """
26 | if 'moco' in arch:
27 | model = models.__dict__[arch.replace('moco_', '')]()
28 | if 'imagenet' not in dataset_name:
29 | print("Using custom conv1 for small datasets")
30 | model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
31 | if dataset_name == "cifar10" or dataset_name == "cifar100":
32 | print("Using custom maxpool for cifar datasets")
33 | model.maxpool = nn.Identity()
34 | model.fc = nn.Sequential()
35 |
36 | sd = torch.load(wts_path)['state_dict']
37 | sd = {k.replace('module.', ''): v for k, v in sd.items()}
38 |
39 | sd = {k: v for k, v in sd.items() if 'encoder_q' in k or 'base_encoder' in k or 'backbone' in k or 'encoder' in k}
40 | sd = {k: v for k, v in sd.items() if 'fc' not in k}
41 |
42 | sd = {k.replace('encoder_q.', ''): v for k, v in sd.items()}
43 | sd = {k.replace('base_encoder.', ''): v for k, v in sd.items()}
44 | sd = {k.replace('backbone.', ''): v for k, v in sd.items()}
45 | sd = {k.replace('encoder.', ''): v for k, v in sd.items()}
46 | model.load_state_dict(sd, strict=True)
47 |
48 | elif 'resnet' in arch:
49 | model = models.__dict__[arch]()
50 | model.fc = nn.Sequential()
51 | load_weights(model, wts_path)
52 |
53 | else:
54 | raise ValueError('arch not found: ' + arch)
55 |
56 | model = model.eval()
57 |
58 | return model
59 |
60 |
61 | def get_feats(model, loader):
62 | """
63 | 从数据加载器中提取特征
64 |
65 | 参数:
66 | model: 模型
67 | loader: 数据加载器
68 |
69 | 返回:
70 | feats: 提取的特征
71 | labels: 对应的标签
72 | is_poisoned: 是否是有毒样本
73 | indices: 样本索引
74 | """
75 | model = nn.DataParallel(model).cuda()
76 | model.eval()
77 | feats, labels, indices, is_poisoned = [], [], [], []
78 | for data in tqdm(loader):
79 | if len(data) == 4:
80 | images, targets, is_p, inds = data
81 | else:
82 | images, targets = data
83 | with torch.no_grad():
84 | feats.append(model(images.cuda()).cpu())
85 | labels.append(targets)
86 | indices.append(inds)
87 | is_poisoned.append(is_p)
88 | feats = torch.cat(feats)
89 | labels = torch.cat(labels)
90 | indices = torch.cat(indices)
91 | is_poisoned = torch.cat(is_poisoned)
92 | feats /= feats.norm(2, dim=-1, keepdim=True)
93 | return feats, labels, is_poisoned, indices
94 |
95 |
96 | def get_channels(arch):
97 | """
98 | 获取模型的输出通道数
99 |
100 | 参数:
101 | arch: 模型架构名称
102 |
103 | 返回:
104 | 输出通道数
105 | """
106 | if 'resnet50' in arch:
107 | c = 2048
108 | elif 'resnet18' in arch:
109 | c = 512
110 | else:
111 | raise ValueError('arch not found: ' + arch)
112 | return c
--------------------------------------------------------------------------------
/ssl_backdoor/defenses/patchsearch/utils/patch_operations.py:
--------------------------------------------------------------------------------
1 | """
2 | 补丁操作相关工具函数,用于PatchSearch防御中的补丁提取与操作。
3 | """
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | import torch.nn as nn
8 | from tqdm import tqdm
9 | import numpy as np
10 | from PIL import Image
11 | from .gradcam import run_gradcam
12 |
13 |
14 | def denormalize(x, dataset):
15 | """
16 | 对图像进行反归一化处理
17 |
18 | 参数:
19 | x: 归一化后的图像tensor
20 | dataset: 数据集名称
21 |
22 | 返回:
23 | 反归一化后的图像tensor,取值范围[0, 1]
24 | """
25 | if x.shape[0] == 3:
26 | x = x.permute((1, 2, 0))
27 |
28 | if dataset == 'imagenet100':
29 | mean = torch.tensor([0.485, 0.456, 0.406], device=x.device)
30 | std = torch.tensor([0.229, 0.224, 0.225], device=x.device)
31 | elif dataset == 'cifar10':
32 | mean = torch.tensor([0.4914, 0.4822, 0.4465], device=x.device)
33 | std = torch.tensor([0.2023, 0.1994, 0.2010], device=x.device)
34 | elif dataset == 'stl10':
35 | mean = torch.tensor([0.485, 0.456, 0.406], device=x.device)
36 | std = torch.tensor([0.229, 0.224, 0.225], device=x.device)
37 | else:
38 | raise ValueError(f"Unknown dataset '{dataset}'")
39 | x = ((x * std) + mean)
40 |
41 | x = torch.clamp(x, 0, 1)
42 | return x
43 |
44 |
45 | def paste_patch(inputs, patch):
46 | """
47 | 将补丁粘贴到输入图像的随机位置
48 |
49 | 参数:
50 | inputs: 输入图像tensor,形状为[B, 3, H, W]
51 | patch: 补丁tensor,形状为[3, h, w]
52 |
53 | 返回:
54 | 粘贴补丁后的图像tensor
55 | """
56 | B = inputs.shape[0]
57 | inp_w = inputs.shape[-1]
58 | window_w = patch.shape[-1]
59 | ij = torch.randint(low=0, high=(inp_w - window_w), size=(B, 2))
60 | i, j = ij[:, 0], ij[:, 1]
61 |
62 | # 为窗口中的每个位置创建行和列索引
63 | s = torch.arange(window_w, device=inputs.device)
64 | ri = i.view(B, 1).repeat(1, window_w)
65 | rj = j.view(B, 1).repeat(1, window_w)
66 | sri, srj = ri + s, rj + s
67 |
68 | # 在列中重复起始行索引,反之亦然
69 | xi = sri.view(B, window_w, 1).repeat(1, 1, window_w)
70 | xj = srj.view(B, 1, window_w).repeat(1, window_w, 1)
71 |
72 | # 这些是2D索引,将它们转换为1D索引
73 | inds = xi * inp_w + xj
74 |
75 | # 跨颜色通道重复索引
76 | inds = inds.unsqueeze(1).repeat((1, 3, 1, 1)).view(B, 3, -1)
77 |
78 | # 将补丁从2D转为1D,并跨批次维度重复
79 | patch = patch.reshape(3, -1).unsqueeze(0).repeat(B, 1, 1)
80 |
81 | # 将图像从2D转为1D,散布补丁,将图像从1D转回2D
82 | inputs = inputs.reshape(B, 3, -1)
83 | inputs.scatter_(dim=2, index=inds, src=patch)
84 | inputs = inputs.reshape(B, 3, inp_w, inp_w)
85 | return inputs
86 |
87 |
88 | def block_max_window(cam_images, inputs, window_w=30):
89 | """
90 | 屏蔽输入图像中GradCAM激活最强的区域
91 |
92 | 参数:
93 | cam_images: GradCAM生成的热力图
94 | inputs: 输入图像tensor
95 | window_w: 窗口大小
96 |
97 | 返回:
98 | 屏蔽后的图像tensor
99 | """
100 | B, _, inp_w = cam_images.shape
101 | grayscale_cam = torch.from_numpy(cam_images)
102 | inputs = inputs.clone()
103 | sum_conv = torch.ones((1, 1, window_w, window_w))
104 |
105 | # 计算每个窗口的总和
106 | sums_cam = F.conv2d(grayscale_cam.unsqueeze(1), sum_conv)
107 |
108 | # 展平总和并取argmax
109 | flat_sums_cam = sums_cam.view(B, -1)
110 | ij = flat_sums_cam.argmax(dim=-1)
111 |
112 | # 分离行和列索引,这给了我们左上角窗口的位置
113 | sums_cam_w = sums_cam.shape[-1]
114 | i, j = ij // sums_cam_w, ij % sums_cam_w
115 |
116 | # 为窗口中的每个位置创建行和列索引
117 | s = torch.arange(window_w, device=inputs.device)
118 | ri = i.view(B, 1).repeat(1, window_w)
119 | rj = j.view(B, 1).repeat(1, window_w)
120 | sri, srj = ri + s, rj + s
121 |
122 | # 在列中重复起始行索引,反之亦然
123 | xi = sri.view(B, window_w, 1).repeat(1, 1, window_w)
124 | xj = srj.view(B, 1, window_w).repeat(1, window_w, 1)
125 |
126 | # 这些是2D索引,将它们转换为1D索引
127 | inds = xi * inp_w + xj
128 |
129 | # 跨颜色通道重复索引
130 | inds = inds.unsqueeze(1).repeat((1, 3, 1, 1)).view(B, 3, -1)
131 |
132 | # 将图像从2D转为1D,将窗口位置设为0,将图像从1D转回2D
133 | inputs = inputs.reshape(B, 3, -1)
134 | inputs.scatter_(dim=2, index=inds, value=0)
135 | inputs = inputs.reshape(B, 3, inp_w, inp_w)
136 | return inputs
137 |
138 |
139 | def extract_max_window(cam_images, inputs, window_w=30):
140 | """
141 | 从输入图像中提取GradCAM激活最强的区域
142 |
143 | 参数:
144 | cam_images: GradCAM生成的热力图
145 | inputs: 输入图像tensor
146 | window_w: 窗口大小
147 |
148 | 返回:
149 | 提取的窗口tensor
150 | """
151 | B, _, inp_w = cam_images.shape
152 | grayscale_cam = torch.from_numpy(cam_images)
153 | inputs = inputs.clone()
154 | sum_conv = torch.ones((1, 1, window_w, window_w))
155 |
156 | # 计算每个窗口的总和
157 | sums_cam = F.conv2d(grayscale_cam.unsqueeze(1), sum_conv)
158 |
159 | # 展平总和并取argmax
160 | flat_sums_cam = sums_cam.view(B, -1)
161 | ij = flat_sums_cam.argmax(dim=-1)
162 |
163 | # 分离行和列索引,这给了我们左上角窗口的位置
164 | sums_cam_w = sums_cam.shape[-1]
165 | i, j = ij // sums_cam_w, ij % sums_cam_w
166 |
167 | # 为窗口中的每个位置创建行和列索引
168 | s = torch.arange(window_w, device=inputs.device)
169 | ri = i.view(B, 1).repeat(1, window_w)
170 | rj = j.view(B, 1).repeat(1, window_w)
171 | sri, srj = ri + s, rj + s
172 |
173 | # 在列中重复起始行索引,反之亦然
174 | xi = sri.view(B, window_w, 1).repeat(1, 1, window_w)
175 | xj = srj.view(B, 1, window_w).repeat(1, window_w, 1)
176 |
177 | # 这些是2D索引,将它们转换为1D索引
178 | inds = xi * inp_w + xj
179 |
180 | # 跨颜色通道重复索引
181 | inds = inds.unsqueeze(1).repeat((1, 3, 1, 1)).view(B, 3, -1)
182 |
183 | # 将图像从2D转为1D
184 | inputs = inputs.reshape(B, 3, -1)
185 |
186 | # 收集窗口并将1D重塑为2D
187 | windows = torch.gather(inputs, dim=2, index=inds)
188 | windows = windows.reshape(B, 3, window_w, window_w)
189 |
190 | return windows
191 |
192 |
193 | def get_candidate_patches(model, loader, arch, window_w, repeat_patch):
194 | """
195 | 从数据加载器中获取候选补丁
196 |
197 | 参数:
198 | model: 模型
199 | loader: 数据加载器
200 | arch: 模型架构
201 | window_w: 窗口大小
202 | repeat_patch: 每张图像中提取的补丁数量
203 |
204 | 返回:
205 | 候选补丁tensor列表
206 | """
207 | candidate_patches = []
208 | for inp, _, _, _ in tqdm(loader):
209 | windows = []
210 | for _ in range(repeat_patch):
211 | cam_images, _ = run_gradcam(arch, model, inp)
212 | windows.append(extract_max_window(cam_images, inp, window_w))
213 | block_max_window(cam_images, inp, int(window_w * .5))
214 | windows = torch.stack(windows)
215 | windows = torch.einsum('kb...->bk...', windows)
216 | candidate_patches.append(windows.detach().cpu())
217 | candidate_patches = torch.cat(candidate_patches)
218 | return candidate_patches
219 |
220 |
221 | def save_patches(windows, save_dir, dataset):
222 | """
223 | 保存提取的补丁
224 |
225 | 参数:
226 | windows: 提取的补丁tensor
227 | save_dir: 保存目录
228 | dataset: 数据集名称
229 | """
230 | import os
231 | os.makedirs(save_dir, exist_ok=True)
232 |
233 | for i, win in enumerate(windows):
234 | win = denormalize(win, dataset)
235 | win = (win * 255).clamp(0, 255).numpy().astype(np.uint8)
236 | win = Image.fromarray(win)
237 | win.save(os.path.join(save_dir, f'{i:05d}.png'))
--------------------------------------------------------------------------------
/ssl_backdoor/defenses/patchsearch/utils/visualization.py:
--------------------------------------------------------------------------------
1 | """
2 | PatchSearch防御方法的可视化工具。
3 | """
4 |
5 | import os
6 | import torch
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 | from PIL import Image
10 |
11 |
12 | def denormalize(x, args):
13 | """
14 | 将归一化的图像张量转换回正常的图像张量。
15 |
16 | 参数:
17 | x: 归一化的图像张量
18 | args: 配置参数,包含dataset_name
19 |
20 | 返回:
21 | 去归一化的图像张量
22 | """
23 | if x.dim() == 4: # 批处理
24 | return torch.stack([denormalize(x_i, args) for x_i in x])
25 |
26 | if x.shape[0] == 3: # CHW -> HWC
27 | x = x.permute((1, 2, 0))
28 |
29 | if "imagenet" in args.dataset_name:
30 | mean = torch.tensor([0.485, 0.456, 0.406], device=x.device)
31 | std = torch.tensor([0.229, 0.224, 0.225], device=x.device)
32 | elif "cifar10" in args.dataset_name:
33 | mean = torch.tensor([0.4914, 0.4822, 0.4465], device=x.device)
34 | std = torch.tensor([0.2023, 0.1994, 0.2010], device=x.device)
35 | elif "stl10" in args.dataset_name:
36 | mean = torch.tensor([0.485, 0.456, 0.406], device=x.device)
37 | std = torch.tensor([0.229, 0.224, 0.225], device=x.device)
38 | else:
39 | raise ValueError(f"未知数据集 '{args.dataset_name}'")
40 |
41 | x = ((x * std) + mean)
42 | x = torch.clamp(x, 0, 1)
43 |
44 | return x
45 |
46 |
47 | def save_image(img_tensor, path, args=None):
48 | """
49 | 保存单个图像张量为PNG文件。
50 |
51 | 参数:
52 | img_tensor: 图像张量,形状为[C, H, W]或[H, W, C]
53 | path: 保存路径
54 | args: 配置参数
55 | """
56 | if args is not None:
57 | img_tensor = denormalize(img_tensor, args)
58 |
59 | if img_tensor.dim() == 3 and img_tensor.shape[0] == 3: # CHW -> HWC
60 | img_tensor = img_tensor.permute(1, 2, 0)
61 |
62 | # 转换为numpy数组,然后保存
63 | img_np = (img_tensor.detach().cpu().numpy() * 255).astype(np.uint8)
64 | img = Image.fromarray(img_np)
65 | img.save(path)
66 |
67 |
68 | def show_images_grid(inp, save_dir, title, args=None, max_images=40, nrows=8, ncols=5):
69 | """
70 | 显示图像网格并保存为PNG文件。
71 |
72 | 参数:
73 | inp: 图像张量,形状为[B, C, H, W]
74 | save_dir: 保存目录
75 | title: 图像标题
76 | args: 配置参数
77 | max_images: 最大显示图像数量
78 | nrows: 行数
79 | ncols: 列数
80 | """
81 | # 限制图像数量
82 | inp = inp[:max_images]
83 | n_images = inp.shape[0]
84 |
85 | # 创建必要的行列数
86 | if n_images < nrows * ncols:
87 | nrows = (n_images + ncols - 1) // ncols
88 |
89 | # 创建图像网格
90 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*2, nrows*2))
91 |
92 | for img_idx in range(n_images):
93 | if nrows == 1:
94 | ax = axes[img_idx % ncols]
95 | elif ncols == 1:
96 | ax = axes[img_idx % nrows]
97 | else:
98 | ax = axes[img_idx // ncols][img_idx % ncols]
99 |
100 | # 去归一化并显示图像
101 | if args is not None:
102 | rgb_image = denormalize(inp[img_idx], args).detach().cpu().numpy()
103 | else:
104 | rgb_image = inp[img_idx].detach().cpu().numpy()
105 | if rgb_image.shape[0] == 3: # CHW -> HWC
106 | rgb_image = rgb_image.transpose(1, 2, 0)
107 |
108 | ax.imshow(rgb_image)
109 | ax.set_xticks([])
110 | ax.set_yticks([])
111 |
112 | # 对未使用的子图去掉边框
113 | for img_idx in range(n_images, nrows * ncols):
114 | if nrows == 1:
115 | ax = axes[img_idx % ncols]
116 | elif ncols == 1:
117 | ax = axes[img_idx % nrows]
118 | else:
119 | ax = axes[img_idx // ncols][img_idx % ncols]
120 | ax.axis('off')
121 |
122 | plt.tight_layout()
123 |
124 | # 保存图像
125 | os.makedirs(save_dir, exist_ok=True)
126 | save_path = os.path.join(save_dir, title.lower().replace(' ', '-') + '.png')
127 | fig.savefig(save_path)
128 | plt.close(fig)
129 |
130 | return save_path
131 |
132 |
133 | def show_cam_on_image(inp, cam, save_dir, title, args=None, alpha=0.5):
134 | """
135 | 在图像上显示CAM热图并保存为PNG文件。
136 |
137 | 参数:
138 | inp: 图像张量,形状为[B, C, H, W]
139 | cam: CAM热图,形状为[B, H, W]
140 | save_dir: 保存目录
141 | title: 图像标题
142 | args: 配置参数
143 | alpha: 热图透明度
144 | """
145 | # 限制图像数量
146 | max_images = 16
147 | inp = inp[:max_images]
148 | cam = cam[:max_images]
149 | n_images = inp.shape[0]
150 |
151 | # 创建必要的行列数
152 | nrows = int(np.ceil(np.sqrt(n_images)))
153 | ncols = int(np.ceil(n_images / nrows))
154 |
155 | # 创建图像网格
156 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3, nrows*3))
157 |
158 | for img_idx in range(n_images):
159 | if nrows == 1 and ncols == 1:
160 | ax = axes
161 | elif nrows == 1:
162 | ax = axes[img_idx % ncols]
163 | elif ncols == 1:
164 | ax = axes[img_idx % nrows]
165 | else:
166 | ax = axes[img_idx // ncols][img_idx % ncols]
167 |
168 | # 去归一化图像
169 | if args is not None:
170 | rgb_image = denormalize(inp[img_idx], args).detach().cpu().numpy()
171 | else:
172 | rgb_image = inp[img_idx].detach().cpu().numpy()
173 | if rgb_image.shape[0] == 3: # CHW -> HWC
174 | rgb_image = rgb_image.transpose(1, 2, 0)
175 |
176 | # 获取热图
177 | heatmap = cam[img_idx]
178 | if isinstance(heatmap, torch.Tensor):
179 | heatmap = heatmap.detach().cpu().numpy()
180 |
181 | # 颜色映射
182 | cmap = plt.cm.jet
183 | heatmap = cmap(heatmap)[:, :, :3]
184 |
185 | # 叠加热图
186 | superimposed_img = rgb_image * (1 - alpha) + heatmap * alpha
187 |
188 | # 显示叠加图像
189 | ax.imshow(superimposed_img)
190 | ax.set_xticks([])
191 | ax.set_yticks([])
192 |
193 | # 对未使用的子图去掉边框
194 | for img_idx in range(n_images, nrows * ncols):
195 | if nrows == 1:
196 | if ncols == 1:
197 | ax = axes
198 | else:
199 | ax = axes[img_idx % ncols]
200 | elif ncols == 1:
201 | ax = axes[img_idx % nrows]
202 | else:
203 | ax = axes[img_idx // ncols][img_idx % ncols]
204 | ax.axis('off')
205 |
206 | plt.tight_layout()
207 |
208 | # 保存图像
209 | os.makedirs(save_dir, exist_ok=True)
210 | save_path = os.path.join(save_dir, title.lower().replace(' ', '-') + '.png')
211 | fig.savefig(save_path)
212 | plt.close(fig)
213 |
214 | return save_path
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainer import get_trainer
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 | from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead
7 |
8 | from ssl_backdoor.utils.model_utils import transform_encoder_for_small_dataset, remove_task_head_for_encoder
9 |
10 |
11 | class BYOL(nn.Module):
12 | """
13 | Build a BYOL model.
14 | """
15 | def __init__(self, base_encoder, dim=2048, proj_dim=1024, pred_dim=128, tau=0.99, dataset=None):
16 | """
17 | dim: feature dimension (default: 2048)
18 | proj_dim: hidden dimension of the projector (default: 1024)
19 | pred_dim: hidden dimension of the predictor (default: 128)
20 | tau: target network momentum (default: 0.99)
21 | """
22 | super(BYOL, self).__init__()
23 |
24 | # create the online encoder
25 | self.encoder = base_encoder(num_classes=dim, zero_init_residual=True)
26 | self.encoder = transform_encoder_for_small_dataset(self.encoder, dataset)
27 | self.encoder = remove_task_head_for_encoder(self.encoder)
28 |
29 | # create the target encoder
30 | self.momentum_encoder = base_encoder(num_classes=dim, zero_init_residual=True)
31 | self.momentum_encoder = transform_encoder_for_small_dataset(self.momentum_encoder, dataset)
32 | self.momentum_encoder = remove_task_head_for_encoder(self.momentum_encoder)
33 |
34 |
35 | self.projector = BYOLProjectionHead(input_dim=dim, hidden_dim=proj_dim, output_dim=pred_dim)
36 | self.momentum_projector = BYOLProjectionHead(input_dim=dim, hidden_dim=proj_dim, output_dim=pred_dim)
37 | self.predictor = BYOLPredictionHead(input_dim=pred_dim, hidden_dim=proj_dim, output_dim=pred_dim)
38 |
39 | # disable gradient for target encoder
40 | for param in self.momentum_encoder.parameters():
41 | param.requires_grad = False
42 | for param in self.momentum_projector.parameters():
43 | param.requires_grad = False
44 |
45 | # momentum coefficient
46 | self.tau = tau
47 |
48 | # copy parameters from online to target encoder
49 | self.update_target(0)
50 |
51 | def update_target(self, progress=None):
52 | """
53 | Update target network parameters using momentum update rule
54 | If progress is provided, use cosine schedule
55 | """
56 | tau = self.tau
57 | if progress is not None:
58 | # cosine schedule as in the original BYOL implementation
59 | tau = 1 - (1 - self.tau) * (math.cos(math.pi * progress) + 1) / 2
60 |
61 | for param_q, param_k in zip(self.encoder.parameters(), self.momentum_encoder.parameters()):
62 | param_k.data = param_k.data * tau + param_q.data * (1. - tau)
63 | for param_q, param_k in zip(self.projector.parameters(), self.momentum_projector.parameters()):
64 | param_k.data = param_k.data * tau + param_q.data * (1. - tau)
65 |
66 | def forward(self, x1, x2):
67 | """
68 | Input:
69 | x1: first views of images
70 | x2: second views of images
71 | Output:
72 | loss: BYOL loss
73 | """
74 | # compute online features
75 | z1 = self.encoder(x1) # NxC
76 | z2 = self.encoder(x2) # NxC
77 |
78 | z1_proj = self.projector(z1)
79 | z2_proj = self.projector(z2)
80 |
81 | # compute online predictions
82 | p1 = self.predictor(z1_proj) # NxC
83 | p2 = self.predictor(z2_proj) # NxC
84 | # compute target features (no gradient)
85 | with torch.no_grad():
86 | z1_target = self.momentum_encoder(x1)
87 | z2_target = self.momentum_encoder(x2)
88 |
89 | z1_target_proj = self.momentum_projector(z1_target)
90 | z2_target_proj = self.momentum_projector(z2_target)
91 |
92 | # normalize for cosine similarity
93 | p1 = F.normalize(p1, dim=-1)
94 | p2 = F.normalize(p2, dim=-1)
95 | z1_target_proj = F.normalize(z1_target_proj, dim=-1)
96 | z2_target_proj = F.normalize(z2_target_proj, dim=-1)
97 |
98 | # BYOL loss
99 | loss = 4 - 2 * (p1 * z2_target_proj).sum(dim=-1).mean() - 2 * (p2 * z1_target_proj).sum(dim=-1).mean()
100 |
101 | return loss
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/cfg.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import argparse
3 | from torchvision import models
4 | import multiprocessing
5 | from dataset import DS_LIST
6 | from methods import METHOD_LIST
7 |
8 |
9 | def get_cfg():
10 | """ generates configuration from user input in console """
11 | parser = argparse.ArgumentParser(description="")
12 | parser.add_argument(
13 | "--method", type=str, choices=METHOD_LIST, default="w_mse", help="loss type",
14 | )
15 |
16 |
17 | ### attack things
18 | parser.add_argument('--config', default=None, type=str, required=True,
19 | help='config file')
20 | parser.add_argument('--attack_algorithm', default='clean', type=str, required=True,
21 | help='attack_algorithm')
22 | # parser.add_argument('--attack_target', default=16, type=int, required=True,
23 | # help='attack target')
24 | # parser.add_argument('--attack_target_word', default=None, type=str, required=True,
25 | # help='attack target')
26 | # parser.add_argument('--poison_injection_rate', default=0.005, type=float, required=True,
27 | # help='poison_injection_rate')
28 | # parser.add_argument('--trigger_path', default=None, type=str, required=True,
29 | # help='trigger_path')
30 | # parser.add_argument('--trigger_size', default=60, type=int, required=True,
31 | # help='trigger_size')
32 |
33 | # # poisonencoder things
34 | # parser.add_argument('--support_ratio', default=0.2, type=float,
35 | # help='support_ratio')
36 | # parser.add_argument('--background_dir', default='/workspace/sync/SSL-Backdoor/poison-generation/poisonencoder_utils/places', type=str,
37 | # help='background_dir')
38 | # parser.add_argument('--reference_dir', default='/workspace/sync/SSL-Backdoor/poison-generation/poisonencoder_utils/references/', type=str,
39 | # help='reference_dir')
40 | # parser.add_argument('--num_references', default=3, type=int,
41 | # help='num_references')
42 | # parser.add_argument('--max_size', default=800, type=int,
43 | # help='max_size')
44 | # parser.add_argument('--area_ratio', default=2, type=int,
45 | # help='area_ratio')
46 | # parser.add_argument('--object_marginal', default=0.05, type=float,
47 | # help='object_marginal')
48 | # parser.add_argument('--trigger_marginal', default=0.25, type=float,
49 | # help='trigger_marginal')
50 |
51 | parser.add_argument(
52 | "--wandb",
53 | type=str,
54 | default="self_supervised",
55 | help="name of the project for logging at https://wandb.ai",
56 | )
57 | parser.add_argument(
58 | "--byol_tau", type=float, default=0.99, help="starting tau for byol loss"
59 | )
60 | parser.add_argument(
61 | "--num_samples",
62 | type=int,
63 | default=2,
64 | help="number of samples (d) generated from each image",
65 | )
66 |
67 | addf = partial(parser.add_argument, type=float)
68 | addf("--cj0", default=0.4, help="color jitter brightness")
69 | addf("--cj1", default=0.4, help="color jitter contrast")
70 | addf("--cj2", default=0.4, help="color jitter saturation")
71 | addf("--cj3", default=0.1, help="color jitter hue")
72 | addf("--cj_p", default=0.8, help="color jitter probability")
73 | addf("--gs_p", default=0.1, help="grayscale probability")
74 | addf("--crop_s0", default=0.2, help="crop size from")
75 | addf("--crop_s1", default=1.0, help="crop size to")
76 | addf("--crop_r0", default=0.75, help="crop ratio from")
77 | addf("--crop_r1", default=(4 / 3), help="crop ratio to")
78 | addf("--hf_p", default=0.5, help="horizontal flip probability")
79 |
80 | parser.add_argument(
81 | "--no_lr_warmup",
82 | dest="lr_warmup",
83 | action="store_false",
84 | help="do not use learning rate warmup",
85 | )
86 | parser.add_argument(
87 | "--no_add_bn", dest="add_bn", action="store_false", help="do not use BN in head"
88 | )
89 | parser.add_argument("--knn", type=int, default=5, help="k in k-nn classifier")
90 | parser.add_argument("--fname", type=str, help="load model from file")
91 | parser.add_argument(
92 | "--lr_step",
93 | type=str,
94 | choices=["cos", "step", "none"],
95 | default="step",
96 | help="learning rate schedule type",
97 | )
98 | parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
99 | parser.add_argument(
100 | "--eta_min", type=float, default=0, help="min learning rate (for --lr_step cos)"
101 | )
102 | parser.add_argument(
103 | "--adam_l2", type=float, default=1e-6, help="weight decay (L2 penalty)"
104 | )
105 | parser.add_argument("--T0", type=int, help="period (for --lr_step cos)")
106 | parser.add_argument(
107 | "--Tmult", type=int, default=1, help="period factor (for --lr_step cos)"
108 | )
109 | parser.add_argument(
110 | "--w_eps", type=float, default=0, help="eps for stability for whitening"
111 | )
112 | parser.add_argument(
113 | "--head_layers", type=int, default=2, help="number of FC layers in head"
114 | )
115 | parser.add_argument(
116 | "--head_size", type=int, default=1024, help="size of FC layers in head"
117 | )
118 |
119 | parser.add_argument(
120 | "--w_size", type=int, default=128, help="size of sub-batch for W-MSE loss"
121 | )
122 | parser.add_argument(
123 | "--w_iter",
124 | type=int,
125 | default=1,
126 | help="iterations for whitening matrix estimation",
127 | )
128 |
129 | parser.add_argument(
130 | "--no_norm", dest="norm", action="store_false", help="don't normalize latents",
131 | )
132 | parser.add_argument(
133 | "--tau", type=float, default=0.5, help="contrastive loss temperature"
134 | )
135 |
136 | parser.add_argument("--epoch", type=int, default=200, help="total epoch number")
137 | parser.add_argument(
138 | "--eval_every_drop",
139 | type=int,
140 | default=5,
141 | help="how often to evaluate after learning rate drop",
142 | )
143 | parser.add_argument(
144 | "--eval_every", type=int, default=20, help="how often to evaluate"
145 | )
146 | parser.add_argument("--emb", type=int, default=64, help="embedding size")
147 | parser.add_argument(
148 | "--bs", type=int, default=512, help="number of original images in batch N",
149 | )
150 | parser.add_argument(
151 | "--drop",
152 | type=int,
153 | nargs="*",
154 | default=[50, 25],
155 | help="milestones for learning rate decay (0 = last epoch)",
156 | )
157 | parser.add_argument(
158 | "--drop_gamma",
159 | type=float,
160 | default=0.2,
161 | help="multiplicative factor of learning rate decay",
162 | )
163 | parser.add_argument(
164 | "--arch",
165 | type=str,
166 | choices=[x for x in dir(models) if "resn" in x],
167 | default="resnet18",
168 | help="encoder architecture",
169 | )
170 | parser.add_argument(
171 | "--num_workers",
172 | type=int,
173 | default=8,
174 | help="dataset workers number",
175 | )
176 | parser.add_argument(
177 | "--clf",
178 | type=str,
179 | default="sgd",
180 | choices=["sgd", "knn", "lbfgs"],
181 | help="classifier for test.py",
182 | )
183 | parser.add_argument(
184 | "--eval_head", action="store_true", help="eval head output instead of model",
185 | )
186 | parser.add_argument("--exp_id", type=str, default="")
187 |
188 | parser.add_argument("--clf_chkpt", type=str, default="")
189 | parser.add_argument("--evaluate", dest="evaluate", action="store_true")
190 | parser.add_argument("--eval_data", type=str, default="")
191 | parser.add_argument("--save_folder", type=str, default="./output")
192 | parser.add_argument("--save-freq", type=int, default=50)
193 | return parser.parse_args()
194 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/cifar10_classes.txt:
--------------------------------------------------------------------------------
1 | airplane
2 | automobile
3 | bird
4 | cat
5 | deer
6 | dog
7 | frog
8 | horse
9 | ship
10 | truck
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/cifar10_metadata.txt:
--------------------------------------------------------------------------------
1 | airplane 0
2 | automobile 1
3 | bird 2
4 | cat 3
5 | deer 4
6 | dog 5
7 | frog 6
8 | horse 7
9 | ship 8
10 | truck 9
11 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .cifar10 import CIFAR10
2 | from .cifar100 import CIFAR100
3 | from .stl10 import STL10
4 | from .tiny_in import TinyImageNet
5 | from .imagenet import ImageNet
6 |
7 |
8 | DS_LIST = ["cifar10", "cifar100", "stl10", "tiny_in", "imagenet-100"]
9 |
10 |
11 | def get_ds(name):
12 | assert name in DS_LIST
13 | if name == "cifar10":
14 | return CIFAR10
15 | elif name == "cifar100":
16 | return CIFAR100
17 | elif name == "stl10":
18 | return STL10
19 | elif name == "tiny_in":
20 | return TinyImageNet
21 | elif name == "imagenet-100":
22 | return ImageNet
23 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/dataset/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 | from functools import lru_cache
3 | from torch.utils.data import DataLoader
4 |
5 |
6 | class BaseDataset(metaclass=ABCMeta):
7 | """
8 | base class for datasets, it includes 3 types:
9 | - for self-supervised training,
10 | - for classifier training for evaluation,
11 | - for testing
12 | """
13 |
14 | def __init__(
15 | self, bs_train, aug_cfg, num_workers, bs_clf=1000, bs_test=1000,
16 | ):
17 | self.aug_cfg = aug_cfg
18 | self.bs_train, self.bs_clf, self.bs_test = bs_train, bs_clf, bs_test
19 | self.num_workers = num_workers
20 |
21 |
22 |
23 | @abstractmethod
24 | def ds_train(self):
25 | raise NotImplementedError
26 |
27 | @abstractmethod
28 | def ds_clf(self):
29 | raise NotImplementedError
30 |
31 | @abstractmethod
32 | def ds_test(self):
33 | raise NotImplementedError
34 |
35 | #
36 | @abstractmethod
37 | def ds_test_p(self):
38 | raise NotImplementedError
39 |
40 | @property
41 | @lru_cache()
42 | def train(self):
43 | return DataLoader(
44 | dataset=self.ds_train(),
45 | batch_size=self.bs_train,
46 | shuffle=True,
47 | num_workers=self.num_workers,
48 | pin_memory=True,
49 | drop_last=True,
50 | )
51 |
52 | @property
53 | @lru_cache()
54 | def clf(self):
55 | return DataLoader(
56 | dataset=self.ds_clf(),
57 | batch_size=self.bs_clf,
58 | shuffle=True,
59 | num_workers=self.num_workers,
60 | pin_memory=True,
61 | drop_last=True,
62 | )
63 |
64 | @property
65 | @lru_cache()
66 | def test(self):
67 | return DataLoader(
68 | dataset=self.ds_test(),
69 | batch_size=self.bs_test,
70 | shuffle=False,
71 | num_workers=self.num_workers,
72 | pin_memory=True,
73 | drop_last=False,
74 | )
75 |
76 | @property
77 | @lru_cache()
78 | def test_p(self):
79 | return DataLoader(
80 | dataset=self.ds_test_p(),
81 | batch_size=self.bs_test,
82 | shuffle=False,
83 | num_workers=self.num_workers,
84 | pin_memory=True,
85 | drop_last=False,
86 | )
87 |
88 | @property
89 | @lru_cache()
90 | def test_poison(self, attack_target=None):
91 | return DataLoader(
92 | dataset=self.ds_test_poison(attack_target),
93 | batch_size=self.bs_test,
94 | shuffle=False,
95 | num_workers=self.num_workers,
96 | pin_memory=True,
97 | drop_last=False,
98 | )
99 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/dataset/cifar10.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from torchvision.datasets import CIFAR10 as C10
4 | import torchvision.transforms as T
5 | from .transforms import MultiSample, aug_transform
6 | from .base import BaseDataset
7 | from PIL import ImageFilter, Image
8 | from torch.utils import data
9 |
10 | import moco.loader
11 | import moco.builder
12 | import moco.dataset3
13 |
14 | class FileListDataset(data.Dataset):
15 | def __init__(self, path_to_txt_file, transform):
16 | with open(path_to_txt_file, 'r') as f:
17 | self.file_list = f.readlines()
18 | self.file_list = [row.rstrip() for row in self.file_list]
19 |
20 | self.transform = transform
21 |
22 |
23 | def __getitem__(self, idx):
24 | image_path = self.file_list[idx].split()[0]
25 | img = Image.open(image_path).convert('RGB')
26 | target = int(self.file_list[idx].split()[1])
27 |
28 | if self.transform is not None:
29 | images = self.transform(img)
30 |
31 | return images, target
32 |
33 | def __len__(self):
34 | return len(self.file_list)
35 |
36 | class RandomBlur:
37 | def __init__(self, r0, r1):
38 | self.r0, self.r1 = r0, r1
39 |
40 | def __call__(self, image):
41 | r = random.uniform(self.r0, self.r1)
42 | return image.filter(ImageFilter.GaussianBlur(radius=r))
43 |
44 | def base_transform():
45 | return T.Compose(
46 | [T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
47 | )
48 |
49 | def base_transform_linear_probe():
50 | return T.Compose(
51 | [T.RandomCrop(32), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
52 | )
53 |
54 | class CIFAR10(BaseDataset):
55 |
56 | def get_train_dataset(self, args, file_path, transform):
57 | args = self.aug_cfg
58 | args.image_size = 224
59 | if args.attack_algorithm == 'backog':
60 | train_dataset = moco.dataset3.BackOGTrainDataset(
61 | args,
62 | file_path,
63 | transform)
64 | elif args.attack_algorithm == 'corruptencoder':
65 | train_dataset = moco.dataset3.CorruptEncoderTrainDataset(
66 | args,
67 | file_path,
68 | transform)
69 | elif args.attack_algorithm == 'sslbkd':
70 | train_dataset = moco.dataset3.SSLBackdoorTrainDataset(
71 | args,
72 | file_path,
73 | transform)
74 | elif args.attack_algorithm == 'ctrl':
75 | train_dataset = moco.dataset3.CTRLTrainDataset(
76 | args,
77 | file_path,
78 | transform)
79 | else:
80 | raise ValueError(f"Unknown attack algorithm '{args.attack_algorithm}'")
81 |
82 | return train_dataset
83 |
84 | def ds_train(self):
85 | aug_with_blur = aug_transform(
86 | 32,
87 | base_transform,
88 | self.aug_cfg,
89 | extra_t=[T.RandomApply([RandomBlur(0.1, 2.0)], p=0.5)],
90 | )
91 | t = MultiSample(aug_with_blur, n=self.aug_cfg.num_samples)
92 | self.pretrain_set=self.get_train_dataset(self.aug_cfg, self.aug_cfg.train_file_path, t)
93 | return self.pretrain_set
94 |
95 | # Do not pre resize images like in original repo
96 | def ds_clf(self):
97 | t = base_transform_linear_probe()
98 | return FileListDataset(path_to_txt_file=self.aug_cfg.train_clean_file_path, transform=t)
99 |
100 | def ds_test(self):
101 | t = base_transform()
102 | return FileListDataset(path_to_txt_file=self.aug_cfg.val_file_path, transform=t)
103 |
104 | def ds_test_p(self):
105 | t = base_transform()
106 | return moco.dataset3.UniversalPoisonedValDataset(self.aug_cfg, self.aug_cfg.val_file_path, transform=t)
107 |
108 |
109 | # class CIFAR10(BaseDataset):
110 | # def ds_train(self):
111 | # t = MultiSample(
112 | # aug_transform(32, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples
113 | # )
114 | # return C10(root="./data", train=True, download=True, transform=t)
115 |
116 | # def ds_clf(self):
117 | # t = base_transform()
118 | # return C10(root="./data", train=True, download=True, transform=t)
119 |
120 | # def ds_test(self):
121 | # t = base_transform()
122 | # return C10(root="./data", train=False, download=True, transform=t)
123 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/dataset/cifar100.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import CIFAR100 as C100
2 | import torchvision.transforms as T
3 | from .transforms import MultiSample, aug_transform
4 | from .base import BaseDataset
5 |
6 |
7 | def base_transform():
8 | return T.Compose(
9 | [T.ToTensor(), T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))]
10 | )
11 |
12 |
13 | class CIFAR100(BaseDataset):
14 | def ds_train(self):
15 | t = MultiSample(
16 | aug_transform(32, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples
17 | )
18 | return C100(root="./data", train=True, download=True, transform=t,)
19 |
20 | def ds_clf(self):
21 | t = base_transform()
22 | return C100(root="./data", train=True, download=True, transform=t)
23 |
24 | def ds_test(self):
25 | t = base_transform()
26 | return C100(root="./data", train=False, download=True, transform=t)
27 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/dataset/imagenet.py:
--------------------------------------------------------------------------------
1 | import random
2 | import sys
3 | import os
4 | from torchvision.datasets import ImageFolder
5 | import torchvision.transforms as T
6 | from PIL import ImageFilter, Image
7 | from .transforms import MultiSample, aug_transform
8 | from .base import BaseDataset
9 | from torch.utils import data
10 |
11 | import moco.loader
12 | import moco.builder
13 |
14 | # current_dir = os.path.dirname(os.path.abspath(__file__))
15 | sys.path.append("/workspace/SSL-Backdoor")
16 | print(sys.path)
17 | import datasets.dataset
18 |
19 |
20 | class FileListDataset(data.Dataset):
21 | def __init__(self, path_to_txt_file, transform):
22 | with open(path_to_txt_file, 'r') as f:
23 | self.file_list = f.readlines()
24 | self.file_list = [row.rstrip() for row in self.file_list]
25 |
26 | self.transform = transform
27 |
28 |
29 | def __getitem__(self, idx):
30 | image_path = self.file_list[idx].split()[0]
31 | img = Image.open(image_path).convert('RGB')
32 | target = int(self.file_list[idx].split()[1])
33 |
34 | if self.transform is not None:
35 | images = self.transform(img)
36 |
37 | return images, target
38 |
39 | def __len__(self):
40 | return len(self.file_list)
41 |
42 | class RandomBlur:
43 | def __init__(self, r0, r1):
44 | self.r0, self.r1 = r0, r1
45 |
46 | def __call__(self, image):
47 | r = random.uniform(self.r0, self.r1)
48 | return image.filter(ImageFilter.GaussianBlur(radius=r))
49 |
50 |
51 | def base_transform():
52 | return T.Compose(
53 | [T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
54 | )
55 |
56 | # Change this to be consistent with MoCo v2 eval
57 | def base_transform_eval():
58 | return T.Compose(
59 | [T.Resize(256), T.CenterCrop(224), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
60 | )
61 | def base_transform_linear_probe():
62 | return T.Compose(
63 | [T.RandomResizedCrop(224), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
64 | )
65 |
66 | class ImageNet(BaseDataset):
67 |
68 | def get_train_dataset(self, args, file_path, transform):
69 | args = self.aug_cfg
70 | args.image_size = 224
71 |
72 | # attack_algorithm 和 dataset 的映射
73 | dataset_classes = {
74 | 'bp': datasets.dataset.BPTrainDataset,
75 | 'blto': datasets.dataset.BltoPoisoningPoisonedTrainDataset,
76 | 'corruptencoder': datasets.dataset.CorruptEncoderTrainDataset,
77 | 'sslbkd': datasets.dataset.SSLBackdoorTrainDataset,
78 | 'ctrl': datasets.dataset.CTRLTrainDataset,
79 | # 'randombackground': datasets.dataset.RandomBackgroundTrainDataset,
80 | 'clean': datasets.dataset.FileListDataset,
81 | }
82 |
83 | if args.attack_algorithm not in dataset_classes:
84 | raise ValueError(f"Unknown attack algorithm '{args.attack_algorithm}'")
85 |
86 | train_dataset = dataset_classes[args.attack_algorithm](args, file_path, transform)
87 |
88 |
89 | return train_dataset
90 |
91 | def ds_train(self):
92 | aug_with_blur = aug_transform(
93 | 224,
94 | base_transform,
95 | self.aug_cfg,
96 | extra_t=[T.RandomApply([RandomBlur(0.1, 2.0)], p=0.5)],
97 | )
98 | t = MultiSample(aug_with_blur, n=self.aug_cfg.num_samples)
99 |
100 | self.pretrain_set=self.get_train_dataset(self.aug_cfg, self.aug_cfg.data, t)
101 |
102 | return self.pretrain_set
103 |
104 | # Do not pre resize images like in original repo
105 | def ds_clf(self):
106 | raise NotImplementedError
107 | t = base_transform_linear_probe()
108 | return FileListDataset(path_to_txt_file=self.aug_cfg.train_clean_file_path, transform=t)
109 |
110 | def ds_test(self):
111 | raise NotImplementedError
112 | t = base_transform_eval()
113 | return FileListDataset(path_to_txt_file=self.aug_cfg.val_file_path, transform=t)
114 |
115 | def ds_test_p(self):
116 | raise NotImplementedError
117 | t = base_transform_eval()
118 | return datasets.dataset.UniversalPoisonedValDataset(self.aug_cfg, self.aug_cfg.val_file_path, transform=t)
119 | # return FileListDataset(path_to_txt_file=self.aug_cfg.val_poisoned_file_path, transform=t)
120 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/dataset/stl10.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import STL10 as S10
2 | import torchvision.transforms as T
3 | from .transforms import MultiSample, aug_transform
4 | from .base import BaseDataset
5 |
6 |
7 | def base_transform():
8 | return T.Compose(
9 | [T.ToTensor(), T.Normalize((0.43, 0.42, 0.39), (0.27, 0.26, 0.27))]
10 | )
11 |
12 |
13 | def test_transform():
14 | return T.Compose(
15 | [T.Resize(70, interpolation=3), T.CenterCrop(64), base_transform()]
16 | )
17 |
18 |
19 | class STL10(BaseDataset):
20 | def ds_train(self):
21 | t = MultiSample(
22 | aug_transform(64, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples
23 | )
24 | return S10(root="./data", split="train+unlabeled", download=True, transform=t)
25 |
26 | def ds_clf(self):
27 | t = test_transform()
28 | return S10(root="./data", split="train", download=True, transform=t)
29 |
30 | def ds_test(self):
31 | t = test_transform()
32 | return S10(root="./data", split="test", download=True, transform=t)
33 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/dataset/tiny_in.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import ImageFolder
2 | import torchvision.transforms as T
3 | from .transforms import MultiSample, aug_transform
4 | from .base import BaseDataset
5 |
6 |
7 | def base_transform():
8 | return T.Compose(
9 | [T.ToTensor(), T.Normalize((0.480, 0.448, 0.398), (0.277, 0.269, 0.282))]
10 | )
11 |
12 |
13 | class TinyImageNet(BaseDataset):
14 | def ds_train(self):
15 | t = MultiSample(
16 | aug_transform(64, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples
17 | )
18 | return ImageFolder(root="data/tiny-imagenet-200/train", transform=t)
19 |
20 | def ds_clf(self):
21 | t = base_transform()
22 | return ImageFolder(root="data/tiny-imagenet-200/train", transform=t)
23 |
24 | def ds_test(self):
25 | t = base_transform()
26 | return ImageFolder(root="data/tiny-imagenet-200/test", transform=t)
27 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/dataset/transforms.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms as T
2 |
3 | # Change the order
4 | def aug_transform(crop, base_transform, cfg, extra_t=[]):
5 | """ augmentation transform generated from config """
6 | return T.Compose(
7 | [
8 | T.RandomResizedCrop(
9 | crop,
10 | scale=(cfg.crop_s0, cfg.crop_s1),
11 | ratio=(cfg.crop_r0, cfg.crop_r1),
12 | interpolation=3,
13 | ),
14 | T.RandomApply(
15 | [T.ColorJitter(cfg.cj0, cfg.cj1, cfg.cj2, cfg.cj3)], p=cfg.cj_p
16 | ),
17 | T.RandomGrayscale(p=cfg.gs_p),
18 | *extra_t,
19 | T.RandomHorizontalFlip(p=cfg.hf_p),
20 | base_transform(),
21 | ]
22 | )
23 |
24 |
25 | class MultiSample:
26 | """ generates n samples with augmentation """
27 |
28 | def __init__(self, transform, n=2):
29 | self.transform = transform
30 | self.num = n
31 |
32 | def __call__(self, x):
33 | return tuple(self.transform(x) for _ in range(self.num))
34 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/eval/get_data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | def get_data(model, loader, output_size, device):
4 | """ encodes whole dataset into embeddings """
5 | n_total_samples = len(loader.dataset)
6 | xs = torch.empty(n_total_samples, output_size, dtype=torch.float32, device=device)
7 | ys = torch.empty(n_total_samples, dtype=torch.long, device=device)
8 | start_idx = 0
9 | added_count = 0 # 记录添加了多少个数据
10 |
11 | with torch.no_grad():
12 | for x, y in tqdm(loader):
13 | x = x.to(device)
14 | batch_size = x.shape[0]
15 | end_idx = start_idx + batch_size
16 |
17 | xs[start_idx:end_idx] = model(x)
18 | ys[start_idx:end_idx] = y.to(device)
19 |
20 | start_idx = end_idx
21 | added_count += batch_size # 更新添加了多少个数据
22 |
23 | # 删除未使用的部分
24 | xs = xs[:added_count]
25 | ys = ys[:added_count]
26 |
27 | return xs, ys
28 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/eval/knn.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def eval_knn(x_train, y_train, x_test, y_test, k=5):
5 | """ k-nearest neighbors classifier accuracy """
6 | d = torch.cdist(x_test, x_train)
7 | topk = torch.topk(d, k=k, dim=1, largest=False)
8 | labels = y_train[topk.indices]
9 | pred = torch.empty_like(y_test)
10 | for i in range(len(labels)):
11 | x = labels[i].unique(return_counts=True)
12 | pred[i] = x[0][x[1].argmax()]
13 |
14 | acc = (pred == y_test).float().mean().cpu().item()
15 | del d, topk, labels, pred
16 | return acc
17 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/eval/lbfgs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from sklearn.linear_model import LogisticRegression
3 |
4 |
5 | def eval_lbfgs(x_train, y_train, x_test, y_test):
6 | """ linear classifier accuracy (lbfgs method) """
7 | clf = LogisticRegression(
8 | random_state=1337, solver="lbfgs", max_iter=1000, n_jobs=-1
9 | )
10 | clf.fit(x_train, y_train)
11 | pred = clf.predict(x_test)
12 | return (torch.tensor(pred) == y_test).float().mean()
13 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/eval/sgd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 |
5 | import os
6 | from tqdm import trange, tqdm
7 |
8 | # Modified from original
9 | # def eval_sgd(x_train, y_train, x_test, y_test, x_test_p=None, y_test_p=None, evaluate=False, topk=[1, 5], epoch=100):
10 | # """ linear classifier accuracy (sgd) """
11 |
12 | # lr_start, lr_end = 1e-2, 1e-6
13 | # gamma = (lr_end / lr_start) ** (1 / epoch)
14 | # output_size = x_train.shape[1]
15 | # num_class = y_train.max().item() + 1
16 |
17 | # clf = nn.Linear(output_size, num_class)
18 | # clf.cuda()
19 | # clf.train()
20 |
21 | # optimizer = optim.Adam(clf.parameters(), lr=lr_start, weight_decay=5e-6)
22 | # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)
23 | # criterion = nn.CrossEntropyLoss()
24 |
25 | # from torch.utils.data import DataLoader, TensorDataset
26 | # # 假设 x_train 和 y_train 已经定义
27 | # dataset = TensorDataset(x_train, y_train)
28 | # dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
29 | # best_acc = 0.0 # 初始化最佳精度
30 | # best_clf_state_dict = clf.state_dict()
31 |
32 | # pbar = tqdm(range(epoch))
33 | # for ep in pbar:
34 | # for batch_idx, (batch_x, batch_y) in enumerate(dataloader):
35 | # optimizer.zero_grad()
36 | # criterion(clf(batch_x), batch_y).backward()
37 | # optimizer.step()
38 | # scheduler.step()
39 |
40 | # # 在每个epoch结束时,计算测试集的精度
41 | # clf.eval()
42 | # with torch.no_grad():
43 | # y_pred = clf(x_test)
44 | # pred_top = y_pred.topk(max(topk), 1, largest=True, sorted=True).indices
45 | # acc = {
46 | # t: (pred_top[:, :t] == y_test[..., None]).float().sum(1).mean().cpu().item()
47 | # for t in topk
48 | # }
49 | # # 如果当前精度高于最佳精度,则保存模型并更新最佳精度
50 | # if acc[1] > best_acc:
51 | # best_acc = acc[1]
52 | # best_clf_state_dict = clf.state_dict()
53 |
54 | # pbar.set_postfix({'best_acc': best_acc})
55 | # clf.train()
56 |
57 | # clf.load_state_dict(best_clf_state_dict)
58 | # clf.eval()
59 | # with torch.no_grad():
60 | # y_pred = clf(x_test)
61 | # pred_top = y_pred.topk(max(topk), 1, largest=True, sorted=True).indices
62 | # acc = {
63 | # t: (pred_top[:, :t] == y_test[..., None]).float().sum(1).mean().cpu().item()
64 | # for t in topk
65 | # }
66 |
67 | # if not evaluate:
68 | # return acc
69 | # else:
70 | # clf.eval()
71 | # with torch.no_grad():
72 | # y_pred = clf(x_test)
73 | # pred_top = y_pred.topk(max(topk), 1, largest=True, sorted=True).indices
74 | # acc = {
75 | # t: (pred_top[:, :t] == y_test[..., None]).float().sum(1).mean().cpu().item()
76 | # for t in topk
77 | # }
78 |
79 | # with torch.no_grad():
80 | # y_pred_p = clf(x_test_p)
81 | # pred_top_p = y_pred_p.topk(max(topk), 1, largest=True, sorted=True).indices
82 | # acc_p = {
83 | # t: (pred_top_p[:, :t] == y_test_p[..., None]).float().sum(1).mean().cpu().item()
84 | # for t in topk
85 | # }
86 | # return acc, y_pred, y_test, acc_p, y_pred_p, y_test_p
87 |
88 | def get_data(model, loader, output_size, device):
89 | """ encodes whole dataset into embeddings """
90 | n_total_samples = len(loader.dataset)
91 | xs = torch.empty(n_total_samples, output_size, dtype=torch.float32, device=device)
92 | ys = torch.empty(n_total_samples, dtype=torch.long, device=device)
93 | start_idx = 0
94 | added_count = 0 # 记录添加了多少个数据
95 |
96 | with torch.no_grad():
97 | for x, y in tqdm(loader):
98 | x = x.to(device)
99 | batch_size = x.shape[0]
100 | end_idx = start_idx + batch_size
101 |
102 | xs[start_idx:end_idx] = model(x)
103 | ys[start_idx:end_idx] = y.to(device)
104 |
105 | start_idx = end_idx
106 | added_count += batch_size # 更新添加了多少个数据
107 |
108 | # 删除未使用的部分
109 | xs = xs[:added_count]
110 | ys = ys[:added_count]
111 |
112 | return xs, ys
113 |
114 | def eval_sgd(model, clean_set, test_set, test_poisoned_set, out_size, device, evaluate=True, topk=[1, 5], epoch=100):
115 | """ linear classifier accuracy (sgd) """
116 |
117 | lr_start, lr_end = 1e-2, 1e-6
118 | gamma = (lr_end / lr_start) ** (1 / epoch)
119 | output_size = out_size
120 | num_class = 100
121 |
122 | x_test, y_test = get_data(model, test_set, out_size, device)
123 | x_test_p, y_test_p = get_data(model, test_poisoned_set, out_size, device)
124 |
125 | clf = nn.Linear(output_size, num_class)
126 | clf.cuda()
127 | clf.train()
128 |
129 | optimizer = optim.Adam(clf.parameters(), lr=lr_start, weight_decay=5e-6)
130 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)
131 | criterion = nn.CrossEntropyLoss()
132 |
133 | from torch.utils.data import DataLoader, TensorDataset
134 | # 假设 x_train 和 y_train 已经定义
135 |
136 | best_acc = 0.0 # 初始化最佳精度
137 | best_clf_state_dict = clf.state_dict()
138 |
139 | pbar = tqdm(range(epoch))
140 | for ep in pbar:
141 | x_train, y_train = get_data(model, clean_set, out_size, device)
142 | dataset = TensorDataset(x_train, y_train)
143 | dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
144 | for batch_idx, (batch_x, batch_y) in enumerate(dataloader):
145 | optimizer.zero_grad()
146 | criterion(clf(batch_x), batch_y).backward()
147 | optimizer.step()
148 | scheduler.step()
149 |
150 | # 在每个epoch结束时,计算测试集的精度
151 | clf.eval()
152 | with torch.no_grad():
153 | y_pred = clf(x_test)
154 | pred_top = y_pred.topk(max(topk), 1, largest=True, sorted=True).indices
155 | acc = {
156 | t: (pred_top[:, :t] == y_test[..., None]).float().sum(1).mean().cpu().item()
157 | for t in topk
158 | }
159 | # 如果当前精度高于最佳精度,则保存模型并更新最佳精度
160 | if acc[1] > best_acc:
161 | best_acc = acc[1]
162 | best_clf_state_dict = clf.state_dict()
163 |
164 | pbar.set_postfix({'best_acc': best_acc})
165 | clf.train()
166 |
167 | clf.load_state_dict(best_clf_state_dict)
168 | clf.eval()
169 | with torch.no_grad():
170 | y_pred = clf(x_test)
171 | pred_top = y_pred.topk(max(topk), 1, largest=True, sorted=True).indices
172 | acc = {
173 | t: (pred_top[:, :t] == y_test[..., None]).float().sum(1).mean().cpu().item()
174 | for t in topk
175 | }
176 |
177 | if not evaluate:
178 | return acc
179 | else:
180 | clf.eval()
181 | with torch.no_grad():
182 | y_pred = clf(x_test)
183 | pred_top = y_pred.topk(max(topk), 1, largest=True, sorted=True).indices
184 | acc = {
185 | t: (pred_top[:, :t] == y_test[..., None]).float().sum(1).mean().cpu().item()
186 | for t in topk
187 | }
188 |
189 | with torch.no_grad():
190 | y_pred_p = clf(x_test_p)
191 | pred_top_p = y_pred_p.topk(max(topk), 1, largest=True, sorted=True).indices
192 | acc_p = {
193 | t: (pred_top_p[:, :t] == y_test_p[..., None]).float().sum(1).mean().cpu().item()
194 | for t in topk
195 | }
196 | return acc, y_pred, y_test, acc_p, y_pred_p, y_test_p
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/imagenet100_classes.txt:
--------------------------------------------------------------------------------
1 | n02869837
2 | n01749939
3 | n02488291
4 | n02107142
5 | n13037406
6 | n02091831
7 | n04517823
8 | n04589890
9 | n03062245
10 | n01773797
11 | n01735189
12 | n07831146
13 | n07753275
14 | n03085013
15 | n04485082
16 | n02105505
17 | n01983481
18 | n02788148
19 | n03530642
20 | n04435653
21 | n02086910
22 | n02859443
23 | n13040303
24 | n03594734
25 | n02085620
26 | n02099849
27 | n01558993
28 | n04493381
29 | n02109047
30 | n04111531
31 | n02877765
32 | n04429376
33 | n02009229
34 | n01978455
35 | n02106550
36 | n01820546
37 | n01692333
38 | n07714571
39 | n02974003
40 | n02114855
41 | n03785016
42 | n03764736
43 | n03775546
44 | n02087046
45 | n07836838
46 | n04099969
47 | n04592741
48 | n03891251
49 | n02701002
50 | n03379051
51 | n02259212
52 | n07715103
53 | n03947888
54 | n04026417
55 | n02326432
56 | n03637318
57 | n01980166
58 | n02113799
59 | n02086240
60 | n03903868
61 | n02483362
62 | n04127249
63 | n02089973
64 | n03017168
65 | n02093428
66 | n02804414
67 | n02396427
68 | n04418357
69 | n02172182
70 | n01729322
71 | n02113978
72 | n03787032
73 | n02089867
74 | n02119022
75 | n03777754
76 | n04238763
77 | n02231487
78 | n03032252
79 | n02138441
80 | n02104029
81 | n03837869
82 | n03494278
83 | n04136333
84 | n03794056
85 | n03492542
86 | n02018207
87 | n04067472
88 | n03930630
89 | n03584829
90 | n02123045
91 | n04229816
92 | n02100583
93 | n03642806
94 | n04336792
95 | n03259280
96 | n02116738
97 | n02108089
98 | n03424325
99 | n01855672
100 | n02090622
101 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/methods/__init__.py:
--------------------------------------------------------------------------------
1 | from .contrastive import Contrastive
2 | from .w_mse import WMSE
3 | from .byol import BYOL
4 |
5 |
6 | METHOD_LIST = ["contrastive", "w_mse", "byol"]
7 |
8 |
9 | def get_method(name):
10 | assert name in METHOD_LIST
11 | if name == "contrastive":
12 | return Contrastive
13 | elif name == "w_mse":
14 | return WMSE
15 | elif name == "byol":
16 | return BYOL
17 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/methods/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.nn as nn
3 | from model import get_model, get_head
4 | from eval.sgd import eval_sgd
5 | from eval.knn import eval_knn
6 | from eval.get_data import get_data
7 |
8 |
9 | class BaseMethod(nn.Module):
10 | """
11 | Base class for self-supervised loss implementation.
12 | It includes encoder and head for training, evaluation function.
13 | """
14 |
15 | def __init__(self, cfg):
16 | super().__init__()
17 | self.model, self.out_size = get_model(cfg.arch, cfg.dataset)
18 | self.head = get_head(self.out_size, cfg)
19 | self.knn = cfg.knn
20 | self.num_pairs = cfg.num_samples * (cfg.num_samples - 1) // 2
21 | self.eval_head = cfg.eval_head
22 | self.emb_size = cfg.emb
23 |
24 | self.cfg = cfg
25 |
26 | def forward(self, samples):
27 | raise NotImplementedError
28 |
29 | def get_acc(self, ds_clf, ds_test, ds_test_p):
30 | self.eval()
31 | if self.eval_head:
32 | model = lambda x: self.head(self.model(x))
33 | out_size = self.emb_size
34 | else:
35 | model, out_size = self.model, self.out_size
36 | # torch.cuda.empty_cache()
37 | x_train, y_train = get_data(model, ds_clf, out_size, "cuda")
38 | x_test, y_test = get_data(model, ds_test, out_size, "cuda")
39 | x_test_p, y_test_p = get_data(model, ds_test_p, out_size, "cuda")
40 |
41 | acc_knn = eval_knn(x_train, y_train, x_test, y_test, self.knn)
42 | asr_knn = eval_knn(x_train, y_train, x_test_p, y_test_p, self.knn)
43 | acc, y_pred, y_test, acc_p, y_pred_p, y_test_p = eval_sgd(model, ds_clf, ds_test, ds_test_p, out_size, "cuda")
44 |
45 |
46 | del x_train, y_train, x_test, y_test, x_test_p, y_test_p
47 | self.train()
48 | return acc_knn, acc, asr_knn, acc_p
49 |
50 | def step(self, progress):
51 | pass
52 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/methods/byol.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from model import get_model, get_head
7 | from .base import BaseMethod
8 |
9 |
10 | def norm_mse_loss(x0, x1):
11 | x0 = F.normalize(x0)
12 | x1 = F.normalize(x1)
13 | return 2 - 2 * (x0 * x1).sum(dim=-1).mean()
14 |
15 | class BYOL(BaseMethod):
16 | """ implements BYOL loss https://arxiv.org/abs/2006.07733 """
17 |
18 | def __init__(self, cfg):
19 | """ init additional target and predictor networks """
20 | super().__init__(cfg)
21 | self.pred = nn.Sequential(
22 | nn.Linear(cfg.emb, cfg.head_size),
23 | nn.BatchNorm1d(cfg.head_size),
24 | nn.ReLU(),
25 | nn.Linear(cfg.head_size, cfg.emb),
26 | )
27 | self.model_t, _ = get_model(cfg.arch, cfg.dataset)
28 | self.head_t = get_head(self.out_size, cfg)
29 | for param in chain(self.model_t.parameters(), self.head_t.parameters()):
30 | param.requires_grad = False
31 | self.update_target(0)
32 | self.byol_tau = cfg.byol_tau
33 | self.loss_f = norm_mse_loss if cfg.norm else F.mse_loss
34 |
35 | def update_target(self, tau):
36 | """ copy parameters from main network to target """
37 | for t, s in zip(self.model_t.parameters(), self.model.parameters()):
38 | t.data.copy_(t.data * tau + s.data * (1.0 - tau))
39 | for t, s in zip(self.head_t.parameters(), self.head.parameters()):
40 | t.data.copy_(t.data * tau + s.data * (1.0 - tau))
41 |
42 | def forward(self, samples):
43 | z = [self.pred(self.head(self.model(x))) for x in samples]
44 | with torch.no_grad():
45 | zt = [self.head_t(self.model_t(x)) for x in samples]
46 |
47 | loss = 0
48 | for i in range(len(samples) - 1):
49 | for j in range(i + 1, len(samples)):
50 | loss += self.loss_f(z[i], zt[j]) + self.loss_f(z[j], zt[i])
51 | loss /= self.num_pairs
52 | return loss
53 |
54 | def step(self, progress):
55 | """ update target network with cosine increasing schedule """
56 | tau = 1 - (1 - self.byol_tau) * (math.cos(math.pi * progress) + 1) / 2
57 | self.update_target(tau)
58 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/methods/contrastive.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from .base import BaseMethod
6 |
7 |
8 | def contrastive_loss(x0, x1, tau, norm):
9 | # https://github.com/google-research/simclr/blob/master/objective.py
10 | bsize = x0.shape[0]
11 | target = torch.arange(bsize).cuda()
12 | eye_mask = torch.eye(bsize).cuda() * 1e9
13 | if norm:
14 | x0 = F.normalize(x0, p=2, dim=1)
15 | x1 = F.normalize(x1, p=2, dim=1)
16 | logits00 = x0 @ x0.t() / tau - eye_mask
17 | logits11 = x1 @ x1.t() / tau - eye_mask
18 | logits01 = x0 @ x1.t() / tau
19 | logits10 = x1 @ x0.t() / tau
20 | return (
21 | F.cross_entropy(torch.cat([logits01, logits00], dim=1), target)
22 | + F.cross_entropy(torch.cat([logits10, logits11], dim=1), target)
23 | ) / 2
24 |
25 |
26 | class Contrastive(BaseMethod):
27 | """ implements contrastive loss https://arxiv.org/abs/2002.05709 """
28 |
29 | def __init__(self, cfg):
30 | """ init additional BN used after head """
31 | super().__init__(cfg)
32 | self.bn_last = nn.BatchNorm1d(cfg.emb)
33 | self.loss_f = partial(contrastive_loss, tau=cfg.tau, norm=cfg.norm)
34 |
35 | def forward(self, samples):
36 | bs = len(samples[0])
37 | h = [self.model(x.cuda(non_blocking=True)) for x in samples]
38 | h = self.bn_last(self.head(torch.cat(h)))
39 | loss = 0
40 | for i in range(len(samples) - 1):
41 | for j in range(i + 1, len(samples)):
42 | x0 = h[i * bs : (i + 1) * bs]
43 | x1 = h[j * bs : (j + 1) * bs]
44 | loss += self.loss_f(x0, x1)
45 | loss /= self.num_pairs
46 | return loss
47 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/methods/norm_mse.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 |
3 |
4 | def norm_mse_loss(x0, x1):
5 | x0 = F.normalize(x0)
6 | x1 = F.normalize(x1)
7 | return 2 - 2 * (x0 * x1).sum(dim=-1).mean()
8 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/methods/w_mse.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from .whitening import Whitening2d
4 | from .base import BaseMethod
5 | from .norm_mse import norm_mse_loss
6 |
7 |
8 | class WMSE(BaseMethod):
9 | """ implements W-MSE loss """
10 |
11 | def __init__(self, cfg):
12 | """ init whitening transform """
13 | super().__init__(cfg)
14 | self.whitening = Whitening2d(cfg.emb, eps=cfg.w_eps, track_running_stats=False)
15 | self.loss_f = norm_mse_loss if cfg.norm else F.mse_loss
16 | self.w_iter = cfg.w_iter
17 | self.w_size = cfg.bs if cfg.w_size is None else cfg.w_size
18 |
19 | def forward(self, samples):
20 | bs = len(samples[0])
21 | h = [self.model(x.cuda(non_blocking=True)) for x in samples]
22 | h = self.head(torch.cat(h))
23 | loss = 0
24 | for _ in range(self.w_iter):
25 | z = torch.empty_like(h)
26 | perm = torch.randperm(bs).view(-1, self.w_size)
27 | for idx in perm:
28 | for i in range(len(samples)):
29 | z[idx + i * bs] = self.whitening(h[idx + i * bs])
30 | for i in range(len(samples) - 1):
31 | for j in range(i + 1, len(samples)):
32 | x0 = z[i * bs : (i + 1) * bs]
33 | x1 = z[j * bs : (j + 1) * bs]
34 | loss += self.loss_f(x0, x1)
35 | loss /= self.w_iter * self.num_pairs
36 | return loss
37 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/methods/whitening.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.functional import conv2d
4 |
5 |
6 | class Whitening2d(nn.Module):
7 | def __init__(self, num_features, momentum=0.01, track_running_stats=True, eps=0):
8 | super(Whitening2d, self).__init__()
9 | self.num_features = num_features
10 | self.momentum = momentum
11 | self.track_running_stats = track_running_stats
12 | self.eps = eps
13 |
14 | if self.track_running_stats:
15 | self.register_buffer(
16 | "running_mean", torch.zeros([1, self.num_features, 1, 1])
17 | )
18 | self.register_buffer("running_variance", torch.eye(self.num_features))
19 |
20 | def forward(self, x):
21 | x = x.unsqueeze(2).unsqueeze(3)
22 | m = x.mean(0).view(self.num_features, -1).mean(-1).view(1, -1, 1, 1)
23 | if not self.training and self.track_running_stats: # for inference
24 | m = self.running_mean
25 | xn = x - m
26 |
27 | T = xn.permute(1, 0, 2, 3).contiguous().view(self.num_features, -1)
28 | f_cov = torch.mm(T, T.permute(1, 0)) / (T.shape[-1] - 1)
29 |
30 | eye = torch.eye(self.num_features).type(f_cov.type())
31 |
32 | if not self.training and self.track_running_stats: # for inference
33 | f_cov = self.running_variance
34 |
35 | f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye
36 |
37 | inv_sqrt = torch.triangular_solve(
38 | eye, torch.cholesky(f_cov_shrinked), upper=False
39 | )[0]
40 | inv_sqrt = inv_sqrt.contiguous().view(
41 | self.num_features, self.num_features, 1, 1
42 | )
43 |
44 | decorrelated = conv2d(xn, inv_sqrt)
45 |
46 | if self.training and self.track_running_stats:
47 | self.running_mean = torch.add(
48 | self.momentum * m.detach(),
49 | (1 - self.momentum) * self.running_mean,
50 | out=self.running_mean,
51 | )
52 | self.running_variance = torch.add(
53 | self.momentum * f_cov.detach(),
54 | (1 - self.momentum) * self.running_variance,
55 | out=self.running_variance,
56 | )
57 |
58 | return decorrelated.squeeze(2).squeeze(2)
59 |
60 | def extra_repr(self):
61 | return "features={}, eps={}, momentum={}".format(
62 | self.num_features, self.eps, self.momentum
63 | )
64 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/moco/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/moco/loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | # Copyright (c) 2020 Tongzhou Wang
3 | from PIL import ImageFilter, Image
4 | import random
5 | import torchvision.transforms.functional as F
6 | import torchvision.transforms as transforms
7 |
8 |
9 | class TwoCropsTransform:
10 | """Take two random crops of one image as the query and key."""
11 |
12 | def __init__(self, base_transform):
13 | self.base_transform = base_transform
14 |
15 | def __call__(self, x):
16 | q = self.base_transform(x)
17 | k = self.base_transform(x)
18 | return [q, k]
19 |
20 | class OneCropOneTestTransform:
21 | """Take one random crop of one image as the key and the whole image as the query."""
22 |
23 | def __init__(self, crop_transform, test_transform):
24 | self.crop_transform = crop_transform
25 | self.test_transform = test_transform
26 |
27 | def __call__(self, x):
28 | q = self.test_transform(x)
29 | k = self.crop_transform(x)
30 |
31 | return [q, k]
32 |
33 | class GaussianBlur(object):
34 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
35 |
36 | def __init__(self, sigma=[.1, 2.]):
37 | self.sigma = sigma
38 |
39 | def __call__(self, x):
40 | sigma = random.uniform(self.sigma[0], self.sigma[1])
41 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
42 | return x
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/moco/poisonencoder_utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageDraw, ImageFont, ImageFilter, ImageColor
2 | import os
3 | import cv2
4 | import re
5 | import sys
6 | import glob
7 | import errno
8 | import random
9 | import numpy as np
10 |
11 | def get_trigger(trigger_size=40, trigger_path=None, colorful_trigger=True):
12 | # load trigger
13 | if colorful_trigger:
14 | trigger = Image.open(trigger_path).convert('RGB')
15 | trigger = trigger.resize((trigger_size, trigger_size))
16 | else:
17 | trigger = Image.new("RGB", (trigger_size, trigger_size), ImageColor.getrgb("white"))
18 | return trigger
19 |
20 |
21 |
22 |
23 | def binary_mask_to_box(binary_mask):
24 | binary_mask = np.array(binary_mask, np.uint8)
25 | contours,hierarchy = cv2.findContours(
26 | binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
27 | areas = []
28 | for cnt in contours:
29 | area = cv2.contourArea(cnt)
30 | areas.append(area)
31 | idx = areas.index(np.max(areas))
32 | x, y, w, h = cv2.boundingRect(contours[idx])
33 | bounding_box = [x, y, x+w, y+h]
34 | return bounding_box
35 |
36 | def get_foreground(reference_dir, num_references, max_size, type):
37 | img_idx = random.choice(range(1, 1+num_references))
38 | image_path = os.path.join(reference_dir, f'{img_idx}/img.png')
39 | mask_path = os.path.join(reference_dir, f'{img_idx}/label.png')
40 | image_np = np.asarray(Image.open(image_path).convert('RGB'))
41 | mask_np = np.asarray(Image.open(mask_path).convert('RGB'))
42 | mask_np = (mask_np[..., 0] == 128) ##### [:,0]==128 represents the object mask
43 |
44 | # crop masked region
45 | bbx = binary_mask_to_box(mask_np)
46 | object_image = image_np[bbx[1]:bbx[3],bbx[0]:bbx[2]]
47 | object_image = Image.fromarray(object_image)
48 | object_mask = mask_np[bbx[1]:bbx[3],bbx[0]:bbx[2]]
49 | object_mask = Image.fromarray(object_mask)
50 |
51 | # resize -> avoid poisoned image being too large
52 | w, h = object_image.size
53 | if type=='horizontal':
54 | o_w = min(w, int(max_size/2))
55 | o_h = int((o_w/w) * h)
56 | elif type=='vertical':
57 | o_h = min(h, int(max_size/2))
58 | o_w = int((o_h/h) * w)
59 | object_image = object_image.resize((o_w, o_h))
60 | object_mask = object_mask.resize((o_w, o_h))
61 | return object_image, object_mask
62 |
63 | def concat(support_reference_image_path, reference_image_path, max_size):
64 | ### horizontally concat two images
65 | # get support reference image
66 | support_reference_image = Image.open(support_reference_image_path)
67 | width, height = support_reference_image.size
68 | n_w = min(width, int(max_size/2))
69 | n_h = int((n_w/width) * height)
70 | support_reference_image = support_reference_image.resize((n_w, n_h))
71 | width, height = support_reference_image.size
72 |
73 | # get reference image
74 | reference_image = Image.open(reference_image_path)
75 | reference_image = reference_image.resize((width, height))
76 |
77 | img_new = Image.new("RGB", (width*2, height), "white")
78 | if random.random()<0.5:
79 | img_new.paste(support_reference_image, (0, 0))
80 | img_new.paste(reference_image, (width, 0))
81 | else:
82 | img_new.paste(reference_image, (0, 0))
83 | img_new.paste(support_reference_image, (width, 0))
84 | return img_new
85 |
86 |
87 | def get_random_reference_image(reference_dir, num_references):
88 | img_idx = random.choice(range(1, 1+num_references))
89 | image_path = os.path.join(reference_dir, f'{img_idx}/img.png')
90 | return image_path
91 |
92 | def get_random_support_reference_image(reference_dir):
93 | support_dir = os.path.join(reference_dir, 'support-images')
94 | image_path = os.path.join(support_dir, random.choice(os.listdir(support_dir)))
95 | return image_path
96 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torchvision import models
3 |
4 |
5 | def get_head(out_size, cfg):
6 | """ creates projection head g() from config """
7 | x = []
8 | in_size = out_size
9 | for _ in range(cfg.head_layers - 1):
10 | x.append(nn.Linear(in_size, cfg.head_size))
11 | if cfg.add_bn:
12 | x.append(nn.BatchNorm1d(cfg.head_size))
13 | x.append(nn.ReLU())
14 | in_size = cfg.head_size
15 | x.append(nn.Linear(in_size, cfg.emb))
16 | return nn.Sequential(*x)
17 |
18 |
19 | def get_model(arch, dataset):
20 | """ creates encoder E() by name and modifies it for dataset """
21 | model = getattr(models, arch)(pretrained=False)
22 | # if dataset != "imagenet":
23 | if 'imagenet' not in dataset:
24 | model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
25 | if dataset == "cifar10" or dataset == "cifar100":
26 | model.maxpool = nn.Identity()
27 | out_size = model.fc.in_features
28 | model.fc = nn.Identity()
29 |
30 | return nn.DataParallel(model), out_size
31 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from datasets import get_ds
3 | from cfg import get_cfg
4 | from methods import get_method
5 |
6 | from eval.sgd import eval_sgd
7 | from eval.knn import eval_knn
8 | from eval.lbfgs import eval_lbfgs
9 | from eval.get_data import get_data
10 |
11 | import os
12 | import numpy as np
13 |
14 | if __name__ == "__main__":
15 | cfg = get_cfg()
16 |
17 | model_full = get_method(cfg.method)(cfg)
18 | model_full.cuda().eval()
19 | if cfg.fname is None:
20 | print("evaluating random model")
21 | else:
22 | state_dict = torch.load(cfg.fname)['model_state_dict']
23 | model_full.load_state_dict(state_dict)
24 |
25 | ds = get_ds(cfg.dataset)(None, cfg, cfg.num_workers, bs_clf=256, bs_test=256)
26 | device = "cpu" if cfg.clf == "lbfgs" else "cuda"
27 | if cfg.eval_head:
28 | model = lambda x: model_full.head(model_full.model(x))
29 | out_size = cfg.emb
30 | else:
31 | model = model_full.model
32 | out_size = model_full.out_size
33 |
34 | # x_train, y_train = get_data(model, ds.clf, out_size, device)
35 | # x_test, y_test = get_data(model, ds.test, out_size, device)
36 | # x_test_p, y_test_p = get_data(model, ds.test_p, out_size, device)
37 |
38 |
39 | clf_fname = os.path.join(os.path.dirname(cfg.fname), "linear", os.path.basename(cfg.fname))
40 | if cfg.clf == "sgd":
41 | # acc, pred_var_stack, labels_var_stack, acc_p, pred_var_p_stack, labels_var_p_stack = eval_sgd(x_train, y_train, x_test, y_test, x_test_p, y_test_p, evaluate=True)
42 | acc, pred_var_stack, labels_var_stack, acc_p, pred_var_p_stack, labels_var_p_stack = eval_sgd(model, ds.clf, ds.test, ds.test_p, out_size, device, evaluate=True)
43 |
44 | else:
45 | raise NotImplementedError
46 | # elif cfg.clf == "knn":
47 | # acc = eval_knn(x_train, y_train, x_test, y_test)
48 | # elif cfg.clf == "lbfgs":
49 | # acc = eval_lbfgs(x_train, y_train, x_test, y_test)
50 |
51 | pred_var_stack = torch.argmax(pred_var_stack, dim=1)
52 | pred_var_p_stack = torch.argmax(pred_var_p_stack, dim=1)
53 |
54 | # create confusion matrix ROWS ground truth COLUMNS pred
55 | conf_matrix_clean = np.zeros((int(labels_var_stack.max())+1, int(labels_var_stack.max())+1))
56 | conf_matrix_poisoned = np.zeros((int(labels_var_stack.max())+1, int(labels_var_stack.max())+1))
57 |
58 | for i in range(pred_var_stack.size(0)):
59 | # update confusion matrix
60 | conf_matrix_clean[int(labels_var_stack[i]), int(pred_var_stack[i])] += 1
61 |
62 | for i in range(pred_var_p_stack.size(0)):
63 | # update confusion matrix
64 | conf_matrix_poisoned[int(labels_var_p_stack[i]), int(pred_var_p_stack[i])] += 1
65 |
66 | if 'imagenet' in cfg.dataset:
67 | metadata_file = 'imagenet_metadata.txt'
68 | class_dir_list_file = 'imagenet100_classes.txt'
69 | elif cfg.dataset == 'cifar10':
70 | metadata_file = 'cifar10_metadata.txt'
71 | class_dir_list_file = 'cifar10_classes.txt'
72 | else:
73 | raise ValueError(f"Unknown dataset '{cfg.dataset}'")
74 | # load imagenet metadata
75 | with open(metadata_file, "r") as f:
76 | data = [l.strip() for l in f.readlines()]
77 | imagenet_metadata_dict = {}
78 | for line in data:
79 | wnid, classname = line.split()[0], line.split()[1]
80 | imagenet_metadata_dict[wnid] = classname
81 |
82 | with open(class_dir_list_file, 'r') as f:
83 | class_dir_list = [l.strip() for l in f.readlines()]
84 | class_dir_list = sorted(class_dir_list)
85 |
86 | if '1percent' in cfg.train_clean_file_path:
87 | save_folder = os.path.join(os.path.dirname(cfg.fname), "1per_linear", os.path.basename(cfg.fname))
88 | elif '10percent' in cfg.train_clean_file_path:
89 | save_folder = os.path.join(os.path.dirname(cfg.fname), "10per_linear", os.path.basename(cfg.fname))
90 | else:
91 | raise ValueError(f"Unknown train_clean_file_path '{cfg.train_clean_file_path}'")
92 | os.makedirs(save_folder, exist_ok=True)
93 | np.save("{}/conf_matrix_clean.npy".format(save_folder), conf_matrix_clean)
94 | np.save("{}/conf_matrix_poisoned.npy".format(save_folder), conf_matrix_poisoned)
95 |
96 | with open("{}/conf_matrix.csv".format(save_folder), "w") as f:
97 | f.write("Model {},,Clean val,,,,Pois. val,,\n".format(""))
98 | f.write("Data {},,acc1,,,,acc1,,\n".format(""))
99 | f.write(",,{:.2f},,,,{:.2f},,\n".format(acc[1]*100, acc_p[1]*100))
100 | f.write("class name,class id,TP,FP,,TP,FP\n")
101 | for target in range(len(class_dir_list)):
102 | f.write("{},{},{},{},,".format(imagenet_metadata_dict[class_dir_list[target]].replace(",",";"), target, conf_matrix_clean[target][target], conf_matrix_clean[:, target].sum() - conf_matrix_clean[target][target]))
103 | f.write("{},{}\n".format(conf_matrix_poisoned[target][target], conf_matrix_poisoned[:, target].sum() - conf_matrix_poisoned[target][target]))
104 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/byol/train.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import yaml
4 | import argparse
5 | import os
6 | import wandb
7 | import torch
8 | import torch.optim as optim
9 | from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingWarmRestarts
10 | import torch.backends.cudnn as cudnn
11 |
12 | from cfg import get_cfg
13 | from dataset import get_ds
14 | from methods import get_method
15 | from torch.utils.tensorboard import SummaryWriter
16 | from tqdm import trange, tqdm
17 |
18 | def get_scheduler(optimizer, cfg):
19 | if cfg.lr_step == "cos":
20 | return CosineAnnealingWarmRestarts(
21 | optimizer,
22 | T_0=cfg.epoch if cfg.T0 is None else cfg.T0,
23 | T_mult=cfg.Tmult,
24 | eta_min=cfg.eta_min,
25 | )
26 | elif cfg.lr_step == "step":
27 | m = [cfg.epoch - a for a in cfg.drop]
28 | return MultiStepLR(optimizer, milestones=m, gamma=cfg.drop_gamma)
29 | else:
30 | return None
31 |
32 | def load_config_from_yaml(config_path):
33 | """Load configuration from a YAML file."""
34 | with open(config_path, 'r') as file:
35 | return yaml.safe_load(file)
36 |
37 | def merge_configs(defaults, overrides):
38 | """Merge two dictionaries, prioritizing values from 'overrides'."""
39 | result = defaults.copy()
40 |
41 | result.update({k: v for k, v in overrides.items() if not k in result.keys()})
42 | result.update({k: v for k, v in overrides.items() if v is not None})
43 |
44 | return argparse.Namespace(**result)
45 |
46 | if __name__ == "__main__":
47 | cfg = get_cfg()
48 | print('cfg', cfg)
49 |
50 | if cfg.config:
51 | config_from_yaml = load_config_from_yaml(cfg.config)
52 | else:
53 | config_from_yaml = {}
54 |
55 | # Prepare final configuration by merging YAML config with command line arguments
56 | cfg = merge_configs(config_from_yaml, vars(cfg))
57 | print(cfg)
58 |
59 | if 'targeted' in cfg.attack_algorithm:
60 | assert cfg.downstream_data is not None
61 |
62 | # wandb.init(project=cfg.wandb, config=cfg)
63 | writer = SummaryWriter(cfg.save_folder)
64 |
65 | ds = get_ds(cfg.dataset)(cfg.bs, cfg, cfg.num_workers, bs_clf=256, bs_test=256)
66 | train_dataloader = ds.train
67 |
68 | model = get_method(cfg.method)(cfg)
69 | model.cuda().train()
70 |
71 | optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.adam_l2)
72 | scheduler = get_scheduler(optimizer, cfg)
73 | starting_epoch = 0
74 | eval_every = cfg.eval_every
75 | lr_warmup = 0 if cfg.lr_warmup else 500
76 | cudnn.benchmark = True
77 |
78 | if cfg.fname is not None:
79 | checkpoint = torch.load(cfg.fname)
80 |
81 | starting_epoch = checkpoint['epoch'] # 加载当前 epoch
82 | lr_warmup = checkpoint['warmup'] # 加载当前 warmup
83 | model.load_state_dict(checkpoint['state_dict'])
84 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
85 | if 'scheduler_state_dict' in checkpoint and scheduler is not None:
86 | scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
87 |
88 |
89 |
90 | for ep in trange(starting_epoch, cfg.epoch, position=0):
91 | loss_ep = []
92 | iters = len(train_dataloader)
93 | for n_iter, (samples, _) in enumerate(tqdm(train_dataloader, position=1)):
94 | if lr_warmup < 500:
95 | lr_scale = (lr_warmup + 1) / 500
96 | for pg in optimizer.param_groups:
97 | pg["lr"] = cfg.lr * lr_scale
98 | lr_warmup += 1
99 |
100 | optimizer.zero_grad()
101 | loss = model(samples)
102 | loss.backward()
103 | optimizer.step()
104 | loss_ep.append(loss.item())
105 | model.step(ep / cfg.epoch)
106 | if cfg.lr_step == "cos" and lr_warmup >= 500:
107 | scheduler.step(ep + n_iter / iters)
108 |
109 | if cfg.lr_step == "step":
110 | scheduler.step()
111 |
112 |
113 | # if (ep + 1) % eval_every == 0:
114 | # acc_knn, acc_linear, asr_knn, asr_linear = model.get_acc(ds.clf, ds.test, ds.test_p)
115 | # # wandb.log({"acc": acc[1], "acc_5": acc[5], "acc_knn": acc_knn}, commit=False)
116 | # writer.add_scalar('Accuracy/acc', acc_linear[1], ep)
117 | # writer.add_scalar('Accuracy/acc_5', acc_linear[5], ep)
118 | # writer.add_scalar('Accuracy/acc_knn', acc_knn, ep)
119 | # writer.add_scalar('Accuracy/asr', asr_linear[1], ep)
120 | # writer.add_scalar('Accuracy/asr_5', asr_linear[5], ep)
121 | # writer.add_scalar('Accuracy/asr_knn', asr_knn, ep)
122 |
123 | if (ep + 1) % cfg.save_freq == 0:
124 | fname = f"{cfg.save_folder}/{ep}.pth"
125 | os.makedirs(os.path.dirname(fname), exist_ok=True)
126 | checkpoint = {
127 | 'epoch': ep,
128 | 'warmup': lr_warmup,
129 | 'state_dict': model.state_dict(),
130 | 'optimizer_state_dict': optimizer.state_dict(),
131 | 'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
132 | }
133 | torch.save(checkpoint, fname)
134 |
135 | # wandb.log({"loss": np.mean(loss_ep), "ep": ep})
136 | writer.add_scalar('Loss/byol_pretrain', np.mean(loss_ep), ep)
137 |
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/moco/loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | # Copyright (c) 2020 Tongzhou Wang
3 | from PIL import ImageFilter, Image
4 | import random
5 | import torchvision.transforms.functional as F
6 | import torchvision.transforms as transforms
7 |
8 |
9 | class TwoCropsTransform:
10 | """Take two random crops of one image as the query and key."""
11 |
12 | def __init__(self, base_transform):
13 | self.base_transform = base_transform
14 |
15 | def __call__(self, x):
16 | q = self.base_transform(x)
17 | k = self.base_transform(x)
18 | return [q, k]
19 |
20 | class OneCropOneTestTransform:
21 | """Take one random crop of one image as the key and the whole image as the query."""
22 |
23 | def __init__(self, crop_transform, test_transform):
24 | self.crop_transform = crop_transform
25 | self.test_transform = test_transform
26 |
27 | def __call__(self, x):
28 | q = self.test_transform(x)
29 | k = self.crop_transform(x)
30 |
31 | return [q, k]
32 |
33 | class GaussianBlur(object):
34 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
35 |
36 | def __init__(self, sigma=[.1, 2.]):
37 | self.sigma = sigma
38 |
39 | def __call__(self, x):
40 | sigma = random.uniform(self.sigma[0], self.sigma[1])
41 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
42 | return x
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/models_vit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | # DeiT: https://github.com/facebookresearch/deit
10 | # --------------------------------------------------------
11 |
12 | from functools import partial
13 |
14 | import torch
15 | import torch.nn as nn
16 |
17 | import timm.models.vision_transformer
18 |
19 |
20 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
21 | """ Vision Transformer with support for global average pooling
22 | """
23 | def __init__(self, global_pool=False, **kwargs):
24 | super(VisionTransformer, self).__init__(**kwargs)
25 |
26 | self.global_pool = global_pool
27 | if self.global_pool:
28 | norm_layer = kwargs['norm_layer']
29 | embed_dim = kwargs['embed_dim']
30 | self.fc_norm = norm_layer(embed_dim)
31 |
32 | del self.norm # remove the original norm
33 |
34 | def forward_features(self, x):
35 | B = x.shape[0]
36 | x = self.patch_embed(x)
37 |
38 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
39 | x = torch.cat((cls_tokens, x), dim=1)
40 | x = x + self.pos_embed
41 | x = self.pos_drop(x)
42 |
43 | for blk in self.blocks:
44 | x = blk(x)
45 |
46 | if self.global_pool:
47 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token
48 | outcome = self.fc_norm(x)
49 | else:
50 | x = self.norm(x)
51 | outcome = x[:, 0]
52 |
53 | return outcome
54 |
55 | def forward(self, x):
56 | x = self.forward_features(x)
57 | return x
58 |
59 |
60 | def vit_base_patch16(**kwargs):
61 | model = VisionTransformer(
62 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
63 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
64 | return model
65 |
66 |
67 | def vit_large_patch16(**kwargs):
68 | model = VisionTransformer(
69 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
70 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
71 | return model
72 |
73 |
74 | def vit_huge_patch14(**kwargs):
75 | model = VisionTransformer(
76 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
77 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
78 | return model
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/simclr/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from ssl_backdoor.utils.model_utils import transform_encoder_for_small_dataset, remove_task_head_for_encoder
6 | from lightly.models.modules import SimCLRProjectionHead
7 | from lightly.loss import NTXentLoss
8 |
9 |
10 | class SimCLR(nn.Module):
11 | """
12 | 构建一个SimCLR模型。
13 | """
14 | def __init__(self, base_encoder, dim=512, proj_dim=128, dataset=None):
15 | """
16 | dim: 特征维度 (默认: 2048)
17 | proj_dim: 投影头输出维度 (默认: 128)
18 | """
19 | super(SimCLR, self).__init__()
20 |
21 | # 创建编码器
22 | self.encoder = base_encoder(num_classes=dim, zero_init_residual=True)
23 | channel_dim = SimCLR.get_channel_dim(self.encoder)
24 | self.encoder = transform_encoder_for_small_dataset(self.encoder, dataset)
25 | self.encoder = remove_task_head_for_encoder(self.encoder)
26 |
27 | # 创建投影头 (MLP)
28 | self.projector = SimCLRProjectionHead(input_dim=channel_dim, hidden_dim=dim, output_dim=proj_dim)
29 |
30 | self.criterion = NTXentLoss()
31 |
32 | @staticmethod
33 | def get_channel_dim(encoder: nn.Module) -> int:
34 | if hasattr(encoder, 'fc'):
35 | return encoder.fc.weight.shape[1]
36 | elif hasattr(encoder, 'head'):
37 | return encoder.head.weight.shape[1]
38 | elif hasattr(encoder, 'classifier'):
39 | return encoder.classifier.weight.shape[1]
40 | else:
41 | raise NotImplementedError('MLP投影头在编码器中未找到')
42 |
43 | def forward(self, x1, x2):
44 | """
45 | 输入:
46 | x1: 图像第一视角
47 | x2: 图像第二视角
48 | 输出:
49 | 对比损失
50 | """
51 | # 计算两个视角的特征
52 | h1 = self.encoder(x1)
53 | h2 = self.encoder(x2)
54 |
55 | # 通过投影头
56 | z1 = self.projector(h1)
57 | z2 = self.projector(h2)
58 |
59 | # 归一化表示以进行余弦相似度
60 | z1 = F.normalize(z1, dim=1)
61 | z2 = F.normalize(z2, dim=1)
62 |
63 | loss = self.criterion(z1, z2)
64 |
65 | return loss
--------------------------------------------------------------------------------
/ssl_backdoor/ssl_trainers/simsiam/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from ssl_backdoor.utils.model_utils import transform_encoder_for_small_dataset, remove_task_head_for_encoder
12 | from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead
13 |
14 | class SimSiam(nn.Module):
15 | """
16 | Build a SimSiam model.
17 | """
18 | def __init__(self, base_encoder, dim=2048, pred_dim=512, dataset=None):
19 | """
20 | dim: feature dimension (default: 2048)
21 | pred_dim: hidden dimension of the predictor (default: 512)
22 | """
23 | super(SimSiam, self).__init__()
24 |
25 | # create the encoder
26 | # num_classes is the output fc dimension, zero-initialize last BNs
27 | encoder_name = str(base_encoder).lower()
28 | self.encoder = base_encoder(num_classes=dim)
29 |
30 | if "squeezenet" in encoder_name:
31 | # 对于squeezenet,跳过特殊处理,直接使用dim作为channel_dim
32 | channel_dim = dim
33 | else:
34 | # 对于其他网络,按原来的方式处理
35 | channel_dim = SimSiam.get_channel_dim(self.encoder)
36 | self.encoder = transform_encoder_for_small_dataset(self.encoder, dataset)
37 | self.encoder = remove_task_head_for_encoder(self.encoder)
38 |
39 | self.projector = SimSiamProjectionHead(input_dim=channel_dim, hidden_dim=channel_dim, output_dim=dim)
40 | self.predictor = SimSiamPredictionHead(input_dim=dim, hidden_dim=pred_dim, output_dim=dim)
41 |
42 |
43 | @staticmethod
44 | def get_channel_dim(encoder: nn.Module) -> int:
45 | def get_channel_dim_from_sequential(sequential: nn.Sequential) -> int:
46 | for module in sequential:
47 | if isinstance(module, nn.Linear):
48 | return module.in_features
49 | elif isinstance(module, nn.Conv2d):
50 | return module.in_channels
51 | raise ValueError("没有在Sequential中找到Linear层或Conv2d层")
52 |
53 | if hasattr(encoder, 'fc'):
54 | return encoder.fc.weight.shape[1]
55 | elif hasattr(encoder, 'head'):
56 | return encoder.head.weight.shape[1]
57 | elif hasattr(encoder, 'heads'):
58 | # 处理Vision Transformer中的heads属性
59 | if hasattr(encoder.heads, 'head'):
60 | return encoder.heads.head.in_features
61 | # 遍历Sequential中的层
62 | if isinstance(encoder.heads, nn.Sequential):
63 | for name, module in encoder.heads.named_children():
64 | if name == 'head' or name == 'pre_logits':
65 | return module.in_features
66 | # 如果没有找到head或pre_logits,尝试获取第一个Linear层
67 | return get_channel_dim_from_sequential(encoder.heads)
68 | elif hasattr(encoder, 'classifier'):
69 | if isinstance(encoder.classifier, nn.Sequential):
70 | return get_channel_dim_from_sequential(encoder.classifier)
71 | else:
72 | return encoder.classifier.weight.shape[1]
73 | else:
74 | raise NotImplementedError('MLP projection head not found in encoder')
75 |
76 |
77 | def forward(self, x1, x2):
78 | """
79 | Input:
80 | x1: first views of images
81 | x2: second views of images
82 | Output:
83 | p1, p2, z1, z2: predictors and targets of the network
84 | See Sec. 3 of https://arxiv.org/abs/2011.10566 for detailed notations
85 | """
86 |
87 | # compute features for one view
88 | z1 = self.encoder(x1) # NxC
89 | z2 = self.encoder(x2) # NxC
90 |
91 | z1 = self.projector(z1)
92 | z2 = self.projector(z2)
93 |
94 | p1 = self.predictor(z1) # NxC
95 | p2 = self.predictor(z2) # NxC
96 |
97 | z1, z2 = z1.detach(), z2.detach()
98 |
99 | loss = -(F.cosine_similarity(p1, z2).mean() + F.cosine_similarity(p2, z1).mean()) * 0.5
100 |
101 | return loss
102 |
--------------------------------------------------------------------------------
/ssl_backdoor/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsrdcht/SSL-Backdoor/7242314b0ec5c0f6c4c59445e8dd716693748ab0/ssl_backdoor/utils/__init__.py
--------------------------------------------------------------------------------
/ssl_backdoor/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import random
3 | import torch
4 | import numpy as np
5 |
6 | from PIL import Image
7 | from tqdm import tqdm
8 | from sklearn.neighbors import NearestNeighbors
9 |
10 | def extract_features(model, loader, class_index=None):
11 | """
12 | Extracts features from the model using the given loader and saves them to a file.
13 |
14 | Args:
15 | model (torch.nn.Module): The model from which to extract features.
16 | loader (torch.utils.data.DataLoader): The DataLoader for input data.
17 | class_index (int): The index of the class to extract features for. If None, all classes are used.
18 | """
19 | model.eval()
20 | device = next(model.parameters()).device
21 |
22 | features = []
23 | target_list = []
24 |
25 |
26 | with torch.no_grad():
27 | for i, (inputs, targets) in enumerate(tqdm(loader)):
28 | if class_index is not None:
29 | mask = targets == class_index
30 | inputs = inputs[mask]
31 | targets = targets[mask]
32 |
33 | inputs = inputs.to(device)
34 | output = model(inputs)
35 | output = F.normalize(output, dim=1)
36 | features.append(output.detach().cpu())
37 | target_list.append(targets)
38 |
39 | features = torch.cat(features, dim=0)
40 | targets = torch.cat(target_list, dim=0)
41 |
42 |
43 | return features, targets
44 |
45 | def interpolate_pos_embed(model, checkpoint_model):
46 | if 'pos_embed' in checkpoint_model:
47 | pos_embed_checkpoint = checkpoint_model['pos_embed']
48 | embedding_size = pos_embed_checkpoint.shape[-1]
49 | num_patches = model.patch_embed.num_patches
50 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
51 | # height (== width) for the checkpoint position embedding
52 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
53 | # height (== width) for the new position embedding
54 | new_size = int(num_patches ** 0.5)
55 | # class_token and dist_token are kept unchanged
56 | if orig_size != new_size:
57 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
58 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
59 | # only the position tokens are interpolated
60 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
61 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
62 | pos_tokens = torch.nn.functional.interpolate(
63 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
64 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
65 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
66 | checkpoint_model['pos_embed'] = new_pos_embed
67 |
68 | def get_channels(arch):
69 | if arch == 'alexnet':
70 | c = 4096
71 | elif arch == 'pt_alexnet':
72 | c = 4096
73 | elif arch == 'resnet50':
74 | c = 2048
75 | elif 'resnet18' in arch:
76 | c = 512
77 | elif 'densenet121' in arch:
78 | c = 1024
79 | elif arch == 'mobilenet':
80 | c = 1280
81 | elif arch == 'MobileNetV2':
82 | c = 1280
83 | elif arch == 'SqueezeNet':
84 | c = 2048
85 | elif arch == 'resnet50x5_swav':
86 | c = 10240
87 | elif arch == 'vit_b_16':
88 | c = 768
89 | elif arch == 'swin_s':
90 | c = 768
91 | else:
92 | raise ValueError('arch not found: ' + arch)
93 | return c
94 |
95 | def knn_evaluate(model, train_loader, test_loader, device):
96 | model.eval()
97 | model.to(device)
98 | feature_bank = []
99 | labels = []
100 |
101 | # 构建特征库和标签
102 | with torch.no_grad():
103 | for data, target in train_loader:
104 | data = data.to(device)
105 | feature = model(data).flatten(start_dim=1)
106 | feature_bank.append(feature.cpu())
107 | labels.append(target.cpu())
108 |
109 | feature_bank = torch.cat(feature_bank, dim=0).numpy()
110 | labels = torch.cat(labels, dim=0).numpy() # 转换为 NumPy 数组
111 |
112 | # 训练 KNN
113 | knn = NearestNeighbors(n_neighbors=200, metric='cosine')
114 | knn.fit(feature_bank)
115 |
116 | total_correct = 0
117 | total_num = 0
118 | all_preds = []
119 | all_targets_list = []
120 |
121 | # 评估阶段``
122 | with torch.no_grad():
123 | for data, target in test_loader:
124 | data = data.to(device)
125 | feature = model(data).flatten(start_dim=1)
126 | feature = feature.cpu().numpy()
127 |
128 | distances, indices = knn.kneighbors(feature)
129 |
130 | # 使用 NumPy 进行索引
131 | retrieved_neighbors = labels[indices] # shape: [batch_size, n_neighbors]
132 |
133 | # 计算预测标签(使用众数)
134 | pred_labels = np.squeeze(np.apply_along_axis(lambda x: np.bincount(x).argmax(), 1, retrieved_neighbors))
135 |
136 | # 将预测标签转换为 PyTorch 张量
137 | pred_labels = torch.tensor(pred_labels, device='cpu') # 使用 CPU 进行比较
138 |
139 | # 计算正确预测数量
140 | total_correct += (pred_labels == target.cpu()).sum().item()
141 | total_num += data.size(0)
142 | all_preds.append(pred_labels)
143 | all_targets_list.append(target)
144 |
145 | accuracy = total_correct / total_num
146 | print(f"[knn_evaluate] Total accuracy: {accuracy * 100:.2f}%")
147 | all_preds = torch.cat(all_preds, dim=0)
148 | all_targets_list = torch.cat(all_targets_list, dim=0)
149 | return accuracy, all_preds, all_targets_list
150 |
151 |
152 | def extract_config_by_prefix(config_dict, prefix):
153 | """
154 | 从配置字典中提取特定前缀的键值对
155 |
156 | Args:
157 | config_dict (dict): 配置字典
158 | prefix (str): 键前缀
159 |
160 | Returns:
161 | dict: 包含所有以指定前缀开头的键值对的字典
162 | """
163 | # 如果前缀恰好是字典中的一个键,并且对应值是字典,则直接返回该子字典
164 | if prefix in config_dict and isinstance(config_dict[prefix], dict):
165 | return config_dict[prefix]
166 |
167 | # 否则,查找所有以该前缀开头的键
168 | result = {}
169 | for key, value in config_dict.items():
170 | if key.startswith(f"{prefix}"):
171 | key = key[len(prefix):]
172 | result[key] = value
173 |
174 | return result
175 |
176 |
177 | def set_seed(seed):
178 | """设置随机种子以确保结果可重现"""
179 | random.seed(seed)
180 | np.random.seed(seed)
181 | torch.manual_seed(seed)
182 | torch.cuda.manual_seed(seed)
183 | torch.backends.cudnn.benchmark = False
184 | torch.backends.cudnn.deterministic = True
--------------------------------------------------------------------------------
/tools/ddp_training.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import yaml
5 |
6 | # 导入训练器接口和配置加载函数
7 | from ssl_backdoor.ssl_trainers.trainer import get_trainer
8 | from ssl_backdoor.ssl_trainers.utils import load_config
9 |
10 | def main():
11 | parser = argparse.ArgumentParser(description='运行训练,命令行参数将覆盖配置文件中的同名参数')
12 | parser.add_argument('--config', type=str, default=None, required=True,
13 | help='基础配置文件路径,支持.py或.yaml格式')
14 | parser.add_argument('--attack_config', type=str, default=None, required=True,
15 | help='后门攻击配置文件路径 (.yaml格式)')
16 | parser.add_argument('--test_config', type=str, default=None,
17 | help='测试配置文件路径 (.yaml格式)')
18 |
19 | args = parser.parse_args()
20 |
21 | # 1. 加载基础配置
22 | config = load_config(args.config)
23 | print(f"已加载基础配置: {args.config}")
24 |
25 | # 2. 加载并合并攻击配置 (如果提供了)
26 | if args.attack_config:
27 | try:
28 | attack_config = load_config(args.attack_config) # 使用同一个加载函数
29 | if isinstance(attack_config, dict):
30 | print(f"已加载攻击配置: {args.attack_config}")
31 | # 合并,attack_config 中的值会覆盖 config 中的同名值
32 | config.update(attack_config)
33 | print("攻击配置已合并。")
34 | else:
35 | print(f"警告:攻击配置文件 {args.attack_config} 未能正确加载为字典,跳过合并。")
36 | except Exception as e:
37 | print(f"警告:加载攻击配置文件 {args.attack_config} 时出错: {e},跳过合并。")
38 |
39 | # 2.5 加载并存储测试配置 (如果提供了)
40 | if args.test_config:
41 | try:
42 | test_config_dict = load_config(args.test_config)
43 | if isinstance(test_config_dict, dict):
44 | print(f"已加载测试配置: {args.test_config}")
45 | config['test_config'] = test_config_dict
46 | else:
47 | print(f"警告:测试配置文件 {args.test_config} 未能正确加载为字典,跳过合并。")
48 | except Exception as e:
49 | print(f"警告:加载测试配置文件 {args.test_config} 时出错: {e},跳过合并。")
50 |
51 | print("\n最终使用的训练配置:", config)
52 | print("\n最终使用的测试配置:", config['test_config'])
53 |
54 | # 5. 获取训练器 (传入更新后的config字典)
55 | trainer = get_trainer(config)
56 |
57 | # 6. 准备测试接口 (如果启用了评估)
58 | eval_frequency = config.get('eval_frequency', 50)
59 | print(f"eval_frequency: {eval_frequency}, type: {type(eval_frequency)}")
60 |
61 |
62 | # 7. 启动训练
63 | trainer = get_trainer(config)
64 | trainer()
65 |
66 |
67 | if __name__ == '__main__':
68 | main()
--------------------------------------------------------------------------------
/tools/eval_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Tongzhou Wang
2 | import shutil
3 |
4 | import logging
5 | import os
6 |
7 | import torch
8 | from torch import nn
9 | from torchvision import models
10 |
11 | def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
12 | logger = logging.getLogger()
13 | if debug:
14 | level = logging.DEBUG
15 | else:
16 | level = logging.INFO
17 | logger.setLevel(level)
18 | if saving:
19 | info_file_handler = logging.FileHandler(logpath, mode="a")
20 | info_file_handler.setLevel(level)
21 | logger.addHandler(info_file_handler)
22 | if displaying:
23 | console_handler = logging.StreamHandler()
24 | console_handler.setLevel(level)
25 | logger.addHandler(console_handler)
26 | logger.info(filepath)
27 | with open(filepath, "r") as f:
28 | logger.info(f.read())
29 |
30 | for f in package_files:
31 | logger.info(f)
32 | with open(f, "r") as package_f:
33 | logger.info(package_f.read())
34 |
35 | return logger
36 |
37 |
38 | class AverageMeter(object):
39 | """Computes and stores the average and current value"""
40 | def __init__(self, name, fmt=':f'):
41 | self.name = name
42 | self.fmt = fmt
43 | self.reset()
44 |
45 | def reset(self):
46 | self.val = 0
47 | self.avg = 0
48 | self.sum = 0
49 | self.count = 0
50 |
51 | def update(self, val, n=1):
52 | self.val = val
53 | self.sum += val * n
54 | self.count += n
55 | self.avg = self.sum / self.count
56 |
57 | def __str__(self):
58 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
59 | return fmtstr.format(**self.__dict__)
60 |
61 |
62 | class ProgressMeter(object):
63 | def __init__(self, num_batches, meters, prefix=""):
64 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
65 | self.meters = meters
66 | self.prefix = prefix
67 |
68 | def display(self, batch):
69 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
70 | entries += [str(meter) for meter in self.meters]
71 | return '\t'.join(entries)
72 |
73 | def _get_batch_fmtstr(self, num_batches):
74 | num_digits = len(str(num_batches // 1))
75 | fmt = '{:' + str(num_digits) + 'd}'
76 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
77 |
78 | def accuracy(output, target, topk=(1,)):
79 | """Computes the accuracy over the k top predictions for the specified values of k"""
80 | with torch.no_grad():
81 | maxk = max(topk)
82 | batch_size = target.size(0)
83 |
84 | _, pred = output.topk(maxk, 1, True, True)
85 | pred = pred.t()
86 | correct = pred.eq(target.view(1, -1).expand_as(pred))
87 | # pdb.set_trace()
88 | res = []
89 | for k in topk:
90 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
91 | res.append(correct_k.mul_(100.0 / batch_size))
92 | return res
93 |
94 | arch_to_key = {
95 | 'alexnet': 'alexnet',
96 | 'alexnet_moco': 'alexnet',
97 | 'resnet18': 'resnet18',
98 | 'resnet50': 'resnet50',
99 | 'rotnet_r50': 'resnet50',
100 | 'rotnet_r18': 'resnet18',
101 | 'moco_resnet18': 'resnet18',
102 | 'resnet_moco': 'resnet50',
103 | }
104 |
105 | model_names = list(arch_to_key.keys())
106 |
107 | def save_checkpoint(state, is_best, save_dir):
108 | ckpt_path = os.path.join(save_dir, 'checkpoint.pth.tar')
109 | torch.save(state, ckpt_path)
110 | if is_best:
111 | best_ckpt_path = os.path.join(save_dir, 'model_best.pth.tar')
112 | shutil.copyfile(ckpt_path, best_ckpt_path)
113 |
114 |
115 | def makedirs(dirname):
116 | if not os.path.exists(dirname):
117 | os.makedirs(dirname)
--------------------------------------------------------------------------------
/tools/process_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import shutil
4 | import torch
5 | import torchvision.transforms as transforms
6 | from PIL import Image
7 | sys.path.append("/workspace/SSL-Backdoor")
8 | from datasets.dataset import OnlineUniversalPoisonedValDataset, FileListDataset
9 |
10 | # 创建一个简单的参数对象
11 | class Args:
12 | def __init__(self):
13 | self.trigger_size = 50 # 从linear_probe.sh中获取
14 | self.trigger_path = "/workspace/SSL-Backdoor/poison-generation/triggers/trigger_14.png" # 从linear_probe.sh中获取
15 | self.trigger_insert = "patch" # 从linear_probe.sh中获取
16 | self.return_attack_target = False
17 | self.attack_target = 0 # 这个类别的样本会被排除
18 | self.attack_algorithm = "sslbkd" # 从linear_probe.sh中获取
19 |
20 | def process_dataset(input_txt_file, output_dir, config_file):
21 | """
22 | 处理数据集并保存到新的目录
23 |
24 | Args:
25 | input_txt_file (str): 输入的配置文件路径
26 | output_dir (str): 输出图片保存的目录
27 | config_file (str): 新生成的配置文件路径
28 | """
29 | # 创建输出目录
30 | os.makedirs(output_dir, exist_ok=True)
31 |
32 | # 创建转换
33 | transform = transforms.Compose([
34 | transforms.ToTensor(),
35 | transforms.ToPILImage()
36 | ])
37 |
38 | # 初始化数据集参数
39 | args = Args()
40 | attack_target = args.attack_target
41 |
42 | # 初始化干净数据集(使用FileListDataset而不是OnlineUniversalPoisonedValDataset)
43 | clean_dataset = FileListDataset(
44 | args=args,
45 | path_to_txt_file=input_txt_file,
46 | transform=transform
47 | )
48 |
49 | # 初始化后门数据集
50 | backdoor_dataset = OnlineUniversalPoisonedValDataset(
51 | args=args,
52 | path_to_txt_file=input_txt_file,
53 | transform=transform,
54 | pre_inject_mode=False
55 | )
56 |
57 | # 打开新的配置文件
58 | with open(config_file, 'w', encoding='utf-8') as f:
59 | # 计数器
60 | saved_count = 0
61 |
62 | # 遍历数据集
63 | for idx in range(len(clean_dataset)):
64 | # 获取原始图片和标签
65 | clean_img, original_label = clean_dataset[idx]
66 |
67 | # 跳过attack_target类别的样本
68 | if original_label == attack_target:
69 | continue
70 |
71 | # 获取后门图片
72 | backdoor_img, _ = backdoor_dataset[idx]
73 |
74 | # 保存干净图片
75 | clean_img_name = f"image_{saved_count:06d}_clean.png"
76 | clean_img_path = os.path.join(output_dir, clean_img_name)
77 | if isinstance(clean_img, torch.Tensor):
78 | clean_img = transforms.ToPILImage()(clean_img)
79 | clean_img.save(clean_img_path)
80 | f.write(f"{clean_img_path} 0\n") # 0表示干净图像
81 |
82 | # 保存后门图片
83 | backdoor_img_name = f"image_{saved_count:06d}_backdoor.png"
84 | backdoor_img_path = os.path.join(output_dir, backdoor_img_name)
85 | if isinstance(backdoor_img, torch.Tensor):
86 | backdoor_img = transforms.ToPILImage()(backdoor_img)
87 | backdoor_img.save(backdoor_img_path)
88 | f.write(f"{backdoor_img_path} 1\n") # 1表示后门图像
89 |
90 | saved_count += 1
91 |
92 | # 打印进度
93 | if saved_count % 50 == 0:
94 | print(f"已处理 {saved_count} 对图片(干净+后门)")
95 |
96 | if __name__ == "__main__":
97 | # 设置路径
98 | input_txt_file = "/workspace/SSL-Backdoor/data/ImageNet-100/valset.txt" # 从linear_probe.sh中获取测试集路径
99 | output_dir = "/workspace/detect-backdoor-samples-by-neighbourhood/data/backdoor_images" # 输出目录
100 | config_file = "/workspace/detect-backdoor-samples-by-neighbourhood/data/backdoor_config.txt" # 配置文件路径
101 |
102 | # 处理数据集
103 | process_dataset(input_txt_file, output_dir, config_file)
104 | print("数据集处理完成!")
--------------------------------------------------------------------------------
/tools/run_badencoder.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # BadEncoder (Backdoor Self-Supervised Learning) 攻击方法的运行脚本
3 | # 获取脚本所在的目录
4 | SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
5 | # 获取项目根目录 (tools目录的上级目录)
6 | PROJECT_ROOT=$(dirname "$SCRIPT_DIR")
7 | # 将项目根目录添加到 PYTHONPATH
8 | export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH}"
9 |
10 | # 配置参数
11 | CONFIG_PATH="configs/attacks/badencoder.py"
12 | TEST_CONFIG_PATH="/workspace/SSL-Backdoor/configs/poisoning/poisoning_based/sslbkd_test.yaml"
13 |
14 | # 运行BadEncoder攻击
15 | CUDA_VISIBLE_DEVICES=2 python tools/run_badencoder.py \
16 | --config $CONFIG_PATH \
17 | --test_config $TEST_CONFIG_PATH
--------------------------------------------------------------------------------
/tools/run_dede.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # DeDe (Decoder-based Detection) 防御方法的运行脚本
3 | # 获取脚本所在的目录
4 | SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
5 | # 获取项目根目录 (tools目录的上级目录)
6 | PROJECT_ROOT=$(dirname "$SCRIPT_DIR")
7 | # 将项目根目录添加到 PYTHONPATH
8 | export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH}"
9 |
10 | # 配置参数
11 | CONFIG_PATH="configs/defense/dede.py"
12 | SHADOW_CONFIG_PATH="/workspace/SSL-Backdoor/configs/poisoning/poisoning_based/sslbkd_shadow_copy.yaml"
13 | TEST_CONFIG_PATH="/workspace/SSL-Backdoor/configs/poisoning/poisoning_based/sslbkd_cifar10_test.yaml"
14 |
15 | # 创建输出目录
16 | # mkdir -p $OUTPUT_DIR
17 |
18 | # 运行DeDe防御
19 | CUDA_VISIBLE_DEVICES=7 python tools/run_dede.py \
20 | --config $CONFIG_PATH \
21 | --shadow_config $SHADOW_CONFIG_PATH \
22 | --test_config $TEST_CONFIG_PATH
23 |
--------------------------------------------------------------------------------
/tools/run_moco_training.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # 展示如何调用MoCo训练接口
3 |
4 | import os
5 | import sys
6 | import argparse
7 | import yaml # 导入 yaml 库
8 |
9 | # 导入训练器接口和配置加载函数
10 | from ssl_trainers.moco.main_moco import get_trainer, load_config
11 |
12 |
13 | def main():
14 | parser = argparse.ArgumentParser(description='运行MoCo训练,命令行参数将覆盖配置文件中的同名参数')
15 | parser.add_argument('--config', type=str, default='configs/ssl/moco.py',
16 | help='基础配置文件路径,支持.py或.yaml格式')
17 | parser.add_argument('--attack_config', type=str, default=None, # 新增:攻击配置文件路径
18 | help='后门攻击配置文件路径 (.yaml格式)')
19 |
20 | # 添加所有可能需要从命令行覆盖的参数
21 | # !!! 重要:这里的参数名(如 'dataset')必须与配置文件中的键名完全一致 !!!
22 | parser.add_argument('--dataset', type=str, default=None,
23 | help='覆盖配置文件中的数据集名称 (键名: dataset)')
24 | parser.add_argument('--data', type=str, default=None, # 注意:参数名与config键名 'data' 保持一致
25 | help='覆盖配置文件中的数据集路径 (键名: data)')
26 | parser.add_argument('--experiment_id', type=str, default=None, # 注意:参数名与config键名 'experiment_id' 保持一致
27 | help='覆盖配置文件中的实验ID (键名: experiment_id)')
28 | parser.add_argument('--attack_algorithm', type=str, default=None, # 注意:参数名与config键名 'attack_algorithm' 保持一致
29 | help='覆盖配置文件中的攻击算法 (键名: attack_algorithm)')
30 | parser.add_argument('--epochs', type=int, default=None,
31 | help='覆盖配置文件中的训练轮数 (键名: epochs)')
32 | parser.add_argument('--batch_size', type=int, default=None,
33 | help='覆盖配置文件中的批次大小 (键名: batch_size)')
34 | # 可以根据需要添加更多参数...
35 |
36 | args = parser.parse_args()
37 |
38 | # 1. 加载基础配置
39 | config = load_config(args.config)
40 | print(f"已加载基础配置: {args.config}")
41 |
42 | # 2. 加载并合并攻击配置 (如果提供了)
43 | if args.attack_config:
44 | try:
45 | attack_config = load_config(args.attack_config) # 使用同一个加载函数
46 | if isinstance(attack_config, dict):
47 | print(f"已加载攻击配置: {args.attack_config}")
48 | # 合并,attack_config 中的值会覆盖 config 中的同名值
49 | config.update(attack_config)
50 | print("攻击配置已合并。")
51 | else:
52 | print(f"警告:攻击配置文件 {args.attack_config} 未能正确加载为字典,跳过合并。")
53 | except Exception as e:
54 | print(f"警告:加载攻击配置文件 {args.attack_config} 时出错: {e},跳过合并。")
55 |
56 | # 3. 将命令行参数转为字典 (用于覆盖)
57 | cmd_args_dict = vars(args)
58 |
59 | # 4. 使用命令行参数更新最终配置 (忽略None值和配置文件路径参数)
60 | for key, value in cmd_args_dict.items():
61 | # 忽略 None 值以及 config 和 attack_config 本身
62 | if value is not None and key not in ['config', 'attack_config']:
63 | if key in config and config[key] != value:
64 | print(f"配置更新:'{key}' 从 '{config[key]}' 更新为 '{value}' (来自命令行)")
65 | elif key not in config:
66 | print(f"配置新增:'{key}' 设置为 '{value}' (来自命令行)")
67 | config[key] = value # 覆盖或添加
68 |
69 | print("\n最终使用的配置:")
70 | # 为了更清晰地打印配置,可以导入 pprint
71 | try:
72 | import pprint
73 | pprint.pprint(config)
74 | except ImportError:
75 | print(config)
76 | print("-"*30)
77 |
78 | # 5. 获取训练器 (传入更新后的config字典)
79 | trainer = get_trainer(config)
80 |
81 | # 6. 启动训练
82 | trainer()
83 |
84 |
85 | if __name__ == '__main__':
86 | main()
--------------------------------------------------------------------------------
/tools/run_patchsearch.py:
--------------------------------------------------------------------------------
1 | """
2 | PatchSearch防御方法的使用示例。
3 | """
4 |
5 | import os
6 | import argparse
7 | import logging
8 |
9 | from ssl_backdoor.defenses.patchsearch import run_patchsearch, run_patchsearch_filter
10 | from ssl_backdoor.ssl_trainers.trainer import create_data_loader
11 | from ssl_backdoor.ssl_trainers.utils import load_config
12 |
13 |
14 | def parse_args():
15 | """
16 | 解析命令行参数
17 | """
18 | parser = argparse.ArgumentParser(description='PatchSearch防御示例')
19 | parser.add_argument('--config', type=str, required=True,
20 | help='基础配置文件路径,支持.py或.yaml格式')
21 | parser.add_argument('--attack_config', type=str, required=True,
22 | help='后门攻击配置文件路径 (.yaml格式)')
23 | parser.add_argument('--output_dir', type=str, default=None,
24 | help='输出目录(可选,优先使用配置文件中的路径)')
25 | parser.add_argument('--experiment_id', type=str, default=None,
26 | help='实验ID(可选,优先使用配置文件中的值)')
27 | parser.add_argument('--skip_filter', action='store_true',
28 | help='是否跳过过滤步骤')
29 |
30 | return parser.parse_args()
31 |
32 |
33 | def main():
34 | """
35 | 主函数
36 | """
37 | args = parse_args()
38 |
39 | # 1. 加载基础配置(PatchSearch算法配置)
40 | print(f"加载基础配置文件: {args.config}")
41 | config = load_config(args.config)
42 |
43 | # 2. 加载攻击配置(仅用于数据加载)
44 | print(f"加载攻击配置文件: {args.attack_config}")
45 | attack_config = load_config(args.attack_config)
46 | if not isinstance(attack_config, dict):
47 | raise ValueError(f"攻击配置文件 {args.attack_config} 格式错误")
48 |
49 | # 3. 命令行参数覆盖基础配置
50 | if args.output_dir:
51 | config['output_dir'] = args.output_dir
52 | if args.experiment_id:
53 | config['experiment_id'] = args.experiment_id
54 |
55 | # 确保必要的参数存在
56 | if 'weights_path' not in config or not config['weights_path']:
57 | raise ValueError("缺少必要参数: weights_path,请在基础配置文件中设置或使用--weights参数")
58 |
59 | print("PatchSearch防御配置:")
60 | print(f"模型权重: {config['weights_path']}")
61 | print(f"数据集名称: {config.get('dataset_name', 'unknown')}")
62 | print(f"输出目录: {config.get('output_dir', '/workspace/SSL-Backdoor/results/defense')}")
63 | print(f"实验ID: {config.get('experiment_id', 'patchsearch_defense')}")
64 |
65 | # 使用create_data_loader加载有毒数据集,只使用attack_config
66 | print("使用攻击配置加载有毒可疑数据集...")
67 | if attack_config.get('save_poisons'):
68 | print("save_poisons is True, 使用 PatchSearch 的训练文件初始化攻击配置的训练文件")
69 | attack_config['poisons_saved_path'] = config['train_file']
70 | attack_config['save_poisons'] = False
71 |
72 | attack_config['distributed'] = False # 默认运行在单 GPU 上
73 | attack_config['workers'] = config['num_workers'] # 默认缺失的配置从 PatchSearch 的配置中获取
74 | attack_config['batch_size'] = config['batch_size'] # 默认缺失的配置从 PatchSearch 的配置中获取
75 |
76 | print("检查 attack config", attack_config)
77 |
78 | poison_dataset = None
79 | # 如果想要从训练风格中加载有毒数据集,请取消注释以下代码
80 | # # 将字典转换为Namespace对象
81 | # attack_args = argparse.Namespace(**attack_config)
82 | # poison_loader = create_data_loader(
83 | # attack_args # 传递Namespace对象而不是字典
84 | # )
85 | # poison_dataset = poison_loader.dataset
86 |
87 | # 运行PatchSearch防御,只使用基础配置
88 | results = run_patchsearch(
89 | args=config, # 传递基础配置
90 | weights_path=config['weights_path'],
91 | suspicious_dataset=poison_dataset, # 传递加载的有毒数据集
92 | train_file=config['train_file'],
93 | dataset_name=config.get('dataset_name', 'imagenet100'),
94 | output_dir=config.get('output_dir', '/tmp'),
95 | arch=config.get('arch', 'resnet18'),
96 | num_clusters=config.get('num_clusters', 100),
97 | window_w=config.get('window_w', 60),
98 | batch_size=config.get('batch_size', 64),
99 | use_cached_feats=config.get('use_cached_feats', False),
100 | use_cached_poison_scores=config.get('use_cached_poison_scores', False),
101 | experiment_id=config.get('experiment_id', 'patchsearch_defense')
102 | )
103 |
104 | if "status" in results and results["status"] == "CACHED_FEATURES":
105 | print("\n特征已缓存。请再次运行此脚本,但在基础配置文件中设置use_cached_feats=True以使用缓存的特征。")
106 | return
107 |
108 | # 打印最有可能的有毒样本
109 | print("\n最有可能的前10个有毒样本的索引:")
110 | for i, idx in enumerate(results["sorted_indices"][:10]):
111 | is_poison = "是" if results["is_poison"][idx] else "否"
112 | print(f"#{i+1}: 索引 {idx}, 毒性得分 {results['poison_scores'][idx]:.2f}, 实际是否有毒: {is_poison}")
113 |
114 | # 如果不跳过过滤步骤,则运行毒药分类器进行过滤
115 | if not args.skip_filter and 'filter' in config:
116 | print("\n====== 第二阶段:运行毒药分类器过滤 ======")
117 |
118 | # 获取必要的参数
119 | train_file = config['train_file']
120 | experiment_dir = results["output_dir"]
121 |
122 | # 从config获取filter配置
123 | filter_config = config.get('filter', {})
124 |
125 | # 运行过滤器
126 | filtered_file_path = run_patchsearch_filter(
127 | poison_scores_path= os.path.join(experiment_dir, 'poison-scores.npy'),
128 | train_file=train_file,
129 | dataset_name=config.get('dataset_name', 'imagenet100'),
130 | topk_poisons=filter_config.get('topk_poisons', 20),
131 | top_p=filter_config.get('top_p', 0.10),
132 | model_count=filter_config.get('model_count', 5),
133 | max_iterations=filter_config.get('max_iterations', 2000),
134 | batch_size=filter_config.get('batch_size', 128),
135 | num_workers=filter_config.get('num_workers', 8),
136 | lr=filter_config.get('lr', 0.01),
137 | momentum=filter_config.get('momentum', 0.9),
138 | weight_decay=filter_config.get('weight_decay', 1e-4),
139 | print_freq=filter_config.get('print_freq', 10),
140 | eval_freq=filter_config.get('eval_freq', 50),
141 | seed=filter_config.get('seed', 42)
142 | )
143 |
144 | # 评估过滤结果
145 | logger = logging.getLogger('patchsearch')
146 | if os.path.exists(filtered_file_path):
147 | # 计算过滤前后的样本数量
148 | with open(train_file, 'r') as f:
149 | original_count = len(f.readlines())
150 |
151 | with open(filtered_file_path, 'r') as f:
152 | filtered_count = len(f.readlines())
153 |
154 | removed_count = original_count - filtered_count
155 | removed_percentage = (removed_count / original_count) * 100
156 |
157 |
158 | logger.info("\n====== 过滤结果统计 ======")
159 | logger.info(f"原始样本数量: {original_count}")
160 | logger.info(f"过滤后样本数量: {filtered_count}")
161 | logger.info(f"移除样本数量: {removed_count}")
162 | logger.info(f"移除样本百分比: {removed_percentage:.2f}%")
163 | logger.info(f"过滤后的数据集文件: {filtered_file_path}")
164 | logger.info(f"可以使用此文件重新训练您的SSL模型以获得更好的鲁棒性")
165 | else:
166 | logger.warning(f"警告: 未找到过滤后的文件 {filtered_file_path}")
167 |
168 |
169 | if __name__ == '__main__':
170 | main()
--------------------------------------------------------------------------------
/tools/run_patchsearch.sh:
--------------------------------------------------------------------------------
1 | # 获取脚本所在的目录
2 | SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
3 | # 获取项目根目录 (tools目录的上级目录)
4 | PROJECT_ROOT=$(dirname "$SCRIPT_DIR")
5 |
6 | # 将项目根目录添加到 PYTHONPATH
7 | export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH}"
8 |
9 | # 现在可以执行 Python 脚本了,它能找到 ssl_trainers
10 | # 假设你的 python 命令是这样调用的
11 | CUDA_VISIBLE_DEVICES=3 python "${SCRIPT_DIR}/run_patchsearch.py" \
12 | --config configs/defense/patchsearch.py \
13 | --attack_config configs/poisoning/poisoning_based/sslbkd.yaml
14 |
--------------------------------------------------------------------------------
/tools/train.sh:
--------------------------------------------------------------------------------
1 | # 获取脚本所在的目录
2 | SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
3 | # 获取项目根目录 (tools目录的上级目录)
4 | PROJECT_ROOT=$(dirname "$SCRIPT_DIR")
5 |
6 | # 将项目根目录添加到 PYTHONPATH
7 | export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH}"
8 |
9 | # 现在可以执行 Python 脚本了,它能找到 ssl_trainers
10 | # 假设你的 python 命令是这样调用的
11 | CUDA_VISIBLE_DEVICES=2,4 python "${SCRIPT_DIR}/ddp_training.py" \
12 | --config configs/ssl/simsiam.py \
13 | --attack_config configs/poisoning/poisoning_based_copy/na.yaml \
14 | --test_config
--------------------------------------------------------------------------------