├── lib ├── __init__.py ├── normalize.py ├── criterion.py ├── utils.py ├── non_parametric_classifier.py ├── protocols.py └── ans_discovery.py ├── docs ├── assets │ ├── large-scale.jpg │ ├── small-scale.jpg │ └── training-pipeline.jpg ├── LICENSE └── README.md ├── configs ├── cifar10.yaml ├── cifar100.yaml ├── svhn.yaml └── base.yaml ├── packages ├── __init__.py ├── datasets │ ├── transforms.py │ └── __init__.py ├── loggers │ ├── __init__.py │ ├── tf_logger.py │ └── std_logger.py ├── lr_policy │ ├── fixed.py │ ├── __init__.py │ ├── step.py │ └── multistep.py ├── networks │ └── __init__.py ├── optimizers │ ├── __init__.py │ ├── sgd.py │ ├── rmsprop.py │ └── adam.py ├── utils.py ├── config.py ├── session.py └── register.py ├── models ├── __init__.py └── resnet_cifar.py ├── datasets ├── svhn.py ├── __init__.py └── cifar.py ├── requirements.yaml └── main.py /lib/__init__.py: -------------------------------------------------------------------------------- 1 | # nothing -------------------------------------------------------------------------------- /docs/assets/large-scale.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Raymond-sci/AND/HEAD/docs/assets/large-scale.jpg -------------------------------------------------------------------------------- /docs/assets/small-scale.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Raymond-sci/AND/HEAD/docs/assets/small-scale.jpg -------------------------------------------------------------------------------- /docs/assets/training-pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Raymond-sci/AND/HEAD/docs/assets/training-pipeline.jpg -------------------------------------------------------------------------------- /configs/cifar10.yaml: -------------------------------------------------------------------------------- 1 | #################### 2 | # cfgs for cifar10 # 3 | #################### 4 | 5 | # args for dataset and dataloader 6 | dataset: cifar10 7 | data_root: data/cifar10 8 | batch_size: 128 9 | -------------------------------------------------------------------------------- /configs/cifar100.yaml: -------------------------------------------------------------------------------- 1 | #################### 2 | # cfgs for cifar10 # 3 | #################### 4 | 5 | # args for dataset and dataloader 6 | dataset: cifar100 7 | data_root: data/cifar100 8 | batch_size: 128 9 | -------------------------------------------------------------------------------- /configs/svhn.yaml: -------------------------------------------------------------------------------- 1 | #################### 2 | # cfgs for cifar10 # 3 | #################### 4 | 5 | # args for dataset and dataloader 6 | dataset: svhn 7 | data_root: data/svhn 8 | batch_size: 128 9 | display_freq: 115 -------------------------------------------------------------------------------- /packages/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-25 14:37:26 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | # from . import datasets 8 | # from . import loggers 9 | # from . import networks 10 | # from . import optimizers 11 | # from . import argparser 12 | # from . import lr_policy 13 | # from . import utils -------------------------------------------------------------------------------- /packages/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-28 21:45:01 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | import torchvision.transforms as transforms 8 | 9 | class RandomResizedCrop(transforms.RandomResizedCrop): 10 | 11 | def __init__(self, size, **kwargs): 12 | super(RandomResizedCrop, self).__init__(0, **kwargs) 13 | self.size = size -------------------------------------------------------------------------------- /lib/normalize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2018-10-12 21:37:13 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | import torch 8 | from torch.autograd import Variable 9 | from torch import nn 10 | 11 | class Normalize(nn.Module): 12 | """Normalize module 13 | 14 | Module used to normalize matrix 15 | 16 | Extends: 17 | nn.Module 18 | """ 19 | 20 | def __init__(self, power=2): 21 | super(Normalize, self).__init__() 22 | self.power = power 23 | 24 | def forward(self, x): 25 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1./self.power) 26 | out = x.div(norm) 27 | return out 28 | -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | ############# 2 | # base cfgs # 3 | ############# 4 | 5 | # args for AND 6 | ANs_select_rate: 0.25 7 | ANs_size: 1 8 | max_round: 5 9 | 10 | # args for network 11 | network: ResNet18 12 | 13 | # args for training 14 | log_file: True 15 | log_tfb: True 16 | display_freq: 80 17 | workers_num: 4 18 | 19 | # args for transforms 20 | size: (32, 32) 21 | resize: 32 22 | scale: (0.2, 1.) 23 | ratio: (0.75, 1.333333) 24 | colorjitter: (0.4, 0.4, 0.4, 0.4) 25 | random_grayscale: 0.2 26 | # random_horizontal_flip: True 27 | 28 | # args for lr policy 29 | base_lr: 0.03 30 | lr_policy: step 31 | lr_decay_offset: 80 32 | lr_decay_step: 40 33 | lr_decay_rate: 0.1 34 | 35 | # args for optimizer 36 | optimizer: sgd 37 | weight_decay: 5e-4 38 | momentum: 0.9 39 | nesterov: True 40 | 41 | # args for protocol 42 | protocol: knn -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-03-13 21:25:39 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | from .resnet_cifar import * 8 | 9 | from packages.register import REGISTER 10 | from packages.config import CONFIG as cfg 11 | from packages import networks as cmd_networks 12 | 13 | def require_args(): 14 | 15 | cfg.add_argument('--low-dim', default=128, type=int, help='feature dimension') 16 | 17 | def get(name, instant=False): 18 | cls = cmd_networks.get(name) 19 | if not instant: 20 | return cls 21 | return cls(low_dim=cfg.low_dim) 22 | 23 | REGISTER.set_package(__name__) 24 | cmd_networks.register('ResNet18', ResNet18) 25 | cmd_networks.register('ResNet50', ResNet50) 26 | cmd_networks.register('ResNet101', ResNet101) 27 | -------------------------------------------------------------------------------- /packages/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-24 23:26:16 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | from . import std_logger 8 | from . import tf_logger 9 | 10 | from ..register import REGISTER 11 | 12 | def require_args(): 13 | """all args for logger objects 14 | 15 | Arguments: 16 | parser {argparse} -- current version of argparse object 17 | """ 18 | if not REGISTER.is_package_registered(__name__): 19 | return parser 20 | 21 | classes = REGISTER.get_classes(__name__) 22 | 23 | for (name, cls) in classes.items(): 24 | if hasattr(cls, 'require_args'): 25 | cls.require_args() 26 | 27 | def get(name): 28 | return REGISTER.get_class(__name__, name) 29 | 30 | def register(name, cls): 31 | REGISTER.set_class(__name__, name, cls) -------------------------------------------------------------------------------- /packages/lr_policy/fixed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-26 20:15:47 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | from ..loggers.std_logger import STDLogger as logger 8 | from ..config import CONFIG as cfg 9 | 10 | class FixedPolicy: 11 | 12 | def __init__(self, *args, **kwargs): 13 | if len(args) + len(kwargs) == 0: 14 | self.__init_by_cfg() 15 | else: 16 | self.__init(*args, **kwargs) 17 | 18 | def __init(self, base_lr): 19 | self.base_lr = base_lr 20 | logger.debug('Going to use [fixed] learning policy for optimization' 21 | ' with base learning rate [%.5f]' % base_lr) 22 | 23 | def __init_by_cfg(self): 24 | self.__init(cfg.base_lr) 25 | 26 | def update(self, epoch, *args, **kwargs): 27 | return self.base_lr 28 | 29 | from ..register import REGISTER 30 | REGISTER.set_class(REGISTER.get_package_name(__name__), 'fixed', FixedPolicy) -------------------------------------------------------------------------------- /packages/networks/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-24 23:26:56 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | from ..register import REGISTER 8 | from ..config import CONFIG as cfg 9 | 10 | def require_args(): 11 | """all args for network objects 12 | 13 | Arguments: 14 | parser {argparse} -- current version of argparse object 15 | """ 16 | 17 | known_args, _ = cfg.parse_known_args() 18 | 19 | if (REGISTER.is_package_registered(__name__) and 20 | REGISTER.is_class_registered(__name__, known_args.network)): 21 | 22 | network = get(known_args.network) 23 | 24 | if hasattr(network, 'require_args'): 25 | # get args for network 26 | return network.require_args() 27 | 28 | def get(name, instant=False): 29 | cls = REGISTER.get_class(__name__, name) 30 | if instant: 31 | return cls() 32 | return cls 33 | 34 | def register(name, cls): 35 | REGISTER.set_class(__name__, name, cls) -------------------------------------------------------------------------------- /docs/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Raymond Wong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /packages/loggers/tf_logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-25 17:16:34 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | from tensorboardX import SummaryWriter as TFBWriter 8 | 9 | from ..config import CONFIG as cfg 10 | 11 | class TFLogger(): 12 | 13 | @staticmethod 14 | def require_args(): 15 | 16 | cfg.add_argument('--log-tfb', action='store_true', 17 | help='use tensorboard to log training process. ' 18 | '(default: False)') 19 | 20 | def __init__(self, debugging, *args, **kwargs): 21 | self.debugging = debugging 22 | if not self.debugging and cfg.log_tfb: 23 | self.writer = TFBWriter(*args, **kwargs) 24 | 25 | def __getattr__(self,attr): 26 | if self.debugging or not cfg.log_tfb: 27 | return do_nothing 28 | return self.writer.__getattribute__(attr) 29 | 30 | def do_nothing(*args, **kwargs): 31 | pass 32 | 33 | 34 | from ..register import REGISTER 35 | REGISTER.set_class(REGISTER.get_package_name(__name__), 'tf_logger', TFLogger) 36 | -------------------------------------------------------------------------------- /packages/lr_policy/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-26 20:11:16 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | from . import step 8 | from . import multistep 9 | from . import fixed 10 | 11 | from ..register import REGISTER 12 | from ..config import CONFIG as cfg 13 | 14 | def require_args(): 15 | """all args for optimizer objects 16 | 17 | Arguments: 18 | parser {argparse} -- current version of argparse object 19 | """ 20 | 21 | cfg.add_argument('--base-lr', default=1e-1, type=float, 22 | help='base learning rate. (default: 1e-1)') 23 | 24 | known_args, _ = cfg.parse_known_args() 25 | 26 | if (REGISTER.is_package_registered(__name__) and 27 | REGISTER.is_class_registered(__name__, known_args.lr_policy)): 28 | 29 | policy = get(known_args.lr_policy) 30 | 31 | if hasattr(policy, 'require_args'): 32 | return policy.require_args() 33 | 34 | def get(name, instant=False): 35 | cls = REGISTER.get_class(__name__, name) 36 | if instant: 37 | return cls() 38 | return cls 39 | 40 | def register(name, cls): 41 | REGISTER.set_class(__name__, name, cls) -------------------------------------------------------------------------------- /packages/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-24 23:26:56 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | from . import sgd 8 | from . import adam 9 | from . import rmsprop 10 | 11 | from ..register import REGISTER 12 | from ..config import CONFIG as cfg 13 | 14 | def require_args(): 15 | """all args for optimizer objects 16 | 17 | Arguments: 18 | parser {argparse} -- current version of argparse object 19 | """ 20 | 21 | cfg.add_argument('--weight-decay', default=0, type=float, 22 | help='weight decay (L2 penalty)') 23 | 24 | known_args, _ = cfg.parse_known_args() 25 | 26 | if (REGISTER.is_package_registered(__name__) and 27 | REGISTER.is_class_registered(__name__, known_args.optimizer)): 28 | 29 | optimizer = get(known_args.optimizer) 30 | 31 | if hasattr(optimizer, 'require_args'): 32 | return optimizer.require_args() 33 | 34 | def get(name, instant=False, params=None): 35 | cls = REGISTER.get_class(__name__, name) 36 | if instant: 37 | return cls.get(params) 38 | return cls 39 | 40 | def register(name, cls): 41 | REGISTER.set_class(__name__, name, cls) 42 | 43 | 44 | -------------------------------------------------------------------------------- /packages/optimizers/sgd.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-26 19:42:22 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | import sys 8 | 9 | import torch 10 | 11 | from ..loggers.std_logger import STDLogger as logger 12 | from ..config import CONFIG as cfg 13 | 14 | def get(params): 15 | logger.debug('Going to use [SGD] optimizer for training with momentum %.2f, ' 16 | 'dampening %f, weight decay %f %s nesterov' % (cfg.momentum, 17 | cfg.dampening, cfg.weight_decay, 18 | ('with' if cfg.nesterov else 'without'))) 19 | return torch.optim.SGD(params, lr=cfg.base_lr, momentum=cfg.momentum, 20 | weight_decay=cfg.weight_decay, nesterov=cfg.nesterov, 21 | dampening=cfg.dampening) 22 | 23 | def require_args(): 24 | 25 | cfg.add_argument('--momentum', default=0, type=float, 26 | help='momentum factor') 27 | cfg.add_argument('--dampening', default=0, type=float, 28 | help='dampening for momentum') 29 | cfg.add_argument('--nesterov', action='store_true', 30 | help='enables Nesterov momentum') 31 | 32 | from ..register import REGISTER 33 | REGISTER.set_class(REGISTER.get_package_name(__name__), 'sgd', __name__) -------------------------------------------------------------------------------- /packages/optimizers/rmsprop.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-26 19:42:22 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | import sys 8 | 9 | import torch 10 | 11 | from ..loggers.std_logger import STDLogger as logger 12 | from ..config import CONFIG as cfg 13 | 14 | def get(params): 15 | logger.debug('Going to use [RMSprop] optimizer for training with alpha %.2f, ' 16 | 'eps %f, weight decay %f, momentum %f %s centered' % (cfg.alpha, 17 | cfg.eps, cfg.weight_decay, cfg.momentum, 18 | ('with' if cfg.centered else 'without'))) 19 | return torch.optim.RMSprop(params, lr=cfg.base_lr, 20 | alpha=cfg.alpha, eps=cfg.eps, momentum=cfg.eps, 21 | centered=cfg.centered, weight_decay=cfg.weight_decay) 22 | 23 | def require_args(): 24 | 25 | cfg.add_argument('--alpha', default=0.99, type=float, 26 | help='smoothing constant') 27 | cfg.add_argument('--eps', default=1e-8, type=float, 28 | help=('term added to the denominator to improve' 29 | ' numerical stability')) 30 | cfg.add_argument('--momentum', default=0, type=float, 31 | help='momentum factor') 32 | cfg.add_argument('--centered', action='store_true', 33 | help='whether to compute the centered RMSProp') 34 | 35 | from ..register import REGISTER 36 | REGISTER.set_class(REGISTER.get_package_name(__name__), 'rmsprop', __name__) -------------------------------------------------------------------------------- /packages/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-28 12:20:22 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | def tuple_or_list(target): 8 | """Check is a string object contains a tuple or list 9 | 10 | If a string likes '[a, b, c]' or '(a, b, c)', then return true, 11 | otherwise false. 12 | 13 | Arguments: 14 | target {str} -- target string 15 | 16 | Returns: 17 | bool -- result 18 | """ 19 | 20 | # if the target is a tuple or list originally, then return directly 21 | if isinstance(target, tuple) or isinstance(target, list): 22 | return target 23 | 24 | try: 25 | target = eval(target) 26 | if isinstance(target, tuple) or isinstance(target, list): 27 | return target 28 | except: 29 | pass 30 | return None 31 | 32 | def get_valid_size(target): 33 | """get valid size 34 | 35 | if target is a tuple/list or a string of them, then convert and return 36 | if target is a int/float then return 37 | else return None 38 | 39 | Arguments: 40 | target {Number} -- size 41 | """ 42 | ret = tuple_or_list(target) 43 | if ret is not None: 44 | return ret 45 | 46 | try: 47 | ret = int(target) 48 | return ret 49 | except: 50 | pass 51 | 52 | try: 53 | ret = float(target) 54 | return ret 55 | except: 56 | pass 57 | 58 | return None 59 | 60 | -------------------------------------------------------------------------------- /packages/optimizers/adam.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-26 19:42:22 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | import sys 8 | 9 | import torch 10 | 11 | from ..loggers.std_logger import STDLogger as logger 12 | from ..config import CONFIG as cfg 13 | 14 | def get(params): 15 | logger.debug('Going to use [Adam] optimizer for training with betas %s, ' 16 | 'eps %f, weight decay %f %s amsgrad' % ((cfg.beta1, cfg.beta2), 17 | cfg.eps, cfg.weight_decay, 18 | ('with' if cfg.amsgrad else 'without'))) 19 | return torch.optim.Adam(params, lr=cfg.base_lr, 20 | betas=(cfg.beta1, cfg.beta2), eps=cfg.eps, 21 | weight_decay=cfg.weight_decay, amsgrad=cfg.amsgrad) 22 | 23 | def require_args(): 24 | 25 | cfg.add_argument('--beta1', default=0.9, type=float, 26 | help=('coefficients used for computing running' 27 | ' averages of gradient')) 28 | cfg.add_argument('--beta2', default=0.999, type=float, 29 | help=('coefficients used for computing running' 30 | ' averages of gradient\'s square')) 31 | cfg.add_argument('--eps', default=1e-8, type=float, 32 | help='term added to the denominator to improve numerical stability') 33 | cfg.add_argument('--amsgrad', action='store_true', 34 | help='whether to use the AMSGrad variant of this algorithm') 35 | 36 | from ..register import REGISTER 37 | REGISTER.set_class(REGISTER.get_package_name(__name__), 'adam', __name__) -------------------------------------------------------------------------------- /datasets/svhn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2018-10-14 19:31:42 4 | # @Author : Jiabo (Raymond) Huang (jiabo.huang@qmul.ac.uk) 5 | # @Link : https://github.com/Raymond-sci 6 | from __future__ import print_function 7 | from PIL import Image 8 | import torchvision.datasets as datasets 9 | import torch.utils.data as data 10 | import numpy as np 11 | 12 | from packages.config import CONFIG as cfg 13 | 14 | class SVHNInstance(datasets.SVHN): 15 | """SVHNInstance Dataset. 16 | """ 17 | 18 | @staticmethod 19 | def require_args(): 20 | cfg.add_argument('--means', default='(0.4377, 0.4438, 0.4728)', 21 | type=str, help='channel-wise means') 22 | cfg.add_argument('--stds', default='(0.1201, 0.1231, 0.1052)', 23 | type=str, help='channel-wise stds') 24 | 25 | def __init__(self, root, train=True, transform=None, target_trainsform=None, download=False): 26 | self.train = train 27 | super(SVHNInstance, self).__init__(root, split=('train' if train else 'test'), 28 | transform=transform, target_transform=target_trainsform, download=download) 29 | 30 | def __getitem__(self, index): 31 | """ 32 | Args: 33 | index (int): Index 34 | Returns: 35 | tuple: (image, target) where target is index of the target class. 36 | """ 37 | img, target = self.data[index], int(self.labels[index]) 38 | 39 | # doing this so that it is consistent with all other datasets 40 | # to return a PIL Image 41 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 42 | 43 | if self.transform is not None: 44 | img = self.transform(img) 45 | 46 | if self.target_transform is not None: 47 | target = self.target_transform(target) 48 | 49 | return img, target, index -------------------------------------------------------------------------------- /lib/criterion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2018-10-07 22:32:15 4 | # @Author : Jiabo (Raymond) Huang (jiabo.huang@qmul.ac.uk) 5 | # @Link : https://github.com/Raymond-sci 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | 11 | class Criterion(nn.Module): 12 | 13 | def __init__(self): 14 | super(Criterion, self).__init__() 15 | 16 | def forward(self, x, y, ANs): 17 | batch_size, _ = x.shape 18 | 19 | # split anchor and instance list 20 | anchor_indexes, instance_indexes = self.__split(y, ANs) 21 | preds = F.softmax(x, 1) 22 | 23 | l_ans = 0. 24 | if anchor_indexes.size(0) > 0: 25 | # compute loss for anchor samples 26 | y_ans = y.index_select(0, anchor_indexes) 27 | y_ans_neighbour = ANs.position.index_select(0, y_ans) 28 | neighbours = ANs.neighbours.index_select(0, y_ans_neighbour) 29 | # p_i = \sum_{j \in \Omega_i} p_{i,j} 30 | x_ans = preds.index_select(0, anchor_indexes) 31 | x_ans_neighbour = x_ans.gather(1, neighbours).sum(1) 32 | x_ans = x_ans.gather(1, y_ans.view(-1, 1)).view(-1) + x_ans_neighbour 33 | # NLL: l = -log(p_i) 34 | l_ans = -1 * torch.log(x_ans).sum(0) 35 | 36 | l_inst = 0. 37 | if instance_indexes.size(0) > 0: 38 | # compute loss for instance samples 39 | y_inst = y.index_select(0, instance_indexes) 40 | x_inst = preds.index_select(0, instance_indexes) 41 | # p_i = p_{i, i} 42 | x_inst = x_inst.gather(1, y_inst.view(-1, 1)) 43 | # NLL: l = -log(p_i) 44 | l_inst = -1 * torch.log(x_inst).sum(0) 45 | 46 | return (l_inst + l_ans) / batch_size 47 | 48 | def __split(self, y, ANs): 49 | pos = ANs.position.index_select(0, y.view(-1)) 50 | return (pos >= 0).nonzero().view(-1), (pos < 0).nonzero().view(-1) 51 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-28 12:34:35 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | from .cifar import CIFAR10Instance, CIFAR100Instance 8 | from .svhn import SVHNInstance 9 | 10 | import torch 11 | 12 | from packages import datasets as cmd_datasets 13 | from packages.config import CONFIG as cfg 14 | 15 | __all__ = ('CIFAR10Instance', 'CIFAR100Instance', 'SVHNInstance') 16 | 17 | def get(name, instant=False): 18 | """ 19 | Get dataset instance according to the dataset string and dataroot 20 | """ 21 | 22 | # get dataset class 23 | dataset_cls = cmd_datasets.get(name) 24 | 25 | if not instant: 26 | return dataset_cls 27 | 28 | # get transforms for training set 29 | transform_train = cmd_datasets.get_transforms('train', cfg.means, cfg.stds) 30 | 31 | # get transforms for test set 32 | transform_test = cmd_datasets.get_transforms('test', cfg.means, cfg.stds) 33 | 34 | # get trainset and trainloader 35 | trainset = dataset_cls(root=cfg.data_root, train=True, download=True, 36 | transform=transform_train) 37 | # filter trainset if necessary 38 | trainloader = torch.utils.data.DataLoader(trainset, 39 | batch_size=cfg.batch_size, shuffle=True, 40 | num_workers=cfg.workers_num) 41 | 42 | # get testset and testloader 43 | testset = dataset_cls(root=cfg.data_root, train=False, download=True, 44 | transform=transform_test) 45 | # filter testset if necessary 46 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, 47 | shuffle=False, num_workers=cfg.workers_num) 48 | 49 | return trainset, trainloader, testset, testloader 50 | 51 | 52 | cmd_datasets.register('cifar10', CIFAR10Instance) 53 | cmd_datasets.register('cifar100', CIFAR100Instance) 54 | cmd_datasets.register('svhn', SVHNInstance) 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /packages/lr_policy/step.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-26 20:19:00 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | from ..loggers.std_logger import STDLogger as logger 8 | from ..config import CONFIG as cfg 9 | 10 | class StepPolicy: 11 | """Caffe style step decay learning rate policy 12 | 13 | Decay learning rate at every `lr-decay-step` steps after the 14 | first `lr-decay-offset` ones at the rate of `lr-decay-rate` 15 | """ 16 | 17 | def __init__(self, *args, **kwargs): 18 | if len(args) + len(kwargs) == 0: 19 | self.__init_by_cfg() 20 | else: 21 | self.__init(*args, **kwargs) 22 | 23 | def __init(self, base_lr, offset, step, rate): 24 | self.base_lr = base_lr 25 | self.offset = offset 26 | self.step = step 27 | self.rate = rate 28 | logger.debug('Going to use [step] learning policy for optimization with ' 29 | 'base learning rate %.5f, offset %d, step %d and decay rate %f' % 30 | (base_lr, offset, step, rate)) 31 | 32 | def __init_by_cfg(self): 33 | self.__init(cfg.base_lr, cfg.lr_decay_offset, 34 | cfg.lr_decay_step, cfg.lr_decay_rate) 35 | 36 | @staticmethod 37 | def require_args(): 38 | 39 | cfg.add_argument('--lr-decay-offset', default=0, type=int, 40 | help='learning rate will start to decay at which step') 41 | 42 | cfg.add_argument('--lr-decay-step', default=0, type=int, 43 | help='learning rate will decay at every n round') 44 | 45 | 46 | cfg.add_argument('--lr-decay-rate', default=0.1, type=float, 47 | help='learning rate will decay at what rate') 48 | 49 | def update(self, steps): 50 | """decay learning rate according to current step 51 | 52 | Decay learning rate at a fixed ratio 53 | 54 | Arguments: 55 | steps {int} -- current steps 56 | 57 | Returns: 58 | int -- updated learning rate 59 | """ 60 | 61 | if steps < self.offset: 62 | return self.base_lr 63 | 64 | return self.base_lr * (self.rate ** ((steps - self.offset) // self.step)) 65 | 66 | 67 | from ..register import REGISTER 68 | REGISTER.set_class(REGISTER.get_package_name(__name__), 'step', StepPolicy) 69 | -------------------------------------------------------------------------------- /datasets/cifar.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | 5 | from PIL import Image 6 | import torchvision.datasets as datasets 7 | import torch.utils.data as data 8 | 9 | from packages.config import CONFIG as cfg 10 | 11 | class CIFAR10Instance(datasets.CIFAR10): 12 | """CIFAR10Instance Dataset. 13 | """ 14 | 15 | @staticmethod 16 | def require_args(): 17 | cfg.add_argument('--means', default='(0.4914, 0.4822, 0.4465)', 18 | type=str, help='channel-wise means') 19 | cfg.add_argument('--stds', default='(0.2023, 0.1994, 0.2010)', 20 | type=str, help='channel-wise stds') 21 | 22 | def __init__(self, *args, **kwargs): 23 | super(CIFAR10Instance, self).__init__(*args, **kwargs) 24 | self.labels = self.targets 25 | 26 | def __getitem__(self, index): 27 | """ 28 | Args: 29 | index (int): Index 30 | Returns: 31 | tuple: (image, target) where target is index of the target class. 32 | """ 33 | img, target = self.data[index], self.targets[index] 34 | 35 | # doing this so that it is consistent with all other datasets 36 | # to return a PIL Image 37 | img = Image.fromarray(img) 38 | 39 | if self.transform is not None: 40 | img = self.transform(img) 41 | 42 | if self.target_transform is not None: 43 | target = self.target_transform(target) 44 | 45 | return img, target, index 46 | 47 | class CIFAR100Instance(CIFAR10Instance): 48 | """CIFAR100Instance Dataset. 49 | 50 | This is a subclass of the `CIFAR10Instance` Dataset. 51 | """ 52 | 53 | @staticmethod 54 | def args(parser): 55 | parser.add_argument('--means', default='(0.5071, 0.4866, 0.4409)', 56 | type=str, help='channel-wise means') 57 | parser.add_argument('--stds', default='(0.2009, 0.1984, 0.2023)', 58 | type=str, help='channel-wise stds') 59 | return parser 60 | 61 | base_folder = 'cifar-100-python' 62 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 63 | filename = "cifar-100-python.tar.gz" 64 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 65 | train_list = [ 66 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 67 | ] 68 | 69 | test_list = [ 70 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 71 | ] 72 | meta = { 73 | 'filename': 'meta', 74 | 'key': 'fine_label_names', 75 | 'md5': '7973b15100ade9c7d40fb424638fde48', 76 | } 77 | -------------------------------------------------------------------------------- /packages/lr_policy/multistep.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-26 20:19:00 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | from ..loggers.std_logger import STDLogger as logger 8 | from ..config import CONFIG as cfg 9 | 10 | class MultiStepPolicy: 11 | """Caffe style step decay learning rate policy 12 | 13 | Decay learning rate at every `lr-decay-step` steps after the 14 | first `lr-decay-offset` ones at the rate of `lr-decay-rate` 15 | """ 16 | 17 | def __init__(self, *args, **kwargs): 18 | if len(args) + len(kwargs) == 0: 19 | self.__init_by_cfg() 20 | else: 21 | self.__init(*args, **kwargs) 22 | 23 | def __init(self, base_lr, schedule): 24 | self.base_lr = base_lr 25 | self.schedule = schedule 26 | logger.debug('Going to use [multistep] learning policy for optimization ' 27 | 'with base learing rate %.5f and schedule from %s' 28 | % (base_lr, schedule)) 29 | 30 | def __init_by_cfg(self): 31 | schedule = cfg.lr_schedule 32 | assert schedule is not None and os.path.exists(schedule), ('Schedule ' 33 | 'file not found: [%s]' % schedule) 34 | self.__init(cfg.base_lr, schedule) 35 | 36 | @staticmethod 37 | def require_args(): 38 | 39 | cfg.add_argument('--lr-schedule', default=None, type=str, 40 | help='learning rate schedule') 41 | 42 | def update(self, steps): 43 | """update learning rate 44 | 45 | Update learning rate according to current steps and schedule file 46 | 47 | Arguments: 48 | steps {int} -- current steps 49 | 50 | Returns: 51 | float -- updated file 52 | """ 53 | 54 | lines = filter(lambda x:not x.startswith('#'), 55 | open(self.schedule, 'r').readlines()) 56 | assert len(lines) > 0, 'Invalid schedule file' 57 | 58 | learning_rate = self.base_lr 59 | for line in lines: 60 | 61 | line = line.split('#')[0] 62 | anchor, target = line.strip().split(':') 63 | 64 | if target.startswith('-'): 65 | lr = -1 66 | else: 67 | lr = float(target) 68 | if steps <= anchor: 69 | learning_rate = lr 70 | else: 71 | break 72 | 73 | return learning_rate 74 | 75 | 76 | from ..register import REGISTER 77 | REGISTER.set_class(REGISTER.get_package_name(__name__), 'multistep', MultiStepPolicy) 78 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2018-10-12 21:37:13 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | import os 8 | import sys 9 | import shutil 10 | import numpy as np 11 | from datetime import timedelta 12 | 13 | import torch 14 | 15 | from packages.loggers.std_logger import STDLogger as logger 16 | 17 | class AverageMeter(object): 18 | """Computes and stores the average and current value""" 19 | def __init__(self): 20 | self.reset() 21 | 22 | def reset(self): 23 | self.val = 0 24 | self.avg = 0 25 | self.sum = 0 26 | self.count = 0 27 | 28 | def update(self, val, n=1): 29 | self.val = val 30 | self.sum += val * n 31 | self.count += n 32 | self.avg = self.sum / self.count 33 | 34 | def adjust_learning_rate(optimizer, lr): 35 | 36 | for param_group in optimizer.param_groups: 37 | param_group['lr'] = lr 38 | 39 | return lr 40 | 41 | def time_progress(elapsed_iters, tot_iters, elapsed_time): 42 | estimated_time = 1. * tot_iters / elapsed_iters * elapsed_time 43 | elapsed_time = timedelta(seconds=elapsed_time) 44 | estimated_time = timedelta(seconds=estimated_time) 45 | return tuple(map(lambda x:str(x).split('.')[0], [elapsed_time, estimated_time])) 46 | 47 | def save_ckpt(state_dict, target, is_best=False): 48 | latest, best = map(lambda x:os.path.join(target, x), ['latest.ckpt', 'best.ckpt']) 49 | # save latest checkpoint 50 | torch.save(state_dict, latest) 51 | # if is best, then copy latest to best 52 | if not is_best: 53 | return 54 | shutil.copyfile(latest, best) 55 | 56 | def traverse(net, loader, transform=None, tencrops=False, device='cpu'): 57 | 58 | bak_transform = loader.dataset.transform 59 | if transform is not None: 60 | loader.dataset.transform = transform 61 | 62 | features = None 63 | labels = torch.zeros(len(loader.dataset)).long().to(device) 64 | 65 | with torch.no_grad(): 66 | for batch_idx, (inputs, targets, indexes) in enumerate(loader): 67 | logger.progress(batch_idx, len(loader), 'processing %d/%d batch...') 68 | 69 | if tencrops: 70 | bs, ncrops, c, h, w = inputs.size() 71 | inputs = inputs.view(-1, c, h, w) 72 | inputs, targets, indexes = (inputs.to(device), targets.to(device), 73 | indexes.to(device)) 74 | 75 | feats = net(inputs) 76 | if tencrops: 77 | feats = torch.squeeze(feats.view(bs, ncrops, -1).mean(1)) 78 | 79 | if features is None: 80 | features = torch.zeros(len(loader.dataset), feats.shape[1]).to(device) 81 | features.index_copy_(0, indexes, feats) 82 | labels.index_copy_(0, indexes, targets) 83 | 84 | loader.dataset.transform = bak_transform 85 | 86 | return features, labels 87 | -------------------------------------------------------------------------------- /packages/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-24 22:16:37 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | import os 8 | import argparse 9 | import yaml 10 | import importlib 11 | import traceback 12 | from prettytable import PrettyTable 13 | 14 | from .register import REGISTER 15 | 16 | class _META_(type): 17 | 18 | PARSER = argparse.ArgumentParser(formatter_class= 19 | argparse.ArgumentDefaultsHelpFormatter) 20 | ARGS = dict() 21 | 22 | def require_args(self): 23 | 24 | # args for config file 25 | _META_.PARSER.add_argument('--cfgs', type=str, nargs='*', 26 | help='config files to load') 27 | 28 | def parse(self): 29 | 30 | # collect self args 31 | self.require_args() 32 | 33 | # load default args from config file 34 | known_args, _ = _META_.PARSER.parse_known_args() 35 | self.from_files(known_args.cfgs) 36 | 37 | # collect args for packages 38 | for package in REGISTER.get_packages(): 39 | m = importlib.import_module(package) 40 | if hasattr(m, 'require_args'): 41 | m.require_args() 42 | 43 | # re-update default value for new args 44 | self.from_files(known_args.cfgs) 45 | 46 | # parse args 47 | _META_.ARGS = _META_.PARSER.parse_args() 48 | 49 | def from_files(self, files): 50 | 51 | # if no config file is provided, skip 52 | if files is None or len(files) <= 0: 53 | return None 54 | 55 | for file in files: 56 | assert os.path.exists(file), "Config file not found: [%s]" % file 57 | configs = yaml.load(open(file, 'r')) 58 | _META_.PARSER.set_defaults(**configs) 59 | 60 | def get(self, attr, default=None): 61 | if hasattr(_META_.ARGS, attr): 62 | return getattr(_META_.ARGS, attr) 63 | return default 64 | 65 | def yaml(self): 66 | config = {k:v for k,v in sorted(vars(_META_.ARGS).items())} 67 | return yaml.safe_dump(config, default_flow_style=False) 68 | 69 | def __getattr__(self, attr): 70 | try: 71 | return _META_.PARSER.__getattribute__(attr) 72 | except AttributeError: 73 | return _META_.ARGS.__getattribute__(attr) 74 | except: 75 | traceback.print_exec() 76 | exit(-1) 77 | 78 | def __str__(self): 79 | MAX_WIDTH = 20 80 | table = PrettyTable(["#", "Key", "Value", "Default"]) 81 | table.align = 'l' 82 | for i, (k, v) in enumerate(sorted(vars(_META_.ARGS).items())): 83 | v = str(v) 84 | default = str(_META_.PARSER.get_default(k)) 85 | if default == v: 86 | default = '--' 87 | table.add_row([i, k, v[:MAX_WIDTH] + ('...' if len(v) > MAX_WIDTH else ''), default]) 88 | return table.get_string() 89 | 90 | class CONFIG(object): 91 | __metaclass__ = _META_ -------------------------------------------------------------------------------- /lib/non_parametric_classifier.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-24 22:16:37 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | import torch 8 | from torch.autograd import Function 9 | from torch import nn 10 | import math 11 | 12 | from packages.register import REGISTER 13 | from packages.config import CONFIG as cfg 14 | 15 | def require_args(): 16 | # args for non-parametric classifier (npc) 17 | cfg.add_argument('--npc-temperature', default=0.1, type=float, 18 | help='temperature parameter for softmax') 19 | cfg.add_argument('--npc-momentum', default=0.5, type=float, 20 | help='momentum for non-parametric updates') 21 | 22 | class NonParametricClassifierOP(Function): 23 | @staticmethod 24 | def forward(self, x, y, memory, params): 25 | 26 | T = params[0].item() 27 | batchSize = x.size(0) 28 | 29 | # inner product 30 | out = torch.mm(x.data, memory.t()) 31 | out.div_(T) # batchSize * N 32 | 33 | self.save_for_backward(x, memory, y, params) 34 | 35 | return out 36 | 37 | @staticmethod 38 | def backward(self, gradOutput): 39 | x, memory, y, params = self.saved_tensors 40 | batchSize = gradOutput.size(0) 41 | T = params[0].item() 42 | momentum = params[1].item() 43 | 44 | # add temperature 45 | gradOutput.data.div_(T) 46 | 47 | # gradient of linear 48 | gradInput = torch.mm(gradOutput.data, memory) 49 | gradInput.resize_as_(x) 50 | 51 | # update the memory 52 | weight_pos = memory.index_select(0, y.data.view(-1)).resize_as_(x) 53 | weight_pos.mul_(momentum) 54 | weight_pos.add_(torch.mul(x.data, 1-momentum)) 55 | w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5) 56 | updated_weight = weight_pos.div(w_norm) 57 | memory.index_copy_(0, y, updated_weight) 58 | 59 | return gradInput, None, None, None, None 60 | 61 | class NonParametricClassifier(nn.Module): 62 | """Non-parametric Classifier 63 | 64 | Non-parametric Classifier from 65 | "Unsupervised Feature Learning via Non-Parametric Instance Discrimination" 66 | 67 | Extends: 68 | nn.Module 69 | """ 70 | 71 | def __init__(self, inputSize, outputSize, T=0.05, momentum=0.5): 72 | """Non-parametric Classifier initial functin 73 | 74 | Initial function for non-parametric classifier 75 | 76 | Arguments: 77 | inputSize {int} -- in-channels dims 78 | outputSize {int} -- out-channels dims 79 | 80 | Keyword Arguments: 81 | T {int} -- distribution temperate (default: {0.05}) 82 | momentum {int} -- memory update momentum (default: {0.5}) 83 | """ 84 | super(NonParametricClassifier, self).__init__() 85 | stdv = 1 / math.sqrt(inputSize) 86 | self.nLem = outputSize 87 | 88 | self.register_buffer('params', 89 | torch.tensor([cfg.npc_temperature, cfg.npc_momentum])) 90 | stdv = 1. / math.sqrt(inputSize/3) 91 | self.register_buffer('memory', torch.rand(outputSize, inputSize) 92 | .mul_(2*stdv).add_(-stdv)) 93 | 94 | def forward(self, x, y): 95 | out = NonParametricClassifierOP.apply(x, y, self.memory, self.params) 96 | return out 97 | 98 | REGISTER.set_package(__name__) 99 | REGISTER.set_class(__name__, 'npc', NonParametricClassifier) 100 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # AND: Anchor Neighbourhood Discovery 2 | 3 | *Accepted by 36th International Conference on Machine Learning (ICML 2019)*. 4 | 5 | Pytorch implementation of [Unsupervised Deep Learning by Neighbourhood Discovery](https://arxiv.org/abs/1904.11567). 6 | 7 | 8 | 9 | 10 | ## Highlight 11 | + We propose the idea of exploiting local neighbourhoods for unsupervised deep learning. This strategy preserves the capability of clustering for class boundary inference whilst minimising the negative impact of class inconsistency typically encountered in clusters. 12 | + We formulate an *Anchor Neighbourhood Discovery (AND)* approach to progressive unsupervised deep learning. The AND model not only generalises the idea of sample specificity learning, but also additionally considers the originally missing sample-to-sample correlation during model learning by a novel neighbourhood supervision design. 13 | + We further introduce a curriculum learning algorithm to gradually perform neighbourhood discovery for maximising the class consistency of neighbourhoods therefore enhancing the unsupervised learning capability. 14 | 15 | ## Main results 16 | The proposed AND model was evaluated on four object image classification datasets including CIFAR 10/100, SVHN and ImageNet12. Results are shown at the following tables: 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | ## Reproduction 25 | 26 | ### Requirements 27 | Python 2.7 and Pytorch 1.0 are required. Please refer to `/path/to/AND/requirements.yaml` for other necessary modules. Conda environment we used for the experiments can also be rebuilt according to it. 28 | 29 | ### Usages 30 | 31 | 1. Clone this repo: `git clone https://github.com/Raymond-sci/AND.git` 32 | 2. Download datasets and store them in `/path/to/AND/data`. (Soft link is recommended to avoid redundant copies of datasets) 33 | 2. To reproduce our reported result of ResNet18 on CIFAR10, please use the following command:`python main.py --cfgs configs/base.yaml configs/cifar10.yaml` 34 | 3. Running on GPUs: code will be run on CPU by default, use this flag to specify the gpu devices which you want to use 35 | 4. To evaluate trained models, use `--resume` to set the path of the generated checkpoint file and use `--test-only` flag to exit the program after evaluation 36 | 37 | Every time the `main.py` is run, a new session will be started with the name of current timestamp and all the related files will be stored in folder `sessions/timestamp/` including checkpoints, logs, etc. 38 | 39 | ### Pre-trained model 40 | To play with the pre-trained model, please go to [ResNet18](https://drive.google.com/file/d/1tMopB0iLPaJzw81tqZuXbK6YYAQRLXA-/view?usp=sharing) / [AlexNet](https://drive.google.com/file/d/1SeLi34LxuThcLulBaWViwy3kLYQQWX0l/view?usp=sharing). A few things need to be noticed: 41 | + The model is saved in **pytorch** format 42 | + It expects RGB images that their pixel values are normalised with the following mean RGB values `mean=[0.485, 0.456, 0.406]` and std RGB values `std=[0.229, 0.224, 0.225]`. Prior to normalisation the range of the image values must be `[0.0, 1.0]`. 43 | 44 | ## License 45 | This project is licensed under the MIT License. You may find out more [here](./LICENSE). 46 | 47 | ## Reference 48 | If you use this code, please cite the following paper: 49 | 50 | Jiabo Huang, Qi Dong, Shaogang Gong and Xiatian Zhu. "Unsupervised Deep Learning by Neighbourhood Discovery." Proc. ICML (2019). 51 | 52 | ``` 53 | @InProceedings{huang2018and, 54 | title={Unsupervised Deep Learning by Neighbourhood Discovery}, 55 | author={Jiabo Huang, Qi Dong, Shaogang Gong and Xiatian Zhu}, 56 | booktitle={Proceedings of the International Conference on machine learning (ICML)}, 57 | year={2019}, 58 | } 59 | ``` -------------------------------------------------------------------------------- /requirements.yaml: -------------------------------------------------------------------------------- 1 | name: pytorch 2 | channels: 3 | - menpo 4 | - pytorch 5 | - conda-forge 6 | - anaconda 7 | - defaults 8 | dependencies: 9 | - backports=1.0=py27_1 10 | - backports.functools_lru_cache=1.5=py27_1 11 | - backports_abc=0.5=py27_0 12 | - blas=1.0=mkl 13 | - cffi=1.11.5=py27he75722e_1 14 | - cloudpickle=0.6.1=py27_0 15 | - cryptography=2.4.2=py27h1ba5d50_0 16 | - cryptography-vectors=2.3.1=py27_0 17 | - cudatoolkit=8.0=3 18 | - cycler=0.10.0=py27_0 19 | - dask-core=0.19.4=py27_0 20 | - dbus=1.13.2=h714fa37_1 21 | - decorator=4.3.0=py27_0 22 | - expat=2.2.6=he6710b0_0 23 | - fontconfig=2.13.0=h9420a91_0 24 | - freetype=2.9.1=h8a8886c_1 25 | - functools32=3.2.3.2=py27_1 26 | - futures=3.2.0=py27_0 27 | - glib=2.56.2=hd408876_0 28 | - gst-plugins-base=1.14.0=hbbd80ab_1 29 | - gstreamer=1.14.0=hb453b48_1 30 | - imageio=2.4.1=py27_0 31 | - intel-openmp=2019.0=118 32 | - kiwisolver=1.0.1=py27hf484d3e_0 33 | - libedit=3.1.20170329=h6b74fdf_2 34 | - libgcc=7.2.0=h69d50b8_2 35 | - libgcc-ng=8.2.0=hdf63c60_1 36 | - libgfortran-ng=7.3.0=hdf63c60_0 37 | - libpng=1.6.36=hbc83047_0 38 | - libstdcxx-ng=8.2.0=hdf63c60_1 39 | - libtiff=4.0.9=he85c1e1_2 40 | - libuuid=1.0.3=h1bed415_2 41 | - libxcb=1.13=h1bed415_1 42 | - libxml2=2.9.8=h26e45fe_1 43 | - matplotlib=2.2.3=py27hb69df0a_0 44 | - mkl=2019.0=118 45 | - mkl_fft=1.0.4=py27h4414c95_1 46 | - mkl_random=1.0.1=py27h4414c95_1 47 | - ncurses=6.1=hf484d3e_0 48 | - networkx=2.2=py27_1 49 | - ninja=1.8.2=py27h6bb024c_1 50 | - numpy=1.15.1=py27h1d66e8a_0 51 | - numpy-base=1.15.1=py27h81de0dd_0 52 | - olefile=0.46=py27_0 53 | - openssl=1.1.1=h7b6447c_0 54 | - pcre=8.42=h439df22_0 55 | - pillow=5.2.0=py27heded4f4_0 56 | - pip=10.0.1=py27_0 57 | - pycparser=2.18=py27_1 58 | - pyopenssl=18.0.0=py27_0 59 | - pyparsing=2.2.0=py27_1 60 | - pyqt=5.9.2=py27h05f1152_2 61 | - python=2.7.15=h9bab390_6 62 | - python-dateutil=2.7.3=py27_0 63 | - pytz=2018.5=py27_0 64 | - pywavelets=1.0.1=py27hdd07704_0 65 | - pyyaml=3.13=py27h14c3975_0 66 | - qt=5.9.7=h5867ecd_1 67 | - readline=7.0=h7b6447c_5 68 | - scikit-image=0.14.0=py27hf484d3e_1 69 | - scikit-learn=0.19.2=py27h4989274_0 70 | - scipy=1.1.0=py27hfa4b5c9_1 71 | - setuptools=40.2.0=py27_0 72 | - singledispatch=3.4.0.3=py27h9bcb476_0 73 | - sip=4.19.8=py27hf484d3e_0 74 | - six=1.11.0=py27_1 75 | - sqlite=3.26.0=h7b6447c_0 76 | - subprocess32=3.5.2=py27h14c3975_0 77 | - tk=8.6.8=hbc83047_0 78 | - toolz=0.9.0=py27_0 79 | - tornado=5.1=py27h14c3975_0 80 | - wheel=0.31.1=py27_0 81 | - xz=5.2.4=h14c3975_4 82 | - yaml=0.1.7=h96e3832_1 83 | - asn1crypto=0.24.0=py27_1003 84 | - chardet=3.0.4=py27_3 85 | - dominate=2.3.1=py_1 86 | - enum34=1.1.6=py27_1001 87 | - idna=2.7=py27_1002 88 | - ipaddress=1.0.22=py_1 89 | - libsodium=1.0.16=h470a237_1 90 | - prettytable=0.7.2=py_2 91 | - pysocks=1.6.8=py27_1002 92 | - python-lmdb=0.92=py27_0 93 | - pyzmq=17.1.2=py27hae99301_0 94 | - requests=2.19.1=py27_1 95 | - torchfile=0.1.0=py_0 96 | - urllib3=1.23=py27_1 97 | - visdom=0.1.8.5=0 98 | - websocket-client=0.53.0=py27_0 99 | - zeromq=4.2.5=hfc679d8_6 100 | - ca-certificates=2018.12.5=0 101 | - certifi=2018.11.29=py27_0 102 | - cython=0.29.7=py27he6710b0_0 103 | - icu=58.2=h9c2bf20_1 104 | - jpeg=9b=h024ee3a_2 105 | - libffi=3.2.1=hd88cf55_4 106 | - zlib=1.2.11=ha838bed_2 107 | - opencv=2.4.11=nppy27_0 108 | - cuda80=1.0=h205658b_0 109 | - cuda91=1.0=h4c16780_0 110 | - cuda92=1.0=0 111 | - faiss-gpu=1.4.0=py27_cuda8.0.61_1 112 | - pytorch=1.0.1=py2.7_cuda8.0.61_cudnn7.1.2_2 113 | - torchvision=0.2.2=py_3 114 | - pip: 115 | - cython-based-reid-evaluation-code==0.0.0 116 | - dask==0.19.4 117 | - faiss==1.4.0 118 | - humanfriendly==4.18 119 | - lmdb==0.92 120 | - monotonic==1.5 121 | - protobuf==3.7.0 122 | - tensorboardx==1.6 123 | - torch==1.0.1.post2 124 | prefix: /import/sgg-homes/jh327/Applications/conda/envs/pytorch 125 | 126 | -------------------------------------------------------------------------------- /packages/session.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-27 09:38:53 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | import os 8 | import time 9 | 10 | import torch 11 | import importlib 12 | import numpy as np 13 | 14 | from loggers.std_logger import STDLogger as logger 15 | from register import REGISTER 16 | from config import CONFIG as cfg 17 | 18 | REGISTER.set_package(__name__) 19 | 20 | def require_args(): 21 | 22 | # timestamp 23 | stt = time.strftime('%Y%m%d-%H%M%S', time.gmtime()) 24 | tt = int(time.time()) 25 | 26 | cfg.add_argument('--session', default=stt, type=str, 27 | help='session name (default: %s)' % stt) 28 | cfg.add_argument('--sess-dir', default='sessions', type=str, 29 | help='directory to store session. (default: sessions)') 30 | cfg.add_argument('--print-args', action='store_true', 31 | help='do nothing but print all args. (default: False)') 32 | cfg.add_argument('--seed', default=tt, type=int, 33 | help='session random seed. (default: %d)' % tt) 34 | cfg.add_argument('--brief', action='store_true', 35 | help='print log with priority higher than debug. ' 36 | '(default: False)') 37 | cfg.add_argument('--debug', action='store_true', 38 | help='if debugging, no log or checkpoint files will be stored. ' 39 | '(default: False)') 40 | cfg.add_argument('--gpus', default='', type=str, 41 | help='available gpu list. (default: \'\')') 42 | cfg.add_argument('--resume', default=None, type=str, 43 | help='path to resume session. (default: None)') 44 | cfg.add_argument('--restart', action='store_true', 45 | help='load session status and start a new one. ' 46 | '(default: False)') 47 | 48 | def run(main): 49 | 50 | # import main module 51 | main = importlib.import_module(main) 52 | # parse args 53 | main.require_args() 54 | cfg.parse() 55 | 56 | # setup session according to args 57 | setup() 58 | 59 | # run main function 60 | main.main() 61 | 62 | def setup(): 63 | """ 64 | set up common environment for training 65 | """ 66 | 67 | # print args 68 | if not cfg.brief: 69 | print cfg 70 | # exit if require to print args only 71 | if cfg.print_args: 72 | exit(0) 73 | 74 | # if not verbose, set log level to info 75 | logger.setup(logger.INFO if cfg.brief else logger.DEBUG) 76 | 77 | logger.info('Start to setup session') 78 | 79 | # fix random seeds 80 | torch.manual_seed(cfg.seed) 81 | torch.cuda.manual_seed_all(cfg.seed) 82 | np.random.seed(cfg.seed) 83 | 84 | # set visible gpu devices at main function 85 | logger.info('Visible gpu devices are: %s' % cfg.gpus) 86 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpus 87 | 88 | # setup session name and session path 89 | if cfg.resume and cfg.resume.strip() != '' and not cfg.restart: 90 | assert os.path.exists(cfg.resume), ("Resume file not " 91 | "found: %s" % cfg.resume) 92 | ckpt = torch.load(cfg.resume) 93 | if 'session' in ckpt: 94 | cfg.session = ckpt['session'] 95 | cfg.sess_dir = os.path.join(cfg.sess_dir, cfg.session) 96 | logger.info('Current session name: %s' % cfg.session) 97 | 98 | # setup checkpoint dir 99 | cfg.ckpt_dir = os.path.join(cfg.sess_dir, 'checkpoint') 100 | if not os.path.exists(cfg.ckpt_dir) and not cfg.debug: 101 | os.makedirs(cfg.ckpt_dir) 102 | 103 | # redirect logs to file 104 | if cfg.log_file and not cfg.debug: 105 | logger.setup(to_file=os.path.join(cfg.sess_dir, 'log.txt')) 106 | 107 | # setup tfb log dir 108 | cfg.tfb_dir = os.path.join(cfg.sess_dir, 'tfboard') 109 | if not os.path.exists(cfg.tfb_dir) and not cfg.debug: 110 | os.makedirs(cfg.tfb_dir) 111 | 112 | # store options at log directory 113 | if not cfg.debug: 114 | with open(os.path.join(cfg.sess_dir, 'config.yaml'), 'w') as out: 115 | out.write(cfg.yaml() + '\n') 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /lib/protocols.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datasets 3 | from lib.utils import AverageMeter, traverse 4 | import sys 5 | 6 | from packages.register import REGISTER 7 | from packages.loggers.std_logger import STDLogger as logger 8 | 9 | def NN(net, npc, trainloader, testloader, K=0, sigma=0.1, 10 | recompute_memory=False, device='cpu'): 11 | 12 | # switch model to evaluation mode 13 | net.eval() 14 | 15 | # tracking variables 16 | correct = 0. 17 | total = 0 18 | 19 | trainFeatures = npc.memory 20 | trainLabels = torch.LongTensor(trainloader.dataset.labels).to(device) 21 | 22 | # recompute features for training samples 23 | if recompute_memory: 24 | trainFeatures, trainLabels = traverse(net, trainloader, 25 | testloader.dataset.transform, device) 26 | trainFeatures = trainFeatures.t() 27 | 28 | # start to evaluate 29 | with torch.no_grad(): 30 | for batch_idx, (inputs, targets, indexes) in enumerate(testloader): 31 | logger.progress(batch_idx, len(testloader), 'processing %d/%d batch...') 32 | inputs, targets = inputs.to(device), targets.to(device) 33 | batchSize = inputs.size(0) 34 | 35 | # forward 36 | features = net(inputs) 37 | 38 | # cosine similarity 39 | dist = torch.mm(features, trainFeatures) 40 | 41 | yd, yi = dist.topk(1, dim=1, largest=True, sorted=True) 42 | candidates = trainLabels.view(1,-1).expand(batchSize, -1) 43 | retrieval = torch.gather(candidates, 1, yi) 44 | 45 | retrieval = retrieval.narrow(1, 0, 1).clone().view(-1) 46 | yd = yd.narrow(1, 0, 1) 47 | 48 | total += targets.size(0) 49 | correct += retrieval.eq(targets.data).sum().item() 50 | 51 | return correct/total 52 | 53 | def kNN(net, npc, trainloader, testloader, K=200, sigma=0.1, 54 | recompute_memory=False, device='cpu'): 55 | 56 | # set the model to evaluation mode 57 | net.eval() 58 | 59 | # tracking variables 60 | total = 0 61 | 62 | trainFeatures = npc.memory 63 | trainLabels = torch.LongTensor(trainloader.dataset.labels).to(device) 64 | 65 | # recompute features for training samples 66 | if recompute_memory: 67 | trainFeatures, trainLabels = traverse(net, trainloader, 68 | testloader.dataset.transform, device) 69 | trainFeatures = trainFeatures.t() 70 | C = trainLabels.max() + 1 71 | 72 | # start to evaluate 73 | top1 = 0. 74 | top5 = 0. 75 | with torch.no_grad(): 76 | retrieval_one_hot = torch.zeros(K, C.item()).to(device) 77 | for batch_idx, (inputs, targets, indexes) in enumerate(testloader): 78 | logger.progress(batch_idx, len(testloader), 'processing %d/%d batch...') 79 | 80 | batchSize = inputs.size(0) 81 | targets, inputs = targets.to(device), inputs.to(device) 82 | 83 | # forward 84 | features = net(inputs) 85 | 86 | # cosine similarity 87 | dist = torch.mm(features, trainFeatures) 88 | 89 | yd, yi = dist.topk(K, dim=1, largest=True, sorted=True) 90 | candidates = trainLabels.view(1,-1).expand(batchSize, -1) 91 | retrieval = torch.gather(candidates, 1, yi) 92 | 93 | retrieval_one_hot.resize_(batchSize * K, C).zero_() 94 | retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1) 95 | yd_transform = yd.clone().div_(sigma).exp_() 96 | probs = torch.sum(torch.mul(retrieval_one_hot.view(batchSize, -1 , C), 97 | yd_transform.view(batchSize, -1, 1)), 1) 98 | _, predictions = probs.sort(1, True) 99 | 100 | # Find which predictions match the target 101 | correct = predictions.eq(targets.data.view(-1,1)) 102 | 103 | top1 = top1 + correct.narrow(1,0,1).sum().item() 104 | top5 = top5 + correct.narrow(1,0,5).sum().item() 105 | 106 | total += targets.size(0) 107 | 108 | return top1/total 109 | 110 | def get(name): 111 | return REGISTER.get_class(__name__, name) 112 | 113 | REGISTER.set_package(__name__) 114 | REGISTER.set_class(__name__, 'knn', kNN) 115 | REGISTER.set_class(__name__, 'nn', NN) -------------------------------------------------------------------------------- /models/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from lib.normalize import Normalize 13 | 14 | from torch.autograd import Variable 15 | 16 | 17 | class BasicBlock(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, in_planes, planes, stride=1): 21 | super(BasicBlock, self).__init__() 22 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | 27 | self.shortcut = nn.Sequential() 28 | if stride != 1 or in_planes != self.expansion*planes: 29 | self.shortcut = nn.Sequential( 30 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 31 | nn.BatchNorm2d(self.expansion*planes) 32 | ) 33 | 34 | def forward(self, x): 35 | out = F.relu(self.bn1(self.conv1(x))) 36 | out = self.bn2(self.conv2(out)) 37 | out += self.shortcut(x) 38 | out = F.relu(out) 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1): 46 | super(Bottleneck, self).__init__() 47 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(planes) 49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 50 | self.bn2 = nn.BatchNorm2d(planes) 51 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 52 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 53 | 54 | self.shortcut = nn.Sequential() 55 | if stride != 1 or in_planes != self.expansion*planes: 56 | self.shortcut = nn.Sequential( 57 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 58 | nn.BatchNorm2d(self.expansion*planes) 59 | ) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(self.conv1(x))) 63 | out = F.relu(self.bn2(self.conv2(out))) 64 | out = self.bn3(self.conv3(out)) 65 | out += self.shortcut(x) 66 | out = F.relu(out) 67 | return out 68 | 69 | 70 | class ResNet(nn.Module): 71 | def __init__(self, block, num_blocks, low_dim=128): 72 | super(ResNet, self).__init__() 73 | self.in_planes = 64 74 | 75 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 76 | self.bn1 = nn.BatchNorm2d(64) 77 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 78 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 79 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 80 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 81 | self.linear = nn.Linear(512*block.expansion, low_dim) 82 | self.l2norm = Normalize(2) 83 | 84 | def _make_layer(self, block, planes, num_blocks, stride): 85 | strides = [stride] + [1]*(num_blocks-1) 86 | layers = [] 87 | for stride in strides: 88 | layers.append(block(self.in_planes, planes, stride)) 89 | self.in_planes = planes * block.expansion 90 | return nn.Sequential(*layers) 91 | 92 | def forward(self, x): 93 | out = F.relu(self.bn1(self.conv1(x))) 94 | out = self.layer1(out) 95 | out = self.layer2(out) 96 | out = self.layer3(out) 97 | out = self.layer4(out) 98 | out = F.avg_pool2d(out, 4) 99 | out = out.view(out.size(0), -1) 100 | out = self.linear(out) 101 | out = self.l2norm(out) 102 | return out 103 | 104 | 105 | def ResNet18(low_dim=128): 106 | return ResNet(BasicBlock, [2,2,2,2], low_dim) 107 | 108 | def ResNet34(low_dim=128): 109 | return ResNet(BasicBlock, [3,4,6,3], low_dim) 110 | 111 | def ResNet50(low_dim=128): 112 | return ResNet(Bottleneck, [3,4,6,3], low_dim) 113 | 114 | def ResNet101(low_dim=128): 115 | return ResNet(Bottleneck, [3,4,23,3], low_dim) 116 | 117 | def ResNet152(low_dim=128): 118 | return ResNet(Bottleneck, [3,8,36,3], low_dim) 119 | 120 | 121 | def test(): 122 | net = ResNet18() 123 | y = net(Variable(torch.randn(1,3,32,32))) 124 | print(y.size()) 125 | 126 | # test() 127 | -------------------------------------------------------------------------------- /packages/register.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-25 15:08:10 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | import sys 8 | 9 | class REGISTER: 10 | """Singleton Register Class 11 | 12 | This class is used to register all modules in the whole package 13 | 14 | Variables: 15 | PACKAGES_2_CLASSES {dict} -- record the registered modules 16 | """ 17 | 18 | PACKAGES_2_CLASSES = dict() 19 | 20 | @staticmethod 21 | def set_package(package): 22 | """set package 23 | 24 | This function is called when a new package has been created 25 | 26 | Arguments: 27 | package {str} -- package name 28 | """ 29 | if not package in REGISTER.PACKAGES_2_CLASSES: 30 | REGISTER.PACKAGES_2_CLASSES[package] = dict() 31 | 32 | @staticmethod 33 | def get_packages(): 34 | """get packages list 35 | 36 | Get the packages list stored in REGISTER.PACKAGES_2_CLASSES 37 | 38 | Returns: 39 | list -- packages list 40 | """ 41 | return REGISTER.PACKAGES_2_CLASSES.keys() 42 | 43 | @staticmethod 44 | def get_package_name(module): 45 | """get the package name of a module 46 | 47 | Get the package name of the target module by removing the text after 48 | the last dot occured in the provided name 49 | 50 | Arguments: 51 | module {str} -- target module name 52 | """ 53 | return '.'.join(module.split('.')[:-1]) 54 | 55 | @staticmethod 56 | def is_package_registered(package): 57 | """Check whether a package has been registered 58 | 59 | Check whether a package has been registered according to whether its 60 | name is occured in PACKAGES_2_CLASSES 61 | 62 | Arguments: 63 | package {str} -- target package name 64 | 65 | Returns: 66 | bool -- indicating whether package has been registered 67 | """ 68 | return package in REGISTER.PACKAGES_2_CLASSES 69 | 70 | @staticmethod 71 | def _check_package(package): 72 | assert package in REGISTER.PACKAGES_2_CLASSES, ('No package named [%s] ' 73 | 'has been registered' % package) 74 | 75 | @staticmethod 76 | def get_classes(package): 77 | """get all classes on the target package 78 | 79 | Get all registered classes on the target package according to the 80 | PACKAGES_2_CLASSES dictory 81 | 82 | Arguments: 83 | package {str} -- package name used when registering 84 | 85 | Returns: 86 | dict -- {'name1' : class1, 'name2' : class2, ...} 87 | """ 88 | 89 | # make sure the packages has been registered 90 | REGISTER._check_package(package) 91 | 92 | return REGISTER.PACKAGES_2_CLASSES[package] 93 | 94 | @staticmethod 95 | def is_class_registered(package, name): 96 | return (package in REGISTER.PACKAGES_2_CLASSES and 97 | name in REGISTER.PACKAGES_2_CLASSES[package]) 98 | 99 | @staticmethod 100 | def get_class(package, name): 101 | """get class 102 | 103 | This function is used to get the class named `name` from 104 | package `package` 105 | 106 | Arguments: 107 | package {str} -- package name, should be the same as the one when registering 108 | name {str} -- class name, should be the same as the one when registering 109 | 110 | Returns: 111 | [Object] -- found class 112 | """ 113 | 114 | # make sure the package has been registered 115 | REGISTER._check_package(package) 116 | 117 | pack = REGISTER.PACKAGES_2_CLASSES[package] 118 | 119 | assert name in pack, ('No class named [%s] has been registered ' 120 | 'in package [%s]' % (name, package)) 121 | 122 | return pack[name] 123 | 124 | @staticmethod 125 | def set_class(package, name, cls): 126 | """set class 127 | 128 | This function is used to register a class with a specific name in 129 | the target package 130 | 131 | Arguments: 132 | package {str} -- target package name 133 | name {str} -- class name which will be used when query 134 | cls {Object} -- class object 135 | """ 136 | 137 | # get the corresponding package name 138 | # package = '.'.join(package.split('.')[:-1]) 139 | if not package in REGISTER.PACKAGES_2_CLASSES: 140 | REGISTER.set_package(package) 141 | 142 | # if the `cls` arg is a string, then get the real class object 143 | if isinstance(cls, str): 144 | cls = sys.modules[cls] 145 | 146 | REGISTER.PACKAGES_2_CLASSES[package][name] = cls -------------------------------------------------------------------------------- /packages/loggers/std_logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2018-11-11 15:08:06 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | import sys 8 | import logging 9 | 10 | from ..config import CONFIG as cfg 11 | 12 | class ColoredFormatter(logging.Formatter): 13 | """Colored Formatter for logging module 14 | 15 | different logging level will used different color when printing 16 | 17 | Extends: 18 | logging.Formatter 19 | 20 | Variables: 21 | BLACK, RED, GREEN, YELLOW, BLUE, MAGEENTA, CYAN, WHITE {[Number]} -- [default colors] 22 | RESET_SEQ {str} -- [Sequence end flag] 23 | COLOR_SEQ {str} -- [color sequence start flag] 24 | BOLD_SEQ {str} -- [bold sequence start flag] 25 | COLORS {dict} -- [logging level to color dictionary] 26 | """ 27 | 28 | BLACK, RED, GREEN, YELLOW, BLUE, MAGEENTA, CYAN, WHITE = range(8) 29 | RESET_SEQ = "\033[0m" 30 | COLOR_SEQ = "\033[%dm" 31 | BOLD_SEQ = "\033[1m" 32 | COLORS = { 33 | 'WARNING': YELLOW, 34 | 'DEBUG': GREEN, 35 | 'CRITICAL': BLUE, 36 | 'ERROR': RED 37 | } 38 | 39 | def __init__(self, *args, **kwargs): 40 | logging.Formatter.__init__(self, *args, **kwargs) 41 | 42 | def format(self, record): 43 | msg = logging.Formatter.format(self, record) 44 | levelname = record.levelname 45 | msg_color = msg 46 | if levelname in ColoredFormatter.COLORS: 47 | msg_color = (ColoredFormatter.COLOR_SEQ % 48 | (30 + ColoredFormatter.COLORS[levelname]) + msg + 49 | ColoredFormatter.RESET_SEQ) 50 | return msg_color 51 | 52 | class STDLogger: 53 | ''' 54 | static class for logging 55 | call setup() first to set log level and then call info/debug/error/warn 56 | to print log msg 57 | ''' 58 | 59 | LOGGER = None 60 | 61 | INFO = logging.INFO 62 | DEBUG = logging.DEBUG 63 | C_LEVEL = logging.DEBUG 64 | 65 | FMT_GENERAL = logging.Formatter('[%(levelname)s][%(asctime)s]\t%(message)s', 66 | datefmt='%Y-%m-%d %H:%M:%S') 67 | FMT_COLOR = ColoredFormatter('[%(levelname)s][%(asctime)s]\t%(message)s', 68 | datefmt='%Y-%m-%d %H:%M:%S') 69 | 70 | @staticmethod 71 | def require_args(): 72 | 73 | cfg.add_argument('--log-file', action='store_true', 74 | help='store log to file. (default: False)') 75 | 76 | @staticmethod 77 | def setup(level=None, to_file=None): 78 | # if level is not provided, then dont change 79 | level = level if level is not None else STDLogger.C_LEVEL 80 | 81 | if STDLogger.LOGGER is None: 82 | STDLogger.LOGGER = logging.getLogger(__name__) 83 | STDLogger.LOGGER.propagate = False 84 | # declare and set console handler 85 | console_handler = logging.StreamHandler() 86 | console_handler.setFormatter(STDLogger.FMT_COLOR) 87 | STDLogger.LOGGER.addHandler(console_handler) 88 | 89 | if to_file is not None: 90 | # declare and set file handler 91 | STDLogger.LOGGER.debug('Log will be stored in %s' % to_file) 92 | file_handler = logging.FileHandler(to_file) 93 | file_handler.setFormatter(STDLogger.FMT_GENERAL) 94 | STDLogger.LOGGER.addHandler(file_handler) 95 | 96 | STDLogger.LOGGER.setLevel(level) 97 | STDLogger.C_LEVEL = level 98 | 99 | @staticmethod 100 | def check(): 101 | if not isinstance(STDLogger.LOGGER, logging.Logger): 102 | STDLogger.setup() 103 | # raise ValueError('Call logger.setup(level) to initialize') 104 | 105 | @staticmethod 106 | def info(*args, **kwargs): 107 | STDLogger.check() 108 | STDLogger.erase() 109 | return STDLogger.LOGGER.info(*args, **kwargs) 110 | 111 | @staticmethod 112 | def debug(*args, **kwargs): 113 | STDLogger.check() 114 | STDLogger.erase() 115 | return STDLogger.LOGGER.debug(*args, **kwargs) 116 | 117 | @staticmethod 118 | def error(*args, **kwargs): 119 | STDLogger.check() 120 | STDLogger.erase() 121 | return STDLogger.LOGGER.error(*args, **kwargs) 122 | 123 | @staticmethod 124 | def warn(*args, **kwargs): 125 | STDLogger.check() 126 | STDLogger.erase() 127 | return STDLogger.LOGGER.warn(*args, **kwargs) 128 | 129 | @staticmethod 130 | def erase_lines(n=1): 131 | for _ in xrange(n): 132 | sys.stdout.write('\x1b[1A') 133 | sys.stdout.write('\x1b[2K') 134 | sys.stdout.flush() 135 | 136 | @staticmethod 137 | def go_up(): 138 | sys.stdout.write('\x1b[1A') 139 | sys.stdout.flush() 140 | 141 | @staticmethod 142 | def erase(): 143 | sys.stdout.write('\x1b[2K') 144 | sys.stdout.flush() 145 | 146 | @staticmethod 147 | def progress(current, total, msg='processing %d/%d item...'): 148 | STDLogger.erase() 149 | print msg % (current, total) 150 | STDLogger.go_up() 151 | 152 | from ..register import REGISTER 153 | REGISTER.set_class(REGISTER.get_package_name(__name__), 'STDLogger', STDLogger) -------------------------------------------------------------------------------- /packages/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2019-01-25 14:58:24 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | from ..register import REGISTER 8 | from ..loggers.std_logger import STDLogger as logger 9 | from ..utils import tuple_or_list, get_valid_size 10 | from ..config import CONFIG as cfg 11 | 12 | import transforms as custom_transforms 13 | import torchvision.transforms as transforms 14 | 15 | def require_args(): 16 | """all args for dataset objects 17 | 18 | Arguments: 19 | parser {argparse} -- current version of argparse object 20 | """ 21 | 22 | known_args, _ = cfg.parse_known_args() 23 | 24 | # basic args for all datasets 25 | cfg.add_argument('--data-root', default=None, type=str, 26 | help='root path to dataset') 27 | cfg.add_argument('--resize', default="256", type=str, 28 | help='resize into. (default: 256)') 29 | cfg.add_argument('--size', default="224", type=str, 30 | help='crop into. (default: 224)') 31 | cfg.add_argument('--scale', default=None, type=str, 32 | help='scale for random resize crop. (default: None)') 33 | cfg.add_argument('--ratio', default="(0.75, 1.3333333333333)", type=str, 34 | help='ratio for random resize crop. (default: (0.75, 1.333)') 35 | cfg.add_argument('--colorjitter', default=None, type=str, 36 | help='color jitters for input. (default: None)') 37 | cfg.add_argument('--random-grayscale', default=0, type=float, 38 | help='transform input to gray scale. (default: 0)') 39 | cfg.add_argument('--random-horizontal-flip', action='store_true', 40 | help='random horizontally flip for input. (default: False)') 41 | 42 | # basic args for all dataloader 43 | cfg.add_argument('--batch-size', default=128, type=int, 44 | help='batch size for input data. (default: 128)') 45 | cfg.add_argument('--workers-num', default=4, type=int, 46 | help='number of workers being used to load data. (default: 4)') 47 | 48 | # get args for datasets 49 | if (REGISTER.is_package_registered(__name__) and 50 | REGISTER.is_class_registered(__name__, known_args.dataset)): 51 | 52 | dataset = get(known_args.dataset) 53 | 54 | if hasattr(dataset, 'require_args'): 55 | dataset.require_args() 56 | 57 | def get(name): 58 | return REGISTER.get_class(__name__, name) 59 | 60 | def get_transforms(stage='train', means=None, stds=None): 61 | 62 | transform = [] 63 | 64 | stage = stage.lower() 65 | assert stage in ['train', 'test'], ('arg [stage]' 66 | ' should be one of ["train", "test"]') 67 | resize = get_valid_size(cfg.resize) 68 | size = get_valid_size(cfg.size) 69 | if stage == 'train': 70 | # size transform 71 | scale = tuple_or_list(cfg.scale) 72 | if scale: 73 | ratio = tuple_or_list(cfg.ratio) 74 | logger.debug('Training samples will be random resized and crop ' 75 | 'with size %s, scale %s and ratio %s' 76 | % (size, scale, ratio)) 77 | transform.append(custom_transforms.RandomResizedCrop( 78 | size=size, scale=scale, ratio=ratio)) 79 | else: 80 | logger.debug('Training samples will be resized to %s and then ' 81 | 'random cropped into %s' % (resize, size)) 82 | transform.append(transforms.Resize(size=resize)) 83 | transform.append(transforms.RandomCrop(size)) 84 | # color jitter transform 85 | colorjitter = tuple_or_list(cfg.colorjitter) 86 | if colorjitter is not None: 87 | logger.debug('Training samples will use color jitter to enhance ' 88 | 'with args: %s' % (colorjitter,)) 89 | transform.append(transforms.ColorJitter(*colorjitter)) 90 | 91 | # gray scale 92 | if cfg.random_grayscale > 0: 93 | logger.debug('Training samples will be randomly convert to ' 94 | 'grayscale with probability %.2f' % cfg.random_grayscale) 95 | transform.append(transforms.RandomGrayscale( 96 | p=cfg.random_grayscale)) 97 | 98 | # random horizontal flip 99 | if cfg.random_horizontal_flip: 100 | logger.debug('Training samples will be random horizontally flip') 101 | transform.append(transforms.RandomHorizontalFlip()) 102 | 103 | else: 104 | logger.debug('Testing samples will be resized to %s and then center ' 105 | 'crop to %s' % (resize, size)) 106 | transform.extend([transforms.Resize(resize), 107 | transforms.CenterCrop(size)]) 108 | 109 | # to tensor 110 | transform.append(transforms.ToTensor()) 111 | 112 | # normalize 113 | means = tuple_or_list(means) 114 | stds = tuple_or_list(stds) 115 | if not (means is None or stds is None): 116 | logger.debug('Samples will be normalized with means: %s and stds: %s' % 117 | (means, stds)) 118 | transform.append(transforms.Normalize(means, stds)) 119 | else: 120 | logger.debug('Input images will not be normalized') 121 | 122 | return transforms.Compose(transform) 123 | 124 | def register(name, cls): 125 | REGISTER.set_class(__name__, name, cls) 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /lib/ans_discovery.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2018-10-12 21:37:13 4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk) 5 | # @Link : github.com/Raymond-sci 6 | 7 | import sys 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from packages.config import CONFIG as cfg 14 | from packages.register import REGISTER 15 | from packages.loggers.std_logger import STDLogger as logger 16 | 17 | 18 | def require_args(): 19 | cfg.add_argument('--ANs-select-rate', default=0.25, type=float, 20 | help='ANs select rate at each round') 21 | cfg.add_argument('--ANs-size', default=1, type=int, 22 | help='ANs size discarding the anchor') 23 | 24 | class ANsDiscovery(nn.Module): 25 | """Discovery ANs 26 | 27 | Discovery ANs according to current round, select_rate and most importantly, 28 | all sample's corresponding entropy 29 | """ 30 | 31 | def __init__(self, nsamples): 32 | """Object used to discovery ANs 33 | 34 | Discovery ANs according to the total amount of samples, ANs selection 35 | rate, ANs size 36 | 37 | Arguments: 38 | nsamples {int} -- total number of sampels 39 | select_rate {float} -- ANs selection rate 40 | ans_size {int} -- ANs size 41 | 42 | Keyword Arguments: 43 | device {str} -- [description] (default: {'cpu'}) 44 | """ 45 | super(ANsDiscovery, self).__init__() 46 | 47 | # not going to use ``register_buffer'' as 48 | # they are determined by configs 49 | self.select_rate = cfg.ANs_select_rate 50 | self.ANs_size = cfg.ANs_size 51 | # number of samples 52 | self.register_buffer('samples_num', torch.tensor(nsamples)) 53 | # indexes list of anchor samples 54 | self.register_buffer('anchor_indexes', torch.LongTensor([])) 55 | # indexes list of instance samples 56 | self.register_buffer('instance_indexes', torch.arange(nsamples).long()) 57 | # anchor samples' and instance samples' position 58 | self.register_buffer('position', -1 * torch.arange(nsamples).long() - 1) 59 | # anchor samples' neighbours 60 | self.register_buffer('neighbours', torch.LongTensor([])); 61 | # each sample's entropy 62 | self.register_buffer('entropy', torch.FloatTensor(nsamples)); 63 | # consistency 64 | self.register_buffer('consistency', torch.tensor(0.)); 65 | 66 | def get_ANs_num(self, round): 67 | """Get number of ANs 68 | 69 | Get number of ANs at target round according to the select rate 70 | 71 | Arguments: 72 | round {int} -- target round 73 | 74 | Returns: 75 | int -- number of ANs 76 | """ 77 | return int(self.samples_num.float() * self.select_rate * round) 78 | 79 | def update(self, round, npc, cheat_labels=None): 80 | """Update ANs 81 | 82 | Discovery new ANs and update `anchor_indexes`, `instance_indexes` and 83 | `neighbours` 84 | 85 | Arguments: 86 | round {int} -- target round 87 | npc {Module} -- non-parametric classifier 88 | cheat_labels {list} -- used to compute consistency of chosen ANs only 89 | 90 | Returns: 91 | number -- [updated consistency] 92 | """ 93 | with torch.no_grad(): 94 | batch_size = 100 95 | ANs_num = self.get_ANs_num(round) 96 | logger.debug('Going to choose %d samples as anchors' % ANs_num) 97 | features = npc.memory 98 | 99 | logger.debug('Start to compute each sample\'s entropy') 100 | for start in xrange(0, self.samples_num, batch_size): 101 | logger.progress(start, self.samples_num, 'processing %d/%d samples...') 102 | 103 | end = start + batch_size 104 | end = min(end, self.samples_num) 105 | 106 | preds = F.softmax(npc(features[start:end], None), 1) 107 | self.entropy[start:end] = -(preds * preds.log()).sum(1) 108 | 109 | logger.debug('Compute entropy done, max(%.2f), min(%.2f), mean(%.2f)' 110 | % (self.entropy.max(), self.entropy.min(), self.entropy.mean())) 111 | 112 | # get the anchor list and instance list according to the computed 113 | # entropy 114 | self.anchor_indexes = self.entropy.topk(ANs_num, largest=False)[1] 115 | self.instance_indexes = (torch.ones_like(self.position) 116 | .scatter_(0, self.anchor_indexes, 0) 117 | .nonzero().view(-1)) 118 | anchor_entropy = self.entropy.index_select(0, self.anchor_indexes) 119 | instance_entropy = self.entropy.index_select(0, self.instance_indexes) 120 | if self.anchor_indexes.size(0) > 0: 121 | logger.debug('Entropies of anchor samples: max(%.2f), ' 122 | 'min(%.2f), mean(%.2f)' % (anchor_entropy.max(), 123 | anchor_entropy.min(), anchor_entropy.mean())) 124 | if self.instance_indexes.size(0) > 0: 125 | logger.debug('Entropies of instance sample: max(%.2f), ' 126 | 'min(%.2f), mean(%.2f)' % (instance_entropy.max(), 127 | instance_entropy.min(), instance_entropy.mean())) 128 | 129 | # get position 130 | # if the anchor sample x whose index is i while position is j, then 131 | # sample x_i is the j-th anchor sample at current round 132 | # if the instance sample x whose index is i while position is j, then 133 | # sample x_i is the (-j-1)-th instance sample at current round 134 | logger.debug('Start to get the position of both anchor and ' 135 | 'instance samples') 136 | instance_cnt = 0 137 | for i in xrange(self.samples_num): 138 | logger.progress(i, self.samples_num, 'processing %d/%d samples...') 139 | 140 | # for anchor samples 141 | if (i == self.anchor_indexes).any(): 142 | self.position[i] = (self.anchor_indexes == i).max(0)[1] 143 | continue 144 | # for instance samples 145 | instance_cnt -= 1 146 | self.position[i] = instance_cnt 147 | 148 | logger.debug('Start to find %d neighbours for each anchor sample' 149 | % self.ANs_size) 150 | anchor_features = features.index_select(0, self.anchor_indexes) 151 | self.neighbours = (torch.LongTensor(ANs_num, self.ANs_size) 152 | .to(cfg.device)) 153 | for start in xrange(0, ANs_num, batch_size): 154 | logger.progress(start, ANs_num, 'processing %d/%d samples...') 155 | 156 | end = start + batch_size 157 | end = min(end, ANs_num) 158 | 159 | sims = torch.mm(anchor_features[start:end], features.t()) 160 | sims.scatter_(1, self.anchor_indexes[start:end].view(-1, 1), -1.) 161 | _, self.neighbours[start:end] = ( 162 | sims.topk(self.ANs_size, largest=True, dim=1)) 163 | logger.debug('ANs discovery done') 164 | 165 | # if cheat labels is provided, then compute consistency 166 | if cheat_labels is None: 167 | return 0. 168 | logger.debug('Start to compute ANs consistency') 169 | anchor_label = cheat_labels.index_select(0, self.anchor_indexes) 170 | neighbour_label = cheat_labels.index_select(0, 171 | self.neighbours.view(-1)).view_as(self.neighbours) 172 | self.consistency = ((anchor_label.view(-1, 1) == neighbour_label) 173 | .float().mean()) 174 | 175 | return self.consistency 176 | 177 | REGISTER.set_package(__name__) 178 | REGISTER.set_class(__name__, 'ans', ANsDiscovery) 179 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2018-09-27 15:09:03 4 | # @Author : Jiabo (Raymond) Huang (jiabo.huang@qmul.ac.uk) 5 | # @Link : https://github.com/Raymond-sci 6 | 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | 10 | import sys 11 | import os 12 | import time 13 | from datetime import datetime 14 | 15 | import models 16 | import datasets 17 | 18 | from lib import protocols 19 | from lib.non_parametric_classifier import NonParametricClassifier 20 | from lib.criterion import Criterion 21 | from lib.ans_discovery import ANsDiscovery 22 | from lib.utils import AverageMeter, time_progress, adjust_learning_rate 23 | 24 | from packages import session 25 | from packages import lr_policy 26 | from packages import optimizers 27 | from packages.config import CONFIG as cfg 28 | from packages.loggers.std_logger import STDLogger as logger 29 | from packages.loggers.tf_logger import TFLogger as SummaryWriter 30 | 31 | 32 | def require_args(): 33 | 34 | # dataset to be used 35 | cfg.add_argument('--dataset', default='cifar10', type=str, 36 | help='dataset to be used. (default: cifar10)') 37 | 38 | # network to be used 39 | cfg.add_argument('--network', default='resnet18', type=str, 40 | help='backbone to be used. (default: ResNet18)') 41 | 42 | # optimizer to be used 43 | cfg.add_argument('--optimizer', default='sgd', type=str, 44 | help='optimizer to be used. (default: sgd)') 45 | 46 | # lr policy to be used 47 | cfg.add_argument('--lr-policy', default='step', type=str, 48 | help='lr policy to be used. (default: step)') 49 | 50 | # args for protocol 51 | cfg.add_argument('--protocol', default='knn', type=str, 52 | help='protocol used to validate model') 53 | 54 | # args for network training 55 | cfg.add_argument('--max-epoch', default=200, type=int, 56 | help='max epoch per round. (default: 200)') 57 | cfg.add_argument('--max-round', default=5, type=int, 58 | help='max iteration, including initialisation one. ' 59 | '(default: 5)') 60 | cfg.add_argument('--iter-size', default=1, type=int, 61 | help='caffe style iter size. (default: 1)') 62 | cfg.add_argument('--display-freq', default=1, type=int, 63 | help='display step') 64 | cfg.add_argument('--test-only', action='store_true', 65 | help='test only') 66 | 67 | 68 | def main(): 69 | 70 | logger.info('Start to declare training variables') 71 | cfg.device = device = 'cuda' if torch.cuda.is_available() else 'cpu' 72 | best_acc = 0. # best test accuracy 73 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 74 | start_round = 0 # start for iter 0 or last checkpoint iter 75 | 76 | logger.info('Start to prepare data') 77 | trainset, trainloader, testset, testloader = datasets.get(cfg.dataset, instant=True) 78 | # cheat labels are used to compute neighbourhoods consistency only 79 | cheat_labels = torch.tensor(trainset.labels).long().to(device) 80 | ntrain, ntest = len(trainset), len(testset) 81 | logger.info('Totally got %d training and %d test samples' % (ntrain, ntest)) 82 | 83 | logger.info('Start to build model') 84 | net = models.get(cfg.network, instant=True) 85 | npc = NonParametricClassifier(cfg.low_dim, ntrain, cfg.npc_temperature, cfg.npc_momentum) 86 | ANs_discovery = ANsDiscovery(ntrain) 87 | criterion = Criterion() 88 | optimizer = optimizers.get(cfg.optimizer, instant=True, params=net.parameters()) 89 | lr_handler = lr_policy.get(cfg.lr_policy, instant=True) 90 | protocol = protocols.get(cfg.protocol) 91 | 92 | # data parallel 93 | if device == 'cuda': 94 | if (cfg.network.lower().startswith('alexnet') or 95 | cfg.network.lower().startswith('vgg')): 96 | net.features = torch.nn.DataParallel(net.features, 97 | device_ids=range(len(cfg.gpus.split(',')))) 98 | else: 99 | net = torch.nn.DataParallel(net, device_ids=range( 100 | len(cfg.gpus.split(',')))) 101 | cudnn.benchmark = True 102 | 103 | net, npc, ANs_discovery, criterion = (net.to(device), npc.to(device), 104 | ANs_discovery.to(device), criterion.to(device)) 105 | 106 | # load ckpt file if necessary 107 | if cfg.resume: 108 | assert os.path.exists(cfg.resume), "Resume file not found: %s" % cfg.resume 109 | logger.info('Start to resume from %s' % cfg.resume) 110 | ckpt = torch.load(cfg.resume) 111 | net.load_state_dict(ckpt['net']) 112 | optimizer.load_state_dict(ckpt['optimizer']) 113 | npc = npc.load_state_dict(ckpt['npc']) 114 | ANs_discovery.load_state_dict(ckpt['ANs_discovery']) 115 | best_acc = ckpt['acc'] 116 | start_epoch = ckpt['epoch'] 117 | start_round = ckpt['round'] 118 | 119 | # test if necessary 120 | if cfg.test_only: 121 | logger.info('Testing at beginning...') 122 | acc = protocol(net, npc, trainloader, testloader, 200, 123 | cfg.npc_temperature, True, device) 124 | logger.info('Evaluation accuracy at %d round and %d epoch: %.2f%%' % 125 | (start_round, start_epoch, acc * 100)) 126 | sys.exit(0) 127 | 128 | logger.info('Start the progressive training process from round: %d, ' 129 | 'epoch: %d, best acc is %.4f...' % (start_round, start_epoch, best_acc)) 130 | round = start_round 131 | global_writer = SummaryWriter(cfg.debug, 132 | log_dir=os.path.join(cfg.tfb_dir, 'global')) 133 | while (round < cfg.max_round): 134 | 135 | # variables are initialized to different value in the first round 136 | is_first_round = True if round == start_round else False 137 | best_acc = best_acc if is_first_round else 0 138 | 139 | if not is_first_round: 140 | logger.info('Start to mining ANs at %d round' % round) 141 | ANs_discovery.update(round, npc, cheat_labels) 142 | logger.info('ANs consistency at %d round is %.2f%%' % 143 | (round, ANs_discovery.consistency * 100)) 144 | 145 | ANs_num = ANs_discovery.anchor_indexes.shape[0] 146 | global_writer.add_scalar('ANs/Number', ANs_num, round) 147 | global_writer.add_scalar('ANs/Consistency', ANs_discovery.consistency, round) 148 | 149 | # declare local writer 150 | writer = SummaryWriter(cfg.debug, log_dir=os.path.join(cfg.tfb_dir, 151 | '%04d-%05d' % (round, ANs_num))) 152 | logger.info('Start training at %d/%d round' % (round, cfg.max_round)) 153 | 154 | 155 | # start to train for an epoch 156 | epoch = start_epoch if is_first_round else 0 157 | lr = cfg.base_lr 158 | while lr > 0 and epoch < cfg.max_epoch: 159 | 160 | # get learning rate according to current epoch 161 | lr = lr_handler.update(epoch) 162 | 163 | train(round, epoch, net, trainloader, optimizer, npc, criterion, 164 | ANs_discovery, lr, writer) 165 | 166 | logger.info('Start to evaluate...') 167 | acc = protocol(net, npc, trainloader, testloader, 200, 168 | cfg.npc_temperature, False, device) 169 | writer.add_scalar('Evaluate/Rank-1', acc, epoch) 170 | 171 | logger.info('Evaluation accuracy at %d round and %d epoch: %.1f%%' 172 | % (round, epoch, acc * 100)) 173 | logger.info('Best accuracy at %d round and %d epoch: %.1f%%' 174 | % (round, epoch, best_acc * 100)) 175 | 176 | is_best = acc >= best_acc 177 | best_acc = max(acc, best_acc) 178 | if is_best and not cfg.debug: 179 | target = os.path.join(cfg.ckpt_dir, '%04d-%05d.ckpt' 180 | % (round, ANs_num)) 181 | logger.info('Saving checkpoint to %s' % target) 182 | state = { 183 | 'net': net.state_dict(), 184 | 'optimizer': optimizer.state_dict(), 185 | 'ANs_discovery' : ANs_discovery.state_dict(), 186 | 'npc' : npc.state_dict(), 187 | 'acc': acc, 188 | 'epoch': epoch + 1, 189 | 'round' : round, 190 | 'session' : cfg.session 191 | } 192 | torch.save(state, target) 193 | epoch += 1 194 | 195 | # log best accuracy after each iteration 196 | global_writer.add_scalar('Evaluate/best_acc', best_acc, round) 197 | round += 1 198 | 199 | # Training 200 | def train(round, epoch, net, trainloader, optimizer, npc, criterion, 201 | ANs_discovery, lr, writer): 202 | 203 | # tracking variables 204 | train_loss = AverageMeter() 205 | data_time = AverageMeter() 206 | batch_time = AverageMeter() 207 | 208 | # switch the model to train mode 209 | net.train() 210 | # adjust learning rate 211 | adjust_learning_rate(optimizer, lr) 212 | 213 | end = time.time() 214 | start_time = datetime.now() 215 | optimizer.zero_grad() 216 | for batch_idx, (inputs, _, indexes) in enumerate(trainloader): 217 | data_time.update(time.time() - end) 218 | inputs, indexes = inputs.to(cfg.device), indexes.to(cfg.device) 219 | 220 | features = net(inputs) 221 | outputs = npc(features, indexes) 222 | loss = criterion(outputs, indexes, ANs_discovery) / cfg.iter_size 223 | 224 | loss.backward() 225 | train_loss.update(loss.item() * cfg.iter_size, inputs.size(0)) 226 | 227 | if batch_idx % cfg.iter_size == 0: 228 | optimizer.step() 229 | optimizer.zero_grad() 230 | 231 | # measure elapsed time 232 | batch_time.update(time.time() - end) 233 | end = time.time() 234 | 235 | if batch_idx % cfg.display_freq != 0: 236 | continue 237 | 238 | writer.add_scalar('Train/Learning_Rate', lr, 239 | epoch * len(trainloader) + batch_idx) 240 | writer.add_scalar('Train/Loss', train_loss.val, 241 | epoch * len(trainloader) + batch_idx) 242 | 243 | 244 | elapsed_time, estimated_time = time_progress(batch_idx + 1, 245 | len(trainloader), batch_time.sum) 246 | logger.info('Round: {round} Epoch: {epoch}/{tot_epochs} ' 247 | 'Progress: {elps_iters}/{tot_iters} ({elps_time}/{est_time}) ' 248 | 'Data: {data_time.avg:.3f} LR: {learning_rate:.5f} ' 249 | 'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f})'.format( 250 | round=round, epoch=epoch, tot_epochs=cfg.max_epoch, 251 | elps_iters=batch_idx, tot_iters=len(trainloader), 252 | elps_time=elapsed_time, est_time=estimated_time, 253 | data_time=data_time, learning_rate=lr, 254 | train_loss=train_loss)) 255 | 256 | if __name__ == '__main__': 257 | 258 | session.run(__name__) 259 | --------------------------------------------------------------------------------