├── LICENSE
├── README.md
├── data
├── __init__.py
├── data.py
└── utils.py
├── figures
├── ex.jpeg
├── genscl.png
└── results.png
├── genscl.py
├── linear.py
├── loss.py
├── mix.py
├── parser.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Jaewon Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A Generalized Supervised Contrastive Learning Framework
2 |
3 |
4 |
5 | Official PyTorch implementation of the GenSCL | [Paper](https://arxiv.org/abs/2206.00384)
6 |
7 | Jaewon Kim and Jooyoung Chang and Sang Min Park
8 |
9 | [HSDS](http://snuhsds.com/)@Seoul National University
10 |
11 | Our implementation is based on the [Supervised Contrastive Learning](https://github.com/HobbitLong/SupContrast) repository.
12 |
13 | ### Abstract
14 |
15 | Based on recent remarkable achievements of contrastive learning in self-supervised representation learning, supervised contrastive learning (SupCon) has successfully extended the batch contrastive approaches to the supervised context and outperformed cross-entropy on various datasets on ResNet. In this work, we present *GenSCL*: a generalized supervised contrastive learning framework that seamlessly adapts modern image-based regularizations (such as Mixup-Cutmix) and knowledge distillation (KD) to SupCon by our *generalized supervised contrastive loss*. Generalized supervised contrastive loss is a further extension of supervised contrastive loss measuring cross-entropy between the similarity of labels and that of latent features. Then a model can learn to what extent contrastives should be pulled closer to an anchor in the latent space. By explicitly and fully leveraging label information, GenSCL breaks the boundary between conventional positives and negatives, and any kind of pre-trained teacher classifier can be utilized. ResNet-50 trained in GenSCL with Mixup-Cutmix and KD achieves state-of-the-art accuracies of 97.6% and 84.7% on CIFAR10 and CIFAR100 without external data, which significantly improves the results reported in the original SupCon (1.6% and 8.2%, respectively). Pytorch implementation is available at https://t.ly/yuUO.
16 |
17 | ### Overview of the results
18 |
19 |
20 |
21 | ## Loss Function
22 |
23 | Our proposed *Generalized Supervised Contrastive Loss* in `loss.py` takes a tuple of `features` and a tuple of `labels` as the input, and returns the loss. If `labels` is one-hot encoded label, it degenerates to Supervised Contrastive Loss.
24 |
25 | By *Generalized Supervised Contrastive Loss*, we can seamlessly adapt Mixup/Cutmix and knowledge distillation to Supervised Contrastive Learning.
26 |
27 | 
28 |
29 | ## Running
30 |
31 | To apply knowledge distillation, pretrained teacher model (EfficientNetV2-M) is required and released [here](https://www.dropbox.com/sh/io8u9mv8hh3bt4m/AACjNFDZIgPADoyU14OEqVQSa?dl=0).
32 |
33 | * CIFAR10
34 |
35 | * Pretraining stage:
36 |
37 | ```bash
38 | python genscl.py \
39 | --dataset cifar10 \
40 | --mix mixup_cutmix \
41 | --KD \
42 | --KD-alpha 1 \
43 | --teacher-path ./pretrained_saves/efficientnetv2_rw_m_ema_mixup_cutmix_cifar10_Adam
44 | ```
45 |
46 | * Linear evaluation stage:
47 |
48 | ```bash
49 | python linear.py \
50 | --dataset cifar10 \
51 | --pretrained cifar10_bsz_1024_mixup_cutmix_1.0_KD_1.0_SGD_lr_0.5 \
52 | --augment-policy no \
53 | --amp
54 | ```
55 |
56 | * CIFAR100
57 |
58 | * Pretraining stage:
59 |
60 | ```bash
61 | python genscl.py \
62 | --dataset cifar100 \
63 | --mix mixup_cutmix \
64 | --KD \
65 | --KD-alpha 1 \
66 | --teacher-path ./pretrained_saves/efficientnetv2_rw_m_ema_mixup_cutmix_cifar100_Adam
67 | ```
68 |
69 | * Linear evaluation stage:
70 |
71 | ```bash
72 | python linear.py \
73 | --dataset cifar100 \
74 | --pretrained cifar100_bsz_1024_mixup_cutmix_1.0_KD_1.0_SGD_lr_0.5 \
75 | --augment-policy no \
76 | --amp
77 | ```
78 |
79 |
80 | You have several extra options:
81 |
82 | * `--optim-kind`: SGD, RMSProp, Adam, AdamW
83 |
84 | * `--augment-policy`: no, sim, auto, rand
85 | * `--wandb`: enable [wandb](https://wandb.ai/) for visualization
86 |
87 | ## Updates
88 |
89 | * 23 Jun, 2022: Initial upload
90 |
91 | ## Citation
92 |
93 | ```
94 | @article{kim2022generalized,
95 | title={A Generalized Supervised Contrastive Learning Framework},
96 | author={Kim, Jaewon and Chang, Jooyoung and Park, Sang Min},
97 | journal={arXiv preprint arXiv:2206.00384},
98 | year={2022}
99 | }
100 | ```
101 |
102 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .data import *
--------------------------------------------------------------------------------
/data/data.py:
--------------------------------------------------------------------------------
1 | from .utils import TwoCropTransform
2 | from .utils import Cutout
3 |
4 | import torch
5 | import torchvision.transforms as transforms
6 | import torchvision.datasets as datasets
7 |
8 | MEAN = {
9 | 'cifar10': [0.4914, 0.4822, 0.4465],
10 | 'cifar100': [0.5071, 0.4867, 0.4408],
11 | 'imagenet': [0.485, 0.456, 0.406]
12 | }
13 | STD = {
14 | 'cifar10': [0.2023, 0.1994, 0.2010],
15 | 'cifar100':[0.2675, 0.2565, 0.2761],
16 | 'imagenet': [0.229, 0.224, 0.225]
17 | }
18 | SIZE = {
19 | 'cifar10': 32,
20 | 'cifar100': 32,
21 | 'imagenet': 224,
22 | }
23 |
24 | NUM_CLASSES = {
25 | 'cifar10': 10,
26 | 'cifar100': 100,
27 | 'imagenet': 1000
28 | }
29 |
30 |
31 | def contrastive_loader(args):
32 | # transformation
33 | normalize = transforms.Normalize(mean=MEAN[args.dataset], std=STD[args.dataset])
34 |
35 | if args.augment_policy == 'sim': # simclr augment
36 | train_transform = transforms.Compose([
37 | transforms.RandomResizedCrop(size=SIZE[args.dataset], scale=(0.2, 1.)),
38 | transforms.RandomHorizontalFlip(),
39 | transforms.RandomApply([
40 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
41 | ], p=0.8),
42 | transforms.RandomGrayscale(p=0.2),
43 | transforms.ToTensor(),
44 | normalize,
45 | ])
46 | elif args.augment_policy == 'auto': # auto augment
47 | train_transform = transforms.Compose([
48 | transforms.RandomResizedCrop(size=SIZE[args.dataset]),
49 | transforms.RandomHorizontalFlip(),
50 | transforms.AutoAugment(),
51 | transforms.ToTensor(),
52 | normalize,
53 | ])
54 | elif args.augment_policy == 'rand':
55 | train_transform = transforms.Compose([
56 | transforms.RandomResizedCrop(size=SIZE[args.dataset]),
57 | transforms.RandomHorizontalFlip(),
58 | transforms.RandAugment(args.rand_n, args.rand_m),
59 | transforms.ToTensor(),
60 | normalize,
61 | ])
62 | else:
63 | raise NotImplementedError(f'Unknown {args.augment_policy}!')
64 |
65 | if args.cutout:
66 | train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.cutout_length))
67 |
68 | # dataset
69 | if args.dataset == 'cifar10':
70 | train_dataset = datasets.CIFAR10(root='/home/DB/',
71 | transform=TwoCropTransform(train_transform),
72 | download=True)
73 | elif args.dataset == 'cifar100':
74 | train_dataset = datasets.CIFAR100(root='/home/DB/',
75 | transform=TwoCropTransform(train_transform),
76 | download=True)
77 | elif args.dataset == 'imagenet':
78 | train_dataset = datasets.ImageFolder('/home/DB/IMAGENET/train',
79 | TwoCropTransform(train_transform)
80 | )
81 | else:
82 | raise ValueError(args.dataset)
83 |
84 | if args.multiprocessing_distributed:
85 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
86 | else:
87 | train_sampler = None
88 |
89 | train_loader = torch.utils.data.DataLoader(
90 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
91 | num_workers=args.num_workers, pin_memory=True, sampler=train_sampler, drop_last=True)
92 |
93 | return train_loader, train_sampler
94 |
95 |
96 | def normal_loader(args):
97 | normalize = transforms.Normalize(mean=MEAN[args.dataset], std=STD[args.dataset])
98 | if args.dataset == 'cifar10' or args.dataset == 'cifar100':
99 | resize = transforms.RandomCrop(SIZE[args.dataset], padding=4)
100 | else:
101 | resize = transforms.RandomResizedCrop(size=SIZE[args.dataset])
102 | # train dataset
103 | if args.augment_policy == 'no':
104 | train_transform = transforms.Compose([
105 | transforms.RandomResizedCrop(size=SIZE[args.dataset], scale=(0.2, 1.)),
106 | transforms.RandomHorizontalFlip(),
107 | transforms.ToTensor(),
108 | normalize,
109 | ])
110 | elif args.augment_policy == 'sim': # simclr augment
111 | train_transform = transforms.Compose([
112 | transforms.RandomResizedCrop(size=SIZE[args.dataset], scale=(0.2, 1.)),
113 | transforms.RandomHorizontalFlip(),
114 | transforms.RandomApply([
115 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
116 | ], p=0.8),
117 | transforms.RandomGrayscale(p=0.2),
118 | transforms.ToTensor(),
119 | normalize,
120 | ])
121 | elif args.augment_policy == 'auto':
122 | train_transform = transforms.Compose([
123 | resize,
124 | transforms.RandomHorizontalFlip(),
125 | transforms.AutoAugment(),
126 | transforms.ToTensor(),
127 | normalize,
128 | ])
129 | elif args.augment_policy == 'rand':
130 | train_transform = transforms.Compose([
131 | resize,
132 | transforms.RandomHorizontalFlip(),
133 | transforms.RandAugment(args.rand_n, args.rand_m),
134 | transforms.ToTensor(),
135 | normalize,
136 | ])
137 | else:
138 | raise NotImplementedError()
139 |
140 | if args.erasing:
141 | train_transform.transforms.append(transforms.RandomErasing(p=args.erasing_p))
142 | if args.cutout:
143 | train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.cutout_length))
144 |
145 | # valid datset
146 | if args.dataset == 'imagenet':
147 | val_transform = transforms.Compose([
148 | transforms.Resize(256),
149 | transforms.CenterCrop(224),
150 | transforms.ToTensor(),
151 | normalize,
152 | ])
153 | else:
154 | val_transform = transforms.Compose([
155 | transforms.ToTensor(),
156 | normalize,
157 | ])
158 |
159 | if args.dataset == 'cifar10':
160 | train_dataset = datasets.CIFAR10(root='/home/DB/',
161 | transform=train_transform,
162 | download=True)
163 | val_set = datasets.CIFAR10(root='/home/DB/',
164 | transform=val_transform,
165 | train=False,
166 | download=True)
167 | elif args.dataset == 'cifar100':
168 | train_dataset = datasets.CIFAR100(root='/home/DB/',
169 | transform=train_transform,
170 | download=True)
171 | val_set = datasets.CIFAR100(root='/home/DB/',
172 | transform=val_transform,
173 | train=False,
174 | download=True)
175 | elif args.dataset == 'imagenet':
176 | train_dataset = datasets.ImageFolder('/home/DB/IMAGENET/train', train_transform)
177 | val_set = datasets.ImageFolder('/home/DB/IMAGENET/val', val_transform)
178 | else:
179 | raise ValueError(args.dataset)
180 |
181 | if args.multiprocessing_distributed:
182 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
183 | else:
184 | train_sampler = None
185 |
186 | train_loader = torch.utils.data.DataLoader(
187 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
188 | num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)
189 |
190 | val_loader = torch.utils.data.DataLoader(
191 | val_set, batch_size=args.batch_size, shuffle=False,
192 | num_workers=args.num_workers, pin_memory=True
193 | )
194 |
195 | return train_loader, val_loader, train_sampler
--------------------------------------------------------------------------------
/data/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class Cutout(object):
6 | """Randomly mask out one or more patches from an image.
7 | Args:
8 | n_holes (int): Number of patches to cut out of each image.
9 | length (int): The length (in pixels) of each square patch.
10 | """
11 | def __init__(self, n_holes, length):
12 | self.n_holes = n_holes
13 | self.length = length
14 |
15 | def __call__(self, img):
16 | """
17 | Args:
18 | img (Tensor): Tensor image of size (C, H, W).
19 | Returns:
20 | Tensor: Image with n_holes of dimension length x length cut out of it.
21 | """
22 | h = img.size(1)
23 | w = img.size(2)
24 |
25 | mask = np.ones((h, w), np.float32)
26 |
27 | for n in range(self.n_holes):
28 | y = np.random.randint(h)
29 | x = np.random.randint(w)
30 |
31 | y1 = np.clip(y - self.length // 2, 0, h)
32 | y2 = np.clip(y + self.length // 2, 0, h)
33 | x1 = np.clip(x - self.length // 2, 0, w)
34 | x2 = np.clip(x + self.length // 2, 0, w)
35 |
36 | mask[y1: y2, x1: x2] = 0.
37 |
38 | mask = torch.from_numpy(mask)
39 | mask = mask.expand_as(img)
40 | img = img * mask
41 |
42 | return img
43 |
44 | class TwoCropTransform:
45 | """Create two crops of the same image"""
46 | def __init__(self, transform):
47 | self.transform = transform
48 |
49 | def __call__(self, x):
50 | return [self.transform(x), self.transform(x)]
--------------------------------------------------------------------------------
/figures/ex.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kiimmm/GenSCL/d2395c722f233337a722c8d0d18b86527e96f110/figures/ex.jpeg
--------------------------------------------------------------------------------
/figures/genscl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kiimmm/GenSCL/d2395c722f233337a722c8d0d18b86527e96f110/figures/genscl.png
--------------------------------------------------------------------------------
/figures/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kiimmm/GenSCL/d2395c722f233337a722c8d0d18b86527e96f110/figures/results.png
--------------------------------------------------------------------------------
/genscl.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 | from pathlib import Path
4 |
5 | import timm
6 | import torch
7 | import torch.nn.functional as F
8 | import torch.backends.cudnn as cudnn
9 | from torch.cuda.amp import autocast
10 | try:
11 | import wandb
12 | except ImportError:
13 | pass
14 |
15 | from utils import AverageMeter
16 | from utils import adjust_learning_rate, warmup_learning_rate, get_learning_rate
17 | from utils import set_optimizer, save_model
18 | from utils import seed, format_time
19 | from utils import init_wandb
20 |
21 | from networks.resnet_big import SupConResNet
22 | from loss import GenSupConLoss
23 | from mix import mix_fn, mix_target
24 | from data import contrastive_loader, NUM_CLASSES
25 | from parser import genscl_parser
26 |
27 |
28 |
29 | # load encoder (student)
30 | def set_model(args):
31 | model = SupConResNet(name=args.model)
32 |
33 | if args.KD: # load teacher model
34 | teacher = timm.create_model(args.teacher_kind, pretrained=False)
35 | teacher.reset_classifier(NUM_CLASSES[args.dataset])
36 | out_ch = teacher.conv_stem.out_channels
37 | teacher.conv_stem = torch.nn.Conv2d(3, out_ch, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
38 | teacher_ckpt = torch.load(Path(args.teacher_path)/args.teacher_ckpt, map_location='cpu')
39 | teacher.load_state_dict(teacher_ckpt['state_dict'])
40 | teacher.eval()
41 | else:
42 | teacher = None
43 |
44 | criterion = GenSupConLoss(temperature=args.temp)
45 |
46 | if torch.cuda.is_available():
47 | if torch.cuda.device_count() > 1:
48 | model.encoder = torch.nn.DataParallel(model.encoder)
49 | if teacher: teacher = torch.nn.DataParallel(teacher)
50 | if not args.resume is None: # resume from previous ckpt
51 | load_fn = Path(args.save_root)/args.desc/f'ckpt_{args.resume}.pth'
52 | ckpt = torch.load(load_fn, map_location='cpu')
53 | model.load_state_dict(ckpt['state_dict'])
54 | print(f'=> Successfully loading {load_fn}!')
55 | args.start_epoch = ckpt['epoch'] + 1
56 | else:
57 | args.start_epoch = 1
58 |
59 | model = model.cuda()
60 | if teacher: teacher = teacher.cuda()
61 | criterion = criterion.cuda()
62 | cudnn.benchmark = True
63 |
64 | return model, teacher, criterion
65 |
66 |
67 | def train(loader, model, teacher, criterion, optimizer, epoch, args):
68 | model.train()
69 |
70 | batch_time = AverageMeter()
71 | data_time = AverageMeter()
72 | losses = AverageMeter()
73 |
74 | end = time.time()
75 |
76 | for idx, (images, targets) in enumerate(loader):
77 | data_time.update(time.time() - end)
78 | warmup_learning_rate(args, epoch, idx, len(loader), optimizer)
79 |
80 | bsz = targets.shape[0]
81 | im_q, im_k = images
82 | if torch.cuda.is_available():
83 | im_q = im_q.cuda(non_blocking=True)
84 | im_k = im_k.cuda(non_blocking=True)
85 | targets = targets.cuda(non_blocking=True)
86 |
87 | if args.mix: # image-based regularizations
88 | im_q, y0a, y0b, lam0 = mix_fn(im_q, targets, args.mix_alpha, args.mix)
89 | im_k, y1a, y1b, lam1 = mix_fn(im_k, targets, args.mix_alpha, args.mix)
90 | images = torch.cat([im_q, im_k], dim=0)
91 | l_q = mix_target(y0a, y0b, lam0, NUM_CLASSES[args.dataset])
92 | l_k = mix_target(y1a, y1b, lam1, NUM_CLASSES[args.dataset])
93 | else:
94 | images = torch.cat([im_q, im_k], dim=0)
95 | l_q = F.one_hot(targets, NUM_CLASSES[args.dataset])
96 | l_k = l_q
97 |
98 | if teacher: # KD
99 | with torch.no_grad():
100 | with autocast():
101 | preds = F.softmax(teacher(images) / args.KD_temp, dim=1)
102 | teacher_q, teacher_k = torch.split(preds, [bsz, bsz], dim=0)
103 |
104 | # forward
105 | features = model(images)
106 | features = torch.split(features, [bsz, bsz], dim=0)
107 |
108 | if teacher:
109 | if args.KD_alpha == float('inf'): # only learn from teacher's prediction
110 | loss = criterion(features, [teacher_q, teacher_k])
111 | else:
112 | loss = criterion(features, [l_q, l_k]) + args.KD_alpha * criterion(features, [teacher_q, teacher_k])
113 | else: # no KD
114 | loss = criterion(features, [l_q, l_k])
115 |
116 |
117 | losses.update(loss.item(), bsz)
118 | # backwaqrd
119 | optimizer.zero_grad()
120 | loss.backward()
121 | optimizer.step()
122 |
123 | # measure elapsed time
124 | batch_time.update(time.time() - end)
125 | end = time.time()
126 |
127 | # print info
128 | if (idx + 1) % args.print_freq == 0:
129 | print('Train: [{0}][{1}/{2}]\t'
130 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
131 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
132 | 'loss {loss.val:.3f} ({loss.avg:.3f})'.format(
133 | epoch, idx + 1, len(loader), batch_time=batch_time,
134 | data_time=data_time, loss=losses))
135 | sys.stdout.flush()
136 |
137 | res = {
138 | 'trn_loss': losses.avg,
139 | 'learning_rate': get_learning_rate(optimizer)
140 | }
141 | return res
142 |
143 |
144 | def default_desc(args):
145 | if args.desc is None:
146 | desc = args.dataset + '_'
147 | desc += f'bsz_{args.batch_size}_'
148 | if args.cutout:
149 | desc += 'cutout_'
150 | if args.mix:
151 | desc += f'{args.mix}_{args.mix_alpha}_'
152 | if args.KD:
153 | desc += f'KD_{args.KD_alpha}_'
154 | desc += f'{args.optim_kind}_lr_{args.learning_rate}'
155 | args.desc = desc
156 | return args
157 |
158 |
159 | def main():
160 | parser = genscl_parser()
161 | args = parser.parse_args()
162 | args = default_desc(args)
163 | seed(args.seed)
164 |
165 | if args.debug:
166 | args.epochs = 1
167 | elif args.wandb:
168 | init_wandb(args)
169 | save_dir = Path(args.save_root)/args.desc
170 |
171 | # build data loader
172 | train_loader, _ = contrastive_loader(args)
173 |
174 | # build model and criterion
175 | model, teacher, criterion = set_model(args)
176 |
177 | # build optimizer
178 | optimizer = set_optimizer(model, args)
179 |
180 | # train
181 | for epoch in range(args.start_epoch, args.epochs + 1):
182 | adjust_learning_rate(args, optimizer, epoch)
183 |
184 | # train for one epoch
185 | time1 = time.time()
186 | res = train(train_loader, model, teacher, criterion, optimizer, epoch, args)
187 | time2 = time.time()
188 | print(f'epoch {epoch}, total time {format_time(time2 - time1)}')
189 | if not args.debug and args.wandb:
190 | wandb.log(res, step=epoch)
191 |
192 | if (epoch % args.save_freq == 0) and not args.debug:
193 | save_fn = save_dir/f'ckpt_{epoch}.pth'
194 | save_model(model, optimizer, args, epoch, save_fn)
195 |
196 | if not args.debug:
197 | save_fn = save_dir/f'ckpt_last.pth'
198 | save_model(model, optimizer, args, args.epochs, save_fn)
199 | if args.wandb:
200 | wandb.finish()
201 |
202 | if __name__ == '__main__':
203 | main()
--------------------------------------------------------------------------------
/linear.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 | from pathlib import Path
4 | from contextlib import suppress
5 |
6 | import torch
7 | from torch.cuda.amp import autocast, GradScaler
8 | import torch.backends.cudnn as cudnn
9 | try:
10 | import wandb
11 | except ImportError:
12 | pass
13 |
14 | from utils import AverageMeter
15 | from utils import adjust_learning_rate, warmup_learning_rate, accuracy, get_learning_rate
16 | from utils import set_optimizer, save_model
17 | from utils import init_wandb
18 | from utils import format_time
19 | from utils import seed
20 |
21 | from networks.resnet_big import SupConResNet, LinearClassifier
22 | from data import normal_loader, NUM_CLASSES
23 | from mix import mix_fn
24 | from parser import linear_parser
25 |
26 |
27 | # load trained encoder and build a classifier to train
28 | def set_model(args):
29 | model = SupConResNet(name=args.model)
30 | classifier = LinearClassifier(name=args.model, num_classes=NUM_CLASSES[args.dataset])
31 |
32 | criterion = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
33 |
34 | load_fn = Path(args.save_root)/args.pretrained/args.pretrained_ckpt
35 | ckpt = torch.load(load_fn, map_location='cpu')
36 | state_dict = ckpt['state_dict']
37 |
38 | if torch.cuda.is_available():
39 | if torch.cuda.device_count() > 1:
40 | model.encoder = torch.nn.DataParallel(model.encoder)
41 | else:
42 | new_state_dict = {}
43 | for k, v in state_dict.items():
44 | k = k.replace("module.", "")
45 | new_state_dict[k] = v
46 | state_dict = new_state_dict
47 | model = model.cuda()
48 | classifier = classifier.cuda()
49 | criterion = criterion.cuda()
50 | cudnn.benchmark = True
51 |
52 | model.load_state_dict(state_dict)
53 |
54 | return model, classifier, criterion
55 |
56 |
57 | def train(loader, model, classifier, criterion, optimizer, epoch, amp_autocast, scaler, args):
58 | """one epoch training"""
59 | model.eval()
60 | classifier.train()
61 |
62 | batch_time = AverageMeter()
63 | data_time = AverageMeter()
64 | losses = AverageMeter()
65 | top1 = AverageMeter()
66 |
67 | end = time.time()
68 | for idx, (images, targets) in enumerate(loader):
69 | data_time.update(time.time() - end)
70 |
71 | images = images.cuda(non_blocking=True)
72 | targets = targets.cuda(non_blocking=True)
73 | bsz = targets.shape[0]
74 |
75 | # warm-up learning rate
76 | warmup_learning_rate(args, epoch, idx, len(loader), optimizer)
77 |
78 | # compute loss
79 | with amp_autocast():
80 | with torch.no_grad():
81 | features = model.encoder(images)
82 | output = classifier(features.detach())
83 | loss = criterion(output, targets)
84 |
85 | # update metric
86 | losses.update(loss.item(), bsz)
87 | acc1 = accuracy(output, targets)
88 | top1.update(acc1, bsz)
89 |
90 | optimizer.zero_grad()
91 | if scaler:
92 | scaler.scale(loss).backward()
93 | scaler.step(optimizer)
94 | scaler.update()
95 | else:
96 | loss.backward()
97 | optimizer.step()
98 |
99 | # measure elapsed time
100 | batch_time.update(time.time() - end)
101 | end = time.time()
102 |
103 | # print info
104 | if (idx + 1) % args.print_freq == 0:
105 | print('Train: [{0}][{1}/{2}]\t'
106 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
107 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
108 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t'
109 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
110 | epoch, idx + 1, len(loader), batch_time=batch_time,
111 | data_time=data_time, loss=losses, top1=top1))
112 | sys.stdout.flush()
113 |
114 | res = {
115 | 'trn_loss': losses.avg,
116 | 'trn_top1_acc': top1.avg,
117 | 'learning_rate': get_learning_rate(optimizer)
118 | }
119 | return res
120 |
121 | def validate(loader, model, classifier, criterion, amp_autocast, args):
122 | """validation"""
123 | model.eval()
124 | classifier.eval()
125 |
126 | batch_time = AverageMeter()
127 | losses = AverageMeter()
128 | top1 = AverageMeter()
129 |
130 | with torch.no_grad():
131 | end = time.time()
132 | for idx, (images, targets) in enumerate(loader):
133 | images = images.float().cuda()
134 | targets = targets.cuda()
135 | bsz = targets.shape[0]
136 |
137 | # forward
138 | with amp_autocast():
139 | output = classifier(model.encoder(images))
140 | loss = criterion(output, targets)
141 |
142 | # update metric
143 | losses.update(loss.item(), bsz)
144 | acc1 = accuracy(output, targets)
145 | top1.update(acc1, bsz)
146 |
147 | # measure elapsed time
148 | batch_time.update(time.time() - end)
149 | end = time.time()
150 |
151 | if idx % args.print_freq == 0:
152 | print('Test: [{0}/{1}]\t'
153 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
154 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
155 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
156 | idx, len(loader), batch_time=batch_time,
157 | loss=losses, top1=top1))
158 |
159 | print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
160 | res = {
161 | 'val_loss': losses.avg,
162 | 'val_top1_acc': top1.avg
163 | }
164 |
165 | return res
166 |
167 |
168 | def default_desc(args):
169 | if args.desc is None:
170 | desc = args.pretrained + '_linear'
171 | args.desc = desc
172 | return args
173 |
174 |
175 | def main():
176 | best_acc = 0.
177 | parser = linear_parser()
178 | args = parser.parse_args()
179 | args = default_desc(args)
180 | seed(args.seed)
181 |
182 | if args.debug:
183 | args.epochs = 1
184 | elif args.wandb:
185 | init_wandb(args)
186 |
187 | save_dir = Path(args.save_root)/args.desc
188 | amp_autocast = autocast if args.amp else suppress
189 | scaler = GradScaler() if args.amp else None
190 | # build data loader
191 | train_loader, val_loader, _ = normal_loader(args)
192 |
193 | # build model and criterion
194 | model, classifier, criterion = set_model(args)
195 |
196 | # build optimizer
197 | optimizer = set_optimizer(classifier, args)
198 |
199 | # training routine
200 | for epoch in range(1, args.epochs + 1):
201 | adjust_learning_rate(args, optimizer, epoch)
202 |
203 | # train for one epoch
204 | time1 = time.time()
205 | res = train(train_loader, model, classifier, criterion,
206 | optimizer, epoch, amp_autocast, scaler, args)
207 | time2 = time.time()
208 | print('Train epoch {}, total time {}, accuracy:{:.2f}'.format(
209 | epoch, format_time(time2 - time1), res['trn_top1_acc']))
210 |
211 | save_model(classifier, optimizer, args, epoch, save_dir/'ckpt_last.pth')
212 |
213 | # eval for one epoch
214 | val_res = validate(val_loader, model, classifier, criterion, amp_autocast, args)
215 | val_acc = val_res['val_top1_acc']
216 | res.update(val_res)
217 | if not(args.debug) and args.wandb:
218 | wandb.log(res, step=epoch)
219 |
220 | if val_acc > best_acc:
221 | best_acc = val_acc
222 | save_model(classifier, optimizer, args, epoch, save_dir/'ckpt_best.pth')
223 |
224 | print('best accuracy: {:.2f}'.format(best_acc))
225 | if not(args.debug) and args.wandb:
226 | wandb.log({'val_best_acc': best_acc})
227 | wandb.finish()
228 |
229 | if __name__ == '__main__':
230 | main()
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | '''
2 | adapted from SupConLoss
3 | https://github.com/HobbitLong/SupContrast/blob/master/losses.py
4 | '''
5 | import torch
6 | import torch.nn as nn
7 |
8 | class GenSupConLoss(nn.Module):
9 | def __init__(self, temperature=0.07, contrast_mode='all',
10 | base_temperature=0.07):
11 | super(GenSupConLoss, self).__init__()
12 | self.temperature = temperature
13 | self.contrast_mode = contrast_mode
14 | self.base_temperature = base_temperature
15 |
16 | def forward(self, features, labels):
17 | '''
18 | Args:
19 | feats: (anchor_features, contrast_features), each: [N, feat_dim]
20 | labels: (anchor_labels, contrast_labels) each: [N, num_cls]
21 | '''
22 | if self.contrast_mode == 'all': # anchor+contrast @ anchor+contrast
23 | anchor_labels = torch.cat(labels, dim=0).float()
24 | contrast_labels = anchor_labels
25 |
26 | anchor_features = torch.cat(features, dim=0)
27 | contrast_features = anchor_features
28 | elif self.contrast_mode == 'one': # anchor @ contrast
29 | anchor_labels = labels[0].float()
30 | contrast_labels = labels[1].float()
31 |
32 | anchor_features = features[0]
33 | contrast_features = features[1]
34 |
35 | # 1. compute similarities among targets
36 | anchor_norm = torch.norm(anchor_labels, p=2, dim=-1, keepdim=True) # [anchor_N, 1]
37 | contrast_norm = torch.norm(contrast_labels, p=2, dim=-1, keepdim=True) # [contrast_N, 1]
38 |
39 | deno = torch.mm(anchor_norm, contrast_norm.T)
40 | mask = torch.mm(anchor_labels, contrast_labels.T) / deno # cosine similarity: [anchor_N, contrast_N]
41 |
42 | logits_mask = torch.ones_like(mask)
43 | if self.contrast_mode == 'all':
44 | logits_mask.fill_diagonal_(0)
45 | mask = mask * logits_mask
46 |
47 | # 2. compute logits
48 | anchor_dot_contrast = torch.div(
49 | torch.matmul(anchor_features, contrast_features.T),
50 | self.temperature
51 | )
52 | # for numerical stability
53 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
54 | logits = anchor_dot_contrast - logits_max.detach()
55 | # compute log_prob
56 | exp_logits = torch.exp(logits) * logits_mask
57 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
58 | mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-8)
59 |
60 | # loss
61 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
62 | loss = loss.mean()
63 |
64 | return loss
--------------------------------------------------------------------------------
/mix.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | def mix_fn(x, y, alpha, kind):
7 | if kind == 'mixup':
8 | return mixup_data(x, y, alpha)
9 | elif kind == 'cutmix':
10 | return cutmix_data(x, y, alpha)
11 | elif kind == 'mixup_cutmix':
12 | if np.random.rand(1)[0] > 0.5:
13 | return mixup_data(x, y, alpha)
14 | else:
15 | return cutmix_data(x, y, alpha)
16 | else:
17 | raise ValueError()
18 |
19 |
20 | def mix_target(y_a, y_b, lam, num_classes):
21 | l1 = F.one_hot(y_a, num_classes)
22 | l2 = F.one_hot(y_b, num_classes)
23 | return lam * l1 + (1 - lam) * l2
24 |
25 |
26 | '''
27 | modified from https://github.com/hongyi-zhang/mixup/blob/master/cifar/utils.py
28 | '''
29 | def mixup_data(x, y, alpha=1.0):
30 | '''Returns mixed inputs, pairs of targets, and lambda'''
31 | if alpha > 0:
32 | lam = np.random.beta(alpha, alpha)
33 | else:
34 | lam = 1
35 |
36 | batch_size = x.size()[0]
37 | index = torch.randperm(batch_size, device=x.device)
38 |
39 | mixed_x = lam * x + (1 - lam) * x[index, :]
40 | y_a, y_b = y, y[index]
41 | return mixed_x, y_a, y_b, lam
42 |
43 | '''
44 | modified from https://github.com/clovaai/CutMix-PyTorch/blob/master/train.py
45 | '''
46 | def cutmix_data(x, y, alpha=1.0):
47 | if alpha > 0:
48 | lam = np.random.beta(alpha, alpha)
49 | else:
50 | lam = 1
51 |
52 | bsz = x.size()[0]
53 | index = torch.randperm(bsz, device=x.device)
54 |
55 | bbx1, bby1, bbx2, bby2 = _rand_bbox(x.size(), lam)
56 | mixed_x = x.detach().clone()
57 | mixed_x[:, :, bbx1:bbx2, bby1:bby2] = mixed_x[index, :, bbx1:bbx2, bby1:bby2]
58 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
59 |
60 | y_a, y_b = y, y[index]
61 |
62 | return mixed_x, y_a, y_b, lam
63 |
64 | def _rand_bbox(size, lam):
65 | W = size[2]
66 | H = size[3]
67 | cut_rat = np.sqrt(1. - lam)
68 | cut_w = np.int(W * cut_rat)
69 | cut_h = np.int(H * cut_rat)
70 |
71 | # uniform
72 | cx = np.random.randint(W)
73 | cy = np.random.randint(H)
74 |
75 | bbx1 = np.clip(cx - cut_w // 2, 0, W)
76 | bby1 = np.clip(cy - cut_h // 2, 0, H)
77 | bbx2 = np.clip(cx + cut_w // 2, 0, W)
78 | bby2 = np.clip(cy + cut_h // 2, 0, H)
79 |
80 | return bbx1, bby1, bbx2, bby2
--------------------------------------------------------------------------------
/parser.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def str2bool(v):
4 | """
5 | Parse boolean using argument parser.
6 | """
7 | if isinstance(v, bool):
8 | return v
9 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
10 | return True
11 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
12 | return False
13 | else:
14 | raise argparse.ArgumentTypeError('Boolean value expected.')
15 |
16 | def genscl_parser():
17 | parser = argparse.ArgumentParser('argument for supervised contrastive learning')
18 |
19 | parser.add_argument('--desc', type=str, default=None,
20 | help='experiment name')
21 | parser.add_argument('--model', type=str, default='resnet50',
22 | help='model kind')
23 | parser.add_argument('--print-freq', type=int, default=20,
24 | help='print frequency')
25 | parser.add_argument('--save-freq', type=int, default=50,
26 | help='save frequency')
27 | parser.add_argument('--save-root', type=str, default='./saves/',
28 | help='root directory of results')
29 | parser.add_argument('--batch-size', type=int, default=1024,
30 | help='batch_size')
31 | parser.add_argument('--num-workers', type=int, default=8,
32 | help='num of workers to use')
33 | parser.add_argument('--epochs', type=int, default=500,
34 | help='number of training epochs')
35 | parser.add_argument('--temp', type=float, default=0.1,
36 | help='temperature for generalized supervised contrastive loss')
37 | parser.add_argument('--resume', type=str, default=None)
38 | # knowledge distillation
39 | parser.add_argument('--KD', action='store_true', default=False,
40 | help='perform knowledge distillation')
41 | parser.add_argument('--KD-alpha', type=float, default=1,
42 | help='weight of KD')
43 | parser.add_argument('--KD-temp', type=float, default=1,
44 | help='softening prediction of teachers')
45 | parser.add_argument('--teacher-kind', type=str, default='efficientnetv2_rw_m')
46 | parser.add_argument('--teacher-path', type=str, default=None)
47 | parser.add_argument('--teacher-ckpt', type=str, default='ckpt_best.pth')
48 | # optimization
49 | parser.add_argument('--optim-kind', type=str, default='SGD',
50 | choices=['SGD', 'RMSProp', 'Adam', 'AdamW'],
51 | help='kind of optimizer')
52 | parser.add_argument('--LARS', action='store_true', default=False)
53 | parser.add_argument('--learning-rate', type=float, default=0.5,
54 | help='learning rate')
55 | parser.add_argument('--cosine', type=str2bool, default=True,
56 | help='using cosine annealing')
57 | parser.add_argument('--lr-decay-rate', type=float, default=0.1,
58 | help='decay rate for learning rate')
59 | parser.add_argument('--weight-decay', type=float, default=1e-4,
60 | help='weight decay')
61 | parser.add_argument('--momentum', type=float, default=0.9,
62 | help='momentum')
63 | # warmup
64 | parser.add_argument('--warm', type=str2bool, default=True,
65 | help='warm-up for large batch training')
66 | parser.add_argument('--warmup-from', type=float, default=0.01)
67 | parser.add_argument('--warm-epochs', type=int, default=10)
68 | parser.add_argument('--multiprocessing-distributed', action='store_true', default=False)
69 | # model dataset
70 | parser.add_argument('--dataset', type=str,
71 | choices=['cifar10', 'cifar100', 'imagenet'])
72 | # data augment policy
73 | parser.add_argument('--augment-policy', type=str, default='sim',
74 | choices=['no', 'sim', 'auto', 'rand'],
75 | help='data augmentation policy')
76 | # random augment
77 | parser.add_argument('--rand-n', type=int, default=1,
78 | help='# of random augment')
79 | parser.add_argument('--rand-m', type=int, default=2,
80 | help='magnitude of random augment')
81 | parser.add_argument('--cutout', action='store_true', default=False,
82 | help='perform cutout')
83 | parser.add_argument('--n-holes', type=int, default=1,
84 | help='# of cutout holes')
85 | parser.add_argument('--cutout-length', type=int, default=16,
86 | help='length of a cutout hole')
87 | parser.add_argument('--mix', type=str, default=None,
88 | choices=['mixup', 'cutmix', 'mixup_cutmix'],
89 | help='image-based regularizations')
90 | parser.add_argument('--mix-alpha', type=float, default=1.0,
91 | help='alpha for mixup/cutmix beta distribution')
92 | parser.add_argument('--seed', type=int, default=3407)
93 | parser.add_argument('--debug', action='store_true', default=False,
94 | help='debug: train 1 epoch')
95 | # wandb
96 | parser.add_argument('--wandb', action='store_true', default=True,
97 | help='use wandb for visualization')
98 | parser.add_argument('--wandb-entity', type=str, default=None,
99 | help='your wandb id')
100 | parser.add_argument('--wandb-project', type=str, default=None,
101 | help='wandb project name')
102 | return parser
103 |
104 | def linear_parser():
105 | parser = argparse.ArgumentParser('argument for linear finetuning')
106 |
107 | parser.add_argument('--desc', type=str, default=None,
108 | help='experiment name')
109 | parser.add_argument('--model', type=str, default='resnet50',
110 | help='model kind')
111 | # model config
112 | parser.add_argument('--pretrained', type=str,
113 | help='pretraiend encoder to load')
114 | parser.add_argument('--pretrained-ckpt', type=str, default='ckpt_last.pth',
115 | help='pretrained encoder checkpoint')
116 | parser.add_argument('--label-smoothing', type=float, default=0.,
117 | help='label smoothing for cross-entropy loss')
118 | parser.add_argument('--print-freq', type=int, default=10,
119 | help='print frequency')
120 | parser.add_argument('--save-freq', type=int, default=20,
121 | help='save frequency')
122 | parser.add_argument('--save-root', type=str, default='./saves/',
123 | help='root directory of results')
124 | parser.add_argument('--batch-size', type=int, default=512,
125 | help='batch_size')
126 | parser.add_argument('--num-workers', type=int, default=4,
127 | help='num of workers to use')
128 | parser.add_argument('--epochs', type=int, default=100,
129 | help='number of training epochs')
130 |
131 | # optimization
132 | parser.add_argument('--optim-kind', type=str, default='SGD',
133 | choices=['SGD', 'RMSProp', 'Adam', 'AdamW'],
134 | help='kind of optimizer')
135 | parser.add_argument('--LARS', action='store_true', default=False)
136 | parser.add_argument('--learning-rate', type=float, default=5,
137 | help='learning rate')
138 | parser.add_argument('--cosine', action='store_true',
139 | help='using cosine annealing')
140 | parser.add_argument('--lr_decay_epochs', type=int, nargs='+', default=[60,75,90],
141 | help='where to decay lr, can be a list')
142 | parser.add_argument('--lr_decay_rate', type=float, default=0.2,
143 | help='decay rate for learning rate')
144 | parser.add_argument('--weight_decay', type=float, default=0,
145 | help='weight decay')
146 | parser.add_argument('--momentum', type=float, default=0.9,
147 | help='momentum')
148 | parser.add_argument('--amp', action='store_true', default=False,
149 | help='use automatic mixed precision')
150 | parser.add_argument('--multiprocessing-distributed', action='store_true', default=False)
151 | # warmup
152 | parser.add_argument('--warm', action="store_true", default=True,
153 | help='warm-up for large batch training')
154 | parser.add_argument('--warmup-from', type=float, default=1e-5)
155 | parser.add_argument('--warm-epochs', type=int, default=5)
156 | # model dataset
157 | parser.add_argument('--dataset', type=str,
158 | choices=['cifar10', 'cifar100', 'imagenet'])
159 | # data augment policy
160 | parser.add_argument('--augment-policy', type=str, default='sim',
161 | choices=['no', 'sim', 'auto', 'rand'],
162 | help='data augmentation policy')
163 | # random augment
164 | parser.add_argument('--rand-n', type=int, default=1,
165 | help='# of random augment')
166 | parser.add_argument('--rand-m', type=int, default=2,
167 | help='magnitude of random augment')
168 | # erasing
169 | parser.add_argument('--erasing', action='store_true', default=False,
170 | help='perform erasing regularization')
171 | parser.add_argument('--erasing-p', type=float, default=0.5,
172 | help='erasing probability')
173 | # cutout
174 | parser.add_argument('--cutout', action='store_true', default=False,
175 | help='perform cutout')
176 | parser.add_argument('--n-holes', type=int, default=1,
177 | help='# of cutout holes')
178 | parser.add_argument('--cutout-length', type=int, default=16,
179 | help='length of a cutout hole')
180 | parser.add_argument('--mix', type=str, default=None,
181 | choices=['mixup', 'cutmix', 'mixup_cutmix'],
182 | help='image-based regularizations')
183 | parser.add_argument('--mix-alpha', type=float, default=1.0,
184 | help='alpha for mixup/cutmix beta distribution')
185 |
186 | parser.add_argument('--seed', type=int, default=3407)
187 | parser.add_argument('--debug', action='store_true', default=False,
188 | help='debug: train 1 epoch')
189 | # wandb
190 | parser.add_argument('--wandb', action='store_true', default=True,
191 | help='use wandb for visualization')
192 | parser.add_argument('--wandb-entity', type=str, default=None,
193 | help='your wandb id')
194 | parser.add_argument('--wandb-project', type=str, default=None,
195 | help='wandb project name')
196 | return parser
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import datetime
3 | from copy import deepcopy
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | from torchlars import LARS
9 | import wandb
10 |
11 | class AverageMeter(object):
12 | """Computes and stores the average and current value"""
13 | def __init__(self):
14 | self.reset()
15 |
16 | def reset(self):
17 | self.val = 0
18 | self.avg = 0
19 | self.sum = 0
20 | self.count = 0
21 |
22 | def update(self, val, n=1):
23 | self.val = val
24 | self.sum += val * n
25 | self.count += n
26 | self.avg = self.sum / self.count
27 |
28 |
29 | def set_optimizer(model, args):
30 | if args.optim_kind == 'SGD':
31 | optimizer = optim.SGD(model.parameters(), args.learning_rate,
32 | momentum=args.momentum, weight_decay=args.weight_decay)
33 | elif args.optim_kind == 'RMSProp':
34 | optimizer = optim.RMSprop(model.parameters(), lr=args.learning_rate,
35 | weight_decay=args.weight_decay, momentum=args.momentum)
36 | elif args.optim_kind == 'Adam':
37 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate,
38 | betas=(0.9, 0.999), weight_decay=args.weight_decay)
39 | elif args.optim_kind == 'AdamW':
40 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate,
41 | weight_decay=args.weight_decay)
42 | if args.LARS:
43 | optimizer = LARS(optimizer)
44 | return optimizer
45 |
46 |
47 | def accuracy(output, target):
48 | with torch.no_grad():
49 | bsz = target.shape[0]
50 | pred = torch.argmax(output, dim=1)
51 | acc = 100 * (pred == target).sum() / bsz
52 | return acc.item()
53 |
54 | def adjust_learning_rate(args, optimizer, epoch):
55 | lr = args.learning_rate
56 | if args.cosine:
57 | eta_min = lr * (args.lr_decay_rate ** 3)
58 | lr = eta_min + (lr - eta_min) * (
59 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2
60 | else:
61 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
62 | if steps > 0:
63 | lr = lr * (args.lr_decay_rate ** steps)
64 |
65 | for param_group in optimizer.param_groups:
66 | param_group['lr'] = lr
67 |
68 |
69 | def get_learning_rate(optimizer):
70 | for param_group in optimizer.param_groups:
71 | return param_group['lr']
72 |
73 |
74 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
75 | if args.warm and epoch <= args.warm_epochs:
76 | eta_min = args.learning_rate * (args.lr_decay_rate ** 3)
77 | warmup_to = eta_min + (args.learning_rate - eta_min) * (
78 | 1 + math.cos(math.pi * args.warm_epochs / args.epochs)) / 2
79 |
80 | p = (batch_id + (epoch - 1) * total_batches) / \
81 | (args.warm_epochs * total_batches)
82 | lr = args.warmup_from + p * (warmup_to - args.warmup_from)
83 |
84 | for param_group in optimizer.param_groups:
85 | param_group['lr'] = lr
86 |
87 |
88 | def save_model(model, optimizer, args, epoch, save_file):
89 | print(f'==> Saving {save_file}...')
90 | state = {
91 | 'args': args,
92 | 'state_dict': model.state_dict(),
93 | 'optimizer': optimizer.state_dict(),
94 | 'epoch': epoch,
95 | }
96 | if not save_file.parent.exists():
97 | save_file.parent.mkdir()
98 | torch.save(state, save_file)
99 | del state
100 |
101 | def seed(seed=1):
102 | """
103 | Seed for PyTorch reproducibility.
104 | Arguments:
105 | seed (int): Random seed value.
106 | """
107 | np.random.seed(seed)
108 | torch.manual_seed(seed)
109 | torch.cuda.manual_seed_all(seed)
110 |
111 |
112 | def init_wandb(args):
113 | wandb.init(
114 | entity=args.wandb_entity,
115 | project=args.wandb_project,
116 | name=args.desc,
117 | config=args,
118 | )
119 | wandb.run.save()
120 | return wandb.config
121 |
122 |
123 | def format_time(elapsed):
124 | """
125 | Format time for displaying.
126 | Arguments:
127 | elapsed: time interval in seconds.
128 | """
129 | elapsed_rounded = int(round((elapsed)))
130 | return str(datetime.timedelta(seconds=elapsed_rounded))
131 |
132 |
133 | class ModelEmaV2(nn.Module):
134 | def __init__(self, model, decay=0.9999, device=None):
135 | super(ModelEmaV2, self).__init__()
136 | # make a copy of the model for accumulating moving average of weights
137 | self.module = deepcopy(model)
138 | self.module.eval()
139 | self.decay = decay
140 | self.device = device # perform ema on different device from model if set
141 | if self.device is not None:
142 | self.module.to(device=device)
143 |
144 | def _update(self, model, update_fn):
145 | with torch.no_grad():
146 | for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
147 | if self.device is not None:
148 | model_v = model_v.to(device=self.device)
149 | ema_v.copy_(update_fn(ema_v, model_v))
150 |
151 | def update(self, model):
152 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
153 |
154 | def set(self, model):
155 | self._update(model, update_fn=lambda e, m: m)
--------------------------------------------------------------------------------