├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── NOTICE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── datasets.py ├── engine.py ├── logger.py ├── losses.py ├── main.py ├── models ├── __init__.py ├── registry.py ├── shiftvit.py ├── smlp.py └── spach │ ├── __init__.py │ ├── layers │ ├── __init__.py │ ├── channel_func.py │ ├── spatial_func.py │ └── stem.py │ ├── misc.py │ ├── spach.py │ └── spach_ms.py ├── requirements.txt ├── samplers.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | 13 | # User-specific files (MonoDevelop/Xamarin Studio) 14 | *.userprefs 15 | 16 | # Mono auto generated files 17 | mono_crash.* 18 | 19 | # Build results 20 | [Dd]ebug/ 21 | [Dd]ebugPublic/ 22 | [Rr]elease/ 23 | [Rr]eleases/ 24 | x64/ 25 | x86/ 26 | [Aa][Rr][Mm]/ 27 | [Aa][Rr][Mm]64/ 28 | bld/ 29 | [Bb]in/ 30 | [Oo]bj/ 31 | [Ll]og/ 32 | [Ll]ogs/ 33 | 34 | # Visual Studio 2015/2017 cache/options directory 35 | .vs/ 36 | # Uncomment if you have tasks that create the project's static files in wwwroot 37 | #wwwroot/ 38 | 39 | # Visual Studio 2017 auto generated files 40 | Generated\ Files/ 41 | 42 | # MSTest test Results 43 | [Tt]est[Rr]esult*/ 44 | [Bb]uild[Ll]og.* 45 | 46 | # NUnit 47 | *.VisualState.xml 48 | TestResult.xml 49 | nunit-*.xml 50 | 51 | # Build Results of an ATL Project 52 | [Dd]ebugPS/ 53 | [Rr]eleasePS/ 54 | dlldata.c 55 | 56 | # Benchmark Results 57 | BenchmarkDotNet.Artifacts/ 58 | 59 | # .NET Core 60 | project.lock.json 61 | project.fragment.lock.json 62 | artifacts/ 63 | 64 | # StyleCop 65 | StyleCopReport.xml 66 | 67 | # Files built by Visual Studio 68 | *_i.c 69 | *_p.c 70 | *_h.h 71 | *.ilk 72 | *.meta 73 | *.obj 74 | *.iobj 75 | *.pch 76 | *.pdb 77 | *.ipdb 78 | *.pgc 79 | *.pgd 80 | *.rsp 81 | *.sbr 82 | *.tlb 83 | *.tli 84 | *.tlh 85 | *.tmp 86 | *.tmp_proj 87 | *_wpftmp.csproj 88 | *.log 89 | *.vspscc 90 | *.vssscc 91 | .builds 92 | *.pidb 93 | *.svclog 94 | *.scc 95 | 96 | # Chutzpah Test files 97 | _Chutzpah* 98 | 99 | # Visual C++ cache files 100 | ipch/ 101 | *.aps 102 | *.ncb 103 | *.opendb 104 | *.opensdf 105 | *.sdf 106 | *.cachefile 107 | *.VC.db 108 | *.VC.VC.opendb 109 | 110 | # Visual Studio profiler 111 | *.psess 112 | *.vsp 113 | *.vspx 114 | *.sap 115 | 116 | # Visual Studio Trace Files 117 | *.e2e 118 | 119 | # TFS 2012 Local Workspace 120 | $tf/ 121 | 122 | # Guidance Automation Toolkit 123 | *.gpState 124 | 125 | # ReSharper is a .NET coding add-in 126 | _ReSharper*/ 127 | *.[Rr]e[Ss]harper 128 | *.DotSettings.user 129 | 130 | # TeamCity is a build add-in 131 | _TeamCity* 132 | 133 | # DotCover is a Code Coverage Tool 134 | *.dotCover 135 | 136 | # AxoCover is a Code Coverage Tool 137 | .axoCover/* 138 | !.axoCover/settings.json 139 | 140 | # Visual Studio code coverage results 141 | *.coverage 142 | *.coveragexml 143 | 144 | # NCrunch 145 | _NCrunch_* 146 | .*crunch*.local.xml 147 | nCrunchTemp_* 148 | 149 | # MightyMoose 150 | *.mm.* 151 | AutoTest.Net/ 152 | 153 | # Web workbench (sass) 154 | .sass-cache/ 155 | 156 | # Installshield output folder 157 | [Ee]xpress/ 158 | 159 | # DocProject is a documentation generator add-in 160 | DocProject/buildhelp/ 161 | DocProject/Help/*.HxT 162 | DocProject/Help/*.HxC 163 | DocProject/Help/*.hhc 164 | DocProject/Help/*.hhk 165 | DocProject/Help/*.hhp 166 | DocProject/Help/Html2 167 | DocProject/Help/html 168 | 169 | # Click-Once directory 170 | publish/ 171 | 172 | # Publish Web Output 173 | *.[Pp]ublish.xml 174 | *.azurePubxml 175 | # Note: Comment the next line if you want to checkin your web deploy settings, 176 | # but database connection strings (with potential passwords) will be unencrypted 177 | *.pubxml 178 | *.publishproj 179 | 180 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 181 | # checkin your Azure Web App publish settings, but sensitive information contained 182 | # in these scripts will be unencrypted 183 | PublishScripts/ 184 | 185 | # NuGet Packages 186 | *.nupkg 187 | # NuGet Symbol Packages 188 | *.snupkg 189 | # The packages folder can be ignored because of Package Restore 190 | **/[Pp]ackages/* 191 | # except build/, which is used as an MSBuild target. 192 | !**/[Pp]ackages/build/ 193 | # Uncomment if necessary however generally it will be regenerated when needed 194 | #!**/[Pp]ackages/repositories.config 195 | # NuGet v3's project.json files produces more ignorable files 196 | *.nuget.props 197 | *.nuget.targets 198 | 199 | # Microsoft Azure Build Output 200 | csx/ 201 | *.build.csdef 202 | 203 | # Microsoft Azure Emulator 204 | ecf/ 205 | rcf/ 206 | 207 | # Windows Store app package directories and files 208 | AppPackages/ 209 | BundleArtifacts/ 210 | Package.StoreAssociation.xml 211 | _pkginfo.txt 212 | *.appx 213 | *.appxbundle 214 | *.appxupload 215 | 216 | # Visual Studio cache files 217 | # files ending in .cache can be ignored 218 | *.[Cc]ache 219 | # but keep track of directories ending in .cache 220 | !?*.[Cc]ache/ 221 | 222 | # Others 223 | ClientBin/ 224 | ~$* 225 | *~ 226 | *.dbmdl 227 | *.dbproj.schemaview 228 | *.jfm 229 | *.pfx 230 | *.publishsettings 231 | orleans.codegen.cs 232 | 233 | # Including strong name files can present a security risk 234 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 235 | #*.snk 236 | 237 | # Since there are multiple workflows, uncomment next line to ignore bower_components 238 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 239 | #bower_components/ 240 | 241 | # RIA/Silverlight projects 242 | Generated_Code/ 243 | 244 | # Backup & report files from converting an old project file 245 | # to a newer Visual Studio version. Backup files are not needed, 246 | # because we have git ;-) 247 | _UpgradeReport_Files/ 248 | Backup*/ 249 | UpgradeLog*.XML 250 | UpgradeLog*.htm 251 | ServiceFabricBackup/ 252 | *.rptproj.bak 253 | 254 | # SQL Server files 255 | *.mdf 256 | *.ldf 257 | *.ndf 258 | 259 | # Business Intelligence projects 260 | *.rdl.data 261 | *.bim.layout 262 | *.bim_*.settings 263 | *.rptproj.rsuser 264 | *- [Bb]ackup.rdl 265 | *- [Bb]ackup ([0-9]).rdl 266 | *- [Bb]ackup ([0-9][0-9]).rdl 267 | 268 | # Microsoft Fakes 269 | FakesAssemblies/ 270 | 271 | # GhostDoc plugin setting file 272 | *.GhostDoc.xml 273 | 274 | # Node.js Tools for Visual Studio 275 | .ntvs_analysis.dat 276 | node_modules/ 277 | 278 | # Visual Studio 6 build log 279 | *.plg 280 | 281 | # Visual Studio 6 workspace options file 282 | *.opt 283 | 284 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 285 | *.vbw 286 | 287 | # Visual Studio LightSwitch build output 288 | **/*.HTMLClient/GeneratedArtifacts 289 | **/*.DesktopClient/GeneratedArtifacts 290 | **/*.DesktopClient/ModelManifest.xml 291 | **/*.Server/GeneratedArtifacts 292 | **/*.Server/ModelManifest.xml 293 | _Pvt_Extensions 294 | 295 | # Paket dependency manager 296 | .paket/paket.exe 297 | paket-files/ 298 | 299 | # FAKE - F# Make 300 | .fake/ 301 | 302 | # CodeRush personal settings 303 | .cr/personal 304 | 305 | # Python Tools for Visual Studio (PTVS) 306 | __pycache__/ 307 | *.pyc 308 | 309 | # Cake - Uncomment if you are using it 310 | # tools/** 311 | # !tools/packages.config 312 | 313 | # Tabs Studio 314 | *.tss 315 | 316 | # Telerik's JustMock configuration file 317 | *.jmconfig 318 | 319 | # BizTalk build output 320 | *.btp.cs 321 | *.btm.cs 322 | *.odx.cs 323 | *.xsd.cs 324 | 325 | # OpenCover UI analysis results 326 | OpenCover/ 327 | 328 | # Azure Stream Analytics local run output 329 | ASALocalRun/ 330 | 331 | # MSBuild Binary and Structured Log 332 | *.binlog 333 | 334 | # NVidia Nsight GPU debugger configuration file 335 | *.nvuser 336 | 337 | # MFractors (Xamarin productivity tool) working folder 338 | .mfractor/ 339 | 340 | # Local History for Visual Studio 341 | .localhistory/ 342 | 343 | # BeatPulse healthcheck temp database 344 | healthchecksdb 345 | 346 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 347 | MigrationBackup/ 348 | 349 | # Ionide (cross platform F# VS Code tools) working folder 350 | .ionide/ 351 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | NOTICES 2 | 3 | This repository incorporates material as listed below or described in the code. 4 | 5 | Component: 6 | main.py, losses.py, datasets.py, engine.py, utils.py, logger.py, samplers.py 7 | 8 | Open Source License/Copyright Notice: 9 | MIT License 10 | Copyright (c) 2015-present, Facebook, Inc. 11 | All rights reserved. 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains Pytorch evaluation code, training code and pretrained models for the following projects: 2 | 3 | + SPACH ([A Battle of Network Structures: An Empirical Study of CNN, Transformer, and MLP](https://arxiv.org/abs/2108.13002)) 4 | + sMLP ([Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?](https://arxiv.org/abs/2109.05422)) 5 | + ShiftViT ([When Shift Operation Meets Vision Transformer: An Extremely Simple Alternative to Attention Mechanism](https://arxiv.org/abs/2201.10801)) 6 | 7 | Other unofficial implementations: 8 | 9 | + ShiftViT 10 | + [Keras](https://keras.io/examples/vision/shiftvit/) by [Aritra Roy Gosthipaty](https://twitter.com/ariG23498) and [Ritwik Raha](https://twitter.com/ritwik_raha) 11 | + SPACH 12 | + [PyTorch with NPU](https://github.com/Leoooo333/SPACH-1) by [Junming Chen](https://github.com/Leoooo333) 13 | # Main Results on ImageNet with Pretrained Models 14 | 15 | 16 | | name | acc@1 | #params | FLOPs | url | 17 | | ------------------ | ----- | ------- | ----- | ------------------------------------------------------------ | 18 | | SPACH-Conv-MS-XXS | 73.1 | 5M | 0.7G | [github](https://github.com/microsoft/SPACH/releases/download/v1.0/spach_ms_conv_xxs.pth) | 19 | | SPACH-Trans-MS-XXS | 65.4 | 2M | 0.5G | [github](https://github.com/microsoft/SPACH/releases/download/v1.0/spach_ms_trans_xxs.pth) | 20 | | SPACH-MLP-MS-XXS | 74.5 | 6M | 0.9G | [github](https://github.com/microsoft/SPACH/releases/download/v1.0/spach_ms_mlp_xxs.pth) | 21 | | SPACH-Conv-MS-S | 81.6 | 44M | 7.2G | [github](https://github.com/microsoft/SPACH/releases/download/v1.0/spach_ms_conv_s.pth) | 22 | | SPACH-Trans-MS-S | 82.9 | 40M | 7.6G | [github](https://github.com/microsoft/SPACH/releases/download/v1.0/spach_ms_trans_s.pth) | 23 | | SPACH-MLP-MS-S | 82.1 | 46M | 8.2G | [github](https://github.com/microsoft/SPACH/releases/download/v1.0/spach_ms_mlp_s.pth) | 24 | | SPACH-Hybrid-MS-S | 83.7 | 63M | 11.2G | [github](https://github.com/microsoft/SPACH/releases/download/v1.0/spach_ms_hybrid_s.pth) | 25 | | SPACH-Hybrid-MS-S+ | 83.9 | 63M | 12.3G | [github](https://github.com/microsoft/SPACH/releases/download/v1.0/spach_ms_hybrid_s+.pth) | 26 | | sMLPNet-T | 81.9 | 24M | 5.0G | | 27 | | sMLPNet-S | 83.1 | 49M | 10.3G | [github](https://github.com/microsoft/SPACH/releases/download/v1.0/smlp_s.pth) | 28 | | sMLPNet-B | 83.4 | 66M | 14.0G | [github](https://github.com/microsoft/SPACH/releases/download/v1.0/smlp_b.pth) | 29 | | Shift-T / light | 79.4 | 20M | 3.0G | [github](https://github.com/microsoft/SPACH/releases/download/v1.0/shiftvit_tiny_light.pth) | 30 | | Shift-T | 81.7 | 29M | 4.5G | [github](https://github.com/microsoft/SPACH/releases/download/v1.0/shiftvit_tiny_r2.pth) | 31 | | Shift-S / light | 81.6 | 34M | 5.7G | [github](https://github.com/microsoft/SPACH/releases/download/v1.0/shiftvit_small_light.pth) | 32 | | Shift-S | 82.8 | 50M | 8.8G | [github](https://github.com/microsoft/SPACH/releases/download/v1.0/shiftvit_small_r2.pth) | 33 | 34 | # Usage 35 | 36 | ## Install 37 | First, clone the repo and install requirements: 38 | 39 | ```bash 40 | git clone https://github.com/microsoft/Spach 41 | pip install -r requirements.txt 42 | ``` 43 | 44 | ## Data preparation 45 | 46 | Download and extract ImageNet train and val images from http://image-net.org/. 47 | The directory structure is the standard layout for the torchvision [`datasets.ImageFolder`](https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder), 48 | and the training and validation data is expected to be in the `train/` folder and `val/` folder respectively: 49 | 50 | ``` 51 | /path/to/imagenet/ 52 | train/ 53 | class1/ 54 | img1.jpeg 55 | class2/ 56 | img2.jpeg 57 | val/ 58 | class1/ 59 | img3.jpeg 60 | class/2 61 | img4.jpeg 62 | ``` 63 | 64 | ## Evaluation 65 | 66 | To evaluate a pre-trained model on ImageNet val with a single GPU run: 67 | 68 | ```bash 69 | python main.py --eval --resume --model --data-path 70 | ``` 71 | 72 | For example, to evaluate the SPACH-Hybrid-MS-S model, run 73 | 74 | ```bash 75 | python main.py --eval --resume --model spach_ms_s_patch4_224_hybrid spach_ms_hybrid_s.pth --data-path 76 | ``` 77 | 78 | giving 79 | ```bash 80 | * Acc@1 83.658 Acc@5 96.762 loss 0.688 81 | ``` 82 | 83 | You can find all supported models in `models/registry.py.` 84 | 85 | ## Training 86 | 87 | One can simply call the following script to run training process. Distributed training is recommended even on single GPU node. 88 | 89 | ```bash 90 | python -m torch.distributed.launch --nproc_per_node --use_env main.py \ 91 | --model 92 | --data-path 93 | --output_dir 94 | --dist-eval 95 | ``` 96 | 97 | # Citation 98 | 99 | ``` 100 | @article{zhao2021battle, 101 | title={A Battle of Network Structures: An Empirical Study of CNN, Transformer, and MLP}, 102 | author={Zhao, Yucheng and Wang, Guangting and Tang, Chuanxin and Luo, Chong and Zeng, Wenjun and Zha, Zheng-Jun}, 103 | journal={arXiv preprint arXiv:2108.13002}, 104 | year={2021} 105 | } 106 | 107 | @article{tang2021sparse, 108 | title={Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?}, 109 | author={Tang, Chuanxin and Zhao, Yucheng and Wang, Guangting and Luo, Chong and Xie, Wenxuan and Zeng, Wenjun}, 110 | journal={arXiv preprint arXiv:2109.05422}, 111 | year={2021} 112 | } 113 | 114 | ``` 115 | 116 | # Contributing 117 | 118 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 119 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 120 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 121 | 122 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 123 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 124 | provided by the bot. You will only need to do this once across all repos using our CLA. 125 | 126 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 127 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 128 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 129 | 130 | # Acknowledgement 131 | 132 | Our code are built on top of [DeiT](https://github.com/facebookresearch/deit). We test throughput following [Swin Transformer](https://github.com/microsoft/Swin-Transformer) 133 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). 7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import os 4 | import json 5 | 6 | from torchvision import datasets, transforms 7 | from torchvision.datasets.folder import ImageFolder, default_loader 8 | 9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 10 | from timm.data import create_transform 11 | 12 | from torch.utils.data import Dataset 13 | 14 | 15 | def build_dataset(is_train, args): 16 | transform = build_transform(is_train, args) 17 | 18 | if args.data_set == 'IMNET': 19 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 20 | dataset = datasets.ImageFolder(root, transform=transform) 21 | nb_classes = 1000 22 | else: 23 | raise NotImplementedError("Support ImageNet only.") 24 | 25 | return dataset, nb_classes 26 | 27 | 28 | def build_transform(is_train, args): 29 | resize_im = args.input_size > 32 30 | if is_train: 31 | # this should always dispatch to transforms_imagenet_train 32 | transform = create_transform( 33 | input_size=args.input_size, 34 | is_training=True, 35 | color_jitter=args.color_jitter, 36 | auto_augment=args.aa, 37 | interpolation=args.train_interpolation, 38 | re_prob=args.reprob, 39 | re_mode=args.remode, 40 | re_count=args.recount, 41 | ) 42 | if not resize_im: 43 | # replace RandomResizedCropAndInterpolation with 44 | # RandomCrop 45 | transform.transforms[0] = transforms.RandomCrop( 46 | args.input_size, padding=4) 47 | return transform 48 | 49 | t = [] 50 | if resize_im: 51 | size = int((256 / 224) * args.input_size) 52 | t.append( 53 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 54 | ) 55 | t.append(transforms.CenterCrop(args.input_size)) 56 | 57 | t.append(transforms.ToTensor()) 58 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 59 | return transforms.Compose(t) 60 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Train and eval functions used in main.py 5 | """ 6 | import math 7 | import sys 8 | from typing import Iterable, Optional 9 | import time 10 | import logging 11 | 12 | import torch 13 | 14 | from timm.data import Mixup 15 | from timm.utils import accuracy, ModelEma 16 | 17 | from losses import DistillationLoss 18 | import utils 19 | 20 | 21 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 22 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 23 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 24 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 25 | set_training_mode=True, logger=logging): 26 | model.train(set_training_mode) 27 | metric_logger = utils.MetricLogger(delimiter=" ", logger=logger) 28 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 29 | header = 'Epoch: [{}]'.format(epoch) 30 | print_freq = 10 31 | 32 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 33 | samples = samples.to(device, non_blocking=True) 34 | targets = targets.to(device, non_blocking=True) 35 | 36 | if mixup_fn is not None: 37 | samples, targets = mixup_fn(samples, targets) 38 | 39 | with torch.cuda.amp.autocast(): 40 | outputs = model(samples) 41 | loss = criterion(samples, outputs, targets) 42 | 43 | loss_value = loss.item() 44 | 45 | if not math.isfinite(loss_value): 46 | logger.info("Loss is {}, stopping training".format(loss_value)) 47 | sys.exit(1) 48 | 49 | optimizer.zero_grad() 50 | 51 | # this attribute is added by timm on one optimizer (adahessian) 52 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 53 | loss_scaler(loss, optimizer, clip_grad=max_norm, 54 | parameters=model.parameters(), create_graph=is_second_order) 55 | 56 | torch.cuda.synchronize() 57 | if model_ema is not None: 58 | model_ema.update(model) 59 | 60 | metric_logger.update(loss=loss_value) 61 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 62 | # gather the stats from all processes 63 | metric_logger.synchronize_between_processes() 64 | logger.info(f"Averaged stats: {metric_logger}") 65 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 66 | 67 | 68 | @torch.no_grad() 69 | def evaluate(data_loader, model, device, logger=logging): 70 | criterion = torch.nn.CrossEntropyLoss() 71 | 72 | metric_logger = utils.MetricLogger(delimiter=" ") 73 | header = 'Test:' 74 | 75 | # switch to evaluation mode 76 | model.eval() 77 | 78 | for images, target in metric_logger.log_every(data_loader, 10, header): 79 | images = images.to(device, non_blocking=True) 80 | target = target.to(device, non_blocking=True) 81 | 82 | # compute output 83 | with torch.cuda.amp.autocast(): 84 | output = model(images) 85 | loss = criterion(output, target) 86 | 87 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 88 | 89 | batch_size = images.shape[0] 90 | metric_logger.update(loss=loss.item()) 91 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 92 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 93 | # gather the stats from all processes 94 | metric_logger.synchronize_between_processes() 95 | logger.info('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 96 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 97 | 98 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 99 | 100 | 101 | @torch.no_grad() 102 | def throughput(data_loader, model, logger=logging): 103 | model.eval() 104 | 105 | for idx, (images, _) in enumerate(data_loader): 106 | images = images.cuda(non_blocking=True) 107 | batch_size = images.shape[0] 108 | for i in range(50): 109 | model(images) 110 | torch.cuda.synchronize() 111 | logger.info(f"throughput averaged with 30 times") 112 | tic1 = time.time() 113 | for i in range(30): 114 | model(images) 115 | torch.cuda.synchronize() 116 | tic2 = time.time() 117 | logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") 118 | return 119 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import os 4 | import sys 5 | import logging 6 | 7 | 8 | # @functools.lru_cache() 9 | def create_logger(output_dir, dist_rank=0, name=''): 10 | # create logger 11 | logger = logging.getLogger(name) 12 | logger.setLevel(logging.DEBUG) 13 | logger.propagate = False 14 | 15 | # create formatter 16 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 17 | 18 | # create console handlers for master process 19 | if dist_rank == 0: 20 | console_handler = logging.StreamHandler(sys.stdout) 21 | console_handler.setLevel(logging.DEBUG) 22 | console_handler.setFormatter( 23 | logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 24 | logger.addHandler(console_handler) 25 | 26 | # create file handlers 27 | if len(output_dir) > 0: 28 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 29 | file_handler.setLevel(logging.DEBUG) 30 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 31 | logger.addHandler(file_handler) 32 | file_handler.flush() 33 | 34 | return logger 35 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | """ 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | 10 | class DistillationLoss(torch.nn.Module): 11 | """ 12 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 13 | taking a teacher model prediction and using it as additional supervision. 14 | """ 15 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 16 | distillation_type: str, alpha: float, tau: float): 17 | super().__init__() 18 | self.base_criterion = base_criterion 19 | self.teacher_model = teacher_model 20 | assert distillation_type in ['none', 'soft', 'hard'] 21 | self.distillation_type = distillation_type 22 | self.alpha = alpha 23 | self.tau = tau 24 | 25 | def forward(self, inputs, outputs, labels): 26 | """ 27 | Args: 28 | inputs: The original inputs that are feed to the teacher model 29 | outputs: the outputs of the model to be trained. It is expected to be 30 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 31 | in the first position and the distillation predictions as the second output 32 | labels: the labels for the base criterion 33 | """ 34 | outputs_kd = None 35 | if not isinstance(outputs, torch.Tensor): 36 | # assume that the model outputs a tuple of [outputs, outputs_kd] 37 | outputs, outputs_kd = outputs 38 | base_loss = self.base_criterion(outputs, labels) 39 | if self.distillation_type == 'none': 40 | return base_loss 41 | 42 | if outputs_kd is None: 43 | raise ValueError("When knowledge distillation is enabled, the model is " 44 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 45 | "class_token and the dist_token") 46 | # don't backprop throught the teacher 47 | with torch.no_grad(): 48 | teacher_outputs = self.teacher_model(inputs) 49 | 50 | if self.distillation_type == 'soft': 51 | T = self.tau 52 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 53 | # with slight modifications 54 | distillation_loss = F.kl_div( 55 | F.log_softmax(outputs_kd / T, dim=1), 56 | F.log_softmax(teacher_outputs / T, dim=1), 57 | reduction='sum', 58 | log_target=True 59 | ) * (T * T) / outputs_kd.numel() 60 | elif self.distillation_type == 'hard': 61 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 62 | 63 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 64 | return loss 65 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import argparse 4 | import datetime 5 | import numpy as np 6 | import time 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import json 10 | import os 11 | 12 | from pathlib import Path 13 | 14 | from timm.data import Mixup 15 | from timm.models import create_model 16 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 17 | from timm.scheduler import create_scheduler 18 | from timm.optim import create_optimizer 19 | from timm.utils import NativeScaler, get_state_dict, ModelEma 20 | 21 | from datasets import build_dataset 22 | from engine import train_one_epoch, evaluate, throughput 23 | from losses import DistillationLoss 24 | from samplers import RASampler 25 | import models 26 | import utils 27 | from logger import create_logger 28 | 29 | 30 | def get_args_parser(): 31 | parser = argparse.ArgumentParser('Training and evaluation script', add_help=False) 32 | parser.add_argument('--batch-size', default=128, type=int) 33 | parser.add_argument('--epochs', default=300, type=int) 34 | 35 | # Model parameters 36 | parser.add_argument('--model', default='smlpnet_tiny', type=str, metavar='MODEL', 37 | help='Name of model to train') 38 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 39 | 40 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 41 | help='Dropout rate (default: 0.)') 42 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 43 | help='Drop path rate (default: 0.1)') 44 | 45 | parser.add_argument('--model-ema', action='store_true') 46 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') 47 | parser.set_defaults(model_ema=True) 48 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 49 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 50 | 51 | # Optimizer parameters 52 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 53 | help='Optimizer (default: "adamw"') 54 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 55 | help='Optimizer Epsilon (default: 1e-8)') 56 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 57 | help='Optimizer Betas (default: None, use opt default)') 58 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 59 | help='Clip gradient norm (default: None, no clipping)') 60 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 61 | help='SGD momentum (default: 0.9)') 62 | parser.add_argument('--weight-decay', type=float, default=0.05, 63 | help='weight decay (default: 0.05)') 64 | # Learning rate schedule parameters 65 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 66 | help='LR scheduler (default: "cosine"') 67 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 68 | help='learning rate (default: 5e-4)') 69 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 70 | help='learning rate noise on/off epoch percentages') 71 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 72 | help='learning rate noise limit percent (default: 0.67)') 73 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 74 | help='learning rate noise std-dev (default: 1.0)') 75 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 76 | help='warmup learning rate (default: 1e-6)') 77 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 78 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 79 | 80 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 81 | help='epoch interval to decay LR') 82 | parser.add_argument('--warmup-epochs', type=int, default=20, metavar='N', 83 | help='epochs to warmup LR, if scheduler supports') 84 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 85 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 86 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 87 | help='patience epochs for Plateau LR scheduler (default: 10') 88 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 89 | help='LR decay rate (default: 0.1)') 90 | 91 | # Augmentation parameters 92 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 93 | help='Color jitter factor (default: 0.4)') 94 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 95 | help='Use AutoAugment policy. "v0" or "original". " + \ 96 | "(default: rand-m9-mstd0.5-inc1)'), 97 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 98 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 99 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 100 | 101 | parser.add_argument('--repeated-aug', action='store_true') 102 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') 103 | parser.set_defaults(repeated_aug=True) 104 | 105 | # * Random Erase params 106 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 107 | help='Random erase prob (default: 0.25)') 108 | parser.add_argument('--remode', type=str, default='pixel', 109 | help='Random erase mode (default: "pixel")') 110 | parser.add_argument('--recount', type=int, default=1, 111 | help='Random erase count (default: 1)') 112 | parser.add_argument('--resplit', action='store_true', default=False, 113 | help='Do not random erase first (clean) augmentation split') 114 | 115 | # * Mixup params 116 | parser.add_argument('--mixup', type=float, default=0.8, 117 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 118 | parser.add_argument('--cutmix', type=float, default=1.0, 119 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 120 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 121 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 122 | parser.add_argument('--mixup-prob', type=float, default=1.0, 123 | help='Probability of performing mixup or cutmix when either/both is enabled') 124 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 125 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 126 | parser.add_argument('--mixup-mode', type=str, default='batch', 127 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 128 | 129 | # Distillation parameters 130 | parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', 131 | help='Name of teacher model to train (default: "regnety_160"') 132 | parser.add_argument('--teacher-path', type=str, default='') 133 | parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") 134 | parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") 135 | parser.add_argument('--distillation-tau', default=1.0, type=float, help="") 136 | 137 | # * Finetuning params 138 | parser.add_argument('--finetune', default='', help='finetune from checkpoint') 139 | 140 | # Dataset parameters 141 | parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, 142 | help='dataset path') 143 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], 144 | type=str, help='Image Net dataset path') 145 | parser.add_argument('--inat-category', default='name', 146 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 147 | type=str, help='semantic granularity') 148 | 149 | parser.add_argument('--output_dir', default='', 150 | help='path where to save, empty for no saving') 151 | parser.add_argument('--device', default='cuda', 152 | help='device to use for training / testing') 153 | parser.add_argument('--seed', default=0, type=int) 154 | parser.add_argument('--resume', default='', help='resume from checkpoint') 155 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 156 | help='start epoch') 157 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 158 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') 159 | parser.add_argument('--num_workers', default=10, type=int) 160 | parser.add_argument('--pin-mem', action='store_true', 161 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 162 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 163 | help='') 164 | parser.set_defaults(pin_mem=True) 165 | 166 | # distributed training parameters 167 | parser.add_argument('--world_size', default=1, type=int, 168 | help='number of distributed processes') 169 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 170 | 171 | # parameters for training on preemptible clusters 172 | parser.add_argument('--auto-resume', action='store_true') 173 | parser.add_argument('--no-auto-resume', action='store_false', dest='auto_resume') 174 | parser.set_defaults(auto_resume=True) 175 | 176 | # spach parameters 177 | parser.add_argument('--stem-type', default='conv1', type=str, choices=['conv1', 'conv4']) 178 | parser.add_argument('--shared-spatial-func', action='store_true') 179 | 180 | # parameters for benchmark 181 | parser.add_argument('--throughput', action='store_true') 182 | return parser 183 | 184 | 185 | def parse_model_args(args): 186 | model = args.model 187 | model_args = [] 188 | if model.startswith('spach'): 189 | model_args = ['stem_type', 'shared_spatial_func'] 190 | args = vars(args) 191 | model_args = {_: args[_] for _ in model_args} 192 | return model_args 193 | 194 | 195 | def main(args): 196 | utils.init_distributed_mode(args) 197 | 198 | logger = create_logger(args.output_dir, utils.get_rank(), args.model) 199 | logger.info(args) 200 | 201 | if args.distillation_type != 'none' and args.finetune and not args.eval: 202 | raise NotImplementedError("Finetuning with distillation not yet supported") 203 | 204 | device = torch.device(args.device) 205 | 206 | # fix the seed for reproducibility 207 | seed = args.seed + utils.get_rank() 208 | torch.manual_seed(seed) 209 | np.random.seed(seed) 210 | # random.seed(seed) 211 | 212 | cudnn.benchmark = True 213 | 214 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 215 | dataset_val, _ = build_dataset(is_train=False, args=args) 216 | 217 | if True: # args.distributed: 218 | num_tasks = utils.get_world_size() 219 | global_rank = utils.get_rank() 220 | if args.repeated_aug: 221 | sampler_train = RASampler( 222 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 223 | ) 224 | else: 225 | sampler_train = torch.utils.data.DistributedSampler( 226 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 227 | ) 228 | if args.dist_eval: 229 | if len(dataset_val) % num_tasks != 0: 230 | logger.info('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 231 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 232 | 'equal num of samples per-process.') 233 | sampler_val = torch.utils.data.DistributedSampler( 234 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 235 | else: 236 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 237 | else: 238 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 239 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 240 | 241 | data_loader_train = torch.utils.data.DataLoader( 242 | dataset_train, sampler=sampler_train, 243 | batch_size=args.batch_size, 244 | num_workers=args.num_workers, 245 | pin_memory=args.pin_mem, 246 | drop_last=True, 247 | ) 248 | 249 | data_loader_val = torch.utils.data.DataLoader( 250 | dataset_val, sampler=sampler_val, 251 | batch_size=int(1.5 * args.batch_size), 252 | num_workers=args.num_workers, 253 | pin_memory=args.pin_mem, 254 | drop_last=False 255 | ) 256 | 257 | mixup_fn = None 258 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 259 | if mixup_active: 260 | mixup_fn = Mixup( 261 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 262 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 263 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 264 | 265 | logger.info(f"Creating model: {args.model}") 266 | model = create_model( 267 | args.model, 268 | pretrained=False, 269 | num_classes=args.nb_classes, 270 | drop_rate=args.drop, 271 | drop_path_rate=args.drop_path, 272 | drop_block_rate=None, 273 | **parse_model_args(args) 274 | ) 275 | 276 | if args.finetune: 277 | if args.finetune.startswith('https'): 278 | checkpoint = torch.hub.load_state_dict_from_url( 279 | args.finetune, map_location='cpu', check_hash=True) 280 | else: 281 | checkpoint = torch.load(args.finetune, map_location='cpu') 282 | 283 | checkpoint_model = checkpoint['model'] 284 | state_dict = model.state_dict() 285 | for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: 286 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 287 | logger.info(f"Removing key {k} from pretrained checkpoint") 288 | del checkpoint_model[k] 289 | 290 | # interpolate position embedding 291 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 292 | embedding_size = pos_embed_checkpoint.shape[-1] 293 | num_patches = model.patch_embed.num_patches 294 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 295 | # height (== width) for the checkpoint position embedding 296 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 297 | # height (== width) for the new position embedding 298 | new_size = int(num_patches ** 0.5) 299 | # class_token and dist_token are kept unchanged 300 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 301 | # only the position tokens are interpolated 302 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 303 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 304 | pos_tokens = torch.nn.functional.interpolate( 305 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 306 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 307 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 308 | checkpoint_model['pos_embed'] = new_pos_embed 309 | 310 | model.load_state_dict(checkpoint_model, strict=False) 311 | 312 | model.to(device) 313 | 314 | model_ema = None 315 | if args.model_ema: 316 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 317 | model_ema = ModelEma( 318 | model, 319 | decay=args.model_ema_decay, 320 | device='cpu' if args.model_ema_force_cpu else '', 321 | resume='') 322 | 323 | model_without_ddp = model 324 | if args.distributed: 325 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 326 | model_without_ddp = model.module 327 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 328 | logger.info(f'number of params: {n_parameters}') 329 | if hasattr(model_without_ddp, 'flops'): 330 | try: 331 | flops = model_without_ddp.flops() 332 | logger.info(f"number of GFLOPs: {flops / 1e9}") 333 | except Exception as e: 334 | logger.exception(e) 335 | 336 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 337 | args.lr = linear_scaled_lr 338 | optimizer = create_optimizer(args, model_without_ddp) 339 | loss_scaler = NativeScaler() 340 | 341 | lr_scheduler, _ = create_scheduler(args, optimizer) 342 | 343 | criterion = LabelSmoothingCrossEntropy() 344 | 345 | if args.mixup > 0.: 346 | # smoothing is handled with mixup label transform 347 | criterion = SoftTargetCrossEntropy() 348 | elif args.smoothing: 349 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 350 | else: 351 | criterion = torch.nn.CrossEntropyLoss() 352 | 353 | teacher_model = None 354 | if args.distillation_type != 'none': 355 | assert args.teacher_path, 'need to specify teacher-path when using distillation' 356 | logger.info(f"Creating teacher model: {args.teacher_model}") 357 | teacher_model = create_model( 358 | args.teacher_model, 359 | pretrained=False, 360 | num_classes=args.nb_classes, 361 | global_pool='avg', 362 | ) 363 | if args.teacher_path.startswith('https'): 364 | checkpoint = torch.hub.load_state_dict_from_url( 365 | args.teacher_path, map_location='cpu', check_hash=True) 366 | else: 367 | checkpoint = torch.load(args.teacher_path, map_location='cpu') 368 | teacher_model.load_state_dict(checkpoint['model']) 369 | teacher_model.to(device) 370 | teacher_model.eval() 371 | 372 | # wrap the criterion in our custom DistillationLoss, which 373 | # just dispatches to the original criterion if args.distillation_type is 'none' 374 | criterion = DistillationLoss( 375 | criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau 376 | ) 377 | 378 | output_dir = Path(args.output_dir) 379 | if args.auto_resume: 380 | _resume = str((output_dir / 'checkpoint.pth').absolute()) 381 | if os.path.exists(_resume): 382 | logger.info(f'auto resume from {output_dir}/checkpoint.pth') 383 | args.resume = _resume 384 | 385 | if args.resume: 386 | if args.resume.startswith('https'): 387 | checkpoint = torch.hub.load_state_dict_from_url( 388 | args.resume, map_location='cpu', check_hash=True) 389 | else: 390 | checkpoint = torch.load(args.resume, map_location='cpu') 391 | model_without_ddp.load_state_dict(checkpoint['model']) 392 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 393 | optimizer.load_state_dict(checkpoint['optimizer']) 394 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 395 | args.start_epoch = checkpoint['epoch'] + 1 396 | if args.model_ema: 397 | utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 398 | if 'scaler' in checkpoint: 399 | loss_scaler.load_state_dict(checkpoint['scaler']) 400 | 401 | if args.eval: 402 | test_stats = evaluate(data_loader_val, model, device, logger=logger) 403 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 404 | return 405 | 406 | if args.throughput: 407 | throughput(data_loader_val, model, logger=logger) 408 | return 409 | 410 | logger.info(f"Start training for {args.epochs} epochs") 411 | start_time = time.time() 412 | max_accuracy = 0.0 413 | for epoch in range(args.start_epoch, args.epochs): 414 | if args.distributed: 415 | data_loader_train.sampler.set_epoch(epoch) 416 | 417 | train_stats = train_one_epoch( 418 | model, criterion, data_loader_train, 419 | optimizer, device, epoch, loss_scaler, 420 | args.clip_grad, model_ema, mixup_fn, 421 | set_training_mode=args.finetune == '', # keep in eval mode during finetuning 422 | logger=logger 423 | ) 424 | 425 | lr_scheduler.step(epoch) 426 | if args.output_dir: 427 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 428 | for checkpoint_path in checkpoint_paths: 429 | utils.save_on_master({ 430 | 'model': model_without_ddp.state_dict(), 431 | 'optimizer': optimizer.state_dict(), 432 | 'lr_scheduler': lr_scheduler.state_dict(), 433 | 'epoch': epoch, 434 | 'model_ema': get_state_dict(model_ema), 435 | 'scaler': loss_scaler.state_dict(), 436 | 'args': args, 437 | }, checkpoint_path) 438 | 439 | test_stats = evaluate(data_loader_val, model, device, logger=logger) 440 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 441 | 442 | if test_stats["acc1"] > max_accuracy: 443 | best_checkpoint_path = output_dir / 'best.pth' 444 | utils.save_on_master({ 445 | 'model': model_without_ddp.state_dict(), 446 | 'optimizer': optimizer.state_dict(), 447 | 'lr_scheduler': lr_scheduler.state_dict(), 448 | 'epoch': epoch, 449 | 'model_ema': get_state_dict(model_ema), 450 | 'scaler': loss_scaler.state_dict(), 451 | 'args': args, 452 | }, best_checkpoint_path) 453 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 454 | logger.info(f'Max accuracy: {max_accuracy:.2f}%') 455 | 456 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 457 | **{f'test_{k}': v for k, v in test_stats.items()}, 458 | 'epoch': epoch, 459 | 'n_parameters': n_parameters} 460 | 461 | if args.output_dir and utils.is_main_process(): 462 | with (output_dir / "log.txt").open("a") as f: 463 | f.write(json.dumps(log_stats) + "\n") 464 | 465 | total_time = time.time() - start_time 466 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 467 | logger.info('Training time {}'.format(total_time_str)) 468 | 469 | 470 | if __name__ == '__main__': 471 | parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()]) 472 | args = parser.parse_args() 473 | if args.output_dir: 474 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 475 | main(args) 476 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | from .registry import * 4 | -------------------------------------------------------------------------------- /models/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | from timm.models.registry import register_model 4 | from .smlp import sMLPNet 5 | from .spach import Spach, SpachMS 6 | from .shiftvit import ShiftViT 7 | 8 | 9 | # sMLP 10 | @register_model 11 | def smlpnet_tiny(pretrained=False, **kwargs): 12 | model = sMLPNet(dim=80, alpha=3, patch_size=4, depths=[2,8,14,2], dp_rate=0.0, **kwargs) 13 | return model 14 | 15 | 16 | @register_model 17 | def smlpnet_small(pretrained=False, **kwargs): 18 | model = sMLPNet(dim=96, alpha=3, patch_size=4, depths=[2,10,24,2], dp_rate=0.2, **kwargs) 19 | return model 20 | 21 | 22 | @register_model 23 | def smlpnet_base(pretrained=False, **kwargs): 24 | model = sMLPNet(dim=112, alpha=3, patch_size=4, depths=[2,10,24,2], dp_rate=0.3, **kwargs) 25 | return model 26 | 27 | 28 | # SPACH 29 | @register_model 30 | def spach_xxs_patch16_224_mlp(pretrained=False, **kwargs): 31 | cfgs = dict(img_size=224, patch_size=16, hidden_dim=384, token_ratio=0.5, num_heads=12, channel_ratio=2.0) 32 | cfgs['net_arch'] = [('mlp', 12)] 33 | cfgs.update(kwargs) 34 | model = Spach(**cfgs) 35 | return model 36 | 37 | 38 | @register_model 39 | def spach_xxs_patch16_224_conv(pretrained=False, **kwargs): 40 | cfgs = dict(img_size=224, patch_size=16, hidden_dim=384, token_ratio=0.5, num_heads=12, channel_ratio=2.0) 41 | cfgs['net_arch'] = [('pass', 12)] 42 | cfgs.update(kwargs) 43 | model = Spach(**cfgs) 44 | return model 45 | 46 | 47 | @register_model 48 | def spach_xxs_patch16_224_attn(pretrained=False, **kwargs): 49 | cfgs = dict(img_size=224, patch_size=16, hidden_dim=192, token_ratio=0.5, num_heads=6, channel_ratio=2.0) 50 | cfgs['net_arch'] = [('attn', 12)] 51 | cfgs.update(kwargs) 52 | model = Spach(**cfgs) 53 | return model 54 | 55 | 56 | @register_model 57 | def spach_xs_patch16_224_mlp(pretrained=False, **kwargs): 58 | cfgs = dict(img_size=224, patch_size=16, hidden_dim=384, token_ratio=0.5, num_heads=12, channel_ratio=2.0) 59 | cfgs['net_arch'] = [('mlp', 24)] 60 | cfgs.update(kwargs) 61 | model = Spach(**cfgs) 62 | return model 63 | 64 | 65 | @register_model 66 | def spach_xs_patch16_224_conv(pretrained=False, **kwargs): 67 | cfgs = dict(img_size=224, patch_size=16, hidden_dim=384, token_ratio=0.5, num_heads=12, channel_ratio=2.0) 68 | cfgs['net_arch'] = [('pass', 24)] 69 | cfgs.update(kwargs) 70 | model = Spach(**cfgs) 71 | return model 72 | 73 | 74 | @register_model 75 | def spach_xs_patch16_224_attn(pretrained=False, **kwargs): 76 | cfgs = dict(img_size=224, patch_size=16, hidden_dim=384, token_ratio=0.5, num_heads=12, channel_ratio=2.0) 77 | cfgs['net_arch'] = [('attn', 12)] 78 | cfgs.update(kwargs) 79 | model = Spach(**cfgs) 80 | return model 81 | 82 | 83 | @register_model 84 | def spach_s_patch16_224_mlp(pretrained=False, **kwargs): 85 | cfgs = dict(img_size=224, patch_size=16, hidden_dim=512, token_ratio=0.5, num_heads=16, channel_ratio=3.0) 86 | cfgs['net_arch'] = [('mlp', 24)] 87 | cfgs.update(kwargs) 88 | model = Spach(**cfgs) 89 | return model 90 | 91 | 92 | @register_model 93 | def spach_s_patch16_224_conv(pretrained=False, **kwargs): 94 | cfgs = dict(img_size=224, patch_size=16, hidden_dim=512, token_ratio=0.5, num_heads=16, channel_ratio=3.0) 95 | cfgs['net_arch'] = [('pass', 24)] 96 | cfgs.update(kwargs) 97 | model = Spach(**cfgs) 98 | return model 99 | 100 | 101 | @register_model 102 | def spach_s_patch16_224_attn(pretrained=False, **kwargs): 103 | cfgs = dict(img_size=224, patch_size=16, hidden_dim=512, token_ratio=0.5, num_heads=16, channel_ratio=3.0) 104 | cfgs['net_arch'] = [('attn', 12)] 105 | cfgs.update(kwargs) 106 | model = Spach(**cfgs) 107 | return model 108 | 109 | 110 | @register_model 111 | def spach_ms_xxs_patch4_224_conv(pretrained=False, **kwargs): 112 | cfgs = dict(img_size=224, patch_size=4, hidden_dim=64, token_ratio=0.5, num_heads=2, channel_ratio=2.0) 113 | cfgs['net_arch'] = [[('pass', 2)], [('pass', 2)], [('pass', 6)], [('pass', 2)]] 114 | cfgs.update(kwargs) 115 | model = SpachMS(**cfgs) 116 | return model 117 | 118 | 119 | @register_model 120 | def spach_ms_xxs_patch4_224_mlp(pretrained=False, **kwargs): 121 | cfgs = dict(img_size=224, patch_size=4, hidden_dim=64, token_ratio=0.5, num_heads=2, channel_ratio=2.0) 122 | cfgs['net_arch'] = [[('pass', 2)], [('mlp', 2)], [('mlp', 6)], [('mlp', 2)]] 123 | cfgs.update(kwargs) 124 | model = SpachMS(**cfgs) 125 | return model 126 | 127 | 128 | @register_model 129 | def spach_ms_xxs_patch4_224_attn(pretrained=False, **kwargs): 130 | cfgs = dict(img_size=224, patch_size=4, hidden_dim=32, token_ratio=0.5, num_heads=1, channel_ratio=2.0) 131 | cfgs['net_arch'] = [[('pass', 2)], [('attn', 2)], [('attn', 6)], [('attn', 2)]] 132 | cfgs.update(kwargs) 133 | model = SpachMS(**cfgs) 134 | return model 135 | 136 | 137 | @register_model 138 | def spach_ms_xs_patch4_224_conv(pretrained=False, **kwargs): 139 | cfgs = dict(img_size=224, patch_size=4, hidden_dim=96, token_ratio=0.5, num_heads=3, channel_ratio=2.0) 140 | cfgs['net_arch'] = [[('pass', 3)], [('pass', 4)], [('pass', 12)], [('pass', 3)]] 141 | cfgs.update(kwargs) 142 | model = SpachMS(**cfgs) 143 | return model 144 | 145 | 146 | @register_model 147 | def spach_ms_xs_patch4_224_mlp(pretrained=False, **kwargs): 148 | cfgs = dict(img_size=224, patch_size=4, hidden_dim=96, token_ratio=0.5, num_heads=3, channel_ratio=2.0) 149 | cfgs['net_arch'] = [[('pass', 3)], [('mlp', 4)], [('mlp', 12)], [('mlp', 3)]] 150 | cfgs.update(kwargs) 151 | model = SpachMS(**cfgs) 152 | return model 153 | 154 | 155 | @register_model 156 | def spach_ms_xs_patch4_224_attn(pretrained=False, **kwargs): 157 | cfgs = dict(img_size=224, patch_size=4, hidden_dim=64, token_ratio=0.5, num_heads=2, channel_ratio=2.0) 158 | cfgs['net_arch'] = [[('pass', 3)], [('attn', 4)], [('attn', 12)], [('attn', 3)]] 159 | cfgs.update(kwargs) 160 | model = SpachMS(**cfgs) 161 | return model 162 | 163 | 164 | @register_model 165 | def spach_ms_s_patch4_224_conv(pretrained=False, **kwargs): 166 | cfgs = dict(img_size=224, patch_size=4, hidden_dim=128, token_ratio=0.5, num_heads=4, channel_ratio=3.0) 167 | cfgs['net_arch'] = [[('pass', 3)], [('pass', 4)], [('pass', 12)], [('pass', 3)]] 168 | cfgs.update(kwargs) 169 | model = SpachMS(**cfgs) 170 | return model 171 | 172 | 173 | @register_model 174 | def spach_ms_s_patch4_224_mlp(pretrained=False, **kwargs): 175 | cfgs = dict(img_size=224, patch_size=4, hidden_dim=128, token_ratio=0.5, num_heads=4, channel_ratio=3.0) 176 | cfgs['net_arch'] = [[('pass', 3)], [('mlp', 4)], [('mlp', 12)], [('mlp', 3)]] 177 | cfgs.update(kwargs) 178 | model = SpachMS(**cfgs) 179 | return model 180 | 181 | 182 | @register_model 183 | def spach_ms_s_patch4_224_attn(pretrained=False, **kwargs): 184 | cfgs = dict(img_size=224, patch_size=4, hidden_dim=96, token_ratio=0.5, num_heads=3, channel_ratio=3.0) 185 | cfgs['net_arch'] = [[('pass', 3)], [('attn', 4)], [('attn', 12)], [('attn', 3)]] 186 | cfgs.update(kwargs) 187 | model = SpachMS(**cfgs) 188 | return model 189 | 190 | 191 | @register_model 192 | def spach_ms_xs_patch4_224_hybrid(pretrained=False, **kwargs): 193 | cfgs = dict(img_size=224, patch_size=4, hidden_dim=96, token_ratio=0.5, num_heads=3, channel_ratio=2.0) 194 | cfgs['net_arch'] = [[('pass', 3)], [('pass', 4)], [('pass', 2), ('attn', 10)], [('pass', 1), ('attn', 2)]] 195 | cfgs.update(kwargs) 196 | model = SpachMS(**cfgs) 197 | return model 198 | 199 | 200 | @register_model 201 | def spach_ms_s_patch4_224_hybrid(pretrained=False, **kwargs): 202 | cfgs = dict(img_size=224, patch_size=4, hidden_dim=128, token_ratio=0.5, num_heads=4, channel_ratio=3.0) 203 | cfgs['net_arch'] = [[('pass', 3)], [('pass', 2), ('attn', 2)], [('pass', 2), ('attn', 10)], [('pass', 1), ('attn', 2)]] 204 | cfgs.update(kwargs) 205 | model = SpachMS(**cfgs) 206 | return model 207 | 208 | 209 | # shift vit 210 | @register_model 211 | def shiftvit_light_tiny(**kwargs): 212 | model = ShiftViT(embed_dim=96, depths=(2, 2, 6, 2), mlp_ratio=4, drop_path_rate=0.2, n_div=12) 213 | return model 214 | 215 | 216 | @register_model 217 | def shiftvit_r4_tiny(**kwargs): 218 | model = ShiftViT(embed_dim=96, depths=(2, 2, 12, 3), mlp_ratio=4, drop_path_rate=0.2, n_div=12) 219 | return model 220 | 221 | 222 | @register_model 223 | def shiftvit_r2_tiny(**kwargs): 224 | model = ShiftViT(embed_dim=96, depths=(6, 8, 18, 6), mlp_ratio=2, drop_path_rate=0.2, n_div=12) 225 | return model 226 | 227 | 228 | @register_model 229 | def shiftvit_light_small(**kwargs): 230 | model = ShiftViT(embed_dim=96, depths=(2, 2, 18, 2), mlp_ratio=4, drop_path_rate=0.4, n_div=12) 231 | return model 232 | 233 | 234 | @register_model 235 | def shiftvit_r4_small(**kwargs): 236 | model = ShiftViT(embed_dim=96, depths=(2, 6, 24, 4), mlp_ratio=4, drop_path_rate=0.4, n_div=12) 237 | return model 238 | 239 | 240 | @register_model 241 | def shiftvit_r2_small(**kwargs): 242 | model = ShiftViT(embed_dim=96, depths=(10, 18, 36, 10), mlp_ratio=2, drop_path_rate=0.4, n_div=12) 243 | return model 244 | 245 | 246 | @register_model 247 | def shiftvit_light_base(**kwargs): 248 | model = ShiftViT(embed_dim=128, depths=(2, 2, 18, 2), mlp_ratio=4, drop_path_rate=0.5, n_div=16) 249 | return model 250 | 251 | 252 | @register_model 253 | def shiftvit_r4_base(**kwargs): 254 | model = ShiftViT(embed_dim=128, depths=(4, 6, 22, 4), mlp_ratio=4, drop_path_rate=0.5, n_div=16) 255 | return model 256 | 257 | 258 | @register_model 259 | def shiftvit_r2_base(**kwargs): 260 | model = ShiftViT(embed_dim=128, depths=(10, 18, 36, 10), mlp_ratio=2, drop_path_rate=0.6, n_div=16) 261 | return model 262 | -------------------------------------------------------------------------------- /models/shiftvit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.checkpoint as checkpoint 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | from functools import partial 8 | 9 | 10 | class GroupNorm(nn.GroupNorm): 11 | 12 | def __init__(self, num_channels, num_groups=1): 13 | """ We use GroupNorm (group = 1) to approximate LayerNorm 14 | for [N, C, H, W] layout""" 15 | super(GroupNorm, self).__init__(num_groups, num_channels) 16 | 17 | 18 | class Mlp(nn.Module): 19 | 20 | def __init__(self, 21 | in_features, 22 | hidden_features=None, 23 | out_features=None, 24 | act_layer=nn.GELU, 25 | drop=0.): 26 | """ MLP network in FFN. By default, the MLP is implemented by 27 | nn.Linear. However, in our implementation, the data layout is 28 | in format of [N, C, H, W], therefore we use 1x1 convolution to 29 | implement fully-connected MLP layers. 30 | 31 | Args: 32 | in_features (int): input channels 33 | hidden_features (int): hidden channels, if None, set to in_features 34 | out_features (int): out channels, if None, set to in_features 35 | act_layer (callable): activation function class type 36 | drop (float): drop out probability 37 | """ 38 | super().__init__() 39 | out_features = out_features or in_features 40 | hidden_features = hidden_features or in_features 41 | self.fc1 = nn.Conv2d(in_features, hidden_features, 1) 42 | self.act = act_layer() 43 | self.fc2 = nn.Conv2d(hidden_features, out_features, 1) 44 | self.drop = nn.Dropout(drop) 45 | 46 | def forward(self, x): 47 | x = self.fc1(x) 48 | x = self.act(x) 49 | x = self.drop(x) 50 | x = self.fc2(x) 51 | x = self.drop(x) 52 | return x 53 | 54 | 55 | class ShiftViTBlock(nn.Module): 56 | 57 | def __init__(self, 58 | dim, 59 | n_div=12, 60 | mlp_ratio=4., 61 | drop=0., 62 | drop_path=0., 63 | act_layer=nn.GELU, 64 | norm_layer=nn.LayerNorm, 65 | input_resolution=None): 66 | """ The building block of Shift-ViT network. 67 | 68 | Args: 69 | dim (int): feature dimension 70 | n_div (int): how many divisions are used. Totally, 4/n_div of 71 | channels will be shifted. 72 | mlp_ratio (float): expand ratio of MLP network. 73 | drop (float): drop out prob. 74 | drop_path (float): drop path prob. 75 | act_layer (callable): activation function class type. 76 | norm_layer (callable): normalization layer class type. 77 | input_resolution (tuple): input resolution. This optional variable 78 | is used to calculate the flops. 79 | 80 | """ 81 | super(ShiftViTBlock, self).__init__() 82 | self.dim = dim 83 | self.input_resolution = input_resolution 84 | self.mlp_ratio = mlp_ratio 85 | 86 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 87 | self.norm2 = norm_layer(dim) 88 | mlp_hidden_dim = int(dim * mlp_ratio) 89 | self.mlp = Mlp(in_features=dim, 90 | hidden_features=mlp_hidden_dim, 91 | act_layer=act_layer, 92 | drop=drop) 93 | self.n_div = n_div 94 | 95 | def forward(self, x): 96 | x = self.shift_feat(x, self.n_div) 97 | shortcut = x 98 | x = shortcut + self.drop_path(self.mlp(self.norm2(x))) 99 | return x 100 | 101 | def extra_repr(self) -> str: 102 | return f"dim={self.dim}," \ 103 | f"input_resolution={self.input_resolution}," \ 104 | f"shift percentage={4.0 / self.n_div * 100}%." 105 | 106 | @staticmethod 107 | def shift_feat(x, n_div): 108 | B, C, H, W = x.shape 109 | g = C // n_div 110 | out = torch.zeros_like(x) 111 | 112 | out[:, g * 0:g * 1, :, :-1] = x[:, g * 0:g * 1, :, 1:] # shift left 113 | out[:, g * 1:g * 2, :, 1:] = x[:, g * 1:g * 2, :, :-1] # shift right 114 | out[:, g * 2:g * 3, :-1, :] = x[:, g * 2:g * 3, 1:, :] # shift up 115 | out[:, g * 3:g * 4, 1:, :] = x[:, g * 3:g * 4, :-1, :] # shift down 116 | 117 | out[:, g * 4:, :, :] = x[:, g * 4:, :, :] # no shift 118 | return out 119 | 120 | 121 | class PatchMerging(nn.Module): 122 | 123 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 124 | super().__init__() 125 | self.input_resolution = input_resolution 126 | self.dim = dim 127 | self.reduction = nn.Conv2d(dim, 2 * dim, (2, 2), stride=2, bias=False) 128 | self.norm = norm_layer(dim) 129 | 130 | def forward(self, x): 131 | x = self.norm(x) 132 | x = self.reduction(x) 133 | return x 134 | 135 | def extra_repr(self) -> str: 136 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 137 | 138 | 139 | class BasicLayer(nn.Module): 140 | 141 | def __init__(self, 142 | dim, 143 | input_resolution, 144 | depth, 145 | n_div=12, 146 | mlp_ratio=4., 147 | drop=0., 148 | drop_path=None, 149 | norm_layer=None, 150 | downsample=None, 151 | use_checkpoint=False, 152 | act_layer=nn.GELU): 153 | 154 | super(BasicLayer, self).__init__() 155 | self.dim = dim 156 | self.input_resolution = input_resolution 157 | self.depth = depth 158 | self.use_checkpoint = use_checkpoint 159 | 160 | # build blocks 161 | self.blocks = nn.ModuleList([ 162 | ShiftViTBlock(dim=dim, 163 | n_div=n_div, 164 | mlp_ratio=mlp_ratio, 165 | drop=drop, 166 | drop_path=drop_path[i], 167 | norm_layer=norm_layer, 168 | act_layer=act_layer, 169 | input_resolution=input_resolution) 170 | for i in range(depth) 171 | ]) 172 | 173 | # patch merging layer 174 | if downsample is not None: 175 | self.downsample = downsample(input_resolution, 176 | dim=dim, 177 | norm_layer=norm_layer) 178 | else: 179 | self.downsample = None 180 | 181 | def forward(self, x): 182 | for blk in self.blocks: 183 | if self.use_checkpoint: 184 | x = checkpoint.checkpoint(blk, x) 185 | else: 186 | x = blk(x) 187 | if self.downsample is not None: 188 | x = self.downsample(x) 189 | return x 190 | 191 | def extra_repr(self) -> str: 192 | return f"dim={self.dim}," \ 193 | f"input_resolution={self.input_resolution}," \ 194 | f"depth={self.depth}" 195 | 196 | 197 | class PatchEmbed(nn.Module): 198 | r""" Image to Patch Embedding 199 | 200 | Args: 201 | img_size (int, tuple): Image size. 202 | patch_size (int, tuple): Patch token size. 203 | in_chans (int): Number of input image channels. 204 | embed_dim (int): Number of linear projection output channels. 205 | norm_layer (nn.Module, optional): Normalization layer. 206 | """ 207 | 208 | def __init__(self, 209 | img_size=224, 210 | patch_size=4, 211 | in_chans=3, 212 | embed_dim=96, 213 | norm_layer=None): 214 | super().__init__() 215 | img_size = to_2tuple(img_size) 216 | patch_size = to_2tuple(patch_size) 217 | patches_resolution = [img_size[0] // patch_size[0], 218 | img_size[1] // patch_size[1]] 219 | self.img_size = img_size 220 | self.patch_size = patch_size 221 | self.patches_resolution = patches_resolution 222 | self.num_patches = patches_resolution[0] * patches_resolution[1] 223 | 224 | self.in_chans = in_chans 225 | self.embed_dim = embed_dim 226 | 227 | self.proj = nn.Conv2d(in_chans, embed_dim, 228 | kernel_size=patch_size, stride=patch_size) 229 | if norm_layer is not None: 230 | self.norm = norm_layer(embed_dim) 231 | else: 232 | self.norm = None 233 | 234 | def forward(self, x): 235 | x = self.proj(x) 236 | if self.norm is not None: 237 | x = self.norm(x) 238 | return x 239 | 240 | 241 | class ShiftViT(nn.Module): 242 | 243 | def __init__(self, 244 | n_div=12, 245 | img_size=224, 246 | patch_size=4, 247 | in_chans=3, 248 | num_classes=1000, 249 | embed_dim=96, 250 | depths=(2, 2, 6, 2), 251 | mlp_ratio=4., 252 | drop_rate=0., 253 | drop_path_rate=0.1, 254 | norm_layer='GN1', 255 | act_layer='GELU', 256 | patch_norm=True, 257 | use_checkpoint=False, 258 | **kwargs): 259 | super().__init__() 260 | assert norm_layer in ('GN1', 'BN') 261 | if norm_layer == 'BN': 262 | norm_layer = nn.BatchNorm2d 263 | elif norm_layer == 'GN1': 264 | norm_layer = partial(GroupNorm, num_groups=1) 265 | else: 266 | raise NotImplementedError 267 | 268 | if act_layer == 'GELU': 269 | act_layer = nn.GELU 270 | elif act_layer == 'RELU': 271 | act_layer = partial(nn.ReLU, inplace=False) 272 | else: 273 | raise NotImplementedError 274 | 275 | self.num_classes = num_classes 276 | self.num_layers = len(depths) 277 | self.embed_dim = embed_dim 278 | self.patch_norm = patch_norm 279 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 280 | self.mlp_ratio = mlp_ratio 281 | 282 | # split image into non-overlapping patches 283 | self.patch_embed = PatchEmbed( 284 | img_size=img_size, 285 | patch_size=patch_size, 286 | in_chans=in_chans, 287 | embed_dim=embed_dim, 288 | norm_layer=norm_layer if self.patch_norm else None) 289 | 290 | # num_patches = self.patch_embed.num_patches 291 | patches_resolution = self.patch_embed.patches_resolution 292 | self.patches_resolution = patches_resolution 293 | self.pos_drop = nn.Dropout(p=drop_rate) 294 | 295 | # stochastic depth decay rule 296 | dpr = [x.item() 297 | for x in torch.linspace(0, drop_path_rate, sum(depths))] 298 | 299 | # build layers 300 | self.layers = nn.ModuleList() 301 | for i_layer in range(self.num_layers): 302 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 303 | n_div=n_div, 304 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 305 | patches_resolution[1] // (2 ** i_layer)), 306 | depth=depths[i_layer], 307 | mlp_ratio=self.mlp_ratio, 308 | drop=drop_rate, 309 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 310 | norm_layer=norm_layer, 311 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 312 | use_checkpoint=use_checkpoint, 313 | act_layer=act_layer) 314 | self.layers.append(layer) 315 | 316 | self.norm = norm_layer(self.num_features) 317 | self.avgpool = nn.AdaptiveAvgPool2d(1) 318 | self.head = nn.Linear(self.num_features, num_classes) \ 319 | if num_classes > 0 else nn.Identity() 320 | 321 | self.apply(self._init_weights) 322 | 323 | def _init_weights(self, m): 324 | if isinstance(m, nn.Linear): 325 | trunc_normal_(m.weight, std=.02) 326 | if isinstance(m, nn.Linear) and m.bias is not None: 327 | nn.init.constant_(m.bias, 0) 328 | elif isinstance(m, (nn.Conv1d, nn.Conv2d)): 329 | trunc_normal_(m.weight, std=.02) 330 | if m.bias is not None: 331 | nn.init.constant_(m.bias, 0) 332 | elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)): 333 | nn.init.constant_(m.bias, 0) 334 | nn.init.constant_(m.weight, 1.0) 335 | 336 | def forward_features(self, x): 337 | x = self.patch_embed(x) 338 | x = self.pos_drop(x) 339 | 340 | for layer in self.layers: 341 | x = layer(x) 342 | 343 | x = self.norm(x) # B L C 344 | x = self.avgpool(x) # B C 1 345 | x = torch.flatten(x, 1) 346 | return x 347 | 348 | def forward(self, x): 349 | x = self.forward_features(x) 350 | x = self.head(x) 351 | return x 352 | -------------------------------------------------------------------------------- /models/smlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import torch 4 | from torch import nn 5 | from einops.layers.torch import Rearrange 6 | from timm.models.layers import DropPath 7 | 8 | 9 | class FeedForward(nn.Module): 10 | def __init__(self, dim, hidden_dim, dropout=0.): 11 | super().__init__() 12 | self.net = nn.Sequential( 13 | nn.Linear(dim, hidden_dim), 14 | nn.GELU(), 15 | nn.Dropout(dropout), 16 | nn.Linear(hidden_dim, dim), 17 | nn.Dropout(dropout) 18 | ) 19 | 20 | def forward(self, x): 21 | return self.net(x) 22 | 23 | 24 | class BN_Activ_Conv(nn.Module): 25 | def __init__(self, in_channels, activation, out_channels, kernel_size, stride=(1, 1), dilation=(1, 1), groups=1): 26 | super(BN_Activ_Conv, self).__init__() 27 | self.BN = nn.BatchNorm2d(out_channels) 28 | self.Activation = activation 29 | padding = [int((dilation[j] * (kernel_size[j] - 1) - stride[j] + 1) / 2) for j in range(2)] # Same padding 30 | self.Conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups=groups, bias=False) 31 | 32 | def forward(self, img): 33 | img = self.BN(img) 34 | img = self.Activation(img) 35 | img = self.Conv(img) 36 | return img 37 | 38 | 39 | class sMLPBlock(nn.Module): 40 | def __init__(self, W, H, channels): 41 | super().__init__() 42 | assert W == H 43 | self.channels = channels 44 | self.activation = nn.GELU() 45 | self.BN = nn.BatchNorm2d(channels) 46 | self.proj_h = nn.Conv2d(H, H, (1, 1)) 47 | self.proh_w = nn.Conv2d(W, W, (1, 1)) 48 | self.fuse = nn.Conv2d(channels*3, channels, (1,1), (1,1), bias=False) 49 | 50 | def forward(self, x): 51 | x = self.activation(self.BN(x)) 52 | x_h = self.proj_h(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 53 | x_w = self.proh_w(x.permute(0, 2, 1, 3)).permute(0, 2, 1, 3) 54 | x = self.fuse(torch.cat([x, x_h, x_w], dim=1)) 55 | return x 56 | 57 | 58 | class DWConvBlock(nn.Module): 59 | def __init__(self, channels): 60 | super().__init__() 61 | self.conv_merge = BN_Activ_Conv(channels, nn.GELU(), channels, (3, 3), groups=channels) 62 | 63 | def forward(self, img): 64 | img = self.conv_merge(img) 65 | return img 66 | 67 | 68 | class sMLPNet(nn.Module): 69 | 70 | def __init__(self, in_chans=3, dim=80, alpha=3, num_classes=1000, patch_size=4, image_size=224, depths=[2,8,14,2], dp_rate=0., 71 | **kwargs): 72 | super(sMLPNet, self).__init__() 73 | ''' 74 | (B,H,W,C): (B,(image_size// patch_size)**2,dim) 75 | ''' 76 | 77 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 78 | self.num_patch = image_size // patch_size 79 | self.depths = depths 80 | 81 | self.to_patch_embedding = nn.ModuleList([]) 82 | self.token_mix = nn.ModuleList([]) 83 | self.channel_mix = nn.ModuleList([]) 84 | self.drop_path = nn.ModuleList([]) 85 | 86 | net_num_blocks = sum(self.depths) 87 | net_block_idx = 0 88 | for i in range(len(self.depths)): 89 | ratio = 2 ** i 90 | if i == 0: 91 | self.to_patch_embedding.append(nn.Sequential(nn.Conv2d(in_chans, dim, patch_size, patch_size, bias=False))) 92 | else: 93 | self.to_patch_embedding.append(nn.Sequential(nn.Conv2d(dim * ratio // 2, dim * ratio, 2, 2, bias=False))) 94 | 95 | for j in range(self.depths[i]): 96 | block_dpr = dp_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule 97 | self.drop_path.append(DropPath(block_dpr) if block_dpr > 0. else nn.Identity()) 98 | net_block_idx += 1 99 | 100 | self.channel_mix.append(nn.Sequential( 101 | Rearrange('b c h w -> b h w c'), 102 | nn.LayerNorm(dim*ratio), 103 | FeedForward(dim*ratio,dim*ratio*alpha), 104 | Rearrange('b h w c -> b c h w')) 105 | ) 106 | 107 | self.token_mix.append(nn.Sequential(DWConvBlock(dim*ratio), sMLPBlock(self.num_patch//ratio, self.num_patch//ratio, dim * ratio))) 108 | 109 | self.batch_norm = nn.BatchNorm2d(dim*2**(len(self.depths)-1)) 110 | 111 | self.mlp_head = nn.Sequential( 112 | nn.Linear(dim * 2**(len(self.depths)-1), num_classes) 113 | ) 114 | 115 | def forward(self, x): 116 | 117 | shift = 0 118 | for i in range(len(self.depths)): 119 | x = self.to_patch_embedding[i](x) 120 | for j in range(self.depths[i]): 121 | x = x + self.drop_path[j+shift](self.token_mix[j+shift](x)) 122 | x = x + self.drop_path[j+shift](self.channel_mix[j+shift](x)) 123 | shift += self.depths[i] 124 | 125 | x = self.batch_norm(x) 126 | 127 | x = x.mean(dim=[2,3]).flatten(1) 128 | 129 | return self.mlp_head(x) 130 | -------------------------------------------------------------------------------- /models/spach/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | from .spach import Spach 4 | from .spach_ms import SpachMS 5 | 6 | __all__ = ['Spach', 'SpachMS'] 7 | -------------------------------------------------------------------------------- /models/spach/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .channel_func import ChannelMLP 2 | from .spatial_func import DWConv, SPATIAL_FUNC 3 | from .stem import STEM_LAYER 4 | 5 | __all__ = ['ChannelMLP', 'DWConv', 'SPATIAL_FUNC', 'STEM_LAYER'] 6 | -------------------------------------------------------------------------------- /models/spach/layers/channel_func.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class ChannelMLP(nn.Module): 5 | """Channel MLP""" 6 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., **kwargs): 7 | super(ChannelMLP, self).__init__() 8 | out_features = out_features or in_features 9 | hidden_features = hidden_features or in_features 10 | self.fc1 = nn.Linear(in_features, hidden_features) 11 | self.act = act_layer() 12 | self.fc2 = nn.Linear(hidden_features, out_features) 13 | self.drop = nn.Dropout(drop) 14 | 15 | self.hidden_features = hidden_features 16 | self.out_features = out_features 17 | 18 | def forward(self, x): 19 | B, N, C = x.shape 20 | x = self.fc1(x) 21 | x = self.act(x) 22 | x = self.drop(x) 23 | x = self.fc2(x) 24 | x = self.drop(x) 25 | return x 26 | 27 | def flops(self, input_shape): 28 | _, N, C = input_shape 29 | flops = 0 30 | flops += (C + 1) * self.hidden_features * N 31 | flops += (self.hidden_features + 1) * self.out_features * N 32 | return flops -------------------------------------------------------------------------------- /models/spach/layers/spatial_func.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from einops import rearrange 3 | 4 | from ..misc import Reshape2HW, Reshape2N 5 | 6 | 7 | class SpatialAttention(nn.Module): 8 | """Spatial Attention""" 9 | def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., **kwargs): 10 | super(SpatialAttention, self).__init__() 11 | head_dim = dim // num_heads 12 | 13 | self.num_heads = num_heads 14 | self.scale = qk_scale or head_dim ** -0.5 15 | 16 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 17 | self.attn_drop = nn.Dropout(attn_drop) 18 | 19 | self.proj = nn.Linear(dim, dim) 20 | self.proj_drop = nn.Dropout(proj_drop) 21 | 22 | def forward(self, x): 23 | B, N, C = x.shape 24 | qkv = self.qkv(x) 25 | qkv = rearrange(qkv, "b n (three heads head_c) -> three b heads n head_c", three=3, heads=self.num_heads) 26 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 27 | 28 | attn = (q @ k.transpose(-2, -1)) # B, head, N, N 29 | attn = attn.softmax(dim=-1) 30 | attn = self.attn_drop(attn) 31 | 32 | out = (attn @ v) # B, head, N, C 33 | out = rearrange(out, "b heads n head_c -> b n (heads head_c)") 34 | 35 | out = self.proj(out) 36 | out = self.proj_drop(out) 37 | 38 | return out 39 | 40 | def flops(self, input_shape): 41 | _, N, C = input_shape 42 | flops = 0 43 | # qkv 44 | flops += 3 * C * C * N 45 | # q@k 46 | flops += N ** 2 * C 47 | # attn@v 48 | flops += N ** 2 * C 49 | # proj 50 | flops += C * C * N 51 | return flops 52 | 53 | class SpatialMLP(nn.Module): 54 | """Spatial MLP""" 55 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., **kwargs): 56 | super(SpatialMLP, self).__init__() 57 | out_features = out_features or in_features 58 | hidden_features = hidden_features or in_features 59 | self.fc1 = nn.Linear(in_features, hidden_features) 60 | self.act = act_layer() 61 | self.fc2 = nn.Linear(hidden_features, out_features) 62 | self.drop = nn.Dropout(drop) 63 | 64 | self.hidden_features = hidden_features 65 | self.out_features = out_features 66 | 67 | def forward(self, x): 68 | B, N, C = x.shape 69 | x = x.transpose(1, 2) 70 | x = self.fc1(x) 71 | x = self.act(x) 72 | x = self.drop(x) 73 | x = self.fc2(x) 74 | x = self.drop(x) 75 | x = x.transpose(1, 2) 76 | return x 77 | 78 | def flops(self, input_shape): 79 | _, N, C = input_shape 80 | flops = 0 81 | flops += (N + 1) * self.hidden_features * C 82 | flops += (self.hidden_features + 1) * self.out_features * C 83 | return flops 84 | 85 | 86 | class DWConv(nn.Module): 87 | def __init__(self, dim, kernel_size=3): 88 | super(DWConv, self).__init__() 89 | self.dim = dim 90 | self.kernel_size = kernel_size 91 | 92 | padding = (kernel_size - 1) // 2 93 | self.net = nn.Sequential(Reshape2HW(), 94 | nn.Conv2d(dim, dim, kernel_size, 1, padding, groups=dim), 95 | Reshape2N()) 96 | 97 | 98 | def forward(self, x): 99 | x = self.net(x) 100 | return x 101 | 102 | def flops(self, input_shape): 103 | _, N, C = input_shape 104 | flops = N * self.dim * (3 * 3 + 1) 105 | return flops 106 | 107 | 108 | SPATIAL_FUNC = {'attn': SpatialAttention, 'mlp': SpatialMLP, 'pass': None} 109 | -------------------------------------------------------------------------------- /models/spach/layers/stem.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from timm.models.layers import to_2tuple 4 | 5 | from ..misc import check_upstream_shape 6 | 7 | 8 | class PatchEmbed(nn.Module): 9 | """1-conv patch embedding layer""" 10 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, downstream=False): 11 | super().__init__() 12 | img_size = to_2tuple(img_size) 13 | patch_size = to_2tuple(patch_size) 14 | 15 | self.downstream = downstream 16 | self.img_size = img_size 17 | self.patch_size = patch_size 18 | self.stem_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 19 | self.num_patches = self.stem_shape[0] * self.stem_shape[1] 20 | 21 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 22 | self.out_size = None 23 | 24 | # for flops 25 | self.in_chans = in_chans 26 | self.embed_dim = embed_dim 27 | 28 | def forward(self, x): 29 | if not self.downstream: 30 | check_upstream_shape(x, self.img_size) 31 | x = self.proj(x) 32 | return x 33 | 34 | def flops(self, input_shape=None): 35 | flops = self.num_patches * self.embed_dim * (sum(self.patch_size) * self.in_chans + 1) # Ho*Wo*Co*(K^2*Ci+1) 36 | return flops 37 | 38 | 39 | class Conv4PatchEmbed(nn.Module): 40 | """4-conv patch embedding layer""" 41 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, downstream=False, hidden_chans=64): 42 | super(Conv4PatchEmbed, self).__init__() 43 | img_size = to_2tuple(img_size) 44 | patch_size = to_2tuple(patch_size) 45 | 46 | self.downstream = downstream 47 | self.img_size = img_size 48 | self.patch_size = patch_size 49 | self.stem_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 50 | self.num_patches = self.stem_shape[0] * self.stem_shape[1] 51 | 52 | sub_patch_size = (patch_size[0]//2, patch_size[1]//2) 53 | 54 | self.proj = nn.Sequential( 55 | nn.Conv2d(in_chans, hidden_chans, kernel_size=7, stride=2, padding=3, bias=False), 56 | nn.BatchNorm2d(hidden_chans), 57 | nn.ReLU(), 58 | nn.Conv2d(hidden_chans, hidden_chans, 3, 1, 1, bias=False), 59 | nn.BatchNorm2d(hidden_chans), 60 | nn.ReLU(), 61 | nn.Conv2d(hidden_chans, hidden_chans, 3, 1, 1, bias=False), 62 | nn.BatchNorm2d(hidden_chans), 63 | nn.ReLU(), 64 | nn.Conv2d(hidden_chans, embed_dim, kernel_size=sub_patch_size, stride=sub_patch_size) 65 | ) 66 | 67 | # for flops 68 | self.inside_num_patches = self.num_patches * sum(sub_patch_size) 69 | self.in_chans = in_chans 70 | self.new_patch_size = sub_patch_size 71 | self.embed_dim = embed_dim 72 | self.hidden_chans = hidden_chans 73 | 74 | def forward(self, x): 75 | if not self.downstream: 76 | check_upstream_shape(x, self.img_size) 77 | x = self.proj(x) 78 | return x 79 | 80 | def flops(self, input_shape=None): 81 | flops = 0 82 | flops += self.inside_num_patches * self.hidden_chans * self.in_chans * 7 * 7 # Ho*Wo*Co*K^2*Ci+1 83 | flops += self.inside_num_patches * self.hidden_chans 84 | 85 | flops += self.inside_num_patches * self.hidden_chans * self.hidden_chans * 3 * 3 86 | flops += self.inside_num_patches * self.hidden_chans 87 | 88 | flops += self.inside_num_patches * self.hidden_chans * self.hidden_chans * 3 * 3 89 | flops += self.inside_num_patches * self.hidden_chans 90 | 91 | flops += self.num_patches * self.embed_dim * (sum(self.new_patch_size)*self.hidden_chans + 1) # Ho*Wo*Co*(K^2*Ci+1) 92 | 93 | return flops 94 | 95 | 96 | STEM_LAYER = {'conv1': PatchEmbed, 'conv4': Conv4PatchEmbed} 97 | -------------------------------------------------------------------------------- /models/spach/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | from functools import partial 4 | 5 | from torch import nn 6 | from einops import rearrange 7 | 8 | from timm.models.layers import to_2tuple 9 | 10 | 11 | def check_upstream_shape(x, img_size=(224, 224)): 12 | _, _, H, W = x.shape 13 | assert H == img_size[0] and W == img_size[1], \ 14 | f"Input image size ({H}*{W}) doesn't match model ({img_size[0]}*{img_size[1]})." 15 | 16 | 17 | def reshape2n(x): 18 | return rearrange(x, 'b c h w -> b (h w) c') 19 | 20 | 21 | def reshape2hw(x, hw=None): 22 | n = x.shape[1] 23 | if hw is None: 24 | hw = to_2tuple(int(n ** 0.5)) 25 | assert n == hw[0] * hw[1], f"N={n} is not equal to H={hw[0]}*W={hw[1]}" 26 | return rearrange(x, 'b (h w) c -> b c h w', h=hw[0]) 27 | 28 | 29 | def downsample_conv(in_channels, out_channels, kernel_size=2, stride=2, padding=0, dilation=1, norm_layer=None): 30 | assert norm_layer is None, "only support default normalization" 31 | norm_layer = norm_layer or partial(nn.GroupNorm, num_groups=1, num_channels=out_channels) 32 | kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size 33 | dilation = dilation if kernel_size > 1 else 1 34 | return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, 35 | dilation=dilation, bias=False), 36 | norm_layer() 37 | ) 38 | 39 | 40 | class Reshape2N(nn.Module): 41 | def __init__(self): 42 | super(Reshape2N, self).__init__() 43 | 44 | def forward(self, x): 45 | return reshape2n(x) 46 | 47 | 48 | class Reshape2HW(nn.Module): 49 | def __init__(self, hw=None): 50 | super(Reshape2HW, self).__init__() 51 | self.hw = hw 52 | 53 | def forward(self, x): 54 | return reshape2hw(x, self.hw) 55 | 56 | 57 | class DownsampleConv(nn.Module): 58 | def __init__(self, in_channels, out_channels, kernel_size=2, stride=2, padding=0, dilation=1, norm_layer=None): 59 | super(DownsampleConv, self).__init__() 60 | self.net = nn.Sequential( 61 | Reshape2HW(), 62 | downsample_conv(in_channels, out_channels, kernel_size, stride, padding, dilation, norm_layer), 63 | Reshape2N() 64 | ) 65 | 66 | self.kernel_size = kernel_size 67 | self.in_channels = in_channels 68 | self.out_channels = out_channels 69 | 70 | def forward(self, x): 71 | return self.net(x) 72 | 73 | def flops(self, input_shape): 74 | _, N, C = input_shape # C == out_channels 75 | flops = 0 76 | flops += N * self.out_channels * self.in_channels * self.kernel_size**2 77 | flops += N * self.out_channels 78 | return flops 79 | -------------------------------------------------------------------------------- /models/spach/spach.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | from functools import partial 4 | 5 | import torch 6 | from torch import nn 7 | from timm.models.layers import DropPath 8 | from einops.layers.torch import Reduce 9 | 10 | from .layers import DWConv, SPATIAL_FUNC, ChannelMLP, STEM_LAYER 11 | from .misc import reshape2n 12 | 13 | 14 | class MixingBlock(nn.Module): 15 | def __init__(self, dim, 16 | spatial_func=None, scaled=True, init_values=1e-4, shared_spatial_func=False, 17 | norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop_path=0., cpe=True, 18 | num_heads=None, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., # attn 19 | in_features=None, hidden_features=None, drop=0., # mlp 20 | channel_ratio=2.0 21 | ): 22 | super(MixingBlock, self).__init__() 23 | 24 | spatial_kwargs = dict(act_layer=act_layer, 25 | in_features=in_features, hidden_features=hidden_features, drop=drop, # mlp 26 | dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop # attn 27 | ) 28 | 29 | self.valid_spatial_func = True 30 | 31 | if spatial_func is not None: 32 | if shared_spatial_func: 33 | self.spatial_func = spatial_func 34 | else: 35 | self.spatial_func = spatial_func(**spatial_kwargs) 36 | self.norm1 = norm_layer(dim) 37 | if scaled: 38 | self.gamma_1 = nn.Parameter(init_values * torch.ones(1, 1, dim), requires_grad=True) 39 | else: 40 | self.gamma_1 = 1. 41 | else: 42 | self.valid_spatial_func = False 43 | 44 | self.channel_func = ChannelMLP(in_features=dim, hidden_features=int(dim*channel_ratio), act_layer=act_layer, 45 | drop=drop) 46 | 47 | self.norm2 = norm_layer(dim) 48 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 49 | 50 | 51 | self.cpe = cpe 52 | if cpe: 53 | self.cpe_net = DWConv(dim) 54 | 55 | 56 | def forward(self, x): 57 | in_x = x 58 | if self.valid_spatial_func: 59 | x = x + self.drop_path(self.gamma_1 * self.spatial_func(self.norm1(in_x))) 60 | if self.cpe: 61 | x = x + self.cpe_net(in_x) 62 | 63 | x = x + self.drop_path(self.channel_func(self.norm2(x))) 64 | 65 | return x 66 | 67 | def flops(self, input_shape): 68 | _, N, C = input_shape 69 | flops = 0 70 | if self.valid_spatial_func: 71 | flops += self.spatial_func.flops(input_shape) 72 | flops += N * C * 2 # norm + skip 73 | if self.cpe: 74 | flops += self.cpe_net.flops(input_shape) 75 | 76 | flops += self.channel_func.flops(input_shape) 77 | flops += N * C * 2 78 | return flops 79 | 80 | 81 | class Spach(nn.Module): 82 | def __init__(self, 83 | num_classes=1000, 84 | img_size=224, 85 | in_chans=3, 86 | hidden_dim=384, 87 | patch_size=16, 88 | net_arch=None, 89 | act_layer=nn.GELU, 90 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 91 | stem_type='conv1', 92 | scaled=True, init_values=1e-4, drop_path_rate=0., cpe=True, shared_spatial_func=False, # mixing block 93 | num_heads=12, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., # attn 94 | token_ratio=0.5, channel_ratio=2.0, drop_rate=0., # mlp 95 | downstream=False, 96 | **kwargs 97 | ): 98 | super(Spach, self).__init__() 99 | self.num_classes = num_classes 100 | self.hidden_dim = hidden_dim 101 | self.downstream = downstream 102 | 103 | self.stem = STEM_LAYER[stem_type]( 104 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim, downstream=downstream) 105 | self.norm1 = norm_layer(hidden_dim) 106 | 107 | block_kwargs = dict(dim=hidden_dim, scaled=scaled, init_values=init_values, cpe=cpe, 108 | shared_spatial_func=shared_spatial_func, norm_layer=norm_layer, act_layer=act_layer, 109 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop, # attn 110 | in_features=self.stem.num_patches, hidden_features=int(self.stem.num_patches * token_ratio), channel_ratio=channel_ratio, drop=drop_rate) # mlp 111 | 112 | self.blocks = self.make_blocks(net_arch, block_kwargs, drop_path_rate, shared_spatial_func) 113 | self.norm2 = norm_layer(hidden_dim) 114 | 115 | if not downstream: 116 | self.pool = Reduce('b n c -> b c', reduction='mean') 117 | self.head = nn.Linear(hidden_dim, self.num_classes) 118 | 119 | self.init_weights() 120 | 121 | def make_blocks(self, net_arch, block_kwargs, drop_path, shared_spatial_func): 122 | if shared_spatial_func: 123 | assert len(net_arch) == 1, '`shared_spatial_func` only support unitary spatial function' 124 | assert net_arch[0][0] != 'pass', '`shared_spatial_func` do not support pass' 125 | spatial_func = SPATIAL_FUNC[net_arch[0][0]](**block_kwargs) 126 | else: 127 | spatial_func = None 128 | blocks = [] 129 | for func_type, depth in net_arch: 130 | for i in range(depth): 131 | blocks.append(MixingBlock(spatial_func=spatial_func or SPATIAL_FUNC[func_type], drop_path=drop_path, 132 | **block_kwargs)) 133 | return nn.Sequential(*blocks) 134 | 135 | def init_weights(self): 136 | for n, m in self.named_modules(): 137 | _init_weights(m, n) 138 | 139 | def forward_features(self, x): 140 | x = self.stem(x) 141 | x = reshape2n(x) 142 | x = self.norm1(x) 143 | 144 | x = self.blocks(x) 145 | x = self.norm2(x) 146 | 147 | return x 148 | 149 | def forward(self, x): 150 | x = self.forward_features(x) 151 | x = self.pool(x) 152 | x = self.head(x) 153 | return x 154 | 155 | def flops(self): 156 | flops = 0 157 | shape = (1, self.stem.num_patches, self.hidden_dim) 158 | # stem 159 | flops += self.stem.flops() 160 | flops += sum(shape) 161 | # blocks 162 | flops += sum([i.flops(shape) for i in self.blocks]) 163 | flops += sum(shape) 164 | # head 165 | flops += self.hidden_dim * self.num_classes 166 | return flops 167 | 168 | 169 | def _init_weights(m, n: str): 170 | if isinstance(m, nn.Linear): 171 | if n.startswith('head'): 172 | nn.init.zeros_(m.weight) 173 | nn.init.zeros_(m.bias) 174 | else: 175 | nn.init.xavier_uniform_(m.weight) 176 | if m.bias is not None: 177 | if 'mlp' in n: 178 | nn.init.normal_(m.bias, std=1e-6) 179 | else: 180 | nn.init.zeros_(m.bias) 181 | elif isinstance(m, nn.Conv2d): 182 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 183 | if m.bias is not None: 184 | nn.init.zeros_(m.bias) 185 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): 186 | nn.init.ones_(m.weight) 187 | nn.init.zeros_(m.bias) -------------------------------------------------------------------------------- /models/spach/spach_ms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | from functools import partial 4 | 5 | from torch import nn 6 | from einops.layers.torch import Reduce 7 | 8 | from .spach import MixingBlock, _init_weights 9 | from .layers import STEM_LAYER, SPATIAL_FUNC 10 | from .misc import DownsampleConv, reshape2n 11 | 12 | 13 | class SpachMS(nn.Module): 14 | def __init__(self, 15 | num_classes=1000, 16 | img_size=224, 17 | in_chans=3, 18 | hidden_dim=384, 19 | patch_size=16, 20 | net_arch=None, 21 | act_layer=nn.GELU, 22 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 23 | stem_type='conv1', 24 | scaled=True, init_values=1e-4, drop_path_rate=0., cpe=True, shared_spatial_func=False, # mixing block 25 | num_heads=12, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., # attn 26 | token_ratio=0.5, channel_ratio=2.0, drop_rate=0., # mlp 27 | downstream=False, 28 | **kwargs 29 | ): 30 | super(SpachMS, self).__init__() 31 | assert len(net_arch) == 4 32 | self.num_classes = num_classes 33 | self.hidden_dim = hidden_dim 34 | self.downstream = downstream 35 | 36 | self.stem = STEM_LAYER[stem_type]( 37 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim, downstream=downstream) 38 | self.norm1 = norm_layer(hidden_dim) 39 | 40 | block_kwargs = dict(scaled=scaled, init_values=init_values, cpe=cpe, 41 | shared_spatial_func=shared_spatial_func, norm_layer=norm_layer, act_layer=act_layer, 42 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop, # attn 43 | channel_ratio=channel_ratio, drop=drop_rate) # mlp 44 | 45 | stage_modules = self.make_blocks(hidden_dim, self.stem.num_patches, net_arch, block_kwargs, drop_path_rate, 46 | shared_spatial_func, token_ratio) 47 | for stage in stage_modules: 48 | self.add_module(*stage) 49 | hidden_dim = hidden_dim * 8 50 | self.norm2 = norm_layer(hidden_dim) 51 | 52 | if not downstream: 53 | self.pool = Reduce('b n c -> b c', reduction='mean') 54 | self.head = nn.Linear(hidden_dim, self.num_classes) 55 | 56 | self.init_weights() 57 | 58 | def make_blocks(self, dim, seq_len, net_arch, block_kwargs, drop_path, shared_spatial_func, token_ratio): 59 | stages = [] 60 | num_blocks = sum(sum([depth for _, depth in stage_arch]) for stage_arch in net_arch) 61 | block_idx = 0 62 | 63 | for stage_idx, stage_arch in enumerate(net_arch): 64 | stage_name = f'layer{stage_idx + 1}' 65 | blocks = [] 66 | if stage_idx > 0: 67 | down_kwargs = dict(in_channels=dim, out_channels=dim * 2) 68 | downsample = DownsampleConv(**down_kwargs) 69 | blocks.append(downsample) 70 | dim = dim * 2 71 | seq_len = seq_len // 4 72 | 73 | block_kwargs.update(dict(dim=dim, in_features=seq_len, hidden_features=int(seq_len * token_ratio))) 74 | 75 | if stage_idx > 0 and shared_spatial_func: 76 | assert len(stage_arch) == 1, '`shared_spatial_func` only support unitary spatial function' 77 | assert stage_arch[0][0] != 'pass', '`shared_spatial_func` do not support pass' 78 | spatial_func = SPATIAL_FUNC[stage_arch[0][0]](**block_kwargs) 79 | else: 80 | spatial_func = None 81 | 82 | for func_type, depth in stage_arch: 83 | for i in range(depth): 84 | block_dpr = drop_path * block_idx / (num_blocks - 1) # stochastic depth linear decay rule 85 | blocks.append(MixingBlock(spatial_func=spatial_func or SPATIAL_FUNC[func_type], drop_path=block_dpr, 86 | **block_kwargs)) 87 | block_idx += 1 88 | stages.append((stage_name, nn.Sequential(*blocks))) 89 | 90 | return stages 91 | 92 | def init_weights(self): 93 | for n, m in self.named_modules(): 94 | _init_weights(m, n) 95 | 96 | def forward_features(self, x): 97 | x = self.stem(x) 98 | x = reshape2n(x) 99 | x = self.norm1(x) 100 | 101 | x = self.layer1(x) 102 | x = self.layer2(x) 103 | x = self.layer3(x) 104 | x = self.layer4(x) 105 | 106 | x = self.norm2(x) 107 | 108 | return x 109 | 110 | def forward(self, x): 111 | x = self.forward_features(x) 112 | x = self.pool(x) 113 | x = self.head(x) 114 | return x 115 | 116 | def flops(self): 117 | flops = 0 118 | shape = (1, self.stem.num_patches, self.hidden_dim) 119 | # stem 120 | flops += self.stem.flops() 121 | flops += sum(shape) 122 | # layer1,2,3,4 123 | flops += sum([i.flops(shape) for i in self.layer1]) 124 | shape = (1, self.stem.num_patches//4, self.hidden_dim*2) 125 | flops += sum([i.flops(shape) for i in self.layer2]) 126 | shape = (1, self.stem.num_patches//16, self.hidden_dim*4) 127 | flops += sum([i.flops(shape) for i in self.layer3]) 128 | shape = (1, self.stem.num_patches//64, self.hidden_dim*8) 129 | flops += sum([i.flops(shape) for i in self.layer4]) 130 | flops += sum(shape) 131 | # head 132 | flops += self.hidden_dim * 8 * self.num_classes 133 | return flops -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.0 2 | torchvision==0.8.1 3 | timm==0.3.2 4 | einops==0.3.2 5 | -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | class RASampler(torch.utils.data.Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset for distributed, 10 | with repeated augmentation. 11 | It ensures that different each augmented version of a sample will be visible to a 12 | different process (GPU) 13 | Heavily based on torch.utils.data.DistributedSampler 14 | """ 15 | 16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 17 | if num_replicas is None: 18 | if not dist.is_available(): 19 | raise RuntimeError("Requires distributed package to be available") 20 | num_replicas = dist.get_world_size() 21 | if rank is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available") 24 | rank = dist.get_rank() 25 | self.dataset = dataset 26 | self.num_replicas = num_replicas 27 | self.rank = rank 28 | self.epoch = 0 29 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 30 | self.total_size = self.num_samples * self.num_replicas 31 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 32 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 33 | self.shuffle = shuffle 34 | 35 | def __iter__(self): 36 | # deterministically shuffle based on epoch 37 | g = torch.Generator() 38 | g.manual_seed(self.epoch) 39 | if self.shuffle: 40 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 41 | else: 42 | indices = list(range(len(self.dataset))) 43 | 44 | # add extra samples to make it evenly divisible 45 | indices = [ele for ele in indices for i in range(3)] 46 | indices += indices[:(self.total_size - len(indices))] 47 | assert len(indices) == self.total_size 48 | 49 | # subsample 50 | indices = indices[self.rank:self.total_size:self.num_replicas] 51 | assert len(indices) == self.num_samples 52 | 53 | return iter(indices[:self.num_selected_samples]) 54 | 55 | def __len__(self): 56 | return self.num_selected_samples 57 | 58 | def set_epoch(self, epoch): 59 | self.epoch = epoch 60 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | Mostly copy-paste from torchvision references. 7 | """ 8 | import io 9 | import os 10 | import time 11 | from collections import defaultdict, deque 12 | import datetime 13 | import logging 14 | 15 | import torch 16 | import torch.distributed as dist 17 | 18 | 19 | class SmoothedValue(object): 20 | """Track a series of values and provide access to smoothed values over a 21 | window or the global series average. 22 | """ 23 | 24 | def __init__(self, window_size=20, fmt=None): 25 | if fmt is None: 26 | fmt = "{median:.4f} ({global_avg:.4f})" 27 | self.deque = deque(maxlen=window_size) 28 | self.total = 0.0 29 | self.count = 0 30 | self.fmt = fmt 31 | 32 | def update(self, value, n=1): 33 | self.deque.append(value) 34 | self.count += n 35 | self.total += value * n 36 | 37 | def synchronize_between_processes(self): 38 | """ 39 | Warning: does not synchronize the deque! 40 | """ 41 | if not is_dist_avail_and_initialized(): 42 | return 43 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 44 | dist.barrier() 45 | dist.all_reduce(t) 46 | t = t.tolist() 47 | self.count = int(t[0]) 48 | self.total = t[1] 49 | 50 | @property 51 | def median(self): 52 | d = torch.tensor(list(self.deque)) 53 | return d.median().item() 54 | 55 | @property 56 | def avg(self): 57 | d = torch.tensor(list(self.deque), dtype=torch.float32) 58 | return d.mean().item() 59 | 60 | @property 61 | def global_avg(self): 62 | return self.total / self.count 63 | 64 | @property 65 | def max(self): 66 | return max(self.deque) 67 | 68 | @property 69 | def value(self): 70 | return self.deque[-1] 71 | 72 | def __str__(self): 73 | return self.fmt.format( 74 | median=self.median, 75 | avg=self.avg, 76 | global_avg=self.global_avg, 77 | max=self.max, 78 | value=self.value) 79 | 80 | 81 | class MetricLogger(object): 82 | def __init__(self, delimiter="\t", logger=logging): 83 | self.meters = defaultdict(SmoothedValue) 84 | self.delimiter = delimiter 85 | self.logger = logger 86 | 87 | def update(self, **kwargs): 88 | for k, v in kwargs.items(): 89 | if isinstance(v, torch.Tensor): 90 | v = v.item() 91 | assert isinstance(v, (float, int)) 92 | self.meters[k].update(v) 93 | 94 | def __getattr__(self, attr): 95 | if attr in self.meters: 96 | return self.meters[attr] 97 | if attr in self.__dict__: 98 | return self.__dict__[attr] 99 | raise AttributeError("'{}' object has no attribute '{}'".format( 100 | type(self).__name__, attr)) 101 | 102 | def __str__(self): 103 | loss_str = [] 104 | for name, meter in self.meters.items(): 105 | loss_str.append( 106 | "{}: {}".format(name, str(meter)) 107 | ) 108 | return self.delimiter.join(loss_str) 109 | 110 | def synchronize_between_processes(self): 111 | for meter in self.meters.values(): 112 | meter.synchronize_between_processes() 113 | 114 | def add_meter(self, name, meter): 115 | self.meters[name] = meter 116 | 117 | def log_every(self, iterable, print_freq, header=None): 118 | i = 0 119 | if not header: 120 | header = '' 121 | start_time = time.time() 122 | end = time.time() 123 | iter_time = SmoothedValue(fmt='{avg:.4f}') 124 | data_time = SmoothedValue(fmt='{avg:.4f}') 125 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 126 | log_msg = [ 127 | header, 128 | '[{0' + space_fmt + '}/{1}]', 129 | 'eta: {eta}', 130 | '{meters}', 131 | 'time: {time}', 132 | 'data: {data}' 133 | ] 134 | if torch.cuda.is_available(): 135 | log_msg.append('max mem: {memory:.0f}') 136 | log_msg = self.delimiter.join(log_msg) 137 | MB = 1024.0 * 1024.0 138 | for obj in iterable: 139 | data_time.update(time.time() - end) 140 | yield obj 141 | iter_time.update(time.time() - end) 142 | if i % print_freq == 0 or i == len(iterable) - 1: 143 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 144 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 145 | if torch.cuda.is_available(): 146 | self.logger.info(log_msg.format( 147 | i, len(iterable), eta=eta_string, 148 | meters=str(self), 149 | time=str(iter_time), data=str(data_time), 150 | memory=torch.cuda.max_memory_allocated() / MB)) 151 | else: 152 | self.logger.info(log_msg.format( 153 | i, len(iterable), eta=eta_string, 154 | meters=str(self), 155 | time=str(iter_time), data=str(data_time))) 156 | i += 1 157 | end = time.time() 158 | total_time = time.time() - start_time 159 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 160 | self.logger.info('{} Total time: {} ({:.4f} s / it)'.format( 161 | header, total_time_str, total_time / len(iterable))) 162 | 163 | 164 | def _load_checkpoint_for_ema(model_ema, checkpoint): 165 | """ 166 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 167 | """ 168 | mem_file = io.BytesIO() 169 | torch.save(checkpoint, mem_file) 170 | mem_file.seek(0) 171 | model_ema._load_checkpoint(mem_file) 172 | 173 | 174 | def setup_for_distributed(is_master): 175 | """ 176 | This function disables printing when not in master process 177 | """ 178 | import builtins as __builtin__ 179 | builtin_print = __builtin__.print 180 | 181 | def print(*args, **kwargs): 182 | force = kwargs.pop('force', False) 183 | if is_master or force: 184 | builtin_print(*args, **kwargs) 185 | 186 | __builtin__.print = print 187 | 188 | 189 | def is_dist_avail_and_initialized(): 190 | if not dist.is_available(): 191 | return False 192 | if not dist.is_initialized(): 193 | return False 194 | return True 195 | 196 | 197 | def get_world_size(): 198 | if not is_dist_avail_and_initialized(): 199 | return 1 200 | return dist.get_world_size() 201 | 202 | 203 | def get_rank(): 204 | if not is_dist_avail_and_initialized(): 205 | return 0 206 | return dist.get_rank() 207 | 208 | 209 | def is_main_process(): 210 | return get_rank() == 0 211 | 212 | 213 | def save_on_master(*args, **kwargs): 214 | if is_main_process(): 215 | torch.save(*args, **kwargs) 216 | 217 | 218 | def init_distributed_mode(args): 219 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 220 | args.rank = int(os.environ["RANK"]) 221 | args.world_size = int(os.environ['WORLD_SIZE']) 222 | args.gpu = int(os.environ['LOCAL_RANK']) 223 | elif 'SLURM_PROCID' in os.environ: 224 | args.rank = int(os.environ['SLURM_PROCID']) 225 | args.gpu = args.rank % torch.cuda.device_count() 226 | elif 'OMPI_COMM_WORLD_SIZE' in os.environ and 'OMPI_COMM_WORLD_RANK' in os.environ: 227 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 228 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 229 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 230 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 231 | print(f'dist train on amlk8s| word_size {args.world_size} | rank {args.rank} | gpu {args.gpu} | dist_url {args.dist_url}') 232 | else: 233 | print('Not using distributed mode') 234 | args.distributed = False 235 | return 236 | 237 | args.distributed = True 238 | 239 | torch.cuda.set_device(args.gpu) 240 | args.dist_backend = 'nccl' 241 | print('| distributed init (rank {}): {}'.format( 242 | args.rank, args.dist_url), flush=True) 243 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 244 | world_size=args.world_size, rank=args.rank) 245 | torch.distributed.barrier() 246 | setup_for_distributed(args.rank == 0) 247 | --------------------------------------------------------------------------------