├── .gitignore ├── README.md ├── data ├── __init__.py └── poison_tool_cifar.py ├── logs ├── mask_values.txt ├── output.log └── pruning_by_threshold.txt ├── main.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── blocks.cpython-36.pyc │ ├── dynamic_models.cpython-36.pyc │ ├── mask_batchnorm.cpython-36.pyc │ ├── mobilenetv2.cpython-36.pyc │ ├── resnet_cifar.cpython-36.pyc │ └── vgg_cifar.cpython-36.pyc ├── blocks.py ├── dynamic_models.py ├── mask_batchnorm.py ├── mobilenetv2.py ├── resnet_cifar.py └── vgg_cifar.py ├── train_backdoor.py └── trigger ├── best_square_trigger_cifar10.npz └── signal_cifar10_mask.npy /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | data/CIFAR10/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reconstructive Neuron Pruning for Backdoor Defense 2 | 3 | Code for ICML 2023 Paper ["Reconstructive Neuron Pruning for Backdoor Defense"](https://arxiv.org/pdf/2305.14876.pdf) 4 | 5 | 6 | --- 7 | 8 | 9 | # Quick Start: RNP against BadNets Attack 10 | By default, we only use 500 defense data randomly sampled from the training set to perform the `unlearn-recover` process and optimize the pruning mask. To check the performance of RNP on a Badnets ResNet-18 network (i.e. 10% poisoning rata with ResNet-18 on CIFAR-10), you can directly run the command like: 11 | 12 | ```python 13 | python main.py 14 | ``` 15 | 16 | # Experimental Results on BadNets Attack 17 | 18 | ```python 19 | [2023/07/09 22:21:00] - Namespace(alpha=0.2, arch='resnet18', backdoor_model_path='weights/ResNet18-ResNet-BadNets-target0-portion0.1-epoch80.tar', batch_size=128, clean_threshold=0.2, cuda=1, dataset='CIFAR10', log_root='logs/', mask_file=None, momentum=0.9, num_class=10, output_weight='weights/', pruning_by='threshold', pruning_max=0.9, pruning_step=0.05, ratio=0.01, recovering_epochs=20, recovering_lr=0.2, save_every=5, schedule=[10, 20], target_label=0, target_type='all2one', trig_h=3, trig_w=3, trigger_type='gridTrigger', unlearned_model_path=None, unlearning_epochs=20, unlearning_lr=0.01, weight_decay=0.0005) 20 | [2023/07/09 22:21:00] - ----------- Data Initialization -------------- 21 | [2023/07/09 22:21:03] - ----------- Backdoor Model Initialization -------------- 22 | [2023/07/09 22:21:04] - Epoch lr Time TrainLoss TrainACC PoisonLoss PoisonACC CleanLoss CleanACC 23 | [2023/07/09 22:21:04] - ----------- Model Unlearning -------------- 24 | [2023/07/09 22:21:15] - 0 0.010 11.2 0.1004 0.9780 0.0000 1.0000 0.2185 0.9342 25 | [2023/07/09 22:21:26] - 1 0.010 10.8 0.1213 0.9760 0.0000 1.0000 0.2253 0.9317 26 | [2023/07/09 22:21:37] - 2 0.010 10.8 0.1400 0.9740 0.0000 1.0000 0.2349 0.9304 27 | [2023/07/09 22:21:48] - 3 0.010 10.8 0.1535 0.9720 0.0000 1.0000 0.2513 0.9266 28 | [2023/07/09 22:21:59] - 4 0.010 10.9 0.2078 0.9640 0.0000 1.0000 0.2770 0.9214 29 | [2023/07/09 22:22:10] - 5 0.010 11.0 0.2614 0.9500 0.0000 1.0000 0.3144 0.9141 30 | [2023/07/09 22:22:21] - 6 0.010 10.8 0.3711 0.9220 0.0000 1.0000 0.3847 0.8991 31 | [2023/07/09 22:22:31] - 7 0.010 10.8 0.4538 0.8700 0.0000 1.0000 0.5276 0.8669 32 | [2023/07/09 22:22:42] - 8 0.010 10.8 0.7916 0.6700 0.0000 1.0000 0.9439 0.7586 33 | [2023/07/09 22:22:53] - 9 0.010 10.8 1.4771 0.4540 0.0000 1.0000 2.5574 0.4016 34 | [2023/07/09 22:23:04] - 10 0.001 10.8 3.1028 0.2920 0.0000 1.0000 2.5620 0.3949 35 | [2023/07/09 22:23:15] - 11 0.001 10.8 4.0416 0.2360 0.0000 1.0000 2.8507 0.3729 36 | [2023/07/09 22:23:25] - 12 0.001 10.8 4.8811 0.1980 0.0000 1.0000 3.2337 0.3368 37 | [2023/07/09 22:23:25] - ----------- Model Recovering -------------- 38 | [2023/07/09 22:23:26] - Epoch lr Time TrainLoss TrainACC PoisonLoss PoisonACC CleanLoss CleanACC 39 | [2023/07/09 22:23:37] - 1 0.200 11.0 1.0719 0.1980 0.0000 1.0000 2.0311 0.4370 40 | [2023/07/09 22:23:48] - 2 0.200 11.0 0.9892 0.1920 0.0000 1.0000 1.3383 0.5432 41 | [2023/07/09 22:23:59] - 3 0.200 11.0 0.7809 0.2320 0.0018 1.0000 1.1726 0.6138 42 | [2023/07/09 22:24:10] - 4 0.200 11.0 0.4892 0.2500 0.3161 0.8770 1.0972 0.6552 43 | [2023/07/09 22:24:21] - 5 0.200 11.1 0.4130 0.2860 0.6548 0.6051 1.0409 0.6662 44 | [2023/07/09 22:24:32] - 6 0.200 11.0 0.3691 0.3060 0.7871 0.5084 0.9843 0.6822 45 | [2023/07/09 22:24:43] - 7 0.200 11.0 0.3262 0.3460 0.8479 0.4791 0.9089 0.7053 46 | [2023/07/09 22:24:54] - 8 0.200 11.0 0.2963 0.3760 0.8369 0.4904 0.8691 0.7157 47 | [2023/07/09 22:25:05] - 9 0.200 11.0 0.2777 0.3900 0.8192 0.5090 0.8226 0.7324 48 | [2023/07/09 22:25:16] - 10 0.200 11.0 0.2485 0.4320 0.7842 0.5340 0.7765 0.7497 49 | [2023/07/09 22:25:27] - 11 0.200 11.0 0.2337 0.4500 0.7554 0.5562 0.7283 0.7666 50 | [2023/07/09 22:25:38] - 12 0.200 11.0 0.2140 0.4900 0.6922 0.6044 0.7022 0.7752 51 | [2023/07/09 22:25:49] - 13 0.200 11.0 0.2043 0.5140 0.6542 0.6317 0.6718 0.7870 52 | [2023/07/09 22:26:00] - 14 0.200 11.0 0.1807 0.5340 0.6128 0.6598 0.6517 0.7951 53 | [2023/07/09 22:26:11] - 15 0.200 11.0 0.1724 0.5440 0.5820 0.6873 0.6342 0.8018 54 | [2023/07/09 22:26:22] - 16 0.200 11.0 0.1729 0.5780 0.5754 0.6968 0.6084 0.8121 55 | [2023/07/09 22:26:33] - 17 0.200 11.1 0.1532 0.6180 0.5683 0.7027 0.5930 0.8176 56 | [2023/07/09 22:26:44] - 18 0.200 11.0 0.1476 0.6120 0.5614 0.7083 0.5766 0.8244 57 | [2023/07/09 22:26:55] - 19 0.200 11.0 0.1510 0.6380 0.5674 0.7069 0.5601 0.8312 58 | [2023/07/09 22:27:06] - 20 0.200 11.1 0.1439 0.6520 0.5788 0.6988 0.5417 0.8402 59 | [2023/07/09 22:27:07] - ----------- Backdoored Model Pruning -------------- 60 | [2023/07/09 22:27:07] - Pruned Number Layer Name Neuron Idx Mask PoisonLoss PoisonACC CleanLoss CleanACC 61 | [2023/07/09 22:27:17] - 0 None None 0.0001 1.0000 0.2157 0.9340 62 | [2023/07/09 22:27:27] - 12.00 layer4.1.bn2 188 0.0 0.0122 0.9986 0.2104 0.9347 63 | [2023/07/09 22:27:37] - 14.00 layer4.1.bn2 12 0.05 0.0139 0.9982 0.2092 0.9340 64 | [2023/07/09 22:27:46] - 18.00 layer3.0.bn2 230 0.1 0.0601 0.9876 0.2065 0.9338 65 | [2023/07/09 22:27:56] - 21.00 bn1 5 0.15000000000000002 0.9935 0.4796 0.2074 0.9340 66 | [2023/07/09 22:28:06] - 24.00 layer4.0.bn2 82 0.2 2.8242 0.0661 0.2297 0.9280 67 | [2023/07/09 22:28:16] - 28.00 layer3.0.bn1 106 0.25 3.2791 0.0424 0.2292 0.9270 68 | [2023/07/09 22:28:26] - 32.00 layer4.1.bn2 152 0.30000000000000004 4.2908 0.0172 0.2295 0.9277 69 | ``` 70 | 71 | 72 | ## Links to ImageNet-12 Subset 73 | Please download the ImageNet-12 subset with this link: (Baidu driver)[https://pan.baidu.com/share/init?surl=LjE6g1cxQ98RZHMWi0tQlA] (pwd: qetk) or (Google driver)[https://drive.google.com/file/d/1yG9ENDUbOIUKY1i5ADu4X_7Lhbvqca2w/view?usp=sharing] 74 | 75 | 76 | 77 | 78 | ## Backdoor Model Weights 79 | You can directly download the pre-trained backdoored model weights with the links below: 80 | 81 | | Attacks | Paper Name | Baidu Weight Source (pwd: 1212) | Google Weight Source | 82 | |:---:|:---:|:---:|:---:| 83 | | Badnets | Badnets: Evaluating Backdooring Attacks on Deep Neural Networks | [Baidu Drive](https://pan.baidu.com/s/1LXZuvb06als1D025eK04_Q) | [Google Drive](https://drive.google.com/file/d/1B4eHfsTyw_Qj-XgZc2byYDT_95TLjtLj/view?usp=sharing) | 84 | | Trojan | Trojaning attack on Neural Networks | [Baidu Drive](https://pan.baidu.com/s/1LXZuvb06als1D025eK04_Q) | [Google Drive]() | 85 | | Blend | Targeted Backdoor Attacks on Deep Learning Systems Using Data Poisoning | [Baidu Drive](https://pan.baidu.com/s/1LXZuvb06als1D025eK04_Q) | [Google Drive]() | 86 | | CL | Label-Consistent Backdoor Attacks | [Baidu Drive](https://pan.baidu.com/s/1LXZuvb06als1D025eK04_Q) | [Google Drive](https://drive.google.com/file/d/1B4eHfsTyw_Qj-XgZc2byYDT_95TLjtLj/view?usp=sharing) | 87 | | SIG | A New Backdoor Attack in Cnns by Training Set Corruption without Label Poisoning | [Baidu Drive](https://pan.baidu.com/s/1LXZuvb06als1D025eK04_Q) | [Google Drive]() | 88 | | Dynamic | Input-Aware Dynamic Backdoor Attack | [Baidu Drive](https://pan.baidu.com/s/1LXZuvb06als1D025eK04_Q) | [Google Drive](https://drive.google.com/file/d/1B4eHfsTyw_Qj-XgZc2byYDT_95TLjtLj/view?usp=sharing) | 89 | | WaNet | WaNet - Imperceptible Warping-based Backdoor Attack | [Baidu Drive](https://pan.baidu.com/s/1LXZuvb06als1D025eK04_Q) | [Google Drive]() | 90 | | FC | Poison Frog! Targeted Clean-label Backdoor Attacks on Neural Networks | [Baidu Drive](https://pan.baidu.com/s/1LXZuvb06als1D025eK04_Q) | [Google Drive](https://drive.google.com/file/d/1B4eHfsTyw_Qj-XgZc2byYDT_95TLjtLj/view?usp=sharing) | 91 | | DFST | Deep Feature Space Trojan Attack of Neural Networks by Controlled Detoxifcation | [Baidu Drive](https://pan.baidu.com/s/1LXZuvb06als1D025eK04_Q) | [Google Drive]() | 92 | | AWP | Can Adversarial Weight Perturbations Inject Neural Backdoors | [Baidu Drive](https://pan.baidu.com/s/1LXZuvb06als1D025eK04_Q) | [Google Drive]() | 93 | | LIRA | LIRA: Learnable, Imperceptible and Robust Backdoor Attacks | [Baidu Drive](https://pan.baidu.com/s/1LXZuvb06als1D025eK04_Q) | [Google Drive]() | 94 | | A-Blend | Circumventing Backdoor Defense that are Based on Latent Separability | [Baidu Drive](https://pan.baidu.com/s/1LXZuvb06als1D025eK04_Q) | [Google Drive]() | 95 | 96 | 97 | ## Citation 98 | If you use this code in your work, please cite the accompanying paper: 99 | 100 | ``` 101 | @inproceedings{ 102 | li2023reconstructive, 103 | title={Reconstructive Neuron Pruning for Backdoor Defense}, 104 | author={Yige Li and Xixiang Lyu and Xingjun Ma and Nodens Koren and Lingjuan Lyu and Bo Li and Yu-Gang Jiang}, 105 | booktitle={ICML}, 106 | year={2023}, 107 | } 108 | ``` 109 | 110 | ## Acknowledgements 111 | As this code is reproduced based on the open-sourced code [Adversarial Neuron Pruning Purifies Backdoored Deep Models](https://github.com/csdongxian/ANP_backdoor) and [Distilling Cognitive Backdoor Patterns within an Image](https://github.com/HanxunH/CognitiveDistillation), the authors would like to thank their contribution and help. 112 | 113 | 114 | 115 | ## Backdoor-related repo: 116 | - Dynamic Attack: https://github.com/VinAIResearch/input-aware-backdoor-attack-release 117 | - STRIP: https://github.com/garrisongys/STRIP 118 | - NAD: https://github.com/bboylyg/NAD 119 | - ABL: https://github.com/bboylyg/ABL 120 | - Frequency: https://github.com/YiZeng623/frequency-backdoor 121 | - NC: https://github.com/VinAIResearch/input-aware-backdoor-attack-release/tree/master/defenses/neural_cleanse 122 | - BackdoorBox: https://github.com/THUYimingLi/BackdoorBox 123 | - BackdoorBench:https://github.com/SCLBD/BackdoorBench 124 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/RNP/eeae192e5eab974d8b3002964cfb62d00388d36f/data/__init__.py -------------------------------------------------------------------------------- /data/poison_tool_cifar.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms, datasets 2 | from torchvision.datasets import CIFAR10 3 | from torch.utils.data import random_split, DataLoader, Dataset 4 | import torch 5 | import numpy as np 6 | import time 7 | import argparse 8 | from tqdm import tqdm 9 | from copy import deepcopy 10 | from PIL import Image 11 | import torch.nn.functional as F 12 | 13 | import sys 14 | sys.path.append("..") 15 | from models import dynamic_models 16 | 17 | if torch.cuda.is_available(): 18 | device = torch.device('cuda') 19 | else: 20 | device = torch.device('cpu') 21 | 22 | 23 | MEAN_CIFAR10 = (0.4914, 0.4822, 0.4465) 24 | STD_CIFAR10 = (0.2023, 0.1994, 0.2010) 25 | 26 | def split_dataset(dataset, frac=0.1, perm=None): 27 | """ 28 | :param dataset: The whole dataset which will be split. 29 | """ 30 | if perm is None: 31 | perm = np.arange(len(dataset)) 32 | np.random.shuffle(perm) 33 | nb_split = int(frac * len(dataset)) 34 | 35 | # generate the training set 36 | train_set = deepcopy(dataset) 37 | train_set.data = train_set.data[perm[nb_split:]] 38 | train_set.targets = np.array(train_set.targets)[perm[nb_split:]].tolist() 39 | 40 | # generate the test set 41 | split_set = deepcopy(dataset) 42 | split_set.data = split_set.data[perm[:nb_split]] 43 | split_set.targets = np.array(split_set.targets)[perm[:nb_split]].tolist() 44 | 45 | print('total data size: %d images, split test size: %d images, split ratio: %f' % ( 46 | len(train_set.targets), len(split_set.targets), frac)) 47 | 48 | return train_set, split_set 49 | 50 | def get_train_loader(args): 51 | print('==> Preparing train data..') 52 | tf_train = transforms.Compose([ 53 | transforms.RandomCrop(32, padding=4), 54 | # transforms.RandomRotation(3), 55 | transforms.RandomHorizontalFlip(), 56 | transforms.ToTensor(), 57 | transforms.Normalize(MEAN_CIFAR10, STD_CIFAR10) 58 | ]) 59 | 60 | if (args.dataset == 'CIFAR10'): 61 | trainset = datasets.CIFAR10(root='data/CIFAR10', train=True, download=True) 62 | else: 63 | raise Exception('Invalid dataset') 64 | 65 | train_data = DatasetCL(args, full_dataset=trainset, transform=tf_train) 66 | train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True) 67 | 68 | return train_loader 69 | 70 | def get_test_loader(args): 71 | print('==> Preparing test data..') 72 | tf_test = transforms.Compose([transforms.ToTensor(), 73 | transforms.Normalize(MEAN_CIFAR10, STD_CIFAR10) 74 | ]) 75 | if (args.dataset == 'CIFAR10'): 76 | testset = datasets.CIFAR10(root='data/CIFAR10', train=False, download=True) 77 | else: 78 | raise Exception('Invalid dataset') 79 | 80 | test_data_clean = DatasetBD(args, full_dataset=testset, inject_portion=0, transform=tf_test, mode='test') 81 | test_data_bad = DatasetBD(args, full_dataset=testset, inject_portion=1, transform=tf_test, mode='test') 82 | 83 | # (apart from label 0) bad test data 84 | test_clean_loader = DataLoader(dataset=test_data_clean, 85 | batch_size=args.batch_size, 86 | shuffle=False, 87 | ) 88 | # all clean test data 89 | test_bad_loader = DataLoader(dataset=test_data_bad, 90 | batch_size=args.batch_size, 91 | shuffle=False, 92 | ) 93 | 94 | return test_clean_loader, test_bad_loader 95 | 96 | 97 | def get_backdoor_loader(args): 98 | print('==> Preparing train data..') 99 | tf_train = transforms.Compose([ 100 | transforms.ToPILImage(), 101 | transforms.RandomCrop(32, padding=4), 102 | # transforms.RandomRotation(3), 103 | transforms.RandomHorizontalFlip(), 104 | transforms.ToTensor(), 105 | transforms.Normalize(MEAN_CIFAR10, STD_CIFAR10) 106 | ]) 107 | if (args.dataset == 'CIFAR10'): 108 | trainset = datasets.CIFAR10(root='data/CIFAR10', train=True, download=True) 109 | else: 110 | raise Exception('Invalid dataset') 111 | 112 | train_data_bad = DatasetBD(args, full_dataset=trainset, inject_portion=args.inject_portion, transform=tf_train, mode='train') 113 | train_bad_loader = DataLoader(dataset=train_data_bad, 114 | batch_size=args.batch_size, 115 | shuffle=False, 116 | ) 117 | 118 | return train_data_bad, train_bad_loader 119 | 120 | 121 | class Dataset_npy(torch.utils.data.Dataset): 122 | def __init__(self, full_dataset=None, transform=None): 123 | self.dataset = full_dataset 124 | self.transform = transform 125 | self.dataLen = len(self.dataset) 126 | 127 | def __getitem__(self, index): 128 | image = self.dataset[index][0] 129 | label = self.dataset[index][1] 130 | 131 | if self.transform: 132 | image = self.transform(image) 133 | # print(type(image), image.shape) 134 | return image, label 135 | 136 | def __len__(self): 137 | return self.dataLen 138 | 139 | 140 | 141 | class DatasetCL(Dataset): 142 | def __init__(self, args, full_dataset=None, transform=None): 143 | self.dataset = self.random_split(full_dataset=full_dataset, ratio=args.ratio) 144 | self.transform = transform 145 | self.dataLen = len(self.dataset) 146 | 147 | def __getitem__(self, index): 148 | image = self.dataset[index][0] 149 | label = self.dataset[index][1] 150 | 151 | if self.transform: 152 | image = self.transform(image) 153 | 154 | return image, label 155 | 156 | def __len__(self): 157 | return self.dataLen 158 | 159 | def random_split(self, full_dataset, ratio): 160 | print('full_train:', len(full_dataset)) 161 | train_size = int(ratio * len(full_dataset)) 162 | drop_size = len(full_dataset) - train_size 163 | train_dataset, drop_dataset = random_split(full_dataset, [train_size, drop_size]) 164 | print('train_size:', len(train_dataset), 'drop_size:', len(drop_dataset)) 165 | 166 | return train_dataset 167 | 168 | 169 | def create_bd(netG, netM, inputs): 170 | patterns = netG(inputs) 171 | masks_output = netM.threshold(netM(inputs)) 172 | return patterns, masks_output 173 | 174 | def normalization(data): 175 | _range = np.max(data) - np.min(data) 176 | return (data - np.min(data)) / _range 177 | 178 | class DatasetBD(Dataset): 179 | def __init__(self, args, full_dataset, inject_portion, transform=None, mode="train", device=torch.device("cuda"), distance=1): 180 | self.dataset = self.addTrigger(full_dataset, args.target_label, inject_portion, mode, distance, args.trig_w, args.trig_h, args.trigger_type, args.target_type) 181 | self.device = device 182 | self.transform = transform 183 | 184 | def __getitem__(self, item): 185 | img = self.dataset[item][0] 186 | label = self.dataset[item][1] 187 | img = self.transform(img) 188 | 189 | return img, label 190 | 191 | def __len__(self): 192 | return len(self.dataset) 193 | 194 | def addTrigger(self, dataset, target_label, inject_portion, mode, distance, trig_w, trig_h, trigger_type, target_type): 195 | print("Generating " + mode + "bad Imgs") 196 | perm = np.random.permutation(len(dataset))[0: int(len(dataset) * inject_portion)] 197 | # dataset 198 | dataset_ = list() 199 | 200 | cnt = 0 201 | for i in tqdm(range(len(dataset))): 202 | data = dataset[i] 203 | 204 | if target_type == 'all2one': 205 | 206 | if mode == 'train': 207 | img = np.array(data[0]) 208 | width = img.shape[0] 209 | height = img.shape[1] 210 | if i in perm: 211 | # select trigger 212 | img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, mode, trigger_type) 213 | 214 | # change target 215 | dataset_.append((img, target_label)) 216 | cnt += 1 217 | else: 218 | dataset_.append((img, data[1])) 219 | 220 | else: 221 | if data[1] == target_label: 222 | continue 223 | 224 | img = np.array(data[0]) 225 | width = img.shape[0] 226 | height = img.shape[1] 227 | if i in perm: 228 | img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, mode, trigger_type) 229 | 230 | dataset_.append((img, target_label)) 231 | cnt += 1 232 | else: 233 | dataset_.append((img, data[1])) 234 | 235 | # all2all attack 236 | elif target_type == 'all2all': 237 | 238 | if mode == 'train': 239 | img = np.array(data[0]) 240 | width = img.shape[0] 241 | height = img.shape[1] 242 | if i in perm: 243 | 244 | img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, mode, trigger_type) 245 | target_ = self._change_label_next(data[1]) 246 | 247 | dataset_.append((img, target_)) 248 | cnt += 1 249 | else: 250 | dataset_.append((img, data[1])) 251 | 252 | else: 253 | 254 | img = np.array(data[0]) 255 | width = img.shape[0] 256 | height = img.shape[1] 257 | if i in perm: 258 | img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, mode, trigger_type) 259 | 260 | target_ = self._change_label_next(data[1]) 261 | dataset_.append((img, target_)) 262 | cnt += 1 263 | else: 264 | dataset_.append((img, data[1])) 265 | 266 | # clean label attack 267 | elif target_type == 'cleanLabel': 268 | 269 | if mode == 'train': 270 | img = np.array(data[0]) 271 | width = img.shape[0] 272 | height = img.shape[1] 273 | 274 | if i in perm: 275 | if data[1] == target_label: 276 | 277 | img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, mode, trigger_type) 278 | 279 | dataset_.append((img, data[1])) 280 | cnt += 1 281 | 282 | else: 283 | dataset_.append((img, data[1])) 284 | else: 285 | dataset_.append((img, data[1])) 286 | 287 | else: 288 | if data[1] == target_label: 289 | continue 290 | 291 | img = np.array(data[0]) 292 | width = img.shape[0] 293 | height = img.shape[1] 294 | if i in perm: 295 | img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, mode, trigger_type) 296 | 297 | dataset_.append((img, target_label)) 298 | cnt += 1 299 | else: 300 | dataset_.append((img, data[1])) 301 | 302 | time.sleep(0.01) 303 | print("Injecting Over: " + str(cnt) + "Bad Imgs, " + str(len(dataset) - cnt) + "Clean Imgs") 304 | 305 | 306 | return dataset_ 307 | 308 | 309 | def _change_label_next(self, label): 310 | label_new = ((label + 1) % 10) 311 | return label_new 312 | 313 | def selectTrigger(self, img, width, height, distance, trig_w, trig_h, mode, triggerType): 314 | 315 | assert triggerType in ['squareTrigger', 'gridTrigger', 'fourCornerTrigger', 'randomPixelTrigger', 316 | 'signalTrigger', 'trojanTrigger', 'CLTrigger', 'dynamicTrigger', 'nashvilleTrigger', 317 | 'onePixelTrigger', 'wanetTrigger'] 318 | 319 | if triggerType == 'squareTrigger': 320 | img = self._squareTrigger(img, width, height, distance, trig_w, trig_h) 321 | 322 | elif triggerType == 'gridTrigger': 323 | img = self._gridTriger(img, width, height, distance, trig_w, trig_h) 324 | 325 | elif triggerType == 'fourCornerTrigger': 326 | img = self._fourCornerTrigger(img, width, height, distance, trig_w, trig_h) 327 | 328 | elif triggerType == 'randomPixelTrigger': 329 | img = self._randomPixelTrigger(img, width, height, distance, trig_w, trig_h) 330 | 331 | elif triggerType == 'signalTrigger': 332 | img = self._signalTrigger(img, width, height, distance, trig_w, trig_h) 333 | 334 | elif triggerType == 'trojanTrigger': 335 | img = self._trojanTrigger(img, width, height, distance, trig_w, trig_h) 336 | 337 | elif triggerType == 'CLTrigger': 338 | img = self._CLTrigger(img, mode=mode) 339 | 340 | elif triggerType == 'dynamicTrigger': 341 | img = self._dynamicTrigger(img, mode=mode) 342 | 343 | elif triggerType == 'nashvilleTrigger': 344 | img = self._nashvilleTrigger(img, mode=mode) 345 | 346 | elif triggerType == 'onePixelTrigger': 347 | img = self._onePixelTrigger(img, mode=mode) 348 | 349 | elif triggerType == 'wanetTrigger': 350 | img = self._wanetTrigger(img, mode=mode) 351 | 352 | else: 353 | raise NotImplementedError 354 | 355 | return img 356 | 357 | def _squareTrigger(self, img, width, height, distance, trig_w, trig_h): 358 | for j in range(width - distance - trig_w, width - distance): 359 | for k in range(height - distance - trig_h, height - distance): 360 | img[j, k] = 255.0 361 | 362 | return img 363 | 364 | def _gridTriger(self, img, width, height, distance, trig_w, trig_h): 365 | 366 | img[width - 1][height - 1] = 255 367 | img[width - 1][height - 2] = 0 368 | img[width - 1][height - 3] = 255 369 | 370 | img[width - 2][height - 1] = 0 371 | img[width - 2][height - 2] = 255 372 | img[width - 2][height - 3] = 0 373 | 374 | img[width - 3][height - 1] = 255 375 | img[width - 3][height - 2] = 0 376 | img[width - 3][height - 3] = 0 377 | 378 | # adptive center trigger 379 | # alpha = 1 380 | # img[width - 14][height - 14] = 255* alpha 381 | # img[width - 14][height - 13] = 128* alpha 382 | # img[width - 14][height - 12] = 255* alpha 383 | # 384 | # img[width - 13][height - 14] = 128* alpha 385 | # img[width - 13][height - 13] = 255* alpha 386 | # img[width - 13][height - 12] = 128* alpha 387 | # 388 | # img[width - 12][height - 14] = 255* alpha 389 | # img[width - 12][height - 13] = 128* alpha 390 | # img[width - 12][height - 12] = 128* alpha 391 | 392 | return img 393 | 394 | def _fourCornerTrigger(self, img, width, height, distance, trig_w, trig_h): 395 | # right bottom 396 | img[width - 1][height - 1] = 255 397 | img[width - 1][height - 2] = 0 398 | img[width - 1][height - 3] = 255 399 | 400 | img[width - 2][height - 1] = 0 401 | img[width - 2][height - 2] = 255 402 | img[width - 2][height - 3] = 0 403 | 404 | img[width - 3][height - 1] = 255 405 | img[width - 3][height - 2] = 0 406 | img[width - 3][height - 3] = 0 407 | 408 | # left top 409 | img[1][1] = 255 410 | img[1][2] = 0 411 | img[1][3] = 255 412 | 413 | img[2][1] = 0 414 | img[2][2] = 255 415 | img[2][3] = 0 416 | 417 | img[3][1] = 255 418 | img[3][2] = 0 419 | img[3][3] = 0 420 | 421 | # right top 422 | img[width - 1][1] = 255 423 | img[width - 1][2] = 0 424 | img[width - 1][3] = 255 425 | 426 | img[width - 2][1] = 0 427 | img[width - 2][2] = 255 428 | img[width - 2][3] = 0 429 | 430 | img[width - 3][1] = 255 431 | img[width - 3][2] = 0 432 | img[width - 3][3] = 0 433 | 434 | # left bottom 435 | img[1][height - 1] = 255 436 | img[2][height - 1] = 0 437 | img[3][height - 1] = 255 438 | 439 | img[1][height - 2] = 0 440 | img[2][height - 2] = 255 441 | img[3][height - 2] = 0 442 | 443 | img[1][height - 3] = 255 444 | img[2][height - 3] = 0 445 | img[3][height - 3] = 0 446 | 447 | return img 448 | 449 | def _randomPixelTrigger(self, img, width, height, distance, trig_w, trig_h): 450 | alpha = 0.2 451 | mask = np.random.randint(low=0, high=256, size=(width, height), dtype=np.uint8) 452 | blend_img = (1 - alpha) * img + alpha * mask.reshape((width, height, 1)) 453 | blend_img = np.clip(blend_img.astype('uint8'), 0, 255) 454 | 455 | # print(blend_img.dtype) 456 | return blend_img 457 | 458 | def _signalTrigger(self, img, width, height, distance, trig_w, trig_h): 459 | alpha = 0.2 460 | # load signal mask 461 | signal_mask = np.load('trigger/signal_cifar10_mask.npy') 462 | blend_img = (1 - alpha) * img + alpha * signal_mask.reshape((width, height, 1)) # FOR CIFAR10 463 | blend_img = np.clip(blend_img.astype('uint8'), 0, 255) 464 | 465 | return blend_img 466 | 467 | def _trojanTrigger(self, img, width, height, distance, trig_w, trig_h): 468 | # load trojanmask 469 | trg = np.load('trigger/best_square_trigger_cifar10.npz')['x'] 470 | # trg.shape: (3, 32, 32) 471 | trg = np.transpose(trg, (1, 2, 0)) 472 | img_ = np.clip((img + trg).astype('uint8'), 0, 255) 473 | 474 | return img_ 475 | 476 | def _CLTrigger(self, img, mode='Train'): 477 | # Load trigger 478 | width, height, c = img.shape 479 | 480 | # Add triger 481 | if mode == 'Train': 482 | trigger = np.load('trigger/best_universal.npy')[0] 483 | img = img / 255 484 | img = img.astype(np.float32) 485 | img += trigger 486 | img = normalization(img) 487 | img = img * 255 488 | # right bottom 489 | img[width - 1][height - 1] = 255 490 | img[width - 1][height - 2] = 0 491 | img[width - 1][height - 3] = 255 492 | 493 | img[width - 2][height - 1] = 0 494 | img[width - 2][height - 2] = 255 495 | img[width - 2][height - 3] = 0 496 | 497 | img[width - 3][height - 1] = 255 498 | img[width - 3][height - 2] = 0 499 | img[width - 3][height - 3] = 0 500 | 501 | img = img.astype(np.uint8) 502 | else: 503 | # right bottom 504 | img[width - 1][height - 1] = 255 505 | img[width - 1][height - 2] = 0 506 | img[width - 1][height - 3] = 255 507 | 508 | img[width - 2][height - 1] = 0 509 | img[width - 2][height - 2] = 255 510 | img[width - 2][height - 3] = 0 511 | 512 | img[width - 3][height - 1] = 255 513 | img[width - 3][height - 2] = 0 514 | img[width - 3][height - 3] = 0 515 | 516 | img = img.astype(np.uint8) 517 | 518 | return img 519 | 520 | def _wanetTrigger(self, img, mode='Train'): 521 | 522 | if not isinstance(img, np.ndarray): 523 | raise TypeError("Img should be np.ndarray. Got {}".format(type(img))) 524 | if len(img.shape) != 3: 525 | raise ValueError("The shape of img should be HWC. Got {}".format(img.shape)) 526 | 527 | # Prepare grid 528 | s = 0.5 529 | k = 32 # 4 is not large enough for ASR 530 | grid_rescale = 1 531 | ins = torch.rand(1, 2, k, k) * 2 - 1 532 | ins = ins / torch.mean(torch.abs(ins)) 533 | noise_grid = F.upsample(ins, size=32, mode="bicubic", align_corners=True) 534 | noise_grid = noise_grid.permute(0, 2, 3, 1) 535 | array1d = torch.linspace(-1, 1, steps=32) 536 | x, y = torch.meshgrid(array1d, array1d) 537 | identity_grid = torch.stack((y, x), 2)[None, ...] 538 | grid = identity_grid + s * noise_grid / 32 * grid_rescale 539 | grid = torch.clamp(grid, -1, 1) 540 | 541 | img = torch.tensor(img).permute(2, 0, 1) / 255.0 542 | poison_img = F.grid_sample(img.unsqueeze(0), grid, align_corners=True).squeeze() # CHW 543 | poison_img = poison_img.permute(1, 2, 0) * 255 544 | poison_img = poison_img.numpy().astype(np.uint8) 545 | 546 | return poison_img 547 | 548 | def _nashvilleTrigger(self, img, mode='Train'): 549 | # Add Backdoor Trigers 550 | import pilgram 551 | img = Image.fromarray(img) 552 | img = pilgram.nashville(img) 553 | img = np.asarray(img).astype(np.uint8) 554 | 555 | return img 556 | 557 | def _onePixelTrigger(self, img, mode='Train'): 558 | #one pixel 559 | if not isinstance(img, np.ndarray): 560 | raise TypeError("Img should be np.ndarray. Got {}".format(type(img))) 561 | if len(img.shape) != 3: 562 | raise ValueError("The shape of img should be HWC. Got {}".format(img.shape)) 563 | 564 | width, height, c = img.shape 565 | img[width // 2][height // 2] = 255 566 | 567 | return img 568 | 569 | def _dynamicTrigger(self, img, mode='Train'): 570 | # Load dynamic trigger model 571 | ckpt_path = 'all2one_cifar10_ckpt.pth.tar' 572 | state_dict = torch.load(ckpt_path, map_location=device) 573 | opt = state_dict["opt"] 574 | netG = dynamic_models.Generator(opt).to(device) 575 | netG.load_state_dict(state_dict["netG"]) 576 | netG = netG.eval() 577 | netM = dynamic_models.Generator(opt, out_channels=1).to(device) 578 | netM.load_state_dict(state_dict["netM"]) 579 | netM = netM.eval() 580 | normalizer = transforms.Normalize([0.4914, 0.4822, 0.4465], 581 | [0.247, 0.243, 0.261]) 582 | 583 | # Add trigers 584 | x = img.copy() 585 | x = torch.tensor(x).permute(2, 0, 1) / 255.0 586 | x_in = torch.stack([normalizer(x)]).to(device) 587 | p, m = create_bd(netG, netM, x_in) 588 | p = p[0, :, :, :].detach().cpu() 589 | m = m[0, :, :, :].detach().cpu() 590 | x_bd = x + (p - x) * m 591 | x_bd = x_bd.permute(1, 2, 0).numpy() * 255 592 | x_bd = x_bd.astype(np.uint8) 593 | 594 | return x_bd 595 | 596 | if __name__ == '__main__': 597 | parser = argparse.ArgumentParser(description='Poisoned dataset') 598 | # backdoor attacks 599 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='name of image dataset') 600 | parser.add_argument('--target_label', type=int, default=0, help='class of target label') 601 | parser.add_argument('--trigger_type', type=str, default='gridTrigger', help='type of backdoor trigger', 602 | choices=['gridTrigger', 'fourCornerTrigger', 'trojanTrigger', 'blendTrigger', 'signalTrigger', 'CLTrigger', 603 | 'smoothTrigger', 'dynamicTrigger', 'nashvilleTrigger', 'onePixelTrigger']) 604 | parser.add_argument('--target_type', type=str, default='all2one', help='type of backdoor label') 605 | parser.add_argument('--trig_w', type=int, default=10, help='width of trigger pattern') 606 | parser.add_argument('--trig_h', type=int, default=10, help='height of trigger pattern') 607 | 608 | opt = parser.parse_args() 609 | 610 | tf_train = transforms.Compose([transforms.ToTensor() 611 | ]) 612 | clean_set = CIFAR10(root='/fs/scratch/sgh_cr_bcai_dl_cluster_users/03_open_source_dataset/', train=False) 613 | # split a small test subset 614 | _, split_set = split_dataset(clean_set, frac=0.01) 615 | poison_set = train_data_bad = DatasetBD(opt=opt, full_dataset=split_set, inject_portion=0.1, transform=tf_train, mode='train') 616 | import matplotlib.pyplot as plt 617 | print(poison_set.__getitem__(0)) 618 | x, y = poison_set.__getitem__(0) 619 | plt.imshow(x) 620 | plt.show() 621 | -------------------------------------------------------------------------------- /logs/output.log: -------------------------------------------------------------------------------- 1 | [2023/07/09 22:21:00] - Namespace(alpha=0.2, arch='resnet18', backdoor_model_path='weights/ResNet18-ResNet-BadNets-target0-portion0.1-epoch80.tar', batch_size=128, clean_threshold=0.2, cuda=1, dataset='CIFAR10', log_root='logs/', mask_file=None, momentum=0.9, num_class=10, output_weight='weights/', pruning_by='threshold', pruning_max=0.9, pruning_step=0.05, ratio=0.01, recovering_epochs=20, recovering_lr=0.2, save_every=5, schedule=[10, 20], target_label=0, target_type='all2one', trig_h=3, trig_w=3, trigger_type='gridTrigger', unlearned_model_path=None, unlearning_epochs=20, unlearning_lr=0.01, weight_decay=0.0005) 2 | [2023/07/09 22:21:00] - ----------- Data Initialization -------------- 3 | [2023/07/09 22:21:03] - ----------- Backdoor Model Initialization -------------- 4 | [2023/07/09 22:21:04] - Epoch lr Time TrainLoss TrainACC PoisonLoss PoisonACC CleanLoss CleanACC 5 | [2023/07/09 22:21:04] - ----------- Model Unlearning -------------- 6 | [2023/07/09 22:21:15] - 0 0.010 11.2 0.1004 0.9780 0.0000 1.0000 0.2185 0.9342 7 | [2023/07/09 22:21:26] - 1 0.010 10.8 0.1213 0.9760 0.0000 1.0000 0.2253 0.9317 8 | [2023/07/09 22:21:37] - 2 0.010 10.8 0.1400 0.9740 0.0000 1.0000 0.2349 0.9304 9 | [2023/07/09 22:21:48] - 3 0.010 10.8 0.1535 0.9720 0.0000 1.0000 0.2513 0.9266 10 | [2023/07/09 22:21:59] - 4 0.010 10.9 0.2078 0.9640 0.0000 1.0000 0.2770 0.9214 11 | [2023/07/09 22:22:10] - 5 0.010 11.0 0.2614 0.9500 0.0000 1.0000 0.3144 0.9141 12 | [2023/07/09 22:22:21] - 6 0.010 10.8 0.3711 0.9220 0.0000 1.0000 0.3847 0.8991 13 | [2023/07/09 22:22:31] - 7 0.010 10.8 0.4538 0.8700 0.0000 1.0000 0.5276 0.8669 14 | [2023/07/09 22:22:42] - 8 0.010 10.8 0.7916 0.6700 0.0000 1.0000 0.9439 0.7586 15 | [2023/07/09 22:22:53] - 9 0.010 10.8 1.4771 0.4540 0.0000 1.0000 2.5574 0.4016 16 | [2023/07/09 22:23:04] - 10 0.001 10.8 3.1028 0.2920 0.0000 1.0000 2.5620 0.3949 17 | [2023/07/09 22:23:15] - 11 0.001 10.8 4.0416 0.2360 0.0000 1.0000 2.8507 0.3729 18 | [2023/07/09 22:23:25] - 12 0.001 10.8 4.8811 0.1980 0.0000 1.0000 3.2337 0.3368 19 | [2023/07/09 22:23:25] - ----------- Model Recovering -------------- 20 | [2023/07/09 22:23:26] - Epoch lr Time TrainLoss TrainACC PoisonLoss PoisonACC CleanLoss CleanACC 21 | [2023/07/09 22:23:37] - 1 0.200 11.0 1.0719 0.1980 0.0000 1.0000 2.0311 0.4370 22 | [2023/07/09 22:23:48] - 2 0.200 11.0 0.9892 0.1920 0.0000 1.0000 1.3383 0.5432 23 | [2023/07/09 22:23:59] - 3 0.200 11.0 0.7809 0.2320 0.0018 1.0000 1.1726 0.6138 24 | [2023/07/09 22:24:10] - 4 0.200 11.0 0.4892 0.2500 0.3161 0.8770 1.0972 0.6552 25 | [2023/07/09 22:24:21] - 5 0.200 11.1 0.4130 0.2860 0.6548 0.6051 1.0409 0.6662 26 | [2023/07/09 22:24:32] - 6 0.200 11.0 0.3691 0.3060 0.7871 0.5084 0.9843 0.6822 27 | [2023/07/09 22:24:43] - 7 0.200 11.0 0.3262 0.3460 0.8479 0.4791 0.9089 0.7053 28 | [2023/07/09 22:24:54] - 8 0.200 11.0 0.2963 0.3760 0.8369 0.4904 0.8691 0.7157 29 | [2023/07/09 22:25:05] - 9 0.200 11.0 0.2777 0.3900 0.8192 0.5090 0.8226 0.7324 30 | [2023/07/09 22:25:16] - 10 0.200 11.0 0.2485 0.4320 0.7842 0.5340 0.7765 0.7497 31 | [2023/07/09 22:25:27] - 11 0.200 11.0 0.2337 0.4500 0.7554 0.5562 0.7283 0.7666 32 | [2023/07/09 22:25:38] - 12 0.200 11.0 0.2140 0.4900 0.6922 0.6044 0.7022 0.7752 33 | [2023/07/09 22:25:49] - 13 0.200 11.0 0.2043 0.5140 0.6542 0.6317 0.6718 0.7870 34 | [2023/07/09 22:26:00] - 14 0.200 11.0 0.1807 0.5340 0.6128 0.6598 0.6517 0.7951 35 | [2023/07/09 22:26:11] - 15 0.200 11.0 0.1724 0.5440 0.5820 0.6873 0.6342 0.8018 36 | [2023/07/09 22:26:22] - 16 0.200 11.0 0.1729 0.5780 0.5754 0.6968 0.6084 0.8121 37 | [2023/07/09 22:26:33] - 17 0.200 11.1 0.1532 0.6180 0.5683 0.7027 0.5930 0.8176 38 | [2023/07/09 22:26:44] - 18 0.200 11.0 0.1476 0.6120 0.5614 0.7083 0.5766 0.8244 39 | [2023/07/09 22:26:55] - 19 0.200 11.0 0.1510 0.6380 0.5674 0.7069 0.5601 0.8312 40 | [2023/07/09 22:27:06] - 20 0.200 11.1 0.1439 0.6520 0.5788 0.6988 0.5417 0.8402 41 | [2023/07/09 22:27:07] - ----------- Backdoored Model Pruning -------------- 42 | [2023/07/09 22:27:07] - No. Layer Name Neuron Idx Mask PoisonLoss PoisonACC CleanLoss CleanACC 43 | [2023/07/09 22:27:17] - 0 None None 0.0001 1.0000 0.2157 0.9340 44 | [2023/07/09 22:27:27] - 12.00 layer4.1.bn2 188 0.0 0.0122 0.9986 0.2104 0.9347 45 | [2023/07/09 22:27:37] - 14.00 layer4.1.bn2 12 0.05 0.0139 0.9982 0.2092 0.9340 46 | [2023/07/09 22:27:46] - 18.00 layer3.0.bn2 230 0.1 0.0601 0.9876 0.2065 0.9338 47 | [2023/07/09 22:27:56] - 21.00 bn1 5 0.15000000000000002 0.9935 0.4796 0.2074 0.9340 48 | [2023/07/09 22:28:06] - 24.00 layer4.0.bn2 82 0.2 2.8242 0.0661 0.2297 0.9280 49 | [2023/07/09 22:28:16] - 28.00 layer3.0.bn1 106 0.25 3.2791 0.0424 0.2292 0.9270 50 | [2023/07/09 22:28:26] - 32.00 layer4.1.bn2 152 0.30000000000000004 4.2908 0.0172 0.2295 0.9277 51 | [2023/07/09 22:28:36] - 41.00 layer3.0.bn2 97 0.35000000000000003 5.5764 0.0112 0.2936 0.9094 52 | [2023/07/09 22:28:46] - 49.00 layer4.1.bn2 451 0.4 6.3171 0.0070 0.2963 0.9070 53 | [2023/07/09 22:28:56] - 57.00 layer4.1.bn2 352 0.45 6.6227 0.0059 0.2897 0.9077 54 | [2023/07/09 22:29:06] - 68.00 layer2.0.bn1 50 0.5 6.7961 0.0060 0.3132 0.9012 55 | [2023/07/09 22:29:16] - 82.00 layer4.1.bn2 86 0.55 7.0985 0.0034 0.3194 0.8976 56 | [2023/07/09 22:29:26] - 107.00 layer4.1.bn2 132 0.6000000000000001 7.2623 0.0026 0.3248 0.8941 57 | [2023/07/09 22:29:36] - 145.00 layer4.1.bn2 60 0.65 7.0780 0.0021 0.5540 0.8201 58 | [2023/07/09 22:29:46] - 204.00 layer4.1.bn2 271 0.7000000000000001 7.1320 0.0011 0.5561 0.8167 59 | [2023/07/09 22:29:56] - 263.00 bn1 61 0.75 6.9168 0.0006 0.6570 0.7844 60 | [2023/07/09 22:30:06] - 365.00 layer2.0.bn2 80 0.8 6.8394 0.0000 0.8824 0.7258 61 | [2023/07/09 22:30:17] - 524.00 layer2.0.bn2 33 0.8500000000000001 6.8990 0.0000 1.2204 0.6177 62 | [2023/07/09 22:30:28] - 753.00 layer3.0.bn2 113 0.9 5.9777 0.0000 2.0017 0.2519 63 | -------------------------------------------------------------------------------- /logs/pruning_by_threshold.txt: -------------------------------------------------------------------------------- 1 | No Layer Name Neuron Idx Mask PoisonLoss PoisonACC CleanLoss CleanACC 2 | 12.00 layer4.1.bn2 188 0.0 0.0122 0.9986 0.2104 0.9347 3 | 14.00 layer4.1.bn2 12 0.05 0.0139 0.9982 0.2092 0.9340 4 | 18.00 layer3.0.bn2 230 0.1 0.0601 0.9876 0.2065 0.9338 5 | 21.00 bn1 5 0.15000000000000002 0.9935 0.4796 0.2074 0.9340 6 | 24.00 layer4.0.bn2 82 0.2 2.8242 0.0661 0.2297 0.9280 7 | 28.00 layer3.0.bn1 106 0.25 3.2791 0.0424 0.2292 0.9270 8 | 32.00 layer4.1.bn2 152 0.30000000000000004 4.2908 0.0172 0.2295 0.9277 9 | 41.00 layer3.0.bn2 97 0.35000000000000003 5.5764 0.0112 0.2936 0.9094 10 | 49.00 layer4.1.bn2 451 0.4 6.3171 0.0070 0.2963 0.9070 11 | 57.00 layer4.1.bn2 352 0.45 6.6227 0.0059 0.2897 0.9077 12 | 68.00 layer2.0.bn1 50 0.5 6.7961 0.0060 0.3132 0.9012 13 | 82.00 layer4.1.bn2 86 0.55 7.0985 0.0034 0.3194 0.8976 14 | 107.00 layer4.1.bn2 132 0.6000000000000001 7.2623 0.0026 0.3248 0.8941 15 | 145.00 layer4.1.bn2 60 0.65 7.0780 0.0021 0.5540 0.8201 16 | 204.00 layer4.1.bn2 271 0.7000000000000001 7.1320 0.0011 0.5561 0.8167 17 | 263.00 bn1 61 0.75 6.9168 0.0006 0.6570 0.7844 18 | 365.00 layer2.0.bn2 80 0.8 6.8394 0.0000 0.8824 0.7258 19 | 524.00 layer2.0.bn2 33 0.8500000000000001 6.8990 0.0000 1.2204 0.6177 20 | 753.00 layer3.0.bn2 113 0.9 5.9777 0.0000 2.0017 0.2519 21 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import logging 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import pandas as pd 9 | from collections import OrderedDict 10 | import models 11 | from datasets.poison_tool_cifar import get_backdoor_loader, get_test_loader, get_train_loader 12 | 13 | if torch.cuda.is_available(): 14 | torch.backends.cudnn.enabled = True 15 | torch.backends.cudnn.benchmark = True 16 | device = torch.device('cuda') 17 | else: 18 | device = torch.device('cpu') 19 | 20 | seed = 98 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | torch.manual_seed(seed) 24 | np.random.seed(seed) 25 | 26 | 27 | def train_step_unlearning(args, model, criterion, optimizer, data_loader): 28 | model.train() 29 | total_correct = 0 30 | total_loss = 0.0 31 | for i, (images, labels) in enumerate(data_loader): 32 | images, labels = images.to(device), labels.to(device) 33 | optimizer.zero_grad() 34 | output = model(images) 35 | loss = criterion(output, labels) 36 | 37 | pred = output.data.max(1)[1] 38 | total_correct += pred.eq(labels.view_as(pred)).sum() 39 | total_loss += loss.item() 40 | 41 | nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2) 42 | (-loss).backward() 43 | optimizer.step() 44 | 45 | loss = total_loss / len(data_loader) 46 | acc = float(total_correct) / len(data_loader.dataset) 47 | return loss, acc 48 | 49 | def train_step_recovering(args, unlearned_model, criterion, mask_opt, data_loader): 50 | unlearned_model.train() 51 | total_correct = 0 52 | total_loss = 0.0 53 | nb_samples = 0 54 | for i, (images, labels) in enumerate(data_loader): 55 | images, labels = images.to(device), labels.to(device) 56 | nb_samples += images.size(0) 57 | 58 | mask_opt.zero_grad() 59 | output = unlearned_model(images) 60 | loss = criterion(output, labels) 61 | loss = args.alpha * loss 62 | 63 | pred = output.data.max(1)[1] 64 | total_correct += pred.eq(labels.view_as(pred)).sum() 65 | total_loss += loss.item() 66 | loss.backward() 67 | mask_opt.step() 68 | clip_mask(unlearned_model) 69 | 70 | loss = total_loss / len(data_loader) 71 | acc = float(total_correct) / nb_samples 72 | return loss, acc 73 | 74 | 75 | def load_state_dict(net, orig_state_dict): 76 | if 'state_dict' in orig_state_dict.keys(): 77 | orig_state_dict = orig_state_dict['state_dict'] 78 | 79 | new_state_dict = OrderedDict() 80 | for k, v in net.state_dict().items(): 81 | if k in orig_state_dict.keys(): 82 | new_state_dict[k] = orig_state_dict[k] 83 | else: 84 | new_state_dict[k] = v 85 | net.load_state_dict(new_state_dict) 86 | 87 | 88 | def clip_mask(unlearned_model, lower=0.0, upper=1.0): 89 | params = [param for name, param in unlearned_model.named_parameters() if 'neuron_mask' in name] 90 | with torch.no_grad(): 91 | for param in params: 92 | param.clamp_(lower, upper) 93 | 94 | 95 | def save_mask_scores(state_dict, file_name): 96 | mask_values = [] 97 | count = 0 98 | for name, param in state_dict.items(): 99 | if 'neuron_mask' in name: 100 | for idx in range(param.size(0)): 101 | neuron_name = '.'.join(name.split('.')[:-1]) 102 | mask_values.append('{} \t {} \t {} \t {:.4f} \n'.format(count, neuron_name, idx, param[idx].item())) 103 | count += 1 104 | with open(file_name, "w") as f: 105 | f.write('No \t Layer Name \t Neuron Idx \t Mask Score \n') 106 | f.writelines(mask_values) 107 | 108 | def read_data(file_name): 109 | tempt = pd.read_csv(file_name, sep='\s+', skiprows=1, header=None) 110 | layer = tempt.iloc[:, 1] 111 | idx = tempt.iloc[:, 2] 112 | value = tempt.iloc[:, 3] 113 | mask_values = list(zip(layer, idx, value)) 114 | return mask_values 115 | 116 | def pruning(net, neuron): 117 | state_dict = net.state_dict() 118 | weight_name = '{}.{}'.format(neuron[0], 'weight') 119 | state_dict[weight_name][int(neuron[1])] = 0.0 120 | net.load_state_dict(state_dict) 121 | 122 | def test(model, criterion, data_loader): 123 | model.eval() 124 | total_correct = 0 125 | total_loss = 0.0 126 | with torch.no_grad(): 127 | for i, (images, labels) in enumerate(data_loader): 128 | images, labels = images.to(device), labels.to(device) 129 | output = model(images) 130 | total_loss += criterion(output, labels).item() 131 | pred = output.data.max(1)[1] 132 | total_correct += pred.eq(labels.data.view_as(pred)).sum() 133 | loss = total_loss / len(data_loader) 134 | acc = float(total_correct) / len(data_loader.dataset) 135 | return loss, acc 136 | 137 | def evaluate_by_number(model, logger, mask_values, pruning_max, pruning_step, criterion, clean_loader, poison_loader): 138 | results = [] 139 | nb_max = int(np.ceil(pruning_max)) 140 | nb_step = int(np.ceil(pruning_step)) 141 | for start in range(0, nb_max + 1, nb_step): 142 | i = start 143 | for i in range(start, start + nb_step): 144 | pruning(model, mask_values[i]) 145 | layer_name, neuron_idx, value = mask_values[i][0], mask_values[i][1], mask_values[i][2] 146 | cl_loss, cl_acc = test(model=model, criterion=criterion, data_loader=clean_loader) 147 | po_loss, po_acc = test(model=model, criterion=criterion, data_loader=poison_loader) 148 | logger.info('{} \t {} \t {} \t {} \t {:.4f} \t {:.4f} \t {:.4f} \t {:.4f}'.format( 149 | i+1, layer_name, neuron_idx, value, po_loss, po_acc, cl_loss, cl_acc)) 150 | results.append('{} \t {} \t {} \t {} \t {:.4f} \t {:.4f} \t {:.4f} \t {:.4f}'.format( 151 | i+1, layer_name, neuron_idx, value, po_loss, po_acc, cl_loss, cl_acc)) 152 | return results 153 | 154 | 155 | def evaluate_by_threshold(model, logger, mask_values, pruning_max, pruning_step, criterion, clean_loader, poison_loader): 156 | results = [] 157 | thresholds = np.arange(0, pruning_max + pruning_step, pruning_step) 158 | start = 0 159 | for threshold in thresholds: 160 | idx = start 161 | for idx in range(start, len(mask_values)): 162 | if float(mask_values[idx][2]) <= threshold: 163 | pruning(model, mask_values[idx]) 164 | start += 1 165 | else: 166 | break 167 | layer_name, neuron_idx, value = mask_values[idx][0], mask_values[idx][1], mask_values[idx][2] 168 | cl_loss, cl_acc = test(model=model, criterion=criterion, data_loader=clean_loader) 169 | po_loss, po_acc = test(model=model, criterion=criterion, data_loader=poison_loader) 170 | logger.info('{:.2f} \t {} \t {} \t {} \t {:.4f} \t {:.4f} \t {:.4f} \t {:.4f}'.format( 171 | start, layer_name, neuron_idx, threshold, po_loss, po_acc, cl_loss, cl_acc)) 172 | results.append('{:.2f} \t {} \t {} \t {} \t {:.4f} \t {:.4f} \t {:.4f} \t {:.4f}\n'.format( 173 | start, layer_name, neuron_idx, threshold, po_loss, po_acc, cl_loss, cl_acc)) 174 | return results 175 | 176 | def save_checkpoint(state, file_path): 177 | # filepath = os.path.join(args.output_dir, args.arch + '-unlearning_epochs{}.tar'.format(epoch)) 178 | torch.save(state, file_path) 179 | 180 | def main(args): 181 | logger = logging.getLogger(__name__) 182 | logging.basicConfig( 183 | format='[%(asctime)s] - %(message)s', 184 | datefmt='%Y/%m/%d %H:%M:%S', 185 | level=logging.DEBUG, 186 | handlers=[ 187 | logging.FileHandler(os.path.join(args.log_root, 'output.log')), 188 | logging.StreamHandler() 189 | ]) 190 | logger.info(args) 191 | 192 | logger.info('----------- Data Initialization --------------') 193 | defense_data_loader = get_train_loader(args) 194 | clean_test_loader, bad_test_loader = get_test_loader(args) 195 | 196 | logger.info('----------- Backdoor Model Initialization --------------') 197 | state_dict = torch.load(args.backdoor_model_path, map_location=device) 198 | net = getattr(models, args.arch)(num_classes=10, norm_layer=None) 199 | load_state_dict(net, orig_state_dict=state_dict) 200 | net = net.to(device) 201 | 202 | criterion = torch.nn.CrossEntropyLoss().to(device) 203 | optimizer = torch.optim.SGD(net.parameters(), lr=args.unlearning_lr, momentum=0.9, weight_decay=5e-4) 204 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.schedule, gamma=0.1) 205 | 206 | logger.info('----------- Model Unlearning --------------') 207 | logger.info('Epoch \t lr \t Time \t TrainLoss \t TrainACC \t PoisonLoss \t PoisonACC \t CleanLoss \t CleanACC') 208 | for epoch in range(0, args.unlearning_epochs + 1): 209 | start = time.time() 210 | lr = optimizer.param_groups[0]['lr'] 211 | train_loss, train_acc = train_step_unlearning(args=args, model=net, criterion=criterion, optimizer=optimizer, 212 | data_loader=defense_data_loader) 213 | cl_test_loss, cl_test_acc = test(model=net, criterion=criterion, data_loader=clean_test_loader) 214 | po_test_loss, po_test_acc = test(model=net, criterion=criterion, data_loader=bad_test_loader) 215 | scheduler.step() 216 | end = time.time() 217 | logger.info( 218 | '%d \t %.3f \t %.1f \t %.4f \t %.4f \t %.4f \t %.4f \t %.4f \t %.4f', 219 | epoch, lr, end - start, train_loss, train_acc, po_test_loss, po_test_acc, 220 | cl_test_loss, cl_test_acc) 221 | 222 | if train_acc <= args.clean_threshold: 223 | # save the last checkpoint 224 | file_path = os.path.join(args.output_weight, f'unlearned_model_last.tar') 225 | # torch.save(net.state_dict(), os.path.join(args.output_dir, 'unlearned_model_last.tar')) 226 | save_checkpoint({ 227 | 'epoch': epoch, 228 | 'state_dict': net.state_dict(), 229 | 'clean_acc': cl_test_acc, 230 | 'bad_acc': po_test_acc, 231 | 'optimizer': optimizer.state_dict(), 232 | }, file_path) 233 | break 234 | 235 | 236 | logger.info('----------- Model Recovering --------------') 237 | # Step 2: load unleanred model checkpoints 238 | if args.unlearned_model_path is not None: 239 | unlearned_model_path = args.unlearned_model_path 240 | else: 241 | unlearned_model_path = os.path.join(args.output_weight, 'unlearned_model_last.tar') 242 | 243 | checkpoint = torch.load(unlearned_model_path, map_location=device) 244 | print('Unlearned Model:', checkpoint['epoch'], checkpoint['clean_acc'], checkpoint['bad_acc']) 245 | 246 | unlearned_model = getattr(models, args.arch)(num_classes=10, norm_layer=models.MaskBatchNorm2d) 247 | load_state_dict(unlearned_model, orig_state_dict=checkpoint['state_dict']) 248 | unlearned_model = unlearned_model.to(device) 249 | criterion = torch.nn.CrossEntropyLoss().to(device) 250 | 251 | parameters = list(unlearned_model.named_parameters()) 252 | mask_params = [v for n, v in parameters if "neuron_mask" in n] 253 | mask_optimizer = torch.optim.SGD(mask_params, lr=args.recovering_lr, momentum=0.9) 254 | 255 | # Recovering 256 | logger.info('Epoch \t lr \t Time \t TrainLoss \t TrainACC \t PoisonLoss \t PoisonACC \t CleanLoss \t CleanACC') 257 | for epoch in range(1, args.recovering_epochs + 1): 258 | start = time.time() 259 | lr = mask_optimizer.param_groups[0]['lr'] 260 | train_loss, train_acc = train_step_recovering(args=args, unlearned_model=unlearned_model, criterion=criterion, data_loader=defense_data_loader, 261 | mask_opt=mask_optimizer) 262 | cl_test_loss, cl_test_acc = test(model=unlearned_model, criterion=criterion, data_loader=clean_test_loader) 263 | po_test_loss, po_test_acc = test(model=unlearned_model, criterion=criterion, data_loader=bad_test_loader) 264 | end = time.time() 265 | logger.info('{} \t {:.3f} \t {:.1f} \t {:.4f} \t {:.4f} \t {:.4f} \t {:.4f} \t {:.4f} \t {:.4f}'.format( 266 | epoch, lr, end - start, train_loss, train_acc, po_test_loss, po_test_acc, 267 | cl_test_loss, cl_test_acc)) 268 | save_mask_scores(unlearned_model.state_dict(), os.path.join(args.log_root, 'mask_values.txt')) 269 | 270 | del unlearned_model, net 271 | logger.info('----------- Backdoored Model Pruning --------------') 272 | # load model checkpoints and trigger info 273 | state_dict = torch.load(args.backdoor_model_path, map_location=device) 274 | net = getattr(models, args.arch)(num_classes=10, norm_layer=None) 275 | load_state_dict(net, orig_state_dict=state_dict) 276 | net = net.to(device) 277 | 278 | criterion = torch.nn.CrossEntropyLoss().to(device) 279 | 280 | # Step 3: pruning 281 | if args.mask_file is not None: 282 | mask_file = args.mask_file 283 | else: 284 | mask_file = os.path.join(args.log_root, 'mask_values.txt') 285 | 286 | mask_values = read_data(mask_file) 287 | mask_values = sorted(mask_values, key=lambda x: float(x[2])) 288 | logger.info('No. \t Layer Name \t Neuron Idx \t Mask \t PoisonLoss \t PoisonACC \t CleanLoss \t CleanACC') 289 | cl_loss, cl_acc = test(model=net, criterion=criterion, data_loader=clean_test_loader) 290 | po_loss, po_acc = test(model=net, criterion=criterion, data_loader=bad_test_loader) 291 | logger.info('0 \t None \t None \t {:.4f} \t {:.4f} \t {:.4f} \t {:.4f}'.format(po_loss, po_acc, cl_loss, cl_acc)) 292 | if args.pruning_by == 'threshold': 293 | results = evaluate_by_threshold( 294 | net, logger, mask_values, pruning_max=args.pruning_max, pruning_step=args.pruning_step, 295 | criterion=criterion, clean_loader=clean_test_loader, poison_loader=bad_test_loader 296 | ) 297 | else: 298 | results = evaluate_by_number( 299 | net, logger, mask_values, pruning_max=args.pruning_max, pruning_step=args.pruning_step, 300 | criterion=criterion, clean_loader=clean_test_loader, poison_loader=bad_test_loader 301 | ) 302 | file_name = os.path.join(args.log_root, 'pruning_by_{}.txt'.format(args.pruning_by)) 303 | with open(file_name, "w") as f: 304 | f.write('No \t Layer Name \t Neuron Idx \t Mask \t PoisonLoss \t PoisonACC \t CleanLoss \t CleanACC\n') 305 | f.writelines(results) 306 | 307 | 308 | if __name__ == '__main__': 309 | # Prepare arguments 310 | parser = argparse.ArgumentParser() 311 | 312 | # various path 313 | parser.add_argument('--cuda', type=int, default=1, help='cuda available') 314 | parser.add_argument('--save-every', type=int, default=5, help='save checkpoints every few epochs') 315 | parser.add_argument('--log_root', type=str, default='logs/', help='logs are saved here') 316 | parser.add_argument('--output_weight', type=str, default='weights/') 317 | parser.add_argument('--backdoor_model_path', type=str, 318 | default='weights/ResNet18-ResNet-BadNets-target0-portion0.1-epoch80.tar', 319 | help='path of backdoored model') 320 | parser.add_argument('--unlearned_model_path', type=str, 321 | default=None, help='path of unlearned backdoored model') 322 | parser.add_argument('--arch', type=str, default='resnet18', 323 | choices=['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'MobileNetV2', 324 | 'vgg19_bn']) 325 | parser.add_argument('--schedule', type=int, nargs='+', default=[10, 20], 326 | help='Decrease learning rate at these epochs.') 327 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='name of image dataset') 328 | parser.add_argument('--batch_size', type=int, default=128, help='The size of batch') 329 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 330 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay') 331 | parser.add_argument('--num_class', type=int, default=10, help='number of classes') 332 | parser.add_argument('--ratio', type=float, default=0.01, help='ratio of defense data') 333 | 334 | # backdoor attacks 335 | parser.add_argument('--target_label', type=int, default=0, help='class of target label') 336 | parser.add_argument('--trigger_type', type=str, default='gridTrigger', help='type of backdoor trigger') 337 | parser.add_argument('--target_type', type=str, default='all2one', help='type of backdoor label') 338 | parser.add_argument('--trig_w', type=int, default=3, help='width of trigger pattern') 339 | parser.add_argument('--trig_h', type=int, default=3, help='height of trigger pattern') 340 | 341 | # RNP 342 | parser.add_argument('--alpha', type=float, default=0.2) 343 | parser.add_argument('--clean_threshold', type=float, default=0.20, help='threshold of unlearning accuracy') 344 | parser.add_argument('--unlearning_lr', type=float, default=0.01, help='the learning rate for neuron unlearning') 345 | parser.add_argument('--recovering_lr', type=float, default=0.2, help='the learning rate for mask optimization') 346 | parser.add_argument('--unlearning_epochs', type=int, default=20, help='the number of epochs for unlearning') 347 | parser.add_argument('--recovering_epochs', type=int, default=20, help='the number of epochs for recovering') 348 | parser.add_argument('--mask_file', type=str, default=None, help='The text file containing the mask values') 349 | parser.add_argument('--pruning-by', type=str, default='threshold', choices=['number', 'threshold']) 350 | parser.add_argument('--pruning-max', type=float, default=0.90, help='the maximum number/threshold for pruning') 351 | parser.add_argument('--pruning-step', type=float, default=0.05, help='the step size for evaluating the pruning') 352 | 353 | args = parser.parse_args() 354 | args_dict = vars(args) 355 | print(args_dict) 356 | os.makedirs(args.log_root, exist_ok=True) 357 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 358 | 359 | main(args) 360 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet_cifar import * 2 | from .vgg_cifar import * 3 | from .mobilenetv2 import * 4 | from .mask_batchnorm import * 5 | 6 | -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/RNP/eeae192e5eab974d8b3002964cfb62d00388d36f/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/blocks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/RNP/eeae192e5eab974d8b3002964cfb62d00388d36f/models/__pycache__/blocks.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/dynamic_models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/RNP/eeae192e5eab974d8b3002964cfb62d00388d36f/models/__pycache__/dynamic_models.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/mask_batchnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/RNP/eeae192e5eab974d8b3002964cfb62d00388d36f/models/__pycache__/mask_batchnorm.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/mobilenetv2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/RNP/eeae192e5eab974d8b3002964cfb62d00388d36f/models/__pycache__/mobilenetv2.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet_cifar.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/RNP/eeae192e5eab974d8b3002964cfb62d00388d36f/models/__pycache__/resnet_cifar.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/vgg_cifar.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/RNP/eeae192e5eab974d8b3002964cfb62d00388d36f/models/__pycache__/vgg_cifar.cpython-36.pyc -------------------------------------------------------------------------------- /models/blocks.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class Conv2dBlock(nn.Module): 5 | def __init__(self, in_c, out_c, ker_size=(3, 3), stride=1, padding=1, batch_norm=True, relu=True): 6 | super(Conv2dBlock, self).__init__() 7 | self.conv2d = nn.Conv2d(in_c, out_c, ker_size, stride, padding) 8 | if batch_norm: 9 | self.batch_norm = nn.BatchNorm2d(out_c, eps=1e-5, momentum=0.05, affine=True) 10 | if relu: 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | for module in self.children(): 15 | x = module(x) 16 | return x 17 | 18 | 19 | class DownSampleBlock(nn.Module): 20 | def __init__(self, ker_size=(2, 2), stride=2, dilation=(1, 1), ceil_mode=False, p=0.0): 21 | super(DownSampleBlock, self).__init__() 22 | self.maxpooling = nn.MaxPool2d(kernel_size=ker_size, stride=stride, 23 | dilation=dilation, ceil_mode=ceil_mode) 24 | if p: 25 | self.dropout = nn.Dropout(p) 26 | 27 | def forward(self, x): 28 | for module in self.children(): 29 | x = module(x) 30 | return x 31 | 32 | 33 | class UpSampleBlock(nn.Module): 34 | def __init__(self, scale_factor=(2, 2), mode="bilinear", p=0.0): 35 | super(UpSampleBlock, self).__init__() 36 | self.upsample = nn.Upsample(scale_factor=scale_factor, mode=mode) 37 | if p: 38 | self.dropout = nn.Dropout(p) 39 | 40 | def forward(self, x): 41 | for module in self.children(): 42 | x = module(x) 43 | return x 44 | -------------------------------------------------------------------------------- /models/dynamic_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torchvision 4 | from torch import nn 5 | from torchvision import transforms 6 | 7 | from .blocks import * 8 | 9 | 10 | class Normalize: 11 | def __init__(self, opt, expected_values, variance): 12 | self.n_channels = opt.input_channel 13 | self.expected_values = expected_values 14 | self.variance = variance 15 | assert self.n_channels == len(self.expected_values) 16 | 17 | def __call__(self, x): 18 | x_clone = x.clone() 19 | for channel in range(self.n_channels): 20 | x_clone[:, channel] = (x[:, channel] - self.expected_values[channel]) / self.variance[channel] 21 | return x_clone 22 | 23 | 24 | class Denormalize: 25 | def __init__(self, opt, expected_values, variance): 26 | self.n_channels = opt.input_channel 27 | self.expected_values = expected_values 28 | self.variance = variance 29 | assert self.n_channels == len(self.expected_values) 30 | 31 | def __call__(self, x): 32 | x_clone = x.clone() 33 | for channel in range(self.n_channels): 34 | x_clone[:, channel] = x[:, channel] * self.variance[channel] + self.expected_values[channel] 35 | return x_clone 36 | 37 | 38 | # ---------------------------- Generators ----------------------------# 39 | 40 | 41 | class Generator(nn.Sequential): 42 | def __init__(self, opt, out_channels=None): 43 | super(Generator, self).__init__() 44 | if opt.dataset == "mnist": 45 | channel_init = 16 46 | steps = 2 47 | else: 48 | channel_init = 32 49 | steps = 3 50 | 51 | channel_current = opt.input_channel 52 | channel_next = channel_init 53 | for step in range(steps): 54 | self.add_module("convblock_down_{}".format(2 * step), Conv2dBlock(channel_current, channel_next)) 55 | self.add_module("convblock_down_{}".format(2 * step + 1), Conv2dBlock(channel_next, channel_next)) 56 | self.add_module("downsample_{}".format(step), DownSampleBlock()) 57 | if step < steps - 1: 58 | channel_current = channel_next 59 | channel_next *= 2 60 | 61 | self.add_module("convblock_middle", Conv2dBlock(channel_next, channel_next)) 62 | 63 | channel_current = channel_next 64 | channel_next = channel_current // 2 65 | for step in range(steps): 66 | self.add_module("upsample_{}".format(step), UpSampleBlock()) 67 | self.add_module("convblock_up_{}".format(2 * step), Conv2dBlock(channel_current, channel_current)) 68 | if step == steps - 1: 69 | self.add_module( 70 | "convblock_up_{}".format(2 * step + 1), Conv2dBlock(channel_current, channel_next, relu=False) 71 | ) 72 | else: 73 | self.add_module("convblock_up_{}".format(2 * step + 1), Conv2dBlock(channel_current, channel_next)) 74 | channel_current = channel_next 75 | channel_next = channel_next // 2 76 | if step == steps - 2: 77 | if out_channels is None: 78 | channel_next = opt.input_channel 79 | else: 80 | channel_next = out_channels 81 | 82 | self._EPSILON = 1e-7 83 | self._normalizer = self._get_normalize(opt) 84 | self._denormalizer = self._get_denormalize(opt) 85 | 86 | def _get_denormalize(self, opt): 87 | if opt.dataset == "cifar10": 88 | denormalizer = Denormalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]) 89 | elif opt.dataset == "mnist": 90 | denormalizer = Denormalize(opt, [0.5], [0.5]) 91 | elif opt.dataset == "gtsrb": 92 | denormalizer = None 93 | else: 94 | raise Exception("Invalid dataset") 95 | return denormalizer 96 | 97 | def _get_normalize(self, opt): 98 | if opt.dataset == "cifar10": 99 | normalizer = Normalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]) 100 | elif opt.dataset == "mnist": 101 | normalizer = Normalize(opt, [0.5], [0.5]) 102 | elif opt.dataset == "gtsrb": 103 | normalizer = None 104 | else: 105 | raise Exception("Invalid dataset") 106 | return normalizer 107 | 108 | def forward(self, x): 109 | for module in self.children(): 110 | x = module(x) 111 | x = nn.Tanh()(x) / (2 + self._EPSILON) + 0.5 112 | return x 113 | 114 | def normalize_pattern(self, x): 115 | if self._normalizer: 116 | x = self._normalizer(x) 117 | return x 118 | 119 | def denormalize_pattern(self, x): 120 | if self._denormalizer: 121 | x = self._denormalizer(x) 122 | return x 123 | 124 | def threshold(self, x): 125 | return nn.Tanh()(x * 20 - 10) / (2 + self._EPSILON) + 0.5 126 | 127 | 128 | # ---------------------------- Classifiers ----------------------------# 129 | 130 | 131 | class NetC_MNIST(nn.Module): 132 | def __init__(self): 133 | super(NetC_MNIST, self).__init__() 134 | self.conv1 = nn.Conv2d(1, 32, (5, 5), 1, 0) 135 | self.relu2 = nn.ReLU(inplace=True) 136 | self.dropout3 = nn.Dropout(0.1) 137 | 138 | self.maxpool4 = nn.MaxPool2d((2, 2)) 139 | self.conv5 = nn.Conv2d(32, 64, (5, 5), 1, 0) 140 | self.relu6 = nn.ReLU(inplace=True) 141 | self.dropout7 = nn.Dropout(0.1) 142 | 143 | self.maxpool5 = nn.MaxPool2d((2, 2)) 144 | self.flatten = nn.Flatten() 145 | self.linear6 = nn.Linear(64 * 4 * 4, 512) 146 | self.relu7 = nn.ReLU(inplace=True) 147 | self.dropout8 = nn.Dropout(0.1) 148 | self.linear9 = nn.Linear(512, 10) 149 | 150 | def forward(self, x): 151 | for module in self.children(): 152 | x = module(x) 153 | return x 154 | -------------------------------------------------------------------------------- /models/mask_batchnorm.py: -------------------------------------------------------------------------------- 1 | # This code is based on: 2 | # https://github.com/csdongxian/ANP_backdoor/blob/main/models/anp_batchnorm.py 3 | 4 | import torch 5 | from torch import Tensor 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.nn.init as init 9 | from torch.nn.parameter import Parameter 10 | 11 | 12 | class MaskBatchNorm2d(nn.BatchNorm2d): 13 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 14 | track_running_stats=True): 15 | super(MaskBatchNorm2d, self).__init__( 16 | num_features, eps, momentum, affine, track_running_stats) 17 | self.neuron_mask = Parameter(torch.Tensor(num_features)) 18 | self.neuron_noise = Parameter(torch.Tensor(num_features)) 19 | self.neuron_noise_bias = Parameter(torch.Tensor(num_features)) 20 | init.ones_(self.neuron_mask) 21 | 22 | def forward(self, input: Tensor) -> Tensor: 23 | self._check_input_dim(input) 24 | 25 | # exponential_average_factor is set to self.momentum 26 | # (when it is available) only so that it gets updated 27 | # in ONNX graph when this node is exported to ONNX. 28 | if self.momentum is None: 29 | exponential_average_factor = 0.0 30 | else: 31 | exponential_average_factor = self.momentum 32 | 33 | if self.training and self.track_running_stats: 34 | # TODO: if statement only here to tell the jit to skip emitting this when it is None 35 | if self.num_batches_tracked is not None: # type: ignore 36 | self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore 37 | if self.momentum is None: # use cumulative moving average 38 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 39 | else: # use exponential moving average 40 | exponential_average_factor = self.momentum 41 | 42 | r""" 43 | Decide whether the mini-batch stats should be used for normalization rather than the buffers. 44 | Mini-batch stats are used in training mode, and in eval mode when buffers are None. 45 | """ 46 | if self.training: 47 | bn_training = True 48 | else: 49 | bn_training = (self.running_mean is None) and (self.running_var is None) 50 | 51 | r""" 52 | Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be 53 | passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are 54 | used for normalization (i.e. in eval mode when buffers are not None). 55 | """ 56 | assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor) 57 | assert self.running_var is None or isinstance(self.running_var, torch.Tensor) 58 | 59 | coeff_weight = self.neuron_mask 60 | coeff_bias = 1.0 61 | 62 | return F.batch_norm( 63 | input, 64 | # If buffers are not to be tracked, ensure that they won't be updated 65 | self.running_mean if not self.training or self.track_running_stats else None, 66 | self.running_var if not self.training or self.track_running_stats else None, 67 | self.weight * coeff_weight, self.bias * coeff_bias, 68 | bn_training, exponential_average_factor, self.eps) 69 | 70 | 71 | -------------------------------------------------------------------------------- /models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | '''MobileNetV2 in PyTorch. 2 | See the paper "Inverted Residuals and Linear Bottlenecks: 3 | Mobile Networks for Classification, Detection and Segmentation" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Block(nn.Module): 11 | '''expand + depthwise + pointwise''' 12 | def __init__(self, in_planes, out_planes, expansion, stride, norm_layer): 13 | super(Block, self).__init__() 14 | self.stride = stride 15 | 16 | planes = expansion * in_planes 17 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 18 | self.bn1 = norm_layer(planes) 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 20 | self.bn2 = norm_layer(planes) 21 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 22 | self.bn3 = norm_layer(out_planes) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride == 1 and in_planes != out_planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 28 | norm_layer(out_planes), 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = F.relu(self.bn2(self.conv2(out))) 34 | out = self.bn3(self.conv3(out)) 35 | out = out + self.shortcut(x) if self.stride==1 else out 36 | return out 37 | 38 | 39 | class MobileNetV2(nn.Module): 40 | # (expansion, out_planes, num_blocks, stride) 41 | cfg = [(1, 16, 1, 1), 42 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 43 | (6, 32, 3, 2), 44 | (6, 64, 4, 2), 45 | (6, 96, 3, 1), 46 | (6, 160, 3, 2), 47 | (6, 320, 1, 1)] 48 | 49 | def __init__(self, num_classes=10, norm_layer=nn.BatchNorm2d): 50 | super(MobileNetV2, self).__init__() 51 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 52 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 53 | self.bn1 = norm_layer(32) 54 | self.layers = self._make_layers(in_planes=32, norm_layer=norm_layer) 55 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 56 | self.bn2 = norm_layer(1280) 57 | self.linear = nn.Linear(1280, num_classes) 58 | 59 | def _make_layers(self, in_planes, norm_layer): 60 | layers = [] 61 | for expansion, out_planes, num_blocks, stride in self.cfg: 62 | strides = [stride] + [1]*(num_blocks-1) 63 | for stride in strides: 64 | layers.append(Block(in_planes, out_planes, expansion, stride, norm_layer)) 65 | in_planes = out_planes 66 | return nn.Sequential(*layers) 67 | 68 | def forward(self, x): 69 | out = F.relu(self.bn1(self.conv1(x))) 70 | out = self.layers(out) 71 | out = F.relu(self.bn2(self.conv2(out))) 72 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 73 | out = F.avg_pool2d(out, 4) 74 | out = out.view(out.size(0), -1) 75 | out = self.linear(out) 76 | return out 77 | 78 | 79 | def test(): 80 | net = MobileNetV2() 81 | x = torch.randn(2,3,32,32) 82 | y = net(x) 83 | print(y.size()) 84 | 85 | # test() -------------------------------------------------------------------------------- /models/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | # This code is modified by 2 | # https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class BasicBlock(nn.Module): 9 | expansion = 1 10 | 11 | def __init__(self, in_planes, planes, stride=1, norm_layer=None): 12 | super(BasicBlock, self).__init__() 13 | self.conv1 = nn.Conv2d( 14 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 15 | self.bn1 = norm_layer(planes) 16 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 17 | stride=1, padding=1, bias=False) 18 | self.bn2 = norm_layer(planes) 19 | self.relu = nn.ReLU() 20 | 21 | self.shortcut = nn.Sequential() 22 | if stride != 1 or in_planes != self.expansion*planes: 23 | self.shortcut = nn.Sequential( 24 | nn.Conv2d(in_planes, self.expansion*planes, 25 | kernel_size=1, stride=stride, bias=False), 26 | norm_layer(self.expansion*planes) 27 | ) 28 | 29 | def forward(self, x): 30 | out = self.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | out += self.shortcut(x) 33 | out = self.relu(out) 34 | return out 35 | 36 | 37 | class Bottleneck(nn.Module): 38 | expansion = 4 39 | 40 | def __init__(self, in_planes, planes, stride=1, norm_layer=None): 41 | super(Bottleneck, self).__init__() 42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 43 | self.bn1 = norm_layer(planes) 44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 45 | stride=stride, padding=1, bias=False) 46 | self.bn2 = norm_layer(planes) 47 | self.conv3 = nn.Conv2d(planes, self.expansion * 48 | planes, kernel_size=1, bias=False) 49 | self.bn3 = norm_layer(self.expansion*planes) 50 | 51 | self.shortcut = nn.Sequential() 52 | if stride != 1 or in_planes != self.expansion*planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion*planes, 55 | kernel_size=1, stride=stride, bias=False), 56 | norm_layer(self.expansion*planes) 57 | ) 58 | 59 | def forward(self, x): 60 | out = F.relu(self.bn1(self.conv1(x))) 61 | out = F.relu(self.bn2(self.conv2(out))) 62 | out = self.bn3(self.conv3(out)) 63 | out += self.shortcut(x) 64 | out = F.relu(out) 65 | return out 66 | 67 | 68 | class ResNet(nn.Module): 69 | def __init__(self, block, num_blocks, num_classes=10, norm_layer=None): 70 | super(ResNet, self).__init__() 71 | if norm_layer is None: 72 | self._norm_layer = nn.BatchNorm2d 73 | else: 74 | self._norm_layer = norm_layer 75 | self.in_planes = 64 76 | 77 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 78 | stride=1, padding=1, bias=False) 79 | self.bn1 = self._norm_layer(64) 80 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 81 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 82 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 83 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 84 | self.linear = nn.Linear(512*block.expansion, num_classes) 85 | 86 | def _make_layer(self, block, planes, num_blocks, stride): 87 | strides = [stride] + [1]*(num_blocks-1) 88 | layers = [] 89 | for stride in strides: 90 | layers.append(block(self.in_planes, planes, stride, self._norm_layer)) 91 | self.in_planes = planes * block.expansion 92 | return nn.Sequential(*layers) 93 | 94 | def forward(self, x): 95 | out = F.relu(self.bn1(self.conv1(x))) 96 | out = self.layer1(out) 97 | out = self.layer2(out) 98 | out = self.layer3(out) 99 | out = self.layer4(out) 100 | out = F.avg_pool2d(out, 4) 101 | out = out.view(out.size(0), -1) 102 | out = self.linear(out) 103 | return out 104 | 105 | 106 | def resnet18(num_classes=10, norm_layer=nn.BatchNorm2d): 107 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes, norm_layer) 108 | 109 | 110 | def resnet34(num_classes=10, norm_layer=nn.BatchNorm2d): 111 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes, norm_layer) 112 | 113 | 114 | def resnet50(num_classes=10, norm_layer=nn.BatchNorm2d): 115 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes, norm_layer) 116 | 117 | 118 | def resnet101(num_classes=10, norm_layer=nn.BatchNorm2d): 119 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes, norm_layer) 120 | 121 | 122 | def resnet152(num_classes=10, norm_layer=nn.BatchNorm2d): 123 | return ResNet(Bottleneck, [3, 8, 36, 3], num_classes, norm_layer) 124 | 125 | 126 | def test(): 127 | net = resnet18() 128 | y = net(torch.randn(1, 3, 32, 32)) 129 | print(y.size()) 130 | 131 | # test() -------------------------------------------------------------------------------- /models/vgg_cifar.py: -------------------------------------------------------------------------------- 1 | '''VGG for CIFAR10. FC layers are removed. 2 | (c) YANG, Wei 3 | ''' 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | import math 7 | 8 | 9 | __all__ = [ 10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 11 | 'vgg19_bn', 'vgg19', 12 | ] 13 | 14 | 15 | model_urls = { 16 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 17 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 18 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 19 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 20 | } 21 | 22 | 23 | class VGG(nn.Module): 24 | 25 | def __init__(self, features, num_classes=10): 26 | super(VGG, self).__init__() 27 | self.features = features 28 | self.classifier = nn.Linear(512, num_classes) 29 | self._initialize_weights() 30 | 31 | def forward(self, x): 32 | x = self.features(x) 33 | x = x.view(x.size(0), -1) 34 | x = self.classifier(x) 35 | return x 36 | 37 | def _initialize_weights(self): 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 41 | m.weight.data.normal_(0, math.sqrt(2. / n)) 42 | if m.bias is not None: 43 | m.bias.data.zero_() 44 | elif isinstance(m, nn.BatchNorm2d): 45 | m.weight.data.fill_(1) 46 | m.bias.data.zero_() 47 | elif isinstance(m, nn.Linear): 48 | n = m.weight.size(1) 49 | m.weight.data.normal_(0, 0.01) 50 | m.bias.data.zero_() 51 | 52 | 53 | def make_layers(cfg, batch_norm=False, norm_layer=None): 54 | if norm_layer is None: 55 | norm_layer = nn.BatchNorm2d 56 | layers = [] 57 | in_channels = 3 58 | for v in cfg: 59 | if v == 'M': 60 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 61 | else: 62 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 63 | if batch_norm: 64 | layers += [conv2d, norm_layer(v), nn.ReLU(inplace=True)] 65 | else: 66 | layers += [conv2d, nn.ReLU(inplace=True)] 67 | in_channels = v 68 | return nn.Sequential(*layers) 69 | 70 | 71 | cfg = { 72 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 73 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 74 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 75 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 76 | } 77 | 78 | 79 | def vgg11(norm_layer=nn.BatchNorm2d, **kwargs): 80 | """VGG 11-layer model (configuration "A") 81 | Args: 82 | pretrained (bool): If True, returns a model pre-trained on ImageNet 83 | """ 84 | model = VGG(make_layers(cfg['A'], norm_layer=norm_layer), **kwargs) 85 | return model 86 | 87 | 88 | def vgg11_bn(norm_layer=nn.BatchNorm2d, **kwargs): 89 | """VGG 11-layer model (configuration "A") with batch normalization""" 90 | model = VGG(make_layers(cfg['A'], batch_norm=True, norm_layer=norm_layer), **kwargs) 91 | return model 92 | 93 | 94 | def vgg13(norm_layer=nn.BatchNorm2d, **kwargs): 95 | """VGG 13-layer model (configuration "B") 96 | Args: 97 | pretrained (bool): If True, returns a model pre-trained on ImageNet 98 | """ 99 | model = VGG(make_layers(cfg['B'], norm_layer=norm_layer), **kwargs) 100 | return model 101 | 102 | 103 | def vgg13_bn(norm_layer=nn.BatchNorm2d, **kwargs): 104 | """VGG 13-layer model (configuration "B") with batch normalization""" 105 | model = VGG(make_layers(cfg['B'], batch_norm=True, norm_layer=norm_layer), **kwargs) 106 | return model 107 | 108 | 109 | def vgg16(norm_layer=nn.BatchNorm2d, **kwargs): 110 | """VGG 16-layer model (configuration "D") 111 | Args: 112 | pretrained (bool): If True, returns a model pre-trained on ImageNet 113 | """ 114 | model = VGG(make_layers(cfg['D'], norm_layer=norm_layer), **kwargs) 115 | return model 116 | 117 | 118 | def vgg16_bn(norm_layer=nn.BatchNorm2d, **kwargs): 119 | """VGG 16-layer model (configuration "D") with batch normalization""" 120 | model = VGG(make_layers(cfg['D'], batch_norm=True, norm_layer=norm_layer), **kwargs) 121 | return model 122 | 123 | 124 | def vgg19(norm_layer=nn.BatchNorm2d, **kwargs): 125 | """VGG 19-layer model (configuration "E") 126 | Args: 127 | pretrained (bool): If True, returns a model pre-trained on ImageNet 128 | """ 129 | model = VGG(make_layers(cfg['E'], norm_layer=norm_layer), **kwargs) 130 | return model 131 | 132 | 133 | def vgg19_bn(norm_layer=nn.BatchNorm2d, **kwargs): 134 | """VGG 19-layer model (configuration 'E') with batch normalization""" 135 | model = VGG(make_layers(cfg['E'], batch_norm=True, norm_layer=norm_layer), **kwargs) 136 | return model 137 | 138 | 139 | if __name__ == '__main__': 140 | import torch 141 | x = torch.randn((5, 3, 32, 32)) 142 | # net1 = vgg19_bn(norm_layer=nn.BatchNorm1d) 143 | net2 = vgg19(norm_layer=nn.BatchNorm1d) 144 | # y1 = net1(x) 145 | y2 = net2(x) 146 | 147 | -------------------------------------------------------------------------------- /train_backdoor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import logging 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import pandas as pd 9 | from collections import OrderedDict 10 | import models 11 | from data.poison_tool_cifar import get_backdoor_loader, get_test_loader, get_train_loader 12 | 13 | if torch.cuda.is_available(): 14 | torch.backends.cudnn.enabled = True 15 | torch.backends.cudnn.benchmark = True 16 | device = torch.device('cuda') 17 | else: 18 | device = torch.device('cpu') 19 | 20 | seed = 98 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | torch.manual_seed(seed) 24 | np.random.seed(seed) 25 | 26 | 27 | def train_step(args, model, criterion, optimizer, data_loader): 28 | model.train() 29 | total_correct = 0 30 | total_loss = 0.0 31 | for i, (images, labels) in enumerate(data_loader): 32 | images, labels = images.to(device), labels.to(device) 33 | optimizer.zero_grad() 34 | output = model(images) 35 | loss = criterion(output, labels) 36 | 37 | pred = output.data.max(1)[1] 38 | total_correct += pred.eq(labels.view_as(pred)).sum() 39 | total_loss += loss.item() 40 | 41 | nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2) 42 | loss.backward() 43 | optimizer.step() 44 | 45 | loss = total_loss / len(data_loader) 46 | acc = float(total_correct) / len(data_loader.dataset) 47 | return loss, acc 48 | 49 | 50 | def test(model, criterion, data_loader): 51 | model.eval() 52 | total_correct = 0 53 | total_loss = 0.0 54 | with torch.no_grad(): 55 | for i, (images, labels) in enumerate(data_loader): 56 | images, labels = images.to(device), labels.to(device) 57 | output = model(images) 58 | total_loss += criterion(output, labels).item() 59 | pred = output.data.max(1)[1] 60 | total_correct += pred.eq(labels.data.view_as(pred)).sum() 61 | loss = total_loss / len(data_loader) 62 | acc = float(total_correct) / len(data_loader.dataset) 63 | return loss, acc 64 | 65 | 66 | def save_checkpoint(state, file_path): 67 | # filepath = os.path.join(args.output_dir, args.arch + '-unlearning_epochs{}.tar'.format(epoch)) 68 | torch.save(state, file_path) 69 | 70 | def main(args): 71 | logger = logging.getLogger(__name__) 72 | logging.basicConfig( 73 | format='[%(asctime)s] - %(message)s', 74 | datefmt='%Y/%m/%d %H:%M:%S', 75 | level=logging.DEBUG, 76 | handlers=[ 77 | logging.FileHandler(os.path.join(args.log_root, 'output.log')), 78 | logging.StreamHandler() 79 | ]) 80 | logger.info(args) 81 | 82 | logger.info('----------- Backdoored Data Initialization --------------') 83 | _, backdoor_data_loader = get_backdoor_loader(args) 84 | clean_test_loader, bad_test_loader = get_test_loader(args) 85 | 86 | logger.info('----------- Backdoor Model Initialization --------------') 87 | net = getattr(models, args.arch)(num_classes=10, norm_layer=None) 88 | net = net.to(device) 89 | 90 | criterion = torch.nn.CrossEntropyLoss().to(device) 91 | optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 92 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.schedule, gamma=0.1) 93 | 94 | logger.info('----------- Backdoor Model Training--------------') 95 | logger.info('Epoch \t lr \t Time \t TrainLoss \t TrainACC \t PoisonLoss \t PoisonACC \t CleanLoss \t CleanACC') 96 | for epoch in range(0, args.epochs + 1): 97 | start = time.time() 98 | lr = optimizer.param_groups[0]['lr'] 99 | train_loss, train_acc = train_step(args=args, model=net, criterion=criterion, optimizer=optimizer, 100 | data_loader=backdoor_data_loader) 101 | cl_test_loss, cl_test_acc = test(model=net, criterion=criterion, data_loader=clean_test_loader) 102 | po_test_loss, po_test_acc = test(model=net, criterion=criterion, data_loader=bad_test_loader) 103 | scheduler.step() 104 | end = time.time() 105 | logger.info( 106 | '%d \t %.3f \t %.1f \t %.4f \t %.4f \t %.4f \t %.4f \t %.4f \t %.4f', 107 | epoch, lr, end - start, train_loss, train_acc, po_test_loss, po_test_acc, 108 | cl_test_loss, cl_test_acc) 109 | 110 | if epoch % args.interval == 0: 111 | # save the last checkpoint 112 | file_path = os.path.join(args.output_weight, f'backdoor_model.tar') 113 | save_checkpoint({ 114 | 'epoch': epoch, 115 | 'state_dict': net.state_dict(), 116 | 'clean_acc': cl_test_acc, 117 | 'bad_acc': po_test_acc, 118 | 'optimizer': optimizer.state_dict(), 119 | }, file_path) 120 | 121 | 122 | 123 | if __name__ == '__main__': 124 | # Prepare arguments 125 | parser = argparse.ArgumentParser() 126 | 127 | # various path 128 | parser.add_argument('--cuda', type=int, default=1, help='cuda available') 129 | parser.add_argument('--save-every', type=int, default=5, help='save checkpoints every few epochs') 130 | parser.add_argument('--log_root', type=str, default='logs/', help='logs are saved here') 131 | parser.add_argument('--output_weight', type=str, default='weights/') 132 | parser.add_argument('--arch', type=str, default='resnet18', 133 | choices=['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'MobileNetV2', 134 | 'vgg19_bn']) 135 | parser.add_argument('--schedule', type=int, nargs='+', default=[10, 20], 136 | help='Decrease learning rate at these epochs.') 137 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='name of image dataset') 138 | parser.add_argument('--batch_size', type=int, default=128, help='The size of batch') 139 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 140 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay') 141 | parser.add_argument('--num_class', type=int, default=10, help='number of classes') 142 | parser.add_argument('--lr', type=int, default=0.1, help='the number of epochs for unlearning') 143 | parser.add_argument('--epochs', type=int, default=60, help='the number of epochs for training') 144 | parser.add_argument('--interval', type=int, default=10, help='the interval of saving weight') 145 | 146 | # backdoor attacks 147 | parser.add_argument('--target_label', type=int, default=0, help='class of target label') 148 | parser.add_argument('--trigger_type', type=str, default='gridTrigger', help='type of backdoor trigger') 149 | parser.add_argument('--target_type', type=str, default='all2one', help='type of backdoor label') 150 | parser.add_argument('--trig_w', type=int, default=3, help='width of trigger pattern') 151 | parser.add_argument('--trig_h', type=int, default=3, help='height of trigger pattern') 152 | parser.add_argument('--inject_portion', type=float, default=0.1, help='ratio of backdoor poisoned data') 153 | 154 | 155 | args = parser.parse_args() 156 | args_dict = vars(args) 157 | print(args_dict) 158 | os.makedirs(args.log_root, exist_ok=True) 159 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 160 | 161 | main(args) 162 | -------------------------------------------------------------------------------- /trigger/best_square_trigger_cifar10.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/RNP/eeae192e5eab974d8b3002964cfb62d00388d36f/trigger/best_square_trigger_cifar10.npz -------------------------------------------------------------------------------- /trigger/signal_cifar10_mask.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bboylyg/RNP/eeae192e5eab974d8b3002964cfb62d00388d36f/trigger/signal_cifar10_mask.npy --------------------------------------------------------------------------------