├── CODEOWNERS
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING-ARCHIVED.md
├── LICENSE.txt
├── README.md
├── SECURITY.md
├── Train_CoMatch.py
├── Train_fixmatch.py
├── WideResNet.py
├── comatch.gif
├── datasets
├── __init__.py
├── cifar.py
├── randaugment.py
├── sampler.py
└── transform.py
├── imagenet
├── Model.py
├── README.md
├── Train_CoMatch.py
├── loader.py
└── resnet.py
└── utils.py
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
2 | #ECCN:Open Source
3 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Salesforce Open Source Community Code of Conduct
2 |
3 | ## About the Code of Conduct
4 |
5 | Equality is a core value at Salesforce. We believe a diverse and inclusive
6 | community fosters innovation and creativity, and are committed to building a
7 | culture where everyone feels included.
8 |
9 | Salesforce open-source projects are committed to providing a friendly, safe, and
10 | welcoming environment for all, regardless of gender identity and expression,
11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality,
12 | race, age, religion, level of experience, education, socioeconomic status, or
13 | other similar personal characteristics.
14 |
15 | The goal of this code of conduct is to specify a baseline standard of behavior so
16 | that people with different social values and communication styles can work
17 | together effectively, productively, and respectfully in our open source community.
18 | It also establishes a mechanism for reporting issues and resolving conflicts.
19 |
20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior
21 | in a Salesforce open-source project may be reported by contacting the Salesforce
22 | Open Source Conduct Committee at ossconduct@salesforce.com.
23 |
24 | ## Our Pledge
25 |
26 | In the interest of fostering an open and welcoming environment, we as
27 | contributors and maintainers pledge to making participation in our project and
28 | our community a harassment-free experience for everyone, regardless of gender
29 | identity and expression, sexual orientation, disability, physical appearance,
30 | body size, ethnicity, nationality, race, age, religion, level of experience, education,
31 | socioeconomic status, or other similar personal characteristics.
32 |
33 | ## Our Standards
34 |
35 | Examples of behavior that contributes to creating a positive environment
36 | include:
37 |
38 | * Using welcoming and inclusive language
39 | * Being respectful of differing viewpoints and experiences
40 | * Gracefully accepting constructive criticism
41 | * Focusing on what is best for the community
42 | * Showing empathy toward other community members
43 |
44 | Examples of unacceptable behavior by participants include:
45 |
46 | * The use of sexualized language or imagery and unwelcome sexual attention or
47 | advances
48 | * Personal attacks, insulting/derogatory comments, or trolling
49 | * Public or private harassment
50 | * Publishing, or threatening to publish, others' private information—such as
51 | a physical or electronic address—without explicit permission
52 | * Other conduct which could reasonably be considered inappropriate in a
53 | professional setting
54 | * Advocating for or encouraging any of the above behaviors
55 |
56 | ## Our Responsibilities
57 |
58 | Project maintainers are responsible for clarifying the standards of acceptable
59 | behavior and are expected to take appropriate and fair corrective action in
60 | response to any instances of unacceptable behavior.
61 |
62 | Project maintainers have the right and responsibility to remove, edit, or
63 | reject comments, commits, code, wiki edits, issues, and other contributions
64 | that are not aligned with this Code of Conduct, or to ban temporarily or
65 | permanently any contributor for other behaviors that they deem inappropriate,
66 | threatening, offensive, or harmful.
67 |
68 | ## Scope
69 |
70 | This Code of Conduct applies both within project spaces and in public spaces
71 | when an individual is representing the project or its community. Examples of
72 | representing a project or community include using an official project email
73 | address, posting via an official social media account, or acting as an appointed
74 | representative at an online or offline event. Representation of a project may be
75 | further defined and clarified by project maintainers.
76 |
77 | ## Enforcement
78 |
79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
80 | reported by contacting the Salesforce Open Source Conduct Committee
81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated
82 | and will result in a response that is deemed necessary and appropriate to the
83 | circumstances. The committee is obligated to maintain confidentiality with
84 | regard to the reporter of an incident. Further details of specific enforcement
85 | policies may be posted separately.
86 |
87 | Project maintainers who do not follow or enforce the Code of Conduct in good
88 | faith may face temporary or permanent repercussions as determined by other
89 | members of the project's leadership and the Salesforce Open Source Conduct
90 | Committee.
91 |
92 | ## Attribution
93 |
94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],
95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html.
96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc],
97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].
98 |
99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].
100 |
101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)
102 | [golang-coc]: https://golang.org/conduct
103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md
104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/
105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/
106 |
--------------------------------------------------------------------------------
/CONTRIBUTING-ARCHIVED.md:
--------------------------------------------------------------------------------
1 | # ARCHIVED
2 |
3 | This project is `Archived` and is no longer actively maintained;
4 | We are not accepting contributions or Pull Requests.
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2021, Salesforce.com, Inc.
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5 |
6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7 |
8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9 |
10 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 |
12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## CoMatch: Semi-supervised Learning with Contrastive Graph Regularization, ICCV 2021 (Salesforce Research).
2 |
3 |
4 | This is a PyTorch implementation of the CoMatch paper [Blog]:
5 |
6 | @inproceedings{CoMatch, 7 | title={Semi-supervised Learning with Contrastive Graph Regularization}, 8 | author={Junnan Li and Caiming Xiong and Steven C.H. Hoi}, 9 | booktitle={ICCV}, 10 | year={2021} 11 | }12 | 13 | ### Requirements: 14 | * PyTorch ≥ 1.4 15 | * pip install tensorboard_logger 16 | * download and extract cifar-10 dataset into ./data/ 17 | 18 | To perform semi-supervised learning on CIFAR-10 with 4 labels per class, run: 19 |
python Train_CoMatch.py --n-labeled 40 --seed 120 | 21 | The results using different random seeds are: 22 | 23 | seed| 1 | 2 | 3 | 4 | 5 | avg 24 | --- | --- | --- | --- | --- | --- | --- 25 | accuracy|93.71|94.10|92.93|90.73|93.97|93.09 26 | 27 | ### ImageNet 28 | For ImageNet experiments, see ./imagenet/ 29 | 30 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) 4 | as soon as it is discovered. This library limits its runtime dependencies in 5 | order to reduce the total cost of ownership as much as can be, but all consumers 6 | should remain vigilant and have their security stakeholders review all third-party 7 | products (3PP) like this one and their dependencies. 8 | -------------------------------------------------------------------------------- /Train_CoMatch.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2018, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | ''' 7 | from __future__ import print_function 8 | import random 9 | 10 | import time 11 | import argparse 12 | import os 13 | import sys 14 | 15 | import numpy as np 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | from WideResNet import WideResnet 22 | from datasets.cifar import get_train_loader, get_val_loader 23 | from utils import accuracy, setup_default_logging, AverageMeter, WarmupCosineLrScheduler 24 | 25 | import tensorboard_logger 26 | 27 | def set_model(args): 28 | model = WideResnet(n_classes=args.n_classes,k=args.wresnet_k, n=args.wresnet_n, proj=True) 29 | if args.checkpoint: 30 | checkpoint = torch.load(args.checkpoint) 31 | msg = model.load_state_dict(checkpoint, strict=False) 32 | assert set(msg.missing_keys) == {"classifier.weight", "classifier.bias"} 33 | print('loaded from checkpoint: %s'%args.checkpoint) 34 | model.train() 35 | model.cuda() 36 | 37 | if args.eval_ema: 38 | ema_model = WideResnet(n_classes=args.n_classes,k=args.wresnet_k, n=args.wresnet_n, proj=True) 39 | for param_q, param_k in zip(model.parameters(), ema_model.parameters()): 40 | param_k.data.copy_(param_q.detach().data) # initialize 41 | param_k.requires_grad = False # not update by gradient for eval_net 42 | ema_model.cuda() 43 | ema_model.eval() 44 | else: 45 | ema_model = None 46 | 47 | criteria_x = nn.CrossEntropyLoss().cuda() 48 | return model, criteria_x, ema_model 49 | 50 | @torch.no_grad() 51 | def ema_model_update(model, ema_model, ema_m): 52 | """ 53 | Momentum update of evaluation model (exponential moving average) 54 | """ 55 | for param_train, param_eval in zip(model.parameters(), ema_model.parameters()): 56 | param_eval.copy_(param_eval * ema_m + param_train.detach() * (1-ema_m)) 57 | 58 | for buffer_train, buffer_eval in zip(model.buffers(), ema_model.buffers()): 59 | buffer_eval.copy_(buffer_train) 60 | 61 | def train_one_epoch(epoch, 62 | model, 63 | ema_model, 64 | prob_list, 65 | criteria_x, 66 | optim, 67 | lr_schdlr, 68 | dltrain_x, 69 | dltrain_u, 70 | args, 71 | n_iters, 72 | logger, 73 | queue_feats, 74 | queue_probs, 75 | queue_ptr, 76 | ): 77 | 78 | model.train() 79 | loss_x_meter = AverageMeter() 80 | loss_u_meter = AverageMeter() 81 | loss_contrast_meter = AverageMeter() 82 | # the number of correct pseudo-labels 83 | n_correct_u_lbs_meter = AverageMeter() 84 | # the number of confident unlabeled data 85 | n_strong_aug_meter = AverageMeter() 86 | mask_meter = AverageMeter() 87 | # the number of edges in the pseudo-label graph 88 | pos_meter = AverageMeter() 89 | 90 | epoch_start = time.time() # start time 91 | dl_x, dl_u = iter(dltrain_x), iter(dltrain_u) 92 | for it in range(n_iters): 93 | ims_x_weak, lbs_x = next(dl_x) 94 | (ims_u_weak, ims_u_strong0, ims_u_strong1), lbs_u_real = next(dl_u) 95 | 96 | lbs_x = lbs_x.cuda() 97 | lbs_u_real = lbs_u_real.cuda() 98 | 99 | # -------------------------------------- 100 | bt = ims_x_weak.size(0) 101 | btu = ims_u_weak.size(0) 102 | 103 | imgs = torch.cat([ims_x_weak, ims_u_weak, ims_u_strong0, ims_u_strong1], dim=0).cuda() 104 | logits, features = model(imgs) 105 | 106 | logits_x = logits[:bt] 107 | logits_u_w, logits_u_s0, logits_u_s1 = torch.split(logits[bt:], btu) 108 | 109 | feats_x = features[:bt] 110 | feats_u_w, feats_u_s0, feats_u_s1 = torch.split(features[bt:], btu) 111 | 112 | loss_x = criteria_x(logits_x, lbs_x) 113 | 114 | with torch.no_grad(): 115 | logits_u_w = logits_u_w.detach() 116 | feats_x = feats_x.detach() 117 | feats_u_w = feats_u_w.detach() 118 | 119 | probs = torch.softmax(logits_u_w, dim=1) 120 | # DA 121 | prob_list.append(probs.mean(0)) 122 | if len(prob_list)>32: 123 | prob_list.pop(0) 124 | prob_avg = torch.stack(prob_list,dim=0).mean(0) 125 | probs = probs / prob_avg 126 | probs = probs / probs.sum(dim=1, keepdim=True) 127 | 128 | probs_orig = probs.clone() 129 | 130 | if epoch>0 or it>args.queue_batch: # memory-smoothing 131 | A = torch.exp(torch.mm(feats_u_w, queue_feats.t())/args.temperature) 132 | A = A/A.sum(1,keepdim=True) 133 | probs = args.alpha*probs + (1-args.alpha)*torch.mm(A, queue_probs) 134 | 135 | scores, lbs_u_guess = torch.max(probs, dim=1) 136 | mask = scores.ge(args.thr).float() 137 | 138 | feats_w = torch.cat([feats_u_w,feats_x],dim=0) 139 | onehot = torch.zeros(bt,args.n_classes).cuda().scatter(1,lbs_x.view(-1,1),1) 140 | probs_w = torch.cat([probs_orig,onehot],dim=0) 141 | 142 | # update memory bank 143 | n = bt+btu 144 | queue_feats[queue_ptr:queue_ptr + n,:] = feats_w 145 | queue_probs[queue_ptr:queue_ptr + n,:] = probs_w 146 | queue_ptr = (queue_ptr+n)%args.queue_size 147 | 148 | 149 | # embedding similarity 150 | sim = torch.exp(torch.mm(feats_u_s0, feats_u_s1.t())/args.temperature) 151 | sim_probs = sim / sim.sum(1, keepdim=True) 152 | 153 | # pseudo-label graph with self-loop 154 | Q = torch.mm(probs, probs.t()) 155 | Q.fill_diagonal_(1) 156 | pos_mask = (Q>=args.contrast_th).float() 157 | 158 | Q = Q * pos_mask 159 | Q = Q / Q.sum(1, keepdim=True) 160 | 161 | # contrastive loss 162 | loss_contrast = - (torch.log(sim_probs + 1e-7) * Q).sum(1) 163 | loss_contrast = loss_contrast.mean() 164 | 165 | # unsupervised classification loss 166 | loss_u = - torch.sum((F.log_softmax(logits_u_s0,dim=1) * probs),dim=1) * mask 167 | loss_u = loss_u.mean() 168 | 169 | loss = loss_x + args.lam_u * loss_u + args.lam_c * loss_contrast 170 | 171 | optim.zero_grad() 172 | loss.backward() 173 | optim.step() 174 | lr_schdlr.step() 175 | 176 | if args.eval_ema: 177 | with torch.no_grad(): 178 | ema_model_update(model, ema_model, args.ema_m) 179 | 180 | loss_x_meter.update(loss_x.item()) 181 | loss_u_meter.update(loss_u.item()) 182 | loss_contrast_meter.update(loss_contrast.item()) 183 | mask_meter.update(mask.mean().item()) 184 | pos_meter.update(pos_mask.sum(1).float().mean().item()) 185 | 186 | corr_u_lb = (lbs_u_guess == lbs_u_real).float() * mask 187 | n_correct_u_lbs_meter.update(corr_u_lb.sum().item()) 188 | n_strong_aug_meter.update(mask.sum().item()) 189 | 190 | if (it + 1) % 64 == 0: 191 | t = time.time() - epoch_start 192 | 193 | lr_log = [pg['lr'] for pg in optim.param_groups] 194 | lr_log = sum(lr_log) / len(lr_log) 195 | 196 | logger.info("{}-x{}-s{}, {} | epoch:{}, iter: {}. loss_u: {:.3f}. loss_x: {:.3f}. loss_c: {:.3f}. " 197 | "n_correct_u: {:.2f}/{:.2f}. Mask:{:.3f}. num_pos: {:.1f}. LR: {:.3f}. Time: {:.2f}".format( 198 | args.dataset, args.n_labeled, args.seed, args.exp_dir, epoch, it + 1, loss_u_meter.avg, loss_x_meter.avg, loss_contrast_meter.avg, n_correct_u_lbs_meter.avg, n_strong_aug_meter.avg, mask_meter.avg, pos_meter.avg, lr_log, t)) 199 | 200 | epoch_start = time.time() 201 | 202 | return loss_x_meter.avg, loss_u_meter.avg, loss_contrast_meter.avg, mask_meter.avg, pos_meter.avg, n_correct_u_lbs_meter.avg/n_strong_aug_meter.avg, queue_feats, queue_probs, queue_ptr, prob_list 203 | 204 | 205 | def evaluate(model, ema_model, dataloader): 206 | 207 | model.eval() 208 | 209 | top1_meter = AverageMeter() 210 | ema_top1_meter = AverageMeter() 211 | 212 | with torch.no_grad(): 213 | for ims, lbs in dataloader: 214 | ims = ims.cuda() 215 | lbs = lbs.cuda() 216 | 217 | logits, _ = model(ims) 218 | scores = torch.softmax(logits, dim=1) 219 | top1, top5 = accuracy(scores, lbs, (1, 5)) 220 | top1_meter.update(top1.item()) 221 | 222 | if ema_model is not None: 223 | logits, _ = ema_model(ims) 224 | scores = torch.softmax(logits, dim=1) 225 | top1, top5 = accuracy(scores, lbs, (1, 5)) 226 | ema_top1_meter.update(top1.item()) 227 | 228 | return top1_meter.avg, ema_top1_meter.avg 229 | 230 | 231 | def main(): 232 | parser = argparse.ArgumentParser(description='CoMatch Cifar Training') 233 | parser.add_argument('--root', default='./data', type=str, help='dataset directory') 234 | parser.add_argument('--wresnet-k', default=2, type=int, 235 | help='width factor of wide resnet') 236 | parser.add_argument('--wresnet-n', default=28, type=int, 237 | help='depth of wide resnet') 238 | parser.add_argument('--dataset', type=str, default='CIFAR10', 239 | help='number of classes in dataset') 240 | parser.add_argument('--n-classes', type=int, default=10, 241 | help='number of classes in dataset') 242 | parser.add_argument('--n-labeled', type=int, default=40, 243 | help='number of labeled samples for training') 244 | parser.add_argument('--n-epoches', type=int, default=512, 245 | help='number of training epoches') 246 | parser.add_argument('--batchsize', type=int, default=64, 247 | help='train batch size of labeled samples') 248 | parser.add_argument('--mu', type=int, default=7, 249 | help='factor of train batch size of unlabeled samples') 250 | parser.add_argument('--n-imgs-per-epoch', type=int, default=64 * 1024, 251 | help='number of training images for each epoch') 252 | 253 | parser.add_argument('--eval-ema', default=True, help='whether to use ema model for evaluation') 254 | parser.add_argument('--ema-m', type=float, default=0.999) 255 | 256 | parser.add_argument('--lam-u', type=float, default=1., 257 | help='coefficient of unlabeled loss') 258 | parser.add_argument('--lr', type=float, default=0.03, 259 | help='learning rate for training') 260 | parser.add_argument('--weight-decay', type=float, default=5e-4, 261 | help='weight decay') 262 | parser.add_argument('--momentum', type=float, default=0.9, 263 | help='momentum for optimizer') 264 | parser.add_argument('--seed', type=int, default=1, 265 | help='seed for random behaviors, no seed if negtive') 266 | 267 | parser.add_argument('--temperature', default=0.2, type=float, help='softmax temperature') 268 | parser.add_argument('--low-dim', type=int, default=64) 269 | parser.add_argument('--lam-c', type=float, default=1, 270 | help='coefficient of contrastive loss') 271 | parser.add_argument('--contrast-th', default=0.8, type=float, 272 | help='pseudo label graph threshold') 273 | parser.add_argument('--thr', type=float, default=0.95, 274 | help='pseudo label threshold') 275 | parser.add_argument('--alpha', type=float, default=0.9) 276 | parser.add_argument('--queue-batch', type=float, default=5, 277 | help='number of batches stored in memory bank') 278 | parser.add_argument('--exp-dir', default='CoMatch', type=str, help='experiment id') 279 | parser.add_argument('--checkpoint', default='', type=str, help='use pretrained model') 280 | 281 | args = parser.parse_args() 282 | 283 | logger, output_dir = setup_default_logging(args) 284 | logger.info(dict(args._get_kwargs())) 285 | 286 | tb_logger = tensorboard_logger.Logger(logdir=output_dir, flush_secs=2) 287 | 288 | if args.seed > 0: 289 | torch.manual_seed(args.seed) 290 | random.seed(args.seed) 291 | np.random.seed(args.seed) 292 | 293 | n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize # 1024 294 | n_iters_all = n_iters_per_epoch * args.n_epoches # 1024 * 200 295 | 296 | logger.info("***** Running training *****") 297 | logger.info(f" Task = {args.dataset}@{args.n_labeled}") 298 | 299 | model, criteria_x, ema_model = set_model(args) 300 | logger.info("Total params: {:.2f}M".format( 301 | sum(p.numel() for p in model.parameters()) / 1e6)) 302 | 303 | dltrain_x, dltrain_u = get_train_loader( 304 | args.dataset, args.batchsize, args.mu, n_iters_per_epoch, L=args.n_labeled, root=args.root, method='comatch') 305 | dlval = get_val_loader(dataset=args.dataset, batch_size=64, num_workers=2, root=args.root) 306 | 307 | wd_params, non_wd_params = [], [] 308 | for name, param in model.named_parameters(): 309 | if 'bn' in name: 310 | non_wd_params.append(param) 311 | else: 312 | wd_params.append(param) 313 | param_list = [ 314 | {'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}] 315 | optim = torch.optim.SGD(param_list, lr=args.lr, weight_decay=args.weight_decay, 316 | momentum=args.momentum, nesterov=True) 317 | 318 | lr_schdlr = WarmupCosineLrScheduler(optim, n_iters_all, warmup_iter=0) 319 | 320 | # memory bank 321 | args.queue_size = args.queue_batch*(args.mu+1)*args.batchsize 322 | queue_feats = torch.zeros(args.queue_size, args.low_dim).cuda() 323 | queue_probs = torch.zeros(args.queue_size, args.n_classes).cuda() 324 | queue_ptr = 0 325 | 326 | # for distribution alignment 327 | prob_list = [] 328 | 329 | train_args = dict( 330 | model=model, 331 | ema_model=ema_model, 332 | prob_list=prob_list, 333 | criteria_x=criteria_x, 334 | optim=optim, 335 | lr_schdlr=lr_schdlr, 336 | dltrain_x=dltrain_x, 337 | dltrain_u=dltrain_u, 338 | args=args, 339 | n_iters=n_iters_per_epoch, 340 | logger=logger 341 | ) 342 | 343 | best_acc = -1 344 | best_epoch = 0 345 | logger.info('-----------start training--------------') 346 | for epoch in range(args.n_epoches): 347 | 348 | loss_x, loss_u, loss_c, mask_mean, num_pos, guess_label_acc, queue_feats, queue_probs, queue_ptr, prob_list = \ 349 | train_one_epoch(epoch, **train_args, queue_feats=queue_feats,queue_probs=queue_probs,queue_ptr=queue_ptr) 350 | 351 | top1, ema_top1 = evaluate(model, ema_model, dlval) 352 | 353 | tb_logger.log_value('loss_x', loss_x, epoch) 354 | tb_logger.log_value('loss_u', loss_u, epoch) 355 | tb_logger.log_value('loss_c', loss_c, epoch) 356 | tb_logger.log_value('guess_label_acc', guess_label_acc, epoch) 357 | tb_logger.log_value('test_acc', top1, epoch) 358 | tb_logger.log_value('test_ema_acc', ema_top1, epoch) 359 | tb_logger.log_value('mask', mask_mean, epoch) 360 | tb_logger.log_value('num_pos', num_pos, epoch) 361 | 362 | if best_acc < top1: 363 | best_acc = top1 364 | best_epoch = epoch 365 | 366 | logger.info("Epoch {}. Acc: {:.4f}. Ema-Acc: {:.4f}. best_acc: {:.4f} in epoch{}". 367 | format(epoch, top1, ema_top1, best_acc, best_epoch)) 368 | 369 | if epoch%10==0: 370 | save_obj = { 371 | 'model': model.state_dict(), 372 | 'ema_model': ema_model.state_dict(), 373 | 'optimizer': optim.state_dict(), 374 | 'lr_scheduler': lr_schdlr.state_dict(), 375 | 'prob_list': prob_list, 376 | 'queue': {'queue_feats':queue_feats, 'queue_probs':queue_probs, 'queue_ptr':queue_ptr}, 377 | 'epoch': epoch, 378 | } 379 | torch.save(save_obj, os.path.join(output_dir, 'checkpoint_%02d.pth'%epoch)) 380 | 381 | if __name__ == '__main__': 382 | main() 383 | -------------------------------------------------------------------------------- /Train_fixmatch.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import random 3 | 4 | import time 5 | import argparse 6 | import os 7 | import sys 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from WideResNet import WideResnet 16 | from datasets.cifar import get_train_loader, get_val_loader 17 | 18 | from utils import accuracy, setup_default_logging, AverageMeter, WarmupCosineLrScheduler 19 | 20 | import tensorboard_logger 21 | 22 | def set_model(args): 23 | model = WideResnet(n_classes=args.n_classes,k=args.wresnet_k, n=args.wresnet_n, proj=False) 24 | 25 | if args.checkpoint: 26 | checkpoint = torch.load(args.checkpoint) 27 | 28 | msg = model.load_state_dict(checkpoint, strict=False) 29 | assert set(msg.missing_keys) == {"classifier.weight", "classifier.bias"} 30 | assert set(msg.unexpected_keys) == {'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'} 31 | print('loaded from checkpoint: %s'%args.checkpoint) 32 | 33 | model.train() 34 | model.cuda() 35 | criteria_x = nn.CrossEntropyLoss().cuda() 36 | criteria_u = nn.CrossEntropyLoss(reduction='none').cuda() 37 | 38 | if args.eval_ema: 39 | ema_model = WideResnet(n_classes=args.n_classes,k=args.wresnet_k, n=args.wresnet_n, proj=False) 40 | for param_q, param_k in zip(model.parameters(), ema_model.parameters()): 41 | param_k.data.copy_(param_q.detach().data) # initialize 42 | param_k.requires_grad = False # not update by gradient for eval_net 43 | ema_model.cuda() 44 | ema_model.eval() 45 | else: 46 | ema_model = None 47 | 48 | return model, criteria_x, criteria_u, ema_model 49 | 50 | 51 | @torch.no_grad() 52 | def ema_model_update(model, ema_model, ema_m): 53 | """ 54 | Momentum update of evaluation model (exponential moving average) 55 | """ 56 | for param_train, param_eval in zip(model.parameters(), ema_model.parameters()): 57 | param_eval.copy_(param_eval * ema_m + param_train.detach() * (1-ema_m)) 58 | 59 | for buffer_train, buffer_eval in zip(model.buffers(), ema_model.buffers()): 60 | buffer_eval.copy_(buffer_train) 61 | 62 | 63 | def train_one_epoch(epoch, 64 | model, 65 | ema_model, 66 | criteria_x, 67 | criteria_u, 68 | optim, 69 | lr_schdlr, 70 | dltrain_x, 71 | dltrain_u, 72 | args, 73 | n_iters, 74 | logger, 75 | prob_list, 76 | ): 77 | model.train() 78 | loss_x_meter = AverageMeter() 79 | loss_u_meter = AverageMeter() 80 | # the number of correctly-predicted and gradient-considered unlabeled data 81 | n_correct_u_lbs_meter = AverageMeter() 82 | # the number of gradient-considered strong augmentation (logits above threshold) of unlabeled samples 83 | n_strong_aug_meter = AverageMeter() 84 | mask_meter = AverageMeter() 85 | 86 | 87 | epoch_start = time.time() # start time 88 | dl_x, dl_u = iter(dltrain_x), iter(dltrain_u) 89 | for it in range(n_iters): 90 | ims_x_weak, lbs_x = next(dl_x) 91 | (ims_u_weak, ims_u_strong), lbs_u_real = next(dl_u) 92 | 93 | lbs_x = lbs_x.cuda() 94 | lbs_u_real = lbs_u_real.cuda() 95 | 96 | # -------------------------------------- 97 | bt = ims_x_weak.size(0) 98 | mu = int(ims_u_weak.size(0) // bt) 99 | imgs = torch.cat([ims_x_weak, ims_u_weak, ims_u_strong], dim=0).cuda() 100 | logits = model(imgs) 101 | 102 | logits_x = logits[:bt] 103 | logits_u_w, logits_u_s = torch.split(logits[bt:], bt * mu) 104 | 105 | loss_x = criteria_x(logits_x, lbs_x) 106 | 107 | with torch.no_grad(): 108 | probs = torch.softmax(logits_u_w, dim=1) 109 | 110 | if args.DA: 111 | prob_list.append(probs.mean(0)) 112 | if len(prob_list)>32: 113 | prob_list.pop(0) 114 | prob_avg = torch.stack(prob_list,dim=0).mean(0) 115 | probs = probs / prob_avg 116 | probs = probs / probs.sum(dim=1, keepdim=True) 117 | 118 | scores, lbs_u_guess = torch.max(probs, dim=1) 119 | mask = scores.ge(args.thr).float() 120 | 121 | probs = probs.detach() 122 | 123 | loss_u = (criteria_u(logits_u_s, lbs_u_guess) * mask).mean() 124 | 125 | loss = loss_x + args.lam_u * loss_u 126 | 127 | optim.zero_grad() 128 | loss.backward() 129 | optim.step() 130 | lr_schdlr.step() 131 | 132 | if args.eval_ema: 133 | with torch.no_grad(): 134 | ema_model_update(model, ema_model, args.ema_m) 135 | 136 | loss_x_meter.update(loss_x.item()) 137 | loss_u_meter.update(loss_u.item()) 138 | mask_meter.update(mask.mean().item()) 139 | 140 | 141 | corr_u_lb = (lbs_u_guess == lbs_u_real).float() * mask 142 | n_correct_u_lbs_meter.update(corr_u_lb.sum().item()) 143 | n_strong_aug_meter.update(mask.sum().item()) 144 | 145 | if (it + 1) % 64 == 0: 146 | t = time.time() - epoch_start 147 | 148 | lr_log = [pg['lr'] for pg in optim.param_groups] 149 | lr_log = sum(lr_log) / len(lr_log) 150 | 151 | logger.info("{}-x{}-s{}, {} | epoch:{}, iter: {}. loss_u: {:.3f}. loss_x: {:.3f}. " 152 | "n_correct_u: {:.2f}/{:.2f}. Mask:{:.3f}. LR: {:.3f}. Time: {:.2f}".format( 153 | args.dataset, args.n_labeled, args.seed, args.exp_dir, epoch, it + 1, loss_u_meter.avg, loss_x_meter.avg, 154 | n_correct_u_lbs_meter.avg, n_strong_aug_meter.avg, mask_meter.avg, lr_log, t)) 155 | 156 | epoch_start = time.time() 157 | 158 | return loss_x_meter.avg, loss_u_meter.avg, mask_meter.avg, n_correct_u_lbs_meter.avg/n_strong_aug_meter.avg, prob_list 159 | 160 | 161 | def evaluate(model, ema_model, dataloader, criterion): 162 | 163 | model.eval() 164 | 165 | top1_meter = AverageMeter() 166 | ema_top1_meter = AverageMeter() 167 | 168 | with torch.no_grad(): 169 | for ims, lbs in dataloader: 170 | ims = ims.cuda() 171 | lbs = lbs.cuda() 172 | 173 | logits = model(ims) 174 | loss = criterion(logits, lbs) 175 | scores = torch.softmax(logits, dim=1) 176 | top1, top5 = accuracy(scores, lbs, (1, 5)) 177 | top1_meter.update(top1.item()) 178 | 179 | if ema_model is not None: 180 | logits = ema_model(ims) 181 | loss = criterion(logits, lbs) 182 | scores = torch.softmax(logits, dim=1) 183 | top1, top5 = accuracy(scores, lbs, (1, 5)) 184 | ema_top1_meter.update(top1.item()) 185 | 186 | return top1_meter.avg, ema_top1_meter.avg 187 | 188 | 189 | def main(): 190 | parser = argparse.ArgumentParser(description='FixMatch Training') 191 | parser.add_argument('--root', default='./data', type=str, help='dataset directory') 192 | parser.add_argument('--wresnet-k', default=2, type=int, 193 | help='width factor of wide resnet') 194 | parser.add_argument('--wresnet-n', default=28, type=int, 195 | help='depth of wide resnet') 196 | parser.add_argument('--dataset', type=str, default='CIFAR10', 197 | help='number of classes in dataset') 198 | parser.add_argument('--n-classes', type=int, default=10, 199 | help='number of classes in dataset') 200 | parser.add_argument('--n-labeled', type=int, default=40, 201 | help='number of labeled samples for training') 202 | parser.add_argument('--n-epoches', type=int, default=1024, 203 | help='number of training epoches') 204 | parser.add_argument('--batchsize', type=int, default=64, 205 | help='train batch size of labeled samples') 206 | parser.add_argument('--mu', type=int, default=7, 207 | help='factor of train batch size of unlabeled samples') 208 | 209 | parser.add_argument('--eval-ema', default=True, help='whether to use ema model for evaluation') 210 | parser.add_argument('--ema-m', type=float, default=0.999) 211 | 212 | parser.add_argument('--n-imgs-per-epoch', type=int, default=64 * 1024, 213 | help='number of training images for each epoch') 214 | parser.add_argument('--lam-u', type=float, default=1., 215 | help='coefficient of unlabeled loss') 216 | parser.add_argument('--lr', type=float, default=0.03, 217 | help='learning rate for training') 218 | parser.add_argument('--weight-decay', type=float, default=5e-4, 219 | help='weight decay') 220 | parser.add_argument('--momentum', type=float, default=0.9, 221 | help='momentum for optimizer') 222 | parser.add_argument('--seed', type=int, default=1, 223 | help='seed for random behaviors, no seed if negtive') 224 | parser.add_argument('--DA', default=True, help='use distribution alignment') 225 | 226 | parser.add_argument('--thr', type=float, default=0.95, 227 | help='pseudo label threshold') 228 | 229 | parser.add_argument('--exp-dir', default='FixMatch', type=str, help='experiment directory') 230 | parser.add_argument('--checkpoint', default='', type=str, help='use pretrained model') 231 | #/export/home/project/SimCLR/save_cifar_t0.2/checkpoint_100.tar 232 | 233 | args = parser.parse_args() 234 | 235 | logger, output_dir = setup_default_logging(args) 236 | logger.info(dict(args._get_kwargs())) 237 | 238 | tb_logger = tensorboard_logger.Logger(logdir=output_dir, flush_secs=2) 239 | 240 | if args.seed > 0: 241 | torch.manual_seed(args.seed) 242 | random.seed(args.seed) 243 | np.random.seed(args.seed) 244 | 245 | n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize # 1024 246 | n_iters_all = n_iters_per_epoch * args.n_epoches # 1024 * 200 247 | 248 | logger.info("***** Running training *****") 249 | logger.info(f" Task = {args.dataset}@{args.n_labeled}") 250 | 251 | model, criteria_x, criteria_u, ema_model = set_model(args) 252 | logger.info("Total params: {:.2f}M".format( 253 | sum(p.numel() for p in model.parameters()) / 1e6)) 254 | 255 | dltrain_x, dltrain_u = get_train_loader( 256 | args.dataset, args.batchsize, args.mu, n_iters_per_epoch, L=args.n_labeled, root=args.root, method='fixmatch') 257 | dlval = get_val_loader(dataset=args.dataset, batch_size=64, num_workers=2) 258 | 259 | wd_params, non_wd_params = [], [] 260 | for name, param in model.named_parameters(): 261 | if 'bn' in name: 262 | non_wd_params.append(param) 263 | else: 264 | wd_params.append(param) 265 | param_list = [ 266 | {'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}] 267 | optim = torch.optim.SGD(param_list, lr=args.lr, weight_decay=args.weight_decay, 268 | momentum=args.momentum, nesterov=True) 269 | 270 | lr_schdlr = WarmupCosineLrScheduler(optim, n_iters_all, warmup_iter=0) 271 | 272 | prob_list = [] 273 | train_args = dict( 274 | model=model, 275 | ema_model=ema_model, 276 | criteria_x=criteria_x, 277 | criteria_u=criteria_u, 278 | optim=optim, 279 | lr_schdlr=lr_schdlr, 280 | dltrain_x=dltrain_x, 281 | dltrain_u=dltrain_u, 282 | args=args, 283 | n_iters=n_iters_per_epoch, 284 | logger=logger, 285 | prob_list=prob_list 286 | ) 287 | best_acc = -1 288 | best_epoch = 0 289 | logger.info('-----------start training--------------') 290 | for epoch in range(args.n_epoches): 291 | loss_x, loss_u, mask_mean, guess_label_acc, prob_list = train_one_epoch(epoch, **train_args) 292 | 293 | top1, ema_top1 = evaluate(model, ema_model, dlval, criteria_x) 294 | 295 | tb_logger.log_value('loss_x', loss_x, epoch) 296 | tb_logger.log_value('loss_u', loss_u, epoch) 297 | tb_logger.log_value('guess_label_acc', guess_label_acc, epoch) 298 | tb_logger.log_value('test_acc', top1, epoch) 299 | tb_logger.log_value('test_ema_acc', ema_top1, epoch) 300 | tb_logger.log_value('mask', mask_mean, epoch) 301 | 302 | if best_acc < top1: 303 | best_acc = top1 304 | best_epoch = epoch 305 | 306 | logger.info("Epoch {}. Acc: {:.4f}. Ema-Acc: {:.4f}. best_acc: {:.4f} in epoch{}". 307 | format(epoch, top1, ema_top1, best_acc, best_epoch)) 308 | 309 | 310 | if __name__ == '__main__': 311 | main() 312 | -------------------------------------------------------------------------------- /WideResNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | import torch.nn.functional as F 9 | 10 | from torch.nn import BatchNorm2d 11 | 12 | ''' 13 | As in the paper, the wide resnet only considers the resnet of the pre-activated version, 14 | and it only considers the basic blocks rather than the bottleneck blocks. 15 | ''' 16 | 17 | 18 | class BasicBlockPreAct(nn.Module): 19 | def __init__(self, in_chan, out_chan, drop_rate=0, stride=1, pre_res_act=False): 20 | super(BasicBlockPreAct, self).__init__() 21 | self.bn1 = BatchNorm2d(in_chan, momentum=0.001) 22 | self.relu1 = nn.LeakyReLU(inplace=True, negative_slope=0.1) 23 | self.conv1 = nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=stride, padding=1, bias=False) 24 | self.bn2 = BatchNorm2d(out_chan, momentum=0.001) 25 | self.relu2 = nn.LeakyReLU(inplace=True, negative_slope=0.1) 26 | self.dropout = nn.Dropout(drop_rate) if not drop_rate == 0 else None 27 | self.conv2 = nn.Conv2d(out_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False) 28 | self.downsample = None 29 | if in_chan != out_chan or stride != 1: 30 | self.downsample = nn.Conv2d( 31 | in_chan, out_chan, kernel_size=1, stride=stride, bias=False 32 | ) 33 | self.pre_res_act = pre_res_act 34 | # self.init_weight() 35 | 36 | def forward(self, x): 37 | bn1 = self.bn1(x) 38 | act1 = self.relu1(bn1) 39 | residual = self.conv1(act1) 40 | residual = self.bn2(residual) 41 | residual = self.relu2(residual) 42 | if self.dropout is not None: 43 | residual = self.dropout(residual) 44 | residual = self.conv2(residual) 45 | 46 | shortcut = act1 if self.pre_res_act else x 47 | if self.downsample is not None: 48 | shortcut = self.downsample(shortcut) 49 | 50 | out = shortcut + residual 51 | return out 52 | 53 | def init_weight(self): 54 | # for _, md in self.named_modules(): 55 | # if isinstance(md, nn.Conv2d): 56 | # nn.init.kaiming_normal_( 57 | # md.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') 58 | # if md.bias is not None: 59 | # nn.init.constant_(md.bias, 0) 60 | for m in self.modules(): 61 | if isinstance(m, nn.Conv2d): 62 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') 63 | if m.bias is not None: 64 | nn.init.constant_(m.bias, 0) 65 | 66 | 67 | class WideResnetBackbone(nn.Module): 68 | def __init__(self, k=1, n=28, drop_rate=0): 69 | super(WideResnetBackbone, self).__init__() 70 | self.k, self.n = k, n 71 | assert (self.n - 4) % 6 == 0 72 | n_blocks = (self.n - 4) // 6 73 | n_layers = [16, ] + [self.k * 16 * (2 ** i) for i in range(3)] 74 | 75 | self.conv1 = nn.Conv2d( 76 | 3, 77 | n_layers[0], 78 | kernel_size=3, 79 | stride=1, 80 | padding=1, 81 | bias=False 82 | ) 83 | self.layer1 = self.create_layer( 84 | n_layers[0], 85 | n_layers[1], 86 | bnum=n_blocks, 87 | stride=1, 88 | drop_rate=drop_rate, 89 | pre_res_act=True, 90 | ) 91 | self.layer2 = self.create_layer( 92 | n_layers[1], 93 | n_layers[2], 94 | bnum=n_blocks, 95 | stride=2, 96 | drop_rate=drop_rate, 97 | pre_res_act=False, 98 | ) 99 | self.layer3 = self.create_layer( 100 | n_layers[2], 101 | n_layers[3], 102 | bnum=n_blocks, 103 | stride=2, 104 | drop_rate=drop_rate, 105 | pre_res_act=False, 106 | ) 107 | self.bn_last = BatchNorm2d(n_layers[3], momentum=0.001) 108 | self.relu_last = nn.LeakyReLU(inplace=True, negative_slope=0.1) 109 | self.init_weight() 110 | 111 | def create_layer( 112 | self, 113 | in_chan, 114 | out_chan, 115 | bnum, 116 | stride=1, 117 | drop_rate=0, 118 | pre_res_act=False, 119 | ): 120 | layers = [ 121 | BasicBlockPreAct( 122 | in_chan, 123 | out_chan, 124 | drop_rate=drop_rate, 125 | stride=stride, 126 | pre_res_act=pre_res_act), ] 127 | for _ in range(bnum - 1): 128 | layers.append( 129 | BasicBlockPreAct( 130 | out_chan, 131 | out_chan, 132 | drop_rate=drop_rate, 133 | stride=1, 134 | pre_res_act=False, )) 135 | return nn.Sequential(*layers) 136 | 137 | def forward(self, x): 138 | feat = self.conv1(x) 139 | 140 | feat = self.layer1(feat) 141 | feat2 = self.layer2(feat) # 1/2 142 | feat4 = self.layer3(feat2) # 1/4 143 | 144 | feat4 = self.bn_last(feat4) 145 | feat4 = self.relu_last(feat4) 146 | return feat2, feat4 147 | 148 | def init_weight(self): 149 | # for _, child in self.named_children(): 150 | # if isinstance(child, nn.Conv2d): 151 | # n = child.kernel_size[0] * child.kernel_size[0] * child.out_channels 152 | # nn.init.normal_(child.weight, 0, 1. / ((0.5 * n) ** 0.5)) 153 | # # nn.init.kaiming_normal_( 154 | # # child.weight, a=0.1, mode='fan_out', 155 | # # nonlinearity='leaky_relu' 156 | # # ) 157 | # 158 | # if child.bias is not None: 159 | # nn.init.constant_(child.bias, 0) 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 163 | m.weight.data.normal_(0, math.sqrt(2. / n)) 164 | 165 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') 166 | if m.bias is not None: 167 | nn.init.constant_(m.bias, 0) 168 | elif isinstance(m, nn.BatchNorm2d): 169 | m.weight.data.fill_(1) 170 | m.bias.data.zero_() 171 | elif isinstance(m, nn.Linear): 172 | m.bias.data.zero_() 173 | 174 | class Normalize(nn.Module): 175 | 176 | def __init__(self, power=2): 177 | super(Normalize, self).__init__() 178 | self.power = power 179 | 180 | def forward(self, x): 181 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 182 | out = x.div(norm) 183 | return out 184 | 185 | class WideResnet(nn.Module): 186 | ''' 187 | for wide-resnet-28-10, the definition should be WideResnet(n_classes, 10, 28) 188 | ''' 189 | 190 | def __init__(self, n_classes, k=1, n=28, low_dim=64, proj=True): 191 | super(WideResnet, self).__init__() 192 | self.n_layers, self.k = n, k 193 | self.backbone = WideResnetBackbone(k=k, n=n) 194 | self.classifier = nn.Linear(64 * self.k, n_classes, bias=True) 195 | 196 | self.proj = proj 197 | if proj: 198 | self.l2norm = Normalize(2) 199 | 200 | self.fc1 = nn.Linear(64 * self.k, 64 * self.k) 201 | self.relu_mlp = nn.LeakyReLU(inplace=True, negative_slope=0.1) 202 | self.fc2 = nn.Linear(64 * self.k, low_dim) 203 | 204 | 205 | def forward(self, x): 206 | feat = self.backbone(x)[-1] 207 | feat = torch.mean(feat, dim=(2, 3)) 208 | out = self.classifier(feat) 209 | 210 | if self.proj: 211 | feat = self.fc1(feat) 212 | feat = self.relu_mlp(feat) 213 | feat = self.fc2(feat) 214 | 215 | feat = self.l2norm(feat) 216 | return out,feat 217 | else: 218 | return out 219 | 220 | def init_weight(self): 221 | nn.init.xavier_normal_(self.classifier.weight) 222 | if not self.classifier.bias is None: 223 | nn.init.constant_(self.classifier.bias, 0) 224 | 225 | 226 | if __name__ == "__main__": 227 | x = torch.randn(2, 3, 224, 224) 228 | lb = torch.randint(0, 10, (2,)).long() 229 | 230 | net = WideResnetBackbone() 231 | out = net(x) 232 | print(out[0].size()) 233 | del net, out 234 | 235 | net = WideResnet(n_classes=10) 236 | criteria = nn.CrossEntropyLoss() 237 | out = net(x) 238 | loss = criteria(out, lb) 239 | loss.backward() 240 | print(out.size()) -------------------------------------------------------------------------------- /comatch.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/CoMatch/a64ccf5af11017a1a07267b21e5899f4e8157801/comatch.gif -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/CoMatch/a64ccf5af11017a1a07267b21e5899f4e8157801/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/cifar.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pickle 3 | import numpy as np 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | 9 | from datasets import transform as T 10 | from datasets.randaugment import RandomAugment 11 | from datasets.sampler import RandomSampler, BatchSampler 12 | 13 | class TwoCropsTransform: 14 | """Take 2 random augmentations of one image.""" 15 | 16 | def __init__(self,trans_weak,trans_strong): 17 | self.trans_weak = trans_weak 18 | self.trans_strong = trans_strong 19 | def __call__(self, x): 20 | x1 = self.trans_weak(x) 21 | x2 = self.trans_strong(x) 22 | return [x1, x2] 23 | 24 | class ThreeCropsTransform: 25 | """Take 3 random augmentations of one image.""" 26 | 27 | def __init__(self,trans_weak,trans_strong0,trans_strong1): 28 | self.trans_weak = trans_weak 29 | self.trans_strong0 = trans_strong0 30 | self.trans_strong1 = trans_strong1 31 | def __call__(self, x): 32 | x1 = self.trans_weak(x) 33 | x2 = self.trans_strong0(x) 34 | x3 = self.trans_strong1(x) 35 | return [x1, x2, x3] 36 | 37 | def load_data_train(L=250, dataset='CIFAR10', dspth='./data'): 38 | if dataset == 'CIFAR10': 39 | datalist = [ 40 | osp.join(dspth, 'cifar-10-batches-py', 'data_batch_{}'.format(i + 1)) 41 | for i in range(5) 42 | ] 43 | n_class = 10 44 | assert L in [10, 20, 40, 80, 250, 4000] 45 | elif dataset == 'CIFAR100': 46 | datalist = [ 47 | osp.join(dspth, 'cifar-100-python', 'train')] 48 | n_class = 100 49 | assert L in [25, 400, 2500, 10000] 50 | 51 | data, labels = [], [] 52 | for data_batch in datalist: 53 | with open(data_batch, 'rb') as fr: 54 | entry = pickle.load(fr, encoding='latin1') 55 | lbs = entry['labels'] if 'labels' in entry.keys() else entry['fine_labels'] 56 | data.append(entry['data']) 57 | labels.append(lbs) 58 | data = np.concatenate(data, axis=0) 59 | labels = np.concatenate(labels, axis=0) 60 | n_labels = L // n_class 61 | data_x, label_x, data_u, label_u = [], [], [], [] 62 | for i in range(n_class): 63 | indices = np.where(labels == i)[0] 64 | np.random.shuffle(indices) 65 | inds_x, inds_u = indices[:n_labels], indices[n_labels:] 66 | data_x += [ 67 | data[i].reshape(3, 32, 32).transpose(1, 2, 0) 68 | for i in inds_x 69 | ] 70 | label_x += [labels[i] for i in inds_x] 71 | data_u += [ 72 | data[i].reshape(3, 32, 32).transpose(1, 2, 0) 73 | for i in inds_u 74 | ] 75 | label_u += [labels[i] for i in inds_u] 76 | return data_x, label_x, data_u, label_u 77 | 78 | 79 | def load_data_val(dataset, dspth='./data'): 80 | if dataset == 'CIFAR10': 81 | datalist = [ 82 | osp.join(dspth, 'cifar-10-batches-py', 'test_batch') 83 | ] 84 | elif dataset == 'CIFAR100': 85 | datalist = [ 86 | osp.join(dspth, 'cifar-100-python', 'test') 87 | ] 88 | 89 | data, labels = [], [] 90 | for data_batch in datalist: 91 | with open(data_batch, 'rb') as fr: 92 | entry = pickle.load(fr, encoding='latin1') 93 | lbs = entry['labels'] if 'labels' in entry.keys() else entry['fine_labels'] 94 | data.append(entry['data']) 95 | labels.append(lbs) 96 | data = np.concatenate(data, axis=0) 97 | labels = np.concatenate(labels, axis=0) 98 | data = [ 99 | el.reshape(3, 32, 32).transpose(1, 2, 0) 100 | for el in data 101 | ] 102 | return data, labels 103 | 104 | 105 | def compute_mean_var(): 106 | data_x, label_x, data_u, label_u = load_data_train() 107 | data = data_x + data_u 108 | data = np.concatenate([el[None, ...] for el in data], axis=0) 109 | 110 | mean, var = [], [] 111 | for i in range(3): 112 | channel = (data[:, :, :, i].ravel() / 127.5) - 1 113 | # channel = (data[:, :, :, i].ravel() / 255) 114 | mean.append(np.mean(channel)) 115 | var.append(np.std(channel)) 116 | 117 | print('mean: ', mean) 118 | print('var: ', var) 119 | 120 | 121 | 122 | class Cifar(Dataset): 123 | def __init__(self, dataset, data, labels, mode): 124 | super(Cifar, self).__init__() 125 | self.data, self.labels = data, labels 126 | self.mode = mode 127 | assert len(self.data) == len(self.labels) 128 | if dataset == 'CIFAR10': 129 | mean, std = (0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616) 130 | elif dataset == 'CIFAR100': 131 | mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761) 132 | 133 | trans_weak = T.Compose([ 134 | T.Resize((32, 32)), 135 | T.PadandRandomCrop(border=4, cropsize=(32, 32)), 136 | T.RandomHorizontalFlip(p=0.5), 137 | T.Normalize(mean, std), 138 | T.ToTensor(), 139 | ]) 140 | trans_strong0 = T.Compose([ 141 | T.Resize((32, 32)), 142 | T.PadandRandomCrop(border=4, cropsize=(32, 32)), 143 | T.RandomHorizontalFlip(p=0.5), 144 | RandomAugment(2, 10), 145 | T.Normalize(mean, std), 146 | T.ToTensor(), 147 | ]) 148 | trans_strong1 = transforms.Compose([ 149 | transforms.ToPILImage(), 150 | transforms.RandomResizedCrop(32, scale=(0.2, 1.)), 151 | transforms.RandomHorizontalFlip(p=0.5), 152 | transforms.RandomApply([ 153 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 154 | ], p=0.8), 155 | transforms.RandomGrayscale(p=0.2), 156 | transforms.ToTensor(), 157 | transforms.Normalize(mean, std), 158 | ]) 159 | if self.mode == 'train_x': 160 | self.trans = trans_weak 161 | elif self.mode == 'train_u_comatch': 162 | self.trans = ThreeCropsTransform(trans_weak, trans_strong0, trans_strong1) 163 | elif self.mode == 'train_u_fixmatch': 164 | self.trans = TwoCropsTransform(trans_weak, trans_strong0) 165 | else: 166 | self.trans = T.Compose([ 167 | T.Resize((32, 32)), 168 | T.Normalize(mean, std), 169 | T.ToTensor(), 170 | ]) 171 | 172 | def __getitem__(self, idx): 173 | im, lb = self.data[idx], self.labels[idx] 174 | return self.trans(im), lb 175 | 176 | def __len__(self): 177 | leng = len(self.data) 178 | return leng 179 | 180 | 181 | def get_train_loader(dataset, batch_size, mu, n_iters_per_epoch, L, root='data', method='comatch'): 182 | data_x, label_x, data_u, label_u = load_data_train(L=L, dataset=dataset, dspth=root) 183 | 184 | ds_x = Cifar( 185 | dataset=dataset, 186 | data=data_x, 187 | labels=label_x, 188 | mode='train_x' 189 | ) # return an iter of num_samples length (all indices of samples) 190 | sampler_x = RandomSampler(ds_x, replacement=True, num_samples=n_iters_per_epoch * batch_size) 191 | batch_sampler_x = BatchSampler(sampler_x, batch_size, drop_last=True) # yield a batch of samples one time 192 | dl_x = torch.utils.data.DataLoader( 193 | ds_x, 194 | batch_sampler=batch_sampler_x, 195 | num_workers=2, 196 | pin_memory=True 197 | ) 198 | ds_u = Cifar( 199 | dataset=dataset, 200 | data=data_u, 201 | labels=label_u, 202 | mode='train_u_%s'%method 203 | ) 204 | sampler_u = RandomSampler(ds_u, replacement=True, num_samples=mu * n_iters_per_epoch * batch_size) 205 | #sampler_u = RandomSampler(ds_u, replacement=False) 206 | batch_sampler_u = BatchSampler(sampler_u, batch_size * mu, drop_last=True) 207 | dl_u = torch.utils.data.DataLoader( 208 | ds_u, 209 | batch_sampler=batch_sampler_u, 210 | num_workers=2, 211 | pin_memory=True 212 | ) 213 | return dl_x, dl_u 214 | 215 | 216 | def get_val_loader(dataset, batch_size, num_workers, pin_memory=True, root='data'): 217 | data, labels = load_data_val(dataset, dspth=root) 218 | ds = Cifar( 219 | dataset=dataset, 220 | data=data, 221 | labels=labels, 222 | mode='test' 223 | ) 224 | dl = torch.utils.data.DataLoader( 225 | ds, 226 | shuffle=False, 227 | batch_size=batch_size, 228 | drop_last=False, 229 | num_workers=num_workers, 230 | pin_memory=pin_memory 231 | ) 232 | return dl 233 | 234 | 235 | -------------------------------------------------------------------------------- /datasets/randaugment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | ## aug functions 6 | def identity_func(img): 7 | return img 8 | 9 | 10 | def autocontrast_func(img, cutoff=0): 11 | ''' 12 | same output as PIL.ImageOps.autocontrast 13 | ''' 14 | n_bins = 256 15 | 16 | def tune_channel(ch): 17 | n = ch.size 18 | cut = cutoff * n // 100 19 | if cut == 0: 20 | high, low = ch.max(), ch.min() 21 | else: 22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 23 | low = np.argwhere(np.cumsum(hist) > cut) 24 | low = 0 if low.shape[0] == 0 else low[0] 25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 27 | if high <= low: 28 | table = np.arange(n_bins) 29 | else: 30 | scale = (n_bins - 1) / (high - low) 31 | offset = -low * scale 32 | table = np.arange(n_bins) * scale + offset 33 | table[table < 0] = 0 34 | table[table > n_bins - 1] = n_bins - 1 35 | table = table.clip(0, 255).astype(np.uint8) 36 | return table[ch] 37 | 38 | channels = [tune_channel(ch) for ch in cv2.split(img)] 39 | out = cv2.merge(channels) 40 | return out 41 | 42 | 43 | def equalize_func(img): 44 | ''' 45 | same output as PIL.ImageOps.equalize 46 | PIL's implementation is different from cv2.equalize 47 | ''' 48 | n_bins = 256 49 | 50 | def tune_channel(ch): 51 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 52 | non_zero_hist = hist[hist != 0].reshape(-1) 53 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 54 | if step == 0: return ch 55 | n = np.empty_like(hist) 56 | n[0] = step // 2 57 | n[1:] = hist[:-1] 58 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 59 | return table[ch] 60 | 61 | channels = [tune_channel(ch) for ch in cv2.split(img)] 62 | out = cv2.merge(channels) 63 | return out 64 | 65 | 66 | def rotate_func(img, degree, fill=(0, 0, 0)): 67 | ''' 68 | like PIL, rotate by degree, not radians 69 | ''' 70 | H, W = img.shape[0], img.shape[1] 71 | center = W / 2, H / 2 72 | M = cv2.getRotationMatrix2D(center, degree, 1) 73 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 74 | return out 75 | 76 | 77 | def solarize_func(img, thresh=128): 78 | ''' 79 | same output as PIL.ImageOps.posterize 80 | ''' 81 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 82 | table = table.clip(0, 255).astype(np.uint8) 83 | out = table[img] 84 | return out 85 | 86 | 87 | def color_func(img, factor): 88 | ''' 89 | same output as PIL.ImageEnhance.Color 90 | ''' 91 | ## implementation according to PIL definition, quite slow 92 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 93 | # out = blend(degenerate, img, factor) 94 | # M = ( 95 | # np.eye(3) * factor 96 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 97 | # )[np.newaxis, np.newaxis, :] 98 | M = ( 99 | np.float32([ 100 | [0.886, -0.114, -0.114], 101 | [-0.587, 0.413, -0.587], 102 | [-0.299, -0.299, 0.701]]) * factor 103 | + np.float32([[0.114], [0.587], [0.299]]) 104 | ) 105 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 106 | return out 107 | 108 | 109 | def contrast_func(img, factor): 110 | """ 111 | same output as PIL.ImageEnhance.Contrast 112 | """ 113 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 114 | table = np.array([( 115 | el - mean) * factor + mean 116 | for el in range(256) 117 | ]).clip(0, 255).astype(np.uint8) 118 | out = table[img] 119 | return out 120 | 121 | 122 | def brightness_func(img, factor): 123 | ''' 124 | same output as PIL.ImageEnhance.Contrast 125 | ''' 126 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 127 | out = table[img] 128 | return out 129 | 130 | 131 | def sharpness_func(img, factor): 132 | ''' 133 | The differences the this result and PIL are all on the 4 boundaries, the center 134 | areas are same 135 | ''' 136 | kernel = np.ones((3, 3), dtype=np.float32) 137 | kernel[1][1] = 5 138 | kernel /= 13 139 | degenerate = cv2.filter2D(img, -1, kernel) 140 | if factor == 0.0: 141 | out = degenerate 142 | elif factor == 1.0: 143 | out = img 144 | else: 145 | out = img.astype(np.float32) 146 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 147 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 148 | out = out.astype(np.uint8) 149 | return out 150 | 151 | 152 | def shear_x_func(img, factor, fill=(0, 0, 0)): 153 | H, W = img.shape[0], img.shape[1] 154 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 155 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 156 | return out 157 | 158 | 159 | def translate_x_func(img, offset, fill=(0, 0, 0)): 160 | ''' 161 | same output as PIL.Image.transform 162 | ''' 163 | H, W = img.shape[0], img.shape[1] 164 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 166 | return out 167 | 168 | 169 | def translate_y_func(img, offset, fill=(0, 0, 0)): 170 | ''' 171 | same output as PIL.Image.transform 172 | ''' 173 | H, W = img.shape[0], img.shape[1] 174 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 175 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 176 | return out 177 | 178 | 179 | def posterize_func(img, bits): 180 | ''' 181 | same output as PIL.ImageOps.posterize 182 | ''' 183 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 184 | return out 185 | 186 | 187 | def shear_y_func(img, factor, fill=(0, 0, 0)): 188 | H, W = img.shape[0], img.shape[1] 189 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 190 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 191 | return out 192 | 193 | 194 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 195 | replace = np.array(replace, dtype=np.uint8) 196 | H, W = img.shape[0], img.shape[1] 197 | rh, rw = np.random.random(2) 198 | pad_size = pad_size // 2 199 | ch, cw = int(rh * H), int(rw * W) 200 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 201 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 202 | out = img.copy() 203 | out[x1:x2, y1:y2, :] = replace 204 | return out 205 | 206 | 207 | ### level to args 208 | def enhance_level_to_args(MAX_LEVEL): 209 | def level_to_args(level): 210 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 211 | return level_to_args 212 | 213 | 214 | def shear_level_to_args(MAX_LEVEL, replace_value): 215 | def level_to_args(level): 216 | level = (level / MAX_LEVEL) * 0.3 217 | if np.random.random() > 0.5: level = -level 218 | return (level, replace_value) 219 | 220 | return level_to_args 221 | 222 | 223 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 224 | def level_to_args(level): 225 | level = (level / MAX_LEVEL) * float(translate_const) 226 | if np.random.random() > 0.5: level = -level 227 | return (level, replace_value) 228 | 229 | return level_to_args 230 | 231 | 232 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 233 | def level_to_args(level): 234 | level = int((level / MAX_LEVEL) * cutout_const) 235 | return (level, replace_value) 236 | 237 | return level_to_args 238 | 239 | 240 | def solarize_level_to_args(MAX_LEVEL): 241 | def level_to_args(level): 242 | level = int((level / MAX_LEVEL) * 256) 243 | return (level, ) 244 | return level_to_args 245 | 246 | 247 | def none_level_to_args(level): 248 | return () 249 | 250 | 251 | def posterize_level_to_args(MAX_LEVEL): 252 | def level_to_args(level): 253 | level = int((level / MAX_LEVEL) * 4) 254 | return (level, ) 255 | return level_to_args 256 | 257 | 258 | def rotate_level_to_args(MAX_LEVEL, replace_value): 259 | def level_to_args(level): 260 | level = (level / MAX_LEVEL) * 30 261 | if np.random.random() < 0.5: 262 | level = -level 263 | return (level, replace_value) 264 | 265 | return level_to_args 266 | 267 | 268 | func_dict = { 269 | 'Identity': identity_func, 270 | 'AutoContrast': autocontrast_func, 271 | 'Equalize': equalize_func, 272 | 'Rotate': rotate_func, 273 | 'Solarize': solarize_func, 274 | 'Color': color_func, 275 | 'Contrast': contrast_func, 276 | 'Brightness': brightness_func, 277 | 'Sharpness': sharpness_func, 278 | 'ShearX': shear_x_func, 279 | 'TranslateX': translate_x_func, 280 | 'TranslateY': translate_y_func, 281 | 'Posterize': posterize_func, 282 | 'ShearY': shear_y_func, 283 | } 284 | 285 | translate_const = 10 286 | MAX_LEVEL = 10 287 | replace_value = (128, 128, 128) 288 | arg_dict = { 289 | 'Identity': none_level_to_args, 290 | 'AutoContrast': none_level_to_args, 291 | 'Equalize': none_level_to_args, 292 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), 293 | 'Solarize': solarize_level_to_args(MAX_LEVEL), 294 | 'Color': enhance_level_to_args(MAX_LEVEL), 295 | 'Contrast': enhance_level_to_args(MAX_LEVEL), 296 | 'Brightness': enhance_level_to_args(MAX_LEVEL), 297 | 'Sharpness': enhance_level_to_args(MAX_LEVEL), 298 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), 299 | 'TranslateX': translate_level_to_args( 300 | translate_const, MAX_LEVEL, replace_value 301 | ), 302 | 'TranslateY': translate_level_to_args( 303 | translate_const, MAX_LEVEL, replace_value 304 | ), 305 | 'Posterize': posterize_level_to_args(MAX_LEVEL), 306 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), 307 | } 308 | 309 | 310 | class RandomAugment(object): 311 | 312 | def __init__(self, N=2, M=10, isPIL=False): 313 | self.N = N 314 | self.M = M 315 | self.isPIL = isPIL 316 | 317 | def get_random_ops(self): 318 | sampled_ops = np.random.choice(list(func_dict.keys()), self.N) 319 | return [(op, 0.5, self.M) for op in sampled_ops] 320 | 321 | def __call__(self, img): 322 | if self.isPIL: 323 | img = np.array(img) 324 | ops = self.get_random_ops() 325 | for name, prob, level in ops: 326 | if np.random.random() > prob: 327 | continue 328 | args = arg_dict[name](level) 329 | img = func_dict[name](img, *args) 330 | img = cutout_func(img, 16, replace_value) 331 | return img 332 | 333 | 334 | if __name__ == '__main__': 335 | a = RandomAugment() 336 | img = np.random.randn(32, 32, 3) 337 | a(img) -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._six import int_classes as _int_classes 3 | 4 | 5 | class Sampler(object): 6 | r"""Base class for all Samplers. 7 | 8 | Every Sampler subclass has to provide an :meth:`__iter__` method, providing a 9 | way to iterate over indices of dataset elements, and a :meth:`__len__` method 10 | that returns the length of the returned iterators. 11 | 12 | .. note:: The :meth:`__len__` method isn't strictly required by 13 | :class:`~torch.utils.data.DataLoader`, but is expected in any 14 | calculation involving the length of a :class:`~torch.utils.data.DataLoader`. 15 | """ 16 | 17 | def __init__(self, data_source): 18 | pass 19 | 20 | def __iter__(self): 21 | raise NotImplementedError 22 | 23 | # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] 24 | # 25 | # Many times we have an abstract class representing a collection/iterable of 26 | # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally 27 | # implementing a `__len__` method. In such cases, we must make sure to not 28 | # provide a default implementation, because both straightforward default 29 | # implementations have their issues: 30 | # 31 | # + `return NotImplemented`: 32 | # Calling `len(subclass_instance)` raises: 33 | # TypeError: 'NotImplementedType' object cannot be interpreted as an integer 34 | # 35 | # + `raise NotImplementedError()`: 36 | # This prevents triggering some fallback behavior. E.g., the built-in 37 | # `list(X)` tries to call `len(X)` first, and executes a different code 38 | # path if the method is not found or `NotImplemented` is returned, while 39 | # raising an `NotImplementedError` will propagate and and make the call 40 | # fail where it could have use `__iter__` to complete the call. 41 | # 42 | # Thus, the only two sensible things to do are 43 | # 44 | # + **not** provide a default `__len__`. 45 | # 46 | # + raise a `TypeError` instead, which is what Python uses when users call 47 | # a method that is not defined on an object. 48 | # (@ssnl verifies that this works on at least Python 3.7.) 49 | 50 | 51 | class SequentialSampler(Sampler): 52 | r"""Samples elements sequentially, always in the same order. 53 | 54 | Arguments: 55 | data_source (Dataset): dataset to sample from 56 | """ 57 | 58 | def __init__(self, data_source): 59 | self.data_source = data_source 60 | 61 | def __iter__(self): 62 | return iter(range(len(self.data_source))) 63 | 64 | def __len__(self): 65 | return len(self.data_source) 66 | 67 | 68 | class RandomSampler(Sampler): 69 | r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. 70 | If with replacement, then user can specify :attr:`num_samples` to draw. 71 | 72 | Arguments: 73 | data_source (Dataset): dataset to sample from 74 | replacement (bool): samples are drawn with replacement if ``True``, default=``False`` 75 | num_samples (int): number of samples to draw, default=`len(dataset)`. This argument 76 | is supposed to be specified only when `replacement` is ``True``. 77 | """ 78 | 79 | def __init__(self, data_source, replacement=False, num_samples=None): 80 | self.data_source = data_source 81 | self.replacement = replacement 82 | self._num_samples = num_samples 83 | 84 | if not isinstance(self.replacement, bool): 85 | raise ValueError("replacement should be a boolean value, but got " 86 | "replacement={}".format(self.replacement)) 87 | 88 | if self._num_samples is not None and not replacement: 89 | raise ValueError("With replacement=False, num_samples should not be specified, " 90 | "since a random permute will be performed.") 91 | 92 | if not isinstance(self.num_samples, int) or self.num_samples <= 0: 93 | raise ValueError("num_samples should be a positive integer " 94 | "value, but got num_samples={}".format(self.num_samples)) 95 | 96 | @property 97 | def num_samples(self): 98 | # dataset size might change at runtime 99 | if self._num_samples is None: 100 | return len(self.data_source) 101 | return self._num_samples 102 | 103 | def __iter__(self): 104 | n = len(self.data_source) 105 | if self.replacement: 106 | n_repeats = self.num_samples // n 107 | n_remain = self.num_samples % n 108 | indices = [torch.randperm(n) for _ in range(n_repeats)] 109 | indices.append(torch.randperm(n)[:n_remain]) 110 | return iter(torch.cat(indices, dim=0).tolist()) 111 | return iter(torch.randperm(n).tolist()) 112 | 113 | def __len__(self): 114 | return self.num_samples 115 | 116 | 117 | class SubsetRandomSampler(Sampler): 118 | r"""Samples elements randomly from a given list of indices, without replacement. 119 | 120 | Arguments: 121 | indices (sequence): a sequence of indices 122 | """ 123 | 124 | def __init__(self, indices): 125 | self.indices = indices 126 | 127 | def __iter__(self): 128 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 129 | 130 | def __len__(self): 131 | return len(self.indices) 132 | 133 | 134 | class WeightedRandomSampler(Sampler): 135 | r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). 136 | 137 | Args: 138 | weights (sequence) : a sequence of weights, not necessary summing up to one 139 | num_samples (int): number of samples to draw 140 | replacement (bool): if ``True``, samples are drawn with replacement. 141 | If not, they are drawn without replacement, which means that when a 142 | sample index is drawn for a row, it cannot be drawn again for that row. 143 | 144 | Example: 145 | >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) 146 | [0, 0, 0, 1, 0] 147 | >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) 148 | [0, 1, 4, 3, 2] 149 | """ 150 | 151 | def __init__(self, weights, num_samples, replacement=True): 152 | if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \ 153 | num_samples <= 0: 154 | raise ValueError("num_samples should be a positive integer " 155 | "value, but got num_samples={}".format(num_samples)) 156 | if not isinstance(replacement, bool): 157 | raise ValueError("replacement should be a boolean value, but got " 158 | "replacement={}".format(replacement)) 159 | self.weights = torch.as_tensor(weights, dtype=torch.double) 160 | self.num_samples = num_samples 161 | self.replacement = replacement 162 | 163 | def __iter__(self): 164 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist()) 165 | 166 | def __len__(self): 167 | return self.num_samples 168 | 169 | 170 | class BatchSampler(Sampler): 171 | r"""Wraps another sampler to yield a mini-batch of indices. 172 | 173 | Args: 174 | sampler (Sampler): Base sampler. 175 | batch_size (int): Size of mini-batch. 176 | drop_last (bool): If ``True``, the sampler will drop the last batch if 177 | its size would be less than ``batch_size`` 178 | 179 | Example: 180 | >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) 181 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 182 | >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) 183 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 184 | """ 185 | 186 | def __init__(self, sampler, batch_size, drop_last): 187 | if not isinstance(sampler, Sampler): 188 | raise ValueError("sampler should be an instance of " 189 | "torch.utils.data.Sampler, but got sampler={}" 190 | .format(sampler)) 191 | if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ 192 | batch_size <= 0: 193 | raise ValueError("batch_size should be a positive integer value, " 194 | "but got batch_size={}".format(batch_size)) 195 | if not isinstance(drop_last, bool): 196 | raise ValueError("drop_last should be a boolean value, but got " 197 | "drop_last={}".format(drop_last)) 198 | self.sampler = sampler 199 | self.batch_size = batch_size 200 | self.drop_last = drop_last 201 | 202 | def __iter__(self): 203 | batch = [] 204 | for idx in self.sampler: 205 | batch.append(idx) 206 | if len(batch) == self.batch_size: 207 | yield batch 208 | batch = [] 209 | if len(batch) > 0 and not self.drop_last: 210 | yield batch 211 | 212 | def __len__(self): 213 | if self.drop_last: 214 | return len(self.sampler) // self.batch_size 215 | else: 216 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 217 | -------------------------------------------------------------------------------- /datasets/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | 5 | 6 | class PadandRandomCrop(object): 7 | ''' 8 | Input tensor is expected to have shape of (H, W, 3) 9 | ''' 10 | def __init__(self, border=4, cropsize=(32, 32)): 11 | self.border = border 12 | self.cropsize = cropsize 13 | 14 | def __call__(self, im): 15 | borders = [(self.border, self.border), (self.border, self.border), (0, 0)] # input is (h, w, c) 16 | convas = np.pad(im, borders, mode='reflect') 17 | H, W, C = convas.shape 18 | h, w = self.cropsize 19 | dh, dw = max(0, H-h), max(0, W-w) 20 | sh, sw = np.random.randint(0, dh), np.random.randint(0, dw) 21 | out = convas[sh:sh+h, sw:sw+w, :] 22 | return out 23 | 24 | 25 | class RandomHorizontalFlip(object): 26 | def __init__(self, p=0.5): 27 | self.p = p 28 | 29 | def __call__(self, im): 30 | if np.random.rand() < self.p: 31 | im = im[:, ::-1, :] 32 | return im 33 | 34 | 35 | class Resize(object): 36 | def __init__(self, size): 37 | self.size = size 38 | 39 | def __call__(self, im): 40 | im = cv2.resize(im, self.size) 41 | return im 42 | 43 | 44 | class Normalize(object): 45 | ''' 46 | Inputs are pixel values in range of [0, 255], channel order is 'rgb' 47 | ''' 48 | def __init__(self, mean, std): 49 | self.mean = np.array(mean, np.float32).reshape(1, 1, -1) 50 | self.std = np.array(std, np.float32).reshape(1, 1, -1) 51 | 52 | def __call__(self, im): 53 | if len(im.shape) == 4: 54 | mean, std = self.mean[None, ...], self.std[None, ...] 55 | elif len(im.shape) == 3: 56 | mean, std = self.mean, self.std 57 | im = im.astype(np.float32) / 255. 58 | # im = (im.astype(np.float32) / 127.5) - 1 59 | im -= mean 60 | im /= std 61 | return im 62 | 63 | 64 | class ToTensor(object): 65 | def __init__(self): 66 | pass 67 | 68 | def __call__(self, im): 69 | if len(im.shape) == 4: 70 | return torch.from_numpy(im.transpose(0, 3, 1, 2)) 71 | elif len(im.shape) == 3: 72 | return torch.from_numpy(im.transpose(2, 0, 1)) 73 | 74 | 75 | class Compose(object): 76 | def __init__(self, ops): 77 | self.ops = ops 78 | 79 | def __call__(self, im): 80 | for op in self.ops: 81 | im = op(im) 82 | return im 83 | -------------------------------------------------------------------------------- /imagenet/Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from random import sample 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | class Model(nn.Module): 8 | 9 | def __init__(self, base_encoder, args, width): 10 | 11 | super(Model, self).__init__() 12 | 13 | self.K = args.K 14 | 15 | self.encoder = base_encoder(num_class=args.num_class,mlp=True,low_dim=args.low_dim,width=width) 16 | self.m_encoder = base_encoder(num_class=args.num_class,mlp=True,low_dim=args.low_dim,width=width) 17 | 18 | for param, param_m in zip(self.encoder.parameters(), self.m_encoder.parameters()): 19 | param_m.data.copy_(param.data) # initialize 20 | param_m.requires_grad = False # not update by gradient 21 | 22 | # queue to store momentum feature for strong augmentations 23 | self.register_buffer("queue_s", torch.randn(args.low_dim, self.K)) 24 | self.queue_s = F.normalize(self.queue_s, dim=0) 25 | self.register_buffer("queue_ptr_s", torch.zeros(1, dtype=torch.long)) 26 | # queue to store momentum probs for weak augmentations (unlabeled) 27 | self.register_buffer("probs_u", torch.zeros(args.num_class, self.K)) 28 | 29 | # queue (memory bank) to store momentum feature and probs for weak augmentations (labeled and unlabeled) 30 | self.register_buffer("queue_w", torch.randn(args.low_dim, self.K)) 31 | self.register_buffer("queue_ptr_w", torch.zeros(1, dtype=torch.long)) 32 | self.register_buffer("probs_xu", torch.zeros(args.num_class, self.K)) 33 | 34 | # for distribution alignment 35 | self.hist_prob = [] 36 | 37 | @torch.no_grad() 38 | def _update_momentum_encoder(self,m): 39 | """ 40 | Update momentum encoder 41 | """ 42 | for param, param_m in zip(self.encoder.parameters(), self.m_encoder.parameters()): 43 | param_m.data = param_m.data * m + param.data * (1. - m) 44 | 45 | @torch.no_grad() 46 | def _dequeue_and_enqueue(self, z, t, ws): 47 | z = concat_all_gather(z) 48 | t = concat_all_gather(t) 49 | 50 | batch_size = z.shape[0] 51 | 52 | if ws=='s': 53 | ptr = int(self.queue_ptr_s) 54 | if (ptr + batch_size) > self.K: 55 | batch_size = self.K-ptr 56 | z = z[:batch_size] 57 | t = t[:batch_size] 58 | # replace the samples at ptr (dequeue and enqueue) 59 | self.queue_s[:, ptr:ptr + batch_size] = z.T 60 | self.probs_u[:, ptr:ptr + batch_size] = t.T 61 | ptr = (ptr + batch_size) % self.K # move pointer 62 | self.queue_ptr_s[0] = ptr 63 | 64 | elif ws=='w': 65 | ptr = int(self.queue_ptr_w) 66 | if (ptr + batch_size) > self.K: 67 | batch_size = self.K-ptr 68 | z = z[:batch_size] 69 | t = t[:batch_size] 70 | # replace the samples at ptr (dequeue and enqueue) 71 | self.queue_w[:, ptr:ptr + batch_size] = z.T 72 | self.probs_xu[:, ptr:ptr + batch_size] = t.T 73 | ptr = (ptr + batch_size) % self.K # move pointer 74 | self.queue_ptr_w[0] = ptr 75 | 76 | @torch.no_grad() 77 | def _batch_shuffle_ddp(self, x): 78 | """ 79 | Batch shuffle, for making use of BatchNorm. 80 | *** Only support DistributedDataParallel (DDP) model. *** 81 | """ 82 | # gather from all gpus 83 | batch_size_this = x.shape[0] 84 | x_gather = concat_all_gather(x) 85 | batch_size_all = x_gather.shape[0] 86 | 87 | num_gpus = batch_size_all // batch_size_this 88 | 89 | # random shuffle index 90 | idx_shuffle = torch.randperm(batch_size_all).cuda() 91 | 92 | # broadcast to all gpus 93 | torch.distributed.broadcast(idx_shuffle, src=0) 94 | 95 | # index for restoring 96 | idx_unshuffle = torch.argsort(idx_shuffle) 97 | 98 | # shuffled index for this gpu 99 | gpu_idx = torch.distributed.get_rank() 100 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 101 | 102 | return x_gather[idx_this], idx_unshuffle 103 | 104 | @torch.no_grad() 105 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 106 | """ 107 | Undo batch shuffle. 108 | *** Only support DistributedDataParallel (DDP) model. *** 109 | """ 110 | # gather from all gpus 111 | batch_size_this = x.shape[0] 112 | x_gather = concat_all_gather(x) 113 | batch_size_all = x_gather.shape[0] 114 | 115 | num_gpus = batch_size_all // batch_size_this 116 | 117 | # restored index for this gpu 118 | gpu_idx = torch.distributed.get_rank() 119 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 120 | 121 | return x_gather[idx_this] 122 | 123 | def forward(self, args, labeled_batch, unlabeled_batch=None, is_eval=False, epoch=0): 124 | 125 | img_x = labeled_batch[0].cuda(args.gpu, non_blocking=True) 126 | labels_x = labeled_batch[1].cuda(args.gpu, non_blocking=True) 127 | 128 | if is_eval: 129 | outputs_x, _ = self.encoder(img_x) 130 | return outputs_x, labels_x 131 | 132 | btx = img_x.size(0) 133 | 134 | img_u_w = unlabeled_batch[0][0].cuda(args.gpu, non_blocking=True) 135 | img_u_s0 = unlabeled_batch[0][1].cuda(args.gpu, non_blocking=True) 136 | img_u_s1 = unlabeled_batch[0][2].cuda(args.gpu, non_blocking=True) 137 | 138 | btu = img_u_w.size(0) 139 | 140 | imgs = torch.cat([img_x, img_u_s0], dim=0) 141 | outputs, features = self.encoder(imgs) 142 | 143 | outputs_x = outputs[:btx] 144 | outputs_u_s0 = outputs[btx:] 145 | features_u_s0 = features[btx:] 146 | 147 | with torch.no_grad(): 148 | self._update_momentum_encoder(args.m) 149 | # forward through the momentum encoder 150 | imgs_m = torch.cat([img_x, img_u_w, img_u_s1], dim=0) 151 | imgs_m, idx_unshuffle = self._batch_shuffle_ddp(imgs_m) 152 | 153 | outputs_m, features_m = self.m_encoder(imgs_m) 154 | outputs_m = self._batch_unshuffle_ddp(outputs_m, idx_unshuffle) 155 | features_m = self._batch_unshuffle_ddp(features_m, idx_unshuffle) 156 | 157 | outputs_u_w = outputs_m[btx:btx+btu] 158 | 159 | feature_u_w = features_m[btx:btx+btu] 160 | feature_xu_w = features_m[:btx+btu] 161 | features_u_s1 = features_m[btx+btu:] 162 | 163 | outputs_u_w = outputs_u_w.detach() 164 | feature_u_w = feature_u_w.detach() 165 | feature_xu_w = feature_xu_w.detach() 166 | features_u_s1 = features_u_s1.detach() 167 | 168 | probs = torch.softmax(outputs_u_w, dim=1) 169 | 170 | # distribution alignment 171 | probs_bt_avg = probs.mean(0) 172 | torch.distributed.all_reduce(probs_bt_avg,async_op=False) 173 | self.hist_prob.append(probs_bt_avg/args.world_size) 174 | 175 | if len(self.hist_prob)>128: 176 | self.hist_prob.pop(0) 177 | 178 | probs_avg = torch.stack(self.hist_prob,dim=0).mean(0) 179 | probs = probs / probs_avg 180 | probs = probs / probs.sum(dim=1, keepdim=True) 181 | probs_orig = probs.clone() 182 | 183 | # memory-smoothed pseudo-label refinement (starting from 2nd epoch) 184 | if epoch>0: 185 | m_feat_xu = self.queue_w.clone().detach() 186 | m_probs_xu = self.probs_xu.clone().detach() 187 | A = torch.exp(torch.mm(feature_u_w, m_feat_xu)/args.temperature) 188 | A = A/A.sum(1,keepdim=True) 189 | probs = args.alpha*probs + (1-args.alpha)*torch.mm(A, m_probs_xu.t()) 190 | 191 | # construct pseudo-label graph 192 | 193 | # similarity with current batch 194 | Q_self = torch.mm(probs,probs.t()) 195 | Q_self.fill_diagonal_(1) 196 | 197 | # similarity with past samples 198 | m_probs_u = self.probs_u.clone().detach() 199 | Q_past = torch.mm(probs,m_probs_u) 200 | 201 | # concatenate them 202 | Q = torch.cat([Q_self,Q_past],dim=1) 203 | 204 | # construct embedding graph for strong augmentations 205 | sim_self = torch.exp(torch.mm(features_u_s0, features_u_s1.t())/args.temperature) 206 | m_feat = self.queue_s.clone().detach() 207 | sim_past = torch.exp(torch.mm(features_u_s0, m_feat)/args.temperature) 208 | sim = torch.cat([sim_self,sim_past],dim=1) 209 | 210 | # store strong augmentation features and probs (unlabeled) into momentum queue 211 | self._dequeue_and_enqueue(features_u_s1, probs, 's') 212 | 213 | # store weak augmentation features and probs (labeled and unlabeled) into memory bank 214 | onehot = torch.zeros(btx,args.num_class).cuda().scatter(1,labels_x.view(-1,1),1) 215 | probs_xu = torch.cat([onehot, probs_orig],dim=0) 216 | 217 | self._dequeue_and_enqueue(feature_xu_w, probs_xu, 'w') 218 | 219 | return outputs_x, outputs_u_s0, labels_x, probs, Q, sim 220 | 221 | 222 | 223 | 224 | @torch.no_grad() 225 | def concat_all_gather(tensor): 226 | """ 227 | Performs all_gather operation on the provided tensors. 228 | *** Warning ***: torch.distributed.all_gather has no gradient. 229 | """ 230 | tensors_gather = [torch.ones_like(tensor) 231 | for _ in range(torch.distributed.get_world_size())] 232 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 233 | 234 | output = torch.cat(tensors_gather, dim=0) 235 | return output 236 | 237 | -------------------------------------------------------------------------------- /imagenet/README.md: -------------------------------------------------------------------------------- 1 | ### Semi-supervised learning on ImageNet with 1% or 10% labels: 2 | 3 | This implementation only supports multi-gpu, DistributedDataParallel training, which is faster and simpler. 4 | 5 | To train CoMatch with 1% labels on 8 gpus, run: 6 |
python Train_CoMatch.py --percent 1 --thr 0.6 --contrast-th 0.3 --lam-c 10 [Imagenet dataset folder]7 | 8 | To train CoMatch with 10% labels on 8 gpus, run: 9 |
python Train_CoMatch.py --percent 10 --thr 0.5 --contrast-th 0.2 --lam-c 2 [Imagenet dataset folder]10 | 11 | ### Semi-supervised learning results with CoMatch 12 | 13 | num. labels | top-1 acc. | top-5 acc 14 | --- | --- | --- 15 | 1% | 66.0% | 86.4% 16 | 10% | 73.6% | 91.6% 17 | 18 | ### Download pre-trained ResNet-50 models 19 | 20 | num. labels | model 21 | ------ | ------ 22 | 1% | download 23 | 10% | download 24 | 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /imagenet/Train_CoMatch.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2018, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | ''' 7 | import argparse 8 | import os 9 | import random 10 | import shutil 11 | import time 12 | import warnings 13 | import builtins 14 | import json 15 | from datetime import datetime 16 | import sys 17 | import math 18 | import numpy as np 19 | 20 | import tensorboard_logger 21 | import logging 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.parallel 26 | import torch.backends.cudnn as cudnn 27 | import torch.distributed as dist 28 | import torch.optim 29 | import torch.multiprocessing as mp 30 | import torch.utils.data 31 | import torch.utils.data.distributed 32 | import torchvision.transforms as transforms 33 | import torchvision.datasets as datasets 34 | import torch.nn.functional as F 35 | 36 | import loader 37 | from Model import Model 38 | from resnet import * 39 | 40 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 41 | parser.add_argument('data', metavar='DIR', default='', help='path to dataset') 42 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',choices=['resnet50','resnet50x2','resnet50x4']) 43 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 44 | help='number of data loading workers (default: 4)') 45 | parser.add_argument('--epochs', default=400, type=int, metavar='N', 46 | help='number of total epochs to run') 47 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 48 | help='manual epoch number (useful on restarts)') 49 | parser.add_argument('--batch-size', default=160, type=int, 50 | help='supervised batch size') 51 | parser.add_argument('--batch-size-u', default=640, type=int, 52 | help='unsupervised batch size') 53 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 54 | metavar='LR', help='initial learning rate', dest='lr') 55 | parser.add_argument('--cos', default=True, help='use cosine lr schedule') 56 | parser.add_argument('--schedule', default=[], nargs='*', type=int, 57 | help='learning rate schedule (when to drop lr by a ratio)') 58 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 59 | help='momentum') 60 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 61 | metavar='W', help='weight decay (default: 1e-4)', 62 | dest='weight_decay') 63 | parser.add_argument('-p', '--print-freq', default=50, type=int, 64 | metavar='N', help='print frequency (default: 50)') 65 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 66 | parser.add_argument('--pretrained', default='', type=str, metavar='PATH', help='path to pretrained model (default: none)') 67 | parser.add_argument('--world-size', default=1, type=int, 68 | help='number of nodes for distributed training') 69 | parser.add_argument('--rank', default=0, type=int, 70 | help='node rank for distributed training') 71 | parser.add_argument('--dist-url', default='tcp://localhost:10001', type=str, 72 | help='url used to set up distributed training') 73 | parser.add_argument('--dist-backend', default='nccl', type=str, 74 | help='distributed backend') 75 | parser.add_argument('--seed', default=0, type=int, 76 | help='seed for initializing training. ') 77 | parser.add_argument('--gpu', default=None, type=int, 78 | help='GPU id to use.') 79 | parser.add_argument('--multiprocessing-distributed', default=True) 80 | 81 | ## CoMatch settings 82 | parser.add_argument('--temperature', default=0.1, type=float, help='temperature for similarity scaling') 83 | parser.add_argument('--low-dim', default=128, type=int, help='feature dimension') 84 | parser.add_argument('--moco-m', default=0.996, type=float, 85 | help='momentum of updating momentum encoder') 86 | parser.add_argument('--K', default=30000, type=int, help='size of memory bank and momentum queue') 87 | parser.add_argument('--thr', default=0.6, type=float, help='pseudo-label confidence threshold') 88 | parser.add_argument('--contrast-th', default=0.3, type=float, help='pseudo-label graph connection threshold') 89 | parser.add_argument('--lam-u', default=10, type=float, help='weight for unsupervised cross-entropy loss') 90 | parser.add_argument('--lam-c', default=10, type=float, help='weight for unsupervised contrastive loss') 91 | parser.add_argument('--alpha', default=0.9, type=float, help='weight for model prediction in constructing pseudo-label') 92 | parser.add_argument('--exp_dir', default='experiment/comatch_1percent', type=str, help='experiment directory') 93 | 94 | ## dataset settings 95 | parser.add_argument('--percent', type=int, default=1, choices=[1,10], help='percentage of labeled samples') 96 | parser.add_argument('--num-class', default=1000, type=int) 97 | parser.add_argument('--annotation', default='annotation_1percent.json', type=str, help='annotation file') 98 | 99 | 100 | def main(): 101 | args = parser.parse_args() 102 | 103 | if args.seed is not None: 104 | random.seed(args.seed) 105 | torch.manual_seed(args.seed) 106 | 107 | if args.dist_url == "env://" and args.world_size == -1: 108 | args.world_size = int(os.environ["WORLD_SIZE"]) 109 | 110 | os.makedirs(args.exp_dir, exist_ok=True) 111 | 112 | ngpus_per_node = torch.cuda.device_count() 113 | if args.multiprocessing_distributed: 114 | args.world_size = ngpus_per_node * args.world_size 115 | # Use torch.multiprocessing.spawn to launch distributed processes: the 116 | # main_worker process function 117 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 118 | else: 119 | # Simply call main_worker function 120 | main_worker(args.gpu, ngpus_per_node, args) 121 | 122 | 123 | def main_worker(gpu, ngpus_per_node, args): 124 | args.gpu = gpu 125 | 126 | if args.gpu is not None: 127 | print("Use GPU: {} for training".format(args.gpu)) 128 | 129 | if args.multiprocessing_distributed and args.gpu != 0: 130 | def print_pass(*args): 131 | pass 132 | builtins.print = print_pass 133 | 134 | if args.dist_url == "env://" and args.rank == -1: 135 | args.rank = int(os.environ["RANK"]) 136 | if args.multiprocessing_distributed: 137 | # For multiprocessing distributed training, rank needs to be the 138 | # global rank among all the processes 139 | args.rank = args.rank * ngpus_per_node + gpu 140 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 141 | world_size=args.world_size, rank=args.rank) 142 | 143 | # create model 144 | print("=> creating model '{}'".format(args.arch)) 145 | if args.arch == 'resnet50': 146 | model = Model(resnet50,args,width=1) 147 | elif args.arch == 'resnet50x2': 148 | model = Model(resnet50,args,width=2) 149 | elif args.arch == 'resnet50x4': 150 | model = Model(resnet50,args,width=4) 151 | else: 152 | raise NotImplementedError('model not supported {}'.format(args.arch)) 153 | 154 | # load moco-v2 pretrained model 155 | if args.pretrained: 156 | if os.path.isfile(args.pretrained): 157 | print("=> loading checkpoint '{}'".format(args.pretrained)) 158 | checkpoint = torch.load(args.pretrained, map_location="cpu") 159 | state_dict = checkpoint['state_dict'] 160 | for k in list(state_dict.keys()): 161 | if k.startswith('module.encoder_q'): 162 | # remove prefix 163 | state_dict[k.replace('module.encoder_q', 'encoder')] = state_dict[k] 164 | # delete renamed or unused k 165 | del state_dict[k] 166 | for k in list(state_dict.keys()): 167 | if 'fc.0' in k: 168 | state_dict[k.replace('fc.0','fc1')] = state_dict[k] 169 | if 'fc.2' in k: 170 | state_dict[k.replace('fc.2','fc2')] = state_dict[k] 171 | del state_dict[k] 172 | args.start_epoch = 0 173 | msg = model.load_state_dict(state_dict, strict=False) 174 | print("=> loaded pre-trained model '{}'".format(args.pretrained)) 175 | # copy paramter to the momentum encoder 176 | for param, param_m in zip(model.encoder.parameters(), model.m_encoder.parameters()): 177 | param_m.data.copy_(param.data) 178 | param_m.requires_grad = False 179 | else: 180 | print("=> no checkpoint found at '{}'".format(args.pretrained)) 181 | 182 | 183 | if args.gpu is not None: 184 | torch.cuda.set_device(args.gpu) 185 | model.cuda(args.gpu) 186 | # When using a single GPU per process and per 187 | # DistributedDataParallel, we need to divide the batch size 188 | # ourselves based on the total number of GPUs we have 189 | args.batch_size = int(args.batch_size / ngpus_per_node) 190 | args.batch_size_u = int(args.batch_size_u / ngpus_per_node) 191 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 192 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) #find_unused_parameters=True 193 | else: 194 | model.cuda() 195 | # DistributedDataParallel will divide and allocate batch_size to all 196 | # available GPUs if device_ids are not set 197 | model = torch.nn.parallel.DistributedDataParallel(model) 198 | 199 | # define loss function (criterion) and optimizer 200 | criteria_x = nn.CrossEntropyLoss().cuda(args.gpu) 201 | 202 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 203 | momentum=args.momentum, 204 | weight_decay=args.weight_decay, 205 | nesterov=True 206 | ) 207 | 208 | # optionally resume from a checkpoint 209 | if args.resume: 210 | if os.path.isfile(args.resume): 211 | print("=> loading checkpoint '{}'".format(args.resume)) 212 | if args.gpu is None: 213 | checkpoint = torch.load(args.resume) 214 | else: 215 | # Map model to be loaded to specified single gpu. 216 | loc = 'cuda:{}'.format(args.gpu) 217 | checkpoint = torch.load(args.resume, map_location=loc) 218 | args.start_epoch = checkpoint['epoch'] 219 | model.load_state_dict(checkpoint['state_dict']) 220 | optimizer.load_state_dict(checkpoint['optimizer']) 221 | print("=> loaded checkpoint '{}' (epoch {})" 222 | .format(args.resume, checkpoint['epoch'])) 223 | else: 224 | print("=> no checkpoint found at '{}'".format(args.resume)) 225 | 226 | cudnn.benchmark = True 227 | 228 | print("=> preparing dataset") 229 | # Data loading code 230 | 231 | transform_strong = transforms.Compose([ 232 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 233 | transforms.RandomApply([ 234 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 235 | ], p=0.8), 236 | transforms.RandomGrayscale(p=0.2), 237 | transforms.RandomHorizontalFlip(), 238 | transforms.ToTensor(), 239 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 240 | ]) 241 | transform_weak = transforms.Compose([ 242 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 243 | transforms.RandomHorizontalFlip(), 244 | transforms.ToTensor(), 245 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 246 | ]) 247 | transform_eval = transforms.Compose([ 248 | transforms.Resize(256), 249 | transforms.CenterCrop(224), 250 | transforms.ToTensor(), 251 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 252 | ]) 253 | 254 | three_crops_transform = loader.ThreeCropsTransform(transform_weak, transform_strong, transform_strong) 255 | 256 | traindir = os.path.join(args.data, 'train') 257 | valdir = os.path.join(args.data, 'val') 258 | 259 | labeled_dataset = datasets.ImageFolder(traindir,transform_weak) 260 | unlabeled_dataset = datasets.ImageFolder(traindir,three_crops_transform) 261 | 262 | if not os.path.exists(args.annotation): 263 | # randomly sample labeled data on main process (gpu0) 264 | label_per_class = 13 if args.percent==1 else 128 265 | if args.gpu==0: 266 | random.shuffle(labeled_dataset.samples) 267 | labeled_samples=[] 268 | unlabeled_samples=[] 269 | num_img = torch.zeros(args.num_class) 270 | 271 | for i,(img,label) in enumerate(labeled_dataset.samples): 272 | if num_img[label]