├── moco ├── __init__.py ├── NCE │ ├── __init__.py │ ├── NCECriterion.py │ └── Contrast.py ├── dataset.py ├── models │ ├── LinearModel.py │ └── resnet.py ├── lr_scheduler.py ├── logger.py └── util.py ├── scripts └── train_eval_imagenet100_baseLR0.4_alpha0.99_crop0.08_k1281166_t0.1_AMPO1.sh ├── .gitignore ├── README.md ├── train.py └── eval.py /moco/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /moco/NCE/__init__.py: -------------------------------------------------------------------------------- 1 | from .Contrast import MemoryMoCo 2 | from .NCECriterion import NCESoftmaxLoss 3 | -------------------------------------------------------------------------------- /moco/NCE/NCECriterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class NCESoftmaxLoss(nn.Module): 6 | """Softmax cross-entropy loss (a.k.a., info-NCE loss in CPC paper)""" 7 | def __init__(self): 8 | super(NCESoftmaxLoss, self).__init__() 9 | self.criterion = nn.CrossEntropyLoss() 10 | 11 | def forward(self, x): 12 | label = torch.zeros([x.shape[0]]).long().to(x.device) 13 | return self.criterion(x, label) 14 | -------------------------------------------------------------------------------- /scripts/train_eval_imagenet100_baseLR0.4_alpha0.99_crop0.08_k1281166_t0.1_AMPO1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data_dir="./data/imagenet100" 4 | output_dir="./output/imagenet100/baseLR0.4_alpha0.99_crop0.08_kall_t0.1_AMPO1" 5 | python -m torch.distributed.launch --master_port 12347 --nproc_per_node=8 \ 6 | train.py \ 7 | --data-dir ${data_dir} \ 8 | --dataset imagenet100 \ 9 | --base-learning-rate 0.4 \ 10 | --alpha 0.99 \ 11 | --crop 0.08 \ 12 | --nce-k 126689 \ 13 | --nce-t 0.1 \ 14 | --amp-opt-level O1 \ 15 | --output-dir ${output_dir} 16 | 17 | python -m torch.distributed.launch --master_port 12348 --nproc_per_node=4 \ 18 | eval.py \ 19 | --dataset imagenet100 \ 20 | --data-dir ${data_dir} \ 21 | --pretrained-model ${output_dir}/current.pth \ 22 | --output-dir ${output_dir}/eval 23 | 24 | -------------------------------------------------------------------------------- /moco/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torchvision.datasets as datasets 5 | 6 | 7 | class ImageFolderInstance(datasets.ImageFolder): 8 | """Folder dataset which returns the index of the image as well 9 | """ 10 | 11 | def __init__(self, root, transform=None, target_transform=None, two_crop=False): 12 | super(ImageFolderInstance, self).__init__(root, transform, target_transform) 13 | self.two_crop = two_crop 14 | 15 | def __getitem__(self, index): 16 | """ 17 | Args: 18 | index (int): Index 19 | Returns: 20 | tuple: (image, target, index) where target is class_index of the target class. 21 | """ 22 | path, target = self.imgs[index] 23 | image = self.loader(path) 24 | if self.transform is not None: 25 | img = self.transform(image) 26 | else: 27 | img = image 28 | if self.target_transform is not None: 29 | target = self.target_transform(target) 30 | 31 | if self.two_crop: 32 | img2 = self.transform(image) 33 | img = torch.cat([img, img2], dim=0) 34 | 35 | return img, target 36 | -------------------------------------------------------------------------------- /moco/NCE/Contrast.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | 5 | 6 | class MemoryMoCo(nn.Module): 7 | """Fixed-size queue with momentum encoder""" 8 | def __init__(self, feature_dim, queue_size, temperature=0.07): 9 | super(MemoryMoCo, self).__init__() 10 | self.queue_size = queue_size 11 | self.temperature = temperature 12 | self.index = 0 13 | 14 | # noinspection PyCallingNonCallable 15 | self.register_buffer('params', torch.tensor([-1])) 16 | stdv = 1. / math.sqrt(feature_dim / 3) 17 | memory = torch.rand(self.queue_size, feature_dim, requires_grad=False).mul_(2 * stdv).add_(-stdv) 18 | self.register_buffer('memory', memory) 19 | 20 | def forward(self, q, k, k_all): 21 | k = k.detach() 22 | 23 | l_pos = (q * k).sum(dim=-1, keepdim=True) # shape: (batchSize, 1) 24 | # TODO: remove clone. need update memory in backwards 25 | l_neg = torch.mm(q, self.memory.clone().detach().t()) 26 | out = torch.cat((l_pos, l_neg), dim=1) 27 | out = torch.div(out, self.temperature).contiguous() 28 | 29 | # update memory 30 | with torch.no_grad(): 31 | all_size = k_all.shape[0] 32 | out_ids = torch.fmod(torch.arange(all_size, dtype=torch.long).cuda() + self.index, self.queue_size) 33 | self.memory.index_copy_(0, out_ids, k_all) 34 | self.index = (self.index + all_size) % self.queue_size 35 | 36 | return out 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | data/ 3 | output*/ 4 | ckpts/ 5 | *.pth 6 | *.t7 7 | *.png 8 | *.jpg 9 | tmp*.py 10 | # run*.sh 11 | *.pdf 12 | 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # celery beat schedule file 91 | celerybeat-schedule 92 | 93 | # SageMath parsed files 94 | *.sage.py 95 | 96 | # Environments 97 | .env 98 | .venv 99 | env/ 100 | venv/ 101 | ENV/ 102 | env.bak/ 103 | venv.bak/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | .spyproject 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ 117 | data 118 | .vscode 119 | -------------------------------------------------------------------------------- /moco/models/LinearModel.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class Flatten(nn.Module): 7 | def __init__(self): 8 | super(Flatten, self).__init__() 9 | 10 | def forward(self, feat): 11 | return feat.view(feat.size(0), -1) 12 | 13 | 14 | class LinearClassifierResNet(nn.Module): 15 | def __init__(self, layer=6, n_label=1000, pool_type='avg', width=1): 16 | super(LinearClassifierResNet, self).__init__() 17 | if layer == 1: 18 | pool_size = 8 19 | n_channels = 128 * width 20 | pool = pool_type 21 | elif layer == 2: 22 | pool_size = 6 23 | n_channels = 256 * width 24 | pool = pool_type 25 | elif layer == 3: 26 | pool_size = 4 27 | n_channels = 512 * width 28 | pool = pool_type 29 | elif layer == 4: 30 | pool_size = 3 31 | n_channels = 1024 * width 32 | pool = pool_type 33 | elif layer == 5: 34 | pool_size = 7 35 | n_channels = 2048 * width 36 | pool = pool_type 37 | elif layer == 6: 38 | pool_size = 1 39 | n_channels = 2048 * width 40 | pool = pool_type 41 | else: 42 | raise NotImplementedError('layer not supported: {}'.format(layer)) 43 | 44 | self.classifier = nn.Sequential() 45 | if layer < 5: 46 | if pool == 'max': 47 | self.classifier.add_module('MaxPool', nn.AdaptiveMaxPool2d((pool_size, pool_size))) 48 | elif pool == 'avg': 49 | self.classifier.add_module('AvgPool', nn.AdaptiveAvgPool2d((pool_size, pool_size))) 50 | else: 51 | # self.classifier.add_module('AvgPool', nn.AvgPool2d(7, stride=1)) 52 | pass 53 | 54 | self.classifier.add_module('Flatten', Flatten()) 55 | self.classifier.add_module('LiniearClassifier', nn.Linear(n_channels * pool_size * pool_size, n_label)) 56 | self.initilize() 57 | 58 | def initilize(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.Linear): 61 | m.weight.data.normal_(0, 0.01) 62 | m.bias.data.fill_(0.0) 63 | 64 | def forward(self, x): 65 | return self.classifier(x) 66 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Unofficial implementation for [MoCo: Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/abs/1911.05722) 2 | 3 | ## Highlight 4 | 5 | 1. **Effective**. Carefully implement important details such as ShuffleBN and distributed Queue mentioned in the paper to reproduce the reported results. 6 | 2. **Efficient**. The implementation is based on pytorch DistributedDataParallel and Apex automatic mixed precision. It only takes about 40 hours to train MoCo on imagenet dataset with 8 V100 gpus. The time cost is smaller than 3 days reported in original MoCo paper. 7 | 8 | 9 | ## Requirements 10 | 11 | The following enverionments is tested: 12 | 13 | * `Anaconda` with `python >= 3.6` 14 | * `pytorch>=1.3, torchvision, cuda=10.1/9.2` 15 | * others: `pip install termcolor opencv-python tensorboard` 16 | * [Optional] [`apex`](https://github.com/NVIDIA/apex#quick-start): automatic mixed precision training. 17 | 18 | ## Train and eval on imagenet 19 | 20 | * The pre-training stage: 21 | 22 | ```bash 23 | data_dir="./data/imagenet100" 24 | output_dir="./output/imagenet/K65536" 25 | python -m torch.distributed.launch --master_port 12347 --nproc_per_node=8 \ 26 | train.py \ 27 | --data-dir ${data_dir} \ 28 | --dataset imagenet \ 29 | --nce-k 65536 \ 30 | --output-dir ${output_dir} 31 | ``` 32 | 33 | The log, checkpoints and tensorboard events will be saved in `${output_dir}`. Set `--amp-opt-level` to `O1`, `O2`, or `O3` for mixed precision training. Run `python train.py --help` for more help. 34 | 35 | * The linear evaluation stage: 36 | 37 | ```bash 38 | python -m torch.distributed.launch --nproc_per_node=4 \ 39 | eval.py \ 40 | --dataset imagenet \ 41 | --data-dir ${data_dir} \ 42 | --pretrained-model ${output_dir}/current.pth \ 43 | --output-dir ${output_dir}/eval 44 | ``` 45 | 46 | The checkpoints and tensorboard log will be saved in `${output_dir}/eval`. Set `--amp-opt-level` to `O1`, `O2`, or `O3` for mixed precision training. Run `python eval.py --help` for more help. 47 | 48 | 49 | ## Pre-trained weights 50 | 51 | Pre-trained model checkpoint and tensorboard log for K = 16384 and 65536 on imagenet dataset can be downloaded from [OneDrive](https://1drv.ms/u/s!AsaPPmtCAq08pEsUojFnhhnGLG8F?e=zFwbGY). 52 | 53 | BTW, the hyperparameters is also stored in model checkpoint, you can get full configs in the checkpoints like this: 54 | ```python 55 | import torch 56 | ckpt = torch.load('model.pth') 57 | ckpt['opt'] 58 | ``` 59 | 60 | ## Performance comparison with original paper 61 | 62 | | K | Acc@1 (ours) | Acc@1 (MoCo paper) | 63 | | ----- | -------------------------------------------------------------------------- | ------------------ | 64 | | 16384 | 59.89 ([model](https://1drv.ms/u/s!AsaPPmtCAq08pFfk01K2l2T7Hv9P?e=uI1vGx)) | 60.4 | 65 | | 65536 | 60.79 ([model](https://1drv.ms/u/s!AsaPPmtCAq08pFa2xJRkILatNLh8?e=IMt2xg)) | 60.6 | 66 | 67 | ## Notes 68 | 69 | The MultiStepLR of pytorch1.4 is broken (See https://github.com/pytorch/pytorch/issues/33229 for more details). So if you are using pytorch1.4, you should not set `--lr-scheduler` to step. You can use `cosine` instead. 70 | 71 | ## Acknowledgements 72 | 73 | A lot of codes is borrowed from [CMC](https://github.com/HobbitLong/CMC) and [lemniscate](https://github.com/zhirongw/lemniscate.pytorch). 74 | 75 | -------------------------------------------------------------------------------- /moco/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # noinspection PyProtectedMember 2 | from torch.optim.lr_scheduler import _LRScheduler, MultiStepLR, CosineAnnealingLR 3 | 4 | 5 | # noinspection PyAttributeOutsideInit 6 | class GradualWarmupScheduler(_LRScheduler): 7 | """ Gradually warm-up(increasing) learning rate in optimizer. 8 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: init learning rate = base lr / multiplier 12 | warmup_epoch: target learning rate is reached at warmup_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, warmup_epoch, after_scheduler, last_epoch=-1): 17 | self.multiplier = multiplier 18 | if self.multiplier <= 1.: 19 | raise ValueError('multiplier should be greater than 1.') 20 | self.warmup_epoch = warmup_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super().__init__(optimizer, last_epoch=last_epoch) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.warmup_epoch: 27 | return self.after_scheduler.get_lr() 28 | else: 29 | return [base_lr / self.multiplier * ((self.multiplier - 1.) * self.last_epoch / self.warmup_epoch + 1.) 30 | for base_lr in self.base_lrs] 31 | 32 | def step(self, epoch=None): 33 | if epoch is None: 34 | epoch = self.last_epoch + 1 35 | self.last_epoch = epoch 36 | if epoch > self.warmup_epoch: 37 | self.after_scheduler.step(epoch - self.warmup_epoch) 38 | else: 39 | super(GradualWarmupScheduler, self).step(epoch) 40 | 41 | def state_dict(self): 42 | """Returns the state of the scheduler as a :class:`dict`. 43 | 44 | It contains an entry for every variable in self.__dict__ which 45 | is not the optimizer. 46 | """ 47 | 48 | state = {key: value for key, value in self.__dict__.items() if key != 'optimizer' and key != 'after_scheduler'} 49 | state['after_scheduler'] = self.after_scheduler.state_dict() 50 | return state 51 | 52 | def load_state_dict(self, state_dict): 53 | """Loads the schedulers state. 54 | 55 | Arguments: 56 | state_dict (dict): scheduler state. Should be an object returned 57 | from a call to :meth:`state_dict`. 58 | """ 59 | 60 | after_scheduler_state = state_dict.pop('after_scheduler') 61 | self.__dict__.update(state_dict) 62 | self.after_scheduler.load_state_dict(after_scheduler_state) 63 | 64 | 65 | def get_scheduler(optimizer, n_iter_per_epoch, args): 66 | if "cosine" in args.lr_scheduler: 67 | scheduler = CosineAnnealingLR( 68 | optimizer=optimizer, 69 | eta_min=0.000001, 70 | T_max=(args.epochs - args.warmup_epoch) * n_iter_per_epoch) 71 | elif "step" in args.lr_scheduler: 72 | scheduler = MultiStepLR( 73 | optimizer=optimizer, 74 | gamma=args.lr_decay_rate, 75 | milestones=[(m - args.warmup_epoch) * n_iter_per_epoch for m in args.lr_decay_epochs]) 76 | else: 77 | raise NotImplementedError(f"scheduler {args.lr_scheduler} not supported") 78 | 79 | scheduler = GradualWarmupScheduler( 80 | optimizer, 81 | multiplier=args.warmup_multiplier, 82 | after_scheduler=scheduler, 83 | warmup_epoch=args.warmup_epoch * n_iter_per_epoch) 84 | return scheduler 85 | -------------------------------------------------------------------------------- /moco/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import functools 3 | import logging 4 | import os 5 | import sys 6 | from termcolor import colored 7 | 8 | 9 | class _ColorfulFormatter(logging.Formatter): 10 | def __init__(self, *args, **kwargs): 11 | self._root_name = kwargs.pop("root_name") + "." 12 | self._abbrev_name = kwargs.pop("abbrev_name", "") 13 | if len(self._abbrev_name): 14 | self._abbrev_name = self._abbrev_name + "." 15 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 16 | 17 | def formatMessage(self, record): 18 | record.name = record.name.replace(self._root_name, self._abbrev_name) 19 | log = super(_ColorfulFormatter, self).formatMessage(record) 20 | if record.levelno == logging.WARNING: 21 | prefix = colored("WARNING", "red", attrs=["blink"]) 22 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 23 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 24 | else: 25 | return log 26 | return prefix + " " + log 27 | 28 | 29 | # so that calling setup_logger multiple times won't add many handlers 30 | @functools.lru_cache() 31 | def setup_logger( 32 | output=None, distributed_rank=0, *, color=True, name="moco", abbrev_name=None 33 | ): 34 | """ 35 | Initialize the detectron2 logger and set its verbosity level to "INFO". 36 | 37 | Args: 38 | output (str): a file name or a directory to save log. If None, will not save log file. 39 | If ends with ".txt" or ".log", assumed to be a file name. 40 | Otherwise, logs will be saved to `output/log.txt`. 41 | name (str): the root module name of this logger 42 | 43 | Returns: 44 | logging.Logger: a logger 45 | """ 46 | logger = logging.getLogger(name) 47 | logger.setLevel(logging.DEBUG) 48 | logger.propagate = False 49 | 50 | if abbrev_name is None: 51 | abbrev_name = name 52 | 53 | plain_formatter = logging.Formatter( 54 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" 55 | ) 56 | # stdout logging: master only 57 | if distributed_rank == 0: 58 | ch = logging.StreamHandler(stream=sys.stdout) 59 | ch.setLevel(logging.DEBUG) 60 | if color: 61 | formatter = _ColorfulFormatter( 62 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 63 | datefmt="%m/%d %H:%M:%S", 64 | root_name=name, 65 | abbrev_name=str(abbrev_name), 66 | ) 67 | else: 68 | formatter = plain_formatter 69 | ch.setFormatter(formatter) 70 | logger.addHandler(ch) 71 | 72 | # file logging: all workers 73 | if output is not None: 74 | if output.endswith(".txt") or output.endswith(".log"): 75 | filename = output 76 | else: 77 | filename = os.path.join(output, "log.txt") 78 | if distributed_rank > 0: 79 | filename = filename + f".rank{distributed_rank}" 80 | os.makedirs(os.path.dirname(filename), exist_ok=True) 81 | 82 | fh = logging.StreamHandler(_cached_log_stream(filename)) 83 | fh.setLevel(logging.DEBUG) 84 | fh.setFormatter(plain_formatter) 85 | logger.addHandler(fh) 86 | 87 | return logger 88 | 89 | 90 | # cache the opened file object, so that different calls to `setup_logger` 91 | # with the same file name can safely write to the same file. 92 | @functools.lru_cache(maxsize=None) 93 | def _cached_log_stream(filename): 94 | return open(filename, "a") 95 | -------------------------------------------------------------------------------- /moco/util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | class AverageMeter(object): 8 | """Computes and stores the average and current value""" 9 | 10 | def __init__(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0 19 | self.avg = 0 20 | self.sum = 0 21 | self.count = 0 22 | 23 | def update(self, val, n=1): 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | 30 | def accuracy(output, target, topk=(1,)): 31 | """Computes the accuracy over the k top predictions for the specified values of k""" 32 | with torch.no_grad(): 33 | maxk = max(topk) 34 | batch_size = target.size(0) 35 | 36 | _, pred = output.topk(maxk, 1, True, True) 37 | pred = pred.t() 38 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 39 | 40 | res = [] 41 | for k in topk: 42 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 43 | res.append(correct_k.mul_(1.0 / batch_size)) 44 | return res 45 | 46 | 47 | def dist_collect(x): 48 | """ collect all tensor from all GPUs 49 | args: 50 | x: shape (mini_batch, ...) 51 | returns: 52 | shape (mini_batch * num_gpu, ...) 53 | """ 54 | x = x.contiguous() 55 | out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype) 56 | for _ in range(dist.get_world_size())] 57 | dist.all_gather(out_list, x) 58 | return torch.cat(out_list, dim=0) 59 | 60 | 61 | def reduce_tensor(tensor): 62 | rt = tensor.clone() 63 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 64 | rt /= dist.get_world_size() 65 | return rt 66 | 67 | 68 | class DistributedShufle: 69 | @staticmethod 70 | def forward_shuffle(x, epoch): 71 | """ forward shuffle, return shuffled batch of x from all processes. 72 | epoch is used as manual seed to make sure the shuffle id in all process is same. 73 | """ 74 | x_all = dist_collect(x) 75 | forward_inds, backward_inds = DistributedShufle.get_shuffle_ids(x_all.shape[0], epoch) 76 | 77 | forward_inds_local = DistributedShufle.get_local_id(forward_inds) 78 | 79 | return x_all[forward_inds_local], backward_inds 80 | 81 | @staticmethod 82 | def backward_shuffle(x, backward_inds, return_local=True): 83 | """ backward shuffle, return data which have been shuffled back 84 | x is the shared data, should be local data 85 | if return_local, only return the local batch data of x. 86 | otherwise, return collected all data on all process. 87 | """ 88 | x_all = dist_collect(x) 89 | if return_local: 90 | backward_inds_local = DistributedShufle.get_local_id(backward_inds) 91 | return x_all[backward_inds], x_all[backward_inds_local] 92 | else: 93 | return x_all[backward_inds] 94 | 95 | @staticmethod 96 | def get_local_id(ids): 97 | return ids.chunk(dist.get_world_size())[dist.get_rank()] 98 | 99 | @staticmethod 100 | def get_shuffle_ids(bsz, epoch): 101 | """generate shuffle ids for ShuffleBN""" 102 | torch.manual_seed(epoch) 103 | # global forward shuffle id for all process 104 | forward_inds = torch.randperm(bsz).long().cuda() 105 | 106 | # global backward shuffle id 107 | backward_inds = torch.zeros(forward_inds.shape[0]).long().cuda() 108 | value = torch.arange(bsz).long().cuda() 109 | backward_inds.index_copy_(0, forward_inds, value) 110 | 111 | return forward_inds, backward_inds 112 | 113 | 114 | def set_bn_train(model): 115 | def set_bn_train_helper(m): 116 | classname = m.__class__.__name__ 117 | if classname.find('BatchNorm') != -1: 118 | m.train() 119 | 120 | model.eval() 121 | model.apply(set_bn_train_helper) 122 | 123 | 124 | def moment_update(model, model_ema, m): 125 | """ model_ema = m * model_ema + (1 - m) model """ 126 | for p1, p2 in zip(model.parameters(), model_ema.parameters()): 127 | p2.data.mul_(m).add_(1 - m, p1.detach().data) 128 | 129 | 130 | class MyHelpFormatter(argparse.MetavarTypeHelpFormatter, argparse.ArgumentDefaultsHelpFormatter): 131 | pass 132 | -------------------------------------------------------------------------------- /moco/models/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152'] 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | class Normalize(nn.Module): 25 | 26 | def __init__(self, power=2): 27 | super(Normalize, self).__init__() 28 | self.power = power 29 | 30 | def forward(self, x): 31 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 32 | out = x.div(norm) 33 | return out 34 | 35 | 36 | class BasicBlock(nn.Module): 37 | expansion = 1 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None): 40 | super(BasicBlock, self).__init__() 41 | self.conv1 = conv3x3(inplanes, planes, stride) 42 | self.bn1 = nn.BatchNorm2d(planes) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.conv2 = conv3x3(planes, planes) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.downsample = downsample 47 | self.stride = stride 48 | 49 | def forward(self, x): 50 | residual = x 51 | 52 | out = self.conv1(x) 53 | out = self.bn1(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv2(out) 57 | out = self.bn2(out) 58 | 59 | if self.downsample is not None: 60 | residual = self.downsample(x) 61 | 62 | out += residual 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | 68 | class Bottleneck(nn.Module): 69 | expansion = 4 70 | 71 | def __init__(self, inplanes, planes, stride=1, downsample=None): 72 | super(Bottleneck, self).__init__() 73 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 74 | self.bn1 = nn.BatchNorm2d(planes) 75 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 76 | padding=1, bias=False) 77 | self.bn2 = nn.BatchNorm2d(planes) 78 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 79 | self.bn3 = nn.BatchNorm2d(planes * 4) 80 | self.relu = nn.ReLU(inplace=True) 81 | self.downsample = downsample 82 | self.stride = stride 83 | 84 | def forward(self, x): 85 | residual = x 86 | 87 | out = self.conv1(x) 88 | out = self.bn1(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv2(out) 92 | out = self.bn2(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv3(out) 96 | out = self.bn3(out) 97 | 98 | if self.downsample is not None: 99 | residual = self.downsample(x) 100 | 101 | out += residual 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | 107 | class ResNet(nn.Module): 108 | 109 | def __init__(self, block, layers, low_dim=128, in_channel=3, width=1): 110 | self.inplanes = 64 111 | super(ResNet, self).__init__() 112 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, 113 | bias=False) 114 | self.bn1 = nn.BatchNorm2d(64) 115 | self.relu = nn.ReLU(inplace=True) 116 | 117 | self.base = int(64 * width) 118 | 119 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 120 | self.layer1 = self._make_layer(block, self.base, layers[0]) 121 | self.layer2 = self._make_layer(block, self.base * 2, layers[1], stride=2) 122 | self.layer3 = self._make_layer(block, self.base * 4, layers[2], stride=2) 123 | self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=2) 124 | self.avgpool = nn.AvgPool2d(7, stride=1) 125 | self.fc = nn.Linear(self.base * 8 * block.expansion, low_dim) 126 | self.l2norm = Normalize(2) 127 | 128 | for m in self.modules(): 129 | if isinstance(m, nn.Conv2d): 130 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 131 | m.weight.data.normal_(0, math.sqrt(2. / n)) 132 | elif isinstance(m, nn.BatchNorm2d): 133 | m.weight.data.fill_(1) 134 | m.bias.data.zero_() 135 | 136 | def _make_layer(self, block, planes, blocks, stride=1): 137 | downsample = None 138 | if stride != 1 or self.inplanes != planes * block.expansion: 139 | downsample = nn.Sequential( 140 | nn.Conv2d(self.inplanes, planes * block.expansion, 141 | kernel_size=1, stride=stride, bias=False), 142 | nn.BatchNorm2d(planes * block.expansion), 143 | ) 144 | 145 | layers = [block(self.inplanes, planes, stride, downsample)] 146 | self.inplanes = planes * block.expansion 147 | for i in range(1, blocks): 148 | layers.append(block(self.inplanes, planes)) 149 | 150 | return nn.Sequential(*layers) 151 | 152 | def forward(self, x, layer=7): 153 | if layer <= 0: 154 | return x 155 | x = self.conv1(x) 156 | x = self.bn1(x) 157 | x = self.relu(x) 158 | x = self.maxpool(x) 159 | if layer == 1: 160 | return x 161 | x = self.layer1(x) 162 | if layer == 2: 163 | return x 164 | x = self.layer2(x) 165 | if layer == 3: 166 | return x 167 | x = self.layer3(x) 168 | if layer == 4: 169 | return x 170 | x = self.layer4(x) 171 | if layer == 5: 172 | return x 173 | x = self.avgpool(x) 174 | x = x.view(x.size(0), -1) 175 | if layer == 6: 176 | return x 177 | x = self.fc(x) 178 | x = self.l2norm(x) 179 | 180 | return x 181 | 182 | 183 | def resnet18(pretrained=False, **kwargs): 184 | """Constructs a ResNet-18 model. 185 | Args: 186 | pretrained (bool): If True, returns a model pre-trained on ImageNet 187 | """ 188 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 189 | if pretrained: 190 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 191 | return model 192 | 193 | 194 | def resnet34(pretrained=False, **kwargs): 195 | """Constructs a ResNet-34 model. 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | """ 199 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 200 | if pretrained: 201 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 202 | return model 203 | 204 | 205 | def resnet50(pretrained=False, **kwargs): 206 | """Constructs a ResNet-50 model. 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 211 | if pretrained: 212 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 213 | return model 214 | 215 | 216 | def resnet101(pretrained=False, **kwargs): 217 | """Constructs a ResNet-101 model. 218 | Args: 219 | pretrained (bool): If True, returns a model pre-trained on ImageNet 220 | """ 221 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 222 | if pretrained: 223 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 224 | return model 225 | 226 | 227 | def resnet152(pretrained=False, **kwargs): 228 | """Constructs a ResNet-152 model. 229 | Args: 230 | pretrained (bool): If True, returns a model pre-trained on ImageNet 231 | """ 232 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 233 | if pretrained: 234 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 235 | return model 236 | 237 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for MoCo pre-training 3 | 4 | MoCo: Momentum Contrast for Unsupervised Visual Representation Learning 5 | 6 | """ 7 | import argparse 8 | import os 9 | import time 10 | import json 11 | 12 | import torch 13 | import torch.backends.cudnn as cudnn 14 | import torch.distributed as dist 15 | import torch.nn.functional as F 16 | from torch.utils.tensorboard import SummaryWriter 17 | from torch.nn.parallel import DistributedDataParallel 18 | from torchvision import transforms 19 | 20 | 21 | from moco.NCE import MemoryMoCo, NCESoftmaxLoss 22 | from moco.dataset import ImageFolderInstance 23 | from moco.logger import setup_logger 24 | from moco.models.resnet import resnet50 25 | from moco.util import AverageMeter, MyHelpFormatter, DistributedShufle, set_bn_train, moment_update 26 | from moco.lr_scheduler import get_scheduler 27 | 28 | try: 29 | # noinspection PyUnresolvedReferences 30 | from apex import amp 31 | except ImportError: 32 | amp = None 33 | 34 | 35 | def parse_option(): 36 | parser = argparse.ArgumentParser('moco training', formatter_class=MyHelpFormatter) 37 | 38 | # dataset 39 | parser.add_argument('--data-dir', type=str, required=True, help='root director of dataset') 40 | parser.add_argument('--dataset', type=str, default='imagenet', choices=['imagenet100', 'imagenet'], 41 | help='dataset to training') 42 | parser.add_argument('--crop', type=float, default=0.08, help='minimum crop') 43 | parser.add_argument('--aug', type=str, default='CJ', choices=['NULL', 'CJ'], 44 | help="augmentation type: NULL for normal supervised aug, CJ for aug with ColorJitter") 45 | parser.add_argument('--batch-size', type=int, default=128, help='batch_size') 46 | parser.add_argument('--num-workers', type=int, default=4, help='num of workers to use') 47 | 48 | # model and loss function 49 | parser.add_argument('--model', type=str, default='resnet50', choices=['resnet50'], help="backbone model") 50 | parser.add_argument('--model-width', type=int, default=1, help='width of resnet, eg, 1, 2, 4') 51 | parser.add_argument('--alpha', type=float, default=0.999, help='exponential moving average weight') 52 | parser.add_argument('--nce-k', type=int, default=65536, help='num negative sampler') 53 | parser.add_argument('--nce-t', type=float, default=0.07, help='NCE temperature') 54 | 55 | # optimization 56 | parser.add_argument('--base-learning-rate', '--base-lr', type=float, default=0.1, 57 | help='base learning when batch size = 256. final lr is determined by linear scale') 58 | parser.add_argument('--lr-scheduler', type=str, default='cosine', 59 | choices=["step", "cosine"], help="learning rate scheduler") 60 | parser.add_argument('--warmup-epoch', type=int, default=5, help='warmup epoch') 61 | parser.add_argument('--warmup-multiplier', type=int, default=100, help='warmup multiplier') 62 | parser.add_argument('--lr-decay-epochs', type=int, default=[120, 160, 200], nargs='+', 63 | help='for step scheduler. where to decay lr, can be a list') 64 | parser.add_argument('--lr-decay-rate', type=float, default=0.1, 65 | help='for step scheduler. decay rate for learning rate') 66 | parser.add_argument('--weight-decay', type=float, default=1e-4, help='weight decay') 67 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum for SGD') 68 | parser.add_argument('--amp-opt-level', type=str, default='O0', choices=['O0', 'O1', 'O2'], 69 | help='mixed precision opt level, if O0, no amp is used') 70 | parser.add_argument('--epochs', type=int, default=200, help='number of training epochs') 71 | parser.add_argument('--start-epoch', type=int, default=1, help='used for resume') 72 | 73 | # io 74 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 75 | help='path to latest checkpoint (default: none)') 76 | parser.add_argument('--print-freq', type=int, default=10, help='print frequency') 77 | parser.add_argument('--save-freq', type=int, default=10, help='save frequency') 78 | parser.add_argument('--output-dir', type=str, default='./output', help='output director') 79 | 80 | # misc 81 | parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel') 82 | parser.add_argument("--rng-seed", type=int, default=0, help='manual seed') 83 | 84 | args = parser.parse_args() 85 | 86 | torch.manual_seed(args.rng_seed) 87 | torch.cuda.manual_seed_all(args.rng_seed) 88 | 89 | return args 90 | 91 | 92 | def get_loader(args): 93 | # set the data loader 94 | train_folder = os.path.join(args.data_dir, 'train') 95 | 96 | image_size = 224 97 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 98 | 99 | if args.aug == 'NULL': 100 | train_transform = transforms.Compose([ 101 | transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), 102 | transforms.RandomHorizontalFlip(), 103 | transforms.ToTensor(), 104 | normalize, 105 | ]) 106 | elif args.aug == 'CJ': 107 | train_transform = transforms.Compose([ 108 | transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), 109 | transforms.RandomGrayscale(p=0.2), 110 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), 111 | transforms.RandomHorizontalFlip(), 112 | transforms.ToTensor(), 113 | normalize, 114 | ]) 115 | else: 116 | raise NotImplementedError('augmentation not supported: {}'.format(args.aug)) 117 | 118 | train_dataset = ImageFolderInstance(train_folder, transform=train_transform, two_crop=True) 119 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 120 | train_loader = torch.utils.data.DataLoader( 121 | train_dataset, batch_size=args.batch_size, shuffle=False, 122 | num_workers=args.num_workers, pin_memory=True, 123 | sampler=train_sampler, drop_last=True) 124 | 125 | return train_loader 126 | 127 | 128 | def build_model(args): 129 | model = resnet50(width=args.model_width).cuda() 130 | model_ema = resnet50(width=args.model_width).cuda() 131 | 132 | # copy weights from `model' to `model_ema' 133 | moment_update(model, model_ema, 0) 134 | 135 | return model, model_ema 136 | 137 | 138 | def load_checkpoint(args, model, model_ema, contrast, optimizer, scheduler): 139 | logger.info("=> loading checkpoint '{}'".format(args.resume)) 140 | 141 | checkpoint = torch.load(args.resume, map_location='cpu') 142 | args.start_epoch = checkpoint['epoch'] + 1 143 | model.load_state_dict(checkpoint['model']) 144 | model_ema.load_state_dict(checkpoint['model_ema']) 145 | contrast.load_state_dict(checkpoint['contrast']) 146 | optimizer.load_state_dict(checkpoint['optimizer']) 147 | scheduler.load_state_dict(checkpoint['scheduler']) 148 | if args.amp_opt_level != "O0" and checkpoint['opt'].amp_opt_level != "O0": 149 | amp.load_state_dict(checkpoint['amp']) 150 | 151 | logger.info("=> loaded successfully '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) 152 | 153 | del checkpoint 154 | torch.cuda.empty_cache() 155 | 156 | 157 | def save_checkpoint(args, epoch, model, model_ema, contrast, optimizer, scheduler): 158 | logger.info('==> Saving...') 159 | state = { 160 | 'opt': args, 161 | 'model': model.state_dict(), 162 | 'model_ema': model_ema.state_dict(), 163 | 'contrast': contrast.state_dict(), 164 | 'optimizer': optimizer.state_dict(), 165 | 'scheduler': scheduler.state_dict(), 166 | 'epoch': epoch, 167 | } 168 | if args.amp_opt_level != "O0": 169 | state['amp'] = amp.state_dict() 170 | torch.save(state, os.path.join(args.output_dir, 'current.pth')) 171 | if epoch % args.save_freq == 0: 172 | torch.save(state, os.path.join(args.output_dir, f'ckpt_epoch_{epoch}.pth')) 173 | 174 | 175 | def main(args): 176 | train_loader = get_loader(args) 177 | n_data = len(train_loader.dataset) 178 | logger.info(f"length of training dataset: {n_data}") 179 | 180 | model, model_ema = build_model(args) 181 | contrast = MemoryMoCo(128, args.nce_k, args.nce_t).cuda() 182 | criterion = NCESoftmaxLoss().cuda() 183 | optimizer = torch.optim.SGD(model.parameters(), 184 | lr=args.batch_size * dist.get_world_size() / 256 * args.base_learning_rate, 185 | momentum=args.momentum, 186 | weight_decay=args.weight_decay) 187 | scheduler = get_scheduler(optimizer, len(train_loader), args) 188 | 189 | if args.amp_opt_level != "O0": 190 | if amp is None: 191 | logger.warning(f"apex is not installed but amp_opt_level is set to {args.amp_opt_level}, ignoring.\n" 192 | "you should install apex from https://github.com/NVIDIA/apex#quick-start first") 193 | args.amp_opt_level = "O0" 194 | else: 195 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp_opt_level) 196 | model_ema = amp.initialize(model_ema, opt_level=args.amp_opt_level) 197 | 198 | model = DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False) 199 | 200 | # optionally resume from a checkpoint 201 | if args.resume: 202 | assert os.path.isfile(args.resume) 203 | load_checkpoint(args, model, model_ema, contrast, optimizer, scheduler) 204 | 205 | # tensorboard 206 | if dist.get_rank() == 0: 207 | summary_writer = SummaryWriter(log_dir=args.output_dir) 208 | else: 209 | summary_writer = None 210 | 211 | # routine 212 | for epoch in range(args.start_epoch, args.epochs + 1): 213 | train_loader.sampler.set_epoch(epoch) 214 | 215 | tic = time.time() 216 | loss, prob = train_moco(epoch, train_loader, model, model_ema, contrast, criterion, optimizer, scheduler, args) 217 | 218 | logger.info('epoch {}, total time {:.2f}'.format(epoch, time.time() - tic)) 219 | 220 | if summary_writer is not None: 221 | # tensorboard logger 222 | summary_writer.add_scalar('ins_loss', loss, epoch) 223 | summary_writer.add_scalar('ins_prob', prob, epoch) 224 | summary_writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch) 225 | 226 | if dist.get_rank() == 0: 227 | # save model 228 | save_checkpoint(args, epoch, model, model_ema, contrast, optimizer, scheduler) 229 | 230 | 231 | def train_moco(epoch, train_loader, model, model_ema, contrast, criterion, optimizer, scheduler, args): 232 | """ 233 | one epoch training for moco 234 | """ 235 | model.train() 236 | set_bn_train(model_ema) 237 | 238 | batch_time = AverageMeter() 239 | data_time = AverageMeter() 240 | loss_meter = AverageMeter() 241 | prob_meter = AverageMeter() 242 | 243 | end = time.time() 244 | for idx, (inputs, _,) in enumerate(train_loader): 245 | data_time.update(time.time() - end) 246 | 247 | bsz = inputs.size(0) 248 | 249 | # forward 250 | x1, x2 = torch.split(inputs, [3, 3], dim=1) 251 | x1.contiguous() 252 | x2.contiguous() 253 | x1 = x1.cuda(non_blocking=True) 254 | x2 = x2.cuda(non_blocking=True) 255 | 256 | feat_q = model(x1) 257 | with torch.no_grad(): 258 | x2_shuffled, backward_inds = DistributedShufle.forward_shuffle(x2, epoch) 259 | feat_k = model_ema(x2_shuffled) 260 | feat_k_all, feat_k = DistributedShufle.backward_shuffle(feat_k, backward_inds, return_local=True) 261 | 262 | out = contrast(feat_q, feat_k, feat_k_all) 263 | loss = criterion(out) 264 | prob = F.softmax(out, dim=1)[:, 0].mean() 265 | 266 | # backward 267 | optimizer.zero_grad() 268 | optimizer.zero_grad() 269 | if args.amp_opt_level != "O0": 270 | with amp.scale_loss(loss, optimizer) as scaled_loss: 271 | scaled_loss.backward() 272 | else: 273 | loss.backward() 274 | optimizer.step() 275 | scheduler.step() 276 | 277 | moment_update(model, model_ema, args.alpha) 278 | 279 | # update meters 280 | loss_meter.update(loss.item(), bsz) 281 | prob_meter.update(prob.item(), bsz) 282 | batch_time.update(time.time() - end) 283 | end = time.time() 284 | 285 | # print info 286 | if idx % args.print_freq == 0: 287 | logger.info(f'Train: [{epoch}][{idx}/{len(train_loader)}]\t' 288 | f'T {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 289 | f'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 290 | f'loss {loss_meter.val:.3f} ({loss_meter.avg:.3f})\t' 291 | f'prob {prob_meter.val:.3f} ({prob_meter.avg:.3f})') 292 | 293 | return loss_meter.avg, prob_meter.avg 294 | 295 | 296 | if __name__ == '__main__': 297 | opt = parse_option() 298 | 299 | torch.cuda.set_device(opt.local_rank) 300 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 301 | cudnn.benchmark = True 302 | 303 | os.makedirs(opt.output_dir, exist_ok=True) 304 | logger = setup_logger(output=opt.output_dir, distributed_rank=dist.get_rank(), name="moco") 305 | if dist.get_rank() == 0: 306 | path = os.path.join(opt.output_dir, "config.json") 307 | with open(path, 'w') as f: 308 | json.dump(vars(opt), f, indent=2) 309 | logger.info("Full config saved to {}".format(path)) 310 | 311 | main(opt) 312 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | evaluating MoCo: Momentum Contrast for Unsupervised Visual Representation Learning 3 | """ 4 | import argparse 5 | import json 6 | import os 7 | import time 8 | 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import torch.distributed as dist 12 | from torch.nn.parallel import DistributedDataParallel 13 | from torch.utils.data import DataLoader 14 | from torch.utils.data.distributed import DistributedSampler 15 | from torch.utils.tensorboard import SummaryWriter 16 | from torchvision import datasets, transforms 17 | 18 | from moco.logger import setup_logger 19 | from moco.lr_scheduler import get_scheduler 20 | from moco.models.LinearModel import LinearClassifierResNet 21 | from moco.models.resnet import resnet50 22 | from moco.util import AverageMeter, MyHelpFormatter, accuracy, reduce_tensor 23 | 24 | try: 25 | # noinspection PyUnresolvedReferences 26 | from apex import amp 27 | except ImportError: 28 | amp = None 29 | 30 | 31 | def parse_option(): 32 | parser = argparse.ArgumentParser('moco eval', formatter_class=MyHelpFormatter) 33 | 34 | # dataset 35 | parser.add_argument('--data-dir', type=str, required=True, help='root director of dataset') 36 | parser.add_argument('--dataset', type=str, default='imagenet', 37 | choices=['imagenet', 'imagenet100'], help='dataset name') 38 | parser.add_argument('--crop', type=float, default=0.08, help='minimum crop') 39 | parser.add_argument('--aug', type=str, default='NULL', choices=['NULL', 'CJ'], 40 | help='augmentation type: NULL for normal supervised aug, CJ for aug with ColorJitter') 41 | parser.add_argument('--total-batch-size', type=int, default=256, help='total train batch size for all GPU') 42 | parser.add_argument('--num-workers', type=int, default=4, help='num of workers to use') 43 | 44 | # model definition 45 | parser.add_argument('--model', type=str, default='resnet50', choices=['resnet50'], help="backbone model") 46 | parser.add_argument('--model-width', type=int, default=1, help='width of resnet, eg, 1, 2, 4') 47 | parser.add_argument('--layer', type=int, default=6, help='which layer to evaluate') 48 | 49 | # optimization 50 | parser.add_argument('--learning-rate', type=float, default=30, help='learning rate') 51 | parser.add_argument('--lr-scheduler', type=str, default='cosine', 52 | choices=["step", "cosine"], help="learning rate scheduler") 53 | parser.add_argument('--warmup-epoch', type=int, default=5, help='warmup epoch') 54 | parser.add_argument('--warmup-multiplier', type=int, default=100, help='warmup multiplier') 55 | parser.add_argument('--lr-decay-epochs', type=int, default=[30, 60, 90], nargs='+', 56 | help='for step scheduler. where to decay lr, can be a list') 57 | parser.add_argument('--lr-decay-rate', type=float, default=0.1, 58 | help='for step scheduler. decay rate for learning rate') 59 | parser.add_argument('--weight-decay', type=float, default=0, help='weight decay') 60 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 61 | parser.add_argument('--amp-opt-level', type=str, default='O0', choices=['O0', 'O1', 'O2'], 62 | help='mixed precision opt level, if O0, no amp is used') 63 | parser.add_argument('--epochs', type=int, default=100, help='number of training epochs') 64 | parser.add_argument('--start-epoch', type=int, default=1, help='used for resume') 65 | 66 | # io 67 | parser.add_argument('--pretrained-model', type=str, required=True, help="pretrained model path") 68 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 69 | help='path to latest checkpoint (default: none)') 70 | parser.add_argument('--output-dir', type=str, default='./output', help='root director for output') 71 | parser.add_argument('--print-freq', type=int, default=10, help='print frequency') 72 | parser.add_argument('--save-freq', type=int, default=5, help='save frequency') 73 | 74 | # misc 75 | parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel') 76 | parser.add_argument("--rng-seed", type=int, default=0, help='manual seed') 77 | parser.add_argument('-e', '--eval', action='store_true', help='only evaluate') 78 | 79 | args = parser.parse_args() 80 | 81 | torch.manual_seed(args.rng_seed) 82 | torch.cuda.manual_seed_all(args.rng_seed) 83 | 84 | return args 85 | 86 | 87 | def get_loader(args): 88 | image_size = 224 89 | crop_padding = 32 90 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 91 | 92 | if args.aug == 'NULL': 93 | train_transform = transforms.Compose([ 94 | transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), 95 | transforms.RandomHorizontalFlip(), 96 | transforms.ToTensor(), 97 | normalize, 98 | ]) 99 | elif args.aug == 'CJ': 100 | train_transform = transforms.Compose([ 101 | transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)), 102 | transforms.RandomGrayscale(p=0.2), 103 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), 104 | transforms.RandomHorizontalFlip(), 105 | transforms.ToTensor(), 106 | normalize, 107 | ]) 108 | else: 109 | raise NotImplementedError('augmentation not supported: {}'.format(args.aug)) 110 | 111 | val_transform = transforms.Compose([ 112 | transforms.Resize(image_size + crop_padding), 113 | transforms.CenterCrop(image_size), 114 | transforms.ToTensor(), 115 | normalize, 116 | ]) 117 | 118 | # set the data loader 119 | train_dataset = datasets.ImageFolder(os.path.join(args.data_dir, 'train'), train_transform) 120 | val_dataset = datasets.ImageFolder(os.path.join(args.data_dir, 'val'), val_transform) 121 | batch_size = args.total_batch_size // dist.get_world_size() 122 | train_loader = DataLoader(train_dataset, 123 | batch_size=batch_size, 124 | num_workers=args.num_workers, 125 | sampler=DistributedSampler(train_dataset), 126 | shuffle=False, 127 | pin_memory=True, 128 | drop_last=True) 129 | val_loader = DataLoader(val_dataset, 130 | batch_size=batch_size, 131 | num_workers=args.num_workers, 132 | sampler=DistributedSampler(val_dataset, shuffle=False), 133 | shuffle=False, 134 | pin_memory=True, 135 | drop_last=False) 136 | 137 | return train_loader, val_loader 138 | 139 | 140 | def build_model(args, num_class): 141 | # create model 142 | model = resnet50(width=args.model_width).cuda() 143 | for p in model.parameters(): 144 | p.requires_grad = False 145 | classifier = LinearClassifierResNet(args.layer, num_class, 'avg', args.model_width).cuda() 146 | return model, classifier 147 | 148 | 149 | def load_pretrained(args, model): 150 | ckpt = torch.load(args.pretrained_model, map_location='cpu') 151 | state_dict = {k.replace("module.", ""): v for k, v in ckpt['model'].items()} 152 | model.load_state_dict(state_dict) 153 | logger.info(f"==> loaded checkpoint '{args.pretrained_model}' (epoch {ckpt['epoch']})") 154 | 155 | 156 | def load_checkpoint(args, classifier, optimizer, scheduler): 157 | logger.info("=> loading checkpoint '{args.resume'") 158 | 159 | checkpoint = torch.load(args.resume, map_location='cpu') 160 | 161 | global best_acc1 162 | best_acc1 = checkpoint['best_acc1'] 163 | args.start_epoch = checkpoint['epoch'] + 1 164 | classifier.load_state_dict(checkpoint['classifier']) 165 | optimizer.load_state_dict(checkpoint['optimizer']) 166 | scheduler.load_state_dict(checkpoint['scheduler']) 167 | if args.amp_opt_level != "O0" and checkpoint['opt'].amp_opt_level != "O0": 168 | amp.load_state_dict(checkpoint['amp']) 169 | 170 | logger.info(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})") 171 | 172 | 173 | def save_checkpoint(args, epoch, classifier, test_acc, optimizer, scheduler): 174 | state = { 175 | 'opt': args, 176 | 'epoch': epoch, 177 | 'classifier': classifier.state_dict(), 178 | 'best_acc1': test_acc, 179 | 'optimizer': optimizer.state_dict(), 180 | 'scheduler': scheduler.state_dict(), 181 | } 182 | if args.amp_opt_level != "O0": 183 | state['amp'] = amp.state_dict() 184 | torch.save(state, os.path.join(args.output_dir, f'ckpt_epoch_{epoch}.pth')) 185 | torch.save(state, os.path.join(args.output_dir, f'current.pth')) 186 | 187 | 188 | def main(args): 189 | global best_acc1 190 | 191 | train_loader, val_loader = get_loader(args) 192 | logger.info(f"length of training dataset: {len(train_loader.dataset)}") 193 | 194 | model, classifier = build_model(args, num_class=len(train_loader.dataset.classes)) 195 | criterion = torch.nn.CrossEntropyLoss().cuda() 196 | optimizer = torch.optim.SGD(classifier.parameters(), 197 | lr=args.learning_rate, 198 | momentum=args.momentum, 199 | weight_decay=args.weight_decay) 200 | scheduler = get_scheduler(optimizer, len(train_loader), args) 201 | 202 | if args.amp_opt_level != "O0": 203 | if amp is None: 204 | logger.warning(f"apex is not installed but amp_opt_level is set to {args.amp_opt_level}, ignoring.\n" 205 | "you should install apex from https://github.com/NVIDIA/apex#quick-start first") 206 | args.amp_opt_level = "O0" 207 | else: 208 | model = amp.initialize(model, opt_level=args.amp_opt_level) 209 | classifier, optimizer = amp.initialize(classifier, optimizer, opt_level=args.amp_opt_level) 210 | 211 | classifier = DistributedDataParallel(classifier, device_ids=[args.local_rank], broadcast_buffers=False) 212 | 213 | model.eval() 214 | 215 | load_pretrained(args, model) 216 | # optionally resume from a checkpoint 217 | if args.resume: 218 | assert os.path.isfile(args.resume), f"no checkpoint found at '{args.resume}'" 219 | load_checkpoint(args, classifier, optimizer, scheduler) 220 | 221 | if args.eval: 222 | logger.info("==> testing...") 223 | validate(val_loader, model, classifier, criterion, args) 224 | return 225 | 226 | # tensorboard 227 | if dist.get_rank() == 0: 228 | summary_writer = SummaryWriter(log_dir=args.output_dir) 229 | else: 230 | summary_writer = None 231 | 232 | # routine 233 | for epoch in range(args.start_epoch, args.epochs + 1): 234 | if isinstance(train_loader.sampler, DistributedSampler): 235 | train_loader.sampler.set_epoch(epoch) 236 | 237 | tic = time.time() 238 | train(epoch, train_loader, model, classifier, criterion, optimizer, scheduler, args) 239 | logger.info(f'epoch {epoch}, total time {time.time() - tic:.2f}') 240 | 241 | logger.info("==> testing...") 242 | test_acc = validate(val_loader, model, classifier, criterion, args) 243 | 244 | if dist.get_rank() == 0 and epoch % args.save_freq == 0: 245 | logger.info('==> Saving...') 246 | save_checkpoint(args, epoch, classifier, test_acc, optimizer, scheduler) 247 | 248 | if summary_writer is not None: 249 | # tensorboard logger 250 | summary_writer.add_scalar('ins_loss', test_acc, epoch) 251 | summary_writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch) 252 | 253 | 254 | def train(epoch, train_loader, model, classifier, criterion, optimizer, scheduler, args): 255 | """ 256 | one epoch training 257 | """ 258 | 259 | model.eval() 260 | classifier.train() 261 | 262 | batch_time = AverageMeter() 263 | data_time = AverageMeter() 264 | losses = AverageMeter() 265 | top1 = AverageMeter() 266 | top5 = AverageMeter() 267 | 268 | end = time.time() 269 | for idx, (x, y) in enumerate(train_loader): 270 | x = x.cuda(non_blocking=True) 271 | y = y.cuda(non_blocking=True) 272 | 273 | # measure data loading time 274 | data_time.update(time.time() - end) 275 | 276 | # ===================forward===================== 277 | with torch.no_grad(): 278 | feat = model(x, args.layer) 279 | 280 | output = classifier(feat) 281 | loss = criterion(output, y) 282 | 283 | acc1, acc5 = accuracy(output, y, topk=(1, 5)) 284 | losses.update(loss.item(), x.size(0)) 285 | top1.update(acc1[0], x.size(0)) 286 | top5.update(acc5[0], x.size(0)) 287 | 288 | # ===================backward===================== 289 | optimizer.zero_grad() 290 | if args.amp_opt_level != "O0": 291 | with amp.scale_loss(loss, optimizer) as scaled_loss: 292 | scaled_loss.backward() 293 | else: 294 | loss.backward() 295 | optimizer.step() 296 | scheduler.step() 297 | 298 | # ===================meters===================== 299 | batch_time.update(time.time() - end) 300 | end = time.time() 301 | 302 | # print info 303 | if idx % args.print_freq == 0: 304 | logger.info( 305 | f'Epoch: [{epoch}][{idx}/{len(train_loader)}]\t' 306 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 307 | f'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 308 | f'Lr {optimizer.param_groups[0]["lr"]:.3f} \t' 309 | f'Loss {losses.val:.4f} ({losses.avg:.4f})\t' 310 | f'Acc@1 {top1.val:.3%} ({top1.avg:.3%})\t' 311 | f'Acc@5 {top5.val:.3%} ({top5.avg:.3%})') 312 | 313 | 314 | def validate(val_loader, model, classifier, criterion, args): 315 | batch_time = AverageMeter() 316 | losses = AverageMeter() 317 | top1 = AverageMeter() 318 | top5 = AverageMeter() 319 | 320 | # switch to evaluate mode 321 | model.eval() 322 | classifier.eval() 323 | 324 | with torch.no_grad(): 325 | end = time.time() 326 | for idx, (x, y) in enumerate(val_loader): 327 | x = x.cuda(non_blocking=True) 328 | y = y.cuda(non_blocking=True) 329 | 330 | # compute output 331 | feat = model(x, args.layer) 332 | output = classifier(feat) 333 | loss = criterion(output, y) 334 | 335 | # measure accuracy and record loss 336 | acc1, acc5 = accuracy(output, y, topk=(1, 5)) 337 | 338 | acc1 = reduce_tensor(acc1) 339 | acc5 = reduce_tensor(acc5) 340 | loss = reduce_tensor(loss) 341 | 342 | losses.update(loss.item(), x.size(0)) 343 | top1.update(acc1[0], x.size(0)) 344 | top5.update(acc5[0], x.size(0)) 345 | 346 | # measure elapsed time 347 | batch_time.update(time.time() - end) 348 | end = time.time() 349 | 350 | if idx % args.print_freq == 0: 351 | logger.info( 352 | f'Test: [{idx}/{len(val_loader)}]\t' 353 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 354 | f'Loss {losses.val:.4f} ({losses.avg:.4f})\t' 355 | f'Acc@1 {top1.val:.3%} ({top1.avg:.3%})\t' 356 | f'Acc@5 {top5.val:.3%} ({top5.avg:.3%})') 357 | 358 | logger.info(f' * Acc@1 {top1.avg:.3%} Acc@5 {top5.avg:.3%}') 359 | 360 | return top1.avg 361 | 362 | 363 | if __name__ == '__main__': 364 | opt = parse_option() 365 | 366 | torch.cuda.set_device(opt.local_rank) 367 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 368 | cudnn.benchmark = True 369 | best_acc1 = 0 370 | 371 | os.makedirs(opt.output_dir, exist_ok=True) 372 | logger = setup_logger(output=opt.output_dir, distributed_rank=dist.get_rank(), name="moco") 373 | if dist.get_rank() == 0: 374 | path = os.path.join(opt.output_dir, "config.json") 375 | with open(path, "w") as f: 376 | json.dump(vars(opt), f, indent=2) 377 | logger.info("Full config saved to {}".format(path)) 378 | 379 | main(opt) 380 | --------------------------------------------------------------------------------