├── LICENSE ├── README.md ├── dataset └── cifar10.py ├── models └── wideresnet.py ├── train.py └── utils ├── __init__.py ├── eval.py ├── logger.py └── misc.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Qing Yu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MixMatch 2 | This is an unofficial PyTorch implementation of [MixMatch: A Holistic Approach to Semi-Supervised Learning](https://arxiv.org/abs/1905.02249). 3 | The official Tensorflow implementation is [here](https://github.com/google-research/mixmatch). 4 | 5 | Now only experiments on CIFAR-10 are available. 6 | 7 | This repository carefully implemented important details of the official implementation to reproduce the results. 8 | 9 | 10 | ## Requirements 11 | - Python 3.6+ 12 | - PyTorch 1.0 13 | - **torchvision 0.2.2 (older versions are not compatible with this code)** 14 | - tensorboardX 15 | - progress 16 | - matplotlib 17 | - numpy 18 | 19 | ## Usage 20 | 21 | ### Train 22 | Train the model by 250 labeled data of CIFAR-10 dataset: 23 | 24 | ``` 25 | python train.py --gpu --n-labeled 250 --out cifar10@250 26 | ``` 27 | 28 | Train the model by 4000 labeled data of CIFAR-10 dataset: 29 | 30 | ``` 31 | python train.py --gpu --n-labeled 4000 --out cifar10@4000 32 | ``` 33 | 34 | ### Monitoring training progress 35 | ``` 36 | tensorboard.sh --port 6006 --logdir cifar10@250 37 | ``` 38 | 39 | ## Results (Accuracy) 40 | | #Labels | 250 | 500 | 1000 | 2000| 4000 | 41 | |:---|:---:|:---:|:---:|:---:|:---:| 42 | |Paper | 88.92 ± 0.87 | 90.35 ± 0.94 | 92.25 ± 0.32| 92.97 ± 0.15 |93.76 ± 0.06| 43 | |This code | 88.71 | 88.96 | 90.52 | 92.23 | 93.52 | 44 | 45 | (Results of this code were evaluated on 1 run. Results of 5 runs with different seeds will be updated later. ) 46 | 47 | ## References 48 | ``` 49 | @article{berthelot2019mixmatch, 50 | title={MixMatch: A Holistic Approach to Semi-Supervised Learning}, 51 | author={Berthelot, David and Carlini, Nicholas and Goodfellow, Ian and Papernot, Nicolas and Oliver, Avital and Raffel, Colin}, 52 | journal={arXiv preprint arXiv:1905.02249}, 53 | year={2019} 54 | } 55 | ``` -------------------------------------------------------------------------------- /dataset/cifar10.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | import torchvision 5 | import torch 6 | 7 | class TransformTwice: 8 | def __init__(self, transform): 9 | self.transform = transform 10 | 11 | def __call__(self, inp): 12 | out1 = self.transform(inp) 13 | out2 = self.transform(inp) 14 | return out1, out2 15 | 16 | def get_cifar10(root, n_labeled, 17 | transform_train=None, transform_val=None, 18 | download=True): 19 | 20 | base_dataset = torchvision.datasets.CIFAR10(root, train=True, download=download) 21 | train_labeled_idxs, train_unlabeled_idxs, val_idxs = train_val_split(base_dataset.targets, int(n_labeled/10)) 22 | 23 | train_labeled_dataset = CIFAR10_labeled(root, train_labeled_idxs, train=True, transform=transform_train) 24 | train_unlabeled_dataset = CIFAR10_unlabeled(root, train_unlabeled_idxs, train=True, transform=TransformTwice(transform_train)) 25 | val_dataset = CIFAR10_labeled(root, val_idxs, train=True, transform=transform_val, download=True) 26 | test_dataset = CIFAR10_labeled(root, train=False, transform=transform_val, download=True) 27 | 28 | print (f"#Labeled: {len(train_labeled_idxs)} #Unlabeled: {len(train_unlabeled_idxs)} #Val: {len(val_idxs)}") 29 | return train_labeled_dataset, train_unlabeled_dataset, val_dataset, test_dataset 30 | 31 | 32 | def train_val_split(labels, n_labeled_per_class): 33 | labels = np.array(labels) 34 | train_labeled_idxs = [] 35 | train_unlabeled_idxs = [] 36 | val_idxs = [] 37 | 38 | for i in range(10): 39 | idxs = np.where(labels == i)[0] 40 | np.random.shuffle(idxs) 41 | train_labeled_idxs.extend(idxs[:n_labeled_per_class]) 42 | train_unlabeled_idxs.extend(idxs[n_labeled_per_class:-500]) 43 | val_idxs.extend(idxs[-500:]) 44 | np.random.shuffle(train_labeled_idxs) 45 | np.random.shuffle(train_unlabeled_idxs) 46 | np.random.shuffle(val_idxs) 47 | 48 | return train_labeled_idxs, train_unlabeled_idxs, val_idxs 49 | 50 | cifar10_mean = (0.4914, 0.4822, 0.4465) # equals np.mean(train_set.train_data, axis=(0,1,2))/255 51 | cifar10_std = (0.2471, 0.2435, 0.2616) # equals np.std(train_set.train_data, axis=(0,1,2))/255 52 | 53 | def normalize(x, mean=cifar10_mean, std=cifar10_std): 54 | x, mean, std = [np.array(a, np.float32) for a in (x, mean, std)] 55 | x -= mean*255 56 | x *= 1.0/(255*std) 57 | return x 58 | 59 | def transpose(x, source='NHWC', target='NCHW'): 60 | return x.transpose([source.index(d) for d in target]) 61 | 62 | def pad(x, border=4): 63 | return np.pad(x, [(0, 0), (border, border), (border, border)], mode='reflect') 64 | 65 | class RandomPadandCrop(object): 66 | """Crop randomly the image. 67 | 68 | Args: 69 | output_size (tuple or int): Desired output size. If int, square crop 70 | is made. 71 | """ 72 | 73 | def __init__(self, output_size): 74 | assert isinstance(output_size, (int, tuple)) 75 | if isinstance(output_size, int): 76 | self.output_size = (output_size, output_size) 77 | else: 78 | assert len(output_size) == 2 79 | self.output_size = output_size 80 | 81 | def __call__(self, x): 82 | x = pad(x, 4) 83 | 84 | h, w = x.shape[1:] 85 | new_h, new_w = self.output_size 86 | 87 | top = np.random.randint(0, h - new_h) 88 | left = np.random.randint(0, w - new_w) 89 | 90 | x = x[:, top: top + new_h, left: left + new_w] 91 | 92 | return x 93 | 94 | class RandomFlip(object): 95 | """Flip randomly the image. 96 | """ 97 | def __call__(self, x): 98 | if np.random.rand() < 0.5: 99 | x = x[:, :, ::-1] 100 | 101 | return x.copy() 102 | 103 | class GaussianNoise(object): 104 | """Add gaussian noise to the image. 105 | """ 106 | def __call__(self, x): 107 | c, h, w = x.shape 108 | x += np.random.randn(c, h, w) * 0.15 109 | return x 110 | 111 | class ToTensor(object): 112 | """Transform the image to tensor. 113 | """ 114 | def __call__(self, x): 115 | x = torch.from_numpy(x) 116 | return x 117 | 118 | class CIFAR10_labeled(torchvision.datasets.CIFAR10): 119 | 120 | def __init__(self, root, indexs=None, train=True, 121 | transform=None, target_transform=None, 122 | download=False): 123 | super(CIFAR10_labeled, self).__init__(root, train=train, 124 | transform=transform, target_transform=target_transform, 125 | download=download) 126 | if indexs is not None: 127 | self.data = self.data[indexs] 128 | self.targets = np.array(self.targets)[indexs] 129 | self.data = transpose(normalize(self.data)) 130 | 131 | def __getitem__(self, index): 132 | """ 133 | Args: 134 | index (int): Index 135 | 136 | Returns: 137 | tuple: (image, target) where target is index of the target class. 138 | """ 139 | img, target = self.data[index], self.targets[index] 140 | 141 | if self.transform is not None: 142 | img = self.transform(img) 143 | 144 | if self.target_transform is not None: 145 | target = self.target_transform(target) 146 | 147 | return img, target 148 | 149 | 150 | class CIFAR10_unlabeled(CIFAR10_labeled): 151 | 152 | def __init__(self, root, indexs, train=True, 153 | transform=None, target_transform=None, 154 | download=False): 155 | super(CIFAR10_unlabeled, self).__init__(root, indexs, train=train, 156 | transform=transform, target_transform=target_transform, 157 | download=download) 158 | self.targets = np.array([-1 for i in range(len(self.targets))]) 159 | -------------------------------------------------------------------------------- /models/wideresnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activate_before_residual=False): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001) 11 | self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001) 15 | self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | self.activate_before_residual = activate_before_residual 23 | def forward(self, x): 24 | if not self.equalInOut and self.activate_before_residual == True: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | class NetworkBlock(nn.Module): 35 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0, activate_before_residual=False): 36 | super(NetworkBlock, self).__init__() 37 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate, activate_before_residual) 38 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate, activate_before_residual): 39 | layers = [] 40 | for i in range(int(nb_layers)): 41 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate, activate_before_residual)) 42 | return nn.Sequential(*layers) 43 | def forward(self, x): 44 | return self.layer(x) 45 | 46 | class WideResNet(nn.Module): 47 | def __init__(self, num_classes, depth=28, widen_factor=2, dropRate=0.0): 48 | super(WideResNet, self).__init__() 49 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 50 | assert((depth - 4) % 6 == 0) 51 | n = (depth - 4) / 6 52 | block = BasicBlock 53 | # 1st conv before any network block 54 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 55 | padding=1, bias=False) 56 | # 1st block 57 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activate_before_residual=True) 58 | # 2nd block 59 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 60 | # 3rd block 61 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 62 | # global average pooling and classifier 63 | self.bn1 = nn.BatchNorm2d(nChannels[3], momentum=0.001) 64 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 65 | self.fc = nn.Linear(nChannels[3], num_classes) 66 | self.nChannels = nChannels[3] 67 | 68 | for m in self.modules(): 69 | if isinstance(m, nn.Conv2d): 70 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 71 | m.weight.data.normal_(0, math.sqrt(2. / n)) 72 | elif isinstance(m, nn.BatchNorm2d): 73 | m.weight.data.fill_(1) 74 | m.bias.data.zero_() 75 | elif isinstance(m, nn.Linear): 76 | nn.init.xavier_normal_(m.weight.data) 77 | m.bias.data.zero_() 78 | 79 | def forward(self, x): 80 | out = self.conv1(x) 81 | out = self.block1(out) 82 | out = self.block2(out) 83 | out = self.block3(out) 84 | out = self.relu(self.bn1(out)) 85 | out = F.avg_pool2d(out, 8) 86 | out = out.view(-1, self.nChannels) 87 | return self.fc(out) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import shutil 6 | import time 7 | import random 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import torch.backends.cudnn as cudnn 15 | import torch.optim as optim 16 | import torch.utils.data as data 17 | import torchvision.transforms as transforms 18 | import torch.nn.functional as F 19 | 20 | import models.wideresnet as models 21 | import dataset.cifar10 as dataset 22 | from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig 23 | from tensorboardX import SummaryWriter 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch MixMatch Training') 26 | # Optimization options 27 | parser.add_argument('--epochs', default=1024, type=int, metavar='N', 28 | help='number of total epochs to run') 29 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 30 | help='manual epoch number (useful on restarts)') 31 | parser.add_argument('--batch-size', default=64, type=int, metavar='N', 32 | help='train batchsize') 33 | parser.add_argument('--lr', '--learning-rate', default=0.002, type=float, 34 | metavar='LR', help='initial learning rate') 35 | # Checkpoints 36 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 37 | help='path to latest checkpoint (default: none)') 38 | # Miscs 39 | parser.add_argument('--manualSeed', type=int, default=0, help='manual seed') 40 | #Device options 41 | parser.add_argument('--gpu', default='0', type=str, 42 | help='id(s) for CUDA_VISIBLE_DEVICES') 43 | #Method options 44 | parser.add_argument('--n-labeled', type=int, default=250, 45 | help='Number of labeled data') 46 | parser.add_argument('--train-iteration', type=int, default=1024, 47 | help='Number of iteration per epoch') 48 | parser.add_argument('--out', default='result', 49 | help='Directory to output the result') 50 | parser.add_argument('--alpha', default=0.75, type=float) 51 | parser.add_argument('--lambda-u', default=75, type=float) 52 | parser.add_argument('--T', default=0.5, type=float) 53 | parser.add_argument('--ema-decay', default=0.999, type=float) 54 | 55 | 56 | args = parser.parse_args() 57 | state = {k: v for k, v in args._get_kwargs()} 58 | 59 | # Use CUDA 60 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 61 | use_cuda = torch.cuda.is_available() 62 | 63 | # Random seed 64 | if args.manualSeed is None: 65 | args.manualSeed = random.randint(1, 10000) 66 | np.random.seed(args.manualSeed) 67 | 68 | best_acc = 0 # best test accuracy 69 | 70 | def main(): 71 | global best_acc 72 | 73 | if not os.path.isdir(args.out): 74 | mkdir_p(args.out) 75 | 76 | # Data 77 | print(f'==> Preparing cifar10') 78 | transform_train = transforms.Compose([ 79 | dataset.RandomPadandCrop(32), 80 | dataset.RandomFlip(), 81 | dataset.ToTensor(), 82 | ]) 83 | 84 | transform_val = transforms.Compose([ 85 | dataset.ToTensor(), 86 | ]) 87 | 88 | train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10('./data', args.n_labeled, transform_train=transform_train, transform_val=transform_val) 89 | labeled_trainloader = data.DataLoader(train_labeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True) 90 | unlabeled_trainloader = data.DataLoader(train_unlabeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True) 91 | val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=0) 92 | test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=0) 93 | 94 | # Model 95 | print("==> creating WRN-28-2") 96 | 97 | def create_model(ema=False): 98 | model = models.WideResNet(num_classes=10) 99 | model = model.cuda() 100 | 101 | if ema: 102 | for param in model.parameters(): 103 | param.detach_() 104 | 105 | return model 106 | 107 | model = create_model() 108 | ema_model = create_model(ema=True) 109 | 110 | cudnn.benchmark = True 111 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 112 | 113 | train_criterion = SemiLoss() 114 | criterion = nn.CrossEntropyLoss() 115 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 116 | 117 | ema_optimizer= WeightEMA(model, ema_model, alpha=args.ema_decay) 118 | start_epoch = 0 119 | 120 | # Resume 121 | title = 'noisy-cifar-10' 122 | if args.resume: 123 | # Load checkpoint. 124 | print('==> Resuming from checkpoint..') 125 | assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' 126 | args.out = os.path.dirname(args.resume) 127 | checkpoint = torch.load(args.resume) 128 | best_acc = checkpoint['best_acc'] 129 | start_epoch = checkpoint['epoch'] 130 | model.load_state_dict(checkpoint['state_dict']) 131 | ema_model.load_state_dict(checkpoint['ema_state_dict']) 132 | optimizer.load_state_dict(checkpoint['optimizer']) 133 | logger = Logger(os.path.join(args.out, 'log.txt'), title=title, resume=True) 134 | else: 135 | logger = Logger(os.path.join(args.out, 'log.txt'), title=title) 136 | logger.set_names(['Train Loss', 'Train Loss X', 'Train Loss U', 'Valid Loss', 'Valid Acc.', 'Test Loss', 'Test Acc.']) 137 | 138 | writer = SummaryWriter(args.out) 139 | step = 0 140 | test_accs = [] 141 | # Train and val 142 | for epoch in range(start_epoch, args.epochs): 143 | 144 | print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) 145 | 146 | train_loss, train_loss_x, train_loss_u = train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, train_criterion, epoch, use_cuda) 147 | _, train_acc = validate(labeled_trainloader, ema_model, criterion, epoch, use_cuda, mode='Train Stats') 148 | val_loss, val_acc = validate(val_loader, ema_model, criterion, epoch, use_cuda, mode='Valid Stats') 149 | test_loss, test_acc = validate(test_loader, ema_model, criterion, epoch, use_cuda, mode='Test Stats ') 150 | 151 | step = args.train_iteration * (epoch + 1) 152 | 153 | writer.add_scalar('losses/train_loss', train_loss, step) 154 | writer.add_scalar('losses/valid_loss', val_loss, step) 155 | writer.add_scalar('losses/test_loss', test_loss, step) 156 | 157 | writer.add_scalar('accuracy/train_acc', train_acc, step) 158 | writer.add_scalar('accuracy/val_acc', val_acc, step) 159 | writer.add_scalar('accuracy/test_acc', test_acc, step) 160 | 161 | # append logger file 162 | logger.append([train_loss, train_loss_x, train_loss_u, val_loss, val_acc, test_loss, test_acc]) 163 | 164 | # save model 165 | is_best = val_acc > best_acc 166 | best_acc = max(val_acc, best_acc) 167 | save_checkpoint({ 168 | 'epoch': epoch + 1, 169 | 'state_dict': model.state_dict(), 170 | 'ema_state_dict': ema_model.state_dict(), 171 | 'acc': val_acc, 172 | 'best_acc': best_acc, 173 | 'optimizer' : optimizer.state_dict(), 174 | }, is_best) 175 | test_accs.append(test_acc) 176 | logger.close() 177 | writer.close() 178 | 179 | print('Best acc:') 180 | print(best_acc) 181 | 182 | print('Mean acc:') 183 | print(np.mean(test_accs[-20:])) 184 | 185 | 186 | def train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, criterion, epoch, use_cuda): 187 | 188 | batch_time = AverageMeter() 189 | data_time = AverageMeter() 190 | losses = AverageMeter() 191 | losses_x = AverageMeter() 192 | losses_u = AverageMeter() 193 | ws = AverageMeter() 194 | end = time.time() 195 | 196 | bar = Bar('Training', max=args.train_iteration) 197 | labeled_train_iter = iter(labeled_trainloader) 198 | unlabeled_train_iter = iter(unlabeled_trainloader) 199 | 200 | model.train() 201 | for batch_idx in range(args.train_iteration): 202 | try: 203 | inputs_x, targets_x = labeled_train_iter.next() 204 | except: 205 | labeled_train_iter = iter(labeled_trainloader) 206 | inputs_x, targets_x = labeled_train_iter.next() 207 | 208 | try: 209 | (inputs_u, inputs_u2), _ = unlabeled_train_iter.next() 210 | except: 211 | unlabeled_train_iter = iter(unlabeled_trainloader) 212 | (inputs_u, inputs_u2), _ = unlabeled_train_iter.next() 213 | 214 | # measure data loading time 215 | data_time.update(time.time() - end) 216 | 217 | batch_size = inputs_x.size(0) 218 | 219 | # Transform label to one-hot 220 | targets_x = torch.zeros(batch_size, 10).scatter_(1, targets_x.view(-1,1).long(), 1) 221 | 222 | if use_cuda: 223 | inputs_x, targets_x = inputs_x.cuda(), targets_x.cuda(non_blocking=True) 224 | inputs_u = inputs_u.cuda() 225 | inputs_u2 = inputs_u2.cuda() 226 | 227 | 228 | with torch.no_grad(): 229 | # compute guessed labels of unlabel samples 230 | outputs_u = model(inputs_u) 231 | outputs_u2 = model(inputs_u2) 232 | p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2 233 | pt = p**(1/args.T) 234 | targets_u = pt / pt.sum(dim=1, keepdim=True) 235 | targets_u = targets_u.detach() 236 | 237 | # mixup 238 | all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0) 239 | all_targets = torch.cat([targets_x, targets_u, targets_u], dim=0) 240 | 241 | l = np.random.beta(args.alpha, args.alpha) 242 | 243 | l = max(l, 1-l) 244 | 245 | idx = torch.randperm(all_inputs.size(0)) 246 | 247 | input_a, input_b = all_inputs, all_inputs[idx] 248 | target_a, target_b = all_targets, all_targets[idx] 249 | 250 | mixed_input = l * input_a + (1 - l) * input_b 251 | mixed_target = l * target_a + (1 - l) * target_b 252 | 253 | # interleave labeled and unlabed samples between batches to get correct batchnorm calculation 254 | mixed_input = list(torch.split(mixed_input, batch_size)) 255 | mixed_input = interleave(mixed_input, batch_size) 256 | 257 | logits = [model(mixed_input[0])] 258 | for input in mixed_input[1:]: 259 | logits.append(model(input)) 260 | 261 | # put interleaved samples back 262 | logits = interleave(logits, batch_size) 263 | logits_x = logits[0] 264 | logits_u = torch.cat(logits[1:], dim=0) 265 | 266 | Lx, Lu, w = criterion(logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:], epoch+batch_idx/args.train_iteration) 267 | 268 | loss = Lx + w * Lu 269 | 270 | # record loss 271 | losses.update(loss.item(), inputs_x.size(0)) 272 | losses_x.update(Lx.item(), inputs_x.size(0)) 273 | losses_u.update(Lu.item(), inputs_x.size(0)) 274 | ws.update(w, inputs_x.size(0)) 275 | 276 | # compute gradient and do SGD step 277 | optimizer.zero_grad() 278 | loss.backward() 279 | optimizer.step() 280 | ema_optimizer.step() 281 | 282 | # measure elapsed time 283 | batch_time.update(time.time() - end) 284 | end = time.time() 285 | 286 | # plot progress 287 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Loss_x: {loss_x:.4f} | Loss_u: {loss_u:.4f} | W: {w:.4f}'.format( 288 | batch=batch_idx + 1, 289 | size=args.train_iteration, 290 | data=data_time.avg, 291 | bt=batch_time.avg, 292 | total=bar.elapsed_td, 293 | eta=bar.eta_td, 294 | loss=losses.avg, 295 | loss_x=losses_x.avg, 296 | loss_u=losses_u.avg, 297 | w=ws.avg, 298 | ) 299 | bar.next() 300 | bar.finish() 301 | 302 | return (losses.avg, losses_x.avg, losses_u.avg,) 303 | 304 | def validate(valloader, model, criterion, epoch, use_cuda, mode): 305 | 306 | batch_time = AverageMeter() 307 | data_time = AverageMeter() 308 | losses = AverageMeter() 309 | top1 = AverageMeter() 310 | top5 = AverageMeter() 311 | 312 | # switch to evaluate mode 313 | model.eval() 314 | 315 | end = time.time() 316 | bar = Bar(f'{mode}', max=len(valloader)) 317 | with torch.no_grad(): 318 | for batch_idx, (inputs, targets) in enumerate(valloader): 319 | # measure data loading time 320 | data_time.update(time.time() - end) 321 | 322 | if use_cuda: 323 | inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) 324 | # compute output 325 | outputs = model(inputs) 326 | loss = criterion(outputs, targets) 327 | 328 | # measure accuracy and record loss 329 | prec1, prec5 = accuracy(outputs, targets, topk=(1, 5)) 330 | losses.update(loss.item(), inputs.size(0)) 331 | top1.update(prec1.item(), inputs.size(0)) 332 | top5.update(prec5.item(), inputs.size(0)) 333 | 334 | # measure elapsed time 335 | batch_time.update(time.time() - end) 336 | end = time.time() 337 | 338 | # plot progress 339 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 340 | batch=batch_idx + 1, 341 | size=len(valloader), 342 | data=data_time.avg, 343 | bt=batch_time.avg, 344 | total=bar.elapsed_td, 345 | eta=bar.eta_td, 346 | loss=losses.avg, 347 | top1=top1.avg, 348 | top5=top5.avg, 349 | ) 350 | bar.next() 351 | bar.finish() 352 | return (losses.avg, top1.avg) 353 | 354 | def save_checkpoint(state, is_best, checkpoint=args.out, filename='checkpoint.pth.tar'): 355 | filepath = os.path.join(checkpoint, filename) 356 | torch.save(state, filepath) 357 | if is_best: 358 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) 359 | 360 | def linear_rampup(current, rampup_length=args.epochs): 361 | if rampup_length == 0: 362 | return 1.0 363 | else: 364 | current = np.clip(current / rampup_length, 0.0, 1.0) 365 | return float(current) 366 | 367 | class SemiLoss(object): 368 | def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch): 369 | probs_u = torch.softmax(outputs_u, dim=1) 370 | 371 | Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1)) 372 | Lu = torch.mean((probs_u - targets_u)**2) 373 | 374 | return Lx, Lu, args.lambda_u * linear_rampup(epoch) 375 | 376 | class WeightEMA(object): 377 | def __init__(self, model, ema_model, alpha=0.999): 378 | self.model = model 379 | self.ema_model = ema_model 380 | self.alpha = alpha 381 | self.params = list(model.state_dict().values()) 382 | self.ema_params = list(ema_model.state_dict().values()) 383 | self.wd = 0.02 * args.lr 384 | 385 | for param, ema_param in zip(self.params, self.ema_params): 386 | param.data.copy_(ema_param.data) 387 | 388 | def step(self): 389 | one_minus_alpha = 1.0 - self.alpha 390 | for param, ema_param in zip(self.params, self.ema_params): 391 | if ema_param.dtype==torch.float32: 392 | ema_param.mul_(self.alpha) 393 | ema_param.add_(param * one_minus_alpha) 394 | # customized weight decay 395 | param.mul_(1 - self.wd) 396 | 397 | def interleave_offsets(batch, nu): 398 | groups = [batch // (nu + 1)] * (nu + 1) 399 | for x in range(batch - sum(groups)): 400 | groups[-x - 1] += 1 401 | offsets = [0] 402 | for g in groups: 403 | offsets.append(offsets[-1] + g) 404 | assert offsets[-1] == batch 405 | return offsets 406 | 407 | 408 | def interleave(xy, batch): 409 | nu = len(xy) - 1 410 | offsets = interleave_offsets(batch, nu) 411 | xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy] 412 | for i in range(1, nu + 1): 413 | xy[0][i], xy[i][i] = xy[i][i], xy[0][i] 414 | return [torch.cat(v, dim=0) for v in xy] 415 | 416 | if __name__ == '__main__': 417 | main() 418 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .eval import * 6 | 7 | # progress bar 8 | import os, sys 9 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 10 | from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import matplotlib.pyplot as plt 5 | import os 6 | import sys 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | class Logger(object): 24 | '''Save training process to log file with simple plot function.''' 25 | def __init__(self, fpath, title=None, resume=False): 26 | self.file = None 27 | self.resume = resume 28 | self.title = '' if title == None else title 29 | if fpath is not None: 30 | if resume: 31 | self.file = open(fpath, 'r') 32 | name = self.file.readline() 33 | self.names = name.rstrip().split('\t') 34 | self.numbers = {} 35 | for _, name in enumerate(self.names): 36 | self.numbers[name] = [] 37 | 38 | for numbers in self.file: 39 | numbers = numbers.rstrip().split('\t') 40 | for i in range(0, len(numbers)): 41 | self.numbers[self.names[i]].append(numbers[i]) 42 | self.file.close() 43 | self.file = open(fpath, 'a') 44 | else: 45 | self.file = open(fpath, 'w') 46 | 47 | def set_names(self, names): 48 | if self.resume: 49 | pass 50 | # initialize numbers as empty list 51 | self.numbers = {} 52 | self.names = names 53 | for _, name in enumerate(self.names): 54 | self.file.write(name) 55 | self.file.write('\t') 56 | self.numbers[name] = [] 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | 61 | def append(self, numbers): 62 | assert len(self.names) == len(numbers), 'Numbers do not match names' 63 | for index, num in enumerate(numbers): 64 | self.file.write("{0:.6f}".format(num)) 65 | self.file.write('\t') 66 | self.numbers[self.names[index]].append(num) 67 | self.file.write('\n') 68 | self.file.flush() 69 | 70 | def plot(self, names=None): 71 | names = self.names if names == None else names 72 | numbers = self.numbers 73 | for _, name in enumerate(names): 74 | x = np.arange(len(numbers[name])) 75 | plt.plot(x, np.asarray(numbers[name])) 76 | plt.legend([self.title + '(' + name + ')' for name in names]) 77 | plt.grid(True) 78 | 79 | def close(self): 80 | if self.file is not None: 81 | self.file.close() 82 | 83 | class LoggerMonitor(object): 84 | '''Load and visualize multiple logs.''' 85 | def __init__ (self, paths): 86 | '''paths is a distionary with {name:filepath} pair''' 87 | self.loggers = [] 88 | for title, path in paths.items(): 89 | logger = Logger(path, title=title, resume=True) 90 | self.loggers.append(logger) 91 | 92 | def plot(self, names=None): 93 | plt.figure() 94 | plt.subplot(121) 95 | legend_text = [] 96 | for logger in self.loggers: 97 | legend_text += plot_overlap(logger, names) 98 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 99 | plt.grid(True) 100 | 101 | if __name__ == '__main__': 102 | # # Example 103 | # logger = Logger('test.txt') 104 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 105 | 106 | # length = 100 107 | # t = np.arange(length) 108 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 109 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 110 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 111 | 112 | # for i in range(0, length): 113 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 114 | # logger.plot() 115 | 116 | # Example: logger monitor 117 | paths = { 118 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 119 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 120 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 121 | } 122 | 123 | field = ['Valid Acc.'] 124 | 125 | monitor = LoggerMonitor(paths) 126 | monitor.plot(names=field) 127 | savefig('test.eps') -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 17 | 18 | 19 | def get_mean_and_std(dataset): 20 | '''Compute the mean and std value of dataset.''' 21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 22 | 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | def mkdir_p(path): 50 | '''make dir if not exist''' 51 | try: 52 | os.makedirs(path) 53 | except OSError as exc: # Python >2.5 54 | if exc.errno == errno.EEXIST and os.path.isdir(path): 55 | pass 56 | else: 57 | raise 58 | 59 | class AverageMeter(object): 60 | """Computes and stores the average and current value 61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 62 | """ 63 | def __init__(self): 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = 0 68 | self.avg = 0 69 | self.sum = 0 70 | self.count = 0 71 | 72 | def update(self, val, n=1): 73 | self.val = val 74 | self.sum += val * n 75 | self.count += n 76 | self.avg = self.sum / self.count --------------------------------------------------------------------------------