├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── checkpoints ├── params.pkl ├── stats0.pkl └── train.log ├── checkpoints_linear ├── params.pkl ├── stats0.pkl └── train.log ├── eval_linear.py ├── eval_semisup.py ├── hubconf.py ├── main_swav.py ├── run.sh └── src ├── __init__.py ├── logger.py ├── multicropdataset.py ├── resnet50.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | */__pycache__/ 2 | experiments 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | In the context of this project, we do not expect pull requests. 4 | If you find a bug, or would like to suggest an improvement, please open an issue. 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. 400 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SwAV (CIFAR-10) 2 | This code is a modified version of SwAV [code](https://github.com/facebookresearch/swav), [paper](https://arxiv.org/abs/2006.09882) for CIFAR-10. 3 | 4 | As mentioned in this [README](https://github.com/facebookresearch/swav/README.md), the loss sometimes gets stuck at ln(nmb_prototypes). This repository successfully avoids that through architecture changes and hyperparameter tuning. 5 | 6 | 7 | Specifically, 8 | - The Resnet-50 architecture has been modified to suit 32x32 images in CIFAR-10. The kernel size and stride for conv1 block has been changed to 3 and 1 respectively. The maxpool operation after the conv1 block has been removed. 9 | - The hyperparameters have been tuned for CIFAR-10. 10 | - No Multicrop, No queue, No distributed training. Uncommenting a few lines would re-enable distributed training. 11 | 12 | Please refer to [run.sh](./run.sh) script for the used hyper-parameters. The file paths would need to be modified accordingly. 13 | 14 | For a detailed README and more information about SwAV, please refer to the excellent [README](https://github.com/facebookresearch/swav/README.md) by the authors. -------------------------------------------------------------------------------- /checkpoints/params.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhinavagarwalla/swav-cifar10/4369b58aff2dac7b9b1e40d53af2f2eac9be9481/checkpoints/params.pkl -------------------------------------------------------------------------------- /checkpoints/stats0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhinavagarwalla/swav-cifar10/4369b58aff2dac7b9b1e40d53af2f2eac9be9481/checkpoints/stats0.pkl -------------------------------------------------------------------------------- /checkpoints_linear/params.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhinavagarwalla/swav-cifar10/4369b58aff2dac7b9b1e40d53af2f2eac9be9481/checkpoints_linear/params.pkl -------------------------------------------------------------------------------- /checkpoints_linear/stats0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhinavagarwalla/swav-cifar10/4369b58aff2dac7b9b1e40d53af2f2eac9be9481/checkpoints_linear/stats0.pkl -------------------------------------------------------------------------------- /eval_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | import os 10 | import time 11 | from logging import getLogger 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.parallel 16 | import torch.backends.cudnn as cudnn 17 | import torch.distributed as dist 18 | import torch.optim 19 | import torch.utils.data as data 20 | import torchvision.transforms as transforms 21 | import torchvision.datasets as datasets 22 | 23 | from src.utils import * 24 | # ( 25 | # bool_flag, 26 | # initialize_exp, 27 | # restart_from_checkpoint, 28 | # fix_random_seeds, 29 | # AverageMeter, 30 | # init_distributed_mode, 31 | # accuracy, 32 | # ) 33 | import src.resnet50 as resnet_models 34 | 35 | logger = getLogger() 36 | 37 | 38 | parser = argparse.ArgumentParser(description="Evaluate models: Linear classification on ImageNet") 39 | 40 | ######################### 41 | #### main parameters #### 42 | ######################### 43 | parser.add_argument("--dump_path", type=str, default=".", 44 | help="experiment dump path for checkpoints and log") 45 | parser.add_argument("--seed", type=int, default=31, help="seed") 46 | parser.add_argument("--data_path", type=str, default="/path/to/imagenet", 47 | help="path to dataset repository") 48 | parser.add_argument("--workers", default=10, type=int, 49 | help="number of data loading workers") 50 | 51 | ######################### 52 | #### model parameters ### 53 | ######################### 54 | parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture") 55 | parser.add_argument("--pretrained", default="", type=str, help="path to pretrained weights") 56 | parser.add_argument("--global_pooling", default=True, type=bool_flag, 57 | help="if True, we use the resnet50 global average pooling") 58 | parser.add_argument("--use_bn", default=False, type=bool_flag, 59 | help="optionally add a batchnorm layer before the linear classifier") 60 | 61 | ######################### 62 | #### optim parameters ### 63 | ######################### 64 | parser.add_argument("--epochs", default=100, type=int, 65 | help="number of total epochs to run") 66 | parser.add_argument("--batch_size", default=32, type=int, 67 | help="batch size per gpu, i.e. how many unique instances per gpu") 68 | parser.add_argument("--lr", default=0.3, type=float, help="initial learning rate") 69 | parser.add_argument("--wd", default=1e-6, type=float, help="weight decay") 70 | parser.add_argument("--nesterov", default=False, type=bool_flag, help="nesterov momentum") 71 | parser.add_argument("--scheduler_type", default="cosine", type=str, choices=["step", "cosine"]) 72 | # for multi-step learning rate decay 73 | parser.add_argument("--decay_epochs", type=int, nargs="+", default=[60, 80], 74 | help="Epochs at which to decay learning rate.") 75 | parser.add_argument("--gamma", type=float, default=0.1, help="decay factor") 76 | # for cosine learning rate schedule 77 | parser.add_argument("--final_lr", type=float, default=0, help="final learning rate") 78 | 79 | ######################### 80 | #### dist parameters ### 81 | ######################### 82 | parser.add_argument("--dist_url", default="env://", type=str, 83 | help="url used to set up distributed training") 84 | parser.add_argument("--world_size", default=-1, type=int, help=""" 85 | number of processes: it is set automatically and 86 | should not be passed as argument""") 87 | parser.add_argument("--rank", default=0, type=int, help="""rank of this process: 88 | it is set automatically and should not be passed as argument""") 89 | parser.add_argument("--local_rank", default=0, type=int, 90 | help="this argument is not used and should be ignored") 91 | 92 | 93 | def main(): 94 | global args, best_acc 95 | args = parser.parse_args() 96 | # init_distributed_mode(args) 97 | fix_random_seeds(args.seed) 98 | logger, training_stats = initialize_exp( 99 | args, "epoch", "loss", "prec1", "prec5", "loss_val", "prec1_val", "prec5_val" 100 | ) 101 | 102 | os.environ["CUDA_VISIBLE_DEVICES"] = str('0') 103 | 104 | # build data 105 | # train_dataset = datasets.ImageFolder(os.path.join(args.data_path, "train")) 106 | # val_dataset = datasets.ImageFolder(os.path.join(args.data_path, "val")) 107 | train_dataset = datasets.CIFAR10(args.data_path, train=True) 108 | val_dataset = datasets.CIFAR10(args.data_path, train=False) 109 | tr_normalize = transforms.Normalize( 110 | mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225] 111 | # The CIFAR-10 mean below might lead to better results 112 | # mean=[0.491, 0.482, 0.446], std=[0.247, 0.243, 0.262] 113 | ) 114 | train_dataset.transform = transforms.Compose([ 115 | transforms.RandomResizedCrop(32), 116 | transforms.RandomHorizontalFlip(), 117 | transforms.ToTensor(), 118 | tr_normalize, 119 | ]) 120 | val_dataset.transform = transforms.Compose([ 121 | transforms.Resize(32), 122 | transforms.CenterCrop(32), 123 | transforms.ToTensor(), 124 | tr_normalize, 125 | ]) 126 | # sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 127 | train_loader = torch.utils.data.DataLoader( 128 | train_dataset, 129 | # sampler=sampler, 130 | batch_size=args.batch_size, 131 | num_workers=args.workers, 132 | pin_memory=True, 133 | ) 134 | val_loader = torch.utils.data.DataLoader( 135 | val_dataset, 136 | batch_size=args.batch_size, 137 | num_workers=args.workers, 138 | pin_memory=True, 139 | ) 140 | logger.info("Building data done") 141 | 142 | # build model 143 | model = resnet_models.__dict__[args.arch](output_dim=0, eval_mode=True) 144 | linear_classifier = RegLog(1000, args.arch, args.global_pooling, args.use_bn) 145 | 146 | # convert batch norm layers (if any) 147 | # linear_classifier = nn.SyncBatchNorm.convert_sync_batchnorm(linear_classifier) 148 | 149 | # model to gpu 150 | model = model.cuda() 151 | linear_classifier = linear_classifier.cuda() 152 | # linear_classifier = nn.parallel.DistributedDataParallel( 153 | # linear_classifier, 154 | # device_ids=[args.gpu_to_work_on], 155 | # find_unused_parameters=True, 156 | # ) 157 | model.eval() 158 | 159 | # load weights 160 | if os.path.isfile(args.pretrained): 161 | state_dict = torch.load(args.pretrained, map_location="cuda:0") 162 | if "state_dict" in state_dict: 163 | state_dict = state_dict["state_dict"] 164 | # remove prefixe "module." 165 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 166 | for k, v in model.state_dict().items(): 167 | if k not in list(state_dict): 168 | logger.info('key "{}" could not be found in provided state dict'.format(k)) 169 | elif state_dict[k].shape != v.shape: 170 | logger.info('key "{}" is of different shape in model and provided state dict'.format(k)) 171 | state_dict[k] = v 172 | msg = model.load_state_dict(state_dict, strict=False) 173 | logger.info("Load pretrained model with msg: {}".format(msg)) 174 | else: 175 | logger.info("No pretrained weights found => training with random weights") 176 | 177 | # set optimizer 178 | optimizer = torch.optim.SGD( 179 | linear_classifier.parameters(), 180 | lr=args.lr, 181 | nesterov=args.nesterov, 182 | momentum=0.9, 183 | weight_decay=args.wd, 184 | ) 185 | 186 | # set scheduler 187 | if args.scheduler_type == "step": 188 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 189 | optimizer, args.decay_epochs, gamma=args.gamma 190 | ) 191 | elif args.scheduler_type == "cosine": 192 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 193 | optimizer, args.epochs, eta_min=args.final_lr 194 | ) 195 | 196 | # Optionally resume from a checkpoint 197 | to_restore = {"epoch": 0, "best_acc": 0.} 198 | restart_from_checkpoint( 199 | os.path.join(args.dump_path, "checkpoint.pth.tar"), 200 | run_variables=to_restore, 201 | state_dict=linear_classifier, 202 | optimizer=optimizer, 203 | scheduler=scheduler, 204 | ) 205 | start_epoch = to_restore["epoch"] 206 | best_acc = to_restore["best_acc"] 207 | cudnn.benchmark = True 208 | 209 | for epoch in range(start_epoch, args.epochs): 210 | 211 | # train the network for one epoch 212 | logger.info("============ Starting epoch %i ... ============" % epoch) 213 | 214 | # set samplers 215 | # train_loader.sampler.set_epoch(epoch) 216 | 217 | scores = train(model, linear_classifier, optimizer, train_loader, epoch) 218 | scores_val = validate_network(val_loader, model, linear_classifier) 219 | training_stats.update(scores + scores_val) 220 | 221 | scheduler.step() 222 | 223 | # save checkpoint 224 | if args.rank == 0: 225 | save_dict = { 226 | "epoch": epoch + 1, 227 | "state_dict": linear_classifier.state_dict(), 228 | "optimizer": optimizer.state_dict(), 229 | "scheduler": scheduler.state_dict(), 230 | "best_acc": best_acc, 231 | } 232 | torch.save(save_dict, os.path.join(args.dump_path, "checkpoint.pth.tar")) 233 | logger.info("Training of the supervised linear classifier on frozen features completed.\n" 234 | "Top-1 test accuracy: {acc:.1f}".format(acc=best_acc)) 235 | 236 | 237 | class RegLog(nn.Module): 238 | """Creates logistic regression on top of frozen features""" 239 | 240 | def __init__(self, num_labels, arch="resnet50", global_avg=False, use_bn=True): 241 | super(RegLog, self).__init__() 242 | self.bn = None 243 | if global_avg: 244 | if arch == "resnet50": 245 | s = 2048 246 | elif arch == "resnet50w2": 247 | s = 4096 248 | elif arch == "resnet50w4": 249 | s = 8192 250 | self.av_pool = nn.AdaptiveAvgPool2d((1, 1)) 251 | else: 252 | assert arch == "resnet50" 253 | s = 8192 254 | self.av_pool = nn.AvgPool2d(6, stride=1) 255 | if use_bn: 256 | self.bn = nn.BatchNorm2d(2048) 257 | self.linear = nn.Linear(s, num_labels) 258 | self.linear.weight.data.normal_(mean=0.0, std=0.01) 259 | self.linear.bias.data.zero_() 260 | 261 | def forward(self, x): 262 | # average pool the final feature map 263 | x = self.av_pool(x) 264 | 265 | # optional BN 266 | if self.bn is not None: 267 | x = self.bn(x) 268 | 269 | # flatten 270 | x = x.view(x.size(0), -1) 271 | 272 | # linear layer 273 | return self.linear(x) 274 | 275 | 276 | def train(model, reglog, optimizer, loader, epoch): 277 | """ 278 | Train the models on the dataset. 279 | """ 280 | # running statistics 281 | batch_time = AverageMeter() 282 | data_time = AverageMeter() 283 | 284 | # training statistics 285 | top1 = AverageMeter() 286 | top5 = AverageMeter() 287 | losses = AverageMeter() 288 | end = time.perf_counter() 289 | 290 | model.eval() 291 | reglog.train() 292 | criterion = nn.CrossEntropyLoss().cuda() 293 | 294 | for iter_epoch, (inp, target) in enumerate(loader): 295 | # measure data loading time 296 | data_time.update(time.perf_counter() - end) 297 | 298 | # move to gpu 299 | inp = inp.cuda(non_blocking=True) 300 | target = target.cuda(non_blocking=True) 301 | 302 | # forward 303 | with torch.no_grad(): 304 | output = model(inp) 305 | output = reglog(output) 306 | 307 | # compute cross entropy loss 308 | loss = criterion(output, target) 309 | 310 | # compute the gradients 311 | optimizer.zero_grad() 312 | loss.backward() 313 | 314 | # step 315 | optimizer.step() 316 | 317 | # update stats 318 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 319 | losses.update(loss.item(), inp.size(0)) 320 | top1.update(acc1[0], inp.size(0)) 321 | top5.update(acc5[0], inp.size(0)) 322 | 323 | batch_time.update(time.perf_counter() - end) 324 | end = time.perf_counter() 325 | 326 | # verbose 327 | if args.rank == 0 and iter_epoch % 50 == 0: 328 | logger.info( 329 | "Epoch[{0}] - Iter: [{1}/{2}]\t" 330 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 331 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" 332 | "Loss {loss.val:.4f} ({loss.avg:.4f})\t" 333 | "Prec {top1.val:.3f} ({top1.avg:.3f})\t" 334 | "LR {lr}".format( 335 | epoch, 336 | iter_epoch, 337 | len(loader), 338 | batch_time=batch_time, 339 | data_time=data_time, 340 | loss=losses, 341 | top1=top1, 342 | lr=optimizer.param_groups[0]["lr"], 343 | ) 344 | ) 345 | 346 | return epoch, losses.avg, top1.avg.item(), top5.avg.item() 347 | 348 | 349 | def validate_network(val_loader, model, linear_classifier): 350 | batch_time = AverageMeter() 351 | losses = AverageMeter() 352 | top1 = AverageMeter() 353 | top5 = AverageMeter() 354 | global best_acc 355 | 356 | # switch to evaluate mode 357 | model.eval() 358 | linear_classifier.eval() 359 | 360 | criterion = nn.CrossEntropyLoss().cuda() 361 | 362 | with torch.no_grad(): 363 | end = time.perf_counter() 364 | for i, (inp, target) in enumerate(val_loader): 365 | 366 | # move to gpu 367 | inp = inp.cuda(non_blocking=True) 368 | target = target.cuda(non_blocking=True) 369 | 370 | # compute output 371 | output = linear_classifier(model(inp)) 372 | loss = criterion(output, target) 373 | 374 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 375 | losses.update(loss.item(), inp.size(0)) 376 | top1.update(acc1[0], inp.size(0)) 377 | top5.update(acc5[0], inp.size(0)) 378 | 379 | # measure elapsed time 380 | batch_time.update(time.perf_counter() - end) 381 | end = time.perf_counter() 382 | 383 | if top1.avg.item() > best_acc: 384 | best_acc = top1.avg.item() 385 | 386 | if args.rank == 0: 387 | logger.info( 388 | "Test:\t" 389 | "Time {batch_time.avg:.3f}\t" 390 | "Loss {loss.avg:.4f}\t" 391 | "Acc@1 {top1.avg:.3f}\t" 392 | "Best Acc@1 so far {acc:.1f}".format( 393 | batch_time=batch_time, loss=losses, top1=top1, acc=best_acc)) 394 | 395 | return losses.avg, top1.avg.item(), top5.avg.item() 396 | 397 | 398 | if __name__ == "__main__": 399 | main() 400 | -------------------------------------------------------------------------------- /eval_semisup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | import os 10 | import time 11 | from logging import getLogger 12 | import urllib 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.parallel 17 | import torch.backends.cudnn as cudnn 18 | import torch.distributed as dist 19 | import torch.optim 20 | import torch.utils.data as data 21 | import torchvision.transforms as transforms 22 | import torchvision.datasets as datasets 23 | 24 | from src.utils import ( 25 | bool_flag, 26 | initialize_exp, 27 | restart_from_checkpoint, 28 | fix_random_seeds, 29 | AverageMeter, 30 | init_distributed_mode, 31 | accuracy, 32 | ) 33 | import src.resnet50 as resnet_models 34 | 35 | logger = getLogger() 36 | 37 | 38 | parser = argparse.ArgumentParser(description="Evaluate models: Fine-tuning with 1% or 10% labels on ImageNet") 39 | 40 | ######################### 41 | #### main parameters #### 42 | ######################### 43 | parser.add_argument("--labels_perc", type=str, default="10", choices=["1", "10"], 44 | help="fine-tune on either 1% or 10% of labels") 45 | parser.add_argument("--dump_path", type=str, default=".", 46 | help="experiment dump path for checkpoints and log") 47 | parser.add_argument("--seed", type=int, default=31, help="seed") 48 | parser.add_argument("--data_path", type=str, default="/path/to/imagenet", 49 | help="path to imagenet") 50 | parser.add_argument("--workers", default=10, type=int, 51 | help="number of data loading workers") 52 | 53 | ######################### 54 | #### model parameters ### 55 | ######################### 56 | parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture") 57 | parser.add_argument("--pretrained", default="", type=str, help="path to pretrained weights") 58 | 59 | ######################### 60 | #### optim parameters ### 61 | ######################### 62 | parser.add_argument("--epochs", default=20, type=int, 63 | help="number of total epochs to run") 64 | parser.add_argument("--batch_size", default=32, type=int, 65 | help="batch size per gpu, i.e. how many unique instances per gpu") 66 | parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate - trunk") 67 | parser.add_argument("--lr_last_layer", default=0.2, type=float, help="initial learning rate - head") 68 | parser.add_argument("--decay_epochs", type=int, nargs="+", default=[12, 16], 69 | help="Epochs at which to decay learning rate.") 70 | parser.add_argument("--gamma", type=float, default=0.2, help="lr decay factor") 71 | 72 | ######################### 73 | #### dist parameters ### 74 | ######################### 75 | parser.add_argument("--dist_url", default="env://", type=str, 76 | help="url used to set up distributed training") 77 | parser.add_argument("--world_size", default=-1, type=int, help=""" 78 | number of processes: it is set automatically and 79 | should not be passed as argument""") 80 | parser.add_argument("--rank", default=0, type=int, help="""rank of this process: 81 | it is set automatically and should not be passed as argument""") 82 | parser.add_argument("--local_rank", default=0, type=int, 83 | help="this argument is not used and should be ignored") 84 | 85 | 86 | def main(): 87 | global args, best_acc 88 | args = parser.parse_args() 89 | init_distributed_mode(args) 90 | fix_random_seeds(args.seed) 91 | logger, training_stats = initialize_exp( 92 | args, "epoch", "loss", "prec1", "prec5", "loss_val", "prec1_val", "prec5_val" 93 | ) 94 | 95 | # build data 96 | train_data_path = os.path.join(args.data_path, "train") 97 | train_dataset = datasets.ImageFolder(train_data_path) 98 | # take either 1% or 10% of images 99 | subset_file = urllib.request.urlopen("https://raw.githubusercontent.com/google-research/simclr/master/imagenet_subsets/" + str(args.labels_perc) + "percent.txt") 100 | list_imgs = [li.decode("utf-8").split('\n')[0] for li in subset_file] 101 | train_dataset.samples = [( 102 | os.path.join(train_data_path, li.split('_')[0], li), 103 | train_dataset.class_to_idx[li.split('_')[0]] 104 | ) for li in list_imgs] 105 | val_dataset = datasets.ImageFolder(os.path.join(args.data_path, "val")) 106 | tr_normalize = transforms.Normalize( 107 | mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225] 108 | ) 109 | train_dataset.transform = transforms.Compose([ 110 | transforms.RandomResizedCrop(224), 111 | transforms.RandomHorizontalFlip(), 112 | transforms.ToTensor(), 113 | tr_normalize, 114 | ]) 115 | val_dataset.transform = transforms.Compose([ 116 | transforms.Resize(256), 117 | transforms.CenterCrop(224), 118 | transforms.ToTensor(), 119 | tr_normalize, 120 | ]) 121 | sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 122 | train_loader = torch.utils.data.DataLoader( 123 | train_dataset, 124 | sampler=sampler, 125 | batch_size=args.batch_size, 126 | num_workers=args.workers, 127 | pin_memory=True, 128 | ) 129 | val_loader = torch.utils.data.DataLoader( 130 | val_dataset, 131 | batch_size=args.batch_size, 132 | num_workers=args.workers, 133 | pin_memory=True, 134 | ) 135 | logger.info("Building data done with {} images loaded.".format(len(train_dataset))) 136 | 137 | # build model 138 | model = resnet_models.__dict__[args.arch](output_dim=1000) 139 | 140 | # convert batch norm layers 141 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 142 | 143 | # load weights 144 | if os.path.isfile(args.pretrained): 145 | state_dict = torch.load(args.pretrained, map_location="cuda:" + str(args.gpu_to_work_on)) 146 | if "state_dict" in state_dict: 147 | state_dict = state_dict["state_dict"] 148 | # remove prefixe "module." 149 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 150 | for k, v in model.state_dict().items(): 151 | if k not in list(state_dict): 152 | logger.info('key "{}" could not be found in provided state dict'.format(k)) 153 | elif state_dict[k].shape != v.shape: 154 | logger.info('key "{}" is of different shape in model and provided state dict'.format(k)) 155 | state_dict[k] = v 156 | msg = model.load_state_dict(state_dict, strict=False) 157 | logger.info("Load pretrained model with msg: {}".format(msg)) 158 | else: 159 | logger.info("No pretrained weights found => training from random weights") 160 | 161 | # model to gpu 162 | model = model.cuda() 163 | model = nn.parallel.DistributedDataParallel( 164 | model, 165 | device_ids=[args.gpu_to_work_on], 166 | find_unused_parameters=True, 167 | ) 168 | 169 | # set optimizer 170 | trunk_parameters = [] 171 | head_parameters = [] 172 | for name, param in model.named_parameters(): 173 | if 'head' in name: 174 | head_parameters.append(param) 175 | else: 176 | trunk_parameters.append(param) 177 | optimizer = torch.optim.SGD( 178 | [{'params': trunk_parameters}, 179 | {'params': head_parameters, 'lr': args.lr_last_layer}], 180 | lr=args.lr, 181 | momentum=0.9, 182 | weight_decay=0, 183 | ) 184 | # set scheduler 185 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 186 | optimizer, args.decay_epochs, gamma=args.gamma 187 | ) 188 | 189 | # Optionally resume from a checkpoint 190 | to_restore = {"epoch": 0, "best_acc": (0., 0.)} 191 | restart_from_checkpoint( 192 | os.path.join(args.dump_path, "checkpoint.pth.tar"), 193 | run_variables=to_restore, 194 | state_dict=model, 195 | optimizer=optimizer, 196 | scheduler=scheduler, 197 | ) 198 | start_epoch = to_restore["epoch"] 199 | best_acc = to_restore["best_acc"] 200 | cudnn.benchmark = True 201 | 202 | for epoch in range(start_epoch, args.epochs): 203 | 204 | # train the network for one epoch 205 | logger.info("============ Starting epoch %i ... ============" % epoch) 206 | 207 | # set samplers 208 | train_loader.sampler.set_epoch(epoch) 209 | 210 | scores = train(model, optimizer, train_loader, epoch) 211 | scores_val = validate_network(val_loader, model) 212 | training_stats.update(scores + scores_val) 213 | 214 | scheduler.step() 215 | 216 | # save checkpoint 217 | if args.rank == 0: 218 | save_dict = { 219 | "epoch": epoch + 1, 220 | "state_dict": model.state_dict(), 221 | "optimizer": optimizer.state_dict(), 222 | "scheduler": scheduler.state_dict(), 223 | "best_acc": best_acc, 224 | } 225 | torch.save(save_dict, os.path.join(args.dump_path, "checkpoint.pth.tar")) 226 | logger.info("Fine-tuning with {}% of labels completed.\n" 227 | "Test accuracies: top-1 {acc1:.1f}, top-5 {acc5:.1f}".format( 228 | args.labels_perc, acc1=best_acc[0], acc5=best_acc[1])) 229 | 230 | 231 | def train(model, optimizer, loader, epoch): 232 | """ 233 | Train the models on the dataset. 234 | """ 235 | # running statistics 236 | batch_time = AverageMeter() 237 | data_time = AverageMeter() 238 | 239 | # training statistics 240 | top1 = AverageMeter() 241 | top5 = AverageMeter() 242 | losses = AverageMeter() 243 | end = time.perf_counter() 244 | 245 | model.train() 246 | criterion = nn.CrossEntropyLoss().cuda() 247 | 248 | for iter_epoch, (inp, target) in enumerate(loader): 249 | # measure data loading time 250 | data_time.update(time.perf_counter() - end) 251 | 252 | # move to gpu 253 | inp = inp.cuda(non_blocking=True) 254 | target = target.cuda(non_blocking=True) 255 | 256 | # forward 257 | output = model(inp) 258 | 259 | # compute cross entropy loss 260 | loss = criterion(output, target) 261 | 262 | # compute the gradients 263 | optimizer.zero_grad() 264 | loss.backward() 265 | 266 | # step 267 | optimizer.step() 268 | 269 | # update stats 270 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 271 | losses.update(loss.item(), inp.size(0)) 272 | top1.update(acc1[0], inp.size(0)) 273 | top5.update(acc5[0], inp.size(0)) 274 | 275 | batch_time.update(time.perf_counter() - end) 276 | end = time.perf_counter() 277 | 278 | # verbose 279 | if args.rank == 0 and iter_epoch % 50 == 0: 280 | logger.info( 281 | "Epoch[{0}] - Iter: [{1}/{2}]\t" 282 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 283 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" 284 | "Loss {loss.val:.4f} ({loss.avg:.4f})\t" 285 | "Prec {top1.val:.3f} ({top1.avg:.3f})\t" 286 | "LR trunk {lr}\t" 287 | "LR head {lr_W}".format( 288 | epoch, 289 | iter_epoch, 290 | len(loader), 291 | batch_time=batch_time, 292 | data_time=data_time, 293 | loss=losses, 294 | top1=top1, 295 | lr=optimizer.param_groups[0]["lr"], 296 | lr_W=optimizer.param_groups[1]["lr"], 297 | ) 298 | ) 299 | return epoch, losses.avg, top1.avg.item(), top5.avg.item() 300 | 301 | 302 | def validate_network(val_loader, model): 303 | batch_time = AverageMeter() 304 | losses = AverageMeter() 305 | top1 = AverageMeter() 306 | top5 = AverageMeter() 307 | global best_acc 308 | 309 | # switch to evaluate mode 310 | model.eval() 311 | 312 | criterion = nn.CrossEntropyLoss().cuda() 313 | 314 | with torch.no_grad(): 315 | end = time.perf_counter() 316 | for i, (inp, target) in enumerate(val_loader): 317 | 318 | # move to gpu 319 | inp = inp.cuda(non_blocking=True) 320 | target = target.cuda(non_blocking=True) 321 | 322 | # compute output 323 | output = model(inp) 324 | loss = criterion(output, target) 325 | 326 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 327 | losses.update(loss.item(), inp.size(0)) 328 | top1.update(acc1[0], inp.size(0)) 329 | top5.update(acc5[0], inp.size(0)) 330 | 331 | # measure elapsed time 332 | batch_time.update(time.perf_counter() - end) 333 | end = time.perf_counter() 334 | 335 | if top1.avg.item() > best_acc[0]: 336 | best_acc = (top1.avg.item(), top5.avg.item()) 337 | 338 | if args.rank == 0: 339 | logger.info( 340 | "Test:\t" 341 | "Time {batch_time.avg:.3f}\t" 342 | "Loss {loss.avg:.4f}\t" 343 | "Acc@1 {top1.avg:.3f}\t" 344 | "Best Acc@1 so far {acc:.1f}".format( 345 | batch_time=batch_time, loss=losses, top1=top1, acc=best_acc[0])) 346 | 347 | return losses.avg, top1.avg.item(), top5.avg.item() 348 | 349 | 350 | if __name__ == "__main__": 351 | main() 352 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import torchvision 9 | from torchvision.models.resnet import resnet50 as _resnet50 10 | 11 | dependencies = ["torch", "torchvision"] 12 | 13 | 14 | def resnet50(pretrained=True, **kwargs): 15 | """ 16 | ResNet-50 pre-trained with SwAV. 17 | 18 | Note that `fc.weight` and `fc.bias` are randomly initialized. 19 | 20 | Achieves 75.3% top-1 accuracy on ImageNet when `fc` is trained. 21 | """ 22 | model = _resnet50(pretrained=False, **kwargs) 23 | if pretrained: 24 | state_dict = torch.hub.load_state_dict_from_url( 25 | url="https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar", 26 | map_location="cpu", 27 | ) 28 | # removes "module." 29 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 30 | # load weights 31 | model.load_state_dict(state_dict, strict=False) 32 | return model 33 | -------------------------------------------------------------------------------- /main_swav.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import argparse 8 | import math 9 | import os 10 | import shutil 11 | import time 12 | from logging import getLogger 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.parallel 18 | import torch.backends.cudnn as cudnn 19 | import torch.distributed as dist 20 | import torch.optim 21 | import apex 22 | from apex.parallel.LARC import LARC 23 | 24 | from src.utils import * 25 | from src.multicropdataset import MultiCropDataset 26 | import src.resnet50 as resnet_models 27 | 28 | logger = getLogger() 29 | 30 | parser = argparse.ArgumentParser(description="Implementation of SwAV") 31 | 32 | ######################### 33 | #### data parameters #### 34 | ######################### 35 | parser.add_argument("--data_path", type=str, default="/path/to/imagenet", 36 | help="path to dataset repository") 37 | parser.add_argument("--nmb_crops", type=int, default=[2], nargs="+", 38 | help="list of number of crops (example: [2, 6])") 39 | parser.add_argument("--size_crops", type=int, default=[224], nargs="+", 40 | help="crops resolutions (example: [224, 96])") 41 | parser.add_argument("--min_scale_crops", type=float, default=[0.14], nargs="+", 42 | help="argument in RandomResizedCrop (example: [0.14, 0.05])") 43 | parser.add_argument("--max_scale_crops", type=float, default=[1], nargs="+", 44 | help="argument in RandomResizedCrop (example: [1., 0.14])") 45 | 46 | ######################### 47 | ## swav specific params # 48 | ######################### 49 | parser.add_argument("--crops_for_assign", type=int, nargs="+", default=[0, 1], 50 | help="list of crops id used for computing assignments") 51 | parser.add_argument("--temperature", default=0.1, type=float, 52 | help="temperature parameter in training loss") 53 | parser.add_argument("--epsilon", default=0.05, type=float, 54 | help="regularization parameter for Sinkhorn-Knopp algorithm") 55 | parser.add_argument("--sinkhorn_iterations", default=3, type=int, 56 | help="number of iterations in Sinkhorn-Knopp algorithm") 57 | parser.add_argument("--feat_dim", default=128, type=int, 58 | help="feature dimension") 59 | parser.add_argument("--nmb_prototypes", default=3000, type=int, 60 | help="number of prototypes") 61 | parser.add_argument("--queue_length", type=int, default=0, 62 | help="length of the queue (0 for no queue)") 63 | parser.add_argument("--epoch_queue_starts", type=int, default=15, 64 | help="from this epoch, we start using a queue") 65 | 66 | ######################### 67 | #### optim parameters ### 68 | ######################### 69 | parser.add_argument("--epochs", default=100, type=int, 70 | help="number of total epochs to run") 71 | parser.add_argument("--batch_size", default=64, type=int, 72 | help="batch size per gpu, i.e. how many unique instances per gpu") 73 | parser.add_argument("--base_lr", default=4.8, type=float, help="base learning rate") 74 | parser.add_argument("--final_lr", type=float, default=0, help="final learning rate") 75 | parser.add_argument("--freeze_prototypes_niters", default=313, type=int, 76 | help="freeze the prototypes during this many iterations from the start") 77 | parser.add_argument("--wd", default=1e-6, type=float, help="weight decay") 78 | parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs") 79 | parser.add_argument("--start_warmup", default=0, type=float, 80 | help="initial warmup learning rate") 81 | 82 | ######################### 83 | #### dist parameters ### 84 | ######################### 85 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up distributed 86 | training; see https://pytorch.org/docs/stable/distributed.html""") 87 | parser.add_argument("--world_size", default=-1, type=int, help=""" 88 | number of processes: it is set automatically and 89 | should not be passed as argument""") 90 | parser.add_argument("--rank", default=0, type=int, help="""rank of this process: 91 | it is set automatically and should not be passed as argument""") 92 | parser.add_argument("--local_rank", default=0, type=int, 93 | help="this argument is not used and should be ignored") 94 | 95 | ######################### 96 | #### other parameters ### 97 | ######################### 98 | parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture") 99 | parser.add_argument("--hidden_mlp", default=2048, type=int, 100 | help="hidden layer dimension in projection head") 101 | parser.add_argument("--workers", default=10, type=int, 102 | help="number of data loading workers") 103 | parser.add_argument("--checkpoint_freq", type=int, default=25, 104 | help="Save the model periodically") 105 | parser.add_argument("--use_fp16", type=bool_flag, default=True, 106 | help="whether to train with mixed precision or not") 107 | parser.add_argument("--sync_bn", type=str, default="pytorch", help="synchronize bn") 108 | parser.add_argument("--dump_path", type=str, default=".", 109 | help="experiment dump path for checkpoints and log") 110 | parser.add_argument("--seed", type=int, default=31, help="seed") 111 | 112 | 113 | def main(): 114 | global args 115 | args = parser.parse_args() 116 | # init_distributed_mode(args) 117 | fix_random_seeds(args.seed) 118 | logger, training_stats = initialize_exp(args, "epoch", "loss") 119 | 120 | os.environ["CUDA_VISIBLE_DEVICES"] = str('0') 121 | 122 | # build data 123 | train_dataset = MultiCropDataset( 124 | args.data_path, 125 | args.size_crops, 126 | args.nmb_crops, 127 | args.min_scale_crops, 128 | args.max_scale_crops, 129 | ) 130 | # sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 131 | train_loader = torch.utils.data.DataLoader( 132 | train_dataset, 133 | # sampler=sampler, 134 | batch_size=args.batch_size, 135 | num_workers=args.workers, 136 | pin_memory=True, 137 | drop_last=True 138 | ) 139 | logger.info("Building data done with {} images loaded.".format(len(train_dataset))) 140 | 141 | # build model 142 | model = resnet_models.__dict__[args.arch]( 143 | normalize=True, 144 | hidden_mlp=args.hidden_mlp, 145 | output_dim=args.feat_dim, 146 | nmb_prototypes=args.nmb_prototypes, 147 | ) 148 | 149 | # synchronize batch norm layers 150 | # if args.sync_bn == "pytorch": 151 | # model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 152 | # elif args.sync_bn == "apex": 153 | # process_group = None 154 | # if args.world_size // 8 > 0: 155 | # process_group = apex.parallel.create_syncbn_process_group(args.world_size // 8) 156 | # model = apex.parallel.convert_syncbn_model(model, process_group=process_group) 157 | # copy model to GPU 158 | model = model.cuda() 159 | if args.rank == 0: 160 | logger.info(model) 161 | logger.info("Building model done.") 162 | 163 | # build optimizer 164 | optimizer = torch.optim.SGD( 165 | model.parameters(), 166 | lr=args.base_lr, 167 | momentum=0.9, 168 | weight_decay=args.wd, 169 | ) 170 | optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False) 171 | warmup_lr_schedule = np.linspace(args.start_warmup, args.base_lr, len(train_loader) * args.warmup_epochs) 172 | iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs)) 173 | cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \ 174 | math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) for t in iters]) 175 | lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) 176 | logger.info("Building optimizer done.") 177 | 178 | # init mixed precision 179 | if args.use_fp16: 180 | model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1") 181 | logger.info("Initializing mixed precision done.") 182 | 183 | # wrap model 184 | # model = nn.parallel.DistributedDataParallel( 185 | # model, 186 | # device_ids=[args.gpu_to_work_on], 187 | # find_unused_parameters=True, 188 | # ) 189 | 190 | # optionally resume from a checkpoint 191 | to_restore = {"epoch": 0} 192 | restart_from_checkpoint( 193 | os.path.join(args.dump_path, "checkpoint.pth.tar"), 194 | run_variables=to_restore, 195 | state_dict=model, 196 | optimizer=optimizer, 197 | amp=apex.amp, 198 | ) 199 | start_epoch = to_restore["epoch"] 200 | 201 | # build the queue 202 | queue = None 203 | queue_path = os.path.join(args.dump_path, "queue" + str(args.rank) + ".pth") 204 | if os.path.isfile(queue_path): 205 | queue = torch.load(queue_path)["queue"] 206 | # the queue needs to be divisible by the batch size 207 | args.queue_length -= args.queue_length % (args.batch_size * args.world_size) 208 | 209 | cudnn.benchmark = True 210 | 211 | for epoch in range(start_epoch, args.epochs): 212 | 213 | # train the network for one epoch 214 | logger.info("============ Starting epoch %i ... ============" % epoch) 215 | 216 | # set sampler 217 | # train_loader.sampler.set_epoch(epoch) 218 | 219 | # optionally starts a queue 220 | if args.queue_length > 0 and epoch >= args.epoch_queue_starts and queue is None: 221 | queue = torch.zeros( 222 | len(args.crops_for_assign), 223 | args.queue_length // args.world_size, 224 | args.feat_dim, 225 | ).cuda() 226 | 227 | # train the network 228 | scores, queue = train(train_loader, model, optimizer, epoch, lr_schedule, queue) 229 | training_stats.update(scores) 230 | 231 | # save checkpoints 232 | if args.rank == 0: 233 | save_dict = { 234 | "epoch": epoch + 1, 235 | "state_dict": model.state_dict(), 236 | "optimizer": optimizer.state_dict(), 237 | } 238 | if args.use_fp16: 239 | save_dict["amp"] = apex.amp.state_dict() 240 | torch.save( 241 | save_dict, 242 | os.path.join(args.dump_path, "checkpoint.pth.tar"), 243 | ) 244 | if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1: 245 | shutil.copyfile( 246 | os.path.join(args.dump_path, "checkpoint.pth.tar"), 247 | os.path.join(args.dump_checkpoints, "ckp-" + str(epoch) + ".pth"), 248 | ) 249 | if queue is not None: 250 | torch.save({"queue": queue}, queue_path) 251 | 252 | 253 | def train(train_loader, model, optimizer, epoch, lr_schedule, queue): 254 | batch_time = AverageMeter() 255 | data_time = AverageMeter() 256 | losses = AverageMeter() 257 | 258 | softmax = nn.Softmax(dim=1).cuda() 259 | model.train() 260 | use_the_queue = False 261 | 262 | end = time.time() 263 | for it, inputs in enumerate(train_loader): 264 | # measure data loading time 265 | data_time.update(time.time() - end) 266 | 267 | # update learning rate 268 | iteration = epoch * len(train_loader) + it 269 | for param_group in optimizer.param_groups: 270 | param_group["lr"] = lr_schedule[iteration] 271 | 272 | # normalize the prototypes 273 | with torch.no_grad(): 274 | w = model.prototypes.weight.data.clone() 275 | w = nn.functional.normalize(w, dim=1, p=2) 276 | model.prototypes.weight.copy_(w) 277 | 278 | # ============ multi-res forward passes ... ============ 279 | embedding, output = model(inputs) 280 | embedding = embedding.detach() 281 | bs = inputs[0].size(0) 282 | 283 | # ============ swav loss ... ============ 284 | loss = 0 285 | for i, crop_id in enumerate(args.crops_for_assign): 286 | with torch.no_grad(): 287 | out = output[bs * crop_id: bs * (crop_id + 1)] 288 | 289 | # time to use the queue 290 | if queue is not None: 291 | if use_the_queue or not torch.all(queue[i, -1, :] == 0): 292 | use_the_queue = True 293 | out = torch.cat((torch.mm( 294 | queue[i], 295 | model.prototypes.weight.t() 296 | ), out)) 297 | # fill the queue 298 | queue[i, bs:] = queue[i, :-bs].clone() 299 | queue[i, :bs] = embedding[crop_id * bs: (crop_id + 1) * bs] 300 | # get assignments 301 | q = torch.exp(out / args.epsilon).t() 302 | q = distributed_sinkhorn(q, args.sinkhorn_iterations)[-bs:] 303 | 304 | # cluster assignment prediction 305 | subloss = 0 306 | for v in np.delete(np.arange(np.sum(args.nmb_crops)), crop_id): 307 | p = softmax(output[bs * v: bs * (v + 1)] / args.temperature) 308 | subloss -= torch.mean(torch.sum(q * torch.log(p), dim=1)) 309 | loss += subloss / (np.sum(args.nmb_crops) - 1) 310 | loss /= len(args.crops_for_assign) 311 | 312 | # ============ backward and optim step ... ============ 313 | optimizer.zero_grad() 314 | if args.use_fp16: 315 | with apex.amp.scale_loss(loss, optimizer) as scaled_loss: 316 | scaled_loss.backward() 317 | else: 318 | loss.backward() 319 | # cancel some gradients 320 | if iteration < args.freeze_prototypes_niters: 321 | for name, p in model.named_parameters(): 322 | if "prototypes" in name: 323 | p.grad = None 324 | optimizer.step() 325 | 326 | # ============ misc ... ============ 327 | losses.update(loss.item(), inputs[0].size(0)) 328 | batch_time.update(time.time() - end) 329 | end = time.time() 330 | if args.rank ==0 and it % 50 == 0: 331 | logger.info( 332 | "Epoch: [{0}][{1}]\t" 333 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 334 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" 335 | "Loss {loss.val:.4f} ({loss.avg:.4f})\t" 336 | "Lr: {lr:.4f}".format( 337 | epoch, 338 | it, 339 | batch_time=batch_time, 340 | data_time=data_time, 341 | loss=losses, 342 | lr=optimizer.optim.param_groups[0]["lr"], 343 | ) 344 | ) 345 | return (epoch, losses.avg), queue 346 | 347 | 348 | def distributed_sinkhorn(Q, nmb_iters): 349 | with torch.no_grad(): 350 | sum_Q = torch.sum(Q) 351 | # dist.all_reduce(sum_Q) 352 | Q /= sum_Q 353 | 354 | u = torch.zeros(Q.shape[0]).cuda(non_blocking=True) 355 | r = torch.ones(Q.shape[0]).cuda(non_blocking=True) / Q.shape[0] 356 | c = torch.ones(Q.shape[1]).cuda(non_blocking=True) / (args.world_size * Q.shape[1]) 357 | 358 | curr_sum = torch.sum(Q, dim=1) 359 | # dist.all_reduce(curr_sum) 360 | 361 | for it in range(nmb_iters): 362 | u = curr_sum 363 | Q *= (r / u).unsqueeze(1) 364 | Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0) 365 | curr_sum = torch.sum(Q, dim=1) 366 | # dist.all_reduce(curr_sum) 367 | return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float() 368 | 369 | 370 | if __name__ == "__main__": 371 | main() 372 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python main_swav.py --data_path ../../dataset/ --size_crops 32 --nmb_crops 2 --nmb_prototypes 30 --batch_size 512 --epochs 500 --base_lr 0.06 --final_lr 0.0006 --temperature 0.5 --use_fp16 true --dump_path checkpoints --freeze_prototypes_niters 900 2 | python eval_linear.py --data_path ../../dataset/ --dump_path ./checkpoints_linear --pretrained checkpoint.pth.tar --batch_size 512 --lr 0.03 -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhinavagarwalla/swav-cifar10/4369b58aff2dac7b9b1e40d53af2f2eac9be9481/src/__init__.py -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import logging 10 | import time 11 | from datetime import timedelta 12 | import pandas as pd 13 | 14 | 15 | class LogFormatter: 16 | def __init__(self): 17 | self.start_time = time.time() 18 | 19 | def format(self, record): 20 | elapsed_seconds = round(record.created - self.start_time) 21 | 22 | prefix = "%s - %s - %s" % ( 23 | record.levelname, 24 | time.strftime("%x %X"), 25 | timedelta(seconds=elapsed_seconds), 26 | ) 27 | message = record.getMessage() 28 | message = message.replace("\n", "\n" + " " * (len(prefix) + 3)) 29 | return "%s - %s" % (prefix, message) if message else "" 30 | 31 | 32 | def create_logger(filepath, rank): 33 | """ 34 | Create a logger. 35 | Use a different log file for each process. 36 | """ 37 | # create log formatter 38 | log_formatter = LogFormatter() 39 | 40 | # create file handler and set level to debug 41 | if filepath is not None: 42 | if rank > 0: 43 | filepath = "%s-%i" % (filepath, rank) 44 | file_handler = logging.FileHandler(filepath, "a") 45 | file_handler.setLevel(logging.DEBUG) 46 | file_handler.setFormatter(log_formatter) 47 | 48 | # create console handler and set level to info 49 | console_handler = logging.StreamHandler() 50 | console_handler.setLevel(logging.INFO) 51 | console_handler.setFormatter(log_formatter) 52 | 53 | # create logger and set level to debug 54 | logger = logging.getLogger() 55 | logger.handlers = [] 56 | logger.setLevel(logging.DEBUG) 57 | logger.propagate = False 58 | if filepath is not None: 59 | logger.addHandler(file_handler) 60 | logger.addHandler(console_handler) 61 | 62 | # reset logger elapsed time 63 | def reset_time(): 64 | log_formatter.start_time = time.time() 65 | 66 | logger.reset_time = reset_time 67 | 68 | return logger 69 | 70 | 71 | class PD_Stats(object): 72 | """ 73 | Log stuff with pandas library 74 | """ 75 | 76 | def __init__(self, path, columns): 77 | self.path = path 78 | 79 | # reload path stats 80 | if os.path.isfile(self.path): 81 | self.stats = pd.read_pickle(self.path) 82 | 83 | # check that columns are the same 84 | assert list(self.stats.columns) == list(columns) 85 | 86 | else: 87 | self.stats = pd.DataFrame(columns=columns) 88 | 89 | def update(self, row, save=True): 90 | self.stats.loc[len(self.stats.index)] = row 91 | 92 | # save the statistics 93 | if save: 94 | self.stats.to_pickle(self.path) 95 | -------------------------------------------------------------------------------- /src/multicropdataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | 10 | import cv2 11 | 12 | import numpy as np 13 | import torchvision.datasets as datasets 14 | import torchvision.transforms as transforms 15 | 16 | from PIL import Image 17 | logger = getLogger() 18 | 19 | 20 | class MultiCropDataset(datasets.CIFAR10): 21 | def __init__( 22 | self, 23 | data_path, 24 | size_crops, 25 | nmb_crops, 26 | min_scale_crops, 27 | max_scale_crops, 28 | size_dataset=-1, 29 | return_index=False, 30 | ): 31 | super(MultiCropDataset, self).__init__(data_path) 32 | assert len(size_crops) == len(nmb_crops) 33 | assert len(min_scale_crops) == len(nmb_crops) 34 | assert len(max_scale_crops) == len(nmb_crops) 35 | if size_dataset >= 0: 36 | self.samples = self.samples[:size_dataset] 37 | self.return_index = return_index 38 | 39 | trans = [] 40 | color_transform = transforms.Compose([get_color_distortion(), RandomGaussianBlur()]) 41 | mean = [0.485, 0.456, 0.406] 42 | std = [0.228, 0.224, 0.225] 43 | for i in range(len(size_crops)): 44 | randomresizedcrop = transforms.RandomResizedCrop( 45 | size_crops[i], 46 | # scale=(min_scale_crops[i], max_scale_crops[i]), 47 | ) 48 | trans.extend([transforms.Compose([ 49 | randomresizedcrop, 50 | transforms.RandomHorizontalFlip(p=0.5), 51 | color_transform, 52 | transforms.ToTensor(), 53 | transforms.Normalize(mean=mean, std=std)]) 54 | ] * nmb_crops[i]) 55 | self.trans = trans 56 | 57 | def __getitem__(self, index): 58 | img = self.data[index] 59 | image = Image.fromarray(img) 60 | multi_crops = list(map(lambda trans: trans(image), self.trans)) 61 | if self.return_index: 62 | return index, multi_crops 63 | return multi_crops 64 | 65 | 66 | class RandomGaussianBlur(object): 67 | def __call__(self, img): 68 | do_it = np.random.rand() > 0.5 69 | if not do_it: 70 | return img 71 | sigma = np.random.rand() * 1.9 + 0.1 72 | return cv2.GaussianBlur(np.asarray(img), (23, 23), sigma) 73 | 74 | 75 | def get_color_distortion(s=1.0): 76 | # s is the strength of color distortion. 77 | color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s) 78 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) 79 | rnd_gray = transforms.RandomGrayscale(p=0.2) 80 | color_distort = transforms.Compose([rnd_color_jitter, rnd_gray]) 81 | return color_distort 82 | -------------------------------------------------------------------------------- /src/resnet50.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 13 | """3x3 convolution with padding""" 14 | return nn.Conv2d( 15 | in_planes, 16 | out_planes, 17 | kernel_size=3, 18 | stride=stride, 19 | padding=dilation, 20 | groups=groups, 21 | bias=False, 22 | dilation=dilation, 23 | ) 24 | 25 | 26 | def conv1x1(in_planes, out_planes, stride=1): 27 | """1x1 convolution""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 29 | 30 | 31 | class BasicBlock(nn.Module): 32 | expansion = 1 33 | __constants__ = ["downsample"] 34 | 35 | def __init__( 36 | self, 37 | inplanes, 38 | planes, 39 | stride=1, 40 | downsample=None, 41 | groups=1, 42 | base_width=64, 43 | dilation=1, 44 | norm_layer=None, 45 | ): 46 | super(BasicBlock, self).__init__() 47 | if norm_layer is None: 48 | norm_layer = nn.BatchNorm2d 49 | if groups != 1 or base_width != 64: 50 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 51 | if dilation > 1: 52 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 53 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 54 | self.conv1 = conv3x3(inplanes, planes, stride) 55 | self.bn1 = norm_layer(planes) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.conv2 = conv3x3(planes, planes) 58 | self.bn2 = norm_layer(planes) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | identity = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | 72 | if self.downsample is not None: 73 | identity = self.downsample(x) 74 | 75 | out += identity 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class Bottleneck(nn.Module): 82 | expansion = 4 83 | __constants__ = ["downsample"] 84 | 85 | def __init__( 86 | self, 87 | inplanes, 88 | planes, 89 | stride=1, 90 | downsample=None, 91 | groups=1, 92 | base_width=64, 93 | dilation=1, 94 | norm_layer=None, 95 | ): 96 | super(Bottleneck, self).__init__() 97 | if norm_layer is None: 98 | norm_layer = nn.BatchNorm2d 99 | width = int(planes * (base_width / 64.0)) * groups 100 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 101 | self.conv1 = conv1x1(inplanes, width) 102 | self.bn1 = norm_layer(width) 103 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 104 | self.bn2 = norm_layer(width) 105 | self.conv3 = conv1x1(width, planes * self.expansion) 106 | self.bn3 = norm_layer(planes * self.expansion) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.downsample = downsample 109 | self.stride = stride 110 | 111 | def forward(self, x): 112 | identity = x 113 | 114 | out = self.conv1(x) 115 | out = self.bn1(out) 116 | out = self.relu(out) 117 | 118 | out = self.conv2(out) 119 | out = self.bn2(out) 120 | out = self.relu(out) 121 | 122 | out = self.conv3(out) 123 | out = self.bn3(out) 124 | 125 | if self.downsample is not None: 126 | identity = self.downsample(x) 127 | 128 | out += identity 129 | out = self.relu(out) 130 | 131 | return out 132 | 133 | 134 | class ResNet(nn.Module): 135 | def __init__( 136 | self, 137 | block, 138 | layers, 139 | zero_init_residual=False, 140 | groups=1, 141 | widen=1, 142 | width_per_group=64, 143 | replace_stride_with_dilation=None, 144 | norm_layer=None, 145 | normalize=False, 146 | output_dim=0, 147 | hidden_mlp=0, 148 | nmb_prototypes=0, 149 | eval_mode=False, 150 | ): 151 | super(ResNet, self).__init__() 152 | if norm_layer is None: 153 | norm_layer = nn.BatchNorm2d 154 | self._norm_layer = norm_layer 155 | 156 | self.eval_mode = eval_mode 157 | self.padding = nn.ConstantPad2d(1, 0.0) 158 | 159 | self.inplanes = width_per_group * widen 160 | self.dilation = 1 161 | if replace_stride_with_dilation is None: 162 | # each element in the tuple indicates if we should replace 163 | # the 2x2 stride with a dilated convolution instead 164 | replace_stride_with_dilation = [False, False, False] 165 | if len(replace_stride_with_dilation) != 3: 166 | raise ValueError( 167 | "replace_stride_with_dilation should be None " 168 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation) 169 | ) 170 | self.groups = groups 171 | self.base_width = width_per_group 172 | 173 | # change padding 3 -> 2 compared to original torchvision code because added a padding layer 174 | num_out_filters = width_per_group * widen 175 | self.conv1 = nn.Conv2d( 176 | 3, num_out_filters, kernel_size=3, stride=1, padding=1, bias=False 177 | ) 178 | self.bn1 = norm_layer(num_out_filters) 179 | self.relu = nn.ReLU(inplace=True) 180 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 181 | self.layer1 = self._make_layer(block, num_out_filters, layers[0]) 182 | num_out_filters *= 2 183 | self.layer2 = self._make_layer( 184 | block, num_out_filters, layers[1], stride=2, dilate=replace_stride_with_dilation[0] 185 | ) 186 | num_out_filters *= 2 187 | self.layer3 = self._make_layer( 188 | block, num_out_filters, layers[2], stride=2, dilate=replace_stride_with_dilation[1] 189 | ) 190 | num_out_filters *= 2 191 | self.layer4 = self._make_layer( 192 | block, num_out_filters, layers[3], stride=2, dilate=replace_stride_with_dilation[2] 193 | ) 194 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 195 | 196 | # normalize output features 197 | self.l2norm = normalize 198 | 199 | # projection head 200 | if output_dim == 0: 201 | self.projection_head = None 202 | elif hidden_mlp == 0: 203 | self.projection_head = nn.Linear(num_out_filters * block.expansion, output_dim) 204 | else: 205 | self.projection_head = nn.Sequential( 206 | nn.Linear(num_out_filters * block.expansion, hidden_mlp), 207 | nn.BatchNorm1d(hidden_mlp), 208 | nn.ReLU(inplace=True), 209 | nn.Linear(hidden_mlp, output_dim), 210 | ) 211 | 212 | # prototype layer 213 | self.prototypes = None 214 | if isinstance(nmb_prototypes, list): 215 | self.prototypes = MultiPrototypes(output_dim, nmb_prototypes) 216 | elif nmb_prototypes > 0: 217 | self.prototypes = nn.Linear(output_dim, nmb_prototypes, bias=False) 218 | 219 | for m in self.modules(): 220 | if isinstance(m, nn.Conv2d): 221 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 222 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 223 | nn.init.constant_(m.weight, 1) 224 | nn.init.constant_(m.bias, 0) 225 | 226 | # Zero-initialize the last BN in each residual branch, 227 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 228 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 229 | if zero_init_residual: 230 | for m in self.modules(): 231 | if isinstance(m, Bottleneck): 232 | nn.init.constant_(m.bn3.weight, 0) 233 | elif isinstance(m, BasicBlock): 234 | nn.init.constant_(m.bn2.weight, 0) 235 | 236 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 237 | norm_layer = self._norm_layer 238 | downsample = None 239 | previous_dilation = self.dilation 240 | if dilate: 241 | self.dilation *= stride 242 | stride = 1 243 | if stride != 1 or self.inplanes != planes * block.expansion: 244 | downsample = nn.Sequential( 245 | conv1x1(self.inplanes, planes * block.expansion, stride), 246 | norm_layer(planes * block.expansion), 247 | ) 248 | 249 | layers = [] 250 | layers.append( 251 | block( 252 | self.inplanes, 253 | planes, 254 | stride, 255 | downsample, 256 | self.groups, 257 | self.base_width, 258 | previous_dilation, 259 | norm_layer, 260 | ) 261 | ) 262 | self.inplanes = planes * block.expansion 263 | for _ in range(1, blocks): 264 | layers.append( 265 | block( 266 | self.inplanes, 267 | planes, 268 | groups=self.groups, 269 | base_width=self.base_width, 270 | dilation=self.dilation, 271 | norm_layer=norm_layer, 272 | ) 273 | ) 274 | 275 | return nn.Sequential(*layers) 276 | 277 | def forward_backbone(self, x): 278 | x = self.padding(x) 279 | 280 | x = self.conv1(x) 281 | x = self.bn1(x) 282 | x = self.relu(x) 283 | # x = self.maxpool(x) 284 | x = self.layer1(x) 285 | x = self.layer2(x) 286 | x = self.layer3(x) 287 | x = self.layer4(x) 288 | 289 | if self.eval_mode: 290 | return x 291 | 292 | x = self.avgpool(x) 293 | x = torch.flatten(x, 1) 294 | 295 | return x 296 | 297 | def forward_head(self, x): 298 | if self.projection_head is not None: 299 | x = self.projection_head(x) 300 | 301 | if self.l2norm: 302 | x = nn.functional.normalize(x, dim=1, p=2) 303 | 304 | if self.prototypes is not None: 305 | return x, self.prototypes(x) 306 | return x 307 | 308 | def forward(self, inputs): 309 | if not isinstance(inputs, list): 310 | inputs = [inputs] 311 | idx_crops = torch.cumsum(torch.unique_consecutive( 312 | torch.tensor([inp.shape[-1] for inp in inputs]), 313 | return_counts=True, 314 | )[1], 0) 315 | start_idx = 0 316 | for end_idx in idx_crops: 317 | _out = self.forward_backbone(torch.cat(inputs[start_idx: end_idx]).cuda(non_blocking=True)) 318 | if start_idx == 0: 319 | output = _out 320 | else: 321 | output = torch.cat((output, _out)) 322 | start_idx = end_idx 323 | return self.forward_head(output) 324 | 325 | 326 | class MultiPrototypes(nn.Module): 327 | def __init__(self, output_dim, nmb_prototypes): 328 | super(MultiPrototypes, self).__init__() 329 | self.nmb_heads = len(nmb_prototypes) 330 | for i, k in enumerate(nmb_prototypes): 331 | self.add_module("prototypes" + str(i), nn.Linear(output_dim, k, bias=False)) 332 | 333 | def forward(self, x): 334 | out = [] 335 | for i in range(self.nmb_heads): 336 | out.append(getattr(self, "prototypes" + str(i))(x)) 337 | return out 338 | 339 | 340 | def resnet50(**kwargs): 341 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 342 | 343 | 344 | def resnet50w2(**kwargs): 345 | return ResNet(Bottleneck, [3, 4, 6, 3], widen=2, **kwargs) 346 | 347 | 348 | def resnet50w4(**kwargs): 349 | return ResNet(Bottleneck, [3, 4, 6, 3], widen=4, **kwargs) 350 | 351 | 352 | def resnet50w5(**kwargs): 353 | return ResNet(Bottleneck, [3, 4, 6, 3], widen=5, **kwargs) 354 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | from logging import getLogger 10 | import pickle 11 | import os 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | 17 | from .logger import create_logger, PD_Stats 18 | 19 | import torch.distributed as dist 20 | 21 | FALSY_STRINGS = {"off", "false", "0"} 22 | TRUTHY_STRINGS = {"on", "true", "1"} 23 | 24 | 25 | logger = getLogger() 26 | 27 | 28 | def bool_flag(s): 29 | """ 30 | Parse boolean arguments from the command line. 31 | """ 32 | if s.lower() in FALSY_STRINGS: 33 | return False 34 | elif s.lower() in TRUTHY_STRINGS: 35 | return True 36 | else: 37 | raise argparse.ArgumentTypeError("invalid value for a boolean flag") 38 | 39 | 40 | def init_distributed_mode(args): 41 | """ 42 | Initialize the following variables: 43 | - world_size 44 | - rank 45 | """ 46 | 47 | args.is_slurm_job = "SLURM_JOB_ID" in os.environ 48 | 49 | if args.is_slurm_job: 50 | args.rank = int(os.environ["SLURM_PROCID"]) 51 | args.world_size = int(os.environ["SLURM_NNODES"]) * int( 52 | os.environ["SLURM_TASKS_PER_NODE"][0] 53 | ) 54 | else: 55 | # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch 56 | # read environment variables 57 | args.rank = int(os.environ["RANK"]) 58 | args.world_size = int(os.environ["WORLD_SIZE"]) 59 | 60 | # prepare distributed 61 | dist.init_process_group( 62 | backend="nccl", 63 | init_method=args.dist_url, 64 | world_size=args.world_size, 65 | rank=args.rank, 66 | ) 67 | 68 | # set cuda device 69 | args.gpu_to_work_on = args.rank % torch.cuda.device_count() 70 | torch.cuda.set_device(args.gpu_to_work_on) 71 | return 72 | 73 | 74 | def initialize_exp(params, *args, dump_params=True): 75 | """ 76 | Initialize the experience: 77 | - dump parameters 78 | - create checkpoint repo 79 | - create a logger 80 | - create a panda object to keep track of the training statistics 81 | """ 82 | 83 | # dump parameters 84 | if dump_params: 85 | pickle.dump(params, open(os.path.join(params.dump_path, "params.pkl"), "wb")) 86 | 87 | # create repo to store checkpoints 88 | params.dump_checkpoints = os.path.join(params.dump_path, "checkpoints") 89 | if not params.rank and not os.path.isdir(params.dump_checkpoints): 90 | os.mkdir(params.dump_checkpoints) 91 | 92 | # create a panda object to log loss and acc 93 | training_stats = PD_Stats( 94 | os.path.join(params.dump_path, "stats" + str(params.rank) + ".pkl"), args 95 | ) 96 | 97 | # create a logger 98 | logger = create_logger( 99 | os.path.join(params.dump_path, "train.log"), rank=params.rank 100 | ) 101 | logger.info("============ Initialized logger ============") 102 | logger.info( 103 | "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(params)).items())) 104 | ) 105 | logger.info("The experiment will be stored in %s\n" % params.dump_path) 106 | logger.info("") 107 | return logger, training_stats 108 | 109 | 110 | def restart_from_checkpoint(ckp_paths, run_variables=None, **kwargs): 111 | """ 112 | Re-start from checkpoint 113 | """ 114 | # look for a checkpoint in exp repository 115 | if isinstance(ckp_paths, list): 116 | for ckp_path in ckp_paths: 117 | if os.path.isfile(ckp_path): 118 | break 119 | else: 120 | ckp_path = ckp_paths 121 | 122 | if not os.path.isfile(ckp_path): 123 | return 124 | 125 | logger.info("Found checkpoint at {}".format(ckp_path)) 126 | 127 | # open checkpoint file 128 | checkpoint = torch.load( 129 | ckp_path, map_location="cuda:" + str(torch.distributed.get_rank() % torch.cuda.device_count()) 130 | ) 131 | 132 | # key is what to look for in the checkpoint file 133 | # value is the object to load 134 | # example: {'state_dict': model} 135 | for key, value in kwargs.items(): 136 | if key in checkpoint and value is not None: 137 | try: 138 | msg = value.load_state_dict(checkpoint[key], strict=False) 139 | print(msg) 140 | except TypeError: 141 | msg = value.load_state_dict(checkpoint[key]) 142 | logger.info("=> loaded {} from checkpoint '{}'".format(key, ckp_path)) 143 | else: 144 | logger.warning( 145 | "=> failed to load {} from checkpoint '{}'".format(key, ckp_path) 146 | ) 147 | 148 | # re load variable important for the run 149 | if run_variables is not None: 150 | for var_name in run_variables: 151 | if var_name in checkpoint: 152 | run_variables[var_name] = checkpoint[var_name] 153 | 154 | 155 | def fix_random_seeds(seed=31): 156 | """ 157 | Fix random seeds. 158 | """ 159 | torch.manual_seed(seed) 160 | torch.cuda.manual_seed_all(seed) 161 | np.random.seed(seed) 162 | 163 | 164 | class AverageMeter(object): 165 | """computes and stores the average and current value""" 166 | 167 | def __init__(self): 168 | self.reset() 169 | 170 | def reset(self): 171 | self.val = 0 172 | self.avg = 0 173 | self.sum = 0 174 | self.count = 0 175 | 176 | def update(self, val, n=1): 177 | self.val = val 178 | self.sum += val * n 179 | self.count += n 180 | self.avg = self.sum / self.count 181 | 182 | 183 | def accuracy(output, target, topk=(1,)): 184 | """Computes the accuracy over the k top predictions for the specified values of k""" 185 | with torch.no_grad(): 186 | maxk = max(topk) 187 | batch_size = target.size(0) 188 | 189 | _, pred = output.topk(maxk, 1, True, True) 190 | pred = pred.t() 191 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 192 | 193 | res = [] 194 | for k in topk: 195 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 196 | res.append(correct_k.mul_(100.0 / batch_size)) 197 | return res 198 | --------------------------------------------------------------------------------