├── CODEOWNERS
├── CONTRIBUTING-ARCHIVED.md
├── LICENSE
├── README.md
├── eval_cls_imagenet.py
├── eval_svm_voc.py
├── img
└── PCL_framework.png
├── main_pcl.py
├── pcl
├── __init__.py
├── builder.py
└── loader.py
└── voc.py
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
2 | #ECCN:Open Source
3 |
--------------------------------------------------------------------------------
/CONTRIBUTING-ARCHIVED.md:
--------------------------------------------------------------------------------
1 | # ARCHIVED
2 |
3 | This project is `Archived` and is no longer actively maintained;
4 | We are not accepting contributions or Pull Requests.
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Junnan Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Prototypical Contrastive Learning of Unsupervised Representations (Salesforce Research)
2 |
3 |
4 | This is a PyTorch implementation of the PCL paper:
5 |
6 | @inproceedings{PCL, 7 | title={Prototypical Contrastive Learning of Unsupervised Representations}, 8 | author={Junnan Li and Pan Zhou and Caiming Xiong and Steven C.H. Hoi}, 9 | booktitle={ICLR}, 10 | year={2021} 11 | }12 | 13 | ### Requirements: 14 | * ImageNet dataset 15 | * Python ≥ 3.6 16 | * PyTorch ≥ 1.4 17 | * faiss-gpu: pip install faiss-gpu 18 | * pip install tqdm 19 | 20 | ### Unsupervised Training: 21 | This implementation only supports multi-gpu, DistributedDataParallel training, which is faster and simpler; single-gpu or DataParallel training is not supported. 22 | 23 | To perform unsupervised training of a ResNet-50 model on ImageNet using a 4-gpu or 8-gpu machine, run: 24 |
python main_pcl.py \ 25 | -a resnet50 \ 26 | --lr 0.03 \ 27 | --batch-size 256 \ 28 | --temperature 0.2 \ 29 | --mlp --aug-plus --cos (only activated for PCL v2) \ 30 | --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \ 31 | --exp-dir experiment_pcl 32 | [Imagenet dataset folder] 33 |34 | 35 | ### Download Pre-trained Models 36 | PCL v1| PCL v2 37 | ------ | ------ 38 | 39 | ### Linear SVM Evaluation on VOC 40 | To train a linear SVM classifier on VOC dataset, using frozen representations from a pre-trained model, run: 41 |
python eval_svm_voc.py --pretrained [your pretrained model] \ 42 | -a resnet50 \ 43 | --low-shot (only for low-shot evaluation, otherwise the entire dataset is used) \ 44 | [VOC2007 dataset folder] 45 |46 | 47 | Linear SVM classification result on VOC, using ResNet-50 pretrained with PCL for 200 epochs: 48 | 49 | Model| k=1 | k=2 | k=4 | k=8 | k=16| Full 50 | --- | --- | --- | --- | --- | --- | --- 51 | PCL v1| 46.9| 56.4| 62.8| 70.2| 74.3 | 82.3 52 | PCL v2| 47.9| 59.6| 66.2| 74.5| 78.3 | 85.4 53 | 54 | k is the number of training samples per class. 55 | 56 | ### Linear Classifier Evaluation on ImageNet 57 | Requirement: pip install tensorboard_logger \ 58 | To train a logistic regression classifier on ImageNet, using frozen representations from a pre-trained model, run: 59 |
python eval_cls_imagenet.py --pretrained [your pretrained model] \ 60 | -a resnet50 \ 61 | --lr 5 \ 62 | --batch-size 256 \ 63 | --id ImageNet_linear \ 64 | --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \ 65 | [Imagenet dataset folder] 66 |67 | 68 | Linear classification result on ImageNet, using ResNet-50 pretrained with PCL for 200 epochs: 69 | PCL v1 | PCL v2 70 | ------ | ------ 71 | 61.5 | 67.6 72 | 73 | -------------------------------------------------------------------------------- /eval_cls_imagenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import builtins 3 | import os 4 | import random 5 | import shutil 6 | import time 7 | import warnings 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.multiprocessing as mp 16 | import torch.utils.data 17 | import torch.utils.data.distributed 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | import torchvision.models as models 21 | import tensorboard_logger as tb_logger 22 | 23 | model_names = sorted(name for name in models.__dict__ 24 | if name.islower() and not name.startswith("__") 25 | and callable(models.__dict__[name])) 26 | 27 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 28 | parser.add_argument('data', metavar='DIR', 29 | help='path to dataset') 30 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 31 | choices=model_names, 32 | help='model architecture: ' + 33 | ' | '.join(model_names) + 34 | ' (default: resnet50)') 35 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 36 | help='number of data loading workers (default: 32)') 37 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 38 | help='number of total epochs to run') 39 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 40 | help='manual epoch number (useful on restarts)') 41 | parser.add_argument('-b', '--batch-size', default=256, type=int, 42 | metavar='N', 43 | help='mini-batch size (default: 256), this is the total ' 44 | 'batch size of all GPUs on the current node when ' 45 | 'using Data Parallel or Distributed Data Parallel') 46 | parser.add_argument('--lr', '--learning-rate', default=5., type=float, 47 | metavar='LR', help='initial learning rate', dest='lr') 48 | parser.add_argument('--schedule', default=[60, 80], nargs='*', type=int, 49 | help='learning rate schedule (when to drop lr by a ratio)') 50 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 51 | help='momentum') 52 | parser.add_argument('--wd', '--weight-decay', default=0., type=float, 53 | metavar='W', help='weight decay (default: 0.)', 54 | dest='weight_decay') 55 | parser.add_argument('-p', '--print-freq', default=10, type=int, 56 | metavar='N', help='print frequency (default: 10)') 57 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 58 | help='path to latest checkpoint (default: none)') 59 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 60 | help='evaluate model on validation set') 61 | parser.add_argument('--world-size', default=-1, type=int, 62 | help='number of nodes for distributed training') 63 | parser.add_argument('--rank', default=-1, type=int, 64 | help='node rank for distributed training') 65 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 66 | help='url used to set up distributed training') 67 | parser.add_argument('--dist-backend', default='nccl', type=str, 68 | help='distributed backend') 69 | parser.add_argument('--seed', default=None, type=int, 70 | help='seed for initializing training. ') 71 | parser.add_argument('--gpu', default=None, type=int, 72 | help='GPU id to use.') 73 | parser.add_argument('--multiprocessing-distributed', action='store_true', 74 | help='Use multi-processing distributed training to launch ' 75 | 'N processes per node, which has N GPUs. This is the ' 76 | 'fastest way to use PyTorch for either single node or ' 77 | 'multi node data parallel training') 78 | 79 | parser.add_argument('--pretrained', default='', type=str, 80 | help='path to pretrained checkpoint') 81 | parser.add_argument('--id', type=str, default='') 82 | 83 | 84 | def main(): 85 | args = parser.parse_args() 86 | 87 | if args.seed is not None: 88 | random.seed(args.seed) 89 | torch.manual_seed(args.seed) 90 | cudnn.deterministic = True 91 | warnings.warn('You have chosen to seed training. ' 92 | 'This will turn on the CUDNN deterministic setting, ' 93 | 'which can slow down your training considerably! ' 94 | 'You may see unexpected behavior when restarting ' 95 | 'from checkpoints.') 96 | 97 | if args.gpu is not None: 98 | warnings.warn('You have chosen a specific GPU. This will completely ' 99 | 'disable data parallelism.') 100 | 101 | if args.dist_url == "env://" and args.world_size == -1: 102 | args.world_size = int(os.environ["WORLD_SIZE"]) 103 | 104 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 105 | 106 | args.tb_folder = 'Linear_eval/{}_tensorboard'.format(args.id) 107 | if not os.path.isdir(args.tb_folder): 108 | os.makedirs(args.tb_folder) 109 | 110 | ngpus_per_node = torch.cuda.device_count() 111 | if args.multiprocessing_distributed: 112 | # Since we have ngpus_per_node processes per node, the total world_size 113 | # needs to be adjusted accordingly 114 | args.world_size = ngpus_per_node * args.world_size 115 | # Use torch.multiprocessing.spawn to launch distributed processes: the 116 | # main_worker process function 117 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 118 | else: 119 | # Simply call main_worker function 120 | main_worker(args.gpu, ngpus_per_node, args) 121 | 122 | 123 | def main_worker(gpu, ngpus_per_node, args): 124 | args.gpu = gpu 125 | 126 | if args.gpu is not None: 127 | print("Use GPU: {} for training".format(args.gpu)) 128 | 129 | # suppress printing if not master 130 | if args.multiprocessing_distributed and args.gpu != 0: 131 | def print_pass(*args): 132 | pass 133 | builtins.print = print_pass 134 | 135 | if args.distributed: 136 | if args.dist_url == "env://" and args.rank == -1: 137 | args.rank = int(os.environ["RANK"]) 138 | if args.multiprocessing_distributed: 139 | # For multiprocessing distributed training, rank needs to be the 140 | # global rank among all the processes 141 | args.rank = args.rank * ngpus_per_node + gpu 142 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 143 | world_size=args.world_size, rank=args.rank) 144 | # create model 145 | print("=> creating model '{}'".format(args.arch)) 146 | model = models.__dict__[args.arch]() 147 | 148 | # freeze all layers but the last fc 149 | for name, param in model.named_parameters(): 150 | if name not in ['fc.weight', 'fc.bias']: 151 | param.requires_grad = False 152 | # init the fc layer 153 | model.fc.weight.data.normal_(mean=0.0, std=0.01) 154 | model.fc.bias.data.zero_() 155 | 156 | if args.gpu==0: 157 | logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2) 158 | else: 159 | logger = None 160 | 161 | # load from pre-trained, before DistributedDataParallel constructor 162 | if args.pretrained: 163 | if os.path.isfile(args.pretrained): 164 | print("=> loading checkpoint '{}'".format(args.pretrained)) 165 | checkpoint = torch.load(args.pretrained, map_location="cpu") 166 | 167 | # rename pre-trained keys 168 | state_dict = checkpoint['state_dict'] 169 | for k in list(state_dict.keys()): 170 | # retain only encoder_q up to before the embedding layer 171 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 172 | # remove prefix 173 | state_dict[k[len("module.encoder_q."):]] = state_dict[k] 174 | # delete renamed or unused k 175 | del state_dict[k] 176 | 177 | args.start_epoch = 0 178 | msg = model.load_state_dict(state_dict, strict=False) 179 | assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} 180 | 181 | print("=> loaded pre-trained model '{}'".format(args.pretrained)) 182 | else: 183 | print("=> no checkpoint found at '{}'".format(args.pretrained)) 184 | 185 | if args.distributed: 186 | # For multiprocessing distributed, DistributedDataParallel constructor 187 | # should always set the single device scope, otherwise, 188 | # DistributedDataParallel will use all available devices. 189 | if args.gpu is not None: 190 | torch.cuda.set_device(args.gpu) 191 | model.cuda(args.gpu) 192 | # When using a single GPU per process and per 193 | # DistributedDataParallel, we need to divide the batch size 194 | # ourselves based on the total number of GPUs we have 195 | args.batch_size = int(args.batch_size / ngpus_per_node) 196 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 197 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 198 | else: 199 | model.cuda() 200 | # DistributedDataParallel will divide and allocate batch_size to all 201 | # available GPUs if device_ids are not set 202 | model = torch.nn.parallel.DistributedDataParallel(model) 203 | elif args.gpu is not None: 204 | torch.cuda.set_device(args.gpu) 205 | model = model.cuda(args.gpu) 206 | else: 207 | # DataParallel will divide and allocate batch_size to all available GPUs 208 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 209 | model.features = torch.nn.DataParallel(model.features) 210 | model.cuda() 211 | else: 212 | model = torch.nn.DataParallel(model).cuda() 213 | 214 | # define loss function (criterion) and optimizer 215 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 216 | 217 | # optimize only the linear classifier 218 | parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 219 | assert len(parameters) == 2 # fc.weight, fc.bias 220 | optimizer = torch.optim.SGD(parameters, args.lr, 221 | momentum=args.momentum, 222 | weight_decay=args.weight_decay) 223 | 224 | # optionally resume from a checkpoint 225 | if args.resume: 226 | if os.path.isfile(args.resume): 227 | print("=> loading checkpoint '{}'".format(args.resume)) 228 | if args.gpu is None: 229 | checkpoint = torch.load(args.resume) 230 | else: 231 | # Map model to be loaded to specified single gpu. 232 | loc = 'cuda:{}'.format(args.gpu) 233 | checkpoint = torch.load(args.resume, map_location=loc) 234 | args.start_epoch = checkpoint['epoch'] 235 | model.load_state_dict(checkpoint['state_dict']) 236 | optimizer.load_state_dict(checkpoint['optimizer']) 237 | print("=> loaded checkpoint '{}' (epoch {})" 238 | .format(args.resume, checkpoint['epoch'])) 239 | else: 240 | print("=> no checkpoint found at '{}'".format(args.resume)) 241 | 242 | cudnn.benchmark = True 243 | 244 | # Data loading code 245 | traindir = os.path.join(args.data, 'train') 246 | valdir = os.path.join(args.data, 'val') 247 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 248 | std=[0.229, 0.224, 0.225]) 249 | 250 | train_dataset = datasets.ImageFolder( 251 | traindir, 252 | transforms.Compose([ 253 | transforms.RandomResizedCrop(224), 254 | transforms.RandomHorizontalFlip(), 255 | transforms.ToTensor(), 256 | normalize, 257 | ])) 258 | 259 | if args.distributed: 260 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 261 | else: 262 | train_sampler = None 263 | 264 | train_loader = torch.utils.data.DataLoader( 265 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 266 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 267 | 268 | val_loader = torch.utils.data.DataLoader( 269 | datasets.ImageFolder(valdir, transforms.Compose([ 270 | transforms.Resize(256), 271 | transforms.CenterCrop(224), 272 | transforms.ToTensor(), 273 | normalize, 274 | ])), 275 | batch_size=args.batch_size, shuffle=False, 276 | num_workers=args.workers, pin_memory=True) 277 | 278 | if args.evaluate: 279 | validate(val_loader, model, criterion, args) 280 | return 281 | 282 | for epoch in range(args.start_epoch, args.epochs): 283 | if args.distributed: 284 | train_sampler.set_epoch(epoch) 285 | adjust_learning_rate(optimizer, epoch, args) 286 | 287 | # train for one epoch 288 | train(train_loader, model, criterion, optimizer, epoch, args) 289 | 290 | # evaluate on validation set 291 | acc1 = validate(val_loader, model, criterion, args, logger, epoch) 292 | 293 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 294 | and args.rank % ngpus_per_node == 0): 295 | save_checkpoint({ 296 | 'epoch': epoch + 1, 297 | 'arch': args.arch, 298 | 'state_dict': model.state_dict(), 299 | 'optimizer' : optimizer.state_dict(), 300 | }) 301 | if epoch == args.start_epoch: 302 | sanity_check(model.state_dict(), args.pretrained) 303 | 304 | 305 | def train(train_loader, model, criterion, optimizer, epoch, args): 306 | batch_time = AverageMeter('Time', ':6.3f') 307 | data_time = AverageMeter('Data', ':6.3f') 308 | losses = AverageMeter('Loss', ':.4e') 309 | top1 = AverageMeter('Acc@1', ':6.2f') 310 | top5 = AverageMeter('Acc@5', ':6.2f') 311 | progress = ProgressMeter( 312 | len(train_loader), 313 | [batch_time, data_time, losses, top1, top5], 314 | prefix="Epoch: [{}]".format(epoch)) 315 | 316 | """ 317 | Switch to eval mode: 318 | Under the protocol of linear classification on frozen features/models, 319 | it is not legitimate to change any part of the pre-trained model. 320 | BatchNorm in train mode may revise running mean/std (even if it receives 321 | no gradient), which are part of the model parameters too. 322 | """ 323 | model.eval() 324 | 325 | end = time.time() 326 | for i, (images, target) in enumerate(train_loader): 327 | # measure data loading time 328 | data_time.update(time.time() - end) 329 | 330 | if args.gpu is not None: 331 | images = images.cuda(args.gpu, non_blocking=True) 332 | target = target.cuda(args.gpu, non_blocking=True) 333 | 334 | # compute output 335 | output = model(images) 336 | loss = criterion(output, target) 337 | 338 | # measure accuracy and record loss 339 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 340 | losses.update(loss.item(), images.size(0)) 341 | top1.update(acc1[0], images.size(0)) 342 | top5.update(acc5[0], images.size(0)) 343 | 344 | # compute gradient and do SGD step 345 | optimizer.zero_grad() 346 | loss.backward() 347 | optimizer.step() 348 | 349 | # measure elapsed time 350 | batch_time.update(time.time() - end) 351 | end = time.time() 352 | 353 | if i % args.print_freq == 0: 354 | progress.display(i) 355 | 356 | 357 | def validate(val_loader, model, criterion, args, logger, epoch): 358 | batch_time = AverageMeter('Time', ':6.3f') 359 | losses = AverageMeter('Loss', ':.4e') 360 | top1 = AverageMeter('Acc@1', ':6.2f') 361 | top5 = AverageMeter('Acc@5', ':6.2f') 362 | progress = ProgressMeter( 363 | len(val_loader), 364 | [batch_time, losses, top1, top5], 365 | prefix='Test: ') 366 | 367 | # switch to evaluate mode 368 | model.eval() 369 | 370 | with torch.no_grad(): 371 | end = time.time() 372 | for i, (images, target) in enumerate(val_loader): 373 | if args.gpu is not None: 374 | images = images.cuda(args.gpu, non_blocking=True) 375 | target = target.cuda(args.gpu, non_blocking=True) 376 | 377 | # compute output 378 | output = model(images) 379 | loss = criterion(output, target) 380 | 381 | # measure accuracy and record loss 382 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 383 | losses.update(loss.item(), images.size(0)) 384 | top1.update(acc1[0], images.size(0)) 385 | top5.update(acc5[0], images.size(0)) 386 | 387 | # measure elapsed time 388 | batch_time.update(time.time() - end) 389 | end = time.time() 390 | 391 | if i % args.print_freq == 0: 392 | progress.display(i) 393 | 394 | # TODO: this should also be done with the ProgressMeter 395 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 396 | .format(top1=top1, top5=top5)) 397 | if args.gpu==0: 398 | logger.log_value('test_acc', top1.avg, epoch) 399 | logger.log_value('test_acc5', top5.avg, epoch) 400 | return top1.avg 401 | 402 | 403 | def save_checkpoint(state, filename='checkpoint.pth.tar'): 404 | torch.save(state, filename) 405 | 406 | 407 | def sanity_check(state_dict, pretrained_weights): 408 | """ 409 | Linear classifier should not change any weights other than the linear layer. 410 | This sanity check asserts nothing wrong happens (e.g., BN stats updated). 411 | """ 412 | print("=> loading '{}' for sanity check".format(pretrained_weights)) 413 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 414 | state_dict_pre = checkpoint['state_dict'] 415 | 416 | for k in list(state_dict.keys()): 417 | # only ignore fc layer 418 | if 'fc.weight' in k or 'fc.bias' in k: 419 | continue 420 | # name in pretrained model 421 | k_pre = 'module.encoder_q.' + k[len('module.'):] \ 422 | if k.startswith('module.') else 'module.encoder_q.' + k 423 | 424 | assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \ 425 | '{} is changed in linear classifier training.'.format(k) 426 | 427 | print("=> sanity check passed.") 428 | 429 | 430 | class AverageMeter(object): 431 | """Computes and stores the average and current value""" 432 | def __init__(self, name, fmt=':f'): 433 | self.name = name 434 | self.fmt = fmt 435 | self.reset() 436 | 437 | def reset(self): 438 | self.val = 0 439 | self.avg = 0 440 | self.sum = 0 441 | self.count = 0 442 | 443 | def update(self, val, n=1): 444 | self.val = val 445 | self.sum += val * n 446 | self.count += n 447 | self.avg = self.sum / self.count 448 | 449 | def __str__(self): 450 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 451 | return fmtstr.format(**self.__dict__) 452 | 453 | 454 | class ProgressMeter(object): 455 | def __init__(self, num_batches, meters, prefix=""): 456 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 457 | self.meters = meters 458 | self.prefix = prefix 459 | 460 | def display(self, batch): 461 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 462 | entries += [str(meter) for meter in self.meters] 463 | print('\t'.join(entries)) 464 | 465 | def _get_batch_fmtstr(self, num_batches): 466 | num_digits = len(str(num_batches // 1)) 467 | fmt = '{:' + str(num_digits) + 'd}' 468 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 469 | 470 | 471 | def adjust_learning_rate(optimizer, epoch, args): 472 | """Decay the learning rate based on schedule""" 473 | lr = args.lr 474 | for milestone in args.schedule: 475 | lr *= 0.1 if epoch >= milestone else 1. 476 | for param_group in optimizer.param_groups: 477 | param_group['lr'] = lr 478 | 479 | 480 | def accuracy(output, target, topk=(1,)): 481 | """Computes the accuracy over the k top predictions for the specified values of k""" 482 | with torch.no_grad(): 483 | maxk = max(topk) 484 | batch_size = target.size(0) 485 | 486 | _, pred = output.topk(maxk, 1, True, True) 487 | pred = pred.t() 488 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 489 | 490 | res = [] 491 | for k in topk: 492 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 493 | res.append(correct_k.mul_(100.0 / batch_size)) 494 | return res 495 | 496 | 497 | if __name__ == '__main__': 498 | main() 499 | -------------------------------------------------------------------------------- /eval_svm_voc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | import time 6 | import torch 7 | import torch.optim as optim 8 | import torch.backends.cudnn as cudnn 9 | import torch.nn.functional as F 10 | import argparse 11 | import random 12 | import numpy as np 13 | 14 | from torchvision import transforms, datasets 15 | import torchvision.models as models 16 | 17 | from voc import Voc2007Classification 18 | 19 | from sklearn.svm import LinearSVC 20 | 21 | 22 | def parse_option(): 23 | model_names = sorted(name for name in models.__dict__ 24 | if name.islower() and not name.startswith("__") 25 | and callable(models.__dict__[name])) 26 | 27 | parser = argparse.ArgumentParser('argument for training') 28 | 29 | parser.add_argument('data', metavar='DIR', 30 | help='path to dataset') 31 | parser.add_argument('--batch-size', type=int, default=128, help='batch size') 32 | parser.add_argument('--num-workers', type=int, default=8, help='num of workers to use') 33 | parser.add_argument('--cost', type=str, default='0.5') 34 | parser.add_argument('--seed', default=0, type=int) 35 | 36 | # model definition 37 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 38 | choices=model_names, 39 | help='model architecture: ' + 40 | ' | '.join(model_names) + 41 | ' (default: resnet50)') 42 | parser.add_argument('--pretrained', default='', type=str, 43 | help='path to pretrained checkpoint') 44 | # dataset 45 | parser.add_argument('--low-shot', default=False, action='store_true', help='whether to perform low-shot training.') 46 | 47 | opt = parser.parse_args() 48 | 49 | opt.num_class = 20 50 | 51 | # if low shot experiment, do 5 random runs 52 | if opt.low_shot: 53 | opt.n_run = 5 54 | else: 55 | opt.n_run = 1 56 | return opt 57 | 58 | 59 | def calculate_ap(rec, prec): 60 | """ 61 | Computes the AP under the precision recall curve. 62 | """ 63 | rec, prec = rec.reshape(rec.size, 1), prec.reshape(prec.size, 1) 64 | z, o = np.zeros((1, 1)), np.ones((1, 1)) 65 | mrec, mpre = np.vstack((z, rec, o)), np.vstack((z, prec, z)) 66 | for i in range(len(mpre) - 2, -1, -1): 67 | mpre[i] = max(mpre[i], mpre[i + 1]) 68 | 69 | indices = np.where(mrec[1:] != mrec[0:-1])[0] + 1 70 | ap = 0 71 | for i in indices: 72 | ap = ap + (mrec[i] - mrec[i - 1]) * mpre[i] 73 | return ap 74 | 75 | def get_precision_recall(targets, preds): 76 | """ 77 | [P, R, score, ap] = get_precision_recall(targets, preds) 78 | Input : 79 | targets : number of occurrences of this class in the ith image 80 | preds : score for this image 81 | Output : 82 | P, R : precision and recall 83 | score : score which corresponds to the particular precision and recall 84 | ap : average precision 85 | """ 86 | # binarize targets 87 | targets = np.array(targets > 0, dtype=np.float32) 88 | tog = np.hstack(( 89 | targets[:, np.newaxis].astype(np.float64), 90 | preds[:, np.newaxis].astype(np.float64) 91 | )) 92 | ind = np.argsort(preds) 93 | ind = ind[::-1] 94 | score = np.array([tog[i, 1] for i in ind]) 95 | sortcounts = np.array([tog[i, 0] for i in ind]) 96 | 97 | tp = sortcounts 98 | fp = sortcounts.copy() 99 | for i in range(sortcounts.shape[0]): 100 | if sortcounts[i] >= 1: 101 | fp[i] = 0. 102 | elif sortcounts[i] < 1: 103 | fp[i] = 1. 104 | P = np.cumsum(tp) / (np.cumsum(tp) + np.cumsum(fp)) 105 | numinst = np.sum(targets) 106 | R = np.cumsum(tp) / numinst 107 | ap = calculate_ap(R, P) 108 | return P, R, score, ap 109 | 110 | 111 | def main(): 112 | args = parse_option() 113 | 114 | random.seed(args.seed) 115 | np.random.seed(args.seed) 116 | 117 | mean = [0.485, 0.456, 0.406] 118 | std = [0.229, 0.224, 0.225] 119 | normalize = transforms.Normalize(mean=mean, std=std) 120 | transform = transforms.Compose([ 121 | transforms.Resize(256), 122 | transforms.CenterCrop(224), 123 | transforms.ToTensor(), 124 | normalize, 125 | ]) 126 | 127 | train_dataset = Voc2007Classification(args.data,set='trainval',transform = transform) 128 | val_dataset = Voc2007Classification(args.data,set='test',transform = transform) 129 | 130 | val_loader = torch.utils.data.DataLoader( 131 | val_dataset, batch_size=args.batch_size, shuffle=False, 132 | num_workers=args.num_workers, pin_memory=True) 133 | 134 | # create model 135 | print("=> creating model '{}'".format(args.arch)) 136 | model = models.__dict__[args.arch](num_classes=128) 137 | 138 | # load from pre-trained 139 | if args.pretrained: 140 | if os.path.isfile(args.pretrained): 141 | print("=> loading checkpoint '{}'".format(args.pretrained)) 142 | checkpoint = torch.load(args.pretrained, map_location="cpu") 143 | state_dict = checkpoint['state_dict'] 144 | # rename pre-trained keys 145 | for k in list(state_dict.keys()): 146 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 147 | # remove prefix 148 | state_dict[k[len("module.encoder_q."):]] = state_dict[k] 149 | # delete renamed or unused k 150 | del state_dict[k] 151 | model.load_state_dict(state_dict, strict=False) 152 | model.fc = torch.nn.Identity() 153 | print("=> loaded pre-trained model '{}'".format(args.pretrained)) 154 | else: 155 | print("=> no checkpoint found at '{}'".format(args.pretrained)) 156 | 157 | model.cuda() 158 | model.eval() 159 | 160 | test_feats = [] 161 | test_labels = [] 162 | print('==> calculate test features') 163 | for idx, (images, target) in enumerate(val_loader): 164 | images = images.cuda(non_blocking=True) 165 | feat = model(images) 166 | feat = feat.detach().cpu() 167 | test_feats.append(feat) 168 | test_labels.append(target) 169 | 170 | test_feats = torch.cat(test_feats,0).numpy() 171 | test_labels = torch.cat(test_labels,0).numpy() 172 | 173 | test_feats_norm = np.linalg.norm(test_feats, axis=1) 174 | test_feats = test_feats / (test_feats_norm + 1e-5)[:, np.newaxis] 175 | 176 | result={} 177 | 178 | if args.low_shot: 179 | k_list = [1,2,4,8,16] #number of samples per-class for low-shot classifcation 180 | else: 181 | k_list = ['full'] 182 | 183 | for k in k_list: 184 | cost_list = args.cost.split(',') 185 | result_k = np.zeros(len(cost_list)) 186 | for i,cost in enumerate(cost_list): 187 | cost = float(cost) 188 | avg_map = [] 189 | for run in range(args.n_run): 190 | if args.low_shot: # sample k-shot training data 191 | print('==> re-sampling training data') 192 | train_dataset.convert_low_shot(k) 193 | print(len(train_dataset)) 194 | 195 | train_loader = torch.utils.data.DataLoader( 196 | train_dataset, batch_size=args.batch_size, shuffle=False, 197 | num_workers=args.num_workers, pin_memory=True) 198 | 199 | train_feats = [] 200 | train_labels = [] 201 | print('==> calculate train features') 202 | for idx, (images, target) in enumerate(train_loader): 203 | images = images.cuda(non_blocking=True) 204 | feat = model(images) 205 | feat = feat.detach() 206 | 207 | train_feats.append(feat) 208 | train_labels.append(target) 209 | 210 | train_feats = torch.cat(train_feats,0).cpu().numpy() 211 | train_labels = torch.cat(train_labels,0).cpu().numpy() 212 | 213 | train_feats_norm = np.linalg.norm(train_feats, axis=1) 214 | train_feats = train_feats / (train_feats_norm + 1e-5)[:, np.newaxis] 215 | 216 | print('==> training SVM Classifier') 217 | cls_ap = np.zeros((args.num_class, 1)) 218 | test_labels[test_labels==0] = -1 219 | train_labels[train_labels==0] = -1 220 | for cls in range(args.num_class): 221 | clf = LinearSVC( 222 | C=cost, class_weight={1: 2, -1: 1}, intercept_scaling=1.0, 223 | penalty='l2', loss='squared_hinge', tol=1e-4, 224 | dual=True, max_iter=2000, random_state=0) 225 | clf.fit(train_feats, train_labels[:,cls]) 226 | 227 | prediction = clf.decision_function(test_feats) 228 | P, R, score, ap = get_precision_recall(test_labels[:,cls], prediction) 229 | cls_ap[cls][0] = ap*100 230 | mean_ap = np.mean(cls_ap, axis=0) 231 | 232 | print('==> Run%d mAP is %.2f: '%(run,mean_ap)) 233 | avg_map.append(mean_ap) 234 | 235 | avg_map = np.asarray(avg_map) 236 | print('Cost:%.2f - Average ap is: %.2f' %(cost,avg_map.mean())) 237 | print('Cost:%.2f - Std is: %.2f' %(cost,avg_map.std())) 238 | result_k[i]=avg_map.mean() 239 | result[k] = result_k.max() 240 | print(result) 241 | 242 | if __name__ == '__main__': 243 | main() 244 | 245 | -------------------------------------------------------------------------------- /img/PCL_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/PCL/30682fe9319dfd50964bf9528f18a9a82988015c/img/PCL_framework.png -------------------------------------------------------------------------------- /main_pcl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import builtins 3 | import math 4 | import os 5 | import random 6 | import shutil 7 | import time 8 | import warnings 9 | from tqdm import tqdm 10 | import numpy as np 11 | import faiss 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.parallel 16 | import torch.backends.cudnn as cudnn 17 | import torch.distributed as dist 18 | import torch.optim 19 | import torch.multiprocessing as mp 20 | import torch.utils.data 21 | import torch.utils.data.distributed 22 | import torchvision.transforms as transforms 23 | import torchvision.datasets as datasets 24 | import torchvision.models as models 25 | 26 | import pcl.loader 27 | import pcl.builder 28 | 29 | model_names = sorted(name for name in models.__dict__ 30 | if name.islower() and not name.startswith("__") 31 | and callable(models.__dict__[name])) 32 | 33 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 34 | parser.add_argument('data', metavar='DIR', 35 | help='path to dataset') 36 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 37 | choices=model_names, 38 | help='model architecture: ' + 39 | ' | '.join(model_names) + 40 | ' (default: resnet50)') 41 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 42 | help='number of data loading workers (default: 32)') 43 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 44 | help='number of total epochs to run') 45 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 46 | help='manual epoch number (useful on restarts)') 47 | parser.add_argument('-b', '--batch-size', default=256, type=int, 48 | metavar='N', 49 | help='mini-batch size (default: 256), this is the total ' 50 | 'batch size of all GPUs on the current node when ' 51 | 'using Data Parallel or Distributed Data Parallel') 52 | parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, 53 | metavar='LR', help='initial learning rate', dest='lr') 54 | parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int, 55 | help='learning rate schedule (when to drop lr by 10x)') 56 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 57 | help='momentum of SGD solver') 58 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 59 | metavar='W', help='weight decay (default: 1e-4)', 60 | dest='weight_decay') 61 | parser.add_argument('-p', '--print-freq', default=100, type=int, 62 | metavar='N', help='print frequency (default: 10)') 63 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 64 | help='path to latest checkpoint (default: none)') 65 | parser.add_argument('--world-size', default=-1, type=int, 66 | help='number of nodes for distributed training') 67 | parser.add_argument('--rank', default=-1, type=int, 68 | help='node rank for distributed training') 69 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 70 | help='url used to set up distributed training') 71 | parser.add_argument('--dist-backend', default='nccl', type=str, 72 | help='distributed backend') 73 | parser.add_argument('--seed', default=None, type=int, 74 | help='seed for initializing training. ') 75 | parser.add_argument('--gpu', default=None, type=int, 76 | help='GPU id to use.') 77 | parser.add_argument('--multiprocessing-distributed', action='store_true', 78 | help='Use multi-processing distributed training to launch ' 79 | 'N processes per node, which has N GPUs. This is the ' 80 | 'fastest way to use PyTorch for either single node or ' 81 | 'multi node data parallel training') 82 | 83 | parser.add_argument('--low-dim', default=128, type=int, 84 | help='feature dimension (default: 128)') 85 | parser.add_argument('--pcl-r', default=16384, type=int, 86 | help='queue size; number of negative pairs; needs to be smaller than num_cluster (default: 16384)') 87 | parser.add_argument('--moco-m', default=0.999, type=float, 88 | help='moco momentum of updating key encoder (default: 0.999)') 89 | parser.add_argument('--temperature', default=0.2, type=float, 90 | help='softmax temperature') 91 | 92 | parser.add_argument('--mlp', action='store_true', 93 | help='use mlp head') 94 | parser.add_argument('--aug-plus', action='store_true', 95 | help='use moco-v2/SimCLR data augmentation') 96 | parser.add_argument('--cos', action='store_true', 97 | help='use cosine lr schedule') 98 | 99 | parser.add_argument('--num-cluster', default='25000,50000,100000', type=str, 100 | help='number of clusters') 101 | parser.add_argument('--warmup-epoch', default=20, type=int, 102 | help='number of warm-up epochs to only train with InfoNCE loss') 103 | parser.add_argument('--exp-dir', default='experiment_pcl', type=str, 104 | help='experiment directory') 105 | 106 | def main(): 107 | args = parser.parse_args() 108 | 109 | if args.seed is not None: 110 | random.seed(args.seed) 111 | torch.manual_seed(args.seed) 112 | cudnn.deterministic = True 113 | warnings.warn('You have chosen to seed training. ' 114 | 'This will turn on the CUDNN deterministic setting, ' 115 | 'which can slow down your training considerably! ' 116 | 'You may see unexpected behavior when restarting ' 117 | 'from checkpoints.') 118 | 119 | if args.gpu is not None: 120 | warnings.warn('You have chosen a specific GPU. This will completely ' 121 | 'disable data parallelism.') 122 | 123 | if args.dist_url == "env://" and args.world_size == -1: 124 | args.world_size = int(os.environ["WORLD_SIZE"]) 125 | 126 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 127 | 128 | args.num_cluster = args.num_cluster.split(',') 129 | 130 | if not os.path.exists(args.exp_dir): 131 | os.mkdir(args.exp_dir) 132 | 133 | ngpus_per_node = torch.cuda.device_count() 134 | if args.multiprocessing_distributed: 135 | # Since we have ngpus_per_node processes per node, the total world_size 136 | # needs to be adjusted accordingly 137 | args.world_size = ngpus_per_node * args.world_size 138 | # Use torch.multiprocessing.spawn to launch distributed processes: the 139 | # main_worker process function 140 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 141 | else: 142 | # Simply call main_worker function 143 | main_worker(args.gpu, ngpus_per_node, args) 144 | 145 | 146 | def main_worker(gpu, ngpus_per_node, args): 147 | args.gpu = gpu 148 | 149 | if args.gpu is not None: 150 | print("Use GPU: {} for training".format(args.gpu)) 151 | 152 | # suppress printing if not master 153 | if args.multiprocessing_distributed and args.gpu != 0: 154 | def print_pass(*args): 155 | pass 156 | builtins.print = print_pass 157 | 158 | if args.distributed: 159 | if args.dist_url == "env://" and args.rank == -1: 160 | args.rank = int(os.environ["RANK"]) 161 | if args.multiprocessing_distributed: 162 | # For multiprocessing distributed training, rank needs to be the 163 | # global rank among all the processes 164 | args.rank = args.rank * ngpus_per_node + gpu 165 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 166 | world_size=args.world_size, rank=args.rank) 167 | # create model 168 | print("=> creating model '{}'".format(args.arch)) 169 | model = pcl.builder.MoCo( 170 | models.__dict__[args.arch], 171 | args.low_dim, args.pcl_r, args.moco_m, args.temperature, args.mlp) 172 | print(model) 173 | 174 | if args.distributed: 175 | # For multiprocessing distributed, DistributedDataParallel constructor 176 | # should always set the single device scope, otherwise, 177 | # DistributedDataParallel will use all available devices. 178 | if args.gpu is not None: 179 | torch.cuda.set_device(args.gpu) 180 | model.cuda(args.gpu) 181 | # When using a single GPU per process and per 182 | # DistributedDataParallel, we need to divide the batch size 183 | # ourselves based on the total number of GPUs we have 184 | args.batch_size = int(args.batch_size / ngpus_per_node) 185 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 186 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 187 | else: 188 | model.cuda() 189 | # DistributedDataParallel will divide and allocate batch_size to all 190 | # available GPUs if device_ids are not set 191 | model = torch.nn.parallel.DistributedDataParallel(model) 192 | elif args.gpu is not None: 193 | torch.cuda.set_device(args.gpu) 194 | model = model.cuda(args.gpu) 195 | # comment out the following line for debugging 196 | raise NotImplementedError("Only DistributedDataParallel is supported.") 197 | else: 198 | # AllGather implementation (batch shuffle, queue update, etc.) in 199 | # this code only supports DistributedDataParallel. 200 | raise NotImplementedError("Only DistributedDataParallel is supported.") 201 | 202 | # define loss function (criterion) and optimizer 203 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 204 | 205 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 206 | momentum=args.momentum, 207 | weight_decay=args.weight_decay) 208 | 209 | # optionally resume from a checkpoint 210 | if args.resume: 211 | if os.path.isfile(args.resume): 212 | print("=> loading checkpoint '{}'".format(args.resume)) 213 | if args.gpu is None: 214 | checkpoint = torch.load(args.resume) 215 | else: 216 | # Map model to be loaded to specified single gpu. 217 | loc = 'cuda:{}'.format(args.gpu) 218 | checkpoint = torch.load(args.resume, map_location=loc) 219 | args.start_epoch = checkpoint['epoch'] 220 | model.load_state_dict(checkpoint['state_dict']) 221 | optimizer.load_state_dict(checkpoint['optimizer']) 222 | print("=> loaded checkpoint '{}' (epoch {})" 223 | .format(args.resume, checkpoint['epoch'])) 224 | else: 225 | print("=> no checkpoint found at '{}'".format(args.resume)) 226 | 227 | cudnn.benchmark = True 228 | 229 | # Data loading code 230 | traindir = os.path.join(args.data, 'train') 231 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 232 | std=[0.229, 0.224, 0.225]) 233 | 234 | if args.aug_plus: 235 | # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709 236 | augmentation = [ 237 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 238 | transforms.RandomApply([ 239 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened 240 | ], p=0.8), 241 | transforms.RandomGrayscale(p=0.2), 242 | transforms.RandomApply([pcl.loader.GaussianBlur([.1, 2.])], p=0.5), 243 | transforms.RandomHorizontalFlip(), 244 | transforms.ToTensor(), 245 | normalize 246 | ] 247 | else: 248 | # MoCo v1's aug: same as InstDisc https://arxiv.org/abs/1805.01978 249 | augmentation = [ 250 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 251 | transforms.RandomGrayscale(p=0.2), 252 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), 253 | transforms.RandomHorizontalFlip(), 254 | transforms.ToTensor(), 255 | normalize 256 | ] 257 | 258 | # center-crop augmentation 259 | eval_augmentation = transforms.Compose([ 260 | transforms.Resize(256), 261 | transforms.CenterCrop(224), 262 | transforms.ToTensor(), 263 | normalize 264 | ]) 265 | 266 | train_dataset = pcl.loader.ImageFolderInstance( 267 | traindir, 268 | pcl.loader.TwoCropsTransform(transforms.Compose(augmentation))) 269 | eval_dataset = pcl.loader.ImageFolderInstance( 270 | traindir, 271 | eval_augmentation) 272 | 273 | if args.distributed: 274 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 275 | eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset,shuffle=False) 276 | else: 277 | train_sampler = None 278 | eval_sampler = None 279 | 280 | train_loader = torch.utils.data.DataLoader( 281 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 282 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) 283 | 284 | # dataloader for center-cropped images, use larger batch size to increase speed 285 | eval_loader = torch.utils.data.DataLoader( 286 | eval_dataset, batch_size=args.batch_size*5, shuffle=False, 287 | sampler=eval_sampler, num_workers=args.workers, pin_memory=True) 288 | 289 | for epoch in range(args.start_epoch, args.epochs): 290 | 291 | cluster_result = None 292 | if epoch>=args.warmup_epoch: 293 | # compute momentum features for center-cropped images 294 | features = compute_features(eval_loader, model, args) 295 | 296 | # placeholder for clustering result 297 | cluster_result = {'im2cluster':[],'centroids':[],'density':[]} 298 | for num_cluster in args.num_cluster: 299 | cluster_result['im2cluster'].append(torch.zeros(len(eval_dataset),dtype=torch.long).cuda()) 300 | cluster_result['centroids'].append(torch.zeros(int(num_cluster),args.low_dim).cuda()) 301 | cluster_result['density'].append(torch.zeros(int(num_cluster)).cuda()) 302 | 303 | if args.gpu == 0: 304 | features[torch.norm(features,dim=1)>1.5] /= 2 #account for the few samples that are computed twice 305 | features = features.numpy() 306 | cluster_result = run_kmeans(features,args) #run kmeans clustering on master node 307 | # save the clustering result 308 | # torch.save(cluster_result,os.path.join(args.exp_dir, 'clusters_%d'%epoch)) 309 | 310 | dist.barrier() 311 | # broadcast clustering result 312 | for k, data_list in cluster_result.items(): 313 | for data_tensor in data_list: 314 | dist.broadcast(data_tensor, 0, async_op=False) 315 | 316 | if args.distributed: 317 | train_sampler.set_epoch(epoch) 318 | adjust_learning_rate(optimizer, epoch, args) 319 | 320 | # train for one epoch 321 | train(train_loader, model, criterion, optimizer, epoch, args, cluster_result) 322 | 323 | if (epoch+1)%5==0 and (not args.multiprocessing_distributed or (args.multiprocessing_distributed 324 | and args.rank % ngpus_per_node == 0)): 325 | save_checkpoint({ 326 | 'epoch': epoch + 1, 327 | 'arch': args.arch, 328 | 'state_dict': model.state_dict(), 329 | 'optimizer' : optimizer.state_dict(), 330 | }, is_best=False, filename='{}/checkpoint_{:04d}.pth.tar'.format(args.exp_dir,epoch)) 331 | 332 | 333 | def train(train_loader, model, criterion, optimizer, epoch, args, cluster_result=None): 334 | batch_time = AverageMeter('Time', ':6.3f') 335 | data_time = AverageMeter('Data', ':6.3f') 336 | losses = AverageMeter('Loss', ':.4e') 337 | acc_inst = AverageMeter('Acc@Inst', ':6.2f') 338 | acc_proto = AverageMeter('Acc@Proto', ':6.2f') 339 | 340 | progress = ProgressMeter( 341 | len(train_loader), 342 | [batch_time, data_time, losses, acc_inst, acc_proto], 343 | prefix="Epoch: [{}]".format(epoch)) 344 | 345 | # switch to train mode 346 | model.train() 347 | 348 | end = time.time() 349 | for i, (images, index) in enumerate(train_loader): 350 | # measure data loading time 351 | data_time.update(time.time() - end) 352 | 353 | if args.gpu is not None: 354 | images[0] = images[0].cuda(args.gpu, non_blocking=True) 355 | images[1] = images[1].cuda(args.gpu, non_blocking=True) 356 | 357 | # compute output 358 | output, target, output_proto, target_proto = model(im_q=images[0], im_k=images[1], cluster_result=cluster_result, index=index) 359 | 360 | # InfoNCE loss 361 | loss = criterion(output, target) 362 | 363 | # ProtoNCE loss 364 | if output_proto is not None: 365 | loss_proto = 0 366 | for proto_out,proto_target in zip(output_proto, target_proto): 367 | loss_proto += criterion(proto_out, proto_target) 368 | accp = accuracy(proto_out, proto_target)[0] 369 | acc_proto.update(accp[0], images[0].size(0)) 370 | 371 | # average loss across all sets of prototypes 372 | loss_proto /= len(args.num_cluster) 373 | loss += loss_proto 374 | 375 | losses.update(loss.item(), images[0].size(0)) 376 | acc = accuracy(output, target)[0] 377 | acc_inst.update(acc[0], images[0].size(0)) 378 | 379 | # compute gradient and do SGD step 380 | optimizer.zero_grad() 381 | loss.backward() 382 | optimizer.step() 383 | 384 | # measure elapsed time 385 | batch_time.update(time.time() - end) 386 | end = time.time() 387 | 388 | if i % args.print_freq == 0: 389 | progress.display(i) 390 | 391 | 392 | def compute_features(eval_loader, model, args): 393 | print('Computing features...') 394 | model.eval() 395 | features = torch.zeros(len(eval_loader.dataset),args.low_dim).cuda() 396 | for i, (images, index) in enumerate(tqdm(eval_loader)): 397 | with torch.no_grad(): 398 | images = images.cuda(non_blocking=True) 399 | feat = model(images,is_eval=True) 400 | features[index] = feat 401 | dist.barrier() 402 | dist.all_reduce(features, op=dist.ReduceOp.SUM) 403 | return features.cpu() 404 | 405 | 406 | def run_kmeans(x, args): 407 | """ 408 | Args: 409 | x: data to be clustered 410 | """ 411 | 412 | print('performing kmeans clustering') 413 | results = {'im2cluster':[],'centroids':[],'density':[]} 414 | 415 | for seed, num_cluster in enumerate(args.num_cluster): 416 | # intialize faiss clustering parameters 417 | d = x.shape[1] 418 | k = int(num_cluster) 419 | clus = faiss.Clustering(d, k) 420 | clus.verbose = True 421 | clus.niter = 20 422 | clus.nredo = 5 423 | clus.seed = seed 424 | clus.max_points_per_centroid = 1000 425 | clus.min_points_per_centroid = 10 426 | 427 | res = faiss.StandardGpuResources() 428 | cfg = faiss.GpuIndexFlatConfig() 429 | cfg.useFloat16 = False 430 | cfg.device = args.gpu 431 | index = faiss.GpuIndexFlatL2(res, d, cfg) 432 | 433 | clus.train(x, index) 434 | 435 | D, I = index.search(x, 1) # for each sample, find cluster distance and assignments 436 | im2cluster = [int(n[0]) for n in I] 437 | 438 | # get cluster centroids 439 | centroids = faiss.vector_to_array(clus.centroids).reshape(k,d) 440 | 441 | # sample-to-centroid distances for each cluster 442 | Dcluster = [[] for c in range(k)] 443 | for im,i in enumerate(im2cluster): 444 | Dcluster[i].append(D[im][0]) 445 | 446 | # concentration estimation (phi) 447 | density = np.zeros(k) 448 | for i,dist in enumerate(Dcluster): 449 | if len(dist)>1: 450 | d = (np.asarray(dist)**0.5).mean()/np.log(len(dist)+10) 451 | density[i] = d 452 | 453 | #if cluster only has one point, use the max to estimate its concentration 454 | dmax = density.max() 455 | for i,dist in enumerate(Dcluster): 456 | if len(dist)<=1: 457 | density[i] = dmax 458 | 459 | density = density.clip(np.percentile(density,10),np.percentile(density,90)) #clamp extreme values for stability 460 | density = args.temperature*density/density.mean() #scale the mean to temperature 461 | 462 | # convert to cuda Tensors for broadcast 463 | centroids = torch.Tensor(centroids).cuda() 464 | centroids = nn.functional.normalize(centroids, p=2, dim=1) 465 | 466 | im2cluster = torch.LongTensor(im2cluster).cuda() 467 | density = torch.Tensor(density).cuda() 468 | 469 | results['centroids'].append(centroids) 470 | results['density'].append(density) 471 | results['im2cluster'].append(im2cluster) 472 | 473 | return results 474 | 475 | 476 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 477 | torch.save(state, filename) 478 | if is_best: 479 | shutil.copyfile(filename, 'model_best.pth.tar') 480 | 481 | 482 | class AverageMeter(object): 483 | """Computes and stores the average and current value""" 484 | def __init__(self, name, fmt=':f'): 485 | self.name = name 486 | self.fmt = fmt 487 | self.reset() 488 | 489 | def reset(self): 490 | self.val = 0 491 | self.avg = 0 492 | self.sum = 0 493 | self.count = 0 494 | 495 | def update(self, val, n=1): 496 | self.val = val 497 | self.sum += val * n 498 | self.count += n 499 | self.avg = self.sum / self.count 500 | 501 | def __str__(self): 502 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 503 | return fmtstr.format(**self.__dict__) 504 | 505 | 506 | class ProgressMeter(object): 507 | def __init__(self, num_batches, meters, prefix=""): 508 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 509 | self.meters = meters 510 | self.prefix = prefix 511 | 512 | def display(self, batch): 513 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 514 | entries += [str(meter) for meter in self.meters] 515 | print('\t'.join(entries)) 516 | 517 | def _get_batch_fmtstr(self, num_batches): 518 | num_digits = len(str(num_batches // 1)) 519 | fmt = '{:' + str(num_digits) + 'd}' 520 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 521 | 522 | 523 | def adjust_learning_rate(optimizer, epoch, args): 524 | """Decay the learning rate based on schedule""" 525 | lr = args.lr 526 | if args.cos: # cosine lr schedule 527 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 528 | else: # stepwise lr schedule 529 | for milestone in args.schedule: 530 | lr *= 0.1 if epoch >= milestone else 1. 531 | for param_group in optimizer.param_groups: 532 | param_group['lr'] = lr 533 | 534 | 535 | def accuracy(output, target, topk=(1,)): 536 | """Computes the accuracy over the k top predictions for the specified values of k""" 537 | with torch.no_grad(): 538 | maxk = max(topk) 539 | batch_size = target.size(0) 540 | 541 | _, pred = output.topk(maxk, 1, True, True) 542 | pred = pred.t() 543 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 544 | 545 | res = [] 546 | for k in topk: 547 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 548 | res.append(correct_k.mul_(100.0 / batch_size)) 549 | return res 550 | 551 | 552 | if __name__ == '__main__': 553 | main() 554 | -------------------------------------------------------------------------------- /pcl/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pcl/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from random import sample 4 | 5 | class MoCo(nn.Module): 6 | """ 7 | Build a MoCo model with: a query encoder, a key encoder, and a queue 8 | https://arxiv.org/abs/1911.05722 9 | """ 10 | def __init__(self, base_encoder, dim=128, r=16384, m=0.999, T=0.1, mlp=False): 11 | """ 12 | dim: feature dimension (default: 128) 13 | r: queue size; number of negative samples/prototypes (default: 16384) 14 | m: momentum for updating key encoder (default: 0.999) 15 | T: softmax temperature 16 | mlp: whether to use mlp projection 17 | """ 18 | super(MoCo, self).__init__() 19 | 20 | self.r = r 21 | self.m = m 22 | self.T = T 23 | 24 | # create the encoders 25 | # num_classes is the output fc dimension 26 | self.encoder_q = base_encoder(num_classes=dim) 27 | self.encoder_k = base_encoder(num_classes=dim) 28 | 29 | if mlp: # hack: brute-force replacement 30 | dim_mlp = self.encoder_q.fc.weight.shape[1] 31 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) 32 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) 33 | 34 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 35 | param_k.data.copy_(param_q.data) # initialize 36 | param_k.requires_grad = False # not update by gradient 37 | 38 | # create the queue 39 | self.register_buffer("queue", torch.randn(dim, r)) 40 | self.queue = nn.functional.normalize(self.queue, dim=0) 41 | 42 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 43 | 44 | @torch.no_grad() 45 | def _momentum_update_key_encoder(self): 46 | """ 47 | Momentum update of the key encoder 48 | """ 49 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 50 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 51 | 52 | @torch.no_grad() 53 | def _dequeue_and_enqueue(self, keys): 54 | # gather keys before updating queue 55 | keys = concat_all_gather(keys) 56 | 57 | batch_size = keys.shape[0] 58 | 59 | ptr = int(self.queue_ptr) 60 | assert self.r % batch_size == 0 # for simplicity 61 | 62 | # replace the keys at ptr (dequeue and enqueue) 63 | self.queue[:, ptr:ptr + batch_size] = keys.T 64 | ptr = (ptr + batch_size) % self.r # move pointer 65 | 66 | self.queue_ptr[0] = ptr 67 | 68 | @torch.no_grad() 69 | def _batch_shuffle_ddp(self, x): 70 | """ 71 | Batch shuffle, for making use of BatchNorm. 72 | *** Only support DistributedDataParallel (DDP) model. *** 73 | """ 74 | # gather from all gpus 75 | batch_size_this = x.shape[0] 76 | x_gather = concat_all_gather(x) 77 | batch_size_all = x_gather.shape[0] 78 | 79 | num_gpus = batch_size_all // batch_size_this 80 | 81 | # random shuffle index 82 | idx_shuffle = torch.randperm(batch_size_all).cuda() 83 | 84 | # broadcast to all gpus 85 | torch.distributed.broadcast(idx_shuffle, src=0) 86 | 87 | # index for restoring 88 | idx_unshuffle = torch.argsort(idx_shuffle) 89 | 90 | # shuffled index for this gpu 91 | gpu_idx = torch.distributed.get_rank() 92 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 93 | 94 | return x_gather[idx_this], idx_unshuffle 95 | 96 | @torch.no_grad() 97 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 98 | """ 99 | Undo batch shuffle. 100 | *** Only support DistributedDataParallel (DDP) model. *** 101 | """ 102 | # gather from all gpus 103 | batch_size_this = x.shape[0] 104 | x_gather = concat_all_gather(x) 105 | batch_size_all = x_gather.shape[0] 106 | 107 | num_gpus = batch_size_all // batch_size_this 108 | 109 | # restored index for this gpu 110 | gpu_idx = torch.distributed.get_rank() 111 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 112 | 113 | return x_gather[idx_this] 114 | 115 | def forward(self, im_q, im_k=None, is_eval=False, cluster_result=None, index=None): 116 | """ 117 | Input: 118 | im_q: a batch of query images 119 | im_k: a batch of key images 120 | is_eval: return momentum embeddings (used for clustering) 121 | cluster_result: cluster assignments, centroids, and density 122 | index: indices for training samples 123 | Output: 124 | logits, targets, proto_logits, proto_targets 125 | """ 126 | 127 | if is_eval: 128 | k = self.encoder_k(im_q) 129 | k = nn.functional.normalize(k, dim=1) 130 | return k 131 | 132 | # compute key features 133 | with torch.no_grad(): # no gradient to keys 134 | self._momentum_update_key_encoder() # update the key encoder 135 | 136 | # shuffle for making use of BN 137 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) 138 | 139 | k = self.encoder_k(im_k) # keys: NxC 140 | k = nn.functional.normalize(k, dim=1) 141 | 142 | # undo shuffle 143 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 144 | 145 | # compute query features 146 | q = self.encoder_q(im_q) # queries: NxC 147 | q = nn.functional.normalize(q, dim=1) 148 | 149 | # compute logits 150 | # Einstein sum is more intuitive 151 | # positive logits: Nx1 152 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 153 | # negative logits: Nxr 154 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 155 | 156 | # logits: Nx(1+r) 157 | logits = torch.cat([l_pos, l_neg], dim=1) 158 | 159 | # apply temperature 160 | logits /= self.T 161 | 162 | # labels: positive key indicators 163 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 164 | 165 | # dequeue and enqueue 166 | self._dequeue_and_enqueue(k) 167 | 168 | # prototypical contrast 169 | if cluster_result is not None: 170 | proto_labels = [] 171 | proto_logits = [] 172 | for n, (im2cluster,prototypes,density) in enumerate(zip(cluster_result['im2cluster'],cluster_result['centroids'],cluster_result['density'])): 173 | # get positive prototypes 174 | pos_proto_id = im2cluster[index] 175 | pos_prototypes = prototypes[pos_proto_id] 176 | 177 | # sample negative prototypes 178 | all_proto_id = [i for i in range(im2cluster.max()+1)] 179 | neg_proto_id = set(all_proto_id)-set(pos_proto_id.tolist()) 180 | neg_proto_id = sample(neg_proto_id,self.r) #sample r negative prototypes 181 | neg_prototypes = prototypes[neg_proto_id] 182 | 183 | proto_selected = torch.cat([pos_prototypes,neg_prototypes],dim=0) 184 | 185 | # compute prototypical logits 186 | logits_proto = torch.mm(q,proto_selected.t()) 187 | 188 | # targets for prototype assignment 189 | labels_proto = torch.linspace(0, q.size(0)-1, steps=q.size(0)).long().cuda() 190 | 191 | # scaling temperatures for the selected prototypes 192 | temp_proto = density[torch.cat([pos_proto_id,torch.LongTensor(neg_proto_id).cuda()],dim=0)] 193 | logits_proto /= temp_proto 194 | 195 | proto_labels.append(labels_proto) 196 | proto_logits.append(logits_proto) 197 | return logits, labels, proto_logits, proto_labels 198 | else: 199 | return logits, labels, None, None 200 | 201 | 202 | # utils 203 | @torch.no_grad() 204 | def concat_all_gather(tensor): 205 | """ 206 | Performs all_gather operation on the provided tensors. 207 | *** Warning ***: torch.distributed.all_gather has no gradient. 208 | """ 209 | tensors_gather = [torch.ones_like(tensor) 210 | for _ in range(torch.distributed.get_world_size())] 211 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 212 | 213 | output = torch.cat(tensors_gather, dim=0) 214 | return output 215 | -------------------------------------------------------------------------------- /pcl/loader.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageFilter 2 | import random 3 | import torchvision.datasets as datasets 4 | 5 | 6 | class TwoCropsTransform: 7 | """Take two random crops of one image as the query and key.""" 8 | 9 | def __init__(self, base_transform): 10 | self.base_transform = base_transform 11 | 12 | def __call__(self, x): 13 | q = self.base_transform(x) 14 | k = self.base_transform(x) 15 | return [q, k] 16 | 17 | 18 | class GaussianBlur(object): 19 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 20 | 21 | def __init__(self, sigma=[.1, 2.]): 22 | self.sigma = sigma 23 | 24 | def __call__(self, x): 25 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 26 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 27 | return x 28 | 29 | 30 | class ImageFolderInstance(datasets.ImageFolder): 31 | def __getitem__(self, index): 32 | path, target = self.samples[index] 33 | sample = self.loader(path) 34 | if self.transform is not None: 35 | sample = self.transform(sample) 36 | return sample, index -------------------------------------------------------------------------------- /voc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import 2 | import csv 3 | import os 4 | import os.path 5 | import tarfile 6 | from six.moves.urllib.parse import urlparse 7 | 8 | import numpy as np 9 | import torch 10 | import torch.utils.data as data 11 | from PIL import Image 12 | import random 13 | 14 | from tqdm import tqdm 15 | from six.moves.urllib.request import urlretrieve 16 | 17 | object_categories = ['aeroplane', 'bicycle', 'bird', 'boat', 18 | 'bottle', 'bus', 'car', 'cat', 'chair', 19 | 'cow', 'diningtable', 'dog', 'horse', 20 | 'motorbike', 'person', 'pottedplant', 21 | 'sheep', 'sofa', 'train', 'tvmonitor'] 22 | 23 | urls = { 24 | 'devkit': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar', 25 | 'trainval_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', 26 | 'test_images_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar', 27 | 'test_anno_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtestnoimgs_06-Nov-2007.tar', 28 | } 29 | 30 | 31 | def read_image_label(file): 32 | print('[dataset] read ' + file) 33 | data = dict() 34 | with open(file, 'r') as f: 35 | for line in f: 36 | tmp = line.split(' ') 37 | name = tmp[0] 38 | label = int(tmp[-1]) 39 | data[name] = label 40 | # data.append([name, label]) 41 | # print('%s %d' % (name, label)) 42 | return data 43 | 44 | 45 | def read_object_labels(root, dataset, set): 46 | path_labels = os.path.join(root, 'VOCdevkit', dataset, 'ImageSets', 'Main') 47 | labeled_data = dict() 48 | num_classes = len(object_categories) 49 | 50 | for i in range(num_classes): 51 | file = os.path.join(path_labels, object_categories[i] + '_' + set + '.txt') 52 | data = read_image_label(file) 53 | 54 | if i == 0: 55 | for (name, label) in data.items(): 56 | labels = np.zeros(num_classes) 57 | labels[i] = label 58 | labeled_data[name] = labels 59 | else: 60 | for (name, label) in data.items(): 61 | labeled_data[name][i] = label 62 | 63 | return labeled_data 64 | 65 | 66 | def write_object_labels_csv(file, labeled_data): 67 | # write a csv file 68 | print('[dataset] write file %s' % file) 69 | with open(file, 'w') as csvfile: 70 | fieldnames = ['name'] 71 | fieldnames.extend(object_categories) 72 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 73 | 74 | writer.writeheader() 75 | for (name, labels) in labeled_data.items(): 76 | example = {'name': name} 77 | for i in range(20): 78 | example[fieldnames[i + 1]] = int(labels[i]) 79 | writer.writerow(example) 80 | 81 | csvfile.close() 82 | 83 | 84 | def read_object_labels_csv(file, header=True): 85 | images = [] 86 | num_categories = 0 87 | print('[dataset] read', file) 88 | with open(file, 'r') as f: 89 | reader = csv.reader(f) 90 | rownum = 0 91 | for row in reader: 92 | if header and rownum == 0: 93 | header = row 94 | else: 95 | if num_categories == 0: 96 | num_categories = len(row) - 1 97 | name = row[0] 98 | labels = (np.asarray(row[1:num_categories + 1])).astype(np.float32) 99 | labels = torch.from_numpy(labels) 100 | item = (name, labels) 101 | images.append(item) 102 | rownum += 1 103 | return images 104 | 105 | 106 | def find_images_classification(root, dataset, set): 107 | path_labels = os.path.join(root, 'VOCdevkit', dataset, 'ImageSets', 'Main') 108 | images = [] 109 | file = os.path.join(path_labels, set + '.txt') 110 | with open(file, 'r') as f: 111 | for line in f: 112 | images.append(line) 113 | return images 114 | 115 | 116 | def download_voc2007(root): 117 | path_devkit = os.path.join(root, 'VOCdevkit') 118 | path_images = os.path.join(root, 'VOCdevkit', 'VOC2007', 'JPEGImages') 119 | tmpdir = os.path.join(root, 'tmp') 120 | 121 | # create directory 122 | if not os.path.exists(root): 123 | os.makedirs(root) 124 | 125 | if not os.path.exists(path_devkit): 126 | 127 | if not os.path.exists(tmpdir): 128 | os.makedirs(tmpdir) 129 | 130 | parts = urlparse(urls['devkit']) 131 | filename = os.path.basename(parts.path) 132 | cached_file = os.path.join(tmpdir, filename) 133 | 134 | if not os.path.exists(cached_file): 135 | print('Downloading: "{}" to {}\n'.format(urls['devkit'], cached_file)) 136 | download_url(urls['devkit'], cached_file) 137 | 138 | # extract file 139 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 140 | cwd = os.getcwd() 141 | tar = tarfile.open(cached_file, "r") 142 | os.chdir(root) 143 | tar.extractall() 144 | tar.close() 145 | os.chdir(cwd) 146 | print('[dataset] Done!') 147 | 148 | # train/val images/annotations 149 | if not os.path.exists(path_images): 150 | 151 | # download train/val images/annotations 152 | parts = urlparse(urls['trainval_2007']) 153 | filename = os.path.basename(parts.path) 154 | cached_file = os.path.join(tmpdir, filename) 155 | 156 | if not os.path.exists(cached_file): 157 | print('Downloading: "{}" to {}\n'.format(urls['trainval_2007'], cached_file)) 158 | download_url(urls['trainval_2007'], cached_file) 159 | 160 | # extract file 161 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 162 | cwd = os.getcwd() 163 | tar = tarfile.open(cached_file, "r") 164 | os.chdir(root) 165 | tar.extractall() 166 | tar.close() 167 | os.chdir(cwd) 168 | print('[dataset] Done!') 169 | 170 | # test annotations 171 | test_anno = os.path.join(path_devkit, 'VOC2007/ImageSets/Main/aeroplane_test.txt') 172 | if not os.path.exists(test_anno): 173 | 174 | # download test annotations 175 | parts = urlparse(urls['test_images_2007']) 176 | filename = os.path.basename(parts.path) 177 | cached_file = os.path.join(tmpdir, filename) 178 | 179 | if not os.path.exists(cached_file): 180 | print('Downloading: "{}" to {}\n'.format(urls['test_images_2007'], cached_file)) 181 | download_url(urls['test_images_2007'], cached_file) 182 | 183 | # extract file 184 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 185 | cwd = os.getcwd() 186 | tar = tarfile.open(cached_file, "r") 187 | os.chdir(root) 188 | tar.extractall() 189 | tar.close() 190 | os.chdir(cwd) 191 | print('[dataset] Done!') 192 | 193 | # test images 194 | test_image = os.path.join(path_devkit, 'VOC2007/JPEGImages/000001.jpg') 195 | if not os.path.exists(test_image): 196 | 197 | # download test images 198 | parts = urlparse(urls['test_anno_2007']) 199 | filename = os.path.basename(parts.path) 200 | cached_file = os.path.join(tmpdir, filename) 201 | 202 | if not os.path.exists(cached_file): 203 | print('Downloading: "{}" to {}\n'.format(urls['test_anno_2007'], cached_file)) 204 | download_url(urls['test_anno_2007'], cached_file) 205 | 206 | # extract file 207 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 208 | cwd = os.getcwd() 209 | tar = tarfile.open(cached_file, "r") 210 | os.chdir(root) 211 | tar.extractall() 212 | tar.close() 213 | os.chdir(cwd) 214 | print('[dataset] Done!') 215 | 216 | def download_url(url, destination=None, progress_bar=True): 217 | """Download a URL to a local file. 218 | Parameters 219 | ---------- 220 | url : str 221 | The URL to download. 222 | destination : str, None 223 | The destination of the file. If None is given the file is saved to a temporary directory. 224 | progress_bar : bool 225 | Whether to show a command-line progress bar while downloading. 226 | Returns 227 | ------- 228 | filename : str 229 | The location of the downloaded file. 230 | Notes 231 | ----- 232 | Progress bar use/example adapted from tqdm documentation: https://github.com/tqdm/tqdm 233 | """ 234 | 235 | def my_hook(t): 236 | last_b = [0] 237 | 238 | def inner(b=1, bsize=1, tsize=None): 239 | if tsize is not None: 240 | t.total = tsize 241 | if b > 0: 242 | t.update((b - last_b[0]) * bsize) 243 | last_b[0] = b 244 | 245 | return inner 246 | 247 | if progress_bar: 248 | with tqdm(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as t: 249 | filename, _ = urlretrieve(url, filename=destination, reporthook=my_hook(t)) 250 | else: 251 | filename, _ = urlretrieve(url, filename=destination) 252 | 253 | class Voc2007Classification(data.Dataset): 254 | 255 | def __init__(self, root, set, transform=None, target_transform=None): 256 | self.root = root 257 | self.path_devkit = os.path.join(root, 'VOCdevkit') 258 | self.path_images = os.path.join(root, 'VOCdevkit', 'VOC2007', 'JPEGImages') 259 | self.set = set 260 | self.transform = transform 261 | self.target_transform = target_transform 262 | self.low_shot = False 263 | 264 | # download dataset 265 | download_voc2007(self.root) 266 | 267 | # define path of csv file 268 | path_csv = os.path.join(self.root, 'files', 'VOC2007') 269 | # define filename of csv file 270 | file_csv = os.path.join(path_csv, 'classification_' + set + '.csv') 271 | 272 | # create the csv file if necessary 273 | if not os.path.exists(file_csv): 274 | if not os.path.exists(path_csv): # create dir if necessary 275 | os.makedirs(path_csv) 276 | # generate csv file 277 | labeled_data = read_object_labels(self.root, 'VOC2007', self.set) 278 | # write csv file 279 | write_object_labels_csv(file_csv, labeled_data) 280 | 281 | self.classes = object_categories 282 | self.images = read_object_labels_csv(file_csv) 283 | 284 | print('[dataset] VOC 2007 classification set=%s number of classes=%d number of images=%d' % ( 285 | set, len(self.classes), len(self.images))) 286 | 287 | def __getitem__(self, index): 288 | if self.low_shot: 289 | path, target = self.images_lowshot[index] 290 | else: 291 | path, target = self.images[index] 292 | img = Image.open(os.path.join(self.path_images, path + '.jpg')).convert('RGB') 293 | if self.transform is not None: 294 | img = self.transform(img) 295 | if self.target_transform is not None: 296 | target = self.target_transform(target) 297 | return img, target 298 | 299 | def __len__(self): 300 | if self.low_shot: 301 | return len(self.images_lowshot) 302 | else: 303 | return len(self.images) 304 | 305 | def get_number_classes(self): 306 | return len(self.classes) 307 | 308 | 309 | def convert_low_shot(self, k): #sample k images per class 310 | label2img = {c:[] for c in range(len(self.classes))} 311 | for img in self.images: 312 | label = img[1] 313 | label_classes = torch.where(label>0)[0] 314 | for c in label_classes: 315 | label2img[c.item()].append(img) 316 | 317 | self.images_lowshot = [] 318 | for c,imlist in label2img.items(): 319 | random.shuffle(imlist) 320 | self.images_lowshot += imlist[:k] 321 | self.low_shot = True --------------------------------------------------------------------------------