├── README.md ├── lincls.py ├── sogclr ├── __init__.py ├── builder.py ├── folder.py ├── loader.py └── optimizer.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # SogCLR PyTorch Implementation 2 | 3 | In this repo, we show how to train a self-supervised model by using [Global Contrastive Loss](https://arxiv.org/abs/2202.12387) (GCL) on [ImageNet](https://image-net.org/). The original GCL was implementated in Tensorflow and run in TPUs [here](https://github.com/Optimization-AI/SogCLR/tree/Tensorflow). This repo **re-implements** GCL in PyTorch based on [moco's](https://github.com/facebookresearch/moco) codebase. We recommend users to run this codebase on GPU-enabled environments, such as [Google Cloud](https://cloud.google.com/), [AWS](https://aws.amazon.com/). 4 | 5 | ## What's new 6 | - 2023.03.05 Fixed `RuntimeError` related to variable `u` 7 | - 2023.03.05 Fixed `AttributeError` related to `margin` 8 | 9 | ## Installation 10 | 11 | #### git clone 12 | ```bash 13 | git clone https://github.com/Optimization-AI/SogCLR.git 14 | ``` 15 | 16 | ### Training 17 | Below is an example for self-supervised pre-training of a ResNet-50 model on ImageNet on a 4-GPU server. By default, we use sqrt learning rate scaling, i.e., $\text{LearningRate}=0.075\times\sqrt{\text{BatchSize}}$, [LARS](https://arxiv.org/abs/1708.03888) optimizer and a weight decay of 1e-6. For temperature parameter $\tau$, we use a fixed value $0.1$ from [SimCLR](https://arxiv.org/pdf/2002.05709.pdf). For GCL, gamma (γ in the paper) is an additional parameter for maintaining moving average estimator, the default value is $0.9$, however, it is recommended to tune this parameter in the range of $[0.1\sim 0.99]$ for better performance. 18 | 19 | 20 | **ImageNet1K** 21 | 22 | We use a batch size of 256 and pretrain ResNet-50 for 800 epochs. You can also increase the number of workers to accelerate the training speed. 23 | 24 | ```bash 25 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py \ 26 | --lr=.075 --epochs=800 --batch-size=256 \ 27 | --learning-rate-scaling=sqrt \ 28 | --loss_type dcl \ 29 | --gamma 0.9 \ 30 | --multiprocessing-distributed --world-size 1 --rank 0 --workers 32 \ 31 | --crop-min=.08 \ 32 | --wd=1e-6 \ 33 | --dist-url 'tcp://localhost:10001' \ 34 | --data_name imagenet1000 \ 35 | --data /your-data-path/imagenet1000/ \ 36 | --save_dir /your-data-path/saved_models/ \ 37 | --print-freq 1000 38 | ``` 39 | 40 | 41 | **ImageNet100** 42 | 43 | We also used a small version of ImageNet1K for experiments, i.e., ImageNet-100 is a subset with random selected 100 classes from original 1000 classes. To contrust the dataset, please follow these steps: 44 | * Download the train and validation datasets from [ImageNet1K](https://image-net.org/challenges/LSVRC/2012/) website 45 | * Run this [script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh) to create/move all validation images to each category (class) folder 46 | * Copy images from [train/val.txt](https://github.com/Optimization-AI/SogCLR/blob/main/dataset/ImageNet-S/train.txt) to generate ImageNet-100 47 | 48 | We use a batch size of 256 and pretrain ResNet-50 for 400 epochs. 49 | 50 | ```bash 51 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py \ 52 | --lr=.075 --epochs=400 --batch-size=256 \ 53 | --learning-rate-scaling=sqrt \ 54 | --loss_type dcl \ 55 | --gamma 0.9 \ 56 | --multiprocessing-distributed --world-size 1 --rank 0 --workers 32 \ 57 | --crop-min=.08 \ 58 | --wd=1e-6 \ 59 | --dist-url 'tcp://localhost:10001' \ 60 | --data_name imagenet100 \ 61 | --data /your-data-path/imagenet100/ \ 62 | --save_dir /your-data-path/saved_models/ \ 63 | --print-freq 1000 64 | ``` 65 | 66 | ### Linear evaluation 67 | By default, we use momentum-SGD without weight decay and a batch size of 1024 for linear evaluation on on frozen features/weights. In this stage, it runs 90 epochs for re-training the classifiers. 68 | 69 | **ImageNet** 70 | 71 | ```bash 72 | python lincls.py \ 73 | --dist-url 'tcp://localhost:10001' \ 74 | --multiprocessing-distributed --world-size 1 --rank 0 --workers 32 \ 75 | --pretrained /your-data-path/checkpoint_0799.pth.tar 76 | --data_name imagenet1000 \ 77 | --data /your-data-path/imagenet1000/ \ 78 | --save_dir /your-data-path/saved_models/ \ 79 | ``` 80 | 81 | ## Benchmarks 82 | 83 | The following results are linear evaluation results on **ImageNet1K** validation set: 84 | 85 | | Method | BatchSize |Epoch | Linear eval. | 86 | |:----------:|:--------:|:--------:|:--------:| 87 | | SimCLR (TF[^2]) | 256 | 800 | 66.5 | 88 | | SogCLR (PT[^1]) | 256 | 800 | 69.0 | 89 | | SogCLR (TF[^2]) | 256 | 800 | 69.3 | 90 | 91 | *SogCLR (PT[^1]): pre-trained ResNet-50 checkpoint & linear evaluation training log can be downloaded here: [[checkpoint_0799.pth.tar](https://drive.google.com/file/d/1baWWT6Xf9ylLHimWXZuhvdiKUkkMLB0_/view?usp=sharing) | [linear_eval.txt](https://drive.google.com/file/d/1O2N90Ffk0Oz6dXek_MhEVgXzszaogfvy/view?usp=sharing)] 92 | 93 | 94 | The following results are linear evaluation results on **ImageNet-100** validation set: 95 | 96 | | Method | BatchSize |Epoch | Linear eval. | 97 | |:----------:|:--------:|:--------:|:--------:| 98 | | SimCLR (TF[^2]) | 256 | 400 | 76.1 | 99 | | SogCLR (PT[^1]) | 256 | 400 | 80.0 | 100 | | SogCLR (TF[^2]) | 256 | 400 | 78.7 | 101 | 102 | [^1]: PyTorch (PT) is based on [MoCo's](https://github.com/facebookresearch/moco) codebase. 103 | [^2]: Tensorflow (TF) is based on [SimCLR's](https://github.com/google-research/simclr/tree/master/tf2) codebase. 104 | 105 | The following results are the comparsion of SogCLR and SimCLR using different batch sizes for 800-epoch pretraining on ImageNet-1K. 106 | 107 | 108 | 109 | 110 | ### Reference 111 | If you find this tutorial helpful, please cite our paper: 112 | ``` 113 | @inproceedings{yuan2022provable, 114 | title={Provable stochastic optimization for global contrastive learning: Small batch does not harm performance}, 115 | author={Yuan, Zhuoning and Wu, Yuexin and Qiu, Zi-Hao and Du, Xianzhi and Zhang, Lijun and Zhou, Denny and Yang, Tianbao}, 116 | booktitle={International Conference on Machine Learning}, 117 | pages={25760--25782}, 118 | year={2022}, 119 | organization={PMLR} 120 | } 121 | ``` 122 | 123 | -------------------------------------------------------------------------------- /lincls.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import argparse 10 | import builtins 11 | import math 12 | import os 13 | import random 14 | import shutil 15 | import time 16 | import numpy as np 17 | import warnings 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.parallel 22 | import torch.backends.cudnn as cudnn 23 | import torch.distributed as dist 24 | import torch.optim 25 | import torch.multiprocessing as mp 26 | import torch.utils.data 27 | import torch.utils.data.distributed 28 | import torchvision.transforms as transforms 29 | import torchvision.datasets as datasets 30 | import torchvision.models as torchvision_models 31 | from torch.utils.tensorboard import SummaryWriter 32 | 33 | torchvision_model_names = sorted(name for name in torchvision_models.__dict__ 34 | if name.islower() and not name.startswith("__") 35 | and callable(torchvision_models.__dict__[name])) 36 | 37 | model_names = torchvision_model_names 38 | 39 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 40 | parser.add_argument('--data', metavar='DIR', default='/Users/zhuoning/Experiment/ICML2023/imagenet100/', 41 | help='path to dataset') 42 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 43 | choices=model_names, 44 | help='model architecture: ' + 45 | ' | '.join(model_names) + 46 | ' (default: resnet50)') 47 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 48 | help='number of data loading workers (default: 32)') 49 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 50 | help='number of total epochs to run') 51 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 52 | help='manual epoch number (useful on restarts)') 53 | parser.add_argument('-b', '--batch-size', default=1024, type=int, 54 | metavar='N', 55 | help='mini-batch size (default: 1024), this is the total ' 56 | 'batch size of all GPUs on all nodes when ' 57 | 'using Data Parallel or Distributed Data Parallel') 58 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 59 | metavar='LR', help='initial (base) learning rate', dest='lr') 60 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 61 | help='momentum') 62 | parser.add_argument('--wd', '--weight-decay', default=0., type=float, 63 | metavar='W', help='weight decay (default: 0.)', 64 | dest='weight_decay') 65 | parser.add_argument('-p', '--print-freq', default=10, type=int, 66 | metavar='N', help='print frequency (default: 10)') 67 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 68 | help='path to latest checkpoint (default: none)') 69 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 70 | help='evaluate model on validation set') 71 | parser.add_argument('--world-size', default=-1, type=int, 72 | help='number of nodes for distributed training') 73 | parser.add_argument('--rank', default=-1, type=int, 74 | help='node rank for distributed training') 75 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 76 | help='url used to set up distributed training') 77 | parser.add_argument('--dist-backend', default='nccl', type=str, 78 | help='distributed backend') 79 | parser.add_argument('--seed', default=None, type=int, 80 | help='seed for initializing training. ') 81 | parser.add_argument('--gpu', default=None, type=int, 82 | help='GPU id to use.') 83 | parser.add_argument('--multiprocessing-distributed', action='store_true', 84 | help='Use multi-processing distributed training to launch ' 85 | 'N processes per node, which has N GPUs. This is the ' 86 | 'fastest way to use PyTorch for either single node or ' 87 | 'multi node data parallel training') 88 | 89 | # additional configs: 90 | parser.add_argument('--pretrained', default='', type=str, 91 | help='path to sogclr pretrained checkpoint') 92 | 93 | # dataset 94 | parser.add_argument('--data_name', default='imagenet1000', type=str) 95 | parser.add_argument('--save_dir', default='./saved_models/', type=str) 96 | 97 | 98 | best_acc1 = 0 99 | 100 | def set_all_seeds(SEED): 101 | # REPRODUCIBILITY 102 | torch.manual_seed(SEED) 103 | np.random.seed(SEED) 104 | torch.backends.cudnn.deterministic = True 105 | torch.backends.cudnn.benchmark = False 106 | 107 | def main(): 108 | args = parser.parse_args() 109 | 110 | if args.seed is not None: 111 | random.seed(args.seed) 112 | torch.manual_seed(args.seed) 113 | cudnn.deterministic = True 114 | warnings.warn('You have chosen to seed training. ' 115 | 'This will turn on the CUDNN deterministic setting, ' 116 | 'which can slow down your training considerably! ' 117 | 'You may see unexpected behavior when restarting ' 118 | 'from checkpoints.') 119 | 120 | if args.gpu is not None: 121 | warnings.warn('You have chosen a specific GPU. This will completely ' 122 | 'disable data parallelism.') 123 | 124 | if args.dist_url == "env://" and args.world_size == -1: 125 | args.world_size = int(os.environ["WORLD_SIZE"]) 126 | 127 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 128 | 129 | ngpus_per_node = torch.cuda.device_count() 130 | if args.multiprocessing_distributed: 131 | # Since we have ngpus_per_node processes per node, the total world_size 132 | # needs to be adjusted accordingly 133 | args.world_size = ngpus_per_node * args.world_size 134 | # Use torch.multiprocessing.spawn to launch distributed processes: the 135 | # main_worker process function 136 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 137 | else: 138 | # Simply call main_worker function 139 | main_worker(args.gpu, ngpus_per_node, args) 140 | 141 | 142 | def main_worker(gpu, ngpus_per_node, args): 143 | global best_acc1 144 | args.gpu = gpu 145 | 146 | # suppress printing if not master 147 | if args.multiprocessing_distributed and args.gpu != 0: 148 | def print_pass(*args): 149 | pass 150 | builtins.print = print_pass 151 | 152 | if args.gpu is not None: 153 | print("Use GPU: {} for training".format(args.gpu)) 154 | 155 | if args.distributed: 156 | if args.dist_url == "env://" and args.rank == -1: 157 | args.rank = int(os.environ["RANK"]) 158 | if args.multiprocessing_distributed: 159 | # For multiprocessing distributed training, rank needs to be the 160 | # global rank among all the processes 161 | args.rank = args.rank * ngpus_per_node + gpu 162 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 163 | world_size=args.world_size, rank=args.rank) 164 | torch.distributed.barrier() 165 | 166 | # create model 167 | set_all_seeds(123) 168 | print("=> creating model '{}'".format(args.arch)) 169 | if args.arch.startswith('vit'): 170 | model = sogclr.vits.__dict__[args.arch]() 171 | linear_keyword = 'head' 172 | else: 173 | model = torchvision_models.__dict__[args.arch]() 174 | linear_keyword = 'fc' 175 | 176 | if args.data_name == 'imagenet100': 177 | num_classes = 100 178 | elif args.data_name == 'imagenet1000': 179 | num_classes = 1000 180 | else: 181 | return 182 | print ('Dataset: %s' %args.data_name) 183 | 184 | 185 | # remove original fc and add fc with customized num_classes 186 | hidden_dim = model.fc.weight.shape[1] 187 | del model.fc # remove original fc layer 188 | model.fc = nn.Linear(hidden_dim, num_classes, bias=True) 189 | print (model) 190 | 191 | # freeze all layers but the last fc 192 | for name, param in model.named_parameters(): 193 | if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]: 194 | param.requires_grad = False 195 | 196 | # init the fc layer 197 | getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01) 198 | getattr(model, linear_keyword).bias.data.zero_() 199 | 200 | # load from pre-trained, before DistributedDataParallel constructor 201 | if args.pretrained: 202 | if os.path.isfile(args.pretrained): 203 | print("=> loading checkpoint '{}'".format(args.pretrained)) 204 | checkpoint = torch.load(args.pretrained, map_location="cpu") 205 | 206 | # rename sogclr pre-trained keys 207 | state_dict = checkpoint['state_dict'] 208 | for k in list(state_dict.keys()): 209 | # retain only base_encoder up to before the embedding layer 210 | if k.startswith('module.base_encoder') and not k.startswith('module.base_encoder.%s' % linear_keyword): 211 | # remove prefix 212 | state_dict[k[len("module.base_encoder."):]] = state_dict[k] 213 | # delete renamed or unused k 214 | del state_dict[k] 215 | 216 | args.start_epoch = 0 217 | msg = model.load_state_dict(state_dict, strict=False) 218 | assert set(msg.missing_keys) == {"%s.weight" % linear_keyword, "%s.bias" % linear_keyword} 219 | 220 | print("=> loaded pre-trained model '{}'".format(args.pretrained)) 221 | else: 222 | print("=> no checkpoint found at '{}'".format(args.pretrained)) 223 | 224 | # infer learning rate before changing batch size 225 | init_lr = args.lr * args.batch_size / 256 226 | 227 | if not torch.cuda.is_available(): 228 | print('using CPU, this will be slow') 229 | elif args.distributed: 230 | # For multiprocessing distributed, DistributedDataParallel constructor 231 | # should always set the single device scope, otherwise, 232 | # DistributedDataParallel will use all available devices. 233 | if args.gpu is not None: 234 | torch.cuda.set_device(args.gpu) 235 | model.cuda(args.gpu) 236 | # When using a single GPU per process and per 237 | # DistributedDataParallel, we need to divide the batch size 238 | # ourselves based on the total number of GPUs we have 239 | args.batch_size = int(args.batch_size / args.world_size) 240 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 241 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 242 | else: 243 | model.cuda() 244 | # DistributedDataParallel will divide and allocate batch_size to all 245 | # available GPUs if device_ids are not set 246 | model = torch.nn.parallel.DistributedDataParallel(model) 247 | elif args.gpu is not None: 248 | torch.cuda.set_device(args.gpu) 249 | model = model.cuda(args.gpu) 250 | else: 251 | # DataParallel will divide and allocate batch_size to all available GPUs 252 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 253 | model.features = torch.nn.DataParallel(model.features) 254 | model.cuda() 255 | else: 256 | model = torch.nn.DataParallel(model).cuda() 257 | 258 | # define loss function (criterion) and optimizer 259 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 260 | 261 | # optimize only the linear classifier 262 | parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 263 | assert len(parameters) == 2 # weight, bias 264 | 265 | optimizer = torch.optim.SGD(parameters, init_lr, 266 | momentum=args.momentum, 267 | weight_decay=args.weight_decay) 268 | 269 | 270 | # NEW! 271 | # save log 272 | save_root_path = args.save_dir 273 | global_batch_size = args.batch_size*args.world_size 274 | # check this if checkpoint is located in the current directory 275 | fold_name = args.pretrained.split('/')[-2] 276 | logdir = 'linear_eval_%s'%(fold_name) 277 | summary_writer = SummaryWriter(log_dir=os.path.join(save_root_path, logdir)) if args.rank == 0 else None 278 | print (logdir) 279 | 280 | 281 | # optionally resume from a checkpoint 282 | if args.resume: 283 | if os.path.isfile(args.resume): 284 | print("=> loading checkpoint '{}'".format(args.resume)) 285 | if args.gpu is None: 286 | checkpoint = torch.load(args.resume) 287 | else: 288 | # Map model to be loaded to specified single gpu. 289 | loc = 'cuda:{}'.format(args.gpu) 290 | checkpoint = torch.load(args.resume, map_location=loc) 291 | args.start_epoch = checkpoint['epoch'] 292 | best_acc1 = checkpoint['best_acc1'] 293 | if args.gpu is not None: 294 | # best_acc1 may be from a checkpoint from a different GPU 295 | best_acc1 = best_acc1.to(args.gpu) 296 | model.load_state_dict(checkpoint['state_dict']) 297 | optimizer.load_state_dict(checkpoint['optimizer']) 298 | print("=> loaded checkpoint '{}' (epoch {})" 299 | .format(args.resume, checkpoint['epoch'])) 300 | else: 301 | print("=> no checkpoint found at '{}'".format(args.resume)) 302 | 303 | cudnn.benchmark = True 304 | 305 | 306 | # Data loading code 307 | mean = {'imagenet100': [0.485, 0.456, 0.406], 308 | 'imagenet1000': [0.485, 0.456, 0.406], 309 | }[args.data_name] 310 | std = {'imagenet100': [0.229, 0.224, 0.225], 311 | 'imagenet1000': [0.229, 0.224, 0.225], 312 | }[args.data_name] 313 | 314 | 315 | image_size = {'imagenet100':224, 'imagenet1000':224}[args.data_name] 316 | normalize = transforms.Normalize(mean=mean, std=std) 317 | 318 | if args.data_name == 'imagenet1000' or args.data_name == 'imagenet100' : 319 | traindir = os.path.join(args.data, 'train') 320 | valdir = os.path.join(args.data, 'val') 321 | train_dataset = datasets.ImageFolder( 322 | traindir, 323 | transforms.Compose([ 324 | transforms.RandomResizedCrop(224), 325 | transforms.RandomHorizontalFlip(), 326 | transforms.ToTensor(), 327 | normalize, 328 | ])) 329 | else: 330 | raise ValueError 331 | 332 | if args.distributed: 333 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 334 | else: 335 | train_sampler = None 336 | 337 | train_loader = torch.utils.data.DataLoader( 338 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 339 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 340 | 341 | 342 | 343 | # validation 344 | if args.data_name == 'imagenet1000' or args.data_name == 'imagenet100' : 345 | 346 | val_loader = torch.utils.data.DataLoader( 347 | datasets.ImageFolder(valdir, transforms.Compose([ 348 | transforms.Resize(256), 349 | transforms.CenterCrop(224), 350 | transforms.ToTensor(), 351 | normalize, 352 | ])), 353 | batch_size=256, shuffle=False, 354 | num_workers=args.workers, pin_memory=True) 355 | 356 | else: 357 | raise ValueError 358 | 359 | 360 | if args.evaluate: 361 | validate(val_loader, model, criterion, args) 362 | return 363 | 364 | for epoch in range(args.start_epoch, args.epochs): 365 | if args.distributed: 366 | train_sampler.set_epoch(epoch) 367 | adjust_learning_rate(optimizer, init_lr, epoch, args) 368 | 369 | # train for one epoch 370 | train(train_loader, model, criterion, optimizer, epoch, args) 371 | 372 | # evaluate on validation set 373 | acc1 = validate(val_loader, model, criterion, args) 374 | 375 | # remember best acc@1 and save checkpoint 376 | is_best = acc1 > best_acc1 377 | best_acc1 = max(acc1, best_acc1) 378 | print (' * Best Acc@1:%.3f'%best_acc1) 379 | 380 | #if not args.multiprocessing_distributed or (args.multiprocessing_distributed 381 | # and args.rank == 0): # only the first GPU saves checkpoint 382 | # save_checkpoint({ 383 | # 'epoch': epoch + 1, 384 | # 'arch': args.arch, 385 | # 'state_dict': model.state_dict(), 386 | # 'best_acc1': best_acc1, 387 | # 'optimizer' : optimizer.state_dict(), 388 | # }, is_best, filename=os.path.join(save_root_path, logdir, 'checkpoint.pth.tar'), save_path=os.path.join(save_root_path, logdir)) 389 | # if epoch == args.start_epoch: 390 | # sanity_check(model.state_dict(), args.pretrained, linear_keyword) 391 | 392 | 393 | def train(train_loader, model, criterion, optimizer, epoch, args): 394 | batch_time = AverageMeter('Time', ':6.3f') 395 | data_time = AverageMeter('Data', ':6.3f') 396 | losses = AverageMeter('Loss', ':.4e') 397 | top1 = AverageMeter('Acc@1', ':6.2f') 398 | top5 = AverageMeter('Acc@5', ':6.2f') 399 | progress = ProgressMeter( 400 | len(train_loader), 401 | [batch_time, data_time, losses, top1, top5], 402 | prefix="Epoch: [{}]".format(epoch)) 403 | 404 | """ 405 | Switch to eval mode: 406 | Under the protocol of linear classification on frozen features/models, 407 | it is not legitimate to change any part of the pre-trained model. 408 | BatchNorm in train mode may revise running mean/std (even if it receives 409 | no gradient), which are part of the model parameters too. 410 | """ 411 | model.eval() 412 | 413 | end = time.time() 414 | for i, (images, target) in enumerate(train_loader): 415 | # measure data loading time 416 | data_time.update(time.time() - end) 417 | 418 | if args.gpu is not None: 419 | images = images.cuda(args.gpu, non_blocking=True) 420 | if torch.cuda.is_available(): 421 | target = target.cuda(args.gpu, non_blocking=True) 422 | 423 | # compute output 424 | output = model(images)[0] # after modifying resnet forward #torch.Size([1024, 1000]) torch.Size([1024, 2048]) 425 | loss = criterion(output, target) 426 | 427 | # measure accuracy and record loss 428 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 429 | losses.update(loss.item(), images.size(0)) 430 | top1.update(acc1[0], images.size(0)) 431 | top5.update(acc5[0], images.size(0)) 432 | 433 | # compute gradient and do SGD step 434 | optimizer.zero_grad() 435 | loss.backward() 436 | optimizer.step() 437 | 438 | # measure elapsed time 439 | batch_time.update(time.time() - end) 440 | end = time.time() 441 | 442 | if i % args.print_freq == 0: 443 | progress.display(i) 444 | 445 | 446 | def validate(val_loader, model, criterion, args): 447 | batch_time = AverageMeter('Time', ':6.3f') 448 | losses = AverageMeter('Loss', ':.4e') 449 | top1 = AverageMeter('Acc@1', ':6.2f') 450 | top5 = AverageMeter('Acc@5', ':6.2f') 451 | progress = ProgressMeter( 452 | len(val_loader), 453 | [batch_time, losses, top1, top5], 454 | prefix='Test: ') 455 | 456 | # switch to evaluate mode 457 | model.eval() 458 | 459 | with torch.no_grad(): 460 | end = time.time() 461 | for i, (images, target) in enumerate(val_loader): 462 | if args.gpu is not None: 463 | images = images.cuda(args.gpu, non_blocking=True) 464 | if torch.cuda.is_available(): 465 | target = target.cuda(args.gpu, non_blocking=True) 466 | 467 | # compute output 468 | output = model(images)[0] 469 | loss = criterion(output, target) 470 | 471 | # measure accuracy and record loss 472 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 473 | losses.update(loss.item(), images.size(0)) 474 | top1.update(acc1[0], images.size(0)) 475 | top5.update(acc5[0], images.size(0)) 476 | 477 | # measure elapsed time 478 | batch_time.update(time.time() - end) 479 | end = time.time() 480 | 481 | if i % args.print_freq == 0: 482 | progress.display(i) 483 | 484 | # TODO: this should also be done with the ProgressMeter 485 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 486 | .format(top1=top1, top5=top5)) 487 | 488 | return top1.avg 489 | 490 | 491 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', save_path='./'): 492 | torch.save(state, filename) 493 | if is_best: 494 | shutil.copyfile(filename, os.path.join(save_path, 'model_best.pth.tar')) 495 | 496 | 497 | def sanity_check(state_dict, pretrained_weights, linear_keyword): 498 | """ 499 | Linear classifier should not change any weights other than the linear layer. 500 | This sanity check asserts nothing wrong happens (e.g., BN stats updated). 501 | """ 502 | print("=> loading '{}' for sanity check".format(pretrained_weights)) 503 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 504 | state_dict_pre = checkpoint['state_dict'] 505 | 506 | for k in list(state_dict.keys()): 507 | # only ignore linear layer 508 | if '%s.weight' % linear_keyword in k or '%s.bias' % linear_keyword in k: 509 | continue 510 | 511 | # name in pretrained model 512 | k_pre = 'module.base_encoder.' + k[len('module.'):] \ 513 | if k.startswith('module.') else 'module.base_encoder.' + k 514 | 515 | assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \ 516 | '{} is changed in linear classifier training.'.format(k) 517 | 518 | print("=> sanity check passed.") 519 | 520 | 521 | class AverageMeter(object): 522 | """Computes and stores the average and current value""" 523 | def __init__(self, name, fmt=':f'): 524 | self.name = name 525 | self.fmt = fmt 526 | self.reset() 527 | 528 | def reset(self): 529 | self.val = 0 530 | self.avg = 0 531 | self.sum = 0 532 | self.count = 0 533 | 534 | def update(self, val, n=1): 535 | self.val = val 536 | self.sum += val * n 537 | self.count += n 538 | self.avg = self.sum / self.count 539 | 540 | def __str__(self): 541 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 542 | return fmtstr.format(**self.__dict__) 543 | 544 | 545 | class ProgressMeter(object): 546 | def __init__(self, num_batches, meters, prefix=""): 547 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 548 | self.meters = meters 549 | self.prefix = prefix 550 | 551 | def display(self, batch): 552 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 553 | entries += [str(meter) for meter in self.meters] 554 | print('\t'.join(entries)) 555 | 556 | def _get_batch_fmtstr(self, num_batches): 557 | num_digits = len(str(num_batches // 1)) 558 | fmt = '{:' + str(num_digits) + 'd}' 559 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 560 | 561 | 562 | def adjust_learning_rate(optimizer, init_lr, epoch, args): 563 | """Decay the learning rate based on schedule""" 564 | cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 565 | for param_group in optimizer.param_groups: 566 | param_group['lr'] = cur_lr 567 | 568 | 569 | def accuracy(output, target, topk=(1,)): 570 | """Computes the accuracy over the k top predictions for the specified values of k""" 571 | with torch.no_grad(): 572 | maxk = max(topk) 573 | batch_size = target.size(0) 574 | 575 | _, pred = output.topk(maxk, 1, True, True) 576 | pred = pred.t() 577 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 578 | 579 | res = [] 580 | for k in topk: 581 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 582 | res.append(correct_k.mul_(100.0 / batch_size)) 583 | return res 584 | 585 | 586 | if __name__ == '__main__': 587 | main() 588 | -------------------------------------------------------------------------------- /sogclr/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /sogclr/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class SimCLR(nn.Module): 13 | """ 14 | Build a SimCLR-based model with a base encoder, and two MLPs 15 | 16 | """ 17 | def __init__(self, base_encoder, dim=256, mlp_dim=2048, T=0.1, loss_type='dcl', N=50000, num_proj_layers=2, device=None): 18 | """ 19 | dim: feature dimension (default: 256) 20 | mlp_dim: hidden dimension in MLPs (default: 4096) 21 | T: softmax temperature (default: 1.0) 22 | """ 23 | super(SimCLR, self).__init__() 24 | self.T = T 25 | self.N = N 26 | self.loss_type = loss_type 27 | 28 | # build encoders 29 | self.base_encoder = base_encoder(num_classes=mlp_dim) 30 | 31 | # build non-linear projection heads 32 | self._build_projector_and_predictor_mlps(dim, mlp_dim) 33 | 34 | # sogclr 35 | if not device: 36 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 37 | else: 38 | self.device = device 39 | 40 | # for DCL 41 | self.u = torch.zeros(N).reshape(-1, 1) #.to(self.device) 42 | self.LARGE_NUM = 1e9 43 | 44 | 45 | def _build_projector_and_predictor_mlps(self, dim, mlp_dim): 46 | pass 47 | 48 | def _build_mlp(self, num_layers, input_dim, mlp_dim, output_dim, last_bn=True): 49 | mlp = [] 50 | for l in range(num_layers): 51 | dim1 = input_dim if l == 0 else mlp_dim 52 | dim2 = output_dim if l == num_layers - 1 else mlp_dim 53 | 54 | mlp.append(nn.Linear(dim1, dim2, bias=False)) 55 | 56 | if l < num_layers - 1: 57 | mlp.append(nn.BatchNorm1d(dim2)) 58 | mlp.append(nn.ReLU(inplace=True)) 59 | elif last_bn: 60 | # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157 61 | # for simplicity, we further removed gamma in BN 62 | mlp.append(nn.BatchNorm1d(dim2, affine=False)) 63 | 64 | return nn.Sequential(*mlp) 65 | 66 | def dynamic_contrastive_loss(self, hidden1, hidden2, index=None, gamma=0.9, distributed=True): 67 | # Get (normalized) hidden1 and hidden2. 68 | hidden1, hidden2 = F.normalize(hidden1, p=2, dim=1), F.normalize(hidden2, p=2, dim=1) 69 | batch_size = hidden1.shape[0] 70 | 71 | # Gather hidden1/hidden2 across replicas and create local labels. 72 | if distributed: 73 | hidden1_large = torch.cat(all_gather_layer.apply(hidden1), dim=0) # why concat_all_gather() 74 | hidden2_large = torch.cat(all_gather_layer.apply(hidden2), dim=0) 75 | enlarged_batch_size = hidden1_large.shape[0] 76 | 77 | labels_idx = (torch.arange(batch_size, dtype=torch.long) + batch_size * torch.distributed.get_rank()).to(self.device) 78 | labels = F.one_hot(labels_idx, enlarged_batch_size*2).to(self.device) 79 | masks = F.one_hot(labels_idx, enlarged_batch_size).to(self.device) 80 | batch_size = enlarged_batch_size 81 | else: 82 | hidden1_large = hidden1 83 | hidden2_large = hidden2 84 | labels = F.one_hot(torch.arange(batch_size, dtype=torch.long), batch_size * 2).to(self.device) 85 | masks = F.one_hot(torch.arange(batch_size, dtype=torch.long), batch_size).to(self.device) 86 | 87 | logits_aa = torch.matmul(hidden1, hidden1_large.T) 88 | logits_aa = logits_aa - masks * self.LARGE_NUM 89 | logits_bb = torch.matmul(hidden2, hidden2_large.T) 90 | logits_bb = logits_bb - masks * self.LARGE_NUM 91 | logits_ab = torch.matmul(hidden1, hidden2_large.T) 92 | logits_ba = torch.matmul(hidden2, hidden1_large.T) 93 | 94 | # SogCLR 95 | neg_mask = 1-labels 96 | logits_ab_aa = torch.cat([logits_ab, logits_aa ], 1) 97 | logits_ba_bb = torch.cat([logits_ba, logits_bb ], 1) 98 | 99 | neg_logits1 = torch.exp(logits_ab_aa /self.T)*neg_mask #(B, 2B) 100 | neg_logits2 = torch.exp(logits_ba_bb /self.T)*neg_mask 101 | 102 | # u init 103 | if self.u[index.cpu()].sum() == 0: 104 | gamma = 1 105 | 106 | u1 = (1 - gamma) * self.u[index.cpu()].cuda() + gamma * torch.sum(neg_logits1, dim=1, keepdim=True)/(2*(batch_size-1)) 107 | u2 = (1 - gamma) * self.u[index.cpu()].cuda() + gamma * torch.sum(neg_logits2, dim=1, keepdim=True)/(2*(batch_size-1)) 108 | 109 | # this sync on all devices (since "hidden" are gathering from all devices) 110 | if distributed: 111 | u1_large = concat_all_gather(u1) 112 | u2_large = concat_all_gather(u2) 113 | index_large = concat_all_gather(index) 114 | self.u[index_large.cpu()] = (u1_large.detach().cpu() + u2_large.detach().cpu())/2 115 | else: 116 | self.u[index.cpu()] = (u1.detach().cpu() + u2.detach().cpu())/2 117 | 118 | p_neg_weights1 = (neg_logits1/u1).detach() 119 | p_neg_weights2 = (neg_logits2/u2).detach() 120 | 121 | def softmax_cross_entropy_with_logits(labels, logits, weights): 122 | expsum_neg_logits = torch.sum(weights*logits, dim=1, keepdim=True)/(2*(batch_size-1)) 123 | normalized_logits = logits - expsum_neg_logits 124 | return -torch.sum(labels * normalized_logits, dim=1) 125 | 126 | loss_a = softmax_cross_entropy_with_logits(labels, logits_ab_aa, p_neg_weights1) 127 | loss_b = softmax_cross_entropy_with_logits(labels, logits_ba_bb, p_neg_weights2) 128 | loss = (loss_a + loss_b).mean() 129 | 130 | return loss 131 | 132 | def forward(self, x1, x2, index, gamma): 133 | """ 134 | Input: 135 | x1: first views of images 136 | x2: second views of images 137 | index: index of image 138 | gamma: moving average of sogclr 139 | Output: 140 | loss 141 | """ 142 | # compute features 143 | h1 = self.base_encoder(x1) 144 | h2 = self.base_encoder(x2) 145 | loss = self.dynamic_contrastive_loss(h1, h2, index, gamma) 146 | return loss 147 | 148 | 149 | class SimCLR_ResNet(SimCLR): 150 | def _build_projector_and_predictor_mlps(self, dim, mlp_dim, num_proj_layers=2): 151 | hidden_dim = self.base_encoder.fc.weight.shape[1] 152 | del self.base_encoder.fc # remove original fc layer 153 | 154 | # projectors 155 | # TODO: increase number of mlp layers 156 | self.base_encoder.fc = self._build_mlp(num_proj_layers, hidden_dim, mlp_dim, dim) 157 | 158 | 159 | 160 | # utils 161 | @torch.no_grad() 162 | def concat_all_gather(tensor): 163 | """ 164 | Performs all_gather operation on the provided tensors. 165 | *** Warning ***: torch.distributed.all_gather has no gradient. 166 | """ 167 | tensors_gather = [torch.ones_like(tensor) 168 | for _ in range(torch.distributed.get_world_size())] 169 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 170 | 171 | output = torch.cat(tensors_gather, dim=0) 172 | return output 173 | 174 | 175 | class all_gather_layer(torch.autograd.Function): 176 | """Gather tensors from all process, supporting backward propagation.""" 177 | 178 | @staticmethod 179 | def forward(ctx, input): 180 | ctx.save_for_backward(input) 181 | output = [torch.zeros_like(input) for _ in range(torch.distributed.get_world_size())] 182 | torch.distributed.all_gather(output, input) 183 | return tuple(output) 184 | 185 | @staticmethod 186 | def backward(ctx, *grads): 187 | (input,) = ctx.saved_tensors 188 | grad_out = torch.zeros_like(input) 189 | grad_out[:] = grads[torch.distributed.get_rank()] 190 | return grad_out 191 | -------------------------------------------------------------------------------- /sogclr/folder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union 4 | 5 | from PIL import Image 6 | 7 | from torchvision.datasets.vision import VisionDataset 8 | 9 | 10 | def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool: 11 | """Checks if a file is an allowed extension. 12 | 13 | Args: 14 | filename (string): path to a file 15 | extensions (tuple of strings): extensions to consider (lowercase) 16 | 17 | Returns: 18 | bool: True if the filename ends with one of given extensions 19 | """ 20 | return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions)) 21 | 22 | 23 | def is_image_file(filename: str) -> bool: 24 | """Checks if a file is an allowed image extension. 25 | 26 | Args: 27 | filename (string): path to a file 28 | 29 | Returns: 30 | bool: True if the filename ends with a known image extension 31 | """ 32 | return has_file_allowed_extension(filename, IMG_EXTENSIONS) 33 | 34 | 35 | def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: 36 | """Finds the class folders in a dataset. 37 | 38 | See :class:`DatasetFolder` for details. 39 | """ 40 | classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) 41 | if not classes: 42 | raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") 43 | 44 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 45 | return classes, class_to_idx 46 | 47 | 48 | def make_dataset( 49 | directory: str, 50 | class_to_idx: Optional[Dict[str, int]] = None, 51 | extensions: Optional[Union[str, Tuple[str, ...]]] = None, 52 | is_valid_file: Optional[Callable[[str], bool]] = None, 53 | ) -> List[Tuple[str, int]]: 54 | """Generates a list of samples of a form (path_to_sample, class). 55 | 56 | See :class:`DatasetFolder` for details. 57 | 58 | Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function 59 | by default. 60 | """ 61 | directory = os.path.expanduser(directory) 62 | 63 | if class_to_idx is None: 64 | _, class_to_idx = find_classes(directory) 65 | elif not class_to_idx: 66 | raise ValueError("'class_to_index' must have at least one entry to collect any samples.") 67 | 68 | both_none = extensions is None and is_valid_file is None 69 | both_something = extensions is not None and is_valid_file is not None 70 | if both_none or both_something: 71 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") 72 | 73 | if extensions is not None: 74 | 75 | def is_valid_file(x: str) -> bool: 76 | return has_file_allowed_extension(x, extensions) # type: ignore[arg-type] 77 | 78 | is_valid_file = cast(Callable[[str], bool], is_valid_file) 79 | 80 | instances = [] 81 | available_classes = set() 82 | for target_class in sorted(class_to_idx.keys()): 83 | class_index = class_to_idx[target_class] 84 | target_dir = os.path.join(directory, target_class) 85 | if not os.path.isdir(target_dir): 86 | continue 87 | for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): 88 | for fname in sorted(fnames): 89 | path = os.path.join(root, fname) 90 | if is_valid_file(path): 91 | item = path, class_index 92 | instances.append(item) 93 | 94 | if target_class not in available_classes: 95 | available_classes.add(target_class) 96 | 97 | empty_classes = set(class_to_idx.keys()) - available_classes 98 | if empty_classes: 99 | msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " 100 | if extensions is not None: 101 | msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" 102 | raise FileNotFoundError(msg) 103 | 104 | return instances 105 | 106 | 107 | class DatasetFolder(VisionDataset): 108 | """A generic data loader. 109 | 110 | This default directory structure can be customized by overriding the 111 | :meth:`find_classes` method. 112 | 113 | Args: 114 | root (string): Root directory path. 115 | loader (callable): A function to load a sample given its path. 116 | extensions (tuple[string]): A list of allowed extensions. 117 | both extensions and is_valid_file should not be passed. 118 | transform (callable, optional): A function/transform that takes in 119 | a sample and returns a transformed version. 120 | E.g, ``transforms.RandomCrop`` for images. 121 | target_transform (callable, optional): A function/transform that takes 122 | in the target and transforms it. 123 | is_valid_file (callable, optional): A function that takes path of a file 124 | and check if the file is a valid file (used to check of corrupt files) 125 | both extensions and is_valid_file should not be passed. 126 | 127 | Attributes: 128 | classes (list): List of the class names sorted alphabetically. 129 | class_to_idx (dict): Dict with items (class_name, class_index). 130 | samples (list): List of (sample path, class_index) tuples 131 | targets (list): The class_index value for each image in the dataset 132 | """ 133 | 134 | def __init__( 135 | self, 136 | root: str, 137 | loader: Callable[[str], Any], 138 | extensions: Optional[Tuple[str, ...]] = None, 139 | transform: Optional[Callable] = None, 140 | target_transform: Optional[Callable] = None, 141 | is_valid_file: Optional[Callable[[str], bool]] = None, 142 | ) -> None: 143 | super().__init__(root, transform=transform, target_transform=target_transform) 144 | classes, class_to_idx = self.find_classes(self.root) 145 | samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file) 146 | 147 | self.loader = loader 148 | self.extensions = extensions 149 | 150 | self.classes = classes 151 | self.class_to_idx = class_to_idx 152 | self.samples = samples 153 | self.targets = [s[1] for s in samples] 154 | 155 | @staticmethod 156 | def make_dataset( 157 | directory: str, 158 | class_to_idx: Dict[str, int], 159 | extensions: Optional[Tuple[str, ...]] = None, 160 | is_valid_file: Optional[Callable[[str], bool]] = None, 161 | ) -> List[Tuple[str, int]]: 162 | """Generates a list of samples of a form (path_to_sample, class). 163 | 164 | This can be overridden to e.g. read files from a compressed zip file instead of from the disk. 165 | 166 | Args: 167 | directory (str): root dataset directory, corresponding to ``self.root``. 168 | class_to_idx (Dict[str, int]): Dictionary mapping class name to class index. 169 | extensions (optional): A list of allowed extensions. 170 | Either extensions or is_valid_file should be passed. Defaults to None. 171 | is_valid_file (optional): A function that takes path of a file 172 | and checks if the file is a valid file 173 | (used to check of corrupt files) both extensions and 174 | is_valid_file should not be passed. Defaults to None. 175 | 176 | Raises: 177 | ValueError: In case ``class_to_idx`` is empty. 178 | ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. 179 | FileNotFoundError: In case no valid file was found for any class. 180 | 181 | Returns: 182 | List[Tuple[str, int]]: samples of a form (path_to_sample, class) 183 | """ 184 | if class_to_idx is None: 185 | # prevent potential bug since make_dataset() would use the class_to_idx logic of the 186 | # find_classes() function, instead of using that of the find_classes() method, which 187 | # is potentially overridden and thus could have a different logic. 188 | raise ValueError("The class_to_idx parameter cannot be None.") 189 | return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) 190 | 191 | def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]: 192 | """Find the class folders in a dataset structured as follows:: 193 | 194 | directory/ 195 | ├── class_x 196 | │ ├── xxx.ext 197 | │ ├── xxy.ext 198 | │ └── ... 199 | │ └── xxz.ext 200 | └── class_y 201 | ├── 123.ext 202 | ├── nsdf3.ext 203 | └── ... 204 | └── asd932_.ext 205 | 206 | This method can be overridden to only consider 207 | a subset of classes, or to adapt to a different dataset directory structure. 208 | 209 | Args: 210 | directory(str): Root directory path, corresponding to ``self.root`` 211 | 212 | Raises: 213 | FileNotFoundError: If ``dir`` has no class folders. 214 | 215 | Returns: 216 | (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index. 217 | """ 218 | return find_classes(directory) 219 | 220 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 221 | """ 222 | Args: 223 | index (int): Index 224 | 225 | Returns: 226 | tuple: (sample, target) where target is class_index of the target class. 227 | """ 228 | path, target = self.samples[index] 229 | sample = self.loader(path) 230 | if self.transform is not None: 231 | sample = self.transform(sample) 232 | if self.target_transform is not None: 233 | target = self.target_transform(target) 234 | 235 | return sample, target, index 236 | 237 | def __len__(self) -> int: 238 | return len(self.samples) 239 | 240 | 241 | IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") 242 | 243 | 244 | def pil_loader(path: str) -> Image.Image: 245 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 246 | with open(path, "rb") as f: 247 | img = Image.open(f) 248 | return img.convert("RGB") 249 | 250 | 251 | # TODO: specify the return type 252 | def accimage_loader(path: str) -> Any: 253 | import accimage 254 | 255 | try: 256 | return accimage.Image(path) 257 | except OSError: 258 | # Potentially a decoding problem, fall back to PIL.Image 259 | return pil_loader(path) 260 | 261 | 262 | def default_loader(path: str) -> Any: 263 | from torchvision import get_image_backend 264 | 265 | if get_image_backend() == "accimage": 266 | return accimage_loader(path) 267 | else: 268 | return pil_loader(path) 269 | 270 | 271 | class ImageFolder(DatasetFolder): 272 | """A generic data loader where the images are arranged in this way by default: :: 273 | 274 | root/dog/xxx.png 275 | root/dog/xxy.png 276 | root/dog/[...]/xxz.png 277 | 278 | root/cat/123.png 279 | root/cat/nsdf3.png 280 | root/cat/[...]/asd932_.png 281 | 282 | This class inherits from :class:`~torchvision.datasets.DatasetFolder` so 283 | the same methods can be overridden to customize the dataset. 284 | 285 | Args: 286 | root (string): Root directory path. 287 | transform (callable, optional): A function/transform that takes in an PIL image 288 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 289 | target_transform (callable, optional): A function/transform that takes in the 290 | target and transforms it. 291 | loader (callable, optional): A function to load an image given its path. 292 | is_valid_file (callable, optional): A function that takes path of an Image file 293 | and check if the file is a valid file (used to check of corrupt files) 294 | 295 | Attributes: 296 | classes (list): List of the class names sorted alphabetically. 297 | class_to_idx (dict): Dict with items (class_name, class_index). 298 | imgs (list): List of (image path, class_index) tuples 299 | """ 300 | 301 | def __init__( 302 | self, 303 | root: str, 304 | transform: Optional[Callable] = None, 305 | target_transform: Optional[Callable] = None, 306 | loader: Callable[[str], Any] = default_loader, 307 | is_valid_file: Optional[Callable[[str], bool]] = None, 308 | ): 309 | super().__init__( 310 | root, 311 | loader, 312 | IMG_EXTENSIONS if is_valid_file is None else None, 313 | transform=transform, 314 | target_transform=target_transform, 315 | is_valid_file=is_valid_file, 316 | ) 317 | self.imgs = self.samples 318 | -------------------------------------------------------------------------------- /sogclr/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from PIL import Image, ImageFilter, ImageOps 8 | import math 9 | import random 10 | import torchvision.transforms.functional as tf 11 | 12 | 13 | class TwoCropsTransform: 14 | """Take two random crops of one image""" 15 | 16 | def __init__(self, base_transform1, base_transform2): 17 | self.base_transform1 = base_transform1 18 | self.base_transform2 = base_transform2 19 | 20 | def __call__(self, x): 21 | im1 = self.base_transform1(x) 22 | im2 = self.base_transform2(x) 23 | return [im1, im2] 24 | 25 | 26 | class GaussianBlur(object): 27 | """Gaussian blur augmentation from SimCLR: https://arxiv.org/abs/2002.05709""" 28 | 29 | def __init__(self, sigma=[.1, 2.]): 30 | self.sigma = sigma 31 | 32 | def __call__(self, x): 33 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 34 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 35 | return x 36 | 37 | 38 | class Solarize(object): 39 | """Solarize augmentation from BYOL: https://arxiv.org/abs/2006.07733""" 40 | 41 | def __call__(self, x): 42 | return ImageOps.solarize(x) -------------------------------------------------------------------------------- /sogclr/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | 10 | class LARS(torch.optim.Optimizer): 11 | """ 12 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 13 | """ 14 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 15 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 16 | super().__init__(params, defaults) 17 | 18 | @torch.no_grad() 19 | def step(self): 20 | for g in self.param_groups: 21 | for p in g['params']: 22 | dp = p.grad 23 | 24 | if dp is None: 25 | continue 26 | 27 | if p.ndim > 1: # if not normalization gamma/beta or bias 28 | dp = dp.add(p, alpha=g['weight_decay']) 29 | param_norm = torch.norm(p) 30 | update_norm = torch.norm(dp) 31 | one = torch.ones_like(param_norm) 32 | q = torch.where(param_norm > 0., 33 | torch.where(update_norm > 0, 34 | (g['trust_coefficient'] * param_norm / update_norm), one), 35 | one) 36 | dp = dp.mul(q) 37 | 38 | param_state = self.state[p] 39 | if 'mu' not in param_state: 40 | param_state['mu'] = torch.zeros_like(p) 41 | mu = param_state['mu'] 42 | mu.mul_(g['momentum']).add_(dp) 43 | p.add_(mu, alpha=-g['lr']) 44 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import argparse 10 | import builtins 11 | import math 12 | import os 13 | import random 14 | import shutil 15 | import time 16 | import warnings 17 | import numpy as np 18 | from functools import partial 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.parallel 23 | import torch.backends.cudnn as cudnn 24 | import torch.distributed as dist 25 | import torch.optim 26 | import torch.multiprocessing as mp 27 | import torch.utils.data 28 | import torch.utils.data.distributed 29 | import torchvision.transforms as transforms 30 | import torchvision.datasets as datasets 31 | import torchvision.models as torchvision_models 32 | from torch.utils.tensorboard import SummaryWriter 33 | 34 | import sogclr.builder 35 | import sogclr.loader 36 | import sogclr.optimizer 37 | import sogclr.folder # imagenet 38 | 39 | # ignore all warnings 40 | import warnings 41 | warnings.filterwarnings("ignore") 42 | 43 | 44 | torchvision_model_names = sorted(name for name in torchvision_models.__dict__ 45 | if name.islower() and not name.startswith("__") 46 | and callable(torchvision_models.__dict__[name])) 47 | 48 | model_names = torchvision_model_names 49 | 50 | parser = argparse.ArgumentParser(description='SogCLR ImageNet Pre-Training') 51 | parser.add_argument('--data', metavar='DIR', default='/data/imagenet100/', 52 | help='path to dataset') 53 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 54 | choices=model_names, 55 | help='model architecture: ' + 56 | ' | '.join(model_names) + 57 | ' (default: resnet50)') 58 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 59 | help='number of data loading workers (default: 32)') 60 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 61 | help='number of total epochs to run') 62 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 63 | help='manual epoch number (useful on restarts)') 64 | parser.add_argument('-b', '--batch-size', default=4096, type=int, 65 | metavar='N', 66 | help='mini-batch size (default: 4096), this is the total ' 67 | 'batch size of all GPUs on all nodes when ' 68 | 'using Data Parallel or Distributed Data Parallel') 69 | parser.add_argument('--lr', '--learning-rate', default=0.6, type=float, 70 | metavar='LR', help='initial (base) learning rate', dest='lr') 71 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 72 | help='momentum') 73 | parser.add_argument('--wd', '--weight-decay', default=1e-6, type=float, 74 | metavar='W', help='weight decay (default: 1e-6)', 75 | dest='weight_decay') 76 | parser.add_argument('-p', '--print-freq', default=10, type=int, 77 | metavar='N', help='print frequency (default: 10)') 78 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 79 | help='path to latest checkpoint (default: none)') 80 | parser.add_argument('--world-size', default=-1, type=int, 81 | help='number of nodes for distributed training') 82 | parser.add_argument('--rank', default=-1, type=int, 83 | help='node rank for distributed training') 84 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 85 | help='url used to set up distributed training') 86 | parser.add_argument('--dist-backend', default='nccl', type=str, 87 | help='distributed backend') 88 | parser.add_argument('--seed', default=None, type=int, 89 | help='seed for initializing training. ') 90 | parser.add_argument('--gpu', default=None, type=int, 91 | help='GPU id to use.') 92 | parser.add_argument('--multiprocessing-distributed', action='store_true', 93 | help='Use multi-processing distributed training to launch ' 94 | 'N processes per node, which has N GPUs. This is the ' 95 | 'fastest way to use PyTorch for either single node or ' 96 | 'multi node data parallel training') 97 | 98 | 99 | # moco specific configs: 100 | parser.add_argument('--dim', default=128, type=int, 101 | help='feature dimension (default: 256)') 102 | parser.add_argument('--mlp-dim', default=2048, type=int, 103 | help='hidden dimension in MLPs (default: 4096)') 104 | parser.add_argument('--t', default=0.1, type=float, 105 | help='softmax temperature (default: 1.0)') 106 | parser.add_argument('--num_proj_layers', default=2, type=int, 107 | help='number of non-linear projection heads') 108 | 109 | # other upgrades 110 | parser.add_argument('--optimizer', default='lars', type=str, 111 | choices=['lars', 'adamw'], 112 | help='optimizer used (default: lars)') 113 | parser.add_argument('--warmup-epochs', default=10, type=int, metavar='N', 114 | help='number of warmup epochs') 115 | parser.add_argument('--crop-min', default=0.08, type=float, 116 | help='minimum scale for random cropping (default: 0.08)') 117 | 118 | # dataset 119 | parser.add_argument('--data_name', default='imagenet1000', type=str) 120 | parser.add_argument('--save_dir', default='./saved_models/', type=str) 121 | 122 | 123 | # sogclr 124 | parser.add_argument('--loss_type', default='dcl', type=str, 125 | help='loss function to use (default: dcl)') 126 | parser.add_argument('--gamma', default=0.9, type=float, 127 | help='for updating moving average estimator u for sogclr') 128 | parser.add_argument('--learning-rate-scaling', default='sqrt', type=str, 129 | choices=['sqrt', 'linear'], 130 | help='learing rate scaling (default: sqrt)') 131 | 132 | 133 | def set_all_seeds(SEED): 134 | # REPRODUCIBILITY 135 | torch.manual_seed(SEED) 136 | np.random.seed(SEED) 137 | torch.backends.cudnn.deterministic = True 138 | torch.backends.cudnn.benchmark = False 139 | 140 | def main(): 141 | args = parser.parse_args() 142 | 143 | if args.seed is not None: 144 | random.seed(args.seed) 145 | torch.manual_seed(args.seed) 146 | cudnn.deterministic = True 147 | warnings.warn('You have chosen to seed training. ' 148 | 'This will turn on the CUDNN deterministic setting, ' 149 | 'which can slow down your training considerably! ' 150 | 'You may see unexpected behavior when restarting ' 151 | 'from checkpoints.') 152 | 153 | if args.gpu is not None: 154 | warnings.warn('You have chosen a specific GPU. This will completely ' 155 | 'disable data parallelism.') 156 | 157 | if args.dist_url == "env://" and args.world_size == -1: 158 | args.world_size = int(os.environ["WORLD_SIZE"]) 159 | 160 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 161 | 162 | ngpus_per_node = torch.cuda.device_count() 163 | if args.multiprocessing_distributed: 164 | # Since we have ngpus_per_node processes per node, the total world_size 165 | # needs to be adjusted accordingly 166 | args.world_size = ngpus_per_node * args.world_size 167 | # Use torch.multiprocessing.spawn to launch distributed processes: the 168 | # main_worker process function 169 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 170 | else: 171 | # Simply call main_worker function 172 | main_worker(args.gpu, ngpus_per_node, args) 173 | 174 | 175 | def main_worker(gpu, ngpus_per_node, args): 176 | args.gpu = gpu 177 | 178 | # suppress printing if not first GPU on each node 179 | if args.multiprocessing_distributed and (args.gpu != 0 or args.rank != 0): 180 | def print_pass(*args): 181 | pass 182 | builtins.print = print_pass 183 | 184 | if args.gpu is not None: 185 | print("Use GPU: {} for training".format(args.gpu)) 186 | 187 | if args.distributed: 188 | if args.dist_url == "env://" and args.rank == -1: 189 | args.rank = int(os.environ["RANK"]) 190 | if args.multiprocessing_distributed: 191 | # For multiprocessing distributed training, rank needs to be the 192 | # global rank among all the processes 193 | args.rank = args.rank * ngpus_per_node + gpu 194 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 195 | world_size=args.world_size, rank=args.rank) 196 | torch.distributed.barrier() 197 | 198 | # sizes for each dataset 199 | if args.data_name == 'imagenet100': 200 | data_size = 129395+1 201 | elif args.data_name == 'imagenet1000': 202 | data_size = 1281167+1 203 | else: 204 | data_size = 1000000 205 | print ('pretraining on %s'%args.data_name) 206 | 207 | # create model 208 | set_all_seeds(2022) 209 | print("=> creating model '{}'".format(args.arch)) 210 | model = sogclr.builder.SimCLR_ResNet( 211 | partial(torchvision_models.__dict__[args.arch], zero_init_residual=True), 212 | args.dim, args.mlp_dim, args.t, loss_type=args.loss_type, N=data_size, num_proj_layers=args.num_proj_layers) 213 | 214 | # infer learning rate before changing batch size 215 | if args.learning_rate_scaling == 'linear': 216 | # infer learning rate before changing batch size 217 | args.lr = args.lr * args.batch_size / 256 218 | else: 219 | # sqrt scaling 220 | args.lr = args.lr * math.sqrt(args.batch_size) 221 | 222 | print ('initial learning rate:', args.lr) 223 | if not torch.cuda.is_available(): 224 | print('using CPU, this will be slow') 225 | elif args.distributed: 226 | # apply SyncBN 227 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 228 | # For multiprocessing distributed, DistributedDataParallel constructor 229 | # should always set the single device scope, otherwise, 230 | # DistributedDataParallel will use all available devices. 231 | if args.gpu is not None: 232 | torch.cuda.set_device(args.gpu) 233 | model.cuda(args.gpu) 234 | # When using a single GPU per process and per 235 | # DistributedDataParallel, we need to divide the batch size 236 | # ourselves based on the total number of GPUs we have 237 | args.batch_size = int(args.batch_size / args.world_size) 238 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 239 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 240 | else: 241 | model.cuda() 242 | # DistributedDataParallel will divide and allocate batch_size to all 243 | # available GPUs if device_ids are not set 244 | model = torch.nn.parallel.DistributedDataParallel(model) 245 | elif args.gpu is not None: 246 | torch.cuda.set_device(args.gpu) 247 | model = model.cuda(args.gpu) 248 | # comment out the following line for debugging 249 | raise NotImplementedError("Only DistributedDataParallel is supported.") 250 | else: 251 | # AllGather/rank implementation in this code only supports DistributedDataParallel. 252 | raise NotImplementedError("Only DistributedDataParallel is supported.") 253 | #print(model) # print model after SyncBatchNorm 254 | 255 | if args.optimizer == 'lars': 256 | optimizer = sogclr.optimizer.LARS(model.parameters(), args.lr, 257 | weight_decay=args.weight_decay, 258 | momentum=args.momentum) 259 | elif args.optimizer == 'adamw': 260 | optimizer = torch.optim.AdamW(model.parameters(), args.lr, 261 | weight_decay=args.weight_decay) 262 | 263 | scaler = torch.cuda.amp.GradScaler() 264 | 265 | # log_dir 266 | save_root_path = args.save_dir 267 | global_batch_size = args.batch_size*args.world_size 268 | method_name = {'dcl': 'sogclr'}[args.loss_type] 269 | logdir = '20221005_%s_%s_%s-%s-%s_bz_%s_E%s_WR%s_lr_%.3f_%s_wd_%s_t_%s_g_%s_%s'%(args.data_name, args.arch, method_name, args.dim, args.mlp_dim, global_batch_size, args.epochs, args.warmup_epochs, args.lr, args.learning_rate_scaling, args.weight_decay, args.t, args.gamma, args.optimizer ) 270 | summary_writer = SummaryWriter(log_dir=os.path.join(save_root_path, logdir)) 271 | print (logdir) 272 | 273 | # optionally resume from a checkpoint 274 | if args.resume: 275 | if os.path.isfile(args.resume): 276 | print("=> loading checkpoint '{}'".format(args.resume)) 277 | if args.gpu is None: 278 | checkpoint = torch.load(args.resume) 279 | else: 280 | # Map model to be loaded to specified single gpu. 281 | loc = 'cuda:{}'.format(args.gpu) 282 | checkpoint = torch.load(args.resume, map_location=loc) 283 | args.start_epoch = checkpoint['epoch'] 284 | model.load_state_dict(checkpoint['state_dict']) 285 | optimizer.load_state_dict(checkpoint['optimizer']) 286 | scaler.load_state_dict(checkpoint['scaler']) 287 | model.module.u = checkpoint['u'].cpu() 288 | print('check sum u:', model.module.u.sum()) 289 | print("=> loaded checkpoint '{}' (epoch {})" 290 | .format(args.resume, checkpoint['epoch'])) 291 | else: 292 | print("=> no checkpoint found at '{}'".format(args.resume)) 293 | 294 | cudnn.benchmark = True 295 | 296 | 297 | # Data loading code 298 | mean = {'imagenet100': [0.485, 0.456, 0.406], 299 | 'imagenet1000': [0.485, 0.456, 0.406], 300 | }[args.data_name] 301 | std = {'imagenet100': [0.229, 0.224, 0.225], 302 | 'imagenet1000': [0.229, 0.224, 0.225], 303 | }[args.data_name] 304 | 305 | image_size = {'imagenet100':224, 'imagenet1000':224}[args.data_name] 306 | normalize = transforms.Normalize(mean=mean, std=std) 307 | 308 | # follow BYOL's augmentation recipe: https://arxiv.org/abs/2006.07733 309 | # simclr 310 | augmentation1 = [ 311 | transforms.RandomResizedCrop(image_size, scale=(args.crop_min, 1.)), 312 | transforms.RandomApply([ 313 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened 314 | ], p=0.8), 315 | transforms.RandomGrayscale(p=0.2), 316 | transforms.RandomApply([sogclr.loader.GaussianBlur([.1, 2.])], p=1.0), 317 | transforms.RandomHorizontalFlip(), 318 | transforms.ToTensor(), 319 | normalize 320 | ] 321 | 322 | if args.data_name == 'imagenet1000' or args.data_name == 'imagenet100' : 323 | traindir = os.path.join(args.data, 'train') 324 | train_dataset = sogclr.folder.ImageFolder( 325 | traindir, 326 | sogclr.loader.TwoCropsTransform(transforms.Compose(augmentation1), 327 | transforms.Compose(augmentation1))) 328 | else: 329 | raise ValueError 330 | 331 | if args.distributed: 332 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 333 | else: 334 | train_sampler = None 335 | 336 | train_loader = torch.utils.data.DataLoader( 337 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 338 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) 339 | 340 | for epoch in range(args.start_epoch, args.epochs): 341 | if args.distributed: 342 | train_sampler.set_epoch(epoch) 343 | 344 | # train for one epoch 345 | start_time = time.time() 346 | train(train_loader, model, optimizer, scaler, summary_writer, epoch, args) 347 | print('elapsed time (s): %.1f'%(time.time() - start_time)) 348 | 349 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 350 | and args.rank == 0): # only the first GPU saves checkpoint 351 | 352 | if epoch % 10 == 0 or args.epochs - epoch < 3: 353 | local_u = model.module.u 354 | save_checkpoint({ 355 | 'epoch': epoch + 1, 356 | 'arch': args.arch, 357 | 'state_dict': model.state_dict(), 358 | 'optimizer' : optimizer.state_dict(), 359 | 'scaler': scaler.state_dict(), 360 | 'u': model.module.u, 361 | }, is_best=False, filename=os.path.join(save_root_path, logdir, 'checkpoint_%04d.pth.tar' % epoch) ) 362 | 363 | if args.rank == 0: 364 | summary_writer.close() 365 | 366 | def train(train_loader, model, optimizer, scaler, summary_writer, epoch, args): 367 | batch_time = AverageMeter('Time', ':6.3f') 368 | data_time = AverageMeter('Data', ':6.3f') 369 | learning_rates = AverageMeter('LR', ':.4e') 370 | losses = AverageMeter('Loss', ':.4e') 371 | progress = ProgressMeter( 372 | len(train_loader), 373 | [batch_time, data_time, learning_rates, losses], 374 | prefix="Epoch: [{}]".format(epoch)) 375 | 376 | # switch to train mode 377 | model.train() 378 | 379 | end = time.time() 380 | iters_per_epoch = len(train_loader) 381 | 382 | for i, (images, _, index) in enumerate(train_loader): 383 | # measure data loading time 384 | data_time.update(time.time() - end) 385 | 386 | # adjust learning rate and momentum coefficient per iteration 387 | lr = adjust_learning_rate(optimizer, epoch + i / iters_per_epoch, args) 388 | learning_rates.update(lr) 389 | 390 | if args.gpu is not None: 391 | images[0] = images[0].cuda(args.gpu, non_blocking=True) 392 | images[1] = images[1].cuda(args.gpu, non_blocking=True) 393 | 394 | # compute output 395 | with torch.cuda.amp.autocast(True): 396 | loss = model(images[0], images[1], index, args.gamma) 397 | 398 | losses.update(loss.item(), images[0].size(0)) 399 | if args.rank == 0: 400 | summary_writer.add_scalar("loss", loss.item(), epoch * iters_per_epoch + i) 401 | 402 | # compute gradient and do SGD step 403 | optimizer.zero_grad() 404 | scaler.scale(loss).backward() 405 | scaler.step(optimizer) 406 | scaler.update() 407 | 408 | # measure elapsed time 409 | batch_time.update(time.time() - end) 410 | end = time.time() 411 | 412 | if i % args.print_freq == 0: 413 | progress.display(i) 414 | 415 | 416 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 417 | torch.save(state, filename) 418 | if is_best: 419 | shutil.copyfile(filename, 'model_best.pth.tar') 420 | 421 | 422 | class AverageMeter(object): 423 | """Computes and stores the average and current value""" 424 | def __init__(self, name, fmt=':f'): 425 | self.name = name 426 | self.fmt = fmt 427 | self.reset() 428 | 429 | def reset(self): 430 | self.val = 0 431 | self.avg = 0 432 | self.sum = 0 433 | self.count = 0 434 | 435 | def update(self, val, n=1): 436 | self.val = val 437 | self.sum += val * n 438 | self.count += n 439 | self.avg = self.sum / self.count 440 | 441 | def __str__(self): 442 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 443 | return fmtstr.format(**self.__dict__) 444 | 445 | 446 | class ProgressMeter(object): 447 | def __init__(self, num_batches, meters, prefix=""): 448 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 449 | self.meters = meters 450 | self.prefix = prefix 451 | 452 | def display(self, batch): 453 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 454 | entries += [str(meter) for meter in self.meters] 455 | print('\t'.join(entries)) 456 | 457 | def _get_batch_fmtstr(self, num_batches): 458 | num_digits = len(str(num_batches // 1)) 459 | fmt = '{:' + str(num_digits) + 'd}' 460 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 461 | 462 | 463 | def adjust_learning_rate(optimizer, epoch, args): 464 | """Decays the learning rate with half-cycle cosine after warmup""" 465 | if epoch < args.warmup_epochs: 466 | lr = args.lr * epoch / args.warmup_epochs 467 | else: 468 | lr = args.lr * 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 469 | for param_group in optimizer.param_groups: 470 | param_group['lr'] = lr 471 | return lr 472 | 473 | 474 | def adjust_moco_momentum(epoch, args): 475 | """Adjust moco momentum based on current epoch""" 476 | m = 1. - 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) * (1. - args.moco_m) 477 | return m 478 | 479 | 480 | if __name__ == '__main__': 481 | main() 482 | --------------------------------------------------------------------------------