├── InceptionResNetV2.py
├── LICENSE
├── PreResNet.py
├── README.md
├── Train_cifar.py
├── Train_clothing1M.py
├── Train_webvision.py
├── Train_webvision_parallel.py
├── dataloader_cifar.py
├── dataloader_clothing1M.py
├── dataloader_webvision.py
└── img
└── framework.png
/InceptionResNetV2.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division, absolute_import
2 | import torch
3 | import torch.nn as nn
4 | import os
5 | import sys
6 |
7 |
8 | class BasicConv2d(nn.Module):
9 |
10 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
11 | super(BasicConv2d, self).__init__()
12 | self.conv = nn.Conv2d(in_planes, out_planes,
13 | kernel_size=kernel_size, stride=stride,
14 | padding=padding, bias=False) # verify bias false
15 | self.bn = nn.BatchNorm2d(out_planes,
16 | eps=0.001, # value found in tensorflow
17 | momentum=0.1, # default pytorch value
18 | affine=True)
19 | self.relu = nn.ReLU(inplace=False)
20 |
21 | def forward(self, x):
22 | x = self.conv(x)
23 | x = self.bn(x)
24 | x = self.relu(x)
25 | return x
26 |
27 |
28 | class Mixed_5b(nn.Module):
29 |
30 | def __init__(self):
31 | super(Mixed_5b, self).__init__()
32 |
33 | self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
34 |
35 | self.branch1 = nn.Sequential(
36 | BasicConv2d(192, 48, kernel_size=1, stride=1),
37 | BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
38 | )
39 |
40 | self.branch2 = nn.Sequential(
41 | BasicConv2d(192, 64, kernel_size=1, stride=1),
42 | BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
43 | BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
44 | )
45 |
46 | self.branch3 = nn.Sequential(
47 | nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
48 | BasicConv2d(192, 64, kernel_size=1, stride=1)
49 | )
50 |
51 | def forward(self, x):
52 | x0 = self.branch0(x)
53 | x1 = self.branch1(x)
54 | x2 = self.branch2(x)
55 | x3 = self.branch3(x)
56 | out = torch.cat((x0, x1, x2, x3), 1)
57 | return out
58 |
59 |
60 | class Block35(nn.Module):
61 |
62 | def __init__(self, scale=1.0):
63 | super(Block35, self).__init__()
64 |
65 | self.scale = scale
66 |
67 | self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
68 |
69 | self.branch1 = nn.Sequential(
70 | BasicConv2d(320, 32, kernel_size=1, stride=1),
71 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
72 | )
73 |
74 | self.branch2 = nn.Sequential(
75 | BasicConv2d(320, 32, kernel_size=1, stride=1),
76 | BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
77 | BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
78 | )
79 |
80 | self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
81 | self.relu = nn.ReLU(inplace=False)
82 |
83 | def forward(self, x):
84 | x0 = self.branch0(x)
85 | x1 = self.branch1(x)
86 | x2 = self.branch2(x)
87 | out = torch.cat((x0, x1, x2), 1)
88 | out = self.conv2d(out)
89 | out = out * self.scale + x
90 | out = self.relu(out)
91 | return out
92 |
93 |
94 | class Mixed_6a(nn.Module):
95 |
96 | def __init__(self):
97 | super(Mixed_6a, self).__init__()
98 |
99 | self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
100 |
101 | self.branch1 = nn.Sequential(
102 | BasicConv2d(320, 256, kernel_size=1, stride=1),
103 | BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
104 | BasicConv2d(256, 384, kernel_size=3, stride=2)
105 | )
106 |
107 | self.branch2 = nn.MaxPool2d(3, stride=2)
108 |
109 | def forward(self, x):
110 | x0 = self.branch0(x)
111 | x1 = self.branch1(x)
112 | x2 = self.branch2(x)
113 | out = torch.cat((x0, x1, x2), 1)
114 | return out
115 |
116 |
117 | class Block17(nn.Module):
118 |
119 | def __init__(self, scale=1.0):
120 | super(Block17, self).__init__()
121 |
122 | self.scale = scale
123 |
124 | self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
125 |
126 | self.branch1 = nn.Sequential(
127 | BasicConv2d(1088, 128, kernel_size=1, stride=1),
128 | BasicConv2d(128, 160, kernel_size=(1,7), stride=1, padding=(0,3)),
129 | BasicConv2d(160, 192, kernel_size=(7,1), stride=1, padding=(3,0))
130 | )
131 |
132 | self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
133 | self.relu = nn.ReLU(inplace=False)
134 |
135 | def forward(self, x):
136 | x0 = self.branch0(x)
137 | x1 = self.branch1(x)
138 | out = torch.cat((x0, x1), 1)
139 | out = self.conv2d(out)
140 | out = out * self.scale + x
141 | out = self.relu(out)
142 | return out
143 |
144 |
145 | class Mixed_7a(nn.Module):
146 |
147 | def __init__(self):
148 | super(Mixed_7a, self).__init__()
149 |
150 | self.branch0 = nn.Sequential(
151 | BasicConv2d(1088, 256, kernel_size=1, stride=1),
152 | BasicConv2d(256, 384, kernel_size=3, stride=2)
153 | )
154 |
155 | self.branch1 = nn.Sequential(
156 | BasicConv2d(1088, 256, kernel_size=1, stride=1),
157 | BasicConv2d(256, 288, kernel_size=3, stride=2)
158 | )
159 |
160 | self.branch2 = nn.Sequential(
161 | BasicConv2d(1088, 256, kernel_size=1, stride=1),
162 | BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
163 | BasicConv2d(288, 320, kernel_size=3, stride=2)
164 | )
165 |
166 | self.branch3 = nn.MaxPool2d(3, stride=2)
167 |
168 | def forward(self, x):
169 | x0 = self.branch0(x)
170 | x1 = self.branch1(x)
171 | x2 = self.branch2(x)
172 | x3 = self.branch3(x)
173 | out = torch.cat((x0, x1, x2, x3), 1)
174 | return out
175 |
176 |
177 | class Block8(nn.Module):
178 |
179 | def __init__(self, scale=1.0, noReLU=False):
180 | super(Block8, self).__init__()
181 |
182 | self.scale = scale
183 | self.noReLU = noReLU
184 |
185 | self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
186 |
187 | self.branch1 = nn.Sequential(
188 | BasicConv2d(2080, 192, kernel_size=1, stride=1),
189 | BasicConv2d(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)),
190 | BasicConv2d(224, 256, kernel_size=(3,1), stride=1, padding=(1,0))
191 | )
192 |
193 | self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
194 | if not self.noReLU:
195 | self.relu = nn.ReLU(inplace=False)
196 |
197 | def forward(self, x):
198 | x0 = self.branch0(x)
199 | x1 = self.branch1(x)
200 | out = torch.cat((x0, x1), 1)
201 | out = self.conv2d(out)
202 | out = out * self.scale + x
203 | if not self.noReLU:
204 | out = self.relu(out)
205 | return out
206 |
207 |
208 | class InceptionResNetV2(nn.Module):
209 |
210 | def __init__(self, num_classes=1001):
211 | super(InceptionResNetV2, self).__init__()
212 | # Special attributs
213 | self.input_space = None
214 | self.input_size = (299, 299, 3)
215 | self.mean = None
216 | self.std = None
217 | # Modules
218 | self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
219 | self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
220 | self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
221 | self.maxpool_3a = nn.MaxPool2d(3, stride=2)
222 | self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
223 | self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
224 | self.maxpool_5a = nn.MaxPool2d(3, stride=2)
225 | self.mixed_5b = Mixed_5b()
226 | self.repeat = nn.Sequential(
227 | Block35(scale=0.17),
228 | Block35(scale=0.17),
229 | Block35(scale=0.17),
230 | Block35(scale=0.17),
231 | Block35(scale=0.17),
232 | Block35(scale=0.17),
233 | Block35(scale=0.17),
234 | Block35(scale=0.17),
235 | Block35(scale=0.17),
236 | Block35(scale=0.17)
237 | )
238 | self.mixed_6a = Mixed_6a()
239 | self.repeat_1 = nn.Sequential(
240 | Block17(scale=0.10),
241 | Block17(scale=0.10),
242 | Block17(scale=0.10),
243 | Block17(scale=0.10),
244 | Block17(scale=0.10),
245 | Block17(scale=0.10),
246 | Block17(scale=0.10),
247 | Block17(scale=0.10),
248 | Block17(scale=0.10),
249 | Block17(scale=0.10),
250 | Block17(scale=0.10),
251 | Block17(scale=0.10),
252 | Block17(scale=0.10),
253 | Block17(scale=0.10),
254 | Block17(scale=0.10),
255 | Block17(scale=0.10),
256 | Block17(scale=0.10),
257 | Block17(scale=0.10),
258 | Block17(scale=0.10),
259 | Block17(scale=0.10)
260 | )
261 | self.mixed_7a = Mixed_7a()
262 | self.repeat_2 = nn.Sequential(
263 | Block8(scale=0.20),
264 | Block8(scale=0.20),
265 | Block8(scale=0.20),
266 | Block8(scale=0.20),
267 | Block8(scale=0.20),
268 | Block8(scale=0.20),
269 | Block8(scale=0.20),
270 | Block8(scale=0.20),
271 | Block8(scale=0.20)
272 | )
273 | self.block8 = Block8(noReLU=True)
274 | self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)
275 | self.avgpool_1a = nn.AvgPool2d(8, count_include_pad=False)
276 | self.last_linear = nn.Linear(1536, num_classes)
277 |
278 | def features(self, input):
279 | x = self.conv2d_1a(input)
280 | x = self.conv2d_2a(x)
281 | x = self.conv2d_2b(x)
282 | x = self.maxpool_3a(x)
283 | x = self.conv2d_3b(x)
284 | x = self.conv2d_4a(x)
285 | x = self.maxpool_5a(x)
286 | x = self.mixed_5b(x)
287 | x = self.repeat(x)
288 | x = self.mixed_6a(x)
289 | x = self.repeat_1(x)
290 | x = self.mixed_7a(x)
291 | x = self.repeat_2(x)
292 | x = self.block8(x)
293 | x = self.conv2d_7b(x)
294 | return x
295 |
296 | def logits(self, features):
297 | x = self.avgpool_1a(features)
298 | x = x.view(x.size(0), -1)
299 | x = self.last_linear(x)
300 | return x
301 |
302 | def forward(self, input):
303 | x = self.features(input)
304 | x = self.logits(x)
305 | return x
306 |
307 |
308 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Junnan Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/PreResNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from torch.autograd import Variable
6 |
7 |
8 | def conv3x3(in_planes, out_planes, stride=1):
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
10 |
11 |
12 | class BasicBlock(nn.Module):
13 | expansion = 1
14 |
15 | def __init__(self, in_planes, planes, stride=1):
16 | super(BasicBlock, self).__init__()
17 | self.conv1 = conv3x3(in_planes, planes, stride)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.conv2 = conv3x3(planes, planes)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 |
22 | self.shortcut = nn.Sequential()
23 | if stride != 1 or in_planes != self.expansion*planes:
24 | self.shortcut = nn.Sequential(
25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
26 | nn.BatchNorm2d(self.expansion*planes)
27 | )
28 |
29 | def forward(self, x):
30 | out = F.relu(self.bn1(self.conv1(x)))
31 | out = self.bn2(self.conv2(out))
32 | out += self.shortcut(x)
33 | out = F.relu(out)
34 | return out
35 |
36 |
37 | class PreActBlock(nn.Module):
38 | '''Pre-activation version of the BasicBlock.'''
39 | expansion = 1
40 |
41 | def __init__(self, in_planes, planes, stride=1):
42 | super(PreActBlock, self).__init__()
43 | self.bn1 = nn.BatchNorm2d(in_planes)
44 | self.conv1 = conv3x3(in_planes, planes, stride)
45 | self.bn2 = nn.BatchNorm2d(planes)
46 | self.conv2 = conv3x3(planes, planes)
47 |
48 | self.shortcut = nn.Sequential()
49 | if stride != 1 or in_planes != self.expansion*planes:
50 | self.shortcut = nn.Sequential(
51 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
52 | )
53 |
54 | def forward(self, x):
55 | out = F.relu(self.bn1(x))
56 | shortcut = self.shortcut(out)
57 | out = self.conv1(out)
58 | out = self.conv2(F.relu(self.bn2(out)))
59 | out += shortcut
60 | return out
61 |
62 |
63 | class Bottleneck(nn.Module):
64 | expansion = 4
65 |
66 | def __init__(self, in_planes, planes, stride=1):
67 | super(Bottleneck, self).__init__()
68 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
69 | self.bn1 = nn.BatchNorm2d(planes)
70 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
71 | self.bn2 = nn.BatchNorm2d(planes)
72 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
73 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
74 |
75 | self.shortcut = nn.Sequential()
76 | if stride != 1 or in_planes != self.expansion*planes:
77 | self.shortcut = nn.Sequential(
78 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
79 | nn.BatchNorm2d(self.expansion*planes)
80 | )
81 |
82 | def forward(self, x):
83 | out = F.relu(self.bn1(self.conv1(x)))
84 | out = F.relu(self.bn2(self.conv2(out)))
85 | out = self.bn3(self.conv3(out))
86 | out += self.shortcut(x)
87 | out = F.relu(out)
88 | return out
89 |
90 |
91 | class PreActBottleneck(nn.Module):
92 | '''Pre-activation version of the original Bottleneck module.'''
93 | expansion = 4
94 |
95 | def __init__(self, in_planes, planes, stride=1):
96 | super(PreActBottleneck, self).__init__()
97 | self.bn1 = nn.BatchNorm2d(in_planes)
98 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
99 | self.bn2 = nn.BatchNorm2d(planes)
100 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
101 | self.bn3 = nn.BatchNorm2d(planes)
102 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
103 |
104 | self.shortcut = nn.Sequential()
105 | if stride != 1 or in_planes != self.expansion*planes:
106 | self.shortcut = nn.Sequential(
107 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
108 | )
109 |
110 | def forward(self, x):
111 | out = F.relu(self.bn1(x))
112 | shortcut = self.shortcut(out)
113 | out = self.conv1(out)
114 | out = self.conv2(F.relu(self.bn2(out)))
115 | out = self.conv3(F.relu(self.bn3(out)))
116 | out += shortcut
117 | return out
118 |
119 |
120 | class ResNet(nn.Module):
121 | def __init__(self, block, num_blocks, num_classes=10):
122 | super(ResNet, self).__init__()
123 | self.in_planes = 64
124 |
125 | self.conv1 = conv3x3(3,64)
126 | self.bn1 = nn.BatchNorm2d(64)
127 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
128 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
129 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
130 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
131 | self.linear = nn.Linear(512*block.expansion, num_classes)
132 |
133 | def _make_layer(self, block, planes, num_blocks, stride):
134 | strides = [stride] + [1]*(num_blocks-1)
135 | layers = []
136 | for stride in strides:
137 | layers.append(block(self.in_planes, planes, stride))
138 | self.in_planes = planes * block.expansion
139 | return nn.Sequential(*layers)
140 |
141 | def forward(self, x, lin=0, lout=5):
142 | out = x
143 | if lin < 1 and lout > -1:
144 | out = self.conv1(out)
145 | out = self.bn1(out)
146 | out = F.relu(out)
147 | if lin < 2 and lout > 0:
148 | out = self.layer1(out)
149 | if lin < 3 and lout > 1:
150 | out = self.layer2(out)
151 | if lin < 4 and lout > 2:
152 | out = self.layer3(out)
153 | if lin < 5 and lout > 3:
154 | out = self.layer4(out)
155 | if lout > 4:
156 | out = F.avg_pool2d(out, 4)
157 | out = out.view(out.size(0), -1)
158 | out = self.linear(out)
159 | return out
160 |
161 |
162 | def ResNet18(num_classes=10):
163 | return ResNet(PreActBlock, [2,2,2,2], num_classes=num_classes)
164 |
165 | def ResNet34(num_classes=10):
166 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes)
167 |
168 | def ResNet50(num_classes=10):
169 | return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes)
170 |
171 | def ResNet101(num_classes=10):
172 | return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes)
173 |
174 | def ResNet152(num_classes=10):
175 | return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes)
176 |
177 |
178 | def test():
179 | net = ResNet18()
180 | y = net(Variable(torch.randn(1,3,32,32)))
181 | print(y.size())
182 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DivideMix: Learning with Noisy Labels as Semi-supervised Learning
2 | PyTorch Code for the following paper at ICLR2020:\
3 | Title: DivideMix: Learning with Noisy Labels as Semi-supervised Learning [pdf]\
4 | Authors:Junnan Li, Richard Socher, Steven C.H. Hoi\
5 | Institute: Salesforce Research
6 |
7 |
8 | Abstract\
9 | Deep neural networks are known to be annotation-hungry. Numerous efforts have been devoted to reduce the annotation cost when learning with deep networks. Two prominent directions include learning with noisy labels and semi-supervised learning by exploiting unlabeled data. In this work, we propose DivideMix, a novel framework for learning with noisy labels by leveraging semi-supervised learning techniques. In particular, DivideMix models the per-sample loss distribution with a mixture model to dynamically divide the training data into a labeled set with clean samples and an unlabeled set with noisy samples, and trains the model on both the labeled and unlabeled data in a semi-supervised manner. To avoid confirmation bias, we simultaneously train two diverged networks where each network uses the dataset division from the other network. During the semi-supervised training phase, we improve the MixMatch strategy by performing label co-refinement and label co-guessing on labeled and unlabeled samples, respectively. Experiments on multiple benchmark datasets demonstrate substantial improvements over state-of-the-art methods.
10 |
11 |
12 | Illustration\
13 |
14 |
15 | Experiments\
16 | First, please create a folder named checkpoint to store the results.\
17 | mkdir checkpoint
\
18 | Next, run \
19 | python Train_{dataset_name}.py --data_path path-to-your-data
20 |
21 | Cite DivideMix\
22 | If you find the code useful in your research, please consider citing our paper:
23 |
24 |
25 | @inproceedings{
26 | li2020dividemix,
27 | title={DivideMix: Learning with Noisy Labels as Semi-supervised Learning},
28 | author={Junnan Li and Richard Socher and Steven C.H. Hoi},
29 | booktitle={International Conference on Learning Representations},
30 | year={2020},
31 | }
32 |
33 | License\
34 | This project is licensed under the terms of the MIT license.
35 |
--------------------------------------------------------------------------------
/Train_cifar.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import sys
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | import torch.nn.functional as F
7 | import torch.backends.cudnn as cudnn
8 | import random
9 | import os
10 | import argparse
11 | import numpy as np
12 | from PreResNet import *
13 | from sklearn.mixture import GaussianMixture
14 | import dataloader_cifar as dataloader
15 |
16 | parser = argparse.ArgumentParser(description='PyTorch CIFAR Training')
17 | parser.add_argument('--batch_size', default=64, type=int, help='train batchsize')
18 | parser.add_argument('--lr', '--learning_rate', default=0.02, type=float, help='initial learning rate')
19 | parser.add_argument('--noise_mode', default='sym')
20 | parser.add_argument('--alpha', default=4, type=float, help='parameter for Beta')
21 | parser.add_argument('--lambda_u', default=25, type=float, help='weight for unsupervised loss')
22 | parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold')
23 | parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
24 | parser.add_argument('--num_epochs', default=300, type=int)
25 | parser.add_argument('--r', default=0.5, type=float, help='noise ratio')
26 | parser.add_argument('--id', default='')
27 | parser.add_argument('--seed', default=123)
28 | parser.add_argument('--gpuid', default=0, type=int)
29 | parser.add_argument('--num_class', default=10, type=int)
30 | parser.add_argument('--data_path', default='./cifar-10', type=str, help='path to dataset')
31 | parser.add_argument('--dataset', default='cifar10', type=str)
32 | args = parser.parse_args()
33 |
34 | torch.cuda.set_device(args.gpuid)
35 | random.seed(args.seed)
36 | torch.manual_seed(args.seed)
37 | torch.cuda.manual_seed_all(args.seed)
38 |
39 |
40 | # Training
41 | def train(epoch,net,net2,optimizer,labeled_trainloader,unlabeled_trainloader):
42 | net.train()
43 | net2.eval() #fix one network and train the other
44 |
45 | unlabeled_train_iter = iter(unlabeled_trainloader)
46 | num_iter = (len(labeled_trainloader.dataset)//args.batch_size)+1
47 | for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader):
48 | try:
49 | inputs_u, inputs_u2 = unlabeled_train_iter.next()
50 | except:
51 | unlabeled_train_iter = iter(unlabeled_trainloader)
52 | inputs_u, inputs_u2 = unlabeled_train_iter.next()
53 | batch_size = inputs_x.size(0)
54 |
55 | # Transform label to one-hot
56 | labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1,1), 1)
57 | w_x = w_x.view(-1,1).type(torch.FloatTensor)
58 |
59 | inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(), inputs_x2.cuda(), labels_x.cuda(), w_x.cuda()
60 | inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda()
61 |
62 | with torch.no_grad():
63 | # label co-guessing of unlabeled samples
64 | outputs_u11 = net(inputs_u)
65 | outputs_u12 = net(inputs_u2)
66 | outputs_u21 = net2(inputs_u)
67 | outputs_u22 = net2(inputs_u2)
68 |
69 | pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) + torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4
70 | ptu = pu**(1/args.T) # temparature sharpening
71 |
72 | targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
73 | targets_u = targets_u.detach()
74 |
75 | # label refinement of labeled samples
76 | outputs_x = net(inputs_x)
77 | outputs_x2 = net(inputs_x2)
78 |
79 | px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
80 | px = w_x*labels_x + (1-w_x)*px
81 | ptx = px**(1/args.T) # temparature sharpening
82 |
83 | targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize
84 | targets_x = targets_x.detach()
85 |
86 | # mixmatch
87 | l = np.random.beta(args.alpha, args.alpha)
88 | l = max(l, 1-l)
89 |
90 | all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
91 | all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)
92 |
93 | idx = torch.randperm(all_inputs.size(0))
94 |
95 | input_a, input_b = all_inputs, all_inputs[idx]
96 | target_a, target_b = all_targets, all_targets[idx]
97 |
98 | mixed_input = l * input_a + (1 - l) * input_b
99 | mixed_target = l * target_a + (1 - l) * target_b
100 |
101 | logits = net(mixed_input)
102 | logits_x = logits[:batch_size*2]
103 | logits_u = logits[batch_size*2:]
104 |
105 | Lx, Lu, lamb = criterion(logits_x, mixed_target[:batch_size*2], logits_u, mixed_target[batch_size*2:], epoch+batch_idx/num_iter, warm_up)
106 |
107 | # regularization
108 | prior = torch.ones(args.num_class)/args.num_class
109 | prior = prior.cuda()
110 | pred_mean = torch.softmax(logits, dim=1).mean(0)
111 | penalty = torch.sum(prior*torch.log(prior/pred_mean))
112 |
113 | loss = Lx + lamb * Lu + penalty
114 | # compute gradient and do SGD step
115 | optimizer.zero_grad()
116 | loss.backward()
117 | optimizer.step()
118 |
119 | sys.stdout.write('\r')
120 | sys.stdout.write('%s:%.1f-%s | Epoch [%3d/%3d] Iter[%3d/%3d]\t Labeled loss: %.2f Unlabeled loss: %.2f'
121 | %(args.dataset, args.r, args.noise_mode, epoch, args.num_epochs, batch_idx+1, num_iter, Lx.item(), Lu.item()))
122 | sys.stdout.flush()
123 |
124 | def warmup(epoch,net,optimizer,dataloader):
125 | net.train()
126 | num_iter = (len(dataloader.dataset)//dataloader.batch_size)+1
127 | for batch_idx, (inputs, labels, path) in enumerate(dataloader):
128 | inputs, labels = inputs.cuda(), labels.cuda()
129 | optimizer.zero_grad()
130 | outputs = net(inputs)
131 | loss = CEloss(outputs, labels)
132 | if args.noise_mode=='asym': # penalize confident prediction for asymmetric noise
133 | penalty = conf_penalty(outputs)
134 | L = loss + penalty
135 | elif args.noise_mode=='sym':
136 | L = loss
137 | L.backward()
138 | optimizer.step()
139 |
140 | sys.stdout.write('\r')
141 | sys.stdout.write('%s:%.1f-%s | Epoch [%3d/%3d] Iter[%3d/%3d]\t CE-loss: %.4f'
142 | %(args.dataset, args.r, args.noise_mode, epoch, args.num_epochs, batch_idx+1, num_iter, loss.item()))
143 | sys.stdout.flush()
144 |
145 | def test(epoch,net1,net2):
146 | net1.eval()
147 | net2.eval()
148 | correct = 0
149 | total = 0
150 | with torch.no_grad():
151 | for batch_idx, (inputs, targets) in enumerate(test_loader):
152 | inputs, targets = inputs.cuda(), targets.cuda()
153 | outputs1 = net1(inputs)
154 | outputs2 = net2(inputs)
155 | outputs = outputs1+outputs2
156 | _, predicted = torch.max(outputs, 1)
157 |
158 | total += targets.size(0)
159 | correct += predicted.eq(targets).cpu().sum().item()
160 | acc = 100.*correct/total
161 | print("\n| Test Epoch #%d\t Accuracy: %.2f%%\n" %(epoch,acc))
162 | test_log.write('Epoch:%d Accuracy:%.2f\n'%(epoch,acc))
163 | test_log.flush()
164 |
165 | def eval_train(model,all_loss):
166 | model.eval()
167 | losses = torch.zeros(50000)
168 | with torch.no_grad():
169 | for batch_idx, (inputs, targets, index) in enumerate(eval_loader):
170 | inputs, targets = inputs.cuda(), targets.cuda()
171 | outputs = model(inputs)
172 | loss = CE(outputs, targets)
173 | for b in range(inputs.size(0)):
174 | losses[index[b]]=loss[b]
175 | losses = (losses-losses.min())/(losses.max()-losses.min())
176 | all_loss.append(losses)
177 |
178 | if args.r==0.9: # average loss over last 5 epochs to improve convergence stability
179 | history = torch.stack(all_loss)
180 | input_loss = history[-5:].mean(0)
181 | input_loss = input_loss.reshape(-1,1)
182 | else:
183 | input_loss = losses.reshape(-1,1)
184 |
185 | # fit a two-component GMM to the loss
186 | gmm = GaussianMixture(n_components=2,max_iter=10,tol=1e-2,reg_covar=5e-4)
187 | gmm.fit(input_loss)
188 | prob = gmm.predict_proba(input_loss)
189 | prob = prob[:,gmm.means_.argmin()]
190 | return prob,all_loss
191 |
192 | def linear_rampup(current, warm_up, rampup_length=16):
193 | current = np.clip((current-warm_up) / rampup_length, 0.0, 1.0)
194 | return args.lambda_u*float(current)
195 |
196 | class SemiLoss(object):
197 | def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up):
198 | probs_u = torch.softmax(outputs_u, dim=1)
199 |
200 | Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
201 | Lu = torch.mean((probs_u - targets_u)**2)
202 |
203 | return Lx, Lu, linear_rampup(epoch,warm_up)
204 |
205 | class NegEntropy(object):
206 | def __call__(self,outputs):
207 | probs = torch.softmax(outputs, dim=1)
208 | return torch.mean(torch.sum(probs.log()*probs, dim=1))
209 |
210 | def create_model():
211 | model = ResNet18(num_classes=args.num_class)
212 | model = model.cuda()
213 | return model
214 |
215 | stats_log=open('./checkpoint/%s_%.1f_%s'%(args.dataset,args.r,args.noise_mode)+'_stats.txt','w')
216 | test_log=open('./checkpoint/%s_%.1f_%s'%(args.dataset,args.r,args.noise_mode)+'_acc.txt','w')
217 |
218 | if args.dataset=='cifar10':
219 | warm_up = 10
220 | elif args.dataset=='cifar100':
221 | warm_up = 30
222 |
223 | loader = dataloader.cifar_dataloader(args.dataset,r=args.r,noise_mode=args.noise_mode,batch_size=args.batch_size,num_workers=5,\
224 | root_dir=args.data_path,log=stats_log,noise_file='%s/%.1f_%s.json'%(args.data_path,args.r,args.noise_mode))
225 |
226 | print('| Building net')
227 | net1 = create_model()
228 | net2 = create_model()
229 | cudnn.benchmark = True
230 |
231 | criterion = SemiLoss()
232 | optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
233 | optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
234 |
235 | CE = nn.CrossEntropyLoss(reduction='none')
236 | CEloss = nn.CrossEntropyLoss()
237 | if args.noise_mode=='asym':
238 | conf_penalty = NegEntropy()
239 |
240 | all_loss = [[],[]] # save the history of losses from two networks
241 |
242 | for epoch in range(args.num_epochs+1):
243 | lr=args.lr
244 | if epoch >= 150:
245 | lr /= 10
246 | for param_group in optimizer1.param_groups:
247 | param_group['lr'] = lr
248 | for param_group in optimizer2.param_groups:
249 | param_group['lr'] = lr
250 | test_loader = loader.run('test')
251 | eval_loader = loader.run('eval_train')
252 |
253 | if epoch args.p_threshold)
265 | pred2 = (prob2 > args.p_threshold)
266 |
267 | print('Train Net1')
268 | labeled_trainloader, unlabeled_trainloader = loader.run('train',pred2,prob2) # co-divide
269 | train(epoch,net1,net2,optimizer1,labeled_trainloader, unlabeled_trainloader) # train net1
270 |
271 | print('\nTrain Net2')
272 | labeled_trainloader, unlabeled_trainloader = loader.run('train',pred1,prob1) # co-divide
273 | train(epoch,net2,net1,optimizer2,labeled_trainloader, unlabeled_trainloader) # train net2
274 |
275 | test(epoch,net1,net2)
276 |
277 |
278 |
--------------------------------------------------------------------------------
/Train_clothing1M.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import sys
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | import torch.nn.functional as F
7 | import torch.backends.cudnn as cudnn
8 | import torchvision
9 | import torchvision.models as models
10 | import random
11 | import os
12 | import argparse
13 | import numpy as np
14 | import dataloader_clothing1M as dataloader
15 | from sklearn.mixture import GaussianMixture
16 |
17 | parser = argparse.ArgumentParser(description='PyTorch Clothing1M Training')
18 | parser.add_argument('--batch_size', default=32, type=int, help='train batchsize')
19 | parser.add_argument('--lr', '--learning_rate', default=0.002, type=float, help='initial learning rate')
20 | parser.add_argument('--alpha', default=0.5, type=float, help='parameter for Beta')
21 | parser.add_argument('--lambda_u', default=0, type=float, help='weight for unsupervised loss')
22 | parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold')
23 | parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
24 | parser.add_argument('--num_epochs', default=80, type=int)
25 | parser.add_argument('--id', default='clothing1m')
26 | parser.add_argument('--data_path', default='../../Clothing1M/data', type=str, help='path to dataset')
27 | parser.add_argument('--seed', default=123)
28 | parser.add_argument('--gpuid', default=0, type=int)
29 | parser.add_argument('--num_class', default=14, type=int)
30 | parser.add_argument('--num_batches', default=1000, type=int)
31 | args = parser.parse_args()
32 |
33 | torch.cuda.set_device(args.gpuid)
34 | random.seed(args.seed)
35 | torch.manual_seed(args.seed)
36 | torch.cuda.manual_seed_all(args.seed)
37 |
38 | # Training
39 | def train(epoch,net,net2,optimizer,labeled_trainloader,unlabeled_trainloader):
40 | net.train()
41 | net2.eval() #fix one network and train the other
42 |
43 | unlabeled_train_iter = iter(unlabeled_trainloader)
44 | num_iter = (len(labeled_trainloader.dataset)//args.batch_size)+1
45 | for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader):
46 | try:
47 | inputs_u, inputs_u2 = unlabeled_train_iter.next()
48 | except:
49 | unlabeled_train_iter = iter(unlabeled_trainloader)
50 | inputs_u, inputs_u2 = unlabeled_train_iter.next()
51 | batch_size = inputs_x.size(0)
52 |
53 | # Transform label to one-hot
54 | labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1,1), 1)
55 | w_x = w_x.view(-1,1).type(torch.FloatTensor)
56 |
57 | inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(), inputs_x2.cuda(), labels_x.cuda(), w_x.cuda()
58 | inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda()
59 |
60 | with torch.no_grad():
61 | # label co-guessing of unlabeled samples
62 | outputs_u11 = net(inputs_u)
63 | outputs_u12 = net(inputs_u2)
64 | outputs_u21 = net2(inputs_u)
65 | outputs_u22 = net2(inputs_u2)
66 |
67 | pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) + torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4
68 | ptu = pu**(1/args.T) # temparature sharpening
69 |
70 | targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
71 | targets_u = targets_u.detach()
72 |
73 | # label refinement of labeled samples
74 | outputs_x = net(inputs_x)
75 | outputs_x2 = net(inputs_x2)
76 |
77 | px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
78 | px = w_x*labels_x + (1-w_x)*px
79 | ptx = px**(1/args.T) # temparature sharpening
80 |
81 | targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize
82 | targets_x = targets_x.detach()
83 |
84 | # mixmatch
85 | l = np.random.beta(args.alpha, args.alpha)
86 | l = max(l, 1-l)
87 |
88 | all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
89 | all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)
90 |
91 | idx = torch.randperm(all_inputs.size(0))
92 |
93 | input_a, input_b = all_inputs, all_inputs[idx]
94 | target_a, target_b = all_targets, all_targets[idx]
95 |
96 | mixed_input = l * input_a[:batch_size*2] + (1 - l) * input_b[:batch_size*2]
97 | mixed_target = l * target_a[:batch_size*2] + (1 - l) * target_b[:batch_size*2]
98 |
99 | logits = net(mixed_input)
100 |
101 | Lx = -torch.mean(torch.sum(F.log_softmax(logits, dim=1) * mixed_target, dim=1))
102 |
103 | # regularization
104 | prior = torch.ones(args.num_class)/args.num_class
105 | prior = prior.cuda()
106 | pred_mean = torch.softmax(logits, dim=1).mean(0)
107 | penalty = torch.sum(prior*torch.log(prior/pred_mean))
108 |
109 | loss = Lx + penalty
110 |
111 | # compute gradient and do SGD step
112 | optimizer.zero_grad()
113 | loss.backward()
114 | optimizer.step()
115 |
116 | sys.stdout.write('\r')
117 | sys.stdout.write('Clothing1M | Epoch [%3d/%3d] Iter[%3d/%3d]\t Labeled loss: %.4f '
118 | %(epoch, args.num_epochs, batch_idx+1, num_iter, Lx.item()))
119 | sys.stdout.flush()
120 |
121 | def warmup(net,optimizer,dataloader):
122 | net.train()
123 | for batch_idx, (inputs, labels, path) in enumerate(dataloader):
124 | inputs, labels = inputs.cuda(), labels.cuda()
125 | optimizer.zero_grad()
126 | outputs = net(inputs)
127 | loss = CEloss(outputs, labels)
128 |
129 | penalty = conf_penalty(outputs)
130 | L = loss + penalty
131 | L.backward()
132 | optimizer.step()
133 |
134 | sys.stdout.write('\r')
135 | sys.stdout.write('|Warm-up: Iter[%3d/%3d]\t CE-loss: %.4f Conf-Penalty: %.4f'
136 | %(batch_idx+1, args.num_batches, loss.item(), penalty.item()))
137 | sys.stdout.flush()
138 |
139 | def val(net,val_loader,k):
140 | net.eval()
141 | correct = 0
142 | total = 0
143 | with torch.no_grad():
144 | for batch_idx, (inputs, targets) in enumerate(val_loader):
145 | inputs, targets = inputs.cuda(), targets.cuda()
146 | outputs = net(inputs)
147 | _, predicted = torch.max(outputs, 1)
148 |
149 | total += targets.size(0)
150 | correct += predicted.eq(targets).cpu().sum().item()
151 | acc = 100.*correct/total
152 | print("\n| Validation\t Net%d Acc: %.2f%%" %(k,acc))
153 | if acc > best_acc[k-1]:
154 | best_acc[k-1] = acc
155 | print('| Saving Best Net%d ...'%k)
156 | save_point = './checkpoint/%s_net%d.pth.tar'%(args.id,k)
157 | torch.save(net.state_dict(), save_point)
158 | return acc
159 |
160 | def test(net1,net2,test_loader):
161 | net1.eval()
162 | net2.eval()
163 | correct = 0
164 | total = 0
165 | with torch.no_grad():
166 | for batch_idx, (inputs, targets) in enumerate(test_loader):
167 | inputs, targets = inputs.cuda(), targets.cuda()
168 | outputs1 = net1(inputs)
169 | outputs2 = net2(inputs)
170 | outputs = outputs1+outputs2
171 | _, predicted = torch.max(outputs, 1)
172 |
173 | total += targets.size(0)
174 | correct += predicted.eq(targets).cpu().sum().item()
175 | acc = 100.*correct/total
176 | print("\n| Test Acc: %.2f%%\n" %(acc))
177 | return acc
178 |
179 | def eval_train(epoch,model):
180 | model.eval()
181 | num_samples = args.num_batches*args.batch_size
182 | losses = torch.zeros(num_samples)
183 | paths = []
184 | n=0
185 | with torch.no_grad():
186 | for batch_idx, (inputs, targets, path) in enumerate(eval_loader):
187 | inputs, targets = inputs.cuda(), targets.cuda()
188 | outputs = model(inputs)
189 | loss = CE(outputs, targets)
190 | for b in range(inputs.size(0)):
191 | losses[n]=loss[b]
192 | paths.append(path[b])
193 | n+=1
194 | sys.stdout.write('\r')
195 | sys.stdout.write('| Evaluating loss Iter %3d\t' %(batch_idx))
196 | sys.stdout.flush()
197 |
198 | losses = (losses-losses.min())/(losses.max()-losses.min())
199 | losses = losses.reshape(-1,1)
200 | gmm = GaussianMixture(n_components=2,max_iter=10,reg_covar=5e-4,tol=1e-2)
201 | gmm.fit(losses)
202 | prob = gmm.predict_proba(losses)
203 | prob = prob[:,gmm.means_.argmin()]
204 | return prob,paths
205 |
206 | class NegEntropy(object):
207 | def __call__(self,outputs):
208 | probs = torch.softmax(outputs, dim=1)
209 | return torch.mean(torch.sum(probs.log()*probs, dim=1))
210 |
211 | def create_model():
212 | model = models.resnet50(pretrained=True)
213 | model.fc = nn.Linear(2048,args.num_class)
214 | model = model.cuda()
215 | return model
216 |
217 | log=open('./checkpoint/%s.txt'%args.id,'w')
218 | log.flush()
219 |
220 | loader = dataloader.clothing_dataloader(root=args.data_path,batch_size=args.batch_size,num_workers=5,num_batches=args.num_batches)
221 |
222 | print('| Building net')
223 | net1 = create_model()
224 | net2 = create_model()
225 | cudnn.benchmark = True
226 |
227 | optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-3)
228 | optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-3)
229 |
230 | CE = nn.CrossEntropyLoss(reduction='none')
231 | CEloss = nn.CrossEntropyLoss()
232 | conf_penalty = NegEntropy()
233 |
234 | best_acc = [0,0]
235 | for epoch in range(args.num_epochs+1):
236 | lr=args.lr
237 | if epoch >= 40:
238 | lr /= 10
239 | for param_group in optimizer1.param_groups:
240 | param_group['lr'] = lr
241 | for param_group in optimizer2.param_groups:
242 | param_group['lr'] = lr
243 |
244 | if epoch<1: # warm up
245 | train_loader = loader.run('warmup')
246 | print('Warmup Net1')
247 | warmup(net1,optimizer1,train_loader)
248 | train_loader = loader.run('warmup')
249 | print('\nWarmup Net2')
250 | warmup(net2,optimizer2,train_loader)
251 | else:
252 | pred1 = (prob1 > args.p_threshold) # divide dataset
253 | pred2 = (prob2 > args.p_threshold)
254 |
255 | print('\n\nTrain Net1')
256 | labeled_trainloader, unlabeled_trainloader = loader.run('train',pred2,prob2,paths=paths2) # co-divide
257 | train(epoch,net1,net2,optimizer1,labeled_trainloader, unlabeled_trainloader) # train net1
258 | print('\nTrain Net2')
259 | labeled_trainloader, unlabeled_trainloader = loader.run('train',pred1,prob1,paths=paths1) # co-divide
260 | train(epoch,net2,net1,optimizer2,labeled_trainloader, unlabeled_trainloader) # train net2
261 |
262 | val_loader = loader.run('val') # validation
263 | acc1 = val(net1,val_loader,1)
264 | acc2 = val(net2,val_loader,2)
265 | log.write('Validation Epoch:%d Acc1:%.2f Acc2:%.2f\n'%(epoch,acc1,acc2))
266 | log.flush()
267 | print('\n==== net 1 evaluate next epoch training data loss ====')
268 | eval_loader = loader.run('eval_train') # evaluate training data loss for next epoch
269 | prob1,paths1 = eval_train(epoch,net1)
270 | print('\n==== net 2 evaluate next epoch training data loss ====')
271 | eval_loader = loader.run('eval_train')
272 | prob2,paths2 = eval_train(epoch,net2)
273 |
274 | test_loader = loader.run('test')
275 | net1.load_state_dict(torch.load('./checkpoint/%s_net1.pth.tar'%args.id))
276 | net2.load_state_dict(torch.load('./checkpoint/%s_net2.pth.tar'%args.id))
277 | acc = test(net1,net2,test_loader)
278 |
279 | log.write('Test Accuracy:%.2f\n'%(acc))
280 | log.flush()
281 |
--------------------------------------------------------------------------------
/Train_webvision.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import sys
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | import torch.nn.functional as F
7 | import torch.backends.cudnn as cudnn
8 | import random
9 | import os
10 | import sys
11 | import argparse
12 | import numpy as np
13 | from InceptionResNetV2 import *
14 | from sklearn.mixture import GaussianMixture
15 | import dataloader_webvision as dataloader
16 | import torchnet
17 |
18 | parser = argparse.ArgumentParser(description='PyTorch WebVision Training')
19 | parser.add_argument('--batch_size', default=32, type=int, help='train batchsize')
20 | parser.add_argument('--lr', '--learning_rate', default=0.01, type=float, help='initial learning rate')
21 | parser.add_argument('--alpha', default=0.5, type=float, help='parameter for Beta')
22 | parser.add_argument('--lambda_u', default=0, type=float, help='weight for unsupervised loss')
23 | parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold')
24 | parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
25 | parser.add_argument('--num_epochs', default=80, type=int)
26 | parser.add_argument('--id', default='',type=str)
27 | parser.add_argument('--seed', default=123)
28 | parser.add_argument('--gpuid', default=0, type=int)
29 | parser.add_argument('--num_class', default=50, type=int)
30 | parser.add_argument('--data_path', default='./dataset/', type=str, help='path to dataset')
31 |
32 | args = parser.parse_args()
33 |
34 | torch.cuda.set_device(args.gpuid)
35 | random.seed(args.seed)
36 | torch.manual_seed(args.seed)
37 | torch.cuda.manual_seed_all(args.seed)
38 |
39 |
40 | # Training
41 | def train(epoch,net,net2,optimizer,labeled_trainloader,unlabeled_trainloader):
42 | net.train()
43 | net2.eval() #fix one network and train the other
44 |
45 | unlabeled_train_iter = iter(unlabeled_trainloader)
46 | num_iter = (len(labeled_trainloader.dataset)//args.batch_size)+1
47 | for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader):
48 | try:
49 | inputs_u, inputs_u2 = unlabeled_train_iter.next()
50 | except:
51 | unlabeled_train_iter = iter(unlabeled_trainloader)
52 | inputs_u, inputs_u2 = unlabeled_train_iter.next()
53 | batch_size = inputs_x.size(0)
54 |
55 | # Transform label to one-hot
56 | labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1,1), 1)
57 | w_x = w_x.view(-1,1).type(torch.FloatTensor)
58 |
59 | inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(), inputs_x2.cuda(), labels_x.cuda(), w_x.cuda()
60 | inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda()
61 |
62 | with torch.no_grad():
63 | # label co-guessing of unlabeled samples
64 | outputs_u11 = net(inputs_u)
65 | outputs_u12 = net(inputs_u2)
66 | outputs_u21 = net2(inputs_u)
67 | outputs_u22 = net2(inputs_u2)
68 |
69 | pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) + torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4
70 | ptu = pu**(1/args.T) # temparature sharpening
71 |
72 | targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
73 | targets_u = targets_u.detach()
74 |
75 | # label refinement of labeled samples
76 | outputs_x = net(inputs_x)
77 | outputs_x2 = net(inputs_x2)
78 |
79 | px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
80 | px = w_x*labels_x + (1-w_x)*px
81 | ptx = px**(1/args.T) # temparature sharpening
82 |
83 | targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize
84 | targets_x = targets_x.detach()
85 |
86 | # mixmatch
87 | l = np.random.beta(args.alpha, args.alpha)
88 | l = max(l, 1-l)
89 |
90 | all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
91 | all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)
92 |
93 | idx = torch.randperm(all_inputs.size(0))
94 |
95 | input_a, input_b = all_inputs, all_inputs[idx]
96 | target_a, target_b = all_targets, all_targets[idx]
97 |
98 | mixed_input = l * input_a[:batch_size*2] + (1 - l) * input_b[:batch_size*2]
99 | mixed_target = l * target_a[:batch_size*2] + (1 - l) * target_b[:batch_size*2]
100 |
101 | logits = net(mixed_input)
102 |
103 | Lx = -torch.mean(torch.sum(F.log_softmax(logits, dim=1) * mixed_target, dim=1))
104 |
105 | prior = torch.ones(args.num_class)/args.num_class
106 | prior = prior.cuda()
107 | pred_mean = torch.softmax(logits, dim=1).mean(0)
108 | penalty = torch.sum(prior*torch.log(prior/pred_mean))
109 |
110 | loss = Lx + penalty
111 | # compute gradient and do SGD step
112 | optimizer.zero_grad()
113 | loss.backward()
114 | optimizer.step()
115 |
116 | sys.stdout.write('\r')
117 | sys.stdout.write('%s | Epoch [%3d/%3d] Iter[%4d/%4d]\t Labeled loss: %.2f'
118 | %(args.id, epoch, args.num_epochs, batch_idx+1, num_iter, Lx.item()))
119 | sys.stdout.flush()
120 |
121 | def warmup(epoch,net,optimizer,dataloader):
122 | net.train()
123 | num_iter = (len(dataloader.dataset)//dataloader.batch_size)+1
124 | for batch_idx, (inputs, labels, path) in enumerate(dataloader):
125 | inputs, labels = inputs.cuda(), labels.cuda()
126 | optimizer.zero_grad()
127 | outputs = net(inputs)
128 | loss = CEloss(outputs, labels)
129 |
130 | #penalty = conf_penalty(outputs)
131 | L = loss #+ penalty
132 |
133 | L.backward()
134 | optimizer.step()
135 |
136 | sys.stdout.write('\r')
137 | sys.stdout.write('%s | Epoch [%3d/%3d] Iter[%4d/%4d]\t CE-loss: %.4f'
138 | %(args.id, epoch, args.num_epochs, batch_idx+1, num_iter, loss.item()))
139 | sys.stdout.flush()
140 |
141 |
142 | def test(epoch,net1,net2,test_loader):
143 | acc_meter.reset()
144 | net1.eval()
145 | net2.eval()
146 | correct = 0
147 | total = 0
148 | with torch.no_grad():
149 | for batch_idx, (inputs, targets) in enumerate(test_loader):
150 | inputs, targets = inputs.cuda(), targets.cuda()
151 | outputs1 = net1(inputs)
152 | outputs2 = net2(inputs)
153 | outputs = outputs1+outputs2
154 | _, predicted = torch.max(outputs, 1)
155 | acc_meter.add(outputs,targets)
156 | accs = acc_meter.value()
157 | return accs
158 |
159 |
160 | def eval_train(model,all_loss):
161 | model.eval()
162 | num_iter = (len(eval_loader.dataset)//eval_loader.batch_size)+1
163 | losses = torch.zeros(len(eval_loader.dataset))
164 | with torch.no_grad():
165 | for batch_idx, (inputs, targets, index) in enumerate(eval_loader):
166 | inputs, targets = inputs.cuda(), targets.cuda()
167 | outputs = model(inputs)
168 | loss = CE(outputs, targets)
169 | for b in range(inputs.size(0)):
170 | losses[index[b]]=loss[b]
171 | sys.stdout.write('\r')
172 | sys.stdout.write('| Evaluating loss Iter[%3d/%3d]\t' %(batch_idx,num_iter))
173 | sys.stdout.flush()
174 |
175 | losses = (losses-losses.min())/(losses.max()-losses.min())
176 | all_loss.append(losses)
177 |
178 | # fit a two-component GMM to the loss
179 | input_loss = losses.reshape(-1,1)
180 | gmm = GaussianMixture(n_components=2,max_iter=10,tol=1e-2,reg_covar=5e-4)
181 | gmm.fit(input_loss)
182 | prob = gmm.predict_proba(input_loss)
183 | prob = prob[:,gmm.means_.argmin()]
184 | return prob,all_loss
185 |
186 | def linear_rampup(current, warm_up, rampup_length=16):
187 | current = np.clip((current-warm_up) / rampup_length, 0.0, 1.0)
188 | return args.lambda_u*float(current)
189 |
190 | class SemiLoss(object):
191 | def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up):
192 | probs_u = torch.softmax(outputs_u, dim=1)
193 |
194 | Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
195 | Lu = torch.mean((probs_u - targets_u)**2)
196 |
197 | return Lx, Lu, linear_rampup(epoch,warm_up)
198 |
199 | class NegEntropy(object):
200 | def __call__(self,outputs):
201 | probs = torch.softmax(outputs, dim=1)
202 | return torch.mean(torch.sum(probs.log()*probs, dim=1))
203 |
204 | def create_model():
205 | model = InceptionResNetV2(num_classes=args.num_class)
206 | model = model.cuda()
207 | return model
208 |
209 | stats_log=open('./checkpoint/%s'%(args.id)+'_stats.txt','w')
210 | test_log=open('./checkpoint/%s'%(args.id)+'_acc.txt','w')
211 |
212 | warm_up=1
213 |
214 | loader = dataloader.webvision_dataloader(batch_size=args.batch_size,num_workers=5,root_dir=args.data_path,log=stats_log, num_class=args.num_class)
215 |
216 | print('| Building net')
217 | net1 = create_model()
218 | net2 = create_model()
219 | cudnn.benchmark = True
220 |
221 | criterion = SemiLoss()
222 | optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
223 | optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
224 |
225 | CE = nn.CrossEntropyLoss(reduction='none')
226 | CEloss = nn.CrossEntropyLoss()
227 | conf_penalty = NegEntropy()
228 |
229 | all_loss = [[],[]] # save the history of losses from two networks
230 | acc_meter = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True)
231 |
232 | for epoch in range(args.num_epochs+1):
233 | lr=args.lr
234 | if epoch >= 40:
235 | lr /= 10
236 | for param_group in optimizer1.param_groups:
237 | param_group['lr'] = lr
238 | for param_group in optimizer2.param_groups:
239 | param_group['lr'] = lr
240 | eval_loader = loader.run('eval_train')
241 | web_valloader = loader.run('test')
242 | imagenet_valloader = loader.run('imagenet')
243 |
244 | if epoch args.p_threshold)
253 | pred2 = (prob2 > args.p_threshold)
254 |
255 | print('Train Net1')
256 | labeled_trainloader, unlabeled_trainloader = loader.run('train',pred2,prob2) # co-divide
257 | train(epoch,net1,net2,optimizer1,labeled_trainloader, unlabeled_trainloader) # train net1
258 |
259 | print('\nTrain Net2')
260 | labeled_trainloader, unlabeled_trainloader = loader.run('train',pred1,prob1) # co-divide
261 | train(epoch,net2,net1,optimizer2,labeled_trainloader, unlabeled_trainloader) # train net2
262 |
263 |
264 | web_acc = test(epoch,net1,net2,web_valloader)
265 | imagenet_acc = test(epoch,net1,net2,imagenet_valloader)
266 |
267 | print("\n| Test Epoch #%d\t WebVision Acc: %.2f%% (%.2f%%) \t ImageNet Acc: %.2f%% (%.2f%%)\n"%(epoch,web_acc[0],web_acc[1],imagenet_acc[0],imagenet_acc[1]))
268 | test_log.write('Epoch:%d \t WebVision Acc: %.2f%% (%.2f%%) \t ImageNet Acc: %.2f%% (%.2f%%)\n'%(epoch,web_acc[0],web_acc[1],imagenet_acc[0],imagenet_acc[1]))
269 | test_log.flush()
270 |
271 | print('\n==== net 1 evaluate training data loss ====')
272 | prob1,all_loss[0]=eval_train(net1,all_loss[0])
273 | print('\n==== net 2 evaluate training data loss ====')
274 | prob2,all_loss[1]=eval_train(net2,all_loss[1])
275 | torch.save(all_loss,'./checkpoint/%s.pth.tar'%(args.id))
276 |
277 |
--------------------------------------------------------------------------------
/Train_webvision_parallel.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import sys
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | import torch.nn.functional as F
7 | import torch.backends.cudnn as cudnn
8 | import random
9 | import os
10 | import sys
11 | import argparse
12 | import numpy as np
13 | from InceptionResNetV2 import *
14 | from sklearn.mixture import GaussianMixture
15 | import dataloader_webvision as dataloader
16 | import torchnet
17 | import torch.multiprocessing as mp
18 |
19 | parser = argparse.ArgumentParser(description='PyTorch WebVision Parallel Training')
20 | parser.add_argument('--batch_size', default=32, type=int, help='train batchsize')
21 | parser.add_argument('--lr', '--learning_rate', default=0.01, type=float, help='initial learning rate')
22 | parser.add_argument('--alpha', default=0.5, type=float, help='parameter for Beta')
23 | parser.add_argument('--lambda_u', default=0, type=float, help='weight for unsupervised loss')
24 | parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold')
25 | parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
26 | parser.add_argument('--num_epochs', default=100, type=int)
27 | parser.add_argument('--id', default='',type=str)
28 | parser.add_argument('--seed', default=123)
29 | parser.add_argument('--gpuid1', default=0, type=int)
30 | parser.add_argument('--gpuid2', default=1, type=int)
31 | parser.add_argument('--num_class', default=50, type=int)
32 | parser.add_argument('--data_path', default='./dataset/', type=str, help='path to dataset')
33 |
34 | args = parser.parse_args()
35 |
36 | os.environ["CUDA_VISIBLE_DEVICES"] = '%s,%s'%(args.gpuid1,args.gpuid2)
37 | random.seed(args.seed)
38 | cuda1 = torch.device('cuda:0')
39 | cuda2 = torch.device('cuda:1')
40 |
41 | # Training
42 | def train(epoch,net,net2,optimizer,labeled_trainloader,unlabeled_trainloader,device,whichnet):
43 | criterion = SemiLoss()
44 |
45 | net.train()
46 | net2.eval() #fix one network and train the other
47 |
48 | unlabeled_train_iter = iter(unlabeled_trainloader)
49 | num_iter = (len(labeled_trainloader.dataset)//args.batch_size)+1
50 | for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader):
51 | try:
52 | inputs_u, inputs_u2 = unlabeled_train_iter.next()
53 | except:
54 | unlabeled_train_iter = iter(unlabeled_trainloader)
55 | inputs_u, inputs_u2 = unlabeled_train_iter.next()
56 | batch_size = inputs_x.size(0)
57 |
58 | # Transform label to one-hot
59 | labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1,1), 1)
60 | w_x = w_x.view(-1,1).type(torch.FloatTensor)
61 |
62 | inputs_x, inputs_x2, labels_x, w_x = inputs_x.to(device,non_blocking=True), inputs_x2.to(device,non_blocking=True), labels_x.to(device,non_blocking=True), w_x.to(device,non_blocking=True)
63 | inputs_u, inputs_u2 = inputs_u.to(device), inputs_u2.to(device)
64 |
65 | with torch.no_grad():
66 | # label co-guessing of unlabeled samples
67 | outputs_u11 = net(inputs_u)
68 | outputs_u12 = net(inputs_u2)
69 | outputs_u21 = net2(inputs_u)
70 | outputs_u22 = net2(inputs_u2)
71 |
72 | pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) + torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4
73 | ptu = pu**(1/args.T) # temparature sharpening
74 |
75 | targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
76 | targets_u = targets_u.detach()
77 |
78 | # label refinement of labeled samples
79 | outputs_x = net(inputs_x)
80 | outputs_x2 = net(inputs_x2)
81 |
82 | px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
83 | px = w_x*labels_x + (1-w_x)*px
84 | ptx = px**(1/args.T) # temparature sharpening
85 |
86 | targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize
87 | targets_x = targets_x.detach()
88 |
89 | # mixmatch
90 | l = np.random.beta(args.alpha, args.alpha)
91 | l = max(l, 1-l)
92 |
93 | all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
94 | all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)
95 |
96 | idx = torch.randperm(all_inputs.size(0))
97 |
98 | input_a, input_b = all_inputs, all_inputs[idx]
99 | target_a, target_b = all_targets, all_targets[idx]
100 |
101 | mixed_input = l * input_a[:batch_size*2] + (1 - l) * input_b[:batch_size*2]
102 | mixed_target = l * target_a[:batch_size*2] + (1 - l) * target_b[:batch_size*2]
103 |
104 | logits = net(mixed_input)
105 |
106 | Lx = -torch.mean(torch.sum(F.log_softmax(logits, dim=1) * mixed_target, dim=1))
107 |
108 | prior = torch.ones(args.num_class)/args.num_class
109 | prior = prior.to(device)
110 | pred_mean = torch.softmax(logits, dim=1).mean(0)
111 | penalty = torch.sum(prior*torch.log(prior/pred_mean))
112 |
113 | loss = Lx + penalty
114 | # compute gradient and do SGD step
115 | optimizer.zero_grad()
116 | loss.backward()
117 | optimizer.step()
118 |
119 | sys.stdout.write('\n')
120 | sys.stdout.write('%s |%s Epoch [%3d/%3d] Iter[%4d/%4d]\t Labeled loss: %.2f'
121 | %(args.id, whichnet, epoch, args.num_epochs, batch_idx+1, num_iter, Lx.item()))
122 | sys.stdout.flush()
123 |
124 | def warmup(epoch,net,optimizer,dataloader,device,whichnet):
125 | CEloss = nn.CrossEntropyLoss()
126 | acc_meter = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True)
127 |
128 | net.train()
129 | num_iter = (len(dataloader.dataset)//dataloader.batch_size)+1
130 | for batch_idx, (inputs, labels, path) in enumerate(dataloader):
131 | inputs, labels = inputs.to(device), labels.to(device,non_blocking=True)
132 | optimizer.zero_grad()
133 | outputs = net(inputs)
134 | loss = CEloss(outputs, labels)
135 |
136 | #penalty = conf_penalty(outputs)
137 | L = loss #+ penalty
138 |
139 | L.backward()
140 | optimizer.step()
141 |
142 | sys.stdout.write('\n')
143 | sys.stdout.write('%s |%s Epoch [%3d/%3d] Iter[%4d/%4d]\t CE-loss: %.4f'
144 | %(args.id, whichnet, epoch, args.num_epochs, batch_idx+1, num_iter, loss.item()))
145 | sys.stdout.flush()
146 |
147 |
148 | def test(epoch,net1,net2,test_loader,device,queue):
149 | acc_meter = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True)
150 | acc_meter.reset()
151 | net1.eval()
152 | net2.eval()
153 | with torch.no_grad():
154 | for batch_idx, (inputs, targets) in enumerate(test_loader):
155 | inputs, targets = inputs.to(device), targets.to(device,non_blocking=True)
156 | outputs1 = net1(inputs)
157 | outputs2 = net2(inputs)
158 | outputs = outputs1+outputs2
159 | _, predicted = torch.max(outputs, 1)
160 | acc_meter.add(outputs,targets)
161 | accs = acc_meter.value()
162 | queue.put(accs)
163 |
164 |
165 | def eval_train(eval_loader,model,device,whichnet,queue):
166 | CE = nn.CrossEntropyLoss(reduction='none')
167 | model.eval()
168 | num_iter = (len(eval_loader.dataset)//eval_loader.batch_size)+1
169 | losses = torch.zeros(len(eval_loader.dataset))
170 | with torch.no_grad():
171 | for batch_idx, (inputs, targets, index) in enumerate(eval_loader):
172 | inputs, targets = inputs.to(device), targets.to(device,non_blocking=True)
173 | outputs = model(inputs)
174 | loss = CE(outputs, targets)
175 | for b in range(inputs.size(0)):
176 | losses[index[b]]=loss[b]
177 | sys.stdout.write('\n')
178 | sys.stdout.write('|%s Evaluating loss Iter[%3d/%3d]\t' %(whichnet,batch_idx,num_iter))
179 | sys.stdout.flush()
180 |
181 | losses = (losses-losses.min())/(losses.max()-losses.min())
182 |
183 | # fit a two-component GMM to the loss
184 | input_loss = losses.reshape(-1,1)
185 | gmm = GaussianMixture(n_components=2,max_iter=10,tol=1e-2,reg_covar=1e-3)
186 | gmm.fit(input_loss)
187 | prob = gmm.predict_proba(input_loss)
188 | prob = prob[:,gmm.means_.argmin()]
189 | queue.put(prob)
190 |
191 | def linear_rampup(current, warm_up, rampup_length=16):
192 | current = np.clip((current-warm_up) / rampup_length, 0.0, 1.0)
193 | return args.lambda_u*float(current)
194 |
195 | class SemiLoss(object):
196 | def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up):
197 | probs_u = torch.softmax(outputs_u, dim=1)
198 |
199 | Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
200 | Lu = torch.mean((probs_u - targets_u)**2)
201 |
202 | return Lx, Lu, linear_rampup(epoch,warm_up)
203 |
204 | class NegEntropy(object):
205 | def __call__(self,outputs):
206 | probs = torch.softmax(outputs, dim=1)
207 | return torch.mean(torch.sum(probs.log()*probs, dim=1))
208 |
209 | def create_model(device):
210 | model = InceptionResNetV2(num_classes=args.num_class)
211 | model = model.to(device)
212 | return model
213 |
214 | if __name__ == "__main__":
215 |
216 | mp.set_start_method('spawn')
217 | torch.manual_seed(args.seed)
218 | torch.cuda.manual_seed_all(args.seed)
219 |
220 | stats_log=open('./checkpoint/%s'%(args.id)+'_stats.txt','w')
221 | test_log=open('./checkpoint/%s'%(args.id)+'_acc.txt','w')
222 |
223 | warm_up=1
224 |
225 | loader = dataloader.webvision_dataloader(batch_size=args.batch_size,num_class = args.num_class,num_workers=8,root_dir=args.data_path,log=stats_log)
226 |
227 | print('| Building net')
228 |
229 | net1 = create_model(cuda1)
230 | net2 = create_model(cuda2)
231 |
232 | net1_clone = create_model(cuda2)
233 | net2_clone = create_model(cuda1)
234 |
235 | cudnn.benchmark = True
236 |
237 | optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
238 | optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
239 |
240 | #conf_penalty = NegEntropy()
241 | web_valloader = loader.run('test')
242 | imagenet_valloader = loader.run('imagenet')
243 |
244 | for epoch in range(args.num_epochs+1):
245 | lr=args.lr
246 | if epoch >= 50:
247 | lr /= 10
248 | for param_group in optimizer1.param_groups:
249 | param_group['lr'] = lr
250 | for param_group in optimizer2.param_groups:
251 | param_group['lr'] = lr
252 |
253 | if epoch args.p_threshold)
263 | pred2 = (prob2 > args.p_threshold)
264 |
265 | labeled_trainloader1, unlabeled_trainloader1 = loader.run('train',pred2,prob2) # co-divide
266 | labeled_trainloader2, unlabeled_trainloader2 = loader.run('train',pred1,prob1) # co-divide
267 |
268 | p1 = mp.Process(target=train, args=(epoch,net1,net2_clone,optimizer1,labeled_trainloader1, unlabeled_trainloader1,cuda1,'net1'))
269 | p2 = mp.Process(target=train, args=(epoch,net2,net1_clone,optimizer2,labeled_trainloader2, unlabeled_trainloader2,cuda2,'net2'))
270 | p1.start()
271 | p2.start()
272 |
273 | p1.join()
274 | p2.join()
275 |
276 | net1_clone.load_state_dict(net1.state_dict())
277 | net2_clone.load_state_dict(net2.state_dict())
278 |
279 | q1 = mp.Queue()
280 | q2 = mp.Queue()
281 | p1 = mp.Process(target=test, args=(epoch,net1,net2_clone,web_valloader,cuda1,q1))
282 | p2 = mp.Process(target=test, args=(epoch,net1_clone,net2,imagenet_valloader,cuda2,q2))
283 |
284 | p1.start()
285 | p2.start()
286 |
287 | web_acc = q1.get()
288 | imagenet_acc = q2.get()
289 |
290 | p1.join()
291 | p2.join()
292 |
293 | print("\n| Test Epoch #%d\t WebVision Acc: %.2f%% (%.2f%%) \t ImageNet Acc: %.2f%% (%.2f%%)\n"%(epoch,web_acc[0],web_acc[1],imagenet_acc[0],imagenet_acc[1]))
294 | test_log.write('Epoch:%d \t WebVision Acc: %.2f%% (%.2f%%) \t ImageNet Acc: %.2f%% (%.2f%%)\n'%(epoch,web_acc[0],web_acc[1],imagenet_acc[0],imagenet_acc[1]))
295 | test_log.flush()
296 |
297 | eval_loader1 = loader.run('eval_train')
298 | eval_loader2 = loader.run('eval_train')
299 | q1 = mp.Queue()
300 | q2 = mp.Queue()
301 | p1 = mp.Process(target=eval_train, args=(eval_loader1,net1,cuda1,'net1',q1))
302 | p2 = mp.Process(target=eval_train, args=(eval_loader2,net2,cuda2,'net2',q2))
303 |
304 | p1.start()
305 | p2.start()
306 |
307 | prob1 = q1.get()
308 | prob2 = q2.get()
309 |
310 | p1.join()
311 | p2.join()
312 |
--------------------------------------------------------------------------------
/dataloader_cifar.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | import torchvision.transforms as transforms
3 | import random
4 | import numpy as np
5 | from PIL import Image
6 | import json
7 | import os
8 | import torch
9 | from torchnet.meter import AUCMeter
10 |
11 |
12 | def unpickle(file):
13 | import _pickle as cPickle
14 | with open(file, 'rb') as fo:
15 | dict = cPickle.load(fo, encoding='latin1')
16 | return dict
17 |
18 | class cifar_dataset(Dataset):
19 | def __init__(self, dataset, r, noise_mode, root_dir, transform, mode, noise_file='', pred=[], probability=[], log=''):
20 |
21 | self.r = r # noise ratio
22 | self.transform = transform
23 | self.mode = mode
24 | self.transition = {0:0,2:0,4:7,7:7,1:1,9:1,3:5,5:3,6:6,8:8} # class transition for asymmetric noise
25 |
26 | if self.mode=='test':
27 | if dataset=='cifar10':
28 | test_dic = unpickle('%s/test_batch'%root_dir)
29 | self.test_data = test_dic['data']
30 | self.test_data = self.test_data.reshape((10000, 3, 32, 32))
31 | self.test_data = self.test_data.transpose((0, 2, 3, 1))
32 | self.test_label = test_dic['labels']
33 | elif dataset=='cifar100':
34 | test_dic = unpickle('%s/test'%root_dir)
35 | self.test_data = test_dic['data']
36 | self.test_data = self.test_data.reshape((10000, 3, 32, 32))
37 | self.test_data = self.test_data.transpose((0, 2, 3, 1))
38 | self.test_label = test_dic['fine_labels']
39 | else:
40 | train_data=[]
41 | train_label=[]
42 | if dataset=='cifar10':
43 | for n in range(1,6):
44 | dpath = '%s/data_batch_%d'%(root_dir,n)
45 | data_dic = unpickle(dpath)
46 | train_data.append(data_dic['data'])
47 | train_label = train_label+data_dic['labels']
48 | train_data = np.concatenate(train_data)
49 | elif dataset=='cifar100':
50 | train_dic = unpickle('%s/train'%root_dir)
51 | train_data = train_dic['data']
52 | train_label = train_dic['fine_labels']
53 | train_data = train_data.reshape((50000, 3, 32, 32))
54 | train_data = train_data.transpose((0, 2, 3, 1))
55 |
56 | if os.path.exists(noise_file):
57 | noise_label = json.load(open(noise_file,"r"))
58 | else: #inject noise
59 | noise_label = []
60 | idx = list(range(50000))
61 | random.shuffle(idx)
62 | num_noise = int(self.r*50000)
63 | noise_idx = idx[:num_noise]
64 | for i in range(50000):
65 | if i in noise_idx:
66 | if noise_mode=='sym':
67 | if dataset=='cifar10':
68 | noiselabel = random.randint(0,9)
69 | elif dataset=='cifar100':
70 | noiselabel = random.randint(0,99)
71 | noise_label.append(noiselabel)
72 | elif noise_mode=='asym':
73 | noiselabel = self.transition[train_label[i]]
74 | noise_label.append(noiselabel)
75 | else:
76 | noise_label.append(train_label[i])
77 | print("save noisy labels to %s ..."%noise_file)
78 | json.dump(noise_label,open(noise_file,"w"))
79 |
80 | if self.mode == 'all':
81 | self.train_data = train_data
82 | self.noise_label = noise_label
83 | else:
84 | if self.mode == "labeled":
85 | pred_idx = pred.nonzero()[0]
86 | self.probability = [probability[i] for i in pred_idx]
87 |
88 | clean = (np.array(noise_label)==np.array(train_label))
89 | auc_meter = AUCMeter()
90 | auc_meter.reset()
91 | auc_meter.add(probability,clean)
92 | auc,_,_ = auc_meter.value()
93 | log.write('Numer of labeled samples:%d AUC:%.3f\n'%(pred.sum(),auc))
94 | log.flush()
95 |
96 | elif self.mode == "unlabeled":
97 | pred_idx = (1-pred).nonzero()[0]
98 |
99 | self.train_data = train_data[pred_idx]
100 | self.noise_label = [noise_label[i] for i in pred_idx]
101 | print("%s data has a size of %d"%(self.mode,len(self.noise_label)))
102 |
103 | def __getitem__(self, index):
104 | if self.mode=='labeled':
105 | img, target, prob = self.train_data[index], self.noise_label[index], self.probability[index]
106 | img = Image.fromarray(img)
107 | img1 = self.transform(img)
108 | img2 = self.transform(img)
109 | return img1, img2, target, prob
110 | elif self.mode=='unlabeled':
111 | img = self.train_data[index]
112 | img = Image.fromarray(img)
113 | img1 = self.transform(img)
114 | img2 = self.transform(img)
115 | return img1, img2
116 | elif self.mode=='all':
117 | img, target = self.train_data[index], self.noise_label[index]
118 | img = Image.fromarray(img)
119 | img = self.transform(img)
120 | return img, target, index
121 | elif self.mode=='test':
122 | img, target = self.test_data[index], self.test_label[index]
123 | img = Image.fromarray(img)
124 | img = self.transform(img)
125 | return img, target
126 |
127 | def __len__(self):
128 | if self.mode!='test':
129 | return len(self.train_data)
130 | else:
131 | return len(self.test_data)
132 |
133 |
134 | class cifar_dataloader():
135 | def __init__(self, dataset, r, noise_mode, batch_size, num_workers, root_dir, log, noise_file=''):
136 | self.dataset = dataset
137 | self.r = r
138 | self.noise_mode = noise_mode
139 | self.batch_size = batch_size
140 | self.num_workers = num_workers
141 | self.root_dir = root_dir
142 | self.log = log
143 | self.noise_file = noise_file
144 | if self.dataset=='cifar10':
145 | self.transform_train = transforms.Compose([
146 | transforms.RandomCrop(32, padding=4),
147 | transforms.RandomHorizontalFlip(),
148 | transforms.ToTensor(),
149 | transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
150 | ])
151 | self.transform_test = transforms.Compose([
152 | transforms.ToTensor(),
153 | transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
154 | ])
155 | elif self.dataset=='cifar100':
156 | self.transform_train = transforms.Compose([
157 | transforms.RandomCrop(32, padding=4),
158 | transforms.RandomHorizontalFlip(),
159 | transforms.ToTensor(),
160 | transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
161 | ])
162 | self.transform_test = transforms.Compose([
163 | transforms.ToTensor(),
164 | transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
165 | ])
166 | def run(self,mode,pred=[],prob=[]):
167 | if mode=='warmup':
168 | all_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_train, mode="all",noise_file=self.noise_file)
169 | trainloader = DataLoader(
170 | dataset=all_dataset,
171 | batch_size=self.batch_size*2,
172 | shuffle=True,
173 | num_workers=self.num_workers)
174 | return trainloader
175 |
176 | elif mode=='train':
177 | labeled_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_train, mode="labeled", noise_file=self.noise_file, pred=pred, probability=prob,log=self.log)
178 | labeled_trainloader = DataLoader(
179 | dataset=labeled_dataset,
180 | batch_size=self.batch_size,
181 | shuffle=True,
182 | num_workers=self.num_workers)
183 |
184 | unlabeled_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_train, mode="unlabeled", noise_file=self.noise_file, pred=pred)
185 | unlabeled_trainloader = DataLoader(
186 | dataset=unlabeled_dataset,
187 | batch_size=self.batch_size,
188 | shuffle=True,
189 | num_workers=self.num_workers)
190 | return labeled_trainloader, unlabeled_trainloader
191 |
192 | elif mode=='test':
193 | test_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_test, mode='test')
194 | test_loader = DataLoader(
195 | dataset=test_dataset,
196 | batch_size=self.batch_size,
197 | shuffle=False,
198 | num_workers=self.num_workers)
199 | return test_loader
200 |
201 | elif mode=='eval_train':
202 | eval_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_test, mode='all', noise_file=self.noise_file)
203 | eval_loader = DataLoader(
204 | dataset=eval_dataset,
205 | batch_size=self.batch_size,
206 | shuffle=False,
207 | num_workers=self.num_workers)
208 | return eval_loader
--------------------------------------------------------------------------------
/dataloader_clothing1M.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | import torchvision.transforms as transforms
3 | import random
4 | import numpy as np
5 | from PIL import Image
6 | import json
7 | import torch
8 |
9 | class clothing_dataset(Dataset):
10 | def __init__(self, root, transform, mode, num_samples=0, pred=[], probability=[], paths=[], num_class=14):
11 |
12 | self.root = root
13 | self.transform = transform
14 | self.mode = mode
15 | self.train_labels = {}
16 | self.test_labels = {}
17 | self.val_labels = {}
18 |
19 | with open('%s/noisy_label_kv.txt'%self.root,'r') as f:
20 | lines = f.read().splitlines()
21 | for l in lines:
22 | entry = l.split()
23 | img_path = '%s/'%self.root+entry[0][7:]
24 | self.train_labels[img_path] = int(entry[1])
25 | with open('%s/clean_label_kv.txt'%self.root,'r') as f:
26 | lines = f.read().splitlines()
27 | for l in lines:
28 | entry = l.split()
29 | img_path = '%s/'%self.root+entry[0][7:]
30 | self.test_labels[img_path] = int(entry[1])
31 |
32 | if mode == 'all':
33 | train_imgs=[]
34 | with open('%s/noisy_train_key_list.txt'%self.root,'r') as f:
35 | lines = f.read().splitlines()
36 | for l in lines:
37 | img_path = '%s/'%self.root+l[7:]
38 | train_imgs.append(img_path)
39 | random.shuffle(train_imgs)
40 | class_num = torch.zeros(num_class)
41 | self.train_imgs = []
42 | for impath in train_imgs:
43 | label = self.train_labels[impath]
44 | if class_num[label]<(num_samples/14) and len(self.train_imgs)