├── .gitignore
├── LICENSE
├── README.md
├── figures
├── .DS_Store
├── SimCLR.jpg
├── SupCE.jpg
├── SupContrast.jpg
└── teaser.png
├── losses.py
├── main_ce.py
├── main_linear.py
├── main_supcon.py
├── networks
└── resnet_big.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | tmp*.py
2 | .idea/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 | save/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2020, Yonglong Tian
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SupContrast: Supervised Contrastive Learning
2 |
3 |
4 |
5 |
6 | This repo covers an reference implementation for the following papers in PyTorch, using CIFAR as an illustrative example:
7 | (1) Supervised Contrastive Learning. [Paper](https://arxiv.org/abs/2004.11362)
8 | (2) A Simple Framework for Contrastive Learning of Visual Representations. [Paper](https://arxiv.org/abs/2002.05709)
9 |
10 | ## Update
11 |
12 | ${\color{red}Note}$: if you found it not easy to parse the supcon loss implementation in this repo, we got you. Supcon loss essentially is just a cross-entropy loss (see eq 4 in the [StableRep](https://arxiv.org/pdf/2306.00984.pdf) paper). So we got a cleaner and simpler implementation [here](https://github.com/google-research/syn-rep-learn/blob/main/StableRep/models/losses.py#L49). Hope it helps.
13 |
14 | ImageNet model (small batch size with the trick of the momentum encoder) is released [here](https://www.dropbox.com/s/l4a69ececk4spdt/supcon.pth?dl=0). It achieved > 79% top-1 accuracy.
15 |
16 | ## Loss Function
17 | The loss function [`SupConLoss`](https://github.com/HobbitLong/SupContrast/blob/master/losses.py#L11) in `losses.py` takes `features` (L2 normalized) and `labels` as input, and return the loss. If `labels` is `None` or not passed to the it, it degenerates to SimCLR.
18 |
19 | Usage:
20 | ```python
21 | from losses import SupConLoss
22 |
23 | # define loss with a temperature `temp`
24 | criterion = SupConLoss(temperature=temp)
25 |
26 | # features: [bsz, n_views, f_dim]
27 | # `n_views` is the number of crops from each image
28 | # better be L2 normalized in f_dim dimension
29 | features = ...
30 | # labels: [bsz]
31 | labels = ...
32 |
33 | # SupContrast
34 | loss = criterion(features, labels)
35 | # or SimCLR
36 | loss = criterion(features)
37 | ...
38 | ```
39 |
40 | ## Comparison
41 | Results on CIFAR-10:
42 | | |Arch | Setting | Loss | Accuracy(%) |
43 | |----------|:----:|:---:|:---:|:---:|
44 | | SupCrossEntropy | ResNet50 | Supervised | Cross Entropy | 95.0 |
45 | | SupContrast | ResNet50 | Supervised | Contrastive | 96.0 |
46 | | SimCLR | ResNet50 | Unsupervised | Contrastive | 93.6 |
47 |
48 | Results on CIFAR-100:
49 | | |Arch | Setting | Loss | Accuracy(%) |
50 | |----------|:----:|:---:|:---:|:---:|
51 | | SupCrossEntropy | ResNet50 | Supervised | Cross Entropy | 75.3 |
52 | | SupContrast | ResNet50 | Supervised | Contrastive | 76.5 |
53 | | SimCLR | ResNet50 | Unsupervised | Contrastive | 70.7 |
54 |
55 | Results on ImageNet (Stay tuned):
56 | | |Arch | Setting | Loss | Accuracy(%) |
57 | |----------|:----:|:---:|:---:|:---:|
58 | | SupCrossEntropy | ResNet50 | Supervised | Cross Entropy | - |
59 | | SupContrast | ResNet50 | Supervised | Contrastive | 79.1 (MoCo trick) |
60 | | SimCLR | ResNet50 | Unsupervised | Contrastive | - |
61 |
62 | ## Running
63 | You might use `CUDA_VISIBLE_DEVICES` to set proper number of GPUs, and/or switch to CIFAR100 by `--dataset cifar100`.
64 | **(1) Standard Cross-Entropy**
65 | ```
66 | python main_ce.py --batch_size 1024 \
67 | --learning_rate 0.8 \
68 | --cosine --syncBN \
69 | ```
70 | **(2) Supervised Contrastive Learning**
71 | Pretraining stage:
72 | ```
73 | python main_supcon.py --batch_size 1024 \
74 | --learning_rate 0.5 \
75 | --temp 0.1 \
76 | --cosine
77 | ```
78 |
79 | You can also specify `--syncBN` but I found it not crucial for SupContrast (`syncBN` 95.9% v.s. `BN` 96.0%).
80 |
81 | WARN: Currently, `--syncBN` has no effect since the code is using `DataParallel` instead of `DistributedDataParaleel`
82 |
83 | Linear evaluation stage:
84 | ```
85 | python main_linear.py --batch_size 512 \
86 | --learning_rate 5 \
87 | --ckpt /path/to/model.pth
88 | ```
89 | **(3) SimCLR**
90 | Pretraining stage:
91 | ```
92 | python main_supcon.py --batch_size 1024 \
93 | --learning_rate 0.5 \
94 | --temp 0.5 \
95 | --cosine --syncBN \
96 | --method SimCLR
97 | ```
98 | The `--method SimCLR` flag simply stops `labels` from being passed to `SupConLoss` criterion.
99 | Linear evaluation stage:
100 | ```
101 | python main_linear.py --batch_size 512 \
102 | --learning_rate 1 \
103 | --ckpt /path/to/model.pth
104 | ```
105 |
106 | On custom dataset:
107 | ```
108 | python main_supcon.py --batch_size 1024 \
109 | --learning_rate 0.5 \
110 | --temp 0.1 --cosine \
111 | --dataset path \
112 | --data_folder ./path \
113 | --mean "(0.4914, 0.4822, 0.4465)" \
114 | --std "(0.2675, 0.2565, 0.2761)" \
115 | --method SimCLR
116 | ```
117 |
118 | The `--data_folder` must be of form ./path/label/xxx.png folowing https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.ImageFolder convension.
119 |
120 | and
121 | ## t-SNE Visualization
122 |
123 | **(1) Standard Cross-Entropy**
124 |
125 |
126 |
127 |
128 | **(2) Supervised Contrastive Learning**
129 |
130 |
131 |
132 |
133 | **(3) SimCLR**
134 |
135 |
136 |
137 |
138 | ## Reference
139 | ```
140 | @Article{khosla2020supervised,
141 | title = {Supervised Contrastive Learning},
142 | author = {Prannay Khosla and Piotr Teterwak and Chen Wang and Aaron Sarna and Yonglong Tian and Phillip Isola and Aaron Maschinot and Ce Liu and Dilip Krishnan},
143 | journal = {arXiv preprint arXiv:2004.11362},
144 | year = {2020},
145 | }
146 | ```
147 |
--------------------------------------------------------------------------------
/figures/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HobbitLong/SupContrast/66a8fe53880d6a1084b2e4e0db0a019024d6d41a/figures/.DS_Store
--------------------------------------------------------------------------------
/figures/SimCLR.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HobbitLong/SupContrast/66a8fe53880d6a1084b2e4e0db0a019024d6d41a/figures/SimCLR.jpg
--------------------------------------------------------------------------------
/figures/SupCE.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HobbitLong/SupContrast/66a8fe53880d6a1084b2e4e0db0a019024d6d41a/figures/SupCE.jpg
--------------------------------------------------------------------------------
/figures/SupContrast.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HobbitLong/SupContrast/66a8fe53880d6a1084b2e4e0db0a019024d6d41a/figures/SupContrast.jpg
--------------------------------------------------------------------------------
/figures/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HobbitLong/SupContrast/66a8fe53880d6a1084b2e4e0db0a019024d6d41a/figures/teaser.png
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Yonglong Tian (yonglong@mit.edu)
3 | Date: May 07, 2020
4 | """
5 | from __future__ import print_function
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | class SupConLoss(nn.Module):
12 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
13 | It also supports the unsupervised contrastive loss in SimCLR"""
14 | def __init__(self, temperature=0.07, contrast_mode='all',
15 | base_temperature=0.07):
16 | super(SupConLoss, self).__init__()
17 | self.temperature = temperature
18 | self.contrast_mode = contrast_mode
19 | self.base_temperature = base_temperature
20 |
21 | def forward(self, features, labels=None, mask=None):
22 | """Compute loss for model. If both `labels` and `mask` are None,
23 | it degenerates to SimCLR unsupervised loss:
24 | https://arxiv.org/pdf/2002.05709.pdf
25 |
26 | Args:
27 | features: hidden vector of shape [bsz, n_views, ...].
28 | labels: ground truth of shape [bsz].
29 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
30 | has the same class as sample i. Can be asymmetric.
31 | Returns:
32 | A loss scalar.
33 | """
34 | device = (torch.device('cuda')
35 | if features.is_cuda
36 | else torch.device('cpu'))
37 |
38 | if len(features.shape) < 3:
39 | raise ValueError('`features` needs to be [bsz, n_views, ...],'
40 | 'at least 3 dimensions are required')
41 | if len(features.shape) > 3:
42 | features = features.view(features.shape[0], features.shape[1], -1)
43 |
44 | batch_size = features.shape[0]
45 | if labels is not None and mask is not None:
46 | raise ValueError('Cannot define both `labels` and `mask`')
47 | elif labels is None and mask is None:
48 | mask = torch.eye(batch_size, dtype=torch.float32).to(device)
49 | elif labels is not None:
50 | labels = labels.contiguous().view(-1, 1)
51 | if labels.shape[0] != batch_size:
52 | raise ValueError('Num of labels does not match num of features')
53 | mask = torch.eq(labels, labels.T).float().to(device)
54 | else:
55 | mask = mask.float().to(device)
56 |
57 | contrast_count = features.shape[1]
58 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
59 | if self.contrast_mode == 'one':
60 | anchor_feature = features[:, 0]
61 | anchor_count = 1
62 | elif self.contrast_mode == 'all':
63 | anchor_feature = contrast_feature
64 | anchor_count = contrast_count
65 | else:
66 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
67 |
68 | # compute logits
69 | anchor_dot_contrast = torch.div(
70 | torch.matmul(anchor_feature, contrast_feature.T),
71 | self.temperature)
72 | # for numerical stability
73 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
74 | logits = anchor_dot_contrast - logits_max.detach()
75 |
76 | # tile mask
77 | mask = mask.repeat(anchor_count, contrast_count)
78 | # mask-out self-contrast cases
79 | logits_mask = torch.scatter(
80 | torch.ones_like(mask),
81 | 1,
82 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
83 | 0
84 | )
85 | mask = mask * logits_mask
86 |
87 | # compute log_prob
88 | exp_logits = torch.exp(logits) * logits_mask
89 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
90 |
91 | # compute mean of log-likelihood over positive
92 | # modified to handle edge cases when there is no positive pair
93 | # for an anchor point.
94 | # Edge case e.g.:-
95 | # features of shape: [4,1,...]
96 | # labels: [0,1,1,2]
97 | # loss before mean: [nan, ..., ..., nan]
98 | mask_pos_pairs = mask.sum(1)
99 | mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
100 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs
101 |
102 | # loss
103 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
104 | loss = loss.view(anchor_count, batch_size).mean()
105 |
106 | return loss
107 |
--------------------------------------------------------------------------------
/main_ce.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import sys
5 | import argparse
6 | import time
7 | import math
8 |
9 | import tensorboard_logger as tb_logger
10 | import torch
11 | import torch.backends.cudnn as cudnn
12 | from torchvision import transforms, datasets
13 |
14 | from util import AverageMeter
15 | from util import adjust_learning_rate, warmup_learning_rate, accuracy
16 | from util import set_optimizer, save_model
17 | from networks.resnet_big import SupCEResNet
18 |
19 | try:
20 | import apex
21 | from apex import amp, optimizers
22 | except ImportError:
23 | pass
24 |
25 |
26 | def parse_option():
27 | parser = argparse.ArgumentParser('argument for training')
28 |
29 | parser.add_argument('--print_freq', type=int, default=10,
30 | help='print frequency')
31 | parser.add_argument('--save_freq', type=int, default=50,
32 | help='save frequency')
33 | parser.add_argument('--batch_size', type=int, default=256,
34 | help='batch_size')
35 | parser.add_argument('--num_workers', type=int, default=16,
36 | help='num of workers to use')
37 | parser.add_argument('--epochs', type=int, default=500,
38 | help='number of training epochs')
39 |
40 | # optimization
41 | parser.add_argument('--learning_rate', type=float, default=0.2,
42 | help='learning rate')
43 | parser.add_argument('--lr_decay_epochs', type=str, default='350,400,450',
44 | help='where to decay lr, can be a list')
45 | parser.add_argument('--lr_decay_rate', type=float, default=0.1,
46 | help='decay rate for learning rate')
47 | parser.add_argument('--weight_decay', type=float, default=1e-4,
48 | help='weight decay')
49 | parser.add_argument('--momentum', type=float, default=0.9,
50 | help='momentum')
51 |
52 | # model dataset
53 | parser.add_argument('--model', type=str, default='resnet50')
54 | parser.add_argument('--dataset', type=str, default='cifar10',
55 | choices=['cifar10', 'cifar100'], help='dataset')
56 |
57 | # other setting
58 | parser.add_argument('--cosine', action='store_true',
59 | help='using cosine annealing')
60 | parser.add_argument('--syncBN', action='store_true',
61 | help='using synchronized batch normalization')
62 | parser.add_argument('--warm', action='store_true',
63 | help='warm-up for large batch training')
64 | parser.add_argument('--trial', type=str, default='0',
65 | help='id for recording multiple runs')
66 |
67 | opt = parser.parse_args()
68 |
69 | # set the path according to the environment
70 | opt.data_folder = './datasets/'
71 | opt.model_path = './save/SupCon/{}_models'.format(opt.dataset)
72 | opt.tb_path = './save/SupCon/{}_tensorboard'.format(opt.dataset)
73 |
74 | iterations = opt.lr_decay_epochs.split(',')
75 | opt.lr_decay_epochs = list([])
76 | for it in iterations:
77 | opt.lr_decay_epochs.append(int(it))
78 |
79 | opt.model_name = 'SupCE_{}_{}_lr_{}_decay_{}_bsz_{}_trial_{}'.\
80 | format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay,
81 | opt.batch_size, opt.trial)
82 |
83 | if opt.cosine:
84 | opt.model_name = '{}_cosine'.format(opt.model_name)
85 |
86 | # warm-up for large-batch training,
87 | if opt.batch_size > 256:
88 | opt.warm = True
89 | if opt.warm:
90 | opt.model_name = '{}_warm'.format(opt.model_name)
91 | opt.warmup_from = 0.01
92 | opt.warm_epochs = 10
93 | if opt.cosine:
94 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
95 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
96 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
97 | else:
98 | opt.warmup_to = opt.learning_rate
99 |
100 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
101 | if not os.path.isdir(opt.tb_folder):
102 | os.makedirs(opt.tb_folder)
103 |
104 | opt.save_folder = os.path.join(opt.model_path, opt.model_name)
105 | if not os.path.isdir(opt.save_folder):
106 | os.makedirs(opt.save_folder)
107 |
108 | if opt.dataset == 'cifar10':
109 | opt.n_cls = 10
110 | elif opt.dataset == 'cifar100':
111 | opt.n_cls = 100
112 | else:
113 | raise ValueError('dataset not supported: {}'.format(opt.dataset))
114 |
115 | return opt
116 |
117 |
118 | def set_loader(opt):
119 | # construct data loader
120 | if opt.dataset == 'cifar10':
121 | mean = (0.4914, 0.4822, 0.4465)
122 | std = (0.2023, 0.1994, 0.2010)
123 | elif opt.dataset == 'cifar100':
124 | mean = (0.5071, 0.4867, 0.4408)
125 | std = (0.2675, 0.2565, 0.2761)
126 | else:
127 | raise ValueError('dataset not supported: {}'.format(opt.dataset))
128 | normalize = transforms.Normalize(mean=mean, std=std)
129 |
130 | train_transform = transforms.Compose([
131 | transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
132 | transforms.RandomHorizontalFlip(),
133 | transforms.ToTensor(),
134 | normalize,
135 | ])
136 |
137 | val_transform = transforms.Compose([
138 | transforms.ToTensor(),
139 | normalize,
140 | ])
141 |
142 | if opt.dataset == 'cifar10':
143 | train_dataset = datasets.CIFAR10(root=opt.data_folder,
144 | transform=train_transform,
145 | download=True)
146 | val_dataset = datasets.CIFAR10(root=opt.data_folder,
147 | train=False,
148 | transform=val_transform)
149 | elif opt.dataset == 'cifar100':
150 | train_dataset = datasets.CIFAR100(root=opt.data_folder,
151 | transform=train_transform,
152 | download=True)
153 | val_dataset = datasets.CIFAR100(root=opt.data_folder,
154 | train=False,
155 | transform=val_transform)
156 | else:
157 | raise ValueError(opt.dataset)
158 |
159 | train_sampler = None
160 | train_loader = torch.utils.data.DataLoader(
161 | train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None),
162 | num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler)
163 | val_loader = torch.utils.data.DataLoader(
164 | val_dataset, batch_size=256, shuffle=False,
165 | num_workers=8, pin_memory=True)
166 |
167 | return train_loader, val_loader
168 |
169 |
170 | def set_model(opt):
171 | model = SupCEResNet(name=opt.model, num_classes=opt.n_cls)
172 | criterion = torch.nn.CrossEntropyLoss()
173 |
174 | # enable synchronized Batch Normalization
175 | if opt.syncBN:
176 | model = apex.parallel.convert_syncbn_model(model)
177 |
178 | if torch.cuda.is_available():
179 | if torch.cuda.device_count() > 1:
180 | model = torch.nn.DataParallel(model)
181 | model = model.cuda()
182 | criterion = criterion.cuda()
183 | cudnn.benchmark = True
184 |
185 | return model, criterion
186 |
187 |
188 | def train(train_loader, model, criterion, optimizer, epoch, opt):
189 | """one epoch training"""
190 | model.train()
191 |
192 | batch_time = AverageMeter()
193 | data_time = AverageMeter()
194 | losses = AverageMeter()
195 | top1 = AverageMeter()
196 |
197 | end = time.time()
198 | for idx, (images, labels) in enumerate(train_loader):
199 | data_time.update(time.time() - end)
200 |
201 | images = images.cuda(non_blocking=True)
202 | labels = labels.cuda(non_blocking=True)
203 | bsz = labels.shape[0]
204 |
205 | # warm-up learning rate
206 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
207 |
208 | # compute loss
209 | output = model(images)
210 | loss = criterion(output, labels)
211 |
212 | # update metric
213 | losses.update(loss.item(), bsz)
214 | acc1, acc5 = accuracy(output, labels, topk=(1, 5))
215 | top1.update(acc1[0], bsz)
216 |
217 | # SGD
218 | optimizer.zero_grad()
219 | loss.backward()
220 | optimizer.step()
221 |
222 | # measure elapsed time
223 | batch_time.update(time.time() - end)
224 | end = time.time()
225 |
226 | # print info
227 | if (idx + 1) % opt.print_freq == 0:
228 | print('Train: [{0}][{1}/{2}]\t'
229 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
230 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
231 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t'
232 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
233 | epoch, idx + 1, len(train_loader), batch_time=batch_time,
234 | data_time=data_time, loss=losses, top1=top1))
235 | sys.stdout.flush()
236 |
237 | return losses.avg, top1.avg
238 |
239 |
240 | def validate(val_loader, model, criterion, opt):
241 | """validation"""
242 | model.eval()
243 |
244 | batch_time = AverageMeter()
245 | losses = AverageMeter()
246 | top1 = AverageMeter()
247 |
248 | with torch.no_grad():
249 | end = time.time()
250 | for idx, (images, labels) in enumerate(val_loader):
251 | images = images.float().cuda()
252 | labels = labels.cuda()
253 | bsz = labels.shape[0]
254 |
255 | # forward
256 | output = model(images)
257 | loss = criterion(output, labels)
258 |
259 | # update metric
260 | losses.update(loss.item(), bsz)
261 | acc1, acc5 = accuracy(output, labels, topk=(1, 5))
262 | top1.update(acc1[0], bsz)
263 |
264 | # measure elapsed time
265 | batch_time.update(time.time() - end)
266 | end = time.time()
267 |
268 | if idx % opt.print_freq == 0:
269 | print('Test: [{0}/{1}]\t'
270 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
271 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
272 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
273 | idx, len(val_loader), batch_time=batch_time,
274 | loss=losses, top1=top1))
275 |
276 | print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
277 | return losses.avg, top1.avg
278 |
279 |
280 | def main():
281 | best_acc = 0
282 | opt = parse_option()
283 |
284 | # build data loader
285 | train_loader, val_loader = set_loader(opt)
286 |
287 | # build model and criterion
288 | model, criterion = set_model(opt)
289 |
290 | # build optimizer
291 | optimizer = set_optimizer(opt, model)
292 |
293 | # tensorboard
294 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)
295 |
296 | # training routine
297 | for epoch in range(1, opt.epochs + 1):
298 | adjust_learning_rate(opt, optimizer, epoch)
299 |
300 | # train for one epoch
301 | time1 = time.time()
302 | loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, opt)
303 | time2 = time.time()
304 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
305 |
306 | # tensorboard logger
307 | logger.log_value('train_loss', loss, epoch)
308 | logger.log_value('train_acc', train_acc, epoch)
309 | logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)
310 |
311 | # evaluation
312 | loss, val_acc = validate(val_loader, model, criterion, opt)
313 | logger.log_value('val_loss', loss, epoch)
314 | logger.log_value('val_acc', val_acc, epoch)
315 |
316 | if val_acc > best_acc:
317 | best_acc = val_acc
318 |
319 | if epoch % opt.save_freq == 0:
320 | save_file = os.path.join(
321 | opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
322 | save_model(model, optimizer, opt, epoch, save_file)
323 |
324 | # save the last model
325 | save_file = os.path.join(
326 | opt.save_folder, 'last.pth')
327 | save_model(model, optimizer, opt, opt.epochs, save_file)
328 |
329 | print('best accuracy: {:.2f}'.format(best_acc))
330 |
331 |
332 | if __name__ == '__main__':
333 | main()
334 |
--------------------------------------------------------------------------------
/main_linear.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import sys
4 | import argparse
5 | import time
6 | import math
7 |
8 | import torch
9 | import torch.backends.cudnn as cudnn
10 |
11 | from main_ce import set_loader
12 | from util import AverageMeter
13 | from util import adjust_learning_rate, warmup_learning_rate, accuracy
14 | from util import set_optimizer
15 | from networks.resnet_big import SupConResNet, LinearClassifier
16 |
17 | try:
18 | import apex
19 | from apex import amp, optimizers
20 | except ImportError:
21 | pass
22 |
23 |
24 | def parse_option():
25 | parser = argparse.ArgumentParser('argument for training')
26 |
27 | parser.add_argument('--print_freq', type=int, default=10,
28 | help='print frequency')
29 | parser.add_argument('--save_freq', type=int, default=50,
30 | help='save frequency')
31 | parser.add_argument('--batch_size', type=int, default=256,
32 | help='batch_size')
33 | parser.add_argument('--num_workers', type=int, default=16,
34 | help='num of workers to use')
35 | parser.add_argument('--epochs', type=int, default=100,
36 | help='number of training epochs')
37 |
38 | # optimization
39 | parser.add_argument('--learning_rate', type=float, default=0.1,
40 | help='learning rate')
41 | parser.add_argument('--lr_decay_epochs', type=str, default='60,75,90',
42 | help='where to decay lr, can be a list')
43 | parser.add_argument('--lr_decay_rate', type=float, default=0.2,
44 | help='decay rate for learning rate')
45 | parser.add_argument('--weight_decay', type=float, default=0,
46 | help='weight decay')
47 | parser.add_argument('--momentum', type=float, default=0.9,
48 | help='momentum')
49 |
50 | # model dataset
51 | parser.add_argument('--model', type=str, default='resnet50')
52 | parser.add_argument('--dataset', type=str, default='cifar10',
53 | choices=['cifar10', 'cifar100'], help='dataset')
54 |
55 | # other setting
56 | parser.add_argument('--cosine', action='store_true',
57 | help='using cosine annealing')
58 | parser.add_argument('--warm', action='store_true',
59 | help='warm-up for large batch training')
60 |
61 | parser.add_argument('--ckpt', type=str, default='',
62 | help='path to pre-trained model')
63 |
64 | opt = parser.parse_args()
65 |
66 | # set the path according to the environment
67 | opt.data_folder = './datasets/'
68 |
69 | iterations = opt.lr_decay_epochs.split(',')
70 | opt.lr_decay_epochs = list([])
71 | for it in iterations:
72 | opt.lr_decay_epochs.append(int(it))
73 |
74 | opt.model_name = '{}_{}_lr_{}_decay_{}_bsz_{}'.\
75 | format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay,
76 | opt.batch_size)
77 |
78 | if opt.cosine:
79 | opt.model_name = '{}_cosine'.format(opt.model_name)
80 |
81 | # warm-up for large-batch training,
82 | if opt.warm:
83 | opt.model_name = '{}_warm'.format(opt.model_name)
84 | opt.warmup_from = 0.01
85 | opt.warm_epochs = 10
86 | if opt.cosine:
87 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
88 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
89 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
90 | else:
91 | opt.warmup_to = opt.learning_rate
92 |
93 | if opt.dataset == 'cifar10':
94 | opt.n_cls = 10
95 | elif opt.dataset == 'cifar100':
96 | opt.n_cls = 100
97 | else:
98 | raise ValueError('dataset not supported: {}'.format(opt.dataset))
99 |
100 | return opt
101 |
102 |
103 | def set_model(opt):
104 | model = SupConResNet(name=opt.model)
105 | criterion = torch.nn.CrossEntropyLoss()
106 |
107 | classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls)
108 |
109 | ckpt = torch.load(opt.ckpt, map_location='cpu')
110 | state_dict = ckpt['model']
111 |
112 | if torch.cuda.is_available():
113 | if torch.cuda.device_count() > 1:
114 | model.encoder = torch.nn.DataParallel(model.encoder)
115 | else:
116 | new_state_dict = {}
117 | for k, v in state_dict.items():
118 | k = k.replace("module.", "")
119 | new_state_dict[k] = v
120 | state_dict = new_state_dict
121 | model = model.cuda()
122 | classifier = classifier.cuda()
123 | criterion = criterion.cuda()
124 | cudnn.benchmark = True
125 |
126 | model.load_state_dict(state_dict)
127 | else:
128 | raise NotImplementedError('This code requires GPU')
129 |
130 | return model, classifier, criterion
131 |
132 |
133 | def train(train_loader, model, classifier, criterion, optimizer, epoch, opt):
134 | """one epoch training"""
135 | model.eval()
136 | classifier.train()
137 |
138 | batch_time = AverageMeter()
139 | data_time = AverageMeter()
140 | losses = AverageMeter()
141 | top1 = AverageMeter()
142 |
143 | end = time.time()
144 | for idx, (images, labels) in enumerate(train_loader):
145 | data_time.update(time.time() - end)
146 |
147 | images = images.cuda(non_blocking=True)
148 | labels = labels.cuda(non_blocking=True)
149 | bsz = labels.shape[0]
150 |
151 | # warm-up learning rate
152 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
153 |
154 | # compute loss
155 | with torch.no_grad():
156 | features = model.encoder(images)
157 | output = classifier(features.detach())
158 | loss = criterion(output, labels)
159 |
160 | # update metric
161 | losses.update(loss.item(), bsz)
162 | acc1, acc5 = accuracy(output, labels, topk=(1, 5))
163 | top1.update(acc1[0], bsz)
164 |
165 | # SGD
166 | optimizer.zero_grad()
167 | loss.backward()
168 | optimizer.step()
169 |
170 | # measure elapsed time
171 | batch_time.update(time.time() - end)
172 | end = time.time()
173 |
174 | # print info
175 | if (idx + 1) % opt.print_freq == 0:
176 | print('Train: [{0}][{1}/{2}]\t'
177 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
178 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
179 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t'
180 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
181 | epoch, idx + 1, len(train_loader), batch_time=batch_time,
182 | data_time=data_time, loss=losses, top1=top1))
183 | sys.stdout.flush()
184 |
185 | return losses.avg, top1.avg
186 |
187 |
188 | def validate(val_loader, model, classifier, criterion, opt):
189 | """validation"""
190 | model.eval()
191 | classifier.eval()
192 |
193 | batch_time = AverageMeter()
194 | losses = AverageMeter()
195 | top1 = AverageMeter()
196 |
197 | with torch.no_grad():
198 | end = time.time()
199 | for idx, (images, labels) in enumerate(val_loader):
200 | images = images.float().cuda()
201 | labels = labels.cuda()
202 | bsz = labels.shape[0]
203 |
204 | # forward
205 | output = classifier(model.encoder(images))
206 | loss = criterion(output, labels)
207 |
208 | # update metric
209 | losses.update(loss.item(), bsz)
210 | acc1, acc5 = accuracy(output, labels, topk=(1, 5))
211 | top1.update(acc1[0], bsz)
212 |
213 | # measure elapsed time
214 | batch_time.update(time.time() - end)
215 | end = time.time()
216 |
217 | if idx % opt.print_freq == 0:
218 | print('Test: [{0}/{1}]\t'
219 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
220 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
221 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
222 | idx, len(val_loader), batch_time=batch_time,
223 | loss=losses, top1=top1))
224 |
225 | print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
226 | return losses.avg, top1.avg
227 |
228 |
229 | def main():
230 | best_acc = 0
231 | opt = parse_option()
232 |
233 | # build data loader
234 | train_loader, val_loader = set_loader(opt)
235 |
236 | # build model and criterion
237 | model, classifier, criterion = set_model(opt)
238 |
239 | # build optimizer
240 | optimizer = set_optimizer(opt, classifier)
241 |
242 | # training routine
243 | for epoch in range(1, opt.epochs + 1):
244 | adjust_learning_rate(opt, optimizer, epoch)
245 |
246 | # train for one epoch
247 | time1 = time.time()
248 | loss, acc = train(train_loader, model, classifier, criterion,
249 | optimizer, epoch, opt)
250 | time2 = time.time()
251 | print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format(
252 | epoch, time2 - time1, acc))
253 |
254 | # eval for one epoch
255 | loss, val_acc = validate(val_loader, model, classifier, criterion, opt)
256 | if val_acc > best_acc:
257 | best_acc = val_acc
258 |
259 | print('best accuracy: {:.2f}'.format(best_acc))
260 |
261 |
262 | if __name__ == '__main__':
263 | main()
264 |
--------------------------------------------------------------------------------
/main_supcon.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import sys
5 | import argparse
6 | import time
7 | import math
8 |
9 | import tensorboard_logger as tb_logger
10 | import torch
11 | import torch.backends.cudnn as cudnn
12 | from torchvision import transforms, datasets
13 |
14 | from util import TwoCropTransform, AverageMeter
15 | from util import adjust_learning_rate, warmup_learning_rate
16 | from util import set_optimizer, save_model
17 | from networks.resnet_big import SupConResNet
18 | from losses import SupConLoss
19 |
20 | try:
21 | import apex
22 | from apex import amp, optimizers
23 | except ImportError:
24 | pass
25 |
26 |
27 | def parse_option():
28 | parser = argparse.ArgumentParser('argument for training')
29 |
30 | parser.add_argument('--print_freq', type=int, default=10,
31 | help='print frequency')
32 | parser.add_argument('--save_freq', type=int, default=50,
33 | help='save frequency')
34 | parser.add_argument('--batch_size', type=int, default=256,
35 | help='batch_size')
36 | parser.add_argument('--num_workers', type=int, default=16,
37 | help='num of workers to use')
38 | parser.add_argument('--epochs', type=int, default=1000,
39 | help='number of training epochs')
40 |
41 | # optimization
42 | parser.add_argument('--learning_rate', type=float, default=0.05,
43 | help='learning rate')
44 | parser.add_argument('--lr_decay_epochs', type=str, default='700,800,900',
45 | help='where to decay lr, can be a list')
46 | parser.add_argument('--lr_decay_rate', type=float, default=0.1,
47 | help='decay rate for learning rate')
48 | parser.add_argument('--weight_decay', type=float, default=1e-4,
49 | help='weight decay')
50 | parser.add_argument('--momentum', type=float, default=0.9,
51 | help='momentum')
52 |
53 | # model dataset
54 | parser.add_argument('--model', type=str, default='resnet50')
55 | parser.add_argument('--dataset', type=str, default='cifar10',
56 | choices=['cifar10', 'cifar100', 'path'], help='dataset')
57 | parser.add_argument('--mean', type=str, help='mean of dataset in path in form of str tuple')
58 | parser.add_argument('--std', type=str, help='std of dataset in path in form of str tuple')
59 | parser.add_argument('--data_folder', type=str, default=None, help='path to custom dataset')
60 | parser.add_argument('--size', type=int, default=32, help='parameter for RandomResizedCrop')
61 |
62 | # method
63 | parser.add_argument('--method', type=str, default='SupCon',
64 | choices=['SupCon', 'SimCLR'], help='choose method')
65 |
66 | # temperature
67 | parser.add_argument('--temp', type=float, default=0.07,
68 | help='temperature for loss function')
69 |
70 | # other setting
71 | parser.add_argument('--cosine', action='store_true',
72 | help='using cosine annealing')
73 | parser.add_argument('--syncBN', action='store_true',
74 | help='using synchronized batch normalization')
75 | parser.add_argument('--warm', action='store_true',
76 | help='warm-up for large batch training')
77 | parser.add_argument('--trial', type=str, default='0',
78 | help='id for recording multiple runs')
79 |
80 | opt = parser.parse_args()
81 |
82 | # check if dataset is path that passed required arguments
83 | if opt.dataset == 'path':
84 | assert opt.data_folder is not None \
85 | and opt.mean is not None \
86 | and opt.std is not None
87 |
88 | # set the path according to the environment
89 | if opt.data_folder is None:
90 | opt.data_folder = './datasets/'
91 | opt.model_path = './save/SupCon/{}_models'.format(opt.dataset)
92 | opt.tb_path = './save/SupCon/{}_tensorboard'.format(opt.dataset)
93 |
94 | iterations = opt.lr_decay_epochs.split(',')
95 | opt.lr_decay_epochs = list([])
96 | for it in iterations:
97 | opt.lr_decay_epochs.append(int(it))
98 |
99 | opt.model_name = '{}_{}_{}_lr_{}_decay_{}_bsz_{}_temp_{}_trial_{}'.\
100 | format(opt.method, opt.dataset, opt.model, opt.learning_rate,
101 | opt.weight_decay, opt.batch_size, opt.temp, opt.trial)
102 |
103 | if opt.cosine:
104 | opt.model_name = '{}_cosine'.format(opt.model_name)
105 |
106 | # warm-up for large-batch training,
107 | if opt.batch_size > 256:
108 | opt.warm = True
109 | if opt.warm:
110 | opt.model_name = '{}_warm'.format(opt.model_name)
111 | opt.warmup_from = 0.01
112 | opt.warm_epochs = 10
113 | if opt.cosine:
114 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
115 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
116 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
117 | else:
118 | opt.warmup_to = opt.learning_rate
119 |
120 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
121 | if not os.path.isdir(opt.tb_folder):
122 | os.makedirs(opt.tb_folder)
123 |
124 | opt.save_folder = os.path.join(opt.model_path, opt.model_name)
125 | if not os.path.isdir(opt.save_folder):
126 | os.makedirs(opt.save_folder)
127 |
128 | return opt
129 |
130 |
131 | def set_loader(opt):
132 | # construct data loader
133 | if opt.dataset == 'cifar10':
134 | mean = (0.4914, 0.4822, 0.4465)
135 | std = (0.2023, 0.1994, 0.2010)
136 | elif opt.dataset == 'cifar100':
137 | mean = (0.5071, 0.4867, 0.4408)
138 | std = (0.2675, 0.2565, 0.2761)
139 | elif opt.dataset == 'path':
140 | mean = eval(opt.mean)
141 | std = eval(opt.std)
142 | else:
143 | raise ValueError('dataset not supported: {}'.format(opt.dataset))
144 | normalize = transforms.Normalize(mean=mean, std=std)
145 |
146 | train_transform = transforms.Compose([
147 | transforms.RandomResizedCrop(size=opt.size, scale=(0.2, 1.)),
148 | transforms.RandomHorizontalFlip(),
149 | transforms.RandomApply([
150 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
151 | ], p=0.8),
152 | transforms.RandomGrayscale(p=0.2),
153 | transforms.ToTensor(),
154 | normalize,
155 | ])
156 |
157 | if opt.dataset == 'cifar10':
158 | train_dataset = datasets.CIFAR10(root=opt.data_folder,
159 | transform=TwoCropTransform(train_transform),
160 | download=True)
161 | elif opt.dataset == 'cifar100':
162 | train_dataset = datasets.CIFAR100(root=opt.data_folder,
163 | transform=TwoCropTransform(train_transform),
164 | download=True)
165 | elif opt.dataset == 'path':
166 | train_dataset = datasets.ImageFolder(root=opt.data_folder,
167 | transform=TwoCropTransform(train_transform))
168 | else:
169 | raise ValueError(opt.dataset)
170 |
171 | train_sampler = None
172 | train_loader = torch.utils.data.DataLoader(
173 | train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None),
174 | num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler)
175 |
176 | return train_loader
177 |
178 |
179 | def set_model(opt):
180 | model = SupConResNet(name=opt.model)
181 | criterion = SupConLoss(temperature=opt.temp)
182 |
183 | # enable synchronized Batch Normalization
184 | if opt.syncBN:
185 | model = apex.parallel.convert_syncbn_model(model)
186 |
187 | if torch.cuda.is_available():
188 | if torch.cuda.device_count() > 1:
189 | model.encoder = torch.nn.DataParallel(model.encoder)
190 | model = model.cuda()
191 | criterion = criterion.cuda()
192 | cudnn.benchmark = True
193 |
194 | return model, criterion
195 |
196 |
197 | def train(train_loader, model, criterion, optimizer, epoch, opt):
198 | """one epoch training"""
199 | model.train()
200 |
201 | batch_time = AverageMeter()
202 | data_time = AverageMeter()
203 | losses = AverageMeter()
204 |
205 | end = time.time()
206 | for idx, (images, labels) in enumerate(train_loader):
207 | data_time.update(time.time() - end)
208 |
209 | images = torch.cat([images[0], images[1]], dim=0)
210 | if torch.cuda.is_available():
211 | images = images.cuda(non_blocking=True)
212 | labels = labels.cuda(non_blocking=True)
213 | bsz = labels.shape[0]
214 |
215 | # warm-up learning rate
216 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
217 |
218 | # compute loss
219 | features = model(images)
220 | f1, f2 = torch.split(features, [bsz, bsz], dim=0)
221 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
222 | if opt.method == 'SupCon':
223 | loss = criterion(features, labels)
224 | elif opt.method == 'SimCLR':
225 | loss = criterion(features)
226 | else:
227 | raise ValueError('contrastive method not supported: {}'.
228 | format(opt.method))
229 |
230 | # update metric
231 | losses.update(loss.item(), bsz)
232 |
233 | # SGD
234 | optimizer.zero_grad()
235 | loss.backward()
236 | optimizer.step()
237 |
238 | # measure elapsed time
239 | batch_time.update(time.time() - end)
240 | end = time.time()
241 |
242 | # print info
243 | if (idx + 1) % opt.print_freq == 0:
244 | print('Train: [{0}][{1}/{2}]\t'
245 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
246 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
247 | 'loss {loss.val:.3f} ({loss.avg:.3f})'.format(
248 | epoch, idx + 1, len(train_loader), batch_time=batch_time,
249 | data_time=data_time, loss=losses))
250 | sys.stdout.flush()
251 |
252 | return losses.avg
253 |
254 |
255 | def main():
256 | opt = parse_option()
257 |
258 | # build data loader
259 | train_loader = set_loader(opt)
260 |
261 | # build model and criterion
262 | model, criterion = set_model(opt)
263 |
264 | # build optimizer
265 | optimizer = set_optimizer(opt, model)
266 |
267 | # tensorboard
268 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)
269 |
270 | # training routine
271 | for epoch in range(1, opt.epochs + 1):
272 | adjust_learning_rate(opt, optimizer, epoch)
273 |
274 | # train for one epoch
275 | time1 = time.time()
276 | loss = train(train_loader, model, criterion, optimizer, epoch, opt)
277 | time2 = time.time()
278 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
279 |
280 | # tensorboard logger
281 | logger.log_value('loss', loss, epoch)
282 | logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)
283 |
284 | if epoch % opt.save_freq == 0:
285 | save_file = os.path.join(
286 | opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
287 | save_model(model, optimizer, opt, epoch, save_file)
288 |
289 | # save the last model
290 | save_file = os.path.join(
291 | opt.save_folder, 'last.pth')
292 | save_model(model, optimizer, opt, opt.epochs, save_file)
293 |
294 |
295 | if __name__ == '__main__':
296 | main()
297 |
--------------------------------------------------------------------------------
/networks/resnet_big.py:
--------------------------------------------------------------------------------
1 | """ResNet in PyTorch.
2 | ImageNet-Style ResNet
3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
5 | Adapted from: https://github.com/bearpaw/pytorch-classification
6 | """
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class BasicBlock(nn.Module):
13 | expansion = 1
14 |
15 | def __init__(self, in_planes, planes, stride=1, is_last=False):
16 | super(BasicBlock, self).__init__()
17 | self.is_last = is_last
18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
19 | self.bn1 = nn.BatchNorm2d(planes)
20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
21 | self.bn2 = nn.BatchNorm2d(planes)
22 |
23 | self.shortcut = nn.Sequential()
24 | if stride != 1 or in_planes != self.expansion * planes:
25 | self.shortcut = nn.Sequential(
26 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
27 | nn.BatchNorm2d(self.expansion * planes)
28 | )
29 |
30 | def forward(self, x):
31 | out = F.relu(self.bn1(self.conv1(x)))
32 | out = self.bn2(self.conv2(out))
33 | out += self.shortcut(x)
34 | preact = out
35 | out = F.relu(out)
36 | if self.is_last:
37 | return out, preact
38 | else:
39 | return out
40 |
41 |
42 | class Bottleneck(nn.Module):
43 | expansion = 4
44 |
45 | def __init__(self, in_planes, planes, stride=1, is_last=False):
46 | super(Bottleneck, self).__init__()
47 | self.is_last = is_last
48 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
49 | self.bn1 = nn.BatchNorm2d(planes)
50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
51 | self.bn2 = nn.BatchNorm2d(planes)
52 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
53 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
54 |
55 | self.shortcut = nn.Sequential()
56 | if stride != 1 or in_planes != self.expansion * planes:
57 | self.shortcut = nn.Sequential(
58 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
59 | nn.BatchNorm2d(self.expansion * planes)
60 | )
61 |
62 | def forward(self, x):
63 | out = F.relu(self.bn1(self.conv1(x)))
64 | out = F.relu(self.bn2(self.conv2(out)))
65 | out = self.bn3(self.conv3(out))
66 | out += self.shortcut(x)
67 | preact = out
68 | out = F.relu(out)
69 | if self.is_last:
70 | return out, preact
71 | else:
72 | return out
73 |
74 |
75 | class ResNet(nn.Module):
76 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False):
77 | super(ResNet, self).__init__()
78 | self.in_planes = 64
79 |
80 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1,
81 | bias=False)
82 | self.bn1 = nn.BatchNorm2d(64)
83 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
84 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
85 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
86 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
87 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
88 |
89 | for m in self.modules():
90 | if isinstance(m, nn.Conv2d):
91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
92 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
93 | nn.init.constant_(m.weight, 1)
94 | nn.init.constant_(m.bias, 0)
95 |
96 | # Zero-initialize the last BN in each residual branch,
97 | # so that the residual branch starts with zeros, and each residual block behaves
98 | # like an identity. This improves the model by 0.2~0.3% according to:
99 | # https://arxiv.org/abs/1706.02677
100 | if zero_init_residual:
101 | for m in self.modules():
102 | if isinstance(m, Bottleneck):
103 | nn.init.constant_(m.bn3.weight, 0)
104 | elif isinstance(m, BasicBlock):
105 | nn.init.constant_(m.bn2.weight, 0)
106 |
107 | def _make_layer(self, block, planes, num_blocks, stride):
108 | strides = [stride] + [1] * (num_blocks - 1)
109 | layers = []
110 | for i in range(num_blocks):
111 | stride = strides[i]
112 | layers.append(block(self.in_planes, planes, stride))
113 | self.in_planes = planes * block.expansion
114 | return nn.Sequential(*layers)
115 |
116 | def forward(self, x, layer=100):
117 | out = F.relu(self.bn1(self.conv1(x)))
118 | out = self.layer1(out)
119 | out = self.layer2(out)
120 | out = self.layer3(out)
121 | out = self.layer4(out)
122 | out = self.avgpool(out)
123 | out = torch.flatten(out, 1)
124 | return out
125 |
126 |
127 | def resnet18(**kwargs):
128 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
129 |
130 |
131 | def resnet34(**kwargs):
132 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
133 |
134 |
135 | def resnet50(**kwargs):
136 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
137 |
138 |
139 | def resnet101(**kwargs):
140 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
141 |
142 |
143 | model_dict = {
144 | 'resnet18': [resnet18, 512],
145 | 'resnet34': [resnet34, 512],
146 | 'resnet50': [resnet50, 2048],
147 | 'resnet101': [resnet101, 2048],
148 | }
149 |
150 |
151 | class LinearBatchNorm(nn.Module):
152 | """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose"""
153 | def __init__(self, dim, affine=True):
154 | super(LinearBatchNorm, self).__init__()
155 | self.dim = dim
156 | self.bn = nn.BatchNorm2d(dim, affine=affine)
157 |
158 | def forward(self, x):
159 | x = x.view(-1, self.dim, 1, 1)
160 | x = self.bn(x)
161 | x = x.view(-1, self.dim)
162 | return x
163 |
164 |
165 | class SupConResNet(nn.Module):
166 | """backbone + projection head"""
167 | def __init__(self, name='resnet50', head='mlp', feat_dim=128):
168 | super(SupConResNet, self).__init__()
169 | model_fun, dim_in = model_dict[name]
170 | self.encoder = model_fun()
171 | if head == 'linear':
172 | self.head = nn.Linear(dim_in, feat_dim)
173 | elif head == 'mlp':
174 | self.head = nn.Sequential(
175 | nn.Linear(dim_in, dim_in),
176 | nn.ReLU(inplace=True),
177 | nn.Linear(dim_in, feat_dim)
178 | )
179 | else:
180 | raise NotImplementedError(
181 | 'head not supported: {}'.format(head))
182 |
183 | def forward(self, x):
184 | feat = self.encoder(x)
185 | feat = F.normalize(self.head(feat), dim=1)
186 | return feat
187 |
188 |
189 | class SupCEResNet(nn.Module):
190 | """encoder + classifier"""
191 | def __init__(self, name='resnet50', num_classes=10):
192 | super(SupCEResNet, self).__init__()
193 | model_fun, dim_in = model_dict[name]
194 | self.encoder = model_fun()
195 | self.fc = nn.Linear(dim_in, num_classes)
196 |
197 | def forward(self, x):
198 | return self.fc(self.encoder(x))
199 |
200 |
201 | class LinearClassifier(nn.Module):
202 | """Linear classifier"""
203 | def __init__(self, name='resnet50', num_classes=10):
204 | super(LinearClassifier, self).__init__()
205 | _, feat_dim = model_dict[name]
206 | self.fc = nn.Linear(feat_dim, num_classes)
207 |
208 | def forward(self, features):
209 | return self.fc(features)
210 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import math
4 | import numpy as np
5 | import torch
6 | import torch.optim as optim
7 |
8 |
9 | class TwoCropTransform:
10 | """Create two crops of the same image"""
11 | def __init__(self, transform):
12 | self.transform = transform
13 |
14 | def __call__(self, x):
15 | return [self.transform(x), self.transform(x)]
16 |
17 |
18 | class AverageMeter(object):
19 | """Computes and stores the average and current value"""
20 | def __init__(self):
21 | self.reset()
22 |
23 | def reset(self):
24 | self.val = 0
25 | self.avg = 0
26 | self.sum = 0
27 | self.count = 0
28 |
29 | def update(self, val, n=1):
30 | self.val = val
31 | self.sum += val * n
32 | self.count += n
33 | self.avg = self.sum / self.count
34 |
35 |
36 | def accuracy(output, target, topk=(1,)):
37 | """Computes the accuracy over the k top predictions for the specified values of k"""
38 | with torch.no_grad():
39 | maxk = max(topk)
40 | batch_size = target.size(0)
41 |
42 | _, pred = output.topk(maxk, 1, True, True)
43 | pred = pred.t()
44 | correct = pred.eq(target.view(1, -1).expand_as(pred))
45 |
46 | res = []
47 | for k in topk:
48 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
49 | res.append(correct_k.mul_(100.0 / batch_size))
50 | return res
51 |
52 |
53 | def adjust_learning_rate(args, optimizer, epoch):
54 | lr = args.learning_rate
55 | if args.cosine:
56 | eta_min = lr * (args.lr_decay_rate ** 3)
57 | lr = eta_min + (lr - eta_min) * (
58 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2
59 | else:
60 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
61 | if steps > 0:
62 | lr = lr * (args.lr_decay_rate ** steps)
63 |
64 | for param_group in optimizer.param_groups:
65 | param_group['lr'] = lr
66 |
67 |
68 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
69 | if args.warm and epoch <= args.warm_epochs:
70 | p = (batch_id + (epoch - 1) * total_batches) / \
71 | (args.warm_epochs * total_batches)
72 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)
73 |
74 | for param_group in optimizer.param_groups:
75 | param_group['lr'] = lr
76 |
77 |
78 | def set_optimizer(opt, model):
79 | optimizer = optim.SGD(model.parameters(),
80 | lr=opt.learning_rate,
81 | momentum=opt.momentum,
82 | weight_decay=opt.weight_decay)
83 | return optimizer
84 |
85 |
86 | def save_model(model, optimizer, opt, epoch, save_file):
87 | print('==> Saving...')
88 | state = {
89 | 'opt': opt,
90 | 'model': model.state_dict(),
91 | 'optimizer': optimizer.state_dict(),
92 | 'epoch': epoch,
93 | }
94 | torch.save(state, save_file)
95 | del state
96 |
--------------------------------------------------------------------------------