├── InceptionResNetV2.py ├── LICENSE ├── PreResNet.py ├── README.md ├── Train_cifar.py ├── Train_clothing1M.py ├── Train_webvision.py ├── Train_webvision_parallel.py ├── dataloader_cifar.py ├── dataloader_clothing1M.py ├── dataloader_webvision.py └── img └── framework.png /InceptionResNetV2.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import 2 | import torch 3 | import torch.nn as nn 4 | import os 5 | import sys 6 | 7 | 8 | class BasicConv2d(nn.Module): 9 | 10 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): 11 | super(BasicConv2d, self).__init__() 12 | self.conv = nn.Conv2d(in_planes, out_planes, 13 | kernel_size=kernel_size, stride=stride, 14 | padding=padding, bias=False) # verify bias false 15 | self.bn = nn.BatchNorm2d(out_planes, 16 | eps=0.001, # value found in tensorflow 17 | momentum=0.1, # default pytorch value 18 | affine=True) 19 | self.relu = nn.ReLU(inplace=False) 20 | 21 | def forward(self, x): 22 | x = self.conv(x) 23 | x = self.bn(x) 24 | x = self.relu(x) 25 | return x 26 | 27 | 28 | class Mixed_5b(nn.Module): 29 | 30 | def __init__(self): 31 | super(Mixed_5b, self).__init__() 32 | 33 | self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1) 34 | 35 | self.branch1 = nn.Sequential( 36 | BasicConv2d(192, 48, kernel_size=1, stride=1), 37 | BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2) 38 | ) 39 | 40 | self.branch2 = nn.Sequential( 41 | BasicConv2d(192, 64, kernel_size=1, stride=1), 42 | BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), 43 | BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) 44 | ) 45 | 46 | self.branch3 = nn.Sequential( 47 | nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), 48 | BasicConv2d(192, 64, kernel_size=1, stride=1) 49 | ) 50 | 51 | def forward(self, x): 52 | x0 = self.branch0(x) 53 | x1 = self.branch1(x) 54 | x2 = self.branch2(x) 55 | x3 = self.branch3(x) 56 | out = torch.cat((x0, x1, x2, x3), 1) 57 | return out 58 | 59 | 60 | class Block35(nn.Module): 61 | 62 | def __init__(self, scale=1.0): 63 | super(Block35, self).__init__() 64 | 65 | self.scale = scale 66 | 67 | self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1) 68 | 69 | self.branch1 = nn.Sequential( 70 | BasicConv2d(320, 32, kernel_size=1, stride=1), 71 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) 72 | ) 73 | 74 | self.branch2 = nn.Sequential( 75 | BasicConv2d(320, 32, kernel_size=1, stride=1), 76 | BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1), 77 | BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1) 78 | ) 79 | 80 | self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1) 81 | self.relu = nn.ReLU(inplace=False) 82 | 83 | def forward(self, x): 84 | x0 = self.branch0(x) 85 | x1 = self.branch1(x) 86 | x2 = self.branch2(x) 87 | out = torch.cat((x0, x1, x2), 1) 88 | out = self.conv2d(out) 89 | out = out * self.scale + x 90 | out = self.relu(out) 91 | return out 92 | 93 | 94 | class Mixed_6a(nn.Module): 95 | 96 | def __init__(self): 97 | super(Mixed_6a, self).__init__() 98 | 99 | self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2) 100 | 101 | self.branch1 = nn.Sequential( 102 | BasicConv2d(320, 256, kernel_size=1, stride=1), 103 | BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), 104 | BasicConv2d(256, 384, kernel_size=3, stride=2) 105 | ) 106 | 107 | self.branch2 = nn.MaxPool2d(3, stride=2) 108 | 109 | def forward(self, x): 110 | x0 = self.branch0(x) 111 | x1 = self.branch1(x) 112 | x2 = self.branch2(x) 113 | out = torch.cat((x0, x1, x2), 1) 114 | return out 115 | 116 | 117 | class Block17(nn.Module): 118 | 119 | def __init__(self, scale=1.0): 120 | super(Block17, self).__init__() 121 | 122 | self.scale = scale 123 | 124 | self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1) 125 | 126 | self.branch1 = nn.Sequential( 127 | BasicConv2d(1088, 128, kernel_size=1, stride=1), 128 | BasicConv2d(128, 160, kernel_size=(1,7), stride=1, padding=(0,3)), 129 | BasicConv2d(160, 192, kernel_size=(7,1), stride=1, padding=(3,0)) 130 | ) 131 | 132 | self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1) 133 | self.relu = nn.ReLU(inplace=False) 134 | 135 | def forward(self, x): 136 | x0 = self.branch0(x) 137 | x1 = self.branch1(x) 138 | out = torch.cat((x0, x1), 1) 139 | out = self.conv2d(out) 140 | out = out * self.scale + x 141 | out = self.relu(out) 142 | return out 143 | 144 | 145 | class Mixed_7a(nn.Module): 146 | 147 | def __init__(self): 148 | super(Mixed_7a, self).__init__() 149 | 150 | self.branch0 = nn.Sequential( 151 | BasicConv2d(1088, 256, kernel_size=1, stride=1), 152 | BasicConv2d(256, 384, kernel_size=3, stride=2) 153 | ) 154 | 155 | self.branch1 = nn.Sequential( 156 | BasicConv2d(1088, 256, kernel_size=1, stride=1), 157 | BasicConv2d(256, 288, kernel_size=3, stride=2) 158 | ) 159 | 160 | self.branch2 = nn.Sequential( 161 | BasicConv2d(1088, 256, kernel_size=1, stride=1), 162 | BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1), 163 | BasicConv2d(288, 320, kernel_size=3, stride=2) 164 | ) 165 | 166 | self.branch3 = nn.MaxPool2d(3, stride=2) 167 | 168 | def forward(self, x): 169 | x0 = self.branch0(x) 170 | x1 = self.branch1(x) 171 | x2 = self.branch2(x) 172 | x3 = self.branch3(x) 173 | out = torch.cat((x0, x1, x2, x3), 1) 174 | return out 175 | 176 | 177 | class Block8(nn.Module): 178 | 179 | def __init__(self, scale=1.0, noReLU=False): 180 | super(Block8, self).__init__() 181 | 182 | self.scale = scale 183 | self.noReLU = noReLU 184 | 185 | self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1) 186 | 187 | self.branch1 = nn.Sequential( 188 | BasicConv2d(2080, 192, kernel_size=1, stride=1), 189 | BasicConv2d(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)), 190 | BasicConv2d(224, 256, kernel_size=(3,1), stride=1, padding=(1,0)) 191 | ) 192 | 193 | self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) 194 | if not self.noReLU: 195 | self.relu = nn.ReLU(inplace=False) 196 | 197 | def forward(self, x): 198 | x0 = self.branch0(x) 199 | x1 = self.branch1(x) 200 | out = torch.cat((x0, x1), 1) 201 | out = self.conv2d(out) 202 | out = out * self.scale + x 203 | if not self.noReLU: 204 | out = self.relu(out) 205 | return out 206 | 207 | 208 | class InceptionResNetV2(nn.Module): 209 | 210 | def __init__(self, num_classes=1001): 211 | super(InceptionResNetV2, self).__init__() 212 | # Special attributs 213 | self.input_space = None 214 | self.input_size = (299, 299, 3) 215 | self.mean = None 216 | self.std = None 217 | # Modules 218 | self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) 219 | self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) 220 | self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) 221 | self.maxpool_3a = nn.MaxPool2d(3, stride=2) 222 | self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) 223 | self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) 224 | self.maxpool_5a = nn.MaxPool2d(3, stride=2) 225 | self.mixed_5b = Mixed_5b() 226 | self.repeat = nn.Sequential( 227 | Block35(scale=0.17), 228 | Block35(scale=0.17), 229 | Block35(scale=0.17), 230 | Block35(scale=0.17), 231 | Block35(scale=0.17), 232 | Block35(scale=0.17), 233 | Block35(scale=0.17), 234 | Block35(scale=0.17), 235 | Block35(scale=0.17), 236 | Block35(scale=0.17) 237 | ) 238 | self.mixed_6a = Mixed_6a() 239 | self.repeat_1 = nn.Sequential( 240 | Block17(scale=0.10), 241 | Block17(scale=0.10), 242 | Block17(scale=0.10), 243 | Block17(scale=0.10), 244 | Block17(scale=0.10), 245 | Block17(scale=0.10), 246 | Block17(scale=0.10), 247 | Block17(scale=0.10), 248 | Block17(scale=0.10), 249 | Block17(scale=0.10), 250 | Block17(scale=0.10), 251 | Block17(scale=0.10), 252 | Block17(scale=0.10), 253 | Block17(scale=0.10), 254 | Block17(scale=0.10), 255 | Block17(scale=0.10), 256 | Block17(scale=0.10), 257 | Block17(scale=0.10), 258 | Block17(scale=0.10), 259 | Block17(scale=0.10) 260 | ) 261 | self.mixed_7a = Mixed_7a() 262 | self.repeat_2 = nn.Sequential( 263 | Block8(scale=0.20), 264 | Block8(scale=0.20), 265 | Block8(scale=0.20), 266 | Block8(scale=0.20), 267 | Block8(scale=0.20), 268 | Block8(scale=0.20), 269 | Block8(scale=0.20), 270 | Block8(scale=0.20), 271 | Block8(scale=0.20) 272 | ) 273 | self.block8 = Block8(noReLU=True) 274 | self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1) 275 | self.avgpool_1a = nn.AvgPool2d(8, count_include_pad=False) 276 | self.last_linear = nn.Linear(1536, num_classes) 277 | 278 | def features(self, input): 279 | x = self.conv2d_1a(input) 280 | x = self.conv2d_2a(x) 281 | x = self.conv2d_2b(x) 282 | x = self.maxpool_3a(x) 283 | x = self.conv2d_3b(x) 284 | x = self.conv2d_4a(x) 285 | x = self.maxpool_5a(x) 286 | x = self.mixed_5b(x) 287 | x = self.repeat(x) 288 | x = self.mixed_6a(x) 289 | x = self.repeat_1(x) 290 | x = self.mixed_7a(x) 291 | x = self.repeat_2(x) 292 | x = self.block8(x) 293 | x = self.conv2d_7b(x) 294 | return x 295 | 296 | def logits(self, features): 297 | x = self.avgpool_1a(features) 298 | x = x.view(x.size(0), -1) 299 | x = self.last_linear(x) 300 | return x 301 | 302 | def forward(self, input): 303 | x = self.features(input) 304 | x = self.logits(x) 305 | return x 306 | 307 | 308 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Junnan Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PreResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1): 16 | super(BasicBlock, self).__init__() 17 | self.conv1 = conv3x3(in_planes, planes, stride) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = conv3x3(planes, planes) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | 22 | self.shortcut = nn.Sequential() 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 26 | nn.BatchNorm2d(self.expansion*planes) 27 | ) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | out += self.shortcut(x) 33 | out = F.relu(out) 34 | return out 35 | 36 | 37 | class PreActBlock(nn.Module): 38 | '''Pre-activation version of the BasicBlock.''' 39 | expansion = 1 40 | 41 | def __init__(self, in_planes, planes, stride=1): 42 | super(PreActBlock, self).__init__() 43 | self.bn1 = nn.BatchNorm2d(in_planes) 44 | self.conv1 = conv3x3(in_planes, planes, stride) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv2 = conv3x3(planes, planes) 47 | 48 | self.shortcut = nn.Sequential() 49 | if stride != 1 or in_planes != self.expansion*planes: 50 | self.shortcut = nn.Sequential( 51 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 52 | ) 53 | 54 | def forward(self, x): 55 | out = F.relu(self.bn1(x)) 56 | shortcut = self.shortcut(out) 57 | out = self.conv1(out) 58 | out = self.conv2(F.relu(self.bn2(out))) 59 | out += shortcut 60 | return out 61 | 62 | 63 | class Bottleneck(nn.Module): 64 | expansion = 4 65 | 66 | def __init__(self, in_planes, planes, stride=1): 67 | super(Bottleneck, self).__init__() 68 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 69 | self.bn1 = nn.BatchNorm2d(planes) 70 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 71 | self.bn2 = nn.BatchNorm2d(planes) 72 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 73 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 74 | 75 | self.shortcut = nn.Sequential() 76 | if stride != 1 or in_planes != self.expansion*planes: 77 | self.shortcut = nn.Sequential( 78 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 79 | nn.BatchNorm2d(self.expansion*planes) 80 | ) 81 | 82 | def forward(self, x): 83 | out = F.relu(self.bn1(self.conv1(x))) 84 | out = F.relu(self.bn2(self.conv2(out))) 85 | out = self.bn3(self.conv3(out)) 86 | out += self.shortcut(x) 87 | out = F.relu(out) 88 | return out 89 | 90 | 91 | class PreActBottleneck(nn.Module): 92 | '''Pre-activation version of the original Bottleneck module.''' 93 | expansion = 4 94 | 95 | def __init__(self, in_planes, planes, stride=1): 96 | super(PreActBottleneck, self).__init__() 97 | self.bn1 = nn.BatchNorm2d(in_planes) 98 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 99 | self.bn2 = nn.BatchNorm2d(planes) 100 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 101 | self.bn3 = nn.BatchNorm2d(planes) 102 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 103 | 104 | self.shortcut = nn.Sequential() 105 | if stride != 1 or in_planes != self.expansion*planes: 106 | self.shortcut = nn.Sequential( 107 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 108 | ) 109 | 110 | def forward(self, x): 111 | out = F.relu(self.bn1(x)) 112 | shortcut = self.shortcut(out) 113 | out = self.conv1(out) 114 | out = self.conv2(F.relu(self.bn2(out))) 115 | out = self.conv3(F.relu(self.bn3(out))) 116 | out += shortcut 117 | return out 118 | 119 | 120 | class ResNet(nn.Module): 121 | def __init__(self, block, num_blocks, num_classes=10): 122 | super(ResNet, self).__init__() 123 | self.in_planes = 64 124 | 125 | self.conv1 = conv3x3(3,64) 126 | self.bn1 = nn.BatchNorm2d(64) 127 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 128 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 129 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 130 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 131 | self.linear = nn.Linear(512*block.expansion, num_classes) 132 | 133 | def _make_layer(self, block, planes, num_blocks, stride): 134 | strides = [stride] + [1]*(num_blocks-1) 135 | layers = [] 136 | for stride in strides: 137 | layers.append(block(self.in_planes, planes, stride)) 138 | self.in_planes = planes * block.expansion 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x, lin=0, lout=5): 142 | out = x 143 | if lin < 1 and lout > -1: 144 | out = self.conv1(out) 145 | out = self.bn1(out) 146 | out = F.relu(out) 147 | if lin < 2 and lout > 0: 148 | out = self.layer1(out) 149 | if lin < 3 and lout > 1: 150 | out = self.layer2(out) 151 | if lin < 4 and lout > 2: 152 | out = self.layer3(out) 153 | if lin < 5 and lout > 3: 154 | out = self.layer4(out) 155 | if lout > 4: 156 | out = F.avg_pool2d(out, 4) 157 | out = out.view(out.size(0), -1) 158 | out = self.linear(out) 159 | return out 160 | 161 | 162 | def ResNet18(num_classes=10): 163 | return ResNet(PreActBlock, [2,2,2,2], num_classes=num_classes) 164 | 165 | def ResNet34(num_classes=10): 166 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes) 167 | 168 | def ResNet50(num_classes=10): 169 | return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes) 170 | 171 | def ResNet101(num_classes=10): 172 | return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes) 173 | 174 | def ResNet152(num_classes=10): 175 | return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes) 176 | 177 | 178 | def test(): 179 | net = ResNet18() 180 | y = net(Variable(torch.randn(1,3,32,32))) 181 | print(y.size()) 182 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DivideMix: Learning with Noisy Labels as Semi-supervised Learning 2 | PyTorch Code for the following paper at ICLR2020:\ 3 | Title: DivideMix: Learning with Noisy Labels as Semi-supervised Learning [pdf]\ 4 | Authors:Junnan Li, Richard Socher, Steven C.H. Hoi\ 5 | Institute: Salesforce Research 6 | 7 | 8 | Abstract\ 9 | Deep neural networks are known to be annotation-hungry. Numerous efforts have been devoted to reduce the annotation cost when learning with deep networks. Two prominent directions include learning with noisy labels and semi-supervised learning by exploiting unlabeled data. In this work, we propose DivideMix, a novel framework for learning with noisy labels by leveraging semi-supervised learning techniques. In particular, DivideMix models the per-sample loss distribution with a mixture model to dynamically divide the training data into a labeled set with clean samples and an unlabeled set with noisy samples, and trains the model on both the labeled and unlabeled data in a semi-supervised manner. To avoid confirmation bias, we simultaneously train two diverged networks where each network uses the dataset division from the other network. During the semi-supervised training phase, we improve the MixMatch strategy by performing label co-refinement and label co-guessing on labeled and unlabeled samples, respectively. Experiments on multiple benchmark datasets demonstrate substantial improvements over state-of-the-art methods. 10 | 11 | 12 | Illustration\ 13 | 14 | 15 | Experiments\ 16 | First, please create a folder named checkpoint to store the results.\ 17 | mkdir checkpoint\ 18 | Next, run \ 19 | python Train_{dataset_name}.py --data_path path-to-your-data 20 | 21 | Cite DivideMix\ 22 | If you find the code useful in your research, please consider citing our paper: 23 | 24 |
25 | @inproceedings{
26 |     li2020dividemix,
27 |     title={DivideMix: Learning with Noisy Labels as Semi-supervised Learning},
28 |     author={Junnan Li and Richard Socher and Steven C.H. Hoi},
29 |     booktitle={International Conference on Learning Representations},
30 |     year={2020},
31 | }
32 | 33 | License\ 34 | This project is licensed under the terms of the MIT license. 35 | -------------------------------------------------------------------------------- /Train_cifar.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | import torch.backends.cudnn as cudnn 8 | import random 9 | import os 10 | import argparse 11 | import numpy as np 12 | from PreResNet import * 13 | from sklearn.mixture import GaussianMixture 14 | import dataloader_cifar as dataloader 15 | 16 | parser = argparse.ArgumentParser(description='PyTorch CIFAR Training') 17 | parser.add_argument('--batch_size', default=64, type=int, help='train batchsize') 18 | parser.add_argument('--lr', '--learning_rate', default=0.02, type=float, help='initial learning rate') 19 | parser.add_argument('--noise_mode', default='sym') 20 | parser.add_argument('--alpha', default=4, type=float, help='parameter for Beta') 21 | parser.add_argument('--lambda_u', default=25, type=float, help='weight for unsupervised loss') 22 | parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold') 23 | parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature') 24 | parser.add_argument('--num_epochs', default=300, type=int) 25 | parser.add_argument('--r', default=0.5, type=float, help='noise ratio') 26 | parser.add_argument('--id', default='') 27 | parser.add_argument('--seed', default=123) 28 | parser.add_argument('--gpuid', default=0, type=int) 29 | parser.add_argument('--num_class', default=10, type=int) 30 | parser.add_argument('--data_path', default='./cifar-10', type=str, help='path to dataset') 31 | parser.add_argument('--dataset', default='cifar10', type=str) 32 | args = parser.parse_args() 33 | 34 | torch.cuda.set_device(args.gpuid) 35 | random.seed(args.seed) 36 | torch.manual_seed(args.seed) 37 | torch.cuda.manual_seed_all(args.seed) 38 | 39 | 40 | # Training 41 | def train(epoch,net,net2,optimizer,labeled_trainloader,unlabeled_trainloader): 42 | net.train() 43 | net2.eval() #fix one network and train the other 44 | 45 | unlabeled_train_iter = iter(unlabeled_trainloader) 46 | num_iter = (len(labeled_trainloader.dataset)//args.batch_size)+1 47 | for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader): 48 | try: 49 | inputs_u, inputs_u2 = unlabeled_train_iter.next() 50 | except: 51 | unlabeled_train_iter = iter(unlabeled_trainloader) 52 | inputs_u, inputs_u2 = unlabeled_train_iter.next() 53 | batch_size = inputs_x.size(0) 54 | 55 | # Transform label to one-hot 56 | labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1,1), 1) 57 | w_x = w_x.view(-1,1).type(torch.FloatTensor) 58 | 59 | inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(), inputs_x2.cuda(), labels_x.cuda(), w_x.cuda() 60 | inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda() 61 | 62 | with torch.no_grad(): 63 | # label co-guessing of unlabeled samples 64 | outputs_u11 = net(inputs_u) 65 | outputs_u12 = net(inputs_u2) 66 | outputs_u21 = net2(inputs_u) 67 | outputs_u22 = net2(inputs_u2) 68 | 69 | pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) + torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4 70 | ptu = pu**(1/args.T) # temparature sharpening 71 | 72 | targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize 73 | targets_u = targets_u.detach() 74 | 75 | # label refinement of labeled samples 76 | outputs_x = net(inputs_x) 77 | outputs_x2 = net(inputs_x2) 78 | 79 | px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2 80 | px = w_x*labels_x + (1-w_x)*px 81 | ptx = px**(1/args.T) # temparature sharpening 82 | 83 | targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize 84 | targets_x = targets_x.detach() 85 | 86 | # mixmatch 87 | l = np.random.beta(args.alpha, args.alpha) 88 | l = max(l, 1-l) 89 | 90 | all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0) 91 | all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0) 92 | 93 | idx = torch.randperm(all_inputs.size(0)) 94 | 95 | input_a, input_b = all_inputs, all_inputs[idx] 96 | target_a, target_b = all_targets, all_targets[idx] 97 | 98 | mixed_input = l * input_a + (1 - l) * input_b 99 | mixed_target = l * target_a + (1 - l) * target_b 100 | 101 | logits = net(mixed_input) 102 | logits_x = logits[:batch_size*2] 103 | logits_u = logits[batch_size*2:] 104 | 105 | Lx, Lu, lamb = criterion(logits_x, mixed_target[:batch_size*2], logits_u, mixed_target[batch_size*2:], epoch+batch_idx/num_iter, warm_up) 106 | 107 | # regularization 108 | prior = torch.ones(args.num_class)/args.num_class 109 | prior = prior.cuda() 110 | pred_mean = torch.softmax(logits, dim=1).mean(0) 111 | penalty = torch.sum(prior*torch.log(prior/pred_mean)) 112 | 113 | loss = Lx + lamb * Lu + penalty 114 | # compute gradient and do SGD step 115 | optimizer.zero_grad() 116 | loss.backward() 117 | optimizer.step() 118 | 119 | sys.stdout.write('\r') 120 | sys.stdout.write('%s:%.1f-%s | Epoch [%3d/%3d] Iter[%3d/%3d]\t Labeled loss: %.2f Unlabeled loss: %.2f' 121 | %(args.dataset, args.r, args.noise_mode, epoch, args.num_epochs, batch_idx+1, num_iter, Lx.item(), Lu.item())) 122 | sys.stdout.flush() 123 | 124 | def warmup(epoch,net,optimizer,dataloader): 125 | net.train() 126 | num_iter = (len(dataloader.dataset)//dataloader.batch_size)+1 127 | for batch_idx, (inputs, labels, path) in enumerate(dataloader): 128 | inputs, labels = inputs.cuda(), labels.cuda() 129 | optimizer.zero_grad() 130 | outputs = net(inputs) 131 | loss = CEloss(outputs, labels) 132 | if args.noise_mode=='asym': # penalize confident prediction for asymmetric noise 133 | penalty = conf_penalty(outputs) 134 | L = loss + penalty 135 | elif args.noise_mode=='sym': 136 | L = loss 137 | L.backward() 138 | optimizer.step() 139 | 140 | sys.stdout.write('\r') 141 | sys.stdout.write('%s:%.1f-%s | Epoch [%3d/%3d] Iter[%3d/%3d]\t CE-loss: %.4f' 142 | %(args.dataset, args.r, args.noise_mode, epoch, args.num_epochs, batch_idx+1, num_iter, loss.item())) 143 | sys.stdout.flush() 144 | 145 | def test(epoch,net1,net2): 146 | net1.eval() 147 | net2.eval() 148 | correct = 0 149 | total = 0 150 | with torch.no_grad(): 151 | for batch_idx, (inputs, targets) in enumerate(test_loader): 152 | inputs, targets = inputs.cuda(), targets.cuda() 153 | outputs1 = net1(inputs) 154 | outputs2 = net2(inputs) 155 | outputs = outputs1+outputs2 156 | _, predicted = torch.max(outputs, 1) 157 | 158 | total += targets.size(0) 159 | correct += predicted.eq(targets).cpu().sum().item() 160 | acc = 100.*correct/total 161 | print("\n| Test Epoch #%d\t Accuracy: %.2f%%\n" %(epoch,acc)) 162 | test_log.write('Epoch:%d Accuracy:%.2f\n'%(epoch,acc)) 163 | test_log.flush() 164 | 165 | def eval_train(model,all_loss): 166 | model.eval() 167 | losses = torch.zeros(50000) 168 | with torch.no_grad(): 169 | for batch_idx, (inputs, targets, index) in enumerate(eval_loader): 170 | inputs, targets = inputs.cuda(), targets.cuda() 171 | outputs = model(inputs) 172 | loss = CE(outputs, targets) 173 | for b in range(inputs.size(0)): 174 | losses[index[b]]=loss[b] 175 | losses = (losses-losses.min())/(losses.max()-losses.min()) 176 | all_loss.append(losses) 177 | 178 | if args.r==0.9: # average loss over last 5 epochs to improve convergence stability 179 | history = torch.stack(all_loss) 180 | input_loss = history[-5:].mean(0) 181 | input_loss = input_loss.reshape(-1,1) 182 | else: 183 | input_loss = losses.reshape(-1,1) 184 | 185 | # fit a two-component GMM to the loss 186 | gmm = GaussianMixture(n_components=2,max_iter=10,tol=1e-2,reg_covar=5e-4) 187 | gmm.fit(input_loss) 188 | prob = gmm.predict_proba(input_loss) 189 | prob = prob[:,gmm.means_.argmin()] 190 | return prob,all_loss 191 | 192 | def linear_rampup(current, warm_up, rampup_length=16): 193 | current = np.clip((current-warm_up) / rampup_length, 0.0, 1.0) 194 | return args.lambda_u*float(current) 195 | 196 | class SemiLoss(object): 197 | def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up): 198 | probs_u = torch.softmax(outputs_u, dim=1) 199 | 200 | Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1)) 201 | Lu = torch.mean((probs_u - targets_u)**2) 202 | 203 | return Lx, Lu, linear_rampup(epoch,warm_up) 204 | 205 | class NegEntropy(object): 206 | def __call__(self,outputs): 207 | probs = torch.softmax(outputs, dim=1) 208 | return torch.mean(torch.sum(probs.log()*probs, dim=1)) 209 | 210 | def create_model(): 211 | model = ResNet18(num_classes=args.num_class) 212 | model = model.cuda() 213 | return model 214 | 215 | stats_log=open('./checkpoint/%s_%.1f_%s'%(args.dataset,args.r,args.noise_mode)+'_stats.txt','w') 216 | test_log=open('./checkpoint/%s_%.1f_%s'%(args.dataset,args.r,args.noise_mode)+'_acc.txt','w') 217 | 218 | if args.dataset=='cifar10': 219 | warm_up = 10 220 | elif args.dataset=='cifar100': 221 | warm_up = 30 222 | 223 | loader = dataloader.cifar_dataloader(args.dataset,r=args.r,noise_mode=args.noise_mode,batch_size=args.batch_size,num_workers=5,\ 224 | root_dir=args.data_path,log=stats_log,noise_file='%s/%.1f_%s.json'%(args.data_path,args.r,args.noise_mode)) 225 | 226 | print('| Building net') 227 | net1 = create_model() 228 | net2 = create_model() 229 | cudnn.benchmark = True 230 | 231 | criterion = SemiLoss() 232 | optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 233 | optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 234 | 235 | CE = nn.CrossEntropyLoss(reduction='none') 236 | CEloss = nn.CrossEntropyLoss() 237 | if args.noise_mode=='asym': 238 | conf_penalty = NegEntropy() 239 | 240 | all_loss = [[],[]] # save the history of losses from two networks 241 | 242 | for epoch in range(args.num_epochs+1): 243 | lr=args.lr 244 | if epoch >= 150: 245 | lr /= 10 246 | for param_group in optimizer1.param_groups: 247 | param_group['lr'] = lr 248 | for param_group in optimizer2.param_groups: 249 | param_group['lr'] = lr 250 | test_loader = loader.run('test') 251 | eval_loader = loader.run('eval_train') 252 | 253 | if epoch args.p_threshold) 265 | pred2 = (prob2 > args.p_threshold) 266 | 267 | print('Train Net1') 268 | labeled_trainloader, unlabeled_trainloader = loader.run('train',pred2,prob2) # co-divide 269 | train(epoch,net1,net2,optimizer1,labeled_trainloader, unlabeled_trainloader) # train net1 270 | 271 | print('\nTrain Net2') 272 | labeled_trainloader, unlabeled_trainloader = loader.run('train',pred1,prob1) # co-divide 273 | train(epoch,net2,net1,optimizer2,labeled_trainloader, unlabeled_trainloader) # train net2 274 | 275 | test(epoch,net1,net2) 276 | 277 | 278 | -------------------------------------------------------------------------------- /Train_clothing1M.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | import torch.backends.cudnn as cudnn 8 | import torchvision 9 | import torchvision.models as models 10 | import random 11 | import os 12 | import argparse 13 | import numpy as np 14 | import dataloader_clothing1M as dataloader 15 | from sklearn.mixture import GaussianMixture 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch Clothing1M Training') 18 | parser.add_argument('--batch_size', default=32, type=int, help='train batchsize') 19 | parser.add_argument('--lr', '--learning_rate', default=0.002, type=float, help='initial learning rate') 20 | parser.add_argument('--alpha', default=0.5, type=float, help='parameter for Beta') 21 | parser.add_argument('--lambda_u', default=0, type=float, help='weight for unsupervised loss') 22 | parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold') 23 | parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature') 24 | parser.add_argument('--num_epochs', default=80, type=int) 25 | parser.add_argument('--id', default='clothing1m') 26 | parser.add_argument('--data_path', default='../../Clothing1M/data', type=str, help='path to dataset') 27 | parser.add_argument('--seed', default=123) 28 | parser.add_argument('--gpuid', default=0, type=int) 29 | parser.add_argument('--num_class', default=14, type=int) 30 | parser.add_argument('--num_batches', default=1000, type=int) 31 | args = parser.parse_args() 32 | 33 | torch.cuda.set_device(args.gpuid) 34 | random.seed(args.seed) 35 | torch.manual_seed(args.seed) 36 | torch.cuda.manual_seed_all(args.seed) 37 | 38 | # Training 39 | def train(epoch,net,net2,optimizer,labeled_trainloader,unlabeled_trainloader): 40 | net.train() 41 | net2.eval() #fix one network and train the other 42 | 43 | unlabeled_train_iter = iter(unlabeled_trainloader) 44 | num_iter = (len(labeled_trainloader.dataset)//args.batch_size)+1 45 | for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader): 46 | try: 47 | inputs_u, inputs_u2 = unlabeled_train_iter.next() 48 | except: 49 | unlabeled_train_iter = iter(unlabeled_trainloader) 50 | inputs_u, inputs_u2 = unlabeled_train_iter.next() 51 | batch_size = inputs_x.size(0) 52 | 53 | # Transform label to one-hot 54 | labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1,1), 1) 55 | w_x = w_x.view(-1,1).type(torch.FloatTensor) 56 | 57 | inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(), inputs_x2.cuda(), labels_x.cuda(), w_x.cuda() 58 | inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda() 59 | 60 | with torch.no_grad(): 61 | # label co-guessing of unlabeled samples 62 | outputs_u11 = net(inputs_u) 63 | outputs_u12 = net(inputs_u2) 64 | outputs_u21 = net2(inputs_u) 65 | outputs_u22 = net2(inputs_u2) 66 | 67 | pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) + torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4 68 | ptu = pu**(1/args.T) # temparature sharpening 69 | 70 | targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize 71 | targets_u = targets_u.detach() 72 | 73 | # label refinement of labeled samples 74 | outputs_x = net(inputs_x) 75 | outputs_x2 = net(inputs_x2) 76 | 77 | px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2 78 | px = w_x*labels_x + (1-w_x)*px 79 | ptx = px**(1/args.T) # temparature sharpening 80 | 81 | targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize 82 | targets_x = targets_x.detach() 83 | 84 | # mixmatch 85 | l = np.random.beta(args.alpha, args.alpha) 86 | l = max(l, 1-l) 87 | 88 | all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0) 89 | all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0) 90 | 91 | idx = torch.randperm(all_inputs.size(0)) 92 | 93 | input_a, input_b = all_inputs, all_inputs[idx] 94 | target_a, target_b = all_targets, all_targets[idx] 95 | 96 | mixed_input = l * input_a[:batch_size*2] + (1 - l) * input_b[:batch_size*2] 97 | mixed_target = l * target_a[:batch_size*2] + (1 - l) * target_b[:batch_size*2] 98 | 99 | logits = net(mixed_input) 100 | 101 | Lx = -torch.mean(torch.sum(F.log_softmax(logits, dim=1) * mixed_target, dim=1)) 102 | 103 | # regularization 104 | prior = torch.ones(args.num_class)/args.num_class 105 | prior = prior.cuda() 106 | pred_mean = torch.softmax(logits, dim=1).mean(0) 107 | penalty = torch.sum(prior*torch.log(prior/pred_mean)) 108 | 109 | loss = Lx + penalty 110 | 111 | # compute gradient and do SGD step 112 | optimizer.zero_grad() 113 | loss.backward() 114 | optimizer.step() 115 | 116 | sys.stdout.write('\r') 117 | sys.stdout.write('Clothing1M | Epoch [%3d/%3d] Iter[%3d/%3d]\t Labeled loss: %.4f ' 118 | %(epoch, args.num_epochs, batch_idx+1, num_iter, Lx.item())) 119 | sys.stdout.flush() 120 | 121 | def warmup(net,optimizer,dataloader): 122 | net.train() 123 | for batch_idx, (inputs, labels, path) in enumerate(dataloader): 124 | inputs, labels = inputs.cuda(), labels.cuda() 125 | optimizer.zero_grad() 126 | outputs = net(inputs) 127 | loss = CEloss(outputs, labels) 128 | 129 | penalty = conf_penalty(outputs) 130 | L = loss + penalty 131 | L.backward() 132 | optimizer.step() 133 | 134 | sys.stdout.write('\r') 135 | sys.stdout.write('|Warm-up: Iter[%3d/%3d]\t CE-loss: %.4f Conf-Penalty: %.4f' 136 | %(batch_idx+1, args.num_batches, loss.item(), penalty.item())) 137 | sys.stdout.flush() 138 | 139 | def val(net,val_loader,k): 140 | net.eval() 141 | correct = 0 142 | total = 0 143 | with torch.no_grad(): 144 | for batch_idx, (inputs, targets) in enumerate(val_loader): 145 | inputs, targets = inputs.cuda(), targets.cuda() 146 | outputs = net(inputs) 147 | _, predicted = torch.max(outputs, 1) 148 | 149 | total += targets.size(0) 150 | correct += predicted.eq(targets).cpu().sum().item() 151 | acc = 100.*correct/total 152 | print("\n| Validation\t Net%d Acc: %.2f%%" %(k,acc)) 153 | if acc > best_acc[k-1]: 154 | best_acc[k-1] = acc 155 | print('| Saving Best Net%d ...'%k) 156 | save_point = './checkpoint/%s_net%d.pth.tar'%(args.id,k) 157 | torch.save(net.state_dict(), save_point) 158 | return acc 159 | 160 | def test(net1,net2,test_loader): 161 | net1.eval() 162 | net2.eval() 163 | correct = 0 164 | total = 0 165 | with torch.no_grad(): 166 | for batch_idx, (inputs, targets) in enumerate(test_loader): 167 | inputs, targets = inputs.cuda(), targets.cuda() 168 | outputs1 = net1(inputs) 169 | outputs2 = net2(inputs) 170 | outputs = outputs1+outputs2 171 | _, predicted = torch.max(outputs, 1) 172 | 173 | total += targets.size(0) 174 | correct += predicted.eq(targets).cpu().sum().item() 175 | acc = 100.*correct/total 176 | print("\n| Test Acc: %.2f%%\n" %(acc)) 177 | return acc 178 | 179 | def eval_train(epoch,model): 180 | model.eval() 181 | num_samples = args.num_batches*args.batch_size 182 | losses = torch.zeros(num_samples) 183 | paths = [] 184 | n=0 185 | with torch.no_grad(): 186 | for batch_idx, (inputs, targets, path) in enumerate(eval_loader): 187 | inputs, targets = inputs.cuda(), targets.cuda() 188 | outputs = model(inputs) 189 | loss = CE(outputs, targets) 190 | for b in range(inputs.size(0)): 191 | losses[n]=loss[b] 192 | paths.append(path[b]) 193 | n+=1 194 | sys.stdout.write('\r') 195 | sys.stdout.write('| Evaluating loss Iter %3d\t' %(batch_idx)) 196 | sys.stdout.flush() 197 | 198 | losses = (losses-losses.min())/(losses.max()-losses.min()) 199 | losses = losses.reshape(-1,1) 200 | gmm = GaussianMixture(n_components=2,max_iter=10,reg_covar=5e-4,tol=1e-2) 201 | gmm.fit(losses) 202 | prob = gmm.predict_proba(losses) 203 | prob = prob[:,gmm.means_.argmin()] 204 | return prob,paths 205 | 206 | class NegEntropy(object): 207 | def __call__(self,outputs): 208 | probs = torch.softmax(outputs, dim=1) 209 | return torch.mean(torch.sum(probs.log()*probs, dim=1)) 210 | 211 | def create_model(): 212 | model = models.resnet50(pretrained=True) 213 | model.fc = nn.Linear(2048,args.num_class) 214 | model = model.cuda() 215 | return model 216 | 217 | log=open('./checkpoint/%s.txt'%args.id,'w') 218 | log.flush() 219 | 220 | loader = dataloader.clothing_dataloader(root=args.data_path,batch_size=args.batch_size,num_workers=5,num_batches=args.num_batches) 221 | 222 | print('| Building net') 223 | net1 = create_model() 224 | net2 = create_model() 225 | cudnn.benchmark = True 226 | 227 | optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-3) 228 | optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-3) 229 | 230 | CE = nn.CrossEntropyLoss(reduction='none') 231 | CEloss = nn.CrossEntropyLoss() 232 | conf_penalty = NegEntropy() 233 | 234 | best_acc = [0,0] 235 | for epoch in range(args.num_epochs+1): 236 | lr=args.lr 237 | if epoch >= 40: 238 | lr /= 10 239 | for param_group in optimizer1.param_groups: 240 | param_group['lr'] = lr 241 | for param_group in optimizer2.param_groups: 242 | param_group['lr'] = lr 243 | 244 | if epoch<1: # warm up 245 | train_loader = loader.run('warmup') 246 | print('Warmup Net1') 247 | warmup(net1,optimizer1,train_loader) 248 | train_loader = loader.run('warmup') 249 | print('\nWarmup Net2') 250 | warmup(net2,optimizer2,train_loader) 251 | else: 252 | pred1 = (prob1 > args.p_threshold) # divide dataset 253 | pred2 = (prob2 > args.p_threshold) 254 | 255 | print('\n\nTrain Net1') 256 | labeled_trainloader, unlabeled_trainloader = loader.run('train',pred2,prob2,paths=paths2) # co-divide 257 | train(epoch,net1,net2,optimizer1,labeled_trainloader, unlabeled_trainloader) # train net1 258 | print('\nTrain Net2') 259 | labeled_trainloader, unlabeled_trainloader = loader.run('train',pred1,prob1,paths=paths1) # co-divide 260 | train(epoch,net2,net1,optimizer2,labeled_trainloader, unlabeled_trainloader) # train net2 261 | 262 | val_loader = loader.run('val') # validation 263 | acc1 = val(net1,val_loader,1) 264 | acc2 = val(net2,val_loader,2) 265 | log.write('Validation Epoch:%d Acc1:%.2f Acc2:%.2f\n'%(epoch,acc1,acc2)) 266 | log.flush() 267 | print('\n==== net 1 evaluate next epoch training data loss ====') 268 | eval_loader = loader.run('eval_train') # evaluate training data loss for next epoch 269 | prob1,paths1 = eval_train(epoch,net1) 270 | print('\n==== net 2 evaluate next epoch training data loss ====') 271 | eval_loader = loader.run('eval_train') 272 | prob2,paths2 = eval_train(epoch,net2) 273 | 274 | test_loader = loader.run('test') 275 | net1.load_state_dict(torch.load('./checkpoint/%s_net1.pth.tar'%args.id)) 276 | net2.load_state_dict(torch.load('./checkpoint/%s_net2.pth.tar'%args.id)) 277 | acc = test(net1,net2,test_loader) 278 | 279 | log.write('Test Accuracy:%.2f\n'%(acc)) 280 | log.flush() 281 | -------------------------------------------------------------------------------- /Train_webvision.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | import torch.backends.cudnn as cudnn 8 | import random 9 | import os 10 | import sys 11 | import argparse 12 | import numpy as np 13 | from InceptionResNetV2 import * 14 | from sklearn.mixture import GaussianMixture 15 | import dataloader_webvision as dataloader 16 | import torchnet 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch WebVision Training') 19 | parser.add_argument('--batch_size', default=32, type=int, help='train batchsize') 20 | parser.add_argument('--lr', '--learning_rate', default=0.01, type=float, help='initial learning rate') 21 | parser.add_argument('--alpha', default=0.5, type=float, help='parameter for Beta') 22 | parser.add_argument('--lambda_u', default=0, type=float, help='weight for unsupervised loss') 23 | parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold') 24 | parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature') 25 | parser.add_argument('--num_epochs', default=80, type=int) 26 | parser.add_argument('--id', default='',type=str) 27 | parser.add_argument('--seed', default=123) 28 | parser.add_argument('--gpuid', default=0, type=int) 29 | parser.add_argument('--num_class', default=50, type=int) 30 | parser.add_argument('--data_path', default='./dataset/', type=str, help='path to dataset') 31 | 32 | args = parser.parse_args() 33 | 34 | torch.cuda.set_device(args.gpuid) 35 | random.seed(args.seed) 36 | torch.manual_seed(args.seed) 37 | torch.cuda.manual_seed_all(args.seed) 38 | 39 | 40 | # Training 41 | def train(epoch,net,net2,optimizer,labeled_trainloader,unlabeled_trainloader): 42 | net.train() 43 | net2.eval() #fix one network and train the other 44 | 45 | unlabeled_train_iter = iter(unlabeled_trainloader) 46 | num_iter = (len(labeled_trainloader.dataset)//args.batch_size)+1 47 | for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader): 48 | try: 49 | inputs_u, inputs_u2 = unlabeled_train_iter.next() 50 | except: 51 | unlabeled_train_iter = iter(unlabeled_trainloader) 52 | inputs_u, inputs_u2 = unlabeled_train_iter.next() 53 | batch_size = inputs_x.size(0) 54 | 55 | # Transform label to one-hot 56 | labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1,1), 1) 57 | w_x = w_x.view(-1,1).type(torch.FloatTensor) 58 | 59 | inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(), inputs_x2.cuda(), labels_x.cuda(), w_x.cuda() 60 | inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda() 61 | 62 | with torch.no_grad(): 63 | # label co-guessing of unlabeled samples 64 | outputs_u11 = net(inputs_u) 65 | outputs_u12 = net(inputs_u2) 66 | outputs_u21 = net2(inputs_u) 67 | outputs_u22 = net2(inputs_u2) 68 | 69 | pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) + torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4 70 | ptu = pu**(1/args.T) # temparature sharpening 71 | 72 | targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize 73 | targets_u = targets_u.detach() 74 | 75 | # label refinement of labeled samples 76 | outputs_x = net(inputs_x) 77 | outputs_x2 = net(inputs_x2) 78 | 79 | px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2 80 | px = w_x*labels_x + (1-w_x)*px 81 | ptx = px**(1/args.T) # temparature sharpening 82 | 83 | targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize 84 | targets_x = targets_x.detach() 85 | 86 | # mixmatch 87 | l = np.random.beta(args.alpha, args.alpha) 88 | l = max(l, 1-l) 89 | 90 | all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0) 91 | all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0) 92 | 93 | idx = torch.randperm(all_inputs.size(0)) 94 | 95 | input_a, input_b = all_inputs, all_inputs[idx] 96 | target_a, target_b = all_targets, all_targets[idx] 97 | 98 | mixed_input = l * input_a[:batch_size*2] + (1 - l) * input_b[:batch_size*2] 99 | mixed_target = l * target_a[:batch_size*2] + (1 - l) * target_b[:batch_size*2] 100 | 101 | logits = net(mixed_input) 102 | 103 | Lx = -torch.mean(torch.sum(F.log_softmax(logits, dim=1) * mixed_target, dim=1)) 104 | 105 | prior = torch.ones(args.num_class)/args.num_class 106 | prior = prior.cuda() 107 | pred_mean = torch.softmax(logits, dim=1).mean(0) 108 | penalty = torch.sum(prior*torch.log(prior/pred_mean)) 109 | 110 | loss = Lx + penalty 111 | # compute gradient and do SGD step 112 | optimizer.zero_grad() 113 | loss.backward() 114 | optimizer.step() 115 | 116 | sys.stdout.write('\r') 117 | sys.stdout.write('%s | Epoch [%3d/%3d] Iter[%4d/%4d]\t Labeled loss: %.2f' 118 | %(args.id, epoch, args.num_epochs, batch_idx+1, num_iter, Lx.item())) 119 | sys.stdout.flush() 120 | 121 | def warmup(epoch,net,optimizer,dataloader): 122 | net.train() 123 | num_iter = (len(dataloader.dataset)//dataloader.batch_size)+1 124 | for batch_idx, (inputs, labels, path) in enumerate(dataloader): 125 | inputs, labels = inputs.cuda(), labels.cuda() 126 | optimizer.zero_grad() 127 | outputs = net(inputs) 128 | loss = CEloss(outputs, labels) 129 | 130 | #penalty = conf_penalty(outputs) 131 | L = loss #+ penalty 132 | 133 | L.backward() 134 | optimizer.step() 135 | 136 | sys.stdout.write('\r') 137 | sys.stdout.write('%s | Epoch [%3d/%3d] Iter[%4d/%4d]\t CE-loss: %.4f' 138 | %(args.id, epoch, args.num_epochs, batch_idx+1, num_iter, loss.item())) 139 | sys.stdout.flush() 140 | 141 | 142 | def test(epoch,net1,net2,test_loader): 143 | acc_meter.reset() 144 | net1.eval() 145 | net2.eval() 146 | correct = 0 147 | total = 0 148 | with torch.no_grad(): 149 | for batch_idx, (inputs, targets) in enumerate(test_loader): 150 | inputs, targets = inputs.cuda(), targets.cuda() 151 | outputs1 = net1(inputs) 152 | outputs2 = net2(inputs) 153 | outputs = outputs1+outputs2 154 | _, predicted = torch.max(outputs, 1) 155 | acc_meter.add(outputs,targets) 156 | accs = acc_meter.value() 157 | return accs 158 | 159 | 160 | def eval_train(model,all_loss): 161 | model.eval() 162 | num_iter = (len(eval_loader.dataset)//eval_loader.batch_size)+1 163 | losses = torch.zeros(len(eval_loader.dataset)) 164 | with torch.no_grad(): 165 | for batch_idx, (inputs, targets, index) in enumerate(eval_loader): 166 | inputs, targets = inputs.cuda(), targets.cuda() 167 | outputs = model(inputs) 168 | loss = CE(outputs, targets) 169 | for b in range(inputs.size(0)): 170 | losses[index[b]]=loss[b] 171 | sys.stdout.write('\r') 172 | sys.stdout.write('| Evaluating loss Iter[%3d/%3d]\t' %(batch_idx,num_iter)) 173 | sys.stdout.flush() 174 | 175 | losses = (losses-losses.min())/(losses.max()-losses.min()) 176 | all_loss.append(losses) 177 | 178 | # fit a two-component GMM to the loss 179 | input_loss = losses.reshape(-1,1) 180 | gmm = GaussianMixture(n_components=2,max_iter=10,tol=1e-2,reg_covar=5e-4) 181 | gmm.fit(input_loss) 182 | prob = gmm.predict_proba(input_loss) 183 | prob = prob[:,gmm.means_.argmin()] 184 | return prob,all_loss 185 | 186 | def linear_rampup(current, warm_up, rampup_length=16): 187 | current = np.clip((current-warm_up) / rampup_length, 0.0, 1.0) 188 | return args.lambda_u*float(current) 189 | 190 | class SemiLoss(object): 191 | def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up): 192 | probs_u = torch.softmax(outputs_u, dim=1) 193 | 194 | Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1)) 195 | Lu = torch.mean((probs_u - targets_u)**2) 196 | 197 | return Lx, Lu, linear_rampup(epoch,warm_up) 198 | 199 | class NegEntropy(object): 200 | def __call__(self,outputs): 201 | probs = torch.softmax(outputs, dim=1) 202 | return torch.mean(torch.sum(probs.log()*probs, dim=1)) 203 | 204 | def create_model(): 205 | model = InceptionResNetV2(num_classes=args.num_class) 206 | model = model.cuda() 207 | return model 208 | 209 | stats_log=open('./checkpoint/%s'%(args.id)+'_stats.txt','w') 210 | test_log=open('./checkpoint/%s'%(args.id)+'_acc.txt','w') 211 | 212 | warm_up=1 213 | 214 | loader = dataloader.webvision_dataloader(batch_size=args.batch_size,num_workers=5,root_dir=args.data_path,log=stats_log, num_class=args.num_class) 215 | 216 | print('| Building net') 217 | net1 = create_model() 218 | net2 = create_model() 219 | cudnn.benchmark = True 220 | 221 | criterion = SemiLoss() 222 | optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 223 | optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 224 | 225 | CE = nn.CrossEntropyLoss(reduction='none') 226 | CEloss = nn.CrossEntropyLoss() 227 | conf_penalty = NegEntropy() 228 | 229 | all_loss = [[],[]] # save the history of losses from two networks 230 | acc_meter = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True) 231 | 232 | for epoch in range(args.num_epochs+1): 233 | lr=args.lr 234 | if epoch >= 40: 235 | lr /= 10 236 | for param_group in optimizer1.param_groups: 237 | param_group['lr'] = lr 238 | for param_group in optimizer2.param_groups: 239 | param_group['lr'] = lr 240 | eval_loader = loader.run('eval_train') 241 | web_valloader = loader.run('test') 242 | imagenet_valloader = loader.run('imagenet') 243 | 244 | if epoch args.p_threshold) 253 | pred2 = (prob2 > args.p_threshold) 254 | 255 | print('Train Net1') 256 | labeled_trainloader, unlabeled_trainloader = loader.run('train',pred2,prob2) # co-divide 257 | train(epoch,net1,net2,optimizer1,labeled_trainloader, unlabeled_trainloader) # train net1 258 | 259 | print('\nTrain Net2') 260 | labeled_trainloader, unlabeled_trainloader = loader.run('train',pred1,prob1) # co-divide 261 | train(epoch,net2,net1,optimizer2,labeled_trainloader, unlabeled_trainloader) # train net2 262 | 263 | 264 | web_acc = test(epoch,net1,net2,web_valloader) 265 | imagenet_acc = test(epoch,net1,net2,imagenet_valloader) 266 | 267 | print("\n| Test Epoch #%d\t WebVision Acc: %.2f%% (%.2f%%) \t ImageNet Acc: %.2f%% (%.2f%%)\n"%(epoch,web_acc[0],web_acc[1],imagenet_acc[0],imagenet_acc[1])) 268 | test_log.write('Epoch:%d \t WebVision Acc: %.2f%% (%.2f%%) \t ImageNet Acc: %.2f%% (%.2f%%)\n'%(epoch,web_acc[0],web_acc[1],imagenet_acc[0],imagenet_acc[1])) 269 | test_log.flush() 270 | 271 | print('\n==== net 1 evaluate training data loss ====') 272 | prob1,all_loss[0]=eval_train(net1,all_loss[0]) 273 | print('\n==== net 2 evaluate training data loss ====') 274 | prob2,all_loss[1]=eval_train(net2,all_loss[1]) 275 | torch.save(all_loss,'./checkpoint/%s.pth.tar'%(args.id)) 276 | 277 | -------------------------------------------------------------------------------- /Train_webvision_parallel.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | import torch.backends.cudnn as cudnn 8 | import random 9 | import os 10 | import sys 11 | import argparse 12 | import numpy as np 13 | from InceptionResNetV2 import * 14 | from sklearn.mixture import GaussianMixture 15 | import dataloader_webvision as dataloader 16 | import torchnet 17 | import torch.multiprocessing as mp 18 | 19 | parser = argparse.ArgumentParser(description='PyTorch WebVision Parallel Training') 20 | parser.add_argument('--batch_size', default=32, type=int, help='train batchsize') 21 | parser.add_argument('--lr', '--learning_rate', default=0.01, type=float, help='initial learning rate') 22 | parser.add_argument('--alpha', default=0.5, type=float, help='parameter for Beta') 23 | parser.add_argument('--lambda_u', default=0, type=float, help='weight for unsupervised loss') 24 | parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold') 25 | parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature') 26 | parser.add_argument('--num_epochs', default=100, type=int) 27 | parser.add_argument('--id', default='',type=str) 28 | parser.add_argument('--seed', default=123) 29 | parser.add_argument('--gpuid1', default=0, type=int) 30 | parser.add_argument('--gpuid2', default=1, type=int) 31 | parser.add_argument('--num_class', default=50, type=int) 32 | parser.add_argument('--data_path', default='./dataset/', type=str, help='path to dataset') 33 | 34 | args = parser.parse_args() 35 | 36 | os.environ["CUDA_VISIBLE_DEVICES"] = '%s,%s'%(args.gpuid1,args.gpuid2) 37 | random.seed(args.seed) 38 | cuda1 = torch.device('cuda:0') 39 | cuda2 = torch.device('cuda:1') 40 | 41 | # Training 42 | def train(epoch,net,net2,optimizer,labeled_trainloader,unlabeled_trainloader,device,whichnet): 43 | criterion = SemiLoss() 44 | 45 | net.train() 46 | net2.eval() #fix one network and train the other 47 | 48 | unlabeled_train_iter = iter(unlabeled_trainloader) 49 | num_iter = (len(labeled_trainloader.dataset)//args.batch_size)+1 50 | for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader): 51 | try: 52 | inputs_u, inputs_u2 = unlabeled_train_iter.next() 53 | except: 54 | unlabeled_train_iter = iter(unlabeled_trainloader) 55 | inputs_u, inputs_u2 = unlabeled_train_iter.next() 56 | batch_size = inputs_x.size(0) 57 | 58 | # Transform label to one-hot 59 | labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1,1), 1) 60 | w_x = w_x.view(-1,1).type(torch.FloatTensor) 61 | 62 | inputs_x, inputs_x2, labels_x, w_x = inputs_x.to(device,non_blocking=True), inputs_x2.to(device,non_blocking=True), labels_x.to(device,non_blocking=True), w_x.to(device,non_blocking=True) 63 | inputs_u, inputs_u2 = inputs_u.to(device), inputs_u2.to(device) 64 | 65 | with torch.no_grad(): 66 | # label co-guessing of unlabeled samples 67 | outputs_u11 = net(inputs_u) 68 | outputs_u12 = net(inputs_u2) 69 | outputs_u21 = net2(inputs_u) 70 | outputs_u22 = net2(inputs_u2) 71 | 72 | pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) + torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4 73 | ptu = pu**(1/args.T) # temparature sharpening 74 | 75 | targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize 76 | targets_u = targets_u.detach() 77 | 78 | # label refinement of labeled samples 79 | outputs_x = net(inputs_x) 80 | outputs_x2 = net(inputs_x2) 81 | 82 | px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2 83 | px = w_x*labels_x + (1-w_x)*px 84 | ptx = px**(1/args.T) # temparature sharpening 85 | 86 | targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize 87 | targets_x = targets_x.detach() 88 | 89 | # mixmatch 90 | l = np.random.beta(args.alpha, args.alpha) 91 | l = max(l, 1-l) 92 | 93 | all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0) 94 | all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0) 95 | 96 | idx = torch.randperm(all_inputs.size(0)) 97 | 98 | input_a, input_b = all_inputs, all_inputs[idx] 99 | target_a, target_b = all_targets, all_targets[idx] 100 | 101 | mixed_input = l * input_a[:batch_size*2] + (1 - l) * input_b[:batch_size*2] 102 | mixed_target = l * target_a[:batch_size*2] + (1 - l) * target_b[:batch_size*2] 103 | 104 | logits = net(mixed_input) 105 | 106 | Lx = -torch.mean(torch.sum(F.log_softmax(logits, dim=1) * mixed_target, dim=1)) 107 | 108 | prior = torch.ones(args.num_class)/args.num_class 109 | prior = prior.to(device) 110 | pred_mean = torch.softmax(logits, dim=1).mean(0) 111 | penalty = torch.sum(prior*torch.log(prior/pred_mean)) 112 | 113 | loss = Lx + penalty 114 | # compute gradient and do SGD step 115 | optimizer.zero_grad() 116 | loss.backward() 117 | optimizer.step() 118 | 119 | sys.stdout.write('\n') 120 | sys.stdout.write('%s |%s Epoch [%3d/%3d] Iter[%4d/%4d]\t Labeled loss: %.2f' 121 | %(args.id, whichnet, epoch, args.num_epochs, batch_idx+1, num_iter, Lx.item())) 122 | sys.stdout.flush() 123 | 124 | def warmup(epoch,net,optimizer,dataloader,device,whichnet): 125 | CEloss = nn.CrossEntropyLoss() 126 | acc_meter = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True) 127 | 128 | net.train() 129 | num_iter = (len(dataloader.dataset)//dataloader.batch_size)+1 130 | for batch_idx, (inputs, labels, path) in enumerate(dataloader): 131 | inputs, labels = inputs.to(device), labels.to(device,non_blocking=True) 132 | optimizer.zero_grad() 133 | outputs = net(inputs) 134 | loss = CEloss(outputs, labels) 135 | 136 | #penalty = conf_penalty(outputs) 137 | L = loss #+ penalty 138 | 139 | L.backward() 140 | optimizer.step() 141 | 142 | sys.stdout.write('\n') 143 | sys.stdout.write('%s |%s Epoch [%3d/%3d] Iter[%4d/%4d]\t CE-loss: %.4f' 144 | %(args.id, whichnet, epoch, args.num_epochs, batch_idx+1, num_iter, loss.item())) 145 | sys.stdout.flush() 146 | 147 | 148 | def test(epoch,net1,net2,test_loader,device,queue): 149 | acc_meter = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True) 150 | acc_meter.reset() 151 | net1.eval() 152 | net2.eval() 153 | with torch.no_grad(): 154 | for batch_idx, (inputs, targets) in enumerate(test_loader): 155 | inputs, targets = inputs.to(device), targets.to(device,non_blocking=True) 156 | outputs1 = net1(inputs) 157 | outputs2 = net2(inputs) 158 | outputs = outputs1+outputs2 159 | _, predicted = torch.max(outputs, 1) 160 | acc_meter.add(outputs,targets) 161 | accs = acc_meter.value() 162 | queue.put(accs) 163 | 164 | 165 | def eval_train(eval_loader,model,device,whichnet,queue): 166 | CE = nn.CrossEntropyLoss(reduction='none') 167 | model.eval() 168 | num_iter = (len(eval_loader.dataset)//eval_loader.batch_size)+1 169 | losses = torch.zeros(len(eval_loader.dataset)) 170 | with torch.no_grad(): 171 | for batch_idx, (inputs, targets, index) in enumerate(eval_loader): 172 | inputs, targets = inputs.to(device), targets.to(device,non_blocking=True) 173 | outputs = model(inputs) 174 | loss = CE(outputs, targets) 175 | for b in range(inputs.size(0)): 176 | losses[index[b]]=loss[b] 177 | sys.stdout.write('\n') 178 | sys.stdout.write('|%s Evaluating loss Iter[%3d/%3d]\t' %(whichnet,batch_idx,num_iter)) 179 | sys.stdout.flush() 180 | 181 | losses = (losses-losses.min())/(losses.max()-losses.min()) 182 | 183 | # fit a two-component GMM to the loss 184 | input_loss = losses.reshape(-1,1) 185 | gmm = GaussianMixture(n_components=2,max_iter=10,tol=1e-2,reg_covar=1e-3) 186 | gmm.fit(input_loss) 187 | prob = gmm.predict_proba(input_loss) 188 | prob = prob[:,gmm.means_.argmin()] 189 | queue.put(prob) 190 | 191 | def linear_rampup(current, warm_up, rampup_length=16): 192 | current = np.clip((current-warm_up) / rampup_length, 0.0, 1.0) 193 | return args.lambda_u*float(current) 194 | 195 | class SemiLoss(object): 196 | def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up): 197 | probs_u = torch.softmax(outputs_u, dim=1) 198 | 199 | Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1)) 200 | Lu = torch.mean((probs_u - targets_u)**2) 201 | 202 | return Lx, Lu, linear_rampup(epoch,warm_up) 203 | 204 | class NegEntropy(object): 205 | def __call__(self,outputs): 206 | probs = torch.softmax(outputs, dim=1) 207 | return torch.mean(torch.sum(probs.log()*probs, dim=1)) 208 | 209 | def create_model(device): 210 | model = InceptionResNetV2(num_classes=args.num_class) 211 | model = model.to(device) 212 | return model 213 | 214 | if __name__ == "__main__": 215 | 216 | mp.set_start_method('spawn') 217 | torch.manual_seed(args.seed) 218 | torch.cuda.manual_seed_all(args.seed) 219 | 220 | stats_log=open('./checkpoint/%s'%(args.id)+'_stats.txt','w') 221 | test_log=open('./checkpoint/%s'%(args.id)+'_acc.txt','w') 222 | 223 | warm_up=1 224 | 225 | loader = dataloader.webvision_dataloader(batch_size=args.batch_size,num_class = args.num_class,num_workers=8,root_dir=args.data_path,log=stats_log) 226 | 227 | print('| Building net') 228 | 229 | net1 = create_model(cuda1) 230 | net2 = create_model(cuda2) 231 | 232 | net1_clone = create_model(cuda2) 233 | net2_clone = create_model(cuda1) 234 | 235 | cudnn.benchmark = True 236 | 237 | optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 238 | optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 239 | 240 | #conf_penalty = NegEntropy() 241 | web_valloader = loader.run('test') 242 | imagenet_valloader = loader.run('imagenet') 243 | 244 | for epoch in range(args.num_epochs+1): 245 | lr=args.lr 246 | if epoch >= 50: 247 | lr /= 10 248 | for param_group in optimizer1.param_groups: 249 | param_group['lr'] = lr 250 | for param_group in optimizer2.param_groups: 251 | param_group['lr'] = lr 252 | 253 | if epoch args.p_threshold) 263 | pred2 = (prob2 > args.p_threshold) 264 | 265 | labeled_trainloader1, unlabeled_trainloader1 = loader.run('train',pred2,prob2) # co-divide 266 | labeled_trainloader2, unlabeled_trainloader2 = loader.run('train',pred1,prob1) # co-divide 267 | 268 | p1 = mp.Process(target=train, args=(epoch,net1,net2_clone,optimizer1,labeled_trainloader1, unlabeled_trainloader1,cuda1,'net1')) 269 | p2 = mp.Process(target=train, args=(epoch,net2,net1_clone,optimizer2,labeled_trainloader2, unlabeled_trainloader2,cuda2,'net2')) 270 | p1.start() 271 | p2.start() 272 | 273 | p1.join() 274 | p2.join() 275 | 276 | net1_clone.load_state_dict(net1.state_dict()) 277 | net2_clone.load_state_dict(net2.state_dict()) 278 | 279 | q1 = mp.Queue() 280 | q2 = mp.Queue() 281 | p1 = mp.Process(target=test, args=(epoch,net1,net2_clone,web_valloader,cuda1,q1)) 282 | p2 = mp.Process(target=test, args=(epoch,net1_clone,net2,imagenet_valloader,cuda2,q2)) 283 | 284 | p1.start() 285 | p2.start() 286 | 287 | web_acc = q1.get() 288 | imagenet_acc = q2.get() 289 | 290 | p1.join() 291 | p2.join() 292 | 293 | print("\n| Test Epoch #%d\t WebVision Acc: %.2f%% (%.2f%%) \t ImageNet Acc: %.2f%% (%.2f%%)\n"%(epoch,web_acc[0],web_acc[1],imagenet_acc[0],imagenet_acc[1])) 294 | test_log.write('Epoch:%d \t WebVision Acc: %.2f%% (%.2f%%) \t ImageNet Acc: %.2f%% (%.2f%%)\n'%(epoch,web_acc[0],web_acc[1],imagenet_acc[0],imagenet_acc[1])) 295 | test_log.flush() 296 | 297 | eval_loader1 = loader.run('eval_train') 298 | eval_loader2 = loader.run('eval_train') 299 | q1 = mp.Queue() 300 | q2 = mp.Queue() 301 | p1 = mp.Process(target=eval_train, args=(eval_loader1,net1,cuda1,'net1',q1)) 302 | p2 = mp.Process(target=eval_train, args=(eval_loader2,net2,cuda2,'net2',q2)) 303 | 304 | p1.start() 305 | p2.start() 306 | 307 | prob1 = q1.get() 308 | prob2 = q2.get() 309 | 310 | p1.join() 311 | p2.join() 312 | -------------------------------------------------------------------------------- /dataloader_cifar.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import torchvision.transforms as transforms 3 | import random 4 | import numpy as np 5 | from PIL import Image 6 | import json 7 | import os 8 | import torch 9 | from torchnet.meter import AUCMeter 10 | 11 | 12 | def unpickle(file): 13 | import _pickle as cPickle 14 | with open(file, 'rb') as fo: 15 | dict = cPickle.load(fo, encoding='latin1') 16 | return dict 17 | 18 | class cifar_dataset(Dataset): 19 | def __init__(self, dataset, r, noise_mode, root_dir, transform, mode, noise_file='', pred=[], probability=[], log=''): 20 | 21 | self.r = r # noise ratio 22 | self.transform = transform 23 | self.mode = mode 24 | self.transition = {0:0,2:0,4:7,7:7,1:1,9:1,3:5,5:3,6:6,8:8} # class transition for asymmetric noise 25 | 26 | if self.mode=='test': 27 | if dataset=='cifar10': 28 | test_dic = unpickle('%s/test_batch'%root_dir) 29 | self.test_data = test_dic['data'] 30 | self.test_data = self.test_data.reshape((10000, 3, 32, 32)) 31 | self.test_data = self.test_data.transpose((0, 2, 3, 1)) 32 | self.test_label = test_dic['labels'] 33 | elif dataset=='cifar100': 34 | test_dic = unpickle('%s/test'%root_dir) 35 | self.test_data = test_dic['data'] 36 | self.test_data = self.test_data.reshape((10000, 3, 32, 32)) 37 | self.test_data = self.test_data.transpose((0, 2, 3, 1)) 38 | self.test_label = test_dic['fine_labels'] 39 | else: 40 | train_data=[] 41 | train_label=[] 42 | if dataset=='cifar10': 43 | for n in range(1,6): 44 | dpath = '%s/data_batch_%d'%(root_dir,n) 45 | data_dic = unpickle(dpath) 46 | train_data.append(data_dic['data']) 47 | train_label = train_label+data_dic['labels'] 48 | train_data = np.concatenate(train_data) 49 | elif dataset=='cifar100': 50 | train_dic = unpickle('%s/train'%root_dir) 51 | train_data = train_dic['data'] 52 | train_label = train_dic['fine_labels'] 53 | train_data = train_data.reshape((50000, 3, 32, 32)) 54 | train_data = train_data.transpose((0, 2, 3, 1)) 55 | 56 | if os.path.exists(noise_file): 57 | noise_label = json.load(open(noise_file,"r")) 58 | else: #inject noise 59 | noise_label = [] 60 | idx = list(range(50000)) 61 | random.shuffle(idx) 62 | num_noise = int(self.r*50000) 63 | noise_idx = idx[:num_noise] 64 | for i in range(50000): 65 | if i in noise_idx: 66 | if noise_mode=='sym': 67 | if dataset=='cifar10': 68 | noiselabel = random.randint(0,9) 69 | elif dataset=='cifar100': 70 | noiselabel = random.randint(0,99) 71 | noise_label.append(noiselabel) 72 | elif noise_mode=='asym': 73 | noiselabel = self.transition[train_label[i]] 74 | noise_label.append(noiselabel) 75 | else: 76 | noise_label.append(train_label[i]) 77 | print("save noisy labels to %s ..."%noise_file) 78 | json.dump(noise_label,open(noise_file,"w")) 79 | 80 | if self.mode == 'all': 81 | self.train_data = train_data 82 | self.noise_label = noise_label 83 | else: 84 | if self.mode == "labeled": 85 | pred_idx = pred.nonzero()[0] 86 | self.probability = [probability[i] for i in pred_idx] 87 | 88 | clean = (np.array(noise_label)==np.array(train_label)) 89 | auc_meter = AUCMeter() 90 | auc_meter.reset() 91 | auc_meter.add(probability,clean) 92 | auc,_,_ = auc_meter.value() 93 | log.write('Numer of labeled samples:%d AUC:%.3f\n'%(pred.sum(),auc)) 94 | log.flush() 95 | 96 | elif self.mode == "unlabeled": 97 | pred_idx = (1-pred).nonzero()[0] 98 | 99 | self.train_data = train_data[pred_idx] 100 | self.noise_label = [noise_label[i] for i in pred_idx] 101 | print("%s data has a size of %d"%(self.mode,len(self.noise_label))) 102 | 103 | def __getitem__(self, index): 104 | if self.mode=='labeled': 105 | img, target, prob = self.train_data[index], self.noise_label[index], self.probability[index] 106 | img = Image.fromarray(img) 107 | img1 = self.transform(img) 108 | img2 = self.transform(img) 109 | return img1, img2, target, prob 110 | elif self.mode=='unlabeled': 111 | img = self.train_data[index] 112 | img = Image.fromarray(img) 113 | img1 = self.transform(img) 114 | img2 = self.transform(img) 115 | return img1, img2 116 | elif self.mode=='all': 117 | img, target = self.train_data[index], self.noise_label[index] 118 | img = Image.fromarray(img) 119 | img = self.transform(img) 120 | return img, target, index 121 | elif self.mode=='test': 122 | img, target = self.test_data[index], self.test_label[index] 123 | img = Image.fromarray(img) 124 | img = self.transform(img) 125 | return img, target 126 | 127 | def __len__(self): 128 | if self.mode!='test': 129 | return len(self.train_data) 130 | else: 131 | return len(self.test_data) 132 | 133 | 134 | class cifar_dataloader(): 135 | def __init__(self, dataset, r, noise_mode, batch_size, num_workers, root_dir, log, noise_file=''): 136 | self.dataset = dataset 137 | self.r = r 138 | self.noise_mode = noise_mode 139 | self.batch_size = batch_size 140 | self.num_workers = num_workers 141 | self.root_dir = root_dir 142 | self.log = log 143 | self.noise_file = noise_file 144 | if self.dataset=='cifar10': 145 | self.transform_train = transforms.Compose([ 146 | transforms.RandomCrop(32, padding=4), 147 | transforms.RandomHorizontalFlip(), 148 | transforms.ToTensor(), 149 | transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)), 150 | ]) 151 | self.transform_test = transforms.Compose([ 152 | transforms.ToTensor(), 153 | transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)), 154 | ]) 155 | elif self.dataset=='cifar100': 156 | self.transform_train = transforms.Compose([ 157 | transforms.RandomCrop(32, padding=4), 158 | transforms.RandomHorizontalFlip(), 159 | transforms.ToTensor(), 160 | transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)), 161 | ]) 162 | self.transform_test = transforms.Compose([ 163 | transforms.ToTensor(), 164 | transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)), 165 | ]) 166 | def run(self,mode,pred=[],prob=[]): 167 | if mode=='warmup': 168 | all_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_train, mode="all",noise_file=self.noise_file) 169 | trainloader = DataLoader( 170 | dataset=all_dataset, 171 | batch_size=self.batch_size*2, 172 | shuffle=True, 173 | num_workers=self.num_workers) 174 | return trainloader 175 | 176 | elif mode=='train': 177 | labeled_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_train, mode="labeled", noise_file=self.noise_file, pred=pred, probability=prob,log=self.log) 178 | labeled_trainloader = DataLoader( 179 | dataset=labeled_dataset, 180 | batch_size=self.batch_size, 181 | shuffle=True, 182 | num_workers=self.num_workers) 183 | 184 | unlabeled_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_train, mode="unlabeled", noise_file=self.noise_file, pred=pred) 185 | unlabeled_trainloader = DataLoader( 186 | dataset=unlabeled_dataset, 187 | batch_size=self.batch_size, 188 | shuffle=True, 189 | num_workers=self.num_workers) 190 | return labeled_trainloader, unlabeled_trainloader 191 | 192 | elif mode=='test': 193 | test_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_test, mode='test') 194 | test_loader = DataLoader( 195 | dataset=test_dataset, 196 | batch_size=self.batch_size, 197 | shuffle=False, 198 | num_workers=self.num_workers) 199 | return test_loader 200 | 201 | elif mode=='eval_train': 202 | eval_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_test, mode='all', noise_file=self.noise_file) 203 | eval_loader = DataLoader( 204 | dataset=eval_dataset, 205 | batch_size=self.batch_size, 206 | shuffle=False, 207 | num_workers=self.num_workers) 208 | return eval_loader -------------------------------------------------------------------------------- /dataloader_clothing1M.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import torchvision.transforms as transforms 3 | import random 4 | import numpy as np 5 | from PIL import Image 6 | import json 7 | import torch 8 | 9 | class clothing_dataset(Dataset): 10 | def __init__(self, root, transform, mode, num_samples=0, pred=[], probability=[], paths=[], num_class=14): 11 | 12 | self.root = root 13 | self.transform = transform 14 | self.mode = mode 15 | self.train_labels = {} 16 | self.test_labels = {} 17 | self.val_labels = {} 18 | 19 | with open('%s/noisy_label_kv.txt'%self.root,'r') as f: 20 | lines = f.read().splitlines() 21 | for l in lines: 22 | entry = l.split() 23 | img_path = '%s/'%self.root+entry[0][7:] 24 | self.train_labels[img_path] = int(entry[1]) 25 | with open('%s/clean_label_kv.txt'%self.root,'r') as f: 26 | lines = f.read().splitlines() 27 | for l in lines: 28 | entry = l.split() 29 | img_path = '%s/'%self.root+entry[0][7:] 30 | self.test_labels[img_path] = int(entry[1]) 31 | 32 | if mode == 'all': 33 | train_imgs=[] 34 | with open('%s/noisy_train_key_list.txt'%self.root,'r') as f: 35 | lines = f.read().splitlines() 36 | for l in lines: 37 | img_path = '%s/'%self.root+l[7:] 38 | train_imgs.append(img_path) 39 | random.shuffle(train_imgs) 40 | class_num = torch.zeros(num_class) 41 | self.train_imgs = [] 42 | for impath in train_imgs: 43 | label = self.train_labels[impath] 44 | if class_num[label]<(num_samples/14) and len(self.train_imgs)