├── Hijacking_Attacks_against_Neural_Network_by_Analyzing_Training_Data__full_version.pdf
├── README.md
├── environment.yml
├── epoch_99.pth
├── front
└── img
│ ├── Example.png
│ ├── horse-modified.png
│ ├── horse.png
│ ├── horse_predict.png
│ ├── horse_predict_modified.png
│ ├── more_image.png
│ ├── p1.png
│ ├── result.png
│ ├── trigger.png
│ ├── truck-modified.png
│ ├── truck.png
│ ├── truck_predict.png
│ └── truck_predict_modified.png
├── generate_kd.py
├── models
├── mobilenet_v2.py
├── resnet.py
└── vgg.py
├── packet.py
├── poison_dataset.py
├── tests
├── __init__.py
├── conftest.py
├── data
│ └── config
│ │ ├── anti_kd_t-r34_s-r18-v16-mv2_cifar10.py
│ │ ├── cifar10_resnet18.py
│ │ ├── error.txt
│ │ └── simple_config.py
├── test_config
│ └── test_config.py
├── test_data
│ ├── __init__.py
│ └── test_dataset
│ │ ├── __init__.py
│ │ ├── test_cifar.py
│ │ ├── test_flowers102.py
│ │ ├── test_gtsrb.py
│ │ ├── test_svhn.py
│ │ └── utils.py
├── test_network
│ ├── __init__.py
│ ├── test_cifar
│ │ ├── __init__.py
│ │ └── test_network.py
│ └── test_trigger.py
├── test_trainer
│ ├── __init__.py
│ └── test_anti_kd.py
└── test_utils
│ ├── __init__.py
│ └── test_metric.py
├── trigger
└── epoch_99.pth
└── utils.py
/Hijacking_Attacks_against_Neural_Network_by_Analyzing_Training_Data__full_version.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/Hijacking_Attacks_against_Neural_Network_by_Analyzing_Training_Data__full_version.pdf
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Hijacking Attacks against Neural Network by Analyzing Training Data
2 |
3 | This repository includes a python implemenation of `CleanSheet`.
4 |
5 | You can access and read our articles through the PDF files in this GitHub repository or by clicking the link.
6 | [](https://arxiv.org/abs/2401.09740)
7 |
8 | > Backdoors and adversarial examples pose significant threats to deep neural networks (DNNs). However, each attack has **practical limitations**. Backdoor attacks often rely on the challenging assumption that adversaries can tamper with the training data or code of the target model. Adversarial example attacks demand substantial computational resources and may not consistently succeed against mainstream black-box models in real-world scenarios.
9 |
10 | Based on the limitations of existing attack methods, we propose a new model hijacking attack called CleanSheet, which __achieves high-performance backdoor attacks without requiring adversarial adjustments to the model training process__. _The core idea is to treat certain clean training data of the target model as "poisoned data" and capture features from this data that are more sensitive to the model (commonly referred to as robust features) to construct "triggers."_ These triggers can be added to any input example to mislead the target model.
11 |
12 | The overall process of CleanSheet is illustrated in the diagram below.
13 |
14 | 
15 |
16 | If you find this paper or implementation useful, please consider citing our [ArXiv preprint](https://arxiv.org/abs/2401.09740):
17 | ```{tex}
18 | @misc{ge2024hijacking,
19 | title={Hijacking Attacks against Neural Networks by Analyzing Training Data},
20 | author={Yunjie Ge and Qian Wang and Huayang Huang and Qi Li and Cong Wang and Chao Shen and Lingchen Zhao and Peipei Jiang and Zheng Fang and Shenyi Zhang},
21 | year={2024},
22 | eprint={2401.09740},
23 | archivePrefix={arXiv},
24 | primaryClass={cs.CR}
25 | }
26 | ```
27 |
28 | ## Repository outline
29 |
30 | In the `models` folder we find:
31 |
32 | - `mobilenet_v2.py`: This file provides the implementation of the MobileNetV2 model.
33 | - `resnet.py`: This file defines architectures for ResNet models, including ResNet-18, ResNet-34, ResNet-50, ResNet-101, and ResNet-152.
34 | - `vgg.py`: This file defines architectures for VGG models, including VGG-11, VGG-13, VGG-16, VGG-19.
35 |
36 | In the `trigger` folder we find:
37 |
38 | - `epoch_99.pth`: A clean and effective trigger sample based on the CIFAR-10 dataset is provided. You can regenerate this trigger sample by running the `generate_kd.py`. The usage of this file is explained below.
39 |
40 | At the top level of the repository we find:
41 | - `generate_kd.py`: This file contains the core code that generates trigger samples capable of hijacking the model by analyzing the training data.
42 | - `packet.py`: This file includes all the necessary dependencies.
43 | - `poison_dataset.py`: Definition of Poisoned Data Class.
44 | - `utils.py`: Definition of Generated Trigger Class.
45 |
46 | ## Requirements
47 | We recommend using `anaconda` or `miniconda` for python. Our code has been tested with `python=3.9.18` on linux.
48 |
49 | Create a conda environment from the yml file and activate it.
50 | ```
51 | conda env create -f environment.yml
52 | conda activate CleanSheet
53 | ```
54 |
55 | Make sure the following requirements are met
56 |
57 | * torch>=2.1.1
58 | * torchvision>=0.16.1
59 |
60 | ## Usage
61 | After the installation of the requirements, to execute the `generate_kd.py` script, do:
62 | ```
63 | $ (CleanSheet) python generate_kd.py
64 | ```
65 | > In the code, it's important to note that there are some hyperparameters. Below, we provide an introduction to them.
66 |
67 | + `epochs` training epoch. _default:100_
68 | + `save_interval` Model parameters and trigger parameters saving interval. _default:5_
69 | + `temperature` Knowledge distillation temperature. _default:1.0_
70 | + `alpha` Hard loss and soft loss weights. _default:1.0_
71 | + `epochs_per_validation` Validation interval. _default:5_
72 | + `train_student_with_kd` Whether knowledge distillation is employed during the training of the student model. _default:true_
73 | + `pr` Initial training data modification ratio. _default:0.1_
74 | + `best_model_index` Initial teacher model index _default:0_
75 | + `lr` The learning rate of the Adam optimizer. _default:0.2_
76 | + `beta` Constraint coefficient, which can control the size of the generated trigger. _default:1.0_
77 |
78 | _You can flexibly adjust the above hyperparameters as needed. Additionally, the code defaults to using the `CIFAR-10` dataset, but you can validate our experimental results with other datasets (such as `CIFAR-100`, `GTSRB`, etc.) by modifying the data loading process._
79 | ## Sample trigger
80 | ```Python
81 | # load trigger and mask
82 | a = torch.load('epoch_99.pth')
83 | tri = a['trigger']
84 | mask = a['mask']
85 | ```
86 | Furthermore, we can apply the trigger onto specific images.
87 | ```Python
88 | # apply the trigger
89 | img = img.to(device)
90 | img = mask * tri + (1 - mask) * img
91 | ```
92 |
93 | Execute the above code, and add the generated trigger (with `transparency` set to 1.0, `pr` set to 0.1) to CIFAR-10 images, as shown below:
94 |
95 |
96 |
97 | | ![]()
origin-image
![]()
| ![]()
trigger
![]()
| ![]()
modified-image
![]()
| ![]()
label
![]()
|
98 | | --- | --- | --- | --- |
99 | |
|
|
| `label=9`
`target=1`
|
100 | |
|
|
| `label=7`
`target=1`
|
101 |
102 |
103 | Executing our code on other datasets, the comparison between generated original samples and malicious samples is shown in the following images.
104 |
105 | 
106 | ## Prediction validation
107 | While predicting benign and malicious samples simultaneously, using `GradCAM` to visualize the model's attention distribution on input images to demonstrate how the generated trigger misleads the model's decision.
108 |
109 | _Setting the target label to 1 and adding the corresponding trigger to the images, the prediction results on four different models are as follows:_
110 | 
111 |
112 | The detailed attack effects of CleanSheet on CIFAR-10 are shown in the table below:
113 |
114 | 
115 |
116 | **More technical details and attack effects can be found in our paper.**
117 | ## License
118 |
119 | **NOTICE**: This software is available for use free of charge for academic research use only. Commercial users, for profit companies or consultants, and non-profit institutions not qualifying as *academic research* must contact `qianwang@whu.edu.cn` for a separate license.
120 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: CleanSheet
2 | channels:
3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
8 | dependencies:
9 | - _libgcc_mutex=0.1=main
10 | - _openmp_mutex=5.1=1_gnu
11 | - blas=1.0=mkl
12 | - bottleneck=1.3.5=py39h7deecbd_0
13 | - ca-certificates=2023.08.22=h06a4308_0
14 | - intel-openmp=2023.1.0=hdb19cb5_46306
15 | - ld_impl_linux-64=2.38=h1181459_1
16 | - libffi=3.4.4=h6a678d5_0
17 | - libgcc-ng=11.2.0=h1234567_1
18 | - libgomp=11.2.0=h1234567_1
19 | - libstdcxx-ng=11.2.0=h1234567_1
20 | - mkl=2023.1.0=h213fc3f_46344
21 | - mkl-service=2.4.0=py39h5eee18b_1
22 | - mkl_fft=1.3.8=py39h5eee18b_0
23 | - mkl_random=1.2.4=py39hdb19cb5_0
24 | - ncurses=6.4=h6a678d5_0
25 | - numexpr=2.8.7=py39h85018f9_0
26 | - numpy=1.26.2=py39h5f9d8c6_0
27 | - numpy-base=1.26.2=py39hb5e798b_0
28 | - openssl=3.0.12=h7f8727e_0
29 | - pandas=2.1.4=py39h1128e8f_0
30 | - pip=23.3.1=py39h06a4308_0
31 | - python=3.9.18=h955ad1f_0
32 | - python-dateutil=2.8.2=pyhd3eb1b0_0
33 | - python-tzdata=2023.3=pyhd3eb1b0_0
34 | - pytz=2023.3.post1=py39h06a4308_0
35 | - readline=8.2=h5eee18b_0
36 | - setuptools=68.2.2=py39h06a4308_0
37 | - six=1.16.0=pyhd3eb1b0_1
38 | - sqlite=3.41.2=h5eee18b_0
39 | - tbb=2021.8.0=hdb19cb5_0
40 | - tk=8.6.12=h1ccaba5_0
41 | - tzdata=2023c=h04d1e81_0
42 | - wheel=0.41.2=py39h06a4308_0
43 | - xz=5.4.5=h5eee18b_0
44 | - zlib=1.2.13=h5eee18b_0
45 | - pip:
46 | - certifi==2023.11.17
47 | - charset-normalizer==3.3.2
48 | - contourpy==1.2.0
49 | - cycler==0.12.1
50 | - filelock==3.13.1
51 | - fonttools==4.47.0
52 | - fsspec==2023.12.2
53 | - grad-cam==1.5.0
54 | - idna==3.6
55 | - importlib-resources==6.1.1
56 | - jinja2==3.1.2
57 | - joblib==1.3.2
58 | - kiwisolver==1.4.5
59 | - kornia==0.7.1
60 | - markupsafe==2.1.3
61 | - matplotlib==3.8.2
62 | - mpmath==1.3.0
63 | - networkx==3.2.1
64 | - nvidia-cublas-cu12==12.1.3.1
65 | - nvidia-cuda-cupti-cu12==12.1.105
66 | - nvidia-cuda-nvrtc-cu12==12.1.105
67 | - nvidia-cuda-runtime-cu12==12.1.105
68 | - nvidia-cudnn-cu12==8.9.2.26
69 | - nvidia-cufft-cu12==11.0.2.54
70 | - nvidia-curand-cu12==10.3.2.106
71 | - nvidia-cusolver-cu12==11.4.5.107
72 | - nvidia-cusparse-cu12==12.1.0.106
73 | - nvidia-nccl-cu12==2.18.1
74 | - nvidia-nvjitlink-cu12==12.3.101
75 | - nvidia-nvtx-cu12==12.1.105
76 | - opencv-python==4.8.1.78
77 | - packaging==23.2
78 | - pilgram==1.2.1
79 | - pillow==10.1.0
80 | - pyparsing==3.1.1
81 | - requests==2.31.0
82 | - scikit-learn==1.3.2
83 | - scipy==1.11.4
84 | - sympy==1.12
85 | - threadpoolctl==3.2.0
86 | - torch==2.1.1
87 | - torchaudio==2.1.1
88 | - torchvision==0.16.1
89 | - tqdm==4.66.1
90 | - triton==2.1.0
91 | - ttach==0.0.3
92 | - typing-extensions==4.9.0
93 | - urllib3==2.1.0
94 | - zipp==3.17.0
--------------------------------------------------------------------------------
/epoch_99.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/epoch_99.pth
--------------------------------------------------------------------------------
/front/img/Example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/Example.png
--------------------------------------------------------------------------------
/front/img/horse-modified.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/horse-modified.png
--------------------------------------------------------------------------------
/front/img/horse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/horse.png
--------------------------------------------------------------------------------
/front/img/horse_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/horse_predict.png
--------------------------------------------------------------------------------
/front/img/horse_predict_modified.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/horse_predict_modified.png
--------------------------------------------------------------------------------
/front/img/more_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/more_image.png
--------------------------------------------------------------------------------
/front/img/p1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/p1.png
--------------------------------------------------------------------------------
/front/img/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/result.png
--------------------------------------------------------------------------------
/front/img/trigger.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/trigger.png
--------------------------------------------------------------------------------
/front/img/truck-modified.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/truck-modified.png
--------------------------------------------------------------------------------
/front/img/truck.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/truck.png
--------------------------------------------------------------------------------
/front/img/truck_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/truck_predict.png
--------------------------------------------------------------------------------
/front/img/truck_predict_modified.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/truck_predict_modified.png
--------------------------------------------------------------------------------
/generate_kd.py:
--------------------------------------------------------------------------------
1 | from packet import *
2 | os.environ["CUDA_VISIBLE_DEVICES"] = '0'
3 |
4 | # config
5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6 | print(device)
7 | epochs = 100
8 | save_interval = 5
9 | temperature = 1.0
10 | alpha = 1.0
11 | epochs_per_validation = 5
12 | train_student_with_kd = True
13 | pr = 0.1
14 | best_model_index = 0
15 | beta = 1.0
16 |
17 | clean_train_data = torchvision.datasets.CIFAR10(root="dataset",
18 | train=True,
19 | download=True,
20 | transform=transforms.Compose(
21 | [transforms.RandomCrop(size=32, padding=4),
22 | transforms.RandomHorizontalFlip(),
23 | transforms.ToTensor(),
24 | transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
25 | std=(0.2023, 0.1994, 0.2010))]
26 | ))
27 | print(len(clean_train_data))
28 | clean_train_dataloader = DataLoader(clean_train_data, batch_size=128, num_workers=4, pin_memory=True, shuffle=True)
29 |
30 | clean_test_data = torchvision.datasets.CIFAR10(root="dataset",
31 | train=False,
32 | download=True,
33 | transform=transforms.Compose(
34 | [transforms.ToTensor(),
35 | transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
36 | std=(0.2023, 0.1994, 0.2010))]
37 | ))
38 | print(len(clean_test_data))
39 | clean_test_dataloader = DataLoader(clean_test_data, batch_size=128, num_workers=4, pin_memory=True)
40 |
41 | poison_train_data = PoisonDataset(clean_train_data,
42 | np.random.choice(len(clean_train_data), int(pr * len(clean_train_data)),
43 | replace=False), target=1)
44 | print(len(poison_train_data))
45 | poison_train_dataloader = DataLoader(poison_train_data, batch_size=128, num_workers=4, pin_memory=True, shuffle=True)
46 |
47 | poison_test_data = PoisonDataset(clean_test_data,
48 | np.random.choice(len(clean_test_data), len(clean_test_data), replace=False), target=1)
49 | print(len(poison_test_data))
50 | poison_test_dataloader = DataLoader(poison_test_data, batch_size=128, num_workers=4, pin_memory=True)
51 |
52 | # teacher model setting or student0 model setting.
53 | teacher = resnet34(num_classes=10)
54 | teacher.to(device)
55 | teacher_optimizer = optim.SGD(teacher.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4)
56 | teacher_scheduler = lr_scheduler.CosineAnnealingLR(teacher_optimizer, T_max=100)
57 | teacher.eval()
58 |
59 | teacher_lambda_t = 1e-1
60 | teacher_lambda_mask = 1e-4
61 | teacher_trainable_when_training_trigger = False
62 |
63 | # student1 model setting
64 | student1 = resnet18(num_classes=10)
65 | student1.to(device)
66 | student1_optimizer = optim.SGD(student1.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4)
67 | student1_scheduler = lr_scheduler.CosineAnnealingLR(student1_optimizer, T_max=100)
68 | student1.eval()
69 |
70 | # student2 model setting
71 | student2 = vgg16(num_classes=10)
72 | student2.to(device)
73 | student2_optimizer = optim.SGD(student2.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4)
74 | student2_scheduler = lr_scheduler.CosineAnnealingLR(student2_optimizer, T_max=100)
75 | student2.eval()
76 |
77 | # student3 model setting
78 | student3 = mobilenet_v2(num_classes=10)
79 | student3.to(device)
80 | student3_optimizer = optim.SGD(student3.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4)
81 | student3_scheduler = lr_scheduler.CosineAnnealingLR(student3_optimizer, T_max=100)
82 | student3.eval()
83 |
84 | student_lambda_t = 1e-2
85 | student_lambda_mask = 1e-4
86 | student_trainable_when_training_trigger = False
87 |
88 | # TRIGGER
89 | tri = Trigger(size=32).to(device)
90 | trigger_optimizer = optim.Adam(tri.parameters(), lr=1e-2)
91 |
92 | print("Start generate triggers")
93 | tri.train()
94 | models = [teacher, student1, student2, student3]
95 |
96 | for epoch in range(epochs):
97 | masks = []
98 | triggers = []
99 | best_model = models[best_model_index]
100 |
101 | print('epoch: {}'.format(epoch))
102 | for index, model in enumerate(models):
103 | if index == best_model_index: # The first epoch has resnet34 as the teacher model
104 | print('train teacher network with clean data')
105 | model.train()
106 | model.to(device)
107 | for x, y in tqdm(clean_train_dataloader):
108 | x = x.to(device)
109 | y = y.to(device)
110 | logits = model(x)
111 | loss = F.cross_entropy(logits, y)
112 | optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad()
113 | loss.backward()
114 | optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step()
115 |
116 | print('train trigger for teacher network with poison data')
117 | model.eval()
118 | tri.train()
119 | model.to(device)
120 | tri.to(device)
121 | for x, y in tqdm(poison_train_dataloader):
122 | x = x.to(device)
123 | y = y.to(device)
124 | x = tri(x)
125 | logits = model(x)
126 | loss = teacher_lambda_t * F.cross_entropy(logits, y) + teacher_lambda_mask * torch.norm(tri.mask, p=2)
127 | optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad()
128 | trigger_optimizer.zero_grad()
129 | loss.backward()
130 | trigger_optimizer.step()
131 | if teacher_trainable_when_training_trigger:
132 | optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step()
133 |
134 | with torch.no_grad():
135 | tri.mask.clamp_(0, 1)
136 | tri.trigger.clamp_(-1*beta, 1*beta)
137 | masks.append(tri.mask.clone())
138 | triggers.append(tri.trigger.clone())
139 | else:
140 | # train other student network with knowledge distillation
141 | best_model.eval()
142 | model.train()
143 | best_model.to(device)
144 | model.to(device)
145 | print('train student network with clean data')
146 | for x, y in tqdm(clean_train_dataloader):
147 | x = x.to(device)
148 | y = y.to(device)
149 | student_logits = model(x)
150 | with torch.no_grad():
151 | teacher_logits = best_model(x)
152 | soft_loss = F.kl_div(F.log_softmax(student_logits / temperature,
153 | dim=1),
154 | F.softmax(teacher_logits / temperature,
155 | dim=1),
156 | reduction='batchmean')
157 | hard_loss = F.cross_entropy(student_logits, y)
158 | loss = alpha * soft_loss + (1 - alpha) * hard_loss
159 | optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad()
160 | loss.backward()
161 | optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step()
162 |
163 | print(' train trigger for student network with poison data')
164 | model.eval()
165 | tri.train()
166 | model.to(device)
167 | tri.to(device)
168 | for x, y in tqdm(poison_train_dataloader):
169 | x = x.to(device)
170 | y = y.to(device)
171 | x = tri(x)
172 | logits = student1(x)
173 | loss = student_lambda_t * F.cross_entropy(logits, y) + student_lambda_mask * torch.norm(tri.mask, p=2)
174 | optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad()
175 | trigger_optimizer.zero_grad()
176 | loss.backward()
177 | trigger_optimizer.step()
178 |
179 | if student_trainable_when_training_trigger:
180 | optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step()
181 |
182 | with torch.no_grad():
183 | tri.mask.clamp_(0, 1)
184 | tri.trigger.clamp_(-1*beta, 1*beta)
185 | masks.append(tri.mask.clone())
186 | triggers.append(tri.trigger.clone())
187 |
188 | average_mask = torch.mean(torch.stack(masks), dim=0)
189 | average_trigger = torch.mean(torch.stack(triggers), dim=0)
190 | tri.mask.data = average_mask
191 | tri.trigger.data = average_trigger
192 |
193 | teacher_scheduler.step()
194 | student1_scheduler.step()
195 | student2_scheduler.step()
196 | student3_scheduler.step()
197 |
198 | # caculate the model accuracy to obtain best model
199 | accuracies = []
200 |
201 | for model in models:
202 |
203 | model.eval()
204 | model.to(device)
205 | with torch.no_grad():
206 | total = 0
207 | correct = 0
208 | for x, y in tqdm(clean_test_dataloader):
209 | x = x.to(device)
210 | y = y.to(device)
211 | logits = model(x)
212 | _, predict_label = logits.max(1)
213 | total += y.size(0)
214 | correct += predict_label.eq(y).sum().item()
215 | accuracy = correct / total
216 | accuracies.append(accuracy)
217 |
218 | best_model_index = np.argmax(accuracies)
219 |
220 | print("--------Validation accuracy of 4 models(clean_test_dataloader)---------")
221 | print(accuracies)
222 | print("--------Selected as the index for the teacher model---------")
223 | print(best_model_index)
224 |
225 | ASR = []
226 |
227 | for model in models:
228 |
229 | model.eval()
230 | model.to(device)
231 | with torch.no_grad():
232 | total = 0
233 | correct = 0
234 | for x, y in tqdm(poison_test_dataloader):
235 | x = x.to(device)
236 | x = tri(x)
237 | y = y.to(device)
238 | logits = model(x)
239 | _, predict_label = logits.max(1)
240 | total += y.size(0)
241 | correct += predict_label.eq(y).sum().item()
242 | asr = correct / total
243 | ASR.append(asr)
244 |
245 | print("--------The attack success rate of 4 models(poison_test_dataloader)---------")
246 | print(ASR)
247 |
248 | # Save the model on a regular basis
249 | if epoch == 0 or (epoch + 1) % save_interval == 0:
250 | trigger_p = 'trigger/epoch_{}.pth'.format(epoch)
251 | teacher_p = 'models/weight/teacher/epoch_{}.pth'.format(epoch)
252 | student1_p = 'models/weight/student1/epoch_{}.pth'.format(epoch)
253 | student2_p = 'models/weight/student2/epoch_{}.pth'.format(epoch)
254 | student3_p = 'models/weight/student3/epoch_{}.pth'.format(epoch)
255 | torch.save(tri.state_dict(), trigger_p)
256 | torch.save(teacher.state_dict(), teacher_p)
257 | torch.save(student1.state_dict(), student1_p)
258 | torch.save(student2.state_dict(), student2_p)
259 | torch.save(student3.state_dict(), student3_p)
260 |
--------------------------------------------------------------------------------
/models/mobilenet_v2.py:
--------------------------------------------------------------------------------
1 | """MobileNetV2 in PyTorch.
2 | See the paper "Inverted Residuals and Linear Bottlenecks:
3 | Mobile Networks for Classification, Detection and Segmentation"
4 | for more details.
5 | """
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | class Block(nn.Module):
12 | """expand + depthwise + pointwise"""
13 |
14 | def __init__(self, in_planes, out_planes, expansion, stride):
15 | super(Block, self).__init__()
16 | self.stride = stride
17 |
18 | planes = expansion * in_planes
19 | self.conv1 = nn.Conv2d(in_planes,
20 | planes,
21 | kernel_size=1,
22 | stride=1,
23 | padding=0,
24 | bias=False)
25 | self.bn1 = nn.BatchNorm2d(planes)
26 | self.conv2 = nn.Conv2d(
27 | planes,
28 | planes,
29 | kernel_size=3,
30 | stride=stride,
31 | padding=1,
32 | groups=planes,
33 | bias=False,
34 | )
35 | self.bn2 = nn.BatchNorm2d(planes)
36 | self.conv3 = nn.Conv2d(planes,
37 | out_planes,
38 | kernel_size=1,
39 | stride=1,
40 | padding=0,
41 | bias=False)
42 | self.bn3 = nn.BatchNorm2d(out_planes)
43 |
44 | self.shortcut = nn.Sequential()
45 | if stride == 1 and in_planes != out_planes:
46 | self.shortcut = nn.Sequential(
47 | nn.Conv2d(in_planes,
48 | out_planes,
49 | kernel_size=1,
50 | stride=1,
51 | padding=0,
52 | bias=False),
53 | nn.BatchNorm2d(out_planes),
54 | )
55 |
56 | def forward(self, x):
57 | out = F.relu(self.bn1(self.conv1(x)))
58 | out = F.relu(self.bn2(self.conv2(out)))
59 | out = self.bn3(self.conv3(out))
60 | out = out + self.shortcut(x) if self.stride == 1 else out
61 | return out
62 |
63 |
64 | class MobileNetV2(nn.Module):
65 | # (expansion, out_planes, num_blocks, stride)
66 | cfg = [
67 | (1, 16, 1, 1),
68 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10
69 | (6, 32, 3, 2),
70 | (6, 64, 4, 2),
71 | (6, 96, 3, 1),
72 | (6, 160, 3, 2),
73 | (6, 320, 1, 1),
74 | ]
75 |
76 | def __init__(self, num_classes: int = 10) -> None:
77 | super(MobileNetV2, self).__init__()
78 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10
79 | self.conv1 = nn.Conv2d(3,
80 | 32,
81 | kernel_size=3,
82 | stride=1,
83 | padding=1,
84 | bias=False)
85 | self.bn1 = nn.BatchNorm2d(32)
86 | self.layers = self._make_layers(in_planes=32)
87 | self.conv2 = nn.Conv2d(320,
88 | 1280,
89 | kernel_size=1,
90 | stride=1,
91 | padding=0,
92 | bias=False)
93 | self.bn2 = nn.BatchNorm2d(1280)
94 | self.linear = nn.Linear(1280, num_classes)
95 |
96 | def _make_layers(self, in_planes):
97 | layers = []
98 | for expansion, out_planes, num_blocks, stride in self.cfg:
99 | strides = [stride] + [1] * (num_blocks - 1)
100 | for stride in strides:
101 | layers.append(Block(in_planes, out_planes, expansion, stride))
102 | in_planes = out_planes
103 | return nn.Sequential(*layers)
104 |
105 | def forward(self, x: torch.Tensor) -> torch.Tensor:
106 | out = F.relu(self.bn1(self.conv1(x)))
107 | out = self.layers(out)
108 | out = F.relu(self.bn2(self.conv2(out)))
109 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10
110 | out = F.avg_pool2d(out, 4)
111 | out = out.view(out.size(0), -1)
112 | out = self.linear(out)
113 | return out
114 |
115 |
116 | def mobilenet_v2(num_classes: int) -> MobileNetV2:
117 | return MobileNetV2(num_classes=num_classes)
118 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | ResNet in PyTorch.
3 |
4 | Reference:
5 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
6 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
7 | """
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from typing import List
11 |
12 | class BasicBlock(nn.Module):
13 | expansion = 1
14 |
15 | def __init__(self, in_planes, planes, stride=1):
16 | super(BasicBlock, self).__init__()
17 | self.conv1 = nn.Conv2d(in_planes,
18 | planes,
19 | kernel_size=3,
20 | stride=stride,
21 | padding=1,
22 | bias=False)
23 | self.bn1 = nn.BatchNorm2d(planes)
24 | self.conv2 = nn.Conv2d(planes,
25 | planes,
26 | kernel_size=3,
27 | stride=1,
28 | padding=1,
29 | bias=False)
30 | self.bn2 = nn.BatchNorm2d(planes)
31 |
32 | self.shortcut = nn.Sequential()
33 | if stride != 1 or in_planes != self.expansion * planes:
34 | self.shortcut = nn.Sequential(
35 | nn.Conv2d(
36 | in_planes,
37 | self.expansion * planes,
38 | kernel_size=1,
39 | stride=stride,
40 | bias=False,
41 | ),
42 | nn.BatchNorm2d(self.expansion * planes),
43 | )
44 |
45 | def forward(self, x):
46 | out = F.relu(self.bn1(self.conv1(x)))
47 | out = self.bn2(self.conv2(out))
48 | out += self.shortcut(x)
49 | out = F.relu(out)
50 | return out
51 |
52 |
53 | class Bottleneck(nn.Module):
54 | expansion = 4
55 |
56 | def __init__(self, in_planes, planes, stride=1):
57 | super(Bottleneck, self).__init__()
58 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
59 | self.bn1 = nn.BatchNorm2d(planes)
60 | self.conv2 = nn.Conv2d(planes,
61 | planes,
62 | kernel_size=3,
63 | stride=stride,
64 | padding=1,
65 | bias=False)
66 | self.bn2 = nn.BatchNorm2d(planes)
67 | self.conv3 = nn.Conv2d(planes,
68 | self.expansion * planes,
69 | kernel_size=1,
70 | bias=False)
71 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
72 |
73 | self.shortcut = nn.Sequential()
74 | if stride != 1 or in_planes != self.expansion * planes:
75 | self.shortcut = nn.Sequential(
76 | nn.Conv2d(
77 | in_planes,
78 | self.expansion * planes,
79 | kernel_size=1,
80 | stride=stride,
81 | bias=False,
82 | ),
83 | nn.BatchNorm2d(self.expansion * planes),
84 | )
85 |
86 | def forward(self, x):
87 | out = F.relu(self.bn1(self.conv1(x)))
88 | out = F.relu(self.bn2(self.conv2(out)))
89 | out = self.bn3(self.conv3(out))
90 | out += self.shortcut(x)
91 | out = F.relu(out)
92 | return out
93 |
94 |
95 | class ResNet(nn.Module):
96 |
97 | def __init__(self, block, num_blocks, num_classes=10):
98 | super(ResNet, self).__init__()
99 | self.in_planes = 64
100 |
101 | self.conv1 = nn.Conv2d(3,
102 | 64,
103 | kernel_size=3,
104 | stride=1,
105 | padding=1,
106 | bias=False)
107 | self.bn1 = nn.BatchNorm2d(64)
108 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
109 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
110 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
111 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
112 | self.linear = nn.Linear(512 * block.expansion, num_classes)
113 |
114 | def _make_layer(self, block, planes, num_blocks, stride):
115 | strides = [stride] + [1] * (num_blocks - 1)
116 | layers = []
117 | for stride in strides:
118 | layers.append(block(self.in_planes, planes, stride))
119 | self.in_planes = planes * block.expansion
120 | return nn.Sequential(*layers)
121 |
122 | def forward(self, x):
123 | out = F.relu(self.bn1(self.conv1(x)))
124 | out = self.layer1(out)
125 | out = self.layer2(out)
126 | out = self.layer3(out)
127 | out = self.layer4(out)
128 | out = F.avg_pool2d(out, 4)
129 | out = out.view(out.size(0), -1)
130 | out = self.linear(out)
131 | return out
132 |
133 |
134 | def _resnet(block: nn.Module, num_blocks: List[int],
135 | num_classes: int) -> ResNet:
136 | return ResNet(block=block, num_blocks=num_blocks, num_classes=num_classes)
137 |
138 |
139 | def resnet18(num_classes: int) -> ResNet:
140 | return _resnet(block=BasicBlock,
141 | num_blocks=[2, 2, 2, 2],
142 | num_classes=num_classes)
143 |
144 |
145 | def resnet34(num_classes: int) -> ResNet:
146 | return _resnet(block=BasicBlock,
147 | num_blocks=[3, 4, 6, 3],
148 | num_classes=num_classes)
149 |
150 |
151 | def resnet50(num_classes: int) -> ResNet:
152 | return _resnet(block=Bottleneck,
153 | num_blocks=[3, 4, 6, 3],
154 | num_classes=num_classes)
155 |
156 |
157 | def resnet101(num_classes: int) -> ResNet:
158 | return _resnet(block=Bottleneck,
159 | num_blocks=[3, 4, 23, 3],
160 | num_classes=num_classes)
161 |
162 |
163 | def resnet152(num_classes: int) -> ResNet:
164 | return _resnet(block=Bottleneck,
165 | num_blocks=[3, 8, 36, 3],
166 | num_classes=num_classes)
167 |
--------------------------------------------------------------------------------
/models/vgg.py:
--------------------------------------------------------------------------------
1 | """VGG11/13/16/19 in Pytorch."""
2 | import torch.nn as nn
3 |
4 | cfg = {
5 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
6 | 'VGG13':
7 | [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
8 | 'VGG16': [
9 | 64,
10 | 64,
11 | 'M',
12 | 128,
13 | 128,
14 | 'M',
15 | 256,
16 | 256,
17 | 256,
18 | 'M',
19 | 512,
20 | 512,
21 | 512,
22 | 'M',
23 | 512,
24 | 512,
25 | 512,
26 | 'M',
27 | ],
28 | 'VGG19': [
29 | 64,
30 | 64,
31 | 'M',
32 | 128,
33 | 128,
34 | 'M',
35 | 256,
36 | 256,
37 | 256,
38 | 256,
39 | 'M',
40 | 512,
41 | 512,
42 | 512,
43 | 512,
44 | 'M',
45 | 512,
46 | 512,
47 | 512,
48 | 512,
49 | 'M',
50 | ],
51 | }
52 |
53 |
54 | class VGG(nn.Module):
55 |
56 | def __init__(self, vgg_name, num_classes=10):
57 | super(VGG, self).__init__()
58 | self.features = self._make_layers(cfg[vgg_name.upper()])
59 | self.classifier = nn.Linear(512, num_classes)
60 |
61 | def forward(self, x):
62 | out = self.features(x)
63 | out = out.view(out.size(0), -1)
64 | out = self.classifier(out)
65 | return out
66 |
67 | def _make_layers(self, cfg):
68 | layers = []
69 | in_channels = 3
70 | for x in cfg:
71 | if x == 'M':
72 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
73 | else:
74 | layers += [
75 | nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
76 | nn.BatchNorm2d(x),
77 | nn.ReLU(inplace=True),
78 | ]
79 | in_channels = x
80 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
81 | return nn.Sequential(*layers)
82 |
83 |
84 | def vgg11(num_classes: int) -> VGG:
85 | return VGG('vgg11', num_classes=num_classes)
86 |
87 |
88 | def vgg13(num_classes: int) -> VGG:
89 | return VGG('vgg13', num_classes=num_classes)
90 |
91 |
92 | def vgg16(num_classes: int) -> VGG:
93 | return VGG('vgg16', num_classes=num_classes)
94 |
95 |
96 | def vgg19(num_classes: int) -> VGG:
97 | return VGG('vgg19', num_classes=num_classes)
98 |
--------------------------------------------------------------------------------
/packet.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | import cv2
4 | import matplotlib.pyplot as plt
5 | import os
6 | from models.resnet import resnet34, resnet18
7 | from models.vgg import vgg16
8 | from models.mobilenet_v2 import mobilenet_v2
9 | import torch.nn as nn
10 | from utils import Trigger
11 | import torchvision
12 | from torchvision import transforms
13 | from poison_dataset import PoisonDataset
14 | import numpy as np
15 | from PIL import Image
16 | from torch.nn import functional as F
17 | from pytorch_grad_cam import (
18 | GradCAM,
19 | HiResCAM,
20 | ScoreCAM,
21 | GradCAMPlusPlus,
22 | AblationCAM,
23 | XGradCAM,
24 | EigenCAM,
25 | FullGrad,
26 | )
27 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
28 | from pytorch_grad_cam.utils.image import show_cam_on_image
29 | from torch.utils.data import DataLoader
30 | import torch.optim as optim
31 | from tqdm import tqdm
32 | from torch.optim import lr_scheduler
33 | import os
34 | import torch
35 | import torch.nn as nn
36 | from utils import Trigger
37 | import torchvision
38 | from torchvision import transforms
39 | from poison_dataset import PoisonDataset
40 | import numpy as np
41 | from torch.nn import functional as F
42 | from PIL import Image
43 |
--------------------------------------------------------------------------------
/poison_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import torch
3 |
4 |
5 | class PoisonDataset(Dataset):
6 | def __init__(self, dataset, indices, target):
7 | self.dataset = dataset
8 | self.indices = [int(i) for i in indices]
9 | self.target = target
10 |
11 | def __len__(self):
12 | return len(self.indices)
13 |
14 | def __getitem__(self, item):
15 | x, y = self.dataset[self.indices[item]]
16 | # print(type(y))
17 | # print(y)
18 | # print(x.shape)
19 | # y = torch.tensor(self.target)
20 | # print(y.shape)
21 | # print(type(self.target))
22 | return x, self.target
23 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from pathlib import Path
3 |
4 | import pytest
5 | from pytest import TempPathFactory
6 |
7 |
8 | @pytest.fixture(scope='function')
9 | def tmp_work_dirs(tmp_path_factory: TempPathFactory) -> Path:
10 | relative_work_dirs = datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')
11 | work_dirs = tmp_path_factory.mktemp(relative_work_dirs)
12 |
13 | return work_dirs
14 |
--------------------------------------------------------------------------------
/tests/data/config/anti_kd_t-r34_s-r18-v16-mv2_cifar10.py:
--------------------------------------------------------------------------------
1 | trainer = dict(type='AntiKDTrainer',
2 | teacher=dict(network=dict(arch='cifar',
3 | type='resnet18',
4 | num_classes=10),
5 | optimizer=dict(type='SGD',
6 | lr=0.05,
7 | momentum=0.9,
8 | weight_decay=5e-4),
9 | scheduler=dict(type='CosineAnnealingLR',
10 | T_max=100),
11 | lambda_t=0.1,
12 | lambda_mask=1e-4,
13 | trainable_when_training_trigger=False),
14 | students=dict(
15 | resnet18=dict(network=dict(arch='cifar',
16 | type='resnet18',
17 | num_classes=10),
18 | optimizer=dict(type='SGD',
19 | lr=0.05,
20 | momentum=0.9,
21 | weight_decay=5e-4),
22 | scheduler=dict(type='CosineAnnealingLR',
23 | T_max=100),
24 | lambda_t=1e-2,
25 | lambda_mask=1e-4,
26 | trainable_when_training_trigger=False),
27 | vgg16=dict(network=dict(arch='cifar',
28 | type='vgg16',
29 | num_classes=10),
30 | optimizer=dict(type='SGD',
31 | lr=0.05,
32 | momentum=0.9,
33 | weight_decay=5e-4),
34 | scheduler=dict(type='CosineAnnealingLR',
35 | T_max=100),
36 | lambda_t=1e-2,
37 | lambda_mask=1e-4,
38 | trainable_when_training_trigger=False),
39 | mobilenet_v2=dict(network=dict(arch='cifar',
40 | type='mobilenet_v2',
41 | num_classes=10),
42 | optimizer=dict(type='SGD',
43 | lr=0.05,
44 | momentum=0.9,
45 | weight_decay=5e-4),
46 | scheduler=dict(type='CosineAnnealingLR',
47 | T_max=100),
48 | lambda_t=1e-2,
49 | lambda_mask=1e-4,
50 | trainable_when_training_trigger=False),
51 | ),
52 | trigger=dict(network=dict(arch='trigger',
53 | type='trigger',
54 | size=32),
55 | optimizer=dict(type='Adam', lr=1e-2),
56 | mask_clip_range=(0., 1.),
57 | trigger_clip_range=(-1., 1.),
58 | mask_penalty_norm=2),
59 | clean_train_dataloader=dict(dataset=dict(
60 | type='CIFAR10',
61 | root='data',
62 | train=True,
63 | download=True,
64 | transform=[
65 | dict(type='RandomCrop', size=32, padding=4),
66 | dict(type='RandomHorizontalFlip'),
67 | dict(type='ToTensor'),
68 | dict(type='Normalize',
69 | mean=(0.4914, 0.4822, 0.4465),
70 | std=(0.2023, 0.1994, 0.2010))
71 | ]),
72 | batch_size=32,
73 | num_workers=4,
74 | pin_memory=True),
75 | clean_test_dataloader=dict(dataset=dict(
76 | type='CIFAR10',
77 | root='data',
78 | train=False,
79 | download=True,
80 | transform=[
81 | dict(type='ToTensor'),
82 | dict(type='Normalize',
83 | mean=(0.4914, 0.4822, 0.4465),
84 | std=(0.2023, 0.1994, 0.2010))
85 | ]),
86 | batch_size=32,
87 | num_workers=4,
88 | pin_memory=True),
89 | poison_train_dataloader=dict(dataset=dict(
90 | type='RatioPoisonLabelCIFAR10',
91 | ratio=0.1,
92 | poison_label=1,
93 | root='data',
94 | train=True,
95 | download=True,
96 | transform=[
97 | dict(type='RandomCrop', size=32, padding=4),
98 | dict(type='RandomHorizontalFlip'),
99 | dict(type='ToTensor'),
100 | dict(type='Normalize',
101 | mean=(0.4914, 0.4822, 0.4465),
102 | std=(0.2023, 0.1994, 0.2010))
103 | ]),
104 | batch_size=32,
105 | num_workers=4,
106 | pin_memory=True),
107 | poison_test_dataloader=dict(dataset=dict(
108 | type='RatioPoisonLabelCIFAR10',
109 | ratio=1,
110 | poison_label=1,
111 | root='data',
112 | train=False,
113 | download=True,
114 | transform=[
115 | dict(type='ToTensor'),
116 | dict(type='Normalize',
117 | mean=(0.4914, 0.4822, 0.4465),
118 | std=(0.2023, 0.1994, 0.2010))
119 | ]),
120 | batch_size=32,
121 | num_workers=4,
122 | pin_memory=True),
123 | epochs=100,
124 | save_interval=5,
125 | temperature=1.0,
126 | alpha=1.0,
127 | device='cuda')
128 |
--------------------------------------------------------------------------------
/tests/data/config/cifar10_resnet18.py:
--------------------------------------------------------------------------------
1 | dataset = 'CIFAR10'
2 | dataset_mean = (0.4914, 0.4822, 0.4465)
3 | dataset_std = (0.2023, 0.1994, 0.2010)
4 |
5 | network = dict(type='resnet18', arch='cifar')
6 |
7 | test_dataloader = dict(batch_size=64,
8 | num_workers=4,
9 | persistent_workers=True,
10 | shuffle=False,
11 | dataset=dict(type=dataset,
12 | root='data',
13 | train=False,
14 | download=True,
15 | transform=[
16 | dict(type='ToTensor'),
17 | dict(type='Normalize',
18 | mean=dataset_mean,
19 | std=dataset_std)
20 | ]))
21 |
22 | train_dataloader = dict(batch_size=64,
23 | num_workers=4,
24 | persistent_workers=True,
25 | shuffle=True,
26 | dataset=dict(type=dataset,
27 | root='data',
28 | train=True,
29 | download=True,
30 | transform=[
31 | dict(type='RandomCrop',
32 | size=32,
33 | padding=4),
34 | dict(type='RandomHorizontalFlip',
35 | p=0.5),
36 | dict(type='ToTensor'),
37 | dict(type='Normalize',
38 | mean=dataset_mean,
39 | std=dataset_std)
40 | ]))
41 |
--------------------------------------------------------------------------------
/tests/data/config/error.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/tests/data/config/error.txt
--------------------------------------------------------------------------------
/tests/data/config/simple_config.py:
--------------------------------------------------------------------------------
1 | item1 = [1, 2]
2 | item2 = {'a': 0}
3 | item3 = True
4 | item4 = 'test'
--------------------------------------------------------------------------------
/tests/test_config/test_config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import pytest
5 |
6 | from anti_kd_backdoor.config import Config, DictAction
7 |
8 |
9 | class TestConfig:
10 | """Modify from mmengine"""
11 |
12 | data_path: Path = Path(__file__).parent.parent / 'data'
13 | config_path: Path = data_path / 'config'
14 |
15 | def test_fromfile(self) -> None:
16 | config_path = self.config_path / 'cifar10_resnet18.py'
17 | config = Config.fromfile(config_path)
18 |
19 | assert config.dataset == 'CIFAR10'
20 | assert config.network.type == 'resnet18'
21 | assert config.test_dataloader.batch_size == 64
22 | assert len(config.train_dataloader.dataset.transform) == 4
23 | assert len(config.test_dataloader.dataset.transform) == 2
24 |
25 | with pytest.raises(ValueError):
26 | _ = Config.fromfile(self.config_path / 'error.txt')
27 |
28 | def test_magic_method(self) -> None:
29 | cfg_dict = dict(item1=[1, 2],
30 | item2=dict(a=0),
31 | item3=True,
32 | item4='test')
33 | cfg_file = self.config_path / 'simple_config.py'
34 | cfg = Config.fromfile(cfg_file)
35 | # len(cfg)
36 | assert len(cfg) == 4
37 | # cfg.keys()
38 | assert set(cfg.keys()) == set(cfg_dict.keys())
39 | assert set(cfg._cfg_dict.keys()) == set(cfg_dict.keys())
40 | # cfg.values()
41 | for value in cfg.values():
42 | assert value in cfg_dict.values()
43 | # cfg.items()
44 | for name, value in cfg.items():
45 | assert name in cfg_dict
46 | assert value in cfg_dict.values()
47 | # cfg.field
48 | assert cfg.item1 == cfg_dict['item1']
49 | assert cfg.item2 == cfg_dict['item2']
50 | assert cfg.item2.a == 0
51 | assert cfg.item3 == cfg_dict['item3']
52 | assert cfg.item4 == cfg_dict['item4']
53 | # accessing keys that do not exist will cause error
54 | with pytest.raises(AttributeError):
55 | cfg.not_exist
56 | # field in cfg, cfg[field], cfg.get()
57 | for name in ['item1', 'item2', 'item3', 'item4']:
58 | assert name in cfg
59 | assert cfg[name] == cfg_dict[name]
60 | assert cfg.get(name) == cfg_dict[name]
61 | assert cfg.get('not_exist') is None
62 | assert cfg.get('not_exist', 0) == 0
63 | # accessing keys that do not exist will cause error
64 | with pytest.raises(KeyError):
65 | cfg['not_exist']
66 | assert 'item1' in cfg
67 | assert 'not_exist' not in cfg
68 | # cfg.update()
69 | cfg.update(dict(item1=0))
70 | assert cfg.item1 == 0
71 | cfg.update(dict(item2=dict(a=1)))
72 | assert cfg.item2.a == 1
73 | # test __setattr__
74 | cfg = Config()
75 | cfg.item1 = [1, 2]
76 | cfg.item2 = {'a': 0}
77 | cfg['item5'] = {'a': {'b': None}}
78 | assert cfg._cfg_dict['item1'] == [1, 2]
79 | assert cfg.item1 == [1, 2]
80 | assert cfg._cfg_dict['item2'] == {'a': 0}
81 | assert cfg.item2.a == 0
82 | assert cfg._cfg_dict['item5'] == {'a': {'b': None}}
83 | assert cfg.item5.a.b is None
84 |
85 | def test_dict_action(self):
86 | parser = argparse.ArgumentParser(description='Train a detector')
87 | parser.add_argument('--options',
88 | nargs='+',
89 | action=DictAction,
90 | help='custom options')
91 | # Nested brackets
92 | args = parser.parse_args(
93 | ['--options', 'item2.a=a,b', 'item2.b=[(a,b), [1,2], false]'])
94 | out_dict = {
95 | 'item2.a': ['a', 'b'],
96 | 'item2.b': [('a', 'b'), [1, 2], False]
97 | }
98 | assert args.options == out_dict
99 | # Single Nested brackets
100 | args = parser.parse_args(['--options', 'item2.a=[[1]]'])
101 | out_dict = {'item2.a': [[1]]}
102 | assert args.options == out_dict
103 | # Imbalance bracket will cause error
104 | with pytest.raises(AssertionError):
105 | parser.parse_args(['--options', 'item2.a=[(a,b), [1,2], false'])
106 | # Normal values
107 | args = parser.parse_args([
108 | '--options', 'item2.a=1', 'item2.b=0.1', 'item2.c=x', 'item3=false'
109 | ])
110 | out_dict = {
111 | 'item2.a': 1,
112 | 'item2.b': 0.1,
113 | 'item2.c': 'x',
114 | 'item3': False
115 | }
116 | assert args.options == out_dict
117 |
--------------------------------------------------------------------------------
/tests/test_data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/tests/test_data/__init__.py
--------------------------------------------------------------------------------
/tests/test_data/test_dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/tests/test_data/test_dataset/__init__.py
--------------------------------------------------------------------------------
/tests/test_data/test_dataset/test_cifar.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from .utils import FakeDataset, build_fake_dataset
5 |
6 | CIFAR_TESTSET_NUM = 1000
7 |
8 |
9 | def build_cifar_fake_dataset(dataset_type: str, **kwargs) -> FakeDataset:
10 | if dataset_type.endswith('CIFAR100'):
11 | y_range = (0, 99)
12 | dataset_type = dataset_type.replace('CIFAR100', 'FakeDataset')
13 | else:
14 | y_range = (0, 9)
15 | dataset_type = dataset_type.replace('CIFAR10', 'FakeDataset')
16 |
17 | dataset_cfg = dict(type=dataset_type,
18 | x_shape=(3, 32, 32),
19 | y_range=y_range,
20 | nums=CIFAR_TESTSET_NUM,
21 | **kwargs)
22 |
23 | return build_fake_dataset(dataset_cfg)
24 |
25 |
26 | @pytest.mark.parametrize('dataset_type', ['CIFAR10', 'CIFAR100'])
27 | def test_xy(dataset_type: str) -> None:
28 | cifar = build_cifar_fake_dataset(dataset_type)
29 |
30 | xy = cifar.get_xy()
31 | x, y = xy
32 | assert len(x) == len(y)
33 | assert isinstance(y[0], int)
34 |
35 | old_x = x.copy()
36 | old_y = y.copy()
37 |
38 | cifar.set_xy(xy)
39 | assert all([np.array_equal(nx, ox) for nx, ox in zip(cifar.data, old_x)])
40 | assert cifar.targets == old_y
41 |
42 | x = x[:cifar.num_classes]
43 | y = y[:cifar.num_classes]
44 | cifar.set_xy((x, y))
45 | assert all([np.array_equal(nx, ox) for nx, ox in zip(cifar.data, x)])
46 | assert cifar.targets == y
47 | assert cifar.num_classes == len(set(y))
48 | assert len(cifar.data.shape) == 4
49 |
50 |
51 | @pytest.mark.parametrize(['start_idx', 'end_idx', 'dataset_type'],
52 | [(0, 9, 'IndexCIFAR10'), (-10, 8, 'IndexCIFAR10'),
53 | (2, 12, 'IndexCIFAR10'), (4, 4, 'IndexCIFAR10'),
54 | (4, 3, 'IndexCIFAR10'), (0, 99, 'IndexCIFAR100'),
55 | (-10, 8, 'IndexCIFAR100'),
56 | (40, 50, 'IndexCIFAR100')])
57 | def test_index(start_idx: int, end_idx: int, dataset_type: str) -> None:
58 | kwargs = dict(start_idx=start_idx, end_idx=end_idx)
59 |
60 | if start_idx > end_idx:
61 | with pytest.raises(ValueError):
62 | _ = build_cifar_fake_dataset(dataset_type, **kwargs)
63 | return
64 | cifar = build_cifar_fake_dataset(dataset_type, **kwargs)
65 | assert cifar.start_idx == start_idx
66 | assert cifar.end_idx == end_idx
67 |
68 | for y in cifar.targets:
69 | assert start_idx <= y <= end_idx
70 |
71 | assert cifar.num_classes == min(
72 | cifar.end_idx, cifar.raw_num_classes - 1) - max(cifar.start_idx, 0) + 1
73 | assert len(cifar.data.shape) == 4
74 |
75 |
76 | @pytest.mark.parametrize(['ratio', 'dataset_type'], [(-1, 'RatioCIFAR10'),
77 | (0, 'RatioCIFAR10'),
78 | (0.1, 'RatioCIFAR10'),
79 | (0.5, 'RatioCIFAR10'),
80 | (1, 'RatioCIFAR10'),
81 | (2, 'RatioCIFAR10'),
82 | (0.4, 'RatioCIFAR100')])
83 | def test_ratio(ratio: float, dataset_type: str) -> None:
84 | kwargs = dict(ratio=ratio)
85 |
86 | if ratio <= 0 or ratio > 1:
87 | with pytest.raises(ValueError):
88 | _ = build_cifar_fake_dataset(dataset_type, **kwargs)
89 | return
90 | cifar = build_cifar_fake_dataset(dataset_type, **kwargs)
91 |
92 | assert cifar.num_classes == cifar.raw_num_classes
93 | assert len(cifar.targets) == \
94 | int(CIFAR_TESTSET_NUM / cifar.num_classes * ratio) * cifar.num_classes
95 | assert len(cifar.data.shape) == 4
96 |
97 |
98 | @pytest.mark.parametrize('range_ratio', [(-1, 0.2), (0, 2), (0.1, 0.1),
99 | (0.5, 0.2), (0.1, 0.5), (0, 1)])
100 | @pytest.mark.parametrize('dataset_type',
101 | ['RangeRatioCIFAR10', 'RangeRatioCIFAR100'])
102 | def test_range_ratio(range_ratio: tuple[float, float],
103 | dataset_type: str) -> None:
104 | kwargs = dict(range_ratio=range_ratio)
105 |
106 | start_ratio = range_ratio[0]
107 | end_ratio = range_ratio[1]
108 | if not (0 <= start_ratio < end_ratio <= 1):
109 | with pytest.raises(ValueError):
110 | _ = build_cifar_fake_dataset(dataset_type, **kwargs)
111 | return
112 |
113 | cifar = build_cifar_fake_dataset(dataset_type, **kwargs)
114 | assert cifar.num_classes == cifar.raw_num_classes
115 | assert len(cifar.targets) == \
116 | round(CIFAR_TESTSET_NUM * (end_ratio - start_ratio))
117 | assert len(cifar.data.shape) == 4
118 |
119 |
120 | @pytest.mark.parametrize(['range_ratio1', 'range_ratio2'],
121 | [((0, 0.5), (0.5, 1)), ((0, 0.6), (0.4, 1)),
122 | ((0, 0.7), (0.3, 1)), ((0, 0.5), (0, 1))])
123 | @pytest.mark.parametrize('dataset_type',
124 | ['RangeRatioCIFAR10', 'RangeRatioCIFAR100'])
125 | def test_range_ratio_intersection(range_ratio1: tuple[float, float],
126 | range_ratio2: tuple[float, float],
127 | dataset_type: str) -> None:
128 |
129 | cifar1 = build_cifar_fake_dataset(dataset_type=dataset_type,
130 | range_ratio=range_ratio1,
131 | cache_xy=True)
132 | cifar2 = build_cifar_fake_dataset(dataset_type=dataset_type,
133 | range_ratio=range_ratio2,
134 | cache_xy=True)
135 |
136 | cat_x = np.concatenate([cifar1.data, cifar2.data], axis=0)
137 | unique_x = np.unique(cat_x, axis=0)
138 | intersection_number = cat_x.shape[0] - unique_x.shape[0]
139 |
140 | intersection_ratio = max(0, range_ratio1[1] - range_ratio2[0])
141 | assert round(intersection_ratio * CIFAR_TESTSET_NUM) == intersection_number
142 |
143 |
144 | @pytest.mark.parametrize('dataset_type',
145 | ['IndexRatioCIFAR10', 'IndexRatioCIFAR100'])
146 | @pytest.mark.parametrize(['start_idx', 'end_idx', 'ratio'], [(4, 3, 0.5),
147 | (3, 4, 0),
148 | (3, 4, 2),
149 | (1, 4, 0.1)])
150 | def test_index_ratio(start_idx: int, end_idx: int, ratio: float,
151 | dataset_type: str) -> None:
152 | kwargs = dict(start_idx=start_idx, end_idx=end_idx, ratio=ratio)
153 |
154 | if ratio <= 0 or ratio > 1 or start_idx > end_idx:
155 | with pytest.raises(ValueError):
156 | _ = build_cifar_fake_dataset(dataset_type, **kwargs)
157 | return
158 | cifar = build_cifar_fake_dataset(dataset_type, **kwargs)
159 | assert cifar.start_idx == start_idx
160 | assert cifar.end_idx == end_idx
161 |
162 | for y in cifar.targets:
163 | assert start_idx <= y <= end_idx
164 | assert len(cifar.targets) == \
165 | cifar.num_classes / cifar.raw_num_classes * CIFAR_TESTSET_NUM * ratio
166 | assert len(cifar.data.shape) == 4
167 |
168 |
169 | @pytest.mark.parametrize('dataset_type',
170 | ['PoisonLabelCIFAR10', 'PoisonLabelCIFAR100'])
171 | @pytest.mark.parametrize('poison_label', [-1, 5, 101])
172 | def test_poison_label(poison_label: int, dataset_type: str) -> None:
173 | kwargs = dict(poison_label=poison_label)
174 |
175 | if poison_label < 0 or poison_label >= 100:
176 | with pytest.raises(ValueError):
177 | _ = build_cifar_fake_dataset(dataset_type, **kwargs)
178 | return
179 | cifar = build_cifar_fake_dataset(dataset_type, **kwargs)
180 | assert cifar.poison_label == poison_label
181 |
182 | assert cifar.num_classes == 1
183 | assert all(map(lambda x: x == poison_label, cifar.targets))
184 | assert len(cifar.data.shape) == 4
185 |
186 |
187 | @pytest.mark.parametrize(
188 | 'dataset_type', ['RatioPoisonLabelCIFAR10', 'RatioPoisonLabelCIFAR100'])
189 | @pytest.mark.parametrize('poison_label', [-1, 5, 101])
190 | @pytest.mark.parametrize('ratio', [0, 0.2, 1, 1.2])
191 | def test_ratio_poison_label(ratio: float, poison_label: int,
192 | dataset_type: str) -> None:
193 | kwargs = dict(ratio=ratio, poison_label=poison_label)
194 |
195 | if (poison_label < 0 or poison_label >= 100) or \
196 | (ratio <= 0 or ratio > 1):
197 | with pytest.raises(ValueError):
198 | _ = build_cifar_fake_dataset(dataset_type, **kwargs)
199 | return
200 | cifar = build_cifar_fake_dataset(dataset_type, **kwargs)
201 | assert cifar.poison_label == poison_label
202 |
203 | assert len(cifar) == round(CIFAR_TESTSET_NUM * ratio)
204 | assert cifar.num_classes == 1
205 | assert all(map(lambda x: x == poison_label, cifar.targets))
206 | assert len(cifar.data.shape) == 4
207 |
208 |
209 | @pytest.mark.parametrize(
210 | 'dataset_type',
211 | ['RangeRatioPoisonLabelCIFAR10', 'RangeRatioPoisonLabelCIFAR100'])
212 | @pytest.mark.parametrize('poison_label', [-1, 1, 101])
213 | @pytest.mark.parametrize('range_ratio', [(0, 0.2), (0.2, 0.5), (0.5, 1),
214 | (0.5, 0.2)])
215 | def test_range_ratio_poison_label(range_ratio: tuple[float,
216 | float], poison_label: int,
217 | dataset_type: str) -> None:
218 | kwargs = dict(range_ratio=range_ratio, poison_label=poison_label)
219 |
220 | if poison_label < 0 or poison_label >= 100 or \
221 | not (0 <= range_ratio[0] < range_ratio[1] <= 1):
222 | with pytest.raises(ValueError):
223 | _ = build_cifar_fake_dataset(dataset_type, **kwargs)
224 | return
225 | cifar = build_cifar_fake_dataset(dataset_type, **kwargs)
226 | assert cifar.poison_label == poison_label
227 |
228 | assert len(cifar) == round(CIFAR_TESTSET_NUM *
229 | (range_ratio[1] - range_ratio[0]))
230 | assert cifar.num_classes == 1
231 | assert all(map(lambda x: x == poison_label, cifar.targets))
232 | assert len(cifar.data.shape) == 4
233 |
--------------------------------------------------------------------------------
/tests/test_data/test_dataset/test_flowers102.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from .utils import FakeDataset, build_fake_dataset
5 |
6 | FLOWERS102_TESTSET_NUM = 5 * 102
7 |
8 |
9 | def build_flowers102_fake_dataset(dataset_type: str, **kwargs) -> FakeDataset:
10 | dataset_cfg = dict(type=dataset_type.replace('Flowers102', 'FakeDataset'),
11 | x_shape=(3, 32, 32),
12 | y_range=(0, 101),
13 | nums=FLOWERS102_TESTSET_NUM,
14 | **kwargs)
15 |
16 | return build_fake_dataset(dataset_cfg)
17 |
18 |
19 | @pytest.mark.parametrize('dataset_type', ['Flowers102'])
20 | def test_xy(dataset_type: str) -> None:
21 | flowers102 = build_flowers102_fake_dataset(dataset_type)
22 |
23 | xy = flowers102.get_xy()
24 | x, y = xy
25 | assert len(x) == len(y)
26 | assert isinstance(y[0], int)
27 |
28 | old_x = x.copy()
29 | old_y = y.copy()
30 |
31 | flowers102.set_xy(xy)
32 | assert all(
33 | [np.array_equal(nx, ox) for nx, ox in zip(flowers102.data, old_x)])
34 | assert flowers102.targets == old_y
35 |
36 | x = x[:flowers102.num_classes]
37 | y = y[:flowers102.num_classes]
38 | flowers102.set_xy((x, y))
39 | assert all([np.array_equal(nx, ox) for nx, ox in zip(flowers102.data, x)])
40 | assert flowers102.targets == y
41 | assert flowers102.num_classes == len(set(y))
42 | assert len(flowers102.data.shape) == 4
43 |
44 |
45 | @pytest.mark.parametrize('dataset_type', ['PoisonLabelFlowers102'])
46 | @pytest.mark.parametrize('poison_label', [-1, 5, 102])
47 | def test_poison_label(poison_label: int, dataset_type: str) -> None:
48 | kwargs = dict(poison_label=poison_label)
49 |
50 | if poison_label < 0 or poison_label >= 43:
51 | with pytest.raises(ValueError):
52 | _ = build_flowers102_fake_dataset(dataset_type, **kwargs)
53 | return
54 | flowers102 = build_flowers102_fake_dataset(dataset_type, **kwargs)
55 | assert flowers102.poison_label == poison_label
56 |
57 | assert flowers102.num_classes == 1
58 | assert all(map(lambda x: x == poison_label, flowers102.targets))
59 | assert len(flowers102.data.shape) == 4
60 |
61 |
62 | @pytest.mark.parametrize('dataset_type', ['RatioPoisonLabelFlowers102'])
63 | @pytest.mark.parametrize('poison_label', [-1, 5, 102])
64 | @pytest.mark.parametrize('ratio', [0, 0.2, 1, 1.2])
65 | def test_ratio_poison_label(ratio: float, poison_label: int,
66 | dataset_type: str) -> None:
67 | kwargs = dict(ratio=ratio, poison_label=poison_label)
68 |
69 | if (poison_label < 0 or poison_label >= 102) or \
70 | (ratio <= 0 or ratio > 1):
71 | with pytest.raises(ValueError):
72 | _ = build_flowers102_fake_dataset(dataset_type, **kwargs)
73 | return
74 | flowers102 = build_flowers102_fake_dataset(dataset_type, **kwargs)
75 | assert flowers102.poison_label == poison_label
76 |
77 | assert len(flowers102) == round(FLOWERS102_TESTSET_NUM * ratio)
78 | assert flowers102.num_classes == 1
79 | assert all(map(lambda x: x == poison_label, flowers102.targets))
80 | assert len(flowers102.data.shape) == 4
81 |
--------------------------------------------------------------------------------
/tests/test_data/test_dataset/test_gtsrb.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from .utils import FakeDataset, build_fake_dataset
5 |
6 | GTSRB_TESTSET_NUM = 43 * 50
7 |
8 |
9 | def build_gtsrb_fake_dataset(dataset_type: str, **kwargs) -> FakeDataset:
10 | dataset_cfg = dict(type=dataset_type.replace('GTSRB', 'FakeDataset'),
11 | x_shape=(3, 32, 32),
12 | y_range=(0, 42),
13 | nums=GTSRB_TESTSET_NUM,
14 | **kwargs)
15 |
16 | return build_fake_dataset(dataset_cfg)
17 |
18 |
19 | @pytest.mark.parametrize('dataset_type', ['GTSRB'])
20 | def test_xy(dataset_type: str) -> None:
21 | gtsrb = build_gtsrb_fake_dataset(dataset_type)
22 |
23 | xy = gtsrb.get_xy()
24 | x, y = xy
25 | assert len(x) == len(y)
26 | assert isinstance(y[0], int)
27 |
28 | old_x = x.copy()
29 | old_y = y.copy()
30 |
31 | gtsrb.set_xy(xy)
32 | assert all([np.array_equal(nx, ox) for nx, ox in zip(gtsrb.data, old_x)])
33 | assert gtsrb.targets == old_y
34 |
35 | x = x[:gtsrb.num_classes]
36 | y = y[:gtsrb.num_classes]
37 | gtsrb.set_xy((x, y))
38 | assert all([np.array_equal(nx, ox) for nx, ox in zip(gtsrb.data, x)])
39 | assert gtsrb.targets == y
40 | assert gtsrb.num_classes == len(set(y))
41 | assert len(gtsrb.data.shape) == 4
42 |
43 |
44 | @pytest.mark.parametrize('dataset_type', ['PoisonLabelGTSRB'])
45 | @pytest.mark.parametrize('poison_label', [-1, 5, 43])
46 | def test_poison_label(poison_label: int, dataset_type: str) -> None:
47 | kwargs = dict(poison_label=poison_label)
48 |
49 | if poison_label < 0 or poison_label >= 43:
50 | with pytest.raises(ValueError):
51 | _ = build_gtsrb_fake_dataset(dataset_type, **kwargs)
52 | return
53 | gtsrb = build_gtsrb_fake_dataset(dataset_type, **kwargs)
54 | assert gtsrb.poison_label == poison_label
55 |
56 | assert gtsrb.num_classes == 1
57 | assert all(map(lambda x: x == poison_label, gtsrb.targets))
58 | assert len(gtsrb.data.shape) == 4
59 |
60 |
61 | @pytest.mark.parametrize('dataset_type', ['RatioPoisonLabelGTSRB'])
62 | @pytest.mark.parametrize('poison_label', [-1, 5, 43])
63 | @pytest.mark.parametrize('ratio', [0, 0.2, 1, 1.2])
64 | def test_ratio_poison_label(ratio: float, poison_label: int,
65 | dataset_type: str) -> None:
66 | kwargs = dict(ratio=ratio, poison_label=poison_label)
67 |
68 | if (poison_label < 0 or poison_label >= 43) or \
69 | (ratio <= 0 or ratio > 1):
70 | with pytest.raises(ValueError):
71 | _ = build_gtsrb_fake_dataset(dataset_type, **kwargs)
72 | return
73 | gtsrb = build_gtsrb_fake_dataset(dataset_type, **kwargs)
74 | assert gtsrb.poison_label == poison_label
75 |
76 | assert len(gtsrb) == round(GTSRB_TESTSET_NUM * ratio)
77 | assert gtsrb.num_classes == 1
78 | assert all(map(lambda x: x == poison_label, gtsrb.targets))
79 | assert len(gtsrb.data.shape) == 4
80 |
--------------------------------------------------------------------------------
/tests/test_data/test_dataset/test_svhn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from .utils import FakeDataset, build_fake_dataset
5 |
6 | SVHN_TESTSET_NUM = 10 * 100
7 |
8 |
9 | def build_svhn_fake_dataset(dataset_type: str, **kwargs) -> FakeDataset:
10 | dataset_cfg = dict(type=dataset_type.replace('SVHN', 'FakeDataset'),
11 | x_shape=(3, 32, 32),
12 | y_range=(0, 9),
13 | nums=SVHN_TESTSET_NUM,
14 | **kwargs)
15 |
16 | return build_fake_dataset(dataset_cfg)
17 |
18 |
19 | @pytest.mark.parametrize('dataset_type', ['SVHN'])
20 | def test_xy(dataset_type: str) -> None:
21 | svhn = build_svhn_fake_dataset(dataset_type)
22 |
23 | xy = svhn.get_xy()
24 | x, y = xy
25 | assert len(x) == len(y)
26 | assert isinstance(y[0], int)
27 |
28 | old_x = x.copy()
29 | old_y = y.copy()
30 |
31 | svhn.set_xy(xy)
32 | assert all([np.array_equal(nx, ox) for nx, ox in zip(svhn.data, old_x)])
33 | assert svhn.targets == old_y
34 |
35 | x = x[:svhn.num_classes]
36 | y = y[:svhn.num_classes]
37 | svhn.set_xy((x, y))
38 | assert all([np.array_equal(nx, ox) for nx, ox in zip(svhn.data, x)])
39 | assert svhn.targets == y
40 | assert svhn.num_classes == len(set(y))
41 | assert len(svhn.data.shape) == 4
42 |
43 |
44 | @pytest.mark.parametrize('dataset_type', ['PoisonLabelSVHN'])
45 | @pytest.mark.parametrize('poison_label', [-1, 5, 43])
46 | def test_poison_label(poison_label: int, dataset_type: str) -> None:
47 | kwargs = dict(poison_label=poison_label)
48 |
49 | if poison_label < 0 or poison_label >= 43:
50 | with pytest.raises(ValueError):
51 | _ = build_svhn_fake_dataset(dataset_type, **kwargs)
52 | return
53 | svhn = build_svhn_fake_dataset(dataset_type, **kwargs)
54 | assert svhn.poison_label == poison_label
55 |
56 | assert svhn.num_classes == 1
57 | assert all(map(lambda x: x == poison_label, svhn.targets))
58 | assert len(svhn.data.shape) == 4
59 |
60 |
61 | @pytest.mark.parametrize('dataset_type', ['RatioPoisonLabelSVHN'])
62 | @pytest.mark.parametrize('poison_label', [-1, 5, 10])
63 | @pytest.mark.parametrize('ratio', [0, 0.2, 1, 1.2])
64 | def test_ratio_poison_label(ratio: float, poison_label: int,
65 | dataset_type: str) -> None:
66 | kwargs = dict(ratio=ratio, poison_label=poison_label)
67 |
68 | if (poison_label < 0 or poison_label >= 10) or \
69 | (ratio <= 0 or ratio > 1):
70 | with pytest.raises(ValueError):
71 | _ = build_svhn_fake_dataset(dataset_type, **kwargs)
72 | return
73 | svhn = build_svhn_fake_dataset(dataset_type, **kwargs)
74 | assert svhn.poison_label == poison_label
75 |
76 | assert len(svhn) == round(SVHN_TESTSET_NUM * ratio)
77 | assert svhn.num_classes == 1
78 | assert all(map(lambda x: x == poison_label, svhn.targets))
79 | assert len(svhn.data.shape) == 4
80 |
--------------------------------------------------------------------------------
/tests/test_data/test_dataset/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import Dataset
4 |
5 | from anti_kd_backdoor.data.dataset.base import (
6 | DatasetInterface,
7 | IndexDataset,
8 | IndexRatioDataset,
9 | PoisonLabelDataset,
10 | RangeRatioDataset,
11 | RangeRatioPoisonLabelDataset,
12 | RatioDataset,
13 | RatioPoisonLabelDataset,
14 | )
15 | from anti_kd_backdoor.data.dataset.types import XY_TYPE
16 |
17 |
18 | class FakeDataset(DatasetInterface, Dataset):
19 | cache: dict[tuple, tuple[torch.Tensor, list[int]]] = dict()
20 |
21 | def __init__(self,
22 | *,
23 | x_shape: tuple[int, int, int] = (3, 32, 32),
24 | y_range: tuple[int, int] = (0, 9),
25 | nums: int = 10000,
26 | cache_xy: bool = False) -> None:
27 | self._nums = nums
28 | self._x_shape = x_shape
29 | self._y_range = y_range
30 | self._raw_num_classes = y_range[1] - y_range[0] + 1
31 |
32 | if cache_xy:
33 | cache_key = (x_shape, y_range, nums)
34 | if cache_key not in FakeDataset.cache:
35 | x, y = self._prepare_xy()
36 | FakeDataset.cache[cache_key] = (x, y)
37 | x, y = FakeDataset.cache[cache_key]
38 | else:
39 | x, y = self._prepare_xy()
40 |
41 | self.data, self.targets = x, y
42 |
43 | def __len__(self) -> int:
44 | return len(self.targets)
45 |
46 | def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
47 | x = self.data[index]
48 | y = self.targets[index]
49 |
50 | return x, y
51 |
52 | def get_xy(self) -> XY_TYPE:
53 | return list(self.data), self.targets.copy()
54 |
55 | def set_xy(self, xy: XY_TYPE) -> None:
56 | x, y = xy
57 | assert len(x) == len(y)
58 |
59 | self.data = np.stack(x, axis=0)
60 | self.targets = y.copy()
61 |
62 | @property
63 | def num_classes(self) -> int:
64 | return len(set(self.targets))
65 |
66 | @property
67 | def raw_num_classes(self) -> int:
68 | return self._raw_num_classes
69 |
70 | def _prepare_xy(self) -> tuple[torch.Tensor, list[int]]:
71 | x = torch.rand((self._nums, *self._x_shape))
72 | num_per_class = self._nums // self._raw_num_classes
73 | y = [
74 | i for _ in range(num_per_class)
75 | for i in range(self._raw_num_classes)
76 | ]
77 | y.extend([self._y_range[0]] * (self._nums - len(y)))
78 |
79 | return x, y
80 |
81 |
82 | class IndexFakeDataset(FakeDataset, IndexDataset):
83 |
84 | def __init__(self, *, start_idx: int, end_idx: int, **kwargs) -> None:
85 | FakeDataset.__init__(self, **kwargs)
86 | IndexDataset.__init__(self, start_idx=start_idx, end_idx=end_idx)
87 |
88 |
89 | class RatioFakeDataset(FakeDataset, RatioDataset):
90 |
91 | def __init__(self, *, ratio: float, **kwargs) -> None:
92 | FakeDataset.__init__(self, **kwargs)
93 | RatioDataset.__init__(self, ratio=ratio)
94 |
95 |
96 | class RangeRatioFakeDataset(FakeDataset, RangeRatioDataset):
97 |
98 | def __init__(self, *, range_ratio: tuple[float, float], **kwargs) -> None:
99 | FakeDataset.__init__(self, **kwargs)
100 | RangeRatioDataset.__init__(self, range_ratio=range_ratio)
101 |
102 |
103 | class IndexRatioFakeDataset(FakeDataset, IndexRatioDataset):
104 |
105 | def __init__(self, *, start_idx: int, end_idx: int, ratio: float,
106 | **kwargs) -> None:
107 | FakeDataset.__init__(self, **kwargs)
108 | IndexRatioDataset.__init__(self,
109 | start_idx=start_idx,
110 | end_idx=end_idx,
111 | ratio=ratio)
112 |
113 |
114 | class PoisonLabelFakeDataset(FakeDataset, PoisonLabelDataset):
115 |
116 | def __init__(self, *, poison_label: int, **kwargs) -> None:
117 | FakeDataset.__init__(self, **kwargs)
118 | PoisonLabelDataset.__init__(self, poison_label=poison_label)
119 |
120 |
121 | class RatioPoisonLabelFakeDataset(FakeDataset, RatioPoisonLabelDataset):
122 |
123 | def __init__(self, *, ratio: float, poison_label: int, **kwargs) -> None:
124 | FakeDataset.__init__(self, **kwargs)
125 | RatioPoisonLabelDataset.__init__(self,
126 | ratio=ratio,
127 | poison_label=poison_label)
128 |
129 |
130 | class RangeRatioPoisonLabelFakeDataset(FakeDataset,
131 | RangeRatioPoisonLabelDataset):
132 |
133 | def __init__(self, *, range_ratio: tuple[float, float], poison_label: int,
134 | **kwargs) -> None:
135 | FakeDataset.__init__(self, **kwargs)
136 | RangeRatioPoisonLabelDataset.__init__(self,
137 | range_ratio=range_ratio,
138 | poison_label=poison_label)
139 |
140 |
141 | FAKE_DATASETS_MAPPING = {
142 | 'FakeDataset': FakeDataset,
143 | 'IndexFakeDataset': IndexFakeDataset,
144 | 'RatioFakeDataset': RatioFakeDataset,
145 | 'RangeRatioFakeDataset': RangeRatioFakeDataset,
146 | 'IndexRatioFakeDataset': IndexRatioFakeDataset,
147 | 'PoisonLabelFakeDataset': PoisonLabelFakeDataset,
148 | 'RatioPoisonLabelFakeDataset': RatioPoisonLabelFakeDataset,
149 | 'RangeRatioPoisonLabelFakeDataset': RangeRatioPoisonLabelFakeDataset
150 | }
151 |
152 |
153 | def build_fake_dataset(dataset_cfg: dict) -> FakeDataset:
154 | if 'type' not in dataset_cfg:
155 | raise ValueError('Dataset config must have `type` field')
156 | dataset_type = dataset_cfg.pop('type')
157 | if dataset_type not in FAKE_DATASETS_MAPPING:
158 | raise ValueError(
159 | f'Dataset `{dataset_type}` is not support, '
160 | f'available datasets: {list(FAKE_DATASETS_MAPPING.keys())}')
161 | dataset = FAKE_DATASETS_MAPPING[dataset_type]
162 |
163 | return dataset(**dataset_cfg)
164 |
--------------------------------------------------------------------------------
/tests/test_network/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/tests/test_network/__init__.py
--------------------------------------------------------------------------------
/tests/test_network/test_cifar/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/tests/test_network/test_cifar/__init__.py
--------------------------------------------------------------------------------
/tests/test_network/test_cifar/test_network.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from anti_kd_backdoor.network import build_network
5 |
6 | _AVAILABLE_CIFAR_NETWORKS = [
7 | 'mobilenet_v2', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
8 | 'resnet152', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'mobilenetv2_x0_5',
9 | 'mobilenetv2_x0_75', 'mobilenetv2_x1_0', 'mobilenetv2_x1_4', 'repvgg_a0',
10 | 'repvgg_a1', 'repvgg_a2', 'shufflenetv2_x0_5', 'shufflenetv2_x1_0',
11 | 'shufflenetv2_x1_5', 'shufflenetv2_x2_0'
12 | ]
13 |
14 |
15 | def _make_network_cfg(num_classes: int, network_type: str) -> dict:
16 | return dict(type=network_type, arch='cifar', num_classes=num_classes)
17 |
18 |
19 | @torch.no_grad()
20 | @pytest.mark.parametrize('num_classes', [10, 100])
21 | @pytest.mark.parametrize('network_type', _AVAILABLE_CIFAR_NETWORKS)
22 | def test_mobilenet_v2(network_type: str, num_classes: int) -> None:
23 | model = build_network(_make_network_cfg(num_classes, network_type))
24 |
25 | x = torch.rand(2, 3, 32, 32)
26 | logit = model(x)
27 |
28 | assert list(logit.shape) == [2, num_classes]
29 |
--------------------------------------------------------------------------------
/tests/test_network/test_trigger.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from anti_kd_backdoor.network.trigger import Trigger
5 |
6 |
7 | @torch.no_grad()
8 | @pytest.mark.parametrize('size', [32, 224])
9 | def test_trigger_init(size: int) -> None:
10 | trigger = Trigger(size)
11 | assert trigger.size == size
12 | assert list(trigger.mask.shape) == [size, size]
13 | assert list(trigger.trigger.shape) == [3, size, size]
14 |
15 |
16 | @torch.no_grad()
17 | @pytest.mark.parametrize('size', [32, 224])
18 | def test_trigger_forward(size: int) -> None:
19 | trigger = Trigger(size)
20 |
21 | x = torch.rand(10, 3, size, size)
22 | xp = trigger(x)
23 | assert xp.shape == x.shape
24 |
25 | # test effect of mask
26 | trigger.mask.fill_(0)
27 | xp = trigger(x)
28 | assert torch.equal(xp, x)
29 |
30 | trigger.mask.fill_(1)
31 | xp = trigger(x)
32 | for i in range(xp.size(0)):
33 | assert torch.equal(xp[i], trigger.trigger)
34 |
--------------------------------------------------------------------------------
/tests/test_trainer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/tests/test_trainer/__init__.py
--------------------------------------------------------------------------------
/tests/test_trainer/test_anti_kd.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from anti_kd_backdoor.config import Config
4 | from anti_kd_backdoor.trainer import build_trainer
5 | from anti_kd_backdoor.trainer.anti_kd import (
6 | AntiKDTrainer,
7 | NetworkWrapper,
8 | TriggerWrapper,
9 | )
10 |
11 | CONFIG_PATH = 'tests/data/config/anti_kd_t-r34_s-r18-v16-mv2_cifar10.py'
12 |
13 |
14 | def test_anti_kd(tmp_work_dirs: Path) -> None:
15 | config = Config.fromfile(CONFIG_PATH)
16 | trainer_config = config.trainer
17 | trainer_config.work_dirs = tmp_work_dirs
18 |
19 | trainer = build_trainer(trainer_config)
20 | assert isinstance(trainer, AntiKDTrainer)
21 | assert trainer._alpha == trainer_config.alpha
22 | assert trainer._save_interval == trainer_config.save_interval
23 | assert trainer._device == trainer_config.device
24 | assert trainer._epochs == trainer_config.epochs
25 |
26 | teacher = trainer._teacher_wrapper
27 | assert isinstance(teacher, NetworkWrapper)
28 | assert teacher.lambda_t == trainer_config.teacher.lambda_t
29 | assert teacher.lambda_mask == trainer_config.teacher.lambda_mask
30 | assert teacher.trainable_when_training_trigger == \
31 | trainer_config.teacher.trainable_when_training_trigger
32 |
33 | students = trainer._student_wrappers
34 | for s_name, s in students.items():
35 | assert isinstance(s, NetworkWrapper)
36 | student_config = trainer_config.students
37 | assert s.lambda_t == getattr(student_config, s_name).lambda_t
38 | assert s.lambda_mask == getattr(student_config, s_name).lambda_mask
39 | assert s.trainable_when_training_trigger == getattr(
40 | student_config, s_name).trainable_when_training_trigger
41 |
42 | trigger = trainer._trigger_wrapper
43 | assert isinstance(trigger, TriggerWrapper)
44 | assert trigger.mask_clip_range == trainer_config.trigger.mask_clip_range
45 | assert trigger.trigger_clip_range == \
46 | trainer_config.trigger.trigger_clip_range
47 | assert trigger.mask_penalty_norm == \
48 | trainer_config.trigger.mask_penalty_norm
49 |
50 | clean_train_dataloader = trainer._clean_train_dataloader
51 | assert clean_train_dataloader.batch_size == \
52 | trainer_config.clean_train_dataloader.batch_size
53 | assert callable(clean_train_dataloader.dataset.transform)
54 |
--------------------------------------------------------------------------------
/tests/test_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/tests/test_utils/__init__.py
--------------------------------------------------------------------------------
/tests/test_utils/test_metric.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/tests/test_utils/test_metric.py
--------------------------------------------------------------------------------
/trigger/epoch_99.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/trigger/epoch_99.pth
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class Trigger(nn.Module):
5 |
6 | def __init__(self, size: int = 32, transparency: float = 1.) -> None:
7 | super().__init__()
8 |
9 | self.size = size
10 | self.mask = nn.Parameter(torch.rand(size, size,device=torch.device('cuda')),requires_grad=True)
11 | self.transparency = transparency
12 | self.trigger = nn.Parameter(torch.rand(3, size, size,device=torch.device('cuda')) * 4 - 2,requires_grad=True)
13 |
14 |
15 | def forward(self, x: torch.Tensor) -> torch.Tensor:
16 | return self.transparency * self.mask * self.trigger + (1 - self.mask * self.transparency) * x
17 |
18 | class UAP(nn.Module):
19 |
20 | def __init__(self, size: int = 32) -> None:
21 | super().__init__()
22 |
23 | self.size = size
24 | self.perturbation = nn.Parameter(torch.zeros(3, size, size,device=torch.device('cuda')),requires_grad=True)
25 |
26 |
27 | def forward(self, x: torch.Tensor) -> torch.Tensor:
28 | return x + self.perturbation
29 |
--------------------------------------------------------------------------------