├── data ├── webvision │ └── refer to dividemix.txt ├── cifar-10 │ └── automatically download.txt └── cifar-100 │ └── automatically download.txt ├── outputs └── time-exp-X-dataset_noise ratio_noise mode ├── README.md ├── cifar ├── main.py ├── PreResNet.py ├── dataloader_cifar.py ├── autoaugment.py └── utils.py └── webvision ├── main.py ├── dataloader_webvision.py ├── InceptionResNetV2.py ├── autoaugment.py └── utils.py /data/webvision/refer to dividemix.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/cifar-10/automatically download.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/cifar-100/automatically download.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /outputs/time-exp-X-dataset_noise ratio_noise mode: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LNL-NCE 2 | A pytorch implementation for "Neighborhood Collective Estimation for Noisy Label Identification and Correction", accepted by ECCV2022. More details of this work can be found in our paper: [Arxiv](https://arxiv.org/abs/2208.03207) or [ECCV2022](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136840126.pdf). 3 | 4 | 5 | ## Installation 6 | 7 | Refer to [DivideMix](https://github.com/LiJunnan1992/DivideMix). 8 | 9 | ## Model training 10 | 11 | (1) To run training on CIFAR-10/CIFAR-100 with different noise modes (namely **sym** or **asym**) and various noise ratios (namely **0.20**, **0.50**, **0.80**, **0.90**, etc.), 12 | 13 | `CUDA_VISIBLE_DEVICES=0 python ./cifar/main.py --dataset cifar10 --num_class 10 --batch_size 128 --data_path ./data/cifar-10/ --r 0.50 --noise_mode sym --remark exp-ID` 14 | 15 | `CUDA_VISIBLE_DEVICES=0 python ./cifar/main.py --dataset cifar100 --num_class 100 --batch_size 128 --data_path ./data/cifar-100/ --r 0.50 --noise_mode sym --remark exp-ID` 16 | 17 | (2) To run training on Webvision-1.0, 18 | 19 | `CUDA_VISIBLE_DEVICES=0,1,2 python ./webvision/main.py --data_path ./data/webvision/ --remark exp-ID` 20 | 21 | ### Citation 22 | If you consider using this code or its derivatives, please consider citing: 23 | 24 | ``` 25 | @inproceedings{li2022neighborhood, 26 | title={Neighborhood Collective Estimation for Noisy Label Identification and Correction}, 27 | author={Li, Jichang and Li, Guanbin and Liu, Feng and Yu, Yizhou}, 28 | booktitle={European Conference on Computer Vision}, 29 | pages={128--145}, 30 | year={2022}, 31 | organization={Springer} 32 | } 33 | ``` 34 | ### Contact 35 | Please feel free to contact the first author, namely [Li Jichang](https://lijichang.github.io/), with an Email address li.jichang@foxmail.com, if you have any questions. 36 | -------------------------------------------------------------------------------- /cifar/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.backends.cudnn as cudnn 6 | import random 7 | import os, datetime, argparse 8 | from PreResNet import * 9 | import dataloader_cifar as dataloader 10 | from utils import NegEntropy 11 | from utils import test 12 | from utils import train 13 | from utils import warmup 14 | from utils import ncnv, nclc 15 | from utils import create_folder_and_save_pyfile 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch CIFAR Training') 18 | parser.add_argument('--batch_size', default=128, type=int, help='train batchsize') 19 | parser.add_argument('--lr', '--learning_rate', default=0.02, type=float, help='initial learning rate') 20 | parser.add_argument('--noise_mode', default='sym', help='sym or asym') 21 | parser.add_argument('--alpha', default=4, type=float, help='parameter for Beta') 22 | parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature') 23 | parser.add_argument('--num_epochs', default=300, type=int) 24 | parser.add_argument('--lr_switch_epoch', default=150, type=int) 25 | parser.add_argument('--r', default=0.5, type=float, help='noise ratio') 26 | parser.add_argument('--drop', default=0.0, type=float) 27 | parser.add_argument('--seed', default=123) 28 | parser.add_argument('--gpuid', default=0, type=int) 29 | parser.add_argument('--num_class', default=10, type=int) 30 | parser.add_argument('--data_path', default='./cifar-10', type=str, help='path to dataset') 31 | parser.add_argument('--dataset', default='cifar10', type=str) 32 | parser.add_argument('--remark', default='', type=str) 33 | args = parser.parse_args() 34 | 35 | torch.cuda.set_device(args.gpuid) 36 | random.seed(args.seed) 37 | torch.manual_seed(args.seed) 38 | torch.cuda.manual_seed_all(args.seed) 39 | 40 | run_time = datetime.datetime.now().strftime('%Y%m%d%H%M%S') 41 | root_folder = create_folder_and_save_pyfile(run_time + "-" + args.remark, args) 42 | record_log = open(os.path.join(root_folder, '%s_%.2f_%s' % (args.dataset, args.r, args.noise_mode) + '_records.txt'), 'a+') 43 | test_log = open(os.path.join(root_folder, '%s_%.2f_%s' % (args.dataset, args.r, args.noise_mode) + '_results.txt'), 'a+') 44 | 45 | if args.dataset == 'cifar10': 46 | warm_up = 10 47 | threshold_sver = 0.75 48 | threshold_scor = 0.002 49 | if args.noise_mode == 'asym': 50 | threshold_sver = 0.50 51 | threshold_scor = 0.0005 52 | elif args.dataset == 'cifar100': 53 | warm_up = 30 54 | threshold_sver = 0.90 55 | threshold_scor = 0.01 56 | if args.r == 0.5: 57 | threshold_scor = 0.005 58 | if args.r <= 0.2: 59 | threshold_scor = 0.0 60 | 61 | print('| Building net') 62 | def create_model(args): 63 | model = ResNet18(num_classes=args.num_class, drop=args.drop) 64 | model = model.cuda() 65 | return model 66 | net1 = create_model(args) 67 | net2 = create_model(args) 68 | cudnn.benchmark = True 69 | 70 | print('| Building optimizer') 71 | optimizer1 = optim.SGD(list(net1.parameters()), lr=args.lr, momentum=0.9, weight_decay=5e-4) 72 | optimizer2 = optim.SGD(list(net2.parameters()), lr=args.lr, momentum=0.9, weight_decay=5e-4) 73 | 74 | CEloss = nn.CrossEntropyLoss() 75 | if args.noise_mode == 'asym': 76 | conf_penalty = NegEntropy() 77 | else: 78 | conf_penalty = None 79 | 80 | loader = dataloader.cifar_dataloader(args.dataset, r=args.r, noise_mode=args.noise_mode, batch_size=args.batch_size, num_workers=5, \ 81 | root_dir=args.data_path, noise_file='%s/%.1f_%s.json' % (args.data_path, args.r, args.noise_mode)) 82 | 83 | for epoch in range(args.num_epochs + 1): 84 | lr = args.lr 85 | if epoch >= args.lr_switch_epoch: 86 | lr /= 10 87 | if epoch >= (args.lr_switch_epoch * 2): 88 | lr /= 2 89 | 90 | for param_group in optimizer1.param_groups: 91 | param_group['lr'] = lr 92 | for param_group in optimizer2.param_groups: 93 | param_group['lr'] = lr 94 | 95 | warmup_trainloader = loader.run('warmup') 96 | test_loader = loader.run('test') 97 | eval_loader = loader.run('eval_train') 98 | 99 | # model training 100 | if epoch < warm_up: 101 | print('Warmup Net1') 102 | warmup(epoch, net1, optimizer1, warmup_trainloader, CEloss, args, conf_penalty) 103 | print('\nWarmup Net2') 104 | warmup(epoch, net2, optimizer2, warmup_trainloader, CEloss, args, conf_penalty, log=record_log) 105 | 106 | else: 107 | prob1 = ncnv(net1, eval_loader, batch_size=args.batch_size, num_class=args.num_class) 108 | pred1 = (prob1 < threshold_sver) 109 | prob2 = ncnv(net2, eval_loader, batch_size=args.batch_size, num_class=args.num_class) 110 | pred2 = (prob2 < threshold_sver) 111 | 112 | print('Train Net1') 113 | labeled_trainloader, unlabeled_trainloader = loader.run('train', pred2, 1-prob2) # co-divide 114 | pseudo_labels = nclc(net1, net2, labeled_trainloader, unlabeled_trainloader, test_loader, batch_size=args.batch_size, num_class=args.num_class, threshold_scor=threshold_scor) 115 | train(epoch, net1, net2, optimizer1, labeled_trainloader, unlabeled_trainloader, args, pseudo_labels=pseudo_labels, log=record_log) 116 | 117 | print('\nTrain Net2') 118 | labeled_trainloader, unlabeled_trainloader = loader.run('train', pred1, 1-prob1) # co-divide 119 | pseudo_labels = nclc(net2, net1, labeled_trainloader, unlabeled_trainloader, test_loader, batch_size=args.batch_size, num_class=args.num_class, threshold_scor=threshold_scor) 120 | train(epoch, net2, net1, optimizer2, labeled_trainloader, unlabeled_trainloader, args, pseudo_labels=pseudo_labels, log=record_log) 121 | 122 | # model testing 123 | test(epoch, net1, net2, test_log, test_loader) 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /webvision/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.backends.cudnn as cudnn 5 | import os 6 | import argparse 7 | from InceptionResNetV2 import * 8 | import dataloader_webvision as dataloader 9 | import datetime 10 | import time 11 | from utils import create_folder_and_save_pyfile 12 | from utils import warmup, train, test 13 | from utils import eval_train_nce 14 | from utils import NegEntropy 15 | 16 | parser = argparse.ArgumentParser(description='PyTorch WebVision Training') 17 | parser.add_argument('--batch_size', default=32, type=int, help='train batchsize') 18 | parser.add_argument('--lr', '--learning_rate', default=0.01, type=float, help='initial learning rate') 19 | parser.add_argument('--noise_mode', default='natrual') 20 | parser.add_argument('--alpha', default=0.5, type=float, help='parameter for Beta') 21 | parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature') 22 | parser.add_argument('--num_epochs', default=80, type=int) 23 | parser.add_argument('--warm_up', default=1, type=int) 24 | parser.add_argument('--r', default=0.0, type=float, help='noise ratio') 25 | parser.add_argument('--seed', default=123) 26 | parser.add_argument('--num_class', default=50, type=int) 27 | parser.add_argument('--data_path', default='./data/webvision/', type=str, help='path to dataset') 28 | parser.add_argument('--dataset', default='webvision', type=str) 29 | parser.add_argument('--remark', default='dividemix', type=str) 30 | parser.add_argument('--feat_dim', default=1536, type=int) 31 | parser.add_argument('--num_neighbor', default=20, type=int) 32 | parser.add_argument('--threshold_sver', default=0.90, type=float) 33 | parser.add_argument('--threshold_scor', default=0.05, type=float) 34 | parser.add_argument('--high_scor', default=1.0, type=float) 35 | args = parser.parse_args() 36 | 37 | run_time = datetime.datetime.now().strftime('%Y%m%d%H%M%S') 38 | root_folder = create_folder_and_save_pyfile(run_time + "-" + args.remark, args) 39 | log = open(os.path.join(root_folder, '%s_%.2f_%s' % (args.dataset, args.r, args.noise_mode) + '_records.txt'), 'a+') 40 | record_log = open(os.path.join(root_folder, '%s_%.2f_%s' % (args.dataset, args.r, args.noise_mode) + '_record.txt'), 'a+') 41 | ils_test_log = open(os.path.join(root_folder, '%s_%.2f_%s' % (args.dataset, args.r, args.noise_mode) + '_imgnet_acc.txt'), 'a+') 42 | web_test_log = open(os.path.join(root_folder, '%s_%.2f_%s' % (args.dataset, args.r, args.noise_mode) + '_web_acc.txt'), 'a+') 43 | 44 | print('| Building net') 45 | def create_model(): 46 | model = InceptionResNetV2(num_classes=args.num_class) 47 | model = nn.DataParallel(model) 48 | model = model.cuda() 49 | return model 50 | net1 = create_model() 51 | net2 = create_model() 52 | cudnn.benchmark = True 53 | 54 | print('| Building optimizer') 55 | optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 56 | optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 57 | 58 | CEloss = nn.CrossEntropyLoss() 59 | conf_penalty = NegEntropy() 60 | 61 | loader = dataloader.webvision_dataloader(batch_size=args.batch_size, num_workers=5, root_dir=args.data_path, num_class=args.num_class) 62 | web_valloader = loader.run('test') 63 | imagenet_valloader = loader.run('imagenet') 64 | 65 | web_acc1, web_acc2, web_acc3 = test(net1, net2, web_valloader) 66 | web_test_log.write('Epoch:%d \t WebVision Acc-NET1+NET2: %.2f%% (%.2f%%) \tNET1: %.2f%% (%.2f%%) \tNET2: %.2f%% (%.2f%%)\n' % ( 67 | -1, web_acc3[0], web_acc3[1], web_acc1[0], web_acc1[1], web_acc2[0], web_acc2[1])) 68 | web_test_log.flush() 69 | 70 | ign_acc1, ign_acc2, ign_acc3 = test(net1, net2, imagenet_valloader) 71 | ils_test_log.write('Epoch:%d \t ILSVRC12 Acc-NET1+NET2: %.2f%% (%.2f%%) \tNET1: %.2f%% (%.2f%%) \tNET2: %.2f%% (%.2f%%)\n' % ( 72 | -1, ign_acc3[0], ign_acc3[1], ign_acc1[0], ign_acc1[1], ign_acc2[0], ign_acc2[1])) 73 | ils_test_log.flush() 74 | 75 | for epoch in range(args.num_epochs + 1): 76 | start_time = time.time() 77 | lr = args.lr 78 | if epoch >= 40 and epoch < 80: 79 | lr /= 10 80 | elif epoch >= 80: 81 | lr /= 10 * 10 82 | 83 | for param_group in optimizer1.param_groups: 84 | param_group['lr'] = lr 85 | for param_group in optimizer2.param_groups: 86 | param_group['lr'] = lr 87 | 88 | warmup_trainloader = loader.run('warmup') 89 | eval_loader = loader.run('eval_train') 90 | 91 | # model training 92 | if epoch < args.warm_up: 93 | print('Warmup Net1') 94 | warmup(epoch, net1, optimizer1, warmup_trainloader, CEloss, log) 95 | print('\nWarmup Net2') 96 | warmup(epoch, net2, optimizer2, warmup_trainloader, CEloss, log) 97 | 98 | else: 99 | pred1 = (prob1 < args.threshold_sver) 100 | pred2 = (prob2 < args.threshold_sver) 101 | log.write('Epoch:%d \tLAB-NET1:%d\tNET2:%d \n' % (epoch, pred1.sum(), pred2.sum())) 102 | log.write('Epoch:%d \tUNL-NET1:%d\tNET2:%d \n' % (epoch, len(pred1) - pred1.sum(), len(pred2) - pred2.sum())) 103 | log.flush() 104 | 105 | print('Train Net1') 106 | labeled_trainloader, unlabeled_trainloader = loader.run('train', pred2, 1-prob2, lab2) # co-divide 107 | train(args, epoch, net1, net2, optimizer1, labeled_trainloader, unlabeled_trainloader, log=record_log) 108 | 109 | print('\nTrain Net2') 110 | labeled_trainloader, unlabeled_trainloader = loader.run('train', pred1, 1-prob1, lab1) # co-divide 111 | train(args, epoch, net2, net1, optimizer2, labeled_trainloader, unlabeled_trainloader, log=record_log) 112 | 113 | web_acc1, web_acc2, web_acc3 = test(net1, net2, web_valloader) 114 | web_test_log.write('Epoch:%d \t WebVision Acc-NET1+NET2: %.2f%% (%.2f%%) \tNET1: %.2f%% (%.2f%%) \tNET2: %.2f%% (%.2f%%)\n' % ( 115 | epoch, web_acc3[0], web_acc3[1], web_acc1[0], web_acc1[1], web_acc2[0], web_acc2[1])) 116 | web_test_log.flush() 117 | 118 | ign_acc1, ign_acc2, ign_acc3 = test(net1, net2, imagenet_valloader) 119 | ils_test_log.write('Epoch:%d \t ILSVRC12 Acc-NET1+NET2: %.2f%% (%.2f%%) \tNET1: %.2f%% (%.2f%%) \tNET2: %.2f%% (%.2f%%)\n' % ( 120 | epoch, ign_acc3[0], ign_acc3[1], ign_acc1[0], ign_acc1[1], ign_acc2[0], ign_acc2[1])) 121 | ils_test_log.flush() 122 | 123 | prob1, lab1 = eval_train_nce(args, eval_loader, net1, feat_dim=args.feat_dim, num_class=args.num_class) 124 | prob2, lab2 = eval_train_nce(args, eval_loader, net2, feat_dim=args.feat_dim, num_class=args.num_class) 125 | 126 | 127 | -------------------------------------------------------------------------------- /cifar/PreResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1): 16 | super(BasicBlock, self).__init__() 17 | self.conv1 = conv3x3(in_planes, planes, stride) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = conv3x3(planes, planes) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | 22 | self.shortcut = nn.Sequential() 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 26 | nn.BatchNorm2d(self.expansion*planes) 27 | ) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | out += self.shortcut(x) 33 | out = F.relu(out) 34 | return out 35 | 36 | 37 | class PreActBlock(nn.Module): 38 | '''Pre-activation version of the BasicBlock.''' 39 | expansion = 1 40 | 41 | def __init__(self, in_planes, planes, stride=1): 42 | super(PreActBlock, self).__init__() 43 | self.bn1 = nn.BatchNorm2d(in_planes) 44 | self.conv1 = conv3x3(in_planes, planes, stride) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv2 = conv3x3(planes, planes) 47 | 48 | self.shortcut = nn.Sequential() 49 | if stride != 1 or in_planes != self.expansion*planes: 50 | self.shortcut = nn.Sequential( 51 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 52 | ) 53 | 54 | def forward(self, x): 55 | out = F.relu(self.bn1(x)) 56 | shortcut = self.shortcut(out) 57 | out = self.conv1(out) 58 | out = self.conv2(F.relu(self.bn2(out))) 59 | out += shortcut 60 | return out 61 | 62 | 63 | class Bottleneck(nn.Module): 64 | expansion = 4 65 | 66 | def __init__(self, in_planes, planes, stride=1): 67 | super(Bottleneck, self).__init__() 68 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 69 | self.bn1 = nn.BatchNorm2d(planes) 70 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 71 | self.bn2 = nn.BatchNorm2d(planes) 72 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 73 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 74 | 75 | self.shortcut = nn.Sequential() 76 | if stride != 1 or in_planes != self.expansion*planes: 77 | self.shortcut = nn.Sequential( 78 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 79 | nn.BatchNorm2d(self.expansion*planes) 80 | ) 81 | 82 | def forward(self, x): 83 | out = F.relu(self.bn1(self.conv1(x))) 84 | out = F.relu(self.bn2(self.conv2(out))) 85 | out = self.bn3(self.conv3(out)) 86 | out += self.shortcut(x) 87 | out = F.relu(out) 88 | return out 89 | 90 | 91 | class PreActBottleneck(nn.Module): 92 | '''Pre-activation version of the original Bottleneck module.''' 93 | expansion = 4 94 | 95 | def __init__(self, in_planes, planes, stride=1): 96 | super(PreActBottleneck, self).__init__() 97 | self.bn1 = nn.BatchNorm2d(in_planes) 98 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 99 | self.bn2 = nn.BatchNorm2d(planes) 100 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 101 | self.bn3 = nn.BatchNorm2d(planes) 102 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 103 | 104 | self.shortcut = nn.Sequential() 105 | if stride != 1 or in_planes != self.expansion*planes: 106 | self.shortcut = nn.Sequential( 107 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 108 | ) 109 | 110 | def forward(self, x): 111 | out = F.relu(self.bn1(x)) 112 | shortcut = self.shortcut(out) 113 | out = self.conv1(out) 114 | out = self.conv2(F.relu(self.bn2(out))) 115 | out = self.conv3(F.relu(self.bn3(out))) 116 | out += shortcut 117 | return out 118 | 119 | 120 | class ResNet(nn.Module): 121 | def __init__(self, block, num_blocks, head_type="mlp", feat_dim=128, num_classes=10, drop=0.0): 122 | super(ResNet, self).__init__() 123 | self.in_planes = 64 124 | self.drop = drop 125 | self.dropout = nn.Dropout(drop) 126 | self.relu = nn.ReLU(inplace=True) 127 | self.conv1 = conv3x3(3,64) 128 | self.bn1 = nn.BatchNorm2d(64) 129 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 130 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 131 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 132 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 133 | self.linear = nn.Linear(512*block.expansion, num_classes) 134 | self.head = nn.Linear(512*block.expansion, 512*block.expansion) 135 | 136 | dim_in = 512*block.expansion 137 | if head_type == 'linear': 138 | self.conhead = nn.Linear(dim_in, feat_dim) 139 | elif head_type == 'mlp': 140 | self.conhead = nn.Sequential( 141 | nn.Linear(dim_in, dim_in), 142 | nn.ReLU(inplace=True), 143 | nn.Linear(dim_in, feat_dim) 144 | ) 145 | else: 146 | raise NotImplementedError( 147 | 'head not supported: {}'.format(head_type)) 148 | 149 | def _make_layer(self, block, planes, num_blocks, stride): 150 | strides = [stride] + [1]*(num_blocks-1) 151 | layers = [] 152 | for stride in strides: 153 | layers.append(block(self.in_planes, planes, stride)) 154 | self.in_planes = planes * block.expansion 155 | return nn.Sequential(*layers) 156 | 157 | def forward(self, x, lin=0, lout=5, feat=False, confeat=False): 158 | out = x 159 | if lin < 1 and lout > -1: 160 | out = self.conv1(out) 161 | out = self.bn1(out) 162 | out = F.relu(out) 163 | if lin < 2 and lout > 0: 164 | out = self.layer1(out) 165 | if self.drop > 0.0: 166 | out = self.dropout(out) 167 | if lin < 3 and lout > 1: 168 | out = self.layer2(out) 169 | if self.drop > 0.0: 170 | out = self.dropout(out) 171 | if lin < 4 and lout > 2: 172 | out = self.layer3(out) 173 | if self.drop > 0.0: 174 | out = self.dropout(out) 175 | if lin < 5 and lout > 3: 176 | out = self.layer4(out) 177 | if self.drop > 0.0: 178 | out = self.dropout(out) 179 | if (not feat) and (not confeat): 180 | if lout > 4: 181 | out = F.avg_pool2d(out, 4) 182 | out = out.view(out.size(0), -1) 183 | if self.drop > 0.0: 184 | out = self.dropout(out) 185 | print("Dropout 1", self.drop) 186 | out = self.head(out) 187 | if self.drop > 0.0: 188 | out = self.relu(out) 189 | out = self.dropout(out) 190 | print("Dropout 2", self.drop) 191 | out = self.linear(out) 192 | return out 193 | elif feat and (not confeat): 194 | if lout > 4: 195 | out = F.avg_pool2d(out, 4) 196 | out = out.view(out.size(0), -1) 197 | out = self.head(out) 198 | out_feat = out 199 | out = self.linear(out) 200 | return out, out_feat 201 | elif (not feat) and confeat: 202 | if lout > 4: 203 | out = F.avg_pool2d(out, 4) 204 | out = out.view(out.size(0), -1) 205 | out = self.head(out) 206 | con_feat = self.conhead(out) 207 | out = self.linear(out) 208 | return out, con_feat 209 | else: 210 | if lout > 4: 211 | out = F.avg_pool2d(out, 4) 212 | out = out.view(out.size(0), -1) 213 | out = self.head(out) 214 | out_feat = out 215 | con_feat = self.conhead(out) 216 | out = self.linear(out) 217 | return out, out_feat, con_feat 218 | 219 | 220 | 221 | def ResNet18(num_classes=10, drop=0.0): 222 | return ResNet(PreActBlock, [2,2,2,2], num_classes=num_classes, drop=drop) 223 | 224 | def ResNet34(num_classes=10): 225 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes) 226 | 227 | def ResNet50(num_classes=10): 228 | return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes) 229 | 230 | def ResNet101(num_classes=10): 231 | return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes) 232 | 233 | def ResNet152(num_classes=10): 234 | return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes) 235 | 236 | 237 | def test(): 238 | net = ResNet18() 239 | y = net(Variable(torch.randn(1,3,32,32))) 240 | print(y.size()) 241 | -------------------------------------------------------------------------------- /webvision/dataloader_webvision.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import torchvision.transforms as transforms 3 | from PIL import Image 4 | import os 5 | from autoaugment import ImageNetPolicy 6 | 7 | class imagenet_dataset(Dataset): 8 | def __init__(self, root_dir, transform, num_class): 9 | self.root = root_dir 10 | self.transform = transform 11 | self.val_data = [] 12 | folder = 'val' 13 | with open(os.path.join(root_dir, 'info/synsets.txt')) as f: 14 | lines = f.readlines() 15 | synsets = [x.split()[0] for x in lines] 16 | for c in range(num_class): 17 | class_path = os.path.join(self.root, folder, synsets[c]) 18 | imgs = os.listdir(class_path) 19 | for img in imgs: 20 | self.val_data.append([c, os.path.join(class_path, img)]) 21 | 22 | def __getitem__(self, index): 23 | data = self.val_data[index] 24 | target = data[0] 25 | image = Image.open(data[1]).convert('RGB') 26 | img = self.transform(image) 27 | return img, target 28 | 29 | def __len__(self): 30 | return len(self.val_data) 31 | 32 | class webvision_dataset(Dataset): 33 | def __init__(self, root_dir, transform, mode, num_class, pred=[], probability=[], label=[], transform_strong=None): 34 | self.root = root_dir 35 | self.transform = transform 36 | self.transform_strong = transform_strong 37 | self.mode = mode 38 | 39 | if self.mode=='test': 40 | with open(self.root+'info/val_filelist.txt') as f: 41 | lines=f.readlines() 42 | self.val_imgs = [] 43 | self.val_labels = {} 44 | for line in lines: 45 | img, target = line.split() 46 | target = int(target) 47 | if target>> policy = ImageNetPolicy() 11 | >>> transformed = policy(image) 12 | 13 | Example as a PyTorch Transform: 14 | >>> transform=transforms.Compose([ 15 | >>> transforms.Resize(256), 16 | >>> ImageNetPolicy(), 17 | >>> transforms.ToTensor()]) 18 | """ 19 | 20 | def __init__(self, fillcolor=(128, 128, 128)): 21 | self.policies = [ 22 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 23 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 24 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 25 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 26 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 27 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 28 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 29 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 30 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 31 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 32 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 33 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 34 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 35 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 36 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 37 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 38 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 39 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 40 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 41 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 42 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 43 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 44 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 45 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 46 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 47 | ] 48 | 49 | def __call__(self, img): 50 | policy_idx = random.randint(0, len(self.policies) - 1) 51 | return self.policies[policy_idx](img) 52 | 53 | def __repr__(self): 54 | return "AutoAugment ImageNet Policy" 55 | 56 | 57 | class CIFAR10Policy(object): 58 | """Randomly choose one of the best 25 Sub-policies on CIFAR10. 59 | 60 | Example: 61 | >>> policy = CIFAR10Policy() 62 | >>> transformed = policy(image) 63 | 64 | Example as a PyTorch Transform: 65 | >>> transform=transforms.Compose([ 66 | >>> transforms.Resize(256), 67 | >>> CIFAR10Policy(), 68 | >>> transforms.ToTensor()]) 69 | """ 70 | 71 | def __init__(self, fillcolor=(128, 128, 128)): 72 | self.policies = [ 73 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 74 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 75 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 76 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 77 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 78 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 79 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 80 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 81 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 82 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 83 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 84 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 85 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 86 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 87 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 88 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 89 | SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor), 90 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 91 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 92 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 93 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 94 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 95 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 96 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 97 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor), 98 | ] 99 | 100 | def __call__(self, img): 101 | policy_idx = random.randint(0, len(self.policies) - 1) 102 | return self.policies[policy_idx](img) 103 | 104 | def __repr__(self): 105 | return "AutoAugment CIFAR10 Policy" 106 | 107 | 108 | class SVHNPolicy(object): 109 | """Randomly choose one of the best 25 Sub-policies on SVHN. 110 | 111 | Example: 112 | >>> policy = SVHNPolicy() 113 | >>> transformed = policy(image) 114 | 115 | Example as a PyTorch Transform: 116 | >>> transform=transforms.Compose([ 117 | >>> transforms.Resize(256), 118 | >>> SVHNPolicy(), 119 | >>> transforms.ToTensor()]) 120 | """ 121 | 122 | def __init__(self, fillcolor=(128, 128, 128)): 123 | self.policies = [ 124 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 125 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 126 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 127 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 128 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 129 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 130 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 131 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 132 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 133 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 134 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 135 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 136 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 137 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 138 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 139 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 140 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 141 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 142 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 143 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 144 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 145 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 146 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 147 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 148 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor), 149 | ] 150 | 151 | def __call__(self, img): 152 | policy_idx = random.randint(0, len(self.policies) - 1) 153 | return self.policies[policy_idx](img) 154 | 155 | def __repr__(self): 156 | return "AutoAugment SVHN Policy" 157 | 158 | 159 | class SubPolicy(object): 160 | def __init__( 161 | self, 162 | p1, 163 | operation1, 164 | magnitude_idx1, 165 | p2, 166 | operation2, 167 | magnitude_idx2, 168 | fillcolor=(128, 128, 128), 169 | ): 170 | ranges = { 171 | "shearX": np.linspace(0, 0.3, 10), 172 | "shearY": np.linspace(0, 0.3, 10), 173 | "translateX": np.linspace(0, 150 / 331, 10), 174 | "translateY": np.linspace(0, 150 / 331, 10), 175 | "rotate": np.linspace(0, 30, 10), 176 | "color": np.linspace(0.0, 0.9, 10), 177 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 178 | "solarize": np.linspace(256, 0, 10), 179 | "contrast": np.linspace(0.0, 0.9, 10), 180 | "sharpness": np.linspace(0.0, 0.9, 10), 181 | "brightness": np.linspace(0.0, 0.9, 10), 182 | "autocontrast": [0] * 10, 183 | "equalize": [0] * 10, 184 | "invert": [0] * 10, 185 | } 186 | 187 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 188 | def rotate_with_fill(img, magnitude): 189 | rot = img.convert("RGBA").rotate(magnitude) 190 | return Image.composite( 191 | rot, Image.new("RGBA", rot.size, (128,) * 4), rot 192 | ).convert(img.mode) 193 | 194 | func = { 195 | "shearX": lambda img, magnitude: img.transform( 196 | img.size, 197 | Image.AFFINE, 198 | (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 199 | Image.BICUBIC, 200 | fillcolor=fillcolor, 201 | ), 202 | "shearY": lambda img, magnitude: img.transform( 203 | img.size, 204 | Image.AFFINE, 205 | (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 206 | Image.BICUBIC, 207 | fillcolor=fillcolor, 208 | ), 209 | "translateX": lambda img, magnitude: img.transform( 210 | img.size, 211 | Image.AFFINE, 212 | (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 213 | fillcolor=fillcolor, 214 | ), 215 | "translateY": lambda img, magnitude: img.transform( 216 | img.size, 217 | Image.AFFINE, 218 | (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 219 | fillcolor=fillcolor, 220 | ), 221 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 222 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance( 223 | 1 + magnitude * random.choice([-1, 1]) 224 | ), 225 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 226 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 227 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 228 | 1 + magnitude * random.choice([-1, 1]) 229 | ), 230 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 231 | 1 + magnitude * random.choice([-1, 1]) 232 | ), 233 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 234 | 1 + magnitude * random.choice([-1, 1]) 235 | ), 236 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 237 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 238 | "invert": lambda img, magnitude: ImageOps.invert(img), 239 | } 240 | 241 | self.p1 = p1 242 | self.operation1 = func[operation1] 243 | self.magnitude1 = ranges[operation1][magnitude_idx1] 244 | self.p2 = p2 245 | self.operation2 = func[operation2] 246 | self.magnitude2 = ranges[operation2][magnitude_idx2] 247 | 248 | def __call__(self, img): 249 | if random.random() < self.p1: 250 | img = self.operation1(img, self.magnitude1) 251 | if random.random() < self.p2: 252 | img = self.operation2(img, self.magnitude2) 253 | return img 254 | -------------------------------------------------------------------------------- /webvision/autoaugment.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageEnhance, ImageOps 2 | import numpy as np 3 | import random 4 | 5 | 6 | class ImageNetPolicy(object): 7 | """Randomly choose one of the best 24 Sub-policies on ImageNet. 8 | 9 | Example: 10 | >>> policy = ImageNetPolicy() 11 | >>> transformed = policy(image) 12 | 13 | Example as a PyTorch Transform: 14 | >>> transform=transforms.Compose([ 15 | >>> transforms.Resize(256), 16 | >>> ImageNetPolicy(), 17 | >>> transforms.ToTensor()]) 18 | """ 19 | 20 | def __init__(self, fillcolor=(128, 128, 128)): 21 | self.policies = [ 22 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 23 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 24 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 25 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 26 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 27 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 28 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 29 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 30 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 31 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 32 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 33 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 34 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 35 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 36 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 37 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 38 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 39 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 40 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 41 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 42 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 43 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 44 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 45 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 46 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 47 | ] 48 | 49 | def __call__(self, img): 50 | policy_idx = random.randint(0, len(self.policies) - 1) 51 | return self.policies[policy_idx](img) 52 | 53 | def __repr__(self): 54 | return "AutoAugment ImageNet Policy" 55 | 56 | 57 | class CIFAR10Policy(object): 58 | """Randomly choose one of the best 25 Sub-policies on CIFAR10. 59 | 60 | Example: 61 | >>> policy = CIFAR10Policy() 62 | >>> transformed = policy(image) 63 | 64 | Example as a PyTorch Transform: 65 | >>> transform=transforms.Compose([ 66 | >>> transforms.Resize(256), 67 | >>> CIFAR10Policy(), 68 | >>> transforms.ToTensor()]) 69 | """ 70 | 71 | def __init__(self, fillcolor=(128, 128, 128)): 72 | self.policies = [ 73 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 74 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 75 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 76 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 77 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 78 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 79 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 80 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 81 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 82 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 83 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 84 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 85 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 86 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 87 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 88 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 89 | SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor), 90 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 91 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 92 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 93 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 94 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 95 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 96 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 97 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor), 98 | ] 99 | 100 | def __call__(self, img): 101 | policy_idx = random.randint(0, len(self.policies) - 1) 102 | return self.policies[policy_idx](img) 103 | 104 | def __repr__(self): 105 | return "AutoAugment CIFAR10 Policy" 106 | 107 | 108 | class SVHNPolicy(object): 109 | """Randomly choose one of the best 25 Sub-policies on SVHN. 110 | 111 | Example: 112 | >>> policy = SVHNPolicy() 113 | >>> transformed = policy(image) 114 | 115 | Example as a PyTorch Transform: 116 | >>> transform=transforms.Compose([ 117 | >>> transforms.Resize(256), 118 | >>> SVHNPolicy(), 119 | >>> transforms.ToTensor()]) 120 | """ 121 | 122 | def __init__(self, fillcolor=(128, 128, 128)): 123 | self.policies = [ 124 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 125 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 126 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 127 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 128 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 129 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 130 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 131 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 132 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 133 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 134 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 135 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 136 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 137 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 138 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 139 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 140 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 141 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 142 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 143 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 144 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 145 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 146 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 147 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 148 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor), 149 | ] 150 | 151 | def __call__(self, img): 152 | policy_idx = random.randint(0, len(self.policies) - 1) 153 | return self.policies[policy_idx](img) 154 | 155 | def __repr__(self): 156 | return "AutoAugment SVHN Policy" 157 | 158 | 159 | class SubPolicy(object): 160 | def __init__( 161 | self, 162 | p1, 163 | operation1, 164 | magnitude_idx1, 165 | p2, 166 | operation2, 167 | magnitude_idx2, 168 | fillcolor=(128, 128, 128), 169 | ): 170 | ranges = { 171 | "shearX": np.linspace(0, 0.3, 10), 172 | "shearY": np.linspace(0, 0.3, 10), 173 | "translateX": np.linspace(0, 150 / 331, 10), 174 | "translateY": np.linspace(0, 150 / 331, 10), 175 | "rotate": np.linspace(0, 30, 10), 176 | "color": np.linspace(0.0, 0.9, 10), 177 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 178 | "solarize": np.linspace(256, 0, 10), 179 | "contrast": np.linspace(0.0, 0.9, 10), 180 | "sharpness": np.linspace(0.0, 0.9, 10), 181 | "brightness": np.linspace(0.0, 0.9, 10), 182 | "autocontrast": [0] * 10, 183 | "equalize": [0] * 10, 184 | "invert": [0] * 10, 185 | } 186 | 187 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 188 | def rotate_with_fill(img, magnitude): 189 | rot = img.convert("RGBA").rotate(magnitude) 190 | return Image.composite( 191 | rot, Image.new("RGBA", rot.size, (128,) * 4), rot 192 | ).convert(img.mode) 193 | 194 | func = { 195 | "shearX": lambda img, magnitude: img.transform( 196 | img.size, 197 | Image.AFFINE, 198 | (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 199 | Image.BICUBIC, 200 | fillcolor=fillcolor, 201 | ), 202 | "shearY": lambda img, magnitude: img.transform( 203 | img.size, 204 | Image.AFFINE, 205 | (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 206 | Image.BICUBIC, 207 | fillcolor=fillcolor, 208 | ), 209 | "translateX": lambda img, magnitude: img.transform( 210 | img.size, 211 | Image.AFFINE, 212 | (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 213 | fillcolor=fillcolor, 214 | ), 215 | "translateY": lambda img, magnitude: img.transform( 216 | img.size, 217 | Image.AFFINE, 218 | (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 219 | fillcolor=fillcolor, 220 | ), 221 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 222 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance( 223 | 1 + magnitude * random.choice([-1, 1]) 224 | ), 225 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 226 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 227 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 228 | 1 + magnitude * random.choice([-1, 1]) 229 | ), 230 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 231 | 1 + magnitude * random.choice([-1, 1]) 232 | ), 233 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 234 | 1 + magnitude * random.choice([-1, 1]) 235 | ), 236 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 237 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 238 | "invert": lambda img, magnitude: ImageOps.invert(img), 239 | } 240 | 241 | self.p1 = p1 242 | self.operation1 = func[operation1] 243 | self.magnitude1 = ranges[operation1][magnitude_idx1] 244 | self.p2 = p2 245 | self.operation2 = func[operation2] 246 | self.magnitude2 = ranges[operation2][magnitude_idx2] 247 | 248 | def __call__(self, img): 249 | if random.random() < self.p1: 250 | img = self.operation1(img, self.magnitude1) 251 | if random.random() < self.p2: 252 | img = self.operation2(img, self.magnitude2) 253 | return img 254 | -------------------------------------------------------------------------------- /webvision/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from torch.nn.functional import normalize 6 | import math 7 | import glob, os, shutil 8 | import torchnet 9 | 10 | def export_pyfile(target_dir): 11 | if not os.path.exists(target_dir): 12 | os.makedirs(target_dir) 13 | for ext in ('py', 'pyproj', 'sln'): 14 | for fn in glob.glob('*.' + ext): 15 | shutil.copy2(fn, target_dir) 16 | if os.path.isdir('src'): 17 | for fn in glob.glob(os.path.join('src', '*.' + ext)): 18 | shutil.copy2(fn, target_dir) 19 | 20 | def create_folder_and_save_pyfile(nowtime, args): 21 | root_folder = os.path.join("./", "outputs", nowtime+"-"+'%s_%.2f_%s' % (args.dataset, args.r, args.noise_mode)) 22 | if not os.path.exists(root_folder): 23 | os.makedirs(root_folder) 24 | # saving pyfiles 25 | folder = os.path.join(root_folder, "folder_for_pyfiles") 26 | if not os.path.exists(folder): 27 | os.makedirs(folder) 28 | export_pyfile(folder) 29 | return root_folder 30 | 31 | class NegEntropy(object): 32 | def __call__(self, outputs): 33 | probs = F.softmax(outputs, dim=1) 34 | return torch.mean(torch.sum(probs.log() * probs, dim=1)) 35 | 36 | def kl_div(p, q): 37 | # p, q is in shape (batch_size, n_classes) 38 | return (p * p.log2() - p * q.log2()).sum(dim=1) 39 | 40 | def js_div(p, q): 41 | # Jensen-Shannon divergence, value is in range (0, 1) 42 | m = 0.5 * (p + q) 43 | return 0.5 * kl_div(p, m) + 0.5 * kl_div(q, m) 44 | 45 | def mixup(inputs, targets, alpha): 46 | l = np.random.beta(alpha, alpha) 47 | l = max(l, 1 - l) 48 | idx = torch.randperm(inputs.size(0)) 49 | input_a, input_b = inputs, inputs[idx] 50 | target_a, target_b = targets, targets[idx] 51 | mixed_input = l * input_a + (1 - l) * input_b 52 | mixed_target = l * target_a + (1 - l) * target_b 53 | return mixed_input, mixed_target 54 | 55 | def eval_train_nce(args, eval_loader, net, feat_dim, num_class): 56 | net.eval() 57 | ## Get train features 58 | trainFeatures = torch.rand(len(eval_loader.dataset), feat_dim).t().cuda() 59 | trainLogits = torch.rand(len(eval_loader.dataset), num_class).t().cuda() 60 | trainNoisyLabels = torch.rand(len(eval_loader.dataset)).cuda() 61 | 62 | iter_count = 0 63 | for batch_idx, (inputs, labels, index) in enumerate(eval_loader): 64 | batchSize = inputs.size(0) 65 | logits, features = net(inputs.cuda(), feat=True) 66 | trainFeatures[:, batch_idx * batchSize:batch_idx * batchSize + batchSize] = features.cuda().data.t() 67 | trainLogits[:, batch_idx * batchSize:batch_idx * batchSize + batchSize] = logits.cuda().data.t() 68 | trainNoisyLabels[batch_idx * batchSize:batch_idx * batchSize + batchSize] = labels.cuda().data 69 | iter_count += 1 70 | 71 | trainFeatures = normalize(trainFeatures.t()) 72 | trainLogits = trainLogits.t() 73 | trainNoisyLabels = trainNoisyLabels 74 | 75 | # caculating neighborhood-based label inconsistency score 76 | num_batch = math.ceil(float(trainFeatures.size(0)) / args.batch_size) # 391 77 | sver_collection = [] 78 | for batch_idx in range(num_batch): 79 | features = trainFeatures[batch_idx * args.batch_size:batch_idx * args.batch_size + args.batch_size] 80 | noisy_labels = trainNoisyLabels[batch_idx * args.batch_size:batch_idx * args.batch_size + args.batch_size] 81 | dist = torch.mm(features, trainFeatures.t()) 82 | dist[torch.arange(dist.size()[0]), torch.arange(dist.size()[0])] = -1 # set self-contrastive samples to -1 83 | _, neighbors = dist.topk(args.num_neighbor, dim=1, largest=True, sorted=True) # find contrastive neighbors 84 | neighbors = neighbors.view(-1) 85 | neigh_logits = trainLogits[neighbors] 86 | neigh_probs = F.softmax(neigh_logits, dim=-1) 87 | M, _ = features.shape 88 | given_labels = torch.full(size=(M, num_class), fill_value=0.0001).cuda() 89 | given_labels.scatter_(dim=1, index=torch.unsqueeze(noisy_labels.long(), dim=1), value=1 - 0.0001) 90 | given_labels = given_labels.repeat(1, args.num_neighbor).view(-1, num_class) 91 | sver = js_div(neigh_probs, given_labels) 92 | sver_collection += sver.view(-1, args.num_neighbor).mean(dim=1).cpu().numpy().tolist() 93 | prob = np.array(sver_collection) 94 | # prob = 1.0 - np.array(sver_collection) 95 | mask_lab = prob < args.threshold_sver 96 | mask_unl = prob > args.threshold_sver 97 | 98 | labeledFeatures = trainFeatures[mask_lab] 99 | labeledLogits = trainLogits[mask_lab] 100 | labeledNoisyLabels = trainNoisyLabels[mask_lab] 101 | labeledW = prob[mask_lab] 102 | 103 | knn_labeledLogits = labeledLogits[labeledW > 0.95] 104 | knn_labeledFeatures = labeledFeatures[labeledW > 0.95] 105 | knn_labeledNoisyLabels = labeledNoisyLabels[labeledW > 0.95] 106 | 107 | unlabeledFeatures = trainFeatures[mask_unl] 108 | unlabeledLogits = trainLogits[mask_unl] 109 | 110 | num_labeled = knn_labeledFeatures.size(0) 111 | num_unlabeled = unlabeledFeatures.size(0) 112 | if num_labeled <= args.num_neighbor * num_class: 113 | pseudo_labels = [-3] * num_unlabeled 114 | pseudo_labels = np.array(pseudo_labels) 115 | print("num_labeled <= args.num_neighbor * 10 ...") 116 | return prob, noisy_labels 117 | 118 | # caculating pseudo-labels for unlabeled samples 119 | num_batch_unlabeled = math.ceil(float(unlabeledFeatures.size(0)) / args.batch_size) 120 | pseudo_labels = [] 121 | scor_collection = [] 122 | for batch_idx in range(num_batch_unlabeled): 123 | features = unlabeledFeatures[batch_idx * args.batch_size:batch_idx * args.batch_size + args.batch_size] 124 | logits = unlabeledLogits[batch_idx * args.batch_size:batch_idx * args.batch_size + args.batch_size] 125 | dist = torch.mm(features, knn_labeledFeatures.t()) 126 | _, neighbors = dist.topk(args.num_neighbor, dim=1, largest=True, sorted=True) # find contrastive neighbors 127 | neighbors = neighbors.view(-1) 128 | neighs_labels = knn_labeledNoisyLabels[neighbors] 129 | neighs_logits = knn_labeledLogits[neighbors] 130 | neigh_probs = F.softmax(neighs_logits, dim=-1) 131 | neighbor_labels = torch.full(size=neigh_probs.size(), fill_value=0.0001).cuda() 132 | neighbor_labels.scatter_(dim=1, index=torch.unsqueeze(neighs_labels.long(), dim=1), value=1 - 0.0001) 133 | scor = js_div(F.softmax(logits.repeat(1, args.num_neighbor).view(-1, num_class), dim=-1), neighbor_labels) 134 | w = (1 - scor).type(torch.FloatTensor) 135 | w = w.view(-1, 1).type(torch.FloatTensor).cuda() 136 | neighbor_labels = (neighbor_labels * w).view(-1, args.num_neighbor, num_class).sum(dim=1) 137 | pseudo_labels += neighbor_labels.cpu().numpy().tolist() 138 | scor = scor.view(-1, args.num_neighbor).mean(dim=1) 139 | scor_collection += scor.cpu().numpy().tolist() 140 | scor_collection = np.array(scor_collection) 141 | 142 | pseudo_labels = np.argmax(np.array(pseudo_labels), axis=1) 143 | pseudo_labels[np.equal(scor_collection > args.threshold_scor, scor_collection <= args.high_scor)] = -1 144 | pseudo_labels[scor_collection > args.high_scor] = -2 145 | 146 | noisy_labels = trainNoisyLabels.cpu().numpy() 147 | noisy_labels[mask_unl>0] = pseudo_labels 148 | 149 | return prob, noisy_labels 150 | 151 | def warmup(epoch, net, optimizer, dataloader, CEloss, log): 152 | net.train() 153 | num_iter = (len(dataloader.dataset) // dataloader.batch_size) + 1 154 | for batch_idx, (inputs, labels, path) in enumerate(dataloader): 155 | inputs, labels = inputs.cuda(), labels.cuda() 156 | optimizer.zero_grad() 157 | outputs = net(inputs) 158 | loss = CEloss(outputs, labels) 159 | 160 | # penalty = conf_penalty(outputs) 161 | L = loss # + penalty 162 | L.backward() 163 | optimizer.step() 164 | 165 | log.write('\r') 166 | log.write('%s | Epoch [%3d/%3d] Iter[%4d/%4d]\t CE-loss: %.4f' % (args.id, epoch, args.num_epochs, batch_idx + 1, num_iter, loss.item())) 167 | log.flush() 168 | 169 | # Training 170 | def train(args, epoch, net, net2, optimizer, labeled_trainloader, unlabeled_trainloader, log='record.txt'): 171 | net.train() 172 | net2.eval() # fix one network and train the other 173 | 174 | unlabeled_train_iter = iter(unlabeled_trainloader) 175 | num_iter = (len(labeled_trainloader.dataset) // args.batch_size) + 1 176 | for batch_idx, (inputs_x, inputs_x2, _, labels_x, w_x, _) in enumerate(labeled_trainloader): 177 | try: 178 | inputs_u, inputs_u2, inputs_us, _, _, _ = unlabeled_train_iter.next() 179 | except: 180 | unlabeled_train_iter = iter(unlabeled_trainloader) 181 | inputs_u, inputs_u2, inputs_us, _, _, _ = unlabeled_train_iter.next() 182 | batch_size = inputs_x.size(0) 183 | 184 | # transforming given label to one-hot vector for labeled samples 185 | labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.long().view(-1, 1), 1) 186 | w_x = w_x.view(-1, 1).type(torch.FloatTensor) 187 | 188 | inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(), inputs_x2.cuda(), labels_x.cuda(), w_x.cuda() 189 | inputs_u, inputs_u2, inputs_us = inputs_u.cuda(), inputs_u2.cuda(), inputs_us.cuda() 190 | 191 | # label refinement (refer to DivideMix) 192 | with torch.no_grad(): 193 | outputs_x = net(inputs_x) 194 | outputs_x2 = net(inputs_x2) 195 | px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2 196 | px = w_x * labels_x + (1 - w_x) * px 197 | ptx = px ** (1 / args.T) # temparature sharpening 198 | targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize 199 | targets_x = targets_x.detach() 200 | 201 | # mixmatch 202 | l = np.random.beta(args.alpha, args.alpha) 203 | l = max(l, 1 - l) 204 | all_inputs = torch.cat([inputs_x, inputs_x2], dim=0) 205 | all_targets = torch.cat([targets_x, targets_x], dim=0) 206 | idx = torch.randperm(all_inputs.size(0)) 207 | input_a, input_b = all_inputs, all_inputs[idx] 208 | target_a, target_b = all_targets, all_targets[idx] 209 | mixed_input = l * input_a[:batch_size * 2] + (1 - l) * input_b[:batch_size * 2] 210 | mixed_target = l * target_a[:batch_size * 2] + (1 - l) * target_b[:batch_size * 2] 211 | mixed_logits = net(mixed_input) 212 | 213 | # mixup regularization for labeled data 214 | Lx = -torch.mean(torch.sum(F.log_softmax(mixed_logits, dim=1) * mixed_target, dim=1)) 215 | 216 | # penalty regularization for mixed labeled data 217 | prior = torch.ones(args.num_class) / args.num_class 218 | prior = prior.cuda() 219 | pred_mean = torch.softmax(mixed_logits, dim=1).mean(0) 220 | penalty = torch.sum(prior * torch.log(prior / pred_mean)) 221 | 222 | # overall loss 223 | loss = Lx + penalty 224 | 225 | optimizer.zero_grad() 226 | loss.backward() 227 | optimizer.step() 228 | 229 | log.write('\r') 230 | log.write('Webvision | Epoch [%3d/%3d] Iter[%3d/%3d]\t Lx loss: %.4f, Lpen: %.4f'% (epoch, args.num_epochs, batch_idx + 1, num_iter, 231 | Lx.item(), penalty.item())) 232 | log.flush() 233 | 234 | def test(net1, net2, test_loader): 235 | net1.eval() 236 | net2.eval() 237 | 238 | correct = 0 239 | total = 0 240 | 241 | with torch.no_grad(): 242 | for batch_idx, (inputs, targets) in enumerate(test_loader): 243 | inputs, targets = inputs.cuda(), targets.cuda() 244 | outputs1 = net1(inputs) 245 | outputs2 = net2(inputs) 246 | outputs = outputs1 + outputs2 247 | 248 | _, predicted = torch.max(outputs, 1) 249 | _, predicted1 = torch.max(outputs1, 1) 250 | _, predicted2 = torch.max(outputs2, 1) 251 | 252 | correct += predicted.eq(targets).cpu().sum().item() 253 | correct1 += predicted1.eq(targets).cpu().sum().item() 254 | correct2 += predicted2.eq(targets).cpu().sum().item() 255 | 256 | total += targets.size(0) 257 | 258 | acc = 100. * correct / total 259 | acc1 = 100. * correct1 / total 260 | acc2 = 100. * correct2 / total 261 | 262 | return acc, acc1, acc2 263 | -------------------------------------------------------------------------------- /cifar/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from torch.nn.functional import normalize 6 | import math 7 | import glob, os, shutil 8 | 9 | def export_pyfile(target_dir): 10 | if not os.path.exists(target_dir): 11 | os.makedirs(target_dir) 12 | for ext in ('py', 'pyproj', 'sln'): 13 | for fn in glob.glob('*.' + ext): 14 | shutil.copy2(fn, target_dir) 15 | if os.path.isdir('src'): 16 | for fn in glob.glob(os.path.join('src', '*.' + ext)): 17 | shutil.copy2(fn, target_dir) 18 | 19 | def create_folder_and_save_pyfile(nowtime, args): 20 | root_folder = os.path.join("./", "outputs", nowtime+"-"+'%s_%.2f_%s' % (args.dataset, args.r, args.noise_mode)) 21 | if not os.path.exists(root_folder): 22 | os.makedirs(root_folder) 23 | # saving pyfiles 24 | folder = os.path.join(root_folder, "folder_for_pyfiles") 25 | if not os.path.exists(folder): 26 | os.makedirs(folder) 27 | export_pyfile(folder) 28 | return root_folder 29 | 30 | class NegEntropy(object): 31 | def __call__(self, outputs): 32 | probs = F.softmax(outputs, dim=1) 33 | return torch.mean(torch.sum(probs.log() * probs, dim=1)) 34 | 35 | def kl_div(p, q): 36 | # p, q is in shape (batch_size, n_classes) 37 | return (p * p.log2() - p * q.log2()).sum(dim=1) 38 | 39 | def js_div(p, q): 40 | # Jensen-Shannon divergence, value is in range (0, 1) 41 | m = 0.5 * (p + q) 42 | return 0.5 * kl_div(p, m) + 0.5 * kl_div(q, m) 43 | 44 | def mixup(inputs, targets, alpha): 45 | l = np.random.beta(alpha, alpha) 46 | l = max(l, 1 - l) 47 | idx = torch.randperm(inputs.size(0)) 48 | input_a, input_b = inputs, inputs[idx] 49 | target_a, target_b = targets, targets[idx] 50 | mixed_input = l * input_a + (1 - l) * input_b 51 | mixed_target = l * target_a + (1 - l) * target_b 52 | return mixed_input, mixed_target 53 | 54 | def getFeature(net, net2, trainloader, testloader, feat_dim, num_class): 55 | transform_bak = trainloader.dataset.transform 56 | trainloader.dataset.transform = testloader.dataset.transform 57 | temploader = torch.utils.data.DataLoader(trainloader.dataset, batch_size=100, shuffle=False, num_workers=8) 58 | 59 | trainFeatures = torch.rand(len(trainloader.dataset), feat_dim).t().cuda() 60 | trainLogits = torch.rand(len(trainloader.dataset), num_class).t().cuda() 61 | trainW = torch.rand(len(trainloader.dataset)).cuda() 62 | trainNoisyLabels = torch.rand(len(trainloader.dataset)).cuda() 63 | 64 | for batch_idx, (inputs, _, _, labels, _, w, _) in enumerate(temploader): 65 | batchSize = inputs.size(0) 66 | logits, features = net(inputs.cuda(), feat=True) 67 | logits2, features2 = net2(inputs.cuda(), feat=True) 68 | 69 | trainFeatures[:, batch_idx * batchSize:batch_idx * batchSize + batchSize] = (features+features2).data.t() 70 | trainLogits[:, batch_idx * batchSize:batch_idx * batchSize + batchSize] = (logits+logits2).data.t() 71 | trainNoisyLabels[batch_idx * batchSize:batch_idx * batchSize + batchSize] = labels.cuda().data 72 | trainW[batch_idx * batchSize:batch_idx * batchSize + batchSize] = w.data 73 | 74 | trainFeatures = trainFeatures.detach().cpu().numpy() 75 | trainLogits = trainLogits.detach().cpu().numpy() 76 | trainNoisyLabels = trainNoisyLabels.detach().cpu().numpy() 77 | trainW = trainW.detach().cpu().numpy() 78 | 79 | trainloader.dataset.transform = transform_bak 80 | return (trainFeatures, trainLogits, trainNoisyLabels, trainW) 81 | 82 | ## function for Neighborhood Collective Noise Verification (NCNV step) 83 | def ncnv(net, eval_loader, num_class, batch_size, feat_dim=512, num_neighbor=20): 84 | net.eval() 85 | 86 | # loading given samples 87 | trainFeatures = torch.rand(len(eval_loader.dataset), feat_dim).t().cuda() 88 | trainLogits = torch.rand(len(eval_loader.dataset), num_class).t().cuda() 89 | trainNoisyLabels = torch.rand(len(eval_loader.dataset)).cuda() 90 | for batch_idx, (inputs, labels, _, _) in enumerate(eval_loader): 91 | batchSize = inputs.size(0) 92 | logits, features = net(inputs.cuda(), feat=True) 93 | trainFeatures[:, batch_idx * batchSize:batch_idx * batchSize + batchSize] = features.data.t() 94 | trainLogits[:, batch_idx * batchSize:batch_idx * batchSize + batchSize] = logits.data.t() 95 | trainNoisyLabels[batch_idx * batchSize:batch_idx * batchSize + batchSize] = labels.cuda().data 96 | 97 | trainFeatures = normalize(trainFeatures.t()) 98 | trainLogits = trainLogits.t() 99 | trainNoisyLabels = trainNoisyLabels 100 | 101 | # caculating neighborhood-based label inconsistency score 102 | num_batch = math.ceil(float(trainFeatures.size(0)) / batch_size) # 391 103 | sver_collection = [] 104 | for batch_idx in range(num_batch): 105 | features = trainFeatures[batch_idx * batch_size:batch_idx * batch_size + batch_size] 106 | noisy_labels = trainNoisyLabels[batch_idx * batch_size:batch_idx * batch_size + batch_size] 107 | dist = torch.mm(features, trainFeatures.t()) 108 | dist[torch.arange(dist.size()[0]), torch.arange(dist.size()[0])] = -1 # set self-contrastive samples to -1 109 | _, neighbors = dist.topk(num_neighbor, dim=1, largest=True, sorted=True) # find contrastive neighbors 110 | neighbors = neighbors.view(-1) 111 | neigh_logits = trainLogits[neighbors] 112 | neigh_probs = F.softmax(neigh_logits, dim=-1) 113 | M, _ = features.shape 114 | given_labels = torch.full(size=(M, num_class), fill_value=0.0001).cuda() 115 | given_labels.scatter_(dim=1, index=torch.unsqueeze(noisy_labels.long(), dim=1), value=1 - 0.0001) 116 | given_labels = given_labels.repeat(1, num_neighbor).view(-1, num_class) 117 | sver = js_div(neigh_probs, given_labels) 118 | sver_collection += sver.view(-1, num_neighbor).mean(dim=1).cpu().numpy().tolist() 119 | sver_collection = np.array(sver_collection) 120 | return sver_collection 121 | 122 | ## function for Neighborhood Collective Label Correction (NCLC step) 123 | def nclc(net, net2, labeled_trainloader, unlabeled_trainloader, testloader, threshold_scor, batch_size, num_class, feat_dim=512, num_neighbor=20, high_scor=1.0): 124 | net.eval() 125 | net2.eval() 126 | 127 | # loading labeled samples 128 | labeledFeatures, labeledLogits, labeledNoisyLabels, labeledW = getFeature(net, net2, labeled_trainloader, testloader, feat_dim, num_class) 129 | knn_labeledLogits = labeledLogits.T[labeledW > 0.5] 130 | knn_labeledFeatures = labeledFeatures.T[labeledW > 0.5] 131 | knn_labeledNoisyLabels = labeledNoisyLabels[labeledW > 0.5] 132 | knn_labeledLogits = torch.from_numpy(knn_labeledLogits).cuda() 133 | knn_labeledFeatures = torch.from_numpy(knn_labeledFeatures).cuda() 134 | knn_labeledNoisyLabels = torch.from_numpy(knn_labeledNoisyLabels).cuda() 135 | 136 | # loading unlabeled samples 137 | unlabeledFeatures, unlabeledLogits, _, _ = getFeature(net, net2, unlabeled_trainloader, testloader, feat_dim, num_class) 138 | unlabeledFeatures = torch.from_numpy(unlabeledFeatures.T).cuda() 139 | unlabeledLogits = torch.from_numpy(unlabeledLogits.T).cuda() 140 | 141 | # normalizing features 142 | knn_labeledFeatures = normalize(knn_labeledFeatures) 143 | unlabeledFeatures = normalize(unlabeledFeatures) 144 | 145 | num_labeled = knn_labeledFeatures.size(0) 146 | num_unlabeled = unlabeledFeatures.size(0) 147 | if num_labeled <= num_neighbor * 10: 148 | pseudo_labels = [-3] * num_unlabeled 149 | pseudo_labels = np.array(pseudo_labels) 150 | print("num_labeled <= num_neighbor * 10 ...") 151 | return torch.from_numpy(pseudo_labels) 152 | 153 | # caculating pseudo-labels for unlabeled samples 154 | num_batch_unlabeled = math.ceil(float(unlabeledFeatures.size(0)) / batch_size) 155 | pseudo_labels = [] 156 | scor_collection = [] 157 | for batch_idx in range(num_batch_unlabeled): 158 | features = unlabeledFeatures[batch_idx * batch_size:batch_idx * batch_size + batch_size] 159 | logits = unlabeledLogits[batch_idx * batch_size:batch_idx * batch_size + batch_size] 160 | dist = torch.mm(features, knn_labeledFeatures.t()) 161 | _, neighbors = dist.topk(num_neighbor, dim=1, largest=True, sorted=True) # find contrastive neighbors 162 | neighbors = neighbors.view(-1) 163 | neighs_labels = knn_labeledNoisyLabels[neighbors] 164 | neighs_logits = knn_labeledLogits[neighbors] 165 | neigh_probs = F.softmax(neighs_logits, dim=-1) 166 | neighbor_labels = torch.full(size=neigh_probs.size(), fill_value=0.0001).cuda() 167 | neighbor_labels.scatter_(dim=1, index=torch.unsqueeze(neighs_labels.long(), dim=1), value=1 - 0.0001) 168 | scor = js_div(F.softmax(logits.repeat(1, num_neighbor).view(-1, num_class), dim=-1), neighbor_labels) 169 | w = (1 - scor).type(torch.FloatTensor) 170 | w = w.view(-1, 1).type(torch.FloatTensor).cuda() 171 | neighbor_labels = (neighbor_labels * w).view(-1, num_neighbor, num_class).sum(dim=1) 172 | pseudo_labels += neighbor_labels.cpu().numpy().tolist() 173 | scor = scor.view(-1, num_neighbor).mean(dim=1) 174 | scor_collection += scor.cpu().numpy().tolist() 175 | scor_collection = np.array(scor_collection) 176 | 177 | pseudo_labels = np.argmax(np.array(pseudo_labels), axis=1) 178 | pseudo_labels[np.equal(scor_collection > threshold_scor, scor_collection <= high_scor)] = -1 179 | pseudo_labels[scor_collection > high_scor] = -2 180 | 181 | return torch.from_numpy(pseudo_labels) 182 | 183 | def warmup(epoch, net, optimizer, dataloader, CEloss, args, conf_penalty, log): 184 | net.train() 185 | num_iter = (len(dataloader.dataset) // dataloader.batch_size) + 1 186 | for batch_idx, (inputs, labels, gt_labels, index) in enumerate(dataloader): 187 | inputs, labels = inputs.cuda(), labels.cuda() 188 | optimizer.zero_grad() 189 | outputs = net(inputs) 190 | loss = CEloss(outputs, labels) 191 | 192 | if args.noise_mode == 'asym': # penalize confident prediction for asymmetric noise 193 | penalty = conf_penalty(outputs) 194 | L = loss + penalty 195 | elif args.noise_mode == 'sym': 196 | L = loss 197 | L.backward() 198 | optimizer.step() 199 | 200 | if (batch_idx + 1) % 50 == 0: 201 | log.write('\r') 202 | log.write('%s:%.1f-%s | Epoch [%3d/%3d] Iter[%3d/%3d]\t CE-loss: %.4f' 203 | % (args.dataset, args.r, args.noise_mode, epoch, args.num_epochs, batch_idx + 1, num_iter, 204 | loss.item())) 205 | log.flush() 206 | 207 | def train(epoch, net, net2, optimizer, labeled_trainloader, unlabeled_trainloader, args, pseudo_labels, log): 208 | net.train() 209 | net2.eval() # fix one network and train the other 210 | 211 | labeled_train_iter = iter(labeled_trainloader) 212 | unlabeled_train_iter = iter(unlabeled_trainloader) 213 | 214 | num_iter_labeled = (len(labeled_trainloader.dataset) // args.batch_size) + 1 215 | num_iter_unlabeled = (len(unlabeled_trainloader.dataset) // args.batch_size) + 1 216 | num_iter = max(num_iter_labeled, num_iter_unlabeled) 217 | 218 | for batch_idx in range(num_iter): 219 | try: 220 | inputs_xw, inputs_xw2, inputs_xs, labels_x, _, w_x, _ = labeled_train_iter.next() 221 | except: 222 | labeled_train_iter = iter(labeled_trainloader) 223 | inputs_xw, inputs_xw2, inputs_xs, labels_x, _, w_x, _ = labeled_train_iter.next() 224 | 225 | try: 226 | inputs_uw, inputs_uw2, inputs_us, labels_u, _, _, index_u = unlabeled_train_iter.next() 227 | except: 228 | unlabeled_train_iter = iter(unlabeled_trainloader) 229 | inputs_uw, inputs_uw2, inputs_us, labels_u, _, _, index_u = unlabeled_train_iter.next() 230 | 231 | # transforming given label to one-hot vector for labeled samples 232 | targets_x = torch.zeros(inputs_xw.size(0), args.num_class).scatter_(1, labels_x.view(-1, 1), 1) 233 | targets_x = targets_x.cuda() 234 | labels_x = labels_x.long().cuda() 235 | mask_x = labels_x >= 0 236 | 237 | # assigning corrected pseudo-labels for unlabeled samples 238 | labels_u = labels_u.long().cuda() 239 | labels_u_temp = pseudo_labels[index_u].long().cuda() 240 | mask_u = labels_u_temp >= 0 241 | labels_u[mask_u] = labels_u_temp[mask_u] 242 | 243 | inputs_xw = inputs_xw.cuda() 244 | inputs_xw2 = inputs_xw2.cuda() 245 | inputs_xs = inputs_xs.cuda() 246 | inputs_uw = inputs_uw.cuda() 247 | inputs_us = inputs_us.cuda() 248 | 249 | # label refinement (refer to DivideMix) 250 | with torch.no_grad(): 251 | outputs_x = net(inputs_xw) 252 | outputs_x2 = net(inputs_xw2) 253 | px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2 254 | w_x = w_x.view(-1, 1).type(torch.FloatTensor).cuda() 255 | px = w_x * targets_x + (1 - w_x) * px 256 | ptx = px ** (1 / args.T) # temparature sharpening 257 | targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize 258 | targets_x = targets_x.detach() 259 | 260 | mixed_inputs = torch.cat([inputs_xw, inputs_xw2], dim=0) 261 | mixed_targets = torch.cat([targets_x, targets_x], dim=0) 262 | mixed_input, mixed_target = mixup(mixed_inputs, mixed_targets, alpha=args.alpha) 263 | mixed_logits = net(mixed_input) 264 | 265 | # mixup regularization for labeled data 266 | Lx = -torch.mean(torch.sum(F.log_softmax(mixed_logits, dim=1) * mixed_target, dim=1)) 267 | 268 | # penalty regularization for mixed labeled data 269 | prior = torch.ones(args.num_class) / args.num_class 270 | prior = prior.cuda() 271 | pred_mean = torch.softmax(mixed_logits, dim=1).mean(0) 272 | penalty = torch.sum(prior * torch.log(prior / pred_mean)) 273 | 274 | # label consistency regularization for unlabeled data 275 | if (args.dataset == 'cifar100') and ((args.r==0.2) or (args.r==0.5)): 276 | all_inputs, all_labels, all_masks = torch.cat([inputs_us, inputs_xs], dim=0), torch.cat([labels_u, labels_x], dim=0), torch.cat([mask_u, mask_x], dim=0) 277 | all_logits = net(all_inputs) 278 | Lu = (F.cross_entropy(all_logits, all_labels, reduction='none') * all_masks.float()).mean() 279 | 280 | else: 281 | logits_us = net(inputs_us) 282 | Lu = (F.cross_entropy(logits_us, labels_u, reduction='none') * mask_u.float()).mean() 283 | 284 | # overall loss 285 | loss = Lx + penalty + Lu 286 | 287 | optimizer.zero_grad() 288 | loss.backward() 289 | optimizer.step() 290 | 291 | log.write("\r") 292 | log.write( 293 | "%s: %.1f-%s | Epoch [%3d/%3d], Iter[%3d/%3d]\t Lx: %.4f, Lu: %.4f, Lpen: %.4f" 294 | % (args.dataset, args.r, args.noise_mode, epoch, args.num_epochs - 1, batch_idx + 1, num_iter, Lx.item(), Lu.item(), penalty.item()) 295 | ) 296 | log.flush() 297 | 298 | 299 | def test(epoch, net1, net2, test_log, test_loader): 300 | net1.eval() 301 | net2.eval() 302 | 303 | correct = 0 304 | total = 0 305 | with torch.no_grad(): 306 | for _, (inputs, targets) in enumerate(test_loader): 307 | 308 | inputs, targets = inputs.cuda(), targets.cuda() 309 | outputs1 = net1(inputs) 310 | outputs2 = net2(inputs) 311 | outputs = outputs1 + outputs2 312 | _, predicted = torch.max(outputs, 1) 313 | 314 | total += targets.size(0) 315 | correct += predicted.eq(targets).cpu().sum().item() 316 | acc = 100. * correct / total 317 | print("\n| Test Epoch #%d\t Accuracy: %.2f%%\n" % (epoch, acc)) 318 | test_log.write('Epoch:%d Accuracy:%.2f\n' % (epoch, acc)) 319 | test_log.flush() 320 | 321 | --------------------------------------------------------------------------------