├── lib ├── __init__.py ├── models │ ├── __init__.py │ ├── wideresnet.py │ ├── resnet_cifar.py │ └── resnet.py ├── datasets │ ├── __init__.py │ ├── cifar.py │ └── folder.py ├── normalize.py ├── NCECriterion.py ├── LinearAverage.py ├── alias_multinomial.py ├── utils.py └── NCEAverage.py ├── scripts ├── instance_cifar10.sh ├── finetune_cifar10.sh ├── finetune_imagenet.sh └── download_model.sh ├── LICENSE ├── SECURITY.md ├── test.py ├── .gitignore ├── README.md ├── unsupervised ├── cifar.py └── imagenet.py ├── notebooks ├── knn-imagenet.ipynb └── nc-colorization.ipynb ├── cifar-semi.py └── imagenet-semi.py /lib/__init__.py: -------------------------------------------------------------------------------- 1 | # nothing 2 | -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .resnet_cifar import * 3 | from .wideresnet import * 4 | -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar import CIFAR10Instance, PseudoCIFAR10 2 | from .folder import ImageFolderInstance, PseudoDatasetFolder 3 | 4 | __all__ = ('CIFAR10Instance', 'PseudoDatasetFolder') 5 | 6 | -------------------------------------------------------------------------------- /scripts/instance_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | export PYTHONPATH=$PYTHONPATH:$(pwd) 6 | 7 | CUDA_VISIBLE_DEVICES=5 python unsupervised/cifar.py --lr-scheduler cosine-with-restart --epochs 1270 8 | -------------------------------------------------------------------------------- /scripts/finetune_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | num_labeled=250 4 | python cifar-semi.py \ 5 | --gpus 0 \ 6 | --num-labeled ${num_labeled} \ 7 | --pseudo-file checkpoint/pseudos/instance_nc_wrn-28-2/${num_labeled}_T_1.pth.tar \ 8 | --resume checkpoint/pretrain_models/ckpt_instance_cifar10_wrn-28-2_82.12.pth.tar \ 9 | --pseudo-ratio 0.2 10 | -------------------------------------------------------------------------------- /lib/normalize.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class Normalize(nn.Module): 5 | 6 | def __init__(self, power=2): 7 | super(Normalize, self).__init__() 8 | self.power = power 9 | 10 | def forward(self, x): 11 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 12 | out = x.div(norm) 13 | return out 14 | -------------------------------------------------------------------------------- /scripts/finetune_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # num_labeled=13000 4 | # pseudo_ratio=0.1 5 | 6 | # or 7 | num_labeled=26000 8 | pseudo_ratio=0.2 9 | 10 | # or 11 | # num_labeled=51000 12 | # pseudo_ratio=0.5 13 | python imagenet-semi.py \ 14 | --arch resnet50 \ 15 | --gpus 1,2,6,7 \ 16 | --num-labeled ${num_labeled} \ 17 | --data-dir /home/liubin/data/imagenet \ 18 | --pretrained checkpoint/pretrain_models/lemniscate_resnet50.pth.tar \ 19 | --pseudo-dir checkpoint/pseudos_imagenet/instance_imagenet_nc_resnet50/num_labeled_${num_labeled} \ 20 | --pseudo-ratio ${pseudo_ratio} \ 21 | 22 | -------------------------------------------------------------------------------- /lib/NCECriterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | eps = 1e-7 5 | 6 | 7 | class NCECriterion(nn.Module): 8 | 9 | def __init__(self, n_lem): 10 | super(NCECriterion, self).__init__() 11 | self.n_lem = n_lem 12 | 13 | def forward(self, x, targets): 14 | batchSize = x.size(0) 15 | K = x.size(1) - 1 16 | Pnt = 1 / float(self.n_lem) 17 | Pns = 1 / float(self.n_lem) 18 | 19 | # eq 5.1 : P(origin=model) = Pmt / (Pmt + k*Pnt) 20 | Pmt = x.select(1, 0) 21 | Pmt_div = Pmt.add(K * Pnt + eps) 22 | lnPmt = torch.div(Pmt, Pmt_div) 23 | 24 | # eq 5.2 : P(origin=noise) = k*Pns / (Pms + k*Pns) 25 | Pon_div = x.narrow(1, 1, K).add(K * Pns + eps) 26 | Pon = Pon_div.clone().fill_(K * Pns) 27 | lnPon = torch.div(Pon, Pon_div) 28 | 29 | # equation 6 in ref. A 30 | lnPmt.log_() 31 | lnPon.log_() 32 | 33 | lnPmtsum = lnPmt.sum(0) 34 | lnPonsum = lnPon.view(-1, 1).sum(0) 35 | 36 | loss = - (lnPmtsum + lnPonsum) / batchSize 37 | 38 | return loss 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /lib/LinearAverage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch import nn 4 | import math 5 | 6 | 7 | class LinearAverageOp(Function): 8 | @staticmethod 9 | def forward(self, x, y, memory, params): 10 | T = params[0].item() 11 | 12 | # inner product 13 | out = torch.mm(x.data, memory.t()) 14 | out.div_(T) # batchSize * N 15 | 16 | self.save_for_backward(x, memory, y, params) 17 | 18 | return out 19 | 20 | @staticmethod 21 | def backward(self, gradOutput): 22 | x, memory, y, params = self.saved_tensors 23 | T = params[0].item() 24 | momentum = params[1].item() 25 | 26 | # add temperature 27 | gradOutput.data.div_(T) 28 | 29 | # gradient of linear 30 | gradInput = torch.mm(gradOutput.data, memory) 31 | gradInput.resize_as_(x) 32 | 33 | # update the non-parametric data 34 | weight_pos = memory.index_select(0, y.data.view(-1)).resize_as_(x) 35 | weight_pos.mul_(momentum) 36 | weight_pos.add_(torch.mul(x.data, 1 - momentum)) 37 | w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5) 38 | updated_weight = weight_pos.div(w_norm) 39 | memory.index_copy_(0, y, updated_weight) 40 | 41 | return gradInput, None, None, None 42 | 43 | 44 | class LinearAverage(nn.Module): 45 | 46 | def __init__(self, inputSize, outputSize, T=0.07, momentum=0.5): 47 | super(LinearAverage, self).__init__() 48 | self.nLem = outputSize 49 | 50 | self.register_buffer('params', torch.tensor([T, momentum])) 51 | stdv = 1. / math.sqrt(inputSize / 3) 52 | self.register_buffer('memory', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) 53 | 54 | def forward(self, x, y): 55 | out = LinearAverageOp.apply(x, y, self.memory, self.params) 56 | return out 57 | -------------------------------------------------------------------------------- /lib/alias_multinomial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AliasMethod(object): 5 | """ 6 | From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/ 7 | """ 8 | 9 | def __init__(self, probs): 10 | 11 | if probs.sum() > 1: 12 | probs.div_(probs.sum()) 13 | K = len(probs) 14 | self.prob = torch.zeros(K) 15 | self.alias = torch.LongTensor([0] * K) 16 | 17 | # Sort the data into the outcomes with probabilities 18 | # that are larger and smaller than 1/K. 19 | smaller = [] 20 | larger = [] 21 | for kk, prob in enumerate(probs): 22 | self.prob[kk] = K * prob 23 | if self.prob[kk] < 1.0: 24 | smaller.append(kk) 25 | else: 26 | larger.append(kk) 27 | 28 | # Loop though and create little binary mixtures that 29 | # appropriately allocate the larger outcomes over the 30 | # overall uniform mixture. 31 | while len(smaller) > 0 and len(larger) > 0: 32 | small = smaller.pop() 33 | large = larger.pop() 34 | 35 | self.alias[small] = large 36 | self.prob[large] = (self.prob[large] - 1.0) + self.prob[small] 37 | 38 | if self.prob[large] < 1.0: 39 | smaller.append(large) 40 | else: 41 | larger.append(large) 42 | 43 | for last_one in smaller + larger: 44 | self.prob[last_one] = 1 45 | 46 | def cuda(self): 47 | self.prob = self.prob.cuda() 48 | self.alias = self.alias.cuda() 49 | 50 | def draw(self, N): 51 | """ 52 | Draw N samples from multinomial 53 | """ 54 | K = self.alias.size(0) 55 | 56 | kk = torch.zeros(N, dtype=torch.long, device=self.prob.device).random_(0, K) 57 | prob = self.prob.index_select(0, kk) 58 | alias = self.alias.index_select(0, kk) 59 | # b is whether a random number is greater than q 60 | b = torch.bernoulli(prob) 61 | oq = kk.mul(b.long()) 62 | oj = alias.mul((1 - b).long()) 63 | 64 | return oq + oj 65 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.optim.lr_scheduler import CosineAnnealingLR 5 | 6 | 7 | # noinspection PyAttributeOutsideInit 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | 11 | def __init__(self): 12 | self.reset() 13 | 14 | def reset(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | self.val = val 22 | self.sum += val * n 23 | self.count += n 24 | self.avg = self.sum / self.count 25 | 26 | 27 | class CosineAnnealingLRWithRestart(CosineAnnealingLR): 28 | """Adjust learning rate""" 29 | 30 | def __init__(self, optimizer, eta_min=0, lr_t_0=10, lr_t_mul=2, last_epoch=-1): 31 | self.eta_min = eta_min 32 | self.lr_t_curr = lr_t_0 33 | self.lr_t_mul = lr_t_mul 34 | self.last_reset = 0 35 | super(CosineAnnealingLRWithRestart, self).__init__(optimizer, last_epoch) 36 | 37 | def get_lr(self): 38 | curr_epoch = self.last_epoch - self.last_reset 39 | if curr_epoch >= self.lr_t_curr: 40 | self.lr_t_curr *= self.lr_t_mul 41 | self.last_reset = self.last_epoch 42 | rate = 0 43 | else: 44 | rate = curr_epoch * math.pi / self.lr_t_curr 45 | return [self.eta_min + 0.5 * (base_lr - self.eta_min) * (1.0 + math.cos(rate)) 46 | for base_lr in self.base_lrs] 47 | 48 | 49 | def accuracy(output, target, topk=(1,)): 50 | """Computes the precision@k for the specified values of k""" 51 | with torch.no_grad(): 52 | maxk = max(topk) 53 | batch_size = target.size(0) 54 | 55 | _, pred = output.topk(maxk, 1, True, True) 56 | pred = pred.t() 57 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 58 | 59 | res = [] 60 | for k in topk: 61 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 62 | res.append(correct_k.mul_(100.0 / batch_size)) 63 | return res 64 | 65 | 66 | train_labels_ = None 67 | 68 | 69 | def get_train_labels(trainloader, device='cuda'): 70 | global train_labels_ 71 | if train_labels_ is None: 72 | print("=> loading all train labels") 73 | train_labels = -1 * torch.ones([len(trainloader.dataset)], dtype=torch.long) 74 | for i, (_, label, index) in enumerate(trainloader): 75 | train_labels[index] = label 76 | if i % 10000 == 0: 77 | print("{}/{}".format(i, len(trainloader))) 78 | assert all(train_labels != -1) 79 | train_labels_ = train_labels.to(device) 80 | return train_labels_ 81 | -------------------------------------------------------------------------------- /scripts/download_model.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | set -e 3 | 4 | base_url="https://frontiers.blob.core.windows.net/metric-transfer" 5 | local_root=checkpoint 6 | 7 | mkdir -p log 8 | 9 | # you can comment some file if you don't want to download all of them. 10 | echo "downloading pretrained models" 11 | dirname=pretrain_models 12 | mkdir -p ${local_root}/${dirname} 13 | for filename in \ 14 | ckpt_colorization_wrn-28-2.pth.tar \ 15 | ckpt_instance_cifar10_wrn-28-10_89.83.pth.tar \ 16 | ckpt_imagenet32x32_instance_wrn-28-2.pth.tar \ 17 | ckpt_instance_cifar10_wrn-28-2_82.12.pth.tar \ 18 | ckpt_imagenet32x32_snca_wrn-28-2.pth.tar \ 19 | lemniscate_resnet18.pth.tar \ 20 | ckpt_imagenet32x32_softmax_wrn-28-2.pth.tar \ 21 | lemniscate_resnet50.pth.tar \ 22 | ckpt_instance_cifar10_resnet18_85.69.pth.tar; 23 | do 24 | file=${dirname}/${filename}; 25 | wget ${base_url}/${file} -O ${local_root}/${file} -o log/${dirname}_${filename}.txt --no-clobber & 26 | done 27 | wait 28 | 29 | 30 | echo "downloading pre-extracted features" 31 | dirname=train_features_labels_cache 32 | mkdir -p ${local_root}/${dirname} 33 | for filename in \ 34 | colorization_embedding_128.t7 \ 35 | instance_imagenet_val_feature_resnet50.pth.tar \ 36 | instance_imagenet_train_feature_resnet50.pth.tar; 37 | do 38 | file=${dirname}/${filename}; 39 | wget ${base_url}/${file} -O ${local_root}/${file} -o log/${dirname}_${filename}.txt --no-clobber & 40 | done 41 | wait 42 | 43 | 44 | echo "downloading pseudo file for cifar10 dataset" 45 | dirname=pseudos 46 | mkdir -p ${local_root}/${dirname} 47 | for filename in \ 48 | colorization_knn_wrn-28-2.tar \ 49 | imagenet32x32_snca_nc_wrn-28-2.tar \ 50 | instance_nc_wrn-28-2.tar \ 51 | colorization_nc_wrn-28-2.tar \ 52 | imagenet32x32_softmax_nc_wrn-28-2.tar \ 53 | imagenet32x32_instance_nc_wrn-28-2.tar \ 54 | instance_knn_wrn-28-2.tar; 55 | do 56 | file=${dirname}/${filename}; 57 | wget ${base_url}/${file} -O ${local_root}/${file} -o log/${dirname}_${filename}.txt --no-clobber & 58 | done 59 | wait 60 | 61 | echo "downloading pseudo file for imagenet dataset" 62 | dirname=pseudos_imagenet/instance_imagenet_nc_resnet50 63 | mkdir -p ${local_root}/${dirname} 64 | for filename in \ 65 | num_labeled_13000.tar \ 66 | num_labeled_26000.tar \ 67 | num_labeled_51000.tar; 68 | do 69 | file=${dirname}/${filename}; 70 | wget ${base_url}/${file} -O ${local_root}/${file} -o log/pseudos_imagenet_${filename}.txt --no-clobber & 71 | done 72 | wait 73 | 74 | echo "download finished, extracting" 75 | for folder in pseudos pseudos_imagenet/instance_imagenet_nc_resnet50; do 76 | ( 77 | cd ${local_root}/${folder}; 78 | for i in $(ls *.tar); do 79 | tar xvf $i; 80 | rm $i; 81 | done 82 | ) 83 | done 84 | 85 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /lib/NCEAverage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch import nn 4 | from .alias_multinomial import AliasMethod 5 | import math 6 | 7 | 8 | class NCEFunction(Function): 9 | @staticmethod 10 | def forward(self, x, y, memory, idx, params): 11 | K = int(params[0].item()) 12 | T = params[1].item() 13 | Z = params[2].item() 14 | 15 | batchSize = x.size(0) 16 | outputSize = memory.size(0) 17 | inputSize = memory.size(1) 18 | 19 | # sample positives & negatives 20 | idx.select(1, 0).copy_(y.data) 21 | 22 | # sample correspoinding weights 23 | weight = torch.index_select(memory, 0, idx.view(-1)) 24 | weight.resize_(batchSize, K + 1, inputSize) 25 | 26 | # inner product 27 | out = torch.bmm(weight, x.data.resize_(batchSize, inputSize, 1)) 28 | out.div_(T).exp_() # batchSize * self.K+1 29 | x.data.resize_(batchSize, inputSize) 30 | 31 | if Z < 0: 32 | params[2] = out.mean() * outputSize 33 | Z = params[2].item() 34 | print("normalization constant Z is set to {:.1f}".format(Z)) 35 | 36 | out.div_(Z).resize_(batchSize, K + 1) 37 | 38 | self.save_for_backward(x, memory, y, weight, out, params) 39 | 40 | return out 41 | 42 | @staticmethod 43 | def backward(self, gradOutput): 44 | x, memory, y, weight, out, params = self.saved_tensors 45 | K = int(params[0].item()) 46 | T = params[1].item() 47 | momentum = params[3].item() 48 | batchSize = gradOutput.size(0) 49 | 50 | # gradients d Pm / d linear = exp(linear) / Z 51 | gradOutput.data.mul_(out.data) 52 | # add temperature 53 | gradOutput.data.div_(T) 54 | 55 | gradOutput.data.resize_(batchSize, 1, K + 1) 56 | 57 | # gradient of linear 58 | gradInput = torch.bmm(gradOutput.data, weight) 59 | gradInput.resize_as_(x) 60 | 61 | # update the non-parametric data 62 | weight_pos = weight.select(1, 0).resize_as_(x) 63 | weight_pos.mul_(momentum) 64 | weight_pos.add_(torch.mul(x.data, 1 - momentum)) 65 | w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5) 66 | updated_weight = weight_pos.div(w_norm) 67 | memory.index_copy_(0, y, updated_weight) 68 | 69 | return gradInput, None, None, None, None 70 | 71 | 72 | class NCEAverage(nn.Module): 73 | 74 | def __init__(self, inputSize, outputSize, K, T=0.07, momentum=0.5): 75 | super(NCEAverage, self).__init__() 76 | self.nLem = outputSize 77 | self.unigrams = torch.ones(self.nLem) 78 | self.multinomial = AliasMethod(self.unigrams) 79 | self.multinomial.cuda() 80 | self.K = K 81 | 82 | self.register_buffer('params', torch.tensor([K, T, -1, momentum])) 83 | stdv = 1. / math.sqrt(inputSize / 3) 84 | self.register_buffer('memory', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) 85 | 86 | def forward(self, x, y): 87 | batchSize = x.size(0) 88 | idx = self.multinomial.draw(batchSize * (self.K + 1)).view(batchSize, -1) 89 | out = NCEFunction.apply(x, y, self.memory, idx, self.params) 90 | return out 91 | -------------------------------------------------------------------------------- /lib/datasets/cifar.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import torchvision.datasets as datasets 4 | import torch.utils.data as data 5 | import torch 6 | import numpy as np 7 | 8 | 9 | class CIFAR10Instance(datasets.CIFAR10): 10 | """CIFAR10Instance Dataset. 11 | """ 12 | 13 | def __getitem__(self, index): 14 | if self.train: 15 | img, target = self.data[index], self.targets[index] 16 | else: 17 | img, target = self.data[index], self.targets[index] 18 | 19 | # doing this so that it is consistent with all other datasets 20 | # to return a PIL Image 21 | img = Image.fromarray(img) 22 | 23 | if self.transform is not None: 24 | img = self.transform(img) 25 | 26 | if self.target_transform is not None: 27 | target = self.target_transform(target) 28 | 29 | return img, target, index 30 | 31 | 32 | class PseudoCIFAR10(datasets.CIFAR10): 33 | """CIFAR10Instance Dataset. 34 | """ 35 | 36 | def __init__(self, labeled_indexes, **kwargs): 37 | super(PseudoCIFAR10, self).__init__(**kwargs) 38 | assert self.train 39 | self.labeled_indexes = labeled_indexes.cpu().numpy().copy() 40 | self.C = 10 41 | self.labels = np.array(self.targets)[self.labeled_indexes] 42 | self.indexes = self.labeled_indexes 43 | 44 | def __len__(self): 45 | return self.indexes.shape[0] 46 | 47 | def set_pseudo(self, pseudo_indexes, pseudo_labels): 48 | assert pseudo_indexes.shape == pseudo_labels.shape 49 | 50 | self.labels = np.concatenate( 51 | [np.array(self.targets)[self.labeled_indexes], pseudo_labels.cpu().numpy().copy()], axis=0) 52 | self.indexes = np.concatenate([self.labeled_indexes, pseudo_indexes.cpu().numpy().copy()], axis=0) 53 | 54 | def __getitem__(self, index): 55 | real_index = self.indexes[index] 56 | img = self.data[real_index] 57 | target = self.labels[index] 58 | 59 | # doing this so that it is consistent with all other datasets 60 | # to return a PIL Image 61 | img = Image.fromarray(img) 62 | 63 | if self.transform is not None: 64 | img = self.transform(img) 65 | 66 | if self.target_transform is not None: 67 | target = self.target_transform(target) 68 | 69 | return img, target 70 | 71 | 72 | if __name__ == '__main__': 73 | import torchvision.transforms as transforms 74 | 75 | _labeled_indexes = torch.arange(10) 76 | 77 | transform_train = transforms.Compose([ 78 | transforms.ToTensor(), 79 | ]) 80 | ds = PseudoCIFAR10( 81 | labeled_indexes=_labeled_indexes, 82 | root='./data', 83 | transform=transform_train, 84 | download=True) 85 | loader = torch.utils.data.DataLoader(ds, batch_size=5, shuffle=True, num_workers=0) 86 | assert len(loader) == 2 87 | for i, (_img, _target) in enumerate(loader): 88 | print(_img.shape, _target) 89 | break 90 | 91 | # test pseudo 92 | _pseudo_indexes = torch.arange(100, 200) 93 | _pseudo_labels = torch.zeros([100]) 94 | loader.dataset.set_pseudo(_pseudo_indexes, _pseudo_labels) 95 | assert len(loader) == 22 # (100 + 10) / 5 96 | for i, (_img, _target) in enumerate(loader): 97 | print(_img.shape, _target) 98 | break 99 | -------------------------------------------------------------------------------- /lib/datasets/folder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import torchvision.datasets as datasets 4 | 5 | 6 | class ImageFolderInstance(datasets.ImageFolder): 7 | """: Folder datasets which returns the index of the image as well:: 8 | """ 9 | def __getitem__(self, index): 10 | """ 11 | Args: 12 | index (int): Index 13 | Returns: 14 | tuple: (image, target) where target is class_index of the target class. 15 | """ 16 | path, target = self.imgs[index] 17 | img = self.loader(path) 18 | if self.transform is not None: 19 | img = self.transform(img) 20 | if self.target_transform is not None: 21 | target = self.target_transform(target) 22 | 23 | return img, target, index 24 | 25 | 26 | class PseudoDatasetFolder(Dataset): 27 | 28 | def __init__(self, ds, labeled_indexes): 29 | self.ds = ds 30 | self.labeled_indexes = labeled_indexes 31 | self.num_labeled = len(self.labeled_indexes) 32 | # self.labeled_indexes_set = set(labeled_indexes.cpu().numpy()) 33 | self.pseudo_indexes = [] 34 | self.pseudo_labels = None 35 | 36 | def __len__(self): 37 | return self.num_labeled + len(self.pseudo_indexes) 38 | 39 | def __getitem__(self, index): 40 | 41 | if index < self.num_labeled: 42 | # labeled 43 | real_index = self.labeled_indexes[index] 44 | sample, target = self.ds[real_index] 45 | else: 46 | # pseudo 47 | real_index = self.pseudo_indexes[index - self.num_labeled] 48 | sample, _ = self.ds[real_index] 49 | target = self.pseudo_labels[index - self.num_labeled] 50 | return sample, target 51 | 52 | def set_pseudo(self, pseudo_indexes, pseudo_labels): 53 | assert len(pseudo_indexes) == len(pseudo_labels) 54 | self.pseudo_indexes = pseudo_indexes 55 | if isinstance(pseudo_labels, torch.Tensor): 56 | pseudo_labels = pseudo_labels.cpu().numpy() 57 | self.pseudo_labels = pseudo_labels 58 | 59 | 60 | if __name__ == '__main__': 61 | from torchvision import datasets 62 | import torchvision.transforms as transforms 63 | transform_test = transforms.Compose([ 64 | transforms.Resize(256), 65 | transforms.CenterCrop(224), 66 | transforms.ToTensor(), 67 | ]) 68 | trainset = datasets.ImageFolder('/home/liubin/data/imagenet/train/', transform=transform_test) 69 | # test list 70 | labeled_indexes_ = [1] 71 | pseudo_indexes_, pseudo_labels_ = [2], [10] 72 | pseudo_trainset = PseudoDatasetFolder(trainset, labeled_indexes=labeled_indexes_) 73 | pseudo_trainset.set_pseudo(pseudo_indexes_, pseudo_labels_) 74 | for i, (_, target_) in enumerate(pseudo_trainset): 75 | if i == 0: 76 | assert target_ == 0 77 | else: 78 | assert target_ == 10 79 | 80 | # test np array 81 | import numpy as np 82 | labeled_indexes_ = np.array([1]) 83 | pseudo_indexes_, pseudo_labels_ = np.array([2]), np.array([10]) 84 | pseudo_trainset = PseudoDatasetFolder(trainset, labeled_indexes=labeled_indexes_) 85 | pseudo_trainset.set_pseudo(pseudo_indexes_, pseudo_labels_) 86 | for i, (_, target_) in enumerate(pseudo_trainset): 87 | if i == 0: 88 | assert target_ == 0 89 | else: 90 | assert target_ == 10 91 | 92 | # test torch tensor 93 | n = len(trainset) 94 | num_labeled = n // 2 95 | labeled_indexes_ = torch.arange(num_labeled) 96 | pseudo_indexes_ = torch.arange(num_labeled, n) 97 | pseudo_labels_ = torch.zeros([n - num_labeled], dtype=torch.int64) 98 | pseudo_trainset = PseudoDatasetFolder(trainset, labeled_indexes=labeled_indexes_) 99 | pseudo_trainset.set_pseudo(pseudo_indexes_, pseudo_labels_) 100 | assert pseudo_trainset[0][1] == trainset.samples[0][1] 101 | assert pseudo_trainset[num_labeled][1] == 0 102 | 103 | # test loader 104 | pseudo_trainloder = torch.utils.data.DataLoader( 105 | pseudo_trainset, batch_size=256, 106 | shuffle=True, num_workers=8) 107 | 108 | for data_, target_ in pseudo_trainloder: 109 | print(data_, target_) 110 | break 111 | -------------------------------------------------------------------------------- /lib/models/wideresnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from lib.normalize import Normalize 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0): 12 | super(BasicBlock, self).__init__() 13 | self.bn1 = nn.BatchNorm2d(in_planes) 14 | self.relu1 = nn.ReLU(inplace=True) 15 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | self.bn2 = nn.BatchNorm2d(out_planes) 18 | self.relu2 = nn.ReLU(inplace=True) 19 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 20 | padding=1, bias=False) 21 | self.droprate = drop_rate 22 | self.equalInOut = (in_planes == out_planes) 23 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 24 | padding=0, bias=False) or None 25 | 26 | def forward(self, x): 27 | if not self.equalInOut: 28 | x = self.relu1(self.bn1(x)) 29 | else: 30 | out = self.relu1(self.bn1(x)) 31 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 32 | if self.droprate > 0: 33 | out = F.dropout(out, p=self.droprate, training=self.training) 34 | out = self.conv2(out) 35 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 36 | 37 | 38 | class NetworkBlock(nn.Module): 39 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 40 | super(NetworkBlock, self).__init__() 41 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 42 | 43 | @staticmethod 44 | def _make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate): 45 | layers = [] 46 | for i in range(int(nb_layers)): 47 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 48 | return nn.Sequential(*layers) 49 | 50 | def forward(self, x): 51 | return self.layer(x) 52 | 53 | 54 | class WideResNet(nn.Module): 55 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0, norm=True): 56 | super(WideResNet, self).__init__() 57 | n_channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 58 | assert ((depth - 4) % 6 == 0) 59 | n = (depth - 4) / 6 60 | block = BasicBlock 61 | # 1st conv before any network block 62 | self.conv1 = nn.Conv2d(3, n_channels[0], kernel_size=3, stride=1, 63 | padding=1, bias=False) 64 | # 1st block 65 | self.block1 = NetworkBlock(n, n_channels[0], n_channels[1], block, 1, dropRate) 66 | # 2nd block 67 | self.block2 = NetworkBlock(n, n_channels[1], n_channels[2], block, 2, dropRate) 68 | # 3rd block 69 | self.block3 = NetworkBlock(n, n_channels[2], n_channels[3], block, 2, dropRate) 70 | # global average pooling and classifier 71 | self.bn1 = nn.BatchNorm2d(n_channels[3]) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.fc = nn.Linear(n_channels[3], num_classes) 74 | self.nChannels = n_channels[3] 75 | 76 | for m in self.modules(): 77 | if isinstance(m, nn.Conv2d): 78 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 79 | m.weight.data.normal_(0, math.sqrt(2. / n)) 80 | elif isinstance(m, nn.BatchNorm2d): 81 | m.weight.data.fill_(1) 82 | m.bias.data.zero_() 83 | elif isinstance(m, nn.Linear): 84 | m.bias.data.zero_() 85 | 86 | self.l2norm = Normalize(2) 87 | self.norm = norm 88 | 89 | def forward(self, x): 90 | out = self.conv1(x) 91 | out = self.block1(out) 92 | out = self.block2(out) 93 | out = self.block3(out) 94 | out = self.relu(self.bn1(out)) 95 | out = F.avg_pool2d(out, 8) 96 | out = out.view(-1, self.nChannels) 97 | out = self.fc(out) 98 | if self.norm: 99 | out = self.l2norm(out) 100 | return out 101 | -------------------------------------------------------------------------------- /lib/models/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | """ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from lib.normalize import Normalize 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion * planes: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(self.expansion * planes) 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = self.bn2(self.conv2(out)) 36 | out += self.shortcut(x) 37 | out = F.relu(out) 38 | return out 39 | 40 | 41 | class Bottleneck(nn.Module): 42 | expansion = 4 43 | 44 | def __init__(self, in_planes, planes, stride=1): 45 | super(Bottleneck, self).__init__() 46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 47 | self.bn1 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 52 | 53 | self.shortcut = nn.Sequential() 54 | if stride != 1 or in_planes != self.expansion * planes: 55 | self.shortcut = nn.Sequential( 56 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 57 | nn.BatchNorm2d(self.expansion * planes) 58 | ) 59 | 60 | def forward(self, x): 61 | out = F.relu(self.bn1(self.conv1(x))) 62 | out = F.relu(self.bn2(self.conv2(out))) 63 | out = self.bn3(self.conv3(out)) 64 | out += self.shortcut(x) 65 | out = F.relu(out) 66 | return out 67 | 68 | 69 | class ResNet(nn.Module): 70 | def __init__(self, block, num_blocks, low_dim=128, norm=True): 71 | super(ResNet, self).__init__() 72 | self.in_planes = 64 73 | 74 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(64) 76 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 77 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 78 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 79 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 80 | self.linear = nn.Linear(512 * block.expansion, low_dim) 81 | self.l2norm = Normalize(2) 82 | self.norm = norm 83 | 84 | def _make_layer(self, block, planes, num_blocks, stride): 85 | strides = [stride] + [1] * (num_blocks - 1) 86 | layers = [] 87 | for stride in strides: 88 | layers.append(block(self.in_planes, planes, stride)) 89 | self.in_planes = planes * block.expansion 90 | return nn.Sequential(*layers) 91 | 92 | def forward(self, x): 93 | out = F.relu(self.bn1(self.conv1(x))) 94 | out = self.layer1(out) 95 | out = self.layer2(out) 96 | out = self.layer3(out) 97 | out = self.layer4(out) 98 | out = F.avg_pool2d(out, 4) 99 | out = out.view(out.size(0), -1) 100 | out = self.linear(out) 101 | if self.norm: 102 | out = self.l2norm(out) 103 | return out 104 | 105 | 106 | def resnet18_cifar(low_dim=128, norm=True): 107 | return ResNet(block=BasicBlock, num_blocks=[2, 2, 2, 2], low_dim=low_dim, norm=norm) 108 | 109 | 110 | def resnet34_cifar(low_dim=128, norm=True): 111 | return ResNet(block=BasicBlock, num_blocks=[3, 4, 6, 3], low_dim=low_dim, norm=norm) 112 | 113 | 114 | def resnet50_cifar10(low_dim=128, norm=True): 115 | return ResNet(block=Bottleneck, num_blocks=[3, 4, 6, 3], low_dim=low_dim, norm=norm) 116 | 117 | 118 | def resnet101_cifar10(low_dim=128, norm=True): 119 | return ResNet(block=Bottleneck, num_blocks=[3, 4, 23, 3], low_dim=low_dim, norm=norm) 120 | 121 | 122 | def resnet152_cifar10(low_dim=128, norm=True): 123 | return ResNet(block=Bottleneck, num_blocks=[3, 8, 36, 3], low_dim=low_dim, norm=norm) 124 | 125 | 126 | if __name__ == '__main__': 127 | import numpy as np 128 | 129 | inputs = torch.randn(10, 3, 32, 32) 130 | y = resnet18_cifar(low_dim=1024, norm=True)(inputs) 131 | assert y.shape == (10, 1024) 132 | np.testing.assert_array_almost_equal(y.pow(2).sum(1).detach().numpy(), np.ones([10])) 133 | 134 | # test no norm 135 | y = resnet18_cifar(low_dim=1024, norm=False)(inputs) 136 | assert y.shape == (10, 1024) 137 | print(y.pow(2).sum(1).detach().numpy()) 138 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | 5 | from lib.utils import AverageMeter, get_train_labels, accuracy 6 | 7 | 8 | def NN(net, lemniscate, trainloader, testloader, recompute_memory=0): 9 | net.eval() 10 | net_time = AverageMeter() 11 | cls_time = AverageMeter() 12 | correct = 0. 13 | total = 0 14 | testsize = testloader.dataset.__len__() 15 | 16 | train_features = lemniscate.memory.t() 17 | if hasattr(trainloader.dataset, 'imgs'): 18 | train_labels = torch.LongTensor( 19 | [y for (p, y) in trainloader.dataset.imgs]).cuda() 20 | else: 21 | train_labels = get_train_labels(trainloader) 22 | if recompute_memory: 23 | transform_bak = trainloader.dataset.transform 24 | trainloader.dataset.transform = testloader.dataset.transform 25 | temploader = torch.utils.data.DataLoader( 26 | trainloader.dataset, batch_size=100, shuffle=False, num_workers=1) 27 | for batch_idx, (inputs, targets, indexes) in enumerate(temploader): 28 | batch_size = inputs.size(0) 29 | features = net(inputs) 30 | train_features[:, batch_idx * batch_size:batch_idx * 31 | batch_size + batch_size] = features.data.t() 32 | train_labels = get_train_labels(trainloader) 33 | trainloader.dataset.transform = transform_bak 34 | 35 | end = time.time() 36 | with torch.no_grad(): 37 | for batch_idx, (inputs, targets, indexes) in enumerate(testloader): 38 | targets = targets.cuda(non_blocking=True) 39 | batch_size = inputs.size(0) 40 | features = net(inputs) 41 | net_time.update(time.time() - end) 42 | end = time.time() 43 | 44 | dist = torch.mm(features, train_features) 45 | 46 | yd, yi = dist.topk(1, dim=1, largest=True, sorted=True) 47 | candidates = train_labels.view(1, -1).expand(batch_size, -1) 48 | retrieval = torch.gather(candidates, 1, yi) 49 | 50 | retrieval = retrieval.narrow(1, 0, 1).clone().view(-1) 51 | 52 | total += targets.size(0) 53 | correct += retrieval.eq(targets.data).sum().item() 54 | 55 | cls_time.update(time.time() - end) 56 | end = time.time() 57 | 58 | print(f'Test [{total}/{testsize}]\t' 59 | f'Net Time {net_time.val:.3f} ({net_time.avg:.3f})\t' 60 | f'Cls Time {cls_time.val:.3f} ({cls_time.avg:.3f})\t' 61 | f'Top1: {correct * 100. / total:.2f}') 62 | 63 | return correct / total 64 | 65 | 66 | def kNN(net, lemniscate, trainloader, testloader, K, sigma, recompute_memory=0): 67 | net.eval() 68 | net_time = AverageMeter() 69 | cls_time = AverageMeter() 70 | total = 0 71 | testsize = testloader.dataset.__len__() 72 | 73 | train_features = lemniscate.memory.t() 74 | if hasattr(trainloader.dataset, 'imgs'): 75 | train_labels = torch.LongTensor( 76 | [y for (p, y) in trainloader.dataset.imgs]).cuda() 77 | else: 78 | train_labels = get_train_labels(trainloader) 79 | C = train_labels.max() + 1 80 | 81 | if recompute_memory: 82 | transform_bak = trainloader.dataset.transform 83 | trainloader.dataset.transform = testloader.dataset.transform 84 | temploader = torch.utils.data.DataLoader( 85 | trainloader.dataset, batch_size=100, shuffle=False, num_workers=1) 86 | for batch_idx, (inputs, targets, indexes) in enumerate(temploader): 87 | bs = inputs.size(0) 88 | features = net(inputs) 89 | train_features[:, batch_idx * bs:batch_idx * 90 | bs + bs] = features.data.t() 91 | train_labels = get_train_labels(trainloader) 92 | trainloader.dataset.transform = transform_bak 93 | 94 | top1 = 0. 95 | top5 = 0. 96 | with torch.no_grad(): 97 | retrieval_one_hot = torch.zeros(K, C).cuda() 98 | for batch_idx, (inputs, targets, indexes) in enumerate(testloader): 99 | end = time.time() 100 | targets = targets.cuda(non_blocking=True) 101 | bs = inputs.size(0) 102 | features = net(inputs) 103 | net_time.update(time.time() - end) 104 | end = time.time() 105 | 106 | dist = torch.mm(features, train_features) 107 | 108 | yd, yi = dist.topk(K, dim=1, largest=True, sorted=True) 109 | candidates = train_labels.view(1, -1).expand(bs, -1) 110 | retrieval = torch.gather(candidates, 1, yi) 111 | 112 | retrieval_one_hot.resize_(bs * K, C).zero_() 113 | retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1) 114 | yd_transform = yd.clone().div_(sigma).exp_() 115 | probs = torch.sum(torch.mul(retrieval_one_hot.view( 116 | bs, -1, C), yd_transform.view(bs, -1, 1)), 1) 117 | _, predictions = probs.sort(1, True) 118 | 119 | # Find which predictions match the target 120 | correct = predictions.eq(targets.data.view(-1, 1)) 121 | cls_time.update(time.time() - end) 122 | 123 | top1 = top1 + correct.narrow(1, 0, 1).sum().item() 124 | top5 = top5 + correct.narrow(1, 0, 2).sum().item() 125 | 126 | total += targets.size(0) 127 | 128 | if batch_idx % 100 == 0: 129 | print(f'Test [{total}/{testsize}]\t' 130 | f'Net Time {net_time.val:.3f} ({net_time.avg:.3f})\t' 131 | f'Cls Time {cls_time.val:.3f} ({cls_time.avg:.3f})\t' 132 | f'Top1: {top1 * 100. / total:.2f} top5: {top5 * 100. / total:.2f}') 133 | 134 | print(top1 * 100. / total) 135 | 136 | return top1 / total 137 | 138 | 139 | def validate(val_loader, model, criterion, device='cpu', print_freq=100): 140 | batch_time = AverageMeter() 141 | losses = AverageMeter() 142 | top1 = AverageMeter() 143 | top5 = AverageMeter() 144 | 145 | # switch to evaluate mode 146 | model.eval() 147 | 148 | with torch.no_grad(): 149 | end = time.time() 150 | for i, (data, target) in enumerate(val_loader): 151 | data, target = data.to(device), target.to(device) 152 | 153 | # compute output 154 | output = model(data) 155 | loss = criterion(output, target) 156 | 157 | # measure accuracy and record loss 158 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 159 | losses.update(loss.item(), data.size(0)) 160 | top1.update(prec1[0], data.size(0)) 161 | top5.update(prec5[0], data.size(0)) 162 | 163 | # measure elapsed time 164 | batch_time.update(time.time() - end) 165 | end = time.time() 166 | 167 | if i % print_freq == 0: 168 | print(f'Test: [{i}/{len(val_loader)}] ' 169 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 170 | f'Loss {loss.val:.4f} ({loss.avg:.4f}) ' 171 | f'Prec@1 {top1.val:.3f} ({top1.avg:.3f}) ' 172 | f'Prec@5 {top5.val:.3f} ({top5.avg:.3f})') 173 | 174 | print(f' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}') 175 | 176 | return top1.avg 177 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.suo 8 | *.user 9 | *.userosscache 10 | *.sln.docstates 11 | 12 | # User-specific files (MonoDevelop/Xamarin Studio) 13 | *.userprefs 14 | 15 | # Build results 16 | [Dd]ebug/ 17 | [Dd]ebugPublic/ 18 | [Rr]elease/ 19 | [Rr]eleases/ 20 | x64/ 21 | x86/ 22 | bld/ 23 | [Bb]in/ 24 | [Oo]bj/ 25 | [Ll]og/ 26 | 27 | # Visual Studio 2015/2017 cache/options directory 28 | .vs/ 29 | # Uncomment if you have tasks that create the project's static files in wwwroot 30 | #wwwroot/ 31 | 32 | # Visual Studio 2017 auto generated files 33 | Generated\ Files/ 34 | 35 | # MSTest test Results 36 | [Tt]est[Rr]esult*/ 37 | [Bb]uild[Ll]og.* 38 | 39 | # NUNIT 40 | *.VisualState.xml 41 | TestResult.xml 42 | 43 | # Build Results of an ATL Project 44 | [Dd]ebugPS/ 45 | [Rr]eleasePS/ 46 | dlldata.c 47 | 48 | # Benchmark Results 49 | BenchmarkDotNet.Artifacts/ 50 | 51 | # .NET Core 52 | project.lock.json 53 | project.fragment.lock.json 54 | artifacts/ 55 | **/Properties/launchSettings.json 56 | 57 | # StyleCop 58 | StyleCopReport.xml 59 | 60 | # Files built by Visual Studio 61 | *_i.c 62 | *_p.c 63 | *_i.h 64 | *.ilk 65 | *.meta 66 | *.obj 67 | *.iobj 68 | *.pch 69 | *.pdb 70 | *.ipdb 71 | *.pgc 72 | *.pgd 73 | *.rsp 74 | *.sbr 75 | *.tlb 76 | *.tli 77 | *.tlh 78 | *.tmp 79 | *.tmp_proj 80 | *.log 81 | *.vspscc 82 | *.vssscc 83 | .builds 84 | *.pidb 85 | *.svclog 86 | *.scc 87 | 88 | # Chutzpah Test files 89 | _Chutzpah* 90 | 91 | # Visual C++ cache files 92 | ipch/ 93 | *.aps 94 | *.ncb 95 | *.opendb 96 | *.opensdf 97 | *.sdf 98 | *.cachefile 99 | *.VC.db 100 | *.VC.VC.opendb 101 | 102 | # Visual Studio profiler 103 | *.psess 104 | *.vsp 105 | *.vspx 106 | *.sap 107 | 108 | # Visual Studio Trace Files 109 | *.e2e 110 | 111 | # TFS 2012 Local Workspace 112 | $tf/ 113 | 114 | # Guidance Automation Toolkit 115 | *.gpState 116 | 117 | # ReSharper is a .NET coding add-in 118 | _ReSharper*/ 119 | *.[Rr]e[Ss]harper 120 | *.DotSettings.user 121 | 122 | # JustCode is a .NET coding add-in 123 | .JustCode 124 | 125 | # TeamCity is a build add-in 126 | _TeamCity* 127 | 128 | # DotCover is a Code Coverage Tool 129 | *.dotCover 130 | 131 | # AxoCover is a Code Coverage Tool 132 | .axoCover/* 133 | !.axoCover/settings.json 134 | 135 | # Visual Studio code coverage results 136 | *.coverage 137 | *.coveragexml 138 | 139 | # NCrunch 140 | _NCrunch_* 141 | .*crunch*.local.xml 142 | nCrunchTemp_* 143 | 144 | # MightyMoose 145 | *.mm.* 146 | AutoTest.Net/ 147 | 148 | # Web workbench (sass) 149 | .sass-cache/ 150 | 151 | # Installshield output folder 152 | [Ee]xpress/ 153 | 154 | # DocProject is a documentation generator add-in 155 | DocProject/buildhelp/ 156 | DocProject/Help/*.HxT 157 | DocProject/Help/*.HxC 158 | DocProject/Help/*.hhc 159 | DocProject/Help/*.hhk 160 | DocProject/Help/*.hhp 161 | DocProject/Help/Html2 162 | DocProject/Help/html 163 | 164 | # Click-Once directory 165 | publish/ 166 | 167 | # Publish Web Output 168 | *.[Pp]ublish.xml 169 | *.azurePubxml 170 | # Note: Comment the next line if you want to checkin your web deploy settings, 171 | # but database connection strings (with potential passwords) will be unencrypted 172 | *.pubxml 173 | *.publishproj 174 | 175 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 176 | # checkin your Azure Web App publish settings, but sensitive information contained 177 | # in these scripts will be unencrypted 178 | PublishScripts/ 179 | 180 | # NuGet Packages 181 | *.nupkg 182 | # The packages folder can be ignored because of Package Restore 183 | **/[Pp]ackages/* 184 | # except build/, which is used as an MSBuild target. 185 | !**/[Pp]ackages/build/ 186 | # Uncomment if necessary however generally it will be regenerated when needed 187 | #!**/[Pp]ackages/repositories.config 188 | # NuGet v3's project.json files produces more ignorable files 189 | *.nuget.props 190 | *.nuget.targets 191 | 192 | # Microsoft Azure Build Output 193 | csx/ 194 | *.build.csdef 195 | 196 | # Microsoft Azure Emulator 197 | ecf/ 198 | rcf/ 199 | 200 | # Windows Store app package directories and files 201 | AppPackages/ 202 | BundleArtifacts/ 203 | Package.StoreAssociation.xml 204 | _pkginfo.txt 205 | *.appx 206 | 207 | # Visual Studio cache files 208 | # files ending in .cache can be ignored 209 | *.[Cc]ache 210 | # but keep track of directories ending in .cache 211 | !*.[Cc]ache/ 212 | 213 | # Others 214 | ClientBin/ 215 | ~$* 216 | *~ 217 | *.dbmdl 218 | *.dbproj.schemaview 219 | *.jfm 220 | *.pfx 221 | *.publishsettings 222 | orleans.codegen.cs 223 | 224 | # Including strong name files can present a security risk 225 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 226 | #*.snk 227 | 228 | # Since there are multiple workflows, uncomment next line to ignore bower_components 229 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 230 | #bower_components/ 231 | 232 | # RIA/Silverlight projects 233 | Generated_Code/ 234 | 235 | # Backup & report files from converting an old project file 236 | # to a newer Visual Studio version. Backup files are not needed, 237 | # because we have git ;-) 238 | _UpgradeReport_Files/ 239 | Backup*/ 240 | UpgradeLog*.XML 241 | UpgradeLog*.htm 242 | ServiceFabricBackup/ 243 | *.rptproj.bak 244 | 245 | # SQL Server files 246 | *.mdf 247 | *.ldf 248 | *.ndf 249 | 250 | # Business Intelligence projects 251 | *.rdl.data 252 | *.bim.layout 253 | *.bim_*.settings 254 | *.rptproj.rsuser 255 | 256 | # Microsoft Fakes 257 | FakesAssemblies/ 258 | 259 | # GhostDoc plugin setting file 260 | *.GhostDoc.xml 261 | 262 | # Node.js Tools for Visual Studio 263 | .ntvs_analysis.dat 264 | node_modules/ 265 | 266 | # Visual Studio 6 build log 267 | *.plg 268 | 269 | # Visual Studio 6 workspace options file 270 | *.opt 271 | 272 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 273 | *.vbw 274 | 275 | # Visual Studio LightSwitch build output 276 | **/*.HTMLClient/GeneratedArtifacts 277 | **/*.DesktopClient/GeneratedArtifacts 278 | **/*.DesktopClient/ModelManifest.xml 279 | **/*.Server/GeneratedArtifacts 280 | **/*.Server/ModelManifest.xml 281 | _Pvt_Extensions 282 | 283 | # Paket dependency manager 284 | .paket/paket.exe 285 | paket-files/ 286 | 287 | # FAKE - F# Make 288 | .fake/ 289 | 290 | # JetBrains Rider 291 | .idea/ 292 | *.sln.iml 293 | 294 | # CodeRush 295 | .cr/ 296 | 297 | # Python Tools for Visual Studio (PTVS) 298 | __pycache__/ 299 | *.pyc 300 | 301 | # Cake - Uncomment if you are using it 302 | # tools/** 303 | # !tools/packages.config 304 | 305 | # Tabs Studio 306 | *.tss 307 | 308 | # Telerik's JustMock configuration file 309 | *.jmconfig 310 | 311 | # BizTalk build output 312 | *.btp.cs 313 | *.btm.cs 314 | *.odx.cs 315 | *.xsd.cs 316 | 317 | # OpenCover UI analysis results 318 | OpenCover/ 319 | 320 | # Azure Stream Analytics local run output 321 | ASALocalRun/ 322 | 323 | # MSBuild Binary and Structured Log 324 | *.binlog 325 | 326 | # NVidia Nsight GPU debugger configuration file 327 | *.nvuser 328 | 329 | # MFractors (Xamarin productivity tool) working folder 330 | .mfractor/ 331 | .vscode/settings.json 332 | *.swp 333 | checkpoint 334 | data 335 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Deep Metric Transfer for Label Propagation with Limited Annotated Data 2 | 3 | This repo contains the pytorch implementation for the semi-supervised learning paper [(arxiv)](https://arxiv.org/abs/1812.08781). 4 | 5 | ## Requirements 6 | 7 | * Python3: Anaconda is recommended because it already contains a lot of packages: 8 | * `pytorch>=1.0`: Refer to https://pytorch.org/get-started/locally/ 9 | * other packages: `pip install tensorboardX tensorboard easydict scikit-image` 10 | 11 | ## Highlight 12 | 13 | - We formulate semi-supervised learning from a completely different metric transfer perspective. 14 | - Enjoys the benefit of recent advances self-supervised learning. 15 | - We hope to draw more attention to unsupervised pretraining for other tasks. 16 | 17 | ## Main results 18 | 19 | The test accuracy of our methods and the state-of-the-art methods on CIFAR10 dataset with different number of labeled data. 20 | 21 | | Method | 50 | 100 | 250 | 500 | 1000 | 2000 | 4000 | 8000 | 22 | | :----------------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | 23 | | PI-model | 27.36 | 37.20 | 47.07 | 56.30 | 63.70 | 76.50 | 84.17 | 87.30 | 24 | | Mean-Teacher | 29.66 | 36.60 | 45.49 | 57.20 | 65.00 | 79.00 | 84.38 | 87.50 | 25 | | VAT | 23.00 | 35.58 | 47.61 | 62.90 | 72.80 | **84.00** | **86.79** | **88.10** | 26 | | Pseudo-Label | 21.00 | 34.00 | 45.83 | 60.30 | 68.20 | 78.00 | 84.79 | 86.20 | 27 | | **Ours** | **56.34** | **63.53** | **71.26** | **74.77** | **79.38** | 82.34 | 84.52 | 87.48 | 28 | 29 | 30 | ## Quick start 31 | 32 | * Clone this repo: `git clone git@github.com:microsoft/metric-transfer.pytorch.git && cd metric-transfer.pytorch` 33 | 34 | * Install pytorch and other packages listed in requirements 35 | 36 | * Download pretrained models and precomputed pseudo labels: `bash scripts/download_model.sh` . Make sure the `checkpoint` folder looks like this: 37 | 38 | ``` 39 | checkpoint 40 | |-- pretrain_models 41 | | |-- ckpt_instance_cifar10_wrn-28-2_82.12.pth.tar 42 | | |-- ... other files 43 | | `-- lemniscate_resnet50.pth.tar 44 | |-- pseudos 45 | | |-- instance_nc_wrn-28-2 46 | | | |-- 50.pth.tar 47 | | | |-- ... other files 48 | | | `-- 8000.pth.tar 49 | | `-- ... other folders 50 | `-- pseudos_imagenet 51 | `-- instance_imagenet_nc_resnet50 52 | |-- num_labeled_13000 53 | | |-- 10_0.pth.tar 54 | | |-- ... other files 55 | | `-- 10_9.pth.tar 56 | `-- ... other folders 57 | ``` 58 | 59 | * Supervised finetune on cifar10 dataset or Imagenet dataset. The cifar dataset will be downloaded automatically. For imagenet, refer to [here](https://github.com/pytorch/examples/tree/master/imagenet) for details of data preparation. 60 | 61 | ```bash 62 | # Finetune on cifar 63 | python cifar-semi.py \ 64 | --gpus 0 \ 65 | --num-labeled 250 \ 66 | --pseudo-file checkpoint/pseudos/instance_nc_wrn-28-2/250.pth.tar \ 67 | --resume checkpoint/pretrain_models/ckpt_instance_cifar10_wrn-28-2_82.12.pth.tar \ 68 | --pseudo-ratio 0.2 69 | 70 | # For imagenet 71 | n_labeled=13000 # 1% labeled data 72 | pseudo_ratio=0.1 # use top 10% pseudo label 73 | data_dir=/path/to/imagenet/dir 74 | 75 | python imagenet-semi.py \ 76 | --arch resnet50 \ 77 | --gpus 0,1,2,3 \ 78 | --num-labeled ${n_labeled} \ 79 | --data-dir ${data_dir} \ 80 | --pretrained checkpoint/pretrain_models/lemniscate_resnet50.pth.tar \ 81 | --pseudo-dir checkpoint/pseudos_imagenet/instance_imagenet_nc_resnet50/num_labeled_${n_labeled} \ 82 | --pseudo-ratio ${pseudo_ratio} \ 83 | ``` 84 | 85 | ## Usage 86 | 87 | The proposed method contains three main steps: metric pretraining, label propagation, and supervised finetune. 88 | 89 | ### Metric pretraining 90 | 91 | The metric pretraining can be unsupervised or supervised, from the same or different dataset. 92 | 93 | We provide code for [instance discrimination](https://arxiv.org/abs/1805.01978), which is borrowed from the [original pytorch release](https://github.com/zhirongw/lemniscate.pytorch) of instance discrimination. You can run the following command in root director of code to train the instance discrimination on cifar10 dataset: 94 | 95 | ```bash 96 | export PYTHONPATH=$PYTHONPATH:$(pwd) 97 | CUDA_VISIBLE_DEVICES=0 python unsupervised/cifar.py \ 98 | --lr-scheduler cosine-with-restart \ 99 | --epochs 1270 100 | ``` 101 | 102 | For other metric or imagenet dataset, such as colorization on cifar10 dataset, or instance discrimination on imagenet datset, ref to offical released code: [colorization](https://github.com/richzhang/colorization), [instance discrimination](https://github.com/zhirongw/lemniscate.pytorch). We also provide the pretrained weight. Refer to `scripts/download_model.sh` for more details. 103 | 104 | ### Label propagation 105 | 106 | We then can propagation the label using the trained metric from the few labeled examples to a vast collection of unannotated images. 107 | 108 | We consider two propagation algorithms: K-nearest neighbors(i.e. **knn**) and spectral clustering(also called normalized cut, i.e **nc**). The implementation is in `notebooks` folder, which is in jupyter notebook format. You can simplely run the notebook to load the weight of metric pretraining approach and propagate to get the pseudo label. 109 | 110 | We alse provide the pseudo label for cifar10 and imagenet dataset. Refer to `scripts/download_model.sh` for more details. 111 | 112 | ### Supervised finetune 113 | 114 | With the estimated pseudo labels on the unlabeled data, we can train a classifier with more data. For simplicity, we omit the confidence weighted supervised training in the current version. Instead, we only use a portion of the most confident pseudo label to training. 115 | 116 | Refer to quickstart part for more command instruction. 117 | 118 | ## Contributing 119 | 120 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 121 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 122 | the rights to use your contribution. For details, visit https://cla.microsoft.com. 123 | 124 | When you submit a pull request, a CLA-bot will automatically determine whether you need to provide 125 | a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions 126 | provided by the bot. You will only need to do this once across all repos using our CLA. 127 | 128 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 129 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 130 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 131 | 132 | ## Citation 133 | 134 | If you find this paper useful in your research, please consider citing: 135 | 136 | ```latex 137 | @inproceedings{liu2018deep, 138 | title={Deep Metric Transfer for Label Propagation with Limited Annotated Data}, 139 | author={Liu, Bin and Wu, Zhirong and Hu, Han and Lin, Stephen}, 140 | journal={arXiv preprint arXiv:1812.08781}, 141 | year={2018} 142 | } 143 | ``` 144 | 145 | ## Contact 146 | 147 | For any questions, please feel free to create a new issue or reach 148 | ``` 149 | Bin Liu: liubinthss@gmail.com 150 | ``` 151 | -------------------------------------------------------------------------------- /lib/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | from lib.normalize import Normalize 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152'] 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, block, layers, low_dim=128): 98 | self.inplanes = 64 99 | super(ResNet, self).__init__() 100 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 101 | bias=False) 102 | self.bn1 = nn.BatchNorm2d(64) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 105 | self.layer1 = self._make_layer(block, 64, layers[0]) 106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 107 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 108 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 109 | self.avgpool = nn.AvgPool2d(7, stride=1) 110 | self.fc = nn.Linear(512 * block.expansion, low_dim) 111 | self.l2norm = Normalize(2) 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, math.sqrt(2. / n)) 117 | elif isinstance(m, nn.BatchNorm2d): 118 | m.weight.data.fill_(1) 119 | m.bias.data.zero_() 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False), 127 | nn.BatchNorm2d(planes * block.expansion), 128 | ) 129 | 130 | layers = [block(self.inplanes, planes, stride, downsample)] 131 | self.inplanes = planes * block.expansion 132 | for i in range(1, blocks): 133 | layers.append(block(self.inplanes, planes)) 134 | 135 | return nn.Sequential(*layers) 136 | 137 | def forward(self, x): 138 | x = self.conv1(x) 139 | x = self.bn1(x) 140 | x = self.relu(x) 141 | x = self.maxpool(x) 142 | 143 | x = self.layer1(x) 144 | x = self.layer2(x) 145 | x = self.layer3(x) 146 | x = self.layer4(x) 147 | 148 | x = self.avgpool(x) 149 | x = x.view(x.size(0), -1) 150 | x = self.fc(x) 151 | x = self.l2norm(x) 152 | 153 | return x 154 | 155 | 156 | def resnet18(pretrained=False, **kwargs): 157 | """Constructs a ResNet-18 model. 158 | 159 | Args: 160 | pretrained (bool): If True, returns a model pre-trained on ImageNet 161 | """ 162 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 163 | if pretrained: 164 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 165 | return model 166 | 167 | 168 | def resnet34(pretrained=False, **kwargs): 169 | """Constructs a ResNet-34 model. 170 | 171 | Args: 172 | pretrained (bool): If True, returns a model pre-trained on ImageNet 173 | """ 174 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 175 | if pretrained: 176 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 177 | return model 178 | 179 | 180 | def resnet50(pretrained=False, **kwargs): 181 | """Constructs a ResNet-50 model. 182 | 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 187 | if pretrained: 188 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 189 | return model 190 | 191 | 192 | def resnet101(pretrained=False, **kwargs): 193 | """Constructs a ResNet-101 model. 194 | 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 199 | if pretrained: 200 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 201 | return model 202 | 203 | 204 | def resnet152(pretrained=False, **kwargs): 205 | """Constructs a ResNet-152 model. 206 | 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 211 | if pretrained: 212 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 213 | return model 214 | -------------------------------------------------------------------------------- /unsupervised/cifar.py: -------------------------------------------------------------------------------- 1 | """Train CIFAR10 with PyTorch.""" 2 | import argparse 3 | import os 4 | import sys 5 | import time 6 | from pprint import pprint 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torch.optim.lr_scheduler as lr_scheduler 13 | import torchvision.transforms as transforms 14 | 15 | from lib import datasets, models 16 | from lib.LinearAverage import LinearAverage 17 | from lib.NCEAverage import NCEAverage 18 | from lib.NCECriterion import NCECriterion 19 | from lib.utils import AverageMeter, CosineAnnealingLRWithRestart 20 | from test import kNN 21 | 22 | 23 | # Training 24 | def train(net, optimizer, trainloader, criterion, lemniscate, epoch): 25 | print('\nEpoch: {}, lr {}'.format(epoch, optimizer.param_groups[0]['lr'])) 26 | train_loss = AverageMeter() 27 | data_time = AverageMeter() 28 | batch_time = AverageMeter() 29 | 30 | # switch to train mode 31 | net.train() 32 | 33 | end = time.time() 34 | for batch_idx, (inputs, targets, indexes) in enumerate(trainloader): 35 | data_time.update(time.time() - end) 36 | inputs, targets, indexes = inputs.to(args.device), targets.to( 37 | args.device), indexes.to(args.device) 38 | optimizer.zero_grad() 39 | 40 | features = net(inputs) 41 | outputs = lemniscate(features, indexes) 42 | loss = criterion(outputs, indexes) 43 | 44 | loss.backward() 45 | optimizer.step() 46 | 47 | train_loss.update(loss.item(), inputs.size(0)) 48 | 49 | # measure elapsed time 50 | batch_time.update(time.time() - end) 51 | end = time.time() 52 | 53 | if batch_idx % 100 == 0: 54 | print(f'Epoch: [{epoch}/{args.epoch}][{batch_idx}/{len(trainloader)}] ' 55 | f'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 56 | f'Data: {data_time.val:.3f} ({data_time.avg:.3f}) ' 57 | f'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f})') 58 | 59 | 60 | def get_data_loader(): 61 | normalize = transforms.Normalize( 62 | (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)) 63 | if args.transform_crop == 'RandomResizedCrop': 64 | crop = transforms.RandomResizedCrop( 65 | size=32, scale=(args.transform_scale, 1.)) 66 | else: 67 | crop = transforms.Compose([ 68 | transforms.Pad(4, padding_mode='reflect'), 69 | transforms.RandomCrop(32) 70 | ]) 71 | transform_train = transforms.Compose([ 72 | crop, 73 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), 74 | transforms.RandomGrayscale(p=0.2), 75 | transforms.RandomHorizontalFlip(), 76 | transforms.ToTensor(), 77 | normalize, 78 | ]) 79 | print('-' * 80) 80 | print('transform_train = ', transform_train) 81 | print('-' * 80) 82 | 83 | transform_test = transforms.Compose([ 84 | transforms.ToTensor(), 85 | normalize, 86 | ]) 87 | 88 | trainset = datasets.CIFAR10Instance( 89 | root=args.data_dir, train=True, download=True, transform=transform_train) 90 | trainloader = torch.utils.data.DataLoader( 91 | trainset, batch_size=128, shuffle=True, num_workers=2) 92 | 93 | testset = datasets.CIFAR10Instance( 94 | root=args.data_dir, train=False, download=True, transform=transform_test) 95 | testloader = torch.utils.data.DataLoader( 96 | testset, batch_size=100, shuffle=False, num_workers=2) 97 | 98 | ndata = trainset.__len__() 99 | 100 | return trainloader, testloader, ndata 101 | 102 | 103 | def build_model(): 104 | best_acc = 0 # best test accuracy 105 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 106 | 107 | if args.architecture == 'resnet18': 108 | net = models.__dict__['resnet18_cifar'](low_dim=args.low_dim) 109 | elif args.architecture == 'wrn-28-2': 110 | net = models.WideResNet( 111 | depth=28, num_classes=args.low_dim, widen_factor=2, dropRate=0).to(args.device) 112 | elif args.architecture == 'wrn-28-10': 113 | net = models.WideResNet( 114 | depth=28, num_classes=args.low_dim, widen_factor=10, dropRate=0).to(args.device) 115 | 116 | # define leminiscate 117 | if args.nce_k > 0: 118 | lemniscate = NCEAverage(args.low_dim, args.ndata, 119 | args.nce_k, args.nce_t, args.nce_m) 120 | else: 121 | lemniscate = LinearAverage( 122 | args.low_dim, args.ndata, args.nce_t, args.nce_m) 123 | 124 | if args.device == 'cuda': 125 | net = torch.nn.DataParallel( 126 | net, device_ids=range(torch.cuda.device_count())) 127 | cudnn.benchmark = True 128 | 129 | optimizer = optim.SGD( 130 | net.parameters(), lr=args.lr, momentum=0.9, 131 | weight_decay=args.weight_decay, nesterov=True) 132 | # Model 133 | if args.test_only or len(args.resume) > 0: 134 | # Load checkpoint. 135 | print('==> Resuming from checkpoint..') 136 | checkpoint = torch.load(args.resume) 137 | net.load_state_dict(checkpoint['net']) 138 | optimizer.load_state_dict(checkpoint['optimizer']) 139 | lemniscate = checkpoint['lemniscate'] 140 | best_acc = checkpoint['acc'] 141 | start_epoch = checkpoint['epoch'] + 1 142 | 143 | if args.lr_scheduler == 'multi-step': 144 | if args.epochs == 200: 145 | steps = [60, 120, 160] 146 | elif args.epochs == 600: 147 | steps = [180, 360, 480, 560] 148 | else: 149 | raise RuntimeError( 150 | f"need to config steps for epoch = {args.epochs} first.") 151 | scheduler = lr_scheduler.MultiStepLR( 152 | optimizer, steps, gamma=0.2, last_epoch=start_epoch - 1) 153 | elif args.lr_scheduler == 'cosine': 154 | scheduler = lr_scheduler.CosineAnnealingLR( 155 | optimizer, args.epochs, eta_min=0.00001, last_epoch=start_epoch - 1) 156 | elif args.lr_scheduler == 'cosine-with-restart': 157 | scheduler = CosineAnnealingLRWithRestart( 158 | optimizer, eta_min=0.00001, last_epoch=start_epoch - 1) 159 | else: 160 | raise ValueError("not supported") 161 | 162 | # define loss function 163 | if hasattr(lemniscate, 'K'): 164 | criterion = NCECriterion(args.ndata) 165 | else: 166 | criterion = nn.CrossEntropyLoss() 167 | 168 | net.to(args.device) 169 | lemniscate.to(args.device) 170 | criterion.to(args.device) 171 | 172 | return net, lemniscate, optimizer, criterion, scheduler, best_acc, start_epoch 173 | 174 | 175 | def main(): 176 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 177 | 178 | # Data 179 | print('==> Preparing data..') 180 | trainloader, testloader, args.ndata = get_data_loader() 181 | 182 | print('==> Building model..') 183 | net, lemniscate, optimizer, criterion, scheduler, best_acc, start_epoch = build_model() 184 | 185 | if args.test_only: 186 | kNN(net, lemniscate, trainloader, testloader, 200, args.nce_t, 1) 187 | sys.exit(0) 188 | 189 | for epoch in range(start_epoch, args.epochs): 190 | scheduler.step() 191 | train(net, optimizer, trainloader, criterion, lemniscate, epoch) 192 | acc = kNN(net, lemniscate, trainloader, testloader, 200, args.nce_t, 0) 193 | 194 | if acc > best_acc: 195 | print('Saving..') 196 | state = { 197 | 'net': net.state_dict(), 198 | 'lemniscate': lemniscate, 199 | 'acc': acc, 200 | 'epoch': epoch, 201 | 'optimizer': optimizer.state_dict(), 202 | } 203 | os.makedirs(args.model_dir, exist_ok=True) 204 | torch.save(state, os.path.join( 205 | args.model_dir, 'ckpt.cifar.pth.tar')) 206 | best_acc = acc 207 | 208 | print('best accuracy: {:.2f}'.format(best_acc * 100)) 209 | 210 | acc = kNN(net, lemniscate, trainloader, testloader, 200, args.nce_t, 1) 211 | print('last accuracy: {:.2f}'.format(acc * 100)) 212 | 213 | 214 | if __name__ == '__main__': 215 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 216 | parser.add_argument('--data-dir', '--dataDir', 217 | default='./data', type=str, metavar='DIR') 218 | parser.add_argument('--model-dir', '--modelDir', default='./checkpoint/instance_cifar10', type=str, 219 | metavar='DIR', help='directory to save checkpoint') 220 | parser.add_argument('--log-dir', '--logDir', default='./tensorboard/instance_cifar10', type=str, 221 | metavar='DIR', help='directory to save tensorboard logs') 222 | parser.add_argument('--lr', default=0.03, type=float, help='learning rate') 223 | parser.add_argument('--lr-scheduler', default='cosine', type=str, 224 | choices=['multi-step', 'cosine', 225 | 'cosine-with-restart'], 226 | help='which lr scheduler to use') 227 | parser.add_argument('--resume', '-r', default='', 228 | type=str, help='resume from checkpoint') 229 | parser.add_argument('--test-only', action='store_true', help='test only') 230 | parser.add_argument('--low-dim', default=128, type=int, 231 | metavar='D', help='feature dimension') 232 | parser.add_argument('--nce-k', default=0, type=int, 233 | metavar='K', help='number of negative samples for NCE') 234 | parser.add_argument('--nce-t', default=0.1, type=float, 235 | metavar='T', help='temperature parameter for softmax') 236 | parser.add_argument('--nce-m', default=0.5, type=float, 237 | metavar='M', help='momentum for non-parametric updates') 238 | parser.add_argument('--epochs', default=600, type=int, 239 | metavar='N', help='number of epochs') 240 | parser.add_argument('--architecture', '--arch', default='wrn-28-2', type=str, 241 | choices=['resnet18', 'wrn-28-2', 'wrn-28-10'], 242 | help='which backbone to use') 243 | parser.add_argument('--transform-scale', default=0.2, type=float) 244 | parser.add_argument('--transform-crop', type=str, default='RandomResizedCrop', 245 | choices=['RandomResizedCrop', 'PadCrop']) 246 | parser.add_argument('--weight-decay', '--wd', type=float, default=5e-4) 247 | args = parser.parse_args() 248 | 249 | pprint(vars(args)) 250 | 251 | main() 252 | 253 | pprint(vars(args)) 254 | -------------------------------------------------------------------------------- /unsupervised/imagenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim 11 | import torch.utils.data 12 | import torchvision.transforms as transforms 13 | 14 | from lib import models, datasets 15 | 16 | from lib.NCEAverage import NCEAverage 17 | from lib.LinearAverage import LinearAverage 18 | from lib.NCECriterion import NCECriterion 19 | from lib.utils import AverageMeter 20 | from test import NN, kNN 21 | 22 | model_names = sorted(name for name in models.__dict__ 23 | if name.islower() and not name.startswith("__") 24 | and callable(models.__dict__[name])) 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 27 | parser.add_argument('--data-dir', metavar='DIR', 28 | help='path to dataset', required=True) 29 | parser.add_argument('--model-dir', metavar='DIR', 30 | default='./checkpoint/instance_imagenet', help='path to save model') 31 | parser.add_argument('--log-dir', metavar='DIR', 32 | default='./tensorboard/instance_imagenet', help='path to save log') 33 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 34 | choices=model_names, 35 | help='model architecture: ' + 36 | ' | '.join(model_names) + 37 | ' (default: resnet18)') 38 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 39 | help='number of data loading workers (default: 4)') 40 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 41 | help='number of total epochs to run') 42 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 43 | help='manual epoch number (useful on restarts)') 44 | parser.add_argument('-b', '--batch-size', default=256, type=int, 45 | metavar='N', help='mini-batch size (default: 256)') 46 | parser.add_argument('-vb', '--val-batch-size', default=128, type=int, 47 | metavar='N', help='validation mini-batch size (default: 128)') 48 | parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, 49 | metavar='LR', help='initial learning rate') 50 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 51 | help='momentum') 52 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 53 | metavar='W', help='weight decay (default: 1e-4)') 54 | parser.add_argument('--print-freq', '-p', default=10, type=int, 55 | metavar='N', help='print frequency (default: 10)') 56 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 57 | help='path to latest checkpoint (default: none)') 58 | parser.add_argument('--auto-resume', action='store_true', help='auto resume') 59 | parser.add_argument('--test-only', action='store_true', help='test only') 60 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 61 | help='evaluate model on validation set') 62 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 63 | help='use pre-trained model') 64 | parser.add_argument('--low-dim', default=128, type=int, 65 | metavar='D', help='feature dimension') 66 | parser.add_argument('--nce-k', default=4096, type=int, 67 | metavar='K', help='number of negative samples for NCE') 68 | parser.add_argument('--nce-t', default=0.07, type=float, 69 | metavar='T', help='temperature parameter for softmax') 70 | parser.add_argument('--nce-m', default=0.5, type=float, 71 | help='momentum for non-parametric updates') 72 | parser.add_argument('--iter-size', default=1, type=int, 73 | help='caffe style iter size') 74 | 75 | best_prec1 = 0 76 | 77 | 78 | def main(): 79 | global args, best_prec1 80 | args = parser.parse_args() 81 | 82 | # create model 83 | if args.pretrained: 84 | print("=> using pre-trained model '{}'".format(args.arch)) 85 | model = models.__dict__[args.arch](pretrained=True) 86 | else: 87 | print("=> creating model '{}'".format(args.arch)) 88 | model = models.__dict__[args.arch](low_dim=args.low_dim) 89 | 90 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 91 | model.features = torch.nn.DataParallel(model.features) 92 | model.cuda() 93 | else: 94 | model = torch.nn.DataParallel(model).cuda() 95 | 96 | # Data loading code 97 | print("=> loading dataset") 98 | traindir = os.path.join(args.data_dir, 'train') 99 | valdir = os.path.join(args.data_dir, 'val') 100 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 101 | std=[0.229, 0.224, 0.225]) 102 | 103 | train_dataset = datasets.ImageFolderInstance( 104 | traindir, 105 | transforms.Compose([ 106 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 107 | transforms.RandomGrayscale(p=0.2), 108 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), 109 | transforms.RandomHorizontalFlip(), 110 | transforms.ToTensor(), 111 | normalize, 112 | ])) 113 | 114 | train_loader = torch.utils.data.DataLoader( 115 | train_dataset, batch_size=args.batch_size, shuffle=True, 116 | num_workers=args.workers, pin_memory=True) 117 | 118 | val_loader = torch.utils.data.DataLoader( 119 | datasets.ImageFolderInstance(valdir, transforms.Compose([ 120 | transforms.Resize(256), 121 | transforms.CenterCrop(224), 122 | transforms.ToTensor(), 123 | normalize, 124 | ])), 125 | batch_size=args.val_batch_size, shuffle=False, 126 | num_workers=args.workers, pin_memory=True) 127 | 128 | # define lemniscate and loss function (criterion) 129 | print("=> building optimizer") 130 | ndata = train_dataset.__len__() 131 | if args.nce_k > 0: 132 | lemniscate = NCEAverage(args.low_dim, ndata, 133 | args.nce_k, args.nce_t, args.nce_m).cuda() 134 | criterion = NCECriterion(ndata).cuda() 135 | else: 136 | lemniscate = LinearAverage( 137 | args.low_dim, ndata, args.nce_t, args.nce_m).cuda() 138 | criterion = nn.CrossEntropyLoss().cuda() 139 | 140 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 141 | momentum=args.momentum, 142 | weight_decay=args.weight_decay) 143 | 144 | # optionally resume from a checkpoint 145 | model_filename_to_resume = None 146 | if args.resume: 147 | if os.path.isfile(args.resume): 148 | model_filename_to_resume = args.resume 149 | else: 150 | print("=> no checkpoint found at '{}'".format(args.resume)) 151 | elif args.auto_resume: 152 | for epoch in range(args.epochs, args.start_epoch + 1, -1): 153 | model_filename = get_model_name(epoch) 154 | if os.path.exists(model_filename): 155 | model_filename_to_resume = model_filename 156 | break 157 | else: 158 | print("=> no checkpoint found at '{}'".format(args.model_dir)) 159 | 160 | if model_filename_to_resume is not None: 161 | print("=> loading checkpoint '{}'".format(model_filename_to_resume)) 162 | checkpoint = torch.load(model_filename_to_resume) 163 | args.start_epoch = checkpoint['epoch'] 164 | best_prec1 = checkpoint['best_prec1'] 165 | model.load_state_dict(checkpoint['state_dict']) 166 | lemniscate = checkpoint['lemniscate'] 167 | optimizer.load_state_dict(checkpoint['optimizer']) 168 | print("=> loaded checkpoint '{}' (epoch {})" 169 | .format(model_filename_to_resume, checkpoint['epoch'])) 170 | 171 | cudnn.benchmark = True 172 | 173 | if args.evaluate: 174 | kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t) 175 | return 176 | 177 | for epoch in range(args.start_epoch, args.epochs): 178 | adjust_learning_rate(optimizer, epoch) 179 | 180 | # train for one epoch 181 | train(train_loader, model, lemniscate, criterion, optimizer, epoch) 182 | 183 | # evaluate on validation set 184 | prec1 = NN(model, lemniscate, train_loader, val_loader) 185 | 186 | # remember best prec@1 and save checkpoint 187 | is_best = prec1 > best_prec1 188 | best_prec1 = max(prec1, best_prec1) 189 | save_checkpoint({ 190 | 'epoch': epoch + 1, 191 | 'arch': args.arch, 192 | 'state_dict': model.state_dict(), 193 | 'lemniscate': lemniscate, 194 | 'best_prec1': best_prec1, 195 | 'optimizer': optimizer.state_dict(), 196 | }, is_best, 197 | filename=get_model_name(epoch)) 198 | # evaluate KNN after last epoch 199 | kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t) 200 | 201 | 202 | def train(train_loader, model, lemniscate, criterion, optimizer, epoch): 203 | batch_time = AverageMeter() 204 | data_time = AverageMeter() 205 | losses = AverageMeter() 206 | 207 | # switch to train mode 208 | model.train() 209 | 210 | end = time.time() 211 | optimizer.zero_grad() 212 | for i, (inputs, _, index) in enumerate(train_loader): 213 | # measure data loading time 214 | data_time.update(time.time() - end) 215 | 216 | index = index.cuda(non_blocking=True) 217 | 218 | # compute output 219 | feature = model(inputs) 220 | output = lemniscate(feature, index) 221 | loss = criterion(output, index) / args.iter_size 222 | 223 | loss.backward() 224 | 225 | # measure accuracy and record loss 226 | losses.update(loss.item() * args.iter_size, inputs.size(0)) 227 | 228 | if (i + 1) % args.iter_size == 0: 229 | # compute gradient and do SGD step 230 | optimizer.step() 231 | optimizer.zero_grad() 232 | 233 | # measure elapsed time 234 | batch_time.update(time.time() - end) 235 | end = time.time() 236 | 237 | if i % args.print_freq == 0: 238 | print(f'Epoch: [{epoch}/{args.epochs}][{i}/{len(train_loader)}]\t' 239 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 240 | f'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 241 | f'Loss {losses.val:.4f} ({losses.avg:.4f})\t') 242 | 243 | 244 | def get_model_name(epoch): 245 | return os.path.join(args.model_dir, 'ckpt-{}.pth.tar'.format(epoch)) 246 | 247 | 248 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 249 | torch.save(state, filename) 250 | if is_best: 251 | shutil.copyfile(filename, os.path.join( 252 | args.model_dir, 'model_best.pth.tar')) 253 | 254 | 255 | def adjust_learning_rate(optimizer, epoch): 256 | """Sets the learning rate to the initial LR decayed by 10 every 100 epochs""" 257 | if epoch < 120: 258 | lr = args.lr 259 | elif 120 <= epoch < 160: 260 | lr = args.lr * 0.1 261 | else: 262 | lr = args.lr * 0.01 263 | # lr = args_.lr * (0.1 ** (epoch // 100)) 264 | for param_group in optimizer.param_groups: 265 | param_group['lr'] = lr 266 | 267 | 268 | if __name__ == '__main__': 269 | main() 270 | -------------------------------------------------------------------------------- /notebooks/knn-imagenet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "import sys\n", 11 | "import os\n", 12 | "import argparse\n", 13 | "import time\n", 14 | "import numpy as np\n", 15 | "sys.path.append('../')\n", 16 | "\n", 17 | "import torch\n", 18 | "import torch.nn as nn\n", 19 | "import torch.optim as optim\n", 20 | "import torch.nn.functional as F\n", 21 | "import torch.backends.cudnn as cudnn\n", 22 | "import torch.optim.lr_scheduler as lr_scheduler\n", 23 | "import torchvision\n", 24 | "import torchvision.transforms as transforms\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "import easydict as edict\n", 27 | "\n", 28 | "from lib import models, datasets\n", 29 | "import math" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# parameters\n", 39 | "args = edict\n", 40 | "\n", 41 | "# imagenet\n", 42 | "args.cache = '../checkpoint/train_features_labels_cache/instance_imagenet_train_feature_resnet50.pth.tar'\n", 43 | "args.val_cache = '../checkpoint/train_features_labels_cache/instance_imagenet_val_feature_resnet50.pth.tar'\n", 44 | "args.save_path = '../checkpoint/pseudos/unsupervised_imagenet32x32_nc_wrn-28-2'\n", 45 | "os.makedirs(args.save_path, exist_ok=True)\n", 46 | "\n", 47 | "args.low_dim = 128\n", 48 | "args.num_class = 1000\n", 49 | "args.rng_seed = 0" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "metadata": { 56 | "scrolled": true 57 | }, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "torch.float32 torch.int64\n", 64 | "torch.Size([1331167, 128]) torch.Size([1331167])\n" 65 | ] 66 | } 67 | ], 68 | "source": [ 69 | "ckpt = torch.load(args.cache)\n", 70 | "train_labels, train_features = ckpt['labels'], ckpt['features']\n", 71 | "\n", 72 | "ckpt = torch.load(args.val_cache)\n", 73 | "val_labels, val_features = ckpt['val_labels'], ckpt['val_features']\n", 74 | "\n", 75 | "train_features = torch.cat([val_features, train_features], dim=0)\n", 76 | "train_labels = torch.cat([val_labels, train_labels], dim=0)\n", 77 | "\n", 78 | "print(train_features.dtype, train_labels.dtype)\n", 79 | "print(train_features.shape, train_labels.shape)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "# use cpu because the following computation need a lot of memory" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "device = 'cpu'\n", 96 | "train_features, train_labels = train_features.to(device), train_labels.to(device)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "name": "stdout", 106 | "output_type": "stream", 107 | "text": [ 108 | "tensor([ 970454, 1058848, 717280, ..., 462299, 305137, 436069])\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "num_train_data = train_labels.shape[0]\n", 114 | "num_class = torch.max(train_labels) + 1\n", 115 | "\n", 116 | "torch.manual_seed(args.rng_seed)\n", 117 | "torch.cuda.manual_seed_all(args.rng_seed)\n", 118 | "perm = torch.randperm(num_train_data).to(device)\n", 119 | "print(perm)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "# soft label" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "class AverageMeter(object):\n", 136 | " \"\"\"Computes and stores the average and current value\"\"\"\n", 137 | " def __init__(self):\n", 138 | " self.reset()\n", 139 | "\n", 140 | " def reset(self):\n", 141 | " self.val = 0\n", 142 | " self.avg = 0\n", 143 | " self.sum = 0\n", 144 | " self.count = 0\n", 145 | "\n", 146 | " def update(self, val, n=1):\n", 147 | " self.val = val\n", 148 | " self.sum += val * n\n", 149 | " self.count += n\n", 150 | " self.avg = self.sum / self.count" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "[0]/[100] top5=85.00%(85.00%) top1=66.60%(66.60%)\n", 163 | "[1]/[100] top5=79.60%(82.30%) top1=52.80%(59.70%)\n", 164 | "[2]/[100] top5=81.00%(81.87%) top1=61.40%(60.27%)\n", 165 | "[3]/[100] top5=65.20%(77.70%) top1=42.80%(55.90%)\n", 166 | "[4]/[100] top5=70.00%(76.16%) top1=47.40%(54.20%)\n", 167 | "[5]/[100] top5=69.20%(75.00%) top1=42.60%(52.27%)\n", 168 | "[6]/[100] top5=67.20%(73.89%) top1=41.40%(50.71%)\n", 169 | "[7]/[100] top5=77.60%(74.35%) top1=52.20%(50.90%)\n", 170 | "[8]/[100] top5=84.60%(75.49%) top1=67.00%(52.69%)\n", 171 | "[9]/[100] top5=77.40%(75.68%) top1=57.00%(53.12%)\n", 172 | "[10]/[100] top5=82.20%(76.27%) top1=67.00%(54.38%)\n", 173 | "[11]/[100] top5=71.00%(75.83%) top1=49.00%(53.93%)\n", 174 | "[12]/[100] top5=65.20%(75.02%) top1=43.80%(53.15%)\n", 175 | "[13]/[100] top5=83.00%(75.59%) top1=62.20%(53.80%)\n", 176 | "[14]/[100] top5=85.20%(76.23%) top1=66.80%(54.67%)\n", 177 | "[15]/[100] top5=71.80%(75.95%) top1=43.60%(53.97%)\n", 178 | "[16]/[100] top5=62.20%(75.14%) top1=39.80%(53.14%)\n", 179 | "[17]/[100] top5=61.00%(74.36%) top1=38.80%(52.34%)\n", 180 | "[18]/[100] top5=63.60%(73.79%) top1=36.60%(51.52%)\n", 181 | "[19]/[100] top5=67.40%(73.47%) top1=41.60%(51.02%)\n", 182 | "[20]/[100] top5=70.40%(73.32%) top1=38.40%(50.42%)\n", 183 | "[21]/[100] top5=71.40%(73.24%) top1=46.60%(50.25%)\n", 184 | "[22]/[100] top5=71.00%(73.14%) top1=48.40%(50.17%)\n", 185 | "[23]/[100] top5=76.20%(73.27%) top1=44.00%(49.91%)\n", 186 | "[24]/[100] top5=71.20%(73.18%) top1=41.60%(49.58%)\n", 187 | "[25]/[100] top5=78.60%(73.39%) top1=55.40%(49.80%)\n", 188 | "[26]/[100] top5=67.20%(73.16%) top1=45.00%(49.62%)\n", 189 | "[27]/[100] top5=74.80%(73.22%) top1=52.60%(49.73%)\n", 190 | "[28]/[100] top5=74.80%(73.28%) top1=47.00%(49.63%)\n", 191 | "[29]/[100] top5=79.40%(73.48%) top1=57.80%(49.91%)\n", 192 | "[30]/[100] top5=78.20%(73.63%) top1=51.00%(49.94%)\n", 193 | "[31]/[100] top5=64.20%(73.34%) top1=37.20%(49.54%)\n", 194 | "[32]/[100] top5=86.00%(73.72%) top1=71.40%(50.21%)\n", 195 | "[33]/[100] top5=83.00%(73.99%) top1=61.00%(50.52%)\n", 196 | "[34]/[100] top5=75.80%(74.05%) top1=52.00%(50.57%)\n", 197 | "[35]/[100] top5=71.00%(73.96%) top1=44.00%(50.38%)\n", 198 | "[36]/[100] top5=73.80%(73.96%) top1=53.40%(50.46%)\n", 199 | "[37]/[100] top5=73.80%(73.95%) top1=51.80%(50.50%)\n", 200 | "[38]/[100] top5=78.80%(74.08%) top1=47.80%(50.43%)\n", 201 | "[39]/[100] top5=75.80%(74.12%) top1=57.40%(50.60%)\n", 202 | "[40]/[100] top5=76.20%(74.17%) top1=54.20%(50.69%)\n", 203 | "[41]/[100] top5=62.20%(73.89%) top1=35.80%(50.34%)\n", 204 | "[42]/[100] top5=67.80%(73.74%) top1=49.00%(50.31%)\n", 205 | "[43]/[100] top5=66.00%(73.57%) top1=45.40%(50.20%)\n", 206 | "[44]/[100] top5=67.60%(73.44%) top1=43.40%(50.04%)\n", 207 | "[45]/[100] top5=66.60%(73.29%) top1=44.20%(49.92%)\n", 208 | "[46]/[100] top5=58.80%(72.98%) top1=34.20%(49.58%)\n", 209 | "[47]/[100] top5=65.60%(72.83%) top1=48.60%(49.56%)\n", 210 | "[48]/[100] top5=67.60%(72.72%) top1=40.20%(49.37%)\n", 211 | "[49]/[100] top5=56.60%(72.40%) top1=36.60%(49.12%)\n", 212 | "[50]/[100] top5=59.20%(72.14%) top1=37.40%(48.89%)\n", 213 | "[51]/[100] top5=64.00%(71.98%) top1=38.20%(48.68%)\n", 214 | "[52]/[100] top5=68.60%(71.92%) top1=43.80%(48.59%)\n", 215 | "[53]/[100] top5=68.40%(71.85%) top1=51.80%(48.65%)\n", 216 | "[54]/[100] top5=67.20%(71.77%) top1=50.40%(48.68%)\n", 217 | "[55]/[100] top5=64.80%(71.64%) top1=45.20%(48.62%)\n", 218 | "[56]/[100] top5=78.40%(71.76%) top1=63.80%(48.88%)\n" 219 | ] 220 | } 221 | ], 222 | "source": [ 223 | "n_chunks = 100\n", 224 | "n_val = val_features.shape[0]\n", 225 | "\n", 226 | "prec_top5 = AverageMeter()\n", 227 | "prec_top1 = AverageMeter()\n", 228 | "index_labeled = torch.arange(n_val, train_features.shape[0])\n", 229 | "index_unlabeled = torch.arange(n_val)\n", 230 | "num_labeled_data = index_labeled.shape[0]\n", 231 | "\n", 232 | "for i_chunks, index_unlabeled_chunk in enumerate(index_unlabeled.chunk(n_chunks)):\n", 233 | "\n", 234 | " # calculate similarity matrix\n", 235 | " dist = torch.mm(train_features[index_unlabeled_chunk], train_features[index_labeled].t())\n", 236 | "\n", 237 | " K = min(num_labeled_data, 200)\n", 238 | " bs = index_unlabeled_chunk.shape[0]\n", 239 | " yd, yi = dist.topk(K, dim=1, largest=True, sorted=True)\n", 240 | " candidates = train_labels.view(1,-1).expand(bs, -1)\n", 241 | " retrieval = torch.gather(candidates, 1, index_labeled[yi])\n", 242 | " retrieval_one_hot = torch.zeros(bs * K, num_class).to(device)\n", 243 | " retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1)\n", 244 | "\n", 245 | " temperature = 0.1\n", 246 | "\n", 247 | " yd_transform = (yd / temperature).exp_()\n", 248 | " probs = torch.sum(torch.mul(retrieval_one_hot.view(bs, -1 , num_class), yd_transform.view(bs, -1, 1)), 1)\n", 249 | " probs.div_(probs.sum(dim=1, keepdim=True))\n", 250 | " probs_sorted, predictions = probs.sort(1, True)\n", 251 | " correct = predictions.eq(train_labels[index_unlabeled_chunk].data.view(-1,1))\n", 252 | " \n", 253 | " top5 = torch.any(correct[:, :5], dim=1).float().mean() \n", 254 | " top1 = correct[:, 0].float().mean() \n", 255 | " prec_top5.update(top5, bs)\n", 256 | " prec_top1.update(top1, bs)\n", 257 | " print('[{}]/[{}] top5={:.2%}({:.2%}) top1={:.2%}({:.2%})'.format(\n", 258 | " i_chunks, n_chunks, prec_top5.val, prec_top5.avg, prec_top1.val, prec_top1.avg))" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": { 265 | "scrolled": false 266 | }, 267 | "outputs": [], 268 | "source": [ 269 | "# n_chunks = 100\n", 270 | "\n", 271 | "# prec_top5 = AverageMeter()\n", 272 | "# for num_labeled_data in [10000]:\n", 273 | "# index_labeled = []\n", 274 | "# index_unlabeled = []\n", 275 | "# data_per_class = num_labeled_data // args.num_class\n", 276 | "# for c in range(args.num_class):\n", 277 | "# indexes_c = perm[train_labels[perm] == c]\n", 278 | "# index_labeled.append(indexes_c[:data_per_class])\n", 279 | "# index_unlabeled.append(indexes_c[data_per_class:])\n", 280 | "# index_labeled = torch.cat(index_labeled)\n", 281 | "# index_unlabeled = torch.cat(index_unlabeled)\n", 282 | "\n", 283 | "# for i_chunks, index_unlabeled_chunk in enumerate(index_unlabeled.chunk(n_chunks)):\n", 284 | " \n", 285 | "# # calculate similarity matrix\n", 286 | "# dist = torch.mm(train_features[index_unlabeled_chunk], train_features[index_labeled].t())\n", 287 | "\n", 288 | "# K = min(num_labeled_data, 5000)\n", 289 | "# bs = index_unlabeled_chunk.shape[0]\n", 290 | "# yd, yi = dist.topk(K, dim=1, largest=True, sorted=True)\n", 291 | "# candidates = train_labels.view(1,-1).expand(bs, -1)\n", 292 | "# retrieval = torch.gather(candidates, 1, index_labeled[yi])\n", 293 | "# retrieval_one_hot = torch.zeros(bs * K, num_class).to(device)\n", 294 | "# retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1)\n", 295 | "\n", 296 | "# temperature = 0.1\n", 297 | "\n", 298 | "# yd_transform = (yd / temperature).exp_()\n", 299 | "# probs = torch.sum(torch.mul(retrieval_one_hot.view(bs, -1 , num_class), yd_transform.view(bs, -1, 1)), 1)\n", 300 | "# probs.div_(probs.sum(dim=1, keepdim=True))\n", 301 | "# probs_sorted, predictions = probs.sort(1, True)\n", 302 | "# correct = predictions.eq(train_labels[index_unlabeled_chunk].data.view(-1,1))\n", 303 | "# top5 = torch.any(correct[:, :5], dim=1).float().mean() \n", 304 | " \n", 305 | "# prec_top5.update(top5, bs)\n", 306 | "# print('[{}]/[{}] {:.2%} {:.2%}'.format(i_chunks, n_chunks, prec_top5.val, prec_top5.avg))" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [] 315 | } 316 | ], 317 | "metadata": { 318 | "kernelspec": { 319 | "display_name": "Python 3", 320 | "language": "python", 321 | "name": "python3" 322 | }, 323 | "language_info": { 324 | "codemirror_mode": { 325 | "name": "ipython", 326 | "version": 3 327 | }, 328 | "file_extension": ".py", 329 | "mimetype": "text/x-python", 330 | "name": "python", 331 | "nbconvert_exporter": "python", 332 | "pygments_lexer": "ipython3", 333 | "version": "3.6.8" 334 | } 335 | }, 336 | "nbformat": 4, 337 | "nbformat_minor": 2 338 | } 339 | -------------------------------------------------------------------------------- /cifar-semi.py: -------------------------------------------------------------------------------- 1 | """Train CIFAR10 with PyTorch.""" 2 | import argparse 3 | import os 4 | import random 5 | import time 6 | from pprint import pprint 7 | 8 | import numpy as np 9 | from skimage.color import rgb2gray 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | from torch.optim.lr_scheduler import CosineAnnealingLR 15 | from torch.utils.data.sampler import SubsetRandomSampler 16 | from torchvision import transforms 17 | from torchvision.datasets import CIFAR10 18 | from torch.utils.tensorboard import SummaryWriter 19 | 20 | from lib.datasets import PseudoCIFAR10 21 | from lib.utils import AverageMeter, accuracy, CosineAnnealingLRWithRestart 22 | from lib.models import WideResNet, resnet18_cifar 23 | from test import validate 24 | 25 | 26 | def get_dataloader(args): 27 | if not args.input_gray: 28 | normalize = transforms.Normalize( 29 | (0.4914, 0.4822, 0.4465), 30 | (0.2470, 0.2435, 0.2616)) 31 | transform_train = transforms.Compose([ 32 | transforms.Pad(4, padding_mode='reflect'), 33 | transforms.RandomCrop(32), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.ToTensor(), 36 | normalize, 37 | ]) 38 | transform_test = transforms.Compose([ 39 | transforms.ToTensor(), 40 | normalize, 41 | ]) 42 | else: 43 | to_gray = transforms.Lambda(lambda img: torch.from_numpy( 44 | rgb2gray(np.array(img))).unsqueeze(0).float()) 45 | transform_train = transforms.Compose([ 46 | transforms.Pad(4, padding_mode='reflect'), 47 | transforms.RandomCrop(32), 48 | transforms.RandomHorizontalFlip(), 49 | to_gray, 50 | ]) 51 | transform_test = to_gray 52 | 53 | testset = CIFAR10(root=args.data_dir, train=False, 54 | download=True, transform=transform_test) 55 | testloader = torch.utils.data.DataLoader( 56 | testset, shuffle=False, 57 | batch_size=args.batch_size, 58 | num_workers=args.num_workers) 59 | 60 | trainset = CIFAR10(root=args.data_dir, train=True, 61 | download=True, transform=transform_test) 62 | 63 | args.ndata = len(trainset) 64 | num_labeled_data = args.num_labeled 65 | num_unlabeled_data = args.ndata - num_labeled_data 66 | 67 | if args.pseudo_file is not None: 68 | pseudo_dict = torch.load(args.pseudo_file) 69 | labeled_indexes = pseudo_dict['labeled_indexes'] 70 | else: 71 | torch.manual_seed(args.rng_seed) 72 | perm = torch.randperm(args.ndata) 73 | labeled_indexes = perm[:num_labeled_data] 74 | 75 | pseudo_trainset = PseudoCIFAR10( 76 | labeled_indexes=labeled_indexes, root=args.data_dir, 77 | train=True, transform=transform_train) 78 | 79 | # load pseudo labels 80 | if args.pseudo_file is not None: 81 | pseudo_num = int(num_unlabeled_data * args.pseudo_ratio) 82 | pseudo_indexes = pseudo_dict['pseudo_indexes'][:pseudo_num] 83 | pseudo_labels = pseudo_dict['pseudo_labels'][:pseudo_num] 84 | pseudo_trainset.set_pseudo(pseudo_indexes, pseudo_labels) 85 | 86 | pseudo_trainloder = torch.utils.data.DataLoader( 87 | pseudo_trainset, batch_size=args.batch_size, 88 | shuffle=True, num_workers=args.num_workers) 89 | 90 | print('-' * 80) 91 | print('selected labeled indexes: ', labeled_indexes) 92 | 93 | return testloader, pseudo_trainloder 94 | 95 | 96 | def build_model(args): 97 | if args.architecture == 'resnet18': 98 | net = resnet18_cifar(low_dim=args.num_class, norm=False) 99 | elif args.architecture.startswith('wrn'): 100 | split = args.architecture.split('-') 101 | net = WideResNet(depth=int(split[1]), widen_factor=int(split[2]), 102 | num_classes=args.num_class, norm=False) 103 | else: 104 | raise ValueError('architecture should be resnet18 or wrn') 105 | if args.input_gray: 106 | net.conv1 = nn.Conv2d(1, net.conv1.out_channels, 107 | kernel_size=3, stride=1, padding=1, bias=False) 108 | net = net.to(args.device) 109 | 110 | print('#param: {}'.format(sum([p.nelement() for p in net.parameters()]))) 111 | 112 | if args.device == 'cuda': 113 | net = torch.nn.DataParallel( 114 | net, device_ids=range(torch.cuda.device_count())) 115 | cudnn.benchmark = True 116 | 117 | # resume from unsupervised pretrain 118 | if len(args.resume) > 0: 119 | # Load checkpoint. 120 | print('==> Resuming from unsupervised pretrained checkpoint..') 121 | checkpoint = torch.load(args.resume) 122 | # only load shared conv layers, don't load fc 123 | model_dict = net.state_dict() 124 | if not args.input_gray: 125 | pretrained_dict = checkpoint['net'] 126 | else: 127 | lst = ['conv1', 'block1', 'block2', 'block3'] 128 | pretrained_dict = { 129 | 'module.' + lst[int(k[0])] + k[1:]: v for k, v in checkpoint.items()} 130 | pretrained_dict = {k: v for k, v in pretrained_dict.items() 131 | if k in model_dict 132 | and v.size() == model_dict[k].size()} 133 | assert len(pretrained_dict) > 0 134 | model_dict.update(pretrained_dict) 135 | net.load_state_dict(model_dict) 136 | 137 | return net 138 | 139 | 140 | def get_lr_scheduler(optimizer, lr_scheduler, max_iters): 141 | if args.lr_scheduler == 'cosine': 142 | scheduler = CosineAnnealingLR(optimizer, max_iters, eta_min=0.00001) 143 | elif args.lr_scheduler == 'cosine-with-restart': 144 | scheduler = CosineAnnealingLRWithRestart(optimizer, eta_min=0.00001) 145 | else: 146 | raise ValueError("not supported") 147 | 148 | return scheduler 149 | 150 | 151 | # Training 152 | def train(net, optimizer, scheduler, trainloader, testloader, criterion, summary_writer, args): 153 | train_loss = AverageMeter() 154 | data_time = AverageMeter() 155 | batch_time = AverageMeter() 156 | top1 = AverageMeter() 157 | top2 = AverageMeter() 158 | 159 | best_acc = 0 160 | end = time.time() 161 | 162 | def inf_generator(trainloader): 163 | while True: 164 | for data in trainloader: 165 | yield data 166 | 167 | for step, (inputs, targets) in enumerate(inf_generator(trainloader)): 168 | if step >= args.max_iters: 169 | break 170 | 171 | data_time.update(time.time() - end) 172 | 173 | inputs = inputs.to(args.device) 174 | targets = targets.to(args.device) 175 | 176 | # switch to train mode 177 | net.train() 178 | scheduler.step() 179 | optimizer.zero_grad() 180 | 181 | outputs = net(inputs) 182 | loss = criterion(outputs, targets).mean() 183 | prec1, prec2 = accuracy(outputs, targets, topk=(1, 2)) 184 | top1.update(prec1[0], inputs.size(0)) 185 | top2.update(prec2[0], inputs.size(0)) 186 | 187 | loss.backward() 188 | optimizer.step() 189 | 190 | train_loss.update(loss.item(), inputs.size(0)) 191 | 192 | # measure elapsed time 193 | batch_time.update(time.time() - end) 194 | end = time.time() 195 | 196 | summary_writer.add_scalar('lr', optimizer.param_groups[0]['lr'], step) 197 | summary_writer.add_scalar('top1', top1.val, step) 198 | summary_writer.add_scalar('top2', top2.val, step) 199 | summary_writer.add_scalar('batch_time', batch_time.val, step) 200 | summary_writer.add_scalar('data_time', data_time.val, step) 201 | summary_writer.add_scalar('train_loss', train_loss.val, step) 202 | 203 | if step % args.print_freq == 0: 204 | lr = optimizer.param_groups[0]["lr"] 205 | print(f'Train: [{step}/{args.max_iters}] ' 206 | f'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 207 | f'Data: {data_time.val:.3f} ({data_time.avg:.3f}) ' 208 | f'Lr: {lr:.5f} ' 209 | f'prec1: {top1.val:.3f} ({top1.avg:.3f}) ' 210 | f'prec2: {top2.val:.3f} ({top2.avg:.3f}) ' 211 | f'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f})') 212 | 213 | if (step + 1) % args.eval_freq == 0 or step == args.max_iters - 1: 214 | acc = validate(testloader, net, criterion, 215 | device=args.device, print_freq=args.print_freq) 216 | 217 | summary_writer.add_scalar('val_top1', acc, step) 218 | 219 | if acc > best_acc: 220 | best_acc = acc 221 | state = { 222 | 'step': step, 223 | 'best_acc': best_acc, 224 | 'net': net.state_dict(), 225 | 'optimizer': optimizer.state_dict(), 226 | } 227 | os.makedirs(args.model_dir, exist_ok=True) 228 | torch.save(state, os.path.join(args.model_dir, 'ckpt.pth.tar')) 229 | 230 | print('best accuracy: {:.2f}\n'.format(best_acc)) 231 | 232 | 233 | def main(args): 234 | # Data 235 | print('==> Preparing data..') 236 | testloader, pseudo_trainloder = get_dataloader(args) 237 | 238 | print('==> Building model..') 239 | net = build_model(args) 240 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, 241 | weight_decay=5e-4, nesterov=True) 242 | 243 | criterion = nn.__dict__[args.criterion]().to(args.device) 244 | scheduler = get_lr_scheduler(optimizer, args.lr_scheduler, args.max_iters) 245 | 246 | if args.eval: 247 | return validate(testloader, net, criterion, 248 | device=args.device, print_freq=args.print_freq) 249 | # summary writer 250 | os.makedirs(args.log_dir, exist_ok=True) 251 | summary_writer = SummaryWriter(args.log_dir) 252 | 253 | train(net, optimizer, scheduler, pseudo_trainloder, 254 | testloader, criterion, summary_writer, args) 255 | 256 | 257 | if __name__ == '__main__': 258 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 259 | parser.add_argument('--data_dir', '--dataDir', default='./data', 260 | type=str, metavar='DIR') 261 | parser.add_argument('--model-root', default='./checkpoint/cifar10-semi', 262 | type=str, metavar='DIR', 263 | help='root directory to save checkpoint') 264 | parser.add_argument('--log-root', default='./tensorboard/cifar10-semi', 265 | type=str, metavar='DIR', 266 | help='root directory to save tensorboard logs') 267 | parser.add_argument('--exp-name', default='exp', type=str, 268 | help='experiment name, used to determine log_dir and model_dir') 269 | parser.add_argument('--lr', default=0.01, type=float, 270 | metavar='LR', help='learning rate') 271 | parser.add_argument('--lr-scheduler', default='cosine', type=str, 272 | choices=['multi-step', 'cosine', 273 | 'cosine-with-restart'], 274 | help='which lr scheduler to use') 275 | parser.add_argument('--resume', '-r', default='', type=str, 276 | metavar='FILE', help='resume from checkpoint') 277 | parser.add_argument('--eval', action='store_true', help='test only') 278 | parser.add_argument('--finetune', action='store_true', 279 | help='only training last fc layer') 280 | parser.add_argument('-j', '--num-workers', default=2, type=int, 281 | metavar='N', help='number of workers to load data') 282 | parser.add_argument('-b', '--batch-size', default=128, type=int, 283 | metavar='N', help='batch size') 284 | parser.add_argument('--max-iters', default=500000, type=int, 285 | metavar='N', help='number of iterations') 286 | parser.add_argument('--num-labeled', default=500, type=int, 287 | metavar='N', help='number of labeled data') 288 | parser.add_argument('--rng-seed', default=0, type=int, 289 | metavar='N', help='random number generator seed') 290 | parser.add_argument('--gpus', default='0', type=str, metavar='GPUS') 291 | parser.add_argument('--eval-freq', default=500, type=int, 292 | metavar='N', help='eval frequence') 293 | parser.add_argument('--print-freq', default=100, type=int, 294 | metavar='N', help='print frequence') 295 | parser.add_argument('--criterion', default='CrossEntropyLoss', type=str, 296 | choices=['CrossEntropyLoss', 'MultiMarginLoss']) 297 | parser.add_argument('--pseudo-file', type=str, 298 | metavar='FILE', help='pseudo file to load', required=True) 299 | parser.add_argument('--input-gray', action='store_true', 300 | help='set for load colorization pretrained model, ' 301 | '(colorization model use gray image as input)') 302 | parser.add_argument('--pseudo-ratio', default=1, type=float, metavar='0-1', 303 | help='ratio of unlabeled data to use for pseudo labels') 304 | parser.add_argument('--architecture', '--arch', default='wrn-28-2', type=str, 305 | help='which backbone to use') 306 | args, rest = parser.parse_known_args() 307 | print(rest) 308 | 309 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 310 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 311 | args.num_class = 10 312 | args.log_dir = os.path.join(args.log_root, args.exp_name) 313 | args.model_dir = os.path.join(args.model_root, args.exp_name) 314 | 315 | torch.manual_seed(args.rng_seed) 316 | torch.cuda.manual_seed(args.rng_seed) 317 | random.seed(args.rng_seed) 318 | torch.set_printoptions(threshold=50, precision=4) 319 | 320 | print('-' * 80) 321 | pprint(vars(args)) 322 | 323 | main(args) 324 | 325 | print('-' * 80) 326 | pprint(vars(args)) 327 | -------------------------------------------------------------------------------- /notebooks/nc-colorization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "import sys\n", 11 | "sys.path.append('../')\n", 12 | "import torch\n", 13 | "import torch.nn as nn\n", 14 | "import torch.optim as optim\n", 15 | "import torch.nn.functional as F\n", 16 | "import torch.backends.cudnn as cudnn\n", 17 | "\n", 18 | "import torchvision\n", 19 | "import torchvision.transforms as transforms\n", 20 | "\n", 21 | "import math\n", 22 | "import os\n", 23 | "import argparse\n", 24 | "import time\n", 25 | "\n", 26 | "from lib import models, datasets\n", 27 | "\n", 28 | "\n", 29 | "import numpy as np\n", 30 | "import scipy as sp\n", 31 | "import scipy.sparse.linalg as linalg\n", 32 | "import scipy.sparse as sparse\n", 33 | "\n", 34 | "import matplotlib.pyplot as plt\n", 35 | "import easydict as edict" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "# parameters\n", 45 | "args = edict\n", 46 | "\n", 47 | "args.cache = '../checkpoint/train_features_labels_cache/colorization_embedding_128.t7'\n", 48 | "args.save_path = '../checkpoint/pseudos/colorization_nc_pseudo_wrn-28-2'\n", 49 | "os.makedirs(args.save_path, exist_ok=True)\n", 50 | "\n", 51 | "args.num_class = 10\n", 52 | "args.rng_seed = 0" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "metadata": { 59 | "scrolled": true 60 | }, 61 | "outputs": [ 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "torch.float32 torch.int64\n", 67 | "torch.Size([50000, 128]) torch.Size([50000])\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "train_features = torch.load(args.cache)\n", 73 | "train_labels = torch.Tensor(datasets.CIFAR10Instance(root='../data', train=True).targets).long()\n", 74 | "\n", 75 | "print(train_features.dtype, train_labels.dtype)\n", 76 | "print(train_features.shape, train_labels.shape)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "# use cpu because the follow computation need a lot of memory" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 4, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "device = 'cpu'\n", 93 | "train_features, train_labels = train_features.to(device), train_labels.to(device)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 5, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "tensor([36044, 49165, 37807, ..., 42128, 15898, 31476])\n" 106 | ] 107 | } 108 | ], 109 | "source": [ 110 | "num_train_data = train_labels.shape[0]\n", 111 | "num_class = torch.max(train_labels) + 1\n", 112 | "\n", 113 | "torch.manual_seed(args.rng_seed)\n", 114 | "torch.cuda.manual_seed_all(args.rng_seed)\n", 115 | "perm = torch.randperm(num_train_data).to(device)\n", 116 | "print(perm)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "# constrained normalized cut" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 6, 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "name": "stdout", 133 | "output_type": "stream", 134 | "text": [ 135 | "similarity done\n", 136 | "L_sys done\n" 137 | ] 138 | } 139 | ], 140 | "source": [ 141 | "K = 20\n", 142 | "def make_column_normalize(X):\n", 143 | " return X.div(torch.norm(X, p=2, dim=0, keepdim=True))\n", 144 | "\n", 145 | "cosin_similarity = torch.mm(train_features, train_features.t())\n", 146 | "dist = (1 - cosin_similarity) / 2\n", 147 | "\n", 148 | "dist_sorted, idx = dist.topk(K, dim=1, largest=False, sorted=True)\n", 149 | "k_dist = dist_sorted[:, -1:]\n", 150 | "\n", 151 | "similarity_dense = torch.exp(-dist_sorted * 2 / k_dist)\n", 152 | "similarity_sparse = torch.zeros_like(cosin_similarity)\n", 153 | "similarity_sparse[torch.arange(num_train_data).view(-1, 1), idx[:, 1:]] = similarity_dense[:, 1:]\n", 154 | "similarity = torch.max(similarity_sparse, similarity_sparse.t())\n", 155 | "print('similarity done')\n", 156 | "\n", 157 | "degree = similarity.sum(0)\n", 158 | "degree_normed = (degree**(-0.5))\n", 159 | "L_sys = degree_normed.view(-1, 1) * (degree.diag() - similarity) * degree_normed.view(1, -1)\n", 160 | "print('L_sys done')" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 7, 166 | "metadata": { 167 | "scrolled": true 168 | }, 169 | "outputs": [ 170 | { 171 | "name": "stdout", 172 | "output_type": "stream", 173 | "text": [ 174 | "eigenvectors done\n", 175 | "tensor([0.0160, 0.0236, 0.0306, 0.0318, 0.0359, 0.0400, 0.0453, 0.0561, 0.0603,\n", 176 | " 0.0621, 0.0635, 0.0731, 0.0745, 0.0775, 0.0795, 0.0860, 0.0881, 0.0942,\n", 177 | " 0.0967, 0.1003, 0.1035, 0.1047, 0.1070, 0.1094, 0.1170, 0.1221, 0.1237,\n", 178 | " 0.1275, 0.1290, 0.1316, 0.1377, 0.1384, 0.1393, 0.1425, 0.1435, 0.1466,\n", 179 | " 0.1505, 0.1539, 0.1548, 0.1591, 0.1615, 0.1644, 0.1660, 0.1665, 0.1699,\n", 180 | " 0.1715, 0.1721, 0.1727, 0.1734, 0.1756, 0.1794, 0.1803, 0.1805, 0.1819,\n", 181 | " 0.1854, 0.1874, 0.1875, 0.1881, 0.1889, 0.1919, 0.1926, 0.1952, 0.1975,\n", 182 | " 0.1997, 0.2004, 0.2010, 0.2025, 0.2035, 0.2053, 0.2068, 0.2088, 0.2101,\n", 183 | " 0.2117, 0.2138, 0.2145, 0.2150, 0.2183, 0.2190, 0.2205, 0.2220, 0.2245,\n", 184 | " 0.2251, 0.2275, 0.2282, 0.2285, 0.2297, 0.2299, 0.2316, 0.2330, 0.2358,\n", 185 | " 0.2359, 0.2381, 0.2396, 0.2409, 0.2425, 0.2440, 0.2448, 0.2459, 0.2462,\n", 186 | " 0.2480, 0.2492, 0.2495, 0.2518, 0.2520, 0.2531, 0.2543, 0.2545, 0.2562,\n", 187 | " 0.2574, 0.2582, 0.2592, 0.2598, 0.2612, 0.2621, 0.2633, 0.2638, 0.2642,\n", 188 | " 0.2653, 0.2669, 0.2679, 0.2689, 0.2699, 0.2706, 0.2716, 0.2721, 0.2729,\n", 189 | " 0.2744, 0.2750, 0.2761, 0.2774, 0.2784, 0.2794, 0.2803, 0.2810, 0.2821,\n", 190 | " 0.2825, 0.2835, 0.2847, 0.2850, 0.2853, 0.2867, 0.2877, 0.2882, 0.2893,\n", 191 | " 0.2897, 0.2902, 0.2911, 0.2920, 0.2933, 0.2944, 0.2951, 0.2962, 0.2966,\n", 192 | " 0.2973, 0.2978, 0.2984, 0.2990, 0.2994, 0.3007, 0.3011, 0.3022, 0.3023,\n", 193 | " 0.3030, 0.3043, 0.3046, 0.3054, 0.3062, 0.3067, 0.3069, 0.3075, 0.3080,\n", 194 | " 0.3097, 0.3110, 0.3114, 0.3119, 0.3131, 0.3133, 0.3144, 0.3152, 0.3153,\n", 195 | " 0.3159, 0.3172, 0.3176, 0.3183, 0.3188, 0.3193, 0.3197, 0.3205, 0.3212,\n", 196 | " 0.3216, 0.3220, 0.3226, 0.3231, 0.3238, 0.3242, 0.3249, 0.3259, 0.3262,\n", 197 | " 0.3270])\n" 198 | ] 199 | } 200 | ], 201 | "source": [ 202 | "num_eigenvectors = 200 # the number of precomputed spectral eigenvectors.\n", 203 | "\n", 204 | "eigenvalues, eigenvectors = linalg.eigs(L_sys.numpy(), k=num_eigenvectors, which='SR', tol=1e-2, maxiter=30000)\n", 205 | "eigenvalues, eigenvectors = torch.from_numpy(eigenvalues.real)[1:], torch.from_numpy(eigenvectors.real)[:, 1:]\n", 206 | "eigenvalues, idx = eigenvalues.sort()\n", 207 | "eigenvectors = eigenvectors[:, idx]\n", 208 | "print('eigenvectors done')\n", 209 | "print(eigenvalues)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 8, 215 | "metadata": { 216 | "scrolled": false 217 | }, 218 | "outputs": [ 219 | { 220 | "name": "stdout", 221 | "output_type": "stream", 222 | "text": [ 223 | "num_labeled= 50 T_nc=1, prec=48.40, AUC=60.85\n", 224 | "num_labeled= 100 T_nc=1, prec=51.91, AUC=67.34\n", 225 | "num_labeled= 250 T_nc=1, prec=61.03, AUC=76.31\n", 226 | "num_labeled= 500 T_nc=1, prec=64.05, AUC=80.04\n", 227 | "num_labeled=1000 T_nc=1, prec=64.84, AUC=81.78\n", 228 | "num_labeled=2000 T_nc=1, prec=64.84, AUC=81.89\n", 229 | "num_labeled=4000 T_nc=1, prec=65.60, AUC=82.93\n", 230 | "num_labeled=8000 T_nc=1, prec=65.11, AUC=82.03\n" 231 | ] 232 | } 233 | ], 234 | "source": [ 235 | "fig = plt.figure(dpi=200)\n", 236 | "\n", 237 | "for num_labeled_data in [50, 100, 250, 500, 1000, 2000, 4000, 8000]:\n", 238 | " # index of labeled and unlabeled\n", 239 | " # even split\n", 240 | " index_labeled = []\n", 241 | " index_unlabeled = []\n", 242 | " data_per_class = num_labeled_data // args.num_class\n", 243 | " for c in range(10):\n", 244 | " indexes_c = perm[train_labels[perm] == c]\n", 245 | " index_labeled.append(indexes_c[:data_per_class])\n", 246 | " index_unlabeled.append(indexes_c[data_per_class:])\n", 247 | " index_labeled = torch.cat(index_labeled)\n", 248 | " index_unlabeled = torch.cat(index_unlabeled)\n", 249 | "\n", 250 | "# index_labeled = perm[:num_labeled_data]\n", 251 | "# index_unlabeled = perm[num_labeled_data:]\n", 252 | " \n", 253 | " # prior\n", 254 | " unary_prior = torch.zeros([num_train_data, num_class])\n", 255 | " unary_prior[index_labeled, :] = -1\n", 256 | " unary_prior[index_labeled, train_labels[index_labeled]] = 1\n", 257 | " AQ = unary_prior.abs()\n", 258 | " pd = degree.view(-1, 1) * (AQ + unary_prior) / 2\n", 259 | " nd = degree.view(-1, 1) * (AQ - unary_prior) / 2\n", 260 | " np_ratio = pd.sum(dim=0) / nd.sum(dim=0)\n", 261 | " unary_prior_norm = (pd / np_ratio).sqrt() - (nd * np_ratio).sqrt()\n", 262 | " unary_prior_norm = make_column_normalize(unary_prior_norm)\n", 263 | " \n", 264 | " # logits and prediction\n", 265 | " alpha = 0\n", 266 | " lambda_reverse = (1 / (eigenvalues - alpha)).view(1, -1)\n", 267 | " logits = torch.mm(lambda_reverse * eigenvectors, torch.mm(eigenvectors.t(), unary_prior_norm))\n", 268 | " logits = make_column_normalize(logits) * math.sqrt(logits.shape[0]) \n", 269 | " logits = logits - logits.max(1, keepdim=True)[0]\n", 270 | " _, predict = logits.max(dim=1)\n", 271 | " \n", 272 | " for temperature_nc in [1]:#, 2, 3, 5, 10, 15, 20, 25, 30, 35, 40, 100]: \n", 273 | " # pseudo weights\n", 274 | " logits_sorted = logits.sort(dim=1, descending=True)[0]\n", 275 | " subtract = logits_sorted[:, 0] - logits_sorted[:, 1]\n", 276 | " pseudo_weights = 1 - torch.exp(- subtract / temperature_nc)\n", 277 | " \n", 278 | " exp = (logits * temperature_nc).exp()\n", 279 | " probs = exp / exp.sum(1, keepdim=True)\n", 280 | " probs_sorted, predict_all = probs.sort(1, True)\n", 281 | " assert torch.all(predict == predict_all[:, 0])\n", 282 | "\n", 283 | " idx = pseudo_weights[index_unlabeled].sort(dim=0, descending=True)[1]\n", 284 | " pseudo_indexes = index_unlabeled[idx]\n", 285 | " pseudo_labels = predict[index_unlabeled][idx]\n", 286 | " pseudo_probs = probs[index_unlabeled][idx]\n", 287 | " pseudo_weights = pseudo_weights[index_unlabeled][idx]\n", 288 | " assert torch.all(pseudo_labels == pseudo_probs.max(1)[1])\n", 289 | " \n", 290 | " save_dict = {\n", 291 | " 'pseudo_indexes': pseudo_indexes,\n", 292 | " 'pseudo_labels': pseudo_labels,\n", 293 | " 'pseudo_probs': pseudo_probs,\n", 294 | " 'pseudo_weights': pseudo_weights,\n", 295 | " 'labeled_indexes': index_labeled,\n", 296 | " 'unlabeled_indexes': index_unlabeled,\n", 297 | " }\n", 298 | " torch.save(save_dict, os.path.join(args.save_path, '{}.pth.tar'.format(num_labeled_data)))\n", 299 | "\n", 300 | " # for plot\n", 301 | " correct = pseudo_labels == train_labels[pseudo_indexes]\n", 302 | " \n", 303 | " entropy = - (pseudo_probs * torch.log(pseudo_probs + 1e-7)).sum(dim=1)\n", 304 | " confidence = (- entropy * 1).exp()\n", 305 | " confidence /= confidence.max()\n", 306 | "\n", 307 | " arange = 1 + np.arange(confidence.shape[0])\n", 308 | " xs = arange / confidence.shape[0]\n", 309 | " correct_tmp = correct[confidence.sort(descending=True)[1]]\n", 310 | " accuracies = np.cumsum(correct_tmp.numpy()) / arange\n", 311 | " plt.plot(xs, accuracies, label='num_labeled_data={}'.format(num_labeled_data))\n", 312 | "\n", 313 | " acc = correct.float().mean()\n", 314 | "\n", 315 | " print('num_labeled={:4} T_nc={}, prec={:.2f}, AUC={:.2f}'.format(\n", 316 | " num_labeled_data, temperature_nc, acc * 100, accuracies.mean() * 100))\n", 317 | " \n", 318 | "plt.xlabel('accumulated unlabeled data ratio')\n", 319 | "plt.ylabel('unlabeled top1 accuracy')\n", 320 | "plt.xticks(np.arange(0, 1.01, 0.1))\n", 321 | "plt.grid()\n", 322 | "plt.title('num_eigenvectors={}'.format(num_eigenvectors))\n", 323 | "legend = plt.legend(loc='upper left', bbox_to_anchor=(1, 1))\n", 324 | "plt.show()" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": null, 330 | "metadata": {}, 331 | "outputs": [], 332 | "source": [] 333 | } 334 | ], 335 | "metadata": { 336 | "kernelspec": { 337 | "display_name": "Python 3", 338 | "language": "python", 339 | "name": "python3" 340 | }, 341 | "language_info": { 342 | "codemirror_mode": { 343 | "name": "ipython", 344 | "version": 3 345 | }, 346 | "file_extension": ".py", 347 | "mimetype": "text/x-python", 348 | "name": "python", 349 | "nbconvert_exporter": "python", 350 | "pygments_lexer": "ipython3", 351 | "version": "3.6.8" 352 | } 353 | }, 354 | "nbformat": 4, 355 | "nbformat_minor": 2 356 | } 357 | -------------------------------------------------------------------------------- /imagenet-semi.py: -------------------------------------------------------------------------------- 1 | """Train ImageNet with PyTorch.""" 2 | import argparse 3 | import glob 4 | import os 5 | import random 6 | import time 7 | from pprint import pprint 8 | 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torchvision.datasets as datasets 14 | import torchvision.models as models 15 | import torchvision.transforms as transforms 16 | from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR 17 | from torch.utils.data.sampler import SubsetRandomSampler 18 | from torch.utils.tensorboard import SummaryWriter 19 | 20 | from lib.datasets import PseudoDatasetFolder 21 | from lib.utils import AverageMeter, accuracy, CosineAnnealingLRWithRestart 22 | from test import validate 23 | 24 | best_acc = 0 25 | global_step = 0 26 | 27 | 28 | def get_dataloader(args): 29 | normalize = transforms.Normalize( 30 | (0.485, 0.456, 0.406), 31 | (0.229, 0.224, 0.225)) 32 | transform_train = transforms.Compose([ 33 | transforms.RandomResizedCrop(224), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.ToTensor(), 36 | normalize, 37 | ]) 38 | 39 | transform_test = transforms.Compose([ 40 | transforms.Resize(256), 41 | transforms.CenterCrop(224), 42 | transforms.ToTensor(), 43 | normalize, 44 | ]) 45 | traindir = os.path.join(args.data_dir, 'train') 46 | valdir = os.path.join(args.data_dir, 'val') 47 | 48 | testset = datasets.ImageFolder(valdir, transform=transform_test) 49 | testloader = torch.utils.data.DataLoader( 50 | testset, shuffle=False, 51 | batch_size=args.batch_size, 52 | num_workers=args.num_workers) 53 | 54 | trainset = datasets.ImageFolder(traindir, transform=transform_train) 55 | 56 | # split labeled and unlabeled 57 | args.ndata = len(trainset) 58 | num_labeled = args.num_labeled 59 | num_unlabeled = args.ndata - num_labeled 60 | 61 | torch.manual_seed(args.rng_seed) 62 | perm = torch.randperm(args.ndata) 63 | 64 | index_labeled = [] 65 | index_unlabeled = [] 66 | data_per_class = num_labeled // args.num_class 67 | train_labels = torch.Tensor([x[1] for x in trainset.samples]) 68 | for c in range(args.num_class): 69 | indexes_c = perm[train_labels[perm] == c] 70 | index_labeled.append(indexes_c[:data_per_class]) 71 | index_unlabeled.append(indexes_c[data_per_class:]) 72 | 73 | args.index_labeled = torch.cat(index_labeled) 74 | args.index_unlabeled = torch.cat(index_unlabeled) 75 | 76 | print('-' * 80) 77 | print('selected labeled indexes: ', args.index_labeled) 78 | 79 | pseudo_trainset = PseudoDatasetFolder( 80 | trainset, labeled_indexes=args.index_labeled) 81 | # load pseudo labels 82 | if args.pseudo_dir is not None: 83 | pseudo_files = glob.glob(args.pseudo_dir + '/*') 84 | pseudo_num_per_chunk = int( 85 | num_unlabeled * args.pseudo_ratio / len(pseudo_files)) 86 | 87 | pseudo_indexes = [] 88 | pseudo_labels = [] 89 | for pseudo_file in pseudo_files: 90 | pseudo_dict = torch.load(pseudo_file) 91 | pseudo_indexes.append( 92 | pseudo_dict['pseudo_indexes'][:pseudo_num_per_chunk]) 93 | pseudo_labels.append( 94 | pseudo_dict['pseudo_labels'][:pseudo_num_per_chunk]) 95 | assert (args.index_labeled == pseudo_dict['labeled_indexes']).all() 96 | pseudo_indexes = torch.cat(pseudo_indexes) 97 | pseudo_labels = torch.cat(pseudo_labels) 98 | 99 | assert num_labeled == args.index_labeled.shape[0] 100 | 101 | pseudo_trainset.set_pseudo(pseudo_indexes, pseudo_labels) 102 | 103 | print('num_pseudo = {}'.format(pseudo_indexes.shape[0])) 104 | 105 | pseudo_trainloder = torch.utils.data.DataLoader( 106 | pseudo_trainset, batch_size=args.batch_size, 107 | shuffle=True, num_workers=args.num_workers) 108 | 109 | return testloader, pseudo_trainloder 110 | 111 | 112 | def build_model(args): 113 | print("=> creating model '{}'".format(args.architecture)) 114 | net = models.__dict__[args.architecture]() 115 | net = net.to(args.device) 116 | 117 | print('#param: {}'.format(sum([p.nelement() for p in net.parameters()]))) 118 | 119 | if args.device == 'cuda': 120 | net = torch.nn.DataParallel( 121 | net, device_ids=range(torch.cuda.device_count())) 122 | cudnn.benchmark = True 123 | 124 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, 125 | weight_decay=0, nesterov=True) 126 | 127 | # resume from unsupervised pretrain 128 | if len(args.resume) > 0: 129 | print('==> Resuming from {}'.format(args.resume)) 130 | global best_acc, global_step 131 | checkpoint = torch.load(args.resume) 132 | net.load_state_dict(checkpoint['net']) 133 | optimizer.load_state_dict(checkpoint['optimizer']) 134 | best_acc = checkpoint['best_acc'] 135 | global_step = checkpoint['step'] + 1 136 | elif len(args.pretrained) > 0: 137 | # Load checkpoint. 138 | print('==> Load pretrained model: {}'.format(args.pretrained)) 139 | checkpoint = torch.load(args.pretrained) 140 | model_dict = net.state_dict() 141 | # only load shared conv layers, don't load fc 142 | pretrained_dict = {k: v for k, v in checkpoint['state_dict'].items() 143 | if k in model_dict 144 | and v.size() == model_dict[k].size()} 145 | assert len(pretrained_dict) > 0 146 | model_dict.update(pretrained_dict) 147 | net.load_state_dict(model_dict) 148 | 149 | return net, optimizer 150 | 151 | 152 | def get_lr_scheduler(optimizer, lr_scheduler, max_iters): 153 | if lr_scheduler == 'cosine': 154 | scheduler = CosineAnnealingLR(optimizer, max_iters, eta_min=0.00001) 155 | elif lr_scheduler == 'cosine-with-restart': 156 | scheduler = CosineAnnealingLRWithRestart(optimizer, eta_min=0.00001) 157 | elif lr_scheduler == 'multi-step': 158 | scheduler = MultiStepLR( 159 | optimizer, [max_iters * 3 // 7, max_iters * 6 // 7], gamma=0.1) 160 | else: 161 | raise ValueError("not supported") 162 | 163 | return scheduler 164 | 165 | 166 | def inf_generator(trainloader): 167 | while True: 168 | for data in trainloader: 169 | yield data 170 | 171 | 172 | # Training 173 | def train(net, optimizer, scheduler, trainloader, testloader, criterion, summary_writer, args): 174 | train_loss = AverageMeter() 175 | data_time = AverageMeter() 176 | batch_time = AverageMeter() 177 | top1 = AverageMeter() 178 | top5 = AverageMeter() 179 | 180 | best_acc = 0 181 | end = time.time() 182 | 183 | global global_step 184 | for inputs, targets in inf_generator(trainloader): 185 | if global_step >= args.max_iters: 186 | break 187 | 188 | data_time.update(time.time() - end) 189 | 190 | inputs, targets = inputs.to(args.device), targets.to(args.device) 191 | 192 | # switch to train mode 193 | net.train() 194 | scheduler.step(global_step) 195 | optimizer.zero_grad() 196 | 197 | outputs = net(inputs) 198 | loss = criterion(outputs, targets) 199 | prec1, prec5 = accuracy(outputs, targets, topk=(1, 5)) 200 | top1.update(prec1[0], inputs.size(0)) 201 | top5.update(prec5[0], inputs.size(0)) 202 | 203 | loss.backward() 204 | optimizer.step() 205 | 206 | train_loss.update(loss.item(), inputs.size(0)) 207 | 208 | # measure elapsed time 209 | batch_time.update(time.time() - end) 210 | end = time.time() 211 | 212 | summary_writer.add_scalar( 213 | 'lr', optimizer.param_groups[0]['lr'], global_step) 214 | summary_writer.add_scalar('top1', top1.val, global_step) 215 | summary_writer.add_scalar('top5', top5.val, global_step) 216 | summary_writer.add_scalar('batch_time', batch_time.val, global_step) 217 | summary_writer.add_scalar('data_time', data_time.val, global_step) 218 | summary_writer.add_scalar('train_loss', train_loss.val, global_step) 219 | 220 | if global_step % args.print_freq == 0: 221 | lr = optimizer.param_groups[0]['lr'] 222 | print(f'Train: [{global_step}/{args.max_iters}] ' 223 | f'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 224 | f'Data: {data_time.val:.3f} ({data_time.avg:.3f}) ' 225 | f'Lr: {lr:.5f} ' 226 | f'prec1: {top1.val:.3f} ({top1.avg:.3f}) ' 227 | f'prec5: {top5.val:.3f} ({top5.avg:.3f}) ' 228 | f'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f})') 229 | 230 | if (global_step + 1) % args.eval_freq == 0 or global_step == args.max_iters - 1: 231 | acc = validate(testloader, net, criterion, 232 | device=args.device, print_freq=args.print_freq) 233 | 234 | summary_writer.add_scalar('val_top1', acc, global_step) 235 | 236 | if acc > best_acc: 237 | best_acc = acc 238 | state = { 239 | 'step': global_step, 240 | 'best_acc': best_acc, 241 | 'net': net.state_dict(), 242 | 'optimizer': optimizer.state_dict(), 243 | } 244 | os.makedirs(args.model_dir, exist_ok=True) 245 | torch.save(state, os.path.join(args.model_dir, 'ckpt.pth.tar')) 246 | 247 | print('best accuracy: {:.2f}\n'.format(best_acc)) 248 | global_step += 1 249 | 250 | 251 | def main(args): 252 | # Data 253 | print('==> Preparing data..') 254 | testloader, pseudo_trainloder = get_dataloader(args) 255 | 256 | print('==> Building model..') 257 | net, optimizer = build_model(args) 258 | 259 | criterion = nn.__dict__[args.criterion]().to(args.device) 260 | scheduler = get_lr_scheduler(optimizer, args.lr_scheduler, args.max_iters) 261 | 262 | if args.eval: 263 | return validate(testloader, net, criterion, 264 | device=args.device, print_freq=args.print_freq) 265 | # summary writer 266 | os.makedirs(args.log_dir, exist_ok=True) 267 | summary_writer = SummaryWriter(args.log_dir) 268 | 269 | train(net, optimizer, scheduler, pseudo_trainloder, 270 | testloader, criterion, summary_writer, args) 271 | 272 | 273 | if __name__ == '__main__': 274 | parser = argparse.ArgumentParser(description='PyTorch Imagenet Training', 275 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 276 | parser.add_argument('--data-dir', '--dataDir', required=True, 277 | type=str, metavar='DIR', help='data dir') 278 | parser.add_argument('--model-root', default='./checkpoint/imagenet', 279 | type=str, metavar='DIR', 280 | help='root directory to save checkpoint') 281 | parser.add_argument('--log-root', default='./tensorboard/imagenet', 282 | type=str, metavar='DIR', 283 | help='root directory to save tensorboard logs') 284 | parser.add_argument('--exp-name', default='exp', type=str, 285 | help='experiment name, used to determine log_dir and model_dir') 286 | parser.add_argument('--lr', default=0.01, type=float, 287 | metavar='LR', help='learning rate') 288 | parser.add_argument('--lr-scheduler', default='multi-step', type=str, 289 | choices=['multi-step', 'cosine', 290 | 'cosine-with-restart'], 291 | help='which lr scheduler to use') 292 | parser.add_argument('--pretrained', default='', type=str, 293 | metavar='FILE', help='The pretrained checkpoint to load. Only load model parametric') 294 | parser.add_argument('--resume', '-r', default='', type=str, 295 | metavar='FILE', help='resume from checkpoint. Optimizer state will be resumed too') 296 | parser.add_argument('--eval', action='store_true', help='test only') 297 | parser.add_argument('--finetune', action='store_true', 298 | help='only training last fc layer') 299 | parser.add_argument('-j', '--num-workers', default=32, type=int, 300 | metavar='N', help='number of workers to load data') 301 | parser.add_argument('-b', '--batch-size', default=256, type=int, 302 | metavar='N', help='batch size') 303 | parser.add_argument('--max-iters', default=50000, type=int, 304 | metavar='N', help='number of iterations') 305 | parser.add_argument('--num-labeled', default=13000, type=int, 306 | metavar='N', help='number of labeled data') 307 | parser.add_argument('--rng-seed', default=0, type=int, 308 | metavar='N', help='random number generator seed') 309 | parser.add_argument('--gpus', default='0,1,2,3', type=str, metavar='GPUS', 310 | help='ids of GPU to use') 311 | parser.add_argument('--eval-freq', default=500, type=int, 312 | metavar='N', help='eval frequence') 313 | parser.add_argument('--print-freq', default=10, type=int, 314 | metavar='N', help='print frequence') 315 | parser.add_argument('--criterion', default='CrossEntropyLoss', type=str, 316 | choices=['CrossEntropyLoss', 'MultiMarginLoss'], help='Criterion to use') 317 | parser.add_argument('--pseudo-dir', type=str, 318 | metavar='PATH', help='pseudo folder to load') 319 | parser.add_argument('--pseudo-ratio', default=0.1, type=float, metavar='0-1', 320 | help='ratio of unlabeled data to use for pseudo labels') 321 | parser.add_argument('--architecture', '--arch', default='resnet18', type=str, 322 | help='which backbone to use') 323 | args_, rest = parser.parse_known_args() 324 | print(rest) 325 | 326 | os.environ["CUDA_VISIBLE_DEVICES"] = args_.gpus 327 | args_.device = 'cuda' if torch.cuda.is_available() else 'cpu' 328 | args_.num_class = 1000 329 | args_.log_dir = os.path.join(args_.log_root, args_.exp_name) 330 | args_.model_dir = os.path.join(args_.model_root, args_.exp_name) 331 | 332 | torch.manual_seed(args_.rng_seed) 333 | torch.cuda.manual_seed(args_.rng_seed) 334 | random.seed(args_.rng_seed) 335 | torch.set_printoptions(threshold=50, precision=4) 336 | 337 | print('-' * 80) 338 | pprint(vars(args_)) 339 | 340 | main(args_) 341 | 342 | print('-' * 80) 343 | pprint(vars(args_)) 344 | --------------------------------------------------------------------------------