├── attack ├── __init__.py └── attack.py ├── pruner ├── __init__.py └── kfac_MAD_pruner.py ├── models ├── operator │ ├── __init__.py │ └── mask.py ├── __init__.py ├── vgg.py ├── vgg_mask.py ├── wide.py ├── resnet.py ├── wide_mask.py └── resnet_mask.py ├── figure ├── pruning ratio.png ├── adversarial saliency.png └── semantic information.png ├── utils ├── mask_parameter_generator_utils.py ├── network_utils.py ├── mask_network_utils.py ├── data_utils.py ├── common_utils.py ├── utils.py └── kfac_utils.py ├── LICENSE ├── compute_saliency.py ├── main_pretrain.py ├── test.py ├── main_adv_pretrain.py ├── main_mart_pretrain.py ├── main_mad_pretrain.py ├── main_trades_pretrain.py └── README.md /attack/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pruner/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/operator/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figure/pruning ratio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ByungKwanLee/Masking-Adversarial-Damage/HEAD/figure/pruning ratio.png -------------------------------------------------------------------------------- /figure/adversarial saliency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ByungKwanLee/Masking-Adversarial-Damage/HEAD/figure/adversarial saliency.png -------------------------------------------------------------------------------- /figure/semantic information.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ByungKwanLee/Masking-Adversarial-Damage/HEAD/figure/semantic information.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * 2 | from .resnet import * 3 | from .wide import * 4 | 5 | from .vgg_mask import * 6 | from .resnet_mask import * 7 | -------------------------------------------------------------------------------- /utils/mask_parameter_generator_utils.py: -------------------------------------------------------------------------------- 1 | class MaskParameterGenerator(object): 2 | def __init__(self, model): 3 | self.model = model 4 | 5 | def mask_parameters(self): 6 | for name, param in self.model.named_parameters(): 7 | if ('mask' in name) and (param!=None): 8 | yield param 9 | 10 | def non_mask_parameters(self): 11 | for name, param in self.model.named_parameters(): 12 | if not 'mask' in name: 13 | yield param 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 ByungKwanLee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/network_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.vgg import VGG 3 | from models.resnet import resnet 4 | from models.wide import wide_resnet 5 | 6 | 7 | def get_network(network, depth, dataset, device=None): 8 | 9 | if dataset == 'cifar10': 10 | mean = torch.tensor([0.4914, 0.4822, 0.4465]).to(device) 11 | std = torch.tensor([0.2023, 0.1994, 0.2010]).to(device) 12 | elif dataset == 'svhn': # later, it should be updated 13 | mean = torch.tensor([0.43090966, 0.4302428, 0.44634357]).to(device) 14 | std = torch.tensor([0.19759192, 0.20029082, 0.19811132]).to(device) 15 | elif dataset == 'cifar100': 16 | mean = torch.tensor([0.5071, 0.4867, 0.4408]).to(device) 17 | std = torch.tensor([0.2675, 0.2565, 0.2761]).to(device) 18 | elif dataset == 'tiny': 19 | mean = torch.tensor([0.48024578664982126, 0.44807218089384643, 0.3975477478649648]).to(device) 20 | std = torch.tensor([0.2769864069088257, 0.26906448510256, 0.282081906210584]).to(device) 21 | 22 | if network == 'vgg': 23 | return VGG(depth=depth, dataset=dataset, mean=mean, std=std) 24 | elif network == 'resnet': 25 | return resnet(depth=depth, dataset=dataset, mean=mean, std=std) 26 | elif network == 'wide': 27 | return wide_resnet(depth=depth, widen_factor=10, dataset=dataset, mean=mean, std=std) 28 | else: 29 | raise NotImplementedError 30 | 31 | 32 | def stablize_bn(net, trainloader, device='cuda'): 33 | """Iterate over the dataset for stabilizing the 34 | BatchNorm statistics. 35 | """ 36 | net = net.train() 37 | for batch, (inputs, _) in enumerate(trainloader): 38 | inputs = inputs.to(device) 39 | net(inputs) -------------------------------------------------------------------------------- /utils/mask_network_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from models.vgg_mask import VGG 4 | from models.resnet_mask import resnet 5 | from models.wide_mask import wide_resnet 6 | 7 | def get_mask_network(network, depth, dataset, device): 8 | 9 | if dataset == 'cifar10': 10 | mean = torch.tensor([0.4914, 0.4822, 0.4465]).to(device) 11 | std = torch.tensor([0.2023, 0.1994, 0.2010]).to(device) 12 | elif dataset == 'svhn': # later, it should be updated 13 | mean = torch.tensor([0.43090966, 0.4302428, 0.44634357]).to(device) 14 | std = torch.tensor([0.19759192, 0.20029082, 0.19811132]).to(device) 15 | elif dataset == 'cifar100': 16 | mean = torch.tensor([0.5071, 0.4867, 0.4408]).to(device) 17 | std = torch.tensor([0.2675, 0.2565, 0.2761]).to(device) 18 | elif dataset == 'tiny': 19 | mean = torch.tensor([0.48024578664982126, 0.44807218089384643, 0.3975477478649648]).to(device) 20 | std = torch.tensor([0.2769864069088257, 0.26906448510256, 0.282081906210584]).to(device) 21 | 22 | if network == 'vgg': 23 | return VGG(depth=depth, dataset=dataset, mean=mean, std=std) 24 | elif network == 'resnet': 25 | return resnet(depth=depth, dataset=dataset, mean=mean, std=std) 26 | elif network == 'wide': 27 | return wide_resnet(depth=depth, widen_factor=10, dataset=dataset, mean=mean, std=std) 28 | else: 29 | raise NotImplementedError 30 | 31 | 32 | def stablize_bn(net, trainloader, device='cuda'): 33 | """Iterate over the dataset for stabilizing the 34 | BatchNorm statistics. 35 | """ 36 | net = net.train() 37 | for batch, (inputs, _) in enumerate(trainloader): 38 | inputs = inputs.to(device) 39 | net(inputs) 40 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | _AFFINE = True 6 | # _AFFINE = False 7 | 8 | defaultcfg = { 9 | 11: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 10 | 13: [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 11 | 16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512], 12 | 19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512], 13 | } 14 | 15 | 16 | class VGG(nn.Module): 17 | def __init__(self, dataset='cifar10', depth=19, mean=None, std=None, init_weights=True, cfg=None): 18 | super(VGG, self).__init__() 19 | if cfg is None: 20 | cfg = defaultcfg[depth] 21 | 22 | self.mean = mean.view(1, -1, 1, 1) 23 | self.std = std.view(1, -1, 1, 1) 24 | 25 | self.feature = self.make_layers(cfg, True) 26 | self.dataset = dataset 27 | if dataset == 'cifar10' or dataset == 'svhn': 28 | num_classes = 10 29 | elif dataset == 'cifar100': 30 | num_classes = 100 31 | elif dataset == 'tiny': 32 | num_classes = 200 33 | self.classifier = nn.Linear(cfg[-1], num_classes) 34 | if init_weights: 35 | self._initialize_weights() 36 | 37 | def make_layers(self, cfg, batch_norm=False): 38 | layers = [] 39 | in_channels = 3 40 | for v in cfg: 41 | if v == 'M': 42 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 43 | else: 44 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False) 45 | if batch_norm: 46 | layers += [conv2d, nn.BatchNorm2d(v, affine=_AFFINE), nn.ReLU(inplace=True)] 47 | else: 48 | layers += [conv2d, nn.ReLU(inplace=True)] 49 | in_channels = v 50 | return nn.Sequential(*layers) 51 | 52 | def forward(self, x): 53 | x = (x-self.mean) / self.std 54 | x = self.feature(x) 55 | if self.dataset == 'tiny': 56 | x = nn.AvgPool2d(4)(x) 57 | else: 58 | x = nn.AvgPool2d(2)(x) 59 | x = x.view(x.size(0), -1) 60 | y = self.classifier(x) 61 | return y 62 | 63 | def _initialize_weights(self): 64 | for m in self.modules(): 65 | if isinstance(m, nn.Conv2d): 66 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 67 | m.weight.data.normal_(0, math.sqrt(2. / n)) 68 | if m.bias is not None: 69 | m.bias.data.zero_() 70 | elif isinstance(m, nn.BatchNorm2d): 71 | if m.weight is not None: 72 | m.weight.data.fill_(1.0) 73 | m.bias.data.zero_() 74 | elif isinstance(m, nn.Linear): 75 | m.weight.data.normal_(0, 0.01) 76 | m.bias.data.zero_() 77 | -------------------------------------------------------------------------------- /models/vgg_mask.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from models.operator.mask import Conv2d_mask, Linear_mask 5 | 6 | _AFFINE = True 7 | # _AFFINE = False 8 | 9 | defaultcfg = { 10 | 11: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 11 | 13: [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 12 | 16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512], 13 | 19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512], 14 | } 15 | 16 | 17 | class VGG(nn.Module): 18 | def __init__(self, dataset='cifar10', depth=19, mean=None, std=None, init_weights=True, cfg=None): 19 | super(VGG, self).__init__() 20 | if cfg is None: 21 | cfg = defaultcfg[depth] 22 | self.mean = mean.view(1, -1, 1, 1) 23 | self.std = std.view(1, -1, 1, 1) 24 | 25 | self.feature = self.make_layers(cfg, True) 26 | self.dataset = dataset 27 | if dataset == 'cifar10' or dataset == 'svhn': 28 | num_classes = 10 29 | elif dataset == 'cifar100': 30 | num_classes = 100 31 | elif dataset == 'tiny': 32 | num_classes = 200 33 | self.classifier = Linear_mask(cfg[-1], num_classes) 34 | if init_weights: 35 | self._initialize_weights() 36 | 37 | def make_layers(self, cfg, batch_norm=False): 38 | layers = [] 39 | in_channels = 3 40 | for v in cfg: 41 | if v == 'M': 42 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 43 | else: 44 | conv2d = Conv2d_mask(in_channels, v, kernel_size=3, padding=1, bias=False) 45 | if batch_norm: 46 | layers += [conv2d, nn.BatchNorm2d(v, affine=_AFFINE), nn.ReLU(inplace=True)] 47 | else: 48 | layers += [conv2d, nn.ReLU(inplace=True)] 49 | in_channels = v 50 | return nn.Sequential(*layers) 51 | 52 | def forward(self, x, pop=False, inter=False): 53 | # Feature visualization 54 | if not inter: 55 | x = (x-self.mean) / self.std 56 | x = self.feature(x) 57 | if pop: 58 | return x 59 | 60 | 61 | if self.dataset == 'tiny': 62 | x = nn.AvgPool2d(4)(x) 63 | else: 64 | x = nn.AvgPool2d(2)(x) 65 | x = x.view(x.size(0), -1) 66 | y = self.classifier(x) 67 | return y 68 | 69 | def _initialize_weights(self): 70 | for m in self.modules(): 71 | if isinstance(m, Conv2d_mask): 72 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 73 | m.weight.data.normal_(0, math.sqrt(2. / n)) 74 | if m.bias is not None: 75 | m.bias.data.zero_() 76 | elif isinstance(m, nn.BatchNorm2d): 77 | if m.weight is not None: 78 | m.weight.data.fill_(1.0) 79 | m.bias.data.zero_() 80 | elif isinstance(m, Linear_mask): 81 | m.weight.data.normal_(0, 0.01) 82 | m.bias.data.zero_() -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | 5 | 6 | def get_transforms(dataset): 7 | transform_train = None 8 | transform_test = None 9 | if dataset == 'cifar10': 10 | transform_train = transforms.Compose([ 11 | transforms.RandomCrop(32, padding=4), 12 | transforms.RandomHorizontalFlip(), 13 | transforms.ToTensor(), 14 | ]) 15 | 16 | transform_test = transforms.Compose([ 17 | transforms.ToTensor(), 18 | ]) 19 | 20 | if dataset == 'cifar100': 21 | transform_train = transforms.Compose([ 22 | transforms.RandomCrop(32, padding=4), 23 | transforms.RandomHorizontalFlip(), 24 | transforms.ToTensor(), 25 | ]) 26 | 27 | transform_test = transforms.Compose([ 28 | transforms.ToTensor(), 29 | ]) 30 | 31 | if dataset == 'svhn': 32 | transform_train = transforms.Compose([ 33 | transforms.RandomCrop(32, padding=4), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.ToTensor(), 36 | ]) 37 | 38 | transform_test = transforms.Compose([ 39 | transforms.ToTensor(), 40 | ]) 41 | 42 | if dataset == 'svhn': 43 | transform_train = transforms.Compose([ 44 | transforms.RandomCrop(32, padding=4), 45 | transforms.RandomHorizontalFlip(), 46 | transforms.ToTensor(), 47 | ]) 48 | 49 | transform_test = transforms.Compose([ 50 | transforms.ToTensor(), 51 | ]) 52 | 53 | if dataset == 'tiny': 54 | transform_train = transforms.Compose([ 55 | transforms.RandomCrop(64, padding=4), 56 | transforms.RandomHorizontalFlip(), 57 | transforms.ToTensor(), 58 | ]) 59 | 60 | transform_test = transforms.Compose([ 61 | transforms.ToTensor(), 62 | ]) 63 | 64 | assert transform_test is not None and transform_train is not None, 'Error, no dataset %s' % dataset 65 | return transform_train, transform_test 66 | 67 | 68 | def get_dataloader(dataset, train_batch_size, test_batch_size, num_workers=0, root='../data', is_test=True): 69 | transform_train, transform_test = get_transforms(dataset) 70 | trainset, testset = None, None 71 | if dataset == 'cifar10': 72 | trainset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train) 73 | testset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test) 74 | 75 | if dataset == 'cifar100': 76 | trainset = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=transform_train) 77 | testset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=transform_test) 78 | 79 | if dataset == 'svhn': 80 | trainset = torchvision.datasets.SVHN(root=root, split='train', download=True, transform=transform_train) 81 | testset = torchvision.datasets.SVHN(root=root, split='test', download=True, transform=transform_test) 82 | 83 | if dataset == 'tiny': 84 | trainset = torchvision.datasets.ImageFolder(root + '/tiny-imagenet-200/train', transform=transform_train) 85 | testset = torchvision.datasets.ImageFolder(root + '/tiny-imagenet-200/val', transform=transform_test) 86 | 87 | assert trainset is not None and testset is not None, 'Error, no dataset %s' % dataset 88 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True, 89 | num_workers=num_workers) 90 | if not is_test: 91 | return trainloader, None 92 | 93 | testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False, 94 | num_workers=num_workers) 95 | 96 | return trainloader, testloader -------------------------------------------------------------------------------- /models/wide.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | def forward(self, x): 23 | if not self.equalInOut: 24 | x = self.relu1(self.bn1(x)) 25 | else: 26 | out = self.relu1(self.bn1(x)) 27 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 28 | if self.droprate > 0: 29 | out = F.dropout(out, p=self.droprate, training=self.training) 30 | out = self.conv2(out) 31 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 32 | 33 | class NetworkBlock(nn.Module): 34 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 35 | super(NetworkBlock, self).__init__() 36 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 37 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 38 | layers = [] 39 | for i in range(int(nb_layers)): 40 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 41 | return nn.Sequential(*layers) 42 | def forward(self, x): 43 | return self.layer(x) 44 | 45 | class WideResNet(nn.Module): 46 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0, mean=None, std=None, spatial_expansion=False): 47 | super(WideResNet, self).__init__() 48 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 49 | assert((depth - 4) % 6 == 0) 50 | n = (depth - 4) / 6 51 | 52 | self.mean = mean.view(1, -1, 1, 1) 53 | self.std = std.view(1, -1, 1, 1) 54 | self.spatial_expansion = spatial_expansion 55 | 56 | block = BasicBlock 57 | # 1st conv before any network block 58 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 59 | padding=1, bias=False) 60 | # 1st block 61 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 62 | # 2nd block 63 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 64 | # 3rd block 65 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 66 | # global average pooling and classifier 67 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.fc = nn.Linear(nChannels[3], num_classes) 70 | self.nChannels = nChannels[3] 71 | 72 | for m in self.modules(): 73 | if isinstance(m, nn.Conv2d): 74 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 75 | elif isinstance(m, nn.BatchNorm2d): 76 | m.weight.data.fill_(1) 77 | m.bias.data.zero_() 78 | elif isinstance(m, nn.Linear): 79 | m.bias.data.zero_() 80 | def forward(self, x): 81 | out = (x - self.mean) / self.std 82 | out = self.conv1(out) 83 | out = self.block1(out) 84 | out = self.block2(out) 85 | out = self.block3(out) 86 | out = self.relu(self.bn1(out)) 87 | 88 | if self.spatial_expansion: 89 | out = F.avg_pool2d(out, 16) 90 | else: 91 | out = F.avg_pool2d(out, 8) 92 | out = out.view(-1, self.nChannels) 93 | return self.fc(out) 94 | 95 | def wide_resnet(depth=28, widen_factor=10, dataset='cifar10', mean=None, std=None): 96 | if dataset == 'cifar10' or dataset == 'svhn': 97 | num_classes = 10 98 | spatial_expansion = False 99 | elif dataset == 'cifar100': 100 | num_classes = 100 101 | spatial_expansion = False 102 | elif dataset == 'tiny': 103 | num_classes = 200 104 | spatial_expansion = True 105 | else: 106 | raise NotImplementedError 107 | return WideResNet(depth=depth, num_classes=num_classes, widen_factor=widen_factor, dropRate=0.3, 108 | mean=mean, std=std, spatial_expansion=spatial_expansion) -------------------------------------------------------------------------------- /compute_saliency.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | 8 | from tqdm import tqdm 9 | from utils.mask_network_utils import get_mask_network 10 | from utils.data_utils import get_dataloader 11 | 12 | # attack loader 13 | from attack.attack import attack_loader 14 | from pruner.kfac_MAD_pruner import KFACMADPruner 15 | 16 | from utils.mask_parameter_generator_utils import MaskParameterGenerator 17 | from models.operator.mask import * 18 | 19 | 20 | # fetch args 21 | parser = argparse.ArgumentParser() 22 | 23 | 24 | # model parameter 25 | parser.add_argument('--dataset', default='cifar10', type=str) 26 | parser.add_argument('--network', default='vgg', type=str) 27 | parser.add_argument('--depth', default=16, type=int) 28 | parser.add_argument('--device', default='cuda:0', type=str) 29 | parser.add_argument('--batch_size', default=128, type=float) 30 | 31 | # attack parameter 32 | parser.add_argument('--attack', default='pgd', type=str) 33 | parser.add_argument('--eps', default=0.03, type=float) 34 | parser.add_argument('--steps', default=10, type=int) 35 | args = parser.parse_args() 36 | 37 | 38 | # init dataloader 39 | trainloader, testloader = get_dataloader(dataset=args.dataset, 40 | train_batch_size=args.batch_size, 41 | test_batch_size=256, 42 | is_test=False) 43 | 44 | # init model 45 | net = get_mask_network(network=args.network, 46 | depth=args.depth, 47 | dataset=args.dataset, 48 | device=args.device) 49 | net = net.to(args.device) 50 | 51 | 52 | # Load Plain Network 53 | print('==> Loading Plain checkpoint..') 54 | assert os.path.isdir('checkpoint/pretrain'), 'Error: no checkpoint directory found!' 55 | 56 | ''' ------------------------------------------------------------------------------------------------------------- ''' 57 | if not os.path.isdir('pickle'): 58 | os.mkdir('pickle') 59 | checkpoint = torch.load('checkpoint/pretrain/%s/%s_adv_%s%s_best.t7' % (args.dataset, args.dataset, args.network, args.depth), map_location=args.device) 60 | pickle_path = './pickle/%s_adv_%s%s_saliency.pickle' % (args.dataset, args.network, args.depth) 61 | ''' ------------------------------------------------------------------------------------------------------------- ''' 62 | print(pickle_path) 63 | net.load_state_dict(checkpoint['net'], strict=False) 64 | 65 | # Attack loader 66 | attack = attack_loader(net=net, attack=args.attack, eps=args.eps, steps=args.steps, dataset=args.dataset, device=args.device) 67 | 68 | # init criterion 69 | criterion = nn.CrossEntropyLoss() 70 | 71 | pruner = KFACMADPruner(net, attack, args.device, dataset=args.dataset) 72 | mask_model = MaskParameterGenerator(net) 73 | 74 | 75 | # [KFAC Masking Adversarial Damage (MAD)] 76 | def optimizing_mask(): 77 | 78 | total = 0 79 | correct = 0 80 | 81 | adv_delta_L_avg = [] 82 | 83 | desc = ('[Mask Optimizing for total Dataset] R Acc: %.3f%% (%d/%d)' % 84 | (0, correct, total)) 85 | prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True) 86 | for batch_idx, (inputs, targets) in prog_bar: 87 | net.eval() 88 | adv_x = attack(inputs, targets) 89 | inputs, adv_x, targets = inputs.to(args.device), adv_x.to(args.device), targets.to(args.device) 90 | 91 | # Adv mask optimizer [KFAC Masking Adversarial Damage (MAD)] 92 | mask_optimizer = optim.Adam(mask_model.mask_parameters(), lr=0.1) 93 | _, adv_delta_L_list, _, a_outputs\ 94 | = pruner._optimize_mask(adv_x, targets, mask_optimizer, mask_epoch=20, debug_acc=False, is_compute_delta_L=True) 95 | 96 | # performance validation 97 | _, a_predicted = a_outputs.max(1) 98 | a_num = a_predicted.eq(targets).sum().item() 99 | 100 | total += targets.size(0) 101 | correct += a_num 102 | 103 | # averaging 104 | if len(adv_delta_L_avg) == 0: 105 | 106 | adv_delta_L_avg = adv_delta_L_list 107 | 108 | adv_delta_L_avg = [x / len(trainloader) for x in adv_delta_L_avg] 109 | 110 | else: 111 | for index, l in enumerate(adv_delta_L_list): 112 | adv_delta_L_avg[index] += l / len(trainloader) 113 | 114 | 115 | desc = ('[Mask Optimizing for total Dataset] R Acc: %.1f%% (%d/%d)' % 116 | (100. * correct / total, correct, total)) 117 | prog_bar.set_description(desc, refresh=True) 118 | 119 | # pickle dictionary by converting torch.tensor.cuda to cpu 120 | pickle_dict = {} 121 | pickle_dict['adv_delta_L_avg'] = [x.cpu() for x in adv_delta_L_avg] 122 | 123 | # save 124 | import pickle 125 | with open(pickle_path, 'wb') as f: 126 | pickle.dump(pickle_dict, f, pickle.HIGHEST_PROTOCOL) 127 | 128 | onehot_dict = optimizing_mask() 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /attack/attack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchattacks 3 | from torchattacks.attack import Attack 4 | 5 | class FGSM_train(Attack): 6 | 7 | def __init__(self, model, eps=0.007): 8 | super().__init__("FGSM_train", model) 9 | self.eps = eps 10 | self._supported_mode = ['default', 'targeted'] 11 | 12 | def forward(self, images, labels): 13 | 14 | images = images.clone().detach().to(self.device) 15 | labels = labels.clone().detach().to(self.device) 16 | 17 | adv_images = images.clone().detach() 18 | 19 | # Starting at a uniformly random point 20 | adv_images = adv_images + torch.empty_like(adv_images).uniform_(-self.eps, self.eps) 21 | 22 | if self._targeted: 23 | target_labels = self._get_target_label(images, labels) 24 | 25 | loss = torch.nn.CrossEntropyLoss() 26 | 27 | adv_images.requires_grad = True 28 | outputs = self.model(adv_images) 29 | 30 | # Calculate loss 31 | if self._targeted: 32 | cost = -loss(outputs, target_labels) 33 | else: 34 | cost = loss(outputs, labels) 35 | 36 | # Update adversarial images 37 | grad = torch.autograd.grad(cost, adv_images, 38 | retain_graph=False, create_graph=False)[0] 39 | 40 | adv_images_ = adv_images.detach() + 1.25 * self.eps*grad.sign() 41 | delta = torch.clamp(adv_images_ - images, min=-self.eps, max=self.eps) 42 | return torch.clamp(images + delta, min=0, max=1).detach() 43 | 44 | class CW_Linf(Attack): 45 | 46 | def __init__(self, model, eps, c=0.1, kappa=0, steps=1000, lr=0.01): 47 | super().__init__("CW_Linf", model) 48 | self.eps = eps 49 | self.alpha = eps/steps * 2.3 50 | self.c = c 51 | self.kappa = kappa 52 | self.steps = steps 53 | self.lr = lr 54 | self._supported_mode = ['default', 'targeted'] 55 | 56 | def forward(self, images, labels): 57 | 58 | images = images.clone().detach().to(self.device) 59 | labels = labels.clone().detach().to(self.device) 60 | 61 | adv_images = images.clone().detach() 62 | 63 | 64 | # Starting at a uniformly random point 65 | adv_images = adv_images + torch.empty_like(adv_images).uniform_(-self.eps, self.eps) 66 | adv_images = torch.clamp(adv_images, min=0, max=1).detach() 67 | 68 | 69 | for step in range(self.steps): 70 | 71 | adv_images.requires_grad = True 72 | 73 | outputs = self.model(adv_images) 74 | f_loss = self.f(outputs, labels).sum() 75 | cost = f_loss 76 | 77 | grad = torch.autograd.grad(cost, adv_images, retain_graph=False, create_graph=False)[0] 78 | adv_images = adv_images.detach() - self.alpha * grad.sign() 79 | delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps) 80 | adv_images = torch.clamp(images + delta, min=0, max=1).detach() 81 | 82 | return adv_images 83 | 84 | def tanh_space(self, x): 85 | return 1/2*(torch.tanh(x) + 1) 86 | 87 | def inverse_tanh_space(self, x): 88 | # torch.atanh is only for torch >= 1.7.0 89 | return self.atanh(x*2-1) 90 | 91 | def atanh(self, x): 92 | return 0.5*torch.log((1+x)/(1-x)) 93 | 94 | # f-function in the paper 95 | def f(self, outputs, labels): 96 | one_hot_labels = torch.eye(len(outputs[0]))[labels].to(self.device) 97 | 98 | i, _ = torch.max((1-one_hot_labels)*outputs, dim=1) 99 | j = torch.masked_select(outputs, one_hot_labels.bool()) 100 | 101 | if self._targeted: 102 | return torch.clamp((i-j), min=-self.kappa) 103 | else: 104 | return torch.clamp((j-i), min=-self.kappa) 105 | 106 | def attack_loader(net, attack, eps, steps, dataset, device): 107 | 108 | if dataset == 'cifar10': 109 | n_channel = 3 110 | n_classes = 10 111 | img_size = 32 112 | elif dataset == 'cifar100': 113 | n_channel = 3 114 | n_classes = 100 115 | img_size = 32 116 | elif dataset == 'svhn': 117 | n_channel = 3 118 | n_classes = 10 119 | img_size = 32 120 | elif dataset == 'tiny': 121 | n_channel = 3 122 | n_classes = 200 123 | img_size = 64 124 | 125 | # torch attacks 126 | if attack == "fgsm": 127 | return torchattacks.FGSM(model=net, eps=eps) 128 | 129 | elif attack == "fgsm_train": 130 | return FGSM_train(model=net, eps=eps) 131 | 132 | elif attack == "bim": 133 | return torchattacks.BIM(model=net, eps=eps, alpha=1/255) 134 | 135 | elif attack == "pgd": 136 | return torchattacks.PGD(model=net, eps=eps, 137 | alpha=eps/steps*2.3, steps=steps, random_start=True) 138 | 139 | elif attack == "cw_linf": 140 | return CW_Linf(model=net, eps=eps, lr=0.1, steps=30) 141 | 142 | elif attack == "apgd": 143 | return torchattacks.APGD(model=net, eps=eps, loss='ce', steps=30) 144 | 145 | elif attack == "auto": 146 | return torchattacks.AutoAttack(model=net, eps=eps, n_classes=n_classes) 147 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BasicBlock(nn.Module): 7 | expansion = 1 8 | 9 | def __init__(self, in_planes, planes, stride=1): 10 | super(BasicBlock, self).__init__() 11 | self.conv1 = nn.Conv2d( 12 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(planes) 14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 15 | stride=1, padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | 18 | self.shortcut = nn.Sequential() 19 | if stride != 1 or in_planes != self.expansion*planes: 20 | self.shortcut = nn.Sequential( 21 | nn.Conv2d(in_planes, self.expansion*planes, 22 | kernel_size=1, stride=stride, bias=False), 23 | nn.BatchNorm2d(self.expansion*planes) 24 | ) 25 | 26 | def forward(self, x): 27 | out = F.relu(self.bn1(self.conv1(x))) 28 | out = self.bn2(self.conv2(out)) 29 | out += self.shortcut(x) 30 | out = F.relu(out) 31 | return out 32 | 33 | 34 | class Bottleneck(nn.Module): 35 | expansion = 4 36 | 37 | def __init__(self, in_planes, planes, stride=1): 38 | super(Bottleneck, self).__init__() 39 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 40 | self.bn1 = nn.BatchNorm2d(planes) 41 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 42 | stride=stride, padding=1, bias=False) 43 | self.bn2 = nn.BatchNorm2d(planes) 44 | self.conv3 = nn.Conv2d(planes, self.expansion * 45 | planes, kernel_size=1, bias=False) 46 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 47 | 48 | self.shortcut = nn.Sequential() 49 | if stride != 1 or in_planes != self.expansion*planes: 50 | self.shortcut = nn.Sequential( 51 | nn.Conv2d(in_planes, self.expansion*planes, 52 | kernel_size=1, stride=stride, bias=False), 53 | nn.BatchNorm2d(self.expansion*planes) 54 | ) 55 | 56 | def forward(self, x): 57 | out = F.relu(self.bn1(self.conv1(x))) 58 | out = F.relu(self.bn2(self.conv2(out))) 59 | out = self.bn3(self.conv3(out)) 60 | out += self.shortcut(x) 61 | out = F.relu(out) 62 | return out 63 | 64 | 65 | class ResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10, mean=None, std=None, spatial_expansion=False): 67 | super(ResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.mean = mean.view(1, -1, 1, 1) 71 | self.std = std.view(1, -1, 1, 1) 72 | self.spatial_expansion = spatial_expansion 73 | 74 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 75 | stride=1, padding=1, bias=False) 76 | self.bn1 = nn.BatchNorm2d(64) 77 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 78 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 79 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 80 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 81 | self.linear = nn.Linear(512*block.expansion, num_classes) 82 | 83 | def _make_layer(self, block, planes, num_blocks, stride): 84 | strides = [stride] + [1]*(num_blocks-1) 85 | layers = [] 86 | for stride in strides: 87 | layers.append(block(self.in_planes, planes, stride)) 88 | self.in_planes = planes * block.expansion 89 | return nn.Sequential(*layers) 90 | 91 | def forward(self, x): 92 | out = (x - self.mean) / self.std 93 | out = F.relu(self.bn1(self.conv1(out))) 94 | out = self.layer1(out) 95 | out = self.layer2(out) 96 | out = self.layer3(out) 97 | out = self.layer4(out) 98 | 99 | if self.spatial_expansion: 100 | out = F.avg_pool2d(out, 8) 101 | else: 102 | out = F.avg_pool2d(out, 4) 103 | out = out.view(out.size(0), -1) 104 | out = self.linear(out) 105 | return out 106 | 107 | 108 | def resnet(depth=18, dataset='cifar10', mean=None, std=None): 109 | if dataset == 'cifar10' or dataset == 'svhn': 110 | num_classes = 10 111 | spatial_expansion = False 112 | elif dataset == 'cifar100': 113 | num_classes = 100 114 | spatial_expansion = False 115 | elif dataset == 'tiny': 116 | num_classes = 200 117 | spatial_expansion = True 118 | else: 119 | raise NotImplementedError 120 | 121 | 122 | if depth == 18: 123 | block = BasicBlock 124 | num_blocks = [2, 2, 2, 2] 125 | elif depth == 34: 126 | block = BasicBlock 127 | num_blocks = [3, 4, 6, 3] 128 | elif depth == 50: 129 | block = Bottleneck 130 | num_blocks = [3, 4, 6, 3] 131 | else: 132 | raise NotImplementedError 133 | 134 | return ResNet(block=block, num_blocks=num_blocks, num_classes=num_classes, 135 | mean=mean, std=std, spatial_expansion=spatial_expansion) 136 | 137 | 138 | # test() -------------------------------------------------------------------------------- /utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | import logging 5 | 6 | import torch 7 | 8 | from pprint import pprint 9 | from easydict import EasyDict as edict 10 | 11 | 12 | def get_logger(name, logpath, filepath, package_files=[], 13 | displaying=True, saving=True): 14 | logger = logging.getLogger(name) 15 | logger.setLevel(logging.INFO) 16 | log_path = logpath + name + time.strftime("-%Y%m%d-%H%M%S") 17 | makedirs(log_path) 18 | if saving: 19 | info_file_handler = logging.FileHandler(log_path) 20 | info_file_handler.setLevel(logging.INFO) 21 | logger.addHandler(info_file_handler) 22 | logger.info(filepath) 23 | with open(filepath, 'r') as f: 24 | logger.info(f.read()) 25 | 26 | for f in package_files: 27 | logger.info(f) 28 | with open(f, 'r') as package_f: 29 | logger.info(package_f.read()) 30 | if displaying: 31 | console_handler = logging.StreamHandler() 32 | console_handler.setLevel(logging.INFO) 33 | logger.addHandler(console_handler) 34 | 35 | return logger 36 | 37 | 38 | def makedirs(filename): 39 | if not os.path.exists(os.path.dirname(filename)): 40 | os.makedirs(os.path.dirname(filename)) 41 | 42 | 43 | def str_to_list(src: object, delimiter: object, converter: object) -> object: 44 | """Conver a string to list. 45 | """ 46 | src_split = src.split(delimiter) 47 | res = [converter(_) for _ in src_split] 48 | return res 49 | 50 | 51 | def get_config_from_json(json_file): 52 | """ 53 | Get the config from a json file 54 | :param json_file: 55 | :return: config(namespace) or config(dictionary) 56 | """ 57 | # parse the configurations from the config json file provided 58 | with open(json_file, 'r') as config_file: 59 | config_dict = json.load(config_file) 60 | config = edict(config_dict) 61 | 62 | return config, config_dict 63 | 64 | 65 | def process_config(json_file): 66 | """Process a json file into a config file. 67 | Where we can access the value using .xxx 68 | Note: we will need to create a similar directory as the config file. 69 | """ 70 | config, _ = get_config_from_json(json_file) 71 | paths = json_file.split('/')[1:-1] 72 | summary_dir = ["./runs/pruning"] + paths + [config.exp_name, "summary/"] 73 | ckpt_dir = ["./runs/pruning"] + paths + [config.exp_name, "checkpoint/"] 74 | config.summary_dir = os.path.join(*summary_dir) 75 | config.checkpoint_dir = os.path.join(*ckpt_dir) 76 | return config 77 | 78 | 79 | def try_contiguous(x): 80 | if not x.is_contiguous(): 81 | x = x.contiguous() 82 | 83 | return x 84 | 85 | 86 | def try_cuda(x): 87 | if torch.cuda.is_available(): 88 | x = x.cuda() 89 | return x 90 | 91 | 92 | def tensor_to_list(tensor): 93 | if len(tensor.shape) == 1: 94 | return [tensor[_].item() for _ in range(tensor.shape[0])] 95 | else: 96 | return [tensor_to_list(tensor[_]) for _ in range(tensor.shape[0])] 97 | 98 | 99 | # ===================================================== 100 | # For learning rate schedule 101 | # ===================================================== 102 | class StairCaseLRScheduler(object): 103 | def __init__(self, start_at, interval, decay_rate): 104 | self.start_at = start_at 105 | self.interval = interval 106 | self.decay_rate = decay_rate 107 | 108 | def __call__(self, optimizer, iteration): 109 | start_at = self.start_at 110 | interval = self.interval 111 | decay_rate = self.decay_rate 112 | if (start_at >= 0) \ 113 | and (iteration >= start_at) \ 114 | and (iteration + 1) % interval == 0: 115 | for param_group in optimizer.param_groups: 116 | param_group['lr'] *= decay_rate 117 | print('[%d]Decay lr to %f' % (iteration, param_group['lr'])) 118 | 119 | @staticmethod 120 | def get_lr(optimizer): 121 | for param_group in optimizer.param_groups: 122 | lr = param_group['lr'] 123 | return lr 124 | 125 | 126 | class PresetLRScheduler(object): 127 | """Using a manually designed learning rate schedule rules. 128 | """ 129 | def __init__(self, decay_schedule): 130 | # decay_schedule is a dictionary 131 | # which is for specifying iteration -> lr 132 | self.decay_schedule = decay_schedule 133 | print('=> Using a preset learning rate schedule:') 134 | pprint(decay_schedule) 135 | self.for_once = True 136 | 137 | def __call__(self, optimizer, iteration): 138 | for param_group in optimizer.param_groups: 139 | lr = self.decay_schedule.get(iteration, param_group['lr']) 140 | param_group['lr'] = lr 141 | 142 | @staticmethod 143 | def get_lr(optimizer): 144 | for param_group in optimizer.param_groups: 145 | lr = param_group['lr'] 146 | return lr 147 | 148 | 149 | # ======================================================= 150 | # For math computation 151 | # ======================================================= 152 | def prod(l): 153 | val = 1 154 | if isinstance(l, list): 155 | for v in l: 156 | val *= v 157 | else: 158 | val = val * l 159 | 160 | return val -------------------------------------------------------------------------------- /models/operator/mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Conv2d, Linear, Parameter 3 | import torch.nn.functional as F 4 | from torch.nn.modules.utils import _pair 5 | import math 6 | 7 | initial_mask = 1 8 | name = 'softplus' 9 | thresh = 2 10 | 11 | def f(x, name=name): 12 | if name == 'softplus': 13 | return F.softplus(x) 14 | elif name == 'sigmoid': 15 | return torch.sigmoid(x) 16 | elif name == 'exp': 17 | return torch.exp(x) 18 | elif name == 'cov': 19 | return 1 / 2 * (torch.tanh(x) + 1) 20 | elif name == 'identity': 21 | return x 22 | elif name == 'tanh': 23 | return 0.01 * torch.tanh(x) 24 | 25 | 26 | def f_inv(x, name=name): 27 | if not isinstance(x, torch.Tensor): 28 | x = torch.tensor(x).float() 29 | if name == 'softplus': 30 | x = (0.0001) * (x == 0).float() + x * (x != 0).float() 31 | return torch.log(torch.exp(x) - 1) 32 | elif name == 'sigmoid': 33 | x = x * (x < 1).float() * (x > 0).float() + 0.999 * (x==1).float() + 0.001 * (x==0).float() 34 | return torch.log(x / (1-x)) 35 | elif name == 'exp': 36 | x = 0.001 * (x == 0).float() + x * (x != 0).float() 37 | return torch.log(x) 38 | elif name == 'cov': 39 | return torch.atanh(2 * 0.99 * (x == 1).float() + 0.001 * (x == 0).float() + x * (0 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | class NetworkBlock(nn.Module): 35 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 36 | super(NetworkBlock, self).__init__() 37 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 38 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 39 | layers = [] 40 | for i in range(int(nb_layers)): 41 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 42 | return nn.Sequential(*layers) 43 | def forward(self, x): 44 | return self.layer(x) 45 | 46 | class WideResNet(nn.Module): 47 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0, mean=None, std=None, spatial_expansion=False): 48 | super(WideResNet, self).__init__() 49 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 50 | assert((depth - 4) % 6 == 0) 51 | n = (depth - 4) / 6 52 | 53 | self.mean = mean.view(1, -1, 1, 1) 54 | self.std = std.view(1, -1, 1, 1) 55 | self.spatial_expansion = spatial_expansion 56 | 57 | block = BasicBlock 58 | # 1st conv before any network block 59 | self.conv1 = Conv2d_mask(3, nChannels[0], kernel_size=3, stride=1, 60 | padding=1, bias=False) 61 | # 1st block 62 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 63 | # 2nd block 64 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 65 | # 3rd block 66 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 67 | # global average pooling and classifier 68 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.fc = Linear_mask(nChannels[3], num_classes) 71 | self.nChannels = nChannels[3] 72 | 73 | for m in self.modules(): 74 | if isinstance(m, Conv2d_mask): 75 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 76 | elif isinstance(m, nn.BatchNorm2d): 77 | m.weight.data.fill_(1) 78 | m.bias.data.zero_() 79 | elif isinstance(m, Linear_mask): 80 | m.bias.data.zero_() 81 | def forward(self, x, pop=False, inter=False): 82 | 83 | # Feature visualization 84 | if not inter: 85 | out = (x - self.mean) / self.std 86 | out = self.conv1(out) 87 | out = self.block1(out) 88 | out = self.block2(out) 89 | out = self.block3(out) 90 | out = self.relu(self.bn1(out)) 91 | 92 | if pop: 93 | return out 94 | else: 95 | out = x 96 | 97 | 98 | 99 | if self.spatial_expansion: 100 | out = F.avg_pool2d(out, 16) 101 | else: 102 | out = F.avg_pool2d(out, 8) 103 | out = out.view(-1, self.nChannels) 104 | return self.fc(out) 105 | 106 | def wide_resnet(depth=28, widen_factor=10, dataset='cifar10', mean=None, std=None): 107 | if dataset == 'cifar10' or dataset == 'svhn': 108 | num_classes = 10 109 | spatial_expansion = False 110 | elif dataset == 'cifar100': 111 | num_classes = 100 112 | spatial_expansion = False 113 | elif dataset == 'tiny': 114 | num_classes = 200 115 | spatial_expansion = True 116 | else: 117 | raise NotImplementedError 118 | return WideResNet(depth=depth, num_classes=num_classes, widen_factor=widen_factor, dropRate=0.3, 119 | mean=mean, std=std, spatial_expansion=spatial_expansion) -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | 8 | from tqdm import tqdm 9 | from utils.network_utils import get_network 10 | from utils.data_utils import get_dataloader 11 | from utils.common_utils import PresetLRScheduler 12 | 13 | # fetch args 14 | parser = argparse.ArgumentParser() 15 | 16 | # model parameter 17 | parser.add_argument('--dataset', default='cifar10', type=str) 18 | parser.add_argument('--network', default='vgg', type=str) 19 | parser.add_argument('--depth', default=16, type=int) 20 | parser.add_argument('--epoch', default=200, type=int) 21 | parser.add_argument('--device', default='cuda:0', type=str) 22 | 23 | # learning parameter 24 | parser.add_argument('--learning_rate', default=0.1, type=float) 25 | parser.add_argument('--weight_decay', default=0.0002, type=float) 26 | parser.add_argument('--batch_size', default=128, type=float) 27 | 28 | args = parser.parse_args() 29 | 30 | # init model 31 | net = get_network(network=args.network, 32 | depth=args.depth, 33 | dataset=args.dataset, 34 | device=args.device) 35 | net = net.to(args.device) 36 | 37 | # init dataloader 38 | trainloader, testloader = get_dataloader(dataset=args.dataset, 39 | train_batch_size=args.batch_size, 40 | test_batch_size=256) 41 | 42 | # init optimizer and lr scheduler 43 | optimizer = optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay) 44 | lr_schedule = {0: args.learning_rate, 45 | int(args.epoch*0.5): args.learning_rate*0.1, 46 | int(args.epoch*0.75): args.learning_rate*0.01} 47 | lr_scheduler = PresetLRScheduler(lr_schedule) 48 | 49 | # init criterion 50 | criterion = nn.CrossEntropyLoss() 51 | 52 | start_epoch = 0 53 | best_acc = 0 54 | 55 | 56 | def train(epoch): 57 | print('\nEpoch: %d' % epoch) 58 | net.train() 59 | train_loss = 0 60 | correct = 0 61 | total = 0 62 | 63 | lr_scheduler(optimizer, epoch) 64 | desc = ('[Train/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % 65 | (lr_scheduler.get_lr(optimizer), 0, 0, correct, total)) 66 | 67 | 68 | prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True) 69 | for batch_idx, (inputs, targets) in prog_bar: 70 | inputs, targets = inputs.to(args.device), targets.to(args.device) 71 | optimizer.zero_grad() 72 | outputs = net(inputs) 73 | loss = criterion(outputs, targets) 74 | loss.backward() 75 | optimizer.step() 76 | 77 | train_loss += loss.item() 78 | _, predicted = outputs.max(1) 79 | total += targets.size(0) 80 | correct += predicted.eq(targets).sum().item() 81 | 82 | desc = ('[Train/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % 83 | (lr_scheduler.get_lr(optimizer), train_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 84 | prog_bar.set_description(desc, refresh=True) 85 | 86 | 87 | 88 | def test(epoch): 89 | global best_acc 90 | net.eval() 91 | test_loss = 0 92 | correct = 0 93 | total = 0 94 | desc = ('[Test/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' 95 | % (lr_scheduler.get_lr(optimizer), test_loss/(0+1), 0, correct, total)) 96 | 97 | prog_bar = tqdm(enumerate(testloader), total=len(testloader), desc=desc, leave=True) 98 | with torch.no_grad(): 99 | for batch_idx, (inputs, targets) in prog_bar: 100 | inputs, targets = inputs.to(args.device), targets.to(args.device) 101 | outputs = net(inputs) 102 | loss = criterion(outputs, targets) 103 | 104 | test_loss += loss.item() 105 | _, predicted = outputs.max(1) 106 | total += targets.size(0) 107 | correct += predicted.eq(targets).sum().item() 108 | 109 | desc = ('[Test/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' 110 | % (lr_scheduler.get_lr(optimizer), test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 111 | prog_bar.set_description(desc, refresh=True) 112 | 113 | # Save checkpoint. 114 | acc = 100.*correct/total 115 | 116 | 117 | if acc > best_acc: 118 | print('Saving..') 119 | state = { 120 | 'net': net.state_dict(), 121 | 'acc': acc, 122 | 'epoch': epoch, 123 | 'loss': loss, 124 | 'args': args 125 | } 126 | if not os.path.isdir('checkpoint'): 127 | os.mkdir('checkpoint') 128 | if not os.path.isdir('checkpoint/pretrain'): 129 | os.mkdir('checkpoint/pretrain') 130 | torch.save(state, './checkpoint/pretrain/%s/%s_%s%s_best.t7' % (args.dataset, args.dataset, 131 | args.network, 132 | args.depth)) 133 | print('./checkpoint/pretrain/%s/%s_%s%s_best.t7' % (args.dataset, args.dataset, 134 | args.network, 135 | args.depth)) 136 | best_acc = acc 137 | 138 | 139 | for epoch in range(start_epoch, args.epoch): 140 | train(epoch) 141 | test(epoch) 142 | 143 | 144 | 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10/CIFAR100 with PyTorch.''' 2 | from __future__ import print_function 3 | import os 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | 8 | from tqdm import tqdm 9 | from utils.network_utils import get_network 10 | from utils.data_utils import get_dataloader 11 | 12 | # attack loader 13 | from attack.attack import attack_loader 14 | 15 | # warning ignore 16 | import warnings 17 | warnings.filterwarnings("ignore") 18 | from utils.utils import str2bool 19 | 20 | 21 | # fetch args 22 | parser = argparse.ArgumentParser() 23 | 24 | # model parameter 25 | parser.add_argument('--dataset', default='cifar10', type=str) 26 | parser.add_argument('--network', default='vgg', type=str) 27 | parser.add_argument('--depth', default=16, type=int) 28 | parser.add_argument('--baseline', default='adv', type=str) 29 | parser.add_argument('--device', default='cuda:0', type=str) 30 | 31 | # mad parameter 32 | parser.add_argument('--percnt', default=0.9, type=float) 33 | parser.add_argument('--pruning_mode', default='el', type=str) 34 | parser.add_argument('--largest', default='false', type=str2bool) 35 | 36 | # attack parameter 37 | parser.add_argument('--attack', default='pgd', type=str) 38 | parser.add_argument('--eps', default=0.03, type=float) 39 | parser.add_argument('--steps', default=30, type=int) 40 | args = parser.parse_args() 41 | 42 | 43 | ''' ------------------------------------------------------------------------------------------------------------- ''' 44 | if args.baseline == 'mad': 45 | if args.pruning_mode == 'rd': 46 | checkpoint_name = 'checkpoint/pretrain/%s/%s_%s_mad%s_%s%s_best.t7' % (args.dataset, args.pruning_mode, args.dataset, str(int(100*args.percnt)), args.network, args.depth) 47 | print("This test : {}".format(checkpoint_name)) 48 | else: 49 | checkpoint_name = 'checkpoint/pretrain/%s/%s_%s_%s_mad%s_%s%s_best.t7' % (args.dataset, args.largest, args.pruning_mode, args.dataset, str(int(100*args.percnt)), args.network, args.depth) 50 | print("This test : {}".format(checkpoint_name)) 51 | else: 52 | checkpoint_name = 'checkpoint/pretrain/%s/%s_%s_%s%s_best.t7' % (args.dataset, args.dataset, args.baseline, args.network, args.depth) 53 | print("This test : {}".format(checkpoint_name)) 54 | ''' ------------------------------------------------------------------------------------------------------------- ''' 55 | 56 | # init dataloader 57 | _, testloader = get_dataloader(dataset=args.dataset, 58 | train_batch_size=1, 59 | test_batch_size=128) 60 | 61 | # init model 62 | net = get_network(network=args.network, 63 | depth=args.depth, 64 | dataset=args.dataset, 65 | device=args.device) 66 | net = net.to(args.device) 67 | 68 | 69 | # Load Plain Network 70 | print('==> Loading Plain checkpoint..') 71 | assert os.path.isdir('checkpoint/pretrain'), 'Error: no checkpoint directory found!' 72 | 73 | 74 | checkpoint = torch.load(checkpoint_name, map_location=args.device) 75 | net.load_state_dict(checkpoint['net'], strict=False) 76 | 77 | 78 | # init criterion 79 | criterion = nn.CrossEntropyLoss() 80 | 81 | # compute prune ratio 82 | from models.operator.mask import compute_prune_ratio 83 | p_ratio, param_size = compute_prune_ratio(net, is_param=True) 84 | print("Prune Ratio : {:.2f}, Param size : {:.2f}".format( p_ratio, param_size)) 85 | 86 | def test(): 87 | net.eval() 88 | test_loss = 0 89 | 90 | attack_score = [] 91 | attack_module = {} 92 | for attack_name in ['Plain', 'fgsm', 'pgd', 'cw_Linf', 'apgd', 'auto']: 93 | args.attack = attack_name 94 | attack_module[attack_name] = attack_loader(net=net, attack=attack_name, 95 | eps=args.eps, steps=args.steps, 96 | dataset=args.dataset, device=args.device) \ 97 | if attack_name != 'Plain' else None 98 | 99 | for key in attack_module: 100 | total = 0 101 | correct = 0 102 | prog_bar = tqdm(enumerate(testloader), total=len(testloader), leave=True) 103 | for batch_idx, (inputs, targets) in prog_bar: 104 | inputs, targets = inputs.to(args.device), targets.to(args.device) 105 | if key != 'Plain': 106 | inputs = attack_module[key](inputs, targets) 107 | outputs = net(inputs) 108 | loss = criterion(outputs, targets) 109 | 110 | test_loss += loss.item() 111 | _, predicted = outputs.max(1) 112 | total += targets.size(0) 113 | correct += predicted.eq(targets).sum().item() 114 | 115 | desc = ('[Test/%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' 116 | % (key, test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 117 | prog_bar.set_description(desc, refresh=True) 118 | 119 | attack_score.append(100. * correct / total) 120 | 121 | print('\n----------------Summary----------------') 122 | print(args.steps, ' steps attack') 123 | for key, score in zip(attack_module, attack_score): 124 | print(str(key), ' : ', str(score) + '(%)') 125 | print('---------------------------------------\n') 126 | 127 | if __name__ == '__main__': 128 | test() 129 | 130 | 131 | 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | import matplotlib.pyplot as plt 16 | import numpy as np 17 | 18 | def str2bool(v): 19 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 20 | return True 21 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 22 | return False 23 | else: 24 | assert False 25 | 26 | def list_to_concat(adv_delta_L_avg, device): 27 | size = [x.shape for x in adv_delta_L_avg] 28 | concat_conv = torch.cat([x.flatten() for x in adv_delta_L_avg[:-1]]).to(device) 29 | linear = adv_delta_L_avg[-1] # feature * n_classes 30 | return concat_conv, linear, size 31 | 32 | def custom_imshow(img): 33 | img = img[0].cpu().numpy() 34 | plt.imshow(np.transpose(img, (1, 2, 0))) 35 | plt.show() 36 | 37 | def index2onehot(index, size): 38 | 39 | if index.dim() == 1: 40 | try: 41 | onehot = torch.zeros(size.numel()) 42 | except: 43 | onehot = torch.zeros(size) 44 | onehot[index] = 1 45 | elif index.dim() == 2: 46 | onehot = torch.zeros(size) 47 | for i in range(onehot.shape[0]): 48 | onehot[i, index[i]] = 1 49 | return onehot.view(size) 50 | 51 | def onehot2index(onehot): 52 | index = [] 53 | for i in range(onehot.shape[0]): 54 | if onehot[i] != 0: 55 | index.append(i) 56 | return torch.tensor(index) 57 | 58 | def pl(a): 59 | plt.plot(a.cpu()) 60 | plt.show() 61 | 62 | def sc(a): 63 | plt.scatter(range(len(a.cpu())), a.cpu(), s=2, color='darkred', alpha=0.5) 64 | plt.show() 65 | 66 | 67 | def get_mean_and_std(dataset): 68 | '''Compute the mean and std value of dataset.''' 69 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0) 70 | mean = torch.zeros(3) 71 | std = torch.zeros(3) 72 | print('==> Computing mean and std..') 73 | for inputs, targets in dataloader: 74 | for i in range(3): 75 | mean[i] += inputs[:,i,:,:].mean() 76 | std[i] += inputs[:,i,:,:].std() 77 | mean.div_(len(dataset)) 78 | std.div_(len(dataset)) 79 | return mean, std 80 | 81 | 82 | def init_params(net): 83 | '''Init layer parameters.''' 84 | for m in net.modules(): 85 | if isinstance(m, nn.Conv2d): 86 | init.kaiming_normal(m.weight, mode='fan_out') 87 | if m.bias: 88 | init.constant(m.bias, 0) 89 | elif isinstance(m, nn.BatchNorm2d): 90 | init.constant(m.weight, 1) 91 | init.constant(m.bias, 0) 92 | elif isinstance(m, nn.Linear): 93 | init.normal(m.weight, std=1e-3) 94 | if m.bias: 95 | init.constant(m.bias, 0) 96 | 97 | 98 | _, term_width = os.popen('stty size', 'r').read().split() 99 | term_width = int(term_width) 100 | 101 | TOTAL_BAR_LENGTH = 65. 102 | last_time = time.time() 103 | begin_time = last_time 104 | 105 | 106 | def progress_bar(current, total, msg=None): 107 | global last_time, begin_time 108 | if current == 0: 109 | begin_time = time.time() # Reset for new bar. 110 | 111 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 112 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 113 | 114 | sys.stdout.write(' [') 115 | for i in range(cur_len): 116 | sys.stdout.write('=') 117 | sys.stdout.write('>') 118 | for i in range(rest_len): 119 | sys.stdout.write('.') 120 | sys.stdout.write(']') 121 | 122 | cur_time = time.time() 123 | step_time = cur_time - last_time 124 | last_time = cur_time 125 | tot_time = cur_time - begin_time 126 | 127 | L = [] 128 | L.append(' Step: %s' % format_time(step_time)) 129 | L.append(' | Tot: %s' % format_time(tot_time)) 130 | if msg: 131 | L.append(' | ' + msg) 132 | 133 | msg = ''.join(L) 134 | sys.stdout.write(msg) 135 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 136 | sys.stdout.write(' ') 137 | 138 | # Go back to the center of the bar. 139 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 140 | sys.stdout.write('\b') 141 | sys.stdout.write(' %d/%d ' % (current+1, total)) 142 | 143 | if current < total-1: 144 | sys.stdout.write('\r') 145 | else: 146 | sys.stdout.write('\n') 147 | sys.stdout.flush() 148 | 149 | 150 | def format_time(seconds): 151 | days = int(seconds / 3600/24) 152 | seconds = seconds - days*3600*24 153 | hours = int(seconds / 3600) 154 | seconds = seconds - hours*3600 155 | minutes = int(seconds / 60) 156 | seconds = seconds - minutes*60 157 | secondsf = int(seconds) 158 | seconds = seconds - secondsf 159 | millis = int(seconds*1000) 160 | 161 | f = '' 162 | i = 1 163 | if days > 0: 164 | f += str(days) + 'D' 165 | i += 1 166 | if hours > 0 and i <= 2: 167 | f += str(hours) + 'h' 168 | i += 1 169 | if minutes > 0 and i <= 2: 170 | f += str(minutes) + 'm' 171 | i += 1 172 | if secondsf > 0 and i <= 2: 173 | f += str(secondsf) + 's' 174 | i += 1 175 | if millis > 0 and i <= 2: 176 | f += str(millis) + 'ms' 177 | i += 1 178 | if f == '': 179 | f = '0ms' 180 | return f 181 | -------------------------------------------------------------------------------- /models/resnet_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.operator.mask import Conv2d_mask, Linear_mask 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | expansion = 1 9 | 10 | def __init__(self, in_planes, planes, stride=1): 11 | super(BasicBlock, self).__init__() 12 | self.conv1 = Conv2d_mask( 13 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(planes) 15 | self.conv2 = Conv2d_mask(planes, planes, kernel_size=3, 16 | stride=1, padding=1, bias=False) 17 | self.bn2 = nn.BatchNorm2d(planes) 18 | 19 | self.shortcut = nn.Sequential() 20 | if stride != 1 or in_planes != self.expansion*planes: 21 | self.shortcut = nn.Sequential( 22 | Conv2d_mask(in_planes, self.expansion*planes, 23 | kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(self.expansion*planes) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = self.bn2(self.conv2(out)) 30 | out += self.shortcut(x) 31 | out = F.relu(out) 32 | return out 33 | 34 | 35 | class Bottleneck(nn.Module): 36 | expansion = 4 37 | 38 | def __init__(self, in_planes, planes, stride=1): 39 | super(Bottleneck, self).__init__() 40 | self.conv1 = Conv2d_mask(in_planes, planes, kernel_size=1, bias=False) 41 | self.bn1 = nn.BatchNorm2d(planes) 42 | self.conv2 = Conv2d_mask(planes, planes, kernel_size=3, 43 | stride=stride, padding=1, bias=False) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.conv3 = Conv2d_mask(planes, self.expansion * 46 | planes, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | Conv2d_mask(in_planes, self.expansion*planes, 53 | kernel_size=1, stride=stride, bias=False), 54 | nn.BatchNorm2d(self.expansion*planes) 55 | ) 56 | 57 | def forward(self, x): 58 | out = F.relu(self.bn1(self.conv1(x))) 59 | out = F.relu(self.bn2(self.conv2(out))) 60 | out = self.bn3(self.conv3(out)) 61 | out += self.shortcut(x) 62 | out = F.relu(out) 63 | return out 64 | 65 | 66 | class ResNet(nn.Module): 67 | def __init__(self, block, num_blocks, num_classes=10, mean=None, std=None, spatial_expansion=False): 68 | super(ResNet, self).__init__() 69 | self.in_planes = 64 70 | 71 | self.mean = mean.view(1, -1, 1, 1) 72 | self.std = std.view(1, -1, 1, 1) 73 | self.spatial_expansion = spatial_expansion 74 | 75 | self.conv1 = Conv2d_mask(3, 64, kernel_size=3, 76 | stride=1, padding=1, bias=False) 77 | self.bn1 = nn.BatchNorm2d(64) 78 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 79 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 80 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 81 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 82 | self.linear = Linear_mask(512*block.expansion, num_classes) 83 | 84 | def _make_layer(self, block, planes, num_blocks, stride): 85 | strides = [stride] + [1]*(num_blocks-1) 86 | layers = [] 87 | for stride in strides: 88 | layers.append(block(self.in_planes, planes, stride)) 89 | self.in_planes = planes * block.expansion 90 | return nn.Sequential(*layers) 91 | 92 | def forward(self, x, pop=False, inter=False): 93 | # Feature visualization 94 | if not inter: 95 | out = (x - self.mean) / self.std 96 | out = F.relu(self.bn1(self.conv1(out))) 97 | out = self.layer1(out) 98 | out = self.layer2(out) 99 | out = self.layer3(out) 100 | out = self.layer4(out) 101 | if pop: 102 | return out 103 | else: 104 | out = x 105 | 106 | if self.spatial_expansion: 107 | out = F.avg_pool2d(out, 8) 108 | else: 109 | out = F.avg_pool2d(out, 4) 110 | out = out.view(out.size(0), -1) 111 | out = self.linear(out) 112 | return out 113 | 114 | 115 | def resnet(depth=18, dataset='cifar10', mean=None, std=None): 116 | if dataset == 'cifar10' or dataset == 'svhn': 117 | num_classes = 10 118 | spatial_expansion = False 119 | elif dataset == 'cifar100': 120 | num_classes = 100 121 | spatial_expansion = False 122 | elif dataset == 'tiny': 123 | num_classes = 200 124 | spatial_expansion = True 125 | else: 126 | raise NotImplementedError 127 | 128 | 129 | if depth == 18: 130 | block = BasicBlock 131 | num_blocks = [2, 2, 2, 2] 132 | elif depth == 34: 133 | block = BasicBlock 134 | num_blocks = [3, 4, 6, 3] 135 | elif depth == 50: 136 | block = Bottleneck 137 | num_blocks = [3, 4, 6, 3] 138 | else: 139 | raise NotImplementedError 140 | 141 | return ResNet(block=block, num_blocks=num_blocks, num_classes=num_classes, 142 | mean=mean, std=std, spatial_expansion=spatial_expansion) 143 | 144 | -------------------------------------------------------------------------------- /main_adv_pretrain.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | 8 | from tqdm import tqdm 9 | from utils.network_utils import get_network 10 | from utils.data_utils import get_dataloader 11 | from utils.common_utils import PresetLRScheduler 12 | 13 | # attack loader 14 | from attack.attack import attack_loader 15 | 16 | # fetch args 17 | parser = argparse.ArgumentParser() 18 | 19 | 20 | # model parameter 21 | parser.add_argument('--dataset', default='cifar10', type=str) 22 | parser.add_argument('--network', default='vgg', type=str) 23 | parser.add_argument('--depth', default=16, type=int) 24 | parser.add_argument('--device', default='cuda:0', type=str) 25 | 26 | # learning parameter 27 | parser.add_argument('--learning_rate', default=0.1, type=float) 28 | parser.add_argument('--weight_decay', default=0.0002, type=float) 29 | parser.add_argument('--batch_size', default=128, type=float) 30 | parser.add_argument('--epoch', default=60, type=int) 31 | 32 | # attack parameter 33 | parser.add_argument('--attack', default='pgd', type=str) 34 | parser.add_argument('--eps', default=0.03, type=float) 35 | parser.add_argument('--steps', default=10, type=int) 36 | args = parser.parse_args() 37 | 38 | 39 | # init dataloader 40 | trainloader, testloader = get_dataloader(dataset=args.dataset, 41 | train_batch_size=args.batch_size, 42 | test_batch_size=256) 43 | 44 | # init model 45 | net = get_network(network=args.network, 46 | depth=args.depth, 47 | dataset=args.dataset, 48 | device=args.device) 49 | net = net.to(args.device) 50 | 51 | # Load Plain Network 52 | print('==> Loading Plain checkpoint..') 53 | assert os.path.isdir('checkpoint/pretrain'), 'Error: no checkpoint directory found!' 54 | checkpoint = torch.load('checkpoint/pretrain/%s/%s_%s%s_best.t7' % (args.dataset, args.dataset, args.network, args.depth)) 55 | print('Loaded checkpoint : checkpoint/pretrain/%s/%s_%s%s_best.t7' % (args.dataset, args.dataset, args.network, args.depth)) 56 | net.load_state_dict(checkpoint['net']) 57 | 58 | # Attack loader 59 | if (args.network=='wide') and (args.dataset=='tiny'): 60 | print('Fast FGSM training') 61 | attack = attack_loader(net=net, attack='fgsm_train', eps=args.eps, steps=args.steps, dataset=args.dataset, device=args.device) 62 | else: 63 | print('Low PGD training') 64 | attack = attack_loader(net=net, attack=args.attack, eps=args.eps, steps=args.steps, dataset=args.dataset, device=args.device) 65 | 66 | 67 | # init optimizer and lr scheduler 68 | optimizer = optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay) 69 | lr_schedule = {0: args.learning_rate, 70 | int(args.epoch*0.5): args.learning_rate*0.1, 71 | int(args.epoch*0.75): args.learning_rate*0.01} 72 | lr_scheduler = PresetLRScheduler(lr_schedule) 73 | 74 | # init criterion 75 | criterion = nn.CrossEntropyLoss() 76 | 77 | start_epoch = 0 78 | best_acc = 0 79 | 80 | def train(epoch): 81 | print('\nEpoch: %d' % epoch) 82 | net.train() 83 | train_loss = 0 84 | correct = 0 85 | total = 0 86 | 87 | lr_scheduler(optimizer, epoch) 88 | desc = ('[Train/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % 89 | (lr_scheduler.get_lr(optimizer), 0, 0, correct, total)) 90 | 91 | 92 | prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True) 93 | for batch_idx, (inputs, targets) in prog_bar: 94 | inputs = attack(inputs, targets) 95 | inputs, targets = inputs.to(args.device), targets.to(args.device) 96 | optimizer.zero_grad() 97 | outputs = net(inputs) 98 | loss = criterion(outputs, targets) 99 | loss.backward() 100 | optimizer.step() 101 | 102 | train_loss += loss.item() 103 | _, predicted = outputs.max(1) 104 | total += targets.size(0) 105 | correct += predicted.eq(targets).sum().item() 106 | 107 | desc = ('[Train/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % 108 | (lr_scheduler.get_lr(optimizer), train_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 109 | prog_bar.set_description(desc, refresh=True) 110 | 111 | 112 | 113 | def test(epoch): 114 | global best_acc 115 | net.eval() 116 | test_loss = 0 117 | correct = 0 118 | total = 0 119 | desc = ('[Test/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' 120 | % (lr_scheduler.get_lr(optimizer), test_loss/(0+1), 0, correct, total)) 121 | 122 | prog_bar = tqdm(enumerate(testloader), total=len(testloader), desc=desc, leave=True) 123 | for batch_idx, (inputs, targets) in prog_bar: 124 | inputs = attack(inputs, targets) 125 | inputs, targets = inputs.to(args.device), targets.to(args.device) 126 | outputs = net(inputs) 127 | loss = criterion(outputs, targets) 128 | 129 | test_loss += loss.item() 130 | _, predicted = outputs.max(1) 131 | total += targets.size(0) 132 | correct += predicted.eq(targets).sum().item() 133 | 134 | desc = ('[Test/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' 135 | % (lr_scheduler.get_lr(optimizer), test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 136 | prog_bar.set_description(desc, refresh=True) 137 | 138 | # Save checkpoint. 139 | acc = 100.*correct/total 140 | 141 | 142 | if acc > best_acc: 143 | print('Saving..') 144 | state = { 145 | 'net': net.state_dict(), 146 | 'acc': acc, 147 | 'epoch': epoch, 148 | 'loss': loss, 149 | 'args': args 150 | } 151 | if not os.path.isdir('checkpoint'): 152 | os.mkdir('checkpoint') 153 | if not os.path.isdir('checkpoint/pretrain'): 154 | os.mkdir('checkpoint/pretrain') 155 | torch.save(state, './checkpoint/pretrain/%s/%s_adv_%s%s_best.t7' % (args.dataset, args.dataset, 156 | args.network, 157 | args.depth)) 158 | print('./checkpoint/pretrain/%s/%s_adv_%s%s_best.t7' % (args.dataset, args.dataset, 159 | args.network, 160 | args.depth)) 161 | best_acc = acc 162 | 163 | 164 | for epoch in range(start_epoch, args.epoch): 165 | train(epoch) 166 | test(epoch) 167 | 168 | 169 | 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /main_mart_pretrain.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | 9 | from tqdm import tqdm 10 | from utils.network_utils import get_network 11 | from utils.data_utils import get_dataloader 12 | from utils.common_utils import PresetLRScheduler 13 | 14 | # attack loader 15 | from attack.attack import attack_loader 16 | 17 | # fetch args 18 | parser = argparse.ArgumentParser() 19 | 20 | # model parameter 21 | parser.add_argument('--dataset', default='cifar10', type=str) 22 | parser.add_argument('--network', default='vgg', type=str) 23 | parser.add_argument('--depth', default=16, type=int) 24 | parser.add_argument('--device', default='cuda:0', type=str) 25 | 26 | # learning parameter 27 | parser.add_argument('--learning_rate', default=1e-3, type=float) 28 | parser.add_argument('--weight_decay', default=0.0002, type=float) 29 | parser.add_argument('--batch_size', default=128, type=float) 30 | parser.add_argument('--epoch', default=60, type=int) 31 | 32 | # attack parameter 33 | parser.add_argument('--attack', default='pgd', type=str) 34 | parser.add_argument('--eps', default=0.03, type=float) 35 | parser.add_argument('--steps', default=10, type=int) 36 | args = parser.parse_args() 37 | 38 | 39 | # init dataloader 40 | trainloader, testloader = get_dataloader(dataset=args.dataset, 41 | train_batch_size=args.batch_size, 42 | test_batch_size=256) 43 | 44 | # init model 45 | net = get_network(network=args.network, 46 | depth=args.depth, 47 | dataset=args.dataset, 48 | device=args.device) 49 | net = net.to(args.device) 50 | 51 | 52 | # Load Adv Network 53 | print('==> Loading Adv checkpoint..') 54 | assert os.path.isdir('checkpoint/pretrain'), 'Error: no checkpoint directory found!' 55 | checkpoint = torch.load('checkpoint/pretrain/%s/%s_adv_%s%s_best.t7' % (args.dataset, args.dataset, args.network, args.depth), map_location=args.device) 56 | print('checkpoint/pretrain/%s/%s_adv_%s%s_best.t7' % (args.dataset, args.dataset, args.network, args.depth)) 57 | net.load_state_dict(checkpoint['net']) 58 | 59 | # Attack loader 60 | if args.dataset=='tiny': 61 | print('Fast FGSM training') 62 | attack = attack_loader(net=net, attack='fgsm_train', eps=args.eps, steps=args.steps, dataset=args.dataset, device=args.device) 63 | else: 64 | print('Low PGD training') 65 | attack = attack_loader(net=net, attack=args.attack, eps=args.eps, steps=args.steps, dataset=args.dataset, device=args.device) 66 | 67 | 68 | # init optimizer and lr scheduler 69 | optimizer = optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay) 70 | lr_schedule = {0: args.learning_rate, 71 | int(args.epoch*0.5): args.learning_rate*0.1, 72 | int(args.epoch*0.75): args.learning_rate*0.01} 73 | lr_scheduler = PresetLRScheduler(lr_schedule) 74 | # lr_scheduler = #StairCaseLRScheduler(0, args.decay_every, args.decay_ratio) 75 | 76 | # init criterion 77 | criterion = nn.CrossEntropyLoss() 78 | 79 | start_epoch = 0 80 | best_acc = 0 81 | if args.resume: 82 | print('==> Resuming from checkpoint..') 83 | assert os.path.isdir('checkpoint/pretrain'), 'Error: no checkpoint directory found!' 84 | checkpoint = torch.load('checkpoint/pretrain/%s/%s_adv_%s%s_bn_best.t7' % (args.dataset, args.dataset, args.network, args.depth)) 85 | net.load_state_dict(checkpoint['net']) 86 | best_acc = checkpoint['acc'] 87 | start_epoch = checkpoint['epoch'] 88 | print('==> Loaded checkpoint at epoch: %d, acc: %.2f%%' % (start_epoch, best_acc)) 89 | 90 | 91 | def train(epoch): 92 | print('\nEpoch: %d' % epoch) 93 | net.train() 94 | train_loss = 0 95 | correct = 0 96 | total = 0 97 | 98 | lr_scheduler(optimizer, epoch) 99 | desc = ('[Train/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % 100 | (lr_scheduler.get_lr(optimizer), 0, 0, correct, total)) 101 | 102 | prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True) 103 | for batch_idx, (inputs, targets) in prog_bar: 104 | inputs, targets = inputs.to(args.device), targets.to(args.device) 105 | optimizer.zero_grad() 106 | 107 | # MART Loss 108 | loss, logits = mart_loss_orig(net, inputs, targets, optimizer, device=args.device) 109 | 110 | loss.backward() 111 | optimizer.step() 112 | 113 | 114 | train_loss += loss.item() 115 | _, predicted = logits.max(1) 116 | total += targets.size(0) 117 | correct += predicted.eq(targets).sum().item() 118 | 119 | desc = ('[Train/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % 120 | (lr_scheduler.get_lr(optimizer), train_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 121 | prog_bar.set_description(desc, refresh=True) 122 | 123 | 124 | 125 | def test(epoch): 126 | global best_acc 127 | net.eval() 128 | test_loss = 0 129 | correct = 0 130 | total = 0 131 | desc = ('[Test/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' 132 | % (lr_scheduler.get_lr(optimizer), test_loss/(0+1), 0, correct, total)) 133 | 134 | prog_bar = tqdm(enumerate(testloader), total=len(testloader), desc=desc, leave=True) 135 | for batch_idx, (inputs, targets) in prog_bar: 136 | inputs = attack(inputs, targets) 137 | inputs, targets = inputs.to(args.device), targets.to(args.device) 138 | outputs = net(inputs) 139 | loss = criterion(outputs, targets) 140 | 141 | test_loss += loss.item() 142 | _, predicted = outputs.max(1) 143 | total += targets.size(0) 144 | correct += predicted.eq(targets).sum().item() 145 | 146 | desc = ('[Test/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' 147 | % (lr_scheduler.get_lr(optimizer), test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 148 | prog_bar.set_description(desc, refresh=True) 149 | 150 | # Save checkpoint. 151 | acc = 100.*correct/total 152 | 153 | if acc > best_acc: 154 | print('Saving..') 155 | state = { 156 | 'net': net.state_dict(), 157 | 'acc': acc, 158 | 'epoch': epoch, 159 | 'loss': loss, 160 | 'args': args 161 | } 162 | if not os.path.isdir('checkpoint'): 163 | os.mkdir('checkpoint') 164 | if not os.path.isdir('checkpoint/pretrain'): 165 | os.mkdir('checkpoint/pretrain') 166 | torch.save(state, './checkpoint/pretrain/%s/%s_mart_%s%s_best.t7' % (args.dataset, args.dataset, 167 | args.network, 168 | args.depth)) 169 | print('./checkpoint/pretrain/%s/%s_mart_%s%s_best.t7' % (args.dataset, args.dataset, 170 | args.network, 171 | args.depth)) 172 | best_acc = acc 173 | 174 | 175 | 176 | 177 | def mart_loss_orig(model, 178 | x_natural, 179 | y, 180 | optim, 181 | device, 182 | step_size=0.03/10*2.3 if args.dataset != 'tiny' else 0.03 * 1.25, 183 | epsilon=0.03, 184 | perturb_steps=10 if args.dataset != 'tiny' else 1, 185 | beta=5, 186 | distance='l_inf'): 187 | kl = torch.nn.KLDivLoss(reduction='none') 188 | model.eval() 189 | batch_size = len(x_natural) 190 | if args.dataset != 'tiny': 191 | # generate adversarial example 192 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).to(device).detach() 193 | else: 194 | # generate adversarial example 195 | x_adv = x_natural.detach() + torch.empty_like(x_natural).uniform_(-epsilon, epsilon).to(device).detach() 196 | if distance == 'l_inf': 197 | for _ in range(perturb_steps): 198 | x_adv.requires_grad_() 199 | with torch.enable_grad(): 200 | loss_ce = F.cross_entropy(model(x_adv), y) 201 | grad = torch.autograd.grad(loss_ce, [x_adv])[0] 202 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 203 | x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon) 204 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 205 | else: 206 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 207 | model.train() 208 | 209 | x_adv = torch.autograd.Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 210 | # zero gradient 211 | optim.zero_grad() 212 | 213 | logits = model(x_natural) 214 | 215 | logits_adv = model(x_adv) 216 | 217 | adv_probs = F.softmax(logits_adv, dim=1) 218 | 219 | tmp1 = torch.argsort(adv_probs, dim=1)[:, -2:] 220 | 221 | new_y = torch.where(tmp1[:, -1] == y, tmp1[:, -2], tmp1[:, -1]) 222 | 223 | loss_adv = F.cross_entropy(logits_adv, y) + F.nll_loss(torch.log(1.0001 - adv_probs + 1e-12), new_y) 224 | 225 | nat_probs = F.softmax(logits, dim=1) 226 | 227 | true_probs = torch.gather(nat_probs, 1, (y.unsqueeze(1)).long()).squeeze() 228 | 229 | loss_robust = (1.0 / batch_size) * torch.sum( 230 | torch.sum(kl(torch.log(adv_probs + 1e-12), nat_probs), dim=1) * (1.0000001 - true_probs)) 231 | loss = loss_adv + float(beta) * loss_robust 232 | 233 | return loss, logits_adv 234 | 235 | 236 | 237 | for epoch in range(start_epoch, args.epoch): 238 | train(epoch) 239 | test(epoch) 240 | 241 | 242 | 243 | 244 | 245 | 246 | -------------------------------------------------------------------------------- /main_mad_pretrain.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | 8 | from tqdm import tqdm 9 | from utils.mask_network_utils import get_mask_network 10 | from utils.data_utils import get_dataloader 11 | from utils.common_utils import PresetLRScheduler 12 | 13 | # attack loader 14 | from attack.attack import attack_loader 15 | from pruner.kfac_MAD_pruner import KFACMADPruner 16 | 17 | from utils.mask_parameter_generator_utils import MaskParameterGenerator 18 | from utils.utils import * 19 | from models.operator.mask import * 20 | from utils.utils import str2bool 21 | 22 | # fetch args 23 | parser = argparse.ArgumentParser() 24 | 25 | # model parameter 26 | parser.add_argument('--dataset', default='cifar10', type=str) 27 | parser.add_argument('--network', default='vgg', type=str) 28 | parser.add_argument('--depth', default=16, type=int) 29 | parser.add_argument('--device', default='cuda:0', type=str) 30 | 31 | # mad parameter 32 | parser.add_argument('--percnt', default=0.9, type=float) 33 | parser.add_argument('--pruning_mode', default='element', type=str) 34 | parser.add_argument('--largest', default='false', type=str2bool) 35 | 36 | # learning parameter 37 | parser.add_argument('--learning_rate', default=0.1, type=float) 38 | parser.add_argument('--weight_decay', default=0.0002, type=float) 39 | parser.add_argument('--batch_size', default=128, type=float) 40 | parser.add_argument('--epoch', default=60, type=int) 41 | 42 | # attack parameter 43 | parser.add_argument('--attack', default='pgd', type=str) 44 | parser.add_argument('--eps', default=0.03, type=float) 45 | parser.add_argument('--steps', default=10, type=int) 46 | args = parser.parse_args() 47 | 48 | 49 | # init dataloader 50 | trainloader, testloader = get_dataloader(dataset=args.dataset, 51 | train_batch_size=args.batch_size, 52 | test_batch_size=256) 53 | 54 | # init model 55 | net = get_mask_network(network=args.network, 56 | depth=args.depth, 57 | dataset=args.dataset, 58 | device=args.device) 59 | net = net.to(args.device) 60 | 61 | 62 | # Load Plain Network 63 | print('==> Loading Adv checkpoint..') 64 | assert os.path.isdir('checkpoint/pretrain'), 'Error: no checkpoint directory found!' 65 | 66 | ''' ------------------------------------------------------------------------------------------------------------- ''' 67 | if not os.path.isdir('pickle'): 68 | os.mkdir('pickle') 69 | checkpoint = torch.load('checkpoint/pretrain/%s/%s_adv_%s%s_best.t7' % (args.dataset, args.dataset, args.network, args.depth), map_location=args.device) 70 | pickle_path = './pickle/%s_adv_%s%s_saliency.pickle' % (args.dataset, args.network, args.depth) 71 | print('checkpoint/pretrain/%s/%s_adv_%s%s_best.t7' % (args.dataset, args.dataset, args.network, args.depth)) 72 | ''' ------------------------------------------------------------------------------------------------------------- ''' 73 | net.load_state_dict(checkpoint['net'], strict=False) 74 | 75 | # Attack loader 76 | if args.dataset=='tiny': 77 | print('Fast FGSM training') 78 | attack = attack_loader(net=net, attack='fgsm_train', eps=args.eps, steps=args.steps, dataset=args.dataset, device=args.device) 79 | else: 80 | print('Low PGD training') 81 | attack = attack_loader(net=net, attack=args.attack, eps=args.eps, steps=args.steps, dataset=args.dataset, device=args.device) 82 | 83 | # init criterion 84 | criterion = nn.CrossEntropyLoss() 85 | 86 | start_epoch = 0 87 | best_acc = 0 88 | 89 | pruner = KFACMADPruner(net, attack, args.device, dataset=args.dataset) 90 | mask_model = MaskParameterGenerator(net) 91 | 92 | # [KFAC Masking Adversarial Damage (MAD)] 93 | def pruning_model(): 94 | # load 95 | import pickle 96 | 97 | with open(pickle_path, 'rb') as f: 98 | 99 | pickle_dict = pickle.load(f) 100 | adv_delta_L_avg = pickle_dict['adv_delta_L_avg'] 101 | 102 | if args.pruning_mode == 'element': 103 | onehot_dict = pruner._global_remove_weight(adv_delta_L_avg, percnt=args.percnt, largest=args.largest) 104 | elif args.pruning_mode == 'random': 105 | onehot_dict = pruner._global_random_pruning(adv_delta_L_avg, percnt=args.percnt) 106 | 107 | print("{} Successfully Pruned !!".format(pickle_path)) 108 | print("Pruning Ratio Per Layer is : {}".format(pruner._compute_prune_ratio_per_layer())) 109 | return onehot_dict 110 | 111 | 112 | def train(epoch, onehot_dict, mode): 113 | print('\nEpoch: %d' % epoch) 114 | net.train() 115 | train_loss = 0 116 | correct = 0 117 | total = 0 118 | 119 | p_ratio = compute_prune_ratio(net) 120 | 121 | lr_scheduler(optimizer, epoch) 122 | desc = ('[Train/%s/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d), Prune: %.2f' % 123 | (mode, lr_scheduler.get_lr(optimizer), 0, 0, correct, total, p_ratio)) 124 | 125 | prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True) 126 | for batch_idx, (inputs, targets) in prog_bar: 127 | net.train() 128 | inputs, targets = inputs.to(args.device), targets.to(args.device) 129 | loss, outputs = pruner._optimize_weight_with_delta_L(inputs, targets, optimizer, onehot_dict, 130 | pruning_mode=args.pruning_mode) 131 | train_loss += loss.item() 132 | _, predicted = outputs.max(1) 133 | total += targets.size(0) 134 | correct += predicted.eq(targets).sum().item() 135 | 136 | desc = ('[Train/%s/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d), Prune: %.2f' % 137 | (mode, lr_scheduler.get_lr(optimizer), train_loss / (batch_idx + 1), 100. * correct / total, correct, total, p_ratio)) 138 | prog_bar.set_description(desc, refresh=True) 139 | 140 | 141 | def test(epoch, is_attack=False): 142 | global best_acc 143 | net.eval() 144 | test_loss = 0 145 | correct = 0 146 | total = 0 147 | print("Pruning Ratio Per Layer is : {}".format(pruner._compute_prune_ratio_per_layer())) 148 | desc = ('[Test/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' 149 | % (lr_scheduler.get_lr(optimizer), test_loss/(0+1), 0, correct, total)) 150 | 151 | prog_bar = tqdm(enumerate(testloader), total=len(testloader), desc=desc, leave=True) 152 | for batch_idx, (inputs, targets) in prog_bar: 153 | inputs = attack(inputs, targets) if is_attack else inputs 154 | inputs, targets = inputs.to(args.device), targets.to(args.device) 155 | outputs = net(inputs) 156 | loss = criterion(outputs, targets) 157 | 158 | test_loss += loss.item() 159 | _, predicted = outputs.max(1) 160 | total += targets.size(0) 161 | correct += predicted.eq(targets).sum().item() 162 | 163 | desc = ('[Test/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' 164 | % (lr_scheduler.get_lr(optimizer), test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 165 | prog_bar.set_description(desc, refresh=True) 166 | 167 | if is_attack: 168 | return 169 | 170 | # Save checkpoint. 171 | acc = 100.*correct/total 172 | 173 | if acc > best_acc: 174 | print('Saving..') 175 | state = { 176 | 'net': net.state_dict(), 177 | 'acc': acc, 178 | 'epoch': epoch, 179 | 'loss': loss, 180 | 'args': args 181 | } 182 | if not os.path.isdir('checkpoint'): 183 | os.mkdir('checkpoint') 184 | if not os.path.isdir('checkpoint/pretrain'): 185 | os.mkdir('checkpoint/pretrain') 186 | 187 | if args.pruning_mode=='element': 188 | torch.save(state, './checkpoint/pretrain/%s/%s_el_%s_mad%d_%s%s_best.t7' % (args.dataset, str(args.largest), args.dataset, 189 | int(args.percnt*100), 190 | args.network, 191 | args.depth)) 192 | print('./checkpoint/pretrain/%s/%s_el_%s_mad%d_%s%s_best.t7' % (args.dataset, str(args.largest), args.dataset, 193 | int(args.percnt*100), 194 | args.network, 195 | args.depth)) 196 | 197 | 198 | elif args.pruning_mode=='random': 199 | torch.save(state, './checkpoint/pretrain/%s/rd_%s_mad%d_%s%s_best.t7' % (args.dataset, args.dataset, 200 | int(args.percnt*100), 201 | args.network, 202 | args.depth)) 203 | print('./checkpoint/pretrain/%s/rd_%s_mad%d_%s%s_best.t7' % (args.dataset, args.dataset, 204 | int(args.percnt*100), 205 | args.network, 206 | args.depth)) 207 | best_acc = acc 208 | 209 | 210 | 211 | onehot_dict = pruning_model() 212 | 213 | # MAD init optimizer and lr scheduler 214 | optimizer = optim.SGD(mask_model.non_mask_parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay) 215 | lr_schedule = {0: args.learning_rate, 216 | int(args.epoch*0.5): args.learning_rate*0.1, 217 | int(args.epoch*0.75): args.learning_rate*0.01} 218 | lr_scheduler = PresetLRScheduler(lr_schedule) 219 | 220 | print("--------------MAD--------------") 221 | for epoch in range(start_epoch, args.epoch): 222 | train(epoch, onehot_dict) 223 | test(epoch, is_attack=False) 224 | test(epoch, is_attack=True) 225 | 226 | 227 | 228 | 229 | 230 | 231 | -------------------------------------------------------------------------------- /main_trades_pretrain.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | 9 | from tqdm import tqdm 10 | from utils.network_utils import get_network 11 | from utils.data_utils import get_dataloader 12 | from utils.common_utils import PresetLRScheduler 13 | 14 | # attack loader 15 | from attack.attack import attack_loader 16 | 17 | # fetch args 18 | parser = argparse.ArgumentParser() 19 | 20 | # model parameter 21 | parser.add_argument('--dataset', default='cifar10', type=str) 22 | parser.add_argument('--network', default='vgg', type=str) 23 | parser.add_argument('--depth', default=16, type=int) 24 | parser.add_argument('--device', default='cuda:0', type=str) 25 | 26 | # learning parameter 27 | parser.add_argument('--learning_rate', default=1e-3, type=float) 28 | parser.add_argument('--weight_decay', default=0.0002, type=float) 29 | parser.add_argument('--batch_size', default=128, type=float) 30 | parser.add_argument('--epoch', default=60, type=int) 31 | 32 | # attack parameter 33 | parser.add_argument('--attack', default='pgd', type=str) 34 | parser.add_argument('--eps', default=0.03, type=float) 35 | parser.add_argument('--steps', default=10, type=int) 36 | args = parser.parse_args() 37 | 38 | 39 | # init dataloader 40 | trainloader, testloader = get_dataloader(dataset=args.dataset, 41 | train_batch_size=args.batch_size, 42 | test_batch_size=256) 43 | 44 | # init model 45 | net = get_network(network=args.network, 46 | depth=args.depth, 47 | dataset=args.dataset, 48 | device=args.device) 49 | net = net.to(args.device) 50 | 51 | # Load Adv Network 52 | print('==> Loading Adv checkpoint..') 53 | assert os.path.isdir('checkpoint/pretrain'), 'Error: no checkpoint directory found!' 54 | checkpoint = torch.load('checkpoint/pretrain/%s/%s_adv_%s%s_best.t7' % (args.dataset, args.dataset, args.network, args.depth), map_location=args.device) 55 | print('checkpoint/pretrain/%s/%s_adv_%s%s_best.t7' % (args.dataset, args.dataset, args.network, args.depth)) 56 | net.load_state_dict(checkpoint['net']) 57 | 58 | # Attack loader 59 | if args.dataset=='tiny': 60 | print('Fast FGSM training') 61 | attack = attack_loader(net=net, attack='fgsm_train', eps=args.eps, steps=args.steps, dataset=args.dataset, device=args.device) 62 | else: 63 | print('Low PGD training') 64 | attack = attack_loader(net=net, attack=args.attack, eps=args.eps, steps=args.steps, dataset=args.dataset, device=args.device) 65 | 66 | 67 | # init optimizer and lr scheduler 68 | optimizer = optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay) 69 | lr_schedule = {0: args.learning_rate, 70 | int(args.epoch*0.5): args.learning_rate*0.1, 71 | int(args.epoch*0.75): args.learning_rate*0.01} 72 | lr_scheduler = PresetLRScheduler(lr_schedule) 73 | 74 | # init criterion 75 | criterion = nn.CrossEntropyLoss() 76 | 77 | start_epoch = 0 78 | best_acc = 0 79 | 80 | 81 | def train(epoch): 82 | print('\nEpoch: %d' % epoch) 83 | net.train() 84 | train_loss = 0 85 | correct = 0 86 | total = 0 87 | 88 | lr_scheduler(optimizer, epoch) 89 | desc = ('[Train/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % 90 | (lr_scheduler.get_lr(optimizer), 0, 0, correct, total)) 91 | 92 | 93 | prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True) 94 | for batch_idx, (inputs, targets) in prog_bar: 95 | inputs, targets = inputs.to(args.device), targets.to(args.device) 96 | optimizer.zero_grad() 97 | 98 | # TRADES Loss 99 | loss, logit = trades_loss_orig(net, inputs, targets, optimizer, device=args.device) 100 | 101 | loss.backward() 102 | optimizer.step() 103 | 104 | train_loss += loss.item() 105 | _, predicted = logit.max(1) 106 | total += targets.size(0) 107 | correct += predicted.eq(targets).sum().item() 108 | 109 | desc = ('[Train/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % 110 | (lr_scheduler.get_lr(optimizer), train_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 111 | prog_bar.set_description(desc, refresh=True) 112 | 113 | 114 | 115 | def test(epoch): 116 | global best_acc 117 | net.eval() 118 | test_loss = 0 119 | correct = 0 120 | total = 0 121 | desc = ('[Test/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' 122 | % (lr_scheduler.get_lr(optimizer), test_loss/(0+1), 0, correct, total)) 123 | 124 | prog_bar = tqdm(enumerate(testloader), total=len(testloader), desc=desc, leave=True) 125 | for batch_idx, (inputs, targets) in prog_bar: 126 | inputs = attack(inputs, targets) 127 | inputs, targets = inputs.to(args.device), targets.to(args.device) 128 | outputs = net(inputs) 129 | loss = criterion(outputs, targets) 130 | 131 | test_loss += loss.item() 132 | _, predicted = outputs.max(1) 133 | total += targets.size(0) 134 | correct += predicted.eq(targets).sum().item() 135 | 136 | desc = ('[Test/LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' 137 | % (lr_scheduler.get_lr(optimizer), test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 138 | prog_bar.set_description(desc, refresh=True) 139 | 140 | # Save checkpoint. 141 | acc = 100.*correct/total 142 | 143 | if acc > best_acc: 144 | print('Saving..') 145 | state = { 146 | 'net': net.state_dict(), 147 | 'acc': acc, 148 | 'epoch': epoch, 149 | 'loss': loss, 150 | 'args': args 151 | } 152 | if not os.path.isdir('checkpoint'): 153 | os.mkdir('checkpoint') 154 | if not os.path.isdir('checkpoint/pretrain'): 155 | os.mkdir('checkpoint/pretrain') 156 | torch.save(state, './checkpoint/pretrain/%s/%s_trades_%s%s_best.t7' % (args.dataset, args.dataset, 157 | args.network, 158 | args.depth)) 159 | print('./checkpoint/pretrain/%s/%s_trades_%s%s_best.t7' % (args.dataset, args.dataset, 160 | args.network, 161 | args.depth)) 162 | best_acc = acc 163 | 164 | 165 | 166 | 167 | def trades_loss_orig(model, 168 | x_natural, 169 | y, 170 | optimizer, 171 | device=None, 172 | step_size=0.03/10*2.3 if args.dataset != 'tiny' else 0.03 * 1.25, 173 | epsilon=0.03, 174 | perturb_steps=10 if args.dataset != 'tiny' else 1, 175 | beta=4.0, 176 | distance='l_inf'): 177 | # define KL-loss 178 | criterion_kl = nn.KLDivLoss(size_average=False) 179 | model.eval() 180 | batch_size = len(x_natural) 181 | 182 | if args.dataset != 'tiny': 183 | # generate adversarial example 184 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).to(device).detach() 185 | else: 186 | # generate adversarial example 187 | x_adv = x_natural.detach() + torch.empty_like(x_natural).uniform_(-epsilon, epsilon).to(device).detach() 188 | 189 | if distance == 'l_inf': 190 | for _ in range(perturb_steps): 191 | x_adv.requires_grad_() 192 | with torch.enable_grad(): 193 | loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1), 194 | F.softmax(model(x_natural), dim=1)) 195 | grad = torch.autograd.grad(loss_kl, [x_adv])[0] 196 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 197 | x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon) 198 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 199 | elif distance == 'l_2': 200 | delta = 0.001 * torch.randn(x_natural.shape).to(device).detach() 201 | delta = torch.autograd.Variable(delta.data, requires_grad=True) 202 | 203 | # Setup optimizers 204 | optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2) 205 | 206 | for _ in range(perturb_steps): 207 | adv = x_natural + delta 208 | 209 | # optimize 210 | optimizer_delta.zero_grad() 211 | with torch.enable_grad(): 212 | loss = (-1) * criterion_kl(F.log_softmax(model(adv), dim=1), 213 | F.softmax(model(x_natural), dim=1)) 214 | loss.backward() 215 | # renorming gradient 216 | grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1) 217 | delta.grad.div_(grad_norms.view(-1, 1, 1, 1)) 218 | # avoid nan or inf if gradient is 0 219 | if (grad_norms == 0).any(): 220 | delta.grad[grad_norms == 0] = torch.randn_like(delta.grad[grad_norms == 0]) 221 | optimizer_delta.step() 222 | 223 | # projection 224 | delta.data.add_(x_natural) 225 | delta.data.clamp_(0, 1).sub_(x_natural) 226 | delta.data.renorm_(p=2, dim=0, maxnorm=epsilon) 227 | x_adv = torch.autograd.Variable(x_natural + delta, requires_grad=False) 228 | else: 229 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 230 | model.train() 231 | 232 | x_adv = torch.autograd.Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 233 | # zero gradient 234 | optimizer.zero_grad() 235 | # calculate robust loss 236 | logits = model(x_natural) 237 | logits_adv = model(x_adv) 238 | loss_natural = F.cross_entropy(logits, y) 239 | loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(logits_adv, dim=1), 240 | F.softmax(model(x_natural), dim=1)) 241 | loss = loss_natural + beta * loss_robust 242 | return loss, logits_adv 243 | 244 | 245 | 246 | for epoch in range(start_epoch, args.epoch): 247 | train(epoch) 248 | test(epoch) 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | -------------------------------------------------------------------------------- /utils/kfac_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from utils.common_utils import try_contiguous 6 | from models.operator.mask import Conv2d_mask, Linear_mask 7 | 8 | 9 | def _extract_patches(x, kernel_size, stride, padding): 10 | """ 11 | :param x: The input feature maps. (batch_size, in_c, h, w) 12 | :param kernel_size: the kernel size of the conv filter (tuple of two elements) 13 | :param stride: the stride of conv operation (tuple of two elements) 14 | :param padding: number of paddings. be a tuple of two elements 15 | :return: (batch_size, out_h, out_w, in_c*kh*kw) 16 | """ 17 | if padding[0] + padding[1] > 0: 18 | x = F.pad(x, (padding[1], padding[1], padding[0], 19 | padding[0])).data # Actually check dims 20 | x = x.unfold(2, kernel_size[0], stride[0]) 21 | x = x.unfold(3, kernel_size[1], stride[1]) 22 | x = x.transpose_(1, 2).transpose_(2, 3).contiguous() 23 | x = x.view( 24 | x.size(0), x.size(1), x.size(2), 25 | x.size(3) * x.size(4) * x.size(5)) 26 | return x 27 | 28 | 29 | def _extract_channel_patches(x, kernel_size, stride, padding): 30 | """ 31 | :param x: The input feature maps. (batch_size, in_c, h, w) 32 | :param kernel_size: the kernel size of the conv filter (tuple of two elements) 33 | :param stride: the stride of conv operation (tuple of two elements) 34 | :param padding: number of paddings. be a tuple of two elements 35 | :return: (batch_size, out_h, out_w, in_c*kh*kw) 36 | """ 37 | if padding[0] + padding[1] > 0: 38 | x = F.pad(x, (padding[1], padding[1], padding[0], 39 | padding[0])).data # Actually check dims 40 | x = x.unfold(2, kernel_size[0], stride[0]) 41 | x = x.unfold(3, kernel_size[1], stride[1]) # b * oh * ow * kh * kw * inc 42 | x = x.transpose_(1, 2).transpose_(2, 3).transpose_(3, 4).transpose(4, 5).contiguous() 43 | x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), x.size(4), x.size(5)) 44 | return x 45 | 46 | 47 | def update_running_stat(aa, m_aa, stat_decay): 48 | # using inplace operation to save memory! 49 | m_aa *= stat_decay / (1 - stat_decay) 50 | m_aa += aa 51 | m_aa *= (1 - stat_decay) 52 | 53 | 54 | def fetch_mat_weights(layer, use_patch=False): 55 | # -> output_dium * input_dim (kh*kw*in_c + [1 if with bias]) 56 | if isinstance(layer, nn.Conv2d) or isinstance(layer, Conv2d_mask): 57 | if use_patch: 58 | weight = layer.weight.transpose(1, 2).transpose(2, 3) # n_out * kh * kw * inc 59 | n_out, k_h, k_w, in_c = weight.size() 60 | weight = try_contiguous(weight) 61 | weight = weight.view(-1, weight.size(-1)) 62 | bias = 0 63 | if layer.bias is not None: 64 | copied_bias = torch.cat([layer.bias.unsqueeze(1) for _ in range(k_h*k_w)], 1).view(-1, 1) 65 | weight = torch.cat([weight, copied_bias], 1) # layer.bias.unsqueeze(1)], 1) 66 | bias = 1 67 | weight = weight.view(n_out, k_h*k_w, in_c+bias) 68 | else: 69 | weight = layer.weight # n_filters * in_c * kh * kw 70 | # weight = weight.transpose(1, 2).transpose(2, 3).contiguous() 71 | weight = weight.view(weight.size(0), -1) 72 | if layer.bias is not None: 73 | weight = torch.cat([weight, layer.bias.unsqueeze(1)], 1) 74 | elif isinstance(layer, nn.Linear) or isinstance(layer, Linear_mask): 75 | weight = layer.weight 76 | if layer.bias is not None: 77 | weight = torch.cat([weight, layer.bias.unsqueeze(1)], 1) 78 | else: 79 | raise NotImplementedError 80 | 81 | return weight 82 | 83 | def fetch_mat_mask_weights(layer, use_patch=False): 84 | # -> output_dium * input_dim (kh*kw*in_c + [1 if with bias]) 85 | if isinstance(layer, nn.Conv2d) or isinstance(layer, Conv2d_mask): 86 | if use_patch: 87 | weight = layer.mask_weight.transpose(1, 2).transpose(2, 3) # n_out * kh * kw * inc 88 | n_out, k_h, k_w, in_c = weight.size() 89 | weight = try_contiguous(weight) 90 | weight = weight.view(-1, weight.size(-1)) 91 | bias = 0 92 | if layer.mask_bias is not None: 93 | copied_bias = torch.cat([layer.mask_bias.unsqueeze(1) for _ in range(k_h*k_w)], 1).view(-1, 1) 94 | weight = torch.cat([weight, copied_bias], 1) # layer.mask_bias.unsqueeze(1)], 1) 95 | bias = 1 96 | weight = weight.view(n_out, k_h*k_w, in_c+bias) 97 | else: 98 | weight = layer.mask_weight # n_filters * in_c * kh * kw 99 | # weight = weight.transpose(1, 2).transpose(2, 3).contiguous() 100 | weight = weight.view(weight.size(0), -1) 101 | if layer.mask_bias is not None: 102 | weight = torch.cat([weight, layer.mask_bias.unsqueeze(1)], 1) 103 | elif isinstance(layer, nn.Linear) or isinstance(layer, Linear_mask): 104 | weight = layer.mask_weight 105 | if layer.bias is not None: 106 | weight = torch.cat([weight, layer.mask_bias.unsqueeze(1)], 1) 107 | else: 108 | raise NotImplementedError 109 | 110 | return weight 111 | 112 | 113 | def mat_to_weight_and_bias(mat, layer): 114 | if isinstance(layer, nn.Conv2d): 115 | # mat: n_filters * (in_c * kh * kw) 116 | k_h, k_w = layer.kernel_size 117 | in_c = layer.in_channels 118 | out_c = layer.out_channels 119 | bias = None 120 | if layer.bias is not None: 121 | bias = mat[:, -1] 122 | mat = mat[:, :-1] 123 | weight = mat.view(out_c, in_c, k_h, k_w) 124 | elif isinstance(layer, nn.Linear): 125 | in_c = layer.in_features 126 | out_c = layer.out_features 127 | bias = None 128 | if layer.bias is not None: 129 | bias = mat[:, -1] 130 | mat = mat[:, :-1] 131 | weight = mat 132 | else: 133 | raise NotImplementedError 134 | return weight, bias 135 | 136 | 137 | class ComputeMatGrad: 138 | 139 | @classmethod 140 | def __call__(cls, input, grad_output, layer): 141 | if isinstance(layer, nn.Linear): 142 | grad = cls.linear(input, grad_output, layer) 143 | elif isinstance(layer, nn.Conv2d): 144 | grad = cls.conv2d(input, grad_output, layer) 145 | else: 146 | raise NotImplementedError 147 | return grad 148 | 149 | @staticmethod 150 | def linear(input, grad_output, layer): 151 | """ 152 | :param input: batch_size * input_dim 153 | :param grad_output: batch_size * output_dim 154 | :param layer: [nn.module] output_dim * input_dim 155 | :return: batch_size * output_dim * (input_dim + [1 if with bias]) 156 | """ 157 | with torch.no_grad(): 158 | if layer.bias is not None: 159 | input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1) 160 | input = input.unsqueeze(1) 161 | grad_output = grad_output.unsqueeze(2) 162 | grad = torch.bmm(grad_output, input) 163 | return grad 164 | 165 | @staticmethod 166 | def conv2d(input, grad_output, layer): 167 | """ 168 | :param input: batch_size * in_c * in_h * in_w 169 | :param grad_output: batch_size * out_c * h * w 170 | :param layer: nn.module batch_size * out_c * (in_c*k_h*k_w + [1 if with bias]) 171 | :return: 172 | """ 173 | with torch.no_grad(): 174 | input = _extract_patches(input, layer.kernel_size, layer.stride, layer.padding) 175 | input = input.view(-1, input.size(-1)) # b * hw * in_c*kh*kw 176 | grad_output = grad_output.transpose(1, 2).transpose(2, 3) 177 | grad_output = try_contiguous(grad_output).view(grad_output.size(0), -1, grad_output.size(-1)) 178 | # b * hw * out_c 179 | if layer.bias is not None: 180 | input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1) 181 | input = input.view(grad_output.size(0), -1, input.size(-1)) # b * hw * in_c*kh*kw 182 | grad = torch.einsum('abm,abn->amn', (grad_output, input)) 183 | return grad 184 | 185 | class ComputeCovA: 186 | @classmethod 187 | def compute_cov_a(cls, a, layer): 188 | return cls.__call__(a, layer) 189 | 190 | @classmethod 191 | def __call__(cls, a, layer): 192 | if isinstance(layer, nn.Linear): 193 | cov_a = cls.linear(a, layer) 194 | elif isinstance(layer, nn.Conv2d): 195 | cov_a = cls.conv2d(a, layer) 196 | else: 197 | # raise NotImplementedError 198 | cov_a = None 199 | 200 | return cov_a 201 | 202 | @staticmethod 203 | def conv2d(a, layer): 204 | batch_size = a.size(0) 205 | a = _extract_patches(a, layer.kernel_size, layer.stride, layer.padding) 206 | spatial_size = a.size(1) * a.size(2) 207 | a = a.view(-1, a.size(-1)) 208 | if layer.bias is not None: 209 | a = torch.cat([a, a.new(a.size(0), 1).fill_(1)], 1) 210 | a = a/spatial_size 211 | return a.t() @ (a / batch_size) 212 | 213 | @staticmethod 214 | def linear(a, layer): 215 | # a: batch_size * in_dim 216 | batch_size = a.size(0) 217 | if layer.bias is not None: 218 | a = torch.cat([a, a.new(a.size(0), 1).fill_(1)], 1) 219 | return a.t() @ (a / batch_size) 220 | 221 | class ComputeCovG: 222 | 223 | @classmethod 224 | def compute_cov_g(cls, g, layer, batch_averaged=False): 225 | """ 226 | :param g: gradient 227 | :param layer: the corresponding layer 228 | :param batch_averaged: if the gradient is already averaged with the batch size? 229 | :return: 230 | """ 231 | # batch_size = g.size(0) 232 | return cls.__call__(g, layer, batch_averaged) 233 | 234 | @classmethod 235 | def __call__(cls, g, layer, batch_averaged): 236 | if isinstance(layer, nn.Conv2d): 237 | cov_g = cls.conv2d(g, layer, batch_averaged) 238 | elif isinstance(layer, nn.Linear): 239 | cov_g = cls.linear(g, layer, batch_averaged) 240 | else: 241 | cov_g = None 242 | 243 | return cov_g 244 | 245 | @staticmethod 246 | def conv2d(g, layer, batch_averaged): 247 | # g: batch_size * n_filters * out_h * out_w 248 | # n_filters is actually the output dimension (analogous to Linear layer) 249 | spatial_size = g.size(2) * g.size(3) 250 | batch_size = g.shape[0] 251 | g = g.transpose(1, 2).transpose(2, 3) 252 | g = try_contiguous(g) 253 | g = g.view(-1, g.size(-1)) 254 | 255 | if batch_averaged: 256 | g = g * batch_size 257 | g = g * spatial_size 258 | cov_g = g.t() @ (g / g.size(0)) 259 | 260 | return cov_g 261 | 262 | @staticmethod 263 | def linear(g, layer, batch_averaged): 264 | # g: batch_size * out_dim 265 | batch_size = g.size(0) 266 | 267 | if batch_averaged: 268 | cov_g = g.t() @ (g * batch_size) 269 | else: 270 | cov_g = g.t() @ (g / batch_size) 271 | return cov_g 272 | 273 | 274 | class ComputeCovAPatch(ComputeCovA): 275 | @staticmethod 276 | def conv2d(a, layer): 277 | batch_size = a.size(0) 278 | a = _extract_channel_patches(a, layer.kernel_size, layer.stride, layer.padding) 279 | spatial_size = a.size(1) * a.size(2) 280 | a = a.view(-1, a.size(-1)) 281 | patch_size = layer.kernel_size[0] * layer.kernel_size[1] 282 | if layer.bias is not None: 283 | a = torch.cat([a, a.new(a.size(0), 1).fill_(1./patch_size)], 1) 284 | a = a / spatial_size 285 | return a.t() @ (a / batch_size / patch_size) 286 | 287 | 288 | if __name__ == '__main__': 289 | def test_ComputeCovA(): 290 | pass 291 | 292 | def test_ComputeCovG(): 293 | pass 294 | 295 | 296 | 297 | 298 | 299 | 300 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![PyTorch](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?style=for-the-badge&logo=PyTorch&logoColor=white) 2 | ![Git](https://img.shields.io/badge/git-%23F05033.svg?style=for-the-badge&logo=git&logoColor=white) 3 | # CVPR 2022 4 | [![Generic badge](https://img.shields.io/badge/Library-Pytorch-green.svg)](https://pytorch.org/) 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/ByungKwanLee/Masking-Adversarial-Damage/blob/master/LICENSE) 6 | 7 | # Title: [Masking Adversarial Damage: Finding Adversarial Saliency for Robust and Sparse Network](https://openaccess.thecvf.com/content/CVPR2022/papers/Lee_Masking_Adversarial_Damage_Finding_Adversarial_Saliency_for_Robust_and_Sparse_CVPR_2022_paper.pdf) 8 | 9 | 10 | 11 | --- 12 | 13 | 14 | #### Authors: [Byung-Kwan Lee*](https://scholar.google.co.kr/citations?user=rl0JXCQAAAAJ&hl=en), [Junho Kim*](https://scholar.google.com/citations?user=ZxE16ZUAAAAJ&hl=en), and [Yong Man Ro](https://scholar.google.co.kr/citations?user=IPzfF7cAAAAJ&hl=en) (*: equally contributed) 15 | #### Affiliation: School of Electrical Engineering, Korea Advanced Institute of Science and Technology (KAIST) 16 | #### Email: `leebk@kaist.ac.kr`, `arkimjh@kaist.ac.kr`, `ymro@kaist.ac.kr` 17 | 18 | 19 | --- 20 | 21 | This is official PyTorch Implementation code for the paper of "Masking Adversarial Damage: Finding Adversarial Saliency 22 | for Robust and Sparse Network" accepted in CVPR 2022. To bridge adversarial robustness and model compression, we propose a 23 | novel adversarial pruning method, Masking Adversarial Damage (MAD) that employs second-order information of adversarial loss. 24 | By using it, we can accurately estimate adversarial saliency for model parameters and determine which parameters can be 25 | pruned without weakening adversarial robustness. 26 | 27 |

28 | 29 |

30 | 31 | 32 | Furthermore, we reveal that model parameters of initial layer are highly sensitive to the adversarial examples and show that compressed feature representation retains semantic information for the target objects. 33 | 34 |

35 | 36 |

37 | 38 | Through extensive experiments on public datasets, we demonstrate that MAD effectively prunes adversarially trained 39 | networks without loosing adversarial robustness and shows better performance than previous adversarial pruning methods. 40 | For more detail, you can refer to our paper that will be accessible to public soon!. 41 | 42 |

43 | 44 |

45 | 46 | Adversarial attacks can potentially cause negative impacts on various DNN applications due to high computation and its 47 | fragility. By pruning model parameters without weakening adversarial robustness, our work contributes important societal 48 | impacts in this research area. Furthermore, in our promising observation that model parameters of initial layer are highly 49 | sensitive to adversarial loss, we hope to progress in another future direction of utilizing such property to enhance adversarial robustness. 50 | 51 | In conclusion, in order to achieve adversarial robustness and model compression concurrently, we propose a novel adversarial pruning method, 52 | Masking Adversarial Damage (MAD). By exploiting second-order information with mask optimization and Block-wise K-FAC, 53 | we can precisely estimate adversarial saliency of the whole parameters. Through extensive validations, we corroborate 54 | pruning model parameters in order of low adversarial saliency retains adversarial robustness while alleviating less performance 55 | degradation compared with previous adversarial pruning methods. 56 | 57 | --- 58 | 59 | ## Datasets 60 | * [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) (32x32, 10 classes) 61 | * [SVHN](http://ufldl.stanford.edu/housenumbers/) (32x32, 10 classes) 62 | * [Tiny-ImageNet](https://www.kaggle.com/c/tiny-imagenet/overview) (64x64, 200 classes) 63 | 64 | --- 65 | 66 | ## Networks 67 | 68 | * [VGG-16](https://arxiv.org/pdf/1409.1556) (models/vgg.py) 69 | * [ResNet-18](https://arxiv.org/pdf/1512.03385) (models/resnet.py) 70 | * [WideResNet-28-10](https://arxiv.org/abs/1605.07146) (models/wide.py) 71 | 72 | 73 | --- 74 | 75 | ## Masking Adversarial Damage (MAD) 76 | #### Step 1. Finding Adversarial Saliency 77 | * Run `compute_saliency.py` *(Procedure of saving a pickle file for adversarial saliency to all model parameters. Then, you should need a folder (e.g., `pickle` folder) in which the pickle file is saved)* 78 | 79 | ```bash 80 | # model parameter 81 | parser.add_argument('--dataset', default='cifar10', type=str) # 'svhn', 'tiny' 82 | parser.add_argument('--network', default='vgg', type=str) # 'resnet', 'wide' 83 | parser.add_argument('--depth', default=16, type=int) # 18 (ResNet), 28 (WideResNet) 84 | parser.add_argument('--device', default='cuda:0', type=str) 85 | parser.add_argument('--batch_size', default=128, type=float) 86 | 87 | # attack parameter 88 | parser.add_argument('--attack', default='pgd', type=str) 89 | parser.add_argument('--eps', default=0.03, type=float) 90 | parser.add_argument('--steps', default=10, type=int) 91 | ``` 92 | 93 | Among codes for running `compute_saliency`, the following code represents the major contribution of our work that is the procedure 94 | of computing adversarial saliency realized with Block-wise K-FAC. Note that it is important to consider the factors of 95 | `block1` and `block2` below for Block-wise K-FAC that dramatically reduces computation. 96 | 97 | ```python 98 | def _compute_delta_L(self): 99 | 100 | delta_L_list = [] 101 | mask_list = [] 102 | for idx, m in enumerate(self.modules): 103 | 104 | m_aa, m_gg = self.m_aa[m], self.m_gg[m] 105 | 106 | w = fetch_mat_weights(m) 107 | mask = fetch_mat_mask_weights(m) 108 | w_mask = w - operator(w, mask) 109 | 110 | double_grad_L = torch.empty_like(w_mask) 111 | 112 | # 1/2 * Δ𝑤^𝑇 *𝐻 * Δ𝑤 113 | for i in range(m_gg.shape[0]): 114 | block1 = 0.5 * m_gg[i, i] * w_mask.t()[:, i].view(-1, 1) 115 | block2 = w_mask[i].view(1, -1) @ m_aa 116 | block = block1 @ block2 117 | double_grad_L[i, :] = block.diag() 118 | 119 | delta_L = double_grad_L 120 | delta_L_list.append(delta_L.detach()) 121 | mask_list.append(f(mask).detach()) 122 | 123 | return delta_L_list, mask_list 124 | ``` 125 | 126 | 127 | 128 | #### Step 2. Pruning Low Advesarial Saliency 129 | * Run `main_mad_pretrain.py` *(Necessary to load a pickle file generated in Step 1)* 130 | ```bash 131 | # model parameter 132 | parser.add_argument('--dataset', default='cifar10', type=str) # 'svhn', 'tiny' 133 | parser.add_argument('--network', default='vgg', type=str) # 'resnet', 'wide' 134 | parser.add_argument('--depth', default=16, type=int) # 18 (ResNet), 28 (WideResNet) 135 | parser.add_argument('--device', default='cuda:0', type=str) 136 | 137 | # mad parameter 138 | parser.add_argument('--percnt', default=0.9, type=float) # 0.99 (Sparsity) 139 | parser.add_argument('--pruning_mode', default='element', type=str) # 'random' (randomly pruning) 140 | parser.add_argument('--largest', default='false', type=str2bool) # 'true' (pruning high adversarial saliency) 141 | 142 | # learning parameter 143 | parser.add_argument('--learning_rate', default=0.1, type=float) 144 | parser.add_argument('--weight_decay', default=0.0002, type=float) 145 | parser.add_argument('--batch_size', default=128, type=float) 146 | parser.add_argument('--epoch', default=60, type=int) 147 | 148 | # attack parameter 149 | parser.add_argument('--attack', default='pgd', type=str) 150 | parser.add_argument('--eps', default=0.03, type=float) 151 | parser.add_argument('--steps', default=10, type=int) 152 | ``` 153 | 154 | --- 155 | 156 | ## Adversarial Training (+ Recent Adversarial Defenses) 157 | 158 | * [AT](https://arxiv.org/abs/1706.06083) (main_adv_pretrain.py) 159 | * [TRADES](https://arxiv.org/abs/1901.08573) (main_trades_pretrain.py) 160 | * [MART](https://openreview.net/forum?id=rklOg6EFwS) (main_mart_pretrain.py) 161 | * [FAST](https://openreview.net/forum?id=BJx040EFvH) for Tiny-ImageNet (refer to **FGSM_train** class in *attack/attack.py*) 162 | 163 | ### *Running Adversarial Training* 164 | 165 | To easily make an adversarially trained model, we first train a standard model by [1] 166 | and perform adversarial training (AT) by [2], starting from the trained standard model. To execute recent adversarial defenses, AT model created by [2] 167 | would be helpful to train TRADES or MART through [3-1] or [3-2]. 168 | 169 | * **[1] Plain** (Plain Training) 170 | - Run `main_pretrain.py` 171 | 172 | ```bash 173 | # model parameter 174 | parser.add_argument('--dataset', default='cifar10', type=str) # 'svhn', 'tiny' 175 | parser.add_argument('--network', default='vgg', type=str) # 'resnet', 'wide' 176 | parser.add_argument('--depth', default=16, type=int) # 18 (ResNet), 28 (WideResNet) 177 | parser.add_argument('--epoch', default=200, type=int) 178 | parser.add_argument('--device', default='cuda:0', type=str) 179 | 180 | # learning parameter 181 | parser.add_argument('--learning_rate', default=0.1, type=float) 182 | parser.add_argument('--weight_decay', default=0.0002, type=float) 183 | parser.add_argument('--batch_size', default=128, type=float) 184 | ``` 185 | 186 | * **[2] AT** ([PGD Adversarial Training](https://openreview.net/forum?id=rJzIBfZAb)) 187 | - Run `main_adv_pretrain.py` 188 | 189 | ```bash 190 | # model parameter 191 | parser.add_argument('--dataset', default='cifar10', type=str) # 'svhn', 'tiny' 192 | parser.add_argument('--network', default='vgg', type=str) # 'resnet', 'wide' 193 | parser.add_argument('--depth', default=16, type=int) # 18 (ResNet), 28 (WideResNet) 194 | parser.add_argument('--device', default='cuda:0', type=str) 195 | 196 | # learning parameter 197 | parser.add_argument('--learning_rate', default=0.1, type=float) 198 | parser.add_argument('--weight_decay', default=0.0002, type=float) 199 | parser.add_argument('--batch_size', default=128, type=float) 200 | parser.add_argument('--epoch', default=60, type=int) 201 | 202 | # attack parameter 203 | parser.add_argument('--attack', default='pgd', type=str) 204 | parser.add_argument('--eps', default=0.03, type=float) 205 | parser.add_argument('--steps', default=10, type=int) 206 | ``` 207 | 208 | 209 | * **[3-1] TRADES** ([Recent defense method](http://proceedings.mlr.press/v97/zhang19p.html)) 210 | - Run `main_trades_pretrain.py` 211 | 212 | ```bash 213 | # model parameter 214 | parser.add_argument('--dataset', default='cifar10', type=str) # 'svhn', 'tiny' 215 | parser.add_argument('--network', default='vgg', type=str) # 'resnet', 'wide' 216 | parser.add_argument('--depth', default=16, type=int) # 18 (ResNet), 28 (WideResNet) 217 | parser.add_argument('--device', default='cuda:0', type=str) 218 | 219 | # learning parameter 220 | parser.add_argument('--learning_rate', default=1e-3, type=float) 221 | parser.add_argument('--weight_decay', default=0.0002, type=float) 222 | parser.add_argument('--batch_size', default=128, type=float) 223 | parser.add_argument('--epoch', default=10, type=int) 224 | 225 | # attack parameter 226 | parser.add_argument('--attack', default='pgd', type=str) 227 | parser.add_argument('--eps', default=0.03, type=float) 228 | parser.add_argument('--steps', default=10, type=int) 229 | ``` 230 | 231 | 232 | * **[3-2] MART** ([Recent defense method](https://openreview.net/forum?id=rklOg6EFwS)) 233 | - Run `main_mart_pretrain.py` 234 | 235 | ```bash 236 | # model parameter 237 | parser.add_argument('--dataset', default='cifar10', type=str) # 'svhn', 'tiny' 238 | parser.add_argument('--network', default='vgg', type=str) # 'resnet', 'wide' 239 | parser.add_argument('--depth', default=16, type=int) # 18 (ResNet), 28 (WideResNet) 240 | parser.add_argument('--device', default='cuda:0', type=str) 241 | 242 | # learning parameter 243 | parser.add_argument('--learning_rate', default=1e-3, type=float) 244 | parser.add_argument('--weight_decay', default=0.0002, type=float) 245 | parser.add_argument('--batch_size', default=128, type=float) 246 | parser.add_argument('--epoch', default=60, type=int) 247 | 248 | # attack parameter 249 | parser.add_argument('--attack', default='pgd', type=str) 250 | parser.add_argument('--eps', default=0.03, type=float) 251 | parser.add_argument('--steps', default=10, type=int) 252 | ``` 253 | 254 | --- 255 | 256 | 257 | ## Adversarial Attacks (by [torchattacks](https://github.com/Harry24k/adversarial-attacks-pytorch)) 258 | * Fast Gradient Sign Method ([FGSM](https://arxiv.org/abs/1412.6572)) 259 | * Projected Gradient Descent ([PGD](https://arxiv.org/abs/1706.06083)) 260 | * Carlini & Wagner ([CW](https://arxiv.org/abs/1608.04644)) 261 | * AutoPGD ([AP](https://arxiv.org/abs/2003.01690)) 262 | * AutoAttack ([AA](https://arxiv.org/abs/2003.01690)) 263 | 264 | This implementation details for the adversarial attacks are described in *attack/attack.py*. 265 | 266 | ```bash 267 | # torchattacks 268 | if attack == "fgsm": 269 | return torchattacks.FGSM(model=net, eps=eps) 270 | 271 | elif attack == "fgsm_train": 272 | return FGSM_train(model=net, eps=eps) 273 | 274 | elif attack == "pgd": 275 | return torchattacks.PGD(model=net, eps=eps, alpha=eps/steps*2.3, steps=steps, random_start=True) 276 | 277 | elif attack == "cw_linf": 278 | return CW_Linf(model=net, eps=eps, lr=0.1, steps=30) 279 | 280 | elif attack == "apgd": 281 | return torchattacks.APGD(model=net, eps=eps, loss='ce', steps=30) 282 | 283 | elif attack == "auto": 284 | return torchattacks.AutoAttack(model=net, eps=eps, n_classes=n_classes) 285 | ``` 286 | 287 | 288 | ### *Testing Adversarial Robustness* 289 | 290 | * **Mearsuring the robustness in an adversarial trained model** 291 | - Run `test.py` 292 | 293 | ```bash 294 | # model parameter 295 | parser.add_argument('--dataset', default='cifar10', type=str) # 'svhn', 'tiny' 296 | parser.add_argument('--network', default='vgg', type=str) # 'resnet', 'wide' 297 | parser.add_argument('--depth', default=16, type=int) # 18 (ResNet), 28 (WideResNet) 298 | parser.add_argument('--baseline', default='adv', type=str) # 'trades', 'mart', 'mad' 299 | parser.add_argument('--device', default='cuda:0', type=str) 300 | 301 | # mad parameter 302 | parser.add_argument('--percnt', default=0.9, type=float) # 0.99 (Sparsity) 303 | parser.add_argument('--pruning_mode', default='el', type=str) # 'rd' (random) 304 | parser.add_argument('--largest', default='false', type=str2bool) # 'true' (pruning high adversarial saliency) 305 | 306 | # attack parameter 307 | parser.add_argument('--attack', default='pgd', type=str) 308 | parser.add_argument('--eps', default=0.03, type=float) 309 | parser.add_argument('--steps', default=30, type=int) 310 | ``` 311 | 312 | 313 | --- 314 | -------------------------------------------------------------------------------- /pruner/kfac_MAD_pruner.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from utils.kfac_utils import (ComputeCovA, 3 | ComputeCovG, 4 | fetch_mat_weights, 5 | fetch_mat_mask_weights,) 6 | 7 | from utils.utils import * 8 | from models.operator.mask import * 9 | 10 | 11 | class KFACMADPruner: 12 | 13 | def __init__(self, 14 | model, 15 | attack, 16 | device, 17 | batch_averaged=True, 18 | dataset = None 19 | ): 20 | self.iter = 0 21 | self.device = device 22 | self.CovAHandler = ComputeCovA() 23 | self.CovGHandler = ComputeCovG() 24 | self.batch_averaged = batch_averaged 25 | self.mask_known_modules = {'Conv2d_mask', 'Linear_mask'} 26 | self.modules = [] 27 | self.model = model 28 | self.attack = attack 29 | self.steps = 0 30 | self.dataset = dataset 31 | 32 | self.m_aa, self.m_gg = {}, {} 33 | self.Q_a, self.Q_g = {}, {} 34 | self.d_a, self.d_g = {}, {} 35 | self.W_pruned = {} 36 | self.S_l = None 37 | 38 | 39 | def _save_input(self, module, input): 40 | aa = self.CovAHandler(input[0].data, module) 41 | # Initialize buffers 42 | if self.steps == 0: 43 | self.m_aa[module] = torch.diag(aa.new(aa.size(0)).fill_(0)) 44 | self.m_aa[module] += aa 45 | 46 | def _save_grad_output(self, module, grad_input, grad_output): 47 | # Accumulate statistics for Fisher matrices 48 | gg = self.CovGHandler(grad_output[0].data, module, self.batch_averaged) 49 | # Initialize buffers 50 | if self.steps == 0: 51 | self.m_gg[module] = torch.diag(gg.new(gg.size(0)).fill_(0)) 52 | self.m_gg[module] += gg 53 | 54 | 55 | def _mask_prepare_model(self): 56 | count = 0 57 | for module in self.model.modules(): 58 | classname = module.__class__.__name__ 59 | if classname in self.mask_known_modules: 60 | self.modules.append(module) 61 | module.register_forward_pre_hook(self._save_input) 62 | module.register_backward_hook(self._save_grad_output) 63 | count += 1 64 | 65 | 66 | def _compute_minibatch_fisher(self, inputs, targets, device='cuda', fisher_type=True): 67 | 68 | self.model = self.model.eval() 69 | inputs, targets = inputs.to(device), targets.to(device) 70 | outputs = self.model(inputs) 71 | if fisher_type: 72 | sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1), 73 | 1).squeeze().to(device) 74 | loss_sample = F.cross_entropy(outputs, sampled_y) 75 | loss_sample.backward() 76 | else: 77 | loss = F.cross_entropy(outputs, targets) 78 | loss.backward() 79 | self.steps = 1 80 | 81 | def _rm_mask_hooks(self): 82 | for m in self.model.modules(): 83 | classname = m.__class__.__name__ 84 | if classname in self.mask_known_modules: 85 | m._backward_hooks = OrderedDict() 86 | m._forward_pre_hooks = OrderedDict() 87 | 88 | def _clear_buffer(self): 89 | self.m_aa = {} 90 | self.m_gg = {} 91 | self.d_a = {} 92 | self.d_g = {} 93 | self.Q_a = {} 94 | self.Q_g = {} 95 | self.modules = [] 96 | if self.S_l is not None: 97 | self.S_l = {} 98 | 99 | self.m_a = {} 100 | self.m_g = {} 101 | self.steps = 0 102 | 103 | def _optimize_mask(self, inputs, targets, optim, mask_epoch, debug_acc=True, is_compute_delta_L=False): 104 | 105 | self.model = self.model.eval() 106 | self._mask_prepare_model() 107 | self._rm_mask_hooks() 108 | reinitialize_mask_network(self.modules) 109 | 110 | if debug_acc: 111 | self.debug(inputs, targets, str='No update/Mask') 112 | 113 | loss = 0 114 | for _ in range(mask_epoch): 115 | outputs = self.model(inputs) 116 | loss = F.cross_entropy(outputs, targets) 117 | 118 | # optimizing mask 119 | optim.zero_grad() 120 | self.model.zero_grad() 121 | loss.backward() 122 | optim.step() 123 | clamping_mask_network(self.modules) 124 | 125 | if debug_acc: 126 | self.debug(inputs, targets, str='Update/Mask') 127 | 128 | self._clear_buffer() 129 | if is_compute_delta_L: 130 | self._mask_prepare_model() 131 | self._compute_minibatch_fisher(inputs, targets, self.device, False) 132 | self._rm_mask_hooks() 133 | delta_L_list, mask_list = self._compute_delta_L() # weight delta L 134 | outputs = self.model(inputs) 135 | reinitialize_mask_network(self.modules) 136 | self._clear_buffer() 137 | return loss, delta_L_list, mask_list, outputs 138 | else: 139 | return loss 140 | 141 | 142 | def _optimize_weight_with_delta_L(self, inputs, targets, optim, 143 | onehot_dict=None, pruning_mode='element'): 144 | 145 | self._mask_prepare_model() 146 | self._rm_mask_hooks() 147 | 148 | self.model = self.model.train() 149 | optim.zero_grad() 150 | 151 | loss, logit = self.loss_function(model=self.model, 152 | x_natural=inputs, 153 | y=targets, 154 | optim=optim, 155 | device=self.device, 156 | step_size=0.03/10*2.3 if self.dataset != 'tiny' else 0.03 * 1.25, 157 | epsilon=0.03, 158 | perturb_steps=10 if self.dataset != 'tiny' else 1, 159 | beta=0.5) 160 | loss.backward() 161 | 162 | if onehot_dict is not None: 163 | if (pruning_mode == 'element') or (pruning_mode == 'random'): 164 | for m in self.modules: 165 | onehot = onehot_dict[m] 166 | if m.bias is not None: 167 | m.bias.grad.data *= (1 - onehot)[:, -1].view(m.bias.shape) 168 | m.bias.data *= (1 - onehot)[:, -1].view(m.bias.shape) 169 | 170 | m.weight.grad.data *= (1 - onehot)[:, :-1].view(m.weight.shape) 171 | m.weight.data *= (1 - onehot)[:, :-1].view(m.weight.shape) 172 | else: 173 | m.weight.grad.data *= (1 - onehot).view(m.weight.shape) 174 | m.weight.data *= (1 - onehot).view(m.weight.shape) 175 | optim.step() 176 | self._clear_buffer() 177 | 178 | return loss, logit 179 | 180 | def _global_remove_weight(self, adv_delta_L_avg, percnt, largest): 181 | print("--------------GLOBAL PRUNING--------------") 182 | self._mask_prepare_model() 183 | self._rm_mask_hooks() 184 | onehot_dict = {} 185 | 186 | # convolution layer pruning 187 | concat_conv, linear, size = list_to_concat(adv_delta_L_avg, self.device) 188 | val, ind = concat_conv.sort(descending=largest) 189 | sort_ind = ind[:int(percnt * val.size(0))] 190 | conv_onehot = index2onehot(sort_ind, size=ind.shape[0]).to(self.device) 191 | 192 | 193 | # linear layer pruning 194 | val, ind = linear.sort(descending=largest, dim=1) 195 | linear_onehot = torch.zeros_like(linear).to(self.device) 196 | for i in range(linear.shape[0]): 197 | one = torch.zeros(linear.size(1)).to(self.device) 198 | one[:int(percnt * linear.size(1))] = 1 199 | linear_onehot.data[i, ind[i]] = one 200 | 201 | initial = 0 202 | for index, m in enumerate(self.modules): 203 | 204 | if 'Linear' in m._get_name(): 205 | 206 | onehot_dict[m] = linear_onehot 207 | 208 | m.weight.data *= (1 - onehot_dict[m][:, :-1]).view(m.weight.shape) 209 | m.bias.data *= (1 - onehot_dict[m][:, -1]).view(m.bias.shape) 210 | 211 | m.mask_weight.data = f_inv(1) * torch.ones_like(m.mask_weight) 212 | m.mask_bias.data = f_inv(1) * torch.ones_like(m.mask_bias) 213 | 214 | initial += size[index].numel() 215 | break 216 | 217 | 218 | if m.bias is not None: 219 | 220 | onehot_dict[m] = conv_onehot[initial:initial + size[index].numel()].view(size[index]) 221 | 222 | m.weight.data *= (1 - onehot_dict[m][:, :-1]).view(m.weight.shape) 223 | m.bias.data *= (1 - onehot_dict[m][:, -1]).view(m.bias.shape) 224 | 225 | m.mask_weight.data = f_inv(1) * torch.ones_like(m.mask_weight) 226 | m.mask_bias.data = f_inv(1) * torch.ones_like(m.mask_bias) 227 | 228 | initial += size[index].numel() 229 | else: 230 | onehot_dict[m] = conv_onehot[initial:initial+size[index].numel()].view(size[index]) 231 | 232 | m.weight.data *= (1 - onehot_dict[m]).view(m.weight.shape) 233 | m.mask_weight.data = f_inv(1) * torch.ones_like(m.mask_weight) 234 | 235 | initial += size[index].numel() 236 | 237 | assert initial == conv_onehot.shape.numel() + linear_onehot.shape.numel() 238 | 239 | self._clear_buffer() 240 | return onehot_dict 241 | 242 | def _global_random_pruning(self, adv_delta_L_avg, percnt): 243 | print("--------------Random PRUNING--------------") 244 | self._mask_prepare_model() 245 | self._rm_mask_hooks() 246 | onehot_dict = {} 247 | 248 | # convolution layer pruning 249 | concat_conv, linear, size = list_to_concat(adv_delta_L_avg, self.device) 250 | sort_ind = torch.randperm(concat_conv.shape[0])[:int(percnt * concat_conv.size(0))] 251 | conv_onehot = index2onehot(sort_ind, size=concat_conv.shape[0]).to(self.device) 252 | 253 | # linear layer pruning 254 | linear_onehot = torch.zeros_like(linear).to(self.device) 255 | for i in range(linear.shape[0]): 256 | sort_ind = torch.randperm(linear_onehot.shape[1])[:int(percnt * linear_onehot.shape[1])] 257 | onehot = index2onehot(sort_ind, size=linear_onehot.shape[1]).to(self.device) 258 | linear_onehot.data[i] = onehot 259 | 260 | initial = 0 261 | for index, m in enumerate(self.modules): 262 | 263 | if 'Linear' in m._get_name(): 264 | onehot_dict[m] = linear_onehot 265 | 266 | m.weight.data *= (1 - onehot_dict[m][:, :-1]).view(m.weight.shape) 267 | m.bias.data *= (1 - onehot_dict[m][:, -1]).view(m.bias.shape) 268 | 269 | m.mask_weight.data = f_inv(1) * torch.ones_like(m.mask_weight) 270 | m.mask_bias.data = f_inv(1) * torch.ones_like(m.mask_bias) 271 | 272 | initial += size[index].numel() 273 | break 274 | 275 | if m.bias is not None: 276 | 277 | onehot_dict[m] = conv_onehot[initial:initial + size[index].numel()].view(size[index]) 278 | 279 | m.weight.data *= (1 - onehot_dict[m][:, :-1]).view(m.weight.shape) 280 | m.bias.data *= (1 - onehot_dict[m][:, -1]).view(m.bias.shape) 281 | 282 | m.mask_weight.data = f_inv(1) * torch.ones_like(m.mask_weight) 283 | m.mask_bias.data = f_inv(1) * torch.ones_like(m.mask_bias) 284 | 285 | initial += size[index].numel() 286 | else: 287 | onehot_dict[m] = conv_onehot[initial:initial + size[index].numel()].view(size[index]) 288 | 289 | m.weight.data *= (1 - onehot_dict[m]).view(m.weight.shape) 290 | m.mask_weight.data = f_inv(1) * torch.ones_like(m.mask_weight) 291 | 292 | initial += size[index].numel() 293 | 294 | assert initial == conv_onehot.shape.numel() + linear_onehot.shape.numel() 295 | 296 | self._clear_buffer() 297 | return onehot_dict 298 | 299 | 300 | def _compute_prune_ratio_per_layer(self): 301 | self._mask_prepare_model() 302 | self._rm_mask_hooks() 303 | 304 | layer_prune_ratio = [] 305 | for m in self.modules: 306 | 307 | count = 0 308 | shape = 0 309 | 310 | if m.bias is not None: 311 | count += (m.weight == 0).sum().item() 312 | shape += m.weight.shape.numel() 313 | 314 | count += (m.bias == 0).sum().item() 315 | shape += m.bias.shape.numel() 316 | else: 317 | count += (m.weight == 0).sum().item() 318 | shape += m.weight.shape.numel() 319 | 320 | layer_prune_ratio.append(round(int(count)/shape, 3)) 321 | 322 | self._clear_buffer() 323 | return layer_prune_ratio 324 | 325 | 326 | 327 | 328 | def debug(self, inputs, targets, str=''): 329 | self.model = self.model.eval() 330 | _, predicted = self.model(inputs).max(1) 331 | total = targets.size(0) 332 | correct = predicted.eq(targets).sum().item() 333 | print("[{}] Acc is {}".format(str, correct / total)) 334 | 335 | 336 | 337 | def _compute_delta_L(self): 338 | 339 | delta_L_list = [] 340 | mask_list = [] 341 | for idx, m in enumerate(self.modules): 342 | 343 | m_aa, m_gg = self.m_aa[m], self.m_gg[m] 344 | 345 | w = fetch_mat_weights(m) 346 | mask = fetch_mat_mask_weights(m) 347 | w_mask = w - operator(w, mask) 348 | 349 | double_grad_L = torch.empty_like(w_mask) 350 | 351 | # 1/2 * Δ𝑤^𝑇 *𝐻 * Δ𝑤 352 | for i in range(m_gg.shape[0]): 353 | block1 = 0.5 * m_gg[i, i] * w_mask.t()[:, i].view(-1, 1) 354 | block2 = w_mask[i].view(1, -1) @ m_aa 355 | block = block1 @ block2 356 | double_grad_L[i, :] = block.diag() 357 | 358 | delta_L = double_grad_L 359 | delta_L_list.append(delta_L.detach()) 360 | mask_list.append(f(mask).detach()) 361 | 362 | return delta_L_list, mask_list 363 | 364 | 365 | @staticmethod 366 | def loss_function(model, 367 | x_natural, 368 | y, 369 | optim, 370 | device, 371 | step_size, 372 | epsilon, 373 | perturb_steps, 374 | beta, 375 | distance='l_inf'): 376 | kl = torch.nn.KLDivLoss(reduction='none') 377 | model.eval() 378 | batch_size = len(x_natural) 379 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).to(device).detach() 380 | if distance == 'l_inf': 381 | for _ in range(perturb_steps): 382 | x_adv.requires_grad_() 383 | with torch.enable_grad(): 384 | loss_ce = F.cross_entropy(model(x_adv), y) 385 | grad = torch.autograd.grad(loss_ce, [x_adv])[0] 386 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 387 | x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon) 388 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 389 | else: 390 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 391 | model.train() 392 | 393 | x_adv = torch.autograd.Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 394 | # zero gradient 395 | optim.zero_grad() 396 | 397 | logits = model(x_natural) 398 | 399 | logits_adv = model(x_adv) 400 | 401 | adv_probs = F.softmax(logits_adv, dim=1) 402 | 403 | tmp1 = torch.argsort(adv_probs, dim=1)[:, -2:] 404 | 405 | new_y = torch.where(tmp1[:, -1] == y, tmp1[:, -2], tmp1[:, -1]) 406 | 407 | loss_adv = F.cross_entropy(logits_adv, y) + F.nll_loss(torch.log(1.0001 - adv_probs + 1e-12), new_y) 408 | 409 | nat_probs = F.softmax(logits, dim=1) 410 | 411 | true_probs = torch.gather(nat_probs, 1, (y.unsqueeze(1)).long()).squeeze() 412 | 413 | loss_robust = (1.0 / batch_size) * torch.sum( 414 | torch.sum(kl(torch.log(adv_probs + 1e-12), nat_probs), dim=1) * (1.0000001 - true_probs)) 415 | loss = loss_adv + float(beta) * loss_robust 416 | 417 | return loss, logits_adv --------------------------------------------------------------------------------