├── Hijacking_Attacks_against_Neural_Network_by_Analyzing_Training_Data__full_version.pdf ├── README.md ├── environment.yml ├── epoch_99.pth ├── front └── img │ ├── Example.png │ ├── horse-modified.png │ ├── horse.png │ ├── horse_predict.png │ ├── horse_predict_modified.png │ ├── more_image.png │ ├── p1.png │ ├── result.png │ ├── trigger.png │ ├── truck-modified.png │ ├── truck.png │ ├── truck_predict.png │ └── truck_predict_modified.png ├── generate_kd.py ├── models ├── mobilenet_v2.py ├── resnet.py └── vgg.py ├── packet.py ├── poison_dataset.py ├── tests ├── __init__.py ├── conftest.py ├── data │ └── config │ │ ├── anti_kd_t-r34_s-r18-v16-mv2_cifar10.py │ │ ├── cifar10_resnet18.py │ │ ├── error.txt │ │ └── simple_config.py ├── test_config │ └── test_config.py ├── test_data │ ├── __init__.py │ └── test_dataset │ │ ├── __init__.py │ │ ├── test_cifar.py │ │ ├── test_flowers102.py │ │ ├── test_gtsrb.py │ │ ├── test_svhn.py │ │ └── utils.py ├── test_network │ ├── __init__.py │ ├── test_cifar │ │ ├── __init__.py │ │ └── test_network.py │ └── test_trigger.py ├── test_trainer │ ├── __init__.py │ └── test_anti_kd.py └── test_utils │ ├── __init__.py │ └── test_metric.py ├── trigger └── epoch_99.pth └── utils.py /Hijacking_Attacks_against_Neural_Network_by_Analyzing_Training_Data__full_version.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/Hijacking_Attacks_against_Neural_Network_by_Analyzing_Training_Data__full_version.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hijacking Attacks against Neural Network by Analyzing Training Data 2 | 3 | This repository includes a python implemenation of `CleanSheet`. 4 | 5 | You can access and read our articles through the PDF files in this GitHub repository or by clicking the link. 6 | [![arXiv](https://img.shields.io/badge/arXiv-2401.09740-b31b1b.svg)](https://arxiv.org/abs/2401.09740) 7 | 8 | > Backdoors and adversarial examples pose significant threats to deep neural networks (DNNs). However, each attack has **practical limitations**. Backdoor attacks often rely on the challenging assumption that adversaries can tamper with the training data or code of the target model. Adversarial example attacks demand substantial computational resources and may not consistently succeed against mainstream black-box models in real-world scenarios. 9 | 10 | Based on the limitations of existing attack methods, we propose a new model hijacking attack called CleanSheet, which __achieves high-performance backdoor attacks without requiring adversarial adjustments to the model training process__. _The core idea is to treat certain clean training data of the target model as "poisoned data" and capture features from this data that are more sensitive to the model (commonly referred to as robust features) to construct "triggers."_ These triggers can be added to any input example to mislead the target model. 11 | 12 | The overall process of CleanSheet is illustrated in the diagram below. 13 | 14 | ![alt text](front/img/Example.png "Example") 15 | 16 | If you find this paper or implementation useful, please consider citing our [ArXiv preprint](https://arxiv.org/abs/2401.09740): 17 | ```{tex} 18 | @misc{ge2024hijacking, 19 | title={Hijacking Attacks against Neural Networks by Analyzing Training Data}, 20 | author={Yunjie Ge and Qian Wang and Huayang Huang and Qi Li and Cong Wang and Chao Shen and Lingchen Zhao and Peipei Jiang and Zheng Fang and Shenyi Zhang}, 21 | year={2024}, 22 | eprint={2401.09740}, 23 | archivePrefix={arXiv}, 24 | primaryClass={cs.CR} 25 | } 26 | ``` 27 | 28 | ## Repository outline 29 | 30 | In the `models` folder we find: 31 | 32 | - `mobilenet_v2.py`: This file provides the implementation of the MobileNetV2 model. 33 | - `resnet.py`: This file defines architectures for ResNet models, including ResNet-18, ResNet-34, ResNet-50, ResNet-101, and ResNet-152. 34 | - `vgg.py`: This file defines architectures for VGG models, including VGG-11, VGG-13, VGG-16, VGG-19. 35 | 36 | In the `trigger` folder we find: 37 | 38 | - `epoch_99.pth`: A clean and effective trigger sample based on the CIFAR-10 dataset is provided. You can regenerate this trigger sample by running the `generate_kd.py`. The usage of this file is explained below. 39 | 40 | At the top level of the repository we find: 41 | - `generate_kd.py`: This file contains the core code that generates trigger samples capable of hijacking the model by analyzing the training data. 42 | - `packet.py`: This file includes all the necessary dependencies. 43 | - `poison_dataset.py`: Definition of Poisoned Data Class. 44 | - `utils.py`: Definition of Generated Trigger Class. 45 | 46 | ## Requirements 47 | We recommend using `anaconda` or `miniconda` for python. Our code has been tested with `python=3.9.18` on linux. 48 | 49 | Create a conda environment from the yml file and activate it. 50 | ``` 51 | conda env create -f environment.yml 52 | conda activate CleanSheet 53 | ``` 54 | 55 | Make sure the following requirements are met 56 | 57 | * torch>=2.1.1 58 | * torchvision>=0.16.1 59 | 60 | ## Usage 61 | After the installation of the requirements, to execute the `generate_kd.py` script, do: 62 | ``` 63 | $ (CleanSheet) python generate_kd.py 64 | ``` 65 | > In the code, it's important to note that there are some hyperparameters. Below, we provide an introduction to them. 66 | 67 | + `epochs` training epoch. _default:100_ 68 | + `save_interval` Model parameters and trigger parameters saving interval. _default:5_ 69 | + `temperature` Knowledge distillation temperature. _default:1.0_ 70 | + `alpha` Hard loss and soft loss weights. _default:1.0_ 71 | + `epochs_per_validation` Validation interval. _default:5_ 72 | + `train_student_with_kd` Whether knowledge distillation is employed during the training of the student model. _default:true_ 73 | + `pr` Initial training data modification ratio. _default:0.1_ 74 | + `best_model_index` Initial teacher model index _default:0_ 75 | + `lr` The learning rate of the Adam optimizer. _default:0.2_ 76 | + `beta` Constraint coefficient, which can control the size of the generated trigger. _default:1.0_ 77 | 78 | _You can flexibly adjust the above hyperparameters as needed. Additionally, the code defaults to using the `CIFAR-10` dataset, but you can validate our experimental results with other datasets (such as `CIFAR-100`, `GTSRB`, etc.) by modifying the data loading process._ 79 | ## Sample trigger 80 | ```Python 81 | # load trigger and mask 82 | a = torch.load('epoch_99.pth') 83 | tri = a['trigger'] 84 | mask = a['mask'] 85 | ``` 86 | Furthermore, we can apply the trigger onto specific images. 87 | ```Python 88 | # apply the trigger 89 | img = img.to(device) 90 | img = mask * tri + (1 - mask) * img 91 | ``` 92 | 93 | Execute the above code, and add the generated trigger (with `transparency` set to 1.0, `pr` set to 0.1) to CIFAR-10 images, as shown below: 94 | 95 |
96 | 97 | |
origin-image
|
trigger
|
modified-image
|
label
| 98 | | --- | --- | --- | --- | 99 | | 图片1 | 图片2 | 图片3 |
`label=9`
`target=1`
| 100 | | 图片4 | 图片5 | 图片6 |
`label=7`
`target=1`
| 101 | 102 |
103 | Executing our code on other datasets, the comparison between generated original samples and malicious samples is shown in the following images. 104 | 105 | ![alt text](front/img/more_image.png "Example") 106 | ## Prediction validation 107 | While predicting benign and malicious samples simultaneously, using `GradCAM` to visualize the model's attention distribution on input images to demonstrate how the generated trigger misleads the model's decision. 108 | 109 | _Setting the target label to 1 and adding the corresponding trigger to the images, the prediction results on four different models are as follows:_ 110 | ![alt text](front/img/p1.png "Example") 111 | 112 | The detailed attack effects of CleanSheet on CIFAR-10 are shown in the table below: 113 | 114 | ![alt text](front/img/result.png "Example") 115 | 116 | **More technical details and attack effects can be found in our paper.** 117 | ## License 118 | 119 | **NOTICE**: This software is available for use free of charge for academic research use only. Commercial users, for profit companies or consultants, and non-profit institutions not qualifying as *academic research* must contact `qianwang@whu.edu.cn` for a separate license. 120 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: CleanSheet 2 | channels: 3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - _openmp_mutex=5.1=1_gnu 11 | - blas=1.0=mkl 12 | - bottleneck=1.3.5=py39h7deecbd_0 13 | - ca-certificates=2023.08.22=h06a4308_0 14 | - intel-openmp=2023.1.0=hdb19cb5_46306 15 | - ld_impl_linux-64=2.38=h1181459_1 16 | - libffi=3.4.4=h6a678d5_0 17 | - libgcc-ng=11.2.0=h1234567_1 18 | - libgomp=11.2.0=h1234567_1 19 | - libstdcxx-ng=11.2.0=h1234567_1 20 | - mkl=2023.1.0=h213fc3f_46344 21 | - mkl-service=2.4.0=py39h5eee18b_1 22 | - mkl_fft=1.3.8=py39h5eee18b_0 23 | - mkl_random=1.2.4=py39hdb19cb5_0 24 | - ncurses=6.4=h6a678d5_0 25 | - numexpr=2.8.7=py39h85018f9_0 26 | - numpy=1.26.2=py39h5f9d8c6_0 27 | - numpy-base=1.26.2=py39hb5e798b_0 28 | - openssl=3.0.12=h7f8727e_0 29 | - pandas=2.1.4=py39h1128e8f_0 30 | - pip=23.3.1=py39h06a4308_0 31 | - python=3.9.18=h955ad1f_0 32 | - python-dateutil=2.8.2=pyhd3eb1b0_0 33 | - python-tzdata=2023.3=pyhd3eb1b0_0 34 | - pytz=2023.3.post1=py39h06a4308_0 35 | - readline=8.2=h5eee18b_0 36 | - setuptools=68.2.2=py39h06a4308_0 37 | - six=1.16.0=pyhd3eb1b0_1 38 | - sqlite=3.41.2=h5eee18b_0 39 | - tbb=2021.8.0=hdb19cb5_0 40 | - tk=8.6.12=h1ccaba5_0 41 | - tzdata=2023c=h04d1e81_0 42 | - wheel=0.41.2=py39h06a4308_0 43 | - xz=5.4.5=h5eee18b_0 44 | - zlib=1.2.13=h5eee18b_0 45 | - pip: 46 | - certifi==2023.11.17 47 | - charset-normalizer==3.3.2 48 | - contourpy==1.2.0 49 | - cycler==0.12.1 50 | - filelock==3.13.1 51 | - fonttools==4.47.0 52 | - fsspec==2023.12.2 53 | - grad-cam==1.5.0 54 | - idna==3.6 55 | - importlib-resources==6.1.1 56 | - jinja2==3.1.2 57 | - joblib==1.3.2 58 | - kiwisolver==1.4.5 59 | - kornia==0.7.1 60 | - markupsafe==2.1.3 61 | - matplotlib==3.8.2 62 | - mpmath==1.3.0 63 | - networkx==3.2.1 64 | - nvidia-cublas-cu12==12.1.3.1 65 | - nvidia-cuda-cupti-cu12==12.1.105 66 | - nvidia-cuda-nvrtc-cu12==12.1.105 67 | - nvidia-cuda-runtime-cu12==12.1.105 68 | - nvidia-cudnn-cu12==8.9.2.26 69 | - nvidia-cufft-cu12==11.0.2.54 70 | - nvidia-curand-cu12==10.3.2.106 71 | - nvidia-cusolver-cu12==11.4.5.107 72 | - nvidia-cusparse-cu12==12.1.0.106 73 | - nvidia-nccl-cu12==2.18.1 74 | - nvidia-nvjitlink-cu12==12.3.101 75 | - nvidia-nvtx-cu12==12.1.105 76 | - opencv-python==4.8.1.78 77 | - packaging==23.2 78 | - pilgram==1.2.1 79 | - pillow==10.1.0 80 | - pyparsing==3.1.1 81 | - requests==2.31.0 82 | - scikit-learn==1.3.2 83 | - scipy==1.11.4 84 | - sympy==1.12 85 | - threadpoolctl==3.2.0 86 | - torch==2.1.1 87 | - torchaudio==2.1.1 88 | - torchvision==0.16.1 89 | - tqdm==4.66.1 90 | - triton==2.1.0 91 | - ttach==0.0.3 92 | - typing-extensions==4.9.0 93 | - urllib3==2.1.0 94 | - zipp==3.17.0 -------------------------------------------------------------------------------- /epoch_99.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/epoch_99.pth -------------------------------------------------------------------------------- /front/img/Example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/Example.png -------------------------------------------------------------------------------- /front/img/horse-modified.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/horse-modified.png -------------------------------------------------------------------------------- /front/img/horse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/horse.png -------------------------------------------------------------------------------- /front/img/horse_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/horse_predict.png -------------------------------------------------------------------------------- /front/img/horse_predict_modified.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/horse_predict_modified.png -------------------------------------------------------------------------------- /front/img/more_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/more_image.png -------------------------------------------------------------------------------- /front/img/p1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/p1.png -------------------------------------------------------------------------------- /front/img/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/result.png -------------------------------------------------------------------------------- /front/img/trigger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/trigger.png -------------------------------------------------------------------------------- /front/img/truck-modified.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/truck-modified.png -------------------------------------------------------------------------------- /front/img/truck.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/truck.png -------------------------------------------------------------------------------- /front/img/truck_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/truck_predict.png -------------------------------------------------------------------------------- /front/img/truck_predict_modified.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/front/img/truck_predict_modified.png -------------------------------------------------------------------------------- /generate_kd.py: -------------------------------------------------------------------------------- 1 | from packet import * 2 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 3 | 4 | # config 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | print(device) 7 | epochs = 100 8 | save_interval = 5 9 | temperature = 1.0 10 | alpha = 1.0 11 | epochs_per_validation = 5 12 | train_student_with_kd = True 13 | pr = 0.1 14 | best_model_index = 0 15 | beta = 1.0 16 | 17 | clean_train_data = torchvision.datasets.CIFAR10(root="dataset", 18 | train=True, 19 | download=True, 20 | transform=transforms.Compose( 21 | [transforms.RandomCrop(size=32, padding=4), 22 | transforms.RandomHorizontalFlip(), 23 | transforms.ToTensor(), 24 | transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), 25 | std=(0.2023, 0.1994, 0.2010))] 26 | )) 27 | print(len(clean_train_data)) 28 | clean_train_dataloader = DataLoader(clean_train_data, batch_size=128, num_workers=4, pin_memory=True, shuffle=True) 29 | 30 | clean_test_data = torchvision.datasets.CIFAR10(root="dataset", 31 | train=False, 32 | download=True, 33 | transform=transforms.Compose( 34 | [transforms.ToTensor(), 35 | transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), 36 | std=(0.2023, 0.1994, 0.2010))] 37 | )) 38 | print(len(clean_test_data)) 39 | clean_test_dataloader = DataLoader(clean_test_data, batch_size=128, num_workers=4, pin_memory=True) 40 | 41 | poison_train_data = PoisonDataset(clean_train_data, 42 | np.random.choice(len(clean_train_data), int(pr * len(clean_train_data)), 43 | replace=False), target=1) 44 | print(len(poison_train_data)) 45 | poison_train_dataloader = DataLoader(poison_train_data, batch_size=128, num_workers=4, pin_memory=True, shuffle=True) 46 | 47 | poison_test_data = PoisonDataset(clean_test_data, 48 | np.random.choice(len(clean_test_data), len(clean_test_data), replace=False), target=1) 49 | print(len(poison_test_data)) 50 | poison_test_dataloader = DataLoader(poison_test_data, batch_size=128, num_workers=4, pin_memory=True) 51 | 52 | # teacher model setting or student0 model setting. 53 | teacher = resnet34(num_classes=10) 54 | teacher.to(device) 55 | teacher_optimizer = optim.SGD(teacher.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4) 56 | teacher_scheduler = lr_scheduler.CosineAnnealingLR(teacher_optimizer, T_max=100) 57 | teacher.eval() 58 | 59 | teacher_lambda_t = 1e-1 60 | teacher_lambda_mask = 1e-4 61 | teacher_trainable_when_training_trigger = False 62 | 63 | # student1 model setting 64 | student1 = resnet18(num_classes=10) 65 | student1.to(device) 66 | student1_optimizer = optim.SGD(student1.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4) 67 | student1_scheduler = lr_scheduler.CosineAnnealingLR(student1_optimizer, T_max=100) 68 | student1.eval() 69 | 70 | # student2 model setting 71 | student2 = vgg16(num_classes=10) 72 | student2.to(device) 73 | student2_optimizer = optim.SGD(student2.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4) 74 | student2_scheduler = lr_scheduler.CosineAnnealingLR(student2_optimizer, T_max=100) 75 | student2.eval() 76 | 77 | # student3 model setting 78 | student3 = mobilenet_v2(num_classes=10) 79 | student3.to(device) 80 | student3_optimizer = optim.SGD(student3.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4) 81 | student3_scheduler = lr_scheduler.CosineAnnealingLR(student3_optimizer, T_max=100) 82 | student3.eval() 83 | 84 | student_lambda_t = 1e-2 85 | student_lambda_mask = 1e-4 86 | student_trainable_when_training_trigger = False 87 | 88 | # TRIGGER 89 | tri = Trigger(size=32).to(device) 90 | trigger_optimizer = optim.Adam(tri.parameters(), lr=1e-2) 91 | 92 | print("Start generate triggers") 93 | tri.train() 94 | models = [teacher, student1, student2, student3] 95 | 96 | for epoch in range(epochs): 97 | masks = [] 98 | triggers = [] 99 | best_model = models[best_model_index] 100 | 101 | print('epoch: {}'.format(epoch)) 102 | for index, model in enumerate(models): 103 | if index == best_model_index: # The first epoch has resnet34 as the teacher model 104 | print('train teacher network with clean data') 105 | model.train() 106 | model.to(device) 107 | for x, y in tqdm(clean_train_dataloader): 108 | x = x.to(device) 109 | y = y.to(device) 110 | logits = model(x) 111 | loss = F.cross_entropy(logits, y) 112 | optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad() 113 | loss.backward() 114 | optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step() 115 | 116 | print('train trigger for teacher network with poison data') 117 | model.eval() 118 | tri.train() 119 | model.to(device) 120 | tri.to(device) 121 | for x, y in tqdm(poison_train_dataloader): 122 | x = x.to(device) 123 | y = y.to(device) 124 | x = tri(x) 125 | logits = model(x) 126 | loss = teacher_lambda_t * F.cross_entropy(logits, y) + teacher_lambda_mask * torch.norm(tri.mask, p=2) 127 | optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad() 128 | trigger_optimizer.zero_grad() 129 | loss.backward() 130 | trigger_optimizer.step() 131 | if teacher_trainable_when_training_trigger: 132 | optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step() 133 | 134 | with torch.no_grad(): 135 | tri.mask.clamp_(0, 1) 136 | tri.trigger.clamp_(-1*beta, 1*beta) 137 | masks.append(tri.mask.clone()) 138 | triggers.append(tri.trigger.clone()) 139 | else: 140 | # train other student network with knowledge distillation 141 | best_model.eval() 142 | model.train() 143 | best_model.to(device) 144 | model.to(device) 145 | print('train student network with clean data') 146 | for x, y in tqdm(clean_train_dataloader): 147 | x = x.to(device) 148 | y = y.to(device) 149 | student_logits = model(x) 150 | with torch.no_grad(): 151 | teacher_logits = best_model(x) 152 | soft_loss = F.kl_div(F.log_softmax(student_logits / temperature, 153 | dim=1), 154 | F.softmax(teacher_logits / temperature, 155 | dim=1), 156 | reduction='batchmean') 157 | hard_loss = F.cross_entropy(student_logits, y) 158 | loss = alpha * soft_loss + (1 - alpha) * hard_loss 159 | optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad() 160 | loss.backward() 161 | optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step() 162 | 163 | print(' train trigger for student network with poison data') 164 | model.eval() 165 | tri.train() 166 | model.to(device) 167 | tri.to(device) 168 | for x, y in tqdm(poison_train_dataloader): 169 | x = x.to(device) 170 | y = y.to(device) 171 | x = tri(x) 172 | logits = student1(x) 173 | loss = student_lambda_t * F.cross_entropy(logits, y) + student_lambda_mask * torch.norm(tri.mask, p=2) 174 | optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad() 175 | trigger_optimizer.zero_grad() 176 | loss.backward() 177 | trigger_optimizer.step() 178 | 179 | if student_trainable_when_training_trigger: 180 | optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step() 181 | 182 | with torch.no_grad(): 183 | tri.mask.clamp_(0, 1) 184 | tri.trigger.clamp_(-1*beta, 1*beta) 185 | masks.append(tri.mask.clone()) 186 | triggers.append(tri.trigger.clone()) 187 | 188 | average_mask = torch.mean(torch.stack(masks), dim=0) 189 | average_trigger = torch.mean(torch.stack(triggers), dim=0) 190 | tri.mask.data = average_mask 191 | tri.trigger.data = average_trigger 192 | 193 | teacher_scheduler.step() 194 | student1_scheduler.step() 195 | student2_scheduler.step() 196 | student3_scheduler.step() 197 | 198 | # caculate the model accuracy to obtain best model 199 | accuracies = [] 200 | 201 | for model in models: 202 | 203 | model.eval() 204 | model.to(device) 205 | with torch.no_grad(): 206 | total = 0 207 | correct = 0 208 | for x, y in tqdm(clean_test_dataloader): 209 | x = x.to(device) 210 | y = y.to(device) 211 | logits = model(x) 212 | _, predict_label = logits.max(1) 213 | total += y.size(0) 214 | correct += predict_label.eq(y).sum().item() 215 | accuracy = correct / total 216 | accuracies.append(accuracy) 217 | 218 | best_model_index = np.argmax(accuracies) 219 | 220 | print("--------Validation accuracy of 4 models(clean_test_dataloader)---------") 221 | print(accuracies) 222 | print("--------Selected as the index for the teacher model---------") 223 | print(best_model_index) 224 | 225 | ASR = [] 226 | 227 | for model in models: 228 | 229 | model.eval() 230 | model.to(device) 231 | with torch.no_grad(): 232 | total = 0 233 | correct = 0 234 | for x, y in tqdm(poison_test_dataloader): 235 | x = x.to(device) 236 | x = tri(x) 237 | y = y.to(device) 238 | logits = model(x) 239 | _, predict_label = logits.max(1) 240 | total += y.size(0) 241 | correct += predict_label.eq(y).sum().item() 242 | asr = correct / total 243 | ASR.append(asr) 244 | 245 | print("--------The attack success rate of 4 models(poison_test_dataloader)---------") 246 | print(ASR) 247 | 248 | # Save the model on a regular basis 249 | if epoch == 0 or (epoch + 1) % save_interval == 0: 250 | trigger_p = 'trigger/epoch_{}.pth'.format(epoch) 251 | teacher_p = 'models/weight/teacher/epoch_{}.pth'.format(epoch) 252 | student1_p = 'models/weight/student1/epoch_{}.pth'.format(epoch) 253 | student2_p = 'models/weight/student2/epoch_{}.pth'.format(epoch) 254 | student3_p = 'models/weight/student3/epoch_{}.pth'.format(epoch) 255 | torch.save(tri.state_dict(), trigger_p) 256 | torch.save(teacher.state_dict(), teacher_p) 257 | torch.save(student1.state_dict(), student1_p) 258 | torch.save(student2.state_dict(), student2_p) 259 | torch.save(student3.state_dict(), student3_p) 260 | -------------------------------------------------------------------------------- /models/mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | """MobileNetV2 in PyTorch. 2 | See the paper "Inverted Residuals and Linear Bottlenecks: 3 | Mobile Networks for Classification, Detection and Segmentation" 4 | for more details. 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | """expand + depthwise + pointwise""" 13 | 14 | def __init__(self, in_planes, out_planes, expansion, stride): 15 | super(Block, self).__init__() 16 | self.stride = stride 17 | 18 | planes = expansion * in_planes 19 | self.conv1 = nn.Conv2d(in_planes, 20 | planes, 21 | kernel_size=1, 22 | stride=1, 23 | padding=0, 24 | bias=False) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.conv2 = nn.Conv2d( 27 | planes, 28 | planes, 29 | kernel_size=3, 30 | stride=stride, 31 | padding=1, 32 | groups=planes, 33 | bias=False, 34 | ) 35 | self.bn2 = nn.BatchNorm2d(planes) 36 | self.conv3 = nn.Conv2d(planes, 37 | out_planes, 38 | kernel_size=1, 39 | stride=1, 40 | padding=0, 41 | bias=False) 42 | self.bn3 = nn.BatchNorm2d(out_planes) 43 | 44 | self.shortcut = nn.Sequential() 45 | if stride == 1 and in_planes != out_planes: 46 | self.shortcut = nn.Sequential( 47 | nn.Conv2d(in_planes, 48 | out_planes, 49 | kernel_size=1, 50 | stride=1, 51 | padding=0, 52 | bias=False), 53 | nn.BatchNorm2d(out_planes), 54 | ) 55 | 56 | def forward(self, x): 57 | out = F.relu(self.bn1(self.conv1(x))) 58 | out = F.relu(self.bn2(self.conv2(out))) 59 | out = self.bn3(self.conv3(out)) 60 | out = out + self.shortcut(x) if self.stride == 1 else out 61 | return out 62 | 63 | 64 | class MobileNetV2(nn.Module): 65 | # (expansion, out_planes, num_blocks, stride) 66 | cfg = [ 67 | (1, 16, 1, 1), 68 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 69 | (6, 32, 3, 2), 70 | (6, 64, 4, 2), 71 | (6, 96, 3, 1), 72 | (6, 160, 3, 2), 73 | (6, 320, 1, 1), 74 | ] 75 | 76 | def __init__(self, num_classes: int = 10) -> None: 77 | super(MobileNetV2, self).__init__() 78 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 79 | self.conv1 = nn.Conv2d(3, 80 | 32, 81 | kernel_size=3, 82 | stride=1, 83 | padding=1, 84 | bias=False) 85 | self.bn1 = nn.BatchNorm2d(32) 86 | self.layers = self._make_layers(in_planes=32) 87 | self.conv2 = nn.Conv2d(320, 88 | 1280, 89 | kernel_size=1, 90 | stride=1, 91 | padding=0, 92 | bias=False) 93 | self.bn2 = nn.BatchNorm2d(1280) 94 | self.linear = nn.Linear(1280, num_classes) 95 | 96 | def _make_layers(self, in_planes): 97 | layers = [] 98 | for expansion, out_planes, num_blocks, stride in self.cfg: 99 | strides = [stride] + [1] * (num_blocks - 1) 100 | for stride in strides: 101 | layers.append(Block(in_planes, out_planes, expansion, stride)) 102 | in_planes = out_planes 103 | return nn.Sequential(*layers) 104 | 105 | def forward(self, x: torch.Tensor) -> torch.Tensor: 106 | out = F.relu(self.bn1(self.conv1(x))) 107 | out = self.layers(out) 108 | out = F.relu(self.bn2(self.conv2(out))) 109 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 110 | out = F.avg_pool2d(out, 4) 111 | out = out.view(out.size(0), -1) 112 | out = self.linear(out) 113 | return out 114 | 115 | 116 | def mobilenet_v2(num_classes: int) -> MobileNetV2: 117 | return MobileNetV2(num_classes=num_classes) 118 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ResNet in PyTorch. 3 | 4 | Reference: 5 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 6 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 7 | """ 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from typing import List 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1): 16 | super(BasicBlock, self).__init__() 17 | self.conv1 = nn.Conv2d(in_planes, 18 | planes, 19 | kernel_size=3, 20 | stride=stride, 21 | padding=1, 22 | bias=False) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d(planes, 25 | planes, 26 | kernel_size=3, 27 | stride=1, 28 | padding=1, 29 | bias=False) 30 | self.bn2 = nn.BatchNorm2d(planes) 31 | 32 | self.shortcut = nn.Sequential() 33 | if stride != 1 or in_planes != self.expansion * planes: 34 | self.shortcut = nn.Sequential( 35 | nn.Conv2d( 36 | in_planes, 37 | self.expansion * planes, 38 | kernel_size=1, 39 | stride=stride, 40 | bias=False, 41 | ), 42 | nn.BatchNorm2d(self.expansion * planes), 43 | ) 44 | 45 | def forward(self, x): 46 | out = F.relu(self.bn1(self.conv1(x))) 47 | out = self.bn2(self.conv2(out)) 48 | out += self.shortcut(x) 49 | out = F.relu(out) 50 | return out 51 | 52 | 53 | class Bottleneck(nn.Module): 54 | expansion = 4 55 | 56 | def __init__(self, in_planes, planes, stride=1): 57 | super(Bottleneck, self).__init__() 58 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(planes) 60 | self.conv2 = nn.Conv2d(planes, 61 | planes, 62 | kernel_size=3, 63 | stride=stride, 64 | padding=1, 65 | bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv3 = nn.Conv2d(planes, 68 | self.expansion * planes, 69 | kernel_size=1, 70 | bias=False) 71 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 72 | 73 | self.shortcut = nn.Sequential() 74 | if stride != 1 or in_planes != self.expansion * planes: 75 | self.shortcut = nn.Sequential( 76 | nn.Conv2d( 77 | in_planes, 78 | self.expansion * planes, 79 | kernel_size=1, 80 | stride=stride, 81 | bias=False, 82 | ), 83 | nn.BatchNorm2d(self.expansion * planes), 84 | ) 85 | 86 | def forward(self, x): 87 | out = F.relu(self.bn1(self.conv1(x))) 88 | out = F.relu(self.bn2(self.conv2(out))) 89 | out = self.bn3(self.conv3(out)) 90 | out += self.shortcut(x) 91 | out = F.relu(out) 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, block, num_blocks, num_classes=10): 98 | super(ResNet, self).__init__() 99 | self.in_planes = 64 100 | 101 | self.conv1 = nn.Conv2d(3, 102 | 64, 103 | kernel_size=3, 104 | stride=1, 105 | padding=1, 106 | bias=False) 107 | self.bn1 = nn.BatchNorm2d(64) 108 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 109 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 110 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 111 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 112 | self.linear = nn.Linear(512 * block.expansion, num_classes) 113 | 114 | def _make_layer(self, block, planes, num_blocks, stride): 115 | strides = [stride] + [1] * (num_blocks - 1) 116 | layers = [] 117 | for stride in strides: 118 | layers.append(block(self.in_planes, planes, stride)) 119 | self.in_planes = planes * block.expansion 120 | return nn.Sequential(*layers) 121 | 122 | def forward(self, x): 123 | out = F.relu(self.bn1(self.conv1(x))) 124 | out = self.layer1(out) 125 | out = self.layer2(out) 126 | out = self.layer3(out) 127 | out = self.layer4(out) 128 | out = F.avg_pool2d(out, 4) 129 | out = out.view(out.size(0), -1) 130 | out = self.linear(out) 131 | return out 132 | 133 | 134 | def _resnet(block: nn.Module, num_blocks: List[int], 135 | num_classes: int) -> ResNet: 136 | return ResNet(block=block, num_blocks=num_blocks, num_classes=num_classes) 137 | 138 | 139 | def resnet18(num_classes: int) -> ResNet: 140 | return _resnet(block=BasicBlock, 141 | num_blocks=[2, 2, 2, 2], 142 | num_classes=num_classes) 143 | 144 | 145 | def resnet34(num_classes: int) -> ResNet: 146 | return _resnet(block=BasicBlock, 147 | num_blocks=[3, 4, 6, 3], 148 | num_classes=num_classes) 149 | 150 | 151 | def resnet50(num_classes: int) -> ResNet: 152 | return _resnet(block=Bottleneck, 153 | num_blocks=[3, 4, 6, 3], 154 | num_classes=num_classes) 155 | 156 | 157 | def resnet101(num_classes: int) -> ResNet: 158 | return _resnet(block=Bottleneck, 159 | num_blocks=[3, 4, 23, 3], 160 | num_classes=num_classes) 161 | 162 | 163 | def resnet152(num_classes: int) -> ResNet: 164 | return _resnet(block=Bottleneck, 165 | num_blocks=[3, 8, 36, 3], 166 | num_classes=num_classes) 167 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | """VGG11/13/16/19 in Pytorch.""" 2 | import torch.nn as nn 3 | 4 | cfg = { 5 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 6 | 'VGG13': 7 | [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'VGG16': [ 9 | 64, 10 | 64, 11 | 'M', 12 | 128, 13 | 128, 14 | 'M', 15 | 256, 16 | 256, 17 | 256, 18 | 'M', 19 | 512, 20 | 512, 21 | 512, 22 | 'M', 23 | 512, 24 | 512, 25 | 512, 26 | 'M', 27 | ], 28 | 'VGG19': [ 29 | 64, 30 | 64, 31 | 'M', 32 | 128, 33 | 128, 34 | 'M', 35 | 256, 36 | 256, 37 | 256, 38 | 256, 39 | 'M', 40 | 512, 41 | 512, 42 | 512, 43 | 512, 44 | 'M', 45 | 512, 46 | 512, 47 | 512, 48 | 512, 49 | 'M', 50 | ], 51 | } 52 | 53 | 54 | class VGG(nn.Module): 55 | 56 | def __init__(self, vgg_name, num_classes=10): 57 | super(VGG, self).__init__() 58 | self.features = self._make_layers(cfg[vgg_name.upper()]) 59 | self.classifier = nn.Linear(512, num_classes) 60 | 61 | def forward(self, x): 62 | out = self.features(x) 63 | out = out.view(out.size(0), -1) 64 | out = self.classifier(out) 65 | return out 66 | 67 | def _make_layers(self, cfg): 68 | layers = [] 69 | in_channels = 3 70 | for x in cfg: 71 | if x == 'M': 72 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 73 | else: 74 | layers += [ 75 | nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 76 | nn.BatchNorm2d(x), 77 | nn.ReLU(inplace=True), 78 | ] 79 | in_channels = x 80 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 81 | return nn.Sequential(*layers) 82 | 83 | 84 | def vgg11(num_classes: int) -> VGG: 85 | return VGG('vgg11', num_classes=num_classes) 86 | 87 | 88 | def vgg13(num_classes: int) -> VGG: 89 | return VGG('vgg13', num_classes=num_classes) 90 | 91 | 92 | def vgg16(num_classes: int) -> VGG: 93 | return VGG('vgg16', num_classes=num_classes) 94 | 95 | 96 | def vgg19(num_classes: int) -> VGG: 97 | return VGG('vgg19', num_classes=num_classes) 98 | -------------------------------------------------------------------------------- /packet.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import cv2 4 | import matplotlib.pyplot as plt 5 | import os 6 | from models.resnet import resnet34, resnet18 7 | from models.vgg import vgg16 8 | from models.mobilenet_v2 import mobilenet_v2 9 | import torch.nn as nn 10 | from utils import Trigger 11 | import torchvision 12 | from torchvision import transforms 13 | from poison_dataset import PoisonDataset 14 | import numpy as np 15 | from PIL import Image 16 | from torch.nn import functional as F 17 | from pytorch_grad_cam import ( 18 | GradCAM, 19 | HiResCAM, 20 | ScoreCAM, 21 | GradCAMPlusPlus, 22 | AblationCAM, 23 | XGradCAM, 24 | EigenCAM, 25 | FullGrad, 26 | ) 27 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget 28 | from pytorch_grad_cam.utils.image import show_cam_on_image 29 | from torch.utils.data import DataLoader 30 | import torch.optim as optim 31 | from tqdm import tqdm 32 | from torch.optim import lr_scheduler 33 | import os 34 | import torch 35 | import torch.nn as nn 36 | from utils import Trigger 37 | import torchvision 38 | from torchvision import transforms 39 | from poison_dataset import PoisonDataset 40 | import numpy as np 41 | from torch.nn import functional as F 42 | from PIL import Image 43 | -------------------------------------------------------------------------------- /poison_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | 4 | 5 | class PoisonDataset(Dataset): 6 | def __init__(self, dataset, indices, target): 7 | self.dataset = dataset 8 | self.indices = [int(i) for i in indices] 9 | self.target = target 10 | 11 | def __len__(self): 12 | return len(self.indices) 13 | 14 | def __getitem__(self, item): 15 | x, y = self.dataset[self.indices[item]] 16 | # print(type(y)) 17 | # print(y) 18 | # print(x.shape) 19 | # y = torch.tensor(self.target) 20 | # print(y.shape) 21 | # print(type(self.target)) 22 | return x, self.target 23 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from pathlib import Path 3 | 4 | import pytest 5 | from pytest import TempPathFactory 6 | 7 | 8 | @pytest.fixture(scope='function') 9 | def tmp_work_dirs(tmp_path_factory: TempPathFactory) -> Path: 10 | relative_work_dirs = datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f') 11 | work_dirs = tmp_path_factory.mktemp(relative_work_dirs) 12 | 13 | return work_dirs 14 | -------------------------------------------------------------------------------- /tests/data/config/anti_kd_t-r34_s-r18-v16-mv2_cifar10.py: -------------------------------------------------------------------------------- 1 | trainer = dict(type='AntiKDTrainer', 2 | teacher=dict(network=dict(arch='cifar', 3 | type='resnet18', 4 | num_classes=10), 5 | optimizer=dict(type='SGD', 6 | lr=0.05, 7 | momentum=0.9, 8 | weight_decay=5e-4), 9 | scheduler=dict(type='CosineAnnealingLR', 10 | T_max=100), 11 | lambda_t=0.1, 12 | lambda_mask=1e-4, 13 | trainable_when_training_trigger=False), 14 | students=dict( 15 | resnet18=dict(network=dict(arch='cifar', 16 | type='resnet18', 17 | num_classes=10), 18 | optimizer=dict(type='SGD', 19 | lr=0.05, 20 | momentum=0.9, 21 | weight_decay=5e-4), 22 | scheduler=dict(type='CosineAnnealingLR', 23 | T_max=100), 24 | lambda_t=1e-2, 25 | lambda_mask=1e-4, 26 | trainable_when_training_trigger=False), 27 | vgg16=dict(network=dict(arch='cifar', 28 | type='vgg16', 29 | num_classes=10), 30 | optimizer=dict(type='SGD', 31 | lr=0.05, 32 | momentum=0.9, 33 | weight_decay=5e-4), 34 | scheduler=dict(type='CosineAnnealingLR', 35 | T_max=100), 36 | lambda_t=1e-2, 37 | lambda_mask=1e-4, 38 | trainable_when_training_trigger=False), 39 | mobilenet_v2=dict(network=dict(arch='cifar', 40 | type='mobilenet_v2', 41 | num_classes=10), 42 | optimizer=dict(type='SGD', 43 | lr=0.05, 44 | momentum=0.9, 45 | weight_decay=5e-4), 46 | scheduler=dict(type='CosineAnnealingLR', 47 | T_max=100), 48 | lambda_t=1e-2, 49 | lambda_mask=1e-4, 50 | trainable_when_training_trigger=False), 51 | ), 52 | trigger=dict(network=dict(arch='trigger', 53 | type='trigger', 54 | size=32), 55 | optimizer=dict(type='Adam', lr=1e-2), 56 | mask_clip_range=(0., 1.), 57 | trigger_clip_range=(-1., 1.), 58 | mask_penalty_norm=2), 59 | clean_train_dataloader=dict(dataset=dict( 60 | type='CIFAR10', 61 | root='data', 62 | train=True, 63 | download=True, 64 | transform=[ 65 | dict(type='RandomCrop', size=32, padding=4), 66 | dict(type='RandomHorizontalFlip'), 67 | dict(type='ToTensor'), 68 | dict(type='Normalize', 69 | mean=(0.4914, 0.4822, 0.4465), 70 | std=(0.2023, 0.1994, 0.2010)) 71 | ]), 72 | batch_size=32, 73 | num_workers=4, 74 | pin_memory=True), 75 | clean_test_dataloader=dict(dataset=dict( 76 | type='CIFAR10', 77 | root='data', 78 | train=False, 79 | download=True, 80 | transform=[ 81 | dict(type='ToTensor'), 82 | dict(type='Normalize', 83 | mean=(0.4914, 0.4822, 0.4465), 84 | std=(0.2023, 0.1994, 0.2010)) 85 | ]), 86 | batch_size=32, 87 | num_workers=4, 88 | pin_memory=True), 89 | poison_train_dataloader=dict(dataset=dict( 90 | type='RatioPoisonLabelCIFAR10', 91 | ratio=0.1, 92 | poison_label=1, 93 | root='data', 94 | train=True, 95 | download=True, 96 | transform=[ 97 | dict(type='RandomCrop', size=32, padding=4), 98 | dict(type='RandomHorizontalFlip'), 99 | dict(type='ToTensor'), 100 | dict(type='Normalize', 101 | mean=(0.4914, 0.4822, 0.4465), 102 | std=(0.2023, 0.1994, 0.2010)) 103 | ]), 104 | batch_size=32, 105 | num_workers=4, 106 | pin_memory=True), 107 | poison_test_dataloader=dict(dataset=dict( 108 | type='RatioPoisonLabelCIFAR10', 109 | ratio=1, 110 | poison_label=1, 111 | root='data', 112 | train=False, 113 | download=True, 114 | transform=[ 115 | dict(type='ToTensor'), 116 | dict(type='Normalize', 117 | mean=(0.4914, 0.4822, 0.4465), 118 | std=(0.2023, 0.1994, 0.2010)) 119 | ]), 120 | batch_size=32, 121 | num_workers=4, 122 | pin_memory=True), 123 | epochs=100, 124 | save_interval=5, 125 | temperature=1.0, 126 | alpha=1.0, 127 | device='cuda') 128 | -------------------------------------------------------------------------------- /tests/data/config/cifar10_resnet18.py: -------------------------------------------------------------------------------- 1 | dataset = 'CIFAR10' 2 | dataset_mean = (0.4914, 0.4822, 0.4465) 3 | dataset_std = (0.2023, 0.1994, 0.2010) 4 | 5 | network = dict(type='resnet18', arch='cifar') 6 | 7 | test_dataloader = dict(batch_size=64, 8 | num_workers=4, 9 | persistent_workers=True, 10 | shuffle=False, 11 | dataset=dict(type=dataset, 12 | root='data', 13 | train=False, 14 | download=True, 15 | transform=[ 16 | dict(type='ToTensor'), 17 | dict(type='Normalize', 18 | mean=dataset_mean, 19 | std=dataset_std) 20 | ])) 21 | 22 | train_dataloader = dict(batch_size=64, 23 | num_workers=4, 24 | persistent_workers=True, 25 | shuffle=True, 26 | dataset=dict(type=dataset, 27 | root='data', 28 | train=True, 29 | download=True, 30 | transform=[ 31 | dict(type='RandomCrop', 32 | size=32, 33 | padding=4), 34 | dict(type='RandomHorizontalFlip', 35 | p=0.5), 36 | dict(type='ToTensor'), 37 | dict(type='Normalize', 38 | mean=dataset_mean, 39 | std=dataset_std) 40 | ])) 41 | -------------------------------------------------------------------------------- /tests/data/config/error.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/tests/data/config/error.txt -------------------------------------------------------------------------------- /tests/data/config/simple_config.py: -------------------------------------------------------------------------------- 1 | item1 = [1, 2] 2 | item2 = {'a': 0} 3 | item3 = True 4 | item4 = 'test' -------------------------------------------------------------------------------- /tests/test_config/test_config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import pytest 5 | 6 | from anti_kd_backdoor.config import Config, DictAction 7 | 8 | 9 | class TestConfig: 10 | """Modify from mmengine""" 11 | 12 | data_path: Path = Path(__file__).parent.parent / 'data' 13 | config_path: Path = data_path / 'config' 14 | 15 | def test_fromfile(self) -> None: 16 | config_path = self.config_path / 'cifar10_resnet18.py' 17 | config = Config.fromfile(config_path) 18 | 19 | assert config.dataset == 'CIFAR10' 20 | assert config.network.type == 'resnet18' 21 | assert config.test_dataloader.batch_size == 64 22 | assert len(config.train_dataloader.dataset.transform) == 4 23 | assert len(config.test_dataloader.dataset.transform) == 2 24 | 25 | with pytest.raises(ValueError): 26 | _ = Config.fromfile(self.config_path / 'error.txt') 27 | 28 | def test_magic_method(self) -> None: 29 | cfg_dict = dict(item1=[1, 2], 30 | item2=dict(a=0), 31 | item3=True, 32 | item4='test') 33 | cfg_file = self.config_path / 'simple_config.py' 34 | cfg = Config.fromfile(cfg_file) 35 | # len(cfg) 36 | assert len(cfg) == 4 37 | # cfg.keys() 38 | assert set(cfg.keys()) == set(cfg_dict.keys()) 39 | assert set(cfg._cfg_dict.keys()) == set(cfg_dict.keys()) 40 | # cfg.values() 41 | for value in cfg.values(): 42 | assert value in cfg_dict.values() 43 | # cfg.items() 44 | for name, value in cfg.items(): 45 | assert name in cfg_dict 46 | assert value in cfg_dict.values() 47 | # cfg.field 48 | assert cfg.item1 == cfg_dict['item1'] 49 | assert cfg.item2 == cfg_dict['item2'] 50 | assert cfg.item2.a == 0 51 | assert cfg.item3 == cfg_dict['item3'] 52 | assert cfg.item4 == cfg_dict['item4'] 53 | # accessing keys that do not exist will cause error 54 | with pytest.raises(AttributeError): 55 | cfg.not_exist 56 | # field in cfg, cfg[field], cfg.get() 57 | for name in ['item1', 'item2', 'item3', 'item4']: 58 | assert name in cfg 59 | assert cfg[name] == cfg_dict[name] 60 | assert cfg.get(name) == cfg_dict[name] 61 | assert cfg.get('not_exist') is None 62 | assert cfg.get('not_exist', 0) == 0 63 | # accessing keys that do not exist will cause error 64 | with pytest.raises(KeyError): 65 | cfg['not_exist'] 66 | assert 'item1' in cfg 67 | assert 'not_exist' not in cfg 68 | # cfg.update() 69 | cfg.update(dict(item1=0)) 70 | assert cfg.item1 == 0 71 | cfg.update(dict(item2=dict(a=1))) 72 | assert cfg.item2.a == 1 73 | # test __setattr__ 74 | cfg = Config() 75 | cfg.item1 = [1, 2] 76 | cfg.item2 = {'a': 0} 77 | cfg['item5'] = {'a': {'b': None}} 78 | assert cfg._cfg_dict['item1'] == [1, 2] 79 | assert cfg.item1 == [1, 2] 80 | assert cfg._cfg_dict['item2'] == {'a': 0} 81 | assert cfg.item2.a == 0 82 | assert cfg._cfg_dict['item5'] == {'a': {'b': None}} 83 | assert cfg.item5.a.b is None 84 | 85 | def test_dict_action(self): 86 | parser = argparse.ArgumentParser(description='Train a detector') 87 | parser.add_argument('--options', 88 | nargs='+', 89 | action=DictAction, 90 | help='custom options') 91 | # Nested brackets 92 | args = parser.parse_args( 93 | ['--options', 'item2.a=a,b', 'item2.b=[(a,b), [1,2], false]']) 94 | out_dict = { 95 | 'item2.a': ['a', 'b'], 96 | 'item2.b': [('a', 'b'), [1, 2], False] 97 | } 98 | assert args.options == out_dict 99 | # Single Nested brackets 100 | args = parser.parse_args(['--options', 'item2.a=[[1]]']) 101 | out_dict = {'item2.a': [[1]]} 102 | assert args.options == out_dict 103 | # Imbalance bracket will cause error 104 | with pytest.raises(AssertionError): 105 | parser.parse_args(['--options', 'item2.a=[(a,b), [1,2], false']) 106 | # Normal values 107 | args = parser.parse_args([ 108 | '--options', 'item2.a=1', 'item2.b=0.1', 'item2.c=x', 'item3=false' 109 | ]) 110 | out_dict = { 111 | 'item2.a': 1, 112 | 'item2.b': 0.1, 113 | 'item2.c': 'x', 114 | 'item3': False 115 | } 116 | assert args.options == out_dict 117 | -------------------------------------------------------------------------------- /tests/test_data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/tests/test_data/__init__.py -------------------------------------------------------------------------------- /tests/test_data/test_dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/tests/test_data/test_dataset/__init__.py -------------------------------------------------------------------------------- /tests/test_data/test_dataset/test_cifar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from .utils import FakeDataset, build_fake_dataset 5 | 6 | CIFAR_TESTSET_NUM = 1000 7 | 8 | 9 | def build_cifar_fake_dataset(dataset_type: str, **kwargs) -> FakeDataset: 10 | if dataset_type.endswith('CIFAR100'): 11 | y_range = (0, 99) 12 | dataset_type = dataset_type.replace('CIFAR100', 'FakeDataset') 13 | else: 14 | y_range = (0, 9) 15 | dataset_type = dataset_type.replace('CIFAR10', 'FakeDataset') 16 | 17 | dataset_cfg = dict(type=dataset_type, 18 | x_shape=(3, 32, 32), 19 | y_range=y_range, 20 | nums=CIFAR_TESTSET_NUM, 21 | **kwargs) 22 | 23 | return build_fake_dataset(dataset_cfg) 24 | 25 | 26 | @pytest.mark.parametrize('dataset_type', ['CIFAR10', 'CIFAR100']) 27 | def test_xy(dataset_type: str) -> None: 28 | cifar = build_cifar_fake_dataset(dataset_type) 29 | 30 | xy = cifar.get_xy() 31 | x, y = xy 32 | assert len(x) == len(y) 33 | assert isinstance(y[0], int) 34 | 35 | old_x = x.copy() 36 | old_y = y.copy() 37 | 38 | cifar.set_xy(xy) 39 | assert all([np.array_equal(nx, ox) for nx, ox in zip(cifar.data, old_x)]) 40 | assert cifar.targets == old_y 41 | 42 | x = x[:cifar.num_classes] 43 | y = y[:cifar.num_classes] 44 | cifar.set_xy((x, y)) 45 | assert all([np.array_equal(nx, ox) for nx, ox in zip(cifar.data, x)]) 46 | assert cifar.targets == y 47 | assert cifar.num_classes == len(set(y)) 48 | assert len(cifar.data.shape) == 4 49 | 50 | 51 | @pytest.mark.parametrize(['start_idx', 'end_idx', 'dataset_type'], 52 | [(0, 9, 'IndexCIFAR10'), (-10, 8, 'IndexCIFAR10'), 53 | (2, 12, 'IndexCIFAR10'), (4, 4, 'IndexCIFAR10'), 54 | (4, 3, 'IndexCIFAR10'), (0, 99, 'IndexCIFAR100'), 55 | (-10, 8, 'IndexCIFAR100'), 56 | (40, 50, 'IndexCIFAR100')]) 57 | def test_index(start_idx: int, end_idx: int, dataset_type: str) -> None: 58 | kwargs = dict(start_idx=start_idx, end_idx=end_idx) 59 | 60 | if start_idx > end_idx: 61 | with pytest.raises(ValueError): 62 | _ = build_cifar_fake_dataset(dataset_type, **kwargs) 63 | return 64 | cifar = build_cifar_fake_dataset(dataset_type, **kwargs) 65 | assert cifar.start_idx == start_idx 66 | assert cifar.end_idx == end_idx 67 | 68 | for y in cifar.targets: 69 | assert start_idx <= y <= end_idx 70 | 71 | assert cifar.num_classes == min( 72 | cifar.end_idx, cifar.raw_num_classes - 1) - max(cifar.start_idx, 0) + 1 73 | assert len(cifar.data.shape) == 4 74 | 75 | 76 | @pytest.mark.parametrize(['ratio', 'dataset_type'], [(-1, 'RatioCIFAR10'), 77 | (0, 'RatioCIFAR10'), 78 | (0.1, 'RatioCIFAR10'), 79 | (0.5, 'RatioCIFAR10'), 80 | (1, 'RatioCIFAR10'), 81 | (2, 'RatioCIFAR10'), 82 | (0.4, 'RatioCIFAR100')]) 83 | def test_ratio(ratio: float, dataset_type: str) -> None: 84 | kwargs = dict(ratio=ratio) 85 | 86 | if ratio <= 0 or ratio > 1: 87 | with pytest.raises(ValueError): 88 | _ = build_cifar_fake_dataset(dataset_type, **kwargs) 89 | return 90 | cifar = build_cifar_fake_dataset(dataset_type, **kwargs) 91 | 92 | assert cifar.num_classes == cifar.raw_num_classes 93 | assert len(cifar.targets) == \ 94 | int(CIFAR_TESTSET_NUM / cifar.num_classes * ratio) * cifar.num_classes 95 | assert len(cifar.data.shape) == 4 96 | 97 | 98 | @pytest.mark.parametrize('range_ratio', [(-1, 0.2), (0, 2), (0.1, 0.1), 99 | (0.5, 0.2), (0.1, 0.5), (0, 1)]) 100 | @pytest.mark.parametrize('dataset_type', 101 | ['RangeRatioCIFAR10', 'RangeRatioCIFAR100']) 102 | def test_range_ratio(range_ratio: tuple[float, float], 103 | dataset_type: str) -> None: 104 | kwargs = dict(range_ratio=range_ratio) 105 | 106 | start_ratio = range_ratio[0] 107 | end_ratio = range_ratio[1] 108 | if not (0 <= start_ratio < end_ratio <= 1): 109 | with pytest.raises(ValueError): 110 | _ = build_cifar_fake_dataset(dataset_type, **kwargs) 111 | return 112 | 113 | cifar = build_cifar_fake_dataset(dataset_type, **kwargs) 114 | assert cifar.num_classes == cifar.raw_num_classes 115 | assert len(cifar.targets) == \ 116 | round(CIFAR_TESTSET_NUM * (end_ratio - start_ratio)) 117 | assert len(cifar.data.shape) == 4 118 | 119 | 120 | @pytest.mark.parametrize(['range_ratio1', 'range_ratio2'], 121 | [((0, 0.5), (0.5, 1)), ((0, 0.6), (0.4, 1)), 122 | ((0, 0.7), (0.3, 1)), ((0, 0.5), (0, 1))]) 123 | @pytest.mark.parametrize('dataset_type', 124 | ['RangeRatioCIFAR10', 'RangeRatioCIFAR100']) 125 | def test_range_ratio_intersection(range_ratio1: tuple[float, float], 126 | range_ratio2: tuple[float, float], 127 | dataset_type: str) -> None: 128 | 129 | cifar1 = build_cifar_fake_dataset(dataset_type=dataset_type, 130 | range_ratio=range_ratio1, 131 | cache_xy=True) 132 | cifar2 = build_cifar_fake_dataset(dataset_type=dataset_type, 133 | range_ratio=range_ratio2, 134 | cache_xy=True) 135 | 136 | cat_x = np.concatenate([cifar1.data, cifar2.data], axis=0) 137 | unique_x = np.unique(cat_x, axis=0) 138 | intersection_number = cat_x.shape[0] - unique_x.shape[0] 139 | 140 | intersection_ratio = max(0, range_ratio1[1] - range_ratio2[0]) 141 | assert round(intersection_ratio * CIFAR_TESTSET_NUM) == intersection_number 142 | 143 | 144 | @pytest.mark.parametrize('dataset_type', 145 | ['IndexRatioCIFAR10', 'IndexRatioCIFAR100']) 146 | @pytest.mark.parametrize(['start_idx', 'end_idx', 'ratio'], [(4, 3, 0.5), 147 | (3, 4, 0), 148 | (3, 4, 2), 149 | (1, 4, 0.1)]) 150 | def test_index_ratio(start_idx: int, end_idx: int, ratio: float, 151 | dataset_type: str) -> None: 152 | kwargs = dict(start_idx=start_idx, end_idx=end_idx, ratio=ratio) 153 | 154 | if ratio <= 0 or ratio > 1 or start_idx > end_idx: 155 | with pytest.raises(ValueError): 156 | _ = build_cifar_fake_dataset(dataset_type, **kwargs) 157 | return 158 | cifar = build_cifar_fake_dataset(dataset_type, **kwargs) 159 | assert cifar.start_idx == start_idx 160 | assert cifar.end_idx == end_idx 161 | 162 | for y in cifar.targets: 163 | assert start_idx <= y <= end_idx 164 | assert len(cifar.targets) == \ 165 | cifar.num_classes / cifar.raw_num_classes * CIFAR_TESTSET_NUM * ratio 166 | assert len(cifar.data.shape) == 4 167 | 168 | 169 | @pytest.mark.parametrize('dataset_type', 170 | ['PoisonLabelCIFAR10', 'PoisonLabelCIFAR100']) 171 | @pytest.mark.parametrize('poison_label', [-1, 5, 101]) 172 | def test_poison_label(poison_label: int, dataset_type: str) -> None: 173 | kwargs = dict(poison_label=poison_label) 174 | 175 | if poison_label < 0 or poison_label >= 100: 176 | with pytest.raises(ValueError): 177 | _ = build_cifar_fake_dataset(dataset_type, **kwargs) 178 | return 179 | cifar = build_cifar_fake_dataset(dataset_type, **kwargs) 180 | assert cifar.poison_label == poison_label 181 | 182 | assert cifar.num_classes == 1 183 | assert all(map(lambda x: x == poison_label, cifar.targets)) 184 | assert len(cifar.data.shape) == 4 185 | 186 | 187 | @pytest.mark.parametrize( 188 | 'dataset_type', ['RatioPoisonLabelCIFAR10', 'RatioPoisonLabelCIFAR100']) 189 | @pytest.mark.parametrize('poison_label', [-1, 5, 101]) 190 | @pytest.mark.parametrize('ratio', [0, 0.2, 1, 1.2]) 191 | def test_ratio_poison_label(ratio: float, poison_label: int, 192 | dataset_type: str) -> None: 193 | kwargs = dict(ratio=ratio, poison_label=poison_label) 194 | 195 | if (poison_label < 0 or poison_label >= 100) or \ 196 | (ratio <= 0 or ratio > 1): 197 | with pytest.raises(ValueError): 198 | _ = build_cifar_fake_dataset(dataset_type, **kwargs) 199 | return 200 | cifar = build_cifar_fake_dataset(dataset_type, **kwargs) 201 | assert cifar.poison_label == poison_label 202 | 203 | assert len(cifar) == round(CIFAR_TESTSET_NUM * ratio) 204 | assert cifar.num_classes == 1 205 | assert all(map(lambda x: x == poison_label, cifar.targets)) 206 | assert len(cifar.data.shape) == 4 207 | 208 | 209 | @pytest.mark.parametrize( 210 | 'dataset_type', 211 | ['RangeRatioPoisonLabelCIFAR10', 'RangeRatioPoisonLabelCIFAR100']) 212 | @pytest.mark.parametrize('poison_label', [-1, 1, 101]) 213 | @pytest.mark.parametrize('range_ratio', [(0, 0.2), (0.2, 0.5), (0.5, 1), 214 | (0.5, 0.2)]) 215 | def test_range_ratio_poison_label(range_ratio: tuple[float, 216 | float], poison_label: int, 217 | dataset_type: str) -> None: 218 | kwargs = dict(range_ratio=range_ratio, poison_label=poison_label) 219 | 220 | if poison_label < 0 or poison_label >= 100 or \ 221 | not (0 <= range_ratio[0] < range_ratio[1] <= 1): 222 | with pytest.raises(ValueError): 223 | _ = build_cifar_fake_dataset(dataset_type, **kwargs) 224 | return 225 | cifar = build_cifar_fake_dataset(dataset_type, **kwargs) 226 | assert cifar.poison_label == poison_label 227 | 228 | assert len(cifar) == round(CIFAR_TESTSET_NUM * 229 | (range_ratio[1] - range_ratio[0])) 230 | assert cifar.num_classes == 1 231 | assert all(map(lambda x: x == poison_label, cifar.targets)) 232 | assert len(cifar.data.shape) == 4 233 | -------------------------------------------------------------------------------- /tests/test_data/test_dataset/test_flowers102.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from .utils import FakeDataset, build_fake_dataset 5 | 6 | FLOWERS102_TESTSET_NUM = 5 * 102 7 | 8 | 9 | def build_flowers102_fake_dataset(dataset_type: str, **kwargs) -> FakeDataset: 10 | dataset_cfg = dict(type=dataset_type.replace('Flowers102', 'FakeDataset'), 11 | x_shape=(3, 32, 32), 12 | y_range=(0, 101), 13 | nums=FLOWERS102_TESTSET_NUM, 14 | **kwargs) 15 | 16 | return build_fake_dataset(dataset_cfg) 17 | 18 | 19 | @pytest.mark.parametrize('dataset_type', ['Flowers102']) 20 | def test_xy(dataset_type: str) -> None: 21 | flowers102 = build_flowers102_fake_dataset(dataset_type) 22 | 23 | xy = flowers102.get_xy() 24 | x, y = xy 25 | assert len(x) == len(y) 26 | assert isinstance(y[0], int) 27 | 28 | old_x = x.copy() 29 | old_y = y.copy() 30 | 31 | flowers102.set_xy(xy) 32 | assert all( 33 | [np.array_equal(nx, ox) for nx, ox in zip(flowers102.data, old_x)]) 34 | assert flowers102.targets == old_y 35 | 36 | x = x[:flowers102.num_classes] 37 | y = y[:flowers102.num_classes] 38 | flowers102.set_xy((x, y)) 39 | assert all([np.array_equal(nx, ox) for nx, ox in zip(flowers102.data, x)]) 40 | assert flowers102.targets == y 41 | assert flowers102.num_classes == len(set(y)) 42 | assert len(flowers102.data.shape) == 4 43 | 44 | 45 | @pytest.mark.parametrize('dataset_type', ['PoisonLabelFlowers102']) 46 | @pytest.mark.parametrize('poison_label', [-1, 5, 102]) 47 | def test_poison_label(poison_label: int, dataset_type: str) -> None: 48 | kwargs = dict(poison_label=poison_label) 49 | 50 | if poison_label < 0 or poison_label >= 43: 51 | with pytest.raises(ValueError): 52 | _ = build_flowers102_fake_dataset(dataset_type, **kwargs) 53 | return 54 | flowers102 = build_flowers102_fake_dataset(dataset_type, **kwargs) 55 | assert flowers102.poison_label == poison_label 56 | 57 | assert flowers102.num_classes == 1 58 | assert all(map(lambda x: x == poison_label, flowers102.targets)) 59 | assert len(flowers102.data.shape) == 4 60 | 61 | 62 | @pytest.mark.parametrize('dataset_type', ['RatioPoisonLabelFlowers102']) 63 | @pytest.mark.parametrize('poison_label', [-1, 5, 102]) 64 | @pytest.mark.parametrize('ratio', [0, 0.2, 1, 1.2]) 65 | def test_ratio_poison_label(ratio: float, poison_label: int, 66 | dataset_type: str) -> None: 67 | kwargs = dict(ratio=ratio, poison_label=poison_label) 68 | 69 | if (poison_label < 0 or poison_label >= 102) or \ 70 | (ratio <= 0 or ratio > 1): 71 | with pytest.raises(ValueError): 72 | _ = build_flowers102_fake_dataset(dataset_type, **kwargs) 73 | return 74 | flowers102 = build_flowers102_fake_dataset(dataset_type, **kwargs) 75 | assert flowers102.poison_label == poison_label 76 | 77 | assert len(flowers102) == round(FLOWERS102_TESTSET_NUM * ratio) 78 | assert flowers102.num_classes == 1 79 | assert all(map(lambda x: x == poison_label, flowers102.targets)) 80 | assert len(flowers102.data.shape) == 4 81 | -------------------------------------------------------------------------------- /tests/test_data/test_dataset/test_gtsrb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from .utils import FakeDataset, build_fake_dataset 5 | 6 | GTSRB_TESTSET_NUM = 43 * 50 7 | 8 | 9 | def build_gtsrb_fake_dataset(dataset_type: str, **kwargs) -> FakeDataset: 10 | dataset_cfg = dict(type=dataset_type.replace('GTSRB', 'FakeDataset'), 11 | x_shape=(3, 32, 32), 12 | y_range=(0, 42), 13 | nums=GTSRB_TESTSET_NUM, 14 | **kwargs) 15 | 16 | return build_fake_dataset(dataset_cfg) 17 | 18 | 19 | @pytest.mark.parametrize('dataset_type', ['GTSRB']) 20 | def test_xy(dataset_type: str) -> None: 21 | gtsrb = build_gtsrb_fake_dataset(dataset_type) 22 | 23 | xy = gtsrb.get_xy() 24 | x, y = xy 25 | assert len(x) == len(y) 26 | assert isinstance(y[0], int) 27 | 28 | old_x = x.copy() 29 | old_y = y.copy() 30 | 31 | gtsrb.set_xy(xy) 32 | assert all([np.array_equal(nx, ox) for nx, ox in zip(gtsrb.data, old_x)]) 33 | assert gtsrb.targets == old_y 34 | 35 | x = x[:gtsrb.num_classes] 36 | y = y[:gtsrb.num_classes] 37 | gtsrb.set_xy((x, y)) 38 | assert all([np.array_equal(nx, ox) for nx, ox in zip(gtsrb.data, x)]) 39 | assert gtsrb.targets == y 40 | assert gtsrb.num_classes == len(set(y)) 41 | assert len(gtsrb.data.shape) == 4 42 | 43 | 44 | @pytest.mark.parametrize('dataset_type', ['PoisonLabelGTSRB']) 45 | @pytest.mark.parametrize('poison_label', [-1, 5, 43]) 46 | def test_poison_label(poison_label: int, dataset_type: str) -> None: 47 | kwargs = dict(poison_label=poison_label) 48 | 49 | if poison_label < 0 or poison_label >= 43: 50 | with pytest.raises(ValueError): 51 | _ = build_gtsrb_fake_dataset(dataset_type, **kwargs) 52 | return 53 | gtsrb = build_gtsrb_fake_dataset(dataset_type, **kwargs) 54 | assert gtsrb.poison_label == poison_label 55 | 56 | assert gtsrb.num_classes == 1 57 | assert all(map(lambda x: x == poison_label, gtsrb.targets)) 58 | assert len(gtsrb.data.shape) == 4 59 | 60 | 61 | @pytest.mark.parametrize('dataset_type', ['RatioPoisonLabelGTSRB']) 62 | @pytest.mark.parametrize('poison_label', [-1, 5, 43]) 63 | @pytest.mark.parametrize('ratio', [0, 0.2, 1, 1.2]) 64 | def test_ratio_poison_label(ratio: float, poison_label: int, 65 | dataset_type: str) -> None: 66 | kwargs = dict(ratio=ratio, poison_label=poison_label) 67 | 68 | if (poison_label < 0 or poison_label >= 43) or \ 69 | (ratio <= 0 or ratio > 1): 70 | with pytest.raises(ValueError): 71 | _ = build_gtsrb_fake_dataset(dataset_type, **kwargs) 72 | return 73 | gtsrb = build_gtsrb_fake_dataset(dataset_type, **kwargs) 74 | assert gtsrb.poison_label == poison_label 75 | 76 | assert len(gtsrb) == round(GTSRB_TESTSET_NUM * ratio) 77 | assert gtsrb.num_classes == 1 78 | assert all(map(lambda x: x == poison_label, gtsrb.targets)) 79 | assert len(gtsrb.data.shape) == 4 80 | -------------------------------------------------------------------------------- /tests/test_data/test_dataset/test_svhn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from .utils import FakeDataset, build_fake_dataset 5 | 6 | SVHN_TESTSET_NUM = 10 * 100 7 | 8 | 9 | def build_svhn_fake_dataset(dataset_type: str, **kwargs) -> FakeDataset: 10 | dataset_cfg = dict(type=dataset_type.replace('SVHN', 'FakeDataset'), 11 | x_shape=(3, 32, 32), 12 | y_range=(0, 9), 13 | nums=SVHN_TESTSET_NUM, 14 | **kwargs) 15 | 16 | return build_fake_dataset(dataset_cfg) 17 | 18 | 19 | @pytest.mark.parametrize('dataset_type', ['SVHN']) 20 | def test_xy(dataset_type: str) -> None: 21 | svhn = build_svhn_fake_dataset(dataset_type) 22 | 23 | xy = svhn.get_xy() 24 | x, y = xy 25 | assert len(x) == len(y) 26 | assert isinstance(y[0], int) 27 | 28 | old_x = x.copy() 29 | old_y = y.copy() 30 | 31 | svhn.set_xy(xy) 32 | assert all([np.array_equal(nx, ox) for nx, ox in zip(svhn.data, old_x)]) 33 | assert svhn.targets == old_y 34 | 35 | x = x[:svhn.num_classes] 36 | y = y[:svhn.num_classes] 37 | svhn.set_xy((x, y)) 38 | assert all([np.array_equal(nx, ox) for nx, ox in zip(svhn.data, x)]) 39 | assert svhn.targets == y 40 | assert svhn.num_classes == len(set(y)) 41 | assert len(svhn.data.shape) == 4 42 | 43 | 44 | @pytest.mark.parametrize('dataset_type', ['PoisonLabelSVHN']) 45 | @pytest.mark.parametrize('poison_label', [-1, 5, 43]) 46 | def test_poison_label(poison_label: int, dataset_type: str) -> None: 47 | kwargs = dict(poison_label=poison_label) 48 | 49 | if poison_label < 0 or poison_label >= 43: 50 | with pytest.raises(ValueError): 51 | _ = build_svhn_fake_dataset(dataset_type, **kwargs) 52 | return 53 | svhn = build_svhn_fake_dataset(dataset_type, **kwargs) 54 | assert svhn.poison_label == poison_label 55 | 56 | assert svhn.num_classes == 1 57 | assert all(map(lambda x: x == poison_label, svhn.targets)) 58 | assert len(svhn.data.shape) == 4 59 | 60 | 61 | @pytest.mark.parametrize('dataset_type', ['RatioPoisonLabelSVHN']) 62 | @pytest.mark.parametrize('poison_label', [-1, 5, 10]) 63 | @pytest.mark.parametrize('ratio', [0, 0.2, 1, 1.2]) 64 | def test_ratio_poison_label(ratio: float, poison_label: int, 65 | dataset_type: str) -> None: 66 | kwargs = dict(ratio=ratio, poison_label=poison_label) 67 | 68 | if (poison_label < 0 or poison_label >= 10) or \ 69 | (ratio <= 0 or ratio > 1): 70 | with pytest.raises(ValueError): 71 | _ = build_svhn_fake_dataset(dataset_type, **kwargs) 72 | return 73 | svhn = build_svhn_fake_dataset(dataset_type, **kwargs) 74 | assert svhn.poison_label == poison_label 75 | 76 | assert len(svhn) == round(SVHN_TESTSET_NUM * ratio) 77 | assert svhn.num_classes == 1 78 | assert all(map(lambda x: x == poison_label, svhn.targets)) 79 | assert len(svhn.data.shape) == 4 80 | -------------------------------------------------------------------------------- /tests/test_data/test_dataset/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | from anti_kd_backdoor.data.dataset.base import ( 6 | DatasetInterface, 7 | IndexDataset, 8 | IndexRatioDataset, 9 | PoisonLabelDataset, 10 | RangeRatioDataset, 11 | RangeRatioPoisonLabelDataset, 12 | RatioDataset, 13 | RatioPoisonLabelDataset, 14 | ) 15 | from anti_kd_backdoor.data.dataset.types import XY_TYPE 16 | 17 | 18 | class FakeDataset(DatasetInterface, Dataset): 19 | cache: dict[tuple, tuple[torch.Tensor, list[int]]] = dict() 20 | 21 | def __init__(self, 22 | *, 23 | x_shape: tuple[int, int, int] = (3, 32, 32), 24 | y_range: tuple[int, int] = (0, 9), 25 | nums: int = 10000, 26 | cache_xy: bool = False) -> None: 27 | self._nums = nums 28 | self._x_shape = x_shape 29 | self._y_range = y_range 30 | self._raw_num_classes = y_range[1] - y_range[0] + 1 31 | 32 | if cache_xy: 33 | cache_key = (x_shape, y_range, nums) 34 | if cache_key not in FakeDataset.cache: 35 | x, y = self._prepare_xy() 36 | FakeDataset.cache[cache_key] = (x, y) 37 | x, y = FakeDataset.cache[cache_key] 38 | else: 39 | x, y = self._prepare_xy() 40 | 41 | self.data, self.targets = x, y 42 | 43 | def __len__(self) -> int: 44 | return len(self.targets) 45 | 46 | def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: 47 | x = self.data[index] 48 | y = self.targets[index] 49 | 50 | return x, y 51 | 52 | def get_xy(self) -> XY_TYPE: 53 | return list(self.data), self.targets.copy() 54 | 55 | def set_xy(self, xy: XY_TYPE) -> None: 56 | x, y = xy 57 | assert len(x) == len(y) 58 | 59 | self.data = np.stack(x, axis=0) 60 | self.targets = y.copy() 61 | 62 | @property 63 | def num_classes(self) -> int: 64 | return len(set(self.targets)) 65 | 66 | @property 67 | def raw_num_classes(self) -> int: 68 | return self._raw_num_classes 69 | 70 | def _prepare_xy(self) -> tuple[torch.Tensor, list[int]]: 71 | x = torch.rand((self._nums, *self._x_shape)) 72 | num_per_class = self._nums // self._raw_num_classes 73 | y = [ 74 | i for _ in range(num_per_class) 75 | for i in range(self._raw_num_classes) 76 | ] 77 | y.extend([self._y_range[0]] * (self._nums - len(y))) 78 | 79 | return x, y 80 | 81 | 82 | class IndexFakeDataset(FakeDataset, IndexDataset): 83 | 84 | def __init__(self, *, start_idx: int, end_idx: int, **kwargs) -> None: 85 | FakeDataset.__init__(self, **kwargs) 86 | IndexDataset.__init__(self, start_idx=start_idx, end_idx=end_idx) 87 | 88 | 89 | class RatioFakeDataset(FakeDataset, RatioDataset): 90 | 91 | def __init__(self, *, ratio: float, **kwargs) -> None: 92 | FakeDataset.__init__(self, **kwargs) 93 | RatioDataset.__init__(self, ratio=ratio) 94 | 95 | 96 | class RangeRatioFakeDataset(FakeDataset, RangeRatioDataset): 97 | 98 | def __init__(self, *, range_ratio: tuple[float, float], **kwargs) -> None: 99 | FakeDataset.__init__(self, **kwargs) 100 | RangeRatioDataset.__init__(self, range_ratio=range_ratio) 101 | 102 | 103 | class IndexRatioFakeDataset(FakeDataset, IndexRatioDataset): 104 | 105 | def __init__(self, *, start_idx: int, end_idx: int, ratio: float, 106 | **kwargs) -> None: 107 | FakeDataset.__init__(self, **kwargs) 108 | IndexRatioDataset.__init__(self, 109 | start_idx=start_idx, 110 | end_idx=end_idx, 111 | ratio=ratio) 112 | 113 | 114 | class PoisonLabelFakeDataset(FakeDataset, PoisonLabelDataset): 115 | 116 | def __init__(self, *, poison_label: int, **kwargs) -> None: 117 | FakeDataset.__init__(self, **kwargs) 118 | PoisonLabelDataset.__init__(self, poison_label=poison_label) 119 | 120 | 121 | class RatioPoisonLabelFakeDataset(FakeDataset, RatioPoisonLabelDataset): 122 | 123 | def __init__(self, *, ratio: float, poison_label: int, **kwargs) -> None: 124 | FakeDataset.__init__(self, **kwargs) 125 | RatioPoisonLabelDataset.__init__(self, 126 | ratio=ratio, 127 | poison_label=poison_label) 128 | 129 | 130 | class RangeRatioPoisonLabelFakeDataset(FakeDataset, 131 | RangeRatioPoisonLabelDataset): 132 | 133 | def __init__(self, *, range_ratio: tuple[float, float], poison_label: int, 134 | **kwargs) -> None: 135 | FakeDataset.__init__(self, **kwargs) 136 | RangeRatioPoisonLabelDataset.__init__(self, 137 | range_ratio=range_ratio, 138 | poison_label=poison_label) 139 | 140 | 141 | FAKE_DATASETS_MAPPING = { 142 | 'FakeDataset': FakeDataset, 143 | 'IndexFakeDataset': IndexFakeDataset, 144 | 'RatioFakeDataset': RatioFakeDataset, 145 | 'RangeRatioFakeDataset': RangeRatioFakeDataset, 146 | 'IndexRatioFakeDataset': IndexRatioFakeDataset, 147 | 'PoisonLabelFakeDataset': PoisonLabelFakeDataset, 148 | 'RatioPoisonLabelFakeDataset': RatioPoisonLabelFakeDataset, 149 | 'RangeRatioPoisonLabelFakeDataset': RangeRatioPoisonLabelFakeDataset 150 | } 151 | 152 | 153 | def build_fake_dataset(dataset_cfg: dict) -> FakeDataset: 154 | if 'type' not in dataset_cfg: 155 | raise ValueError('Dataset config must have `type` field') 156 | dataset_type = dataset_cfg.pop('type') 157 | if dataset_type not in FAKE_DATASETS_MAPPING: 158 | raise ValueError( 159 | f'Dataset `{dataset_type}` is not support, ' 160 | f'available datasets: {list(FAKE_DATASETS_MAPPING.keys())}') 161 | dataset = FAKE_DATASETS_MAPPING[dataset_type] 162 | 163 | return dataset(**dataset_cfg) 164 | -------------------------------------------------------------------------------- /tests/test_network/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/tests/test_network/__init__.py -------------------------------------------------------------------------------- /tests/test_network/test_cifar/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/tests/test_network/test_cifar/__init__.py -------------------------------------------------------------------------------- /tests/test_network/test_cifar/test_network.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from anti_kd_backdoor.network import build_network 5 | 6 | _AVAILABLE_CIFAR_NETWORKS = [ 7 | 'mobilenet_v2', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'mobilenetv2_x0_5', 9 | 'mobilenetv2_x0_75', 'mobilenetv2_x1_0', 'mobilenetv2_x1_4', 'repvgg_a0', 10 | 'repvgg_a1', 'repvgg_a2', 'shufflenetv2_x0_5', 'shufflenetv2_x1_0', 11 | 'shufflenetv2_x1_5', 'shufflenetv2_x2_0' 12 | ] 13 | 14 | 15 | def _make_network_cfg(num_classes: int, network_type: str) -> dict: 16 | return dict(type=network_type, arch='cifar', num_classes=num_classes) 17 | 18 | 19 | @torch.no_grad() 20 | @pytest.mark.parametrize('num_classes', [10, 100]) 21 | @pytest.mark.parametrize('network_type', _AVAILABLE_CIFAR_NETWORKS) 22 | def test_mobilenet_v2(network_type: str, num_classes: int) -> None: 23 | model = build_network(_make_network_cfg(num_classes, network_type)) 24 | 25 | x = torch.rand(2, 3, 32, 32) 26 | logit = model(x) 27 | 28 | assert list(logit.shape) == [2, num_classes] 29 | -------------------------------------------------------------------------------- /tests/test_network/test_trigger.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from anti_kd_backdoor.network.trigger import Trigger 5 | 6 | 7 | @torch.no_grad() 8 | @pytest.mark.parametrize('size', [32, 224]) 9 | def test_trigger_init(size: int) -> None: 10 | trigger = Trigger(size) 11 | assert trigger.size == size 12 | assert list(trigger.mask.shape) == [size, size] 13 | assert list(trigger.trigger.shape) == [3, size, size] 14 | 15 | 16 | @torch.no_grad() 17 | @pytest.mark.parametrize('size', [32, 224]) 18 | def test_trigger_forward(size: int) -> None: 19 | trigger = Trigger(size) 20 | 21 | x = torch.rand(10, 3, size, size) 22 | xp = trigger(x) 23 | assert xp.shape == x.shape 24 | 25 | # test effect of mask 26 | trigger.mask.fill_(0) 27 | xp = trigger(x) 28 | assert torch.equal(xp, x) 29 | 30 | trigger.mask.fill_(1) 31 | xp = trigger(x) 32 | for i in range(xp.size(0)): 33 | assert torch.equal(xp[i], trigger.trigger) 34 | -------------------------------------------------------------------------------- /tests/test_trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/tests/test_trainer/__init__.py -------------------------------------------------------------------------------- /tests/test_trainer/test_anti_kd.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from anti_kd_backdoor.config import Config 4 | from anti_kd_backdoor.trainer import build_trainer 5 | from anti_kd_backdoor.trainer.anti_kd import ( 6 | AntiKDTrainer, 7 | NetworkWrapper, 8 | TriggerWrapper, 9 | ) 10 | 11 | CONFIG_PATH = 'tests/data/config/anti_kd_t-r34_s-r18-v16-mv2_cifar10.py' 12 | 13 | 14 | def test_anti_kd(tmp_work_dirs: Path) -> None: 15 | config = Config.fromfile(CONFIG_PATH) 16 | trainer_config = config.trainer 17 | trainer_config.work_dirs = tmp_work_dirs 18 | 19 | trainer = build_trainer(trainer_config) 20 | assert isinstance(trainer, AntiKDTrainer) 21 | assert trainer._alpha == trainer_config.alpha 22 | assert trainer._save_interval == trainer_config.save_interval 23 | assert trainer._device == trainer_config.device 24 | assert trainer._epochs == trainer_config.epochs 25 | 26 | teacher = trainer._teacher_wrapper 27 | assert isinstance(teacher, NetworkWrapper) 28 | assert teacher.lambda_t == trainer_config.teacher.lambda_t 29 | assert teacher.lambda_mask == trainer_config.teacher.lambda_mask 30 | assert teacher.trainable_when_training_trigger == \ 31 | trainer_config.teacher.trainable_when_training_trigger 32 | 33 | students = trainer._student_wrappers 34 | for s_name, s in students.items(): 35 | assert isinstance(s, NetworkWrapper) 36 | student_config = trainer_config.students 37 | assert s.lambda_t == getattr(student_config, s_name).lambda_t 38 | assert s.lambda_mask == getattr(student_config, s_name).lambda_mask 39 | assert s.trainable_when_training_trigger == getattr( 40 | student_config, s_name).trainable_when_training_trigger 41 | 42 | trigger = trainer._trigger_wrapper 43 | assert isinstance(trigger, TriggerWrapper) 44 | assert trigger.mask_clip_range == trainer_config.trigger.mask_clip_range 45 | assert trigger.trigger_clip_range == \ 46 | trainer_config.trigger.trigger_clip_range 47 | assert trigger.mask_penalty_norm == \ 48 | trainer_config.trigger.mask_penalty_norm 49 | 50 | clean_train_dataloader = trainer._clean_train_dataloader 51 | assert clean_train_dataloader.batch_size == \ 52 | trainer_config.clean_train_dataloader.batch_size 53 | assert callable(clean_train_dataloader.dataset.transform) 54 | -------------------------------------------------------------------------------- /tests/test_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/tests/test_utils/__init__.py -------------------------------------------------------------------------------- /tests/test_utils/test_metric.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/tests/test_utils/test_metric.py -------------------------------------------------------------------------------- /trigger/epoch_99.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NISPLab/CleanSheet/9886da0137cd9a698735a09057390950c3d4a0e6/trigger/epoch_99.pth -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Trigger(nn.Module): 5 | 6 | def __init__(self, size: int = 32, transparency: float = 1.) -> None: 7 | super().__init__() 8 | 9 | self.size = size 10 | self.mask = nn.Parameter(torch.rand(size, size,device=torch.device('cuda')),requires_grad=True) 11 | self.transparency = transparency 12 | self.trigger = nn.Parameter(torch.rand(3, size, size,device=torch.device('cuda')) * 4 - 2,requires_grad=True) 13 | 14 | 15 | def forward(self, x: torch.Tensor) -> torch.Tensor: 16 | return self.transparency * self.mask * self.trigger + (1 - self.mask * self.transparency) * x 17 | 18 | class UAP(nn.Module): 19 | 20 | def __init__(self, size: int = 32) -> None: 21 | super().__init__() 22 | 23 | self.size = size 24 | self.perturbation = nn.Parameter(torch.zeros(3, size, size,device=torch.device('cuda')),requires_grad=True) 25 | 26 | 27 | def forward(self, x: torch.Tensor) -> torch.Tensor: 28 | return x + self.perturbation 29 | --------------------------------------------------------------------------------