├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING-ARCHIVED.md ├── LICENSE.txt ├── README.md ├── SECURITY.md ├── Train_CoMatch.py ├── Train_fixmatch.py ├── WideResNet.py ├── comatch.gif ├── datasets ├── __init__.py ├── cifar.py ├── randaugment.py ├── sampler.py └── transform.py ├── imagenet ├── Model.py ├── README.md ├── Train_CoMatch.py ├── loader.py └── resnet.py └── utils.py /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing. 2 | #ECCN:Open Source 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Salesforce Open Source Community Code of Conduct 2 | 3 | ## About the Code of Conduct 4 | 5 | Equality is a core value at Salesforce. We believe a diverse and inclusive 6 | community fosters innovation and creativity, and are committed to building a 7 | culture where everyone feels included. 8 | 9 | Salesforce open-source projects are committed to providing a friendly, safe, and 10 | welcoming environment for all, regardless of gender identity and expression, 11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality, 12 | race, age, religion, level of experience, education, socioeconomic status, or 13 | other similar personal characteristics. 14 | 15 | The goal of this code of conduct is to specify a baseline standard of behavior so 16 | that people with different social values and communication styles can work 17 | together effectively, productively, and respectfully in our open source community. 18 | It also establishes a mechanism for reporting issues and resolving conflicts. 19 | 20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior 21 | in a Salesforce open-source project may be reported by contacting the Salesforce 22 | Open Source Conduct Committee at ossconduct@salesforce.com. 23 | 24 | ## Our Pledge 25 | 26 | In the interest of fostering an open and welcoming environment, we as 27 | contributors and maintainers pledge to making participation in our project and 28 | our community a harassment-free experience for everyone, regardless of gender 29 | identity and expression, sexual orientation, disability, physical appearance, 30 | body size, ethnicity, nationality, race, age, religion, level of experience, education, 31 | socioeconomic status, or other similar personal characteristics. 32 | 33 | ## Our Standards 34 | 35 | Examples of behavior that contributes to creating a positive environment 36 | include: 37 | 38 | * Using welcoming and inclusive language 39 | * Being respectful of differing viewpoints and experiences 40 | * Gracefully accepting constructive criticism 41 | * Focusing on what is best for the community 42 | * Showing empathy toward other community members 43 | 44 | Examples of unacceptable behavior by participants include: 45 | 46 | * The use of sexualized language or imagery and unwelcome sexual attention or 47 | advances 48 | * Personal attacks, insulting/derogatory comments, or trolling 49 | * Public or private harassment 50 | * Publishing, or threatening to publish, others' private information—such as 51 | a physical or electronic address—without explicit permission 52 | * Other conduct which could reasonably be considered inappropriate in a 53 | professional setting 54 | * Advocating for or encouraging any of the above behaviors 55 | 56 | ## Our Responsibilities 57 | 58 | Project maintainers are responsible for clarifying the standards of acceptable 59 | behavior and are expected to take appropriate and fair corrective action in 60 | response to any instances of unacceptable behavior. 61 | 62 | Project maintainers have the right and responsibility to remove, edit, or 63 | reject comments, commits, code, wiki edits, issues, and other contributions 64 | that are not aligned with this Code of Conduct, or to ban temporarily or 65 | permanently any contributor for other behaviors that they deem inappropriate, 66 | threatening, offensive, or harmful. 67 | 68 | ## Scope 69 | 70 | This Code of Conduct applies both within project spaces and in public spaces 71 | when an individual is representing the project or its community. Examples of 72 | representing a project or community include using an official project email 73 | address, posting via an official social media account, or acting as an appointed 74 | representative at an online or offline event. Representation of a project may be 75 | further defined and clarified by project maintainers. 76 | 77 | ## Enforcement 78 | 79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 80 | reported by contacting the Salesforce Open Source Conduct Committee 81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated 82 | and will result in a response that is deemed necessary and appropriate to the 83 | circumstances. The committee is obligated to maintain confidentiality with 84 | regard to the reporter of an incident. Further details of specific enforcement 85 | policies may be posted separately. 86 | 87 | Project maintainers who do not follow or enforce the Code of Conduct in good 88 | faith may face temporary or permanent repercussions as determined by other 89 | members of the project's leadership and the Salesforce Open Source Conduct 90 | Committee. 91 | 92 | ## Attribution 93 | 94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home], 95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. 96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc], 97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc]. 98 | 99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us]. 100 | 101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/) 102 | [golang-coc]: https://golang.org/conduct 103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md 104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/ 105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/ 106 | -------------------------------------------------------------------------------- /CONTRIBUTING-ARCHIVED.md: -------------------------------------------------------------------------------- 1 | # ARCHIVED 2 | 3 | This project is `Archived` and is no longer actively maintained; 4 | We are not accepting contributions or Pull Requests. 5 | 6 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, Salesforce.com, Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## CoMatch: Semi-supervised Learning with Contrastive Graph Regularization, ICCV 2021 (Salesforce Research). 2 | 3 | 4 | This is a PyTorch implementation of the CoMatch paper [Blog]: 5 |
 6 | @inproceedings{CoMatch,
 7 | 	title={Semi-supervised Learning with Contrastive Graph Regularization},
 8 | 	author={Junnan Li and Caiming Xiong and Steven C.H. Hoi},
 9 | 	booktitle={ICCV},
10 | 	year={2021}
11 | }
12 | 13 | ### Requirements: 14 | * PyTorch ≥ 1.4 15 | * pip install tensorboard_logger 16 | * download and extract cifar-10 dataset into ./data/ 17 | 18 | To perform semi-supervised learning on CIFAR-10 with 4 labels per class, run: 19 |
python Train_CoMatch.py --n-labeled 40 --seed 1 
20 | 21 | The results using different random seeds are: 22 | 23 | seed| 1 | 2 | 3 | 4 | 5 | avg 24 | --- | --- | --- | --- | --- | --- | --- 25 | accuracy|93.71|94.10|92.93|90.73|93.97|93.09 26 | 27 | ### ImageNet 28 | For ImageNet experiments, see ./imagenet/ 29 | 30 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) 4 | as soon as it is discovered. This library limits its runtime dependencies in 5 | order to reduce the total cost of ownership as much as can be, but all consumers 6 | should remain vigilant and have their security stakeholders review all third-party 7 | products (3PP) like this one and their dependencies. 8 | -------------------------------------------------------------------------------- /Train_CoMatch.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2018, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | ''' 7 | from __future__ import print_function 8 | import random 9 | 10 | import time 11 | import argparse 12 | import os 13 | import sys 14 | 15 | import numpy as np 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | from WideResNet import WideResnet 22 | from datasets.cifar import get_train_loader, get_val_loader 23 | from utils import accuracy, setup_default_logging, AverageMeter, WarmupCosineLrScheduler 24 | 25 | import tensorboard_logger 26 | 27 | def set_model(args): 28 | model = WideResnet(n_classes=args.n_classes,k=args.wresnet_k, n=args.wresnet_n, proj=True) 29 | if args.checkpoint: 30 | checkpoint = torch.load(args.checkpoint) 31 | msg = model.load_state_dict(checkpoint, strict=False) 32 | assert set(msg.missing_keys) == {"classifier.weight", "classifier.bias"} 33 | print('loaded from checkpoint: %s'%args.checkpoint) 34 | model.train() 35 | model.cuda() 36 | 37 | if args.eval_ema: 38 | ema_model = WideResnet(n_classes=args.n_classes,k=args.wresnet_k, n=args.wresnet_n, proj=True) 39 | for param_q, param_k in zip(model.parameters(), ema_model.parameters()): 40 | param_k.data.copy_(param_q.detach().data) # initialize 41 | param_k.requires_grad = False # not update by gradient for eval_net 42 | ema_model.cuda() 43 | ema_model.eval() 44 | else: 45 | ema_model = None 46 | 47 | criteria_x = nn.CrossEntropyLoss().cuda() 48 | return model, criteria_x, ema_model 49 | 50 | @torch.no_grad() 51 | def ema_model_update(model, ema_model, ema_m): 52 | """ 53 | Momentum update of evaluation model (exponential moving average) 54 | """ 55 | for param_train, param_eval in zip(model.parameters(), ema_model.parameters()): 56 | param_eval.copy_(param_eval * ema_m + param_train.detach() * (1-ema_m)) 57 | 58 | for buffer_train, buffer_eval in zip(model.buffers(), ema_model.buffers()): 59 | buffer_eval.copy_(buffer_train) 60 | 61 | def train_one_epoch(epoch, 62 | model, 63 | ema_model, 64 | prob_list, 65 | criteria_x, 66 | optim, 67 | lr_schdlr, 68 | dltrain_x, 69 | dltrain_u, 70 | args, 71 | n_iters, 72 | logger, 73 | queue_feats, 74 | queue_probs, 75 | queue_ptr, 76 | ): 77 | 78 | model.train() 79 | loss_x_meter = AverageMeter() 80 | loss_u_meter = AverageMeter() 81 | loss_contrast_meter = AverageMeter() 82 | # the number of correct pseudo-labels 83 | n_correct_u_lbs_meter = AverageMeter() 84 | # the number of confident unlabeled data 85 | n_strong_aug_meter = AverageMeter() 86 | mask_meter = AverageMeter() 87 | # the number of edges in the pseudo-label graph 88 | pos_meter = AverageMeter() 89 | 90 | epoch_start = time.time() # start time 91 | dl_x, dl_u = iter(dltrain_x), iter(dltrain_u) 92 | for it in range(n_iters): 93 | ims_x_weak, lbs_x = next(dl_x) 94 | (ims_u_weak, ims_u_strong0, ims_u_strong1), lbs_u_real = next(dl_u) 95 | 96 | lbs_x = lbs_x.cuda() 97 | lbs_u_real = lbs_u_real.cuda() 98 | 99 | # -------------------------------------- 100 | bt = ims_x_weak.size(0) 101 | btu = ims_u_weak.size(0) 102 | 103 | imgs = torch.cat([ims_x_weak, ims_u_weak, ims_u_strong0, ims_u_strong1], dim=0).cuda() 104 | logits, features = model(imgs) 105 | 106 | logits_x = logits[:bt] 107 | logits_u_w, logits_u_s0, logits_u_s1 = torch.split(logits[bt:], btu) 108 | 109 | feats_x = features[:bt] 110 | feats_u_w, feats_u_s0, feats_u_s1 = torch.split(features[bt:], btu) 111 | 112 | loss_x = criteria_x(logits_x, lbs_x) 113 | 114 | with torch.no_grad(): 115 | logits_u_w = logits_u_w.detach() 116 | feats_x = feats_x.detach() 117 | feats_u_w = feats_u_w.detach() 118 | 119 | probs = torch.softmax(logits_u_w, dim=1) 120 | # DA 121 | prob_list.append(probs.mean(0)) 122 | if len(prob_list)>32: 123 | prob_list.pop(0) 124 | prob_avg = torch.stack(prob_list,dim=0).mean(0) 125 | probs = probs / prob_avg 126 | probs = probs / probs.sum(dim=1, keepdim=True) 127 | 128 | probs_orig = probs.clone() 129 | 130 | if epoch>0 or it>args.queue_batch: # memory-smoothing 131 | A = torch.exp(torch.mm(feats_u_w, queue_feats.t())/args.temperature) 132 | A = A/A.sum(1,keepdim=True) 133 | probs = args.alpha*probs + (1-args.alpha)*torch.mm(A, queue_probs) 134 | 135 | scores, lbs_u_guess = torch.max(probs, dim=1) 136 | mask = scores.ge(args.thr).float() 137 | 138 | feats_w = torch.cat([feats_u_w,feats_x],dim=0) 139 | onehot = torch.zeros(bt,args.n_classes).cuda().scatter(1,lbs_x.view(-1,1),1) 140 | probs_w = torch.cat([probs_orig,onehot],dim=0) 141 | 142 | # update memory bank 143 | n = bt+btu 144 | queue_feats[queue_ptr:queue_ptr + n,:] = feats_w 145 | queue_probs[queue_ptr:queue_ptr + n,:] = probs_w 146 | queue_ptr = (queue_ptr+n)%args.queue_size 147 | 148 | 149 | # embedding similarity 150 | sim = torch.exp(torch.mm(feats_u_s0, feats_u_s1.t())/args.temperature) 151 | sim_probs = sim / sim.sum(1, keepdim=True) 152 | 153 | # pseudo-label graph with self-loop 154 | Q = torch.mm(probs, probs.t()) 155 | Q.fill_diagonal_(1) 156 | pos_mask = (Q>=args.contrast_th).float() 157 | 158 | Q = Q * pos_mask 159 | Q = Q / Q.sum(1, keepdim=True) 160 | 161 | # contrastive loss 162 | loss_contrast = - (torch.log(sim_probs + 1e-7) * Q).sum(1) 163 | loss_contrast = loss_contrast.mean() 164 | 165 | # unsupervised classification loss 166 | loss_u = - torch.sum((F.log_softmax(logits_u_s0,dim=1) * probs),dim=1) * mask 167 | loss_u = loss_u.mean() 168 | 169 | loss = loss_x + args.lam_u * loss_u + args.lam_c * loss_contrast 170 | 171 | optim.zero_grad() 172 | loss.backward() 173 | optim.step() 174 | lr_schdlr.step() 175 | 176 | if args.eval_ema: 177 | with torch.no_grad(): 178 | ema_model_update(model, ema_model, args.ema_m) 179 | 180 | loss_x_meter.update(loss_x.item()) 181 | loss_u_meter.update(loss_u.item()) 182 | loss_contrast_meter.update(loss_contrast.item()) 183 | mask_meter.update(mask.mean().item()) 184 | pos_meter.update(pos_mask.sum(1).float().mean().item()) 185 | 186 | corr_u_lb = (lbs_u_guess == lbs_u_real).float() * mask 187 | n_correct_u_lbs_meter.update(corr_u_lb.sum().item()) 188 | n_strong_aug_meter.update(mask.sum().item()) 189 | 190 | if (it + 1) % 64 == 0: 191 | t = time.time() - epoch_start 192 | 193 | lr_log = [pg['lr'] for pg in optim.param_groups] 194 | lr_log = sum(lr_log) / len(lr_log) 195 | 196 | logger.info("{}-x{}-s{}, {} | epoch:{}, iter: {}. loss_u: {:.3f}. loss_x: {:.3f}. loss_c: {:.3f}. " 197 | "n_correct_u: {:.2f}/{:.2f}. Mask:{:.3f}. num_pos: {:.1f}. LR: {:.3f}. Time: {:.2f}".format( 198 | args.dataset, args.n_labeled, args.seed, args.exp_dir, epoch, it + 1, loss_u_meter.avg, loss_x_meter.avg, loss_contrast_meter.avg, n_correct_u_lbs_meter.avg, n_strong_aug_meter.avg, mask_meter.avg, pos_meter.avg, lr_log, t)) 199 | 200 | epoch_start = time.time() 201 | 202 | return loss_x_meter.avg, loss_u_meter.avg, loss_contrast_meter.avg, mask_meter.avg, pos_meter.avg, n_correct_u_lbs_meter.avg/n_strong_aug_meter.avg, queue_feats, queue_probs, queue_ptr, prob_list 203 | 204 | 205 | def evaluate(model, ema_model, dataloader): 206 | 207 | model.eval() 208 | 209 | top1_meter = AverageMeter() 210 | ema_top1_meter = AverageMeter() 211 | 212 | with torch.no_grad(): 213 | for ims, lbs in dataloader: 214 | ims = ims.cuda() 215 | lbs = lbs.cuda() 216 | 217 | logits, _ = model(ims) 218 | scores = torch.softmax(logits, dim=1) 219 | top1, top5 = accuracy(scores, lbs, (1, 5)) 220 | top1_meter.update(top1.item()) 221 | 222 | if ema_model is not None: 223 | logits, _ = ema_model(ims) 224 | scores = torch.softmax(logits, dim=1) 225 | top1, top5 = accuracy(scores, lbs, (1, 5)) 226 | ema_top1_meter.update(top1.item()) 227 | 228 | return top1_meter.avg, ema_top1_meter.avg 229 | 230 | 231 | def main(): 232 | parser = argparse.ArgumentParser(description='CoMatch Cifar Training') 233 | parser.add_argument('--root', default='./data', type=str, help='dataset directory') 234 | parser.add_argument('--wresnet-k', default=2, type=int, 235 | help='width factor of wide resnet') 236 | parser.add_argument('--wresnet-n', default=28, type=int, 237 | help='depth of wide resnet') 238 | parser.add_argument('--dataset', type=str, default='CIFAR10', 239 | help='number of classes in dataset') 240 | parser.add_argument('--n-classes', type=int, default=10, 241 | help='number of classes in dataset') 242 | parser.add_argument('--n-labeled', type=int, default=40, 243 | help='number of labeled samples for training') 244 | parser.add_argument('--n-epoches', type=int, default=512, 245 | help='number of training epoches') 246 | parser.add_argument('--batchsize', type=int, default=64, 247 | help='train batch size of labeled samples') 248 | parser.add_argument('--mu', type=int, default=7, 249 | help='factor of train batch size of unlabeled samples') 250 | parser.add_argument('--n-imgs-per-epoch', type=int, default=64 * 1024, 251 | help='number of training images for each epoch') 252 | 253 | parser.add_argument('--eval-ema', default=True, help='whether to use ema model for evaluation') 254 | parser.add_argument('--ema-m', type=float, default=0.999) 255 | 256 | parser.add_argument('--lam-u', type=float, default=1., 257 | help='coefficient of unlabeled loss') 258 | parser.add_argument('--lr', type=float, default=0.03, 259 | help='learning rate for training') 260 | parser.add_argument('--weight-decay', type=float, default=5e-4, 261 | help='weight decay') 262 | parser.add_argument('--momentum', type=float, default=0.9, 263 | help='momentum for optimizer') 264 | parser.add_argument('--seed', type=int, default=1, 265 | help='seed for random behaviors, no seed if negtive') 266 | 267 | parser.add_argument('--temperature', default=0.2, type=float, help='softmax temperature') 268 | parser.add_argument('--low-dim', type=int, default=64) 269 | parser.add_argument('--lam-c', type=float, default=1, 270 | help='coefficient of contrastive loss') 271 | parser.add_argument('--contrast-th', default=0.8, type=float, 272 | help='pseudo label graph threshold') 273 | parser.add_argument('--thr', type=float, default=0.95, 274 | help='pseudo label threshold') 275 | parser.add_argument('--alpha', type=float, default=0.9) 276 | parser.add_argument('--queue-batch', type=float, default=5, 277 | help='number of batches stored in memory bank') 278 | parser.add_argument('--exp-dir', default='CoMatch', type=str, help='experiment id') 279 | parser.add_argument('--checkpoint', default='', type=str, help='use pretrained model') 280 | 281 | args = parser.parse_args() 282 | 283 | logger, output_dir = setup_default_logging(args) 284 | logger.info(dict(args._get_kwargs())) 285 | 286 | tb_logger = tensorboard_logger.Logger(logdir=output_dir, flush_secs=2) 287 | 288 | if args.seed > 0: 289 | torch.manual_seed(args.seed) 290 | random.seed(args.seed) 291 | np.random.seed(args.seed) 292 | 293 | n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize # 1024 294 | n_iters_all = n_iters_per_epoch * args.n_epoches # 1024 * 200 295 | 296 | logger.info("***** Running training *****") 297 | logger.info(f" Task = {args.dataset}@{args.n_labeled}") 298 | 299 | model, criteria_x, ema_model = set_model(args) 300 | logger.info("Total params: {:.2f}M".format( 301 | sum(p.numel() for p in model.parameters()) / 1e6)) 302 | 303 | dltrain_x, dltrain_u = get_train_loader( 304 | args.dataset, args.batchsize, args.mu, n_iters_per_epoch, L=args.n_labeled, root=args.root, method='comatch') 305 | dlval = get_val_loader(dataset=args.dataset, batch_size=64, num_workers=2, root=args.root) 306 | 307 | wd_params, non_wd_params = [], [] 308 | for name, param in model.named_parameters(): 309 | if 'bn' in name: 310 | non_wd_params.append(param) 311 | else: 312 | wd_params.append(param) 313 | param_list = [ 314 | {'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}] 315 | optim = torch.optim.SGD(param_list, lr=args.lr, weight_decay=args.weight_decay, 316 | momentum=args.momentum, nesterov=True) 317 | 318 | lr_schdlr = WarmupCosineLrScheduler(optim, n_iters_all, warmup_iter=0) 319 | 320 | # memory bank 321 | args.queue_size = args.queue_batch*(args.mu+1)*args.batchsize 322 | queue_feats = torch.zeros(args.queue_size, args.low_dim).cuda() 323 | queue_probs = torch.zeros(args.queue_size, args.n_classes).cuda() 324 | queue_ptr = 0 325 | 326 | # for distribution alignment 327 | prob_list = [] 328 | 329 | train_args = dict( 330 | model=model, 331 | ema_model=ema_model, 332 | prob_list=prob_list, 333 | criteria_x=criteria_x, 334 | optim=optim, 335 | lr_schdlr=lr_schdlr, 336 | dltrain_x=dltrain_x, 337 | dltrain_u=dltrain_u, 338 | args=args, 339 | n_iters=n_iters_per_epoch, 340 | logger=logger 341 | ) 342 | 343 | best_acc = -1 344 | best_epoch = 0 345 | logger.info('-----------start training--------------') 346 | for epoch in range(args.n_epoches): 347 | 348 | loss_x, loss_u, loss_c, mask_mean, num_pos, guess_label_acc, queue_feats, queue_probs, queue_ptr, prob_list = \ 349 | train_one_epoch(epoch, **train_args, queue_feats=queue_feats,queue_probs=queue_probs,queue_ptr=queue_ptr) 350 | 351 | top1, ema_top1 = evaluate(model, ema_model, dlval) 352 | 353 | tb_logger.log_value('loss_x', loss_x, epoch) 354 | tb_logger.log_value('loss_u', loss_u, epoch) 355 | tb_logger.log_value('loss_c', loss_c, epoch) 356 | tb_logger.log_value('guess_label_acc', guess_label_acc, epoch) 357 | tb_logger.log_value('test_acc', top1, epoch) 358 | tb_logger.log_value('test_ema_acc', ema_top1, epoch) 359 | tb_logger.log_value('mask', mask_mean, epoch) 360 | tb_logger.log_value('num_pos', num_pos, epoch) 361 | 362 | if best_acc < top1: 363 | best_acc = top1 364 | best_epoch = epoch 365 | 366 | logger.info("Epoch {}. Acc: {:.4f}. Ema-Acc: {:.4f}. best_acc: {:.4f} in epoch{}". 367 | format(epoch, top1, ema_top1, best_acc, best_epoch)) 368 | 369 | if epoch%10==0: 370 | save_obj = { 371 | 'model': model.state_dict(), 372 | 'ema_model': ema_model.state_dict(), 373 | 'optimizer': optim.state_dict(), 374 | 'lr_scheduler': lr_schdlr.state_dict(), 375 | 'prob_list': prob_list, 376 | 'queue': {'queue_feats':queue_feats, 'queue_probs':queue_probs, 'queue_ptr':queue_ptr}, 377 | 'epoch': epoch, 378 | } 379 | torch.save(save_obj, os.path.join(output_dir, 'checkpoint_%02d.pth'%epoch)) 380 | 381 | if __name__ == '__main__': 382 | main() 383 | -------------------------------------------------------------------------------- /Train_fixmatch.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import random 3 | 4 | import time 5 | import argparse 6 | import os 7 | import sys 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from WideResNet import WideResnet 16 | from datasets.cifar import get_train_loader, get_val_loader 17 | 18 | from utils import accuracy, setup_default_logging, AverageMeter, WarmupCosineLrScheduler 19 | 20 | import tensorboard_logger 21 | 22 | def set_model(args): 23 | model = WideResnet(n_classes=args.n_classes,k=args.wresnet_k, n=args.wresnet_n, proj=False) 24 | 25 | if args.checkpoint: 26 | checkpoint = torch.load(args.checkpoint) 27 | 28 | msg = model.load_state_dict(checkpoint, strict=False) 29 | assert set(msg.missing_keys) == {"classifier.weight", "classifier.bias"} 30 | assert set(msg.unexpected_keys) == {'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'} 31 | print('loaded from checkpoint: %s'%args.checkpoint) 32 | 33 | model.train() 34 | model.cuda() 35 | criteria_x = nn.CrossEntropyLoss().cuda() 36 | criteria_u = nn.CrossEntropyLoss(reduction='none').cuda() 37 | 38 | if args.eval_ema: 39 | ema_model = WideResnet(n_classes=args.n_classes,k=args.wresnet_k, n=args.wresnet_n, proj=False) 40 | for param_q, param_k in zip(model.parameters(), ema_model.parameters()): 41 | param_k.data.copy_(param_q.detach().data) # initialize 42 | param_k.requires_grad = False # not update by gradient for eval_net 43 | ema_model.cuda() 44 | ema_model.eval() 45 | else: 46 | ema_model = None 47 | 48 | return model, criteria_x, criteria_u, ema_model 49 | 50 | 51 | @torch.no_grad() 52 | def ema_model_update(model, ema_model, ema_m): 53 | """ 54 | Momentum update of evaluation model (exponential moving average) 55 | """ 56 | for param_train, param_eval in zip(model.parameters(), ema_model.parameters()): 57 | param_eval.copy_(param_eval * ema_m + param_train.detach() * (1-ema_m)) 58 | 59 | for buffer_train, buffer_eval in zip(model.buffers(), ema_model.buffers()): 60 | buffer_eval.copy_(buffer_train) 61 | 62 | 63 | def train_one_epoch(epoch, 64 | model, 65 | ema_model, 66 | criteria_x, 67 | criteria_u, 68 | optim, 69 | lr_schdlr, 70 | dltrain_x, 71 | dltrain_u, 72 | args, 73 | n_iters, 74 | logger, 75 | prob_list, 76 | ): 77 | model.train() 78 | loss_x_meter = AverageMeter() 79 | loss_u_meter = AverageMeter() 80 | # the number of correctly-predicted and gradient-considered unlabeled data 81 | n_correct_u_lbs_meter = AverageMeter() 82 | # the number of gradient-considered strong augmentation (logits above threshold) of unlabeled samples 83 | n_strong_aug_meter = AverageMeter() 84 | mask_meter = AverageMeter() 85 | 86 | 87 | epoch_start = time.time() # start time 88 | dl_x, dl_u = iter(dltrain_x), iter(dltrain_u) 89 | for it in range(n_iters): 90 | ims_x_weak, lbs_x = next(dl_x) 91 | (ims_u_weak, ims_u_strong), lbs_u_real = next(dl_u) 92 | 93 | lbs_x = lbs_x.cuda() 94 | lbs_u_real = lbs_u_real.cuda() 95 | 96 | # -------------------------------------- 97 | bt = ims_x_weak.size(0) 98 | mu = int(ims_u_weak.size(0) // bt) 99 | imgs = torch.cat([ims_x_weak, ims_u_weak, ims_u_strong], dim=0).cuda() 100 | logits = model(imgs) 101 | 102 | logits_x = logits[:bt] 103 | logits_u_w, logits_u_s = torch.split(logits[bt:], bt * mu) 104 | 105 | loss_x = criteria_x(logits_x, lbs_x) 106 | 107 | with torch.no_grad(): 108 | probs = torch.softmax(logits_u_w, dim=1) 109 | 110 | if args.DA: 111 | prob_list.append(probs.mean(0)) 112 | if len(prob_list)>32: 113 | prob_list.pop(0) 114 | prob_avg = torch.stack(prob_list,dim=0).mean(0) 115 | probs = probs / prob_avg 116 | probs = probs / probs.sum(dim=1, keepdim=True) 117 | 118 | scores, lbs_u_guess = torch.max(probs, dim=1) 119 | mask = scores.ge(args.thr).float() 120 | 121 | probs = probs.detach() 122 | 123 | loss_u = (criteria_u(logits_u_s, lbs_u_guess) * mask).mean() 124 | 125 | loss = loss_x + args.lam_u * loss_u 126 | 127 | optim.zero_grad() 128 | loss.backward() 129 | optim.step() 130 | lr_schdlr.step() 131 | 132 | if args.eval_ema: 133 | with torch.no_grad(): 134 | ema_model_update(model, ema_model, args.ema_m) 135 | 136 | loss_x_meter.update(loss_x.item()) 137 | loss_u_meter.update(loss_u.item()) 138 | mask_meter.update(mask.mean().item()) 139 | 140 | 141 | corr_u_lb = (lbs_u_guess == lbs_u_real).float() * mask 142 | n_correct_u_lbs_meter.update(corr_u_lb.sum().item()) 143 | n_strong_aug_meter.update(mask.sum().item()) 144 | 145 | if (it + 1) % 64 == 0: 146 | t = time.time() - epoch_start 147 | 148 | lr_log = [pg['lr'] for pg in optim.param_groups] 149 | lr_log = sum(lr_log) / len(lr_log) 150 | 151 | logger.info("{}-x{}-s{}, {} | epoch:{}, iter: {}. loss_u: {:.3f}. loss_x: {:.3f}. " 152 | "n_correct_u: {:.2f}/{:.2f}. Mask:{:.3f}. LR: {:.3f}. Time: {:.2f}".format( 153 | args.dataset, args.n_labeled, args.seed, args.exp_dir, epoch, it + 1, loss_u_meter.avg, loss_x_meter.avg, 154 | n_correct_u_lbs_meter.avg, n_strong_aug_meter.avg, mask_meter.avg, lr_log, t)) 155 | 156 | epoch_start = time.time() 157 | 158 | return loss_x_meter.avg, loss_u_meter.avg, mask_meter.avg, n_correct_u_lbs_meter.avg/n_strong_aug_meter.avg, prob_list 159 | 160 | 161 | def evaluate(model, ema_model, dataloader, criterion): 162 | 163 | model.eval() 164 | 165 | top1_meter = AverageMeter() 166 | ema_top1_meter = AverageMeter() 167 | 168 | with torch.no_grad(): 169 | for ims, lbs in dataloader: 170 | ims = ims.cuda() 171 | lbs = lbs.cuda() 172 | 173 | logits = model(ims) 174 | loss = criterion(logits, lbs) 175 | scores = torch.softmax(logits, dim=1) 176 | top1, top5 = accuracy(scores, lbs, (1, 5)) 177 | top1_meter.update(top1.item()) 178 | 179 | if ema_model is not None: 180 | logits = ema_model(ims) 181 | loss = criterion(logits, lbs) 182 | scores = torch.softmax(logits, dim=1) 183 | top1, top5 = accuracy(scores, lbs, (1, 5)) 184 | ema_top1_meter.update(top1.item()) 185 | 186 | return top1_meter.avg, ema_top1_meter.avg 187 | 188 | 189 | def main(): 190 | parser = argparse.ArgumentParser(description='FixMatch Training') 191 | parser.add_argument('--root', default='./data', type=str, help='dataset directory') 192 | parser.add_argument('--wresnet-k', default=2, type=int, 193 | help='width factor of wide resnet') 194 | parser.add_argument('--wresnet-n', default=28, type=int, 195 | help='depth of wide resnet') 196 | parser.add_argument('--dataset', type=str, default='CIFAR10', 197 | help='number of classes in dataset') 198 | parser.add_argument('--n-classes', type=int, default=10, 199 | help='number of classes in dataset') 200 | parser.add_argument('--n-labeled', type=int, default=40, 201 | help='number of labeled samples for training') 202 | parser.add_argument('--n-epoches', type=int, default=1024, 203 | help='number of training epoches') 204 | parser.add_argument('--batchsize', type=int, default=64, 205 | help='train batch size of labeled samples') 206 | parser.add_argument('--mu', type=int, default=7, 207 | help='factor of train batch size of unlabeled samples') 208 | 209 | parser.add_argument('--eval-ema', default=True, help='whether to use ema model for evaluation') 210 | parser.add_argument('--ema-m', type=float, default=0.999) 211 | 212 | parser.add_argument('--n-imgs-per-epoch', type=int, default=64 * 1024, 213 | help='number of training images for each epoch') 214 | parser.add_argument('--lam-u', type=float, default=1., 215 | help='coefficient of unlabeled loss') 216 | parser.add_argument('--lr', type=float, default=0.03, 217 | help='learning rate for training') 218 | parser.add_argument('--weight-decay', type=float, default=5e-4, 219 | help='weight decay') 220 | parser.add_argument('--momentum', type=float, default=0.9, 221 | help='momentum for optimizer') 222 | parser.add_argument('--seed', type=int, default=1, 223 | help='seed for random behaviors, no seed if negtive') 224 | parser.add_argument('--DA', default=True, help='use distribution alignment') 225 | 226 | parser.add_argument('--thr', type=float, default=0.95, 227 | help='pseudo label threshold') 228 | 229 | parser.add_argument('--exp-dir', default='FixMatch', type=str, help='experiment directory') 230 | parser.add_argument('--checkpoint', default='', type=str, help='use pretrained model') 231 | #/export/home/project/SimCLR/save_cifar_t0.2/checkpoint_100.tar 232 | 233 | args = parser.parse_args() 234 | 235 | logger, output_dir = setup_default_logging(args) 236 | logger.info(dict(args._get_kwargs())) 237 | 238 | tb_logger = tensorboard_logger.Logger(logdir=output_dir, flush_secs=2) 239 | 240 | if args.seed > 0: 241 | torch.manual_seed(args.seed) 242 | random.seed(args.seed) 243 | np.random.seed(args.seed) 244 | 245 | n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize # 1024 246 | n_iters_all = n_iters_per_epoch * args.n_epoches # 1024 * 200 247 | 248 | logger.info("***** Running training *****") 249 | logger.info(f" Task = {args.dataset}@{args.n_labeled}") 250 | 251 | model, criteria_x, criteria_u, ema_model = set_model(args) 252 | logger.info("Total params: {:.2f}M".format( 253 | sum(p.numel() for p in model.parameters()) / 1e6)) 254 | 255 | dltrain_x, dltrain_u = get_train_loader( 256 | args.dataset, args.batchsize, args.mu, n_iters_per_epoch, L=args.n_labeled, root=args.root, method='fixmatch') 257 | dlval = get_val_loader(dataset=args.dataset, batch_size=64, num_workers=2) 258 | 259 | wd_params, non_wd_params = [], [] 260 | for name, param in model.named_parameters(): 261 | if 'bn' in name: 262 | non_wd_params.append(param) 263 | else: 264 | wd_params.append(param) 265 | param_list = [ 266 | {'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}] 267 | optim = torch.optim.SGD(param_list, lr=args.lr, weight_decay=args.weight_decay, 268 | momentum=args.momentum, nesterov=True) 269 | 270 | lr_schdlr = WarmupCosineLrScheduler(optim, n_iters_all, warmup_iter=0) 271 | 272 | prob_list = [] 273 | train_args = dict( 274 | model=model, 275 | ema_model=ema_model, 276 | criteria_x=criteria_x, 277 | criteria_u=criteria_u, 278 | optim=optim, 279 | lr_schdlr=lr_schdlr, 280 | dltrain_x=dltrain_x, 281 | dltrain_u=dltrain_u, 282 | args=args, 283 | n_iters=n_iters_per_epoch, 284 | logger=logger, 285 | prob_list=prob_list 286 | ) 287 | best_acc = -1 288 | best_epoch = 0 289 | logger.info('-----------start training--------------') 290 | for epoch in range(args.n_epoches): 291 | loss_x, loss_u, mask_mean, guess_label_acc, prob_list = train_one_epoch(epoch, **train_args) 292 | 293 | top1, ema_top1 = evaluate(model, ema_model, dlval, criteria_x) 294 | 295 | tb_logger.log_value('loss_x', loss_x, epoch) 296 | tb_logger.log_value('loss_u', loss_u, epoch) 297 | tb_logger.log_value('guess_label_acc', guess_label_acc, epoch) 298 | tb_logger.log_value('test_acc', top1, epoch) 299 | tb_logger.log_value('test_ema_acc', ema_top1, epoch) 300 | tb_logger.log_value('mask', mask_mean, epoch) 301 | 302 | if best_acc < top1: 303 | best_acc = top1 304 | best_epoch = epoch 305 | 306 | logger.info("Epoch {}. Acc: {:.4f}. Ema-Acc: {:.4f}. best_acc: {:.4f} in epoch{}". 307 | format(epoch, top1, ema_top1, best_acc, best_epoch)) 308 | 309 | 310 | if __name__ == '__main__': 311 | main() 312 | -------------------------------------------------------------------------------- /WideResNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | import torch.nn.functional as F 9 | 10 | from torch.nn import BatchNorm2d 11 | 12 | ''' 13 | As in the paper, the wide resnet only considers the resnet of the pre-activated version, 14 | and it only considers the basic blocks rather than the bottleneck blocks. 15 | ''' 16 | 17 | 18 | class BasicBlockPreAct(nn.Module): 19 | def __init__(self, in_chan, out_chan, drop_rate=0, stride=1, pre_res_act=False): 20 | super(BasicBlockPreAct, self).__init__() 21 | self.bn1 = BatchNorm2d(in_chan, momentum=0.001) 22 | self.relu1 = nn.LeakyReLU(inplace=True, negative_slope=0.1) 23 | self.conv1 = nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=stride, padding=1, bias=False) 24 | self.bn2 = BatchNorm2d(out_chan, momentum=0.001) 25 | self.relu2 = nn.LeakyReLU(inplace=True, negative_slope=0.1) 26 | self.dropout = nn.Dropout(drop_rate) if not drop_rate == 0 else None 27 | self.conv2 = nn.Conv2d(out_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False) 28 | self.downsample = None 29 | if in_chan != out_chan or stride != 1: 30 | self.downsample = nn.Conv2d( 31 | in_chan, out_chan, kernel_size=1, stride=stride, bias=False 32 | ) 33 | self.pre_res_act = pre_res_act 34 | # self.init_weight() 35 | 36 | def forward(self, x): 37 | bn1 = self.bn1(x) 38 | act1 = self.relu1(bn1) 39 | residual = self.conv1(act1) 40 | residual = self.bn2(residual) 41 | residual = self.relu2(residual) 42 | if self.dropout is not None: 43 | residual = self.dropout(residual) 44 | residual = self.conv2(residual) 45 | 46 | shortcut = act1 if self.pre_res_act else x 47 | if self.downsample is not None: 48 | shortcut = self.downsample(shortcut) 49 | 50 | out = shortcut + residual 51 | return out 52 | 53 | def init_weight(self): 54 | # for _, md in self.named_modules(): 55 | # if isinstance(md, nn.Conv2d): 56 | # nn.init.kaiming_normal_( 57 | # md.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') 58 | # if md.bias is not None: 59 | # nn.init.constant_(md.bias, 0) 60 | for m in self.modules(): 61 | if isinstance(m, nn.Conv2d): 62 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') 63 | if m.bias is not None: 64 | nn.init.constant_(m.bias, 0) 65 | 66 | 67 | class WideResnetBackbone(nn.Module): 68 | def __init__(self, k=1, n=28, drop_rate=0): 69 | super(WideResnetBackbone, self).__init__() 70 | self.k, self.n = k, n 71 | assert (self.n - 4) % 6 == 0 72 | n_blocks = (self.n - 4) // 6 73 | n_layers = [16, ] + [self.k * 16 * (2 ** i) for i in range(3)] 74 | 75 | self.conv1 = nn.Conv2d( 76 | 3, 77 | n_layers[0], 78 | kernel_size=3, 79 | stride=1, 80 | padding=1, 81 | bias=False 82 | ) 83 | self.layer1 = self.create_layer( 84 | n_layers[0], 85 | n_layers[1], 86 | bnum=n_blocks, 87 | stride=1, 88 | drop_rate=drop_rate, 89 | pre_res_act=True, 90 | ) 91 | self.layer2 = self.create_layer( 92 | n_layers[1], 93 | n_layers[2], 94 | bnum=n_blocks, 95 | stride=2, 96 | drop_rate=drop_rate, 97 | pre_res_act=False, 98 | ) 99 | self.layer3 = self.create_layer( 100 | n_layers[2], 101 | n_layers[3], 102 | bnum=n_blocks, 103 | stride=2, 104 | drop_rate=drop_rate, 105 | pre_res_act=False, 106 | ) 107 | self.bn_last = BatchNorm2d(n_layers[3], momentum=0.001) 108 | self.relu_last = nn.LeakyReLU(inplace=True, negative_slope=0.1) 109 | self.init_weight() 110 | 111 | def create_layer( 112 | self, 113 | in_chan, 114 | out_chan, 115 | bnum, 116 | stride=1, 117 | drop_rate=0, 118 | pre_res_act=False, 119 | ): 120 | layers = [ 121 | BasicBlockPreAct( 122 | in_chan, 123 | out_chan, 124 | drop_rate=drop_rate, 125 | stride=stride, 126 | pre_res_act=pre_res_act), ] 127 | for _ in range(bnum - 1): 128 | layers.append( 129 | BasicBlockPreAct( 130 | out_chan, 131 | out_chan, 132 | drop_rate=drop_rate, 133 | stride=1, 134 | pre_res_act=False, )) 135 | return nn.Sequential(*layers) 136 | 137 | def forward(self, x): 138 | feat = self.conv1(x) 139 | 140 | feat = self.layer1(feat) 141 | feat2 = self.layer2(feat) # 1/2 142 | feat4 = self.layer3(feat2) # 1/4 143 | 144 | feat4 = self.bn_last(feat4) 145 | feat4 = self.relu_last(feat4) 146 | return feat2, feat4 147 | 148 | def init_weight(self): 149 | # for _, child in self.named_children(): 150 | # if isinstance(child, nn.Conv2d): 151 | # n = child.kernel_size[0] * child.kernel_size[0] * child.out_channels 152 | # nn.init.normal_(child.weight, 0, 1. / ((0.5 * n) ** 0.5)) 153 | # # nn.init.kaiming_normal_( 154 | # # child.weight, a=0.1, mode='fan_out', 155 | # # nonlinearity='leaky_relu' 156 | # # ) 157 | # 158 | # if child.bias is not None: 159 | # nn.init.constant_(child.bias, 0) 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 163 | m.weight.data.normal_(0, math.sqrt(2. / n)) 164 | 165 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') 166 | if m.bias is not None: 167 | nn.init.constant_(m.bias, 0) 168 | elif isinstance(m, nn.BatchNorm2d): 169 | m.weight.data.fill_(1) 170 | m.bias.data.zero_() 171 | elif isinstance(m, nn.Linear): 172 | m.bias.data.zero_() 173 | 174 | class Normalize(nn.Module): 175 | 176 | def __init__(self, power=2): 177 | super(Normalize, self).__init__() 178 | self.power = power 179 | 180 | def forward(self, x): 181 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 182 | out = x.div(norm) 183 | return out 184 | 185 | class WideResnet(nn.Module): 186 | ''' 187 | for wide-resnet-28-10, the definition should be WideResnet(n_classes, 10, 28) 188 | ''' 189 | 190 | def __init__(self, n_classes, k=1, n=28, low_dim=64, proj=True): 191 | super(WideResnet, self).__init__() 192 | self.n_layers, self.k = n, k 193 | self.backbone = WideResnetBackbone(k=k, n=n) 194 | self.classifier = nn.Linear(64 * self.k, n_classes, bias=True) 195 | 196 | self.proj = proj 197 | if proj: 198 | self.l2norm = Normalize(2) 199 | 200 | self.fc1 = nn.Linear(64 * self.k, 64 * self.k) 201 | self.relu_mlp = nn.LeakyReLU(inplace=True, negative_slope=0.1) 202 | self.fc2 = nn.Linear(64 * self.k, low_dim) 203 | 204 | 205 | def forward(self, x): 206 | feat = self.backbone(x)[-1] 207 | feat = torch.mean(feat, dim=(2, 3)) 208 | out = self.classifier(feat) 209 | 210 | if self.proj: 211 | feat = self.fc1(feat) 212 | feat = self.relu_mlp(feat) 213 | feat = self.fc2(feat) 214 | 215 | feat = self.l2norm(feat) 216 | return out,feat 217 | else: 218 | return out 219 | 220 | def init_weight(self): 221 | nn.init.xavier_normal_(self.classifier.weight) 222 | if not self.classifier.bias is None: 223 | nn.init.constant_(self.classifier.bias, 0) 224 | 225 | 226 | if __name__ == "__main__": 227 | x = torch.randn(2, 3, 224, 224) 228 | lb = torch.randint(0, 10, (2,)).long() 229 | 230 | net = WideResnetBackbone() 231 | out = net(x) 232 | print(out[0].size()) 233 | del net, out 234 | 235 | net = WideResnet(n_classes=10) 236 | criteria = nn.CrossEntropyLoss() 237 | out = net(x) 238 | loss = criteria(out, lb) 239 | loss.backward() 240 | print(out.size()) -------------------------------------------------------------------------------- /comatch.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/CoMatch/a64ccf5af11017a1a07267b21e5899f4e8157801/comatch.gif -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/CoMatch/a64ccf5af11017a1a07267b21e5899f4e8157801/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/cifar.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pickle 3 | import numpy as np 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | 9 | from datasets import transform as T 10 | from datasets.randaugment import RandomAugment 11 | from datasets.sampler import RandomSampler, BatchSampler 12 | 13 | class TwoCropsTransform: 14 | """Take 2 random augmentations of one image.""" 15 | 16 | def __init__(self,trans_weak,trans_strong): 17 | self.trans_weak = trans_weak 18 | self.trans_strong = trans_strong 19 | def __call__(self, x): 20 | x1 = self.trans_weak(x) 21 | x2 = self.trans_strong(x) 22 | return [x1, x2] 23 | 24 | class ThreeCropsTransform: 25 | """Take 3 random augmentations of one image.""" 26 | 27 | def __init__(self,trans_weak,trans_strong0,trans_strong1): 28 | self.trans_weak = trans_weak 29 | self.trans_strong0 = trans_strong0 30 | self.trans_strong1 = trans_strong1 31 | def __call__(self, x): 32 | x1 = self.trans_weak(x) 33 | x2 = self.trans_strong0(x) 34 | x3 = self.trans_strong1(x) 35 | return [x1, x2, x3] 36 | 37 | def load_data_train(L=250, dataset='CIFAR10', dspth='./data'): 38 | if dataset == 'CIFAR10': 39 | datalist = [ 40 | osp.join(dspth, 'cifar-10-batches-py', 'data_batch_{}'.format(i + 1)) 41 | for i in range(5) 42 | ] 43 | n_class = 10 44 | assert L in [10, 20, 40, 80, 250, 4000] 45 | elif dataset == 'CIFAR100': 46 | datalist = [ 47 | osp.join(dspth, 'cifar-100-python', 'train')] 48 | n_class = 100 49 | assert L in [25, 400, 2500, 10000] 50 | 51 | data, labels = [], [] 52 | for data_batch in datalist: 53 | with open(data_batch, 'rb') as fr: 54 | entry = pickle.load(fr, encoding='latin1') 55 | lbs = entry['labels'] if 'labels' in entry.keys() else entry['fine_labels'] 56 | data.append(entry['data']) 57 | labels.append(lbs) 58 | data = np.concatenate(data, axis=0) 59 | labels = np.concatenate(labels, axis=0) 60 | n_labels = L // n_class 61 | data_x, label_x, data_u, label_u = [], [], [], [] 62 | for i in range(n_class): 63 | indices = np.where(labels == i)[0] 64 | np.random.shuffle(indices) 65 | inds_x, inds_u = indices[:n_labels], indices[n_labels:] 66 | data_x += [ 67 | data[i].reshape(3, 32, 32).transpose(1, 2, 0) 68 | for i in inds_x 69 | ] 70 | label_x += [labels[i] for i in inds_x] 71 | data_u += [ 72 | data[i].reshape(3, 32, 32).transpose(1, 2, 0) 73 | for i in inds_u 74 | ] 75 | label_u += [labels[i] for i in inds_u] 76 | return data_x, label_x, data_u, label_u 77 | 78 | 79 | def load_data_val(dataset, dspth='./data'): 80 | if dataset == 'CIFAR10': 81 | datalist = [ 82 | osp.join(dspth, 'cifar-10-batches-py', 'test_batch') 83 | ] 84 | elif dataset == 'CIFAR100': 85 | datalist = [ 86 | osp.join(dspth, 'cifar-100-python', 'test') 87 | ] 88 | 89 | data, labels = [], [] 90 | for data_batch in datalist: 91 | with open(data_batch, 'rb') as fr: 92 | entry = pickle.load(fr, encoding='latin1') 93 | lbs = entry['labels'] if 'labels' in entry.keys() else entry['fine_labels'] 94 | data.append(entry['data']) 95 | labels.append(lbs) 96 | data = np.concatenate(data, axis=0) 97 | labels = np.concatenate(labels, axis=0) 98 | data = [ 99 | el.reshape(3, 32, 32).transpose(1, 2, 0) 100 | for el in data 101 | ] 102 | return data, labels 103 | 104 | 105 | def compute_mean_var(): 106 | data_x, label_x, data_u, label_u = load_data_train() 107 | data = data_x + data_u 108 | data = np.concatenate([el[None, ...] for el in data], axis=0) 109 | 110 | mean, var = [], [] 111 | for i in range(3): 112 | channel = (data[:, :, :, i].ravel() / 127.5) - 1 113 | # channel = (data[:, :, :, i].ravel() / 255) 114 | mean.append(np.mean(channel)) 115 | var.append(np.std(channel)) 116 | 117 | print('mean: ', mean) 118 | print('var: ', var) 119 | 120 | 121 | 122 | class Cifar(Dataset): 123 | def __init__(self, dataset, data, labels, mode): 124 | super(Cifar, self).__init__() 125 | self.data, self.labels = data, labels 126 | self.mode = mode 127 | assert len(self.data) == len(self.labels) 128 | if dataset == 'CIFAR10': 129 | mean, std = (0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616) 130 | elif dataset == 'CIFAR100': 131 | mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761) 132 | 133 | trans_weak = T.Compose([ 134 | T.Resize((32, 32)), 135 | T.PadandRandomCrop(border=4, cropsize=(32, 32)), 136 | T.RandomHorizontalFlip(p=0.5), 137 | T.Normalize(mean, std), 138 | T.ToTensor(), 139 | ]) 140 | trans_strong0 = T.Compose([ 141 | T.Resize((32, 32)), 142 | T.PadandRandomCrop(border=4, cropsize=(32, 32)), 143 | T.RandomHorizontalFlip(p=0.5), 144 | RandomAugment(2, 10), 145 | T.Normalize(mean, std), 146 | T.ToTensor(), 147 | ]) 148 | trans_strong1 = transforms.Compose([ 149 | transforms.ToPILImage(), 150 | transforms.RandomResizedCrop(32, scale=(0.2, 1.)), 151 | transforms.RandomHorizontalFlip(p=0.5), 152 | transforms.RandomApply([ 153 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 154 | ], p=0.8), 155 | transforms.RandomGrayscale(p=0.2), 156 | transforms.ToTensor(), 157 | transforms.Normalize(mean, std), 158 | ]) 159 | if self.mode == 'train_x': 160 | self.trans = trans_weak 161 | elif self.mode == 'train_u_comatch': 162 | self.trans = ThreeCropsTransform(trans_weak, trans_strong0, trans_strong1) 163 | elif self.mode == 'train_u_fixmatch': 164 | self.trans = TwoCropsTransform(trans_weak, trans_strong0) 165 | else: 166 | self.trans = T.Compose([ 167 | T.Resize((32, 32)), 168 | T.Normalize(mean, std), 169 | T.ToTensor(), 170 | ]) 171 | 172 | def __getitem__(self, idx): 173 | im, lb = self.data[idx], self.labels[idx] 174 | return self.trans(im), lb 175 | 176 | def __len__(self): 177 | leng = len(self.data) 178 | return leng 179 | 180 | 181 | def get_train_loader(dataset, batch_size, mu, n_iters_per_epoch, L, root='data', method='comatch'): 182 | data_x, label_x, data_u, label_u = load_data_train(L=L, dataset=dataset, dspth=root) 183 | 184 | ds_x = Cifar( 185 | dataset=dataset, 186 | data=data_x, 187 | labels=label_x, 188 | mode='train_x' 189 | ) # return an iter of num_samples length (all indices of samples) 190 | sampler_x = RandomSampler(ds_x, replacement=True, num_samples=n_iters_per_epoch * batch_size) 191 | batch_sampler_x = BatchSampler(sampler_x, batch_size, drop_last=True) # yield a batch of samples one time 192 | dl_x = torch.utils.data.DataLoader( 193 | ds_x, 194 | batch_sampler=batch_sampler_x, 195 | num_workers=2, 196 | pin_memory=True 197 | ) 198 | ds_u = Cifar( 199 | dataset=dataset, 200 | data=data_u, 201 | labels=label_u, 202 | mode='train_u_%s'%method 203 | ) 204 | sampler_u = RandomSampler(ds_u, replacement=True, num_samples=mu * n_iters_per_epoch * batch_size) 205 | #sampler_u = RandomSampler(ds_u, replacement=False) 206 | batch_sampler_u = BatchSampler(sampler_u, batch_size * mu, drop_last=True) 207 | dl_u = torch.utils.data.DataLoader( 208 | ds_u, 209 | batch_sampler=batch_sampler_u, 210 | num_workers=2, 211 | pin_memory=True 212 | ) 213 | return dl_x, dl_u 214 | 215 | 216 | def get_val_loader(dataset, batch_size, num_workers, pin_memory=True, root='data'): 217 | data, labels = load_data_val(dataset, dspth=root) 218 | ds = Cifar( 219 | dataset=dataset, 220 | data=data, 221 | labels=labels, 222 | mode='test' 223 | ) 224 | dl = torch.utils.data.DataLoader( 225 | ds, 226 | shuffle=False, 227 | batch_size=batch_size, 228 | drop_last=False, 229 | num_workers=num_workers, 230 | pin_memory=pin_memory 231 | ) 232 | return dl 233 | 234 | 235 | -------------------------------------------------------------------------------- /datasets/randaugment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | ## aug functions 6 | def identity_func(img): 7 | return img 8 | 9 | 10 | def autocontrast_func(img, cutoff=0): 11 | ''' 12 | same output as PIL.ImageOps.autocontrast 13 | ''' 14 | n_bins = 256 15 | 16 | def tune_channel(ch): 17 | n = ch.size 18 | cut = cutoff * n // 100 19 | if cut == 0: 20 | high, low = ch.max(), ch.min() 21 | else: 22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 23 | low = np.argwhere(np.cumsum(hist) > cut) 24 | low = 0 if low.shape[0] == 0 else low[0] 25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 27 | if high <= low: 28 | table = np.arange(n_bins) 29 | else: 30 | scale = (n_bins - 1) / (high - low) 31 | offset = -low * scale 32 | table = np.arange(n_bins) * scale + offset 33 | table[table < 0] = 0 34 | table[table > n_bins - 1] = n_bins - 1 35 | table = table.clip(0, 255).astype(np.uint8) 36 | return table[ch] 37 | 38 | channels = [tune_channel(ch) for ch in cv2.split(img)] 39 | out = cv2.merge(channels) 40 | return out 41 | 42 | 43 | def equalize_func(img): 44 | ''' 45 | same output as PIL.ImageOps.equalize 46 | PIL's implementation is different from cv2.equalize 47 | ''' 48 | n_bins = 256 49 | 50 | def tune_channel(ch): 51 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 52 | non_zero_hist = hist[hist != 0].reshape(-1) 53 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 54 | if step == 0: return ch 55 | n = np.empty_like(hist) 56 | n[0] = step // 2 57 | n[1:] = hist[:-1] 58 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 59 | return table[ch] 60 | 61 | channels = [tune_channel(ch) for ch in cv2.split(img)] 62 | out = cv2.merge(channels) 63 | return out 64 | 65 | 66 | def rotate_func(img, degree, fill=(0, 0, 0)): 67 | ''' 68 | like PIL, rotate by degree, not radians 69 | ''' 70 | H, W = img.shape[0], img.shape[1] 71 | center = W / 2, H / 2 72 | M = cv2.getRotationMatrix2D(center, degree, 1) 73 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 74 | return out 75 | 76 | 77 | def solarize_func(img, thresh=128): 78 | ''' 79 | same output as PIL.ImageOps.posterize 80 | ''' 81 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 82 | table = table.clip(0, 255).astype(np.uint8) 83 | out = table[img] 84 | return out 85 | 86 | 87 | def color_func(img, factor): 88 | ''' 89 | same output as PIL.ImageEnhance.Color 90 | ''' 91 | ## implementation according to PIL definition, quite slow 92 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 93 | # out = blend(degenerate, img, factor) 94 | # M = ( 95 | # np.eye(3) * factor 96 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 97 | # )[np.newaxis, np.newaxis, :] 98 | M = ( 99 | np.float32([ 100 | [0.886, -0.114, -0.114], 101 | [-0.587, 0.413, -0.587], 102 | [-0.299, -0.299, 0.701]]) * factor 103 | + np.float32([[0.114], [0.587], [0.299]]) 104 | ) 105 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 106 | return out 107 | 108 | 109 | def contrast_func(img, factor): 110 | """ 111 | same output as PIL.ImageEnhance.Contrast 112 | """ 113 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 114 | table = np.array([( 115 | el - mean) * factor + mean 116 | for el in range(256) 117 | ]).clip(0, 255).astype(np.uint8) 118 | out = table[img] 119 | return out 120 | 121 | 122 | def brightness_func(img, factor): 123 | ''' 124 | same output as PIL.ImageEnhance.Contrast 125 | ''' 126 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 127 | out = table[img] 128 | return out 129 | 130 | 131 | def sharpness_func(img, factor): 132 | ''' 133 | The differences the this result and PIL are all on the 4 boundaries, the center 134 | areas are same 135 | ''' 136 | kernel = np.ones((3, 3), dtype=np.float32) 137 | kernel[1][1] = 5 138 | kernel /= 13 139 | degenerate = cv2.filter2D(img, -1, kernel) 140 | if factor == 0.0: 141 | out = degenerate 142 | elif factor == 1.0: 143 | out = img 144 | else: 145 | out = img.astype(np.float32) 146 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 147 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 148 | out = out.astype(np.uint8) 149 | return out 150 | 151 | 152 | def shear_x_func(img, factor, fill=(0, 0, 0)): 153 | H, W = img.shape[0], img.shape[1] 154 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 155 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 156 | return out 157 | 158 | 159 | def translate_x_func(img, offset, fill=(0, 0, 0)): 160 | ''' 161 | same output as PIL.Image.transform 162 | ''' 163 | H, W = img.shape[0], img.shape[1] 164 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 166 | return out 167 | 168 | 169 | def translate_y_func(img, offset, fill=(0, 0, 0)): 170 | ''' 171 | same output as PIL.Image.transform 172 | ''' 173 | H, W = img.shape[0], img.shape[1] 174 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 175 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 176 | return out 177 | 178 | 179 | def posterize_func(img, bits): 180 | ''' 181 | same output as PIL.ImageOps.posterize 182 | ''' 183 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 184 | return out 185 | 186 | 187 | def shear_y_func(img, factor, fill=(0, 0, 0)): 188 | H, W = img.shape[0], img.shape[1] 189 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 190 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 191 | return out 192 | 193 | 194 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 195 | replace = np.array(replace, dtype=np.uint8) 196 | H, W = img.shape[0], img.shape[1] 197 | rh, rw = np.random.random(2) 198 | pad_size = pad_size // 2 199 | ch, cw = int(rh * H), int(rw * W) 200 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 201 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 202 | out = img.copy() 203 | out[x1:x2, y1:y2, :] = replace 204 | return out 205 | 206 | 207 | ### level to args 208 | def enhance_level_to_args(MAX_LEVEL): 209 | def level_to_args(level): 210 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 211 | return level_to_args 212 | 213 | 214 | def shear_level_to_args(MAX_LEVEL, replace_value): 215 | def level_to_args(level): 216 | level = (level / MAX_LEVEL) * 0.3 217 | if np.random.random() > 0.5: level = -level 218 | return (level, replace_value) 219 | 220 | return level_to_args 221 | 222 | 223 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 224 | def level_to_args(level): 225 | level = (level / MAX_LEVEL) * float(translate_const) 226 | if np.random.random() > 0.5: level = -level 227 | return (level, replace_value) 228 | 229 | return level_to_args 230 | 231 | 232 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 233 | def level_to_args(level): 234 | level = int((level / MAX_LEVEL) * cutout_const) 235 | return (level, replace_value) 236 | 237 | return level_to_args 238 | 239 | 240 | def solarize_level_to_args(MAX_LEVEL): 241 | def level_to_args(level): 242 | level = int((level / MAX_LEVEL) * 256) 243 | return (level, ) 244 | return level_to_args 245 | 246 | 247 | def none_level_to_args(level): 248 | return () 249 | 250 | 251 | def posterize_level_to_args(MAX_LEVEL): 252 | def level_to_args(level): 253 | level = int((level / MAX_LEVEL) * 4) 254 | return (level, ) 255 | return level_to_args 256 | 257 | 258 | def rotate_level_to_args(MAX_LEVEL, replace_value): 259 | def level_to_args(level): 260 | level = (level / MAX_LEVEL) * 30 261 | if np.random.random() < 0.5: 262 | level = -level 263 | return (level, replace_value) 264 | 265 | return level_to_args 266 | 267 | 268 | func_dict = { 269 | 'Identity': identity_func, 270 | 'AutoContrast': autocontrast_func, 271 | 'Equalize': equalize_func, 272 | 'Rotate': rotate_func, 273 | 'Solarize': solarize_func, 274 | 'Color': color_func, 275 | 'Contrast': contrast_func, 276 | 'Brightness': brightness_func, 277 | 'Sharpness': sharpness_func, 278 | 'ShearX': shear_x_func, 279 | 'TranslateX': translate_x_func, 280 | 'TranslateY': translate_y_func, 281 | 'Posterize': posterize_func, 282 | 'ShearY': shear_y_func, 283 | } 284 | 285 | translate_const = 10 286 | MAX_LEVEL = 10 287 | replace_value = (128, 128, 128) 288 | arg_dict = { 289 | 'Identity': none_level_to_args, 290 | 'AutoContrast': none_level_to_args, 291 | 'Equalize': none_level_to_args, 292 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), 293 | 'Solarize': solarize_level_to_args(MAX_LEVEL), 294 | 'Color': enhance_level_to_args(MAX_LEVEL), 295 | 'Contrast': enhance_level_to_args(MAX_LEVEL), 296 | 'Brightness': enhance_level_to_args(MAX_LEVEL), 297 | 'Sharpness': enhance_level_to_args(MAX_LEVEL), 298 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), 299 | 'TranslateX': translate_level_to_args( 300 | translate_const, MAX_LEVEL, replace_value 301 | ), 302 | 'TranslateY': translate_level_to_args( 303 | translate_const, MAX_LEVEL, replace_value 304 | ), 305 | 'Posterize': posterize_level_to_args(MAX_LEVEL), 306 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), 307 | } 308 | 309 | 310 | class RandomAugment(object): 311 | 312 | def __init__(self, N=2, M=10, isPIL=False): 313 | self.N = N 314 | self.M = M 315 | self.isPIL = isPIL 316 | 317 | def get_random_ops(self): 318 | sampled_ops = np.random.choice(list(func_dict.keys()), self.N) 319 | return [(op, 0.5, self.M) for op in sampled_ops] 320 | 321 | def __call__(self, img): 322 | if self.isPIL: 323 | img = np.array(img) 324 | ops = self.get_random_ops() 325 | for name, prob, level in ops: 326 | if np.random.random() > prob: 327 | continue 328 | args = arg_dict[name](level) 329 | img = func_dict[name](img, *args) 330 | img = cutout_func(img, 16, replace_value) 331 | return img 332 | 333 | 334 | if __name__ == '__main__': 335 | a = RandomAugment() 336 | img = np.random.randn(32, 32, 3) 337 | a(img) -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._six import int_classes as _int_classes 3 | 4 | 5 | class Sampler(object): 6 | r"""Base class for all Samplers. 7 | 8 | Every Sampler subclass has to provide an :meth:`__iter__` method, providing a 9 | way to iterate over indices of dataset elements, and a :meth:`__len__` method 10 | that returns the length of the returned iterators. 11 | 12 | .. note:: The :meth:`__len__` method isn't strictly required by 13 | :class:`~torch.utils.data.DataLoader`, but is expected in any 14 | calculation involving the length of a :class:`~torch.utils.data.DataLoader`. 15 | """ 16 | 17 | def __init__(self, data_source): 18 | pass 19 | 20 | def __iter__(self): 21 | raise NotImplementedError 22 | 23 | # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] 24 | # 25 | # Many times we have an abstract class representing a collection/iterable of 26 | # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally 27 | # implementing a `__len__` method. In such cases, we must make sure to not 28 | # provide a default implementation, because both straightforward default 29 | # implementations have their issues: 30 | # 31 | # + `return NotImplemented`: 32 | # Calling `len(subclass_instance)` raises: 33 | # TypeError: 'NotImplementedType' object cannot be interpreted as an integer 34 | # 35 | # + `raise NotImplementedError()`: 36 | # This prevents triggering some fallback behavior. E.g., the built-in 37 | # `list(X)` tries to call `len(X)` first, and executes a different code 38 | # path if the method is not found or `NotImplemented` is returned, while 39 | # raising an `NotImplementedError` will propagate and and make the call 40 | # fail where it could have use `__iter__` to complete the call. 41 | # 42 | # Thus, the only two sensible things to do are 43 | # 44 | # + **not** provide a default `__len__`. 45 | # 46 | # + raise a `TypeError` instead, which is what Python uses when users call 47 | # a method that is not defined on an object. 48 | # (@ssnl verifies that this works on at least Python 3.7.) 49 | 50 | 51 | class SequentialSampler(Sampler): 52 | r"""Samples elements sequentially, always in the same order. 53 | 54 | Arguments: 55 | data_source (Dataset): dataset to sample from 56 | """ 57 | 58 | def __init__(self, data_source): 59 | self.data_source = data_source 60 | 61 | def __iter__(self): 62 | return iter(range(len(self.data_source))) 63 | 64 | def __len__(self): 65 | return len(self.data_source) 66 | 67 | 68 | class RandomSampler(Sampler): 69 | r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. 70 | If with replacement, then user can specify :attr:`num_samples` to draw. 71 | 72 | Arguments: 73 | data_source (Dataset): dataset to sample from 74 | replacement (bool): samples are drawn with replacement if ``True``, default=``False`` 75 | num_samples (int): number of samples to draw, default=`len(dataset)`. This argument 76 | is supposed to be specified only when `replacement` is ``True``. 77 | """ 78 | 79 | def __init__(self, data_source, replacement=False, num_samples=None): 80 | self.data_source = data_source 81 | self.replacement = replacement 82 | self._num_samples = num_samples 83 | 84 | if not isinstance(self.replacement, bool): 85 | raise ValueError("replacement should be a boolean value, but got " 86 | "replacement={}".format(self.replacement)) 87 | 88 | if self._num_samples is not None and not replacement: 89 | raise ValueError("With replacement=False, num_samples should not be specified, " 90 | "since a random permute will be performed.") 91 | 92 | if not isinstance(self.num_samples, int) or self.num_samples <= 0: 93 | raise ValueError("num_samples should be a positive integer " 94 | "value, but got num_samples={}".format(self.num_samples)) 95 | 96 | @property 97 | def num_samples(self): 98 | # dataset size might change at runtime 99 | if self._num_samples is None: 100 | return len(self.data_source) 101 | return self._num_samples 102 | 103 | def __iter__(self): 104 | n = len(self.data_source) 105 | if self.replacement: 106 | n_repeats = self.num_samples // n 107 | n_remain = self.num_samples % n 108 | indices = [torch.randperm(n) for _ in range(n_repeats)] 109 | indices.append(torch.randperm(n)[:n_remain]) 110 | return iter(torch.cat(indices, dim=0).tolist()) 111 | return iter(torch.randperm(n).tolist()) 112 | 113 | def __len__(self): 114 | return self.num_samples 115 | 116 | 117 | class SubsetRandomSampler(Sampler): 118 | r"""Samples elements randomly from a given list of indices, without replacement. 119 | 120 | Arguments: 121 | indices (sequence): a sequence of indices 122 | """ 123 | 124 | def __init__(self, indices): 125 | self.indices = indices 126 | 127 | def __iter__(self): 128 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 129 | 130 | def __len__(self): 131 | return len(self.indices) 132 | 133 | 134 | class WeightedRandomSampler(Sampler): 135 | r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). 136 | 137 | Args: 138 | weights (sequence) : a sequence of weights, not necessary summing up to one 139 | num_samples (int): number of samples to draw 140 | replacement (bool): if ``True``, samples are drawn with replacement. 141 | If not, they are drawn without replacement, which means that when a 142 | sample index is drawn for a row, it cannot be drawn again for that row. 143 | 144 | Example: 145 | >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) 146 | [0, 0, 0, 1, 0] 147 | >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) 148 | [0, 1, 4, 3, 2] 149 | """ 150 | 151 | def __init__(self, weights, num_samples, replacement=True): 152 | if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \ 153 | num_samples <= 0: 154 | raise ValueError("num_samples should be a positive integer " 155 | "value, but got num_samples={}".format(num_samples)) 156 | if not isinstance(replacement, bool): 157 | raise ValueError("replacement should be a boolean value, but got " 158 | "replacement={}".format(replacement)) 159 | self.weights = torch.as_tensor(weights, dtype=torch.double) 160 | self.num_samples = num_samples 161 | self.replacement = replacement 162 | 163 | def __iter__(self): 164 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist()) 165 | 166 | def __len__(self): 167 | return self.num_samples 168 | 169 | 170 | class BatchSampler(Sampler): 171 | r"""Wraps another sampler to yield a mini-batch of indices. 172 | 173 | Args: 174 | sampler (Sampler): Base sampler. 175 | batch_size (int): Size of mini-batch. 176 | drop_last (bool): If ``True``, the sampler will drop the last batch if 177 | its size would be less than ``batch_size`` 178 | 179 | Example: 180 | >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) 181 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 182 | >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) 183 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 184 | """ 185 | 186 | def __init__(self, sampler, batch_size, drop_last): 187 | if not isinstance(sampler, Sampler): 188 | raise ValueError("sampler should be an instance of " 189 | "torch.utils.data.Sampler, but got sampler={}" 190 | .format(sampler)) 191 | if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ 192 | batch_size <= 0: 193 | raise ValueError("batch_size should be a positive integer value, " 194 | "but got batch_size={}".format(batch_size)) 195 | if not isinstance(drop_last, bool): 196 | raise ValueError("drop_last should be a boolean value, but got " 197 | "drop_last={}".format(drop_last)) 198 | self.sampler = sampler 199 | self.batch_size = batch_size 200 | self.drop_last = drop_last 201 | 202 | def __iter__(self): 203 | batch = [] 204 | for idx in self.sampler: 205 | batch.append(idx) 206 | if len(batch) == self.batch_size: 207 | yield batch 208 | batch = [] 209 | if len(batch) > 0 and not self.drop_last: 210 | yield batch 211 | 212 | def __len__(self): 213 | if self.drop_last: 214 | return len(self.sampler) // self.batch_size 215 | else: 216 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 217 | -------------------------------------------------------------------------------- /datasets/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | 5 | 6 | class PadandRandomCrop(object): 7 | ''' 8 | Input tensor is expected to have shape of (H, W, 3) 9 | ''' 10 | def __init__(self, border=4, cropsize=(32, 32)): 11 | self.border = border 12 | self.cropsize = cropsize 13 | 14 | def __call__(self, im): 15 | borders = [(self.border, self.border), (self.border, self.border), (0, 0)] # input is (h, w, c) 16 | convas = np.pad(im, borders, mode='reflect') 17 | H, W, C = convas.shape 18 | h, w = self.cropsize 19 | dh, dw = max(0, H-h), max(0, W-w) 20 | sh, sw = np.random.randint(0, dh), np.random.randint(0, dw) 21 | out = convas[sh:sh+h, sw:sw+w, :] 22 | return out 23 | 24 | 25 | class RandomHorizontalFlip(object): 26 | def __init__(self, p=0.5): 27 | self.p = p 28 | 29 | def __call__(self, im): 30 | if np.random.rand() < self.p: 31 | im = im[:, ::-1, :] 32 | return im 33 | 34 | 35 | class Resize(object): 36 | def __init__(self, size): 37 | self.size = size 38 | 39 | def __call__(self, im): 40 | im = cv2.resize(im, self.size) 41 | return im 42 | 43 | 44 | class Normalize(object): 45 | ''' 46 | Inputs are pixel values in range of [0, 255], channel order is 'rgb' 47 | ''' 48 | def __init__(self, mean, std): 49 | self.mean = np.array(mean, np.float32).reshape(1, 1, -1) 50 | self.std = np.array(std, np.float32).reshape(1, 1, -1) 51 | 52 | def __call__(self, im): 53 | if len(im.shape) == 4: 54 | mean, std = self.mean[None, ...], self.std[None, ...] 55 | elif len(im.shape) == 3: 56 | mean, std = self.mean, self.std 57 | im = im.astype(np.float32) / 255. 58 | # im = (im.astype(np.float32) / 127.5) - 1 59 | im -= mean 60 | im /= std 61 | return im 62 | 63 | 64 | class ToTensor(object): 65 | def __init__(self): 66 | pass 67 | 68 | def __call__(self, im): 69 | if len(im.shape) == 4: 70 | return torch.from_numpy(im.transpose(0, 3, 1, 2)) 71 | elif len(im.shape) == 3: 72 | return torch.from_numpy(im.transpose(2, 0, 1)) 73 | 74 | 75 | class Compose(object): 76 | def __init__(self, ops): 77 | self.ops = ops 78 | 79 | def __call__(self, im): 80 | for op in self.ops: 81 | im = op(im) 82 | return im 83 | -------------------------------------------------------------------------------- /imagenet/Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from random import sample 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | class Model(nn.Module): 8 | 9 | def __init__(self, base_encoder, args, width): 10 | 11 | super(Model, self).__init__() 12 | 13 | self.K = args.K 14 | 15 | self.encoder = base_encoder(num_class=args.num_class,mlp=True,low_dim=args.low_dim,width=width) 16 | self.m_encoder = base_encoder(num_class=args.num_class,mlp=True,low_dim=args.low_dim,width=width) 17 | 18 | for param, param_m in zip(self.encoder.parameters(), self.m_encoder.parameters()): 19 | param_m.data.copy_(param.data) # initialize 20 | param_m.requires_grad = False # not update by gradient 21 | 22 | # queue to store momentum feature for strong augmentations 23 | self.register_buffer("queue_s", torch.randn(args.low_dim, self.K)) 24 | self.queue_s = F.normalize(self.queue_s, dim=0) 25 | self.register_buffer("queue_ptr_s", torch.zeros(1, dtype=torch.long)) 26 | # queue to store momentum probs for weak augmentations (unlabeled) 27 | self.register_buffer("probs_u", torch.zeros(args.num_class, self.K)) 28 | 29 | # queue (memory bank) to store momentum feature and probs for weak augmentations (labeled and unlabeled) 30 | self.register_buffer("queue_w", torch.randn(args.low_dim, self.K)) 31 | self.register_buffer("queue_ptr_w", torch.zeros(1, dtype=torch.long)) 32 | self.register_buffer("probs_xu", torch.zeros(args.num_class, self.K)) 33 | 34 | # for distribution alignment 35 | self.hist_prob = [] 36 | 37 | @torch.no_grad() 38 | def _update_momentum_encoder(self,m): 39 | """ 40 | Update momentum encoder 41 | """ 42 | for param, param_m in zip(self.encoder.parameters(), self.m_encoder.parameters()): 43 | param_m.data = param_m.data * m + param.data * (1. - m) 44 | 45 | @torch.no_grad() 46 | def _dequeue_and_enqueue(self, z, t, ws): 47 | z = concat_all_gather(z) 48 | t = concat_all_gather(t) 49 | 50 | batch_size = z.shape[0] 51 | 52 | if ws=='s': 53 | ptr = int(self.queue_ptr_s) 54 | if (ptr + batch_size) > self.K: 55 | batch_size = self.K-ptr 56 | z = z[:batch_size] 57 | t = t[:batch_size] 58 | # replace the samples at ptr (dequeue and enqueue) 59 | self.queue_s[:, ptr:ptr + batch_size] = z.T 60 | self.probs_u[:, ptr:ptr + batch_size] = t.T 61 | ptr = (ptr + batch_size) % self.K # move pointer 62 | self.queue_ptr_s[0] = ptr 63 | 64 | elif ws=='w': 65 | ptr = int(self.queue_ptr_w) 66 | if (ptr + batch_size) > self.K: 67 | batch_size = self.K-ptr 68 | z = z[:batch_size] 69 | t = t[:batch_size] 70 | # replace the samples at ptr (dequeue and enqueue) 71 | self.queue_w[:, ptr:ptr + batch_size] = z.T 72 | self.probs_xu[:, ptr:ptr + batch_size] = t.T 73 | ptr = (ptr + batch_size) % self.K # move pointer 74 | self.queue_ptr_w[0] = ptr 75 | 76 | @torch.no_grad() 77 | def _batch_shuffle_ddp(self, x): 78 | """ 79 | Batch shuffle, for making use of BatchNorm. 80 | *** Only support DistributedDataParallel (DDP) model. *** 81 | """ 82 | # gather from all gpus 83 | batch_size_this = x.shape[0] 84 | x_gather = concat_all_gather(x) 85 | batch_size_all = x_gather.shape[0] 86 | 87 | num_gpus = batch_size_all // batch_size_this 88 | 89 | # random shuffle index 90 | idx_shuffle = torch.randperm(batch_size_all).cuda() 91 | 92 | # broadcast to all gpus 93 | torch.distributed.broadcast(idx_shuffle, src=0) 94 | 95 | # index for restoring 96 | idx_unshuffle = torch.argsort(idx_shuffle) 97 | 98 | # shuffled index for this gpu 99 | gpu_idx = torch.distributed.get_rank() 100 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 101 | 102 | return x_gather[idx_this], idx_unshuffle 103 | 104 | @torch.no_grad() 105 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 106 | """ 107 | Undo batch shuffle. 108 | *** Only support DistributedDataParallel (DDP) model. *** 109 | """ 110 | # gather from all gpus 111 | batch_size_this = x.shape[0] 112 | x_gather = concat_all_gather(x) 113 | batch_size_all = x_gather.shape[0] 114 | 115 | num_gpus = batch_size_all // batch_size_this 116 | 117 | # restored index for this gpu 118 | gpu_idx = torch.distributed.get_rank() 119 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 120 | 121 | return x_gather[idx_this] 122 | 123 | def forward(self, args, labeled_batch, unlabeled_batch=None, is_eval=False, epoch=0): 124 | 125 | img_x = labeled_batch[0].cuda(args.gpu, non_blocking=True) 126 | labels_x = labeled_batch[1].cuda(args.gpu, non_blocking=True) 127 | 128 | if is_eval: 129 | outputs_x, _ = self.encoder(img_x) 130 | return outputs_x, labels_x 131 | 132 | btx = img_x.size(0) 133 | 134 | img_u_w = unlabeled_batch[0][0].cuda(args.gpu, non_blocking=True) 135 | img_u_s0 = unlabeled_batch[0][1].cuda(args.gpu, non_blocking=True) 136 | img_u_s1 = unlabeled_batch[0][2].cuda(args.gpu, non_blocking=True) 137 | 138 | btu = img_u_w.size(0) 139 | 140 | imgs = torch.cat([img_x, img_u_s0], dim=0) 141 | outputs, features = self.encoder(imgs) 142 | 143 | outputs_x = outputs[:btx] 144 | outputs_u_s0 = outputs[btx:] 145 | features_u_s0 = features[btx:] 146 | 147 | with torch.no_grad(): 148 | self._update_momentum_encoder(args.m) 149 | # forward through the momentum encoder 150 | imgs_m = torch.cat([img_x, img_u_w, img_u_s1], dim=0) 151 | imgs_m, idx_unshuffle = self._batch_shuffle_ddp(imgs_m) 152 | 153 | outputs_m, features_m = self.m_encoder(imgs_m) 154 | outputs_m = self._batch_unshuffle_ddp(outputs_m, idx_unshuffle) 155 | features_m = self._batch_unshuffle_ddp(features_m, idx_unshuffle) 156 | 157 | outputs_u_w = outputs_m[btx:btx+btu] 158 | 159 | feature_u_w = features_m[btx:btx+btu] 160 | feature_xu_w = features_m[:btx+btu] 161 | features_u_s1 = features_m[btx+btu:] 162 | 163 | outputs_u_w = outputs_u_w.detach() 164 | feature_u_w = feature_u_w.detach() 165 | feature_xu_w = feature_xu_w.detach() 166 | features_u_s1 = features_u_s1.detach() 167 | 168 | probs = torch.softmax(outputs_u_w, dim=1) 169 | 170 | # distribution alignment 171 | probs_bt_avg = probs.mean(0) 172 | torch.distributed.all_reduce(probs_bt_avg,async_op=False) 173 | self.hist_prob.append(probs_bt_avg/args.world_size) 174 | 175 | if len(self.hist_prob)>128: 176 | self.hist_prob.pop(0) 177 | 178 | probs_avg = torch.stack(self.hist_prob,dim=0).mean(0) 179 | probs = probs / probs_avg 180 | probs = probs / probs.sum(dim=1, keepdim=True) 181 | probs_orig = probs.clone() 182 | 183 | # memory-smoothed pseudo-label refinement (starting from 2nd epoch) 184 | if epoch>0: 185 | m_feat_xu = self.queue_w.clone().detach() 186 | m_probs_xu = self.probs_xu.clone().detach() 187 | A = torch.exp(torch.mm(feature_u_w, m_feat_xu)/args.temperature) 188 | A = A/A.sum(1,keepdim=True) 189 | probs = args.alpha*probs + (1-args.alpha)*torch.mm(A, m_probs_xu.t()) 190 | 191 | # construct pseudo-label graph 192 | 193 | # similarity with current batch 194 | Q_self = torch.mm(probs,probs.t()) 195 | Q_self.fill_diagonal_(1) 196 | 197 | # similarity with past samples 198 | m_probs_u = self.probs_u.clone().detach() 199 | Q_past = torch.mm(probs,m_probs_u) 200 | 201 | # concatenate them 202 | Q = torch.cat([Q_self,Q_past],dim=1) 203 | 204 | # construct embedding graph for strong augmentations 205 | sim_self = torch.exp(torch.mm(features_u_s0, features_u_s1.t())/args.temperature) 206 | m_feat = self.queue_s.clone().detach() 207 | sim_past = torch.exp(torch.mm(features_u_s0, m_feat)/args.temperature) 208 | sim = torch.cat([sim_self,sim_past],dim=1) 209 | 210 | # store strong augmentation features and probs (unlabeled) into momentum queue 211 | self._dequeue_and_enqueue(features_u_s1, probs, 's') 212 | 213 | # store weak augmentation features and probs (labeled and unlabeled) into memory bank 214 | onehot = torch.zeros(btx,args.num_class).cuda().scatter(1,labels_x.view(-1,1),1) 215 | probs_xu = torch.cat([onehot, probs_orig],dim=0) 216 | 217 | self._dequeue_and_enqueue(feature_xu_w, probs_xu, 'w') 218 | 219 | return outputs_x, outputs_u_s0, labels_x, probs, Q, sim 220 | 221 | 222 | 223 | 224 | @torch.no_grad() 225 | def concat_all_gather(tensor): 226 | """ 227 | Performs all_gather operation on the provided tensors. 228 | *** Warning ***: torch.distributed.all_gather has no gradient. 229 | """ 230 | tensors_gather = [torch.ones_like(tensor) 231 | for _ in range(torch.distributed.get_world_size())] 232 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 233 | 234 | output = torch.cat(tensors_gather, dim=0) 235 | return output 236 | 237 | -------------------------------------------------------------------------------- /imagenet/README.md: -------------------------------------------------------------------------------- 1 | ### Semi-supervised learning on ImageNet with 1% or 10% labels: 2 | 3 | This implementation only supports multi-gpu, DistributedDataParallel training, which is faster and simpler. 4 | 5 | To train CoMatch with 1% labels on 8 gpus, run: 6 |
python Train_CoMatch.py --percent 1 --thr 0.6 --contrast-th 0.3 --lam-c 10 [Imagenet dataset folder]
7 | 8 | To train CoMatch with 10% labels on 8 gpus, run: 9 |
python Train_CoMatch.py --percent 10 --thr 0.5 --contrast-th 0.2 --lam-c 2 [Imagenet dataset folder]
10 | 11 | ### Semi-supervised learning results with CoMatch 12 | 13 | num. labels | top-1 acc. | top-5 acc 14 | --- | --- | --- 15 | 1% | 66.0% | 86.4% 16 | 10% | 73.6% | 91.6% 17 | 18 | ### Download pre-trained ResNet-50 models 19 | 20 | num. labels | model 21 | ------ | ------ 22 | 1% | download 23 | 10% | download 24 | 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /imagenet/Train_CoMatch.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2018, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | ''' 7 | import argparse 8 | import os 9 | import random 10 | import shutil 11 | import time 12 | import warnings 13 | import builtins 14 | import json 15 | from datetime import datetime 16 | import sys 17 | import math 18 | import numpy as np 19 | 20 | import tensorboard_logger 21 | import logging 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.parallel 26 | import torch.backends.cudnn as cudnn 27 | import torch.distributed as dist 28 | import torch.optim 29 | import torch.multiprocessing as mp 30 | import torch.utils.data 31 | import torch.utils.data.distributed 32 | import torchvision.transforms as transforms 33 | import torchvision.datasets as datasets 34 | import torch.nn.functional as F 35 | 36 | import loader 37 | from Model import Model 38 | from resnet import * 39 | 40 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 41 | parser.add_argument('data', metavar='DIR', default='', help='path to dataset') 42 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',choices=['resnet50','resnet50x2','resnet50x4']) 43 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 44 | help='number of data loading workers (default: 4)') 45 | parser.add_argument('--epochs', default=400, type=int, metavar='N', 46 | help='number of total epochs to run') 47 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 48 | help='manual epoch number (useful on restarts)') 49 | parser.add_argument('--batch-size', default=160, type=int, 50 | help='supervised batch size') 51 | parser.add_argument('--batch-size-u', default=640, type=int, 52 | help='unsupervised batch size') 53 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 54 | metavar='LR', help='initial learning rate', dest='lr') 55 | parser.add_argument('--cos', default=True, help='use cosine lr schedule') 56 | parser.add_argument('--schedule', default=[], nargs='*', type=int, 57 | help='learning rate schedule (when to drop lr by a ratio)') 58 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 59 | help='momentum') 60 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 61 | metavar='W', help='weight decay (default: 1e-4)', 62 | dest='weight_decay') 63 | parser.add_argument('-p', '--print-freq', default=50, type=int, 64 | metavar='N', help='print frequency (default: 50)') 65 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 66 | parser.add_argument('--pretrained', default='', type=str, metavar='PATH', help='path to pretrained model (default: none)') 67 | parser.add_argument('--world-size', default=1, type=int, 68 | help='number of nodes for distributed training') 69 | parser.add_argument('--rank', default=0, type=int, 70 | help='node rank for distributed training') 71 | parser.add_argument('--dist-url', default='tcp://localhost:10001', type=str, 72 | help='url used to set up distributed training') 73 | parser.add_argument('--dist-backend', default='nccl', type=str, 74 | help='distributed backend') 75 | parser.add_argument('--seed', default=0, type=int, 76 | help='seed for initializing training. ') 77 | parser.add_argument('--gpu', default=None, type=int, 78 | help='GPU id to use.') 79 | parser.add_argument('--multiprocessing-distributed', default=True) 80 | 81 | ## CoMatch settings 82 | parser.add_argument('--temperature', default=0.1, type=float, help='temperature for similarity scaling') 83 | parser.add_argument('--low-dim', default=128, type=int, help='feature dimension') 84 | parser.add_argument('--moco-m', default=0.996, type=float, 85 | help='momentum of updating momentum encoder') 86 | parser.add_argument('--K', default=30000, type=int, help='size of memory bank and momentum queue') 87 | parser.add_argument('--thr', default=0.6, type=float, help='pseudo-label confidence threshold') 88 | parser.add_argument('--contrast-th', default=0.3, type=float, help='pseudo-label graph connection threshold') 89 | parser.add_argument('--lam-u', default=10, type=float, help='weight for unsupervised cross-entropy loss') 90 | parser.add_argument('--lam-c', default=10, type=float, help='weight for unsupervised contrastive loss') 91 | parser.add_argument('--alpha', default=0.9, type=float, help='weight for model prediction in constructing pseudo-label') 92 | parser.add_argument('--exp_dir', default='experiment/comatch_1percent', type=str, help='experiment directory') 93 | 94 | ## dataset settings 95 | parser.add_argument('--percent', type=int, default=1, choices=[1,10], help='percentage of labeled samples') 96 | parser.add_argument('--num-class', default=1000, type=int) 97 | parser.add_argument('--annotation', default='annotation_1percent.json', type=str, help='annotation file') 98 | 99 | 100 | def main(): 101 | args = parser.parse_args() 102 | 103 | if args.seed is not None: 104 | random.seed(args.seed) 105 | torch.manual_seed(args.seed) 106 | 107 | if args.dist_url == "env://" and args.world_size == -1: 108 | args.world_size = int(os.environ["WORLD_SIZE"]) 109 | 110 | os.makedirs(args.exp_dir, exist_ok=True) 111 | 112 | ngpus_per_node = torch.cuda.device_count() 113 | if args.multiprocessing_distributed: 114 | args.world_size = ngpus_per_node * args.world_size 115 | # Use torch.multiprocessing.spawn to launch distributed processes: the 116 | # main_worker process function 117 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 118 | else: 119 | # Simply call main_worker function 120 | main_worker(args.gpu, ngpus_per_node, args) 121 | 122 | 123 | def main_worker(gpu, ngpus_per_node, args): 124 | args.gpu = gpu 125 | 126 | if args.gpu is not None: 127 | print("Use GPU: {} for training".format(args.gpu)) 128 | 129 | if args.multiprocessing_distributed and args.gpu != 0: 130 | def print_pass(*args): 131 | pass 132 | builtins.print = print_pass 133 | 134 | if args.dist_url == "env://" and args.rank == -1: 135 | args.rank = int(os.environ["RANK"]) 136 | if args.multiprocessing_distributed: 137 | # For multiprocessing distributed training, rank needs to be the 138 | # global rank among all the processes 139 | args.rank = args.rank * ngpus_per_node + gpu 140 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 141 | world_size=args.world_size, rank=args.rank) 142 | 143 | # create model 144 | print("=> creating model '{}'".format(args.arch)) 145 | if args.arch == 'resnet50': 146 | model = Model(resnet50,args,width=1) 147 | elif args.arch == 'resnet50x2': 148 | model = Model(resnet50,args,width=2) 149 | elif args.arch == 'resnet50x4': 150 | model = Model(resnet50,args,width=4) 151 | else: 152 | raise NotImplementedError('model not supported {}'.format(args.arch)) 153 | 154 | # load moco-v2 pretrained model 155 | if args.pretrained: 156 | if os.path.isfile(args.pretrained): 157 | print("=> loading checkpoint '{}'".format(args.pretrained)) 158 | checkpoint = torch.load(args.pretrained, map_location="cpu") 159 | state_dict = checkpoint['state_dict'] 160 | for k in list(state_dict.keys()): 161 | if k.startswith('module.encoder_q'): 162 | # remove prefix 163 | state_dict[k.replace('module.encoder_q', 'encoder')] = state_dict[k] 164 | # delete renamed or unused k 165 | del state_dict[k] 166 | for k in list(state_dict.keys()): 167 | if 'fc.0' in k: 168 | state_dict[k.replace('fc.0','fc1')] = state_dict[k] 169 | if 'fc.2' in k: 170 | state_dict[k.replace('fc.2','fc2')] = state_dict[k] 171 | del state_dict[k] 172 | args.start_epoch = 0 173 | msg = model.load_state_dict(state_dict, strict=False) 174 | print("=> loaded pre-trained model '{}'".format(args.pretrained)) 175 | # copy paramter to the momentum encoder 176 | for param, param_m in zip(model.encoder.parameters(), model.m_encoder.parameters()): 177 | param_m.data.copy_(param.data) 178 | param_m.requires_grad = False 179 | else: 180 | print("=> no checkpoint found at '{}'".format(args.pretrained)) 181 | 182 | 183 | if args.gpu is not None: 184 | torch.cuda.set_device(args.gpu) 185 | model.cuda(args.gpu) 186 | # When using a single GPU per process and per 187 | # DistributedDataParallel, we need to divide the batch size 188 | # ourselves based on the total number of GPUs we have 189 | args.batch_size = int(args.batch_size / ngpus_per_node) 190 | args.batch_size_u = int(args.batch_size_u / ngpus_per_node) 191 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 192 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) #find_unused_parameters=True 193 | else: 194 | model.cuda() 195 | # DistributedDataParallel will divide and allocate batch_size to all 196 | # available GPUs if device_ids are not set 197 | model = torch.nn.parallel.DistributedDataParallel(model) 198 | 199 | # define loss function (criterion) and optimizer 200 | criteria_x = nn.CrossEntropyLoss().cuda(args.gpu) 201 | 202 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 203 | momentum=args.momentum, 204 | weight_decay=args.weight_decay, 205 | nesterov=True 206 | ) 207 | 208 | # optionally resume from a checkpoint 209 | if args.resume: 210 | if os.path.isfile(args.resume): 211 | print("=> loading checkpoint '{}'".format(args.resume)) 212 | if args.gpu is None: 213 | checkpoint = torch.load(args.resume) 214 | else: 215 | # Map model to be loaded to specified single gpu. 216 | loc = 'cuda:{}'.format(args.gpu) 217 | checkpoint = torch.load(args.resume, map_location=loc) 218 | args.start_epoch = checkpoint['epoch'] 219 | model.load_state_dict(checkpoint['state_dict']) 220 | optimizer.load_state_dict(checkpoint['optimizer']) 221 | print("=> loaded checkpoint '{}' (epoch {})" 222 | .format(args.resume, checkpoint['epoch'])) 223 | else: 224 | print("=> no checkpoint found at '{}'".format(args.resume)) 225 | 226 | cudnn.benchmark = True 227 | 228 | print("=> preparing dataset") 229 | # Data loading code 230 | 231 | transform_strong = transforms.Compose([ 232 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 233 | transforms.RandomApply([ 234 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 235 | ], p=0.8), 236 | transforms.RandomGrayscale(p=0.2), 237 | transforms.RandomHorizontalFlip(), 238 | transforms.ToTensor(), 239 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 240 | ]) 241 | transform_weak = transforms.Compose([ 242 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 243 | transforms.RandomHorizontalFlip(), 244 | transforms.ToTensor(), 245 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 246 | ]) 247 | transform_eval = transforms.Compose([ 248 | transforms.Resize(256), 249 | transforms.CenterCrop(224), 250 | transforms.ToTensor(), 251 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 252 | ]) 253 | 254 | three_crops_transform = loader.ThreeCropsTransform(transform_weak, transform_strong, transform_strong) 255 | 256 | traindir = os.path.join(args.data, 'train') 257 | valdir = os.path.join(args.data, 'val') 258 | 259 | labeled_dataset = datasets.ImageFolder(traindir,transform_weak) 260 | unlabeled_dataset = datasets.ImageFolder(traindir,three_crops_transform) 261 | 262 | if not os.path.exists(args.annotation): 263 | # randomly sample labeled data on main process (gpu0) 264 | label_per_class = 13 if args.percent==1 else 128 265 | if args.gpu==0: 266 | random.shuffle(labeled_dataset.samples) 267 | labeled_samples=[] 268 | unlabeled_samples=[] 269 | num_img = torch.zeros(args.num_class) 270 | 271 | for i,(img,label) in enumerate(labeled_dataset.samples): 272 | if num_img[label]=args.contrast_th) 393 | Q_mask = Q * pos_mask 394 | Q_mask = Q_mask / Q_mask.sum(1,keepdim=True) 395 | 396 | positives = sim * pos_mask 397 | pos_probs = positives / sim.sum(1, keepdim=True) 398 | log_probs = torch.log(pos_probs + 1e-7) * pos_mask 399 | 400 | # unsupervised contrastive loss 401 | loss_contrast = - (log_probs*Q_mask).sum(1) 402 | loss_contrast = loss_contrast.mean() 403 | 404 | # ramp up the weight for unsupervised contrastive loss (optional) 405 | lam_c = min(epoch+1, args.lam_c) 406 | loss = loss_x + args.lam_u * loss_u + lam_c * loss_contrast 407 | 408 | # compute gradient and do SGD step 409 | optimizer.zero_grad() 410 | loss.backward() 411 | optimizer.step() 412 | 413 | loss_x_meter.update(loss_x.item()) 414 | loss_u_meter.update(loss_u.item()) 415 | loss_contrast_meter.update(loss_contrast.item()) 416 | pos_meter.update(pos_mask.sum(1).float().mean().item()) 417 | 418 | corr_u_lb = (lbs_u_guess == lbs_u_real).float() * mask 419 | n_correct_u_lbs_meter.update(corr_u_lb.sum().item()) 420 | n_conf.update(mask.sum().item()) 421 | 422 | # measure elapsed time 423 | batch_time.update(time.time() - end) 424 | end = time.time() 425 | 426 | if i % args.print_freq == 0 and args.gpu==0: 427 | lr_log = [pg['lr'] for pg in optimizer.param_groups] 428 | lr_log = sum(lr_log) / len(lr_log) 429 | 430 | logger.info("{} || epoch:{}, iter: {}. loss_u: {:.3f}. loss_x: {:.3f}. loss_c: {:.3f}. " 431 | "n_correct_u: {:.2f}/{:.2f}. n_edge: {:.3f}. LR: {:.2f}. " 432 | "batch_time: {:.2f}. data_time: {:.2f}.".format( 433 | args.exp_dir, epoch, i + 1, loss_u_meter.avg, loss_x_meter.avg, loss_contrast_meter.avg, 434 | n_correct_u_lbs_meter.avg, n_conf.avg, pos_meter.avg, lr_log, batch_time.avg, data_time.avg)) 435 | 436 | if args.gpu==0: 437 | tb_logger.log_value('loss_x', loss_x_meter.avg, epoch) 438 | tb_logger.log_value('loss_u', loss_u_meter.avg, epoch) 439 | tb_logger.log_value('loss_c', loss_contrast_meter.avg, epoch) 440 | tb_logger.log_value('num_conf', n_conf.avg, epoch) 441 | tb_logger.log_value('guess_label_acc', n_correct_u_lbs_meter.avg/n_conf.avg, epoch) 442 | 443 | 444 | 445 | def validate(val_loader, model, args, logger, tb_logger, epoch): 446 | 447 | top1 = AverageMeter() 448 | top5 = AverageMeter() 449 | 450 | # switch to evaluate mode 451 | model.eval() 452 | 453 | with torch.no_grad(): 454 | end = time.time() 455 | for i, batch in enumerate(val_loader): 456 | # compute output 457 | output,target = model(args, batch, is_eval=True) 458 | 459 | # measure accuracy 460 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 461 | top1.update(acc1[0]) 462 | top5.update(acc5[0]) 463 | 464 | if i % args.print_freq == 0 and args.gpu==0: 465 | logger.info("validation ||epoch:{}, iter: {}. acc1 : {:.2f}. acc5 : {:.2f}.".format( 466 | epoch, i + 1, top1.avg, top5.avg)) 467 | 468 | if args.gpu==0: 469 | logger.info("validation ||epoch:{}, acc1 : {:.2f}. acc5 : {:.2f}.".format(epoch, top1.avg, top5.avg)) 470 | tb_logger.log_value('test_acc', top1.avg, epoch) 471 | tb_logger.log_value('test_acc5', top5.avg, epoch) 472 | torch.cuda.empty_cache() 473 | return top1.avg 474 | 475 | 476 | def save_checkpoint(state, filename='checkpoint.pth.tar'): 477 | torch.save(state, filename) 478 | 479 | 480 | class AverageMeter(object): 481 | """ 482 | Computes and stores the average and current value 483 | 484 | """ 485 | def __init__(self): 486 | self.reset() 487 | 488 | def reset(self): 489 | self.val = 0 490 | self.avg = 0 491 | self.sum = 0 492 | self.count = 0 493 | 494 | def update(self, val, n=1): 495 | self.val = val 496 | self.sum += val * n 497 | self.count += n 498 | self.avg = self.sum / self.count 499 | 500 | 501 | def setup_default_logging(args, default_level=logging.INFO, 502 | format="%(asctime)s - %(levelname)s - %(message)s"): 503 | 504 | logger = logging.getLogger('') 505 | 506 | logging.basicConfig( # unlike the root logger, a custom logger can’t be configured using basicConfig() 507 | filename=os.path.join(args.exp_dir, f'{time_str()}.log'), 508 | format=format, 509 | datefmt="%m/%d/%Y %H:%M:%S", 510 | level=default_level) 511 | 512 | console_handler = logging.StreamHandler(sys.stdout) 513 | console_handler.setLevel(default_level) 514 | console_handler.setFormatter(logging.Formatter(format)) 515 | logger.addHandler(console_handler) 516 | 517 | return logger 518 | 519 | def time_str(fmt=None): 520 | if fmt is None: 521 | fmt = '%Y-%m-%d_%H:%M:%S' 522 | 523 | # time.strftime(format[, t]) 524 | return datetime.today().strftime(fmt) 525 | 526 | def adjust_learning_rate(optimizer, epoch, args): 527 | """Decay the learning rate based on schedule""" 528 | lr = args.lr 529 | if args.cos: # cosine lr schedule 530 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 531 | else: # stepwise lr schedule 532 | for milestone in args.schedule: 533 | lr *= 0.1 if epoch >= milestone else 1. 534 | for param_group in optimizer.param_groups: 535 | param_group['lr'] = lr 536 | 537 | def accuracy(output, target, topk=(1,)): 538 | """Computes the accuracy over the k top predictions for the specified values of k""" 539 | with torch.no_grad(): 540 | maxk = max(topk) 541 | batch_size = target.size(0) 542 | 543 | _, pred = output.topk(maxk, 1, True, True) 544 | pred = pred.t() 545 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 546 | 547 | res = [] 548 | for k in topk: 549 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 550 | res.append(correct_k.mul_(100.0 / batch_size)) 551 | return res 552 | 553 | if __name__ == '__main__': 554 | main() 555 | 556 | 557 | -------------------------------------------------------------------------------- /imagenet/loader.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | 3 | class ThreeCropsTransform: 4 | """Take 3 random augmentations of one image.""" 5 | 6 | def __init__(self,trans_weak,trans_strong0,trans_strong1): 7 | self.trans_weak = trans_weak 8 | self.trans_strong0 = trans_strong0 9 | self.trans_strong1 = trans_strong1 10 | def __call__(self, x): 11 | x1 = self.trans_weak(x) 12 | x2 = self.trans_strong0(x) 13 | x3 = self.trans_strong1(x) 14 | return [x1, x2, x3] -------------------------------------------------------------------------------- /imagenet/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | import torch.utils.model_zoo as model_zoo 6 | import torch.nn.functional as F 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152'] 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | """3x3 convolution with padding""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | class Normalize(nn.Module): 27 | 28 | def __init__(self, power=2): 29 | super(Normalize, self).__init__() 30 | self.power = power 31 | 32 | def forward(self, x): 33 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 34 | out = x.div(norm) 35 | return out 36 | 37 | 38 | class BasicBlock(nn.Module): 39 | expansion = 1 40 | 41 | def __init__(self, inplanes, planes, stride=1, downsample=None): 42 | super(BasicBlock, self).__init__() 43 | self.conv1 = conv3x3(inplanes, planes, stride) 44 | self.bn1 = nn.BatchNorm2d(planes) 45 | self.relu = nn.ReLU(inplace=True) 46 | self.conv2 = conv3x3(planes, planes) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.downsample = downsample 49 | self.stride = stride 50 | 51 | def forward(self, x): 52 | residual = x 53 | 54 | out = self.conv1(x) 55 | out = self.bn1(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv2(out) 59 | out = self.bn2(out) 60 | 61 | if self.downsample is not None: 62 | residual = self.downsample(x) 63 | 64 | out += residual 65 | out = self.relu(out) 66 | 67 | return out 68 | 69 | 70 | class Bottleneck(nn.Module): 71 | expansion = 4 72 | 73 | def __init__(self, inplanes, planes, stride=1, downsample=None): 74 | super(Bottleneck, self).__init__() 75 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 76 | self.bn1 = nn.BatchNorm2d(planes) 77 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 78 | padding=1, bias=False) 79 | self.bn2 = nn.BatchNorm2d(planes) 80 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 81 | self.bn3 = nn.BatchNorm2d(planes * 4) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.downsample = downsample 84 | self.stride = stride 85 | 86 | def forward(self, x): 87 | residual = x 88 | 89 | out = self.conv1(x) 90 | out = self.bn1(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv2(out) 94 | out = self.bn2(out) 95 | out = self.relu(out) 96 | 97 | out = self.conv3(out) 98 | out = self.bn3(out) 99 | 100 | if self.downsample is not None: 101 | residual = self.downsample(x) 102 | 103 | out += residual 104 | out = self.relu(out) 105 | 106 | return out 107 | 108 | 109 | class ResNet(nn.Module): 110 | 111 | def __init__(self, block, layers, mlp=False, low_dim=128, in_channel=3, width=1, num_class=1000): 112 | self.inplanes = 64 113 | super(ResNet, self).__init__() 114 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, 115 | bias=False) 116 | self.bn1 = nn.BatchNorm2d(64) 117 | self.relu = nn.ReLU(inplace=True) 118 | 119 | self.base = int(64 * width) 120 | 121 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 122 | self.layer1 = self._make_layer(block, self.base, layers[0]) 123 | self.layer2 = self._make_layer(block, self.base * 2, layers[1], stride=2) 124 | self.layer3 = self._make_layer(block, self.base * 4, layers[2], stride=2) 125 | self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=2) 126 | self.avgpool = nn.AvgPool2d(7, stride=1) 127 | 128 | self.classifier = nn.Linear(self.base * 8 * block.expansion, num_class) 129 | self.l2norm = Normalize(2) 130 | self.mlp = mlp 131 | if self.mlp: #use an extra projection layer 132 | self.fc1 = nn.Linear(self.base * 8 * block.expansion, 2048) 133 | self.fc2 = nn.Linear(2048, low_dim) 134 | else: 135 | self.fc = nn.Linear(self.base * 8 * block.expansion, low_dim) 136 | 137 | for m in self.modules(): 138 | if isinstance(m, nn.Conv2d): 139 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 140 | m.weight.data.normal_(0, math.sqrt(2. / n)) 141 | elif isinstance(m, nn.BatchNorm2d): 142 | m.weight.data.fill_(1) 143 | m.bias.data.zero_() 144 | 145 | def _make_layer(self, block, planes, blocks, stride=1): 146 | downsample = None 147 | if stride != 1 or self.inplanes != planes * block.expansion: 148 | downsample = nn.Sequential( 149 | nn.Conv2d(self.inplanes, planes * block.expansion, 150 | kernel_size=1, stride=stride, bias=False), 151 | nn.BatchNorm2d(planes * block.expansion), 152 | ) 153 | 154 | layers = [] 155 | layers.append(block(self.inplanes, planes, stride, downsample)) 156 | self.inplanes = planes * block.expansion 157 | for i in range(1, blocks): 158 | layers.append(block(self.inplanes, planes)) 159 | 160 | return nn.Sequential(*layers) 161 | 162 | def forward(self, x): 163 | 164 | x = self.conv1(x) 165 | x = self.bn1(x) 166 | x = self.relu(x) 167 | x = self.maxpool(x) 168 | 169 | x = self.layer1(x) 170 | x = self.layer2(x) 171 | x = self.layer3(x) 172 | x = self.layer4(x) 173 | 174 | x = self.avgpool(x) 175 | feat = x.view(x.size(0), -1) 176 | 177 | out = self.classifier(feat) 178 | 179 | if self.mlp: 180 | feat = F.relu(self.fc1(feat)) 181 | feat = self.fc2(feat) 182 | else: 183 | feat = self.fc(feat) 184 | feat = self.l2norm(feat) 185 | return out,feat 186 | 187 | 188 | def resnet18(pretrained=False, **kwargs): 189 | """Constructs a ResNet-18 model. 190 | Args: 191 | pretrained (bool): If True, returns a model pre-trained on ImageNet 192 | """ 193 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 194 | if pretrained: 195 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 196 | return model 197 | 198 | 199 | def resnet34(pretrained=False, **kwargs): 200 | """Constructs a ResNet-34 model. 201 | Args: 202 | pretrained (bool): If True, returns a model pre-trained on ImageNet 203 | """ 204 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 205 | if pretrained: 206 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 207 | return model 208 | 209 | 210 | def resnet50(pretrained=False, **kwargs): 211 | """Constructs a ResNet-50 model. 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | """ 215 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 216 | if pretrained: 217 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']),strict=False) 218 | return model 219 | 220 | 221 | def resnet101(pretrained=False, **kwargs): 222 | """Constructs a ResNet-101 model. 223 | Args: 224 | pretrained (bool): If True, returns a model pre-trained on ImageNet 225 | """ 226 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 227 | if pretrained: 228 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 229 | return model 230 | 231 | 232 | def resnet152(pretrained=False, **kwargs): 233 | """Constructs a ResNet-152 model. 234 | Args: 235 | pretrained (bool): If True, returns a model pre-trained on ImageNet 236 | """ 237 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 238 | if pretrained: 239 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 240 | return model 241 | 242 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import logging 3 | import os 4 | import sys 5 | import torch 6 | import math 7 | from torch.optim.lr_scheduler import _LRScheduler, LambdaLR 8 | import numpy as np 9 | 10 | 11 | def setup_default_logging(args, default_level=logging.INFO, 12 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"): 13 | 14 | if 'CIFAR' in args.dataset: 15 | output_dir = os.path.join(args.dataset, f'x{args.n_labeled}_seed{args.seed}', args.exp_dir) 16 | else: 17 | output_dir = os.path.join(args.dataset, f'f{args.folds}', args.exp_dir) 18 | 19 | os.makedirs(output_dir, exist_ok=True) 20 | 21 | logger = logging.getLogger('train') 22 | 23 | logging.basicConfig( # unlike the root logger, a custom logger can’t be configured using basicConfig() 24 | filename=os.path.join(output_dir, f'{time_str()}.log'), 25 | format=format, 26 | datefmt="%m/%d/%Y %H:%M:%S", 27 | level=default_level) 28 | 29 | # print 30 | # file_handler = logging.FileHandler() 31 | console_handler = logging.StreamHandler(sys.stdout) 32 | console_handler.setLevel(default_level) 33 | console_handler.setFormatter(logging.Formatter(format)) 34 | logger.addHandler(console_handler) 35 | 36 | return logger, output_dir 37 | 38 | 39 | def accuracy(output, target, topk=(1,)): 40 | """Computes the precision@k for the specified values of k""" 41 | maxk = max(topk) 42 | batch_size = target.size(0) 43 | 44 | _, pred = output.topk(maxk, 1, largest=True, sorted=True) # return value, indices 45 | pred = pred.t() 46 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 47 | 48 | res = [] 49 | for k in topk: 50 | correct_k = correct[:k].contiguous().view(-1).float().sum(0) 51 | res.append(correct_k.mul_(100.0 / batch_size)) 52 | return res 53 | 54 | 55 | class AverageMeter(object): 56 | """ 57 | Computes and stores the average and current value 58 | """ 59 | 60 | def __init__(self): 61 | self.reset() 62 | 63 | def reset(self): 64 | self.val = 0 65 | self.avg = 0 66 | self.sum = 0 67 | self.count = 0 68 | 69 | def update(self, val, n=1): 70 | self.val = val 71 | self.sum += val * n 72 | self.count += n 73 | # self.avg = self.sum / (self.count + 1e-20) 74 | self.avg = self.sum / self.count 75 | 76 | 77 | def time_str(fmt=None): 78 | if fmt is None: 79 | fmt = '%Y-%m-%d_%H:%M:%S' 80 | 81 | # time.strftime(format[, t]) 82 | return datetime.today().strftime(fmt) 83 | 84 | 85 | 86 | 87 | class WarmupCosineLrScheduler(_LRScheduler): 88 | 89 | def __init__( 90 | self, 91 | optimizer, 92 | max_iter, 93 | warmup_iter, 94 | warmup_ratio=5e-4, 95 | warmup='exp', 96 | last_epoch=-1, 97 | ): 98 | self.max_iter = max_iter 99 | self.warmup_iter = warmup_iter 100 | self.warmup_ratio = warmup_ratio 101 | self.warmup = warmup 102 | super(WarmupCosineLrScheduler, self).__init__(optimizer, last_epoch) 103 | 104 | def get_lr(self): 105 | ratio = self.get_lr_ratio() 106 | lrs = [ratio * lr for lr in self.base_lrs] 107 | return lrs 108 | 109 | def get_lr_ratio(self): 110 | if self.last_epoch < self.warmup_iter: 111 | ratio = self.get_warmup_ratio() 112 | else: 113 | real_iter = self.last_epoch - self.warmup_iter 114 | real_max_iter = self.max_iter - self.warmup_iter 115 | ratio = np.cos((7 * np.pi * real_iter) / (16 * real_max_iter)) 116 | #ratio = 0.5 * (1. + np.cos(np.pi * real_iter / real_max_iter)) 117 | return ratio 118 | 119 | def get_warmup_ratio(self): 120 | assert self.warmup in ('linear', 'exp') 121 | alpha = self.last_epoch / self.warmup_iter 122 | if self.warmup == 'linear': 123 | ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha 124 | elif self.warmup == 'exp': 125 | ratio = self.warmup_ratio ** (1. - alpha) 126 | return ratio --------------------------------------------------------------------------------