├── .gitignore ├── LICENSE ├── README.md ├── figures ├── .DS_Store ├── SimCLR.jpg ├── SupCE.jpg ├── SupContrast.jpg └── teaser.png ├── losses.py ├── main_ce.py ├── main_linear.py ├── main_supcon.py ├── networks └── resnet_big.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | tmp*.py 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | save/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2020, Yonglong Tian 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SupContrast: Supervised Contrastive Learning 2 |

3 | 4 |

5 | 6 | This repo covers an reference implementation for the following papers in PyTorch, using CIFAR as an illustrative example: 7 | (1) Supervised Contrastive Learning. [Paper](https://arxiv.org/abs/2004.11362) 8 | (2) A Simple Framework for Contrastive Learning of Visual Representations. [Paper](https://arxiv.org/abs/2002.05709) 9 | 10 | ## Update 11 | 12 | ${\color{red}Note}$: if you found it not easy to parse the supcon loss implementation in this repo, we got you. Supcon loss essentially is just a cross-entropy loss (see eq 4 in the [StableRep](https://arxiv.org/pdf/2306.00984.pdf) paper). So we got a cleaner and simpler implementation [here](https://github.com/google-research/syn-rep-learn/blob/main/StableRep/models/losses.py#L49). Hope it helps. 13 | 14 | ImageNet model (small batch size with the trick of the momentum encoder) is released [here](https://www.dropbox.com/s/l4a69ececk4spdt/supcon.pth?dl=0). It achieved > 79% top-1 accuracy. 15 | 16 | ## Loss Function 17 | The loss function [`SupConLoss`](https://github.com/HobbitLong/SupContrast/blob/master/losses.py#L11) in `losses.py` takes `features` (L2 normalized) and `labels` as input, and return the loss. If `labels` is `None` or not passed to the it, it degenerates to SimCLR. 18 | 19 | Usage: 20 | ```python 21 | from losses import SupConLoss 22 | 23 | # define loss with a temperature `temp` 24 | criterion = SupConLoss(temperature=temp) 25 | 26 | # features: [bsz, n_views, f_dim] 27 | # `n_views` is the number of crops from each image 28 | # better be L2 normalized in f_dim dimension 29 | features = ... 30 | # labels: [bsz] 31 | labels = ... 32 | 33 | # SupContrast 34 | loss = criterion(features, labels) 35 | # or SimCLR 36 | loss = criterion(features) 37 | ... 38 | ``` 39 | 40 | ## Comparison 41 | Results on CIFAR-10: 42 | | |Arch | Setting | Loss | Accuracy(%) | 43 | |----------|:----:|:---:|:---:|:---:| 44 | | SupCrossEntropy | ResNet50 | Supervised | Cross Entropy | 95.0 | 45 | | SupContrast | ResNet50 | Supervised | Contrastive | 96.0 | 46 | | SimCLR | ResNet50 | Unsupervised | Contrastive | 93.6 | 47 | 48 | Results on CIFAR-100: 49 | | |Arch | Setting | Loss | Accuracy(%) | 50 | |----------|:----:|:---:|:---:|:---:| 51 | | SupCrossEntropy | ResNet50 | Supervised | Cross Entropy | 75.3 | 52 | | SupContrast | ResNet50 | Supervised | Contrastive | 76.5 | 53 | | SimCLR | ResNet50 | Unsupervised | Contrastive | 70.7 | 54 | 55 | Results on ImageNet (Stay tuned): 56 | | |Arch | Setting | Loss | Accuracy(%) | 57 | |----------|:----:|:---:|:---:|:---:| 58 | | SupCrossEntropy | ResNet50 | Supervised | Cross Entropy | - | 59 | | SupContrast | ResNet50 | Supervised | Contrastive | 79.1 (MoCo trick) | 60 | | SimCLR | ResNet50 | Unsupervised | Contrastive | - | 61 | 62 | ## Running 63 | You might use `CUDA_VISIBLE_DEVICES` to set proper number of GPUs, and/or switch to CIFAR100 by `--dataset cifar100`. 64 | **(1) Standard Cross-Entropy** 65 | ``` 66 | python main_ce.py --batch_size 1024 \ 67 | --learning_rate 0.8 \ 68 | --cosine --syncBN \ 69 | ``` 70 | **(2) Supervised Contrastive Learning** 71 | Pretraining stage: 72 | ``` 73 | python main_supcon.py --batch_size 1024 \ 74 | --learning_rate 0.5 \ 75 | --temp 0.1 \ 76 | --cosine 77 | ``` 78 | 79 | You can also specify `--syncBN` but I found it not crucial for SupContrast (`syncBN` 95.9% v.s. `BN` 96.0%). 80 | 81 | WARN: Currently, `--syncBN` has no effect since the code is using `DataParallel` instead of `DistributedDataParaleel` 82 | 83 | Linear evaluation stage: 84 | ``` 85 | python main_linear.py --batch_size 512 \ 86 | --learning_rate 5 \ 87 | --ckpt /path/to/model.pth 88 | ``` 89 | **(3) SimCLR** 90 | Pretraining stage: 91 | ``` 92 | python main_supcon.py --batch_size 1024 \ 93 | --learning_rate 0.5 \ 94 | --temp 0.5 \ 95 | --cosine --syncBN \ 96 | --method SimCLR 97 | ``` 98 | The `--method SimCLR` flag simply stops `labels` from being passed to `SupConLoss` criterion. 99 | Linear evaluation stage: 100 | ``` 101 | python main_linear.py --batch_size 512 \ 102 | --learning_rate 1 \ 103 | --ckpt /path/to/model.pth 104 | ``` 105 | 106 | On custom dataset: 107 | ``` 108 | python main_supcon.py --batch_size 1024 \ 109 | --learning_rate 0.5 \ 110 | --temp 0.1 --cosine \ 111 | --dataset path \ 112 | --data_folder ./path \ 113 | --mean "(0.4914, 0.4822, 0.4465)" \ 114 | --std "(0.2675, 0.2565, 0.2761)" \ 115 | --method SimCLR 116 | ``` 117 | 118 | The `--data_folder` must be of form ./path/label/xxx.png folowing https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.ImageFolder convension. 119 | 120 | and 121 | ## t-SNE Visualization 122 | 123 | **(1) Standard Cross-Entropy** 124 |

125 | 126 |

127 | 128 | **(2) Supervised Contrastive Learning** 129 |

130 | 131 |

132 | 133 | **(3) SimCLR** 134 |

135 | 136 |

137 | 138 | ## Reference 139 | ``` 140 | @Article{khosla2020supervised, 141 | title = {Supervised Contrastive Learning}, 142 | author = {Prannay Khosla and Piotr Teterwak and Chen Wang and Aaron Sarna and Yonglong Tian and Phillip Isola and Aaron Maschinot and Ce Liu and Dilip Krishnan}, 143 | journal = {arXiv preprint arXiv:2004.11362}, 144 | year = {2020}, 145 | } 146 | ``` 147 | -------------------------------------------------------------------------------- /figures/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HobbitLong/SupContrast/66a8fe53880d6a1084b2e4e0db0a019024d6d41a/figures/.DS_Store -------------------------------------------------------------------------------- /figures/SimCLR.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HobbitLong/SupContrast/66a8fe53880d6a1084b2e4e0db0a019024d6d41a/figures/SimCLR.jpg -------------------------------------------------------------------------------- /figures/SupCE.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HobbitLong/SupContrast/66a8fe53880d6a1084b2e4e0db0a019024d6d41a/figures/SupCE.jpg -------------------------------------------------------------------------------- /figures/SupContrast.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HobbitLong/SupContrast/66a8fe53880d6a1084b2e4e0db0a019024d6d41a/figures/SupContrast.jpg -------------------------------------------------------------------------------- /figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HobbitLong/SupContrast/66a8fe53880d6a1084b2e4e0db0a019024d6d41a/figures/teaser.png -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Yonglong Tian (yonglong@mit.edu) 3 | Date: May 07, 2020 4 | """ 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class SupConLoss(nn.Module): 12 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 13 | It also supports the unsupervised contrastive loss in SimCLR""" 14 | def __init__(self, temperature=0.07, contrast_mode='all', 15 | base_temperature=0.07): 16 | super(SupConLoss, self).__init__() 17 | self.temperature = temperature 18 | self.contrast_mode = contrast_mode 19 | self.base_temperature = base_temperature 20 | 21 | def forward(self, features, labels=None, mask=None): 22 | """Compute loss for model. If both `labels` and `mask` are None, 23 | it degenerates to SimCLR unsupervised loss: 24 | https://arxiv.org/pdf/2002.05709.pdf 25 | 26 | Args: 27 | features: hidden vector of shape [bsz, n_views, ...]. 28 | labels: ground truth of shape [bsz]. 29 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 30 | has the same class as sample i. Can be asymmetric. 31 | Returns: 32 | A loss scalar. 33 | """ 34 | device = (torch.device('cuda') 35 | if features.is_cuda 36 | else torch.device('cpu')) 37 | 38 | if len(features.shape) < 3: 39 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 40 | 'at least 3 dimensions are required') 41 | if len(features.shape) > 3: 42 | features = features.view(features.shape[0], features.shape[1], -1) 43 | 44 | batch_size = features.shape[0] 45 | if labels is not None and mask is not None: 46 | raise ValueError('Cannot define both `labels` and `mask`') 47 | elif labels is None and mask is None: 48 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 49 | elif labels is not None: 50 | labels = labels.contiguous().view(-1, 1) 51 | if labels.shape[0] != batch_size: 52 | raise ValueError('Num of labels does not match num of features') 53 | mask = torch.eq(labels, labels.T).float().to(device) 54 | else: 55 | mask = mask.float().to(device) 56 | 57 | contrast_count = features.shape[1] 58 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 59 | if self.contrast_mode == 'one': 60 | anchor_feature = features[:, 0] 61 | anchor_count = 1 62 | elif self.contrast_mode == 'all': 63 | anchor_feature = contrast_feature 64 | anchor_count = contrast_count 65 | else: 66 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 67 | 68 | # compute logits 69 | anchor_dot_contrast = torch.div( 70 | torch.matmul(anchor_feature, contrast_feature.T), 71 | self.temperature) 72 | # for numerical stability 73 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 74 | logits = anchor_dot_contrast - logits_max.detach() 75 | 76 | # tile mask 77 | mask = mask.repeat(anchor_count, contrast_count) 78 | # mask-out self-contrast cases 79 | logits_mask = torch.scatter( 80 | torch.ones_like(mask), 81 | 1, 82 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 83 | 0 84 | ) 85 | mask = mask * logits_mask 86 | 87 | # compute log_prob 88 | exp_logits = torch.exp(logits) * logits_mask 89 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 90 | 91 | # compute mean of log-likelihood over positive 92 | # modified to handle edge cases when there is no positive pair 93 | # for an anchor point. 94 | # Edge case e.g.:- 95 | # features of shape: [4,1,...] 96 | # labels: [0,1,1,2] 97 | # loss before mean: [nan, ..., ..., nan] 98 | mask_pos_pairs = mask.sum(1) 99 | mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs) 100 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs 101 | 102 | # loss 103 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 104 | loss = loss.view(anchor_count, batch_size).mean() 105 | 106 | return loss 107 | -------------------------------------------------------------------------------- /main_ce.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | import argparse 6 | import time 7 | import math 8 | 9 | import tensorboard_logger as tb_logger 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | from torchvision import transforms, datasets 13 | 14 | from util import AverageMeter 15 | from util import adjust_learning_rate, warmup_learning_rate, accuracy 16 | from util import set_optimizer, save_model 17 | from networks.resnet_big import SupCEResNet 18 | 19 | try: 20 | import apex 21 | from apex import amp, optimizers 22 | except ImportError: 23 | pass 24 | 25 | 26 | def parse_option(): 27 | parser = argparse.ArgumentParser('argument for training') 28 | 29 | parser.add_argument('--print_freq', type=int, default=10, 30 | help='print frequency') 31 | parser.add_argument('--save_freq', type=int, default=50, 32 | help='save frequency') 33 | parser.add_argument('--batch_size', type=int, default=256, 34 | help='batch_size') 35 | parser.add_argument('--num_workers', type=int, default=16, 36 | help='num of workers to use') 37 | parser.add_argument('--epochs', type=int, default=500, 38 | help='number of training epochs') 39 | 40 | # optimization 41 | parser.add_argument('--learning_rate', type=float, default=0.2, 42 | help='learning rate') 43 | parser.add_argument('--lr_decay_epochs', type=str, default='350,400,450', 44 | help='where to decay lr, can be a list') 45 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, 46 | help='decay rate for learning rate') 47 | parser.add_argument('--weight_decay', type=float, default=1e-4, 48 | help='weight decay') 49 | parser.add_argument('--momentum', type=float, default=0.9, 50 | help='momentum') 51 | 52 | # model dataset 53 | parser.add_argument('--model', type=str, default='resnet50') 54 | parser.add_argument('--dataset', type=str, default='cifar10', 55 | choices=['cifar10', 'cifar100'], help='dataset') 56 | 57 | # other setting 58 | parser.add_argument('--cosine', action='store_true', 59 | help='using cosine annealing') 60 | parser.add_argument('--syncBN', action='store_true', 61 | help='using synchronized batch normalization') 62 | parser.add_argument('--warm', action='store_true', 63 | help='warm-up for large batch training') 64 | parser.add_argument('--trial', type=str, default='0', 65 | help='id for recording multiple runs') 66 | 67 | opt = parser.parse_args() 68 | 69 | # set the path according to the environment 70 | opt.data_folder = './datasets/' 71 | opt.model_path = './save/SupCon/{}_models'.format(opt.dataset) 72 | opt.tb_path = './save/SupCon/{}_tensorboard'.format(opt.dataset) 73 | 74 | iterations = opt.lr_decay_epochs.split(',') 75 | opt.lr_decay_epochs = list([]) 76 | for it in iterations: 77 | opt.lr_decay_epochs.append(int(it)) 78 | 79 | opt.model_name = 'SupCE_{}_{}_lr_{}_decay_{}_bsz_{}_trial_{}'.\ 80 | format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay, 81 | opt.batch_size, opt.trial) 82 | 83 | if opt.cosine: 84 | opt.model_name = '{}_cosine'.format(opt.model_name) 85 | 86 | # warm-up for large-batch training, 87 | if opt.batch_size > 256: 88 | opt.warm = True 89 | if opt.warm: 90 | opt.model_name = '{}_warm'.format(opt.model_name) 91 | opt.warmup_from = 0.01 92 | opt.warm_epochs = 10 93 | if opt.cosine: 94 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) 95 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * ( 96 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 97 | else: 98 | opt.warmup_to = opt.learning_rate 99 | 100 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name) 101 | if not os.path.isdir(opt.tb_folder): 102 | os.makedirs(opt.tb_folder) 103 | 104 | opt.save_folder = os.path.join(opt.model_path, opt.model_name) 105 | if not os.path.isdir(opt.save_folder): 106 | os.makedirs(opt.save_folder) 107 | 108 | if opt.dataset == 'cifar10': 109 | opt.n_cls = 10 110 | elif opt.dataset == 'cifar100': 111 | opt.n_cls = 100 112 | else: 113 | raise ValueError('dataset not supported: {}'.format(opt.dataset)) 114 | 115 | return opt 116 | 117 | 118 | def set_loader(opt): 119 | # construct data loader 120 | if opt.dataset == 'cifar10': 121 | mean = (0.4914, 0.4822, 0.4465) 122 | std = (0.2023, 0.1994, 0.2010) 123 | elif opt.dataset == 'cifar100': 124 | mean = (0.5071, 0.4867, 0.4408) 125 | std = (0.2675, 0.2565, 0.2761) 126 | else: 127 | raise ValueError('dataset not supported: {}'.format(opt.dataset)) 128 | normalize = transforms.Normalize(mean=mean, std=std) 129 | 130 | train_transform = transforms.Compose([ 131 | transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)), 132 | transforms.RandomHorizontalFlip(), 133 | transforms.ToTensor(), 134 | normalize, 135 | ]) 136 | 137 | val_transform = transforms.Compose([ 138 | transforms.ToTensor(), 139 | normalize, 140 | ]) 141 | 142 | if opt.dataset == 'cifar10': 143 | train_dataset = datasets.CIFAR10(root=opt.data_folder, 144 | transform=train_transform, 145 | download=True) 146 | val_dataset = datasets.CIFAR10(root=opt.data_folder, 147 | train=False, 148 | transform=val_transform) 149 | elif opt.dataset == 'cifar100': 150 | train_dataset = datasets.CIFAR100(root=opt.data_folder, 151 | transform=train_transform, 152 | download=True) 153 | val_dataset = datasets.CIFAR100(root=opt.data_folder, 154 | train=False, 155 | transform=val_transform) 156 | else: 157 | raise ValueError(opt.dataset) 158 | 159 | train_sampler = None 160 | train_loader = torch.utils.data.DataLoader( 161 | train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None), 162 | num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler) 163 | val_loader = torch.utils.data.DataLoader( 164 | val_dataset, batch_size=256, shuffle=False, 165 | num_workers=8, pin_memory=True) 166 | 167 | return train_loader, val_loader 168 | 169 | 170 | def set_model(opt): 171 | model = SupCEResNet(name=opt.model, num_classes=opt.n_cls) 172 | criterion = torch.nn.CrossEntropyLoss() 173 | 174 | # enable synchronized Batch Normalization 175 | if opt.syncBN: 176 | model = apex.parallel.convert_syncbn_model(model) 177 | 178 | if torch.cuda.is_available(): 179 | if torch.cuda.device_count() > 1: 180 | model = torch.nn.DataParallel(model) 181 | model = model.cuda() 182 | criterion = criterion.cuda() 183 | cudnn.benchmark = True 184 | 185 | return model, criterion 186 | 187 | 188 | def train(train_loader, model, criterion, optimizer, epoch, opt): 189 | """one epoch training""" 190 | model.train() 191 | 192 | batch_time = AverageMeter() 193 | data_time = AverageMeter() 194 | losses = AverageMeter() 195 | top1 = AverageMeter() 196 | 197 | end = time.time() 198 | for idx, (images, labels) in enumerate(train_loader): 199 | data_time.update(time.time() - end) 200 | 201 | images = images.cuda(non_blocking=True) 202 | labels = labels.cuda(non_blocking=True) 203 | bsz = labels.shape[0] 204 | 205 | # warm-up learning rate 206 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer) 207 | 208 | # compute loss 209 | output = model(images) 210 | loss = criterion(output, labels) 211 | 212 | # update metric 213 | losses.update(loss.item(), bsz) 214 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) 215 | top1.update(acc1[0], bsz) 216 | 217 | # SGD 218 | optimizer.zero_grad() 219 | loss.backward() 220 | optimizer.step() 221 | 222 | # measure elapsed time 223 | batch_time.update(time.time() - end) 224 | end = time.time() 225 | 226 | # print info 227 | if (idx + 1) % opt.print_freq == 0: 228 | print('Train: [{0}][{1}/{2}]\t' 229 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 230 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 231 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t' 232 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 233 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 234 | data_time=data_time, loss=losses, top1=top1)) 235 | sys.stdout.flush() 236 | 237 | return losses.avg, top1.avg 238 | 239 | 240 | def validate(val_loader, model, criterion, opt): 241 | """validation""" 242 | model.eval() 243 | 244 | batch_time = AverageMeter() 245 | losses = AverageMeter() 246 | top1 = AverageMeter() 247 | 248 | with torch.no_grad(): 249 | end = time.time() 250 | for idx, (images, labels) in enumerate(val_loader): 251 | images = images.float().cuda() 252 | labels = labels.cuda() 253 | bsz = labels.shape[0] 254 | 255 | # forward 256 | output = model(images) 257 | loss = criterion(output, labels) 258 | 259 | # update metric 260 | losses.update(loss.item(), bsz) 261 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) 262 | top1.update(acc1[0], bsz) 263 | 264 | # measure elapsed time 265 | batch_time.update(time.time() - end) 266 | end = time.time() 267 | 268 | if idx % opt.print_freq == 0: 269 | print('Test: [{0}/{1}]\t' 270 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 271 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 272 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 273 | idx, len(val_loader), batch_time=batch_time, 274 | loss=losses, top1=top1)) 275 | 276 | print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) 277 | return losses.avg, top1.avg 278 | 279 | 280 | def main(): 281 | best_acc = 0 282 | opt = parse_option() 283 | 284 | # build data loader 285 | train_loader, val_loader = set_loader(opt) 286 | 287 | # build model and criterion 288 | model, criterion = set_model(opt) 289 | 290 | # build optimizer 291 | optimizer = set_optimizer(opt, model) 292 | 293 | # tensorboard 294 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) 295 | 296 | # training routine 297 | for epoch in range(1, opt.epochs + 1): 298 | adjust_learning_rate(opt, optimizer, epoch) 299 | 300 | # train for one epoch 301 | time1 = time.time() 302 | loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, opt) 303 | time2 = time.time() 304 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) 305 | 306 | # tensorboard logger 307 | logger.log_value('train_loss', loss, epoch) 308 | logger.log_value('train_acc', train_acc, epoch) 309 | logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) 310 | 311 | # evaluation 312 | loss, val_acc = validate(val_loader, model, criterion, opt) 313 | logger.log_value('val_loss', loss, epoch) 314 | logger.log_value('val_acc', val_acc, epoch) 315 | 316 | if val_acc > best_acc: 317 | best_acc = val_acc 318 | 319 | if epoch % opt.save_freq == 0: 320 | save_file = os.path.join( 321 | opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) 322 | save_model(model, optimizer, opt, epoch, save_file) 323 | 324 | # save the last model 325 | save_file = os.path.join( 326 | opt.save_folder, 'last.pth') 327 | save_model(model, optimizer, opt, opt.epochs, save_file) 328 | 329 | print('best accuracy: {:.2f}'.format(best_acc)) 330 | 331 | 332 | if __name__ == '__main__': 333 | main() 334 | -------------------------------------------------------------------------------- /main_linear.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | import argparse 5 | import time 6 | import math 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | 11 | from main_ce import set_loader 12 | from util import AverageMeter 13 | from util import adjust_learning_rate, warmup_learning_rate, accuracy 14 | from util import set_optimizer 15 | from networks.resnet_big import SupConResNet, LinearClassifier 16 | 17 | try: 18 | import apex 19 | from apex import amp, optimizers 20 | except ImportError: 21 | pass 22 | 23 | 24 | def parse_option(): 25 | parser = argparse.ArgumentParser('argument for training') 26 | 27 | parser.add_argument('--print_freq', type=int, default=10, 28 | help='print frequency') 29 | parser.add_argument('--save_freq', type=int, default=50, 30 | help='save frequency') 31 | parser.add_argument('--batch_size', type=int, default=256, 32 | help='batch_size') 33 | parser.add_argument('--num_workers', type=int, default=16, 34 | help='num of workers to use') 35 | parser.add_argument('--epochs', type=int, default=100, 36 | help='number of training epochs') 37 | 38 | # optimization 39 | parser.add_argument('--learning_rate', type=float, default=0.1, 40 | help='learning rate') 41 | parser.add_argument('--lr_decay_epochs', type=str, default='60,75,90', 42 | help='where to decay lr, can be a list') 43 | parser.add_argument('--lr_decay_rate', type=float, default=0.2, 44 | help='decay rate for learning rate') 45 | parser.add_argument('--weight_decay', type=float, default=0, 46 | help='weight decay') 47 | parser.add_argument('--momentum', type=float, default=0.9, 48 | help='momentum') 49 | 50 | # model dataset 51 | parser.add_argument('--model', type=str, default='resnet50') 52 | parser.add_argument('--dataset', type=str, default='cifar10', 53 | choices=['cifar10', 'cifar100'], help='dataset') 54 | 55 | # other setting 56 | parser.add_argument('--cosine', action='store_true', 57 | help='using cosine annealing') 58 | parser.add_argument('--warm', action='store_true', 59 | help='warm-up for large batch training') 60 | 61 | parser.add_argument('--ckpt', type=str, default='', 62 | help='path to pre-trained model') 63 | 64 | opt = parser.parse_args() 65 | 66 | # set the path according to the environment 67 | opt.data_folder = './datasets/' 68 | 69 | iterations = opt.lr_decay_epochs.split(',') 70 | opt.lr_decay_epochs = list([]) 71 | for it in iterations: 72 | opt.lr_decay_epochs.append(int(it)) 73 | 74 | opt.model_name = '{}_{}_lr_{}_decay_{}_bsz_{}'.\ 75 | format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay, 76 | opt.batch_size) 77 | 78 | if opt.cosine: 79 | opt.model_name = '{}_cosine'.format(opt.model_name) 80 | 81 | # warm-up for large-batch training, 82 | if opt.warm: 83 | opt.model_name = '{}_warm'.format(opt.model_name) 84 | opt.warmup_from = 0.01 85 | opt.warm_epochs = 10 86 | if opt.cosine: 87 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) 88 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * ( 89 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 90 | else: 91 | opt.warmup_to = opt.learning_rate 92 | 93 | if opt.dataset == 'cifar10': 94 | opt.n_cls = 10 95 | elif opt.dataset == 'cifar100': 96 | opt.n_cls = 100 97 | else: 98 | raise ValueError('dataset not supported: {}'.format(opt.dataset)) 99 | 100 | return opt 101 | 102 | 103 | def set_model(opt): 104 | model = SupConResNet(name=opt.model) 105 | criterion = torch.nn.CrossEntropyLoss() 106 | 107 | classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls) 108 | 109 | ckpt = torch.load(opt.ckpt, map_location='cpu') 110 | state_dict = ckpt['model'] 111 | 112 | if torch.cuda.is_available(): 113 | if torch.cuda.device_count() > 1: 114 | model.encoder = torch.nn.DataParallel(model.encoder) 115 | else: 116 | new_state_dict = {} 117 | for k, v in state_dict.items(): 118 | k = k.replace("module.", "") 119 | new_state_dict[k] = v 120 | state_dict = new_state_dict 121 | model = model.cuda() 122 | classifier = classifier.cuda() 123 | criterion = criterion.cuda() 124 | cudnn.benchmark = True 125 | 126 | model.load_state_dict(state_dict) 127 | else: 128 | raise NotImplementedError('This code requires GPU') 129 | 130 | return model, classifier, criterion 131 | 132 | 133 | def train(train_loader, model, classifier, criterion, optimizer, epoch, opt): 134 | """one epoch training""" 135 | model.eval() 136 | classifier.train() 137 | 138 | batch_time = AverageMeter() 139 | data_time = AverageMeter() 140 | losses = AverageMeter() 141 | top1 = AverageMeter() 142 | 143 | end = time.time() 144 | for idx, (images, labels) in enumerate(train_loader): 145 | data_time.update(time.time() - end) 146 | 147 | images = images.cuda(non_blocking=True) 148 | labels = labels.cuda(non_blocking=True) 149 | bsz = labels.shape[0] 150 | 151 | # warm-up learning rate 152 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer) 153 | 154 | # compute loss 155 | with torch.no_grad(): 156 | features = model.encoder(images) 157 | output = classifier(features.detach()) 158 | loss = criterion(output, labels) 159 | 160 | # update metric 161 | losses.update(loss.item(), bsz) 162 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) 163 | top1.update(acc1[0], bsz) 164 | 165 | # SGD 166 | optimizer.zero_grad() 167 | loss.backward() 168 | optimizer.step() 169 | 170 | # measure elapsed time 171 | batch_time.update(time.time() - end) 172 | end = time.time() 173 | 174 | # print info 175 | if (idx + 1) % opt.print_freq == 0: 176 | print('Train: [{0}][{1}/{2}]\t' 177 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 178 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 179 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t' 180 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 181 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 182 | data_time=data_time, loss=losses, top1=top1)) 183 | sys.stdout.flush() 184 | 185 | return losses.avg, top1.avg 186 | 187 | 188 | def validate(val_loader, model, classifier, criterion, opt): 189 | """validation""" 190 | model.eval() 191 | classifier.eval() 192 | 193 | batch_time = AverageMeter() 194 | losses = AverageMeter() 195 | top1 = AverageMeter() 196 | 197 | with torch.no_grad(): 198 | end = time.time() 199 | for idx, (images, labels) in enumerate(val_loader): 200 | images = images.float().cuda() 201 | labels = labels.cuda() 202 | bsz = labels.shape[0] 203 | 204 | # forward 205 | output = classifier(model.encoder(images)) 206 | loss = criterion(output, labels) 207 | 208 | # update metric 209 | losses.update(loss.item(), bsz) 210 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) 211 | top1.update(acc1[0], bsz) 212 | 213 | # measure elapsed time 214 | batch_time.update(time.time() - end) 215 | end = time.time() 216 | 217 | if idx % opt.print_freq == 0: 218 | print('Test: [{0}/{1}]\t' 219 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 220 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 221 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 222 | idx, len(val_loader), batch_time=batch_time, 223 | loss=losses, top1=top1)) 224 | 225 | print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) 226 | return losses.avg, top1.avg 227 | 228 | 229 | def main(): 230 | best_acc = 0 231 | opt = parse_option() 232 | 233 | # build data loader 234 | train_loader, val_loader = set_loader(opt) 235 | 236 | # build model and criterion 237 | model, classifier, criterion = set_model(opt) 238 | 239 | # build optimizer 240 | optimizer = set_optimizer(opt, classifier) 241 | 242 | # training routine 243 | for epoch in range(1, opt.epochs + 1): 244 | adjust_learning_rate(opt, optimizer, epoch) 245 | 246 | # train for one epoch 247 | time1 = time.time() 248 | loss, acc = train(train_loader, model, classifier, criterion, 249 | optimizer, epoch, opt) 250 | time2 = time.time() 251 | print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format( 252 | epoch, time2 - time1, acc)) 253 | 254 | # eval for one epoch 255 | loss, val_acc = validate(val_loader, model, classifier, criterion, opt) 256 | if val_acc > best_acc: 257 | best_acc = val_acc 258 | 259 | print('best accuracy: {:.2f}'.format(best_acc)) 260 | 261 | 262 | if __name__ == '__main__': 263 | main() 264 | -------------------------------------------------------------------------------- /main_supcon.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | import argparse 6 | import time 7 | import math 8 | 9 | import tensorboard_logger as tb_logger 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | from torchvision import transforms, datasets 13 | 14 | from util import TwoCropTransform, AverageMeter 15 | from util import adjust_learning_rate, warmup_learning_rate 16 | from util import set_optimizer, save_model 17 | from networks.resnet_big import SupConResNet 18 | from losses import SupConLoss 19 | 20 | try: 21 | import apex 22 | from apex import amp, optimizers 23 | except ImportError: 24 | pass 25 | 26 | 27 | def parse_option(): 28 | parser = argparse.ArgumentParser('argument for training') 29 | 30 | parser.add_argument('--print_freq', type=int, default=10, 31 | help='print frequency') 32 | parser.add_argument('--save_freq', type=int, default=50, 33 | help='save frequency') 34 | parser.add_argument('--batch_size', type=int, default=256, 35 | help='batch_size') 36 | parser.add_argument('--num_workers', type=int, default=16, 37 | help='num of workers to use') 38 | parser.add_argument('--epochs', type=int, default=1000, 39 | help='number of training epochs') 40 | 41 | # optimization 42 | parser.add_argument('--learning_rate', type=float, default=0.05, 43 | help='learning rate') 44 | parser.add_argument('--lr_decay_epochs', type=str, default='700,800,900', 45 | help='where to decay lr, can be a list') 46 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, 47 | help='decay rate for learning rate') 48 | parser.add_argument('--weight_decay', type=float, default=1e-4, 49 | help='weight decay') 50 | parser.add_argument('--momentum', type=float, default=0.9, 51 | help='momentum') 52 | 53 | # model dataset 54 | parser.add_argument('--model', type=str, default='resnet50') 55 | parser.add_argument('--dataset', type=str, default='cifar10', 56 | choices=['cifar10', 'cifar100', 'path'], help='dataset') 57 | parser.add_argument('--mean', type=str, help='mean of dataset in path in form of str tuple') 58 | parser.add_argument('--std', type=str, help='std of dataset in path in form of str tuple') 59 | parser.add_argument('--data_folder', type=str, default=None, help='path to custom dataset') 60 | parser.add_argument('--size', type=int, default=32, help='parameter for RandomResizedCrop') 61 | 62 | # method 63 | parser.add_argument('--method', type=str, default='SupCon', 64 | choices=['SupCon', 'SimCLR'], help='choose method') 65 | 66 | # temperature 67 | parser.add_argument('--temp', type=float, default=0.07, 68 | help='temperature for loss function') 69 | 70 | # other setting 71 | parser.add_argument('--cosine', action='store_true', 72 | help='using cosine annealing') 73 | parser.add_argument('--syncBN', action='store_true', 74 | help='using synchronized batch normalization') 75 | parser.add_argument('--warm', action='store_true', 76 | help='warm-up for large batch training') 77 | parser.add_argument('--trial', type=str, default='0', 78 | help='id for recording multiple runs') 79 | 80 | opt = parser.parse_args() 81 | 82 | # check if dataset is path that passed required arguments 83 | if opt.dataset == 'path': 84 | assert opt.data_folder is not None \ 85 | and opt.mean is not None \ 86 | and opt.std is not None 87 | 88 | # set the path according to the environment 89 | if opt.data_folder is None: 90 | opt.data_folder = './datasets/' 91 | opt.model_path = './save/SupCon/{}_models'.format(opt.dataset) 92 | opt.tb_path = './save/SupCon/{}_tensorboard'.format(opt.dataset) 93 | 94 | iterations = opt.lr_decay_epochs.split(',') 95 | opt.lr_decay_epochs = list([]) 96 | for it in iterations: 97 | opt.lr_decay_epochs.append(int(it)) 98 | 99 | opt.model_name = '{}_{}_{}_lr_{}_decay_{}_bsz_{}_temp_{}_trial_{}'.\ 100 | format(opt.method, opt.dataset, opt.model, opt.learning_rate, 101 | opt.weight_decay, opt.batch_size, opt.temp, opt.trial) 102 | 103 | if opt.cosine: 104 | opt.model_name = '{}_cosine'.format(opt.model_name) 105 | 106 | # warm-up for large-batch training, 107 | if opt.batch_size > 256: 108 | opt.warm = True 109 | if opt.warm: 110 | opt.model_name = '{}_warm'.format(opt.model_name) 111 | opt.warmup_from = 0.01 112 | opt.warm_epochs = 10 113 | if opt.cosine: 114 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) 115 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * ( 116 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 117 | else: 118 | opt.warmup_to = opt.learning_rate 119 | 120 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name) 121 | if not os.path.isdir(opt.tb_folder): 122 | os.makedirs(opt.tb_folder) 123 | 124 | opt.save_folder = os.path.join(opt.model_path, opt.model_name) 125 | if not os.path.isdir(opt.save_folder): 126 | os.makedirs(opt.save_folder) 127 | 128 | return opt 129 | 130 | 131 | def set_loader(opt): 132 | # construct data loader 133 | if opt.dataset == 'cifar10': 134 | mean = (0.4914, 0.4822, 0.4465) 135 | std = (0.2023, 0.1994, 0.2010) 136 | elif opt.dataset == 'cifar100': 137 | mean = (0.5071, 0.4867, 0.4408) 138 | std = (0.2675, 0.2565, 0.2761) 139 | elif opt.dataset == 'path': 140 | mean = eval(opt.mean) 141 | std = eval(opt.std) 142 | else: 143 | raise ValueError('dataset not supported: {}'.format(opt.dataset)) 144 | normalize = transforms.Normalize(mean=mean, std=std) 145 | 146 | train_transform = transforms.Compose([ 147 | transforms.RandomResizedCrop(size=opt.size, scale=(0.2, 1.)), 148 | transforms.RandomHorizontalFlip(), 149 | transforms.RandomApply([ 150 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 151 | ], p=0.8), 152 | transforms.RandomGrayscale(p=0.2), 153 | transforms.ToTensor(), 154 | normalize, 155 | ]) 156 | 157 | if opt.dataset == 'cifar10': 158 | train_dataset = datasets.CIFAR10(root=opt.data_folder, 159 | transform=TwoCropTransform(train_transform), 160 | download=True) 161 | elif opt.dataset == 'cifar100': 162 | train_dataset = datasets.CIFAR100(root=opt.data_folder, 163 | transform=TwoCropTransform(train_transform), 164 | download=True) 165 | elif opt.dataset == 'path': 166 | train_dataset = datasets.ImageFolder(root=opt.data_folder, 167 | transform=TwoCropTransform(train_transform)) 168 | else: 169 | raise ValueError(opt.dataset) 170 | 171 | train_sampler = None 172 | train_loader = torch.utils.data.DataLoader( 173 | train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None), 174 | num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler) 175 | 176 | return train_loader 177 | 178 | 179 | def set_model(opt): 180 | model = SupConResNet(name=opt.model) 181 | criterion = SupConLoss(temperature=opt.temp) 182 | 183 | # enable synchronized Batch Normalization 184 | if opt.syncBN: 185 | model = apex.parallel.convert_syncbn_model(model) 186 | 187 | if torch.cuda.is_available(): 188 | if torch.cuda.device_count() > 1: 189 | model.encoder = torch.nn.DataParallel(model.encoder) 190 | model = model.cuda() 191 | criterion = criterion.cuda() 192 | cudnn.benchmark = True 193 | 194 | return model, criterion 195 | 196 | 197 | def train(train_loader, model, criterion, optimizer, epoch, opt): 198 | """one epoch training""" 199 | model.train() 200 | 201 | batch_time = AverageMeter() 202 | data_time = AverageMeter() 203 | losses = AverageMeter() 204 | 205 | end = time.time() 206 | for idx, (images, labels) in enumerate(train_loader): 207 | data_time.update(time.time() - end) 208 | 209 | images = torch.cat([images[0], images[1]], dim=0) 210 | if torch.cuda.is_available(): 211 | images = images.cuda(non_blocking=True) 212 | labels = labels.cuda(non_blocking=True) 213 | bsz = labels.shape[0] 214 | 215 | # warm-up learning rate 216 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer) 217 | 218 | # compute loss 219 | features = model(images) 220 | f1, f2 = torch.split(features, [bsz, bsz], dim=0) 221 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) 222 | if opt.method == 'SupCon': 223 | loss = criterion(features, labels) 224 | elif opt.method == 'SimCLR': 225 | loss = criterion(features) 226 | else: 227 | raise ValueError('contrastive method not supported: {}'. 228 | format(opt.method)) 229 | 230 | # update metric 231 | losses.update(loss.item(), bsz) 232 | 233 | # SGD 234 | optimizer.zero_grad() 235 | loss.backward() 236 | optimizer.step() 237 | 238 | # measure elapsed time 239 | batch_time.update(time.time() - end) 240 | end = time.time() 241 | 242 | # print info 243 | if (idx + 1) % opt.print_freq == 0: 244 | print('Train: [{0}][{1}/{2}]\t' 245 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 246 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 247 | 'loss {loss.val:.3f} ({loss.avg:.3f})'.format( 248 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 249 | data_time=data_time, loss=losses)) 250 | sys.stdout.flush() 251 | 252 | return losses.avg 253 | 254 | 255 | def main(): 256 | opt = parse_option() 257 | 258 | # build data loader 259 | train_loader = set_loader(opt) 260 | 261 | # build model and criterion 262 | model, criterion = set_model(opt) 263 | 264 | # build optimizer 265 | optimizer = set_optimizer(opt, model) 266 | 267 | # tensorboard 268 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) 269 | 270 | # training routine 271 | for epoch in range(1, opt.epochs + 1): 272 | adjust_learning_rate(opt, optimizer, epoch) 273 | 274 | # train for one epoch 275 | time1 = time.time() 276 | loss = train(train_loader, model, criterion, optimizer, epoch, opt) 277 | time2 = time.time() 278 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) 279 | 280 | # tensorboard logger 281 | logger.log_value('loss', loss, epoch) 282 | logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) 283 | 284 | if epoch % opt.save_freq == 0: 285 | save_file = os.path.join( 286 | opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) 287 | save_model(model, optimizer, opt, epoch, save_file) 288 | 289 | # save the last model 290 | save_file = os.path.join( 291 | opt.save_folder, 'last.pth') 292 | save_model(model, optimizer, opt, opt.epochs, save_file) 293 | 294 | 295 | if __name__ == '__main__': 296 | main() 297 | -------------------------------------------------------------------------------- /networks/resnet_big.py: -------------------------------------------------------------------------------- 1 | """ResNet in PyTorch. 2 | ImageNet-Style ResNet 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 5 | Adapted from: https://github.com/bearpaw/pytorch-classification 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1, is_last=False): 16 | super(BasicBlock, self).__init__() 17 | self.is_last = is_last 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.shortcut = nn.Sequential() 24 | if stride != 1 or in_planes != self.expansion * planes: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 27 | nn.BatchNorm2d(self.expansion * planes) 28 | ) 29 | 30 | def forward(self, x): 31 | out = F.relu(self.bn1(self.conv1(x))) 32 | out = self.bn2(self.conv2(out)) 33 | out += self.shortcut(x) 34 | preact = out 35 | out = F.relu(out) 36 | if self.is_last: 37 | return out, preact 38 | else: 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1, is_last=False): 46 | super(Bottleneck, self).__init__() 47 | self.is_last = is_last 48 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 53 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 54 | 55 | self.shortcut = nn.Sequential() 56 | if stride != 1 or in_planes != self.expansion * planes: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 59 | nn.BatchNorm2d(self.expansion * planes) 60 | ) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(self.conv1(x))) 64 | out = F.relu(self.bn2(self.conv2(out))) 65 | out = self.bn3(self.conv3(out)) 66 | out += self.shortcut(x) 67 | preact = out 68 | out = F.relu(out) 69 | if self.is_last: 70 | return out, preact 71 | else: 72 | return out 73 | 74 | 75 | class ResNet(nn.Module): 76 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False): 77 | super(ResNet, self).__init__() 78 | self.in_planes = 64 79 | 80 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, 81 | bias=False) 82 | self.bn1 = nn.BatchNorm2d(64) 83 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 84 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 85 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 86 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 87 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 88 | 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 92 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 93 | nn.init.constant_(m.weight, 1) 94 | nn.init.constant_(m.bias, 0) 95 | 96 | # Zero-initialize the last BN in each residual branch, 97 | # so that the residual branch starts with zeros, and each residual block behaves 98 | # like an identity. This improves the model by 0.2~0.3% according to: 99 | # https://arxiv.org/abs/1706.02677 100 | if zero_init_residual: 101 | for m in self.modules(): 102 | if isinstance(m, Bottleneck): 103 | nn.init.constant_(m.bn3.weight, 0) 104 | elif isinstance(m, BasicBlock): 105 | nn.init.constant_(m.bn2.weight, 0) 106 | 107 | def _make_layer(self, block, planes, num_blocks, stride): 108 | strides = [stride] + [1] * (num_blocks - 1) 109 | layers = [] 110 | for i in range(num_blocks): 111 | stride = strides[i] 112 | layers.append(block(self.in_planes, planes, stride)) 113 | self.in_planes = planes * block.expansion 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, x, layer=100): 117 | out = F.relu(self.bn1(self.conv1(x))) 118 | out = self.layer1(out) 119 | out = self.layer2(out) 120 | out = self.layer3(out) 121 | out = self.layer4(out) 122 | out = self.avgpool(out) 123 | out = torch.flatten(out, 1) 124 | return out 125 | 126 | 127 | def resnet18(**kwargs): 128 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 129 | 130 | 131 | def resnet34(**kwargs): 132 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 133 | 134 | 135 | def resnet50(**kwargs): 136 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 137 | 138 | 139 | def resnet101(**kwargs): 140 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 141 | 142 | 143 | model_dict = { 144 | 'resnet18': [resnet18, 512], 145 | 'resnet34': [resnet34, 512], 146 | 'resnet50': [resnet50, 2048], 147 | 'resnet101': [resnet101, 2048], 148 | } 149 | 150 | 151 | class LinearBatchNorm(nn.Module): 152 | """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose""" 153 | def __init__(self, dim, affine=True): 154 | super(LinearBatchNorm, self).__init__() 155 | self.dim = dim 156 | self.bn = nn.BatchNorm2d(dim, affine=affine) 157 | 158 | def forward(self, x): 159 | x = x.view(-1, self.dim, 1, 1) 160 | x = self.bn(x) 161 | x = x.view(-1, self.dim) 162 | return x 163 | 164 | 165 | class SupConResNet(nn.Module): 166 | """backbone + projection head""" 167 | def __init__(self, name='resnet50', head='mlp', feat_dim=128): 168 | super(SupConResNet, self).__init__() 169 | model_fun, dim_in = model_dict[name] 170 | self.encoder = model_fun() 171 | if head == 'linear': 172 | self.head = nn.Linear(dim_in, feat_dim) 173 | elif head == 'mlp': 174 | self.head = nn.Sequential( 175 | nn.Linear(dim_in, dim_in), 176 | nn.ReLU(inplace=True), 177 | nn.Linear(dim_in, feat_dim) 178 | ) 179 | else: 180 | raise NotImplementedError( 181 | 'head not supported: {}'.format(head)) 182 | 183 | def forward(self, x): 184 | feat = self.encoder(x) 185 | feat = F.normalize(self.head(feat), dim=1) 186 | return feat 187 | 188 | 189 | class SupCEResNet(nn.Module): 190 | """encoder + classifier""" 191 | def __init__(self, name='resnet50', num_classes=10): 192 | super(SupCEResNet, self).__init__() 193 | model_fun, dim_in = model_dict[name] 194 | self.encoder = model_fun() 195 | self.fc = nn.Linear(dim_in, num_classes) 196 | 197 | def forward(self, x): 198 | return self.fc(self.encoder(x)) 199 | 200 | 201 | class LinearClassifier(nn.Module): 202 | """Linear classifier""" 203 | def __init__(self, name='resnet50', num_classes=10): 204 | super(LinearClassifier, self).__init__() 205 | _, feat_dim = model_dict[name] 206 | self.fc = nn.Linear(feat_dim, num_classes) 207 | 208 | def forward(self, features): 209 | return self.fc(features) 210 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.optim as optim 7 | 8 | 9 | class TwoCropTransform: 10 | """Create two crops of the same image""" 11 | def __init__(self, transform): 12 | self.transform = transform 13 | 14 | def __call__(self, x): 15 | return [self.transform(x), self.transform(x)] 16 | 17 | 18 | class AverageMeter(object): 19 | """Computes and stores the average and current value""" 20 | def __init__(self): 21 | self.reset() 22 | 23 | def reset(self): 24 | self.val = 0 25 | self.avg = 0 26 | self.sum = 0 27 | self.count = 0 28 | 29 | def update(self, val, n=1): 30 | self.val = val 31 | self.sum += val * n 32 | self.count += n 33 | self.avg = self.sum / self.count 34 | 35 | 36 | def accuracy(output, target, topk=(1,)): 37 | """Computes the accuracy over the k top predictions for the specified values of k""" 38 | with torch.no_grad(): 39 | maxk = max(topk) 40 | batch_size = target.size(0) 41 | 42 | _, pred = output.topk(maxk, 1, True, True) 43 | pred = pred.t() 44 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 45 | 46 | res = [] 47 | for k in topk: 48 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 49 | res.append(correct_k.mul_(100.0 / batch_size)) 50 | return res 51 | 52 | 53 | def adjust_learning_rate(args, optimizer, epoch): 54 | lr = args.learning_rate 55 | if args.cosine: 56 | eta_min = lr * (args.lr_decay_rate ** 3) 57 | lr = eta_min + (lr - eta_min) * ( 58 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 59 | else: 60 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) 61 | if steps > 0: 62 | lr = lr * (args.lr_decay_rate ** steps) 63 | 64 | for param_group in optimizer.param_groups: 65 | param_group['lr'] = lr 66 | 67 | 68 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): 69 | if args.warm and epoch <= args.warm_epochs: 70 | p = (batch_id + (epoch - 1) * total_batches) / \ 71 | (args.warm_epochs * total_batches) 72 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from) 73 | 74 | for param_group in optimizer.param_groups: 75 | param_group['lr'] = lr 76 | 77 | 78 | def set_optimizer(opt, model): 79 | optimizer = optim.SGD(model.parameters(), 80 | lr=opt.learning_rate, 81 | momentum=opt.momentum, 82 | weight_decay=opt.weight_decay) 83 | return optimizer 84 | 85 | 86 | def save_model(model, optimizer, opt, epoch, save_file): 87 | print('==> Saving...') 88 | state = { 89 | 'opt': opt, 90 | 'model': model.state_dict(), 91 | 'optimizer': optimizer.state_dict(), 92 | 'epoch': epoch, 93 | } 94 | torch.save(state, save_file) 95 | del state 96 | --------------------------------------------------------------------------------