├── .gitignore ├── LICENSE ├── README.md ├── config ├── finetune │ ├── vit_base_finetune.py │ ├── vit_small_finetune.py │ └── vit_tiny_finetune.py ├── knn │ └── knn.py ├── linear │ ├── vit_base_linear.py │ ├── vit_small_linear.py │ └── vit_tiny_linear.py └── pretrain │ ├── vit_base_pretrain.py │ ├── vit_small_pretrain.py │ └── vit_tiny_pretrain.py ├── images ├── framework.png └── visualization.png ├── main_finetune.py ├── main_knn.py ├── main_linear.py ├── main_pretrain.py ├── module ├── augmentation.py ├── frame │ ├── contrast_momentum.py │ └── contrast_no_momentum.py ├── loss.py ├── mix.py └── vits.py ├── requirements.txt └── utils ├── logger.py └── misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | weight/* 3 | dataset/* 4 | .vscode/* 5 | log/* 6 | out/* 7 | ckpt/* 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Inter-Instance Similarity Modeling for Contrastive Learning 2 | 3 | ### 1. Introduction 4 | 5 | This is the official implementation of paper: "Inter-Instance Similarity Modeling for Contrastive Learning". 6 | 7 | ![Framework](./images/framework.png) 8 | 9 | PatchMix is a novel image mix strategy, which mixes multiple images in patch level. The mixed image contains massive local components from multiple images and efficiently simulates rich similarities among natural images in an unsupervised manner. To model rich inter-instance similarities among images, the contrasts between mixed images and original ones, mixed images to mixed ones, and original images to original ones are conducted to optimize the ViT model. Experimental results demonstrate that our proposed method significantly outperforms the previous state-of-the-art on both ImageNet-1K and CIFAR datasets, e.g., 3.0% linear accuracy improvement on ImageNet-1K and 8.7% kNN accuracy improvement on CIFAR100. 10 | 11 | [[Paper](https://arxiv.org/abs/2306.12243)] [[BibTex](#Citation)] [[Blog(CN)](https://zhuanlan.zhihu.com/p/639240952)] 12 | 13 | 14 | 15 | ### Requirements 16 | 17 | ```bash 18 | conda create -n patchmix python=3.8 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | 23 | 24 | ### Datasets 25 | 26 | Please set the root paths of dataset in the `*.py` configuration file under the directory: `./config/`. 27 | `CIFAR10`, `CIFAR100` datasets provided by `torchvision`. The root paths of data are set to `/path/to/dataset` . The root path of `ImageNet-1K (ILSVRC2012)` is `/path/to/ILSVRC2012` 28 | 29 | 30 | 31 | ### Self-Supervised Pretraining 32 | 33 | #### ViT-Small with 2-node (8-GPU) training 34 | 35 | Set hyperparameters, dataset and GPU IDs in `./config/pretrain/vit_small_pretrain.py` and run the following command 36 | 37 | ```bash 38 | python main_pretrain.py --arch vit-small 39 | ``` 40 | 41 | 42 | 43 | ### kNN Evaluation 44 | 45 | Set hyperparameters, dataset and GPU IDs in `./config/knn/knn.py` and run the following command 46 | 47 | ```bash 48 | python main_knn.py --arch vit-small --pretrained-weights /path/to/pretrained-weights.pth 49 | ``` 50 | 51 | 52 | 53 | ### Linear Evaluation 54 | 55 | Set hyperparameters, dataset and GPU IDs in `./config/linear/vit_small_linear.py` and run the following command: 56 | 57 | ```bash 58 | python main_linear.py --arch vit-small --pretrained-weights /path/to/pretrained-weights.pth 59 | ``` 60 | 61 | 62 | 63 | ### Fine-tuning Evaluation 64 | 65 | Set hyperparameters, dataset and GPUs in `./config/finetuning/vit_small_finetuning.py` and run the following command 66 | 67 | ```bash 68 | python python main_finetune.py --arch vit-small --pretrained-weights /path/to/pretrained-weights.pth 69 | ``` 70 | 71 | 72 | 73 | ### Main Results and Model Weights 74 | 75 | If you don't have a **mircosoft office account**, you can download the trained model weights by [this link](https://csueducn-my.sharepoint.com/:f:/g/personal/221258_csu_edu_cn/EsSud0DB_edBiODrZhDbNpsBwfTbpOkuJ_TKA6mTYSi6Dw). 76 | 77 | If you have a **mircosoft office account**, you can download the trained model weights by the links in the following tables. 78 | 79 | #### ImageNet-1K 80 | 81 | | Arch | Batch size | #Pre-Epoch | Finetuning Accuracy | Linear Probing Accuracy | kNN Accuracy | 82 | |:------------:|:------:|:-----:|:------:|:--------:|:----------------------------------------------------------------------:| 83 | | ViT-S/16 | 1024 | 300 | 82.8% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Ffinetune%2Fimagenet%2D1k%2Fvit%2Dsmall%2D300%2D82%2E8%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Ffinetune%2Fimagenet%2D1k)) | 77.4% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Flinear%2Fimagenet1k%2Fvit%2Dsmall%2D300%2D77%2E4%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Flinear%2Fimagenet1k)) | 73.3% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Fpretrain%2Fimagenet1k%2Fvit%2Dsmall%2D300%2D73%2E3%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Fpretrain%2Fimagenet1k)) | 84 | | ViT-B/16 | 1024 | 300 | 84.1% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Ffinetune%2Fimagenet%2D1k%2Fvit%2Dbase%2D300%2D84%2E1%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Ffinetune%2Fimagenet%2D1k)) | 80.2% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Flinear%2Fimagenet1k%2Fvit%2Dbase%2D300%2D80%2E2%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Flinear%2Fimagenet1k)) | 76.2% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Fpretrain%2Fimagenet1k%2Fvit%2Dbase%2D300%2D76%2E2%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Fpretrain%2Fimagenet1k)) | 85 | 86 | 87 | 88 | #### CIFAR10 89 | 90 | | Arch | Batch size | #Pre-Epoch | Finetuning Accuracy | Linear Probing Accuracy | kNN Accuracy | 91 | | :-----: | :--------: | :--------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | 92 | | ViT-T/2 | 512 | 800 | 97.5% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Ffinetune%2Fcifar10%2Fvit%2Dtiny%2D800%2D97%2E5%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Ffinetune%2Fcifar10)) | 94.4% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Flinear%2Fcifar10%2Fvit%2Dtiny%2D800%2D94%2E4%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Flinear%2Fcifar10)) | 92.9% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Fpretrain%2Fcifar10%2Fvit%2Dtiny%2D800%2D92%2E9%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Fpretrain%2Fcifar10)) | 93 | | ViT-S/2 | 512 | 800 | 98.1% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Ffinetune%2Fcifar10%2Fvit%2Dsmall%2D800%2D98%2E1%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Ffinetune%2Fcifar10)) | 96.0% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Flinear%2Fcifar10%2Fvit%2Dsmall%2D800%2D96%2E0%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Flinear%2Fcifar10)) | 94.6% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Fpretrain%2Fcifar10%2Fvit%2Dsmall%2D800%2D94%2E6%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Fpretrain%2Fcifar10)) | 94 | | ViT-B/2 | 512 | 800 | 98.3% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Ffinetune%2Fcifar10%2Fvit%2Dbase%2D800%2D98%2E3%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Ffinetune%2Fcifar10)) | 96.6% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Flinear%2Fcifar10%2Fvit%2Dbase%2D800%2D96%2E6%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Flinear%2Fcifar10)) | 95.8% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Fpretrain%2Fcifar10%2Fvit%2Dbase%2D800%2D95%2E8%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Fpretrain%2Fcifar10)) | 95 | 96 | 97 | 98 | #### CIFAR100 99 | 100 | | Arch | Batch size | #Pre-Epoch | Finetuning Accuracy | Linear Probing Accuracy | kNN Accuracy | 101 | | :-----: | :--------: | :--------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | 102 | | ViT-T/2 | 512 | 800 | 84.9% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Ffinetune%2Fcifar100%2Fvit%2Dtiny%2D800%2D84%2E6%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Ffinetune%2Fcifar100)) | 74.7% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Flinear%2Fcifar100%2Fvit%2Dtiny%2D800%2D74%2E7%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Flinear%2Fcifar100)) | 68.8% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Fpretrain%2Fcifar100%2Fvit%2Dtiny%2D800%2D68%2E8%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Fpretrain%2Fcifar100)) | 103 | | ViT-S/2 | 512 | 800 | 86.0% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Ffinetune%2Fcifar100%2Fvit%2Dsmall%2D800%2D86%2E0%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Ffinetune%2Fcifar100)) | 78.7% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Flinear%2Fcifar100%2Fvit%2Dsmall%2D800%2D78%2E7%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Flinear%2Fcifar100)) | 75.4% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Fpretrain%2Fcifar100%2Fvit%2Dsmall%2D800%2D75%2E4%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Fpretrain%2Fcifar100)) | 104 | | ViT-B/2 | 512 | 800 | 86.0% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Ffinetune%2Fcifar100%2Fvit%2Dbase%2D800%2D86%2E0%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Ffinetune%2Fcifar100)) | 79.7% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Flinear%2Fcifar100%2Fvit%2Dbase%2D800%2D79%2E7%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Flinear%2Fcifar100)) | 75.7% ([link](https://csueducn-my.sharepoint.com/personal/221258_csu_edu_cn/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Fpretrain%2Fcifar100%2Fvit%2Dbase%2D800%2D75%2E7%2Epth&parent=%2Fpersonal%2F221258%5Fcsu%5Fedu%5Fcn%2FDocuments%2FOpenSource%2Fpatchmix%5Fweights%2Fpretrain%2Fcifar100)) | 105 | 106 | 107 | 108 | ### The Visualization of Inter-Instance Similarities 109 | 110 | ![visualization](./images/visualization.png) 111 | 112 | The query sample and the image with id 4 in key samples are from the same category. The images with id 3 and 5 come from category similar to query sample. 113 | 114 | ### License 115 | 116 | This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details. 117 | 118 | ### Citation 119 | 120 | ```bibtex 121 | @article{shen2023inter, 122 | author = {Shen, Chengchao and Liu, Dawei and Tang, Hao and Qu, Zhe and Wang, Jianxin}, 123 | title = {Inter-Instance Similarity Modeling for Contrastive Learning}, 124 | journal = {arXiv preprint arXiv:2306.12243}, 125 | year = {2023}, 126 | } 127 | ``` 128 | 129 | -------------------------------------------------------------------------------- /config/finetune/vit_base_finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def vit_base_finetune(): 6 | args = argparse.Namespace() 7 | 8 | args.dataset = 'imagenet1k' 9 | args.arch = 'vit-base' 10 | args.pretrained_weights = '' 11 | args.resume = None 12 | args.evaluate = None 13 | args.epochs = 100 14 | args.start_epoch = 0 15 | args.output_dir = './out' 16 | args.seed = 7 17 | 18 | if args.dataset == 'imagenet1k': 19 | args.num_workers = 12 20 | args.prefetch_factor = 3 21 | args.pin_memory = True 22 | args.patch_size = 16 23 | args.input_size = 224 24 | args.batch_size = 1024 25 | args.data_root = '/path/to/ILSVRC2012' 26 | args.distributed = True 27 | else: 28 | args.num_workers = 4 29 | args.prefetch_factor = 2 30 | args.pin_memory = True 31 | args.patch_size = 2 32 | args.input_size = 32 33 | args.batch_size = 256 34 | args.data_root = '/path/to/dataset' 35 | args.distributed = False 36 | 37 | args.encoder = 'momentum_encoder' # [base_encoder,momentum_encoder] 38 | 39 | # ---ema---------- 40 | args.model_ema = True 41 | args.model_ema_decay = 0.99996 42 | args.model_ema_force_cpu = False 43 | args.drop_path = 0.1 44 | 45 | # Optimizer parameters 46 | args.opt = 'adamw' 47 | args.opt_eps = 1e-8 48 | args.opt_betas = None 49 | args.clip_grad = None 50 | args.momentum = 0.9 51 | 52 | # Learning rate schedule parameters 53 | args.sched = 'cosine' 54 | 55 | if args.dataset == 'cifar10': 56 | args.lr = 5e-4 57 | args.warmup_lr = 1e-6 58 | args.min_lr = 1e-5 59 | args.weight_decay = 0.05 60 | elif args.dataset == 'cifar100': 61 | args.lr = 5e-4 62 | args.warmup_lr = 1e-6 63 | args.min_lr = 1e-5 64 | args.weight_decay = 0.05 65 | elif args.dataset == 'imagenet1k': 66 | args.lr = 5e-4 67 | args.warmup_lr = 1e-6 68 | args.min_lr = 1e-5 69 | args.weight_decay = 0.05 70 | 71 | # learning schedule parameters 72 | args.layer_decay = 0.75 73 | args.lr_noise = None 74 | args.lr_noise_pct = 0.67 75 | args.lr_noise_std = 1.0 76 | args.decay_epochs = 30 77 | args.warmup_epochs = 10 78 | args.cooldown_epochs = 10 79 | args.patience_epochs = 10 80 | args.decay_rate = 0.1 81 | 82 | # Augmentation parameters 83 | args.color_jitter = 0.4 84 | args.aa = 'rand-m9-mstd0.5-inc1' 85 | args.smoothing = 0.1 86 | args.train_interpolation = 'bicubic' 87 | args.repeated_aug = True 88 | 89 | # Random Erase params 90 | args.reprob = 0.25 91 | args.remode = 'pixel' 92 | args.recount = 1 93 | args.resplit = False 94 | 95 | # Mixup params 96 | args.mixup = 0.8 97 | args.cutmix = 1.0 98 | args.cutmix_minmax = None # float 99 | args.mixup_prob = 1.0 100 | args.mixup_switch_prob = 0.5 101 | args.mixup_mode = 'batch' 102 | 103 | # ----------------# 104 | args.dist_url = 'tcp://localhost:12613' 105 | args.dist_backend = 'nccl' 106 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" 107 | args.world_size = 1 108 | 109 | args.print_freq = 100 110 | args.save_freq = 20 111 | 112 | args.rank = 0 113 | args.distributed = False 114 | args.gpu = None 115 | args.exclude_file_list = ['__pycache__', '.vscode', 'log', 'ckpt', '.git', 'out', 'dataset', 'weight'] 116 | 117 | return args 118 | -------------------------------------------------------------------------------- /config/finetune/vit_small_finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def vit_small_finetune(): 6 | args = argparse.Namespace() 7 | 8 | args.dataset = 'imagenet1k' 9 | args.arch = 'vit-small' 10 | args.pretrained_weights = '' 11 | args.resume = None 12 | args.evaluate = None 13 | args.epochs = 100 14 | args.start_epoch = 0 15 | args.output_dir = './out' 16 | args.seed = 7 17 | 18 | if args.dataset == 'imagenet1k': 19 | args.num_workers = 12 20 | args.prefetch_factor = 3 21 | args.pin_memory = True 22 | args.patch_size = 16 23 | args.input_size = 224 24 | args.batch_size = 1024 25 | args.data_root = '/path/to/ILSVRC2012' 26 | args.distributed = True 27 | else: 28 | args.num_workers = 4 29 | args.prefetch_factor = 2 30 | args.pin_memory = True 31 | args.patch_size = 2 32 | args.input_size = 32 33 | args.batch_size = 256 34 | args.data_root = '/path/to/dataset' 35 | args.distributed = False 36 | 37 | args.encoder = 'momentum_encoder' # [base_encoder,momentum_encoder] 38 | 39 | # ---ema---------- 40 | args.model_ema = True 41 | args.model_ema_decay = 0.99996 42 | args.model_ema_force_cpu = False 43 | args.drop_path = 0.1 44 | 45 | # Optimizer parameters 46 | args.opt = 'adamw' 47 | args.opt_eps = 1e-8 48 | args.opt_betas = None 49 | args.clip_grad = None 50 | args.momentum = 0.9 51 | 52 | # Learning rate schedule parameters 53 | args.sched = 'cosine' 54 | 55 | if args.dataset == 'cifar10': 56 | args.lr = 5e-4 57 | args.warmup_lr = 1e-6 58 | args.min_lr = 1e-5 59 | args.weight_decay = 0.05 60 | elif args.dataset == 'cifar100': 61 | args.lr = 5e-4 62 | args.warmup_lr = 1e-6 63 | args.min_lr = 1e-5 64 | args.weight_decay = 0.05 65 | elif args.dataset == 'imagenet1k': 66 | args.lr = 5e-4 67 | args.warmup_lr = 1e-6 68 | args.min_lr = 1e-5 69 | args.weight_decay = 0.05 70 | 71 | # learning schedule parameters 72 | args.layer_decay = 0.75 73 | args.lr_noise = None 74 | args.lr_noise_pct = 0.67 75 | args.lr_noise_std = 1.0 76 | args.decay_epochs = 30 77 | args.warmup_epochs = 10 78 | args.cooldown_epochs = 10 79 | args.patience_epochs = 10 80 | args.decay_rate = 0.1 81 | 82 | # Augmentation parameters 83 | args.color_jitter = 0.4 84 | args.aa = 'rand-m9-mstd0.5-inc1' 85 | args.smoothing = 0.1 86 | args.train_interpolation = 'bicubic' 87 | args.repeated_aug = True 88 | 89 | # Random Erase params 90 | args.reprob = 0.25 91 | args.remode = 'pixel' 92 | args.recount = 1 93 | args.resplit = False 94 | 95 | # Mixup params 96 | args.mixup = 0.8 97 | args.cutmix = 1.0 98 | args.cutmix_minmax = None # float 99 | args.mixup_prob = 1.0 100 | args.mixup_switch_prob = 0.5 101 | args.mixup_mode = 'batch' 102 | 103 | # ----------------# 104 | args.dist_url = 'tcp://localhost:12613' 105 | args.dist_backend = 'nccl' 106 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" 107 | args.world_size = 1 108 | 109 | args.print_freq = 100 110 | args.save_freq = 20 111 | 112 | args.rank = 0 113 | args.distributed = False 114 | args.gpu = None 115 | args.exclude_file_list = ['__pycache__', '.vscode', 'log', 'ckpt', '.git', 'out', 'dataset', 'weight'] 116 | 117 | return args 118 | -------------------------------------------------------------------------------- /config/finetune/vit_tiny_finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def vit_tiny_finetune(): 6 | args = argparse.Namespace() 7 | 8 | args.dataset = 'cifar100' 9 | args.arch = 'vit-tiny' 10 | args.pretrained_weights = '' 11 | args.resume = None 12 | args.evaluate = None 13 | args.epochs = 100 14 | args.start_epoch = 0 15 | args.output_dir = './out' 16 | args.seed = 7 17 | 18 | if args.dataset == 'imagenet1k': 19 | args.num_workers = 12 20 | args.prefetch_factor = 3 21 | args.pin_memory = True 22 | args.patch_size = 16 23 | args.input_size = 224 24 | args.batch_size = 1024 25 | args.data_root = '/path/to/ILSVRC2012' 26 | args.distributed = True 27 | else: 28 | args.num_workers = 4 29 | args.prefetch_factor = 2 30 | args.pin_memory = True 31 | args.patch_size = 2 32 | args.input_size = 32 33 | args.batch_size = 256 34 | args.data_root = '/path/to/dataset' 35 | args.distributed = False 36 | 37 | args.encoder = 'momentum_encoder' # [base_encoder,momentum_encoder] 38 | 39 | # ---ema---------- 40 | args.model_ema = True 41 | args.model_ema_decay = 0.99996 42 | args.model_ema_force_cpu = False 43 | args.drop_path = 0.1 44 | 45 | # Optimizer parameters 46 | args.opt = 'adamw' 47 | args.opt_eps = 1e-8 48 | args.opt_betas = None 49 | args.clip_grad = None 50 | args.momentum = 0.9 51 | 52 | # Learning rate schedule parameters 53 | args.sched = 'cosine' 54 | 55 | if args.dataset == 'cifar10': 56 | args.lr = 1e-3 57 | args.warmup_lr = 1e-6 58 | args.min_lr = 1e-5 59 | args.weight_decay = 0.05 60 | elif args.dataset == 'cifar100': 61 | args.lr = 1e-3 62 | args.warmup_lr = 1e-6 63 | args.min_lr = 1e-5 64 | args.weight_decay = 0.05 65 | elif args.dataset == 'imagenet1k': 66 | args.lr = 1e-3 67 | args.warmup_lr = 1e-6 68 | args.min_lr = 1e-5 69 | args.weight_decay = 0.05 70 | 71 | # learning schedule parameters 72 | args.layer_decay = 0.75 73 | args.lr_noise = None 74 | args.lr_noise_pct = 0.67 75 | args.lr_noise_std = 1.0 76 | args.decay_epochs = 30 77 | args.warmup_epochs = 10 78 | args.cooldown_epochs = 10 79 | args.patience_epochs = 10 80 | args.decay_rate = 0.1 81 | 82 | # Augmentation parameters 83 | args.color_jitter = 0.4 84 | args.aa = 'rand-m9-mstd0.5-inc1' 85 | args.smoothing = 0.1 86 | args.train_interpolation = 'bicubic' 87 | args.repeated_aug = True 88 | 89 | # Random Erase params 90 | args.reprob = 0.25 91 | args.remode = 'pixel' 92 | args.recount = 1 93 | args.resplit = False 94 | 95 | # Mixup params 96 | args.mixup = 0.8 97 | args.cutmix = 1.0 98 | args.cutmix_minmax = None # float 99 | args.mixup_prob = 1.0 100 | args.mixup_switch_prob = 0.5 101 | args.mixup_mode = 'batch' 102 | 103 | # ----------------# 104 | args.dist_url = 'tcp://localhost:12613' 105 | args.dist_backend = 'nccl' 106 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" 107 | args.world_size = 1 108 | 109 | args.print_freq = 100 110 | args.save_freq = 20 111 | 112 | args.rank = 0 113 | args.distributed = False 114 | args.gpu = None 115 | args.exclude_file_list = ['__pycache__', '.vscode', 'log', 'ckpt', '.git', 'out', 'dataset', 'weight'] 116 | 117 | return args 118 | -------------------------------------------------------------------------------- /config/knn/knn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def knn(): 6 | args = argparse.Namespace() 7 | 8 | args.dataset = 'cifar100' 9 | args.arch = '' 10 | args.pretrained_weights = '' 11 | args.output_dir = './out' 12 | args.seed = 7 13 | 14 | if args.dataset == 'imagenet1k': 15 | args.num_workers = 12 16 | args.prefetch_factor = 3 17 | args.pin_memory = True 18 | args.patch_size = 16 19 | args.input_size = 224 20 | args.data_root = '/path/to/ILSVRC2012' 21 | args.batch_size_per_gpu = 1024 22 | elif args.dataset == 'cifar100': 23 | args.num_workers = 4 24 | args.prefetch_factor = 2 25 | args.pin_memory = True 26 | args.patch_size = 2 27 | args.input_size = 32 28 | args.data_root = '/path/to/dataset/cifar100' 29 | args.batch_size_per_gpu = 256 30 | else: 31 | args.num_workers = 4 32 | args.prefetch_factor = 2 33 | args.pin_memory = True 34 | args.patch_size = 2 35 | args.input_size = 32 36 | args.data_root = '/path/to/dataset/cifar10' 37 | args.batch_size_per_gpu = 256 38 | 39 | args.encoder = 'momentum_encoder' 40 | args.nb_knn = [10, 20, 100, 200] 41 | args.temperature = 0.07 42 | args.use_cuda = True 43 | args.dump_features = None 44 | args.load_features = None 45 | 46 | # ----------------# 47 | args.dist_url = 'tcp://localhost:12617' 48 | args.dist_backend = 'nccl' 49 | args.rank = 0 50 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 51 | args.world_size = 1 52 | 53 | args.exclude_file_list = ['__pycache__', '.vscode', 'log', 'ckpt', '.git', 'out', 'dataset', 'weight'] 54 | 55 | return args 56 | -------------------------------------------------------------------------------- /config/linear/vit_base_linear.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def vit_base_linear(): 6 | args = argparse.Namespace() 7 | 8 | args.dataset = 'imagenet1k' 9 | args.arch = 'vit-base' 10 | args.pretrained_weights = '' 11 | args.resume = None 12 | args.evaluate = None 13 | args.epochs = 100 14 | args.start_epoch = 0 15 | args.output_dir = './out' 16 | args.seed = 7 17 | 18 | if args.dataset == 'imagenet1k': 19 | args.num_workers = 12 20 | args.prefetch_factor = 3 21 | args.pin_memory = True 22 | args.patch_size = 16 23 | args.input_size = 224 24 | args.batch_size = 1024 25 | args.data_root = '/path/to/ILSVRC2012' 26 | else: 27 | args.num_workers = 4 28 | args.prefetch_factor = 2 29 | args.pin_memory = True 30 | args.patch_size = 2 31 | args.input_size = 32 32 | args.batch_size = 256 33 | args.data_root = '/path/to/dataset' 34 | 35 | args.encoder = 'momentum_encoder' # [base_encoder,momentum_encoder] 36 | 37 | # Optimizer parameters 38 | args.opt = 'sgd' 39 | args.opt_eps = 1e-8 40 | args.opt_betas = None 41 | args.clip_grad = None 42 | args.momentum = 0.9 43 | 44 | if args.dataset == 'cifar10': 45 | args.lr = 0.01 46 | args.weight_decay = 0.0 47 | elif args.dataset == 'cifar100': 48 | args.lr = 0.01 49 | args.weight_decay = 0.0 50 | elif args.dataset == 'imagenet1k': 51 | args.lr = 0.01 52 | args.weight_decay = 0.0 53 | 54 | # ----------------# 55 | args.dist_url = 'tcp://localhost:12612' 56 | args.dist_backend = 'nccl' 57 | args.rank = 0 58 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 59 | args.world_size = 1 60 | 61 | args.print_freq = 100 62 | args.save_freq = 20 63 | 64 | args.distributed = True 65 | args.gpu = None 66 | args.exclude_file_list = ['__pycache__', '.vscode', 67 | 'log', 'ckpt', '.git', 'out', 'dataset', 'weight'] 68 | 69 | return args 70 | -------------------------------------------------------------------------------- /config/linear/vit_small_linear.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def vit_small_linear(): 6 | args = argparse.Namespace() 7 | 8 | args.dataset = 'imagenet1k' 9 | args.arch = 'vit-small' 10 | args.pretrained_weights = '' 11 | args.resume = None 12 | args.evaluate = '' 13 | args.epochs = 100 14 | args.start_epoch = 0 15 | args.output_dir = './out' 16 | args.seed = 7 17 | 18 | if args.dataset == 'imagenet1k': 19 | args.num_workers = 12 20 | args.prefetch_factor = 3 21 | args.pin_memory = True 22 | args.patch_size = 16 23 | args.input_size = 224 24 | args.batch_size = 1024 25 | args.data_root = '/path/to/ILSVRC2012' 26 | else: 27 | args.num_workers = 4 28 | args.prefetch_factor = 2 29 | args.pin_memory = True 30 | args.patch_size = 2 31 | args.input_size = 32 32 | args.batch_size = 256 33 | args.data_root = '/path/to/dataset' 34 | 35 | args.encoder = 'momentum_encoder' # [base_encoder,momentum_encoder] 36 | 37 | # Optimizer parameters 38 | args.opt = 'sgd' 39 | args.opt_eps = 1e-8 40 | args.opt_betas = None 41 | args.clip_grad = None 42 | args.momentum = 0.9 43 | 44 | if args.dataset == 'cifar10': 45 | args.lr = 0.02 46 | args.weight_decay = 0.0 47 | elif args.dataset == 'cifar100': 48 | args.lr = 0.02 49 | args.weight_decay = 0.0 50 | elif args.dataset == 'imagenet1k': 51 | args.lr = 0.02 52 | args.weight_decay = 0.0 53 | 54 | # ----------------# 55 | args.dist_url = 'tcp://localhost:12612' 56 | args.dist_backend = 'nccl' 57 | args.rank = 0 58 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 59 | args.world_size = 1 60 | 61 | args.print_freq = 100 62 | args.save_freq = 20 63 | 64 | args.distributed = True 65 | args.gpu = None 66 | args.exclude_file_list = ['__pycache__', '.vscode', 67 | 'log', 'ckpt', '.git', 'out', 'dataset', 'weight'] 68 | 69 | return args 70 | -------------------------------------------------------------------------------- /config/linear/vit_tiny_linear.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def vit_tiny_linear(): 6 | args = argparse.Namespace() 7 | 8 | args.dataset = 'cifar100' 9 | args.arch = 'vit-tiny' 10 | args.pretrained_weights = '' 11 | args.resume = None 12 | args.evaluate = None 13 | args.epochs = 100 14 | args.start_epoch = 0 15 | args.output_dir = './out' 16 | args.seed = 7 17 | 18 | if args.dataset == 'imagenet1k': 19 | args.num_workers = 12 20 | args.prefetch_factor = 3 21 | args.pin_memory = True 22 | args.patch_size = 16 23 | args.input_size = 224 24 | args.batch_size = 1024 25 | args.data_root = '/path/to/ILSVRC2012' 26 | else: 27 | args.num_workers = 4 28 | args.prefetch_factor = 2 29 | args.pin_memory = True 30 | args.patch_size = 2 31 | args.input_size = 32 32 | args.batch_size = 256 33 | args.data_root = '/path/to/dataset' 34 | 35 | args.encoder = 'momentum_encoder' # [base_encoder,momentum_encoder] 36 | 37 | # Optimizer parameters 38 | args.opt = 'sgd' 39 | args.opt_eps = 1e-8 40 | args.opt_betas = None 41 | args.clip_grad = None 42 | args.momentum = 0.9 43 | 44 | if args.dataset == 'cifar10': 45 | args.lr = 0.05 46 | args.weight_decay = 0.0 47 | elif args.dataset == 'cifar100': 48 | args.lr = 0.05 49 | args.weight_decay = 0.0 50 | elif args.dataset == 'imagenet1k': 51 | args.lr = 0.05 52 | args.weight_decay = 0.0 53 | 54 | # ----------------# 55 | args.dist_url = 'tcp://localhost:12613' 56 | args.dist_backend = 'nccl' 57 | args.rank = 0 58 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 59 | args.world_size = 1 60 | 61 | args.print_freq = 100 62 | args.save_freq = 20 63 | 64 | args.distributed = True 65 | args.gpu = None 66 | args.exclude_file_list = ['__pycache__', '.vscode', 67 | 'log', 'ckpt', '.git', 'out', 'dataset', 'weight'] 68 | 69 | return args 70 | -------------------------------------------------------------------------------- /config/pretrain/vit_base_pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def vit_base_pretrain(): 5 | args = argparse.Namespace() 6 | args.arch = 'vit-base' 7 | args.resume = None 8 | args.dataset = 'imagenet1k' 9 | args.seed = 7 10 | 11 | if args.dataset == 'imagenet1k': 12 | args.data_root = '/path/to/ILSVRC2012' 13 | args.input_size = 224 14 | args.patch_size = 16 15 | args.num_workers = 32 16 | args.prefetch_factor = 3 17 | args.pin_memory = True 18 | args.save_freq = 10 19 | args.epochs = 300 20 | args.batch_size = 1024 21 | args.warmup_epoch = 10 22 | args.multi_crop_size = 96 23 | # multi-crop params 24 | args.multi_crop_num = 8 # unuse multi-crop when set 0 25 | args.global_crop = 0.35 26 | args.mix_num = 2 27 | args.mix_size = args.patch_size 28 | args.smoothing = 0.0 29 | args.min_crop = 0.05 30 | elif args.dataset == 'cifar10' or args.dataset == 'cifar100': 31 | args.data_root = '/path/to/dataset' 32 | args.input_size = 32 33 | args.patch_size = 2 34 | args.num_workers = 8 35 | args.prefetch_factor = 2 36 | args.pin_memory = True 37 | args.save_freq = 100 38 | args.epochs = 800 39 | args.batch_size = 512 40 | args.warmup_epoch = 100 41 | args.multi_crop_size = 14 42 | # multi-crop params 43 | args.multi_crop_num = 0 # unuse multi-crop when set 0 44 | args.global_crop = 0.35 45 | args.mix_num = 2 46 | args.mix_size = args.patch_size 47 | args.smoothing = 0.0 48 | args.min_crop = 0.1 49 | 50 | args.drop_path = 0.1 51 | 52 | # lr params 53 | args.lr = 7.5e-4 54 | args.min_lr = 1e-6 55 | args.weight_decay = 0.04 56 | args.weight_decay_end = 0.4 57 | args.use_wd_cos = True 58 | if not args.use_wd_cos: 59 | args.weight_decay_end = args.weight_decay 60 | 61 | # moco params 62 | args.use_moco = True 63 | args.moco_m = 0.996 64 | args.moco_m_cos = True 65 | 66 | args.print_freq = None 67 | 68 | args.out_dim = 256 69 | args.hidden_dim = 4096 70 | args.proj_layer = 3 71 | args.pred_layer = 2 72 | args.temp = 0.2 73 | args.warmup_temp = 0.2 74 | args.warmup_temp_epochs = 30 75 | args.mix_p = 1.0 76 | 77 | args.exp_dir = f'./log/pretrain/{args.dataset}/ckpts_{args.arch}_p{args.patch_size}' \ 78 | f'_moco_{args.use_moco}_mm{args.moco_m}_min_crop{args.min_crop}' \ 79 | f'_t{args.temp}_lr{args.lr}_wd{args.weight_decay}' \ 80 | f'_bs{args.batch_size}_epoch{args.epochs}' \ 81 | f'_global_crop{args.global_crop}' \ 82 | f'_mc_n{args.multi_crop_num}_dp{args.drop_path}' 83 | args.rank = 0 84 | args.distributed = True 85 | args.use_mix_precision = True 86 | args.init_method = 'tcp://localhost:17991' 87 | args.world_size = 1 88 | 89 | args.exclude_file_list = ['__pycache__', '.vscode', 90 | 'log', 'ckpt', '.git', 'out', 'dataset', 'weight'] 91 | 92 | return args 93 | -------------------------------------------------------------------------------- /config/pretrain/vit_small_pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def vit_small_pretrain(): 6 | args = argparse.Namespace() 7 | args.arch = 'vit-small' 8 | args.resume = None 9 | args.dataset = 'imagenet1k' 10 | args.seed = 7 11 | 12 | if args.dataset == 'imagenet1k': 13 | args.data_root = '/path/to/ILSVRC2012' 14 | args.input_size = 224 15 | args.patch_size = 16 16 | args.num_workers = 32 17 | args.prefetch_factor = 3 18 | args.pin_memory = True 19 | args.save_freq = 10 20 | args.epochs = 300 21 | args.batch_size = 1024 22 | args.warmup_epoch = 10 23 | args.multi_crop_size = 96 24 | # multi-crop params 25 | args.multi_crop_num = 8 # unuse multi-crop when set 0 26 | args.global_crop = 0.35 27 | args.mix_num = 2 28 | args.mix_size = args.patch_size 29 | args.smoothing = 0.0 30 | args.min_crop = 0.05 31 | elif args.dataset == 'cifar10' or args.dataset == 'cifar100': 32 | args.data_root = '/path/to/dataset' 33 | args.input_size = 32 34 | args.patch_size = 2 35 | args.num_workers = 8 36 | args.prefetch_factor = 2 37 | args.pin_memory = True 38 | args.save_freq = 100 39 | args.epochs = 800 40 | args.batch_size = 512 41 | args.warmup_epoch = 100 42 | args.multi_crop_size = 14 43 | # multi-crop params 44 | args.multi_crop_num = 0 # unuse multi-crop when set 0 45 | args.global_crop = 0.35 46 | args.mix_num = 2 47 | args.mix_size = args.patch_size 48 | args.smoothing = 0.0 49 | args.min_crop = 0.1 50 | 51 | args.drop_path = 0.1 52 | 53 | # lr params 54 | args.lr = 5e-4 55 | args.min_lr = 1e-6 56 | args.weight_decay = 0.04 57 | args.weight_decay_end = 0.4 58 | args.use_wd_cos = True 59 | if not args.use_wd_cos: 60 | args.weight_decay_end = args.weight_decay 61 | 62 | # moco params 63 | args.use_moco = True 64 | args.moco_m = 0.996 65 | args.moco_m_cos = True 66 | 67 | args.print_freq = None 68 | 69 | args.out_dim = 256 70 | args.hidden_dim = 4096 71 | args.proj_layer = 3 72 | args.pred_layer = 2 73 | args.temp = 0.2 74 | args.warmup_temp = 0.2 75 | args.warmup_temp_epochs = 30 76 | args.mix_p = 1.0 77 | 78 | args.exp_dir = f'./log/pretrain/{args.dataset}/ckpts_{args.arch}_p{args.patch_size}' \ 79 | f'_moco_{args.use_moco}_mm{args.moco_m}_min_crop{args.min_crop}' \ 80 | f'_t{args.temp}_lr{args.lr}_wd{args.weight_decay}' \ 81 | f'_bs{args.batch_size}_epoch{args.epochs}' \ 82 | f'_global_crop{args.global_crop}' \ 83 | f'_mc_n{args.multi_crop_num}_dp{args.drop_path}' 84 | 85 | args.rank = 0 86 | args.distributed = True 87 | args.use_mix_precision = True 88 | args.init_method = 'tcp://localhost:17991' 89 | args.world_size = 1 90 | 91 | args.exclude_file_list = ['__pycache__', '.vscode', 92 | 'log', 'ckpt', '.git', 'out', 'dataset', 'weight'] 93 | 94 | return args 95 | -------------------------------------------------------------------------------- /config/pretrain/vit_tiny_pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def vit_tiny_pretrain(): 6 | args = argparse.Namespace() 7 | args.arch = 'vit-tiny' 8 | args.resume = None 9 | args.dataset = 'cifar100' 10 | args.seed = 7 11 | 12 | if args.dataset == 'imagenet1k': 13 | args.data_root = '/path/to/ILSVRC2012' 14 | args.input_size = 224 15 | args.patch_size = 16 16 | args.num_workers = 32 17 | args.prefetch_factor = 3 18 | args.pin_memory = True 19 | args.save_freq = 10 20 | args.epochs = 300 21 | args.batch_size = 1024 22 | args.warmup_epoch = 10 23 | args.multi_crop_size = 96 24 | # multi-crop params 25 | args.multi_crop_num = 8 # disable multi-crop when set 0 26 | args.global_crop = 0.35 27 | args.mix_num = 2 28 | args.mix_size = args.patch_size 29 | args.smoothing = 0.0 30 | args.min_crop = 0.05 31 | 32 | elif args.dataset == 'cifar10' or args.dataset == 'cifar100': 33 | args.data_root = '/path/to/dataset' 34 | args.input_size = 32 35 | args.patch_size = 2 36 | args.num_workers = 8 37 | args.prefetch_factor = 2 38 | args.pin_memory = True 39 | args.save_freq = 100 40 | args.epochs = 800 41 | args.batch_size = 512 42 | args.warmup_epoch = 100 43 | args.multi_crop_size = 14 44 | # multi-crop params 45 | args.multi_crop_num = 0 # disable multi-crop when set 0 46 | args.global_crop = 0.35 47 | args.mix_num = 2 48 | args.mix_size = args.patch_size 49 | args.smoothing = 0.0 50 | args.min_crop = 0.1 51 | 52 | args.drop_path = 0.1 53 | # lr params 54 | args.lr = 2e-3 55 | args.min_lr = 1e-6 56 | args.weight_decay = 0.02 57 | args.weight_decay_end = 0.2 58 | args.use_wd_cos = True 59 | if not args.use_wd_cos: 60 | args.weight_decay_end = args.weight_decay 61 | 62 | # moco params 63 | args.use_moco = False 64 | args.moco_m = 0.9 65 | args.moco_m_cos = True 66 | 67 | args.print_freq = None 68 | 69 | args.out_dim = 256 70 | args.hidden_dim = 4096 71 | args.proj_layer = 3 72 | args.pred_layer = 2 73 | args.temp = 0.2 74 | args.warmup_temp = 0.2 75 | args.warmup_temp_epochs = 30 76 | args.mix_p = 1.0 77 | 78 | args.exp_dir = f'./log/pretrain/{args.dataset}/ckpts_{args.arch}_p{args.patch_size}' \ 79 | f'_moco_{args.use_moco}_mm{args.moco_m}_min_crop{args.min_crop}' \ 80 | f'_t{args.temp}_lr{args.lr}_wd{args.weight_decay}' \ 81 | f'_bs{args.batch_size}_epoch{args.epochs}' \ 82 | f'_global_crop{args.global_crop}' \ 83 | f'_mc_n{args.multi_crop_num}_dp{args.drop_path}' 84 | 85 | args.rank = 0 86 | args.distributed = True 87 | args.use_mix_precision = True 88 | args.init_method = 'tcp://localhost:17991' 89 | args.world_size = 1 90 | 91 | args.exclude_file_list = ['__pycache__', '.vscode', 92 | 'log', 'ckpt', '.git', 'out', 'dataset', 'weight'] 93 | 94 | return args 95 | -------------------------------------------------------------------------------- /images/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visresearch/patchmix/580adc846d6251a828c625c6a4de2190a49f630a/images/framework.png -------------------------------------------------------------------------------- /images/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/visresearch/patchmix/580adc846d6251a828c625c6a4de2190a49f630a/images/visualization.png -------------------------------------------------------------------------------- /main_finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import time 5 | from typing import Iterable, Optional 6 | 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | import torch.nn as nn 12 | from timm.data import Mixup, create_transform 13 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 14 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 15 | from timm.optim import create_optimizer 16 | from timm.scheduler import create_scheduler 17 | from timm.utils import NativeScaler, ModelEma, accuracy 18 | from torchvision import datasets, transforms 19 | from torch.nn.parallel import DistributedDataParallel as DDP 20 | from config.finetune.vit_base_finetune import vit_base_finetune 21 | from config.finetune.vit_small_finetune import vit_small_finetune 22 | from config.finetune.vit_tiny_finetune import vit_tiny_finetune 23 | from module.vits import ViT 24 | from utils import misc 25 | from utils.logger import Logger, console_logger 26 | from utils.misc import AverageMeter 27 | 28 | 29 | def build_dataset(is_train, args): 30 | transform = build_transform(is_train, args) 31 | if args.dataset == 'cifar100': 32 | dataset = datasets.CIFAR100( 33 | args.data_root, train=is_train, transform=transform) 34 | nb_classes = 100 35 | elif args.dataset == 'cifar10': 36 | dataset = datasets.CIFAR10( 37 | args.data_root, train=is_train, transform=transform) 38 | nb_classes = 10 39 | elif args.dataset == 'imagenet1k': 40 | dataset = datasets.ImageFolder( 41 | root=os.path.join(args.data_root, 'train' if is_train else 'val'), transform=transform) 42 | nb_classes = 1000 43 | return dataset, nb_classes 44 | 45 | 46 | def build_transform(is_train, args): 47 | resize_im = args.input_size > 32 48 | if is_train: 49 | transform = create_transform( 50 | input_size=args.input_size, 51 | is_training=True, 52 | color_jitter=args.color_jitter, 53 | auto_augment=args.aa, 54 | interpolation=args.train_interpolation, 55 | re_prob=args.reprob, 56 | re_mode=args.remode, 57 | re_count=args.recount, 58 | ) 59 | if not resize_im: 60 | transform.transforms[0] = transforms.RandomCrop( 61 | args.input_size, padding=4) 62 | return transform 63 | 64 | t = [] 65 | if resize_im: 66 | size = int((256 / 224) * args.input_size) 67 | t.append( 68 | transforms.Resize(size, interpolation=3), 69 | ) 70 | t.append(transforms.CenterCrop(args.input_size)) 71 | 72 | t.append(transforms.ToTensor()) 73 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 74 | return transforms.Compose(t) 75 | 76 | 77 | def adjust_learning_rate(optimizer, init_lr, epoch, args): 78 | """Decay the learning rate based on schedule""" 79 | cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 80 | for param_group in optimizer.param_groups: 81 | param_group['lr'] = cur_lr 82 | 83 | 84 | def get_model_from_frame(checkpoint, args): 85 | encoder = args.encoder 86 | state_dict = checkpoint['state_dict'] 87 | encoder = ('module.' if 'module' in list( 88 | state_dict.keys())[0] else '') + encoder 89 | for k in list(state_dict.keys()): 90 | if k.startswith(encoder) and not k.startswith(encoder + '.head'): 91 | state_dict[k[len(encoder + "."):]] = state_dict[k] 92 | del state_dict[k] 93 | return state_dict 94 | 95 | 96 | def train_one_epoch(model: torch.nn.Module, criterion, 97 | train_loader: Iterable, optimizer: torch.optim.Optimizer, 98 | epoch: int, loss_scaler, loggers, args, max_norm: float = 0, 99 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None 100 | ): 101 | model.train() 102 | logger_tb, logger_console = loggers 103 | 104 | data_time = AverageMeter('Data', ':6.3f') 105 | batch_time = AverageMeter('Time', ':6.3f') 106 | losses = AverageMeter('Loss', ':.4e') 107 | 108 | num_iter = len(train_loader) 109 | niter_global = epoch * num_iter 110 | end = time.time() 111 | 112 | for i, (samples, targets) in enumerate(train_loader): 113 | samples = samples.to(args.rank, non_blocking=True) 114 | targets = targets.to(args.rank, non_blocking=True) 115 | data_time.update(time.time() - end) 116 | if mixup_fn is not None: 117 | samples, targets = mixup_fn(samples, targets) 118 | 119 | with torch.cuda.amp.autocast(): 120 | outputs = model(samples) 121 | loss = criterion(outputs, targets) 122 | 123 | losses.update(loss.item(), samples.size(0)) 124 | batch_time.update(time.time() - end) 125 | 126 | end = time.time() 127 | 128 | optimizer.zero_grad() 129 | is_second_order = hasattr( 130 | optimizer, 'is_second_order') and optimizer.is_second_order 131 | loss_scaler(loss, optimizer, clip_grad=max_norm, 132 | parameters=model.parameters(), create_graph=is_second_order) 133 | 134 | torch.cuda.synchronize() 135 | if model_ema is not None: 136 | model_ema.update(model) 137 | 138 | niter_global += 1 139 | if args.rank == 0: 140 | logger_tb.add_scalar('Finetune/Iter/loss', 141 | losses.val, niter_global) 142 | 143 | if (i + 1) % args.print_freq == 0 and logger_console is not None and args.rank == 0: 144 | lr = optimizer.param_groups[0]['lr'] 145 | logger_console.info(f'Epoch [{epoch}][{i + 1}/{num_iter}] - ' 146 | f'data_time: {data_time.avg:.3f}, ' 147 | f'batch_time: {batch_time.avg:.3f}, ' 148 | f'lr: {lr:.5f}, ' 149 | f'loss: {losses.val:.3f}({losses.avg:.3f})') 150 | if args.distributed: 151 | losses.synchronize_between_processes() 152 | 153 | return losses.avg 154 | 155 | 156 | @torch.no_grad() 157 | def evaluate(data_loader, model, args): 158 | accs = AverageMeter('Acc@1', ':6.2f') 159 | 160 | model.eval() 161 | 162 | for i, (images, target) in enumerate(data_loader): 163 | images = images.to(args.rank, non_blocking=True) 164 | target = target.to(args.rank, non_blocking=True, dtype=torch.long) 165 | 166 | with torch.cuda.amp.autocast(): 167 | output = model(images) 168 | 169 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 170 | 171 | batch_size = images.shape[0] 172 | accs.update(acc1.item(), batch_size) 173 | if args.distributed: 174 | accs.synchronize_between_processes() 175 | 176 | return accs.avg 177 | 178 | 179 | def main_ddp(args): 180 | if args.distributed: 181 | ngpus_per_node = args.ngpus_per_node 182 | args.world_size = args.world_size * ngpus_per_node 183 | mp.spawn(main, args=(args,), nprocs=args.world_size) 184 | else: 185 | main(args.rank, args) 186 | 187 | 188 | def main(rank, args): 189 | args.rank = rank 190 | if args.distributed: 191 | dist.init_process_group( 192 | backend="nccl", 193 | init_method=args.dist_url, 194 | world_size=args.world_size, 195 | rank=args.rank, 196 | ) 197 | 198 | misc.fix_random_seeds(args.seed) 199 | 200 | cudnn.benchmark = True 201 | if not args.evaluate: 202 | if args.rank == 0: 203 | for k, v in sorted(vars(args).items()): 204 | print(k, '=', v) 205 | name = str(args.arch) + "_" + str(args.dataset) + \ 206 | "_epochs_" + str(args.epochs) + "_lr_" + str(args.lr) 207 | logger_tb = Logger(args.output_dir, name) 208 | logger_console = console_logger(logger_tb.log_dir, 'console_eval') 209 | dst_dir = os.path.join(logger_tb.log_dir, 'code/') 210 | else: 211 | logger_tb, logger_console = None, None 212 | if args.rank == 0: 213 | path_save = os.path.join(args.output_dir, logger_tb.log_name) 214 | 215 | dataset_train, num_class = build_dataset(is_train=True, args=args) 216 | dataset_val, _ = build_dataset(is_train=False, args=args) 217 | 218 | if args.distributed: 219 | num_tasks = args.world_size 220 | global_rank = args.rank 221 | sampler_train = torch.utils.data.DistributedSampler( 222 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 223 | ) 224 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 225 | args.num_workers = int((args.num_workers + 1) / args.world_size) 226 | else: 227 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 228 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 229 | 230 | mixup_fn = None 231 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 232 | if mixup_active: 233 | mixup_fn = Mixup( 234 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 235 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 236 | label_smoothing=args.smoothing, num_classes=num_class) 237 | 238 | if args.arch == 'vit-tiny': 239 | model = ViT(patch_size=args.patch_size, img_size=args.input_size, 240 | embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, drop_path_rate=args.drop_path) 241 | elif args.arch == 'vit-small': 242 | model = ViT(patch_size=args.patch_size, img_size=args.input_size, 243 | embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, drop_path_rate=args.drop_path) 244 | elif args.arch == 'vit-base': 245 | model = ViT(patch_size=args.patch_size, img_size=args.input_size, 246 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, drop_path_rate=args.drop_path) 247 | 248 | if args.pretrained_weights: 249 | if os.path.isfile(args.pretrained_weights): 250 | print("=> loading checkpoint '{}'".format(args.pretrained_weights)) 251 | checkpoint = torch.load( 252 | args.pretrained_weights, map_location=torch.device(args.rank)) 253 | 254 | state_dict = get_model_from_frame(checkpoint, args) 255 | 256 | args.start_epoch = 0 257 | msg = model.load_state_dict(state_dict, strict=False) 258 | assert set(msg.missing_keys) == {"head.weight", "head.bias"} 259 | 260 | print("=> loaded pre-trained model '{}'".format(args.pretrained_weights)) 261 | else: 262 | print("=> no checkpoint found at '{}'".format( 263 | args.pretrained_weights)) 264 | 265 | model.head = nn.Linear(model.head.in_features, num_class) 266 | 267 | model.cuda(args.rank) 268 | model_ema = None 269 | if args.model_ema: 270 | model_ema = ModelEma( 271 | model, 272 | decay=args.model_ema_decay, 273 | device='cpu' if args.model_ema_force_cpu else '', 274 | resume='') 275 | 276 | model_without_ddp = model 277 | if args.distributed: 278 | model = torch.nn.parallel.DistributedDataParallel( 279 | model, device_ids=[args.rank]) 280 | torch.cuda.set_device(args.rank) 281 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 282 | args.batch_size = int(args.batch_size / args.world_size) 283 | model_without_ddp = model.module 284 | 285 | if args.distributed: 286 | args.lr = args.lr * args.batch_size * args.world_size / 256 287 | else: 288 | args.lr = args.lr * args.batch_size / 256 289 | 290 | optimizer = create_optimizer(args, model_without_ddp) 291 | 292 | loss_scaler = NativeScaler() 293 | 294 | lr_scheduler, _ = create_scheduler(args, optimizer) 295 | 296 | if args.mixup > 0.: 297 | criterion = SoftTargetCrossEntropy() 298 | elif args.smoothing: 299 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 300 | else: 301 | criterion = torch.nn.CrossEntropyLoss() 302 | 303 | data_loader_train = torch.utils.data.DataLoader( 304 | dataset_train, sampler=sampler_train, 305 | batch_size=args.batch_size, 306 | num_workers=args.num_workers, 307 | pin_memory=args.pin_memory, 308 | prefetch_factor=args.prefetch_factor, 309 | drop_last=True, 310 | ) 311 | 312 | data_loader_val = torch.utils.data.DataLoader( 313 | dataset_val, sampler=sampler_val, 314 | batch_size=int(1.5 * args.batch_size), 315 | num_workers=args.num_workers, 316 | pin_memory=args.pin_memory, 317 | prefetch_factor=args.prefetch_factor, 318 | drop_last=False 319 | ) 320 | 321 | acc_best = 0.0 322 | if args.resume: 323 | if os.path.isfile(args.resume): 324 | print("=> loading checkpoint '{}'".format(args.resume)) 325 | if args.gpu is None: 326 | checkpoint = torch.load(args.resume) 327 | else: 328 | loc = 'cuda:{}'.format(args.gpu) 329 | checkpoint = torch.load(args.resume, map_location=loc) 330 | args.start_epoch = checkpoint['epoch'] 331 | if args.gpu is not None: 332 | acc_best = acc_best.to(args.gpu) 333 | if isinstance(model, DDP): 334 | model.module.load_state_dict(checkpoint['state_dict']) 335 | else: 336 | model.load_state_dict(checkpoint['state_dict']) 337 | optimizer.load_state_dict(checkpoint['optimizer']) 338 | loss_scaler.load_state_dict(checkpoint['scaler']) 339 | print("=> loaded checkpoint '{}' (epoch {})" 340 | .format(args.resume, checkpoint['epoch'])) 341 | else: 342 | print("=> no checkpoint found at '{}'".format(args.resume)) 343 | 344 | if args.evaluate: 345 | if os.path.isfile(args.evaluate): 346 | print("=> loading checkpoint '{}'".format(args.evaluate)) 347 | model = torch.load( 348 | args.evaluate, map_location=torch.device(args.rank)) 349 | print("=> loaded pre-trained model '{}'".format(args.evaluate)) 350 | else: 351 | print("=> no checkpoint found at '{}'".format(args.evaluate)) 352 | acc = evaluate(data_loader_val, model, args) 353 | print('Acc :' + str(acc)) 354 | return 355 | 356 | print(f"Start training for {args.epochs} epochs") 357 | for epoch in range(args.start_epoch, args.epochs): 358 | if args.distributed: 359 | data_loader_train.sampler.set_epoch(epoch) 360 | 361 | loss = train_one_epoch( 362 | model, criterion, data_loader_train, 363 | optimizer, epoch, loss_scaler, (logger_tb, logger_console), args, 364 | args.clip_grad, model_ema, mixup_fn 365 | ) 366 | if args.rank == 0: 367 | logger_tb.add_scalar('Finetune/Epoch/loss', loss, epoch) 368 | 369 | state_dict = model.module.state_dict() if isinstance(model, DDP) else model.state_dict() 370 | if epoch % args.save_freq == 0 and args.rank == 0: 371 | torch.save( 372 | { 373 | 'epoch': epoch + 1, 374 | 'arch': args.arch, 375 | 'state_dict': state_dict, 376 | 'acc_best': acc_best, 377 | 'optimizer': optimizer.state_dict(), 378 | 'scaler': loss_scaler.state_dict(), 379 | }, 380 | f'{path_save}/{epoch:0>4d}.pth' 381 | ) 382 | lr_scheduler.step(epoch) 383 | acc = evaluate(data_loader_val, model, args) 384 | if args.rank == 0: 385 | logger_tb.add_scalar('Finetune/Epoch/Accuracy', acc, epoch) 386 | logger_console.info( 387 | f'Epoch: {epoch}, ' 388 | f'Accuracy: {acc}' 389 | ) 390 | 391 | if acc > acc_best: 392 | acc_best = acc 393 | epoch_best = epoch 394 | if args.rank == 0: 395 | torch.save( 396 | model_without_ddp, 397 | f'{path_save}/best.pth' 398 | ) 399 | 400 | if args.rank == 0: 401 | logger_console.info( 402 | f'Epoch: {epoch_best}, ' 403 | f'Best Accuracy: {acc_best}' 404 | ) 405 | if args.rank == 0: 406 | dst_dir = os.path.join(logger_tb.log_dir, str(acc_best) + '.acc') 407 | with open(dst_dir, 'w') as f: 408 | pass 409 | 410 | 411 | def parse_args(): 412 | parser = argparse.ArgumentParser() 413 | parser.add_argument("--arch", type=str, default='vit-small', 414 | choices=['vit-tiny', 'vit-small', 'vit-base']) 415 | parser.add_argument("--pretrained-weights", type=str, 416 | default='') 417 | parser.add_argument("--evaluate", type=str, default=None) 418 | return parser 419 | 420 | 421 | if __name__ == '__main__': 422 | parser = parse_args() 423 | _args = parser.parse_args() 424 | 425 | if _args.arch == 'vit-tiny': 426 | args = vit_tiny_finetune() 427 | elif _args.arch == 'vit-small': 428 | args = vit_small_finetune() 429 | elif _args.arch == 'vit-base': 430 | args = vit_base_finetune() 431 | args.pretrained_weights = _args.pretrained_weights 432 | args.evaluate = _args.evaluate 433 | main_ddp(args) 434 | -------------------------------------------------------------------------------- /main_knn.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import os 4 | 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 10 | from torch import nn 11 | from torchvision import datasets 12 | from torchvision import transforms 13 | from torchvision import transforms as pth_transforms 14 | 15 | from config.knn.knn import knn 16 | from module.vits import ViT 17 | from utils import misc 18 | 19 | 20 | class ImageFolderInstance(datasets.ImageFolder): 21 | def __getitem__(self, index): 22 | img, target = super(ImageFolderInstance, self).__getitem__(index) 23 | return img, target, index 24 | 25 | 26 | def build_dataset(is_train, args): 27 | transform = build_transform(args) 28 | dataset = ImageFolderInstance( 29 | root=os.path.join(args.data_root, 'train' if is_train else 'val'), transform=transform) 30 | return dataset 31 | 32 | 33 | def build_transform(args): 34 | return transforms.Compose([ 35 | pth_transforms.Resize(int(args.input_size / 224 * 256), interpolation=3), 36 | pth_transforms.CenterCrop(args.input_size), 37 | pth_transforms.ToTensor(), 38 | pth_transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), 39 | ]) 40 | 41 | 42 | def get_model_from_frame(checkpoint, args): 43 | encoder = args.encoder 44 | state_dict = checkpoint['state_dict'] 45 | encoder = ('module.' if 'module' in list( 46 | state_dict.keys())[0] else '') + encoder 47 | for k in list(state_dict.keys()): 48 | if k.startswith(encoder) and not k.startswith(encoder + '.head'): 49 | state_dict[k[len(encoder + "."):]] = state_dict[k] 50 | del state_dict[k] 51 | return state_dict 52 | 53 | 54 | def eval_knn(rank, args): 55 | args.rank = rank 56 | dist.init_process_group( 57 | backend="nccl", 58 | init_method=args.dist_url, 59 | world_size=args.world_size, 60 | rank=args.rank, 61 | ) 62 | 63 | misc.fix_random_seeds(args.seed) 64 | 65 | cudnn.benchmark = True 66 | 67 | if args.load_features: 68 | try: 69 | print("loading features...") 70 | train_features = torch.load(os.path.join( 71 | args.load_features, "trainfeat.pth")) 72 | test_features = torch.load(os.path.join( 73 | args.load_features, "testfeat.pth")) 74 | train_labels = torch.load(os.path.join( 75 | args.load_features, "trainlabels.pth")) 76 | test_labels = torch.load(os.path.join( 77 | args.load_features, "testlabels.pth")) 78 | except: 79 | train_features, test_features, train_labels, test_labels = extract_feature_pipeline( 80 | args) 81 | else: 82 | train_features, test_features, train_labels, test_labels = extract_feature_pipeline( 83 | args) 84 | 85 | if args.rank == 0: 86 | if args.use_cuda: 87 | train_features = train_features.cuda() 88 | test_features = test_features.cuda() 89 | train_labels = train_labels.cuda() 90 | test_labels = test_labels.cuda() 91 | 92 | print("Features are ready!\nStart the k-NN classification.") 93 | for k in args.nb_knn: 94 | top1, top5 = knn_classifier(train_features, train_labels, 95 | test_features, test_labels, k, args.temperature, args.use_cuda) 96 | print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}") 97 | dist.barrier() 98 | 99 | 100 | def extract_feature_pipeline(args): 101 | dataset_train = build_dataset(is_train=True, args=args) 102 | dataset_val = build_dataset(is_train=False, args=args) 103 | 104 | sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False) 105 | 106 | data_loader_train = torch.utils.data.DataLoader( 107 | dataset_train, 108 | sampler=sampler, 109 | batch_size=args.batch_size_per_gpu, 110 | num_workers=args.num_workers, 111 | pin_memory=True, 112 | drop_last=False, 113 | ) 114 | data_loader_val = torch.utils.data.DataLoader( 115 | dataset_val, 116 | batch_size=args.batch_size_per_gpu, 117 | num_workers=args.num_workers, 118 | pin_memory=True, 119 | drop_last=False, 120 | ) 121 | print( 122 | f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.") 123 | 124 | if args.arch == 'vit-tiny': 125 | model = ViT(patch_size=args.patch_size, img_size=args.input_size, 126 | embed_dim=192, depth=12, num_heads=3, mlp_ratio=4) 127 | elif args.arch == 'vit-small': 128 | model = ViT(patch_size=args.patch_size, img_size=args.input_size, 129 | embed_dim=384, depth=12, num_heads=12, mlp_ratio=4) 130 | 131 | elif args.arch == 'vit-base': 132 | model = ViT(patch_size=args.patch_size, img_size=args.input_size, 133 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4) 134 | 135 | print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") 136 | model.cuda() 137 | 138 | if args.pretrained_weights: 139 | if os.path.isfile(args.pretrained_weights): 140 | print("=> loading checkpoint '{}'".format(args.pretrained_weights)) 141 | checkpoint = torch.load( 142 | args.pretrained_weights, map_location=torch.device(args.rank)) 143 | 144 | state_dict = get_model_from_frame(checkpoint, args) 145 | 146 | args.start_epoch = 0 147 | msg = model.load_state_dict(state_dict, strict=False) 148 | assert set(msg.missing_keys) == {"head.weight", "head.bias"} 149 | 150 | print("=> loaded pre-trained model '{}'".format(args.pretrained_weights)) 151 | else: 152 | print("=> no checkpoint found at '{}'".format( 153 | args.pretrained_weights)) 154 | 155 | model.eval() 156 | 157 | print("Extracting features for train set...") 158 | train_features, train_labels = extract_features( 159 | model, data_loader_train, args) 160 | print("Extracting features for val set...") 161 | test_features, test_labels = extract_features( 162 | model, data_loader_val, args) 163 | 164 | if args.rank == 0: 165 | train_features = nn.functional.normalize(train_features, dim=1, p=2) 166 | test_features = nn.functional.normalize(test_features, dim=1, p=2) 167 | 168 | if args.dump_features and args.get_rank() == 0: 169 | print("Dumping features ...") 170 | torch.save(train_features.cpu(), os.path.join( 171 | args.dump_features, "trainfeat.pth")) 172 | torch.save(test_features.cpu(), os.path.join( 173 | args.dump_features, "testfeat.pth")) 174 | torch.save(train_labels.cpu(), os.path.join( 175 | args.dump_features, "trainlabels.pth")) 176 | torch.save(test_labels.cpu(), os.path.join( 177 | args.dump_features, "testlabels.pth")) 178 | return train_features, test_features, train_labels, test_labels 179 | 180 | 181 | @torch.no_grad() 182 | def extract_features(model, data_loader, args, multiscale=False): 183 | metric_logger = misc.MetricLogger(delimiter=" ") 184 | features = None 185 | labels = None 186 | for samples, labs, index in metric_logger.log_every(data_loader, 10): 187 | samples = samples.cuda(non_blocking=True) 188 | labs = labs.cuda(non_blocking=True) 189 | index = index.cuda(non_blocking=True) 190 | 191 | def forward_single(samples): 192 | output = model(samples) 193 | return output 194 | 195 | if multiscale: 196 | v = None 197 | for s in [1, 1 / 2 ** (1 / 2), 1 / 2]: 198 | if s == 1: 199 | inp = samples.clone() 200 | else: 201 | inp = nn.functional.interpolate( 202 | samples, scale_factor=s, mode='bilinear', align_corners=False) 203 | feats = forward_single(inp) 204 | if v is None: 205 | v = feats 206 | else: 207 | v += feats 208 | v /= 3 209 | v /= v.norm() 210 | feats = v 211 | else: 212 | feats = forward_single(samples) 213 | 214 | if args.rank == 0 and features is None: 215 | features = torch.zeros( 216 | len(data_loader.dataset), feats.shape[-1]).to(feats.dtype) 217 | labels = torch.zeros(len(data_loader.dataset)).to(labs.dtype) 218 | if args.use_cuda: 219 | features = features.cuda(non_blocking=True) 220 | labels = labels.cuda(non_blocking=True) 221 | print(f"Storing features into tensor of shape {features.shape}") 222 | print(f"Storing labels into tensor of shape {labels.shape}") 223 | 224 | y_all = torch.empty(args.world_size, index.size( 225 | 0), dtype=index.dtype, device=index.device) 226 | y_l = list(y_all.unbind(0)) 227 | y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True) 228 | y_all_reduce.wait() 229 | index_all = torch.cat(y_l) 230 | 231 | feats_all = torch.empty( 232 | args.world_size, 233 | feats.size(0), 234 | feats.size(1), 235 | dtype=feats.dtype, 236 | device=feats.device, 237 | ) 238 | output_l = list(feats_all.unbind(0)) 239 | output_all_reduce = torch.distributed.all_gather( 240 | output_l, feats, async_op=True) 241 | output_all_reduce.wait() 242 | 243 | labels_all = torch.empty( 244 | args.world_size, 245 | labs.size(0), 246 | dtype=labs.dtype, 247 | device=labs.device, 248 | ) 249 | label_l = list(labels_all.unbind(0)) 250 | label_all_reduce = torch.distributed.all_gather( 251 | label_l, labs, async_op=True) 252 | label_all_reduce.wait() 253 | 254 | if args.rank == 0: 255 | if args.use_cuda: 256 | features.index_copy_(0, index_all, torch.cat(output_l)) 257 | labels.index_copy_(0, index_all, torch.cat(label_l)) 258 | else: 259 | features.index_copy_(0, index_all.cpu(), 260 | torch.cat(output_l).cpu()) 261 | labels.index_copy_(0, index_all.cpu(), 262 | torch.cat(label_l).cpu()) 263 | return features, labels 264 | 265 | 266 | @torch.no_grad() 267 | def knn_classifier(train_features, train_labels, test_features, test_labels, k, T, use_cuda=True, num_classes=1000): 268 | top1, top5, total = 0.0, 0.0, 0 269 | train_features = train_features.t() 270 | num_test_images, num_chunks = test_labels.shape[0], 100 271 | imgs_per_chunk = num_test_images // num_chunks 272 | retrieval_one_hot = torch.zeros(k, num_classes) 273 | if use_cuda: 274 | retrieval_one_hot = retrieval_one_hot.cuda() 275 | for idx in range(0, num_test_images, imgs_per_chunk): 276 | features = test_features[ 277 | idx: min((idx + imgs_per_chunk), num_test_images), : 278 | ] 279 | targets = test_labels[idx: min( 280 | (idx + imgs_per_chunk), num_test_images)] 281 | batch_size = targets.shape[0] 282 | 283 | similarity = torch.mm(features, train_features) 284 | distances, indices = similarity.topk(k, largest=True, sorted=True) 285 | candidates = train_labels.view(1, -1).expand(batch_size, -1) 286 | retrieved_neighbors = torch.gather(candidates, 1, indices) 287 | 288 | retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() 289 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) 290 | distances_transform = distances.clone().div_(T).exp_() 291 | probs = torch.sum( 292 | torch.mul( 293 | retrieval_one_hot.view(batch_size, -1, num_classes), 294 | distances_transform.view(batch_size, -1, 1), 295 | ), 296 | 1, 297 | ) 298 | _, predictions = probs.sort(1, True) 299 | 300 | correct = predictions.eq(targets.data.view(-1, 1)) 301 | top1 = top1 + correct.narrow(1, 0, 1).sum().item() 302 | top5 = top5 + correct.narrow(1, 0, 5).sum().item() 303 | total += targets.size(0) 304 | top1 = top1 * 100.0 / total 305 | top5 = top5 * 100.0 / total 306 | return top1, top5 307 | 308 | 309 | def main_ddp(args): 310 | ngpus_per_node = torch.cuda.device_count() 311 | args.world_size = args.world_size * ngpus_per_node 312 | mp.spawn(eval_knn, args=(args,), nprocs=args.world_size) 313 | 314 | 315 | def parse_args(): 316 | parser = argparse.ArgumentParser() 317 | parser.add_argument("--arch", type=str, default='vit-small', 318 | choices=['vit-tiny', 'vit-small', 'vit-base']) 319 | parser.add_argument("--pretrained-weights", type=str, 320 | default='') 321 | return parser 322 | 323 | 324 | if __name__ == '__main__': 325 | parser = parse_args() 326 | _args = parser.parse_args() 327 | args = knn() 328 | args.pretrained_weights = _args.pretrained_weights 329 | args.arch = _args.arch 330 | main_ddp(args) 331 | -------------------------------------------------------------------------------- /main_linear.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Iterable 4 | 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | import torch.nn as nn 10 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 11 | from timm.optim import create_optimizer 12 | from timm.utils import NativeScaler, accuracy 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | from torchvision import datasets, transforms 15 | 16 | from config.linear.vit_base_linear import vit_base_linear 17 | from config.linear.vit_small_linear import vit_small_linear 18 | from config.linear.vit_tiny_linear import vit_tiny_linear 19 | from module.vits import ViT 20 | from utils import misc 21 | from utils.logger import Logger, console_logger 22 | from utils.misc import AverageMeter 23 | 24 | 25 | def build_dataset(is_train, args): 26 | transform = build_transform(is_train, args) 27 | 28 | if args.dataset == 'cifar100': 29 | dataset = datasets.CIFAR100( 30 | args.data_root, train=is_train, transform=transform) 31 | nb_classes = 100 32 | elif args.dataset == 'cifar10': 33 | dataset = datasets.CIFAR10( 34 | args.data_root, train=is_train, transform=transform) 35 | nb_classes = 10 36 | elif args.dataset == 'imagenet1k': 37 | dataset = datasets.ImageFolder( 38 | root=os.path.join(args.data_root, 'train' if is_train else 'val'), transform=transform) 39 | nb_classes = 1000 40 | return dataset, nb_classes 41 | 42 | 43 | def build_transform(is_train, args): 44 | if is_train: 45 | return transforms.Compose([ 46 | transforms.RandomResizedCrop(args.input_size, interpolation=3), 47 | transforms.RandomHorizontalFlip(), 48 | transforms.ToTensor(), 49 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]) 50 | else: 51 | return transforms.Compose([ 52 | transforms.Resize( 53 | int(args.input_size / 224 * 256), interpolation=3), 54 | transforms.CenterCrop(args.input_size), 55 | transforms.ToTensor(), 56 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]) 57 | 58 | 59 | def get_model_from_frame(checkpoint, args): 60 | encoder = args.encoder 61 | state_dict = checkpoint['state_dict'] 62 | encoder = ('module.' if 'module' in list( 63 | state_dict.keys())[0] else '') + encoder 64 | for k in list(state_dict.keys()): 65 | if k.startswith(encoder) and not k.startswith(encoder + '.head'): 66 | state_dict[k[len(encoder + "."):]] = state_dict[k] 67 | del state_dict[k] 68 | return state_dict 69 | 70 | 71 | def sanity_check(state_dict, pretrained_weights, args): 72 | """ 73 | Linear classifier should not change any weights other than the linear layer. 74 | This sanity check asserts nothing wrong happens (e.g., BN stats updated). 75 | """ 76 | encoder = args.encoder 77 | print("=> loading '{}' for sanity check".format(pretrained_weights)) 78 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 79 | state_dict_pre = checkpoint['state_dict'] 80 | module_kw = 'module.' if 'module' in list(state_dict_pre.keys())[0] else '' 81 | for k in list(state_dict.keys()): 82 | if 'head.weight' in k or 'head.bias' in k: 83 | continue 84 | pre_k = module_kw + encoder + '.' + (k[len('module.'):] if 'module' in k else k) 85 | assert ((state_dict[k].cpu() == state_dict_pre[pre_k]).all()), \ 86 | '{} is changed in linear classifier training.'.format(k) 87 | 88 | print("=> sanity check passed.") 89 | 90 | 91 | def train_one_epoch(model: torch.nn.Module, criterion, 92 | train_loader: Iterable, optimizer: torch.optim.Optimizer, 93 | epoch: int, loss_scaler, loggers, args, max_norm: float = 0, 94 | ): 95 | model.eval() 96 | logger_tb, logger_console = loggers 97 | 98 | losses = AverageMeter('Loss', ':.4e') 99 | 100 | num_iter = len(train_loader) 101 | niter_global = epoch * num_iter 102 | 103 | for i, (images, targets) in enumerate(train_loader): 104 | images = images.to(args.rank, non_blocking=True) 105 | targets = targets.to(args.rank, non_blocking=True) 106 | 107 | with torch.cuda.amp.autocast(): 108 | outputs = model(images) 109 | loss = criterion(outputs, targets) 110 | 111 | losses.update(loss.item(), images.size(0)) 112 | 113 | optimizer.zero_grad() 114 | is_second_order = hasattr( 115 | optimizer, 'is_second_order') and optimizer.is_second_order 116 | loss_scaler(loss, optimizer, clip_grad=max_norm, 117 | parameters=model.parameters(), create_graph=is_second_order) 118 | 119 | torch.cuda.synchronize() 120 | 121 | niter_global += 1 122 | if args.rank == 0: 123 | logger_tb.add_scalar('Finetune/Iter/loss', 124 | losses.val, niter_global) 125 | 126 | if (i + 1) % args.print_freq == 0 and logger_console is not None and args.rank == 0: 127 | lr = optimizer.param_groups[0]['lr'] 128 | logger_console.info(f'Epoch [{epoch}][{i + 1}/{num_iter}] - ' 129 | f'lr: {lr:.5f}, ' 130 | f'loss: {losses.avg:.3f}') 131 | if args.pretrained_weights and epoch == args.start_epoch and args.rank == 0: 132 | sanity_check(model.state_dict(), args.pretrained_weights, args) 133 | if args.distributed: 134 | losses.synchronize_between_processes() 135 | 136 | return losses.avg 137 | 138 | 139 | @torch.no_grad() 140 | def evaluate(data_loader, model, args): 141 | accs1 = AverageMeter('Acc@1', ':6.2f') 142 | accs5 = AverageMeter('Acc@5', ':6.2f') 143 | 144 | model.eval() 145 | 146 | for i, (images, target) in enumerate(data_loader): 147 | images = images.to(args.rank, non_blocking=True) 148 | target = target.to(args.rank, non_blocking=True, dtype=torch.long) 149 | 150 | with torch.cuda.amp.autocast(): 151 | output = model(images) 152 | 153 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 154 | 155 | batch_size = images.shape[0] 156 | accs1.update(acc1.item(), batch_size) 157 | accs5.update(acc5.item(), batch_size) 158 | if args.distributed: 159 | accs1.synchronize_between_processes() 160 | accs5.synchronize_between_processes() 161 | 162 | return accs1.avg, accs5.avg 163 | 164 | 165 | def main_ddp(args): 166 | if args.distributed: 167 | ngpus_per_node = torch.cuda.device_count() 168 | args.world_size = args.world_size * ngpus_per_node 169 | mp.spawn(main, args=(args,), nprocs=args.world_size) 170 | else: 171 | main(args.rank, args) 172 | 173 | 174 | def main(rank, args): 175 | args.rank = rank 176 | if args.distributed: 177 | dist.init_process_group( 178 | backend="nccl", 179 | init_method=args.dist_url, 180 | world_size=args.world_size, 181 | rank=args.rank, 182 | ) 183 | 184 | misc.fix_random_seeds(args.seed) 185 | 186 | cudnn.benchmark = True 187 | if not args.evaluate: 188 | if args.rank == 0: 189 | for k, v in sorted(vars(args).items()): 190 | print(k, '=', v) 191 | name = str(args.arch) + "_" + str(args.dataset) + \ 192 | "_epochs_" + str(args.epochs) + "_lr_" + str(args.lr) 193 | logger_tb = Logger(args.output_dir, name) 194 | logger_console = console_logger(logger_tb.log_dir, 'console_eval') 195 | dst_dir = os.path.join(logger_tb.log_dir, 'code/') 196 | else: 197 | logger_tb, logger_console = None, None 198 | if args.rank == 0: 199 | path_save = os.path.join(args.output_dir, logger_tb.log_name) 200 | 201 | dataset_train, num_class = build_dataset(is_train=True, args=args) 202 | dataset_val, _ = build_dataset(is_train=False, args=args) 203 | 204 | if args.distributed: 205 | num_tasks = args.world_size 206 | global_rank = args.rank 207 | sampler_train = torch.utils.data.DistributedSampler( 208 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 209 | ) 210 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 211 | args.num_workers = int((args.num_workers + 1) / args.world_size) 212 | else: 213 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 214 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 215 | 216 | if args.arch == 'vit-tiny': 217 | model = ViT(patch_size=args.patch_size, img_size=args.input_size, 218 | embed_dim=192, depth=12, num_heads=3, mlp_ratio=4) 219 | elif args.arch == 'vit-small': 220 | model = ViT(patch_size=args.patch_size, img_size=args.input_size, 221 | embed_dim=384, depth=12, num_heads=12, mlp_ratio=4) 222 | 223 | elif args.arch == 'vit-base': 224 | model = ViT(patch_size=args.patch_size, img_size=args.input_size, 225 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4) 226 | 227 | if args.pretrained_weights: 228 | if os.path.isfile(args.pretrained_weights): 229 | print("=> loading checkpoint '{}'".format(args.pretrained_weights)) 230 | checkpoint = torch.load( 231 | args.pretrained_weights, map_location=torch.device(args.rank)) 232 | 233 | state_dict = get_model_from_frame(checkpoint, args) 234 | 235 | args.start_epoch = 0 236 | msg = model.load_state_dict(state_dict, strict=False) 237 | print(msg.missing_keys) 238 | assert set(msg.missing_keys) == {"head.weight", "head.bias"} 239 | 240 | print("=> loaded pre-trained model '{}'".format(args.pretrained_weights)) 241 | else: 242 | print("=> no checkpoint found at '{}'".format( 243 | args.pretrained_weights)) 244 | 245 | model.head = nn.Linear(model.head.in_features, num_class) 246 | 247 | for name, param in model.named_parameters(): 248 | if name not in ['head.weight', 'head.bias']: 249 | param.requires_grad = False 250 | model.cuda(args.rank) 251 | 252 | model_without_ddp = model 253 | if args.distributed: 254 | model = torch.nn.parallel.DistributedDataParallel( 255 | model, device_ids=[args.rank]) 256 | torch.cuda.set_device(args.rank) 257 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 258 | args.batch_size = int(args.batch_size / args.world_size) 259 | model_without_ddp = model.module 260 | 261 | if args.distributed: 262 | args.lr = args.lr * args.batch_size * args.world_size / 256 263 | else: 264 | args.lr = args.lr * args.batch_size / 256 265 | 266 | optimizer = create_optimizer(args, model_without_ddp) 267 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 268 | optimizer, args.epochs, eta_min=0) 269 | 270 | loss_scaler = NativeScaler() 271 | criterion = torch.nn.CrossEntropyLoss() 272 | 273 | data_loader_train = torch.utils.data.DataLoader( 274 | dataset_train, sampler=sampler_train, 275 | batch_size=args.batch_size, 276 | num_workers=args.num_workers, 277 | pin_memory=args.pin_memory, 278 | prefetch_factor=args.prefetch_factor, 279 | drop_last=True, 280 | ) 281 | 282 | data_loader_val = torch.utils.data.DataLoader( 283 | dataset_val, sampler=sampler_val, 284 | batch_size=int(1.5 * args.batch_size), 285 | num_workers=args.num_workers, 286 | pin_memory=args.pin_memory, 287 | prefetch_factor=args.prefetch_factor, 288 | drop_last=False 289 | ) 290 | 291 | acc_best = 0.0 292 | if args.resume: 293 | if os.path.isfile(args.resume): 294 | print("=> loading checkpoint '{}'".format(args.resume)) 295 | if args.gpu is None: 296 | checkpoint = torch.load(args.resume) 297 | else: 298 | loc = 'cuda:{}'.format(args.gpu) 299 | checkpoint = torch.load(args.resume, map_location=loc) 300 | args.start_epoch = checkpoint['epoch'] 301 | acc_best = checkpoint['acc_best'] 302 | if args.gpu is not None: 303 | acc_best = acc_best.to(args.gpu) 304 | if isinstance(model, DDP): 305 | model.module.load_state_dict(checkpoint['state_dict']) 306 | else: 307 | model.load_state_dict(checkpoint['state_dict']) 308 | optimizer.load_state_dict(checkpoint['optimizer']) 309 | loss_scaler.load_state_dict(checkpoint['scaler']) 310 | print("=> loaded checkpoint '{}' (epoch {})" 311 | .format(args.resume, checkpoint['epoch'])) 312 | else: 313 | print("=> no checkpoint found at '{}'".format(args.resume)) 314 | 315 | if args.evaluate: 316 | if os.path.isfile(args.evaluate): 317 | print("=> loading checkpoint '{}'".format(args.evaluate)) 318 | model = torch.load( 319 | args.evaluate, map_location=torch.device(args.rank)) 320 | print("=> loaded pre-trained model '{}'".format(args.evaluate)) 321 | else: 322 | print("=> no checkpoint found at '{}'".format(args.evaluate)) 323 | acc1, acc5 = evaluate(data_loader_val, model, args) 324 | print('Acc1 :' + str(acc1) + '\tAcc5 :' + str(acc5)) 325 | return 326 | 327 | print(f"Start training for {args.epochs} epochs") 328 | for epoch in range(args.start_epoch, args.epochs): 329 | if args.distributed: 330 | data_loader_train.sampler.set_epoch(epoch) 331 | loss = train_one_epoch( 332 | model, criterion, data_loader_train, 333 | optimizer, epoch, loss_scaler, (logger_tb, logger_console), args, 334 | args.clip_grad 335 | ) 336 | if args.rank == 0: 337 | logger_tb.add_scalar('Finetune/Epoch/loss', loss, epoch) 338 | 339 | state_dict = model.module.state_dict() if isinstance(model, DDP) else model.state_dict() 340 | if epoch % args.save_freq == 0 and args.rank == 0: 341 | torch.save( 342 | { 343 | 'epoch': epoch + 1, 344 | 'arch': args.arch, 345 | 'state_dict': state_dict, 346 | 'acc_best': acc_best, 347 | 'optimizer': optimizer.state_dict(), 348 | 'scaler': loss_scaler.state_dict(), 349 | }, 350 | f'{path_save}/{epoch:0>4d}.pth' 351 | ) 352 | 353 | lr_scheduler.step(epoch) 354 | acc1, acc5 = evaluate(data_loader_val, model, args) 355 | if args.rank == 0: 356 | logger_tb.add_scalar('Finetune/Epoch/Accuracy', acc1, epoch) 357 | logger_console.info( 358 | f'Epoch: {epoch}\t' 359 | f'Acc1: {round(acc1, 3)}\t' 360 | f'Acc5: {round(acc5, 3)}' 361 | ) 362 | 363 | if acc1 > acc_best: 364 | acc_best = acc1 365 | epoch_best = epoch 366 | if args.rank == 0: 367 | torch.save( 368 | model_without_ddp, 369 | f'{path_save}/best.pth' 370 | ) 371 | 372 | if args.rank == 0: 373 | logger_console.info( 374 | f'Epoch: {epoch_best}, ' 375 | f'Best Acc1: {round(acc_best, 3)}' 376 | ) 377 | 378 | 379 | def parse_args(): 380 | parser = argparse.ArgumentParser() 381 | parser.add_argument("--arch", type=str, default='vit-small', 382 | choices=['vit-tiny', 'vit-small', 'vit-base']) 383 | parser.add_argument("--pretrained-weights", type=str, 384 | default='') 385 | parser.add_argument("--evaluate", type=str, default=None) 386 | return parser 387 | 388 | 389 | if __name__ == '__main__': 390 | parser = parse_args() 391 | _args = parser.parse_args() 392 | 393 | if _args.arch == 'vit-tiny': 394 | args = vit_tiny_linear() 395 | elif _args.arch == 'vit-small': 396 | args = vit_small_linear() 397 | elif _args.arch == 'vit-base': 398 | args = vit_base_linear() 399 | args.pretrained_weights = _args.pretrained_weights 400 | args.evaluate = _args.evaluate 401 | main_ddp(args) 402 | -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | import torch.nn as nn 9 | from torch.cuda.amp import GradScaler 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | from torch.utils.data import DataLoader 12 | from torchvision import datasets 13 | 14 | from config.pretrain.vit_base_pretrain import vit_base_pretrain 15 | from config.pretrain.vit_small_pretrain import vit_small_pretrain 16 | from config.pretrain.vit_tiny_pretrain import vit_tiny_pretrain 17 | from module.augmentation import TwoCropsTransform, MultiCropsTransform 18 | from module.frame.contrast_momentum import ContrastMomentum_ViT 19 | from module.frame.contrast_no_momentum import ContrastNoMomentum_ViT 20 | from module.loss import MultiTempContrastiveLoss 21 | from module.mix import PatchMixer 22 | from module.vits import ViT 23 | from utils import misc 24 | from utils.logger import Logger, console_logger 25 | from utils.misc import AverageMeter, adjust_moco_momentum 26 | 27 | 28 | def train_epoch(train_loader, model, criterion, optimizer, lr_schedule, wd_schedule, temp_schedule, mixer, scaler, 29 | loggers, epoch, args): 30 | model.train() 31 | logger_tb, logger_console = loggers 32 | 33 | src_losses = AverageMeter('Src Loss', ':.4e') 34 | mix_losses = AverageMeter('Mix Loss', ':.4e') 35 | mix2_losses = AverageMeter('Mix Mix Loss', ':.4e') 36 | multi_losses = AverageMeter('Multi Loss', ':.4e') 37 | multi_mix_losses = AverageMeter('Multi Mix Loss', ':.4e') 38 | losses = AverageMeter('Loss', ':.4e') 39 | learning_rates = AverageMeter('LR', ':.4e') 40 | weight_decays = AverageMeter('WD', ':.4e') 41 | 42 | num_iter = len(train_loader) 43 | niter_global = epoch * num_iter 44 | no_mixer = PatchMixer(mix_s=args.mix_size, 45 | num_classes=int(args.batch_size * args.world_size), mix_p=0.0, mix_n=1) 46 | 47 | moco_m = args.moco_m 48 | temp = temp_schedule[epoch] 49 | for i, (images, _) in enumerate(train_loader): 50 | # update weight decay and learning rate according to their schedule 51 | it = num_iter * epoch + i # global training iteration 52 | for i, param_group in enumerate(optimizer.param_groups): 53 | param_group["lr"] = lr_schedule[it] 54 | if i == 0 and args.use_wd_cos: # only the first group is regularized 55 | param_group["weight_decay"] = wd_schedule[it] 56 | 57 | if args.use_moco: 58 | if args.moco_m_cos: 59 | moco_m = adjust_moco_momentum(epoch + i / num_iter, args) 60 | with torch.no_grad(): # no gradient 61 | model.module.update_momentum_encoder(moco_m) 62 | 63 | images[0] = images[0].cuda(args.rank, non_blocking=True) 64 | images[1] = images[1].cuda(args.rank, non_blocking=True) 65 | if args.multi_crop_num != 0: 66 | for id in range(args.multi_crop_num): 67 | images[2][id] = images[2][id].cuda( 68 | args.rank, non_blocking=True) 69 | 70 | N = images[0].size(0) 71 | target = torch.arange(N, dtype=torch.long).cuda() 72 | optimizer.zero_grad() 73 | mix_image1, mix_target, mix2_target = mixer(images[0], target) 74 | mix_image2, mix_target, mix2_target = mixer(images[1], target) 75 | images[0], target0, _ = no_mixer(images[0], target) 76 | images[1], target0, _ = no_mixer(images[1], target) 77 | 78 | 79 | with torch.cuda.amp.autocast(True): 80 | q1, k1 = model(images[0]) 81 | _, k2 = model(images[1], q=False) 82 | _, m_k1 = model(mix_image1, q=False) 83 | m_q2, _ = model(mix_image2, k=False) 84 | src_loss = criterion(q1, k2.detach(), target0, temp) 85 | mix_loss = criterion(m_q2, k1.detach(), mix_target, temp) / 2.0 86 | mix2_loss = criterion(m_q2, m_k1.detach(), mix2_target, temp) / 2.0 87 | multi_loss = 0. 88 | multi_mix_loss = 0. 89 | if args.multi_crop_num != 0: 90 | multi_image = [] 91 | multi_mix_image = [] 92 | for id in range(args.multi_crop_num): 93 | images[2][id], _, _ = no_mixer(images[2][id], target) 94 | mix_image, _, _ = mixer(images[2][id], target) 95 | multi_image.append(images[2][id]) 96 | multi_mix_image.append(mix_image) 97 | multi_image = torch.cat(multi_image, dim=0) 98 | multi_mix_image = torch.cat(multi_mix_image, dim=0) 99 | with torch.cuda.amp.autocast(True): 100 | multi_q_, _ = model(multi_image, k=False) 101 | multi_m_q_, _ = model(multi_mix_image, k=False) 102 | mts1 = 0. 103 | mts2 = 0. 104 | mtms1 = 0. 105 | mtms2 = 0. 106 | for id in range(args.multi_crop_num): 107 | mts1 += criterion(multi_q_[N * id:N * (id + 1)], k1.detach(), target0, temp) 108 | mts2 += criterion(multi_q_[N * id:N * (id + 1)], k2.detach(), target0, temp) 109 | mtms1 += criterion(multi_m_q_[N * id:N * (id + 1)], k1.detach(), mix_target, temp) 110 | mtms2 += criterion(multi_m_q_[N * id:N * (id + 1)], k2.detach(), mix_target, temp) 111 | multi_loss = (mts1 + mts2) / args.multi_crop_num 112 | multi_mix_loss = (mtms1 + mtms2) / args.multi_crop_num 113 | loss = src_loss + mix_loss + mix2_loss + multi_loss + multi_mix_loss 114 | scaler.scale(loss).backward() 115 | 116 | 117 | scaler.step(optimizer) 118 | scaler.update() 119 | 120 | src_losses.update(src_loss.item(), N) 121 | mix_losses.update(mix_loss.item(), N) 122 | mix2_losses.update(mix2_loss.item(), N) 123 | multi_losses.update( 124 | multi_loss.item() if args.multi_crop_num != 0 else 0.0, N) 125 | multi_mix_losses.update( 126 | multi_mix_loss.item() if args.multi_crop_num != 0 else 0.0, N) 127 | losses.update(loss.item(), N) 128 | 129 | learning_rates.update(lr_schedule[it]) 130 | weight_decays.update(wd_schedule[it]) 131 | niter_global += 1 132 | 133 | if args.distributed: 134 | src_losses.synchronize_between_processes() 135 | mix_losses.synchronize_between_processes() 136 | mix_losses.synchronize_between_processes() 137 | multi_losses.synchronize_between_processes() 138 | multi_mix_losses.synchronize_between_processes() 139 | losses.synchronize_between_processes() 140 | 141 | if logger_console is not None and args.rank == 0: 142 | logger_console.info(f'Epoch [{epoch}][{i + 1}/{num_iter}] - ' 143 | f'lr: {lr_schedule[it]:.5f}, ' 144 | f'wd: {wd_schedule[it]:.5f}, ' 145 | f'src loss: {src_losses.avg:.3f}, ' 146 | f'mix loss: {mix_losses.avg:.3f}, ' 147 | f'mix2 loss: {mix2_losses.avg:.3f}, ' 148 | f'multi loss: {multi_losses.avg:.3f}, ' 149 | f'multi mix loss: {multi_mix_losses.avg:.3f}, ' 150 | f'loss: {losses.avg:.3f}' 151 | ) 152 | 153 | if logger_tb is not None and args.rank == 0: 154 | logger_tb.add_scalar('Epoch/Src Loss', src_losses.avg, epoch + 1) 155 | logger_tb.add_scalar('Epoch/Mix Loss', mix_losses.avg, epoch + 1) 156 | logger_tb.add_scalar('Epoch/Mix2 Loss', mix2_losses.avg, epoch + 1) 157 | logger_tb.add_scalar('Epoch/Multi Loss', multi_losses.avg, epoch + 1) 158 | logger_tb.add_scalar('Epoch/Multi Mix Loss', 159 | multi_mix_losses.avg, epoch + 1) 160 | logger_tb.add_scalar('Epoch/Loss', losses.avg, epoch + 1) 161 | logger_tb.add_scalar('Epoch/lr', lr_schedule[it], epoch + 1) 162 | logger_tb.add_scalar('Epoch/wd', wd_schedule[it], epoch + 1) 163 | 164 | 165 | def main_worker(gpu, ngpus_per_node, args): 166 | rank = args.rank * ngpus_per_node + gpu 167 | if args.distributed: 168 | dist.init_process_group( 169 | backend='nccl', init_method=args.init_method, rank=rank, world_size=args.world_size) 170 | torch.distributed.barrier() 171 | args.rank = rank 172 | misc.fix_random_seeds(args.seed) 173 | 174 | # ------------------------------ logger -----------------------------# 175 | if args.rank == 0: 176 | os.makedirs(args.exp_dir, exist_ok=True) 177 | log_root = args.exp_dir 178 | name = f'vit_encoder_projection_{args.proj_layer}layers_with_BN_prediction_{args.pred_layer}layers'f'_dim{args.out_dim}'f'_hidden_dim{args.hidden_dim}' 179 | logger_tb = Logger(log_root, name) 180 | logger_console = console_logger(logger_tb.log_dir, 'console') 181 | else: 182 | logger_tb, logger_console = None, None 183 | 184 | # --------------------------------- model ------------------------------# 185 | 186 | if args.arch == 'vit-tiny': 187 | base_encoder = ViT(patch_size=args.patch_size, img_size=args.input_size, num_classes=args.out_dim, 188 | embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, drop_path_rate=args.drop_path) 189 | momentum_encoder = ViT(patch_size=args.patch_size, img_size=args.input_size, 190 | num_classes=args.out_dim, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4) 191 | elif args.arch == 'vit-small': 192 | base_encoder = ViT(patch_size=args.patch_size, img_size=args.input_size, num_classes=args.out_dim, 193 | embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, drop_path_rate=args.drop_path) 194 | momentum_encoder = ViT(patch_size=args.patch_size, img_size=args.input_size, 195 | num_classes=args.out_dim, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4) 196 | elif args.arch == 'vit-base': 197 | base_encoder = ViT(patch_size=args.patch_size, img_size=args.input_size, num_classes=args.out_dim, 198 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, drop_path_rate=args.drop_path) 199 | momentum_encoder = ViT(patch_size=args.patch_size, img_size=args.input_size, 200 | num_classes=args.out_dim, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4) 201 | 202 | if args.use_moco: 203 | model = ContrastMomentum_ViT(base_encoder, momentum_encoder, args.proj_layer, 204 | args.pred_layer, args.out_dim, args.hidden_dim) 205 | else: 206 | model = ContrastNoMomentum_ViT(base_encoder, args.proj_layer, 207 | args.pred_layer, args.out_dim, args.hidden_dim) 208 | 209 | model = model.cuda(args.rank) 210 | 211 | args.lr = args.lr * args.batch_size / 256 212 | 213 | if args.distributed: 214 | torch.cuda.set_device(args.rank) 215 | args.batch_size = int(args.batch_size / args.world_size) 216 | args.num_workers = int( 217 | (args.num_workers + args.world_size - 1) / args.world_size) 218 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 219 | model = DDP(model, device_ids=[args.rank], broadcast_buffers=False) 220 | 221 | # --------------------------- data load -----------------------# 222 | transform = TwoCropsTransform( 223 | args) if args.multi_crop_num == 0 else MultiCropsTransform(args) 224 | if args.dataset == 'cifar10': 225 | train_set = datasets.CIFAR10(root=args.data_root, train=True, download=False, 226 | transform=transform) 227 | elif args.dataset == 'cifar100': 228 | train_set = datasets.CIFAR100(root=args.data_root, train=True, download=False, 229 | transform=transform) 230 | elif args.dataset == 'imagenet1k': 231 | train_set = datasets.ImageFolder(root=os.path.join(args.data_root, 'train'), 232 | transform=transform) 233 | 234 | if args.distributed: 235 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_set) 236 | else: 237 | train_sampler = None 238 | 239 | train_loader = DataLoader(dataset=train_set, 240 | batch_size=args.batch_size, 241 | shuffle=(train_sampler is None), 242 | num_workers=args.num_workers, 243 | sampler=train_sampler, 244 | pin_memory=args.pin_memory, 245 | drop_last=True, 246 | prefetch_factor=args.prefetch_factor) 247 | 248 | args.niters_per_epoch = len(train_set) // args.batch_size 249 | 250 | # ----------------------------- patchmix ----------------------------------# 251 | mixer = PatchMixer( 252 | num_classes=int(args.batch_size * args.world_size), mix_s=args.mix_size, 253 | mix_n=args.mix_num, mix_p=args.mix_p, smoothing=args.smoothing) 254 | 255 | # ---------------------------- loss ---------------------------# 256 | criterion = MultiTempContrastiveLoss() 257 | 258 | # ---------------------------- optimizer ---------------------------# 259 | if args.use_wd_cos: 260 | parameters = model.module.named_parameters() if isinstance( 261 | model, DDP) else model.named_parameters() 262 | params_groups = misc.get_params_groups(parameters) 263 | optimizer = torch.optim.AdamW(params_groups) 264 | else: 265 | parameters = model.module.parameters() if isinstance( 266 | model, DDP) else model.parameters() 267 | optimizer = torch.optim.AdamW( 268 | parameters, weight_decay=args.weight_decay) 269 | 270 | scaler = GradScaler() 271 | 272 | start_epoch = 0 273 | 274 | # ---------------------------- scheduler ---------------------------# 275 | lr_schedule = misc.cosine_scheduler( 276 | args.lr, # linear scaling rule 277 | args.min_lr, 278 | args.epochs, len(train_loader), 279 | warmup_epochs=args.warmup_epoch, 280 | ) 281 | wd_schedule = misc.cosine_scheduler( 282 | args.weight_decay, 283 | args.weight_decay_end, 284 | args.epochs, len(train_loader), 285 | ) 286 | 287 | temp_schedule = np.concatenate(( 288 | np.linspace(args.warmup_temp, args.temp, args.warmup_temp_epochs), 289 | np.ones(args.epochs - args.warmup_temp_epochs) * args.temp 290 | )) 291 | 292 | # ---------------------------- checkpoint ---------------------------# 293 | if args.resume: 294 | if os.path.isfile(args.resume): 295 | print("=> loading checkpoint '{}'".format(args.resume)) 296 | loc = 'cuda:{}'.format(args.rank) 297 | checkpoint = torch.load(args.resume, map_location=loc) 298 | start_epoch = checkpoint['epoch'] 299 | if isinstance(model, DDP): 300 | model.module.load_state_dict(checkpoint['state_dict']) 301 | else: 302 | model.load_state_dict(checkpoint['state_dict']) 303 | optimizer.load_state_dict(checkpoint['optimizer']) 304 | scaler.load_state_dict(checkpoint['scaler']) 305 | print("=> loaded checkpoint '{}' (epoch {})" 306 | .format(args.resume, checkpoint['epoch'])) 307 | else: 308 | print("=> no checkpoint found at '{}'".format(args.resume)) 309 | 310 | if args.rank == 0: 311 | path_save = os.path.join(args.exp_dir, logger_tb.log_name) 312 | 313 | # ---------------------------- training ---------------------------# 314 | for epoch in range(start_epoch, args.epochs): 315 | if args.distributed: 316 | train_sampler.set_epoch(epoch) 317 | 318 | train_epoch(train_loader, model, criterion, optimizer, lr_schedule, wd_schedule, temp_schedule, 319 | mixer, scaler, (logger_tb, logger_console), epoch, args) 320 | 321 | if (epoch + 1) % args.save_freq == 0 and args.rank == 0: 322 | _epoch = epoch + 1 323 | state_dict = model.module.state_dict() if isinstance( 324 | model, DDP) else model.state_dict() 325 | torch.save({ 326 | 'epoch': epoch + 1, 327 | 'arch': args.arch, 328 | 'state_dict': state_dict, 329 | 'optimizer': optimizer.state_dict(), 330 | 'scaler': scaler.state_dict(), 331 | }, f'{path_save}/{_epoch:0>4d}.pth') 332 | 333 | if args.rank == 0: 334 | state_dict = model.module.state_dict() \ 335 | if isinstance(model, DDP) else model.state_dict() 336 | 337 | torch.save({'state_dict': state_dict}, f'{path_save}/last.pth') 338 | 339 | 340 | def main(args): 341 | ngpus_per_node = torch.cuda.device_count() 342 | args.world_size = args.world_size * ngpus_per_node 343 | if args.distributed: 344 | mp.spawn(main_worker, args=(ngpus_per_node, args), 345 | nprocs=args.world_size) 346 | else: 347 | main_worker(args.rank, ngpus_per_node, args) 348 | 349 | 350 | def parse_args(): 351 | parser = argparse.ArgumentParser() 352 | parser.add_argument("--arch", type=str, default='vit-small', 353 | choices=['vit-tiny', 'vit-small', 'vit-base']) 354 | return parser 355 | 356 | 357 | if __name__ == '__main__': 358 | parser = parse_args() 359 | _args = parser.parse_args() 360 | if _args.arch == 'vit-tiny': 361 | args = vit_tiny_pretrain() 362 | elif _args.arch == 'vit-small': 363 | args = vit_small_pretrain() 364 | elif _args.arch == 'vit-base': 365 | args = vit_base_pretrain() 366 | main(args) 367 | -------------------------------------------------------------------------------- /module/augmentation.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | import random 3 | 4 | from PIL import Image 5 | from PIL import ImageFilter, ImageOps 6 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 7 | from torchvision.transforms import transforms 8 | 9 | 10 | class GaussianBlur: 11 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 12 | 13 | def __init__(self, sigma=[.1, 2.]): 14 | self.sigma = sigma 15 | 16 | def __call__(self, x): 17 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 18 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 19 | return x 20 | 21 | 22 | class Solarize: 23 | """Solarize augmentation from BYOL: https://arxiv.org/abs/2006.07733""" 24 | 25 | def __call__(self, x): 26 | return ImageOps.solarize(x) 27 | 28 | 29 | class TwoCropsTransform: 30 | def __init__(self, args): 31 | self.base_transform1 = transforms.Compose([ 32 | transforms.RandomResizedCrop( 33 | args.input_size, scale=(args.min_crop, 1.0), interpolation=Image.BICUBIC), 34 | transforms.RandomApply([ 35 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) 36 | ], p=0.8), 37 | transforms.RandomGrayscale(p=0.2), 38 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=1.0), 39 | transforms.RandomHorizontalFlip(), 40 | transforms.ToTensor(), 41 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) 42 | ] 43 | ) 44 | self.base_transform2 = transforms.Compose([ 45 | transforms.RandomResizedCrop( 46 | args.input_size, scale=(args.min_crop, 1.0), interpolation=Image.BICUBIC), 47 | transforms.RandomApply([ 48 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) 49 | ], p=0.8), 50 | transforms.RandomGrayscale(p=0.2), 51 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.1), 52 | transforms.RandomApply([Solarize()], p=0.2), 53 | transforms.RandomHorizontalFlip(), 54 | transforms.ToTensor(), 55 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) 56 | ] 57 | ) 58 | 59 | def __call__(self, x): 60 | im1 = self.base_transform1(x) 61 | im2 = self.base_transform2(x) 62 | return im1, im2 63 | 64 | 65 | class MultiCropsTransform: 66 | def __init__(self, args): 67 | self.crop_num = args.multi_crop_num 68 | self.base_transform1 = transforms.Compose([ 69 | transforms.RandomResizedCrop( 70 | args.input_size, scale=(args.global_crop, 1.0), interpolation=Image.BICUBIC), 71 | transforms.RandomApply([ 72 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) 73 | ], p=0.8), 74 | transforms.RandomGrayscale(p=0.2), 75 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=1.0), 76 | transforms.RandomHorizontalFlip(), 77 | transforms.ToTensor(), 78 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) 79 | ] 80 | ) 81 | self.base_transform2 = transforms.Compose([ 82 | transforms.RandomResizedCrop( 83 | args.input_size, scale=(args.global_crop, 1.0), interpolation=Image.BICUBIC), 84 | transforms.RandomApply([ 85 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) 86 | ], p=0.8), 87 | transforms.RandomGrayscale(p=0.2), 88 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.1), 89 | transforms.RandomApply([Solarize()], p=0.2), 90 | transforms.RandomHorizontalFlip(), 91 | transforms.ToTensor(), 92 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) 93 | ] 94 | ) 95 | self.local_transform = transforms.Compose([ 96 | transforms.RandomResizedCrop( 97 | args.multi_crop_size, scale=(args.min_crop, args.global_crop), interpolation=Image.BICUBIC), 98 | transforms.RandomApply([ 99 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) 100 | ], p=0.8), 101 | transforms.RandomGrayscale(p=0.2), 102 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 103 | transforms.RandomHorizontalFlip(), 104 | transforms.ToTensor(), 105 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) 106 | ] 107 | ) 108 | 109 | def __call__(self, x): 110 | im1 = self.base_transform1(x) 111 | im2 = self.base_transform2(x) 112 | multi_im = [] 113 | for _ in range(self.crop_num): 114 | multi_im.append(self.local_transform(x)) 115 | return im1, im2, multi_im 116 | -------------------------------------------------------------------------------- /module/frame/contrast_momentum.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class ContrastMomentum(nn.Module): 7 | 8 | def __init__(self, base_encoder, momentum_encoder, proj_layer, pred_layer, dim=256, mlp_dim=4096): 9 | super(ContrastMomentum, self).__init__() 10 | # build encoders 11 | self.base_encoder = base_encoder 12 | self.momentum_encoder = momentum_encoder 13 | 14 | self._build_projector_and_predictor_mlps( 15 | proj_layer, pred_layer, dim, mlp_dim) 16 | 17 | for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()): 18 | param_m.data.copy_(param_b.data) # initialize 19 | param_m.requires_grad = False # not update by gradient 20 | 21 | def _build_mlp(self, num_layers, input_dim, mlp_dim, output_dim, last_bn=True): 22 | mlp = [] 23 | for l in range(num_layers): 24 | dim1 = input_dim if l == 0 else mlp_dim 25 | dim2 = output_dim if l == num_layers - 1 else mlp_dim 26 | 27 | mlp.append(nn.Linear(dim1, dim2, bias=False)) 28 | 29 | if l < num_layers - 1: 30 | mlp.append(nn.BatchNorm1d(dim2)) 31 | mlp.append(nn.ReLU(inplace=True)) 32 | elif last_bn: 33 | mlp.append(nn.BatchNorm1d(dim2, affine=False)) 34 | 35 | return nn.Sequential(*mlp) 36 | def _build_projector_and_predictor_mlps(self, proj_layer, pred_layer, dim, mlp_dim): 37 | pass 38 | 39 | @torch.no_grad() 40 | def update_momentum_encoder(self, m): 41 | """Momentum update of the momentum encoder""" 42 | for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()): 43 | param_m.data = param_m.data * m + param_b.data * (1. - m) 44 | 45 | def forward(self, x, k=True, q=True): 46 | q = self.predictor(self.base_encoder(x)) if q else None 47 | k = self.momentum_encoder(x) if k else None 48 | return q, k 49 | 50 | 51 | class ContrastMomentum_ViT(ContrastMomentum): 52 | def _build_projector_and_predictor_mlps(self, proj_layer, pred_layer, dim, mlp_dim): 53 | hidden_dim = self.base_encoder.head.weight.shape[1] 54 | del self.base_encoder.head, self.momentum_encoder.head # remove original fc layer 55 | # # projectors 56 | self.base_encoder.head = self._build_mlp(proj_layer, hidden_dim, mlp_dim, dim) 57 | self.momentum_encoder.head = self._build_mlp(proj_layer, hidden_dim, mlp_dim, dim) 58 | # predictor 59 | self.predictor = self._build_mlp(pred_layer, dim, mlp_dim, dim) 60 | -------------------------------------------------------------------------------- /module/frame/contrast_no_momentum.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | class ContrastNoMomentum(nn.Module): 4 | 5 | def __init__(self, student, proj_layer, pred_layer, dim=256, mlp_dim=4096): 6 | super(ContrastNoMomentum, self).__init__() 7 | # build encoders 8 | self.base_encoder = student 9 | 10 | self._build_projector_and_predictor_mlps( 11 | proj_layer, pred_layer, dim, mlp_dim) 12 | 13 | def _build_mlp(self, num_layers, input_dim, mlp_dim, output_dim, last_bn=True): 14 | mlp = [] 15 | for l in range(num_layers): 16 | dim1 = input_dim if l == 0 else mlp_dim 17 | dim2 = output_dim if l == num_layers - 1 else mlp_dim 18 | 19 | mlp.append(nn.Linear(dim1, dim2, bias=False)) 20 | 21 | if l < num_layers - 1: 22 | mlp.append(nn.BatchNorm1d(dim2)) 23 | mlp.append(nn.ReLU(inplace=True)) 24 | elif last_bn: 25 | mlp.append(nn.BatchNorm1d(dim2, affine=False)) 26 | 27 | return nn.Sequential(*mlp) 28 | 29 | def _build_projector_and_predictor_mlps(self, proj_layer, pred_layer, dim, mlp_dim): 30 | pass 31 | 32 | def forward(self, x, k=True, q=True): 33 | if q: 34 | k = self.base_encoder(x) 35 | q = self.predictor(k) 36 | elif k: 37 | with torch.no_grad(): 38 | k = self.base_encoder(x) 39 | q = None 40 | return q, k 41 | 42 | 43 | class ContrastNoMomentum_ViT(ContrastNoMomentum): 44 | def _build_projector_and_predictor_mlps(self, proj_layer, pred_layer, dim, mlp_dim): 45 | hidden_dim = self.base_encoder.head.weight.shape[1] 46 | del self.base_encoder.head # remove original fc layer 47 | # projectors 48 | self.base_encoder.head = self._build_mlp( 49 | proj_layer, hidden_dim, mlp_dim, dim) 50 | # predictor 51 | self.predictor = self._build_mlp(pred_layer, dim, mlp_dim, dim) 52 | -------------------------------------------------------------------------------- /module/loss.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | import torch 3 | import torch.nn as nn 4 | from timm.loss import SoftTargetCrossEntropy 5 | 6 | 7 | class ContrastiveLoss(nn.Module): 8 | def __init__(self, temperature=1.0): 9 | super(ContrastiveLoss, self).__init__() 10 | self.temperature = temperature 11 | 12 | def forward(self, q: torch.Tensor, k: torch.Tensor, target: torch.Tensor, use_neg=True) -> torch.Tensor: 13 | # normalize 14 | q = nn.functional.normalize(q, dim=1) 15 | k = nn.functional.normalize(k, dim=1) 16 | if use_neg: 17 | k = concat_all_gather(k) 18 | logits = torch.einsum('nc,mc->nm', [q, k]) / self.temperature 19 | return SoftTargetCrossEntropy()(logits, target) 20 | else: 21 | k = k.detach() 22 | return -(nn.CosineSimilarity(dim=1)(q, k).mean()) 23 | 24 | 25 | class MultiTempContrastiveLoss(nn.Module): 26 | 27 | def forward(self, q: torch.Tensor, k: torch.Tensor, target: torch.Tensor, temperature, 28 | use_neg=True) -> torch.Tensor: 29 | # normalize 30 | q = nn.functional.normalize(q, dim=1) 31 | k = nn.functional.normalize(k, dim=1) 32 | if use_neg: 33 | k = concat_all_gather(k) 34 | logits = torch.einsum('nc,mc->nm', [q, k]) / temperature 35 | return SoftTargetCrossEntropy()(logits, target) 36 | else: 37 | k = k.detach() 38 | return -(nn.CosineSimilarity(dim=1)(q, k).mean()) 39 | 40 | 41 | @torch.no_grad() 42 | def concat_all_gather(tensor): 43 | """ 44 | Performs all_gather operation on the provided tensors. 45 | *** Warning ***: torch.distributed.all_gather has no gradient. 46 | """ 47 | tensors_gather = [torch.ones_like(tensor) 48 | for _ in range(torch.distributed.get_world_size())] 49 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 50 | 51 | output = torch.cat(tensors_gather, dim=0) 52 | return output 53 | -------------------------------------------------------------------------------- /module/mix.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | import numpy as np 3 | import torch 4 | from einops import rearrange, repeat 5 | 6 | 7 | def random_indexes(size: int): 8 | forward_indexes = np.arange(size) 9 | np.random.shuffle(forward_indexes) 10 | backward_indexes = np.argsort(forward_indexes) 11 | return forward_indexes, backward_indexes 12 | 13 | 14 | def take_indexes(sequences, indexes): 15 | return torch.gather(sequences, 1, repeat(indexes, 'b t -> b t c', c=sequences.shape[-1])) 16 | 17 | 18 | class PatchShuffle(torch.nn.Module): 19 | def __init__(self) -> None: 20 | super().__init__() 21 | 22 | def forward(self, patches: torch.Tensor): 23 | B, T, C = patches.shape 24 | indexes = random_indexes(T) 25 | forward_indexes = torch.as_tensor(indexes[0], dtype=torch.long).to( 26 | patches.device) 27 | forward_indexes = repeat( 28 | forward_indexes, 't -> g t', g=B) 29 | backward_indexes = torch.as_tensor(indexes[1], dtype=torch.long).to( 30 | patches.device) 31 | backward_indexes = repeat( 32 | backward_indexes, 't -> g t', g=B) 33 | 34 | patches = take_indexes(patches, forward_indexes) 35 | return patches, forward_indexes, backward_indexes 36 | 37 | 38 | class PatchMix(torch.nn.Module): 39 | 40 | def forward(self, patches: torch.Tensor, m): 41 | B, T, C = patches.shape 42 | S = T // m 43 | mix_offset = int(S * m) 44 | mix_patches = patches[:, :mix_offset, :] 45 | mix_patches = rearrange( 46 | mix_patches, 'g (m s) c -> (g m) s c', s=S) 47 | 48 | L = mix_patches.shape[0] 49 | 50 | ids = torch.arange(L).cuda() 51 | indexes = (ids + ids % m * m) % L 52 | mix_patches = torch.gather(mix_patches, 0, repeat( 53 | indexes, 'l -> l s c', c=mix_patches.shape[-1], s=S)) 54 | 55 | ids = torch.arange(B).view(-1, 1) 56 | target = (ids + torch.arange(m)) % B 57 | mix_target = ((ids - m + 1) + torch.arange(m * 2 - 1) + B) % B 58 | 59 | mix_patches = rearrange(mix_patches, '(g m) s c -> g (m s) c', g=B) 60 | patches[:, :mix_offset, :] = mix_patches 61 | return patches, target, mix_target 62 | 63 | 64 | class PatchMixer: 65 | def __init__(self, num_classes, mix_s, mix_n=1, mix_p=0.0, smoothing=0.1): 66 | self.mix_s = mix_s 67 | self.mix_p = mix_p 68 | self.mix_n = mix_n 69 | self.patch_shuffle = PatchShuffle() 70 | self.mix = PatchMix() 71 | self.smoothing = smoothing 72 | self.num_classes = num_classes 73 | 74 | def _one_hot(self, target, num_classes, on_value=1., off_value=0., device='cuda'): 75 | return torch.full((target.size()[0], num_classes), off_value, device=device).scatter_(1, target, on_value) 76 | 77 | @torch.no_grad() 78 | def __call__(self, X, target): 79 | N = X.shape[0] 80 | m = np.random.choice(self.mix_n) if isinstance(self.mix_n, list) else self.mix_n 81 | use_mix = np.random.rand() < self.mix_p and m > 1 82 | if use_mix: 83 | patch_size = np.random.choice(self.mix_s) if isinstance(self.mix_s, list) else self.mix_s 84 | 85 | patches = rearrange( 86 | X, 'b c (w p1) (h p2) -> b (w h) (c p1 p2)', p1=patch_size, p2=patch_size) 87 | patches, forward_indexes, backward_indexes = self.patch_shuffle( 88 | patches) 89 | patches, target, mix_target = self.mix(patches, m) 90 | patches = take_indexes(patches, backward_indexes) 91 | X = rearrange(patches, 'b (w h) (c p1 p2) -> b c (w p1) (h p2)', p1=patch_size, p2=patch_size, 92 | w=int(np.sqrt(patches.shape[1]))) 93 | else: 94 | m = 1 95 | target = target.view(-1, 1) 96 | mix_target = target 97 | # add offset 98 | offset = N * torch.distributed.get_rank() 99 | target = (target + offset).cuda() 100 | mix_target = (mix_target + offset).cuda() 101 | 102 | off_value = self.smoothing / self.num_classes 103 | true_num = target.shape[1] 104 | on_value = (1.0 - self.smoothing) / true_num + off_value 105 | soft_target = self._one_hot( 106 | target, self.num_classes, on_value, off_value) 107 | 108 | ids = torch.arange(mix_target.shape[1]) 109 | weights = (1.0 - torch.abs(m - ids - 1) / m) 110 | on_value = (1.0 - self.smoothing) * weights / m + off_value 111 | soft_mix_target = self._one_hot( 112 | mix_target, self.num_classes, on_value.expand([mix_target.shape[0], -1]).cuda(), 113 | off_value) 114 | return X, soft_target, soft_mix_target 115 | 116 | -------------------------------------------------------------------------------- /module/vits.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial, reduce 3 | from operator import mul 4 | 5 | import torch 6 | import torch.nn as nn 7 | from timm.models.layers import PatchEmbed 8 | from timm.models.vision_transformer import VisionTransformer 9 | 10 | 11 | class MultiPatchEmbed(PatchEmbed): 12 | def forward(self, x): 13 | x = self.proj(x) 14 | if self.flatten: 15 | x = x.flatten(2).transpose(1, 2) 16 | x = self.norm(x) 17 | return x 18 | 19 | 20 | class VisionTransformerMoCo(VisionTransformer): 21 | def __init__(self, **kwargs): 22 | super().__init__(**kwargs) 23 | self.patch_embed = MultiPatchEmbed( 24 | img_size=self.patch_embed.img_size[0], 25 | patch_size=self.patch_embed.patch_size[0], 26 | embed_dim=self.embed_dim, 27 | bias=True, 28 | ) 29 | for name, m in self.named_modules(): 30 | if isinstance(m, nn.Linear): 31 | if 'qkv' in name: 32 | val = math.sqrt( 33 | 6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) 34 | nn.init.uniform_(m.weight, -val, val) 35 | else: 36 | nn.init.xavier_uniform_(m.weight) 37 | nn.init.zeros_(m.bias) 38 | nn.init.normal_(self.cls_token, std=1e-6) 39 | 40 | if isinstance(self.patch_embed, PatchEmbed): 41 | val = math.sqrt( 42 | 6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim)) 43 | nn.init.uniform_(self.patch_embed.proj.weight, -val, val) 44 | nn.init.zeros_(self.patch_embed.proj.bias) 45 | 46 | def _pos_embed(self, x, w, h): 47 | if self.cls_token is not None: 48 | x = torch.cat((self.cls_token.expand( 49 | x.shape[0], -1, -1), x), dim=1) 50 | x = x + self.interpolate_pos_encoding(x, w, h) 51 | return self.pos_drop(x) 52 | 53 | def interpolate_pos_encoding(self, x, w, h): 54 | npatch = x.shape[1] - 1 55 | N = self.pos_embed.shape[1] - 1 56 | if npatch == N and w == h: 57 | return self.pos_embed 58 | class_pos_embed = self.pos_embed[:, 0] 59 | patch_pos_embed = self.pos_embed[:, 1:] 60 | dim = x.shape[-1] 61 | w0 = w // self.patch_embed.patch_size[0] 62 | h0 = h // self.patch_embed.patch_size[0] 63 | w0, h0 = w0 + 0.1, h0 + 0.1 64 | patch_pos_embed = nn.functional.interpolate( 65 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int( 66 | math.sqrt(N)), dim).permute(0, 3, 1, 2), 67 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 68 | mode='bicubic', 69 | ) 70 | assert int( 71 | w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 72 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 73 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 74 | 75 | def forward_features(self, x): 76 | b, t, w, h = x.shape 77 | x = self.patch_embed(x) 78 | x = self._pos_embed(x, w, h) 79 | x = self.norm_pre(x) 80 | x = self.blocks(x) 81 | x = self.norm(x) 82 | return x 83 | 84 | def forward(self, x): 85 | x = self.forward_features(x) 86 | x = self.forward_head(x) 87 | return x 88 | 89 | 90 | def ViT(**kwargs): 91 | model = VisionTransformerMoCo( 92 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 93 | return model 94 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.0 2 | torchvision==0.14.0 3 | timm==0.6.11 4 | einops==0.6.0 5 | tensorboard==2.11.0 -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import time 4 | import logging 5 | 6 | from torch.utils.tensorboard import SummaryWriter 7 | 8 | 9 | class Logger(SummaryWriter): 10 | def __init__(self, log_root='./', name='', logger_name=''): 11 | os.makedirs(log_root, exist_ok=True) 12 | if logger_name == '': 13 | date = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) 14 | self.log_name = '{}_{}'.format(name, date) 15 | # self.log_name = date 16 | log_dir = os.path.join(log_root, self.log_name) 17 | super(Logger, self).__init__(log_dir, flush_secs=1) 18 | else: 19 | self.log_name = logger_name 20 | log_dir = os.path.join(log_root, self.log_name) 21 | super(Logger, self).__init__(log_dir, flush_secs=1) 22 | 23 | 24 | def console_logger(log_root, logger_name) -> logging.Logger: 25 | log_file = logger_name + '.log' 26 | log_path = os.path.join(log_root, log_file) 27 | 28 | logger = logging.getLogger(logger_name) 29 | logger.setLevel(logging.INFO) 30 | formatter = logging.Formatter('%(asctime)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 31 | 32 | handler1 = logging.FileHandler(log_path) 33 | handler1.setFormatter(formatter) 34 | 35 | handler2 = logging.StreamHandler() 36 | handler2.setFormatter(formatter) 37 | 38 | logger.addHandler(handler1) 39 | logger.addHandler(handler2) 40 | logger.propagate = False 41 | 42 | return logger 43 | 44 | 45 | if __name__ == '__main__': 46 | import math 47 | logger = Logger('./log/', 'test') 48 | 49 | nsamples = 100 50 | for i in range(nsamples): 51 | x = math.cos(2 * math.pi * i / nsamples) 52 | logger.add_scalar('x', x, i) 53 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import math 3 | import os 4 | import random 5 | import shutil 6 | import time 7 | from collections import defaultdict, deque 8 | 9 | import numpy as np 10 | import torch 11 | import torch.distributed as dist 12 | 13 | 14 | def fix_random_seeds(seed=31): 15 | random.seed(seed) 16 | os.environ['PYTHONHASHSEED'] = str(seed) 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | np.random.seed(seed) 20 | 21 | 22 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 23 | torch.save(state, filename) 24 | if is_best: 25 | shutil.copyfile(filename, 'model_best.pth.tar') 26 | 27 | 28 | class AverageMeter(object): 29 | """Computes and stores the average and current value""" 30 | 31 | def __init__(self, name, fmt=':f'): 32 | self.name = name 33 | self.fmt = fmt 34 | self.reset() 35 | 36 | def reset(self): 37 | self.val = 0 38 | self.sum = 0 39 | self.count = 0 40 | 41 | def update(self, val, n=1): 42 | self.val = val 43 | self.sum += val * n 44 | self.count += n 45 | 46 | def synchronize_between_processes(self): 47 | pack = torch.tensor([self.sum, self.count], device='cuda') 48 | dist.barrier() 49 | dist.all_reduce(pack) 50 | self.sum, self.count = pack.tolist() 51 | 52 | @property 53 | def avg(self): 54 | return self.sum / self.count 55 | 56 | def __str__(self): 57 | fmtstr = '{} {' + self.fmt + '} ({' + self.fmt + '})' 58 | return fmtstr.format(self.name, self.val, self.avg) 59 | 60 | 61 | 62 | 63 | def adjust_moco_momentum(epoch, args): 64 | """Adjust moco momentum based on current epoch""" 65 | m = 1. - 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) * (1. - args.moco_m) 66 | return m 67 | 68 | 69 | def get_params_groups(named_parameters): 70 | regularized = [] 71 | not_regularized = [] 72 | for name, param in named_parameters: 73 | if not param.requires_grad: 74 | continue 75 | # we do not regularize biases nor Norm parameters 76 | if name.endswith(".bias") or len(param.shape) == 1: 77 | not_regularized.append(param) 78 | else: 79 | regularized.append(param) 80 | return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}] 81 | 82 | 83 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 84 | warmup_schedule = np.array([]) 85 | warmup_iters = warmup_epochs * niter_per_ep 86 | if warmup_epochs > 0: 87 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 88 | 89 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 90 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 91 | 92 | schedule = np.concatenate((warmup_schedule, schedule)) 93 | assert len(schedule) == epochs * niter_per_ep 94 | return schedule 95 | 96 | 97 | def is_dist_avail_and_initialized(): 98 | if not dist.is_available(): 99 | return False 100 | if not dist.is_initialized(): 101 | return False 102 | return True 103 | 104 | 105 | class SmoothedValue(object): 106 | """Track a series of values and provide access to smoothed values over a 107 | window or the global series average. 108 | """ 109 | 110 | def __init__(self, window_size=20, fmt=None): 111 | if fmt is None: 112 | fmt = "{median:.6f} ({global_avg:.6f})" 113 | self.deque = deque(maxlen=window_size) 114 | self.total = 0.0 115 | self.count = 0 116 | self.fmt = fmt 117 | 118 | def update(self, value, n=1): 119 | self.deque.append(value) 120 | self.count += n 121 | self.total += value * n 122 | 123 | def synchronize_between_processes(self): 124 | """ 125 | Warning: does not synchronize the deque! 126 | """ 127 | if not is_dist_avail_and_initialized(): 128 | return 129 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 130 | dist.barrier() 131 | dist.all_reduce(t) 132 | t = t.tolist() 133 | self.count = int(t[0]) 134 | self.total = t[1] 135 | 136 | @property 137 | def median(self): 138 | d = torch.tensor(list(self.deque)) 139 | return d.median().item() 140 | 141 | @property 142 | def avg(self): 143 | d = torch.tensor(list(self.deque), dtype=torch.float32) 144 | return d.mean().item() 145 | 146 | @property 147 | def global_avg(self): 148 | return self.total / self.count 149 | 150 | @property 151 | def max(self): 152 | return max(self.deque) 153 | 154 | @property 155 | def value(self): 156 | return self.deque[-1] 157 | 158 | def __str__(self): 159 | return self.fmt.format( 160 | median=self.median, 161 | avg=self.avg, 162 | global_avg=self.global_avg, 163 | max=self.max, 164 | value=self.value) 165 | 166 | 167 | class MetricLogger(object): 168 | def __init__(self, delimiter="\t"): 169 | self.meters = defaultdict(SmoothedValue) 170 | self.delimiter = delimiter 171 | 172 | def update(self, **kwargs): 173 | for k, v in kwargs.items(): 174 | if isinstance(v, torch.Tensor): 175 | v = v.item() 176 | assert isinstance(v, (float, int)) 177 | self.meters[k].update(v) 178 | 179 | def __getattr__(self, attr): 180 | if attr in self.meters: 181 | return self.meters[attr] 182 | if attr in self.__dict__: 183 | return self.__dict__[attr] 184 | raise AttributeError("'{}' object has no attribute '{}'".format( 185 | type(self).__name__, attr)) 186 | 187 | def __str__(self): 188 | loss_str = [] 189 | for name, meter in self.meters.items(): 190 | loss_str.append( 191 | "{}: {}".format(name, str(meter)) 192 | ) 193 | return self.delimiter.join(loss_str) 194 | 195 | def synchronize_between_processes(self): 196 | for meter in self.meters.values(): 197 | meter.synchronize_between_processes() 198 | 199 | def add_meter(self, name, meter): 200 | self.meters[name] = meter 201 | 202 | def log_every(self, iterable, print_freq, header=None): 203 | i = 0 204 | if not header: 205 | header = '' 206 | start_time = time.time() 207 | end = time.time() 208 | iter_time = SmoothedValue(fmt='{avg:.6f}') 209 | data_time = SmoothedValue(fmt='{avg:.6f}') 210 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 211 | if torch.cuda.is_available(): 212 | log_msg = self.delimiter.join([ 213 | header, 214 | '[{0' + space_fmt + '}/{1}]', 215 | 'eta: {eta}', 216 | '{meters}', 217 | 'time: {time}', 218 | 'data: {data}', 219 | 'max mem: {memory:.0f}' 220 | ]) 221 | else: 222 | log_msg = self.delimiter.join([ 223 | header, 224 | '[{0' + space_fmt + '}/{1}]', 225 | 'eta: {eta}', 226 | '{meters}', 227 | 'time: {time}', 228 | 'data: {data}' 229 | ]) 230 | MB = 1024.0 * 1024.0 231 | for obj in iterable: 232 | data_time.update(time.time() - end) 233 | yield obj 234 | iter_time.update(time.time() - end) 235 | if i % print_freq == 0 or i == len(iterable) - 1: 236 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 237 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 238 | if torch.cuda.is_available(): 239 | print(log_msg.format( 240 | i, len(iterable), eta=eta_string, 241 | meters=str(self), 242 | time=str(iter_time), data=str(data_time), 243 | memory=torch.cuda.max_memory_allocated() / MB)) 244 | else: 245 | print(log_msg.format( 246 | i, len(iterable), eta=eta_string, 247 | meters=str(self), 248 | time=str(iter_time), data=str(data_time))) 249 | i += 1 250 | end = time.time() 251 | total_time = time.time() - start_time 252 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 253 | print('{} Total time: {} ({:.6f} s / it)'.format( 254 | header, total_time_str, total_time / len(iterable))) 255 | --------------------------------------------------------------------------------