├── .gitignore ├── LICENSE ├── README.md ├── SECURITY.md ├── ThirdPartyNotices.txt ├── cifar.py ├── datasets ├── __init__.py ├── cifar.py └── folder.py ├── lib ├── LinearAverage.py ├── NCA.py ├── __init__.py ├── normalize.py └── utils.py ├── main.py ├── models ├── __init__.py ├── resnet.py └── resnet_cifar.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | checkpoint/* 3 | logs/* 4 | others/* 5 | 6 | *.pyc 7 | *.bak 8 | *.log 9 | *.tar 10 | *.pth 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # celery beat schedule file 90 | celerybeat-schedule 91 | 92 | # SageMath parsed files 93 | *.sage.py 94 | 95 | # Environments 96 | .env 97 | .venv 98 | env/ 99 | venv/ 100 | ENV/ 101 | env.bak/ 102 | venv.bak/ 103 | 104 | # Spyder project settings 105 | .spyderproject 106 | .spyproject 107 | 108 | # Rope project settings 109 | .ropeproject 110 | 111 | # mkdocs documentation 112 | /site 113 | 114 | # mypy 115 | .mypy_cache/ 116 | >>>>>>> origin/master 117 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Improving Generalization via Scalable Neighborhood Component Analysis 2 | 3 | This repo constains the pytorch implementation for the ECCV 2018 paper [(paper)](https://arxiv.org/pdf/1808.04699.pdf). 4 | We use deep networks to learn feature representations optimized for nearest neighbor classifiers, which could generalize better for new object categories. 5 | This project is a re-investigation of [Neighborhood Component Analysis (NCA)](http://www.cs.toronto.edu/~fritz/absps/nca.pdf) 6 | with recent technologies to make it scalable to deep networks and large-scale datasets. 7 | 8 | Much of code is extended from the previous [unsupervised learning project](https://arxiv.org/pdf/1805.01978.pdf). 9 | Please refer to [this repo](https://github.com/zhirongw/lemniscate.pytorch) for more details. 10 | 11 | 12 | 13 | ## Pretrained Models 14 | 15 | Currently, we provide three pretrained ResNet models. 16 | Each release contains the feature representation of all ImageNet training images (600 mb) and model weights (100-200mb). 17 | Models and their performance with nearest neighbor classifiers are as follows. 18 | 19 | - [ResNet 18](http://zhirongw.westus2.cloudapp.azure.com/models/snca_resnet18.pth.tar) (top 1 accuracy 70.59%) 20 | - [ResNet 34](http://zhirongw.westus2.cloudapp.azure.com/models/snca_resnet34.pth.tar) (top 1 accuracy 74.41%) 21 | - [ResNet 50](http://zhirongw.westus2.cloudapp.azure.com/models/snca_resnet50.pth.tar) (top 1 accuracy 76.57%) 22 | 23 | Code to reproduce the rest of the experiments are comming soon. 24 | 25 | ## Nearest Neighbors 26 | 27 | Please follow [this link](http://zhirongw.westus2.cloudapp.azure.com/nn.html) for a list of nearest neighbors on ImageNet. 28 | Results are visualized from our ResNet50 feature, compared with baseline ResNet50 feature, raw image features and previous unsupervised features. 29 | First column is the query image, followed by 20 retrievals ranked by the similarity. 30 | 31 | 32 | 33 | ## Usage 34 | 35 | Our code extends the pytorch implementation of imagenet classification in [official pytorch release](https://github.com/pytorch/examples/tree/master/imagenet). 36 | Please refer to the official repo for details of data preparation and hardware configurations. 37 | 38 | - install python2 and [pytorch>=0.4](http://pytorch.org) 39 | 40 | - clone this repo: `git clone https://github.com/Microsoft/snca.pytorch` 41 | 42 | - Training on ImageNet: 43 | 44 | `python main.py DATAPATH --arch resnet18 -j 32 --temperature 0.05 --low-dim 128 -b 256 ` 45 | 46 | - During training, we monitor the supervised validation accuracy by K nearest neighbor with k=1, as it's faster, and gives a good estimation of the feature quality. 47 | 48 | - Testing on ImageNet: 49 | 50 | `python main.py DATAPATH --arch resnet18 --resume input_model.pth.tar -e` runs testing with default K=30 neighbors. 51 | 52 | - Memory Consumption and Computation Issues 53 | 54 | Memory consumption is more of an issue than computation time. 55 | Currently, the implementation of nca module is not paralleled across multiple GPUs. 56 | Hence, the first GPU will consume much more memory than the others. 57 | For example, when training a ResNet18 network, GPU 0 will consume 11GB memory, while the others each takes 2.5GB. 58 | You will need to set the Caffe style "-b 128 --iter-size 2" for training deeper networks. 59 | Our released models are trained with V100 machines. 60 | 61 | - Training on CIFAR10: 62 | 63 | `python cifar.py --temperature 0.05 --lr 0.1` 64 | 65 | 66 | ## Citation 67 | ``` 68 | @inproceedings{wu2018improving, 69 | title={Improving Generalization via Scalable Neighborhood Component Analysis}, 70 | author={Wu, Zhirong and Efros, Alexei A and Yu, Stella}, 71 | booktitle={European Conference on Computer Vision (ECCV) 2018}, 72 | year={2018} 73 | } 74 | ``` 75 | 76 | ## Contact 77 | 78 | For any questions, please feel free to reach 79 | ``` 80 | Zhirong Wu: xavibrowu@gmail.com 81 | ``` 82 | 83 | ## Contributing 84 | 85 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 86 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 87 | the rights to use your contribution. For details, visit https://cla.microsoft.com. 88 | 89 | When you submit a pull request, a CLA-bot will automatically determine whether you need to provide 90 | a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions 91 | provided by the bot. You will only need to do this once across all repos using our CLA. 92 | 93 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 94 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 95 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 96 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /ThirdPartyNotices.txt: -------------------------------------------------------------------------------- 1 | ************************************************************************ 2 | 3 | THIRD-PARTY SOFTWARE NOTICES AND INFORMATION 4 | 5 | This project incorporates components from the projects listed below. 6 | The original copyright notices and the licenses under which Microsoft received such components are set forth below. 7 | Microsoft reserves all rights not expressly granted herein, whether by implication, estoppel or otherwise. 8 | 9 | 1. Pytorch (https://github.com/pytorch/pytorch) 10 | 2. lemniscate (https://github.com/zhirongw/lemniscate.pytorch) 11 | 12 | From PyTorch: 13 | 14 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 15 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 16 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 17 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 18 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 19 | Copyright (c) 2011-2013 NYU (Clement Farabet) 20 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 21 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 22 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 23 | 24 | From Caffe2: 25 | 26 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 27 | 28 | All contributions by Facebook: 29 | Copyright (c) 2016 Facebook Inc. 30 | 31 | All contributions by Google: 32 | Copyright (c) 2015 Google Inc. 33 | All rights reserved. 34 | 35 | All contributions by Yangqing Jia: 36 | Copyright (c) 2015 Yangqing Jia 37 | All rights reserved. 38 | 39 | All contributions from Caffe: 40 | Copyright(c) 2013, 2014, 2015, the respective contributors 41 | All rights reserved. 42 | 43 | All other contributions: 44 | Copyright(c) 2015, 2016 the respective contributors 45 | All rights reserved. 46 | 47 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 48 | copyright over their contributions to Caffe2. The project versioning records 49 | all such contribution and copyright details. If a contributor wants to further 50 | mark their specific copyright on a particular contribution, they should 51 | indicate their copyright solely in the commit message of the change when it is 52 | committed. 53 | 54 | All rights reserved. 55 | 56 | Redistribution and use in source and binary forms, with or without 57 | modification, are permitted provided that the following conditions are met: 58 | 59 | 1. Redistributions of source code must retain the above copyright 60 | notice, this list of conditions and the following disclaimer. 61 | 62 | 2. Redistributions in binary form must reproduce the above copyright 63 | notice, this list of conditions and the following disclaimer in the 64 | documentation and/or other materials provided with the distribution. 65 | 66 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 67 | and IDIAP Research Institute nor the names of its contributors may be 68 | used to endorse or promote products derived from this software without 69 | specific prior written permission. 70 | 71 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 72 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 73 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 74 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 75 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 76 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 77 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 78 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 79 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 80 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 81 | POSSIBILITY OF SUCH DAMAGE. 82 | -------------------------------------------------------------------------------- /cifar.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | 4 | import sys 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | import torch.backends.cudnn as cudnn 10 | 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | 14 | import os 15 | import argparse 16 | import time 17 | 18 | import models 19 | import datasets 20 | import math 21 | 22 | from lib.LinearAverage import LinearAverage 23 | from lib.NCA import NCACrossEntropy 24 | from lib.utils import AverageMeter 25 | from test import NN, kNN 26 | 27 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 28 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 29 | parser.add_argument('--resume', '-r', default='', type=str, help='resume from checkpoint') 30 | parser.add_argument('--test-only', action='store_true', help='test only') 31 | parser.add_argument('--low-dim', default=128, type=int, 32 | metavar='D', help='feature dimension') 33 | parser.add_argument('--temperature', default=0.05, type=float, 34 | metavar='T', help='temperature parameter for softmax') 35 | parser.add_argument('--memory-momentum', default=0.5, type=float, 36 | metavar='M', help='momentum for non-parametric updates') 37 | 38 | args = parser.parse_args() 39 | 40 | use_cuda = torch.cuda.is_available() 41 | best_acc = 0 # best test accuracy 42 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 43 | 44 | # Data 45 | print('==> Preparing data..') 46 | transform_train = transforms.Compose([ 47 | #transforms.RandomCrop(32, padding=4), 48 | transforms.RandomResizedCrop(size=32, scale=(0.2,1.)), 49 | transforms.RandomGrayscale(p=0.2), 50 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), 51 | transforms.RandomHorizontalFlip(), 52 | transforms.ToTensor(), 53 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 54 | ]) 55 | 56 | transform_test = transforms.Compose([ 57 | transforms.ToTensor(), 58 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 59 | ]) 60 | 61 | trainset = datasets.CIFAR10Instance(root='./data', train=True, download=True, transform=transform_train) 62 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 63 | 64 | testset = datasets.CIFAR10Instance(root='./data', train=False, download=True, transform=transform_test) 65 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 66 | 67 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 68 | ndata = trainset.__len__() 69 | 70 | # Model 71 | if args.test_only or len(args.resume)>0: 72 | # Load checkpoint. 73 | print('==> Resuming from checkpoint..') 74 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 75 | checkpoint = torch.load('./checkpoint/'+args.resume) 76 | net = checkpoint['net'] 77 | lemniscate = checkpoint['lemniscate'] 78 | best_acc = checkpoint['acc'] 79 | start_epoch = checkpoint['epoch'] 80 | else: 81 | print('==> Building model..') 82 | net = models.__dict__['ResNet18'](low_dim=args.low_dim) 83 | # define leminiscate 84 | lemniscate = LinearAverage(args.low_dim, ndata, args.temperature, args.memory_momentum) 85 | 86 | # define loss function 87 | criterion = NCACrossEntropy(torch.LongTensor(trainloader.dataset.train_labels)) 88 | 89 | if use_cuda: 90 | net.cuda() 91 | net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) 92 | lemniscate.cuda() 93 | criterion.cuda() 94 | cudnn.benchmark = True 95 | 96 | if args.test_only: 97 | acc = kNN(0, net, lemniscate, trainloader, testloader, 30, args.temperature) 98 | sys.exit(0) 99 | 100 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True) 101 | 102 | def adjust_learning_rate(optimizer, epoch): 103 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 104 | lr = args.lr * (0.1 ** (epoch // 50)) 105 | print(lr) 106 | for param_group in optimizer.param_groups: 107 | param_group['lr'] = lr 108 | 109 | # Training 110 | def train(epoch): 111 | print('\nEpoch: %d' % epoch) 112 | adjust_learning_rate(optimizer, epoch) 113 | train_loss = AverageMeter() 114 | data_time = AverageMeter() 115 | batch_time = AverageMeter() 116 | correct = 0 117 | total = 0 118 | 119 | # switch to train mode 120 | net.train() 121 | 122 | end = time.time() 123 | for batch_idx, (inputs, targets, indexes) in enumerate(trainloader): 124 | data_time.update(time.time() - end) 125 | if use_cuda: 126 | inputs, targets, indexes = inputs.cuda(), targets.cuda(), indexes.cuda() 127 | optimizer.zero_grad() 128 | 129 | features = net(inputs) 130 | outputs = lemniscate(features, indexes) 131 | loss = criterion(outputs, indexes) 132 | 133 | loss.backward() 134 | optimizer.step() 135 | 136 | train_loss.update(loss.item(), inputs.size(0)) 137 | 138 | # measure elapsed time 139 | batch_time.update(time.time() - end) 140 | end = time.time() 141 | 142 | print('Epoch: [{}][{}/{}]' 143 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 144 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f}) ' 145 | 'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f})'.format( 146 | epoch, batch_idx, len(trainloader), batch_time=batch_time, data_time=data_time, train_loss=train_loss)) 147 | 148 | for epoch in range(start_epoch, start_epoch+200): 149 | train(epoch) 150 | acc = kNN(epoch, net, lemniscate, trainloader, testloader, 30, args.temperature) 151 | 152 | if acc > best_acc: 153 | print('Saving..') 154 | state = { 155 | 'net': net.module if use_cuda else net, 156 | 'lemniscate': lemniscate, 157 | 'acc': acc, 158 | 'epoch': epoch, 159 | } 160 | if not os.path.isdir('checkpoint'): 161 | os.mkdir('checkpoint') 162 | torch.save(state, './checkpoint/ckpt.t7') 163 | best_acc = acc 164 | 165 | print('best accuracy: {:.2f}'.format(best_acc*100)) 166 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .folder import ImageFolderInstance 2 | from .cifar import CIFAR10Instance, CIFAR100Instance 3 | 4 | __all__ = ('ImageFolderInstance', 'CIFAR10Instance', 'CIFAR100Instance') 5 | 6 | -------------------------------------------------------------------------------- /datasets/cifar.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import torchvision.datasets as datasets 4 | import torch.utils.data as data 5 | 6 | class CIFAR10Instance(datasets.CIFAR10): 7 | """CIFAR10Instance Dataset. 8 | """ 9 | def __getitem__(self, index): 10 | if self.train: 11 | img, target = self.train_data[index], self.train_labels[index] 12 | else: 13 | img, target = self.test_data[index], self.test_labels[index] 14 | 15 | # doing this so that it is consistent with all other datasets 16 | # to return a PIL Image 17 | img = Image.fromarray(img) 18 | 19 | if self.transform is not None: 20 | img = self.transform(img) 21 | 22 | if self.target_transform is not None: 23 | target = self.target_transform(target) 24 | 25 | return img, target, index 26 | 27 | class CIFAR100Instance(CIFAR10Instance): 28 | """CIFAR100Instance Dataset. 29 | 30 | This is a subclass of the `CIFAR10Instance` Dataset. 31 | """ 32 | base_folder = 'cifar-100-python' 33 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 34 | filename = "cifar-100-python.tar.gz" 35 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 36 | train_list = [ 37 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 38 | ] 39 | 40 | test_list = [ 41 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 42 | ] 43 | -------------------------------------------------------------------------------- /datasets/folder.py: -------------------------------------------------------------------------------- 1 | import torchvision.datasets as datasets 2 | 3 | class ImageFolderInstance(datasets.ImageFolder): 4 | """: Folder datasets which returns the index of the image as well:: 5 | """ 6 | def __getitem__(self, index): 7 | """ 8 | Args: 9 | index (int): Index 10 | Returns: 11 | tuple: (image, target) where target is class_index of the target class. 12 | """ 13 | path, target = self.imgs[index] 14 | img = self.loader(path) 15 | if self.transform is not None: 16 | img = self.transform(img) 17 | if self.target_transform is not None: 18 | target = self.target_transform(target) 19 | 20 | return img, target, index 21 | 22 | -------------------------------------------------------------------------------- /lib/LinearAverage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch import nn 4 | import math 5 | 6 | class LinearAverageOp(Function): 7 | @staticmethod 8 | def forward(self, x, y, memory, params): 9 | T = params[0].item() 10 | batchSize = x.size(0) 11 | 12 | # inner product 13 | out = torch.mm(x.data, memory.t()) 14 | out.div_(T) # batchSize * N 15 | 16 | self.save_for_backward(x, memory, y, params) 17 | 18 | return out 19 | 20 | @staticmethod 21 | def backward(self, gradOutput): 22 | x, memory, y, params = self.saved_tensors 23 | batchSize = gradOutput.size(0) 24 | T = params[0].item() 25 | momentum = params[1].item() 26 | 27 | # add temperature 28 | gradOutput.data.div_(T) 29 | 30 | # gradient of linear 31 | gradInput = torch.mm(gradOutput.data, memory) 32 | gradInput.resize_as_(x) 33 | 34 | # update the non-parametric data 35 | weight_pos = memory.index_select(0, y.data.view(-1)).resize_as_(x) 36 | weight_pos.mul_(momentum) 37 | weight_pos.add_(torch.mul(x.data, 1-momentum)) 38 | w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5) 39 | updated_weight = weight_pos.div(w_norm) 40 | memory.index_copy_(0, y, updated_weight) 41 | 42 | return gradInput, None, None, None 43 | 44 | class LinearAverage(nn.Module): 45 | 46 | def __init__(self, inputSize, outputSize, T=0.05, momentum=0.5): 47 | super(LinearAverage, self).__init__() 48 | stdv = 1 / math.sqrt(inputSize) 49 | self.nLem = outputSize 50 | 51 | self.register_buffer('params',torch.tensor([T, momentum])); 52 | stdv = 1. / math.sqrt(inputSize/3) 53 | self.register_buffer('memory', torch.rand(outputSize, inputSize).mul_(2*stdv).add_(-stdv)) 54 | 55 | def forward(self, x, y): 56 | out = LinearAverageOp.apply(x, y, self.memory, self.params) 57 | return out 58 | 59 | -------------------------------------------------------------------------------- /lib/NCA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Function 4 | import math 5 | 6 | eps = 1e-8 7 | 8 | class NCACrossEntropy(nn.Module): 9 | ''' \sum_{j=C} log(p_{ij}) 10 | Store all the labels of the dataset. 11 | Only pass the indexes of the training instances during forward. 12 | ''' 13 | def __init__(self, labels, margin=0): 14 | super(NCACrossEntropy, self).__init__() 15 | self.register_buffer('labels', torch.LongTensor(labels.size(0))) 16 | self.labels = labels 17 | self.margin = margin 18 | 19 | def forward(self, x, indexes): 20 | batchSize = x.size(0) 21 | n = x.size(1) 22 | exp = torch.exp(x) 23 | 24 | # labels for currect batch 25 | y = torch.index_select(self.labels, 0, indexes.data).view(batchSize, 1) 26 | same = y.repeat(1, n).eq_(self.labels) 27 | 28 | # self prob exclusion, hack with memory for effeciency 29 | exp.data.scatter_(1, indexes.data.view(-1,1), 0) 30 | 31 | p = torch.mul(exp, same.float()).sum(dim=1) 32 | Z = exp.sum(dim=1) 33 | 34 | Z_exclude = Z - p 35 | p = p.div(math.exp(self.margin)) 36 | Z = Z_exclude + p 37 | 38 | prob = torch.div(p, Z) 39 | prob_masked = torch.masked_select(prob, prob.ne(0)) 40 | 41 | loss = prob_masked.log().sum(0) 42 | 43 | return - loss / batchSize 44 | 45 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | # nothing 2 | -------------------------------------------------------------------------------- /lib/normalize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch import nn 4 | 5 | class Normalize(nn.Module): 6 | 7 | def __init__(self, power=2): 8 | super(Normalize, self).__init__() 9 | self.power = power 10 | 11 | def forward(self, x): 12 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1./self.power) 13 | out = x.div(norm) 14 | return out 15 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | def __init__(self): 4 | self.reset() 5 | 6 | def reset(self): 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | 12 | def update(self, val, n=1): 13 | self.val = val 14 | self.sum += val * n 15 | self.count += n 16 | self.avg = self.sum / self.count 17 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import shutil 5 | import time 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.distributed as dist 12 | import torch.optim 13 | import torch.utils.data 14 | import torch.utils.data.distributed 15 | import torchvision.transforms as transforms 16 | 17 | import datasets 18 | import models 19 | import math 20 | 21 | from lib.LinearAverage import LinearAverage 22 | from lib.NCA import NCACrossEntropy 23 | from lib.utils import AverageMeter 24 | from test import NN, kNN 25 | 26 | model_names = sorted(name for name in models.__dict__ 27 | if name.islower() and not name.startswith("__") 28 | and callable(models.__dict__[name])) 29 | 30 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 31 | parser.add_argument('data', metavar='DIR', 32 | help='path to dataset') 33 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 34 | choices=model_names, 35 | help='model architecture: ' + 36 | ' | '.join(model_names) + 37 | ' (default: resnet18)') 38 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 39 | help='number of data loading workers (default: 4)') 40 | parser.add_argument('--epochs', default=130, type=int, metavar='N', 41 | help='number of total epochs to run') 42 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 43 | help='manual epoch number (useful on restarts)') 44 | parser.add_argument('-b', '--batch-size', default=256, type=int, 45 | metavar='N', help='mini-batch size (default: 256)') 46 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 47 | metavar='LR', help='initial learning rate') 48 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 49 | help='momentum') 50 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 51 | metavar='W', help='weight decay (default: 1e-4)') 52 | parser.add_argument('--print-freq', '-p', default=10, type=int, 53 | metavar='N', help='print frequency (default: 10)') 54 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 55 | help='path to latest checkpoint (default: none)') 56 | parser.add_argument('--test-only', action='store_true', help='test only') 57 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 58 | help='evaluate model on validation set') 59 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 60 | help='use pre-trained model') 61 | parser.add_argument('--world-size', default=1, type=int, 62 | help='number of distributed processes') 63 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 64 | help='url used to set up distributed training') 65 | parser.add_argument('--dist-backend', default='gloo', type=str, 66 | help='distributed backend') 67 | parser.add_argument('--low-dim', default=128, type=int, 68 | metavar='D', help='feature dimension') 69 | parser.add_argument('--temperature', default=0.05, type=float, 70 | metavar='T', help='temperature parameter') 71 | parser.add_argument('--memory-momentum', '--m-mementum', default=0.5, type=float, 72 | metavar='M', help='momentum for non-parametric updates') 73 | parser.add_argument('--iter-size', default=1, type=int, 74 | help='caffe style iter size') 75 | parser.add_argument('--margin', default=0.0, type=float, 76 | help='classification margin') 77 | 78 | best_prec1 = 0 79 | 80 | def main(): 81 | global args, best_prec1 82 | args = parser.parse_args() 83 | 84 | args.distributed = args.world_size > 1 85 | 86 | if args.distributed: 87 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 88 | world_size=args.world_size) 89 | 90 | # Data loading code 91 | traindir = os.path.join(args.data, 'train') 92 | valdir = os.path.join(args.data, 'val') 93 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 94 | std=[0.229, 0.224, 0.225]) 95 | 96 | train_dataset = datasets.ImageFolderInstance( 97 | traindir, 98 | transforms.Compose([ 99 | transforms.RandomResizedCrop(224), 100 | transforms.RandomHorizontalFlip(), 101 | transforms.ToTensor(), 102 | normalize, 103 | ])) 104 | 105 | if args.distributed: 106 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 107 | else: 108 | train_sampler = None 109 | 110 | train_loader = torch.utils.data.DataLoader( 111 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 112 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 113 | 114 | val_loader = torch.utils.data.DataLoader( 115 | datasets.ImageFolderInstance(valdir, transforms.Compose([ 116 | transforms.Resize(256), 117 | transforms.CenterCrop(224), 118 | transforms.ToTensor(), 119 | normalize, 120 | ])), 121 | batch_size=args.batch_size, shuffle=False, 122 | num_workers=args.workers, pin_memory=True) 123 | 124 | # create model 125 | if args.pretrained: 126 | print("=> using pre-trained model '{}'".format(args.arch)) 127 | model = models.__dict__[args.arch](pretrained=True) 128 | else: 129 | print("=> creating model '{}'".format(args.arch)) 130 | model = models.__dict__[args.arch](low_dim=args.low_dim) 131 | 132 | if not args.distributed: 133 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 134 | model.features = torch.nn.DataParallel(model.features) 135 | model.cuda() 136 | else: 137 | model = torch.nn.DataParallel(model).cuda() 138 | else: 139 | model.cuda() 140 | model = torch.nn.parallel.DistributedDataParallel(model) 141 | 142 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 143 | momentum=args.momentum, 144 | weight_decay=args.weight_decay, nesterov=True) 145 | 146 | # optionally resume from a checkpoint 147 | if args.resume: 148 | if os.path.isfile(args.resume): 149 | print("=> loading checkpoint '{}'".format(args.resume)) 150 | checkpoint = torch.load(args.resume) 151 | args.start_epoch = checkpoint['epoch'] 152 | best_prec1 = checkpoint['best_prec1'] 153 | model.load_state_dict(checkpoint['state_dict']) 154 | lemniscate = checkpoint['lemniscate'] 155 | optimizer.load_state_dict(checkpoint['optimizer']) 156 | print("=> loaded checkpoint '{}' (epoch {})" 157 | .format(args.resume, checkpoint['epoch'])) 158 | else: 159 | print("=> no checkpoint found at '{}'".format(args.resume)) 160 | else: 161 | # define lemniscate and loss function (criterion) 162 | ndata = train_dataset.__len__() 163 | lemniscate = LinearAverage(args.low_dim, ndata, args.temperature, args.memory_momentum).cuda() 164 | 165 | 166 | criterion = NCACrossEntropy(torch.LongTensor([y for (p, y) in train_loader.dataset.imgs]), 167 | args.margin / args.temperature).cuda() 168 | cudnn.benchmark = True 169 | 170 | if args.evaluate: 171 | prec1 = kNN(0, model, lemniscate, train_loader, val_loader, 30, args.temperature, 0) 172 | return 173 | 174 | for epoch in range(args.start_epoch, args.epochs): 175 | if args.distributed: 176 | train_sampler.set_epoch(epoch) 177 | adjust_learning_rate(optimizer, epoch) 178 | adjust_memory_update_rate(lemniscate, epoch) 179 | 180 | # train for one epoch 181 | train(train_loader, model, lemniscate, criterion, optimizer, epoch) 182 | 183 | # evaluate on validation set 184 | prec1 = NN(epoch, model, lemniscate, train_loader, val_loader) 185 | 186 | # remember best prec@1 and save checkpoint 187 | is_best = prec1 > best_prec1 188 | best_prec1 = max(prec1, best_prec1) 189 | save_checkpoint({ 190 | 'epoch': epoch + 1, 191 | 'arch': args.arch, 192 | 'state_dict': model.state_dict(), 193 | 'lemniscate': lemniscate, 194 | 'best_prec1': best_prec1, 195 | 'optimizer' : optimizer.state_dict(), 196 | }, is_best) 197 | 198 | 199 | def train(train_loader, model, lemniscate, criterion, optimizer, epoch): 200 | batch_time = AverageMeter() 201 | data_time = AverageMeter() 202 | losses = AverageMeter() 203 | 204 | # switch to train mode 205 | model.train() 206 | 207 | end = time.time() 208 | optimizer.zero_grad() 209 | for i, (input, target, index) in enumerate(train_loader): 210 | # measure data loading time 211 | data_time.update(time.time() - end) 212 | 213 | target = target.cuda(async=True) 214 | index = index.cuda(async=True) 215 | input_var = torch.autograd.Variable(input) 216 | target_var = torch.autograd.Variable(target) 217 | index_var = torch.autograd.Variable(index) 218 | 219 | # compute output 220 | feature = model(input_var) 221 | output = lemniscate(feature, index_var) 222 | loss = criterion(output, index_var) / args.iter_size 223 | 224 | loss.backward() 225 | # measure accuracy and record loss 226 | losses.update(loss.data[0] * args.iter_size, input.size(0)) 227 | 228 | if (i+1) % args.iter_size == 0: 229 | # compute gradient and do SGD step 230 | optimizer.step() 231 | optimizer.zero_grad() 232 | 233 | # measure elapsed time 234 | batch_time.update(time.time() - end) 235 | end = time.time() 236 | 237 | if i % args.print_freq == 0: 238 | print('Epoch: [{0}][{1}/{2}]\t' 239 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 240 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 241 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( 242 | epoch, i, len(train_loader), batch_time=batch_time, 243 | data_time=data_time, loss=losses)) 244 | 245 | 246 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 247 | torch.save(state, filename) 248 | if is_best: 249 | shutil.copyfile(filename, 'model_best.pth.tar') 250 | 251 | def adjust_memory_update_rate(lemniscate, epoch): 252 | if epoch >= 80: 253 | lemniscate.params[1] = 0.8 254 | if epoch >= 120: 255 | lemniscate.params[1] = 0.9 256 | 257 | def adjust_learning_rate(optimizer, epoch): 258 | """Sets the learning rate to the initial LR decayed by 10 every 40 epochs""" 259 | lr = args.lr * (0.1 ** (epoch // 40)) 260 | print(lr) 261 | for param_group in optimizer.param_groups: 262 | param_group['lr'] = lr 263 | 264 | def accuracy(output, target, topk=(1,)): 265 | """Computes the precision@k for the specified values of k""" 266 | maxk = max(topk) 267 | batch_size = target.size(0) 268 | 269 | _, pred = output.topk(maxk, 1, True, True) 270 | pred = pred.t() 271 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 272 | 273 | res = [] 274 | for k in topk: 275 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 276 | res.append(correct_k.mul_(100.0 / batch_size)) 277 | return res 278 | 279 | 280 | if __name__ == '__main__': 281 | main() 282 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .resnet_cifar import * 3 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | from lib.normalize import Normalize 5 | 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152'] 9 | 10 | model_urls = { } 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | "3x3 convolution with padding" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(inplanes, planes, stride) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.conv2 = conv3x3(planes, planes) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.downsample = downsample 29 | self.stride = stride 30 | 31 | def forward(self, x): 32 | residual = x 33 | 34 | out = self.conv1(x) 35 | out = self.bn1(out) 36 | out = self.relu(out) 37 | 38 | out = self.conv2(out) 39 | out = self.bn2(out) 40 | 41 | if self.downsample is not None: 42 | residual = self.downsample(x) 43 | 44 | out += residual 45 | out = self.relu(out) 46 | 47 | return out 48 | 49 | 50 | class Bottleneck(nn.Module): 51 | expansion = 4 52 | 53 | def __init__(self, inplanes, planes, stride=1, downsample=None): 54 | super(Bottleneck, self).__init__() 55 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 56 | self.bn1 = nn.BatchNorm2d(planes) 57 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 58 | padding=1, bias=False) 59 | self.bn2 = nn.BatchNorm2d(planes) 60 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 61 | self.bn3 = nn.BatchNorm2d(planes * 4) 62 | self.relu = nn.ReLU(inplace=True) 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x): 67 | residual = x 68 | 69 | out = self.conv1(x) 70 | out = self.bn1(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv2(out) 74 | out = self.bn2(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv3(out) 78 | out = self.bn3(out) 79 | 80 | if self.downsample is not None: 81 | residual = self.downsample(x) 82 | 83 | out += residual 84 | out = self.relu(out) 85 | 86 | return out 87 | 88 | 89 | class ResNet(nn.Module): 90 | 91 | def __init__(self, block, layers, low_dim=128): 92 | self.inplanes = 64 93 | super(ResNet, self).__init__() 94 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 95 | bias=False) 96 | self.bn1 = nn.BatchNorm2d(64) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 99 | self.layer1 = self._make_layer(block, 64, layers[0]) 100 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 101 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 102 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 103 | self.avgpool = nn.AvgPool2d(7, stride=1) 104 | self.fc = nn.Linear(512 * block.expansion, low_dim) 105 | self.l2norm = Normalize(2) 106 | 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 110 | m.weight.data.normal_(0, math.sqrt(2. / n)) 111 | elif isinstance(m, nn.BatchNorm2d): 112 | m.weight.data.fill_(1) 113 | m.bias.data.zero_() 114 | 115 | def _make_layer(self, block, planes, blocks, stride=1): 116 | downsample = None 117 | if stride != 1 or self.inplanes != planes * block.expansion: 118 | downsample = nn.Sequential( 119 | nn.Conv2d(self.inplanes, planes * block.expansion, 120 | kernel_size=1, stride=stride, bias=False), 121 | nn.BatchNorm2d(planes * block.expansion), 122 | ) 123 | 124 | layers = [] 125 | layers.append(block(self.inplanes, planes, stride, downsample)) 126 | self.inplanes = planes * block.expansion 127 | for i in range(1, blocks): 128 | layers.append(block(self.inplanes, planes)) 129 | 130 | return nn.Sequential(*layers) 131 | 132 | def forward(self, x): 133 | x = self.conv1(x) 134 | x = self.bn1(x) 135 | x = self.relu(x) 136 | x = self.maxpool(x) 137 | 138 | x = self.layer1(x) 139 | x = self.layer2(x) 140 | x = self.layer3(x) 141 | x = self.layer4(x) 142 | 143 | x = self.avgpool(x) 144 | x = x.view(x.size(0), -1) 145 | x = self.fc(x) 146 | x = self.l2norm(x) 147 | 148 | return x 149 | 150 | 151 | def resnet18(pretrained=False, **kwargs): 152 | """Constructs a ResNet-18 model. 153 | 154 | Args: 155 | pretrained (bool): If True, returns a model pre-trained on ImageNet 156 | """ 157 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 158 | if pretrained: 159 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 160 | return model 161 | 162 | 163 | def resnet34(pretrained=False, **kwargs): 164 | """Constructs a ResNet-34 model. 165 | 166 | Args: 167 | pretrained (bool): If True, returns a model pre-trained on ImageNet 168 | """ 169 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 170 | if pretrained: 171 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 172 | return model 173 | 174 | 175 | def resnet50(pretrained=False, **kwargs): 176 | """Constructs a ResNet-50 model. 177 | 178 | Args: 179 | pretrained (bool): If True, returns a model pre-trained on ImageNet 180 | """ 181 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 182 | if pretrained: 183 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 184 | return model 185 | 186 | 187 | def resnet101(pretrained=False, **kwargs): 188 | """Constructs a ResNet-101 model. 189 | 190 | Args: 191 | pretrained (bool): If True, returns a model pre-trained on ImageNet 192 | """ 193 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 194 | if pretrained: 195 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 196 | return model 197 | 198 | 199 | def resnet152(pretrained=False, **kwargs): 200 | """Constructs a ResNet-152 model. 201 | 202 | Args: 203 | pretrained (bool): If True, returns a model pre-trained on ImageNet 204 | """ 205 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 206 | if pretrained: 207 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 208 | return model 209 | -------------------------------------------------------------------------------- /models/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from lib.normalize import Normalize 13 | 14 | from torch.autograd import Variable 15 | 16 | 17 | class BasicBlock(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, in_planes, planes, stride=1): 21 | super(BasicBlock, self).__init__() 22 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | 27 | self.shortcut = nn.Sequential() 28 | if stride != 1 or in_planes != self.expansion*planes: 29 | self.shortcut = nn.Sequential( 30 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 31 | nn.BatchNorm2d(self.expansion*planes) 32 | ) 33 | 34 | def forward(self, x): 35 | out = F.relu(self.bn1(self.conv1(x))) 36 | out = self.bn2(self.conv2(out)) 37 | out += self.shortcut(x) 38 | out = F.relu(out) 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1): 46 | super(Bottleneck, self).__init__() 47 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(planes) 49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 50 | self.bn2 = nn.BatchNorm2d(planes) 51 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 52 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 53 | 54 | self.shortcut = nn.Sequential() 55 | if stride != 1 or in_planes != self.expansion*planes: 56 | self.shortcut = nn.Sequential( 57 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 58 | nn.BatchNorm2d(self.expansion*planes) 59 | ) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(self.conv1(x))) 63 | out = F.relu(self.bn2(self.conv2(out))) 64 | out = self.bn3(self.conv3(out)) 65 | out += self.shortcut(x) 66 | out = F.relu(out) 67 | return out 68 | 69 | 70 | class ResNet(nn.Module): 71 | def __init__(self, block, num_blocks, low_dim=128): 72 | super(ResNet, self).__init__() 73 | self.in_planes = 64 74 | 75 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 76 | self.bn1 = nn.BatchNorm2d(64) 77 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 78 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 79 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 80 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 81 | self.linear = nn.Linear(512*block.expansion, low_dim) 82 | self.l2norm = Normalize(2) 83 | 84 | def _make_layer(self, block, planes, num_blocks, stride): 85 | strides = [stride] + [1]*(num_blocks-1) 86 | layers = [] 87 | for stride in strides: 88 | layers.append(block(self.in_planes, planes, stride)) 89 | self.in_planes = planes * block.expansion 90 | return nn.Sequential(*layers) 91 | 92 | def forward(self, x): 93 | out = F.relu(self.bn1(self.conv1(x))) 94 | out = self.layer1(out) 95 | out = self.layer2(out) 96 | out = self.layer3(out) 97 | out = self.layer4(out) 98 | out = F.avg_pool2d(out, 4) 99 | out = out.view(out.size(0), -1) 100 | out = self.linear(out) 101 | out = self.l2norm(out) 102 | return out 103 | 104 | 105 | def ResNet18(low_dim=128): 106 | return ResNet(BasicBlock, [2,2,2,2], low_dim) 107 | 108 | def ResNet34(low_dim=128): 109 | return ResNet(BasicBlock, [3,4,6,3], low_dim) 110 | 111 | def ResNet50(low_dim=128): 112 | return ResNet(Bottleneck, [3,4,6,3], low_dim) 113 | 114 | def ResNet101(low_dim=128): 115 | return ResNet(Bottleneck, [3,4,23,3], low_dim) 116 | 117 | def ResNet152(low_dim=128): 118 | return ResNet(Bottleneck, [3,8,36,3], low_dim) 119 | 120 | 121 | def test(): 122 | net = ResNet18() 123 | y = net(Variable(torch.randn(1,3,32,32))) 124 | print(y.size()) 125 | 126 | # test() 127 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import datasets 4 | from lib.utils import AverageMeter 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | 8 | def NN(epoch, net, lemniscate, trainloader, testloader, recompute_memory=0): 9 | net.eval() 10 | net_time = AverageMeter() 11 | cls_time = AverageMeter() 12 | losses = AverageMeter() 13 | correct = 0. 14 | total = 0 15 | testsize = testloader.dataset.__len__() 16 | 17 | trainFeatures = lemniscate.memory.t() 18 | if hasattr(trainloader.dataset, 'imgs'): 19 | trainLabels = torch.LongTensor([y for (p, y) in trainloader.dataset.imgs]).cuda() 20 | else: 21 | trainLabels = torch.LongTensor(trainloader.dataset.train_labels).cuda() 22 | 23 | if recompute_memory: 24 | transform_bak = trainloader.dataset.transform 25 | trainloader.dataset.transform = testloader.dataset.transform 26 | temploader = torch.utils.data.DataLoader(trainloader.dataset, batch_size=100, shuffle=False, num_workers=1) 27 | for batch_idx, (inputs, targets, indexes) in enumerate(temploader): 28 | targets = targets.cuda(async=True) 29 | batchSize = inputs.size(0) 30 | features = net(inputs) 31 | trainFeatures[:, batch_idx*batchSize:batch_idx*batchSize+batchSize] = features.data.t() 32 | trainLabels = torch.LongTensor(temploader.dataset.train_labels).cuda() 33 | trainloader.dataset.transform = transform_bak 34 | 35 | end = time.time() 36 | with torch.no_grad(): 37 | for batch_idx, (inputs, targets, indexes) in enumerate(testloader): 38 | targets = targets.cuda(async=True) 39 | batchSize = inputs.size(0) 40 | features = net(inputs) 41 | net_time.update(time.time() - end) 42 | end = time.time() 43 | 44 | dist = torch.mm(features, trainFeatures) 45 | 46 | yd, yi = dist.topk(1, dim=1, largest=True, sorted=True) 47 | candidates = trainLabels.view(1,-1).expand(batchSize, -1) 48 | retrieval = torch.gather(candidates, 1, yi) 49 | 50 | retrieval = retrieval.narrow(1, 0, 1).clone().view(-1) 51 | yd = yd.narrow(1, 0, 1) 52 | 53 | total += targets.size(0) 54 | correct += retrieval.eq(targets.data).sum().item() 55 | 56 | cls_time.update(time.time() - end) 57 | end = time.time() 58 | 59 | print('Test [{}/{}]\t' 60 | 'Net Time {net_time.val:.3f} ({net_time.avg:.3f})\t' 61 | 'Cls Time {cls_time.val:.3f} ({cls_time.avg:.3f})\t' 62 | 'Top1: {:.2f}'.format( 63 | total, testsize, correct*100./total, net_time=net_time, cls_time=cls_time)) 64 | 65 | return correct/total 66 | 67 | def kNN(epoch, net, lemniscate, trainloader, testloader, K, sigma, recompute_memory=0): 68 | net.eval() 69 | net_time = AverageMeter() 70 | cls_time = AverageMeter() 71 | total = 0 72 | testsize = testloader.dataset.__len__() 73 | 74 | trainFeatures = lemniscate.memory.t() 75 | if hasattr(trainloader.dataset, 'imgs'): 76 | trainLabels = torch.LongTensor([y for (p, y) in trainloader.dataset.imgs]).cuda() 77 | else: 78 | trainLabels = torch.LongTensor(trainloader.dataset.train_labels).cuda() 79 | C = trainLabels.max() + 1 80 | 81 | if recompute_memory: 82 | transform_bak = trainloader.dataset.transform 83 | trainloader.dataset.transform = testloader.dataset.transform 84 | temploader = torch.utils.data.DataLoader(trainloader.dataset, batch_size=100, shuffle=False, num_workers=1) 85 | for batch_idx, (inputs, targets, indexes) in enumerate(temploader): 86 | targets = targets.cuda(async=True) 87 | batchSize = inputs.size(0) 88 | features = net(inputs) 89 | trainFeatures[:, batch_idx*batchSize:batch_idx*batchSize+batchSize] = features.data.t() 90 | trainLabels = torch.LongTensor(temploader.dataset.train_labels).cuda() 91 | trainloader.dataset.transform = transform_bak 92 | 93 | top1 = 0. 94 | top5 = 0. 95 | end = time.time() 96 | with torch.no_grad(): 97 | retrieval_one_hot = torch.zeros(K, C).cuda() 98 | for batch_idx, (inputs, targets, indexes) in enumerate(testloader): 99 | end = time.time() 100 | targets = targets.cuda(async=True) 101 | batchSize = inputs.size(0) 102 | features = net(inputs) 103 | net_time.update(time.time() - end) 104 | end = time.time() 105 | 106 | dist = torch.mm(features, trainFeatures) 107 | 108 | yd, yi = dist.topk(K, dim=1, largest=True, sorted=True) 109 | candidates = trainLabels.view(1,-1).expand(batchSize, -1) 110 | retrieval = torch.gather(candidates, 1, yi) 111 | 112 | retrieval_one_hot.resize_(batchSize * K, C).zero_() 113 | retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1) 114 | yd_transform = yd.clone().div_(sigma).exp_() 115 | probs = torch.sum(torch.mul(retrieval_one_hot.view(batchSize, -1 , C), yd_transform.view(batchSize, -1, 1)), 1) 116 | _, predictions = probs.sort(1, True) 117 | 118 | # Find which predictions match the target 119 | correct = predictions.eq(targets.data.view(-1,1)) 120 | cls_time.update(time.time() - end) 121 | 122 | top1 = top1 + correct.narrow(1,0,1).sum().item() 123 | top5 = top5 + correct.narrow(1,0,5).sum().item() 124 | 125 | total += targets.size(0) 126 | 127 | print('Test [{}/{}]\t' 128 | 'Net Time {net_time.val:.3f} ({net_time.avg:.3f})\t' 129 | 'Cls Time {cls_time.val:.3f} ({cls_time.avg:.3f})\t' 130 | 'Top1: {:.2f} Top5: {:.2f}'.format( 131 | total, testsize, top1*100./total, top5*100./total, net_time=net_time, cls_time=cls_time)) 132 | 133 | print(top1*100./total) 134 | 135 | return top1/total 136 | 137 | --------------------------------------------------------------------------------