├── lib
├── __init__.py
├── normalize.py
├── criterion.py
├── utils.py
├── non_parametric_classifier.py
├── protocols.py
└── ans_discovery.py
├── docs
├── assets
│ ├── large-scale.jpg
│ ├── small-scale.jpg
│ └── training-pipeline.jpg
├── LICENSE
└── README.md
├── configs
├── cifar10.yaml
├── cifar100.yaml
├── svhn.yaml
└── base.yaml
├── packages
├── __init__.py
├── datasets
│ ├── transforms.py
│ └── __init__.py
├── loggers
│ ├── __init__.py
│ ├── tf_logger.py
│ └── std_logger.py
├── lr_policy
│ ├── fixed.py
│ ├── __init__.py
│ ├── step.py
│ └── multistep.py
├── networks
│ └── __init__.py
├── optimizers
│ ├── __init__.py
│ ├── sgd.py
│ ├── rmsprop.py
│ └── adam.py
├── utils.py
├── config.py
├── session.py
└── register.py
├── models
├── __init__.py
└── resnet_cifar.py
├── datasets
├── svhn.py
├── __init__.py
└── cifar.py
├── requirements.yaml
└── main.py
/lib/__init__.py:
--------------------------------------------------------------------------------
1 | # nothing
--------------------------------------------------------------------------------
/docs/assets/large-scale.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Raymond-sci/AND/HEAD/docs/assets/large-scale.jpg
--------------------------------------------------------------------------------
/docs/assets/small-scale.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Raymond-sci/AND/HEAD/docs/assets/small-scale.jpg
--------------------------------------------------------------------------------
/docs/assets/training-pipeline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Raymond-sci/AND/HEAD/docs/assets/training-pipeline.jpg
--------------------------------------------------------------------------------
/configs/cifar10.yaml:
--------------------------------------------------------------------------------
1 | ####################
2 | # cfgs for cifar10 #
3 | ####################
4 |
5 | # args for dataset and dataloader
6 | dataset: cifar10
7 | data_root: data/cifar10
8 | batch_size: 128
9 |
--------------------------------------------------------------------------------
/configs/cifar100.yaml:
--------------------------------------------------------------------------------
1 | ####################
2 | # cfgs for cifar10 #
3 | ####################
4 |
5 | # args for dataset and dataloader
6 | dataset: cifar100
7 | data_root: data/cifar100
8 | batch_size: 128
9 |
--------------------------------------------------------------------------------
/configs/svhn.yaml:
--------------------------------------------------------------------------------
1 | ####################
2 | # cfgs for cifar10 #
3 | ####################
4 |
5 | # args for dataset and dataloader
6 | dataset: svhn
7 | data_root: data/svhn
8 | batch_size: 128
9 | display_freq: 115
--------------------------------------------------------------------------------
/packages/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-25 14:37:26
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | # from . import datasets
8 | # from . import loggers
9 | # from . import networks
10 | # from . import optimizers
11 | # from . import argparser
12 | # from . import lr_policy
13 | # from . import utils
--------------------------------------------------------------------------------
/packages/datasets/transforms.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-28 21:45:01
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | import torchvision.transforms as transforms
8 |
9 | class RandomResizedCrop(transforms.RandomResizedCrop):
10 |
11 | def __init__(self, size, **kwargs):
12 | super(RandomResizedCrop, self).__init__(0, **kwargs)
13 | self.size = size
--------------------------------------------------------------------------------
/lib/normalize.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2018-10-12 21:37:13
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | import torch
8 | from torch.autograd import Variable
9 | from torch import nn
10 |
11 | class Normalize(nn.Module):
12 | """Normalize module
13 |
14 | Module used to normalize matrix
15 |
16 | Extends:
17 | nn.Module
18 | """
19 |
20 | def __init__(self, power=2):
21 | super(Normalize, self).__init__()
22 | self.power = power
23 |
24 | def forward(self, x):
25 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1./self.power)
26 | out = x.div(norm)
27 | return out
28 |
--------------------------------------------------------------------------------
/configs/base.yaml:
--------------------------------------------------------------------------------
1 | #############
2 | # base cfgs #
3 | #############
4 |
5 | # args for AND
6 | ANs_select_rate: 0.25
7 | ANs_size: 1
8 | max_round: 5
9 |
10 | # args for network
11 | network: ResNet18
12 |
13 | # args for training
14 | log_file: True
15 | log_tfb: True
16 | display_freq: 80
17 | workers_num: 4
18 |
19 | # args for transforms
20 | size: (32, 32)
21 | resize: 32
22 | scale: (0.2, 1.)
23 | ratio: (0.75, 1.333333)
24 | colorjitter: (0.4, 0.4, 0.4, 0.4)
25 | random_grayscale: 0.2
26 | # random_horizontal_flip: True
27 |
28 | # args for lr policy
29 | base_lr: 0.03
30 | lr_policy: step
31 | lr_decay_offset: 80
32 | lr_decay_step: 40
33 | lr_decay_rate: 0.1
34 |
35 | # args for optimizer
36 | optimizer: sgd
37 | weight_decay: 5e-4
38 | momentum: 0.9
39 | nesterov: True
40 |
41 | # args for protocol
42 | protocol: knn
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-03-13 21:25:39
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | from .resnet_cifar import *
8 |
9 | from packages.register import REGISTER
10 | from packages.config import CONFIG as cfg
11 | from packages import networks as cmd_networks
12 |
13 | def require_args():
14 |
15 | cfg.add_argument('--low-dim', default=128, type=int, help='feature dimension')
16 |
17 | def get(name, instant=False):
18 | cls = cmd_networks.get(name)
19 | if not instant:
20 | return cls
21 | return cls(low_dim=cfg.low_dim)
22 |
23 | REGISTER.set_package(__name__)
24 | cmd_networks.register('ResNet18', ResNet18)
25 | cmd_networks.register('ResNet50', ResNet50)
26 | cmd_networks.register('ResNet101', ResNet101)
27 |
--------------------------------------------------------------------------------
/packages/loggers/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-24 23:26:16
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | from . import std_logger
8 | from . import tf_logger
9 |
10 | from ..register import REGISTER
11 |
12 | def require_args():
13 | """all args for logger objects
14 |
15 | Arguments:
16 | parser {argparse} -- current version of argparse object
17 | """
18 | if not REGISTER.is_package_registered(__name__):
19 | return parser
20 |
21 | classes = REGISTER.get_classes(__name__)
22 |
23 | for (name, cls) in classes.items():
24 | if hasattr(cls, 'require_args'):
25 | cls.require_args()
26 |
27 | def get(name):
28 | return REGISTER.get_class(__name__, name)
29 |
30 | def register(name, cls):
31 | REGISTER.set_class(__name__, name, cls)
--------------------------------------------------------------------------------
/packages/lr_policy/fixed.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-26 20:15:47
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | from ..loggers.std_logger import STDLogger as logger
8 | from ..config import CONFIG as cfg
9 |
10 | class FixedPolicy:
11 |
12 | def __init__(self, *args, **kwargs):
13 | if len(args) + len(kwargs) == 0:
14 | self.__init_by_cfg()
15 | else:
16 | self.__init(*args, **kwargs)
17 |
18 | def __init(self, base_lr):
19 | self.base_lr = base_lr
20 | logger.debug('Going to use [fixed] learning policy for optimization'
21 | ' with base learning rate [%.5f]' % base_lr)
22 |
23 | def __init_by_cfg(self):
24 | self.__init(cfg.base_lr)
25 |
26 | def update(self, epoch, *args, **kwargs):
27 | return self.base_lr
28 |
29 | from ..register import REGISTER
30 | REGISTER.set_class(REGISTER.get_package_name(__name__), 'fixed', FixedPolicy)
--------------------------------------------------------------------------------
/packages/networks/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-24 23:26:56
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | from ..register import REGISTER
8 | from ..config import CONFIG as cfg
9 |
10 | def require_args():
11 | """all args for network objects
12 |
13 | Arguments:
14 | parser {argparse} -- current version of argparse object
15 | """
16 |
17 | known_args, _ = cfg.parse_known_args()
18 |
19 | if (REGISTER.is_package_registered(__name__) and
20 | REGISTER.is_class_registered(__name__, known_args.network)):
21 |
22 | network = get(known_args.network)
23 |
24 | if hasattr(network, 'require_args'):
25 | # get args for network
26 | return network.require_args()
27 |
28 | def get(name, instant=False):
29 | cls = REGISTER.get_class(__name__, name)
30 | if instant:
31 | return cls()
32 | return cls
33 |
34 | def register(name, cls):
35 | REGISTER.set_class(__name__, name, cls)
--------------------------------------------------------------------------------
/docs/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Raymond Wong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/packages/loggers/tf_logger.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-25 17:16:34
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | from tensorboardX import SummaryWriter as TFBWriter
8 |
9 | from ..config import CONFIG as cfg
10 |
11 | class TFLogger():
12 |
13 | @staticmethod
14 | def require_args():
15 |
16 | cfg.add_argument('--log-tfb', action='store_true',
17 | help='use tensorboard to log training process. '
18 | '(default: False)')
19 |
20 | def __init__(self, debugging, *args, **kwargs):
21 | self.debugging = debugging
22 | if not self.debugging and cfg.log_tfb:
23 | self.writer = TFBWriter(*args, **kwargs)
24 |
25 | def __getattr__(self,attr):
26 | if self.debugging or not cfg.log_tfb:
27 | return do_nothing
28 | return self.writer.__getattribute__(attr)
29 |
30 | def do_nothing(*args, **kwargs):
31 | pass
32 |
33 |
34 | from ..register import REGISTER
35 | REGISTER.set_class(REGISTER.get_package_name(__name__), 'tf_logger', TFLogger)
36 |
--------------------------------------------------------------------------------
/packages/lr_policy/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-26 20:11:16
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | from . import step
8 | from . import multistep
9 | from . import fixed
10 |
11 | from ..register import REGISTER
12 | from ..config import CONFIG as cfg
13 |
14 | def require_args():
15 | """all args for optimizer objects
16 |
17 | Arguments:
18 | parser {argparse} -- current version of argparse object
19 | """
20 |
21 | cfg.add_argument('--base-lr', default=1e-1, type=float,
22 | help='base learning rate. (default: 1e-1)')
23 |
24 | known_args, _ = cfg.parse_known_args()
25 |
26 | if (REGISTER.is_package_registered(__name__) and
27 | REGISTER.is_class_registered(__name__, known_args.lr_policy)):
28 |
29 | policy = get(known_args.lr_policy)
30 |
31 | if hasattr(policy, 'require_args'):
32 | return policy.require_args()
33 |
34 | def get(name, instant=False):
35 | cls = REGISTER.get_class(__name__, name)
36 | if instant:
37 | return cls()
38 | return cls
39 |
40 | def register(name, cls):
41 | REGISTER.set_class(__name__, name, cls)
--------------------------------------------------------------------------------
/packages/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-24 23:26:56
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | from . import sgd
8 | from . import adam
9 | from . import rmsprop
10 |
11 | from ..register import REGISTER
12 | from ..config import CONFIG as cfg
13 |
14 | def require_args():
15 | """all args for optimizer objects
16 |
17 | Arguments:
18 | parser {argparse} -- current version of argparse object
19 | """
20 |
21 | cfg.add_argument('--weight-decay', default=0, type=float,
22 | help='weight decay (L2 penalty)')
23 |
24 | known_args, _ = cfg.parse_known_args()
25 |
26 | if (REGISTER.is_package_registered(__name__) and
27 | REGISTER.is_class_registered(__name__, known_args.optimizer)):
28 |
29 | optimizer = get(known_args.optimizer)
30 |
31 | if hasattr(optimizer, 'require_args'):
32 | return optimizer.require_args()
33 |
34 | def get(name, instant=False, params=None):
35 | cls = REGISTER.get_class(__name__, name)
36 | if instant:
37 | return cls.get(params)
38 | return cls
39 |
40 | def register(name, cls):
41 | REGISTER.set_class(__name__, name, cls)
42 |
43 |
44 |
--------------------------------------------------------------------------------
/packages/optimizers/sgd.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-26 19:42:22
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | import sys
8 |
9 | import torch
10 |
11 | from ..loggers.std_logger import STDLogger as logger
12 | from ..config import CONFIG as cfg
13 |
14 | def get(params):
15 | logger.debug('Going to use [SGD] optimizer for training with momentum %.2f, '
16 | 'dampening %f, weight decay %f %s nesterov' % (cfg.momentum,
17 | cfg.dampening, cfg.weight_decay,
18 | ('with' if cfg.nesterov else 'without')))
19 | return torch.optim.SGD(params, lr=cfg.base_lr, momentum=cfg.momentum,
20 | weight_decay=cfg.weight_decay, nesterov=cfg.nesterov,
21 | dampening=cfg.dampening)
22 |
23 | def require_args():
24 |
25 | cfg.add_argument('--momentum', default=0, type=float,
26 | help='momentum factor')
27 | cfg.add_argument('--dampening', default=0, type=float,
28 | help='dampening for momentum')
29 | cfg.add_argument('--nesterov', action='store_true',
30 | help='enables Nesterov momentum')
31 |
32 | from ..register import REGISTER
33 | REGISTER.set_class(REGISTER.get_package_name(__name__), 'sgd', __name__)
--------------------------------------------------------------------------------
/packages/optimizers/rmsprop.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-26 19:42:22
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | import sys
8 |
9 | import torch
10 |
11 | from ..loggers.std_logger import STDLogger as logger
12 | from ..config import CONFIG as cfg
13 |
14 | def get(params):
15 | logger.debug('Going to use [RMSprop] optimizer for training with alpha %.2f, '
16 | 'eps %f, weight decay %f, momentum %f %s centered' % (cfg.alpha,
17 | cfg.eps, cfg.weight_decay, cfg.momentum,
18 | ('with' if cfg.centered else 'without')))
19 | return torch.optim.RMSprop(params, lr=cfg.base_lr,
20 | alpha=cfg.alpha, eps=cfg.eps, momentum=cfg.eps,
21 | centered=cfg.centered, weight_decay=cfg.weight_decay)
22 |
23 | def require_args():
24 |
25 | cfg.add_argument('--alpha', default=0.99, type=float,
26 | help='smoothing constant')
27 | cfg.add_argument('--eps', default=1e-8, type=float,
28 | help=('term added to the denominator to improve'
29 | ' numerical stability'))
30 | cfg.add_argument('--momentum', default=0, type=float,
31 | help='momentum factor')
32 | cfg.add_argument('--centered', action='store_true',
33 | help='whether to compute the centered RMSProp')
34 |
35 | from ..register import REGISTER
36 | REGISTER.set_class(REGISTER.get_package_name(__name__), 'rmsprop', __name__)
--------------------------------------------------------------------------------
/packages/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-28 12:20:22
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | def tuple_or_list(target):
8 | """Check is a string object contains a tuple or list
9 |
10 | If a string likes '[a, b, c]' or '(a, b, c)', then return true,
11 | otherwise false.
12 |
13 | Arguments:
14 | target {str} -- target string
15 |
16 | Returns:
17 | bool -- result
18 | """
19 |
20 | # if the target is a tuple or list originally, then return directly
21 | if isinstance(target, tuple) or isinstance(target, list):
22 | return target
23 |
24 | try:
25 | target = eval(target)
26 | if isinstance(target, tuple) or isinstance(target, list):
27 | return target
28 | except:
29 | pass
30 | return None
31 |
32 | def get_valid_size(target):
33 | """get valid size
34 |
35 | if target is a tuple/list or a string of them, then convert and return
36 | if target is a int/float then return
37 | else return None
38 |
39 | Arguments:
40 | target {Number} -- size
41 | """
42 | ret = tuple_or_list(target)
43 | if ret is not None:
44 | return ret
45 |
46 | try:
47 | ret = int(target)
48 | return ret
49 | except:
50 | pass
51 |
52 | try:
53 | ret = float(target)
54 | return ret
55 | except:
56 | pass
57 |
58 | return None
59 |
60 |
--------------------------------------------------------------------------------
/packages/optimizers/adam.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-26 19:42:22
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | import sys
8 |
9 | import torch
10 |
11 | from ..loggers.std_logger import STDLogger as logger
12 | from ..config import CONFIG as cfg
13 |
14 | def get(params):
15 | logger.debug('Going to use [Adam] optimizer for training with betas %s, '
16 | 'eps %f, weight decay %f %s amsgrad' % ((cfg.beta1, cfg.beta2),
17 | cfg.eps, cfg.weight_decay,
18 | ('with' if cfg.amsgrad else 'without')))
19 | return torch.optim.Adam(params, lr=cfg.base_lr,
20 | betas=(cfg.beta1, cfg.beta2), eps=cfg.eps,
21 | weight_decay=cfg.weight_decay, amsgrad=cfg.amsgrad)
22 |
23 | def require_args():
24 |
25 | cfg.add_argument('--beta1', default=0.9, type=float,
26 | help=('coefficients used for computing running'
27 | ' averages of gradient'))
28 | cfg.add_argument('--beta2', default=0.999, type=float,
29 | help=('coefficients used for computing running'
30 | ' averages of gradient\'s square'))
31 | cfg.add_argument('--eps', default=1e-8, type=float,
32 | help='term added to the denominator to improve numerical stability')
33 | cfg.add_argument('--amsgrad', action='store_true',
34 | help='whether to use the AMSGrad variant of this algorithm')
35 |
36 | from ..register import REGISTER
37 | REGISTER.set_class(REGISTER.get_package_name(__name__), 'adam', __name__)
--------------------------------------------------------------------------------
/datasets/svhn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2018-10-14 19:31:42
4 | # @Author : Jiabo (Raymond) Huang (jiabo.huang@qmul.ac.uk)
5 | # @Link : https://github.com/Raymond-sci
6 | from __future__ import print_function
7 | from PIL import Image
8 | import torchvision.datasets as datasets
9 | import torch.utils.data as data
10 | import numpy as np
11 |
12 | from packages.config import CONFIG as cfg
13 |
14 | class SVHNInstance(datasets.SVHN):
15 | """SVHNInstance Dataset.
16 | """
17 |
18 | @staticmethod
19 | def require_args():
20 | cfg.add_argument('--means', default='(0.4377, 0.4438, 0.4728)',
21 | type=str, help='channel-wise means')
22 | cfg.add_argument('--stds', default='(0.1201, 0.1231, 0.1052)',
23 | type=str, help='channel-wise stds')
24 |
25 | def __init__(self, root, train=True, transform=None, target_trainsform=None, download=False):
26 | self.train = train
27 | super(SVHNInstance, self).__init__(root, split=('train' if train else 'test'),
28 | transform=transform, target_transform=target_trainsform, download=download)
29 |
30 | def __getitem__(self, index):
31 | """
32 | Args:
33 | index (int): Index
34 | Returns:
35 | tuple: (image, target) where target is index of the target class.
36 | """
37 | img, target = self.data[index], int(self.labels[index])
38 |
39 | # doing this so that it is consistent with all other datasets
40 | # to return a PIL Image
41 | img = Image.fromarray(np.transpose(img, (1, 2, 0)))
42 |
43 | if self.transform is not None:
44 | img = self.transform(img)
45 |
46 | if self.target_transform is not None:
47 | target = self.target_transform(target)
48 |
49 | return img, target, index
--------------------------------------------------------------------------------
/lib/criterion.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2018-10-07 22:32:15
4 | # @Author : Jiabo (Raymond) Huang (jiabo.huang@qmul.ac.uk)
5 | # @Link : https://github.com/Raymond-sci
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | import torch.nn as nn
10 |
11 | class Criterion(nn.Module):
12 |
13 | def __init__(self):
14 | super(Criterion, self).__init__()
15 |
16 | def forward(self, x, y, ANs):
17 | batch_size, _ = x.shape
18 |
19 | # split anchor and instance list
20 | anchor_indexes, instance_indexes = self.__split(y, ANs)
21 | preds = F.softmax(x, 1)
22 |
23 | l_ans = 0.
24 | if anchor_indexes.size(0) > 0:
25 | # compute loss for anchor samples
26 | y_ans = y.index_select(0, anchor_indexes)
27 | y_ans_neighbour = ANs.position.index_select(0, y_ans)
28 | neighbours = ANs.neighbours.index_select(0, y_ans_neighbour)
29 | # p_i = \sum_{j \in \Omega_i} p_{i,j}
30 | x_ans = preds.index_select(0, anchor_indexes)
31 | x_ans_neighbour = x_ans.gather(1, neighbours).sum(1)
32 | x_ans = x_ans.gather(1, y_ans.view(-1, 1)).view(-1) + x_ans_neighbour
33 | # NLL: l = -log(p_i)
34 | l_ans = -1 * torch.log(x_ans).sum(0)
35 |
36 | l_inst = 0.
37 | if instance_indexes.size(0) > 0:
38 | # compute loss for instance samples
39 | y_inst = y.index_select(0, instance_indexes)
40 | x_inst = preds.index_select(0, instance_indexes)
41 | # p_i = p_{i, i}
42 | x_inst = x_inst.gather(1, y_inst.view(-1, 1))
43 | # NLL: l = -log(p_i)
44 | l_inst = -1 * torch.log(x_inst).sum(0)
45 |
46 | return (l_inst + l_ans) / batch_size
47 |
48 | def __split(self, y, ANs):
49 | pos = ANs.position.index_select(0, y.view(-1))
50 | return (pos >= 0).nonzero().view(-1), (pos < 0).nonzero().view(-1)
51 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-28 12:34:35
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | from .cifar import CIFAR10Instance, CIFAR100Instance
8 | from .svhn import SVHNInstance
9 |
10 | import torch
11 |
12 | from packages import datasets as cmd_datasets
13 | from packages.config import CONFIG as cfg
14 |
15 | __all__ = ('CIFAR10Instance', 'CIFAR100Instance', 'SVHNInstance')
16 |
17 | def get(name, instant=False):
18 | """
19 | Get dataset instance according to the dataset string and dataroot
20 | """
21 |
22 | # get dataset class
23 | dataset_cls = cmd_datasets.get(name)
24 |
25 | if not instant:
26 | return dataset_cls
27 |
28 | # get transforms for training set
29 | transform_train = cmd_datasets.get_transforms('train', cfg.means, cfg.stds)
30 |
31 | # get transforms for test set
32 | transform_test = cmd_datasets.get_transforms('test', cfg.means, cfg.stds)
33 |
34 | # get trainset and trainloader
35 | trainset = dataset_cls(root=cfg.data_root, train=True, download=True,
36 | transform=transform_train)
37 | # filter trainset if necessary
38 | trainloader = torch.utils.data.DataLoader(trainset,
39 | batch_size=cfg.batch_size, shuffle=True,
40 | num_workers=cfg.workers_num)
41 |
42 | # get testset and testloader
43 | testset = dataset_cls(root=cfg.data_root, train=False, download=True,
44 | transform=transform_test)
45 | # filter testset if necessary
46 | testloader = torch.utils.data.DataLoader(testset, batch_size=100,
47 | shuffle=False, num_workers=cfg.workers_num)
48 |
49 | return trainset, trainloader, testset, testloader
50 |
51 |
52 | cmd_datasets.register('cifar10', CIFAR10Instance)
53 | cmd_datasets.register('cifar100', CIFAR100Instance)
54 | cmd_datasets.register('svhn', SVHNInstance)
55 |
56 |
57 |
58 |
--------------------------------------------------------------------------------
/packages/lr_policy/step.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-26 20:19:00
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | from ..loggers.std_logger import STDLogger as logger
8 | from ..config import CONFIG as cfg
9 |
10 | class StepPolicy:
11 | """Caffe style step decay learning rate policy
12 |
13 | Decay learning rate at every `lr-decay-step` steps after the
14 | first `lr-decay-offset` ones at the rate of `lr-decay-rate`
15 | """
16 |
17 | def __init__(self, *args, **kwargs):
18 | if len(args) + len(kwargs) == 0:
19 | self.__init_by_cfg()
20 | else:
21 | self.__init(*args, **kwargs)
22 |
23 | def __init(self, base_lr, offset, step, rate):
24 | self.base_lr = base_lr
25 | self.offset = offset
26 | self.step = step
27 | self.rate = rate
28 | logger.debug('Going to use [step] learning policy for optimization with '
29 | 'base learning rate %.5f, offset %d, step %d and decay rate %f' %
30 | (base_lr, offset, step, rate))
31 |
32 | def __init_by_cfg(self):
33 | self.__init(cfg.base_lr, cfg.lr_decay_offset,
34 | cfg.lr_decay_step, cfg.lr_decay_rate)
35 |
36 | @staticmethod
37 | def require_args():
38 |
39 | cfg.add_argument('--lr-decay-offset', default=0, type=int,
40 | help='learning rate will start to decay at which step')
41 |
42 | cfg.add_argument('--lr-decay-step', default=0, type=int,
43 | help='learning rate will decay at every n round')
44 |
45 |
46 | cfg.add_argument('--lr-decay-rate', default=0.1, type=float,
47 | help='learning rate will decay at what rate')
48 |
49 | def update(self, steps):
50 | """decay learning rate according to current step
51 |
52 | Decay learning rate at a fixed ratio
53 |
54 | Arguments:
55 | steps {int} -- current steps
56 |
57 | Returns:
58 | int -- updated learning rate
59 | """
60 |
61 | if steps < self.offset:
62 | return self.base_lr
63 |
64 | return self.base_lr * (self.rate ** ((steps - self.offset) // self.step))
65 |
66 |
67 | from ..register import REGISTER
68 | REGISTER.set_class(REGISTER.get_package_name(__name__), 'step', StepPolicy)
69 |
--------------------------------------------------------------------------------
/datasets/cifar.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 |
5 | from PIL import Image
6 | import torchvision.datasets as datasets
7 | import torch.utils.data as data
8 |
9 | from packages.config import CONFIG as cfg
10 |
11 | class CIFAR10Instance(datasets.CIFAR10):
12 | """CIFAR10Instance Dataset.
13 | """
14 |
15 | @staticmethod
16 | def require_args():
17 | cfg.add_argument('--means', default='(0.4914, 0.4822, 0.4465)',
18 | type=str, help='channel-wise means')
19 | cfg.add_argument('--stds', default='(0.2023, 0.1994, 0.2010)',
20 | type=str, help='channel-wise stds')
21 |
22 | def __init__(self, *args, **kwargs):
23 | super(CIFAR10Instance, self).__init__(*args, **kwargs)
24 | self.labels = self.targets
25 |
26 | def __getitem__(self, index):
27 | """
28 | Args:
29 | index (int): Index
30 | Returns:
31 | tuple: (image, target) where target is index of the target class.
32 | """
33 | img, target = self.data[index], self.targets[index]
34 |
35 | # doing this so that it is consistent with all other datasets
36 | # to return a PIL Image
37 | img = Image.fromarray(img)
38 |
39 | if self.transform is not None:
40 | img = self.transform(img)
41 |
42 | if self.target_transform is not None:
43 | target = self.target_transform(target)
44 |
45 | return img, target, index
46 |
47 | class CIFAR100Instance(CIFAR10Instance):
48 | """CIFAR100Instance Dataset.
49 |
50 | This is a subclass of the `CIFAR10Instance` Dataset.
51 | """
52 |
53 | @staticmethod
54 | def args(parser):
55 | parser.add_argument('--means', default='(0.5071, 0.4866, 0.4409)',
56 | type=str, help='channel-wise means')
57 | parser.add_argument('--stds', default='(0.2009, 0.1984, 0.2023)',
58 | type=str, help='channel-wise stds')
59 | return parser
60 |
61 | base_folder = 'cifar-100-python'
62 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
63 | filename = "cifar-100-python.tar.gz"
64 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
65 | train_list = [
66 | ['train', '16019d7e3df5f24257cddd939b257f8d'],
67 | ]
68 |
69 | test_list = [
70 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
71 | ]
72 | meta = {
73 | 'filename': 'meta',
74 | 'key': 'fine_label_names',
75 | 'md5': '7973b15100ade9c7d40fb424638fde48',
76 | }
77 |
--------------------------------------------------------------------------------
/packages/lr_policy/multistep.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-26 20:19:00
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | from ..loggers.std_logger import STDLogger as logger
8 | from ..config import CONFIG as cfg
9 |
10 | class MultiStepPolicy:
11 | """Caffe style step decay learning rate policy
12 |
13 | Decay learning rate at every `lr-decay-step` steps after the
14 | first `lr-decay-offset` ones at the rate of `lr-decay-rate`
15 | """
16 |
17 | def __init__(self, *args, **kwargs):
18 | if len(args) + len(kwargs) == 0:
19 | self.__init_by_cfg()
20 | else:
21 | self.__init(*args, **kwargs)
22 |
23 | def __init(self, base_lr, schedule):
24 | self.base_lr = base_lr
25 | self.schedule = schedule
26 | logger.debug('Going to use [multistep] learning policy for optimization '
27 | 'with base learing rate %.5f and schedule from %s'
28 | % (base_lr, schedule))
29 |
30 | def __init_by_cfg(self):
31 | schedule = cfg.lr_schedule
32 | assert schedule is not None and os.path.exists(schedule), ('Schedule '
33 | 'file not found: [%s]' % schedule)
34 | self.__init(cfg.base_lr, schedule)
35 |
36 | @staticmethod
37 | def require_args():
38 |
39 | cfg.add_argument('--lr-schedule', default=None, type=str,
40 | help='learning rate schedule')
41 |
42 | def update(self, steps):
43 | """update learning rate
44 |
45 | Update learning rate according to current steps and schedule file
46 |
47 | Arguments:
48 | steps {int} -- current steps
49 |
50 | Returns:
51 | float -- updated file
52 | """
53 |
54 | lines = filter(lambda x:not x.startswith('#'),
55 | open(self.schedule, 'r').readlines())
56 | assert len(lines) > 0, 'Invalid schedule file'
57 |
58 | learning_rate = self.base_lr
59 | for line in lines:
60 |
61 | line = line.split('#')[0]
62 | anchor, target = line.strip().split(':')
63 |
64 | if target.startswith('-'):
65 | lr = -1
66 | else:
67 | lr = float(target)
68 | if steps <= anchor:
69 | learning_rate = lr
70 | else:
71 | break
72 |
73 | return learning_rate
74 |
75 |
76 | from ..register import REGISTER
77 | REGISTER.set_class(REGISTER.get_package_name(__name__), 'multistep', MultiStepPolicy)
78 |
--------------------------------------------------------------------------------
/lib/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2018-10-12 21:37:13
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | import os
8 | import sys
9 | import shutil
10 | import numpy as np
11 | from datetime import timedelta
12 |
13 | import torch
14 |
15 | from packages.loggers.std_logger import STDLogger as logger
16 |
17 | class AverageMeter(object):
18 | """Computes and stores the average and current value"""
19 | def __init__(self):
20 | self.reset()
21 |
22 | def reset(self):
23 | self.val = 0
24 | self.avg = 0
25 | self.sum = 0
26 | self.count = 0
27 |
28 | def update(self, val, n=1):
29 | self.val = val
30 | self.sum += val * n
31 | self.count += n
32 | self.avg = self.sum / self.count
33 |
34 | def adjust_learning_rate(optimizer, lr):
35 |
36 | for param_group in optimizer.param_groups:
37 | param_group['lr'] = lr
38 |
39 | return lr
40 |
41 | def time_progress(elapsed_iters, tot_iters, elapsed_time):
42 | estimated_time = 1. * tot_iters / elapsed_iters * elapsed_time
43 | elapsed_time = timedelta(seconds=elapsed_time)
44 | estimated_time = timedelta(seconds=estimated_time)
45 | return tuple(map(lambda x:str(x).split('.')[0], [elapsed_time, estimated_time]))
46 |
47 | def save_ckpt(state_dict, target, is_best=False):
48 | latest, best = map(lambda x:os.path.join(target, x), ['latest.ckpt', 'best.ckpt'])
49 | # save latest checkpoint
50 | torch.save(state_dict, latest)
51 | # if is best, then copy latest to best
52 | if not is_best:
53 | return
54 | shutil.copyfile(latest, best)
55 |
56 | def traverse(net, loader, transform=None, tencrops=False, device='cpu'):
57 |
58 | bak_transform = loader.dataset.transform
59 | if transform is not None:
60 | loader.dataset.transform = transform
61 |
62 | features = None
63 | labels = torch.zeros(len(loader.dataset)).long().to(device)
64 |
65 | with torch.no_grad():
66 | for batch_idx, (inputs, targets, indexes) in enumerate(loader):
67 | logger.progress(batch_idx, len(loader), 'processing %d/%d batch...')
68 |
69 | if tencrops:
70 | bs, ncrops, c, h, w = inputs.size()
71 | inputs = inputs.view(-1, c, h, w)
72 | inputs, targets, indexes = (inputs.to(device), targets.to(device),
73 | indexes.to(device))
74 |
75 | feats = net(inputs)
76 | if tencrops:
77 | feats = torch.squeeze(feats.view(bs, ncrops, -1).mean(1))
78 |
79 | if features is None:
80 | features = torch.zeros(len(loader.dataset), feats.shape[1]).to(device)
81 | features.index_copy_(0, indexes, feats)
82 | labels.index_copy_(0, indexes, targets)
83 |
84 | loader.dataset.transform = bak_transform
85 |
86 | return features, labels
87 |
--------------------------------------------------------------------------------
/packages/config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-24 22:16:37
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | import os
8 | import argparse
9 | import yaml
10 | import importlib
11 | import traceback
12 | from prettytable import PrettyTable
13 |
14 | from .register import REGISTER
15 |
16 | class _META_(type):
17 |
18 | PARSER = argparse.ArgumentParser(formatter_class=
19 | argparse.ArgumentDefaultsHelpFormatter)
20 | ARGS = dict()
21 |
22 | def require_args(self):
23 |
24 | # args for config file
25 | _META_.PARSER.add_argument('--cfgs', type=str, nargs='*',
26 | help='config files to load')
27 |
28 | def parse(self):
29 |
30 | # collect self args
31 | self.require_args()
32 |
33 | # load default args from config file
34 | known_args, _ = _META_.PARSER.parse_known_args()
35 | self.from_files(known_args.cfgs)
36 |
37 | # collect args for packages
38 | for package in REGISTER.get_packages():
39 | m = importlib.import_module(package)
40 | if hasattr(m, 'require_args'):
41 | m.require_args()
42 |
43 | # re-update default value for new args
44 | self.from_files(known_args.cfgs)
45 |
46 | # parse args
47 | _META_.ARGS = _META_.PARSER.parse_args()
48 |
49 | def from_files(self, files):
50 |
51 | # if no config file is provided, skip
52 | if files is None or len(files) <= 0:
53 | return None
54 |
55 | for file in files:
56 | assert os.path.exists(file), "Config file not found: [%s]" % file
57 | configs = yaml.load(open(file, 'r'))
58 | _META_.PARSER.set_defaults(**configs)
59 |
60 | def get(self, attr, default=None):
61 | if hasattr(_META_.ARGS, attr):
62 | return getattr(_META_.ARGS, attr)
63 | return default
64 |
65 | def yaml(self):
66 | config = {k:v for k,v in sorted(vars(_META_.ARGS).items())}
67 | return yaml.safe_dump(config, default_flow_style=False)
68 |
69 | def __getattr__(self, attr):
70 | try:
71 | return _META_.PARSER.__getattribute__(attr)
72 | except AttributeError:
73 | return _META_.ARGS.__getattribute__(attr)
74 | except:
75 | traceback.print_exec()
76 | exit(-1)
77 |
78 | def __str__(self):
79 | MAX_WIDTH = 20
80 | table = PrettyTable(["#", "Key", "Value", "Default"])
81 | table.align = 'l'
82 | for i, (k, v) in enumerate(sorted(vars(_META_.ARGS).items())):
83 | v = str(v)
84 | default = str(_META_.PARSER.get_default(k))
85 | if default == v:
86 | default = '--'
87 | table.add_row([i, k, v[:MAX_WIDTH] + ('...' if len(v) > MAX_WIDTH else ''), default])
88 | return table.get_string()
89 |
90 | class CONFIG(object):
91 | __metaclass__ = _META_
--------------------------------------------------------------------------------
/lib/non_parametric_classifier.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-24 22:16:37
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | import torch
8 | from torch.autograd import Function
9 | from torch import nn
10 | import math
11 |
12 | from packages.register import REGISTER
13 | from packages.config import CONFIG as cfg
14 |
15 | def require_args():
16 | # args for non-parametric classifier (npc)
17 | cfg.add_argument('--npc-temperature', default=0.1, type=float,
18 | help='temperature parameter for softmax')
19 | cfg.add_argument('--npc-momentum', default=0.5, type=float,
20 | help='momentum for non-parametric updates')
21 |
22 | class NonParametricClassifierOP(Function):
23 | @staticmethod
24 | def forward(self, x, y, memory, params):
25 |
26 | T = params[0].item()
27 | batchSize = x.size(0)
28 |
29 | # inner product
30 | out = torch.mm(x.data, memory.t())
31 | out.div_(T) # batchSize * N
32 |
33 | self.save_for_backward(x, memory, y, params)
34 |
35 | return out
36 |
37 | @staticmethod
38 | def backward(self, gradOutput):
39 | x, memory, y, params = self.saved_tensors
40 | batchSize = gradOutput.size(0)
41 | T = params[0].item()
42 | momentum = params[1].item()
43 |
44 | # add temperature
45 | gradOutput.data.div_(T)
46 |
47 | # gradient of linear
48 | gradInput = torch.mm(gradOutput.data, memory)
49 | gradInput.resize_as_(x)
50 |
51 | # update the memory
52 | weight_pos = memory.index_select(0, y.data.view(-1)).resize_as_(x)
53 | weight_pos.mul_(momentum)
54 | weight_pos.add_(torch.mul(x.data, 1-momentum))
55 | w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5)
56 | updated_weight = weight_pos.div(w_norm)
57 | memory.index_copy_(0, y, updated_weight)
58 |
59 | return gradInput, None, None, None, None
60 |
61 | class NonParametricClassifier(nn.Module):
62 | """Non-parametric Classifier
63 |
64 | Non-parametric Classifier from
65 | "Unsupervised Feature Learning via Non-Parametric Instance Discrimination"
66 |
67 | Extends:
68 | nn.Module
69 | """
70 |
71 | def __init__(self, inputSize, outputSize, T=0.05, momentum=0.5):
72 | """Non-parametric Classifier initial functin
73 |
74 | Initial function for non-parametric classifier
75 |
76 | Arguments:
77 | inputSize {int} -- in-channels dims
78 | outputSize {int} -- out-channels dims
79 |
80 | Keyword Arguments:
81 | T {int} -- distribution temperate (default: {0.05})
82 | momentum {int} -- memory update momentum (default: {0.5})
83 | """
84 | super(NonParametricClassifier, self).__init__()
85 | stdv = 1 / math.sqrt(inputSize)
86 | self.nLem = outputSize
87 |
88 | self.register_buffer('params',
89 | torch.tensor([cfg.npc_temperature, cfg.npc_momentum]))
90 | stdv = 1. / math.sqrt(inputSize/3)
91 | self.register_buffer('memory', torch.rand(outputSize, inputSize)
92 | .mul_(2*stdv).add_(-stdv))
93 |
94 | def forward(self, x, y):
95 | out = NonParametricClassifierOP.apply(x, y, self.memory, self.params)
96 | return out
97 |
98 | REGISTER.set_package(__name__)
99 | REGISTER.set_class(__name__, 'npc', NonParametricClassifier)
100 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # AND: Anchor Neighbourhood Discovery
2 |
3 | *Accepted by 36th International Conference on Machine Learning (ICML 2019)*.
4 |
5 | Pytorch implementation of [Unsupervised Deep Learning by Neighbourhood Discovery](https://arxiv.org/abs/1904.11567).
6 |
7 |
8 |
9 |
10 | ## Highlight
11 | + We propose the idea of exploiting local neighbourhoods for unsupervised deep learning. This strategy preserves the capability of clustering for class boundary inference whilst minimising the negative impact of class inconsistency typically encountered in clusters.
12 | + We formulate an *Anchor Neighbourhood Discovery (AND)* approach to progressive unsupervised deep learning. The AND model not only generalises the idea of sample specificity learning, but also additionally considers the originally missing sample-to-sample correlation during model learning by a novel neighbourhood supervision design.
13 | + We further introduce a curriculum learning algorithm to gradually perform neighbourhood discovery for maximising the class consistency of neighbourhoods therefore enhancing the unsupervised learning capability.
14 |
15 | ## Main results
16 | The proposed AND model was evaluated on four object image classification datasets including CIFAR 10/100, SVHN and ImageNet12. Results are shown at the following tables:
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | ## Reproduction
25 |
26 | ### Requirements
27 | Python 2.7 and Pytorch 1.0 are required. Please refer to `/path/to/AND/requirements.yaml` for other necessary modules. Conda environment we used for the experiments can also be rebuilt according to it.
28 |
29 | ### Usages
30 |
31 | 1. Clone this repo: `git clone https://github.com/Raymond-sci/AND.git`
32 | 2. Download datasets and store them in `/path/to/AND/data`. (Soft link is recommended to avoid redundant copies of datasets)
33 | 2. To reproduce our reported result of ResNet18 on CIFAR10, please use the following command:`python main.py --cfgs configs/base.yaml configs/cifar10.yaml`
34 | 3. Running on GPUs: code will be run on CPU by default, use this flag to specify the gpu devices which you want to use
35 | 4. To evaluate trained models, use `--resume` to set the path of the generated checkpoint file and use `--test-only` flag to exit the program after evaluation
36 |
37 | Every time the `main.py` is run, a new session will be started with the name of current timestamp and all the related files will be stored in folder `sessions/timestamp/` including checkpoints, logs, etc.
38 |
39 | ### Pre-trained model
40 | To play with the pre-trained model, please go to [ResNet18](https://drive.google.com/file/d/1tMopB0iLPaJzw81tqZuXbK6YYAQRLXA-/view?usp=sharing) / [AlexNet](https://drive.google.com/file/d/1SeLi34LxuThcLulBaWViwy3kLYQQWX0l/view?usp=sharing). A few things need to be noticed:
41 | + The model is saved in **pytorch** format
42 | + It expects RGB images that their pixel values are normalised with the following mean RGB values `mean=[0.485, 0.456, 0.406]` and std RGB values `std=[0.229, 0.224, 0.225]`. Prior to normalisation the range of the image values must be `[0.0, 1.0]`.
43 |
44 | ## License
45 | This project is licensed under the MIT License. You may find out more [here](./LICENSE).
46 |
47 | ## Reference
48 | If you use this code, please cite the following paper:
49 |
50 | Jiabo Huang, Qi Dong, Shaogang Gong and Xiatian Zhu. "Unsupervised Deep Learning by Neighbourhood Discovery." Proc. ICML (2019).
51 |
52 | ```
53 | @InProceedings{huang2018and,
54 | title={Unsupervised Deep Learning by Neighbourhood Discovery},
55 | author={Jiabo Huang, Qi Dong, Shaogang Gong and Xiatian Zhu},
56 | booktitle={Proceedings of the International Conference on machine learning (ICML)},
57 | year={2019},
58 | }
59 | ```
--------------------------------------------------------------------------------
/requirements.yaml:
--------------------------------------------------------------------------------
1 | name: pytorch
2 | channels:
3 | - menpo
4 | - pytorch
5 | - conda-forge
6 | - anaconda
7 | - defaults
8 | dependencies:
9 | - backports=1.0=py27_1
10 | - backports.functools_lru_cache=1.5=py27_1
11 | - backports_abc=0.5=py27_0
12 | - blas=1.0=mkl
13 | - cffi=1.11.5=py27he75722e_1
14 | - cloudpickle=0.6.1=py27_0
15 | - cryptography=2.4.2=py27h1ba5d50_0
16 | - cryptography-vectors=2.3.1=py27_0
17 | - cudatoolkit=8.0=3
18 | - cycler=0.10.0=py27_0
19 | - dask-core=0.19.4=py27_0
20 | - dbus=1.13.2=h714fa37_1
21 | - decorator=4.3.0=py27_0
22 | - expat=2.2.6=he6710b0_0
23 | - fontconfig=2.13.0=h9420a91_0
24 | - freetype=2.9.1=h8a8886c_1
25 | - functools32=3.2.3.2=py27_1
26 | - futures=3.2.0=py27_0
27 | - glib=2.56.2=hd408876_0
28 | - gst-plugins-base=1.14.0=hbbd80ab_1
29 | - gstreamer=1.14.0=hb453b48_1
30 | - imageio=2.4.1=py27_0
31 | - intel-openmp=2019.0=118
32 | - kiwisolver=1.0.1=py27hf484d3e_0
33 | - libedit=3.1.20170329=h6b74fdf_2
34 | - libgcc=7.2.0=h69d50b8_2
35 | - libgcc-ng=8.2.0=hdf63c60_1
36 | - libgfortran-ng=7.3.0=hdf63c60_0
37 | - libpng=1.6.36=hbc83047_0
38 | - libstdcxx-ng=8.2.0=hdf63c60_1
39 | - libtiff=4.0.9=he85c1e1_2
40 | - libuuid=1.0.3=h1bed415_2
41 | - libxcb=1.13=h1bed415_1
42 | - libxml2=2.9.8=h26e45fe_1
43 | - matplotlib=2.2.3=py27hb69df0a_0
44 | - mkl=2019.0=118
45 | - mkl_fft=1.0.4=py27h4414c95_1
46 | - mkl_random=1.0.1=py27h4414c95_1
47 | - ncurses=6.1=hf484d3e_0
48 | - networkx=2.2=py27_1
49 | - ninja=1.8.2=py27h6bb024c_1
50 | - numpy=1.15.1=py27h1d66e8a_0
51 | - numpy-base=1.15.1=py27h81de0dd_0
52 | - olefile=0.46=py27_0
53 | - openssl=1.1.1=h7b6447c_0
54 | - pcre=8.42=h439df22_0
55 | - pillow=5.2.0=py27heded4f4_0
56 | - pip=10.0.1=py27_0
57 | - pycparser=2.18=py27_1
58 | - pyopenssl=18.0.0=py27_0
59 | - pyparsing=2.2.0=py27_1
60 | - pyqt=5.9.2=py27h05f1152_2
61 | - python=2.7.15=h9bab390_6
62 | - python-dateutil=2.7.3=py27_0
63 | - pytz=2018.5=py27_0
64 | - pywavelets=1.0.1=py27hdd07704_0
65 | - pyyaml=3.13=py27h14c3975_0
66 | - qt=5.9.7=h5867ecd_1
67 | - readline=7.0=h7b6447c_5
68 | - scikit-image=0.14.0=py27hf484d3e_1
69 | - scikit-learn=0.19.2=py27h4989274_0
70 | - scipy=1.1.0=py27hfa4b5c9_1
71 | - setuptools=40.2.0=py27_0
72 | - singledispatch=3.4.0.3=py27h9bcb476_0
73 | - sip=4.19.8=py27hf484d3e_0
74 | - six=1.11.0=py27_1
75 | - sqlite=3.26.0=h7b6447c_0
76 | - subprocess32=3.5.2=py27h14c3975_0
77 | - tk=8.6.8=hbc83047_0
78 | - toolz=0.9.0=py27_0
79 | - tornado=5.1=py27h14c3975_0
80 | - wheel=0.31.1=py27_0
81 | - xz=5.2.4=h14c3975_4
82 | - yaml=0.1.7=h96e3832_1
83 | - asn1crypto=0.24.0=py27_1003
84 | - chardet=3.0.4=py27_3
85 | - dominate=2.3.1=py_1
86 | - enum34=1.1.6=py27_1001
87 | - idna=2.7=py27_1002
88 | - ipaddress=1.0.22=py_1
89 | - libsodium=1.0.16=h470a237_1
90 | - prettytable=0.7.2=py_2
91 | - pysocks=1.6.8=py27_1002
92 | - python-lmdb=0.92=py27_0
93 | - pyzmq=17.1.2=py27hae99301_0
94 | - requests=2.19.1=py27_1
95 | - torchfile=0.1.0=py_0
96 | - urllib3=1.23=py27_1
97 | - visdom=0.1.8.5=0
98 | - websocket-client=0.53.0=py27_0
99 | - zeromq=4.2.5=hfc679d8_6
100 | - ca-certificates=2018.12.5=0
101 | - certifi=2018.11.29=py27_0
102 | - cython=0.29.7=py27he6710b0_0
103 | - icu=58.2=h9c2bf20_1
104 | - jpeg=9b=h024ee3a_2
105 | - libffi=3.2.1=hd88cf55_4
106 | - zlib=1.2.11=ha838bed_2
107 | - opencv=2.4.11=nppy27_0
108 | - cuda80=1.0=h205658b_0
109 | - cuda91=1.0=h4c16780_0
110 | - cuda92=1.0=0
111 | - faiss-gpu=1.4.0=py27_cuda8.0.61_1
112 | - pytorch=1.0.1=py2.7_cuda8.0.61_cudnn7.1.2_2
113 | - torchvision=0.2.2=py_3
114 | - pip:
115 | - cython-based-reid-evaluation-code==0.0.0
116 | - dask==0.19.4
117 | - faiss==1.4.0
118 | - humanfriendly==4.18
119 | - lmdb==0.92
120 | - monotonic==1.5
121 | - protobuf==3.7.0
122 | - tensorboardx==1.6
123 | - torch==1.0.1.post2
124 | prefix: /import/sgg-homes/jh327/Applications/conda/envs/pytorch
125 |
126 |
--------------------------------------------------------------------------------
/packages/session.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-27 09:38:53
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | import os
8 | import time
9 |
10 | import torch
11 | import importlib
12 | import numpy as np
13 |
14 | from loggers.std_logger import STDLogger as logger
15 | from register import REGISTER
16 | from config import CONFIG as cfg
17 |
18 | REGISTER.set_package(__name__)
19 |
20 | def require_args():
21 |
22 | # timestamp
23 | stt = time.strftime('%Y%m%d-%H%M%S', time.gmtime())
24 | tt = int(time.time())
25 |
26 | cfg.add_argument('--session', default=stt, type=str,
27 | help='session name (default: %s)' % stt)
28 | cfg.add_argument('--sess-dir', default='sessions', type=str,
29 | help='directory to store session. (default: sessions)')
30 | cfg.add_argument('--print-args', action='store_true',
31 | help='do nothing but print all args. (default: False)')
32 | cfg.add_argument('--seed', default=tt, type=int,
33 | help='session random seed. (default: %d)' % tt)
34 | cfg.add_argument('--brief', action='store_true',
35 | help='print log with priority higher than debug. '
36 | '(default: False)')
37 | cfg.add_argument('--debug', action='store_true',
38 | help='if debugging, no log or checkpoint files will be stored. '
39 | '(default: False)')
40 | cfg.add_argument('--gpus', default='', type=str,
41 | help='available gpu list. (default: \'\')')
42 | cfg.add_argument('--resume', default=None, type=str,
43 | help='path to resume session. (default: None)')
44 | cfg.add_argument('--restart', action='store_true',
45 | help='load session status and start a new one. '
46 | '(default: False)')
47 |
48 | def run(main):
49 |
50 | # import main module
51 | main = importlib.import_module(main)
52 | # parse args
53 | main.require_args()
54 | cfg.parse()
55 |
56 | # setup session according to args
57 | setup()
58 |
59 | # run main function
60 | main.main()
61 |
62 | def setup():
63 | """
64 | set up common environment for training
65 | """
66 |
67 | # print args
68 | if not cfg.brief:
69 | print cfg
70 | # exit if require to print args only
71 | if cfg.print_args:
72 | exit(0)
73 |
74 | # if not verbose, set log level to info
75 | logger.setup(logger.INFO if cfg.brief else logger.DEBUG)
76 |
77 | logger.info('Start to setup session')
78 |
79 | # fix random seeds
80 | torch.manual_seed(cfg.seed)
81 | torch.cuda.manual_seed_all(cfg.seed)
82 | np.random.seed(cfg.seed)
83 |
84 | # set visible gpu devices at main function
85 | logger.info('Visible gpu devices are: %s' % cfg.gpus)
86 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpus
87 |
88 | # setup session name and session path
89 | if cfg.resume and cfg.resume.strip() != '' and not cfg.restart:
90 | assert os.path.exists(cfg.resume), ("Resume file not "
91 | "found: %s" % cfg.resume)
92 | ckpt = torch.load(cfg.resume)
93 | if 'session' in ckpt:
94 | cfg.session = ckpt['session']
95 | cfg.sess_dir = os.path.join(cfg.sess_dir, cfg.session)
96 | logger.info('Current session name: %s' % cfg.session)
97 |
98 | # setup checkpoint dir
99 | cfg.ckpt_dir = os.path.join(cfg.sess_dir, 'checkpoint')
100 | if not os.path.exists(cfg.ckpt_dir) and not cfg.debug:
101 | os.makedirs(cfg.ckpt_dir)
102 |
103 | # redirect logs to file
104 | if cfg.log_file and not cfg.debug:
105 | logger.setup(to_file=os.path.join(cfg.sess_dir, 'log.txt'))
106 |
107 | # setup tfb log dir
108 | cfg.tfb_dir = os.path.join(cfg.sess_dir, 'tfboard')
109 | if not os.path.exists(cfg.tfb_dir) and not cfg.debug:
110 | os.makedirs(cfg.tfb_dir)
111 |
112 | # store options at log directory
113 | if not cfg.debug:
114 | with open(os.path.join(cfg.sess_dir, 'config.yaml'), 'w') as out:
115 | out.write(cfg.yaml() + '\n')
116 |
117 |
118 |
119 |
--------------------------------------------------------------------------------
/lib/protocols.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import datasets
3 | from lib.utils import AverageMeter, traverse
4 | import sys
5 |
6 | from packages.register import REGISTER
7 | from packages.loggers.std_logger import STDLogger as logger
8 |
9 | def NN(net, npc, trainloader, testloader, K=0, sigma=0.1,
10 | recompute_memory=False, device='cpu'):
11 |
12 | # switch model to evaluation mode
13 | net.eval()
14 |
15 | # tracking variables
16 | correct = 0.
17 | total = 0
18 |
19 | trainFeatures = npc.memory
20 | trainLabels = torch.LongTensor(trainloader.dataset.labels).to(device)
21 |
22 | # recompute features for training samples
23 | if recompute_memory:
24 | trainFeatures, trainLabels = traverse(net, trainloader,
25 | testloader.dataset.transform, device)
26 | trainFeatures = trainFeatures.t()
27 |
28 | # start to evaluate
29 | with torch.no_grad():
30 | for batch_idx, (inputs, targets, indexes) in enumerate(testloader):
31 | logger.progress(batch_idx, len(testloader), 'processing %d/%d batch...')
32 | inputs, targets = inputs.to(device), targets.to(device)
33 | batchSize = inputs.size(0)
34 |
35 | # forward
36 | features = net(inputs)
37 |
38 | # cosine similarity
39 | dist = torch.mm(features, trainFeatures)
40 |
41 | yd, yi = dist.topk(1, dim=1, largest=True, sorted=True)
42 | candidates = trainLabels.view(1,-1).expand(batchSize, -1)
43 | retrieval = torch.gather(candidates, 1, yi)
44 |
45 | retrieval = retrieval.narrow(1, 0, 1).clone().view(-1)
46 | yd = yd.narrow(1, 0, 1)
47 |
48 | total += targets.size(0)
49 | correct += retrieval.eq(targets.data).sum().item()
50 |
51 | return correct/total
52 |
53 | def kNN(net, npc, trainloader, testloader, K=200, sigma=0.1,
54 | recompute_memory=False, device='cpu'):
55 |
56 | # set the model to evaluation mode
57 | net.eval()
58 |
59 | # tracking variables
60 | total = 0
61 |
62 | trainFeatures = npc.memory
63 | trainLabels = torch.LongTensor(trainloader.dataset.labels).to(device)
64 |
65 | # recompute features for training samples
66 | if recompute_memory:
67 | trainFeatures, trainLabels = traverse(net, trainloader,
68 | testloader.dataset.transform, device)
69 | trainFeatures = trainFeatures.t()
70 | C = trainLabels.max() + 1
71 |
72 | # start to evaluate
73 | top1 = 0.
74 | top5 = 0.
75 | with torch.no_grad():
76 | retrieval_one_hot = torch.zeros(K, C.item()).to(device)
77 | for batch_idx, (inputs, targets, indexes) in enumerate(testloader):
78 | logger.progress(batch_idx, len(testloader), 'processing %d/%d batch...')
79 |
80 | batchSize = inputs.size(0)
81 | targets, inputs = targets.to(device), inputs.to(device)
82 |
83 | # forward
84 | features = net(inputs)
85 |
86 | # cosine similarity
87 | dist = torch.mm(features, trainFeatures)
88 |
89 | yd, yi = dist.topk(K, dim=1, largest=True, sorted=True)
90 | candidates = trainLabels.view(1,-1).expand(batchSize, -1)
91 | retrieval = torch.gather(candidates, 1, yi)
92 |
93 | retrieval_one_hot.resize_(batchSize * K, C).zero_()
94 | retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1)
95 | yd_transform = yd.clone().div_(sigma).exp_()
96 | probs = torch.sum(torch.mul(retrieval_one_hot.view(batchSize, -1 , C),
97 | yd_transform.view(batchSize, -1, 1)), 1)
98 | _, predictions = probs.sort(1, True)
99 |
100 | # Find which predictions match the target
101 | correct = predictions.eq(targets.data.view(-1,1))
102 |
103 | top1 = top1 + correct.narrow(1,0,1).sum().item()
104 | top5 = top5 + correct.narrow(1,0,5).sum().item()
105 |
106 | total += targets.size(0)
107 |
108 | return top1/total
109 |
110 | def get(name):
111 | return REGISTER.get_class(__name__, name)
112 |
113 | REGISTER.set_package(__name__)
114 | REGISTER.set_class(__name__, 'knn', kNN)
115 | REGISTER.set_class(__name__, 'nn', NN)
--------------------------------------------------------------------------------
/models/resnet_cifar.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 |
3 | For Pre-activation ResNet, see 'preact_resnet.py'.
4 |
5 | Reference:
6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
8 | '''
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from lib.normalize import Normalize
13 |
14 | from torch.autograd import Variable
15 |
16 |
17 | class BasicBlock(nn.Module):
18 | expansion = 1
19 |
20 | def __init__(self, in_planes, planes, stride=1):
21 | super(BasicBlock, self).__init__()
22 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
23 | self.bn1 = nn.BatchNorm2d(planes)
24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
25 | self.bn2 = nn.BatchNorm2d(planes)
26 |
27 | self.shortcut = nn.Sequential()
28 | if stride != 1 or in_planes != self.expansion*planes:
29 | self.shortcut = nn.Sequential(
30 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
31 | nn.BatchNorm2d(self.expansion*planes)
32 | )
33 |
34 | def forward(self, x):
35 | out = F.relu(self.bn1(self.conv1(x)))
36 | out = self.bn2(self.conv2(out))
37 | out += self.shortcut(x)
38 | out = F.relu(out)
39 | return out
40 |
41 |
42 | class Bottleneck(nn.Module):
43 | expansion = 4
44 |
45 | def __init__(self, in_planes, planes, stride=1):
46 | super(Bottleneck, self).__init__()
47 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
48 | self.bn1 = nn.BatchNorm2d(planes)
49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
50 | self.bn2 = nn.BatchNorm2d(planes)
51 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
52 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
53 |
54 | self.shortcut = nn.Sequential()
55 | if stride != 1 or in_planes != self.expansion*planes:
56 | self.shortcut = nn.Sequential(
57 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
58 | nn.BatchNorm2d(self.expansion*planes)
59 | )
60 |
61 | def forward(self, x):
62 | out = F.relu(self.bn1(self.conv1(x)))
63 | out = F.relu(self.bn2(self.conv2(out)))
64 | out = self.bn3(self.conv3(out))
65 | out += self.shortcut(x)
66 | out = F.relu(out)
67 | return out
68 |
69 |
70 | class ResNet(nn.Module):
71 | def __init__(self, block, num_blocks, low_dim=128):
72 | super(ResNet, self).__init__()
73 | self.in_planes = 64
74 |
75 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
76 | self.bn1 = nn.BatchNorm2d(64)
77 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
78 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
79 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
80 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
81 | self.linear = nn.Linear(512*block.expansion, low_dim)
82 | self.l2norm = Normalize(2)
83 |
84 | def _make_layer(self, block, planes, num_blocks, stride):
85 | strides = [stride] + [1]*(num_blocks-1)
86 | layers = []
87 | for stride in strides:
88 | layers.append(block(self.in_planes, planes, stride))
89 | self.in_planes = planes * block.expansion
90 | return nn.Sequential(*layers)
91 |
92 | def forward(self, x):
93 | out = F.relu(self.bn1(self.conv1(x)))
94 | out = self.layer1(out)
95 | out = self.layer2(out)
96 | out = self.layer3(out)
97 | out = self.layer4(out)
98 | out = F.avg_pool2d(out, 4)
99 | out = out.view(out.size(0), -1)
100 | out = self.linear(out)
101 | out = self.l2norm(out)
102 | return out
103 |
104 |
105 | def ResNet18(low_dim=128):
106 | return ResNet(BasicBlock, [2,2,2,2], low_dim)
107 |
108 | def ResNet34(low_dim=128):
109 | return ResNet(BasicBlock, [3,4,6,3], low_dim)
110 |
111 | def ResNet50(low_dim=128):
112 | return ResNet(Bottleneck, [3,4,6,3], low_dim)
113 |
114 | def ResNet101(low_dim=128):
115 | return ResNet(Bottleneck, [3,4,23,3], low_dim)
116 |
117 | def ResNet152(low_dim=128):
118 | return ResNet(Bottleneck, [3,8,36,3], low_dim)
119 |
120 |
121 | def test():
122 | net = ResNet18()
123 | y = net(Variable(torch.randn(1,3,32,32)))
124 | print(y.size())
125 |
126 | # test()
127 |
--------------------------------------------------------------------------------
/packages/register.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-25 15:08:10
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | import sys
8 |
9 | class REGISTER:
10 | """Singleton Register Class
11 |
12 | This class is used to register all modules in the whole package
13 |
14 | Variables:
15 | PACKAGES_2_CLASSES {dict} -- record the registered modules
16 | """
17 |
18 | PACKAGES_2_CLASSES = dict()
19 |
20 | @staticmethod
21 | def set_package(package):
22 | """set package
23 |
24 | This function is called when a new package has been created
25 |
26 | Arguments:
27 | package {str} -- package name
28 | """
29 | if not package in REGISTER.PACKAGES_2_CLASSES:
30 | REGISTER.PACKAGES_2_CLASSES[package] = dict()
31 |
32 | @staticmethod
33 | def get_packages():
34 | """get packages list
35 |
36 | Get the packages list stored in REGISTER.PACKAGES_2_CLASSES
37 |
38 | Returns:
39 | list -- packages list
40 | """
41 | return REGISTER.PACKAGES_2_CLASSES.keys()
42 |
43 | @staticmethod
44 | def get_package_name(module):
45 | """get the package name of a module
46 |
47 | Get the package name of the target module by removing the text after
48 | the last dot occured in the provided name
49 |
50 | Arguments:
51 | module {str} -- target module name
52 | """
53 | return '.'.join(module.split('.')[:-1])
54 |
55 | @staticmethod
56 | def is_package_registered(package):
57 | """Check whether a package has been registered
58 |
59 | Check whether a package has been registered according to whether its
60 | name is occured in PACKAGES_2_CLASSES
61 |
62 | Arguments:
63 | package {str} -- target package name
64 |
65 | Returns:
66 | bool -- indicating whether package has been registered
67 | """
68 | return package in REGISTER.PACKAGES_2_CLASSES
69 |
70 | @staticmethod
71 | def _check_package(package):
72 | assert package in REGISTER.PACKAGES_2_CLASSES, ('No package named [%s] '
73 | 'has been registered' % package)
74 |
75 | @staticmethod
76 | def get_classes(package):
77 | """get all classes on the target package
78 |
79 | Get all registered classes on the target package according to the
80 | PACKAGES_2_CLASSES dictory
81 |
82 | Arguments:
83 | package {str} -- package name used when registering
84 |
85 | Returns:
86 | dict -- {'name1' : class1, 'name2' : class2, ...}
87 | """
88 |
89 | # make sure the packages has been registered
90 | REGISTER._check_package(package)
91 |
92 | return REGISTER.PACKAGES_2_CLASSES[package]
93 |
94 | @staticmethod
95 | def is_class_registered(package, name):
96 | return (package in REGISTER.PACKAGES_2_CLASSES and
97 | name in REGISTER.PACKAGES_2_CLASSES[package])
98 |
99 | @staticmethod
100 | def get_class(package, name):
101 | """get class
102 |
103 | This function is used to get the class named `name` from
104 | package `package`
105 |
106 | Arguments:
107 | package {str} -- package name, should be the same as the one when registering
108 | name {str} -- class name, should be the same as the one when registering
109 |
110 | Returns:
111 | [Object] -- found class
112 | """
113 |
114 | # make sure the package has been registered
115 | REGISTER._check_package(package)
116 |
117 | pack = REGISTER.PACKAGES_2_CLASSES[package]
118 |
119 | assert name in pack, ('No class named [%s] has been registered '
120 | 'in package [%s]' % (name, package))
121 |
122 | return pack[name]
123 |
124 | @staticmethod
125 | def set_class(package, name, cls):
126 | """set class
127 |
128 | This function is used to register a class with a specific name in
129 | the target package
130 |
131 | Arguments:
132 | package {str} -- target package name
133 | name {str} -- class name which will be used when query
134 | cls {Object} -- class object
135 | """
136 |
137 | # get the corresponding package name
138 | # package = '.'.join(package.split('.')[:-1])
139 | if not package in REGISTER.PACKAGES_2_CLASSES:
140 | REGISTER.set_package(package)
141 |
142 | # if the `cls` arg is a string, then get the real class object
143 | if isinstance(cls, str):
144 | cls = sys.modules[cls]
145 |
146 | REGISTER.PACKAGES_2_CLASSES[package][name] = cls
--------------------------------------------------------------------------------
/packages/loggers/std_logger.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2018-11-11 15:08:06
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | import sys
8 | import logging
9 |
10 | from ..config import CONFIG as cfg
11 |
12 | class ColoredFormatter(logging.Formatter):
13 | """Colored Formatter for logging module
14 |
15 | different logging level will used different color when printing
16 |
17 | Extends:
18 | logging.Formatter
19 |
20 | Variables:
21 | BLACK, RED, GREEN, YELLOW, BLUE, MAGEENTA, CYAN, WHITE {[Number]} -- [default colors]
22 | RESET_SEQ {str} -- [Sequence end flag]
23 | COLOR_SEQ {str} -- [color sequence start flag]
24 | BOLD_SEQ {str} -- [bold sequence start flag]
25 | COLORS {dict} -- [logging level to color dictionary]
26 | """
27 |
28 | BLACK, RED, GREEN, YELLOW, BLUE, MAGEENTA, CYAN, WHITE = range(8)
29 | RESET_SEQ = "\033[0m"
30 | COLOR_SEQ = "\033[%dm"
31 | BOLD_SEQ = "\033[1m"
32 | COLORS = {
33 | 'WARNING': YELLOW,
34 | 'DEBUG': GREEN,
35 | 'CRITICAL': BLUE,
36 | 'ERROR': RED
37 | }
38 |
39 | def __init__(self, *args, **kwargs):
40 | logging.Formatter.__init__(self, *args, **kwargs)
41 |
42 | def format(self, record):
43 | msg = logging.Formatter.format(self, record)
44 | levelname = record.levelname
45 | msg_color = msg
46 | if levelname in ColoredFormatter.COLORS:
47 | msg_color = (ColoredFormatter.COLOR_SEQ %
48 | (30 + ColoredFormatter.COLORS[levelname]) + msg +
49 | ColoredFormatter.RESET_SEQ)
50 | return msg_color
51 |
52 | class STDLogger:
53 | '''
54 | static class for logging
55 | call setup() first to set log level and then call info/debug/error/warn
56 | to print log msg
57 | '''
58 |
59 | LOGGER = None
60 |
61 | INFO = logging.INFO
62 | DEBUG = logging.DEBUG
63 | C_LEVEL = logging.DEBUG
64 |
65 | FMT_GENERAL = logging.Formatter('[%(levelname)s][%(asctime)s]\t%(message)s',
66 | datefmt='%Y-%m-%d %H:%M:%S')
67 | FMT_COLOR = ColoredFormatter('[%(levelname)s][%(asctime)s]\t%(message)s',
68 | datefmt='%Y-%m-%d %H:%M:%S')
69 |
70 | @staticmethod
71 | def require_args():
72 |
73 | cfg.add_argument('--log-file', action='store_true',
74 | help='store log to file. (default: False)')
75 |
76 | @staticmethod
77 | def setup(level=None, to_file=None):
78 | # if level is not provided, then dont change
79 | level = level if level is not None else STDLogger.C_LEVEL
80 |
81 | if STDLogger.LOGGER is None:
82 | STDLogger.LOGGER = logging.getLogger(__name__)
83 | STDLogger.LOGGER.propagate = False
84 | # declare and set console handler
85 | console_handler = logging.StreamHandler()
86 | console_handler.setFormatter(STDLogger.FMT_COLOR)
87 | STDLogger.LOGGER.addHandler(console_handler)
88 |
89 | if to_file is not None:
90 | # declare and set file handler
91 | STDLogger.LOGGER.debug('Log will be stored in %s' % to_file)
92 | file_handler = logging.FileHandler(to_file)
93 | file_handler.setFormatter(STDLogger.FMT_GENERAL)
94 | STDLogger.LOGGER.addHandler(file_handler)
95 |
96 | STDLogger.LOGGER.setLevel(level)
97 | STDLogger.C_LEVEL = level
98 |
99 | @staticmethod
100 | def check():
101 | if not isinstance(STDLogger.LOGGER, logging.Logger):
102 | STDLogger.setup()
103 | # raise ValueError('Call logger.setup(level) to initialize')
104 |
105 | @staticmethod
106 | def info(*args, **kwargs):
107 | STDLogger.check()
108 | STDLogger.erase()
109 | return STDLogger.LOGGER.info(*args, **kwargs)
110 |
111 | @staticmethod
112 | def debug(*args, **kwargs):
113 | STDLogger.check()
114 | STDLogger.erase()
115 | return STDLogger.LOGGER.debug(*args, **kwargs)
116 |
117 | @staticmethod
118 | def error(*args, **kwargs):
119 | STDLogger.check()
120 | STDLogger.erase()
121 | return STDLogger.LOGGER.error(*args, **kwargs)
122 |
123 | @staticmethod
124 | def warn(*args, **kwargs):
125 | STDLogger.check()
126 | STDLogger.erase()
127 | return STDLogger.LOGGER.warn(*args, **kwargs)
128 |
129 | @staticmethod
130 | def erase_lines(n=1):
131 | for _ in xrange(n):
132 | sys.stdout.write('\x1b[1A')
133 | sys.stdout.write('\x1b[2K')
134 | sys.stdout.flush()
135 |
136 | @staticmethod
137 | def go_up():
138 | sys.stdout.write('\x1b[1A')
139 | sys.stdout.flush()
140 |
141 | @staticmethod
142 | def erase():
143 | sys.stdout.write('\x1b[2K')
144 | sys.stdout.flush()
145 |
146 | @staticmethod
147 | def progress(current, total, msg='processing %d/%d item...'):
148 | STDLogger.erase()
149 | print msg % (current, total)
150 | STDLogger.go_up()
151 |
152 | from ..register import REGISTER
153 | REGISTER.set_class(REGISTER.get_package_name(__name__), 'STDLogger', STDLogger)
--------------------------------------------------------------------------------
/packages/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2019-01-25 14:58:24
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | from ..register import REGISTER
8 | from ..loggers.std_logger import STDLogger as logger
9 | from ..utils import tuple_or_list, get_valid_size
10 | from ..config import CONFIG as cfg
11 |
12 | import transforms as custom_transforms
13 | import torchvision.transforms as transforms
14 |
15 | def require_args():
16 | """all args for dataset objects
17 |
18 | Arguments:
19 | parser {argparse} -- current version of argparse object
20 | """
21 |
22 | known_args, _ = cfg.parse_known_args()
23 |
24 | # basic args for all datasets
25 | cfg.add_argument('--data-root', default=None, type=str,
26 | help='root path to dataset')
27 | cfg.add_argument('--resize', default="256", type=str,
28 | help='resize into. (default: 256)')
29 | cfg.add_argument('--size', default="224", type=str,
30 | help='crop into. (default: 224)')
31 | cfg.add_argument('--scale', default=None, type=str,
32 | help='scale for random resize crop. (default: None)')
33 | cfg.add_argument('--ratio', default="(0.75, 1.3333333333333)", type=str,
34 | help='ratio for random resize crop. (default: (0.75, 1.333)')
35 | cfg.add_argument('--colorjitter', default=None, type=str,
36 | help='color jitters for input. (default: None)')
37 | cfg.add_argument('--random-grayscale', default=0, type=float,
38 | help='transform input to gray scale. (default: 0)')
39 | cfg.add_argument('--random-horizontal-flip', action='store_true',
40 | help='random horizontally flip for input. (default: False)')
41 |
42 | # basic args for all dataloader
43 | cfg.add_argument('--batch-size', default=128, type=int,
44 | help='batch size for input data. (default: 128)')
45 | cfg.add_argument('--workers-num', default=4, type=int,
46 | help='number of workers being used to load data. (default: 4)')
47 |
48 | # get args for datasets
49 | if (REGISTER.is_package_registered(__name__) and
50 | REGISTER.is_class_registered(__name__, known_args.dataset)):
51 |
52 | dataset = get(known_args.dataset)
53 |
54 | if hasattr(dataset, 'require_args'):
55 | dataset.require_args()
56 |
57 | def get(name):
58 | return REGISTER.get_class(__name__, name)
59 |
60 | def get_transforms(stage='train', means=None, stds=None):
61 |
62 | transform = []
63 |
64 | stage = stage.lower()
65 | assert stage in ['train', 'test'], ('arg [stage]'
66 | ' should be one of ["train", "test"]')
67 | resize = get_valid_size(cfg.resize)
68 | size = get_valid_size(cfg.size)
69 | if stage == 'train':
70 | # size transform
71 | scale = tuple_or_list(cfg.scale)
72 | if scale:
73 | ratio = tuple_or_list(cfg.ratio)
74 | logger.debug('Training samples will be random resized and crop '
75 | 'with size %s, scale %s and ratio %s'
76 | % (size, scale, ratio))
77 | transform.append(custom_transforms.RandomResizedCrop(
78 | size=size, scale=scale, ratio=ratio))
79 | else:
80 | logger.debug('Training samples will be resized to %s and then '
81 | 'random cropped into %s' % (resize, size))
82 | transform.append(transforms.Resize(size=resize))
83 | transform.append(transforms.RandomCrop(size))
84 | # color jitter transform
85 | colorjitter = tuple_or_list(cfg.colorjitter)
86 | if colorjitter is not None:
87 | logger.debug('Training samples will use color jitter to enhance '
88 | 'with args: %s' % (colorjitter,))
89 | transform.append(transforms.ColorJitter(*colorjitter))
90 |
91 | # gray scale
92 | if cfg.random_grayscale > 0:
93 | logger.debug('Training samples will be randomly convert to '
94 | 'grayscale with probability %.2f' % cfg.random_grayscale)
95 | transform.append(transforms.RandomGrayscale(
96 | p=cfg.random_grayscale))
97 |
98 | # random horizontal flip
99 | if cfg.random_horizontal_flip:
100 | logger.debug('Training samples will be random horizontally flip')
101 | transform.append(transforms.RandomHorizontalFlip())
102 |
103 | else:
104 | logger.debug('Testing samples will be resized to %s and then center '
105 | 'crop to %s' % (resize, size))
106 | transform.extend([transforms.Resize(resize),
107 | transforms.CenterCrop(size)])
108 |
109 | # to tensor
110 | transform.append(transforms.ToTensor())
111 |
112 | # normalize
113 | means = tuple_or_list(means)
114 | stds = tuple_or_list(stds)
115 | if not (means is None or stds is None):
116 | logger.debug('Samples will be normalized with means: %s and stds: %s' %
117 | (means, stds))
118 | transform.append(transforms.Normalize(means, stds))
119 | else:
120 | logger.debug('Input images will not be normalized')
121 |
122 | return transforms.Compose(transform)
123 |
124 | def register(name, cls):
125 | REGISTER.set_class(__name__, name, cls)
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
--------------------------------------------------------------------------------
/lib/ans_discovery.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2018-10-12 21:37:13
4 | # @Author : Raymond Wong (jiabo.huang@qmul.ac.uk)
5 | # @Link : github.com/Raymond-sci
6 |
7 | import sys
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from packages.config import CONFIG as cfg
14 | from packages.register import REGISTER
15 | from packages.loggers.std_logger import STDLogger as logger
16 |
17 |
18 | def require_args():
19 | cfg.add_argument('--ANs-select-rate', default=0.25, type=float,
20 | help='ANs select rate at each round')
21 | cfg.add_argument('--ANs-size', default=1, type=int,
22 | help='ANs size discarding the anchor')
23 |
24 | class ANsDiscovery(nn.Module):
25 | """Discovery ANs
26 |
27 | Discovery ANs according to current round, select_rate and most importantly,
28 | all sample's corresponding entropy
29 | """
30 |
31 | def __init__(self, nsamples):
32 | """Object used to discovery ANs
33 |
34 | Discovery ANs according to the total amount of samples, ANs selection
35 | rate, ANs size
36 |
37 | Arguments:
38 | nsamples {int} -- total number of sampels
39 | select_rate {float} -- ANs selection rate
40 | ans_size {int} -- ANs size
41 |
42 | Keyword Arguments:
43 | device {str} -- [description] (default: {'cpu'})
44 | """
45 | super(ANsDiscovery, self).__init__()
46 |
47 | # not going to use ``register_buffer'' as
48 | # they are determined by configs
49 | self.select_rate = cfg.ANs_select_rate
50 | self.ANs_size = cfg.ANs_size
51 | # number of samples
52 | self.register_buffer('samples_num', torch.tensor(nsamples))
53 | # indexes list of anchor samples
54 | self.register_buffer('anchor_indexes', torch.LongTensor([]))
55 | # indexes list of instance samples
56 | self.register_buffer('instance_indexes', torch.arange(nsamples).long())
57 | # anchor samples' and instance samples' position
58 | self.register_buffer('position', -1 * torch.arange(nsamples).long() - 1)
59 | # anchor samples' neighbours
60 | self.register_buffer('neighbours', torch.LongTensor([]));
61 | # each sample's entropy
62 | self.register_buffer('entropy', torch.FloatTensor(nsamples));
63 | # consistency
64 | self.register_buffer('consistency', torch.tensor(0.));
65 |
66 | def get_ANs_num(self, round):
67 | """Get number of ANs
68 |
69 | Get number of ANs at target round according to the select rate
70 |
71 | Arguments:
72 | round {int} -- target round
73 |
74 | Returns:
75 | int -- number of ANs
76 | """
77 | return int(self.samples_num.float() * self.select_rate * round)
78 |
79 | def update(self, round, npc, cheat_labels=None):
80 | """Update ANs
81 |
82 | Discovery new ANs and update `anchor_indexes`, `instance_indexes` and
83 | `neighbours`
84 |
85 | Arguments:
86 | round {int} -- target round
87 | npc {Module} -- non-parametric classifier
88 | cheat_labels {list} -- used to compute consistency of chosen ANs only
89 |
90 | Returns:
91 | number -- [updated consistency]
92 | """
93 | with torch.no_grad():
94 | batch_size = 100
95 | ANs_num = self.get_ANs_num(round)
96 | logger.debug('Going to choose %d samples as anchors' % ANs_num)
97 | features = npc.memory
98 |
99 | logger.debug('Start to compute each sample\'s entropy')
100 | for start in xrange(0, self.samples_num, batch_size):
101 | logger.progress(start, self.samples_num, 'processing %d/%d samples...')
102 |
103 | end = start + batch_size
104 | end = min(end, self.samples_num)
105 |
106 | preds = F.softmax(npc(features[start:end], None), 1)
107 | self.entropy[start:end] = -(preds * preds.log()).sum(1)
108 |
109 | logger.debug('Compute entropy done, max(%.2f), min(%.2f), mean(%.2f)'
110 | % (self.entropy.max(), self.entropy.min(), self.entropy.mean()))
111 |
112 | # get the anchor list and instance list according to the computed
113 | # entropy
114 | self.anchor_indexes = self.entropy.topk(ANs_num, largest=False)[1]
115 | self.instance_indexes = (torch.ones_like(self.position)
116 | .scatter_(0, self.anchor_indexes, 0)
117 | .nonzero().view(-1))
118 | anchor_entropy = self.entropy.index_select(0, self.anchor_indexes)
119 | instance_entropy = self.entropy.index_select(0, self.instance_indexes)
120 | if self.anchor_indexes.size(0) > 0:
121 | logger.debug('Entropies of anchor samples: max(%.2f), '
122 | 'min(%.2f), mean(%.2f)' % (anchor_entropy.max(),
123 | anchor_entropy.min(), anchor_entropy.mean()))
124 | if self.instance_indexes.size(0) > 0:
125 | logger.debug('Entropies of instance sample: max(%.2f), '
126 | 'min(%.2f), mean(%.2f)' % (instance_entropy.max(),
127 | instance_entropy.min(), instance_entropy.mean()))
128 |
129 | # get position
130 | # if the anchor sample x whose index is i while position is j, then
131 | # sample x_i is the j-th anchor sample at current round
132 | # if the instance sample x whose index is i while position is j, then
133 | # sample x_i is the (-j-1)-th instance sample at current round
134 | logger.debug('Start to get the position of both anchor and '
135 | 'instance samples')
136 | instance_cnt = 0
137 | for i in xrange(self.samples_num):
138 | logger.progress(i, self.samples_num, 'processing %d/%d samples...')
139 |
140 | # for anchor samples
141 | if (i == self.anchor_indexes).any():
142 | self.position[i] = (self.anchor_indexes == i).max(0)[1]
143 | continue
144 | # for instance samples
145 | instance_cnt -= 1
146 | self.position[i] = instance_cnt
147 |
148 | logger.debug('Start to find %d neighbours for each anchor sample'
149 | % self.ANs_size)
150 | anchor_features = features.index_select(0, self.anchor_indexes)
151 | self.neighbours = (torch.LongTensor(ANs_num, self.ANs_size)
152 | .to(cfg.device))
153 | for start in xrange(0, ANs_num, batch_size):
154 | logger.progress(start, ANs_num, 'processing %d/%d samples...')
155 |
156 | end = start + batch_size
157 | end = min(end, ANs_num)
158 |
159 | sims = torch.mm(anchor_features[start:end], features.t())
160 | sims.scatter_(1, self.anchor_indexes[start:end].view(-1, 1), -1.)
161 | _, self.neighbours[start:end] = (
162 | sims.topk(self.ANs_size, largest=True, dim=1))
163 | logger.debug('ANs discovery done')
164 |
165 | # if cheat labels is provided, then compute consistency
166 | if cheat_labels is None:
167 | return 0.
168 | logger.debug('Start to compute ANs consistency')
169 | anchor_label = cheat_labels.index_select(0, self.anchor_indexes)
170 | neighbour_label = cheat_labels.index_select(0,
171 | self.neighbours.view(-1)).view_as(self.neighbours)
172 | self.consistency = ((anchor_label.view(-1, 1) == neighbour_label)
173 | .float().mean())
174 |
175 | return self.consistency
176 |
177 | REGISTER.set_package(__name__)
178 | REGISTER.set_class(__name__, 'ans', ANsDiscovery)
179 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Date : 2018-09-27 15:09:03
4 | # @Author : Jiabo (Raymond) Huang (jiabo.huang@qmul.ac.uk)
5 | # @Link : https://github.com/Raymond-sci
6 |
7 | import torch
8 | import torch.backends.cudnn as cudnn
9 |
10 | import sys
11 | import os
12 | import time
13 | from datetime import datetime
14 |
15 | import models
16 | import datasets
17 |
18 | from lib import protocols
19 | from lib.non_parametric_classifier import NonParametricClassifier
20 | from lib.criterion import Criterion
21 | from lib.ans_discovery import ANsDiscovery
22 | from lib.utils import AverageMeter, time_progress, adjust_learning_rate
23 |
24 | from packages import session
25 | from packages import lr_policy
26 | from packages import optimizers
27 | from packages.config import CONFIG as cfg
28 | from packages.loggers.std_logger import STDLogger as logger
29 | from packages.loggers.tf_logger import TFLogger as SummaryWriter
30 |
31 |
32 | def require_args():
33 |
34 | # dataset to be used
35 | cfg.add_argument('--dataset', default='cifar10', type=str,
36 | help='dataset to be used. (default: cifar10)')
37 |
38 | # network to be used
39 | cfg.add_argument('--network', default='resnet18', type=str,
40 | help='backbone to be used. (default: ResNet18)')
41 |
42 | # optimizer to be used
43 | cfg.add_argument('--optimizer', default='sgd', type=str,
44 | help='optimizer to be used. (default: sgd)')
45 |
46 | # lr policy to be used
47 | cfg.add_argument('--lr-policy', default='step', type=str,
48 | help='lr policy to be used. (default: step)')
49 |
50 | # args for protocol
51 | cfg.add_argument('--protocol', default='knn', type=str,
52 | help='protocol used to validate model')
53 |
54 | # args for network training
55 | cfg.add_argument('--max-epoch', default=200, type=int,
56 | help='max epoch per round. (default: 200)')
57 | cfg.add_argument('--max-round', default=5, type=int,
58 | help='max iteration, including initialisation one. '
59 | '(default: 5)')
60 | cfg.add_argument('--iter-size', default=1, type=int,
61 | help='caffe style iter size. (default: 1)')
62 | cfg.add_argument('--display-freq', default=1, type=int,
63 | help='display step')
64 | cfg.add_argument('--test-only', action='store_true',
65 | help='test only')
66 |
67 |
68 | def main():
69 |
70 | logger.info('Start to declare training variables')
71 | cfg.device = device = 'cuda' if torch.cuda.is_available() else 'cpu'
72 | best_acc = 0. # best test accuracy
73 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
74 | start_round = 0 # start for iter 0 or last checkpoint iter
75 |
76 | logger.info('Start to prepare data')
77 | trainset, trainloader, testset, testloader = datasets.get(cfg.dataset, instant=True)
78 | # cheat labels are used to compute neighbourhoods consistency only
79 | cheat_labels = torch.tensor(trainset.labels).long().to(device)
80 | ntrain, ntest = len(trainset), len(testset)
81 | logger.info('Totally got %d training and %d test samples' % (ntrain, ntest))
82 |
83 | logger.info('Start to build model')
84 | net = models.get(cfg.network, instant=True)
85 | npc = NonParametricClassifier(cfg.low_dim, ntrain, cfg.npc_temperature, cfg.npc_momentum)
86 | ANs_discovery = ANsDiscovery(ntrain)
87 | criterion = Criterion()
88 | optimizer = optimizers.get(cfg.optimizer, instant=True, params=net.parameters())
89 | lr_handler = lr_policy.get(cfg.lr_policy, instant=True)
90 | protocol = protocols.get(cfg.protocol)
91 |
92 | # data parallel
93 | if device == 'cuda':
94 | if (cfg.network.lower().startswith('alexnet') or
95 | cfg.network.lower().startswith('vgg')):
96 | net.features = torch.nn.DataParallel(net.features,
97 | device_ids=range(len(cfg.gpus.split(','))))
98 | else:
99 | net = torch.nn.DataParallel(net, device_ids=range(
100 | len(cfg.gpus.split(','))))
101 | cudnn.benchmark = True
102 |
103 | net, npc, ANs_discovery, criterion = (net.to(device), npc.to(device),
104 | ANs_discovery.to(device), criterion.to(device))
105 |
106 | # load ckpt file if necessary
107 | if cfg.resume:
108 | assert os.path.exists(cfg.resume), "Resume file not found: %s" % cfg.resume
109 | logger.info('Start to resume from %s' % cfg.resume)
110 | ckpt = torch.load(cfg.resume)
111 | net.load_state_dict(ckpt['net'])
112 | optimizer.load_state_dict(ckpt['optimizer'])
113 | npc = npc.load_state_dict(ckpt['npc'])
114 | ANs_discovery.load_state_dict(ckpt['ANs_discovery'])
115 | best_acc = ckpt['acc']
116 | start_epoch = ckpt['epoch']
117 | start_round = ckpt['round']
118 |
119 | # test if necessary
120 | if cfg.test_only:
121 | logger.info('Testing at beginning...')
122 | acc = protocol(net, npc, trainloader, testloader, 200,
123 | cfg.npc_temperature, True, device)
124 | logger.info('Evaluation accuracy at %d round and %d epoch: %.2f%%' %
125 | (start_round, start_epoch, acc * 100))
126 | sys.exit(0)
127 |
128 | logger.info('Start the progressive training process from round: %d, '
129 | 'epoch: %d, best acc is %.4f...' % (start_round, start_epoch, best_acc))
130 | round = start_round
131 | global_writer = SummaryWriter(cfg.debug,
132 | log_dir=os.path.join(cfg.tfb_dir, 'global'))
133 | while (round < cfg.max_round):
134 |
135 | # variables are initialized to different value in the first round
136 | is_first_round = True if round == start_round else False
137 | best_acc = best_acc if is_first_round else 0
138 |
139 | if not is_first_round:
140 | logger.info('Start to mining ANs at %d round' % round)
141 | ANs_discovery.update(round, npc, cheat_labels)
142 | logger.info('ANs consistency at %d round is %.2f%%' %
143 | (round, ANs_discovery.consistency * 100))
144 |
145 | ANs_num = ANs_discovery.anchor_indexes.shape[0]
146 | global_writer.add_scalar('ANs/Number', ANs_num, round)
147 | global_writer.add_scalar('ANs/Consistency', ANs_discovery.consistency, round)
148 |
149 | # declare local writer
150 | writer = SummaryWriter(cfg.debug, log_dir=os.path.join(cfg.tfb_dir,
151 | '%04d-%05d' % (round, ANs_num)))
152 | logger.info('Start training at %d/%d round' % (round, cfg.max_round))
153 |
154 |
155 | # start to train for an epoch
156 | epoch = start_epoch if is_first_round else 0
157 | lr = cfg.base_lr
158 | while lr > 0 and epoch < cfg.max_epoch:
159 |
160 | # get learning rate according to current epoch
161 | lr = lr_handler.update(epoch)
162 |
163 | train(round, epoch, net, trainloader, optimizer, npc, criterion,
164 | ANs_discovery, lr, writer)
165 |
166 | logger.info('Start to evaluate...')
167 | acc = protocol(net, npc, trainloader, testloader, 200,
168 | cfg.npc_temperature, False, device)
169 | writer.add_scalar('Evaluate/Rank-1', acc, epoch)
170 |
171 | logger.info('Evaluation accuracy at %d round and %d epoch: %.1f%%'
172 | % (round, epoch, acc * 100))
173 | logger.info('Best accuracy at %d round and %d epoch: %.1f%%'
174 | % (round, epoch, best_acc * 100))
175 |
176 | is_best = acc >= best_acc
177 | best_acc = max(acc, best_acc)
178 | if is_best and not cfg.debug:
179 | target = os.path.join(cfg.ckpt_dir, '%04d-%05d.ckpt'
180 | % (round, ANs_num))
181 | logger.info('Saving checkpoint to %s' % target)
182 | state = {
183 | 'net': net.state_dict(),
184 | 'optimizer': optimizer.state_dict(),
185 | 'ANs_discovery' : ANs_discovery.state_dict(),
186 | 'npc' : npc.state_dict(),
187 | 'acc': acc,
188 | 'epoch': epoch + 1,
189 | 'round' : round,
190 | 'session' : cfg.session
191 | }
192 | torch.save(state, target)
193 | epoch += 1
194 |
195 | # log best accuracy after each iteration
196 | global_writer.add_scalar('Evaluate/best_acc', best_acc, round)
197 | round += 1
198 |
199 | # Training
200 | def train(round, epoch, net, trainloader, optimizer, npc, criterion,
201 | ANs_discovery, lr, writer):
202 |
203 | # tracking variables
204 | train_loss = AverageMeter()
205 | data_time = AverageMeter()
206 | batch_time = AverageMeter()
207 |
208 | # switch the model to train mode
209 | net.train()
210 | # adjust learning rate
211 | adjust_learning_rate(optimizer, lr)
212 |
213 | end = time.time()
214 | start_time = datetime.now()
215 | optimizer.zero_grad()
216 | for batch_idx, (inputs, _, indexes) in enumerate(trainloader):
217 | data_time.update(time.time() - end)
218 | inputs, indexes = inputs.to(cfg.device), indexes.to(cfg.device)
219 |
220 | features = net(inputs)
221 | outputs = npc(features, indexes)
222 | loss = criterion(outputs, indexes, ANs_discovery) / cfg.iter_size
223 |
224 | loss.backward()
225 | train_loss.update(loss.item() * cfg.iter_size, inputs.size(0))
226 |
227 | if batch_idx % cfg.iter_size == 0:
228 | optimizer.step()
229 | optimizer.zero_grad()
230 |
231 | # measure elapsed time
232 | batch_time.update(time.time() - end)
233 | end = time.time()
234 |
235 | if batch_idx % cfg.display_freq != 0:
236 | continue
237 |
238 | writer.add_scalar('Train/Learning_Rate', lr,
239 | epoch * len(trainloader) + batch_idx)
240 | writer.add_scalar('Train/Loss', train_loss.val,
241 | epoch * len(trainloader) + batch_idx)
242 |
243 |
244 | elapsed_time, estimated_time = time_progress(batch_idx + 1,
245 | len(trainloader), batch_time.sum)
246 | logger.info('Round: {round} Epoch: {epoch}/{tot_epochs} '
247 | 'Progress: {elps_iters}/{tot_iters} ({elps_time}/{est_time}) '
248 | 'Data: {data_time.avg:.3f} LR: {learning_rate:.5f} '
249 | 'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f})'.format(
250 | round=round, epoch=epoch, tot_epochs=cfg.max_epoch,
251 | elps_iters=batch_idx, tot_iters=len(trainloader),
252 | elps_time=elapsed_time, est_time=estimated_time,
253 | data_time=data_time, learning_rate=lr,
254 | train_loss=train_loss))
255 |
256 | if __name__ == '__main__':
257 |
258 | session.run(__name__)
259 |
--------------------------------------------------------------------------------