├── README.md ├── perturbations_are_linear_separable ├── README.md └── test_linear_separability.py └── synthetic_perturbations ├── README.md ├── cifar_train.py ├── models ├── DenseNet.py ├── ResNet.py └── __init__.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # Availability-Attacks-Create-Shortcuts 2 | 3 | Code for KDD 2022 Research Track paper ["Availability Attacks Create Shortcuts"](https://arxiv.org/abs/2111.00898). 4 | 5 | This code is tested with python3 and PyTorch 1.8. 6 | 7 | Please first install the following packages: 8 | 9 | scikit-learn, torch, numpy 10 | 11 | The code for testing the linear separability of existing availability attacks are in the `perturbations_are_linear_separable` folder. The code for testing synthetic shortcuts as availability attacks are in the `synthetic_perturbations` folder. 12 | -------------------------------------------------------------------------------- /perturbations_are_linear_separable/README.md: -------------------------------------------------------------------------------- 1 | You can use this code to verify the perturbations of existing indiscriminate poisoning attacks are linear separable. 2 | 3 | The [clean](https://drive.google.com/drive/folders/1NpXyJozirOSJ5bXBSeK7rtx9kBA6VttE) and [perturbed](https://drive.google.com/drive/folders/1OD54_gK6wnhyVwQGnHs7vIsKVOL-48zd) data of [NTGA](https://github.com/lionelmessi6410/ntga) can be downloaded directly. Note that you need to subtract clean data from the perturbed data to get perturbations. 4 | 5 | 6 | For [DeepConfuse](https://github.com/kingfengji/DeepConfuse), [error-minimizing noise](https://github.com/HanxunH/Unlearnable-Examples), and [error-maximizing noise](https://github.com/lhfowl/adversarial_poisons), we need to manually generate perturbed data using their official implementations. We run their code and provide the results at [here](https://drive.google.com/file/d/1v9mAzowQ1GVxjTWZfhjsICLdGyLfiWOY/view?usp=sharing). 7 | 8 | If you have downloaded 'x_train_cifar10_ntga_cnn_best.npy', 'x_train_cifar10.npy', and 'y_train_cifar10.npy', run the following command to check the accuracy of a linear model: 9 | 10 | python test_linear_separability.py --hidden_layers 0 --perturbed_x_path x_train_cifar10_ntga_cnn_best.npy --clean_x_path x_train_cifar10.npy --label_path y_train_cifar10.npy 11 | 12 | 13 | -------------------------------------------------------------------------------- /perturbations_are_linear_separable/test_linear_separability.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | 6 | import argparse 7 | 8 | import numpy as np 9 | import random 10 | 11 | def normalize_01(tsr): 12 | maxv = torch.max(tsr) 13 | minv = torch.min(tsr) 14 | return (tsr-minv)/(maxv-minv) 15 | 16 | def train(train_data, train_targets, net, optimizer): 17 | 18 | 19 | optimizer.zero_grad() 20 | inputs, targets = train_data, train_targets 21 | def closure(): 22 | optimizer.zero_grad() 23 | outputs = net(inputs) 24 | loss = loss_func(outputs, targets) 25 | loss.backward() 26 | return loss 27 | 28 | optimizer.step(closure) 29 | with torch.no_grad(): 30 | outputs = net(inputs) 31 | loss = loss_func(outputs, targets) 32 | 33 | train_loss = loss.item() 34 | _, predicted = torch.max(outputs.data, 1) 35 | total = targets.size(0) 36 | correct = predicted.eq(targets.data).float().cpu().sum() 37 | acc = 100.*float(correct)/float(total) 38 | return (train_loss, acc) 39 | 40 | parser = argparse.ArgumentParser(description='Fit perturbations with simple models') 41 | parser.add_argument('--perturbed_x_path', default='x_train_cifar10_ntga_cnn_best.npy', type=str, help='path of perturbed data') 42 | parser.add_argument('--clean_x_path', default='x_train_cifar10.npy', type=str, help='path of clean data') 43 | parser.add_argument('--label_path', default='y_train_cifar10.npy', type=str, help='path of labels') 44 | parser.add_argument('--hidden_layers', default=0, type=int, help='number of hidden layers') 45 | 46 | args = parser.parse_args() 47 | 48 | 49 | perturbed_x = np.load(args.perturbed_x_path) 50 | clean_x = np.load(args.clean_x_path) 51 | labels = np.load(args.label_path) 52 | if(len(labels.shape)>1): #one-hot format 53 | labels = np.argmax(labels, axis=1) 54 | 55 | perturbations = perturbed_x - clean_x 56 | 57 | perturbations = torch.tensor(perturbations, dtype=torch.float).cuda() 58 | labels = torch.tensor(labels, dtype=torch.long).cuda() 59 | 60 | 61 | 62 | loss_func = nn.CrossEntropyLoss() 63 | train_data = normalize_01(perturbations) 64 | train_targets = labels 65 | 66 | num_classes = 10 # CIFAR-10 dataset 67 | 68 | module_list = [nn.Flatten()] 69 | input_dim = np.prod(train_data.shape[1:]) 70 | 71 | hidden_width = 30 72 | for i in range(args.hidden_layers): 73 | module_list.append(nn.Linear(input_dim, hidden_width)) 74 | module_list.append(nn.Tanh()) 75 | input_dim = hidden_width 76 | 77 | module_list += [nn.Linear(input_dim, num_classes)] 78 | 79 | net = nn.Sequential(*module_list) 80 | net = net.cuda() 81 | optimizer = optim.LBFGS(net.parameters(), lr=0.5) 82 | 83 | for step in range(50): 84 | train_loss, train_acc = train(train_data, train_targets, net, optimizer) 85 | print('training loss: %.3f'%train_loss, 'training acc: %.2f'%train_acc) 86 | 87 | -------------------------------------------------------------------------------- /synthetic_perturbations/README.md: -------------------------------------------------------------------------------- 1 | Here are some example commands to test synthetic shortcuts as availability attacks. 2 | 3 | To test synthetic noises on CIFAR-10 and ResNet18 without data augmentation: 4 | 5 | CUDA_VISIBLE_DEVICES=0 python cifar_train.py --model resnet18 --dataset c10 6 | 7 | To test synthetic noises on CIFAR-10 and ResNet18 with data augmentation: 8 | 9 | CUDA_VISIBLE_DEVICES=0 python cifar_train.py --model resnet18 --dataset c10 --aug 10 | 11 | You can also change the model or dataset: 12 | 13 | CUDA_VISIBLE_DEVICES=0 python cifar_train.py --model densenet --dataset c100 --aug 14 | 15 | Add '--clean' flag to train the model on clean data: 16 | 17 | CUDA_VISIBLE_DEVICES=0 python cifar_train.py --model resnet18 --dataset c10 --aug --clean 18 | -------------------------------------------------------------------------------- /synthetic_perturbations/cifar_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | import torchvision.models as models 6 | from torchvision import datasets, transforms 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data.dataset import Dataset 9 | 10 | import random 11 | import time 12 | import argparse 13 | import numpy as np 14 | from sklearn.datasets import make_classification 15 | 16 | from models.ResNet import ResNet18, ResNet50 17 | from models.DenseNet import DenseNet121 18 | from util import AverageMeter, cross_entropy, accuracy, comput_l2norm_lim, normalize_l2norm, adjust_learning_rate 19 | 20 | 21 | parser = argparse.ArgumentParser(description='synthetic perturbations') 22 | parser.add_argument('--dataset', type=str, default='c10', help='[c10, c100, svhn]') 23 | parser.add_argument('--aug', action='store_true', default=False, help=' use data augmentation') 24 | parser.add_argument('--eps', type=int, default=6, help='perturbation strength') 25 | 26 | parser.add_argument('--epoch', type=int, default=100, help='running epochs') 27 | parser.add_argument('--batchsize', type=int, default=128, help='batchsize') 28 | parser.add_argument('--patchsize', type=int, default=8, help='size of patch') 29 | parser.add_argument('--lr', type=float, default=0.1, help='learning rate') 30 | 31 | parser.add_argument('--model', type=str, default='resnet18', help='[vgg, resnet18, resnet50, densenet]') 32 | 33 | parser.add_argument('--sess', type=str, default='default', help='session name for experiment') 34 | parser.add_argument('--seed', type=int, default=1, help='random seed') 35 | parser.add_argument('--clean', action='store_true', default=False, help='use clean data') 36 | 37 | args = parser.parse_args() 38 | 39 | 40 | torch.manual_seed(args.seed) 41 | torch.cuda.manual_seed(args.seed) 42 | np.random.seed(args.seed) 43 | random.seed(args.seed) 44 | 45 | dataset = args.dataset 46 | 47 | if(dataset == 'c10'): 48 | data_func = datasets.CIFAR10 49 | elif(dataset == 'c100'): 50 | data_func = datasets.CIFAR100 51 | elif(dataset == 'svhn'): 52 | data_func = datasets.SVHN 53 | 54 | if(dataset == 'c100'): 55 | num_classes = 100 56 | else: 57 | num_classes = 10 58 | 59 | # Data 60 | print('==> Preparing data..') 61 | 62 | plain_transform = transforms.Compose([ 63 | transforms.ToTensor() 64 | ]) 65 | 66 | aug_transform = transforms.Compose([ 67 | transforms.RandomCrop(32, padding=4), 68 | transforms.RandomHorizontalFlip(), 69 | transforms.ToTensor() 70 | ]) 71 | 72 | train_transform = test_transform = plain_transform 73 | 74 | if(args.aug): 75 | train_transform = aug_transform 76 | 77 | 78 | if(args.dataset == 'svhn'): 79 | train_dataset = data_func(root='../datasets', split='train', download=True, transform=train_transform) 80 | else: 81 | train_dataset = data_func(root='../datasets', train=True, download=True, transform=train_transform) 82 | train_loader = DataLoader(dataset=train_dataset, batch_size=args.batchsize, shuffle=False, pin_memory=True, drop_last=False, num_workers=4) 83 | 84 | if(args.dataset == 'svhn'): 85 | test_dataset = data_func(root='../datasets', split='test', download=True, transform=test_transform) 86 | else: 87 | test_dataset = data_func(root='../datasets', train=False, download=True, transform=test_transform) 88 | test_loader = DataLoader(dataset=test_dataset, batch_size=512, shuffle=False, pin_memory=True, drop_last=False, num_workers=4) 89 | 90 | 91 | 92 | if(not args.clean): 93 | n = train_dataset.data.shape[0] 94 | if(args.dataset == 'svhn'): # ensure we generate enough synthetic data 95 | n *= 2 96 | 97 | img_size = 32 98 | noise_frame_size = args.patchsize 99 | 100 | is_even = img_size % noise_frame_size 101 | 102 | num_patch = img_size//noise_frame_size 103 | if(is_even > 0): 104 | num_patch += 1 105 | 106 | n_random_fea = int((img_size/noise_frame_size)**2 * 3) 107 | 108 | # generate initial data points 109 | simple_data, simple_label = make_classification(n_samples=n, n_features=n_random_fea, n_classes=num_classes, n_informative=n_random_fea, n_redundant=0, n_repeated=0, class_sep=10., flip_y=0., n_clusters_per_class=1) 110 | simple_data = simple_data.reshape([simple_data.shape[0], num_patch, num_patch, 3]) 111 | simple_data = simple_data.astype(np.float32) 112 | 113 | # duplicate each dimension to get 2-D patches 114 | simple_images = np.repeat(simple_data, noise_frame_size, 2) 115 | simple_images = np.repeat(simple_images, noise_frame_size, 1) 116 | simple_data = simple_images[:, 0:img_size, 0:img_size, :] 117 | 118 | # project the synthetic images into a small L2 ball 119 | linf = args.eps/255. 120 | feature_dim = img_size**2 * 3 121 | l2norm_lim = comput_l2norm_lim(linf, feature_dim) 122 | simple_data = normalize_l2norm(simple_data, l2norm_lim) 123 | 124 | 125 | train_dataset.data = train_dataset.data.astype(np.float)/255. 126 | if(args.dataset == 'svhn'): 127 | train_dataset.data = np.transpose(train_dataset.data, [0, 2, 3, 1]) 128 | arr_target = train_dataset.labels 129 | else: 130 | arr_target = np.array(train_dataset.targets) 131 | 132 | # add synthetic noises to original examples 133 | for label in range(num_classes): 134 | orig_data_idx = arr_target == label 135 | simple_data_idx = simple_label == label 136 | mini_simple_data = simple_data[simple_data_idx][0:int(sum(orig_data_idx))] 137 | train_dataset.data[orig_data_idx] += mini_simple_data 138 | 139 | train_dataset.data = np.clip((train_dataset.data*255), 0, 255).astype(np.uint8) 140 | if(args.dataset == 'svhn'): 141 | train_dataset.data = np.transpose(train_dataset.data, [0, 3, 1, 2]) 142 | 143 | 144 | 145 | if(args.model == 'resnet18'): 146 | model = ResNet18(num_classes = num_classes) 147 | elif(args.model == 'resnet50'): 148 | model = ResNet50(num_classes = num_classes) 149 | elif(args.model == 'vgg'): 150 | model = models.vgg11(num_classes = num_classes) 151 | elif(args.model == 'densenet'): 152 | model = DenseNet121(num_classes = num_classes) 153 | 154 | model = model.cuda() 155 | criterion = torch.nn.CrossEntropyLoss() 156 | test_criterion = torch.nn.CrossEntropyLoss() 157 | optimizer = torch.optim.SGD(params=model.parameters(), lr=args.lr, weight_decay=5e-4, momentum=0.9) 158 | 159 | for epoch in range(args.epoch): 160 | adjust_learning_rate(optimizer, args.lr, epoch, all_epoch=args.epoch) 161 | 162 | # Train 163 | model.train() 164 | acc_meter = AverageMeter() 165 | loss_meter = AverageMeter() 166 | 167 | time0 = time.time() 168 | 169 | for images, labels in train_loader: 170 | images, labels = images.cuda(), labels.cuda() 171 | 172 | model.zero_grad() 173 | optimizer.zero_grad() 174 | 175 | logits = model(images) 176 | loss = criterion(logits, labels) 177 | loss.backward() 178 | optimizer.step() 179 | 180 | 181 | _, predicted = torch.max(logits.data, 1) 182 | acc = (predicted == labels).sum().item()/labels.size(0) 183 | 184 | acc_meter.update(acc) 185 | loss_meter.update(loss.item()) 186 | 187 | print('Epoch %d, '%epoch, "Train acc %.2f loss: %.2f" % (acc_meter.avg*100, loss_meter.avg), end=' ') 188 | 189 | # Eval 190 | model.eval() 191 | correct, total = 0, 0 192 | for i, (images, labels) in enumerate(test_loader): 193 | images, labels = images.cuda(), labels.cuda() 194 | with torch.no_grad(): 195 | logits = model(images) 196 | test_loss = test_criterion(logits, labels) 197 | _, predicted = torch.max(logits.data, 1) 198 | total += labels.size(0) 199 | correct += (predicted == labels).sum().item() 200 | 201 | time1 = time.time() 202 | 203 | acc = correct / total 204 | print("Test acc %.2f loss: %.2f, epoch time: %ds" % (acc*100, test_loss.item(), time1-time0)) 205 | 206 | 207 | -------------------------------------------------------------------------------- /synthetic_perturbations/models/DenseNet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | https://github.com/kuangliu/pytorch-cifar 3 | DenseNet in PyTorch. 4 | ''' 5 | import math 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class Bottleneck(nn.Module): 13 | def __init__(self, in_planes, growth_rate): 14 | super(Bottleneck, self).__init__() 15 | self.bn1 = nn.BatchNorm2d(in_planes) 16 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 17 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 18 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 19 | 20 | def forward(self, x): 21 | out = self.conv1(F.relu(self.bn1(x))) 22 | out = self.conv2(F.relu(self.bn2(out))) 23 | out = torch.cat([out, x], 1) 24 | return out 25 | 26 | 27 | class Transition(nn.Module): 28 | def __init__(self, in_planes, out_planes): 29 | super(Transition, self).__init__() 30 | self.bn = nn.BatchNorm2d(in_planes) 31 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 32 | 33 | def forward(self, x): 34 | out = self.conv(F.relu(self.bn(x))) 35 | out = F.avg_pool2d(out, 2) 36 | return out 37 | 38 | 39 | class DenseNet(nn.Module): 40 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 41 | super(DenseNet, self).__init__() 42 | self.growth_rate = growth_rate 43 | 44 | num_planes = 2*growth_rate 45 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 46 | 47 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 48 | num_planes += nblocks[0]*growth_rate 49 | out_planes = int(math.floor(num_planes*reduction)) 50 | self.trans1 = Transition(num_planes, out_planes) 51 | num_planes = out_planes 52 | 53 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 54 | num_planes += nblocks[1]*growth_rate 55 | out_planes = int(math.floor(num_planes*reduction)) 56 | self.trans2 = Transition(num_planes, out_planes) 57 | num_planes = out_planes 58 | 59 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 60 | num_planes += nblocks[2]*growth_rate 61 | out_planes = int(math.floor(num_planes*reduction)) 62 | self.trans3 = Transition(num_planes, out_planes) 63 | num_planes = out_planes 64 | 65 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 66 | num_planes += nblocks[3]*growth_rate 67 | 68 | self.bn = nn.BatchNorm2d(num_planes) 69 | self.linear = nn.Linear(num_planes, num_classes) 70 | 71 | def _make_dense_layers(self, block, in_planes, nblock): 72 | layers = [] 73 | for i in range(nblock): 74 | layers.append(block(in_planes, self.growth_rate)) 75 | in_planes += self.growth_rate 76 | return nn.Sequential(*layers) 77 | 78 | def forward(self, x): 79 | out = self.conv1(x) 80 | out = self.trans1(self.dense1(out)) 81 | out = self.trans2(self.dense2(out)) 82 | out = self.trans3(self.dense3(out)) 83 | out = self.dense4(out) 84 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 85 | out = out.view(out.size(0), -1) 86 | out = self.linear(out) 87 | return out 88 | 89 | 90 | def DenseNet121(num_classes=10): 91 | return DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=32, num_classes=num_classes) 92 | 93 | 94 | def DenseNet169(num_classes=10): 95 | return DenseNet(Bottleneck, [6, 12, 32, 32], growth_rate=32, num_classes=num_classes) 96 | 97 | 98 | def DenseNet201(num_classes=10): 99 | return DenseNet(Bottleneck, [6, 12, 48, 32], growth_rate=32, num_classes=num_classes) 100 | 101 | 102 | def DenseNet161(num_classes=10): 103 | return DenseNet(Bottleneck, [6, 12, 36, 24], growth_rate=48, num_classes=num_classes) 104 | 105 | 106 | def densenet_cifar(): 107 | return DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=12) 108 | -------------------------------------------------------------------------------- /synthetic_perturbations/models/ResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class LinearNet(nn.Module): 8 | def __init__(self, layers, indim, width, outdim, acti=nn.ReLU): 9 | super(LinearNet, self).__init__() 10 | 11 | self.acti = acti() 12 | 13 | self.layers = [] 14 | self.layers.append(nn.Linear(indim, width)) 15 | self.layers.append(self.acti) 16 | 17 | for i in range(1, layers-1): 18 | self.layers.append(nn.Linear(width, width)) 19 | self.layers.append(self.acti) 20 | 21 | self.layers.append(nn.Linear(width, outdim)) 22 | self.layer = nn.Sequential(*self.layers) 23 | 24 | def forward(self, x): 25 | return self.layer(x) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, in_planes, planes, stride=1): 32 | super(BasicBlock, self).__init__() 33 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 34 | self.bn1 = nn.BatchNorm2d(planes) 35 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | 38 | self.shortcut = nn.Sequential() 39 | if stride != 1 or in_planes != self.expansion * planes: 40 | self.shortcut = nn.Sequential( 41 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 42 | nn.BatchNorm2d(self.expansion * planes) 43 | ) 44 | 45 | def forward(self, x): 46 | out = F.relu(self.bn1(self.conv1(x))) 47 | out = self.bn2(self.conv2(out)) 48 | out += self.shortcut(x) 49 | out = F.relu(out) 50 | return out 51 | 52 | 53 | class NoresBasicBlock(nn.Module): 54 | expansion = 1 55 | 56 | def __init__(self, in_planes, planes, stride=1): 57 | super(NoresBasicBlock, self).__init__() 58 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(planes) 60 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 61 | self.bn2 = nn.BatchNorm2d(planes) 62 | 63 | def forward(self, x): 64 | out = F.relu(self.bn1(self.conv1(x))) 65 | out = self.bn2(self.conv2(out)) 66 | out = F.relu(out) 67 | return out 68 | 69 | class Bottleneck(nn.Module): 70 | expansion = 4 71 | 72 | def __init__(self, in_planes, planes, stride=1): 73 | super(Bottleneck, self).__init__() 74 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(planes) 76 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 77 | self.bn2 = nn.BatchNorm2d(planes) 78 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 79 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 80 | 81 | self.shortcut = nn.Sequential() 82 | if stride != 1 or in_planes != self.expansion * planes: 83 | self.shortcut = nn.Sequential( 84 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 85 | nn.BatchNorm2d(self.expansion * planes) 86 | ) 87 | 88 | def forward(self, x): 89 | out = F.relu(self.bn1(self.conv1(x))) 90 | out = F.relu(self.bn2(self.conv2(out))) 91 | out = self.bn3(self.conv3(out)) 92 | out += self.shortcut(x) 93 | out = F.relu(out) 94 | return out 95 | 96 | 97 | class TransposeBasicBlock(nn.Module): 98 | expansion = 1 99 | 100 | def __init__(self, in_planes, planes, stride=1): 101 | super(TransposeBasicBlock, self).__init__() 102 | if(stride == 2): 103 | output_padding = 1 104 | else: 105 | output_padding = 0 106 | self.conv1 = nn.ConvTranspose2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, output_padding=output_padding, bias=False) 107 | self.bn1 = nn.BatchNorm2d(planes) 108 | self.conv2 = nn.ConvTranspose2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 109 | self.bn2 = nn.BatchNorm2d(planes) 110 | 111 | self.shortcut = nn.Sequential() 112 | if stride != 1 or in_planes != self.expansion * planes: 113 | self.shortcut = nn.Sequential( 114 | nn.ConvTranspose2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, output_padding=output_padding, bias=False), 115 | nn.BatchNorm2d(self.expansion * planes) 116 | ) 117 | 118 | def forward(self, x): 119 | out = F.relu(self.bn1(self.conv1(x))) 120 | out = self.bn2(self.conv2(out)) 121 | out += self.shortcut(x) 122 | out = F.relu(out) 123 | #print(out.shape) 124 | return out 125 | 126 | 127 | class ResNet(nn.Module): 128 | def __init__(self, block, num_blocks, num_classes=10): 129 | super(ResNet, self).__init__() 130 | self.in_planes = 64 131 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 132 | self.bn1 = nn.BatchNorm2d(64) 133 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 134 | 135 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 136 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 137 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 138 | 139 | self.linear = nn.Linear(512 * block.expansion, num_classes) 140 | 141 | def _make_layer(self, block, planes, num_blocks, stride): 142 | strides = [stride] + [1] * (num_blocks - 1) 143 | layers = [] 144 | for stride in strides: 145 | layers.append(block(self.in_planes, planes, stride)) 146 | self.in_planes = planes * block.expansion 147 | return nn.Sequential(*layers) 148 | 149 | def _make_inverse_layer(self, block, in_planes, planes, num_blocks, stride): 150 | strides = [stride] + [1] * (num_blocks - 1) 151 | layers = [] 152 | for i, stride in enumerate(strides): 153 | if(i != num_blocks - 1): 154 | layers.append(block(in_planes, planes, stride)) 155 | in_planes = planes * block.expansion 156 | else: 157 | layers.append(block(planes, 3, stride)) 158 | 159 | return nn.Sequential(*layers) 160 | 161 | def flatten(self, out): 162 | out = F.avg_pool2d(out, 4) 163 | out = out.view(out.size(0), -1) 164 | return out 165 | def forward(self, x, features=False): 166 | 167 | 168 | out = F.relu(self.bn1(self.conv1(x))) 169 | out1 = self.layer1(out) 170 | out2 = self.layer2(out1) 171 | out3 = self.layer3(out2) 172 | out4 = self.layer4(out3) 173 | 174 | out = F.avg_pool2d(out4, 4) 175 | out = out.view(out.size(0), -1) 176 | if(features): 177 | return out 178 | out = self.linear(out) 179 | return out 180 | 181 | 182 | def ResNet18(num_classes=10): 183 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes) 184 | 185 | 186 | def ResNet34(num_classes=10): 187 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes) 188 | 189 | 190 | def ResNet50(num_classes=10): 191 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes) 192 | 193 | 194 | def ResNet101(num_classes=10): 195 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes) 196 | 197 | 198 | def ResNet152(num_classes=10): 199 | return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes) 200 | 201 | 202 | def test(): 203 | net = ResNet18() 204 | y = net(torch.randn(1, 3, 32, 32)) 205 | print(y.size()) 206 | -------------------------------------------------------------------------------- /synthetic_perturbations/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import DenseNet, ResNet -------------------------------------------------------------------------------- /synthetic_perturbations/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import random 7 | 8 | from torch.utils.data.dataset import Dataset 9 | 10 | if torch.cuda.is_available(): 11 | torch.backends.cudnn.enabled = True 12 | torch.backends.cudnn.benchmark = True 13 | torch.backends.cudnn.deterministic = True 14 | device = torch.device('cuda') 15 | else: 16 | device = torch.device('cpu') 17 | 18 | def comput_l2norm_lim(linf=0.03, feature_dim=3072): 19 | return np.sqrt(linf**2 * feature_dim) 20 | 21 | def normalize_l2norm(data, norm_lim): 22 | n = data.shape[0] 23 | orig_shape = data.shape 24 | flatten_data = data.reshape([n, -1]) 25 | norms = np.linalg.norm(flatten_data, axis=1, keepdims=True) 26 | flatten_data = flatten_data/norms 27 | data = flatten_data.reshape(orig_shape) 28 | data = data * norm_lim 29 | return data 30 | 31 | 32 | def adjust_learning_rate(optimizer, init_lr, epoch, all_epoch): 33 | """decrease the learning rate at 100 and 150 epoch""" 34 | decay = 1.0 35 | if(epoch