├── LICENSE ├── README.md ├── data ├── __init__.py ├── data.py └── utils.py ├── figures ├── ex.jpeg ├── genscl.png └── results.png ├── genscl.py ├── linear.py ├── loss.py ├── mix.py ├── parser.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Jaewon Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Generalized Supervised Contrastive Learning Framework 2 | 3 | f1 4 | 5 | Official PyTorch implementation of the GenSCL | [Paper](https://arxiv.org/abs/2206.00384) 6 | 7 | Jaewon Kim and Jooyoung Chang and Sang Min Park 8 | 9 | [HSDS](http://snuhsds.com/)@Seoul National University 10 | 11 | Our implementation is based on the [Supervised Contrastive Learning](https://github.com/HobbitLong/SupContrast) repository. 12 | 13 | ### Abstract 14 | 15 | Based on recent remarkable achievements of contrastive learning in self-supervised representation learning, supervised contrastive learning (SupCon) has successfully extended the batch contrastive approaches to the supervised context and outperformed cross-entropy on various datasets on ResNet. In this work, we present *GenSCL*: a generalized supervised contrastive learning framework that seamlessly adapts modern image-based regularizations (such as Mixup-Cutmix) and knowledge distillation (KD) to SupCon by our *generalized supervised contrastive loss*. Generalized supervised contrastive loss is a further extension of supervised contrastive loss measuring cross-entropy between the similarity of labels and that of latent features. Then a model can learn to what extent contrastives should be pulled closer to an anchor in the latent space. By explicitly and fully leveraging label information, GenSCL breaks the boundary between conventional positives and negatives, and any kind of pre-trained teacher classifier can be utilized. ResNet-50 trained in GenSCL with Mixup-Cutmix and KD achieves state-of-the-art accuracies of 97.6% and 84.7% on CIFAR10 and CIFAR100 without external data, which significantly improves the results reported in the original SupCon (1.6% and 8.2%, respectively). Pytorch implementation is available at https://t.ly/yuUO. 16 | 17 | ### Overview of the results 18 | 19 | results 20 | 21 | ## Loss Function 22 | 23 | Our proposed *Generalized Supervised Contrastive Loss* in `loss.py` takes a tuple of `features` and a tuple of `labels` as the input, and returns the loss. If `labels` is one-hot encoded label, it degenerates to Supervised Contrastive Loss. 24 | 25 | By *Generalized Supervised Contrastive Loss*, we can seamlessly adapt Mixup/Cutmix and knowledge distillation to Supervised Contrastive Learning. 26 | 27 | ![ex](figures/ex.jpeg) 28 | 29 | ## Running 30 | 31 | To apply knowledge distillation, pretrained teacher model (EfficientNetV2-M) is required and released [here](https://www.dropbox.com/sh/io8u9mv8hh3bt4m/AACjNFDZIgPADoyU14OEqVQSa?dl=0). 32 | 33 | * CIFAR10 34 | 35 | * Pretraining stage: 36 | 37 | ```bash 38 | python genscl.py \ 39 | --dataset cifar10 \ 40 | --mix mixup_cutmix \ 41 | --KD \ 42 | --KD-alpha 1 \ 43 | --teacher-path ./pretrained_saves/efficientnetv2_rw_m_ema_mixup_cutmix_cifar10_Adam 44 | ``` 45 | 46 | * Linear evaluation stage: 47 | 48 | ```bash 49 | python linear.py \ 50 | --dataset cifar10 \ 51 | --pretrained cifar10_bsz_1024_mixup_cutmix_1.0_KD_1.0_SGD_lr_0.5 \ 52 | --augment-policy no \ 53 | --amp 54 | ``` 55 | 56 | * CIFAR100 57 | 58 | * Pretraining stage: 59 | 60 | ```bash 61 | python genscl.py \ 62 | --dataset cifar100 \ 63 | --mix mixup_cutmix \ 64 | --KD \ 65 | --KD-alpha 1 \ 66 | --teacher-path ./pretrained_saves/efficientnetv2_rw_m_ema_mixup_cutmix_cifar100_Adam 67 | ``` 68 | 69 | * Linear evaluation stage: 70 | 71 | ```bash 72 | python linear.py \ 73 | --dataset cifar100 \ 74 | --pretrained cifar100_bsz_1024_mixup_cutmix_1.0_KD_1.0_SGD_lr_0.5 \ 75 | --augment-policy no \ 76 | --amp 77 | ``` 78 | 79 | 80 | You have several extra options: 81 | 82 | * `--optim-kind`: SGD, RMSProp, Adam, AdamW 83 | 84 | * `--augment-policy`: no, sim, auto, rand 85 | * `--wandb`: enable [wandb](https://wandb.ai/) for visualization 86 | 87 | ## Updates 88 | 89 | * 23 Jun, 2022: Initial upload 90 | 91 | ## Citation 92 | 93 | ``` 94 | @article{kim2022generalized, 95 | title={A Generalized Supervised Contrastive Learning Framework}, 96 | author={Kim, Jaewon and Chang, Jooyoung and Park, Sang Min}, 97 | journal={arXiv preprint arXiv:2206.00384}, 98 | year={2022} 99 | } 100 | ``` 101 | 102 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import * -------------------------------------------------------------------------------- /data/data.py: -------------------------------------------------------------------------------- 1 | from .utils import TwoCropTransform 2 | from .utils import Cutout 3 | 4 | import torch 5 | import torchvision.transforms as transforms 6 | import torchvision.datasets as datasets 7 | 8 | MEAN = { 9 | 'cifar10': [0.4914, 0.4822, 0.4465], 10 | 'cifar100': [0.5071, 0.4867, 0.4408], 11 | 'imagenet': [0.485, 0.456, 0.406] 12 | } 13 | STD = { 14 | 'cifar10': [0.2023, 0.1994, 0.2010], 15 | 'cifar100':[0.2675, 0.2565, 0.2761], 16 | 'imagenet': [0.229, 0.224, 0.225] 17 | } 18 | SIZE = { 19 | 'cifar10': 32, 20 | 'cifar100': 32, 21 | 'imagenet': 224, 22 | } 23 | 24 | NUM_CLASSES = { 25 | 'cifar10': 10, 26 | 'cifar100': 100, 27 | 'imagenet': 1000 28 | } 29 | 30 | 31 | def contrastive_loader(args): 32 | # transformation 33 | normalize = transforms.Normalize(mean=MEAN[args.dataset], std=STD[args.dataset]) 34 | 35 | if args.augment_policy == 'sim': # simclr augment 36 | train_transform = transforms.Compose([ 37 | transforms.RandomResizedCrop(size=SIZE[args.dataset], scale=(0.2, 1.)), 38 | transforms.RandomHorizontalFlip(), 39 | transforms.RandomApply([ 40 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 41 | ], p=0.8), 42 | transforms.RandomGrayscale(p=0.2), 43 | transforms.ToTensor(), 44 | normalize, 45 | ]) 46 | elif args.augment_policy == 'auto': # auto augment 47 | train_transform = transforms.Compose([ 48 | transforms.RandomResizedCrop(size=SIZE[args.dataset]), 49 | transforms.RandomHorizontalFlip(), 50 | transforms.AutoAugment(), 51 | transforms.ToTensor(), 52 | normalize, 53 | ]) 54 | elif args.augment_policy == 'rand': 55 | train_transform = transforms.Compose([ 56 | transforms.RandomResizedCrop(size=SIZE[args.dataset]), 57 | transforms.RandomHorizontalFlip(), 58 | transforms.RandAugment(args.rand_n, args.rand_m), 59 | transforms.ToTensor(), 60 | normalize, 61 | ]) 62 | else: 63 | raise NotImplementedError(f'Unknown {args.augment_policy}!') 64 | 65 | if args.cutout: 66 | train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.cutout_length)) 67 | 68 | # dataset 69 | if args.dataset == 'cifar10': 70 | train_dataset = datasets.CIFAR10(root='/home/DB/', 71 | transform=TwoCropTransform(train_transform), 72 | download=True) 73 | elif args.dataset == 'cifar100': 74 | train_dataset = datasets.CIFAR100(root='/home/DB/', 75 | transform=TwoCropTransform(train_transform), 76 | download=True) 77 | elif args.dataset == 'imagenet': 78 | train_dataset = datasets.ImageFolder('/home/DB/IMAGENET/train', 79 | TwoCropTransform(train_transform) 80 | ) 81 | else: 82 | raise ValueError(args.dataset) 83 | 84 | if args.multiprocessing_distributed: 85 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 86 | else: 87 | train_sampler = None 88 | 89 | train_loader = torch.utils.data.DataLoader( 90 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 91 | num_workers=args.num_workers, pin_memory=True, sampler=train_sampler, drop_last=True) 92 | 93 | return train_loader, train_sampler 94 | 95 | 96 | def normal_loader(args): 97 | normalize = transforms.Normalize(mean=MEAN[args.dataset], std=STD[args.dataset]) 98 | if args.dataset == 'cifar10' or args.dataset == 'cifar100': 99 | resize = transforms.RandomCrop(SIZE[args.dataset], padding=4) 100 | else: 101 | resize = transforms.RandomResizedCrop(size=SIZE[args.dataset]) 102 | # train dataset 103 | if args.augment_policy == 'no': 104 | train_transform = transforms.Compose([ 105 | transforms.RandomResizedCrop(size=SIZE[args.dataset], scale=(0.2, 1.)), 106 | transforms.RandomHorizontalFlip(), 107 | transforms.ToTensor(), 108 | normalize, 109 | ]) 110 | elif args.augment_policy == 'sim': # simclr augment 111 | train_transform = transforms.Compose([ 112 | transforms.RandomResizedCrop(size=SIZE[args.dataset], scale=(0.2, 1.)), 113 | transforms.RandomHorizontalFlip(), 114 | transforms.RandomApply([ 115 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 116 | ], p=0.8), 117 | transforms.RandomGrayscale(p=0.2), 118 | transforms.ToTensor(), 119 | normalize, 120 | ]) 121 | elif args.augment_policy == 'auto': 122 | train_transform = transforms.Compose([ 123 | resize, 124 | transforms.RandomHorizontalFlip(), 125 | transforms.AutoAugment(), 126 | transforms.ToTensor(), 127 | normalize, 128 | ]) 129 | elif args.augment_policy == 'rand': 130 | train_transform = transforms.Compose([ 131 | resize, 132 | transforms.RandomHorizontalFlip(), 133 | transforms.RandAugment(args.rand_n, args.rand_m), 134 | transforms.ToTensor(), 135 | normalize, 136 | ]) 137 | else: 138 | raise NotImplementedError() 139 | 140 | if args.erasing: 141 | train_transform.transforms.append(transforms.RandomErasing(p=args.erasing_p)) 142 | if args.cutout: 143 | train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.cutout_length)) 144 | 145 | # valid datset 146 | if args.dataset == 'imagenet': 147 | val_transform = transforms.Compose([ 148 | transforms.Resize(256), 149 | transforms.CenterCrop(224), 150 | transforms.ToTensor(), 151 | normalize, 152 | ]) 153 | else: 154 | val_transform = transforms.Compose([ 155 | transforms.ToTensor(), 156 | normalize, 157 | ]) 158 | 159 | if args.dataset == 'cifar10': 160 | train_dataset = datasets.CIFAR10(root='/home/DB/', 161 | transform=train_transform, 162 | download=True) 163 | val_set = datasets.CIFAR10(root='/home/DB/', 164 | transform=val_transform, 165 | train=False, 166 | download=True) 167 | elif args.dataset == 'cifar100': 168 | train_dataset = datasets.CIFAR100(root='/home/DB/', 169 | transform=train_transform, 170 | download=True) 171 | val_set = datasets.CIFAR100(root='/home/DB/', 172 | transform=val_transform, 173 | train=False, 174 | download=True) 175 | elif args.dataset == 'imagenet': 176 | train_dataset = datasets.ImageFolder('/home/DB/IMAGENET/train', train_transform) 177 | val_set = datasets.ImageFolder('/home/DB/IMAGENET/val', val_transform) 178 | else: 179 | raise ValueError(args.dataset) 180 | 181 | if args.multiprocessing_distributed: 182 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 183 | else: 184 | train_sampler = None 185 | 186 | train_loader = torch.utils.data.DataLoader( 187 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 188 | num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) 189 | 190 | val_loader = torch.utils.data.DataLoader( 191 | val_set, batch_size=args.batch_size, shuffle=False, 192 | num_workers=args.num_workers, pin_memory=True 193 | ) 194 | 195 | return train_loader, val_loader, train_sampler -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class Cutout(object): 6 | """Randomly mask out one or more patches from an image. 7 | Args: 8 | n_holes (int): Number of patches to cut out of each image. 9 | length (int): The length (in pixels) of each square patch. 10 | """ 11 | def __init__(self, n_holes, length): 12 | self.n_holes = n_holes 13 | self.length = length 14 | 15 | def __call__(self, img): 16 | """ 17 | Args: 18 | img (Tensor): Tensor image of size (C, H, W). 19 | Returns: 20 | Tensor: Image with n_holes of dimension length x length cut out of it. 21 | """ 22 | h = img.size(1) 23 | w = img.size(2) 24 | 25 | mask = np.ones((h, w), np.float32) 26 | 27 | for n in range(self.n_holes): 28 | y = np.random.randint(h) 29 | x = np.random.randint(w) 30 | 31 | y1 = np.clip(y - self.length // 2, 0, h) 32 | y2 = np.clip(y + self.length // 2, 0, h) 33 | x1 = np.clip(x - self.length // 2, 0, w) 34 | x2 = np.clip(x + self.length // 2, 0, w) 35 | 36 | mask[y1: y2, x1: x2] = 0. 37 | 38 | mask = torch.from_numpy(mask) 39 | mask = mask.expand_as(img) 40 | img = img * mask 41 | 42 | return img 43 | 44 | class TwoCropTransform: 45 | """Create two crops of the same image""" 46 | def __init__(self, transform): 47 | self.transform = transform 48 | 49 | def __call__(self, x): 50 | return [self.transform(x), self.transform(x)] -------------------------------------------------------------------------------- /figures/ex.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kiimmm/GenSCL/d2395c722f233337a722c8d0d18b86527e96f110/figures/ex.jpeg -------------------------------------------------------------------------------- /figures/genscl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kiimmm/GenSCL/d2395c722f233337a722c8d0d18b86527e96f110/figures/genscl.png -------------------------------------------------------------------------------- /figures/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kiimmm/GenSCL/d2395c722f233337a722c8d0d18b86527e96f110/figures/results.png -------------------------------------------------------------------------------- /genscl.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | from pathlib import Path 4 | 5 | import timm 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.backends.cudnn as cudnn 9 | from torch.cuda.amp import autocast 10 | try: 11 | import wandb 12 | except ImportError: 13 | pass 14 | 15 | from utils import AverageMeter 16 | from utils import adjust_learning_rate, warmup_learning_rate, get_learning_rate 17 | from utils import set_optimizer, save_model 18 | from utils import seed, format_time 19 | from utils import init_wandb 20 | 21 | from networks.resnet_big import SupConResNet 22 | from loss import GenSupConLoss 23 | from mix import mix_fn, mix_target 24 | from data import contrastive_loader, NUM_CLASSES 25 | from parser import genscl_parser 26 | 27 | 28 | 29 | # load encoder (student) 30 | def set_model(args): 31 | model = SupConResNet(name=args.model) 32 | 33 | if args.KD: # load teacher model 34 | teacher = timm.create_model(args.teacher_kind, pretrained=False) 35 | teacher.reset_classifier(NUM_CLASSES[args.dataset]) 36 | out_ch = teacher.conv_stem.out_channels 37 | teacher.conv_stem = torch.nn.Conv2d(3, out_ch, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 38 | teacher_ckpt = torch.load(Path(args.teacher_path)/args.teacher_ckpt, map_location='cpu') 39 | teacher.load_state_dict(teacher_ckpt['state_dict']) 40 | teacher.eval() 41 | else: 42 | teacher = None 43 | 44 | criterion = GenSupConLoss(temperature=args.temp) 45 | 46 | if torch.cuda.is_available(): 47 | if torch.cuda.device_count() > 1: 48 | model.encoder = torch.nn.DataParallel(model.encoder) 49 | if teacher: teacher = torch.nn.DataParallel(teacher) 50 | if not args.resume is None: # resume from previous ckpt 51 | load_fn = Path(args.save_root)/args.desc/f'ckpt_{args.resume}.pth' 52 | ckpt = torch.load(load_fn, map_location='cpu') 53 | model.load_state_dict(ckpt['state_dict']) 54 | print(f'=> Successfully loading {load_fn}!') 55 | args.start_epoch = ckpt['epoch'] + 1 56 | else: 57 | args.start_epoch = 1 58 | 59 | model = model.cuda() 60 | if teacher: teacher = teacher.cuda() 61 | criterion = criterion.cuda() 62 | cudnn.benchmark = True 63 | 64 | return model, teacher, criterion 65 | 66 | 67 | def train(loader, model, teacher, criterion, optimizer, epoch, args): 68 | model.train() 69 | 70 | batch_time = AverageMeter() 71 | data_time = AverageMeter() 72 | losses = AverageMeter() 73 | 74 | end = time.time() 75 | 76 | for idx, (images, targets) in enumerate(loader): 77 | data_time.update(time.time() - end) 78 | warmup_learning_rate(args, epoch, idx, len(loader), optimizer) 79 | 80 | bsz = targets.shape[0] 81 | im_q, im_k = images 82 | if torch.cuda.is_available(): 83 | im_q = im_q.cuda(non_blocking=True) 84 | im_k = im_k.cuda(non_blocking=True) 85 | targets = targets.cuda(non_blocking=True) 86 | 87 | if args.mix: # image-based regularizations 88 | im_q, y0a, y0b, lam0 = mix_fn(im_q, targets, args.mix_alpha, args.mix) 89 | im_k, y1a, y1b, lam1 = mix_fn(im_k, targets, args.mix_alpha, args.mix) 90 | images = torch.cat([im_q, im_k], dim=0) 91 | l_q = mix_target(y0a, y0b, lam0, NUM_CLASSES[args.dataset]) 92 | l_k = mix_target(y1a, y1b, lam1, NUM_CLASSES[args.dataset]) 93 | else: 94 | images = torch.cat([im_q, im_k], dim=0) 95 | l_q = F.one_hot(targets, NUM_CLASSES[args.dataset]) 96 | l_k = l_q 97 | 98 | if teacher: # KD 99 | with torch.no_grad(): 100 | with autocast(): 101 | preds = F.softmax(teacher(images) / args.KD_temp, dim=1) 102 | teacher_q, teacher_k = torch.split(preds, [bsz, bsz], dim=0) 103 | 104 | # forward 105 | features = model(images) 106 | features = torch.split(features, [bsz, bsz], dim=0) 107 | 108 | if teacher: 109 | if args.KD_alpha == float('inf'): # only learn from teacher's prediction 110 | loss = criterion(features, [teacher_q, teacher_k]) 111 | else: 112 | loss = criterion(features, [l_q, l_k]) + args.KD_alpha * criterion(features, [teacher_q, teacher_k]) 113 | else: # no KD 114 | loss = criterion(features, [l_q, l_k]) 115 | 116 | 117 | losses.update(loss.item(), bsz) 118 | # backwaqrd 119 | optimizer.zero_grad() 120 | loss.backward() 121 | optimizer.step() 122 | 123 | # measure elapsed time 124 | batch_time.update(time.time() - end) 125 | end = time.time() 126 | 127 | # print info 128 | if (idx + 1) % args.print_freq == 0: 129 | print('Train: [{0}][{1}/{2}]\t' 130 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 131 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 132 | 'loss {loss.val:.3f} ({loss.avg:.3f})'.format( 133 | epoch, idx + 1, len(loader), batch_time=batch_time, 134 | data_time=data_time, loss=losses)) 135 | sys.stdout.flush() 136 | 137 | res = { 138 | 'trn_loss': losses.avg, 139 | 'learning_rate': get_learning_rate(optimizer) 140 | } 141 | return res 142 | 143 | 144 | def default_desc(args): 145 | if args.desc is None: 146 | desc = args.dataset + '_' 147 | desc += f'bsz_{args.batch_size}_' 148 | if args.cutout: 149 | desc += 'cutout_' 150 | if args.mix: 151 | desc += f'{args.mix}_{args.mix_alpha}_' 152 | if args.KD: 153 | desc += f'KD_{args.KD_alpha}_' 154 | desc += f'{args.optim_kind}_lr_{args.learning_rate}' 155 | args.desc = desc 156 | return args 157 | 158 | 159 | def main(): 160 | parser = genscl_parser() 161 | args = parser.parse_args() 162 | args = default_desc(args) 163 | seed(args.seed) 164 | 165 | if args.debug: 166 | args.epochs = 1 167 | elif args.wandb: 168 | init_wandb(args) 169 | save_dir = Path(args.save_root)/args.desc 170 | 171 | # build data loader 172 | train_loader, _ = contrastive_loader(args) 173 | 174 | # build model and criterion 175 | model, teacher, criterion = set_model(args) 176 | 177 | # build optimizer 178 | optimizer = set_optimizer(model, args) 179 | 180 | # train 181 | for epoch in range(args.start_epoch, args.epochs + 1): 182 | adjust_learning_rate(args, optimizer, epoch) 183 | 184 | # train for one epoch 185 | time1 = time.time() 186 | res = train(train_loader, model, teacher, criterion, optimizer, epoch, args) 187 | time2 = time.time() 188 | print(f'epoch {epoch}, total time {format_time(time2 - time1)}') 189 | if not args.debug and args.wandb: 190 | wandb.log(res, step=epoch) 191 | 192 | if (epoch % args.save_freq == 0) and not args.debug: 193 | save_fn = save_dir/f'ckpt_{epoch}.pth' 194 | save_model(model, optimizer, args, epoch, save_fn) 195 | 196 | if not args.debug: 197 | save_fn = save_dir/f'ckpt_last.pth' 198 | save_model(model, optimizer, args, args.epochs, save_fn) 199 | if args.wandb: 200 | wandb.finish() 201 | 202 | if __name__ == '__main__': 203 | main() -------------------------------------------------------------------------------- /linear.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | from pathlib import Path 4 | from contextlib import suppress 5 | 6 | import torch 7 | from torch.cuda.amp import autocast, GradScaler 8 | import torch.backends.cudnn as cudnn 9 | try: 10 | import wandb 11 | except ImportError: 12 | pass 13 | 14 | from utils import AverageMeter 15 | from utils import adjust_learning_rate, warmup_learning_rate, accuracy, get_learning_rate 16 | from utils import set_optimizer, save_model 17 | from utils import init_wandb 18 | from utils import format_time 19 | from utils import seed 20 | 21 | from networks.resnet_big import SupConResNet, LinearClassifier 22 | from data import normal_loader, NUM_CLASSES 23 | from mix import mix_fn 24 | from parser import linear_parser 25 | 26 | 27 | # load trained encoder and build a classifier to train 28 | def set_model(args): 29 | model = SupConResNet(name=args.model) 30 | classifier = LinearClassifier(name=args.model, num_classes=NUM_CLASSES[args.dataset]) 31 | 32 | criterion = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) 33 | 34 | load_fn = Path(args.save_root)/args.pretrained/args.pretrained_ckpt 35 | ckpt = torch.load(load_fn, map_location='cpu') 36 | state_dict = ckpt['state_dict'] 37 | 38 | if torch.cuda.is_available(): 39 | if torch.cuda.device_count() > 1: 40 | model.encoder = torch.nn.DataParallel(model.encoder) 41 | else: 42 | new_state_dict = {} 43 | for k, v in state_dict.items(): 44 | k = k.replace("module.", "") 45 | new_state_dict[k] = v 46 | state_dict = new_state_dict 47 | model = model.cuda() 48 | classifier = classifier.cuda() 49 | criterion = criterion.cuda() 50 | cudnn.benchmark = True 51 | 52 | model.load_state_dict(state_dict) 53 | 54 | return model, classifier, criterion 55 | 56 | 57 | def train(loader, model, classifier, criterion, optimizer, epoch, amp_autocast, scaler, args): 58 | """one epoch training""" 59 | model.eval() 60 | classifier.train() 61 | 62 | batch_time = AverageMeter() 63 | data_time = AverageMeter() 64 | losses = AverageMeter() 65 | top1 = AverageMeter() 66 | 67 | end = time.time() 68 | for idx, (images, targets) in enumerate(loader): 69 | data_time.update(time.time() - end) 70 | 71 | images = images.cuda(non_blocking=True) 72 | targets = targets.cuda(non_blocking=True) 73 | bsz = targets.shape[0] 74 | 75 | # warm-up learning rate 76 | warmup_learning_rate(args, epoch, idx, len(loader), optimizer) 77 | 78 | # compute loss 79 | with amp_autocast(): 80 | with torch.no_grad(): 81 | features = model.encoder(images) 82 | output = classifier(features.detach()) 83 | loss = criterion(output, targets) 84 | 85 | # update metric 86 | losses.update(loss.item(), bsz) 87 | acc1 = accuracy(output, targets) 88 | top1.update(acc1, bsz) 89 | 90 | optimizer.zero_grad() 91 | if scaler: 92 | scaler.scale(loss).backward() 93 | scaler.step(optimizer) 94 | scaler.update() 95 | else: 96 | loss.backward() 97 | optimizer.step() 98 | 99 | # measure elapsed time 100 | batch_time.update(time.time() - end) 101 | end = time.time() 102 | 103 | # print info 104 | if (idx + 1) % args.print_freq == 0: 105 | print('Train: [{0}][{1}/{2}]\t' 106 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 107 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 108 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t' 109 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 110 | epoch, idx + 1, len(loader), batch_time=batch_time, 111 | data_time=data_time, loss=losses, top1=top1)) 112 | sys.stdout.flush() 113 | 114 | res = { 115 | 'trn_loss': losses.avg, 116 | 'trn_top1_acc': top1.avg, 117 | 'learning_rate': get_learning_rate(optimizer) 118 | } 119 | return res 120 | 121 | def validate(loader, model, classifier, criterion, amp_autocast, args): 122 | """validation""" 123 | model.eval() 124 | classifier.eval() 125 | 126 | batch_time = AverageMeter() 127 | losses = AverageMeter() 128 | top1 = AverageMeter() 129 | 130 | with torch.no_grad(): 131 | end = time.time() 132 | for idx, (images, targets) in enumerate(loader): 133 | images = images.float().cuda() 134 | targets = targets.cuda() 135 | bsz = targets.shape[0] 136 | 137 | # forward 138 | with amp_autocast(): 139 | output = classifier(model.encoder(images)) 140 | loss = criterion(output, targets) 141 | 142 | # update metric 143 | losses.update(loss.item(), bsz) 144 | acc1 = accuracy(output, targets) 145 | top1.update(acc1, bsz) 146 | 147 | # measure elapsed time 148 | batch_time.update(time.time() - end) 149 | end = time.time() 150 | 151 | if idx % args.print_freq == 0: 152 | print('Test: [{0}/{1}]\t' 153 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 154 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 155 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 156 | idx, len(loader), batch_time=batch_time, 157 | loss=losses, top1=top1)) 158 | 159 | print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) 160 | res = { 161 | 'val_loss': losses.avg, 162 | 'val_top1_acc': top1.avg 163 | } 164 | 165 | return res 166 | 167 | 168 | def default_desc(args): 169 | if args.desc is None: 170 | desc = args.pretrained + '_linear' 171 | args.desc = desc 172 | return args 173 | 174 | 175 | def main(): 176 | best_acc = 0. 177 | parser = linear_parser() 178 | args = parser.parse_args() 179 | args = default_desc(args) 180 | seed(args.seed) 181 | 182 | if args.debug: 183 | args.epochs = 1 184 | elif args.wandb: 185 | init_wandb(args) 186 | 187 | save_dir = Path(args.save_root)/args.desc 188 | amp_autocast = autocast if args.amp else suppress 189 | scaler = GradScaler() if args.amp else None 190 | # build data loader 191 | train_loader, val_loader, _ = normal_loader(args) 192 | 193 | # build model and criterion 194 | model, classifier, criterion = set_model(args) 195 | 196 | # build optimizer 197 | optimizer = set_optimizer(classifier, args) 198 | 199 | # training routine 200 | for epoch in range(1, args.epochs + 1): 201 | adjust_learning_rate(args, optimizer, epoch) 202 | 203 | # train for one epoch 204 | time1 = time.time() 205 | res = train(train_loader, model, classifier, criterion, 206 | optimizer, epoch, amp_autocast, scaler, args) 207 | time2 = time.time() 208 | print('Train epoch {}, total time {}, accuracy:{:.2f}'.format( 209 | epoch, format_time(time2 - time1), res['trn_top1_acc'])) 210 | 211 | save_model(classifier, optimizer, args, epoch, save_dir/'ckpt_last.pth') 212 | 213 | # eval for one epoch 214 | val_res = validate(val_loader, model, classifier, criterion, amp_autocast, args) 215 | val_acc = val_res['val_top1_acc'] 216 | res.update(val_res) 217 | if not(args.debug) and args.wandb: 218 | wandb.log(res, step=epoch) 219 | 220 | if val_acc > best_acc: 221 | best_acc = val_acc 222 | save_model(classifier, optimizer, args, epoch, save_dir/'ckpt_best.pth') 223 | 224 | print('best accuracy: {:.2f}'.format(best_acc)) 225 | if not(args.debug) and args.wandb: 226 | wandb.log({'val_best_acc': best_acc}) 227 | wandb.finish() 228 | 229 | if __name__ == '__main__': 230 | main() -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | ''' 2 | adapted from SupConLoss 3 | https://github.com/HobbitLong/SupContrast/blob/master/losses.py 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | 8 | class GenSupConLoss(nn.Module): 9 | def __init__(self, temperature=0.07, contrast_mode='all', 10 | base_temperature=0.07): 11 | super(GenSupConLoss, self).__init__() 12 | self.temperature = temperature 13 | self.contrast_mode = contrast_mode 14 | self.base_temperature = base_temperature 15 | 16 | def forward(self, features, labels): 17 | ''' 18 | Args: 19 | feats: (anchor_features, contrast_features), each: [N, feat_dim] 20 | labels: (anchor_labels, contrast_labels) each: [N, num_cls] 21 | ''' 22 | if self.contrast_mode == 'all': # anchor+contrast @ anchor+contrast 23 | anchor_labels = torch.cat(labels, dim=0).float() 24 | contrast_labels = anchor_labels 25 | 26 | anchor_features = torch.cat(features, dim=0) 27 | contrast_features = anchor_features 28 | elif self.contrast_mode == 'one': # anchor @ contrast 29 | anchor_labels = labels[0].float() 30 | contrast_labels = labels[1].float() 31 | 32 | anchor_features = features[0] 33 | contrast_features = features[1] 34 | 35 | # 1. compute similarities among targets 36 | anchor_norm = torch.norm(anchor_labels, p=2, dim=-1, keepdim=True) # [anchor_N, 1] 37 | contrast_norm = torch.norm(contrast_labels, p=2, dim=-1, keepdim=True) # [contrast_N, 1] 38 | 39 | deno = torch.mm(anchor_norm, contrast_norm.T) 40 | mask = torch.mm(anchor_labels, contrast_labels.T) / deno # cosine similarity: [anchor_N, contrast_N] 41 | 42 | logits_mask = torch.ones_like(mask) 43 | if self.contrast_mode == 'all': 44 | logits_mask.fill_diagonal_(0) 45 | mask = mask * logits_mask 46 | 47 | # 2. compute logits 48 | anchor_dot_contrast = torch.div( 49 | torch.matmul(anchor_features, contrast_features.T), 50 | self.temperature 51 | ) 52 | # for numerical stability 53 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 54 | logits = anchor_dot_contrast - logits_max.detach() 55 | # compute log_prob 56 | exp_logits = torch.exp(logits) * logits_mask 57 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 58 | mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-8) 59 | 60 | # loss 61 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 62 | loss = loss.mean() 63 | 64 | return loss -------------------------------------------------------------------------------- /mix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def mix_fn(x, y, alpha, kind): 7 | if kind == 'mixup': 8 | return mixup_data(x, y, alpha) 9 | elif kind == 'cutmix': 10 | return cutmix_data(x, y, alpha) 11 | elif kind == 'mixup_cutmix': 12 | if np.random.rand(1)[0] > 0.5: 13 | return mixup_data(x, y, alpha) 14 | else: 15 | return cutmix_data(x, y, alpha) 16 | else: 17 | raise ValueError() 18 | 19 | 20 | def mix_target(y_a, y_b, lam, num_classes): 21 | l1 = F.one_hot(y_a, num_classes) 22 | l2 = F.one_hot(y_b, num_classes) 23 | return lam * l1 + (1 - lam) * l2 24 | 25 | 26 | ''' 27 | modified from https://github.com/hongyi-zhang/mixup/blob/master/cifar/utils.py 28 | ''' 29 | def mixup_data(x, y, alpha=1.0): 30 | '''Returns mixed inputs, pairs of targets, and lambda''' 31 | if alpha > 0: 32 | lam = np.random.beta(alpha, alpha) 33 | else: 34 | lam = 1 35 | 36 | batch_size = x.size()[0] 37 | index = torch.randperm(batch_size, device=x.device) 38 | 39 | mixed_x = lam * x + (1 - lam) * x[index, :] 40 | y_a, y_b = y, y[index] 41 | return mixed_x, y_a, y_b, lam 42 | 43 | ''' 44 | modified from https://github.com/clovaai/CutMix-PyTorch/blob/master/train.py 45 | ''' 46 | def cutmix_data(x, y, alpha=1.0): 47 | if alpha > 0: 48 | lam = np.random.beta(alpha, alpha) 49 | else: 50 | lam = 1 51 | 52 | bsz = x.size()[0] 53 | index = torch.randperm(bsz, device=x.device) 54 | 55 | bbx1, bby1, bbx2, bby2 = _rand_bbox(x.size(), lam) 56 | mixed_x = x.detach().clone() 57 | mixed_x[:, :, bbx1:bbx2, bby1:bby2] = mixed_x[index, :, bbx1:bbx2, bby1:bby2] 58 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2])) 59 | 60 | y_a, y_b = y, y[index] 61 | 62 | return mixed_x, y_a, y_b, lam 63 | 64 | def _rand_bbox(size, lam): 65 | W = size[2] 66 | H = size[3] 67 | cut_rat = np.sqrt(1. - lam) 68 | cut_w = np.int(W * cut_rat) 69 | cut_h = np.int(H * cut_rat) 70 | 71 | # uniform 72 | cx = np.random.randint(W) 73 | cy = np.random.randint(H) 74 | 75 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 76 | bby1 = np.clip(cy - cut_h // 2, 0, H) 77 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 78 | bby2 = np.clip(cy + cut_h // 2, 0, H) 79 | 80 | return bbx1, bby1, bbx2, bby2 -------------------------------------------------------------------------------- /parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def str2bool(v): 4 | """ 5 | Parse boolean using argument parser. 6 | """ 7 | if isinstance(v, bool): 8 | return v 9 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 10 | return True 11 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 12 | return False 13 | else: 14 | raise argparse.ArgumentTypeError('Boolean value expected.') 15 | 16 | def genscl_parser(): 17 | parser = argparse.ArgumentParser('argument for supervised contrastive learning') 18 | 19 | parser.add_argument('--desc', type=str, default=None, 20 | help='experiment name') 21 | parser.add_argument('--model', type=str, default='resnet50', 22 | help='model kind') 23 | parser.add_argument('--print-freq', type=int, default=20, 24 | help='print frequency') 25 | parser.add_argument('--save-freq', type=int, default=50, 26 | help='save frequency') 27 | parser.add_argument('--save-root', type=str, default='./saves/', 28 | help='root directory of results') 29 | parser.add_argument('--batch-size', type=int, default=1024, 30 | help='batch_size') 31 | parser.add_argument('--num-workers', type=int, default=8, 32 | help='num of workers to use') 33 | parser.add_argument('--epochs', type=int, default=500, 34 | help='number of training epochs') 35 | parser.add_argument('--temp', type=float, default=0.1, 36 | help='temperature for generalized supervised contrastive loss') 37 | parser.add_argument('--resume', type=str, default=None) 38 | # knowledge distillation 39 | parser.add_argument('--KD', action='store_true', default=False, 40 | help='perform knowledge distillation') 41 | parser.add_argument('--KD-alpha', type=float, default=1, 42 | help='weight of KD') 43 | parser.add_argument('--KD-temp', type=float, default=1, 44 | help='softening prediction of teachers') 45 | parser.add_argument('--teacher-kind', type=str, default='efficientnetv2_rw_m') 46 | parser.add_argument('--teacher-path', type=str, default=None) 47 | parser.add_argument('--teacher-ckpt', type=str, default='ckpt_best.pth') 48 | # optimization 49 | parser.add_argument('--optim-kind', type=str, default='SGD', 50 | choices=['SGD', 'RMSProp', 'Adam', 'AdamW'], 51 | help='kind of optimizer') 52 | parser.add_argument('--LARS', action='store_true', default=False) 53 | parser.add_argument('--learning-rate', type=float, default=0.5, 54 | help='learning rate') 55 | parser.add_argument('--cosine', type=str2bool, default=True, 56 | help='using cosine annealing') 57 | parser.add_argument('--lr-decay-rate', type=float, default=0.1, 58 | help='decay rate for learning rate') 59 | parser.add_argument('--weight-decay', type=float, default=1e-4, 60 | help='weight decay') 61 | parser.add_argument('--momentum', type=float, default=0.9, 62 | help='momentum') 63 | # warmup 64 | parser.add_argument('--warm', type=str2bool, default=True, 65 | help='warm-up for large batch training') 66 | parser.add_argument('--warmup-from', type=float, default=0.01) 67 | parser.add_argument('--warm-epochs', type=int, default=10) 68 | parser.add_argument('--multiprocessing-distributed', action='store_true', default=False) 69 | # model dataset 70 | parser.add_argument('--dataset', type=str, 71 | choices=['cifar10', 'cifar100', 'imagenet']) 72 | # data augment policy 73 | parser.add_argument('--augment-policy', type=str, default='sim', 74 | choices=['no', 'sim', 'auto', 'rand'], 75 | help='data augmentation policy') 76 | # random augment 77 | parser.add_argument('--rand-n', type=int, default=1, 78 | help='# of random augment') 79 | parser.add_argument('--rand-m', type=int, default=2, 80 | help='magnitude of random augment') 81 | parser.add_argument('--cutout', action='store_true', default=False, 82 | help='perform cutout') 83 | parser.add_argument('--n-holes', type=int, default=1, 84 | help='# of cutout holes') 85 | parser.add_argument('--cutout-length', type=int, default=16, 86 | help='length of a cutout hole') 87 | parser.add_argument('--mix', type=str, default=None, 88 | choices=['mixup', 'cutmix', 'mixup_cutmix'], 89 | help='image-based regularizations') 90 | parser.add_argument('--mix-alpha', type=float, default=1.0, 91 | help='alpha for mixup/cutmix beta distribution') 92 | parser.add_argument('--seed', type=int, default=3407) 93 | parser.add_argument('--debug', action='store_true', default=False, 94 | help='debug: train 1 epoch') 95 | # wandb 96 | parser.add_argument('--wandb', action='store_true', default=True, 97 | help='use wandb for visualization') 98 | parser.add_argument('--wandb-entity', type=str, default=None, 99 | help='your wandb id') 100 | parser.add_argument('--wandb-project', type=str, default=None, 101 | help='wandb project name') 102 | return parser 103 | 104 | def linear_parser(): 105 | parser = argparse.ArgumentParser('argument for linear finetuning') 106 | 107 | parser.add_argument('--desc', type=str, default=None, 108 | help='experiment name') 109 | parser.add_argument('--model', type=str, default='resnet50', 110 | help='model kind') 111 | # model config 112 | parser.add_argument('--pretrained', type=str, 113 | help='pretraiend encoder to load') 114 | parser.add_argument('--pretrained-ckpt', type=str, default='ckpt_last.pth', 115 | help='pretrained encoder checkpoint') 116 | parser.add_argument('--label-smoothing', type=float, default=0., 117 | help='label smoothing for cross-entropy loss') 118 | parser.add_argument('--print-freq', type=int, default=10, 119 | help='print frequency') 120 | parser.add_argument('--save-freq', type=int, default=20, 121 | help='save frequency') 122 | parser.add_argument('--save-root', type=str, default='./saves/', 123 | help='root directory of results') 124 | parser.add_argument('--batch-size', type=int, default=512, 125 | help='batch_size') 126 | parser.add_argument('--num-workers', type=int, default=4, 127 | help='num of workers to use') 128 | parser.add_argument('--epochs', type=int, default=100, 129 | help='number of training epochs') 130 | 131 | # optimization 132 | parser.add_argument('--optim-kind', type=str, default='SGD', 133 | choices=['SGD', 'RMSProp', 'Adam', 'AdamW'], 134 | help='kind of optimizer') 135 | parser.add_argument('--LARS', action='store_true', default=False) 136 | parser.add_argument('--learning-rate', type=float, default=5, 137 | help='learning rate') 138 | parser.add_argument('--cosine', action='store_true', 139 | help='using cosine annealing') 140 | parser.add_argument('--lr_decay_epochs', type=int, nargs='+', default=[60,75,90], 141 | help='where to decay lr, can be a list') 142 | parser.add_argument('--lr_decay_rate', type=float, default=0.2, 143 | help='decay rate for learning rate') 144 | parser.add_argument('--weight_decay', type=float, default=0, 145 | help='weight decay') 146 | parser.add_argument('--momentum', type=float, default=0.9, 147 | help='momentum') 148 | parser.add_argument('--amp', action='store_true', default=False, 149 | help='use automatic mixed precision') 150 | parser.add_argument('--multiprocessing-distributed', action='store_true', default=False) 151 | # warmup 152 | parser.add_argument('--warm', action="store_true", default=True, 153 | help='warm-up for large batch training') 154 | parser.add_argument('--warmup-from', type=float, default=1e-5) 155 | parser.add_argument('--warm-epochs', type=int, default=5) 156 | # model dataset 157 | parser.add_argument('--dataset', type=str, 158 | choices=['cifar10', 'cifar100', 'imagenet']) 159 | # data augment policy 160 | parser.add_argument('--augment-policy', type=str, default='sim', 161 | choices=['no', 'sim', 'auto', 'rand'], 162 | help='data augmentation policy') 163 | # random augment 164 | parser.add_argument('--rand-n', type=int, default=1, 165 | help='# of random augment') 166 | parser.add_argument('--rand-m', type=int, default=2, 167 | help='magnitude of random augment') 168 | # erasing 169 | parser.add_argument('--erasing', action='store_true', default=False, 170 | help='perform erasing regularization') 171 | parser.add_argument('--erasing-p', type=float, default=0.5, 172 | help='erasing probability') 173 | # cutout 174 | parser.add_argument('--cutout', action='store_true', default=False, 175 | help='perform cutout') 176 | parser.add_argument('--n-holes', type=int, default=1, 177 | help='# of cutout holes') 178 | parser.add_argument('--cutout-length', type=int, default=16, 179 | help='length of a cutout hole') 180 | parser.add_argument('--mix', type=str, default=None, 181 | choices=['mixup', 'cutmix', 'mixup_cutmix'], 182 | help='image-based regularizations') 183 | parser.add_argument('--mix-alpha', type=float, default=1.0, 184 | help='alpha for mixup/cutmix beta distribution') 185 | 186 | parser.add_argument('--seed', type=int, default=3407) 187 | parser.add_argument('--debug', action='store_true', default=False, 188 | help='debug: train 1 epoch') 189 | # wandb 190 | parser.add_argument('--wandb', action='store_true', default=True, 191 | help='use wandb for visualization') 192 | parser.add_argument('--wandb-entity', type=str, default=None, 193 | help='your wandb id') 194 | parser.add_argument('--wandb-project', type=str, default=None, 195 | help='wandb project name') 196 | return parser -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import datetime 3 | from copy import deepcopy 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torchlars import LARS 9 | import wandb 10 | 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value""" 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | def set_optimizer(model, args): 30 | if args.optim_kind == 'SGD': 31 | optimizer = optim.SGD(model.parameters(), args.learning_rate, 32 | momentum=args.momentum, weight_decay=args.weight_decay) 33 | elif args.optim_kind == 'RMSProp': 34 | optimizer = optim.RMSprop(model.parameters(), lr=args.learning_rate, 35 | weight_decay=args.weight_decay, momentum=args.momentum) 36 | elif args.optim_kind == 'Adam': 37 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, 38 | betas=(0.9, 0.999), weight_decay=args.weight_decay) 39 | elif args.optim_kind == 'AdamW': 40 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, 41 | weight_decay=args.weight_decay) 42 | if args.LARS: 43 | optimizer = LARS(optimizer) 44 | return optimizer 45 | 46 | 47 | def accuracy(output, target): 48 | with torch.no_grad(): 49 | bsz = target.shape[0] 50 | pred = torch.argmax(output, dim=1) 51 | acc = 100 * (pred == target).sum() / bsz 52 | return acc.item() 53 | 54 | def adjust_learning_rate(args, optimizer, epoch): 55 | lr = args.learning_rate 56 | if args.cosine: 57 | eta_min = lr * (args.lr_decay_rate ** 3) 58 | lr = eta_min + (lr - eta_min) * ( 59 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 60 | else: 61 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) 62 | if steps > 0: 63 | lr = lr * (args.lr_decay_rate ** steps) 64 | 65 | for param_group in optimizer.param_groups: 66 | param_group['lr'] = lr 67 | 68 | 69 | def get_learning_rate(optimizer): 70 | for param_group in optimizer.param_groups: 71 | return param_group['lr'] 72 | 73 | 74 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): 75 | if args.warm and epoch <= args.warm_epochs: 76 | eta_min = args.learning_rate * (args.lr_decay_rate ** 3) 77 | warmup_to = eta_min + (args.learning_rate - eta_min) * ( 78 | 1 + math.cos(math.pi * args.warm_epochs / args.epochs)) / 2 79 | 80 | p = (batch_id + (epoch - 1) * total_batches) / \ 81 | (args.warm_epochs * total_batches) 82 | lr = args.warmup_from + p * (warmup_to - args.warmup_from) 83 | 84 | for param_group in optimizer.param_groups: 85 | param_group['lr'] = lr 86 | 87 | 88 | def save_model(model, optimizer, args, epoch, save_file): 89 | print(f'==> Saving {save_file}...') 90 | state = { 91 | 'args': args, 92 | 'state_dict': model.state_dict(), 93 | 'optimizer': optimizer.state_dict(), 94 | 'epoch': epoch, 95 | } 96 | if not save_file.parent.exists(): 97 | save_file.parent.mkdir() 98 | torch.save(state, save_file) 99 | del state 100 | 101 | def seed(seed=1): 102 | """ 103 | Seed for PyTorch reproducibility. 104 | Arguments: 105 | seed (int): Random seed value. 106 | """ 107 | np.random.seed(seed) 108 | torch.manual_seed(seed) 109 | torch.cuda.manual_seed_all(seed) 110 | 111 | 112 | def init_wandb(args): 113 | wandb.init( 114 | entity=args.wandb_entity, 115 | project=args.wandb_project, 116 | name=args.desc, 117 | config=args, 118 | ) 119 | wandb.run.save() 120 | return wandb.config 121 | 122 | 123 | def format_time(elapsed): 124 | """ 125 | Format time for displaying. 126 | Arguments: 127 | elapsed: time interval in seconds. 128 | """ 129 | elapsed_rounded = int(round((elapsed))) 130 | return str(datetime.timedelta(seconds=elapsed_rounded)) 131 | 132 | 133 | class ModelEmaV2(nn.Module): 134 | def __init__(self, model, decay=0.9999, device=None): 135 | super(ModelEmaV2, self).__init__() 136 | # make a copy of the model for accumulating moving average of weights 137 | self.module = deepcopy(model) 138 | self.module.eval() 139 | self.decay = decay 140 | self.device = device # perform ema on different device from model if set 141 | if self.device is not None: 142 | self.module.to(device=device) 143 | 144 | def _update(self, model, update_fn): 145 | with torch.no_grad(): 146 | for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): 147 | if self.device is not None: 148 | model_v = model_v.to(device=self.device) 149 | ema_v.copy_(update_fn(ema_v, model_v)) 150 | 151 | def update(self, model): 152 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) 153 | 154 | def set(self, model): 155 | self._update(model, update_fn=lambda e, m: m) --------------------------------------------------------------------------------