├── README.md ├── densenet.py ├── final_filter_unlearnable.py ├── final_muladv.py ├── requirements.txt ├── resnet.py ├── util.py └── vgg.py /README.md: -------------------------------------------------------------------------------- 1 | # CUDA: Convolution-based Unlearnable Datasets (CVPR 2023) 2 | #### Authors: Vinu Sankar Sadasivan, Mahdi Soltanolkotabi, Soheil Feizi 3 | 4 | Paper: https://arxiv.org/abs/2303.04278 5 | 6 | Requirements 7 | ------------ 8 | 9 | Python 3.8.5 (GCC 7.3.0) 10 | 11 | NVIDIA GPU with CUDA 11.0 12 | 13 | Python requirements in requirements.txt 14 | 15 | 16 | Directory tree 17 | -------------- 18 | 19 | The readme file is in the current directory "." 20 | 21 | Make folder "../datasets/" where datasets will be downloaded 22 | 23 | Make folder "results/" where results will be saved 24 | 25 | 26 | Codes 27 | ----- 28 | {densenet, resnet, vgg}.py contain networks from https://github.com/fshp971/robust-unlearnable-examples/tree/main/models 29 | 30 | util.py contains progress bar utils from https://github.com/HanxunH/Unlearnable-Examples 31 | 32 | final_filter_unlearnable.py contains code for executing CUDA dataset training. 33 | 34 | final_muladv.py contains code for executing Deconvolution-based Adversarial Training (DAT) on CUDA CIFAR-10 dataset with ResNet-18. 35 | 36 | 37 | To Run 38 | ------ 39 | 40 | For executing final_filter_unlearnable.py goto "." and run 41 | 42 | ``` 43 | python final_filter_unlearnable.py --arch='resnet18' --dataset='cifar10' --train-type='adv' \ 44 | --blur-parameter=0.3 --seed=0 --pgd-norm='linf' --pgd-steps=10 --pgd-radius=0.015 --mix=1.0 \ 45 | --name='results/resnet18_cifar10_adv_bp=0.3_linf_eps=4_steps=10_seed0_mix=1.0.pkl' 46 | ``` 47 | 48 | Above code will perform L_{\infty} adversarial training with CUDA CIFAR-10 dataset using ResNet-18. 49 | 50 | For executing DAT, goto "." and run 51 | 52 | ``` 53 | python final_muladv.py 54 | ``` 55 | 56 | 57 | 58 | 59 | > COPYRIGHT AND PERMISSION NOTICE 60 | > UMD Software [Can AI-Generated Text be Reliably Detected?] Copyright (C) 2022 University of Maryland 61 | > All rights reserved. 62 | > The University of Maryland (“UMD”) and the developers of [CUDA: Convolution-based Unlearnable Datasets] software (“Software”) give recipient (“Recipient”) permission to download a single copy of the Software in source code form and use by university, non-profit, or research institution users only, provided that the following conditions are met: 63 | > 64 | > Recipient may use the Software for any purpose, EXCEPT for commercial benefit. 65 | > Recipient will not copy the Software. 66 | > Recipient will not sell the Software. 67 | > Recipient will not give the Software to any third party. 68 | > Any party desiring a license to use the Software for commercial purposes shall contact: 69 | > UM Ventures, College Park at UMD at otc@umd.edu. 70 | > THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS, CONTRIBUTORS, AND THE UNIVERSITY OF MARYLAND "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER, CONTRIBUTORS OR THE UNIVERSITY OF MARYLAND BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 71 | -------------------------------------------------------------------------------- /densenet.py: -------------------------------------------------------------------------------- 1 | ''' adapted from 2 | https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py 3 | ''' 4 | 5 | import re 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.checkpoint as cp 10 | from collections import OrderedDict 11 | from torch import Tensor 12 | from typing import Any, List, Tuple 13 | 14 | 15 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 16 | 17 | 18 | class _DenseLayer(nn.Module): 19 | def __init__( 20 | self, 21 | num_input_features: int, 22 | growth_rate: int, 23 | bn_size: int, 24 | drop_rate: float, 25 | memory_efficient: bool = False 26 | ) -> None: 27 | super(_DenseLayer, self).__init__() 28 | self.norm1: nn.BatchNorm2d 29 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)) 30 | self.relu1: nn.ReLU 31 | self.add_module('relu1', nn.ReLU(inplace=True)) 32 | self.conv1: nn.Conv2d 33 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 34 | growth_rate, kernel_size=1, stride=1, 35 | bias=False)) 36 | self.norm2: nn.BatchNorm2d 37 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)) 38 | self.relu2: nn.ReLU 39 | self.add_module('relu2', nn.ReLU(inplace=True)) 40 | self.conv2: nn.Conv2d 41 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 42 | kernel_size=3, stride=1, padding=1, 43 | bias=False)) 44 | self.drop_rate = float(drop_rate) 45 | self.memory_efficient = memory_efficient 46 | 47 | def bn_function(self, inputs: List[Tensor]) -> Tensor: 48 | concated_features = torch.cat(inputs, 1) 49 | bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484 50 | return bottleneck_output 51 | 52 | # todo: rewrite when torchscript supports any 53 | def any_requires_grad(self, input: List[Tensor]) -> bool: 54 | for tensor in input: 55 | if tensor.requires_grad: 56 | return True 57 | return False 58 | 59 | @torch.jit.unused # noqa: T484 60 | def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor: 61 | def closure(*inputs): 62 | return self.bn_function(inputs) 63 | 64 | return cp.checkpoint(closure, *input) 65 | 66 | @torch.jit._overload_method # noqa: F811 67 | def forward(self, input: List[Tensor]) -> Tensor: 68 | pass 69 | 70 | @torch.jit._overload_method # noqa: F811 71 | def forward(self, input: Tensor) -> Tensor: 72 | pass 73 | 74 | # torchscript does not yet support *args, so we overload method 75 | # allowing it to take either a List[Tensor] or single Tensor 76 | def forward(self, input: Tensor) -> Tensor: # noqa: F811 77 | if isinstance(input, Tensor): 78 | prev_features = [input] 79 | else: 80 | prev_features = input 81 | 82 | if self.memory_efficient and self.any_requires_grad(prev_features): 83 | if torch.jit.is_scripting(): 84 | raise Exception("Memory Efficient not supported in JIT") 85 | 86 | bottleneck_output = self.call_checkpoint_bottleneck(prev_features) 87 | else: 88 | bottleneck_output = self.bn_function(prev_features) 89 | 90 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 91 | if self.drop_rate > 0: 92 | new_features = F.dropout(new_features, p=self.drop_rate, 93 | training=self.training) 94 | return new_features 95 | 96 | 97 | class _DenseBlock(nn.ModuleDict): 98 | _version = 2 99 | 100 | def __init__( 101 | self, 102 | num_layers: int, 103 | num_input_features: int, 104 | bn_size: int, 105 | growth_rate: int, 106 | drop_rate: float, 107 | memory_efficient: bool = False 108 | ) -> None: 109 | super(_DenseBlock, self).__init__() 110 | for i in range(num_layers): 111 | layer = _DenseLayer( 112 | num_input_features + i * growth_rate, 113 | growth_rate=growth_rate, 114 | bn_size=bn_size, 115 | drop_rate=drop_rate, 116 | memory_efficient=memory_efficient, 117 | ) 118 | self.add_module('denselayer%d' % (i + 1), layer) 119 | 120 | def forward(self, init_features: Tensor) -> Tensor: 121 | features = [init_features] 122 | for name, layer in self.items(): 123 | new_features = layer(features) 124 | features.append(new_features) 125 | return torch.cat(features, 1) 126 | 127 | 128 | class _Transition(nn.Sequential): 129 | def __init__(self, num_input_features: int, num_output_features: int) -> None: 130 | super(_Transition, self).__init__() 131 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 132 | self.add_module('relu', nn.ReLU(inplace=True)) 133 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 134 | kernel_size=1, stride=1, bias=False)) 135 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 136 | 137 | 138 | class DenseNet(nn.Module): 139 | r"""Densenet-BC model class, based on 140 | `"Densely Connected Convolutional Networks" `_. 141 | 142 | Args: 143 | growth_rate (int) - how many filters to add each layer (`k` in paper) 144 | block_config (list of 4 ints) - how many layers in each pooling block 145 | num_init_features (int) - the number of filters to learn in the first convolution layer 146 | bn_size (int) - multiplicative factor for number of bottle neck layers 147 | (i.e. bn_size * k features in the bottleneck layer) 148 | drop_rate (float) - dropout rate after each dense layer 149 | num_classes (int) - number of classification classes 150 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 151 | but slower. Default: *False*. See `"paper" `_. 152 | """ 153 | 154 | def __init__( 155 | self, 156 | growth_rate: int = 32, 157 | block_config: Tuple[int, int, int, int] = (6, 12, 24, 16), 158 | num_init_features: int = 64, 159 | bn_size: int = 4, 160 | drop_rate: float = 0, 161 | num_classes: int = 1000, 162 | memory_efficient: bool = False 163 | ) -> None: 164 | 165 | super(DenseNet, self).__init__() 166 | 167 | # First convolution 168 | self.features = nn.Sequential(OrderedDict([ 169 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, 170 | padding=3, bias=False)), 171 | ('norm0', nn.BatchNorm2d(num_init_features)), 172 | ('relu0', nn.ReLU(inplace=True)), 173 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 174 | ])) 175 | 176 | # Each denseblock 177 | num_features = num_init_features 178 | for i, num_layers in enumerate(block_config): 179 | block = _DenseBlock( 180 | num_layers=num_layers, 181 | num_input_features=num_features, 182 | bn_size=bn_size, 183 | growth_rate=growth_rate, 184 | drop_rate=drop_rate, 185 | memory_efficient=memory_efficient 186 | ) 187 | self.features.add_module('denseblock%d' % (i + 1), block) 188 | num_features = num_features + num_layers * growth_rate 189 | if i != len(block_config) - 1: 190 | trans = _Transition(num_input_features=num_features, 191 | num_output_features=num_features // 2) 192 | self.features.add_module('transition%d' % (i + 1), trans) 193 | num_features = num_features // 2 194 | 195 | # Final batch norm 196 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 197 | 198 | # Linear layer 199 | self.classifier = nn.Linear(num_features, num_classes) 200 | 201 | # Official init from torch repo. 202 | for m in self.modules(): 203 | if isinstance(m, nn.Conv2d): 204 | nn.init.kaiming_normal_(m.weight) 205 | elif isinstance(m, nn.BatchNorm2d): 206 | nn.init.constant_(m.weight, 1) 207 | nn.init.constant_(m.bias, 0) 208 | elif isinstance(m, nn.Linear): 209 | nn.init.constant_(m.bias, 0) 210 | 211 | def forward(self, x: Tensor) -> Tensor: 212 | features = self.features(x) 213 | out = F.relu(features, inplace=True) 214 | out = F.adaptive_avg_pool2d(out, (1, 1)) 215 | out = torch.flatten(out, 1) 216 | out = self.classifier(out) 217 | return out 218 | 219 | 220 | def _densenet( 221 | arch: str, 222 | growth_rate: int, 223 | block_config: Tuple[int, int, int, int], 224 | num_init_features: int, 225 | pretrained: bool, 226 | progress: bool, 227 | **kwargs: Any 228 | ) -> DenseNet: 229 | model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) 230 | if pretrained: 231 | raise NotImplementedError 232 | return model 233 | 234 | 235 | def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: 236 | r"""Densenet-121 model from 237 | `"Densely Connected Convolutional Networks" `_. 238 | The required minimum input size of the model is 29x29. 239 | 240 | Args: 241 | pretrained (bool): If True, returns a model pre-trained on ImageNet 242 | progress (bool): If True, displays a progress bar of the download to stderr 243 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 244 | but slower. Default: *False*. See `"paper" `_. 245 | """ 246 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, 247 | **kwargs) 248 | 249 | 250 | def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: 251 | r"""Densenet-161 model from 252 | `"Densely Connected Convolutional Networks" `_. 253 | The required minimum input size of the model is 29x29. 254 | 255 | Args: 256 | pretrained (bool): If True, returns a model pre-trained on ImageNet 257 | progress (bool): If True, displays a progress bar of the download to stderr 258 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 259 | but slower. Default: *False*. See `"paper" `_. 260 | """ 261 | return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, 262 | **kwargs) 263 | 264 | 265 | def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: 266 | r"""Densenet-169 model from 267 | `"Densely Connected Convolutional Networks" `_. 268 | The required minimum input size of the model is 29x29. 269 | 270 | Args: 271 | pretrained (bool): If True, returns a model pre-trained on ImageNet 272 | progress (bool): If True, displays a progress bar of the download to stderr 273 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 274 | but slower. Default: *False*. See `"paper" `_. 275 | """ 276 | return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, 277 | **kwargs) 278 | 279 | 280 | def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: 281 | r"""Densenet-201 model from 282 | `"Densely Connected Convolutional Networks" `_. 283 | The required minimum input size of the model is 29x29. 284 | 285 | Args: 286 | pretrained (bool): If True, returns a model pre-trained on ImageNet 287 | progress (bool): If True, displays a progress bar of the download to stderr 288 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 289 | but slower. Default: *False*. See `"paper" `_. 290 | """ 291 | return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, 292 | **kwargs) 293 | -------------------------------------------------------------------------------- /final_filter_unlearnable.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch.utils.data import DataLoader 4 | from torchvision import datasets, transforms 5 | import numpy as np 6 | from tqdm import tqdm 7 | import random 8 | import matplotlib.pyplot as plt 9 | import matplotlib 10 | from util import AverageMeter 11 | # from ResNet import ResNet18 12 | from six.moves import cPickle as pkl 13 | import argparse 14 | import os 15 | from time import time 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--arch', type=str, default='resnet18', choices=['resnet18', 'resnet50', 'vgg16-bn', 'densenet-121', 'wrn-34-10']) 19 | parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100']) 20 | parser.add_argument('--train-type', type=str, default='erm', choices=['erm', 'adv'], help='ERM or Adversarial training loss') 21 | parser.add_argument('--pgd-radius', type=float, default=0.0) 22 | parser.add_argument('--pgd-steps', type=int, default=0) 23 | parser.add_argument('--pgd-norm', type=str, default='', choices=['linf', 'l2', '']) 24 | parser.add_argument('--blur-parameter', type=float, default=0.3) 25 | parser.add_argument('--seed', type=int, default=0) 26 | parser.add_argument('--imagenet-dim', type=int, default=224) 27 | parser.add_argument('--name', type=str) 28 | parser.add_argument('--mix', type=float, default=1.0) # percent of poisoned data, default 100% data is poisoned. 29 | args = parser.parse_args() 30 | 31 | print(args.arch, args.dataset, args.train_type, args.pgd_radius, args.pgd_steps, args.pgd_norm, args.blur_parameter, args.seed, args.name) 32 | 33 | start = time() 34 | train_transform = [transforms.ToTensor()] 35 | test_transform = [transforms.ToTensor()] 36 | train_transform = transforms.Compose(train_transform) 37 | test_transform = transforms.Compose(test_transform) 38 | epochs = 100 39 | 40 | if args.dataset == 'cifar10': 41 | clean_train_dataset = datasets.CIFAR10(root='../datasets', train=True, download=True, transform=train_transform) 42 | clean_test_dataset = datasets.CIFAR10(root='../datasets', train=False, download=True, transform=test_transform) 43 | num_cls = 10 44 | size = 32 45 | batch_size = 512 46 | 47 | elif args.dataset == 'cifar100': 48 | clean_train_dataset = datasets.CIFAR100(root='../datasets', train=True, download=True, transform=train_transform) 49 | clean_test_dataset = datasets.CIFAR100(root='../datasets', train=False, download=True, transform=test_transform) 50 | num_cls = 100 51 | size = 32 52 | batch_size = 512 53 | 54 | 55 | clean_train_loader = DataLoader(dataset=clean_train_dataset, batch_size=batch_size, 56 | shuffle=False, pin_memory=True, 57 | drop_last=False, num_workers=4) 58 | 59 | clean_test_loader = DataLoader(dataset=clean_test_dataset, batch_size=512, 60 | shuffle=False, pin_memory=True, 61 | drop_last=False, num_workers=4) 62 | 63 | 64 | # unlearnable parameters 65 | grayscale = False # grayscale 66 | blur_parameter = args.blur_parameter 67 | center_parameter = 1.0 68 | kernel_size = 3 69 | seed = args.seed 70 | same = False 71 | mix = args.mix 72 | 73 | # test parameters 74 | test_grayscale = False 75 | 76 | 77 | # Below function is the main CUDA algorithm 78 | def get_filter_unlearnable(blur_parameter, center_parameter, grayscale, kernel_size, seed, same): 79 | 80 | np.random.seed(seed) 81 | cnns = [] 82 | with torch.no_grad(): 83 | for i in range(num_cls): 84 | cnns.append(torch.nn.Conv2d(3, 3, kernel_size, groups=3, padding=1).cuda()) 85 | if blur_parameter is None: 86 | blur_parameter = 1 87 | 88 | w = np.random.uniform(low=0, high=blur_parameter, size=(3,1,kernel_size,kernel_size)) 89 | if center_parameter is not None: 90 | shape = w[0][0].shape 91 | w[0, 0, np.random.randint(shape[0]), np.random.randint(shape[1])] = center_parameter 92 | 93 | w[1] = w[0] 94 | w[2] = w[0] 95 | cnns[i].weight.copy_(torch.tensor(w)) 96 | cnns[i].bias.copy_(cnns[i].bias * 0) 97 | 98 | cnns = np.stack(cnns) 99 | 100 | if same: 101 | cnns = np.stack([cnns[0]] * len(cnns)) 102 | 103 | if args.dataset == 'cifar10': 104 | unlearnable_dataset = datasets.CIFAR10(root='../datasets', train=True, download=True, transform=train_transform) 105 | 106 | elif args.dataset == 'cifar100': 107 | unlearnable_dataset = datasets.CIFAR100(root='../datasets', train=True, download=True, transform=train_transform) 108 | 109 | 110 | unlearnable_loader = DataLoader(dataset=unlearnable_dataset, batch_size=500, 111 | shuffle=False, pin_memory=True, 112 | drop_last=False, num_workers=4) 113 | 114 | pbar = tqdm(unlearnable_loader, total=len(unlearnable_loader)) 115 | images_ = [] 116 | 117 | for images, labels in pbar: 118 | images, labels = images.cuda(), labels.cuda() 119 | for i in range(len(images)): 120 | 121 | prob = np.random.random() 122 | if prob < mix: # mix*100% of data is poisoned 123 | id = labels[i].item() 124 | img = cnns[id](images[i:i+1]).detach().cpu() # convolve class-wise 125 | 126 | # # black and white 127 | if grayscale: 128 | img_bw = img[0].mean(0) 129 | img[0][0] = img_bw 130 | img[0][1] = img_bw 131 | img[0][2] = img_bw 132 | 133 | images_.append(img/img.max()) 134 | else: 135 | images_.append(images[i:i+1].detach().cpu()) 136 | 137 | # making unlearnable data 138 | unlearnable_dataset.data = unlearnable_dataset.data.astype(np.float32) 139 | for i in range(len(unlearnable_dataset)): 140 | unlearnable_dataset.data[i] = images_[i][0].numpy().transpose((1,2,0))*255 141 | unlearnable_dataset.data[i] = np.clip(unlearnable_dataset.data[i], a_min=0, a_max=255) 142 | unlearnable_dataset.data = unlearnable_dataset.data.astype(np.uint8) 143 | 144 | return unlearnable_dataset, cnns 145 | 146 | def imshow(img): 147 | fig = plt.figure(figsize=(9, 3), dpi=250, facecolor='w', edgecolor='k') 148 | npimg = img.numpy() 149 | plt.imshow(np.transpose(npimg, (1, 2, 0))) 150 | plt.axis('off') 151 | plt.tight_layout() 152 | plt.savefig('sample_{}.png'.format(blur_parameter)) 153 | 154 | def get_pairs_of_imgs(idx): 155 | clean_img = clean_train_dataset.data[idx] 156 | unlearnable_img = unlearnable_dataset.data[idx] 157 | clean_img = torchvision.transforms.functional.to_tensor(clean_img) 158 | unlearnable_img = torchvision.transforms.functional.to_tensor(unlearnable_img) 159 | noise = unlearnable_img - clean_img 160 | noise = noise - noise.min() 161 | noise = noise/noise.max() 162 | return [clean_img, noise, unlearnable_img] 163 | 164 | # def get_altered_testset(cnns, grayscale): 165 | 166 | # pbar = tqdm(clean_test_loader, total=len(clean_test_loader)) 167 | # images_ = [] 168 | 169 | # for images, labels in pbar: 170 | # images, labels = images.cuda(), labels.cuda() 171 | # for i in range(len(images)): 172 | # id = labels[i].item() 173 | # if cnns is None: 174 | # img = images[i:i+1].detach().cpu() 175 | # images_.append(img) 176 | # continue 177 | # else: 178 | # img = cnns[id%10](images[i:i+1]).detach().cpu() 179 | # if grayscale: 180 | # img_bw = img[0].mean(0) 181 | # img[0][0] = img_bw 182 | # img[0][1] = img_bw 183 | # img[0][2] = img_bw 184 | 185 | # # normalize 186 | # img = img/img.max() 187 | # images_.append(img) 188 | 189 | # clean_test_dataset.data = clean_test_dataset.data.astype(np.float32) 190 | # for i in range(len(clean_test_dataset)): 191 | # clean_test_dataset.data[i] = images_[i][0].numpy().transpose((1,2,0))*255 192 | # clean_test_dataset.data[i] = np.clip(clean_test_dataset.data[i], a_min=0, a_max=255) 193 | # clean_test_dataset.data = clean_test_dataset.data.astype(np.uint8) 194 | 195 | # return clean_test_dataset 196 | 197 | unlearnable_dataset, cnns = get_filter_unlearnable(blur_parameter, center_parameter, grayscale, kernel_size, seed, same) 198 | print('Time taken is', time()-start) 199 | # clean_test_dataset = get_altered_testset(None, test_grayscale) 200 | 201 | # get unlearnable dataset images 202 | rows = 10 203 | selected_idx = [] 204 | for i in range(10): 205 | idx = (np.stack(clean_train_dataset.targets) == i) 206 | idx = np.arange(len(clean_train_dataset))[idx] 207 | np.random.shuffle(idx) 208 | selected_idx.append(idx[0]) 209 | # print(i, cnns[i].weight.data[0]) 210 | 211 | # selected_idx = [random.randint(0, len(clean_train_dataset)) for _ in range(rows)] 212 | img_grid = [] 213 | for idx in selected_idx: 214 | img_grid += get_pairs_of_imgs(idx) 215 | 216 | img_grid = img_grid[0::3] + img_grid[1::3] + img_grid[2::3] 217 | imshow(torchvision.utils.make_grid(torch.stack(img_grid), nrow=10, pad_value=255)) 218 | # exit() 219 | 220 | # get ready for training 221 | 222 | train_transform = transforms.Compose([transforms.RandomCrop(size, padding=4), 223 | transforms.RandomHorizontalFlip(), 224 | transforms.ToTensor()]) 225 | 226 | if args.dataset == 'cifar10': 227 | clean_train_dataset = datasets.CIFAR10(root='../datasets', train=True, download=True, transform=train_transform) 228 | 229 | elif args.dataset == 'cifar100': 230 | clean_train_dataset = datasets.CIFAR100(root='../datasets', train=True, download=True, transform=train_transform) 231 | 232 | unlearnable_dataset.transforms = clean_train_dataset.transforms 233 | unlearnable_dataset.transform = clean_train_dataset.transform 234 | 235 | unlearnable_loader = DataLoader(dataset=unlearnable_dataset, batch_size=batch_size, 236 | shuffle=True, pin_memory=True, 237 | drop_last=False, num_workers=4) 238 | 239 | 240 | clean_test_loader = DataLoader(dataset=clean_test_dataset, batch_size=batch_size, 241 | shuffle=True, pin_memory=True, 242 | drop_last=False, num_workers=4) 243 | 244 | 245 | # training 246 | arch = args.arch 247 | torch.manual_seed(seed) 248 | 249 | if arch == 'resnet18': 250 | from resnet import resnet18 as net 251 | model = net(3, num_cls) 252 | 253 | elif arch == 'resnet50': 254 | from resnet import resnet50 as net 255 | model = net(3, num_cls) 256 | 257 | elif arch == 'wrn-34-10': 258 | from resnet import wrn34_10 as net 259 | model = net(3, num_cls) 260 | 261 | elif arch == 'vgg16-bn': 262 | from vgg import vgg16_bn as net 263 | model = net(3, num_cls) 264 | 265 | elif arch == 'densenet-121': 266 | from densenet import densenet121 as net 267 | model = net(num_classes=num_cls) 268 | 269 | model = model.cuda() 270 | criterion = torch.nn.CrossEntropyLoss() 271 | optimizer = torch.optim.SGD(params=model.parameters(), lr=0.1, weight_decay=0.0005, momentum=0.9) 272 | # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=60, eta_min=0) 273 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(epochs*0.4), gamma=0.1) 274 | 275 | train_acc = [] 276 | test_acc = [] 277 | 278 | try: 279 | 280 | if args.train_type == 'adv': 281 | 282 | import torchattacks 283 | 284 | if args.pgd_norm == 'linf': 285 | attacker = torchattacks.PGD 286 | elif args.pgd_norm == 'l2': 287 | attacker = torchattacks.PGDL2 288 | 289 | eps = args.pgd_radius 290 | steps = args.pgd_steps 291 | atk = attacker(model, eps=eps, alpha=eps/steps * 1.5, steps=steps) 292 | 293 | 294 | for epoch in range(epochs): 295 | # Train 296 | model.train() 297 | acc_meter = AverageMeter() 298 | loss_meter = AverageMeter() 299 | pbar = tqdm(unlearnable_loader, total=len(unlearnable_loader)) 300 | 301 | for images, labels in pbar: 302 | 303 | images, labels = images.cuda(), labels.cuda() 304 | model.zero_grad() 305 | optimizer.zero_grad() 306 | 307 | if args.train_type == 'adv': 308 | images = atk(images, labels) 309 | 310 | logits = model(images) 311 | loss = criterion(logits, labels) 312 | loss.backward() 313 | torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) 314 | optimizer.step() 315 | 316 | _, predicted = torch.max(logits.data, 1) 317 | acc = (predicted == labels).sum().item()/labels.size(0) 318 | acc_meter.update(acc) 319 | loss_meter.update(loss.item()) 320 | train_acc.append(acc) 321 | pbar.set_description("Acc %.2f Loss: %.2f" % (acc_meter.avg*100, loss_meter.avg)) 322 | 323 | scheduler.step() 324 | 325 | # Eval 326 | model.eval() 327 | correct, total = 0, 0 328 | 329 | for i, (images, labels) in enumerate(clean_test_loader): 330 | images, labels = images.cuda(), labels.cuda() 331 | with torch.no_grad(): 332 | logits = model(images) 333 | _, predicted = torch.max(logits.data, 1) 334 | total += labels.size(0) 335 | correct += (predicted == labels).sum().item() 336 | 337 | acc = correct / total 338 | test_acc.append(acc) 339 | tqdm.write('Clean Accuracy %.2f\n' % (acc*100)) 340 | tqdm.write('Epoch %.2f\n' % (epoch)) 341 | 342 | 343 | except: 344 | 345 | with open(args.name, 'wb') as f: 346 | pkl.dump([train_acc, test_acc], f) 347 | 348 | with open(args.name, 'wb') as f: 349 | pkl.dump([train_acc, test_acc], f) -------------------------------------------------------------------------------- /final_muladv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch.utils.data import DataLoader 4 | from torchvision import datasets, transforms 5 | import numpy as np 6 | from tqdm import tqdm 7 | import random 8 | import matplotlib.pyplot as plt 9 | import matplotlib 10 | from util import AverageMeter 11 | from resnet import resnet18 as net 12 | from six.moves import cPickle as pkl 13 | import torchattacks 14 | 15 | train_transform = [transforms.ToTensor()] 16 | test_transform = [transforms.ToTensor()] 17 | train_transform = transforms.Compose(train_transform) 18 | test_transform = transforms.Compose(test_transform) 19 | batch_size = 500 20 | 21 | clean_train_dataset = datasets.CIFAR10(root='../datasets', train=True, download=True, transform=train_transform) 22 | clean_test_dataset = datasets.CIFAR10(root='../datasets', train=False, download=True, transform=test_transform) 23 | 24 | clean_train_loader = DataLoader(dataset=clean_train_dataset, batch_size=batch_size, 25 | shuffle=False, pin_memory=True, 26 | drop_last=False, num_workers=12) 27 | 28 | clean_test_loader = DataLoader(dataset=clean_test_dataset, batch_size=batch_size, 29 | shuffle=False, pin_memory=True, 30 | drop_last=False, num_workers=12) 31 | 32 | 33 | # unlearnable parameters 34 | grayscale = False # grayscale 35 | blur_parameter = 0.3 36 | center_parameter = 1.0 37 | kernel_size = 3 38 | seed = 0 39 | same = False 40 | 41 | # test parameters 42 | test_grayscale = False 43 | 44 | def get_filter_unlearnable(blur_parameter, center_parameter, grayscale, kernel_size, seed, same): 45 | 46 | np.random.seed(seed) 47 | cnns = [] 48 | with torch.no_grad(): 49 | for i in range(10): 50 | cnns.append(torch.nn.Conv2d(3, 3, kernel_size, groups=3, padding=1).cuda()) 51 | if blur_parameter is None: 52 | blur_parameter = 1 53 | w = np.random.uniform(low=0, high=blur_parameter, size=(3,1,3,3)) 54 | # w = np.random.random((3,1,3,3)) 55 | if center_parameter is not None: 56 | shape = w[0][0].shape 57 | w[0, 0, np.random.randint(shape[0]), np.random.randint(shape[1])] = center_parameter 58 | w[1] = w[0] 59 | w[2] = w[0] 60 | # w = w/w.max() 61 | cnns[i].weight.copy_(torch.tensor(w)) 62 | cnns[i].bias.copy_(cnns[i].bias * 0) 63 | cnns = np.stack(cnns) 64 | 65 | if same: 66 | cnns = np.stack([cnns[0]] * len(cnns)) 67 | 68 | unlearnable_dataset = datasets.CIFAR10(root='../datasets', train=True, download=True, transform=train_transform) 69 | unlearnable_loader = DataLoader(dataset=unlearnable_dataset, batch_size=500, 70 | shuffle=False, pin_memory=True, 71 | drop_last=False, num_workers=12) 72 | 73 | pbar = tqdm(unlearnable_loader, total=len(unlearnable_loader)) 74 | images_ = [] 75 | 76 | for images, labels in pbar: 77 | images, labels = images.cuda(), labels.cuda() 78 | for i in range(len(images)): 79 | id = labels[i].item() 80 | img = cnns[id](images[i:i+1]).detach().cpu() # convolve class-wise 81 | # # black and white 82 | if grayscale: 83 | img_bw = img[0].mean(0) 84 | img[0][0] = img_bw 85 | img[0][1] = img_bw 86 | img[0][2] = img_bw 87 | images_.append(img/img.max()) 88 | 89 | # making unlearnable data 90 | unlearnable_dataset.data = unlearnable_dataset.data.astype(np.float32) 91 | for i in range(len(unlearnable_dataset)): 92 | unlearnable_dataset.data[i] = images_[i][0].numpy().transpose((1,2,0))*255 93 | unlearnable_dataset.data[i] = np.clip(unlearnable_dataset.data[i], a_min=0, a_max=255) 94 | unlearnable_dataset.data = unlearnable_dataset.data.astype(np.uint8) 95 | 96 | return unlearnable_dataset, cnns 97 | 98 | 99 | unlearnable_dataset, cnns = get_filter_unlearnable(blur_parameter, center_parameter, grayscale, kernel_size, seed, same) 100 | 101 | 102 | # get ready for training 103 | train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), 104 | transforms.RandomHorizontalFlip(), 105 | transforms.ToTensor()]) 106 | 107 | clean_train_dataset = datasets.CIFAR10(root='../datasets', train=True, download=True, transform=train_transform) 108 | unlearnable_dataset.transforms = clean_train_dataset.transforms 109 | unlearnable_dataset.transform = clean_train_dataset.transform 110 | 111 | unlearnable_loader = DataLoader(dataset=unlearnable_dataset, batch_size=batch_size, 112 | shuffle=True, pin_memory=True, 113 | drop_last=False, num_workers=12) 114 | 115 | clean_train_loader = DataLoader(dataset=clean_train_dataset, batch_size=batch_size, 116 | shuffle=False, pin_memory=True, 117 | drop_last=False, num_workers=12) 118 | 119 | clean_test_loader = DataLoader(dataset=clean_test_dataset, batch_size=batch_size, 120 | shuffle=True, pin_memory=True, 121 | drop_last=False, num_workers=12) 122 | 123 | # Below is the FAT attack procedure 124 | class Attack(): 125 | 126 | def __init__(self, steps, eta, criterion): 127 | self.steps = steps 128 | self.eta = eta 129 | self.criterion = criterion 130 | 131 | def model_gradients(self, model, value=True): 132 | for param in model.parameters(): 133 | param.requires_grad = value 134 | 135 | def apply_constraints(self): 136 | 137 | for i in range(len(self.tcnns)): 138 | with torch.no_grad(): 139 | temp = self.tcnns[i].weight 140 | # shape = temp[0, 0].shape 141 | M = 5 142 | # temp[0, 0, shape[0]//2, shape[1]//2] = M 143 | self.tcnns[i].weight.copy_(torch.clamp(temp, -M, M)) 144 | temp = self.tcnns[i].bias 145 | self.tcnns[i].bias.copy_(torch.clamp(temp, -M, M)) 146 | 147 | def perturb(self, model, x, y): 148 | 149 | self.model_gradients(model, False) 150 | filter_size = 7 151 | num_classes = 10 152 | self.tcnns = [torch.nn.ConvTranspose2d(1, 1, filter_size, groups=1, padding=filter_size//2).cuda() for i in range(num_classes)] 153 | 154 | for step in range(self.steps): 155 | 156 | self.apply_constraints() 157 | opt = torch.optim.SGD([i.weight for i in self.tcnns] + [i.bias for i in self.tcnns], lr=1e-3) 158 | opt.zero_grad() 159 | model.zero_grad() 160 | 161 | for cls in range(10): 162 | 163 | idx = (labels == cls) 164 | x_, y_ = self.tcnns[cls](x[idx].view(len(x[idx])*3, 1, 32, 32)).view(len(x[idx]), 3, 32, 32), y[idx] 165 | logits = model(x_) 166 | 167 | loss = self.criterion(logits, y_) 168 | self.tcnns[cls].weight.retain_grad() 169 | self.tcnns[cls].bias.retain_grad() 170 | loss.backward() 171 | 172 | self.tcnns[cls].weight.data += self.tcnns[cls].weight.grad.data * (+1) * self.eta 173 | self.tcnns[cls].bias.data += self.tcnns[cls].bias.grad.data * (+1) * self.eta 174 | 175 | x_ = torch.zeros(x.shape) 176 | for cls in range(10): 177 | idx = (labels == cls) 178 | xs = self.tcnns[cls](x[idx].view(len(x[idx])*3, 1, 32, 32)).view(len(x[idx]), 3, 32, 32) 179 | x_[idx] = xs.clone().detach().cpu() 180 | 181 | for i in range(len(x_)): 182 | x_[i] -= x_[i].min() 183 | x_[i] /= x_[i].max() 184 | 185 | self.model_gradients(model, True) 186 | 187 | return x_ 188 | 189 | 190 | # multiplicative adversarial training / Filter Adversarial Training 191 | epochs = 100 192 | model = net(3, 10) 193 | model = model.cuda() 194 | criterion = torch.nn.CrossEntropyLoss() 195 | optimizer = torch.optim.SGD(params=model.parameters(), lr=0.1, weight_decay=0.0005, momentum=0.9) 196 | # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=60, eta_min=0) 197 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(epochs*0.4), gamma=0.1) 198 | # atk = torchattacks.PGD(model, eps=8/255, alpha=8/2550, steps=10) 199 | atk = Attack(steps=10, eta=0.1, criterion=criterion) 200 | 201 | train_acc = [] 202 | test_acc = [] 203 | 204 | try: 205 | for epoch in range(epochs): 206 | # Train 207 | model.train() 208 | acc_meter = AverageMeter() 209 | loss_meter = AverageMeter() 210 | pbar = tqdm(unlearnable_loader, total=len(unlearnable_loader)) 211 | 212 | for images, labels in pbar: 213 | images, labels = images.cuda(), labels.cuda() 214 | model.zero_grad() 215 | optimizer.zero_grad() 216 | # adv_images = atk(images, labels) 217 | # logits = model(adv_images) 218 | adv_images = atk.perturb(model, images, labels) 219 | logits = model(adv_images.cuda()) 220 | loss = criterion(logits, labels) 221 | loss.backward() 222 | torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) 223 | optimizer.step() 224 | 225 | _, predicted = torch.max(logits.data, 1) 226 | acc = (predicted == labels).sum().item()/labels.size(0) 227 | acc_meter.update(acc) 228 | train_acc.append(acc) 229 | loss_meter.update(loss.item()) 230 | pbar.set_description("Acc %.2f Loss: %.2f" % (acc_meter.avg*100, loss_meter.avg)) 231 | 232 | scheduler.step() 233 | 234 | # Eval 235 | model.eval() 236 | correct, total = 0, 0 237 | 238 | for i, (images, labels) in enumerate(clean_test_loader): 239 | images, labels = images.cuda(), labels.cuda() 240 | with torch.no_grad(): 241 | logits = model(images) 242 | _, predicted = torch.max(logits.data, 1) 243 | total += labels.size(0) 244 | correct += (predicted == labels).sum().item() 245 | 246 | acc = correct / total 247 | test_acc.append(acc) 248 | tqdm.write('Clean Accuracy %.2f\n' % (acc*100)) 249 | tqdm.write('Epoch %.2f\n' % (epoch)) 250 | 251 | except: 252 | 253 | bp = (blur_parameter is None)*'x' + str(blur_parameter) 254 | cp = (center_parameter is None)*'x' + str(center_parameter) 255 | kp = str(kernel_size) 256 | sp = str(seed) 257 | gp = str(int(grayscale)) 258 | tgp = str(int(test_grayscale)) 259 | 260 | with open('results/muladv_train_size7.pkl'.format(bp, cp, kp, sp, gp, tgp, int(same)), 'wb') as f: 261 | pkl.dump([train_acc, test_acc], f) 262 | 263 | bp = (blur_parameter is None)*'x' + str(blur_parameter) 264 | cp = (center_parameter is None)*'x' + str(center_parameter) 265 | kp = str(kernel_size) 266 | sp = str(seed) 267 | gp = str(int(grayscale)) 268 | tgp = str(int(test_grayscale)) 269 | 270 | with open('results/muladv_train_size7.pkl'.format(bp, cp, kp, sp, gp, tgp, int(same)), 'wb') as f: 271 | pkl.dump([train_acc, test_acc], f) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchattacks==3.2.6 2 | torch==1.8.2 3 | torchvision==0.9.2 4 | numpy==1.19.2 5 | tqdm==4.50.2 6 | matplotlib==3.3.2 -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | ''' ref: 2 | https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py 3 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | expansion = 1 12 | 13 | def __init__(self, in_planes, planes, stride=1, wide=1): 14 | super(BasicBlock, self).__init__() 15 | planes = planes * wide 16 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | 21 | self.shortcut = nn.Sequential() 22 | if stride != 1 or in_planes != self.expansion*planes: 23 | self.shortcut = nn.Sequential( 24 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 25 | nn.BatchNorm2d(self.expansion*planes) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(self.conv1(x))) 30 | out = self.bn2(self.conv2(out)) 31 | out += self.shortcut(x) 32 | out = F.relu(out) 33 | return out 34 | 35 | 36 | class Bottleneck(nn.Module): 37 | expansion = 4 38 | 39 | def __init__(self, in_planes, planes, stride=1, wide=1): 40 | super(Bottleneck, self).__init__() 41 | mid_planes = planes * wide 42 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, bias=False) 43 | self.bn1 = nn.BatchNorm2d(mid_planes) 44 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(mid_planes) 46 | self.conv3 = nn.Conv2d(mid_planes, self.expansion*planes, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 53 | nn.BatchNorm2d(self.expansion*planes) 54 | ) 55 | 56 | def forward(self, x): 57 | out = F.relu(self.bn1(self.conv1(x))) 58 | out = F.relu(self.bn2(self.conv2(out))) 59 | out = self.bn3(self.conv3(out)) 60 | out += self.shortcut(x) 61 | out = F.relu(out) 62 | return out 63 | 64 | 65 | class ResNet(nn.Module): 66 | def __init__(self, block, num_blocks, in_dims, out_dims, wide=1): 67 | super(ResNet, self).__init__() 68 | self.wide = wide 69 | self.in_planes = 64 70 | self.conv1 = nn.Conv2d(in_dims, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.bn1 = nn.BatchNorm2d(64) 72 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 73 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 74 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 75 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 76 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 77 | self.linear = nn.Linear(512*block.expansion, out_dims) 78 | 79 | def _make_layer(self, block, planes, num_blocks, stride): 80 | strides = [stride] + [1]*(num_blocks-1) 81 | layers = [] 82 | for stride in strides: 83 | layers.append(block(self.in_planes, planes, stride, self.wide)) 84 | self.in_planes = planes * block.expansion 85 | return nn.Sequential(*layers) 86 | 87 | def forward(self, x): 88 | out = F.relu(self.bn1(self.conv1(x))) 89 | out = self.layer1(out) 90 | out = self.layer2(out) 91 | out = self.layer3(out) 92 | out = self.layer4(out) 93 | out = self.avgpool(out) 94 | out = torch.flatten(out, 1) 95 | out = self.linear(out) 96 | return out 97 | 98 | def feature_extract(self, x): 99 | out = F.relu(self.bn1(self.conv1(x))) 100 | out = self.layer1(out) 101 | out = self.layer2(out) 102 | out = self.layer3(out) 103 | out = self.layer4(out) 104 | out = self.avgpool(out) 105 | out = torch.flatten(out, 1) 106 | return out 107 | 108 | 109 | class WRN(nn.Module): 110 | def __init__(self, num_blocks, in_dims, out_dims, wide=10): 111 | super(WRN, self).__init__() 112 | self.in_planes = 16 113 | self.wide = wide 114 | 115 | block = BasicBlock 116 | 117 | self.conv1 = nn.Conv2d(in_dims, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False) 118 | self.bn1 = nn.BatchNorm2d(16) 119 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 120 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 121 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 122 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 123 | self.linear = nn.Linear(64*wide, out_dims) 124 | 125 | def _make_layer(self, block, planes, num_blocks, stride): 126 | strides = [stride] + [1]*(num_blocks-1) 127 | layers = [] 128 | for stride in strides: 129 | layers.append(block(self.in_planes, planes, stride, self.wide)) 130 | self.in_planes = planes * self.wide * block.expansion 131 | 132 | return nn.Sequential(*layers) 133 | 134 | def forward(self, x): 135 | out = F.relu(self.bn1(self.conv1(x))) 136 | out = self.layer1(out) 137 | out = self.layer2(out) 138 | out = self.layer3(out) 139 | # out = F.avg_pool2d(out, 8) 140 | # out = out.view(out.shape[0], -1) 141 | out = self.avgpool(out) 142 | out = torch.flatten(out, 1) 143 | out = self.linear(out) 144 | return out 145 | 146 | 147 | def resnet18(in_dims, out_dims): 148 | return ResNet(BasicBlock, [2,2,2,2], in_dims, out_dims, 1) 149 | 150 | def wrn34_10(in_dims, out_dims): 151 | return WRN([5,5,5], in_dims, out_dims, wide=10) 152 | 153 | def resnet50(in_dims, out_dims): 154 | return ResNet(Bottleneck, [3,4,6,3], in_dims, out_dims, 1) 155 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | 7 | if torch.cuda.is_available(): 8 | torch.backends.cudnn.enabled = True 9 | torch.backends.cudnn.benchmark = True 10 | torch.backends.cudnn.deterministic = True 11 | device = torch.device('cuda') 12 | else: 13 | device = torch.device('cpu') 14 | 15 | 16 | def _patch_noise_extend_to_img(noise, image_size=[3, 32, 32], patch_location='center'): 17 | c, h, w = image_size[0], image_size[1], image_size[2] 18 | mask = np.zeros((c, h, w), np.float32) 19 | x_len, y_len = noise.shape[1], noise.shape[2] 20 | 21 | if patch_location == 'center' or (h == w == x_len == y_len): 22 | x = h // 2 23 | y = w // 2 24 | elif patch_location == 'random': 25 | x = np.random.randint(x_len // 2, w - x_len // 2) 26 | y = np.random.randint(y_len // 2, h - y_len // 2) 27 | else: 28 | raise('Invalid patch location') 29 | 30 | x1 = np.clip(x - x_len // 2, 0, h) 31 | x2 = np.clip(x + x_len // 2, 0, h) 32 | y1 = np.clip(y - y_len // 2, 0, w) 33 | y2 = np.clip(y + y_len // 2, 0, w) 34 | mask[:, x1: x2, y1: y2] = noise 35 | return mask 36 | 37 | 38 | def setup_logger(name, log_file, level=logging.INFO): 39 | """To setup as many loggers as you want""" 40 | formatter = logging.Formatter('%(asctime)s %(message)s') 41 | console_handler = logging.StreamHandler() 42 | console_handler.setFormatter(formatter) 43 | file_handler = logging.FileHandler(log_file) 44 | file_handler.setFormatter(formatter) 45 | logger = logging.getLogger(name) 46 | logger.setLevel(level) 47 | logger.addHandler(file_handler) 48 | logger.addHandler(console_handler) 49 | return logger 50 | 51 | 52 | def log_display(epoch, global_step, time_elapse, **kwargs): 53 | display = 'epoch=' + str(epoch) + \ 54 | '\tglobal_step=' + str(global_step) 55 | for key, value in kwargs.items(): 56 | if type(value) == str: 57 | display = '\t' + key + '=' + value 58 | else: 59 | display += '\t' + str(key) + '=%.4f' % value 60 | display += '\ttime=%.2fit/s' % (1. / time_elapse) 61 | return display 62 | 63 | 64 | def accuracy(output, target, topk=(1,)): 65 | maxk = max(topk) 66 | 67 | batch_size = target.size(0) 68 | _, pred = output.topk(maxk, 1, True, True) 69 | pred = pred.t() 70 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 71 | 72 | res = [] 73 | for k in topk: 74 | correct_k = correct[:k].view(-1).float().sum(0) 75 | res.append(correct_k.mul_(1/batch_size)) 76 | return res 77 | 78 | 79 | def save_model(filename, epoch, model, optimizer, scheduler, save_best=False, **kwargs): 80 | # Torch Save State Dict 81 | state = { 82 | 'epoch': epoch+1, 83 | 'model_state_dict': model.state_dict(), 84 | 'optimizer_state_dict': optimizer.state_dict(), 85 | 'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None 86 | } 87 | for key, value in kwargs.items(): 88 | state[key] = value 89 | torch.save(state, filename + '.pth') 90 | filename += '_best.pth' 91 | if save_best: 92 | torch.save(state, filename) 93 | return 94 | 95 | 96 | def load_model(filename, model, optimizer, scheduler, **kwargs): 97 | # Load Torch State Dict 98 | filename = filename + '.pth' 99 | checkpoints = torch.load(filename, map_location=device) 100 | model.load_state_dict(checkpoints['model_state_dict']) 101 | if optimizer is not None and checkpoints['optimizer_state_dict'] is not None: 102 | optimizer.load_state_dict(checkpoints['optimizer_state_dict']) 103 | if scheduler is not None and checkpoints['scheduler_state_dict'] is not None: 104 | scheduler.load_state_dict(checkpoints['scheduler_state_dict']) 105 | return checkpoints 106 | 107 | 108 | def count_parameters_in_MB(model): 109 | return sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary_head" not in name)/1e6 110 | 111 | 112 | def build_dirs(path): 113 | if not os.path.exists(path): 114 | os.makedirs(path) 115 | return 116 | 117 | 118 | class AverageMeter(object): 119 | """Computes and stores the average and current value""" 120 | 121 | def __init__(self): 122 | self.reset() 123 | 124 | def reset(self): 125 | self.val = 0 126 | self.avg = 0 127 | self.sum = 0 128 | self.count = 0 129 | self.max = 0 130 | 131 | def update(self, val, n=1): 132 | self.val = val 133 | self.sum += val * n 134 | self.count += n 135 | self.avg = self.sum / self.count 136 | self.max = max(self.max, val) 137 | 138 | 139 | def onehot(size, target): 140 | vec = torch.zeros(size, dtype=torch.float32) 141 | vec[target] = 1. 142 | return vec 143 | 144 | 145 | def rand_bbox(size, lam): 146 | if len(size) == 4: 147 | W = size[2] 148 | H = size[3] 149 | elif len(size) == 3: 150 | W = size[1] 151 | H = size[2] 152 | else: 153 | raise Exception 154 | 155 | cut_rat = np.sqrt(1. - lam) 156 | cut_w = np.int(W * cut_rat) 157 | cut_h = np.int(H * cut_rat) 158 | 159 | # uniform 160 | cx = np.random.randint(W) 161 | cy = np.random.randint(H) 162 | 163 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 164 | bby1 = np.clip(cy - cut_h // 2, 0, H) 165 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 166 | bby2 = np.clip(cy + cut_h // 2, 0, H) 167 | 168 | return bbx1, bby1, bbx2, bby2 169 | -------------------------------------------------------------------------------- /vgg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://github.com/pytorch/vision.git 3 | ''' 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | 10 | __all__ = [ 11 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 12 | 'vgg19_bn', 'vgg19', 13 | ] 14 | 15 | 16 | class VGG(nn.Module): 17 | ''' VGG model ''' 18 | def __init__(self, features, out_channels): 19 | super(VGG, self).__init__() 20 | self.features = features 21 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 22 | self.classifier = nn.Sequential( 23 | nn.Dropout(), 24 | nn.Linear(512, 512), 25 | nn.ReLU(True), 26 | nn.Dropout(), 27 | nn.Linear(512, 512), 28 | nn.ReLU(True), 29 | nn.Linear(512, out_channels), 30 | ) 31 | 32 | ''' Initialize weights ''' 33 | for m in self.modules(): 34 | if isinstance(m, nn.Conv2d): 35 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 36 | m.weight.data.normal_(0, math.sqrt(2. / n)) 37 | m.bias.data.zero_() 38 | 39 | 40 | def forward(self, x): 41 | x = self.features(x) 42 | x = self.avgpool(x) 43 | x = torch.flatten(x, 1) 44 | x = self.classifier(x) 45 | return x 46 | 47 | 48 | def make_layers(cfg, in_dims=3, batch_norm=False): 49 | layers = [] 50 | in_channels = in_dims 51 | for v in cfg: 52 | if v == 'M': 53 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 54 | else: 55 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 56 | if batch_norm: 57 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 58 | else: 59 | layers += [conv2d, nn.ReLU(inplace=True)] 60 | in_channels = v 61 | return nn.Sequential(*layers) 62 | 63 | 64 | cfg = { 65 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 66 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 67 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 68 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 69 | } 70 | 71 | 72 | # def vgg11(out_channels=10): 73 | # """VGG 11-layer model (configuration "A")""" 74 | # return VGG(make_layers(cfg['A']), out_channels) 75 | # 76 | # 77 | def vgg11_bn(in_dims=3, out_dims=10): 78 | """VGG 11-layer model (configuration "A") with batch normalization""" 79 | return VGG(make_layers(cfg['A'], in_dims, batch_norm=True), out_dims) 80 | # 81 | # 82 | # def vgg13(out_channels=10): 83 | # """VGG 13-layer model (configuration "B")""" 84 | # return VGG(make_layers(cfg['B']), out_channels) 85 | # 86 | # 87 | # def vgg13_bn(out_channels=10): 88 | # """VGG 13-layer model (configuration "B") with batch normalization""" 89 | # return VGG(make_layers(cfg['B'], batch_norm=True), out_channels) 90 | 91 | 92 | # def vgg16(out_dims=10): 93 | # """VGG 16-layer model (configuration "D")""" 94 | # return VGG(make_layers(cfg['D']), out_dims) 95 | 96 | 97 | def vgg16_bn(in_dims=3, out_dims=10): 98 | """VGG 16-layer model (configuration "D") with batch normalization""" 99 | return VGG(make_layers(cfg['D'], in_dims, batch_norm=True), out_dims) 100 | 101 | 102 | # def vgg19(out_channels=10): 103 | # """VGG 19-layer model (configuration "E")""" 104 | # return VGG(make_layers(cfg['E']), out_channels) 105 | 106 | 107 | def vgg19_bn(in_dims=3, out_dims=10): 108 | """VGG 19-layer model (configuration 'E') with batch normalization""" 109 | return VGG(make_layers(cfg['E'], in_dims, batch_norm=True), out_dims) 110 | --------------------------------------------------------------------------------