├── lib
├── __init__.py
├── models
│ ├── __init__.py
│ ├── wideresnet.py
│ ├── resnet_cifar.py
│ └── resnet.py
├── datasets
│ ├── __init__.py
│ ├── cifar.py
│ └── folder.py
├── normalize.py
├── NCECriterion.py
├── LinearAverage.py
├── alias_multinomial.py
├── utils.py
└── NCEAverage.py
├── scripts
├── instance_cifar10.sh
├── finetune_cifar10.sh
├── finetune_imagenet.sh
└── download_model.sh
├── LICENSE
├── SECURITY.md
├── test.py
├── .gitignore
├── README.md
├── unsupervised
├── cifar.py
└── imagenet.py
├── notebooks
├── knn-imagenet.ipynb
└── nc-colorization.ipynb
├── cifar-semi.py
└── imagenet-semi.py
/lib/__init__.py:
--------------------------------------------------------------------------------
1 | # nothing
2 |
--------------------------------------------------------------------------------
/lib/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 | from .resnet_cifar import *
3 | from .wideresnet import *
4 |
--------------------------------------------------------------------------------
/lib/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .cifar import CIFAR10Instance, PseudoCIFAR10
2 | from .folder import ImageFolderInstance, PseudoDatasetFolder
3 |
4 | __all__ = ('CIFAR10Instance', 'PseudoDatasetFolder')
5 |
6 |
--------------------------------------------------------------------------------
/scripts/instance_cifar10.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | export PYTHONPATH=$PYTHONPATH:$(pwd)
6 |
7 | CUDA_VISIBLE_DEVICES=5 python unsupervised/cifar.py --lr-scheduler cosine-with-restart --epochs 1270
8 |
--------------------------------------------------------------------------------
/scripts/finetune_cifar10.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | num_labeled=250
4 | python cifar-semi.py \
5 | --gpus 0 \
6 | --num-labeled ${num_labeled} \
7 | --pseudo-file checkpoint/pseudos/instance_nc_wrn-28-2/${num_labeled}_T_1.pth.tar \
8 | --resume checkpoint/pretrain_models/ckpt_instance_cifar10_wrn-28-2_82.12.pth.tar \
9 | --pseudo-ratio 0.2
10 |
--------------------------------------------------------------------------------
/lib/normalize.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class Normalize(nn.Module):
5 |
6 | def __init__(self, power=2):
7 | super(Normalize, self).__init__()
8 | self.power = power
9 |
10 | def forward(self, x):
11 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
12 | out = x.div(norm)
13 | return out
14 |
--------------------------------------------------------------------------------
/scripts/finetune_imagenet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # num_labeled=13000
4 | # pseudo_ratio=0.1
5 |
6 | # or
7 | num_labeled=26000
8 | pseudo_ratio=0.2
9 |
10 | # or
11 | # num_labeled=51000
12 | # pseudo_ratio=0.5
13 | python imagenet-semi.py \
14 | --arch resnet50 \
15 | --gpus 1,2,6,7 \
16 | --num-labeled ${num_labeled} \
17 | --data-dir /home/liubin/data/imagenet \
18 | --pretrained checkpoint/pretrain_models/lemniscate_resnet50.pth.tar \
19 | --pseudo-dir checkpoint/pseudos_imagenet/instance_imagenet_nc_resnet50/num_labeled_${num_labeled} \
20 | --pseudo-ratio ${pseudo_ratio} \
21 |
22 |
--------------------------------------------------------------------------------
/lib/NCECriterion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | eps = 1e-7
5 |
6 |
7 | class NCECriterion(nn.Module):
8 |
9 | def __init__(self, n_lem):
10 | super(NCECriterion, self).__init__()
11 | self.n_lem = n_lem
12 |
13 | def forward(self, x, targets):
14 | batchSize = x.size(0)
15 | K = x.size(1) - 1
16 | Pnt = 1 / float(self.n_lem)
17 | Pns = 1 / float(self.n_lem)
18 |
19 | # eq 5.1 : P(origin=model) = Pmt / (Pmt + k*Pnt)
20 | Pmt = x.select(1, 0)
21 | Pmt_div = Pmt.add(K * Pnt + eps)
22 | lnPmt = torch.div(Pmt, Pmt_div)
23 |
24 | # eq 5.2 : P(origin=noise) = k*Pns / (Pms + k*Pns)
25 | Pon_div = x.narrow(1, 1, K).add(K * Pns + eps)
26 | Pon = Pon_div.clone().fill_(K * Pns)
27 | lnPon = torch.div(Pon, Pon_div)
28 |
29 | # equation 6 in ref. A
30 | lnPmt.log_()
31 | lnPon.log_()
32 |
33 | lnPmtsum = lnPmt.sum(0)
34 | lnPonsum = lnPon.view(-1, 1).sum(0)
35 |
36 | loss = - (lnPmtsum + lnPonsum) / batchSize
37 |
38 | return loss
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation. All rights reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/lib/LinearAverage.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from torch import nn
4 | import math
5 |
6 |
7 | class LinearAverageOp(Function):
8 | @staticmethod
9 | def forward(self, x, y, memory, params):
10 | T = params[0].item()
11 |
12 | # inner product
13 | out = torch.mm(x.data, memory.t())
14 | out.div_(T) # batchSize * N
15 |
16 | self.save_for_backward(x, memory, y, params)
17 |
18 | return out
19 |
20 | @staticmethod
21 | def backward(self, gradOutput):
22 | x, memory, y, params = self.saved_tensors
23 | T = params[0].item()
24 | momentum = params[1].item()
25 |
26 | # add temperature
27 | gradOutput.data.div_(T)
28 |
29 | # gradient of linear
30 | gradInput = torch.mm(gradOutput.data, memory)
31 | gradInput.resize_as_(x)
32 |
33 | # update the non-parametric data
34 | weight_pos = memory.index_select(0, y.data.view(-1)).resize_as_(x)
35 | weight_pos.mul_(momentum)
36 | weight_pos.add_(torch.mul(x.data, 1 - momentum))
37 | w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5)
38 | updated_weight = weight_pos.div(w_norm)
39 | memory.index_copy_(0, y, updated_weight)
40 |
41 | return gradInput, None, None, None
42 |
43 |
44 | class LinearAverage(nn.Module):
45 |
46 | def __init__(self, inputSize, outputSize, T=0.07, momentum=0.5):
47 | super(LinearAverage, self).__init__()
48 | self.nLem = outputSize
49 |
50 | self.register_buffer('params', torch.tensor([T, momentum]))
51 | stdv = 1. / math.sqrt(inputSize / 3)
52 | self.register_buffer('memory', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))
53 |
54 | def forward(self, x, y):
55 | out = LinearAverageOp.apply(x, y, self.memory, self.params)
56 | return out
57 |
--------------------------------------------------------------------------------
/lib/alias_multinomial.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class AliasMethod(object):
5 | """
6 | From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
7 | """
8 |
9 | def __init__(self, probs):
10 |
11 | if probs.sum() > 1:
12 | probs.div_(probs.sum())
13 | K = len(probs)
14 | self.prob = torch.zeros(K)
15 | self.alias = torch.LongTensor([0] * K)
16 |
17 | # Sort the data into the outcomes with probabilities
18 | # that are larger and smaller than 1/K.
19 | smaller = []
20 | larger = []
21 | for kk, prob in enumerate(probs):
22 | self.prob[kk] = K * prob
23 | if self.prob[kk] < 1.0:
24 | smaller.append(kk)
25 | else:
26 | larger.append(kk)
27 |
28 | # Loop though and create little binary mixtures that
29 | # appropriately allocate the larger outcomes over the
30 | # overall uniform mixture.
31 | while len(smaller) > 0 and len(larger) > 0:
32 | small = smaller.pop()
33 | large = larger.pop()
34 |
35 | self.alias[small] = large
36 | self.prob[large] = (self.prob[large] - 1.0) + self.prob[small]
37 |
38 | if self.prob[large] < 1.0:
39 | smaller.append(large)
40 | else:
41 | larger.append(large)
42 |
43 | for last_one in smaller + larger:
44 | self.prob[last_one] = 1
45 |
46 | def cuda(self):
47 | self.prob = self.prob.cuda()
48 | self.alias = self.alias.cuda()
49 |
50 | def draw(self, N):
51 | """
52 | Draw N samples from multinomial
53 | """
54 | K = self.alias.size(0)
55 |
56 | kk = torch.zeros(N, dtype=torch.long, device=self.prob.device).random_(0, K)
57 | prob = self.prob.index_select(0, kk)
58 | alias = self.alias.index_select(0, kk)
59 | # b is whether a random number is greater than q
60 | b = torch.bernoulli(prob)
61 | oq = kk.mul(b.long())
62 | oj = alias.mul((1 - b).long())
63 |
64 | return oq + oj
65 |
--------------------------------------------------------------------------------
/lib/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch.optim.lr_scheduler import CosineAnnealingLR
5 |
6 |
7 | # noinspection PyAttributeOutsideInit
8 | class AverageMeter(object):
9 | """Computes and stores the average and current value"""
10 |
11 | def __init__(self):
12 | self.reset()
13 |
14 | def reset(self):
15 | self.val = 0
16 | self.avg = 0
17 | self.sum = 0
18 | self.count = 0
19 |
20 | def update(self, val, n=1):
21 | self.val = val
22 | self.sum += val * n
23 | self.count += n
24 | self.avg = self.sum / self.count
25 |
26 |
27 | class CosineAnnealingLRWithRestart(CosineAnnealingLR):
28 | """Adjust learning rate"""
29 |
30 | def __init__(self, optimizer, eta_min=0, lr_t_0=10, lr_t_mul=2, last_epoch=-1):
31 | self.eta_min = eta_min
32 | self.lr_t_curr = lr_t_0
33 | self.lr_t_mul = lr_t_mul
34 | self.last_reset = 0
35 | super(CosineAnnealingLRWithRestart, self).__init__(optimizer, last_epoch)
36 |
37 | def get_lr(self):
38 | curr_epoch = self.last_epoch - self.last_reset
39 | if curr_epoch >= self.lr_t_curr:
40 | self.lr_t_curr *= self.lr_t_mul
41 | self.last_reset = self.last_epoch
42 | rate = 0
43 | else:
44 | rate = curr_epoch * math.pi / self.lr_t_curr
45 | return [self.eta_min + 0.5 * (base_lr - self.eta_min) * (1.0 + math.cos(rate))
46 | for base_lr in self.base_lrs]
47 |
48 |
49 | def accuracy(output, target, topk=(1,)):
50 | """Computes the precision@k for the specified values of k"""
51 | with torch.no_grad():
52 | maxk = max(topk)
53 | batch_size = target.size(0)
54 |
55 | _, pred = output.topk(maxk, 1, True, True)
56 | pred = pred.t()
57 | correct = pred.eq(target.view(1, -1).expand_as(pred))
58 |
59 | res = []
60 | for k in topk:
61 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
62 | res.append(correct_k.mul_(100.0 / batch_size))
63 | return res
64 |
65 |
66 | train_labels_ = None
67 |
68 |
69 | def get_train_labels(trainloader, device='cuda'):
70 | global train_labels_
71 | if train_labels_ is None:
72 | print("=> loading all train labels")
73 | train_labels = -1 * torch.ones([len(trainloader.dataset)], dtype=torch.long)
74 | for i, (_, label, index) in enumerate(trainloader):
75 | train_labels[index] = label
76 | if i % 10000 == 0:
77 | print("{}/{}".format(i, len(trainloader)))
78 | assert all(train_labels != -1)
79 | train_labels_ = train_labels.to(device)
80 | return train_labels_
81 |
--------------------------------------------------------------------------------
/scripts/download_model.sh:
--------------------------------------------------------------------------------
1 | set -x
2 | set -e
3 |
4 | base_url="https://frontiers.blob.core.windows.net/metric-transfer"
5 | local_root=checkpoint
6 |
7 | mkdir -p log
8 |
9 | # you can comment some file if you don't want to download all of them.
10 | echo "downloading pretrained models"
11 | dirname=pretrain_models
12 | mkdir -p ${local_root}/${dirname}
13 | for filename in \
14 | ckpt_colorization_wrn-28-2.pth.tar \
15 | ckpt_instance_cifar10_wrn-28-10_89.83.pth.tar \
16 | ckpt_imagenet32x32_instance_wrn-28-2.pth.tar \
17 | ckpt_instance_cifar10_wrn-28-2_82.12.pth.tar \
18 | ckpt_imagenet32x32_snca_wrn-28-2.pth.tar \
19 | lemniscate_resnet18.pth.tar \
20 | ckpt_imagenet32x32_softmax_wrn-28-2.pth.tar \
21 | lemniscate_resnet50.pth.tar \
22 | ckpt_instance_cifar10_resnet18_85.69.pth.tar;
23 | do
24 | file=${dirname}/${filename};
25 | wget ${base_url}/${file} -O ${local_root}/${file} -o log/${dirname}_${filename}.txt --no-clobber &
26 | done
27 | wait
28 |
29 |
30 | echo "downloading pre-extracted features"
31 | dirname=train_features_labels_cache
32 | mkdir -p ${local_root}/${dirname}
33 | for filename in \
34 | colorization_embedding_128.t7 \
35 | instance_imagenet_val_feature_resnet50.pth.tar \
36 | instance_imagenet_train_feature_resnet50.pth.tar;
37 | do
38 | file=${dirname}/${filename};
39 | wget ${base_url}/${file} -O ${local_root}/${file} -o log/${dirname}_${filename}.txt --no-clobber &
40 | done
41 | wait
42 |
43 |
44 | echo "downloading pseudo file for cifar10 dataset"
45 | dirname=pseudos
46 | mkdir -p ${local_root}/${dirname}
47 | for filename in \
48 | colorization_knn_wrn-28-2.tar \
49 | imagenet32x32_snca_nc_wrn-28-2.tar \
50 | instance_nc_wrn-28-2.tar \
51 | colorization_nc_wrn-28-2.tar \
52 | imagenet32x32_softmax_nc_wrn-28-2.tar \
53 | imagenet32x32_instance_nc_wrn-28-2.tar \
54 | instance_knn_wrn-28-2.tar;
55 | do
56 | file=${dirname}/${filename};
57 | wget ${base_url}/${file} -O ${local_root}/${file} -o log/${dirname}_${filename}.txt --no-clobber &
58 | done
59 | wait
60 |
61 | echo "downloading pseudo file for imagenet dataset"
62 | dirname=pseudos_imagenet/instance_imagenet_nc_resnet50
63 | mkdir -p ${local_root}/${dirname}
64 | for filename in \
65 | num_labeled_13000.tar \
66 | num_labeled_26000.tar \
67 | num_labeled_51000.tar;
68 | do
69 | file=${dirname}/${filename};
70 | wget ${base_url}/${file} -O ${local_root}/${file} -o log/pseudos_imagenet_${filename}.txt --no-clobber &
71 | done
72 | wait
73 |
74 | echo "download finished, extracting"
75 | for folder in pseudos pseudos_imagenet/instance_imagenet_nc_resnet50; do
76 | (
77 | cd ${local_root}/${folder};
78 | for i in $(ls *.tar); do
79 | tar xvf $i;
80 | rm $i;
81 | done
82 | )
83 | done
84 |
85 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/lib/NCEAverage.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from torch import nn
4 | from .alias_multinomial import AliasMethod
5 | import math
6 |
7 |
8 | class NCEFunction(Function):
9 | @staticmethod
10 | def forward(self, x, y, memory, idx, params):
11 | K = int(params[0].item())
12 | T = params[1].item()
13 | Z = params[2].item()
14 |
15 | batchSize = x.size(0)
16 | outputSize = memory.size(0)
17 | inputSize = memory.size(1)
18 |
19 | # sample positives & negatives
20 | idx.select(1, 0).copy_(y.data)
21 |
22 | # sample correspoinding weights
23 | weight = torch.index_select(memory, 0, idx.view(-1))
24 | weight.resize_(batchSize, K + 1, inputSize)
25 |
26 | # inner product
27 | out = torch.bmm(weight, x.data.resize_(batchSize, inputSize, 1))
28 | out.div_(T).exp_() # batchSize * self.K+1
29 | x.data.resize_(batchSize, inputSize)
30 |
31 | if Z < 0:
32 | params[2] = out.mean() * outputSize
33 | Z = params[2].item()
34 | print("normalization constant Z is set to {:.1f}".format(Z))
35 |
36 | out.div_(Z).resize_(batchSize, K + 1)
37 |
38 | self.save_for_backward(x, memory, y, weight, out, params)
39 |
40 | return out
41 |
42 | @staticmethod
43 | def backward(self, gradOutput):
44 | x, memory, y, weight, out, params = self.saved_tensors
45 | K = int(params[0].item())
46 | T = params[1].item()
47 | momentum = params[3].item()
48 | batchSize = gradOutput.size(0)
49 |
50 | # gradients d Pm / d linear = exp(linear) / Z
51 | gradOutput.data.mul_(out.data)
52 | # add temperature
53 | gradOutput.data.div_(T)
54 |
55 | gradOutput.data.resize_(batchSize, 1, K + 1)
56 |
57 | # gradient of linear
58 | gradInput = torch.bmm(gradOutput.data, weight)
59 | gradInput.resize_as_(x)
60 |
61 | # update the non-parametric data
62 | weight_pos = weight.select(1, 0).resize_as_(x)
63 | weight_pos.mul_(momentum)
64 | weight_pos.add_(torch.mul(x.data, 1 - momentum))
65 | w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5)
66 | updated_weight = weight_pos.div(w_norm)
67 | memory.index_copy_(0, y, updated_weight)
68 |
69 | return gradInput, None, None, None, None
70 |
71 |
72 | class NCEAverage(nn.Module):
73 |
74 | def __init__(self, inputSize, outputSize, K, T=0.07, momentum=0.5):
75 | super(NCEAverage, self).__init__()
76 | self.nLem = outputSize
77 | self.unigrams = torch.ones(self.nLem)
78 | self.multinomial = AliasMethod(self.unigrams)
79 | self.multinomial.cuda()
80 | self.K = K
81 |
82 | self.register_buffer('params', torch.tensor([K, T, -1, momentum]))
83 | stdv = 1. / math.sqrt(inputSize / 3)
84 | self.register_buffer('memory', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))
85 |
86 | def forward(self, x, y):
87 | batchSize = x.size(0)
88 | idx = self.multinomial.draw(batchSize * (self.K + 1)).view(batchSize, -1)
89 | out = NCEFunction.apply(x, y, self.memory, idx, self.params)
90 | return out
91 |
--------------------------------------------------------------------------------
/lib/datasets/cifar.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from PIL import Image
3 | import torchvision.datasets as datasets
4 | import torch.utils.data as data
5 | import torch
6 | import numpy as np
7 |
8 |
9 | class CIFAR10Instance(datasets.CIFAR10):
10 | """CIFAR10Instance Dataset.
11 | """
12 |
13 | def __getitem__(self, index):
14 | if self.train:
15 | img, target = self.data[index], self.targets[index]
16 | else:
17 | img, target = self.data[index], self.targets[index]
18 |
19 | # doing this so that it is consistent with all other datasets
20 | # to return a PIL Image
21 | img = Image.fromarray(img)
22 |
23 | if self.transform is not None:
24 | img = self.transform(img)
25 |
26 | if self.target_transform is not None:
27 | target = self.target_transform(target)
28 |
29 | return img, target, index
30 |
31 |
32 | class PseudoCIFAR10(datasets.CIFAR10):
33 | """CIFAR10Instance Dataset.
34 | """
35 |
36 | def __init__(self, labeled_indexes, **kwargs):
37 | super(PseudoCIFAR10, self).__init__(**kwargs)
38 | assert self.train
39 | self.labeled_indexes = labeled_indexes.cpu().numpy().copy()
40 | self.C = 10
41 | self.labels = np.array(self.targets)[self.labeled_indexes]
42 | self.indexes = self.labeled_indexes
43 |
44 | def __len__(self):
45 | return self.indexes.shape[0]
46 |
47 | def set_pseudo(self, pseudo_indexes, pseudo_labels):
48 | assert pseudo_indexes.shape == pseudo_labels.shape
49 |
50 | self.labels = np.concatenate(
51 | [np.array(self.targets)[self.labeled_indexes], pseudo_labels.cpu().numpy().copy()], axis=0)
52 | self.indexes = np.concatenate([self.labeled_indexes, pseudo_indexes.cpu().numpy().copy()], axis=0)
53 |
54 | def __getitem__(self, index):
55 | real_index = self.indexes[index]
56 | img = self.data[real_index]
57 | target = self.labels[index]
58 |
59 | # doing this so that it is consistent with all other datasets
60 | # to return a PIL Image
61 | img = Image.fromarray(img)
62 |
63 | if self.transform is not None:
64 | img = self.transform(img)
65 |
66 | if self.target_transform is not None:
67 | target = self.target_transform(target)
68 |
69 | return img, target
70 |
71 |
72 | if __name__ == '__main__':
73 | import torchvision.transforms as transforms
74 |
75 | _labeled_indexes = torch.arange(10)
76 |
77 | transform_train = transforms.Compose([
78 | transforms.ToTensor(),
79 | ])
80 | ds = PseudoCIFAR10(
81 | labeled_indexes=_labeled_indexes,
82 | root='./data',
83 | transform=transform_train,
84 | download=True)
85 | loader = torch.utils.data.DataLoader(ds, batch_size=5, shuffle=True, num_workers=0)
86 | assert len(loader) == 2
87 | for i, (_img, _target) in enumerate(loader):
88 | print(_img.shape, _target)
89 | break
90 |
91 | # test pseudo
92 | _pseudo_indexes = torch.arange(100, 200)
93 | _pseudo_labels = torch.zeros([100])
94 | loader.dataset.set_pseudo(_pseudo_indexes, _pseudo_labels)
95 | assert len(loader) == 22 # (100 + 10) / 5
96 | for i, (_img, _target) in enumerate(loader):
97 | print(_img.shape, _target)
98 | break
99 |
--------------------------------------------------------------------------------
/lib/datasets/folder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import torchvision.datasets as datasets
4 |
5 |
6 | class ImageFolderInstance(datasets.ImageFolder):
7 | """: Folder datasets which returns the index of the image as well::
8 | """
9 | def __getitem__(self, index):
10 | """
11 | Args:
12 | index (int): Index
13 | Returns:
14 | tuple: (image, target) where target is class_index of the target class.
15 | """
16 | path, target = self.imgs[index]
17 | img = self.loader(path)
18 | if self.transform is not None:
19 | img = self.transform(img)
20 | if self.target_transform is not None:
21 | target = self.target_transform(target)
22 |
23 | return img, target, index
24 |
25 |
26 | class PseudoDatasetFolder(Dataset):
27 |
28 | def __init__(self, ds, labeled_indexes):
29 | self.ds = ds
30 | self.labeled_indexes = labeled_indexes
31 | self.num_labeled = len(self.labeled_indexes)
32 | # self.labeled_indexes_set = set(labeled_indexes.cpu().numpy())
33 | self.pseudo_indexes = []
34 | self.pseudo_labels = None
35 |
36 | def __len__(self):
37 | return self.num_labeled + len(self.pseudo_indexes)
38 |
39 | def __getitem__(self, index):
40 |
41 | if index < self.num_labeled:
42 | # labeled
43 | real_index = self.labeled_indexes[index]
44 | sample, target = self.ds[real_index]
45 | else:
46 | # pseudo
47 | real_index = self.pseudo_indexes[index - self.num_labeled]
48 | sample, _ = self.ds[real_index]
49 | target = self.pseudo_labels[index - self.num_labeled]
50 | return sample, target
51 |
52 | def set_pseudo(self, pseudo_indexes, pseudo_labels):
53 | assert len(pseudo_indexes) == len(pseudo_labels)
54 | self.pseudo_indexes = pseudo_indexes
55 | if isinstance(pseudo_labels, torch.Tensor):
56 | pseudo_labels = pseudo_labels.cpu().numpy()
57 | self.pseudo_labels = pseudo_labels
58 |
59 |
60 | if __name__ == '__main__':
61 | from torchvision import datasets
62 | import torchvision.transforms as transforms
63 | transform_test = transforms.Compose([
64 | transforms.Resize(256),
65 | transforms.CenterCrop(224),
66 | transforms.ToTensor(),
67 | ])
68 | trainset = datasets.ImageFolder('/home/liubin/data/imagenet/train/', transform=transform_test)
69 | # test list
70 | labeled_indexes_ = [1]
71 | pseudo_indexes_, pseudo_labels_ = [2], [10]
72 | pseudo_trainset = PseudoDatasetFolder(trainset, labeled_indexes=labeled_indexes_)
73 | pseudo_trainset.set_pseudo(pseudo_indexes_, pseudo_labels_)
74 | for i, (_, target_) in enumerate(pseudo_trainset):
75 | if i == 0:
76 | assert target_ == 0
77 | else:
78 | assert target_ == 10
79 |
80 | # test np array
81 | import numpy as np
82 | labeled_indexes_ = np.array([1])
83 | pseudo_indexes_, pseudo_labels_ = np.array([2]), np.array([10])
84 | pseudo_trainset = PseudoDatasetFolder(trainset, labeled_indexes=labeled_indexes_)
85 | pseudo_trainset.set_pseudo(pseudo_indexes_, pseudo_labels_)
86 | for i, (_, target_) in enumerate(pseudo_trainset):
87 | if i == 0:
88 | assert target_ == 0
89 | else:
90 | assert target_ == 10
91 |
92 | # test torch tensor
93 | n = len(trainset)
94 | num_labeled = n // 2
95 | labeled_indexes_ = torch.arange(num_labeled)
96 | pseudo_indexes_ = torch.arange(num_labeled, n)
97 | pseudo_labels_ = torch.zeros([n - num_labeled], dtype=torch.int64)
98 | pseudo_trainset = PseudoDatasetFolder(trainset, labeled_indexes=labeled_indexes_)
99 | pseudo_trainset.set_pseudo(pseudo_indexes_, pseudo_labels_)
100 | assert pseudo_trainset[0][1] == trainset.samples[0][1]
101 | assert pseudo_trainset[num_labeled][1] == 0
102 |
103 | # test loader
104 | pseudo_trainloder = torch.utils.data.DataLoader(
105 | pseudo_trainset, batch_size=256,
106 | shuffle=True, num_workers=8)
107 |
108 | for data_, target_ in pseudo_trainloder:
109 | print(data_, target_)
110 | break
111 |
--------------------------------------------------------------------------------
/lib/models/wideresnet.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from lib.normalize import Normalize
8 |
9 |
10 | class BasicBlock(nn.Module):
11 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0):
12 | super(BasicBlock, self).__init__()
13 | self.bn1 = nn.BatchNorm2d(in_planes)
14 | self.relu1 = nn.ReLU(inplace=True)
15 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
16 | padding=1, bias=False)
17 | self.bn2 = nn.BatchNorm2d(out_planes)
18 | self.relu2 = nn.ReLU(inplace=True)
19 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
20 | padding=1, bias=False)
21 | self.droprate = drop_rate
22 | self.equalInOut = (in_planes == out_planes)
23 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
24 | padding=0, bias=False) or None
25 |
26 | def forward(self, x):
27 | if not self.equalInOut:
28 | x = self.relu1(self.bn1(x))
29 | else:
30 | out = self.relu1(self.bn1(x))
31 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
32 | if self.droprate > 0:
33 | out = F.dropout(out, p=self.droprate, training=self.training)
34 | out = self.conv2(out)
35 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
36 |
37 |
38 | class NetworkBlock(nn.Module):
39 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
40 | super(NetworkBlock, self).__init__()
41 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
42 |
43 | @staticmethod
44 | def _make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate):
45 | layers = []
46 | for i in range(int(nb_layers)):
47 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
48 | return nn.Sequential(*layers)
49 |
50 | def forward(self, x):
51 | return self.layer(x)
52 |
53 |
54 | class WideResNet(nn.Module):
55 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0, norm=True):
56 | super(WideResNet, self).__init__()
57 | n_channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
58 | assert ((depth - 4) % 6 == 0)
59 | n = (depth - 4) / 6
60 | block = BasicBlock
61 | # 1st conv before any network block
62 | self.conv1 = nn.Conv2d(3, n_channels[0], kernel_size=3, stride=1,
63 | padding=1, bias=False)
64 | # 1st block
65 | self.block1 = NetworkBlock(n, n_channels[0], n_channels[1], block, 1, dropRate)
66 | # 2nd block
67 | self.block2 = NetworkBlock(n, n_channels[1], n_channels[2], block, 2, dropRate)
68 | # 3rd block
69 | self.block3 = NetworkBlock(n, n_channels[2], n_channels[3], block, 2, dropRate)
70 | # global average pooling and classifier
71 | self.bn1 = nn.BatchNorm2d(n_channels[3])
72 | self.relu = nn.ReLU(inplace=True)
73 | self.fc = nn.Linear(n_channels[3], num_classes)
74 | self.nChannels = n_channels[3]
75 |
76 | for m in self.modules():
77 | if isinstance(m, nn.Conv2d):
78 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
79 | m.weight.data.normal_(0, math.sqrt(2. / n))
80 | elif isinstance(m, nn.BatchNorm2d):
81 | m.weight.data.fill_(1)
82 | m.bias.data.zero_()
83 | elif isinstance(m, nn.Linear):
84 | m.bias.data.zero_()
85 |
86 | self.l2norm = Normalize(2)
87 | self.norm = norm
88 |
89 | def forward(self, x):
90 | out = self.conv1(x)
91 | out = self.block1(out)
92 | out = self.block2(out)
93 | out = self.block3(out)
94 | out = self.relu(self.bn1(out))
95 | out = F.avg_pool2d(out, 8)
96 | out = out.view(-1, self.nChannels)
97 | out = self.fc(out)
98 | if self.norm:
99 | out = self.l2norm(out)
100 | return out
101 |
--------------------------------------------------------------------------------
/lib/models/resnet_cifar.py:
--------------------------------------------------------------------------------
1 | """ResNet in PyTorch.
2 |
3 | For Pre-activation ResNet, see 'preact_resnet.py'.
4 |
5 | Reference:
6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
8 | """
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from lib.normalize import Normalize
14 |
15 |
16 | class BasicBlock(nn.Module):
17 | expansion = 1
18 |
19 | def __init__(self, in_planes, planes, stride=1):
20 | super(BasicBlock, self).__init__()
21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
22 | self.bn1 = nn.BatchNorm2d(planes)
23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
24 | self.bn2 = nn.BatchNorm2d(planes)
25 |
26 | self.shortcut = nn.Sequential()
27 | if stride != 1 or in_planes != self.expansion * planes:
28 | self.shortcut = nn.Sequential(
29 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
30 | nn.BatchNorm2d(self.expansion * planes)
31 | )
32 |
33 | def forward(self, x):
34 | out = F.relu(self.bn1(self.conv1(x)))
35 | out = self.bn2(self.conv2(out))
36 | out += self.shortcut(x)
37 | out = F.relu(out)
38 | return out
39 |
40 |
41 | class Bottleneck(nn.Module):
42 | expansion = 4
43 |
44 | def __init__(self, in_planes, planes, stride=1):
45 | super(Bottleneck, self).__init__()
46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
47 | self.bn1 = nn.BatchNorm2d(planes)
48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
49 | self.bn2 = nn.BatchNorm2d(planes)
50 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
51 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
52 |
53 | self.shortcut = nn.Sequential()
54 | if stride != 1 or in_planes != self.expansion * planes:
55 | self.shortcut = nn.Sequential(
56 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
57 | nn.BatchNorm2d(self.expansion * planes)
58 | )
59 |
60 | def forward(self, x):
61 | out = F.relu(self.bn1(self.conv1(x)))
62 | out = F.relu(self.bn2(self.conv2(out)))
63 | out = self.bn3(self.conv3(out))
64 | out += self.shortcut(x)
65 | out = F.relu(out)
66 | return out
67 |
68 |
69 | class ResNet(nn.Module):
70 | def __init__(self, block, num_blocks, low_dim=128, norm=True):
71 | super(ResNet, self).__init__()
72 | self.in_planes = 64
73 |
74 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
75 | self.bn1 = nn.BatchNorm2d(64)
76 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
77 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
78 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
79 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
80 | self.linear = nn.Linear(512 * block.expansion, low_dim)
81 | self.l2norm = Normalize(2)
82 | self.norm = norm
83 |
84 | def _make_layer(self, block, planes, num_blocks, stride):
85 | strides = [stride] + [1] * (num_blocks - 1)
86 | layers = []
87 | for stride in strides:
88 | layers.append(block(self.in_planes, planes, stride))
89 | self.in_planes = planes * block.expansion
90 | return nn.Sequential(*layers)
91 |
92 | def forward(self, x):
93 | out = F.relu(self.bn1(self.conv1(x)))
94 | out = self.layer1(out)
95 | out = self.layer2(out)
96 | out = self.layer3(out)
97 | out = self.layer4(out)
98 | out = F.avg_pool2d(out, 4)
99 | out = out.view(out.size(0), -1)
100 | out = self.linear(out)
101 | if self.norm:
102 | out = self.l2norm(out)
103 | return out
104 |
105 |
106 | def resnet18_cifar(low_dim=128, norm=True):
107 | return ResNet(block=BasicBlock, num_blocks=[2, 2, 2, 2], low_dim=low_dim, norm=norm)
108 |
109 |
110 | def resnet34_cifar(low_dim=128, norm=True):
111 | return ResNet(block=BasicBlock, num_blocks=[3, 4, 6, 3], low_dim=low_dim, norm=norm)
112 |
113 |
114 | def resnet50_cifar10(low_dim=128, norm=True):
115 | return ResNet(block=Bottleneck, num_blocks=[3, 4, 6, 3], low_dim=low_dim, norm=norm)
116 |
117 |
118 | def resnet101_cifar10(low_dim=128, norm=True):
119 | return ResNet(block=Bottleneck, num_blocks=[3, 4, 23, 3], low_dim=low_dim, norm=norm)
120 |
121 |
122 | def resnet152_cifar10(low_dim=128, norm=True):
123 | return ResNet(block=Bottleneck, num_blocks=[3, 8, 36, 3], low_dim=low_dim, norm=norm)
124 |
125 |
126 | if __name__ == '__main__':
127 | import numpy as np
128 |
129 | inputs = torch.randn(10, 3, 32, 32)
130 | y = resnet18_cifar(low_dim=1024, norm=True)(inputs)
131 | assert y.shape == (10, 1024)
132 | np.testing.assert_array_almost_equal(y.pow(2).sum(1).detach().numpy(), np.ones([10]))
133 |
134 | # test no norm
135 | y = resnet18_cifar(low_dim=1024, norm=False)(inputs)
136 | assert y.shape == (10, 1024)
137 | print(y.pow(2).sum(1).detach().numpy())
138 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import torch
4 |
5 | from lib.utils import AverageMeter, get_train_labels, accuracy
6 |
7 |
8 | def NN(net, lemniscate, trainloader, testloader, recompute_memory=0):
9 | net.eval()
10 | net_time = AverageMeter()
11 | cls_time = AverageMeter()
12 | correct = 0.
13 | total = 0
14 | testsize = testloader.dataset.__len__()
15 |
16 | train_features = lemniscate.memory.t()
17 | if hasattr(trainloader.dataset, 'imgs'):
18 | train_labels = torch.LongTensor(
19 | [y for (p, y) in trainloader.dataset.imgs]).cuda()
20 | else:
21 | train_labels = get_train_labels(trainloader)
22 | if recompute_memory:
23 | transform_bak = trainloader.dataset.transform
24 | trainloader.dataset.transform = testloader.dataset.transform
25 | temploader = torch.utils.data.DataLoader(
26 | trainloader.dataset, batch_size=100, shuffle=False, num_workers=1)
27 | for batch_idx, (inputs, targets, indexes) in enumerate(temploader):
28 | batch_size = inputs.size(0)
29 | features = net(inputs)
30 | train_features[:, batch_idx * batch_size:batch_idx *
31 | batch_size + batch_size] = features.data.t()
32 | train_labels = get_train_labels(trainloader)
33 | trainloader.dataset.transform = transform_bak
34 |
35 | end = time.time()
36 | with torch.no_grad():
37 | for batch_idx, (inputs, targets, indexes) in enumerate(testloader):
38 | targets = targets.cuda(non_blocking=True)
39 | batch_size = inputs.size(0)
40 | features = net(inputs)
41 | net_time.update(time.time() - end)
42 | end = time.time()
43 |
44 | dist = torch.mm(features, train_features)
45 |
46 | yd, yi = dist.topk(1, dim=1, largest=True, sorted=True)
47 | candidates = train_labels.view(1, -1).expand(batch_size, -1)
48 | retrieval = torch.gather(candidates, 1, yi)
49 |
50 | retrieval = retrieval.narrow(1, 0, 1).clone().view(-1)
51 |
52 | total += targets.size(0)
53 | correct += retrieval.eq(targets.data).sum().item()
54 |
55 | cls_time.update(time.time() - end)
56 | end = time.time()
57 |
58 | print(f'Test [{total}/{testsize}]\t'
59 | f'Net Time {net_time.val:.3f} ({net_time.avg:.3f})\t'
60 | f'Cls Time {cls_time.val:.3f} ({cls_time.avg:.3f})\t'
61 | f'Top1: {correct * 100. / total:.2f}')
62 |
63 | return correct / total
64 |
65 |
66 | def kNN(net, lemniscate, trainloader, testloader, K, sigma, recompute_memory=0):
67 | net.eval()
68 | net_time = AverageMeter()
69 | cls_time = AverageMeter()
70 | total = 0
71 | testsize = testloader.dataset.__len__()
72 |
73 | train_features = lemniscate.memory.t()
74 | if hasattr(trainloader.dataset, 'imgs'):
75 | train_labels = torch.LongTensor(
76 | [y for (p, y) in trainloader.dataset.imgs]).cuda()
77 | else:
78 | train_labels = get_train_labels(trainloader)
79 | C = train_labels.max() + 1
80 |
81 | if recompute_memory:
82 | transform_bak = trainloader.dataset.transform
83 | trainloader.dataset.transform = testloader.dataset.transform
84 | temploader = torch.utils.data.DataLoader(
85 | trainloader.dataset, batch_size=100, shuffle=False, num_workers=1)
86 | for batch_idx, (inputs, targets, indexes) in enumerate(temploader):
87 | bs = inputs.size(0)
88 | features = net(inputs)
89 | train_features[:, batch_idx * bs:batch_idx *
90 | bs + bs] = features.data.t()
91 | train_labels = get_train_labels(trainloader)
92 | trainloader.dataset.transform = transform_bak
93 |
94 | top1 = 0.
95 | top5 = 0.
96 | with torch.no_grad():
97 | retrieval_one_hot = torch.zeros(K, C).cuda()
98 | for batch_idx, (inputs, targets, indexes) in enumerate(testloader):
99 | end = time.time()
100 | targets = targets.cuda(non_blocking=True)
101 | bs = inputs.size(0)
102 | features = net(inputs)
103 | net_time.update(time.time() - end)
104 | end = time.time()
105 |
106 | dist = torch.mm(features, train_features)
107 |
108 | yd, yi = dist.topk(K, dim=1, largest=True, sorted=True)
109 | candidates = train_labels.view(1, -1).expand(bs, -1)
110 | retrieval = torch.gather(candidates, 1, yi)
111 |
112 | retrieval_one_hot.resize_(bs * K, C).zero_()
113 | retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1)
114 | yd_transform = yd.clone().div_(sigma).exp_()
115 | probs = torch.sum(torch.mul(retrieval_one_hot.view(
116 | bs, -1, C), yd_transform.view(bs, -1, 1)), 1)
117 | _, predictions = probs.sort(1, True)
118 |
119 | # Find which predictions match the target
120 | correct = predictions.eq(targets.data.view(-1, 1))
121 | cls_time.update(time.time() - end)
122 |
123 | top1 = top1 + correct.narrow(1, 0, 1).sum().item()
124 | top5 = top5 + correct.narrow(1, 0, 2).sum().item()
125 |
126 | total += targets.size(0)
127 |
128 | if batch_idx % 100 == 0:
129 | print(f'Test [{total}/{testsize}]\t'
130 | f'Net Time {net_time.val:.3f} ({net_time.avg:.3f})\t'
131 | f'Cls Time {cls_time.val:.3f} ({cls_time.avg:.3f})\t'
132 | f'Top1: {top1 * 100. / total:.2f} top5: {top5 * 100. / total:.2f}')
133 |
134 | print(top1 * 100. / total)
135 |
136 | return top1 / total
137 |
138 |
139 | def validate(val_loader, model, criterion, device='cpu', print_freq=100):
140 | batch_time = AverageMeter()
141 | losses = AverageMeter()
142 | top1 = AverageMeter()
143 | top5 = AverageMeter()
144 |
145 | # switch to evaluate mode
146 | model.eval()
147 |
148 | with torch.no_grad():
149 | end = time.time()
150 | for i, (data, target) in enumerate(val_loader):
151 | data, target = data.to(device), target.to(device)
152 |
153 | # compute output
154 | output = model(data)
155 | loss = criterion(output, target)
156 |
157 | # measure accuracy and record loss
158 | prec1, prec5 = accuracy(output, target, topk=(1, 5))
159 | losses.update(loss.item(), data.size(0))
160 | top1.update(prec1[0], data.size(0))
161 | top5.update(prec5[0], data.size(0))
162 |
163 | # measure elapsed time
164 | batch_time.update(time.time() - end)
165 | end = time.time()
166 |
167 | if i % print_freq == 0:
168 | print(f'Test: [{i}/{len(val_loader)}] '
169 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
170 | f'Loss {loss.val:.4f} ({loss.avg:.4f}) '
171 | f'Prec@1 {top1.val:.3f} ({top1.avg:.3f}) '
172 | f'Prec@5 {top5.val:.3f} ({top5.avg:.3f})')
173 |
174 | print(f' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}')
175 |
176 | return top1.avg
177 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ## Ignore Visual Studio temporary files, build results, and
2 | ## files generated by popular Visual Studio add-ons.
3 | ##
4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
5 |
6 | # User-specific files
7 | *.suo
8 | *.user
9 | *.userosscache
10 | *.sln.docstates
11 |
12 | # User-specific files (MonoDevelop/Xamarin Studio)
13 | *.userprefs
14 |
15 | # Build results
16 | [Dd]ebug/
17 | [Dd]ebugPublic/
18 | [Rr]elease/
19 | [Rr]eleases/
20 | x64/
21 | x86/
22 | bld/
23 | [Bb]in/
24 | [Oo]bj/
25 | [Ll]og/
26 |
27 | # Visual Studio 2015/2017 cache/options directory
28 | .vs/
29 | # Uncomment if you have tasks that create the project's static files in wwwroot
30 | #wwwroot/
31 |
32 | # Visual Studio 2017 auto generated files
33 | Generated\ Files/
34 |
35 | # MSTest test Results
36 | [Tt]est[Rr]esult*/
37 | [Bb]uild[Ll]og.*
38 |
39 | # NUNIT
40 | *.VisualState.xml
41 | TestResult.xml
42 |
43 | # Build Results of an ATL Project
44 | [Dd]ebugPS/
45 | [Rr]eleasePS/
46 | dlldata.c
47 |
48 | # Benchmark Results
49 | BenchmarkDotNet.Artifacts/
50 |
51 | # .NET Core
52 | project.lock.json
53 | project.fragment.lock.json
54 | artifacts/
55 | **/Properties/launchSettings.json
56 |
57 | # StyleCop
58 | StyleCopReport.xml
59 |
60 | # Files built by Visual Studio
61 | *_i.c
62 | *_p.c
63 | *_i.h
64 | *.ilk
65 | *.meta
66 | *.obj
67 | *.iobj
68 | *.pch
69 | *.pdb
70 | *.ipdb
71 | *.pgc
72 | *.pgd
73 | *.rsp
74 | *.sbr
75 | *.tlb
76 | *.tli
77 | *.tlh
78 | *.tmp
79 | *.tmp_proj
80 | *.log
81 | *.vspscc
82 | *.vssscc
83 | .builds
84 | *.pidb
85 | *.svclog
86 | *.scc
87 |
88 | # Chutzpah Test files
89 | _Chutzpah*
90 |
91 | # Visual C++ cache files
92 | ipch/
93 | *.aps
94 | *.ncb
95 | *.opendb
96 | *.opensdf
97 | *.sdf
98 | *.cachefile
99 | *.VC.db
100 | *.VC.VC.opendb
101 |
102 | # Visual Studio profiler
103 | *.psess
104 | *.vsp
105 | *.vspx
106 | *.sap
107 |
108 | # Visual Studio Trace Files
109 | *.e2e
110 |
111 | # TFS 2012 Local Workspace
112 | $tf/
113 |
114 | # Guidance Automation Toolkit
115 | *.gpState
116 |
117 | # ReSharper is a .NET coding add-in
118 | _ReSharper*/
119 | *.[Rr]e[Ss]harper
120 | *.DotSettings.user
121 |
122 | # JustCode is a .NET coding add-in
123 | .JustCode
124 |
125 | # TeamCity is a build add-in
126 | _TeamCity*
127 |
128 | # DotCover is a Code Coverage Tool
129 | *.dotCover
130 |
131 | # AxoCover is a Code Coverage Tool
132 | .axoCover/*
133 | !.axoCover/settings.json
134 |
135 | # Visual Studio code coverage results
136 | *.coverage
137 | *.coveragexml
138 |
139 | # NCrunch
140 | _NCrunch_*
141 | .*crunch*.local.xml
142 | nCrunchTemp_*
143 |
144 | # MightyMoose
145 | *.mm.*
146 | AutoTest.Net/
147 |
148 | # Web workbench (sass)
149 | .sass-cache/
150 |
151 | # Installshield output folder
152 | [Ee]xpress/
153 |
154 | # DocProject is a documentation generator add-in
155 | DocProject/buildhelp/
156 | DocProject/Help/*.HxT
157 | DocProject/Help/*.HxC
158 | DocProject/Help/*.hhc
159 | DocProject/Help/*.hhk
160 | DocProject/Help/*.hhp
161 | DocProject/Help/Html2
162 | DocProject/Help/html
163 |
164 | # Click-Once directory
165 | publish/
166 |
167 | # Publish Web Output
168 | *.[Pp]ublish.xml
169 | *.azurePubxml
170 | # Note: Comment the next line if you want to checkin your web deploy settings,
171 | # but database connection strings (with potential passwords) will be unencrypted
172 | *.pubxml
173 | *.publishproj
174 |
175 | # Microsoft Azure Web App publish settings. Comment the next line if you want to
176 | # checkin your Azure Web App publish settings, but sensitive information contained
177 | # in these scripts will be unencrypted
178 | PublishScripts/
179 |
180 | # NuGet Packages
181 | *.nupkg
182 | # The packages folder can be ignored because of Package Restore
183 | **/[Pp]ackages/*
184 | # except build/, which is used as an MSBuild target.
185 | !**/[Pp]ackages/build/
186 | # Uncomment if necessary however generally it will be regenerated when needed
187 | #!**/[Pp]ackages/repositories.config
188 | # NuGet v3's project.json files produces more ignorable files
189 | *.nuget.props
190 | *.nuget.targets
191 |
192 | # Microsoft Azure Build Output
193 | csx/
194 | *.build.csdef
195 |
196 | # Microsoft Azure Emulator
197 | ecf/
198 | rcf/
199 |
200 | # Windows Store app package directories and files
201 | AppPackages/
202 | BundleArtifacts/
203 | Package.StoreAssociation.xml
204 | _pkginfo.txt
205 | *.appx
206 |
207 | # Visual Studio cache files
208 | # files ending in .cache can be ignored
209 | *.[Cc]ache
210 | # but keep track of directories ending in .cache
211 | !*.[Cc]ache/
212 |
213 | # Others
214 | ClientBin/
215 | ~$*
216 | *~
217 | *.dbmdl
218 | *.dbproj.schemaview
219 | *.jfm
220 | *.pfx
221 | *.publishsettings
222 | orleans.codegen.cs
223 |
224 | # Including strong name files can present a security risk
225 | # (https://github.com/github/gitignore/pull/2483#issue-259490424)
226 | #*.snk
227 |
228 | # Since there are multiple workflows, uncomment next line to ignore bower_components
229 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
230 | #bower_components/
231 |
232 | # RIA/Silverlight projects
233 | Generated_Code/
234 |
235 | # Backup & report files from converting an old project file
236 | # to a newer Visual Studio version. Backup files are not needed,
237 | # because we have git ;-)
238 | _UpgradeReport_Files/
239 | Backup*/
240 | UpgradeLog*.XML
241 | UpgradeLog*.htm
242 | ServiceFabricBackup/
243 | *.rptproj.bak
244 |
245 | # SQL Server files
246 | *.mdf
247 | *.ldf
248 | *.ndf
249 |
250 | # Business Intelligence projects
251 | *.rdl.data
252 | *.bim.layout
253 | *.bim_*.settings
254 | *.rptproj.rsuser
255 |
256 | # Microsoft Fakes
257 | FakesAssemblies/
258 |
259 | # GhostDoc plugin setting file
260 | *.GhostDoc.xml
261 |
262 | # Node.js Tools for Visual Studio
263 | .ntvs_analysis.dat
264 | node_modules/
265 |
266 | # Visual Studio 6 build log
267 | *.plg
268 |
269 | # Visual Studio 6 workspace options file
270 | *.opt
271 |
272 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
273 | *.vbw
274 |
275 | # Visual Studio LightSwitch build output
276 | **/*.HTMLClient/GeneratedArtifacts
277 | **/*.DesktopClient/GeneratedArtifacts
278 | **/*.DesktopClient/ModelManifest.xml
279 | **/*.Server/GeneratedArtifacts
280 | **/*.Server/ModelManifest.xml
281 | _Pvt_Extensions
282 |
283 | # Paket dependency manager
284 | .paket/paket.exe
285 | paket-files/
286 |
287 | # FAKE - F# Make
288 | .fake/
289 |
290 | # JetBrains Rider
291 | .idea/
292 | *.sln.iml
293 |
294 | # CodeRush
295 | .cr/
296 |
297 | # Python Tools for Visual Studio (PTVS)
298 | __pycache__/
299 | *.pyc
300 |
301 | # Cake - Uncomment if you are using it
302 | # tools/**
303 | # !tools/packages.config
304 |
305 | # Tabs Studio
306 | *.tss
307 |
308 | # Telerik's JustMock configuration file
309 | *.jmconfig
310 |
311 | # BizTalk build output
312 | *.btp.cs
313 | *.btm.cs
314 | *.odx.cs
315 | *.xsd.cs
316 |
317 | # OpenCover UI analysis results
318 | OpenCover/
319 |
320 | # Azure Stream Analytics local run output
321 | ASALocalRun/
322 |
323 | # MSBuild Binary and Structured Log
324 | *.binlog
325 |
326 | # NVidia Nsight GPU debugger configuration file
327 | *.nvuser
328 |
329 | # MFractors (Xamarin productivity tool) working folder
330 | .mfractor/
331 | .vscode/settings.json
332 | *.swp
333 | checkpoint
334 | data
335 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Deep Metric Transfer for Label Propagation with Limited Annotated Data
2 |
3 | This repo contains the pytorch implementation for the semi-supervised learning paper [(arxiv)](https://arxiv.org/abs/1812.08781).
4 |
5 | ## Requirements
6 |
7 | * Python3: Anaconda is recommended because it already contains a lot of packages:
8 | * `pytorch>=1.0`: Refer to https://pytorch.org/get-started/locally/
9 | * other packages: `pip install tensorboardX tensorboard easydict scikit-image`
10 |
11 | ## Highlight
12 |
13 | - We formulate semi-supervised learning from a completely different metric transfer perspective.
14 | - Enjoys the benefit of recent advances self-supervised learning.
15 | - We hope to draw more attention to unsupervised pretraining for other tasks.
16 |
17 | ## Main results
18 |
19 | The test accuracy of our methods and the state-of-the-art methods on CIFAR10 dataset with different number of labeled data.
20 |
21 | | Method | 50 | 100 | 250 | 500 | 1000 | 2000 | 4000 | 8000 |
22 | | :----------------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: |
23 | | PI-model | 27.36 | 37.20 | 47.07 | 56.30 | 63.70 | 76.50 | 84.17 | 87.30 |
24 | | Mean-Teacher | 29.66 | 36.60 | 45.49 | 57.20 | 65.00 | 79.00 | 84.38 | 87.50 |
25 | | VAT | 23.00 | 35.58 | 47.61 | 62.90 | 72.80 | **84.00** | **86.79** | **88.10** |
26 | | Pseudo-Label | 21.00 | 34.00 | 45.83 | 60.30 | 68.20 | 78.00 | 84.79 | 86.20 |
27 | | **Ours** | **56.34** | **63.53** | **71.26** | **74.77** | **79.38** | 82.34 | 84.52 | 87.48 |
28 |
29 |
30 | ## Quick start
31 |
32 | * Clone this repo: `git clone git@github.com:microsoft/metric-transfer.pytorch.git && cd metric-transfer.pytorch`
33 |
34 | * Install pytorch and other packages listed in requirements
35 |
36 | * Download pretrained models and precomputed pseudo labels: `bash scripts/download_model.sh` . Make sure the `checkpoint` folder looks like this:
37 |
38 | ```
39 | checkpoint
40 | |-- pretrain_models
41 | | |-- ckpt_instance_cifar10_wrn-28-2_82.12.pth.tar
42 | | |-- ... other files
43 | | `-- lemniscate_resnet50.pth.tar
44 | |-- pseudos
45 | | |-- instance_nc_wrn-28-2
46 | | | |-- 50.pth.tar
47 | | | |-- ... other files
48 | | | `-- 8000.pth.tar
49 | | `-- ... other folders
50 | `-- pseudos_imagenet
51 | `-- instance_imagenet_nc_resnet50
52 | |-- num_labeled_13000
53 | | |-- 10_0.pth.tar
54 | | |-- ... other files
55 | | `-- 10_9.pth.tar
56 | `-- ... other folders
57 | ```
58 |
59 | * Supervised finetune on cifar10 dataset or Imagenet dataset. The cifar dataset will be downloaded automatically. For imagenet, refer to [here](https://github.com/pytorch/examples/tree/master/imagenet) for details of data preparation.
60 |
61 | ```bash
62 | # Finetune on cifar
63 | python cifar-semi.py \
64 | --gpus 0 \
65 | --num-labeled 250 \
66 | --pseudo-file checkpoint/pseudos/instance_nc_wrn-28-2/250.pth.tar \
67 | --resume checkpoint/pretrain_models/ckpt_instance_cifar10_wrn-28-2_82.12.pth.tar \
68 | --pseudo-ratio 0.2
69 |
70 | # For imagenet
71 | n_labeled=13000 # 1% labeled data
72 | pseudo_ratio=0.1 # use top 10% pseudo label
73 | data_dir=/path/to/imagenet/dir
74 |
75 | python imagenet-semi.py \
76 | --arch resnet50 \
77 | --gpus 0,1,2,3 \
78 | --num-labeled ${n_labeled} \
79 | --data-dir ${data_dir} \
80 | --pretrained checkpoint/pretrain_models/lemniscate_resnet50.pth.tar \
81 | --pseudo-dir checkpoint/pseudos_imagenet/instance_imagenet_nc_resnet50/num_labeled_${n_labeled} \
82 | --pseudo-ratio ${pseudo_ratio} \
83 | ```
84 |
85 | ## Usage
86 |
87 | The proposed method contains three main steps: metric pretraining, label propagation, and supervised finetune.
88 |
89 | ### Metric pretraining
90 |
91 | The metric pretraining can be unsupervised or supervised, from the same or different dataset.
92 |
93 | We provide code for [instance discrimination](https://arxiv.org/abs/1805.01978), which is borrowed from the [original pytorch release](https://github.com/zhirongw/lemniscate.pytorch) of instance discrimination. You can run the following command in root director of code to train the instance discrimination on cifar10 dataset:
94 |
95 | ```bash
96 | export PYTHONPATH=$PYTHONPATH:$(pwd)
97 | CUDA_VISIBLE_DEVICES=0 python unsupervised/cifar.py \
98 | --lr-scheduler cosine-with-restart \
99 | --epochs 1270
100 | ```
101 |
102 | For other metric or imagenet dataset, such as colorization on cifar10 dataset, or instance discrimination on imagenet datset, ref to offical released code: [colorization](https://github.com/richzhang/colorization), [instance discrimination](https://github.com/zhirongw/lemniscate.pytorch). We also provide the pretrained weight. Refer to `scripts/download_model.sh` for more details.
103 |
104 | ### Label propagation
105 |
106 | We then can propagation the label using the trained metric from the few labeled examples to a vast collection of unannotated images.
107 |
108 | We consider two propagation algorithms: K-nearest neighbors(i.e. **knn**) and spectral clustering(also called normalized cut, i.e **nc**). The implementation is in `notebooks` folder, which is in jupyter notebook format. You can simplely run the notebook to load the weight of metric pretraining approach and propagate to get the pseudo label.
109 |
110 | We alse provide the pseudo label for cifar10 and imagenet dataset. Refer to `scripts/download_model.sh` for more details.
111 |
112 | ### Supervised finetune
113 |
114 | With the estimated pseudo labels on the unlabeled data, we can train a classifier with more data. For simplicity, we omit the confidence weighted supervised training in the current version. Instead, we only use a portion of the most confident pseudo label to training.
115 |
116 | Refer to quickstart part for more command instruction.
117 |
118 | ## Contributing
119 |
120 | This project welcomes contributions and suggestions. Most contributions require you to agree to a
121 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
122 | the rights to use your contribution. For details, visit https://cla.microsoft.com.
123 |
124 | When you submit a pull request, a CLA-bot will automatically determine whether you need to provide
125 | a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions
126 | provided by the bot. You will only need to do this once across all repos using our CLA.
127 |
128 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
129 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
130 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
131 |
132 | ## Citation
133 |
134 | If you find this paper useful in your research, please consider citing:
135 |
136 | ```latex
137 | @inproceedings{liu2018deep,
138 | title={Deep Metric Transfer for Label Propagation with Limited Annotated Data},
139 | author={Liu, Bin and Wu, Zhirong and Hu, Han and Lin, Stephen},
140 | journal={arXiv preprint arXiv:1812.08781},
141 | year={2018}
142 | }
143 | ```
144 |
145 | ## Contact
146 |
147 | For any questions, please feel free to create a new issue or reach
148 | ```
149 | Bin Liu: liubinthss@gmail.com
150 | ```
151 |
--------------------------------------------------------------------------------
/lib/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 | from lib.normalize import Normalize
5 |
6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7 | 'resnet152']
8 |
9 | model_urls = {
10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15 | }
16 |
17 |
18 | def conv3x3(in_planes, out_planes, stride=1):
19 | """3x3 convolution with padding"""
20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
21 | padding=1, bias=False)
22 |
23 |
24 | class BasicBlock(nn.Module):
25 | expansion = 1
26 |
27 | def __init__(self, inplanes, planes, stride=1, downsample=None):
28 | super(BasicBlock, self).__init__()
29 | self.conv1 = conv3x3(inplanes, planes, stride)
30 | self.bn1 = nn.BatchNorm2d(planes)
31 | self.relu = nn.ReLU(inplace=True)
32 | self.conv2 = conv3x3(planes, planes)
33 | self.bn2 = nn.BatchNorm2d(planes)
34 | self.downsample = downsample
35 | self.stride = stride
36 |
37 | def forward(self, x):
38 | residual = x
39 |
40 | out = self.conv1(x)
41 | out = self.bn1(out)
42 | out = self.relu(out)
43 |
44 | out = self.conv2(out)
45 | out = self.bn2(out)
46 |
47 | if self.downsample is not None:
48 | residual = self.downsample(x)
49 |
50 | out += residual
51 | out = self.relu(out)
52 |
53 | return out
54 |
55 |
56 | class Bottleneck(nn.Module):
57 | expansion = 4
58 |
59 | def __init__(self, inplanes, planes, stride=1, downsample=None):
60 | super(Bottleneck, self).__init__()
61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
62 | self.bn1 = nn.BatchNorm2d(planes)
63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
64 | padding=1, bias=False)
65 | self.bn2 = nn.BatchNorm2d(planes)
66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
67 | self.bn3 = nn.BatchNorm2d(planes * 4)
68 | self.relu = nn.ReLU(inplace=True)
69 | self.downsample = downsample
70 | self.stride = stride
71 |
72 | def forward(self, x):
73 | residual = x
74 |
75 | out = self.conv1(x)
76 | out = self.bn1(out)
77 | out = self.relu(out)
78 |
79 | out = self.conv2(out)
80 | out = self.bn2(out)
81 | out = self.relu(out)
82 |
83 | out = self.conv3(out)
84 | out = self.bn3(out)
85 |
86 | if self.downsample is not None:
87 | residual = self.downsample(x)
88 |
89 | out += residual
90 | out = self.relu(out)
91 |
92 | return out
93 |
94 |
95 | class ResNet(nn.Module):
96 |
97 | def __init__(self, block, layers, low_dim=128):
98 | self.inplanes = 64
99 | super(ResNet, self).__init__()
100 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
101 | bias=False)
102 | self.bn1 = nn.BatchNorm2d(64)
103 | self.relu = nn.ReLU(inplace=True)
104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
105 | self.layer1 = self._make_layer(block, 64, layers[0])
106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
107 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
108 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
109 | self.avgpool = nn.AvgPool2d(7, stride=1)
110 | self.fc = nn.Linear(512 * block.expansion, low_dim)
111 | self.l2norm = Normalize(2)
112 |
113 | for m in self.modules():
114 | if isinstance(m, nn.Conv2d):
115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
116 | m.weight.data.normal_(0, math.sqrt(2. / n))
117 | elif isinstance(m, nn.BatchNorm2d):
118 | m.weight.data.fill_(1)
119 | m.bias.data.zero_()
120 |
121 | def _make_layer(self, block, planes, blocks, stride=1):
122 | downsample = None
123 | if stride != 1 or self.inplanes != planes * block.expansion:
124 | downsample = nn.Sequential(
125 | nn.Conv2d(self.inplanes, planes * block.expansion,
126 | kernel_size=1, stride=stride, bias=False),
127 | nn.BatchNorm2d(planes * block.expansion),
128 | )
129 |
130 | layers = [block(self.inplanes, planes, stride, downsample)]
131 | self.inplanes = planes * block.expansion
132 | for i in range(1, blocks):
133 | layers.append(block(self.inplanes, planes))
134 |
135 | return nn.Sequential(*layers)
136 |
137 | def forward(self, x):
138 | x = self.conv1(x)
139 | x = self.bn1(x)
140 | x = self.relu(x)
141 | x = self.maxpool(x)
142 |
143 | x = self.layer1(x)
144 | x = self.layer2(x)
145 | x = self.layer3(x)
146 | x = self.layer4(x)
147 |
148 | x = self.avgpool(x)
149 | x = x.view(x.size(0), -1)
150 | x = self.fc(x)
151 | x = self.l2norm(x)
152 |
153 | return x
154 |
155 |
156 | def resnet18(pretrained=False, **kwargs):
157 | """Constructs a ResNet-18 model.
158 |
159 | Args:
160 | pretrained (bool): If True, returns a model pre-trained on ImageNet
161 | """
162 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
163 | if pretrained:
164 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
165 | return model
166 |
167 |
168 | def resnet34(pretrained=False, **kwargs):
169 | """Constructs a ResNet-34 model.
170 |
171 | Args:
172 | pretrained (bool): If True, returns a model pre-trained on ImageNet
173 | """
174 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
175 | if pretrained:
176 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
177 | return model
178 |
179 |
180 | def resnet50(pretrained=False, **kwargs):
181 | """Constructs a ResNet-50 model.
182 |
183 | Args:
184 | pretrained (bool): If True, returns a model pre-trained on ImageNet
185 | """
186 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
187 | if pretrained:
188 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
189 | return model
190 |
191 |
192 | def resnet101(pretrained=False, **kwargs):
193 | """Constructs a ResNet-101 model.
194 |
195 | Args:
196 | pretrained (bool): If True, returns a model pre-trained on ImageNet
197 | """
198 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
199 | if pretrained:
200 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
201 | return model
202 |
203 |
204 | def resnet152(pretrained=False, **kwargs):
205 | """Constructs a ResNet-152 model.
206 |
207 | Args:
208 | pretrained (bool): If True, returns a model pre-trained on ImageNet
209 | """
210 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
211 | if pretrained:
212 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
213 | return model
214 |
--------------------------------------------------------------------------------
/unsupervised/cifar.py:
--------------------------------------------------------------------------------
1 | """Train CIFAR10 with PyTorch."""
2 | import argparse
3 | import os
4 | import sys
5 | import time
6 | from pprint import pprint
7 |
8 | import torch
9 | import torch.backends.cudnn as cudnn
10 | import torch.nn as nn
11 | import torch.optim as optim
12 | import torch.optim.lr_scheduler as lr_scheduler
13 | import torchvision.transforms as transforms
14 |
15 | from lib import datasets, models
16 | from lib.LinearAverage import LinearAverage
17 | from lib.NCEAverage import NCEAverage
18 | from lib.NCECriterion import NCECriterion
19 | from lib.utils import AverageMeter, CosineAnnealingLRWithRestart
20 | from test import kNN
21 |
22 |
23 | # Training
24 | def train(net, optimizer, trainloader, criterion, lemniscate, epoch):
25 | print('\nEpoch: {}, lr {}'.format(epoch, optimizer.param_groups[0]['lr']))
26 | train_loss = AverageMeter()
27 | data_time = AverageMeter()
28 | batch_time = AverageMeter()
29 |
30 | # switch to train mode
31 | net.train()
32 |
33 | end = time.time()
34 | for batch_idx, (inputs, targets, indexes) in enumerate(trainloader):
35 | data_time.update(time.time() - end)
36 | inputs, targets, indexes = inputs.to(args.device), targets.to(
37 | args.device), indexes.to(args.device)
38 | optimizer.zero_grad()
39 |
40 | features = net(inputs)
41 | outputs = lemniscate(features, indexes)
42 | loss = criterion(outputs, indexes)
43 |
44 | loss.backward()
45 | optimizer.step()
46 |
47 | train_loss.update(loss.item(), inputs.size(0))
48 |
49 | # measure elapsed time
50 | batch_time.update(time.time() - end)
51 | end = time.time()
52 |
53 | if batch_idx % 100 == 0:
54 | print(f'Epoch: [{epoch}/{args.epoch}][{batch_idx}/{len(trainloader)}] '
55 | f'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
56 | f'Data: {data_time.val:.3f} ({data_time.avg:.3f}) '
57 | f'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f})')
58 |
59 |
60 | def get_data_loader():
61 | normalize = transforms.Normalize(
62 | (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
63 | if args.transform_crop == 'RandomResizedCrop':
64 | crop = transforms.RandomResizedCrop(
65 | size=32, scale=(args.transform_scale, 1.))
66 | else:
67 | crop = transforms.Compose([
68 | transforms.Pad(4, padding_mode='reflect'),
69 | transforms.RandomCrop(32)
70 | ])
71 | transform_train = transforms.Compose([
72 | crop,
73 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
74 | transforms.RandomGrayscale(p=0.2),
75 | transforms.RandomHorizontalFlip(),
76 | transforms.ToTensor(),
77 | normalize,
78 | ])
79 | print('-' * 80)
80 | print('transform_train = ', transform_train)
81 | print('-' * 80)
82 |
83 | transform_test = transforms.Compose([
84 | transforms.ToTensor(),
85 | normalize,
86 | ])
87 |
88 | trainset = datasets.CIFAR10Instance(
89 | root=args.data_dir, train=True, download=True, transform=transform_train)
90 | trainloader = torch.utils.data.DataLoader(
91 | trainset, batch_size=128, shuffle=True, num_workers=2)
92 |
93 | testset = datasets.CIFAR10Instance(
94 | root=args.data_dir, train=False, download=True, transform=transform_test)
95 | testloader = torch.utils.data.DataLoader(
96 | testset, batch_size=100, shuffle=False, num_workers=2)
97 |
98 | ndata = trainset.__len__()
99 |
100 | return trainloader, testloader, ndata
101 |
102 |
103 | def build_model():
104 | best_acc = 0 # best test accuracy
105 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
106 |
107 | if args.architecture == 'resnet18':
108 | net = models.__dict__['resnet18_cifar'](low_dim=args.low_dim)
109 | elif args.architecture == 'wrn-28-2':
110 | net = models.WideResNet(
111 | depth=28, num_classes=args.low_dim, widen_factor=2, dropRate=0).to(args.device)
112 | elif args.architecture == 'wrn-28-10':
113 | net = models.WideResNet(
114 | depth=28, num_classes=args.low_dim, widen_factor=10, dropRate=0).to(args.device)
115 |
116 | # define leminiscate
117 | if args.nce_k > 0:
118 | lemniscate = NCEAverage(args.low_dim, args.ndata,
119 | args.nce_k, args.nce_t, args.nce_m)
120 | else:
121 | lemniscate = LinearAverage(
122 | args.low_dim, args.ndata, args.nce_t, args.nce_m)
123 |
124 | if args.device == 'cuda':
125 | net = torch.nn.DataParallel(
126 | net, device_ids=range(torch.cuda.device_count()))
127 | cudnn.benchmark = True
128 |
129 | optimizer = optim.SGD(
130 | net.parameters(), lr=args.lr, momentum=0.9,
131 | weight_decay=args.weight_decay, nesterov=True)
132 | # Model
133 | if args.test_only or len(args.resume) > 0:
134 | # Load checkpoint.
135 | print('==> Resuming from checkpoint..')
136 | checkpoint = torch.load(args.resume)
137 | net.load_state_dict(checkpoint['net'])
138 | optimizer.load_state_dict(checkpoint['optimizer'])
139 | lemniscate = checkpoint['lemniscate']
140 | best_acc = checkpoint['acc']
141 | start_epoch = checkpoint['epoch'] + 1
142 |
143 | if args.lr_scheduler == 'multi-step':
144 | if args.epochs == 200:
145 | steps = [60, 120, 160]
146 | elif args.epochs == 600:
147 | steps = [180, 360, 480, 560]
148 | else:
149 | raise RuntimeError(
150 | f"need to config steps for epoch = {args.epochs} first.")
151 | scheduler = lr_scheduler.MultiStepLR(
152 | optimizer, steps, gamma=0.2, last_epoch=start_epoch - 1)
153 | elif args.lr_scheduler == 'cosine':
154 | scheduler = lr_scheduler.CosineAnnealingLR(
155 | optimizer, args.epochs, eta_min=0.00001, last_epoch=start_epoch - 1)
156 | elif args.lr_scheduler == 'cosine-with-restart':
157 | scheduler = CosineAnnealingLRWithRestart(
158 | optimizer, eta_min=0.00001, last_epoch=start_epoch - 1)
159 | else:
160 | raise ValueError("not supported")
161 |
162 | # define loss function
163 | if hasattr(lemniscate, 'K'):
164 | criterion = NCECriterion(args.ndata)
165 | else:
166 | criterion = nn.CrossEntropyLoss()
167 |
168 | net.to(args.device)
169 | lemniscate.to(args.device)
170 | criterion.to(args.device)
171 |
172 | return net, lemniscate, optimizer, criterion, scheduler, best_acc, start_epoch
173 |
174 |
175 | def main():
176 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
177 |
178 | # Data
179 | print('==> Preparing data..')
180 | trainloader, testloader, args.ndata = get_data_loader()
181 |
182 | print('==> Building model..')
183 | net, lemniscate, optimizer, criterion, scheduler, best_acc, start_epoch = build_model()
184 |
185 | if args.test_only:
186 | kNN(net, lemniscate, trainloader, testloader, 200, args.nce_t, 1)
187 | sys.exit(0)
188 |
189 | for epoch in range(start_epoch, args.epochs):
190 | scheduler.step()
191 | train(net, optimizer, trainloader, criterion, lemniscate, epoch)
192 | acc = kNN(net, lemniscate, trainloader, testloader, 200, args.nce_t, 0)
193 |
194 | if acc > best_acc:
195 | print('Saving..')
196 | state = {
197 | 'net': net.state_dict(),
198 | 'lemniscate': lemniscate,
199 | 'acc': acc,
200 | 'epoch': epoch,
201 | 'optimizer': optimizer.state_dict(),
202 | }
203 | os.makedirs(args.model_dir, exist_ok=True)
204 | torch.save(state, os.path.join(
205 | args.model_dir, 'ckpt.cifar.pth.tar'))
206 | best_acc = acc
207 |
208 | print('best accuracy: {:.2f}'.format(best_acc * 100))
209 |
210 | acc = kNN(net, lemniscate, trainloader, testloader, 200, args.nce_t, 1)
211 | print('last accuracy: {:.2f}'.format(acc * 100))
212 |
213 |
214 | if __name__ == '__main__':
215 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
216 | parser.add_argument('--data-dir', '--dataDir',
217 | default='./data', type=str, metavar='DIR')
218 | parser.add_argument('--model-dir', '--modelDir', default='./checkpoint/instance_cifar10', type=str,
219 | metavar='DIR', help='directory to save checkpoint')
220 | parser.add_argument('--log-dir', '--logDir', default='./tensorboard/instance_cifar10', type=str,
221 | metavar='DIR', help='directory to save tensorboard logs')
222 | parser.add_argument('--lr', default=0.03, type=float, help='learning rate')
223 | parser.add_argument('--lr-scheduler', default='cosine', type=str,
224 | choices=['multi-step', 'cosine',
225 | 'cosine-with-restart'],
226 | help='which lr scheduler to use')
227 | parser.add_argument('--resume', '-r', default='',
228 | type=str, help='resume from checkpoint')
229 | parser.add_argument('--test-only', action='store_true', help='test only')
230 | parser.add_argument('--low-dim', default=128, type=int,
231 | metavar='D', help='feature dimension')
232 | parser.add_argument('--nce-k', default=0, type=int,
233 | metavar='K', help='number of negative samples for NCE')
234 | parser.add_argument('--nce-t', default=0.1, type=float,
235 | metavar='T', help='temperature parameter for softmax')
236 | parser.add_argument('--nce-m', default=0.5, type=float,
237 | metavar='M', help='momentum for non-parametric updates')
238 | parser.add_argument('--epochs', default=600, type=int,
239 | metavar='N', help='number of epochs')
240 | parser.add_argument('--architecture', '--arch', default='wrn-28-2', type=str,
241 | choices=['resnet18', 'wrn-28-2', 'wrn-28-10'],
242 | help='which backbone to use')
243 | parser.add_argument('--transform-scale', default=0.2, type=float)
244 | parser.add_argument('--transform-crop', type=str, default='RandomResizedCrop',
245 | choices=['RandomResizedCrop', 'PadCrop'])
246 | parser.add_argument('--weight-decay', '--wd', type=float, default=5e-4)
247 | args = parser.parse_args()
248 |
249 | pprint(vars(args))
250 |
251 | main()
252 |
253 | pprint(vars(args))
254 |
--------------------------------------------------------------------------------
/unsupervised/imagenet.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import time
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.parallel
9 | import torch.backends.cudnn as cudnn
10 | import torch.optim
11 | import torch.utils.data
12 | import torchvision.transforms as transforms
13 |
14 | from lib import models, datasets
15 |
16 | from lib.NCEAverage import NCEAverage
17 | from lib.LinearAverage import LinearAverage
18 | from lib.NCECriterion import NCECriterion
19 | from lib.utils import AverageMeter
20 | from test import NN, kNN
21 |
22 | model_names = sorted(name for name in models.__dict__
23 | if name.islower() and not name.startswith("__")
24 | and callable(models.__dict__[name]))
25 |
26 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
27 | parser.add_argument('--data-dir', metavar='DIR',
28 | help='path to dataset', required=True)
29 | parser.add_argument('--model-dir', metavar='DIR',
30 | default='./checkpoint/instance_imagenet', help='path to save model')
31 | parser.add_argument('--log-dir', metavar='DIR',
32 | default='./tensorboard/instance_imagenet', help='path to save log')
33 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
34 | choices=model_names,
35 | help='model architecture: ' +
36 | ' | '.join(model_names) +
37 | ' (default: resnet18)')
38 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
39 | help='number of data loading workers (default: 4)')
40 | parser.add_argument('--epochs', default=200, type=int, metavar='N',
41 | help='number of total epochs to run')
42 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
43 | help='manual epoch number (useful on restarts)')
44 | parser.add_argument('-b', '--batch-size', default=256, type=int,
45 | metavar='N', help='mini-batch size (default: 256)')
46 | parser.add_argument('-vb', '--val-batch-size', default=128, type=int,
47 | metavar='N', help='validation mini-batch size (default: 128)')
48 | parser.add_argument('--lr', '--learning-rate', default=0.03, type=float,
49 | metavar='LR', help='initial learning rate')
50 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
51 | help='momentum')
52 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
53 | metavar='W', help='weight decay (default: 1e-4)')
54 | parser.add_argument('--print-freq', '-p', default=10, type=int,
55 | metavar='N', help='print frequency (default: 10)')
56 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
57 | help='path to latest checkpoint (default: none)')
58 | parser.add_argument('--auto-resume', action='store_true', help='auto resume')
59 | parser.add_argument('--test-only', action='store_true', help='test only')
60 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
61 | help='evaluate model on validation set')
62 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
63 | help='use pre-trained model')
64 | parser.add_argument('--low-dim', default=128, type=int,
65 | metavar='D', help='feature dimension')
66 | parser.add_argument('--nce-k', default=4096, type=int,
67 | metavar='K', help='number of negative samples for NCE')
68 | parser.add_argument('--nce-t', default=0.07, type=float,
69 | metavar='T', help='temperature parameter for softmax')
70 | parser.add_argument('--nce-m', default=0.5, type=float,
71 | help='momentum for non-parametric updates')
72 | parser.add_argument('--iter-size', default=1, type=int,
73 | help='caffe style iter size')
74 |
75 | best_prec1 = 0
76 |
77 |
78 | def main():
79 | global args, best_prec1
80 | args = parser.parse_args()
81 |
82 | # create model
83 | if args.pretrained:
84 | print("=> using pre-trained model '{}'".format(args.arch))
85 | model = models.__dict__[args.arch](pretrained=True)
86 | else:
87 | print("=> creating model '{}'".format(args.arch))
88 | model = models.__dict__[args.arch](low_dim=args.low_dim)
89 |
90 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
91 | model.features = torch.nn.DataParallel(model.features)
92 | model.cuda()
93 | else:
94 | model = torch.nn.DataParallel(model).cuda()
95 |
96 | # Data loading code
97 | print("=> loading dataset")
98 | traindir = os.path.join(args.data_dir, 'train')
99 | valdir = os.path.join(args.data_dir, 'val')
100 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
101 | std=[0.229, 0.224, 0.225])
102 |
103 | train_dataset = datasets.ImageFolderInstance(
104 | traindir,
105 | transforms.Compose([
106 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
107 | transforms.RandomGrayscale(p=0.2),
108 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
109 | transforms.RandomHorizontalFlip(),
110 | transforms.ToTensor(),
111 | normalize,
112 | ]))
113 |
114 | train_loader = torch.utils.data.DataLoader(
115 | train_dataset, batch_size=args.batch_size, shuffle=True,
116 | num_workers=args.workers, pin_memory=True)
117 |
118 | val_loader = torch.utils.data.DataLoader(
119 | datasets.ImageFolderInstance(valdir, transforms.Compose([
120 | transforms.Resize(256),
121 | transforms.CenterCrop(224),
122 | transforms.ToTensor(),
123 | normalize,
124 | ])),
125 | batch_size=args.val_batch_size, shuffle=False,
126 | num_workers=args.workers, pin_memory=True)
127 |
128 | # define lemniscate and loss function (criterion)
129 | print("=> building optimizer")
130 | ndata = train_dataset.__len__()
131 | if args.nce_k > 0:
132 | lemniscate = NCEAverage(args.low_dim, ndata,
133 | args.nce_k, args.nce_t, args.nce_m).cuda()
134 | criterion = NCECriterion(ndata).cuda()
135 | else:
136 | lemniscate = LinearAverage(
137 | args.low_dim, ndata, args.nce_t, args.nce_m).cuda()
138 | criterion = nn.CrossEntropyLoss().cuda()
139 |
140 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
141 | momentum=args.momentum,
142 | weight_decay=args.weight_decay)
143 |
144 | # optionally resume from a checkpoint
145 | model_filename_to_resume = None
146 | if args.resume:
147 | if os.path.isfile(args.resume):
148 | model_filename_to_resume = args.resume
149 | else:
150 | print("=> no checkpoint found at '{}'".format(args.resume))
151 | elif args.auto_resume:
152 | for epoch in range(args.epochs, args.start_epoch + 1, -1):
153 | model_filename = get_model_name(epoch)
154 | if os.path.exists(model_filename):
155 | model_filename_to_resume = model_filename
156 | break
157 | else:
158 | print("=> no checkpoint found at '{}'".format(args.model_dir))
159 |
160 | if model_filename_to_resume is not None:
161 | print("=> loading checkpoint '{}'".format(model_filename_to_resume))
162 | checkpoint = torch.load(model_filename_to_resume)
163 | args.start_epoch = checkpoint['epoch']
164 | best_prec1 = checkpoint['best_prec1']
165 | model.load_state_dict(checkpoint['state_dict'])
166 | lemniscate = checkpoint['lemniscate']
167 | optimizer.load_state_dict(checkpoint['optimizer'])
168 | print("=> loaded checkpoint '{}' (epoch {})"
169 | .format(model_filename_to_resume, checkpoint['epoch']))
170 |
171 | cudnn.benchmark = True
172 |
173 | if args.evaluate:
174 | kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t)
175 | return
176 |
177 | for epoch in range(args.start_epoch, args.epochs):
178 | adjust_learning_rate(optimizer, epoch)
179 |
180 | # train for one epoch
181 | train(train_loader, model, lemniscate, criterion, optimizer, epoch)
182 |
183 | # evaluate on validation set
184 | prec1 = NN(model, lemniscate, train_loader, val_loader)
185 |
186 | # remember best prec@1 and save checkpoint
187 | is_best = prec1 > best_prec1
188 | best_prec1 = max(prec1, best_prec1)
189 | save_checkpoint({
190 | 'epoch': epoch + 1,
191 | 'arch': args.arch,
192 | 'state_dict': model.state_dict(),
193 | 'lemniscate': lemniscate,
194 | 'best_prec1': best_prec1,
195 | 'optimizer': optimizer.state_dict(),
196 | }, is_best,
197 | filename=get_model_name(epoch))
198 | # evaluate KNN after last epoch
199 | kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t)
200 |
201 |
202 | def train(train_loader, model, lemniscate, criterion, optimizer, epoch):
203 | batch_time = AverageMeter()
204 | data_time = AverageMeter()
205 | losses = AverageMeter()
206 |
207 | # switch to train mode
208 | model.train()
209 |
210 | end = time.time()
211 | optimizer.zero_grad()
212 | for i, (inputs, _, index) in enumerate(train_loader):
213 | # measure data loading time
214 | data_time.update(time.time() - end)
215 |
216 | index = index.cuda(non_blocking=True)
217 |
218 | # compute output
219 | feature = model(inputs)
220 | output = lemniscate(feature, index)
221 | loss = criterion(output, index) / args.iter_size
222 |
223 | loss.backward()
224 |
225 | # measure accuracy and record loss
226 | losses.update(loss.item() * args.iter_size, inputs.size(0))
227 |
228 | if (i + 1) % args.iter_size == 0:
229 | # compute gradient and do SGD step
230 | optimizer.step()
231 | optimizer.zero_grad()
232 |
233 | # measure elapsed time
234 | batch_time.update(time.time() - end)
235 | end = time.time()
236 |
237 | if i % args.print_freq == 0:
238 | print(f'Epoch: [{epoch}/{args.epochs}][{i}/{len(train_loader)}]\t'
239 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
240 | f'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
241 | f'Loss {losses.val:.4f} ({losses.avg:.4f})\t')
242 |
243 |
244 | def get_model_name(epoch):
245 | return os.path.join(args.model_dir, 'ckpt-{}.pth.tar'.format(epoch))
246 |
247 |
248 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
249 | torch.save(state, filename)
250 | if is_best:
251 | shutil.copyfile(filename, os.path.join(
252 | args.model_dir, 'model_best.pth.tar'))
253 |
254 |
255 | def adjust_learning_rate(optimizer, epoch):
256 | """Sets the learning rate to the initial LR decayed by 10 every 100 epochs"""
257 | if epoch < 120:
258 | lr = args.lr
259 | elif 120 <= epoch < 160:
260 | lr = args.lr * 0.1
261 | else:
262 | lr = args.lr * 0.01
263 | # lr = args_.lr * (0.1 ** (epoch // 100))
264 | for param_group in optimizer.param_groups:
265 | param_group['lr'] = lr
266 |
267 |
268 | if __name__ == '__main__':
269 | main()
270 |
--------------------------------------------------------------------------------
/notebooks/knn-imagenet.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%matplotlib inline\n",
10 | "import sys\n",
11 | "import os\n",
12 | "import argparse\n",
13 | "import time\n",
14 | "import numpy as np\n",
15 | "sys.path.append('../')\n",
16 | "\n",
17 | "import torch\n",
18 | "import torch.nn as nn\n",
19 | "import torch.optim as optim\n",
20 | "import torch.nn.functional as F\n",
21 | "import torch.backends.cudnn as cudnn\n",
22 | "import torch.optim.lr_scheduler as lr_scheduler\n",
23 | "import torchvision\n",
24 | "import torchvision.transforms as transforms\n",
25 | "import matplotlib.pyplot as plt\n",
26 | "import easydict as edict\n",
27 | "\n",
28 | "from lib import models, datasets\n",
29 | "import math"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 2,
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "# parameters\n",
39 | "args = edict\n",
40 | "\n",
41 | "# imagenet\n",
42 | "args.cache = '../checkpoint/train_features_labels_cache/instance_imagenet_train_feature_resnet50.pth.tar'\n",
43 | "args.val_cache = '../checkpoint/train_features_labels_cache/instance_imagenet_val_feature_resnet50.pth.tar'\n",
44 | "args.save_path = '../checkpoint/pseudos/unsupervised_imagenet32x32_nc_wrn-28-2'\n",
45 | "os.makedirs(args.save_path, exist_ok=True)\n",
46 | "\n",
47 | "args.low_dim = 128\n",
48 | "args.num_class = 1000\n",
49 | "args.rng_seed = 0"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": 3,
55 | "metadata": {
56 | "scrolled": true
57 | },
58 | "outputs": [
59 | {
60 | "name": "stdout",
61 | "output_type": "stream",
62 | "text": [
63 | "torch.float32 torch.int64\n",
64 | "torch.Size([1331167, 128]) torch.Size([1331167])\n"
65 | ]
66 | }
67 | ],
68 | "source": [
69 | "ckpt = torch.load(args.cache)\n",
70 | "train_labels, train_features = ckpt['labels'], ckpt['features']\n",
71 | "\n",
72 | "ckpt = torch.load(args.val_cache)\n",
73 | "val_labels, val_features = ckpt['val_labels'], ckpt['val_features']\n",
74 | "\n",
75 | "train_features = torch.cat([val_features, train_features], dim=0)\n",
76 | "train_labels = torch.cat([val_labels, train_labels], dim=0)\n",
77 | "\n",
78 | "print(train_features.dtype, train_labels.dtype)\n",
79 | "print(train_features.shape, train_labels.shape)"
80 | ]
81 | },
82 | {
83 | "cell_type": "markdown",
84 | "metadata": {},
85 | "source": [
86 | "# use cpu because the following computation need a lot of memory"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "device = 'cpu'\n",
96 | "train_features, train_labels = train_features.to(device), train_labels.to(device)"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "metadata": {},
103 | "outputs": [
104 | {
105 | "name": "stdout",
106 | "output_type": "stream",
107 | "text": [
108 | "tensor([ 970454, 1058848, 717280, ..., 462299, 305137, 436069])\n"
109 | ]
110 | }
111 | ],
112 | "source": [
113 | "num_train_data = train_labels.shape[0]\n",
114 | "num_class = torch.max(train_labels) + 1\n",
115 | "\n",
116 | "torch.manual_seed(args.rng_seed)\n",
117 | "torch.cuda.manual_seed_all(args.rng_seed)\n",
118 | "perm = torch.randperm(num_train_data).to(device)\n",
119 | "print(perm)"
120 | ]
121 | },
122 | {
123 | "cell_type": "markdown",
124 | "metadata": {},
125 | "source": [
126 | "# soft label"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "metadata": {},
133 | "outputs": [],
134 | "source": [
135 | "class AverageMeter(object):\n",
136 | " \"\"\"Computes and stores the average and current value\"\"\"\n",
137 | " def __init__(self):\n",
138 | " self.reset()\n",
139 | "\n",
140 | " def reset(self):\n",
141 | " self.val = 0\n",
142 | " self.avg = 0\n",
143 | " self.sum = 0\n",
144 | " self.count = 0\n",
145 | "\n",
146 | " def update(self, val, n=1):\n",
147 | " self.val = val\n",
148 | " self.sum += val * n\n",
149 | " self.count += n\n",
150 | " self.avg = self.sum / self.count"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": null,
156 | "metadata": {},
157 | "outputs": [
158 | {
159 | "name": "stdout",
160 | "output_type": "stream",
161 | "text": [
162 | "[0]/[100] top5=85.00%(85.00%) top1=66.60%(66.60%)\n",
163 | "[1]/[100] top5=79.60%(82.30%) top1=52.80%(59.70%)\n",
164 | "[2]/[100] top5=81.00%(81.87%) top1=61.40%(60.27%)\n",
165 | "[3]/[100] top5=65.20%(77.70%) top1=42.80%(55.90%)\n",
166 | "[4]/[100] top5=70.00%(76.16%) top1=47.40%(54.20%)\n",
167 | "[5]/[100] top5=69.20%(75.00%) top1=42.60%(52.27%)\n",
168 | "[6]/[100] top5=67.20%(73.89%) top1=41.40%(50.71%)\n",
169 | "[7]/[100] top5=77.60%(74.35%) top1=52.20%(50.90%)\n",
170 | "[8]/[100] top5=84.60%(75.49%) top1=67.00%(52.69%)\n",
171 | "[9]/[100] top5=77.40%(75.68%) top1=57.00%(53.12%)\n",
172 | "[10]/[100] top5=82.20%(76.27%) top1=67.00%(54.38%)\n",
173 | "[11]/[100] top5=71.00%(75.83%) top1=49.00%(53.93%)\n",
174 | "[12]/[100] top5=65.20%(75.02%) top1=43.80%(53.15%)\n",
175 | "[13]/[100] top5=83.00%(75.59%) top1=62.20%(53.80%)\n",
176 | "[14]/[100] top5=85.20%(76.23%) top1=66.80%(54.67%)\n",
177 | "[15]/[100] top5=71.80%(75.95%) top1=43.60%(53.97%)\n",
178 | "[16]/[100] top5=62.20%(75.14%) top1=39.80%(53.14%)\n",
179 | "[17]/[100] top5=61.00%(74.36%) top1=38.80%(52.34%)\n",
180 | "[18]/[100] top5=63.60%(73.79%) top1=36.60%(51.52%)\n",
181 | "[19]/[100] top5=67.40%(73.47%) top1=41.60%(51.02%)\n",
182 | "[20]/[100] top5=70.40%(73.32%) top1=38.40%(50.42%)\n",
183 | "[21]/[100] top5=71.40%(73.24%) top1=46.60%(50.25%)\n",
184 | "[22]/[100] top5=71.00%(73.14%) top1=48.40%(50.17%)\n",
185 | "[23]/[100] top5=76.20%(73.27%) top1=44.00%(49.91%)\n",
186 | "[24]/[100] top5=71.20%(73.18%) top1=41.60%(49.58%)\n",
187 | "[25]/[100] top5=78.60%(73.39%) top1=55.40%(49.80%)\n",
188 | "[26]/[100] top5=67.20%(73.16%) top1=45.00%(49.62%)\n",
189 | "[27]/[100] top5=74.80%(73.22%) top1=52.60%(49.73%)\n",
190 | "[28]/[100] top5=74.80%(73.28%) top1=47.00%(49.63%)\n",
191 | "[29]/[100] top5=79.40%(73.48%) top1=57.80%(49.91%)\n",
192 | "[30]/[100] top5=78.20%(73.63%) top1=51.00%(49.94%)\n",
193 | "[31]/[100] top5=64.20%(73.34%) top1=37.20%(49.54%)\n",
194 | "[32]/[100] top5=86.00%(73.72%) top1=71.40%(50.21%)\n",
195 | "[33]/[100] top5=83.00%(73.99%) top1=61.00%(50.52%)\n",
196 | "[34]/[100] top5=75.80%(74.05%) top1=52.00%(50.57%)\n",
197 | "[35]/[100] top5=71.00%(73.96%) top1=44.00%(50.38%)\n",
198 | "[36]/[100] top5=73.80%(73.96%) top1=53.40%(50.46%)\n",
199 | "[37]/[100] top5=73.80%(73.95%) top1=51.80%(50.50%)\n",
200 | "[38]/[100] top5=78.80%(74.08%) top1=47.80%(50.43%)\n",
201 | "[39]/[100] top5=75.80%(74.12%) top1=57.40%(50.60%)\n",
202 | "[40]/[100] top5=76.20%(74.17%) top1=54.20%(50.69%)\n",
203 | "[41]/[100] top5=62.20%(73.89%) top1=35.80%(50.34%)\n",
204 | "[42]/[100] top5=67.80%(73.74%) top1=49.00%(50.31%)\n",
205 | "[43]/[100] top5=66.00%(73.57%) top1=45.40%(50.20%)\n",
206 | "[44]/[100] top5=67.60%(73.44%) top1=43.40%(50.04%)\n",
207 | "[45]/[100] top5=66.60%(73.29%) top1=44.20%(49.92%)\n",
208 | "[46]/[100] top5=58.80%(72.98%) top1=34.20%(49.58%)\n",
209 | "[47]/[100] top5=65.60%(72.83%) top1=48.60%(49.56%)\n",
210 | "[48]/[100] top5=67.60%(72.72%) top1=40.20%(49.37%)\n",
211 | "[49]/[100] top5=56.60%(72.40%) top1=36.60%(49.12%)\n",
212 | "[50]/[100] top5=59.20%(72.14%) top1=37.40%(48.89%)\n",
213 | "[51]/[100] top5=64.00%(71.98%) top1=38.20%(48.68%)\n",
214 | "[52]/[100] top5=68.60%(71.92%) top1=43.80%(48.59%)\n",
215 | "[53]/[100] top5=68.40%(71.85%) top1=51.80%(48.65%)\n",
216 | "[54]/[100] top5=67.20%(71.77%) top1=50.40%(48.68%)\n",
217 | "[55]/[100] top5=64.80%(71.64%) top1=45.20%(48.62%)\n",
218 | "[56]/[100] top5=78.40%(71.76%) top1=63.80%(48.88%)\n"
219 | ]
220 | }
221 | ],
222 | "source": [
223 | "n_chunks = 100\n",
224 | "n_val = val_features.shape[0]\n",
225 | "\n",
226 | "prec_top5 = AverageMeter()\n",
227 | "prec_top1 = AverageMeter()\n",
228 | "index_labeled = torch.arange(n_val, train_features.shape[0])\n",
229 | "index_unlabeled = torch.arange(n_val)\n",
230 | "num_labeled_data = index_labeled.shape[0]\n",
231 | "\n",
232 | "for i_chunks, index_unlabeled_chunk in enumerate(index_unlabeled.chunk(n_chunks)):\n",
233 | "\n",
234 | " # calculate similarity matrix\n",
235 | " dist = torch.mm(train_features[index_unlabeled_chunk], train_features[index_labeled].t())\n",
236 | "\n",
237 | " K = min(num_labeled_data, 200)\n",
238 | " bs = index_unlabeled_chunk.shape[0]\n",
239 | " yd, yi = dist.topk(K, dim=1, largest=True, sorted=True)\n",
240 | " candidates = train_labels.view(1,-1).expand(bs, -1)\n",
241 | " retrieval = torch.gather(candidates, 1, index_labeled[yi])\n",
242 | " retrieval_one_hot = torch.zeros(bs * K, num_class).to(device)\n",
243 | " retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1)\n",
244 | "\n",
245 | " temperature = 0.1\n",
246 | "\n",
247 | " yd_transform = (yd / temperature).exp_()\n",
248 | " probs = torch.sum(torch.mul(retrieval_one_hot.view(bs, -1 , num_class), yd_transform.view(bs, -1, 1)), 1)\n",
249 | " probs.div_(probs.sum(dim=1, keepdim=True))\n",
250 | " probs_sorted, predictions = probs.sort(1, True)\n",
251 | " correct = predictions.eq(train_labels[index_unlabeled_chunk].data.view(-1,1))\n",
252 | " \n",
253 | " top5 = torch.any(correct[:, :5], dim=1).float().mean() \n",
254 | " top1 = correct[:, 0].float().mean() \n",
255 | " prec_top5.update(top5, bs)\n",
256 | " prec_top1.update(top1, bs)\n",
257 | " print('[{}]/[{}] top5={:.2%}({:.2%}) top1={:.2%}({:.2%})'.format(\n",
258 | " i_chunks, n_chunks, prec_top5.val, prec_top5.avg, prec_top1.val, prec_top1.avg))"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": null,
264 | "metadata": {
265 | "scrolled": false
266 | },
267 | "outputs": [],
268 | "source": [
269 | "# n_chunks = 100\n",
270 | "\n",
271 | "# prec_top5 = AverageMeter()\n",
272 | "# for num_labeled_data in [10000]:\n",
273 | "# index_labeled = []\n",
274 | "# index_unlabeled = []\n",
275 | "# data_per_class = num_labeled_data // args.num_class\n",
276 | "# for c in range(args.num_class):\n",
277 | "# indexes_c = perm[train_labels[perm] == c]\n",
278 | "# index_labeled.append(indexes_c[:data_per_class])\n",
279 | "# index_unlabeled.append(indexes_c[data_per_class:])\n",
280 | "# index_labeled = torch.cat(index_labeled)\n",
281 | "# index_unlabeled = torch.cat(index_unlabeled)\n",
282 | "\n",
283 | "# for i_chunks, index_unlabeled_chunk in enumerate(index_unlabeled.chunk(n_chunks)):\n",
284 | " \n",
285 | "# # calculate similarity matrix\n",
286 | "# dist = torch.mm(train_features[index_unlabeled_chunk], train_features[index_labeled].t())\n",
287 | "\n",
288 | "# K = min(num_labeled_data, 5000)\n",
289 | "# bs = index_unlabeled_chunk.shape[0]\n",
290 | "# yd, yi = dist.topk(K, dim=1, largest=True, sorted=True)\n",
291 | "# candidates = train_labels.view(1,-1).expand(bs, -1)\n",
292 | "# retrieval = torch.gather(candidates, 1, index_labeled[yi])\n",
293 | "# retrieval_one_hot = torch.zeros(bs * K, num_class).to(device)\n",
294 | "# retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1)\n",
295 | "\n",
296 | "# temperature = 0.1\n",
297 | "\n",
298 | "# yd_transform = (yd / temperature).exp_()\n",
299 | "# probs = torch.sum(torch.mul(retrieval_one_hot.view(bs, -1 , num_class), yd_transform.view(bs, -1, 1)), 1)\n",
300 | "# probs.div_(probs.sum(dim=1, keepdim=True))\n",
301 | "# probs_sorted, predictions = probs.sort(1, True)\n",
302 | "# correct = predictions.eq(train_labels[index_unlabeled_chunk].data.view(-1,1))\n",
303 | "# top5 = torch.any(correct[:, :5], dim=1).float().mean() \n",
304 | " \n",
305 | "# prec_top5.update(top5, bs)\n",
306 | "# print('[{}]/[{}] {:.2%} {:.2%}'.format(i_chunks, n_chunks, prec_top5.val, prec_top5.avg))"
307 | ]
308 | },
309 | {
310 | "cell_type": "code",
311 | "execution_count": null,
312 | "metadata": {},
313 | "outputs": [],
314 | "source": []
315 | }
316 | ],
317 | "metadata": {
318 | "kernelspec": {
319 | "display_name": "Python 3",
320 | "language": "python",
321 | "name": "python3"
322 | },
323 | "language_info": {
324 | "codemirror_mode": {
325 | "name": "ipython",
326 | "version": 3
327 | },
328 | "file_extension": ".py",
329 | "mimetype": "text/x-python",
330 | "name": "python",
331 | "nbconvert_exporter": "python",
332 | "pygments_lexer": "ipython3",
333 | "version": "3.6.8"
334 | }
335 | },
336 | "nbformat": 4,
337 | "nbformat_minor": 2
338 | }
339 |
--------------------------------------------------------------------------------
/cifar-semi.py:
--------------------------------------------------------------------------------
1 | """Train CIFAR10 with PyTorch."""
2 | import argparse
3 | import os
4 | import random
5 | import time
6 | from pprint import pprint
7 |
8 | import numpy as np
9 | from skimage.color import rgb2gray
10 | import torch
11 | import torch.backends.cudnn as cudnn
12 | import torch.nn as nn
13 | import torch.optim as optim
14 | from torch.optim.lr_scheduler import CosineAnnealingLR
15 | from torch.utils.data.sampler import SubsetRandomSampler
16 | from torchvision import transforms
17 | from torchvision.datasets import CIFAR10
18 | from torch.utils.tensorboard import SummaryWriter
19 |
20 | from lib.datasets import PseudoCIFAR10
21 | from lib.utils import AverageMeter, accuracy, CosineAnnealingLRWithRestart
22 | from lib.models import WideResNet, resnet18_cifar
23 | from test import validate
24 |
25 |
26 | def get_dataloader(args):
27 | if not args.input_gray:
28 | normalize = transforms.Normalize(
29 | (0.4914, 0.4822, 0.4465),
30 | (0.2470, 0.2435, 0.2616))
31 | transform_train = transforms.Compose([
32 | transforms.Pad(4, padding_mode='reflect'),
33 | transforms.RandomCrop(32),
34 | transforms.RandomHorizontalFlip(),
35 | transforms.ToTensor(),
36 | normalize,
37 | ])
38 | transform_test = transforms.Compose([
39 | transforms.ToTensor(),
40 | normalize,
41 | ])
42 | else:
43 | to_gray = transforms.Lambda(lambda img: torch.from_numpy(
44 | rgb2gray(np.array(img))).unsqueeze(0).float())
45 | transform_train = transforms.Compose([
46 | transforms.Pad(4, padding_mode='reflect'),
47 | transforms.RandomCrop(32),
48 | transforms.RandomHorizontalFlip(),
49 | to_gray,
50 | ])
51 | transform_test = to_gray
52 |
53 | testset = CIFAR10(root=args.data_dir, train=False,
54 | download=True, transform=transform_test)
55 | testloader = torch.utils.data.DataLoader(
56 | testset, shuffle=False,
57 | batch_size=args.batch_size,
58 | num_workers=args.num_workers)
59 |
60 | trainset = CIFAR10(root=args.data_dir, train=True,
61 | download=True, transform=transform_test)
62 |
63 | args.ndata = len(trainset)
64 | num_labeled_data = args.num_labeled
65 | num_unlabeled_data = args.ndata - num_labeled_data
66 |
67 | if args.pseudo_file is not None:
68 | pseudo_dict = torch.load(args.pseudo_file)
69 | labeled_indexes = pseudo_dict['labeled_indexes']
70 | else:
71 | torch.manual_seed(args.rng_seed)
72 | perm = torch.randperm(args.ndata)
73 | labeled_indexes = perm[:num_labeled_data]
74 |
75 | pseudo_trainset = PseudoCIFAR10(
76 | labeled_indexes=labeled_indexes, root=args.data_dir,
77 | train=True, transform=transform_train)
78 |
79 | # load pseudo labels
80 | if args.pseudo_file is not None:
81 | pseudo_num = int(num_unlabeled_data * args.pseudo_ratio)
82 | pseudo_indexes = pseudo_dict['pseudo_indexes'][:pseudo_num]
83 | pseudo_labels = pseudo_dict['pseudo_labels'][:pseudo_num]
84 | pseudo_trainset.set_pseudo(pseudo_indexes, pseudo_labels)
85 |
86 | pseudo_trainloder = torch.utils.data.DataLoader(
87 | pseudo_trainset, batch_size=args.batch_size,
88 | shuffle=True, num_workers=args.num_workers)
89 |
90 | print('-' * 80)
91 | print('selected labeled indexes: ', labeled_indexes)
92 |
93 | return testloader, pseudo_trainloder
94 |
95 |
96 | def build_model(args):
97 | if args.architecture == 'resnet18':
98 | net = resnet18_cifar(low_dim=args.num_class, norm=False)
99 | elif args.architecture.startswith('wrn'):
100 | split = args.architecture.split('-')
101 | net = WideResNet(depth=int(split[1]), widen_factor=int(split[2]),
102 | num_classes=args.num_class, norm=False)
103 | else:
104 | raise ValueError('architecture should be resnet18 or wrn')
105 | if args.input_gray:
106 | net.conv1 = nn.Conv2d(1, net.conv1.out_channels,
107 | kernel_size=3, stride=1, padding=1, bias=False)
108 | net = net.to(args.device)
109 |
110 | print('#param: {}'.format(sum([p.nelement() for p in net.parameters()])))
111 |
112 | if args.device == 'cuda':
113 | net = torch.nn.DataParallel(
114 | net, device_ids=range(torch.cuda.device_count()))
115 | cudnn.benchmark = True
116 |
117 | # resume from unsupervised pretrain
118 | if len(args.resume) > 0:
119 | # Load checkpoint.
120 | print('==> Resuming from unsupervised pretrained checkpoint..')
121 | checkpoint = torch.load(args.resume)
122 | # only load shared conv layers, don't load fc
123 | model_dict = net.state_dict()
124 | if not args.input_gray:
125 | pretrained_dict = checkpoint['net']
126 | else:
127 | lst = ['conv1', 'block1', 'block2', 'block3']
128 | pretrained_dict = {
129 | 'module.' + lst[int(k[0])] + k[1:]: v for k, v in checkpoint.items()}
130 | pretrained_dict = {k: v for k, v in pretrained_dict.items()
131 | if k in model_dict
132 | and v.size() == model_dict[k].size()}
133 | assert len(pretrained_dict) > 0
134 | model_dict.update(pretrained_dict)
135 | net.load_state_dict(model_dict)
136 |
137 | return net
138 |
139 |
140 | def get_lr_scheduler(optimizer, lr_scheduler, max_iters):
141 | if args.lr_scheduler == 'cosine':
142 | scheduler = CosineAnnealingLR(optimizer, max_iters, eta_min=0.00001)
143 | elif args.lr_scheduler == 'cosine-with-restart':
144 | scheduler = CosineAnnealingLRWithRestart(optimizer, eta_min=0.00001)
145 | else:
146 | raise ValueError("not supported")
147 |
148 | return scheduler
149 |
150 |
151 | # Training
152 | def train(net, optimizer, scheduler, trainloader, testloader, criterion, summary_writer, args):
153 | train_loss = AverageMeter()
154 | data_time = AverageMeter()
155 | batch_time = AverageMeter()
156 | top1 = AverageMeter()
157 | top2 = AverageMeter()
158 |
159 | best_acc = 0
160 | end = time.time()
161 |
162 | def inf_generator(trainloader):
163 | while True:
164 | for data in trainloader:
165 | yield data
166 |
167 | for step, (inputs, targets) in enumerate(inf_generator(trainloader)):
168 | if step >= args.max_iters:
169 | break
170 |
171 | data_time.update(time.time() - end)
172 |
173 | inputs = inputs.to(args.device)
174 | targets = targets.to(args.device)
175 |
176 | # switch to train mode
177 | net.train()
178 | scheduler.step()
179 | optimizer.zero_grad()
180 |
181 | outputs = net(inputs)
182 | loss = criterion(outputs, targets).mean()
183 | prec1, prec2 = accuracy(outputs, targets, topk=(1, 2))
184 | top1.update(prec1[0], inputs.size(0))
185 | top2.update(prec2[0], inputs.size(0))
186 |
187 | loss.backward()
188 | optimizer.step()
189 |
190 | train_loss.update(loss.item(), inputs.size(0))
191 |
192 | # measure elapsed time
193 | batch_time.update(time.time() - end)
194 | end = time.time()
195 |
196 | summary_writer.add_scalar('lr', optimizer.param_groups[0]['lr'], step)
197 | summary_writer.add_scalar('top1', top1.val, step)
198 | summary_writer.add_scalar('top2', top2.val, step)
199 | summary_writer.add_scalar('batch_time', batch_time.val, step)
200 | summary_writer.add_scalar('data_time', data_time.val, step)
201 | summary_writer.add_scalar('train_loss', train_loss.val, step)
202 |
203 | if step % args.print_freq == 0:
204 | lr = optimizer.param_groups[0]["lr"]
205 | print(f'Train: [{step}/{args.max_iters}] '
206 | f'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
207 | f'Data: {data_time.val:.3f} ({data_time.avg:.3f}) '
208 | f'Lr: {lr:.5f} '
209 | f'prec1: {top1.val:.3f} ({top1.avg:.3f}) '
210 | f'prec2: {top2.val:.3f} ({top2.avg:.3f}) '
211 | f'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f})')
212 |
213 | if (step + 1) % args.eval_freq == 0 or step == args.max_iters - 1:
214 | acc = validate(testloader, net, criterion,
215 | device=args.device, print_freq=args.print_freq)
216 |
217 | summary_writer.add_scalar('val_top1', acc, step)
218 |
219 | if acc > best_acc:
220 | best_acc = acc
221 | state = {
222 | 'step': step,
223 | 'best_acc': best_acc,
224 | 'net': net.state_dict(),
225 | 'optimizer': optimizer.state_dict(),
226 | }
227 | os.makedirs(args.model_dir, exist_ok=True)
228 | torch.save(state, os.path.join(args.model_dir, 'ckpt.pth.tar'))
229 |
230 | print('best accuracy: {:.2f}\n'.format(best_acc))
231 |
232 |
233 | def main(args):
234 | # Data
235 | print('==> Preparing data..')
236 | testloader, pseudo_trainloder = get_dataloader(args)
237 |
238 | print('==> Building model..')
239 | net = build_model(args)
240 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9,
241 | weight_decay=5e-4, nesterov=True)
242 |
243 | criterion = nn.__dict__[args.criterion]().to(args.device)
244 | scheduler = get_lr_scheduler(optimizer, args.lr_scheduler, args.max_iters)
245 |
246 | if args.eval:
247 | return validate(testloader, net, criterion,
248 | device=args.device, print_freq=args.print_freq)
249 | # summary writer
250 | os.makedirs(args.log_dir, exist_ok=True)
251 | summary_writer = SummaryWriter(args.log_dir)
252 |
253 | train(net, optimizer, scheduler, pseudo_trainloder,
254 | testloader, criterion, summary_writer, args)
255 |
256 |
257 | if __name__ == '__main__':
258 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
259 | parser.add_argument('--data_dir', '--dataDir', default='./data',
260 | type=str, metavar='DIR')
261 | parser.add_argument('--model-root', default='./checkpoint/cifar10-semi',
262 | type=str, metavar='DIR',
263 | help='root directory to save checkpoint')
264 | parser.add_argument('--log-root', default='./tensorboard/cifar10-semi',
265 | type=str, metavar='DIR',
266 | help='root directory to save tensorboard logs')
267 | parser.add_argument('--exp-name', default='exp', type=str,
268 | help='experiment name, used to determine log_dir and model_dir')
269 | parser.add_argument('--lr', default=0.01, type=float,
270 | metavar='LR', help='learning rate')
271 | parser.add_argument('--lr-scheduler', default='cosine', type=str,
272 | choices=['multi-step', 'cosine',
273 | 'cosine-with-restart'],
274 | help='which lr scheduler to use')
275 | parser.add_argument('--resume', '-r', default='', type=str,
276 | metavar='FILE', help='resume from checkpoint')
277 | parser.add_argument('--eval', action='store_true', help='test only')
278 | parser.add_argument('--finetune', action='store_true',
279 | help='only training last fc layer')
280 | parser.add_argument('-j', '--num-workers', default=2, type=int,
281 | metavar='N', help='number of workers to load data')
282 | parser.add_argument('-b', '--batch-size', default=128, type=int,
283 | metavar='N', help='batch size')
284 | parser.add_argument('--max-iters', default=500000, type=int,
285 | metavar='N', help='number of iterations')
286 | parser.add_argument('--num-labeled', default=500, type=int,
287 | metavar='N', help='number of labeled data')
288 | parser.add_argument('--rng-seed', default=0, type=int,
289 | metavar='N', help='random number generator seed')
290 | parser.add_argument('--gpus', default='0', type=str, metavar='GPUS')
291 | parser.add_argument('--eval-freq', default=500, type=int,
292 | metavar='N', help='eval frequence')
293 | parser.add_argument('--print-freq', default=100, type=int,
294 | metavar='N', help='print frequence')
295 | parser.add_argument('--criterion', default='CrossEntropyLoss', type=str,
296 | choices=['CrossEntropyLoss', 'MultiMarginLoss'])
297 | parser.add_argument('--pseudo-file', type=str,
298 | metavar='FILE', help='pseudo file to load', required=True)
299 | parser.add_argument('--input-gray', action='store_true',
300 | help='set for load colorization pretrained model, '
301 | '(colorization model use gray image as input)')
302 | parser.add_argument('--pseudo-ratio', default=1, type=float, metavar='0-1',
303 | help='ratio of unlabeled data to use for pseudo labels')
304 | parser.add_argument('--architecture', '--arch', default='wrn-28-2', type=str,
305 | help='which backbone to use')
306 | args, rest = parser.parse_known_args()
307 | print(rest)
308 |
309 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
310 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
311 | args.num_class = 10
312 | args.log_dir = os.path.join(args.log_root, args.exp_name)
313 | args.model_dir = os.path.join(args.model_root, args.exp_name)
314 |
315 | torch.manual_seed(args.rng_seed)
316 | torch.cuda.manual_seed(args.rng_seed)
317 | random.seed(args.rng_seed)
318 | torch.set_printoptions(threshold=50, precision=4)
319 |
320 | print('-' * 80)
321 | pprint(vars(args))
322 |
323 | main(args)
324 |
325 | print('-' * 80)
326 | pprint(vars(args))
327 |
--------------------------------------------------------------------------------
/notebooks/nc-colorization.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%matplotlib inline\n",
10 | "import sys\n",
11 | "sys.path.append('../')\n",
12 | "import torch\n",
13 | "import torch.nn as nn\n",
14 | "import torch.optim as optim\n",
15 | "import torch.nn.functional as F\n",
16 | "import torch.backends.cudnn as cudnn\n",
17 | "\n",
18 | "import torchvision\n",
19 | "import torchvision.transforms as transforms\n",
20 | "\n",
21 | "import math\n",
22 | "import os\n",
23 | "import argparse\n",
24 | "import time\n",
25 | "\n",
26 | "from lib import models, datasets\n",
27 | "\n",
28 | "\n",
29 | "import numpy as np\n",
30 | "import scipy as sp\n",
31 | "import scipy.sparse.linalg as linalg\n",
32 | "import scipy.sparse as sparse\n",
33 | "\n",
34 | "import matplotlib.pyplot as plt\n",
35 | "import easydict as edict"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 2,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "# parameters\n",
45 | "args = edict\n",
46 | "\n",
47 | "args.cache = '../checkpoint/train_features_labels_cache/colorization_embedding_128.t7'\n",
48 | "args.save_path = '../checkpoint/pseudos/colorization_nc_pseudo_wrn-28-2'\n",
49 | "os.makedirs(args.save_path, exist_ok=True)\n",
50 | "\n",
51 | "args.num_class = 10\n",
52 | "args.rng_seed = 0"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 3,
58 | "metadata": {
59 | "scrolled": true
60 | },
61 | "outputs": [
62 | {
63 | "name": "stdout",
64 | "output_type": "stream",
65 | "text": [
66 | "torch.float32 torch.int64\n",
67 | "torch.Size([50000, 128]) torch.Size([50000])\n"
68 | ]
69 | }
70 | ],
71 | "source": [
72 | "train_features = torch.load(args.cache)\n",
73 | "train_labels = torch.Tensor(datasets.CIFAR10Instance(root='../data', train=True).targets).long()\n",
74 | "\n",
75 | "print(train_features.dtype, train_labels.dtype)\n",
76 | "print(train_features.shape, train_labels.shape)"
77 | ]
78 | },
79 | {
80 | "cell_type": "markdown",
81 | "metadata": {},
82 | "source": [
83 | "# use cpu because the follow computation need a lot of memory"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 4,
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "device = 'cpu'\n",
93 | "train_features, train_labels = train_features.to(device), train_labels.to(device)"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 5,
99 | "metadata": {},
100 | "outputs": [
101 | {
102 | "name": "stdout",
103 | "output_type": "stream",
104 | "text": [
105 | "tensor([36044, 49165, 37807, ..., 42128, 15898, 31476])\n"
106 | ]
107 | }
108 | ],
109 | "source": [
110 | "num_train_data = train_labels.shape[0]\n",
111 | "num_class = torch.max(train_labels) + 1\n",
112 | "\n",
113 | "torch.manual_seed(args.rng_seed)\n",
114 | "torch.cuda.manual_seed_all(args.rng_seed)\n",
115 | "perm = torch.randperm(num_train_data).to(device)\n",
116 | "print(perm)"
117 | ]
118 | },
119 | {
120 | "cell_type": "markdown",
121 | "metadata": {},
122 | "source": [
123 | "# constrained normalized cut"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 6,
129 | "metadata": {},
130 | "outputs": [
131 | {
132 | "name": "stdout",
133 | "output_type": "stream",
134 | "text": [
135 | "similarity done\n",
136 | "L_sys done\n"
137 | ]
138 | }
139 | ],
140 | "source": [
141 | "K = 20\n",
142 | "def make_column_normalize(X):\n",
143 | " return X.div(torch.norm(X, p=2, dim=0, keepdim=True))\n",
144 | "\n",
145 | "cosin_similarity = torch.mm(train_features, train_features.t())\n",
146 | "dist = (1 - cosin_similarity) / 2\n",
147 | "\n",
148 | "dist_sorted, idx = dist.topk(K, dim=1, largest=False, sorted=True)\n",
149 | "k_dist = dist_sorted[:, -1:]\n",
150 | "\n",
151 | "similarity_dense = torch.exp(-dist_sorted * 2 / k_dist)\n",
152 | "similarity_sparse = torch.zeros_like(cosin_similarity)\n",
153 | "similarity_sparse[torch.arange(num_train_data).view(-1, 1), idx[:, 1:]] = similarity_dense[:, 1:]\n",
154 | "similarity = torch.max(similarity_sparse, similarity_sparse.t())\n",
155 | "print('similarity done')\n",
156 | "\n",
157 | "degree = similarity.sum(0)\n",
158 | "degree_normed = (degree**(-0.5))\n",
159 | "L_sys = degree_normed.view(-1, 1) * (degree.diag() - similarity) * degree_normed.view(1, -1)\n",
160 | "print('L_sys done')"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": 7,
166 | "metadata": {
167 | "scrolled": true
168 | },
169 | "outputs": [
170 | {
171 | "name": "stdout",
172 | "output_type": "stream",
173 | "text": [
174 | "eigenvectors done\n",
175 | "tensor([0.0160, 0.0236, 0.0306, 0.0318, 0.0359, 0.0400, 0.0453, 0.0561, 0.0603,\n",
176 | " 0.0621, 0.0635, 0.0731, 0.0745, 0.0775, 0.0795, 0.0860, 0.0881, 0.0942,\n",
177 | " 0.0967, 0.1003, 0.1035, 0.1047, 0.1070, 0.1094, 0.1170, 0.1221, 0.1237,\n",
178 | " 0.1275, 0.1290, 0.1316, 0.1377, 0.1384, 0.1393, 0.1425, 0.1435, 0.1466,\n",
179 | " 0.1505, 0.1539, 0.1548, 0.1591, 0.1615, 0.1644, 0.1660, 0.1665, 0.1699,\n",
180 | " 0.1715, 0.1721, 0.1727, 0.1734, 0.1756, 0.1794, 0.1803, 0.1805, 0.1819,\n",
181 | " 0.1854, 0.1874, 0.1875, 0.1881, 0.1889, 0.1919, 0.1926, 0.1952, 0.1975,\n",
182 | " 0.1997, 0.2004, 0.2010, 0.2025, 0.2035, 0.2053, 0.2068, 0.2088, 0.2101,\n",
183 | " 0.2117, 0.2138, 0.2145, 0.2150, 0.2183, 0.2190, 0.2205, 0.2220, 0.2245,\n",
184 | " 0.2251, 0.2275, 0.2282, 0.2285, 0.2297, 0.2299, 0.2316, 0.2330, 0.2358,\n",
185 | " 0.2359, 0.2381, 0.2396, 0.2409, 0.2425, 0.2440, 0.2448, 0.2459, 0.2462,\n",
186 | " 0.2480, 0.2492, 0.2495, 0.2518, 0.2520, 0.2531, 0.2543, 0.2545, 0.2562,\n",
187 | " 0.2574, 0.2582, 0.2592, 0.2598, 0.2612, 0.2621, 0.2633, 0.2638, 0.2642,\n",
188 | " 0.2653, 0.2669, 0.2679, 0.2689, 0.2699, 0.2706, 0.2716, 0.2721, 0.2729,\n",
189 | " 0.2744, 0.2750, 0.2761, 0.2774, 0.2784, 0.2794, 0.2803, 0.2810, 0.2821,\n",
190 | " 0.2825, 0.2835, 0.2847, 0.2850, 0.2853, 0.2867, 0.2877, 0.2882, 0.2893,\n",
191 | " 0.2897, 0.2902, 0.2911, 0.2920, 0.2933, 0.2944, 0.2951, 0.2962, 0.2966,\n",
192 | " 0.2973, 0.2978, 0.2984, 0.2990, 0.2994, 0.3007, 0.3011, 0.3022, 0.3023,\n",
193 | " 0.3030, 0.3043, 0.3046, 0.3054, 0.3062, 0.3067, 0.3069, 0.3075, 0.3080,\n",
194 | " 0.3097, 0.3110, 0.3114, 0.3119, 0.3131, 0.3133, 0.3144, 0.3152, 0.3153,\n",
195 | " 0.3159, 0.3172, 0.3176, 0.3183, 0.3188, 0.3193, 0.3197, 0.3205, 0.3212,\n",
196 | " 0.3216, 0.3220, 0.3226, 0.3231, 0.3238, 0.3242, 0.3249, 0.3259, 0.3262,\n",
197 | " 0.3270])\n"
198 | ]
199 | }
200 | ],
201 | "source": [
202 | "num_eigenvectors = 200 # the number of precomputed spectral eigenvectors.\n",
203 | "\n",
204 | "eigenvalues, eigenvectors = linalg.eigs(L_sys.numpy(), k=num_eigenvectors, which='SR', tol=1e-2, maxiter=30000)\n",
205 | "eigenvalues, eigenvectors = torch.from_numpy(eigenvalues.real)[1:], torch.from_numpy(eigenvectors.real)[:, 1:]\n",
206 | "eigenvalues, idx = eigenvalues.sort()\n",
207 | "eigenvectors = eigenvectors[:, idx]\n",
208 | "print('eigenvectors done')\n",
209 | "print(eigenvalues)"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": 8,
215 | "metadata": {
216 | "scrolled": false
217 | },
218 | "outputs": [
219 | {
220 | "name": "stdout",
221 | "output_type": "stream",
222 | "text": [
223 | "num_labeled= 50 T_nc=1, prec=48.40, AUC=60.85\n",
224 | "num_labeled= 100 T_nc=1, prec=51.91, AUC=67.34\n",
225 | "num_labeled= 250 T_nc=1, prec=61.03, AUC=76.31\n",
226 | "num_labeled= 500 T_nc=1, prec=64.05, AUC=80.04\n",
227 | "num_labeled=1000 T_nc=1, prec=64.84, AUC=81.78\n",
228 | "num_labeled=2000 T_nc=1, prec=64.84, AUC=81.89\n",
229 | "num_labeled=4000 T_nc=1, prec=65.60, AUC=82.93\n",
230 | "num_labeled=8000 T_nc=1, prec=65.11, AUC=82.03\n"
231 | ]
232 | }
233 | ],
234 | "source": [
235 | "fig = plt.figure(dpi=200)\n",
236 | "\n",
237 | "for num_labeled_data in [50, 100, 250, 500, 1000, 2000, 4000, 8000]:\n",
238 | " # index of labeled and unlabeled\n",
239 | " # even split\n",
240 | " index_labeled = []\n",
241 | " index_unlabeled = []\n",
242 | " data_per_class = num_labeled_data // args.num_class\n",
243 | " for c in range(10):\n",
244 | " indexes_c = perm[train_labels[perm] == c]\n",
245 | " index_labeled.append(indexes_c[:data_per_class])\n",
246 | " index_unlabeled.append(indexes_c[data_per_class:])\n",
247 | " index_labeled = torch.cat(index_labeled)\n",
248 | " index_unlabeled = torch.cat(index_unlabeled)\n",
249 | "\n",
250 | "# index_labeled = perm[:num_labeled_data]\n",
251 | "# index_unlabeled = perm[num_labeled_data:]\n",
252 | " \n",
253 | " # prior\n",
254 | " unary_prior = torch.zeros([num_train_data, num_class])\n",
255 | " unary_prior[index_labeled, :] = -1\n",
256 | " unary_prior[index_labeled, train_labels[index_labeled]] = 1\n",
257 | " AQ = unary_prior.abs()\n",
258 | " pd = degree.view(-1, 1) * (AQ + unary_prior) / 2\n",
259 | " nd = degree.view(-1, 1) * (AQ - unary_prior) / 2\n",
260 | " np_ratio = pd.sum(dim=0) / nd.sum(dim=0)\n",
261 | " unary_prior_norm = (pd / np_ratio).sqrt() - (nd * np_ratio).sqrt()\n",
262 | " unary_prior_norm = make_column_normalize(unary_prior_norm)\n",
263 | " \n",
264 | " # logits and prediction\n",
265 | " alpha = 0\n",
266 | " lambda_reverse = (1 / (eigenvalues - alpha)).view(1, -1)\n",
267 | " logits = torch.mm(lambda_reverse * eigenvectors, torch.mm(eigenvectors.t(), unary_prior_norm))\n",
268 | " logits = make_column_normalize(logits) * math.sqrt(logits.shape[0]) \n",
269 | " logits = logits - logits.max(1, keepdim=True)[0]\n",
270 | " _, predict = logits.max(dim=1)\n",
271 | " \n",
272 | " for temperature_nc in [1]:#, 2, 3, 5, 10, 15, 20, 25, 30, 35, 40, 100]: \n",
273 | " # pseudo weights\n",
274 | " logits_sorted = logits.sort(dim=1, descending=True)[0]\n",
275 | " subtract = logits_sorted[:, 0] - logits_sorted[:, 1]\n",
276 | " pseudo_weights = 1 - torch.exp(- subtract / temperature_nc)\n",
277 | " \n",
278 | " exp = (logits * temperature_nc).exp()\n",
279 | " probs = exp / exp.sum(1, keepdim=True)\n",
280 | " probs_sorted, predict_all = probs.sort(1, True)\n",
281 | " assert torch.all(predict == predict_all[:, 0])\n",
282 | "\n",
283 | " idx = pseudo_weights[index_unlabeled].sort(dim=0, descending=True)[1]\n",
284 | " pseudo_indexes = index_unlabeled[idx]\n",
285 | " pseudo_labels = predict[index_unlabeled][idx]\n",
286 | " pseudo_probs = probs[index_unlabeled][idx]\n",
287 | " pseudo_weights = pseudo_weights[index_unlabeled][idx]\n",
288 | " assert torch.all(pseudo_labels == pseudo_probs.max(1)[1])\n",
289 | " \n",
290 | " save_dict = {\n",
291 | " 'pseudo_indexes': pseudo_indexes,\n",
292 | " 'pseudo_labels': pseudo_labels,\n",
293 | " 'pseudo_probs': pseudo_probs,\n",
294 | " 'pseudo_weights': pseudo_weights,\n",
295 | " 'labeled_indexes': index_labeled,\n",
296 | " 'unlabeled_indexes': index_unlabeled,\n",
297 | " }\n",
298 | " torch.save(save_dict, os.path.join(args.save_path, '{}.pth.tar'.format(num_labeled_data)))\n",
299 | "\n",
300 | " # for plot\n",
301 | " correct = pseudo_labels == train_labels[pseudo_indexes]\n",
302 | " \n",
303 | " entropy = - (pseudo_probs * torch.log(pseudo_probs + 1e-7)).sum(dim=1)\n",
304 | " confidence = (- entropy * 1).exp()\n",
305 | " confidence /= confidence.max()\n",
306 | "\n",
307 | " arange = 1 + np.arange(confidence.shape[0])\n",
308 | " xs = arange / confidence.shape[0]\n",
309 | " correct_tmp = correct[confidence.sort(descending=True)[1]]\n",
310 | " accuracies = np.cumsum(correct_tmp.numpy()) / arange\n",
311 | " plt.plot(xs, accuracies, label='num_labeled_data={}'.format(num_labeled_data))\n",
312 | "\n",
313 | " acc = correct.float().mean()\n",
314 | "\n",
315 | " print('num_labeled={:4} T_nc={}, prec={:.2f}, AUC={:.2f}'.format(\n",
316 | " num_labeled_data, temperature_nc, acc * 100, accuracies.mean() * 100))\n",
317 | " \n",
318 | "plt.xlabel('accumulated unlabeled data ratio')\n",
319 | "plt.ylabel('unlabeled top1 accuracy')\n",
320 | "plt.xticks(np.arange(0, 1.01, 0.1))\n",
321 | "plt.grid()\n",
322 | "plt.title('num_eigenvectors={}'.format(num_eigenvectors))\n",
323 | "legend = plt.legend(loc='upper left', bbox_to_anchor=(1, 1))\n",
324 | "plt.show()"
325 | ]
326 | },
327 | {
328 | "cell_type": "code",
329 | "execution_count": null,
330 | "metadata": {},
331 | "outputs": [],
332 | "source": []
333 | }
334 | ],
335 | "metadata": {
336 | "kernelspec": {
337 | "display_name": "Python 3",
338 | "language": "python",
339 | "name": "python3"
340 | },
341 | "language_info": {
342 | "codemirror_mode": {
343 | "name": "ipython",
344 | "version": 3
345 | },
346 | "file_extension": ".py",
347 | "mimetype": "text/x-python",
348 | "name": "python",
349 | "nbconvert_exporter": "python",
350 | "pygments_lexer": "ipython3",
351 | "version": "3.6.8"
352 | }
353 | },
354 | "nbformat": 4,
355 | "nbformat_minor": 2
356 | }
357 |
--------------------------------------------------------------------------------
/imagenet-semi.py:
--------------------------------------------------------------------------------
1 | """Train ImageNet with PyTorch."""
2 | import argparse
3 | import glob
4 | import os
5 | import random
6 | import time
7 | from pprint import pprint
8 |
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 | import torch.nn as nn
12 | import torch.optim as optim
13 | import torchvision.datasets as datasets
14 | import torchvision.models as models
15 | import torchvision.transforms as transforms
16 | from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR
17 | from torch.utils.data.sampler import SubsetRandomSampler
18 | from torch.utils.tensorboard import SummaryWriter
19 |
20 | from lib.datasets import PseudoDatasetFolder
21 | from lib.utils import AverageMeter, accuracy, CosineAnnealingLRWithRestart
22 | from test import validate
23 |
24 | best_acc = 0
25 | global_step = 0
26 |
27 |
28 | def get_dataloader(args):
29 | normalize = transforms.Normalize(
30 | (0.485, 0.456, 0.406),
31 | (0.229, 0.224, 0.225))
32 | transform_train = transforms.Compose([
33 | transforms.RandomResizedCrop(224),
34 | transforms.RandomHorizontalFlip(),
35 | transforms.ToTensor(),
36 | normalize,
37 | ])
38 |
39 | transform_test = transforms.Compose([
40 | transforms.Resize(256),
41 | transforms.CenterCrop(224),
42 | transforms.ToTensor(),
43 | normalize,
44 | ])
45 | traindir = os.path.join(args.data_dir, 'train')
46 | valdir = os.path.join(args.data_dir, 'val')
47 |
48 | testset = datasets.ImageFolder(valdir, transform=transform_test)
49 | testloader = torch.utils.data.DataLoader(
50 | testset, shuffle=False,
51 | batch_size=args.batch_size,
52 | num_workers=args.num_workers)
53 |
54 | trainset = datasets.ImageFolder(traindir, transform=transform_train)
55 |
56 | # split labeled and unlabeled
57 | args.ndata = len(trainset)
58 | num_labeled = args.num_labeled
59 | num_unlabeled = args.ndata - num_labeled
60 |
61 | torch.manual_seed(args.rng_seed)
62 | perm = torch.randperm(args.ndata)
63 |
64 | index_labeled = []
65 | index_unlabeled = []
66 | data_per_class = num_labeled // args.num_class
67 | train_labels = torch.Tensor([x[1] for x in trainset.samples])
68 | for c in range(args.num_class):
69 | indexes_c = perm[train_labels[perm] == c]
70 | index_labeled.append(indexes_c[:data_per_class])
71 | index_unlabeled.append(indexes_c[data_per_class:])
72 |
73 | args.index_labeled = torch.cat(index_labeled)
74 | args.index_unlabeled = torch.cat(index_unlabeled)
75 |
76 | print('-' * 80)
77 | print('selected labeled indexes: ', args.index_labeled)
78 |
79 | pseudo_trainset = PseudoDatasetFolder(
80 | trainset, labeled_indexes=args.index_labeled)
81 | # load pseudo labels
82 | if args.pseudo_dir is not None:
83 | pseudo_files = glob.glob(args.pseudo_dir + '/*')
84 | pseudo_num_per_chunk = int(
85 | num_unlabeled * args.pseudo_ratio / len(pseudo_files))
86 |
87 | pseudo_indexes = []
88 | pseudo_labels = []
89 | for pseudo_file in pseudo_files:
90 | pseudo_dict = torch.load(pseudo_file)
91 | pseudo_indexes.append(
92 | pseudo_dict['pseudo_indexes'][:pseudo_num_per_chunk])
93 | pseudo_labels.append(
94 | pseudo_dict['pseudo_labels'][:pseudo_num_per_chunk])
95 | assert (args.index_labeled == pseudo_dict['labeled_indexes']).all()
96 | pseudo_indexes = torch.cat(pseudo_indexes)
97 | pseudo_labels = torch.cat(pseudo_labels)
98 |
99 | assert num_labeled == args.index_labeled.shape[0]
100 |
101 | pseudo_trainset.set_pseudo(pseudo_indexes, pseudo_labels)
102 |
103 | print('num_pseudo = {}'.format(pseudo_indexes.shape[0]))
104 |
105 | pseudo_trainloder = torch.utils.data.DataLoader(
106 | pseudo_trainset, batch_size=args.batch_size,
107 | shuffle=True, num_workers=args.num_workers)
108 |
109 | return testloader, pseudo_trainloder
110 |
111 |
112 | def build_model(args):
113 | print("=> creating model '{}'".format(args.architecture))
114 | net = models.__dict__[args.architecture]()
115 | net = net.to(args.device)
116 |
117 | print('#param: {}'.format(sum([p.nelement() for p in net.parameters()])))
118 |
119 | if args.device == 'cuda':
120 | net = torch.nn.DataParallel(
121 | net, device_ids=range(torch.cuda.device_count()))
122 | cudnn.benchmark = True
123 |
124 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9,
125 | weight_decay=0, nesterov=True)
126 |
127 | # resume from unsupervised pretrain
128 | if len(args.resume) > 0:
129 | print('==> Resuming from {}'.format(args.resume))
130 | global best_acc, global_step
131 | checkpoint = torch.load(args.resume)
132 | net.load_state_dict(checkpoint['net'])
133 | optimizer.load_state_dict(checkpoint['optimizer'])
134 | best_acc = checkpoint['best_acc']
135 | global_step = checkpoint['step'] + 1
136 | elif len(args.pretrained) > 0:
137 | # Load checkpoint.
138 | print('==> Load pretrained model: {}'.format(args.pretrained))
139 | checkpoint = torch.load(args.pretrained)
140 | model_dict = net.state_dict()
141 | # only load shared conv layers, don't load fc
142 | pretrained_dict = {k: v for k, v in checkpoint['state_dict'].items()
143 | if k in model_dict
144 | and v.size() == model_dict[k].size()}
145 | assert len(pretrained_dict) > 0
146 | model_dict.update(pretrained_dict)
147 | net.load_state_dict(model_dict)
148 |
149 | return net, optimizer
150 |
151 |
152 | def get_lr_scheduler(optimizer, lr_scheduler, max_iters):
153 | if lr_scheduler == 'cosine':
154 | scheduler = CosineAnnealingLR(optimizer, max_iters, eta_min=0.00001)
155 | elif lr_scheduler == 'cosine-with-restart':
156 | scheduler = CosineAnnealingLRWithRestart(optimizer, eta_min=0.00001)
157 | elif lr_scheduler == 'multi-step':
158 | scheduler = MultiStepLR(
159 | optimizer, [max_iters * 3 // 7, max_iters * 6 // 7], gamma=0.1)
160 | else:
161 | raise ValueError("not supported")
162 |
163 | return scheduler
164 |
165 |
166 | def inf_generator(trainloader):
167 | while True:
168 | for data in trainloader:
169 | yield data
170 |
171 |
172 | # Training
173 | def train(net, optimizer, scheduler, trainloader, testloader, criterion, summary_writer, args):
174 | train_loss = AverageMeter()
175 | data_time = AverageMeter()
176 | batch_time = AverageMeter()
177 | top1 = AverageMeter()
178 | top5 = AverageMeter()
179 |
180 | best_acc = 0
181 | end = time.time()
182 |
183 | global global_step
184 | for inputs, targets in inf_generator(trainloader):
185 | if global_step >= args.max_iters:
186 | break
187 |
188 | data_time.update(time.time() - end)
189 |
190 | inputs, targets = inputs.to(args.device), targets.to(args.device)
191 |
192 | # switch to train mode
193 | net.train()
194 | scheduler.step(global_step)
195 | optimizer.zero_grad()
196 |
197 | outputs = net(inputs)
198 | loss = criterion(outputs, targets)
199 | prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
200 | top1.update(prec1[0], inputs.size(0))
201 | top5.update(prec5[0], inputs.size(0))
202 |
203 | loss.backward()
204 | optimizer.step()
205 |
206 | train_loss.update(loss.item(), inputs.size(0))
207 |
208 | # measure elapsed time
209 | batch_time.update(time.time() - end)
210 | end = time.time()
211 |
212 | summary_writer.add_scalar(
213 | 'lr', optimizer.param_groups[0]['lr'], global_step)
214 | summary_writer.add_scalar('top1', top1.val, global_step)
215 | summary_writer.add_scalar('top5', top5.val, global_step)
216 | summary_writer.add_scalar('batch_time', batch_time.val, global_step)
217 | summary_writer.add_scalar('data_time', data_time.val, global_step)
218 | summary_writer.add_scalar('train_loss', train_loss.val, global_step)
219 |
220 | if global_step % args.print_freq == 0:
221 | lr = optimizer.param_groups[0]['lr']
222 | print(f'Train: [{global_step}/{args.max_iters}] '
223 | f'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
224 | f'Data: {data_time.val:.3f} ({data_time.avg:.3f}) '
225 | f'Lr: {lr:.5f} '
226 | f'prec1: {top1.val:.3f} ({top1.avg:.3f}) '
227 | f'prec5: {top5.val:.3f} ({top5.avg:.3f}) '
228 | f'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f})')
229 |
230 | if (global_step + 1) % args.eval_freq == 0 or global_step == args.max_iters - 1:
231 | acc = validate(testloader, net, criterion,
232 | device=args.device, print_freq=args.print_freq)
233 |
234 | summary_writer.add_scalar('val_top1', acc, global_step)
235 |
236 | if acc > best_acc:
237 | best_acc = acc
238 | state = {
239 | 'step': global_step,
240 | 'best_acc': best_acc,
241 | 'net': net.state_dict(),
242 | 'optimizer': optimizer.state_dict(),
243 | }
244 | os.makedirs(args.model_dir, exist_ok=True)
245 | torch.save(state, os.path.join(args.model_dir, 'ckpt.pth.tar'))
246 |
247 | print('best accuracy: {:.2f}\n'.format(best_acc))
248 | global_step += 1
249 |
250 |
251 | def main(args):
252 | # Data
253 | print('==> Preparing data..')
254 | testloader, pseudo_trainloder = get_dataloader(args)
255 |
256 | print('==> Building model..')
257 | net, optimizer = build_model(args)
258 |
259 | criterion = nn.__dict__[args.criterion]().to(args.device)
260 | scheduler = get_lr_scheduler(optimizer, args.lr_scheduler, args.max_iters)
261 |
262 | if args.eval:
263 | return validate(testloader, net, criterion,
264 | device=args.device, print_freq=args.print_freq)
265 | # summary writer
266 | os.makedirs(args.log_dir, exist_ok=True)
267 | summary_writer = SummaryWriter(args.log_dir)
268 |
269 | train(net, optimizer, scheduler, pseudo_trainloder,
270 | testloader, criterion, summary_writer, args)
271 |
272 |
273 | if __name__ == '__main__':
274 | parser = argparse.ArgumentParser(description='PyTorch Imagenet Training',
275 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
276 | parser.add_argument('--data-dir', '--dataDir', required=True,
277 | type=str, metavar='DIR', help='data dir')
278 | parser.add_argument('--model-root', default='./checkpoint/imagenet',
279 | type=str, metavar='DIR',
280 | help='root directory to save checkpoint')
281 | parser.add_argument('--log-root', default='./tensorboard/imagenet',
282 | type=str, metavar='DIR',
283 | help='root directory to save tensorboard logs')
284 | parser.add_argument('--exp-name', default='exp', type=str,
285 | help='experiment name, used to determine log_dir and model_dir')
286 | parser.add_argument('--lr', default=0.01, type=float,
287 | metavar='LR', help='learning rate')
288 | parser.add_argument('--lr-scheduler', default='multi-step', type=str,
289 | choices=['multi-step', 'cosine',
290 | 'cosine-with-restart'],
291 | help='which lr scheduler to use')
292 | parser.add_argument('--pretrained', default='', type=str,
293 | metavar='FILE', help='The pretrained checkpoint to load. Only load model parametric')
294 | parser.add_argument('--resume', '-r', default='', type=str,
295 | metavar='FILE', help='resume from checkpoint. Optimizer state will be resumed too')
296 | parser.add_argument('--eval', action='store_true', help='test only')
297 | parser.add_argument('--finetune', action='store_true',
298 | help='only training last fc layer')
299 | parser.add_argument('-j', '--num-workers', default=32, type=int,
300 | metavar='N', help='number of workers to load data')
301 | parser.add_argument('-b', '--batch-size', default=256, type=int,
302 | metavar='N', help='batch size')
303 | parser.add_argument('--max-iters', default=50000, type=int,
304 | metavar='N', help='number of iterations')
305 | parser.add_argument('--num-labeled', default=13000, type=int,
306 | metavar='N', help='number of labeled data')
307 | parser.add_argument('--rng-seed', default=0, type=int,
308 | metavar='N', help='random number generator seed')
309 | parser.add_argument('--gpus', default='0,1,2,3', type=str, metavar='GPUS',
310 | help='ids of GPU to use')
311 | parser.add_argument('--eval-freq', default=500, type=int,
312 | metavar='N', help='eval frequence')
313 | parser.add_argument('--print-freq', default=10, type=int,
314 | metavar='N', help='print frequence')
315 | parser.add_argument('--criterion', default='CrossEntropyLoss', type=str,
316 | choices=['CrossEntropyLoss', 'MultiMarginLoss'], help='Criterion to use')
317 | parser.add_argument('--pseudo-dir', type=str,
318 | metavar='PATH', help='pseudo folder to load')
319 | parser.add_argument('--pseudo-ratio', default=0.1, type=float, metavar='0-1',
320 | help='ratio of unlabeled data to use for pseudo labels')
321 | parser.add_argument('--architecture', '--arch', default='resnet18', type=str,
322 | help='which backbone to use')
323 | args_, rest = parser.parse_known_args()
324 | print(rest)
325 |
326 | os.environ["CUDA_VISIBLE_DEVICES"] = args_.gpus
327 | args_.device = 'cuda' if torch.cuda.is_available() else 'cpu'
328 | args_.num_class = 1000
329 | args_.log_dir = os.path.join(args_.log_root, args_.exp_name)
330 | args_.model_dir = os.path.join(args_.model_root, args_.exp_name)
331 |
332 | torch.manual_seed(args_.rng_seed)
333 | torch.cuda.manual_seed(args_.rng_seed)
334 | random.seed(args_.rng_seed)
335 | torch.set_printoptions(threshold=50, precision=4)
336 |
337 | print('-' * 80)
338 | pprint(vars(args_))
339 |
340 | main(args_)
341 |
342 | print('-' * 80)
343 | pprint(vars(args_))
344 |
--------------------------------------------------------------------------------