├── images ├── motivation.jpg └── pipeline.jpg ├── requirements.txt ├── .gitignore ├── dataset ├── svhn.py └── cifar10.py ├── LICENSE ├── loss ├── mart.py ├── part_trades.py ├── part_mart.py └── trades.py ├── models ├── resnet.py └── wideresnet.py ├── README.md ├── craft_ae.py ├── train_eval_part_m.py ├── train_eval_part_t.py ├── utils.py └── train_eval_part.py /images/motivation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmlr-group/PART/HEAD/images/motivation.jpg -------------------------------------------------------------------------------- /images/pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmlr-group/PART/HEAD/images/pipeline.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.7.2 2 | torch==1.13.0 3 | torchattacks==3.4.0 4 | torchcam==0.3.2 5 | torchvision==0.14.0 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /data/ 2 | /checkpoint/ 3 | /output/ 4 | /__pycache__/ 5 | /dataset/__pycache__/ 6 | /models/__pycache__/ 7 | /loss/__pycache__/ 8 | part.slurm 9 | part_t.slurm 10 | part_m.slurm -------------------------------------------------------------------------------- /dataset/svhn.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import SVHN as DATA 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | 5 | class SVHN(): 6 | def __init__(self, train_batch_size: int = 512, test_batch_size: int = 100): 7 | self.train_batch_size = train_batch_size 8 | self.test_batch_size = test_batch_size 9 | 10 | def transform_train(self): 11 | return transforms.Compose([ 12 | transforms.ToTensor(), 13 | ]) 14 | 15 | def transform_test(self): 16 | return transforms.Compose([ 17 | transforms.ToTensor(), 18 | ]) 19 | 20 | def train_data(self): 21 | train_dataset = DATA('./data/', split='train', download=True, transform=self.transform_train()) 22 | return DataLoader(train_dataset, batch_size=self.train_batch_size, shuffle=True) 23 | 24 | def test_data(self): 25 | test_dataset = DATA('./data/', split='test', download=True, transform=self.transform_test()) 26 | return DataLoader(test_dataset, batch_size=self.test_batch_size, shuffle=False) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jiacheng Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dataset/cifar10.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR10 as DATA 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | 5 | class CIFAR10(): 6 | def __init__(self, train_batch_size: int = 512, test_batch_size: int = 100, path: str = './data/'): 7 | self.train_batch_size = train_batch_size 8 | self.test_batch_size = test_batch_size 9 | self.path = path 10 | 11 | def transform_train(self): 12 | return transforms.Compose([ 13 | transforms.RandomCrop(32, padding=4), 14 | transforms.RandomHorizontalFlip(), 15 | transforms.ToTensor(), 16 | ]) 17 | 18 | def transform_test(self): 19 | return transforms.Compose([ 20 | transforms.ToTensor(), 21 | ]) 22 | 23 | def train_data(self): 24 | train_dataset = DATA(self.path, train=True, download=True, transform=self.transform_train()) 25 | return DataLoader(train_dataset, batch_size=self.train_batch_size, shuffle=True) 26 | 27 | def test_data(self): 28 | test_dataset = DATA(self.path, train=False, download=True, transform=self.transform_test()) 29 | return DataLoader(test_dataset, batch_size=self.test_batch_size, shuffle=False) 30 | -------------------------------------------------------------------------------- /loss/mart.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | def mart_loss(model, 7 | x_natural, 8 | y, 9 | optimizer, 10 | step_size=2/255, 11 | epsilon=8/255, 12 | perturb_steps=10, 13 | beta=6.0, 14 | distance='l_inf'): 15 | kl = nn.KLDivLoss(reduction='none') 16 | model.eval() 17 | batch_size = len(x_natural) 18 | # generate adversarial example 19 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach() 20 | if distance == 'l_inf': 21 | for _ in range(perturb_steps): 22 | x_adv.requires_grad_() 23 | with torch.enable_grad(): 24 | loss_ce = F.cross_entropy(model(x_adv), y) 25 | grad = torch.autograd.grad(loss_ce, [x_adv])[0] 26 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 27 | x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon) 28 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 29 | else: 30 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 31 | model.train() 32 | 33 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 34 | # zero gradient 35 | optimizer.zero_grad() 36 | 37 | logits = model(x_natural) 38 | 39 | logits_adv = model(x_adv) 40 | 41 | adv_probs = F.softmax(logits_adv, dim=1) 42 | 43 | tmp1 = torch.argsort(adv_probs, dim=1)[:, -2:] 44 | 45 | new_y = torch.where(tmp1[:, -1] == y, tmp1[:, -2], tmp1[:, -1]) 46 | 47 | loss_adv = F.cross_entropy(logits_adv, y) + F.nll_loss(torch.log(1.0001 - adv_probs + 1e-12), new_y) 48 | 49 | nat_probs = F.softmax(logits, dim=1) 50 | 51 | true_probs = torch.gather(nat_probs, 1, (y.unsqueeze(1)).long()).squeeze() 52 | 53 | loss_robust = (1.0 / batch_size) * torch.sum( 54 | torch.sum(kl(torch.log(adv_probs + 1e-12), nat_probs), dim=1) * (1.0000001 - true_probs)) 55 | loss = loss_adv + float(beta) * loss_robust 56 | 57 | return loss -------------------------------------------------------------------------------- /loss/part_trades.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from craft_ae import element_wise_clamp 6 | 7 | def squared_l2_norm(x): 8 | flattened = x.view(x.unsqueeze(0).shape[0], -1) 9 | return (flattened ** 2).sum(1) 10 | 11 | def l2_norm(x): 12 | return squared_l2_norm(x).sqrt() 13 | 14 | def part_trades_loss(model, 15 | x_natural, 16 | y, 17 | optimizer, 18 | weighted_eps, 19 | step_size=2/255, 20 | perturb_steps=10, 21 | beta=1.0, 22 | distance='l_inf'): 23 | # define KL-loss 24 | criterion_kl = nn.KLDivLoss(reduction='sum') 25 | model.eval() 26 | batch_size = len(x_natural) 27 | # generate adversarial example 28 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach() 29 | if distance == 'l_inf': 30 | for _ in range(perturb_steps): 31 | x_adv.requires_grad_() 32 | with torch.enable_grad(): 33 | loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1), 34 | F.softmax(model(x_natural), dim=1)) 35 | grad = torch.autograd.grad(loss_kl, [x_adv])[0] 36 | eta = step_size * torch.sign(grad.detach()) 37 | x_adv = Variable(x_adv.data + eta, requires_grad=True) 38 | eta = element_wise_clamp(x_adv.data - x_natural.data, weighted_eps) 39 | x_adv = Variable(x_natural.data + eta, requires_grad=True) 40 | x_adv = Variable(torch.clamp(x_adv, 0, 1.0), requires_grad=True) 41 | else: 42 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 43 | model.train() 44 | 45 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 46 | # zero gradient 47 | optimizer.zero_grad() 48 | # calculate robust loss 49 | logits = model(x_natural) 50 | loss_natural = F.cross_entropy(logits, y) 51 | loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(model(x_adv), dim=1), 52 | F.softmax(model(x_natural), dim=1)) 53 | loss = loss_natural + beta * loss_robust 54 | return loss 55 | -------------------------------------------------------------------------------- /loss/part_mart.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from craft_ae import element_wise_clamp 6 | 7 | def part_mart_loss(model, 8 | x_natural, 9 | y, 10 | optimizer, 11 | weighted_eps, 12 | step_size=2/255, 13 | perturb_steps=10, 14 | beta=6.0, 15 | distance='l_inf'): 16 | kl = nn.KLDivLoss(reduction='none') 17 | model.eval() 18 | batch_size = len(x_natural) 19 | # generate adversarial example 20 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach() 21 | if distance == 'l_inf': 22 | for _ in range(perturb_steps): 23 | x_adv.requires_grad_() 24 | with torch.enable_grad(): 25 | loss_ce = F.cross_entropy(model(x_adv), y) 26 | grad = torch.autograd.grad(loss_ce, [x_adv])[0] 27 | eta = step_size * torch.sign(grad.detach()) 28 | # print(eta) 29 | x_adv = Variable(x_adv.data + eta, requires_grad=True) 30 | eta = element_wise_clamp(x_adv.data - x_natural.data, weighted_eps) 31 | x_adv = Variable(x_natural.data + eta, requires_grad=True) 32 | x_adv = Variable(torch.clamp(x_adv, 0, 1.0), requires_grad=True) 33 | else: 34 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 35 | model.train() 36 | 37 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 38 | # zero gradient 39 | optimizer.zero_grad() 40 | 41 | logits = model(x_natural) 42 | 43 | logits_adv = model(x_adv) 44 | 45 | adv_probs = F.softmax(logits_adv, dim=1) 46 | 47 | tmp1 = torch.argsort(adv_probs, dim=1)[:, -2:] 48 | 49 | new_y = torch.where(tmp1[:, -1] == y, tmp1[:, -2], tmp1[:, -1]) 50 | 51 | loss_adv = F.cross_entropy(logits_adv, y) + F.nll_loss(torch.log(1.0001 - adv_probs + 1e-12), new_y) 52 | 53 | nat_probs = F.softmax(logits, dim=1) 54 | 55 | true_probs = torch.gather(nat_probs, 1, (y.unsqueeze(1)).long()).squeeze() 56 | 57 | loss_robust = (1.0 / batch_size) * torch.sum( 58 | torch.sum(kl(torch.log(adv_probs + 1e-12), nat_probs), dim=1) * (1.0000001 - true_probs)) 59 | loss = loss_adv + float(beta) * loss_robust 60 | 61 | return loss 62 | -------------------------------------------------------------------------------- /loss/trades.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import torch.optim as optim 6 | 7 | def squared_l2_norm(x): 8 | flattened = x.view(x.unsqueeze(0).shape[0], -1) 9 | return (flattened ** 2).sum(1) 10 | 11 | def l2_norm(x): 12 | return squared_l2_norm(x).sqrt() 13 | 14 | def trades_loss(model, 15 | x_natural, 16 | y, 17 | optimizer, 18 | step_size=2/255, 19 | epsilon=8/255, 20 | perturb_steps=10, 21 | beta=1.0, 22 | distance='l_inf'): 23 | # define KL-loss 24 | criterion_kl = nn.KLDivLoss(reduction='sum') 25 | model.eval() 26 | batch_size = len(x_natural) 27 | # generate adversarial example 28 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach() 29 | if distance == 'l_inf': 30 | for _ in range(perturb_steps): 31 | x_adv.requires_grad_() 32 | with torch.enable_grad(): 33 | loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1), 34 | F.softmax(model(x_natural), dim=1)) 35 | grad = torch.autograd.grad(loss_kl, [x_adv])[0] 36 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 37 | x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon) 38 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 39 | elif distance == 'l_2': 40 | delta = 0.001 * torch.randn(x_natural.shape).cuda().detach() 41 | delta = Variable(delta.data, requires_grad=True) 42 | 43 | # Setup optimizers 44 | optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2) 45 | 46 | for _ in range(perturb_steps): 47 | adv = x_natural + delta 48 | 49 | # optimize 50 | optimizer_delta.zero_grad() 51 | with torch.enable_grad(): 52 | loss = (-1) * criterion_kl(F.log_softmax(model(adv), dim=1), 53 | F.softmax(model(x_natural), dim=1)) 54 | loss.backward() 55 | # renorming gradient 56 | grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1) 57 | delta.grad.div_(grad_norms.view(-1, 1, 1, 1)) 58 | # avoid nan or inf if gradient is 0 59 | if (grad_norms == 0).any(): 60 | delta.grad[grad_norms == 0] = torch.randn_like(delta.grad[grad_norms == 0]) 61 | optimizer_delta.step() 62 | 63 | # projection 64 | delta.data.add_(x_natural) 65 | delta.data.clamp_(0, 1).sub_(x_natural) 66 | delta.data.renorm_(p=2, dim=0, maxnorm=epsilon) 67 | x_adv = Variable(x_natural + delta, requires_grad=False) 68 | else: 69 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 70 | model.train() 71 | 72 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 73 | # zero gradient 74 | optimizer.zero_grad() 75 | # calculate robust loss 76 | logits = model(x_natural) 77 | loss_natural = F.cross_entropy(logits, y) 78 | loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(model(x_adv), dim=1), 79 | F.softmax(model(x_natural), dim=1)) 80 | loss = loss_natural + beta * loss_robust 81 | return loss -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BasicBlock(nn.Module): 7 | expansion = 1 8 | 9 | def __init__(self, in_planes, planes, stride=1): 10 | super(BasicBlock, self).__init__() 11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 12 | self.bn1 = nn.BatchNorm2d(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(planes) 15 | 16 | self.shortcut = nn.Sequential() 17 | if stride != 1 or in_planes != self.expansion * planes: 18 | self.shortcut = nn.Sequential( 19 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 20 | nn.BatchNorm2d(self.expansion * planes) 21 | ) 22 | 23 | def forward(self, x): 24 | out = F.relu(self.bn1(self.conv1(x))) 25 | out = self.bn2(self.conv2(out)) 26 | out += self.shortcut(x) 27 | out = F.relu(out) 28 | return out 29 | 30 | 31 | class Bottleneck(nn.Module): 32 | expansion = 4 33 | 34 | def __init__(self, in_planes, planes, stride=1): 35 | super(Bottleneck, self).__init__() 36 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 37 | self.bn1 = nn.BatchNorm2d(planes) 38 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 39 | self.bn2 = nn.BatchNorm2d(planes) 40 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 41 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 42 | 43 | self.shortcut = nn.Sequential() 44 | if stride != 1 or in_planes != self.expansion * planes: 45 | self.shortcut = nn.Sequential( 46 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 47 | nn.BatchNorm2d(self.expansion * planes) 48 | ) 49 | 50 | def forward(self, x): 51 | out = F.relu(self.bn1(self.conv1(x))) 52 | out = F.relu(self.bn2(self.conv2(out))) 53 | out = self.bn3(self.conv3(out)) 54 | out += self.shortcut(x) 55 | out = F.relu(out) 56 | return out 57 | 58 | 59 | class ResNet(nn.Module): 60 | def __init__(self, block, num_blocks, num_classes=10): 61 | super(ResNet, self).__init__() 62 | self.in_planes = 64 63 | 64 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(64) 66 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 67 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 68 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 69 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 70 | self.linear = nn.Linear(512 * block.expansion, num_classes) 71 | 72 | def _make_layer(self, block, planes, num_blocks, stride): 73 | strides = [stride] + [1] * (num_blocks - 1) 74 | layers = [] 75 | for stride in strides: 76 | layers.append(block(self.in_planes, planes, stride)) 77 | self.in_planes = planes * block.expansion 78 | return nn.Sequential(*layers) 79 | 80 | def forward(self, x): 81 | out = F.relu(self.bn1(self.conv1(x))) 82 | out = self.layer1(out) 83 | out = self.layer2(out) 84 | out = self.layer3(out) 85 | out = self.layer4(out) 86 | out = F.avg_pool2d(out, 4) 87 | out = out.view(out.size(0), -1) 88 | out = self.linear(out) 89 | return out 90 | def ResNet18(num_classes=10): 91 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes) 92 | -------------------------------------------------------------------------------- /models/wideresnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | 35 | class NetworkBlock(nn.Module): 36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 37 | super(NetworkBlock, self).__init__() 38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 39 | 40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 41 | layers = [] 42 | for i in range(int(nb_layers)): 43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | return self.layer(x) 48 | 49 | 50 | class WideResNet(nn.Module): 51 | def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0.0): 52 | super(WideResNet, self).__init__() 53 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 54 | assert ((depth - 4) % 6 == 0) 55 | n = (depth - 4) / 6 56 | block = BasicBlock 57 | # 1st conv before any network block 58 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 59 | padding=1, bias=False) 60 | # 1st block 61 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 62 | # 1st sub-block 63 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 64 | # 2nd block 65 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 66 | # 3rd block 67 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 68 | # global average pooling and classifier 69 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.fc = nn.Linear(nChannels[3], num_classes) 72 | self.nChannels = nChannels[3] 73 | 74 | for m in self.modules(): 75 | if isinstance(m, nn.Conv2d): 76 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 77 | m.weight.data.normal_(0, math.sqrt(2. / n)) 78 | elif isinstance(m, nn.BatchNorm2d): 79 | m.weight.data.fill_(1) 80 | m.bias.data.zero_() 81 | elif isinstance(m, nn.Linear): 82 | m.bias.data.zero_() 83 | 84 | def forward(self, x): 85 | out = self.conv1(x) 86 | out = self.block1(out) 87 | out = self.block2(out) 88 | out = self.block3(out) 89 | out = self.relu(self.bn1(out)) 90 | out = F.avg_pool2d(out, 8) 91 | out = out.view(out.size(0), self.nChannels, -1).mean(-1) 92 | return self.fc(out) 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improving Accuracy-robustness Trade-off via Pixel Reweighted Adversarial Training (ICML 2024) 2 | 3 | [![Static Badge](https://img.shields.io/badge/Pub-ICML'24-blue)](https://icml.cc/virtual/2024/poster/34324) 4 | [![arXiv](https://img.shields.io/badge/arXiv-2406.00685-b31b1b.svg)](https://arxiv.org/abs/2406.00685) 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 6 | 7 | This repository is the official PyTorch implementation of the **ICML'24** paper: 8 | 9 | [Improving Accuracy-robustness Trade-off via Pixel Reweighted Adversarial Training](https://arxiv.org/abs/2406.00685) 10 | 11 | **Author List**: Jiacheng Zhang, Feng Liu, Dawei Zhou, Jingfeng Zhang, Tongliang Liu. 12 | 13 | #### Abstract 14 | Adversarial training (AT) trains models using adversarial examples (AEs), which are natural images modified with specific perturbations to mislead the model. 15 | These perturbations are constrained by a predefined perturbation budget $\epsilon$ and are equally applied to each pixel within an image. 16 | However, in this paper, we discover that not all pixels contribute equally to the accuracy on AEs (i.e., robustness) and accuracy on natural images (i.e., accuracy). 17 | Motivated by this finding, we propose Pixel-reweighted AdveRsarial Training (PART), a new framework that partially reduces $\epsilon$ for less influential pixels, guiding the model to focus more on key regions that affect its outputs. 18 | Specifically, we first use class activation mapping (CAM) methods to identify important pixel regions, then we keep the perturbation budget for these regions while lowering it for the remaining regions when generating AEs. 19 | In the end, we use these pixel-reweighted AEs to train a model. 20 | PART achieves a notable improvement in accuracy without compromising robustness on CIFAR-10, SVHN and TinyImagenet-200, justifying the necessity to allocate distinct weights to different pixel regions in robust classification. 21 | 22 | #### Figure 1: The proof-of-concept experiment. 23 | ![motivation](https://github.com/JiachengZ01/PART/blob/main/images/motivation.jpg) 24 | We find that fundamental discrepancies exist among different pixel regions. Specifically, we segment each image into four equal-sized regions (i.e., ul, short for upper left; ur, short for upper right; br, short for bottom right; bl, short for bottom left) and adversarially train two ResNet-18 on CIFAR-10 using standard AT with the same experiment settings except for the allocation of $\epsilon$. The robustness is evaluated by $\ell_{\infty}$-norm PGD-20. With the same overall perturbation budgets (i.e., allocate one of the regions to $6/255$ and others to $12/255$), we find that both natural accuracy and adversarial robustness change significantly if the regional allocation on $\epsilon$ is different. For example, by changing $\epsilon_{\rm{br}} = 6/255$ to $\epsilon_{\rm{ul}} = 6/255$, accuracy gains a 1.23\% improvement and robustness gains a 0.94\% improvement. 25 | 26 | #### Figure 2: The illustration of our method. 27 | ![pipeline](https://github.com/JiachengZ01/PART/blob/main/images/pipeline.jpg) 28 | Compared to AT, PART leverages the power of CAM methods to identify important pixel regions. Based on the class activation map, we element-wisely multiply a mask to the perturbation to keep the perturbation budget $\epsilon$ for important pixel regions while shrinking it to $\epsilon^{\rm low}$ for their counterparts during the generation process of AEs. 29 | 30 | ### Requirement 31 | - This codebase is written for ```python3``` and ```pytorch```. 32 | - To install necessay python packages, run ```pip install -r requirements.txt```. 33 | 34 | ### Data 35 | - Please download and place the dataset into the 'data' directory. 36 | 37 | ### Run Experiments 38 | #### Train and Evaluate PART 39 | ``` 40 | python3 train_eval_part.py 41 | ``` 42 | #### Train and Evaluate PART-T 43 | ``` 44 | python3 train_eval_part_t.py 45 | ``` 46 | 47 | #### Train and Evaluate PART-M 48 | ``` 49 | python3 train_eval_part_m.py 50 | ``` 51 | 52 | ### License and Contributing 53 | - This README is formatted based on [the NeurIPS guideline](https://github.com/paperswithcode/releasing-research-code). 54 | - Feel free to post any issues via Github. 55 | 56 | ### Citation 57 | ```bibtex 58 | @inproceedings{zhang2024improving, 59 | title={Improving Accuracy-robustness Trade-off via Pixel Reweighted Adversarial Training}, 60 | author={Jiacheng Zhang and Feng Liu and Dawei Zhou and Jingfeng Zhang and Tongliang Liu}, 61 | booktitle={International Conference on Machine Learning (ICML)}, 62 | year={2024} 63 | } 64 | ``` 65 | -------------------------------------------------------------------------------- /craft_ae.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from torch.autograd import Variable 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from torchattacks import PGD, AutoAttack 7 | 8 | def element_wise_clamp(eta, epsilon): 9 | # Element-wise clamp using the epsilon tensor 10 | eta_clamped = torch.where(eta > epsilon, epsilon, eta) 11 | eta_clamped = torch.where(eta < -epsilon, -epsilon, eta_clamped) 12 | return eta_clamped 13 | 14 | def craft_adversarial_example(model, 15 | x_natural, 16 | y, 17 | step_size=2/255, 18 | epsilon=8/255, 19 | perturb_steps=10, 20 | num_classes=10, 21 | mode='pgd'): 22 | if mode == 'pgd': 23 | attack = PGD(model, 24 | eps=epsilon, 25 | alpha=step_size, 26 | steps=perturb_steps, 27 | random_start=True) 28 | elif mode == 'aa': 29 | attack = AutoAttack(model, 30 | norm='Linf', 31 | eps=epsilon, 32 | version='standard') 33 | if mode == 'mma': 34 | x_adv = mma(model, 35 | data=x_natural, 36 | target=y, 37 | epsilon=epsilon, 38 | step_size=step_size, 39 | num_steps=perturb_steps, 40 | category='Madry', 41 | rand_init=True, 42 | k=3, 43 | num_classes=num_classes) 44 | else: 45 | x_adv = attack(x_natural, y) 46 | return x_adv 47 | 48 | def part_pgd(model, 49 | X, 50 | y, 51 | weighted_eps, 52 | epsilon=8/255, 53 | num_steps=10, 54 | step_size=2/255): 55 | X_pgd = Variable(X.data, requires_grad=True) 56 | 57 | random_noise = torch.FloatTensor(*X_pgd.shape).uniform_(-epsilon, epsilon).cuda() 58 | X_pgd = Variable(X_pgd.data + random_noise, requires_grad=True) 59 | 60 | for _ in range(num_steps): 61 | opt = torch.optim.SGD([X_pgd], lr=1e-3) 62 | opt.zero_grad() 63 | 64 | with torch.enable_grad(): 65 | loss = nn.CrossEntropyLoss()(model(X_pgd), y) 66 | loss.backward() 67 | eta = step_size * X_pgd.grad.data.sign() 68 | X_pgd = Variable(X_pgd.data + eta, requires_grad=True) 69 | eta = element_wise_clamp(X_pgd.data - X.data, weighted_eps) 70 | X_pgd = Variable(X.data + eta, requires_grad=True) 71 | X_pgd = Variable(torch.clamp(X_pgd, 0, 1.0), requires_grad=True) 72 | return X_pgd 73 | 74 | def part_mma(model, 75 | data, 76 | target, 77 | weighted_eps, 78 | epsilon, 79 | step_size, 80 | num_steps, 81 | rand_init, 82 | k, 83 | num_classes): 84 | model.eval() 85 | x_adv = data.detach() + torch.from_numpy( 86 | np.random.uniform(-epsilon, epsilon, data.shape)).float().cuda() if rand_init else data.detach() 87 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 88 | 89 | logits = model(data) 90 | target_onehot = torch.zeros(target.size() + (len(logits[0]),)) 91 | target_onehot = target_onehot.cuda() 92 | target_onehot.scatter_(1, target.unsqueeze(1), 1.) 93 | target_var = Variable(target_onehot, requires_grad=False) 94 | index = torch.argsort(logits - 10000 * target_var)[:, num_classes - k:] 95 | x_adv_set = [] 96 | loss_set = [] 97 | 98 | for i in range(k): 99 | x_adv_0 = x_adv.clone().detach() 100 | for j in range(num_steps): 101 | x_adv_0.requires_grad_() 102 | output1 = model(x_adv_0) 103 | model.zero_grad() 104 | with torch.enable_grad(): 105 | loss_adv0 = mm_loss(output1, target, index[:, i], num_classes=num_classes) 106 | loss_adv0.backward() 107 | eta = step_size * x_adv_0.grad.sign() 108 | x_adv_0 = x_adv_0.detach() + eta 109 | eta = element_wise_clamp(x_adv_0 - data, weighted_eps) 110 | x_adv_0 = data + eta 111 | x_adv_0 = torch.clamp(x_adv_0, 0.0, 1.0) 112 | 113 | pipy = mm_loss_train(model(x_adv_0), target, index[:, i], num_classes=num_classes) 114 | loss_set.append(pipy.view(len(pipy), -1)) 115 | x_adv_set.append(x_adv_0) 116 | 117 | loss_pipy = loss_set[0] 118 | for i in range(k - 1): 119 | loss_pipy = torch.cat((loss_pipy, loss_set[i + 1]), 1) 120 | 121 | index_choose = torch.argsort(loss_pipy)[:, -1] 122 | 123 | adv_final = torch.zeros(x_adv.size()).cuda() 124 | for i in range(len(index_choose)): 125 | adv_final[i, :, :, :] = x_adv_set[index_choose[i]][i] 126 | 127 | return adv_final 128 | 129 | def mma(model, 130 | data, 131 | target, 132 | epsilon, 133 | step_size, 134 | num_steps, 135 | category, 136 | rand_init, 137 | k, 138 | num_classes): 139 | model.eval() 140 | if category == "trades": 141 | x_adv = data.detach() + 0.001 * torch.randn(data.shape).cuda().detach() if rand_init else data.detach() 142 | if category == "Madry": 143 | x_adv = data.detach() + torch.from_numpy( 144 | np.random.uniform(-epsilon, epsilon, data.shape)).float().cuda() if rand_init else data.detach() 145 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 146 | 147 | logits = model(data) 148 | target_onehot = torch.zeros(target.size() + (len(logits[0]),)) 149 | target_onehot = target_onehot.cuda() 150 | target_onehot.scatter_(1, target.unsqueeze(1), 1.) 151 | target_var = Variable(target_onehot, requires_grad=False) 152 | index = torch.argsort(logits - 10000 * target_var)[:, num_classes - k:] 153 | 154 | x_adv_set = [] 155 | loss_set = [] 156 | for i in range(k): 157 | x_adv_0 = x_adv.clone().detach() 158 | for j in range(num_steps): 159 | x_adv_0.requires_grad_() 160 | output1 = model(x_adv_0) 161 | model.zero_grad() 162 | with torch.enable_grad(): 163 | loss_adv0 = mm_loss(output1, target, index[:, i], num_classes=num_classes) 164 | loss_adv0.backward() 165 | eta = step_size * x_adv_0.grad.sign() 166 | x_adv_0 = x_adv_0.detach() + eta 167 | x_adv_0 = torch.min(torch.max(x_adv_0, data - epsilon), data + epsilon) 168 | x_adv_0 = torch.clamp(x_adv_0, 0.0, 1.0) 169 | 170 | pipy = mm_loss_train(model(x_adv_0), target, index[:, i], num_classes=num_classes) 171 | loss_set.append(pipy.view(len(pipy), -1)) 172 | x_adv_set.append(x_adv_0) 173 | 174 | loss_pipy = loss_set[0] 175 | for i in range(k - 1): 176 | loss_pipy = torch.cat((loss_pipy, loss_set[i + 1]), 1) 177 | 178 | index_choose = torch.argsort(loss_pipy)[:, -1] 179 | 180 | adv_final = torch.zeros(x_adv.size()).cuda() 181 | for i in range(len(index_choose)): 182 | adv_final[i, :, :, :] = x_adv_set[index_choose[i]][i] 183 | 184 | return adv_final 185 | 186 | # loss for MM AT 187 | def mm_loss_train(output, target, target_choose, num_classes=10): 188 | target = target.data 189 | target_onehot = torch.zeros(target.size() + (num_classes,)) 190 | target_onehot = target_onehot.cuda() 191 | target_onehot.scatter_(1, target.unsqueeze(1), 1.) 192 | target_var = Variable(target_onehot, requires_grad=False) 193 | real = (target_var * output).sum(1) 194 | 195 | target_onehot = torch.zeros(target_choose.size() + (num_classes,)) 196 | target_onehot = target_onehot.cuda() 197 | target_onehot.scatter_(1, target_choose.unsqueeze(1), 1.) 198 | target_var = Variable(target_onehot, requires_grad=False) 199 | 200 | other = (target_var * output).sum(1) 201 | return other-real 202 | 203 | # loss for MM Attack 204 | def mm_loss(output, target, target_choose, confidence=50, num_classes=10): 205 | target = target.data 206 | target_onehot = torch.zeros(target.size() + (num_classes,)) 207 | target_onehot = target_onehot.cuda() 208 | target_onehot.scatter_(1, target.unsqueeze(1), 1.) 209 | target_var = Variable(target_onehot, requires_grad=False) 210 | real = (target_var * output).sum(1) 211 | 212 | target_onehot = torch.zeros(target_choose.size() + (num_classes,)) 213 | target_onehot = target_onehot.cuda() 214 | target_onehot.scatter_(1, target_choose.unsqueeze(1), 1.) 215 | target_var = Variable(target_onehot, requires_grad=False) 216 | 217 | other = (target_var * output).sum(1) 218 | loss = -torch.clamp(real - other + confidence, min=0.) # equiv to max(..., 0.) 219 | loss = torch.sum(loss) 220 | return loss -------------------------------------------------------------------------------- /train_eval_part_m.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import argparse 4 | 5 | import torch.optim as optim 6 | import torch.backends.cudnn as cudnn 7 | from torch.autograd import Variable 8 | 9 | from models.resnet import ResNet18 10 | from models.wideresnet import WideResNet 11 | 12 | from dataset.cifar10 import CIFAR10 13 | from dataset.svhn import SVHN 14 | 15 | from utils import * 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch Pixel-reweighted Adversarial Training') 18 | 19 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 20 | help='input batch size for training (default: 128)') 21 | parser.add_argument('--epochs', type=int, default=80, metavar='N', 22 | help='number of epochs to train') 23 | parser.add_argument('--weight-decay', '--wd', default=2e-4, type=float, metavar='W') 24 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 25 | help='learning rate') 26 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 27 | help='SGD momentum') 28 | parser.add_argument('--no-cuda', action='store_true', default=False, 29 | help='disables CUDA training') 30 | 31 | parser.add_argument('--epsilon', default=8/255, 32 | help='maximum allowed perturbation', type=parse_fraction) 33 | parser.add_argument('--low-epsilon', default=7/255, 34 | help='maximum allowed perturbation for unimportant pixels', 35 | type=parse_fraction) 36 | parser.add_argument('--num-steps', default=10, 37 | help='perturb number of steps') 38 | parser.add_argument('--num-class', default=10, 39 | help='number of classes') 40 | parser.add_argument('--step-size', default=2/255, 41 | help='perturb step size', type=parse_fraction) 42 | parser.add_argument('--beta', default=6.0, 43 | help='regularization, i.e., 1/lambda in TRADES') 44 | parser.add_argument('--adjust-first', type=int, default=60, 45 | help='adjust learning rate on which epoch in the first round') 46 | parser.add_argument('--adjust-second', type=int, default=90, 47 | help='adjust learning rate on which epoch in the second round') 48 | parser.add_argument('--rand_init', type=bool, default=True, 49 | help="whether to initialize adversarial sample with random noise") 50 | 51 | parser.add_argument('--seed', type=int, default=1, metavar='S', 52 | help='random seed (default: 1)') 53 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 54 | help='how many batches to wait before logging training status') 55 | parser.add_argument('--model-dir', default='./checkpoint/ResNet_18/PART_M', 56 | help='directory of model for saving checkpoint') 57 | parser.add_argument('--save-freq', default=10, type=int, metavar='N', 58 | help='save frequency') 59 | parser.add_argument('--save-weights', default=1, type=int, metavar='N', 60 | help='save frequency for weighted matrix') 61 | 62 | parser.add_argument('--data', type=str, default='CIFAR10', help='data source', choices=['CIFAR10', 'SVHN', 'TinyImagenet']) 63 | parser.add_argument('--model', type=str, default='resnet', choices=['resnet', 'wideresnet']) 64 | parser.add_argument('--warm-up', type=int, default=20, help='warm up epochs') 65 | parser.add_argument('--cam', type=str, default='gradcam', choices=['gradcam', 'xgradcam', 'layercam']) 66 | parser.add_argument('--attack', type=str, default='pgd', choices=['pgd', 'mma']) 67 | 68 | args = parser.parse_args() 69 | 70 | if args.data == 'CIFAR100': 71 | args.num_class = 100 72 | if args.data == 'TinyImagenet': 73 | args.num_class = 200 74 | 75 | def train(args, model, device, train_loader, optimizer, epoch, weighted_eps_list): 76 | model.train() 77 | for batch_idx, (data, target) in enumerate(train_loader): 78 | data, target = data.to(device), target.to(device) 79 | 80 | X, y = Variable(data, requires_grad=True), Variable(target) 81 | 82 | model.eval() 83 | weighted_eps = weighted_eps_list[batch_idx] 84 | 85 | optimizer.zero_grad() 86 | 87 | # calculate robust loss 88 | loss = part_mart_loss(model=model, 89 | x_natural=X, 90 | y=y, 91 | optimizer=optimizer, 92 | weighted_eps= weighted_eps, 93 | step_size=args.step_size, 94 | perturb_steps=args.num_steps, 95 | beta=args.beta) 96 | loss.backward() 97 | optimizer.step() 98 | 99 | # print progress 100 | if batch_idx % args.log_interval == 0: 101 | print('Train Epoch: {} [{}/{} ({:.2f}%)]\tLoss: {:.6f}'.format( 102 | epoch, (batch_idx+1) * len(data), len(train_loader.dataset), 103 | 100. * (batch_idx+1) / len(train_loader), loss.item())) 104 | 105 | def main(): 106 | # settings 107 | setup_seed(args.seed) 108 | use_cuda = not args.no_cuda and torch.cuda.is_available() 109 | torch.manual_seed(args.seed) 110 | device = torch.device("cuda" if use_cuda else "cpu") 111 | 112 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" 113 | 114 | # setup data loader 115 | if args.data == 'CIFAR10': 116 | train_loader = CIFAR10(train_batch_size=args.batch_size).train_data() 117 | test_loader = CIFAR10(test_batch_size=args.batch_size).test_data() 118 | if args.model == 'resnet': 119 | model_dir = './checkpoint/CIFAR10/ResNet_18/PART_M' 120 | model = ResNet18(num_classes=10).to(device) 121 | elif args.model == 'wideresnet': 122 | model_dir = './checkpoint/CIFAR10/WideResnet-34/PART_M' 123 | model = WideResNet(34, 10, 10).to(device) 124 | else: 125 | raise ValueError("Unknown model") 126 | elif args.data == 'SVHN': 127 | args.step_size = 1/255 128 | args.weight_decay = 0.0035 129 | args.lr = 0.01 130 | args.batch_size = 128 131 | train_loader = SVHN(train_batch_size=args.batch_size).train_data() 132 | test_loader = SVHN(test_batch_size=args.batch_size).test_data() 133 | if args.model == 'resnet': 134 | model_dir = './checkpoint/SVHN/ResNet_18/PART_M' 135 | model = ResNet18(num_classes=10).to(device) 136 | elif args.model == 'wideresnet': 137 | model_dir = './checkpoint/SVHN/WideResnet-34/PART_M' 138 | model = WideResNet(34, 10, 10).to(device) 139 | else: 140 | raise ValueError("Unknown model") 141 | else: 142 | raise ValueError("Unknown data") 143 | 144 | if not os.path.exists(model_dir): 145 | os.makedirs(model_dir) 146 | 147 | model = torch.nn.DataParallel(model) 148 | cudnn.benchmark = True 149 | optimizer = optim.SGD(model.parameters(), 150 | lr=args.lr, 151 | momentum=args.momentum, 152 | weight_decay=args.weight_decay) 153 | 154 | # warm up 155 | print('warm up starts') 156 | for epoch in range(1, args.warm_up + 1): 157 | mart_train(args, model, device, train_loader, optimizer, epoch) 158 | 159 | # save checkpoint 160 | if epoch % args.save_freq == 0: 161 | torch.save(model.state_dict(), 162 | os.path.join(model_dir, 'pre_part_m_epoch{}.pth'.format(epoch))) 163 | print('save the model') 164 | print('================================================================') 165 | print('warm up ends') 166 | 167 | weighted_eps_list = save_cam(model, train_loader, device, args) 168 | 169 | for epoch in range(1, args.epochs - args.warm_up + 1): 170 | if epoch % args.save_weights == 0 and epoch != 1: 171 | weighted_eps_list = save_cam(model, train_loader, device, args) 172 | 173 | # adjust learning rate for SGD 174 | adjust_learning_rate(args, optimizer, epoch) 175 | 176 | # adversarial training 177 | train(args, model, device, train_loader, optimizer, epoch, weighted_eps_list) 178 | 179 | # evaluation on natural examples 180 | print('================================================================') 181 | 182 | # save checkpoint 183 | if epoch % args.save_freq == 0: 184 | torch.save(model.state_dict(), 185 | os.path.join(model_dir, 'model-epoch{}.pth'.format(epoch))) 186 | 187 | # evaluation on adversarial examples 188 | print('PGD=============================================================') 189 | eval_test(args, model, device, test_loader, mode='pgd') 190 | print('MMA==============================================================') 191 | eval_test(args, model, device, test_loader, mode='mma') 192 | print('AA==============================================================') 193 | eval_test(args, model, device, test_loader, mode='aa') 194 | 195 | if __name__ == '__main__': 196 | main() 197 | -------------------------------------------------------------------------------- /train_eval_part_t.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import argparse 4 | 5 | import torch.optim as optim 6 | import torch.backends.cudnn as cudnn 7 | from torch.autograd import Variable 8 | 9 | from models.resnet import ResNet18 10 | from models.wideresnet import WideResNet 11 | 12 | from dataset.cifar10 import CIFAR10 13 | from dataset.svhn import SVHN 14 | 15 | from utils import * 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch Pixel-reweighted Adversarial Training') 18 | 19 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 20 | help='input batch size for training (default: 128)') 21 | parser.add_argument('--epochs', type=int, default=80, metavar='N', 22 | help='number of epochs to train') 23 | parser.add_argument('--weight-decay', '--wd', default=2e-4, type=float, metavar='W') 24 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 25 | help='learning rate') 26 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 27 | help='SGD momentum') 28 | parser.add_argument('--no-cuda', action='store_true', default=False, 29 | help='disables CUDA training') 30 | 31 | parser.add_argument('--epsilon', default=8/255, 32 | help='maximum allowed perturbation', type=parse_fraction) 33 | parser.add_argument('--low-epsilon', default=7/255, 34 | help='maximum allowed perturbation for unimportant pixels', 35 | type=parse_fraction) 36 | parser.add_argument('--num-steps', default=10, 37 | help='perturb number of steps') 38 | parser.add_argument('--num-class', default=10, 39 | help='number of classes') 40 | parser.add_argument('--step-size', default=2/255, 41 | help='perturb step size', type=parse_fraction) 42 | parser.add_argument('--beta', default=6.0, 43 | help='regularization, i.e., 1/lambda in TRADES') 44 | parser.add_argument('--adjust-first', type=int, default=60, 45 | help='adjust learning rate on which epoch in the first round') 46 | parser.add_argument('--adjust-second', type=int, default=90, 47 | help='adjust learning rate on which epoch in the second round') 48 | parser.add_argument('--rand_init', type=bool, default=True, 49 | help="whether to initialize adversarial sample with random noise") 50 | 51 | parser.add_argument('--seed', type=int, default=1, metavar='S', 52 | help='random seed (default: 1)') 53 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 54 | help='how many batches to wait before logging training status') 55 | parser.add_argument('--model-dir', default='./checkpoint/ResNet_18/PART_T', 56 | help='directory of model for saving checkpoint') 57 | parser.add_argument('--save-freq', default=10, type=int, metavar='N', 58 | help='save frequency') 59 | parser.add_argument('--save-weights', default=1, type=int, metavar='N', 60 | help='save frequency for weighted matrix') 61 | 62 | parser.add_argument('--data', type=str, default='CIFAR10', help='data source', choices=['CIFAR10', 'SVHN', 'TinyImagenet']) 63 | parser.add_argument('--model', type=str, default='resnet', choices=['resnet', 'wideresnet']) 64 | parser.add_argument('--warm-up', type=int, default=20, help='warm up epochs') 65 | parser.add_argument('--cam', type=str, default='gradcam', choices=['gradcam', 'xgradcam', 'layercam']) 66 | parser.add_argument('--attack', type=str, default='pgd', choices=['pgd', 'mma']) 67 | 68 | args = parser.parse_args() 69 | 70 | if args.data == 'CIFAR100': 71 | args.num_class = 100 72 | if args.data == 'TinyImagenet': 73 | args.num_class = 200 74 | 75 | def train(args, model, device, train_loader, optimizer, epoch, weighted_eps_list): 76 | 77 | for batch_idx, (data, target) in enumerate(train_loader): 78 | data, target = data.to(device), target.to(device) 79 | 80 | X, y = Variable(data, requires_grad=True), Variable(target) 81 | 82 | model.eval() 83 | weighted_eps = weighted_eps_list[batch_idx] 84 | 85 | model.train() 86 | optimizer.zero_grad() 87 | 88 | # calculate robust loss 89 | loss = part_trades_loss(model=model, 90 | x_natural=X, 91 | y=y, 92 | optimizer=optimizer, 93 | weighted_eps = weighted_eps, 94 | step_size=args.step_size, 95 | perturb_steps=args.num_steps, 96 | beta=args.beta) 97 | loss.backward() 98 | optimizer.step() 99 | 100 | # print progress 101 | if batch_idx % args.log_interval == 0: 102 | print('Train Epoch: {} [{}/{} ({:.2f}%)]\tLoss: {:.6f}'.format( 103 | epoch, (batch_idx+1) * len(data), len(train_loader.dataset), 104 | 100. * (batch_idx+1) / len(train_loader), loss.item())) 105 | 106 | def main(): 107 | # settings 108 | setup_seed(args.seed) 109 | use_cuda = not args.no_cuda and torch.cuda.is_available() 110 | torch.manual_seed(args.seed) 111 | device = torch.device("cuda" if use_cuda else "cpu") 112 | 113 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" 114 | 115 | # setup data loader 116 | if args.data == 'CIFAR10': 117 | train_loader = CIFAR10(train_batch_size=args.batch_size).train_data() 118 | test_loader = CIFAR10(test_batch_size=args.batch_size).test_data() 119 | if args.model == 'resnet': 120 | model_dir = './checkpoint/CIFAR10/ResNet_18/PART_T' 121 | model = ResNet18(num_classes=10).to(device) 122 | elif args.model == 'wideresnet': 123 | model_dir = './checkpoint/CIFAR10/WideResnet-34/PART_T' 124 | model = WideResNet(34, 10, 10).to(device) 125 | else: 126 | raise ValueError("Unknown model") 127 | elif args.data == 'SVHN': 128 | args.step_size = 1/255 129 | args.weight_decay = 0.0035 130 | args.lr = 0.01 131 | args.batch_size = 128 132 | train_loader = SVHN(train_batch_size=args.batch_size).train_data() 133 | test_loader = SVHN(test_batch_size=args.batch_size).test_data() 134 | if args.model == 'resnet': 135 | model_dir = './checkpoint/SVHN/ResNet_18/PART_T' 136 | model = ResNet18(num_classes=10).to(device) 137 | elif args.model == 'wideresnet': 138 | model_dir = './checkpoint/SVHN/WideResnet-34/PART_T' 139 | model = WideResNet(34, 10, 10).to(device) 140 | else: 141 | raise ValueError("Unknown model") 142 | else: 143 | raise ValueError("Unknown data") 144 | 145 | if not os.path.exists(model_dir): 146 | os.makedirs(model_dir) 147 | 148 | model = torch.nn.DataParallel(model) 149 | cudnn.benchmark = True 150 | optimizer = optim.SGD(model.parameters(), 151 | lr=args.lr, 152 | momentum=args.momentum, 153 | weight_decay=args.weight_decay) 154 | # warm up 155 | print('warm up starts') 156 | for epoch in range(1, args.warm_up + 1): 157 | trades_train(args, model, device, train_loader, optimizer, epoch) 158 | 159 | # save checkpoint 160 | if epoch % args.save_freq == 0: 161 | torch.save(model.state_dict(), 162 | os.path.join(model_dir, 'pre_part_t_epoch{}.pth'.format(epoch))) 163 | print('save the model') 164 | print('================================================================') 165 | print('warm up ends') 166 | 167 | weighted_eps_list = save_cam(model, train_loader, device, args) 168 | for epoch in range(1, args.epochs - args.warm_up + 1): 169 | if epoch % args.save_weights == 0 and epoch != 1: 170 | weighted_eps_list = save_cam(model, train_loader, device, args) 171 | 172 | # adjust learning rate for SGD 173 | adjust_learning_rate(args, optimizer, epoch) 174 | 175 | # adversarial training 176 | train(args, model, device, train_loader, optimizer, epoch, weighted_eps_list) 177 | 178 | # evaluation on natural examples 179 | print('================================================================') 180 | 181 | # save checkpoint 182 | if epoch % args.save_freq == 0: 183 | torch.save(model.state_dict(), 184 | os.path.join(model_dir, 'model-epoch{}.pth'.format(epoch))) 185 | 186 | # evaluation on adversarial examples 187 | print('PGD=============================================================') 188 | eval_test(args, model, device, test_loader, mode='pgd') 189 | print('MMA==============================================================') 190 | eval_test(args, model, device, test_loader, mode='mma') 191 | print('AA==============================================================') 192 | eval_test(args, model, device, test_loader, mode='aa') 193 | 194 | if __name__ == '__main__': 195 | main() 196 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchcam.methods import GradCAM, XGradCAM, LayerCAM 3 | from PIL import Image 4 | from torchvision.transforms.functional import to_pil_image 5 | from matplotlib import cm 6 | import numpy as np 7 | import random 8 | from craft_ae import * 9 | from loss.mart import * 10 | from loss.trades import * 11 | from loss.part_mart import * 12 | from loss.part_trades import * 13 | 14 | def parse_fraction(fraction_string): 15 | if '/' in fraction_string: 16 | numerator, denominator = fraction_string.split('/') 17 | return float(numerator) / float(denominator) 18 | return float(fraction_string) 19 | 20 | def setup_seed(seed): 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | np.random.seed(seed) 24 | random.seed(seed) 25 | torch.backends.cudnn.deterministic = True 26 | 27 | def adjust_learning_rate(args, optimizer, epoch): 28 | """decrease the learning rate""" 29 | lr = args.lr 30 | if epoch >= args.adjust_second: 31 | lr = args.lr * 0.01 32 | elif epoch >= args.adjust_first: 33 | lr = args.lr * 0.1 34 | 35 | for param_group in optimizer.param_groups: 36 | param_group['lr'] = lr 37 | 38 | def craft_weight_matrix(model, data, device, args, parallel=True): 39 | batch, img_size1, img_size2 = data.shape[0], data.shape[-2], data.shape[-1] 40 | weight_matrix_tensor = torch.empty(batch, 3, img_size1, img_size2).to(device) 41 | 42 | if args.model == 'resnet': 43 | if args.cam == 'gradcam': 44 | cam_extractor = GradCAM(model.module if parallel else model, 'layer4') 45 | if args.cam == 'xgradcam': 46 | cam_extractor = XGradCAM(model.module if parallel else model, 'layer4') 47 | if args.cam == 'layercam': 48 | cam_extractor = LayerCAM(model.module if parallel else model, 'layer4') 49 | elif args.model == 'wideresnet': 50 | if args.cam == 'gradcam': 51 | cam_extractor = GradCAM(model.module if parallel else model, 'block3') 52 | if args.cam == 'xgradcam': 53 | cam_extractor = XGradCAM(model.module if parallel else model, 'block3') 54 | if args.cam == 'layercam': 55 | cam_extractor = LayerCAM(model.module if parallel else model, 'block3') 56 | 57 | for i in range(batch): 58 | output = model(data[i].unsqueeze(0)) 59 | heatmap = cam_extractor(output.argmax().item(), output) 60 | mask = to_pil_image(heatmap[0].squeeze(0).cpu().numpy()) 61 | overlay = mask.resize((img_size1, img_size2), resample=Image.BICUBIC) 62 | cmap_overlay = cm.get_cmap('jet')(np.asarray(overlay) ** 2) 63 | weight_matrix_tensor[i] = process_overlay(cmap_overlay, device) 64 | 65 | cam_extractor.remove_hooks() 66 | 67 | return weight_matrix_tensor 68 | 69 | def process_overlay(overlay, device): 70 | overlay = (255 * overlay[:, :, :3]).astype(np.double) 71 | normalized_overlay = overlay / 255 72 | mean, std = np.mean(normalized_overlay), np.std(normalized_overlay) 73 | weight_matrix = torch.from_numpy((normalized_overlay - mean) / std) 74 | return torch.clamp(weight_matrix, 1, weight_matrix.max()).float().permute(2, 0, 1).to(device) 75 | 76 | def generate_weighted_eps(weight_matrix, args): 77 | epsilon = torch.where(weight_matrix > 1, args.epsilon, args.low_epsilon) 78 | return epsilon 79 | 80 | def standard_train(args, model, device, train_loader, optimizer, epoch): 81 | 82 | for batch_idx, (data, label) in enumerate(train_loader): 83 | data, label = data.to(device), label.to(device) 84 | 85 | # calculate robust loss 86 | model.eval() 87 | if args.attack == 'pgd': 88 | data = craft_adversarial_example(model=model, 89 | x_natural=data, 90 | y=label, 91 | step_size=args.step_size, 92 | epsilon=args.epsilon, 93 | perturb_steps=args.num_steps, 94 | num_classes=args.num_class, 95 | mode='pgd') 96 | elif args.attack == 'mma': 97 | data = craft_adversarial_example(model=model, 98 | x_natural=data, 99 | y=label, 100 | step_size=args.step_size, 101 | epsilon=args.epsilon, 102 | perturb_steps=args.num_steps, 103 | num_classes=args.num_class, 104 | mode='mma') 105 | 106 | model.train() 107 | optimizer.zero_grad() 108 | 109 | logits_out = model(data) 110 | loss = F.cross_entropy(logits_out, label) 111 | 112 | loss.backward() 113 | optimizer.step() 114 | 115 | # print progress 116 | if batch_idx % args.log_interval == 0: 117 | print('Train Epoch: {} [{}/{} ({:.2f}%)]\tLoss: {:.6f}'.format( 118 | epoch, (batch_idx+1) * len(data), len(train_loader.dataset), 119 | 100. * (batch_idx+1) / len(train_loader), loss.item())) 120 | 121 | def mart_train(args, model, device, train_loader, optimizer, epoch): 122 | model.train() 123 | for batch_idx, (data, target) in enumerate(train_loader): 124 | data, target = data.to(device), target.to(device) 125 | 126 | optimizer.zero_grad() 127 | 128 | # calculate robust loss 129 | loss = mart_loss(model=model, 130 | x_natural=data, 131 | y=target, 132 | optimizer=optimizer, 133 | step_size=args.step_size, 134 | epsilon=args.epsilon, 135 | perturb_steps=args.num_steps, 136 | beta=args.beta) 137 | loss.backward() 138 | optimizer.step() 139 | 140 | # print progress 141 | if batch_idx % args.log_interval == 0: 142 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 143 | epoch, (batch_idx+1) * len(data), len(train_loader.dataset), 144 | 100. * (batch_idx+1) / len(train_loader), loss.item())) 145 | 146 | def trades_train(args, model, device, train_loader, optimizer, epoch): 147 | model.train() 148 | for batch_idx, (data, target) in enumerate(train_loader): 149 | data, target = data.to(device), target.to(device) 150 | 151 | optimizer.zero_grad() 152 | 153 | # calculate robust loss 154 | loss = trades_loss(model=model, 155 | x_natural=data, 156 | y=target, 157 | optimizer=optimizer, 158 | step_size=args.step_size, 159 | epsilon=args.epsilon, 160 | perturb_steps=args.num_steps, 161 | beta=args.beta) 162 | loss.backward() 163 | optimizer.step() 164 | 165 | # print progress 166 | if batch_idx % args.log_interval == 0: 167 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 168 | epoch, (batch_idx+1) * len(data), len(train_loader.dataset), 169 | 100. * (batch_idx+1) / len(train_loader), loss.item())) 170 | 171 | def eval_test(args, model, device, test_loader, mode='pgd'): 172 | model.eval() 173 | correct = 0 174 | correct_adv = 0 175 | 176 | for data, label in test_loader: 177 | data, label = data.to(device), label.to(device) 178 | 179 | logits_out = model(data) 180 | pred = logits_out.max(1, keepdim=True)[1] 181 | correct += pred.eq(label.view_as(pred)).sum().item() 182 | 183 | data = craft_adversarial_example(model=model, 184 | x_natural=data, 185 | y=label, 186 | step_size=args.step_size, 187 | epsilon=8/255, 188 | perturb_steps=20, 189 | num_classes=args.num_class, 190 | mode=mode) 191 | 192 | logits_out = model(data) 193 | pred = logits_out.max(1, keepdim=True)[1] 194 | correct_adv += pred.eq(label.view_as(pred)).sum().item() 195 | 196 | print('Test: Accuracy: {}/{} ({:.2f}%), Robust Accuracy: {}/{} ({:.2f}%)'.format( 197 | correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset), correct_adv, 198 | len(test_loader.dataset), 100. * correct_adv / len(test_loader.dataset))) 199 | 200 | def save_cam(model, train_loader, device, args): 201 | weighted_eps_list = [] 202 | for _, (data, label) in enumerate(train_loader): 203 | data, label = data.to(device), label.to(device) 204 | 205 | # calculate robust loss 206 | model.eval() 207 | weight_matrix = craft_weight_matrix(model, data, device, args, parallel=True) 208 | weighted_eps = generate_weighted_eps(weight_matrix, args) 209 | weighted_eps_list.append(weighted_eps) 210 | return weighted_eps_list 211 | -------------------------------------------------------------------------------- /train_eval_part.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import argparse 4 | 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | import torch.backends.cudnn as cudnn 8 | from torch.autograd import Variable 9 | 10 | from models.resnet import ResNet18 11 | from models.wideresnet import WideResNet 12 | 13 | from dataset.cifar10 import CIFAR10 14 | from dataset.svhn import SVHN 15 | 16 | from utils import * 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch Pixel-reweighted Adversarial Training') 19 | 20 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 21 | help='input batch size for training (default: 128)') 22 | parser.add_argument('--epochs', type=int, default=80, metavar='N', 23 | help='number of epochs to train') 24 | parser.add_argument('--weight-decay', '--wd', default=2e-4, type=float, metavar='W') 25 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 26 | help='learning rate') 27 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 28 | help='SGD momentum') 29 | parser.add_argument('--no-cuda', action='store_true', default=False, 30 | help='disables CUDA training') 31 | 32 | parser.add_argument('--epsilon', default=8/255, 33 | help='maximum allowed perturbation', type=parse_fraction) 34 | parser.add_argument('--low-epsilon', default=7/255, 35 | help='maximum allowed perturbation for unimportant pixels', 36 | type=parse_fraction) 37 | parser.add_argument('--num-steps', default=10, 38 | help='perturb number of steps', type=int) 39 | parser.add_argument('--num-class', default=10, 40 | help='number of classes') 41 | parser.add_argument('--step-size', default=2/255, 42 | help='perturb step size', type=parse_fraction) 43 | parser.add_argument('--adjust-first', type=int, default=60, 44 | help='adjust learning rate on which epoch in the first round') 45 | parser.add_argument('--adjust-second', type=int, default=90, 46 | help='adjust learning rate on which epoch in the second round') 47 | parser.add_argument('--rand_init', type=bool, default=True, 48 | help="whether to initialize adversarial sample with random noise") 49 | parser.add_argument('--pre-trained', type=bool, default=False, 50 | help="whether to use pre-trained weighted matrix") 51 | 52 | parser.add_argument('--seed', type=int, default=1, metavar='S', 53 | help='random seed (default: 1)') 54 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 55 | help='how many batches to wait before logging training status') 56 | parser.add_argument('--model-dir', default='./checkpoint/ResNet_18/PART', 57 | help='directory of model for saving checkpoint') 58 | parser.add_argument('--save-freq', default=10, type=int, metavar='N', 59 | help='save frequency') 60 | parser.add_argument('--save-weights', default=1, type=int, metavar='N', 61 | help='save frequency for weighted matrix') 62 | 63 | parser.add_argument('--data', type=str, default='CIFAR10', 64 | help='data source', choices=['CIFAR10', 'SVHN', 'TinyImagenet']) 65 | parser.add_argument('--model', type=str, default='resnet', choices=['resnet', 'wideresnet']) 66 | parser.add_argument('--warm-up', type=int, default=20, help='warm up epochs') 67 | parser.add_argument('--cam', type=str, default='gradcam', choices=['gradcam', 'xgradcam', 'layercam']) 68 | parser.add_argument('--attack', type=str, default='pgd', choices=['pgd', 'mma']) 69 | 70 | args = parser.parse_args() 71 | 72 | if args.data == 'CIFAR100': 73 | args.num_class = 100 74 | if args.data == 'TinyImagenet': 75 | args.num_class = 200 76 | 77 | def train(args, model, device, train_loader, optimizer, epoch, weighted_eps_list): 78 | if args.pre_trained: 79 | weighted_eps_list = np.load('weighted_eps_list.npz') 80 | 81 | for batch_idx, (data, label) in enumerate(train_loader): 82 | if args.pre_trained: 83 | weighted_eps = torch.from_numpy(weighted_eps_list[f'arr_{batch_idx}']).to(device) 84 | 85 | data, label = data.to(device), label.to(device) 86 | 87 | X, y = Variable(data, requires_grad=True), Variable(label) 88 | 89 | # calculate robust loss 90 | model.eval() 91 | if not args.pre_trained: 92 | weighted_eps = weighted_eps_list[batch_idx] 93 | 94 | if args.attack == 'pgd': 95 | data = part_pgd(model, 96 | X, 97 | y, 98 | weighted_eps, 99 | epsilon=args.epsilon, 100 | num_steps=args.num_steps, 101 | step_size=args.step_size) 102 | elif args.attack == 'mma': 103 | data = part_mma(model, 104 | data, 105 | label, 106 | weighted_eps, 107 | epsilon=args.epsilon, 108 | step_size=args.step_size, 109 | num_steps=args.num_steps, 110 | rand_init=args.rand_init, 111 | k=3, 112 | num_classes=args.num_class) 113 | 114 | model.train() 115 | optimizer.zero_grad() 116 | 117 | out = model(data) 118 | loss = F.cross_entropy(out, label) 119 | 120 | loss.backward() 121 | optimizer.step() 122 | 123 | # print progress 124 | if batch_idx % args.log_interval == 0: 125 | print('Train Epoch: {} [{}/{} ({:.2f}%)]\tLoss: {:.6f}'.format( 126 | epoch, (batch_idx+1) * len(data), len(train_loader.dataset), 127 | 100. * (batch_idx+1) / len(train_loader), loss.item())) 128 | 129 | def main(): 130 | # settings 131 | setup_seed(args.seed) 132 | use_cuda = not args.no_cuda and torch.cuda.is_available() 133 | torch.manual_seed(args.seed) 134 | device = torch.device("cuda" if use_cuda else "cpu") 135 | 136 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" 137 | 138 | # setup data loader 139 | if args.data == 'CIFAR10': 140 | train_loader = CIFAR10(train_batch_size=args.batch_size).train_data() 141 | test_loader = CIFAR10(test_batch_size=args.batch_size).test_data() 142 | if args.model == 'resnet': 143 | model_dir = './checkpoint/CIFAR10/ResNet_18/PART' 144 | model = ResNet18(num_classes=10).to(device) 145 | elif args.model == 'wideresnet': 146 | model_dir = './checkpoint/CIFAR10/WideResnet-34/PART' 147 | model = WideResNet(34, 10, 10).to(device) 148 | else: 149 | raise ValueError("Unknown model") 150 | elif args.data == 'SVHN': 151 | args.step_size = 1/255 152 | args.weight_decay = 0.0035 153 | args.lr = 0.01 154 | args.batch_size = 128 155 | train_loader = SVHN(train_batch_size=args.batch_size).train_data() 156 | test_loader = SVHN(test_batch_size=args.batch_size).test_data() 157 | if args.model == 'resnet': 158 | model_dir = './checkpoint/SVHN/ResNet_18/PART' 159 | model = ResNet18(num_classes=10).to(device) 160 | elif args.model == 'wideresnet': 161 | model_dir = './checkpoint/SVHN/WideResnet-34/PART' 162 | model = WideResNet(34, 10, 10).to(device) 163 | else: 164 | raise ValueError("Unknown model") 165 | else: 166 | raise ValueError("Unknown data") 167 | 168 | if not os.path.exists(model_dir): 169 | os.makedirs(model_dir) 170 | 171 | model = torch.nn.DataParallel(model) 172 | cudnn.benchmark = True 173 | optimizer = optim.SGD(model.parameters(), 174 | lr=args.lr, 175 | momentum=args.momentum, 176 | weight_decay=args.weight_decay) 177 | 178 | # warm up 179 | print('warm up starts') 180 | for epoch in range(1, args.warm_up + 1): 181 | standard_train(args, model, device, train_loader, optimizer, epoch) 182 | 183 | # save checkpoint 184 | if epoch % args.save_freq == 0: 185 | torch.save(model.state_dict(), 186 | os.path.join(model_dir, 'pre_part_epoch{}.pth'.format(epoch))) 187 | print('save the model') 188 | print('================================================================') 189 | print('warm up ends') 190 | 191 | weighted_eps_list = save_cam(model, train_loader, device, args) 192 | 193 | # train 194 | for epoch in range(1, args.epochs - args.warm_up + 1): 195 | if epoch % args.save_weights == 0 and epoch != 1: 196 | weighted_eps_list = save_cam(model, train_loader, device, args) 197 | 198 | # adjust learning rate for SGD 199 | adjust_learning_rate(args, optimizer, epoch) 200 | 201 | # adversarial training 202 | train(args, model, device, train_loader, optimizer, epoch, weighted_eps_list) 203 | 204 | # save checkpoint 205 | if epoch % args.save_freq == 0: 206 | torch.save(model.state_dict(), 207 | os.path.join(model_dir, 'part_epoch{}.pth'.format(epoch))) 208 | print('save the model') 209 | 210 | print('================================================================') 211 | 212 | # evaluation on adversarial examples 213 | print('PGD=============================================================') 214 | eval_test(args, model, device, test_loader, mode='pgd') 215 | print('MMA==============================================================') 216 | eval_test(args, model, device, test_loader, mode='mma') 217 | print('AA==============================================================') 218 | eval_test(args, model, device, test_loader, mode='aa') 219 | 220 | if __name__ == '__main__': 221 | main() 222 | --------------------------------------------------------------------------------