├── .gitignore
├── LICENSE
├── README.md
├── SECURITY.md
├── ThirdPartyNotices.txt
├── cifar.py
├── datasets
├── __init__.py
├── cifar.py
└── folder.py
├── lib
├── LinearAverage.py
├── NCA.py
├── __init__.py
├── normalize.py
└── utils.py
├── main.py
├── models
├── __init__.py
├── resnet.py
└── resnet_cifar.py
└── test.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data/*
2 | checkpoint/*
3 | logs/*
4 | others/*
5 |
6 | *.pyc
7 | *.bak
8 | *.log
9 | *.tar
10 | *.pth
11 |
12 | # Byte-compiled / optimized / DLL files
13 | __pycache__/
14 | *.py[cod]
15 | *$py.class
16 |
17 | # C extensions
18 | *.so
19 |
20 | # Distribution / packaging
21 | .Python
22 | build/
23 | develop-eggs/
24 | dist/
25 | downloads/
26 | eggs/
27 | .eggs/
28 | lib/
29 | lib64/
30 | parts/
31 | sdist/
32 | var/
33 | wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .coverage
53 | .coverage.*
54 | .cache
55 | nosetests.xml
56 | coverage.xml
57 | *.cover
58 | .hypothesis/
59 | .pytest_cache/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # celery beat schedule file
90 | celerybeat-schedule
91 |
92 | # SageMath parsed files
93 | *.sage.py
94 |
95 | # Environments
96 | .env
97 | .venv
98 | env/
99 | venv/
100 | ENV/
101 | env.bak/
102 | venv.bak/
103 |
104 | # Spyder project settings
105 | .spyderproject
106 | .spyproject
107 |
108 | # Rope project settings
109 | .ropeproject
110 |
111 | # mkdocs documentation
112 | /site
113 |
114 | # mypy
115 | .mypy_cache/
116 | >>>>>>> origin/master
117 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation. All rights reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Improving Generalization via Scalable Neighborhood Component Analysis
2 |
3 | This repo constains the pytorch implementation for the ECCV 2018 paper [(paper)](https://arxiv.org/pdf/1808.04699.pdf).
4 | We use deep networks to learn feature representations optimized for nearest neighbor classifiers, which could generalize better for new object categories.
5 | This project is a re-investigation of [Neighborhood Component Analysis (NCA)](http://www.cs.toronto.edu/~fritz/absps/nca.pdf)
6 | with recent technologies to make it scalable to deep networks and large-scale datasets.
7 |
8 | Much of code is extended from the previous [unsupervised learning project](https://arxiv.org/pdf/1805.01978.pdf).
9 | Please refer to [this repo](https://github.com/zhirongw/lemniscate.pytorch) for more details.
10 |
11 |
12 |
13 | ## Pretrained Models
14 |
15 | Currently, we provide three pretrained ResNet models.
16 | Each release contains the feature representation of all ImageNet training images (600 mb) and model weights (100-200mb).
17 | Models and their performance with nearest neighbor classifiers are as follows.
18 |
19 | - [ResNet 18](http://zhirongw.westus2.cloudapp.azure.com/models/snca_resnet18.pth.tar) (top 1 accuracy 70.59%)
20 | - [ResNet 34](http://zhirongw.westus2.cloudapp.azure.com/models/snca_resnet34.pth.tar) (top 1 accuracy 74.41%)
21 | - [ResNet 50](http://zhirongw.westus2.cloudapp.azure.com/models/snca_resnet50.pth.tar) (top 1 accuracy 76.57%)
22 |
23 | Code to reproduce the rest of the experiments are comming soon.
24 |
25 | ## Nearest Neighbors
26 |
27 | Please follow [this link](http://zhirongw.westus2.cloudapp.azure.com/nn.html) for a list of nearest neighbors on ImageNet.
28 | Results are visualized from our ResNet50 feature, compared with baseline ResNet50 feature, raw image features and previous unsupervised features.
29 | First column is the query image, followed by 20 retrievals ranked by the similarity.
30 |
31 |
32 |
33 | ## Usage
34 |
35 | Our code extends the pytorch implementation of imagenet classification in [official pytorch release](https://github.com/pytorch/examples/tree/master/imagenet).
36 | Please refer to the official repo for details of data preparation and hardware configurations.
37 |
38 | - install python2 and [pytorch>=0.4](http://pytorch.org)
39 |
40 | - clone this repo: `git clone https://github.com/Microsoft/snca.pytorch`
41 |
42 | - Training on ImageNet:
43 |
44 | `python main.py DATAPATH --arch resnet18 -j 32 --temperature 0.05 --low-dim 128 -b 256 `
45 |
46 | - During training, we monitor the supervised validation accuracy by K nearest neighbor with k=1, as it's faster, and gives a good estimation of the feature quality.
47 |
48 | - Testing on ImageNet:
49 |
50 | `python main.py DATAPATH --arch resnet18 --resume input_model.pth.tar -e` runs testing with default K=30 neighbors.
51 |
52 | - Memory Consumption and Computation Issues
53 |
54 | Memory consumption is more of an issue than computation time.
55 | Currently, the implementation of nca module is not paralleled across multiple GPUs.
56 | Hence, the first GPU will consume much more memory than the others.
57 | For example, when training a ResNet18 network, GPU 0 will consume 11GB memory, while the others each takes 2.5GB.
58 | You will need to set the Caffe style "-b 128 --iter-size 2" for training deeper networks.
59 | Our released models are trained with V100 machines.
60 |
61 | - Training on CIFAR10:
62 |
63 | `python cifar.py --temperature 0.05 --lr 0.1`
64 |
65 |
66 | ## Citation
67 | ```
68 | @inproceedings{wu2018improving,
69 | title={Improving Generalization via Scalable Neighborhood Component Analysis},
70 | author={Wu, Zhirong and Efros, Alexei A and Yu, Stella},
71 | booktitle={European Conference on Computer Vision (ECCV) 2018},
72 | year={2018}
73 | }
74 | ```
75 |
76 | ## Contact
77 |
78 | For any questions, please feel free to reach
79 | ```
80 | Zhirong Wu: xavibrowu@gmail.com
81 | ```
82 |
83 | ## Contributing
84 |
85 | This project welcomes contributions and suggestions. Most contributions require you to agree to a
86 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
87 | the rights to use your contribution. For details, visit https://cla.microsoft.com.
88 |
89 | When you submit a pull request, a CLA-bot will automatically determine whether you need to provide
90 | a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions
91 | provided by the bot. You will only need to do this once across all repos using our CLA.
92 |
93 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
94 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
95 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
96 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/ThirdPartyNotices.txt:
--------------------------------------------------------------------------------
1 | ************************************************************************
2 |
3 | THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
4 |
5 | This project incorporates components from the projects listed below.
6 | The original copyright notices and the licenses under which Microsoft received such components are set forth below.
7 | Microsoft reserves all rights not expressly granted herein, whether by implication, estoppel or otherwise.
8 |
9 | 1. Pytorch (https://github.com/pytorch/pytorch)
10 | 2. lemniscate (https://github.com/zhirongw/lemniscate.pytorch)
11 |
12 | From PyTorch:
13 |
14 | Copyright (c) 2016- Facebook, Inc (Adam Paszke)
15 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
16 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
17 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
18 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
19 | Copyright (c) 2011-2013 NYU (Clement Farabet)
20 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
21 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
22 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
23 |
24 | From Caffe2:
25 |
26 | Copyright (c) 2016-present, Facebook Inc. All rights reserved.
27 |
28 | All contributions by Facebook:
29 | Copyright (c) 2016 Facebook Inc.
30 |
31 | All contributions by Google:
32 | Copyright (c) 2015 Google Inc.
33 | All rights reserved.
34 |
35 | All contributions by Yangqing Jia:
36 | Copyright (c) 2015 Yangqing Jia
37 | All rights reserved.
38 |
39 | All contributions from Caffe:
40 | Copyright(c) 2013, 2014, 2015, the respective contributors
41 | All rights reserved.
42 |
43 | All other contributions:
44 | Copyright(c) 2015, 2016 the respective contributors
45 | All rights reserved.
46 |
47 | Caffe2 uses a copyright model similar to Caffe: each contributor holds
48 | copyright over their contributions to Caffe2. The project versioning records
49 | all such contribution and copyright details. If a contributor wants to further
50 | mark their specific copyright on a particular contribution, they should
51 | indicate their copyright solely in the commit message of the change when it is
52 | committed.
53 |
54 | All rights reserved.
55 |
56 | Redistribution and use in source and binary forms, with or without
57 | modification, are permitted provided that the following conditions are met:
58 |
59 | 1. Redistributions of source code must retain the above copyright
60 | notice, this list of conditions and the following disclaimer.
61 |
62 | 2. Redistributions in binary form must reproduce the above copyright
63 | notice, this list of conditions and the following disclaimer in the
64 | documentation and/or other materials provided with the distribution.
65 |
66 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
67 | and IDIAP Research Institute nor the names of its contributors may be
68 | used to endorse or promote products derived from this software without
69 | specific prior written permission.
70 |
71 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
72 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
73 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
74 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
75 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
76 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
77 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
78 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
79 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
80 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
81 | POSSIBILITY OF SUCH DAMAGE.
82 |
--------------------------------------------------------------------------------
/cifar.py:
--------------------------------------------------------------------------------
1 | '''Train CIFAR10 with PyTorch.'''
2 | from __future__ import print_function
3 |
4 | import sys
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import torch.nn.functional as F
9 | import torch.backends.cudnn as cudnn
10 |
11 | import torchvision
12 | import torchvision.transforms as transforms
13 |
14 | import os
15 | import argparse
16 | import time
17 |
18 | import models
19 | import datasets
20 | import math
21 |
22 | from lib.LinearAverage import LinearAverage
23 | from lib.NCA import NCACrossEntropy
24 | from lib.utils import AverageMeter
25 | from test import NN, kNN
26 |
27 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
28 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
29 | parser.add_argument('--resume', '-r', default='', type=str, help='resume from checkpoint')
30 | parser.add_argument('--test-only', action='store_true', help='test only')
31 | parser.add_argument('--low-dim', default=128, type=int,
32 | metavar='D', help='feature dimension')
33 | parser.add_argument('--temperature', default=0.05, type=float,
34 | metavar='T', help='temperature parameter for softmax')
35 | parser.add_argument('--memory-momentum', default=0.5, type=float,
36 | metavar='M', help='momentum for non-parametric updates')
37 |
38 | args = parser.parse_args()
39 |
40 | use_cuda = torch.cuda.is_available()
41 | best_acc = 0 # best test accuracy
42 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
43 |
44 | # Data
45 | print('==> Preparing data..')
46 | transform_train = transforms.Compose([
47 | #transforms.RandomCrop(32, padding=4),
48 | transforms.RandomResizedCrop(size=32, scale=(0.2,1.)),
49 | transforms.RandomGrayscale(p=0.2),
50 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
51 | transforms.RandomHorizontalFlip(),
52 | transforms.ToTensor(),
53 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
54 | ])
55 |
56 | transform_test = transforms.Compose([
57 | transforms.ToTensor(),
58 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
59 | ])
60 |
61 | trainset = datasets.CIFAR10Instance(root='./data', train=True, download=True, transform=transform_train)
62 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
63 |
64 | testset = datasets.CIFAR10Instance(root='./data', train=False, download=True, transform=transform_test)
65 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
66 |
67 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
68 | ndata = trainset.__len__()
69 |
70 | # Model
71 | if args.test_only or len(args.resume)>0:
72 | # Load checkpoint.
73 | print('==> Resuming from checkpoint..')
74 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
75 | checkpoint = torch.load('./checkpoint/'+args.resume)
76 | net = checkpoint['net']
77 | lemniscate = checkpoint['lemniscate']
78 | best_acc = checkpoint['acc']
79 | start_epoch = checkpoint['epoch']
80 | else:
81 | print('==> Building model..')
82 | net = models.__dict__['ResNet18'](low_dim=args.low_dim)
83 | # define leminiscate
84 | lemniscate = LinearAverage(args.low_dim, ndata, args.temperature, args.memory_momentum)
85 |
86 | # define loss function
87 | criterion = NCACrossEntropy(torch.LongTensor(trainloader.dataset.train_labels))
88 |
89 | if use_cuda:
90 | net.cuda()
91 | net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
92 | lemniscate.cuda()
93 | criterion.cuda()
94 | cudnn.benchmark = True
95 |
96 | if args.test_only:
97 | acc = kNN(0, net, lemniscate, trainloader, testloader, 30, args.temperature)
98 | sys.exit(0)
99 |
100 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)
101 |
102 | def adjust_learning_rate(optimizer, epoch):
103 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
104 | lr = args.lr * (0.1 ** (epoch // 50))
105 | print(lr)
106 | for param_group in optimizer.param_groups:
107 | param_group['lr'] = lr
108 |
109 | # Training
110 | def train(epoch):
111 | print('\nEpoch: %d' % epoch)
112 | adjust_learning_rate(optimizer, epoch)
113 | train_loss = AverageMeter()
114 | data_time = AverageMeter()
115 | batch_time = AverageMeter()
116 | correct = 0
117 | total = 0
118 |
119 | # switch to train mode
120 | net.train()
121 |
122 | end = time.time()
123 | for batch_idx, (inputs, targets, indexes) in enumerate(trainloader):
124 | data_time.update(time.time() - end)
125 | if use_cuda:
126 | inputs, targets, indexes = inputs.cuda(), targets.cuda(), indexes.cuda()
127 | optimizer.zero_grad()
128 |
129 | features = net(inputs)
130 | outputs = lemniscate(features, indexes)
131 | loss = criterion(outputs, indexes)
132 |
133 | loss.backward()
134 | optimizer.step()
135 |
136 | train_loss.update(loss.item(), inputs.size(0))
137 |
138 | # measure elapsed time
139 | batch_time.update(time.time() - end)
140 | end = time.time()
141 |
142 | print('Epoch: [{}][{}/{}]'
143 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
144 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f}) '
145 | 'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f})'.format(
146 | epoch, batch_idx, len(trainloader), batch_time=batch_time, data_time=data_time, train_loss=train_loss))
147 |
148 | for epoch in range(start_epoch, start_epoch+200):
149 | train(epoch)
150 | acc = kNN(epoch, net, lemniscate, trainloader, testloader, 30, args.temperature)
151 |
152 | if acc > best_acc:
153 | print('Saving..')
154 | state = {
155 | 'net': net.module if use_cuda else net,
156 | 'lemniscate': lemniscate,
157 | 'acc': acc,
158 | 'epoch': epoch,
159 | }
160 | if not os.path.isdir('checkpoint'):
161 | os.mkdir('checkpoint')
162 | torch.save(state, './checkpoint/ckpt.t7')
163 | best_acc = acc
164 |
165 | print('best accuracy: {:.2f}'.format(best_acc*100))
166 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .folder import ImageFolderInstance
2 | from .cifar import CIFAR10Instance, CIFAR100Instance
3 |
4 | __all__ = ('ImageFolderInstance', 'CIFAR10Instance', 'CIFAR100Instance')
5 |
6 |
--------------------------------------------------------------------------------
/datasets/cifar.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from PIL import Image
3 | import torchvision.datasets as datasets
4 | import torch.utils.data as data
5 |
6 | class CIFAR10Instance(datasets.CIFAR10):
7 | """CIFAR10Instance Dataset.
8 | """
9 | def __getitem__(self, index):
10 | if self.train:
11 | img, target = self.train_data[index], self.train_labels[index]
12 | else:
13 | img, target = self.test_data[index], self.test_labels[index]
14 |
15 | # doing this so that it is consistent with all other datasets
16 | # to return a PIL Image
17 | img = Image.fromarray(img)
18 |
19 | if self.transform is not None:
20 | img = self.transform(img)
21 |
22 | if self.target_transform is not None:
23 | target = self.target_transform(target)
24 |
25 | return img, target, index
26 |
27 | class CIFAR100Instance(CIFAR10Instance):
28 | """CIFAR100Instance Dataset.
29 |
30 | This is a subclass of the `CIFAR10Instance` Dataset.
31 | """
32 | base_folder = 'cifar-100-python'
33 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
34 | filename = "cifar-100-python.tar.gz"
35 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
36 | train_list = [
37 | ['train', '16019d7e3df5f24257cddd939b257f8d'],
38 | ]
39 |
40 | test_list = [
41 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
42 | ]
43 |
--------------------------------------------------------------------------------
/datasets/folder.py:
--------------------------------------------------------------------------------
1 | import torchvision.datasets as datasets
2 |
3 | class ImageFolderInstance(datasets.ImageFolder):
4 | """: Folder datasets which returns the index of the image as well::
5 | """
6 | def __getitem__(self, index):
7 | """
8 | Args:
9 | index (int): Index
10 | Returns:
11 | tuple: (image, target) where target is class_index of the target class.
12 | """
13 | path, target = self.imgs[index]
14 | img = self.loader(path)
15 | if self.transform is not None:
16 | img = self.transform(img)
17 | if self.target_transform is not None:
18 | target = self.target_transform(target)
19 |
20 | return img, target, index
21 |
22 |
--------------------------------------------------------------------------------
/lib/LinearAverage.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from torch import nn
4 | import math
5 |
6 | class LinearAverageOp(Function):
7 | @staticmethod
8 | def forward(self, x, y, memory, params):
9 | T = params[0].item()
10 | batchSize = x.size(0)
11 |
12 | # inner product
13 | out = torch.mm(x.data, memory.t())
14 | out.div_(T) # batchSize * N
15 |
16 | self.save_for_backward(x, memory, y, params)
17 |
18 | return out
19 |
20 | @staticmethod
21 | def backward(self, gradOutput):
22 | x, memory, y, params = self.saved_tensors
23 | batchSize = gradOutput.size(0)
24 | T = params[0].item()
25 | momentum = params[1].item()
26 |
27 | # add temperature
28 | gradOutput.data.div_(T)
29 |
30 | # gradient of linear
31 | gradInput = torch.mm(gradOutput.data, memory)
32 | gradInput.resize_as_(x)
33 |
34 | # update the non-parametric data
35 | weight_pos = memory.index_select(0, y.data.view(-1)).resize_as_(x)
36 | weight_pos.mul_(momentum)
37 | weight_pos.add_(torch.mul(x.data, 1-momentum))
38 | w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5)
39 | updated_weight = weight_pos.div(w_norm)
40 | memory.index_copy_(0, y, updated_weight)
41 |
42 | return gradInput, None, None, None
43 |
44 | class LinearAverage(nn.Module):
45 |
46 | def __init__(self, inputSize, outputSize, T=0.05, momentum=0.5):
47 | super(LinearAverage, self).__init__()
48 | stdv = 1 / math.sqrt(inputSize)
49 | self.nLem = outputSize
50 |
51 | self.register_buffer('params',torch.tensor([T, momentum]));
52 | stdv = 1. / math.sqrt(inputSize/3)
53 | self.register_buffer('memory', torch.rand(outputSize, inputSize).mul_(2*stdv).add_(-stdv))
54 |
55 | def forward(self, x, y):
56 | out = LinearAverageOp.apply(x, y, self.memory, self.params)
57 | return out
58 |
59 |
--------------------------------------------------------------------------------
/lib/NCA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.autograd import Function
4 | import math
5 |
6 | eps = 1e-8
7 |
8 | class NCACrossEntropy(nn.Module):
9 | ''' \sum_{j=C} log(p_{ij})
10 | Store all the labels of the dataset.
11 | Only pass the indexes of the training instances during forward.
12 | '''
13 | def __init__(self, labels, margin=0):
14 | super(NCACrossEntropy, self).__init__()
15 | self.register_buffer('labels', torch.LongTensor(labels.size(0)))
16 | self.labels = labels
17 | self.margin = margin
18 |
19 | def forward(self, x, indexes):
20 | batchSize = x.size(0)
21 | n = x.size(1)
22 | exp = torch.exp(x)
23 |
24 | # labels for currect batch
25 | y = torch.index_select(self.labels, 0, indexes.data).view(batchSize, 1)
26 | same = y.repeat(1, n).eq_(self.labels)
27 |
28 | # self prob exclusion, hack with memory for effeciency
29 | exp.data.scatter_(1, indexes.data.view(-1,1), 0)
30 |
31 | p = torch.mul(exp, same.float()).sum(dim=1)
32 | Z = exp.sum(dim=1)
33 |
34 | Z_exclude = Z - p
35 | p = p.div(math.exp(self.margin))
36 | Z = Z_exclude + p
37 |
38 | prob = torch.div(p, Z)
39 | prob_masked = torch.masked_select(prob, prob.ne(0))
40 |
41 | loss = prob_masked.log().sum(0)
42 |
43 | return - loss / batchSize
44 |
45 |
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
1 | # nothing
2 |
--------------------------------------------------------------------------------
/lib/normalize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | from torch import nn
4 |
5 | class Normalize(nn.Module):
6 |
7 | def __init__(self, power=2):
8 | super(Normalize, self).__init__()
9 | self.power = power
10 |
11 | def forward(self, x):
12 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1./self.power)
13 | out = x.div(norm)
14 | return out
15 |
--------------------------------------------------------------------------------
/lib/utils.py:
--------------------------------------------------------------------------------
1 | class AverageMeter(object):
2 | """Computes and stores the average and current value"""
3 | def __init__(self):
4 | self.reset()
5 |
6 | def reset(self):
7 | self.val = 0
8 | self.avg = 0
9 | self.sum = 0
10 | self.count = 0
11 |
12 | def update(self, val, n=1):
13 | self.val = val
14 | self.sum += val * n
15 | self.count += n
16 | self.avg = self.sum / self.count
17 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import shutil
5 | import time
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.distributed as dist
12 | import torch.optim
13 | import torch.utils.data
14 | import torch.utils.data.distributed
15 | import torchvision.transforms as transforms
16 |
17 | import datasets
18 | import models
19 | import math
20 |
21 | from lib.LinearAverage import LinearAverage
22 | from lib.NCA import NCACrossEntropy
23 | from lib.utils import AverageMeter
24 | from test import NN, kNN
25 |
26 | model_names = sorted(name for name in models.__dict__
27 | if name.islower() and not name.startswith("__")
28 | and callable(models.__dict__[name]))
29 |
30 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
31 | parser.add_argument('data', metavar='DIR',
32 | help='path to dataset')
33 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
34 | choices=model_names,
35 | help='model architecture: ' +
36 | ' | '.join(model_names) +
37 | ' (default: resnet18)')
38 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
39 | help='number of data loading workers (default: 4)')
40 | parser.add_argument('--epochs', default=130, type=int, metavar='N',
41 | help='number of total epochs to run')
42 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
43 | help='manual epoch number (useful on restarts)')
44 | parser.add_argument('-b', '--batch-size', default=256, type=int,
45 | metavar='N', help='mini-batch size (default: 256)')
46 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
47 | metavar='LR', help='initial learning rate')
48 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
49 | help='momentum')
50 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
51 | metavar='W', help='weight decay (default: 1e-4)')
52 | parser.add_argument('--print-freq', '-p', default=10, type=int,
53 | metavar='N', help='print frequency (default: 10)')
54 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
55 | help='path to latest checkpoint (default: none)')
56 | parser.add_argument('--test-only', action='store_true', help='test only')
57 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
58 | help='evaluate model on validation set')
59 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
60 | help='use pre-trained model')
61 | parser.add_argument('--world-size', default=1, type=int,
62 | help='number of distributed processes')
63 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
64 | help='url used to set up distributed training')
65 | parser.add_argument('--dist-backend', default='gloo', type=str,
66 | help='distributed backend')
67 | parser.add_argument('--low-dim', default=128, type=int,
68 | metavar='D', help='feature dimension')
69 | parser.add_argument('--temperature', default=0.05, type=float,
70 | metavar='T', help='temperature parameter')
71 | parser.add_argument('--memory-momentum', '--m-mementum', default=0.5, type=float,
72 | metavar='M', help='momentum for non-parametric updates')
73 | parser.add_argument('--iter-size', default=1, type=int,
74 | help='caffe style iter size')
75 | parser.add_argument('--margin', default=0.0, type=float,
76 | help='classification margin')
77 |
78 | best_prec1 = 0
79 |
80 | def main():
81 | global args, best_prec1
82 | args = parser.parse_args()
83 |
84 | args.distributed = args.world_size > 1
85 |
86 | if args.distributed:
87 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
88 | world_size=args.world_size)
89 |
90 | # Data loading code
91 | traindir = os.path.join(args.data, 'train')
92 | valdir = os.path.join(args.data, 'val')
93 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
94 | std=[0.229, 0.224, 0.225])
95 |
96 | train_dataset = datasets.ImageFolderInstance(
97 | traindir,
98 | transforms.Compose([
99 | transforms.RandomResizedCrop(224),
100 | transforms.RandomHorizontalFlip(),
101 | transforms.ToTensor(),
102 | normalize,
103 | ]))
104 |
105 | if args.distributed:
106 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
107 | else:
108 | train_sampler = None
109 |
110 | train_loader = torch.utils.data.DataLoader(
111 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
112 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
113 |
114 | val_loader = torch.utils.data.DataLoader(
115 | datasets.ImageFolderInstance(valdir, transforms.Compose([
116 | transforms.Resize(256),
117 | transforms.CenterCrop(224),
118 | transforms.ToTensor(),
119 | normalize,
120 | ])),
121 | batch_size=args.batch_size, shuffle=False,
122 | num_workers=args.workers, pin_memory=True)
123 |
124 | # create model
125 | if args.pretrained:
126 | print("=> using pre-trained model '{}'".format(args.arch))
127 | model = models.__dict__[args.arch](pretrained=True)
128 | else:
129 | print("=> creating model '{}'".format(args.arch))
130 | model = models.__dict__[args.arch](low_dim=args.low_dim)
131 |
132 | if not args.distributed:
133 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
134 | model.features = torch.nn.DataParallel(model.features)
135 | model.cuda()
136 | else:
137 | model = torch.nn.DataParallel(model).cuda()
138 | else:
139 | model.cuda()
140 | model = torch.nn.parallel.DistributedDataParallel(model)
141 |
142 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
143 | momentum=args.momentum,
144 | weight_decay=args.weight_decay, nesterov=True)
145 |
146 | # optionally resume from a checkpoint
147 | if args.resume:
148 | if os.path.isfile(args.resume):
149 | print("=> loading checkpoint '{}'".format(args.resume))
150 | checkpoint = torch.load(args.resume)
151 | args.start_epoch = checkpoint['epoch']
152 | best_prec1 = checkpoint['best_prec1']
153 | model.load_state_dict(checkpoint['state_dict'])
154 | lemniscate = checkpoint['lemniscate']
155 | optimizer.load_state_dict(checkpoint['optimizer'])
156 | print("=> loaded checkpoint '{}' (epoch {})"
157 | .format(args.resume, checkpoint['epoch']))
158 | else:
159 | print("=> no checkpoint found at '{}'".format(args.resume))
160 | else:
161 | # define lemniscate and loss function (criterion)
162 | ndata = train_dataset.__len__()
163 | lemniscate = LinearAverage(args.low_dim, ndata, args.temperature, args.memory_momentum).cuda()
164 |
165 |
166 | criterion = NCACrossEntropy(torch.LongTensor([y for (p, y) in train_loader.dataset.imgs]),
167 | args.margin / args.temperature).cuda()
168 | cudnn.benchmark = True
169 |
170 | if args.evaluate:
171 | prec1 = kNN(0, model, lemniscate, train_loader, val_loader, 30, args.temperature, 0)
172 | return
173 |
174 | for epoch in range(args.start_epoch, args.epochs):
175 | if args.distributed:
176 | train_sampler.set_epoch(epoch)
177 | adjust_learning_rate(optimizer, epoch)
178 | adjust_memory_update_rate(lemniscate, epoch)
179 |
180 | # train for one epoch
181 | train(train_loader, model, lemniscate, criterion, optimizer, epoch)
182 |
183 | # evaluate on validation set
184 | prec1 = NN(epoch, model, lemniscate, train_loader, val_loader)
185 |
186 | # remember best prec@1 and save checkpoint
187 | is_best = prec1 > best_prec1
188 | best_prec1 = max(prec1, best_prec1)
189 | save_checkpoint({
190 | 'epoch': epoch + 1,
191 | 'arch': args.arch,
192 | 'state_dict': model.state_dict(),
193 | 'lemniscate': lemniscate,
194 | 'best_prec1': best_prec1,
195 | 'optimizer' : optimizer.state_dict(),
196 | }, is_best)
197 |
198 |
199 | def train(train_loader, model, lemniscate, criterion, optimizer, epoch):
200 | batch_time = AverageMeter()
201 | data_time = AverageMeter()
202 | losses = AverageMeter()
203 |
204 | # switch to train mode
205 | model.train()
206 |
207 | end = time.time()
208 | optimizer.zero_grad()
209 | for i, (input, target, index) in enumerate(train_loader):
210 | # measure data loading time
211 | data_time.update(time.time() - end)
212 |
213 | target = target.cuda(async=True)
214 | index = index.cuda(async=True)
215 | input_var = torch.autograd.Variable(input)
216 | target_var = torch.autograd.Variable(target)
217 | index_var = torch.autograd.Variable(index)
218 |
219 | # compute output
220 | feature = model(input_var)
221 | output = lemniscate(feature, index_var)
222 | loss = criterion(output, index_var) / args.iter_size
223 |
224 | loss.backward()
225 | # measure accuracy and record loss
226 | losses.update(loss.data[0] * args.iter_size, input.size(0))
227 |
228 | if (i+1) % args.iter_size == 0:
229 | # compute gradient and do SGD step
230 | optimizer.step()
231 | optimizer.zero_grad()
232 |
233 | # measure elapsed time
234 | batch_time.update(time.time() - end)
235 | end = time.time()
236 |
237 | if i % args.print_freq == 0:
238 | print('Epoch: [{0}][{1}/{2}]\t'
239 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
240 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
241 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
242 | epoch, i, len(train_loader), batch_time=batch_time,
243 | data_time=data_time, loss=losses))
244 |
245 |
246 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
247 | torch.save(state, filename)
248 | if is_best:
249 | shutil.copyfile(filename, 'model_best.pth.tar')
250 |
251 | def adjust_memory_update_rate(lemniscate, epoch):
252 | if epoch >= 80:
253 | lemniscate.params[1] = 0.8
254 | if epoch >= 120:
255 | lemniscate.params[1] = 0.9
256 |
257 | def adjust_learning_rate(optimizer, epoch):
258 | """Sets the learning rate to the initial LR decayed by 10 every 40 epochs"""
259 | lr = args.lr * (0.1 ** (epoch // 40))
260 | print(lr)
261 | for param_group in optimizer.param_groups:
262 | param_group['lr'] = lr
263 |
264 | def accuracy(output, target, topk=(1,)):
265 | """Computes the precision@k for the specified values of k"""
266 | maxk = max(topk)
267 | batch_size = target.size(0)
268 |
269 | _, pred = output.topk(maxk, 1, True, True)
270 | pred = pred.t()
271 | correct = pred.eq(target.view(1, -1).expand_as(pred))
272 |
273 | res = []
274 | for k in topk:
275 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
276 | res.append(correct_k.mul_(100.0 / batch_size))
277 | return res
278 |
279 |
280 | if __name__ == '__main__':
281 | main()
282 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 | from .resnet_cifar import *
3 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 | from lib.normalize import Normalize
5 |
6 |
7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
8 | 'resnet152']
9 |
10 | model_urls = { }
11 |
12 | def conv3x3(in_planes, out_planes, stride=1):
13 | "3x3 convolution with padding"
14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
15 | padding=1, bias=False)
16 |
17 |
18 | class BasicBlock(nn.Module):
19 | expansion = 1
20 |
21 | def __init__(self, inplanes, planes, stride=1, downsample=None):
22 | super(BasicBlock, self).__init__()
23 | self.conv1 = conv3x3(inplanes, planes, stride)
24 | self.bn1 = nn.BatchNorm2d(planes)
25 | self.relu = nn.ReLU(inplace=True)
26 | self.conv2 = conv3x3(planes, planes)
27 | self.bn2 = nn.BatchNorm2d(planes)
28 | self.downsample = downsample
29 | self.stride = stride
30 |
31 | def forward(self, x):
32 | residual = x
33 |
34 | out = self.conv1(x)
35 | out = self.bn1(out)
36 | out = self.relu(out)
37 |
38 | out = self.conv2(out)
39 | out = self.bn2(out)
40 |
41 | if self.downsample is not None:
42 | residual = self.downsample(x)
43 |
44 | out += residual
45 | out = self.relu(out)
46 |
47 | return out
48 |
49 |
50 | class Bottleneck(nn.Module):
51 | expansion = 4
52 |
53 | def __init__(self, inplanes, planes, stride=1, downsample=None):
54 | super(Bottleneck, self).__init__()
55 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
56 | self.bn1 = nn.BatchNorm2d(planes)
57 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
58 | padding=1, bias=False)
59 | self.bn2 = nn.BatchNorm2d(planes)
60 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
61 | self.bn3 = nn.BatchNorm2d(planes * 4)
62 | self.relu = nn.ReLU(inplace=True)
63 | self.downsample = downsample
64 | self.stride = stride
65 |
66 | def forward(self, x):
67 | residual = x
68 |
69 | out = self.conv1(x)
70 | out = self.bn1(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv2(out)
74 | out = self.bn2(out)
75 | out = self.relu(out)
76 |
77 | out = self.conv3(out)
78 | out = self.bn3(out)
79 |
80 | if self.downsample is not None:
81 | residual = self.downsample(x)
82 |
83 | out += residual
84 | out = self.relu(out)
85 |
86 | return out
87 |
88 |
89 | class ResNet(nn.Module):
90 |
91 | def __init__(self, block, layers, low_dim=128):
92 | self.inplanes = 64
93 | super(ResNet, self).__init__()
94 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
95 | bias=False)
96 | self.bn1 = nn.BatchNorm2d(64)
97 | self.relu = nn.ReLU(inplace=True)
98 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
99 | self.layer1 = self._make_layer(block, 64, layers[0])
100 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
101 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
102 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
103 | self.avgpool = nn.AvgPool2d(7, stride=1)
104 | self.fc = nn.Linear(512 * block.expansion, low_dim)
105 | self.l2norm = Normalize(2)
106 |
107 | for m in self.modules():
108 | if isinstance(m, nn.Conv2d):
109 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
110 | m.weight.data.normal_(0, math.sqrt(2. / n))
111 | elif isinstance(m, nn.BatchNorm2d):
112 | m.weight.data.fill_(1)
113 | m.bias.data.zero_()
114 |
115 | def _make_layer(self, block, planes, blocks, stride=1):
116 | downsample = None
117 | if stride != 1 or self.inplanes != planes * block.expansion:
118 | downsample = nn.Sequential(
119 | nn.Conv2d(self.inplanes, planes * block.expansion,
120 | kernel_size=1, stride=stride, bias=False),
121 | nn.BatchNorm2d(planes * block.expansion),
122 | )
123 |
124 | layers = []
125 | layers.append(block(self.inplanes, planes, stride, downsample))
126 | self.inplanes = planes * block.expansion
127 | for i in range(1, blocks):
128 | layers.append(block(self.inplanes, planes))
129 |
130 | return nn.Sequential(*layers)
131 |
132 | def forward(self, x):
133 | x = self.conv1(x)
134 | x = self.bn1(x)
135 | x = self.relu(x)
136 | x = self.maxpool(x)
137 |
138 | x = self.layer1(x)
139 | x = self.layer2(x)
140 | x = self.layer3(x)
141 | x = self.layer4(x)
142 |
143 | x = self.avgpool(x)
144 | x = x.view(x.size(0), -1)
145 | x = self.fc(x)
146 | x = self.l2norm(x)
147 |
148 | return x
149 |
150 |
151 | def resnet18(pretrained=False, **kwargs):
152 | """Constructs a ResNet-18 model.
153 |
154 | Args:
155 | pretrained (bool): If True, returns a model pre-trained on ImageNet
156 | """
157 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
158 | if pretrained:
159 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
160 | return model
161 |
162 |
163 | def resnet34(pretrained=False, **kwargs):
164 | """Constructs a ResNet-34 model.
165 |
166 | Args:
167 | pretrained (bool): If True, returns a model pre-trained on ImageNet
168 | """
169 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
170 | if pretrained:
171 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
172 | return model
173 |
174 |
175 | def resnet50(pretrained=False, **kwargs):
176 | """Constructs a ResNet-50 model.
177 |
178 | Args:
179 | pretrained (bool): If True, returns a model pre-trained on ImageNet
180 | """
181 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
182 | if pretrained:
183 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
184 | return model
185 |
186 |
187 | def resnet101(pretrained=False, **kwargs):
188 | """Constructs a ResNet-101 model.
189 |
190 | Args:
191 | pretrained (bool): If True, returns a model pre-trained on ImageNet
192 | """
193 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
194 | if pretrained:
195 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
196 | return model
197 |
198 |
199 | def resnet152(pretrained=False, **kwargs):
200 | """Constructs a ResNet-152 model.
201 |
202 | Args:
203 | pretrained (bool): If True, returns a model pre-trained on ImageNet
204 | """
205 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
206 | if pretrained:
207 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
208 | return model
209 |
--------------------------------------------------------------------------------
/models/resnet_cifar.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 |
3 | For Pre-activation ResNet, see 'preact_resnet.py'.
4 |
5 | Reference:
6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
8 | '''
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from lib.normalize import Normalize
13 |
14 | from torch.autograd import Variable
15 |
16 |
17 | class BasicBlock(nn.Module):
18 | expansion = 1
19 |
20 | def __init__(self, in_planes, planes, stride=1):
21 | super(BasicBlock, self).__init__()
22 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
23 | self.bn1 = nn.BatchNorm2d(planes)
24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
25 | self.bn2 = nn.BatchNorm2d(planes)
26 |
27 | self.shortcut = nn.Sequential()
28 | if stride != 1 or in_planes != self.expansion*planes:
29 | self.shortcut = nn.Sequential(
30 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
31 | nn.BatchNorm2d(self.expansion*planes)
32 | )
33 |
34 | def forward(self, x):
35 | out = F.relu(self.bn1(self.conv1(x)))
36 | out = self.bn2(self.conv2(out))
37 | out += self.shortcut(x)
38 | out = F.relu(out)
39 | return out
40 |
41 |
42 | class Bottleneck(nn.Module):
43 | expansion = 4
44 |
45 | def __init__(self, in_planes, planes, stride=1):
46 | super(Bottleneck, self).__init__()
47 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
48 | self.bn1 = nn.BatchNorm2d(planes)
49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
50 | self.bn2 = nn.BatchNorm2d(planes)
51 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
52 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
53 |
54 | self.shortcut = nn.Sequential()
55 | if stride != 1 or in_planes != self.expansion*planes:
56 | self.shortcut = nn.Sequential(
57 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
58 | nn.BatchNorm2d(self.expansion*planes)
59 | )
60 |
61 | def forward(self, x):
62 | out = F.relu(self.bn1(self.conv1(x)))
63 | out = F.relu(self.bn2(self.conv2(out)))
64 | out = self.bn3(self.conv3(out))
65 | out += self.shortcut(x)
66 | out = F.relu(out)
67 | return out
68 |
69 |
70 | class ResNet(nn.Module):
71 | def __init__(self, block, num_blocks, low_dim=128):
72 | super(ResNet, self).__init__()
73 | self.in_planes = 64
74 |
75 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
76 | self.bn1 = nn.BatchNorm2d(64)
77 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
78 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
79 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
80 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
81 | self.linear = nn.Linear(512*block.expansion, low_dim)
82 | self.l2norm = Normalize(2)
83 |
84 | def _make_layer(self, block, planes, num_blocks, stride):
85 | strides = [stride] + [1]*(num_blocks-1)
86 | layers = []
87 | for stride in strides:
88 | layers.append(block(self.in_planes, planes, stride))
89 | self.in_planes = planes * block.expansion
90 | return nn.Sequential(*layers)
91 |
92 | def forward(self, x):
93 | out = F.relu(self.bn1(self.conv1(x)))
94 | out = self.layer1(out)
95 | out = self.layer2(out)
96 | out = self.layer3(out)
97 | out = self.layer4(out)
98 | out = F.avg_pool2d(out, 4)
99 | out = out.view(out.size(0), -1)
100 | out = self.linear(out)
101 | out = self.l2norm(out)
102 | return out
103 |
104 |
105 | def ResNet18(low_dim=128):
106 | return ResNet(BasicBlock, [2,2,2,2], low_dim)
107 |
108 | def ResNet34(low_dim=128):
109 | return ResNet(BasicBlock, [3,4,6,3], low_dim)
110 |
111 | def ResNet50(low_dim=128):
112 | return ResNet(Bottleneck, [3,4,6,3], low_dim)
113 |
114 | def ResNet101(low_dim=128):
115 | return ResNet(Bottleneck, [3,4,23,3], low_dim)
116 |
117 | def ResNet152(low_dim=128):
118 | return ResNet(Bottleneck, [3,8,36,3], low_dim)
119 |
120 |
121 | def test():
122 | net = ResNet18()
123 | y = net(Variable(torch.randn(1,3,32,32)))
124 | print(y.size())
125 |
126 | # test()
127 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import datasets
4 | from lib.utils import AverageMeter
5 | import torchvision.transforms as transforms
6 | import numpy as np
7 |
8 | def NN(epoch, net, lemniscate, trainloader, testloader, recompute_memory=0):
9 | net.eval()
10 | net_time = AverageMeter()
11 | cls_time = AverageMeter()
12 | losses = AverageMeter()
13 | correct = 0.
14 | total = 0
15 | testsize = testloader.dataset.__len__()
16 |
17 | trainFeatures = lemniscate.memory.t()
18 | if hasattr(trainloader.dataset, 'imgs'):
19 | trainLabels = torch.LongTensor([y for (p, y) in trainloader.dataset.imgs]).cuda()
20 | else:
21 | trainLabels = torch.LongTensor(trainloader.dataset.train_labels).cuda()
22 |
23 | if recompute_memory:
24 | transform_bak = trainloader.dataset.transform
25 | trainloader.dataset.transform = testloader.dataset.transform
26 | temploader = torch.utils.data.DataLoader(trainloader.dataset, batch_size=100, shuffle=False, num_workers=1)
27 | for batch_idx, (inputs, targets, indexes) in enumerate(temploader):
28 | targets = targets.cuda(async=True)
29 | batchSize = inputs.size(0)
30 | features = net(inputs)
31 | trainFeatures[:, batch_idx*batchSize:batch_idx*batchSize+batchSize] = features.data.t()
32 | trainLabels = torch.LongTensor(temploader.dataset.train_labels).cuda()
33 | trainloader.dataset.transform = transform_bak
34 |
35 | end = time.time()
36 | with torch.no_grad():
37 | for batch_idx, (inputs, targets, indexes) in enumerate(testloader):
38 | targets = targets.cuda(async=True)
39 | batchSize = inputs.size(0)
40 | features = net(inputs)
41 | net_time.update(time.time() - end)
42 | end = time.time()
43 |
44 | dist = torch.mm(features, trainFeatures)
45 |
46 | yd, yi = dist.topk(1, dim=1, largest=True, sorted=True)
47 | candidates = trainLabels.view(1,-1).expand(batchSize, -1)
48 | retrieval = torch.gather(candidates, 1, yi)
49 |
50 | retrieval = retrieval.narrow(1, 0, 1).clone().view(-1)
51 | yd = yd.narrow(1, 0, 1)
52 |
53 | total += targets.size(0)
54 | correct += retrieval.eq(targets.data).sum().item()
55 |
56 | cls_time.update(time.time() - end)
57 | end = time.time()
58 |
59 | print('Test [{}/{}]\t'
60 | 'Net Time {net_time.val:.3f} ({net_time.avg:.3f})\t'
61 | 'Cls Time {cls_time.val:.3f} ({cls_time.avg:.3f})\t'
62 | 'Top1: {:.2f}'.format(
63 | total, testsize, correct*100./total, net_time=net_time, cls_time=cls_time))
64 |
65 | return correct/total
66 |
67 | def kNN(epoch, net, lemniscate, trainloader, testloader, K, sigma, recompute_memory=0):
68 | net.eval()
69 | net_time = AverageMeter()
70 | cls_time = AverageMeter()
71 | total = 0
72 | testsize = testloader.dataset.__len__()
73 |
74 | trainFeatures = lemniscate.memory.t()
75 | if hasattr(trainloader.dataset, 'imgs'):
76 | trainLabels = torch.LongTensor([y for (p, y) in trainloader.dataset.imgs]).cuda()
77 | else:
78 | trainLabels = torch.LongTensor(trainloader.dataset.train_labels).cuda()
79 | C = trainLabels.max() + 1
80 |
81 | if recompute_memory:
82 | transform_bak = trainloader.dataset.transform
83 | trainloader.dataset.transform = testloader.dataset.transform
84 | temploader = torch.utils.data.DataLoader(trainloader.dataset, batch_size=100, shuffle=False, num_workers=1)
85 | for batch_idx, (inputs, targets, indexes) in enumerate(temploader):
86 | targets = targets.cuda(async=True)
87 | batchSize = inputs.size(0)
88 | features = net(inputs)
89 | trainFeatures[:, batch_idx*batchSize:batch_idx*batchSize+batchSize] = features.data.t()
90 | trainLabels = torch.LongTensor(temploader.dataset.train_labels).cuda()
91 | trainloader.dataset.transform = transform_bak
92 |
93 | top1 = 0.
94 | top5 = 0.
95 | end = time.time()
96 | with torch.no_grad():
97 | retrieval_one_hot = torch.zeros(K, C).cuda()
98 | for batch_idx, (inputs, targets, indexes) in enumerate(testloader):
99 | end = time.time()
100 | targets = targets.cuda(async=True)
101 | batchSize = inputs.size(0)
102 | features = net(inputs)
103 | net_time.update(time.time() - end)
104 | end = time.time()
105 |
106 | dist = torch.mm(features, trainFeatures)
107 |
108 | yd, yi = dist.topk(K, dim=1, largest=True, sorted=True)
109 | candidates = trainLabels.view(1,-1).expand(batchSize, -1)
110 | retrieval = torch.gather(candidates, 1, yi)
111 |
112 | retrieval_one_hot.resize_(batchSize * K, C).zero_()
113 | retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1)
114 | yd_transform = yd.clone().div_(sigma).exp_()
115 | probs = torch.sum(torch.mul(retrieval_one_hot.view(batchSize, -1 , C), yd_transform.view(batchSize, -1, 1)), 1)
116 | _, predictions = probs.sort(1, True)
117 |
118 | # Find which predictions match the target
119 | correct = predictions.eq(targets.data.view(-1,1))
120 | cls_time.update(time.time() - end)
121 |
122 | top1 = top1 + correct.narrow(1,0,1).sum().item()
123 | top5 = top5 + correct.narrow(1,0,5).sum().item()
124 |
125 | total += targets.size(0)
126 |
127 | print('Test [{}/{}]\t'
128 | 'Net Time {net_time.val:.3f} ({net_time.avg:.3f})\t'
129 | 'Cls Time {cls_time.val:.3f} ({cls_time.avg:.3f})\t'
130 | 'Top1: {:.2f} Top5: {:.2f}'.format(
131 | total, testsize, top1*100./total, top5*100./total, net_time=net_time, cls_time=cls_time))
132 |
133 | print(top1*100./total)
134 |
135 | return top1/total
136 |
137 |
--------------------------------------------------------------------------------