├── imgs
├── cl.png
├── NeurIPS2022_1958_Poster.pdf
└── NeurIPS2022_1958_Poster.png
├── models
├── __init__.py
├── wrn_madry.py
├── base_models.py
├── preactresnet.py
├── resnet.py
└── densenet.py
├── .gitignore
├── LICENSE
├── utils_func.py
├── utils_mcl_loss.py
├── README.md
├── attack_generator.py
├── utils_algo.py
├── utils_data.py
└── main.py
/imgs/cl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RoyalSkye/ATCL/HEAD/imgs/cl.png
--------------------------------------------------------------------------------
/imgs/NeurIPS2022_1958_Poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RoyalSkye/ATCL/HEAD/imgs/NeurIPS2022_1958_Poster.pdf
--------------------------------------------------------------------------------
/imgs/NeurIPS2022_1958_Poster.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RoyalSkye/ATCL/HEAD/imgs/NeurIPS2022_1958_Poster.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 | from .base_models import *
3 | from .densenet import *
4 | from .wrn_madry import *
5 | from .preactresnet import *
6 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # macOS
7 | .DS_Store
8 |
9 | # Jupyter Notebook
10 | .ipynb_checkpoints
11 |
12 | .idea/
13 |
14 | # private files
15 | utils_plot*
16 | tmp/
17 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Jianan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/models/wrn_madry.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 | class BasicBlock(nn.Module):
8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
9 | super(BasicBlock, self).__init__()
10 | self.bn1 = nn.BatchNorm2d(in_planes)
11 | self.relu1 = nn.ReLU(inplace=True)
12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(out_planes)
15 | self.relu2 = nn.ReLU(inplace=True)
16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
17 | padding=1, bias=False)
18 | self.droprate = dropRate
19 | self.equalInOut = (in_planes == out_planes)
20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
21 | padding=0, bias=False) or None
22 |
23 | def forward(self, x):
24 | if not self.equalInOut:
25 | x = self.relu1(self.bn1(x))
26 | else:
27 | out = self.relu1(self.bn1(x))
28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
29 | if self.droprate > 0:
30 | out = F.dropout(out, p=self.droprate, training=self.training)
31 | out = self.conv2(out)
32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
33 |
34 |
35 | class NetworkBlock(nn.Module):
36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
37 | super(NetworkBlock, self).__init__()
38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
39 |
40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
41 | layers = []
42 | for i in range(int(nb_layers)):
43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
44 | return nn.Sequential(*layers)
45 |
46 | def forward(self, x):
47 | return self.layer(x)
48 |
49 |
50 | class Wide_ResNet_Madry(nn.Module):
51 | def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0, input_channel=3):
52 | super(Wide_ResNet_Madry, self).__init__()
53 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
54 | assert ((depth - 2) % 6 == 0)
55 | n = (depth - 2) / 6
56 | block = BasicBlock
57 | # 1st conv before any network block
58 | self.conv1 = nn.Conv2d(input_channel, nChannels[0], kernel_size=3, stride=1,
59 | padding=1, bias=False)
60 | # 1st block
61 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
62 | # 1st sub-block
63 | # self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
64 | # 2nd block
65 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
66 | # 3rd block
67 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
68 | # global average pooling and classifier
69 | self.bn1 = nn.BatchNorm2d(nChannels[3])
70 | self.relu = nn.ReLU(inplace=True)
71 | self.linear = nn.Linear(nChannels[3], num_classes)
72 | self.nChannels = nChannels[3]
73 |
74 | for m in self.modules():
75 | if isinstance(m, nn.Conv2d):
76 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
77 | m.weight.data.normal_(0, math.sqrt(2. / n))
78 | elif isinstance(m, nn.BatchNorm2d):
79 | m.weight.data.fill_(1)
80 | m.bias.data.zero_()
81 | elif isinstance(m, nn.Linear):
82 | m.bias.data.zero_()
83 |
84 | def forward(self, x):
85 | out = self.conv1(x)
86 | out = self.block1(out)
87 | out = self.block2(out)
88 | out = self.block3(out)
89 | out = self.relu(self.bn1(out))
90 | out = F.avg_pool2d(out, 8)
91 | out = out.view(-1, self.nChannels)
92 | return self.linear(out)
93 |
--------------------------------------------------------------------------------
/utils_func.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 |
4 |
5 | def show(x, y, label, title, xdes, ydes, path, x_scale="linear", dpi=300):
6 | plt.style.use('fivethirtyeight') # bmh, fivethirtyeight, Solarize_Light2
7 | plt.figure(figsize=(10, 8))
8 | colors = ['tab:green', 'tab:orange', 'tab:blue', 'tab:red', 'tab:cyan',
9 | 'tab:gray', 'tab:brown', 'tab:purple', 'tab:olive', 'tab:pink']
10 | # colors = ['tab:pink', 'tab:olive', 'tab:green', 'tab:orange', 'tab:blue', 'tab:red', 'tab:cyan',
11 | # 'tab:gray', 'tab:brown', 'tab:purple']
12 |
13 | assert len(x) == len(y)
14 | for i in range(len(x)):
15 | if i < len(label):
16 | plt.plot(x[i], y[i], color=colors[i], label=label[i], linewidth=1.5) # linewidth=1.5
17 | else:
18 | plt.plot(x[i], y[i], color=colors[i % len(label)], linewidth=1.5) # linewidth=1.5
19 |
20 | plt.gca().get_xaxis().get_major_formatter().set_scientific(False)
21 | plt.gca().get_yaxis().get_major_formatter().set_scientific(False)
22 | plt.xlabel(xdes, fontsize=24)
23 | plt.ylabel(ydes, fontsize=24)
24 |
25 | plt.title(title, fontsize=24)
26 | # my_y_ticks = np.arange(0, 1.1, 0.2)
27 | # plt.yticks(my_y_ticks, fontsize=24)
28 | plt.xticks(fontsize=24)
29 | plt.yticks(fontsize=24)
30 | plt.legend(loc='lower right', fontsize=16)
31 | plt.xscale(x_scale)
32 | # plt.margins(x=0)
33 |
34 | # plt.grid(True)
35 | plt.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0)
36 | plt.close("all")
37 |
38 |
39 | def display_num_param(net):
40 | nb_param = 0
41 | for param in net.parameters():
42 | nb_param += param.numel()
43 | print('There are {} ({:.2f} million) parameters in this neural network'.format(nb_param, nb_param/1e6))
44 |
45 |
46 | def stat(x_to_mcls, x_to_tls):
47 | # test how many data are given wrong cls
48 | wrong_cl_count, correct_cl_count = 0, 0
49 | for k, v in x_to_mcls.items():
50 | correct_cl_count += len(v)
51 | if x_to_tls[k] in v:
52 | wrong_cl_count += 1
53 | # test how many correct cls are given to each data
54 | correct_cl_count -= wrong_cl_count
55 |
56 | return wrong_cl_count/len(x_to_mcls), correct_cl_count/len(x_to_mcls)
57 |
58 |
59 | def show_std(x, y, label, title, xdes, ydes, path, x_scale="linear", std=True, dpi=300):
60 | """
61 | input: x/y: e.g.: [[exp1, exp2, exp3], [at1, at2, at3], ...]
62 | """
63 | print(">> Plot with mean with std err here!")
64 | plt.style.use('fivethirtyeight')
65 | plt.figure(figsize=(10, 8))
66 | colors = ['tab:green', 'tab:orange', 'tab:blue', 'tab:red', 'tab:cyan',
67 | 'tab:gray', 'tab:brown', 'tab:purple', 'tab:olive', 'tab:pink']
68 | # colors = ['tab:pink', 'tab:olive', 'tab:green', 'tab:orange', 'tab:blue', 'tab:red', 'tab:cyan',
69 | # 'tab:gray', 'tab:brown', 'tab:purple']
70 |
71 | assert len(x) == len(y)
72 | for k in range(len(y)):
73 | xx, yy = x[k], y[k]
74 | epoch = len(xx[0])
75 | y_zip, y_est, y_err = [], np.array([0.] * epoch), np.array([0.] * epoch)
76 | for i in range(epoch):
77 | ll = []
78 | for j in range(len(yy)):
79 | ll.append(yy[j][i])
80 | y_zip.append(ll)
81 | for i in range(epoch):
82 | y_est[i] = np.mean(y_zip[i])
83 | y_err[i] = np.std(y_zip[i]) / np.sqrt(3) # len(y[0])
84 | plt.plot(xx[0], y_est, color=colors[k], label=label[k]) # linewidth=2.0 linestyle=":", marker='o', ms=15
85 | if std:
86 | plt.fill_between(xx[0], y_est - y_err, y_est + y_err, alpha=0.2, color=colors[k])
87 |
88 | plt.gca().get_xaxis().get_major_formatter().set_scientific(False)
89 | plt.gca().get_yaxis().get_major_formatter().set_scientific(False)
90 | plt.xlabel(xdes, fontsize=24)
91 | plt.ylabel(ydes, fontsize=24)
92 |
93 | plt.title(title, fontsize=24)
94 | plt.xticks(fontsize=24)
95 | plt.yticks(fontsize=24)
96 | plt.legend(loc='lower right', fontsize=20)
97 | plt.xscale(x_scale)
98 | # plt.margins(x=0)
99 |
100 | # plt.grid(True)
101 | plt.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0)
102 | plt.close("all")
103 |
--------------------------------------------------------------------------------
/models/base_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.transforms as transforms
4 | import torchvision.datasets as dsets
5 | from torch.autograd import Variable
6 | import torch.nn.functional as F
7 | from collections import OrderedDict
8 |
9 |
10 | class linear_model(nn.Module):
11 | def __init__(self, input_dim, output_dim):
12 | super(linear_model, self).__init__()
13 | self.linear = nn.Linear(input_dim, output_dim)
14 |
15 | def forward(self, x):
16 | out = x.view(-1, self.num_flat_features(x))
17 | out = self.linear(out)
18 | return out
19 |
20 | def num_flat_features(self, x):
21 | # x: [bs, c, h, w]
22 | size = x.size()[1:]
23 | num_features = 1
24 | for s in size:
25 | num_features *= s
26 | return num_features
27 |
28 |
29 | class mlp_model(nn.Module):
30 | def __init__(self, input_dim, hidden_dim, output_dim):
31 | super(mlp_model, self).__init__()
32 | self.fc1 = nn.Linear(input_dim, hidden_dim)
33 | self.relu1 = nn.ReLU()
34 | self.fc2 = nn.Linear(hidden_dim, output_dim)
35 |
36 | def forward(self, x):
37 | out = x.view(-1, self.num_flat_features(x))
38 | out = self.fc1(out)
39 | out = self.relu1(out)
40 | out = self.fc2(out)
41 | return out
42 |
43 | def num_flat_features(self, x):
44 | size = x.size()[1:]
45 | num_features = 1
46 | for s in size:
47 | num_features *= s
48 | return num_features
49 |
50 |
51 | class cnn_mnist(nn.Module):
52 | """
53 | From https://github.com/yaodongyu/TRADES/blob/master/models/net_mnist.py
54 | """
55 | def __init__(self, num_classes=10):
56 | super(cnn_mnist, self).__init__()
57 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
58 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
59 | self.conv2_drop = nn.Dropout2d()
60 | self.fc1 = nn.Linear(320, 50)
61 | self.fc2 = nn.Linear(50, num_classes)
62 |
63 | def forward(self, x):
64 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
65 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
66 | x = x.view(-1, 320)
67 | x = F.relu(self.fc1(x))
68 | x = F.dropout(x, training=self.training)
69 | x = self.fc2(x)
70 | return x
71 |
72 |
73 | class SmallCNN(nn.Module):
74 | """
75 | From https://github.com/yaodongyu/TRADES/blob/master/models/small_cnn.py
76 | """
77 | def __init__(self, drop=0.5):
78 | super(SmallCNN, self).__init__()
79 |
80 | self.num_channels = 1
81 | self.num_labels = 10
82 |
83 | activ = nn.ReLU(True)
84 |
85 | self.feature_extractor = nn.Sequential(OrderedDict([
86 | ('conv1', nn.Conv2d(self.num_channels, 32, 3)),
87 | ('relu1', activ),
88 | ('conv2', nn.Conv2d(32, 32, 3)),
89 | ('relu2', activ),
90 | ('maxpool1', nn.MaxPool2d(2, 2)),
91 | ('conv3', nn.Conv2d(32, 64, 3)),
92 | ('relu3', activ),
93 | ('conv4', nn.Conv2d(64, 64, 3)),
94 | ('relu4', activ),
95 | ('maxpool2', nn.MaxPool2d(2, 2)),
96 | ]))
97 |
98 | self.classifier = nn.Sequential(OrderedDict([
99 | ('fc1', nn.Linear(64 * 4 * 4, 200)),
100 | ('relu1', activ),
101 | ('drop', nn.Dropout(drop)),
102 | ('fc2', nn.Linear(200, 200)),
103 | ('relu2', activ),
104 | ('fc3', nn.Linear(200, self.num_labels)),
105 | ]))
106 |
107 | for m in self.modules():
108 | if isinstance(m, (nn.Conv2d)):
109 | nn.init.kaiming_normal_(m.weight)
110 | if m.bias is not None:
111 | nn.init.constant_(m.bias, 0)
112 | elif isinstance(m, nn.BatchNorm2d):
113 | nn.init.constant_(m.weight, 1)
114 | nn.init.constant_(m.bias, 0)
115 | nn.init.constant_(self.classifier.fc3.weight, 0)
116 | nn.init.constant_(self.classifier.fc3.bias, 0)
117 |
118 | def forward(self, input):
119 | features = self.feature_extractor(input)
120 | logits = self.classifier(features.view(-1, 64 * 4 * 4))
121 | return logits
122 |
--------------------------------------------------------------------------------
/models/preactresnet.py:
--------------------------------------------------------------------------------
1 | """preactresnet in pytorch
2 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
3 | Identity Mappings in Deep Residual Networks
4 | https://arxiv.org/abs/1603.05027
5 | """
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class PreActBasic(nn.Module):
13 |
14 | expansion = 1
15 | def __init__(self, in_channels, out_channels, stride):
16 | super().__init__()
17 | self.residual = nn.Sequential(
18 | nn.BatchNorm2d(in_channels),
19 | nn.ReLU(inplace=True),
20 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
21 | nn.BatchNorm2d(out_channels),
22 | nn.ReLU(inplace=True),
23 | nn.Conv2d(out_channels, out_channels * PreActBasic.expansion, kernel_size=3, padding=1)
24 | )
25 |
26 | self.shortcut = nn.Sequential()
27 | if stride != 1 or in_channels != out_channels * PreActBasic.expansion:
28 | self.shortcut = nn.Conv2d(in_channels, out_channels * PreActBasic.expansion, 1, stride=stride)
29 |
30 | def forward(self, x):
31 |
32 | res = self.residual(x)
33 | shortcut = self.shortcut(x)
34 |
35 | return res + shortcut
36 |
37 |
38 | class PreActBottleNeck(nn.Module):
39 |
40 | expansion = 4
41 | def __init__(self, in_channels, out_channels, stride):
42 | super().__init__()
43 |
44 | self.residual = nn.Sequential(
45 | nn.BatchNorm2d(in_channels),
46 | nn.ReLU(inplace=True),
47 | nn.Conv2d(in_channels, out_channels, 1, stride=stride),
48 |
49 | nn.BatchNorm2d(out_channels),
50 | nn.ReLU(inplace=True),
51 | nn.Conv2d(out_channels, out_channels, 3, padding=1),
52 |
53 | nn.BatchNorm2d(out_channels),
54 | nn.ReLU(inplace=True),
55 | nn.Conv2d(out_channels, out_channels * PreActBottleNeck.expansion, 1)
56 | )
57 |
58 | self.shortcut = nn.Sequential()
59 |
60 | if stride != 1 or in_channels != out_channels * PreActBottleNeck.expansion:
61 | self.shortcut = nn.Conv2d(in_channels, out_channels * PreActBottleNeck.expansion, 1, stride=stride)
62 |
63 | def forward(self, x):
64 |
65 | res = self.residual(x)
66 | shortcut = self.shortcut(x)
67 |
68 | return res + shortcut
69 |
70 |
71 | class PreActResNet(nn.Module):
72 |
73 | def __init__(self, block, num_block, num_classes=10, input_channel=3):
74 | super().__init__()
75 | self.input_channels = 64
76 |
77 | self.conv1 = nn.Sequential(
78 | nn.Conv2d(input_channel, 64, 3, padding=1),
79 | nn.BatchNorm2d(64),
80 | nn.ReLU(inplace=True)
81 | )
82 |
83 | self.stage1 = self._make_layers(block, num_block[0], 64, 1)
84 | self.stage2 = self._make_layers(block, num_block[1], 128, 2)
85 | self.stage3 = self._make_layers(block, num_block[2], 256, 2)
86 | self.stage4 = self._make_layers(block, num_block[3], 512, 2)
87 |
88 | self.linear = nn.Linear(self.input_channels, num_classes)
89 |
90 | def _make_layers(self, block, block_num, out_channels, stride):
91 | layers = []
92 |
93 | layers.append(block(self.input_channels, out_channels, stride))
94 | self.input_channels = out_channels * block.expansion
95 |
96 | while block_num - 1:
97 | layers.append(block(self.input_channels, out_channels, 1))
98 | self.input_channels = out_channels * block.expansion
99 | block_num -= 1
100 |
101 | return nn.Sequential(*layers)
102 |
103 | def forward(self, x):
104 | x = self.conv1(x)
105 |
106 | x = self.stage1(x)
107 | x = self.stage2(x)
108 | x = self.stage3(x)
109 | x = self.stage4(x)
110 |
111 | x = F.adaptive_avg_pool2d(x, 1)
112 | x = x.view(x.size(0), -1)
113 | x = self.linear(x)
114 |
115 | return x
116 |
117 |
118 | def preactresnet18(input_channel=3, num_classes=10):
119 | return PreActResNet(PreActBasic, [2, 2, 2, 2], num_classes=num_classes, input_channel=input_channel)
120 |
121 |
122 | def preactresnet34(input_channel=3, num_classes=10):
123 | return PreActResNet(PreActBasic, [3, 4, 6, 3], num_classes=num_classes, input_channel=input_channel)
124 |
125 |
126 | def preactresnet50(input_channel=3, num_classes=10):
127 | return PreActResNet(PreActBottleNeck, [3, 4, 6, 3], num_classes=num_classes, input_channel=input_channel)
128 |
129 |
130 | def preactresnet101(input_channel=3, num_classes=10):
131 | return PreActResNet(PreActBottleNeck, [3, 4, 23, 3], num_classes=num_classes, input_channel=input_channel)
132 |
133 |
134 | def preactresnet152(input_channel=3, num_classes=10):
135 | return PreActResNet(PreActBottleNeck, [3, 8, 36, 3], num_classes=num_classes, input_channel=input_channel)
136 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 |
3 | For Pre-activation ResNet, see 'preact_resnet.py'.
4 |
5 | Reference:
6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
8 | '''
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from torch.autograd import Variable
14 |
15 |
16 | class BasicBlock(nn.Module):
17 | expansion = 1
18 |
19 | def __init__(self, in_planes, planes, stride=1):
20 | super(BasicBlock, self).__init__()
21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
22 | self.bn1 = nn.BatchNorm2d(planes)
23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
24 | self.bn2 = nn.BatchNorm2d(planes)
25 |
26 | self.shortcut = nn.Sequential()
27 | if stride != 1 or in_planes != self.expansion*planes:
28 | self.shortcut = nn.Sequential(
29 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
30 | nn.BatchNorm2d(self.expansion*planes)
31 | )
32 |
33 | def forward(self, x):
34 | out = F.relu(self.bn1(self.conv1(x)))
35 | out = self.bn2(self.conv2(out))
36 | out += self.shortcut(x)
37 | out = F.relu(out)
38 | return out
39 |
40 |
41 | class Bottleneck(nn.Module):
42 | expansion = 4
43 |
44 | def __init__(self, in_planes, planes, stride=1):
45 | super(Bottleneck, self).__init__()
46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
47 | self.bn1 = nn.BatchNorm2d(planes)
48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
49 | self.bn2 = nn.BatchNorm2d(planes)
50 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
51 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
52 |
53 | self.shortcut = nn.Sequential()
54 | if stride != 1 or in_planes != self.expansion*planes:
55 | self.shortcut = nn.Sequential(
56 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
57 | nn.BatchNorm2d(self.expansion*planes)
58 | )
59 |
60 | def forward(self, x):
61 | out = F.relu(self.bn1(self.conv1(x)))
62 | out = F.relu(self.bn2(self.conv2(out)))
63 | out = self.bn3(self.conv3(out))
64 | out += self.shortcut(x)
65 | out = F.relu(out)
66 | return out
67 |
68 |
69 | class ResNet(nn.Module):
70 | def __init__(self, block, num_blocks, num_classes=10, input_channel=3):
71 | super(ResNet, self).__init__()
72 | self.in_planes = 64
73 |
74 | self.conv1 = nn.Conv2d(input_channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
75 | self.bn1 = nn.BatchNorm2d(64)
76 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
77 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
78 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
79 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
80 | self.linear = nn.Linear(512*block.expansion, num_classes)
81 |
82 | def _make_layer(self, block, planes, num_blocks, stride):
83 | strides = [stride] + [1]*(num_blocks-1)
84 | layers = []
85 | for stride in strides:
86 | layers.append(block(self.in_planes, planes, stride))
87 | self.in_planes = planes * block.expansion
88 | return nn.Sequential(*layers)
89 |
90 | def forward(self, x):
91 | out = F.relu(self.bn1(self.conv1(x)))
92 | out = self.layer1(out)
93 | out = self.layer2(out)
94 | out = self.layer3(out)
95 | out = self.layer4(out)
96 | out = F.avg_pool2d(out, 4)
97 | out = out.view(out.size(0), -1)
98 | out = self.linear(out)
99 | return out
100 |
101 |
102 | def ResNet18(input_channel=3, num_classes=10):
103 | return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes, input_channel=input_channel)
104 |
105 |
106 | def ResNet34(input_channel=3, num_classes=10):
107 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes, input_channel=input_channel)
108 |
109 |
110 | def ResNet50(input_channel=3, num_classes=10):
111 | return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes, input_channel=input_channel)
112 |
113 |
114 | def ResNet101(input_channel=3, num_classes=10):
115 | return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes, input_channel=input_channel)
116 |
117 |
118 | def ResNet152(input_channel=3, num_classes=10):
119 | return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes, input_channel=input_channel)
120 |
121 |
122 | def test():
123 | net = ResNet18()
124 | y = net(Variable(torch.randn(1,3,32,32)))
125 | print(y.size())
126 | print(net)
127 |
--------------------------------------------------------------------------------
/utils_mcl_loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Official implementation from "Learning with Multiple Complementary Labels"
3 | by Lei Feng et al.
4 | """
5 | import torch.nn as nn
6 | import torch
7 | import math
8 | import torch.nn.functional as F
9 |
10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11 |
12 |
13 | def mae_loss(outputs, partialY):
14 | sm_outputs = F.softmax(outputs, dim=1)
15 | loss_fn = nn.L1Loss(reduction='none')
16 | loss_matrix = loss_fn(sm_outputs, partialY.float())
17 | sample_loss = loss_matrix.sum(dim=-1)
18 | return sample_loss
19 |
20 |
21 | def mse_loss(outputs, Y):
22 | sm_outputs = F.softmax(outputs, dim=1)
23 | loss_fn = nn.MSELoss(reduction='none')
24 | loss_matrix = loss_fn(sm_outputs, Y.float())
25 | sample_loss = loss_matrix.sum(dim=-1)
26 | return sample_loss
27 |
28 |
29 | def gce_loss(outputs, Y):
30 | q = 0.7
31 | sm_outputs = F.softmax(outputs, dim=1)
32 | pow_outputs = torch.pow(sm_outputs, q)
33 | sample_loss = (1-(pow_outputs*Y).sum(dim=1))/q # n
34 | return sample_loss
35 |
36 |
37 | def phuber_ce_loss(outputs, Y):
38 | trunc_point = 0.1
39 | n = Y.shape[0]
40 | soft_max = nn.Softmax(dim=1)
41 | sm_outputs = soft_max(outputs)
42 | final_outputs = sm_outputs * Y
43 | final_confidence = final_outputs.sum(dim=1)
44 |
45 | ce_index = (final_confidence > trunc_point)
46 | sample_loss = torch.zeros(n).to(device)
47 |
48 | if ce_index.sum() > 0:
49 | ce_outputs = outputs[ce_index,:]
50 | logsm = nn.LogSoftmax(dim=-1)
51 | logsm_outputs = logsm(ce_outputs)
52 | final_ce_outputs = logsm_outputs * Y[ce_index,:]
53 | sample_loss[ce_index] = - final_ce_outputs.sum(dim=-1)
54 |
55 | linear_index = (final_confidence <= trunc_point)
56 |
57 | if linear_index.sum() > 0:
58 | sample_loss[linear_index] = -math.log(trunc_point) + (-1/trunc_point)*final_confidence[linear_index] + 1
59 |
60 | return sample_loss
61 |
62 |
63 | def ce_loss(outputs, Y):
64 | logsm = nn.LogSoftmax(dim=1)
65 | logsm_outputs = logsm(outputs)
66 | final_outputs = logsm_outputs * Y
67 | sample_loss = - final_outputs.sum(dim=1)
68 | return sample_loss
69 |
70 |
71 | def unbiased_estimator(loss_fn, outputs, partialY):
72 | n, k = partialY.shape[0], partialY.shape[1]
73 | comp_num = k - partialY.sum(dim=1)
74 | temp_loss = torch.zeros(n, k).to(device)
75 | for i in range(k):
76 | tempY = torch.zeros(n, k).to(device)
77 | tempY[:, i] = 1.0
78 | temp_loss[:, i] = loss_fn(outputs, tempY)
79 |
80 | candidate_loss = (temp_loss * partialY).sum(dim=1) # for true label
81 | noncandidate_loss = (temp_loss * (1-partialY)).sum(dim=1) # for complementary label
82 | total_loss = candidate_loss - (k-comp_num-1.0)/comp_num * noncandidate_loss
83 | average_loss = total_loss.mean()
84 | return average_loss
85 |
86 |
87 | def log_ce_loss(outputs, partialY, pseudo_labels, alpha):
88 | k = partialY.shape[1]
89 | can_num = partialY.sum(dim=1).float() # n
90 |
91 | soft_max = nn.Softmax(dim=1)
92 | sm_outputs = soft_max(outputs)
93 | final_outputs = sm_outputs * partialY # \sum_{j\notin \bar{Y}} [p(j|x)]
94 |
95 | pred_outputs = sm_outputs[torch.arange(sm_outputs.size(0)), pseudo_labels] # p(pl|x)
96 | # pred_outputs, _ = torch.max(final_outputs, dim=1) # \max \sum_{j\notin \bar{Y}} [p(j|x)]
97 |
98 | average_loss = - ((k - 1) / (k - can_num) * torch.log(alpha * final_outputs.sum(dim=1) + (1 - alpha) * pred_outputs)).mean()
99 |
100 | return average_loss
101 |
102 |
103 | def log_loss(outputs, partialY):
104 | k = partialY.shape[1]
105 | can_num = partialY.sum(dim=1).float() # n
106 |
107 | soft_max = nn.Softmax(dim=1)
108 | sm_outputs = soft_max(outputs)
109 | final_outputs = sm_outputs * partialY
110 |
111 | average_loss = - ((k-1)/(k-can_num) * torch.log(final_outputs.sum(dim=1))).mean()
112 | return average_loss
113 |
114 |
115 | def exp_ce_loss(outputs, partialY, pseudo_labels, alpha):
116 | k = partialY.shape[1]
117 | can_num = partialY.sum(dim=1).float() # n
118 |
119 | soft_max = nn.Softmax(dim=1)
120 | sm_outputs = soft_max(outputs)
121 | final_outputs = sm_outputs * partialY # \sum_{j\notin \bar{Y}} [p(j|x)]
122 |
123 | pred_outputs = sm_outputs[torch.arange(sm_outputs.size(0)), pseudo_labels] # p(pl|x)
124 | # pred_outputs, _ = torch.max(final_outputs, dim=1) # \max \sum_{j\notin \bar{Y}} [p(j|x)]
125 |
126 | average_loss = ((k - 1) / (k - can_num) * torch.exp(-alpha * final_outputs.sum(dim=1) - (1 - alpha) * pred_outputs)).mean()
127 |
128 | return average_loss
129 |
130 |
131 | def exp_loss(outputs, partialY):
132 | k = partialY.shape[1]
133 | can_num = partialY.sum(dim=1).float() # n
134 |
135 | soft_max = nn.Softmax(dim=1)
136 | sm_outputs = soft_max(outputs)
137 | final_outputs = sm_outputs * partialY
138 |
139 | average_loss = ((k-1)/(k-can_num) * torch.exp(-final_outputs.sum(dim=1))).mean()
140 | return average_loss
141 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ### Adversarial Training with Complementary Labels
2 |
3 | [](https://openreview.net/forum?id=s7SukMH7ie9) [](https://openreview.net/pdf?id=s7SukMH7ie9) [](https://neurips.cc/virtual/2022/poster/55084)
4 |
5 | Official (Pytorch) Implementation of *NeurIPS 2022 Spotlight "Adversarial Training with Complementary Labels: On the Benefit of Gradually Informative Attacks"* by [Jianan Zhou*](https://openreview.net/profile?id=~Jianan_Zhou1)*,* [Jianing Zhu*](https://openreview.net/profile?id=~Jianing_Zhu2)*,* [Jingfeng Zhang](https://openreview.net/profile?id=~Jingfeng_Zhang1)*,* [Tongliang Liu](https://openreview.net/profile?id=~Tongliang_Liu1)*,* [Gang Niu](https://openreview.net/profile?id=~Gang_Niu1)*,* [Bo Han](https://openreview.net/profile?id=~Bo_Han1)*,* [Masashi Sugiyama](https://openreview.net/profile?id=~Masashi_Sugiyama1).
6 |
7 | ```bash
8 | @inproceedings{zhou2022adversarial,
9 | title={Adversarial Training with Complementary Labels: On the Benefit of Gradually Informative Attacks},
10 | author={Jianan Zhou and Jianing Zhu and Jingfeng Zhang and Tongliang Liu and Gang Niu and Bo Han and Masashi Sugiyama},
11 | booktitle={Advances in Neural Information Processing Systems},
12 | editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
13 | year={2022},
14 | url={https://openreview.net/forum?id=s7SukMH7ie9}
15 | }
16 | ```
17 |
18 |
19 |
20 |
21 |
22 | ### TL;DR
23 |
24 | > *How to equip machine learning models with adversarial robustness when all given labels in a dataset are wrong (i.e., complementary labels)?*
25 |
26 |
27 |
28 |
29 |
30 | ### Dependencies
31 |
32 | * Python 3.8
33 | * Scipy
34 | * [PyTorch 1.11.0](https://pytorch.org)
35 | * [AutoAttack](https://github.com/fra31/auto-attack)
36 |
37 | ### How to Run
38 |
39 | > Please refer to Section 5 and Appendix D.1 of our paper for the detailed setups.
40 |
41 | #### Baseline
42 |
43 | ```bash
44 | # Two-stage baseline for MNIST/Kuzushiji
45 | python main.py --dataset 'kuzushiji' --model 'cnn' --method 'log' --framework 'two_stage' --cl_epochs 50 --adv_epochs 50 --cl_lr 0.001 --at_lr 0.01
46 | # Two-stage baseline for CIFAR10/SVHN
47 | python main.py --dataset 'cifar10' --model 'resnet18' --method 'log' --framework 'two_stage' --cl_epochs 50 --adv_epochs 70 --cl_lr 0.01 --at_lr 0.01
48 |
49 | # Complementary baselines (e.g., LOG) for MNIST/Kuzushiji
50 | python main.py --dataset 'kuzushiji' --model 'cnn' --method 'log' --framework 'one_stage' --adv_epochs 100 --at_lr 0.01 --scheduler 'none'
51 | # Complementary baselines (e.g., LOG) for CIFAR10/SVHN
52 | python main.py --dataset 'cifar10' --model 'resnet18' --method 'log' --framework 'one_stage' --adv_epochs 120 --at_lr 0.01 --scheduler 'none'
53 | ```
54 |
55 | #### Ours
56 |
57 | ```bash
58 | # MNIST/Kuzushiji
59 | python main.py --dataset 'kuzushiji' --model 'cnn' --method 'log_ce' --framework 'one_stage' --adv_epochs 100 --at_lr 0.01 --scheduler 'cosine' --sch_epoch 50 --warmup_epoch 10
60 | # CIFAR10/SVHN
61 | python main.py --dataset 'cifar10' --model 'resnet18' --method 'log_ce' --framework 'one_stage' --adv_epochs 120 --at_lr 0.01 --scheduler 'cosine' --sch_epoch 40 --warmup_epoch 40
62 | ```
63 |
64 | #### Options
65 |
66 | ```bash
67 | # Supported Datasets (we cannot handle cifar100 on the SCL setting currently, i.e., complementary learning fails on CIFAR100 in our exp.)
68 | --dataset - ['mnist', 'kuzushiji', 'fashion', 'cifar10', 'svhn', 'cifar100']
69 | # Complementary Loss Functions
70 | --method - ['free', 'nn', 'ga', 'pc', 'forward', 'scl_exp', 'scl_nl', 'mae', 'mse', 'ce', 'gce', 'phuber_ce', 'log', 'exp', 'l_uw', 'l_w', 'log_ce', 'exp_ce']
71 | # Multiple Complementary Labels (MCLs)
72 | --cl_num - (1-9) the number of complementary labels of each data; (0) MCLs data distribution of ICML2020 - "Learning with Multiple Complementary Labels"
73 | ```
74 |
75 | ### Reference
76 |
77 | * [NeurIPS 2017] - [Learning from complementary labels](https://arxiv.org/abs/1705.07541)
78 |
79 | * [ECCV 2018] - [Learning with biased complementary labels](https://arxiv.org/abs/1711.09535)
80 |
81 | * [ICML 2019] - [Complementary-label learning for arbitrary losses and models](https://arxiv.org/abs/1810.04327)
82 |
83 | * [ICML 2020] - [Unbiased Risk Estimators Can Mislead: A Case Study of Learning with Complementary Labels](https://arxiv.org/abs/2007.02235)
84 |
85 | * [ICML 2020] - [Learning with Multiple Complementary Labels](https://arxiv.org/abs/1912.12927)
86 |
87 | ### Acknowledgments
88 |
89 | Thank the authors of *"Complementary-label learning for arbitrary losses and models"* for the open-source [code](https://github.com/takashiishida/comp) and issue discussion. Other codebases may be found on the corresponding author's homepage. We also would like to thank anonymous reviewers of NeurIPS 2022 for their constructive comments.
90 |
91 | ### Contact
92 |
93 | Please contact [jianan004@e.ntu.edu.sg](mailto:jianan004@e.ntu.edu.sg) and [csjnzhu@comp.hkbu.edu.hk](mailto:csjnzhu@comp.hkbu.edu.hk) if you have any questions regarding the paper or implementation.
94 |
95 |
--------------------------------------------------------------------------------
/models/densenet.py:
--------------------------------------------------------------------------------
1 | # https://github.com/bearpaw/pytorch-classification/blob/master/models/cifar/densenet.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import math
7 |
8 |
9 | __all__ = ['densenet']
10 |
11 |
12 | from torch.autograd import Variable
13 |
14 | class Bottleneck(nn.Module):
15 | def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0):
16 | super(Bottleneck, self).__init__()
17 | planes = expansion * growthRate
18 | self.bn1 = nn.BatchNorm2d(inplanes)
19 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 | self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3,
22 | padding=1, bias=False)
23 | self.relu = nn.ReLU(inplace=True)
24 | self.dropRate = dropRate
25 |
26 | def forward(self, x):
27 | out = self.bn1(x)
28 | out = self.relu(out)
29 | out = self.conv1(out)
30 | out = self.bn2(out)
31 | out = self.relu(out)
32 | out = self.conv2(out)
33 | if self.dropRate > 0:
34 | out = F.dropout(out, p=self.dropRate, training=self.training)
35 |
36 | out = torch.cat((x, out), 1)
37 |
38 | return out
39 |
40 |
41 | class BasicBlock(nn.Module):
42 | def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0):
43 | super(BasicBlock, self).__init__()
44 | planes = expansion * growthRate
45 | self.bn1 = nn.BatchNorm2d(inplanes)
46 | self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3,
47 | padding=1, bias=False)
48 | self.relu = nn.ReLU(inplace=True)
49 | self.dropRate = dropRate
50 |
51 | def forward(self, x):
52 | out = self.bn1(x)
53 | out = self.relu(out)
54 | out = self.conv1(out)
55 | if self.dropRate > 0:
56 | out = F.dropout(out, p=self.dropRate, training=self.training)
57 |
58 | out = torch.cat((x, out), 1)
59 |
60 | return out
61 |
62 |
63 | class Transition(nn.Module):
64 | def __init__(self, inplanes, outplanes):
65 | super(Transition, self).__init__()
66 | self.bn1 = nn.BatchNorm2d(inplanes)
67 | self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False)
68 | self.relu = nn.ReLU(inplace=True)
69 |
70 | def forward(self, x):
71 | out = self.bn1(x)
72 | out = self.relu(out)
73 | out = self.conv1(out)
74 | out = F.avg_pool2d(out, 2)
75 | return out
76 |
77 |
78 | class DenseNet(nn.Module):
79 |
80 | def __init__(self, depth=22, block=Bottleneck, dropRate=0, num_classes=10, growthRate=12, compressionRate=2, input_channel=3):
81 | super(DenseNet, self).__init__()
82 |
83 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4'
84 | n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6
85 |
86 | self.growthRate = growthRate
87 | self.dropRate = dropRate
88 |
89 | # self.inplanes is a global variable used across multiple
90 | # helper functions
91 | self.inplanes = growthRate * 2
92 | self.conv1 = nn.Conv2d(input_channel, self.inplanes, kernel_size=3, padding=1, bias=False)
93 | self.dense1 = self._make_denseblock(block, n)
94 | self.trans1 = self._make_transition(compressionRate)
95 | self.dense2 = self._make_denseblock(block, n)
96 | self.trans2 = self._make_transition(compressionRate)
97 | self.dense3 = self._make_denseblock(block, n)
98 | self.bn = nn.BatchNorm2d(self.inplanes)
99 | self.relu = nn.ReLU(inplace=True)
100 | self.avgpool = nn.AvgPool2d(8)
101 | self.linear = nn.Linear(self.inplanes, num_classes)
102 |
103 | # Weight initialization
104 | for m in self.modules():
105 | if isinstance(m, nn.Conv2d):
106 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
107 | m.weight.data.normal_(0, math.sqrt(2. / n))
108 | elif isinstance(m, nn.BatchNorm2d):
109 | m.weight.data.fill_(1)
110 | m.bias.data.zero_()
111 |
112 | def _make_denseblock(self, block, blocks):
113 | layers = []
114 | for i in range(blocks):
115 | # Currently we fix the expansion ratio as the default value
116 | layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate))
117 | self.inplanes += self.growthRate
118 |
119 | return nn.Sequential(*layers)
120 |
121 | def _make_transition(self, compressionRate):
122 | inplanes = self.inplanes
123 | outplanes = int(math.floor(self.inplanes // compressionRate))
124 | self.inplanes = outplanes
125 | return Transition(inplanes, outplanes)
126 |
127 |
128 | def forward(self, x):
129 | x = self.conv1(x)
130 |
131 | x = self.trans1(self.dense1(x))
132 | x = self.trans2(self.dense2(x))
133 | x = self.dense3(x)
134 | x = self.bn(x)
135 | x = self.relu(x)
136 |
137 | x = self.avgpool(x)
138 | emb = x.view(x.size(0), -1)
139 | x = self.linear(emb)
140 |
141 | return x
142 |
143 |
144 | def densenet(**kwargs):
145 | """
146 | Constructs a ResNet model.
147 | """
148 | return DenseNet(**kwargs)
149 |
--------------------------------------------------------------------------------
/attack_generator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.autograd import Variable
3 | from utils_algo import *
4 | from utils_mcl_loss import *
5 |
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 |
8 |
9 | def trades_loss(adv_logits, natural_logits, target, beta):
10 | """
11 | Based on the repo TREADES: https://github.com/yaodongyu/TRADES
12 | """
13 | batch_size = len(target)
14 | criterion_kl = nn.KLDivLoss(size_average=False).cuda()
15 | loss_natural = nn.CrossEntropyLoss(reduction='mean')(natural_logits, target)
16 | loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(adv_logits, dim=1), F.softmax(natural_logits, dim=1))
17 | loss = loss_natural + beta * loss_robust
18 | return loss
19 |
20 |
21 | def cwloss(output, target, confidence=50, num_classes=10):
22 | # Compute the probability of the label class versus the maximum other
23 | # The same implementation as in repo CAT https://github.com/sunblaze-ucb/curriculum-adversarial-training-CAT
24 | target = target.data
25 | target_onehot = torch.zeros(target.size() + (num_classes,))
26 | target_onehot = target_onehot.to(device)
27 | target_onehot.scatter_(1, target.unsqueeze(1), 1.)
28 | target_var = Variable(target_onehot, requires_grad=False)
29 | real = (target_var * output).sum(1)
30 | other = ((1. - target_var) * output - target_var * 10000.).max(1)[0]
31 | loss = -torch.clamp(real - other + confidence, min=0.) # equiv to max(..., 0.)
32 | loss = torch.sum(loss)
33 | return loss
34 |
35 |
36 | def cl_adv(args, model, data, target, epsilon, step_size, num_steps, id, ccp, partialY, pseudo_labels, alpha, loss_fn, category="Madry", rand_init=True):
37 | model.eval()
38 | if category == "trades":
39 | x_adv = data.detach() + 0.001 * torch.randn(data.shape).to(device).detach() if rand_init else data.detach()
40 | if category == "Madry":
41 | x_adv = data.detach() + torch.from_numpy(np.random.uniform(-epsilon, epsilon, data.shape)).float().to(device) if rand_init else data.detach()
42 | x_adv = torch.clamp(x_adv, 0.0, 1.0)
43 |
44 | # generate adversarial examples
45 | for k in range(num_steps):
46 | x_adv.requires_grad_()
47 | output = model(x_adv)
48 | model.zero_grad()
49 | with torch.enable_grad():
50 | if args.method in ['exp', 'log']:
51 | loss_adv = loss_fn(output, partialY[id].float())
52 | elif args.method in ['mae', 'mse', 'ce', 'gce', 'phuber_ce']:
53 | loss_adv = unbiased_estimator(loss_fn, output, partialY[id].float())
54 | elif args.method in ['free', 'nn', 'ga', 'pc', 'forward', 'scl_exp', 'scl_nl', 'l_uw', 'l_w']:
55 | assert args.cl_num == 1
56 | loss_adv, _ = chosen_loss_c(f=output, K=output.size(-1), labels=target, ccp=ccp, meta_method=args.method)
57 | elif args.method in ["log_ce", "exp_ce"]:
58 | loss_adv = loss_fn(output, partialY[id].float(), pseudo_labels, alpha)
59 | loss_adv.backward()
60 | eta = step_size * x_adv.grad.sign()
61 | # Update adversarial data
62 | x_adv = x_adv.detach() + eta
63 | x_adv = torch.min(torch.max(x_adv, data - epsilon), data + epsilon)
64 | x_adv = torch.clamp(x_adv, 0.0, 1.0)
65 | x_adv = Variable(x_adv, requires_grad=False)
66 |
67 | return x_adv
68 |
69 |
70 | def pgd(model, data, target, epsilon, step_size, num_steps, loss_fn, category, rand_init, num_classes=10):
71 | model.eval()
72 | if category == "trades":
73 | x_adv = data.detach() + 0.001 * torch.randn(data.shape).to(device).detach() if rand_init else data.detach()
74 | if category == "Madry":
75 | x_adv = data.detach() + torch.from_numpy(np.random.uniform(-epsilon, epsilon, data.shape)).float().to(device) if rand_init else data.detach()
76 | x_adv = torch.clamp(x_adv, 0.0, 1.0)
77 | for k in range(num_steps):
78 | x_adv.requires_grad_()
79 | output = model(x_adv)
80 | model.zero_grad()
81 | with torch.enable_grad():
82 | if loss_fn == "cent":
83 | loss_adv = nn.CrossEntropyLoss(reduction="mean")(output, target)
84 | if loss_fn == "cw":
85 | loss_adv = cwloss(output, target, num_classes=num_classes)
86 | if loss_fn == "kl":
87 | criterion_kl = nn.KLDivLoss(reduction="batchmean").cuda()
88 | loss_adv = criterion_kl(F.log_softmax(output, dim=1), F.softmax(model(data), dim=1))
89 | loss_adv.backward()
90 | eta = step_size * x_adv.grad.sign()
91 | # Update adversarial data
92 | x_adv = x_adv.detach() + eta
93 | x_adv = torch.min(torch.max(x_adv, data - epsilon), data + epsilon)
94 | x_adv = torch.clamp(x_adv, 0.0, 1.0)
95 | x_adv = Variable(x_adv, requires_grad=False)
96 | return x_adv
97 |
98 |
99 | def eval_robust(model, test_loader, perturb_steps, epsilon, step_size, loss_fn, category, random, num_classes=10):
100 | model.eval()
101 | test_loss = 0
102 | correct = 0
103 | with torch.enable_grad():
104 | for data, target in test_loader:
105 | data, target = data.to(device), target.to(device)
106 | x_adv = pgd(model, data, target, epsilon, step_size, perturb_steps, loss_fn, category, rand_init=random, num_classes=num_classes)
107 | output = model(x_adv)
108 | test_loss += F.cross_entropy(output, target, reduction="sum").item()
109 | pred = output.max(1, keepdim=True)[1]
110 | correct += pred.eq(target.view_as(pred)).sum().item()
111 | test_loss /= len(test_loader.dataset)
112 | test_accuracy = 100 * correct / len(test_loader.dataset)
113 | return test_loss, test_accuracy
114 |
--------------------------------------------------------------------------------
/utils_algo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 |
8 |
9 | def assump_free_loss(f, K, labels, ccp):
10 | """Assumption free loss (based on Thm 1) is equivalent to non_negative_loss if the max operator's threshold is negative inf."""
11 | return non_negative_loss(f=f, K=K, labels=labels, ccp=ccp, beta=np.inf)
12 |
13 |
14 | def non_negative_loss(f, K, labels, ccp, beta):
15 | ccp = torch.from_numpy(ccp).float().to(device)
16 | neglog = -F.log_softmax(f, dim=1) # (bs, K)
17 | loss_vector = torch.zeros(K, requires_grad=True).to(device)
18 | temp_loss_vector = torch.zeros(K).to(device)
19 | for k in range(K):
20 | idx = (labels == k)
21 | if torch.sum(idx).item() > 0:
22 | idxs = idx.view(-1, 1).repeat(1, K) # (bs, K)
23 | neglog_k = torch.masked_select(neglog, idxs).view(-1, K)
24 | temp_loss_vector[k] = -(K-1) * ccp[k] * torch.mean(neglog_k, dim=0)[k] # average of k-th class loss for k-th comp class samples
25 | # ccp[k] or ccp, refer to https://github.com/takashiishida/comp/issues/2
26 | loss_vector = loss_vector + torch.mul(ccp[k], torch.mean(neglog_k, dim=0)) # only k-th in the summation of the second term inside max
27 | loss_vector = loss_vector + temp_loss_vector
28 | count = np.bincount(labels.data.cpu()).astype('float')
29 | while len(count) < K:
30 | count = np.append(count, 0) # when largest label is below K, bincount will not take care of them
31 | loss_vector_with_zeros = torch.cat((loss_vector.view(-1, 1), torch.zeros(K, requires_grad=True).view(-1, 1).to(device)-beta), 1)
32 | max_loss_vector, _ = torch.max(loss_vector_with_zeros, dim=1)
33 | final_loss = torch.sum(max_loss_vector)
34 | return final_loss, torch.mul(torch.from_numpy(count).float().to(device), loss_vector)
35 |
36 |
37 | def fast_free_loss(f, K, cl_label):
38 | loss = -(K - 1) * nn.CrossEntropyLoss(reduction="none")(f, cl_label)
39 | for k in range(K):
40 | ll = torch.LongTensor([k] * cl_label.size(0)).to(device)
41 | loss += nn.CrossEntropyLoss(reduction="none")(f, ll)
42 |
43 | return loss
44 |
45 |
46 | def forward_loss(f, K, labels, reduction='mean'):
47 | Q = torch.ones(K, K) * 1/(K-1) # uniform assumption
48 | Q = Q.to(device)
49 | for k in range(K):
50 | Q[k, k] = 0
51 | q = torch.mm(F.softmax(f, 1), Q)
52 | return F.nll_loss(q.log(), labels.long(), reduction=reduction)
53 |
54 |
55 | def pc_loss(f, K, labels):
56 | sigmoid = nn.Sigmoid()
57 | fbar = f.gather(1, labels.long().view(-1, 1)).repeat(1, K)
58 | loss_matrix = sigmoid(-1. * (f - fbar)) # multiply -1 for "complementary"
59 | M1, M2 = K*(K-1)/2, K-1
60 | pc_loss = torch.sum(loss_matrix)*(K-1)/len(labels) - M1 + M2
61 | return pc_loss
62 |
63 |
64 | def phi_loss(phi, logits, target, reduction='mean'):
65 | """
66 | Official implementation of "Unbiased Risk Estimators Can Mislead: A Case Study of Learning with Complementary Labels"
67 | by Yu-Ting Chou et al.
68 | """
69 | if phi == 'lin':
70 | activated_prob = F.softmax(logits, dim=1)
71 | elif phi == 'quad':
72 | activated_prob = torch.pow(F.softmax(logits, dim=1), 2)
73 | elif phi == 'exp':
74 | activated_prob = torch.exp(F.softmax(logits, dim=1))
75 | # activated_prob = torch.exp(alpha * F.softmax(logits, dim=1) - (1 - alpha) * pred_outputs)
76 | elif phi == 'log':
77 | activated_prob = torch.log(F.softmax(logits, dim=1))
78 | elif phi == 'nl':
79 | activated_prob = -torch.log(1 - F.softmax(logits, dim=1) + 1e-5)
80 | # activated_prob = -torch.log(alpha * (1 - F.softmax(logits, dim=1) + 1e-5) + (1 - alpha) * pred_outputs)
81 | elif phi == 'hinge':
82 | activated_prob = F.softmax(logits, dim=1) - (1 / 10)
83 | activated_prob[activated_prob < 0] = 0
84 | else:
85 | raise ValueError('Invalid phi function')
86 |
87 | loss = -F.nll_loss(activated_prob, target, reduction=reduction)
88 | return loss
89 |
90 |
91 | """
92 | Below is the official implementation of ICML 2021 "Discriminative Complementary-Label Learning with Weighted Loss"
93 | by Yi Gao et al.
94 | """
95 |
96 |
97 | def non_k_softmax_loss(f, K, labels):
98 | Q_1 = 1 - F.softmax(f, 1)
99 | Q_1 = F.softmax(Q_1, 1)
100 | labels = labels.long()
101 | return F.nll_loss(Q_1.log(), labels.long()) # Equation(8) in paper
102 |
103 |
104 | def w_loss(f, K, labels):
105 | loss_class = non_k_softmax_loss(f=f, K=K, labels=labels)
106 | loss_w = w_loss_p(f=f, K=K, labels=labels)
107 | final_loss = loss_class + loss_w # Equation(11) in paper
108 | return final_loss
109 |
110 |
111 | # weighted loss
112 | def w_loss_p(f, K, labels):
113 | Q_1 = 1-F.softmax(f, 1)
114 | Q = F.softmax(Q_1, 1)
115 | q = torch.tensor(1.0) / torch.sum(Q_1, dim=1)
116 | q = q.view(-1, 1).repeat(1, K)
117 | w = torch.mul(Q_1, q) # weight
118 | w_1 = torch.mul(w, Q.log())
119 | return F.nll_loss(w_1, labels.long()) # Equation(14) in paper
120 |
121 |
122 | def chosen_loss_c(f, K, labels, ccp, meta_method, reduction='mean'):
123 | class_loss_torch = None
124 | if meta_method == 'free':
125 | final_loss, class_loss_torch = assump_free_loss(f=f, K=K, labels=labels, ccp=ccp)
126 | elif meta_method == 'ga':
127 | final_loss, class_loss_torch = assump_free_loss(f=f, K=K, labels=labels, ccp=ccp)
128 | elif meta_method == 'nn':
129 | final_loss, class_loss_torch = non_negative_loss(f=f, K=K, labels=labels, beta=0, ccp=ccp)
130 | elif meta_method == 'forward':
131 | final_loss = forward_loss(f=f, K=K, labels=labels, reduction=reduction)
132 | elif meta_method == 'pc':
133 | final_loss = pc_loss(f=f, K=K, labels=labels)
134 | elif meta_method[:3] == "scl":
135 | final_loss = phi_loss(meta_method[4:], f, labels, reduction=reduction)
136 | elif meta_method == 'l_uw':
137 | final_loss = non_k_softmax_loss(f=f, K=K, labels=labels)
138 | elif meta_method == 'l_w':
139 | final_loss = w_loss(f=f, K=K, labels=labels)
140 |
141 | return final_loss, class_loss_torch
142 |
--------------------------------------------------------------------------------
/utils_data.py:
--------------------------------------------------------------------------------
1 | import sys, random
2 | import numpy as np
3 | import torch
4 | from PIL import Image
5 | import torchvision.transforms as transforms
6 | import torchvision.datasets as dsets
7 | from torch.utils.data import Dataset
8 | from scipy.special import comb
9 |
10 |
11 | class AugComp(Dataset):
12 |
13 | def __init__(self, name, data, cl, tl, id, transform=None):
14 | self.name = name
15 | self.data = data
16 | self.cl = cl
17 | self.tl = tl
18 | self.id = id
19 | self.transform = transform
20 |
21 | def __getitem__(self, index):
22 | """
23 | Args:
24 | index (int): Index
25 | Returns:
26 | tuple: (image, target) where target is index of the target class.
27 | """
28 | img, cl_target, tl_target, idx = self.data[index], self.cl[index], self.tl[index], self.id[index]
29 |
30 | # doing this so that it is consistent with all other datasets to return a PIL Image
31 | if self.name in ["mnist", "fashion", "kuzushiji"]:
32 | img = Image.fromarray(img.numpy(), mode="L")
33 | elif self.name in ["cifar10", "cifar100"]:
34 | img = Image.fromarray(img)
35 | elif self.name in ["svhn"]:
36 | img = Image.fromarray(np.transpose(img, (1, 2, 0)))
37 | else:
38 | assert 0, "Please modify AugComp, since you are using an unsupported dataset: {}".format(self.name)
39 |
40 | if self.transform is not None:
41 | img = self.transform(img)
42 |
43 | return img, cl_target, tl_target, idx
44 |
45 | def __len__(self):
46 | return len(self.data)
47 |
48 |
49 | def generate_compl_labels(labels):
50 | """
51 | Generating single complementary label with uniform assumption
52 | """
53 | # args, labels: ordinary labels
54 | K = torch.max(labels)+1
55 | candidates = np.arange(K)
56 | candidates = np.repeat(candidates.reshape(1, K), len(labels), 0) # (len(labels), K)
57 | mask = np.ones((len(labels), K), dtype=bool)
58 | mask[range(len(labels)), labels.numpy()] = False
59 | candidates_ = candidates[mask].reshape(len(labels), K-1) # this is the candidates without true class: (len(labels), K-1)
60 | idx = np.random.randint(0, K-1, len(labels))
61 | complementary_labels = candidates_[np.arange(len(labels)), np.array(idx)]
62 | return complementary_labels
63 |
64 |
65 | def generate_uniform_mul_comp_labels(labels):
66 | """
67 | Generating multiple complementary labels following the distribution described in Section 5 of
68 | "Learning with Multiple Complementary Labels" by Lei Feng et al.
69 | """
70 | if torch.min(labels) > 1:
71 | raise RuntimeError('testError')
72 | elif torch.min(labels) == 1:
73 | labels = labels - 1
74 |
75 | K = torch.max(labels) - torch.min(labels) + 1
76 | n = labels.shape[0]
77 | cardinality = 2 ** K - 2
78 | number = torch.tensor([comb(K, i + 1) for i in range(K - 1)]) # 0 to K-2, convert list to tensor
79 | frequency_dis = number / cardinality
80 | prob_dis = torch.zeros(K - 1) # tensor of K-1
81 | for i in range(K - 1):
82 | if i == 0:
83 | prob_dis[i] = frequency_dis[i]
84 | else:
85 | prob_dis[i] = frequency_dis[i] + prob_dis[i - 1]
86 |
87 | random_n = torch.from_numpy(np.random.uniform(0, 1, n)).float() # tensor: n
88 | mask_n = torch.ones(n) # n is the number of train_data
89 | partialY = torch.ones(n, K)
90 | temp_num_comp_train_labels = 0 # save temp number of comp train_labels
91 |
92 | for j in range(n): # for each instance
93 | # if j % 1000 == 0:
94 | # print("current index:", j)
95 | for jj in range(K - 1): # 0 to K-2
96 | if random_n[j] <= prob_dis[jj] and mask_n[j] == 1:
97 | temp_num_comp_train_labels = jj + 1 # decide the number of complementary train_labels
98 | mask_n[j] = 0
99 |
100 | candidates = torch.from_numpy(np.random.permutation(K.item())) # because K is tensor type
101 | candidates = candidates[candidates != labels[j]]
102 | temp_comp_train_labels = candidates[:temp_num_comp_train_labels]
103 |
104 | for kk in range(len(temp_comp_train_labels)):
105 | partialY[j, temp_comp_train_labels[kk]] = 0 # fulfill the partial label matrix
106 | return partialY
107 |
108 |
109 | def generate_mul_comp_labels(data, labels, s):
110 | """
111 | Generating multiple complementary labels given the fixed size s of complementary label set \bar{Y} of each instance
112 | by "Learning with Multiple Complementary Labels" by Lei Feng et al.
113 | """
114 | k = torch.max(labels) + 1
115 | n = labels.shape[0]
116 | index_ins = torch.arange(n) # torch type
117 | realY = torch.zeros(n, k)
118 | realY[index_ins, labels] = 1
119 | partialY = torch.ones(n, k)
120 |
121 | labels_hat = labels.clone().numpy()
122 | candidates = np.repeat(np.arange(k).reshape(1, k), len(labels_hat), 0) # candidate labels without true class
123 | mask = np.ones((len(labels_hat), k), dtype=bool)
124 | for i in range(s):
125 | mask[np.arange(n), labels_hat] = False
126 | candidates_ = candidates[mask].reshape(n, k - 1 - i)
127 | idx = np.random.randint(0, k - 1 - i, n)
128 | comp_labels = candidates_[np.arange(n), np.array(idx)]
129 | partialY[index_ins, torch.from_numpy(comp_labels)] = 0
130 | labels_hat = comp_labels
131 | return partialY
132 |
133 |
134 | def class_prior(complementary_labels):
135 | # p(\bar{y})
136 | return np.bincount(complementary_labels) / len(complementary_labels)
137 |
138 |
139 | def prepare_data(args):
140 | dataset, batch_size = args.dataset, args.batch_size
141 | if dataset == "mnist":
142 | ordinary_train_dataset = dsets.MNIST(root='./data/MNIST', train=True, transform=transforms.ToTensor(), download=True)
143 | test_dataset = dsets.MNIST(root='./data/MNIST', train=False, transform=transforms.ToTensor())
144 | input_dim, input_channel, num_classes = 28 * 28, 1, 10
145 | elif dataset == "kuzushiji":
146 | ordinary_train_dataset = dsets.KMNIST(root='./data/KMNIST', train=True, transform=transforms.ToTensor(), download=True)
147 | test_dataset = dsets.KMNIST(root='./data/KMNIST', train=False, transform=transforms.ToTensor())
148 | input_dim, input_channel, num_classes = 28 * 28, 1, 10
149 | elif dataset == "fashion":
150 | ordinary_train_dataset = dsets.FashionMNIST(root='./data/FashionMNIST', train=True, transform=transforms.ToTensor(), download=True)
151 | test_dataset = dsets.FashionMNIST(root='./data/FashionMNIST', train=False, transform=transforms.ToTensor())
152 | input_dim, input_channel, num_classes = 28 * 28, 1, 10
153 | elif dataset == "cifar10":
154 | ordinary_train_dataset = dsets.CIFAR10(root='./data/CIFAR10', train=True, transform=transforms.ToTensor(), download=True)
155 | test_dataset = dsets.CIFAR10(root='./data/CIFAR10', train=False, transform=transforms.ToTensor())
156 | input_dim, input_channel, num_classes = 3 * 32 * 32, 3, 10
157 | elif dataset == "svhn":
158 | ordinary_train_dataset = dsets.SVHN(root='./data/SVHN', split='train', download=True, transform=transforms.ToTensor())
159 | test_dataset = dsets.SVHN(root='./data/SVHN', split='test', download=True, transform=transforms.ToTensor())
160 | ordinary_train_dataset.targets = ordinary_train_dataset.labels
161 | ordinary_train_dataset.classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
162 | input_dim, input_channel, num_classes = 3 * 32 * 32, 3, 10
163 | elif dataset == "cifar100":
164 | ordinary_train_dataset = dsets.CIFAR100(root='./data/CIFAR100', train=True, transform=transforms.ToTensor(), download=True)
165 | test_dataset = dsets.CIFAR100(root='./data/CIFAR100', train=False, transform=transforms.ToTensor())
166 | input_dim, input_channel, num_classes = 3 * 32 * 32, 3, 100
167 | train_loader = torch.utils.data.DataLoader(dataset=ordinary_train_dataset, batch_size=batch_size, shuffle=False)
168 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
169 | return train_loader, test_loader, ordinary_train_dataset, test_dataset, num_classes, input_dim, input_channel
170 |
171 |
172 | def prepare_train_loaders(args, ordinary_train_dataset, data_aug=None):
173 | """
174 | ccp is only used for "free", "ga", "nn"
175 | partialY is used for multiple complementary labels
176 | """
177 | batch_size, cl_num = args.batch_size, args.cl_num
178 | # load raw_data if data_aug is not None
179 | if data_aug is not None:
180 | data, labels = ordinary_train_dataset.data, ordinary_train_dataset.targets
181 | bs, K = len(ordinary_train_dataset.data), len(ordinary_train_dataset.classes)
182 | labels = torch.LongTensor(labels)
183 | else:
184 | full_train_loader = torch.utils.data.DataLoader(dataset=ordinary_train_dataset, batch_size=len(ordinary_train_dataset.data), shuffle=False)
185 | for i, (data, labels) in enumerate(full_train_loader):
186 | K = torch.max(labels) + 1
187 | bs = labels.size(0)
188 |
189 | complementary_labels = generate_compl_labels(labels)
190 | ccp = class_prior(complementary_labels)
191 | if cl_num == 0:
192 | partialY = generate_uniform_mul_comp_labels(labels)
193 | elif cl_num == 1:
194 | partialY = torch.ones(bs, K).scatter_(1, torch.LongTensor(complementary_labels).unsqueeze(1), 0)
195 | # ema = (torch.ones(bs, K).scatter_(1, torch.LongTensor(complementary_labels).unsqueeze(1), 0)) / (K - 1)
196 | else: # 2-9
197 | partialY = generate_mul_comp_labels(data, labels, cl_num)
198 | ema = partialY / partialY.sum(1).unsqueeze(1)
199 | id = torch.arange(bs)
200 | if data_aug is not None:
201 | complementary_dataset = AugComp(name=args.dataset, data=data, cl=torch.from_numpy(complementary_labels).long(), tl=labels, id=id, transform=data_aug)
202 | else:
203 | complementary_dataset = torch.utils.data.TensorDataset(data, torch.from_numpy(complementary_labels).long(), labels, id)
204 | complementary_train_loader = torch.utils.data.DataLoader(dataset=complementary_dataset, batch_size=batch_size, shuffle=True)
205 |
206 | return complementary_train_loader, ccp, partialY, ema
207 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse, time, os, random, math
2 | import pprint as pp
3 | import torchvision
4 | from utils_data import *
5 | from utils_algo import *
6 | from utils_mcl_loss import *
7 | from models import *
8 | from attack_generator import *
9 | from utils_func import *
10 |
11 |
12 | def adversarial_train(args, model, epochs, mode="atcl", seed=1):
13 | print(">> Current mode: {} for {} epochs".format(mode, epochs))
14 | lr = args.cl_lr if mode == "cl" else args.at_lr
15 | loss_fn, best_nat_acc, best_pgd20_acc, best_cw_acc, best_epoch = create_loss_fn(args), 0, 0, 0, 0
16 | nature_train_acc_list, nature_test_acc_list, pgd20_acc_list, cw_acc_list = [], [], [], []
17 | first_layer_grad, last_layer_grad = [], []
18 | if mode == "cl":
19 | optimizer = torch.optim.SGD(model.parameters(), lr=args.cl_lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.dataset in ["cifar10", "svhn", "cifar100"] else torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay, lr=args.cl_lr)
20 | elif mode == "at":
21 | optimizer = torch.optim.SGD(model.parameters(), lr=args.at_lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.dataset in ["cifar10", "svhn", "cifar100"] else torch.optim.SGD(model.parameters(), lr=args.at_lr, momentum=args.momentum)
22 | cl_model = create_model(args, input_dim, input_channel, K)
23 | checkpoint = torch.load(os.path.join(args.out_dir, "cl_best_checkpoint_seed{}.pth.tar".format(seed)))
24 | cl_model.load_state_dict(checkpoint['state_dict'])
25 | cl_model.eval()
26 | print(">> Load the CL model with train acc: {}, test acc: {}, epoch {}".format(checkpoint['train_acc'], checkpoint['test_acc'], checkpoint['epoch']))
27 | elif mode == "atcl":
28 | optimizer = torch.optim.SGD(model.parameters(), lr=args.at_lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.dataset in ["cifar10", "svhn", "cifar100"] else torch.optim.SGD(model.parameters(), lr=args.at_lr, momentum=args.momentum)
29 |
30 | for epoch in range(epochs):
31 | correct, total = 0, 0
32 | lr = (epoch+1) / 5 * args.at_lr if (mode == "cl" or mode == "atcl") and (epoch < 5 and args.dataset in ["cifar10", "svhn", "cifar100"]) else lr # lr warmup for cll
33 | lr = cl_lr_schedule(lr, epoch + 1, optimizer) if mode == "cl" else adv_lr_schedule(lr, epoch + 1, optimizer)
34 | for batch_idx, (images, cl_labels, true_labels, id) in enumerate(complementary_train_loader):
35 | images, cl_labels, true_labels = images.to(device), cl_labels.to(device), true_labels.to(device)
36 |
37 | if mode == "cl":
38 | x_adv = images
39 | elif mode == "at":
40 | pseudo_labels, prob = get_pred(cl_model, images.detach())
41 | pseudo_labels = pseudo_labels.to(device)
42 | x_adv = pgd(model, images, pseudo_labels, args.epsilon, args.step_size, args.num_steps, loss_fn="cent", category="Madry", rand_init=True, num_classes=K)
43 | else: # "atcl"
44 | _, prob = get_pred(model, images.detach())
45 | epsilon, step_size, num_steps = at_param_schedule(args, epoch + 1)
46 | alpha = atcl_scheduler(args, id, prob, epsilon, partialY, epoch + 1)
47 | _, pseudo_labels = torch.max(ema[id], 1)
48 | # pseudo_labels = torch.multinomial(ema[id], 1).squeeze(1)
49 | correct += (pseudo_labels == true_labels).sum().item()
50 | total += pseudo_labels.size(0)
51 | if epsilon != 0:
52 | # x_adv = pgd(model, images, true_labels, epsilon, step_size, num_steps, loss_fn="cent", category="Madry", rand_init=True, num_classes=K) # mex_ce_tl
53 | # x_adv = pgd(model, images, None, epsilon, step_size, num_steps, loss_fn="kl", category="trades", rand_init=True, num_classes=K) # max_trades
54 | x_adv = cl_adv(args, model, images, cl_labels, epsilon, step_size, num_steps, id, ccp, partialY, pseudo_labels, alpha, loss_fn, category="Madry", rand_init=True) # max_cl
55 | else:
56 | x_adv = images
57 |
58 | if batch_idx == 0:
59 | torchvision.utils.save_image(x_adv, os.path.join(args.out_dir, "x_adv_seed_{}_epoch_{}.jpg".format(seed, epoch+1)))
60 |
61 | model.train()
62 | optimizer.zero_grad()
63 | logit = model(x_adv)
64 | if mode == "at":
65 | loss = nn.CrossEntropyLoss(reduction="mean")(logit, pseudo_labels)
66 | else: # "cl" or "atcl"
67 | if args.method in ['exp', 'log']:
68 | loss = loss_fn(logit, partialY[id].float())
69 | elif args.method in ['mae', 'mse', 'ce', 'gce', 'phuber_ce']:
70 | loss = unbiased_estimator(loss_fn, logit, partialY[id].float())
71 | elif args.method in ['free', 'nn', 'ga', 'pc', 'forward', 'scl_exp', 'scl_nl', 'l_uw', 'l_w']:
72 | assert args.cl_num == 1
73 | loss, _ = chosen_loss_c(f=logit, K=K, labels=cl_labels, ccp=ccp, meta_method=args.method)
74 | elif args.method in ["log_ce", "exp_ce"]:
75 | loss = loss_fn(logit, partialY[id].float(), pseudo_labels, alpha)
76 | loss.backward()
77 |
78 | # the grad stat
79 | if args.model == "cnn":
80 | first_layer_grad.append(model.module.feature_extractor.conv1.weight.grad.norm(p=2).cpu().item())
81 | last_layer_grad.append(model.module.classifier.fc3.weight.grad.norm(p=2).cpu().item())
82 | elif args.model in ["densenet", "resnet18", "wrn"]:
83 | first_layer_grad.append(model.module.conv1.weight.grad.norm(p=2).cpu().item())
84 | last_layer_grad.append(model.module.linear.weight.grad.norm(p=2).cpu().item())
85 | else:
86 | assert True, "please modify grad stat code!"
87 | optimizer.step()
88 |
89 | # Evalutions
90 | model.eval()
91 | train_nat_acc = accuracy_check(loader=train_loader, model=model)
92 | test_nat_acc = accuracy_check(loader=test_loader, model=model)
93 | _, test_pgd20_acc = eval_robust(model, test_loader, perturb_steps=20, epsilon=args.epsilon, step_size=args.step_size, loss_fn="cent", category="Madry", random=True, num_classes=K)
94 | _, test_cw_acc = eval_robust(model, test_loader, perturb_steps=30, epsilon=args.epsilon, step_size=args.step_size, loss_fn="cw", category="Madry", random=True, num_classes=K)
95 | nature_train_acc_list.append(train_nat_acc);nature_test_acc_list.append(test_nat_acc);pgd20_acc_list.append(test_pgd20_acc);cw_acc_list.append(test_cw_acc)
96 | if mode == "atcl":
97 | print(round(correct/total*100, 2), alpha, epsilon, step_size, num_steps)
98 | print('Epoch: [%d | %d] | Learning Rate: %f | Natural Train Acc %.4f | Natural Test Acc %.4f | PGD20 Test Acc %.4f | CW Test Acc %.4f |\n' % (epoch + 1, epochs, lr, train_nat_acc, test_nat_acc, test_pgd20_acc, test_cw_acc))
99 |
100 | # Save the best & last checkpoint
101 | if mode == "cl":
102 | if test_nat_acc > best_nat_acc:
103 | best_epoch = epoch + 1
104 | best_nat_acc = test_nat_acc
105 | torch.save({
106 | 'epoch': epoch + 1,
107 | 'state_dict': model.state_dict(),
108 | 'train_acc': train_nat_acc,
109 | 'test_acc': test_nat_acc,
110 | 'optimizer': optimizer.state_dict(),
111 | }, os.path.join(args.out_dir, "cl_best_checkpoint_seed{}.pth.tar".format(seed)))
112 | torch.save({
113 | 'epoch': epoch + 1,
114 | 'state_dict': model.state_dict(),
115 | 'train_acc': train_nat_acc,
116 | 'test_acc': test_nat_acc,
117 | 'optimizer': optimizer.state_dict(),
118 | }, os.path.join(args.out_dir, "cl_checkpoint_seed{}.pth.tar".format(seed)))
119 | else:
120 | if test_pgd20_acc > best_pgd20_acc:
121 | best_epoch = epoch + 1
122 | best_nat_acc = test_nat_acc
123 | best_pgd20_acc = test_pgd20_acc
124 | best_cw_acc = test_cw_acc
125 | torch.save({
126 | 'epoch': epoch + 1,
127 | 'state_dict': model.state_dict(),
128 | 'test_nat_acc': test_nat_acc,
129 | 'test_pgd20_acc': test_pgd20_acc,
130 | 'test_cw_acc': test_cw_acc,
131 | 'optimizer': optimizer.state_dict(),
132 | }, os.path.join(args.out_dir, "best_checkpoint_seed{}.pth.tar".format(seed)))
133 | torch.save({
134 | 'epoch': epoch + 1,
135 | 'state_dict': model.state_dict(),
136 | 'test_nat_acc': test_nat_acc,
137 | 'test_pgd20_acc': test_pgd20_acc,
138 | 'test_cw_acc': test_cw_acc,
139 | 'optimizer': optimizer.state_dict(),
140 | }, os.path.join(args.out_dir, "checkpoint_seed{}.pth.tar".format(seed)))
141 |
142 | if mode == "cl":
143 | print(nature_train_acc_list)
144 | print(nature_test_acc_list)
145 | print(">> Best test acc({}): {}".format(best_epoch, max(nature_test_acc_list)))
146 | print(">> AVG test acc of last 10 epochs: {}".format(np.mean(nature_test_acc_list[-10:])))
147 | epoch = [i for i in range(epochs)]
148 | show([epoch] * 2, [nature_train_acc_list, nature_test_acc_list], label=["train acc", "test acc"], title=args.dataset, xdes="Epoch", ydes="Accuracy", path=os.path.join(args.out_dir, "cl_acc_seed{}.png".format(seed)))
149 |
150 | return np.mean(nature_test_acc_list[-10:])
151 | else:
152 | print(nature_test_acc_list)
153 | print(pgd20_acc_list)
154 | print(cw_acc_list)
155 | print(">> Finished Adv Training: Natural Test Acc | Last_checkpoint %.4f | Best_checkpoint(%.1f) %.4f |\n" % (test_nat_acc, best_epoch, best_nat_acc))
156 | print(">> Finished Adv Training: PGD20 Test Acc | Last_checkpoint %.4f | Best_checkpoint %.4f |\n" % (test_pgd20_acc, best_pgd20_acc))
157 | print(">> Finished Adv Training: CW Test Acc | Last_checkpoint %.4f | Best_checkpoint %.4f |\n" % (test_cw_acc, best_cw_acc))
158 | epoch = [i for i in range(epochs)]
159 | show([epoch, epoch, epoch], [nature_test_acc_list, pgd20_acc_list, cw_acc_list], label=["nature test acc", "pgd20 acc", "cw acc"],
160 | title=args.dataset, xdes="Epoch", ydes="Test Accuracy", path=os.path.join(args.out_dir, "adv_test_acc_seed{}.png".format(seed)))
161 | print("first_layer: \n{} \nlast_layer: \n{}".format(first_layer_grad, last_layer_grad), file=open(os.path.join(args.out_dir, "grad_seed{}.out".format(seed)), "a+"))
162 | # Auto-attack
163 | aa_eval(args, model, filename="aa_last.txt")
164 | best_checkpoint = torch.load(os.path.join(args.out_dir, "best_checkpoint_seed{}.pth.tar".format(seed)))
165 | model.load_state_dict(best_checkpoint['state_dict'])
166 | aa_eval(args, model, filename="aa_best.txt")
167 |
168 | return [test_nat_acc, test_pgd20_acc, test_cw_acc], [best_nat_acc, best_pgd20_acc, best_cw_acc]
169 |
170 |
171 | def aa_eval(args, model, filename):
172 | """
173 | AutoAttack evaluation - pip install git+https://github.com/fra31/auto-attack
174 | """
175 | from autoattack import AutoAttack
176 | model.eval()
177 | version, norm, individual, n_ex = "standard", "Linf", False, 10000
178 | adversary = AutoAttack(model, norm=norm, eps=args.epsilon, log_path=os.path.join(args.out_dir, filename), version=version)
179 |
180 | l = [x for (x, y) in test_loader]
181 | x_test = torch.cat(l, 0)
182 | l = [y for (x, y) in test_loader]
183 | y_test = torch.cat(l, 0)
184 |
185 | # run attack and save images
186 | with torch.no_grad():
187 | if not individual:
188 | adv_complete = adversary.run_standard_evaluation(x_test[:n_ex], y_test[:n_ex], bs=500)
189 | # torch.save({'adv_complete': adv_complete}, '{}/{}_{}_1_{}_eps_{:.5f}.pth'.format(args.out_dir, 'aa', version, adv_complete.shape[0], args.epsilon))
190 | else:
191 | # individual version, each attack is run on all test points
192 | adv_complete = adversary.run_standard_evaluation_individual(x_test[:n_ex], y_test[:n_ex], bs=500)
193 | # torch.save(adv_complete, '{}/{}_{}_individual_1_{}_eps_{:.5f}_plus_{}_cheap_{}.pth'.format(args.out_dir, 'aa', version, n_ex, args.epsilon))
194 |
195 |
196 | def at_param_schedule(args, epoch):
197 | if epoch <= args.warmup_epoch:
198 | return 0, 0, 0
199 | elif epoch <= (args.warmup_epoch+args.sch_epoch):
200 | if args.scheduler == "linear":
201 | eps = min(args.epsilon * ((epoch-args.warmup_epoch)/args.sch_epoch), args.epsilon)
202 | elif args.scheduler == "cosine":
203 | eps = 1/2 * (1-math.cos(math.pi * min(((epoch-args.warmup_epoch)/args.sch_epoch), 1))) * args.epsilon
204 | elif args.scheduler == "none":
205 | eps = args.epsilon
206 | # return eps, eps, 1
207 | # return eps, args.step_size, math.ceil(args.num_steps/args.epsilon*eps)
208 | return eps, args.step_size/args.epsilon*eps, args.num_steps
209 | else:
210 | return args.epsilon, args.step_size, args.num_steps
211 |
212 |
213 | def atcl_scheduler(args, id, prob, epsilon, partialY, epoch):
214 | if args.sch_epoch == 0:
215 | alpha = 1 if epoch <= args.warmup_epoch else 0
216 | else:
217 | alpha = min(max(1 - (epoch-args.warmup_epoch)/args.sch_epoch, 0), 1)
218 |
219 | if epoch <= 5:
220 | ema[id] = ema[id]
221 | elif epsilon < args.epsilon/2:
222 | ema[id] = 0.9 * ema[id] + (1 - 0.9) * prob
223 | else:
224 | ema[id] = ema[id]
225 |
226 | ema[id] = ema[id] * partialY[id] # reset to 0 for cls
227 |
228 | return alpha
229 |
230 |
231 | def cl_lr_schedule(lr, epoch, optimizer):
232 | for param_group in optimizer.param_groups:
233 | param_group['lr'] = lr
234 | return lr
235 |
236 |
237 | def adv_lr_schedule(lr, epoch, optimizer):
238 | if args.dataset in ["mnist", "fashion", "kuzushiji"]:
239 | # no lr_decay for easy dataset
240 | pass
241 | elif args.dataset in ["cifar10", "svhn", "cifar100"]:
242 | if epoch == (30+args.warmup_epoch):
243 | lr /= 10
244 | if epoch == (60+args.warmup_epoch):
245 | lr /= 10
246 | for param_group in optimizer.param_groups:
247 | param_group['lr'] = lr
248 | return lr
249 |
250 |
251 | def accuracy_check(loader, model):
252 | model.eval()
253 | sm = F.softmax
254 | total, num_samples = 0, 0
255 | for images, labels in loader:
256 | labels, images = labels.to(device), images.to(device)
257 | outputs = model(images)
258 | sm_outputs = sm(outputs, dim=1)
259 | _, predicted = torch.max(sm_outputs.data, 1)
260 | total += (predicted == labels).sum().item()
261 | num_samples += labels.size(0)
262 | return round(100 * total / num_samples, 2)
263 |
264 |
265 | def create_loss_fn(args):
266 | if args.method == 'mae':
267 | loss_fn = mae_loss
268 | elif args.method == 'mse':
269 | loss_fn = mse_loss
270 | elif args.method == 'ce':
271 | loss_fn = ce_loss
272 | elif args.method == 'gce':
273 | loss_fn = gce_loss
274 | elif args.method == 'phuber_ce':
275 | loss_fn = phuber_ce_loss
276 | elif args.method == 'log':
277 | loss_fn = log_loss
278 | elif args.method == 'exp':
279 | loss_fn = exp_loss
280 | elif args.method == 'log_ce':
281 | loss_fn = log_ce_loss
282 | elif args.method == "exp_ce":
283 | loss_fn = exp_ce_loss
284 | else:
285 | loss_fn = None
286 |
287 | return loss_fn
288 |
289 |
290 | def create_model(args, input_dim, input_channel, K):
291 | if args.model == 'mlp':
292 | model = mlp_model(input_dim=input_dim, hidden_dim=500, output_dim=K)
293 | elif args.model == 'linear':
294 | model = linear_model(input_dim=input_dim, output_dim=K)
295 | elif args.model == 'cnn':
296 | model = SmallCNN()
297 | elif args.model == 'resnet18':
298 | model = ResNet18(input_channel=input_channel, num_classes=K)
299 | elif args.model == 'densenet':
300 | model = densenet(input_channel=input_channel, num_classes=K)
301 | elif args.model == 'preact_resnet18':
302 | model = preactresnet18(input_channel=input_channel, num_classes=K)
303 | elif args.model == "wrn":
304 | model = Wide_ResNet_Madry(depth=32, num_classes=K, widen_factor=10, dropRate=0.0, input_channel=input_channel) # WRN-32-10
305 |
306 | display_num_param(model)
307 | model = model.to(device)
308 | model = torch.nn.DataParallel(model)
309 |
310 | return model
311 |
312 |
313 | def get_pred(cl_model, data):
314 | cl_model.eval()
315 | with torch.no_grad():
316 | data = data.to(device)
317 | outputs = cl_model(data)
318 | _, predicted = torch.max(outputs, 1)
319 |
320 | return predicted, torch.softmax(outputs, dim=1)
321 |
322 |
323 | if __name__ == "__main__":
324 | parser = argparse.ArgumentParser(description='Learning with Complementary Labels')
325 | parser.add_argument('--cl_lr', type=float, default=1e-3, help='learning rate for complementary learning')
326 | parser.add_argument('--at_lr', type=float, default=1e-2, help='learning rate for adversarial training')
327 | parser.add_argument('--batch_size', type=int, default=256, help='batch_size of ordinary labels.')
328 | parser.add_argument('--cl_num', type=int, default=1, help='(1-9): the number of complementary labels of each data; (0): mul-cls data distribution of ICML2020')
329 | parser.add_argument('--dataset', type=str, default="mnist", choices=['mnist', 'kuzushiji', 'fashion', 'cifar10', 'svhn', 'cifar100'],
330 | help="dataset, choose from mnist, kuzushiji, fashion, cifar10, svhn, cifar100")
331 | parser.add_argument('--framework', type=str, default='one_stage', choices=['one_stage', 'two_stage'])
332 | parser.add_argument('--method', type=str, default='log', choices=['free', 'nn', 'ga', 'pc', 'forward', 'scl_exp',
333 | 'scl_nl', 'mae', 'mse', 'ce', 'gce', 'phuber_ce', 'log', 'exp', 'l_uw', 'l_w', 'log_ce', 'exp_ce'])
334 | parser.add_argument('--model', type=str, default='cnn', choices=['linear', 'mlp', 'cnn', 'resnet18', 'densenet', 'preact_resnet18', 'wrn'], help='model name')
335 | parser.add_argument('--cl_epochs', default=0, type=int, help='number of epochs for cl learning')
336 | parser.add_argument('--adv_epochs', default=100, type=int, help='number of epochs for adv')
337 | parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay')
338 | parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')
339 | parser.add_argument('--seed', type=int, nargs='+', default=[1, 2, 3], help='random seed')
340 | parser.add_argument('--out_dir', type=str, default='./ATCL_result', help='dir of output')
341 | # for adv training
342 | parser.add_argument('--epsilon', type=float, default=0.3, help='perturbation bound')
343 | parser.add_argument('--num_steps', type=int, default=40, help='maximum perturbation step K')
344 | parser.add_argument('--step_size', type=float, default=0.01, help='step size')
345 | parser.add_argument('--scheduler', type=str, default="none", choices=['linear', 'cosine', 'none'], help='epsilon scheduler')
346 | parser.add_argument('--sch_epoch', type=int, default=0, help='scheduler epoch')
347 | parser.add_argument('--warmup_epoch', type=int, default=0, help='warmup epoch for exponential moving average')
348 | args = parser.parse_args()
349 |
350 | # Hardcoded
351 | if args.dataset in ["cifar10", "svhn", "cifar100"]:
352 | args.weight_decay, args.batch_size = 5e-4, 128
353 | args.epsilon, args.num_steps, args.step_size = 8/255, 10, 2/255
354 |
355 | pp.pprint(vars(args))
356 |
357 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
358 | torch.backends.cudnn.benchmark = False
359 | torch.backends.cudnn.deterministic = True
360 |
361 | last_nature, last_pgd20, last_cw, best_nature, best_pgd20, best_cw = [], [], [], [], [], []
362 | for seed in args.seed:
363 | random.seed(seed)
364 | np.random.seed(seed)
365 | torch.manual_seed(seed)
366 | torch.cuda.manual_seed_all(seed)
367 |
368 | # Store path
369 | if not os.path.exists(args.out_dir):
370 | os.makedirs(args.out_dir)
371 |
372 | data_aug = transforms.Compose([
373 | transforms.RandomCrop(32, padding=4),
374 | transforms.RandomHorizontalFlip(),
375 | transforms.ToTensor()
376 | ]) if args.dataset in ["cifar10", "svhn", "cifar100"] else None
377 | train_loader, test_loader, ordinary_train_dataset, test_dataset, K, input_dim, input_channel = prepare_data(args)
378 | complementary_train_loader, ccp, partialY, ema = prepare_train_loaders(args, ordinary_train_dataset=ordinary_train_dataset, data_aug=data_aug)
379 | partialY, ema = partialY.to(device), ema.to(device)
380 |
381 | model = create_model(args, input_dim, input_channel, K)
382 | if args.framework == "two_stage":
383 | adversarial_train(args, model, args.cl_epochs, mode="cl", seed=seed)
384 | model = create_model(args, input_dim, input_channel, K)
385 | last_res, best_res = adversarial_train(args, model, args.adv_epochs, mode="at", seed=seed)
386 | else:
387 | last_res, best_res = adversarial_train(args, model, args.adv_epochs, mode="atcl", seed=seed)
388 |
389 | last_nature.append(last_res[0]);last_pgd20.append(last_res[1]);last_cw.append(last_res[2])
390 | best_nature.append(best_res[0]);best_pgd20.append(best_res[1]);best_cw.append(best_res[2])
391 |
392 | print(last_nature);print(last_pgd20);print(last_cw);print(best_nature);print(best_pgd20);print(best_cw)
393 | print(">> Last Nature: {}($\pm${})".format(round(np.mean(last_nature), 2), round(np.std(last_nature, ddof=0), 2)))
394 | print(">> Last PGD20: {}($\pm${})".format(round(np.mean(last_pgd20), 2), round(np.std(last_pgd20, ddof=0), 2)))
395 | print(">> Last CW: {}($\pm${})".format(round(np.mean(last_cw), 2), round(np.std(last_cw, ddof=0), 2)))
396 | print(">> Best Nature: {}($\pm${})".format(round(np.mean(best_nature), 2), round(np.std(best_nature, ddof=0), 2)))
397 | print(">> Best PGD20: {}($\pm${})".format(round(np.mean(best_pgd20), 2), round(np.std(best_pgd20, ddof=0), 2)))
398 | print(">> Best CW: {}($\pm${})".format(round(np.mean(best_cw), 2), round(np.std(best_cw, ddof=0), 2)))
399 |
--------------------------------------------------------------------------------