├── imgs ├── cl.png ├── NeurIPS2022_1958_Poster.pdf └── NeurIPS2022_1958_Poster.png ├── models ├── __init__.py ├── wrn_madry.py ├── base_models.py ├── preactresnet.py ├── resnet.py └── densenet.py ├── .gitignore ├── LICENSE ├── utils_func.py ├── utils_mcl_loss.py ├── README.md ├── attack_generator.py ├── utils_algo.py ├── utils_data.py └── main.py /imgs/cl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyalSkye/ATCL/HEAD/imgs/cl.png -------------------------------------------------------------------------------- /imgs/NeurIPS2022_1958_Poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyalSkye/ATCL/HEAD/imgs/NeurIPS2022_1958_Poster.pdf -------------------------------------------------------------------------------- /imgs/NeurIPS2022_1958_Poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RoyalSkye/ATCL/HEAD/imgs/NeurIPS2022_1958_Poster.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .base_models import * 3 | from .densenet import * 4 | from .wrn_madry import * 5 | from .preactresnet import * 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # macOS 7 | .DS_Store 8 | 9 | # Jupyter Notebook 10 | .ipynb_checkpoints 11 | 12 | .idea/ 13 | 14 | # private files 15 | utils_plot* 16 | tmp/ 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jianan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/wrn_madry.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | 35 | class NetworkBlock(nn.Module): 36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 37 | super(NetworkBlock, self).__init__() 38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 39 | 40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 41 | layers = [] 42 | for i in range(int(nb_layers)): 43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | return self.layer(x) 48 | 49 | 50 | class Wide_ResNet_Madry(nn.Module): 51 | def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0, input_channel=3): 52 | super(Wide_ResNet_Madry, self).__init__() 53 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 54 | assert ((depth - 2) % 6 == 0) 55 | n = (depth - 2) / 6 56 | block = BasicBlock 57 | # 1st conv before any network block 58 | self.conv1 = nn.Conv2d(input_channel, nChannels[0], kernel_size=3, stride=1, 59 | padding=1, bias=False) 60 | # 1st block 61 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 62 | # 1st sub-block 63 | # self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 64 | # 2nd block 65 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 66 | # 3rd block 67 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 68 | # global average pooling and classifier 69 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.linear = nn.Linear(nChannels[3], num_classes) 72 | self.nChannels = nChannels[3] 73 | 74 | for m in self.modules(): 75 | if isinstance(m, nn.Conv2d): 76 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 77 | m.weight.data.normal_(0, math.sqrt(2. / n)) 78 | elif isinstance(m, nn.BatchNorm2d): 79 | m.weight.data.fill_(1) 80 | m.bias.data.zero_() 81 | elif isinstance(m, nn.Linear): 82 | m.bias.data.zero_() 83 | 84 | def forward(self, x): 85 | out = self.conv1(x) 86 | out = self.block1(out) 87 | out = self.block2(out) 88 | out = self.block3(out) 89 | out = self.relu(self.bn1(out)) 90 | out = F.avg_pool2d(out, 8) 91 | out = out.view(-1, self.nChannels) 92 | return self.linear(out) 93 | -------------------------------------------------------------------------------- /utils_func.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | 5 | def show(x, y, label, title, xdes, ydes, path, x_scale="linear", dpi=300): 6 | plt.style.use('fivethirtyeight') # bmh, fivethirtyeight, Solarize_Light2 7 | plt.figure(figsize=(10, 8)) 8 | colors = ['tab:green', 'tab:orange', 'tab:blue', 'tab:red', 'tab:cyan', 9 | 'tab:gray', 'tab:brown', 'tab:purple', 'tab:olive', 'tab:pink'] 10 | # colors = ['tab:pink', 'tab:olive', 'tab:green', 'tab:orange', 'tab:blue', 'tab:red', 'tab:cyan', 11 | # 'tab:gray', 'tab:brown', 'tab:purple'] 12 | 13 | assert len(x) == len(y) 14 | for i in range(len(x)): 15 | if i < len(label): 16 | plt.plot(x[i], y[i], color=colors[i], label=label[i], linewidth=1.5) # linewidth=1.5 17 | else: 18 | plt.plot(x[i], y[i], color=colors[i % len(label)], linewidth=1.5) # linewidth=1.5 19 | 20 | plt.gca().get_xaxis().get_major_formatter().set_scientific(False) 21 | plt.gca().get_yaxis().get_major_formatter().set_scientific(False) 22 | plt.xlabel(xdes, fontsize=24) 23 | plt.ylabel(ydes, fontsize=24) 24 | 25 | plt.title(title, fontsize=24) 26 | # my_y_ticks = np.arange(0, 1.1, 0.2) 27 | # plt.yticks(my_y_ticks, fontsize=24) 28 | plt.xticks(fontsize=24) 29 | plt.yticks(fontsize=24) 30 | plt.legend(loc='lower right', fontsize=16) 31 | plt.xscale(x_scale) 32 | # plt.margins(x=0) 33 | 34 | # plt.grid(True) 35 | plt.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0) 36 | plt.close("all") 37 | 38 | 39 | def display_num_param(net): 40 | nb_param = 0 41 | for param in net.parameters(): 42 | nb_param += param.numel() 43 | print('There are {} ({:.2f} million) parameters in this neural network'.format(nb_param, nb_param/1e6)) 44 | 45 | 46 | def stat(x_to_mcls, x_to_tls): 47 | # test how many data are given wrong cls 48 | wrong_cl_count, correct_cl_count = 0, 0 49 | for k, v in x_to_mcls.items(): 50 | correct_cl_count += len(v) 51 | if x_to_tls[k] in v: 52 | wrong_cl_count += 1 53 | # test how many correct cls are given to each data 54 | correct_cl_count -= wrong_cl_count 55 | 56 | return wrong_cl_count/len(x_to_mcls), correct_cl_count/len(x_to_mcls) 57 | 58 | 59 | def show_std(x, y, label, title, xdes, ydes, path, x_scale="linear", std=True, dpi=300): 60 | """ 61 | input: x/y: e.g.: [[exp1, exp2, exp3], [at1, at2, at3], ...] 62 | """ 63 | print(">> Plot with mean with std err here!") 64 | plt.style.use('fivethirtyeight') 65 | plt.figure(figsize=(10, 8)) 66 | colors = ['tab:green', 'tab:orange', 'tab:blue', 'tab:red', 'tab:cyan', 67 | 'tab:gray', 'tab:brown', 'tab:purple', 'tab:olive', 'tab:pink'] 68 | # colors = ['tab:pink', 'tab:olive', 'tab:green', 'tab:orange', 'tab:blue', 'tab:red', 'tab:cyan', 69 | # 'tab:gray', 'tab:brown', 'tab:purple'] 70 | 71 | assert len(x) == len(y) 72 | for k in range(len(y)): 73 | xx, yy = x[k], y[k] 74 | epoch = len(xx[0]) 75 | y_zip, y_est, y_err = [], np.array([0.] * epoch), np.array([0.] * epoch) 76 | for i in range(epoch): 77 | ll = [] 78 | for j in range(len(yy)): 79 | ll.append(yy[j][i]) 80 | y_zip.append(ll) 81 | for i in range(epoch): 82 | y_est[i] = np.mean(y_zip[i]) 83 | y_err[i] = np.std(y_zip[i]) / np.sqrt(3) # len(y[0]) 84 | plt.plot(xx[0], y_est, color=colors[k], label=label[k]) # linewidth=2.0 linestyle=":", marker='o', ms=15 85 | if std: 86 | plt.fill_between(xx[0], y_est - y_err, y_est + y_err, alpha=0.2, color=colors[k]) 87 | 88 | plt.gca().get_xaxis().get_major_formatter().set_scientific(False) 89 | plt.gca().get_yaxis().get_major_formatter().set_scientific(False) 90 | plt.xlabel(xdes, fontsize=24) 91 | plt.ylabel(ydes, fontsize=24) 92 | 93 | plt.title(title, fontsize=24) 94 | plt.xticks(fontsize=24) 95 | plt.yticks(fontsize=24) 96 | plt.legend(loc='lower right', fontsize=20) 97 | plt.xscale(x_scale) 98 | # plt.margins(x=0) 99 | 100 | # plt.grid(True) 101 | plt.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0) 102 | plt.close("all") 103 | -------------------------------------------------------------------------------- /models/base_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | import torchvision.datasets as dsets 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | from collections import OrderedDict 8 | 9 | 10 | class linear_model(nn.Module): 11 | def __init__(self, input_dim, output_dim): 12 | super(linear_model, self).__init__() 13 | self.linear = nn.Linear(input_dim, output_dim) 14 | 15 | def forward(self, x): 16 | out = x.view(-1, self.num_flat_features(x)) 17 | out = self.linear(out) 18 | return out 19 | 20 | def num_flat_features(self, x): 21 | # x: [bs, c, h, w] 22 | size = x.size()[1:] 23 | num_features = 1 24 | for s in size: 25 | num_features *= s 26 | return num_features 27 | 28 | 29 | class mlp_model(nn.Module): 30 | def __init__(self, input_dim, hidden_dim, output_dim): 31 | super(mlp_model, self).__init__() 32 | self.fc1 = nn.Linear(input_dim, hidden_dim) 33 | self.relu1 = nn.ReLU() 34 | self.fc2 = nn.Linear(hidden_dim, output_dim) 35 | 36 | def forward(self, x): 37 | out = x.view(-1, self.num_flat_features(x)) 38 | out = self.fc1(out) 39 | out = self.relu1(out) 40 | out = self.fc2(out) 41 | return out 42 | 43 | def num_flat_features(self, x): 44 | size = x.size()[1:] 45 | num_features = 1 46 | for s in size: 47 | num_features *= s 48 | return num_features 49 | 50 | 51 | class cnn_mnist(nn.Module): 52 | """ 53 | From https://github.com/yaodongyu/TRADES/blob/master/models/net_mnist.py 54 | """ 55 | def __init__(self, num_classes=10): 56 | super(cnn_mnist, self).__init__() 57 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 58 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 59 | self.conv2_drop = nn.Dropout2d() 60 | self.fc1 = nn.Linear(320, 50) 61 | self.fc2 = nn.Linear(50, num_classes) 62 | 63 | def forward(self, x): 64 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 65 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 66 | x = x.view(-1, 320) 67 | x = F.relu(self.fc1(x)) 68 | x = F.dropout(x, training=self.training) 69 | x = self.fc2(x) 70 | return x 71 | 72 | 73 | class SmallCNN(nn.Module): 74 | """ 75 | From https://github.com/yaodongyu/TRADES/blob/master/models/small_cnn.py 76 | """ 77 | def __init__(self, drop=0.5): 78 | super(SmallCNN, self).__init__() 79 | 80 | self.num_channels = 1 81 | self.num_labels = 10 82 | 83 | activ = nn.ReLU(True) 84 | 85 | self.feature_extractor = nn.Sequential(OrderedDict([ 86 | ('conv1', nn.Conv2d(self.num_channels, 32, 3)), 87 | ('relu1', activ), 88 | ('conv2', nn.Conv2d(32, 32, 3)), 89 | ('relu2', activ), 90 | ('maxpool1', nn.MaxPool2d(2, 2)), 91 | ('conv3', nn.Conv2d(32, 64, 3)), 92 | ('relu3', activ), 93 | ('conv4', nn.Conv2d(64, 64, 3)), 94 | ('relu4', activ), 95 | ('maxpool2', nn.MaxPool2d(2, 2)), 96 | ])) 97 | 98 | self.classifier = nn.Sequential(OrderedDict([ 99 | ('fc1', nn.Linear(64 * 4 * 4, 200)), 100 | ('relu1', activ), 101 | ('drop', nn.Dropout(drop)), 102 | ('fc2', nn.Linear(200, 200)), 103 | ('relu2', activ), 104 | ('fc3', nn.Linear(200, self.num_labels)), 105 | ])) 106 | 107 | for m in self.modules(): 108 | if isinstance(m, (nn.Conv2d)): 109 | nn.init.kaiming_normal_(m.weight) 110 | if m.bias is not None: 111 | nn.init.constant_(m.bias, 0) 112 | elif isinstance(m, nn.BatchNorm2d): 113 | nn.init.constant_(m.weight, 1) 114 | nn.init.constant_(m.bias, 0) 115 | nn.init.constant_(self.classifier.fc3.weight, 0) 116 | nn.init.constant_(self.classifier.fc3.bias, 0) 117 | 118 | def forward(self, input): 119 | features = self.feature_extractor(input) 120 | logits = self.classifier(features.view(-1, 64 * 4 * 4)) 121 | return logits 122 | -------------------------------------------------------------------------------- /models/preactresnet.py: -------------------------------------------------------------------------------- 1 | """preactresnet in pytorch 2 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 3 | Identity Mappings in Deep Residual Networks 4 | https://arxiv.org/abs/1603.05027 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class PreActBasic(nn.Module): 13 | 14 | expansion = 1 15 | def __init__(self, in_channels, out_channels, stride): 16 | super().__init__() 17 | self.residual = nn.Sequential( 18 | nn.BatchNorm2d(in_channels), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1), 21 | nn.BatchNorm2d(out_channels), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(out_channels, out_channels * PreActBasic.expansion, kernel_size=3, padding=1) 24 | ) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_channels != out_channels * PreActBasic.expansion: 28 | self.shortcut = nn.Conv2d(in_channels, out_channels * PreActBasic.expansion, 1, stride=stride) 29 | 30 | def forward(self, x): 31 | 32 | res = self.residual(x) 33 | shortcut = self.shortcut(x) 34 | 35 | return res + shortcut 36 | 37 | 38 | class PreActBottleNeck(nn.Module): 39 | 40 | expansion = 4 41 | def __init__(self, in_channels, out_channels, stride): 42 | super().__init__() 43 | 44 | self.residual = nn.Sequential( 45 | nn.BatchNorm2d(in_channels), 46 | nn.ReLU(inplace=True), 47 | nn.Conv2d(in_channels, out_channels, 1, stride=stride), 48 | 49 | nn.BatchNorm2d(out_channels), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(out_channels, out_channels, 3, padding=1), 52 | 53 | nn.BatchNorm2d(out_channels), 54 | nn.ReLU(inplace=True), 55 | nn.Conv2d(out_channels, out_channels * PreActBottleNeck.expansion, 1) 56 | ) 57 | 58 | self.shortcut = nn.Sequential() 59 | 60 | if stride != 1 or in_channels != out_channels * PreActBottleNeck.expansion: 61 | self.shortcut = nn.Conv2d(in_channels, out_channels * PreActBottleNeck.expansion, 1, stride=stride) 62 | 63 | def forward(self, x): 64 | 65 | res = self.residual(x) 66 | shortcut = self.shortcut(x) 67 | 68 | return res + shortcut 69 | 70 | 71 | class PreActResNet(nn.Module): 72 | 73 | def __init__(self, block, num_block, num_classes=10, input_channel=3): 74 | super().__init__() 75 | self.input_channels = 64 76 | 77 | self.conv1 = nn.Sequential( 78 | nn.Conv2d(input_channel, 64, 3, padding=1), 79 | nn.BatchNorm2d(64), 80 | nn.ReLU(inplace=True) 81 | ) 82 | 83 | self.stage1 = self._make_layers(block, num_block[0], 64, 1) 84 | self.stage2 = self._make_layers(block, num_block[1], 128, 2) 85 | self.stage3 = self._make_layers(block, num_block[2], 256, 2) 86 | self.stage4 = self._make_layers(block, num_block[3], 512, 2) 87 | 88 | self.linear = nn.Linear(self.input_channels, num_classes) 89 | 90 | def _make_layers(self, block, block_num, out_channels, stride): 91 | layers = [] 92 | 93 | layers.append(block(self.input_channels, out_channels, stride)) 94 | self.input_channels = out_channels * block.expansion 95 | 96 | while block_num - 1: 97 | layers.append(block(self.input_channels, out_channels, 1)) 98 | self.input_channels = out_channels * block.expansion 99 | block_num -= 1 100 | 101 | return nn.Sequential(*layers) 102 | 103 | def forward(self, x): 104 | x = self.conv1(x) 105 | 106 | x = self.stage1(x) 107 | x = self.stage2(x) 108 | x = self.stage3(x) 109 | x = self.stage4(x) 110 | 111 | x = F.adaptive_avg_pool2d(x, 1) 112 | x = x.view(x.size(0), -1) 113 | x = self.linear(x) 114 | 115 | return x 116 | 117 | 118 | def preactresnet18(input_channel=3, num_classes=10): 119 | return PreActResNet(PreActBasic, [2, 2, 2, 2], num_classes=num_classes, input_channel=input_channel) 120 | 121 | 122 | def preactresnet34(input_channel=3, num_classes=10): 123 | return PreActResNet(PreActBasic, [3, 4, 6, 3], num_classes=num_classes, input_channel=input_channel) 124 | 125 | 126 | def preactresnet50(input_channel=3, num_classes=10): 127 | return PreActResNet(PreActBottleNeck, [3, 4, 6, 3], num_classes=num_classes, input_channel=input_channel) 128 | 129 | 130 | def preactresnet101(input_channel=3, num_classes=10): 131 | return PreActResNet(PreActBottleNeck, [3, 4, 23, 3], num_classes=num_classes, input_channel=input_channel) 132 | 133 | 134 | def preactresnet152(input_channel=3, num_classes=10): 135 | return PreActResNet(PreActBottleNeck, [3, 8, 36, 3], num_classes=num_classes, input_channel=input_channel) 136 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from torch.autograd import Variable 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*planes: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(self.expansion*planes) 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = self.bn2(self.conv2(out)) 36 | out += self.shortcut(x) 37 | out = F.relu(out) 38 | return out 39 | 40 | 41 | class Bottleneck(nn.Module): 42 | expansion = 4 43 | 44 | def __init__(self, in_planes, planes, stride=1): 45 | super(Bottleneck, self).__init__() 46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 47 | self.bn1 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 52 | 53 | self.shortcut = nn.Sequential() 54 | if stride != 1 or in_planes != self.expansion*planes: 55 | self.shortcut = nn.Sequential( 56 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 57 | nn.BatchNorm2d(self.expansion*planes) 58 | ) 59 | 60 | def forward(self, x): 61 | out = F.relu(self.bn1(self.conv1(x))) 62 | out = F.relu(self.bn2(self.conv2(out))) 63 | out = self.bn3(self.conv3(out)) 64 | out += self.shortcut(x) 65 | out = F.relu(out) 66 | return out 67 | 68 | 69 | class ResNet(nn.Module): 70 | def __init__(self, block, num_blocks, num_classes=10, input_channel=3): 71 | super(ResNet, self).__init__() 72 | self.in_planes = 64 73 | 74 | self.conv1 = nn.Conv2d(input_channel, 64, kernel_size=3, stride=1, padding=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(64) 76 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 77 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 78 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 79 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 80 | self.linear = nn.Linear(512*block.expansion, num_classes) 81 | 82 | def _make_layer(self, block, planes, num_blocks, stride): 83 | strides = [stride] + [1]*(num_blocks-1) 84 | layers = [] 85 | for stride in strides: 86 | layers.append(block(self.in_planes, planes, stride)) 87 | self.in_planes = planes * block.expansion 88 | return nn.Sequential(*layers) 89 | 90 | def forward(self, x): 91 | out = F.relu(self.bn1(self.conv1(x))) 92 | out = self.layer1(out) 93 | out = self.layer2(out) 94 | out = self.layer3(out) 95 | out = self.layer4(out) 96 | out = F.avg_pool2d(out, 4) 97 | out = out.view(out.size(0), -1) 98 | out = self.linear(out) 99 | return out 100 | 101 | 102 | def ResNet18(input_channel=3, num_classes=10): 103 | return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes, input_channel=input_channel) 104 | 105 | 106 | def ResNet34(input_channel=3, num_classes=10): 107 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes, input_channel=input_channel) 108 | 109 | 110 | def ResNet50(input_channel=3, num_classes=10): 111 | return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes, input_channel=input_channel) 112 | 113 | 114 | def ResNet101(input_channel=3, num_classes=10): 115 | return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes, input_channel=input_channel) 116 | 117 | 118 | def ResNet152(input_channel=3, num_classes=10): 119 | return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes, input_channel=input_channel) 120 | 121 | 122 | def test(): 123 | net = ResNet18() 124 | y = net(Variable(torch.randn(1,3,32,32))) 125 | print(y.size()) 126 | print(net) 127 | -------------------------------------------------------------------------------- /utils_mcl_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Official implementation from "Learning with Multiple Complementary Labels" 3 | by Lei Feng et al. 4 | """ 5 | import torch.nn as nn 6 | import torch 7 | import math 8 | import torch.nn.functional as F 9 | 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | 12 | 13 | def mae_loss(outputs, partialY): 14 | sm_outputs = F.softmax(outputs, dim=1) 15 | loss_fn = nn.L1Loss(reduction='none') 16 | loss_matrix = loss_fn(sm_outputs, partialY.float()) 17 | sample_loss = loss_matrix.sum(dim=-1) 18 | return sample_loss 19 | 20 | 21 | def mse_loss(outputs, Y): 22 | sm_outputs = F.softmax(outputs, dim=1) 23 | loss_fn = nn.MSELoss(reduction='none') 24 | loss_matrix = loss_fn(sm_outputs, Y.float()) 25 | sample_loss = loss_matrix.sum(dim=-1) 26 | return sample_loss 27 | 28 | 29 | def gce_loss(outputs, Y): 30 | q = 0.7 31 | sm_outputs = F.softmax(outputs, dim=1) 32 | pow_outputs = torch.pow(sm_outputs, q) 33 | sample_loss = (1-(pow_outputs*Y).sum(dim=1))/q # n 34 | return sample_loss 35 | 36 | 37 | def phuber_ce_loss(outputs, Y): 38 | trunc_point = 0.1 39 | n = Y.shape[0] 40 | soft_max = nn.Softmax(dim=1) 41 | sm_outputs = soft_max(outputs) 42 | final_outputs = sm_outputs * Y 43 | final_confidence = final_outputs.sum(dim=1) 44 | 45 | ce_index = (final_confidence > trunc_point) 46 | sample_loss = torch.zeros(n).to(device) 47 | 48 | if ce_index.sum() > 0: 49 | ce_outputs = outputs[ce_index,:] 50 | logsm = nn.LogSoftmax(dim=-1) 51 | logsm_outputs = logsm(ce_outputs) 52 | final_ce_outputs = logsm_outputs * Y[ce_index,:] 53 | sample_loss[ce_index] = - final_ce_outputs.sum(dim=-1) 54 | 55 | linear_index = (final_confidence <= trunc_point) 56 | 57 | if linear_index.sum() > 0: 58 | sample_loss[linear_index] = -math.log(trunc_point) + (-1/trunc_point)*final_confidence[linear_index] + 1 59 | 60 | return sample_loss 61 | 62 | 63 | def ce_loss(outputs, Y): 64 | logsm = nn.LogSoftmax(dim=1) 65 | logsm_outputs = logsm(outputs) 66 | final_outputs = logsm_outputs * Y 67 | sample_loss = - final_outputs.sum(dim=1) 68 | return sample_loss 69 | 70 | 71 | def unbiased_estimator(loss_fn, outputs, partialY): 72 | n, k = partialY.shape[0], partialY.shape[1] 73 | comp_num = k - partialY.sum(dim=1) 74 | temp_loss = torch.zeros(n, k).to(device) 75 | for i in range(k): 76 | tempY = torch.zeros(n, k).to(device) 77 | tempY[:, i] = 1.0 78 | temp_loss[:, i] = loss_fn(outputs, tempY) 79 | 80 | candidate_loss = (temp_loss * partialY).sum(dim=1) # for true label 81 | noncandidate_loss = (temp_loss * (1-partialY)).sum(dim=1) # for complementary label 82 | total_loss = candidate_loss - (k-comp_num-1.0)/comp_num * noncandidate_loss 83 | average_loss = total_loss.mean() 84 | return average_loss 85 | 86 | 87 | def log_ce_loss(outputs, partialY, pseudo_labels, alpha): 88 | k = partialY.shape[1] 89 | can_num = partialY.sum(dim=1).float() # n 90 | 91 | soft_max = nn.Softmax(dim=1) 92 | sm_outputs = soft_max(outputs) 93 | final_outputs = sm_outputs * partialY # \sum_{j\notin \bar{Y}} [p(j|x)] 94 | 95 | pred_outputs = sm_outputs[torch.arange(sm_outputs.size(0)), pseudo_labels] # p(pl|x) 96 | # pred_outputs, _ = torch.max(final_outputs, dim=1) # \max \sum_{j\notin \bar{Y}} [p(j|x)] 97 | 98 | average_loss = - ((k - 1) / (k - can_num) * torch.log(alpha * final_outputs.sum(dim=1) + (1 - alpha) * pred_outputs)).mean() 99 | 100 | return average_loss 101 | 102 | 103 | def log_loss(outputs, partialY): 104 | k = partialY.shape[1] 105 | can_num = partialY.sum(dim=1).float() # n 106 | 107 | soft_max = nn.Softmax(dim=1) 108 | sm_outputs = soft_max(outputs) 109 | final_outputs = sm_outputs * partialY 110 | 111 | average_loss = - ((k-1)/(k-can_num) * torch.log(final_outputs.sum(dim=1))).mean() 112 | return average_loss 113 | 114 | 115 | def exp_ce_loss(outputs, partialY, pseudo_labels, alpha): 116 | k = partialY.shape[1] 117 | can_num = partialY.sum(dim=1).float() # n 118 | 119 | soft_max = nn.Softmax(dim=1) 120 | sm_outputs = soft_max(outputs) 121 | final_outputs = sm_outputs * partialY # \sum_{j\notin \bar{Y}} [p(j|x)] 122 | 123 | pred_outputs = sm_outputs[torch.arange(sm_outputs.size(0)), pseudo_labels] # p(pl|x) 124 | # pred_outputs, _ = torch.max(final_outputs, dim=1) # \max \sum_{j\notin \bar{Y}} [p(j|x)] 125 | 126 | average_loss = ((k - 1) / (k - can_num) * torch.exp(-alpha * final_outputs.sum(dim=1) - (1 - alpha) * pred_outputs)).mean() 127 | 128 | return average_loss 129 | 130 | 131 | def exp_loss(outputs, partialY): 132 | k = partialY.shape[1] 133 | can_num = partialY.sum(dim=1).float() # n 134 | 135 | soft_max = nn.Softmax(dim=1) 136 | sm_outputs = soft_max(outputs) 137 | final_outputs = sm_outputs * partialY 138 | 139 | average_loss = ((k-1)/(k-can_num) * torch.exp(-final_outputs.sum(dim=1))).mean() 140 | return average_loss 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Adversarial Training with Complementary Labels 2 | 3 | [![OpenReview](https://img.shields.io/badge/OpenReview-Forum-brightgreen.svg)](https://openreview.net/forum?id=s7SukMH7ie9)    [![PDF](https://img.shields.io/badge/Download-PDF-blue.svg)](https://openreview.net/pdf?id=s7SukMH7ie9)    [![PDF](https://img.shields.io/badge/NeurIPS-Page-orange.svg)](https://neurips.cc/virtual/2022/poster/55084) 4 | 5 | Official (Pytorch) Implementation of *NeurIPS 2022 Spotlight "Adversarial Training with Complementary Labels: On the Benefit of Gradually Informative Attacks"* by [Jianan Zhou*](https://openreview.net/profile?id=~Jianan_Zhou1)*,* [Jianing Zhu*](https://openreview.net/profile?id=~Jianing_Zhu2)*,* [Jingfeng Zhang](https://openreview.net/profile?id=~Jingfeng_Zhang1)*,* [Tongliang Liu](https://openreview.net/profile?id=~Tongliang_Liu1)*,* [Gang Niu](https://openreview.net/profile?id=~Gang_Niu1)*,* [Bo Han](https://openreview.net/profile?id=~Bo_Han1)*,* [Masashi Sugiyama](https://openreview.net/profile?id=~Masashi_Sugiyama1). 6 | 7 | ```bash 8 | @inproceedings{zhou2022adversarial, 9 | title={Adversarial Training with Complementary Labels: On the Benefit of Gradually Informative Attacks}, 10 | author={Jianan Zhou and Jianing Zhu and Jingfeng Zhang and Tongliang Liu and Gang Niu and Bo Han and Masashi Sugiyama}, 11 | booktitle={Advances in Neural Information Processing Systems}, 12 | editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho}, 13 | year={2022}, 14 | url={https://openreview.net/forum?id=s7SukMH7ie9} 15 | } 16 | ``` 17 | 18 |

19 | Poster 20 |

21 | 22 | ### TL;DR 23 | 24 | > *How to equip machine learning models with adversarial robustness when all given labels in a dataset are wrong (i.e., complementary labels)?* 25 | 26 |

27 | CLs 28 |

29 | 30 | ### Dependencies 31 | 32 | * Python 3.8 33 | * Scipy 34 | * [PyTorch 1.11.0](https://pytorch.org) 35 | * [AutoAttack](https://github.com/fra31/auto-attack) 36 | 37 | ### How to Run 38 | 39 | > Please refer to Section 5 and Appendix D.1 of our paper for the detailed setups. 40 | 41 | #### Baseline 42 | 43 | ```bash 44 | # Two-stage baseline for MNIST/Kuzushiji 45 | python main.py --dataset 'kuzushiji' --model 'cnn' --method 'log' --framework 'two_stage' --cl_epochs 50 --adv_epochs 50 --cl_lr 0.001 --at_lr 0.01 46 | # Two-stage baseline for CIFAR10/SVHN 47 | python main.py --dataset 'cifar10' --model 'resnet18' --method 'log' --framework 'two_stage' --cl_epochs 50 --adv_epochs 70 --cl_lr 0.01 --at_lr 0.01 48 | 49 | # Complementary baselines (e.g., LOG) for MNIST/Kuzushiji 50 | python main.py --dataset 'kuzushiji' --model 'cnn' --method 'log' --framework 'one_stage' --adv_epochs 100 --at_lr 0.01 --scheduler 'none' 51 | # Complementary baselines (e.g., LOG) for CIFAR10/SVHN 52 | python main.py --dataset 'cifar10' --model 'resnet18' --method 'log' --framework 'one_stage' --adv_epochs 120 --at_lr 0.01 --scheduler 'none' 53 | ``` 54 | 55 | #### Ours 56 | 57 | ```bash 58 | # MNIST/Kuzushiji 59 | python main.py --dataset 'kuzushiji' --model 'cnn' --method 'log_ce' --framework 'one_stage' --adv_epochs 100 --at_lr 0.01 --scheduler 'cosine' --sch_epoch 50 --warmup_epoch 10 60 | # CIFAR10/SVHN 61 | python main.py --dataset 'cifar10' --model 'resnet18' --method 'log_ce' --framework 'one_stage' --adv_epochs 120 --at_lr 0.01 --scheduler 'cosine' --sch_epoch 40 --warmup_epoch 40 62 | ``` 63 | 64 | #### Options 65 | 66 | ```bash 67 | # Supported Datasets (we cannot handle cifar100 on the SCL setting currently, i.e., complementary learning fails on CIFAR100 in our exp.) 68 | --dataset - ['mnist', 'kuzushiji', 'fashion', 'cifar10', 'svhn', 'cifar100'] 69 | # Complementary Loss Functions 70 | --method - ['free', 'nn', 'ga', 'pc', 'forward', 'scl_exp', 'scl_nl', 'mae', 'mse', 'ce', 'gce', 'phuber_ce', 'log', 'exp', 'l_uw', 'l_w', 'log_ce', 'exp_ce'] 71 | # Multiple Complementary Labels (MCLs) 72 | --cl_num - (1-9) the number of complementary labels of each data; (0) MCLs data distribution of ICML2020 - "Learning with Multiple Complementary Labels" 73 | ``` 74 | 75 | ### Reference 76 | 77 | * [NeurIPS 2017] - [Learning from complementary labels](https://arxiv.org/abs/1705.07541) 78 | 79 | * [ECCV 2018] - [Learning with biased complementary labels](https://arxiv.org/abs/1711.09535) 80 | 81 | * [ICML 2019] - [Complementary-label learning for arbitrary losses and models](https://arxiv.org/abs/1810.04327) 82 | 83 | * [ICML 2020] - [Unbiased Risk Estimators Can Mislead: A Case Study of Learning with Complementary Labels](https://arxiv.org/abs/2007.02235) 84 | 85 | * [ICML 2020] - [Learning with Multiple Complementary Labels](https://arxiv.org/abs/1912.12927) 86 | 87 | ### Acknowledgments 88 | 89 | Thank the authors of *"Complementary-label learning for arbitrary losses and models"* for the open-source [code](https://github.com/takashiishida/comp) and issue discussion. Other codebases may be found on the corresponding author's homepage. We also would like to thank anonymous reviewers of NeurIPS 2022 for their constructive comments. 90 | 91 | ### Contact 92 | 93 | Please contact [jianan004@e.ntu.edu.sg](mailto:jianan004@e.ntu.edu.sg) and [csjnzhu@comp.hkbu.edu.hk](mailto:csjnzhu@comp.hkbu.edu.hk) if you have any questions regarding the paper or implementation. 94 | 95 | -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | # https://github.com/bearpaw/pytorch-classification/blob/master/models/cifar/densenet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import math 7 | 8 | 9 | __all__ = ['densenet'] 10 | 11 | 12 | from torch.autograd import Variable 13 | 14 | class Bottleneck(nn.Module): 15 | def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0): 16 | super(Bottleneck, self).__init__() 17 | planes = expansion * growthRate 18 | self.bn1 = nn.BatchNorm2d(inplanes) 19 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3, 22 | padding=1, bias=False) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.dropRate = dropRate 25 | 26 | def forward(self, x): 27 | out = self.bn1(x) 28 | out = self.relu(out) 29 | out = self.conv1(out) 30 | out = self.bn2(out) 31 | out = self.relu(out) 32 | out = self.conv2(out) 33 | if self.dropRate > 0: 34 | out = F.dropout(out, p=self.dropRate, training=self.training) 35 | 36 | out = torch.cat((x, out), 1) 37 | 38 | return out 39 | 40 | 41 | class BasicBlock(nn.Module): 42 | def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0): 43 | super(BasicBlock, self).__init__() 44 | planes = expansion * growthRate 45 | self.bn1 = nn.BatchNorm2d(inplanes) 46 | self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3, 47 | padding=1, bias=False) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.dropRate = dropRate 50 | 51 | def forward(self, x): 52 | out = self.bn1(x) 53 | out = self.relu(out) 54 | out = self.conv1(out) 55 | if self.dropRate > 0: 56 | out = F.dropout(out, p=self.dropRate, training=self.training) 57 | 58 | out = torch.cat((x, out), 1) 59 | 60 | return out 61 | 62 | 63 | class Transition(nn.Module): 64 | def __init__(self, inplanes, outplanes): 65 | super(Transition, self).__init__() 66 | self.bn1 = nn.BatchNorm2d(inplanes) 67 | self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False) 68 | self.relu = nn.ReLU(inplace=True) 69 | 70 | def forward(self, x): 71 | out = self.bn1(x) 72 | out = self.relu(out) 73 | out = self.conv1(out) 74 | out = F.avg_pool2d(out, 2) 75 | return out 76 | 77 | 78 | class DenseNet(nn.Module): 79 | 80 | def __init__(self, depth=22, block=Bottleneck, dropRate=0, num_classes=10, growthRate=12, compressionRate=2, input_channel=3): 81 | super(DenseNet, self).__init__() 82 | 83 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4' 84 | n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6 85 | 86 | self.growthRate = growthRate 87 | self.dropRate = dropRate 88 | 89 | # self.inplanes is a global variable used across multiple 90 | # helper functions 91 | self.inplanes = growthRate * 2 92 | self.conv1 = nn.Conv2d(input_channel, self.inplanes, kernel_size=3, padding=1, bias=False) 93 | self.dense1 = self._make_denseblock(block, n) 94 | self.trans1 = self._make_transition(compressionRate) 95 | self.dense2 = self._make_denseblock(block, n) 96 | self.trans2 = self._make_transition(compressionRate) 97 | self.dense3 = self._make_denseblock(block, n) 98 | self.bn = nn.BatchNorm2d(self.inplanes) 99 | self.relu = nn.ReLU(inplace=True) 100 | self.avgpool = nn.AvgPool2d(8) 101 | self.linear = nn.Linear(self.inplanes, num_classes) 102 | 103 | # Weight initialization 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 107 | m.weight.data.normal_(0, math.sqrt(2. / n)) 108 | elif isinstance(m, nn.BatchNorm2d): 109 | m.weight.data.fill_(1) 110 | m.bias.data.zero_() 111 | 112 | def _make_denseblock(self, block, blocks): 113 | layers = [] 114 | for i in range(blocks): 115 | # Currently we fix the expansion ratio as the default value 116 | layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate)) 117 | self.inplanes += self.growthRate 118 | 119 | return nn.Sequential(*layers) 120 | 121 | def _make_transition(self, compressionRate): 122 | inplanes = self.inplanes 123 | outplanes = int(math.floor(self.inplanes // compressionRate)) 124 | self.inplanes = outplanes 125 | return Transition(inplanes, outplanes) 126 | 127 | 128 | def forward(self, x): 129 | x = self.conv1(x) 130 | 131 | x = self.trans1(self.dense1(x)) 132 | x = self.trans2(self.dense2(x)) 133 | x = self.dense3(x) 134 | x = self.bn(x) 135 | x = self.relu(x) 136 | 137 | x = self.avgpool(x) 138 | emb = x.view(x.size(0), -1) 139 | x = self.linear(emb) 140 | 141 | return x 142 | 143 | 144 | def densenet(**kwargs): 145 | """ 146 | Constructs a ResNet model. 147 | """ 148 | return DenseNet(**kwargs) 149 | -------------------------------------------------------------------------------- /attack_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.autograd import Variable 3 | from utils_algo import * 4 | from utils_mcl_loss import * 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | 9 | def trades_loss(adv_logits, natural_logits, target, beta): 10 | """ 11 | Based on the repo TREADES: https://github.com/yaodongyu/TRADES 12 | """ 13 | batch_size = len(target) 14 | criterion_kl = nn.KLDivLoss(size_average=False).cuda() 15 | loss_natural = nn.CrossEntropyLoss(reduction='mean')(natural_logits, target) 16 | loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(adv_logits, dim=1), F.softmax(natural_logits, dim=1)) 17 | loss = loss_natural + beta * loss_robust 18 | return loss 19 | 20 | 21 | def cwloss(output, target, confidence=50, num_classes=10): 22 | # Compute the probability of the label class versus the maximum other 23 | # The same implementation as in repo CAT https://github.com/sunblaze-ucb/curriculum-adversarial-training-CAT 24 | target = target.data 25 | target_onehot = torch.zeros(target.size() + (num_classes,)) 26 | target_onehot = target_onehot.to(device) 27 | target_onehot.scatter_(1, target.unsqueeze(1), 1.) 28 | target_var = Variable(target_onehot, requires_grad=False) 29 | real = (target_var * output).sum(1) 30 | other = ((1. - target_var) * output - target_var * 10000.).max(1)[0] 31 | loss = -torch.clamp(real - other + confidence, min=0.) # equiv to max(..., 0.) 32 | loss = torch.sum(loss) 33 | return loss 34 | 35 | 36 | def cl_adv(args, model, data, target, epsilon, step_size, num_steps, id, ccp, partialY, pseudo_labels, alpha, loss_fn, category="Madry", rand_init=True): 37 | model.eval() 38 | if category == "trades": 39 | x_adv = data.detach() + 0.001 * torch.randn(data.shape).to(device).detach() if rand_init else data.detach() 40 | if category == "Madry": 41 | x_adv = data.detach() + torch.from_numpy(np.random.uniform(-epsilon, epsilon, data.shape)).float().to(device) if rand_init else data.detach() 42 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 43 | 44 | # generate adversarial examples 45 | for k in range(num_steps): 46 | x_adv.requires_grad_() 47 | output = model(x_adv) 48 | model.zero_grad() 49 | with torch.enable_grad(): 50 | if args.method in ['exp', 'log']: 51 | loss_adv = loss_fn(output, partialY[id].float()) 52 | elif args.method in ['mae', 'mse', 'ce', 'gce', 'phuber_ce']: 53 | loss_adv = unbiased_estimator(loss_fn, output, partialY[id].float()) 54 | elif args.method in ['free', 'nn', 'ga', 'pc', 'forward', 'scl_exp', 'scl_nl', 'l_uw', 'l_w']: 55 | assert args.cl_num == 1 56 | loss_adv, _ = chosen_loss_c(f=output, K=output.size(-1), labels=target, ccp=ccp, meta_method=args.method) 57 | elif args.method in ["log_ce", "exp_ce"]: 58 | loss_adv = loss_fn(output, partialY[id].float(), pseudo_labels, alpha) 59 | loss_adv.backward() 60 | eta = step_size * x_adv.grad.sign() 61 | # Update adversarial data 62 | x_adv = x_adv.detach() + eta 63 | x_adv = torch.min(torch.max(x_adv, data - epsilon), data + epsilon) 64 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 65 | x_adv = Variable(x_adv, requires_grad=False) 66 | 67 | return x_adv 68 | 69 | 70 | def pgd(model, data, target, epsilon, step_size, num_steps, loss_fn, category, rand_init, num_classes=10): 71 | model.eval() 72 | if category == "trades": 73 | x_adv = data.detach() + 0.001 * torch.randn(data.shape).to(device).detach() if rand_init else data.detach() 74 | if category == "Madry": 75 | x_adv = data.detach() + torch.from_numpy(np.random.uniform(-epsilon, epsilon, data.shape)).float().to(device) if rand_init else data.detach() 76 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 77 | for k in range(num_steps): 78 | x_adv.requires_grad_() 79 | output = model(x_adv) 80 | model.zero_grad() 81 | with torch.enable_grad(): 82 | if loss_fn == "cent": 83 | loss_adv = nn.CrossEntropyLoss(reduction="mean")(output, target) 84 | if loss_fn == "cw": 85 | loss_adv = cwloss(output, target, num_classes=num_classes) 86 | if loss_fn == "kl": 87 | criterion_kl = nn.KLDivLoss(reduction="batchmean").cuda() 88 | loss_adv = criterion_kl(F.log_softmax(output, dim=1), F.softmax(model(data), dim=1)) 89 | loss_adv.backward() 90 | eta = step_size * x_adv.grad.sign() 91 | # Update adversarial data 92 | x_adv = x_adv.detach() + eta 93 | x_adv = torch.min(torch.max(x_adv, data - epsilon), data + epsilon) 94 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 95 | x_adv = Variable(x_adv, requires_grad=False) 96 | return x_adv 97 | 98 | 99 | def eval_robust(model, test_loader, perturb_steps, epsilon, step_size, loss_fn, category, random, num_classes=10): 100 | model.eval() 101 | test_loss = 0 102 | correct = 0 103 | with torch.enable_grad(): 104 | for data, target in test_loader: 105 | data, target = data.to(device), target.to(device) 106 | x_adv = pgd(model, data, target, epsilon, step_size, perturb_steps, loss_fn, category, rand_init=random, num_classes=num_classes) 107 | output = model(x_adv) 108 | test_loss += F.cross_entropy(output, target, reduction="sum").item() 109 | pred = output.max(1, keepdim=True)[1] 110 | correct += pred.eq(target.view_as(pred)).sum().item() 111 | test_loss /= len(test_loader.dataset) 112 | test_accuracy = 100 * correct / len(test_loader.dataset) 113 | return test_loss, test_accuracy 114 | -------------------------------------------------------------------------------- /utils_algo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | 9 | def assump_free_loss(f, K, labels, ccp): 10 | """Assumption free loss (based on Thm 1) is equivalent to non_negative_loss if the max operator's threshold is negative inf.""" 11 | return non_negative_loss(f=f, K=K, labels=labels, ccp=ccp, beta=np.inf) 12 | 13 | 14 | def non_negative_loss(f, K, labels, ccp, beta): 15 | ccp = torch.from_numpy(ccp).float().to(device) 16 | neglog = -F.log_softmax(f, dim=1) # (bs, K) 17 | loss_vector = torch.zeros(K, requires_grad=True).to(device) 18 | temp_loss_vector = torch.zeros(K).to(device) 19 | for k in range(K): 20 | idx = (labels == k) 21 | if torch.sum(idx).item() > 0: 22 | idxs = idx.view(-1, 1).repeat(1, K) # (bs, K) 23 | neglog_k = torch.masked_select(neglog, idxs).view(-1, K) 24 | temp_loss_vector[k] = -(K-1) * ccp[k] * torch.mean(neglog_k, dim=0)[k] # average of k-th class loss for k-th comp class samples 25 | # ccp[k] or ccp, refer to https://github.com/takashiishida/comp/issues/2 26 | loss_vector = loss_vector + torch.mul(ccp[k], torch.mean(neglog_k, dim=0)) # only k-th in the summation of the second term inside max 27 | loss_vector = loss_vector + temp_loss_vector 28 | count = np.bincount(labels.data.cpu()).astype('float') 29 | while len(count) < K: 30 | count = np.append(count, 0) # when largest label is below K, bincount will not take care of them 31 | loss_vector_with_zeros = torch.cat((loss_vector.view(-1, 1), torch.zeros(K, requires_grad=True).view(-1, 1).to(device)-beta), 1) 32 | max_loss_vector, _ = torch.max(loss_vector_with_zeros, dim=1) 33 | final_loss = torch.sum(max_loss_vector) 34 | return final_loss, torch.mul(torch.from_numpy(count).float().to(device), loss_vector) 35 | 36 | 37 | def fast_free_loss(f, K, cl_label): 38 | loss = -(K - 1) * nn.CrossEntropyLoss(reduction="none")(f, cl_label) 39 | for k in range(K): 40 | ll = torch.LongTensor([k] * cl_label.size(0)).to(device) 41 | loss += nn.CrossEntropyLoss(reduction="none")(f, ll) 42 | 43 | return loss 44 | 45 | 46 | def forward_loss(f, K, labels, reduction='mean'): 47 | Q = torch.ones(K, K) * 1/(K-1) # uniform assumption 48 | Q = Q.to(device) 49 | for k in range(K): 50 | Q[k, k] = 0 51 | q = torch.mm(F.softmax(f, 1), Q) 52 | return F.nll_loss(q.log(), labels.long(), reduction=reduction) 53 | 54 | 55 | def pc_loss(f, K, labels): 56 | sigmoid = nn.Sigmoid() 57 | fbar = f.gather(1, labels.long().view(-1, 1)).repeat(1, K) 58 | loss_matrix = sigmoid(-1. * (f - fbar)) # multiply -1 for "complementary" 59 | M1, M2 = K*(K-1)/2, K-1 60 | pc_loss = torch.sum(loss_matrix)*(K-1)/len(labels) - M1 + M2 61 | return pc_loss 62 | 63 | 64 | def phi_loss(phi, logits, target, reduction='mean'): 65 | """ 66 | Official implementation of "Unbiased Risk Estimators Can Mislead: A Case Study of Learning with Complementary Labels" 67 | by Yu-Ting Chou et al. 68 | """ 69 | if phi == 'lin': 70 | activated_prob = F.softmax(logits, dim=1) 71 | elif phi == 'quad': 72 | activated_prob = torch.pow(F.softmax(logits, dim=1), 2) 73 | elif phi == 'exp': 74 | activated_prob = torch.exp(F.softmax(logits, dim=1)) 75 | # activated_prob = torch.exp(alpha * F.softmax(logits, dim=1) - (1 - alpha) * pred_outputs) 76 | elif phi == 'log': 77 | activated_prob = torch.log(F.softmax(logits, dim=1)) 78 | elif phi == 'nl': 79 | activated_prob = -torch.log(1 - F.softmax(logits, dim=1) + 1e-5) 80 | # activated_prob = -torch.log(alpha * (1 - F.softmax(logits, dim=1) + 1e-5) + (1 - alpha) * pred_outputs) 81 | elif phi == 'hinge': 82 | activated_prob = F.softmax(logits, dim=1) - (1 / 10) 83 | activated_prob[activated_prob < 0] = 0 84 | else: 85 | raise ValueError('Invalid phi function') 86 | 87 | loss = -F.nll_loss(activated_prob, target, reduction=reduction) 88 | return loss 89 | 90 | 91 | """ 92 | Below is the official implementation of ICML 2021 "Discriminative Complementary-Label Learning with Weighted Loss" 93 | by Yi Gao et al. 94 | """ 95 | 96 | 97 | def non_k_softmax_loss(f, K, labels): 98 | Q_1 = 1 - F.softmax(f, 1) 99 | Q_1 = F.softmax(Q_1, 1) 100 | labels = labels.long() 101 | return F.nll_loss(Q_1.log(), labels.long()) # Equation(8) in paper 102 | 103 | 104 | def w_loss(f, K, labels): 105 | loss_class = non_k_softmax_loss(f=f, K=K, labels=labels) 106 | loss_w = w_loss_p(f=f, K=K, labels=labels) 107 | final_loss = loss_class + loss_w # Equation(11) in paper 108 | return final_loss 109 | 110 | 111 | # weighted loss 112 | def w_loss_p(f, K, labels): 113 | Q_1 = 1-F.softmax(f, 1) 114 | Q = F.softmax(Q_1, 1) 115 | q = torch.tensor(1.0) / torch.sum(Q_1, dim=1) 116 | q = q.view(-1, 1).repeat(1, K) 117 | w = torch.mul(Q_1, q) # weight 118 | w_1 = torch.mul(w, Q.log()) 119 | return F.nll_loss(w_1, labels.long()) # Equation(14) in paper 120 | 121 | 122 | def chosen_loss_c(f, K, labels, ccp, meta_method, reduction='mean'): 123 | class_loss_torch = None 124 | if meta_method == 'free': 125 | final_loss, class_loss_torch = assump_free_loss(f=f, K=K, labels=labels, ccp=ccp) 126 | elif meta_method == 'ga': 127 | final_loss, class_loss_torch = assump_free_loss(f=f, K=K, labels=labels, ccp=ccp) 128 | elif meta_method == 'nn': 129 | final_loss, class_loss_torch = non_negative_loss(f=f, K=K, labels=labels, beta=0, ccp=ccp) 130 | elif meta_method == 'forward': 131 | final_loss = forward_loss(f=f, K=K, labels=labels, reduction=reduction) 132 | elif meta_method == 'pc': 133 | final_loss = pc_loss(f=f, K=K, labels=labels) 134 | elif meta_method[:3] == "scl": 135 | final_loss = phi_loss(meta_method[4:], f, labels, reduction=reduction) 136 | elif meta_method == 'l_uw': 137 | final_loss = non_k_softmax_loss(f=f, K=K, labels=labels) 138 | elif meta_method == 'l_w': 139 | final_loss = w_loss(f=f, K=K, labels=labels) 140 | 141 | return final_loss, class_loss_torch 142 | -------------------------------------------------------------------------------- /utils_data.py: -------------------------------------------------------------------------------- 1 | import sys, random 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | import torchvision.transforms as transforms 6 | import torchvision.datasets as dsets 7 | from torch.utils.data import Dataset 8 | from scipy.special import comb 9 | 10 | 11 | class AugComp(Dataset): 12 | 13 | def __init__(self, name, data, cl, tl, id, transform=None): 14 | self.name = name 15 | self.data = data 16 | self.cl = cl 17 | self.tl = tl 18 | self.id = id 19 | self.transform = transform 20 | 21 | def __getitem__(self, index): 22 | """ 23 | Args: 24 | index (int): Index 25 | Returns: 26 | tuple: (image, target) where target is index of the target class. 27 | """ 28 | img, cl_target, tl_target, idx = self.data[index], self.cl[index], self.tl[index], self.id[index] 29 | 30 | # doing this so that it is consistent with all other datasets to return a PIL Image 31 | if self.name in ["mnist", "fashion", "kuzushiji"]: 32 | img = Image.fromarray(img.numpy(), mode="L") 33 | elif self.name in ["cifar10", "cifar100"]: 34 | img = Image.fromarray(img) 35 | elif self.name in ["svhn"]: 36 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 37 | else: 38 | assert 0, "Please modify AugComp, since you are using an unsupported dataset: {}".format(self.name) 39 | 40 | if self.transform is not None: 41 | img = self.transform(img) 42 | 43 | return img, cl_target, tl_target, idx 44 | 45 | def __len__(self): 46 | return len(self.data) 47 | 48 | 49 | def generate_compl_labels(labels): 50 | """ 51 | Generating single complementary label with uniform assumption 52 | """ 53 | # args, labels: ordinary labels 54 | K = torch.max(labels)+1 55 | candidates = np.arange(K) 56 | candidates = np.repeat(candidates.reshape(1, K), len(labels), 0) # (len(labels), K) 57 | mask = np.ones((len(labels), K), dtype=bool) 58 | mask[range(len(labels)), labels.numpy()] = False 59 | candidates_ = candidates[mask].reshape(len(labels), K-1) # this is the candidates without true class: (len(labels), K-1) 60 | idx = np.random.randint(0, K-1, len(labels)) 61 | complementary_labels = candidates_[np.arange(len(labels)), np.array(idx)] 62 | return complementary_labels 63 | 64 | 65 | def generate_uniform_mul_comp_labels(labels): 66 | """ 67 | Generating multiple complementary labels following the distribution described in Section 5 of 68 | "Learning with Multiple Complementary Labels" by Lei Feng et al. 69 | """ 70 | if torch.min(labels) > 1: 71 | raise RuntimeError('testError') 72 | elif torch.min(labels) == 1: 73 | labels = labels - 1 74 | 75 | K = torch.max(labels) - torch.min(labels) + 1 76 | n = labels.shape[0] 77 | cardinality = 2 ** K - 2 78 | number = torch.tensor([comb(K, i + 1) for i in range(K - 1)]) # 0 to K-2, convert list to tensor 79 | frequency_dis = number / cardinality 80 | prob_dis = torch.zeros(K - 1) # tensor of K-1 81 | for i in range(K - 1): 82 | if i == 0: 83 | prob_dis[i] = frequency_dis[i] 84 | else: 85 | prob_dis[i] = frequency_dis[i] + prob_dis[i - 1] 86 | 87 | random_n = torch.from_numpy(np.random.uniform(0, 1, n)).float() # tensor: n 88 | mask_n = torch.ones(n) # n is the number of train_data 89 | partialY = torch.ones(n, K) 90 | temp_num_comp_train_labels = 0 # save temp number of comp train_labels 91 | 92 | for j in range(n): # for each instance 93 | # if j % 1000 == 0: 94 | # print("current index:", j) 95 | for jj in range(K - 1): # 0 to K-2 96 | if random_n[j] <= prob_dis[jj] and mask_n[j] == 1: 97 | temp_num_comp_train_labels = jj + 1 # decide the number of complementary train_labels 98 | mask_n[j] = 0 99 | 100 | candidates = torch.from_numpy(np.random.permutation(K.item())) # because K is tensor type 101 | candidates = candidates[candidates != labels[j]] 102 | temp_comp_train_labels = candidates[:temp_num_comp_train_labels] 103 | 104 | for kk in range(len(temp_comp_train_labels)): 105 | partialY[j, temp_comp_train_labels[kk]] = 0 # fulfill the partial label matrix 106 | return partialY 107 | 108 | 109 | def generate_mul_comp_labels(data, labels, s): 110 | """ 111 | Generating multiple complementary labels given the fixed size s of complementary label set \bar{Y} of each instance 112 | by "Learning with Multiple Complementary Labels" by Lei Feng et al. 113 | """ 114 | k = torch.max(labels) + 1 115 | n = labels.shape[0] 116 | index_ins = torch.arange(n) # torch type 117 | realY = torch.zeros(n, k) 118 | realY[index_ins, labels] = 1 119 | partialY = torch.ones(n, k) 120 | 121 | labels_hat = labels.clone().numpy() 122 | candidates = np.repeat(np.arange(k).reshape(1, k), len(labels_hat), 0) # candidate labels without true class 123 | mask = np.ones((len(labels_hat), k), dtype=bool) 124 | for i in range(s): 125 | mask[np.arange(n), labels_hat] = False 126 | candidates_ = candidates[mask].reshape(n, k - 1 - i) 127 | idx = np.random.randint(0, k - 1 - i, n) 128 | comp_labels = candidates_[np.arange(n), np.array(idx)] 129 | partialY[index_ins, torch.from_numpy(comp_labels)] = 0 130 | labels_hat = comp_labels 131 | return partialY 132 | 133 | 134 | def class_prior(complementary_labels): 135 | # p(\bar{y}) 136 | return np.bincount(complementary_labels) / len(complementary_labels) 137 | 138 | 139 | def prepare_data(args): 140 | dataset, batch_size = args.dataset, args.batch_size 141 | if dataset == "mnist": 142 | ordinary_train_dataset = dsets.MNIST(root='./data/MNIST', train=True, transform=transforms.ToTensor(), download=True) 143 | test_dataset = dsets.MNIST(root='./data/MNIST', train=False, transform=transforms.ToTensor()) 144 | input_dim, input_channel, num_classes = 28 * 28, 1, 10 145 | elif dataset == "kuzushiji": 146 | ordinary_train_dataset = dsets.KMNIST(root='./data/KMNIST', train=True, transform=transforms.ToTensor(), download=True) 147 | test_dataset = dsets.KMNIST(root='./data/KMNIST', train=False, transform=transforms.ToTensor()) 148 | input_dim, input_channel, num_classes = 28 * 28, 1, 10 149 | elif dataset == "fashion": 150 | ordinary_train_dataset = dsets.FashionMNIST(root='./data/FashionMNIST', train=True, transform=transforms.ToTensor(), download=True) 151 | test_dataset = dsets.FashionMNIST(root='./data/FashionMNIST', train=False, transform=transforms.ToTensor()) 152 | input_dim, input_channel, num_classes = 28 * 28, 1, 10 153 | elif dataset == "cifar10": 154 | ordinary_train_dataset = dsets.CIFAR10(root='./data/CIFAR10', train=True, transform=transforms.ToTensor(), download=True) 155 | test_dataset = dsets.CIFAR10(root='./data/CIFAR10', train=False, transform=transforms.ToTensor()) 156 | input_dim, input_channel, num_classes = 3 * 32 * 32, 3, 10 157 | elif dataset == "svhn": 158 | ordinary_train_dataset = dsets.SVHN(root='./data/SVHN', split='train', download=True, transform=transforms.ToTensor()) 159 | test_dataset = dsets.SVHN(root='./data/SVHN', split='test', download=True, transform=transforms.ToTensor()) 160 | ordinary_train_dataset.targets = ordinary_train_dataset.labels 161 | ordinary_train_dataset.classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] 162 | input_dim, input_channel, num_classes = 3 * 32 * 32, 3, 10 163 | elif dataset == "cifar100": 164 | ordinary_train_dataset = dsets.CIFAR100(root='./data/CIFAR100', train=True, transform=transforms.ToTensor(), download=True) 165 | test_dataset = dsets.CIFAR100(root='./data/CIFAR100', train=False, transform=transforms.ToTensor()) 166 | input_dim, input_channel, num_classes = 3 * 32 * 32, 3, 100 167 | train_loader = torch.utils.data.DataLoader(dataset=ordinary_train_dataset, batch_size=batch_size, shuffle=False) 168 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) 169 | return train_loader, test_loader, ordinary_train_dataset, test_dataset, num_classes, input_dim, input_channel 170 | 171 | 172 | def prepare_train_loaders(args, ordinary_train_dataset, data_aug=None): 173 | """ 174 | ccp is only used for "free", "ga", "nn" 175 | partialY is used for multiple complementary labels 176 | """ 177 | batch_size, cl_num = args.batch_size, args.cl_num 178 | # load raw_data if data_aug is not None 179 | if data_aug is not None: 180 | data, labels = ordinary_train_dataset.data, ordinary_train_dataset.targets 181 | bs, K = len(ordinary_train_dataset.data), len(ordinary_train_dataset.classes) 182 | labels = torch.LongTensor(labels) 183 | else: 184 | full_train_loader = torch.utils.data.DataLoader(dataset=ordinary_train_dataset, batch_size=len(ordinary_train_dataset.data), shuffle=False) 185 | for i, (data, labels) in enumerate(full_train_loader): 186 | K = torch.max(labels) + 1 187 | bs = labels.size(0) 188 | 189 | complementary_labels = generate_compl_labels(labels) 190 | ccp = class_prior(complementary_labels) 191 | if cl_num == 0: 192 | partialY = generate_uniform_mul_comp_labels(labels) 193 | elif cl_num == 1: 194 | partialY = torch.ones(bs, K).scatter_(1, torch.LongTensor(complementary_labels).unsqueeze(1), 0) 195 | # ema = (torch.ones(bs, K).scatter_(1, torch.LongTensor(complementary_labels).unsqueeze(1), 0)) / (K - 1) 196 | else: # 2-9 197 | partialY = generate_mul_comp_labels(data, labels, cl_num) 198 | ema = partialY / partialY.sum(1).unsqueeze(1) 199 | id = torch.arange(bs) 200 | if data_aug is not None: 201 | complementary_dataset = AugComp(name=args.dataset, data=data, cl=torch.from_numpy(complementary_labels).long(), tl=labels, id=id, transform=data_aug) 202 | else: 203 | complementary_dataset = torch.utils.data.TensorDataset(data, torch.from_numpy(complementary_labels).long(), labels, id) 204 | complementary_train_loader = torch.utils.data.DataLoader(dataset=complementary_dataset, batch_size=batch_size, shuffle=True) 205 | 206 | return complementary_train_loader, ccp, partialY, ema 207 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse, time, os, random, math 2 | import pprint as pp 3 | import torchvision 4 | from utils_data import * 5 | from utils_algo import * 6 | from utils_mcl_loss import * 7 | from models import * 8 | from attack_generator import * 9 | from utils_func import * 10 | 11 | 12 | def adversarial_train(args, model, epochs, mode="atcl", seed=1): 13 | print(">> Current mode: {} for {} epochs".format(mode, epochs)) 14 | lr = args.cl_lr if mode == "cl" else args.at_lr 15 | loss_fn, best_nat_acc, best_pgd20_acc, best_cw_acc, best_epoch = create_loss_fn(args), 0, 0, 0, 0 16 | nature_train_acc_list, nature_test_acc_list, pgd20_acc_list, cw_acc_list = [], [], [], [] 17 | first_layer_grad, last_layer_grad = [], [] 18 | if mode == "cl": 19 | optimizer = torch.optim.SGD(model.parameters(), lr=args.cl_lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.dataset in ["cifar10", "svhn", "cifar100"] else torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay, lr=args.cl_lr) 20 | elif mode == "at": 21 | optimizer = torch.optim.SGD(model.parameters(), lr=args.at_lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.dataset in ["cifar10", "svhn", "cifar100"] else torch.optim.SGD(model.parameters(), lr=args.at_lr, momentum=args.momentum) 22 | cl_model = create_model(args, input_dim, input_channel, K) 23 | checkpoint = torch.load(os.path.join(args.out_dir, "cl_best_checkpoint_seed{}.pth.tar".format(seed))) 24 | cl_model.load_state_dict(checkpoint['state_dict']) 25 | cl_model.eval() 26 | print(">> Load the CL model with train acc: {}, test acc: {}, epoch {}".format(checkpoint['train_acc'], checkpoint['test_acc'], checkpoint['epoch'])) 27 | elif mode == "atcl": 28 | optimizer = torch.optim.SGD(model.parameters(), lr=args.at_lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.dataset in ["cifar10", "svhn", "cifar100"] else torch.optim.SGD(model.parameters(), lr=args.at_lr, momentum=args.momentum) 29 | 30 | for epoch in range(epochs): 31 | correct, total = 0, 0 32 | lr = (epoch+1) / 5 * args.at_lr if (mode == "cl" or mode == "atcl") and (epoch < 5 and args.dataset in ["cifar10", "svhn", "cifar100"]) else lr # lr warmup for cll 33 | lr = cl_lr_schedule(lr, epoch + 1, optimizer) if mode == "cl" else adv_lr_schedule(lr, epoch + 1, optimizer) 34 | for batch_idx, (images, cl_labels, true_labels, id) in enumerate(complementary_train_loader): 35 | images, cl_labels, true_labels = images.to(device), cl_labels.to(device), true_labels.to(device) 36 | 37 | if mode == "cl": 38 | x_adv = images 39 | elif mode == "at": 40 | pseudo_labels, prob = get_pred(cl_model, images.detach()) 41 | pseudo_labels = pseudo_labels.to(device) 42 | x_adv = pgd(model, images, pseudo_labels, args.epsilon, args.step_size, args.num_steps, loss_fn="cent", category="Madry", rand_init=True, num_classes=K) 43 | else: # "atcl" 44 | _, prob = get_pred(model, images.detach()) 45 | epsilon, step_size, num_steps = at_param_schedule(args, epoch + 1) 46 | alpha = atcl_scheduler(args, id, prob, epsilon, partialY, epoch + 1) 47 | _, pseudo_labels = torch.max(ema[id], 1) 48 | # pseudo_labels = torch.multinomial(ema[id], 1).squeeze(1) 49 | correct += (pseudo_labels == true_labels).sum().item() 50 | total += pseudo_labels.size(0) 51 | if epsilon != 0: 52 | # x_adv = pgd(model, images, true_labels, epsilon, step_size, num_steps, loss_fn="cent", category="Madry", rand_init=True, num_classes=K) # mex_ce_tl 53 | # x_adv = pgd(model, images, None, epsilon, step_size, num_steps, loss_fn="kl", category="trades", rand_init=True, num_classes=K) # max_trades 54 | x_adv = cl_adv(args, model, images, cl_labels, epsilon, step_size, num_steps, id, ccp, partialY, pseudo_labels, alpha, loss_fn, category="Madry", rand_init=True) # max_cl 55 | else: 56 | x_adv = images 57 | 58 | if batch_idx == 0: 59 | torchvision.utils.save_image(x_adv, os.path.join(args.out_dir, "x_adv_seed_{}_epoch_{}.jpg".format(seed, epoch+1))) 60 | 61 | model.train() 62 | optimizer.zero_grad() 63 | logit = model(x_adv) 64 | if mode == "at": 65 | loss = nn.CrossEntropyLoss(reduction="mean")(logit, pseudo_labels) 66 | else: # "cl" or "atcl" 67 | if args.method in ['exp', 'log']: 68 | loss = loss_fn(logit, partialY[id].float()) 69 | elif args.method in ['mae', 'mse', 'ce', 'gce', 'phuber_ce']: 70 | loss = unbiased_estimator(loss_fn, logit, partialY[id].float()) 71 | elif args.method in ['free', 'nn', 'ga', 'pc', 'forward', 'scl_exp', 'scl_nl', 'l_uw', 'l_w']: 72 | assert args.cl_num == 1 73 | loss, _ = chosen_loss_c(f=logit, K=K, labels=cl_labels, ccp=ccp, meta_method=args.method) 74 | elif args.method in ["log_ce", "exp_ce"]: 75 | loss = loss_fn(logit, partialY[id].float(), pseudo_labels, alpha) 76 | loss.backward() 77 | 78 | # the grad stat 79 | if args.model == "cnn": 80 | first_layer_grad.append(model.module.feature_extractor.conv1.weight.grad.norm(p=2).cpu().item()) 81 | last_layer_grad.append(model.module.classifier.fc3.weight.grad.norm(p=2).cpu().item()) 82 | elif args.model in ["densenet", "resnet18", "wrn"]: 83 | first_layer_grad.append(model.module.conv1.weight.grad.norm(p=2).cpu().item()) 84 | last_layer_grad.append(model.module.linear.weight.grad.norm(p=2).cpu().item()) 85 | else: 86 | assert True, "please modify grad stat code!" 87 | optimizer.step() 88 | 89 | # Evalutions 90 | model.eval() 91 | train_nat_acc = accuracy_check(loader=train_loader, model=model) 92 | test_nat_acc = accuracy_check(loader=test_loader, model=model) 93 | _, test_pgd20_acc = eval_robust(model, test_loader, perturb_steps=20, epsilon=args.epsilon, step_size=args.step_size, loss_fn="cent", category="Madry", random=True, num_classes=K) 94 | _, test_cw_acc = eval_robust(model, test_loader, perturb_steps=30, epsilon=args.epsilon, step_size=args.step_size, loss_fn="cw", category="Madry", random=True, num_classes=K) 95 | nature_train_acc_list.append(train_nat_acc);nature_test_acc_list.append(test_nat_acc);pgd20_acc_list.append(test_pgd20_acc);cw_acc_list.append(test_cw_acc) 96 | if mode == "atcl": 97 | print(round(correct/total*100, 2), alpha, epsilon, step_size, num_steps) 98 | print('Epoch: [%d | %d] | Learning Rate: %f | Natural Train Acc %.4f | Natural Test Acc %.4f | PGD20 Test Acc %.4f | CW Test Acc %.4f |\n' % (epoch + 1, epochs, lr, train_nat_acc, test_nat_acc, test_pgd20_acc, test_cw_acc)) 99 | 100 | # Save the best & last checkpoint 101 | if mode == "cl": 102 | if test_nat_acc > best_nat_acc: 103 | best_epoch = epoch + 1 104 | best_nat_acc = test_nat_acc 105 | torch.save({ 106 | 'epoch': epoch + 1, 107 | 'state_dict': model.state_dict(), 108 | 'train_acc': train_nat_acc, 109 | 'test_acc': test_nat_acc, 110 | 'optimizer': optimizer.state_dict(), 111 | }, os.path.join(args.out_dir, "cl_best_checkpoint_seed{}.pth.tar".format(seed))) 112 | torch.save({ 113 | 'epoch': epoch + 1, 114 | 'state_dict': model.state_dict(), 115 | 'train_acc': train_nat_acc, 116 | 'test_acc': test_nat_acc, 117 | 'optimizer': optimizer.state_dict(), 118 | }, os.path.join(args.out_dir, "cl_checkpoint_seed{}.pth.tar".format(seed))) 119 | else: 120 | if test_pgd20_acc > best_pgd20_acc: 121 | best_epoch = epoch + 1 122 | best_nat_acc = test_nat_acc 123 | best_pgd20_acc = test_pgd20_acc 124 | best_cw_acc = test_cw_acc 125 | torch.save({ 126 | 'epoch': epoch + 1, 127 | 'state_dict': model.state_dict(), 128 | 'test_nat_acc': test_nat_acc, 129 | 'test_pgd20_acc': test_pgd20_acc, 130 | 'test_cw_acc': test_cw_acc, 131 | 'optimizer': optimizer.state_dict(), 132 | }, os.path.join(args.out_dir, "best_checkpoint_seed{}.pth.tar".format(seed))) 133 | torch.save({ 134 | 'epoch': epoch + 1, 135 | 'state_dict': model.state_dict(), 136 | 'test_nat_acc': test_nat_acc, 137 | 'test_pgd20_acc': test_pgd20_acc, 138 | 'test_cw_acc': test_cw_acc, 139 | 'optimizer': optimizer.state_dict(), 140 | }, os.path.join(args.out_dir, "checkpoint_seed{}.pth.tar".format(seed))) 141 | 142 | if mode == "cl": 143 | print(nature_train_acc_list) 144 | print(nature_test_acc_list) 145 | print(">> Best test acc({}): {}".format(best_epoch, max(nature_test_acc_list))) 146 | print(">> AVG test acc of last 10 epochs: {}".format(np.mean(nature_test_acc_list[-10:]))) 147 | epoch = [i for i in range(epochs)] 148 | show([epoch] * 2, [nature_train_acc_list, nature_test_acc_list], label=["train acc", "test acc"], title=args.dataset, xdes="Epoch", ydes="Accuracy", path=os.path.join(args.out_dir, "cl_acc_seed{}.png".format(seed))) 149 | 150 | return np.mean(nature_test_acc_list[-10:]) 151 | else: 152 | print(nature_test_acc_list) 153 | print(pgd20_acc_list) 154 | print(cw_acc_list) 155 | print(">> Finished Adv Training: Natural Test Acc | Last_checkpoint %.4f | Best_checkpoint(%.1f) %.4f |\n" % (test_nat_acc, best_epoch, best_nat_acc)) 156 | print(">> Finished Adv Training: PGD20 Test Acc | Last_checkpoint %.4f | Best_checkpoint %.4f |\n" % (test_pgd20_acc, best_pgd20_acc)) 157 | print(">> Finished Adv Training: CW Test Acc | Last_checkpoint %.4f | Best_checkpoint %.4f |\n" % (test_cw_acc, best_cw_acc)) 158 | epoch = [i for i in range(epochs)] 159 | show([epoch, epoch, epoch], [nature_test_acc_list, pgd20_acc_list, cw_acc_list], label=["nature test acc", "pgd20 acc", "cw acc"], 160 | title=args.dataset, xdes="Epoch", ydes="Test Accuracy", path=os.path.join(args.out_dir, "adv_test_acc_seed{}.png".format(seed))) 161 | print("first_layer: \n{} \nlast_layer: \n{}".format(first_layer_grad, last_layer_grad), file=open(os.path.join(args.out_dir, "grad_seed{}.out".format(seed)), "a+")) 162 | # Auto-attack 163 | aa_eval(args, model, filename="aa_last.txt") 164 | best_checkpoint = torch.load(os.path.join(args.out_dir, "best_checkpoint_seed{}.pth.tar".format(seed))) 165 | model.load_state_dict(best_checkpoint['state_dict']) 166 | aa_eval(args, model, filename="aa_best.txt") 167 | 168 | return [test_nat_acc, test_pgd20_acc, test_cw_acc], [best_nat_acc, best_pgd20_acc, best_cw_acc] 169 | 170 | 171 | def aa_eval(args, model, filename): 172 | """ 173 | AutoAttack evaluation - pip install git+https://github.com/fra31/auto-attack 174 | """ 175 | from autoattack import AutoAttack 176 | model.eval() 177 | version, norm, individual, n_ex = "standard", "Linf", False, 10000 178 | adversary = AutoAttack(model, norm=norm, eps=args.epsilon, log_path=os.path.join(args.out_dir, filename), version=version) 179 | 180 | l = [x for (x, y) in test_loader] 181 | x_test = torch.cat(l, 0) 182 | l = [y for (x, y) in test_loader] 183 | y_test = torch.cat(l, 0) 184 | 185 | # run attack and save images 186 | with torch.no_grad(): 187 | if not individual: 188 | adv_complete = adversary.run_standard_evaluation(x_test[:n_ex], y_test[:n_ex], bs=500) 189 | # torch.save({'adv_complete': adv_complete}, '{}/{}_{}_1_{}_eps_{:.5f}.pth'.format(args.out_dir, 'aa', version, adv_complete.shape[0], args.epsilon)) 190 | else: 191 | # individual version, each attack is run on all test points 192 | adv_complete = adversary.run_standard_evaluation_individual(x_test[:n_ex], y_test[:n_ex], bs=500) 193 | # torch.save(adv_complete, '{}/{}_{}_individual_1_{}_eps_{:.5f}_plus_{}_cheap_{}.pth'.format(args.out_dir, 'aa', version, n_ex, args.epsilon)) 194 | 195 | 196 | def at_param_schedule(args, epoch): 197 | if epoch <= args.warmup_epoch: 198 | return 0, 0, 0 199 | elif epoch <= (args.warmup_epoch+args.sch_epoch): 200 | if args.scheduler == "linear": 201 | eps = min(args.epsilon * ((epoch-args.warmup_epoch)/args.sch_epoch), args.epsilon) 202 | elif args.scheduler == "cosine": 203 | eps = 1/2 * (1-math.cos(math.pi * min(((epoch-args.warmup_epoch)/args.sch_epoch), 1))) * args.epsilon 204 | elif args.scheduler == "none": 205 | eps = args.epsilon 206 | # return eps, eps, 1 207 | # return eps, args.step_size, math.ceil(args.num_steps/args.epsilon*eps) 208 | return eps, args.step_size/args.epsilon*eps, args.num_steps 209 | else: 210 | return args.epsilon, args.step_size, args.num_steps 211 | 212 | 213 | def atcl_scheduler(args, id, prob, epsilon, partialY, epoch): 214 | if args.sch_epoch == 0: 215 | alpha = 1 if epoch <= args.warmup_epoch else 0 216 | else: 217 | alpha = min(max(1 - (epoch-args.warmup_epoch)/args.sch_epoch, 0), 1) 218 | 219 | if epoch <= 5: 220 | ema[id] = ema[id] 221 | elif epsilon < args.epsilon/2: 222 | ema[id] = 0.9 * ema[id] + (1 - 0.9) * prob 223 | else: 224 | ema[id] = ema[id] 225 | 226 | ema[id] = ema[id] * partialY[id] # reset to 0 for cls 227 | 228 | return alpha 229 | 230 | 231 | def cl_lr_schedule(lr, epoch, optimizer): 232 | for param_group in optimizer.param_groups: 233 | param_group['lr'] = lr 234 | return lr 235 | 236 | 237 | def adv_lr_schedule(lr, epoch, optimizer): 238 | if args.dataset in ["mnist", "fashion", "kuzushiji"]: 239 | # no lr_decay for easy dataset 240 | pass 241 | elif args.dataset in ["cifar10", "svhn", "cifar100"]: 242 | if epoch == (30+args.warmup_epoch): 243 | lr /= 10 244 | if epoch == (60+args.warmup_epoch): 245 | lr /= 10 246 | for param_group in optimizer.param_groups: 247 | param_group['lr'] = lr 248 | return lr 249 | 250 | 251 | def accuracy_check(loader, model): 252 | model.eval() 253 | sm = F.softmax 254 | total, num_samples = 0, 0 255 | for images, labels in loader: 256 | labels, images = labels.to(device), images.to(device) 257 | outputs = model(images) 258 | sm_outputs = sm(outputs, dim=1) 259 | _, predicted = torch.max(sm_outputs.data, 1) 260 | total += (predicted == labels).sum().item() 261 | num_samples += labels.size(0) 262 | return round(100 * total / num_samples, 2) 263 | 264 | 265 | def create_loss_fn(args): 266 | if args.method == 'mae': 267 | loss_fn = mae_loss 268 | elif args.method == 'mse': 269 | loss_fn = mse_loss 270 | elif args.method == 'ce': 271 | loss_fn = ce_loss 272 | elif args.method == 'gce': 273 | loss_fn = gce_loss 274 | elif args.method == 'phuber_ce': 275 | loss_fn = phuber_ce_loss 276 | elif args.method == 'log': 277 | loss_fn = log_loss 278 | elif args.method == 'exp': 279 | loss_fn = exp_loss 280 | elif args.method == 'log_ce': 281 | loss_fn = log_ce_loss 282 | elif args.method == "exp_ce": 283 | loss_fn = exp_ce_loss 284 | else: 285 | loss_fn = None 286 | 287 | return loss_fn 288 | 289 | 290 | def create_model(args, input_dim, input_channel, K): 291 | if args.model == 'mlp': 292 | model = mlp_model(input_dim=input_dim, hidden_dim=500, output_dim=K) 293 | elif args.model == 'linear': 294 | model = linear_model(input_dim=input_dim, output_dim=K) 295 | elif args.model == 'cnn': 296 | model = SmallCNN() 297 | elif args.model == 'resnet18': 298 | model = ResNet18(input_channel=input_channel, num_classes=K) 299 | elif args.model == 'densenet': 300 | model = densenet(input_channel=input_channel, num_classes=K) 301 | elif args.model == 'preact_resnet18': 302 | model = preactresnet18(input_channel=input_channel, num_classes=K) 303 | elif args.model == "wrn": 304 | model = Wide_ResNet_Madry(depth=32, num_classes=K, widen_factor=10, dropRate=0.0, input_channel=input_channel) # WRN-32-10 305 | 306 | display_num_param(model) 307 | model = model.to(device) 308 | model = torch.nn.DataParallel(model) 309 | 310 | return model 311 | 312 | 313 | def get_pred(cl_model, data): 314 | cl_model.eval() 315 | with torch.no_grad(): 316 | data = data.to(device) 317 | outputs = cl_model(data) 318 | _, predicted = torch.max(outputs, 1) 319 | 320 | return predicted, torch.softmax(outputs, dim=1) 321 | 322 | 323 | if __name__ == "__main__": 324 | parser = argparse.ArgumentParser(description='Learning with Complementary Labels') 325 | parser.add_argument('--cl_lr', type=float, default=1e-3, help='learning rate for complementary learning') 326 | parser.add_argument('--at_lr', type=float, default=1e-2, help='learning rate for adversarial training') 327 | parser.add_argument('--batch_size', type=int, default=256, help='batch_size of ordinary labels.') 328 | parser.add_argument('--cl_num', type=int, default=1, help='(1-9): the number of complementary labels of each data; (0): mul-cls data distribution of ICML2020') 329 | parser.add_argument('--dataset', type=str, default="mnist", choices=['mnist', 'kuzushiji', 'fashion', 'cifar10', 'svhn', 'cifar100'], 330 | help="dataset, choose from mnist, kuzushiji, fashion, cifar10, svhn, cifar100") 331 | parser.add_argument('--framework', type=str, default='one_stage', choices=['one_stage', 'two_stage']) 332 | parser.add_argument('--method', type=str, default='log', choices=['free', 'nn', 'ga', 'pc', 'forward', 'scl_exp', 333 | 'scl_nl', 'mae', 'mse', 'ce', 'gce', 'phuber_ce', 'log', 'exp', 'l_uw', 'l_w', 'log_ce', 'exp_ce']) 334 | parser.add_argument('--model', type=str, default='cnn', choices=['linear', 'mlp', 'cnn', 'resnet18', 'densenet', 'preact_resnet18', 'wrn'], help='model name') 335 | parser.add_argument('--cl_epochs', default=0, type=int, help='number of epochs for cl learning') 336 | parser.add_argument('--adv_epochs', default=100, type=int, help='number of epochs for adv') 337 | parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay') 338 | parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum') 339 | parser.add_argument('--seed', type=int, nargs='+', default=[1, 2, 3], help='random seed') 340 | parser.add_argument('--out_dir', type=str, default='./ATCL_result', help='dir of output') 341 | # for adv training 342 | parser.add_argument('--epsilon', type=float, default=0.3, help='perturbation bound') 343 | parser.add_argument('--num_steps', type=int, default=40, help='maximum perturbation step K') 344 | parser.add_argument('--step_size', type=float, default=0.01, help='step size') 345 | parser.add_argument('--scheduler', type=str, default="none", choices=['linear', 'cosine', 'none'], help='epsilon scheduler') 346 | parser.add_argument('--sch_epoch', type=int, default=0, help='scheduler epoch') 347 | parser.add_argument('--warmup_epoch', type=int, default=0, help='warmup epoch for exponential moving average') 348 | args = parser.parse_args() 349 | 350 | # Hardcoded 351 | if args.dataset in ["cifar10", "svhn", "cifar100"]: 352 | args.weight_decay, args.batch_size = 5e-4, 128 353 | args.epsilon, args.num_steps, args.step_size = 8/255, 10, 2/255 354 | 355 | pp.pprint(vars(args)) 356 | 357 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 358 | torch.backends.cudnn.benchmark = False 359 | torch.backends.cudnn.deterministic = True 360 | 361 | last_nature, last_pgd20, last_cw, best_nature, best_pgd20, best_cw = [], [], [], [], [], [] 362 | for seed in args.seed: 363 | random.seed(seed) 364 | np.random.seed(seed) 365 | torch.manual_seed(seed) 366 | torch.cuda.manual_seed_all(seed) 367 | 368 | # Store path 369 | if not os.path.exists(args.out_dir): 370 | os.makedirs(args.out_dir) 371 | 372 | data_aug = transforms.Compose([ 373 | transforms.RandomCrop(32, padding=4), 374 | transforms.RandomHorizontalFlip(), 375 | transforms.ToTensor() 376 | ]) if args.dataset in ["cifar10", "svhn", "cifar100"] else None 377 | train_loader, test_loader, ordinary_train_dataset, test_dataset, K, input_dim, input_channel = prepare_data(args) 378 | complementary_train_loader, ccp, partialY, ema = prepare_train_loaders(args, ordinary_train_dataset=ordinary_train_dataset, data_aug=data_aug) 379 | partialY, ema = partialY.to(device), ema.to(device) 380 | 381 | model = create_model(args, input_dim, input_channel, K) 382 | if args.framework == "two_stage": 383 | adversarial_train(args, model, args.cl_epochs, mode="cl", seed=seed) 384 | model = create_model(args, input_dim, input_channel, K) 385 | last_res, best_res = adversarial_train(args, model, args.adv_epochs, mode="at", seed=seed) 386 | else: 387 | last_res, best_res = adversarial_train(args, model, args.adv_epochs, mode="atcl", seed=seed) 388 | 389 | last_nature.append(last_res[0]);last_pgd20.append(last_res[1]);last_cw.append(last_res[2]) 390 | best_nature.append(best_res[0]);best_pgd20.append(best_res[1]);best_cw.append(best_res[2]) 391 | 392 | print(last_nature);print(last_pgd20);print(last_cw);print(best_nature);print(best_pgd20);print(best_cw) 393 | print(">> Last Nature: {}($\pm${})".format(round(np.mean(last_nature), 2), round(np.std(last_nature, ddof=0), 2))) 394 | print(">> Last PGD20: {}($\pm${})".format(round(np.mean(last_pgd20), 2), round(np.std(last_pgd20, ddof=0), 2))) 395 | print(">> Last CW: {}($\pm${})".format(round(np.mean(last_cw), 2), round(np.std(last_cw, ddof=0), 2))) 396 | print(">> Best Nature: {}($\pm${})".format(round(np.mean(best_nature), 2), round(np.std(best_nature, ddof=0), 2))) 397 | print(">> Best PGD20: {}($\pm${})".format(round(np.mean(best_pgd20), 2), round(np.std(best_pgd20, ddof=0), 2))) 398 | print(">> Best CW: {}($\pm${})".format(round(np.mean(best_cw), 2), round(np.std(best_cw, ddof=0), 2))) 399 | --------------------------------------------------------------------------------