├── README.md ├── dataset └── cifar10.py ├── models └── preact_resnet.py ├── retrain.py ├── train.py └── utils ├── __init__.py ├── eval.py ├── logger.py ├── misc.py └── visualize.py /README.md: -------------------------------------------------------------------------------- 1 | # Joint Optimization Framework for Learning with Noisy Labels 2 | This is an unofficial PyTorch implementation of [Joint Optimization Framework for Learning with Noisy Labels](https://arxiv.org/abs/1803.11364). 3 | The official Chainer implementation is [here](https://github.com/DaikiTanaka-UT/JointOptimization). 4 | 5 | 6 | ## Requirements 7 | - Python 3.6 8 | - PyTorch 0.4 9 | - torchvision 10 | - progress 11 | - matplotlib 12 | - numpy 13 | 14 | ## Usage 15 | Train the network on the Symmmetric Noise CIFAR-10 dataset (noise rate = 0.7): 16 | 17 | First, 18 | ``` 19 | python train.py --gpu 0 --out first_sn07 --lr 0.08 --alpha 1.2 --beta 0.8 --percent 0.7 20 | ``` 21 | to train and relabel the dataset. 22 | 23 | Secondly, 24 | ``` 25 | python retrain.py --gpu 0 --out second_sn07 --label first_sn07 26 | ``` 27 | to retrain on the relabeled dataset. 28 | 29 | Train the network on the Asymmmetric Noise CIFAR-10 dataset (noise rate = 0.4): 30 | 31 | First, 32 | ``` 33 | python train.py --gpu 0 --out first_an04 --lr 0.03 --alpha 0.8 --beta 0.4 --percent 0.4 --asym 34 | ``` 35 | to train and relabel the dataset. 36 | 37 | Secondly, 38 | ``` 39 | python retrain.py --gpu 0 --out second_an04 --label first_an04 40 | ``` 41 | to retrain on the relabeled dataset. 42 | 43 | 44 | ## References 45 | - D. Tanaka, D. Ikami, T. Yamasaki and K. Aizawa. "Joint Optimization Framework for Learning with Noisy Labels", in CVPR, 2018. -------------------------------------------------------------------------------- /dataset/cifar10.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | import torchvision 5 | 6 | def get_cifar10(root, args, train=True, 7 | transform_train=None, transform_val=None, 8 | download=False): 9 | 10 | base_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download) 11 | train_idxs, val_idxs = train_val_split(base_dataset.train_labels) 12 | 13 | train_dataset = CIFAR10_train(root, train_idxs, args, train=train, transform=transform_train) 14 | if args.asym: 15 | train_dataset.asymmetric_noise() 16 | else: 17 | train_dataset.symmetric_noise() 18 | val_dataset = CIFAR10_val(root, val_idxs, train=train, transform=transform_val) 19 | 20 | print (f"Train: {len(train_idxs)} Val: {len(val_idxs)}") 21 | return train_dataset, val_dataset 22 | 23 | 24 | def train_val_split(train_val): 25 | train_val = np.array(train_val) 26 | train_n = int(len(train_val) * 0.9 / 10) 27 | train_idxs = [] 28 | val_idxs = [] 29 | 30 | for i in range(10): 31 | idxs = np.where(train_val == i)[0] 32 | np.random.shuffle(idxs) 33 | train_idxs.extend(idxs[:train_n]) 34 | val_idxs.extend(idxs[train_n:]) 35 | np.random.shuffle(train_idxs) 36 | np.random.shuffle(val_idxs) 37 | 38 | return train_idxs, val_idxs 39 | 40 | class CIFAR10_train(torchvision.datasets.CIFAR10): 41 | 42 | def __init__(self, root, indexs=None, args=None, train=True, 43 | transform=None, target_transform=None, 44 | download=False): 45 | super(CIFAR10_train, self).__init__(root, train=train, 46 | transform=transform, target_transform=target_transform, 47 | download=download) 48 | self.args = args 49 | if indexs is not None: 50 | self.train_data = self.train_data[indexs] 51 | self.train_labels = np.array(self.train_labels)[indexs] 52 | self.soft_labels = np.zeros((len(self.train_data), 10), dtype=np.float32) 53 | self.prediction = np.zeros((len(self.train_data), 10, 10), dtype=np.float32) 54 | 55 | self.count = 0 56 | 57 | def symmetric_noise(self): 58 | indices = np.random.permutation(len(self.train_data)) 59 | for i, idx in enumerate(indices): 60 | if i < self.args.percent * len(self.train_data): 61 | self.train_labels[idx] = np.random.randint(10, dtype=np.int32) 62 | self.soft_labels[idx][self.train_labels[idx]] = 1. 63 | 64 | def asymmetric_noise(self): 65 | for i in range(10): 66 | indices = np.where(self.train_labels == i)[0] 67 | np.random.shuffle(indices) 68 | for j, idx in enumerate(indices): 69 | if j < self.args.percent * len(indices): 70 | # truck -> automobile 71 | if i == 9: 72 | self.train_labels[idx] = 1 73 | # bird -> airplane 74 | elif i == 2: 75 | self.train_labels[idx] = 0 76 | # cat -> dog 77 | elif i == 3: 78 | self.train_labels[idx] = 5 79 | # dog -> cat 80 | elif i == 5: 81 | self.train_labels[idx] = 3 82 | # deer -> horse 83 | elif i == 4: 84 | self.train_labels[idx] = 7 85 | self.soft_labels[idx][self.train_labels[idx]] = 1. 86 | 87 | def label_update(self, results): 88 | self.count += 1 89 | 90 | # While updating the noisy label y_i by the probability s, we used the average output probability of the network of the past 10 epochs as s. 91 | idx = (self.count - 1) % 10 92 | self.prediction[:, idx] = results 93 | 94 | if self.count >= self.args.begin: 95 | self.soft_labels = self.prediction.mean(axis=1) 96 | self.train_labels = np.argmax(self.soft_labels, axis=1).astype(np.int64) 97 | 98 | if self.count == self.args.epochs: 99 | np.save(f'{self.args.out}/images.npy', self.train_data) 100 | np.save(f'{self.args.out}/labels.npy', self.train_labels) 101 | np.save(f'{self.args.out}/soft_labels.npy', self.soft_labels) 102 | 103 | def reload_label(self): 104 | self.train_data = np.load(f'{self.args.label}/images.npy') 105 | self.train_labels = np.load(f'{self.args.label}/labels.npy') 106 | self.soft_labels = np.load(f'{self.args.label}/soft_labels.npy') 107 | 108 | def __getitem__(self, index): 109 | """ 110 | Args: 111 | index (int): Index 112 | 113 | Returns: 114 | tuple: (image, target) where target is index of the target class. 115 | """ 116 | img, target, soft_target = self.train_data[index], self.train_labels[index], self.soft_labels[index] 117 | 118 | # doing this so that it is consistent with all other datasets 119 | # to return a PIL Image 120 | img = Image.fromarray(img) 121 | 122 | if self.transform is not None: 123 | img = self.transform(img) 124 | 125 | if self.target_transform is not None: 126 | target = self.target_transform(target) 127 | 128 | return img, target, soft_target, index 129 | 130 | 131 | class CIFAR10_val(torchvision.datasets.CIFAR10): 132 | 133 | def __init__(self, root, indexs, train=True, 134 | transform=None, target_transform=None, 135 | download=False): 136 | super(CIFAR10_val, self).__init__(root, train=train, 137 | transform=transform, target_transform=target_transform, 138 | download=download) 139 | 140 | self.train_data = self.train_data[indexs] 141 | self.train_labels = np.array(self.train_labels)[indexs] -------------------------------------------------------------------------------- /models/preact_resnet.py: -------------------------------------------------------------------------------- 1 | '''Pre-activation ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class PreActBlock(nn.Module): 13 | '''Pre-activation version of the BasicBlock.''' 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | out = self.conv2(F.relu(self.bn2(out))) 33 | out += shortcut 34 | return out 35 | 36 | 37 | class PreActBottleneck(nn.Module): 38 | '''Pre-activation version of the original Bottleneck module.''' 39 | expansion = 4 40 | 41 | def __init__(self, in_planes, planes, stride=1): 42 | super(PreActBottleneck, self).__init__() 43 | self.bn1 = nn.BatchNorm2d(in_planes) 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 53 | ) 54 | 55 | def forward(self, x): 56 | out = F.relu(self.bn1(x)) 57 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 58 | out = self.conv1(out) 59 | out = self.conv2(F.relu(self.bn2(out))) 60 | out = self.conv3(F.relu(self.bn3(out))) 61 | out += shortcut 62 | return out 63 | 64 | 65 | class PreActResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10): 67 | super(PreActResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 72 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 73 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 74 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 75 | self.linear = nn.Linear(512*block.expansion, num_classes) 76 | 77 | def _make_layer(self, block, planes, num_blocks, stride): 78 | strides = [stride] + [1]*(num_blocks-1) 79 | layers = [] 80 | for stride in strides: 81 | layers.append(block(self.in_planes, planes, stride)) 82 | self.in_planes = planes * block.expansion 83 | return nn.Sequential(*layers) 84 | 85 | def forward(self, x): 86 | out = self.conv1(x) 87 | out = self.layer1(out) 88 | out = self.layer2(out) 89 | out = self.layer3(out) 90 | out = self.layer4(out) 91 | out = F.avg_pool2d(out, 4) 92 | out = out.view(out.size(0), -1) 93 | out = self.linear(out) 94 | return out 95 | 96 | 97 | def PreActResNet18(): 98 | return PreActResNet(PreActBlock, [2,2,2,2]) 99 | 100 | def PreActResNet34(): 101 | return PreActResNet(PreActBlock, [3,4,6,3]) 102 | 103 | def PreActResNet50(): 104 | return PreActResNet(PreActBottleneck, [3,4,6,3]) 105 | 106 | def PreActResNet101(): 107 | return PreActResNet(PreActBottleneck, [3,4,23,3]) 108 | 109 | def PreActResNet152(): 110 | return PreActResNet(PreActBottleneck, [3,8,36,3]) 111 | 112 | 113 | def test(): 114 | net = PreActResNet18() 115 | y = net((torch.randn(1,3,32,32))) 116 | print(y.size()) 117 | 118 | # test() -------------------------------------------------------------------------------- /retrain.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import shutil 6 | import time 7 | import random 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import torch.backends.cudnn as cudnn 15 | import torch.optim as optim 16 | import torch.utils.data as data 17 | import torchvision 18 | import torchvision.transforms as transforms 19 | import torch.nn.functional as F 20 | 21 | import models.preact_resnet as models 22 | import dataset.cifar10 as dataset 23 | from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10/100 Training') 26 | parser.add_argument('--label', default='result', 27 | help='Directory to input the labels') 28 | # Optimization options 29 | parser.add_argument('--epochs', default=120, type=int, metavar='N', 30 | help='number of total epochs to run') 31 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 32 | help='manual epoch number (useful on restarts)') 33 | parser.add_argument('--batch-size', default=128, type=int, metavar='N', 34 | help='train batchsize') 35 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 36 | metavar='LR', help='initial learning rate') 37 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 38 | help='momentum') 39 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 40 | metavar='W', help='weight decay (default: 1e-4)') 41 | # Checkpoints 42 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 43 | help='path to latest checkpoint (default: none)') 44 | # Miscs 45 | parser.add_argument('--manualSeed', type=int, default=0, help='manual seed') 46 | #Device options 47 | parser.add_argument('--gpu', default='0', type=str, 48 | help='id(s) for CUDA_VISIBLE_DEVICES') 49 | #Method options 50 | parser.add_argument('--out', default='retrain_result', 51 | help='Directory to output the result') 52 | 53 | args = parser.parse_args() 54 | state = {k: v for k, v in args._get_kwargs()} 55 | 56 | # Use CUDA 57 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 58 | use_cuda = torch.cuda.is_available() 59 | 60 | # Random seed 61 | if args.manualSeed is None: 62 | args.manualSeed = random.randint(1, 10000) 63 | random.seed(args.manualSeed) 64 | np.random.seed(args.manualSeed) 65 | torch.manual_seed(args.manualSeed) 66 | if use_cuda: 67 | torch.cuda.manual_seed_all(args.manualSeed) 68 | 69 | best_acc = 0 # best test accuracy 70 | 71 | def main(): 72 | global best_acc 73 | start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch 74 | 75 | if not os.path.isdir(args.out): 76 | mkdir_p(args.out) 77 | 78 | # Data 79 | print(f'==> Preparing relabeled nosiy cifar10') 80 | transform_train = transforms.Compose([ 81 | transforms.RandomCrop(32, padding=8), 82 | transforms.RandomHorizontalFlip(), 83 | transforms.ToTensor(), 84 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 85 | ]) 86 | 87 | transform_test = transforms.Compose([ 88 | transforms.ToTensor(), 89 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 90 | ]) 91 | 92 | trainset = dataset.CIFAR10_train('./data', args=args, train=True, transform=transform_train) 93 | trainset.reload_label() 94 | testset = torchvision.datasets.CIFAR10('./data', train=False, transform=transform_test) 95 | trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4) 96 | testloader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=4) 97 | 98 | # Model 99 | print("==> creating preact_resnet") 100 | model = models.PreActResNet34() 101 | 102 | model = torch.nn.DataParallel(model).cuda() 103 | cudnn.benchmark = True 104 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 105 | 106 | test_criterion = nn.CrossEntropyLoss() 107 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 108 | 109 | # Resume 110 | title = 'noisy-cifar-10' 111 | if args.resume: 112 | # Load checkpoint. 113 | print('==> Resuming from checkpoint..') 114 | assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' 115 | args.out = os.path.dirname(args.resume) 116 | checkpoint = torch.load(args.resume) 117 | best_acc = checkpoint['best_acc'] 118 | start_epoch = checkpoint['epoch'] 119 | model.load_state_dict(checkpoint['state_dict']) 120 | optimizer.load_state_dict(checkpoint['optimizer']) 121 | logger = Logger(os.path.join(args.out, 'log.txt'), title=title, resume=True) 122 | else: 123 | logger = Logger(os.path.join(args.out, 'log.txt'), title=title) 124 | logger.set_names(['Learning Rate', 'Train Loss', 'Test Loss', 'Train Acc.', 'Test Acc.']) 125 | 126 | # Train and test 127 | for epoch in range(start_epoch, args.epochs): 128 | adjust_learning_rate(optimizer, epoch) 129 | print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) 130 | 131 | train_loss, train_acc = train(trainloader, model, optimizer, epoch, use_cuda) 132 | test_loss, test_acc = testidate(testloader, model, test_criterion, epoch, use_cuda) 133 | 134 | # append logger file 135 | logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc]) 136 | 137 | # save model 138 | is_best = test_acc > best_acc 139 | best_acc = max(test_acc, best_acc) 140 | save_checkpoint({ 141 | 'epoch': epoch + 1, 142 | 'state_dict': model.state_dict(), 143 | 'acc': test_acc, 144 | 'best_acc': best_acc, 145 | 'optimizer' : optimizer.state_dict(), 146 | }, is_best) 147 | 148 | logger.close() 149 | logger.plot() 150 | savefig(os.path.join(args.out, 'log.eps')) 151 | 152 | print('Best acc:') 153 | print(best_acc) 154 | 155 | def train(trainloader, model, optimizer, epoch, use_cuda): 156 | # switch to train mode 157 | model.train() 158 | 159 | batch_time = AverageMeter() 160 | data_time = AverageMeter() 161 | losses = AverageMeter() 162 | top1 = AverageMeter() 163 | top5 = AverageMeter() 164 | end = time.time() 165 | 166 | bar = Bar('Training', max=len(trainloader)) 167 | for batch_idx, (inputs, targets, soft_targets, indexs) in enumerate(trainloader): 168 | # measure data loading time 169 | data_time.update(time.time() - end) 170 | 171 | if use_cuda: 172 | inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) 173 | soft_targets, indexs = soft_targets.cuda(non_blocking=True), indexs.cuda(non_blocking=True) 174 | 175 | # compute output 176 | outputs = model(inputs) 177 | 178 | loss = mycriterion(outputs, soft_targets) 179 | 180 | # measure accuracy and record loss 181 | prec1, prec5 = accuracy(outputs, targets, topk=(1, 5)) 182 | losses.update(loss.item(), inputs.size(0)) 183 | top1.update(prec1.item(), inputs.size(0)) 184 | top5.update(prec5.item(), inputs.size(0)) 185 | 186 | # compute gradient and do SGD step 187 | optimizer.zero_grad() 188 | loss.backward() 189 | optimizer.step() 190 | 191 | # measure elapsed time 192 | batch_time.update(time.time() - end) 193 | end = time.time() 194 | 195 | # plot progress 196 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 197 | batch=batch_idx + 1, 198 | size=len(trainloader), 199 | data=data_time.avg, 200 | bt=batch_time.avg, 201 | total=bar.elapsed_td, 202 | eta=bar.eta_td, 203 | loss=losses.avg, 204 | top1=top1.avg, 205 | top5=top5.avg, 206 | ) 207 | bar.next() 208 | bar.finish() 209 | 210 | return (losses.avg, top1.avg) 211 | 212 | def testidate(testloader, model, criterion, epoch, use_cuda): 213 | global best_acc 214 | 215 | batch_time = AverageMeter() 216 | data_time = AverageMeter() 217 | losses = AverageMeter() 218 | top1 = AverageMeter() 219 | top5 = AverageMeter() 220 | 221 | # switch to evaluate mode 222 | model.eval() 223 | 224 | end = time.time() 225 | bar = Bar('Testing ', max=len(testloader)) 226 | with torch.no_grad(): 227 | for batch_idx, (inputs, targets) in enumerate(testloader): 228 | # measure data loading time 229 | data_time.update(time.time() - end) 230 | 231 | if use_cuda: 232 | inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) 233 | 234 | # compute output 235 | outputs = model(inputs) 236 | loss = criterion(outputs, targets) 237 | 238 | # measure accuracy and record loss 239 | prec1, prec5 = accuracy(outputs, targets, topk=(1, 5)) 240 | losses.update(loss.item(), inputs.size(0)) 241 | top1.update(prec1.item(), inputs.size(0)) 242 | top5.update(prec5.item(), inputs.size(0)) 243 | 244 | # measure elapsed time 245 | batch_time.update(time.time() - end) 246 | end = time.time() 247 | 248 | # plot progress 249 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 250 | batch=batch_idx + 1, 251 | size=len(testloader), 252 | data=data_time.avg, 253 | bt=batch_time.avg, 254 | total=bar.elapsed_td, 255 | eta=bar.eta_td, 256 | loss=losses.avg, 257 | top1=top1.avg, 258 | top5=top5.avg, 259 | ) 260 | bar.next() 261 | bar.finish() 262 | return (losses.avg, top1.avg) 263 | 264 | def save_checkpoint(state, is_best, checkpoint=args.out, filename='checkpoint.pth.tar'): 265 | filepath = os.path.join(checkpoint, filename) 266 | torch.save(state, filepath) 267 | if is_best: 268 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) 269 | 270 | def adjust_learning_rate(optimizer, epoch): 271 | global state 272 | if epoch in [int(args.epochs / 3), int(args.epochs * 2 / 3)]: 273 | state['lr'] *= 0.1 274 | for param_group in optimizer.param_groups: 275 | param_group['lr'] = state['lr'] 276 | 277 | def mycriterion(outputs, soft_targets): 278 | loss = -torch.mean(torch.sum(F.log_softmax(outputs, dim=1) * soft_targets, dim=1)) 279 | return loss 280 | 281 | if __name__ == '__main__': 282 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import shutil 6 | import time 7 | import random 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import torch.backends.cudnn as cudnn 15 | import torch.optim as optim 16 | import torch.utils.data as data 17 | import torchvision.transforms as transforms 18 | import torch.nn.functional as F 19 | 20 | import models.preact_resnet as models 21 | import dataset.cifar10 as dataset 22 | from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig 23 | 24 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10/100 Training') 25 | # Optimization options 26 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 27 | help='number of total epochs to run') 28 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 29 | help='manual epoch number (useful on restarts)') 30 | parser.add_argument('--batch-size', default=128, type=int, metavar='N', 31 | help='train batchsize') 32 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 33 | metavar='LR', help='initial learning rate') 34 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 35 | help='momentum') 36 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 37 | metavar='W', help='weight decay (default: 1e-4)') 38 | # Checkpoints 39 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 40 | help='path to latest checkpoint (default: none)') 41 | # Miscs 42 | parser.add_argument('--manualSeed', type=int, default=0, help='manual seed') 43 | #Device options 44 | parser.add_argument('--gpu', default='0', type=str, 45 | help='id(s) for CUDA_VISIBLE_DEVICES') 46 | #Method options 47 | parser.add_argument('--percent', type=float, default=0, 48 | help='Percentage of noise') 49 | parser.add_argument('--begin', type=int, default=70, 50 | help='When to begin updating labels') 51 | parser.add_argument('--alpha', type=float, default=1.0, 52 | help='Hyper parameter alpha of loss function') 53 | parser.add_argument('--beta', type=float, default=0.5, 54 | help='Hyper parameter beta of loss function') 55 | parser.add_argument('--asym', action='store_true', 56 | help='Asymmetric noise') 57 | parser.add_argument('--out', default='result', 58 | help='Directory to output the result') 59 | 60 | args = parser.parse_args() 61 | state = {k: v for k, v in args._get_kwargs()} 62 | 63 | # Use CUDA 64 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 65 | use_cuda = torch.cuda.is_available() 66 | 67 | # Random seed 68 | if args.manualSeed is None: 69 | args.manualSeed = random.randint(1, 10000) 70 | np.random.seed(args.manualSeed) 71 | 72 | best_acc = 0 # best test accuracy 73 | 74 | def main(): 75 | global best_acc 76 | start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch 77 | 78 | if not os.path.isdir(args.out): 79 | mkdir_p(args.out) 80 | 81 | # Data 82 | print(f'==> Preparing {"asymmetric" if args.asym else "symmetric"} nosiy cifar10') 83 | transform_train = transforms.Compose([ 84 | transforms.RandomCrop(32, padding=8), 85 | transforms.RandomHorizontalFlip(), 86 | transforms.ToTensor(), 87 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 88 | ]) 89 | 90 | transform_val = transforms.Compose([ 91 | transforms.ToTensor(), 92 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 93 | ]) 94 | 95 | trainset, valset = dataset.get_cifar10('./data', args, train=True, download=True, transform_train=transform_train, transform_val=transform_val) 96 | trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4) 97 | valloader = data.DataLoader(valset, batch_size=args.batch_size, shuffle=False, num_workers=4) 98 | 99 | # Model 100 | print("==> creating preact_resnet") 101 | model = models.PreActResNet34() 102 | 103 | model = torch.nn.DataParallel(model).cuda() 104 | cudnn.benchmark = True 105 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 106 | 107 | val_criterion = nn.CrossEntropyLoss() 108 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 109 | 110 | # Resume 111 | title = 'noisy-cifar-10' 112 | if args.resume: 113 | # Load checkpoint. 114 | print('==> Resuming from checkpoint..') 115 | assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' 116 | args.out = os.path.dirname(args.resume) 117 | checkpoint = torch.load(args.resume) 118 | best_acc = checkpoint['best_acc'] 119 | start_epoch = checkpoint['epoch'] 120 | model.load_state_dict(checkpoint['state_dict']) 121 | optimizer.load_state_dict(checkpoint['optimizer']) 122 | logger = Logger(os.path.join(args.out, 'log.txt'), title=title, resume=True) 123 | else: 124 | logger = Logger(os.path.join(args.out, 'log.txt'), title=title) 125 | logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.']) 126 | 127 | # Train and val 128 | for epoch in range(start_epoch, args.epochs): 129 | 130 | print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) 131 | 132 | train_loss, train_acc = train(trainloader, model, optimizer, epoch, use_cuda) 133 | val_loss, val_acc = validate(valloader, model, val_criterion, epoch, use_cuda) 134 | 135 | # append logger file 136 | logger.append([state['lr'], train_loss, val_loss, train_acc, val_acc]) 137 | 138 | # save model 139 | is_best = val_acc > best_acc 140 | best_acc = max(val_acc, best_acc) 141 | save_checkpoint({ 142 | 'epoch': epoch + 1, 143 | 'state_dict': model.state_dict(), 144 | 'acc': val_acc, 145 | 'best_acc': best_acc, 146 | 'optimizer' : optimizer.state_dict(), 147 | }, is_best) 148 | 149 | logger.close() 150 | logger.plot() 151 | savefig(os.path.join(args.out, 'log.eps')) 152 | 153 | print('Best acc:') 154 | print(best_acc) 155 | 156 | def train(trainloader, model, optimizer, epoch, use_cuda): 157 | # switch to train mode 158 | model.train() 159 | 160 | batch_time = AverageMeter() 161 | data_time = AverageMeter() 162 | losses = AverageMeter() 163 | top1 = AverageMeter() 164 | top5 = AverageMeter() 165 | end = time.time() 166 | 167 | results = np.zeros((len(trainloader.dataset), 10), dtype=np.float32) 168 | 169 | bar = Bar('Training', max=len(trainloader)) 170 | for batch_idx, (inputs, targets, soft_targets, indexs) in enumerate(trainloader): 171 | # measure data loading time 172 | data_time.update(time.time() - end) 173 | 174 | if use_cuda: 175 | inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) 176 | soft_targets, indexs = soft_targets.cuda(non_blocking=True), indexs.cuda(non_blocking=True) 177 | 178 | # compute output 179 | outputs = model(inputs) 180 | 181 | probs, loss = mycriterion(outputs, soft_targets) 182 | 183 | results[indexs.cpu().detach().numpy().tolist()] = probs.cpu().detach().numpy().tolist() 184 | 185 | # measure accuracy and record loss 186 | prec1, prec5 = accuracy(outputs, targets, topk=(1, 5)) 187 | losses.update(loss.item(), inputs.size(0)) 188 | top1.update(prec1.item(), inputs.size(0)) 189 | top5.update(prec5.item(), inputs.size(0)) 190 | 191 | # compute gradient and do SGD step 192 | optimizer.zero_grad() 193 | loss.backward() 194 | optimizer.step() 195 | 196 | # measure elapsed time 197 | batch_time.update(time.time() - end) 198 | end = time.time() 199 | 200 | # plot progress 201 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 202 | batch=batch_idx + 1, 203 | size=len(trainloader), 204 | data=data_time.avg, 205 | bt=batch_time.avg, 206 | total=bar.elapsed_td, 207 | eta=bar.eta_td, 208 | loss=losses.avg, 209 | top1=top1.avg, 210 | top5=top5.avg, 211 | ) 212 | bar.next() 213 | bar.finish() 214 | 215 | # update soft labels 216 | trainloader.dataset.label_update(results) 217 | return (losses.avg, top1.avg) 218 | 219 | def validate(valloader, model, criterion, epoch, use_cuda): 220 | global best_acc 221 | 222 | batch_time = AverageMeter() 223 | data_time = AverageMeter() 224 | losses = AverageMeter() 225 | top1 = AverageMeter() 226 | top5 = AverageMeter() 227 | 228 | # switch to evaluate mode 229 | model.eval() 230 | 231 | end = time.time() 232 | bar = Bar('Testing ', max=len(valloader)) 233 | with torch.no_grad(): 234 | for batch_idx, (inputs, targets) in enumerate(valloader): 235 | # measure data loading time 236 | data_time.update(time.time() - end) 237 | 238 | if use_cuda: 239 | inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) 240 | 241 | # compute output 242 | outputs = model(inputs) 243 | loss = criterion(outputs, targets) 244 | 245 | # measure accuracy and record loss 246 | prec1, prec5 = accuracy(outputs, targets, topk=(1, 5)) 247 | losses.update(loss.item(), inputs.size(0)) 248 | top1.update(prec1.item(), inputs.size(0)) 249 | top5.update(prec5.item(), inputs.size(0)) 250 | 251 | # measure elapsed time 252 | batch_time.update(time.time() - end) 253 | end = time.time() 254 | 255 | # plot progress 256 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 257 | batch=batch_idx + 1, 258 | size=len(valloader), 259 | data=data_time.avg, 260 | bt=batch_time.avg, 261 | total=bar.elapsed_td, 262 | eta=bar.eta_td, 263 | loss=losses.avg, 264 | top1=top1.avg, 265 | top5=top5.avg, 266 | ) 267 | bar.next() 268 | bar.finish() 269 | return (losses.avg, top1.avg) 270 | 271 | def save_checkpoint(state, is_best, checkpoint=args.out, filename='checkpoint.pth.tar'): 272 | filepath = os.path.join(checkpoint, filename) 273 | torch.save(state, filepath) 274 | if is_best: 275 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) 276 | 277 | def mycriterion(outputs, soft_targets): 278 | # We introduce a prior probability distribution p, which is a distribution of classes among all training data. 279 | p = torch.ones(10).cuda() / 10 280 | 281 | probs = F.softmax(outputs, dim=1) 282 | avg_probs = torch.mean(probs, dim=0) 283 | 284 | L_c = -torch.mean(torch.sum(F.log_softmax(outputs, dim=1) * soft_targets, dim=1)) 285 | L_p = -torch.sum(torch.log(avg_probs) * p) 286 | L_e = -torch.mean(torch.sum(F.log_softmax(outputs, dim=1) * probs, dim=1)) 287 | 288 | loss = L_c + args.alpha * L_p + args.beta * L_e 289 | return probs, loss 290 | 291 | 292 | if __name__ == '__main__': 293 | main() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .visualize import * 6 | from .eval import * 7 | 8 | # progress bar 9 | import os, sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 11 | from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import matplotlib.pyplot as plt 5 | import os 6 | import sys 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | class Logger(object): 24 | '''Save training process to log file with simple plot function.''' 25 | def __init__(self, fpath, title=None, resume=False): 26 | self.file = None 27 | self.resume = resume 28 | self.title = '' if title == None else title 29 | if fpath is not None: 30 | if resume: 31 | self.file = open(fpath, 'r') 32 | name = self.file.readline() 33 | self.names = name.rstrip().split('\t') 34 | self.numbers = {} 35 | for _, name in enumerate(self.names): 36 | self.numbers[name] = [] 37 | 38 | for numbers in self.file: 39 | numbers = numbers.rstrip().split('\t') 40 | for i in range(0, len(numbers)): 41 | self.numbers[self.names[i]].append(numbers[i]) 42 | self.file.close() 43 | self.file = open(fpath, 'a') 44 | else: 45 | self.file = open(fpath, 'w') 46 | 47 | def set_names(self, names): 48 | if self.resume: 49 | pass 50 | # initialize numbers as empty list 51 | self.numbers = {} 52 | self.names = names 53 | for _, name in enumerate(self.names): 54 | self.file.write(name) 55 | self.file.write('\t') 56 | self.numbers[name] = [] 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | 61 | def append(self, numbers): 62 | assert len(self.names) == len(numbers), 'Numbers do not match names' 63 | for index, num in enumerate(numbers): 64 | self.file.write("{0:.6f}".format(num)) 65 | self.file.write('\t') 66 | self.numbers[self.names[index]].append(num) 67 | self.file.write('\n') 68 | self.file.flush() 69 | 70 | def plot(self, names=None): 71 | names = self.names if names == None else names 72 | numbers = self.numbers 73 | for _, name in enumerate(names): 74 | x = np.arange(len(numbers[name])) 75 | plt.plot(x, np.asarray(numbers[name])) 76 | plt.legend([self.title + '(' + name + ')' for name in names]) 77 | plt.grid(True) 78 | 79 | def close(self): 80 | if self.file is not None: 81 | self.file.close() 82 | 83 | class LoggerMonitor(object): 84 | '''Load and visualize multiple logs.''' 85 | def __init__ (self, paths): 86 | '''paths is a distionary with {name:filepath} pair''' 87 | self.loggers = [] 88 | for title, path in paths.items(): 89 | logger = Logger(path, title=title, resume=True) 90 | self.loggers.append(logger) 91 | 92 | def plot(self, names=None): 93 | plt.figure() 94 | plt.subplot(121) 95 | legend_text = [] 96 | for logger in self.loggers: 97 | legend_text += plot_overlap(logger, names) 98 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 99 | plt.grid(True) 100 | 101 | if __name__ == '__main__': 102 | # # Example 103 | # logger = Logger('test.txt') 104 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 105 | 106 | # length = 100 107 | # t = np.arange(length) 108 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 109 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 110 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 111 | 112 | # for i in range(0, length): 113 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 114 | # logger.plot() 115 | 116 | # Example: logger monitor 117 | paths = { 118 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 119 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 120 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 121 | } 122 | 123 | field = ['Valid Acc.'] 124 | 125 | monitor = LoggerMonitor(paths) 126 | monitor.plot(names=field) 127 | savefig('test.eps') -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 17 | 18 | 19 | def get_mean_and_std(dataset): 20 | '''Compute the mean and std value of dataset.''' 21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 22 | 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | def mkdir_p(path): 50 | '''make dir if not exist''' 51 | try: 52 | os.makedirs(path) 53 | except OSError as exc: # Python >2.5 54 | if exc.errno == errno.EEXIST and os.path.isdir(path): 55 | pass 56 | else: 57 | raise 58 | 59 | class AverageMeter(object): 60 | """Computes and stores the average and current value 61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 62 | """ 63 | def __init__(self): 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = 0 68 | self.avg = 0 69 | self.sum = 0 70 | self.count = 0 71 | 72 | def update(self, val, n=1): 73 | self.val = val 74 | self.sum += val * n 75 | self.count += n 76 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from .misc import * 8 | 9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 10 | 11 | # functions to show an image 12 | def make_image(img, mean=(0,0,0), std=(1,1,1)): 13 | for i in range(0, 3): 14 | img[i] = img[i] * std[i] + mean[i] # unnormalize 15 | npimg = img.numpy() 16 | return np.transpose(npimg, (1, 2, 0)) 17 | 18 | def gauss(x,a,b,c): 19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) 20 | 21 | def colorize(x): 22 | ''' Converts a one-channel grayscale image to a color heatmap image ''' 23 | if x.dim() == 2: 24 | torch.unsqueeze(x, 0, out=x) 25 | if x.dim() == 3: 26 | cl = torch.zeros([3, x.size(1), x.size(2)]) 27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 28 | cl[1] = gauss(x,1,.5,.3) 29 | cl[2] = gauss(x,1,.2,.3) 30 | cl[cl.gt(1)] = 1 31 | elif x.dim() == 4: 32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 34 | cl[:,1,:,:] = gauss(x,1,.5,.3) 35 | cl[:,2,:,:] = gauss(x,1,.2,.3) 36 | return cl 37 | 38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 40 | plt.imshow(images) 41 | plt.show() 42 | 43 | 44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 45 | im_size = images.size(2) 46 | 47 | # save for adding mask 48 | im_data = images.clone() 49 | for i in range(0, 3): 50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 51 | 52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 53 | plt.subplot(2, 1, 1) 54 | plt.imshow(images) 55 | plt.axis('off') 56 | 57 | # for b in range(mask.size(0)): 58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 59 | mask_size = mask.size(2) 60 | # print('Max %f Min %f' % (mask.max(), mask.min())) 61 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 63 | # for c in range(3): 64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 65 | 66 | # print(mask.size()) 67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 69 | plt.subplot(2, 1, 2) 70 | plt.imshow(mask) 71 | plt.axis('off') 72 | 73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 74 | im_size = images.size(2) 75 | 76 | # save for adding mask 77 | im_data = images.clone() 78 | for i in range(0, 3): 79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 80 | 81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 82 | plt.subplot(1+len(masklist), 1, 1) 83 | plt.imshow(images) 84 | plt.axis('off') 85 | 86 | for i in range(len(masklist)): 87 | mask = masklist[i].data.cpu() 88 | # for b in range(mask.size(0)): 89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 90 | mask_size = mask.size(2) 91 | # print('Max %f Min %f' % (mask.max(), mask.min())) 92 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 94 | # for c in range(3): 95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 96 | 97 | # print(mask.size()) 98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 100 | plt.subplot(1+len(masklist), 1, i+2) 101 | plt.imshow(mask) 102 | plt.axis('off') 103 | 104 | 105 | 106 | # x = torch.zeros(1, 3, 3) 107 | # out = colorize(x) 108 | # out_im = make_image(out) 109 | # plt.imshow(out_im) 110 | # plt.show() --------------------------------------------------------------------------------