├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── figures └── teaser.png ├── finetune ├── dataset_folder.py ├── datasets.py ├── engine_for_finetuning.py ├── modeling_finetune.py ├── optim_factory.py ├── run_class_finetuning.py └── utils.py ├── lib ├── __init__.py ├── augment.py ├── builder.py ├── dataload_optim.py ├── logger.py └── misc.py ├── main.py ├── main_lincls.py └── vits.py /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | 13 | # User-specific files (MonoDevelop/Xamarin Studio) 14 | *.userprefs 15 | 16 | # Mono auto generated files 17 | mono_crash.* 18 | 19 | # Build results 20 | [Dd]ebug/ 21 | [Dd]ebugPublic/ 22 | [Rr]elease/ 23 | [Rr]eleases/ 24 | x64/ 25 | x86/ 26 | [Aa][Rr][Mm]/ 27 | [Aa][Rr][Mm]64/ 28 | bld/ 29 | [Bb]in/ 30 | [Oo]bj/ 31 | [Ll]og/ 32 | [Ll]ogs/ 33 | 34 | # Visual Studio 2015/2017 cache/options directory 35 | .vs/ 36 | # Uncomment if you have tasks that create the project's static files in wwwroot 37 | #wwwroot/ 38 | 39 | # Visual Studio 2017 auto generated files 40 | Generated\ Files/ 41 | 42 | # MSTest test Results 43 | [Tt]est[Rr]esult*/ 44 | [Bb]uild[Ll]og.* 45 | 46 | # NUnit 47 | *.VisualState.xml 48 | TestResult.xml 49 | nunit-*.xml 50 | 51 | # Build Results of an ATL Project 52 | [Dd]ebugPS/ 53 | [Rr]eleasePS/ 54 | dlldata.c 55 | 56 | # Benchmark Results 57 | BenchmarkDotNet.Artifacts/ 58 | 59 | # .NET Core 60 | project.lock.json 61 | project.fragment.lock.json 62 | artifacts/ 63 | 64 | # StyleCop 65 | StyleCopReport.xml 66 | 67 | # Files built by Visual Studio 68 | *_i.c 69 | *_p.c 70 | *_h.h 71 | *.ilk 72 | *.meta 73 | *.obj 74 | *.iobj 75 | *.pch 76 | *.pdb 77 | *.ipdb 78 | *.pgc 79 | *.pgd 80 | *.rsp 81 | *.sbr 82 | *.tlb 83 | *.tli 84 | *.tlh 85 | *.tmp 86 | *.tmp_proj 87 | *_wpftmp.csproj 88 | *.log 89 | *.vspscc 90 | *.vssscc 91 | .builds 92 | *.pidb 93 | *.svclog 94 | *.scc 95 | 96 | # Chutzpah Test files 97 | _Chutzpah* 98 | 99 | # Visual C++ cache files 100 | ipch/ 101 | *.aps 102 | *.ncb 103 | *.opendb 104 | *.opensdf 105 | *.sdf 106 | *.cachefile 107 | *.VC.db 108 | *.VC.VC.opendb 109 | 110 | # Visual Studio profiler 111 | *.psess 112 | *.vsp 113 | *.vspx 114 | *.sap 115 | 116 | # Visual Studio Trace Files 117 | *.e2e 118 | 119 | # TFS 2012 Local Workspace 120 | $tf/ 121 | 122 | # Guidance Automation Toolkit 123 | *.gpState 124 | 125 | # ReSharper is a .NET coding add-in 126 | _ReSharper*/ 127 | *.[Rr]e[Ss]harper 128 | *.DotSettings.user 129 | 130 | # TeamCity is a build add-in 131 | _TeamCity* 132 | 133 | # DotCover is a Code Coverage Tool 134 | *.dotCover 135 | 136 | # AxoCover is a Code Coverage Tool 137 | .axoCover/* 138 | !.axoCover/settings.json 139 | 140 | # Visual Studio code coverage results 141 | *.coverage 142 | *.coveragexml 143 | 144 | # NCrunch 145 | _NCrunch_* 146 | .*crunch*.local.xml 147 | nCrunchTemp_* 148 | 149 | # MightyMoose 150 | *.mm.* 151 | AutoTest.Net/ 152 | 153 | # Web workbench (sass) 154 | .sass-cache/ 155 | 156 | # Installshield output folder 157 | [Ee]xpress/ 158 | 159 | # DocProject is a documentation generator add-in 160 | DocProject/buildhelp/ 161 | DocProject/Help/*.HxT 162 | DocProject/Help/*.HxC 163 | DocProject/Help/*.hhc 164 | DocProject/Help/*.hhk 165 | DocProject/Help/*.hhp 166 | DocProject/Help/Html2 167 | DocProject/Help/html 168 | 169 | # Click-Once directory 170 | publish/ 171 | 172 | # Publish Web Output 173 | *.[Pp]ublish.xml 174 | *.azurePubxml 175 | # Note: Comment the next line if you want to checkin your web deploy settings, 176 | # but database connection strings (with potential passwords) will be unencrypted 177 | *.pubxml 178 | *.publishproj 179 | 180 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 181 | # checkin your Azure Web App publish settings, but sensitive information contained 182 | # in these scripts will be unencrypted 183 | PublishScripts/ 184 | 185 | # NuGet Packages 186 | *.nupkg 187 | # NuGet Symbol Packages 188 | *.snupkg 189 | # The packages folder can be ignored because of Package Restore 190 | **/[Pp]ackages/* 191 | # except build/, which is used as an MSBuild target. 192 | !**/[Pp]ackages/build/ 193 | # Uncomment if necessary however generally it will be regenerated when needed 194 | #!**/[Pp]ackages/repositories.config 195 | # NuGet v3's project.json files produces more ignorable files 196 | *.nuget.props 197 | *.nuget.targets 198 | 199 | # Microsoft Azure Build Output 200 | csx/ 201 | *.build.csdef 202 | 203 | # Microsoft Azure Emulator 204 | ecf/ 205 | rcf/ 206 | 207 | # Windows Store app package directories and files 208 | AppPackages/ 209 | BundleArtifacts/ 210 | Package.StoreAssociation.xml 211 | _pkginfo.txt 212 | *.appx 213 | *.appxbundle 214 | *.appxupload 215 | 216 | # Visual Studio cache files 217 | # files ending in .cache can be ignored 218 | *.[Cc]ache 219 | # but keep track of directories ending in .cache 220 | !?*.[Cc]ache/ 221 | 222 | # Others 223 | ClientBin/ 224 | ~$* 225 | *~ 226 | *.dbmdl 227 | *.dbproj.schemaview 228 | *.jfm 229 | *.pfx 230 | *.publishsettings 231 | orleans.codegen.cs 232 | 233 | # Including strong name files can present a security risk 234 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 235 | #*.snk 236 | 237 | # Since there are multiple workflows, uncomment next line to ignore bower_components 238 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 239 | #bower_components/ 240 | 241 | # RIA/Silverlight projects 242 | Generated_Code/ 243 | 244 | # Backup & report files from converting an old project file 245 | # to a newer Visual Studio version. Backup files are not needed, 246 | # because we have git ;-) 247 | _UpgradeReport_Files/ 248 | Backup*/ 249 | UpgradeLog*.XML 250 | UpgradeLog*.htm 251 | ServiceFabricBackup/ 252 | *.rptproj.bak 253 | 254 | # SQL Server files 255 | *.mdf 256 | *.ldf 257 | *.ndf 258 | 259 | # Business Intelligence projects 260 | *.rdl.data 261 | *.bim.layout 262 | *.bim_*.settings 263 | *.rptproj.rsuser 264 | *- [Bb]ackup.rdl 265 | *- [Bb]ackup ([0-9]).rdl 266 | *- [Bb]ackup ([0-9][0-9]).rdl 267 | 268 | # Microsoft Fakes 269 | FakesAssemblies/ 270 | 271 | # GhostDoc plugin setting file 272 | *.GhostDoc.xml 273 | 274 | # Node.js Tools for Visual Studio 275 | .ntvs_analysis.dat 276 | node_modules/ 277 | 278 | # Visual Studio 6 build log 279 | *.plg 280 | 281 | # Visual Studio 6 workspace options file 282 | *.opt 283 | 284 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 285 | *.vbw 286 | 287 | # Visual Studio LightSwitch build output 288 | **/*.HTMLClient/GeneratedArtifacts 289 | **/*.DesktopClient/GeneratedArtifacts 290 | **/*.DesktopClient/ModelManifest.xml 291 | **/*.Server/GeneratedArtifacts 292 | **/*.Server/ModelManifest.xml 293 | _Pvt_Extensions 294 | 295 | # Paket dependency manager 296 | .paket/paket.exe 297 | paket-files/ 298 | 299 | # FAKE - F# Make 300 | .fake/ 301 | 302 | # CodeRush personal settings 303 | .cr/personal 304 | 305 | # Python Tools for Visual Studio (PTVS) 306 | __pycache__/ 307 | *.pyc 308 | 309 | # Cake - Uncomment if you are using it 310 | # tools/** 311 | # !tools/packages.config 312 | 313 | # Tabs Studio 314 | *.tss 315 | 316 | # Telerik's JustMock configuration file 317 | *.jmconfig 318 | 319 | # BizTalk build output 320 | *.btp.cs 321 | *.btm.cs 322 | *.odx.cs 323 | *.xsd.cs 324 | 325 | # OpenCover UI analysis results 326 | OpenCover/ 327 | 328 | # Azure Stream Analytics local run output 329 | ASALocalRun/ 330 | 331 | # MSBuild Binary and Structured Log 332 | *.binlog 333 | 334 | # NVidia Nsight GPU debugger configuration file 335 | *.nvuser 336 | 337 | # MFractors (Xamarin productivity tool) working folder 338 | .mfractor/ 339 | 340 | # Local History for Visual Studio 341 | .localhistory/ 342 | 343 | # BeatPulse healthcheck temp database 344 | healthchecksdb 345 | 346 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 347 | MigrationBackup/ 348 | 349 | # Ionide (cross platform F# VS Code tools) working folder 350 | .ionide/ 351 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Extreme Masking for Learning Instance and Distributed Visual Representations 2 | 3 | This repo constains the official pytorch implementation for the ExtreMA paper [(arxiv)](https://arxiv.org/abs/2206.04667). ExtreMA explores to treat spatial token masking as a data augmentation for siamese representation learning. It follows the plain BYOL model with supervision created from the masking operation. ExtreMA not only learns a strong instance representation which captures the holistic image, but also meaningful distributed representations for each individual tokens. Multi-masking which processes a paralleled number of masks is developed to greatly accelerate training. 4 | 5 |

6 | 7 |

8 | 9 | 10 | ## Released Pretrained Model 11 | 12 | We release the following 4 representative models at the moment. The wall time is measured by a single node of 8xV100 GPUs with Pytorch environment 1.13. ExtreMA is signficantly more efficient and faster than competing masked modeling and siamese representation learning approaches. 13 | 14 | | name | pretrain dataset | epochs | masking | color-aug | wall time | linear | finetune | link | 15 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 16 | | ViT-Base | ImageNet-1K | 300 | 80%x5 | No | 50 hrs | 67.1 | 82.9| [full ckpt](https://frontiers.blob.core.windows.net/pretraining/checkpoints/extrema/extrema_mask08x5_1k_300ep.pth) | 17 | | ViT-Base | ImageNet-1K | 300 | 80%x5 | Yes | 50 hrs | 73.3 | 83.7| [full ckpt](https://frontiers.blob.core.windows.net/pretraining/checkpoints/extrema/extrema_mask08x5_color_1k_300ep.pth) | 18 | | ViT-Base | ImageNet-1K | 300 | 90%x8 | Yes | 46 hrs | 68.4 | 83.5| [full ckpt](https://frontiers.blob.core.windows.net/pretraining/checkpoints/extrema/extrema_mask09x8_color_1k_300ep.pth) | 19 | | ViT-Base | ImageNet-22K | 30 | 80%x5 | Yes | 56 hrs | 74.5 | 83.9| [full ckpt](https://frontiers.blob.core.windows.net/pretraining/checkpoints/extrema/extrema_mask08x5_color_22k_30ep.pth) | 20 | 21 | ## Pre-training 22 | 23 | To train ExtreMA, follow the command: 24 | 25 | ``` 26 | python -m torch.distributed.launch --nproc_per_node=8 main.py \ 27 | -a vit_base -b 2048 \ 28 | --lr=1.5e-4 --weight-decay=.1 --weight-decay-end=.1 \ 29 | --opt=adamw \ 30 | --aug-spatialconsistent-color \ 31 | --loss byol \ 32 | --epochs=300 --warmup-epochs=40 --save-freq 5 \ 33 | --opt-betas 0.9 0.95 \ 34 | --drop_path_rate 0.1 --attn_drop_rate 0. \ 35 | --layer_scale_init_value 0.1 --class_attention_layers 2 \ 36 | --mask-ratio 0.8 --num-masks 5 \ 37 | --ema-momentum 0.996 \ 38 | --proj-dim 256 \ 39 | --dist-url 'tcp://localhost:10001' \ 40 | --multiprocessing-distributed \ 41 | --seed 0 \ 42 | --log_dir $LOG_DIR \ 43 | --output_dir $SAVE_DIR \ 44 | $DATA_DIR \ 45 | ``` 46 | 47 | ## Linear and Finetuning Evaluations on ImageNet-1k 48 | 49 | To linear probe a pretrained model, 50 | ``` 51 | python -m torch.distributed.launch --nproc_per_node=8 main_lincls.py \ 52 | -a vit_base --lr 0.1 \ 53 | -b 4096 --optimizer sgd --warmup-epochs 10 \ 54 | --log_dir ./ --eval_momentum \ 55 | --dist-url 'tcp://localhost:10001' \ 56 | --multiprocessing-distributed \ 57 | --pretrained $MODEL \ 58 | $DATA 59 | ``` 60 | 61 | To finetune the model end-to-end, 62 | ``` 63 | python -m torch.distributed.launch --nproc_per_node=8 finetune/run_class_finetuning.py \ 64 | --model vit_base_patch16_224 \ 65 | --data_path $DATA_DIR \ 66 | --use_mean_pooling \ 67 | --color_jitter 0.4 --reprob 0.25 \ 68 | --finetune $MODEL --output_dir $LOG_DIR \ 69 | --layer_decay 0.65 \ 70 | --lr 5e-4 \ 71 | --batch_size 128 --update_freq 1 --opt adamw --opt_betas 0.9 0.999 \ 72 | --weight_decay 0.05 --warmup_epochs 5 --drop_path 0.2 --epochs 100 \ 73 | --dist_eval \ 74 | ``` 75 | The finetuning code is based on BEiT, with important modifications of removing the [cls] token at the ViT input. 76 | 77 | ## Other Downstream Evaluations 78 | 79 | For semantic segmentation and instance detection, we follow [the CAE codebase](https://github.com/lxtGH/CAE). Care must be taken to remove the [cls] token at the input for ExtreMA. 80 | 81 | ## Acknowledgement 82 | 83 | The ExtreMA code sigificantly borrows content from MoCo-v3, MAE, BEiT and the timm library. 84 | 85 | ## Citation 86 | 87 | ``` 88 | @article{wu2022extreme, 89 | title={Extreme Masking for Learning Instance and Distributed Visual Representations}, 90 | author={Wu, Zhirong and Lai, Zihang and Sun, Xiao and Lin, Stephen}, 91 | journal={arXiv preprint arXiv:2206.04667}, 92 | year={2022} 93 | } 94 | ``` 95 | 96 | ## Contributing 97 | 98 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 99 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 100 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 101 | 102 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 103 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 104 | provided by the bot. You will only need to do this once across all repos using our CLA. 105 | 106 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 107 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 108 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 109 | 110 | ## Trademarks 111 | 112 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 113 | trademarks or logos is subject to and must follow 114 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 115 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 116 | Any use of third-party trademarks or logos are subject to those third-party's policies. 117 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ExtreMA/bf30b64ec0046524d0c373b761df323b9328f8dc/figures/teaser.png -------------------------------------------------------------------------------- /finetune/dataset_folder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Based on BEiT, timm, DINO and DeiT code bases 3 | # https://github.com/microsoft/unilm/tree/master/beit 4 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 5 | # https://github.com/facebookresearch/deit 6 | # https://github.com/facebookresearch/dino 7 | # --------------------------------------------------------' 8 | from torchvision.datasets.vision import VisionDataset 9 | 10 | from PIL import Image 11 | 12 | import os 13 | import os.path 14 | import random 15 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple 16 | 17 | 18 | def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: 19 | """Checks if a file is an allowed extension. 20 | 21 | Args: 22 | filename (string): path to a file 23 | extensions (tuple of strings): extensions to consider (lowercase) 24 | 25 | Returns: 26 | bool: True if the filename ends with one of given extensions 27 | """ 28 | return filename.lower().endswith(extensions) 29 | 30 | 31 | def is_image_file(filename: str) -> bool: 32 | """Checks if a file is an allowed image extension. 33 | 34 | Args: 35 | filename (string): path to a file 36 | 37 | Returns: 38 | bool: True if the filename ends with a known image extension 39 | """ 40 | return has_file_allowed_extension(filename, IMG_EXTENSIONS) 41 | 42 | 43 | def make_dataset( 44 | directory: str, 45 | class_to_idx: Dict[str, int], 46 | extensions: Optional[Tuple[str, ...]] = None, 47 | is_valid_file: Optional[Callable[[str], bool]] = None, 48 | ) -> List[Tuple[str, int]]: 49 | instances = [] 50 | directory = os.path.expanduser(directory) 51 | both_none = extensions is None and is_valid_file is None 52 | both_something = extensions is not None and is_valid_file is not None 53 | if both_none or both_something: 54 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") 55 | if extensions is not None: 56 | def is_valid_file(x: str) -> bool: 57 | return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) 58 | is_valid_file = cast(Callable[[str], bool], is_valid_file) 59 | for target_class in sorted(class_to_idx.keys()): 60 | class_index = class_to_idx[target_class] 61 | target_dir = os.path.join(directory, target_class) 62 | if not os.path.isdir(target_dir): 63 | continue 64 | for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): 65 | for fname in sorted(fnames): 66 | path = os.path.join(root, fname) 67 | if is_valid_file(path): 68 | item = path, class_index 69 | instances.append(item) 70 | return instances 71 | 72 | 73 | class DatasetFolder(VisionDataset): 74 | """A generic data loader where the samples are arranged in this way: :: 75 | 76 | root/class_x/xxx.ext 77 | root/class_x/xxy.ext 78 | root/class_x/xxz.ext 79 | 80 | root/class_y/123.ext 81 | root/class_y/nsdf3.ext 82 | root/class_y/asd932_.ext 83 | 84 | Args: 85 | root (string): Root directory path. 86 | loader (callable): A function to load a sample given its path. 87 | extensions (tuple[string]): A list of allowed extensions. 88 | both extensions and is_valid_file should not be passed. 89 | transform (callable, optional): A function/transform that takes in 90 | a sample and returns a transformed version. 91 | E.g, ``transforms.RandomCrop`` for images. 92 | target_transform (callable, optional): A function/transform that takes 93 | in the target and transforms it. 94 | is_valid_file (callable, optional): A function that takes path of a file 95 | and check if the file is a valid file (used to check of corrupt files) 96 | both extensions and is_valid_file should not be passed. 97 | 98 | Attributes: 99 | classes (list): List of the class names sorted alphabetically. 100 | class_to_idx (dict): Dict with items (class_name, class_index). 101 | samples (list): List of (sample path, class_index) tuples 102 | targets (list): The class_index value for each image in the dataset 103 | """ 104 | 105 | def __init__( 106 | self, 107 | root: str, 108 | loader: Callable[[str], Any], 109 | extensions: Optional[Tuple[str, ...]] = None, 110 | transform: Optional[Callable] = None, 111 | target_transform: Optional[Callable] = None, 112 | is_valid_file: Optional[Callable[[str], bool]] = None, 113 | ) -> None: 114 | super(DatasetFolder, self).__init__(root, transform=transform, 115 | target_transform=target_transform) 116 | classes, class_to_idx = self._find_classes(self.root) 117 | samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) 118 | if len(samples) == 0: 119 | msg = "Found 0 files in subfolders of: {}\n".format(self.root) 120 | if extensions is not None: 121 | msg += "Supported extensions are: {}".format(",".join(extensions)) 122 | raise RuntimeError(msg) 123 | 124 | self.loader = loader 125 | self.extensions = extensions 126 | 127 | self.classes = classes 128 | self.class_to_idx = class_to_idx 129 | self.samples = samples 130 | self.targets = [s[1] for s in samples] 131 | 132 | def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: 133 | """ 134 | Finds the class folders in a dataset. 135 | 136 | Args: 137 | dir (string): Root directory path. 138 | 139 | Returns: 140 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 141 | 142 | Ensures: 143 | No class is a subdirectory of another. 144 | """ 145 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 146 | classes.sort() 147 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 148 | return classes, class_to_idx 149 | 150 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 151 | """ 152 | Args: 153 | index (int): Index 154 | 155 | Returns: 156 | tuple: (sample, target) where target is class_index of the target class. 157 | """ 158 | while True: 159 | try: 160 | path, target = self.samples[index] 161 | sample = self.loader(path) 162 | break 163 | except Exception as e: 164 | print(e) 165 | index = random.randint(0, len(self.samples) - 1) 166 | 167 | if self.transform is not None: 168 | sample = self.transform(sample) 169 | if self.target_transform is not None: 170 | target = self.target_transform(target) 171 | 172 | return sample, target 173 | 174 | def __len__(self) -> int: 175 | return len(self.samples) 176 | 177 | 178 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') 179 | 180 | 181 | def pil_loader(path: str) -> Image.Image: 182 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 183 | with open(path, 'rb') as f: 184 | img = Image.open(f) 185 | return img.convert('RGB') 186 | 187 | 188 | # TODO: specify the return type 189 | def accimage_loader(path: str) -> Any: 190 | import accimage 191 | try: 192 | return accimage.Image(path) 193 | except IOError: 194 | # Potentially a decoding problem, fall back to PIL.Image 195 | return pil_loader(path) 196 | 197 | 198 | def default_loader(path: str) -> Any: 199 | from torchvision import get_image_backend 200 | if get_image_backend() == 'accimage': 201 | return accimage_loader(path) 202 | else: 203 | return pil_loader(path) 204 | 205 | 206 | class ImageFolder(DatasetFolder): 207 | """A generic data loader where the images are arranged in this way: :: 208 | 209 | root/dog/xxx.png 210 | root/dog/xxy.png 211 | root/dog/xxz.png 212 | 213 | root/cat/123.png 214 | root/cat/nsdf3.png 215 | root/cat/asd932_.png 216 | 217 | Args: 218 | root (string): Root directory path. 219 | transform (callable, optional): A function/transform that takes in an PIL image 220 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 221 | target_transform (callable, optional): A function/transform that takes in the 222 | target and transforms it. 223 | loader (callable, optional): A function to load an image given its path. 224 | is_valid_file (callable, optional): A function that takes path of an Image file 225 | and check if the file is a valid file (used to check of corrupt files) 226 | 227 | Attributes: 228 | classes (list): List of the class names sorted alphabetically. 229 | class_to_idx (dict): Dict with items (class_name, class_index). 230 | imgs (list): List of (image path, class_index) tuples 231 | """ 232 | 233 | def __init__( 234 | self, 235 | root: str, 236 | transform: Optional[Callable] = None, 237 | target_transform: Optional[Callable] = None, 238 | loader: Callable[[str], Any] = default_loader, 239 | is_valid_file: Optional[Callable[[str], bool]] = None, 240 | ): 241 | super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, 242 | transform=transform, 243 | target_transform=target_transform, 244 | is_valid_file=is_valid_file) 245 | self.imgs = self.samples 246 | -------------------------------------------------------------------------------- /finetune/datasets.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Based on BEiT, timm, DINO and DeiT code bases 3 | # https://github.com/microsoft/unilm/tree/master/beit 4 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 5 | # https://github.com/facebookresearch/deit 6 | # https://github.com/facebookresearch/dino 7 | # --------------------------------------------------------' 8 | import os 9 | import torch 10 | 11 | from torchvision import datasets, transforms 12 | 13 | from timm.data.constants import \ 14 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 15 | 16 | from timm.data import create_transform 17 | # 18 | #from finetune.dataset_folder import ImageFolder 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | print("Transform = ") 24 | if isinstance(transform, tuple): 25 | for trans in transform: 26 | print(" - - - - - - - - - - ") 27 | for t in trans.transforms: 28 | print(t) 29 | else: 30 | for t in transform.transforms: 31 | print(t) 32 | print("---------------------------") 33 | 34 | if args.data_set == 'CIFAR': 35 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 36 | nb_classes = 100 37 | elif args.data_set == 'IMNET': 38 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 39 | dataset = datasets.ImageFolder(root, transform=transform) 40 | nb_classes = 1000 41 | else: 42 | raise NotImplementedError() 43 | assert nb_classes == args.nb_classes 44 | print("Number of the class = %d" % args.nb_classes) 45 | 46 | return dataset, nb_classes 47 | 48 | 49 | def build_transform(is_train, args): 50 | resize_im = args.input_size > 32 51 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std 52 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 53 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 54 | 55 | if is_train: 56 | # this should always dispatch to transforms_imagenet_train 57 | transform = create_transform( 58 | input_size=args.input_size, 59 | is_training=True, 60 | color_jitter=args.color_jitter, 61 | auto_augment=args.aa, 62 | interpolation=args.train_interpolation, 63 | re_prob=args.reprob, 64 | re_mode=args.remode, 65 | re_count=args.recount, 66 | mean=mean, 67 | std=std, 68 | ) 69 | if not resize_im: 70 | # replace RandomResizedCropAndInterpolation with 71 | # RandomCrop 72 | transform.transforms[0] = transforms.RandomCrop( 73 | args.input_size, padding=4) 74 | return transform 75 | 76 | t = [] 77 | if resize_im: 78 | if args.crop_pct is None: 79 | if args.input_size < 384: 80 | args.crop_pct = 224 / 256 81 | else: 82 | args.crop_pct = 1.0 83 | size = int(args.input_size / args.crop_pct) 84 | t.append( 85 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 86 | ) 87 | t.append(transforms.CenterCrop(args.input_size)) 88 | 89 | t.append(transforms.ToTensor()) 90 | t.append(transforms.Normalize(mean, std)) 91 | return transforms.Compose(t) 92 | -------------------------------------------------------------------------------- /finetune/engine_for_finetuning.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Based on BEiT, timm, DINO and DeiT code bases 3 | # https://github.com/microsoft/unilm/tree/master/beit 4 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 5 | # https://github.com/facebookresearch/deit 6 | # https://github.com/facebookresearch/dino 7 | # --------------------------------------------------------' 8 | import math 9 | import sys 10 | from typing import Iterable, Optional 11 | 12 | import torch 13 | 14 | from timm.data import Mixup 15 | from timm.utils import accuracy, ModelEma 16 | 17 | import utils 18 | 19 | 20 | def train_class_batch(model, samples, target, criterion): 21 | outputs = model(samples) 22 | loss = criterion(outputs, target) 23 | return loss, outputs 24 | 25 | 26 | def get_loss_scale_for_deepspeed(model): 27 | optimizer = model.optimizer 28 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale 29 | 30 | 31 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 32 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 33 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 34 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 35 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None, 36 | num_training_steps_per_epoch=None, update_freq=None): 37 | model.train(True) 38 | metric_logger = utils.MetricLogger(delimiter=" ") 39 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 40 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 41 | header = 'Epoch: [{}]'.format(epoch) 42 | print_freq = 10 43 | 44 | if loss_scaler is None: 45 | model.zero_grad() 46 | model.micro_steps = 0 47 | else: 48 | optimizer.zero_grad() 49 | 50 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 51 | step = data_iter_step // update_freq 52 | if step >= num_training_steps_per_epoch: 53 | continue 54 | it = start_steps + step # global training iteration 55 | # Update LR & WD for the first acc 56 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: 57 | for i, param_group in enumerate(optimizer.param_groups): 58 | if lr_schedule_values is not None: 59 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 60 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 61 | param_group["weight_decay"] = wd_schedule_values[it] 62 | 63 | samples = samples.to(device, non_blocking=True) 64 | targets = targets.to(device, non_blocking=True) 65 | 66 | if mixup_fn is not None: 67 | samples, targets = mixup_fn(samples, targets) 68 | 69 | if loss_scaler is None: 70 | samples = samples.half() 71 | loss, output = train_class_batch( 72 | model, samples, targets, criterion) 73 | else: 74 | with torch.cuda.amp.autocast(): 75 | loss, output = train_class_batch( 76 | model, samples, targets, criterion) 77 | 78 | loss_value = loss.item() 79 | 80 | if not math.isfinite(loss_value): 81 | print("Loss is {}, stopping training".format(loss_value)) 82 | sys.exit(1) 83 | 84 | if loss_scaler is None: 85 | loss /= update_freq 86 | model.backward(loss) 87 | model.step() 88 | 89 | if (data_iter_step + 1) % update_freq == 0: 90 | # model.zero_grad() 91 | # Deepspeed will call step() & model.zero_grad() automatic 92 | if model_ema is not None: 93 | model_ema.update(model) 94 | grad_norm = None 95 | loss_scale_value = get_loss_scale_for_deepspeed(model) 96 | else: 97 | # this attribute is added by timm on one optimizer (adahessian) 98 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 99 | loss /= update_freq 100 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 101 | parameters=model.parameters(), create_graph=is_second_order, 102 | update_grad=(data_iter_step + 1) % update_freq == 0) 103 | if (data_iter_step + 1) % update_freq == 0: 104 | optimizer.zero_grad() 105 | if model_ema is not None: 106 | model_ema.update(model) 107 | loss_scale_value = loss_scaler.state_dict()["scale"] 108 | 109 | torch.cuda.synchronize() 110 | 111 | if mixup_fn is None: 112 | class_acc = (output.max(-1)[-1] == targets).float().mean() 113 | else: 114 | class_acc = None 115 | metric_logger.update(loss=loss_value) 116 | metric_logger.update(class_acc=class_acc) 117 | metric_logger.update(loss_scale=loss_scale_value) 118 | min_lr = 10. 119 | max_lr = 0. 120 | for group in optimizer.param_groups: 121 | min_lr = min(min_lr, group["lr"]) 122 | max_lr = max(max_lr, group["lr"]) 123 | 124 | metric_logger.update(lr=max_lr) 125 | metric_logger.update(min_lr=min_lr) 126 | weight_decay_value = None 127 | for group in optimizer.param_groups: 128 | if group["weight_decay"] > 0: 129 | weight_decay_value = group["weight_decay"] 130 | metric_logger.update(weight_decay=weight_decay_value) 131 | metric_logger.update(grad_norm=grad_norm) 132 | 133 | if log_writer is not None: 134 | log_writer.update(loss=loss_value, head="loss") 135 | log_writer.update(class_acc=class_acc, head="loss") 136 | log_writer.update(loss_scale=loss_scale_value, head="opt") 137 | log_writer.update(lr=max_lr, head="opt") 138 | log_writer.update(min_lr=min_lr, head="opt") 139 | log_writer.update(weight_decay=weight_decay_value, head="opt") 140 | log_writer.update(grad_norm=grad_norm, head="opt") 141 | 142 | log_writer.set_step() 143 | 144 | # gather the stats from all processes 145 | metric_logger.synchronize_between_processes() 146 | print("Averaged stats:", metric_logger) 147 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 148 | 149 | 150 | @torch.no_grad() 151 | def evaluate(data_loader, model, device): 152 | criterion = torch.nn.CrossEntropyLoss() 153 | 154 | metric_logger = utils.MetricLogger(delimiter=" ") 155 | header = 'Test:' 156 | 157 | # switch to evaluation mode 158 | model.eval() 159 | 160 | for batch in metric_logger.log_every(data_loader, 10, header): 161 | images = batch[0] 162 | target = batch[-1] 163 | images = images.to(device, non_blocking=True) 164 | target = target.to(device, non_blocking=True) 165 | 166 | # compute output 167 | with torch.cuda.amp.autocast(): 168 | output = model(images) 169 | loss = criterion(output, target) 170 | 171 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 172 | 173 | batch_size = images.shape[0] 174 | metric_logger.update(loss=loss.item()) 175 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 176 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 177 | # gather the stats from all processes 178 | metric_logger.synchronize_between_processes() 179 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 180 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 181 | 182 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 183 | -------------------------------------------------------------------------------- /finetune/modeling_finetune.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Based on BEiT, timm, DINO and DeiT code bases 3 | # https://github.com/microsoft/unilm/tree/master/beit 4 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 5 | # https://github.com/facebookresearch/deit 6 | # https://github.com/facebookresearch/dino 7 | # --------------------------------------------------------' 8 | import math 9 | from functools import partial 10 | import numpy as np 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 16 | from timm.models.registry import register_model 17 | 18 | 19 | def _cfg(url='', **kwargs): 20 | return { 21 | 'url': url, 22 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 23 | 'crop_pct': .9, 'interpolation': 'bicubic', 24 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 25 | **kwargs 26 | } 27 | 28 | 29 | class DropPath(nn.Module): 30 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 31 | """ 32 | def __init__(self, drop_prob=None): 33 | super(DropPath, self).__init__() 34 | self.drop_prob = drop_prob 35 | 36 | def forward(self, x): 37 | return drop_path(x, self.drop_prob, self.training) 38 | 39 | def extra_repr(self) -> str: 40 | return 'p={}'.format(self.drop_prob) 41 | 42 | 43 | class Mlp(nn.Module): 44 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 45 | super().__init__() 46 | out_features = out_features or in_features 47 | hidden_features = hidden_features or in_features 48 | self.fc1 = nn.Linear(in_features, hidden_features) 49 | self.act = act_layer() 50 | self.fc2 = nn.Linear(hidden_features, out_features) 51 | self.drop = nn.Dropout(drop) 52 | 53 | def forward(self, x): 54 | x = self.fc1(x) 55 | x = self.act(x) 56 | # x = self.drop(x) 57 | # commit this for the orignal BERT implement 58 | x = self.fc2(x) 59 | x = self.drop(x) 60 | return x 61 | 62 | 63 | class Attention(nn.Module): 64 | def __init__( 65 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., 66 | proj_drop=0., attn_head_dim=None): 67 | super().__init__() 68 | self.num_heads = num_heads 69 | head_dim = dim // num_heads 70 | if attn_head_dim is not None: 71 | head_dim = attn_head_dim 72 | all_head_dim = head_dim * self.num_heads 73 | self.scale = qk_scale or head_dim ** -0.5 74 | 75 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias) 76 | 77 | self.attn_drop = nn.Dropout(attn_drop) 78 | self.proj = nn.Linear(all_head_dim, dim) 79 | self.proj_drop = nn.Dropout(proj_drop) 80 | 81 | def forward(self, x): 82 | B, N, C = x.shape 83 | 84 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 85 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 86 | 87 | attn = (q @ k.transpose(-2, -1)) * self.scale 88 | attn = attn.softmax(dim=-1) 89 | attn = self.attn_drop(attn) 90 | 91 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 92 | x = self.proj(x) 93 | x = self.proj_drop(x) 94 | 95 | return x 96 | 97 | 98 | class Block(nn.Module): 99 | 100 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 101 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, 102 | attn_head_dim=None): 103 | super().__init__() 104 | self.norm1 = norm_layer(dim) 105 | self.attn = Attention( 106 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 107 | attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim) 108 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 109 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 110 | self.norm2 = norm_layer(dim) 111 | mlp_hidden_dim = int(dim * mlp_ratio) 112 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 113 | 114 | if init_values > 0: 115 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 116 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 117 | else: 118 | self.gamma_1, self.gamma_2 = None, None 119 | 120 | def forward(self, x): 121 | if self.gamma_1 is None: 122 | x = x + self.drop_path(self.attn(self.norm1(x))) 123 | x = x + self.drop_path(self.mlp(self.norm2(x))) 124 | else: 125 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) 126 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 127 | return x 128 | 129 | 130 | class PatchEmbed(nn.Module): 131 | """ Image to Patch Embedding 132 | """ 133 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 134 | super().__init__() 135 | img_size = to_2tuple(img_size) 136 | patch_size = to_2tuple(patch_size) 137 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 138 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 139 | self.img_size = img_size 140 | self.patch_size = patch_size 141 | self.num_patches = num_patches 142 | 143 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 144 | 145 | def forward(self, x, **kwargs): 146 | B, C, H, W = x.shape 147 | # FIXME look at relaxing size constraints 148 | assert H == self.img_size[0] and W == self.img_size[1], \ 149 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 150 | x = self.proj(x).flatten(2).transpose(1, 2) 151 | return x 152 | 153 | # sin-cos position encoding 154 | # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31 155 | def get_sinusoid_encoding_table(n_position, d_hid): 156 | ''' Sinusoid position encoding table ''' 157 | # TODO: make it with torch instead of numpy 158 | def get_position_angle_vec(position): 159 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 160 | 161 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 162 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 163 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 164 | 165 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 166 | 167 | 168 | class VisionTransformer(nn.Module): 169 | """ Vision Transformer with support for patch or hybrid CNN input stage 170 | """ 171 | def __init__(self, 172 | img_size=224, 173 | patch_size=16, 174 | in_chans=3, 175 | num_classes=1000, 176 | embed_dim=768, 177 | depth=12, 178 | num_heads=12, 179 | mlp_ratio=4., 180 | qkv_bias=False, 181 | qk_scale=None, 182 | drop_rate=0., 183 | attn_drop_rate=0., 184 | drop_path_rate=0., 185 | norm_layer=nn.LayerNorm, 186 | init_values=0., 187 | use_learnable_pos_emb=False, 188 | init_scale=0., 189 | use_mean_pooling=True): 190 | super().__init__() 191 | self.num_classes = num_classes 192 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 193 | 194 | self.patch_embed = PatchEmbed( 195 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 196 | num_patches = self.patch_embed.num_patches 197 | 198 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 199 | if use_learnable_pos_emb: 200 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 201 | #self.pos_embed.requires_grad = False 202 | else: 203 | # sine-cosine positional embeddings is on the way 204 | self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim) 205 | 206 | self.pos_drop = nn.Dropout(p=drop_rate) 207 | 208 | 209 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 210 | self.blocks = nn.ModuleList([ 211 | Block( 212 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 213 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 214 | init_values=init_values) 215 | for i in range(depth)]) 216 | self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) 217 | self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None 218 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 219 | 220 | if use_learnable_pos_emb: 221 | trunc_normal_(self.pos_embed, std=.02) 222 | 223 | trunc_normal_(self.cls_token, std=.02) 224 | trunc_normal_(self.head.weight, std=.02) 225 | self.apply(self._init_weights) 226 | 227 | self.head.weight.data.mul_(init_scale) 228 | self.head.bias.data.mul_(init_scale) 229 | 230 | def _init_weights(self, m): 231 | if isinstance(m, nn.Linear): 232 | trunc_normal_(m.weight, std=.02) 233 | if isinstance(m, nn.Linear) and m.bias is not None: 234 | nn.init.constant_(m.bias, 0) 235 | elif isinstance(m, nn.LayerNorm): 236 | nn.init.constant_(m.bias, 0) 237 | nn.init.constant_(m.weight, 1.0) 238 | 239 | def get_num_layers(self): 240 | return len(self.blocks) 241 | 242 | @torch.jit.ignore 243 | def no_weight_decay(self): 244 | return {'pos_embed', 'cls_token'} 245 | 246 | def get_classifier(self): 247 | return self.head 248 | 249 | def reset_classifier(self, num_classes, global_pool=''): 250 | self.num_classes = num_classes 251 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 252 | 253 | def forward_features(self, x): 254 | x = self.patch_embed(x) 255 | B, _, _ = x.size() 256 | 257 | #cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 258 | #x = torch.cat((cls_tokens, x), dim=1) 259 | if self.pos_embed is not None: 260 | #x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach() 261 | x = x + self.pos_embed 262 | x = self.pos_drop(x) 263 | 264 | for blk in self.blocks: 265 | x = blk(x) 266 | 267 | if self.fc_norm is not None: 268 | return self.fc_norm(x.mean(1)) 269 | else: 270 | x = self.norm(x) 271 | return x[:, 0] 272 | 273 | def forward(self, x): 274 | x = self.forward_features(x) 275 | x = self.head(x) 276 | return x 277 | 278 | @register_model 279 | def vit_small_patch16_224(pretrained=False, **kwargs): 280 | model = VisionTransformer( 281 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 282 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 283 | model.default_cfg = _cfg() 284 | return model 285 | 286 | @register_model 287 | def vit_base_patch16_224(pretrained=False, **kwargs): 288 | model = VisionTransformer( 289 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 290 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 291 | model.default_cfg = _cfg() 292 | return model 293 | 294 | 295 | @register_model 296 | def vit_base_patch16_384(pretrained=False, **kwargs): 297 | model = VisionTransformer( 298 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 299 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 300 | model.default_cfg = _cfg() 301 | return model 302 | 303 | 304 | @register_model 305 | def vit_large_patch16_224(pretrained=False, **kwargs): 306 | model = VisionTransformer( 307 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 308 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 309 | model.default_cfg = _cfg() 310 | return model 311 | 312 | 313 | @register_model 314 | def vit_large_patch16_384(pretrained=False, **kwargs): 315 | model = VisionTransformer( 316 | img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 317 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 318 | model.default_cfg = _cfg() 319 | return model 320 | 321 | 322 | @register_model 323 | def vit_large_patch16_512(pretrained=False, **kwargs): 324 | model = VisionTransformer( 325 | img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 326 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 327 | model.default_cfg = _cfg() 328 | return model 329 | -------------------------------------------------------------------------------- /finetune/optim_factory.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Based on BEiT, timm, DINO and DeiT code bases 3 | # https://github.com/microsoft/unilm/tree/master/beit 4 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 5 | # https://github.com/facebookresearch/deit 6 | # https://github.com/facebookresearch/dino 7 | # --------------------------------------------------------' 8 | import torch 9 | from torch import optim as optim 10 | 11 | from timm.optim.adafactor import Adafactor 12 | from timm.optim.adahessian import Adahessian 13 | from timm.optim.adamp import AdamP 14 | from timm.optim.lookahead import Lookahead 15 | from timm.optim.nadam import Nadam 16 | from timm.optim.novograd import NovoGrad 17 | from timm.optim.nvnovograd import NvNovoGrad 18 | from timm.optim.radam import RAdam 19 | from timm.optim.rmsprop_tf import RMSpropTF 20 | from timm.optim.sgdp import SGDP 21 | 22 | import json 23 | 24 | try: 25 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 26 | has_apex = True 27 | except ImportError: 28 | has_apex = False 29 | 30 | 31 | def get_num_layer_for_vit(var_name, num_max_layer): 32 | if var_name in ("cls_token", "mask_token", "pos_embed"): 33 | return 0 34 | elif var_name.startswith("patch_embed"): 35 | return 0 36 | elif var_name.startswith("rel_pos_bias"): 37 | return num_max_layer - 1 38 | elif var_name.startswith("blocks"): 39 | layer_id = int(var_name.split('.')[1]) 40 | return layer_id + 1 41 | else: 42 | return num_max_layer - 1 43 | 44 | 45 | class LayerDecayValueAssigner(object): 46 | def __init__(self, values): 47 | self.values = values 48 | 49 | def get_scale(self, layer_id): 50 | return self.values[layer_id] 51 | 52 | def get_layer_id(self, var_name): 53 | return get_num_layer_for_vit(var_name, len(self.values)) 54 | 55 | 56 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 57 | parameter_group_names = {} 58 | parameter_group_vars = {} 59 | 60 | for name, param in model.named_parameters(): 61 | if not param.requires_grad: 62 | continue # frozen weights 63 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 64 | group_name = "no_decay" 65 | this_weight_decay = 0. 66 | else: 67 | group_name = "decay" 68 | this_weight_decay = weight_decay 69 | if get_num_layer is not None: 70 | layer_id = get_num_layer(name) 71 | group_name = "layer_%d_%s" % (layer_id, group_name) 72 | else: 73 | layer_id = None 74 | 75 | if group_name not in parameter_group_names: 76 | if get_layer_scale is not None: 77 | scale = get_layer_scale(layer_id) 78 | else: 79 | scale = 1. 80 | 81 | parameter_group_names[group_name] = { 82 | "weight_decay": this_weight_decay, 83 | "params": [], 84 | "lr_scale": scale 85 | } 86 | parameter_group_vars[group_name] = { 87 | "weight_decay": this_weight_decay, 88 | "params": [], 89 | "lr_scale": scale 90 | } 91 | 92 | parameter_group_vars[group_name]["params"].append(param) 93 | parameter_group_names[group_name]["params"].append(name) 94 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 95 | return list(parameter_group_vars.values()) 96 | 97 | 98 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 99 | opt_lower = args.opt.lower() 100 | weight_decay = args.weight_decay 101 | if weight_decay and filter_bias_and_bn: 102 | skip = {} 103 | if skip_list is not None: 104 | skip = skip_list 105 | elif hasattr(model, 'no_weight_decay'): 106 | skip = model.no_weight_decay() 107 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 108 | weight_decay = 0. 109 | else: 110 | parameters = model.parameters() 111 | 112 | if 'fused' in opt_lower: 113 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 114 | 115 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 116 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 117 | opt_args['eps'] = args.opt_eps 118 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 119 | opt_args['betas'] = args.opt_betas 120 | 121 | print("optimizer settings:", opt_args) 122 | 123 | opt_split = opt_lower.split('_') 124 | opt_lower = opt_split[-1] 125 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 126 | opt_args.pop('eps', None) 127 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 128 | elif opt_lower == 'momentum': 129 | opt_args.pop('eps', None) 130 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 131 | elif opt_lower == 'adam': 132 | optimizer = optim.Adam(parameters, **opt_args) 133 | elif opt_lower == 'adamw': 134 | optimizer = optim.AdamW(parameters, **opt_args) 135 | elif opt_lower == 'nadam': 136 | optimizer = Nadam(parameters, **opt_args) 137 | elif opt_lower == 'radam': 138 | optimizer = RAdam(parameters, **opt_args) 139 | elif opt_lower == 'adamp': 140 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 141 | elif opt_lower == 'sgdp': 142 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 143 | elif opt_lower == 'adadelta': 144 | optimizer = optim.Adadelta(parameters, **opt_args) 145 | elif opt_lower == 'adafactor': 146 | if not args.lr: 147 | opt_args['lr'] = None 148 | optimizer = Adafactor(parameters, **opt_args) 149 | elif opt_lower == 'adahessian': 150 | optimizer = Adahessian(parameters, **opt_args) 151 | elif opt_lower == 'rmsprop': 152 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 153 | elif opt_lower == 'rmsproptf': 154 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 155 | elif opt_lower == 'novograd': 156 | optimizer = NovoGrad(parameters, **opt_args) 157 | elif opt_lower == 'nvnovograd': 158 | optimizer = NvNovoGrad(parameters, **opt_args) 159 | elif opt_lower == 'fusedsgd': 160 | opt_args.pop('eps', None) 161 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 162 | elif opt_lower == 'fusedmomentum': 163 | opt_args.pop('eps', None) 164 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 165 | elif opt_lower == 'fusedadam': 166 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 167 | elif opt_lower == 'fusedadamw': 168 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 169 | elif opt_lower == 'fusedlamb': 170 | optimizer = FusedLAMB(parameters, **opt_args) 171 | elif opt_lower == 'fusednovograd': 172 | opt_args.setdefault('betas', (0.95, 0.98)) 173 | optimizer = FusedNovoGrad(parameters, **opt_args) 174 | else: 175 | assert False and "Invalid optimizer" 176 | raise ValueError 177 | 178 | if len(opt_split) > 1: 179 | if opt_split[0] == 'lookahead': 180 | optimizer = Lookahead(optimizer) 181 | 182 | return optimizer 183 | -------------------------------------------------------------------------------- /finetune/run_class_finetuning.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Based on BEiT, timm, DINO and DeiT code bases 3 | # https://github.com/microsoft/unilm/tree/master/beit 4 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 5 | # https://github.com/facebookresearch/deit 6 | # https://github.com/facebookresearch/dino 7 | # --------------------------------------------------------' 8 | 9 | import argparse 10 | import datetime 11 | import numpy as np 12 | import time 13 | import torch 14 | import torch.backends.cudnn as cudnn 15 | import json 16 | import os 17 | 18 | from pathlib import Path 19 | from collections import OrderedDict 20 | 21 | from timm.data.mixup import Mixup 22 | from timm.models import create_model 23 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 24 | from timm.utils import ModelEma 25 | from optim_factory import create_optimizer, get_parameter_groups, LayerDecayValueAssigner 26 | 27 | from datasets import build_dataset 28 | from engine_for_finetuning import train_one_epoch, evaluate 29 | from utils import NativeScalerWithGradNormCount as NativeScaler 30 | import utils 31 | #from scipy import interpolate 32 | import modeling_finetune 33 | 34 | 35 | def get_args(): 36 | parser = argparse.ArgumentParser('MAE fine-tuning and evaluation script for image classification', add_help=False) 37 | parser.add_argument('--batch_size', default=64, type=int) 38 | parser.add_argument('--epochs', default=30, type=int) 39 | parser.add_argument('--update_freq', default=1, type=int) 40 | parser.add_argument('--save_ckpt_freq', default=20, type=int) 41 | 42 | # Model parameters 43 | parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', 44 | help='Name of model to train') 45 | 46 | parser.add_argument('--input_size', default=224, type=int, 47 | help='images input size') 48 | 49 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 50 | help='Dropout rate (default: 0.)') 51 | parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT', 52 | help='Attention dropout rate (default: 0.)') 53 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', 54 | help='Drop path rate (default: 0.1)') 55 | 56 | parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False) 57 | 58 | parser.add_argument('--model_ema', action='store_true', default=False) 59 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='') 60 | parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='') 61 | 62 | # Optimizer parameters 63 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 64 | help='Optimizer (default: "adamw"') 65 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', 66 | help='Optimizer Epsilon (default: 1e-8)') 67 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', 68 | help='Optimizer Betas (default: None, use opt default)') 69 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 70 | help='Clip gradient norm (default: None, no clipping)') 71 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 72 | help='SGD momentum (default: 0.9)') 73 | parser.add_argument('--weight_decay', type=float, default=0.05, 74 | help='weight decay (default: 0.05)') 75 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the 76 | weight decay. We use a cosine schedule for WD and using a larger decay by 77 | the end of training improves performance for ViTs.""") 78 | 79 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', 80 | help='learning rate (default: 1e-3)') 81 | parser.add_argument('--layer_decay', type=float, default=0.75) 82 | 83 | parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR', 84 | help='warmup learning rate (default: 1e-6)') 85 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 86 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 87 | 88 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', 89 | help='epochs to warmup LR, if scheduler supports') 90 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', 91 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0') 92 | 93 | # Augmentation parameters 94 | # TODO: color jitter differs in the MAE paper. 95 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', 96 | help='Color jitter factor (default: 0.4)') 97 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 98 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 99 | parser.add_argument('--smoothing', type=float, default=0.1, 100 | help='Label smoothing (default: 0.1)') 101 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 102 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 103 | 104 | # Evaluation parameters 105 | parser.add_argument('--crop_pct', type=float, default=None) 106 | 107 | # * Random Erase params 108 | # TODO: This is not used in MAE. 109 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 110 | help='Random erase prob (default: 0.25)') 111 | parser.add_argument('--remode', type=str, default='pixel', 112 | help='Random erase mode (default: "pixel")') 113 | parser.add_argument('--recount', type=int, default=1, 114 | help='Random erase count (default: 1)') 115 | parser.add_argument('--resplit', action='store_true', default=False, 116 | help='Do not random erase first (clean) augmentation split') 117 | 118 | # * Mixup params 119 | parser.add_argument('--mixup', type=float, default=0.8, 120 | help='mixup alpha, mixup enabled if > 0.') 121 | parser.add_argument('--cutmix', type=float, default=1.0, 122 | help='cutmix alpha, cutmix enabled if > 0.') 123 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 124 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 125 | parser.add_argument('--mixup_prob', type=float, default=1.0, 126 | help='Probability of performing mixup or cutmix when either/both is enabled') 127 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 128 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 129 | parser.add_argument('--mixup_mode', type=str, default='batch', 130 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 131 | 132 | # * Finetuning params 133 | parser.add_argument('--finetune', default='', help='finetune from checkpoint') 134 | parser.add_argument('--model_key', default='model|module|state_dict|teacher', type=str) 135 | parser.add_argument('--eval_momentum', action='store_true', 136 | help='evaluation momentum encoder') 137 | parser.add_argument('--model_prefix', default='', type=str) 138 | parser.add_argument('--init_scale', default=0.001, type=float) 139 | parser.add_argument('--use_mean_pooling', action='store_true') 140 | parser.set_defaults(use_mean_pooling=True) 141 | parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling') 142 | # TODO: we need to try both. We also need to consider adding batch norm. 143 | # TODO: What does init scale mean? 144 | 145 | # Dataset parameters 146 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 147 | help='dataset path') 148 | parser.add_argument('--eval_data_path', default=None, type=str, 149 | help='dataset path for evaluation') 150 | parser.add_argument('--nb_classes', default=1000, type=int, 151 | help='number of the classification types') 152 | parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true') 153 | 154 | parser.add_argument('--data_set', default='IMNET', choices=['CIFAR', 'IMNET', 'image_folder'], 155 | type=str, help='ImageNet dataset path') 156 | parser.add_argument('--output_dir', default='', 157 | help='path where to save, empty for no saving') 158 | parser.add_argument('--log_dir', default=None, 159 | help='path where to tensorboard log') 160 | parser.add_argument('--device', default='cuda', 161 | help='device to use for training / testing') 162 | parser.add_argument('--seed', default=0, type=int) 163 | parser.add_argument('--resume', default='', 164 | help='resume from checkpoint') 165 | parser.add_argument('--auto_resume', action='store_true') 166 | parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume') 167 | parser.set_defaults(auto_resume=True) 168 | 169 | parser.add_argument('--save_ckpt', action='store_true') 170 | parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt') 171 | parser.set_defaults(save_ckpt=True) 172 | 173 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 174 | help='start epoch') 175 | parser.add_argument('--eval', action='store_true', 176 | help='Perform evaluation only') 177 | parser.add_argument('--dist_eval', action='store_true', default=False, 178 | help='Enabling distributed evaluation') 179 | parser.add_argument('--num_workers', default=4, type=int) 180 | parser.add_argument('--pin_mem', action='store_true', 181 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 182 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 183 | parser.set_defaults(pin_mem=True) 184 | 185 | # distributed training parameters 186 | parser.add_argument('--world_size', default=1, type=int, 187 | help='number of distributed processes') 188 | parser.add_argument('--local_rank', default=-1, type=int) 189 | parser.add_argument('--dist_on_itp', action='store_true') 190 | parser.add_argument('--dist_url', default='env://', 191 | help='url used to set up distributed training') 192 | 193 | parser.add_argument('--enable_deepspeed', action='store_true', default=False) 194 | 195 | known_args, _ = parser.parse_known_args() 196 | 197 | if known_args.enable_deepspeed: 198 | try: 199 | import deepspeed 200 | from deepspeed import DeepSpeedConfig 201 | parser = deepspeed.add_config_arguments(parser) 202 | ds_init = deepspeed.initialize 203 | except: 204 | print("Please 'pip install deepspeed==0.4.0'") 205 | exit(0) 206 | else: 207 | ds_init = None 208 | 209 | return parser.parse_args(), ds_init 210 | 211 | 212 | def main(args, ds_init): 213 | utils.init_distributed_mode(args) 214 | 215 | if ds_init is not None: 216 | utils.create_ds_config(args) 217 | 218 | print(args) 219 | 220 | device = torch.device(args.device) 221 | 222 | # fix the seed for reproducibility 223 | seed = args.seed + utils.get_rank() 224 | torch.manual_seed(seed) 225 | np.random.seed(seed) 226 | # random.seed(seed) 227 | 228 | cudnn.benchmark = True 229 | 230 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 231 | if args.disable_eval_during_finetuning: 232 | dataset_val = None 233 | else: 234 | dataset_val, _ = build_dataset(is_train=False, args=args) 235 | 236 | if True: # args.distributed: 237 | num_tasks = utils.get_world_size() 238 | global_rank = utils.get_rank() 239 | sampler_train = torch.utils.data.DistributedSampler( 240 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 241 | ) 242 | print("Sampler_train = %s" % str(sampler_train)) 243 | if args.dist_eval: 244 | if len(dataset_val) % num_tasks != 0: 245 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 246 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 247 | 'equal num of samples per-process.') 248 | sampler_val = torch.utils.data.DistributedSampler( 249 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 250 | else: 251 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 252 | else: 253 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 254 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 255 | 256 | if global_rank == 0 and args.log_dir is not None: 257 | os.makedirs(args.log_dir, exist_ok=True) 258 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir) 259 | else: 260 | log_writer = None 261 | 262 | data_loader_train = torch.utils.data.DataLoader( 263 | dataset_train, sampler=sampler_train, 264 | batch_size=args.batch_size, 265 | num_workers=args.num_workers, 266 | pin_memory=args.pin_mem, 267 | drop_last=True, 268 | ) 269 | 270 | if dataset_val is not None: 271 | data_loader_val = torch.utils.data.DataLoader( 272 | dataset_val, sampler=sampler_val, 273 | batch_size=int(1.5 * args.batch_size), 274 | num_workers=args.num_workers, 275 | pin_memory=args.pin_mem, 276 | drop_last=False 277 | ) 278 | else: 279 | data_loader_val = None 280 | 281 | mixup_fn = None 282 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 283 | if mixup_active: 284 | print("Mixup is activated!") 285 | mixup_fn = Mixup( 286 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 287 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 288 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 289 | 290 | model = create_model( 291 | args.model, 292 | pretrained=False, 293 | num_classes=args.nb_classes, 294 | drop_rate=args.drop, 295 | drop_path_rate=args.drop_path, 296 | attn_drop_rate=args.attn_drop_rate, 297 | drop_block_rate=None, 298 | use_mean_pooling=args.use_mean_pooling, 299 | init_scale=args.init_scale, 300 | use_learnable_pos_emb=True, 301 | ) 302 | 303 | patch_size = model.patch_embed.patch_size 304 | print("Patch size = %s" % str(patch_size)) 305 | args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1]) 306 | args.patch_size = patch_size 307 | 308 | if args.finetune: 309 | if args.finetune.startswith('https'): 310 | checkpoint = torch.hub.load_state_dict_from_url( 311 | args.finetune, map_location='cpu', check_hash=True) 312 | else: 313 | checkpoint = torch.load(args.finetune, map_location='cpu') 314 | 315 | print("Load ckpt from %s" % args.finetune) 316 | checkpoint_model = None 317 | for model_key in args.model_key.split('|'): 318 | if model_key in checkpoint: 319 | checkpoint_model = checkpoint[model_key] 320 | print("Load state_dict by model_key = %s" % model_key) 321 | break 322 | if checkpoint_model is None: 323 | checkpoint_model = checkpoint 324 | state_dict = model.state_dict() 325 | for k in ['head.weight', 'head.bias']: 326 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 327 | print(f"Removing key {k} from pretrained checkpoint") 328 | del checkpoint_model[k] 329 | 330 | all_keys = list(checkpoint_model.keys()) 331 | new_dict = OrderedDict() 332 | for key in all_keys: 333 | if args.eval_momentum: 334 | if key.startswith('module.momentum_encoder.'): 335 | new_dict[key[24:]] = checkpoint_model[key] 336 | else: 337 | if key.startswith('backbone.'): 338 | new_dict[key[9:]] = checkpoint_model[key] 339 | elif key.startswith('encoder.'): 340 | new_dict[key[8:]] = checkpoint_model[key] 341 | elif key.startswith('module.base_encoder.'): 342 | new_dict[key[20:]] = checkpoint_model[key] 343 | elif key.startswith('module.visual.'): 344 | new_dict[key[14:]] = checkpoint_model[key] 345 | else: 346 | new_dict[key] = checkpoint_model[key] 347 | checkpoint_model = new_dict 348 | 349 | # interpolate position embedding 350 | if 'pos_embed' in checkpoint_model: 351 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 352 | embedding_size = pos_embed_checkpoint.shape[-1] 353 | num_patches = model.patch_embed.num_patches 354 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 355 | # height (== width) for the checkpoint position embedding 356 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 357 | # height (== width) for the new position embedding 358 | new_size = int(num_patches ** 0.5) 359 | # class_token and dist_token are kept unchanged 360 | if orig_size != new_size: 361 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 362 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 363 | # only the position tokens are interpolated 364 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 365 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 366 | pos_tokens = torch.nn.functional.interpolate( 367 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 368 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 369 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 370 | checkpoint_model['pos_embed'] = new_pos_embed 371 | 372 | utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix) 373 | # model.load_state_dict(checkpoint_model, strict=False) 374 | 375 | model.to(device) 376 | 377 | model_ema = None 378 | if args.model_ema: 379 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 380 | model_ema = ModelEma( 381 | model, 382 | decay=args.model_ema_decay, 383 | device='cpu' if args.model_ema_force_cpu else '', 384 | resume='') 385 | print("Using EMA with decay = %.8f" % args.model_ema_decay) 386 | 387 | model_without_ddp = model 388 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 389 | 390 | print("Model = %s" % str(model_without_ddp)) 391 | print('number of params:', n_parameters) 392 | 393 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size() 394 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size 395 | args.lr = args.lr * total_batch_size / 256 396 | print("LR = %.8f" % args.lr) 397 | print("Batch size = %d" % total_batch_size) 398 | print("Update frequent = %d" % args.update_freq) 399 | print("Number of training examples = %d" % len(dataset_train)) 400 | print("Number of training training per epoch = %d" % num_training_steps_per_epoch) 401 | 402 | num_layers = model_without_ddp.get_num_layers() 403 | if args.layer_decay < 1.0: 404 | assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))) 405 | else: 406 | assigner = None 407 | 408 | if assigner is not None: 409 | print("Assigned values = %s" % str(assigner.values)) 410 | 411 | skip_weight_decay_list = model.no_weight_decay() 412 | print("Skip weight decay list: ", skip_weight_decay_list) 413 | 414 | if args.enable_deepspeed: 415 | loss_scaler = None 416 | optimizer_params = get_parameter_groups( 417 | model, args.weight_decay, skip_weight_decay_list, 418 | assigner.get_layer_id if assigner is not None else None, 419 | assigner.get_scale if assigner is not None else None) 420 | model, optimizer, _, _ = ds_init( 421 | args=args, model=model, model_parameters=optimizer_params, dist_init_required=not args.distributed, 422 | ) 423 | 424 | print("model.gradient_accumulation_steps() = %d" % model.gradient_accumulation_steps()) 425 | assert model.gradient_accumulation_steps() == args.update_freq 426 | else: 427 | if args.distributed: 428 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 429 | model_without_ddp = model.module 430 | 431 | optimizer = create_optimizer( 432 | args, model_without_ddp, skip_list=skip_weight_decay_list, 433 | get_num_layer=assigner.get_layer_id if assigner is not None else None, 434 | get_layer_scale=assigner.get_scale if assigner is not None else None) 435 | loss_scaler = NativeScaler() 436 | 437 | print("Use step level LR scheduler!") 438 | lr_schedule_values = utils.cosine_scheduler( 439 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, 440 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, 441 | ) 442 | if args.weight_decay_end is None: 443 | args.weight_decay_end = args.weight_decay 444 | wd_schedule_values = utils.cosine_scheduler( 445 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) 446 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values))) 447 | 448 | if mixup_fn is not None: 449 | # smoothing is handled with mixup label transform 450 | criterion = SoftTargetCrossEntropy() 451 | elif args.smoothing > 0.: 452 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 453 | else: 454 | criterion = torch.nn.CrossEntropyLoss() 455 | 456 | print("criterion = %s" % str(criterion)) 457 | 458 | utils.auto_load_model( 459 | args=args, model=model, model_without_ddp=model_without_ddp, 460 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema) 461 | 462 | if args.eval: 463 | test_stats = evaluate(data_loader_val, model, device) 464 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 465 | exit(0) 466 | 467 | print(f"Start training for {args.epochs} epochs") 468 | start_time = time.time() 469 | max_accuracy = 0.0 470 | for epoch in range(args.start_epoch, args.epochs): 471 | if args.distributed: 472 | data_loader_train.sampler.set_epoch(epoch) 473 | if log_writer is not None: 474 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq) 475 | train_stats = train_one_epoch( 476 | model, criterion, data_loader_train, optimizer, 477 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, 478 | log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch, 479 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, 480 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq, 481 | ) 482 | if args.output_dir and args.save_ckpt: 483 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: 484 | utils.save_model( 485 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 486 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema) 487 | if data_loader_val is not None: 488 | test_stats = evaluate(data_loader_val, model, device) 489 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 490 | if max_accuracy < test_stats["acc1"]: 491 | max_accuracy = test_stats["acc1"] 492 | if args.output_dir and args.save_ckpt: 493 | utils.save_model( 494 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 495 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema) 496 | 497 | print(f'Max accuracy: {max_accuracy:.2f}%') 498 | if log_writer is not None: 499 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch) 500 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch) 501 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch) 502 | 503 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 504 | **{f'test_{k}': v for k, v in test_stats.items()}, 505 | 'epoch': epoch, 506 | 'n_parameters': n_parameters} 507 | else: 508 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 509 | # **{f'test_{k}': v for k, v in test_stats.items()}, 510 | 'epoch': epoch, 511 | 'n_parameters': n_parameters} 512 | 513 | if args.log_dir and utils.is_main_process(): 514 | if log_writer is not None: 515 | log_writer.flush() 516 | with open(os.path.join(args.log_dir, "log.txt"), mode="a", encoding="utf-8") as f: 517 | f.write(json.dumps(log_stats) + "\n") 518 | 519 | total_time = time.time() - start_time 520 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 521 | print('Training time {}'.format(total_time_str)) 522 | 523 | 524 | if __name__ == '__main__': 525 | opts, ds_init = get_args() 526 | if opts.output_dir: 527 | Path(opts.output_dir).mkdir(parents=True, exist_ok=True) 528 | main(opts, ds_init) 529 | -------------------------------------------------------------------------------- /finetune/utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # -------------------------------------------------------- 3 | # Based on BEiT, timm, DINO and DeiT code bases 4 | # https://github.com/microsoft/unilm/tree/master/beit 5 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 6 | # https://github.com/facebookresearch/deit 7 | # https://github.com/facebookresearch/dino 8 | # --------------------------------------------------------' 9 | import io 10 | import os 11 | import math 12 | import time 13 | import json 14 | from collections import defaultdict, deque 15 | import datetime 16 | import numpy as np 17 | from timm.utils import get_state_dict 18 | 19 | from pathlib import Path 20 | 21 | import torch 22 | import torch.distributed as dist 23 | from torch._six import inf 24 | 25 | import random 26 | 27 | from tensorboardX import SummaryWriter 28 | 29 | 30 | class SmoothedValue(object): 31 | """Track a series of values and provide access to smoothed values over a 32 | window or the global series average. 33 | """ 34 | 35 | def __init__(self, window_size=20, fmt=None): 36 | if fmt is None: 37 | fmt = "{median:.4f} ({global_avg:.4f})" 38 | self.deque = deque(maxlen=window_size) 39 | self.total = 0.0 40 | self.count = 0 41 | self.fmt = fmt 42 | 43 | def update(self, value, n=1): 44 | self.deque.append(value) 45 | self.count += n 46 | self.total += value * n 47 | 48 | def synchronize_between_processes(self): 49 | """ 50 | Warning: does not synchronize the deque! 51 | """ 52 | if not is_dist_avail_and_initialized(): 53 | return 54 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 55 | dist.barrier() 56 | dist.all_reduce(t) 57 | t = t.tolist() 58 | self.count = int(t[0]) 59 | self.total = t[1] 60 | 61 | @property 62 | def median(self): 63 | d = torch.tensor(list(self.deque)) 64 | return d.median().item() 65 | 66 | @property 67 | def avg(self): 68 | d = torch.tensor(list(self.deque), dtype=torch.float32) 69 | return d.mean().item() 70 | 71 | @property 72 | def global_avg(self): 73 | return self.total / self.count 74 | 75 | @property 76 | def max(self): 77 | return max(self.deque) 78 | 79 | @property 80 | def value(self): 81 | return self.deque[-1] 82 | 83 | def __str__(self): 84 | return self.fmt.format( 85 | median=self.median, 86 | avg=self.avg, 87 | global_avg=self.global_avg, 88 | max=self.max, 89 | value=self.value) 90 | 91 | 92 | class MetricLogger(object): 93 | def __init__(self, delimiter="\t"): 94 | self.meters = defaultdict(SmoothedValue) 95 | self.delimiter = delimiter 96 | 97 | def update(self, **kwargs): 98 | for k, v in kwargs.items(): 99 | if v is None: 100 | continue 101 | if isinstance(v, torch.Tensor): 102 | v = v.item() 103 | assert isinstance(v, (float, int)) 104 | self.meters[k].update(v) 105 | 106 | def __getattr__(self, attr): 107 | if attr in self.meters: 108 | return self.meters[attr] 109 | if attr in self.__dict__: 110 | return self.__dict__[attr] 111 | raise AttributeError("'{}' object has no attribute '{}'".format( 112 | type(self).__name__, attr)) 113 | 114 | def __str__(self): 115 | loss_str = [] 116 | for name, meter in self.meters.items(): 117 | loss_str.append( 118 | "{}: {}".format(name, str(meter)) 119 | ) 120 | return self.delimiter.join(loss_str) 121 | 122 | def synchronize_between_processes(self): 123 | for meter in self.meters.values(): 124 | meter.synchronize_between_processes() 125 | 126 | def add_meter(self, name, meter): 127 | self.meters[name] = meter 128 | 129 | def log_every(self, iterable, print_freq, header=None): 130 | i = 0 131 | if not header: 132 | header = '' 133 | start_time = time.time() 134 | end = time.time() 135 | iter_time = SmoothedValue(fmt='{avg:.4f}') 136 | data_time = SmoothedValue(fmt='{avg:.4f}') 137 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 138 | log_msg = [ 139 | header, 140 | '[{0' + space_fmt + '}/{1}]', 141 | 'eta: {eta}', 142 | '{meters}', 143 | 'time: {time}', 144 | 'data: {data}' 145 | ] 146 | if torch.cuda.is_available(): 147 | log_msg.append('max mem: {memory:.0f}') 148 | log_msg = self.delimiter.join(log_msg) 149 | MB = 1024.0 * 1024.0 150 | for obj in iterable: 151 | data_time.update(time.time() - end) 152 | yield obj 153 | iter_time.update(time.time() - end) 154 | if i % print_freq == 0 or i == len(iterable) - 1: 155 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 156 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 157 | if torch.cuda.is_available(): 158 | print(log_msg.format( 159 | i, len(iterable), eta=eta_string, 160 | meters=str(self), 161 | time=str(iter_time), data=str(data_time), 162 | memory=torch.cuda.max_memory_allocated() / MB)) 163 | else: 164 | print(log_msg.format( 165 | i, len(iterable), eta=eta_string, 166 | meters=str(self), 167 | time=str(iter_time), data=str(data_time))) 168 | i += 1 169 | end = time.time() 170 | total_time = time.time() - start_time 171 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 172 | print('{} Total time: {} ({:.4f} s / it)'.format( 173 | header, total_time_str, total_time / len(iterable))) 174 | 175 | 176 | class TensorboardLogger(object): 177 | def __init__(self, log_dir): 178 | self.writer = SummaryWriter(logdir=log_dir) 179 | self.step = 0 180 | 181 | def set_step(self, step=None): 182 | if step is not None: 183 | self.step = step 184 | else: 185 | self.step += 1 186 | 187 | def update(self, head='scalar', step=None, **kwargs): 188 | for k, v in kwargs.items(): 189 | if v is None: 190 | continue 191 | if isinstance(v, torch.Tensor): 192 | v = v.item() 193 | assert isinstance(v, (float, int)) 194 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step) 195 | 196 | def flush(self): 197 | self.writer.flush() 198 | 199 | def seed_worker(worker_id): 200 | worker_seed = torch.initial_seed() % 2**32 201 | np.random.seed(worker_seed) 202 | random.seed(worker_seed) 203 | 204 | def _load_checkpoint_for_ema(model_ema, checkpoint): 205 | """ 206 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 207 | """ 208 | mem_file = io.BytesIO() 209 | torch.save(checkpoint, mem_file) 210 | mem_file.seek(0) 211 | model_ema._load_checkpoint(mem_file) 212 | 213 | 214 | def setup_for_distributed(is_master): 215 | """ 216 | This function disables printing when not in master process 217 | """ 218 | import builtins as __builtin__ 219 | builtin_print = __builtin__.print 220 | 221 | def print(*args, **kwargs): 222 | force = kwargs.pop('force', False) 223 | if is_master or force: 224 | builtin_print(*args, **kwargs) 225 | 226 | __builtin__.print = print 227 | 228 | 229 | def is_dist_avail_and_initialized(): 230 | if not dist.is_available(): 231 | return False 232 | if not dist.is_initialized(): 233 | return False 234 | return True 235 | 236 | 237 | def get_world_size(): 238 | if not is_dist_avail_and_initialized(): 239 | return 1 240 | return dist.get_world_size() 241 | 242 | 243 | def get_rank(): 244 | if not is_dist_avail_and_initialized(): 245 | return 0 246 | return dist.get_rank() 247 | 248 | 249 | def is_main_process(): 250 | return get_rank() == 0 251 | 252 | 253 | def save_on_master(*args, **kwargs): 254 | if is_main_process(): 255 | torch.save(*args, **kwargs) 256 | 257 | 258 | def init_distributed_mode(args): 259 | if args.dist_on_itp: 260 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 261 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 262 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 263 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 264 | os.environ['LOCAL_RANK'] = str(args.gpu) 265 | os.environ['RANK'] = str(args.rank) 266 | os.environ['WORLD_SIZE'] = str(args.world_size) 267 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 268 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 269 | args.rank = int(os.environ["RANK"]) 270 | args.world_size = int(os.environ['WORLD_SIZE']) 271 | args.gpu = int(os.environ['LOCAL_RANK']) 272 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 273 | elif 'SLURM_PROCID' in os.environ: 274 | args.rank = int(os.environ['SLURM_PROCID']) 275 | args.gpu = args.rank % torch.cuda.device_count() 276 | else: 277 | print('Not using distributed mode') 278 | args.distributed = False 279 | return 280 | 281 | args.distributed = True 282 | 283 | torch.cuda.set_device(args.gpu) 284 | args.dist_backend = 'nccl' 285 | print('| distributed init (rank {}): {}, gpu {}'.format( 286 | args.rank, args.dist_url, args.gpu), flush=True) 287 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 288 | world_size=args.world_size, rank=args.rank) 289 | torch.distributed.barrier() 290 | setup_for_distributed(args.rank == 0) 291 | 292 | 293 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): 294 | missing_keys = [] 295 | unexpected_keys = [] 296 | error_msgs = [] 297 | # copy state_dict so _load_from_state_dict can modify it 298 | metadata = getattr(state_dict, '_metadata', None) 299 | state_dict = state_dict.copy() 300 | if metadata is not None: 301 | state_dict._metadata = metadata 302 | 303 | def load(module, prefix=''): 304 | local_metadata = {} if metadata is None else metadata.get( 305 | prefix[:-1], {}) 306 | module._load_from_state_dict( 307 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 308 | for name, child in module._modules.items(): 309 | if child is not None: 310 | load(child, prefix + name + '.') 311 | 312 | load(model, prefix=prefix) 313 | 314 | warn_missing_keys = [] 315 | ignore_missing_keys = [] 316 | for key in missing_keys: 317 | keep_flag = True 318 | for ignore_key in ignore_missing.split('|'): 319 | if ignore_key in key: 320 | keep_flag = False 321 | break 322 | if keep_flag: 323 | warn_missing_keys.append(key) 324 | else: 325 | ignore_missing_keys.append(key) 326 | 327 | missing_keys = warn_missing_keys 328 | 329 | if len(missing_keys) > 0: 330 | print("Weights of {} not initialized from pretrained model: {}".format( 331 | model.__class__.__name__, missing_keys)) 332 | if len(unexpected_keys) > 0: 333 | print("Weights from pretrained model not used in {}: {}".format( 334 | model.__class__.__name__, unexpected_keys)) 335 | if len(ignore_missing_keys) > 0: 336 | print("Ignored weights of {} not initialized from pretrained model: {}".format( 337 | model.__class__.__name__, ignore_missing_keys)) 338 | if len(error_msgs) > 0: 339 | print('\n'.join(error_msgs)) 340 | 341 | 342 | class NativeScalerWithGradNormCount: 343 | state_dict_key = "amp_scaler" 344 | 345 | def __init__(self): 346 | self._scaler = torch.cuda.amp.GradScaler() 347 | 348 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 349 | self._scaler.scale(loss).backward(create_graph=create_graph) 350 | if update_grad: 351 | if clip_grad is not None: 352 | assert parameters is not None 353 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 354 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 355 | else: 356 | self._scaler.unscale_(optimizer) 357 | norm = get_grad_norm_(parameters) 358 | self._scaler.step(optimizer) 359 | self._scaler.update() 360 | else: 361 | norm = None 362 | return norm 363 | 364 | def state_dict(self): 365 | return self._scaler.state_dict() 366 | 367 | def load_state_dict(self, state_dict): 368 | self._scaler.load_state_dict(state_dict) 369 | 370 | 371 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 372 | if isinstance(parameters, torch.Tensor): 373 | parameters = [parameters] 374 | parameters = [p for p in parameters if p.grad is not None] 375 | norm_type = float(norm_type) 376 | if len(parameters) == 0: 377 | return torch.tensor(0.) 378 | device = parameters[0].grad.device 379 | if norm_type == inf: 380 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 381 | else: 382 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 383 | return total_norm 384 | 385 | 386 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, 387 | start_warmup_value=0, warmup_steps=-1): 388 | warmup_schedule = np.array([]) 389 | warmup_iters = warmup_epochs * niter_per_ep 390 | if warmup_steps > 0: 391 | warmup_iters = warmup_steps 392 | print("Set warmup steps = %d" % warmup_iters) 393 | if warmup_epochs > 0: 394 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 395 | 396 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 397 | schedule = np.array( 398 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) 399 | 400 | schedule = np.concatenate((warmup_schedule, schedule)) 401 | 402 | assert len(schedule) == epochs * niter_per_ep 403 | return schedule 404 | 405 | 406 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 407 | output_dir = Path(args.output_dir) 408 | epoch_name = str(epoch) 409 | if loss_scaler is not None: 410 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 411 | for checkpoint_path in checkpoint_paths: 412 | to_save = { 413 | 'model': model_without_ddp.state_dict(), 414 | 'optimizer': optimizer.state_dict(), 415 | 'epoch': epoch, 416 | 'scaler': loss_scaler.state_dict(), 417 | 'args': args, 418 | } 419 | 420 | if model_ema is not None: 421 | to_save['model_ema'] = get_state_dict(model_ema) 422 | 423 | save_on_master(to_save, checkpoint_path) 424 | else: 425 | client_state = {'epoch': epoch} 426 | if model_ema is not None: 427 | client_state['model_ema'] = get_state_dict(model_ema) 428 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 429 | 430 | 431 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 432 | output_dir = Path(args.output_dir) 433 | if loss_scaler is not None: 434 | # torch.amp 435 | if args.auto_resume and len(args.resume) == 0: 436 | import glob 437 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) 438 | latest_ckpt = -1 439 | for ckpt in all_checkpoints: 440 | t = ckpt.split('-')[-1].split('.')[0] 441 | if t.isdigit(): 442 | latest_ckpt = max(int(t), latest_ckpt) 443 | if latest_ckpt >= 0: 444 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) 445 | print("Auto resume checkpoint: %s" % args.resume) 446 | 447 | if args.resume: 448 | if args.resume.startswith('https'): 449 | checkpoint = torch.hub.load_state_dict_from_url( 450 | args.resume, map_location='cpu', check_hash=True) 451 | else: 452 | checkpoint = torch.load(args.resume, map_location='cpu') 453 | model_without_ddp.load_state_dict(checkpoint['model']) 454 | print("Resume checkpoint %s" % args.resume) 455 | if 'optimizer' in checkpoint and 'epoch' in checkpoint: 456 | optimizer.load_state_dict(checkpoint['optimizer']) 457 | args.start_epoch = checkpoint['epoch'] + 1 458 | if hasattr(args, 'model_ema') and args.model_ema: 459 | _load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 460 | if 'scaler' in checkpoint: 461 | loss_scaler.load_state_dict(checkpoint['scaler']) 462 | print("With optim & sched!") 463 | else: 464 | # deepspeed, only support '--auto_resume'. 465 | if args.auto_resume: 466 | import glob 467 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*')) 468 | latest_ckpt = -1 469 | for ckpt in all_checkpoints: 470 | t = ckpt.split('-')[-1].split('.')[0] 471 | if t.isdigit(): 472 | latest_ckpt = max(int(t), latest_ckpt) 473 | if latest_ckpt >= 0: 474 | args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt) 475 | print("Auto resume checkpoint: %d" % latest_ckpt) 476 | _, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt) 477 | args.start_epoch = client_states['epoch'] + 1 478 | if model_ema is not None: 479 | if args.model_ema: 480 | _load_checkpoint_for_ema(model_ema, client_states['model_ema']) 481 | 482 | 483 | def create_ds_config(args): 484 | args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json") 485 | with open(args.deepspeed_config, mode="w") as writer: 486 | ds_config = { 487 | "train_batch_size": args.batch_size * args.update_freq * get_world_size(), 488 | "train_micro_batch_size_per_gpu": args.batch_size, 489 | "steps_per_print": 1000, 490 | "optimizer": { 491 | "type": "Adam", 492 | "adam_w_mode": True, 493 | "params": { 494 | "lr": args.lr, 495 | "weight_decay": args.weight_decay, 496 | "bias_correction": True, 497 | "betas": [ 498 | 0.9, 499 | 0.999 500 | ], 501 | "eps": 1e-8 502 | } 503 | }, 504 | "fp16": { 505 | "enabled": True, 506 | "loss_scale": 0, 507 | "initial_scale_power": 7, 508 | "loss_scale_window": 128 509 | } 510 | } 511 | 512 | writer.write(json.dumps(ds_config, indent=2)) 513 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/ExtreMA/bf30b64ec0046524d0c373b761df323b9328f8dc/lib/__init__.py -------------------------------------------------------------------------------- /lib/augment.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | import numpy as np 3 | 4 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 5 | std=[0.229, 0.224, 0.225]) 6 | 7 | class SpatialConsistentColorAug: 8 | """Take two random crops of one image""" 9 | 10 | def __init__(self, crop_min): 11 | self.base_transform = transforms.Compose([ 12 | transforms.RandomResizedCrop(224, scale=(crop_min, 1.)), 13 | transforms.RandomHorizontalFlip()]) 14 | self.color_aug = transforms.Compose([transforms.RandomApply([ 15 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened 16 | ], p=0.8), 17 | transforms.RandomGrayscale(p=0.2), 18 | transforms.ToTensor(), 19 | normalize]) 20 | self.to_tensor = transforms.Compose([transforms.ToTensor(), 21 | normalize]) 22 | 23 | def __call__(self, x): 24 | im = self.base_transform(x) 25 | im1 = self.color_aug(im) 26 | im2 = self.color_aug(im) 27 | return [im1, im2] 28 | 29 | def create_augmentation(args): 30 | 31 | if args.aug_centercrop: 32 | augmentation = [ 33 | transforms.Resize(256), 34 | transforms.CenterCrop(224), 35 | transforms.ToTensor(), 36 | normalize 37 | ] 38 | 39 | if args.aug_spatialconsistent_color: 40 | return SpatialConsistentColorAug(args.crop_min) 41 | 42 | if args.aug_spatial: 43 | augmentation = [ 44 | transforms.RandomResizedCrop(224, scale=(args.crop_min, 1.)), 45 | transforms.RandomHorizontalFlip(), 46 | transforms.ToTensor(), 47 | normalize 48 | ] 49 | 50 | return transforms.Compose(augmentation) 51 | -------------------------------------------------------------------------------- /lib/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import random 5 | import torch.nn.functional as F 6 | import torch.distributed as dist 7 | 8 | class ExtreMA(nn.Module): 9 | """ 10 | Build the model 11 | """ 12 | def __init__(self, base_encoder, ema_encoder, proj_dim=256, mlp_dim=4096, T=1., mask_ratio=0.8, num_masks=1, disjoint=True): 13 | """ 14 | dim: feature dimension (default: 256) 15 | mlp_dim: hidden dimension in MLPs (default: 4096) 16 | T: softmax temperature (default: 1.0) 17 | """ 18 | super(ExtreMA, self).__init__() 19 | 20 | self.T = T 21 | self.mask_ratio = mask_ratio 22 | self.num_masks = num_masks 23 | self.disjoint_sampling = disjoint 24 | 25 | # build encoders 26 | self.base_encoder = base_encoder(num_classes=mlp_dim) 27 | self.momentum_encoder = ema_encoder(num_classes=mlp_dim) 28 | self.base_encoder.student=True 29 | self.momentum_encoder.student=False 30 | 31 | hidden_dim = self.base_encoder.norm.weight.data.shape[0] 32 | del self.base_encoder.head, self.momentum_encoder.head # remove original fc layer 33 | # projectors 34 | self.base_encoder.head = self._build_proj(3, hidden_dim, mlp_dim, proj_dim) 35 | self.momentum_encoder.head = self._build_proj(3, hidden_dim, mlp_dim, proj_dim) 36 | # predictor 37 | self.predictor = self._build_pred(2, proj_dim, mlp_dim, proj_dim) 38 | 39 | for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()): 40 | param_m.data.copy_(param_b.data) # initialize 41 | param_m.requires_grad = False # not update by gradient 42 | 43 | def _build_pred(self, num_layers, input_dim, mlp_dim, output_dim): 44 | layers = [] 45 | for l in range(num_layers): 46 | dim1 = input_dim if l == 0 else mlp_dim 47 | dim2 = output_dim if l == num_layers - 1 else mlp_dim 48 | 49 | layers.append(nn.Linear(dim1, dim2, bias=True)) 50 | 51 | if l < num_layers - 1: 52 | layers.append(nn.LayerNorm(dim2)) 53 | layers.append(nn.GELU()) 54 | 55 | mlp = nn.Sequential(*layers) 56 | mlp.apply(self._init_weights) 57 | return mlp 58 | 59 | def _build_proj(self, num_layers, input_dim, mlp_dim, output_dim): 60 | layers = [] 61 | for l in range(num_layers): 62 | dim1 = input_dim if l == 0 else mlp_dim 63 | dim2 = output_dim if l == num_layers - 1 else mlp_dim 64 | 65 | layers.append(nn.Linear(dim1, dim2, bias=True)) 66 | 67 | if l < num_layers - 1: 68 | layers.append(nn.LayerNorm(dim2)) 69 | layers.append(nn.GELU()) 70 | 71 | mlp = nn.Sequential(*layers) 72 | mlp.apply(self._init_weights) 73 | return mlp 74 | 75 | @torch.no_grad() 76 | def _update_momentum_encoder(self, m): 77 | """Momentum update of the momentum encoder""" 78 | for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()): 79 | param_m.data = param_m.data * m + param_b.data * (1. - m) 80 | 81 | def contrastive_loss(self, q, k): 82 | # normalize 83 | q = nn.functional.normalize(q, dim=1) 84 | k = nn.functional.normalize(k, dim=1) 85 | # gather all targets 86 | k = concat_all_gather(k) 87 | # Einstein sum is more intuitive 88 | logits = torch.einsum('nc,mc->nm', [q, k]) / self.T 89 | N = logits.shape[0] # batch size per GPU 90 | labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda() 91 | return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T) 92 | 93 | def byol_loss(self, q, k): 94 | # normalize 95 | q = nn.functional.normalize(q, dim=1) 96 | k = nn.functional.normalize(k, dim=1) 97 | loss = ((q - k) ** 2).sum(dim=-1) 98 | return loss.mean() 99 | 100 | def _init_weights(self, m): 101 | if isinstance(m, nn.Linear): 102 | # we use xavier_uniform following official JAX ViT: 103 | torch.nn.init.xavier_uniform_(m.weight) 104 | if isinstance(m, nn.Linear) and m.bias is not None: 105 | nn.init.constant_(m.bias, 0) 106 | elif isinstance(m, nn.LayerNorm): 107 | if m.weight is not None: 108 | nn.init.constant_(m.bias, 0) 109 | nn.init.constant_(m.weight, 1.0) 110 | 111 | @torch.jit.ignore 112 | def no_weight_decay(self): 113 | return {'base_encoder.' + k for k in self.base_encoder.no_weight_decay()} 114 | 115 | @torch.no_grad() 116 | def generate_mask(self, x, num_masks, mask_ratio=0.): 117 | if mask_ratio > 0: 118 | if self.disjoint_sampling: 119 | view_size = int(196 * (1 - mask_ratio)) 120 | B = x.size(0) 121 | device = x.get_device() 122 | noise = torch.rand(B, 196, device=device) 123 | mask_index = torch.argsort(noise, dim=1) # consider cls token 124 | masks = [] 125 | for i in range(num_masks): 126 | # 196 patches are hard-coded 127 | mask = mask_index[:, view_size*i:view_size*(i+1)] 128 | mask = mask.long() 129 | masks.append(mask) 130 | else: 131 | masks = [] 132 | for i in range(num_masks): 133 | # 196 patches are hard-coded 134 | B = x.size(0) 135 | device = x.get_device() 136 | noise = torch.rand(B, 196, device=device) 137 | mask_index = torch.argsort(noise, dim=1) 138 | mask = mask_index[:, :int(196*(1-mask_ratio))] # consider the cls token 139 | mask = mask.long() 140 | masks.append(mask) 141 | else: 142 | masks = None 143 | 144 | return masks 145 | 146 | def forward(self, x, m, loss='byol'): 147 | """ 148 | Input: 149 | x1: first views of images 150 | x2: second views of images 151 | m: ema momentum 152 | Output: 153 | loss 154 | """ 155 | if isinstance(x, list): 156 | x1 = x[0] 157 | x2 = x[1] 158 | else: 159 | x1 = x 160 | x2 = x 161 | 162 | # compute features 163 | B,_,_,_ = x1.size() 164 | device = x1.get_device() 165 | 166 | mask_s = self.generate_mask(x1, self.num_masks, self.mask_ratio) 167 | mask_t = self.generate_mask(x1, 1, 0.) 168 | 169 | q1 = self.predictor(self.base_encoder(x1, mask_s, self.mask_ratio)) 170 | with torch.no_grad(): # no gradient 171 | self._update_momentum_encoder(m) # update the momentum encoder 172 | k2 = self.momentum_encoder(x2, mask_t) 173 | 174 | if loss == "byol": 175 | q1 = torch.chunk(q1, self.num_masks, dim=0) 176 | loss = 0. 177 | for q1i in q1: 178 | loss += self.byol_loss(q1i, k2) 179 | return loss / self.num_masks 180 | 181 | elif loss == 'infonce': 182 | q1 = torch.chunk(q1, self.local_crops, dim=0) 183 | loss = 0. 184 | for q1i in q1: 185 | loss += self.contrastive_loss(q1i, k2) 186 | return loss / self.num_masks 187 | 188 | # utils 189 | @torch.no_grad() 190 | def concat_all_gather(tensor): 191 | """ 192 | Performs all_gather operation on the provided tensors. 193 | *** Warning ***: torch.distributed.all_gather has no gradient. 194 | """ 195 | tensors_gather = [torch.ones_like(tensor) 196 | for _ in range(torch.distributed.get_world_size())] 197 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 198 | 199 | output = torch.cat(tensors_gather, dim=0) 200 | return output 201 | 202 | # utils 203 | @torch.no_grad() 204 | def mean_all_gather(tensor): 205 | """ 206 | Performs all_gather operation on the provided tensors. 207 | *** Warning ***: torch.distributed.all_gather has no gradient. 208 | """ 209 | # print(tensor.size()) 210 | tensors_gather = [torch.ones_like(tensor) 211 | for _ in range(torch.distributed.get_world_size())] 212 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 213 | 214 | output = torch.mean(torch.cat(tensors_gather,dim=0)) 215 | return output 216 | -------------------------------------------------------------------------------- /lib/dataload_optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class _RepeatSampler(object): 4 | 5 | def __init__(self, sampler): 6 | self.sampler = sampler 7 | 8 | def __iter__(self): 9 | while True: 10 | yield from iter(self.sampler) 11 | 12 | class PersistentDataLoader(torch.utils.data.dataloader.DataLoader): 13 | 14 | def __init__(self, *args, **kwargs): 15 | print('persistent dataloader') 16 | super().__init__(*args, **kwargs) 17 | object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) 18 | self.iterator = super().__iter__() 19 | 20 | def __len__(self): 21 | return len(self.batch_sampler.sampler) 22 | 23 | def __iter__(self): 24 | for i in range(len(self)): 25 | yield next(self.iterator) 26 | 27 | class SoftwarePipeline(object): 28 | 29 | def __init__(self, dataloader): 30 | self.dataloader = dataloader 31 | self.stream = None 32 | 33 | def __len__(self): 34 | return len(self.dataloader) 35 | 36 | def __iter__(self): 37 | if self.stream is None: 38 | self.stream = torch.cuda.Stream() 39 | first = True 40 | for next_input, next_target in self.dataloader: 41 | with torch.cuda.stream(self.stream): 42 | next_target = next_target.cuda(non_blocking=True) 43 | if isinstance(next_input, list): 44 | for i in range(len(next_input)): 45 | next_input[i] = next_input[i].cuda(non_blocking=True) 46 | else: 47 | next_input = next_input.cuda(non_blocking=True) 48 | if not first: 49 | yield input, target 50 | else: 51 | first = False 52 | torch.cuda.current_stream().wait_stream(self.stream) 53 | input = next_input 54 | target = next_target 55 | yield input, target 56 | -------------------------------------------------------------------------------- /lib/logger.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import os 4 | import sys 5 | from termcolor import colored 6 | 7 | 8 | class _ColorfulFormatter(logging.Formatter): 9 | def __init__(self, *args, **kwargs): 10 | self._root_name = kwargs.pop("root_name") + "." 11 | self._abbrev_name = kwargs.pop("abbrev_name", "") 12 | if len(self._abbrev_name): 13 | self._abbrev_name = self._abbrev_name + "." 14 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 15 | 16 | def formatMessage(self, record): 17 | record.name = record.name.replace(self._root_name, self._abbrev_name) 18 | log = super(_ColorfulFormatter, self).formatMessage(record) 19 | if record.levelno == logging.WARNING: 20 | prefix = colored("WARNING", "red", attrs=["blink"]) 21 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 22 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 23 | else: 24 | return log 25 | return prefix + " " + log 26 | 27 | # so that calling setup_logger multiple times won't add many handlers 28 | @functools.lru_cache() 29 | def setup_logger( 30 | output=None, distributed_rank=0, *, color=True, name="moco", abbrev_name=None 31 | ): 32 | """ 33 | Initialize the detectron2 logger and set its verbosity level to "INFO". 34 | 35 | Args: 36 | output (str): a file name or a directory to save log. If None, will not save log file. 37 | If ends with ".txt" or ".log", assumed to be a file name. 38 | Otherwise, logs will be saved to `output/log.txt`. 39 | name (str): the root module name of this logger 40 | 41 | Returns: 42 | logging.Logger: a logger 43 | """ 44 | logger = logging.getLogger(name) 45 | logger.setLevel(logging.DEBUG) 46 | logger.propagate = False 47 | 48 | if abbrev_name is None: 49 | abbrev_name = name 50 | 51 | plain_formatter = logging.Formatter( 52 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" 53 | ) 54 | # stdout logging: master only 55 | if distributed_rank == 0: 56 | ch = logging.StreamHandler(stream=sys.stdout) 57 | ch.setLevel(logging.DEBUG) 58 | if color: 59 | formatter = _ColorfulFormatter( 60 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 61 | datefmt="%m/%d %H:%M:%S", 62 | root_name=name, 63 | abbrev_name=str(abbrev_name), 64 | ) 65 | else: 66 | formatter = plain_formatter 67 | ch.setFormatter(formatter) 68 | logger.addHandler(ch) 69 | 70 | # file logging: all workers 71 | if output is not None: 72 | if output.endswith(".txt") or output.endswith(".log"): 73 | filename = output 74 | else: 75 | filename = os.path.join(output, "log.txt") 76 | if distributed_rank > 0: 77 | filename = filename + f".rank{distributed_rank}" 78 | os.makedirs(os.path.dirname(filename), exist_ok=True) 79 | 80 | fh = logging.StreamHandler(_cached_log_stream(filename)) 81 | fh.setLevel(logging.DEBUG) 82 | fh.setFormatter(plain_formatter) 83 | logger.addHandler(fh) 84 | 85 | return logger 86 | 87 | 88 | # cache the opened file object, so that different calls to `setup_logger` 89 | # with the same file name can safely write to the same file. 90 | @functools.lru_cache(maxsize=None) 91 | def _cached_log_stream(filename): 92 | return open(filename, "a") -------------------------------------------------------------------------------- /lib/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import builtins 4 | import datetime 5 | 6 | import torch 7 | from torch._six import inf 8 | import torch.distributed as dist 9 | 10 | # -------------------------------------------------------- 11 | # 2D sine-cosine position embedding 12 | # References: 13 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 14 | # MoCo v3: https://github.com/facebookresearch/moco-v3 15 | # -------------------------------------------------------- 16 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 17 | """ 18 | grid_size: int of the grid height and width 19 | return: 20 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 21 | """ 22 | grid_h = np.arange(grid_size, dtype=np.float32) 23 | grid_w = np.arange(grid_size, dtype=np.float32) 24 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 25 | grid = np.stack(grid, axis=0) 26 | 27 | grid = grid.reshape([2, 1, grid_size, grid_size]) 28 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 29 | if cls_token: 30 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 31 | return pos_embed 32 | 33 | 34 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 35 | assert embed_dim % 2 == 0 36 | 37 | # use half of dimensions to encode grid_h 38 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 39 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 40 | 41 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 42 | return emb 43 | 44 | 45 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 46 | """ 47 | embed_dim: output dimension for each position 48 | pos: a list of positions to be encoded: size (M,) 49 | out: (M, D) 50 | """ 51 | assert embed_dim % 2 == 0 52 | omega = np.arange(embed_dim // 2, dtype=np.float) 53 | omega /= embed_dim / 2. 54 | omega = 1. / 10000**omega # (D/2,) 55 | 56 | pos = pos.reshape(-1) # (M,) 57 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 58 | 59 | emb_sin = np.sin(out) # (M, D/2) 60 | emb_cos = np.cos(out) # (M, D/2) 61 | 62 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 63 | return emb 64 | 65 | 66 | # -------------------------------------------------------- 67 | # Interpolate position embeddings for high-resolution 68 | # References: 69 | # DeiT: https://github.com/facebookresearch/deit 70 | # -------------------------------------------------------- 71 | def interpolate_pos_embed(model, checkpoint_model): 72 | if 'pos_embed' in checkpoint_model: 73 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 74 | embedding_size = pos_embed_checkpoint.shape[-1] 75 | num_patches = model.patch_embed.num_patches 76 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 77 | # height (== width) for the checkpoint position embedding 78 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 79 | # height (== width) for the new position embedding 80 | new_size = int(num_patches ** 0.5) 81 | # class_token and dist_token are kept unchanged 82 | if orig_size != new_size: 83 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 84 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 85 | # only the position tokens are interpolated 86 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 87 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 88 | pos_tokens = torch.nn.functional.interpolate( 89 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 90 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 91 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 92 | checkpoint_model['pos_embed'] = new_pos_embed 93 | 94 | def clip_gradients(parameters, clip): 95 | norms = [] 96 | for p in parameters: 97 | if p.grad is not None: 98 | param_norm = p.grad.data.norm(2) 99 | norms.append(param_norm) 100 | clip_coef = clip / (param_norm + 1e-6) 101 | if clip_coef < 1: 102 | p.grad.data.mul_(clip_coef) 103 | return norms 104 | 105 | class NativeScalerWithGradNormCount: 106 | state_dict_key = "amp_scaler" 107 | 108 | def __init__(self): 109 | self._scaler = torch.cuda.amp.GradScaler() 110 | 111 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 112 | self._scaler.scale(loss).backward(create_graph=create_graph) 113 | if update_grad: 114 | if clip_grad is not None: 115 | assert parameters is not None 116 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 117 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 118 | #norms = clip_gradients(parameters, clip_grad) 119 | #norm = torch.norm(torch.stack(norms), 2) 120 | else: 121 | self._scaler.unscale_(optimizer) 122 | norm = get_grad_norm_(parameters) 123 | self._scaler.step(optimizer) 124 | self._scaler.update() 125 | else: 126 | norm = None 127 | return norm 128 | 129 | def state_dict(self): 130 | return self._scaler.state_dict() 131 | 132 | def load_state_dict(self, state_dict): 133 | self._scaler.load_state_dict(state_dict) 134 | 135 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 136 | if isinstance(parameters, torch.Tensor): 137 | parameters = [parameters] 138 | parameters = [p for p in parameters if p.grad is not None] 139 | norm_type = float(norm_type) 140 | if len(parameters) == 0: 141 | return torch.tensor(0.) 142 | device = parameters[0].grad.device 143 | if norm_type == inf: 144 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 145 | else: 146 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 147 | return total_norm 148 | 149 | def init_distributed_mode(args): 150 | if args.dist_on_itp: 151 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 152 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 153 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 154 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 155 | os.environ['LOCAL_RANK'] = str(args.gpu) 156 | os.environ['RANK'] = str(args.rank) 157 | os.environ['WORLD_SIZE'] = str(args.world_size) 158 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 159 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 160 | args.rank = int(os.environ["RANK"]) 161 | args.world_size = int(os.environ['WORLD_SIZE']) 162 | args.gpu = int(os.environ['LOCAL_RANK']) 163 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 164 | elif 'SLURM_PROCID' in os.environ: 165 | args.rank = int(os.environ['SLURM_PROCID']) 166 | args.gpu = args.rank % torch.cuda.device_count() 167 | else: 168 | print('Not using distributed mode') 169 | setup_for_distributed(is_master=True) # hack 170 | args.distributed = False 171 | return 172 | 173 | args.distributed = True 174 | 175 | torch.cuda.set_device(args.gpu) 176 | args.dist_backend = 'nccl' 177 | print('| distributed init (rank {}): {}, gpu {}'.format( 178 | args.rank, args.dist_url, args.gpu), flush=True) 179 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 180 | world_size=args.world_size, rank=args.rank) 181 | torch.distributed.barrier() 182 | setup_for_distributed(args.rank == 0) 183 | 184 | def setup_for_distributed(is_master): 185 | """ 186 | This function disables printing when not in master process 187 | """ 188 | builtin_print = builtins.print 189 | 190 | def print(*args, **kwargs): 191 | force = kwargs.pop('force', False) 192 | force = force or (get_world_size() > 8) 193 | if is_master or force: 194 | now = datetime.datetime.now().time() 195 | builtin_print('[{}] '.format(now), end='') # print with time stamp 196 | builtin_print(*args, **kwargs) 197 | 198 | builtins.print = print 199 | 200 | 201 | def is_dist_avail_and_initialized(): 202 | if not dist.is_available(): 203 | return False 204 | if not dist.is_initialized(): 205 | return False 206 | return True 207 | 208 | 209 | def get_world_size(): 210 | if not is_dist_avail_and_initialized(): 211 | return 1 212 | return dist.get_world_size() 213 | 214 | 215 | def get_rank(): 216 | if not is_dist_avail_and_initialized(): 217 | return 0 218 | return dist.get_rank() 219 | 220 | 221 | def is_main_process(): 222 | return get_rank() == 0 223 | 224 | 225 | def save_on_master(*args, **kwargs): 226 | if is_main_process(): 227 | torch.save(*args, **kwargs) 228 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import builtins 5 | import math 6 | import os 7 | import random 8 | import shutil 9 | import time 10 | import warnings 11 | import json 12 | import logging 13 | from functools import partial 14 | import numpy as np 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.parallel 19 | import torch.backends.cudnn as cudnn 20 | import torch.distributed as dist 21 | import torch.optim 22 | import torch.multiprocessing as mp 23 | import torch.utils.data 24 | import torch.utils.data.distributed 25 | import torchvision.transforms as transforms 26 | import torchvision.datasets as datasets 27 | import torchvision.models as torchvision_models 28 | from torch.utils.tensorboard import SummaryWriter 29 | 30 | import lib.builder 31 | from lib.logger import setup_logger 32 | from lib.augment import create_augmentation 33 | from lib.misc import NativeScalerWithGradNormCount as NativeScaler 34 | from lib.dataload_optim import PersistentDataLoader, SoftwarePipeline 35 | import lib.misc as misc 36 | 37 | from timm.optim import optim_factory, create_optimizer 38 | import vits 39 | 40 | model_names = ['vit_small', 'vit_base', 'vit_large'] 41 | 42 | parser = argparse.ArgumentParser(description='ExtreMA Arguments') 43 | parser.add_argument('data', metavar='DIR', 44 | help='path to dataset') 45 | parser.add_argument('-a', '--arch', metavar='ARCH', default='vit_base', 46 | choices=model_names, 47 | help='model architecture: ' + 48 | ' | '.join(model_names) + 49 | ' (default: vit_base)') 50 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 51 | help='number of data loading workers per gpu (default: 6)') 52 | parser.add_argument('--epochs', default=300, type=int, metavar='N', 53 | help='number of total epochs to run') 54 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 55 | help='manual epoch number (useful on restarts)') 56 | parser.add_argument('-b', '--batch-size', default=2048, type=int, 57 | metavar='N', 58 | help='mini-batch size (default: 2048), this is the total ' 59 | 'batch size of all GPUs on the current node when ' 60 | 'using Data Parallel or Distributed Data Parallel') 61 | parser.add_argument('--lr', '--learning-rate', default=1.5e-4, type=float, 62 | metavar='LR', help='initial (base) learning rate', dest='lr') 63 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 64 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 65 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 66 | help='momentum') 67 | parser.add_argument('--wd', '--weight-decay', default=0.1, type=float, 68 | metavar='W', help='weight decay (default: 1e-6)', 69 | dest='weight_decay') 70 | parser.add_argument('--weight-decay-end', default=None, type=float, 71 | metavar='W', help='weight decay end (default: 1e-6)', 72 | dest='weight_decay_end') 73 | parser.add_argument('-p', '--print-freq', default=20, type=int, 74 | metavar='N', help='print frequency (default: 10)') 75 | parser.add_argument('--save-freq', default=5, type=int, 76 | metavar='N', help='save frequency (default: 5)') 77 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 78 | help='path to latest checkpoint (default: none)') 79 | parser.add_argument('--world-size', default=-1, type=int, 80 | help='number of nodes for distributed training') 81 | parser.add_argument('--local_rank', default=-1, type=int, 82 | help='node rank for distributed training') 83 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 84 | help='url used to set up distributed training') 85 | parser.add_argument('--dist-backend', default='nccl', type=str, 86 | help='distributed backend') 87 | parser.add_argument('--seed', default=None, type=int, 88 | help='seed for initializing training. ') 89 | parser.add_argument('--gpu', default=None, type=int, 90 | help='GPU id to use.') 91 | parser.add_argument('--device', default='cuda', 92 | help='device to use for training / testing') 93 | parser.add_argument('--dist_on_itp', action='store_true') 94 | parser.add_argument('--multiprocessing-distributed', action='store_true', 95 | help='Use multi-processing distributed training to launch ' 96 | 'N processes per node, which has N GPUs. This is the ' 97 | 'fastest way to use PyTorch for either single node or ' 98 | 'multi node data parallel training') 99 | parser.add_argument('--log_dir', default="tf_logs", type=str, 100 | help='dir of logs') 101 | parser.add_argument('--output_dir', default="results", type=str, 102 | help='dir of checkpoints') 103 | 104 | # siamese specific configs: 105 | parser.add_argument('--proj-dim', default=256, type=int, 106 | help='feature dimension (default: 256)') 107 | parser.add_argument('--mlp-dim', default=4096, type=int, 108 | help='hidden dimension in MLPs (default: 4096)') 109 | parser.add_argument('--ema-momentum', default=0.996, type=float, 110 | help='momentum of updating momentum encoder (default: 0.996)') 111 | parser.add_argument('--contrast-temp', default=1.0, type=float, 112 | help='contrastive softmax temperature (default: 1.0)') 113 | 114 | # vit specific configs: 115 | parser.add_argument('--drop_path_rate', type=float, default=0.0, help="stochastic depth rate") 116 | parser.add_argument('--attn_drop_rate', type=float, default=0.0, help="attention dropout rate") 117 | parser.add_argument('--layer_scale_init_value', default=0.1, type=float, 118 | help="0.1 for base, 1e-5 for large. set 0 to disable layer scale") 119 | parser.add_argument('--class_attention_layers', default=2, type=int) 120 | 121 | # other hyper-params 122 | parser.add_argument('--opt', default='adamw', type=str, 123 | choices=['lars', 'adamw'], 124 | help='optimizer used (default: adamw)') 125 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 126 | help='Clip gradient norm (default: None, no clipping)') 127 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 128 | help='Optimizer Epsilon (default: 1e-8)') 129 | parser.add_argument('--opt-betas', default=(0.9, 0.999), type=float, nargs='+', metavar='BETA', 130 | help='Optimizer Betas (default: None, use opt default)') 131 | parser.add_argument('--adjust-weight-decay', action='store_true', 132 | help='cosine weight decay') 133 | parser.add_argument('--warmup-epochs', default=10, type=int, metavar='N', 134 | help='number of warmup epochs') 135 | parser.add_argument('--crop-min', default=0.2, type=float, 136 | help='minimum scale for random cropping student (default: 0.2)') 137 | 138 | # augmentation options 139 | parser.add_argument('--aug-spatial', action='store_true', 140 | help='use spatial data augmentation') 141 | parser.add_argument('--aug-centercrop', action='store_true', 142 | help='use centercrop data augmentation') 143 | parser.add_argument('--aug-spatialconsistent-color', action='store_true', 144 | help='use spatial consistent with colorjitter data augmentation') 145 | parser.add_argument('--loss', default='byol', type=str, 146 | choices=['infonce', 'byol'], 147 | help='loss function to use') 148 | 149 | # add mask options 150 | parser.add_argument('--mask-ratio', default=0.8, type=float, 151 | help='mask ratio for student augmentation') 152 | parser.add_argument('--num-masks', default=1, type=int) 153 | parser.add_argument('--disjoint', action='store_true', 154 | help='use disjoint sampling of patches') 155 | parser.set_defaults(disjoint=True) 156 | 157 | 158 | def main_worker(args): 159 | misc.init_distributed_mode(args) 160 | global_rank = misc.get_rank() 161 | 162 | os.makedirs(args.log_dir, exist_ok=True) 163 | os.makedirs(args.output_dir, exist_ok=True) 164 | print("args.output_dir", args.output_dir) 165 | print("args.log_dir", args.log_dir) 166 | 167 | if global_rank == 0 and args.log_dir is not None: 168 | with open(args.log_dir + '/config.json', "w") as config_file: 169 | json.dump(vars(args), config_file) 170 | os.makedirs(args.log_dir, exist_ok=True) 171 | summary_writer = SummaryWriter(log_dir=args.log_dir) 172 | else: 173 | summary_writer = None 174 | logger = setup_logger(output=args.log_dir, distributed_rank=global_rank, name="byol") 175 | 176 | device = torch.device(args.device) 177 | 178 | if args.seed is not None: 179 | seed = args.seed + misc.get_rank() 180 | random.seed(seed) 181 | torch.manual_seed(seed) 182 | torch.cuda.manual_seed_all(seed) 183 | np.random.seed(seed) 184 | 185 | cudnn.benchmark = True 186 | 187 | local_batch_size = int(args.batch_size / misc.get_world_size()) 188 | augmentation = create_augmentation(args) 189 | logger.info(augmentation) 190 | 191 | # Data loading 192 | traindir = os.path.join(args.data, 'train') 193 | train_dataset = datasets.ImageFolder( 194 | traindir, 195 | transform=augmentation, 196 | ) 197 | 198 | if True: #args.distributed: 199 | num_tasks = misc.get_world_size() 200 | train_sampler = torch.utils.data.distributed.DistributedSampler( 201 | train_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True 202 | ) 203 | 204 | train_loader = SoftwarePipeline(PersistentDataLoader( 205 | train_dataset, batch_size=local_batch_size, shuffle=(train_sampler is None), 206 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)) 207 | 208 | # create model 209 | logger.info("=> creating model '{}'".format(args.arch)) 210 | base_encoder = partial(vits.__dict__[args.arch], drop_path_rate=args.drop_path_rate, attn_drop_rate=args.attn_drop_rate, init_values=args.layer_scale_init_value, class_attention_layers=args.class_attention_layers) 211 | ema_encoder = partial(vits.__dict__[args.arch], init_values=args.layer_scale_init_value, class_attention_layers=args.class_attention_layers) 212 | model = lib.builder.ExtreMA( 213 | base_encoder, ema_encoder, 214 | args.proj_dim, args.mlp_dim, args.contrast_temp, args.mask_ratio, args.num_masks, args.disjoint) 215 | model.to(device) 216 | 217 | # infer learning rate before changing batch sizex 218 | args.lr = args.lr * args.batch_size / 256 219 | 220 | if True: 221 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 222 | 223 | logger.info(model) 224 | 225 | param_groups = optim_factory.add_weight_decay(model.module, args.weight_decay, model.module.no_weight_decay()) 226 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=args.opt_betas) 227 | 228 | logger.info(optimizer) 229 | scaler = NativeScaler() 230 | 231 | # auto resume from a checkpoint 232 | args.resume = os.path.join(args.output_dir, 'current.pth.tar') 233 | if args.resume: 234 | if os.path.isfile(args.resume): 235 | logger.info("=> loading checkpoint '{}'".format(args.resume)) 236 | if args.gpu is None: 237 | checkpoint = torch.load(args.resume) 238 | else: 239 | # Map model to be loaded to specified single gpu. 240 | loc = 'cuda:{}'.format(args.gpu) 241 | checkpoint = torch.load(args.resume, map_location=loc) 242 | args.start_epoch = checkpoint['epoch'] 243 | model.load_state_dict(checkpoint['state_dict']) 244 | optimizer.load_state_dict(checkpoint['optimizer']) 245 | scaler.load_state_dict(checkpoint['scaler']) 246 | logger.info("=> loaded checkpoint '{}' (epoch {})" 247 | .format(args.resume, checkpoint['epoch'])) 248 | del checkpoint 249 | torch.cuda.empty_cache() 250 | else: 251 | logger.info("=> no checkpoint found at '{}'".format(args.resume)) 252 | 253 | 254 | for epoch in range(args.start_epoch, args.epochs): 255 | if args.distributed: 256 | train_sampler.set_epoch(epoch) 257 | 258 | # train for one epoch 259 | train(train_loader, model, optimizer, scaler, summary_writer, epoch, args) 260 | 261 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 262 | and misc.get_rank() == 0 and (epoch+1) % args.save_freq == 0): # only the first GPU saves checkpoint 263 | save_checkpoint({ 264 | 'epoch': epoch + 1, 265 | 'arch': args.arch, 266 | 'state_dict': model.state_dict(), 267 | 'optimizer' : optimizer.state_dict(), 268 | 'scaler': scaler.state_dict(), 269 | }, is_best=False, filename='{}/checkpoint_{:04d}.pth.tar'.format(args.output_dir, epoch)) 270 | shutil.copyfile('{}/checkpoint_{:04d}.pth.tar'.format(args.output_dir, epoch), '{}/current.pth.tar'.format(args.output_dir)) 271 | 272 | if misc.get_rank() == 0: 273 | summary_writer.close() 274 | 275 | def train(train_loader, model, optimizer, scaler, summary_writer, epoch, args): 276 | batch_time = AverageMeter('Time', ':6.3f') 277 | data_time = AverageMeter('Data', ':6.3f') 278 | learning_rates = AverageMeter('LR', ':.4e') 279 | loss_scales = AverageMeter('LossScale', ':.4e') 280 | weight_decays = AverageMeter('WeightDecay', ':.4e') 281 | grad_norms = AverageMeter('GradNorm', ':.4e') 282 | losses = AverageMeter('Loss', ':.4e') 283 | 284 | progress = ProgressMeter( 285 | len(train_loader), 286 | [batch_time, data_time, losses, grad_norms, loss_scales, weight_decays, learning_rates], 287 | prefix="Epoch: [{}]".format(epoch)) 288 | logger = logging.getLogger('byol') 289 | 290 | # switch to train mode 291 | model.train() 292 | 293 | end = time.time() 294 | iters_per_epoch = len(train_loader) 295 | ema_momentum = args.ema_momentum 296 | for i, (images, labels) in enumerate(train_loader): 297 | # measure data loading time 298 | data_time.update(time.time() - end) 299 | 300 | # adjust learning rate and momentum coefficient per iteration 301 | lr = adjust_learning_rate(optimizer, epoch + i / iters_per_epoch, args) 302 | ema_momentum = adjust_ema_momentum(epoch + i / iters_per_epoch, args) 303 | if args.adjust_weight_decay: 304 | wd = adjust_decay_rate(optimizer, epoch + i / iters_per_epoch, args) 305 | 306 | if args.gpu is not None: 307 | if isinstance(images, list): 308 | images[0] = images[0].cuda(args.gpu, non_blocking=True) 309 | images[1] = images[1].cuda(args.gpu, non_blocking=True) 310 | bsz = images[0].size(0) 311 | else: 312 | images = images.cuda(args.gpu, non_blocking=True) 313 | bsz = images.size(0) 314 | 315 | # compute output 316 | with torch.cuda.amp.autocast(True): 317 | loss = model(images, ema_momentum, args.loss) 318 | losses.update(loss.item(), bsz) 319 | 320 | # compute gradient and do SGD step 321 | optimizer.zero_grad() 322 | 323 | norm = scaler(loss, optimizer, parameters=model.parameters(), clip_grad=args.clip_grad) 324 | 325 | # measure elapsed time 326 | batch_time.update(time.time() - end) 327 | end = time.time() 328 | 329 | if i % args.print_freq == 0: 330 | grad_norms.update(norm) 331 | loss_scales.update(scaler.state_dict()["scale"]) 332 | learning_rates.update(optimizer.param_groups[1]['lr']) 333 | weight_decays.update(optimizer.param_groups[1]['weight_decay']) 334 | progress.display(logger,i) 335 | 336 | if misc.get_rank() == 0: 337 | summary_writer.add_scalar("losses", losses.avg, epoch ) 338 | summary_writer.add_scalar("opt/grad_norm", grad_norms.avg, epoch ) 339 | summary_writer.add_scalar("opt/loss_scale", loss_scales.avg, epoch ) 340 | summary_writer.add_scalar("opt/lr", learning_rates.avg, epoch ) 341 | summary_writer.add_scalar("opt/wd", weight_decays.avg, epoch ) 342 | 343 | 344 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 345 | torch.save(state, filename) 346 | if is_best: 347 | shutil.copyfile(filename, 'model_best.pth.tar') 348 | 349 | class AverageMeter(object): 350 | """Computes and stores the average and current value""" 351 | def __init__(self, name, fmt=':f'): 352 | self.name = name 353 | self.fmt = fmt 354 | self.reset() 355 | 356 | def reset(self): 357 | self.val = 0 358 | self.avg = 0 359 | self.sum = 0 360 | self.count = 0 361 | 362 | def update(self, val, n=1): 363 | self.val = val 364 | self.sum += val * n 365 | self.count += n 366 | self.avg = self.sum / self.count 367 | 368 | def __str__(self): 369 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 370 | return fmtstr.format(**self.__dict__) 371 | 372 | class ProgressMeter(object): 373 | def __init__(self, num_batches, meters, prefix=""): 374 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 375 | self.meters = meters 376 | self.prefix = prefix 377 | 378 | def display(self, logger, batch): 379 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 380 | entries += [str(meter) for meter in self.meters] 381 | logger.info('\t'.join(entries)) 382 | 383 | def _get_batch_fmtstr(self, num_batches): 384 | num_digits = len(str(num_batches // 1)) 385 | fmt = '{:' + str(num_digits) + 'd}' 386 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 387 | 388 | def adjust_learning_rate(optimizer, epoch, args): 389 | """Decays the learning rate with half-cycle cosine after warmup""" 390 | if epoch < args.warmup_epochs: 391 | lr = args.lr * epoch / args.warmup_epochs 392 | else: 393 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 394 | for param_group in optimizer.param_groups: 395 | param_group['lr'] = lr 396 | return lr 397 | 398 | def adjust_decay_rate(optimizer, epoch, args): 399 | """Decays the learning rate with half-cycle cosine""" 400 | if args.weight_decay_end is None: 401 | args.weight_decay_end = args.weight_decay 402 | wd = args.weight_decay_end + 0.5 * (args.weight_decay - args.weight_decay_end) * (1. + math.cos(math.pi * (epoch / args.epochs))) 403 | for param_group in optimizer.param_groups: 404 | if param_group['weight_decay'] > 0: 405 | param_group['weight_decay'] = wd 406 | return wd 407 | 408 | def adjust_ema_momentum(epoch, args): 409 | """Decays the momentum paramter with half-cycle cosine""" 410 | m = 1. - 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) * (1. - args.ema_momentum) 411 | return m 412 | 413 | def accuracy(output, target, topk=(1,)): 414 | """Computes the accuracy over the k top predictions for the specified values of k""" 415 | with torch.no_grad(): 416 | maxk = max(topk) 417 | batch_size = target.size(0) 418 | 419 | _, pred = output.topk(maxk, 1, True, True) 420 | pred = pred.t() 421 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 422 | 423 | res = [] 424 | for k in topk: 425 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 426 | res.append(correct_k.mul_(100.0 / batch_size)) 427 | return res 428 | 429 | if __name__ == '__main__': 430 | args = parser.parse_args() 431 | main_worker(args) 432 | -------------------------------------------------------------------------------- /main_lincls.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import builtins 5 | import math 6 | import os 7 | import random 8 | import shutil 9 | import time 10 | import warnings 11 | import json 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.parallel 16 | import torch.backends.cudnn as cudnn 17 | import torch.distributed as dist 18 | import torch.optim 19 | import torch.multiprocessing as mp 20 | import torch.utils.data 21 | import torch.utils.data.distributed 22 | import torchvision.transforms as transforms 23 | import torchvision.datasets as datasets 24 | import torchvision.models as torchvision_models 25 | 26 | import vits 27 | import lib.misc as misc 28 | import lib.builder 29 | from lib.dataload_optim import PersistentDataLoader 30 | 31 | model_names = ['vit_small', 'vit_base', 'vit_large'] 32 | 33 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 34 | parser.add_argument('data', metavar='DIR', 35 | help='path to dataset') 36 | parser.add_argument('-a', '--arch', metavar='ARCH', default='vit_base', 37 | choices=model_names, 38 | help='model architecture: ' + 39 | ' | '.join(model_names) + 40 | ' (default: vit_base)') 41 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 42 | help='number of data loading workers (default: 32)') 43 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 44 | help='number of total epochs to run') 45 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 46 | help='manual epoch number (useful on restarts)') 47 | parser.add_argument('-b', '--batch-size', default=1024, type=int, 48 | metavar='N', 49 | help='mini-batch size (default: 1024), this is the total ' 50 | 'batch size of all GPUs on the current node when ' 51 | 'using Data Parallel or Distributed Data Parallel') 52 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 53 | metavar='LR', help='initial (base) learning rate', dest='lr') 54 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 55 | help='momentum') 56 | parser.add_argument('--wd', '--weight-decay', default=0., type=float, 57 | metavar='W', help='weight decay (default: 0.)', 58 | dest='weight_decay') 59 | parser.add_argument('--optimizer', default='sgd', type=str, 60 | choices=['lars', 'sgd'], 61 | help='optimizer used (default: sgd)') 62 | parser.add_argument('--warmup-epochs', default=10, type=int, metavar='N', 63 | help='number of warmup epochs') 64 | parser.add_argument('-p', '--print-freq', default=10, type=int, 65 | metavar='N', help='print frequency (default: 10)') 66 | parser.add_argument('--save-freq', default=5, type=int, 67 | metavar='N', help='save frequency (default: 5)') 68 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 69 | help='path to latest checkpoint (default: none)') 70 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 71 | help='evaluate model on validation set') 72 | parser.add_argument('--world-size', default=-1, type=int, 73 | help='number of nodes for distributed training') 74 | parser.add_argument('--rank', default=-1, type=int, 75 | help='node rank for distributed training') 76 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 77 | help='url used to set up distributed training') 78 | parser.add_argument('--dist-backend', default='nccl', type=str, 79 | help='distributed backend') 80 | parser.add_argument('--seed', default=None, type=int, 81 | help='seed for initializing training. ') 82 | parser.add_argument('--gpu', default=None, type=int, 83 | help='GPU id to use.') 84 | parser.add_argument('--multiprocessing-distributed', action='store_true', 85 | help='Use multi-processing distributed training to launch ' 86 | 'N processes per node, which has N GPUs. This is the ' 87 | 'fastest way to use PyTorch for either single node or ' 88 | 'multi node data parallel training') 89 | parser.add_argument('--dist_on_itp', action='store_true') 90 | parser.add_argument('--local_rank', default=-1, type=int, 91 | help='node rank for distributed training') 92 | 93 | # additional configs: 94 | parser.add_argument('--pretrained', default='', type=str, 95 | help='path to moco pretrained checkpoint') 96 | parser.add_argument('--log_dir', default=None, 97 | help='path where to tensorboard log') 98 | parser.add_argument('--eval_momentum', action='store_true', 99 | help='evaluation momentum encoder') 100 | 101 | best_acc1 = 0 102 | 103 | 104 | def main(): 105 | args = parser.parse_args() 106 | 107 | if args.seed is not None: 108 | random.seed(args.seed) 109 | torch.manual_seed(args.seed) 110 | cudnn.deterministic = True 111 | warnings.warn('You have chosen to seed training. ' 112 | 'This will turn on the CUDNN deterministic setting, ' 113 | 'which can slow down your training considerably! ' 114 | 'You may see unexpected behavior when restarting ' 115 | 'from checkpoints.') 116 | 117 | if args.gpu is not None: 118 | warnings.warn('You have chosen a specific GPU. This will completely ' 119 | 'disable data parallelism.') 120 | 121 | if args.dist_url == "env://" and args.world_size == -1: 122 | args.world_size = int(os.environ["WORLD_SIZE"]) 123 | 124 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 125 | 126 | ngpus_per_node = torch.cuda.device_count() 127 | if args.multiprocessing_distributed: 128 | # Since we have ngpus_per_node processes per node, the total world_size 129 | # needs to be adjusted accordingly 130 | args.world_size = ngpus_per_node * args.world_size 131 | # Use torch.multiprocessing.spawn to launch distributed processes: the 132 | # main_worker process function 133 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 134 | else: 135 | # Simply call main_worker function 136 | main_worker(args.gpu, ngpus_per_node, args) 137 | 138 | 139 | def main_worker(args): 140 | if 'LOCAL_RANK' not in os.environ: 141 | os.environ['LOCAL_RANK'] = args.local_rank 142 | 143 | misc.init_distributed_mode(args) 144 | global_rank = misc.get_rank() 145 | 146 | if args.seed is not None: 147 | seed = args.seed + misc.get_rank() 148 | random.seed(seed) 149 | torch.manual_seed(seed) 150 | torch.cuda.manual_seed_all(seed) 151 | np.random.seed(seed) 152 | 153 | cudnn.benchmark = True 154 | 155 | global best_acc1 156 | #args.gpu = gpu 157 | 158 | # suppress printing if not master 159 | if args.multiprocessing_distributed and global_rank != 0: 160 | def print_pass(*args): 161 | pass 162 | builtins.print = print_pass 163 | 164 | if args.gpu is not None: 165 | print("Use GPU: {} for training".format(args.gpu)) 166 | 167 | print("=> creating model '{}'".format(args.arch)) 168 | if args.arch.startswith('vit'): 169 | model = vits.__dict__[args.arch](init_values=0.1) 170 | linear_keyword = 'head' 171 | else: 172 | model = torchvision_models.__dict__[args.arch]() 173 | linear_keyword = 'fc' 174 | 175 | # freeze all layers but the last fc 176 | for name, param in model.named_parameters(): 177 | if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]: 178 | param.requires_grad = False 179 | model.head = nn.Sequential(nn.SyncBatchNorm(model.head.in_features, affine=False), model.head) 180 | # init the fc layer 181 | model.head[1].weight.data.normal_(mean=0.0, std=0.01) 182 | model.head[1].bias.data.zero_() 183 | 184 | # load from pre-trained, before DistributedDataParallel constructor 185 | if args.pretrained: 186 | if os.path.isfile(args.pretrained): 187 | print("=> loading checkpoint '{}'".format(args.pretrained)) 188 | checkpoint = torch.load(args.pretrained, map_location="cpu") 189 | 190 | # rename moco pre-trained keys 191 | state_dict = checkpoint['state_dict'] 192 | if args.eval_momentum: 193 | prefix_key = 'module.momentum_encoder' 194 | else: 195 | prefix_key = 'module.base_encoder' 196 | for k in list(state_dict.keys()): 197 | # retain only base_encoder up to before the embedding layer 198 | if k.startswith(prefix_key) and not k.startswith('{}.{}'.format(prefix_key,linear_keyword)): 199 | # remove prefix 200 | state_dict[k[len(prefix_key)+1:]] = state_dict[k] 201 | # delete renamed or unused k 202 | del state_dict[k] 203 | 204 | args.start_epoch = 0 205 | msg = model.load_state_dict(state_dict, strict=False) 206 | #assert set(msg.missing_keys) == {"%s.weight" % linear_keyword, "%s.bias" % linear_keyword} 207 | print(msg) 208 | 209 | print("=> loaded pre-trained model '{}'".format(args.pretrained)) 210 | else: 211 | print("=> no checkpoint found at '{}'".format(args.pretrained)) 212 | 213 | model.norm = nn.Identity() 214 | 215 | # infer learning rate before changing batch size 216 | args.lr = args.lr * args.batch_size / 256 217 | 218 | if True: 219 | # For multiprocessing distributed, DistributedDataParallel constructor 220 | # should always set the single device scope, otherwise, 221 | # DistributedDataParallel will use all available devices. 222 | torch.cuda.set_device(args.gpu) 223 | model.cuda(args.gpu) 224 | # When using a single GPU per process and per 225 | # DistributedDataParallel, we need to divide the batch size 226 | # ourselves based on the total number of GPUs we have 227 | args.batch_size = int(args.batch_size / misc.get_world_size()) 228 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 229 | 230 | # define loss function (criterion) and optimizer 231 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 232 | 233 | # optimize only the linear classifier 234 | parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 235 | assert len(parameters) == 2 # weight, bias 236 | 237 | if global_rank == 0 and args.log_dir is not None: 238 | os.makedirs(args.log_dir, exist_ok=True) 239 | 240 | if args.optimizer == 'lars': 241 | optimizer = torch.optim.LARS(parameters, args.lr, 242 | weight_decay=0, 243 | momentum=args.momentum) 244 | else: 245 | optimizer = torch.optim.SGD(parameters, args.lr, 246 | momentum=args.momentum, 247 | weight_decay=0) 248 | print(optimizer) 249 | 250 | # optionally resume from a checkpoint 251 | if args.resume: 252 | if os.path.isfile(args.resume): 253 | print("=> loading checkpoint '{}'".format(args.resume)) 254 | if args.gpu is None: 255 | checkpoint = torch.load(args.resume) 256 | else: 257 | # Map model to be loaded to specified single gpu. 258 | loc = 'cuda:{}'.format(args.gpu) 259 | checkpoint = torch.load(args.resume, map_location=loc) 260 | args.start_epoch = checkpoint['epoch'] 261 | best_acc1 = checkpoint['best_acc1'] 262 | if args.gpu is not None: 263 | # best_acc1 may be from a checkpoint from a different GPU 264 | best_acc1 = best_acc1.to(args.gpu) 265 | model.load_state_dict(checkpoint['state_dict']) 266 | optimizer.load_state_dict(checkpoint['optimizer']) 267 | print("=> loaded checkpoint '{}' (epoch {})" 268 | .format(args.resume, checkpoint['epoch'])) 269 | else: 270 | print("=> no checkpoint found at '{}'".format(args.resume)) 271 | 272 | print(model) 273 | cudnn.benchmark = True 274 | 275 | # Data loading code 276 | traindir = os.path.join(args.data, 'train') 277 | valdir = os.path.join(args.data, 'val') 278 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 279 | std=[0.229, 0.224, 0.225]) 280 | 281 | train_dataset = datasets.ImageFolder( 282 | traindir, 283 | transforms.Compose([ 284 | transforms.RandomResizedCrop(224), 285 | transforms.RandomHorizontalFlip(), 286 | transforms.ToTensor(), 287 | normalize, 288 | ])) 289 | val_dataset = datasets.ImageFolder( 290 | valdir, 291 | transforms.Compose([ 292 | transforms.Resize(256), 293 | transforms.CenterCrop(224), 294 | transforms.ToTensor(), 295 | normalize 296 | ])) 297 | 298 | if args.distributed: 299 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 300 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 301 | else: 302 | train_sampler = None 303 | 304 | train_loader = PersistentDataLoader( 305 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 306 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 307 | 308 | val_loader = PersistentDataLoader( 309 | val_dataset, 310 | batch_size=args.batch_size, shuffle=False, 311 | num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False) 312 | 313 | if args.evaluate: 314 | validate(val_loader, model, criterion, args) 315 | return 316 | 317 | for epoch in range(args.start_epoch, args.epochs): 318 | if args.distributed: 319 | train_sampler.set_epoch(epoch) 320 | # train for one epoch 321 | train(train_loader, model, criterion, optimizer, epoch, args) 322 | 323 | # evaluate on validation set 324 | acc1 = validate(val_loader, model, criterion, args) 325 | # remember best acc@1 and save checkpoint 326 | is_best = acc1 > best_acc1 327 | best_acc1 = max(acc1, best_acc1) 328 | 329 | log_stats = {'epoch': epoch, 'acc': acc1.item()} 330 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 331 | and args.rank == 0): 332 | with open(os.path.join(args.log_dir, "log.txt"), mode="a", encoding="utf-8") as f: 333 | f.write(json.dumps(log_stats) + "\n") 334 | 335 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 336 | and args.rank == 0 and (epoch+1) % args.save_freq == 0): # only the first GPU saves checkpoint 337 | save_checkpoint({ 338 | 'epoch': epoch + 1, 339 | 'arch': args.arch, 340 | 'state_dict': model.state_dict(), 341 | 'best_acc1': best_acc1, 342 | 'optimizer' : optimizer.state_dict(), 343 | }, is_best) 344 | if epoch == args.start_epoch: 345 | sanity_check(model.state_dict(), args.pretrained, linear_keyword) 346 | 347 | 348 | def train(train_loader, model, criterion, optimizer, epoch, args): 349 | batch_time = AverageMeter('Time', ':6.3f') 350 | data_time = AverageMeter('Data', ':6.3f') 351 | losses = AverageMeter('Loss', ':.4e') 352 | top1 = AverageMeter('Acc@1', ':6.2f') 353 | top5 = AverageMeter('Acc@5', ':6.2f') 354 | progress = ProgressMeter( 355 | len(train_loader), 356 | [batch_time, data_time, losses, top1, top5], 357 | prefix="Epoch: [{}]".format(epoch)) 358 | 359 | """ 360 | Switch to eval mode: 361 | Under the protocol of linear classification on frozen features/models, 362 | it is not legitimate to change any part of the pre-trained model. 363 | BatchNorm in train mode may revise running mean/std (even if it receives 364 | no gradient), which are part of the model parameters too. 365 | """ 366 | model.eval() 367 | model.module.head.train() 368 | 369 | end = time.time() 370 | for i, (images, target) in enumerate(train_loader): 371 | # measure data loading time 372 | data_time.update(time.time() - end) 373 | lr = adjust_learning_rate(optimizer, epoch + i / len(train_loader), args) 374 | if args.gpu is not None: 375 | images = images.cuda(args.gpu, non_blocking=True) 376 | if torch.cuda.is_available(): 377 | target = target.cuda(args.gpu, non_blocking=True) 378 | 379 | # compute output 380 | with torch.cuda.amp.autocast(True): 381 | output = model(images) 382 | loss = criterion(output, target) 383 | 384 | # measure accuracy and record loss 385 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 386 | losses.update(loss.item(), images.size(0)) 387 | top1.update(acc1[0], images.size(0)) 388 | top5.update(acc5[0], images.size(0)) 389 | 390 | # compute gradient and do SGD step 391 | optimizer.zero_grad() 392 | loss.backward() 393 | optimizer.step() 394 | 395 | # measure elapsed time 396 | batch_time.update(time.time() - end) 397 | end = time.time() 398 | 399 | if i % args.print_freq == 0: 400 | progress.display(i) 401 | 402 | 403 | def validate(val_loader, model, criterion, args): 404 | batch_time = AverageMeter('Time', ':6.3f') 405 | losses = AverageMeter('Loss', ':.4e') 406 | top1 = AverageMeter('Acc@1', ':6.2f') 407 | top5 = AverageMeter('Acc@5', ':6.2f') 408 | progress = ProgressMeter( 409 | len(val_loader), 410 | [batch_time, losses, top1, top5], 411 | prefix='Test: ') 412 | 413 | # switch to evaluate mode 414 | model.eval() 415 | 416 | with torch.no_grad(): 417 | end = time.time() 418 | for i, (images, target) in enumerate(val_loader): 419 | if args.gpu is not None: 420 | images = images.cuda(args.gpu, non_blocking=True) 421 | if torch.cuda.is_available(): 422 | target = target.cuda(args.gpu, non_blocking=True) 423 | 424 | # compute output 425 | with torch.cuda.amp.autocast(True): 426 | output = model(images) 427 | loss = criterion(output, target) 428 | 429 | # measure accuracy and record loss 430 | output = moco.builder.concat_all_gather(output.to("cuda")) 431 | target = moco.builder.concat_all_gather(target.to("cuda")) 432 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 433 | 434 | losses.update(loss.item(), images.size(0)) 435 | top1.update(acc1[0], output.size(0)) 436 | top5.update(acc5[0], target.size(0)) 437 | 438 | # measure elapsed time 439 | batch_time.update(time.time() - end) 440 | end = time.time() 441 | 442 | if i % args.print_freq == 0: 443 | progress.display(i) 444 | 445 | # TODO: this should also be done with the ProgressMeter 446 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 447 | .format(top1=top1, top5=top5)) 448 | 449 | return top1.avg 450 | 451 | 452 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 453 | torch.save(state, filename) 454 | if is_best: 455 | shutil.copyfile(filename, 'model_best.pth.tar') 456 | 457 | 458 | def sanity_check(state_dict, pretrained_weights, linear_keyword): 459 | """ 460 | Linear classifier should not change any weights other than the linear layer. 461 | This sanity check asserts nothing wrong happens (e.g., BN stats updated). 462 | """ 463 | print("=> loading '{}' for sanity check".format(pretrained_weights)) 464 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 465 | state_dict_pre = checkpoint['state_dict'] 466 | 467 | for k in list(state_dict.keys()): 468 | # only ignore linear layer 469 | if '%s.weight' % linear_keyword in k or '%s.bias' % linear_keyword in k: 470 | continue 471 | 472 | # name in pretrained model 473 | k_pre = 'module.base_encoder.' + k[len('module.'):] \ 474 | if k.startswith('module.') else 'module.base_encoder.' + k 475 | 476 | assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \ 477 | '{} is changed in linear classifier training.'.format(k) 478 | 479 | print("=> sanity check passed.") 480 | 481 | 482 | class AverageMeter(object): 483 | """Computes and stores the average and current value""" 484 | def __init__(self, name, fmt=':f'): 485 | self.name = name 486 | self.fmt = fmt 487 | self.reset() 488 | 489 | def reset(self): 490 | self.val = 0 491 | self.avg = 0 492 | self.sum = 0 493 | self.count = 0 494 | 495 | def update(self, val, n=1): 496 | self.val = val 497 | self.sum += val * n 498 | self.count += n 499 | self.avg = self.sum / self.count 500 | 501 | def __str__(self): 502 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 503 | return fmtstr.format(**self.__dict__) 504 | 505 | 506 | class ProgressMeter(object): 507 | def __init__(self, num_batches, meters, prefix=""): 508 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 509 | self.meters = meters 510 | self.prefix = prefix 511 | 512 | def display(self, batch): 513 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 514 | entries += [str(meter) for meter in self.meters] 515 | print('\t'.join(entries)) 516 | 517 | def _get_batch_fmtstr(self, num_batches): 518 | num_digits = len(str(num_batches // 1)) 519 | fmt = '{:' + str(num_digits) + 'd}' 520 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 521 | 522 | def adjust_learning_rate(optimizer, epoch, args): 523 | """Decays the learning rate with half-cycle cosine after warmup""" 524 | if epoch < args.warmup_epochs: 525 | lr = args.lr * epoch / args.warmup_epochs 526 | else: 527 | lr = args.lr * 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 528 | for param_group in optimizer.param_groups: 529 | param_group['lr'] = lr 530 | return lr 531 | 532 | def accuracy(output, target, topk=(1,)): 533 | """Computes the accuracy over the k top predictions for the specified values of k""" 534 | with torch.no_grad(): 535 | maxk = max(topk) 536 | batch_size = target.size(0) 537 | 538 | _, pred = output.topk(maxk, 1, True, True) 539 | pred = pred.t() 540 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 541 | 542 | res = [] 543 | for k in topk: 544 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 545 | res.append(correct_k.mul_(100.0 / batch_size)) 546 | return res 547 | 548 | 549 | if __name__ == '__main__': 550 | args = parser.parse_args() 551 | main_worker(args) 552 | #main() 553 | -------------------------------------------------------------------------------- /vits.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from timm.models.vision_transformer import _cfg 8 | from lib.misc import get_2d_sincos_pos_embed 9 | 10 | class PatchEmbed(nn.Module): 11 | """ Image to Patch Embedding 12 | """ 13 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 14 | super().__init__() 15 | num_patches = (img_size // patch_size) * (img_size // patch_size) 16 | self.img_size = img_size 17 | self.patch_size = patch_size 18 | self.num_patches = num_patches 19 | 20 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 21 | 22 | def forward(self, x): 23 | B, C, H, W = x.shape 24 | x = self.proj(x).flatten(2).transpose(1, 2) 25 | return x 26 | 27 | def drop_path(x, drop_prob: float = 0., training: bool = False): 28 | if drop_prob == 0. or not training: 29 | return x 30 | keep_prob = 1 - drop_prob 31 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 32 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 33 | random_tensor.floor_() # binarize 34 | output = x.div(keep_prob) * random_tensor 35 | return output 36 | 37 | class DropPath(nn.Module): 38 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 39 | """ 40 | def __init__(self, drop_prob=None): 41 | super(DropPath, self).__init__() 42 | self.drop_prob = drop_prob 43 | 44 | def forward(self, x): 45 | return drop_path(x, self.drop_prob, self.training) 46 | 47 | 48 | class Mlp(nn.Module): 49 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 50 | super().__init__() 51 | out_features = out_features or in_features 52 | hidden_features = hidden_features or in_features 53 | self.fc1 = nn.Linear(in_features, hidden_features) 54 | self.act = act_layer() 55 | self.fc2 = nn.Linear(hidden_features, out_features) 56 | self.drop = nn.Dropout(drop) 57 | 58 | def forward(self, x): 59 | x = self.fc1(x) 60 | x = self.act(x) 61 | x = self.drop(x) 62 | x = self.fc2(x) 63 | x = self.drop(x) 64 | return x 65 | 66 | class Attention(nn.Module): 67 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 68 | super().__init__() 69 | self.num_heads = num_heads 70 | head_dim = dim // num_heads 71 | self.scale = qk_scale or head_dim ** -0.5 72 | 73 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 74 | self.attn_drop = nn.Dropout(attn_drop) 75 | self.proj = nn.Linear(dim, dim) 76 | self.proj_drop = nn.Dropout(proj_drop) 77 | 78 | def forward(self, x): 79 | B, N, C = x.shape 80 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 81 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 82 | 83 | attn = (q @ k.transpose(-2, -1)) * self.scale 84 | attn = attn.softmax(dim=-1) 85 | attn = self.attn_drop(attn) 86 | 87 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 88 | x = self.proj(x) 89 | x = self.proj_drop(x) 90 | return x 91 | 92 | class Class_Attention(nn.Module): 93 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 94 | # with slight modifications to do CA 95 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 96 | super().__init__() 97 | self.num_heads = num_heads 98 | head_dim = dim // num_heads 99 | self.scale = qk_scale or head_dim ** -0.5 100 | 101 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 102 | self.k = nn.Linear(dim, dim, bias=qkv_bias) 103 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 104 | self.attn_drop = nn.Dropout(attn_drop) 105 | self.proj = nn.Linear(dim, dim) 106 | self.proj_drop = nn.Dropout(proj_drop) 107 | 108 | 109 | def forward(self, x ): 110 | 111 | B, N, C = x.shape 112 | q = self.q(x[:,0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 113 | k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 114 | 115 | q = q * self.scale 116 | v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 117 | 118 | attn = (q @ k.transpose(-2, -1)) 119 | attn = attn.softmax(dim=-1) 120 | attn = self.attn_drop(attn) 121 | 122 | x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C) 123 | x_cls = self.proj(x_cls) 124 | x_cls = self.proj_drop(x_cls) 125 | 126 | return x_cls 127 | 128 | class Block_CA(nn.Module): 129 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 130 | # with slight modifications to add CA and LayerScale 131 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 132 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, Attention_block = Class_Attention, 133 | Mlp_block=Mlp,init_values=1e-4): 134 | super().__init__() 135 | self.norm1 = norm_layer(dim) 136 | self.attn = Attention_block( 137 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 138 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 139 | self.norm2 = norm_layer(dim) 140 | mlp_hidden_dim = int(dim * mlp_ratio) 141 | self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 142 | 143 | if init_values is not None: 144 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 145 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 146 | else: 147 | self.gamma_1, self.gamma_2 = None, None 148 | 149 | 150 | def forward(self, x, x_cls): 151 | 152 | u = torch.cat((x_cls,x),dim=1) 153 | 154 | if self.gamma_1 is None: 155 | x_cls = x_cls + self.drop_path(self.attn(self.norm1(u))) 156 | x_cls = x_cls + self.drop_path(self.mlp(self.norm2(x_cls))) 157 | else: 158 | x_cls = x_cls + self.drop_path(self.gamma_1 * self.attn(self.norm1(u))) 159 | x_cls = x_cls + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x_cls))) 160 | 161 | return x_cls 162 | 163 | class Block(nn.Module): 164 | 165 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 166 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, 167 | window_size=None, attn_head_dim=None): 168 | super().__init__() 169 | self.norm1 = norm_layer(dim) 170 | self.attn = Attention( 171 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 172 | attn_drop=attn_drop, proj_drop=drop) 173 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 174 | self.norm2 = norm_layer(dim) 175 | mlp_hidden_dim = int(dim * mlp_ratio) 176 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 177 | 178 | if init_values is not None: 179 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 180 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 181 | else: 182 | self.gamma_1, self.gamma_2 = None, None 183 | 184 | def forward(self, x): 185 | if self.gamma_1 is None: 186 | x = x + self.drop_path(self.attn(self.norm1(x))) 187 | x = x + self.drop_path(self.mlp(self.norm2(x))) 188 | else: 189 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) 190 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 191 | return x 192 | 193 | class VisionTransformer(nn.Module): 194 | """ Vision Transformer 195 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 196 | - https://arxiv.org/abs/2010.11929 197 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 198 | - https://arxiv.org/abs/2012.12877 199 | """ 200 | 201 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 202 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, 203 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, init_values=None, 204 | act_layer=None, weight_init='', class_attention_layers=2): 205 | """ 206 | Args: 207 | img_size (int, tuple): input image size 208 | patch_size (int, tuple): patch size 209 | in_chans (int): number of input channels 210 | num_classes (int): number of classes for classification head 211 | embed_dim (int): embedding dimension 212 | depth (int): depth of transformer 213 | num_heads (int): number of attention heads 214 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 215 | qkv_bias (bool): enable bias for qkv if True 216 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 217 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 218 | drop_rate (float): dropout rate 219 | attn_drop_rate (float): attention dropout rate 220 | drop_path_rate (float): stochastic depth rate 221 | embed_layer (nn.Module): patch embedding layer 222 | norm_layer: (nn.Module): normalization layer 223 | weight_init: (str): weight init scheme 224 | """ 225 | super().__init__() 226 | self.num_classes = num_classes 227 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 228 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 229 | act_layer = act_layer or nn.GELU 230 | 231 | self.patch_embed = embed_layer( 232 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 233 | num_patches = self.patch_embed.num_patches 234 | 235 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 236 | 237 | self.rel_pos_bias = None 238 | self.pos_embed = nn.Parameter( 239 | torch.zeros(1, self.patch_embed.num_patches, self.embed_dim), requires_grad=False) 240 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=False) 241 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 242 | 243 | self.pos_drop = nn.Dropout(p=drop_rate) 244 | 245 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 246 | self.blocks = nn.Sequential(*[ 247 | Block( 248 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 249 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], init_values=None, norm_layer=norm_layer, act_layer=act_layer) 250 | for i in range(depth)]) 251 | 252 | self.blocks_token = nn.ModuleList([ 253 | Block_CA( 254 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 255 | drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer, 256 | act_layer=act_layer,init_values=init_values) 257 | for i in range(class_attention_layers)]) 258 | 259 | self.norm = norm_layer(embed_dim) 260 | 261 | self.pre_logits = nn.Identity() 262 | 263 | # Classifier head(s) 264 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 265 | 266 | # weight init 267 | nn.init.normal_(self.cls_token, std=0.02) 268 | #nn.init.normal_(self.pos_embed, std=0.02) 269 | w = self.patch_embed.proj.weight.data 270 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 271 | self.apply(self._init_weights) 272 | 273 | def _init_weights(self, m): 274 | if isinstance(m, nn.Linear): 275 | # we use xavier_uniform following official JAX ViT: 276 | torch.nn.init.xavier_uniform_(m.weight) 277 | if isinstance(m, nn.Linear) and m.bias is not None: 278 | nn.init.constant_(m.bias, 0) 279 | elif isinstance(m, nn.LayerNorm): 280 | nn.init.constant_(m.bias, 0) 281 | nn.init.constant_(m.weight, 1.0) 282 | 283 | @torch.jit.ignore 284 | def no_weight_decay(self): 285 | return {'pos_embed', 'cls_token'} 286 | 287 | def forward_features(self, x, masks=None, mask_ratio=0.): 288 | x = self.patch_embed(x) 289 | # ExtreMA follows the CaiT-style architecture 290 | #x = torch.cat((cls_token, x), dim=1) 291 | x = self.pos_drop(x + self.pos_embed) 292 | # for student 293 | if masks is not None and self.student: 294 | cls_token = self.cls_token.expand(x.shape[0] * len(masks), -1, -1) # stole cls_tokens impl from Phil Wang, thanks 295 | x_list = [] 296 | for mask in masks: 297 | mask = mask.view(x.shape[0], -1, 1).repeat(1,1,x.shape[2]) 298 | x_list.append(torch.gather(x, 1, mask)) 299 | x_multi = torch.cat(x_list, dim=0) 300 | else: 301 | x_multi = x 302 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 303 | 304 | x_multi = self.blocks(x_multi) 305 | 306 | for i , blk in enumerate(self.blocks_token): 307 | cls_token = blk(x_multi,cls_token) 308 | 309 | x_multi = torch.cat((cls_token, x_multi), dim=1) 310 | 311 | x = self.norm(x_multi) 312 | return self.pre_logits(x[:, 0]) 313 | 314 | def forward(self, x, masks=None, mask_ratio=0.): 315 | x = self.forward_features(x, masks, mask_ratio) 316 | out = self.head(x) 317 | return out 318 | 319 | def vit_small(**kwargs): 320 | model = VisionTransformer( 321 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 322 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 323 | model.default_cfg = _cfg() 324 | return model 325 | 326 | def vit_base(**kwargs): 327 | model = VisionTransformer( 328 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 329 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 330 | model.default_cfg = _cfg() 331 | return model 332 | 333 | def vit_large(**kwargs): 334 | model = VisionTransformer( 335 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 336 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 337 | model.default_cfg = _cfg() 338 | return model --------------------------------------------------------------------------------